1#include <Interpreters/LogicalExpressionsOptimizer.h>
2#include <Core/Settings.h>
3
4#include <Parsers/ASTFunction.h>
5#include <Parsers/ASTSelectQuery.h>
6#include <Parsers/ASTLiteral.h>
7
8#include <Common/typeid_cast.h>
9
10#include <deque>
11
12
13namespace DB
14{
15
16namespace ErrorCodes
17{
18 extern const int LOGICAL_ERROR;
19}
20
21
22LogicalExpressionsOptimizer::OrWithExpression::OrWithExpression(const ASTFunction * or_function_,
23 const IAST::Hash & expression_, const std::string & alias_)
24 : or_function(or_function_), expression(expression_), alias(alias_)
25{
26}
27
28bool LogicalExpressionsOptimizer::OrWithExpression::operator<(const OrWithExpression & rhs) const
29{
30 return std::tie(this->or_function, this->expression) < std::tie(rhs.or_function, rhs.expression);
31}
32
33LogicalExpressionsOptimizer::LogicalExpressionsOptimizer(ASTSelectQuery * select_query_, UInt64 optimize_min_equality_disjunction_chain_length)
34 : select_query(select_query_), settings(optimize_min_equality_disjunction_chain_length)
35{
36}
37
38void LogicalExpressionsOptimizer::perform()
39{
40 if (select_query == nullptr)
41 return;
42 if (visited_nodes.count(select_query))
43 return;
44
45 size_t position = 0;
46 for (auto & column : select_query->select()->children)
47 {
48 bool inserted = column_to_position.emplace(column.get(), position).second;
49
50 /// Do not run, if AST was already converted to DAG.
51 /// TODO This is temporary solution. We must completely eliminate conversion of AST to DAG.
52 /// (see ExpressionAnalyzer::normalizeTree)
53 if (!inserted)
54 return;
55
56 ++position;
57 }
58
59 collectDisjunctiveEqualityChains();
60
61 for (auto & chain : disjunctive_equality_chains_map)
62 {
63 if (!mayOptimizeDisjunctiveEqualityChain(chain))
64 continue;
65 addInExpression(chain);
66
67 auto & equalities = chain.second;
68 equalities.is_processed = true;
69 ++processed_count;
70 }
71
72 if (processed_count > 0)
73 {
74 cleanupOrExpressions();
75 fixBrokenOrExpressions();
76 reorderColumns();
77 }
78}
79
80void LogicalExpressionsOptimizer::reorderColumns()
81{
82 auto & columns = select_query->select()->children;
83 size_t cur_position = 0;
84
85 while (cur_position < columns.size())
86 {
87 size_t expected_position = column_to_position.at(columns[cur_position].get());
88 if (cur_position != expected_position)
89 std::swap(columns[cur_position], columns[expected_position]);
90 else
91 ++cur_position;
92 }
93}
94
95void LogicalExpressionsOptimizer::collectDisjunctiveEqualityChains()
96{
97 if (visited_nodes.count(select_query))
98 return;
99
100 using Edge = std::pair<IAST *, IAST *>;
101 std::deque<Edge> to_visit;
102
103 to_visit.emplace_back(nullptr, select_query);
104 while (!to_visit.empty())
105 {
106 auto edge = to_visit.back();
107 auto from_node = edge.first;
108 auto to_node = edge.second;
109
110 to_visit.pop_back();
111
112 bool found_chain = false;
113
114 auto * function = to_node->as<ASTFunction>();
115 if (function && function->name == "or" && function->children.size() == 1)
116 {
117 const auto * expression_list = function->children[0]->as<ASTExpressionList>();
118 if (expression_list)
119 {
120 /// The chain of elements of the OR expression.
121 for (auto & child : expression_list->children)
122 {
123 auto * equals = child->as<ASTFunction>();
124 if (equals && equals->name == "equals" && equals->children.size() == 1)
125 {
126 const auto * equals_expression_list = equals->children[0]->as<ASTExpressionList>();
127 if (equals_expression_list && equals_expression_list->children.size() == 2)
128 {
129 /// Equality expr = xN.
130 const auto * literal = equals_expression_list->children[1]->as<ASTLiteral>();
131 if (literal)
132 {
133 auto expr_lhs = equals_expression_list->children[0]->getTreeHash();
134 OrWithExpression or_with_expression{function, expr_lhs, function->tryGetAlias()};
135 disjunctive_equality_chains_map[or_with_expression].functions.push_back(equals);
136 found_chain = true;
137 }
138 }
139 }
140 }
141 }
142 }
143
144 visited_nodes.insert(to_node);
145
146 if (found_chain)
147 {
148 if (from_node != nullptr)
149 {
150 auto res = or_parent_map.insert(std::make_pair(function, ParentNodes{from_node}));
151 if (!res.second)
152 throw Exception("LogicalExpressionsOptimizer: parent node information is corrupted",
153 ErrorCodes::LOGICAL_ERROR);
154 }
155 }
156 else
157 {
158 for (auto & child : to_node->children)
159 {
160 if (!child->as<ASTSelectQuery>())
161 {
162 if (!visited_nodes.count(child.get()))
163 to_visit.push_back(Edge(to_node, &*child));
164 else
165 {
166 /// If the node is an OR function, update the information about its parents.
167 auto it = or_parent_map.find(&*child);
168 if (it != or_parent_map.end())
169 {
170 auto & parent_nodes = it->second;
171 parent_nodes.push_back(to_node);
172 }
173 }
174 }
175 }
176 }
177 }
178
179 for (auto & chain : disjunctive_equality_chains_map)
180 {
181 auto & equalities = chain.second;
182 auto & equality_functions = equalities.functions;
183 std::sort(equality_functions.begin(), equality_functions.end());
184 }
185}
186
187namespace
188{
189
190inline ASTs & getFunctionOperands(const ASTFunction * or_function)
191{
192 return or_function->children[0]->children;
193}
194
195}
196
197bool LogicalExpressionsOptimizer::mayOptimizeDisjunctiveEqualityChain(const DisjunctiveEqualityChain & chain) const
198{
199 const auto & equalities = chain.second;
200 const auto & equality_functions = equalities.functions;
201
202 /// We eliminate too short chains.
203 if (equality_functions.size() < settings.optimize_min_equality_disjunction_chain_length)
204 return false;
205
206 /// We check that the right-hand sides of all equalities have the same type.
207 auto & first_operands = getFunctionOperands(equality_functions[0]);
208 const auto * first_literal = first_operands[1]->as<ASTLiteral>();
209 for (size_t i = 1; i < equality_functions.size(); ++i)
210 {
211 auto & operands = getFunctionOperands(equality_functions[i]);
212 const auto * literal = operands[1]->as<ASTLiteral>();
213
214 if (literal->value.getType() != first_literal->value.getType())
215 return false;
216 }
217 return true;
218}
219
220void LogicalExpressionsOptimizer::addInExpression(const DisjunctiveEqualityChain & chain)
221{
222 const auto & or_with_expression = chain.first;
223 const auto & equalities = chain.second;
224 const auto & equality_functions = equalities.functions;
225
226 /// 1. Create a new IN expression based on information from the OR-chain.
227
228 /// Construct a list of literals `x1, ..., xN` from the string `expr = x1 OR ... OR expr = xN`
229 ASTPtr value_list = std::make_shared<ASTExpressionList>();
230 for (const auto function : equality_functions)
231 {
232 const auto & operands = getFunctionOperands(function);
233 value_list->children.push_back(operands[1]);
234 }
235
236 /// Sort the literals so that they are specified in the same order in the IN expression.
237 /// Otherwise, they would be specified in the order of the ASTLiteral addresses, which is nondeterministic.
238 std::sort(value_list->children.begin(), value_list->children.end(), [](const DB::ASTPtr & lhs, const DB::ASTPtr & rhs)
239 {
240 const auto * val_lhs = lhs->as<ASTLiteral>();
241 const auto * val_rhs = rhs->as<ASTLiteral>();
242 return val_lhs->value < val_rhs->value;
243 });
244
245 /// Get the expression `expr` from the chain `expr = x1 OR ... OR expr = xN`
246 ASTPtr equals_expr_lhs;
247 {
248 auto function = equality_functions[0];
249 const auto & operands = getFunctionOperands(function);
250 equals_expr_lhs = operands[0];
251 }
252
253 auto tuple_function = std::make_shared<ASTFunction>();
254 tuple_function->name = "tuple";
255 tuple_function->arguments = value_list;
256 tuple_function->children.push_back(tuple_function->arguments);
257
258 ASTPtr expression_list = std::make_shared<ASTExpressionList>();
259 expression_list->children.push_back(equals_expr_lhs);
260 expression_list->children.push_back(tuple_function);
261
262 /// Construct the expression `expr IN (x1, ..., xN)`
263 auto in_function = std::make_shared<ASTFunction>();
264 in_function->name = "in";
265 in_function->arguments = expression_list;
266 in_function->children.push_back(in_function->arguments);
267 in_function->setAlias(or_with_expression.alias);
268
269 /// 2. Insert the new IN expression.
270
271 auto & operands = getFunctionOperands(or_with_expression.or_function);
272 operands.push_back(in_function);
273}
274
275void LogicalExpressionsOptimizer::cleanupOrExpressions()
276{
277 /// Saves for each optimized OR-chain the iterator on the first element
278 /// list of operands to be deleted.
279 std::unordered_map<const ASTFunction *, ASTs::iterator> garbage_map;
280
281 /// Initialization.
282 garbage_map.reserve(processed_count);
283 for (const auto & chain : disjunctive_equality_chains_map)
284 {
285 if (!chain.second.is_processed)
286 continue;
287
288 const auto & or_with_expression = chain.first;
289 auto & operands = getFunctionOperands(or_with_expression.or_function);
290 garbage_map.emplace(or_with_expression.or_function, operands.end());
291 }
292
293 /// Collect garbage.
294 for (const auto & chain : disjunctive_equality_chains_map)
295 {
296 const auto & equalities = chain.second;
297 if (!equalities.is_processed)
298 continue;
299
300 const auto & or_with_expression = chain.first;
301 auto & operands = getFunctionOperands(or_with_expression.or_function);
302 const auto & equality_functions = equalities.functions;
303
304 auto it = garbage_map.find(or_with_expression.or_function);
305 if (it == garbage_map.end())
306 throw Exception("LogicalExpressionsOptimizer: garbage map is corrupted",
307 ErrorCodes::LOGICAL_ERROR);
308
309 auto & first_erased = it->second;
310 first_erased = std::remove_if(operands.begin(), first_erased, [&](const ASTPtr & operand)
311 {
312 return std::binary_search(equality_functions.begin(), equality_functions.end(), &*operand);
313 });
314 }
315
316 /// Delete garbage.
317 for (const auto & entry : garbage_map)
318 {
319 auto function = entry.first;
320 auto first_erased = entry.second;
321
322 auto & operands = getFunctionOperands(function);
323 operands.erase(first_erased, operands.end());
324 }
325}
326
327void LogicalExpressionsOptimizer::fixBrokenOrExpressions()
328{
329 for (const auto & chain : disjunctive_equality_chains_map)
330 {
331 const auto & equalities = chain.second;
332 if (!equalities.is_processed)
333 continue;
334
335 const auto & or_with_expression = chain.first;
336 auto or_function = or_with_expression.or_function;
337 auto & operands = getFunctionOperands(or_with_expression.or_function);
338
339 if (operands.size() == 1)
340 {
341 auto it = or_parent_map.find(or_function);
342 if (it == or_parent_map.end())
343 throw Exception("LogicalExpressionsOptimizer: parent node information is corrupted",
344 ErrorCodes::LOGICAL_ERROR);
345 auto & parents = it->second;
346
347 auto it2 = column_to_position.find(or_function);
348 if (it2 != column_to_position.end())
349 {
350 size_t position = it2->second;
351 bool inserted = column_to_position.emplace(operands[0].get(), position).second;
352 if (!inserted)
353 throw Exception("LogicalExpressionsOptimizer: internal error", ErrorCodes::LOGICAL_ERROR);
354 column_to_position.erase(it2);
355 }
356
357 for (auto & parent : parents)
358 {
359 // The order of children matters if or is children of some function, e.g. minus
360 std::replace_if(parent->children.begin(), parent->children.end(),
361 [or_function](const ASTPtr & ptr) { return ptr.get() == or_function; },
362 operands[0]);
363 }
364
365 /// If the OR node was the root of the WHERE, PREWHERE, or HAVING expression, then update this root.
366 /// Due to the fact that we are dealing with a directed acyclic graph, we must check all cases.
367 if (select_query->where() && (or_function == &*(select_query->where())))
368 select_query->setExpression(ASTSelectQuery::Expression::WHERE, operands[0]->clone());
369 if (select_query->prewhere() && (or_function == &*(select_query->prewhere())))
370 select_query->setExpression(ASTSelectQuery::Expression::PREWHERE, operands[0]->clone());
371 if (select_query->having() && (or_function == &*(select_query->having())))
372 select_query->setExpression(ASTSelectQuery::Expression::HAVING, operands[0]->clone());
373 }
374 }
375}
376
377}
378