1#include <Functions/IFunctionImpl.h>
2#include <Functions/FunctionFactory.h>
3#include <Functions/FunctionHelpers.h>
4#include <DataTypes/DataTypeArray.h>
5#include <Columns/ColumnArray.h>
6#include <Columns/ColumnString.h>
7#include <Columns/ColumnAggregateFunction.h>
8#include <IO/WriteHelpers.h>
9#include <AggregateFunctions/AggregateFunctionFactory.h>
10#include <AggregateFunctions/AggregateFunctionState.h>
11#include <AggregateFunctions/IAggregateFunction.h>
12#include <AggregateFunctions/parseAggregateFunctionParameters.h>
13#include <Common/AlignedBuffer.h>
14#include <Common/Arena.h>
15
16#include <ext/scope_guard.h>
17
18
19namespace DB
20{
21
22namespace ErrorCodes
23{
24 extern const int SIZES_OF_ARRAYS_DOESNT_MATCH;
25 extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
26 extern const int ILLEGAL_COLUMN;
27 extern const int ILLEGAL_TYPE_OF_ARGUMENT;
28 extern const int BAD_ARGUMENTS;
29}
30
31
32/** Applies an aggregate function to array and returns its result.
33 * If aggregate function has multiple arguments, then this function can be applied to multiple arrays of the same size.
34 *
35 * arrayReduce('agg', arr1, ...) - apply the aggregate function `agg` to arrays `arr1...`
36 * If multiple arrays passed, then elements on corresponding positions are passed as multiple arguments to the aggregate function.
37 */
38class FunctionArrayReduce : public IFunction
39{
40public:
41 static constexpr auto name = "arrayReduce";
42 static FunctionPtr create(const Context &) { return std::make_shared<FunctionArrayReduce>(); }
43
44 String getName() const override { return name; }
45
46 bool isVariadic() const override { return true; }
47 size_t getNumberOfArguments() const override { return 0; }
48
49 bool useDefaultImplementationForConstants() const override { return true; }
50 ColumnNumbers getArgumentsThatAreAlwaysConstant() const override { return {0}; }
51
52 DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override;
53
54 void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t input_rows_count) override;
55
56private:
57 /// lazy initialization in getReturnTypeImpl
58 /// TODO: init in OverloadResolver
59 mutable AggregateFunctionPtr aggregate_function;
60};
61
62
63DataTypePtr FunctionArrayReduce::getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const
64{
65 /// The first argument is a constant string with the name of the aggregate function
66 /// (possibly with parameters in parentheses, for example: "quantile(0.99)").
67
68 if (arguments.size() < 2)
69 throw Exception("Number of arguments for function " + getName() + " doesn't match: passed "
70 + toString(arguments.size()) + ", should be at least 2.",
71 ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
72
73 const ColumnConst * aggregate_function_name_column = checkAndGetColumnConst<ColumnString>(arguments[0].column.get());
74 if (!aggregate_function_name_column)
75 throw Exception("First argument for function " + getName() + " must be constant string: name of aggregate function.",
76 ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
77
78 DataTypes argument_types(arguments.size() - 1);
79 for (size_t i = 1, size = arguments.size(); i < size; ++i)
80 {
81 const DataTypeArray * arg = checkAndGetDataType<DataTypeArray>(arguments[i].type.get());
82 if (!arg)
83 throw Exception("Argument " + toString(i) + " for function " + getName() + " must be an array but it has type "
84 + arguments[i].type->getName() + ".", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
85
86 argument_types[i - 1] = arg->getNestedType();
87 }
88
89 if (!aggregate_function)
90 {
91 String aggregate_function_name_with_params = aggregate_function_name_column->getValue<String>();
92
93 if (aggregate_function_name_with_params.empty())
94 throw Exception("First argument for function " + getName() + " (name of aggregate function) cannot be empty.",
95 ErrorCodes::BAD_ARGUMENTS);
96
97 String aggregate_function_name;
98 Array params_row;
99 getAggregateFunctionNameAndParametersArray(aggregate_function_name_with_params,
100 aggregate_function_name, params_row, "function " + getName());
101
102 aggregate_function = AggregateFunctionFactory::instance().get(aggregate_function_name, argument_types, params_row);
103 }
104
105 return aggregate_function->getReturnType();
106}
107
108
109void FunctionArrayReduce::executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t input_rows_count)
110{
111 IAggregateFunction & agg_func = *aggregate_function.get();
112 std::unique_ptr<Arena> arena = std::make_unique<Arena>();
113
114 /// Aggregate functions do not support constant columns. Therefore, we materialize them.
115 std::vector<ColumnPtr> materialized_columns;
116
117 const size_t num_arguments_columns = arguments.size() - 1;
118
119 std::vector<const IColumn *> aggregate_arguments_vec(num_arguments_columns);
120 const ColumnArray::Offsets * offsets = nullptr;
121
122 for (size_t i = 0; i < num_arguments_columns; ++i)
123 {
124 const IColumn * col = block.getByPosition(arguments[i + 1]).column.get();
125
126 const ColumnArray::Offsets * offsets_i = nullptr;
127 if (const ColumnArray * arr = checkAndGetColumn<ColumnArray>(col))
128 {
129 aggregate_arguments_vec[i] = &arr->getData();
130 offsets_i = &arr->getOffsets();
131 }
132 else if (const ColumnConst * const_arr = checkAndGetColumnConst<ColumnArray>(col))
133 {
134 materialized_columns.emplace_back(const_arr->convertToFullColumn());
135 const auto & materialized_arr = typeid_cast<const ColumnArray &>(*materialized_columns.back().get());
136 aggregate_arguments_vec[i] = &materialized_arr.getData();
137 offsets_i = &materialized_arr.getOffsets();
138 }
139 else
140 throw Exception("Illegal column " + col->getName() + " as argument of function " + getName(), ErrorCodes::ILLEGAL_COLUMN);
141
142 if (i == 0)
143 offsets = offsets_i;
144 else if (*offsets_i != *offsets)
145 throw Exception("Lengths of all arrays passed to " + getName() + " must be equal.",
146 ErrorCodes::SIZES_OF_ARRAYS_DOESNT_MATCH);
147 }
148 const IColumn ** aggregate_arguments = aggregate_arguments_vec.data();
149
150 MutableColumnPtr result_holder = block.getByPosition(result).type->createColumn();
151 IColumn & res_col = *result_holder;
152
153 /// AggregateFunction's states should be inserted into column using specific way
154 auto res_col_aggregate_function = typeid_cast<ColumnAggregateFunction *>(&res_col);
155
156 if (!res_col_aggregate_function && agg_func.isState())
157 throw Exception("State function " + agg_func.getName() + " inserts results into non-state column "
158 + block.getByPosition(result).type->getName(), ErrorCodes::ILLEGAL_COLUMN);
159
160 PODArray<AggregateDataPtr> places(input_rows_count);
161 for (size_t i = 0; i < input_rows_count; ++i)
162 {
163 places[i] = arena->alignedAlloc(agg_func.sizeOfData(), agg_func.alignOfData());
164 try
165 {
166 agg_func.create(places[i]);
167 }
168 catch (...)
169 {
170 for (size_t j = 0; j < i; ++j)
171 agg_func.destroy(places[j]);
172 throw;
173 }
174 }
175
176 SCOPE_EXIT({
177 for (size_t i = 0; i < input_rows_count; ++i)
178 agg_func.destroy(places[i]);
179 });
180
181 {
182 auto that = &agg_func;
183 /// Unnest consecutive trailing -State combinators
184 while (auto func = typeid_cast<AggregateFunctionState *>(that))
185 that = func->getNestedFunction().get();
186
187 that->addBatchArray(input_rows_count, places.data(), 0, aggregate_arguments, offsets->data(), arena.get());
188 }
189
190 for (size_t i = 0; i < input_rows_count; ++i)
191 if (!res_col_aggregate_function)
192 agg_func.insertResultInto(places[i], res_col);
193 else
194 res_col_aggregate_function->insertFrom(places[i]);
195 block.getByPosition(result).column = std::move(result_holder);
196}
197
198
199void registerFunctionArrayReduce(FunctionFactory & factory)
200{
201 factory.registerFunction<FunctionArrayReduce>();
202}
203
204}
205