1 | /* |
2 | * Copyright 2011-present Facebook, Inc. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | |
17 | // @author: Xin Liu <xliux@fb.com> |
18 | |
19 | #include <folly/ConcurrentSkipList.h> |
20 | |
21 | #include <atomic> |
22 | #include <memory> |
23 | #include <set> |
24 | #include <system_error> |
25 | #include <thread> |
26 | #include <vector> |
27 | |
28 | #include <glog/logging.h> |
29 | |
30 | #include <folly/Memory.h> |
31 | #include <folly/String.h> |
32 | #include <folly/container/Foreach.h> |
33 | #include <folly/memory/Arena.h> |
34 | #include <folly/portability/GFlags.h> |
35 | #include <folly/portability/GTest.h> |
36 | |
37 | DEFINE_int32(num_threads, 12, "num concurrent threads to test" ); |
38 | |
39 | namespace { |
40 | |
41 | template <typename ParentAlloc> |
42 | struct ParanoidArenaAlloc { |
43 | explicit ParanoidArenaAlloc(ParentAlloc& arena) : arena_(arena) {} |
44 | ParanoidArenaAlloc(ParanoidArenaAlloc const&) = delete; |
45 | ParanoidArenaAlloc(ParanoidArenaAlloc&&) = delete; |
46 | ParanoidArenaAlloc& operator=(ParanoidArenaAlloc const&) = delete; |
47 | ParanoidArenaAlloc& operator=(ParanoidArenaAlloc&&) = delete; |
48 | |
49 | void* allocate(size_t size) { |
50 | void* result = arena_.get().allocate(size); |
51 | allocated_.insert(result); |
52 | return result; |
53 | } |
54 | |
55 | void deallocate(void* ptr, size_t n) { |
56 | EXPECT_EQ(1, allocated_.erase(ptr)); |
57 | arena_.get().deallocate(ptr, n); |
58 | } |
59 | |
60 | bool isEmpty() const { |
61 | return allocated_.empty(); |
62 | } |
63 | |
64 | std::reference_wrapper<ParentAlloc> arena_; |
65 | std::set<void*> allocated_; |
66 | }; |
67 | } // namespace |
68 | |
69 | namespace folly { |
70 | template <typename ParentAlloc> |
71 | struct AllocatorHasTrivialDeallocate<ParanoidArenaAlloc<ParentAlloc>> |
72 | : AllocatorHasTrivialDeallocate<ParentAlloc> {}; |
73 | } // namespace folly |
74 | |
75 | namespace { |
76 | |
77 | using namespace folly; |
78 | using std::vector; |
79 | |
80 | typedef int ValueType; |
81 | typedef detail::SkipListNode<ValueType> SkipListNodeType; |
82 | typedef ConcurrentSkipList<ValueType> SkipListType; |
83 | typedef SkipListType::Accessor SkipListAccessor; |
84 | typedef vector<ValueType> VectorType; |
85 | typedef std::set<ValueType> SetType; |
86 | |
87 | static const int kHeadHeight = 2; |
88 | static const int kMaxValue = 5000; |
89 | |
90 | static void randomAdding( |
91 | int size, |
92 | SkipListAccessor skipList, |
93 | SetType* verifier, |
94 | int maxValue = kMaxValue) { |
95 | for (int i = 0; i < size; ++i) { |
96 | int32_t r = rand() % maxValue; |
97 | verifier->insert(r); |
98 | skipList.add(r); |
99 | } |
100 | } |
101 | |
102 | static void randomRemoval( |
103 | int size, |
104 | SkipListAccessor skipList, |
105 | SetType* verifier, |
106 | int maxValue = kMaxValue) { |
107 | for (int i = 0; i < size; ++i) { |
108 | int32_t r = rand() % maxValue; |
109 | verifier->insert(r); |
110 | skipList.remove(r); |
111 | } |
112 | } |
113 | |
114 | static void sumAllValues(SkipListAccessor skipList, int64_t* sum) { |
115 | *sum = 0; |
116 | FOR_EACH (it, skipList) { *sum += *it; } |
117 | VLOG(20) << "sum = " << sum; |
118 | } |
119 | |
120 | static void concurrentSkip( |
121 | const vector<ValueType>* values, |
122 | SkipListAccessor skipList) { |
123 | int64_t sum = 0; |
124 | SkipListAccessor::Skipper skipper(skipList); |
125 | FOR_EACH (it, *values) { |
126 | if (skipper.to(*it)) { |
127 | sum += *it; |
128 | } |
129 | } |
130 | VLOG(20) << "sum = " << sum; |
131 | } |
132 | |
133 | bool verifyEqual(SkipListAccessor skipList, const SetType& verifier) { |
134 | EXPECT_EQ(verifier.size(), skipList.size()); |
135 | FOR_EACH (it, verifier) { |
136 | CHECK(skipList.contains(*it)) << *it; |
137 | SkipListType::const_iterator iter = skipList.find(*it); |
138 | CHECK(iter != skipList.end()); |
139 | EXPECT_EQ(*iter, *it); |
140 | } |
141 | EXPECT_TRUE(std::equal(verifier.begin(), verifier.end(), skipList.begin())); |
142 | return true; |
143 | } |
144 | |
145 | TEST(ConcurrentSkipList, SequentialAccess) { |
146 | { |
147 | LOG(INFO) << "nodetype size=" << sizeof(SkipListNodeType); |
148 | |
149 | auto skipList(SkipListType::create(kHeadHeight)); |
150 | EXPECT_TRUE(skipList.first() == nullptr); |
151 | EXPECT_TRUE(skipList.last() == nullptr); |
152 | |
153 | skipList.add(3); |
154 | EXPECT_TRUE(skipList.contains(3)); |
155 | EXPECT_FALSE(skipList.contains(2)); |
156 | EXPECT_EQ(3, *skipList.first()); |
157 | EXPECT_EQ(3, *skipList.last()); |
158 | |
159 | EXPECT_EQ(3, *skipList.find(3)); |
160 | EXPECT_FALSE(skipList.find(3) == skipList.end()); |
161 | EXPECT_TRUE(skipList.find(2) == skipList.end()); |
162 | |
163 | { |
164 | SkipListAccessor::Skipper skipper(skipList); |
165 | skipper.to(3); |
166 | CHECK_EQ(3, *skipper); |
167 | } |
168 | |
169 | skipList.add(2); |
170 | EXPECT_EQ(2, *skipList.first()); |
171 | EXPECT_EQ(3, *skipList.last()); |
172 | skipList.add(5); |
173 | EXPECT_EQ(5, *skipList.last()); |
174 | skipList.add(3); |
175 | EXPECT_EQ(5, *skipList.last()); |
176 | auto ret = skipList.insert(9); |
177 | EXPECT_EQ(9, *ret.first); |
178 | EXPECT_TRUE(ret.second); |
179 | |
180 | ret = skipList.insert(5); |
181 | EXPECT_EQ(5, *ret.first); |
182 | EXPECT_FALSE(ret.second); |
183 | |
184 | EXPECT_EQ(2, *skipList.first()); |
185 | EXPECT_EQ(9, *skipList.last()); |
186 | EXPECT_TRUE(skipList.pop_back()); |
187 | EXPECT_EQ(5, *skipList.last()); |
188 | EXPECT_TRUE(skipList.pop_back()); |
189 | EXPECT_EQ(3, *skipList.last()); |
190 | |
191 | skipList.add(9); |
192 | skipList.add(5); |
193 | |
194 | CHECK(skipList.contains(2)); |
195 | CHECK(skipList.contains(3)); |
196 | CHECK(skipList.contains(5)); |
197 | CHECK(skipList.contains(9)); |
198 | CHECK(!skipList.contains(4)); |
199 | |
200 | // lower_bound |
201 | auto it = skipList.lower_bound(5); |
202 | EXPECT_EQ(5, *it); |
203 | it = skipList.lower_bound(4); |
204 | EXPECT_EQ(5, *it); |
205 | it = skipList.lower_bound(9); |
206 | EXPECT_EQ(9, *it); |
207 | it = skipList.lower_bound(12); |
208 | EXPECT_FALSE(it.good()); |
209 | |
210 | it = skipList.begin(); |
211 | EXPECT_EQ(2, *it); |
212 | |
213 | // skipper test |
214 | SkipListAccessor::Skipper skipper(skipList); |
215 | skipper.to(3); |
216 | EXPECT_EQ(3, skipper.data()); |
217 | skipper.to(5); |
218 | EXPECT_EQ(5, skipper.data()); |
219 | CHECK(!skipper.to(7)); |
220 | |
221 | skipList.remove(5); |
222 | skipList.remove(3); |
223 | CHECK(skipper.to(9)); |
224 | EXPECT_EQ(9, skipper.data()); |
225 | |
226 | CHECK(!skipList.contains(3)); |
227 | skipList.add(3); |
228 | CHECK(skipList.contains(3)); |
229 | int pos = 0; |
230 | for (auto entry : skipList) { |
231 | LOG(INFO) << "pos= " << pos++ << " value= " << entry; |
232 | } |
233 | } |
234 | |
235 | { |
236 | auto skipList(SkipListType::create(kHeadHeight)); |
237 | |
238 | SetType verifier; |
239 | randomAdding(10000, skipList, &verifier); |
240 | verifyEqual(skipList, verifier); |
241 | |
242 | // test skipper |
243 | SkipListAccessor::Skipper skipper(skipList); |
244 | int num_skips = 1000; |
245 | for (int i = 0; i < num_skips; ++i) { |
246 | int n = i * kMaxValue / num_skips; |
247 | bool found = skipper.to(n); |
248 | EXPECT_EQ(found, (verifier.find(n) != verifier.end())); |
249 | } |
250 | } |
251 | } |
252 | |
253 | static std::string makeRandomeString(int len) { |
254 | std::string s; |
255 | for (int j = 0; j < len; j++) { |
256 | s.push_back((rand() % 26) + 'A'); |
257 | } |
258 | return s; |
259 | } |
260 | |
261 | TEST(ConcurrentSkipList, TestStringType) { |
262 | typedef folly::ConcurrentSkipList<std::string> SkipListT; |
263 | std::shared_ptr<SkipListT> skip = SkipListT::createInstance(); |
264 | SkipListT::Accessor accessor(skip); |
265 | { |
266 | for (int i = 0; i < 100000; i++) { |
267 | std::string s = makeRandomeString(7); |
268 | accessor.insert(s); |
269 | } |
270 | } |
271 | EXPECT_TRUE(std::is_sorted(accessor.begin(), accessor.end())); |
272 | } |
273 | |
274 | struct UniquePtrComp { |
275 | bool operator()(const std::unique_ptr<int>& x, const std::unique_ptr<int>& y) |
276 | const { |
277 | if (!x) { |
278 | return false; |
279 | } |
280 | if (!y) { |
281 | return true; |
282 | } |
283 | return *x < *y; |
284 | } |
285 | }; |
286 | |
287 | TEST(ConcurrentSkipList, TestMovableData) { |
288 | typedef folly::ConcurrentSkipList<std::unique_ptr<int>, UniquePtrComp> |
289 | SkipListT; |
290 | auto sl = SkipListT::createInstance(); |
291 | SkipListT::Accessor accessor(sl); |
292 | |
293 | static const int N = 10; |
294 | for (int i = 0; i < N; ++i) { |
295 | accessor.insert(std::make_unique<int>(i)); |
296 | } |
297 | |
298 | for (int i = 0; i < N; ++i) { |
299 | EXPECT_TRUE( |
300 | accessor.find(std::unique_ptr<int>(new int(i))) != accessor.end()); |
301 | } |
302 | EXPECT_TRUE( |
303 | accessor.find(std::unique_ptr<int>(new int(N))) == accessor.end()); |
304 | } |
305 | |
306 | void testConcurrentAdd(int numThreads) { |
307 | auto skipList(SkipListType::create(kHeadHeight)); |
308 | |
309 | vector<std::thread> threads; |
310 | vector<SetType> verifiers(numThreads); |
311 | try { |
312 | for (int i = 0; i < numThreads; ++i) { |
313 | threads.push_back( |
314 | std::thread(&randomAdding, 100, skipList, &verifiers[i], kMaxValue)); |
315 | } |
316 | } catch (const std::system_error& e) { |
317 | LOG(WARNING) << "Caught " << exceptionStr(e) << ": could only create " |
318 | << threads.size() << " threads out of " << numThreads; |
319 | } |
320 | for (size_t i = 0; i < threads.size(); ++i) { |
321 | threads[i].join(); |
322 | } |
323 | |
324 | SetType all; |
325 | FOR_EACH (s, verifiers) { all.insert(s->begin(), s->end()); } |
326 | verifyEqual(skipList, all); |
327 | } |
328 | |
329 | TEST(ConcurrentSkipList, ConcurrentAdd) { |
330 | // test it many times |
331 | for (int numThreads = 10; numThreads < 10000; numThreads += 1000) { |
332 | testConcurrentAdd(numThreads); |
333 | } |
334 | } |
335 | |
336 | void testConcurrentRemoval(int numThreads, int maxValue) { |
337 | auto skipList = SkipListType::create(kHeadHeight); |
338 | for (int i = 0; i < maxValue; ++i) { |
339 | skipList.add(i); |
340 | } |
341 | |
342 | vector<std::thread> threads; |
343 | vector<SetType> verifiers(numThreads); |
344 | try { |
345 | for (int i = 0; i < numThreads; ++i) { |
346 | threads.push_back( |
347 | std::thread(&randomRemoval, 100, skipList, &verifiers[i], maxValue)); |
348 | } |
349 | } catch (const std::system_error& e) { |
350 | LOG(WARNING) << "Caught " << exceptionStr(e) << ": could only create " |
351 | << threads.size() << " threads out of " << numThreads; |
352 | } |
353 | FOR_EACH (t, threads) { (*t).join(); } |
354 | |
355 | SetType all; |
356 | FOR_EACH (s, verifiers) { all.insert(s->begin(), s->end()); } |
357 | |
358 | CHECK_EQ(maxValue, all.size() + skipList.size()); |
359 | for (int i = 0; i < maxValue; ++i) { |
360 | if (all.find(i) != all.end()) { |
361 | CHECK(!skipList.contains(i)) << i; |
362 | } else { |
363 | CHECK(skipList.contains(i)) << i; |
364 | } |
365 | } |
366 | } |
367 | |
368 | TEST(ConcurrentSkipList, ConcurrentRemove) { |
369 | for (int numThreads = 10; numThreads < 1000; numThreads += 100) { |
370 | testConcurrentRemoval(numThreads, 100 * numThreads); |
371 | } |
372 | } |
373 | |
374 | static void |
375 | testConcurrentAccess(int numInsertions, int numDeletions, int maxValue) { |
376 | auto skipList = SkipListType::create(kHeadHeight); |
377 | |
378 | vector<SetType> verifiers(FLAGS_num_threads); |
379 | vector<int64_t> sums(FLAGS_num_threads); |
380 | vector<vector<ValueType>> skipValues(FLAGS_num_threads); |
381 | |
382 | for (int i = 0; i < FLAGS_num_threads; ++i) { |
383 | for (int j = 0; j < numInsertions; ++j) { |
384 | skipValues[i].push_back(rand() % (maxValue + 1)); |
385 | } |
386 | std::sort(skipValues[i].begin(), skipValues[i].end()); |
387 | } |
388 | |
389 | vector<std::thread> threads; |
390 | for (int i = 0; i < FLAGS_num_threads; ++i) { |
391 | switch (i % 8) { |
392 | case 0: |
393 | case 1: |
394 | threads.push_back(std::thread( |
395 | randomAdding, numInsertions, skipList, &verifiers[i], maxValue)); |
396 | break; |
397 | case 2: |
398 | threads.push_back(std::thread( |
399 | randomRemoval, numDeletions, skipList, &verifiers[i], maxValue)); |
400 | break; |
401 | case 3: |
402 | threads.push_back( |
403 | std::thread(concurrentSkip, &skipValues[i], skipList)); |
404 | break; |
405 | default: |
406 | threads.push_back(std::thread(sumAllValues, skipList, &sums[i])); |
407 | break; |
408 | } |
409 | } |
410 | |
411 | FOR_EACH (t, threads) { (*t).join(); } |
412 | // just run through it, no need to verify the correctness. |
413 | } |
414 | |
415 | TEST(ConcurrentSkipList, ConcurrentAccess) { |
416 | testConcurrentAccess(10000, 100, kMaxValue); |
417 | testConcurrentAccess(100000, 10000, kMaxValue * 10); |
418 | testConcurrentAccess(1000000, 100000, kMaxValue); |
419 | } |
420 | |
421 | struct NonTrivialValue { |
422 | static std::atomic<int> InstanceCounter; |
423 | static const int kBadPayLoad; |
424 | |
425 | NonTrivialValue() : payload_(kBadPayLoad) { |
426 | ++InstanceCounter; |
427 | } |
428 | |
429 | explicit NonTrivialValue(int payload) : payload_(payload) { |
430 | ++InstanceCounter; |
431 | } |
432 | |
433 | NonTrivialValue(const NonTrivialValue& rhs) : payload_(rhs.payload_) { |
434 | ++InstanceCounter; |
435 | } |
436 | |
437 | NonTrivialValue& operator=(const NonTrivialValue& rhs) { |
438 | payload_ = rhs.payload_; |
439 | return *this; |
440 | } |
441 | |
442 | ~NonTrivialValue() { |
443 | --InstanceCounter; |
444 | } |
445 | |
446 | bool operator<(const NonTrivialValue& rhs) const { |
447 | EXPECT_NE(kBadPayLoad, payload_); |
448 | EXPECT_NE(kBadPayLoad, rhs.payload_); |
449 | return payload_ < rhs.payload_; |
450 | } |
451 | |
452 | private: |
453 | int payload_; |
454 | }; |
455 | |
456 | std::atomic<int> NonTrivialValue::InstanceCounter(0); |
457 | const int NonTrivialValue::kBadPayLoad = 0xDEADBEEF; |
458 | |
459 | template <typename SkipListPtrType> |
460 | void TestNonTrivialDeallocation(SkipListPtrType& list) { |
461 | { |
462 | auto accessor = typename SkipListPtrType::element_type::Accessor(list); |
463 | static const size_t N = 10000; |
464 | for (size_t i = 0; i < N; ++i) { |
465 | accessor.add(NonTrivialValue(i)); |
466 | } |
467 | list.reset(); |
468 | } |
469 | EXPECT_EQ(0, NonTrivialValue::InstanceCounter); |
470 | } |
471 | |
472 | template <typename ParentAlloc> |
473 | void NonTrivialDeallocationWithParanoid(ParentAlloc& parentAlloc) { |
474 | using ParanoidAlloc = ParanoidArenaAlloc<ParentAlloc>; |
475 | using Alloc = CxxAllocatorAdaptor<void, ParanoidAlloc>; |
476 | using ParanoidSkipListType = |
477 | ConcurrentSkipList<NonTrivialValue, std::less<NonTrivialValue>, Alloc>; |
478 | ParanoidAlloc paranoidAlloc(parentAlloc); |
479 | Alloc alloc(paranoidAlloc); |
480 | auto list = ParanoidSkipListType::createInstance(10, alloc); |
481 | TestNonTrivialDeallocation(list); |
482 | EXPECT_TRUE(paranoidAlloc.isEmpty()); |
483 | } |
484 | |
485 | TEST(ConcurrentSkipList, NonTrivialDeallocationWithParanoidSysAlloc) { |
486 | SysAllocator<void> alloc; |
487 | NonTrivialDeallocationWithParanoid(alloc); |
488 | } |
489 | |
490 | TEST(ConcurrentSkipList, NonTrivialDeallocationWithParanoidSysArena) { |
491 | SysArena arena; |
492 | SysArenaAllocator<void> alloc(arena); |
493 | NonTrivialDeallocationWithParanoid(alloc); |
494 | } |
495 | |
496 | TEST(ConcurrentSkipList, NonTrivialDeallocationWithSysArena) { |
497 | using SysArenaSkipListType = ConcurrentSkipList< |
498 | NonTrivialValue, |
499 | std::less<NonTrivialValue>, |
500 | SysArenaAllocator<void>>; |
501 | SysArena arena; |
502 | SysArenaAllocator<void> alloc(arena); |
503 | auto list = SysArenaSkipListType::createInstance(10, alloc); |
504 | TestNonTrivialDeallocation(list); |
505 | } |
506 | |
507 | } // namespace |
508 | |
509 | int main(int argc, char* argv[]) { |
510 | testing::InitGoogleTest(&argc, argv); |
511 | google::InitGoogleLogging(argv[0]); |
512 | gflags::ParseCommandLineFlags(&argc, &argv, true); |
513 | |
514 | return RUN_ALL_TESTS(); |
515 | } |
516 | |