| 1 | #include <Functions/FunctionsExternalModels.h> |
| 2 | #include <Functions/FunctionHelpers.h> |
| 3 | #include <Functions/FunctionFactory.h> |
| 4 | |
| 5 | #include <Interpreters/Context.h> |
| 6 | #include <Interpreters/ExternalModelsLoader.h> |
| 7 | #include <DataTypes/DataTypeString.h> |
| 8 | #include <DataTypes/DataTypesNumber.h> |
| 9 | #include <Columns/ColumnString.h> |
| 10 | #include <ext/range.h> |
| 11 | #include <string> |
| 12 | #include <memory> |
| 13 | #include <DataTypes/DataTypeNullable.h> |
| 14 | #include <Columns/ColumnNullable.h> |
| 15 | #include <Columns/ColumnTuple.h> |
| 16 | #include <DataTypes/DataTypeTuple.h> |
| 17 | #include <Common/assert_cast.h> |
| 18 | |
| 19 | |
| 20 | namespace DB |
| 21 | { |
| 22 | |
| 23 | FunctionPtr FunctionModelEvaluate::create(const Context & context) |
| 24 | { |
| 25 | return std::make_shared<FunctionModelEvaluate>(context.getExternalModelsLoader()); |
| 26 | } |
| 27 | |
| 28 | namespace ErrorCodes |
| 29 | { |
| 30 | extern const int ILLEGAL_TYPE_OF_ARGUMENT; |
| 31 | extern const int TOO_FEW_ARGUMENTS_FOR_FUNCTION; |
| 32 | extern const int ILLEGAL_COLUMN; |
| 33 | } |
| 34 | |
| 35 | DataTypePtr FunctionModelEvaluate::getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const |
| 36 | { |
| 37 | if (arguments.size() < 2) |
| 38 | throw Exception("Function " + getName() + " expects at least 2 arguments" , |
| 39 | ErrorCodes::TOO_FEW_ARGUMENTS_FOR_FUNCTION); |
| 40 | |
| 41 | if (!isString(arguments[0].type)) |
| 42 | throw Exception("Illegal type " + arguments[0].type->getName() + " of first argument of function " + getName() |
| 43 | + ", expected a string." , ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); |
| 44 | |
| 45 | const auto name_col = checkAndGetColumnConst<ColumnString>(arguments[0].column.get()); |
| 46 | if (!name_col) |
| 47 | throw Exception("First argument of function " + getName() + " must be a constant string" , |
| 48 | ErrorCodes::ILLEGAL_COLUMN); |
| 49 | |
| 50 | bool has_nullable = false; |
| 51 | for (size_t i = 1; i < arguments.size(); ++i) |
| 52 | has_nullable = has_nullable || arguments[i].type->isNullable(); |
| 53 | |
| 54 | auto model = models_loader.getModel(name_col->getValue<String>()); |
| 55 | auto type = model->getReturnType(); |
| 56 | |
| 57 | if (has_nullable) |
| 58 | { |
| 59 | if (auto * tuple = typeid_cast<const DataTypeTuple *>(type.get())) |
| 60 | { |
| 61 | auto elements = tuple->getElements(); |
| 62 | for (auto & element : elements) |
| 63 | element = makeNullable(element); |
| 64 | |
| 65 | type = std::make_shared<DataTypeTuple>(elements); |
| 66 | } |
| 67 | else |
| 68 | type = makeNullable(type); |
| 69 | } |
| 70 | |
| 71 | return type; |
| 72 | } |
| 73 | |
| 74 | void FunctionModelEvaluate::executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t /*input_rows_count*/) |
| 75 | { |
| 76 | const auto name_col = checkAndGetColumnConst<ColumnString>(block.getByPosition(arguments[0]).column.get()); |
| 77 | if (!name_col) |
| 78 | throw Exception("First argument of function " + getName() + " must be a constant string" , |
| 79 | ErrorCodes::ILLEGAL_COLUMN); |
| 80 | |
| 81 | auto model = models_loader.getModel(name_col->getValue<String>()); |
| 82 | |
| 83 | ColumnRawPtrs columns; |
| 84 | Columns materialized_columns; |
| 85 | ColumnPtr null_map; |
| 86 | |
| 87 | columns.reserve(arguments.size()); |
| 88 | for (auto arg : ext::range(1, arguments.size())) |
| 89 | { |
| 90 | auto & column = block.getByPosition(arguments[arg]).column; |
| 91 | columns.push_back(column.get()); |
| 92 | if (auto full_column = column->convertToFullColumnIfConst()) |
| 93 | { |
| 94 | materialized_columns.push_back(full_column); |
| 95 | columns.back() = full_column.get(); |
| 96 | } |
| 97 | if (auto * col_nullable = checkAndGetColumn<ColumnNullable>(*columns.back())) |
| 98 | { |
| 99 | if (!null_map) |
| 100 | null_map = col_nullable->getNullMapColumnPtr(); |
| 101 | else |
| 102 | { |
| 103 | auto mut_null_map = (*std::move(null_map)).mutate(); |
| 104 | |
| 105 | NullMap & result_null_map = assert_cast<ColumnUInt8 &>(*mut_null_map).getData(); |
| 106 | const NullMap & src_null_map = col_nullable->getNullMapColumn().getData(); |
| 107 | |
| 108 | for (size_t i = 0, size = result_null_map.size(); i < size; ++i) |
| 109 | if (src_null_map[i]) |
| 110 | result_null_map[i] = 1; |
| 111 | |
| 112 | null_map = std::move(mut_null_map); |
| 113 | } |
| 114 | |
| 115 | columns.back() = &col_nullable->getNestedColumn(); |
| 116 | } |
| 117 | } |
| 118 | |
| 119 | auto res = model->evaluate(columns); |
| 120 | |
| 121 | if (null_map) |
| 122 | { |
| 123 | if (auto * tuple = typeid_cast<const ColumnTuple *>(res.get())) |
| 124 | { |
| 125 | auto nested = tuple->getColumns(); |
| 126 | for (auto & col : nested) |
| 127 | col = ColumnNullable::create(col, null_map); |
| 128 | |
| 129 | res = ColumnTuple::create(nested); |
| 130 | } |
| 131 | else |
| 132 | res = ColumnNullable::create(res, null_map); |
| 133 | } |
| 134 | |
| 135 | block.getByPosition(result).column = res; |
| 136 | } |
| 137 | |
| 138 | void registerFunctionsExternalModels(FunctionFactory & factory) |
| 139 | { |
| 140 | factory.registerFunction<FunctionModelEvaluate>(); |
| 141 | } |
| 142 | |
| 143 | } |
| 144 | |