1#include <Functions/FunctionFactory.h>
2#include <Functions/FunctionIfBase.h>
3#include <Columns/ColumnNullable.h>
4#include <Columns/ColumnConst.h>
5#include <Columns/ColumnsNumber.h>
6#include <Interpreters/castColumn.h>
7#include <Common/typeid_cast.h>
8#include <Common/assert_cast.h>
9#include <DataTypes/DataTypeNullable.h>
10#include <DataTypes/getLeastSupertype.h>
11
12
13namespace DB
14{
15
16namespace ErrorCodes
17{
18 extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
19}
20
21/// Function multiIf, which generalizes the function if.
22///
23/// Syntax: multiIf(cond_1, then_1, ..., cond_N, then_N, else)
24/// where N >= 1.
25///
26/// For all 1 <= i <= N, "cond_i" has type UInt8.
27/// Types of all the branches "then_i" and "else" have a common type.
28///
29/// Additionally the arguments, conditions or branches, support nullable types
30/// and the NULL value, with a NULL condition treated as false.
31class FunctionMultiIf final : public FunctionIfBase</*null_is_false=*/true>
32{
33public:
34 static constexpr auto name = "multiIf";
35 static FunctionPtr create(const Context & context) { return std::make_shared<FunctionMultiIf>(context); }
36 FunctionMultiIf(const Context & context_) : context(context_) {}
37
38public:
39 String getName() const override { return name; }
40 bool isVariadic() const override { return true; }
41 size_t getNumberOfArguments() const override { return 0; }
42 bool useDefaultImplementationForNulls() const override { return false; }
43 ColumnNumbers getArgumentsThatDontImplyNullableReturnType(size_t number_of_arguments) const override
44 {
45 ColumnNumbers args;
46 for (size_t i = 0; i + 1 < number_of_arguments; i += 2)
47 args.push_back(i);
48 return args;
49 }
50
51 DataTypePtr getReturnTypeImpl(const DataTypes & args) const override
52 {
53 /// Arguments are the following: cond1, then1, cond2, then2, ... condN, thenN, else.
54
55 auto for_conditions = [&args](auto && f)
56 {
57 size_t conditions_end = args.size() - 1;
58 for (size_t i = 0; i < conditions_end; i += 2)
59 f(args[i]);
60 };
61
62 auto for_branches = [&args](auto && f)
63 {
64 size_t branches_end = args.size();
65 for (size_t i = 1; i < branches_end; i += 2)
66 f(args[i]);
67 f(args.back());
68 };
69
70 if (!(args.size() >= 3 && args.size() % 2 == 1))
71 throw Exception{"Invalid number of arguments for function " + getName(),
72 ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH};
73
74
75 for_conditions([&](const DataTypePtr & arg)
76 {
77 const IDataType * nested_type;
78 if (arg->isNullable())
79 {
80 if (arg->onlyNull())
81 return;
82
83 const DataTypeNullable & nullable_type = static_cast<const DataTypeNullable &>(*arg);
84 nested_type = nullable_type.getNestedType().get();
85 }
86 else
87 {
88 nested_type = arg.get();
89 }
90
91 if (!WhichDataType(nested_type).isUInt8())
92 throw Exception{"Illegal type " + arg->getName() + " of argument (condition) "
93 "of function " + getName() + ". Must be UInt8.",
94 ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
95 });
96
97 DataTypes types_of_branches;
98 types_of_branches.reserve(args.size() / 2 + 1);
99
100 for_branches([&](const DataTypePtr & arg)
101 {
102 types_of_branches.emplace_back(arg);
103 });
104
105 return getLeastSupertype(types_of_branches);
106 }
107
108 void executeImpl(Block & block, const ColumnNumbers & args, size_t result, size_t input_rows_count) override
109 {
110 /** We will gather values from columns in branches to result column,
111 * depending on values of conditions.
112 */
113 struct Instruction
114 {
115 const IColumn * condition = nullptr;
116 const IColumn * source = nullptr;
117
118 bool condition_always_true = false;
119 bool condition_is_nullable = false;
120 bool source_is_constant = false;
121 };
122
123 std::vector<Instruction> instructions;
124 instructions.reserve(args.size() / 2 + 1);
125
126 Columns converted_columns_holder;
127 converted_columns_holder.reserve(instructions.size());
128
129 const DataTypePtr & return_type = block.getByPosition(result).type;
130
131 for (size_t i = 0; i < args.size(); i += 2)
132 {
133 Instruction instruction;
134 size_t source_idx = i + 1;
135
136 if (source_idx == args.size())
137 {
138 /// The last, "else" branch can be treated as a branch with always true condition "else if (true)".
139 --source_idx;
140 instruction.condition_always_true = true;
141 }
142 else
143 {
144 const ColumnWithTypeAndName & cond_col = block.getByPosition(args[i]);
145
146 /// We skip branches that are always false.
147 /// If we encounter a branch that is always true, we can finish.
148
149 if (cond_col.column->onlyNull())
150 continue;
151
152 if (isColumnConst(*cond_col.column))
153 {
154 Field value = typeid_cast<const ColumnConst &>(*cond_col.column).getField();
155 if (value.isNull())
156 continue;
157 if (value.get<UInt64>() == 0)
158 continue;
159 instruction.condition_always_true = true;
160 }
161 else
162 {
163 if (isColumnNullable(*cond_col.column))
164 instruction.condition_is_nullable = true;
165
166 instruction.condition = cond_col.column.get();
167 }
168 }
169
170 const ColumnWithTypeAndName & source_col = block.getByPosition(args[source_idx]);
171 if (source_col.type->equals(*return_type))
172 {
173 instruction.source = source_col.column.get();
174 }
175 else
176 {
177 /// Cast all columns to result type.
178 converted_columns_holder.emplace_back(castColumn(source_col, return_type, context));
179 instruction.source = converted_columns_holder.back().get();
180 }
181
182 if (instruction.source && isColumnConst(*instruction.source))
183 instruction.source_is_constant = true;
184
185 instructions.emplace_back(std::move(instruction));
186
187 if (instructions.back().condition_always_true)
188 break;
189 }
190
191 size_t rows = input_rows_count;
192 MutableColumnPtr res = return_type->createColumn();
193
194 for (size_t i = 0; i < rows; ++i)
195 {
196 for (const auto & instruction : instructions)
197 {
198 bool insert = false;
199
200 if (instruction.condition_always_true)
201 insert = true;
202 else if (!instruction.condition_is_nullable)
203 insert = assert_cast<const ColumnUInt8 &>(*instruction.condition).getData()[i];
204 else
205 {
206 const ColumnNullable & condition_nullable = assert_cast<const ColumnNullable &>(*instruction.condition);
207 const ColumnUInt8 & condition_nested = assert_cast<const ColumnUInt8 &>(condition_nullable.getNestedColumn());
208 const NullMap & condition_null_map = condition_nullable.getNullMapData();
209
210 insert = !condition_null_map[i] && condition_nested.getData()[i];
211 }
212
213 if (insert)
214 {
215 if (!instruction.source_is_constant)
216 res->insertFrom(*instruction.source, i);
217 else
218 res->insertFrom(assert_cast<const ColumnConst &>(*instruction.source).getDataColumn(), 0);
219
220 break;
221 }
222 }
223 }
224
225 block.getByPosition(result).column = std::move(res);
226 }
227
228private:
229 const Context & context;
230};
231
232void registerFunctionMultiIf(FunctionFactory & factory)
233{
234 factory.registerFunction<FunctionMultiIf>();
235
236 /// These are obsolete function names.
237 factory.registerFunction<FunctionMultiIf>("caseWithoutExpr");
238 factory.registerFunction<FunctionMultiIf>("caseWithoutExpression");
239}
240
241}
242
243
244
245