1#include <Functions/IFunctionImpl.h>
2#include <Functions/FunctionFactory.h>
3#include <Functions/FunctionHelpers.h>
4#include <DataTypes/DataTypeAggregateFunction.h>
5#include <DataTypes/DataTypesNumber.h>
6#include <Columns/ColumnAggregateFunction.h>
7#include <Common/typeid_cast.h>
8
9#include <Columns/ColumnVector.h>
10#include <Columns/ColumnsNumber.h>
11#include <iostream>
12
13#include <Common/PODArray.h>
14#include <Columns/ColumnArray.h>
15
16namespace DB
17{
18
19 namespace ErrorCodes
20 {
21 extern const int ILLEGAL_COLUMN;
22 extern const int ILLEGAL_TYPE_OF_ARGUMENT;
23 }
24
25
26/** finalizeAggregation(agg_state) - get the result from the aggregation state.
27* Takes state of aggregate function. Returns result of aggregation (finalized state).
28*/
29class FunctionEvalMLMethod : public IFunction
30{
31public:
32 static constexpr auto name = "evalMLMethod";
33 static FunctionPtr create(const Context & context)
34 {
35 return std::make_shared<FunctionEvalMLMethod>(context);
36 }
37 FunctionEvalMLMethod(const Context & context_) : context(context_)
38 {}
39
40 String getName() const override
41 {
42 return name;
43 }
44
45 bool isVariadic() const override
46 {
47 return true;
48 }
49 size_t getNumberOfArguments() const override
50 {
51 return 0;
52 }
53
54 DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
55 {
56 if (arguments.empty())
57 throw Exception("Function " + getName() + " requires at least one argument", ErrorCodes::BAD_ARGUMENTS);
58
59 const auto * type = checkAndGetDataType<DataTypeAggregateFunction>(arguments[0].get());
60 if (!type)
61 throw Exception("Argument for function " + getName() + " must have type AggregateFunction - state of aggregate function.",
62 ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
63
64 return type->getReturnTypeToPredict();
65 }
66
67 void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t /*input_rows_count*/) override
68 {
69 if (arguments.empty())
70 throw Exception("Function " + getName() + " requires at least one argument", ErrorCodes::BAD_ARGUMENTS);
71
72 const auto * model = block.getByPosition(arguments[0]).column.get();
73
74 if (const auto * column_with_states = typeid_cast<const ColumnConst *>(model))
75 model = column_with_states->getDataColumnPtr().get();
76
77 const auto * agg_function = typeid_cast<const ColumnAggregateFunction *>(model);
78
79 if (!agg_function)
80 throw Exception("Illegal column " + block.getByPosition(arguments[0]).column->getName()
81 + " of first argument of function " + getName(), ErrorCodes::ILLEGAL_COLUMN);
82
83 block.getByPosition(result).column = agg_function->predictValues(block, arguments, context);
84 }
85
86 const Context & context;
87};
88
89void registerFunctionEvalMLMethod(FunctionFactory & factory)
90{
91 factory.registerFunction<FunctionEvalMLMethod>();
92}
93
94}
95