1/*
2 * Copyright 2013-present Facebook, Inc.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include <folly/net/NetOps.h>
18
19#include <errno.h>
20#include <fcntl.h>
21
22#include <cstddef>
23
24#include <folly/Portability.h>
25#include <folly/net/detail/SocketFileDescriptorMap.h>
26
27#if _WIN32
28#include <event2/util.h> // @manual
29
30#include <MSWSock.h> // @manual
31
32#include <folly/ScopeGuard.h>
33#endif
34
35namespace folly {
36namespace netops {
37
38namespace {
39#if _WIN32
40// WSA has to be explicitly initialized.
41static struct WinSockInit {
42 WinSockInit() {
43 WSADATA dat;
44 WSAStartup(MAKEWORD(2, 2), &dat);
45 }
46 ~WinSockInit() {
47 WSACleanup();
48 }
49} winsockInit;
50
51int translate_wsa_error(int wsaErr) {
52 switch (wsaErr) {
53 case WSAEWOULDBLOCK:
54 return EAGAIN;
55 default:
56 return wsaErr;
57 }
58}
59#endif
60
61template <class R, class F, class... Args>
62static R wrapSocketFunction(F f, NetworkSocket s, Args... args) {
63 R ret = f(s.data, args...);
64#if _WIN32
65 errno = translate_wsa_error(WSAGetLastError());
66#endif
67 return ret;
68}
69} // namespace
70
71NetworkSocket accept(NetworkSocket s, sockaddr* addr, socklen_t* addrlen) {
72 return NetworkSocket(wrapSocketFunction<NetworkSocket::native_handle_type>(
73 ::accept, s, addr, addrlen));
74}
75
76int bind(NetworkSocket s, const sockaddr* name, socklen_t namelen) {
77 if (kIsWindows && name->sa_family == AF_UNIX) {
78 // Windows added support for AF_UNIX sockets, but didn't add
79 // support for autobind sockets, so detect requests for autobind
80 // sockets and treat them as invalid. (otherwise they don't trigger
81 // an error, but also don't actually work)
82 if (name->sa_data[0] == '\0') {
83 errno = EINVAL;
84 return -1;
85 }
86 }
87 return wrapSocketFunction<int>(::bind, s, name, namelen);
88}
89
90int close(NetworkSocket s) {
91 return netops::detail::SocketFileDescriptorMap::close(s.data);
92}
93
94int connect(NetworkSocket s, const sockaddr* name, socklen_t namelen) {
95 auto r = wrapSocketFunction<int>(::connect, s, name, namelen);
96#if _WIN32
97 if (r == -1 && WSAGetLastError() == WSAEWOULDBLOCK) {
98 errno = EINPROGRESS;
99 }
100#endif
101 return r;
102}
103
104int getpeername(NetworkSocket s, sockaddr* name, socklen_t* namelen) {
105 return wrapSocketFunction<int>(::getpeername, s, name, namelen);
106}
107
108int getsockname(NetworkSocket s, sockaddr* name, socklen_t* namelen) {
109 return wrapSocketFunction<int>(::getsockname, s, name, namelen);
110}
111
112int getsockopt(
113 NetworkSocket s,
114 int level,
115 int optname,
116 void* optval,
117 socklen_t* optlen) {
118 auto ret = wrapSocketFunction<int>(
119 ::getsockopt, s, level, optname, (char*)optval, optlen);
120#if _WIN32
121 if (optname == TCP_NODELAY && *optlen == 1) {
122 // Windows is weird about this value, and documents it as a
123 // BOOL (ie. int) but expects the variable to be bool (1-byte),
124 // so we get to adapt the interface to work that way.
125 *(int*)optval = *(uint8_t*)optval;
126 *optlen = sizeof(int);
127 }
128#endif
129 return ret;
130}
131
132int inet_aton(const char* cp, in_addr* inp) {
133 inp->s_addr = inet_addr(cp);
134 return inp->s_addr == INADDR_NONE ? 0 : 1;
135}
136
137int listen(NetworkSocket s, int backlog) {
138 return wrapSocketFunction<int>(::listen, s, backlog);
139}
140
141int poll(PollDescriptor fds[], nfds_t nfds, int timeout) {
142 // Make sure that PollDescriptor is byte-for-byte identical to pollfd,
143 // so we don't need extra allocations just for the safety of this shim.
144 static_assert(
145 alignof(PollDescriptor) == alignof(pollfd),
146 "PollDescriptor is misaligned");
147 static_assert(
148 sizeof(PollDescriptor) == sizeof(pollfd),
149 "PollDescriptor is the wrong size");
150 static_assert(
151 offsetof(PollDescriptor, fd) == offsetof(pollfd, fd),
152 "PollDescriptor.fd is at the wrong place");
153 static_assert(
154 sizeof(decltype(PollDescriptor().fd)) == sizeof(decltype(pollfd().fd)),
155 "PollDescriptor.fd is the wrong size");
156 static_assert(
157 offsetof(PollDescriptor, events) == offsetof(pollfd, events),
158 "PollDescriptor.events is at the wrong place");
159 static_assert(
160 sizeof(decltype(PollDescriptor().events)) ==
161 sizeof(decltype(pollfd().events)),
162 "PollDescriptor.events is the wrong size");
163 static_assert(
164 offsetof(PollDescriptor, revents) == offsetof(pollfd, revents),
165 "PollDescriptor.revents is at the wrong place");
166 static_assert(
167 sizeof(decltype(PollDescriptor().revents)) ==
168 sizeof(decltype(pollfd().revents)),
169 "PollDescriptor.revents is the wrong size");
170
171 // Pun it through
172 pollfd* files = reinterpret_cast<pollfd*>(reinterpret_cast<void*>(fds));
173#if _WIN32
174 return ::WSAPoll(files, (ULONG)nfds, timeout);
175#else
176 return ::poll(files, nfds, timeout);
177#endif
178}
179
180ssize_t recv(NetworkSocket s, void* buf, size_t len, int flags) {
181#if _WIN32
182 if ((flags & MSG_DONTWAIT) == MSG_DONTWAIT) {
183 flags &= ~MSG_DONTWAIT;
184
185 u_long pendingRead = 0;
186 if (ioctlsocket(s.data, FIONREAD, &pendingRead)) {
187 errno = translate_wsa_error(WSAGetLastError());
188 return -1;
189 }
190
191 fd_set readSet;
192 FD_ZERO(&readSet);
193 FD_SET(s.data, &readSet);
194 timeval timeout{0, 0};
195 auto ret = select(1, &readSet, nullptr, nullptr, &timeout);
196 if (ret == 0) {
197 errno = EWOULDBLOCK;
198 return -1;
199 }
200 }
201 return wrapSocketFunction<ssize_t>(::recv, s, (char*)buf, (int)len, flags);
202#else
203 return wrapSocketFunction<ssize_t>(::recv, s, buf, len, flags);
204#endif
205}
206
207ssize_t recvfrom(
208 NetworkSocket s,
209 void* buf,
210 size_t len,
211 int flags,
212 sockaddr* from,
213 socklen_t* fromlen) {
214#if _WIN32
215 if ((flags & MSG_TRUNC) == MSG_TRUNC) {
216 SOCKET h = s.data;
217
218 WSABUF wBuf{};
219 wBuf.buf = (CHAR*)buf;
220 wBuf.len = (ULONG)len;
221 WSAMSG wMsg{};
222 wMsg.dwBufferCount = 1;
223 wMsg.lpBuffers = &wBuf;
224 wMsg.name = from;
225 if (fromlen != nullptr) {
226 wMsg.namelen = *fromlen;
227 }
228
229 // WSARecvMsg is an extension, so we don't get
230 // the convenience of being able to call it directly, even though
231 // WSASendMsg is part of the normal API -_-...
232 LPFN_WSARECVMSG WSARecvMsg;
233 GUID WSARecgMsg_GUID = WSAID_WSARECVMSG;
234 DWORD recMsgBytes;
235 WSAIoctl(
236 h,
237 SIO_GET_EXTENSION_FUNCTION_POINTER,
238 &WSARecgMsg_GUID,
239 sizeof(WSARecgMsg_GUID),
240 &WSARecvMsg,
241 sizeof(WSARecvMsg),
242 &recMsgBytes,
243 nullptr,
244 nullptr);
245
246 DWORD bytesReceived;
247 int res = WSARecvMsg(h, &wMsg, &bytesReceived, nullptr, nullptr);
248 errno = translate_wsa_error(WSAGetLastError());
249 if (res == 0) {
250 return bytesReceived;
251 }
252 if (fromlen != nullptr) {
253 *fromlen = wMsg.namelen;
254 }
255 if ((wMsg.dwFlags & MSG_TRUNC) == MSG_TRUNC) {
256 return wBuf.len + 1;
257 }
258 return -1;
259 }
260 return wrapSocketFunction<ssize_t>(
261 ::recvfrom, s, (char*)buf, (int)len, flags, from, fromlen);
262#else
263 return wrapSocketFunction<ssize_t>(
264 ::recvfrom, s, buf, len, flags, from, fromlen);
265#endif
266}
267
268ssize_t recvmsg(NetworkSocket s, msghdr* message, int flags) {
269#if _WIN32
270 (void)flags;
271 SOCKET h = s.data;
272
273 // Don't currently support the name translation.
274 if (message->msg_name != nullptr || message->msg_namelen != 0) {
275 return (ssize_t)-1;
276 }
277 WSAMSG msg;
278 msg.name = nullptr;
279 msg.namelen = 0;
280 msg.Control.buf = (CHAR*)message->msg_control;
281 msg.Control.len = (ULONG)message->msg_controllen;
282 msg.dwFlags = 0;
283 msg.dwBufferCount = (DWORD)message->msg_iovlen;
284 msg.lpBuffers = new WSABUF[message->msg_iovlen];
285 SCOPE_EXIT {
286 delete[] msg.lpBuffers;
287 };
288 for (size_t i = 0; i < message->msg_iovlen; i++) {
289 msg.lpBuffers[i].buf = (CHAR*)message->msg_iov[i].iov_base;
290 msg.lpBuffers[i].len = (ULONG)message->msg_iov[i].iov_len;
291 }
292
293 // WSARecvMsg is an extension, so we don't get
294 // the convenience of being able to call it directly, even though
295 // WSASendMsg is part of the normal API -_-...
296 LPFN_WSARECVMSG WSARecvMsg;
297 GUID WSARecgMsg_GUID = WSAID_WSARECVMSG;
298 DWORD recMsgBytes;
299 WSAIoctl(
300 h,
301 SIO_GET_EXTENSION_FUNCTION_POINTER,
302 &WSARecgMsg_GUID,
303 sizeof(WSARecgMsg_GUID),
304 &WSARecvMsg,
305 sizeof(WSARecvMsg),
306 &recMsgBytes,
307 nullptr,
308 nullptr);
309
310 DWORD bytesReceived;
311 int res = WSARecvMsg(h, &msg, &bytesReceived, nullptr, nullptr);
312 errno = translate_wsa_error(WSAGetLastError());
313 return res == 0 ? (ssize_t)bytesReceived : -1;
314#else
315 return wrapSocketFunction<ssize_t>(::recvmsg, s, message, flags);
316#endif
317}
318
319ssize_t send(NetworkSocket s, const void* buf, size_t len, int flags) {
320#if _WIN32
321 return wrapSocketFunction<ssize_t>(
322 ::send, s, (const char*)buf, (int)len, flags);
323#else
324 return wrapSocketFunction<ssize_t>(::send, s, buf, len, flags);
325#endif
326}
327
328ssize_t sendmsg(NetworkSocket socket, const msghdr* message, int flags) {
329#if _WIN32
330 (void)flags;
331 SOCKET h = socket.data;
332
333 // Unfortunately, WSASendMsg requires the socket to have been opened
334 // as either SOCK_DGRAM or SOCK_RAW, but sendmsg has no such requirement,
335 // so we have to implement it based on send instead :(
336 ssize_t bytesSent = 0;
337 for (size_t i = 0; i < message->msg_iovlen; i++) {
338 int r = -1;
339 if (message->msg_name != nullptr) {
340 r = ::sendto(
341 h,
342 (const char*)message->msg_iov[i].iov_base,
343 (int)message->msg_iov[i].iov_len,
344 message->msg_flags,
345 (const sockaddr*)message->msg_name,
346 (int)message->msg_namelen);
347 } else {
348 r = ::send(
349 h,
350 (const char*)message->msg_iov[i].iov_base,
351 (int)message->msg_iov[i].iov_len,
352 message->msg_flags);
353 }
354 if (r == -1 || size_t(r) != message->msg_iov[i].iov_len) {
355 errno = translate_wsa_error(WSAGetLastError());
356 if (WSAGetLastError() == WSAEWOULDBLOCK && bytesSent > 0) {
357 return bytesSent;
358 }
359 return -1;
360 }
361 bytesSent += r;
362 }
363 return bytesSent;
364#else
365 return wrapSocketFunction<ssize_t>(::sendmsg, socket, message, flags);
366#endif
367}
368
369int sendmmsg(
370 NetworkSocket socket,
371 mmsghdr* msgvec,
372 unsigned int vlen,
373 int flags) {
374#if FOLLY_HAVE_SENDMMSG
375 return wrapSocketFunction<int>(::sendmmsg, socket, msgvec, vlen, flags);
376#else
377 // implement via sendmsg
378 for (unsigned int i = 0; i < vlen; i++) {
379 ssize_t ret = sendmsg(socket, &msgvec[i].msg_hdr, flags);
380 // in case of an error
381 // we return the number of msgs sent if > 0
382 // or an error if no msg was sent
383 if (ret < 0) {
384 if (i) {
385 return static_cast<int>(i);
386 }
387
388 return static_cast<int>(ret);
389 }
390 }
391
392 return static_cast<int>(vlen);
393#endif
394}
395
396ssize_t sendto(
397 NetworkSocket s,
398 const void* buf,
399 size_t len,
400 int flags,
401 const sockaddr* to,
402 socklen_t tolen) {
403#if _WIN32
404 return wrapSocketFunction<ssize_t>(
405 ::sendto, s, (const char*)buf, (int)len, flags, to, (int)tolen);
406#else
407 return wrapSocketFunction<ssize_t>(::sendto, s, buf, len, flags, to, tolen);
408#endif
409}
410
411int setsockopt(
412 NetworkSocket s,
413 int level,
414 int optname,
415 const void* optval,
416 socklen_t optlen) {
417#if _WIN32
418 if (optname == SO_REUSEADDR) {
419 // We don't have an equivelent to the Linux & OSX meaning of this
420 // on Windows, so ignore it.
421 return 0;
422 } else if (optname == SO_REUSEPORT) {
423 // Windows's SO_REUSEADDR option is closer to SO_REUSEPORT than
424 // it is to the Linux & OSX meaning of SO_REUSEADDR.
425 return -1;
426 }
427 return wrapSocketFunction<int>(
428 ::setsockopt, s, level, optname, (char*)optval, optlen);
429#else
430 return wrapSocketFunction<int>(
431 ::setsockopt, s, level, optname, optval, optlen);
432#endif
433}
434
435int shutdown(NetworkSocket s, int how) {
436 return wrapSocketFunction<int>(::shutdown, s, how);
437}
438
439NetworkSocket socket(int af, int type, int protocol) {
440 return NetworkSocket(::socket(af, type, protocol));
441}
442
443int socketpair(int domain, int type, int protocol, NetworkSocket sv[2]) {
444#if _WIN32
445 if (domain != PF_UNIX || type != SOCK_STREAM || protocol != 0) {
446 return -1;
447 }
448 intptr_t pair[2];
449 auto r = evutil_socketpair(AF_INET, type, protocol, pair);
450 if (r == -1) {
451 return r;
452 }
453 sv[0] = NetworkSocket(static_cast<SOCKET>(pair[0]));
454 sv[1] = NetworkSocket(static_cast<SOCKET>(pair[1]));
455 return r;
456#else
457 int pair[2];
458 auto r = ::socketpair(domain, type, protocol, pair);
459 if (r == -1) {
460 return r;
461 }
462 sv[0] = NetworkSocket(pair[0]);
463 sv[1] = NetworkSocket(pair[1]);
464 return r;
465#endif
466}
467
468int set_socket_non_blocking(NetworkSocket s) {
469#if _WIN32
470 u_long nonBlockingEnabled = 1;
471 return ioctlsocket(s.data, FIONBIO, &nonBlockingEnabled);
472#else
473 int flags = fcntl(s.data, F_GETFL, 0);
474 if (flags == -1) {
475 return -1;
476 }
477 return fcntl(s.data, F_SETFL, flags | O_NONBLOCK);
478#endif
479}
480
481int set_socket_close_on_exec(NetworkSocket s) {
482#if _WIN32
483 if (SetHandleInformation((HANDLE)s.data, HANDLE_FLAG_INHERIT, 0)) {
484 return 0;
485 }
486 return -1;
487#else
488 return fcntl(s.data, F_SETFD, FD_CLOEXEC);
489#endif
490}
491} // namespace netops
492} // namespace folly
493