1#include "duckdb/function/scalar/regexp.hpp"
2#include "duckdb/execution/expression_executor.hpp"
3#include "duckdb/planner/expression/bound_function_expression.hpp"
4#include "duckdb/function/scalar/string_functions.hpp"
5#include "re2/re2.h"
6
7namespace duckdb {
8
9using regexp_util::CreateStringPiece;
10using regexp_util::Extract;
11using regexp_util::ParseRegexOptions;
12using regexp_util::TryParseConstantPattern;
13
14unique_ptr<FunctionLocalState>
15RegexpExtractAll::InitLocalState(ExpressionState &state, const BoundFunctionExpression &expr, FunctionData *bind_data) {
16 auto &info = bind_data->Cast<RegexpBaseBindData>();
17 if (info.constant_pattern) {
18 return make_uniq<RegexLocalState>(args&: info, args: true);
19 }
20 return nullptr;
21}
22
23// Forwards startpos automatically
24bool ExtractAll(duckdb_re2::StringPiece &input, duckdb_re2::RE2 &pattern, idx_t *startpos,
25 duckdb_re2::StringPiece *groups, int ngroups) {
26
27 D_ASSERT(pattern.ok());
28 D_ASSERT(pattern.NumberOfCapturingGroups() == ngroups);
29
30 if (!pattern.Match(text: input, startpos: *startpos, endpos: input.size(), re_anchor: pattern.Anchored(), submatch: groups, nsubmatch: ngroups + 1)) {
31 return false;
32 }
33 idx_t consumed = static_cast<size_t>(groups[0].end() - (input.begin() + *startpos));
34 if (!consumed) {
35 // Empty match found, have to manually forward the input
36 // to avoid an infinite loop
37 // FIXME: support unicode characters
38 consumed++;
39 while (*startpos + consumed < input.length() && !LengthFun::IsCharacter(c: input[*startpos + consumed])) {
40 consumed++;
41 }
42 }
43 *startpos += consumed;
44 return true;
45}
46
47void ExtractSingleTuple(const string_t &string, duckdb_re2::RE2 &pattern, int32_t group, RegexStringPieceArgs &args,
48 Vector &result, idx_t row) {
49 auto input = CreateStringPiece(input: string);
50
51 auto &child_vector = ListVector::GetEntry(vector&: result);
52 auto list_content = FlatVector::GetData<string_t>(vector&: child_vector);
53 auto &child_validity = FlatVector::Validity(vector&: child_vector);
54
55 auto current_list_size = ListVector::GetListSize(vector: result);
56 auto current_list_capacity = ListVector::GetListCapacity(vector: result);
57
58 auto result_data = FlatVector::GetData<list_entry_t>(vector&: result);
59 auto &list_entry = result_data[row];
60 list_entry.offset = current_list_size;
61
62 if (group < 0) {
63 list_entry.length = 0;
64 return;
65 }
66 // If the requested group index is out of bounds
67 // we want to throw only if there is a match
68 bool throw_on_group_found = (idx_t)group > args.size;
69
70 idx_t startpos = 0;
71 for (idx_t iteration = 0; ExtractAll(input, pattern, startpos: &startpos, groups: args.group_buffer, ngroups: args.size); iteration++) {
72 if (!iteration && throw_on_group_found) {
73 throw InvalidInputException("Pattern has %d groups. Cannot access group %d", args.size, group);
74 }
75
76 // Make sure we have enough room for the new entries
77 if (current_list_size + 1 >= current_list_capacity) {
78 ListVector::Reserve(vec&: result, required_capacity: current_list_capacity * 2);
79 current_list_capacity = ListVector::GetListCapacity(vector: result);
80 list_content = FlatVector::GetData<string_t>(vector&: child_vector);
81 }
82
83 // Write the captured groups into the list-child vector
84 auto &match_group = args.group_buffer[group];
85
86 idx_t child_idx = current_list_size;
87 if (match_group.empty()) {
88 // This group was not matched
89 list_content[child_idx] = string_t(string.GetData(), 0);
90 if (match_group.begin() == nullptr) {
91 // This group is optional
92 child_validity.SetInvalid(child_idx);
93 }
94 } else {
95 // Every group is a substring of the original, we can find out the offset using the pointer
96 // the 'match_group' address is guaranteed to be bigger than that of the source
97 D_ASSERT(const_char_ptr_cast(match_group.begin()) >= string.GetData());
98 idx_t offset = match_group.begin() - string.GetData();
99 list_content[child_idx] = string_t(string.GetData() + offset, match_group.size());
100 }
101 current_list_size++;
102 if (startpos > input.size()) {
103 // Empty match found at the end of the string
104 break;
105 }
106 }
107 list_entry.length = current_list_size - list_entry.offset;
108 ListVector::SetListSize(vec&: result, size: current_list_size);
109}
110
111int32_t GetGroupIndex(DataChunk &args, idx_t row, int32_t &result) {
112 if (args.ColumnCount() < 3) {
113 result = 0;
114 return true;
115 }
116 UnifiedVectorFormat format;
117 args.data[2].ToUnifiedFormat(count: args.size(), data&: format);
118 idx_t index = format.sel->get_index(idx: row);
119 if (!format.validity.RowIsValid(row_idx: index)) {
120 return false;
121 }
122 result = UnifiedVectorFormat::GetData<int32_t>(format)[index];
123 return true;
124}
125
126duckdb_re2::RE2 &GetPattern(const RegexpBaseBindData &info, ExpressionState &state,
127 unique_ptr<duckdb_re2::RE2> &pattern_p) {
128 if (info.constant_pattern) {
129 auto &lstate = ExecuteFunctionState::GetFunctionState(state)->Cast<RegexLocalState>();
130 return lstate.constant_pattern;
131 }
132 D_ASSERT(pattern_p);
133 return *pattern_p;
134}
135
136RegexStringPieceArgs &GetGroupsBuffer(const RegexpBaseBindData &info, ExpressionState &state,
137 unique_ptr<RegexStringPieceArgs> &groups_p) {
138 if (info.constant_pattern) {
139 auto &lstate = ExecuteFunctionState::GetFunctionState(state)->Cast<RegexLocalState>();
140 return lstate.group_buffer;
141 }
142 D_ASSERT(groups_p);
143 return *groups_p;
144}
145
146void RegexpExtractAll::Execute(DataChunk &args, ExpressionState &state, Vector &result) {
147 auto &func_expr = state.expr.Cast<BoundFunctionExpression>();
148 const auto &info = func_expr.bind_info->Cast<RegexpBaseBindData>();
149
150 auto &strings = args.data[0];
151 auto &patterns = args.data[1];
152 D_ASSERT(result.GetType().id() == LogicalTypeId::LIST);
153 auto &output_child = ListVector::GetEntry(vector&: result);
154
155 UnifiedVectorFormat strings_data;
156 strings.ToUnifiedFormat(count: args.size(), data&: strings_data);
157
158 UnifiedVectorFormat pattern_data;
159 patterns.ToUnifiedFormat(count: args.size(), data&: pattern_data);
160
161 ListVector::Reserve(vec&: result, STANDARD_VECTOR_SIZE);
162 // Reference the 'strings' StringBuffer, because we won't need to allocate new data
163 // for the result, all returned strings are substrings of the originals
164 output_child.SetAuxiliary(strings.GetAuxiliary());
165
166 // Avoid doing extra work if all the inputs are constant
167 idx_t tuple_count = args.AllConstant() ? 1 : args.size();
168
169 unique_ptr<RegexStringPieceArgs> non_const_args;
170 unique_ptr<duckdb_re2::RE2> stored_re;
171 if (!info.constant_pattern) {
172 non_const_args = make_uniq<RegexStringPieceArgs>();
173 } else {
174 // Verify that the constant pattern is valid
175 auto &re = GetPattern(info, state, pattern_p&: stored_re);
176 auto group_count_p = re.NumberOfCapturingGroups();
177 if (group_count_p == -1) {
178 throw InvalidInputException("Pattern failed to parse, error: '%s'", re.error());
179 }
180 }
181
182 for (idx_t row = 0; row < tuple_count; row++) {
183 bool pattern_valid = true;
184 if (!info.constant_pattern) {
185 // Check if the pattern is NULL or not,
186 // and compile the pattern if it's not constant
187 auto pattern_idx = pattern_data.sel->get_index(idx: row);
188 if (!pattern_data.validity.RowIsValid(row_idx: pattern_idx)) {
189 pattern_valid = false;
190 } else {
191 auto &pattern_p = UnifiedVectorFormat::GetData<string_t>(format: pattern_data)[pattern_idx];
192 auto pattern_strpiece = CreateStringPiece(input: pattern_p);
193 stored_re = make_uniq<duckdb_re2::RE2>(args&: pattern_strpiece, args: info.options);
194
195 // Increase the size of the args buffer if needed
196 auto group_count_p = stored_re->NumberOfCapturingGroups();
197 if (group_count_p == -1) {
198 throw InvalidInputException("Pattern failed to parse, error: '%s'", stored_re->error());
199 }
200 non_const_args->SetSize(group_count_p);
201 }
202 }
203
204 auto string_idx = strings_data.sel->get_index(idx: row);
205 int32_t group_index;
206 if (!pattern_valid || !strings_data.validity.RowIsValid(row_idx: string_idx) || !GetGroupIndex(args, row, result&: group_index)) {
207 // If something is NULL, the result is NULL
208 // FIXME: do we even need 'SPECIAL_HANDLING'?
209 auto result_data = FlatVector::GetData<list_entry_t>(vector&: result);
210 auto &result_validity = FlatVector::Validity(vector&: result);
211 result_data[row].length = 0;
212 result_data[row].offset = ListVector::GetListSize(vector: result);
213 result_validity.SetInvalid(row);
214 continue;
215 }
216
217 auto &re = GetPattern(info, state, pattern_p&: stored_re);
218 auto &groups = GetGroupsBuffer(info, state, groups_p&: non_const_args);
219 auto &string = UnifiedVectorFormat::GetData<string_t>(format: strings_data)[string_idx];
220 ExtractSingleTuple(string, pattern&: re, group: group_index, args&: groups, result, row);
221 }
222
223 if (args.AllConstant()) {
224 result.SetVectorType(VectorType::CONSTANT_VECTOR);
225 }
226}
227
228unique_ptr<FunctionData> RegexpExtractAll::Bind(ClientContext &context, ScalarFunction &bound_function,
229 vector<unique_ptr<Expression>> &arguments) {
230 D_ASSERT(arguments.size() >= 2);
231
232 duckdb_re2::RE2::Options options;
233
234 string constant_string;
235 bool constant_pattern = TryParseConstantPattern(context, expr&: *arguments[1], constant_string);
236
237 if (arguments.size() >= 4) {
238 ParseRegexOptions(context, expr&: *arguments[3], target&: options);
239 }
240 return make_uniq<RegexpExtractBindData>(args&: options, args: std::move(constant_string), args&: constant_pattern, args: "");
241}
242
243} // namespace duckdb
244