1#include "duckdb/function/table/table_scan.hpp"
2
3#include "duckdb/catalog/catalog_entry/duck_table_entry.hpp"
4#include "duckdb/common/field_writer.hpp"
5#include "duckdb/common/mutex.hpp"
6#include "duckdb/main/client_config.hpp"
7#include "duckdb/optimizer/matcher/expression_matcher.hpp"
8#include "duckdb/planner/expression/bound_between_expression.hpp"
9#include "duckdb/planner/expression_iterator.hpp"
10#include "duckdb/planner/operator/logical_get.hpp"
11#include "duckdb/storage/data_table.hpp"
12#include "duckdb/transaction/local_storage.hpp"
13#include "duckdb/transaction/duck_transaction.hpp"
14#include "duckdb/main/attached_database.hpp"
15#include "duckdb/catalog/dependency_list.hpp"
16#include "duckdb/function/function_set.hpp"
17#include "duckdb/storage/table/scan_state.hpp"
18
19namespace duckdb {
20
21//===--------------------------------------------------------------------===//
22// Table Scan
23//===--------------------------------------------------------------------===//
24bool TableScanParallelStateNext(ClientContext &context, const FunctionData *bind_data_p,
25 LocalTableFunctionState *local_state, GlobalTableFunctionState *gstate);
26
27struct TableScanLocalState : public LocalTableFunctionState {
28 //! The current position in the scan
29 TableScanState scan_state;
30 //! The DataChunk containing all read columns (even filter columns that are immediately removed)
31 DataChunk all_columns;
32};
33
34static storage_t GetStorageIndex(TableCatalogEntry &table, column_t column_id) {
35 if (column_id == DConstants::INVALID_INDEX) {
36 return column_id;
37 }
38 auto &col = table.GetColumn(idx: LogicalIndex(column_id));
39 return col.StorageOid();
40}
41
42struct TableScanGlobalState : public GlobalTableFunctionState {
43 TableScanGlobalState(ClientContext &context, const FunctionData *bind_data_p) {
44 D_ASSERT(bind_data_p);
45 auto &bind_data = bind_data_p->Cast<TableScanBindData>();
46 max_threads = bind_data.table.GetStorage().MaxThreads(context);
47 }
48
49 ParallelTableScanState state;
50 idx_t max_threads;
51
52 vector<idx_t> projection_ids;
53 vector<LogicalType> scanned_types;
54
55 idx_t MaxThreads() const override {
56 return max_threads;
57 }
58
59 bool CanRemoveFilterColumns() const {
60 return !projection_ids.empty();
61 }
62};
63
64static unique_ptr<LocalTableFunctionState> TableScanInitLocal(ExecutionContext &context, TableFunctionInitInput &input,
65 GlobalTableFunctionState *gstate) {
66 auto result = make_uniq<TableScanLocalState>();
67 auto &bind_data = input.bind_data->Cast<TableScanBindData>();
68 vector<column_t> column_ids = input.column_ids;
69 for (auto &col : column_ids) {
70 auto storage_idx = GetStorageIndex(table&: bind_data.table, column_id: col);
71 col = storage_idx;
72 }
73 result->scan_state.Initialize(column_ids: std::move(column_ids), table_filters: input.filters.get());
74 TableScanParallelStateNext(context&: context.client, bind_data_p: input.bind_data.get(), local_state: result.get(), gstate);
75 if (input.CanRemoveFilterColumns()) {
76 auto &tsgs = gstate->Cast<TableScanGlobalState>();
77 result->all_columns.Initialize(context&: context.client, types: tsgs.scanned_types);
78 }
79 return std::move(result);
80}
81
82unique_ptr<GlobalTableFunctionState> TableScanInitGlobal(ClientContext &context, TableFunctionInitInput &input) {
83
84 D_ASSERT(input.bind_data);
85 auto &bind_data = input.bind_data->Cast<TableScanBindData>();
86 auto result = make_uniq<TableScanGlobalState>(args&: context, args: input.bind_data.get());
87 bind_data.table.GetStorage().InitializeParallelScan(context, state&: result->state);
88 if (input.CanRemoveFilterColumns()) {
89 result->projection_ids = input.projection_ids;
90 const auto &columns = bind_data.table.GetColumns();
91 for (const auto &col_idx : input.column_ids) {
92 if (col_idx == COLUMN_IDENTIFIER_ROW_ID) {
93 result->scanned_types.emplace_back(args: LogicalType::ROW_TYPE);
94 } else {
95 result->scanned_types.push_back(x: columns.GetColumn(index: LogicalIndex(col_idx)).Type());
96 }
97 }
98 }
99 return std::move(result);
100}
101
102static unique_ptr<BaseStatistics> TableScanStatistics(ClientContext &context, const FunctionData *bind_data_p,
103 column_t column_id) {
104 auto &bind_data = bind_data_p->Cast<TableScanBindData>();
105 auto &local_storage = LocalStorage::Get(context, catalog&: bind_data.table.catalog);
106 if (local_storage.Find(table&: bind_data.table.GetStorage())) {
107 // we don't emit any statistics for tables that have outstanding transaction-local data
108 return nullptr;
109 }
110 return bind_data.table.GetStatistics(context, column_id);
111}
112
113static void TableScanFunc(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) {
114 auto &bind_data = data_p.bind_data->Cast<TableScanBindData>();
115 auto &gstate = data_p.global_state->Cast<TableScanGlobalState>();
116 auto &state = data_p.local_state->Cast<TableScanLocalState>();
117 auto &transaction = DuckTransaction::Get(context, catalog&: bind_data.table.catalog);
118 auto &storage = bind_data.table.GetStorage();
119 do {
120 if (bind_data.is_create_index) {
121 storage.CreateIndexScan(state&: state.scan_state, result&: output,
122 type: TableScanType::TABLE_SCAN_COMMITTED_ROWS_OMIT_PERMANENTLY_DELETED);
123 } else if (gstate.CanRemoveFilterColumns()) {
124 state.all_columns.Reset();
125 storage.Scan(transaction, result&: state.all_columns, state&: state.scan_state);
126 output.ReferenceColumns(other&: state.all_columns, column_ids: gstate.projection_ids);
127 } else {
128 storage.Scan(transaction, result&: output, state&: state.scan_state);
129 }
130 if (output.size() > 0) {
131 return;
132 }
133 if (!TableScanParallelStateNext(context, bind_data_p: data_p.bind_data.get(), local_state: data_p.local_state.get(),
134 gstate: data_p.global_state.get())) {
135 return;
136 }
137 } while (true);
138}
139
140bool TableScanParallelStateNext(ClientContext &context, const FunctionData *bind_data_p,
141 LocalTableFunctionState *local_state, GlobalTableFunctionState *global_state) {
142 auto &bind_data = bind_data_p->Cast<TableScanBindData>();
143 auto &parallel_state = global_state->Cast<TableScanGlobalState>();
144 auto &state = local_state->Cast<TableScanLocalState>();
145 auto &storage = bind_data.table.GetStorage();
146
147 return storage.NextParallelScan(context, state&: parallel_state.state, scan_state&: state.scan_state);
148}
149
150double TableScanProgress(ClientContext &context, const FunctionData *bind_data_p,
151 const GlobalTableFunctionState *gstate_p) {
152 auto &bind_data = bind_data_p->Cast<TableScanBindData>();
153 auto &gstate = gstate_p->Cast<TableScanGlobalState>();
154 auto &storage = bind_data.table.GetStorage();
155 idx_t total_rows = storage.GetTotalRows();
156 if (total_rows == 0) {
157 //! Table is either empty or smaller than a vector size, so it is finished
158 return 100;
159 }
160 idx_t scanned_rows = gstate.state.scan_state.processed_rows;
161 scanned_rows += gstate.state.local_state.processed_rows;
162 auto percentage = 100 * (double(scanned_rows) / total_rows);
163 if (percentage > 100) {
164 //! In case the last chunk has less elements than STANDARD_VECTOR_SIZE, if our percentage is over 100
165 //! It means we finished this table.
166 return 100;
167 }
168 return percentage;
169}
170
171idx_t TableScanGetBatchIndex(ClientContext &context, const FunctionData *bind_data_p,
172 LocalTableFunctionState *local_state, GlobalTableFunctionState *gstate_p) {
173 auto &state = local_state->Cast<TableScanLocalState>();
174 if (state.scan_state.table_state.row_group) {
175 return state.scan_state.table_state.batch_index;
176 }
177 if (state.scan_state.local_state.row_group) {
178 return state.scan_state.table_state.batch_index + state.scan_state.local_state.batch_index;
179 }
180 return 0;
181}
182
183BindInfo TableScanGetBindInfo(const FunctionData *bind_data) {
184 return BindInfo(ScanType::TABLE);
185}
186
187void TableScanDependency(DependencyList &entries, const FunctionData *bind_data_p) {
188 auto &bind_data = bind_data_p->Cast<TableScanBindData>();
189 entries.AddDependency(entry&: bind_data.table);
190}
191
192unique_ptr<NodeStatistics> TableScanCardinality(ClientContext &context, const FunctionData *bind_data_p) {
193 auto &bind_data = bind_data_p->Cast<TableScanBindData>();
194 auto &local_storage = LocalStorage::Get(context, catalog&: bind_data.table.catalog);
195 auto &storage = bind_data.table.GetStorage();
196 idx_t estimated_cardinality = storage.info->cardinality + local_storage.AddedRows(table&: bind_data.table.GetStorage());
197 return make_uniq<NodeStatistics>(args&: storage.info->cardinality, args&: estimated_cardinality);
198}
199
200//===--------------------------------------------------------------------===//
201// Index Scan
202//===--------------------------------------------------------------------===//
203struct IndexScanGlobalState : public GlobalTableFunctionState {
204 explicit IndexScanGlobalState(data_ptr_t row_id_data) : row_ids(LogicalType::ROW_TYPE, row_id_data) {
205 }
206
207 Vector row_ids;
208 ColumnFetchState fetch_state;
209 TableScanState local_storage_state;
210 vector<storage_t> column_ids;
211 bool finished;
212};
213
214static unique_ptr<GlobalTableFunctionState> IndexScanInitGlobal(ClientContext &context, TableFunctionInitInput &input) {
215 auto &bind_data = input.bind_data->Cast<TableScanBindData>();
216 data_ptr_t row_id_data = nullptr;
217 if (!bind_data.result_ids.empty()) {
218 row_id_data = (data_ptr_t)&bind_data.result_ids[0]; // NOLINT - this is not pretty
219 }
220 auto result = make_uniq<IndexScanGlobalState>(args&: row_id_data);
221 auto &local_storage = LocalStorage::Get(context, catalog&: bind_data.table.catalog);
222
223 result->column_ids.reserve(n: input.column_ids.size());
224 for (auto &id : input.column_ids) {
225 result->column_ids.push_back(x: GetStorageIndex(table&: bind_data.table, column_id: id));
226 }
227 result->local_storage_state.Initialize(column_ids: result->column_ids, table_filters: input.filters.get());
228 local_storage.InitializeScan(table&: bind_data.table.GetStorage(), state&: result->local_storage_state.local_state, table_filters: input.filters);
229
230 result->finished = false;
231 return std::move(result);
232}
233
234static void IndexScanFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) {
235 auto &bind_data = data_p.bind_data->Cast<TableScanBindData>();
236 auto &state = data_p.global_state->Cast<IndexScanGlobalState>();
237 auto &transaction = DuckTransaction::Get(context, catalog&: bind_data.table.catalog);
238 auto &local_storage = LocalStorage::Get(transaction);
239
240 if (!state.finished) {
241 bind_data.table.GetStorage().Fetch(transaction, result&: output, column_ids: state.column_ids, row_ids: state.row_ids,
242 fetch_count: bind_data.result_ids.size(), state&: state.fetch_state);
243 state.finished = true;
244 }
245 if (output.size() == 0) {
246 local_storage.Scan(state&: state.local_storage_state.local_state, column_ids: state.column_ids, result&: output);
247 }
248}
249
250static void RewriteIndexExpression(Index &index, LogicalGet &get, Expression &expr, bool &rewrite_possible) {
251 if (expr.type == ExpressionType::BOUND_COLUMN_REF) {
252 auto &bound_colref = expr.Cast<BoundColumnRefExpression>();
253 // bound column ref: rewrite to fit in the current set of bound column ids
254 bound_colref.binding.table_index = get.table_index;
255 column_t referenced_column = index.column_ids[bound_colref.binding.column_index];
256 // search for the referenced column in the set of column_ids
257 for (idx_t i = 0; i < get.column_ids.size(); i++) {
258 if (get.column_ids[i] == referenced_column) {
259 bound_colref.binding.column_index = i;
260 return;
261 }
262 }
263 // column id not found in bound columns in the LogicalGet: rewrite not possible
264 rewrite_possible = false;
265 }
266 ExpressionIterator::EnumerateChildren(
267 expression&: expr, callback: [&](Expression &child) { RewriteIndexExpression(index, get, expr&: child, rewrite_possible); });
268}
269
270void TableScanPushdownComplexFilter(ClientContext &context, LogicalGet &get, FunctionData *bind_data_p,
271 vector<unique_ptr<Expression>> &filters) {
272 auto &bind_data = bind_data_p->Cast<TableScanBindData>();
273 auto &table = bind_data.table;
274 auto &storage = table.GetStorage();
275
276 auto &config = ClientConfig::GetConfig(context);
277 if (!config.enable_optimizer) {
278 // we only push index scans if the optimizer is enabled
279 return;
280 }
281 if (bind_data.is_index_scan) {
282 return;
283 }
284 if (filters.empty()) {
285 // no indexes or no filters: skip the pushdown
286 return;
287 }
288 // behold
289 storage.info->indexes.Scan(callback: [&](Index &index) {
290 // first rewrite the index expression so the ColumnBindings align with the column bindings of the current table
291
292 if (index.unbound_expressions.size() > 1) {
293 // NOTE: index scans are not (yet) supported for compound index keys
294 return false;
295 }
296
297 auto index_expression = index.unbound_expressions[0]->Copy();
298 bool rewrite_possible = true;
299 RewriteIndexExpression(index, get, expr&: *index_expression, rewrite_possible);
300 if (!rewrite_possible) {
301 // could not rewrite!
302 return false;
303 }
304
305 Value low_value, high_value, equal_value;
306 ExpressionType low_comparison_type = ExpressionType::INVALID, high_comparison_type = ExpressionType::INVALID;
307 // try to find a matching index for any of the filter expressions
308 for (auto &filter : filters) {
309 auto &expr = *filter;
310
311 // create a matcher for a comparison with a constant
312 ComparisonExpressionMatcher matcher;
313 // match on a comparison type
314 matcher.expr_type = make_uniq<ComparisonExpressionTypeMatcher>();
315 // match on a constant comparison with the indexed expression
316 matcher.matchers.push_back(x: make_uniq<ExpressionEqualityMatcher>(args&: *index_expression));
317 matcher.matchers.push_back(x: make_uniq<ConstantExpressionMatcher>());
318
319 matcher.policy = SetMatcher::Policy::UNORDERED;
320
321 vector<reference<Expression>> bindings;
322 if (matcher.Match(expr_&: expr, bindings)) {
323 // range or equality comparison with constant value
324 // we can use our index here
325 // bindings[0] = the expression
326 // bindings[1] = the index expression
327 // bindings[2] = the constant
328 auto &comparison = bindings[0].get().Cast<BoundComparisonExpression>();
329 auto constant_value = bindings[2].get().Cast<BoundConstantExpression>().value;
330 auto comparison_type = comparison.type;
331 if (comparison.left->type == ExpressionType::VALUE_CONSTANT) {
332 // the expression is on the right side, we flip them around
333 comparison_type = FlipComparisonExpression(type: comparison_type);
334 }
335 if (comparison_type == ExpressionType::COMPARE_EQUAL) {
336 // equality value
337 // equality overrides any other bounds so we just break here
338 equal_value = constant_value;
339 break;
340 } else if (comparison_type == ExpressionType::COMPARE_GREATERTHANOREQUALTO ||
341 comparison_type == ExpressionType::COMPARE_GREATERTHAN) {
342 // greater than means this is a lower bound
343 low_value = constant_value;
344 low_comparison_type = comparison_type;
345 } else {
346 // smaller than means this is an upper bound
347 high_value = constant_value;
348 high_comparison_type = comparison_type;
349 }
350 } else if (expr.type == ExpressionType::COMPARE_BETWEEN) {
351 // BETWEEN expression
352 auto &between = expr.Cast<BoundBetweenExpression>();
353 if (!between.input->Equals(other: *index_expression)) {
354 // expression doesn't match the current index expression
355 continue;
356 }
357 if (between.lower->type != ExpressionType::VALUE_CONSTANT ||
358 between.upper->type != ExpressionType::VALUE_CONSTANT) {
359 // not a constant comparison
360 continue;
361 }
362 low_value = (between.lower->Cast<BoundConstantExpression>()).value;
363 low_comparison_type = between.lower_inclusive ? ExpressionType::COMPARE_GREATERTHANOREQUALTO
364 : ExpressionType::COMPARE_GREATERTHAN;
365 high_value = (between.upper->Cast<BoundConstantExpression>()).value;
366 high_comparison_type = between.upper_inclusive ? ExpressionType::COMPARE_LESSTHANOREQUALTO
367 : ExpressionType::COMPARE_LESSTHAN;
368 break;
369 }
370 }
371 if (!equal_value.IsNull() || !low_value.IsNull() || !high_value.IsNull()) {
372 // we can scan this index using this predicate: try a scan
373 auto &transaction = Transaction::Get(context, catalog&: bind_data.table.catalog);
374 unique_ptr<IndexScanState> index_state;
375 if (!equal_value.IsNull()) {
376 // equality predicate
377 index_state =
378 index.InitializeScanSinglePredicate(transaction, value: equal_value, expression_type: ExpressionType::COMPARE_EQUAL);
379 } else if (!low_value.IsNull() && !high_value.IsNull()) {
380 // two-sided predicate
381 index_state = index.InitializeScanTwoPredicates(transaction, low_value, low_expression_type: low_comparison_type, high_value,
382 high_expression_type: high_comparison_type);
383 } else if (!low_value.IsNull()) {
384 // less than predicate
385 index_state = index.InitializeScanSinglePredicate(transaction, value: low_value, expression_type: low_comparison_type);
386 } else {
387 D_ASSERT(!high_value.IsNull());
388 index_state = index.InitializeScanSinglePredicate(transaction, value: high_value, expression_type: high_comparison_type);
389 }
390 if (index.Scan(transaction, table: storage, state&: *index_state, STANDARD_VECTOR_SIZE, result_ids&: bind_data.result_ids)) {
391 // use an index scan!
392 bind_data.is_index_scan = true;
393 get.function = TableScanFunction::GetIndexScanFunction();
394 } else {
395 bind_data.result_ids.clear();
396 }
397 return true;
398 }
399 return false;
400 });
401}
402
403string TableScanToString(const FunctionData *bind_data_p) {
404 auto &bind_data = bind_data_p->Cast<TableScanBindData>();
405 string result = bind_data.table.name;
406 return result;
407}
408
409static void TableScanSerialize(FieldWriter &writer, const FunctionData *bind_data_p, const TableFunction &function) {
410 auto &bind_data = bind_data_p->Cast<TableScanBindData>();
411
412 writer.WriteString(val: bind_data.table.schema.name);
413 writer.WriteString(val: bind_data.table.name);
414 writer.WriteField<bool>(element: bind_data.is_index_scan);
415 writer.WriteField<bool>(element: bind_data.is_create_index);
416 writer.WriteList<row_t>(elements: bind_data.result_ids);
417 writer.WriteString(val: bind_data.table.schema.catalog.GetName());
418}
419
420static unique_ptr<FunctionData> TableScanDeserialize(PlanDeserializationState &state, FieldReader &reader,
421 TableFunction &function) {
422 auto schema_name = reader.ReadRequired<string>();
423 auto table_name = reader.ReadRequired<string>();
424 auto is_index_scan = reader.ReadRequired<bool>();
425 auto is_create_index = reader.ReadRequired<bool>();
426 auto result_ids = reader.ReadRequiredList<row_t>();
427 auto catalog_name = reader.ReadField<string>(INVALID_CATALOG);
428
429 auto &catalog_entry = Catalog::GetEntry<TableCatalogEntry>(context&: state.context, catalog_name, schema_name, name: table_name);
430 if (catalog_entry.type != CatalogType::TABLE_ENTRY) {
431 throw SerializationException("Cant find table for %s.%s", schema_name, table_name);
432 }
433
434 auto result = make_uniq<TableScanBindData>(args&: catalog_entry.Cast<DuckTableEntry>());
435 result->is_index_scan = is_index_scan;
436 result->is_create_index = is_create_index;
437 result->result_ids = std::move(result_ids);
438 return std::move(result);
439}
440
441TableFunction TableScanFunction::GetIndexScanFunction() {
442 TableFunction scan_function("index_scan", {}, IndexScanFunction);
443 scan_function.init_local = nullptr;
444 scan_function.init_global = IndexScanInitGlobal;
445 scan_function.statistics = TableScanStatistics;
446 scan_function.dependency = TableScanDependency;
447 scan_function.cardinality = TableScanCardinality;
448 scan_function.pushdown_complex_filter = nullptr;
449 scan_function.to_string = TableScanToString;
450 scan_function.table_scan_progress = nullptr;
451 scan_function.get_batch_index = nullptr;
452 scan_function.projection_pushdown = true;
453 scan_function.filter_pushdown = false;
454 scan_function.serialize = TableScanSerialize;
455 scan_function.deserialize = TableScanDeserialize;
456 return scan_function;
457}
458
459TableFunction TableScanFunction::GetFunction() {
460 TableFunction scan_function("seq_scan", {}, TableScanFunc);
461 scan_function.init_local = TableScanInitLocal;
462 scan_function.init_global = TableScanInitGlobal;
463 scan_function.statistics = TableScanStatistics;
464 scan_function.dependency = TableScanDependency;
465 scan_function.cardinality = TableScanCardinality;
466 scan_function.pushdown_complex_filter = TableScanPushdownComplexFilter;
467 scan_function.to_string = TableScanToString;
468 scan_function.table_scan_progress = TableScanProgress;
469 scan_function.get_batch_index = TableScanGetBatchIndex;
470 scan_function.get_batch_info = TableScanGetBindInfo;
471 scan_function.projection_pushdown = true;
472 scan_function.filter_pushdown = true;
473 scan_function.filter_prune = true;
474 scan_function.serialize = TableScanSerialize;
475 scan_function.deserialize = TableScanDeserialize;
476 return scan_function;
477}
478
479optional_ptr<TableCatalogEntry> TableScanFunction::GetTableEntry(const TableFunction &function,
480 const optional_ptr<FunctionData> bind_data_p) {
481 if (function.function != TableScanFunc || !bind_data_p) {
482 return nullptr;
483 }
484 auto &bind_data = bind_data_p->Cast<TableScanBindData>();
485 return &bind_data.table;
486}
487
488void TableScanFunction::RegisterFunction(BuiltinFunctions &set) {
489 TableFunctionSet table_scan_set("seq_scan");
490 table_scan_set.AddFunction(function: GetFunction());
491 set.AddFunction(set: std::move(table_scan_set));
492
493 set.AddFunction(function: GetIndexScanFunction());
494}
495
496void BuiltinFunctions::RegisterTableScanFunctions() {
497 TableScanFunction::RegisterFunction(set&: *this);
498}
499
500} // namespace duckdb
501