| 1 | #include "duckdb/verification/prepared_statement_verifier.hpp" |
| 2 | |
| 3 | #include "duckdb/common/preserved_error.hpp" |
| 4 | #include "duckdb/parser/expression/parameter_expression.hpp" |
| 5 | #include "duckdb/parser/parsed_expression_iterator.hpp" |
| 6 | #include "duckdb/parser/statement/drop_statement.hpp" |
| 7 | #include "duckdb/parser/statement/execute_statement.hpp" |
| 8 | #include "duckdb/parser/statement/prepare_statement.hpp" |
| 9 | |
| 10 | namespace duckdb { |
| 11 | |
| 12 | PreparedStatementVerifier::PreparedStatementVerifier(unique_ptr<SQLStatement> statement_p) |
| 13 | : StatementVerifier(VerificationType::PREPARED, "Prepared" , std::move(statement_p)) { |
| 14 | } |
| 15 | |
| 16 | unique_ptr<StatementVerifier> PreparedStatementVerifier::Create(const SQLStatement &statement) { |
| 17 | return make_uniq<PreparedStatementVerifier>(args: statement.Copy()); |
| 18 | } |
| 19 | |
| 20 | void PreparedStatementVerifier::() { |
| 21 | auto &select = *statement; |
| 22 | // replace all the constants from the select statement and replace them with parameter expressions |
| 23 | ParsedExpressionIterator::EnumerateQueryNodeChildren( |
| 24 | node&: *select.node, callback: [&](unique_ptr<ParsedExpression> &child) { ConvertConstants(child); }); |
| 25 | statement->n_param = values.size(); |
| 26 | // create the PREPARE and EXECUTE statements |
| 27 | string name = "__duckdb_verification_prepared_statement" ; |
| 28 | auto prepare = make_uniq<PrepareStatement>(); |
| 29 | prepare->name = name; |
| 30 | prepare->statement = std::move(statement); |
| 31 | |
| 32 | auto execute = make_uniq<ExecuteStatement>(); |
| 33 | execute->name = name; |
| 34 | execute->values = std::move(values); |
| 35 | |
| 36 | auto dealloc = make_uniq<DropStatement>(); |
| 37 | dealloc->info->type = CatalogType::PREPARED_STATEMENT; |
| 38 | dealloc->info->name = string(name); |
| 39 | |
| 40 | prepare_statement = std::move(prepare); |
| 41 | execute_statement = std::move(execute); |
| 42 | dealloc_statement = std::move(dealloc); |
| 43 | } |
| 44 | |
| 45 | void PreparedStatementVerifier::ConvertConstants(unique_ptr<ParsedExpression> &child) { |
| 46 | if (child->type == ExpressionType::VALUE_CONSTANT) { |
| 47 | // constant: extract the constant value |
| 48 | auto alias = child->alias; |
| 49 | child->alias = string(); |
| 50 | // check if the value already exists |
| 51 | idx_t index = values.size(); |
| 52 | for (idx_t v_idx = 0; v_idx < values.size(); v_idx++) { |
| 53 | if (values[v_idx]->Equals(other: *child)) { |
| 54 | // duplicate value! refer to the original value |
| 55 | index = v_idx; |
| 56 | break; |
| 57 | } |
| 58 | } |
| 59 | if (index == values.size()) { |
| 60 | values.push_back(x: std::move(child)); |
| 61 | } |
| 62 | // replace it with an expression |
| 63 | auto parameter = make_uniq<ParameterExpression>(); |
| 64 | parameter->parameter_nr = index + 1; |
| 65 | parameter->alias = alias; |
| 66 | child = std::move(parameter); |
| 67 | return; |
| 68 | } |
| 69 | ParsedExpressionIterator::EnumerateChildren(expr&: *child, |
| 70 | callback: [&](unique_ptr<ParsedExpression> &child) { ConvertConstants(child); }); |
| 71 | } |
| 72 | |
| 73 | bool PreparedStatementVerifier::Run( |
| 74 | ClientContext &context, const string &query, |
| 75 | const std::function<unique_ptr<QueryResult>(const string &, unique_ptr<SQLStatement>)> &run) { |
| 76 | bool failed = false; |
| 77 | // verify that we can extract all constants from the query and run the query as a prepared statement |
| 78 | // create the PREPARE and EXECUTE statements |
| 79 | Extract(); |
| 80 | // execute the prepared statements |
| 81 | try { |
| 82 | auto prepare_result = run(string(), std::move(prepare_statement)); |
| 83 | if (prepare_result->HasError()) { |
| 84 | prepare_result->ThrowError(prepended_message: "Failed prepare during verify: " ); |
| 85 | } |
| 86 | auto execute_result = run(string(), std::move(execute_statement)); |
| 87 | if (execute_result->HasError()) { |
| 88 | execute_result->ThrowError(prepended_message: "Failed execute during verify: " ); |
| 89 | } |
| 90 | materialized_result = unique_ptr_cast<QueryResult, MaterializedQueryResult>(src: std::move(execute_result)); |
| 91 | } catch (const Exception &ex) { |
| 92 | if (ex.type != ExceptionType::PARAMETER_NOT_ALLOWED) { |
| 93 | materialized_result = make_uniq<MaterializedQueryResult>(args: PreservedError(ex)); |
| 94 | } |
| 95 | failed = true; |
| 96 | } catch (std::exception &ex) { |
| 97 | materialized_result = make_uniq<MaterializedQueryResult>(args: PreservedError(ex)); |
| 98 | failed = true; |
| 99 | } |
| 100 | run(string(), std::move(dealloc_statement)); |
| 101 | context.interrupted = false; |
| 102 | |
| 103 | return failed; |
| 104 | } |
| 105 | |
| 106 | } // namespace duckdb |
| 107 | |