1 | #include "duckdb/execution/operator/set/physical_recursive_cte.hpp" |
2 | |
3 | #include "duckdb/common/types/column/column_data_collection.hpp" |
4 | #include "duckdb/common/vector_operations/vector_operations.hpp" |
5 | #include "duckdb/execution/aggregate_hashtable.hpp" |
6 | #include "duckdb/execution/executor.hpp" |
7 | #include "duckdb/parallel/event.hpp" |
8 | #include "duckdb/parallel/meta_pipeline.hpp" |
9 | #include "duckdb/parallel/pipeline.hpp" |
10 | #include "duckdb/parallel/task_scheduler.hpp" |
11 | #include "duckdb/storage/buffer_manager.hpp" |
12 | |
13 | namespace duckdb { |
14 | |
15 | PhysicalRecursiveCTE::PhysicalRecursiveCTE(vector<LogicalType> types, bool union_all, unique_ptr<PhysicalOperator> top, |
16 | unique_ptr<PhysicalOperator> bottom, idx_t estimated_cardinality) |
17 | : PhysicalOperator(PhysicalOperatorType::RECURSIVE_CTE, std::move(types), estimated_cardinality), |
18 | union_all(union_all) { |
19 | children.push_back(x: std::move(top)); |
20 | children.push_back(x: std::move(bottom)); |
21 | } |
22 | |
23 | PhysicalRecursiveCTE::~PhysicalRecursiveCTE() { |
24 | } |
25 | |
26 | //===--------------------------------------------------------------------===// |
27 | // Sink |
28 | //===--------------------------------------------------------------------===// |
29 | class RecursiveCTEState : public GlobalSinkState { |
30 | public: |
31 | explicit RecursiveCTEState(ClientContext &context, const PhysicalRecursiveCTE &op) |
32 | : intermediate_table(context, op.GetTypes()), new_groups(STANDARD_VECTOR_SIZE) { |
33 | ht = make_uniq<GroupedAggregateHashTable>(args&: context, args&: Allocator::Get(context), args: op.types, args: vector<LogicalType>(), |
34 | args: vector<BoundAggregateExpression *>()); |
35 | } |
36 | |
37 | unique_ptr<GroupedAggregateHashTable> ht; |
38 | |
39 | bool intermediate_empty = true; |
40 | ColumnDataCollection intermediate_table; |
41 | ColumnDataScanState scan_state; |
42 | bool initialized = false; |
43 | bool finished_scan = false; |
44 | SelectionVector new_groups; |
45 | AggregateHTAppendState append_state; |
46 | }; |
47 | |
48 | unique_ptr<GlobalSinkState> PhysicalRecursiveCTE::GetGlobalSinkState(ClientContext &context) const { |
49 | return make_uniq<RecursiveCTEState>(args&: context, args: *this); |
50 | } |
51 | |
52 | idx_t PhysicalRecursiveCTE::ProbeHT(DataChunk &chunk, RecursiveCTEState &state) const { |
53 | Vector dummy_addresses(LogicalType::POINTER); |
54 | |
55 | // Use the HT to eliminate duplicate rows |
56 | idx_t new_group_count = state.ht->FindOrCreateGroups(state&: state.append_state, groups&: chunk, addresses_out&: dummy_addresses, new_groups_out&: state.new_groups); |
57 | |
58 | // we only return entries we have not seen before (i.e. new groups) |
59 | chunk.Slice(sel_vector: state.new_groups, count: new_group_count); |
60 | |
61 | return new_group_count; |
62 | } |
63 | |
64 | SinkResultType PhysicalRecursiveCTE::Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const { |
65 | auto &gstate = input.global_state.Cast<RecursiveCTEState>(); |
66 | if (!union_all) { |
67 | idx_t match_count = ProbeHT(chunk, state&: gstate); |
68 | if (match_count > 0) { |
69 | gstate.intermediate_table.Append(new_chunk&: chunk); |
70 | } |
71 | } else { |
72 | gstate.intermediate_table.Append(new_chunk&: chunk); |
73 | } |
74 | return SinkResultType::NEED_MORE_INPUT; |
75 | } |
76 | |
77 | //===--------------------------------------------------------------------===// |
78 | // Source |
79 | //===--------------------------------------------------------------------===// |
80 | SourceResultType PhysicalRecursiveCTE::GetData(ExecutionContext &context, DataChunk &chunk, |
81 | OperatorSourceInput &input) const { |
82 | auto &gstate = sink_state->Cast<RecursiveCTEState>(); |
83 | if (!gstate.initialized) { |
84 | gstate.intermediate_table.InitializeScan(state&: gstate.scan_state); |
85 | gstate.finished_scan = false; |
86 | gstate.initialized = true; |
87 | } |
88 | while (chunk.size() == 0) { |
89 | if (!gstate.finished_scan) { |
90 | // scan any chunks we have collected so far |
91 | gstate.intermediate_table.Scan(state&: gstate.scan_state, result&: chunk); |
92 | if (chunk.size() == 0) { |
93 | gstate.finished_scan = true; |
94 | } else { |
95 | break; |
96 | } |
97 | } else { |
98 | // we have run out of chunks |
99 | // now we need to recurse |
100 | // we set up the working table as the data we gathered in this iteration of the recursion |
101 | working_table->Reset(); |
102 | working_table->Combine(other&: gstate.intermediate_table); |
103 | // and we clear the intermediate table |
104 | gstate.finished_scan = false; |
105 | gstate.intermediate_table.Reset(); |
106 | // now we need to re-execute all of the pipelines that depend on the recursion |
107 | ExecuteRecursivePipelines(context); |
108 | |
109 | // check if we obtained any results |
110 | // if not, we are done |
111 | if (gstate.intermediate_table.Count() == 0) { |
112 | gstate.finished_scan = true; |
113 | break; |
114 | } |
115 | // set up the scan again |
116 | gstate.intermediate_table.InitializeScan(state&: gstate.scan_state); |
117 | } |
118 | } |
119 | |
120 | return chunk.size() == 0 ? SourceResultType::FINISHED : SourceResultType::HAVE_MORE_OUTPUT; |
121 | } |
122 | |
123 | void PhysicalRecursiveCTE::ExecuteRecursivePipelines(ExecutionContext &context) const { |
124 | if (!recursive_meta_pipeline) { |
125 | throw InternalException("Missing meta pipeline for recursive CTE" ); |
126 | } |
127 | D_ASSERT(recursive_meta_pipeline->HasRecursiveCTE()); |
128 | |
129 | // get and reset pipelines |
130 | vector<shared_ptr<Pipeline>> pipelines; |
131 | recursive_meta_pipeline->GetPipelines(result&: pipelines, recursive: true); |
132 | for (auto &pipeline : pipelines) { |
133 | auto sink = pipeline->GetSink(); |
134 | if (sink.get() != this) { |
135 | sink->sink_state.reset(); |
136 | } |
137 | for (auto &op_ref : pipeline->GetOperators()) { |
138 | auto &op = op_ref.get(); |
139 | op.op_state.reset(); |
140 | } |
141 | pipeline->ClearSource(); |
142 | } |
143 | |
144 | // get the MetaPipelines in the recursive_meta_pipeline and reschedule them |
145 | vector<shared_ptr<MetaPipeline>> meta_pipelines; |
146 | recursive_meta_pipeline->GetMetaPipelines(result&: meta_pipelines, recursive: true, skip: false); |
147 | auto &executor = recursive_meta_pipeline->GetExecutor(); |
148 | vector<shared_ptr<Event>> events; |
149 | executor.ReschedulePipelines(pipelines: meta_pipelines, events); |
150 | |
151 | while (true) { |
152 | executor.WorkOnTasks(); |
153 | if (executor.HasError()) { |
154 | executor.ThrowException(); |
155 | } |
156 | bool finished = true; |
157 | for (auto &event : events) { |
158 | if (!event->IsFinished()) { |
159 | finished = false; |
160 | break; |
161 | } |
162 | } |
163 | if (finished) { |
164 | // all pipelines finished: done! |
165 | break; |
166 | } |
167 | } |
168 | } |
169 | |
170 | //===--------------------------------------------------------------------===// |
171 | // Pipeline Construction |
172 | //===--------------------------------------------------------------------===// |
173 | void PhysicalRecursiveCTE::BuildPipelines(Pipeline ¤t, MetaPipeline &meta_pipeline) { |
174 | op_state.reset(); |
175 | sink_state.reset(); |
176 | recursive_meta_pipeline.reset(); |
177 | |
178 | auto &state = meta_pipeline.GetState(); |
179 | state.SetPipelineSource(pipeline&: current, op&: *this); |
180 | |
181 | auto &executor = meta_pipeline.GetExecutor(); |
182 | executor.AddRecursiveCTE(rec_cte&: *this); |
183 | |
184 | // the LHS of the recursive CTE is our initial state |
185 | auto &initial_state_pipeline = meta_pipeline.CreateChildMetaPipeline(current, op&: *this); |
186 | initial_state_pipeline.Build(op&: *children[0]); |
187 | |
188 | // the RHS is the recursive pipeline |
189 | recursive_meta_pipeline = make_shared<MetaPipeline>(args&: executor, args&: state, args: this); |
190 | recursive_meta_pipeline->SetRecursiveCTE(); |
191 | recursive_meta_pipeline->Build(op&: *children[1]); |
192 | } |
193 | |
194 | vector<const_reference<PhysicalOperator>> PhysicalRecursiveCTE::GetSources() const { |
195 | return {*this}; |
196 | } |
197 | |
198 | } // namespace duckdb |
199 | |