1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT license.
3
4#pragma once
5
6#include <atomic>
7#include <chrono>
8#include <cstdint>
9#include <cstring>
10#include <functional>
11#include <memory>
12#include <thread>
13
14#include "alloc.h"
15#include "async.h"
16#include "constants.h"
17#include "phase.h"
18#include "thread.h"
19#include "utility.h"
20
21namespace FASTER {
22namespace core {
23
24class LightEpoch {
25 private:
26 /// Entry in epoch table
27 struct alignas(Constants::kCacheLineBytes) Entry {
28 Entry()
29 : local_current_epoch{ 0 }
30 , reentrant{ 0 }
31 , phase_finished{ Phase::REST } {
32 }
33
34 uint64_t local_current_epoch;
35 uint32_t reentrant;
36 std::atomic<Phase> phase_finished;
37 };
38 static_assert(sizeof(Entry) == 64, "sizeof(Entry) != 64");
39
40 struct EpochAction {
41 typedef void(*callback_t)(IAsyncContext*);
42
43 static constexpr uint64_t kFree = UINT64_MAX;
44 static constexpr uint64_t kLocked = UINT64_MAX - 1;
45
46 EpochAction()
47 : epoch{ kFree }
48 , callback{ nullptr }
49 , context{ nullptr } {
50 }
51
52 void Initialize() {
53 callback = nullptr;
54 context = nullptr;
55 epoch = kFree;
56 }
57
58 bool IsFree() const {
59 return epoch.load() == kFree;
60 }
61
62 bool TryPop(uint64_t expected_epoch) {
63 bool retval = epoch.compare_exchange_strong(expected_epoch, kLocked);
64 if(retval) {
65 callback_t callback_ = callback;
66 IAsyncContext* context_ = context;
67 callback = nullptr;
68 context = nullptr;
69 // Release the lock.
70 epoch.store(kFree);
71 // Perform the action.
72 callback_(context_);
73 }
74 return retval;
75 }
76
77 bool TryPush(uint64_t prior_epoch, callback_t new_callback, IAsyncContext* new_context) {
78 uint64_t expected_epoch = kFree;
79 bool retval = epoch.compare_exchange_strong(expected_epoch, kLocked);
80 if(retval) {
81 callback = new_callback;
82 context = new_context;
83 // Release the lock.
84 epoch.store(prior_epoch);
85 }
86 return retval;
87 }
88
89 bool TrySwap(uint64_t expected_epoch, uint64_t prior_epoch, callback_t new_callback,
90 IAsyncContext* new_context) {
91 bool retval = epoch.compare_exchange_strong(expected_epoch, kLocked);
92 if(retval) {
93 callback_t existing_callback = callback;
94 IAsyncContext* existing_context = context;
95 callback = new_callback;
96 context = new_context;
97 // Release the lock.
98 epoch.store(prior_epoch);
99 // Perform the action.
100 existing_callback(existing_context);
101 }
102 return retval;
103 }
104
105 /// The epoch field is atomic--always read it first and write it last.
106 std::atomic<uint64_t> epoch;
107
108 void(*callback)(IAsyncContext* context);
109 IAsyncContext* context;
110 };
111
112 public:
113 /// Default invalid page_index entry.
114 static constexpr uint32_t kInvalidIndex = 0;
115 /// This thread is not protecting any epoch.
116 static constexpr uint64_t kUnprotected = 0;
117
118 private:
119 /// Default number of entries in the entries table
120 static constexpr uint32_t kTableSize = Thread::kMaxNumThreads;
121 /// Default drainlist size
122 static constexpr uint32_t kDrainListSize = 256;
123 /// Epoch table
124 Entry* table_;
125 /// Number of entries in epoch table.
126 uint32_t num_entries_;
127
128 /// List of action, epoch pairs containing actions to performed when an epoch becomes
129 /// safe to reclaim.
130 EpochAction drain_list_[kDrainListSize];
131 /// Count of drain actions
132 std::atomic<uint32_t> drain_count_;
133
134 public:
135 /// Current system epoch (global state)
136 std::atomic<uint64_t> current_epoch;
137 /// Cached value of epoch that is safe to reclaim
138 std::atomic<uint64_t> safe_to_reclaim_epoch;
139
140 LightEpoch(uint32_t size = kTableSize)
141 : table_{ nullptr }
142 , num_entries_{ 0 }
143 , drain_count_{ 0 }
144 , drain_list_{} {
145 Initialize(size);
146 }
147
148 ~LightEpoch() {
149 Uninitialize();
150 }
151
152 private:
153 void Initialize(uint32_t size) {
154 num_entries_ = size;
155 // do cache-line alignment
156 table_ = reinterpret_cast<Entry*>(aligned_alloc(Constants::kCacheLineBytes,
157 (size + 2) * sizeof(Entry)));
158 new(table_) Entry[size + 2];
159 current_epoch = 1;
160 safe_to_reclaim_epoch = 0;
161 for(uint32_t idx = 0; idx < kDrainListSize; ++idx) {
162 drain_list_[idx].Initialize();
163 }
164 drain_count_ = 0;
165 }
166
167 void Uninitialize() {
168 aligned_free(table_);
169 table_ = nullptr;
170 num_entries_ = 0;
171 current_epoch = 1;
172 safe_to_reclaim_epoch = 0;
173 }
174
175 public:
176 /// Enter the thread into the protected code region
177 inline uint64_t Protect() {
178 uint32_t entry = Thread::id();
179 table_[entry].local_current_epoch = current_epoch.load();
180 return table_[entry].local_current_epoch;
181 }
182
183 /// Enter the thread into the protected code region
184 /// Process entries in drain list if possible
185 inline uint64_t ProtectAndDrain() {
186 uint32_t entry = Thread::id();
187 table_[entry].local_current_epoch = current_epoch.load();
188 if(drain_count_.load() > 0) {
189 Drain(table_[entry].local_current_epoch);
190 }
191 return table_[entry].local_current_epoch;
192 }
193
194 uint64_t ReentrantProtect() {
195 uint32_t entry = Thread::id();
196 if(table_[entry].local_current_epoch != kUnprotected)
197 return table_[entry].local_current_epoch;
198 table_[entry].local_current_epoch = current_epoch.load();
199 table_[entry].reentrant++;
200 return table_[entry].local_current_epoch;
201 }
202
203 inline bool IsProtected() {
204 uint32_t entry = Thread::id();
205 return table_[entry].local_current_epoch != kUnprotected;
206 }
207
208 /// Exit the thread from the protected code region.
209 void Unprotect() {
210 table_[Thread::id()].local_current_epoch = kUnprotected;
211 }
212
213 void ReentrantUnprotect() {
214 uint32_t entry = Thread::id();
215 if(--(table_[entry].reentrant) == 0) {
216 table_[entry].local_current_epoch = kUnprotected;
217 }
218 }
219
220 void Drain(uint64_t nextEpoch) {
221 ComputeNewSafeToReclaimEpoch(nextEpoch);
222 for(uint32_t idx = 0; idx < kDrainListSize; ++idx) {
223 uint64_t trigger_epoch = drain_list_[idx].epoch.load();
224 if(trigger_epoch <= safe_to_reclaim_epoch) {
225 if(drain_list_[idx].TryPop(trigger_epoch)) {
226 if(--drain_count_ == 0) {
227 break;
228 }
229 }
230 }
231 }
232 }
233
234 /// Increment the current epoch (global system state)
235 uint64_t BumpCurrentEpoch() {
236 uint64_t nextEpoch = ++current_epoch;
237 if(drain_count_ > 0) {
238 Drain(nextEpoch);
239 }
240 return nextEpoch;
241 }
242
243 /// Increment the current epoch (global system state) and register
244 /// a trigger action for when older epoch becomes safe to reclaim
245 uint64_t BumpCurrentEpoch(EpochAction::callback_t callback, IAsyncContext* context) {
246 uint64_t prior_epoch = BumpCurrentEpoch() - 1;
247 uint32_t i = 0, j = 0;
248 while(true) {
249 uint64_t trigger_epoch = drain_list_[i].epoch.load();
250 if(trigger_epoch == EpochAction::kFree) {
251 if(drain_list_[i].TryPush(prior_epoch, callback, context)) {
252 ++drain_count_;
253 break;
254 }
255 } else if(trigger_epoch <= safe_to_reclaim_epoch.load()) {
256 if(drain_list_[i].TrySwap(trigger_epoch, prior_epoch, callback, context)) {
257 break;
258 }
259 }
260 if(++i == kDrainListSize) {
261 i = 0;
262 if(++j == 500) {
263 j = 0;
264 std::this_thread::sleep_for(std::chrono::seconds(1));
265 fprintf(stderr, "Slowdown: Unable to add trigger to epoch\n");
266 }
267 }
268 }
269 return prior_epoch + 1;
270 }
271
272 /// Compute latest epoch that is safe to reclaim, by scanning the epoch table
273 uint64_t ComputeNewSafeToReclaimEpoch(uint64_t current_epoch_) {
274 uint64_t oldest_ongoing_call = current_epoch_;
275 for(uint32_t index = 1; index <= num_entries_; ++index) {
276 uint64_t entry_epoch = table_[index].local_current_epoch;
277 if(entry_epoch != kUnprotected && entry_epoch < oldest_ongoing_call) {
278 oldest_ongoing_call = entry_epoch;
279 }
280 }
281 safe_to_reclaim_epoch = oldest_ongoing_call - 1;
282 return safe_to_reclaim_epoch;
283 }
284
285 void SpinWaitForSafeToReclaim(uint64_t current_epoch_, uint64_t safe_to_reclaim_epoch_) {
286 do {
287 ComputeNewSafeToReclaimEpoch(current_epoch_);
288 } while(safe_to_reclaim_epoch_ > safe_to_reclaim_epoch);
289 }
290
291 bool IsSafeToReclaim(uint64_t epoch) const {
292 return (epoch <= safe_to_reclaim_epoch);
293 }
294
295 /// CPR checkpoint functions.
296 inline void ResetPhaseFinished() {
297 for(uint32_t idx = 1; idx <= num_entries_; ++idx) {
298 assert(table_[idx].phase_finished.load() == Phase::REST ||
299 table_[idx].phase_finished.load() == Phase::INDEX_CHKPT ||
300 table_[idx].phase_finished.load() == Phase::PERSISTENCE_CALLBACK ||
301 table_[idx].phase_finished.load() == Phase::GC_IN_PROGRESS ||
302 table_[idx].phase_finished.load() == Phase::GROW_IN_PROGRESS);
303 table_[idx].phase_finished.store(Phase::REST);
304 }
305 }
306 /// This thread has completed the specified phase.
307 inline bool FinishThreadPhase(Phase phase) {
308 uint32_t entry = Thread::id();
309 table_[entry].phase_finished = phase;
310 // Check if other threads have reported complete.
311 for(uint32_t idx = 1; idx <= num_entries_; ++idx) {
312 Phase entry_phase = table_[idx].phase_finished.load();
313 uint64_t entry_epoch = table_[idx].local_current_epoch;
314 if(entry_epoch != 0 && entry_phase != phase) {
315 return false;
316 }
317 }
318 return true;
319 }
320 /// Has this thread completed the specified phase (i.e., is it waiting for other threads to
321 /// finish the specified phase, before it can advance the global phase)?
322 inline bool HasThreadFinishedPhase(Phase phase) const {
323 uint32_t entry = Thread::id();
324 return table_[entry].phase_finished == phase;
325 }
326};
327
328}
329} // namespace FASTER::core
330