1#include <Interpreters/evaluateConstantExpression.h>
2
3#include <Columns/ColumnConst.h>
4#include <Columns/ColumnsNumber.h>
5#include <Core/Block.h>
6#include <DataTypes/DataTypesNumber.h>
7#include <Interpreters/Context.h>
8#include <Interpreters/convertFieldToType.h>
9#include <Interpreters/ExpressionActions.h>
10#include <Interpreters/ExpressionAnalyzer.h>
11#include <Interpreters/SyntaxAnalyzer.h>
12#include <Parsers/ASTFunction.h>
13#include <Parsers/ASTIdentifier.h>
14#include <Parsers/ASTLiteral.h>
15#include <Parsers/ExpressionElementParsers.h>
16#include <TableFunctions/TableFunctionFactory.h>
17#include <Common/typeid_cast.h>
18#include <Interpreters/ReplaceQueryParameterVisitor.h>
19
20
21namespace DB
22{
23
24namespace ErrorCodes
25{
26 extern const int LOGICAL_ERROR;
27 extern const int BAD_ARGUMENTS;
28}
29
30
31std::pair<Field, std::shared_ptr<const IDataType>> evaluateConstantExpression(const ASTPtr & node, const Context & context)
32{
33 NamesAndTypesList source_columns = {{ "_dummy", std::make_shared<DataTypeUInt8>() }};
34 auto ast = node->clone();
35 ReplaceQueryParameterVisitor param_visitor(context.getQueryParameters());
36 param_visitor.visit(ast);
37 String name = ast->getColumnName();
38 auto syntax_result = SyntaxAnalyzer(context).analyze(ast, source_columns);
39 ExpressionActionsPtr expr_for_constant_folding = ExpressionAnalyzer(ast, syntax_result, context).getConstActions();
40
41 /// There must be at least one column in the block so that it knows the number of rows.
42 Block block_with_constants{{ ColumnConst::create(ColumnUInt8::create(1, 0), 1), std::make_shared<DataTypeUInt8>(), "_dummy" }};
43
44 expr_for_constant_folding->execute(block_with_constants);
45
46 if (!block_with_constants || block_with_constants.rows() == 0)
47 throw Exception("Logical error: empty block after evaluation of constant expression for IN, VALUES or LIMIT", ErrorCodes::LOGICAL_ERROR);
48
49 if (!block_with_constants.has(name))
50 throw Exception("Element of set in IN, VALUES or LIMIT is not a constant expression (result column not found): " + name, ErrorCodes::BAD_ARGUMENTS);
51
52 const ColumnWithTypeAndName & result = block_with_constants.getByName(name);
53 const IColumn & result_column = *result.column;
54
55 /// Expressions like rand() or now() are not constant
56 if (!isColumnConst(result_column))
57 throw Exception("Element of set in IN, VALUES or LIMIT is not a constant expression (result column is not const): " + name, ErrorCodes::BAD_ARGUMENTS);
58
59 return std::make_pair(result_column[0], result.type);
60}
61
62
63ASTPtr evaluateConstantExpressionAsLiteral(const ASTPtr & node, const Context & context)
64{
65 /// If it's already a literal.
66 if (node->as<ASTLiteral>())
67 return node;
68
69 /// Skip table functions.
70 if (const auto * table_func_ptr = node->as<ASTFunction>())
71 if (TableFunctionFactory::instance().isTableFunctionName(table_func_ptr->name))
72 return node;
73
74 return std::make_shared<ASTLiteral>(evaluateConstantExpression(node, context).first);
75}
76
77ASTPtr evaluateConstantExpressionOrIdentifierAsLiteral(const ASTPtr & node, const Context & context)
78{
79 if (const auto * id = node->as<ASTIdentifier>())
80 return std::make_shared<ASTLiteral>(id->name);
81
82 return evaluateConstantExpressionAsLiteral(node, context);
83}
84
85namespace
86{
87 using Conjunction = ColumnsWithTypeAndName;
88 using Disjunction = std::vector<Conjunction>;
89
90 Disjunction analyzeEquals(const ASTIdentifier * identifier, const ASTLiteral * literal, const ExpressionActionsPtr & expr)
91 {
92 if (!identifier || !literal)
93 {
94 return {};
95 }
96
97 for (const auto & name_and_type : expr->getRequiredColumnsWithTypes())
98 {
99 const auto & name = name_and_type.name;
100 const auto & type = name_and_type.type;
101
102 if (name == identifier->name)
103 {
104 ColumnWithTypeAndName column;
105 // FIXME: what to do if field is not convertable?
106 column.column = type->createColumnConst(1, convertFieldToType(literal->value, *type));
107 column.name = name;
108 column.type = type;
109 return {{std::move(column)}};
110 }
111 }
112
113 return {};
114 }
115
116 Disjunction andDNF(const Disjunction & left, const Disjunction & right)
117 {
118 if (left.empty())
119 {
120 return right;
121 }
122
123 Disjunction result;
124
125 for (const auto & conjunct1 : left)
126 {
127 for (const auto & conjunct2 : right)
128 {
129 Conjunction new_conjunct{conjunct1};
130 new_conjunct.insert(new_conjunct.end(), conjunct2.begin(), conjunct2.end());
131 result.emplace_back(new_conjunct);
132 }
133 }
134
135 return result;
136 }
137
138 Disjunction analyzeFunction(const ASTFunction * fn, const ExpressionActionsPtr & expr)
139 {
140 if (!fn)
141 {
142 return {};
143 }
144
145 // TODO: enumerate all possible function names!
146
147 if (fn->name == "equals")
148 {
149 const auto * left = fn->arguments->children.front().get();
150 const auto * right = fn->arguments->children.back().get();
151 const auto * identifier = left->as<ASTIdentifier>() ? left->as<ASTIdentifier>() : right->as<ASTIdentifier>();
152 const auto * literal = left->as<ASTLiteral>() ? left->as<ASTLiteral>() : right->as<ASTLiteral>();
153
154 return analyzeEquals(identifier, literal, expr);
155 }
156 else if (fn->name == "in")
157 {
158 const auto * left = fn->arguments->children.front().get();
159 const auto * right = fn->arguments->children.back().get();
160 const auto * identifier = left->as<ASTIdentifier>();
161 const auto * inner_fn = right->as<ASTFunction>();
162
163 if (!inner_fn)
164 {
165 return {};
166 }
167
168 const auto * tuple = inner_fn->children.front()->as<ASTExpressionList>();
169
170 if (!tuple)
171 {
172 return {};
173 }
174
175 Disjunction result;
176
177 for (const auto & child : tuple->children)
178 {
179 const auto * literal = child->as<ASTLiteral>();
180 const auto dnf = analyzeEquals(identifier, literal, expr);
181
182 if (dnf.empty())
183 {
184 return {};
185 }
186
187 result.insert(result.end(), dnf.begin(), dnf.end());
188 }
189
190 return result;
191 }
192 else if (fn->name == "or")
193 {
194 const auto * args = fn->children.front()->as<ASTExpressionList>();
195
196 if (!args)
197 {
198 return {};
199 }
200
201 Disjunction result;
202
203 for (const auto & arg : args->children)
204 {
205 const auto dnf = analyzeFunction(arg->as<ASTFunction>(), expr);
206
207 if (dnf.empty())
208 {
209 return {};
210 }
211
212 result.insert(result.end(), dnf.begin(), dnf.end());
213 }
214
215 return result;
216 }
217 else if (fn->name == "and")
218 {
219 const auto * args = fn->children.front()->as<ASTExpressionList>();
220
221 if (!args)
222 {
223 return {};
224 }
225
226 Disjunction result;
227
228 for (const auto & arg : args->children)
229 {
230 const auto dnf = analyzeFunction(arg->as<ASTFunction>(), expr);
231
232 if (dnf.empty())
233 {
234 continue;
235 }
236
237 result = andDNF(result, dnf);
238 }
239
240 return result;
241 }
242
243 return {};
244 }
245}
246
247std::optional<Blocks> evaluateExpressionOverConstantCondition(const ASTPtr & node, const ExpressionActionsPtr & target_expr)
248{
249 Blocks result;
250
251 // TODO: `node` may be always-false literal.
252
253 if (const auto * fn = node->as<ASTFunction>())
254 {
255 const auto dnf = analyzeFunction(fn, target_expr);
256
257 if (dnf.empty())
258 {
259 return {};
260 }
261
262 auto hasRequiredColumns = [&target_expr](const Block & block) -> bool
263 {
264 for (const auto & name : target_expr->getRequiredColumns())
265 {
266 bool hasColumn = false;
267 for (const auto & column_name : block.getNames())
268 {
269 if (column_name == name)
270 {
271 hasColumn = true;
272 break;
273 }
274 }
275
276 if (!hasColumn)
277 return false;
278 }
279
280 return true;
281 };
282
283 for (const auto & conjunct : dnf)
284 {
285 Block block(conjunct);
286
287 // Block should contain all required columns from `target_expr`
288 if (!hasRequiredColumns(block))
289 {
290 return {};
291 }
292
293 target_expr->execute(block);
294
295 if (block.rows() == 1)
296 {
297 result.push_back(block);
298 }
299 else if (block.rows() == 0)
300 {
301 // filter out cases like "WHERE a = 1 AND a = 2"
302 continue;
303 }
304 else
305 {
306 // FIXME: shouldn't happen
307 return {};
308 }
309 }
310 }
311
312 return {result};
313}
314
315}
316