1#include <Functions/FunctionHelpers.h>
2#include <Functions/IFunctionImpl.h>
3#include <Columns/ColumnTuple.h>
4#include <Columns/ColumnString.h>
5#include <Columns/ColumnFixedString.h>
6#include <Columns/ColumnNullable.h>
7#include <Columns/ColumnLowCardinality.h>
8#include <Common/assert_cast.h>
9#include <DataTypes/DataTypeNullable.h>
10#include <DataTypes/DataTypeLowCardinality.h>
11#include <IO/WriteHelpers.h>
12
13
14namespace DB
15{
16
17namespace ErrorCodes
18{
19 extern const int ILLEGAL_COLUMN;
20 extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
21 extern const int SIZES_OF_ARRAYS_DOESNT_MATCH;
22 extern const int ILLEGAL_TYPE_OF_ARGUMENT;
23}
24
25const ColumnConst * checkAndGetColumnConstStringOrFixedString(const IColumn * column)
26{
27 if (!isColumnConst(*column))
28 return {};
29
30 const ColumnConst * res = assert_cast<const ColumnConst *>(column);
31
32 if (checkColumn<ColumnString>(&res->getDataColumn())
33 || checkColumn<ColumnFixedString>(&res->getDataColumn()))
34 return res;
35
36 return {};
37}
38
39
40Columns convertConstTupleToConstantElements(const ColumnConst & column)
41{
42 const ColumnTuple & src_tuple = assert_cast<const ColumnTuple &>(column.getDataColumn());
43 const auto & src_tuple_columns = src_tuple.getColumns();
44 size_t tuple_size = src_tuple_columns.size();
45 size_t rows = column.size();
46
47 Columns res(tuple_size);
48 for (size_t i = 0; i < tuple_size; ++i)
49 res[i] = ColumnConst::create(src_tuple_columns[i], rows);
50
51 return res;
52}
53
54
55static Block createBlockWithNestedColumnsImpl(const Block & block, const std::unordered_set<size_t> & args)
56{
57 Block res;
58 size_t columns = block.columns();
59
60 for (size_t i = 0; i < columns; ++i)
61 {
62 const auto & col = block.getByPosition(i);
63
64 if (args.count(i) && col.type->isNullable())
65 {
66 const DataTypePtr & nested_type = static_cast<const DataTypeNullable &>(*col.type).getNestedType();
67
68 if (!col.column)
69 {
70 res.insert({nullptr, nested_type, col.name});
71 }
72 else if (auto * nullable = checkAndGetColumn<ColumnNullable>(*col.column))
73 {
74 const auto & nested_col = nullable->getNestedColumnPtr();
75 res.insert({nested_col, nested_type, col.name});
76 }
77 else if (auto * const_column = checkAndGetColumn<ColumnConst>(*col.column))
78 {
79 const auto & nested_col = checkAndGetColumn<ColumnNullable>(const_column->getDataColumn())->getNestedColumnPtr();
80 res.insert({ ColumnConst::create(nested_col, col.column->size()), nested_type, col.name});
81 }
82 else
83 throw Exception("Illegal column for DataTypeNullable", ErrorCodes::ILLEGAL_COLUMN);
84 }
85 else
86 res.insert(col);
87 }
88
89 return res;
90}
91
92
93Block createBlockWithNestedColumns(const Block & block, const ColumnNumbers & args)
94{
95 std::unordered_set<size_t> args_set(args.begin(), args.end());
96 return createBlockWithNestedColumnsImpl(block, args_set);
97}
98
99Block createBlockWithNestedColumns(const Block & block, const ColumnNumbers & args, size_t result)
100{
101 std::unordered_set<size_t> args_set(args.begin(), args.end());
102 args_set.insert(result);
103 return createBlockWithNestedColumnsImpl(block, args_set);
104}
105
106void validateArgumentType(const IFunction & func, const DataTypes & arguments,
107 size_t argument_index, bool (* validator_func)(const IDataType &),
108 const char * expected_type_description)
109{
110 if (arguments.size() <= argument_index)
111 throw Exception("Incorrect number of arguments of function " + func.getName(),
112 ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
113
114 const auto & argument = arguments[argument_index];
115 if (validator_func(*argument) == false)
116 throw Exception("Illegal type " + argument->getName() +
117 " of " + std::to_string(argument_index) +
118 " argument of function " + func.getName() +
119 " expected " + expected_type_description,
120 ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
121}
122
123namespace
124{
125void validateArgumentsImpl(const IFunction & func,
126 const ColumnsWithTypeAndName & arguments,
127 size_t argument_offset,
128 const FunctionArgumentDescriptors & descriptors)
129{
130 for (size_t i = 0; i < descriptors.size(); ++i)
131 {
132 const auto argument_index = i + argument_offset;
133 if (argument_index >= arguments.size())
134 {
135 break;
136 }
137
138 const auto & arg = arguments[i + argument_offset];
139 const auto descriptor = descriptors[i];
140 if (int errorCode = descriptor.isValid(arg.type, arg.column); errorCode != 0)
141 throw Exception("Illegal type of argument #" + std::to_string(i)
142 + (descriptor.argument_name ? " '" + std::string(descriptor.argument_name) + "'" : String{})
143 + " of function " + func.getName()
144 + (descriptor.expected_type_description ? String(", expected ") + descriptor.expected_type_description : String{})
145 + (arg.type ? ", got " + arg.type->getName() : String{}),
146 errorCode);
147 }
148}
149
150}
151
152int FunctionArgumentDescriptor::isValid(const DataTypePtr & data_type, const ColumnPtr & column) const
153{
154 if (type_validator_func && (data_type == nullptr || type_validator_func(*data_type) == false))
155 return ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT;
156
157 if (column_validator_func && (column == nullptr || column_validator_func(*column) == false))
158 return ErrorCodes::ILLEGAL_COLUMN;
159
160 return 0;
161}
162
163void validateFunctionArgumentTypes(const IFunction & func,
164 const ColumnsWithTypeAndName & arguments,
165 const FunctionArgumentDescriptors & mandatory_args,
166 const FunctionArgumentDescriptors & optional_args)
167{
168 if (arguments.size() < mandatory_args.size() || arguments.size() > mandatory_args.size() + optional_args.size())
169 {
170 auto joinArgumentTypes = [](const auto & args, const String sep = ", ") -> String
171 {
172 String result;
173 for (const auto & a : args)
174 {
175 using A = std::decay_t<decltype(a)>;
176 if constexpr (std::is_same_v<A, FunctionArgumentDescriptor>)
177 {
178 if (a.argument_name)
179 result += "'" + std::string(a.argument_name) + "' : ";
180 if (a.expected_type_description)
181 result += a.expected_type_description;
182 }
183 else if constexpr (std::is_same_v<A, ColumnWithTypeAndName>)
184 result += a.type->getName();
185
186 result += sep;
187 }
188
189 if (args.size() != 0)
190 result.erase(result.end() - sep.length(), result.end());
191
192 return result;
193 };
194
195 throw Exception("Incorrect number of arguments for function " + func.getName()
196 + " provided " + std::to_string(arguments.size())
197 + (arguments.size() ? " (" + joinArgumentTypes(arguments) + ")" : String{})
198 + ", expected " + std::to_string(mandatory_args.size())
199 + (optional_args.size() ? " to " + std::to_string(mandatory_args.size() + optional_args.size()) : "")
200 + " (" + joinArgumentTypes(mandatory_args)
201 + (optional_args.size() ? ", [" + joinArgumentTypes(optional_args) + "]" : "")
202 + ")",
203 ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
204 }
205
206 validateArgumentsImpl(func, arguments, 0, mandatory_args);
207 if (optional_args.size())
208 {
209 validateArgumentsImpl(func, arguments, mandatory_args.size(), optional_args);
210 }
211}
212
213std::pair<std::vector<const IColumn *>, const ColumnArray::Offset *>
214checkAndGetNestedArrayOffset(const IColumn ** columns, size_t num_arguments)
215{
216 assert(num_arguments > 0);
217 std::vector<const IColumn *> nested_columns(num_arguments);
218 const ColumnArray::Offsets * offsets = nullptr;
219 for (size_t i = 0; i < num_arguments; ++i)
220 {
221 const ColumnArray::Offsets * offsets_i = nullptr;
222 if (const ColumnArray * arr = checkAndGetColumn<const ColumnArray>(columns[i]))
223 {
224 nested_columns[i] = &arr->getData();
225 offsets_i = &arr->getOffsets();
226 }
227 else
228 throw Exception("Illegal column " + columns[i]->getName() + " as argument of function", ErrorCodes::ILLEGAL_COLUMN);
229 if (i == 0)
230 offsets = offsets_i;
231 else if (*offsets_i != *offsets)
232 throw Exception("Lengths of all arrays passed to aggregate function must be equal.", ErrorCodes::SIZES_OF_ARRAYS_DOESNT_MATCH);
233 }
234 return {nested_columns, offsets->data()};
235}
236
237}
238