1#include <Functions/FunctionsExternalModels.h>
2#include <Functions/FunctionHelpers.h>
3#include <Functions/FunctionFactory.h>
4
5#include <Interpreters/Context.h>
6#include <Interpreters/ExternalModelsLoader.h>
7#include <DataTypes/DataTypeString.h>
8#include <DataTypes/DataTypesNumber.h>
9#include <Columns/ColumnString.h>
10#include <ext/range.h>
11#include <string>
12#include <memory>
13#include <DataTypes/DataTypeNullable.h>
14#include <Columns/ColumnNullable.h>
15#include <Columns/ColumnTuple.h>
16#include <DataTypes/DataTypeTuple.h>
17#include <Common/assert_cast.h>
18
19
20namespace DB
21{
22
23FunctionPtr FunctionModelEvaluate::create(const Context & context)
24{
25 return std::make_shared<FunctionModelEvaluate>(context.getExternalModelsLoader());
26}
27
28namespace ErrorCodes
29{
30 extern const int ILLEGAL_TYPE_OF_ARGUMENT;
31 extern const int TOO_FEW_ARGUMENTS_FOR_FUNCTION;
32 extern const int ILLEGAL_COLUMN;
33}
34
35DataTypePtr FunctionModelEvaluate::getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const
36{
37 if (arguments.size() < 2)
38 throw Exception("Function " + getName() + " expects at least 2 arguments",
39 ErrorCodes::TOO_FEW_ARGUMENTS_FOR_FUNCTION);
40
41 if (!isString(arguments[0].type))
42 throw Exception("Illegal type " + arguments[0].type->getName() + " of first argument of function " + getName()
43 + ", expected a string.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
44
45 const auto name_col = checkAndGetColumnConst<ColumnString>(arguments[0].column.get());
46 if (!name_col)
47 throw Exception("First argument of function " + getName() + " must be a constant string",
48 ErrorCodes::ILLEGAL_COLUMN);
49
50 bool has_nullable = false;
51 for (size_t i = 1; i < arguments.size(); ++i)
52 has_nullable = has_nullable || arguments[i].type->isNullable();
53
54 auto model = models_loader.getModel(name_col->getValue<String>());
55 auto type = model->getReturnType();
56
57 if (has_nullable)
58 {
59 if (auto * tuple = typeid_cast<const DataTypeTuple *>(type.get()))
60 {
61 auto elements = tuple->getElements();
62 for (auto & element : elements)
63 element = makeNullable(element);
64
65 type = std::make_shared<DataTypeTuple>(elements);
66 }
67 else
68 type = makeNullable(type);
69 }
70
71 return type;
72}
73
74void FunctionModelEvaluate::executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t /*input_rows_count*/)
75{
76 const auto name_col = checkAndGetColumnConst<ColumnString>(block.getByPosition(arguments[0]).column.get());
77 if (!name_col)
78 throw Exception("First argument of function " + getName() + " must be a constant string",
79 ErrorCodes::ILLEGAL_COLUMN);
80
81 auto model = models_loader.getModel(name_col->getValue<String>());
82
83 ColumnRawPtrs columns;
84 Columns materialized_columns;
85 ColumnPtr null_map;
86
87 columns.reserve(arguments.size());
88 for (auto arg : ext::range(1, arguments.size()))
89 {
90 auto & column = block.getByPosition(arguments[arg]).column;
91 columns.push_back(column.get());
92 if (auto full_column = column->convertToFullColumnIfConst())
93 {
94 materialized_columns.push_back(full_column);
95 columns.back() = full_column.get();
96 }
97 if (auto * col_nullable = checkAndGetColumn<ColumnNullable>(*columns.back()))
98 {
99 if (!null_map)
100 null_map = col_nullable->getNullMapColumnPtr();
101 else
102 {
103 auto mut_null_map = (*std::move(null_map)).mutate();
104
105 NullMap & result_null_map = assert_cast<ColumnUInt8 &>(*mut_null_map).getData();
106 const NullMap & src_null_map = col_nullable->getNullMapColumn().getData();
107
108 for (size_t i = 0, size = result_null_map.size(); i < size; ++i)
109 if (src_null_map[i])
110 result_null_map[i] = 1;
111
112 null_map = std::move(mut_null_map);
113 }
114
115 columns.back() = &col_nullable->getNestedColumn();
116 }
117 }
118
119 auto res = model->evaluate(columns);
120
121 if (null_map)
122 {
123 if (auto * tuple = typeid_cast<const ColumnTuple *>(res.get()))
124 {
125 auto nested = tuple->getColumns();
126 for (auto & col : nested)
127 col = ColumnNullable::create(col, null_map);
128
129 res = ColumnTuple::create(nested);
130 }
131 else
132 res = ColumnNullable::create(res, null_map);
133 }
134
135 block.getByPosition(result).column = res;
136}
137
138void registerFunctionsExternalModels(FunctionFactory & factory)
139{
140 factory.registerFunction<FunctionModelEvaluate>();
141}
142
143}
144