1#include "duckdb/main/connection.hpp"
2
3#include "duckdb/common/types/column/column_data_collection.hpp"
4#include "duckdb/execution/operator/persistent/parallel_csv_reader.hpp"
5#include "duckdb/function/table/read_csv.hpp"
6#include "duckdb/main/appender.hpp"
7#include "duckdb/main/client_context.hpp"
8#include "duckdb/main/connection_manager.hpp"
9#include "duckdb/main/database.hpp"
10#include "duckdb/main/query_profiler.hpp"
11#include "duckdb/main/relation/query_relation.hpp"
12#include "duckdb/main/relation/read_csv_relation.hpp"
13#include "duckdb/main/relation/table_function_relation.hpp"
14#include "duckdb/main/relation/table_relation.hpp"
15#include "duckdb/main/relation/value_relation.hpp"
16#include "duckdb/main/relation/view_relation.hpp"
17#include "duckdb/parser/parser.hpp"
18#include "duckdb/planner/logical_operator.hpp"
19
20namespace duckdb {
21
22Connection::Connection(DatabaseInstance &database) : context(make_shared<ClientContext>(args: database.shared_from_this())) {
23 ConnectionManager::Get(db&: database).AddConnection(context&: *context);
24#ifdef DEBUG
25 EnableProfiling();
26 context->config.emit_profiler_output = false;
27#endif
28}
29
30Connection::Connection(DuckDB &database) : Connection(*database.instance) {
31}
32
33Connection::~Connection() {
34 ConnectionManager::Get(db&: *context->db).RemoveConnection(context&: *context);
35}
36
37string Connection::GetProfilingInformation(ProfilerPrintFormat format) {
38 auto &profiler = QueryProfiler::Get(context&: *context);
39 if (format == ProfilerPrintFormat::JSON) {
40 return profiler.ToJSON();
41 } else {
42 return profiler.QueryTreeToString();
43 }
44}
45
46void Connection::Interrupt() {
47 context->Interrupt();
48}
49
50void Connection::EnableProfiling() {
51 context->EnableProfiling();
52}
53
54void Connection::DisableProfiling() {
55 context->DisableProfiling();
56}
57
58void Connection::EnableQueryVerification() {
59 ClientConfig::GetConfig(context&: *context).query_verification_enabled = true;
60}
61
62void Connection::DisableQueryVerification() {
63 ClientConfig::GetConfig(context&: *context).query_verification_enabled = false;
64}
65
66void Connection::ForceParallelism() {
67 ClientConfig::GetConfig(context&: *context).verify_parallelism = true;
68}
69
70unique_ptr<QueryResult> Connection::SendQuery(const string &query) {
71 return context->Query(query, allow_stream_result: true);
72}
73
74unique_ptr<MaterializedQueryResult> Connection::Query(const string &query) {
75 auto result = context->Query(query, allow_stream_result: false);
76 D_ASSERT(result->type == QueryResultType::MATERIALIZED_RESULT);
77 return unique_ptr_cast<QueryResult, MaterializedQueryResult>(src: std::move(result));
78}
79
80DUCKDB_API string Connection::GetSubstrait(const string &query) {
81 vector<Value> params;
82 params.emplace_back(args: query);
83 auto result = TableFunction(tname: "get_substrait", values: params)->Execute();
84 auto protobuf = result->FetchRaw()->GetValue(col_idx: 0, index: 0);
85 return protobuf.GetValueUnsafe<string_t>().GetString();
86}
87
88DUCKDB_API unique_ptr<QueryResult> Connection::FromSubstrait(const string &proto) {
89 vector<Value> params;
90 params.emplace_back(args: Value::BLOB_RAW(data: proto));
91 return TableFunction(tname: "from_substrait", values: params)->Execute();
92}
93
94DUCKDB_API string Connection::GetSubstraitJSON(const string &query) {
95 vector<Value> params;
96 params.emplace_back(args: query);
97 auto result = TableFunction(tname: "get_substrait_json", values: params)->Execute();
98 auto protobuf = result->FetchRaw()->GetValue(col_idx: 0, index: 0);
99 return protobuf.GetValueUnsafe<string_t>().GetString();
100}
101
102DUCKDB_API unique_ptr<QueryResult> Connection::FromSubstraitJSON(const string &json) {
103 vector<Value> params;
104 params.emplace_back(args: json);
105 return TableFunction(tname: "from_substrait_json", values: params)->Execute();
106}
107
108unique_ptr<MaterializedQueryResult> Connection::Query(unique_ptr<SQLStatement> statement) {
109 auto result = context->Query(statement: std::move(statement), allow_stream_result: false);
110 D_ASSERT(result->type == QueryResultType::MATERIALIZED_RESULT);
111 return unique_ptr_cast<QueryResult, MaterializedQueryResult>(src: std::move(result));
112}
113
114unique_ptr<PendingQueryResult> Connection::PendingQuery(const string &query, bool allow_stream_result) {
115 return context->PendingQuery(query, allow_stream_result);
116}
117
118unique_ptr<PendingQueryResult> Connection::PendingQuery(unique_ptr<SQLStatement> statement, bool allow_stream_result) {
119 return context->PendingQuery(statement: std::move(statement), allow_stream_result);
120}
121
122unique_ptr<PreparedStatement> Connection::Prepare(const string &query) {
123 return context->Prepare(query);
124}
125
126unique_ptr<PreparedStatement> Connection::Prepare(unique_ptr<SQLStatement> statement) {
127 return context->Prepare(statement: std::move(statement));
128}
129
130unique_ptr<QueryResult> Connection::QueryParamsRecursive(const string &query, vector<Value> &values) {
131 auto statement = Prepare(query);
132 if (statement->HasError()) {
133 return make_uniq<MaterializedQueryResult>(args&: statement->error);
134 }
135 return statement->Execute(values, allow_stream_result: false);
136}
137
138unique_ptr<TableDescription> Connection::TableInfo(const string &table_name) {
139 return TableInfo(INVALID_SCHEMA, table_name);
140}
141
142unique_ptr<TableDescription> Connection::TableInfo(const string &schema_name, const string &table_name) {
143 return context->TableInfo(schema_name, table_name);
144}
145
146vector<unique_ptr<SQLStatement>> Connection::ExtractStatements(const string &query) {
147 return context->ParseStatements(query);
148}
149
150unique_ptr<LogicalOperator> Connection::ExtractPlan(const string &query) {
151 return context->ExtractPlan(query);
152}
153
154void Connection::Append(TableDescription &description, DataChunk &chunk) {
155 if (chunk.size() == 0) {
156 return;
157 }
158 ColumnDataCollection collection(Allocator::Get(context&: *context), chunk.GetTypes());
159 collection.Append(new_chunk&: chunk);
160 Append(description, collection);
161}
162
163void Connection::Append(TableDescription &description, ColumnDataCollection &collection) {
164 context->Append(description, collection);
165}
166
167shared_ptr<Relation> Connection::Table(const string &table_name) {
168 return Table(DEFAULT_SCHEMA, table_name);
169}
170
171shared_ptr<Relation> Connection::Table(const string &schema_name, const string &table_name) {
172 auto table_info = TableInfo(schema_name, table_name);
173 if (!table_info) {
174 throw CatalogException("Table '%s' does not exist!", table_name);
175 }
176 return make_shared<TableRelation>(args&: context, args: std::move(table_info));
177}
178
179shared_ptr<Relation> Connection::View(const string &tname) {
180 return View(DEFAULT_SCHEMA, table_name: tname);
181}
182
183shared_ptr<Relation> Connection::View(const string &schema_name, const string &table_name) {
184 return make_shared<ViewRelation>(args&: context, args: schema_name, args: table_name);
185}
186
187shared_ptr<Relation> Connection::TableFunction(const string &fname) {
188 vector<Value> values;
189 named_parameter_map_t named_parameters;
190 return TableFunction(tname: fname, values, named_parameters);
191}
192
193shared_ptr<Relation> Connection::TableFunction(const string &fname, const vector<Value> &values,
194 const named_parameter_map_t &named_parameters) {
195 return make_shared<TableFunctionRelation>(args&: context, args: fname, args: values, args: named_parameters);
196}
197
198shared_ptr<Relation> Connection::TableFunction(const string &fname, const vector<Value> &values) {
199 return make_shared<TableFunctionRelation>(args&: context, args: fname, args: values);
200}
201
202shared_ptr<Relation> Connection::Values(const vector<vector<Value>> &values) {
203 vector<string> column_names;
204 return Values(values, column_names);
205}
206
207shared_ptr<Relation> Connection::Values(const vector<vector<Value>> &values, const vector<string> &column_names,
208 const string &alias) {
209 return make_shared<ValueRelation>(args&: context, args: values, args: column_names, args: alias);
210}
211
212shared_ptr<Relation> Connection::Values(const string &values) {
213 vector<string> column_names;
214 return Values(values, column_names);
215}
216
217shared_ptr<Relation> Connection::Values(const string &values, const vector<string> &column_names, const string &alias) {
218 return make_shared<ValueRelation>(args&: context, args: values, args: column_names, args: alias);
219}
220
221shared_ptr<Relation> Connection::ReadCSV(const string &csv_file) {
222 BufferedCSVReaderOptions options;
223 return ReadCSV(csv_file, options);
224}
225
226shared_ptr<Relation> Connection::ReadCSV(const string &csv_file, BufferedCSVReaderOptions &options) {
227 options.file_path = csv_file;
228 options.auto_detect = true;
229 return make_shared<ReadCSVRelation>(args&: context, args: csv_file, args&: options);
230}
231
232shared_ptr<Relation> Connection::ReadCSV(const string &csv_file, const vector<string> &columns) {
233 // parse columns
234 vector<ColumnDefinition> column_list;
235 for (auto &column : columns) {
236 auto col_list = Parser::ParseColumnList(column_list: column, options: context->GetParserOptions());
237 if (col_list.LogicalColumnCount() != 1) {
238 throw ParserException("Expected a single column definition");
239 }
240 column_list.push_back(x: std::move(col_list.GetColumnMutable(index: LogicalIndex(0))));
241 }
242 return make_shared<ReadCSVRelation>(args&: context, args: csv_file, args: std::move(column_list));
243}
244
245shared_ptr<Relation> Connection::ReadParquet(const string &parquet_file, bool binary_as_string) {
246 vector<Value> params;
247 params.emplace_back(args: parquet_file);
248 named_parameter_map_t named_parameters({{"binary_as_string", Value::BOOLEAN(value: binary_as_string)}});
249 return TableFunction(fname: "parquet_scan", values: params, named_parameters)->Alias(alias: parquet_file);
250}
251
252unordered_set<string> Connection::GetTableNames(const string &query) {
253 return context->GetTableNames(query);
254}
255
256shared_ptr<Relation> Connection::RelationFromQuery(const string &query, const string &alias, const string &error) {
257 return RelationFromQuery(select_stmt: QueryRelation::ParseStatement(context&: *context, query, error), alias);
258}
259
260shared_ptr<Relation> Connection::RelationFromQuery(unique_ptr<SelectStatement> select_stmt, const string &alias) {
261 return make_shared<QueryRelation>(args&: context, args: std::move(select_stmt), args: alias);
262}
263
264void Connection::BeginTransaction() {
265 auto result = Query(query: "BEGIN TRANSACTION");
266 if (result->HasError()) {
267 result->ThrowError();
268 }
269}
270
271void Connection::Commit() {
272 auto result = Query(query: "COMMIT");
273 if (result->HasError()) {
274 result->ThrowError();
275 }
276}
277
278void Connection::Rollback() {
279 auto result = Query(query: "ROLLBACK");
280 if (result->HasError()) {
281 result->ThrowError();
282 }
283}
284
285void Connection::SetAutoCommit(bool auto_commit) {
286 context->transaction.SetAutoCommit(auto_commit);
287}
288
289bool Connection::IsAutoCommit() {
290 return context->transaction.IsAutoCommit();
291}
292bool Connection::HasActiveTransaction() {
293 return context->transaction.HasActiveTransaction();
294}
295
296} // namespace duckdb
297