1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements. See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership. The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License. You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied. See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18#include "arrow/util/task-group.h"
19
20#include <atomic>
21#include <condition_variable>
22#include <cstdint>
23#include <mutex>
24#include <utility>
25
26#include "arrow/util/logging.h"
27#include "arrow/util/thread-pool.h"
28
29namespace arrow {
30namespace internal {
31
32////////////////////////////////////////////////////////////////////////
33// Serial TaskGroup implementation
34
35class SerialTaskGroup : public TaskGroup {
36 public:
37 void AppendReal(std::function<Status()> task) override {
38 DCHECK(!finished_);
39 if (status_.ok()) {
40 status_ &= task();
41 }
42 }
43
44 Status current_status() override { return status_; }
45
46 bool ok() override { return status_.ok(); }
47
48 Status Finish() override {
49 if (!finished_) {
50 finished_ = true;
51 if (parent_) {
52 parent_->status_ &= status_;
53 }
54 }
55 return status_;
56 }
57
58 int parallelism() override { return 1; }
59
60 std::shared_ptr<TaskGroup> MakeSubGroup() override {
61 auto child = new SerialTaskGroup();
62 child->parent_ = this;
63 return std::shared_ptr<TaskGroup>(child);
64 }
65
66 protected:
67 Status status_;
68 bool finished_ = false;
69 SerialTaskGroup* parent_ = nullptr;
70};
71
72////////////////////////////////////////////////////////////////////////
73// Threaded TaskGroup implementation
74
75class ThreadedTaskGroup : public TaskGroup {
76 public:
77 explicit ThreadedTaskGroup(ThreadPool* thread_pool)
78 : thread_pool_(thread_pool), nremaining_(0), ok_(true) {}
79
80 ~ThreadedTaskGroup() override {
81 // Make sure all pending tasks are finished, so that dangling references
82 // to this don't persist.
83 ARROW_UNUSED(Finish());
84 }
85
86 void AppendReal(std::function<Status()> task) override {
87 // The hot path is unlocked thanks to atomics
88 // Only if an error occurs is the lock taken
89 if (ok_.load(std::memory_order_acquire)) {
90 nremaining_.fetch_add(1, std::memory_order_acquire);
91 Status st = thread_pool_->Spawn([this, task]() {
92 if (ok_.load(std::memory_order_acquire)) {
93 // XXX what about exceptions?
94 Status st = task();
95 UpdateStatus(std::move(st));
96 }
97 OneTaskDone();
98 });
99 UpdateStatus(std::move(st));
100 }
101 }
102
103 Status current_status() override {
104 std::lock_guard<std::mutex> lock(mutex_);
105 return status_;
106 }
107
108 bool ok() override { return ok_.load(); }
109
110 Status Finish() override {
111 std::unique_lock<std::mutex> lock(mutex_);
112 if (!finished_) {
113 cv_.wait(lock, [&]() { return nremaining_.load() == 0; });
114 // Current tasks may start other tasks, so only set this when done
115 finished_ = true;
116 if (parent_) {
117 parent_->OneTaskDone();
118 }
119 }
120 return status_;
121 }
122
123 int parallelism() override { return thread_pool_->GetCapacity(); }
124
125 std::shared_ptr<TaskGroup> MakeSubGroup() override {
126 std::lock_guard<std::mutex> lock(mutex_);
127 auto child = new ThreadedTaskGroup(thread_pool_);
128 child->parent_ = this;
129 nremaining_.fetch_add(1, std::memory_order_acquire);
130 return std::shared_ptr<TaskGroup>(child);
131 }
132
133 protected:
134 void UpdateStatus(Status&& st) {
135 // Must be called unlocked, only locks on error
136 if (ARROW_PREDICT_FALSE(!st.ok())) {
137 std::lock_guard<std::mutex> lock(mutex_);
138 ok_.store(false, std::memory_order_release);
139 status_ &= std::move(st);
140 }
141 }
142
143 void OneTaskDone() {
144 // Can be called unlocked thanks to atomics
145 auto nremaining = nremaining_.fetch_sub(1, std::memory_order_release) - 1;
146 DCHECK_GE(nremaining, 0);
147 if (nremaining == 0) {
148 // Take the lock so that ~ThreadedTaskGroup cannot destroy cv
149 // before cv.notify_one() has returned
150 std::unique_lock<std::mutex> lock(mutex_);
151 cv_.notify_one();
152 }
153 }
154
155 // These members are usable unlocked
156 ThreadPool* thread_pool_;
157 std::atomic<int32_t> nremaining_;
158 std::atomic<bool> ok_;
159
160 // These members use locking
161 std::mutex mutex_;
162 std::condition_variable cv_;
163 Status status_;
164 bool finished_ = false;
165 ThreadedTaskGroup* parent_ = nullptr;
166};
167
168std::shared_ptr<TaskGroup> TaskGroup::MakeSerial() {
169 return std::shared_ptr<TaskGroup>(new SerialTaskGroup);
170}
171
172std::shared_ptr<TaskGroup> TaskGroup::MakeThreaded(ThreadPool* thread_pool) {
173 return std::shared_ptr<TaskGroup>(new ThreadedTaskGroup(thread_pool));
174}
175
176} // namespace internal
177} // namespace arrow
178