1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT license.
3
4#pragma once
5
6#include <atomic>
7#include <cassert>
8#include <cinttypes>
9#include <cstdint>
10#include <deque>
11#include <fstream>
12#include <iostream>
13#include <string>
14#include <thread>
15
16#include "core/auto_ptr.h"
17#include "core/faster.h"
18#include "core/thread.h"
19#include "sum_store.h"
20
21namespace sum_store {
22
23class ConcurrentRecoveryTest {
24 public:
25 static constexpr uint64_t kNumUniqueKeys = (1L << 22);
26 static constexpr uint64_t kKeySpace = (1L << 14);
27 static constexpr uint64_t kNumOps = (1L << 25);
28 static constexpr uint64_t kRefreshInterval = (1L << 8);
29 static constexpr uint64_t kCompletePendingInterval = (1L << 12);
30 static constexpr uint64_t kCheckpointInterval = (1L << 22);
31
32 ConcurrentRecoveryTest(store_t& store_, size_t num_threads_)
33 : store{ store_ }
34 , num_threads{ num_threads_ }
35 , num_active_threads{ 0 }
36 , num_checkpoints{ 0 } {
37 }
38
39 private:
40 static void PopulateWorker(store_t* store, size_t thread_idx,
41 std::atomic<size_t>* num_active_threads, size_t num_threads,
42 std::atomic<uint32_t>* num_checkpoints) {
43 auto callback = [](IAsyncContext* ctxt, Status result) {
44 CallbackContext<RmwContext> context{ ctxt };
45 assert(result == Status::Ok);
46 };
47
48 auto hybrid_log_persistence_callback = [](Status result, uint64_t persistent_serial_num) {
49 if(result != Status::Ok) {
50 printf("Thread %" PRIu32 " reports checkpoint failed.\n",
51 Thread::id());
52 } else {
53 printf("Thread %" PRIu32 " reports persistence until %" PRIu64 "\n",
54 Thread::id(), persistent_serial_num);
55 }
56 };
57
58 // Register thread with the store
59 store->StartSession();
60
61 ++(*num_active_threads);
62
63 // Process the batch of input data
64 for(size_t idx = 0; idx < kNumOps; ++idx) {
65 RmwContext context{ idx % kNumUniqueKeys, 1 };
66 store->Rmw(context, callback, idx);
67 if(idx % kCheckpointInterval == 0 && *num_active_threads == num_threads) {
68 Guid token;
69 if(store->Checkpoint(nullptr, hybrid_log_persistence_callback, token)) {
70 printf("Thread %" PRIu32 " calling Checkpoint(), version = %" PRIu32 ", token = %s\n",
71 Thread::id(), ++(*num_checkpoints), token.ToString().c_str());
72 }
73 }
74 if(idx % kCompletePendingInterval == 0) {
75 store->CompletePending(false);
76 } else if(idx % kRefreshInterval == 0) {
77 store->Refresh();
78 }
79 }
80
81 // Make sure operations are completed
82 store->CompletePending(true);
83
84 // Deregister thread from FASTER
85 store->StopSession();
86
87 printf("Populate successful on thread %" PRIu32 ".\n", Thread::id());
88 }
89
90 public:
91 void Populate() {
92 std::deque<std::thread> threads;
93 for(size_t idx = 0; idx < num_threads; ++idx) {
94 threads.emplace_back(&PopulateWorker, &store, idx, &num_active_threads, num_threads,
95 &num_checkpoints);
96 }
97 for(auto& thread : threads) {
98 thread.join();
99 }
100 // Verify the records.
101 auto callback = [](IAsyncContext* ctxt, Status result) {
102 CallbackContext<ReadContext> context{ ctxt };
103 assert(result == Status::Ok);
104 };
105 // Create array for reading
106 auto read_results = alloc_aligned<uint64_t>(64, sizeof(uint64_t) * kNumUniqueKeys);
107 std::memset(read_results.get(), 0, sizeof(uint64_t) * kNumUniqueKeys);
108
109 // Register with thread
110 store.StartSession();
111
112 // Issue read requests
113 for(uint64_t idx = 0; idx < kNumUniqueKeys; ++idx) {
114 ReadContext context{ AdId{ idx }, read_results.get() + idx };
115 store.Read(context, callback, idx);
116 }
117
118 // Complete all pending requests
119 store.CompletePending(true);
120
121 // Release
122 store.StopSession();
123 for(uint64_t idx = 0; idx < kNumUniqueKeys; ++idx) {
124 uint64_t expected_result = (num_threads * kNumOps) / kNumUniqueKeys;
125 if(read_results.get()[idx] != expected_result) {
126 printf("Debug error for AdId %" PRIu64 ": Expected (%" PRIu64 "), Found(%" PRIu64 ")\n",
127 idx,
128 expected_result,
129 read_results.get()[idx]);
130 }
131 }
132 }
133
134 void RecoverAndTest(const Guid& index_token, const Guid& hybrid_log_token) {
135 auto callback = [](IAsyncContext* ctxt, Status result) {
136 CallbackContext<ReadContext> context{ ctxt };
137 assert(result == Status::Ok);
138 };
139
140 // Recover
141 uint32_t version;
142 std::vector<Guid> session_ids;
143 FASTER::core::Status result = store.Recover(index_token, hybrid_log_token, version,
144 session_ids);
145 if(result != FASTER::core::Status::Ok) {
146 printf("Recovery failed with error %u\n", static_cast<uint8_t>(result));
147 exit(1);
148 }
149
150 std::vector<uint64_t> serial_nums;
151 for(const auto& session_id : session_ids) {
152 serial_nums.push_back(store.ContinueSession(session_id));
153 store.StopSession();
154 }
155
156 // Create array for reading
157 auto read_results = alloc_aligned<uint64_t>(64, sizeof(uint64_t) * kNumUniqueKeys);
158 std::memset(read_results.get(), 0, sizeof(uint64_t) * kNumUniqueKeys);
159
160 // Register with thread
161 store.StartSession();
162
163 // Issue read requests
164 for(uint64_t idx = 0; idx < kNumUniqueKeys; ++idx) {
165 ReadContext context{ AdId{ idx}, read_results.get() + idx };
166 store.Read(context, callback, idx);
167 }
168
169 // Complete all pending requests
170 store.CompletePending(true);
171
172 // Release
173 store.StopSession();
174
175 // Test outputs
176 // Compute expected array
177 auto expected_results = alloc_aligned<uint64_t>(64,
178 sizeof(uint64_t) * kNumUniqueKeys);
179 std::memset(expected_results.get(), 0, sizeof(uint64_t) * kNumUniqueKeys);
180
181 // Sessions that were active during checkpoint:
182 for(uint64_t serial_num : serial_nums) {
183 for(uint64_t idx = 0; idx <= serial_num; ++idx) {
184 ++expected_results.get()[idx % kNumUniqueKeys];
185 }
186 }
187 // Sessions that were finished at time of checkpoint.
188 size_t num_completed = num_threads - serial_nums.size();
189 for(size_t thread_idx = 0; thread_idx < num_completed; ++thread_idx) {
190 uint64_t serial_num = kNumOps;
191 for(uint64_t idx = 0; idx < serial_num; ++idx) {
192 ++expected_results.get()[idx % kNumUniqueKeys];
193 }
194 }
195
196 // Assert if expected is same as found
197 for(uint64_t idx = 0; idx < kNumUniqueKeys; ++idx) {
198 if(expected_results.get()[idx] != read_results.get()[idx]) {
199 printf("Debug error for AdId %" PRIu64 ": Expected (%" PRIu64 "), Found(%" PRIu64 ")\n",
200 idx,
201 expected_results.get()[idx],
202 read_results.get()[idx]);
203 }
204 }
205 printf("Test successful\n");
206 }
207
208 static void ContinueWorker(store_t* store, size_t thread_idx,
209 std::atomic<size_t>* num_active_threads, size_t num_threads,
210 std::atomic<uint32_t>* num_checkpoints, Guid guid) {
211 auto callback = [](IAsyncContext* ctxt, Status result) {
212 CallbackContext<RmwContext> context{ ctxt };
213 assert(result == Status::Ok);
214 };
215
216
217 auto hybrid_log_persistence_callback = [](Status result, uint64_t persistent_serial_num) {
218 if(result != Status::Ok) {
219 printf("Thread %" PRIu32 " reports checkpoint failed.\n",
220 Thread::id());
221 } else {
222 printf("Thread %" PRIu32 " reports persistence until %" PRIu64 "\n",
223 Thread::id(), persistent_serial_num);
224 }
225 };
226
227 // Register thread with the store
228 uint64_t start_num = store->ContinueSession(guid);
229
230 ++(*num_active_threads);
231
232 // Process the batch of input data
233 for(size_t idx = start_num + 1; idx < kNumOps; ++idx) {
234 RmwContext context{ idx % kNumUniqueKeys, 1 };
235 store->Rmw(context, callback, idx);
236 if(idx % kCheckpointInterval == 0 && *num_active_threads == num_threads) {
237 Guid token;
238 if(store->Checkpoint(nullptr, hybrid_log_persistence_callback, token)) {
239 printf("Thread %" PRIu32 " calling Checkpoint(), version = %" PRIu32 ", token = %s\n",
240 Thread::id(), ++(*num_checkpoints), token.ToString().c_str());
241 }
242 }
243 if(idx % kCompletePendingInterval == 0) {
244 store->CompletePending(false);
245 } else if(idx % kRefreshInterval == 0) {
246 store->Refresh();
247 }
248 }
249
250 // Make sure operations are completed
251 store->CompletePending(true);
252
253 // Deregister thread from FASTER
254 store->StopSession();
255
256 printf("Populate successful on thread %" PRIu32 ".\n", Thread::id());
257 }
258
259 void Continue(const Guid& index_token, const Guid& hybrid_log_token) {
260 // Recover
261 printf("Recovering version (index_token = %s, hybrid_log_token = %s)\n",
262 index_token.ToString().c_str(), hybrid_log_token.ToString().c_str());
263 uint32_t version;
264 std::vector<Guid> session_ids;
265 FASTER::core::Status result = store.Recover(index_token, hybrid_log_token, version,
266 session_ids);
267 if(result != FASTER::core::Status::Ok) {
268 printf("Recovery failed with error %u\n", static_cast<uint8_t>(result));
269 exit(1);
270 } else {
271 printf("Recovery Done!\n");
272 }
273
274 num_checkpoints.store(version);
275 // Some threads may have already completed.
276 num_threads = session_ids.size();
277
278 std::deque<std::thread> threads;
279 for(size_t idx = 0; idx < num_threads; ++idx) {
280 threads.emplace_back(&ContinueWorker, &store, idx, &num_active_threads, num_threads,
281 &num_checkpoints, session_ids[idx]);
282 }
283 for(auto& thread : threads) {
284 thread.join();
285 }
286 }
287
288 store_t& store;
289 size_t num_threads;
290 std::atomic<size_t> num_active_threads;
291 std::atomic<uint32_t> num_checkpoints;
292};
293
294} // namespace sum_store
295