1 | #include <Common/typeid_cast.h> |
2 | #include <Functions/FunctionsComparison.h> |
3 | #include <Functions/FunctionsLogical.h> |
4 | #include <Interpreters/CrossToInnerJoinVisitor.h> |
5 | #include <Interpreters/DatabaseAndTableWithAlias.h> |
6 | #include <Interpreters/IdentifierSemantic.h> |
7 | #include <Interpreters/misc.h> |
8 | #include <Parsers/ASTSelectQuery.h> |
9 | #include <Parsers/ASTTablesInSelectQuery.h> |
10 | #include <Parsers/ASTIdentifier.h> |
11 | #include <Parsers/ASTFunction.h> |
12 | #include <Parsers/ASTExpressionList.h> |
13 | #include <Parsers/ParserTablesInSelectQuery.h> |
14 | #include <Parsers/ExpressionListParsers.h> |
15 | #include <Parsers/parseQuery.h> |
16 | #include <IO/WriteHelpers.h> |
17 | |
18 | namespace DB |
19 | { |
20 | |
21 | namespace ErrorCodes |
22 | { |
23 | extern const int LOGICAL_ERROR; |
24 | extern const int NOT_IMPLEMENTED; |
25 | } |
26 | |
27 | namespace |
28 | { |
29 | |
30 | struct JoinedTable |
31 | { |
32 | DatabaseAndTableWithAlias table; |
33 | ASTTablesInSelectQueryElement * element = nullptr; |
34 | ASTTableJoin * join = nullptr; |
35 | ASTPtr array_join = nullptr; |
36 | bool has_using = false; |
37 | |
38 | JoinedTable(ASTPtr table_element) |
39 | { |
40 | element = table_element->as<ASTTablesInSelectQueryElement>(); |
41 | if (!element) |
42 | throw Exception("Logical error: TablesInSelectQueryElement expected" , ErrorCodes::LOGICAL_ERROR); |
43 | |
44 | if (element->table_join) |
45 | { |
46 | join = element->table_join->as<ASTTableJoin>(); |
47 | if (join->kind == ASTTableJoin::Kind::Cross || |
48 | join->kind == ASTTableJoin::Kind::Comma) |
49 | { |
50 | if (!join->children.empty()) |
51 | throw Exception("Logical error: CROSS JOIN has expressions" , ErrorCodes::LOGICAL_ERROR); |
52 | } |
53 | |
54 | if (join->using_expression_list) |
55 | has_using = true; |
56 | } |
57 | |
58 | if (element->table_expression) |
59 | { |
60 | const auto & expr = element->table_expression->as<ASTTableExpression &>(); |
61 | table = DatabaseAndTableWithAlias(expr); |
62 | } |
63 | |
64 | array_join = element->array_join; |
65 | } |
66 | |
67 | void rewriteCommaToCross() |
68 | { |
69 | if (join) |
70 | join->kind = ASTTableJoin::Kind::Cross; |
71 | } |
72 | |
73 | bool canAttachOnExpression() const { return join && !join->on_expression; } |
74 | }; |
75 | |
76 | bool isComparison(const String & name) |
77 | { |
78 | return name == NameEquals::name || |
79 | name == NameNotEquals::name || |
80 | name == NameLess::name || |
81 | name == NameGreater::name || |
82 | name == NameLessOrEquals::name || |
83 | name == NameGreaterOrEquals::name; |
84 | } |
85 | |
86 | /// It checks if where expression could be moved to JOIN ON expression partially or entirely. |
87 | class CheckExpressionVisitorData |
88 | { |
89 | public: |
90 | using TypeToVisit = const ASTFunction; |
91 | |
92 | CheckExpressionVisitorData(const std::vector<JoinedTable> & tables_) |
93 | : joined_tables(tables_) |
94 | , ands_only(true) |
95 | { |
96 | for (auto & joined : joined_tables) |
97 | tables.push_back(joined.table); |
98 | } |
99 | |
100 | void visit(const ASTFunction & node, const ASTPtr & ast) |
101 | { |
102 | if (!ands_only) |
103 | return; |
104 | |
105 | if (node.name == NameAnd::name) |
106 | { |
107 | if (!node.arguments || node.arguments->children.empty()) |
108 | throw Exception("Logical error: function requires argiment" , ErrorCodes::LOGICAL_ERROR); |
109 | |
110 | for (auto & child : node.arguments->children) |
111 | { |
112 | if (const auto * func = child->as<ASTFunction>()) |
113 | visit(*func, child); |
114 | else |
115 | ands_only = false; |
116 | } |
117 | } |
118 | else if (node.name == NameEquals::name) |
119 | { |
120 | if (size_t min_table = canMoveEqualsToJoinOn(node)) |
121 | asts_to_join_on[min_table].push_back(ast); |
122 | } |
123 | else if (isComparison(node.name)) |
124 | { |
125 | /// leave other comparisons as is |
126 | } |
127 | else if (functionIsInOperator(node.name)) /// IN, NOT IN |
128 | { |
129 | if (auto ident = node.arguments->children.at(0)->as<ASTIdentifier>()) |
130 | if (size_t min_table = checkIdentifier(*ident)) |
131 | asts_to_join_on[min_table].push_back(ast); |
132 | } |
133 | else |
134 | { |
135 | ands_only = false; |
136 | asts_to_join_on.clear(); |
137 | } |
138 | } |
139 | |
140 | bool complex() const { return !ands_only; } |
141 | bool matchAny(size_t t) const { return asts_to_join_on.count(t); } |
142 | |
143 | ASTPtr makeOnExpression(size_t table_pos) |
144 | { |
145 | if (!asts_to_join_on.count(table_pos)) |
146 | return {}; |
147 | |
148 | std::vector<ASTPtr> & expressions = asts_to_join_on[table_pos]; |
149 | |
150 | if (expressions.size() == 1) |
151 | return expressions[0]->clone(); |
152 | |
153 | std::vector<ASTPtr> arguments; |
154 | arguments.reserve(expressions.size()); |
155 | for (auto & ast : expressions) |
156 | arguments.emplace_back(ast->clone()); |
157 | |
158 | return makeASTFunction(NameAnd::name, std::move(arguments)); |
159 | } |
160 | |
161 | private: |
162 | const std::vector<JoinedTable> & joined_tables; |
163 | std::vector<DatabaseAndTableWithAlias> tables; |
164 | std::map<size_t, std::vector<ASTPtr>> asts_to_join_on; |
165 | bool ands_only; |
166 | |
167 | size_t canMoveEqualsToJoinOn(const ASTFunction & node) |
168 | { |
169 | if (!node.arguments) |
170 | throw Exception("Logical error: function requires arguments" , ErrorCodes::LOGICAL_ERROR); |
171 | if (node.arguments->children.size() != 2) |
172 | return false; |
173 | |
174 | const auto * left = node.arguments->children[0]->as<ASTIdentifier>(); |
175 | const auto * right = node.arguments->children[1]->as<ASTIdentifier>(); |
176 | if (!left || !right) |
177 | return false; |
178 | |
179 | return checkIdentifiers(*left, *right); |
180 | } |
181 | |
182 | /// Check if the identifiers are from different joined tables. If it's a self joint, tables should have aliases. |
183 | /// select * from t1 a cross join t2 b where a.x = b.x |
184 | /// @return table position to attach expression to or 0. |
185 | size_t checkIdentifiers(const ASTIdentifier & left, const ASTIdentifier & right) |
186 | { |
187 | size_t left_table_pos = 0; |
188 | bool left_match = IdentifierSemantic::chooseTable(left, tables, left_table_pos); |
189 | |
190 | size_t right_table_pos = 0; |
191 | bool right_match = IdentifierSemantic::chooseTable(right, tables, right_table_pos); |
192 | |
193 | if (left_match && right_match && (left_table_pos != right_table_pos)) |
194 | { |
195 | size_t table_pos = std::max(left_table_pos, right_table_pos); |
196 | if (joined_tables[table_pos].canAttachOnExpression()) |
197 | return table_pos; |
198 | } |
199 | return 0; |
200 | } |
201 | |
202 | size_t checkIdentifier(const ASTIdentifier & identifier) |
203 | { |
204 | size_t best_table_pos = 0; |
205 | bool match = IdentifierSemantic::chooseTable(identifier, tables, best_table_pos); |
206 | |
207 | if (match && joined_tables[best_table_pos].canAttachOnExpression()) |
208 | return best_table_pos; |
209 | return 0; |
210 | } |
211 | }; |
212 | |
213 | using CheckExpressionMatcher = ConstOneTypeMatcher<CheckExpressionVisitorData, false>; |
214 | using CheckExpressionVisitor = ConstInDepthNodeVisitor<CheckExpressionMatcher, true>; |
215 | |
216 | |
217 | bool getTables(ASTSelectQuery & select, std::vector<JoinedTable> & joined_tables, size_t & num_comma) |
218 | { |
219 | if (!select.tables()) |
220 | return false; |
221 | |
222 | const auto * tables = select.tables()->as<ASTTablesInSelectQuery>(); |
223 | if (!tables) |
224 | return false; |
225 | |
226 | size_t num_tables = tables->children.size(); |
227 | if (num_tables < 2) |
228 | return false; |
229 | |
230 | joined_tables.reserve(num_tables); |
231 | size_t num_array_join = 0; |
232 | size_t num_using = 0; |
233 | |
234 | for (auto & child : tables->children) |
235 | { |
236 | joined_tables.emplace_back(JoinedTable(child)); |
237 | JoinedTable & t = joined_tables.back(); |
238 | if (t.array_join) |
239 | { |
240 | ++num_array_join; |
241 | continue; |
242 | } |
243 | |
244 | if (t.has_using) |
245 | { |
246 | ++num_using; |
247 | continue; |
248 | } |
249 | |
250 | if (auto * join = t.join) |
251 | if (join->kind == ASTTableJoin::Kind::Comma) |
252 | ++num_comma; |
253 | } |
254 | |
255 | if (num_using && (num_tables - num_array_join) > 2) |
256 | throw Exception("Multiple CROSS/COMMA JOIN do not support USING" , ErrorCodes::NOT_IMPLEMENTED); |
257 | |
258 | if (num_comma && (num_comma != (joined_tables.size() - 1))) |
259 | throw Exception("Mix of COMMA and other JOINS is not supported" , ErrorCodes::NOT_IMPLEMENTED); |
260 | |
261 | if (num_array_join || num_using) |
262 | return false; |
263 | return true; |
264 | } |
265 | |
266 | } |
267 | |
268 | |
269 | void CrossToInnerJoinMatcher::visit(ASTPtr & ast, Data & data) |
270 | { |
271 | if (auto * t = ast->as<ASTSelectQuery>()) |
272 | visit(*t, ast, data); |
273 | } |
274 | |
275 | void CrossToInnerJoinMatcher::visit(ASTSelectQuery & select, ASTPtr &, Data & data) |
276 | { |
277 | size_t num_comma = 0; |
278 | std::vector<JoinedTable> joined_tables; |
279 | if (!getTables(select, joined_tables, num_comma)) |
280 | return; |
281 | |
282 | /// COMMA to CROSS |
283 | |
284 | if (num_comma) |
285 | { |
286 | for (auto & table : joined_tables) |
287 | table.rewriteCommaToCross(); |
288 | } |
289 | |
290 | /// CROSS to INNER |
291 | |
292 | if (!select.where()) |
293 | return; |
294 | |
295 | CheckExpressionVisitor::Data visitor_data{joined_tables}; |
296 | CheckExpressionVisitor(visitor_data).visit(select.where()); |
297 | |
298 | if (visitor_data.complex()) |
299 | return; |
300 | |
301 | for (size_t i = 1; i < joined_tables.size(); ++i) |
302 | { |
303 | if (visitor_data.matchAny(i)) |
304 | { |
305 | ASTTableJoin & join = *joined_tables[i].join; |
306 | join.kind = ASTTableJoin::Kind::Inner; |
307 | join.strictness = ASTTableJoin::Strictness::All; |
308 | |
309 | join.on_expression = visitor_data.makeOnExpression(i); |
310 | join.children.push_back(join.on_expression); |
311 | data.done = true; |
312 | } |
313 | } |
314 | } |
315 | |
316 | } |
317 | |