1#pragma once
2
3#include <DataTypes/DataTypeArray.h>
4#include <DataTypes/DataTypeFunction.h>
5#include <DataTypes/DataTypeLowCardinality.h>
6#include <Columns/ColumnArray.h>
7#include <Columns/ColumnConst.h>
8#include <Columns/ColumnFunction.h>
9#include <Common/typeid_cast.h>
10#include <Common/assert_cast.h>
11#include <Functions/IFunctionImpl.h>
12#include <Functions/FunctionHelpers.h>
13#include <IO/WriteHelpers.h>
14
15
16namespace DB
17{
18
19namespace ErrorCodes
20{
21 extern const int ILLEGAL_COLUMN;
22 extern const int ILLEGAL_TYPE_OF_ARGUMENT;
23 extern const int SIZES_OF_ARRAYS_DOESNT_MATCH;
24 extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
25}
26
27
28/** Higher-order functions for arrays.
29 * These functions optionally apply a map (transform) to array (or multiple arrays of identical size) by lambda function,
30 * and return some result based on that transformation.
31 *
32 * Examples:
33 * arrayMap(x1,...,xn -> expression, array1,...,arrayn) - apply the expression to each element of the array (or set of parallel arrays).
34 * arrayFilter(x -> predicate, array) - leave in the array only the elements for which the expression is true.
35 *
36 * For some functions arrayCount, arrayExists, arrayAll, an overload of the form f(array) is available, which works in the same way as f(x -> x, array).
37 *
38 * See the example of Impl template parameter in arrayMap.cpp
39 */
40template <typename Impl, typename Name>
41class FunctionArrayMapped : public IFunction
42{
43public:
44 static constexpr auto name = Name::name;
45 static FunctionPtr create(const Context &) { return std::make_shared<FunctionArrayMapped>(); }
46
47 String getName() const override
48 {
49 return name;
50 }
51
52 bool isVariadic() const override { return true; }
53 size_t getNumberOfArguments() const override { return 0; }
54
55 /// Called if at least one function argument is a lambda expression.
56 /// For argument-lambda expressions, it defines the types of arguments of these expressions.
57 void getLambdaArgumentTypes(DataTypes & arguments) const override
58 {
59 if (arguments.size() < 1)
60 throw Exception("Function " + getName() + " needs at least one argument; passed "
61 + toString(arguments.size()) + ".",
62 ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
63
64 if (arguments.size() == 1)
65 throw Exception("Function " + getName() + " needs at least one array argument.",
66 ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
67
68 DataTypes nested_types(arguments.size() - 1);
69 for (size_t i = 0; i < nested_types.size(); ++i)
70 {
71 const DataTypeArray * array_type = checkAndGetDataType<DataTypeArray>(&*arguments[i + 1]);
72 if (!array_type)
73 throw Exception("Argument " + toString(i + 2) + " of function " + getName() + " must be array. Found "
74 + arguments[i + 1]->getName() + " instead.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
75 nested_types[i] = removeLowCardinality(array_type->getNestedType());
76 }
77
78 const DataTypeFunction * function_type = checkAndGetDataType<DataTypeFunction>(arguments[0].get());
79 if (!function_type || function_type->getArgumentTypes().size() != nested_types.size())
80 throw Exception("First argument for this overload of " + getName() + " must be a function with "
81 + toString(nested_types.size()) + " arguments. Found "
82 + arguments[0]->getName() + " instead.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
83
84 arguments[0] = std::make_shared<DataTypeFunction>(nested_types);
85 }
86
87 DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override
88 {
89 size_t min_args = Impl::needExpression() ? 2 : 1;
90 if (arguments.size() < min_args)
91 throw Exception("Function " + getName() + " needs at least "
92 + toString(min_args) + " argument; passed "
93 + toString(arguments.size()) + ".",
94 ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
95
96 if (arguments.size() == 1)
97 {
98 const auto array_type = checkAndGetDataType<DataTypeArray>(arguments[0].type.get());
99
100 if (!array_type)
101 throw Exception("The only argument for function " + getName() + " must be array. Found "
102 + arguments[0].type->getName() + " instead.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
103
104 DataTypePtr nested_type = array_type->getNestedType();
105
106 if (Impl::needBoolean() && !WhichDataType(nested_type).isUInt8())
107 throw Exception("The only argument for function " + getName() + " must be array of UInt8. Found "
108 + arguments[0].type->getName() + " instead.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
109
110 return Impl::getReturnType(nested_type, nested_type);
111 }
112 else
113 {
114 if (arguments.size() > 2 && Impl::needOneArray())
115 throw Exception("Function " + getName() + " needs one array argument.",
116 ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
117
118 const auto data_type_function = checkAndGetDataType<DataTypeFunction>(arguments[0].type.get());
119
120 if (!data_type_function)
121 throw Exception("First argument for function " + getName() + " must be a function.",
122 ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
123
124 /// The types of the remaining arguments are already checked in getLambdaArgumentTypes.
125
126 DataTypePtr return_type = removeLowCardinality(data_type_function->getReturnType());
127 if (Impl::needBoolean() && !WhichDataType(return_type).isUInt8())
128 throw Exception("Expression for function " + getName() + " must return UInt8, found "
129 + return_type->getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
130
131 const auto first_array_type = checkAndGetDataType<DataTypeArray>(arguments[1].type.get());
132
133 return Impl::getReturnType(return_type, first_array_type->getNestedType());
134 }
135 }
136
137 void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t /*input_rows_count*/) override
138 {
139 if (arguments.size() == 1)
140 {
141 ColumnPtr column_array_ptr = block.getByPosition(arguments[0]).column;
142 const auto * column_array = checkAndGetColumn<ColumnArray>(column_array_ptr.get());
143
144 if (!column_array)
145 {
146 const ColumnConst * column_const_array = checkAndGetColumnConst<ColumnArray>(column_array_ptr.get());
147 if (!column_const_array)
148 throw Exception("Expected array column, found " + column_array_ptr->getName(), ErrorCodes::ILLEGAL_COLUMN);
149 column_array_ptr = column_const_array->convertToFullColumn();
150 column_array = assert_cast<const ColumnArray *>(column_array_ptr.get());
151 }
152
153 block.getByPosition(result).column = Impl::execute(*column_array, column_array->getDataPtr());
154 }
155 else
156 {
157 const auto & column_with_type_and_name = block.getByPosition(arguments[0]);
158
159 if (!column_with_type_and_name.column)
160 throw Exception("First argument for function " + getName() + " must be a function.",
161 ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
162
163 const auto * column_function = typeid_cast<const ColumnFunction *>(column_with_type_and_name.column.get());
164
165 if (!column_function)
166 throw Exception("First argument for function " + getName() + " must be a function.",
167 ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
168
169 ColumnPtr offsets_column;
170
171 ColumnPtr column_first_array_ptr;
172 const ColumnArray * column_first_array = nullptr;
173
174 ColumnsWithTypeAndName arrays;
175 arrays.reserve(arguments.size() - 1);
176
177 for (size_t i = 1; i < arguments.size(); ++i)
178 {
179 const auto & array_with_type_and_name = block.getByPosition(arguments[i]);
180
181 ColumnPtr column_array_ptr = array_with_type_and_name.column;
182 const auto * column_array = checkAndGetColumn<ColumnArray>(column_array_ptr.get());
183
184 const DataTypePtr & array_type_ptr = array_with_type_and_name.type;
185 const auto * array_type = checkAndGetDataType<DataTypeArray>(array_type_ptr.get());
186
187 if (!column_array)
188 {
189 const ColumnConst * column_const_array = checkAndGetColumnConst<ColumnArray>(column_array_ptr.get());
190 if (!column_const_array)
191 throw Exception("Expected array column, found " + column_array_ptr->getName(), ErrorCodes::ILLEGAL_COLUMN);
192 column_array_ptr = column_const_array->convertToFullColumn();
193 if (column_array_ptr->lowCardinality())
194 column_array_ptr = column_array_ptr->convertToFullColumnIfLowCardinality();
195 column_array = checkAndGetColumn<ColumnArray>(column_array_ptr.get());
196 }
197
198 if (!array_type)
199 throw Exception("Expected array type, found " + array_type_ptr->getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
200
201 if (!offsets_column)
202 {
203 offsets_column = column_array->getOffsetsPtr();
204 }
205 else
206 {
207 /// The first condition is optimization: do not compare data if the pointers are equal.
208 if (column_array->getOffsetsPtr() != offsets_column
209 && column_array->getOffsets() != typeid_cast<const ColumnArray::ColumnOffsets &>(*offsets_column).getData())
210 throw Exception("Arrays passed to " + getName() + " must have equal size", ErrorCodes::SIZES_OF_ARRAYS_DOESNT_MATCH);
211 }
212
213 if (i == 1)
214 {
215 column_first_array_ptr = column_array_ptr;
216 column_first_array = column_array;
217 }
218
219 arrays.emplace_back(ColumnWithTypeAndName(column_array->getDataPtr(),
220 removeLowCardinality(array_type->getNestedType()),
221 array_with_type_and_name.name));
222 }
223
224 /// Put all the necessary columns multiplied by the sizes of arrays into the block.
225 auto replicated_column_function_ptr = (*column_function->replicate(column_first_array->getOffsets())).mutate();
226 auto * replicated_column_function = typeid_cast<ColumnFunction *>(replicated_column_function_ptr.get());
227 replicated_column_function->appendArguments(arrays);
228
229 auto lambda_result = replicated_column_function->reduce().column;
230 if (lambda_result->lowCardinality())
231 lambda_result = lambda_result->convertToFullColumnIfLowCardinality();
232
233 block.getByPosition(result).column = Impl::execute(*column_first_array, lambda_result);
234 }
235 }
236};
237
238}
239