1 | #pragma once |
2 | |
3 | #include <unordered_map> |
4 | #include <list> |
5 | #include <memory> |
6 | #include <chrono> |
7 | #include <mutex> |
8 | #include <atomic> |
9 | |
10 | #include <common/logger_useful.h> |
11 | |
12 | |
13 | namespace DB |
14 | { |
15 | |
16 | template <typename T> |
17 | struct TrivialWeightFunction |
18 | { |
19 | size_t operator()(const T &) const |
20 | { |
21 | return 1; |
22 | } |
23 | }; |
24 | |
25 | |
26 | /// Thread-safe cache that evicts entries which are not used for a long time. |
27 | /// WeightFunction is a functor that takes Mapped as a parameter and returns "weight" (approximate size) |
28 | /// of that value. |
29 | /// Cache starts to evict entries when their total weight exceeds max_size. |
30 | /// Value weight should not change after insertion. |
31 | template <typename TKey, typename TMapped, typename HashFunction = std::hash<TKey>, typename WeightFunction = TrivialWeightFunction<TMapped>> |
32 | class LRUCache |
33 | { |
34 | public: |
35 | using Key = TKey; |
36 | using Mapped = TMapped; |
37 | using MappedPtr = std::shared_ptr<Mapped>; |
38 | |
39 | private: |
40 | using Clock = std::chrono::steady_clock; |
41 | |
42 | public: |
43 | LRUCache(size_t max_size_) |
44 | : max_size(std::max(static_cast<size_t>(1), max_size_)) {} |
45 | |
46 | MappedPtr get(const Key & key) |
47 | { |
48 | std::lock_guard lock(mutex); |
49 | |
50 | auto res = getImpl(key, lock); |
51 | if (res) |
52 | ++hits; |
53 | else |
54 | ++misses; |
55 | |
56 | return res; |
57 | } |
58 | |
59 | void set(const Key & key, const MappedPtr & mapped) |
60 | { |
61 | std::lock_guard lock(mutex); |
62 | |
63 | setImpl(key, mapped, lock); |
64 | } |
65 | |
66 | /// If the value for the key is in the cache, returns it. If it is not, calls load_func() to |
67 | /// produce it, saves the result in the cache and returns it. |
68 | /// Only one of several concurrent threads calling getOrSet() will call load_func(), |
69 | /// others will wait for that call to complete and will use its result (this helps prevent cache stampede). |
70 | /// Exceptions occuring in load_func will be propagated to the caller. Another thread from the |
71 | /// set of concurrent threads will then try to call its load_func etc. |
72 | /// |
73 | /// Returns std::pair of the cached value and a bool indicating whether the value was produced during this call. |
74 | template <typename LoadFunc> |
75 | std::pair<MappedPtr, bool> getOrSet(const Key & key, LoadFunc && load_func) |
76 | { |
77 | InsertTokenHolder token_holder; |
78 | { |
79 | std::lock_guard cache_lock(mutex); |
80 | |
81 | auto val = getImpl(key, cache_lock); |
82 | if (val) |
83 | { |
84 | ++hits; |
85 | return std::make_pair(val, false); |
86 | } |
87 | |
88 | auto & token = insert_tokens[key]; |
89 | if (!token) |
90 | token = std::make_shared<InsertToken>(*this); |
91 | |
92 | token_holder.acquire(&key, token, cache_lock); |
93 | } |
94 | |
95 | InsertToken * token = token_holder.token.get(); |
96 | |
97 | std::lock_guard token_lock(token->mutex); |
98 | |
99 | token_holder.cleaned_up = token->cleaned_up; |
100 | |
101 | if (token->value) |
102 | { |
103 | /// Another thread already produced the value while we waited for token->mutex. |
104 | ++hits; |
105 | return std::make_pair(token->value, false); |
106 | } |
107 | |
108 | ++misses; |
109 | token->value = load_func(); |
110 | |
111 | std::lock_guard cache_lock(mutex); |
112 | |
113 | /// Insert the new value only if the token is still in present in insert_tokens. |
114 | /// (The token may be absent because of a concurrent reset() call). |
115 | bool result = false; |
116 | auto token_it = insert_tokens.find(key); |
117 | if (token_it != insert_tokens.end() && token_it->second.get() == token) |
118 | { |
119 | setImpl(key, token->value, cache_lock); |
120 | result = true; |
121 | } |
122 | |
123 | if (!token->cleaned_up) |
124 | token_holder.cleanup(token_lock, cache_lock); |
125 | |
126 | return std::make_pair(token->value, result); |
127 | } |
128 | |
129 | void getStats(size_t & out_hits, size_t & out_misses) const |
130 | { |
131 | std::lock_guard lock(mutex); |
132 | out_hits = hits; |
133 | out_misses = misses; |
134 | } |
135 | |
136 | size_t weight() const |
137 | { |
138 | std::lock_guard lock(mutex); |
139 | return current_size; |
140 | } |
141 | |
142 | size_t count() const |
143 | { |
144 | std::lock_guard lock(mutex); |
145 | return cells.size(); |
146 | } |
147 | |
148 | void reset() |
149 | { |
150 | std::lock_guard lock(mutex); |
151 | queue.clear(); |
152 | cells.clear(); |
153 | insert_tokens.clear(); |
154 | current_size = 0; |
155 | hits = 0; |
156 | misses = 0; |
157 | } |
158 | |
159 | virtual ~LRUCache() {} |
160 | |
161 | protected: |
162 | using LRUQueue = std::list<Key>; |
163 | using LRUQueueIterator = typename LRUQueue::iterator; |
164 | |
165 | struct Cell |
166 | { |
167 | MappedPtr value; |
168 | size_t size; |
169 | LRUQueueIterator queue_iterator; |
170 | }; |
171 | |
172 | using Cells = std::unordered_map<Key, Cell, HashFunction>; |
173 | |
174 | Cells cells; |
175 | |
176 | mutable std::mutex mutex; |
177 | private: |
178 | |
179 | /// Represents pending insertion attempt. |
180 | struct InsertToken |
181 | { |
182 | explicit InsertToken(LRUCache & cache_) : cache(cache_) {} |
183 | |
184 | std::mutex mutex; |
185 | bool cleaned_up = false; /// Protected by the token mutex |
186 | MappedPtr value; /// Protected by the token mutex |
187 | |
188 | LRUCache & cache; |
189 | size_t refcount = 0; /// Protected by the cache mutex |
190 | }; |
191 | |
192 | using InsertTokenById = std::unordered_map<Key, std::shared_ptr<InsertToken>, HashFunction>; |
193 | |
194 | /// This class is responsible for removing used insert tokens from the insert_tokens map. |
195 | /// Among several concurrent threads the first successful one is responsible for removal. But if they all |
196 | /// fail, then the last one is responsible. |
197 | struct InsertTokenHolder |
198 | { |
199 | const Key * key = nullptr; |
200 | std::shared_ptr<InsertToken> token; |
201 | bool cleaned_up = false; |
202 | |
203 | InsertTokenHolder() = default; |
204 | |
205 | void acquire(const Key * key_, const std::shared_ptr<InsertToken> & token_, [[maybe_unused]] std::lock_guard<std::mutex> & cache_lock) |
206 | { |
207 | key = key_; |
208 | token = token_; |
209 | ++token->refcount; |
210 | } |
211 | |
212 | void cleanup([[maybe_unused]] std::lock_guard<std::mutex> & token_lock, [[maybe_unused]] std::lock_guard<std::mutex> & cache_lock) |
213 | { |
214 | token->cache.insert_tokens.erase(*key); |
215 | token->cleaned_up = true; |
216 | cleaned_up = true; |
217 | } |
218 | |
219 | ~InsertTokenHolder() |
220 | { |
221 | if (!token) |
222 | return; |
223 | |
224 | if (cleaned_up) |
225 | return; |
226 | |
227 | std::lock_guard token_lock(token->mutex); |
228 | |
229 | if (token->cleaned_up) |
230 | return; |
231 | |
232 | std::lock_guard cache_lock(token->cache.mutex); |
233 | |
234 | --token->refcount; |
235 | if (token->refcount == 0) |
236 | cleanup(token_lock, cache_lock); |
237 | } |
238 | }; |
239 | |
240 | friend struct InsertTokenHolder; |
241 | |
242 | |
243 | InsertTokenById insert_tokens; |
244 | |
245 | LRUQueue queue; |
246 | |
247 | /// Total weight of values. |
248 | size_t current_size = 0; |
249 | const size_t max_size; |
250 | |
251 | std::atomic<size_t> hits {0}; |
252 | std::atomic<size_t> misses {0}; |
253 | |
254 | WeightFunction weight_function; |
255 | |
256 | MappedPtr getImpl(const Key & key, [[maybe_unused]] std::lock_guard<std::mutex> & cache_lock) |
257 | { |
258 | auto it = cells.find(key); |
259 | if (it == cells.end()) |
260 | { |
261 | return MappedPtr(); |
262 | } |
263 | |
264 | Cell & cell = it->second; |
265 | |
266 | /// Move the key to the end of the queue. The iterator remains valid. |
267 | queue.splice(queue.end(), queue, cell.queue_iterator); |
268 | |
269 | return cell.value; |
270 | } |
271 | |
272 | void setImpl(const Key & key, const MappedPtr & mapped, [[maybe_unused]] std::lock_guard<std::mutex> & cache_lock) |
273 | { |
274 | auto res = cells.emplace(std::piecewise_construct, |
275 | std::forward_as_tuple(key), |
276 | std::forward_as_tuple()); |
277 | |
278 | Cell & cell = res.first->second; |
279 | bool inserted = res.second; |
280 | |
281 | if (inserted) |
282 | { |
283 | cell.queue_iterator = queue.insert(queue.end(), key); |
284 | } |
285 | else |
286 | { |
287 | current_size -= cell.size; |
288 | queue.splice(queue.end(), queue, cell.queue_iterator); |
289 | } |
290 | |
291 | cell.value = mapped; |
292 | cell.size = cell.value ? weight_function(*cell.value) : 0; |
293 | current_size += cell.size; |
294 | |
295 | removeOverflow(); |
296 | } |
297 | |
298 | void removeOverflow() |
299 | { |
300 | size_t current_weight_lost = 0; |
301 | size_t queue_size = cells.size(); |
302 | while ((current_size > max_size) && (queue_size > 1)) |
303 | { |
304 | const Key & key = queue.front(); |
305 | |
306 | auto it = cells.find(key); |
307 | if (it == cells.end()) |
308 | { |
309 | LOG_ERROR(&Logger::get("LRUCache" ), "LRUCache became inconsistent. There must be a bug in it." ); |
310 | abort(); |
311 | } |
312 | |
313 | const auto & cell = it->second; |
314 | |
315 | current_size -= cell.size; |
316 | current_weight_lost += cell.size; |
317 | |
318 | cells.erase(it); |
319 | queue.pop_front(); |
320 | --queue_size; |
321 | } |
322 | |
323 | onRemoveOverflowWeightLoss(current_weight_lost); |
324 | |
325 | if (current_size > (1ull << 63)) |
326 | { |
327 | LOG_ERROR(&Logger::get("LRUCache" ), "LRUCache became inconsistent. There must be a bug in it." ); |
328 | abort(); |
329 | } |
330 | } |
331 | |
332 | /// Override this method if you want to track how much weight was lost in removeOverflow method. |
333 | virtual void onRemoveOverflowWeightLoss(size_t /*weight_loss*/) {} |
334 | }; |
335 | |
336 | |
337 | } |
338 | |