1#include "duckdb/verification/prepared_statement_verifier.hpp"
2
3#include "duckdb/common/preserved_error.hpp"
4#include "duckdb/parser/expression/parameter_expression.hpp"
5#include "duckdb/parser/parsed_expression_iterator.hpp"
6#include "duckdb/parser/statement/drop_statement.hpp"
7#include "duckdb/parser/statement/execute_statement.hpp"
8#include "duckdb/parser/statement/prepare_statement.hpp"
9
10namespace duckdb {
11
12PreparedStatementVerifier::PreparedStatementVerifier(unique_ptr<SQLStatement> statement_p)
13 : StatementVerifier(VerificationType::PREPARED, "Prepared", std::move(statement_p)) {
14}
15
16unique_ptr<StatementVerifier> PreparedStatementVerifier::Create(const SQLStatement &statement) {
17 return make_uniq<PreparedStatementVerifier>(args: statement.Copy());
18}
19
20void PreparedStatementVerifier::Extract() {
21 auto &select = *statement;
22 // replace all the constants from the select statement and replace them with parameter expressions
23 ParsedExpressionIterator::EnumerateQueryNodeChildren(
24 node&: *select.node, callback: [&](unique_ptr<ParsedExpression> &child) { ConvertConstants(child); });
25 statement->n_param = values.size();
26 // create the PREPARE and EXECUTE statements
27 string name = "__duckdb_verification_prepared_statement";
28 auto prepare = make_uniq<PrepareStatement>();
29 prepare->name = name;
30 prepare->statement = std::move(statement);
31
32 auto execute = make_uniq<ExecuteStatement>();
33 execute->name = name;
34 execute->values = std::move(values);
35
36 auto dealloc = make_uniq<DropStatement>();
37 dealloc->info->type = CatalogType::PREPARED_STATEMENT;
38 dealloc->info->name = string(name);
39
40 prepare_statement = std::move(prepare);
41 execute_statement = std::move(execute);
42 dealloc_statement = std::move(dealloc);
43}
44
45void PreparedStatementVerifier::ConvertConstants(unique_ptr<ParsedExpression> &child) {
46 if (child->type == ExpressionType::VALUE_CONSTANT) {
47 // constant: extract the constant value
48 auto alias = child->alias;
49 child->alias = string();
50 // check if the value already exists
51 idx_t index = values.size();
52 for (idx_t v_idx = 0; v_idx < values.size(); v_idx++) {
53 if (values[v_idx]->Equals(other: *child)) {
54 // duplicate value! refer to the original value
55 index = v_idx;
56 break;
57 }
58 }
59 if (index == values.size()) {
60 values.push_back(x: std::move(child));
61 }
62 // replace it with an expression
63 auto parameter = make_uniq<ParameterExpression>();
64 parameter->parameter_nr = index + 1;
65 parameter->alias = alias;
66 child = std::move(parameter);
67 return;
68 }
69 ParsedExpressionIterator::EnumerateChildren(expr&: *child,
70 callback: [&](unique_ptr<ParsedExpression> &child) { ConvertConstants(child); });
71}
72
73bool PreparedStatementVerifier::Run(
74 ClientContext &context, const string &query,
75 const std::function<unique_ptr<QueryResult>(const string &, unique_ptr<SQLStatement>)> &run) {
76 bool failed = false;
77 // verify that we can extract all constants from the query and run the query as a prepared statement
78 // create the PREPARE and EXECUTE statements
79 Extract();
80 // execute the prepared statements
81 try {
82 auto prepare_result = run(string(), std::move(prepare_statement));
83 if (prepare_result->HasError()) {
84 prepare_result->ThrowError(prepended_message: "Failed prepare during verify: ");
85 }
86 auto execute_result = run(string(), std::move(execute_statement));
87 if (execute_result->HasError()) {
88 execute_result->ThrowError(prepended_message: "Failed execute during verify: ");
89 }
90 materialized_result = unique_ptr_cast<QueryResult, MaterializedQueryResult>(src: std::move(execute_result));
91 } catch (const Exception &ex) {
92 if (ex.type != ExceptionType::PARAMETER_NOT_ALLOWED) {
93 materialized_result = make_uniq<MaterializedQueryResult>(args: PreservedError(ex));
94 }
95 failed = true;
96 } catch (std::exception &ex) {
97 materialized_result = make_uniq<MaterializedQueryResult>(args: PreservedError(ex));
98 failed = true;
99 }
100 run(string(), std::move(dealloc_statement));
101 context.interrupted = false;
102
103 return failed;
104}
105
106} // namespace duckdb
107