1#include <iostream>
2
3#include <Common/typeid_cast.h>
4#include <Storages/IStorage.h>
5#include <Interpreters/PredicateExpressionsOptimizer.h>
6#include <Interpreters/InterpreterSelectQuery.h>
7#include <Interpreters/IdentifierSemantic.h>
8#include <AggregateFunctions/AggregateFunctionFactory.h>
9#include <Parsers/IAST.h>
10#include <Parsers/ASTFunction.h>
11#include <Parsers/ASTIdentifier.h>
12#include <Parsers/ASTSelectQuery.h>
13#include <Parsers/ASTSelectWithUnionQuery.h>
14#include <Parsers/ASTSubquery.h>
15#include <Parsers/ASTTablesInSelectQuery.h>
16#include <Parsers/ASTAsterisk.h>
17#include <Parsers/ASTQualifiedAsterisk.h>
18#include <Parsers/ASTColumnsMatcher.h>
19#include <Parsers/queryToString.h>
20#include <Interpreters/Context.h>
21#include <Interpreters/ExpressionActions.h>
22#include <Interpreters/QueryNormalizer.h>
23#include <Interpreters/QueryAliasesVisitor.h>
24#include <Interpreters/MarkTableIdentifiersVisitor.h>
25#include <Interpreters/TranslateQualifiedNamesVisitor.h>
26#include <Interpreters/FindIdentifierBestTableVisitor.h>
27#include <Interpreters/ExtractFunctionDataVisitor.h>
28#include <Interpreters/getTableExpressions.h>
29#include <Functions/FunctionFactory.h>
30
31
32namespace DB
33{
34
35namespace ErrorCodes
36{
37 extern const int LOGICAL_ERROR;
38 extern const int UNKNOWN_ELEMENT_IN_AST;
39}
40
41namespace
42{
43
44constexpr auto and_function_name = "and";
45
46String qualifiedName(ASTIdentifier * identifier, const String & prefix)
47{
48 if (identifier->isShort())
49 return prefix + identifier->getAliasOrColumnName();
50 return identifier->getAliasOrColumnName();
51}
52
53}
54
55PredicateExpressionsOptimizer::PredicateExpressionsOptimizer(
56 ASTSelectQuery * ast_select_, ExtractedSettings && settings_, const Context & context_)
57 : ast_select(ast_select_), settings(settings_), context(context_)
58{
59}
60
61bool PredicateExpressionsOptimizer::optimize()
62{
63 if (!settings.enable_optimize_predicate_expression || !ast_select || !ast_select->tables() || ast_select->tables()->children.empty())
64 return false;
65
66 if (!ast_select->where() && !ast_select->prewhere())
67 return false;
68
69 if (ast_select->array_join_expression_list())
70 return false;
71
72 SubqueriesProjectionColumns all_subquery_projection_columns = getAllSubqueryProjectionColumns();
73
74 bool is_rewrite_subqueries = false;
75 if (!all_subquery_projection_columns.empty())
76 {
77 is_rewrite_subqueries |= optimizeImpl(ast_select->where(), all_subquery_projection_columns, OptimizeKind::PUSH_TO_WHERE);
78 is_rewrite_subqueries |= optimizeImpl(ast_select->prewhere(), all_subquery_projection_columns, OptimizeKind::PUSH_TO_PREWHERE);
79 }
80
81 return is_rewrite_subqueries;
82}
83
84bool PredicateExpressionsOptimizer::optimizeImpl(
85 const ASTPtr & outer_expression, const SubqueriesProjectionColumns & subqueries_projection_columns, OptimizeKind expression_kind)
86{
87 /// split predicate with `and`
88 std::vector<ASTPtr> outer_predicate_expressions = splitConjunctionPredicate(outer_expression);
89
90 std::vector<const ASTTableExpression *> table_expressions = getTableExpressions(*ast_select);
91 std::vector<TableWithColumnNames> tables_with_columns = getDatabaseAndTablesWithColumnNames(table_expressions, context);
92
93 bool is_rewrite_subquery = false;
94 for (auto & outer_predicate : outer_predicate_expressions)
95 {
96 if (isArrayJoinFunction(outer_predicate))
97 continue;
98
99 auto outer_predicate_dependencies = getDependenciesAndQualifiers(outer_predicate, tables_with_columns);
100
101 /// TODO: remove origin expression
102 for (const auto & [subquery, projection_columns] : subqueries_projection_columns)
103 {
104 OptimizeKind optimize_kind = OptimizeKind::NONE;
105 if (allowPushDown(subquery, outer_predicate, projection_columns, outer_predicate_dependencies, optimize_kind))
106 {
107 if (optimize_kind == OptimizeKind::NONE)
108 optimize_kind = expression_kind;
109
110 ASTPtr inner_predicate = outer_predicate->clone();
111 cleanExpressionAlias(inner_predicate); /// clears the alias name contained in the outer predicate
112
113 std::vector<IdentifierWithQualifier> inner_predicate_dependencies =
114 getDependenciesAndQualifiers(inner_predicate, tables_with_columns);
115
116 setNewAliasesForInnerPredicate(projection_columns, inner_predicate_dependencies);
117
118 switch (optimize_kind)
119 {
120 case OptimizeKind::NONE: continue;
121 case OptimizeKind::PUSH_TO_WHERE:
122 is_rewrite_subquery |= optimizeExpression(inner_predicate, subquery, ASTSelectQuery::Expression::WHERE);
123 continue;
124 case OptimizeKind::PUSH_TO_HAVING:
125 is_rewrite_subquery |= optimizeExpression(inner_predicate, subquery, ASTSelectQuery::Expression::HAVING);
126 continue;
127 case OptimizeKind::PUSH_TO_PREWHERE:
128 is_rewrite_subquery |= optimizeExpression(inner_predicate, subquery, ASTSelectQuery::Expression::PREWHERE);
129 continue;
130 }
131 }
132 }
133 }
134 return is_rewrite_subquery;
135}
136
137bool PredicateExpressionsOptimizer::allowPushDown(
138 const ASTSelectQuery * subquery,
139 const ASTPtr &,
140 const std::vector<ProjectionWithAlias> & projection_columns,
141 const std::vector<IdentifierWithQualifier> & dependencies,
142 OptimizeKind & optimize_kind)
143{
144 if (!subquery
145 || (!settings.enable_optimize_predicate_expression_to_final_subquery && subquery->final())
146 || subquery->limitBy() || subquery->limitLength()
147 || subquery->with() || subquery->withFill())
148 return false;
149 else
150 {
151 ASTPtr expr_list = ast_select->select();
152 ExtractFunctionVisitor::Data extract_data;
153 ExtractFunctionVisitor(extract_data).visit(expr_list);
154
155 for (const auto & subquery_function : extract_data.functions)
156 {
157 const auto & function = FunctionFactory::instance().tryGet(subquery_function->name, context);
158
159 /// Skip lambda, tuple and other special functions
160 if (function && function->isStateful())
161 return false;
162 }
163 }
164
165 const auto * ast_join = ast_select->join();
166 const ASTTableExpression * left_table_expr = nullptr;
167 const ASTTableExpression * right_table_expr = nullptr;
168 const ASTSelectQuery * left_subquery = nullptr;
169 const ASTSelectQuery * right_subquery = nullptr;
170
171 if (ast_join)
172 {
173 left_table_expr = ast_select
174 ->tables()->as<ASTTablesInSelectQuery>()
175 ->children[0]->as<ASTTablesInSelectQueryElement>()
176 ->table_expression->as<ASTTableExpression>();
177 right_table_expr = ast_select
178 ->tables()->as<ASTTablesInSelectQuery>()
179 ->children[1]->as<ASTTablesInSelectQueryElement>()
180 ->table_expression->as<ASTTableExpression>();
181
182 if (left_table_expr && left_table_expr->subquery)
183 left_subquery = left_table_expr->subquery
184 ->children[0]->as<ASTSelectWithUnionQuery>()
185 ->list_of_selects->children[0]->as<ASTSelectQuery>();
186 if (right_table_expr && right_table_expr->subquery)
187 right_subquery = right_table_expr->subquery
188 ->children[0]->as<ASTSelectWithUnionQuery>()
189 ->list_of_selects->children[0]->as<ASTSelectQuery>();
190
191 /// NOTE: the syntactic way of pushdown has limitations and should be partially disabled in case of JOINs.
192 /// Let's take a look at the query:
193 ///
194 /// SELECT a, b FROM (SELECT 1 AS a) ANY LEFT JOIN (SELECT 1 AS a, 1 AS b) USING (a) WHERE b = 0
195 ///
196 /// The result is empty - without pushdown. But the pushdown tends to modify it in this way:
197 ///
198 /// SELECT a, b FROM (SELECT 1 AS a) ANY LEFT JOIN (SELECT 1 AS a, 1 AS b WHERE b = 0) USING (a) WHERE b = 0
199 ///
200 /// That leads to the empty result in the right subquery and changes the whole outcome to (1, 0) or (1, NULL).
201 /// It happens because the not-matching columns are replaced with a global default values on JOIN.
202 /// Same is true for RIGHT JOIN and FULL JOIN.
203
204 /// Check right side for LEFT'o'FULL JOIN
205 if (isLeftOrFull(ast_join->table_join->as<ASTTableJoin>()->kind) && right_subquery == subquery)
206 return false;
207
208 /// Check left side for RIGHT'o'FULL JOIN
209 if (isRightOrFull(ast_join->table_join->as<ASTTableJoin>()->kind) && left_subquery == subquery)
210 return false;
211 }
212
213 return checkDependencies(projection_columns, dependencies, optimize_kind);
214}
215
216bool PredicateExpressionsOptimizer::checkDependencies(
217 const std::vector<ProjectionWithAlias> & projection_columns,
218 const std::vector<IdentifierWithQualifier> & dependencies,
219 OptimizeKind & optimize_kind)
220{
221 for (const auto & [identifier, prefix] : dependencies)
222 {
223 bool is_found = false;
224 String qualified_name = qualifiedName(identifier, prefix);
225
226 for (const auto & [ast, alias] : projection_columns)
227 {
228 if (alias == qualified_name)
229 {
230 is_found = true;
231 ASTPtr projection_column = ast;
232 ExtractFunctionVisitor::Data extract_data;
233 ExtractFunctionVisitor(extract_data).visit(projection_column);
234
235 if (!extract_data.aggregate_functions.empty())
236 optimize_kind = OptimizeKind::PUSH_TO_HAVING;
237 }
238 }
239
240 if (!is_found)
241 return false;
242 }
243
244 return true;
245}
246
247std::vector<ASTPtr> PredicateExpressionsOptimizer::splitConjunctionPredicate(const ASTPtr & predicate_expression)
248{
249 std::vector<ASTPtr> predicate_expressions;
250
251 if (predicate_expression)
252 {
253 predicate_expressions.emplace_back(predicate_expression);
254
255 auto remove_expression_at_index = [&predicate_expressions] (const size_t index)
256 {
257 if (index < predicate_expressions.size() - 1)
258 std::swap(predicate_expressions[index], predicate_expressions.back());
259 predicate_expressions.pop_back();
260 };
261
262 for (size_t idx = 0; idx < predicate_expressions.size();)
263 {
264 const auto expression = predicate_expressions.at(idx);
265
266 if (const auto * function = expression->as<ASTFunction>())
267 {
268 if (function->name == and_function_name)
269 {
270 for (auto & child : function->arguments->children)
271 predicate_expressions.emplace_back(child);
272
273 remove_expression_at_index(idx);
274 continue;
275 }
276 }
277 ++idx;
278 }
279 }
280 return predicate_expressions;
281}
282
283std::vector<PredicateExpressionsOptimizer::IdentifierWithQualifier>
284PredicateExpressionsOptimizer::getDependenciesAndQualifiers(ASTPtr & expression, std::vector<TableWithColumnNames> & tables)
285{
286 FindIdentifierBestTableVisitor::Data find_data(tables);
287 FindIdentifierBestTableVisitor(find_data).visit(expression);
288
289 std::vector<IdentifierWithQualifier> dependencies;
290
291 for (const auto & [identifier, table] : find_data.identifier_table)
292 {
293 String table_alias;
294 if (table)
295 table_alias = table->getQualifiedNamePrefix();
296
297 dependencies.emplace_back(identifier, table_alias);
298 }
299
300 return dependencies;
301}
302
303void PredicateExpressionsOptimizer::setNewAliasesForInnerPredicate(
304 const std::vector<ProjectionWithAlias> & projection_columns,
305 const std::vector<IdentifierWithQualifier> & dependencies)
306{
307 for (auto & [identifier, prefix] : dependencies)
308 {
309 String qualified_name = qualifiedName(identifier, prefix);
310
311 for (auto & [ast, alias] : projection_columns)
312 {
313 if (alias == qualified_name)
314 {
315 String name;
316 if (auto * id = ast->as<ASTIdentifier>())
317 {
318 name = id->tryGetAlias();
319 if (name.empty())
320 name = id->shortName();
321 }
322 else
323 {
324 if (ast->tryGetAlias().empty())
325 ast->setAlias(ast->getColumnName());
326 name = ast->getAliasOrColumnName();
327 }
328
329 identifier->setShortName(name);
330 }
331 }
332 }
333}
334
335bool PredicateExpressionsOptimizer::isArrayJoinFunction(const ASTPtr & node)
336{
337 if (const auto * function = node->as<ASTFunction>())
338 {
339 if (function->name == "arrayJoin")
340 return true;
341 }
342
343 for (auto & child : node->children)
344 if (isArrayJoinFunction(child))
345 return true;
346
347 return false;
348}
349
350bool PredicateExpressionsOptimizer::optimizeExpression(const ASTPtr & outer_expression, ASTSelectQuery * subquery,
351 ASTSelectQuery::Expression expr)
352{
353 ASTPtr subquery_expression = subquery->getExpression(expr, false);
354 subquery_expression = subquery_expression ? makeASTFunction(and_function_name, outer_expression, subquery_expression) : outer_expression;
355
356 subquery->setExpression(expr, std::move(subquery_expression));
357 return true;
358}
359
360PredicateExpressionsOptimizer::SubqueriesProjectionColumns PredicateExpressionsOptimizer::getAllSubqueryProjectionColumns()
361{
362 SubqueriesProjectionColumns projection_columns;
363
364 for (const auto & table_expression : getTableExpressions(*ast_select))
365 if (table_expression->subquery)
366 getSubqueryProjectionColumns(table_expression->subquery, projection_columns);
367
368 return projection_columns;
369}
370
371void PredicateExpressionsOptimizer::getSubqueryProjectionColumns(const ASTPtr & subquery, SubqueriesProjectionColumns & projection_columns)
372{
373 String qualified_name_prefix = subquery->tryGetAlias();
374 if (!qualified_name_prefix.empty())
375 qualified_name_prefix += '.';
376
377 const ASTPtr & subselect = subquery->children[0];
378
379 ASTs select_with_union_projections;
380 const auto * select_with_union_query = subselect->as<ASTSelectWithUnionQuery>();
381
382 for (auto & select : select_with_union_query->list_of_selects->children)
383 {
384 std::vector<ProjectionWithAlias> subquery_projections;
385 auto select_projection_columns = getSelectQueryProjectionColumns(select);
386
387 if (!select_projection_columns.empty())
388 {
389 if (select_with_union_projections.empty())
390 select_with_union_projections = select_projection_columns;
391
392 for (size_t i = 0; i < select_projection_columns.size(); i++)
393 subquery_projections.emplace_back(std::pair(select_projection_columns[i],
394 qualified_name_prefix + select_with_union_projections[i]->getAliasOrColumnName()));
395
396 projection_columns.insert(std::pair(select->as<ASTSelectQuery>(), subquery_projections));
397 }
398 }
399}
400
401ASTs PredicateExpressionsOptimizer::getSelectQueryProjectionColumns(ASTPtr & ast)
402{
403 ASTs projection_columns;
404 auto * select_query = ast->as<ASTSelectQuery>();
405
406 /// first should normalize query tree.
407 std::unordered_map<String, ASTPtr> aliases;
408 std::vector<DatabaseAndTableWithAlias> tables = getDatabaseAndTables(*select_query, context.getCurrentDatabase());
409
410 /// TODO: get tables from evaluateAsterisk instead of tablesOnly() to extract asterisks in general way
411 std::vector<TableWithColumnNames> tables_with_columns = TranslateQualifiedNamesVisitor::Data::tablesOnly(tables);
412 TranslateQualifiedNamesVisitor::Data qn_visitor_data({}, std::move(tables_with_columns), false);
413 TranslateQualifiedNamesVisitor(qn_visitor_data).visit(ast);
414
415 QueryAliasesVisitor::Data query_aliases_data{aliases};
416 QueryAliasesVisitor(query_aliases_data).visit(ast);
417
418 MarkTableIdentifiersVisitor::Data mark_tables_data{aliases};
419 MarkTableIdentifiersVisitor(mark_tables_data).visit(ast);
420
421 QueryNormalizer::Data normalizer_data(aliases, settings);
422 QueryNormalizer(normalizer_data).visit(ast);
423
424 for (const auto & projection_column : select_query->select()->children)
425 {
426 if (projection_column->as<ASTAsterisk>() || projection_column->as<ASTQualifiedAsterisk>() || projection_column->as<ASTColumnsMatcher>())
427 {
428 ASTs evaluated_columns = evaluateAsterisk(select_query, projection_column);
429
430 for (const auto & column : evaluated_columns)
431 projection_columns.emplace_back(column);
432
433 continue;
434 }
435
436 projection_columns.emplace_back(projection_column);
437 }
438 return projection_columns;
439}
440
441ASTs PredicateExpressionsOptimizer::evaluateAsterisk(ASTSelectQuery * select_query, const ASTPtr & asterisk)
442{
443 /// SELECT *, SELECT dummy, SELECT 1 AS id
444 if (!select_query->tables() || select_query->tables()->children.empty())
445 return {};
446
447 std::vector<const ASTTableExpression *> tables_expression = getTableExpressions(*select_query);
448
449 if (const auto * qualified_asterisk = asterisk->as<ASTQualifiedAsterisk>())
450 {
451 if (qualified_asterisk->children.size() != 1)
452 throw Exception("Logical error: qualified asterisk must have exactly one child", ErrorCodes::LOGICAL_ERROR);
453
454 DatabaseAndTableWithAlias ident_db_and_name(qualified_asterisk->children[0]);
455
456 for (auto it = tables_expression.begin(); it != tables_expression.end();)
457 {
458 const ASTTableExpression * table_expression = *it;
459 DatabaseAndTableWithAlias database_and_table_with_alias(*table_expression, context.getCurrentDatabase());
460
461 if (ident_db_and_name.satisfies(database_and_table_with_alias, true))
462 ++it;
463 else
464 it = tables_expression.erase(it); /// It's not a required table
465 }
466 }
467
468 ASTs projection_columns;
469 for (auto & table_expression : tables_expression)
470 {
471 if (table_expression->subquery)
472 {
473 const auto * subquery = table_expression->subquery->as<ASTSubquery>();
474 const auto * select_with_union_query = subquery->children[0]->as<ASTSelectWithUnionQuery>();
475 const auto subquery_projections = getSelectQueryProjectionColumns(select_with_union_query->list_of_selects->children[0]);
476 projection_columns.insert(projection_columns.end(), subquery_projections.begin(), subquery_projections.end());
477 }
478 else
479 {
480 StoragePtr storage;
481
482 if (table_expression->table_function)
483 {
484 auto query_context = const_cast<Context *>(&context.getQueryContext());
485 storage = query_context->executeTableFunction(table_expression->table_function);
486 }
487 else if (table_expression->database_and_table_name)
488 {
489 const auto * database_and_table_ast = table_expression->database_and_table_name->as<ASTIdentifier>();
490 DatabaseAndTableWithAlias database_and_table_name(*database_and_table_ast);
491 storage = context.getTable(database_and_table_name.database, database_and_table_name.table);
492 }
493 else
494 throw Exception("Logical error: unexpected table expression", ErrorCodes::LOGICAL_ERROR);
495
496 const auto block = storage->getSampleBlock();
497 if (const auto * asterisk_pattern = asterisk->as<ASTColumnsMatcher>())
498 {
499 for (size_t idx = 0; idx < block.columns(); ++idx)
500 {
501 auto & col = block.getByPosition(idx);
502 if (asterisk_pattern->isColumnMatching(col.name))
503 projection_columns.emplace_back(std::make_shared<ASTIdentifier>(col.name));
504 }
505 }
506 else
507 {
508 for (size_t idx = 0; idx < block.columns(); ++idx)
509 projection_columns.emplace_back(std::make_shared<ASTIdentifier>(block.getByPosition(idx).name));
510 }
511 }
512 }
513 return projection_columns;
514}
515
516void PredicateExpressionsOptimizer::cleanExpressionAlias(ASTPtr & expression)
517{
518 const auto my_alias = expression->tryGetAlias();
519 if (!my_alias.empty())
520 expression->setAlias("");
521
522 for (auto & child : expression->children)
523 cleanExpressionAlias(child);
524}
525
526}
527