1#include "duckdb/parser/statement/insert_statement.hpp"
2#include "duckdb/parser/query_node/select_node.hpp"
3#include "duckdb/parser/tableref/expressionlistref.hpp"
4#include "duckdb/parser/statement/update_statement.hpp"
5
6namespace duckdb {
7
8OnConflictInfo::OnConflictInfo() : action_type(OnConflictAction::THROW) {
9}
10
11OnConflictInfo::OnConflictInfo(const OnConflictInfo &other)
12 : action_type(other.action_type), indexed_columns(other.indexed_columns) {
13 if (other.set_info) {
14 set_info = other.set_info->Copy();
15 }
16 if (other.condition) {
17 condition = other.condition->Copy();
18 }
19}
20
21unique_ptr<OnConflictInfo> OnConflictInfo::Copy() const {
22 return unique_ptr<OnConflictInfo>(new OnConflictInfo(*this));
23}
24
25InsertStatement::InsertStatement()
26 : SQLStatement(StatementType::INSERT_STATEMENT), schema(DEFAULT_SCHEMA), catalog(INVALID_CATALOG) {
27}
28
29InsertStatement::InsertStatement(const InsertStatement &other)
30 : SQLStatement(other), select_statement(unique_ptr_cast<SQLStatement, SelectStatement>(
31 src: other.select_statement ? other.select_statement->Copy() : nullptr)),
32 columns(other.columns), table(other.table), schema(other.schema), catalog(other.catalog),
33 default_values(other.default_values), column_order(other.column_order) {
34 cte_map = other.cte_map.Copy();
35 for (auto &expr : other.returning_list) {
36 returning_list.emplace_back(args: expr->Copy());
37 }
38 if (other.table_ref) {
39 table_ref = other.table_ref->Copy();
40 }
41 if (other.on_conflict_info) {
42 on_conflict_info = other.on_conflict_info->Copy();
43 }
44}
45
46string InsertStatement::OnConflictActionToString(OnConflictAction action) {
47 switch (action) {
48 case OnConflictAction::NOTHING:
49 return "DO NOTHING";
50 case OnConflictAction::REPLACE:
51 case OnConflictAction::UPDATE:
52 return "DO UPDATE";
53 case OnConflictAction::THROW:
54 // Explicitly left empty, for ToString purposes
55 return "";
56 default: {
57 throw NotImplementedException("type not implemented for OnConflictActionType");
58 }
59 }
60}
61
62string InsertStatement::ToString() const {
63 bool or_replace_shorthand_set = false;
64 string result;
65
66 result = cte_map.ToString();
67 result += "INSERT";
68 if (on_conflict_info && on_conflict_info->action_type == OnConflictAction::REPLACE) {
69 or_replace_shorthand_set = true;
70 result += " OR REPLACE";
71 }
72 result += " INTO ";
73 if (!catalog.empty()) {
74 result += KeywordHelper::WriteOptionallyQuoted(text: catalog) + ".";
75 }
76 if (!schema.empty()) {
77 result += KeywordHelper::WriteOptionallyQuoted(text: schema) + ".";
78 }
79 result += KeywordHelper::WriteOptionallyQuoted(text: table);
80 // Write the (optional) alias of the insert target
81 if (table_ref && !table_ref->alias.empty()) {
82 result += StringUtil::Format(fmt_str: " AS %s", params: KeywordHelper::WriteOptionallyQuoted(text: table_ref->alias));
83 }
84 if (column_order == InsertColumnOrder::INSERT_BY_NAME) {
85 result += " BY NAME";
86 }
87 if (!columns.empty()) {
88 result += " (";
89 for (idx_t i = 0; i < columns.size(); i++) {
90 if (i > 0) {
91 result += ", ";
92 }
93 result += KeywordHelper::WriteOptionallyQuoted(text: columns[i]);
94 }
95 result += " )";
96 }
97 result += " ";
98 auto values_list = GetValuesList();
99 if (values_list) {
100 D_ASSERT(!default_values);
101 values_list->alias = string();
102 result += values_list->ToString();
103 } else if (select_statement) {
104 D_ASSERT(!default_values);
105 result += select_statement->ToString();
106 } else {
107 D_ASSERT(default_values);
108 result += "DEFAULT VALUES";
109 }
110 if (!or_replace_shorthand_set && on_conflict_info) {
111 auto &conflict_info = *on_conflict_info;
112 result += " ON CONFLICT ";
113 // (optional) conflict target
114 if (!conflict_info.indexed_columns.empty()) {
115 result += "(";
116 auto &columns = conflict_info.indexed_columns;
117 for (auto it = columns.begin(); it != columns.end();) {
118 result += StringUtil::Lower(str: *it);
119 if (++it != columns.end()) {
120 result += ", ";
121 }
122 }
123 result += " )";
124 }
125
126 // (optional) where clause
127 if (conflict_info.condition) {
128 result += " WHERE " + conflict_info.condition->ToString();
129 }
130 result += " " + OnConflictActionToString(action: conflict_info.action_type);
131 if (conflict_info.set_info) {
132 D_ASSERT(conflict_info.action_type == OnConflictAction::UPDATE);
133 result += " SET ";
134 auto &set_info = *conflict_info.set_info;
135 D_ASSERT(set_info.columns.size() == set_info.expressions.size());
136 // SET <column_name> = <expression>
137 for (idx_t i = 0; i < set_info.columns.size(); i++) {
138 auto &column = set_info.columns[i];
139 auto &expr = set_info.expressions[i];
140 if (i) {
141 result += ", ";
142 }
143 result += StringUtil::Lower(str: column) + " = " + expr->ToString();
144 }
145 // (optional) where clause
146 if (set_info.condition) {
147 result += " WHERE " + set_info.condition->ToString();
148 }
149 }
150 }
151 if (!returning_list.empty()) {
152 result += " RETURNING ";
153 for (idx_t i = 0; i < returning_list.size(); i++) {
154 if (i > 0) {
155 result += ", ";
156 }
157 result += returning_list[i]->ToString();
158 }
159 }
160 return result;
161}
162
163unique_ptr<SQLStatement> InsertStatement::Copy() const {
164 return unique_ptr<InsertStatement>(new InsertStatement(*this));
165}
166
167optional_ptr<ExpressionListRef> InsertStatement::GetValuesList() const {
168 if (!select_statement) {
169 return nullptr;
170 }
171 if (select_statement->node->type != QueryNodeType::SELECT_NODE) {
172 return nullptr;
173 }
174 auto &node = select_statement->node->Cast<SelectNode>();
175 if (node.where_clause || node.qualify || node.having) {
176 return nullptr;
177 }
178 if (!node.cte_map.map.empty()) {
179 return nullptr;
180 }
181 if (!node.groups.grouping_sets.empty()) {
182 return nullptr;
183 }
184 if (node.aggregate_handling != AggregateHandling::STANDARD_HANDLING) {
185 return nullptr;
186 }
187 if (node.select_list.size() != 1 || node.select_list[0]->type != ExpressionType::STAR) {
188 return nullptr;
189 }
190 if (!node.from_table || node.from_table->type != TableReferenceType::EXPRESSION_LIST) {
191 return nullptr;
192 }
193 return &node.from_table->Cast<ExpressionListRef>();
194}
195
196} // namespace duckdb
197