1#pragma once
2
3#include <DataTypes/DataTypesNumber.h>
4#include <DataTypes/DataTypesDecimal.h>
5#include <DataTypes/Native.h>
6#include <Columns/ColumnVector.h>
7#include <Columns/ColumnDecimal.h>
8#include <Functions/IFunctionImpl.h>
9#include <Functions/FunctionHelpers.h>
10#include <Functions/castTypeToEither.h>
11#include <Common/config.h>
12
13#if USE_EMBEDDED_COMPILER
14#pragma GCC diagnostic push
15#pragma GCC diagnostic ignored "-Wunused-parameter"
16#include <llvm/IR/IRBuilder.h>
17#pragma GCC diagnostic pop
18#endif
19
20
21namespace DB
22{
23
24namespace ErrorCodes
25{
26 extern const int ILLEGAL_TYPE_OF_ARGUMENT;
27 extern const int LOGICAL_ERROR;
28}
29
30
31template <typename A, typename Op>
32struct UnaryOperationImpl
33{
34 using ResultType = typename Op::ResultType;
35 using ColVecA = std::conditional_t<IsDecimalNumber<A>, ColumnDecimal<A>, ColumnVector<A>>;
36 using ColVecC = std::conditional_t<IsDecimalNumber<ResultType>, ColumnDecimal<ResultType>, ColumnVector<ResultType>>;
37 using ArrayA = typename ColVecA::Container;
38 using ArrayC = typename ColVecC::Container;
39
40 static void NO_INLINE vector(const ArrayA & a, ArrayC & c)
41 {
42 size_t size = a.size();
43 for (size_t i = 0; i < size; ++i)
44 c[i] = Op::apply(a[i]);
45 }
46
47 static void constant(A a, ResultType & c)
48 {
49 c = Op::apply(a);
50 }
51};
52
53
54template <typename FunctionName>
55struct FunctionUnaryArithmeticMonotonicity;
56
57template <typename> struct AbsImpl;
58template <typename> struct NegateImpl;
59
60/// Used to indicate undefined operation
61struct InvalidType;
62
63
64template <template <typename> class Op, typename Name, bool is_injective>
65class FunctionUnaryArithmetic : public IFunction
66{
67 static constexpr bool allow_decimal = std::is_same_v<Op<Int8>, NegateImpl<Int8>> || std::is_same_v<Op<Int8>, AbsImpl<Int8>>;
68
69 template <typename F>
70 static bool castType(const IDataType * type, F && f)
71 {
72 return castTypeToEither<
73 DataTypeUInt8,
74 DataTypeUInt16,
75 DataTypeUInt32,
76 DataTypeUInt64,
77 DataTypeInt8,
78 DataTypeInt16,
79 DataTypeInt32,
80 DataTypeInt64,
81 DataTypeFloat32,
82 DataTypeFloat64,
83 DataTypeDecimal<Decimal32>,
84 DataTypeDecimal<Decimal64>,
85 DataTypeDecimal<Decimal128>
86 >(type, std::forward<F>(f));
87 }
88
89public:
90 static constexpr auto name = Name::name;
91 static FunctionPtr create(const Context &) { return std::make_shared<FunctionUnaryArithmetic>(); }
92
93 String getName() const override
94 {
95 return name;
96 }
97
98 size_t getNumberOfArguments() const override { return 1; }
99 bool isInjective(const Block &) override { return is_injective; }
100
101 bool useDefaultImplementationForConstants() const override { return true; }
102
103 DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
104 {
105 DataTypePtr result;
106 bool valid = castType(arguments[0].get(), [&](const auto & type)
107 {
108 using DataType = std::decay_t<decltype(type)>;
109 using T0 = typename DataType::FieldType;
110
111 if constexpr (IsDataTypeDecimal<DataType>)
112 {
113 if constexpr (!allow_decimal)
114 return false;
115 result = std::make_shared<DataType>(type.getPrecision(), type.getScale());
116 }
117 else
118 result = std::make_shared<DataTypeNumber<typename Op<T0>::ResultType>>();
119 return true;
120 });
121 if (!valid)
122 throw Exception("Illegal type " + arguments[0]->getName() + " of argument of function " + getName(),
123 ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
124 return result;
125 }
126
127 void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t /*input_rows_count*/) override
128 {
129 bool valid = castType(block.getByPosition(arguments[0]).type.get(), [&](const auto & type)
130 {
131 using DataType = std::decay_t<decltype(type)>;
132 using T0 = typename DataType::FieldType;
133
134 if constexpr (IsDataTypeDecimal<DataType>)
135 {
136 if constexpr (allow_decimal)
137 {
138 if (auto col = checkAndGetColumn<ColumnDecimal<T0>>(block.getByPosition(arguments[0]).column.get()))
139 {
140 auto col_res = ColumnDecimal<typename Op<T0>::ResultType>::create(0, type.getScale());
141 auto & vec_res = col_res->getData();
142 vec_res.resize(col->getData().size());
143 UnaryOperationImpl<T0, Op<T0>>::vector(col->getData(), vec_res);
144 block.getByPosition(result).column = std::move(col_res);
145 return true;
146 }
147 }
148 }
149 else
150 {
151 if (auto col = checkAndGetColumn<ColumnVector<T0>>(block.getByPosition(arguments[0]).column.get()))
152 {
153 auto col_res = ColumnVector<typename Op<T0>::ResultType>::create();
154 auto & vec_res = col_res->getData();
155 vec_res.resize(col->getData().size());
156 UnaryOperationImpl<T0, Op<T0>>::vector(col->getData(), vec_res);
157 block.getByPosition(result).column = std::move(col_res);
158 return true;
159 }
160 }
161
162 return false;
163 });
164 if (!valid)
165 throw Exception(getName() + "'s argument does not match the expected data type", ErrorCodes::LOGICAL_ERROR);
166 }
167
168#if USE_EMBEDDED_COMPILER
169 bool isCompilableImpl(const DataTypes & arguments) const override
170 {
171 return castType(arguments[0].get(), [&](const auto & type)
172 {
173 using DataType = std::decay_t<decltype(type)>;
174 return !IsDataTypeDecimal<DataType> && Op<typename DataType::FieldType>::compilable;
175 });
176 }
177
178 llvm::Value * compileImpl(llvm::IRBuilderBase & builder, const DataTypes & types, ValuePlaceholders values) const override
179 {
180 llvm::Value * result = nullptr;
181 castType(types[0].get(), [&](const auto & type)
182 {
183 using DataType = std::decay_t<decltype(type)>;
184 using T0 = typename DataType::FieldType;
185 using T1 = typename Op<T0>::ResultType;
186 if constexpr (!std::is_same_v<T1, InvalidType> && !IsDataTypeDecimal<DataType> && Op<T0>::compilable)
187 {
188 auto & b = static_cast<llvm::IRBuilder<> &>(builder);
189 auto * v = nativeCast(b, types[0], values[0](), std::make_shared<DataTypeNumber<T1>>());
190 result = Op<T0>::compile(b, v, is_signed_v<T1>);
191 return true;
192 }
193 return false;
194 });
195 return result;
196 }
197#endif
198
199 bool hasInformationAboutMonotonicity() const override
200 {
201 return FunctionUnaryArithmeticMonotonicity<Name>::has();
202 }
203
204 Monotonicity getMonotonicityForRange(const IDataType &, const Field & left, const Field & right) const override
205 {
206 return FunctionUnaryArithmeticMonotonicity<Name>::get(left, right);
207 }
208};
209
210
211struct PositiveMonotonicity
212{
213 static bool has() { return true; }
214 static IFunction::Monotonicity get(const Field &, const Field &)
215 {
216 return { true };
217 }
218};
219
220}
221