1 | #include <Columns/ColumnString.h> |
2 | #include <Poco/UTF8Encoding.h> |
3 | #include <Common/UTF8Helpers.h> |
4 | |
5 | #ifdef __SSE2__ |
6 | #include <emmintrin.h> |
7 | #endif |
8 | |
9 | |
10 | namespace DB |
11 | { |
12 | |
13 | namespace ErrorCodes |
14 | { |
15 | extern const int BAD_ARGUMENTS; |
16 | } |
17 | |
18 | namespace |
19 | { |
20 | /// xor or do nothing |
21 | template <bool> |
22 | UInt8 xor_or_identity(const UInt8 c, const int mask) |
23 | { |
24 | return c ^ mask; |
25 | } |
26 | |
27 | template <> |
28 | inline UInt8 xor_or_identity<false>(const UInt8 c, const int) |
29 | { |
30 | return c; |
31 | } |
32 | |
33 | /// It is caller's responsibility to ensure the presence of a valid cyrillic sequence in array |
34 | template <bool to_lower> |
35 | inline void UTF8CyrillicToCase(const UInt8 *& src, UInt8 *& dst) |
36 | { |
37 | if (src[0] == 0xD0u && (src[1] >= 0x80u && src[1] <= 0x8Fu)) |
38 | { |
39 | /// ЀЁЂЃЄЅІЇЈЉЊЋЌЍЎЏ |
40 | *dst++ = xor_or_identity<to_lower>(*src++, 0x1); |
41 | *dst++ = xor_or_identity<to_lower>(*src++, 0x10); |
42 | } |
43 | else if (src[0] == 0xD1u && (src[1] >= 0x90u && src[1] <= 0x9Fu)) |
44 | { |
45 | /// ѐёђѓєѕіїјљњћќѝўџ |
46 | *dst++ = xor_or_identity<!to_lower>(*src++, 0x1); |
47 | *dst++ = xor_or_identity<!to_lower>(*src++, 0x10); |
48 | } |
49 | else if (src[0] == 0xD0u && (src[1] >= 0x90u && src[1] <= 0x9Fu)) |
50 | { |
51 | /// А-П |
52 | *dst++ = *src++; |
53 | *dst++ = xor_or_identity<to_lower>(*src++, 0x20); |
54 | } |
55 | else if (src[0] == 0xD0u && (src[1] >= 0xB0u && src[1] <= 0xBFu)) |
56 | { |
57 | /// а-п |
58 | *dst++ = *src++; |
59 | *dst++ = xor_or_identity<!to_lower>(*src++, 0x20); |
60 | } |
61 | else if (src[0] == 0xD0u && (src[1] >= 0xA0u && src[1] <= 0xAFu)) |
62 | { |
63 | /// Р-Я |
64 | *dst++ = xor_or_identity<to_lower>(*src++, 0x1); |
65 | *dst++ = xor_or_identity<to_lower>(*src++, 0x20); |
66 | } |
67 | else if (src[0] == 0xD1u && (src[1] >= 0x80u && src[1] <= 0x8Fu)) |
68 | { |
69 | /// р-я |
70 | *dst++ = xor_or_identity<!to_lower>(*src++, 0x1); |
71 | *dst++ = xor_or_identity<!to_lower>(*src++, 0x20); |
72 | } |
73 | } |
74 | } |
75 | |
76 | |
77 | /** If the string contains UTF-8 encoded text, convert it to the lower (upper) case. |
78 | * Note: It is assumed that after the character is converted to another case, |
79 | * the length of its multibyte sequence in UTF-8 does not change. |
80 | * Otherwise, the behavior is undefined. |
81 | */ |
82 | template <char not_case_lower_bound, |
83 | char not_case_upper_bound, |
84 | int to_case(int), |
85 | void cyrillic_to_case(const UInt8 *&, UInt8 *&)> |
86 | struct LowerUpperUTF8Impl |
87 | { |
88 | static void vector( |
89 | const ColumnString::Chars & data, |
90 | const ColumnString::Offsets & offsets, |
91 | ColumnString::Chars & res_data, |
92 | ColumnString::Offsets & res_offsets) |
93 | { |
94 | res_data.resize(data.size()); |
95 | res_offsets.assign(offsets); |
96 | array(data.data(), data.data() + data.size(), res_data.data()); |
97 | } |
98 | |
99 | static void vector_fixed(const ColumnString::Chars &, size_t, ColumnString::Chars &) |
100 | { |
101 | throw Exception("Functions lowerUTF8 and upperUTF8 cannot work with FixedString argument" , ErrorCodes::BAD_ARGUMENTS); |
102 | } |
103 | |
104 | /** Converts a single code point starting at `src` to desired case, storing result starting at `dst`. |
105 | * `src` and `dst` are incremented by corresponding sequence lengths. */ |
106 | static void toCase(const UInt8 *& src, const UInt8 * src_end, UInt8 *& dst) |
107 | { |
108 | if (src[0] <= ascii_upper_bound) |
109 | { |
110 | if (*src >= not_case_lower_bound && *src <= not_case_upper_bound) |
111 | *dst++ = *src++ ^ flip_case_mask; |
112 | else |
113 | *dst++ = *src++; |
114 | } |
115 | else if (src + 1 < src_end |
116 | && ((src[0] == 0xD0u && (src[1] >= 0x80u && src[1] <= 0xBFu)) || (src[0] == 0xD1u && (src[1] >= 0x80u && src[1] <= 0x9Fu)))) |
117 | { |
118 | cyrillic_to_case(src, dst); |
119 | } |
120 | else if (src + 1 < src_end && src[0] == 0xC2u) |
121 | { |
122 | /// Punctuation U+0080 - U+00BF, UTF-8: C2 80 - C2 BF |
123 | *dst++ = *src++; |
124 | *dst++ = *src++; |
125 | } |
126 | else if (src + 2 < src_end && src[0] == 0xE2u) |
127 | { |
128 | /// Characters U+2000 - U+2FFF, UTF-8: E2 80 80 - E2 BF BF |
129 | *dst++ = *src++; |
130 | *dst++ = *src++; |
131 | *dst++ = *src++; |
132 | } |
133 | else |
134 | { |
135 | static const Poco::UTF8Encoding utf8; |
136 | |
137 | int src_sequence_length = UTF8::seqLength(*src); |
138 | |
139 | int src_code_point = utf8.queryConvert(src, src_end - src); |
140 | if (src_code_point > 0) |
141 | { |
142 | int dst_code_point = to_case(src_code_point); |
143 | if (dst_code_point > 0) |
144 | { |
145 | int dst_sequence_length = utf8.convert(dst_code_point, dst, src_end - src); |
146 | |
147 | /// We don't support cases when lowercase and uppercase characters occupy different number of bytes in UTF-8. |
148 | /// As an example, this happens for ß and ẞ. |
149 | if (dst_sequence_length == src_sequence_length) |
150 | { |
151 | src += dst_sequence_length; |
152 | dst += dst_sequence_length; |
153 | return; |
154 | } |
155 | } |
156 | } |
157 | |
158 | *dst++ = *src++; |
159 | } |
160 | } |
161 | |
162 | private: |
163 | static constexpr auto ascii_upper_bound = '\x7f'; |
164 | static constexpr auto flip_case_mask = 'A' ^ 'a'; |
165 | |
166 | static void array(const UInt8 * src, const UInt8 * src_end, UInt8 * dst) |
167 | { |
168 | #ifdef __SSE2__ |
169 | static constexpr auto bytes_sse = sizeof(__m128i); |
170 | auto src_end_sse = src + (src_end - src) / bytes_sse * bytes_sse; |
171 | |
172 | /// SSE2 packed comparison operate on signed types, hence compare (c < 0) instead of (c > 0x7f) |
173 | const auto v_zero = _mm_setzero_si128(); |
174 | const auto v_not_case_lower_bound = _mm_set1_epi8(not_case_lower_bound - 1); |
175 | const auto v_not_case_upper_bound = _mm_set1_epi8(not_case_upper_bound + 1); |
176 | const auto v_flip_case_mask = _mm_set1_epi8(flip_case_mask); |
177 | |
178 | while (src < src_end_sse) |
179 | { |
180 | const auto chars = _mm_loadu_si128(reinterpret_cast<const __m128i *>(src)); |
181 | |
182 | /// check for ASCII |
183 | const auto is_not_ascii = _mm_cmplt_epi8(chars, v_zero); |
184 | const auto mask_is_not_ascii = _mm_movemask_epi8(is_not_ascii); |
185 | |
186 | /// ASCII |
187 | if (mask_is_not_ascii == 0) |
188 | { |
189 | const auto is_not_case |
190 | = _mm_and_si128(_mm_cmpgt_epi8(chars, v_not_case_lower_bound), _mm_cmplt_epi8(chars, v_not_case_upper_bound)); |
191 | const auto mask_is_not_case = _mm_movemask_epi8(is_not_case); |
192 | |
193 | /// everything in correct case ASCII |
194 | if (mask_is_not_case == 0) |
195 | _mm_storeu_si128(reinterpret_cast<__m128i *>(dst), chars); |
196 | else |
197 | { |
198 | /// ASCII in mixed case |
199 | /// keep `flip_case_mask` only where necessary, zero out elsewhere |
200 | const auto xor_mask = _mm_and_si128(v_flip_case_mask, is_not_case); |
201 | |
202 | /// flip case by applying calculated mask |
203 | const auto cased_chars = _mm_xor_si128(chars, xor_mask); |
204 | |
205 | /// store result back to destination |
206 | _mm_storeu_si128(reinterpret_cast<__m128i *>(dst), cased_chars); |
207 | } |
208 | |
209 | src += bytes_sse; |
210 | dst += bytes_sse; |
211 | } |
212 | else |
213 | { |
214 | /// UTF-8 |
215 | const auto expected_end = src + bytes_sse; |
216 | |
217 | while (src < expected_end) |
218 | toCase(src, src_end, dst); |
219 | |
220 | /// adjust src_end_sse by pushing it forward or backward |
221 | const auto diff = src - expected_end; |
222 | if (diff != 0) |
223 | { |
224 | if (src_end_sse + diff < src_end) |
225 | src_end_sse += diff; |
226 | else |
227 | src_end_sse -= bytes_sse - diff; |
228 | } |
229 | } |
230 | } |
231 | #endif |
232 | /// handle remaining symbols |
233 | while (src < src_end) |
234 | toCase(src, src_end, dst); |
235 | } |
236 | }; |
237 | |
238 | } |
239 | |