1 | /* SPDX-License-Identifier: MIT */ |
2 | /* Copyright © 2022 Max Bachmann */ |
3 | |
4 | #pragma once |
5 | |
6 | #include "common.hpp" |
7 | #include "intrinsics.hpp" |
8 | |
9 | namespace duckdb_jaro_winkler { |
10 | namespace detail { |
11 | |
12 | struct FlaggedCharsWord { |
13 | uint64_t P_flag; |
14 | uint64_t T_flag; |
15 | }; |
16 | |
17 | struct FlaggedCharsMultiword { |
18 | std::vector<uint64_t> P_flag; |
19 | std::vector<uint64_t> T_flag; |
20 | }; |
21 | |
22 | struct SearchBoundMask { |
23 | int64_t words = 0; |
24 | int64_t empty_words = 0; |
25 | uint64_t last_mask = 0; |
26 | uint64_t first_mask = 0; |
27 | }; |
28 | |
29 | struct TextPosition { |
30 | TextPosition(int64_t Word_, int64_t WordPos_) : Word(Word_), WordPos(WordPos_) |
31 | {} |
32 | int64_t Word; |
33 | int64_t WordPos; |
34 | }; |
35 | |
36 | static inline double jaro_calculate_similarity(int64_t P_len, int64_t T_len, int64_t CommonChars, |
37 | int64_t Transpositions) |
38 | { |
39 | Transpositions /= 2; |
40 | double Sim = 0; |
41 | Sim += static_cast<double>(CommonChars) / static_cast<double>(P_len); |
42 | Sim += static_cast<double>(CommonChars) / static_cast<double>(T_len); |
43 | Sim += (static_cast<double>(CommonChars) - static_cast<double>(Transpositions)) / static_cast<double>(CommonChars); |
44 | return Sim / 3.0; |
45 | } |
46 | |
47 | /** |
48 | * @brief filter matches below score_cutoff based on string lengths |
49 | */ |
50 | static inline bool jaro_length_filter(int64_t P_len, int64_t T_len, double score_cutoff) |
51 | { |
52 | if (!T_len || !P_len) return false; |
53 | |
54 | double min_len = static_cast<double>(std::min(P_len, T_len)); |
55 | double Sim = min_len / static_cast<double>(P_len) + min_len / static_cast<double>(T_len) + 1.0; |
56 | Sim /= 3.0; |
57 | return Sim >= score_cutoff; |
58 | } |
59 | |
60 | /** |
61 | * @brief filter matches below score_cutoff based on string lengths and common characters |
62 | */ |
63 | static inline bool jaro_common_char_filter(int64_t P_len, int64_t T_len, int64_t CommonChars, |
64 | double score_cutoff) |
65 | { |
66 | if (!CommonChars) return false; |
67 | |
68 | double Sim = 0; |
69 | Sim += static_cast<double>(CommonChars) / static_cast<double>(P_len); |
70 | Sim += static_cast<double>(CommonChars) / static_cast<double>(T_len); |
71 | Sim += 1.0; |
72 | Sim /= 3.0; |
73 | return Sim >= score_cutoff; |
74 | } |
75 | |
76 | static inline int64_t count_common_chars(const FlaggedCharsWord& flagged) |
77 | { |
78 | return intrinsics::popcount(x: flagged.P_flag); |
79 | } |
80 | |
81 | static inline int64_t count_common_chars(const FlaggedCharsMultiword& flagged) |
82 | { |
83 | int64_t CommonChars = 0; |
84 | if (flagged.P_flag.size() < flagged.T_flag.size()) { |
85 | for (uint64_t flag : flagged.P_flag) { |
86 | CommonChars += intrinsics::popcount(x: flag); |
87 | } |
88 | } |
89 | else { |
90 | for (uint64_t flag : flagged.T_flag) { |
91 | CommonChars += intrinsics::popcount(x: flag); |
92 | } |
93 | } |
94 | return CommonChars; |
95 | } |
96 | |
97 | template <typename PM_Vec, typename InputIt1, typename InputIt2> |
98 | static inline FlaggedCharsWord |
99 | flag_similar_characters_word(const PM_Vec& PM, InputIt1 P_first, |
100 | InputIt1 P_last, InputIt2 T_first, InputIt2 T_last, int Bound) |
101 | { |
102 | using namespace intrinsics; |
103 | int64_t P_len = std::distance(P_first, P_last); |
104 | (void)P_len; |
105 | int64_t T_len = std::distance(T_first, T_last); |
106 | assert(P_len <= 64); |
107 | assert(T_len <= 64); |
108 | assert(Bound > P_len || P_len - Bound <= T_len); |
109 | |
110 | FlaggedCharsWord flagged = {.P_flag: 0, .T_flag: 0}; |
111 | |
112 | uint64_t BoundMask = bit_mask_lsb<uint64_t>(n: Bound + 1); |
113 | |
114 | int64_t j = 0; |
115 | for (; j < std::min(static_cast<int64_t>(Bound), T_len); ++j) { |
116 | uint64_t PM_j = PM.get(T_first[j]) & BoundMask & (~flagged.P_flag); |
117 | |
118 | flagged.P_flag |= blsi(a: PM_j); |
119 | flagged.T_flag |= static_cast<uint64_t>(PM_j != 0) << j; |
120 | |
121 | BoundMask = (BoundMask << 1) | 1; |
122 | } |
123 | |
124 | for (; j < T_len; ++j) { |
125 | uint64_t PM_j = PM.get(T_first[j]) & BoundMask & (~flagged.P_flag); |
126 | |
127 | flagged.P_flag |= blsi(a: PM_j); |
128 | flagged.T_flag |= static_cast<uint64_t>(PM_j != 0) << j; |
129 | |
130 | BoundMask <<= 1; |
131 | } |
132 | |
133 | return flagged; |
134 | } |
135 | |
136 | template <typename CharT> |
137 | static inline void flag_similar_characters_step(const common::BlockPatternMatchVector& PM, |
138 | CharT T_j, FlaggedCharsMultiword& flagged, |
139 | int64_t j, SearchBoundMask BoundMask) |
140 | { |
141 | using namespace intrinsics; |
142 | |
143 | int64_t j_word = j / 64; |
144 | int64_t j_pos = j % 64; |
145 | int64_t word = BoundMask.empty_words; |
146 | int64_t last_word = word + BoundMask.words; |
147 | |
148 | if (BoundMask.words == 1) { |
149 | uint64_t PM_j = PM.get(word, T_j) & BoundMask.last_mask & BoundMask.first_mask & |
150 | (~flagged.P_flag[word]); |
151 | |
152 | flagged.P_flag[word] |= blsi(a: PM_j); |
153 | flagged.T_flag[j_word] |= static_cast<uint64_t>(PM_j != 0) << j_pos; |
154 | return; |
155 | } |
156 | |
157 | if (BoundMask.first_mask) { |
158 | uint64_t PM_j = PM.get(word, T_j) & BoundMask.first_mask & (~flagged.P_flag[word]); |
159 | |
160 | if (PM_j) { |
161 | flagged.P_flag[word] |= blsi(a: PM_j); |
162 | flagged.T_flag[j_word] |= 1ull << j_pos; |
163 | return; |
164 | } |
165 | word++; |
166 | } |
167 | |
168 | for (; word < last_word - 1; ++word) { |
169 | uint64_t PM_j = PM.get(word, T_j) & (~flagged.P_flag[word]); |
170 | |
171 | if (PM_j) { |
172 | flagged.P_flag[word] |= blsi(a: PM_j); |
173 | flagged.T_flag[j_word] |= 1ull << j_pos; |
174 | return; |
175 | } |
176 | } |
177 | |
178 | if (BoundMask.last_mask) { |
179 | uint64_t PM_j = PM.get(word, T_j) & BoundMask.last_mask & (~flagged.P_flag[word]); |
180 | |
181 | flagged.P_flag[word] |= blsi(a: PM_j); |
182 | flagged.T_flag[j_word] |= static_cast<uint64_t>(PM_j != 0) << j_pos; |
183 | } |
184 | } |
185 | |
186 | template <typename InputIt1, typename InputIt2> |
187 | static inline FlaggedCharsMultiword |
188 | flag_similar_characters_block(const common::BlockPatternMatchVector& PM, InputIt1 P_first, |
189 | InputIt1 P_last, InputIt2 T_first, InputIt2 T_last, int64_t Bound) |
190 | { |
191 | using namespace intrinsics; |
192 | int64_t P_len = std::distance(P_first, P_last); |
193 | int64_t T_len = std::distance(T_first, T_last); |
194 | assert(P_len > 64 || T_len > 64); |
195 | assert(Bound > P_len || P_len - Bound <= T_len); |
196 | assert(Bound >= 31); |
197 | |
198 | int64_t TextWords = common::ceildiv(a: T_len, divisor: 64); |
199 | int64_t PatternWords = common::ceildiv(a: P_len, divisor: 64); |
200 | |
201 | FlaggedCharsMultiword flagged; |
202 | flagged.T_flag.resize(new_size: TextWords); |
203 | flagged.P_flag.resize(new_size: PatternWords); |
204 | |
205 | SearchBoundMask BoundMask; |
206 | int64_t start_range = std::min(Bound + 1, P_len); |
207 | BoundMask.words = 1 + start_range / 64; |
208 | BoundMask.empty_words = 0; |
209 | BoundMask.last_mask = (1ull << (start_range % 64)) - 1; |
210 | BoundMask.first_mask = ~UINT64_C(0); |
211 | |
212 | for (int64_t j = 0; j < T_len; ++j) { |
213 | flag_similar_characters_step(PM, T_first[j], flagged, j, BoundMask); |
214 | |
215 | if (j + Bound + 1 < P_len) { |
216 | BoundMask.last_mask = (BoundMask.last_mask << 1) | 1; |
217 | if (j + Bound + 2 < P_len && BoundMask.last_mask == ~UINT64_C(0)) { |
218 | BoundMask.last_mask = 0; |
219 | BoundMask.words++; |
220 | } |
221 | } |
222 | |
223 | if (j >= Bound) { |
224 | BoundMask.first_mask <<= 1; |
225 | if (BoundMask.first_mask == 0) { |
226 | BoundMask.first_mask = ~UINT64_C(0); |
227 | BoundMask.words--; |
228 | BoundMask.empty_words++; |
229 | } |
230 | } |
231 | } |
232 | |
233 | return flagged; |
234 | } |
235 | |
236 | template <typename PM_Vec, typename InputIt1> |
237 | static inline int64_t count_transpositions_word(const PM_Vec& PM, |
238 | InputIt1 T_first, InputIt1, |
239 | const FlaggedCharsWord& flagged) |
240 | { |
241 | using namespace intrinsics; |
242 | uint64_t P_flag = flagged.P_flag; |
243 | uint64_t T_flag = flagged.T_flag; |
244 | int64_t Transpositions = 0; |
245 | while (T_flag) { |
246 | uint64_t PatternFlagMask = blsi(a: P_flag); |
247 | |
248 | Transpositions += !(PM.get(T_first[tzcnt(x: T_flag)]) & PatternFlagMask); |
249 | |
250 | T_flag = blsr(x: T_flag); |
251 | P_flag ^= PatternFlagMask; |
252 | } |
253 | |
254 | return Transpositions; |
255 | } |
256 | |
257 | template <typename InputIt1> |
258 | static inline int64_t |
259 | count_transpositions_block(const common::BlockPatternMatchVector& PM, InputIt1 T_first, InputIt1, |
260 | const FlaggedCharsMultiword& flagged, int64_t FlaggedChars) |
261 | { |
262 | using namespace intrinsics; |
263 | int64_t TextWord = 0; |
264 | int64_t PatternWord = 0; |
265 | uint64_t T_flag = flagged.T_flag[TextWord]; |
266 | uint64_t P_flag = flagged.P_flag[PatternWord]; |
267 | |
268 | int64_t Transpositions = 0; |
269 | while (FlaggedChars) { |
270 | while (!T_flag) { |
271 | TextWord++; |
272 | T_first += 64; |
273 | T_flag = flagged.T_flag[TextWord]; |
274 | } |
275 | |
276 | while (T_flag) { |
277 | while (!P_flag) { |
278 | PatternWord++; |
279 | P_flag = flagged.P_flag[PatternWord]; |
280 | } |
281 | |
282 | uint64_t PatternFlagMask = blsi(a: P_flag); |
283 | |
284 | Transpositions += !(PM.get(PatternWord, T_first[tzcnt(x: T_flag)]) & PatternFlagMask); |
285 | |
286 | T_flag = blsr(x: T_flag); |
287 | P_flag ^= PatternFlagMask; |
288 | |
289 | FlaggedChars--; |
290 | } |
291 | } |
292 | |
293 | return Transpositions; |
294 | } |
295 | |
296 | /** |
297 | * @brief find bounds and skip out of bound parts of the sequences |
298 | * |
299 | */ |
300 | template <typename InputIt1, typename InputIt2> |
301 | int64_t jaro_bounds(InputIt1 P_first, InputIt1& P_last, InputIt2 T_first, InputIt2& T_last) |
302 | { |
303 | int64_t P_len = std::distance(P_first, P_last); |
304 | int64_t T_len = std::distance(T_first, T_last); |
305 | |
306 | /* since jaro uses a sliding window some parts of T/P might never be in |
307 | * range an can be removed ahead of time |
308 | */ |
309 | int64_t Bound = 0; |
310 | if (T_len > P_len) { |
311 | Bound = T_len / 2 - 1; |
312 | if (T_len > P_len + Bound) { |
313 | T_last = T_first + P_len + Bound; |
314 | } |
315 | } |
316 | else { |
317 | Bound = P_len / 2 - 1; |
318 | if (P_len > T_len + Bound) { |
319 | P_last = P_first + T_len + Bound; |
320 | } |
321 | } |
322 | return Bound; |
323 | } |
324 | |
325 | template <typename InputIt1, typename InputIt2> |
326 | double jaro_similarity(InputIt1 P_first, InputIt1 P_last, InputIt2 T_first, InputIt2 T_last, |
327 | double score_cutoff) |
328 | { |
329 | int64_t P_len = std::distance(P_first, P_last); |
330 | int64_t T_len = std::distance(T_first, T_last); |
331 | |
332 | /* filter out based on the length difference between the two strings */ |
333 | if (!jaro_length_filter(P_len, T_len, score_cutoff)) { |
334 | return 0.0; |
335 | } |
336 | |
337 | if (P_len == 1 && T_len == 1) { |
338 | return static_cast<double>(P_first[0] == T_first[0]); |
339 | } |
340 | |
341 | int64_t Bound = jaro_bounds(P_first, P_last, T_first, T_last); |
342 | |
343 | /* common prefix never includes Transpositions */ |
344 | int64_t CommonChars = common::remove_common_prefix(P_first, P_last, T_first, T_last); |
345 | int64_t Transpositions = 0; |
346 | int64_t P_view_len = std::distance(P_first, P_last); |
347 | int64_t T_view_len = std::distance(T_first, T_last); |
348 | |
349 | if (!P_view_len || !T_view_len) { |
350 | /* already has correct number of common chars and transpositions */ |
351 | } |
352 | else if (P_view_len <= 64 && T_view_len <= 64) { |
353 | common::PatternMatchVector PM(P_first, P_last); |
354 | auto flagged = flag_similar_characters_word(PM, P_first, P_last, T_first, T_last, static_cast<int>(Bound)); |
355 | CommonChars += count_common_chars(flagged); |
356 | |
357 | if (!jaro_common_char_filter(P_len, T_len, CommonChars, score_cutoff)) { |
358 | return 0.0; |
359 | } |
360 | |
361 | Transpositions = count_transpositions_word(PM, T_first, T_last, flagged); |
362 | } |
363 | else { |
364 | common::BlockPatternMatchVector PM(P_first, P_last); |
365 | auto flagged = flag_similar_characters_block(PM, P_first, P_last, T_first, T_last, Bound); |
366 | int64_t FlaggedChars = count_common_chars(flagged); |
367 | CommonChars += FlaggedChars; |
368 | |
369 | if (!jaro_common_char_filter(P_len, T_len, CommonChars, score_cutoff)) { |
370 | return 0.0; |
371 | } |
372 | |
373 | Transpositions = count_transpositions_block(PM, T_first, T_last, flagged, FlaggedChars); |
374 | } |
375 | |
376 | double Sim = jaro_calculate_similarity(P_len, T_len, CommonChars, Transpositions); |
377 | return common::result_cutoff(result: Sim, score_cutoff); |
378 | } |
379 | |
380 | template <typename InputIt1, typename InputIt2> |
381 | double jaro_similarity(const common::BlockPatternMatchVector& PM, InputIt1 P_first, InputIt1 P_last, |
382 | InputIt2 T_first, InputIt2 T_last, double score_cutoff) |
383 | { |
384 | int64_t P_len = std::distance(P_first, P_last); |
385 | int64_t T_len = std::distance(T_first, T_last); |
386 | |
387 | /* filter out based on the length difference between the two strings */ |
388 | if (!jaro_length_filter(P_len, T_len, score_cutoff)) { |
389 | return 0.0; |
390 | } |
391 | |
392 | if (P_len == 1 && T_len == 1) { |
393 | return static_cast<double>(P_first[0] == T_first[0]); |
394 | } |
395 | |
396 | int64_t Bound = jaro_bounds(P_first, P_last, T_first, T_last); |
397 | |
398 | /* common prefix never includes Transpositions */ |
399 | int64_t CommonChars = 0; |
400 | int64_t Transpositions = 0; |
401 | int64_t P_view_len = std::distance(P_first, P_last); |
402 | int64_t T_view_len = std::distance(T_first, T_last); |
403 | |
404 | if (!P_view_len || !T_view_len) { |
405 | /* already has correct number of common chars and transpositions */ |
406 | } |
407 | else if (P_view_len <= 64 && T_view_len <= 64) { |
408 | auto flagged = flag_similar_characters_word(PM, P_first, P_last, T_first, T_last, static_cast<int>(Bound)); |
409 | CommonChars += count_common_chars(flagged); |
410 | |
411 | if (!jaro_common_char_filter(P_len, T_len, CommonChars, score_cutoff)) { |
412 | return 0.0; |
413 | } |
414 | |
415 | Transpositions = count_transpositions_word(PM, T_first, T_last, flagged); |
416 | } |
417 | else { |
418 | auto flagged = flag_similar_characters_block(PM, P_first, P_last, T_first, T_last, Bound); |
419 | int64_t FlaggedChars = count_common_chars(flagged); |
420 | CommonChars += FlaggedChars; |
421 | |
422 | if (!jaro_common_char_filter(P_len, T_len, CommonChars, score_cutoff)) { |
423 | return 0.0; |
424 | } |
425 | |
426 | Transpositions = count_transpositions_block(PM, T_first, T_last, flagged, FlaggedChars); |
427 | } |
428 | |
429 | double Sim = jaro_calculate_similarity(P_len, T_len, CommonChars, Transpositions); |
430 | return common::result_cutoff(result: Sim, score_cutoff); |
431 | } |
432 | |
433 | template <typename InputIt1, typename InputIt2> |
434 | double jaro_winkler_similarity(InputIt1 P_first, InputIt1 P_last, InputIt2 T_first, InputIt2 T_last, |
435 | double prefix_weight, double score_cutoff) |
436 | { |
437 | int64_t P_len = std::distance(P_first, P_last); |
438 | int64_t T_len = std::distance(T_first, T_last); |
439 | int64_t min_len = std::min(P_len, T_len); |
440 | int64_t prefix = 0; |
441 | int64_t max_prefix = std::min<int64_t>(min_len, 4); |
442 | |
443 | for (; prefix < max_prefix; ++prefix) { |
444 | if (T_first[prefix] != P_first[prefix]) { |
445 | break; |
446 | } |
447 | } |
448 | |
449 | double jaro_score_cutoff = score_cutoff; |
450 | if (jaro_score_cutoff > 0.7) { |
451 | double prefix_sim = prefix * prefix_weight; |
452 | |
453 | if (prefix_sim >= 1.0) { |
454 | jaro_score_cutoff = 0.7; |
455 | } |
456 | else { |
457 | jaro_score_cutoff = |
458 | std::max(0.7, (prefix_sim - jaro_score_cutoff) / (prefix_sim - 1.0)); |
459 | } |
460 | } |
461 | |
462 | double Sim = jaro_similarity(P_first, P_last, T_first, T_last, jaro_score_cutoff); |
463 | if (Sim > 0.7) { |
464 | Sim += prefix * prefix_weight * (1.0 - Sim); |
465 | } |
466 | |
467 | return common::result_cutoff(result: Sim, score_cutoff); |
468 | } |
469 | |
470 | template <typename InputIt1, typename InputIt2> |
471 | double jaro_winkler_similarity(const common::BlockPatternMatchVector& PM, InputIt1 P_first, |
472 | InputIt1 P_last, InputIt2 T_first, InputIt2 T_last, |
473 | double prefix_weight, double score_cutoff) |
474 | { |
475 | int64_t P_len = std::distance(P_first, P_last); |
476 | int64_t T_len = std::distance(T_first, T_last); |
477 | int64_t min_len = std::min(P_len, T_len); |
478 | int64_t prefix = 0; |
479 | int64_t max_prefix = std::min<int64_t>(min_len, 4); |
480 | |
481 | for (; prefix < max_prefix; ++prefix) { |
482 | if (T_first[prefix] != P_first[prefix]) { |
483 | break; |
484 | } |
485 | } |
486 | |
487 | double jaro_score_cutoff = score_cutoff; |
488 | if (jaro_score_cutoff > 0.7) { |
489 | double prefix_sim = prefix * prefix_weight; |
490 | |
491 | if (prefix_sim >= 1.0) { |
492 | jaro_score_cutoff = 0.7; |
493 | } |
494 | else { |
495 | jaro_score_cutoff = |
496 | std::max(0.7, (prefix_sim - jaro_score_cutoff) / (prefix_sim - 1.0)); |
497 | } |
498 | } |
499 | |
500 | double Sim = jaro_similarity(PM, P_first, P_last, T_first, T_last, jaro_score_cutoff); |
501 | if (Sim > 0.7) { |
502 | Sim += prefix * prefix_weight * (1.0 - Sim); |
503 | } |
504 | |
505 | return common::result_cutoff(result: Sim, score_cutoff); |
506 | } |
507 | |
508 | } // namespace detail |
509 | } // namespace duckdb_jaro_winkler |
510 | |