1#include "expression_helper.hpp"
2
3#include "duckdb.hpp"
4#include "duckdb/optimizer/rule/constant_folding.hpp"
5#include "duckdb/parser/parser.hpp"
6#include "duckdb/planner/binder.hpp"
7#include "duckdb/planner/bound_query_node.hpp"
8#include "duckdb/planner/expression_iterator.hpp"
9#include "duckdb/planner/operator/logical_projection.hpp"
10#include "duckdb/planner/planner.hpp"
11#include "duckdb/planner/query_node/bound_select_node.hpp"
12
13using namespace duckdb;
14using namespace std;
15
16ExpressionHelper::ExpressionHelper() : db(nullptr), con(db), rewriter(*con.context) {
17 con.Query("BEGIN TRANSACTION");
18}
19
20bool ExpressionHelper::VerifyRewrite(string input, string expected_output, bool silent) {
21 auto root = ParseExpression(input);
22 auto result = ApplyExpressionRule(move(root));
23 auto expected_result = ParseExpression(expected_output);
24 bool equals = Expression::Equals(result.get(), expected_result.get());
25 if (!equals && !silent) {
26 printf("Optimized result does not equal expected result!\n");
27 result->Print();
28 printf("Expected:\n");
29 expected_result->Print();
30 }
31 return equals;
32}
33
34string ExpressionHelper::AddColumns(string columns) {
35 if (!from_clause.empty()) {
36 con.Query("DROP TABLE expression_helper");
37 }
38 auto result = con.Query("CREATE TABLE expression_helper(" + columns + ")");
39 if (!result->success) {
40 return result->error;
41 }
42 from_clause = " FROM expression_helper";
43 return string();
44}
45
46unique_ptr<Expression> ExpressionHelper::ParseExpression(string expression) {
47 string query = "SELECT " + expression + from_clause;
48
49 Parser parser;
50 parser.ParseQuery(query.c_str());
51 if (parser.statements.size() == 0 || parser.statements[0]->type != StatementType::SELECT_STATEMENT) {
52 return nullptr;
53 }
54 Binder binder(*con.context);
55 auto bound_statement = binder.Bind(*parser.statements[0]);
56 assert(bound_statement.plan->type == LogicalOperatorType::PROJECTION);
57 return move(bound_statement.plan->expressions[0]);
58}
59
60unique_ptr<LogicalOperator> ExpressionHelper::ParseLogicalTree(string query) {
61 Parser parser;
62 parser.ParseQuery(query.c_str());
63 if (parser.statements.size() == 0 || parser.statements[0]->type != StatementType::SELECT_STATEMENT) {
64 return nullptr;
65 }
66 Planner planner(*con.context);
67 planner.CreatePlan(move(parser.statements[0]));
68 return move(planner.plan);
69}
70
71unique_ptr<Expression> ExpressionHelper::ApplyExpressionRule(unique_ptr<Expression> root) {
72 // make a logical projection
73 vector<unique_ptr<Expression>> expressions;
74 expressions.push_back(move(root));
75 auto proj = make_unique<LogicalProjection>(0, move(expressions));
76 rewriter.Apply(*proj);
77 return move(proj->expressions[0]);
78}
79