1 | #include <Functions/FunctionFactory.h> |
2 | #include <Functions/FunctionIfBase.h> |
3 | #include <Columns/ColumnNullable.h> |
4 | #include <Columns/ColumnConst.h> |
5 | #include <Columns/ColumnsNumber.h> |
6 | #include <Interpreters/castColumn.h> |
7 | #include <Common/typeid_cast.h> |
8 | #include <Common/assert_cast.h> |
9 | #include <DataTypes/DataTypeNullable.h> |
10 | #include <DataTypes/getLeastSupertype.h> |
11 | |
12 | |
13 | namespace DB |
14 | { |
15 | |
16 | namespace ErrorCodes |
17 | { |
18 | extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; |
19 | } |
20 | |
21 | /// Function multiIf, which generalizes the function if. |
22 | /// |
23 | /// Syntax: multiIf(cond_1, then_1, ..., cond_N, then_N, else) |
24 | /// where N >= 1. |
25 | /// |
26 | /// For all 1 <= i <= N, "cond_i" has type UInt8. |
27 | /// Types of all the branches "then_i" and "else" have a common type. |
28 | /// |
29 | /// Additionally the arguments, conditions or branches, support nullable types |
30 | /// and the NULL value, with a NULL condition treated as false. |
31 | class FunctionMultiIf final : public FunctionIfBase</*null_is_false=*/true> |
32 | { |
33 | public: |
34 | static constexpr auto name = "multiIf" ; |
35 | static FunctionPtr create(const Context & context) { return std::make_shared<FunctionMultiIf>(context); } |
36 | FunctionMultiIf(const Context & context_) : context(context_) {} |
37 | |
38 | public: |
39 | String getName() const override { return name; } |
40 | bool isVariadic() const override { return true; } |
41 | size_t getNumberOfArguments() const override { return 0; } |
42 | bool useDefaultImplementationForNulls() const override { return false; } |
43 | ColumnNumbers getArgumentsThatDontImplyNullableReturnType(size_t number_of_arguments) const override |
44 | { |
45 | ColumnNumbers args; |
46 | for (size_t i = 0; i + 1 < number_of_arguments; i += 2) |
47 | args.push_back(i); |
48 | return args; |
49 | } |
50 | |
51 | DataTypePtr getReturnTypeImpl(const DataTypes & args) const override |
52 | { |
53 | /// Arguments are the following: cond1, then1, cond2, then2, ... condN, thenN, else. |
54 | |
55 | auto for_conditions = [&args](auto && f) |
56 | { |
57 | size_t conditions_end = args.size() - 1; |
58 | for (size_t i = 0; i < conditions_end; i += 2) |
59 | f(args[i]); |
60 | }; |
61 | |
62 | auto for_branches = [&args](auto && f) |
63 | { |
64 | size_t branches_end = args.size(); |
65 | for (size_t i = 1; i < branches_end; i += 2) |
66 | f(args[i]); |
67 | f(args.back()); |
68 | }; |
69 | |
70 | if (!(args.size() >= 3 && args.size() % 2 == 1)) |
71 | throw Exception{"Invalid number of arguments for function " + getName(), |
72 | ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH}; |
73 | |
74 | |
75 | for_conditions([&](const DataTypePtr & arg) |
76 | { |
77 | const IDataType * nested_type; |
78 | if (arg->isNullable()) |
79 | { |
80 | if (arg->onlyNull()) |
81 | return; |
82 | |
83 | const DataTypeNullable & nullable_type = static_cast<const DataTypeNullable &>(*arg); |
84 | nested_type = nullable_type.getNestedType().get(); |
85 | } |
86 | else |
87 | { |
88 | nested_type = arg.get(); |
89 | } |
90 | |
91 | if (!WhichDataType(nested_type).isUInt8()) |
92 | throw Exception{"Illegal type " + arg->getName() + " of argument (condition) " |
93 | "of function " + getName() + ". Must be UInt8." , |
94 | ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT}; |
95 | }); |
96 | |
97 | DataTypes types_of_branches; |
98 | types_of_branches.reserve(args.size() / 2 + 1); |
99 | |
100 | for_branches([&](const DataTypePtr & arg) |
101 | { |
102 | types_of_branches.emplace_back(arg); |
103 | }); |
104 | |
105 | return getLeastSupertype(types_of_branches); |
106 | } |
107 | |
108 | void executeImpl(Block & block, const ColumnNumbers & args, size_t result, size_t input_rows_count) override |
109 | { |
110 | /** We will gather values from columns in branches to result column, |
111 | * depending on values of conditions. |
112 | */ |
113 | struct Instruction |
114 | { |
115 | const IColumn * condition = nullptr; |
116 | const IColumn * source = nullptr; |
117 | |
118 | bool condition_always_true = false; |
119 | bool condition_is_nullable = false; |
120 | bool source_is_constant = false; |
121 | }; |
122 | |
123 | std::vector<Instruction> instructions; |
124 | instructions.reserve(args.size() / 2 + 1); |
125 | |
126 | Columns converted_columns_holder; |
127 | converted_columns_holder.reserve(instructions.size()); |
128 | |
129 | const DataTypePtr & return_type = block.getByPosition(result).type; |
130 | |
131 | for (size_t i = 0; i < args.size(); i += 2) |
132 | { |
133 | Instruction instruction; |
134 | size_t source_idx = i + 1; |
135 | |
136 | if (source_idx == args.size()) |
137 | { |
138 | /// The last, "else" branch can be treated as a branch with always true condition "else if (true)". |
139 | --source_idx; |
140 | instruction.condition_always_true = true; |
141 | } |
142 | else |
143 | { |
144 | const ColumnWithTypeAndName & cond_col = block.getByPosition(args[i]); |
145 | |
146 | /// We skip branches that are always false. |
147 | /// If we encounter a branch that is always true, we can finish. |
148 | |
149 | if (cond_col.column->onlyNull()) |
150 | continue; |
151 | |
152 | if (isColumnConst(*cond_col.column)) |
153 | { |
154 | Field value = typeid_cast<const ColumnConst &>(*cond_col.column).getField(); |
155 | if (value.isNull()) |
156 | continue; |
157 | if (value.get<UInt64>() == 0) |
158 | continue; |
159 | instruction.condition_always_true = true; |
160 | } |
161 | else |
162 | { |
163 | if (isColumnNullable(*cond_col.column)) |
164 | instruction.condition_is_nullable = true; |
165 | |
166 | instruction.condition = cond_col.column.get(); |
167 | } |
168 | } |
169 | |
170 | const ColumnWithTypeAndName & source_col = block.getByPosition(args[source_idx]); |
171 | if (source_col.type->equals(*return_type)) |
172 | { |
173 | instruction.source = source_col.column.get(); |
174 | } |
175 | else |
176 | { |
177 | /// Cast all columns to result type. |
178 | converted_columns_holder.emplace_back(castColumn(source_col, return_type, context)); |
179 | instruction.source = converted_columns_holder.back().get(); |
180 | } |
181 | |
182 | if (instruction.source && isColumnConst(*instruction.source)) |
183 | instruction.source_is_constant = true; |
184 | |
185 | instructions.emplace_back(std::move(instruction)); |
186 | |
187 | if (instructions.back().condition_always_true) |
188 | break; |
189 | } |
190 | |
191 | size_t rows = input_rows_count; |
192 | MutableColumnPtr res = return_type->createColumn(); |
193 | |
194 | for (size_t i = 0; i < rows; ++i) |
195 | { |
196 | for (const auto & instruction : instructions) |
197 | { |
198 | bool insert = false; |
199 | |
200 | if (instruction.condition_always_true) |
201 | insert = true; |
202 | else if (!instruction.condition_is_nullable) |
203 | insert = assert_cast<const ColumnUInt8 &>(*instruction.condition).getData()[i]; |
204 | else |
205 | { |
206 | const ColumnNullable & condition_nullable = assert_cast<const ColumnNullable &>(*instruction.condition); |
207 | const ColumnUInt8 & condition_nested = assert_cast<const ColumnUInt8 &>(condition_nullable.getNestedColumn()); |
208 | const NullMap & condition_null_map = condition_nullable.getNullMapData(); |
209 | |
210 | insert = !condition_null_map[i] && condition_nested.getData()[i]; |
211 | } |
212 | |
213 | if (insert) |
214 | { |
215 | if (!instruction.source_is_constant) |
216 | res->insertFrom(*instruction.source, i); |
217 | else |
218 | res->insertFrom(assert_cast<const ColumnConst &>(*instruction.source).getDataColumn(), 0); |
219 | |
220 | break; |
221 | } |
222 | } |
223 | } |
224 | |
225 | block.getByPosition(result).column = std::move(res); |
226 | } |
227 | |
228 | private: |
229 | const Context & context; |
230 | }; |
231 | |
232 | void registerFunctionMultiIf(FunctionFactory & factory) |
233 | { |
234 | factory.registerFunction<FunctionMultiIf>(); |
235 | |
236 | /// These are obsolete function names. |
237 | factory.registerFunction<FunctionMultiIf>("caseWithoutExpr" ); |
238 | factory.registerFunction<FunctionMultiIf>("caseWithoutExpression" ); |
239 | } |
240 | |
241 | } |
242 | |
243 | |
244 | |
245 | |