1/*
2 * Copyright 2014-present Facebook, Inc.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include <thread>
18
19#include <folly/Conv.h>
20#include <folly/SocketAddress.h>
21#include <folly/String.h>
22#include <folly/io/IOBuf.h>
23#include <folly/io/async/AsyncTimeout.h>
24#include <folly/io/async/AsyncUDPServerSocket.h>
25#include <folly/io/async/AsyncUDPSocket.h>
26#include <folly/io/async/EventBase.h>
27#include <folly/portability/GMock.h>
28#include <folly/portability/GTest.h>
29#include <folly/portability/Sockets.h>
30
31using folly::AsyncTimeout;
32using folly::AsyncUDPServerSocket;
33using folly::AsyncUDPSocket;
34using folly::errnoStr;
35using folly::EventBase;
36using folly::IOBuf;
37using folly::SocketAddress;
38using namespace testing;
39
40class UDPAcceptor : public AsyncUDPServerSocket::Callback {
41 public:
42 UDPAcceptor(EventBase* evb, int n, bool changePortForWrites)
43 : evb_(evb), n_(n), changePortForWrites_(changePortForWrites) {}
44
45 void onListenStarted() noexcept override {}
46
47 void onListenStopped() noexcept override {}
48
49 void onDataAvailable(
50 std::shared_ptr<folly::AsyncUDPSocket> socket,
51 const folly::SocketAddress& client,
52 std::unique_ptr<folly::IOBuf> data,
53 bool truncated) noexcept override {
54 lastClient_ = client;
55 lastMsg_ = data->clone()->moveToFbString().toStdString();
56
57 auto len = data->computeChainDataLength();
58 VLOG(4) << "Worker " << n_ << " read " << len << " bytes "
59 << "(trun:" << truncated << ") from " << client.describe() << " - "
60 << lastMsg_;
61
62 sendPong(socket);
63 }
64
65 void sendPong(std::shared_ptr<folly::AsyncUDPSocket> socket) noexcept {
66 try {
67 auto writeSocket = socket;
68 if (changePortForWrites_) {
69 writeSocket = std::make_shared<folly::AsyncUDPSocket>(evb_);
70 writeSocket->setReuseAddr(false);
71 writeSocket->bind(folly::SocketAddress("127.0.0.1", 0));
72 }
73 writeSocket->write(lastClient_, folly::IOBuf::copyBuffer(lastMsg_));
74 } catch (const std::exception& ex) {
75 VLOG(4) << "Failed to send PONG " << ex.what();
76 }
77 }
78
79 private:
80 EventBase* const evb_{nullptr};
81 const int n_{-1};
82 // Whether to create a new port per write.
83 bool changePortForWrites_{true};
84
85 folly::SocketAddress lastClient_;
86 std::string lastMsg_;
87};
88
89class UDPServer {
90 public:
91 UDPServer(EventBase* evb, folly::SocketAddress addr, int n)
92 : evb_(evb), addr_(addr), evbs_(n) {}
93
94 void start() {
95 CHECK(evb_->isInEventBaseThread());
96
97 socket_ = std::make_unique<AsyncUDPServerSocket>(evb_, 1500);
98
99 try {
100 socket_->bind(addr_);
101 VLOG(4) << "Server listening on " << socket_->address().describe();
102 } catch (const std::exception& ex) {
103 LOG(FATAL) << ex.what();
104 }
105
106 acceptors_.reserve(evbs_.size());
107 threads_.reserve(evbs_.size());
108
109 // Add numWorkers thread
110 int i = 0;
111 for (auto& evb : evbs_) {
112 acceptors_.emplace_back(&evb, i, changePortForWrites_);
113
114 std::thread t([&]() { evb.loopForever(); });
115
116 evb.waitUntilRunning();
117
118 socket_->addListener(&evb, &acceptors_[i]);
119 threads_.emplace_back(std::move(t));
120 ++i;
121 }
122
123 socket_->listen();
124 }
125
126 folly::SocketAddress address() const {
127 return socket_->address();
128 }
129
130 void shutdown() {
131 CHECK(evb_->isInEventBaseThread());
132 socket_->close();
133 socket_.reset();
134
135 for (auto& evb : evbs_) {
136 evb.terminateLoopSoon();
137 }
138
139 for (auto& t : threads_) {
140 t.join();
141 }
142 }
143
144 void pauseAccepting() {
145 socket_->pauseAccepting();
146 }
147
148 void resumeAccepting() {
149 socket_->resumeAccepting();
150 }
151
152 // Whether writes from the UDP server should change the port for each message.
153 void setChangePortForWrites(bool changePortForWrites) {
154 changePortForWrites_ = changePortForWrites;
155 }
156
157 private:
158 EventBase* const evb_{nullptr};
159 const folly::SocketAddress addr_;
160
161 std::unique_ptr<AsyncUDPServerSocket> socket_;
162 std::vector<std::thread> threads_;
163 std::vector<folly::EventBase> evbs_;
164 std::vector<UDPAcceptor> acceptors_;
165 bool changePortForWrites_{true};
166};
167
168class UDPClient : private AsyncUDPSocket::ReadCallback, private AsyncTimeout {
169 public:
170 explicit UDPClient(EventBase* evb) : AsyncTimeout(evb), evb_(evb) {}
171
172 void start(const folly::SocketAddress& server, int n) {
173 CHECK(evb_->isInEventBaseThread());
174 server_ = server;
175 socket_ = std::make_unique<AsyncUDPSocket>(evb_);
176
177 try {
178 socket_->bind(folly::SocketAddress("127.0.0.1", 0));
179 if (connectAddr_) {
180 connect();
181 }
182 VLOG(2) << "Client bound to " << socket_->address().describe();
183 } catch (const std::exception& ex) {
184 LOG(FATAL) << ex.what();
185 }
186
187 socket_->resumeRead(this);
188
189 n_ = n;
190
191 // Start playing ping pong
192 sendPing();
193 }
194
195 void connect() {
196 int ret = socket_->connect(*connectAddr_);
197 if (ret != 0) {
198 throw folly::AsyncSocketException(
199 folly::AsyncSocketException::NOT_OPEN, "ConnectFail", errno);
200 }
201 VLOG(2) << "Client connected to address=" << *connectAddr_;
202 }
203
204 void shutdown() {
205 CHECK(evb_->isInEventBaseThread());
206 socket_->pauseRead();
207 socket_->close();
208 socket_.reset();
209 evb_->terminateLoopSoon();
210 }
211
212 void sendPing() {
213 if (n_ == 0) {
214 shutdown();
215 return;
216 }
217
218 --n_;
219 scheduleTimeout(5);
220 writePing(folly::IOBuf::copyBuffer(folly::to<std::string>("PING ", n_)));
221 }
222
223 virtual void writePing(std::unique_ptr<folly::IOBuf> buf) {
224 socket_->write(server_, std::move(buf));
225 }
226
227 void getReadBuffer(void** buf, size_t* len) noexcept override {
228 *buf = buf_;
229 *len = 1024;
230 }
231
232 void onDataAvailable(
233 const folly::SocketAddress& client,
234 size_t len,
235 bool truncated) noexcept override {
236 VLOG(4) << "Read " << len << " bytes (trun:" << truncated << ") from "
237 << client.describe() << " - " << std::string(buf_, len);
238 VLOG(4) << n_ << " left";
239
240 ++pongRecvd_;
241
242 sendPing();
243 }
244
245 void onReadError(const folly::AsyncSocketException& ex) noexcept override {
246 VLOG(4) << ex.what();
247
248 // Start listening for next PONG
249 socket_->resumeRead(this);
250 }
251
252 void onReadClosed() noexcept override {
253 CHECK(false) << "We unregister reads before closing";
254 }
255
256 void timeoutExpired() noexcept override {
257 VLOG(4) << "Timeout expired";
258 sendPing();
259 }
260
261 int pongRecvd() const {
262 return pongRecvd_;
263 }
264
265 AsyncUDPSocket& getSocket() {
266 return *socket_;
267 }
268
269 void setShouldConnect(const folly::SocketAddress& connectAddr) {
270 connectAddr_ = connectAddr;
271 }
272
273 protected:
274 folly::Optional<folly::SocketAddress> connectAddr_;
275 EventBase* const evb_{nullptr};
276
277 folly::SocketAddress server_;
278 std::unique_ptr<AsyncUDPSocket> socket_;
279
280 private:
281 int pongRecvd_{0};
282
283 int n_{0};
284 char buf_[1024];
285};
286
287class ConnectedWriteUDPClient : public UDPClient {
288 public:
289 ~ConnectedWriteUDPClient() override = default;
290
291 ConnectedWriteUDPClient(EventBase* evb) : UDPClient(evb) {}
292
293 // When the socket is connected you don't need to supply the address to send
294 // msg. This will test that connect worked.
295 void writePing(std::unique_ptr<folly::IOBuf> buf) override {
296 iovec vec[16];
297 size_t iovec_len =
298 buf->fillIov(vec, sizeof(vec) / sizeof(vec[0])).numIovecs;
299 if (UNLIKELY(iovec_len == 0)) {
300 buf->coalesce();
301 vec[0].iov_base = const_cast<uint8_t*>(buf->data());
302 vec[0].iov_len = buf->length();
303 iovec_len = 1;
304 }
305
306 struct msghdr msg;
307 msg.msg_name = nullptr;
308 msg.msg_namelen = 0;
309 msg.msg_iov = const_cast<struct iovec*>(vec);
310 msg.msg_iovlen = iovec_len;
311 msg.msg_control = nullptr;
312 msg.msg_controllen = 0;
313 msg.msg_flags = 0;
314
315 ssize_t ret = ::sendmsg(socket_->getFD(), &msg, 0);
316 if (ret == -1) {
317 if (errno != EAGAIN || errno != EWOULDBLOCK) {
318 throw folly::AsyncSocketException(
319 folly::AsyncSocketException::NOT_OPEN, "WriteFail", errno);
320 }
321 }
322 connect();
323 }
324};
325
326class AsyncSocketIntegrationTest : public Test {
327 public:
328 void SetUp() override {
329 server = std::make_unique<UDPServer>(
330 &sevb, folly::SocketAddress("127.0.0.1", 0), 4);
331
332 // Start event loop in a separate thread
333 serverThread =
334 std::make_unique<std::thread>([this]() { sevb.loopForever(); });
335
336 // Wait for event loop to start
337 sevb.waitUntilRunning();
338 }
339
340 void startServer() {
341 // Start the server
342 sevb.runInEventBaseThreadAndWait([&]() { server->start(); });
343 LOG(INFO) << "Server listening=" << server->address();
344 }
345
346 void TearDown() override {
347 // Shutdown server
348 sevb.runInEventBaseThread([&]() {
349 server->shutdown();
350 sevb.terminateLoopSoon();
351 });
352
353 // Wait for server thread to join
354 serverThread->join();
355 }
356
357 std::unique_ptr<UDPClient> performPingPongTest(
358 folly::Optional<folly::SocketAddress> connectedAddress,
359 bool useConnectedWrite);
360
361 folly::EventBase sevb;
362 folly::EventBase cevb;
363 std::unique_ptr<std::thread> serverThread;
364 std::unique_ptr<UDPServer> server;
365 std::unique_ptr<UDPClient> client;
366};
367
368std::unique_ptr<UDPClient> AsyncSocketIntegrationTest::performPingPongTest(
369 folly::Optional<folly::SocketAddress> connectedAddress,
370 bool useConnectedWrite) {
371 if (useConnectedWrite) {
372 CHECK(connectedAddress.hasValue());
373 client = std::make_unique<ConnectedWriteUDPClient>(&cevb);
374 client->setShouldConnect(*connectedAddress);
375 } else {
376 client = std::make_unique<UDPClient>(&cevb);
377 if (connectedAddress) {
378 client->setShouldConnect(*connectedAddress);
379 }
380 }
381 // Start event loop in a separate thread
382 auto clientThread = std::thread([this]() { cevb.loopForever(); });
383
384 // Wait for event loop to start
385 cevb.waitUntilRunning();
386
387 // Send ping
388 cevb.runInEventBaseThread([&]() { client->start(server->address(), 100); });
389
390 // Wait for client to finish
391 clientThread.join();
392 return std::move(client);
393}
394
395TEST_F(AsyncSocketIntegrationTest, PingPong) {
396 startServer();
397 auto pingClient = performPingPongTest(folly::none, false);
398 // This should succeed.
399 ASSERT_GT(pingClient->pongRecvd(), 0);
400}
401
402TEST_F(AsyncSocketIntegrationTest, ConnectedPingPong) {
403 server->setChangePortForWrites(false);
404 startServer();
405 auto pingClient = performPingPongTest(server->address(), false);
406 // This should succeed
407 ASSERT_GT(pingClient->pongRecvd(), 0);
408}
409
410TEST_F(AsyncSocketIntegrationTest, ConnectedPingPongServerWrongAddress) {
411 server->setChangePortForWrites(true);
412 startServer();
413 auto pingClient = performPingPongTest(server->address(), false);
414 // This should fail.
415 ASSERT_EQ(pingClient->pongRecvd(), 0);
416}
417
418TEST_F(AsyncSocketIntegrationTest, ConnectedPingPongClientWrongAddress) {
419 server->setChangePortForWrites(false);
420 startServer();
421 folly::SocketAddress connectAddr(
422 server->address().getIPAddress(), server->address().getPort() + 1);
423 auto pingClient = performPingPongTest(connectAddr, false);
424 // This should fail.
425 ASSERT_EQ(pingClient->pongRecvd(), 0);
426}
427
428TEST_F(AsyncSocketIntegrationTest, PingPongUseConnectedSendMsg) {
429 server->setChangePortForWrites(false);
430 startServer();
431 auto pingClient = performPingPongTest(server->address(), true);
432 // This should succeed.
433 ASSERT_GT(pingClient->pongRecvd(), 0);
434}
435
436TEST_F(AsyncSocketIntegrationTest, PingPongUseConnectedSendMsgServerWrongAddr) {
437 server->setChangePortForWrites(true);
438 startServer();
439 auto pingClient = performPingPongTest(server->address(), true);
440 // This should fail.
441 ASSERT_EQ(pingClient->pongRecvd(), 0);
442}
443
444TEST_F(AsyncSocketIntegrationTest, PingPongPauseResumeListening) {
445 startServer();
446
447 // Exchange should not happen when paused.
448 server->pauseAccepting();
449 auto pausedClient = performPingPongTest(folly::none, false);
450 ASSERT_EQ(pausedClient->pongRecvd(), 0);
451
452 // Exchange does occur after resuming.
453 server->resumeAccepting();
454 auto pingClient = performPingPongTest(folly::none, false);
455 ASSERT_GT(pingClient->pongRecvd(), 0);
456}
457
458class TestAsyncUDPSocket : public AsyncUDPSocket {
459 public:
460 explicit TestAsyncUDPSocket(EventBase* evb) : AsyncUDPSocket(evb) {}
461
462 MOCK_METHOD3(
463 sendmsg,
464 ssize_t(folly::NetworkSocket, const struct msghdr*, int));
465};
466
467class MockErrMessageCallback : public AsyncUDPSocket::ErrMessageCallback {
468 public:
469 ~MockErrMessageCallback() override = default;
470
471 MOCK_METHOD1(errMessage_, void(const cmsghdr&));
472 void errMessage(const cmsghdr& cmsg) noexcept override {
473 errMessage_(cmsg);
474 }
475
476 MOCK_METHOD1(errMessageError_, void(const folly::AsyncSocketException&));
477 void errMessageError(
478 const folly::AsyncSocketException& ex) noexcept override {
479 errMessageError_(ex);
480 }
481};
482
483class MockUDPReadCallback : public AsyncUDPSocket::ReadCallback {
484 public:
485 ~MockUDPReadCallback() override = default;
486
487 MOCK_METHOD2(getReadBuffer_, void(void**, size_t*));
488 void getReadBuffer(void** buf, size_t* len) noexcept override {
489 getReadBuffer_(buf, len);
490 }
491
492 MOCK_METHOD3(
493 onDataAvailable_,
494 void(const folly::SocketAddress&, size_t, bool));
495 void onDataAvailable(
496 const folly::SocketAddress& client,
497 size_t len,
498 bool truncated) noexcept override {
499 onDataAvailable_(client, len, truncated);
500 }
501
502 MOCK_METHOD1(onReadError_, void(const folly::AsyncSocketException&));
503 void onReadError(const folly::AsyncSocketException& ex) noexcept override {
504 onReadError_(ex);
505 }
506
507 MOCK_METHOD0(onReadClosed_, void());
508 void onReadClosed() noexcept override {
509 onReadClosed_();
510 }
511};
512
513class AsyncUDPSocketTest : public Test {
514 public:
515 void SetUp() override {
516 socket_ = std::make_shared<AsyncUDPSocket>(&evb_);
517 addr_ = folly::SocketAddress("127.0.0.1", 0);
518 socket_->bind(addr_);
519 }
520
521 EventBase evb_;
522 MockErrMessageCallback err;
523 MockUDPReadCallback readCb;
524 std::shared_ptr<AsyncUDPSocket> socket_;
525 folly::SocketAddress addr_;
526};
527
528TEST_F(AsyncUDPSocketTest, TestConnect) {
529 EXPECT_EQ(socket_->connect(addr_), 0);
530}
531
532TEST_F(AsyncUDPSocketTest, TestErrToNonExistentServer) {
533 socket_->resumeRead(&readCb);
534 socket_->setErrMessageCallback(&err);
535 folly::SocketAddress addr("127.0.0.1", 10000);
536 bool errRecvd = false;
537#ifdef FOLLY_HAVE_MSG_ERRQUEUE
538 EXPECT_CALL(err, errMessage_(_))
539 .WillOnce(Invoke([this, &errRecvd](auto& cmsg) {
540 if ((cmsg.cmsg_level == SOL_IP && cmsg.cmsg_type == IP_RECVERR) ||
541 (cmsg.cmsg_level == SOL_IPV6 && cmsg.cmsg_type == IPV6_RECVERR)) {
542 const struct sock_extended_err* serr =
543 reinterpret_cast<const struct sock_extended_err*>(
544 CMSG_DATA(&cmsg));
545 errRecvd =
546 (serr->ee_origin == SO_EE_ORIGIN_ICMP || SO_EE_ORIGIN_ICMP6);
547 LOG(ERROR) << "errno " << errnoStr(serr->ee_errno);
548 }
549 evb_.terminateLoopSoon();
550 }));
551#endif // FOLLY_HAVE_MSG_ERRQUEUE
552 socket_->write(addr, folly::IOBuf::copyBuffer("hey"));
553 evb_.loopForever();
554 EXPECT_TRUE(errRecvd);
555}
556
557TEST_F(AsyncUDPSocketTest, TestUnsetErrCallback) {
558 socket_->resumeRead(&readCb);
559 socket_->setErrMessageCallback(&err);
560 socket_->setErrMessageCallback(nullptr);
561 folly::SocketAddress addr("127.0.0.1", 10000);
562 EXPECT_CALL(err, errMessage_(_)).Times(0);
563 socket_->write(addr, folly::IOBuf::copyBuffer("hey"));
564 evb_.timer().scheduleTimeoutFn(
565 [&] { evb_.terminateLoopSoon(); }, std::chrono::milliseconds(30));
566 evb_.loopForever();
567}
568
569TEST_F(AsyncUDPSocketTest, CloseInErrorCallback) {
570 socket_->resumeRead(&readCb);
571 socket_->setErrMessageCallback(&err);
572 folly::SocketAddress addr("127.0.0.1", 10000);
573 bool errRecvd = false;
574 EXPECT_CALL(err, errMessage_(_)).WillOnce(Invoke([this, &errRecvd](auto&) {
575 errRecvd = true;
576 socket_->close();
577 evb_.terminateLoopSoon();
578 }));
579 socket_->write(addr, folly::IOBuf::copyBuffer("hey"));
580 socket_->write(addr, folly::IOBuf::copyBuffer("hey"));
581 evb_.loopForever();
582 EXPECT_TRUE(errRecvd);
583}
584
585TEST_F(AsyncUDPSocketTest, TestNonExistentServerNoErrCb) {
586 socket_->resumeRead(&readCb);
587 folly::SocketAddress addr("127.0.0.1", 10000);
588 bool errRecvd = false;
589 folly::IOBufQueue readBuf;
590 EXPECT_CALL(readCb, getReadBuffer_(_, _))
591 .WillRepeatedly(Invoke([&readBuf](void** buf, size_t* len) {
592 auto readSpace = readBuf.preallocate(2000, 10000);
593 *buf = readSpace.first;
594 *len = readSpace.second;
595 }));
596 ON_CALL(readCb, onReadError_(_)).WillByDefault(Invoke([&errRecvd](auto& ex) {
597 LOG(ERROR) << ex.what();
598 errRecvd = true;
599 }));
600 socket_->write(addr, folly::IOBuf::copyBuffer("hey"));
601 evb_.timer().scheduleTimeoutFn(
602 [&] { evb_.terminateLoopSoon(); }, std::chrono::milliseconds(30));
603 evb_.loopForever();
604 EXPECT_FALSE(errRecvd);
605}
606
607TEST_F(AsyncUDPSocketTest, TestBound) {
608 AsyncUDPSocket socket(&evb_);
609 EXPECT_FALSE(socket.isBound());
610 folly::SocketAddress address("0.0.0.0", 0);
611 socket.bind(address);
612 EXPECT_TRUE(socket.isBound());
613}
614
615TEST_F(AsyncUDPSocketTest, TestAttachAfterDetachEvbWithReadCallback) {
616 socket_->resumeRead(&readCb);
617 EXPECT_TRUE(socket_->isHandlerRegistered());
618 socket_->detachEventBase();
619 EXPECT_FALSE(socket_->isHandlerRegistered());
620 socket_->attachEventBase(&evb_);
621 EXPECT_TRUE(socket_->isHandlerRegistered());
622}
623
624TEST_F(AsyncUDPSocketTest, TestAttachAfterDetachEvbNoReadCallback) {
625 EXPECT_FALSE(socket_->isHandlerRegistered());
626 socket_->detachEventBase();
627 EXPECT_FALSE(socket_->isHandlerRegistered());
628 socket_->attachEventBase(&evb_);
629 EXPECT_FALSE(socket_->isHandlerRegistered());
630}
631
632TEST_F(AsyncUDPSocketTest, TestDetachAttach) {
633 folly::EventBase evb2;
634 auto writeSocket = std::make_shared<folly::AsyncUDPSocket>(&evb_);
635 folly::SocketAddress address("127.0.0.1", 0);
636 writeSocket->bind(address);
637 std::array<uint8_t, 1024> data;
638 std::atomic<int> packetsRecvd{0};
639 EXPECT_CALL(readCb, getReadBuffer_(_, _))
640 .WillRepeatedly(Invoke([&](void** buf, size_t* len) {
641 *buf = data.data();
642 *len = 1024;
643 }));
644 EXPECT_CALL(readCb, onDataAvailable_(_, _, _))
645 .WillRepeatedly(Invoke(
646 [&](const folly::SocketAddress&, size_t, bool) { packetsRecvd++; }));
647 socket_->resumeRead(&readCb);
648 writeSocket->write(socket_->address(), folly::IOBuf::copyBuffer("hello"));
649 while (packetsRecvd != 1) {
650 evb_.loopOnce();
651 }
652 EXPECT_EQ(packetsRecvd, 1);
653
654 socket_->detachEventBase();
655 std::thread t([&] { evb2.loopForever(); });
656 evb2.runInEventBaseThreadAndWait([&] { socket_->attachEventBase(&evb2); });
657 writeSocket->write(socket_->address(), folly::IOBuf::copyBuffer("hello"));
658 auto now = std::chrono::steady_clock::now();
659 while (packetsRecvd != 2 ||
660 std::chrono::steady_clock::now() <
661 now + std::chrono::milliseconds(10)) {
662 std::this_thread::sleep_for(std::chrono::milliseconds(1));
663 }
664 evb2.runInEventBaseThread([&] {
665 socket_ = nullptr;
666 evb2.terminateLoopSoon();
667 });
668 t.join();
669 EXPECT_EQ(packetsRecvd, 2);
670}
671