1#include "duckdb/execution/operator/set/physical_recursive_cte.hpp"
2
3#include "duckdb/common/types/column/column_data_collection.hpp"
4#include "duckdb/common/vector_operations/vector_operations.hpp"
5#include "duckdb/execution/aggregate_hashtable.hpp"
6#include "duckdb/execution/executor.hpp"
7#include "duckdb/parallel/event.hpp"
8#include "duckdb/parallel/meta_pipeline.hpp"
9#include "duckdb/parallel/pipeline.hpp"
10#include "duckdb/parallel/task_scheduler.hpp"
11#include "duckdb/storage/buffer_manager.hpp"
12
13namespace duckdb {
14
15PhysicalRecursiveCTE::PhysicalRecursiveCTE(vector<LogicalType> types, bool union_all, unique_ptr<PhysicalOperator> top,
16 unique_ptr<PhysicalOperator> bottom, idx_t estimated_cardinality)
17 : PhysicalOperator(PhysicalOperatorType::RECURSIVE_CTE, std::move(types), estimated_cardinality),
18 union_all(union_all) {
19 children.push_back(x: std::move(top));
20 children.push_back(x: std::move(bottom));
21}
22
23PhysicalRecursiveCTE::~PhysicalRecursiveCTE() {
24}
25
26//===--------------------------------------------------------------------===//
27// Sink
28//===--------------------------------------------------------------------===//
29class RecursiveCTEState : public GlobalSinkState {
30public:
31 explicit RecursiveCTEState(ClientContext &context, const PhysicalRecursiveCTE &op)
32 : intermediate_table(context, op.GetTypes()), new_groups(STANDARD_VECTOR_SIZE) {
33 ht = make_uniq<GroupedAggregateHashTable>(args&: context, args&: Allocator::Get(context), args: op.types, args: vector<LogicalType>(),
34 args: vector<BoundAggregateExpression *>());
35 }
36
37 unique_ptr<GroupedAggregateHashTable> ht;
38
39 bool intermediate_empty = true;
40 ColumnDataCollection intermediate_table;
41 ColumnDataScanState scan_state;
42 bool initialized = false;
43 bool finished_scan = false;
44 SelectionVector new_groups;
45 AggregateHTAppendState append_state;
46};
47
48unique_ptr<GlobalSinkState> PhysicalRecursiveCTE::GetGlobalSinkState(ClientContext &context) const {
49 return make_uniq<RecursiveCTEState>(args&: context, args: *this);
50}
51
52idx_t PhysicalRecursiveCTE::ProbeHT(DataChunk &chunk, RecursiveCTEState &state) const {
53 Vector dummy_addresses(LogicalType::POINTER);
54
55 // Use the HT to eliminate duplicate rows
56 idx_t new_group_count = state.ht->FindOrCreateGroups(state&: state.append_state, groups&: chunk, addresses_out&: dummy_addresses, new_groups_out&: state.new_groups);
57
58 // we only return entries we have not seen before (i.e. new groups)
59 chunk.Slice(sel_vector: state.new_groups, count: new_group_count);
60
61 return new_group_count;
62}
63
64SinkResultType PhysicalRecursiveCTE::Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const {
65 auto &gstate = input.global_state.Cast<RecursiveCTEState>();
66 if (!union_all) {
67 idx_t match_count = ProbeHT(chunk, state&: gstate);
68 if (match_count > 0) {
69 gstate.intermediate_table.Append(new_chunk&: chunk);
70 }
71 } else {
72 gstate.intermediate_table.Append(new_chunk&: chunk);
73 }
74 return SinkResultType::NEED_MORE_INPUT;
75}
76
77//===--------------------------------------------------------------------===//
78// Source
79//===--------------------------------------------------------------------===//
80SourceResultType PhysicalRecursiveCTE::GetData(ExecutionContext &context, DataChunk &chunk,
81 OperatorSourceInput &input) const {
82 auto &gstate = sink_state->Cast<RecursiveCTEState>();
83 if (!gstate.initialized) {
84 gstate.intermediate_table.InitializeScan(state&: gstate.scan_state);
85 gstate.finished_scan = false;
86 gstate.initialized = true;
87 }
88 while (chunk.size() == 0) {
89 if (!gstate.finished_scan) {
90 // scan any chunks we have collected so far
91 gstate.intermediate_table.Scan(state&: gstate.scan_state, result&: chunk);
92 if (chunk.size() == 0) {
93 gstate.finished_scan = true;
94 } else {
95 break;
96 }
97 } else {
98 // we have run out of chunks
99 // now we need to recurse
100 // we set up the working table as the data we gathered in this iteration of the recursion
101 working_table->Reset();
102 working_table->Combine(other&: gstate.intermediate_table);
103 // and we clear the intermediate table
104 gstate.finished_scan = false;
105 gstate.intermediate_table.Reset();
106 // now we need to re-execute all of the pipelines that depend on the recursion
107 ExecuteRecursivePipelines(context);
108
109 // check if we obtained any results
110 // if not, we are done
111 if (gstate.intermediate_table.Count() == 0) {
112 gstate.finished_scan = true;
113 break;
114 }
115 // set up the scan again
116 gstate.intermediate_table.InitializeScan(state&: gstate.scan_state);
117 }
118 }
119
120 return chunk.size() == 0 ? SourceResultType::FINISHED : SourceResultType::HAVE_MORE_OUTPUT;
121}
122
123void PhysicalRecursiveCTE::ExecuteRecursivePipelines(ExecutionContext &context) const {
124 if (!recursive_meta_pipeline) {
125 throw InternalException("Missing meta pipeline for recursive CTE");
126 }
127 D_ASSERT(recursive_meta_pipeline->HasRecursiveCTE());
128
129 // get and reset pipelines
130 vector<shared_ptr<Pipeline>> pipelines;
131 recursive_meta_pipeline->GetPipelines(result&: pipelines, recursive: true);
132 for (auto &pipeline : pipelines) {
133 auto sink = pipeline->GetSink();
134 if (sink.get() != this) {
135 sink->sink_state.reset();
136 }
137 for (auto &op_ref : pipeline->GetOperators()) {
138 auto &op = op_ref.get();
139 op.op_state.reset();
140 }
141 pipeline->ClearSource();
142 }
143
144 // get the MetaPipelines in the recursive_meta_pipeline and reschedule them
145 vector<shared_ptr<MetaPipeline>> meta_pipelines;
146 recursive_meta_pipeline->GetMetaPipelines(result&: meta_pipelines, recursive: true, skip: false);
147 auto &executor = recursive_meta_pipeline->GetExecutor();
148 vector<shared_ptr<Event>> events;
149 executor.ReschedulePipelines(pipelines: meta_pipelines, events);
150
151 while (true) {
152 executor.WorkOnTasks();
153 if (executor.HasError()) {
154 executor.ThrowException();
155 }
156 bool finished = true;
157 for (auto &event : events) {
158 if (!event->IsFinished()) {
159 finished = false;
160 break;
161 }
162 }
163 if (finished) {
164 // all pipelines finished: done!
165 break;
166 }
167 }
168}
169
170//===--------------------------------------------------------------------===//
171// Pipeline Construction
172//===--------------------------------------------------------------------===//
173void PhysicalRecursiveCTE::BuildPipelines(Pipeline &current, MetaPipeline &meta_pipeline) {
174 op_state.reset();
175 sink_state.reset();
176 recursive_meta_pipeline.reset();
177
178 auto &state = meta_pipeline.GetState();
179 state.SetPipelineSource(pipeline&: current, op&: *this);
180
181 auto &executor = meta_pipeline.GetExecutor();
182 executor.AddRecursiveCTE(rec_cte&: *this);
183
184 // the LHS of the recursive CTE is our initial state
185 auto &initial_state_pipeline = meta_pipeline.CreateChildMetaPipeline(current, op&: *this);
186 initial_state_pipeline.Build(op&: *children[0]);
187
188 // the RHS is the recursive pipeline
189 recursive_meta_pipeline = make_shared<MetaPipeline>(args&: executor, args&: state, args: this);
190 recursive_meta_pipeline->SetRecursiveCTE();
191 recursive_meta_pipeline->Build(op&: *children[1]);
192}
193
194vector<const_reference<PhysicalOperator>> PhysicalRecursiveCTE::GetSources() const {
195 return {*this};
196}
197
198} // namespace duckdb
199