1 | #include <Parsers/ASTSelectQuery.h> |
2 | #include <Parsers/ParserSelectQuery.h> |
3 | #include <Parsers/parseQuery.h> |
4 | #include <Parsers/queryToString.h> |
5 | #include <Interpreters/LogicalExpressionsOptimizer.h> |
6 | #include <Core/Settings.h> |
7 | #include <Common/typeid_cast.h> |
8 | |
9 | #include <iostream> |
10 | #include <vector> |
11 | #include <utility> |
12 | #include <string> |
13 | #include <algorithm> |
14 | |
15 | |
16 | namespace |
17 | { |
18 | |
19 | struct TestEntry |
20 | { |
21 | std::string input; |
22 | std::string expected_output; |
23 | UInt64 limit; |
24 | }; |
25 | |
26 | using TestEntries = std::vector<TestEntry>; |
27 | using TestResult = std::pair<bool, std::string>; |
28 | |
29 | void run(); |
30 | void performTests(const TestEntries & entries); |
31 | TestResult check(const TestEntry & entry); |
32 | bool parse(DB::ASTPtr & ast, const std::string & query); |
33 | bool equals(const DB::ASTPtr & lhs, const DB::ASTPtr & rhs); |
34 | void reorder(DB::IAST * ast); |
35 | |
36 | |
37 | void run() |
38 | { |
39 | /// NOTE: Queries are not always realistic, but we are only interested in the syntax. |
40 | TestEntries entries = |
41 | { |
42 | { |
43 | "SELECT 1" , |
44 | "SELECT 1" , |
45 | 3 |
46 | }, |
47 | |
48 | // WHERE |
49 | |
50 | { |
51 | "SELECT name, value FROM report WHERE (name = 'Alice') OR (name = 'Bob') OR (name = 'Carol')" , |
52 | "SELECT name, value FROM report WHERE (name = 'Alice') OR (name = 'Bob') OR (name = 'Carol')" , |
53 | 4 |
54 | }, |
55 | |
56 | { |
57 | "SELECT name, value FROM report WHERE (name = 'Alice') OR (name = 'Bob') OR (name = 'Carol')" , |
58 | "SELECT name, value FROM report WHERE name IN ('Alice', 'Bob', 'Carol')" , |
59 | 3 |
60 | }, |
61 | |
62 | { |
63 | "SELECT name, value FROM report WHERE (name = 'Alice') OR (name = 'Bob') OR (name = 'Carol')" , |
64 | "SELECT name, value FROM report WHERE name IN ('Alice', 'Bob', 'Carol')" , |
65 | 2 |
66 | }, |
67 | |
68 | { |
69 | "SELECT name, value FROM report WHERE (name = 'Alice') OR (value = 1000) OR (name = 'Bob') OR (name = 'Carol')" , |
70 | "SELECT name, value FROM report WHERE (value = 1000) OR name IN ('Alice', 'Bob', 'Carol')" , |
71 | 2 |
72 | }, |
73 | |
74 | { |
75 | "SELECT name, value FROM report WHERE (name = 'Alice') OR (value = 1000) OR (name = 'Bob') OR (name = 'Carol') OR (value = 2000)" , |
76 | "SELECT name, value FROM report WHERE name IN ('Alice', 'Bob', 'Carol') OR value IN (1000, 2000)" , |
77 | 2 |
78 | }, |
79 | |
80 | { |
81 | "SELECT value FROM report WHERE ((value + 1) = 1000) OR ((2 * value) = 2000) OR ((2 * value) = 4000) OR ((value + 1) = 3000)" , |
82 | "SELECT value FROM report WHERE ((value + 1) IN (1000, 3000)) OR ((2 * value) IN (2000, 4000))" , |
83 | 2 |
84 | }, |
85 | |
86 | { |
87 | "SELECT name, value FROM report WHERE ((name = 'Alice') OR (name = 'Bob') OR (name = 'Carol')) AND ((value = 1000) OR (value = 2000))" , |
88 | "SELECT name, value FROM report WHERE name IN ('Alice', 'Bob', 'Carol') AND ((value = 1000) OR (value = 2000))" , |
89 | 3 |
90 | }, |
91 | |
92 | // PREWHERE |
93 | |
94 | { |
95 | "SELECT name, value FROM report PREWHERE (name = 'Alice') OR (name = 'Bob') OR (name = 'Carol')" , |
96 | "SELECT name, value FROM report PREWHERE (name = 'Alice') OR (name = 'Bob') OR (name = 'Carol')" , |
97 | 4 |
98 | }, |
99 | |
100 | { |
101 | "SELECT name, value FROM report PREWHERE (name = 'Alice') OR (name = 'Bob') OR (name = 'Carol')" , |
102 | "SELECT name, value FROM report PREWHERE name IN ('Alice', 'Bob', 'Carol')" , |
103 | 3 |
104 | }, |
105 | |
106 | { |
107 | "SELECT name, value FROM report PREWHERE (name = 'Alice') OR (name = 'Bob') OR (name = 'Carol')" , |
108 | "SELECT name, value FROM report PREWHERE name IN ('Alice', 'Bob', 'Carol')" , |
109 | 2 |
110 | }, |
111 | |
112 | { |
113 | "SELECT name, value FROM report PREWHERE (name = 'Alice') OR (value = 1000) OR (name = 'Bob') OR (name = 'Carol')" , |
114 | "SELECT name, value FROM report PREWHERE (value = 1000) OR name IN ('Alice', 'Bob', 'Carol')" , |
115 | 2 |
116 | }, |
117 | |
118 | { |
119 | "SELECT name, value FROM report PREWHERE (name = 'Alice') OR (value = 1000) OR (name = 'Bob') OR (name = 'Carol') OR (value = 2000)" , |
120 | "SELECT name, value FROM report PREWHERE name IN ('Alice', 'Bob', 'Carol') OR value IN (1000, 2000)" , |
121 | 2 |
122 | }, |
123 | |
124 | { |
125 | "SELECT value FROM report PREWHERE ((value + 1) = 1000) OR ((2 * value) = 2000) OR ((2 * value) = 4000) OR ((value + 1) = 3000)" , |
126 | "SELECT value FROM report PREWHERE (value + 1) IN (1000, 3000) OR (2 * value) IN (2000, 4000)" , |
127 | 2 |
128 | }, |
129 | |
130 | // HAVING |
131 | |
132 | { |
133 | "SELECT number, count() FROM (SELECT * FROM system.numbers LIMIT 10) GROUP BY number HAVING number = 1" , |
134 | "SELECT number, count() FROM (SELECT * FROM system.numbers LIMIT 10) GROUP BY number HAVING number = 1" , |
135 | 2 |
136 | }, |
137 | |
138 | { |
139 | "SELECT number, count() FROM (SELECT * FROM system.numbers LIMIT 10) GROUP BY number HAVING (number = 1) OR (number = 2)" , |
140 | "SELECT number, count() FROM (SELECT * FROM system.numbers LIMIT 10) GROUP BY number HAVING number IN (1, 2)" , |
141 | 2 |
142 | }, |
143 | |
144 | { |
145 | "SELECT number, count() FROM (SELECT * FROM system.numbers LIMIT 10) GROUP BY number HAVING (number = 1) OR (number = 2)" , |
146 | "SELECT number, count() FROM (SELECT * FROM system.numbers LIMIT 10) GROUP BY number HAVING (number = 1) OR (number = 2)" , |
147 | 3 |
148 | }, |
149 | |
150 | { |
151 | "SELECT number, count() FROM (SELECT * FROM system.numbers LIMIT 10) GROUP BY number HAVING ((number + 1) = 1) OR ((number + 1) = 2) OR ((number + 3) = 7)" , |
152 | "SELECT number, count() FROM (SELECT * FROM system.numbers LIMIT 10) GROUP BY number HAVING ((number + 3) = 7) OR (number + 1) IN (1, 2)" , |
153 | 2 |
154 | }, |
155 | |
156 | // PREWHERE + WHERE + HAVING |
157 | |
158 | { |
159 | "SELECT number, count(), 1 AS T, 2 AS U FROM (SELECT * FROM system.numbers LIMIT 10) PREWHERE (U = 1) OR (U = 2) " |
160 | "WHERE (T = 1) OR (T = 2) GROUP BY number HAVING (number = 1) OR (number = 2)" , |
161 | "SELECT number, count(), 1 AS T, 2 AS U FROM (SELECT * FROM system.numbers LIMIT 10) PREWHERE U IN (1, 2) " |
162 | "WHERE T IN (1, 2) GROUP BY number HAVING number IN (1, 2)" , |
163 | 2 |
164 | }, |
165 | |
166 | { |
167 | "SELECT number, count(), 1 AS T, 2 AS U FROM (SELECT * FROM system.numbers LIMIT 10) PREWHERE (U = 1) OR (U = 2) OR (U = 3) " |
168 | "WHERE (T = 1) OR (T = 2) GROUP BY number HAVING (number = 1) OR (number = 2)" , |
169 | "SELECT number, count(), 1 AS T, 2 AS U FROM (SELECT * FROM system.numbers LIMIT 10) PREWHERE U IN (1, 2, 3) " |
170 | "WHERE (T = 1) OR (T = 2) GROUP BY number HAVING (number = 1) OR (number = 2)" , |
171 | 3 |
172 | }, |
173 | |
174 | { |
175 | "SELECT x = 1 OR x=2 OR (x = 3 AS x3) AS y, 4 AS x" , |
176 | "SELECT x IN (1, 2, 3) AS y, 4 AS x" , |
177 | 2 |
178 | } |
179 | }; |
180 | |
181 | performTests(entries); |
182 | } |
183 | |
184 | void performTests(const TestEntries & entries) |
185 | { |
186 | unsigned int count = 0; |
187 | unsigned int i = 1; |
188 | |
189 | for (const auto & entry : entries) |
190 | { |
191 | auto res = check(entry); |
192 | if (res.first) |
193 | { |
194 | ++count; |
195 | std::cout << "Test " << i << " passed.\n" ; |
196 | } |
197 | else |
198 | std::cout << "Test " << i << " failed. Expected: " << entry.expected_output << ". Received: " << res.second << "\n" ; |
199 | |
200 | ++i; |
201 | } |
202 | std::cout << count << " out of " << entries.size() << " test(s) passed.\n" ; |
203 | } |
204 | |
205 | TestResult check(const TestEntry & entry) |
206 | { |
207 | try |
208 | { |
209 | /// Parse and optimize the incoming query. |
210 | DB::ASTPtr ast_input; |
211 | if (!parse(ast_input, entry.input)) |
212 | return TestResult(false, "parse error" ); |
213 | |
214 | auto select_query = typeid_cast<DB::ASTSelectQuery *>(&*ast_input); |
215 | |
216 | DB::LogicalExpressionsOptimizer optimizer(select_query, entry.limit); |
217 | optimizer.perform(); |
218 | |
219 | /// Parse the expected result. |
220 | DB::ASTPtr ast_expected; |
221 | if (!parse(ast_expected, entry.expected_output)) |
222 | return TestResult(false, "parse error" ); |
223 | |
224 | /// Compare the optimized query and the expected result. |
225 | bool res = equals(ast_input, ast_expected); |
226 | std::string output = DB::queryToString(ast_input); |
227 | |
228 | return TestResult(res, output); |
229 | } |
230 | catch (DB::Exception & e) |
231 | { |
232 | return TestResult(false, e.displayText()); |
233 | } |
234 | } |
235 | |
236 | bool parse(DB::ASTPtr & ast, const std::string & query) |
237 | { |
238 | DB::ParserSelectQuery parser; |
239 | std::string message; |
240 | auto begin = query.data(); |
241 | auto end = begin + query.size(); |
242 | ast = DB::tryParseQuery(parser, begin, end, message, false, "" , false, 0); |
243 | return ast != nullptr; |
244 | } |
245 | |
246 | bool equals(const DB::ASTPtr & lhs, const DB::ASTPtr & rhs) |
247 | { |
248 | DB::ASTPtr lhs_reordered = lhs->clone(); |
249 | reorder(&*lhs_reordered); |
250 | |
251 | DB::ASTPtr rhs_reordered = rhs->clone(); |
252 | reorder(&*rhs_reordered); |
253 | |
254 | return lhs_reordered->getTreeHash() == rhs_reordered->getTreeHash(); |
255 | } |
256 | |
257 | void reorderImpl(DB::IAST * ast) |
258 | { |
259 | if (ast == nullptr) |
260 | return; |
261 | |
262 | auto & children = ast->children; |
263 | if (children.empty()) |
264 | return; |
265 | |
266 | for (auto & child : children) |
267 | reorderImpl(&*child); |
268 | |
269 | std::sort(children.begin(), children.end(), [](const DB::ASTPtr & lhs, const DB::ASTPtr & rhs) |
270 | { |
271 | return lhs->getTreeHash() < rhs->getTreeHash(); |
272 | }); |
273 | } |
274 | |
275 | void reorder(DB::IAST * ast) |
276 | { |
277 | if (ast == nullptr) |
278 | return; |
279 | |
280 | auto select_query = typeid_cast<DB::ASTSelectQuery *>(ast); |
281 | if (select_query == nullptr) |
282 | return; |
283 | |
284 | reorderImpl(select_query->where().get()); |
285 | reorderImpl(select_query->prewhere().get()); |
286 | reorderImpl(select_query->having().get()); |
287 | } |
288 | |
289 | } |
290 | |
291 | int main() |
292 | { |
293 | run(); |
294 | return 0; |
295 | } |
296 | |