1 | #include "duckdb/verification/prepared_statement_verifier.hpp" |
2 | |
3 | #include "duckdb/common/preserved_error.hpp" |
4 | #include "duckdb/parser/expression/parameter_expression.hpp" |
5 | #include "duckdb/parser/parsed_expression_iterator.hpp" |
6 | #include "duckdb/parser/statement/drop_statement.hpp" |
7 | #include "duckdb/parser/statement/execute_statement.hpp" |
8 | #include "duckdb/parser/statement/prepare_statement.hpp" |
9 | |
10 | namespace duckdb { |
11 | |
12 | PreparedStatementVerifier::PreparedStatementVerifier(unique_ptr<SQLStatement> statement_p) |
13 | : StatementVerifier(VerificationType::PREPARED, "Prepared" , std::move(statement_p)) { |
14 | } |
15 | |
16 | unique_ptr<StatementVerifier> PreparedStatementVerifier::Create(const SQLStatement &statement) { |
17 | return make_uniq<PreparedStatementVerifier>(args: statement.Copy()); |
18 | } |
19 | |
20 | void PreparedStatementVerifier::() { |
21 | auto &select = *statement; |
22 | // replace all the constants from the select statement and replace them with parameter expressions |
23 | ParsedExpressionIterator::EnumerateQueryNodeChildren( |
24 | node&: *select.node, callback: [&](unique_ptr<ParsedExpression> &child) { ConvertConstants(child); }); |
25 | statement->n_param = values.size(); |
26 | // create the PREPARE and EXECUTE statements |
27 | string name = "__duckdb_verification_prepared_statement" ; |
28 | auto prepare = make_uniq<PrepareStatement>(); |
29 | prepare->name = name; |
30 | prepare->statement = std::move(statement); |
31 | |
32 | auto execute = make_uniq<ExecuteStatement>(); |
33 | execute->name = name; |
34 | execute->values = std::move(values); |
35 | |
36 | auto dealloc = make_uniq<DropStatement>(); |
37 | dealloc->info->type = CatalogType::PREPARED_STATEMENT; |
38 | dealloc->info->name = string(name); |
39 | |
40 | prepare_statement = std::move(prepare); |
41 | execute_statement = std::move(execute); |
42 | dealloc_statement = std::move(dealloc); |
43 | } |
44 | |
45 | void PreparedStatementVerifier::ConvertConstants(unique_ptr<ParsedExpression> &child) { |
46 | if (child->type == ExpressionType::VALUE_CONSTANT) { |
47 | // constant: extract the constant value |
48 | auto alias = child->alias; |
49 | child->alias = string(); |
50 | // check if the value already exists |
51 | idx_t index = values.size(); |
52 | for (idx_t v_idx = 0; v_idx < values.size(); v_idx++) { |
53 | if (values[v_idx]->Equals(other: *child)) { |
54 | // duplicate value! refer to the original value |
55 | index = v_idx; |
56 | break; |
57 | } |
58 | } |
59 | if (index == values.size()) { |
60 | values.push_back(x: std::move(child)); |
61 | } |
62 | // replace it with an expression |
63 | auto parameter = make_uniq<ParameterExpression>(); |
64 | parameter->parameter_nr = index + 1; |
65 | parameter->alias = alias; |
66 | child = std::move(parameter); |
67 | return; |
68 | } |
69 | ParsedExpressionIterator::EnumerateChildren(expr&: *child, |
70 | callback: [&](unique_ptr<ParsedExpression> &child) { ConvertConstants(child); }); |
71 | } |
72 | |
73 | bool PreparedStatementVerifier::Run( |
74 | ClientContext &context, const string &query, |
75 | const std::function<unique_ptr<QueryResult>(const string &, unique_ptr<SQLStatement>)> &run) { |
76 | bool failed = false; |
77 | // verify that we can extract all constants from the query and run the query as a prepared statement |
78 | // create the PREPARE and EXECUTE statements |
79 | Extract(); |
80 | // execute the prepared statements |
81 | try { |
82 | auto prepare_result = run(string(), std::move(prepare_statement)); |
83 | if (prepare_result->HasError()) { |
84 | prepare_result->ThrowError(prepended_message: "Failed prepare during verify: " ); |
85 | } |
86 | auto execute_result = run(string(), std::move(execute_statement)); |
87 | if (execute_result->HasError()) { |
88 | execute_result->ThrowError(prepended_message: "Failed execute during verify: " ); |
89 | } |
90 | materialized_result = unique_ptr_cast<QueryResult, MaterializedQueryResult>(src: std::move(execute_result)); |
91 | } catch (const Exception &ex) { |
92 | if (ex.type != ExceptionType::PARAMETER_NOT_ALLOWED) { |
93 | materialized_result = make_uniq<MaterializedQueryResult>(args: PreservedError(ex)); |
94 | } |
95 | failed = true; |
96 | } catch (std::exception &ex) { |
97 | materialized_result = make_uniq<MaterializedQueryResult>(args: PreservedError(ex)); |
98 | failed = true; |
99 | } |
100 | run(string(), std::move(dealloc_statement)); |
101 | context.interrupted = false; |
102 | |
103 | return failed; |
104 | } |
105 | |
106 | } // namespace duckdb |
107 | |