1#include "catch.hpp"
2#include "duckdb/common/helper.hpp"
3#include "expression_helper.hpp"
4#include "duckdb/optimizer/cse_optimizer.hpp"
5#include "duckdb/planner/expression/bound_comparison_expression.hpp"
6#include "duckdb/planner/expression/bound_function_expression.hpp"
7#include "duckdb/planner/expression/common_subexpression.hpp"
8#include "test_helpers.hpp"
9
10using namespace duckdb;
11using namespace std;
12
13TEST_CASE("Test CSE Optimizer", "[optimizer]") {
14 ExpressionHelper helper;
15 auto &con = helper.con;
16
17 REQUIRE_NO_FAIL(con.Query("CREATE TABLE integers(i INTEGER)"));
18
19 CommonSubExpressionOptimizer optimizer;
20
21 // simple CSE
22 auto tree = helper.ParseLogicalTree("SELECT i+1, i+1 FROM integers");
23 optimizer.VisitOperator(*tree);
24 REQUIRE(tree->type == LogicalOperatorType::PROJECTION);
25 REQUIRE(tree->expressions[0]->type == ExpressionType::COMMON_SUBEXPRESSION);
26 REQUIRE(tree->expressions[1]->type == ExpressionType::COMMON_SUBEXPRESSION);
27
28 // more complex CSE
29 tree = helper.ParseLogicalTree("SELECT i+(i+1), i+1 FROM integers");
30 optimizer.VisitOperator(*tree);
31 REQUIRE(tree->type == LogicalOperatorType::PROJECTION);
32 REQUIRE(tree->expressions[1]->type == ExpressionType::COMMON_SUBEXPRESSION);
33
34 // more CSEs
35 tree = helper.ParseLogicalTree("SELECT i*2, i+1, i*2, i+1, (i+1)+(i*2) FROM integers");
36 optimizer.VisitOperator(*tree);
37 REQUIRE(tree->type == LogicalOperatorType::PROJECTION);
38 REQUIRE(tree->expressions[0]->type == ExpressionType::COMMON_SUBEXPRESSION);
39 REQUIRE(tree->expressions[1]->type == ExpressionType::COMMON_SUBEXPRESSION);
40 REQUIRE(tree->expressions[2]->type == ExpressionType::COMMON_SUBEXPRESSION);
41 REQUIRE(tree->expressions[3]->type == ExpressionType::COMMON_SUBEXPRESSION);
42 REQUIRE(tree->expressions[4]->type == ExpressionType::BOUND_FUNCTION);
43 auto &op = (BoundFunctionExpression &)*tree->expressions[4];
44 REQUIRE(op.children[0]->type == ExpressionType::COMMON_SUBEXPRESSION);
45 REQUIRE(op.children[1]->type == ExpressionType::COMMON_SUBEXPRESSION);
46
47 // test CSEs in WHERE clause
48 tree = helper.ParseLogicalTree("SELECT i FROM integers WHERE i+1>10 AND i+1<20");
49 optimizer.VisitOperator(*tree);
50 REQUIRE(tree->type == LogicalOperatorType::PROJECTION);
51 REQUIRE(tree->children[0]->type == LogicalOperatorType::FILTER);
52 auto &filter = *tree->children[0];
53 REQUIRE(filter.expressions[0]->GetExpressionClass() == ExpressionClass::BOUND_COMPARISON);
54 REQUIRE(filter.expressions[1]->GetExpressionClass() == ExpressionClass::BOUND_COMPARISON);
55 auto &lcomp = (BoundComparisonExpression &)*filter.expressions[0];
56 auto &rcomp = (BoundComparisonExpression &)*filter.expressions[1];
57 REQUIRE(lcomp.left->type == ExpressionType::COMMON_SUBEXPRESSION);
58 REQUIRE(rcomp.left->type == ExpressionType::COMMON_SUBEXPRESSION);
59}
60
61TEST_CASE("CSE NULL*MIN(42) defense", "[optimizer]") {
62 DuckDB db(nullptr);
63 Connection con(db);
64 con.EnableQueryVerification();
65
66 auto result = con.Query("SELECT NULL * MIN(42);");
67 REQUIRE(result->success);
68 REQUIRE(CHECK_COLUMN(result, 0, {Value()}));
69}
70