1#include "duckdb/execution/operator/set/physical_recursive_cte.hpp"
2
3#include "duckdb/common/vector_operations/vector_operations.hpp"
4
5#include "duckdb/common/types/chunk_collection.hpp"
6#include "duckdb/execution/aggregate_hashtable.hpp"
7
8using namespace duckdb;
9using namespace std;
10
11class PhysicalRecursiveCTEState : public PhysicalOperatorState {
12public:
13 PhysicalRecursiveCTEState() : PhysicalOperatorState(nullptr), top_done(false) {
14 }
15 unique_ptr<PhysicalOperatorState> top_state;
16 unique_ptr<PhysicalOperatorState> bottom_state;
17 unique_ptr<SuperLargeHashTable> ht;
18
19 bool top_done = false;
20
21 bool recursing = false;
22 bool intermediate_empty = true;
23};
24
25PhysicalRecursiveCTE::PhysicalRecursiveCTE(LogicalOperator &op, bool union_all, unique_ptr<PhysicalOperator> top,
26 unique_ptr<PhysicalOperator> bottom)
27 : PhysicalOperator(PhysicalOperatorType::RECURSIVE_CTE, op.types), union_all(union_all) {
28 children.push_back(move(top));
29 children.push_back(move(bottom));
30}
31
32// first exhaust non recursive term, then exhaust recursive term iteratively until no (new) rows are generated.
33void PhysicalRecursiveCTE::GetChunkInternal(ClientContext &context, DataChunk &chunk, PhysicalOperatorState *state_) {
34 auto state = reinterpret_cast<PhysicalRecursiveCTEState *>(state_);
35
36 if (!state->recursing) {
37 do {
38 children[0]->GetChunk(context, chunk, state->top_state.get());
39 if (!union_all) {
40 idx_t match_count = ProbeHT(chunk, state);
41 if (match_count > 0) {
42 working_table->Append(chunk);
43 }
44 } else {
45 working_table->Append(chunk);
46 }
47
48 if (chunk.size() != 0)
49 return;
50 } while (chunk.size() != 0);
51 state->recursing = true;
52 }
53
54 while (true) {
55 children[1]->GetChunk(context, chunk, state->bottom_state.get());
56
57 if (chunk.size() == 0) {
58 // Done if there is nothing in the intermediate table
59 if (state->intermediate_empty) {
60 state->finished = true;
61 break;
62 }
63
64 working_table->count = 0;
65 working_table->chunks.clear();
66
67 working_table->count = intermediate_table.count;
68 working_table->chunks = move(intermediate_table.chunks);
69
70 intermediate_table.count = 0;
71 intermediate_table.chunks.clear();
72
73 state->bottom_state = children[1]->GetOperatorState();
74
75 state->intermediate_empty = true;
76 continue;
77 }
78
79 if (!union_all) {
80 // If we evaluate using UNION semantics, we have to eliminate duplicates before appending them to
81 // intermediate tables.
82 idx_t match_count = ProbeHT(chunk, state);
83 if (match_count > 0) {
84 intermediate_table.Append(chunk);
85 state->intermediate_empty = false;
86 }
87 } else {
88 intermediate_table.Append(chunk);
89 state->intermediate_empty = false;
90 }
91
92 return;
93 }
94}
95
96idx_t PhysicalRecursiveCTE::ProbeHT(DataChunk &chunk, PhysicalOperatorState *state_) {
97 auto state = reinterpret_cast<PhysicalRecursiveCTEState *>(state_);
98
99 Vector dummy_addresses(TypeId::POINTER);
100
101 // Use the HT to eliminate duplicate rows
102 SelectionVector new_groups(STANDARD_VECTOR_SIZE);
103 idx_t new_group_count = state->ht->FindOrCreateGroups(chunk, dummy_addresses, new_groups);
104
105 // we only return entries we have not seen before (i.e. new groups)
106 chunk.Slice(new_groups, new_group_count);
107
108 return new_group_count;
109}
110
111unique_ptr<PhysicalOperatorState> PhysicalRecursiveCTE::GetOperatorState() {
112 auto state = make_unique<PhysicalRecursiveCTEState>();
113 state->top_state = children[0]->GetOperatorState();
114 state->bottom_state = children[1]->GetOperatorState();
115 state->ht = make_unique<SuperLargeHashTable>(1024, types, vector<TypeId>(), vector<BoundAggregateExpression *>());
116 return (move(state));
117}
118