1#pragma once
2
3#include <time.h>
4#include <cstdlib>
5#include <climits>
6#include <random>
7#include <functional>
8#include <common/Types.h>
9#include <ext/scope_guard.h>
10#include <Core/Types.h>
11#include <Common/PoolBase.h>
12#include <Common/ProfileEvents.h>
13#include <Common/NetException.h>
14#include <Common/Exception.h>
15#include <Common/randomSeed.h>
16
17
18namespace DB
19{
20namespace ErrorCodes
21{
22 extern const int ALL_CONNECTION_TRIES_FAILED;
23 extern const int ALL_REPLICAS_ARE_STALE;
24 extern const int LOGICAL_ERROR;
25}
26}
27
28namespace ProfileEvents
29{
30 extern const Event DistributedConnectionFailTry;
31 extern const Event DistributedConnectionFailAtAll;
32}
33
34/// This class provides a pool with fault tolerance. It is used for pooling of connections to replicated DB.
35/// Initialized by several PoolBase objects.
36/// When a connection is requested, tries to create or choose an alive connection from one of the nested pools.
37/// Pools are tried in the order consistent with lexicographical order of (error count, priority, random number) tuples.
38/// Number of tries for a single pool is limited by max_tries parameter.
39/// The client can set nested pool priority by passing a GetPriority functor.
40///
41/// NOTE: if one of the nested pools blocks because it is empty, this pool will also block.
42///
43/// The client must provide a TryGetEntryFunc functor, which should perform a single try to get a connection from a nested pool.
44/// This functor can also check if the connection satisfies some eligibility criterion (e.g. check if
45/// the replica is up-to-date).
46
47template <typename TNestedPool>
48class PoolWithFailoverBase : private boost::noncopyable
49{
50public:
51 using NestedPool = TNestedPool;
52 using NestedPoolPtr = std::shared_ptr<NestedPool>;
53 using Entry = typename NestedPool::Entry;
54 using NestedPools = std::vector<NestedPoolPtr>;
55
56 PoolWithFailoverBase(
57 NestedPools nested_pools_,
58 time_t decrease_error_period_,
59 size_t max_error_cap_,
60 Logger * log_)
61 : nested_pools(std::move(nested_pools_))
62 , decrease_error_period(decrease_error_period_)
63 , max_error_cap(max_error_cap_)
64 , shared_pool_states(nested_pools.size())
65 , log(log_)
66 {
67 }
68
69 struct TryResult
70 {
71 TryResult() = default;
72
73 explicit TryResult(Entry entry_)
74 : entry(std::move(entry_))
75 , is_usable(true)
76 , is_up_to_date(true)
77 {
78 }
79
80 void reset()
81 {
82 entry = Entry();
83 is_usable = false;
84 is_up_to_date = false;
85 staleness = 0.0;
86 }
87
88 Entry entry;
89 bool is_usable = false; /// If false, the entry is unusable for current request
90 /// (but may be usable for other requests, so error counts are not incremented)
91 bool is_up_to_date = false; /// If true, the entry is a connection to up-to-date replica.
92 double staleness = 0.0; /// Helps choosing the "least stale" option when all replicas are stale.
93 };
94
95 /// This functor must be provided by a client. It must perform a single try that takes a connection
96 /// from the provided pool and checks that it is good.
97 using TryGetEntryFunc = std::function<TryResult(NestedPool & pool, std::string & fail_message)>;
98
99 /// The client can provide this functor to affect load balancing - the index of a pool is passed to
100 /// this functor. The pools with lower result value will be tried first.
101 using GetPriorityFunc = std::function<size_t(size_t index)>;
102
103 /// Returns a single connection.
104 Entry get(const TryGetEntryFunc & try_get_entry, const GetPriorityFunc & get_priority = GetPriorityFunc());
105
106
107 /// Returns at least min_entries and at most max_entries connections (at most one connection per nested pool).
108 /// The method will throw if it is unable to get min_entries alive connections or
109 /// if fallback_to_stale_replicas is false and it is unable to get min_entries connections to up-to-date replicas.
110 std::vector<TryResult> getMany(
111 size_t min_entries, size_t max_entries, size_t max_tries,
112 const TryGetEntryFunc & try_get_entry,
113 const GetPriorityFunc & get_priority = GetPriorityFunc(),
114 bool fallback_to_stale_replicas = true);
115
116 void reportError(const Entry & entry);
117
118protected:
119 struct PoolState;
120
121 using PoolStates = std::vector<PoolState>;
122
123 /// This function returns a copy of pool states to avoid race conditions when modifying shared pool states.
124 PoolStates updatePoolStates();
125 PoolStates getPoolStates() const;
126
127 NestedPools nested_pools;
128
129 const time_t decrease_error_period;
130 const size_t max_error_cap;
131
132 mutable std::mutex pool_states_mutex;
133 PoolStates shared_pool_states;
134 /// The time when error counts were last decreased.
135 time_t last_error_decrease_time = 0;
136
137 Logger * log;
138};
139
140template <typename TNestedPool>
141typename TNestedPool::Entry
142PoolWithFailoverBase<TNestedPool>::get(const TryGetEntryFunc & try_get_entry, const GetPriorityFunc & get_priority)
143{
144 std::vector<TryResult> results = getMany(1, 1, 1, try_get_entry, get_priority);
145 if (results.empty() || results[0].entry.isNull())
146 throw DB::Exception(
147 "PoolWithFailoverBase::getMany() returned less than min_entries entries.",
148 DB::ErrorCodes::LOGICAL_ERROR);
149 return results[0].entry;
150}
151
152template <typename TNestedPool>
153std::vector<typename PoolWithFailoverBase<TNestedPool>::TryResult>
154PoolWithFailoverBase<TNestedPool>::getMany(
155 size_t min_entries, size_t max_entries, size_t max_tries,
156 const TryGetEntryFunc & try_get_entry,
157 const GetPriorityFunc & get_priority,
158 bool fallback_to_stale_replicas)
159{
160 /// Update random numbers and error counts.
161 PoolStates pool_states = updatePoolStates();
162 if (get_priority)
163 {
164 for (size_t i = 0; i < pool_states.size(); ++i)
165 pool_states[i].priority = get_priority(i);
166 }
167
168 struct ShuffledPool
169 {
170 NestedPool * pool{};
171 const PoolState * state{};
172 size_t index = 0;
173 size_t error_count = 0;
174 };
175
176 /// Sort the pools into order in which they will be tried (based on respective PoolStates).
177 std::vector<ShuffledPool> shuffled_pools;
178 shuffled_pools.reserve(nested_pools.size());
179 for (size_t i = 0; i < nested_pools.size(); ++i)
180 shuffled_pools.push_back(ShuffledPool{nested_pools[i].get(), &pool_states[i], i, 0});
181 std::sort(
182 shuffled_pools.begin(), shuffled_pools.end(),
183 [](const ShuffledPool & lhs, const ShuffledPool & rhs)
184 {
185 return PoolState::compare(*lhs.state, *rhs.state);
186 });
187
188 /// We will try to get a connection from each pool until a connection is produced or max_tries is reached.
189 std::vector<TryResult> try_results(shuffled_pools.size());
190 size_t entries_count = 0;
191 size_t usable_count = 0;
192 size_t up_to_date_count = 0;
193 size_t failed_pools_count = 0;
194
195 /// At exit update shared error counts with error counts occurred during this call.
196 SCOPE_EXIT(
197 {
198 std::lock_guard lock(pool_states_mutex);
199 for (const ShuffledPool & pool: shuffled_pools)
200 {
201 auto & pool_state = shared_pool_states[pool.index];
202 pool_state.error_count = std::min<UInt64>(max_error_cap, pool_state.error_count + pool.error_count);
203 }
204 });
205
206 std::string fail_messages;
207 bool finished = false;
208 while (!finished)
209 {
210 for (size_t i = 0; i < shuffled_pools.size(); ++i)
211 {
212 if (up_to_date_count >= max_entries /// Already enough good entries.
213 || entries_count + failed_pools_count >= nested_pools.size()) /// No more good entries will be produced.
214 {
215 finished = true;
216 break;
217 }
218
219 ShuffledPool & shuffled_pool = shuffled_pools[i];
220 TryResult & result = try_results[i];
221 if (shuffled_pool.error_count >= max_tries || !result.entry.isNull())
222 continue;
223
224 std::string fail_message;
225 result = try_get_entry(*shuffled_pool.pool, fail_message);
226
227 if (!fail_message.empty())
228 fail_messages += fail_message + '\n';
229
230 if (!result.entry.isNull())
231 {
232 ++entries_count;
233 if (result.is_usable)
234 {
235 ++usable_count;
236 if (result.is_up_to_date)
237 ++up_to_date_count;
238 }
239 }
240 else
241 {
242 LOG_WARNING(log, "Connection failed at try №"
243 << (shuffled_pool.error_count + 1) << ", reason: " << fail_message);
244 ProfileEvents::increment(ProfileEvents::DistributedConnectionFailTry);
245
246 shuffled_pool.error_count = std::min(max_error_cap, shuffled_pool.error_count + 1);
247
248 if (shuffled_pool.error_count >= max_tries)
249 {
250 ++failed_pools_count;
251 ProfileEvents::increment(ProfileEvents::DistributedConnectionFailAtAll);
252 }
253 }
254 }
255 }
256
257 if (usable_count < min_entries)
258 throw DB::NetException(
259 "All connection tries failed. Log: \n\n" + fail_messages + "\n",
260 DB::ErrorCodes::ALL_CONNECTION_TRIES_FAILED);
261
262 try_results.erase(
263 std::remove_if(
264 try_results.begin(), try_results.end(),
265 [](const TryResult & r) { return r.entry.isNull() || !r.is_usable; }),
266 try_results.end());
267
268 /// Sort so that preferred items are near the beginning.
269 std::stable_sort(
270 try_results.begin(), try_results.end(),
271 [](const TryResult & left, const TryResult & right)
272 {
273 return std::forward_as_tuple(!left.is_up_to_date, left.staleness)
274 < std::forward_as_tuple(!right.is_up_to_date, right.staleness);
275 });
276
277 if (up_to_date_count >= min_entries)
278 {
279 /// There is enough up-to-date entries.
280 try_results.resize(up_to_date_count);
281 }
282 else if (fallback_to_stale_replicas)
283 {
284 /// There is not enough up-to-date entries but we are allowed to return stale entries.
285 /// Gather all up-to-date ones and least-bad stale ones.
286
287 size_t size = std::min(try_results.size(), max_entries);
288 try_results.resize(size);
289 }
290 else
291 throw DB::Exception(
292 "Could not find enough connections to up-to-date replicas. Got: " + std::to_string(up_to_date_count)
293 + ", needed: " + std::to_string(min_entries),
294 DB::ErrorCodes::ALL_REPLICAS_ARE_STALE);
295
296 return try_results;
297}
298
299template <typename TNestedPool>
300void PoolWithFailoverBase<TNestedPool>::reportError(const Entry & entry)
301{
302 for (size_t i = 0; i < nested_pools.size(); ++i)
303 {
304 if (nested_pools[i]->contains(entry))
305 {
306 std::lock_guard lock(pool_states_mutex);
307 auto & pool_state = shared_pool_states[i];
308 pool_state.error_count = std::min(max_error_cap, pool_state.error_count + 1);
309 return;
310 }
311 }
312 throw DB::Exception("Can't find pool to report error", DB::ErrorCodes::LOGICAL_ERROR);
313}
314
315template <typename TNestedPool>
316struct PoolWithFailoverBase<TNestedPool>::PoolState
317{
318 UInt64 error_count = 0;
319 Int64 priority = 0;
320 UInt32 random = 0;
321
322 void randomize()
323 {
324 random = rng();
325 }
326
327 static bool compare(const PoolState & lhs, const PoolState & rhs)
328 {
329 return std::forward_as_tuple(lhs.error_count, lhs.priority, lhs.random)
330 < std::forward_as_tuple(rhs.error_count, rhs.priority, rhs.random);
331 }
332
333private:
334 std::minstd_rand rng = std::minstd_rand(randomSeed());
335};
336
337template <typename TNestedPool>
338typename PoolWithFailoverBase<TNestedPool>::PoolStates
339PoolWithFailoverBase<TNestedPool>::updatePoolStates()
340{
341 PoolStates result;
342 result.reserve(nested_pools.size());
343
344 {
345 std::lock_guard lock(pool_states_mutex);
346
347 for (auto & state : shared_pool_states)
348 state.randomize();
349
350 time_t current_time = time(nullptr);
351
352 if (last_error_decrease_time)
353 {
354 time_t delta = current_time - last_error_decrease_time;
355
356 if (delta >= 0)
357 {
358 /// Divide error counts by 2 every decrease_error_period seconds.
359 size_t shift_amount = delta / decrease_error_period;
360 /// Update time but don't do it more often than once a period.
361 /// Else if the function is called often enough, error count will never decrease.
362 if (shift_amount)
363 last_error_decrease_time = current_time;
364
365 if (shift_amount >= sizeof(UInt64) * CHAR_BIT)
366 {
367 for (auto & state : shared_pool_states)
368 state.error_count = 0;
369 }
370 else if (shift_amount)
371 {
372 for (auto & state : shared_pool_states)
373 state.error_count >>= shift_amount;
374 }
375 }
376 }
377 else
378 last_error_decrease_time = current_time;
379
380 result.assign(shared_pool_states.begin(), shared_pool_states.end());
381 }
382 return result;
383}
384
385template <typename TNestedPool>
386typename PoolWithFailoverBase<TNestedPool>::PoolStates
387PoolWithFailoverBase<TNestedPool>::getPoolStates() const
388{
389 std::lock_guard lock(pool_states_mutex);
390 return shared_pool_states;
391}
392