1#include <Common/typeid_cast.h>
2#include <Functions/FunctionsComparison.h>
3#include <Functions/FunctionsLogical.h>
4#include <Interpreters/CrossToInnerJoinVisitor.h>
5#include <Interpreters/DatabaseAndTableWithAlias.h>
6#include <Interpreters/IdentifierSemantic.h>
7#include <Interpreters/misc.h>
8#include <Parsers/ASTSelectQuery.h>
9#include <Parsers/ASTTablesInSelectQuery.h>
10#include <Parsers/ASTIdentifier.h>
11#include <Parsers/ASTFunction.h>
12#include <Parsers/ASTExpressionList.h>
13#include <Parsers/ParserTablesInSelectQuery.h>
14#include <Parsers/ExpressionListParsers.h>
15#include <Parsers/parseQuery.h>
16#include <IO/WriteHelpers.h>
17
18namespace DB
19{
20
21namespace ErrorCodes
22{
23 extern const int LOGICAL_ERROR;
24 extern const int NOT_IMPLEMENTED;
25}
26
27namespace
28{
29
30struct JoinedTable
31{
32 DatabaseAndTableWithAlias table;
33 ASTTablesInSelectQueryElement * element = nullptr;
34 ASTTableJoin * join = nullptr;
35 ASTPtr array_join = nullptr;
36 bool has_using = false;
37
38 JoinedTable(ASTPtr table_element)
39 {
40 element = table_element->as<ASTTablesInSelectQueryElement>();
41 if (!element)
42 throw Exception("Logical error: TablesInSelectQueryElement expected", ErrorCodes::LOGICAL_ERROR);
43
44 if (element->table_join)
45 {
46 join = element->table_join->as<ASTTableJoin>();
47 if (join->kind == ASTTableJoin::Kind::Cross ||
48 join->kind == ASTTableJoin::Kind::Comma)
49 {
50 if (!join->children.empty())
51 throw Exception("Logical error: CROSS JOIN has expressions", ErrorCodes::LOGICAL_ERROR);
52 }
53
54 if (join->using_expression_list)
55 has_using = true;
56 }
57
58 if (element->table_expression)
59 {
60 const auto & expr = element->table_expression->as<ASTTableExpression &>();
61 table = DatabaseAndTableWithAlias(expr);
62 }
63
64 array_join = element->array_join;
65 }
66
67 void rewriteCommaToCross()
68 {
69 if (join)
70 join->kind = ASTTableJoin::Kind::Cross;
71 }
72
73 bool canAttachOnExpression() const { return join && !join->on_expression; }
74};
75
76bool isComparison(const String & name)
77{
78 return name == NameEquals::name ||
79 name == NameNotEquals::name ||
80 name == NameLess::name ||
81 name == NameGreater::name ||
82 name == NameLessOrEquals::name ||
83 name == NameGreaterOrEquals::name;
84}
85
86/// It checks if where expression could be moved to JOIN ON expression partially or entirely.
87class CheckExpressionVisitorData
88{
89public:
90 using TypeToVisit = const ASTFunction;
91
92 CheckExpressionVisitorData(const std::vector<JoinedTable> & tables_)
93 : joined_tables(tables_)
94 , ands_only(true)
95 {
96 for (auto & joined : joined_tables)
97 tables.push_back(joined.table);
98 }
99
100 void visit(const ASTFunction & node, const ASTPtr & ast)
101 {
102 if (!ands_only)
103 return;
104
105 if (node.name == NameAnd::name)
106 {
107 if (!node.arguments || node.arguments->children.empty())
108 throw Exception("Logical error: function requires argiment", ErrorCodes::LOGICAL_ERROR);
109
110 for (auto & child : node.arguments->children)
111 {
112 if (const auto * func = child->as<ASTFunction>())
113 visit(*func, child);
114 else
115 ands_only = false;
116 }
117 }
118 else if (node.name == NameEquals::name)
119 {
120 if (size_t min_table = canMoveEqualsToJoinOn(node))
121 asts_to_join_on[min_table].push_back(ast);
122 }
123 else if (isComparison(node.name))
124 {
125 /// leave other comparisons as is
126 }
127 else if (functionIsInOperator(node.name)) /// IN, NOT IN
128 {
129 if (auto ident = node.arguments->children.at(0)->as<ASTIdentifier>())
130 if (size_t min_table = checkIdentifier(*ident))
131 asts_to_join_on[min_table].push_back(ast);
132 }
133 else
134 {
135 ands_only = false;
136 asts_to_join_on.clear();
137 }
138 }
139
140 bool complex() const { return !ands_only; }
141 bool matchAny(size_t t) const { return asts_to_join_on.count(t); }
142
143 ASTPtr makeOnExpression(size_t table_pos)
144 {
145 if (!asts_to_join_on.count(table_pos))
146 return {};
147
148 std::vector<ASTPtr> & expressions = asts_to_join_on[table_pos];
149
150 if (expressions.size() == 1)
151 return expressions[0]->clone();
152
153 std::vector<ASTPtr> arguments;
154 arguments.reserve(expressions.size());
155 for (auto & ast : expressions)
156 arguments.emplace_back(ast->clone());
157
158 return makeASTFunction(NameAnd::name, std::move(arguments));
159 }
160
161private:
162 const std::vector<JoinedTable> & joined_tables;
163 std::vector<DatabaseAndTableWithAlias> tables;
164 std::map<size_t, std::vector<ASTPtr>> asts_to_join_on;
165 bool ands_only;
166
167 size_t canMoveEqualsToJoinOn(const ASTFunction & node)
168 {
169 if (!node.arguments)
170 throw Exception("Logical error: function requires arguments", ErrorCodes::LOGICAL_ERROR);
171 if (node.arguments->children.size() != 2)
172 return false;
173
174 const auto * left = node.arguments->children[0]->as<ASTIdentifier>();
175 const auto * right = node.arguments->children[1]->as<ASTIdentifier>();
176 if (!left || !right)
177 return false;
178
179 return checkIdentifiers(*left, *right);
180 }
181
182 /// Check if the identifiers are from different joined tables. If it's a self joint, tables should have aliases.
183 /// select * from t1 a cross join t2 b where a.x = b.x
184 /// @return table position to attach expression to or 0.
185 size_t checkIdentifiers(const ASTIdentifier & left, const ASTIdentifier & right)
186 {
187 size_t left_table_pos = 0;
188 bool left_match = IdentifierSemantic::chooseTable(left, tables, left_table_pos);
189
190 size_t right_table_pos = 0;
191 bool right_match = IdentifierSemantic::chooseTable(right, tables, right_table_pos);
192
193 if (left_match && right_match && (left_table_pos != right_table_pos))
194 {
195 size_t table_pos = std::max(left_table_pos, right_table_pos);
196 if (joined_tables[table_pos].canAttachOnExpression())
197 return table_pos;
198 }
199 return 0;
200 }
201
202 size_t checkIdentifier(const ASTIdentifier & identifier)
203 {
204 size_t best_table_pos = 0;
205 bool match = IdentifierSemantic::chooseTable(identifier, tables, best_table_pos);
206
207 if (match && joined_tables[best_table_pos].canAttachOnExpression())
208 return best_table_pos;
209 return 0;
210 }
211};
212
213using CheckExpressionMatcher = ConstOneTypeMatcher<CheckExpressionVisitorData, false>;
214using CheckExpressionVisitor = ConstInDepthNodeVisitor<CheckExpressionMatcher, true>;
215
216
217bool getTables(ASTSelectQuery & select, std::vector<JoinedTable> & joined_tables, size_t & num_comma)
218{
219 if (!select.tables())
220 return false;
221
222 const auto * tables = select.tables()->as<ASTTablesInSelectQuery>();
223 if (!tables)
224 return false;
225
226 size_t num_tables = tables->children.size();
227 if (num_tables < 2)
228 return false;
229
230 joined_tables.reserve(num_tables);
231 size_t num_array_join = 0;
232 size_t num_using = 0;
233
234 for (auto & child : tables->children)
235 {
236 joined_tables.emplace_back(JoinedTable(child));
237 JoinedTable & t = joined_tables.back();
238 if (t.array_join)
239 {
240 ++num_array_join;
241 continue;
242 }
243
244 if (t.has_using)
245 {
246 ++num_using;
247 continue;
248 }
249
250 if (auto * join = t.join)
251 if (join->kind == ASTTableJoin::Kind::Comma)
252 ++num_comma;
253 }
254
255 if (num_using && (num_tables - num_array_join) > 2)
256 throw Exception("Multiple CROSS/COMMA JOIN do not support USING", ErrorCodes::NOT_IMPLEMENTED);
257
258 if (num_comma && (num_comma != (joined_tables.size() - 1)))
259 throw Exception("Mix of COMMA and other JOINS is not supported", ErrorCodes::NOT_IMPLEMENTED);
260
261 if (num_array_join || num_using)
262 return false;
263 return true;
264}
265
266}
267
268
269void CrossToInnerJoinMatcher::visit(ASTPtr & ast, Data & data)
270{
271 if (auto * t = ast->as<ASTSelectQuery>())
272 visit(*t, ast, data);
273}
274
275void CrossToInnerJoinMatcher::visit(ASTSelectQuery & select, ASTPtr &, Data & data)
276{
277 size_t num_comma = 0;
278 std::vector<JoinedTable> joined_tables;
279 if (!getTables(select, joined_tables, num_comma))
280 return;
281
282 /// COMMA to CROSS
283
284 if (num_comma)
285 {
286 for (auto & table : joined_tables)
287 table.rewriteCommaToCross();
288 }
289
290 /// CROSS to INNER
291
292 if (!select.where())
293 return;
294
295 CheckExpressionVisitor::Data visitor_data{joined_tables};
296 CheckExpressionVisitor(visitor_data).visit(select.where());
297
298 if (visitor_data.complex())
299 return;
300
301 for (size_t i = 1; i < joined_tables.size(); ++i)
302 {
303 if (visitor_data.matchAny(i))
304 {
305 ASTTableJoin & join = *joined_tables[i].join;
306 join.kind = ASTTableJoin::Kind::Inner;
307 join.strictness = ASTTableJoin::Strictness::All;
308
309 join.on_expression = visitor_data.makeOnExpression(i);
310 join.children.push_back(join.on_expression);
311 data.done = true;
312 }
313 }
314}
315
316}
317