1#include "duckdb/optimizer/join_order/query_graph.hpp"
2
3#include "duckdb/common/printer.hpp"
4#include "duckdb/common/string_util.hpp"
5
6using namespace duckdb;
7using namespace std;
8
9using QueryEdge = QueryGraph::QueryEdge;
10
11static string QueryEdgeToString(const QueryEdge *info, vector<idx_t> prefix) {
12 string result = "";
13 string source = "[";
14 for (idx_t i = 0; i < prefix.size(); i++) {
15 source += to_string(prefix[i]) + (i < prefix.size() - 1 ? ", " : "");
16 }
17 source += "]";
18 for (auto &entry : info->neighbors) {
19 result += StringUtil::Format("%s -> %s\n", source.c_str(), entry->neighbor->ToString().c_str());
20 }
21 for (auto &entry : info->children) {
22 vector<idx_t> new_prefix = prefix;
23 new_prefix.push_back(entry.first);
24 result += QueryEdgeToString(entry.second.get(), new_prefix);
25 }
26 return result;
27}
28
29string QueryGraph::ToString() const {
30 return QueryEdgeToString(&root, {});
31}
32
33QueryEdge *QueryGraph::GetQueryEdge(JoinRelationSet *left) {
34 assert(left && left->count > 0);
35 // find the EdgeInfo corresponding to the left set
36 QueryEdge *info = &root;
37 for (idx_t i = 0; i < left->count; i++) {
38 auto entry = info->children.find(left->relations[i]);
39 if (entry == info->children.end()) {
40 // node not found, create it
41 auto insert_it = info->children.insert(make_pair(left->relations[i], make_unique<QueryEdge>()));
42 entry = insert_it.first;
43 }
44 // move to the next node
45 info = entry->second.get();
46 }
47 return info;
48}
49
50void QueryGraph::CreateEdge(JoinRelationSet *left, JoinRelationSet *right, FilterInfo *filter_info) {
51 assert(left && right && left->count > 0 && right->count > 0);
52 // find the EdgeInfo corresponding to the left set
53 auto info = GetQueryEdge(left);
54 // now insert the edge to the right relation, if it does not exist
55 for (idx_t i = 0; i < info->neighbors.size(); i++) {
56 if (info->neighbors[i]->neighbor == right) {
57 if (filter_info) {
58 // neighbor already exists just add the filter, if we have any
59 info->neighbors[i]->filters.push_back(filter_info);
60 }
61 return;
62 }
63 }
64 // neighbor does not exist, create it
65 auto n = make_unique<NeighborInfo>();
66 if (filter_info) {
67 n->filters.push_back(filter_info);
68 }
69 n->neighbor = right;
70 info->neighbors.push_back(move(n));
71}
72
73void QueryGraph::EnumerateNeighbors(JoinRelationSet *node, function<bool(NeighborInfo *)> callback) {
74 for (idx_t j = 0; j < node->count; j++) {
75 QueryEdge *info = &root;
76 for (idx_t i = j; i < node->count; i++) {
77 auto entry = info->children.find(node->relations[i]);
78 if (entry == info->children.end()) {
79 // node not found
80 break;
81 }
82 // check if any subset of the other set is in this sets neighbors
83 info = entry->second.get();
84 for (auto &neighbor : info->neighbors) {
85 if (callback(neighbor.get())) {
86 return;
87 }
88 }
89 }
90 }
91}
92
93//! Returns true if a JoinRelationSet is banned by the list of exclusion_set, false otherwise
94static bool JoinRelationSetIsExcluded(JoinRelationSet *node, unordered_set<idx_t> &exclusion_set) {
95 return exclusion_set.find(node->relations[0]) != exclusion_set.end();
96}
97
98vector<idx_t> QueryGraph::GetNeighbors(JoinRelationSet *node, unordered_set<idx_t> &exclusion_set) {
99 unordered_set<idx_t> result;
100 EnumerateNeighbors(node, [&](NeighborInfo *info) -> bool {
101 if (!JoinRelationSetIsExcluded(info->neighbor, exclusion_set)) {
102 // add the smallest node of the neighbor to the set
103 result.insert(info->neighbor->relations[0]);
104 }
105 return false;
106 });
107 vector<idx_t> neighbors;
108 neighbors.insert(neighbors.end(), result.begin(), result.end());
109 return neighbors;
110}
111
112NeighborInfo *QueryGraph::GetConnection(JoinRelationSet *node, JoinRelationSet *other) {
113 NeighborInfo *connection = nullptr;
114 EnumerateNeighbors(node, [&](NeighborInfo *info) -> bool {
115 if (JoinRelationSet::IsSubset(other, info->neighbor)) {
116 connection = info;
117 return true;
118 }
119 return false;
120 });
121 return connection;
122}
123
124void QueryGraph::Print() {
125 Printer::Print(ToString());
126}
127