1 | /* SPDX-License-Identifier: MIT */ |
2 | /* Copyright © 2022 Max Bachmann */ |
3 | |
4 | #pragma once |
5 | #include <algorithm> |
6 | #include <array> |
7 | #include <cassert> |
8 | #include <cmath> |
9 | #include <cstdint> |
10 | #include <cstring> |
11 | #include <iterator> |
12 | #include <type_traits> |
13 | #include <vector> |
14 | |
15 | namespace duckdb_jaro_winkler { |
16 | |
17 | namespace common { |
18 | |
19 | /** |
20 | * @defgroup Common Common |
21 | * Common utilities shared among multiple functions |
22 | * @{ |
23 | */ |
24 | |
25 | /* taken from https://stackoverflow.com/a/30766365/11335032 */ |
26 | template <typename T> |
27 | struct is_iterator { |
28 | static char test(...); |
29 | |
30 | template <typename U, typename = typename std::iterator_traits<U>::difference_type, |
31 | typename = typename std::iterator_traits<U>::pointer, |
32 | typename = typename std::iterator_traits<U>::reference, |
33 | typename = typename std::iterator_traits<U>::value_type, |
34 | typename = typename std::iterator_traits<U>::iterator_category> |
35 | static long test(U&&); |
36 | |
37 | constexpr static bool value = std::is_same<decltype(test(std::declval<T>())), long>::value; |
38 | }; |
39 | |
40 | constexpr double result_cutoff(double result, double score_cutoff) |
41 | { |
42 | return (result >= score_cutoff) ? result : 0; |
43 | } |
44 | |
45 | template <typename T, typename U> |
46 | T ceildiv(T a, U divisor) |
47 | { |
48 | return static_cast<T>(a / divisor) + static_cast<T>((a % divisor) != 0); |
49 | } |
50 | |
51 | /** |
52 | * Removes common prefix of two string views // todo |
53 | */ |
54 | template <typename InputIt1, typename InputIt2> |
55 | int64_t remove_common_prefix(InputIt1& first1, InputIt1 last1, InputIt2& first2, InputIt2 last2) |
56 | { |
57 | // DuckDB passes a raw pointer, but this gives compile errors for std:: |
58 | int64_t len1 = std::distance(first1, last1); |
59 | int64_t len2 = std::distance(first2, last2); |
60 | const int64_t max_comparisons = std::min<int64_t>(len1, len2); |
61 | int64_t prefix; |
62 | for (prefix = 0; prefix < max_comparisons; prefix++) { |
63 | if (first1[prefix] != first2[prefix]) { |
64 | break; |
65 | } |
66 | } |
67 | |
68 | // int64_t prefix = static_cast<int64_t>( |
69 | // std::distance(first1, std::mismatch(first1, last1, first2, last2).first)); |
70 | first1 += prefix; |
71 | first2 += prefix; |
72 | return prefix; |
73 | } |
74 | |
75 | struct BitvectorHashmap { |
76 | struct MapElem { |
77 | uint64_t key = 0; |
78 | uint64_t value = 0; |
79 | }; |
80 | |
81 | BitvectorHashmap() : m_map() |
82 | {} |
83 | |
84 | template <typename CharT> |
85 | void insert(CharT key, int64_t pos) |
86 | { |
87 | insert_mask(key, 1ull << pos); |
88 | } |
89 | |
90 | template <typename CharT> |
91 | void insert_mask(CharT key, uint64_t mask) |
92 | { |
93 | uint64_t i = lookup(key: static_cast<uint64_t>(key)); |
94 | m_map[i].key = key; |
95 | m_map[i].value |= mask; |
96 | } |
97 | |
98 | template <typename CharT> |
99 | uint64_t get(CharT key) const |
100 | { |
101 | return m_map[lookup(key: static_cast<uint64_t>(key))].value; |
102 | } |
103 | |
104 | private: |
105 | /** |
106 | * lookup key inside the hashmap using a similar collision resolution |
107 | * strategy to CPython and Ruby |
108 | */ |
109 | uint64_t lookup(uint64_t key) const |
110 | { |
111 | uint64_t i = key % 128; |
112 | |
113 | if (!m_map[i].value || m_map[i].key == key) { |
114 | return i; |
115 | } |
116 | |
117 | uint64_t perturb = key; |
118 | while (true) { |
119 | i = ((i * 5) + perturb + 1) % 128; |
120 | if (!m_map[i].value || m_map[i].key == key) { |
121 | return i; |
122 | } |
123 | |
124 | perturb >>= 5; |
125 | } |
126 | } |
127 | |
128 | std::array<MapElem, 128> m_map; |
129 | }; |
130 | |
131 | struct PatternMatchVector { |
132 | struct MapElem { |
133 | uint64_t key = 0; |
134 | uint64_t value = 0; |
135 | }; |
136 | |
137 | PatternMatchVector() : m_map(), m_extendedAscii() |
138 | {} |
139 | |
140 | template <typename InputIt1> |
141 | PatternMatchVector(InputIt1 first, InputIt1 last) : m_map(), m_extendedAscii() |
142 | { |
143 | insert(first, last); |
144 | } |
145 | |
146 | template <typename InputIt1> |
147 | void insert(InputIt1 first, InputIt1 last) |
148 | { |
149 | uint64_t mask = 1; |
150 | for (int64_t i = 0; i < std::distance(first, last); ++i) { |
151 | auto key = first[i]; |
152 | if (key >= 0 && key <= 255) { |
153 | m_extendedAscii[key] |= mask; |
154 | } |
155 | else { |
156 | m_map.insert_mask(key, mask); |
157 | } |
158 | mask <<= 1; |
159 | } |
160 | } |
161 | |
162 | template <typename CharT> |
163 | void insert(CharT key, int64_t pos) |
164 | { |
165 | uint64_t mask = 1ull << pos; |
166 | if (key >= 0 && key <= 255) { |
167 | m_extendedAscii[key] |= mask; |
168 | } |
169 | else { |
170 | m_map.insert_mask(key, mask); |
171 | } |
172 | } |
173 | |
174 | template <typename CharT> |
175 | uint64_t get(CharT key) const |
176 | { |
177 | if (key >= 0 && key <= 255) { |
178 | return m_extendedAscii[key]; |
179 | } |
180 | else { |
181 | return m_map.get(key); |
182 | } |
183 | } |
184 | |
185 | /** |
186 | * combat func for BlockPatternMatchVector |
187 | */ |
188 | template <typename CharT> |
189 | uint64_t get(int64_t block, CharT key) const |
190 | { |
191 | (void)block; |
192 | assert(block == 0); |
193 | return get(key); |
194 | } |
195 | |
196 | private: |
197 | BitvectorHashmap m_map; |
198 | std::array<uint64_t, 256> m_extendedAscii; |
199 | }; |
200 | |
201 | struct BlockPatternMatchVector { |
202 | BlockPatternMatchVector() : m_block_count(0) |
203 | {} |
204 | |
205 | template <typename InputIt1> |
206 | BlockPatternMatchVector(InputIt1 first, InputIt1 last) : m_block_count(0) |
207 | { |
208 | insert(first, last); |
209 | } |
210 | |
211 | template <typename CharT> |
212 | void insert(int64_t block, CharT key, int pos) |
213 | { |
214 | uint64_t mask = 1ull << pos; |
215 | |
216 | assert(block < m_block_count); |
217 | if (key >= 0 && key <= 255) { |
218 | m_extendedAscii[key * m_block_count + block] |= mask; |
219 | } |
220 | else { |
221 | m_map[block].insert_mask(key, mask); |
222 | } |
223 | } |
224 | |
225 | template <typename InputIt1> |
226 | void insert(InputIt1 first, InputIt1 last) |
227 | { |
228 | int64_t len = std::distance(first, last); |
229 | m_block_count = ceildiv(a: len, divisor: 64); |
230 | m_map.resize(new_size: m_block_count); |
231 | m_extendedAscii.resize(new_size: m_block_count * 256); |
232 | |
233 | for (int64_t i = 0; i < len; ++i) { |
234 | int64_t block = i / 64; |
235 | int64_t pos = i % 64; |
236 | insert(block, first[i], pos); |
237 | } |
238 | } |
239 | |
240 | /** |
241 | * combat func for PatternMatchVector |
242 | */ |
243 | template <typename CharT> |
244 | uint64_t get(CharT key) const |
245 | { |
246 | return get(0, key); |
247 | } |
248 | |
249 | template <typename CharT> |
250 | uint64_t get(int64_t block, CharT key) const |
251 | { |
252 | assert(block < m_block_count); |
253 | if (key >= 0 && key <= 255) { |
254 | return m_extendedAscii[key * m_block_count + block]; |
255 | } |
256 | else { |
257 | return m_map[block].get(key); |
258 | } |
259 | } |
260 | |
261 | private: |
262 | std::vector<BitvectorHashmap> m_map; |
263 | std::vector<uint64_t> m_extendedAscii; |
264 | int64_t m_block_count; |
265 | }; |
266 | |
267 | /**@}*/ |
268 | |
269 | } // namespace common |
270 | } // namespace duckdb_jaro_winkler |
271 | |