| 1 | #pragma once | 
| 2 |  | 
| 3 | #include <unordered_map> | 
| 4 | #include <list> | 
| 5 | #include <memory> | 
| 6 | #include <chrono> | 
| 7 | #include <mutex> | 
| 8 | #include <atomic> | 
| 9 |  | 
| 10 | #include <common/logger_useful.h> | 
| 11 |  | 
| 12 |  | 
| 13 | namespace DB | 
| 14 | { | 
| 15 |  | 
| 16 | template <typename T> | 
| 17 | struct TrivialWeightFunction | 
| 18 | { | 
| 19 |     size_t operator()(const T &) const | 
| 20 |     { | 
| 21 |         return 1; | 
| 22 |     } | 
| 23 | }; | 
| 24 |  | 
| 25 |  | 
| 26 | /// Thread-safe cache that evicts entries which are not used for a long time. | 
| 27 | /// WeightFunction is a functor that takes Mapped as a parameter and returns "weight" (approximate size) | 
| 28 | /// of that value. | 
| 29 | /// Cache starts to evict entries when their total weight exceeds max_size. | 
| 30 | /// Value weight should not change after insertion. | 
| 31 | template <typename TKey, typename TMapped, typename HashFunction = std::hash<TKey>, typename WeightFunction = TrivialWeightFunction<TMapped>> | 
| 32 | class LRUCache | 
| 33 | { | 
| 34 | public: | 
| 35 |     using Key = TKey; | 
| 36 |     using Mapped = TMapped; | 
| 37 |     using MappedPtr = std::shared_ptr<Mapped>; | 
| 38 |  | 
| 39 | private: | 
| 40 |     using Clock = std::chrono::steady_clock; | 
| 41 |  | 
| 42 | public: | 
| 43 |     LRUCache(size_t max_size_) | 
| 44 |         : max_size(std::max(static_cast<size_t>(1), max_size_)) {} | 
| 45 |  | 
| 46 |     MappedPtr get(const Key & key) | 
| 47 |     { | 
| 48 |         std::lock_guard lock(mutex); | 
| 49 |  | 
| 50 |         auto res = getImpl(key, lock); | 
| 51 |         if (res) | 
| 52 |             ++hits; | 
| 53 |         else | 
| 54 |             ++misses; | 
| 55 |  | 
| 56 |         return res; | 
| 57 |     } | 
| 58 |  | 
| 59 |     void set(const Key & key, const MappedPtr & mapped) | 
| 60 |     { | 
| 61 |         std::lock_guard lock(mutex); | 
| 62 |  | 
| 63 |         setImpl(key, mapped, lock); | 
| 64 |     } | 
| 65 |  | 
| 66 |     /// If the value for the key is in the cache, returns it. If it is not, calls load_func() to | 
| 67 |     /// produce it, saves the result in the cache and returns it. | 
| 68 |     /// Only one of several concurrent threads calling getOrSet() will call load_func(), | 
| 69 |     /// others will wait for that call to complete and will use its result (this helps prevent cache stampede). | 
| 70 |     /// Exceptions occuring in load_func will be propagated to the caller. Another thread from the | 
| 71 |     /// set of concurrent threads will then try to call its load_func etc. | 
| 72 |     /// | 
| 73 |     /// Returns std::pair of the cached value and a bool indicating whether the value was produced during this call. | 
| 74 |     template <typename LoadFunc> | 
| 75 |     std::pair<MappedPtr, bool> getOrSet(const Key & key, LoadFunc && load_func) | 
| 76 |     { | 
| 77 |         InsertTokenHolder token_holder; | 
| 78 |         { | 
| 79 |             std::lock_guard cache_lock(mutex); | 
| 80 |  | 
| 81 |             auto val = getImpl(key, cache_lock); | 
| 82 |             if (val) | 
| 83 |             { | 
| 84 |                 ++hits; | 
| 85 |                 return std::make_pair(val, false); | 
| 86 |             } | 
| 87 |  | 
| 88 |             auto & token = insert_tokens[key]; | 
| 89 |             if (!token) | 
| 90 |                 token = std::make_shared<InsertToken>(*this); | 
| 91 |  | 
| 92 |             token_holder.acquire(&key, token, cache_lock); | 
| 93 |         } | 
| 94 |  | 
| 95 |         InsertToken * token = token_holder.token.get(); | 
| 96 |  | 
| 97 |         std::lock_guard token_lock(token->mutex); | 
| 98 |  | 
| 99 |         token_holder.cleaned_up = token->cleaned_up; | 
| 100 |  | 
| 101 |         if (token->value) | 
| 102 |         { | 
| 103 |             /// Another thread already produced the value while we waited for token->mutex. | 
| 104 |             ++hits; | 
| 105 |             return std::make_pair(token->value, false); | 
| 106 |         } | 
| 107 |  | 
| 108 |         ++misses; | 
| 109 |         token->value = load_func(); | 
| 110 |  | 
| 111 |         std::lock_guard cache_lock(mutex); | 
| 112 |  | 
| 113 |         /// Insert the new value only if the token is still in present in insert_tokens. | 
| 114 |         /// (The token may be absent because of a concurrent reset() call). | 
| 115 |         bool result = false; | 
| 116 |         auto token_it = insert_tokens.find(key); | 
| 117 |         if (token_it != insert_tokens.end() && token_it->second.get() == token) | 
| 118 |         { | 
| 119 |             setImpl(key, token->value, cache_lock); | 
| 120 |             result = true; | 
| 121 |         } | 
| 122 |  | 
| 123 |         if (!token->cleaned_up) | 
| 124 |             token_holder.cleanup(token_lock, cache_lock); | 
| 125 |  | 
| 126 |         return std::make_pair(token->value, result); | 
| 127 |     } | 
| 128 |  | 
| 129 |     void getStats(size_t & out_hits, size_t & out_misses) const | 
| 130 |     { | 
| 131 |         std::lock_guard lock(mutex); | 
| 132 |         out_hits = hits; | 
| 133 |         out_misses = misses; | 
| 134 |     } | 
| 135 |  | 
| 136 |     size_t weight() const | 
| 137 |     { | 
| 138 |         std::lock_guard lock(mutex); | 
| 139 |         return current_size; | 
| 140 |     } | 
| 141 |  | 
| 142 |     size_t count() const | 
| 143 |     { | 
| 144 |         std::lock_guard lock(mutex); | 
| 145 |         return cells.size(); | 
| 146 |     } | 
| 147 |  | 
| 148 |     void reset() | 
| 149 |     { | 
| 150 |         std::lock_guard lock(mutex); | 
| 151 |         queue.clear(); | 
| 152 |         cells.clear(); | 
| 153 |         insert_tokens.clear(); | 
| 154 |         current_size = 0; | 
| 155 |         hits = 0; | 
| 156 |         misses = 0; | 
| 157 |     } | 
| 158 |  | 
| 159 |     virtual ~LRUCache() {} | 
| 160 |  | 
| 161 | protected: | 
| 162 |     using LRUQueue = std::list<Key>; | 
| 163 |     using LRUQueueIterator = typename LRUQueue::iterator; | 
| 164 |  | 
| 165 |     struct Cell | 
| 166 |     { | 
| 167 |         MappedPtr value; | 
| 168 |         size_t size; | 
| 169 |         LRUQueueIterator queue_iterator; | 
| 170 |     }; | 
| 171 |  | 
| 172 |     using Cells = std::unordered_map<Key, Cell, HashFunction>; | 
| 173 |  | 
| 174 |     Cells cells; | 
| 175 |  | 
| 176 |     mutable std::mutex mutex; | 
| 177 | private: | 
| 178 |  | 
| 179 |     /// Represents pending insertion attempt. | 
| 180 |     struct InsertToken | 
| 181 |     { | 
| 182 |         explicit InsertToken(LRUCache & cache_) : cache(cache_) {} | 
| 183 |  | 
| 184 |         std::mutex mutex; | 
| 185 |         bool cleaned_up = false; /// Protected by the token mutex | 
| 186 |         MappedPtr value; /// Protected by the token mutex | 
| 187 |  | 
| 188 |         LRUCache & cache; | 
| 189 |         size_t refcount = 0; /// Protected by the cache mutex | 
| 190 |     }; | 
| 191 |  | 
| 192 |     using InsertTokenById = std::unordered_map<Key, std::shared_ptr<InsertToken>, HashFunction>; | 
| 193 |  | 
| 194 |     /// This class is responsible for removing used insert tokens from the insert_tokens map. | 
| 195 |     /// Among several concurrent threads the first successful one is responsible for removal. But if they all | 
| 196 |     /// fail, then the last one is responsible. | 
| 197 |     struct InsertTokenHolder | 
| 198 |     { | 
| 199 |         const Key * key = nullptr; | 
| 200 |         std::shared_ptr<InsertToken> token; | 
| 201 |         bool cleaned_up = false; | 
| 202 |  | 
| 203 |         InsertTokenHolder() = default; | 
| 204 |  | 
| 205 |         void acquire(const Key * key_, const std::shared_ptr<InsertToken> & token_, [[maybe_unused]] std::lock_guard<std::mutex> & cache_lock) | 
| 206 |         { | 
| 207 |             key = key_; | 
| 208 |             token = token_; | 
| 209 |             ++token->refcount; | 
| 210 |         } | 
| 211 |  | 
| 212 |         void cleanup([[maybe_unused]] std::lock_guard<std::mutex> & token_lock, [[maybe_unused]] std::lock_guard<std::mutex> & cache_lock) | 
| 213 |         { | 
| 214 |             token->cache.insert_tokens.erase(*key); | 
| 215 |             token->cleaned_up = true; | 
| 216 |             cleaned_up = true; | 
| 217 |         } | 
| 218 |  | 
| 219 |         ~InsertTokenHolder() | 
| 220 |         { | 
| 221 |             if (!token) | 
| 222 |                 return; | 
| 223 |  | 
| 224 |             if (cleaned_up) | 
| 225 |                 return; | 
| 226 |  | 
| 227 |             std::lock_guard token_lock(token->mutex); | 
| 228 |  | 
| 229 |             if (token->cleaned_up) | 
| 230 |                 return; | 
| 231 |  | 
| 232 |             std::lock_guard cache_lock(token->cache.mutex); | 
| 233 |  | 
| 234 |             --token->refcount; | 
| 235 |             if (token->refcount == 0) | 
| 236 |                 cleanup(token_lock, cache_lock); | 
| 237 |         } | 
| 238 |     }; | 
| 239 |  | 
| 240 |     friend struct InsertTokenHolder; | 
| 241 |  | 
| 242 |  | 
| 243 |     InsertTokenById insert_tokens; | 
| 244 |  | 
| 245 |     LRUQueue queue; | 
| 246 |  | 
| 247 |     /// Total weight of values. | 
| 248 |     size_t current_size = 0; | 
| 249 |     const size_t max_size; | 
| 250 |  | 
| 251 |     std::atomic<size_t> hits {0}; | 
| 252 |     std::atomic<size_t> misses {0}; | 
| 253 |  | 
| 254 |     WeightFunction weight_function; | 
| 255 |  | 
| 256 |     MappedPtr getImpl(const Key & key, [[maybe_unused]] std::lock_guard<std::mutex> & cache_lock) | 
| 257 |     { | 
| 258 |         auto it = cells.find(key); | 
| 259 |         if (it == cells.end()) | 
| 260 |         { | 
| 261 |             return MappedPtr(); | 
| 262 |         } | 
| 263 |  | 
| 264 |         Cell & cell = it->second; | 
| 265 |  | 
| 266 |         /// Move the key to the end of the queue. The iterator remains valid. | 
| 267 |         queue.splice(queue.end(), queue, cell.queue_iterator); | 
| 268 |  | 
| 269 |         return cell.value; | 
| 270 |     } | 
| 271 |  | 
| 272 |     void setImpl(const Key & key, const MappedPtr & mapped, [[maybe_unused]] std::lock_guard<std::mutex> & cache_lock) | 
| 273 |     { | 
| 274 |         auto res = cells.emplace(std::piecewise_construct, | 
| 275 |             std::forward_as_tuple(key), | 
| 276 |             std::forward_as_tuple()); | 
| 277 |  | 
| 278 |         Cell & cell = res.first->second; | 
| 279 |         bool inserted = res.second; | 
| 280 |  | 
| 281 |         if (inserted) | 
| 282 |         { | 
| 283 |             cell.queue_iterator = queue.insert(queue.end(), key); | 
| 284 |         } | 
| 285 |         else | 
| 286 |         { | 
| 287 |             current_size -= cell.size; | 
| 288 |             queue.splice(queue.end(), queue, cell.queue_iterator); | 
| 289 |         } | 
| 290 |  | 
| 291 |         cell.value = mapped; | 
| 292 |         cell.size = cell.value ? weight_function(*cell.value) : 0; | 
| 293 |         current_size += cell.size; | 
| 294 |  | 
| 295 |         removeOverflow(); | 
| 296 |     } | 
| 297 |  | 
| 298 |     void removeOverflow() | 
| 299 |     { | 
| 300 |         size_t current_weight_lost = 0; | 
| 301 |         size_t queue_size = cells.size(); | 
| 302 |         while ((current_size > max_size) && (queue_size > 1)) | 
| 303 |         { | 
| 304 |             const Key & key = queue.front(); | 
| 305 |  | 
| 306 |             auto it = cells.find(key); | 
| 307 |             if (it == cells.end()) | 
| 308 |             { | 
| 309 |                 LOG_ERROR(&Logger::get("LRUCache" ), "LRUCache became inconsistent. There must be a bug in it." ); | 
| 310 |                 abort(); | 
| 311 |             } | 
| 312 |  | 
| 313 |             const auto & cell = it->second; | 
| 314 |  | 
| 315 |             current_size -= cell.size; | 
| 316 |             current_weight_lost += cell.size; | 
| 317 |  | 
| 318 |             cells.erase(it); | 
| 319 |             queue.pop_front(); | 
| 320 |             --queue_size; | 
| 321 |         } | 
| 322 |  | 
| 323 |         onRemoveOverflowWeightLoss(current_weight_lost); | 
| 324 |  | 
| 325 |         if (current_size > (1ull << 63)) | 
| 326 |         { | 
| 327 |             LOG_ERROR(&Logger::get("LRUCache" ), "LRUCache became inconsistent. There must be a bug in it." ); | 
| 328 |             abort(); | 
| 329 |         } | 
| 330 |     } | 
| 331 |  | 
| 332 |     /// Override this method if you want to track how much weight was lost in removeOverflow method. | 
| 333 |     virtual void onRemoveOverflowWeightLoss(size_t /*weight_loss*/) {} | 
| 334 | }; | 
| 335 |  | 
| 336 |  | 
| 337 | } | 
| 338 |  |