1#include <algorithm>
2#include <cassert>
3#include <numeric>
4#include <stdexcept>
5#include <typeinfo>
6
7#include "random.hh"
8#include "relmodel.hh"
9#include "grammar.hh"
10#include "schema.hh"
11#include "impedance.hh"
12
13using namespace std;
14
15shared_ptr<table_ref> table_ref::factory(prod *p) {
16 try {
17 if (p->level < 3 + d6()) {
18 if (d6() > 3 && p->level < d6())
19 return make_shared<table_subquery>(p);
20 if (d6() > 3)
21 return make_shared<joined_table>(p);
22 }
23 if (d6() > 3)
24 return make_shared<table_or_query_name>(p);
25 else
26 return make_shared<table_sample>(p);
27 } catch (runtime_error &e) {
28 p->retry();
29 }
30 return factory(p);
31}
32
33table_or_query_name::table_or_query_name(prod *p) : table_ref(p) {
34 t = random_pick(scope->tables);
35 refs.push_back(make_shared<aliased_relation>(scope->stmt_uid("ref"), t));
36}
37
38void table_or_query_name::out(std::ostream &out) {
39 out << t->ident() << " as " << refs[0]->ident();
40}
41
42target_table::target_table(prod *p, table *victim) : table_ref(p) {
43 while (!victim || victim->schema == "pg_catalog" || !victim->is_base_table || !victim->columns().size()) {
44 struct named_relation *pick = random_pick(scope->tables);
45 victim = dynamic_cast<table *>(pick);
46 retry();
47 }
48 victim_ = victim;
49 refs.push_back(make_shared<aliased_relation>(scope->stmt_uid("target"), victim));
50}
51
52void target_table::out(std::ostream &out) {
53 out << victim_->ident() << " as " << refs[0]->ident();
54}
55
56table_sample::table_sample(prod *p) : table_ref(p) {
57 match();
58 retry_limit = 1000; /* retries are cheap here */
59 do {
60 auto pick = random_pick(scope->schema->base_tables);
61 t = dynamic_cast<struct table *>(pick);
62 retry();
63 } while (!t || !t->is_base_table);
64
65 refs.push_back(make_shared<aliased_relation>(scope->stmt_uid("sample"), t));
66 percent = 0.1 * d100();
67 method = (d6() > 2) ? "system" : "bernoulli";
68}
69
70void table_sample::out(std::ostream &out) {
71 out << t->ident() << " as " << refs[0]->ident() << " tablesample " << method << " (" << percent << ") ";
72}
73
74table_subquery::table_subquery(prod *p, bool lateral) : table_ref(p), is_lateral(lateral) {
75 query = make_shared<query_spec>(this, scope, lateral);
76 string alias = scope->stmt_uid("subq");
77 relation *aliased_rel = &query->select_list->derived_table;
78 refs.push_back(make_shared<aliased_relation>(alias, aliased_rel));
79}
80
81table_subquery::~table_subquery() {
82}
83
84void table_subquery::accept(prod_visitor *v) {
85 query->accept(v);
86 v->visit(this);
87}
88
89shared_ptr<join_cond> join_cond::factory(prod *p, table_ref &lhs, table_ref &rhs) {
90 try {
91 if (d6() < 6)
92 return make_shared<expr_join_cond>(p, lhs, rhs);
93 else
94 return make_shared<simple_join_cond>(p, lhs, rhs);
95 } catch (runtime_error &e) {
96 p->retry();
97 }
98 return factory(p, lhs, rhs);
99}
100
101simple_join_cond::simple_join_cond(prod *p, table_ref &lhs, table_ref &rhs) : join_cond(p, lhs, rhs) {
102retry:
103 named_relation *left_rel = &*random_pick(lhs.refs);
104
105 if (!left_rel->columns().size()) {
106 retry();
107 goto retry;
108 }
109
110 named_relation *right_rel = &*random_pick(rhs.refs);
111
112 column &c1 = random_pick(left_rel->columns());
113
114 for (auto c2 : right_rel->columns()) {
115 if (c1.type == c2.type) {
116 condition += left_rel->ident() + "." + c1.name + " = " + right_rel->ident() + "." + c2.name + " ";
117 break;
118 }
119 }
120 if (condition == "") {
121 retry();
122 goto retry;
123 }
124}
125
126void simple_join_cond::out(std::ostream &out) {
127 out << condition;
128}
129
130expr_join_cond::expr_join_cond(prod *p, table_ref &lhs, table_ref &rhs) : join_cond(p, lhs, rhs), joinscope(p->scope) {
131 scope = &joinscope;
132 for (auto ref : lhs.refs)
133 joinscope.refs.push_back(&*ref);
134 for (auto ref : rhs.refs)
135 joinscope.refs.push_back(&*ref);
136 search = bool_expr::factory(this);
137}
138
139void expr_join_cond::out(std::ostream &out) {
140 out << *search;
141}
142
143joined_table::joined_table(prod *p) : table_ref(p) {
144 lhs = table_ref::factory(this);
145 rhs = table_ref::factory(this);
146
147 condition = join_cond::factory(this, *lhs, *rhs);
148
149 if (d6() < 4) {
150 type = "inner";
151 } else if (d6() < 4) {
152 type = "left";
153 } else {
154 type = "right";
155 }
156
157 for (auto ref : lhs->refs)
158 refs.push_back(ref);
159 for (auto ref : rhs->refs)
160 refs.push_back(ref);
161}
162
163void joined_table::out(std::ostream &out) {
164 out << *lhs;
165 indent(out);
166 out << type << " join " << *rhs;
167 indent(out);
168 out << "on (" << *condition << ")";
169}
170
171void table_subquery::out(std::ostream &out) {
172 if (is_lateral)
173 out << "lateral ";
174 out << "(" << *query << ") as " << refs[0]->ident();
175}
176
177void from_clause::out(std::ostream &out) {
178 if (!reflist.size())
179 return;
180 out << "from ";
181
182 for (auto r = reflist.begin(); r < reflist.end(); r++) {
183 indent(out);
184 out << **r;
185 if (r + 1 != reflist.end())
186 out << ",";
187 }
188}
189
190from_clause::from_clause(prod *p) : prod(p) {
191 reflist.push_back(table_ref::factory(this));
192 for (auto r : reflist.back()->refs)
193 scope->refs.push_back(&*r);
194
195 while (d6() > 5) {
196 // add a lateral subquery
197 if (!impedance::matched(typeid(lateral_subquery)))
198 break;
199 reflist.push_back(make_shared<lateral_subquery>(this));
200 for (auto r : reflist.back()->refs)
201 scope->refs.push_back(&*r);
202 }
203}
204
205select_list::select_list(prod *p) : prod(p) {
206 do {
207 shared_ptr<value_expr> e = value_expr::factory(this);
208 value_exprs.push_back(e);
209 ostringstream name;
210 name << "c" << columns++;
211 sqltype *t = e->type;
212 assert(t);
213 derived_table.columns().push_back(column(name.str(), t));
214 } while (d6() > 1);
215}
216
217void select_list::out(std::ostream &out) {
218 int i = 0;
219 for (auto expr = value_exprs.begin(); expr != value_exprs.end(); expr++) {
220 indent(out);
221 out << **expr << " as " << derived_table.columns()[i].name;
222 i++;
223 if (expr + 1 != value_exprs.end())
224 out << ", ";
225 }
226}
227
228void query_spec::out(std::ostream &out) {
229 out << "select " << set_quantifier << " " << *select_list;
230 indent(out);
231 out << *from_clause;
232 indent(out);
233 out << "where ";
234 out << *search;
235 if (limit_clause.length()) {
236 indent(out);
237 out << limit_clause;
238 }
239}
240
241struct for_update_verify : prod_visitor {
242 virtual void visit(prod *p) {
243 if (dynamic_cast<window_function *>(p))
244 throw("window function");
245 joined_table *join = dynamic_cast<joined_table *>(p);
246 if (join && join->type != "inner")
247 throw("outer join");
248 query_spec *subquery = dynamic_cast<query_spec *>(p);
249 if (subquery)
250 subquery->set_quantifier = "";
251 table_or_query_name *tab = dynamic_cast<table_or_query_name *>(p);
252 if (tab) {
253 table *actual_table = dynamic_cast<table *>(tab->t);
254 if (actual_table && !actual_table->is_insertable)
255 throw("read only");
256 if (actual_table->name.find("pg_"))
257 throw("catalog");
258 }
259 table_sample *sample = dynamic_cast<table_sample *>(p);
260 if (sample) {
261 table *actual_table = dynamic_cast<table *>(sample->t);
262 if (actual_table && !actual_table->is_insertable)
263 throw("read only");
264 if (actual_table->name.find("pg_"))
265 throw("catalog");
266 }
267 };
268};
269
270select_for_update::select_for_update(prod *p, struct scope *s, bool lateral) : query_spec(p, s, lateral) {
271 static const char *modes[] = {
272 "update",
273 "share",
274 "no key update",
275 "key share",
276 };
277
278 try {
279 for_update_verify v1;
280 this->accept(&v1);
281
282 } catch (const char *reason) {
283 lockmode = 0;
284 return;
285 }
286 lockmode = modes[d6() % (sizeof(modes) / sizeof(*modes))];
287 set_quantifier = ""; // disallow distinct
288}
289
290void select_for_update::out(std::ostream &out) {
291 query_spec::out(out);
292 if (lockmode) {
293 indent(out);
294 out << " for " << lockmode;
295 }
296}
297
298query_spec::query_spec(prod *p, struct scope *s, bool lateral) : prod(p), myscope(s) {
299 scope = &myscope;
300 scope->tables = s->tables;
301
302 if (lateral)
303 scope->refs = s->refs;
304
305 from_clause = make_shared<struct from_clause>(this);
306 select_list = make_shared<struct select_list>(this);
307
308 set_quantifier = (d100() == 1) ? "distinct" : "";
309
310 search = bool_expr::factory(this);
311
312 if (d6() > 2) {
313 ostringstream cons;
314 cons << "limit " << d100() + d100();
315 limit_clause = cons.str();
316 }
317}
318
319long prepare_stmt::seq;
320
321void modifying_stmt::pick_victim() {
322 do {
323 struct named_relation *pick = random_pick(scope->tables);
324 victim = dynamic_cast<struct table *>(pick);
325 retry();
326 } while (!victim || victim->schema == "pg_catalog" || !victim->is_base_table || !victim->columns().size());
327}
328
329modifying_stmt::modifying_stmt(prod *p, struct scope *s, table *victim) : prod(p), myscope(s) {
330 scope = &myscope;
331 scope->tables = s->tables;
332
333 if (!victim)
334 pick_victim();
335}
336
337delete_stmt::delete_stmt(prod *p, struct scope *s, table *v) : modifying_stmt(p, s, v) {
338 scope->refs.push_back(victim);
339 search = bool_expr::factory(this);
340}
341
342delete_returning::delete_returning(prod *p, struct scope *s, table *victim) : delete_stmt(p, s, victim) {
343 match();
344 select_list = make_shared<struct select_list>(this);
345}
346
347insert_stmt::insert_stmt(prod *p, struct scope *s, table *v) : modifying_stmt(p, s, v) {
348 match();
349
350 for (auto col : victim->columns()) {
351 auto expr = value_expr::factory(this, col.type);
352 assert(expr->type == col.type);
353 value_exprs.push_back(expr);
354 }
355}
356
357void insert_stmt::out(std::ostream &out) {
358 out << "insert into " << victim->ident() << " ";
359
360 if (!value_exprs.size()) {
361 out << "default values";
362 return;
363 }
364
365 out << "values (";
366
367 for (auto expr = value_exprs.begin(); expr != value_exprs.end(); expr++) {
368 indent(out);
369 out << **expr;
370 if (expr + 1 != value_exprs.end())
371 out << ", ";
372 }
373 out << ")";
374}
375
376set_list::set_list(prod *p, table *target) : prod(p) {
377 do {
378 for (auto col : target->columns()) {
379 if (d6() < 4)
380 continue;
381 auto expr = value_expr::factory(this, col.type);
382 value_exprs.push_back(expr);
383 names.push_back(col.name);
384 }
385 } while (!names.size());
386}
387
388void set_list::out(std::ostream &out) {
389 assert(names.size());
390 out << " set ";
391 for (size_t i = 0; i < names.size(); i++) {
392 indent(out);
393 out << names[i] << " = " << *value_exprs[i];
394 if (i + 1 != names.size())
395 out << ", ";
396 }
397}
398
399update_stmt::update_stmt(prod *p, struct scope *s, table *v) : modifying_stmt(p, s, v) {
400 scope->refs.push_back(victim);
401 search = bool_expr::factory(this);
402 set_list = make_shared<struct set_list>(this, victim);
403}
404
405void update_stmt::out(std::ostream &out) {
406 out << "update " << victim->ident() << *set_list;
407}
408
409update_returning::update_returning(prod *p, struct scope *s, table *v) : update_stmt(p, s, v) {
410 match();
411
412 select_list = make_shared<struct select_list>(this);
413}
414
415upsert_stmt::upsert_stmt(prod *p, struct scope *s, table *v) : insert_stmt(p, s, v) {
416 match();
417
418 if (!victim->constraints.size())
419 fail("need table w/ constraint for upsert");
420
421 set_list = std::make_shared<struct set_list>(this, victim);
422 search = bool_expr::factory(this);
423 constraint = random_pick(victim->constraints);
424}
425
426shared_ptr<prod> statement_factory(struct scope *s) {
427 try {
428 s->new_stmt();
429 if (d42() == 1)
430 return make_shared<merge_stmt>((struct prod *)0, s);
431 if (d42() == 1)
432 return make_shared<insert_stmt>((struct prod *)0, s);
433 else if (d42() == 1)
434 return make_shared<delete_returning>((struct prod *)0, s);
435 else if (d42() == 1) {
436 return make_shared<upsert_stmt>((struct prod *)0, s);
437 } else if (d42() == 1)
438 return make_shared<update_returning>((struct prod *)0, s);
439 else if (d6() > 4)
440 return make_shared<select_for_update>((struct prod *)0, s);
441 else if (d6() > 5)
442 return make_shared<common_table_expression>((struct prod *)0, s);
443 return make_shared<query_spec>((struct prod *)0, s);
444 } catch (runtime_error &e) {
445 return statement_factory(s);
446 }
447}
448
449void common_table_expression::accept(prod_visitor *v) {
450 v->visit(this);
451 for (auto q : with_queries)
452 q->accept(v);
453 query->accept(v);
454}
455
456common_table_expression::common_table_expression(prod *parent, struct scope *s) : prod(parent), myscope(s) {
457 scope = &myscope;
458 do {
459 shared_ptr<query_spec> query = make_shared<query_spec>(this, s);
460 with_queries.push_back(query);
461 string alias = scope->stmt_uid("jennifer");
462 relation *relation = &query->select_list->derived_table;
463 auto aliased_rel = make_shared<aliased_relation>(alias, relation);
464 refs.push_back(aliased_rel);
465 scope->tables.push_back(&*aliased_rel);
466
467 } while (d6() > 2);
468
469retry:
470 do {
471 auto pick = random_pick(s->tables);
472 scope->tables.push_back(pick);
473 } while (d6() > 3);
474 try {
475 query = make_shared<query_spec>(this, scope);
476 } catch (runtime_error &e) {
477 retry();
478 goto retry;
479 }
480}
481
482void common_table_expression::out(std::ostream &out) {
483 out << "WITH ";
484 for (size_t i = 0; i < with_queries.size(); i++) {
485 indent(out);
486 out << refs[i]->ident() << " AS "
487 << "(" << *with_queries[i] << ")";
488 if (i + 1 != with_queries.size())
489 out << ", ";
490 indent(out);
491 }
492 out << *query;
493 indent(out);
494}
495
496merge_stmt::merge_stmt(prod *p, struct scope *s, table *v) : modifying_stmt(p, s, v) {
497 match();
498 target_table_ = make_shared<target_table>(this, victim);
499 data_source = table_ref::factory(this);
500 // join_condition = join_cond::factory(this, *target_table_,
501 // *data_source);
502 join_condition = make_shared<simple_join_cond>(this, *target_table_, *data_source);
503
504 /* Put data_source into scope but not target_table. Visibility of
505 the latter varies depending on kind of when clause. */
506 // for (auto r : data_source->refs)
507 // scope->refs.push_back(&*r);
508
509 clauselist.push_back(when_clause::factory(this));
510 while (d6() > 4)
511 clauselist.push_back(when_clause::factory(this));
512}
513
514void merge_stmt::out(std::ostream &out) {
515 out << "MERGE INTO " << *target_table_;
516 indent(out);
517 out << "USING " << *data_source;
518 indent(out);
519 out << "ON " << *join_condition;
520 indent(out);
521 for (auto p : clauselist) {
522 out << *p;
523 indent(out);
524 }
525}
526
527void merge_stmt::accept(prod_visitor *v) {
528 v->visit(this);
529 target_table_->accept(v);
530 data_source->accept(v);
531 join_condition->accept(v);
532 for (auto p : clauselist)
533 p->accept(v);
534}
535
536when_clause::when_clause(merge_stmt *p) : prod(p) {
537 condition = bool_expr::factory(this);
538 matched = d6() > 3;
539}
540
541void when_clause::out(std::ostream &out) {
542 out << (matched ? "WHEN MATCHED " : "WHEN NOT MATCHED");
543 indent(out);
544 out << "AND " << *condition;
545 indent(out);
546 out << " THEN ";
547 out << (matched ? "DELETE" : "DO NOTHING");
548}
549
550void when_clause::accept(prod_visitor *v) {
551 v->visit(this);
552 condition->accept(v);
553}
554
555when_clause_update::when_clause_update(merge_stmt *p) : when_clause(p), myscope(p->scope) {
556 myscope.tables = scope->tables;
557 myscope.refs = scope->refs;
558 scope = &myscope;
559 scope->refs.push_back(&*(p->target_table_->refs[0]));
560
561 set_list = std::make_shared<struct set_list>(this, p->victim);
562}
563
564void when_clause_update::out(std::ostream &out) {
565 out << "WHEN MATCHED AND " << *condition;
566 indent(out);
567 out << " THEN UPDATE " << *set_list;
568}
569
570void when_clause_update::accept(prod_visitor *v) {
571 v->visit(this);
572 set_list->accept(v);
573}
574
575when_clause_insert::when_clause_insert(struct merge_stmt *p) : when_clause(p) {
576 for (auto col : p->victim->columns()) {
577 auto expr = value_expr::factory(this, col.type);
578 assert(expr->type == col.type);
579 exprs.push_back(expr);
580 }
581}
582
583void when_clause_insert::out(std::ostream &out) {
584 out << "WHEN NOT MATCHED AND " << *condition;
585 indent(out);
586 out << " THEN INSERT VALUES ( ";
587
588 for (auto expr = exprs.begin(); expr != exprs.end(); expr++) {
589 out << **expr;
590 if (expr + 1 != exprs.end())
591 out << ", ";
592 }
593 out << ")";
594}
595
596void when_clause_insert::accept(prod_visitor *v) {
597 v->visit(this);
598 for (auto p : exprs)
599 p->accept(v);
600}
601
602shared_ptr<when_clause> when_clause::factory(struct merge_stmt *p) {
603 try {
604 switch (d6()) {
605 case 1:
606 case 2:
607 return make_shared<when_clause_insert>(p);
608 case 3:
609 case 4:
610 return make_shared<when_clause_update>(p);
611 default:
612 return make_shared<when_clause>(p);
613 }
614 } catch (runtime_error &e) {
615 p->retry();
616 }
617 return factory(p);
618}
619