1 | // Copyright (c) Microsoft Corporation. All rights reserved. |
2 | // Licensed under the MIT license. |
3 | |
4 | #pragma once |
5 | |
6 | #include <atomic> |
7 | #include <cassert> |
8 | #include <cstdint> |
9 | #include <thread> |
10 | |
11 | #include "alloc.h" |
12 | #include "constants.h" |
13 | #include "key_hash.h" |
14 | |
15 | namespace FASTER { |
16 | namespace core { |
17 | |
18 | struct CheckpointLock { |
19 | CheckpointLock() |
20 | : control_{ 0 } { |
21 | } |
22 | CheckpointLock(uint64_t control) |
23 | : control_{ control } { |
24 | } |
25 | CheckpointLock(uint32_t old_lock_count, uint32_t new_lock_count) |
26 | : old_lock_count_{ old_lock_count } |
27 | , new_lock_count_{ new_lock_count } { |
28 | } |
29 | |
30 | union { |
31 | struct { |
32 | uint32_t old_lock_count_; |
33 | uint32_t new_lock_count_; |
34 | }; |
35 | uint64_t control_; |
36 | }; |
37 | }; |
38 | static_assert(sizeof(CheckpointLock) == 8, "sizeof(CheckpointLock) != 8" ); |
39 | |
40 | class AtomicCheckpointLock { |
41 | public: |
42 | AtomicCheckpointLock() |
43 | : control_{ 0 } { |
44 | } |
45 | |
46 | /// Try to lock the old version of a record. |
47 | inline bool try_lock_old() { |
48 | CheckpointLock expected{ control_.load() }; |
49 | while(expected.new_lock_count_ == 0) { |
50 | CheckpointLock desired{ expected.old_lock_count_ + 1, 0 }; |
51 | if(control_.compare_exchange_strong(expected.control_, desired.control_)) { |
52 | return true; |
53 | } |
54 | } |
55 | return false; |
56 | } |
57 | inline void unlock_old() { |
58 | control_ -= CheckpointLock{ 1, 0 } .control_; |
59 | } |
60 | |
61 | /// Try to lock the new version of a record. |
62 | inline bool try_lock_new() { |
63 | CheckpointLock expected{ control_.load() }; |
64 | while(expected.old_lock_count_ == 0) { |
65 | CheckpointLock desired{ 0, expected.new_lock_count_ + 1 }; |
66 | if(control_.compare_exchange_strong(expected.control_, desired.control_)) { |
67 | return true; |
68 | } |
69 | } |
70 | return false; |
71 | } |
72 | inline void unlock_new() { |
73 | control_ -= CheckpointLock{ 0, 1 } .control_; |
74 | } |
75 | |
76 | inline bool old_locked() const { |
77 | CheckpointLock result{ control_ }; |
78 | return result.old_lock_count_ > 0; |
79 | } |
80 | inline bool new_locked() const { |
81 | CheckpointLock result{ control_ }; |
82 | return result.new_lock_count_ > 0; |
83 | } |
84 | |
85 | private: |
86 | union { |
87 | std::atomic<uint64_t> control_; |
88 | }; |
89 | }; |
90 | static_assert(sizeof(AtomicCheckpointLock) == 8, "sizeof(AtomicCheckpointLock) != 8" ); |
91 | |
92 | class CheckpointLocks { |
93 | public: |
94 | CheckpointLocks() |
95 | : size_{ 0 } |
96 | , locks_{ nullptr } { |
97 | } |
98 | |
99 | ~CheckpointLocks() { |
100 | if(locks_) { |
101 | aligned_free(locks_); |
102 | } |
103 | } |
104 | |
105 | void Initialize(uint64_t size) { |
106 | assert(size < INT32_MAX); |
107 | assert(Utility::IsPowerOfTwo(size)); |
108 | if(locks_) { |
109 | aligned_free(locks_); |
110 | } |
111 | size_ = size; |
112 | locks_ = reinterpret_cast<AtomicCheckpointLock*>(aligned_alloc(Constants::kCacheLineBytes, |
113 | size_ * sizeof(AtomicCheckpointLock))); |
114 | std::memset(locks_, 0, size_ * sizeof(AtomicCheckpointLock)); |
115 | } |
116 | |
117 | void Free() { |
118 | assert(locks_); |
119 | #ifdef _DEBUG |
120 | for(uint64_t idx = 0; idx < size_; ++idx) { |
121 | assert(!locks_[idx].old_locked()); |
122 | assert(!locks_[idx].new_locked()); |
123 | } |
124 | #endif |
125 | aligned_free(locks_); |
126 | size_ = 0; |
127 | locks_ = nullptr; |
128 | } |
129 | |
130 | inline uint64_t size() const { |
131 | return size_; |
132 | } |
133 | |
134 | inline AtomicCheckpointLock& get_lock(KeyHash hash) { |
135 | return locks_[hash.idx(size_)]; |
136 | } |
137 | |
138 | private: |
139 | uint64_t size_; |
140 | AtomicCheckpointLock* locks_; |
141 | }; |
142 | |
143 | class CheckpointLockGuard { |
144 | public: |
145 | CheckpointLockGuard(CheckpointLocks& locks, KeyHash hash) |
146 | : lock_{ nullptr } |
147 | , locked_old_{ false } |
148 | , locked_new_{ false } { |
149 | if(locks.size() > 0) { |
150 | lock_ = &locks.get_lock(hash); |
151 | } |
152 | } |
153 | ~CheckpointLockGuard() { |
154 | if(lock_) { |
155 | if(locked_old_) { |
156 | lock_->unlock_old(); |
157 | } |
158 | if(locked_new_) { |
159 | lock_->unlock_new(); |
160 | } |
161 | } |
162 | } |
163 | inline bool try_lock_old() { |
164 | assert(lock_); |
165 | assert(!locked_old_); |
166 | locked_old_ = lock_->try_lock_old(); |
167 | return locked_old_; |
168 | } |
169 | inline bool try_lock_new() { |
170 | assert(lock_); |
171 | assert(!locked_new_); |
172 | locked_new_ = lock_->try_lock_new(); |
173 | return locked_new_; |
174 | } |
175 | |
176 | inline bool old_locked() const { |
177 | assert(lock_); |
178 | return lock_->old_locked(); |
179 | } |
180 | inline bool new_locked() const { |
181 | assert(lock_); |
182 | return lock_->new_locked(); |
183 | } |
184 | |
185 | private: |
186 | AtomicCheckpointLock* lock_; |
187 | bool locked_old_; |
188 | bool locked_new_; |
189 | }; |
190 | |
191 | } |
192 | } // namespace FASTER::core |
193 | |