1#include <DataStreams/AggregatingSortedBlockInputStream.h>
2#include <Common/typeid_cast.h>
3#include <Common/StringUtils/StringUtils.h>
4#include <Common/Arena.h>
5#include <DataTypes/DataTypeAggregateFunction.h>
6#include <DataTypes/DataTypeCustomSimpleAggregateFunction.h>
7#include <DataTypes/DataTypeLowCardinality.h>
8
9
10namespace DB
11{
12
13namespace ErrorCodes
14{
15 extern const int LOGICAL_ERROR;
16}
17
18
19class RemovingLowCardinalityBlockInputStream : public IBlockInputStream
20{
21public:
22 RemovingLowCardinalityBlockInputStream(BlockInputStreamPtr input_, ColumnNumbers positions_)
23 : input(std::move(input_)), positions(std::move(positions_))
24 {
25 header = transform(input->getHeader());
26 }
27
28 Block transform(Block block)
29 {
30 if (block)
31 {
32 for (auto & pos : positions)
33 {
34 auto & col = block.safeGetByPosition(pos);
35 col.column = recursiveRemoveLowCardinality(col.column);
36 col.type = recursiveRemoveLowCardinality(col.type);
37 }
38 }
39
40 return block;
41 }
42
43 String getName() const override { return "RemovingLowCardinality"; }
44 Block getHeader() const override { return header; }
45 const BlockMissingValues & getMissingValues() const override { return input->getMissingValues(); }
46 bool isSortedOutput() const override { return input->isSortedOutput(); }
47 const SortDescription & getSortDescription() const override { return input->getSortDescription(); }
48
49protected:
50 Block readImpl() override { return transform(input->read()); }
51
52private:
53 Block header;
54 BlockInputStreamPtr input;
55 ColumnNumbers positions;
56};
57
58
59AggregatingSortedBlockInputStream::AggregatingSortedBlockInputStream(
60 const BlockInputStreams & inputs_, const SortDescription & description_, size_t max_block_size_)
61 : MergingSortedBlockInputStream(inputs_, description_, max_block_size_)
62{
63 ColumnNumbers positions;
64
65 /// Fill in the column numbers that need to be aggregated.
66 for (size_t i = 0; i < num_columns; ++i)
67 {
68 ColumnWithTypeAndName & column = header.safeGetByPosition(i);
69
70 /// We leave only states of aggregate functions.
71 if (!dynamic_cast<const DataTypeAggregateFunction *>(column.type.get()) && !dynamic_cast<const DataTypeCustomSimpleAggregateFunction *>(column.type->getCustomName()))
72 {
73 column_numbers_not_to_aggregate.push_back(i);
74 continue;
75 }
76
77 /// Included into PK?
78 SortDescription::const_iterator it = description.begin();
79 for (; it != description.end(); ++it)
80 if (it->column_name == column.name || (it->column_name.empty() && it->column_number == i))
81 break;
82
83 if (it != description.end())
84 {
85 column_numbers_not_to_aggregate.push_back(i);
86 continue;
87 }
88
89 if (auto simple_aggr = dynamic_cast<const DataTypeCustomSimpleAggregateFunction *>(column.type->getCustomName()))
90 {
91 // simple aggregate function
92 SimpleAggregateDescription desc{simple_aggr->getFunction(), i};
93 if (desc.function->allocatesMemoryInArena())
94 allocatesMemoryInArena = true;
95
96 columns_to_simple_aggregate.emplace_back(std::move(desc));
97
98 if (recursiveRemoveLowCardinality(column.type).get() != column.type.get())
99 positions.emplace_back(i);
100 }
101 else
102 {
103 // standard aggregate function
104 column_numbers_to_aggregate.push_back(i);
105 }
106 }
107
108 if (!positions.empty())
109 {
110 for (auto & input : children)
111 input = std::make_shared<RemovingLowCardinalityBlockInputStream>(input, positions);
112
113 header = children.at(0)->getHeader();
114 }
115}
116
117
118Block AggregatingSortedBlockInputStream::readImpl()
119{
120 if (finished)
121 return Block();
122
123 MutableColumns merged_columns;
124 init(merged_columns);
125
126 if (has_collation)
127 throw Exception("Logical error: " + getName() + " does not support collations", ErrorCodes::LOGICAL_ERROR);
128
129 if (merged_columns.empty())
130 return Block();
131
132 columns_to_aggregate.resize(column_numbers_to_aggregate.size());
133 for (size_t i = 0, size = columns_to_aggregate.size(); i < size; ++i)
134 columns_to_aggregate[i] = typeid_cast<ColumnAggregateFunction *>(merged_columns[column_numbers_to_aggregate[i]].get());
135
136 merge(merged_columns, queue_without_collation);
137 return header.cloneWithColumns(std::move(merged_columns));
138}
139
140
141void AggregatingSortedBlockInputStream::merge(MutableColumns & merged_columns, std::priority_queue<SortCursor> & queue)
142{
143 size_t merged_rows = 0;
144
145 /// We take the rows in the correct order and put them in `merged_block`, while the rows are no more than `max_block_size`
146 while (!queue.empty())
147 {
148 SortCursor current = queue.top();
149
150 setPrimaryKeyRef(next_key, current);
151
152 bool key_differs;
153
154 if (current_key.empty()) /// The first key encountered.
155 {
156 setPrimaryKeyRef(current_key, current);
157 key_differs = true;
158 }
159 else
160 key_differs = next_key != current_key;
161
162 /// if there are enough rows accumulated and the last one is calculated completely
163 if (key_differs && merged_rows >= max_block_size)
164 {
165 /// Write the simple aggregation result for the previous group.
166 insertSimpleAggregationResult(merged_columns);
167 return;
168 }
169
170 queue.pop();
171
172 if (key_differs)
173 {
174 current_key.swap(next_key);
175
176 /// We will write the data for the group. We copy the values of ordinary columns.
177 for (size_t i = 0, size = column_numbers_not_to_aggregate.size(); i < size; ++i)
178 {
179 size_t j = column_numbers_not_to_aggregate[i];
180 merged_columns[j]->insertFrom(*current->all_columns[j], current->pos);
181 }
182
183 /// Add the empty aggregation state to the aggregate columns. The state will be updated in the `addRow` function.
184 for (auto & column_to_aggregate : columns_to_aggregate)
185 column_to_aggregate->insertDefault();
186
187 /// Write the simple aggregation result for the previous group.
188 if (merged_rows > 0)
189 insertSimpleAggregationResult(merged_columns);
190
191 /// Reset simple aggregation states for next row
192 for (auto & desc : columns_to_simple_aggregate)
193 desc.createState();
194
195 if (allocatesMemoryInArena)
196 arena = std::make_unique<Arena>();
197
198 ++merged_rows;
199 }
200
201 addRow(current);
202
203 if (!current->isLast())
204 {
205 current->next();
206 queue.push(current);
207 }
208 else
209 {
210 /// We fetch the next block from the appropriate source, if there is one.
211 fetchNextBlock(current, queue);
212 }
213 }
214
215 /// Write the simple aggregation result for the previous group.
216 if (merged_rows > 0)
217 insertSimpleAggregationResult(merged_columns);
218
219 finished = true;
220}
221
222
223void AggregatingSortedBlockInputStream::addRow(SortCursor & cursor)
224{
225 for (size_t i = 0, size = column_numbers_to_aggregate.size(); i < size; ++i)
226 {
227 size_t j = column_numbers_to_aggregate[i];
228 columns_to_aggregate[i]->insertMergeFrom(*cursor->all_columns[j], cursor->pos);
229 }
230
231 for (auto & desc : columns_to_simple_aggregate)
232 {
233 auto & col = cursor->all_columns[desc.column_number];
234 desc.add_function(desc.function.get(), desc.state.data(), &col, cursor->pos, arena.get());
235 }
236}
237
238void AggregatingSortedBlockInputStream::insertSimpleAggregationResult(MutableColumns & merged_columns)
239{
240 for (auto & desc : columns_to_simple_aggregate)
241 {
242 desc.function->insertResultInto(desc.state.data(), *merged_columns[desc.column_number]);
243 desc.destroyState();
244 }
245}
246
247}
248