1 | // Licensed to the Apache Software Foundation (ASF) under one |
2 | // or more contributor license agreements. See the NOTICE file |
3 | // distributed with this work for additional information |
4 | // regarding copyright ownership. The ASF licenses this file |
5 | // to you under the Apache License, Version 2.0 (the |
6 | // "License"); you may not use this file except in compliance |
7 | // with the License. You may obtain a copy of the License at |
8 | // |
9 | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | // |
11 | // Unless required by applicable law or agreed to in writing, |
12 | // software distributed under the License is distributed on an |
13 | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | // KIND, either express or implied. See the License for the |
15 | // specific language governing permissions and limitations |
16 | // under the License. |
17 | |
18 | #include <random> |
19 | #include <string> |
20 | #include <vector> |
21 | |
22 | #include <gtest/gtest.h> |
23 | |
24 | #include "arrow/util/string.h" |
25 | #include "arrow/util/utf8.h" |
26 | |
27 | namespace arrow { |
28 | namespace util { |
29 | |
30 | class UTF8Test : public ::testing::Test { |
31 | protected: |
32 | static void SetUpTestCase() { |
33 | InitializeUTF8(); |
34 | |
35 | all_valid_sequences.clear(); |
36 | for (const auto& v : |
37 | {valid_sequences_1, valid_sequences_2, valid_sequences_3, valid_sequences_4}) { |
38 | all_valid_sequences.insert(all_valid_sequences.end(), v.begin(), v.end()); |
39 | } |
40 | |
41 | all_invalid_sequences.clear(); |
42 | for (const auto& v : {invalid_sequences_1, invalid_sequences_2, invalid_sequences_3, |
43 | invalid_sequences_4}) { |
44 | all_invalid_sequences.insert(all_invalid_sequences.end(), v.begin(), v.end()); |
45 | } |
46 | } |
47 | |
48 | static std::vector<std::string> valid_sequences_1; |
49 | static std::vector<std::string> valid_sequences_2; |
50 | static std::vector<std::string> valid_sequences_3; |
51 | static std::vector<std::string> valid_sequences_4; |
52 | |
53 | static std::vector<std::string> all_valid_sequences; |
54 | |
55 | static std::vector<std::string> invalid_sequences_1; |
56 | static std::vector<std::string> invalid_sequences_2; |
57 | static std::vector<std::string> invalid_sequences_3; |
58 | static std::vector<std::string> invalid_sequences_4; |
59 | |
60 | static std::vector<std::string> all_invalid_sequences; |
61 | }; |
62 | |
63 | std::vector<std::string> UTF8Test::valid_sequences_1 = {"a" , "\x7f" }; |
64 | std::vector<std::string> UTF8Test::valid_sequences_2 = {"\xc2\x80" , "\xc3\xbf" , |
65 | "\xdf\xbf" }; |
66 | std::vector<std::string> UTF8Test::valid_sequences_3 = {"\xe0\xa0\x80" , "\xe8\x9d\xa5" , |
67 | "\xef\xbf\xbf" }; |
68 | std::vector<std::string> UTF8Test::valid_sequences_4 = { |
69 | "\xf0\x90\x80\x80" , "\xf0\x9f\xbf\xbf" , "\xf4\x80\x80\x80" , "\xf4\x8f\xbf\xbf" }; |
70 | |
71 | std::vector<std::string> UTF8Test::all_valid_sequences; |
72 | |
73 | std::vector<std::string> UTF8Test::invalid_sequences_1 = {"\x80" , "\xa0" , "\xbf" , "\xc0" , |
74 | "\xc1" }; |
75 | std::vector<std::string> UTF8Test::invalid_sequences_2 = { |
76 | "\x80\x80" , "\x80\xbf" , "\xbf\x80" , "\xbf\xbf" , |
77 | "\xc1\x80" , "\xc2\x7f" , "\xc3\xff" , "\xdf\xc0" }; |
78 | std::vector<std::string> UTF8Test::invalid_sequences_3 = { |
79 | "\xe0\x80\x80" , "\xe0\x9f\x80" , "\xef\xbf\xc0" , "\xef\xc0\xbf" , "\xef\xff\xff" , |
80 | // Surrogates |
81 | "\xed\xa0\x80" , "\xed\xbf\xbf" }; |
82 | std::vector<std::string> UTF8Test::invalid_sequences_4 = { |
83 | "\xf0\x80\x80\x80" , "\xf0\x8f\x80\x80" , "\xf4\x8f\xbf\xc0" , "\xf4\x8f\xc0\xbf" , |
84 | "\xf4\x90\x80\x80" }; |
85 | |
86 | std::vector<std::string> UTF8Test::all_invalid_sequences; |
87 | |
88 | class UTF8ValidationTest : public UTF8Test {}; |
89 | |
90 | ::testing::AssertionResult IsValidUTF8(const std::string& s) { |
91 | if (ValidateUTF8(reinterpret_cast<const uint8_t*>(s.data()), s.size())) { |
92 | return ::testing::AssertionSuccess(); |
93 | } else { |
94 | std::string h = HexEncode(reinterpret_cast<const uint8_t*>(s.data()), |
95 | static_cast<int32_t>(s.size())); |
96 | return ::testing::AssertionFailure() |
97 | << "string '" << h << "' didn't validate as UTF8" ; |
98 | } |
99 | } |
100 | |
101 | ::testing::AssertionResult IsInvalidUTF8(const std::string& s) { |
102 | if (!ValidateUTF8(reinterpret_cast<const uint8_t*>(s.data()), s.size())) { |
103 | return ::testing::AssertionSuccess(); |
104 | } else { |
105 | std::string h = HexEncode(reinterpret_cast<const uint8_t*>(s.data()), |
106 | static_cast<int32_t>(s.size())); |
107 | return ::testing::AssertionFailure() << "string '" << h << "' validated as UTF8" ; |
108 | } |
109 | } |
110 | |
111 | void AssertValidUTF8(const std::string& s) { ASSERT_TRUE(IsValidUTF8(s)); } |
112 | |
113 | void AssertInvalidUTF8(const std::string& s) { ASSERT_TRUE(IsInvalidUTF8(s)); } |
114 | |
115 | TEST_F(UTF8ValidationTest, EmptyString) { AssertValidUTF8("" ); } |
116 | |
117 | TEST_F(UTF8ValidationTest, OneCharacterValid) { |
118 | for (const auto& s : all_valid_sequences) { |
119 | AssertValidUTF8(s); |
120 | } |
121 | } |
122 | |
123 | TEST_F(UTF8ValidationTest, TwoCharacterValid) { |
124 | for (const auto& s1 : all_valid_sequences) { |
125 | for (const auto& s2 : all_valid_sequences) { |
126 | AssertValidUTF8(s1 + s2); |
127 | } |
128 | } |
129 | } |
130 | |
131 | TEST_F(UTF8ValidationTest, RandomValid) { |
132 | #ifdef ARROW_VALGRIND |
133 | const int niters = 50; |
134 | #else |
135 | const int niters = 1000; |
136 | #endif |
137 | const int nchars = 100; |
138 | std::default_random_engine gen(42); |
139 | std::uniform_int_distribution<size_t> valid_dist(0, all_valid_sequences.size() - 1); |
140 | |
141 | for (int i = 0; i < niters; ++i) { |
142 | std::string s; |
143 | s.reserve(nchars * 4); |
144 | for (int j = 0; j < nchars; ++j) { |
145 | s += all_valid_sequences[valid_dist(gen)]; |
146 | } |
147 | AssertValidUTF8(s); |
148 | } |
149 | } |
150 | |
151 | TEST_F(UTF8ValidationTest, OneCharacterTruncated) { |
152 | for (const auto& s : all_valid_sequences) { |
153 | if (s.size() > 1) { |
154 | AssertInvalidUTF8(s.substr(0, s.size() - 1)); |
155 | } |
156 | } |
157 | } |
158 | |
159 | TEST_F(UTF8ValidationTest, TwoCharacterTruncated) { |
160 | for (const auto& s1 : all_valid_sequences) { |
161 | for (const auto& s2 : all_valid_sequences) { |
162 | if (s2.size() > 1) { |
163 | AssertInvalidUTF8(s1 + s2.substr(0, s2.size() - 1)); |
164 | AssertInvalidUTF8(s2.substr(0, s2.size() - 1) + s1); |
165 | } |
166 | } |
167 | } |
168 | } |
169 | |
170 | TEST_F(UTF8ValidationTest, OneCharacterInvalid) { |
171 | for (const auto& s : all_invalid_sequences) { |
172 | AssertInvalidUTF8(s); |
173 | } |
174 | } |
175 | |
176 | TEST_F(UTF8ValidationTest, TwoCharacterInvalid) { |
177 | for (const auto& s1 : all_valid_sequences) { |
178 | for (const auto& s2 : all_invalid_sequences) { |
179 | AssertInvalidUTF8(s1 + s2); |
180 | AssertInvalidUTF8(s2 + s1); |
181 | } |
182 | } |
183 | for (const auto& s1 : all_invalid_sequences) { |
184 | for (const auto& s2 : all_invalid_sequences) { |
185 | AssertInvalidUTF8(s1 + s2); |
186 | } |
187 | } |
188 | } |
189 | |
190 | TEST_F(UTF8ValidationTest, RandomInvalid) { |
191 | #ifdef ARROW_VALGRIND |
192 | const int niters = 50; |
193 | #else |
194 | const int niters = 1000; |
195 | #endif |
196 | const int nchars = 100; |
197 | std::default_random_engine gen(42); |
198 | std::uniform_int_distribution<size_t> valid_dist(0, all_valid_sequences.size() - 1); |
199 | std::uniform_int_distribution<int> invalid_pos_dist(0, nchars - 1); |
200 | std::uniform_int_distribution<size_t> invalid_dist(0, all_invalid_sequences.size() - 1); |
201 | |
202 | for (int i = 0; i < niters; ++i) { |
203 | std::string s; |
204 | s.reserve(nchars * 4); |
205 | // Stuff a single invalid sequence somewhere in a valid UTF8 stream |
206 | int invalid_pos = invalid_pos_dist(gen); |
207 | for (int j = 0; j < nchars; ++j) { |
208 | if (j == invalid_pos) { |
209 | s += all_invalid_sequences[invalid_dist(gen)]; |
210 | } else { |
211 | s += all_valid_sequences[valid_dist(gen)]; |
212 | } |
213 | } |
214 | AssertInvalidUTF8(s); |
215 | } |
216 | } |
217 | |
218 | TEST_F(UTF8ValidationTest, RandomTruncated) { |
219 | #ifdef ARROW_VALGRIND |
220 | const int niters = 50; |
221 | #else |
222 | const int niters = 1000; |
223 | #endif |
224 | const int nchars = 100; |
225 | std::default_random_engine gen(42); |
226 | std::uniform_int_distribution<size_t> valid_dist(0, all_valid_sequences.size() - 1); |
227 | std::uniform_int_distribution<int> invalid_pos_dist(0, nchars - 1); |
228 | |
229 | for (int i = 0; i < niters; ++i) { |
230 | std::string s; |
231 | s.reserve(nchars * 4); |
232 | // Truncate a single sequence somewhere in a valid UTF8 stream |
233 | int invalid_pos = invalid_pos_dist(gen); |
234 | for (int j = 0; j < nchars; ++j) { |
235 | if (j == invalid_pos) { |
236 | while (true) { |
237 | // Ensure we truncate a 2-byte or more sequence |
238 | const std::string& t = all_valid_sequences[valid_dist(gen)]; |
239 | if (t.size() > 1) { |
240 | s += t.substr(0, t.size() - 1); |
241 | break; |
242 | } |
243 | } |
244 | } else { |
245 | s += all_valid_sequences[valid_dist(gen)]; |
246 | } |
247 | } |
248 | AssertInvalidUTF8(s); |
249 | } |
250 | } |
251 | |
252 | } // namespace util |
253 | } // namespace arrow |
254 | |