1#include <DataStreams/TotalsHavingBlockInputStream.h>
2#include <DataStreams/finalizeBlock.h>
3#include <Interpreters/ExpressionActions.h>
4#include <DataTypes/DataTypeAggregateFunction.h>
5#include <Columns/ColumnAggregateFunction.h>
6#include <Columns/FilterDescription.h>
7#include <Common/typeid_cast.h>
8#include <Common/assert_cast.h>
9#include <Common/Arena.h>
10
11
12namespace DB
13{
14
15
16TotalsHavingBlockInputStream::TotalsHavingBlockInputStream(
17 const BlockInputStreamPtr & input_,
18 bool overflow_row_, const ExpressionActionsPtr & expression_,
19 const std::string & filter_column_, TotalsMode totals_mode_, double auto_include_threshold_, bool final_)
20 : overflow_row(overflow_row_),
21 expression(expression_), filter_column_name(filter_column_), totals_mode(totals_mode_),
22 auto_include_threshold(auto_include_threshold_), final(final_)
23{
24 children.push_back(input_);
25
26 /// Initialize current totals with initial state.
27
28 Block source_header = children.at(0)->getHeader();
29
30 current_totals.reserve(source_header.columns());
31 for (const auto & elem : source_header)
32 {
33 // Create a column with default value
34 MutableColumnPtr new_column = elem.type->createColumn();
35 elem.type->insertDefaultInto(*new_column);
36 current_totals.emplace_back(std::move(new_column));
37 }
38}
39
40
41Block TotalsHavingBlockInputStream::getTotals()
42{
43 if (!totals)
44 {
45 /** If totals_mode == AFTER_HAVING_AUTO, you need to decide whether to add aggregates to TOTALS for strings,
46 * not passed max_rows_to_group_by.
47 */
48 if (overflow_aggregates)
49 {
50 if (totals_mode == TotalsMode::BEFORE_HAVING
51 || totals_mode == TotalsMode::AFTER_HAVING_INCLUSIVE
52 || (totals_mode == TotalsMode::AFTER_HAVING_AUTO
53 && static_cast<double>(passed_keys) / total_keys >= auto_include_threshold))
54 addToTotals(overflow_aggregates, nullptr);
55 }
56
57 totals = children.at(0)->getHeader().cloneWithColumns(std::move(current_totals));
58 finalizeBlock(totals);
59 }
60
61 if (totals && expression)
62 expression->execute(totals);
63
64 return totals;
65}
66
67
68Block TotalsHavingBlockInputStream::getHeader() const
69{
70 Block res = children.at(0)->getHeader();
71 if (final)
72 finalizeBlock(res);
73 if (expression)
74 expression->execute(res);
75 return res;
76}
77
78
79Block TotalsHavingBlockInputStream::readImpl()
80{
81 Block finalized;
82 Block block;
83
84 while (1)
85 {
86 block = children[0]->read();
87
88 /// Block with values not included in `max_rows_to_group_by`. We'll postpone it.
89 if (overflow_row && block && block.info.is_overflows)
90 {
91 overflow_aggregates = block;
92 continue;
93 }
94
95 if (!block)
96 return finalized;
97
98 finalized = block;
99 if (final)
100 finalizeBlock(finalized);
101
102 total_keys += finalized.rows();
103
104 if (filter_column_name.empty())
105 {
106 addToTotals(block, nullptr);
107 }
108 else
109 {
110 /// Compute the expression in HAVING.
111 expression->execute(finalized);
112
113 size_t filter_column_pos = finalized.getPositionByName(filter_column_name);
114 ColumnPtr filter_column_ptr = finalized.safeGetByPosition(filter_column_pos).column->convertToFullColumnIfConst();
115
116 FilterDescription filter_description(*filter_column_ptr);
117
118 /// Add values to `totals` (if it was not already done).
119 if (totals_mode == TotalsMode::BEFORE_HAVING)
120 addToTotals(block, nullptr);
121 else
122 addToTotals(block, filter_description.data);
123
124 /// Filter the block by expression in HAVING.
125 size_t columns = finalized.columns();
126
127 for (size_t i = 0; i < columns; ++i)
128 {
129 ColumnWithTypeAndName & current_column = finalized.safeGetByPosition(i);
130 current_column.column = current_column.column->filter(*filter_description.data, -1);
131 if (current_column.column->empty())
132 {
133 finalized.clear();
134 break;
135 }
136 }
137 }
138
139 if (!finalized)
140 continue;
141
142 passed_keys += finalized.rows();
143 return finalized;
144 }
145}
146
147
148void TotalsHavingBlockInputStream::addToTotals(const Block & source_block, const IColumn::Filter * filter)
149{
150 for (size_t i = 0, num_columns = source_block.columns(); i < num_columns; ++i)
151 {
152 const auto * source_column = typeid_cast<const ColumnAggregateFunction *>(
153 source_block.getByPosition(i).column.get());
154 if (!source_column)
155 {
156 continue;
157 }
158
159 auto & totals_column = assert_cast<ColumnAggregateFunction &>(*current_totals[i]);
160 assert(totals_column.size() == 1);
161
162 /// Accumulate all aggregate states from a column of a source block into
163 /// the corresponding totals column.
164 const auto & vec = source_column->getData();
165 size_t size = vec.size();
166
167 if (filter)
168 {
169 for (size_t j = 0; j < size; ++j)
170 if ((*filter)[j])
171 totals_column.insertMergeFrom(vec[j]);
172 }
173 else
174 {
175 for (size_t j = 0; j < size; ++j)
176 totals_column.insertMergeFrom(vec[j]);
177 }
178 }
179}
180
181}
182