1#include <Processors/Transforms/TotalsHavingTransform.h>
2#include <Processors/Transforms/AggregatingTransform.h>
3
4#include <Columns/ColumnAggregateFunction.h>
5#include <Columns/FilterDescription.h>
6
7#include <Common/typeid_cast.h>
8#include <DataStreams/finalizeBlock.h>
9#include <Interpreters/ExpressionActions.h>
10
11namespace DB
12{
13
14void finalizeChunk(Chunk & chunk)
15{
16 auto num_rows = chunk.getNumRows();
17 auto columns = chunk.detachColumns();
18
19 for (auto & column : columns)
20 if (auto * agg_function = typeid_cast<const ColumnAggregateFunction *>(column.get()))
21 column = agg_function->convertToValues();
22
23 chunk.setColumns(std::move(columns), num_rows);
24}
25
26static Block createOutputHeader(Block block, const ExpressionActionsPtr & expression, bool final)
27{
28 if (final)
29 finalizeBlock(block);
30
31 if (expression)
32 expression->execute(block);
33
34 return block;
35}
36
37TotalsHavingTransform::TotalsHavingTransform(
38 const Block & header,
39 bool overflow_row_,
40 const ExpressionActionsPtr & expression_,
41 const std::string & filter_column_,
42 TotalsMode totals_mode_,
43 double auto_include_threshold_,
44 bool final_)
45 : ISimpleTransform(header, createOutputHeader(header, expression_, final_), true)
46 , overflow_row(overflow_row_)
47 , expression(expression_)
48 , filter_column_name(filter_column_)
49 , totals_mode(totals_mode_)
50 , auto_include_threshold(auto_include_threshold_)
51 , final(final_)
52{
53 if (!filter_column_name.empty())
54 filter_column_pos = outputs.front().getHeader().getPositionByName(filter_column_name);
55
56 finalized_header = getInputPort().getHeader();
57 finalizeBlock(finalized_header);
58
59 /// Port for Totals.
60 if (expression)
61 {
62 auto totals_header = finalized_header;
63 expression->execute(totals_header);
64 outputs.emplace_back(totals_header, this);
65 }
66 else
67 outputs.emplace_back(finalized_header, this);
68
69 /// Initialize current totals with initial state.
70 current_totals.reserve(header.columns());
71 for (const auto & elem : header)
72 {
73 MutableColumnPtr new_column = elem.type->createColumn();
74 elem.type->insertDefaultInto(*new_column);
75 current_totals.emplace_back(std::move(new_column));
76 }
77}
78
79IProcessor::Status TotalsHavingTransform::prepare()
80{
81 if (!finished_transform)
82 {
83 auto status = ISimpleTransform::prepare();
84
85 if (status != Status::Finished)
86 return status;
87
88 finished_transform = true;
89 }
90
91 auto & totals_output = getTotalsPort();
92
93 /// Check can output.
94 if (totals_output.isFinished())
95 return Status::Finished;
96
97 if (!totals_output.canPush())
98 return Status::PortFull;
99
100 if (!totals)
101 return Status::Ready;
102
103 totals_output.push(std::move(totals));
104 totals_output.finish();
105 return Status::Finished;
106}
107
108void TotalsHavingTransform::work()
109{
110 if (finished_transform)
111 prepareTotals();
112 else
113 ISimpleTransform::work();
114}
115
116void TotalsHavingTransform::transform(Chunk & chunk)
117{
118 /// Block with values not included in `max_rows_to_group_by`. We'll postpone it.
119 if (overflow_row)
120 {
121 auto & info = chunk.getChunkInfo();
122 if (!info)
123 throw Exception("Chunk info was not set for chunk in TotalsHavingTransform.", ErrorCodes::LOGICAL_ERROR);
124
125 auto * agg_info = typeid_cast<const AggregatedChunkInfo *>(info.get());
126 if (!agg_info)
127 throw Exception("Chunk should have AggregatedChunkInfo in TotalsHavingTransform.", ErrorCodes::LOGICAL_ERROR);
128
129 if (agg_info->is_overflows)
130 {
131 overflow_aggregates = std::move(chunk);
132 return;
133 }
134 }
135
136 if (!chunk)
137 return;
138
139 auto finalized = chunk.clone();
140 if (final)
141 finalizeChunk(finalized);
142
143 total_keys += finalized.getNumRows();
144
145 if (filter_column_name.empty())
146 {
147 addToTotals(chunk, nullptr);
148 chunk = std::move(finalized);
149 }
150 else
151 {
152 /// Compute the expression in HAVING.
153 auto & cur_header = final ? finalized_header : getInputPort().getHeader();
154 auto finalized_block = cur_header.cloneWithColumns(finalized.detachColumns());
155 expression->execute(finalized_block);
156 auto columns = finalized_block.getColumns();
157
158 ColumnPtr filter_column_ptr = columns[filter_column_pos];
159 ConstantFilterDescription const_filter_description(*filter_column_ptr);
160
161 if (const_filter_description.always_true)
162 {
163 addToTotals(chunk, nullptr);
164 auto num_rows = columns.front()->size();
165 chunk.setColumns(std::move(columns), num_rows);
166 return;
167 }
168
169 if (const_filter_description.always_false)
170 {
171 if (totals_mode == TotalsMode::BEFORE_HAVING)
172 addToTotals(chunk, nullptr);
173
174 chunk.clear();
175 return;
176 }
177
178 FilterDescription filter_description(*filter_column_ptr);
179
180 /// Add values to `totals` (if it was not already done).
181 if (totals_mode == TotalsMode::BEFORE_HAVING)
182 addToTotals(chunk, nullptr);
183 else
184 addToTotals(chunk, filter_description.data);
185
186 /// Filter the block by expression in HAVING.
187 for (auto & column : columns)
188 {
189 column = column->filter(*filter_description.data, -1);
190 if (column->empty())
191 {
192 chunk.clear();
193 return;
194 }
195 }
196
197 auto num_rows = columns.front()->size();
198 chunk.setColumns(std::move(columns), num_rows);
199 }
200
201 passed_keys += chunk.getNumRows();
202}
203
204void TotalsHavingTransform::addToTotals(const Chunk & chunk, const IColumn::Filter * filter)
205{
206 auto num_columns = chunk.getNumColumns();
207 for (size_t col = 0; col < num_columns; ++col)
208 {
209 const auto & current = chunk.getColumns()[col];
210
211 if (const auto * column = typeid_cast<const ColumnAggregateFunction *>(current.get()))
212 {
213 auto & totals_column = typeid_cast<ColumnAggregateFunction &>(*current_totals[col]);
214 assert(totals_column.size() == 1);
215
216 /// Accumulate all aggregate states from a column of a source chunk into
217 /// the corresponding totals column.
218 const ColumnAggregateFunction::Container & vec = column->getData();
219 size_t size = vec.size();
220
221 if (filter)
222 {
223 for (size_t row = 0; row < size; ++row)
224 if ((*filter)[row])
225 totals_column.insertMergeFrom(vec[row]);
226 }
227 else
228 {
229 for (size_t row = 0; row < size; ++row)
230 totals_column.insertMergeFrom(vec[row]);
231 }
232 }
233 }
234}
235
236void TotalsHavingTransform::prepareTotals()
237{
238 /// If totals_mode == AFTER_HAVING_AUTO, you need to decide whether to add aggregates to TOTALS for strings,
239 /// not passed max_rows_to_group_by.
240 if (overflow_aggregates)
241 {
242 if (totals_mode == TotalsMode::BEFORE_HAVING
243 || totals_mode == TotalsMode::AFTER_HAVING_INCLUSIVE
244 || (totals_mode == TotalsMode::AFTER_HAVING_AUTO
245 && static_cast<double>(passed_keys) / total_keys >= auto_include_threshold))
246 addToTotals(overflow_aggregates, nullptr);
247 }
248
249 totals = Chunk(std::move(current_totals), 1);
250 finalizeChunk(totals);
251
252 if (expression)
253 {
254 auto block = finalized_header.cloneWithColumns(totals.detachColumns());
255 expression->execute(block);
256 totals = Chunk(block.getColumns(), 1);
257 }
258}
259
260}
261