1#pragma once
2
3#include <unordered_map>
4#include <list>
5#include <memory>
6#include <chrono>
7#include <mutex>
8#include <atomic>
9
10#include <common/logger_useful.h>
11
12
13namespace DB
14{
15
16template <typename T>
17struct TrivialWeightFunction
18{
19 size_t operator()(const T &) const
20 {
21 return 1;
22 }
23};
24
25
26/// Thread-safe cache that evicts entries which are not used for a long time.
27/// WeightFunction is a functor that takes Mapped as a parameter and returns "weight" (approximate size)
28/// of that value.
29/// Cache starts to evict entries when their total weight exceeds max_size.
30/// Value weight should not change after insertion.
31template <typename TKey, typename TMapped, typename HashFunction = std::hash<TKey>, typename WeightFunction = TrivialWeightFunction<TMapped>>
32class LRUCache
33{
34public:
35 using Key = TKey;
36 using Mapped = TMapped;
37 using MappedPtr = std::shared_ptr<Mapped>;
38
39private:
40 using Clock = std::chrono::steady_clock;
41
42public:
43 LRUCache(size_t max_size_)
44 : max_size(std::max(static_cast<size_t>(1), max_size_)) {}
45
46 MappedPtr get(const Key & key)
47 {
48 std::lock_guard lock(mutex);
49
50 auto res = getImpl(key, lock);
51 if (res)
52 ++hits;
53 else
54 ++misses;
55
56 return res;
57 }
58
59 void set(const Key & key, const MappedPtr & mapped)
60 {
61 std::lock_guard lock(mutex);
62
63 setImpl(key, mapped, lock);
64 }
65
66 /// If the value for the key is in the cache, returns it. If it is not, calls load_func() to
67 /// produce it, saves the result in the cache and returns it.
68 /// Only one of several concurrent threads calling getOrSet() will call load_func(),
69 /// others will wait for that call to complete and will use its result (this helps prevent cache stampede).
70 /// Exceptions occuring in load_func will be propagated to the caller. Another thread from the
71 /// set of concurrent threads will then try to call its load_func etc.
72 ///
73 /// Returns std::pair of the cached value and a bool indicating whether the value was produced during this call.
74 template <typename LoadFunc>
75 std::pair<MappedPtr, bool> getOrSet(const Key & key, LoadFunc && load_func)
76 {
77 InsertTokenHolder token_holder;
78 {
79 std::lock_guard cache_lock(mutex);
80
81 auto val = getImpl(key, cache_lock);
82 if (val)
83 {
84 ++hits;
85 return std::make_pair(val, false);
86 }
87
88 auto & token = insert_tokens[key];
89 if (!token)
90 token = std::make_shared<InsertToken>(*this);
91
92 token_holder.acquire(&key, token, cache_lock);
93 }
94
95 InsertToken * token = token_holder.token.get();
96
97 std::lock_guard token_lock(token->mutex);
98
99 token_holder.cleaned_up = token->cleaned_up;
100
101 if (token->value)
102 {
103 /// Another thread already produced the value while we waited for token->mutex.
104 ++hits;
105 return std::make_pair(token->value, false);
106 }
107
108 ++misses;
109 token->value = load_func();
110
111 std::lock_guard cache_lock(mutex);
112
113 /// Insert the new value only if the token is still in present in insert_tokens.
114 /// (The token may be absent because of a concurrent reset() call).
115 bool result = false;
116 auto token_it = insert_tokens.find(key);
117 if (token_it != insert_tokens.end() && token_it->second.get() == token)
118 {
119 setImpl(key, token->value, cache_lock);
120 result = true;
121 }
122
123 if (!token->cleaned_up)
124 token_holder.cleanup(token_lock, cache_lock);
125
126 return std::make_pair(token->value, result);
127 }
128
129 void getStats(size_t & out_hits, size_t & out_misses) const
130 {
131 std::lock_guard lock(mutex);
132 out_hits = hits;
133 out_misses = misses;
134 }
135
136 size_t weight() const
137 {
138 std::lock_guard lock(mutex);
139 return current_size;
140 }
141
142 size_t count() const
143 {
144 std::lock_guard lock(mutex);
145 return cells.size();
146 }
147
148 void reset()
149 {
150 std::lock_guard lock(mutex);
151 queue.clear();
152 cells.clear();
153 insert_tokens.clear();
154 current_size = 0;
155 hits = 0;
156 misses = 0;
157 }
158
159 virtual ~LRUCache() {}
160
161protected:
162 using LRUQueue = std::list<Key>;
163 using LRUQueueIterator = typename LRUQueue::iterator;
164
165 struct Cell
166 {
167 MappedPtr value;
168 size_t size;
169 LRUQueueIterator queue_iterator;
170 };
171
172 using Cells = std::unordered_map<Key, Cell, HashFunction>;
173
174 Cells cells;
175
176 mutable std::mutex mutex;
177private:
178
179 /// Represents pending insertion attempt.
180 struct InsertToken
181 {
182 explicit InsertToken(LRUCache & cache_) : cache(cache_) {}
183
184 std::mutex mutex;
185 bool cleaned_up = false; /// Protected by the token mutex
186 MappedPtr value; /// Protected by the token mutex
187
188 LRUCache & cache;
189 size_t refcount = 0; /// Protected by the cache mutex
190 };
191
192 using InsertTokenById = std::unordered_map<Key, std::shared_ptr<InsertToken>, HashFunction>;
193
194 /// This class is responsible for removing used insert tokens from the insert_tokens map.
195 /// Among several concurrent threads the first successful one is responsible for removal. But if they all
196 /// fail, then the last one is responsible.
197 struct InsertTokenHolder
198 {
199 const Key * key = nullptr;
200 std::shared_ptr<InsertToken> token;
201 bool cleaned_up = false;
202
203 InsertTokenHolder() = default;
204
205 void acquire(const Key * key_, const std::shared_ptr<InsertToken> & token_, [[maybe_unused]] std::lock_guard<std::mutex> & cache_lock)
206 {
207 key = key_;
208 token = token_;
209 ++token->refcount;
210 }
211
212 void cleanup([[maybe_unused]] std::lock_guard<std::mutex> & token_lock, [[maybe_unused]] std::lock_guard<std::mutex> & cache_lock)
213 {
214 token->cache.insert_tokens.erase(*key);
215 token->cleaned_up = true;
216 cleaned_up = true;
217 }
218
219 ~InsertTokenHolder()
220 {
221 if (!token)
222 return;
223
224 if (cleaned_up)
225 return;
226
227 std::lock_guard token_lock(token->mutex);
228
229 if (token->cleaned_up)
230 return;
231
232 std::lock_guard cache_lock(token->cache.mutex);
233
234 --token->refcount;
235 if (token->refcount == 0)
236 cleanup(token_lock, cache_lock);
237 }
238 };
239
240 friend struct InsertTokenHolder;
241
242
243 InsertTokenById insert_tokens;
244
245 LRUQueue queue;
246
247 /// Total weight of values.
248 size_t current_size = 0;
249 const size_t max_size;
250
251 std::atomic<size_t> hits {0};
252 std::atomic<size_t> misses {0};
253
254 WeightFunction weight_function;
255
256 MappedPtr getImpl(const Key & key, [[maybe_unused]] std::lock_guard<std::mutex> & cache_lock)
257 {
258 auto it = cells.find(key);
259 if (it == cells.end())
260 {
261 return MappedPtr();
262 }
263
264 Cell & cell = it->second;
265
266 /// Move the key to the end of the queue. The iterator remains valid.
267 queue.splice(queue.end(), queue, cell.queue_iterator);
268
269 return cell.value;
270 }
271
272 void setImpl(const Key & key, const MappedPtr & mapped, [[maybe_unused]] std::lock_guard<std::mutex> & cache_lock)
273 {
274 auto res = cells.emplace(std::piecewise_construct,
275 std::forward_as_tuple(key),
276 std::forward_as_tuple());
277
278 Cell & cell = res.first->second;
279 bool inserted = res.second;
280
281 if (inserted)
282 {
283 cell.queue_iterator = queue.insert(queue.end(), key);
284 }
285 else
286 {
287 current_size -= cell.size;
288 queue.splice(queue.end(), queue, cell.queue_iterator);
289 }
290
291 cell.value = mapped;
292 cell.size = cell.value ? weight_function(*cell.value) : 0;
293 current_size += cell.size;
294
295 removeOverflow();
296 }
297
298 void removeOverflow()
299 {
300 size_t current_weight_lost = 0;
301 size_t queue_size = cells.size();
302 while ((current_size > max_size) && (queue_size > 1))
303 {
304 const Key & key = queue.front();
305
306 auto it = cells.find(key);
307 if (it == cells.end())
308 {
309 LOG_ERROR(&Logger::get("LRUCache"), "LRUCache became inconsistent. There must be a bug in it.");
310 abort();
311 }
312
313 const auto & cell = it->second;
314
315 current_size -= cell.size;
316 current_weight_lost += cell.size;
317
318 cells.erase(it);
319 queue.pop_front();
320 --queue_size;
321 }
322
323 onRemoveOverflowWeightLoss(current_weight_lost);
324
325 if (current_size > (1ull << 63))
326 {
327 LOG_ERROR(&Logger::get("LRUCache"), "LRUCache became inconsistent. There must be a bug in it.");
328 abort();
329 }
330 }
331
332 /// Override this method if you want to track how much weight was lost in removeOverflow method.
333 virtual void onRemoveOverflowWeightLoss(size_t /*weight_loss*/) {}
334};
335
336
337}
338