1 | #pragma once |
2 | |
3 | #include <AggregateFunctions/IAggregateFunction.h> |
4 | #include <Columns/ColumnArray.h> |
5 | #include <DataTypes/DataTypeArray.h> |
6 | #include <Common/assert_cast.h> |
7 | |
8 | |
9 | namespace DB |
10 | { |
11 | |
12 | namespace ErrorCodes |
13 | { |
14 | extern const int ARGUMENT_OUT_OF_BOUND; |
15 | } |
16 | |
17 | template <typename Key> |
18 | class AggregateFunctionResample final : public IAggregateFunctionHelper<AggregateFunctionResample<Key>> |
19 | { |
20 | private: |
21 | const size_t MAX_ELEMENTS = 4096; |
22 | |
23 | AggregateFunctionPtr nested_function; |
24 | |
25 | size_t last_col; |
26 | |
27 | Key begin; |
28 | Key end; |
29 | size_t step; |
30 | |
31 | size_t total; |
32 | size_t align_of_data; |
33 | size_t size_of_data; |
34 | |
35 | public: |
36 | AggregateFunctionResample( |
37 | AggregateFunctionPtr nested_function_, |
38 | Key begin_, |
39 | Key end_, |
40 | size_t step_, |
41 | const DataTypes & arguments, |
42 | const Array & params) |
43 | : IAggregateFunctionHelper<AggregateFunctionResample<Key>>{arguments, params} |
44 | , nested_function{nested_function_} |
45 | , last_col{arguments.size() - 1} |
46 | , begin{begin_} |
47 | , end{end_} |
48 | , step{step_} |
49 | , total{0} |
50 | , align_of_data{nested_function->alignOfData()} |
51 | , size_of_data{(nested_function->sizeOfData() + align_of_data - 1) / align_of_data * align_of_data} |
52 | { |
53 | // notice: argument types has been checked before |
54 | if (step == 0) |
55 | throw Exception("The step given in function " |
56 | + getName() + " should not be zero" , |
57 | ErrorCodes::ARGUMENT_OUT_OF_BOUND); |
58 | |
59 | if (end < begin) |
60 | total = 0; |
61 | else |
62 | total = (end - begin + step - 1) / step; |
63 | |
64 | if (total > MAX_ELEMENTS) |
65 | throw Exception("The range given in function " |
66 | + getName() + " contains too many elements" , |
67 | ErrorCodes::ARGUMENT_OUT_OF_BOUND); |
68 | } |
69 | |
70 | String getName() const override |
71 | { |
72 | return nested_function->getName() + "Resample" ; |
73 | } |
74 | |
75 | bool isState() const override |
76 | { |
77 | return nested_function->isState(); |
78 | } |
79 | |
80 | bool allocatesMemoryInArena() const override |
81 | { |
82 | return nested_function->allocatesMemoryInArena(); |
83 | } |
84 | |
85 | bool hasTrivialDestructor() const override |
86 | { |
87 | return nested_function->hasTrivialDestructor(); |
88 | } |
89 | |
90 | size_t sizeOfData() const override |
91 | { |
92 | return total * size_of_data; |
93 | } |
94 | |
95 | size_t alignOfData() const override |
96 | { |
97 | return align_of_data; |
98 | } |
99 | |
100 | void create(AggregateDataPtr place) const override |
101 | { |
102 | for (size_t i = 0; i < total; ++i) |
103 | { |
104 | try |
105 | { |
106 | nested_function->create(place + i * size_of_data); |
107 | } |
108 | catch (...) |
109 | { |
110 | for (size_t j = 0; j < i; ++j) |
111 | nested_function->destroy(place + j * size_of_data); |
112 | throw; |
113 | } |
114 | } |
115 | } |
116 | |
117 | void destroy(AggregateDataPtr place) const noexcept override |
118 | { |
119 | for (size_t i = 0; i < total; ++i) |
120 | nested_function->destroy(place + i * size_of_data); |
121 | } |
122 | |
123 | void add( |
124 | AggregateDataPtr place, |
125 | const IColumn ** columns, |
126 | size_t row_num, |
127 | Arena * arena) const override |
128 | { |
129 | Key key; |
130 | |
131 | if constexpr (static_cast<Key>(-1) < 0) |
132 | key = columns[last_col]->getInt(row_num); |
133 | else |
134 | key = columns[last_col]->getUInt(row_num); |
135 | |
136 | if (key < begin || key >= end) |
137 | return; |
138 | |
139 | size_t pos = (key - begin) / step; |
140 | |
141 | nested_function->add(place + pos * size_of_data, columns, row_num, arena); |
142 | } |
143 | |
144 | void merge( |
145 | AggregateDataPtr place, |
146 | ConstAggregateDataPtr rhs, |
147 | Arena * arena) const override |
148 | { |
149 | for (size_t i = 0; i < total; ++i) |
150 | nested_function->merge(place + i * size_of_data, rhs + i * size_of_data, arena); |
151 | } |
152 | |
153 | void serialize( |
154 | ConstAggregateDataPtr place, |
155 | WriteBuffer & buf) const override |
156 | { |
157 | for (size_t i = 0; i < total; ++i) |
158 | nested_function->serialize(place + i * size_of_data, buf); |
159 | } |
160 | |
161 | void deserialize( |
162 | AggregateDataPtr place, |
163 | ReadBuffer & buf, |
164 | Arena * arena) const override |
165 | { |
166 | for (size_t i = 0; i < total; ++i) |
167 | nested_function->deserialize(place + i * size_of_data, buf, arena); |
168 | } |
169 | |
170 | DataTypePtr getReturnType() const override |
171 | { |
172 | return std::make_shared<DataTypeArray>(nested_function->getReturnType()); |
173 | } |
174 | |
175 | void insertResultInto( |
176 | ConstAggregateDataPtr place, |
177 | IColumn & to) const override |
178 | { |
179 | auto & col = assert_cast<ColumnArray &>(to); |
180 | auto & col_offsets = assert_cast<ColumnArray::ColumnOffsets &>(col.getOffsetsColumn()); |
181 | |
182 | for (size_t i = 0; i < total; ++i) |
183 | nested_function->insertResultInto(place + i * size_of_data, col.getData()); |
184 | |
185 | col_offsets.getData().push_back(col.getData().size()); |
186 | } |
187 | }; |
188 | |
189 | } |
190 | |