1#pragma once
2
3#include <DataTypes/DataTypesNumber.h>
4#include <Columns/ColumnVector.h>
5#include <Functions/IFunctionImpl.h>
6#include <Functions/FunctionHelpers.h>
7#include <IO/WriteHelpers.h>
8#include <ext/range.h>
9
10
11namespace DB
12{
13
14namespace ErrorCodes
15{
16 extern const int ILLEGAL_DIVISION;
17 extern const int ILLEGAL_COLUMN;
18 extern const int ILLEGAL_TYPE_OF_ARGUMENT;
19 extern const int TOO_FEW_ARGUMENTS_FOR_FUNCTION;
20}
21
22
23template <typename Impl, typename Name>
24struct FunctionBitTestMany : public IFunction
25{
26public:
27 static constexpr auto name = Name::name;
28 static FunctionPtr create(const Context &) { return std::make_shared<FunctionBitTestMany>(); }
29
30 String getName() const override { return name; }
31
32 bool isVariadic() const override { return true; }
33 size_t getNumberOfArguments() const override { return 0; }
34
35 DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
36 {
37 if (arguments.size() < 2)
38 throw Exception{"Number of arguments for function " + getName() + " doesn't match: passed "
39 + toString(arguments.size()) + ", should be at least 2.", ErrorCodes::TOO_FEW_ARGUMENTS_FOR_FUNCTION};
40
41 const auto & first_arg = arguments.front();
42
43 if (!isInteger(first_arg))
44 throw Exception{"Illegal type " + first_arg->getName() + " of first argument of function " + getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
45
46
47 for (const auto i : ext::range(1, arguments.size()))
48 {
49 const auto & pos_arg = arguments[i];
50
51 if (!isUnsignedInteger(pos_arg))
52 throw Exception{"Illegal type " + pos_arg->getName() + " of " + toString(i) + " argument of function " + getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
53 }
54
55 return std::make_shared<DataTypeUInt8>();
56 }
57
58 void executeImpl(Block & block , const ColumnNumbers & arguments, size_t result, size_t /*input_rows_count*/) override
59 {
60 const auto value_col = block.getByPosition(arguments.front()).column.get();
61
62 if (!execute<UInt8>(block, arguments, result, value_col)
63 && !execute<UInt16>(block, arguments, result, value_col)
64 && !execute<UInt32>(block, arguments, result, value_col)
65 && !execute<UInt64>(block, arguments, result, value_col)
66 && !execute<Int8>(block, arguments, result, value_col)
67 && !execute<Int16>(block, arguments, result, value_col)
68 && !execute<Int32>(block, arguments, result, value_col)
69 && !execute<Int64>(block, arguments, result, value_col))
70 throw Exception{"Illegal column " + value_col->getName() + " of argument of function " + getName(), ErrorCodes::ILLEGAL_COLUMN};
71 }
72
73private:
74 template <typename T>
75 bool execute(
76 Block & block, const ColumnNumbers & arguments, const size_t result,
77 const IColumn * const value_col_untyped)
78 {
79 if (const auto value_col = checkAndGetColumn<ColumnVector<T>>(value_col_untyped))
80 {
81 const auto size = value_col->size();
82 bool is_const;
83 const auto const_mask = createConstMaskIfConst<T>(block, arguments, is_const);
84 const auto & val = value_col->getData();
85
86 auto out_col = ColumnVector<UInt8>::create(size);
87 auto & out = out_col->getData();
88
89 if (is_const)
90 {
91 for (const auto i : ext::range(0, size))
92 out[i] = Impl::apply(val[i], const_mask);
93 }
94 else
95 {
96 const auto mask = createMask<T>(size, block, arguments);
97
98 for (const auto i : ext::range(0, size))
99 out[i] = Impl::apply(val[i], mask[i]);
100 }
101
102 block.getByPosition(result).column = std::move(out_col);
103 return true;
104 }
105 else if (const auto value_col_const = checkAndGetColumnConst<ColumnVector<T>>(value_col_untyped))
106 {
107 const auto size = value_col_const->size();
108 bool is_const;
109 const auto const_mask = createConstMaskIfConst<T>(block, arguments, is_const);
110 const auto val = value_col_const->template getValue<T>();
111
112 if (is_const)
113 {
114 block.getByPosition(result).column = block.getByPosition(result).type->createColumnConst(size, toField(Impl::apply(val, const_mask)));
115 }
116 else
117 {
118 const auto mask = createMask<T>(size, block, arguments);
119 auto out_col = ColumnVector<UInt8>::create(size);
120
121 auto & out = out_col->getData();
122
123 for (const auto i : ext::range(0, size))
124 out[i] = Impl::apply(val, mask[i]);
125
126 block.getByPosition(result).column = std::move(out_col);
127 }
128
129 return true;
130 }
131
132 return false;
133 }
134
135 template <typename ValueType>
136 ValueType createConstMaskIfConst(const Block & block, const ColumnNumbers & arguments, bool & out_is_const)
137 {
138 out_is_const = true;
139 ValueType mask = 0;
140
141 for (const auto i : ext::range(1, arguments.size()))
142 {
143 if (auto pos_col_const = checkAndGetColumnConst<ColumnVector<ValueType>>(block.getByPosition(arguments[i]).column.get()))
144 {
145 const auto pos = pos_col_const->template getValue<ValueType>();
146 mask = mask | (1 << pos);
147 }
148 else
149 {
150 out_is_const = false;
151 return {};
152 }
153 }
154
155 return mask;
156 }
157
158 template <typename ValueType>
159 PaddedPODArray<ValueType> createMask(const size_t size, const Block & block, const ColumnNumbers & arguments)
160 {
161 PaddedPODArray<ValueType> mask(size, ValueType{});
162
163 for (const auto i : ext::range(1, arguments.size()))
164 {
165 const auto pos_col = block.getByPosition(arguments[i]).column.get();
166
167 if (!addToMaskImpl<UInt8>(mask, pos_col)
168 && !addToMaskImpl<UInt16>(mask, pos_col)
169 && !addToMaskImpl<UInt32>(mask, pos_col)
170 && !addToMaskImpl<UInt64>(mask, pos_col))
171 throw Exception{"Illegal column " + pos_col->getName() + " of argument of function " + getName(), ErrorCodes::ILLEGAL_COLUMN};
172 }
173
174 return mask;
175 }
176
177 template <typename PosType, typename ValueType>
178 bool addToMaskImpl(PaddedPODArray<ValueType> & mask, const IColumn * const pos_col_untyped)
179 {
180 if (const auto pos_col = checkAndGetColumn<ColumnVector<PosType>>(pos_col_untyped))
181 {
182 const auto & pos = pos_col->getData();
183
184 for (const auto i : ext::range(0, mask.size()))
185 mask[i] = mask[i] | (1 << pos[i]);
186
187 return true;
188 }
189 else if (const auto pos_col_const = checkAndGetColumnConst<ColumnVector<PosType>>(pos_col_untyped))
190 {
191 const auto & pos = pos_col_const->template getValue<PosType>();
192 const auto new_mask = 1 << pos;
193
194 for (const auto i : ext::range(0, mask.size()))
195 mask[i] = mask[i] | new_mask;
196
197 return true;
198 }
199
200 return false;
201 }
202};
203
204}
205