1 | #pragma once |
2 | |
3 | #include <DataTypes/DataTypeArray.h> |
4 | #include <DataTypes/DataTypeFunction.h> |
5 | #include <DataTypes/DataTypeLowCardinality.h> |
6 | #include <Columns/ColumnArray.h> |
7 | #include <Columns/ColumnConst.h> |
8 | #include <Columns/ColumnFunction.h> |
9 | #include <Common/typeid_cast.h> |
10 | #include <Common/assert_cast.h> |
11 | #include <Functions/IFunctionImpl.h> |
12 | #include <Functions/FunctionHelpers.h> |
13 | #include <IO/WriteHelpers.h> |
14 | |
15 | |
16 | namespace DB |
17 | { |
18 | |
19 | namespace ErrorCodes |
20 | { |
21 | extern const int ILLEGAL_COLUMN; |
22 | extern const int ILLEGAL_TYPE_OF_ARGUMENT; |
23 | extern const int SIZES_OF_ARRAYS_DOESNT_MATCH; |
24 | extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; |
25 | } |
26 | |
27 | |
28 | /** Higher-order functions for arrays. |
29 | * These functions optionally apply a map (transform) to array (or multiple arrays of identical size) by lambda function, |
30 | * and return some result based on that transformation. |
31 | * |
32 | * Examples: |
33 | * arrayMap(x1,...,xn -> expression, array1,...,arrayn) - apply the expression to each element of the array (or set of parallel arrays). |
34 | * arrayFilter(x -> predicate, array) - leave in the array only the elements for which the expression is true. |
35 | * |
36 | * For some functions arrayCount, arrayExists, arrayAll, an overload of the form f(array) is available, which works in the same way as f(x -> x, array). |
37 | * |
38 | * See the example of Impl template parameter in arrayMap.cpp |
39 | */ |
40 | template <typename Impl, typename Name> |
41 | class FunctionArrayMapped : public IFunction |
42 | { |
43 | public: |
44 | static constexpr auto name = Name::name; |
45 | static FunctionPtr create(const Context &) { return std::make_shared<FunctionArrayMapped>(); } |
46 | |
47 | String getName() const override |
48 | { |
49 | return name; |
50 | } |
51 | |
52 | bool isVariadic() const override { return true; } |
53 | size_t getNumberOfArguments() const override { return 0; } |
54 | |
55 | /// Called if at least one function argument is a lambda expression. |
56 | /// For argument-lambda expressions, it defines the types of arguments of these expressions. |
57 | void getLambdaArgumentTypes(DataTypes & arguments) const override |
58 | { |
59 | if (arguments.size() < 1) |
60 | throw Exception("Function " + getName() + " needs at least one argument; passed " |
61 | + toString(arguments.size()) + "." , |
62 | ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); |
63 | |
64 | if (arguments.size() == 1) |
65 | throw Exception("Function " + getName() + " needs at least one array argument." , |
66 | ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); |
67 | |
68 | DataTypes nested_types(arguments.size() - 1); |
69 | for (size_t i = 0; i < nested_types.size(); ++i) |
70 | { |
71 | const DataTypeArray * array_type = checkAndGetDataType<DataTypeArray>(&*arguments[i + 1]); |
72 | if (!array_type) |
73 | throw Exception("Argument " + toString(i + 2) + " of function " + getName() + " must be array. Found " |
74 | + arguments[i + 1]->getName() + " instead." , ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); |
75 | nested_types[i] = removeLowCardinality(array_type->getNestedType()); |
76 | } |
77 | |
78 | const DataTypeFunction * function_type = checkAndGetDataType<DataTypeFunction>(arguments[0].get()); |
79 | if (!function_type || function_type->getArgumentTypes().size() != nested_types.size()) |
80 | throw Exception("First argument for this overload of " + getName() + " must be a function with " |
81 | + toString(nested_types.size()) + " arguments. Found " |
82 | + arguments[0]->getName() + " instead." , ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); |
83 | |
84 | arguments[0] = std::make_shared<DataTypeFunction>(nested_types); |
85 | } |
86 | |
87 | DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override |
88 | { |
89 | size_t min_args = Impl::needExpression() ? 2 : 1; |
90 | if (arguments.size() < min_args) |
91 | throw Exception("Function " + getName() + " needs at least " |
92 | + toString(min_args) + " argument; passed " |
93 | + toString(arguments.size()) + "." , |
94 | ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); |
95 | |
96 | if (arguments.size() == 1) |
97 | { |
98 | const auto array_type = checkAndGetDataType<DataTypeArray>(arguments[0].type.get()); |
99 | |
100 | if (!array_type) |
101 | throw Exception("The only argument for function " + getName() + " must be array. Found " |
102 | + arguments[0].type->getName() + " instead." , ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); |
103 | |
104 | DataTypePtr nested_type = array_type->getNestedType(); |
105 | |
106 | if (Impl::needBoolean() && !WhichDataType(nested_type).isUInt8()) |
107 | throw Exception("The only argument for function " + getName() + " must be array of UInt8. Found " |
108 | + arguments[0].type->getName() + " instead." , ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); |
109 | |
110 | return Impl::getReturnType(nested_type, nested_type); |
111 | } |
112 | else |
113 | { |
114 | if (arguments.size() > 2 && Impl::needOneArray()) |
115 | throw Exception("Function " + getName() + " needs one array argument." , |
116 | ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); |
117 | |
118 | const auto data_type_function = checkAndGetDataType<DataTypeFunction>(arguments[0].type.get()); |
119 | |
120 | if (!data_type_function) |
121 | throw Exception("First argument for function " + getName() + " must be a function." , |
122 | ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); |
123 | |
124 | /// The types of the remaining arguments are already checked in getLambdaArgumentTypes. |
125 | |
126 | DataTypePtr return_type = removeLowCardinality(data_type_function->getReturnType()); |
127 | if (Impl::needBoolean() && !WhichDataType(return_type).isUInt8()) |
128 | throw Exception("Expression for function " + getName() + " must return UInt8, found " |
129 | + return_type->getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); |
130 | |
131 | const auto first_array_type = checkAndGetDataType<DataTypeArray>(arguments[1].type.get()); |
132 | |
133 | return Impl::getReturnType(return_type, first_array_type->getNestedType()); |
134 | } |
135 | } |
136 | |
137 | void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t /*input_rows_count*/) override |
138 | { |
139 | if (arguments.size() == 1) |
140 | { |
141 | ColumnPtr column_array_ptr = block.getByPosition(arguments[0]).column; |
142 | const auto * column_array = checkAndGetColumn<ColumnArray>(column_array_ptr.get()); |
143 | |
144 | if (!column_array) |
145 | { |
146 | const ColumnConst * column_const_array = checkAndGetColumnConst<ColumnArray>(column_array_ptr.get()); |
147 | if (!column_const_array) |
148 | throw Exception("Expected array column, found " + column_array_ptr->getName(), ErrorCodes::ILLEGAL_COLUMN); |
149 | column_array_ptr = column_const_array->convertToFullColumn(); |
150 | column_array = assert_cast<const ColumnArray *>(column_array_ptr.get()); |
151 | } |
152 | |
153 | block.getByPosition(result).column = Impl::execute(*column_array, column_array->getDataPtr()); |
154 | } |
155 | else |
156 | { |
157 | const auto & column_with_type_and_name = block.getByPosition(arguments[0]); |
158 | |
159 | if (!column_with_type_and_name.column) |
160 | throw Exception("First argument for function " + getName() + " must be a function." , |
161 | ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); |
162 | |
163 | const auto * column_function = typeid_cast<const ColumnFunction *>(column_with_type_and_name.column.get()); |
164 | |
165 | if (!column_function) |
166 | throw Exception("First argument for function " + getName() + " must be a function." , |
167 | ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); |
168 | |
169 | ColumnPtr offsets_column; |
170 | |
171 | ColumnPtr column_first_array_ptr; |
172 | const ColumnArray * column_first_array = nullptr; |
173 | |
174 | ColumnsWithTypeAndName arrays; |
175 | arrays.reserve(arguments.size() - 1); |
176 | |
177 | for (size_t i = 1; i < arguments.size(); ++i) |
178 | { |
179 | const auto & array_with_type_and_name = block.getByPosition(arguments[i]); |
180 | |
181 | ColumnPtr column_array_ptr = array_with_type_and_name.column; |
182 | const auto * column_array = checkAndGetColumn<ColumnArray>(column_array_ptr.get()); |
183 | |
184 | const DataTypePtr & array_type_ptr = array_with_type_and_name.type; |
185 | const auto * array_type = checkAndGetDataType<DataTypeArray>(array_type_ptr.get()); |
186 | |
187 | if (!column_array) |
188 | { |
189 | const ColumnConst * column_const_array = checkAndGetColumnConst<ColumnArray>(column_array_ptr.get()); |
190 | if (!column_const_array) |
191 | throw Exception("Expected array column, found " + column_array_ptr->getName(), ErrorCodes::ILLEGAL_COLUMN); |
192 | column_array_ptr = column_const_array->convertToFullColumn(); |
193 | if (column_array_ptr->lowCardinality()) |
194 | column_array_ptr = column_array_ptr->convertToFullColumnIfLowCardinality(); |
195 | column_array = checkAndGetColumn<ColumnArray>(column_array_ptr.get()); |
196 | } |
197 | |
198 | if (!array_type) |
199 | throw Exception("Expected array type, found " + array_type_ptr->getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); |
200 | |
201 | if (!offsets_column) |
202 | { |
203 | offsets_column = column_array->getOffsetsPtr(); |
204 | } |
205 | else |
206 | { |
207 | /// The first condition is optimization: do not compare data if the pointers are equal. |
208 | if (column_array->getOffsetsPtr() != offsets_column |
209 | && column_array->getOffsets() != typeid_cast<const ColumnArray::ColumnOffsets &>(*offsets_column).getData()) |
210 | throw Exception("Arrays passed to " + getName() + " must have equal size" , ErrorCodes::SIZES_OF_ARRAYS_DOESNT_MATCH); |
211 | } |
212 | |
213 | if (i == 1) |
214 | { |
215 | column_first_array_ptr = column_array_ptr; |
216 | column_first_array = column_array; |
217 | } |
218 | |
219 | arrays.emplace_back(ColumnWithTypeAndName(column_array->getDataPtr(), |
220 | removeLowCardinality(array_type->getNestedType()), |
221 | array_with_type_and_name.name)); |
222 | } |
223 | |
224 | /// Put all the necessary columns multiplied by the sizes of arrays into the block. |
225 | auto replicated_column_function_ptr = (*column_function->replicate(column_first_array->getOffsets())).mutate(); |
226 | auto * replicated_column_function = typeid_cast<ColumnFunction *>(replicated_column_function_ptr.get()); |
227 | replicated_column_function->appendArguments(arrays); |
228 | |
229 | auto lambda_result = replicated_column_function->reduce().column; |
230 | if (lambda_result->lowCardinality()) |
231 | lambda_result = lambda_result->convertToFullColumnIfLowCardinality(); |
232 | |
233 | block.getByPosition(result).column = Impl::execute(*column_first_array, lambda_result); |
234 | } |
235 | } |
236 | }; |
237 | |
238 | } |
239 | |