| 1 | /* SPDX-License-Identifier: MIT */ |
| 2 | /* Copyright © 2022 Max Bachmann */ |
| 3 | |
| 4 | #pragma once |
| 5 | |
| 6 | #include "common.hpp" |
| 7 | #include "intrinsics.hpp" |
| 8 | |
| 9 | namespace duckdb_jaro_winkler { |
| 10 | namespace detail { |
| 11 | |
| 12 | struct FlaggedCharsWord { |
| 13 | uint64_t P_flag; |
| 14 | uint64_t T_flag; |
| 15 | }; |
| 16 | |
| 17 | struct FlaggedCharsMultiword { |
| 18 | std::vector<uint64_t> P_flag; |
| 19 | std::vector<uint64_t> T_flag; |
| 20 | }; |
| 21 | |
| 22 | struct SearchBoundMask { |
| 23 | int64_t words = 0; |
| 24 | int64_t empty_words = 0; |
| 25 | uint64_t last_mask = 0; |
| 26 | uint64_t first_mask = 0; |
| 27 | }; |
| 28 | |
| 29 | struct TextPosition { |
| 30 | TextPosition(int64_t Word_, int64_t WordPos_) : Word(Word_), WordPos(WordPos_) |
| 31 | {} |
| 32 | int64_t Word; |
| 33 | int64_t WordPos; |
| 34 | }; |
| 35 | |
| 36 | static inline double jaro_calculate_similarity(int64_t P_len, int64_t T_len, int64_t CommonChars, |
| 37 | int64_t Transpositions) |
| 38 | { |
| 39 | Transpositions /= 2; |
| 40 | double Sim = 0; |
| 41 | Sim += static_cast<double>(CommonChars) / static_cast<double>(P_len); |
| 42 | Sim += static_cast<double>(CommonChars) / static_cast<double>(T_len); |
| 43 | Sim += (static_cast<double>(CommonChars) - static_cast<double>(Transpositions)) / static_cast<double>(CommonChars); |
| 44 | return Sim / 3.0; |
| 45 | } |
| 46 | |
| 47 | /** |
| 48 | * @brief filter matches below score_cutoff based on string lengths |
| 49 | */ |
| 50 | static inline bool jaro_length_filter(int64_t P_len, int64_t T_len, double score_cutoff) |
| 51 | { |
| 52 | if (!T_len || !P_len) return false; |
| 53 | |
| 54 | double min_len = static_cast<double>(std::min(P_len, T_len)); |
| 55 | double Sim = min_len / static_cast<double>(P_len) + min_len / static_cast<double>(T_len) + 1.0; |
| 56 | Sim /= 3.0; |
| 57 | return Sim >= score_cutoff; |
| 58 | } |
| 59 | |
| 60 | /** |
| 61 | * @brief filter matches below score_cutoff based on string lengths and common characters |
| 62 | */ |
| 63 | static inline bool jaro_common_char_filter(int64_t P_len, int64_t T_len, int64_t CommonChars, |
| 64 | double score_cutoff) |
| 65 | { |
| 66 | if (!CommonChars) return false; |
| 67 | |
| 68 | double Sim = 0; |
| 69 | Sim += static_cast<double>(CommonChars) / static_cast<double>(P_len); |
| 70 | Sim += static_cast<double>(CommonChars) / static_cast<double>(T_len); |
| 71 | Sim += 1.0; |
| 72 | Sim /= 3.0; |
| 73 | return Sim >= score_cutoff; |
| 74 | } |
| 75 | |
| 76 | static inline int64_t count_common_chars(const FlaggedCharsWord& flagged) |
| 77 | { |
| 78 | return intrinsics::popcount(x: flagged.P_flag); |
| 79 | } |
| 80 | |
| 81 | static inline int64_t count_common_chars(const FlaggedCharsMultiword& flagged) |
| 82 | { |
| 83 | int64_t CommonChars = 0; |
| 84 | if (flagged.P_flag.size() < flagged.T_flag.size()) { |
| 85 | for (uint64_t flag : flagged.P_flag) { |
| 86 | CommonChars += intrinsics::popcount(x: flag); |
| 87 | } |
| 88 | } |
| 89 | else { |
| 90 | for (uint64_t flag : flagged.T_flag) { |
| 91 | CommonChars += intrinsics::popcount(x: flag); |
| 92 | } |
| 93 | } |
| 94 | return CommonChars; |
| 95 | } |
| 96 | |
| 97 | template <typename PM_Vec, typename InputIt1, typename InputIt2> |
| 98 | static inline FlaggedCharsWord |
| 99 | flag_similar_characters_word(const PM_Vec& PM, InputIt1 P_first, |
| 100 | InputIt1 P_last, InputIt2 T_first, InputIt2 T_last, int Bound) |
| 101 | { |
| 102 | using namespace intrinsics; |
| 103 | int64_t P_len = std::distance(P_first, P_last); |
| 104 | (void)P_len; |
| 105 | int64_t T_len = std::distance(T_first, T_last); |
| 106 | assert(P_len <= 64); |
| 107 | assert(T_len <= 64); |
| 108 | assert(Bound > P_len || P_len - Bound <= T_len); |
| 109 | |
| 110 | FlaggedCharsWord flagged = {.P_flag: 0, .T_flag: 0}; |
| 111 | |
| 112 | uint64_t BoundMask = bit_mask_lsb<uint64_t>(n: Bound + 1); |
| 113 | |
| 114 | int64_t j = 0; |
| 115 | for (; j < std::min(static_cast<int64_t>(Bound), T_len); ++j) { |
| 116 | uint64_t PM_j = PM.get(T_first[j]) & BoundMask & (~flagged.P_flag); |
| 117 | |
| 118 | flagged.P_flag |= blsi(a: PM_j); |
| 119 | flagged.T_flag |= static_cast<uint64_t>(PM_j != 0) << j; |
| 120 | |
| 121 | BoundMask = (BoundMask << 1) | 1; |
| 122 | } |
| 123 | |
| 124 | for (; j < T_len; ++j) { |
| 125 | uint64_t PM_j = PM.get(T_first[j]) & BoundMask & (~flagged.P_flag); |
| 126 | |
| 127 | flagged.P_flag |= blsi(a: PM_j); |
| 128 | flagged.T_flag |= static_cast<uint64_t>(PM_j != 0) << j; |
| 129 | |
| 130 | BoundMask <<= 1; |
| 131 | } |
| 132 | |
| 133 | return flagged; |
| 134 | } |
| 135 | |
| 136 | template <typename CharT> |
| 137 | static inline void flag_similar_characters_step(const common::BlockPatternMatchVector& PM, |
| 138 | CharT T_j, FlaggedCharsMultiword& flagged, |
| 139 | int64_t j, SearchBoundMask BoundMask) |
| 140 | { |
| 141 | using namespace intrinsics; |
| 142 | |
| 143 | int64_t j_word = j / 64; |
| 144 | int64_t j_pos = j % 64; |
| 145 | int64_t word = BoundMask.empty_words; |
| 146 | int64_t last_word = word + BoundMask.words; |
| 147 | |
| 148 | if (BoundMask.words == 1) { |
| 149 | uint64_t PM_j = PM.get(word, T_j) & BoundMask.last_mask & BoundMask.first_mask & |
| 150 | (~flagged.P_flag[word]); |
| 151 | |
| 152 | flagged.P_flag[word] |= blsi(a: PM_j); |
| 153 | flagged.T_flag[j_word] |= static_cast<uint64_t>(PM_j != 0) << j_pos; |
| 154 | return; |
| 155 | } |
| 156 | |
| 157 | if (BoundMask.first_mask) { |
| 158 | uint64_t PM_j = PM.get(word, T_j) & BoundMask.first_mask & (~flagged.P_flag[word]); |
| 159 | |
| 160 | if (PM_j) { |
| 161 | flagged.P_flag[word] |= blsi(a: PM_j); |
| 162 | flagged.T_flag[j_word] |= 1ull << j_pos; |
| 163 | return; |
| 164 | } |
| 165 | word++; |
| 166 | } |
| 167 | |
| 168 | for (; word < last_word - 1; ++word) { |
| 169 | uint64_t PM_j = PM.get(word, T_j) & (~flagged.P_flag[word]); |
| 170 | |
| 171 | if (PM_j) { |
| 172 | flagged.P_flag[word] |= blsi(a: PM_j); |
| 173 | flagged.T_flag[j_word] |= 1ull << j_pos; |
| 174 | return; |
| 175 | } |
| 176 | } |
| 177 | |
| 178 | if (BoundMask.last_mask) { |
| 179 | uint64_t PM_j = PM.get(word, T_j) & BoundMask.last_mask & (~flagged.P_flag[word]); |
| 180 | |
| 181 | flagged.P_flag[word] |= blsi(a: PM_j); |
| 182 | flagged.T_flag[j_word] |= static_cast<uint64_t>(PM_j != 0) << j_pos; |
| 183 | } |
| 184 | } |
| 185 | |
| 186 | template <typename InputIt1, typename InputIt2> |
| 187 | static inline FlaggedCharsMultiword |
| 188 | flag_similar_characters_block(const common::BlockPatternMatchVector& PM, InputIt1 P_first, |
| 189 | InputIt1 P_last, InputIt2 T_first, InputIt2 T_last, int64_t Bound) |
| 190 | { |
| 191 | using namespace intrinsics; |
| 192 | int64_t P_len = std::distance(P_first, P_last); |
| 193 | int64_t T_len = std::distance(T_first, T_last); |
| 194 | assert(P_len > 64 || T_len > 64); |
| 195 | assert(Bound > P_len || P_len - Bound <= T_len); |
| 196 | assert(Bound >= 31); |
| 197 | |
| 198 | int64_t TextWords = common::ceildiv(a: T_len, divisor: 64); |
| 199 | int64_t PatternWords = common::ceildiv(a: P_len, divisor: 64); |
| 200 | |
| 201 | FlaggedCharsMultiword flagged; |
| 202 | flagged.T_flag.resize(new_size: TextWords); |
| 203 | flagged.P_flag.resize(new_size: PatternWords); |
| 204 | |
| 205 | SearchBoundMask BoundMask; |
| 206 | int64_t start_range = std::min(Bound + 1, P_len); |
| 207 | BoundMask.words = 1 + start_range / 64; |
| 208 | BoundMask.empty_words = 0; |
| 209 | BoundMask.last_mask = (1ull << (start_range % 64)) - 1; |
| 210 | BoundMask.first_mask = ~UINT64_C(0); |
| 211 | |
| 212 | for (int64_t j = 0; j < T_len; ++j) { |
| 213 | flag_similar_characters_step(PM, T_first[j], flagged, j, BoundMask); |
| 214 | |
| 215 | if (j + Bound + 1 < P_len) { |
| 216 | BoundMask.last_mask = (BoundMask.last_mask << 1) | 1; |
| 217 | if (j + Bound + 2 < P_len && BoundMask.last_mask == ~UINT64_C(0)) { |
| 218 | BoundMask.last_mask = 0; |
| 219 | BoundMask.words++; |
| 220 | } |
| 221 | } |
| 222 | |
| 223 | if (j >= Bound) { |
| 224 | BoundMask.first_mask <<= 1; |
| 225 | if (BoundMask.first_mask == 0) { |
| 226 | BoundMask.first_mask = ~UINT64_C(0); |
| 227 | BoundMask.words--; |
| 228 | BoundMask.empty_words++; |
| 229 | } |
| 230 | } |
| 231 | } |
| 232 | |
| 233 | return flagged; |
| 234 | } |
| 235 | |
| 236 | template <typename PM_Vec, typename InputIt1> |
| 237 | static inline int64_t count_transpositions_word(const PM_Vec& PM, |
| 238 | InputIt1 T_first, InputIt1, |
| 239 | const FlaggedCharsWord& flagged) |
| 240 | { |
| 241 | using namespace intrinsics; |
| 242 | uint64_t P_flag = flagged.P_flag; |
| 243 | uint64_t T_flag = flagged.T_flag; |
| 244 | int64_t Transpositions = 0; |
| 245 | while (T_flag) { |
| 246 | uint64_t PatternFlagMask = blsi(a: P_flag); |
| 247 | |
| 248 | Transpositions += !(PM.get(T_first[tzcnt(x: T_flag)]) & PatternFlagMask); |
| 249 | |
| 250 | T_flag = blsr(x: T_flag); |
| 251 | P_flag ^= PatternFlagMask; |
| 252 | } |
| 253 | |
| 254 | return Transpositions; |
| 255 | } |
| 256 | |
| 257 | template <typename InputIt1> |
| 258 | static inline int64_t |
| 259 | count_transpositions_block(const common::BlockPatternMatchVector& PM, InputIt1 T_first, InputIt1, |
| 260 | const FlaggedCharsMultiword& flagged, int64_t FlaggedChars) |
| 261 | { |
| 262 | using namespace intrinsics; |
| 263 | int64_t TextWord = 0; |
| 264 | int64_t PatternWord = 0; |
| 265 | uint64_t T_flag = flagged.T_flag[TextWord]; |
| 266 | uint64_t P_flag = flagged.P_flag[PatternWord]; |
| 267 | |
| 268 | int64_t Transpositions = 0; |
| 269 | while (FlaggedChars) { |
| 270 | while (!T_flag) { |
| 271 | TextWord++; |
| 272 | T_first += 64; |
| 273 | T_flag = flagged.T_flag[TextWord]; |
| 274 | } |
| 275 | |
| 276 | while (T_flag) { |
| 277 | while (!P_flag) { |
| 278 | PatternWord++; |
| 279 | P_flag = flagged.P_flag[PatternWord]; |
| 280 | } |
| 281 | |
| 282 | uint64_t PatternFlagMask = blsi(a: P_flag); |
| 283 | |
| 284 | Transpositions += !(PM.get(PatternWord, T_first[tzcnt(x: T_flag)]) & PatternFlagMask); |
| 285 | |
| 286 | T_flag = blsr(x: T_flag); |
| 287 | P_flag ^= PatternFlagMask; |
| 288 | |
| 289 | FlaggedChars--; |
| 290 | } |
| 291 | } |
| 292 | |
| 293 | return Transpositions; |
| 294 | } |
| 295 | |
| 296 | /** |
| 297 | * @brief find bounds and skip out of bound parts of the sequences |
| 298 | * |
| 299 | */ |
| 300 | template <typename InputIt1, typename InputIt2> |
| 301 | int64_t jaro_bounds(InputIt1 P_first, InputIt1& P_last, InputIt2 T_first, InputIt2& T_last) |
| 302 | { |
| 303 | int64_t P_len = std::distance(P_first, P_last); |
| 304 | int64_t T_len = std::distance(T_first, T_last); |
| 305 | |
| 306 | /* since jaro uses a sliding window some parts of T/P might never be in |
| 307 | * range an can be removed ahead of time |
| 308 | */ |
| 309 | int64_t Bound = 0; |
| 310 | if (T_len > P_len) { |
| 311 | Bound = T_len / 2 - 1; |
| 312 | if (T_len > P_len + Bound) { |
| 313 | T_last = T_first + P_len + Bound; |
| 314 | } |
| 315 | } |
| 316 | else { |
| 317 | Bound = P_len / 2 - 1; |
| 318 | if (P_len > T_len + Bound) { |
| 319 | P_last = P_first + T_len + Bound; |
| 320 | } |
| 321 | } |
| 322 | return Bound; |
| 323 | } |
| 324 | |
| 325 | template <typename InputIt1, typename InputIt2> |
| 326 | double jaro_similarity(InputIt1 P_first, InputIt1 P_last, InputIt2 T_first, InputIt2 T_last, |
| 327 | double score_cutoff) |
| 328 | { |
| 329 | int64_t P_len = std::distance(P_first, P_last); |
| 330 | int64_t T_len = std::distance(T_first, T_last); |
| 331 | |
| 332 | /* filter out based on the length difference between the two strings */ |
| 333 | if (!jaro_length_filter(P_len, T_len, score_cutoff)) { |
| 334 | return 0.0; |
| 335 | } |
| 336 | |
| 337 | if (P_len == 1 && T_len == 1) { |
| 338 | return static_cast<double>(P_first[0] == T_first[0]); |
| 339 | } |
| 340 | |
| 341 | int64_t Bound = jaro_bounds(P_first, P_last, T_first, T_last); |
| 342 | |
| 343 | /* common prefix never includes Transpositions */ |
| 344 | int64_t CommonChars = common::remove_common_prefix(P_first, P_last, T_first, T_last); |
| 345 | int64_t Transpositions = 0; |
| 346 | int64_t P_view_len = std::distance(P_first, P_last); |
| 347 | int64_t T_view_len = std::distance(T_first, T_last); |
| 348 | |
| 349 | if (!P_view_len || !T_view_len) { |
| 350 | /* already has correct number of common chars and transpositions */ |
| 351 | } |
| 352 | else if (P_view_len <= 64 && T_view_len <= 64) { |
| 353 | common::PatternMatchVector PM(P_first, P_last); |
| 354 | auto flagged = flag_similar_characters_word(PM, P_first, P_last, T_first, T_last, static_cast<int>(Bound)); |
| 355 | CommonChars += count_common_chars(flagged); |
| 356 | |
| 357 | if (!jaro_common_char_filter(P_len, T_len, CommonChars, score_cutoff)) { |
| 358 | return 0.0; |
| 359 | } |
| 360 | |
| 361 | Transpositions = count_transpositions_word(PM, T_first, T_last, flagged); |
| 362 | } |
| 363 | else { |
| 364 | common::BlockPatternMatchVector PM(P_first, P_last); |
| 365 | auto flagged = flag_similar_characters_block(PM, P_first, P_last, T_first, T_last, Bound); |
| 366 | int64_t FlaggedChars = count_common_chars(flagged); |
| 367 | CommonChars += FlaggedChars; |
| 368 | |
| 369 | if (!jaro_common_char_filter(P_len, T_len, CommonChars, score_cutoff)) { |
| 370 | return 0.0; |
| 371 | } |
| 372 | |
| 373 | Transpositions = count_transpositions_block(PM, T_first, T_last, flagged, FlaggedChars); |
| 374 | } |
| 375 | |
| 376 | double Sim = jaro_calculate_similarity(P_len, T_len, CommonChars, Transpositions); |
| 377 | return common::result_cutoff(result: Sim, score_cutoff); |
| 378 | } |
| 379 | |
| 380 | template <typename InputIt1, typename InputIt2> |
| 381 | double jaro_similarity(const common::BlockPatternMatchVector& PM, InputIt1 P_first, InputIt1 P_last, |
| 382 | InputIt2 T_first, InputIt2 T_last, double score_cutoff) |
| 383 | { |
| 384 | int64_t P_len = std::distance(P_first, P_last); |
| 385 | int64_t T_len = std::distance(T_first, T_last); |
| 386 | |
| 387 | /* filter out based on the length difference between the two strings */ |
| 388 | if (!jaro_length_filter(P_len, T_len, score_cutoff)) { |
| 389 | return 0.0; |
| 390 | } |
| 391 | |
| 392 | if (P_len == 1 && T_len == 1) { |
| 393 | return static_cast<double>(P_first[0] == T_first[0]); |
| 394 | } |
| 395 | |
| 396 | int64_t Bound = jaro_bounds(P_first, P_last, T_first, T_last); |
| 397 | |
| 398 | /* common prefix never includes Transpositions */ |
| 399 | int64_t CommonChars = 0; |
| 400 | int64_t Transpositions = 0; |
| 401 | int64_t P_view_len = std::distance(P_first, P_last); |
| 402 | int64_t T_view_len = std::distance(T_first, T_last); |
| 403 | |
| 404 | if (!P_view_len || !T_view_len) { |
| 405 | /* already has correct number of common chars and transpositions */ |
| 406 | } |
| 407 | else if (P_view_len <= 64 && T_view_len <= 64) { |
| 408 | auto flagged = flag_similar_characters_word(PM, P_first, P_last, T_first, T_last, static_cast<int>(Bound)); |
| 409 | CommonChars += count_common_chars(flagged); |
| 410 | |
| 411 | if (!jaro_common_char_filter(P_len, T_len, CommonChars, score_cutoff)) { |
| 412 | return 0.0; |
| 413 | } |
| 414 | |
| 415 | Transpositions = count_transpositions_word(PM, T_first, T_last, flagged); |
| 416 | } |
| 417 | else { |
| 418 | auto flagged = flag_similar_characters_block(PM, P_first, P_last, T_first, T_last, Bound); |
| 419 | int64_t FlaggedChars = count_common_chars(flagged); |
| 420 | CommonChars += FlaggedChars; |
| 421 | |
| 422 | if (!jaro_common_char_filter(P_len, T_len, CommonChars, score_cutoff)) { |
| 423 | return 0.0; |
| 424 | } |
| 425 | |
| 426 | Transpositions = count_transpositions_block(PM, T_first, T_last, flagged, FlaggedChars); |
| 427 | } |
| 428 | |
| 429 | double Sim = jaro_calculate_similarity(P_len, T_len, CommonChars, Transpositions); |
| 430 | return common::result_cutoff(result: Sim, score_cutoff); |
| 431 | } |
| 432 | |
| 433 | template <typename InputIt1, typename InputIt2> |
| 434 | double jaro_winkler_similarity(InputIt1 P_first, InputIt1 P_last, InputIt2 T_first, InputIt2 T_last, |
| 435 | double prefix_weight, double score_cutoff) |
| 436 | { |
| 437 | int64_t P_len = std::distance(P_first, P_last); |
| 438 | int64_t T_len = std::distance(T_first, T_last); |
| 439 | int64_t min_len = std::min(P_len, T_len); |
| 440 | int64_t prefix = 0; |
| 441 | int64_t max_prefix = std::min<int64_t>(min_len, 4); |
| 442 | |
| 443 | for (; prefix < max_prefix; ++prefix) { |
| 444 | if (T_first[prefix] != P_first[prefix]) { |
| 445 | break; |
| 446 | } |
| 447 | } |
| 448 | |
| 449 | double jaro_score_cutoff = score_cutoff; |
| 450 | if (jaro_score_cutoff > 0.7) { |
| 451 | double prefix_sim = prefix * prefix_weight; |
| 452 | |
| 453 | if (prefix_sim >= 1.0) { |
| 454 | jaro_score_cutoff = 0.7; |
| 455 | } |
| 456 | else { |
| 457 | jaro_score_cutoff = |
| 458 | std::max(0.7, (prefix_sim - jaro_score_cutoff) / (prefix_sim - 1.0)); |
| 459 | } |
| 460 | } |
| 461 | |
| 462 | double Sim = jaro_similarity(P_first, P_last, T_first, T_last, jaro_score_cutoff); |
| 463 | if (Sim > 0.7) { |
| 464 | Sim += prefix * prefix_weight * (1.0 - Sim); |
| 465 | } |
| 466 | |
| 467 | return common::result_cutoff(result: Sim, score_cutoff); |
| 468 | } |
| 469 | |
| 470 | template <typename InputIt1, typename InputIt2> |
| 471 | double jaro_winkler_similarity(const common::BlockPatternMatchVector& PM, InputIt1 P_first, |
| 472 | InputIt1 P_last, InputIt2 T_first, InputIt2 T_last, |
| 473 | double prefix_weight, double score_cutoff) |
| 474 | { |
| 475 | int64_t P_len = std::distance(P_first, P_last); |
| 476 | int64_t T_len = std::distance(T_first, T_last); |
| 477 | int64_t min_len = std::min(P_len, T_len); |
| 478 | int64_t prefix = 0; |
| 479 | int64_t max_prefix = std::min<int64_t>(min_len, 4); |
| 480 | |
| 481 | for (; prefix < max_prefix; ++prefix) { |
| 482 | if (T_first[prefix] != P_first[prefix]) { |
| 483 | break; |
| 484 | } |
| 485 | } |
| 486 | |
| 487 | double jaro_score_cutoff = score_cutoff; |
| 488 | if (jaro_score_cutoff > 0.7) { |
| 489 | double prefix_sim = prefix * prefix_weight; |
| 490 | |
| 491 | if (prefix_sim >= 1.0) { |
| 492 | jaro_score_cutoff = 0.7; |
| 493 | } |
| 494 | else { |
| 495 | jaro_score_cutoff = |
| 496 | std::max(0.7, (prefix_sim - jaro_score_cutoff) / (prefix_sim - 1.0)); |
| 497 | } |
| 498 | } |
| 499 | |
| 500 | double Sim = jaro_similarity(PM, P_first, P_last, T_first, T_last, jaro_score_cutoff); |
| 501 | if (Sim > 0.7) { |
| 502 | Sim += prefix * prefix_weight * (1.0 - Sim); |
| 503 | } |
| 504 | |
| 505 | return common::result_cutoff(result: Sim, score_cutoff); |
| 506 | } |
| 507 | |
| 508 | } // namespace detail |
| 509 | } // namespace duckdb_jaro_winkler |
| 510 | |