1#include <Core/NamesAndTypes.h>
2
3#include <Interpreters/Context.h>
4#include <Interpreters/SyntaxAnalyzer.h>
5#include <Interpreters/ExpressionAnalyzer.h>
6#include <Interpreters/ExpressionActions.h>
7#include <Interpreters/IdentifierSemantic.h>
8
9#include <Parsers/ASTIdentifier.h>
10#include <Parsers/ASTExpressionList.h>
11#include <Parsers/ASTLiteral.h>
12#include <Parsers/ASTFunction.h>
13#include <Parsers/ASTSelectQuery.h>
14
15#include <Columns/ColumnsNumber.h>
16#include <Columns/ColumnsCommon.h>
17
18#include <Storages/VirtualColumnUtils.h>
19#include <IO/WriteHelpers.h>
20#include <Common/typeid_cast.h>
21
22
23namespace DB
24{
25
26namespace
27{
28
29/// Verifying that the function depends only on the specified columns
30bool isValidFunction(const ASTPtr & expression, const NameSet & columns)
31{
32 for (size_t i = 0; i < expression->children.size(); ++i)
33 if (!isValidFunction(expression->children[i], columns))
34 return false;
35
36 if (auto opt_name = IdentifierSemantic::getColumnName(expression))
37 return columns.count(*opt_name);
38
39 return true;
40}
41
42/// Extract all subfunctions of the main conjunction, but depending only on the specified columns
43void extractFunctions(const ASTPtr & expression, const NameSet & columns, std::vector<ASTPtr> & result)
44{
45 const auto * function = expression->as<ASTFunction>();
46 if (function && function->name == "and")
47 {
48 for (size_t i = 0; i < function->arguments->children.size(); ++i)
49 extractFunctions(function->arguments->children[i], columns, result);
50 }
51 else if (isValidFunction(expression, columns))
52 {
53 result.push_back(expression->clone());
54 }
55}
56
57/// Construct a conjunction from given functions
58ASTPtr buildWhereExpression(const ASTs & functions)
59{
60 if (functions.size() == 0)
61 return nullptr;
62 if (functions.size() == 1)
63 return functions[0];
64 return makeASTFunction("and", functions);
65}
66
67}
68
69namespace VirtualColumnUtils
70{
71
72void rewriteEntityInAst(ASTPtr ast, const String & column_name, const Field & value, const String & func)
73{
74 auto & select = ast->as<ASTSelectQuery &>();
75 if (!select.with())
76 select.setExpression(ASTSelectQuery::Expression::WITH, std::make_shared<ASTExpressionList>());
77
78
79 if (func.empty())
80 {
81 auto literal = std::make_shared<ASTLiteral>(value);
82 literal->alias = column_name;
83 literal->prefer_alias_to_column_name = true;
84 select.with()->children.push_back(literal);
85 }
86 else
87 {
88 auto literal = std::make_shared<ASTLiteral>(value);
89 literal->prefer_alias_to_column_name = true;
90
91 auto function = makeASTFunction(func, literal);
92 function->alias = column_name;
93 function->prefer_alias_to_column_name = true;
94 select.with()->children.push_back(function);
95 }
96}
97
98void filterBlockWithQuery(const ASTPtr & query, Block & block, const Context & context)
99{
100 const auto & select = query->as<ASTSelectQuery &>();
101 if (!select.where() && !select.prewhere())
102 return;
103
104 NameSet columns;
105 for (const auto & it : block.getNamesAndTypesList())
106 columns.insert(it.name);
107
108 /// We will create an expression that evaluates the expressions in WHERE and PREWHERE, depending only on the existing columns.
109 std::vector<ASTPtr> functions;
110 if (select.where())
111 extractFunctions(select.where(), columns, functions);
112 if (select.prewhere())
113 extractFunctions(select.prewhere(), columns, functions);
114
115 ASTPtr expression_ast = buildWhereExpression(functions);
116 if (!expression_ast)
117 return;
118
119 /// Let's analyze and calculate the expression.
120 auto syntax_result = SyntaxAnalyzer(context).analyze(expression_ast, block.getNamesAndTypesList());
121 ExpressionAnalyzer analyzer(expression_ast, syntax_result, context);
122 ExpressionActionsPtr actions = analyzer.getActions(false);
123
124 Block block_with_filter = block;
125 actions->execute(block_with_filter);
126
127 /// Filter the block.
128 String filter_column_name = expression_ast->getColumnName();
129 ColumnPtr filter_column = block_with_filter.getByName(filter_column_name).column->convertToFullColumnIfConst();
130 const IColumn::Filter & filter = typeid_cast<const ColumnUInt8 &>(*filter_column).getData();
131
132 for (size_t i = 0; i < block.columns(); ++i)
133 {
134 ColumnPtr & column = block.safeGetByPosition(i).column;
135 column = column->filter(filter, -1);
136 }
137}
138
139}
140
141}
142