| 1 | #include "duckdb/execution/operator/aggregate/distinct_aggregate_data.hpp" | 
| 2 | #include "duckdb/planner/expression.hpp" | 
| 3 | #include "duckdb/planner/expression/bound_aggregate_expression.hpp" | 
| 4 | #include "duckdb/planner/expression/bound_reference_expression.hpp" | 
| 5 | #include "duckdb/common/algorithm.hpp" | 
| 6 |  | 
| 7 | namespace duckdb { | 
| 8 |  | 
| 9 | //! Shared information about a collection of distinct aggregates | 
| 10 | DistinctAggregateCollectionInfo::DistinctAggregateCollectionInfo(const vector<unique_ptr<Expression>> &aggregates, | 
| 11 |                                                                  vector<idx_t> indices) | 
| 12 |     : indices(std::move(indices)), aggregates(aggregates) { | 
| 13 | 	table_count = CreateTableIndexMap(); | 
| 14 |  | 
| 15 | 	const idx_t aggregate_count = aggregates.size(); | 
| 16 |  | 
| 17 | 	total_child_count = 0; | 
| 18 | 	for (idx_t i = 0; i < aggregate_count; i++) { | 
| 19 | 		auto &aggregate = aggregates[i]->Cast<BoundAggregateExpression>(); | 
| 20 |  | 
| 21 | 		if (!aggregate.IsDistinct()) { | 
| 22 | 			continue; | 
| 23 | 		} | 
| 24 | 		total_child_count += aggregate.children.size(); | 
| 25 | 	} | 
| 26 | } | 
| 27 |  | 
| 28 | //! Stateful data for the distinct aggregates | 
| 29 |  | 
| 30 | DistinctAggregateState::DistinctAggregateState(const DistinctAggregateData &data, ClientContext &client) | 
| 31 |     : child_executor(client) { | 
| 32 |  | 
| 33 | 	radix_states.resize(new_size: data.info.table_count); | 
| 34 | 	distinct_output_chunks.resize(new_size: data.info.table_count); | 
| 35 |  | 
| 36 | 	idx_t aggregate_count = data.info.aggregates.size(); | 
| 37 | 	for (idx_t i = 0; i < aggregate_count; i++) { | 
| 38 | 		auto &aggregate = data.info.aggregates[i]->Cast<BoundAggregateExpression>(); | 
| 39 |  | 
| 40 | 		// Initialize the child executor and get the payload types for every aggregate | 
| 41 | 		for (auto &child : aggregate.children) { | 
| 42 | 			child_executor.AddExpression(expr: *child); | 
| 43 | 		} | 
| 44 | 		if (!aggregate.IsDistinct()) { | 
| 45 | 			continue; | 
| 46 | 		} | 
| 47 | 		D_ASSERT(data.info.table_map.count(i)); | 
| 48 | 		idx_t table_idx = data.info.table_map.at(k: i); | 
| 49 | 		if (data.radix_tables[table_idx] == nullptr) { | 
| 50 | 			//! This table is unused because the aggregate shares its data with another | 
| 51 | 			continue; | 
| 52 | 		} | 
| 53 |  | 
| 54 | 		// Get the global sinkstate for the aggregate | 
| 55 | 		auto &radix_table = *data.radix_tables[table_idx]; | 
| 56 | 		radix_states[table_idx] = radix_table.GetGlobalSinkState(context&: client); | 
| 57 |  | 
| 58 | 		// Fill the chunk_types (group_by + children) | 
| 59 | 		vector<LogicalType> chunk_types; | 
| 60 | 		for (auto &group_type : data.grouped_aggregate_data[table_idx]->group_types) { | 
| 61 | 			chunk_types.push_back(x: group_type); | 
| 62 | 		} | 
| 63 |  | 
| 64 | 		// This is used in Finalize to get the data from the radix table | 
| 65 | 		distinct_output_chunks[table_idx] = make_uniq<DataChunk>(); | 
| 66 | 		distinct_output_chunks[table_idx]->Initialize(context&: client, types: chunk_types); | 
| 67 | 	} | 
| 68 | } | 
| 69 |  | 
| 70 | //! Persistent + shared (read-only) data for the distinct aggregates | 
| 71 | DistinctAggregateData::DistinctAggregateData(const DistinctAggregateCollectionInfo &info) | 
| 72 |     : DistinctAggregateData(info, {}, nullptr) { | 
| 73 | } | 
| 74 |  | 
| 75 | DistinctAggregateData::DistinctAggregateData(const DistinctAggregateCollectionInfo &info, const GroupingSet &groups, | 
| 76 |                                              const vector<unique_ptr<Expression>> *group_expressions) | 
| 77 |     : info(info) { | 
| 78 | 	grouped_aggregate_data.resize(new_size: info.table_count); | 
| 79 | 	radix_tables.resize(new_size: info.table_count); | 
| 80 | 	grouping_sets.resize(new_size: info.table_count); | 
| 81 |  | 
| 82 | 	for (auto &i : info.indices) { | 
| 83 | 		auto &aggregate = info.aggregates[i]->Cast<BoundAggregateExpression>(); | 
| 84 |  | 
| 85 | 		D_ASSERT(info.table_map.count(i)); | 
| 86 | 		idx_t table_idx = info.table_map.at(k: i); | 
| 87 | 		if (radix_tables[table_idx] != nullptr) { | 
| 88 | 			//! This aggregate shares a table with another aggregate, and the table is already initialized | 
| 89 | 			continue; | 
| 90 | 		} | 
| 91 | 		// The grouping set contains the indices of the chunk that correspond to the data vector | 
| 92 | 		// that will be used to figure out in which bucket the payload should be put | 
| 93 | 		auto &grouping_set = grouping_sets[table_idx]; | 
| 94 | 		//! Populate the group with the children of the aggregate | 
| 95 | 		for (auto &group : groups) { | 
| 96 | 			grouping_set.insert(x: group); | 
| 97 | 		} | 
| 98 | 		idx_t group_by_size = group_expressions ? group_expressions->size() : 0; | 
| 99 | 		for (idx_t set_idx = 0; set_idx < aggregate.children.size(); set_idx++) { | 
| 100 | 			grouping_set.insert(x: set_idx + group_by_size); | 
| 101 | 		} | 
| 102 | 		// Create the hashtable for the aggregate | 
| 103 | 		grouped_aggregate_data[table_idx] = make_uniq<GroupedAggregateData>(); | 
| 104 | 		grouped_aggregate_data[table_idx]->InitializeDistinct(aggregate: info.aggregates[i], groups_p: group_expressions); | 
| 105 | 		radix_tables[table_idx] = | 
| 106 | 		    make_uniq<RadixPartitionedHashTable>(args&: grouping_set, args&: *grouped_aggregate_data[table_idx]); | 
| 107 |  | 
| 108 | 		// Fill the chunk_types (only contains the payload of the distinct aggregates) | 
| 109 | 		vector<LogicalType> chunk_types; | 
| 110 | 		for (auto &child_p : aggregate.children) { | 
| 111 | 			chunk_types.push_back(x: child_p->return_type); | 
| 112 | 		} | 
| 113 | 	} | 
| 114 | } | 
| 115 |  | 
| 116 | using aggr_ref_t = reference<BoundAggregateExpression>; | 
| 117 |  | 
| 118 | struct FindMatchingAggregate { | 
| 119 | 	explicit FindMatchingAggregate(const aggr_ref_t &aggr) : aggr_r(aggr) { | 
| 120 | 	} | 
| 121 | 	bool operator()(const aggr_ref_t other_r) { | 
| 122 | 		auto &other = other_r.get(); | 
| 123 | 		auto &aggr = aggr_r.get(); | 
| 124 | 		if (other.children.size() != aggr.children.size()) { | 
| 125 | 			return false; | 
| 126 | 		} | 
| 127 | 		if (!Expression::Equals(left: aggr.filter, right: other.filter)) { | 
| 128 | 			return false; | 
| 129 | 		} | 
| 130 | 		for (idx_t i = 0; i < aggr.children.size(); i++) { | 
| 131 | 			auto &other_child = other.children[i]->Cast<BoundReferenceExpression>(); | 
| 132 | 			auto &aggr_child = aggr.children[i]->Cast<BoundReferenceExpression>(); | 
| 133 | 			if (other_child.index != aggr_child.index) { | 
| 134 | 				return false; | 
| 135 | 			} | 
| 136 | 		} | 
| 137 | 		return true; | 
| 138 | 	} | 
| 139 | 	const aggr_ref_t aggr_r; | 
| 140 | }; | 
| 141 |  | 
| 142 | idx_t DistinctAggregateCollectionInfo::CreateTableIndexMap() { | 
| 143 | 	vector<aggr_ref_t> table_inputs; | 
| 144 |  | 
| 145 | 	D_ASSERT(table_map.empty()); | 
| 146 | 	for (auto &agg_idx : indices) { | 
| 147 | 		D_ASSERT(agg_idx < aggregates.size()); | 
| 148 | 		auto &aggregate = aggregates[agg_idx]->Cast<BoundAggregateExpression>(); | 
| 149 |  | 
| 150 | 		auto matching_inputs = | 
| 151 | 		    std::find_if(first: table_inputs.begin(), last: table_inputs.end(), pred: FindMatchingAggregate(std::ref(t&: aggregate))); | 
| 152 | 		if (matching_inputs != table_inputs.end()) { | 
| 153 | 			//! Assign the existing table to the aggregate | 
| 154 | 			idx_t found_idx = std::distance(first: table_inputs.begin(), last: matching_inputs); | 
| 155 | 			table_map[agg_idx] = found_idx; | 
| 156 | 			continue; | 
| 157 | 		} | 
| 158 | 		//! Create a new table and assign its index to the aggregate | 
| 159 | 		table_map[agg_idx] = table_inputs.size(); | 
| 160 | 		table_inputs.push_back(x: std::ref(t&: aggregate)); | 
| 161 | 	} | 
| 162 | 	//! Every distinct aggregate needs to be assigned an index | 
| 163 | 	D_ASSERT(table_map.size() == indices.size()); | 
| 164 | 	//! There can not be more tables than there are distinct aggregates | 
| 165 | 	D_ASSERT(table_inputs.size() <= indices.size()); | 
| 166 |  | 
| 167 | 	return table_inputs.size(); | 
| 168 | } | 
| 169 |  | 
| 170 | bool DistinctAggregateCollectionInfo::AnyDistinct() const { | 
| 171 | 	return !indices.empty(); | 
| 172 | } | 
| 173 |  | 
| 174 | const unsafe_vector<idx_t> &DistinctAggregateCollectionInfo::Indices() const { | 
| 175 | 	return this->indices; | 
| 176 | } | 
| 177 |  | 
| 178 | static vector<idx_t> GetDistinctIndices(vector<unique_ptr<Expression>> &aggregates) { | 
| 179 | 	vector<idx_t> distinct_indices; | 
| 180 | 	for (idx_t i = 0; i < aggregates.size(); i++) { | 
| 181 | 		auto &aggregate = aggregates[i]; | 
| 182 | 		auto &aggr = aggregate->Cast<BoundAggregateExpression>(); | 
| 183 | 		if (aggr.IsDistinct()) { | 
| 184 | 			distinct_indices.push_back(x: i); | 
| 185 | 		} | 
| 186 | 	} | 
| 187 | 	return distinct_indices; | 
| 188 | } | 
| 189 |  | 
| 190 | unique_ptr<DistinctAggregateCollectionInfo> | 
| 191 | DistinctAggregateCollectionInfo::Create(vector<unique_ptr<Expression>> &aggregates) { | 
| 192 | 	vector<idx_t> indices = GetDistinctIndices(aggregates); | 
| 193 | 	if (indices.empty()) { | 
| 194 | 		return nullptr; | 
| 195 | 	} | 
| 196 | 	return make_uniq<DistinctAggregateCollectionInfo>(args&: aggregates, args: std::move(indices)); | 
| 197 | } | 
| 198 |  | 
| 199 | bool DistinctAggregateData::IsDistinct(idx_t index) const { | 
| 200 | 	bool is_distinct = !radix_tables.empty() && info.table_map.count(x: index); | 
| 201 | #ifdef DEBUG | 
| 202 | 	//! Make sure that if it is distinct, it's also in the indices | 
| 203 | 	//! And if it's not distinct, that it's also not in the indices | 
| 204 | 	bool found = false; | 
| 205 | 	for (auto &idx : info.indices) { | 
| 206 | 		if (idx == index) { | 
| 207 | 			found = true; | 
| 208 | 			break; | 
| 209 | 		} | 
| 210 | 	} | 
| 211 | 	D_ASSERT(found == is_distinct); | 
| 212 | #endif | 
| 213 | 	return is_distinct; | 
| 214 | } | 
| 215 |  | 
| 216 | } // namespace duckdb | 
| 217 |  |