1 | #include <Functions/IFunctionImpl.h> |
2 | #include <Functions/FunctionHelpers.h> |
3 | #include <Columns/ColumnsNumber.h> |
4 | #include <Columns/ColumnNullable.h> |
5 | #include <Common/assert_cast.h> |
6 | #include <DataTypes/DataTypeDate.h> |
7 | #include <DataTypes/DataTypeDateTime.h> |
8 | #include <DataTypes/DataTypeDateTime64.h> |
9 | #include <DataTypes/DataTypesNumber.h> |
10 | #include <DataTypes/NumberTraits.h> |
11 | #include <DataTypes/DataTypeNullable.h> |
12 | |
13 | |
14 | namespace DB |
15 | { |
16 | |
17 | namespace ErrorCodes |
18 | { |
19 | extern const int ILLEGAL_TYPE_OF_ARGUMENT; |
20 | } |
21 | |
22 | |
23 | template <bool is_first_line_zero> |
24 | struct FunctionRunningDifferenceName; |
25 | |
26 | template <> |
27 | struct FunctionRunningDifferenceName<true> |
28 | { |
29 | static constexpr auto name = "runningDifference" ; |
30 | }; |
31 | |
32 | template <> |
33 | struct FunctionRunningDifferenceName<false> |
34 | { |
35 | static constexpr auto name = "runningDifferenceStartingWithFirstValue" ; |
36 | }; |
37 | |
38 | /** Calculate difference of consecutive values in block. |
39 | * So, result of function depends on partition of data to blocks and on order of data in block. |
40 | */ |
41 | template <bool is_first_line_zero> |
42 | class FunctionRunningDifferenceImpl : public IFunction |
43 | { |
44 | private: |
45 | /// It is possible to track value from previous block, to calculate continuously across all blocks. Not implemented. |
46 | |
47 | template <typename Src, typename Dst> |
48 | static void process(const PaddedPODArray<Src> & src, PaddedPODArray<Dst> & dst, const NullMap * null_map) |
49 | { |
50 | size_t size = src.size(); |
51 | dst.resize(size); |
52 | |
53 | if (size == 0) |
54 | return; |
55 | |
56 | /// It is possible to SIMD optimize this loop. By no need for that in practice. |
57 | |
58 | Src prev{}; |
59 | bool has_prev_value = false; |
60 | |
61 | for (size_t i = 0; i < size; ++i) |
62 | { |
63 | if (null_map && (*null_map)[i]) |
64 | { |
65 | dst[i] = Dst{}; |
66 | continue; |
67 | } |
68 | |
69 | if (!has_prev_value) |
70 | { |
71 | dst[i] = is_first_line_zero ? 0 : src[i]; |
72 | prev = src[i]; |
73 | has_prev_value = true; |
74 | } |
75 | else |
76 | { |
77 | auto cur = src[i]; |
78 | dst[i] = static_cast<Dst>(cur) - prev; |
79 | prev = cur; |
80 | } |
81 | } |
82 | } |
83 | |
84 | /// Result type is same as result of subtraction of argument types. |
85 | template <typename SrcFieldType> |
86 | using DstFieldType = typename NumberTraits::ResultOfSubtraction<SrcFieldType, SrcFieldType>::Type; |
87 | |
88 | /// Call polymorphic lambda with tag argument of concrete field type of src_type. |
89 | template <typename F> |
90 | void dispatchForSourceType(const IDataType & src_type, F && f) const |
91 | { |
92 | WhichDataType which(src_type); |
93 | |
94 | if (which.isUInt8()) |
95 | f(UInt8()); |
96 | else if (which.isUInt16()) |
97 | f(UInt16()); |
98 | else if (which.isUInt32()) |
99 | f(UInt32()); |
100 | else if (which.isUInt64()) |
101 | f(UInt64()); |
102 | else if (which.isInt8()) |
103 | f(Int8()); |
104 | else if (which.isInt16()) |
105 | f(Int16()); |
106 | else if (which.isInt32()) |
107 | f(Int32()); |
108 | else if (which.isInt64()) |
109 | f(Int64()); |
110 | else if (which.isFloat32()) |
111 | f(Float32()); |
112 | else if (which.isFloat64()) |
113 | f(Float64()); |
114 | else if (which.isDate()) |
115 | f(DataTypeDate::FieldType()); |
116 | else if (which.isDateTime()) |
117 | f(DataTypeDateTime::FieldType()); |
118 | else |
119 | throw Exception("Argument for function " + getName() + " must have numeric type." , ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); |
120 | } |
121 | |
122 | public: |
123 | static constexpr auto name = FunctionRunningDifferenceName<is_first_line_zero>::name; |
124 | |
125 | static FunctionPtr create(const Context &) |
126 | { |
127 | return std::make_shared<FunctionRunningDifferenceImpl<is_first_line_zero>>(); |
128 | } |
129 | |
130 | String getName() const override |
131 | { |
132 | return name; |
133 | } |
134 | |
135 | bool isStateful() const override |
136 | { |
137 | return true; |
138 | } |
139 | |
140 | size_t getNumberOfArguments() const override |
141 | { |
142 | return 1; |
143 | } |
144 | |
145 | bool isDeterministic() const override { return false; } |
146 | bool isDeterministicInScopeOfQuery() const override |
147 | { |
148 | return false; |
149 | } |
150 | |
151 | bool useDefaultImplementationForNulls() const override { return false; } |
152 | |
153 | DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override |
154 | { |
155 | DataTypePtr res; |
156 | dispatchForSourceType(*removeNullable(arguments[0]), [&](auto field_type_tag) |
157 | { |
158 | res = std::make_shared<DataTypeNumber<DstFieldType<decltype(field_type_tag)>>>(); |
159 | }); |
160 | |
161 | if (arguments[0]->isNullable()) |
162 | res = makeNullable(res); |
163 | |
164 | return res; |
165 | } |
166 | |
167 | void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t input_rows_count) override |
168 | { |
169 | auto & src = block.getByPosition(arguments.at(0)); |
170 | const auto & res_type = block.getByPosition(result).type; |
171 | |
172 | /// When column is constant, its difference is zero. |
173 | if (isColumnConst(*src.column)) |
174 | { |
175 | block.getByPosition(result).column = res_type->createColumnConstWithDefaultValue(input_rows_count); |
176 | return; |
177 | } |
178 | |
179 | auto res_column = removeNullable(res_type)->createColumn(); |
180 | auto * src_column = src.column.get(); |
181 | ColumnPtr null_map_column = nullptr; |
182 | const NullMap * null_map = nullptr; |
183 | if (auto * nullable_column = checkAndGetColumn<ColumnNullable>(src_column)) |
184 | { |
185 | src_column = &nullable_column->getNestedColumn(); |
186 | null_map_column = nullable_column->getNullMapColumnPtr(); |
187 | null_map = &nullable_column->getNullMapData(); |
188 | } |
189 | |
190 | dispatchForSourceType(*removeNullable(src.type), [&](auto field_type_tag) |
191 | { |
192 | using SrcFieldType = decltype(field_type_tag); |
193 | |
194 | process(assert_cast<const ColumnVector<SrcFieldType> &>(*src_column).getData(), |
195 | assert_cast<ColumnVector<DstFieldType<SrcFieldType>> &>(*res_column).getData(), null_map); |
196 | }); |
197 | |
198 | if (null_map_column) |
199 | block.getByPosition(result).column = ColumnNullable::create(std::move(res_column), null_map_column); |
200 | else |
201 | block.getByPosition(result).column = std::move(res_column); |
202 | } |
203 | }; |
204 | |
205 | } |
206 | |