1 | #include "duckdb/function/scalar/regexp.hpp" |
2 | #include "duckdb/execution/expression_executor.hpp" |
3 | #include "duckdb/planner/expression/bound_function_expression.hpp" |
4 | #include "duckdb/function/scalar/string_functions.hpp" |
5 | #include "re2/re2.h" |
6 | |
7 | namespace duckdb { |
8 | |
9 | using regexp_util::CreateStringPiece; |
10 | using regexp_util::Extract; |
11 | using regexp_util::ParseRegexOptions; |
12 | using regexp_util::TryParseConstantPattern; |
13 | |
14 | unique_ptr<FunctionLocalState> |
15 | RegexpExtractAll::(ExpressionState &state, const BoundFunctionExpression &expr, FunctionData *bind_data) { |
16 | auto &info = bind_data->Cast<RegexpBaseBindData>(); |
17 | if (info.constant_pattern) { |
18 | return make_uniq<RegexLocalState>(args&: info, args: true); |
19 | } |
20 | return nullptr; |
21 | } |
22 | |
23 | // Forwards startpos automatically |
24 | bool (duckdb_re2::StringPiece &input, duckdb_re2::RE2 &pattern, idx_t *startpos, |
25 | duckdb_re2::StringPiece *groups, int ngroups) { |
26 | |
27 | D_ASSERT(pattern.ok()); |
28 | D_ASSERT(pattern.NumberOfCapturingGroups() == ngroups); |
29 | |
30 | if (!pattern.Match(text: input, startpos: *startpos, endpos: input.size(), re_anchor: pattern.Anchored(), submatch: groups, nsubmatch: ngroups + 1)) { |
31 | return false; |
32 | } |
33 | idx_t consumed = static_cast<size_t>(groups[0].end() - (input.begin() + *startpos)); |
34 | if (!consumed) { |
35 | // Empty match found, have to manually forward the input |
36 | // to avoid an infinite loop |
37 | // FIXME: support unicode characters |
38 | consumed++; |
39 | while (*startpos + consumed < input.length() && !LengthFun::IsCharacter(c: input[*startpos + consumed])) { |
40 | consumed++; |
41 | } |
42 | } |
43 | *startpos += consumed; |
44 | return true; |
45 | } |
46 | |
47 | void (const string_t &string, duckdb_re2::RE2 &pattern, int32_t group, RegexStringPieceArgs &args, |
48 | Vector &result, idx_t row) { |
49 | auto input = CreateStringPiece(input: string); |
50 | |
51 | auto &child_vector = ListVector::GetEntry(vector&: result); |
52 | auto list_content = FlatVector::GetData<string_t>(vector&: child_vector); |
53 | auto &child_validity = FlatVector::Validity(vector&: child_vector); |
54 | |
55 | auto current_list_size = ListVector::GetListSize(vector: result); |
56 | auto current_list_capacity = ListVector::GetListCapacity(vector: result); |
57 | |
58 | auto result_data = FlatVector::GetData<list_entry_t>(vector&: result); |
59 | auto &list_entry = result_data[row]; |
60 | list_entry.offset = current_list_size; |
61 | |
62 | if (group < 0) { |
63 | list_entry.length = 0; |
64 | return; |
65 | } |
66 | // If the requested group index is out of bounds |
67 | // we want to throw only if there is a match |
68 | bool throw_on_group_found = (idx_t)group > args.size; |
69 | |
70 | idx_t startpos = 0; |
71 | for (idx_t iteration = 0; ExtractAll(input, pattern, startpos: &startpos, groups: args.group_buffer, ngroups: args.size); iteration++) { |
72 | if (!iteration && throw_on_group_found) { |
73 | throw InvalidInputException("Pattern has %d groups. Cannot access group %d" , args.size, group); |
74 | } |
75 | |
76 | // Make sure we have enough room for the new entries |
77 | if (current_list_size + 1 >= current_list_capacity) { |
78 | ListVector::Reserve(vec&: result, required_capacity: current_list_capacity * 2); |
79 | current_list_capacity = ListVector::GetListCapacity(vector: result); |
80 | list_content = FlatVector::GetData<string_t>(vector&: child_vector); |
81 | } |
82 | |
83 | // Write the captured groups into the list-child vector |
84 | auto &match_group = args.group_buffer[group]; |
85 | |
86 | idx_t child_idx = current_list_size; |
87 | if (match_group.empty()) { |
88 | // This group was not matched |
89 | list_content[child_idx] = string_t(string.GetData(), 0); |
90 | if (match_group.begin() == nullptr) { |
91 | // This group is optional |
92 | child_validity.SetInvalid(child_idx); |
93 | } |
94 | } else { |
95 | // Every group is a substring of the original, we can find out the offset using the pointer |
96 | // the 'match_group' address is guaranteed to be bigger than that of the source |
97 | D_ASSERT(const_char_ptr_cast(match_group.begin()) >= string.GetData()); |
98 | idx_t offset = match_group.begin() - string.GetData(); |
99 | list_content[child_idx] = string_t(string.GetData() + offset, match_group.size()); |
100 | } |
101 | current_list_size++; |
102 | if (startpos > input.size()) { |
103 | // Empty match found at the end of the string |
104 | break; |
105 | } |
106 | } |
107 | list_entry.length = current_list_size - list_entry.offset; |
108 | ListVector::SetListSize(vec&: result, size: current_list_size); |
109 | } |
110 | |
111 | int32_t GetGroupIndex(DataChunk &args, idx_t row, int32_t &result) { |
112 | if (args.ColumnCount() < 3) { |
113 | result = 0; |
114 | return true; |
115 | } |
116 | UnifiedVectorFormat format; |
117 | args.data[2].ToUnifiedFormat(count: args.size(), data&: format); |
118 | idx_t index = format.sel->get_index(idx: row); |
119 | if (!format.validity.RowIsValid(row_idx: index)) { |
120 | return false; |
121 | } |
122 | result = UnifiedVectorFormat::GetData<int32_t>(format)[index]; |
123 | return true; |
124 | } |
125 | |
126 | duckdb_re2::RE2 &GetPattern(const RegexpBaseBindData &info, ExpressionState &state, |
127 | unique_ptr<duckdb_re2::RE2> &pattern_p) { |
128 | if (info.constant_pattern) { |
129 | auto &lstate = ExecuteFunctionState::GetFunctionState(state)->Cast<RegexLocalState>(); |
130 | return lstate.constant_pattern; |
131 | } |
132 | D_ASSERT(pattern_p); |
133 | return *pattern_p; |
134 | } |
135 | |
136 | RegexStringPieceArgs &GetGroupsBuffer(const RegexpBaseBindData &info, ExpressionState &state, |
137 | unique_ptr<RegexStringPieceArgs> &groups_p) { |
138 | if (info.constant_pattern) { |
139 | auto &lstate = ExecuteFunctionState::GetFunctionState(state)->Cast<RegexLocalState>(); |
140 | return lstate.group_buffer; |
141 | } |
142 | D_ASSERT(groups_p); |
143 | return *groups_p; |
144 | } |
145 | |
146 | void RegexpExtractAll::(DataChunk &args, ExpressionState &state, Vector &result) { |
147 | auto &func_expr = state.expr.Cast<BoundFunctionExpression>(); |
148 | const auto &info = func_expr.bind_info->Cast<RegexpBaseBindData>(); |
149 | |
150 | auto &strings = args.data[0]; |
151 | auto &patterns = args.data[1]; |
152 | D_ASSERT(result.GetType().id() == LogicalTypeId::LIST); |
153 | auto &output_child = ListVector::GetEntry(vector&: result); |
154 | |
155 | UnifiedVectorFormat strings_data; |
156 | strings.ToUnifiedFormat(count: args.size(), data&: strings_data); |
157 | |
158 | UnifiedVectorFormat pattern_data; |
159 | patterns.ToUnifiedFormat(count: args.size(), data&: pattern_data); |
160 | |
161 | ListVector::Reserve(vec&: result, STANDARD_VECTOR_SIZE); |
162 | // Reference the 'strings' StringBuffer, because we won't need to allocate new data |
163 | // for the result, all returned strings are substrings of the originals |
164 | output_child.SetAuxiliary(strings.GetAuxiliary()); |
165 | |
166 | // Avoid doing extra work if all the inputs are constant |
167 | idx_t tuple_count = args.AllConstant() ? 1 : args.size(); |
168 | |
169 | unique_ptr<RegexStringPieceArgs> non_const_args; |
170 | unique_ptr<duckdb_re2::RE2> stored_re; |
171 | if (!info.constant_pattern) { |
172 | non_const_args = make_uniq<RegexStringPieceArgs>(); |
173 | } else { |
174 | // Verify that the constant pattern is valid |
175 | auto &re = GetPattern(info, state, pattern_p&: stored_re); |
176 | auto group_count_p = re.NumberOfCapturingGroups(); |
177 | if (group_count_p == -1) { |
178 | throw InvalidInputException("Pattern failed to parse, error: '%s'" , re.error()); |
179 | } |
180 | } |
181 | |
182 | for (idx_t row = 0; row < tuple_count; row++) { |
183 | bool pattern_valid = true; |
184 | if (!info.constant_pattern) { |
185 | // Check if the pattern is NULL or not, |
186 | // and compile the pattern if it's not constant |
187 | auto pattern_idx = pattern_data.sel->get_index(idx: row); |
188 | if (!pattern_data.validity.RowIsValid(row_idx: pattern_idx)) { |
189 | pattern_valid = false; |
190 | } else { |
191 | auto &pattern_p = UnifiedVectorFormat::GetData<string_t>(format: pattern_data)[pattern_idx]; |
192 | auto pattern_strpiece = CreateStringPiece(input: pattern_p); |
193 | stored_re = make_uniq<duckdb_re2::RE2>(args&: pattern_strpiece, args: info.options); |
194 | |
195 | // Increase the size of the args buffer if needed |
196 | auto group_count_p = stored_re->NumberOfCapturingGroups(); |
197 | if (group_count_p == -1) { |
198 | throw InvalidInputException("Pattern failed to parse, error: '%s'" , stored_re->error()); |
199 | } |
200 | non_const_args->SetSize(group_count_p); |
201 | } |
202 | } |
203 | |
204 | auto string_idx = strings_data.sel->get_index(idx: row); |
205 | int32_t group_index; |
206 | if (!pattern_valid || !strings_data.validity.RowIsValid(row_idx: string_idx) || !GetGroupIndex(args, row, result&: group_index)) { |
207 | // If something is NULL, the result is NULL |
208 | // FIXME: do we even need 'SPECIAL_HANDLING'? |
209 | auto result_data = FlatVector::GetData<list_entry_t>(vector&: result); |
210 | auto &result_validity = FlatVector::Validity(vector&: result); |
211 | result_data[row].length = 0; |
212 | result_data[row].offset = ListVector::GetListSize(vector: result); |
213 | result_validity.SetInvalid(row); |
214 | continue; |
215 | } |
216 | |
217 | auto &re = GetPattern(info, state, pattern_p&: stored_re); |
218 | auto &groups = GetGroupsBuffer(info, state, groups_p&: non_const_args); |
219 | auto &string = UnifiedVectorFormat::GetData<string_t>(format: strings_data)[string_idx]; |
220 | ExtractSingleTuple(string, pattern&: re, group: group_index, args&: groups, result, row); |
221 | } |
222 | |
223 | if (args.AllConstant()) { |
224 | result.SetVectorType(VectorType::CONSTANT_VECTOR); |
225 | } |
226 | } |
227 | |
228 | unique_ptr<FunctionData> RegexpExtractAll::(ClientContext &context, ScalarFunction &bound_function, |
229 | vector<unique_ptr<Expression>> &arguments) { |
230 | D_ASSERT(arguments.size() >= 2); |
231 | |
232 | duckdb_re2::RE2::Options options; |
233 | |
234 | string constant_string; |
235 | bool constant_pattern = TryParseConstantPattern(context, expr&: *arguments[1], constant_string); |
236 | |
237 | if (arguments.size() >= 4) { |
238 | ParseRegexOptions(context, expr&: *arguments[3], target&: options); |
239 | } |
240 | return make_uniq<RegexpExtractBindData>(args&: options, args: std::move(constant_string), args&: constant_pattern, args: "" ); |
241 | } |
242 | |
243 | } // namespace duckdb |
244 | |