1#include "duckdb/common/bind_helpers.hpp"
2#include "duckdb/common/file_system.hpp"
3#include "duckdb/common/multi_file_reader.hpp"
4#include "duckdb/common/serializer/buffered_serializer.hpp"
5#include "duckdb/common/string_util.hpp"
6#include "duckdb/common/types/column/column_data_collection.hpp"
7#include "duckdb/common/types/string_type.hpp"
8#include "duckdb/common/vector_operations/vector_operations.hpp"
9#include "duckdb/function/copy_function.hpp"
10#include "duckdb/function/scalar/string_functions.hpp"
11#include "duckdb/function/table/read_csv.hpp"
12#include "duckdb/parser/parsed_data/copy_info.hpp"
13
14#include <limits>
15
16namespace duckdb {
17
18void SubstringDetection(string &str_1, string &str_2, const string &name_str_1, const string &name_str_2) {
19 if (str_1.empty() || str_2.empty()) {
20 return;
21 }
22 if ((str_1.find(str: str_2) != string::npos || str_2.find(str: str_1) != std::string::npos)) {
23 throw BinderException("%s must not appear in the %s specification and vice versa", name_str_1, name_str_2);
24 }
25}
26
27//===--------------------------------------------------------------------===//
28// Bind
29//===--------------------------------------------------------------------===//
30
31void BaseCSVData::Finalize() {
32 // verify that the options are correct in the final pass
33 if (options.escape.empty()) {
34 options.escape = options.quote;
35 }
36 // escape and delimiter must not be substrings of each other
37 if (options.has_delimiter && options.has_escape) {
38 SubstringDetection(str_1&: options.delimiter, str_2&: options.escape, name_str_1: "DELIMITER", name_str_2: "ESCAPE");
39 }
40 // delimiter and quote must not be substrings of each other
41 if (options.has_quote && options.has_delimiter) {
42 SubstringDetection(str_1&: options.quote, str_2&: options.delimiter, name_str_1: "DELIMITER", name_str_2: "QUOTE");
43 }
44 // escape and quote must not be substrings of each other (but can be the same)
45 if (options.quote != options.escape && options.has_quote && options.has_escape) {
46 SubstringDetection(str_1&: options.quote, str_2&: options.escape, name_str_1: "QUOTE", name_str_2: "ESCAPE");
47 }
48 if (!options.null_str.empty()) {
49 // null string and delimiter must not be substrings of each other
50 if (options.has_delimiter) {
51 SubstringDetection(str_1&: options.delimiter, str_2&: options.null_str, name_str_1: "DELIMITER", name_str_2: "NULL");
52 }
53 // quote/escape and nullstr must not be substrings of each other
54 if (options.has_quote) {
55 SubstringDetection(str_1&: options.quote, str_2&: options.null_str, name_str_1: "QUOTE", name_str_2: "NULL");
56 }
57 if (options.has_escape) {
58 SubstringDetection(str_1&: options.escape, str_2&: options.null_str, name_str_1: "ESCAPE", name_str_2: "NULL");
59 }
60 }
61
62 if (!options.prefix.empty() || !options.suffix.empty()) {
63 if (options.prefix.empty() || options.suffix.empty()) {
64 throw BinderException("COPY ... (FORMAT CSV) must have both PREFIX and SUFFIX, or none at all");
65 }
66 if (options.header) {
67 throw BinderException("COPY ... (FORMAT CSV)'s HEADER cannot be combined with PREFIX/SUFFIX");
68 }
69 }
70}
71
72static unique_ptr<FunctionData> WriteCSVBind(ClientContext &context, CopyInfo &info, vector<string> &names,
73 vector<LogicalType> &sql_types) {
74 auto bind_data = make_uniq<WriteCSVData>(args&: info.file_path, args&: sql_types, args&: names);
75
76 // check all the options in the copy info
77 for (auto &option : info.options) {
78 auto loption = StringUtil::Lower(str: option.first);
79 auto &set = option.second;
80 bind_data->options.SetWriteOption(loption, value: ConvertVectorToValue(set: std::move(set)));
81 }
82 // verify the parsed options
83 if (bind_data->options.force_quote.empty()) {
84 // no FORCE_QUOTE specified: initialize to false
85 bind_data->options.force_quote.resize(new_size: names.size(), x: false);
86 }
87 bind_data->Finalize();
88 bind_data->is_simple = bind_data->options.delimiter.size() == 1 && bind_data->options.escape.size() == 1 &&
89 bind_data->options.quote.size() == 1;
90 if (bind_data->is_simple) {
91 bind_data->requires_quotes = make_unsafe_uniq_array<bool>(n: 256);
92 memset(s: bind_data->requires_quotes.get(), c: 0, n: sizeof(bool) * 256);
93 bind_data->requires_quotes['\n'] = true;
94 bind_data->requires_quotes['\r'] = true;
95 bind_data->requires_quotes[bind_data->options.delimiter[0]] = true;
96 bind_data->requires_quotes[bind_data->options.quote[0]] = true;
97 }
98 if (!bind_data->options.write_newline.empty()) {
99 bind_data->newline = bind_data->options.write_newline;
100 }
101 return std::move(bind_data);
102}
103
104static unique_ptr<FunctionData> ReadCSVBind(ClientContext &context, CopyInfo &info, vector<string> &expected_names,
105 vector<LogicalType> &expected_types) {
106 auto bind_data = make_uniq<ReadCSVData>();
107 bind_data->csv_types = expected_types;
108 bind_data->csv_names = expected_names;
109 bind_data->return_types = expected_types;
110 bind_data->return_names = expected_names;
111 bind_data->files = MultiFileReader::GetFileList(context, input: Value(info.file_path), name: "CSV");
112
113 auto &options = bind_data->options;
114
115 // check all the options in the copy info
116 for (auto &option : info.options) {
117 auto loption = StringUtil::Lower(str: option.first);
118 auto &set = option.second;
119 options.SetReadOption(loption, value: ConvertVectorToValue(set: std::move(set)), expected_names);
120 }
121 // verify the parsed options
122 if (options.force_not_null.empty()) {
123 // no FORCE_QUOTE specified: initialize to false
124 options.force_not_null.resize(new_size: expected_types.size(), x: false);
125 }
126 bind_data->FinalizeRead(context);
127 if (!bind_data->single_threaded && options.auto_detect) {
128 options.file_path = bind_data->files[0];
129 options.name_list = expected_names;
130 auto initial_reader = make_uniq<BufferedCSVReader>(args&: context, args&: options, args&: expected_types);
131 options = initial_reader->options;
132 }
133 return std::move(bind_data);
134}
135
136//===--------------------------------------------------------------------===//
137// Helper writing functions
138//===--------------------------------------------------------------------===//
139static string AddEscapes(string &to_be_escaped, const string &escape, const string &val) {
140 idx_t i = 0;
141 string new_val = "";
142 idx_t found = val.find(str: to_be_escaped);
143
144 while (found != string::npos) {
145 while (i < found) {
146 new_val += val[i];
147 i++;
148 }
149 new_val += escape;
150 found = val.find(str: to_be_escaped, pos: found + escape.length());
151 }
152 while (i < val.length()) {
153 new_val += val[i];
154 i++;
155 }
156 return new_val;
157}
158
159static bool RequiresQuotes(WriteCSVData &csv_data, const char *str, idx_t len) {
160 auto &options = csv_data.options;
161 // check if the string is equal to the null string
162 if (len == options.null_str.size() && memcmp(s1: str, s2: options.null_str.c_str(), n: len) == 0) {
163 return true;
164 }
165 if (csv_data.is_simple) {
166 // simple CSV: check for newlines, quotes and delimiter all at once
167 auto str_data = reinterpret_cast<const_data_ptr_t>(str);
168 for (idx_t i = 0; i < len; i++) {
169 if (csv_data.requires_quotes[str_data[i]]) {
170 // this byte requires quotes - write a quoted string
171 return true;
172 }
173 }
174 // no newline, quote or delimiter in the string
175 // no quoting or escaping necessary
176 return false;
177 } else {
178 // CSV with complex quotes/delimiter (multiple bytes)
179
180 // first check for \n, \r, \n\r in string
181 for (idx_t i = 0; i < len; i++) {
182 if (str[i] == '\n' || str[i] == '\r') {
183 // newline, write a quoted string
184 return true;
185 }
186 }
187
188 // check for delimiter
189 if (options.delimiter.length() != 0 &&
190 ContainsFun::Find(haystack: const_uchar_ptr_cast(src: str), haystack_size: len, needle: const_uchar_ptr_cast(src: options.delimiter.c_str()),
191 needle_size: options.delimiter.size()) != DConstants::INVALID_INDEX) {
192 return true;
193 }
194 // check for quote
195 if (options.quote.length() != 0 &&
196 ContainsFun::Find(haystack: const_uchar_ptr_cast(src: str), haystack_size: len, needle: const_uchar_ptr_cast(src: options.quote.c_str()),
197 needle_size: options.quote.size()) != DConstants::INVALID_INDEX) {
198 return true;
199 }
200 return false;
201 }
202}
203
204static void WriteQuotedString(Serializer &serializer, WriteCSVData &csv_data, const char *str, idx_t len,
205 bool force_quote) {
206 auto &options = csv_data.options;
207 if (!force_quote) {
208 // force quote is disabled: check if we need to add quotes anyway
209 force_quote = RequiresQuotes(csv_data, str, len);
210 }
211 if (force_quote) {
212 // quoting is enabled: we might need to escape things in the string
213 bool requires_escape = false;
214 if (csv_data.is_simple) {
215 // simple CSV
216 // do a single loop to check for a quote or escape value
217 for (idx_t i = 0; i < len; i++) {
218 if (str[i] == options.quote[0] || str[i] == options.escape[0]) {
219 requires_escape = true;
220 break;
221 }
222 }
223 } else {
224 // complex CSV
225 // check for quote or escape separately
226 if (options.quote.length() != 0 &&
227 ContainsFun::Find(haystack: const_uchar_ptr_cast(src: str), haystack_size: len, needle: const_uchar_ptr_cast(src: options.quote.c_str()),
228 needle_size: options.quote.size()) != DConstants::INVALID_INDEX) {
229 requires_escape = true;
230 } else if (options.escape.length() != 0 &&
231 ContainsFun::Find(haystack: const_uchar_ptr_cast(src: str), haystack_size: len, needle: const_uchar_ptr_cast(src: options.escape.c_str()),
232 needle_size: options.escape.size()) != DConstants::INVALID_INDEX) {
233 requires_escape = true;
234 }
235 }
236 if (!requires_escape) {
237 // fast path: no need to escape anything
238 serializer.WriteBufferData(str: options.quote);
239 serializer.WriteData(buffer: const_data_ptr_cast(src: str), write_size: len);
240 serializer.WriteBufferData(str: options.quote);
241 return;
242 }
243
244 // slow path: need to add escapes
245 string new_val(str, len);
246 new_val = AddEscapes(to_be_escaped&: options.escape, escape: options.escape, val: new_val);
247 if (options.escape != options.quote) {
248 // need to escape quotes separately
249 new_val = AddEscapes(to_be_escaped&: options.quote, escape: options.escape, val: new_val);
250 }
251 serializer.WriteBufferData(str: options.quote);
252 serializer.WriteBufferData(str: new_val);
253 serializer.WriteBufferData(str: options.quote);
254 } else {
255 serializer.WriteData(buffer: const_data_ptr_cast(src: str), write_size: len);
256 }
257}
258
259//===--------------------------------------------------------------------===//
260// Sink
261//===--------------------------------------------------------------------===//
262struct LocalWriteCSVData : public LocalFunctionData {
263 //! The thread-local buffer to write data into
264 BufferedSerializer serializer;
265 //! A chunk with VARCHAR columns to cast intermediates into
266 DataChunk cast_chunk;
267 //! If we've written any rows yet, allows us to prevent a trailing comma when writing JSON ARRAY
268 bool written_anything = false;
269};
270
271struct GlobalWriteCSVData : public GlobalFunctionData {
272 GlobalWriteCSVData(FileSystem &fs, const string &file_path, FileCompressionType compression)
273 : fs(fs), written_anything(false) {
274 handle = fs.OpenFile(path: file_path, flags: FileFlags::FILE_FLAGS_WRITE | FileFlags::FILE_FLAGS_FILE_CREATE_NEW,
275 lock: FileLockType::WRITE_LOCK, compression);
276 }
277
278 //! Write generic data, e.g., CSV header
279 void WriteData(const_data_ptr_t data, idx_t size) {
280 lock_guard<mutex> flock(lock);
281 handle->Write(buffer: (void *)data, nr_bytes: size);
282 }
283
284 void WriteData(const char *data, idx_t size) {
285 WriteData(data: const_data_ptr_cast(src: data), size);
286 }
287
288 //! Write rows
289 void WriteRows(const_data_ptr_t data, idx_t size, const string &newline) {
290 lock_guard<mutex> flock(lock);
291 if (written_anything) {
292 handle->Write(buffer: (void *)newline.c_str(), nr_bytes: newline.length());
293 } else {
294 written_anything = true;
295 }
296 handle->Write(buffer: (void *)data, nr_bytes: size);
297 }
298
299 FileSystem &fs;
300 //! The mutex for writing to the physical file
301 mutex lock;
302 //! The file handle to write to
303 unique_ptr<FileHandle> handle;
304 //! If we've written any rows yet, allows us to prevent a trailing comma when writing JSON ARRAY
305 bool written_anything;
306};
307
308static unique_ptr<LocalFunctionData> WriteCSVInitializeLocal(ExecutionContext &context, FunctionData &bind_data) {
309 auto &csv_data = bind_data.Cast<WriteCSVData>();
310 auto local_data = make_uniq<LocalWriteCSVData>();
311
312 // create the chunk with VARCHAR types
313 vector<LogicalType> types;
314 types.resize(new_size: csv_data.options.name_list.size(), x: LogicalType::VARCHAR);
315
316 local_data->cast_chunk.Initialize(allocator&: Allocator::Get(context&: context.client), types);
317 return std::move(local_data);
318}
319
320static unique_ptr<GlobalFunctionData> WriteCSVInitializeGlobal(ClientContext &context, FunctionData &bind_data,
321 const string &file_path) {
322 auto &csv_data = bind_data.Cast<WriteCSVData>();
323 auto &options = csv_data.options;
324 auto global_data =
325 make_uniq<GlobalWriteCSVData>(args&: FileSystem::GetFileSystem(context), args: file_path, args&: options.compression);
326
327 if (!options.prefix.empty()) {
328 global_data->WriteData(data: options.prefix.c_str(), size: options.prefix.size());
329 }
330
331 if (options.header) {
332 BufferedSerializer serializer;
333 // write the header line to the file
334 for (idx_t i = 0; i < csv_data.options.name_list.size(); i++) {
335 if (i != 0) {
336 serializer.WriteBufferData(str: options.delimiter);
337 }
338 WriteQuotedString(serializer, csv_data, str: csv_data.options.name_list[i].c_str(),
339 len: csv_data.options.name_list[i].size(), force_quote: false);
340 }
341 serializer.WriteBufferData(str: csv_data.newline);
342
343 global_data->WriteData(data: serializer.blob.data.get(), size: serializer.blob.size);
344 }
345
346 return std::move(global_data);
347}
348
349static void WriteCSVChunkInternal(ClientContext &context, FunctionData &bind_data, DataChunk &cast_chunk,
350 BufferedSerializer &writer, DataChunk &input, bool &written_anything) {
351 auto &csv_data = bind_data.Cast<WriteCSVData>();
352 auto &options = csv_data.options;
353
354 // first cast the columns of the chunk to varchar
355 cast_chunk.Reset();
356 cast_chunk.SetCardinality(input);
357 for (idx_t col_idx = 0; col_idx < input.ColumnCount(); col_idx++) {
358 if (csv_data.sql_types[col_idx].id() == LogicalTypeId::VARCHAR) {
359 // VARCHAR, just reinterpret (cannot reference, because LogicalTypeId::VARCHAR is used by the JSON type too)
360 cast_chunk.data[col_idx].Reinterpret(other&: input.data[col_idx]);
361 } else if (options.has_format[LogicalTypeId::DATE] && csv_data.sql_types[col_idx].id() == LogicalTypeId::DATE) {
362 // use the date format to cast the chunk
363 csv_data.options.write_date_format[LogicalTypeId::DATE].ConvertDateVector(
364 input&: input.data[col_idx], result&: cast_chunk.data[col_idx], count: input.size());
365 } else if (options.has_format[LogicalTypeId::TIMESTAMP] &&
366 (csv_data.sql_types[col_idx].id() == LogicalTypeId::TIMESTAMP ||
367 csv_data.sql_types[col_idx].id() == LogicalTypeId::TIMESTAMP_TZ)) {
368 // use the timestamp format to cast the chunk
369 csv_data.options.write_date_format[LogicalTypeId::TIMESTAMP].ConvertTimestampVector(
370 input&: input.data[col_idx], result&: cast_chunk.data[col_idx], count: input.size());
371 } else {
372 // non varchar column, perform the cast
373 VectorOperations::Cast(context, source&: input.data[col_idx], result&: cast_chunk.data[col_idx], count: input.size());
374 }
375 }
376
377 cast_chunk.Flatten();
378 // now loop over the vectors and output the values
379 for (idx_t row_idx = 0; row_idx < cast_chunk.size(); row_idx++) {
380 if (row_idx == 0 && !written_anything) {
381 written_anything = true;
382 } else {
383 writer.WriteBufferData(str: csv_data.newline);
384 }
385 // write values
386 for (idx_t col_idx = 0; col_idx < cast_chunk.ColumnCount(); col_idx++) {
387 if (col_idx != 0) {
388 writer.WriteBufferData(str: options.delimiter);
389 }
390 if (FlatVector::IsNull(vector: cast_chunk.data[col_idx], idx: row_idx)) {
391 // write null value
392 writer.WriteBufferData(str: options.null_str);
393 continue;
394 }
395
396 // non-null value, fetch the string value from the cast chunk
397 auto str_data = FlatVector::GetData<string_t>(vector&: cast_chunk.data[col_idx]);
398 // FIXME: we could gain some performance here by checking for certain types if they ever require quotes
399 // (e.g. integers only require quotes if the delimiter is a number, decimals only require quotes if the
400 // delimiter is a number or "." character)
401 WriteQuotedString(serializer&: writer, csv_data, str: str_data[row_idx].GetData(), len: str_data[row_idx].GetSize(),
402 force_quote: csv_data.options.force_quote[col_idx]);
403 }
404 }
405}
406
407static void WriteCSVSink(ExecutionContext &context, FunctionData &bind_data, GlobalFunctionData &gstate,
408 LocalFunctionData &lstate, DataChunk &input) {
409 auto &csv_data = bind_data.Cast<WriteCSVData>();
410 auto &local_data = lstate.Cast<LocalWriteCSVData>();
411 auto &global_state = gstate.Cast<GlobalWriteCSVData>();
412
413 // write data into the local buffer
414 WriteCSVChunkInternal(context&: context.client, bind_data, cast_chunk&: local_data.cast_chunk, writer&: local_data.serializer, input,
415 written_anything&: local_data.written_anything);
416
417 // check if we should flush what we have currently written
418 auto &writer = local_data.serializer;
419 if (writer.blob.size >= csv_data.flush_size) {
420 global_state.WriteRows(data: writer.blob.data.get(), size: writer.blob.size, newline: csv_data.newline);
421 writer.Reset();
422 local_data.written_anything = false;
423 }
424}
425
426//===--------------------------------------------------------------------===//
427// Combine
428//===--------------------------------------------------------------------===//
429static void WriteCSVCombine(ExecutionContext &context, FunctionData &bind_data, GlobalFunctionData &gstate,
430 LocalFunctionData &lstate) {
431 auto &local_data = lstate.Cast<LocalWriteCSVData>();
432 auto &global_state = gstate.Cast<GlobalWriteCSVData>();
433 auto &csv_data = bind_data.Cast<WriteCSVData>();
434 auto &writer = local_data.serializer;
435 // flush the local writer
436 if (local_data.written_anything) {
437 global_state.WriteRows(data: writer.blob.data.get(), size: writer.blob.size, newline: csv_data.newline);
438 writer.Reset();
439 }
440}
441
442//===--------------------------------------------------------------------===//
443// Finalize
444//===--------------------------------------------------------------------===//
445void WriteCSVFinalize(ClientContext &context, FunctionData &bind_data, GlobalFunctionData &gstate) {
446 auto &global_state = gstate.Cast<GlobalWriteCSVData>();
447 auto &csv_data = bind_data.Cast<WriteCSVData>();
448 auto &options = csv_data.options;
449
450 BufferedSerializer serializer;
451 if (!options.suffix.empty()) {
452 serializer.WriteBufferData(str: options.suffix);
453 } else if (global_state.written_anything) {
454 serializer.WriteBufferData(str: csv_data.newline);
455 }
456 global_state.WriteData(data: serializer.blob.data.get(), size: serializer.blob.size);
457
458 global_state.handle->Close();
459 global_state.handle.reset();
460}
461
462//===--------------------------------------------------------------------===//
463// Execution Mode
464//===--------------------------------------------------------------------===//
465CopyFunctionExecutionMode WriteCSVExecutionMode(bool preserve_insertion_order, bool supports_batch_index) {
466 if (!preserve_insertion_order) {
467 return CopyFunctionExecutionMode::PARALLEL_COPY_TO_FILE;
468 }
469 if (supports_batch_index) {
470 return CopyFunctionExecutionMode::BATCH_COPY_TO_FILE;
471 }
472 return CopyFunctionExecutionMode::REGULAR_COPY_TO_FILE;
473}
474//===--------------------------------------------------------------------===//
475// Prepare Batch
476//===--------------------------------------------------------------------===//
477struct WriteCSVBatchData : public PreparedBatchData {
478 //! The thread-local buffer to write data into
479 BufferedSerializer serializer;
480};
481
482unique_ptr<PreparedBatchData> WriteCSVPrepareBatch(ClientContext &context, FunctionData &bind_data,
483 GlobalFunctionData &gstate,
484 unique_ptr<ColumnDataCollection> collection) {
485 auto &csv_data = bind_data.Cast<WriteCSVData>();
486
487 // create the cast chunk with VARCHAR types
488 vector<LogicalType> types;
489 types.resize(new_size: csv_data.options.name_list.size(), x: LogicalType::VARCHAR);
490 DataChunk cast_chunk;
491 cast_chunk.Initialize(allocator&: Allocator::Get(context), types);
492
493 // write CSV chunks to the batch data
494 bool written_anything = false;
495 auto batch = make_uniq<WriteCSVBatchData>();
496 for (auto &chunk : collection->Chunks()) {
497 WriteCSVChunkInternal(context, bind_data, cast_chunk, writer&: batch->serializer, input&: chunk, written_anything);
498 }
499 return std::move(batch);
500}
501
502//===--------------------------------------------------------------------===//
503// Flush Batch
504//===--------------------------------------------------------------------===//
505void WriteCSVFlushBatch(ClientContext &context, FunctionData &bind_data, GlobalFunctionData &gstate,
506 PreparedBatchData &batch) {
507 auto &csv_batch = batch.Cast<WriteCSVBatchData>();
508 auto &global_state = gstate.Cast<GlobalWriteCSVData>();
509 auto &csv_data = bind_data.Cast<WriteCSVData>();
510 auto &writer = csv_batch.serializer;
511 global_state.WriteRows(data: writer.blob.data.get(), size: writer.blob.size, newline: csv_data.newline);
512 writer.Reset();
513}
514
515void CSVCopyFunction::RegisterFunction(BuiltinFunctions &set) {
516 CopyFunction info("csv");
517 info.copy_to_bind = WriteCSVBind;
518 info.copy_to_initialize_local = WriteCSVInitializeLocal;
519 info.copy_to_initialize_global = WriteCSVInitializeGlobal;
520 info.copy_to_sink = WriteCSVSink;
521 info.copy_to_combine = WriteCSVCombine;
522 info.copy_to_finalize = WriteCSVFinalize;
523 info.execution_mode = WriteCSVExecutionMode;
524 info.prepare_batch = WriteCSVPrepareBatch;
525 info.flush_batch = WriteCSVFlushBatch;
526
527 info.copy_from_bind = ReadCSVBind;
528 info.copy_from_function = ReadCSVTableFunction::GetFunction();
529
530 info.extension = "csv";
531
532 set.AddFunction(function: info);
533}
534
535} // namespace duckdb
536