1#pragma once
2
3#include <cstdint>
4
5#if defined(__SSE2__)
6 #include <emmintrin.h>
7#endif
8#if defined(__SSE4_2__)
9 #include <nmmintrin.h>
10#endif
11
12
13/** find_first_symbols<c1, c2, ...>(begin, end):
14 *
15 * Allow to search for next character from the set of 'symbols...' in a string.
16 * It is similar to 'strpbrk', 'strcspn' (and 'strchr', 'memchr' in the case of one symbol and '\0'),
17 * but with the following differencies:
18 * - works with any memory ranges, including containing zero bytes;
19 * - doesn't require terminating zero byte: end of memory range is passed explicitly;
20 * - if not found, returns pointer to end instead of nullptr;
21 * - maximum number of symbols to search is 16.
22 *
23 * Uses SSE 2 in case of small number of symbols for search and SSE 4.2 in the case of large number of symbols,
24 * that have more than 2x performance advantage over trivial loop
25 * in the case of parsing tab-separated dump with (probably escaped) string fields.
26 * In the case of parsing tab separated dump with short strings, there is no performance degradation over trivial loop.
27 *
28 * Note: the optimal threshold to choose between SSE 2 and SSE 4.2 may depend on CPU model.
29 *
30 * find_last_symbols_or_null<c1, c2, ...>(begin, end):
31 *
32 * Allow to search for the last matching character in a string.
33 * If no such characters, returns nullptr.
34 */
35
36namespace detail
37{
38
39template <char s0>
40inline bool is_in(char x)
41{
42 return x == s0;
43}
44
45template <char s0, char s1, char... tail>
46inline bool is_in(char x)
47{
48 return x == s0 || is_in<s1, tail...>(x);
49}
50
51#if defined(__SSE2__)
52template <char s0>
53inline __m128i mm_is_in(__m128i bytes)
54{
55 __m128i eq0 = _mm_cmpeq_epi8(bytes, _mm_set1_epi8(s0));
56 return eq0;
57}
58
59template <char s0, char s1, char... tail>
60inline __m128i mm_is_in(__m128i bytes)
61{
62 __m128i eq0 = _mm_cmpeq_epi8(bytes, _mm_set1_epi8(s0));
63 __m128i eq = mm_is_in<s1, tail...>(bytes);
64 return _mm_or_si128(eq0, eq);
65}
66#endif
67
68template <bool positive>
69bool maybe_negate(bool x)
70{
71 if constexpr (positive)
72 return x;
73 else
74 return !x;
75}
76
77template <bool positive>
78uint16_t maybe_negate(uint16_t x)
79{
80 if constexpr (positive)
81 return x;
82 else
83 return ~x;
84}
85
86enum class ReturnMode
87{
88 End,
89 Nullptr,
90};
91
92
93template <bool positive, ReturnMode return_mode, char... symbols>
94inline const char * find_first_symbols_sse2(const char * const begin, const char * const end)
95{
96 const char * pos = begin;
97
98#if defined(__SSE2__)
99 for (; pos + 15 < end; pos += 16)
100 {
101 __m128i bytes = _mm_loadu_si128(reinterpret_cast<const __m128i *>(pos));
102
103 __m128i eq = mm_is_in<symbols...>(bytes);
104
105 uint16_t bit_mask = maybe_negate<positive>(uint16_t(_mm_movemask_epi8(eq)));
106 if (bit_mask)
107 return pos + __builtin_ctz(bit_mask);
108 }
109#endif
110
111 for (; pos < end; ++pos)
112 if (maybe_negate<positive>(is_in<symbols...>(*pos)))
113 return pos;
114
115 return return_mode == ReturnMode::End ? end : nullptr;
116}
117
118
119template <bool positive, ReturnMode return_mode, char... symbols>
120inline const char * find_last_symbols_sse2(const char * const begin, const char * const end)
121{
122 const char * pos = end;
123
124#if defined(__SSE2__)
125 for (; pos - 16 >= begin; pos -= 16) /// Assuming the pointer cannot overflow. Assuming we can compare these pointers.
126 {
127 __m128i bytes = _mm_loadu_si128(reinterpret_cast<const __m128i *>(pos - 16));
128
129 __m128i eq = mm_is_in<symbols...>(bytes);
130
131 uint16_t bit_mask = maybe_negate<positive>(uint16_t(_mm_movemask_epi8(eq)));
132 if (bit_mask)
133 return pos - 1 - (__builtin_clz(bit_mask) - 16); /// because __builtin_clz works with mask as uint32.
134 }
135#endif
136
137 --pos;
138 for (; pos >= begin; --pos)
139 if (maybe_negate<positive>(is_in<symbols...>(*pos)))
140 return pos;
141
142 return return_mode == ReturnMode::End ? end : nullptr;
143}
144
145
146template <bool positive, ReturnMode return_mode, size_t num_chars,
147 char c01, char c02 = 0, char c03 = 0, char c04 = 0,
148 char c05 = 0, char c06 = 0, char c07 = 0, char c08 = 0,
149 char c09 = 0, char c10 = 0, char c11 = 0, char c12 = 0,
150 char c13 = 0, char c14 = 0, char c15 = 0, char c16 = 0>
151inline const char * find_first_symbols_sse42_impl(const char * const begin, const char * const end)
152{
153 const char * pos = begin;
154
155#if defined(__SSE4_2__)
156#define MODE (_SIDD_UBYTE_OPS | _SIDD_CMP_EQUAL_ANY | _SIDD_LEAST_SIGNIFICANT)
157 __m128i set = _mm_setr_epi8(c01, c02, c03, c04, c05, c06, c07, c08, c09, c10, c11, c12, c13, c14, c15, c16);
158
159 for (; pos + 15 < end; pos += 16)
160 {
161 __m128i bytes = _mm_loadu_si128(reinterpret_cast<const __m128i *>(pos));
162
163 if constexpr (positive)
164 {
165 if (_mm_cmpestrc(set, num_chars, bytes, 16, MODE))
166 return pos + _mm_cmpestri(set, num_chars, bytes, 16, MODE);
167 }
168 else
169 {
170 if (_mm_cmpestrc(set, num_chars, bytes, 16, MODE | _SIDD_NEGATIVE_POLARITY))
171 return pos + _mm_cmpestri(set, num_chars, bytes, 16, MODE | _SIDD_NEGATIVE_POLARITY);
172 }
173 }
174#undef MODE
175#endif
176
177 for (; pos < end; ++pos)
178 if ( (num_chars >= 1 && maybe_negate<positive>(*pos == c01))
179 || (num_chars >= 2 && maybe_negate<positive>(*pos == c02))
180 || (num_chars >= 3 && maybe_negate<positive>(*pos == c03))
181 || (num_chars >= 4 && maybe_negate<positive>(*pos == c04))
182 || (num_chars >= 5 && maybe_negate<positive>(*pos == c05))
183 || (num_chars >= 6 && maybe_negate<positive>(*pos == c06))
184 || (num_chars >= 7 && maybe_negate<positive>(*pos == c07))
185 || (num_chars >= 8 && maybe_negate<positive>(*pos == c08))
186 || (num_chars >= 9 && maybe_negate<positive>(*pos == c09))
187 || (num_chars >= 10 && maybe_negate<positive>(*pos == c10))
188 || (num_chars >= 11 && maybe_negate<positive>(*pos == c11))
189 || (num_chars >= 12 && maybe_negate<positive>(*pos == c12))
190 || (num_chars >= 13 && maybe_negate<positive>(*pos == c13))
191 || (num_chars >= 14 && maybe_negate<positive>(*pos == c14))
192 || (num_chars >= 15 && maybe_negate<positive>(*pos == c15))
193 || (num_chars >= 16 && maybe_negate<positive>(*pos == c16)))
194 return pos;
195 return return_mode == ReturnMode::End ? end : nullptr;
196}
197
198
199template <bool positive, ReturnMode return_mode, char... symbols>
200inline const char * find_first_symbols_sse42(const char * begin, const char * end)
201{
202 return find_first_symbols_sse42_impl<positive, return_mode, sizeof...(symbols), symbols...>(begin, end);
203}
204
205/// NOTE No SSE 4.2 implementation for find_last_symbols_or_null. Not worth to do.
206
207template <bool positive, ReturnMode return_mode, char... symbols>
208inline const char * find_first_symbols_dispatch(const char * begin, const char * end)
209{
210#if defined(__SSE4_2__)
211 if (sizeof...(symbols) >= 5)
212 return find_first_symbols_sse42<positive, return_mode, symbols...>(begin, end);
213 else
214#endif
215 return find_first_symbols_sse2<positive, return_mode, symbols...>(begin, end);
216}
217
218}
219
220
221template <char... symbols>
222inline const char * find_first_symbols(const char * begin, const char * end)
223{
224 return detail::find_first_symbols_dispatch<true, detail::ReturnMode::End, symbols...>(begin, end);
225}
226
227/// Returning non const result for non const arguments.
228/// It is convenient when you are using this function to iterate through non-const buffer.
229template <char... symbols>
230inline char * find_first_symbols(char * begin, char * end)
231{
232 return const_cast<char *>(detail::find_first_symbols_dispatch<true, detail::ReturnMode::End, symbols...>(begin, end));
233}
234
235template <char... symbols>
236inline const char * find_first_not_symbols(const char * begin, const char * end)
237{
238 return detail::find_first_symbols_dispatch<false, detail::ReturnMode::End, symbols...>(begin, end);
239}
240
241template <char... symbols>
242inline char * find_first_not_symbols(char * begin, char * end)
243{
244 return const_cast<char *>(detail::find_first_symbols_dispatch<false, detail::ReturnMode::End, symbols...>(begin, end));
245}
246
247template <char... symbols>
248inline const char * find_first_symbols_or_null(const char * begin, const char * end)
249{
250 return detail::find_first_symbols_dispatch<true, detail::ReturnMode::Nullptr, symbols...>(begin, end);
251}
252
253template <char... symbols>
254inline char * find_first_symbols_or_null(char * begin, char * end)
255{
256 return const_cast<char *>(detail::find_first_symbols_dispatch<true, detail::ReturnMode::Nullptr, symbols...>(begin, end));
257}
258
259template <char... symbols>
260inline const char * find_first_not_symbols_or_null(const char * begin, const char * end)
261{
262 return detail::find_first_symbols_dispatch<false, detail::ReturnMode::Nullptr, symbols...>(begin, end);
263}
264
265template <char... symbols>
266inline char * find_first_not_symbols_or_null(char * begin, char * end)
267{
268 return const_cast<char *>(detail::find_first_symbols_dispatch<false, detail::ReturnMode::Nullptr, symbols...>(begin, end));
269}
270
271
272template <char... symbols>
273inline const char * find_last_symbols_or_null(const char * begin, const char * end)
274{
275 return detail::find_last_symbols_sse2<true, detail::ReturnMode::Nullptr, symbols...>(begin, end);
276}
277
278template <char... symbols>
279inline char * find_last_symbols_or_null(char * begin, char * end)
280{
281 return const_cast<char *>(detail::find_last_symbols_sse2<true, detail::ReturnMode::Nullptr, symbols...>(begin, end));
282}
283
284template <char... symbols>
285inline const char * find_last_not_symbols_or_null(const char * begin, const char * end)
286{
287 return detail::find_last_symbols_sse2<false, detail::ReturnMode::Nullptr, symbols...>(begin, end);
288}
289
290template <char... symbols>
291inline char * find_last_not_symbols_or_null(char * begin, char * end)
292{
293 return const_cast<char *>(detail::find_last_symbols_sse2<false, detail::ReturnMode::Nullptr, symbols...>(begin, end));
294}
295