1#include <Functions/IFunctionImpl.h>
2#include <Functions/FunctionHelpers.h>
3#include <Columns/ColumnsNumber.h>
4#include <Columns/ColumnNullable.h>
5#include <Common/assert_cast.h>
6#include <DataTypes/DataTypeDate.h>
7#include <DataTypes/DataTypeDateTime.h>
8#include <DataTypes/DataTypeDateTime64.h>
9#include <DataTypes/DataTypesNumber.h>
10#include <DataTypes/NumberTraits.h>
11#include <DataTypes/DataTypeNullable.h>
12
13
14namespace DB
15{
16
17namespace ErrorCodes
18{
19 extern const int ILLEGAL_TYPE_OF_ARGUMENT;
20}
21
22
23template <bool is_first_line_zero>
24struct FunctionRunningDifferenceName;
25
26template <>
27struct FunctionRunningDifferenceName<true>
28{
29 static constexpr auto name = "runningDifference";
30};
31
32template <>
33struct FunctionRunningDifferenceName<false>
34{
35 static constexpr auto name = "runningDifferenceStartingWithFirstValue";
36};
37
38/** Calculate difference of consecutive values in block.
39 * So, result of function depends on partition of data to blocks and on order of data in block.
40 */
41template <bool is_first_line_zero>
42class FunctionRunningDifferenceImpl : public IFunction
43{
44private:
45 /// It is possible to track value from previous block, to calculate continuously across all blocks. Not implemented.
46
47 template <typename Src, typename Dst>
48 static void process(const PaddedPODArray<Src> & src, PaddedPODArray<Dst> & dst, const NullMap * null_map)
49 {
50 size_t size = src.size();
51 dst.resize(size);
52
53 if (size == 0)
54 return;
55
56 /// It is possible to SIMD optimize this loop. By no need for that in practice.
57
58 Src prev{};
59 bool has_prev_value = false;
60
61 for (size_t i = 0; i < size; ++i)
62 {
63 if (null_map && (*null_map)[i])
64 {
65 dst[i] = Dst{};
66 continue;
67 }
68
69 if (!has_prev_value)
70 {
71 dst[i] = is_first_line_zero ? 0 : src[i];
72 prev = src[i];
73 has_prev_value = true;
74 }
75 else
76 {
77 auto cur = src[i];
78 dst[i] = static_cast<Dst>(cur) - prev;
79 prev = cur;
80 }
81 }
82 }
83
84 /// Result type is same as result of subtraction of argument types.
85 template <typename SrcFieldType>
86 using DstFieldType = typename NumberTraits::ResultOfSubtraction<SrcFieldType, SrcFieldType>::Type;
87
88 /// Call polymorphic lambda with tag argument of concrete field type of src_type.
89 template <typename F>
90 void dispatchForSourceType(const IDataType & src_type, F && f) const
91 {
92 WhichDataType which(src_type);
93
94 if (which.isUInt8())
95 f(UInt8());
96 else if (which.isUInt16())
97 f(UInt16());
98 else if (which.isUInt32())
99 f(UInt32());
100 else if (which.isUInt64())
101 f(UInt64());
102 else if (which.isInt8())
103 f(Int8());
104 else if (which.isInt16())
105 f(Int16());
106 else if (which.isInt32())
107 f(Int32());
108 else if (which.isInt64())
109 f(Int64());
110 else if (which.isFloat32())
111 f(Float32());
112 else if (which.isFloat64())
113 f(Float64());
114 else if (which.isDate())
115 f(DataTypeDate::FieldType());
116 else if (which.isDateTime())
117 f(DataTypeDateTime::FieldType());
118 else
119 throw Exception("Argument for function " + getName() + " must have numeric type.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
120 }
121
122public:
123 static constexpr auto name = FunctionRunningDifferenceName<is_first_line_zero>::name;
124
125 static FunctionPtr create(const Context &)
126 {
127 return std::make_shared<FunctionRunningDifferenceImpl<is_first_line_zero>>();
128 }
129
130 String getName() const override
131 {
132 return name;
133 }
134
135 bool isStateful() const override
136 {
137 return true;
138 }
139
140 size_t getNumberOfArguments() const override
141 {
142 return 1;
143 }
144
145 bool isDeterministic() const override { return false; }
146 bool isDeterministicInScopeOfQuery() const override
147 {
148 return false;
149 }
150
151 bool useDefaultImplementationForNulls() const override { return false; }
152
153 DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
154 {
155 DataTypePtr res;
156 dispatchForSourceType(*removeNullable(arguments[0]), [&](auto field_type_tag)
157 {
158 res = std::make_shared<DataTypeNumber<DstFieldType<decltype(field_type_tag)>>>();
159 });
160
161 if (arguments[0]->isNullable())
162 res = makeNullable(res);
163
164 return res;
165 }
166
167 void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t input_rows_count) override
168 {
169 auto & src = block.getByPosition(arguments.at(0));
170 const auto & res_type = block.getByPosition(result).type;
171
172 /// When column is constant, its difference is zero.
173 if (isColumnConst(*src.column))
174 {
175 block.getByPosition(result).column = res_type->createColumnConstWithDefaultValue(input_rows_count);
176 return;
177 }
178
179 auto res_column = removeNullable(res_type)->createColumn();
180 auto * src_column = src.column.get();
181 ColumnPtr null_map_column = nullptr;
182 const NullMap * null_map = nullptr;
183 if (auto * nullable_column = checkAndGetColumn<ColumnNullable>(src_column))
184 {
185 src_column = &nullable_column->getNestedColumn();
186 null_map_column = nullable_column->getNullMapColumnPtr();
187 null_map = &nullable_column->getNullMapData();
188 }
189
190 dispatchForSourceType(*removeNullable(src.type), [&](auto field_type_tag)
191 {
192 using SrcFieldType = decltype(field_type_tag);
193
194 process(assert_cast<const ColumnVector<SrcFieldType> &>(*src_column).getData(),
195 assert_cast<ColumnVector<DstFieldType<SrcFieldType>> &>(*res_column).getData(), null_map);
196 });
197
198 if (null_map_column)
199 block.getByPosition(result).column = ColumnNullable::create(std::move(res_column), null_map_column);
200 else
201 block.getByPosition(result).column = std::move(res_column);
202 }
203};
204
205}
206