| 1 | #include <Functions/FunctionFactory.h> |
| 2 | #include <Functions/FunctionIfBase.h> |
| 3 | #include <Columns/ColumnNullable.h> |
| 4 | #include <Columns/ColumnConst.h> |
| 5 | #include <Columns/ColumnsNumber.h> |
| 6 | #include <Interpreters/castColumn.h> |
| 7 | #include <Common/typeid_cast.h> |
| 8 | #include <Common/assert_cast.h> |
| 9 | #include <DataTypes/DataTypeNullable.h> |
| 10 | #include <DataTypes/getLeastSupertype.h> |
| 11 | |
| 12 | |
| 13 | namespace DB |
| 14 | { |
| 15 | |
| 16 | namespace ErrorCodes |
| 17 | { |
| 18 | extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; |
| 19 | } |
| 20 | |
| 21 | /// Function multiIf, which generalizes the function if. |
| 22 | /// |
| 23 | /// Syntax: multiIf(cond_1, then_1, ..., cond_N, then_N, else) |
| 24 | /// where N >= 1. |
| 25 | /// |
| 26 | /// For all 1 <= i <= N, "cond_i" has type UInt8. |
| 27 | /// Types of all the branches "then_i" and "else" have a common type. |
| 28 | /// |
| 29 | /// Additionally the arguments, conditions or branches, support nullable types |
| 30 | /// and the NULL value, with a NULL condition treated as false. |
| 31 | class FunctionMultiIf final : public FunctionIfBase</*null_is_false=*/true> |
| 32 | { |
| 33 | public: |
| 34 | static constexpr auto name = "multiIf" ; |
| 35 | static FunctionPtr create(const Context & context) { return std::make_shared<FunctionMultiIf>(context); } |
| 36 | FunctionMultiIf(const Context & context_) : context(context_) {} |
| 37 | |
| 38 | public: |
| 39 | String getName() const override { return name; } |
| 40 | bool isVariadic() const override { return true; } |
| 41 | size_t getNumberOfArguments() const override { return 0; } |
| 42 | bool useDefaultImplementationForNulls() const override { return false; } |
| 43 | ColumnNumbers getArgumentsThatDontImplyNullableReturnType(size_t number_of_arguments) const override |
| 44 | { |
| 45 | ColumnNumbers args; |
| 46 | for (size_t i = 0; i + 1 < number_of_arguments; i += 2) |
| 47 | args.push_back(i); |
| 48 | return args; |
| 49 | } |
| 50 | |
| 51 | DataTypePtr getReturnTypeImpl(const DataTypes & args) const override |
| 52 | { |
| 53 | /// Arguments are the following: cond1, then1, cond2, then2, ... condN, thenN, else. |
| 54 | |
| 55 | auto for_conditions = [&args](auto && f) |
| 56 | { |
| 57 | size_t conditions_end = args.size() - 1; |
| 58 | for (size_t i = 0; i < conditions_end; i += 2) |
| 59 | f(args[i]); |
| 60 | }; |
| 61 | |
| 62 | auto for_branches = [&args](auto && f) |
| 63 | { |
| 64 | size_t branches_end = args.size(); |
| 65 | for (size_t i = 1; i < branches_end; i += 2) |
| 66 | f(args[i]); |
| 67 | f(args.back()); |
| 68 | }; |
| 69 | |
| 70 | if (!(args.size() >= 3 && args.size() % 2 == 1)) |
| 71 | throw Exception{"Invalid number of arguments for function " + getName(), |
| 72 | ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH}; |
| 73 | |
| 74 | |
| 75 | for_conditions([&](const DataTypePtr & arg) |
| 76 | { |
| 77 | const IDataType * nested_type; |
| 78 | if (arg->isNullable()) |
| 79 | { |
| 80 | if (arg->onlyNull()) |
| 81 | return; |
| 82 | |
| 83 | const DataTypeNullable & nullable_type = static_cast<const DataTypeNullable &>(*arg); |
| 84 | nested_type = nullable_type.getNestedType().get(); |
| 85 | } |
| 86 | else |
| 87 | { |
| 88 | nested_type = arg.get(); |
| 89 | } |
| 90 | |
| 91 | if (!WhichDataType(nested_type).isUInt8()) |
| 92 | throw Exception{"Illegal type " + arg->getName() + " of argument (condition) " |
| 93 | "of function " + getName() + ". Must be UInt8." , |
| 94 | ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT}; |
| 95 | }); |
| 96 | |
| 97 | DataTypes types_of_branches; |
| 98 | types_of_branches.reserve(args.size() / 2 + 1); |
| 99 | |
| 100 | for_branches([&](const DataTypePtr & arg) |
| 101 | { |
| 102 | types_of_branches.emplace_back(arg); |
| 103 | }); |
| 104 | |
| 105 | return getLeastSupertype(types_of_branches); |
| 106 | } |
| 107 | |
| 108 | void executeImpl(Block & block, const ColumnNumbers & args, size_t result, size_t input_rows_count) override |
| 109 | { |
| 110 | /** We will gather values from columns in branches to result column, |
| 111 | * depending on values of conditions. |
| 112 | */ |
| 113 | struct Instruction |
| 114 | { |
| 115 | const IColumn * condition = nullptr; |
| 116 | const IColumn * source = nullptr; |
| 117 | |
| 118 | bool condition_always_true = false; |
| 119 | bool condition_is_nullable = false; |
| 120 | bool source_is_constant = false; |
| 121 | }; |
| 122 | |
| 123 | std::vector<Instruction> instructions; |
| 124 | instructions.reserve(args.size() / 2 + 1); |
| 125 | |
| 126 | Columns converted_columns_holder; |
| 127 | converted_columns_holder.reserve(instructions.size()); |
| 128 | |
| 129 | const DataTypePtr & return_type = block.getByPosition(result).type; |
| 130 | |
| 131 | for (size_t i = 0; i < args.size(); i += 2) |
| 132 | { |
| 133 | Instruction instruction; |
| 134 | size_t source_idx = i + 1; |
| 135 | |
| 136 | if (source_idx == args.size()) |
| 137 | { |
| 138 | /// The last, "else" branch can be treated as a branch with always true condition "else if (true)". |
| 139 | --source_idx; |
| 140 | instruction.condition_always_true = true; |
| 141 | } |
| 142 | else |
| 143 | { |
| 144 | const ColumnWithTypeAndName & cond_col = block.getByPosition(args[i]); |
| 145 | |
| 146 | /// We skip branches that are always false. |
| 147 | /// If we encounter a branch that is always true, we can finish. |
| 148 | |
| 149 | if (cond_col.column->onlyNull()) |
| 150 | continue; |
| 151 | |
| 152 | if (isColumnConst(*cond_col.column)) |
| 153 | { |
| 154 | Field value = typeid_cast<const ColumnConst &>(*cond_col.column).getField(); |
| 155 | if (value.isNull()) |
| 156 | continue; |
| 157 | if (value.get<UInt64>() == 0) |
| 158 | continue; |
| 159 | instruction.condition_always_true = true; |
| 160 | } |
| 161 | else |
| 162 | { |
| 163 | if (isColumnNullable(*cond_col.column)) |
| 164 | instruction.condition_is_nullable = true; |
| 165 | |
| 166 | instruction.condition = cond_col.column.get(); |
| 167 | } |
| 168 | } |
| 169 | |
| 170 | const ColumnWithTypeAndName & source_col = block.getByPosition(args[source_idx]); |
| 171 | if (source_col.type->equals(*return_type)) |
| 172 | { |
| 173 | instruction.source = source_col.column.get(); |
| 174 | } |
| 175 | else |
| 176 | { |
| 177 | /// Cast all columns to result type. |
| 178 | converted_columns_holder.emplace_back(castColumn(source_col, return_type, context)); |
| 179 | instruction.source = converted_columns_holder.back().get(); |
| 180 | } |
| 181 | |
| 182 | if (instruction.source && isColumnConst(*instruction.source)) |
| 183 | instruction.source_is_constant = true; |
| 184 | |
| 185 | instructions.emplace_back(std::move(instruction)); |
| 186 | |
| 187 | if (instructions.back().condition_always_true) |
| 188 | break; |
| 189 | } |
| 190 | |
| 191 | size_t rows = input_rows_count; |
| 192 | MutableColumnPtr res = return_type->createColumn(); |
| 193 | |
| 194 | for (size_t i = 0; i < rows; ++i) |
| 195 | { |
| 196 | for (const auto & instruction : instructions) |
| 197 | { |
| 198 | bool insert = false; |
| 199 | |
| 200 | if (instruction.condition_always_true) |
| 201 | insert = true; |
| 202 | else if (!instruction.condition_is_nullable) |
| 203 | insert = assert_cast<const ColumnUInt8 &>(*instruction.condition).getData()[i]; |
| 204 | else |
| 205 | { |
| 206 | const ColumnNullable & condition_nullable = assert_cast<const ColumnNullable &>(*instruction.condition); |
| 207 | const ColumnUInt8 & condition_nested = assert_cast<const ColumnUInt8 &>(condition_nullable.getNestedColumn()); |
| 208 | const NullMap & condition_null_map = condition_nullable.getNullMapData(); |
| 209 | |
| 210 | insert = !condition_null_map[i] && condition_nested.getData()[i]; |
| 211 | } |
| 212 | |
| 213 | if (insert) |
| 214 | { |
| 215 | if (!instruction.source_is_constant) |
| 216 | res->insertFrom(*instruction.source, i); |
| 217 | else |
| 218 | res->insertFrom(assert_cast<const ColumnConst &>(*instruction.source).getDataColumn(), 0); |
| 219 | |
| 220 | break; |
| 221 | } |
| 222 | } |
| 223 | } |
| 224 | |
| 225 | block.getByPosition(result).column = std::move(res); |
| 226 | } |
| 227 | |
| 228 | private: |
| 229 | const Context & context; |
| 230 | }; |
| 231 | |
| 232 | void registerFunctionMultiIf(FunctionFactory & factory) |
| 233 | { |
| 234 | factory.registerFunction<FunctionMultiIf>(); |
| 235 | |
| 236 | /// These are obsolete function names. |
| 237 | factory.registerFunction<FunctionMultiIf>("caseWithoutExpr" ); |
| 238 | factory.registerFunction<FunctionMultiIf>("caseWithoutExpression" ); |
| 239 | } |
| 240 | |
| 241 | } |
| 242 | |
| 243 | |
| 244 | |
| 245 | |