1/* SPDX-License-Identifier: MIT */
2/* Copyright © 2022 Max Bachmann */
3
4#pragma once
5
6#include "common.hpp"
7#include "intrinsics.hpp"
8
9namespace duckdb_jaro_winkler {
10namespace detail {
11
12struct FlaggedCharsWord {
13 uint64_t P_flag;
14 uint64_t T_flag;
15};
16
17struct FlaggedCharsMultiword {
18 std::vector<uint64_t> P_flag;
19 std::vector<uint64_t> T_flag;
20};
21
22struct SearchBoundMask {
23 int64_t words = 0;
24 int64_t empty_words = 0;
25 uint64_t last_mask = 0;
26 uint64_t first_mask = 0;
27};
28
29struct TextPosition {
30 TextPosition(int64_t Word_, int64_t WordPos_) : Word(Word_), WordPos(WordPos_)
31 {}
32 int64_t Word;
33 int64_t WordPos;
34};
35
36static inline double jaro_calculate_similarity(int64_t P_len, int64_t T_len, int64_t CommonChars,
37 int64_t Transpositions)
38{
39 Transpositions /= 2;
40 double Sim = 0;
41 Sim += static_cast<double>(CommonChars) / static_cast<double>(P_len);
42 Sim += static_cast<double>(CommonChars) / static_cast<double>(T_len);
43 Sim += (static_cast<double>(CommonChars) - static_cast<double>(Transpositions)) / static_cast<double>(CommonChars);
44 return Sim / 3.0;
45}
46
47/**
48 * @brief filter matches below score_cutoff based on string lengths
49 */
50static inline bool jaro_length_filter(int64_t P_len, int64_t T_len, double score_cutoff)
51{
52 if (!T_len || !P_len) return false;
53
54 double min_len = static_cast<double>(std::min(P_len, T_len));
55 double Sim = min_len / static_cast<double>(P_len) + min_len / static_cast<double>(T_len) + 1.0;
56 Sim /= 3.0;
57 return Sim >= score_cutoff;
58}
59
60/**
61 * @brief filter matches below score_cutoff based on string lengths and common characters
62 */
63static inline bool jaro_common_char_filter(int64_t P_len, int64_t T_len, int64_t CommonChars,
64 double score_cutoff)
65{
66 if (!CommonChars) return false;
67
68 double Sim = 0;
69 Sim += static_cast<double>(CommonChars) / static_cast<double>(P_len);
70 Sim += static_cast<double>(CommonChars) / static_cast<double>(T_len);
71 Sim += 1.0;
72 Sim /= 3.0;
73 return Sim >= score_cutoff;
74}
75
76static inline int64_t count_common_chars(const FlaggedCharsWord& flagged)
77{
78 return intrinsics::popcount(x: flagged.P_flag);
79}
80
81static inline int64_t count_common_chars(const FlaggedCharsMultiword& flagged)
82{
83 int64_t CommonChars = 0;
84 if (flagged.P_flag.size() < flagged.T_flag.size()) {
85 for (uint64_t flag : flagged.P_flag) {
86 CommonChars += intrinsics::popcount(x: flag);
87 }
88 }
89 else {
90 for (uint64_t flag : flagged.T_flag) {
91 CommonChars += intrinsics::popcount(x: flag);
92 }
93 }
94 return CommonChars;
95}
96
97template <typename PM_Vec, typename InputIt1, typename InputIt2>
98static inline FlaggedCharsWord
99flag_similar_characters_word(const PM_Vec& PM, InputIt1 P_first,
100 InputIt1 P_last, InputIt2 T_first, InputIt2 T_last, int Bound)
101{
102 using namespace intrinsics;
103 int64_t P_len = std::distance(P_first, P_last);
104 (void)P_len;
105 int64_t T_len = std::distance(T_first, T_last);
106 assert(P_len <= 64);
107 assert(T_len <= 64);
108 assert(Bound > P_len || P_len - Bound <= T_len);
109
110 FlaggedCharsWord flagged = {.P_flag: 0, .T_flag: 0};
111
112 uint64_t BoundMask = bit_mask_lsb<uint64_t>(n: Bound + 1);
113
114 int64_t j = 0;
115 for (; j < std::min(static_cast<int64_t>(Bound), T_len); ++j) {
116 uint64_t PM_j = PM.get(T_first[j]) & BoundMask & (~flagged.P_flag);
117
118 flagged.P_flag |= blsi(a: PM_j);
119 flagged.T_flag |= static_cast<uint64_t>(PM_j != 0) << j;
120
121 BoundMask = (BoundMask << 1) | 1;
122 }
123
124 for (; j < T_len; ++j) {
125 uint64_t PM_j = PM.get(T_first[j]) & BoundMask & (~flagged.P_flag);
126
127 flagged.P_flag |= blsi(a: PM_j);
128 flagged.T_flag |= static_cast<uint64_t>(PM_j != 0) << j;
129
130 BoundMask <<= 1;
131 }
132
133 return flagged;
134}
135
136template <typename CharT>
137static inline void flag_similar_characters_step(const common::BlockPatternMatchVector& PM,
138 CharT T_j, FlaggedCharsMultiword& flagged,
139 int64_t j, SearchBoundMask BoundMask)
140{
141 using namespace intrinsics;
142
143 int64_t j_word = j / 64;
144 int64_t j_pos = j % 64;
145 int64_t word = BoundMask.empty_words;
146 int64_t last_word = word + BoundMask.words;
147
148 if (BoundMask.words == 1) {
149 uint64_t PM_j = PM.get(word, T_j) & BoundMask.last_mask & BoundMask.first_mask &
150 (~flagged.P_flag[word]);
151
152 flagged.P_flag[word] |= blsi(a: PM_j);
153 flagged.T_flag[j_word] |= static_cast<uint64_t>(PM_j != 0) << j_pos;
154 return;
155 }
156
157 if (BoundMask.first_mask) {
158 uint64_t PM_j = PM.get(word, T_j) & BoundMask.first_mask & (~flagged.P_flag[word]);
159
160 if (PM_j) {
161 flagged.P_flag[word] |= blsi(a: PM_j);
162 flagged.T_flag[j_word] |= 1ull << j_pos;
163 return;
164 }
165 word++;
166 }
167
168 for (; word < last_word - 1; ++word) {
169 uint64_t PM_j = PM.get(word, T_j) & (~flagged.P_flag[word]);
170
171 if (PM_j) {
172 flagged.P_flag[word] |= blsi(a: PM_j);
173 flagged.T_flag[j_word] |= 1ull << j_pos;
174 return;
175 }
176 }
177
178 if (BoundMask.last_mask) {
179 uint64_t PM_j = PM.get(word, T_j) & BoundMask.last_mask & (~flagged.P_flag[word]);
180
181 flagged.P_flag[word] |= blsi(a: PM_j);
182 flagged.T_flag[j_word] |= static_cast<uint64_t>(PM_j != 0) << j_pos;
183 }
184}
185
186template <typename InputIt1, typename InputIt2>
187static inline FlaggedCharsMultiword
188flag_similar_characters_block(const common::BlockPatternMatchVector& PM, InputIt1 P_first,
189 InputIt1 P_last, InputIt2 T_first, InputIt2 T_last, int64_t Bound)
190{
191 using namespace intrinsics;
192 int64_t P_len = std::distance(P_first, P_last);
193 int64_t T_len = std::distance(T_first, T_last);
194 assert(P_len > 64 || T_len > 64);
195 assert(Bound > P_len || P_len - Bound <= T_len);
196 assert(Bound >= 31);
197
198 int64_t TextWords = common::ceildiv(a: T_len, divisor: 64);
199 int64_t PatternWords = common::ceildiv(a: P_len, divisor: 64);
200
201 FlaggedCharsMultiword flagged;
202 flagged.T_flag.resize(new_size: TextWords);
203 flagged.P_flag.resize(new_size: PatternWords);
204
205 SearchBoundMask BoundMask;
206 int64_t start_range = std::min(Bound + 1, P_len);
207 BoundMask.words = 1 + start_range / 64;
208 BoundMask.empty_words = 0;
209 BoundMask.last_mask = (1ull << (start_range % 64)) - 1;
210 BoundMask.first_mask = ~UINT64_C(0);
211
212 for (int64_t j = 0; j < T_len; ++j) {
213 flag_similar_characters_step(PM, T_first[j], flagged, j, BoundMask);
214
215 if (j + Bound + 1 < P_len) {
216 BoundMask.last_mask = (BoundMask.last_mask << 1) | 1;
217 if (j + Bound + 2 < P_len && BoundMask.last_mask == ~UINT64_C(0)) {
218 BoundMask.last_mask = 0;
219 BoundMask.words++;
220 }
221 }
222
223 if (j >= Bound) {
224 BoundMask.first_mask <<= 1;
225 if (BoundMask.first_mask == 0) {
226 BoundMask.first_mask = ~UINT64_C(0);
227 BoundMask.words--;
228 BoundMask.empty_words++;
229 }
230 }
231 }
232
233 return flagged;
234}
235
236template <typename PM_Vec, typename InputIt1>
237static inline int64_t count_transpositions_word(const PM_Vec& PM,
238 InputIt1 T_first, InputIt1,
239 const FlaggedCharsWord& flagged)
240{
241 using namespace intrinsics;
242 uint64_t P_flag = flagged.P_flag;
243 uint64_t T_flag = flagged.T_flag;
244 int64_t Transpositions = 0;
245 while (T_flag) {
246 uint64_t PatternFlagMask = blsi(a: P_flag);
247
248 Transpositions += !(PM.get(T_first[tzcnt(x: T_flag)]) & PatternFlagMask);
249
250 T_flag = blsr(x: T_flag);
251 P_flag ^= PatternFlagMask;
252 }
253
254 return Transpositions;
255}
256
257template <typename InputIt1>
258static inline int64_t
259count_transpositions_block(const common::BlockPatternMatchVector& PM, InputIt1 T_first, InputIt1,
260 const FlaggedCharsMultiword& flagged, int64_t FlaggedChars)
261{
262 using namespace intrinsics;
263 int64_t TextWord = 0;
264 int64_t PatternWord = 0;
265 uint64_t T_flag = flagged.T_flag[TextWord];
266 uint64_t P_flag = flagged.P_flag[PatternWord];
267
268 int64_t Transpositions = 0;
269 while (FlaggedChars) {
270 while (!T_flag) {
271 TextWord++;
272 T_first += 64;
273 T_flag = flagged.T_flag[TextWord];
274 }
275
276 while (T_flag) {
277 while (!P_flag) {
278 PatternWord++;
279 P_flag = flagged.P_flag[PatternWord];
280 }
281
282 uint64_t PatternFlagMask = blsi(a: P_flag);
283
284 Transpositions += !(PM.get(PatternWord, T_first[tzcnt(x: T_flag)]) & PatternFlagMask);
285
286 T_flag = blsr(x: T_flag);
287 P_flag ^= PatternFlagMask;
288
289 FlaggedChars--;
290 }
291 }
292
293 return Transpositions;
294}
295
296/**
297 * @brief find bounds and skip out of bound parts of the sequences
298 *
299 */
300template <typename InputIt1, typename InputIt2>
301int64_t jaro_bounds(InputIt1 P_first, InputIt1& P_last, InputIt2 T_first, InputIt2& T_last)
302{
303 int64_t P_len = std::distance(P_first, P_last);
304 int64_t T_len = std::distance(T_first, T_last);
305
306 /* since jaro uses a sliding window some parts of T/P might never be in
307 * range an can be removed ahead of time
308 */
309 int64_t Bound = 0;
310 if (T_len > P_len) {
311 Bound = T_len / 2 - 1;
312 if (T_len > P_len + Bound) {
313 T_last = T_first + P_len + Bound;
314 }
315 }
316 else {
317 Bound = P_len / 2 - 1;
318 if (P_len > T_len + Bound) {
319 P_last = P_first + T_len + Bound;
320 }
321 }
322 return Bound;
323}
324
325template <typename InputIt1, typename InputIt2>
326double jaro_similarity(InputIt1 P_first, InputIt1 P_last, InputIt2 T_first, InputIt2 T_last,
327 double score_cutoff)
328{
329 int64_t P_len = std::distance(P_first, P_last);
330 int64_t T_len = std::distance(T_first, T_last);
331
332 /* filter out based on the length difference between the two strings */
333 if (!jaro_length_filter(P_len, T_len, score_cutoff)) {
334 return 0.0;
335 }
336
337 if (P_len == 1 && T_len == 1) {
338 return static_cast<double>(P_first[0] == T_first[0]);
339 }
340
341 int64_t Bound = jaro_bounds(P_first, P_last, T_first, T_last);
342
343 /* common prefix never includes Transpositions */
344 int64_t CommonChars = common::remove_common_prefix(P_first, P_last, T_first, T_last);
345 int64_t Transpositions = 0;
346 int64_t P_view_len = std::distance(P_first, P_last);
347 int64_t T_view_len = std::distance(T_first, T_last);
348
349 if (!P_view_len || !T_view_len) {
350 /* already has correct number of common chars and transpositions */
351 }
352 else if (P_view_len <= 64 && T_view_len <= 64) {
353 common::PatternMatchVector PM(P_first, P_last);
354 auto flagged = flag_similar_characters_word(PM, P_first, P_last, T_first, T_last, static_cast<int>(Bound));
355 CommonChars += count_common_chars(flagged);
356
357 if (!jaro_common_char_filter(P_len, T_len, CommonChars, score_cutoff)) {
358 return 0.0;
359 }
360
361 Transpositions = count_transpositions_word(PM, T_first, T_last, flagged);
362 }
363 else {
364 common::BlockPatternMatchVector PM(P_first, P_last);
365 auto flagged = flag_similar_characters_block(PM, P_first, P_last, T_first, T_last, Bound);
366 int64_t FlaggedChars = count_common_chars(flagged);
367 CommonChars += FlaggedChars;
368
369 if (!jaro_common_char_filter(P_len, T_len, CommonChars, score_cutoff)) {
370 return 0.0;
371 }
372
373 Transpositions = count_transpositions_block(PM, T_first, T_last, flagged, FlaggedChars);
374 }
375
376 double Sim = jaro_calculate_similarity(P_len, T_len, CommonChars, Transpositions);
377 return common::result_cutoff(result: Sim, score_cutoff);
378}
379
380template <typename InputIt1, typename InputIt2>
381double jaro_similarity(const common::BlockPatternMatchVector& PM, InputIt1 P_first, InputIt1 P_last,
382 InputIt2 T_first, InputIt2 T_last, double score_cutoff)
383{
384 int64_t P_len = std::distance(P_first, P_last);
385 int64_t T_len = std::distance(T_first, T_last);
386
387 /* filter out based on the length difference between the two strings */
388 if (!jaro_length_filter(P_len, T_len, score_cutoff)) {
389 return 0.0;
390 }
391
392 if (P_len == 1 && T_len == 1) {
393 return static_cast<double>(P_first[0] == T_first[0]);
394 }
395
396 int64_t Bound = jaro_bounds(P_first, P_last, T_first, T_last);
397
398 /* common prefix never includes Transpositions */
399 int64_t CommonChars = 0;
400 int64_t Transpositions = 0;
401 int64_t P_view_len = std::distance(P_first, P_last);
402 int64_t T_view_len = std::distance(T_first, T_last);
403
404 if (!P_view_len || !T_view_len) {
405 /* already has correct number of common chars and transpositions */
406 }
407 else if (P_view_len <= 64 && T_view_len <= 64) {
408 auto flagged = flag_similar_characters_word(PM, P_first, P_last, T_first, T_last, static_cast<int>(Bound));
409 CommonChars += count_common_chars(flagged);
410
411 if (!jaro_common_char_filter(P_len, T_len, CommonChars, score_cutoff)) {
412 return 0.0;
413 }
414
415 Transpositions = count_transpositions_word(PM, T_first, T_last, flagged);
416 }
417 else {
418 auto flagged = flag_similar_characters_block(PM, P_first, P_last, T_first, T_last, Bound);
419 int64_t FlaggedChars = count_common_chars(flagged);
420 CommonChars += FlaggedChars;
421
422 if (!jaro_common_char_filter(P_len, T_len, CommonChars, score_cutoff)) {
423 return 0.0;
424 }
425
426 Transpositions = count_transpositions_block(PM, T_first, T_last, flagged, FlaggedChars);
427 }
428
429 double Sim = jaro_calculate_similarity(P_len, T_len, CommonChars, Transpositions);
430 return common::result_cutoff(result: Sim, score_cutoff);
431}
432
433template <typename InputIt1, typename InputIt2>
434double jaro_winkler_similarity(InputIt1 P_first, InputIt1 P_last, InputIt2 T_first, InputIt2 T_last,
435 double prefix_weight, double score_cutoff)
436{
437 int64_t P_len = std::distance(P_first, P_last);
438 int64_t T_len = std::distance(T_first, T_last);
439 int64_t min_len = std::min(P_len, T_len);
440 int64_t prefix = 0;
441 int64_t max_prefix = std::min<int64_t>(min_len, 4);
442
443 for (; prefix < max_prefix; ++prefix) {
444 if (T_first[prefix] != P_first[prefix]) {
445 break;
446 }
447 }
448
449 double jaro_score_cutoff = score_cutoff;
450 if (jaro_score_cutoff > 0.7) {
451 double prefix_sim = prefix * prefix_weight;
452
453 if (prefix_sim >= 1.0) {
454 jaro_score_cutoff = 0.7;
455 }
456 else {
457 jaro_score_cutoff =
458 std::max(0.7, (prefix_sim - jaro_score_cutoff) / (prefix_sim - 1.0));
459 }
460 }
461
462 double Sim = jaro_similarity(P_first, P_last, T_first, T_last, jaro_score_cutoff);
463 if (Sim > 0.7) {
464 Sim += prefix * prefix_weight * (1.0 - Sim);
465 }
466
467 return common::result_cutoff(result: Sim, score_cutoff);
468}
469
470template <typename InputIt1, typename InputIt2>
471double jaro_winkler_similarity(const common::BlockPatternMatchVector& PM, InputIt1 P_first,
472 InputIt1 P_last, InputIt2 T_first, InputIt2 T_last,
473 double prefix_weight, double score_cutoff)
474{
475 int64_t P_len = std::distance(P_first, P_last);
476 int64_t T_len = std::distance(T_first, T_last);
477 int64_t min_len = std::min(P_len, T_len);
478 int64_t prefix = 0;
479 int64_t max_prefix = std::min<int64_t>(min_len, 4);
480
481 for (; prefix < max_prefix; ++prefix) {
482 if (T_first[prefix] != P_first[prefix]) {
483 break;
484 }
485 }
486
487 double jaro_score_cutoff = score_cutoff;
488 if (jaro_score_cutoff > 0.7) {
489 double prefix_sim = prefix * prefix_weight;
490
491 if (prefix_sim >= 1.0) {
492 jaro_score_cutoff = 0.7;
493 }
494 else {
495 jaro_score_cutoff =
496 std::max(0.7, (prefix_sim - jaro_score_cutoff) / (prefix_sim - 1.0));
497 }
498 }
499
500 double Sim = jaro_similarity(PM, P_first, P_last, T_first, T_last, jaro_score_cutoff);
501 if (Sim > 0.7) {
502 Sim += prefix * prefix_weight * (1.0 - Sim);
503 }
504
505 return common::result_cutoff(result: Sim, score_cutoff);
506}
507
508} // namespace detail
509} // namespace duckdb_jaro_winkler
510