1 | #include <Functions/FunctionFactory.h> |
2 | #include <DataTypes/DataTypeArray.h> |
3 | #include <DataTypes/getLeastSupertype.h> |
4 | #include <ext/map.h> |
5 | |
6 | |
7 | namespace DB |
8 | { |
9 | |
10 | namespace ErrorCodes |
11 | { |
12 | extern const int TOO_FEW_ARGUMENTS_FOR_FUNCTION; |
13 | } |
14 | |
15 | /// Implements the CASE construction when it is |
16 | /// provided an expression. Users should not call this function. |
17 | class FunctionCaseWithExpression : public IFunction |
18 | { |
19 | public: |
20 | static constexpr auto name = "caseWithExpression" ; |
21 | static FunctionPtr create(const Context & context_) { return std::make_shared<FunctionCaseWithExpression>(context_); } |
22 | |
23 | public: |
24 | FunctionCaseWithExpression(const Context & context_) : context(context_) {} |
25 | bool isVariadic() const override { return true; } |
26 | size_t getNumberOfArguments() const override { return 0; } |
27 | String getName() const override { return name; } |
28 | |
29 | DataTypePtr getReturnTypeImpl(const DataTypes & args) const override |
30 | { |
31 | if (!args.size()) |
32 | throw Exception{"Function " + getName() + " expects at least 1 arguments" , |
33 | ErrorCodes::TOO_FEW_ARGUMENTS_FOR_FUNCTION}; |
34 | |
35 | /// See the comments in executeImpl() to understand why we actually have to |
36 | /// get the return type of a transform function. |
37 | |
38 | /// Get the types of the arrays that we pass to the transform function. |
39 | DataTypes dst_array_types; |
40 | |
41 | for (size_t i = 2; i < args.size() - 1; i += 2) |
42 | dst_array_types.push_back(args[i]); |
43 | |
44 | return getLeastSupertype(dst_array_types); |
45 | } |
46 | |
47 | void executeImpl(Block & block, const ColumnNumbers & args, size_t result, size_t input_rows_count) override |
48 | { |
49 | if (!args.size()) |
50 | throw Exception{"Function " + getName() + " expects at least 1 argument" , |
51 | ErrorCodes::TOO_FEW_ARGUMENTS_FOR_FUNCTION}; |
52 | |
53 | /// In the following code, we turn the construction: |
54 | /// CASE expr WHEN val[0] THEN branch[0] ... WHEN val[N-1] then branch[N-1] ELSE branchN |
55 | /// into the construction transform(expr, src, dest, branchN) |
56 | /// where: |
57 | /// src = [val[0], val[1], ..., val[N-1]] |
58 | /// dst = [branch[0], ..., branch[N-1]] |
59 | /// then we perform it. |
60 | |
61 | /// Create the arrays required by the transform function. |
62 | ColumnNumbers src_array_args; |
63 | ColumnsWithTypeAndName src_array_elems; |
64 | DataTypes src_array_types; |
65 | |
66 | ColumnNumbers dst_array_args; |
67 | ColumnsWithTypeAndName dst_array_elems; |
68 | DataTypes dst_array_types; |
69 | |
70 | for (size_t i = 1; i < (args.size() - 1); ++i) |
71 | { |
72 | if (i % 2) |
73 | { |
74 | src_array_args.push_back(args[i]); |
75 | src_array_elems.push_back(block.getByPosition(args[i])); |
76 | src_array_types.push_back(block.getByPosition(args[i]).type); |
77 | } |
78 | else |
79 | { |
80 | dst_array_args.push_back(args[i]); |
81 | dst_array_elems.push_back(block.getByPosition(args[i])); |
82 | dst_array_types.push_back(block.getByPosition(args[i]).type); |
83 | } |
84 | } |
85 | |
86 | DataTypePtr src_array_type = std::make_shared<DataTypeArray>(getLeastSupertype(src_array_types)); |
87 | DataTypePtr dst_array_type = std::make_shared<DataTypeArray>(getLeastSupertype(dst_array_types)); |
88 | |
89 | Block temp_block = block; |
90 | |
91 | size_t src_array_pos = temp_block.columns(); |
92 | temp_block.insert({nullptr, src_array_type, "" }); |
93 | |
94 | size_t dst_array_pos = temp_block.columns(); |
95 | temp_block.insert({nullptr, dst_array_type, "" }); |
96 | |
97 | auto fun_array = FunctionFactory::instance().get("array" , context); |
98 | |
99 | fun_array->build(src_array_elems)->execute(temp_block, src_array_args, src_array_pos, input_rows_count); |
100 | fun_array->build(dst_array_elems)->execute(temp_block, dst_array_args, dst_array_pos, input_rows_count); |
101 | |
102 | /// Execute transform. |
103 | ColumnNumbers transform_args{args.front(), src_array_pos, dst_array_pos, args.back()}; |
104 | FunctionFactory::instance().get("transform" , context)->build( |
105 | ext::map<ColumnsWithTypeAndName>(transform_args, [&](auto i){ return temp_block.getByPosition(i); })) |
106 | ->execute(temp_block, transform_args, result, input_rows_count); |
107 | |
108 | /// Put the result into the original block. |
109 | block.getByPosition(result).column = std::move(temp_block.getByPosition(result).column); |
110 | } |
111 | |
112 | private: |
113 | const Context & context; |
114 | }; |
115 | |
116 | void registerFunctionCaseWithExpression(FunctionFactory & factory) |
117 | { |
118 | factory.registerFunction<FunctionCaseWithExpression>(); |
119 | |
120 | /// These are obsolete function names. |
121 | factory.registerFunction<FunctionCaseWithExpression>("caseWithExpr" ); |
122 | } |
123 | |
124 | } |
125 | |
126 | |
127 | |
128 | |
129 | |