1 | #include <Functions/FunctionsExternalModels.h> |
2 | #include <Functions/FunctionHelpers.h> |
3 | #include <Functions/FunctionFactory.h> |
4 | |
5 | #include <Interpreters/Context.h> |
6 | #include <Interpreters/ExternalModelsLoader.h> |
7 | #include <DataTypes/DataTypeString.h> |
8 | #include <DataTypes/DataTypesNumber.h> |
9 | #include <Columns/ColumnString.h> |
10 | #include <ext/range.h> |
11 | #include <string> |
12 | #include <memory> |
13 | #include <DataTypes/DataTypeNullable.h> |
14 | #include <Columns/ColumnNullable.h> |
15 | #include <Columns/ColumnTuple.h> |
16 | #include <DataTypes/DataTypeTuple.h> |
17 | #include <Common/assert_cast.h> |
18 | |
19 | |
20 | namespace DB |
21 | { |
22 | |
23 | FunctionPtr FunctionModelEvaluate::create(const Context & context) |
24 | { |
25 | return std::make_shared<FunctionModelEvaluate>(context.getExternalModelsLoader()); |
26 | } |
27 | |
28 | namespace ErrorCodes |
29 | { |
30 | extern const int ILLEGAL_TYPE_OF_ARGUMENT; |
31 | extern const int TOO_FEW_ARGUMENTS_FOR_FUNCTION; |
32 | extern const int ILLEGAL_COLUMN; |
33 | } |
34 | |
35 | DataTypePtr FunctionModelEvaluate::getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const |
36 | { |
37 | if (arguments.size() < 2) |
38 | throw Exception("Function " + getName() + " expects at least 2 arguments" , |
39 | ErrorCodes::TOO_FEW_ARGUMENTS_FOR_FUNCTION); |
40 | |
41 | if (!isString(arguments[0].type)) |
42 | throw Exception("Illegal type " + arguments[0].type->getName() + " of first argument of function " + getName() |
43 | + ", expected a string." , ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); |
44 | |
45 | const auto name_col = checkAndGetColumnConst<ColumnString>(arguments[0].column.get()); |
46 | if (!name_col) |
47 | throw Exception("First argument of function " + getName() + " must be a constant string" , |
48 | ErrorCodes::ILLEGAL_COLUMN); |
49 | |
50 | bool has_nullable = false; |
51 | for (size_t i = 1; i < arguments.size(); ++i) |
52 | has_nullable = has_nullable || arguments[i].type->isNullable(); |
53 | |
54 | auto model = models_loader.getModel(name_col->getValue<String>()); |
55 | auto type = model->getReturnType(); |
56 | |
57 | if (has_nullable) |
58 | { |
59 | if (auto * tuple = typeid_cast<const DataTypeTuple *>(type.get())) |
60 | { |
61 | auto elements = tuple->getElements(); |
62 | for (auto & element : elements) |
63 | element = makeNullable(element); |
64 | |
65 | type = std::make_shared<DataTypeTuple>(elements); |
66 | } |
67 | else |
68 | type = makeNullable(type); |
69 | } |
70 | |
71 | return type; |
72 | } |
73 | |
74 | void FunctionModelEvaluate::executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t /*input_rows_count*/) |
75 | { |
76 | const auto name_col = checkAndGetColumnConst<ColumnString>(block.getByPosition(arguments[0]).column.get()); |
77 | if (!name_col) |
78 | throw Exception("First argument of function " + getName() + " must be a constant string" , |
79 | ErrorCodes::ILLEGAL_COLUMN); |
80 | |
81 | auto model = models_loader.getModel(name_col->getValue<String>()); |
82 | |
83 | ColumnRawPtrs columns; |
84 | Columns materialized_columns; |
85 | ColumnPtr null_map; |
86 | |
87 | columns.reserve(arguments.size()); |
88 | for (auto arg : ext::range(1, arguments.size())) |
89 | { |
90 | auto & column = block.getByPosition(arguments[arg]).column; |
91 | columns.push_back(column.get()); |
92 | if (auto full_column = column->convertToFullColumnIfConst()) |
93 | { |
94 | materialized_columns.push_back(full_column); |
95 | columns.back() = full_column.get(); |
96 | } |
97 | if (auto * col_nullable = checkAndGetColumn<ColumnNullable>(*columns.back())) |
98 | { |
99 | if (!null_map) |
100 | null_map = col_nullable->getNullMapColumnPtr(); |
101 | else |
102 | { |
103 | auto mut_null_map = (*std::move(null_map)).mutate(); |
104 | |
105 | NullMap & result_null_map = assert_cast<ColumnUInt8 &>(*mut_null_map).getData(); |
106 | const NullMap & src_null_map = col_nullable->getNullMapColumn().getData(); |
107 | |
108 | for (size_t i = 0, size = result_null_map.size(); i < size; ++i) |
109 | if (src_null_map[i]) |
110 | result_null_map[i] = 1; |
111 | |
112 | null_map = std::move(mut_null_map); |
113 | } |
114 | |
115 | columns.back() = &col_nullable->getNestedColumn(); |
116 | } |
117 | } |
118 | |
119 | auto res = model->evaluate(columns); |
120 | |
121 | if (null_map) |
122 | { |
123 | if (auto * tuple = typeid_cast<const ColumnTuple *>(res.get())) |
124 | { |
125 | auto nested = tuple->getColumns(); |
126 | for (auto & col : nested) |
127 | col = ColumnNullable::create(col, null_map); |
128 | |
129 | res = ColumnTuple::create(nested); |
130 | } |
131 | else |
132 | res = ColumnNullable::create(res, null_map); |
133 | } |
134 | |
135 | block.getByPosition(result).column = res; |
136 | } |
137 | |
138 | void registerFunctionsExternalModels(FunctionFactory & factory) |
139 | { |
140 | factory.registerFunction<FunctionModelEvaluate>(); |
141 | } |
142 | |
143 | } |
144 | |