1#include <Interpreters/RequiredSourceColumnsVisitor.h>
2#include <Common/typeid_cast.h>
3#include <Core/Names.h>
4#include <Parsers/IAST.h>
5#include <Parsers/ASTIdentifier.h>
6#include <Parsers/ASTFunction.h>
7#include <Parsers/ASTSelectQuery.h>
8#include <Parsers/ASTSubquery.h>
9#include <Parsers/ASTTablesInSelectQuery.h>
10
11namespace DB
12{
13
14namespace ErrorCodes
15{
16 extern const int TYPE_MISMATCH;
17 extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
18}
19
20static std::vector<String> extractNamesFromLambda(const ASTFunction & node)
21{
22 if (node.arguments->children.size() != 2)
23 throw Exception("lambda requires two arguments", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
24
25 const auto * lambda_args_tuple = node.arguments->children[0]->as<ASTFunction>();
26
27 if (!lambda_args_tuple || lambda_args_tuple->name != "tuple")
28 throw Exception("First argument of lambda must be a tuple", ErrorCodes::TYPE_MISMATCH);
29
30 std::vector<String> names;
31 for (auto & child : lambda_args_tuple->arguments->children)
32 {
33 const auto * identifier = child->as<ASTIdentifier>();
34 if (!identifier)
35 throw Exception("lambda argument declarations must be identifiers", ErrorCodes::TYPE_MISMATCH);
36
37 names.push_back(identifier->name);
38 }
39
40 return names;
41}
42
43bool RequiredSourceColumnsMatcher::needChildVisit(const ASTPtr & node, const ASTPtr & child)
44{
45 if (child->as<ASTSelectQuery>())
46 return false;
47
48 /// Processed. Do not need children.
49 if (node->as<ASTTableExpression>() || node->as<ASTArrayJoin>() || node->as<ASTSelectQuery>())
50 return false;
51
52 if (const auto * f = node->as<ASTFunction>())
53 {
54 /// "indexHint" is a special function for index analysis. Everything that is inside it is not calculated. @sa KeyCondition
55 /// "lambda" visit children itself.
56 if (f->name == "indexHint" || f->name == "lambda")
57 return false;
58 }
59
60 return true;
61}
62
63void RequiredSourceColumnsMatcher::visit(const ASTPtr & ast, Data & data)
64{
65 /// results are columns
66
67 if (auto * t = ast->as<ASTIdentifier>())
68 {
69 visit(*t, ast, data);
70 return;
71 }
72 if (auto * t = ast->as<ASTFunction>())
73 {
74 data.addColumnAliasIfAny(*ast);
75 visit(*t, ast, data);
76 return;
77 }
78
79 /// results are tables
80
81 if (auto * t = ast->as<ASTTablesInSelectQueryElement>())
82 {
83 visit(*t, ast, data);
84 return;
85 }
86
87 if (auto * t = ast->as<ASTTableExpression>())
88 {
89 visit(*t, ast, data);
90 return;
91 }
92 if (auto * t = ast->as<ASTSelectQuery>())
93 {
94 data.addTableAliasIfAny(*ast);
95 visit(*t, ast, data);
96 return;
97 }
98 if (ast->as<ASTSubquery>())
99 {
100 data.addTableAliasIfAny(*ast);
101 return;
102 }
103
104 /// other
105
106 if (auto * t = ast->as<ASTArrayJoin>())
107 {
108 data.has_array_join = true;
109 visit(*t, ast, data);
110 return;
111 }
112}
113
114void RequiredSourceColumnsMatcher::visit(const ASTSelectQuery & select, const ASTPtr &, Data & data)
115{
116 /// special case for top-level SELECT items: they are publics
117 for (auto & node : select.select()->children)
118 {
119 if (const auto * identifier = node->as<ASTIdentifier>())
120 data.addColumnIdentifier(*identifier);
121 else
122 data.addColumnAliasIfAny(*node);
123 }
124
125 std::vector<ASTPtr *> out;
126 for (auto & node : select.children)
127 if (node != select.select())
128 Visitor(data).visit(node);
129
130 /// revisit select_expression_list (with children) when all the aliases are set
131 Visitor(data).visit(select.select());
132}
133
134void RequiredSourceColumnsMatcher::visit(const ASTIdentifier & node, const ASTPtr &, Data & data)
135{
136 if (node.name.empty())
137 throw Exception("Expected not empty name", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
138
139 if (!data.private_aliases.count(node.name))
140 data.addColumnIdentifier(node);
141}
142
143void RequiredSourceColumnsMatcher::visit(const ASTFunction & node, const ASTPtr &, Data & data)
144{
145 /// Do not add formal parameters of the lambda expression
146 if (node.name == "lambda")
147 {
148 Names local_aliases;
149 for (const auto & name : extractNamesFromLambda(node))
150 if (data.private_aliases.insert(name).second)
151 local_aliases.push_back(name);
152
153 /// visit child with masked local aliases
154 RequiredSourceColumnsVisitor(data).visit(node.arguments->children[1]);
155
156 for (const auto & name : local_aliases)
157 data.private_aliases.erase(name);
158 }
159}
160
161void RequiredSourceColumnsMatcher::visit(const ASTTablesInSelectQueryElement & node, const ASTPtr &, Data & data)
162{
163 ASTTableExpression * expr = nullptr;
164 ASTTableJoin * join = nullptr;
165
166 for (auto & child : node.children)
167 {
168 if (auto * e = child->as<ASTTableExpression>())
169 expr = e;
170 if (auto * j = child->as<ASTTableJoin>())
171 join = j;
172 }
173
174 if (join)
175 data.has_table_join = true;
176 data.tables.emplace_back(ColumnNamesContext::JoinedTable{expr, join});
177}
178
179/// ASTIdentifiers here are tables. Do not visit them as generic ones.
180void RequiredSourceColumnsMatcher::visit(const ASTTableExpression & node, const ASTPtr &, Data & data)
181{
182 if (node.database_and_table_name)
183 data.addTableAliasIfAny(*node.database_and_table_name);
184
185 if (node.table_function)
186 data.addTableAliasIfAny(*node.table_function);
187
188 if (node.subquery)
189 data.addTableAliasIfAny(*node.subquery);
190}
191
192void RequiredSourceColumnsMatcher::visit(const ASTArrayJoin & node, const ASTPtr &, Data & data)
193{
194 ASTPtr expression_list = node.expression_list;
195 if (!expression_list || expression_list->children.empty())
196 throw Exception("Expected not empty expression_list", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
197
198 std::vector<ASTPtr *> out;
199
200 /// Tech debt. Ignore ARRAY JOIN top-level identifiers and aliases. There's its own logic for them.
201 for (auto & expr : expression_list->children)
202 {
203 data.addArrayJoinAliasIfAny(*expr);
204
205 if (const auto * identifier = expr->as<ASTIdentifier>())
206 {
207 data.addArrayJoinIdentifier(*identifier);
208 continue;
209 }
210
211 out.push_back(&expr);
212 }
213
214 for (ASTPtr * add_node : out)
215 Visitor(data).visit(*add_node);
216}
217
218}
219