1#include <Functions/FunctionFactory.h>
2#include <DataTypes/DataTypeArray.h>
3#include <DataTypes/getLeastSupertype.h>
4#include <ext/map.h>
5
6
7namespace DB
8{
9
10namespace ErrorCodes
11{
12 extern const int TOO_FEW_ARGUMENTS_FOR_FUNCTION;
13}
14
15/// Implements the CASE construction when it is
16/// provided an expression. Users should not call this function.
17class FunctionCaseWithExpression : public IFunction
18{
19public:
20 static constexpr auto name = "caseWithExpression";
21 static FunctionPtr create(const Context & context_) { return std::make_shared<FunctionCaseWithExpression>(context_); }
22
23public:
24 FunctionCaseWithExpression(const Context & context_) : context(context_) {}
25 bool isVariadic() const override { return true; }
26 size_t getNumberOfArguments() const override { return 0; }
27 String getName() const override { return name; }
28
29 DataTypePtr getReturnTypeImpl(const DataTypes & args) const override
30 {
31 if (!args.size())
32 throw Exception{"Function " + getName() + " expects at least 1 arguments",
33 ErrorCodes::TOO_FEW_ARGUMENTS_FOR_FUNCTION};
34
35 /// See the comments in executeImpl() to understand why we actually have to
36 /// get the return type of a transform function.
37
38 /// Get the types of the arrays that we pass to the transform function.
39 DataTypes dst_array_types;
40
41 for (size_t i = 2; i < args.size() - 1; i += 2)
42 dst_array_types.push_back(args[i]);
43
44 return getLeastSupertype(dst_array_types);
45 }
46
47 void executeImpl(Block & block, const ColumnNumbers & args, size_t result, size_t input_rows_count) override
48 {
49 if (!args.size())
50 throw Exception{"Function " + getName() + " expects at least 1 argument",
51 ErrorCodes::TOO_FEW_ARGUMENTS_FOR_FUNCTION};
52
53 /// In the following code, we turn the construction:
54 /// CASE expr WHEN val[0] THEN branch[0] ... WHEN val[N-1] then branch[N-1] ELSE branchN
55 /// into the construction transform(expr, src, dest, branchN)
56 /// where:
57 /// src = [val[0], val[1], ..., val[N-1]]
58 /// dst = [branch[0], ..., branch[N-1]]
59 /// then we perform it.
60
61 /// Create the arrays required by the transform function.
62 ColumnNumbers src_array_args;
63 ColumnsWithTypeAndName src_array_elems;
64 DataTypes src_array_types;
65
66 ColumnNumbers dst_array_args;
67 ColumnsWithTypeAndName dst_array_elems;
68 DataTypes dst_array_types;
69
70 for (size_t i = 1; i < (args.size() - 1); ++i)
71 {
72 if (i % 2)
73 {
74 src_array_args.push_back(args[i]);
75 src_array_elems.push_back(block.getByPosition(args[i]));
76 src_array_types.push_back(block.getByPosition(args[i]).type);
77 }
78 else
79 {
80 dst_array_args.push_back(args[i]);
81 dst_array_elems.push_back(block.getByPosition(args[i]));
82 dst_array_types.push_back(block.getByPosition(args[i]).type);
83 }
84 }
85
86 DataTypePtr src_array_type = std::make_shared<DataTypeArray>(getLeastSupertype(src_array_types));
87 DataTypePtr dst_array_type = std::make_shared<DataTypeArray>(getLeastSupertype(dst_array_types));
88
89 Block temp_block = block;
90
91 size_t src_array_pos = temp_block.columns();
92 temp_block.insert({nullptr, src_array_type, ""});
93
94 size_t dst_array_pos = temp_block.columns();
95 temp_block.insert({nullptr, dst_array_type, ""});
96
97 auto fun_array = FunctionFactory::instance().get("array", context);
98
99 fun_array->build(src_array_elems)->execute(temp_block, src_array_args, src_array_pos, input_rows_count);
100 fun_array->build(dst_array_elems)->execute(temp_block, dst_array_args, dst_array_pos, input_rows_count);
101
102 /// Execute transform.
103 ColumnNumbers transform_args{args.front(), src_array_pos, dst_array_pos, args.back()};
104 FunctionFactory::instance().get("transform", context)->build(
105 ext::map<ColumnsWithTypeAndName>(transform_args, [&](auto i){ return temp_block.getByPosition(i); }))
106 ->execute(temp_block, transform_args, result, input_rows_count);
107
108 /// Put the result into the original block.
109 block.getByPosition(result).column = std::move(temp_block.getByPosition(result).column);
110 }
111
112private:
113 const Context & context;
114};
115
116void registerFunctionCaseWithExpression(FunctionFactory & factory)
117{
118 factory.registerFunction<FunctionCaseWithExpression>();
119
120 /// These are obsolete function names.
121 factory.registerFunction<FunctionCaseWithExpression>("caseWithExpr");
122}
123
124}
125
126
127
128
129