1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT license.
3
4#pragma once
5
6#include <atomic>
7#include <cassert>
8#include <cstdint>
9#include <thread>
10
11#include "alloc.h"
12#include "constants.h"
13#include "key_hash.h"
14
15namespace FASTER {
16namespace core {
17
18struct CheckpointLock {
19 CheckpointLock()
20 : control_{ 0 } {
21 }
22 CheckpointLock(uint64_t control)
23 : control_{ control } {
24 }
25 CheckpointLock(uint32_t old_lock_count, uint32_t new_lock_count)
26 : old_lock_count_{ old_lock_count }
27 , new_lock_count_{ new_lock_count } {
28 }
29
30 union {
31 struct {
32 uint32_t old_lock_count_;
33 uint32_t new_lock_count_;
34 };
35 uint64_t control_;
36 };
37};
38static_assert(sizeof(CheckpointLock) == 8, "sizeof(CheckpointLock) != 8");
39
40class AtomicCheckpointLock {
41 public:
42 AtomicCheckpointLock()
43 : control_{ 0 } {
44 }
45
46 /// Try to lock the old version of a record.
47 inline bool try_lock_old() {
48 CheckpointLock expected{ control_.load() };
49 while(expected.new_lock_count_ == 0) {
50 CheckpointLock desired{ expected.old_lock_count_ + 1, 0 };
51 if(control_.compare_exchange_strong(expected.control_, desired.control_)) {
52 return true;
53 }
54 }
55 return false;
56 }
57 inline void unlock_old() {
58 control_ -= CheckpointLock{ 1, 0 } .control_;
59 }
60
61 /// Try to lock the new version of a record.
62 inline bool try_lock_new() {
63 CheckpointLock expected{ control_.load() };
64 while(expected.old_lock_count_ == 0) {
65 CheckpointLock desired{ 0, expected.new_lock_count_ + 1 };
66 if(control_.compare_exchange_strong(expected.control_, desired.control_)) {
67 return true;
68 }
69 }
70 return false;
71 }
72 inline void unlock_new() {
73 control_ -= CheckpointLock{ 0, 1 } .control_;
74 }
75
76 inline bool old_locked() const {
77 CheckpointLock result{ control_ };
78 return result.old_lock_count_ > 0;
79 }
80 inline bool new_locked() const {
81 CheckpointLock result{ control_ };
82 return result.new_lock_count_ > 0;
83 }
84
85 private:
86 union {
87 std::atomic<uint64_t> control_;
88 };
89};
90static_assert(sizeof(AtomicCheckpointLock) == 8, "sizeof(AtomicCheckpointLock) != 8");
91
92class CheckpointLocks {
93 public:
94 CheckpointLocks()
95 : size_{ 0 }
96 , locks_{ nullptr } {
97 }
98
99 ~CheckpointLocks() {
100 if(locks_) {
101 aligned_free(locks_);
102 }
103 }
104
105 void Initialize(uint64_t size) {
106 assert(size < INT32_MAX);
107 assert(Utility::IsPowerOfTwo(size));
108 if(locks_) {
109 aligned_free(locks_);
110 }
111 size_ = size;
112 locks_ = reinterpret_cast<AtomicCheckpointLock*>(aligned_alloc(Constants::kCacheLineBytes,
113 size_ * sizeof(AtomicCheckpointLock)));
114 std::memset(locks_, 0, size_ * sizeof(AtomicCheckpointLock));
115 }
116
117 void Free() {
118 assert(locks_);
119#ifdef _DEBUG
120 for(uint64_t idx = 0; idx < size_; ++idx) {
121 assert(!locks_[idx].old_locked());
122 assert(!locks_[idx].new_locked());
123 }
124#endif
125 aligned_free(locks_);
126 size_ = 0;
127 locks_ = nullptr;
128 }
129
130 inline uint64_t size() const {
131 return size_;
132 }
133
134 inline AtomicCheckpointLock& get_lock(KeyHash hash) {
135 return locks_[hash.idx(size_)];
136 }
137
138 private:
139 uint64_t size_;
140 AtomicCheckpointLock* locks_;
141};
142
143class CheckpointLockGuard {
144 public:
145 CheckpointLockGuard(CheckpointLocks& locks, KeyHash hash)
146 : lock_{ nullptr }
147 , locked_old_{ false }
148 , locked_new_{ false } {
149 if(locks.size() > 0) {
150 lock_ = &locks.get_lock(hash);
151 }
152 }
153 ~CheckpointLockGuard() {
154 if(lock_) {
155 if(locked_old_) {
156 lock_->unlock_old();
157 }
158 if(locked_new_) {
159 lock_->unlock_new();
160 }
161 }
162 }
163 inline bool try_lock_old() {
164 assert(lock_);
165 assert(!locked_old_);
166 locked_old_ = lock_->try_lock_old();
167 return locked_old_;
168 }
169 inline bool try_lock_new() {
170 assert(lock_);
171 assert(!locked_new_);
172 locked_new_ = lock_->try_lock_new();
173 return locked_new_;
174 }
175
176 inline bool old_locked() const {
177 assert(lock_);
178 return lock_->old_locked();
179 }
180 inline bool new_locked() const {
181 assert(lock_);
182 return lock_->new_locked();
183 }
184
185 private:
186 AtomicCheckpointLock* lock_;
187 bool locked_old_;
188 bool locked_new_;
189};
190
191}
192} // namespace FASTER::core
193