1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT license.
3
4#pragma once
5
6#include <cassert>
7#include <cinttypes>
8#include <cstdint>
9#include <fstream>
10#include <iostream>
11#include <string>
12
13#include "core/auto_ptr.h"
14#include "core/faster.h"
15#include "sum_store.h"
16
17using namespace FASTER;
18
19namespace sum_store {
20
21class SingleThreadedRecoveryTest {
22 public:
23 static constexpr uint64_t kNumUniqueKeys = (1L << 23);
24 static constexpr uint64_t kNumOps = (1L << 25);
25 static constexpr uint64_t kRefreshInterval = (1L << 8);
26 static constexpr uint64_t kCompletePendingInterval = (1L << 12);
27 static constexpr uint64_t kCheckpointInterval = (1L << 20);
28
29 SingleThreadedRecoveryTest(store_t& store_)
30 : store{ store_ } {
31 }
32
33 private:
34
35 public:
36 void Populate() {
37 auto callback = [](IAsyncContext* ctxt, Status result) {
38 CallbackContext<RmwContext> context{ ctxt };
39 assert(result == Status::Ok);
40 };
41
42 auto hybrid_log_persistence_callback = [](Status result, uint64_t persistent_serial_num) {
43 if(result != Status::Ok) {
44 printf("Thread %" PRIu32 " reports checkpoint failed.\n",
45 Thread::id());
46 } else {
47 printf("Thread %" PRIu32 " reports persistence until %" PRIu64 "\n",
48 Thread::id(), persistent_serial_num);
49 }
50 };
51
52 // Register thread with FASTER
53 store.StartSession();
54
55 // Process the batch of input data
56 for(uint64_t idx = 0; idx < kNumOps; ++idx) {
57 RmwContext context{ AdId{ idx % kNumUniqueKeys}, 1 };
58 store.Rmw(context, callback, idx);
59
60 if(idx % kCheckpointInterval == 0) {
61 Guid token;
62 store.Checkpoint(nullptr, hybrid_log_persistence_callback, token);
63 printf("Calling Checkpoint(), token = %s\n", token.ToString().c_str());
64 }
65 if(idx % kCompletePendingInterval == 0) {
66 store.CompletePending(false);
67 } else if(idx % kRefreshInterval == 0) {
68 store.Refresh();
69 }
70 }
71 // Make sure operations are completed
72 store.CompletePending(true);
73
74 // Deregister thread from FASTER
75 store.StopSession();
76
77 printf("Populate successful\n");
78
79 std::string discard;
80 std::getline(std::cin, discard);
81 }
82
83 void RecoverAndTest(const Guid& index_token, const Guid& hybrid_log_token) {
84 auto callback = [](IAsyncContext* ctxt, Status result) {
85 CallbackContext<ReadContext> context{ ctxt };
86 assert(result == Status::Ok);
87 };
88
89 // Recover
90 uint32_t version;
91 std::vector<Guid> session_ids;
92 store.Recover(index_token, hybrid_log_token, version, session_ids);
93
94 // Create array for reading
95 auto read_results = alloc_aligned<uint64_t>(64, sizeof(uint64_t) * kNumUniqueKeys);
96 std::memset(read_results.get(), 0, sizeof(uint64_t) * kNumUniqueKeys);
97
98 Guid session_id = session_ids[0];
99
100 // Register with thread
101 uint64_t sno = store.ContinueSession(session_id);
102
103 // Issue read requests
104 for(uint64_t idx = 0; idx < kNumUniqueKeys; ++idx) {
105 ReadContext context{ AdId{ idx}, read_results.get() + idx };
106 store.Read(context, callback, idx);
107 }
108
109 // Complete all pending requests
110 store.CompletePending(true);
111
112 // Release
113 store.StopSession();
114
115 // Test outputs
116 // Compute expected array
117 auto expected_results = alloc_aligned<uint64_t>(64,
118 sizeof(uint64_t) * kNumUniqueKeys);
119 std::memset(expected_results.get(), 0, sizeof(uint64_t) * kNumUniqueKeys);
120
121 for(uint64_t idx = 0; idx <= sno; ++idx) {
122 ++expected_results.get()[idx % kNumUniqueKeys];
123 }
124
125 // Assert if expected is same as found
126 for(uint64_t idx = 0; idx < kNumUniqueKeys; ++idx) {
127 if(expected_results.get()[idx] != read_results.get()[idx]) {
128 printf("Debug error for AdId %" PRIu64 ": Expected (%" PRIu64 "), Found(%" PRIu64 ")\n",
129 idx,
130 expected_results.get()[idx],
131 read_results.get()[idx]);
132 }
133 }
134 printf("Test successful\n");
135
136 std::string discard;
137 std::getline(std::cin, discard);
138 }
139
140 void Continue() {
141 // Not implemented.
142 assert(false);
143 }
144
145 store_t& store;
146};
147
148} // namespace sum_store
149