| 1 | #include "duckdb/planner/binder.hpp" |
| 2 | #include "duckdb/parser/tableref/pivotref.hpp" |
| 3 | #include "duckdb/parser/tableref/subqueryref.hpp" |
| 4 | #include "duckdb/parser/query_node/select_node.hpp" |
| 5 | #include "duckdb/parser/expression/case_expression.hpp" |
| 6 | #include "duckdb/parser/expression/cast_expression.hpp" |
| 7 | #include "duckdb/parser/expression/columnref_expression.hpp" |
| 8 | #include "duckdb/parser/expression/comparison_expression.hpp" |
| 9 | #include "duckdb/parser/expression/conjunction_expression.hpp" |
| 10 | #include "duckdb/parser/expression/constant_expression.hpp" |
| 11 | #include "duckdb/parser/expression/function_expression.hpp" |
| 12 | #include "duckdb/planner/query_node/bound_select_node.hpp" |
| 13 | #include "duckdb/parser/expression/star_expression.hpp" |
| 14 | #include "duckdb/common/types/value_map.hpp" |
| 15 | #include "duckdb/parser/parsed_expression_iterator.hpp" |
| 16 | #include "duckdb/parser/expression/operator_expression.hpp" |
| 17 | #include "duckdb/planner/tableref/bound_subqueryref.hpp" |
| 18 | #include "duckdb/planner/tableref/bound_pivotref.hpp" |
| 19 | #include "duckdb/planner/expression/bound_aggregate_expression.hpp" |
| 20 | #include "duckdb/main/client_config.hpp" |
| 21 | |
| 22 | namespace duckdb { |
| 23 | |
| 24 | static void ConstructPivots(PivotRef &ref, vector<PivotValueElement> &pivot_values, idx_t pivot_idx = 0, |
| 25 | const PivotValueElement ¤t_value = PivotValueElement()) { |
| 26 | auto &pivot = ref.pivots[pivot_idx]; |
| 27 | bool last_pivot = pivot_idx + 1 == ref.pivots.size(); |
| 28 | for (auto &entry : pivot.entries) { |
| 29 | PivotValueElement new_value = current_value; |
| 30 | string name = entry.alias; |
| 31 | D_ASSERT(entry.values.size() == pivot.pivot_expressions.size()); |
| 32 | for (idx_t v = 0; v < entry.values.size(); v++) { |
| 33 | auto &value = entry.values[v]; |
| 34 | new_value.values.push_back(x: value); |
| 35 | if (entry.alias.empty()) { |
| 36 | if (name.empty()) { |
| 37 | name = value.ToString(); |
| 38 | } else { |
| 39 | name += "_" + value.ToString(); |
| 40 | } |
| 41 | } |
| 42 | } |
| 43 | if (!current_value.name.empty()) { |
| 44 | new_value.name = current_value.name + "_" + name; |
| 45 | } else { |
| 46 | new_value.name = std::move(name); |
| 47 | } |
| 48 | if (last_pivot) { |
| 49 | pivot_values.push_back(x: std::move(new_value)); |
| 50 | } else { |
| 51 | // need to recurse |
| 52 | ConstructPivots(ref, pivot_values, pivot_idx: pivot_idx + 1, current_value: new_value); |
| 53 | } |
| 54 | } |
| 55 | } |
| 56 | |
| 57 | static void (ParsedExpression &expr, case_insensitive_set_t &handled_columns) { |
| 58 | if (expr.type == ExpressionType::COLUMN_REF) { |
| 59 | auto &child_colref = expr.Cast<ColumnRefExpression>(); |
| 60 | if (child_colref.IsQualified()) { |
| 61 | throw BinderException("PIVOT expression cannot contain qualified columns" ); |
| 62 | } |
| 63 | handled_columns.insert(x: child_colref.GetColumnName()); |
| 64 | } |
| 65 | ParsedExpressionIterator::EnumerateChildren( |
| 66 | expr, callback: [&](ParsedExpression &child) { ExtractPivotExpressions(expr&: child, handled_columns); }); |
| 67 | } |
| 68 | |
| 69 | struct PivotBindState { |
| 70 | vector<string> internal_group_names; |
| 71 | vector<string> group_names; |
| 72 | vector<string> aggregate_names; |
| 73 | vector<string> internal_aggregate_names; |
| 74 | }; |
| 75 | |
| 76 | static unique_ptr<SelectNode> PivotInitialAggregate(PivotBindState &bind_state, PivotRef &ref, |
| 77 | vector<unique_ptr<ParsedExpression>> all_columns, |
| 78 | const case_insensitive_set_t &handled_columns) { |
| 79 | auto subquery_stage1 = make_uniq<SelectNode>(); |
| 80 | subquery_stage1->from_table = std::move(ref.source); |
| 81 | if (ref.groups.empty()) { |
| 82 | // if rows are not specified any columns that are not pivoted/aggregated on are added to the GROUP BY clause |
| 83 | for (auto &entry : all_columns) { |
| 84 | if (entry->type != ExpressionType::COLUMN_REF) { |
| 85 | throw InternalException("Unexpected child of pivot source - not a ColumnRef" ); |
| 86 | } |
| 87 | auto &columnref = entry->Cast<ColumnRefExpression>(); |
| 88 | if (handled_columns.find(x: columnref.GetColumnName()) == handled_columns.end()) { |
| 89 | // not handled - add to grouping set |
| 90 | subquery_stage1->groups.group_expressions.push_back( |
| 91 | x: make_uniq<ConstantExpression>(args: Value::INTEGER(value: subquery_stage1->select_list.size() + 1))); |
| 92 | subquery_stage1->select_list.push_back(x: make_uniq<ColumnRefExpression>(args: columnref.GetColumnName())); |
| 93 | } |
| 94 | } |
| 95 | } else { |
| 96 | // if rows are specified only the columns mentioned in rows are added as groups |
| 97 | for (auto &row : ref.groups) { |
| 98 | subquery_stage1->groups.group_expressions.push_back( |
| 99 | x: make_uniq<ConstantExpression>(args: Value::INTEGER(value: subquery_stage1->select_list.size() + 1))); |
| 100 | subquery_stage1->select_list.push_back(x: make_uniq<ColumnRefExpression>(args&: row)); |
| 101 | } |
| 102 | } |
| 103 | idx_t group_count = 0; |
| 104 | for (auto &expr : subquery_stage1->select_list) { |
| 105 | bind_state.group_names.push_back(x: expr->GetName()); |
| 106 | if (expr->alias.empty()) { |
| 107 | expr->alias = "__internal_pivot_group" + std::to_string(val: ++group_count); |
| 108 | } |
| 109 | bind_state.internal_group_names.push_back(x: expr->alias); |
| 110 | } |
| 111 | // group by all of the pivot values |
| 112 | idx_t pivot_count = 0; |
| 113 | for (auto &pivot_column : ref.pivots) { |
| 114 | for (auto &pivot_expr : pivot_column.pivot_expressions) { |
| 115 | if (pivot_expr->alias.empty()) { |
| 116 | pivot_expr->alias = "__internal_pivot_ref" + std::to_string(val: ++pivot_count); |
| 117 | } |
| 118 | auto pivot_alias = pivot_expr->alias; |
| 119 | subquery_stage1->groups.group_expressions.push_back( |
| 120 | x: make_uniq<ConstantExpression>(args: Value::INTEGER(value: subquery_stage1->select_list.size() + 1))); |
| 121 | subquery_stage1->select_list.push_back(x: std::move(pivot_expr)); |
| 122 | pivot_expr = make_uniq<ColumnRefExpression>(args: std::move(pivot_alias)); |
| 123 | } |
| 124 | } |
| 125 | idx_t aggregate_count = 0; |
| 126 | // finally add the aggregates |
| 127 | for (auto &aggregate : ref.aggregates) { |
| 128 | auto aggregate_alias = "__internal_pivot_aggregate" + std::to_string(val: ++aggregate_count); |
| 129 | bind_state.aggregate_names.push_back(x: aggregate->alias); |
| 130 | bind_state.internal_aggregate_names.push_back(x: aggregate_alias); |
| 131 | aggregate->alias = std::move(aggregate_alias); |
| 132 | subquery_stage1->select_list.push_back(x: std::move(aggregate)); |
| 133 | } |
| 134 | return subquery_stage1; |
| 135 | } |
| 136 | |
| 137 | static unique_ptr<SelectNode> PivotListAggregate(PivotBindState &bind_state, PivotRef &ref, |
| 138 | unique_ptr<SelectNode> subquery_stage1) { |
| 139 | auto subquery_stage2 = make_uniq<SelectNode>(); |
| 140 | // wrap the subquery of stage 1 |
| 141 | auto subquery_select = make_uniq<SelectStatement>(); |
| 142 | subquery_select->node = std::move(subquery_stage1); |
| 143 | auto subquery_ref = make_uniq<SubqueryRef>(args: std::move(subquery_select)); |
| 144 | |
| 145 | // add all of the groups |
| 146 | for (idx_t gr = 0; gr < bind_state.internal_group_names.size(); gr++) { |
| 147 | subquery_stage2->groups.group_expressions.push_back( |
| 148 | x: make_uniq<ConstantExpression>(args: Value::INTEGER(value: subquery_stage2->select_list.size() + 1))); |
| 149 | auto group_reference = make_uniq<ColumnRefExpression>(args&: bind_state.internal_group_names[gr]); |
| 150 | group_reference->alias = bind_state.internal_group_names[gr]; |
| 151 | subquery_stage2->select_list.push_back(x: std::move(group_reference)); |
| 152 | } |
| 153 | |
| 154 | // construct the list aggregates |
| 155 | for (idx_t aggr = 0; aggr < bind_state.internal_aggregate_names.size(); aggr++) { |
| 156 | auto colref = make_uniq<ColumnRefExpression>(args&: bind_state.internal_aggregate_names[aggr]); |
| 157 | vector<unique_ptr<ParsedExpression>> list_children; |
| 158 | list_children.push_back(x: std::move(colref)); |
| 159 | auto aggregate = make_uniq<FunctionExpression>(args: "list" , args: std::move(list_children)); |
| 160 | aggregate->alias = bind_state.internal_aggregate_names[aggr]; |
| 161 | subquery_stage2->select_list.push_back(x: std::move(aggregate)); |
| 162 | } |
| 163 | // construct the pivot list |
| 164 | auto pivot_name = "__internal_pivot_name" ; |
| 165 | unique_ptr<ParsedExpression> expr; |
| 166 | for (auto &pivot : ref.pivots) { |
| 167 | for (auto &pivot_expr : pivot.pivot_expressions) { |
| 168 | // coalesce(pivot::VARCHAR, 'NULL') |
| 169 | auto cast = make_uniq<CastExpression>(args: LogicalType::VARCHAR, args: std::move(pivot_expr)); |
| 170 | vector<unique_ptr<ParsedExpression>> coalesce_children; |
| 171 | coalesce_children.push_back(x: std::move(cast)); |
| 172 | coalesce_children.push_back(x: make_uniq<ConstantExpression>(args: Value("NULL" ))); |
| 173 | auto coalesce = |
| 174 | make_uniq<OperatorExpression>(args: ExpressionType::OPERATOR_COALESCE, args: std::move(coalesce_children)); |
| 175 | |
| 176 | if (!expr) { |
| 177 | expr = std::move(coalesce); |
| 178 | } else { |
| 179 | // string concat |
| 180 | vector<unique_ptr<ParsedExpression>> concat_children; |
| 181 | concat_children.push_back(x: std::move(expr)); |
| 182 | concat_children.push_back(x: make_uniq<ConstantExpression>(args: Value("_" ))); |
| 183 | concat_children.push_back(x: std::move(coalesce)); |
| 184 | auto concat = make_uniq<FunctionExpression>(args: "concat" , args: std::move(concat_children)); |
| 185 | expr = std::move(concat); |
| 186 | } |
| 187 | } |
| 188 | } |
| 189 | // list(coalesce) |
| 190 | vector<unique_ptr<ParsedExpression>> list_children; |
| 191 | list_children.push_back(x: std::move(expr)); |
| 192 | auto aggregate = make_uniq<FunctionExpression>(args: "list" , args: std::move(list_children)); |
| 193 | |
| 194 | aggregate->alias = pivot_name; |
| 195 | subquery_stage2->select_list.push_back(x: std::move(aggregate)); |
| 196 | |
| 197 | subquery_stage2->from_table = std::move(subquery_ref); |
| 198 | return subquery_stage2; |
| 199 | } |
| 200 | |
| 201 | static unique_ptr<SelectNode> PivotFinalOperator(PivotBindState &bind_state, PivotRef &ref, |
| 202 | unique_ptr<SelectNode> subquery, |
| 203 | vector<PivotValueElement> pivot_values) { |
| 204 | auto final_pivot_operator = make_uniq<SelectNode>(); |
| 205 | // wrap the subquery of stage 1 |
| 206 | auto subquery_select = make_uniq<SelectStatement>(); |
| 207 | subquery_select->node = std::move(subquery); |
| 208 | auto subquery_ref = make_uniq<SubqueryRef>(args: std::move(subquery_select)); |
| 209 | |
| 210 | auto bound_pivot = make_uniq<PivotRef>(); |
| 211 | bound_pivot->bound_pivot_values = std::move(pivot_values); |
| 212 | bound_pivot->bound_group_names = std::move(bind_state.group_names); |
| 213 | bound_pivot->bound_aggregate_names = std::move(bind_state.aggregate_names); |
| 214 | bound_pivot->source = std::move(subquery_ref); |
| 215 | |
| 216 | final_pivot_operator->select_list.push_back(x: make_uniq<StarExpression>()); |
| 217 | final_pivot_operator->from_table = std::move(bound_pivot); |
| 218 | return final_pivot_operator; |
| 219 | } |
| 220 | |
| 221 | void (BoundTableRef &node, vector<unique_ptr<Expression>> &aggregates) { |
| 222 | if (node.type != TableReferenceType::SUBQUERY) { |
| 223 | throw InternalException("Pivot - Expected a subquery" ); |
| 224 | } |
| 225 | auto &subq = node.Cast<BoundSubqueryRef>(); |
| 226 | if (subq.subquery->type != QueryNodeType::SELECT_NODE) { |
| 227 | throw InternalException("Pivot - Expected a select node" ); |
| 228 | } |
| 229 | auto &select = subq.subquery->Cast<BoundSelectNode>(); |
| 230 | if (select.from_table->type != TableReferenceType::SUBQUERY) { |
| 231 | throw InternalException("Pivot - Expected another subquery" ); |
| 232 | } |
| 233 | auto &subq2 = select.from_table->Cast<BoundSubqueryRef>(); |
| 234 | if (subq2.subquery->type != QueryNodeType::SELECT_NODE) { |
| 235 | throw InternalException("Pivot - Expected another select node" ); |
| 236 | } |
| 237 | auto &select2 = subq2.subquery->Cast<BoundSelectNode>(); |
| 238 | for (auto &aggr : select2.aggregates) { |
| 239 | aggregates.push_back(x: aggr->Copy()); |
| 240 | } |
| 241 | } |
| 242 | |
| 243 | unique_ptr<BoundTableRef> Binder::BindBoundPivot(PivotRef &ref) { |
| 244 | // bind the child table in a child binder |
| 245 | auto result = make_uniq<BoundPivotRef>(); |
| 246 | result->bind_index = GenerateTableIndex(); |
| 247 | result->child_binder = Binder::CreateBinder(context, parent: this); |
| 248 | result->child = result->child_binder->Bind(ref&: *ref.source); |
| 249 | |
| 250 | auto &aggregates = result->bound_pivot.aggregates; |
| 251 | ExtractPivotAggregates(node&: *result->child, aggregates); |
| 252 | if (aggregates.size() != ref.bound_aggregate_names.size()) { |
| 253 | throw BinderException("Pivot aggregate count mismatch. Expected %llu aggregates but found %llu. Are all pivot " |
| 254 | "expressions aggregate functions?" , |
| 255 | ref.bound_aggregate_names.size(), aggregates.size()); |
| 256 | } |
| 257 | |
| 258 | vector<string> child_names; |
| 259 | vector<LogicalType> child_types; |
| 260 | result->child_binder->bind_context.GetTypesAndNames(result_names&: child_names, result_types&: child_types); |
| 261 | |
| 262 | vector<string> names; |
| 263 | vector<LogicalType> types; |
| 264 | // emit the groups |
| 265 | for (idx_t i = 0; i < ref.bound_group_names.size(); i++) { |
| 266 | names.push_back(x: ref.bound_group_names[i]); |
| 267 | types.push_back(x: child_types[i]); |
| 268 | } |
| 269 | // emit the pivot columns |
| 270 | for (auto &pivot_value : ref.bound_pivot_values) { |
| 271 | for (idx_t aggr_idx = 0; aggr_idx < ref.bound_aggregate_names.size(); aggr_idx++) { |
| 272 | auto &aggr = aggregates[aggr_idx]; |
| 273 | auto &aggr_name = ref.bound_aggregate_names[aggr_idx]; |
| 274 | auto name = pivot_value.name; |
| 275 | if (aggregates.size() > 1 || !aggr_name.empty()) { |
| 276 | // if there are multiple aggregates specified we add the name of the aggregate as well |
| 277 | name += "_" + (aggr_name.empty() ? aggr->GetName() : aggr_name); |
| 278 | } |
| 279 | string pivot_str; |
| 280 | for (auto &value : pivot_value.values) { |
| 281 | auto str = value.ToString(); |
| 282 | if (pivot_str.empty()) { |
| 283 | pivot_str = std::move(str); |
| 284 | } else { |
| 285 | pivot_str += "_" + str; |
| 286 | } |
| 287 | } |
| 288 | result->bound_pivot.pivot_values.push_back(x: std::move(pivot_str)); |
| 289 | names.push_back(x: std::move(name)); |
| 290 | types.push_back(x: aggr->return_type); |
| 291 | } |
| 292 | } |
| 293 | result->bound_pivot.group_count = ref.bound_group_names.size(); |
| 294 | result->bound_pivot.types = types; |
| 295 | auto subquery_alias = ref.alias.empty() ? "__unnamed_pivot" : ref.alias; |
| 296 | bind_context.AddGenericBinding(index: result->bind_index, alias: subquery_alias, names, types); |
| 297 | MoveCorrelatedExpressions(other&: *result->child_binder); |
| 298 | return std::move(result); |
| 299 | } |
| 300 | |
| 301 | unique_ptr<SelectNode> Binder::BindPivot(PivotRef &ref, vector<unique_ptr<ParsedExpression>> all_columns) { |
| 302 | // keep track of the columns by which we pivot/aggregate |
| 303 | // any columns which are not pivoted/aggregated on are added to the GROUP BY clause |
| 304 | case_insensitive_set_t handled_columns; |
| 305 | // parse the aggregate, and extract the referenced columns from the aggregate |
| 306 | for (auto &aggr : ref.aggregates) { |
| 307 | if (aggr->type != ExpressionType::FUNCTION) { |
| 308 | throw BinderException(FormatError(expr_context&: *aggr, message: "Pivot expression must be an aggregate" )); |
| 309 | } |
| 310 | if (aggr->HasSubquery()) { |
| 311 | throw BinderException(FormatError(expr_context&: *aggr, message: "Pivot expression cannot contain subqueries" )); |
| 312 | } |
| 313 | if (aggr->IsWindow()) { |
| 314 | throw BinderException(FormatError(expr_context&: *aggr, message: "Pivot expression cannot contain window functions" )); |
| 315 | } |
| 316 | ExtractPivotExpressions(expr&: *aggr, handled_columns); |
| 317 | } |
| 318 | value_set_t pivots; |
| 319 | |
| 320 | // first add all pivots to the set of handled columns, and check for duplicates |
| 321 | idx_t total_pivots = 1; |
| 322 | for (auto &pivot : ref.pivots) { |
| 323 | if (!pivot.pivot_enum.empty()) { |
| 324 | auto type = Catalog::GetType(context, INVALID_CATALOG, INVALID_SCHEMA, name: pivot.pivot_enum); |
| 325 | if (type.id() != LogicalTypeId::ENUM) { |
| 326 | throw BinderException( |
| 327 | FormatError(ref_context&: ref, message: StringUtil::Format(fmt_str: "Pivot must reference an ENUM type: \"%s\" is of type \"%s\"" , |
| 328 | params: pivot.pivot_enum, params: type.ToString()))); |
| 329 | } |
| 330 | auto enum_size = EnumType::GetSize(type); |
| 331 | for (idx_t i = 0; i < enum_size; i++) { |
| 332 | auto enum_value = EnumType::GetValue(val: Value::ENUM(value: i, original_type: type)); |
| 333 | PivotColumnEntry entry; |
| 334 | entry.values.emplace_back(args&: enum_value); |
| 335 | entry.alias = std::move(enum_value); |
| 336 | pivot.entries.push_back(x: std::move(entry)); |
| 337 | } |
| 338 | } |
| 339 | total_pivots *= pivot.entries.size(); |
| 340 | // add the pivoted column to the columns that have been handled |
| 341 | for (auto &pivot_name : pivot.pivot_expressions) { |
| 342 | ExtractPivotExpressions(expr&: *pivot_name, handled_columns); |
| 343 | } |
| 344 | value_set_t pivots; |
| 345 | for (auto &entry : pivot.entries) { |
| 346 | D_ASSERT(!entry.star_expr); |
| 347 | Value val; |
| 348 | if (entry.values.size() == 1) { |
| 349 | val = entry.values[0]; |
| 350 | } else { |
| 351 | val = Value::LIST(child_type: LogicalType::VARCHAR, values: entry.values); |
| 352 | } |
| 353 | if (pivots.find(x: val) != pivots.end()) { |
| 354 | throw BinderException(FormatError( |
| 355 | ref_context&: ref, message: StringUtil::Format(fmt_str: "The value \"%s\" was specified multiple times in the IN clause" , |
| 356 | params: val.ToString()))); |
| 357 | } |
| 358 | if (entry.values.size() != pivot.pivot_expressions.size()) { |
| 359 | throw ParserException("PIVOT IN list - inconsistent amount of rows - expected %d but got %d" , |
| 360 | pivot.pivot_expressions.size(), entry.values.size()); |
| 361 | } |
| 362 | pivots.insert(x: val); |
| 363 | } |
| 364 | } |
| 365 | auto pivot_limit = ClientConfig::GetConfig(context).pivot_limit; |
| 366 | if (total_pivots >= pivot_limit) { |
| 367 | throw BinderException("Pivot column limit of %llu exceeded. Use SET pivot_limit=X to increase the limit." , |
| 368 | ClientConfig::GetConfig(context).pivot_limit); |
| 369 | } |
| 370 | |
| 371 | // construct the required pivot values recursively |
| 372 | vector<PivotValueElement> pivot_values; |
| 373 | ConstructPivots(ref, pivot_values); |
| 374 | |
| 375 | // pivots have three components |
| 376 | // - the pivots (i.e. future column names) |
| 377 | // - the groups (i.e. the future row names |
| 378 | // - the aggregates (i.e. the values of the pivot columns) |
| 379 | |
| 380 | // executing a pivot statement happens in three stages |
| 381 | // 1) execute the query "SELECT {groups}, {pivots}, {aggregates} FROM {from_clause} GROUP BY {groups}, {pivots} |
| 382 | // this computes all values that are required in the final result, but not yet in the correct orientation |
| 383 | // 2) execute the query "SELECT {groups}, LIST({pivots}), LIST({aggregates}) FROM [Q1] GROUP BY {groups} |
| 384 | // this pushes all pivots and aggregates that belong to a specific group together in an aligned manner |
| 385 | // 3) push a PIVOT operator, that performs the actual pivoting of the values into the different columns |
| 386 | |
| 387 | PivotBindState bind_state; |
| 388 | // Pivot Stage 1 |
| 389 | // SELECT {groups}, {pivots}, {aggregates} FROM {from_clause} GROUP BY {groups}, {pivots} |
| 390 | auto subquery_stage1 = PivotInitialAggregate(bind_state, ref, all_columns: std::move(all_columns), handled_columns); |
| 391 | |
| 392 | // Pivot stage 2 |
| 393 | // SELECT {groups}, LIST({pivots}), LIST({aggregates}) FROM [Q1] GROUP BY {groups} |
| 394 | auto subquery_stage2 = PivotListAggregate(bind_state, ref, subquery_stage1: std::move(subquery_stage1)); |
| 395 | |
| 396 | // Pivot stage 3 |
| 397 | // construct the final pivot operator |
| 398 | auto pivot_node = PivotFinalOperator(bind_state, ref, subquery: std::move(subquery_stage2), pivot_values: std::move(pivot_values)); |
| 399 | return pivot_node; |
| 400 | } |
| 401 | |
| 402 | unique_ptr<SelectNode> Binder::BindUnpivot(Binder &child_binder, PivotRef &ref, |
| 403 | vector<unique_ptr<ParsedExpression>> all_columns, |
| 404 | unique_ptr<ParsedExpression> &where_clause) { |
| 405 | D_ASSERT(ref.groups.empty()); |
| 406 | D_ASSERT(ref.pivots.size() == 1); |
| 407 | |
| 408 | unique_ptr<ParsedExpression> expr; |
| 409 | auto select_node = make_uniq<SelectNode>(); |
| 410 | select_node->from_table = std::move(ref.source); |
| 411 | |
| 412 | // handle the pivot |
| 413 | auto &unpivot = ref.pivots[0]; |
| 414 | |
| 415 | // handle star expressions in any entries |
| 416 | vector<PivotColumnEntry> new_entries; |
| 417 | for (auto &entry : unpivot.entries) { |
| 418 | if (entry.star_expr) { |
| 419 | D_ASSERT(entry.values.empty()); |
| 420 | vector<unique_ptr<ParsedExpression>> star_columns; |
| 421 | child_binder.ExpandStarExpression(expr: std::move(entry.star_expr), new_select_list&: star_columns); |
| 422 | |
| 423 | for (auto &col : star_columns) { |
| 424 | if (col->type != ExpressionType::COLUMN_REF) { |
| 425 | throw InternalException("Unexpected child of unpivot star - not a ColumnRef" ); |
| 426 | } |
| 427 | auto &columnref = col->Cast<ColumnRefExpression>(); |
| 428 | PivotColumnEntry new_entry; |
| 429 | new_entry.values.emplace_back(args: columnref.GetColumnName()); |
| 430 | new_entry.alias = columnref.GetColumnName(); |
| 431 | new_entries.push_back(x: std::move(new_entry)); |
| 432 | } |
| 433 | } else { |
| 434 | new_entries.push_back(x: std::move(entry)); |
| 435 | } |
| 436 | } |
| 437 | unpivot.entries = std::move(new_entries); |
| 438 | |
| 439 | case_insensitive_set_t handled_columns; |
| 440 | case_insensitive_map_t<string> name_map; |
| 441 | for (auto &entry : unpivot.entries) { |
| 442 | for (auto &value : entry.values) { |
| 443 | handled_columns.insert(x: value.ToString()); |
| 444 | } |
| 445 | } |
| 446 | |
| 447 | for (auto &col_expr : all_columns) { |
| 448 | if (col_expr->type != ExpressionType::COLUMN_REF) { |
| 449 | throw InternalException("Unexpected child of pivot source - not a ColumnRef" ); |
| 450 | } |
| 451 | auto &columnref = col_expr->Cast<ColumnRefExpression>(); |
| 452 | auto &column_name = columnref.GetColumnName(); |
| 453 | auto entry = handled_columns.find(x: column_name); |
| 454 | if (entry == handled_columns.end()) { |
| 455 | // not handled - add to the set of regularly selected columns |
| 456 | select_node->select_list.push_back(x: std::move(col_expr)); |
| 457 | } else { |
| 458 | name_map[column_name] = column_name; |
| 459 | handled_columns.erase(position: entry); |
| 460 | } |
| 461 | } |
| 462 | if (!handled_columns.empty()) { |
| 463 | for (auto &entry : handled_columns) { |
| 464 | throw BinderException("Column \"%s\" referenced in UNPIVOT but no matching entry was found in the table" , |
| 465 | entry); |
| 466 | } |
| 467 | } |
| 468 | vector<Value> unpivot_names; |
| 469 | for (auto &entry : unpivot.entries) { |
| 470 | string generated_name; |
| 471 | for (auto &val : entry.values) { |
| 472 | auto name_entry = name_map.find(x: val.ToString()); |
| 473 | if (name_entry == name_map.end()) { |
| 474 | throw InternalException("Unpivot - could not find column name in name map" ); |
| 475 | } |
| 476 | if (!generated_name.empty()) { |
| 477 | generated_name += "_" ; |
| 478 | } |
| 479 | generated_name += name_entry->second; |
| 480 | } |
| 481 | unpivot_names.emplace_back(args&: !entry.alias.empty() ? entry.alias : generated_name); |
| 482 | } |
| 483 | vector<vector<unique_ptr<ParsedExpression>>> unpivot_expressions; |
| 484 | for (idx_t v_idx = 1; v_idx < unpivot.entries.size(); v_idx++) { |
| 485 | if (unpivot.entries[v_idx].values.size() != unpivot.entries[0].values.size()) { |
| 486 | throw BinderException( |
| 487 | "UNPIVOT value count mismatch - entry has %llu values, but expected all entries to have %llu values" , |
| 488 | unpivot.entries[v_idx].values.size(), unpivot.entries[0].values.size()); |
| 489 | } |
| 490 | } |
| 491 | |
| 492 | for (idx_t v_idx = 0; v_idx < unpivot.entries[0].values.size(); v_idx++) { |
| 493 | vector<unique_ptr<ParsedExpression>> expressions; |
| 494 | expressions.reserve(n: unpivot.entries.size()); |
| 495 | for (auto &entry : unpivot.entries) { |
| 496 | expressions.push_back(x: make_uniq<ColumnRefExpression>(args: entry.values[v_idx].ToString())); |
| 497 | } |
| 498 | unpivot_expressions.push_back(x: std::move(expressions)); |
| 499 | } |
| 500 | |
| 501 | // construct the UNNEST expression for the set of names (constant) |
| 502 | auto unpivot_list = Value::LIST(child_type: LogicalType::VARCHAR, values: std::move(unpivot_names)); |
| 503 | auto unpivot_name_expr = make_uniq<ConstantExpression>(args: std::move(unpivot_list)); |
| 504 | vector<unique_ptr<ParsedExpression>> unnest_name_children; |
| 505 | unnest_name_children.push_back(x: std::move(unpivot_name_expr)); |
| 506 | auto unnest_name_expr = make_uniq<FunctionExpression>(args: "unnest" , args: std::move(unnest_name_children)); |
| 507 | unnest_name_expr->alias = unpivot.unpivot_names[0]; |
| 508 | select_node->select_list.push_back(x: std::move(unnest_name_expr)); |
| 509 | |
| 510 | // construct the UNNEST expression for the set of unpivoted columns |
| 511 | if (ref.unpivot_names.size() != unpivot_expressions.size()) { |
| 512 | throw BinderException("UNPIVOT name count mismatch - got %d names but %d expressions" , ref.unpivot_names.size(), |
| 513 | unpivot_expressions.size()); |
| 514 | } |
| 515 | for (idx_t i = 0; i < unpivot_expressions.size(); i++) { |
| 516 | auto list_expr = make_uniq<FunctionExpression>(args: "list_value" , args: std::move(unpivot_expressions[i])); |
| 517 | vector<unique_ptr<ParsedExpression>> unnest_val_children; |
| 518 | unnest_val_children.push_back(x: std::move(list_expr)); |
| 519 | auto unnest_val_expr = make_uniq<FunctionExpression>(args: "unnest" , args: std::move(unnest_val_children)); |
| 520 | auto unnest_name = i < ref.column_name_alias.size() ? ref.column_name_alias[i] : ref.unpivot_names[i]; |
| 521 | unnest_val_expr->alias = unnest_name; |
| 522 | select_node->select_list.push_back(x: std::move(unnest_val_expr)); |
| 523 | if (!ref.include_nulls) { |
| 524 | // if we are running with EXCLUDE NULLS we need to add an IS NOT NULL filter |
| 525 | auto colref = make_uniq<ColumnRefExpression>(args&: unnest_name); |
| 526 | auto filter = make_uniq<OperatorExpression>(args: ExpressionType::OPERATOR_IS_NOT_NULL, args: std::move(colref)); |
| 527 | if (where_clause) { |
| 528 | where_clause = make_uniq<ConjunctionExpression>(args: ExpressionType::CONJUNCTION_AND, |
| 529 | args: std::move(where_clause), args: std::move(filter)); |
| 530 | } else { |
| 531 | where_clause = std::move(filter); |
| 532 | } |
| 533 | } |
| 534 | } |
| 535 | return select_node; |
| 536 | } |
| 537 | |
| 538 | unique_ptr<BoundTableRef> Binder::Bind(PivotRef &ref) { |
| 539 | if (!ref.source) { |
| 540 | throw InternalException("Pivot without a source!?" ); |
| 541 | } |
| 542 | if (!ref.bound_pivot_values.empty() || !ref.bound_group_names.empty() || !ref.bound_aggregate_names.empty()) { |
| 543 | // bound pivot |
| 544 | return BindBoundPivot(ref); |
| 545 | } |
| 546 | |
| 547 | // bind the source of the pivot |
| 548 | // we need to do this to be able to expand star expressions |
| 549 | if (ref.source->type == TableReferenceType::SUBQUERY && ref.source->alias.empty()) { |
| 550 | ref.source->alias = "__internal_pivot_alias_" + to_string(val: GenerateTableIndex()); |
| 551 | } |
| 552 | auto copied_source = ref.source->Copy(); |
| 553 | auto star_binder = Binder::CreateBinder(context, parent: this); |
| 554 | star_binder->Bind(ref&: *copied_source); |
| 555 | |
| 556 | // figure out the set of column names that are in the source of the pivot |
| 557 | vector<unique_ptr<ParsedExpression>> all_columns; |
| 558 | star_binder->ExpandStarExpression(expr: make_uniq<StarExpression>(), new_select_list&: all_columns); |
| 559 | |
| 560 | unique_ptr<SelectNode> select_node; |
| 561 | unique_ptr<ParsedExpression> where_clause; |
| 562 | if (!ref.aggregates.empty()) { |
| 563 | select_node = BindPivot(ref, all_columns: std::move(all_columns)); |
| 564 | } else { |
| 565 | select_node = BindUnpivot(child_binder&: *star_binder, ref, all_columns: std::move(all_columns), where_clause); |
| 566 | } |
| 567 | // bind the generated select node |
| 568 | auto child_binder = Binder::CreateBinder(context, parent: this); |
| 569 | auto bound_select_node = child_binder->BindNode(node&: *select_node); |
| 570 | auto root_index = bound_select_node->GetRootIndex(); |
| 571 | BoundQueryNode *bound_select_ptr = bound_select_node.get(); |
| 572 | |
| 573 | unique_ptr<BoundTableRef> result; |
| 574 | MoveCorrelatedExpressions(other&: *child_binder); |
| 575 | result = make_uniq<BoundSubqueryRef>(args: std::move(child_binder), args: std::move(bound_select_node)); |
| 576 | auto subquery_alias = ref.alias.empty() ? "__unnamed_pivot" : ref.alias; |
| 577 | SubqueryRef subquery_ref(nullptr, subquery_alias); |
| 578 | subquery_ref.column_name_alias = std::move(ref.column_name_alias); |
| 579 | if (where_clause) { |
| 580 | // if a WHERE clause was provided - bind a subquery holding the WHERE clause |
| 581 | // we need to bind a new subquery here because the WHERE clause has to be applied AFTER the unnest |
| 582 | child_binder = Binder::CreateBinder(context, parent: this); |
| 583 | child_binder->bind_context.AddSubquery(index: root_index, alias: subquery_ref.alias, ref&: subquery_ref, subquery&: *bound_select_ptr); |
| 584 | auto where_query = make_uniq<SelectNode>(); |
| 585 | where_query->select_list.push_back(x: make_uniq<StarExpression>()); |
| 586 | where_query->where_clause = std::move(where_clause); |
| 587 | bound_select_node = child_binder->BindSelectNode(statement&: *where_query, from_table: std::move(result)); |
| 588 | bound_select_ptr = bound_select_node.get(); |
| 589 | root_index = bound_select_node->GetRootIndex(); |
| 590 | result = make_uniq<BoundSubqueryRef>(args: std::move(child_binder), args: std::move(bound_select_node)); |
| 591 | } |
| 592 | bind_context.AddSubquery(index: root_index, alias: subquery_ref.alias, ref&: subquery_ref, subquery&: *bound_select_ptr); |
| 593 | return result; |
| 594 | } |
| 595 | |
| 596 | } // namespace duckdb |
| 597 | |