1 | // Copyright (c) Microsoft Corporation. All rights reserved. |
2 | // Licensed under the MIT license. |
3 | |
4 | #pragma once |
5 | |
6 | #include <atomic> |
7 | #include <cassert> |
8 | #include <cinttypes> |
9 | #include <cstdint> |
10 | #include <deque> |
11 | #include <fstream> |
12 | #include <iostream> |
13 | #include <string> |
14 | #include <thread> |
15 | |
16 | #include "core/auto_ptr.h" |
17 | #include "core/faster.h" |
18 | #include "core/thread.h" |
19 | #include "sum_store.h" |
20 | |
21 | namespace sum_store { |
22 | |
23 | class ConcurrentRecoveryTest { |
24 | public: |
25 | static constexpr uint64_t kNumUniqueKeys = (1L << 22); |
26 | static constexpr uint64_t kKeySpace = (1L << 14); |
27 | static constexpr uint64_t kNumOps = (1L << 25); |
28 | static constexpr uint64_t kRefreshInterval = (1L << 8); |
29 | static constexpr uint64_t kCompletePendingInterval = (1L << 12); |
30 | static constexpr uint64_t kCheckpointInterval = (1L << 22); |
31 | |
32 | ConcurrentRecoveryTest(store_t& store_, size_t num_threads_) |
33 | : store{ store_ } |
34 | , num_threads{ num_threads_ } |
35 | , num_active_threads{ 0 } |
36 | , num_checkpoints{ 0 } { |
37 | } |
38 | |
39 | private: |
40 | static void PopulateWorker(store_t* store, size_t thread_idx, |
41 | std::atomic<size_t>* num_active_threads, size_t num_threads, |
42 | std::atomic<uint32_t>* num_checkpoints) { |
43 | auto callback = [](IAsyncContext* ctxt, Status result) { |
44 | CallbackContext<RmwContext> context{ ctxt }; |
45 | assert(result == Status::Ok); |
46 | }; |
47 | |
48 | auto hybrid_log_persistence_callback = [](Status result, uint64_t persistent_serial_num) { |
49 | if(result != Status::Ok) { |
50 | printf("Thread %" PRIu32 " reports checkpoint failed.\n" , |
51 | Thread::id()); |
52 | } else { |
53 | printf("Thread %" PRIu32 " reports persistence until %" PRIu64 "\n" , |
54 | Thread::id(), persistent_serial_num); |
55 | } |
56 | }; |
57 | |
58 | // Register thread with the store |
59 | store->StartSession(); |
60 | |
61 | ++(*num_active_threads); |
62 | |
63 | // Process the batch of input data |
64 | for(size_t idx = 0; idx < kNumOps; ++idx) { |
65 | RmwContext context{ idx % kNumUniqueKeys, 1 }; |
66 | store->Rmw(context, callback, idx); |
67 | if(idx % kCheckpointInterval == 0 && *num_active_threads == num_threads) { |
68 | Guid token; |
69 | if(store->Checkpoint(nullptr, hybrid_log_persistence_callback, token)) { |
70 | printf("Thread %" PRIu32 " calling Checkpoint(), version = %" PRIu32 ", token = %s\n" , |
71 | Thread::id(), ++(*num_checkpoints), token.ToString().c_str()); |
72 | } |
73 | } |
74 | if(idx % kCompletePendingInterval == 0) { |
75 | store->CompletePending(false); |
76 | } else if(idx % kRefreshInterval == 0) { |
77 | store->Refresh(); |
78 | } |
79 | } |
80 | |
81 | // Make sure operations are completed |
82 | store->CompletePending(true); |
83 | |
84 | // Deregister thread from FASTER |
85 | store->StopSession(); |
86 | |
87 | printf("Populate successful on thread %" PRIu32 ".\n" , Thread::id()); |
88 | } |
89 | |
90 | public: |
91 | void Populate() { |
92 | std::deque<std::thread> threads; |
93 | for(size_t idx = 0; idx < num_threads; ++idx) { |
94 | threads.emplace_back(&PopulateWorker, &store, idx, &num_active_threads, num_threads, |
95 | &num_checkpoints); |
96 | } |
97 | for(auto& thread : threads) { |
98 | thread.join(); |
99 | } |
100 | // Verify the records. |
101 | auto callback = [](IAsyncContext* ctxt, Status result) { |
102 | CallbackContext<ReadContext> context{ ctxt }; |
103 | assert(result == Status::Ok); |
104 | }; |
105 | // Create array for reading |
106 | auto read_results = alloc_aligned<uint64_t>(64, sizeof(uint64_t) * kNumUniqueKeys); |
107 | std::memset(read_results.get(), 0, sizeof(uint64_t) * kNumUniqueKeys); |
108 | |
109 | // Register with thread |
110 | store.StartSession(); |
111 | |
112 | // Issue read requests |
113 | for(uint64_t idx = 0; idx < kNumUniqueKeys; ++idx) { |
114 | ReadContext context{ AdId{ idx }, read_results.get() + idx }; |
115 | store.Read(context, callback, idx); |
116 | } |
117 | |
118 | // Complete all pending requests |
119 | store.CompletePending(true); |
120 | |
121 | // Release |
122 | store.StopSession(); |
123 | for(uint64_t idx = 0; idx < kNumUniqueKeys; ++idx) { |
124 | uint64_t expected_result = (num_threads * kNumOps) / kNumUniqueKeys; |
125 | if(read_results.get()[idx] != expected_result) { |
126 | printf("Debug error for AdId %" PRIu64 ": Expected (%" PRIu64 "), Found(%" PRIu64 ")\n" , |
127 | idx, |
128 | expected_result, |
129 | read_results.get()[idx]); |
130 | } |
131 | } |
132 | } |
133 | |
134 | void RecoverAndTest(const Guid& index_token, const Guid& hybrid_log_token) { |
135 | auto callback = [](IAsyncContext* ctxt, Status result) { |
136 | CallbackContext<ReadContext> context{ ctxt }; |
137 | assert(result == Status::Ok); |
138 | }; |
139 | |
140 | // Recover |
141 | uint32_t version; |
142 | std::vector<Guid> session_ids; |
143 | FASTER::core::Status result = store.Recover(index_token, hybrid_log_token, version, |
144 | session_ids); |
145 | if(result != FASTER::core::Status::Ok) { |
146 | printf("Recovery failed with error %u\n" , static_cast<uint8_t>(result)); |
147 | exit(1); |
148 | } |
149 | |
150 | std::vector<uint64_t> serial_nums; |
151 | for(const auto& session_id : session_ids) { |
152 | serial_nums.push_back(store.ContinueSession(session_id)); |
153 | store.StopSession(); |
154 | } |
155 | |
156 | // Create array for reading |
157 | auto read_results = alloc_aligned<uint64_t>(64, sizeof(uint64_t) * kNumUniqueKeys); |
158 | std::memset(read_results.get(), 0, sizeof(uint64_t) * kNumUniqueKeys); |
159 | |
160 | // Register with thread |
161 | store.StartSession(); |
162 | |
163 | // Issue read requests |
164 | for(uint64_t idx = 0; idx < kNumUniqueKeys; ++idx) { |
165 | ReadContext context{ AdId{ idx}, read_results.get() + idx }; |
166 | store.Read(context, callback, idx); |
167 | } |
168 | |
169 | // Complete all pending requests |
170 | store.CompletePending(true); |
171 | |
172 | // Release |
173 | store.StopSession(); |
174 | |
175 | // Test outputs |
176 | // Compute expected array |
177 | auto expected_results = alloc_aligned<uint64_t>(64, |
178 | sizeof(uint64_t) * kNumUniqueKeys); |
179 | std::memset(expected_results.get(), 0, sizeof(uint64_t) * kNumUniqueKeys); |
180 | |
181 | // Sessions that were active during checkpoint: |
182 | for(uint64_t serial_num : serial_nums) { |
183 | for(uint64_t idx = 0; idx <= serial_num; ++idx) { |
184 | ++expected_results.get()[idx % kNumUniqueKeys]; |
185 | } |
186 | } |
187 | // Sessions that were finished at time of checkpoint. |
188 | size_t num_completed = num_threads - serial_nums.size(); |
189 | for(size_t thread_idx = 0; thread_idx < num_completed; ++thread_idx) { |
190 | uint64_t serial_num = kNumOps; |
191 | for(uint64_t idx = 0; idx < serial_num; ++idx) { |
192 | ++expected_results.get()[idx % kNumUniqueKeys]; |
193 | } |
194 | } |
195 | |
196 | // Assert if expected is same as found |
197 | for(uint64_t idx = 0; idx < kNumUniqueKeys; ++idx) { |
198 | if(expected_results.get()[idx] != read_results.get()[idx]) { |
199 | printf("Debug error for AdId %" PRIu64 ": Expected (%" PRIu64 "), Found(%" PRIu64 ")\n" , |
200 | idx, |
201 | expected_results.get()[idx], |
202 | read_results.get()[idx]); |
203 | } |
204 | } |
205 | printf("Test successful\n" ); |
206 | } |
207 | |
208 | static void ContinueWorker(store_t* store, size_t thread_idx, |
209 | std::atomic<size_t>* num_active_threads, size_t num_threads, |
210 | std::atomic<uint32_t>* num_checkpoints, Guid guid) { |
211 | auto callback = [](IAsyncContext* ctxt, Status result) { |
212 | CallbackContext<RmwContext> context{ ctxt }; |
213 | assert(result == Status::Ok); |
214 | }; |
215 | |
216 | |
217 | auto hybrid_log_persistence_callback = [](Status result, uint64_t persistent_serial_num) { |
218 | if(result != Status::Ok) { |
219 | printf("Thread %" PRIu32 " reports checkpoint failed.\n" , |
220 | Thread::id()); |
221 | } else { |
222 | printf("Thread %" PRIu32 " reports persistence until %" PRIu64 "\n" , |
223 | Thread::id(), persistent_serial_num); |
224 | } |
225 | }; |
226 | |
227 | // Register thread with the store |
228 | uint64_t start_num = store->ContinueSession(guid); |
229 | |
230 | ++(*num_active_threads); |
231 | |
232 | // Process the batch of input data |
233 | for(size_t idx = start_num + 1; idx < kNumOps; ++idx) { |
234 | RmwContext context{ idx % kNumUniqueKeys, 1 }; |
235 | store->Rmw(context, callback, idx); |
236 | if(idx % kCheckpointInterval == 0 && *num_active_threads == num_threads) { |
237 | Guid token; |
238 | if(store->Checkpoint(nullptr, hybrid_log_persistence_callback, token)) { |
239 | printf("Thread %" PRIu32 " calling Checkpoint(), version = %" PRIu32 ", token = %s\n" , |
240 | Thread::id(), ++(*num_checkpoints), token.ToString().c_str()); |
241 | } |
242 | } |
243 | if(idx % kCompletePendingInterval == 0) { |
244 | store->CompletePending(false); |
245 | } else if(idx % kRefreshInterval == 0) { |
246 | store->Refresh(); |
247 | } |
248 | } |
249 | |
250 | // Make sure operations are completed |
251 | store->CompletePending(true); |
252 | |
253 | // Deregister thread from FASTER |
254 | store->StopSession(); |
255 | |
256 | printf("Populate successful on thread %" PRIu32 ".\n" , Thread::id()); |
257 | } |
258 | |
259 | void Continue(const Guid& index_token, const Guid& hybrid_log_token) { |
260 | // Recover |
261 | printf("Recovering version (index_token = %s, hybrid_log_token = %s)\n" , |
262 | index_token.ToString().c_str(), hybrid_log_token.ToString().c_str()); |
263 | uint32_t version; |
264 | std::vector<Guid> session_ids; |
265 | FASTER::core::Status result = store.Recover(index_token, hybrid_log_token, version, |
266 | session_ids); |
267 | if(result != FASTER::core::Status::Ok) { |
268 | printf("Recovery failed with error %u\n" , static_cast<uint8_t>(result)); |
269 | exit(1); |
270 | } else { |
271 | printf("Recovery Done!\n" ); |
272 | } |
273 | |
274 | num_checkpoints.store(version); |
275 | // Some threads may have already completed. |
276 | num_threads = session_ids.size(); |
277 | |
278 | std::deque<std::thread> threads; |
279 | for(size_t idx = 0; idx < num_threads; ++idx) { |
280 | threads.emplace_back(&ContinueWorker, &store, idx, &num_active_threads, num_threads, |
281 | &num_checkpoints, session_ids[idx]); |
282 | } |
283 | for(auto& thread : threads) { |
284 | thread.join(); |
285 | } |
286 | } |
287 | |
288 | store_t& store; |
289 | size_t num_threads; |
290 | std::atomic<size_t> num_active_threads; |
291 | std::atomic<uint32_t> num_checkpoints; |
292 | }; |
293 | |
294 | } // namespace sum_store |
295 | |