1 | #include <Functions/FunctionHelpers.h> |
2 | #include <Functions/IFunctionImpl.h> |
3 | #include <Columns/ColumnTuple.h> |
4 | #include <Columns/ColumnString.h> |
5 | #include <Columns/ColumnFixedString.h> |
6 | #include <Columns/ColumnNullable.h> |
7 | #include <Columns/ColumnLowCardinality.h> |
8 | #include <Common/assert_cast.h> |
9 | #include <DataTypes/DataTypeNullable.h> |
10 | #include <DataTypes/DataTypeLowCardinality.h> |
11 | #include <IO/WriteHelpers.h> |
12 | |
13 | |
14 | namespace DB |
15 | { |
16 | |
17 | namespace ErrorCodes |
18 | { |
19 | extern const int ILLEGAL_COLUMN; |
20 | extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; |
21 | extern const int SIZES_OF_ARRAYS_DOESNT_MATCH; |
22 | extern const int ILLEGAL_TYPE_OF_ARGUMENT; |
23 | } |
24 | |
25 | const ColumnConst * checkAndGetColumnConstStringOrFixedString(const IColumn * column) |
26 | { |
27 | if (!isColumnConst(*column)) |
28 | return {}; |
29 | |
30 | const ColumnConst * res = assert_cast<const ColumnConst *>(column); |
31 | |
32 | if (checkColumn<ColumnString>(&res->getDataColumn()) |
33 | || checkColumn<ColumnFixedString>(&res->getDataColumn())) |
34 | return res; |
35 | |
36 | return {}; |
37 | } |
38 | |
39 | |
40 | Columns convertConstTupleToConstantElements(const ColumnConst & column) |
41 | { |
42 | const ColumnTuple & src_tuple = assert_cast<const ColumnTuple &>(column.getDataColumn()); |
43 | const auto & src_tuple_columns = src_tuple.getColumns(); |
44 | size_t tuple_size = src_tuple_columns.size(); |
45 | size_t rows = column.size(); |
46 | |
47 | Columns res(tuple_size); |
48 | for (size_t i = 0; i < tuple_size; ++i) |
49 | res[i] = ColumnConst::create(src_tuple_columns[i], rows); |
50 | |
51 | return res; |
52 | } |
53 | |
54 | |
55 | static Block createBlockWithNestedColumnsImpl(const Block & block, const std::unordered_set<size_t> & args) |
56 | { |
57 | Block res; |
58 | size_t columns = block.columns(); |
59 | |
60 | for (size_t i = 0; i < columns; ++i) |
61 | { |
62 | const auto & col = block.getByPosition(i); |
63 | |
64 | if (args.count(i) && col.type->isNullable()) |
65 | { |
66 | const DataTypePtr & nested_type = static_cast<const DataTypeNullable &>(*col.type).getNestedType(); |
67 | |
68 | if (!col.column) |
69 | { |
70 | res.insert({nullptr, nested_type, col.name}); |
71 | } |
72 | else if (auto * nullable = checkAndGetColumn<ColumnNullable>(*col.column)) |
73 | { |
74 | const auto & nested_col = nullable->getNestedColumnPtr(); |
75 | res.insert({nested_col, nested_type, col.name}); |
76 | } |
77 | else if (auto * const_column = checkAndGetColumn<ColumnConst>(*col.column)) |
78 | { |
79 | const auto & nested_col = checkAndGetColumn<ColumnNullable>(const_column->getDataColumn())->getNestedColumnPtr(); |
80 | res.insert({ ColumnConst::create(nested_col, col.column->size()), nested_type, col.name}); |
81 | } |
82 | else |
83 | throw Exception("Illegal column for DataTypeNullable" , ErrorCodes::ILLEGAL_COLUMN); |
84 | } |
85 | else |
86 | res.insert(col); |
87 | } |
88 | |
89 | return res; |
90 | } |
91 | |
92 | |
93 | Block createBlockWithNestedColumns(const Block & block, const ColumnNumbers & args) |
94 | { |
95 | std::unordered_set<size_t> args_set(args.begin(), args.end()); |
96 | return createBlockWithNestedColumnsImpl(block, args_set); |
97 | } |
98 | |
99 | Block createBlockWithNestedColumns(const Block & block, const ColumnNumbers & args, size_t result) |
100 | { |
101 | std::unordered_set<size_t> args_set(args.begin(), args.end()); |
102 | args_set.insert(result); |
103 | return createBlockWithNestedColumnsImpl(block, args_set); |
104 | } |
105 | |
106 | void validateArgumentType(const IFunction & func, const DataTypes & arguments, |
107 | size_t argument_index, bool (* validator_func)(const IDataType &), |
108 | const char * expected_type_description) |
109 | { |
110 | if (arguments.size() <= argument_index) |
111 | throw Exception("Incorrect number of arguments of function " + func.getName(), |
112 | ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); |
113 | |
114 | const auto & argument = arguments[argument_index]; |
115 | if (validator_func(*argument) == false) |
116 | throw Exception("Illegal type " + argument->getName() + |
117 | " of " + std::to_string(argument_index) + |
118 | " argument of function " + func.getName() + |
119 | " expected " + expected_type_description, |
120 | ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); |
121 | } |
122 | |
123 | namespace |
124 | { |
125 | void validateArgumentsImpl(const IFunction & func, |
126 | const ColumnsWithTypeAndName & arguments, |
127 | size_t argument_offset, |
128 | const FunctionArgumentDescriptors & descriptors) |
129 | { |
130 | for (size_t i = 0; i < descriptors.size(); ++i) |
131 | { |
132 | const auto argument_index = i + argument_offset; |
133 | if (argument_index >= arguments.size()) |
134 | { |
135 | break; |
136 | } |
137 | |
138 | const auto & arg = arguments[i + argument_offset]; |
139 | const auto descriptor = descriptors[i]; |
140 | if (int errorCode = descriptor.isValid(arg.type, arg.column); errorCode != 0) |
141 | throw Exception("Illegal type of argument #" + std::to_string(i) |
142 | + (descriptor.argument_name ? " '" + std::string(descriptor.argument_name) + "'" : String{}) |
143 | + " of function " + func.getName() |
144 | + (descriptor.expected_type_description ? String(", expected " ) + descriptor.expected_type_description : String{}) |
145 | + (arg.type ? ", got " + arg.type->getName() : String{}), |
146 | errorCode); |
147 | } |
148 | } |
149 | |
150 | } |
151 | |
152 | int FunctionArgumentDescriptor::isValid(const DataTypePtr & data_type, const ColumnPtr & column) const |
153 | { |
154 | if (type_validator_func && (data_type == nullptr || type_validator_func(*data_type) == false)) |
155 | return ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT; |
156 | |
157 | if (column_validator_func && (column == nullptr || column_validator_func(*column) == false)) |
158 | return ErrorCodes::ILLEGAL_COLUMN; |
159 | |
160 | return 0; |
161 | } |
162 | |
163 | void validateFunctionArgumentTypes(const IFunction & func, |
164 | const ColumnsWithTypeAndName & arguments, |
165 | const FunctionArgumentDescriptors & mandatory_args, |
166 | const FunctionArgumentDescriptors & optional_args) |
167 | { |
168 | if (arguments.size() < mandatory_args.size() || arguments.size() > mandatory_args.size() + optional_args.size()) |
169 | { |
170 | auto joinArgumentTypes = [](const auto & args, const String sep = ", " ) -> String |
171 | { |
172 | String result; |
173 | for (const auto & a : args) |
174 | { |
175 | using A = std::decay_t<decltype(a)>; |
176 | if constexpr (std::is_same_v<A, FunctionArgumentDescriptor>) |
177 | { |
178 | if (a.argument_name) |
179 | result += "'" + std::string(a.argument_name) + "' : " ; |
180 | if (a.expected_type_description) |
181 | result += a.expected_type_description; |
182 | } |
183 | else if constexpr (std::is_same_v<A, ColumnWithTypeAndName>) |
184 | result += a.type->getName(); |
185 | |
186 | result += sep; |
187 | } |
188 | |
189 | if (args.size() != 0) |
190 | result.erase(result.end() - sep.length(), result.end()); |
191 | |
192 | return result; |
193 | }; |
194 | |
195 | throw Exception("Incorrect number of arguments for function " + func.getName() |
196 | + " provided " + std::to_string(arguments.size()) |
197 | + (arguments.size() ? " (" + joinArgumentTypes(arguments) + ")" : String{}) |
198 | + ", expected " + std::to_string(mandatory_args.size()) |
199 | + (optional_args.size() ? " to " + std::to_string(mandatory_args.size() + optional_args.size()) : "" ) |
200 | + " (" + joinArgumentTypes(mandatory_args) |
201 | + (optional_args.size() ? ", [" + joinArgumentTypes(optional_args) + "]" : "" ) |
202 | + ")" , |
203 | ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); |
204 | } |
205 | |
206 | validateArgumentsImpl(func, arguments, 0, mandatory_args); |
207 | if (optional_args.size()) |
208 | { |
209 | validateArgumentsImpl(func, arguments, mandatory_args.size(), optional_args); |
210 | } |
211 | } |
212 | |
213 | std::pair<std::vector<const IColumn *>, const ColumnArray::Offset *> |
214 | checkAndGetNestedArrayOffset(const IColumn ** columns, size_t num_arguments) |
215 | { |
216 | assert(num_arguments > 0); |
217 | std::vector<const IColumn *> nested_columns(num_arguments); |
218 | const ColumnArray::Offsets * offsets = nullptr; |
219 | for (size_t i = 0; i < num_arguments; ++i) |
220 | { |
221 | const ColumnArray::Offsets * offsets_i = nullptr; |
222 | if (const ColumnArray * arr = checkAndGetColumn<const ColumnArray>(columns[i])) |
223 | { |
224 | nested_columns[i] = &arr->getData(); |
225 | offsets_i = &arr->getOffsets(); |
226 | } |
227 | else |
228 | throw Exception("Illegal column " + columns[i]->getName() + " as argument of function" , ErrorCodes::ILLEGAL_COLUMN); |
229 | if (i == 0) |
230 | offsets = offsets_i; |
231 | else if (*offsets_i != *offsets) |
232 | throw Exception("Lengths of all arrays passed to aggregate function must be equal." , ErrorCodes::SIZES_OF_ARRAYS_DOESNT_MATCH); |
233 | } |
234 | return {nested_columns, offsets->data()}; |
235 | } |
236 | |
237 | } |
238 | |