1 | #include <algorithm> |
2 | #include <cassert> |
3 | #include <numeric> |
4 | #include <stdexcept> |
5 | #include <typeinfo> |
6 | |
7 | #include "random.hh" |
8 | #include "relmodel.hh" |
9 | #include "grammar.hh" |
10 | #include "schema.hh" |
11 | #include "impedance.hh" |
12 | |
13 | using namespace std; |
14 | |
15 | shared_ptr<table_ref> table_ref::factory(prod *p) { |
16 | try { |
17 | if (p->level < 3 + d6()) { |
18 | if (d6() > 3 && p->level < d6()) |
19 | return make_shared<table_subquery>(p); |
20 | if (d6() > 3) |
21 | return make_shared<joined_table>(p); |
22 | } |
23 | if (d6() > 3) |
24 | return make_shared<table_or_query_name>(p); |
25 | else |
26 | return make_shared<table_sample>(p); |
27 | } catch (runtime_error &e) { |
28 | p->retry(); |
29 | } |
30 | return factory(p); |
31 | } |
32 | |
33 | table_or_query_name::table_or_query_name(prod *p) : table_ref(p) { |
34 | t = random_pick(scope->tables); |
35 | refs.push_back(make_shared<aliased_relation>(scope->stmt_uid("ref" ), t)); |
36 | } |
37 | |
38 | void table_or_query_name::out(std::ostream &out) { |
39 | out << t->ident() << " as " << refs[0]->ident(); |
40 | } |
41 | |
42 | target_table::target_table(prod *p, table *victim) : table_ref(p) { |
43 | while (!victim || victim->schema == "pg_catalog" || !victim->is_base_table || !victim->columns().size()) { |
44 | struct named_relation *pick = random_pick(scope->tables); |
45 | victim = dynamic_cast<table *>(pick); |
46 | retry(); |
47 | } |
48 | victim_ = victim; |
49 | refs.push_back(make_shared<aliased_relation>(scope->stmt_uid("target" ), victim)); |
50 | } |
51 | |
52 | void target_table::out(std::ostream &out) { |
53 | out << victim_->ident() << " as " << refs[0]->ident(); |
54 | } |
55 | |
56 | table_sample::table_sample(prod *p) : table_ref(p) { |
57 | match(); |
58 | retry_limit = 1000; /* retries are cheap here */ |
59 | do { |
60 | auto pick = random_pick(scope->schema->base_tables); |
61 | t = dynamic_cast<struct table *>(pick); |
62 | retry(); |
63 | } while (!t || !t->is_base_table); |
64 | |
65 | refs.push_back(make_shared<aliased_relation>(scope->stmt_uid("sample" ), t)); |
66 | percent = 0.1 * d100(); |
67 | method = (d6() > 2) ? "system" : "bernoulli" ; |
68 | } |
69 | |
70 | void table_sample::out(std::ostream &out) { |
71 | out << t->ident() << " as " << refs[0]->ident() << " tablesample " << method << " (" << percent << ") " ; |
72 | } |
73 | |
74 | table_subquery::table_subquery(prod *p, bool lateral) : table_ref(p), is_lateral(lateral) { |
75 | query = make_shared<query_spec>(this, scope, lateral); |
76 | string alias = scope->stmt_uid("subq" ); |
77 | relation *aliased_rel = &query->select_list->derived_table; |
78 | refs.push_back(make_shared<aliased_relation>(alias, aliased_rel)); |
79 | } |
80 | |
81 | table_subquery::~table_subquery() { |
82 | } |
83 | |
84 | void table_subquery::accept(prod_visitor *v) { |
85 | query->accept(v); |
86 | v->visit(this); |
87 | } |
88 | |
89 | shared_ptr<join_cond> join_cond::factory(prod *p, table_ref &lhs, table_ref &rhs) { |
90 | try { |
91 | if (d6() < 6) |
92 | return make_shared<expr_join_cond>(p, lhs, rhs); |
93 | else |
94 | return make_shared<simple_join_cond>(p, lhs, rhs); |
95 | } catch (runtime_error &e) { |
96 | p->retry(); |
97 | } |
98 | return factory(p, lhs, rhs); |
99 | } |
100 | |
101 | simple_join_cond::simple_join_cond(prod *p, table_ref &lhs, table_ref &rhs) : join_cond(p, lhs, rhs) { |
102 | retry: |
103 | named_relation *left_rel = &*random_pick(lhs.refs); |
104 | |
105 | if (!left_rel->columns().size()) { |
106 | retry(); |
107 | goto retry; |
108 | } |
109 | |
110 | named_relation *right_rel = &*random_pick(rhs.refs); |
111 | |
112 | column &c1 = random_pick(left_rel->columns()); |
113 | |
114 | for (auto c2 : right_rel->columns()) { |
115 | if (c1.type == c2.type) { |
116 | condition += left_rel->ident() + "." + c1.name + " = " + right_rel->ident() + "." + c2.name + " " ; |
117 | break; |
118 | } |
119 | } |
120 | if (condition == "" ) { |
121 | retry(); |
122 | goto retry; |
123 | } |
124 | } |
125 | |
126 | void simple_join_cond::out(std::ostream &out) { |
127 | out << condition; |
128 | } |
129 | |
130 | expr_join_cond::expr_join_cond(prod *p, table_ref &lhs, table_ref &rhs) : join_cond(p, lhs, rhs), joinscope(p->scope) { |
131 | scope = &joinscope; |
132 | for (auto ref : lhs.refs) |
133 | joinscope.refs.push_back(&*ref); |
134 | for (auto ref : rhs.refs) |
135 | joinscope.refs.push_back(&*ref); |
136 | search = bool_expr::factory(this); |
137 | } |
138 | |
139 | void expr_join_cond::out(std::ostream &out) { |
140 | out << *search; |
141 | } |
142 | |
143 | joined_table::joined_table(prod *p) : table_ref(p) { |
144 | lhs = table_ref::factory(this); |
145 | rhs = table_ref::factory(this); |
146 | |
147 | condition = join_cond::factory(this, *lhs, *rhs); |
148 | |
149 | if (d6() < 4) { |
150 | type = "inner" ; |
151 | } else if (d6() < 4) { |
152 | type = "left" ; |
153 | } else { |
154 | type = "right" ; |
155 | } |
156 | |
157 | for (auto ref : lhs->refs) |
158 | refs.push_back(ref); |
159 | for (auto ref : rhs->refs) |
160 | refs.push_back(ref); |
161 | } |
162 | |
163 | void joined_table::out(std::ostream &out) { |
164 | out << *lhs; |
165 | indent(out); |
166 | out << type << " join " << *rhs; |
167 | indent(out); |
168 | out << "on (" << *condition << ")" ; |
169 | } |
170 | |
171 | void table_subquery::out(std::ostream &out) { |
172 | if (is_lateral) |
173 | out << "lateral " ; |
174 | out << "(" << *query << ") as " << refs[0]->ident(); |
175 | } |
176 | |
177 | void from_clause::out(std::ostream &out) { |
178 | if (!reflist.size()) |
179 | return; |
180 | out << "from " ; |
181 | |
182 | for (auto r = reflist.begin(); r < reflist.end(); r++) { |
183 | indent(out); |
184 | out << **r; |
185 | if (r + 1 != reflist.end()) |
186 | out << "," ; |
187 | } |
188 | } |
189 | |
190 | from_clause::from_clause(prod *p) : prod(p) { |
191 | reflist.push_back(table_ref::factory(this)); |
192 | for (auto r : reflist.back()->refs) |
193 | scope->refs.push_back(&*r); |
194 | |
195 | while (d6() > 5) { |
196 | // add a lateral subquery |
197 | if (!impedance::matched(typeid(lateral_subquery))) |
198 | break; |
199 | reflist.push_back(make_shared<lateral_subquery>(this)); |
200 | for (auto r : reflist.back()->refs) |
201 | scope->refs.push_back(&*r); |
202 | } |
203 | } |
204 | |
205 | select_list::select_list(prod *p) : prod(p) { |
206 | do { |
207 | shared_ptr<value_expr> e = value_expr::factory(this); |
208 | value_exprs.push_back(e); |
209 | ostringstream name; |
210 | name << "c" << columns++; |
211 | sqltype *t = e->type; |
212 | assert(t); |
213 | derived_table.columns().push_back(column(name.str(), t)); |
214 | } while (d6() > 1); |
215 | } |
216 | |
217 | void select_list::out(std::ostream &out) { |
218 | int i = 0; |
219 | for (auto expr = value_exprs.begin(); expr != value_exprs.end(); expr++) { |
220 | indent(out); |
221 | out << **expr << " as " << derived_table.columns()[i].name; |
222 | i++; |
223 | if (expr + 1 != value_exprs.end()) |
224 | out << ", " ; |
225 | } |
226 | } |
227 | |
228 | void query_spec::out(std::ostream &out) { |
229 | out << "select " << set_quantifier << " " << *select_list; |
230 | indent(out); |
231 | out << *from_clause; |
232 | indent(out); |
233 | out << "where " ; |
234 | out << *search; |
235 | if (limit_clause.length()) { |
236 | indent(out); |
237 | out << limit_clause; |
238 | } |
239 | } |
240 | |
241 | struct for_update_verify : prod_visitor { |
242 | virtual void visit(prod *p) { |
243 | if (dynamic_cast<window_function *>(p)) |
244 | throw("window function" ); |
245 | joined_table *join = dynamic_cast<joined_table *>(p); |
246 | if (join && join->type != "inner" ) |
247 | throw("outer join" ); |
248 | query_spec *subquery = dynamic_cast<query_spec *>(p); |
249 | if (subquery) |
250 | subquery->set_quantifier = "" ; |
251 | table_or_query_name *tab = dynamic_cast<table_or_query_name *>(p); |
252 | if (tab) { |
253 | table *actual_table = dynamic_cast<table *>(tab->t); |
254 | if (actual_table && !actual_table->is_insertable) |
255 | throw("read only" ); |
256 | if (actual_table->name.find("pg_" )) |
257 | throw("catalog" ); |
258 | } |
259 | table_sample *sample = dynamic_cast<table_sample *>(p); |
260 | if (sample) { |
261 | table *actual_table = dynamic_cast<table *>(sample->t); |
262 | if (actual_table && !actual_table->is_insertable) |
263 | throw("read only" ); |
264 | if (actual_table->name.find("pg_" )) |
265 | throw("catalog" ); |
266 | } |
267 | }; |
268 | }; |
269 | |
270 | select_for_update::select_for_update(prod *p, struct scope *s, bool lateral) : query_spec(p, s, lateral) { |
271 | static const char *modes[] = { |
272 | "update" , |
273 | "share" , |
274 | "no key update" , |
275 | "key share" , |
276 | }; |
277 | |
278 | try { |
279 | for_update_verify v1; |
280 | this->accept(&v1); |
281 | |
282 | } catch (const char *reason) { |
283 | lockmode = 0; |
284 | return; |
285 | } |
286 | lockmode = modes[d6() % (sizeof(modes) / sizeof(*modes))]; |
287 | set_quantifier = "" ; // disallow distinct |
288 | } |
289 | |
290 | void select_for_update::out(std::ostream &out) { |
291 | query_spec::out(out); |
292 | if (lockmode) { |
293 | indent(out); |
294 | out << " for " << lockmode; |
295 | } |
296 | } |
297 | |
298 | query_spec::query_spec(prod *p, struct scope *s, bool lateral) : prod(p), myscope(s) { |
299 | scope = &myscope; |
300 | scope->tables = s->tables; |
301 | |
302 | if (lateral) |
303 | scope->refs = s->refs; |
304 | |
305 | from_clause = make_shared<struct from_clause>(this); |
306 | select_list = make_shared<struct select_list>(this); |
307 | |
308 | set_quantifier = (d100() == 1) ? "distinct" : "" ; |
309 | |
310 | search = bool_expr::factory(this); |
311 | |
312 | if (d6() > 2) { |
313 | ostringstream cons; |
314 | cons << "limit " << d100() + d100(); |
315 | limit_clause = cons.str(); |
316 | } |
317 | } |
318 | |
319 | long prepare_stmt::seq; |
320 | |
321 | void modifying_stmt::pick_victim() { |
322 | do { |
323 | struct named_relation *pick = random_pick(scope->tables); |
324 | victim = dynamic_cast<struct table *>(pick); |
325 | retry(); |
326 | } while (!victim || victim->schema == "pg_catalog" || !victim->is_base_table || !victim->columns().size()); |
327 | } |
328 | |
329 | modifying_stmt::modifying_stmt(prod *p, struct scope *s, table *victim) : prod(p), myscope(s) { |
330 | scope = &myscope; |
331 | scope->tables = s->tables; |
332 | |
333 | if (!victim) |
334 | pick_victim(); |
335 | } |
336 | |
337 | delete_stmt::delete_stmt(prod *p, struct scope *s, table *v) : modifying_stmt(p, s, v) { |
338 | scope->refs.push_back(victim); |
339 | search = bool_expr::factory(this); |
340 | } |
341 | |
342 | delete_returning::delete_returning(prod *p, struct scope *s, table *victim) : delete_stmt(p, s, victim) { |
343 | match(); |
344 | select_list = make_shared<struct select_list>(this); |
345 | } |
346 | |
347 | insert_stmt::insert_stmt(prod *p, struct scope *s, table *v) : modifying_stmt(p, s, v) { |
348 | match(); |
349 | |
350 | for (auto col : victim->columns()) { |
351 | auto expr = value_expr::factory(this, col.type); |
352 | assert(expr->type == col.type); |
353 | value_exprs.push_back(expr); |
354 | } |
355 | } |
356 | |
357 | void insert_stmt::out(std::ostream &out) { |
358 | out << "insert into " << victim->ident() << " " ; |
359 | |
360 | if (!value_exprs.size()) { |
361 | out << "default values" ; |
362 | return; |
363 | } |
364 | |
365 | out << "values (" ; |
366 | |
367 | for (auto expr = value_exprs.begin(); expr != value_exprs.end(); expr++) { |
368 | indent(out); |
369 | out << **expr; |
370 | if (expr + 1 != value_exprs.end()) |
371 | out << ", " ; |
372 | } |
373 | out << ")" ; |
374 | } |
375 | |
376 | set_list::set_list(prod *p, table *target) : prod(p) { |
377 | do { |
378 | for (auto col : target->columns()) { |
379 | if (d6() < 4) |
380 | continue; |
381 | auto expr = value_expr::factory(this, col.type); |
382 | value_exprs.push_back(expr); |
383 | names.push_back(col.name); |
384 | } |
385 | } while (!names.size()); |
386 | } |
387 | |
388 | void set_list::out(std::ostream &out) { |
389 | assert(names.size()); |
390 | out << " set " ; |
391 | for (size_t i = 0; i < names.size(); i++) { |
392 | indent(out); |
393 | out << names[i] << " = " << *value_exprs[i]; |
394 | if (i + 1 != names.size()) |
395 | out << ", " ; |
396 | } |
397 | } |
398 | |
399 | update_stmt::update_stmt(prod *p, struct scope *s, table *v) : modifying_stmt(p, s, v) { |
400 | scope->refs.push_back(victim); |
401 | search = bool_expr::factory(this); |
402 | set_list = make_shared<struct set_list>(this, victim); |
403 | } |
404 | |
405 | void update_stmt::out(std::ostream &out) { |
406 | out << "update " << victim->ident() << *set_list; |
407 | } |
408 | |
409 | update_returning::update_returning(prod *p, struct scope *s, table *v) : update_stmt(p, s, v) { |
410 | match(); |
411 | |
412 | select_list = make_shared<struct select_list>(this); |
413 | } |
414 | |
415 | upsert_stmt::upsert_stmt(prod *p, struct scope *s, table *v) : insert_stmt(p, s, v) { |
416 | match(); |
417 | |
418 | if (!victim->constraints.size()) |
419 | fail("need table w/ constraint for upsert" ); |
420 | |
421 | set_list = std::make_shared<struct set_list>(this, victim); |
422 | search = bool_expr::factory(this); |
423 | constraint = random_pick(victim->constraints); |
424 | } |
425 | |
426 | shared_ptr<prod> statement_factory(struct scope *s) { |
427 | try { |
428 | s->new_stmt(); |
429 | if (d42() == 1) |
430 | return make_shared<merge_stmt>((struct prod *)0, s); |
431 | if (d42() == 1) |
432 | return make_shared<insert_stmt>((struct prod *)0, s); |
433 | else if (d42() == 1) |
434 | return make_shared<delete_returning>((struct prod *)0, s); |
435 | else if (d42() == 1) { |
436 | return make_shared<upsert_stmt>((struct prod *)0, s); |
437 | } else if (d42() == 1) |
438 | return make_shared<update_returning>((struct prod *)0, s); |
439 | else if (d6() > 4) |
440 | return make_shared<select_for_update>((struct prod *)0, s); |
441 | else if (d6() > 5) |
442 | return make_shared<common_table_expression>((struct prod *)0, s); |
443 | return make_shared<query_spec>((struct prod *)0, s); |
444 | } catch (runtime_error &e) { |
445 | return statement_factory(s); |
446 | } |
447 | } |
448 | |
449 | void common_table_expression::accept(prod_visitor *v) { |
450 | v->visit(this); |
451 | for (auto q : with_queries) |
452 | q->accept(v); |
453 | query->accept(v); |
454 | } |
455 | |
456 | common_table_expression::common_table_expression(prod *parent, struct scope *s) : prod(parent), myscope(s) { |
457 | scope = &myscope; |
458 | do { |
459 | shared_ptr<query_spec> query = make_shared<query_spec>(this, s); |
460 | with_queries.push_back(query); |
461 | string alias = scope->stmt_uid("jennifer" ); |
462 | relation *relation = &query->select_list->derived_table; |
463 | auto aliased_rel = make_shared<aliased_relation>(alias, relation); |
464 | refs.push_back(aliased_rel); |
465 | scope->tables.push_back(&*aliased_rel); |
466 | |
467 | } while (d6() > 2); |
468 | |
469 | retry: |
470 | do { |
471 | auto pick = random_pick(s->tables); |
472 | scope->tables.push_back(pick); |
473 | } while (d6() > 3); |
474 | try { |
475 | query = make_shared<query_spec>(this, scope); |
476 | } catch (runtime_error &e) { |
477 | retry(); |
478 | goto retry; |
479 | } |
480 | } |
481 | |
482 | void common_table_expression::out(std::ostream &out) { |
483 | out << "WITH " ; |
484 | for (size_t i = 0; i < with_queries.size(); i++) { |
485 | indent(out); |
486 | out << refs[i]->ident() << " AS " |
487 | << "(" << *with_queries[i] << ")" ; |
488 | if (i + 1 != with_queries.size()) |
489 | out << ", " ; |
490 | indent(out); |
491 | } |
492 | out << *query; |
493 | indent(out); |
494 | } |
495 | |
496 | merge_stmt::merge_stmt(prod *p, struct scope *s, table *v) : modifying_stmt(p, s, v) { |
497 | match(); |
498 | target_table_ = make_shared<target_table>(this, victim); |
499 | data_source = table_ref::factory(this); |
500 | // join_condition = join_cond::factory(this, *target_table_, |
501 | // *data_source); |
502 | join_condition = make_shared<simple_join_cond>(this, *target_table_, *data_source); |
503 | |
504 | /* Put data_source into scope but not target_table. Visibility of |
505 | the latter varies depending on kind of when clause. */ |
506 | // for (auto r : data_source->refs) |
507 | // scope->refs.push_back(&*r); |
508 | |
509 | clauselist.push_back(when_clause::factory(this)); |
510 | while (d6() > 4) |
511 | clauselist.push_back(when_clause::factory(this)); |
512 | } |
513 | |
514 | void merge_stmt::out(std::ostream &out) { |
515 | out << "MERGE INTO " << *target_table_; |
516 | indent(out); |
517 | out << "USING " << *data_source; |
518 | indent(out); |
519 | out << "ON " << *join_condition; |
520 | indent(out); |
521 | for (auto p : clauselist) { |
522 | out << *p; |
523 | indent(out); |
524 | } |
525 | } |
526 | |
527 | void merge_stmt::accept(prod_visitor *v) { |
528 | v->visit(this); |
529 | target_table_->accept(v); |
530 | data_source->accept(v); |
531 | join_condition->accept(v); |
532 | for (auto p : clauselist) |
533 | p->accept(v); |
534 | } |
535 | |
536 | when_clause::when_clause(merge_stmt *p) : prod(p) { |
537 | condition = bool_expr::factory(this); |
538 | matched = d6() > 3; |
539 | } |
540 | |
541 | void when_clause::out(std::ostream &out) { |
542 | out << (matched ? "WHEN MATCHED " : "WHEN NOT MATCHED" ); |
543 | indent(out); |
544 | out << "AND " << *condition; |
545 | indent(out); |
546 | out << " THEN " ; |
547 | out << (matched ? "DELETE" : "DO NOTHING" ); |
548 | } |
549 | |
550 | void when_clause::accept(prod_visitor *v) { |
551 | v->visit(this); |
552 | condition->accept(v); |
553 | } |
554 | |
555 | when_clause_update::when_clause_update(merge_stmt *p) : when_clause(p), myscope(p->scope) { |
556 | myscope.tables = scope->tables; |
557 | myscope.refs = scope->refs; |
558 | scope = &myscope; |
559 | scope->refs.push_back(&*(p->target_table_->refs[0])); |
560 | |
561 | set_list = std::make_shared<struct set_list>(this, p->victim); |
562 | } |
563 | |
564 | void when_clause_update::out(std::ostream &out) { |
565 | out << "WHEN MATCHED AND " << *condition; |
566 | indent(out); |
567 | out << " THEN UPDATE " << *set_list; |
568 | } |
569 | |
570 | void when_clause_update::accept(prod_visitor *v) { |
571 | v->visit(this); |
572 | set_list->accept(v); |
573 | } |
574 | |
575 | when_clause_insert::when_clause_insert(struct merge_stmt *p) : when_clause(p) { |
576 | for (auto col : p->victim->columns()) { |
577 | auto expr = value_expr::factory(this, col.type); |
578 | assert(expr->type == col.type); |
579 | exprs.push_back(expr); |
580 | } |
581 | } |
582 | |
583 | void when_clause_insert::out(std::ostream &out) { |
584 | out << "WHEN NOT MATCHED AND " << *condition; |
585 | indent(out); |
586 | out << " THEN INSERT VALUES ( " ; |
587 | |
588 | for (auto expr = exprs.begin(); expr != exprs.end(); expr++) { |
589 | out << **expr; |
590 | if (expr + 1 != exprs.end()) |
591 | out << ", " ; |
592 | } |
593 | out << ")" ; |
594 | } |
595 | |
596 | void when_clause_insert::accept(prod_visitor *v) { |
597 | v->visit(this); |
598 | for (auto p : exprs) |
599 | p->accept(v); |
600 | } |
601 | |
602 | shared_ptr<when_clause> when_clause::factory(struct merge_stmt *p) { |
603 | try { |
604 | switch (d6()) { |
605 | case 1: |
606 | case 2: |
607 | return make_shared<when_clause_insert>(p); |
608 | case 3: |
609 | case 4: |
610 | return make_shared<when_clause_update>(p); |
611 | default: |
612 | return make_shared<when_clause>(p); |
613 | } |
614 | } catch (runtime_error &e) { |
615 | p->retry(); |
616 | } |
617 | return factory(p); |
618 | } |
619 | |