1 | #include <Interpreters/LogicalExpressionsOptimizer.h> |
2 | #include <Core/Settings.h> |
3 | |
4 | #include <Parsers/ASTFunction.h> |
5 | #include <Parsers/ASTSelectQuery.h> |
6 | #include <Parsers/ASTLiteral.h> |
7 | |
8 | #include <Common/typeid_cast.h> |
9 | |
10 | #include <deque> |
11 | |
12 | |
13 | namespace DB |
14 | { |
15 | |
16 | namespace ErrorCodes |
17 | { |
18 | extern const int LOGICAL_ERROR; |
19 | } |
20 | |
21 | |
22 | LogicalExpressionsOptimizer::OrWithExpression::OrWithExpression(const ASTFunction * or_function_, |
23 | const IAST::Hash & expression_, const std::string & alias_) |
24 | : or_function(or_function_), expression(expression_), alias(alias_) |
25 | { |
26 | } |
27 | |
28 | bool LogicalExpressionsOptimizer::OrWithExpression::operator<(const OrWithExpression & rhs) const |
29 | { |
30 | return std::tie(this->or_function, this->expression) < std::tie(rhs.or_function, rhs.expression); |
31 | } |
32 | |
33 | LogicalExpressionsOptimizer::LogicalExpressionsOptimizer(ASTSelectQuery * select_query_, UInt64 optimize_min_equality_disjunction_chain_length) |
34 | : select_query(select_query_), settings(optimize_min_equality_disjunction_chain_length) |
35 | { |
36 | } |
37 | |
38 | void LogicalExpressionsOptimizer::perform() |
39 | { |
40 | if (select_query == nullptr) |
41 | return; |
42 | if (visited_nodes.count(select_query)) |
43 | return; |
44 | |
45 | size_t position = 0; |
46 | for (auto & column : select_query->select()->children) |
47 | { |
48 | bool inserted = column_to_position.emplace(column.get(), position).second; |
49 | |
50 | /// Do not run, if AST was already converted to DAG. |
51 | /// TODO This is temporary solution. We must completely eliminate conversion of AST to DAG. |
52 | /// (see ExpressionAnalyzer::normalizeTree) |
53 | if (!inserted) |
54 | return; |
55 | |
56 | ++position; |
57 | } |
58 | |
59 | collectDisjunctiveEqualityChains(); |
60 | |
61 | for (auto & chain : disjunctive_equality_chains_map) |
62 | { |
63 | if (!mayOptimizeDisjunctiveEqualityChain(chain)) |
64 | continue; |
65 | addInExpression(chain); |
66 | |
67 | auto & equalities = chain.second; |
68 | equalities.is_processed = true; |
69 | ++processed_count; |
70 | } |
71 | |
72 | if (processed_count > 0) |
73 | { |
74 | cleanupOrExpressions(); |
75 | fixBrokenOrExpressions(); |
76 | reorderColumns(); |
77 | } |
78 | } |
79 | |
80 | void LogicalExpressionsOptimizer::reorderColumns() |
81 | { |
82 | auto & columns = select_query->select()->children; |
83 | size_t cur_position = 0; |
84 | |
85 | while (cur_position < columns.size()) |
86 | { |
87 | size_t expected_position = column_to_position.at(columns[cur_position].get()); |
88 | if (cur_position != expected_position) |
89 | std::swap(columns[cur_position], columns[expected_position]); |
90 | else |
91 | ++cur_position; |
92 | } |
93 | } |
94 | |
95 | void LogicalExpressionsOptimizer::collectDisjunctiveEqualityChains() |
96 | { |
97 | if (visited_nodes.count(select_query)) |
98 | return; |
99 | |
100 | using Edge = std::pair<IAST *, IAST *>; |
101 | std::deque<Edge> to_visit; |
102 | |
103 | to_visit.emplace_back(nullptr, select_query); |
104 | while (!to_visit.empty()) |
105 | { |
106 | auto edge = to_visit.back(); |
107 | auto from_node = edge.first; |
108 | auto to_node = edge.second; |
109 | |
110 | to_visit.pop_back(); |
111 | |
112 | bool found_chain = false; |
113 | |
114 | auto * function = to_node->as<ASTFunction>(); |
115 | if (function && function->name == "or" && function->children.size() == 1) |
116 | { |
117 | const auto * expression_list = function->children[0]->as<ASTExpressionList>(); |
118 | if (expression_list) |
119 | { |
120 | /// The chain of elements of the OR expression. |
121 | for (auto & child : expression_list->children) |
122 | { |
123 | auto * equals = child->as<ASTFunction>(); |
124 | if (equals && equals->name == "equals" && equals->children.size() == 1) |
125 | { |
126 | const auto * equals_expression_list = equals->children[0]->as<ASTExpressionList>(); |
127 | if (equals_expression_list && equals_expression_list->children.size() == 2) |
128 | { |
129 | /// Equality expr = xN. |
130 | const auto * literal = equals_expression_list->children[1]->as<ASTLiteral>(); |
131 | if (literal) |
132 | { |
133 | auto expr_lhs = equals_expression_list->children[0]->getTreeHash(); |
134 | OrWithExpression or_with_expression{function, expr_lhs, function->tryGetAlias()}; |
135 | disjunctive_equality_chains_map[or_with_expression].functions.push_back(equals); |
136 | found_chain = true; |
137 | } |
138 | } |
139 | } |
140 | } |
141 | } |
142 | } |
143 | |
144 | visited_nodes.insert(to_node); |
145 | |
146 | if (found_chain) |
147 | { |
148 | if (from_node != nullptr) |
149 | { |
150 | auto res = or_parent_map.insert(std::make_pair(function, ParentNodes{from_node})); |
151 | if (!res.second) |
152 | throw Exception("LogicalExpressionsOptimizer: parent node information is corrupted" , |
153 | ErrorCodes::LOGICAL_ERROR); |
154 | } |
155 | } |
156 | else |
157 | { |
158 | for (auto & child : to_node->children) |
159 | { |
160 | if (!child->as<ASTSelectQuery>()) |
161 | { |
162 | if (!visited_nodes.count(child.get())) |
163 | to_visit.push_back(Edge(to_node, &*child)); |
164 | else |
165 | { |
166 | /// If the node is an OR function, update the information about its parents. |
167 | auto it = or_parent_map.find(&*child); |
168 | if (it != or_parent_map.end()) |
169 | { |
170 | auto & parent_nodes = it->second; |
171 | parent_nodes.push_back(to_node); |
172 | } |
173 | } |
174 | } |
175 | } |
176 | } |
177 | } |
178 | |
179 | for (auto & chain : disjunctive_equality_chains_map) |
180 | { |
181 | auto & equalities = chain.second; |
182 | auto & equality_functions = equalities.functions; |
183 | std::sort(equality_functions.begin(), equality_functions.end()); |
184 | } |
185 | } |
186 | |
187 | namespace |
188 | { |
189 | |
190 | inline ASTs & getFunctionOperands(const ASTFunction * or_function) |
191 | { |
192 | return or_function->children[0]->children; |
193 | } |
194 | |
195 | } |
196 | |
197 | bool LogicalExpressionsOptimizer::mayOptimizeDisjunctiveEqualityChain(const DisjunctiveEqualityChain & chain) const |
198 | { |
199 | const auto & equalities = chain.second; |
200 | const auto & equality_functions = equalities.functions; |
201 | |
202 | /// We eliminate too short chains. |
203 | if (equality_functions.size() < settings.optimize_min_equality_disjunction_chain_length) |
204 | return false; |
205 | |
206 | /// We check that the right-hand sides of all equalities have the same type. |
207 | auto & first_operands = getFunctionOperands(equality_functions[0]); |
208 | const auto * first_literal = first_operands[1]->as<ASTLiteral>(); |
209 | for (size_t i = 1; i < equality_functions.size(); ++i) |
210 | { |
211 | auto & operands = getFunctionOperands(equality_functions[i]); |
212 | const auto * literal = operands[1]->as<ASTLiteral>(); |
213 | |
214 | if (literal->value.getType() != first_literal->value.getType()) |
215 | return false; |
216 | } |
217 | return true; |
218 | } |
219 | |
220 | void LogicalExpressionsOptimizer::addInExpression(const DisjunctiveEqualityChain & chain) |
221 | { |
222 | const auto & or_with_expression = chain.first; |
223 | const auto & equalities = chain.second; |
224 | const auto & equality_functions = equalities.functions; |
225 | |
226 | /// 1. Create a new IN expression based on information from the OR-chain. |
227 | |
228 | /// Construct a list of literals `x1, ..., xN` from the string `expr = x1 OR ... OR expr = xN` |
229 | ASTPtr value_list = std::make_shared<ASTExpressionList>(); |
230 | for (const auto function : equality_functions) |
231 | { |
232 | const auto & operands = getFunctionOperands(function); |
233 | value_list->children.push_back(operands[1]); |
234 | } |
235 | |
236 | /// Sort the literals so that they are specified in the same order in the IN expression. |
237 | /// Otherwise, they would be specified in the order of the ASTLiteral addresses, which is nondeterministic. |
238 | std::sort(value_list->children.begin(), value_list->children.end(), [](const DB::ASTPtr & lhs, const DB::ASTPtr & rhs) |
239 | { |
240 | const auto * val_lhs = lhs->as<ASTLiteral>(); |
241 | const auto * val_rhs = rhs->as<ASTLiteral>(); |
242 | return val_lhs->value < val_rhs->value; |
243 | }); |
244 | |
245 | /// Get the expression `expr` from the chain `expr = x1 OR ... OR expr = xN` |
246 | ASTPtr equals_expr_lhs; |
247 | { |
248 | auto function = equality_functions[0]; |
249 | const auto & operands = getFunctionOperands(function); |
250 | equals_expr_lhs = operands[0]; |
251 | } |
252 | |
253 | auto tuple_function = std::make_shared<ASTFunction>(); |
254 | tuple_function->name = "tuple" ; |
255 | tuple_function->arguments = value_list; |
256 | tuple_function->children.push_back(tuple_function->arguments); |
257 | |
258 | ASTPtr expression_list = std::make_shared<ASTExpressionList>(); |
259 | expression_list->children.push_back(equals_expr_lhs); |
260 | expression_list->children.push_back(tuple_function); |
261 | |
262 | /// Construct the expression `expr IN (x1, ..., xN)` |
263 | auto in_function = std::make_shared<ASTFunction>(); |
264 | in_function->name = "in" ; |
265 | in_function->arguments = expression_list; |
266 | in_function->children.push_back(in_function->arguments); |
267 | in_function->setAlias(or_with_expression.alias); |
268 | |
269 | /// 2. Insert the new IN expression. |
270 | |
271 | auto & operands = getFunctionOperands(or_with_expression.or_function); |
272 | operands.push_back(in_function); |
273 | } |
274 | |
275 | void LogicalExpressionsOptimizer::cleanupOrExpressions() |
276 | { |
277 | /// Saves for each optimized OR-chain the iterator on the first element |
278 | /// list of operands to be deleted. |
279 | std::unordered_map<const ASTFunction *, ASTs::iterator> garbage_map; |
280 | |
281 | /// Initialization. |
282 | garbage_map.reserve(processed_count); |
283 | for (const auto & chain : disjunctive_equality_chains_map) |
284 | { |
285 | if (!chain.second.is_processed) |
286 | continue; |
287 | |
288 | const auto & or_with_expression = chain.first; |
289 | auto & operands = getFunctionOperands(or_with_expression.or_function); |
290 | garbage_map.emplace(or_with_expression.or_function, operands.end()); |
291 | } |
292 | |
293 | /// Collect garbage. |
294 | for (const auto & chain : disjunctive_equality_chains_map) |
295 | { |
296 | const auto & equalities = chain.second; |
297 | if (!equalities.is_processed) |
298 | continue; |
299 | |
300 | const auto & or_with_expression = chain.first; |
301 | auto & operands = getFunctionOperands(or_with_expression.or_function); |
302 | const auto & equality_functions = equalities.functions; |
303 | |
304 | auto it = garbage_map.find(or_with_expression.or_function); |
305 | if (it == garbage_map.end()) |
306 | throw Exception("LogicalExpressionsOptimizer: garbage map is corrupted" , |
307 | ErrorCodes::LOGICAL_ERROR); |
308 | |
309 | auto & first_erased = it->second; |
310 | first_erased = std::remove_if(operands.begin(), first_erased, [&](const ASTPtr & operand) |
311 | { |
312 | return std::binary_search(equality_functions.begin(), equality_functions.end(), &*operand); |
313 | }); |
314 | } |
315 | |
316 | /// Delete garbage. |
317 | for (const auto & entry : garbage_map) |
318 | { |
319 | auto function = entry.first; |
320 | auto first_erased = entry.second; |
321 | |
322 | auto & operands = getFunctionOperands(function); |
323 | operands.erase(first_erased, operands.end()); |
324 | } |
325 | } |
326 | |
327 | void LogicalExpressionsOptimizer::fixBrokenOrExpressions() |
328 | { |
329 | for (const auto & chain : disjunctive_equality_chains_map) |
330 | { |
331 | const auto & equalities = chain.second; |
332 | if (!equalities.is_processed) |
333 | continue; |
334 | |
335 | const auto & or_with_expression = chain.first; |
336 | auto or_function = or_with_expression.or_function; |
337 | auto & operands = getFunctionOperands(or_with_expression.or_function); |
338 | |
339 | if (operands.size() == 1) |
340 | { |
341 | auto it = or_parent_map.find(or_function); |
342 | if (it == or_parent_map.end()) |
343 | throw Exception("LogicalExpressionsOptimizer: parent node information is corrupted" , |
344 | ErrorCodes::LOGICAL_ERROR); |
345 | auto & parents = it->second; |
346 | |
347 | auto it2 = column_to_position.find(or_function); |
348 | if (it2 != column_to_position.end()) |
349 | { |
350 | size_t position = it2->second; |
351 | bool inserted = column_to_position.emplace(operands[0].get(), position).second; |
352 | if (!inserted) |
353 | throw Exception("LogicalExpressionsOptimizer: internal error" , ErrorCodes::LOGICAL_ERROR); |
354 | column_to_position.erase(it2); |
355 | } |
356 | |
357 | for (auto & parent : parents) |
358 | { |
359 | // The order of children matters if or is children of some function, e.g. minus |
360 | std::replace_if(parent->children.begin(), parent->children.end(), |
361 | [or_function](const ASTPtr & ptr) { return ptr.get() == or_function; }, |
362 | operands[0]); |
363 | } |
364 | |
365 | /// If the OR node was the root of the WHERE, PREWHERE, or HAVING expression, then update this root. |
366 | /// Due to the fact that we are dealing with a directed acyclic graph, we must check all cases. |
367 | if (select_query->where() && (or_function == &*(select_query->where()))) |
368 | select_query->setExpression(ASTSelectQuery::Expression::WHERE, operands[0]->clone()); |
369 | if (select_query->prewhere() && (or_function == &*(select_query->prewhere()))) |
370 | select_query->setExpression(ASTSelectQuery::Expression::PREWHERE, operands[0]->clone()); |
371 | if (select_query->having() && (or_function == &*(select_query->having()))) |
372 | select_query->setExpression(ASTSelectQuery::Expression::HAVING, operands[0]->clone()); |
373 | } |
374 | } |
375 | } |
376 | |
377 | } |
378 | |