1// LAF Base Library
2// Copyright (C) 2019-2022 Igara Studio S.A.
3//
4// This file is released under the terms of the MIT license.
5// Read LICENSE.txt for more information.
6
7#ifdef HAVE_CONFIG_H
8#include "config.h"
9#endif
10
11#include "base/debug.h"
12#include "base/log.h"
13#include "base/thread_pool.h"
14
15namespace base {
16
17thread_pool::thread_pool(const size_t n)
18 : m_running(true)
19 , m_threads(n)
20 , m_doingWork(0)
21{
22 std::unique_lock<std::mutex> lock(m_mutex);
23 for (size_t i=0; i<n; ++i)
24 m_threads[i] = std::thread([this]{ worker(); });
25}
26
27thread_pool::~thread_pool()
28{
29 join_all();
30}
31
32void thread_pool::execute(std::function<void()>&& func)
33{
34 std::unique_lock<std::mutex> lock(m_mutex);
35 ASSERT(m_running);
36 m_work.push(std::move(func));
37 m_cv.notify_one();
38}
39
40void thread_pool::wait_all()
41{
42 std::unique_lock<std::mutex> lock(m_mutex);
43 m_cvWait.wait(lock, [this]() -> bool {
44 return
45 !m_running ||
46 (m_work.empty() && m_doingWork == 0);
47 });
48}
49
50void thread_pool::join_all()
51{
52 {
53 std::unique_lock<std::mutex> lock(m_mutex);
54 m_running = false;
55 }
56 m_cv.notify_all();
57
58 for (auto& j : m_threads) {
59 try {
60 if (j.joinable())
61 j.join();
62 }
63 catch (const std::exception& ex) {
64 LOG(FATAL, "Exception joining threads: %s\n", ex.what());
65 ASSERT(false);
66 }
67 catch (...) {
68 LOG(FATAL, "Exception joining threads\n");
69 ASSERT(false);
70 }
71 }
72}
73
74void thread_pool::worker()
75{
76 bool running;
77 {
78 std::unique_lock<std::mutex> lock(m_mutex);
79 running = m_running;
80 }
81 while (running) {
82 std::function<void()> func;
83 {
84 std::unique_lock<std::mutex> lock(m_mutex);
85 m_cv.wait(lock, [this]() -> bool {
86 return !m_running || !m_work.empty();
87 });
88 running = m_running;
89 if (m_running && !m_work.empty()) {
90 func = std::move(m_work.front());
91 ++m_doingWork;
92 m_work.pop();
93 }
94 }
95 try {
96 if (func)
97 func();
98 }
99 // TODO handle exceptions in a better way
100 catch (const std::exception& e) {
101 LOG(FATAL, "Exception from worker: %s", e.what());
102 ASSERT(false);
103 }
104 catch (...) {
105 LOG(FATAL, "Exception from worker\n");
106 ASSERT(false);
107 }
108
109 // Decrement m_doingWork only if we've incremented it
110 if (func) {
111 std::unique_lock<std::mutex> lock(m_mutex);
112 --m_doingWork;
113 m_cvWait.notify_all();
114 }
115 }
116}
117
118}
119