1 | #pragma once |
2 | |
3 | #include <DataTypes/DataTypesNumber.h> |
4 | #include <Columns/ColumnVector.h> |
5 | #include <Functions/IFunctionImpl.h> |
6 | #include <Functions/FunctionHelpers.h> |
7 | #include <IO/WriteHelpers.h> |
8 | #include <ext/range.h> |
9 | |
10 | |
11 | namespace DB |
12 | { |
13 | |
14 | namespace ErrorCodes |
15 | { |
16 | extern const int ILLEGAL_DIVISION; |
17 | extern const int ILLEGAL_COLUMN; |
18 | extern const int ILLEGAL_TYPE_OF_ARGUMENT; |
19 | extern const int TOO_FEW_ARGUMENTS_FOR_FUNCTION; |
20 | } |
21 | |
22 | |
23 | template <typename Impl, typename Name> |
24 | struct FunctionBitTestMany : public IFunction |
25 | { |
26 | public: |
27 | static constexpr auto name = Name::name; |
28 | static FunctionPtr create(const Context &) { return std::make_shared<FunctionBitTestMany>(); } |
29 | |
30 | String getName() const override { return name; } |
31 | |
32 | bool isVariadic() const override { return true; } |
33 | size_t getNumberOfArguments() const override { return 0; } |
34 | |
35 | DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override |
36 | { |
37 | if (arguments.size() < 2) |
38 | throw Exception{"Number of arguments for function " + getName() + " doesn't match: passed " |
39 | + toString(arguments.size()) + ", should be at least 2." , ErrorCodes::TOO_FEW_ARGUMENTS_FOR_FUNCTION}; |
40 | |
41 | const auto & first_arg = arguments.front(); |
42 | |
43 | if (!isInteger(first_arg)) |
44 | throw Exception{"Illegal type " + first_arg->getName() + " of first argument of function " + getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT}; |
45 | |
46 | |
47 | for (const auto i : ext::range(1, arguments.size())) |
48 | { |
49 | const auto & pos_arg = arguments[i]; |
50 | |
51 | if (!isUnsignedInteger(pos_arg)) |
52 | throw Exception{"Illegal type " + pos_arg->getName() + " of " + toString(i) + " argument of function " + getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT}; |
53 | } |
54 | |
55 | return std::make_shared<DataTypeUInt8>(); |
56 | } |
57 | |
58 | void executeImpl(Block & block , const ColumnNumbers & arguments, size_t result, size_t /*input_rows_count*/) override |
59 | { |
60 | const auto value_col = block.getByPosition(arguments.front()).column.get(); |
61 | |
62 | if (!execute<UInt8>(block, arguments, result, value_col) |
63 | && !execute<UInt16>(block, arguments, result, value_col) |
64 | && !execute<UInt32>(block, arguments, result, value_col) |
65 | && !execute<UInt64>(block, arguments, result, value_col) |
66 | && !execute<Int8>(block, arguments, result, value_col) |
67 | && !execute<Int16>(block, arguments, result, value_col) |
68 | && !execute<Int32>(block, arguments, result, value_col) |
69 | && !execute<Int64>(block, arguments, result, value_col)) |
70 | throw Exception{"Illegal column " + value_col->getName() + " of argument of function " + getName(), ErrorCodes::ILLEGAL_COLUMN}; |
71 | } |
72 | |
73 | private: |
74 | template <typename T> |
75 | bool execute( |
76 | Block & block, const ColumnNumbers & arguments, const size_t result, |
77 | const IColumn * const value_col_untyped) |
78 | { |
79 | if (const auto value_col = checkAndGetColumn<ColumnVector<T>>(value_col_untyped)) |
80 | { |
81 | const auto size = value_col->size(); |
82 | bool is_const; |
83 | const auto const_mask = createConstMaskIfConst<T>(block, arguments, is_const); |
84 | const auto & val = value_col->getData(); |
85 | |
86 | auto out_col = ColumnVector<UInt8>::create(size); |
87 | auto & out = out_col->getData(); |
88 | |
89 | if (is_const) |
90 | { |
91 | for (const auto i : ext::range(0, size)) |
92 | out[i] = Impl::apply(val[i], const_mask); |
93 | } |
94 | else |
95 | { |
96 | const auto mask = createMask<T>(size, block, arguments); |
97 | |
98 | for (const auto i : ext::range(0, size)) |
99 | out[i] = Impl::apply(val[i], mask[i]); |
100 | } |
101 | |
102 | block.getByPosition(result).column = std::move(out_col); |
103 | return true; |
104 | } |
105 | else if (const auto value_col_const = checkAndGetColumnConst<ColumnVector<T>>(value_col_untyped)) |
106 | { |
107 | const auto size = value_col_const->size(); |
108 | bool is_const; |
109 | const auto const_mask = createConstMaskIfConst<T>(block, arguments, is_const); |
110 | const auto val = value_col_const->template getValue<T>(); |
111 | |
112 | if (is_const) |
113 | { |
114 | block.getByPosition(result).column = block.getByPosition(result).type->createColumnConst(size, toField(Impl::apply(val, const_mask))); |
115 | } |
116 | else |
117 | { |
118 | const auto mask = createMask<T>(size, block, arguments); |
119 | auto out_col = ColumnVector<UInt8>::create(size); |
120 | |
121 | auto & out = out_col->getData(); |
122 | |
123 | for (const auto i : ext::range(0, size)) |
124 | out[i] = Impl::apply(val, mask[i]); |
125 | |
126 | block.getByPosition(result).column = std::move(out_col); |
127 | } |
128 | |
129 | return true; |
130 | } |
131 | |
132 | return false; |
133 | } |
134 | |
135 | template <typename ValueType> |
136 | ValueType createConstMaskIfConst(const Block & block, const ColumnNumbers & arguments, bool & out_is_const) |
137 | { |
138 | out_is_const = true; |
139 | ValueType mask = 0; |
140 | |
141 | for (const auto i : ext::range(1, arguments.size())) |
142 | { |
143 | if (auto pos_col_const = checkAndGetColumnConst<ColumnVector<ValueType>>(block.getByPosition(arguments[i]).column.get())) |
144 | { |
145 | const auto pos = pos_col_const->template getValue<ValueType>(); |
146 | mask = mask | (1 << pos); |
147 | } |
148 | else |
149 | { |
150 | out_is_const = false; |
151 | return {}; |
152 | } |
153 | } |
154 | |
155 | return mask; |
156 | } |
157 | |
158 | template <typename ValueType> |
159 | PaddedPODArray<ValueType> createMask(const size_t size, const Block & block, const ColumnNumbers & arguments) |
160 | { |
161 | PaddedPODArray<ValueType> mask(size, ValueType{}); |
162 | |
163 | for (const auto i : ext::range(1, arguments.size())) |
164 | { |
165 | const auto pos_col = block.getByPosition(arguments[i]).column.get(); |
166 | |
167 | if (!addToMaskImpl<UInt8>(mask, pos_col) |
168 | && !addToMaskImpl<UInt16>(mask, pos_col) |
169 | && !addToMaskImpl<UInt32>(mask, pos_col) |
170 | && !addToMaskImpl<UInt64>(mask, pos_col)) |
171 | throw Exception{"Illegal column " + pos_col->getName() + " of argument of function " + getName(), ErrorCodes::ILLEGAL_COLUMN}; |
172 | } |
173 | |
174 | return mask; |
175 | } |
176 | |
177 | template <typename PosType, typename ValueType> |
178 | bool addToMaskImpl(PaddedPODArray<ValueType> & mask, const IColumn * const pos_col_untyped) |
179 | { |
180 | if (const auto pos_col = checkAndGetColumn<ColumnVector<PosType>>(pos_col_untyped)) |
181 | { |
182 | const auto & pos = pos_col->getData(); |
183 | |
184 | for (const auto i : ext::range(0, mask.size())) |
185 | mask[i] = mask[i] | (1 << pos[i]); |
186 | |
187 | return true; |
188 | } |
189 | else if (const auto pos_col_const = checkAndGetColumnConst<ColumnVector<PosType>>(pos_col_untyped)) |
190 | { |
191 | const auto & pos = pos_col_const->template getValue<PosType>(); |
192 | const auto new_mask = 1 << pos; |
193 | |
194 | for (const auto i : ext::range(0, mask.size())) |
195 | mask[i] = mask[i] | new_mask; |
196 | |
197 | return true; |
198 | } |
199 | |
200 | return false; |
201 | } |
202 | }; |
203 | |
204 | } |
205 | |