1 | // Licensed to the Apache Software Foundation (ASF) under one |
2 | // or more contributor license agreements. See the NOTICE file |
3 | // distributed with this work for additional information |
4 | // regarding copyright ownership. The ASF licenses this file |
5 | // to you under the Apache License, Version 2.0 (the |
6 | // "License"); you may not use this file except in compliance |
7 | // with the License. You may obtain a copy of the License at |
8 | // |
9 | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | // |
11 | // Unless required by applicable law or agreed to in writing, |
12 | // software distributed under the License is distributed on an |
13 | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | // KIND, either express or implied. See the License for the |
15 | // specific language governing permissions and limitations |
16 | // under the License. |
17 | |
18 | #include <algorithm> |
19 | #include <chrono> |
20 | #include <cstdint> |
21 | #include <cstring> |
22 | #include <functional> |
23 | #include <memory> |
24 | #include <mutex> |
25 | #include <set> |
26 | #include <string> |
27 | #include <thread> |
28 | #include <utility> |
29 | #include <vector> |
30 | |
31 | #include <gtest/gtest.h> |
32 | |
33 | #include "arrow/buffer.h" |
34 | #include "arrow/io/interfaces.h" |
35 | #include "arrow/io/memory.h" |
36 | #include "arrow/io/readahead.h" |
37 | #include "arrow/memory_pool.h" |
38 | #include "arrow/status.h" |
39 | #include "arrow/test-util.h" |
40 | #include "arrow/util/checked_cast.h" |
41 | |
42 | namespace arrow { |
43 | |
44 | using internal::checked_cast; |
45 | |
46 | namespace io { |
47 | namespace internal { |
48 | |
49 | class LockedInputStream : public InputStream { |
50 | public: |
51 | explicit LockedInputStream(const std::shared_ptr<InputStream>& stream) |
52 | : stream_(stream) {} |
53 | |
54 | Status Close() override { |
55 | std::lock_guard<std::mutex> lock(mutex_); |
56 | return stream_->Close(); |
57 | } |
58 | |
59 | bool closed() const override { |
60 | std::lock_guard<std::mutex> lock(mutex_); |
61 | return stream_->closed(); |
62 | } |
63 | |
64 | Status Tell(int64_t* position) const override { |
65 | std::lock_guard<std::mutex> lock(mutex_); |
66 | return stream_->Tell(position); |
67 | } |
68 | |
69 | Status Read(int64_t nbytes, int64_t* bytes_read, void* buffer) override { |
70 | std::lock_guard<std::mutex> lock(mutex_); |
71 | return stream_->Read(nbytes, bytes_read, buffer); |
72 | } |
73 | |
74 | Status Read(int64_t nbytes, std::shared_ptr<Buffer>* out) override { |
75 | std::lock_guard<std::mutex> lock(mutex_); |
76 | return stream_->Read(nbytes, out); |
77 | } |
78 | |
79 | bool supports_zero_copy() const override { |
80 | std::lock_guard<std::mutex> lock(mutex_); |
81 | return stream_->supports_zero_copy(); |
82 | } |
83 | |
84 | util::string_view Peek(int64_t nbytes) const override { |
85 | std::lock_guard<std::mutex> lock(mutex_); |
86 | return stream_->Peek(nbytes); |
87 | } |
88 | |
89 | protected: |
90 | std::shared_ptr<InputStream> stream_; |
91 | mutable std::mutex mutex_; |
92 | }; |
93 | |
94 | static void sleep_for(double seconds) { |
95 | std::this_thread::sleep_for( |
96 | std::chrono::nanoseconds(static_cast<int64_t>(seconds * 1e9))); |
97 | } |
98 | |
99 | static void busy_wait(double seconds, std::function<bool()> predicate) { |
100 | const double period = 0.001; |
101 | for (int i = 0; !predicate() && i * period < seconds; ++i) { |
102 | sleep_for(period); |
103 | } |
104 | } |
105 | |
106 | std::shared_ptr<InputStream> DataReader(const std::string& data) { |
107 | std::shared_ptr<Buffer> buffer; |
108 | ABORT_NOT_OK(Buffer::FromString(data, &buffer)); |
109 | return std::make_shared<LockedInputStream>(std::make_shared<BufferReader>(buffer)); |
110 | } |
111 | |
112 | static int64_t WaitForPosition(const FileInterface& file, int64_t expected, |
113 | double seconds = 0.2) { |
114 | int64_t pos = -1; |
115 | busy_wait(seconds, [&]() -> bool { |
116 | ABORT_NOT_OK(file.Tell(&pos)); |
117 | return pos >= expected; |
118 | }); |
119 | return pos; |
120 | } |
121 | |
122 | static void AssertEventualPosition(const FileInterface& file, int64_t expected) { |
123 | int64_t pos = WaitForPosition(file, expected); |
124 | ASSERT_EQ(pos, expected) << "File didn't reach expected position" ; |
125 | } |
126 | |
127 | static void AssertPosition(const FileInterface& file, int64_t expected) { |
128 | int64_t pos = -1; |
129 | ABORT_NOT_OK(file.Tell(&pos)); |
130 | ASSERT_EQ(pos, expected) << "File didn't reach expected position" ; |
131 | } |
132 | |
133 | template <typename Expected> |
134 | static void AssertReadaheadBuffer(const ReadaheadBuffer& buf, |
135 | std::set<int64_t> left_paddings, |
136 | std::set<int64_t> right_paddings, |
137 | const Expected& expected_data) { |
138 | ASSERT_TRUE(left_paddings.count(buf.left_padding)) |
139 | << "Left padding (" << buf.left_padding << ") not amongst expected values" ; |
140 | ASSERT_TRUE(right_paddings.count(buf.right_padding)) |
141 | << "Right padding (" << buf.right_padding << ") not amongst expected values" ; |
142 | auto actual_data = |
143 | SliceBuffer(buf.buffer, buf.left_padding, |
144 | buf.buffer->size() - buf.left_padding - buf.right_padding); |
145 | AssertBufferEqual(*actual_data, expected_data); |
146 | } |
147 | |
148 | static void AssertReadaheadBufferEOF(const ReadaheadBuffer& buf) { |
149 | ASSERT_EQ(buf.buffer.get(), nullptr) << "Expected EOF signalled by null buffer pointer" ; |
150 | } |
151 | |
152 | TEST(ReadaheadSpooler, BasicReads) { |
153 | // Test basic reads |
154 | auto data_reader = DataReader("0123456789" ); |
155 | ReadaheadSpooler spooler(data_reader, 2, 3); |
156 | ReadaheadBuffer buf; |
157 | |
158 | AssertEventualPosition(*data_reader, 3 * 2); |
159 | |
160 | ASSERT_OK(spooler.Read(&buf)); |
161 | AssertReadaheadBuffer(buf, {0}, {0}, "01" ); |
162 | AssertEventualPosition(*data_reader, 4 * 2); |
163 | ASSERT_OK(spooler.Read(&buf)); |
164 | AssertReadaheadBuffer(buf, {0}, {0}, "23" ); |
165 | AssertEventualPosition(*data_reader, 5 * 2); |
166 | ASSERT_OK(spooler.Read(&buf)); |
167 | AssertReadaheadBuffer(buf, {0}, {0}, "45" ); |
168 | ASSERT_OK(spooler.Read(&buf)); |
169 | AssertReadaheadBuffer(buf, {0}, {0}, "67" ); |
170 | ASSERT_OK(spooler.Read(&buf)); |
171 | AssertReadaheadBuffer(buf, {0}, {0}, "89" ); |
172 | ASSERT_OK(spooler.Read(&buf)); |
173 | AssertReadaheadBufferEOF(buf); |
174 | ASSERT_OK(spooler.Read(&buf)); |
175 | AssertReadaheadBufferEOF(buf); |
176 | } |
177 | |
178 | TEST(ReadaheadSpooler, ShortReadAtEnd) { |
179 | auto data_reader = DataReader("01234" ); |
180 | ReadaheadSpooler spooler(data_reader, 3, 2); |
181 | ReadaheadBuffer buf; |
182 | |
183 | AssertEventualPosition(*data_reader, 5); |
184 | |
185 | ASSERT_OK(spooler.Read(&buf)); |
186 | AssertReadaheadBuffer(buf, {0}, {0}, "012" ); |
187 | ASSERT_OK(spooler.Read(&buf)); |
188 | AssertReadaheadBuffer(buf, {0}, {0}, "34" ); |
189 | ASSERT_OK(spooler.Read(&buf)); |
190 | AssertReadaheadBufferEOF(buf); |
191 | } |
192 | |
193 | TEST(ReadaheadSpooler, Close) { |
194 | // Closing should stop reads |
195 | auto data_reader = DataReader("0123456789" ); |
196 | ReadaheadSpooler spooler(data_reader, 2, 2); |
197 | ReadaheadBuffer buf; |
198 | |
199 | AssertEventualPosition(*data_reader, 2 * 2); |
200 | ASSERT_OK(spooler.Close()); |
201 | |
202 | // XXX not sure this makes sense |
203 | ASSERT_OK(spooler.Read(&buf)); |
204 | AssertReadaheadBuffer(buf, {0}, {0}, "01" ); |
205 | ASSERT_OK(spooler.Read(&buf)); |
206 | AssertReadaheadBuffer(buf, {0}, {0}, "23" ); |
207 | ASSERT_OK(spooler.Read(&buf)); |
208 | AssertReadaheadBufferEOF(buf); |
209 | AssertPosition(*data_reader, 2 * 2); |
210 | |
211 | // Idempotency |
212 | ASSERT_OK(spooler.Close()); |
213 | } |
214 | |
215 | TEST(ReadaheadSpooler, Paddings) { |
216 | auto data_reader = DataReader("0123456789" ); |
217 | ReadaheadSpooler spooler(data_reader, 2, 2, 1 /* left_padding */, |
218 | 4 /* right_padding */); |
219 | ReadaheadBuffer buf; |
220 | |
221 | AssertEventualPosition(*data_reader, 2 * 2); |
222 | ASSERT_EQ(spooler.GetLeftPadding(), 1); |
223 | ASSERT_EQ(spooler.GetRightPadding(), 4); |
224 | spooler.SetLeftPadding(3); |
225 | spooler.SetRightPadding(2); |
226 | ASSERT_EQ(spooler.GetLeftPadding(), 3); |
227 | ASSERT_EQ(spooler.GetRightPadding(), 2); |
228 | |
229 | ASSERT_OK(spooler.Read(&buf)); |
230 | AssertReadaheadBuffer(buf, {1}, {4}, "01" ); |
231 | ASSERT_OK(spooler.Read(&buf)); |
232 | AssertReadaheadBuffer(buf, {1}, {4}, "23" ); |
233 | ASSERT_OK(spooler.Read(&buf)); |
234 | AssertReadaheadBuffer(buf, {3}, {2}, "45" ); |
235 | ASSERT_OK(spooler.Read(&buf)); |
236 | AssertReadaheadBuffer(buf, {3}, {2}, "67" ); |
237 | spooler.SetLeftPadding(4); |
238 | ASSERT_OK(spooler.Read(&buf)); |
239 | AssertReadaheadBuffer(buf, {3, 4}, {2}, "89" ); |
240 | ASSERT_OK(spooler.Read(&buf)); |
241 | AssertReadaheadBufferEOF(buf); |
242 | } |
243 | |
244 | TEST(ReadaheadSpooler, StressReads) { |
245 | // NBYTES % READ_SIZE != 0 ensures a short read at end |
246 | #if defined(ARROW_VALGRIND) |
247 | const int64_t NBYTES = 101; |
248 | #else |
249 | const int64_t NBYTES = 50001; |
250 | #endif |
251 | const int64_t READ_SIZE = 2; |
252 | |
253 | std::shared_ptr<ResizableBuffer> data; |
254 | ASSERT_OK(MakeRandomByteBuffer(NBYTES, default_memory_pool(), &data)); |
255 | auto data_reader = std::make_shared<BufferReader>(data); |
256 | |
257 | ReadaheadSpooler spooler(data_reader, READ_SIZE, 7); |
258 | int64_t pos = 0; |
259 | std::vector<ReadaheadBuffer> readahead_buffers; |
260 | |
261 | // Stress Read() calls while the background thread is reading ahead |
262 | while (pos < NBYTES) { |
263 | ReadaheadBuffer buf; |
264 | ASSERT_OK(spooler.Read(&buf)); |
265 | ASSERT_NE(buf.buffer.get(), nullptr) << "Got premature EOF at index " << pos; |
266 | pos += buf.buffer->size() - buf.left_padding - buf.right_padding; |
267 | readahead_buffers.push_back(std::move(buf)); |
268 | } |
269 | // At EOF |
270 | { |
271 | ReadaheadBuffer buf; |
272 | ASSERT_OK(spooler.Read(&buf)); |
273 | AssertReadaheadBufferEOF(buf); |
274 | } |
275 | |
276 | pos = 0; |
277 | for (const auto& buf : readahead_buffers) { |
278 | auto expected_data = SliceBuffer(data, pos, std::min(READ_SIZE, NBYTES - pos)); |
279 | AssertReadaheadBuffer(buf, {0}, {0}, *expected_data); |
280 | pos += expected_data->size(); |
281 | } |
282 | // Got exactly the total bytes length |
283 | ASSERT_EQ(pos, NBYTES); |
284 | } |
285 | |
286 | } // namespace internal |
287 | } // namespace io |
288 | } // namespace arrow |
289 | |