1#pragma once
2
3#include <common/arithmeticOverflow.h>
4#include <Core/Block.h>
5#include <Core/AccurateComparison.h>
6#include <Core/callOnTypeIndex.h>
7#include <DataTypes/DataTypesNumber.h>
8#include <DataTypes/DataTypesDecimal.h>
9#include <Columns/ColumnVector.h>
10#include <Columns/ColumnsNumber.h>
11#include <Columns/ColumnConst.h>
12#include <Functions/FunctionHelpers.h> /// TODO Core should not depend on Functions
13
14
15namespace DB
16{
17
18namespace ErrorCodes
19{
20 extern const int DECIMAL_OVERFLOW;
21}
22
23///
24inline bool allowDecimalComparison(const DataTypePtr & left_type, const DataTypePtr & right_type)
25{
26 if (isColumnedAsDecimal(left_type))
27 {
28 if (isColumnedAsDecimal(right_type) || isNotDecimalButComparableToDecimal(right_type))
29 return true;
30 }
31 else if (isNotDecimalButComparableToDecimal(left_type) && isColumnedAsDecimal(right_type))
32 return true;
33 return false;
34}
35
36template <size_t > struct ConstructDecInt { using Type = Int32; };
37template <> struct ConstructDecInt<8> { using Type = Int64; };
38template <> struct ConstructDecInt<16> { using Type = Int128; };
39
40template <typename T, typename U>
41struct DecCompareInt
42{
43 using Type = typename ConstructDecInt<(!IsDecimalNumber<U> || sizeof(T) > sizeof(U)) ? sizeof(T) : sizeof(U)>::Type;
44 using TypeA = Type;
45 using TypeB = Type;
46};
47
48///
49template <typename A, typename B, template <typename, typename> typename Operation, bool _check_overflow = true,
50 bool _actual = IsDecimalNumber<A> || IsDecimalNumber<B>>
51class DecimalComparison
52{
53public:
54 using CompareInt = typename DecCompareInt<A, B>::Type;
55 using Op = Operation<CompareInt, CompareInt>;
56 using ColVecA = std::conditional_t<IsDecimalNumber<A>, ColumnDecimal<A>, ColumnVector<A>>;
57 using ColVecB = std::conditional_t<IsDecimalNumber<B>, ColumnDecimal<B>, ColumnVector<B>>;
58 using ArrayA = typename ColVecA::Container;
59 using ArrayB = typename ColVecB::Container;
60
61 DecimalComparison(Block & block, size_t result, const ColumnWithTypeAndName & col_left, const ColumnWithTypeAndName & col_right)
62 {
63 if (!apply(block, result, col_left, col_right))
64 throw Exception("Wrong decimal comparison with " + col_left.type->getName() + " and " + col_right.type->getName(),
65 ErrorCodes::LOGICAL_ERROR);
66 }
67
68 static bool apply(Block & block, size_t result [[maybe_unused]],
69 const ColumnWithTypeAndName & col_left, const ColumnWithTypeAndName & col_right)
70 {
71 if constexpr (_actual)
72 {
73 ColumnPtr c_res;
74 Shift shift = getScales<A, B>(col_left.type, col_right.type);
75
76 c_res = applyWithScale(col_left.column, col_right.column, shift);
77 if (c_res)
78 block.getByPosition(result).column = std::move(c_res);
79 return true;
80 }
81 return false;
82 }
83
84 static bool compare(A a, B b, UInt32 scale_a, UInt32 scale_b)
85 {
86 static const UInt32 max_scale = DecimalUtils::maxPrecision<Decimal128>();
87 if (scale_a > max_scale || scale_b > max_scale)
88 throw Exception("Bad scale of decimal field", ErrorCodes::DECIMAL_OVERFLOW);
89
90 Shift shift;
91 if (scale_a < scale_b)
92 shift.a = B::getScaleMultiplier(scale_b - scale_a);
93 if (scale_a > scale_b)
94 shift.b = A::getScaleMultiplier(scale_a - scale_b);
95
96 return applyWithScale(a, b, shift);
97 }
98
99private:
100 struct Shift
101 {
102 CompareInt a = 1;
103 CompareInt b = 1;
104
105 bool none() const { return a == 1 && b == 1; }
106 bool left() const { return a != 1; }
107 bool right() const { return b != 1; }
108 };
109
110 template <typename T, typename U>
111 static auto applyWithScale(T a, U b, const Shift & shift)
112 {
113 if (shift.left())
114 return apply<true, false>(a, b, shift.a);
115 else if (shift.right())
116 return apply<false, true>(a, b, shift.b);
117 return apply<false, false>(a, b, 1);
118 }
119
120 template <typename T, typename U>
121 static std::enable_if_t<IsDecimalNumber<T> && IsDecimalNumber<U>, Shift>
122 getScales(const DataTypePtr & left_type, const DataTypePtr & right_type)
123 {
124 const DataTypeDecimal<T> * decimal0 = checkDecimal<T>(*left_type);
125 const DataTypeDecimal<U> * decimal1 = checkDecimal<U>(*right_type);
126
127 Shift shift;
128 if (decimal0 && decimal1)
129 {
130 auto result_type = decimalResultType(*decimal0, *decimal1, false, false);
131 shift.a = result_type.scaleFactorFor(*decimal0, false);
132 shift.b = result_type.scaleFactorFor(*decimal1, false);
133 }
134 else if (decimal0)
135 shift.b = decimal0->getScaleMultiplier();
136 else if (decimal1)
137 shift.a = decimal1->getScaleMultiplier();
138
139 return shift;
140 }
141
142 template <typename T, typename U>
143 static std::enable_if_t<IsDecimalNumber<T> && !IsDecimalNumber<U>, Shift>
144 getScales(const DataTypePtr & left_type, const DataTypePtr &)
145 {
146 Shift shift;
147 const DataTypeDecimal<T> * decimal0 = checkDecimal<T>(*left_type);
148 if (decimal0)
149 shift.b = decimal0->getScaleMultiplier();
150 return shift;
151 }
152
153 template <typename T, typename U>
154 static std::enable_if_t<!IsDecimalNumber<T> && IsDecimalNumber<U>, Shift>
155 getScales(const DataTypePtr &, const DataTypePtr & right_type)
156 {
157 Shift shift;
158 const DataTypeDecimal<U> * decimal1 = checkDecimal<U>(*right_type);
159 if (decimal1)
160 shift.a = decimal1->getScaleMultiplier();
161 return shift;
162 }
163
164 template <bool scale_left, bool scale_right>
165 static ColumnPtr apply(const ColumnPtr & c0, const ColumnPtr & c1, CompareInt scale)
166 {
167 auto c_res = ColumnUInt8::create();
168
169 if constexpr (_actual)
170 {
171 bool c0_is_const = isColumnConst(*c0);
172 bool c1_is_const = isColumnConst(*c1);
173
174 if (c0_is_const && c1_is_const)
175 {
176 const ColumnConst * c0_const = checkAndGetColumnConst<ColVecA>(c0.get());
177 const ColumnConst * c1_const = checkAndGetColumnConst<ColVecB>(c1.get());
178
179 A a = c0_const->template getValue<A>();
180 B b = c1_const->template getValue<B>();
181 UInt8 res = apply<scale_left, scale_right>(a, b, scale);
182 return DataTypeUInt8().createColumnConst(c0->size(), toField(res));
183 }
184
185 ColumnUInt8::Container & vec_res = c_res->getData();
186 vec_res.resize(c0->size());
187
188 if (c0_is_const)
189 {
190 const ColumnConst * c0_const = checkAndGetColumnConst<ColVecA>(c0.get());
191 A a = c0_const->template getValue<A>();
192 if (const ColVecB * c1_vec = checkAndGetColumn<ColVecB>(c1.get()))
193 constant_vector<scale_left, scale_right>(a, c1_vec->getData(), vec_res, scale);
194 else
195 throw Exception("Wrong column in Decimal comparison", ErrorCodes::LOGICAL_ERROR);
196 }
197 else if (c1_is_const)
198 {
199 const ColumnConst * c1_const = checkAndGetColumnConst<ColVecB>(c1.get());
200 B b = c1_const->template getValue<B>();
201 if (const ColVecA * c0_vec = checkAndGetColumn<ColVecA>(c0.get()))
202 vector_constant<scale_left, scale_right>(c0_vec->getData(), b, vec_res, scale);
203 else
204 throw Exception("Wrong column in Decimal comparison", ErrorCodes::LOGICAL_ERROR);
205 }
206 else
207 {
208 if (const ColVecA * c0_vec = checkAndGetColumn<ColVecA>(c0.get()))
209 {
210 if (const ColVecB * c1_vec = checkAndGetColumn<ColVecB>(c1.get()))
211 vector_vector<scale_left, scale_right>(c0_vec->getData(), c1_vec->getData(), vec_res, scale);
212 else
213 throw Exception("Wrong column in Decimal comparison", ErrorCodes::LOGICAL_ERROR);
214 }
215 else
216 throw Exception("Wrong column in Decimal comparison", ErrorCodes::LOGICAL_ERROR);
217 }
218 }
219
220 return c_res;
221 }
222
223 template <bool scale_left, bool scale_right>
224 static NO_INLINE UInt8 apply(A a, B b, CompareInt scale [[maybe_unused]])
225 {
226 CompareInt x = a;
227 CompareInt y = b;
228
229 if constexpr (_check_overflow)
230 {
231 bool overflow = false;
232
233 if constexpr (sizeof(A) > sizeof(CompareInt))
234 overflow |= (A(x) != a);
235 if constexpr (sizeof(B) > sizeof(CompareInt))
236 overflow |= (B(y) != b);
237 if constexpr (is_unsigned_v<A>)
238 overflow |= (x < 0);
239 if constexpr (is_unsigned_v<B>)
240 overflow |= (y < 0);
241
242 if constexpr (scale_left)
243 overflow |= common::mulOverflow(x, scale, x);
244 if constexpr (scale_right)
245 overflow |= common::mulOverflow(y, scale, y);
246
247 if (overflow)
248 throw Exception("Can't compare", ErrorCodes::DECIMAL_OVERFLOW);
249 }
250 else
251 {
252 if constexpr (scale_left)
253 x *= scale;
254 if constexpr (scale_right)
255 y *= scale;
256 }
257
258 return Op::apply(x, y);
259 }
260
261 template <bool scale_left, bool scale_right>
262 static void NO_INLINE vector_vector(const ArrayA & a, const ArrayB & b, PaddedPODArray<UInt8> & c,
263 CompareInt scale)
264 {
265 size_t size = a.size();
266 const A * a_pos = a.data();
267 const B * b_pos = b.data();
268 UInt8 * c_pos = c.data();
269 const A * a_end = a_pos + size;
270
271 while (a_pos < a_end)
272 {
273 *c_pos = apply<scale_left, scale_right>(*a_pos, *b_pos, scale);
274 ++a_pos;
275 ++b_pos;
276 ++c_pos;
277 }
278 }
279
280 template <bool scale_left, bool scale_right>
281 static void NO_INLINE vector_constant(const ArrayA & a, B b, PaddedPODArray<UInt8> & c, CompareInt scale)
282 {
283 size_t size = a.size();
284 const A * a_pos = a.data();
285 UInt8 * c_pos = c.data();
286 const A * a_end = a_pos + size;
287
288 while (a_pos < a_end)
289 {
290 *c_pos = apply<scale_left, scale_right>(*a_pos, b, scale);
291 ++a_pos;
292 ++c_pos;
293 }
294 }
295
296 template <bool scale_left, bool scale_right>
297 static void NO_INLINE constant_vector(A a, const ArrayB & b, PaddedPODArray<UInt8> & c, CompareInt scale)
298 {
299 size_t size = b.size();
300 const B * b_pos = b.data();
301 UInt8 * c_pos = c.data();
302 const B * b_end = b_pos + size;
303
304 while (b_pos < b_end)
305 {
306 *c_pos = apply<scale_left, scale_right>(a, *b_pos, scale);
307 ++b_pos;
308 ++c_pos;
309 }
310 }
311};
312
313}
314