1/* SPDX-License-Identifier: MIT */
2/* Copyright © 2022 Max Bachmann */
3
4#pragma once
5#include <algorithm>
6#include <array>
7#include <cassert>
8#include <cmath>
9#include <cstdint>
10#include <cstring>
11#include <iterator>
12#include <type_traits>
13#include <vector>
14
15namespace duckdb_jaro_winkler {
16
17namespace common {
18
19/**
20 * @defgroup Common Common
21 * Common utilities shared among multiple functions
22 * @{
23 */
24
25/* taken from https://stackoverflow.com/a/30766365/11335032 */
26template <typename T>
27struct is_iterator {
28 static char test(...);
29
30 template <typename U, typename = typename std::iterator_traits<U>::difference_type,
31 typename = typename std::iterator_traits<U>::pointer,
32 typename = typename std::iterator_traits<U>::reference,
33 typename = typename std::iterator_traits<U>::value_type,
34 typename = typename std::iterator_traits<U>::iterator_category>
35 static long test(U&&);
36
37 constexpr static bool value = std::is_same<decltype(test(std::declval<T>())), long>::value;
38};
39
40constexpr double result_cutoff(double result, double score_cutoff)
41{
42 return (result >= score_cutoff) ? result : 0;
43}
44
45template <typename T, typename U>
46T ceildiv(T a, U divisor)
47{
48 return static_cast<T>(a / divisor) + static_cast<T>((a % divisor) != 0);
49}
50
51/**
52 * Removes common prefix of two string views // todo
53 */
54template <typename InputIt1, typename InputIt2>
55int64_t remove_common_prefix(InputIt1& first1, InputIt1 last1, InputIt2& first2, InputIt2 last2)
56{
57 // DuckDB passes a raw pointer, but this gives compile errors for std::
58 int64_t len1 = std::distance(first1, last1);
59 int64_t len2 = std::distance(first2, last2);
60 const int64_t max_comparisons = std::min<int64_t>(len1, len2);
61 int64_t prefix;
62 for (prefix = 0; prefix < max_comparisons; prefix++) {
63 if (first1[prefix] != first2[prefix]) {
64 break;
65 }
66 }
67
68// int64_t prefix = static_cast<int64_t>(
69// std::distance(first1, std::mismatch(first1, last1, first2, last2).first));
70 first1 += prefix;
71 first2 += prefix;
72 return prefix;
73}
74
75struct BitvectorHashmap {
76 struct MapElem {
77 uint64_t key = 0;
78 uint64_t value = 0;
79 };
80
81 BitvectorHashmap() : m_map()
82 {}
83
84 template <typename CharT>
85 void insert(CharT key, int64_t pos)
86 {
87 insert_mask(key, 1ull << pos);
88 }
89
90 template <typename CharT>
91 void insert_mask(CharT key, uint64_t mask)
92 {
93 uint64_t i = lookup(key: static_cast<uint64_t>(key));
94 m_map[i].key = key;
95 m_map[i].value |= mask;
96 }
97
98 template <typename CharT>
99 uint64_t get(CharT key) const
100 {
101 return m_map[lookup(key: static_cast<uint64_t>(key))].value;
102 }
103
104private:
105 /**
106 * lookup key inside the hashmap using a similar collision resolution
107 * strategy to CPython and Ruby
108 */
109 uint64_t lookup(uint64_t key) const
110 {
111 uint64_t i = key % 128;
112
113 if (!m_map[i].value || m_map[i].key == key) {
114 return i;
115 }
116
117 uint64_t perturb = key;
118 while (true) {
119 i = ((i * 5) + perturb + 1) % 128;
120 if (!m_map[i].value || m_map[i].key == key) {
121 return i;
122 }
123
124 perturb >>= 5;
125 }
126 }
127
128 std::array<MapElem, 128> m_map;
129};
130
131struct PatternMatchVector {
132 struct MapElem {
133 uint64_t key = 0;
134 uint64_t value = 0;
135 };
136
137 PatternMatchVector() : m_map(), m_extendedAscii()
138 {}
139
140 template <typename InputIt1>
141 PatternMatchVector(InputIt1 first, InputIt1 last) : m_map(), m_extendedAscii()
142 {
143 insert(first, last);
144 }
145
146 template <typename InputIt1>
147 void insert(InputIt1 first, InputIt1 last)
148 {
149 uint64_t mask = 1;
150 for (int64_t i = 0; i < std::distance(first, last); ++i) {
151 auto key = first[i];
152 if (key >= 0 && key <= 255) {
153 m_extendedAscii[key] |= mask;
154 }
155 else {
156 m_map.insert_mask(key, mask);
157 }
158 mask <<= 1;
159 }
160 }
161
162 template <typename CharT>
163 void insert(CharT key, int64_t pos)
164 {
165 uint64_t mask = 1ull << pos;
166 if (key >= 0 && key <= 255) {
167 m_extendedAscii[key] |= mask;
168 }
169 else {
170 m_map.insert_mask(key, mask);
171 }
172 }
173
174 template <typename CharT>
175 uint64_t get(CharT key) const
176 {
177 if (key >= 0 && key <= 255) {
178 return m_extendedAscii[key];
179 }
180 else {
181 return m_map.get(key);
182 }
183 }
184
185 /**
186 * combat func for BlockPatternMatchVector
187 */
188 template <typename CharT>
189 uint64_t get(int64_t block, CharT key) const
190 {
191 (void)block;
192 assert(block == 0);
193 return get(key);
194 }
195
196private:
197 BitvectorHashmap m_map;
198 std::array<uint64_t, 256> m_extendedAscii;
199};
200
201struct BlockPatternMatchVector {
202 BlockPatternMatchVector() : m_block_count(0)
203 {}
204
205 template <typename InputIt1>
206 BlockPatternMatchVector(InputIt1 first, InputIt1 last) : m_block_count(0)
207 {
208 insert(first, last);
209 }
210
211 template <typename CharT>
212 void insert(int64_t block, CharT key, int pos)
213 {
214 uint64_t mask = 1ull << pos;
215
216 assert(block < m_block_count);
217 if (key >= 0 && key <= 255) {
218 m_extendedAscii[key * m_block_count + block] |= mask;
219 }
220 else {
221 m_map[block].insert_mask(key, mask);
222 }
223 }
224
225 template <typename InputIt1>
226 void insert(InputIt1 first, InputIt1 last)
227 {
228 int64_t len = std::distance(first, last);
229 m_block_count = ceildiv(a: len, divisor: 64);
230 m_map.resize(new_size: m_block_count);
231 m_extendedAscii.resize(new_size: m_block_count * 256);
232
233 for (int64_t i = 0; i < len; ++i) {
234 int64_t block = i / 64;
235 int64_t pos = i % 64;
236 insert(block, first[i], pos);
237 }
238 }
239
240 /**
241 * combat func for PatternMatchVector
242 */
243 template <typename CharT>
244 uint64_t get(CharT key) const
245 {
246 return get(0, key);
247 }
248
249 template <typename CharT>
250 uint64_t get(int64_t block, CharT key) const
251 {
252 assert(block < m_block_count);
253 if (key >= 0 && key <= 255) {
254 return m_extendedAscii[key * m_block_count + block];
255 }
256 else {
257 return m_map[block].get(key);
258 }
259 }
260
261private:
262 std::vector<BitvectorHashmap> m_map;
263 std::vector<uint64_t> m_extendedAscii;
264 int64_t m_block_count;
265};
266
267/**@}*/
268
269} // namespace common
270} // namespace duckdb_jaro_winkler
271