1/*=========================================================================*\
2* Socket compatibilization module for Unix
3* LuaSocket toolkit
4*
5* The code is now interrupt-safe.
6* The penalty of calling select to avoid busy-wait is only paid when
7* the I/O call fail in the first place.
8\*=========================================================================*/
9#include <string.h>
10#include <signal.h>
11
12#include "socket.h"
13#include "pierror.h"
14
15/*-------------------------------------------------------------------------*\
16* Wait for readable/writable/connected socket with timeout
17\*-------------------------------------------------------------------------*/
18#ifndef SOCKET_SELECT
19#include <sys/poll.h>
20
21#define WAITFD_R POLLIN
22#define WAITFD_W POLLOUT
23#define WAITFD_C (POLLIN|POLLOUT)
24int socket_waitfd(p_socket ps, int sw, p_timeout tm) {
25 int ret;
26 struct pollfd pfd;
27 pfd.fd = *ps;
28 pfd.events = sw;
29 pfd.revents = 0;
30 if (timeout_iszero(tm)) return IO_TIMEOUT; /* optimize timeout == 0 case */
31 do {
32 int t = (int)(timeout_getretry(tm)*1e3);
33 ret = poll(&pfd, 1, t >= 0? t: -1);
34 } while (ret == -1 && errno == EINTR);
35 if (ret == -1) return errno;
36 if (ret == 0) return IO_TIMEOUT;
37 if (sw == WAITFD_C && (pfd.revents & (POLLIN|POLLERR))) return IO_CLOSED;
38 return IO_DONE;
39}
40#else
41
42#define WAITFD_R 1
43#define WAITFD_W 2
44#define WAITFD_C (WAITFD_R|WAITFD_W)
45
46int socket_waitfd(p_socket ps, int sw, p_timeout tm) {
47 int ret;
48 fd_set rfds, wfds, *rp, *wp;
49 struct timeval tv, *tp;
50 double t;
51 if (*ps >= FD_SETSIZE) return EINVAL;
52 if (timeout_iszero(tm)) return IO_TIMEOUT; /* optimize timeout == 0 case */
53 do {
54 /* must set bits within loop, because select may have modifed them */
55 rp = wp = NULL;
56 if (sw & WAITFD_R) { FD_ZERO(&rfds); FD_SET(*ps, &rfds); rp = &rfds; }
57 if (sw & WAITFD_W) { FD_ZERO(&wfds); FD_SET(*ps, &wfds); wp = &wfds; }
58 t = timeout_getretry(tm);
59 tp = NULL;
60 if (t >= 0.0) {
61 tv.tv_sec = (int)t;
62 tv.tv_usec = (int)((t-tv.tv_sec)*1.0e6);
63 tp = &tv;
64 }
65 ret = select(*ps+1, rp, wp, NULL, tp);
66 } while (ret == -1 && errno == EINTR);
67 if (ret == -1) return errno;
68 if (ret == 0) return IO_TIMEOUT;
69 if (sw == WAITFD_C && FD_ISSET(*ps, &rfds)) return IO_CLOSED;
70 return IO_DONE;
71}
72#endif
73
74
75/*-------------------------------------------------------------------------*\
76* Initializes module
77\*-------------------------------------------------------------------------*/
78int socket_open(void) {
79 /* instals a handler to ignore sigpipe or it will crash us */
80 signal(SIGPIPE, SIG_IGN);
81 return 1;
82}
83
84/*-------------------------------------------------------------------------*\
85* Close module
86\*-------------------------------------------------------------------------*/
87int socket_close(void) {
88 return 1;
89}
90
91/*-------------------------------------------------------------------------*\
92* Close and inutilize socket
93\*-------------------------------------------------------------------------*/
94void socket_destroy(p_socket ps) {
95 if (*ps != SOCKET_INVALID) {
96 close(*ps);
97 *ps = SOCKET_INVALID;
98 }
99}
100
101/*-------------------------------------------------------------------------*\
102* Select with timeout control
103\*-------------------------------------------------------------------------*/
104int socket_select(t_socket n, fd_set *rfds, fd_set *wfds, fd_set *efds,
105 p_timeout tm) {
106 int ret;
107 do {
108 struct timeval tv;
109 double t = timeout_getretry(tm);
110 tv.tv_sec = (int) t;
111 tv.tv_usec = (int) ((t - tv.tv_sec) * 1.0e6);
112 /* timeout = 0 means no wait */
113 ret = select(n, rfds, wfds, efds, t >= 0.0 ? &tv: NULL);
114 } while (ret < 0 && errno == EINTR);
115 return ret;
116}
117
118/*-------------------------------------------------------------------------*\
119* Creates and sets up a socket
120\*-------------------------------------------------------------------------*/
121int socket_create(p_socket ps, int domain, int type, int protocol) {
122 *ps = socket(domain, type, protocol);
123 if (*ps != SOCKET_INVALID) return IO_DONE;
124 else return errno;
125}
126
127/*-------------------------------------------------------------------------*\
128* Binds or returns error message
129\*-------------------------------------------------------------------------*/
130int socket_bind(p_socket ps, SA *addr, socklen_t len) {
131 int err = IO_DONE;
132 socket_setblocking(ps);
133 if (bind(*ps, addr, len) < 0) err = errno;
134 socket_setnonblocking(ps);
135 return err;
136}
137
138/*-------------------------------------------------------------------------*\
139*
140\*-------------------------------------------------------------------------*/
141int socket_listen(p_socket ps, int backlog) {
142 int err = IO_DONE;
143 if (listen(*ps, backlog)) err = errno;
144 return err;
145}
146
147/*-------------------------------------------------------------------------*\
148*
149\*-------------------------------------------------------------------------*/
150void socket_shutdown(p_socket ps, int how) {
151 shutdown(*ps, how);
152}
153
154/*-------------------------------------------------------------------------*\
155* Connects or returns error message
156\*-------------------------------------------------------------------------*/
157int socket_connect(p_socket ps, SA *addr, socklen_t len, p_timeout tm) {
158 int err;
159 /* avoid calling on closed sockets */
160 if (*ps == SOCKET_INVALID) return IO_CLOSED;
161 /* call connect until done or failed without being interrupted */
162 do if (connect(*ps, addr, len) == 0) return IO_DONE;
163 while ((err = errno) == EINTR);
164 /* if connection failed immediately, return error code */
165 if (err != EINPROGRESS && err != EAGAIN) return err;
166 /* zero timeout case optimization */
167 if (timeout_iszero(tm)) return IO_TIMEOUT;
168 /* wait until we have the result of the connection attempt or timeout */
169 err = socket_waitfd(ps, WAITFD_C, tm);
170 if (err == IO_CLOSED) {
171 if (recv(*ps, (char *) &err, 0, 0) == 0) return IO_DONE;
172 else return errno;
173 } else return err;
174}
175
176/*-------------------------------------------------------------------------*\
177* Accept with timeout
178\*-------------------------------------------------------------------------*/
179int socket_accept(p_socket ps, p_socket pa, SA *addr, socklen_t *len, p_timeout tm) {
180 if (*ps == SOCKET_INVALID) return IO_CLOSED;
181 for ( ;; ) {
182 int err;
183 if ((*pa = accept(*ps, addr, len)) != SOCKET_INVALID) return IO_DONE;
184 err = errno;
185 if (err == EINTR) continue;
186 if (err != EAGAIN && err != ECONNABORTED) return err;
187 if ((err = socket_waitfd(ps, WAITFD_R, tm)) != IO_DONE) return err;
188 }
189 /* can't reach here */
190 return IO_UNKNOWN;
191}
192
193/*-------------------------------------------------------------------------*\
194* Send with timeout
195\*-------------------------------------------------------------------------*/
196int socket_send(p_socket ps, const char *data, size_t count,
197 size_t *sent, p_timeout tm)
198{
199 int err;
200 *sent = 0;
201 /* avoid making system calls on closed sockets */
202 if (*ps == SOCKET_INVALID) return IO_CLOSED;
203 /* loop until we send something or we give up on error */
204 for ( ;; ) {
205 long put = (long) send(*ps, data, count, 0);
206 /* if we sent anything, we are done */
207 if (put >= 0) {
208 *sent = put;
209 return IO_DONE;
210 }
211 err = errno;
212 /* EPIPE means the connection was closed */
213 if (err == EPIPE) return IO_CLOSED;
214 /* EPROTOTYPE means the connection is being closed (on Yosemite!)*/
215 if (err == EPROTOTYPE) continue;
216 /* we call was interrupted, just try again */
217 if (err == EINTR) continue;
218 /* if failed fatal reason, report error */
219 if (err != EAGAIN) return err;
220 /* wait until we can send something or we timeout */
221 if ((err = socket_waitfd(ps, WAITFD_W, tm)) != IO_DONE) return err;
222 }
223 /* can't reach here */
224 return IO_UNKNOWN;
225}
226
227/*-------------------------------------------------------------------------*\
228* Sendto with timeout
229\*-------------------------------------------------------------------------*/
230int socket_sendto(p_socket ps, const char *data, size_t count, size_t *sent,
231 SA *addr, socklen_t len, p_timeout tm)
232{
233 int err;
234 *sent = 0;
235 if (*ps == SOCKET_INVALID) return IO_CLOSED;
236 for ( ;; ) {
237 long put = (long) sendto(*ps, data, count, 0, addr, len);
238 if (put >= 0) {
239 *sent = put;
240 return IO_DONE;
241 }
242 err = errno;
243 if (err == EPIPE) return IO_CLOSED;
244 if (err == EPROTOTYPE) continue;
245 if (err == EINTR) continue;
246 if (err != EAGAIN) return err;
247 if ((err = socket_waitfd(ps, WAITFD_W, tm)) != IO_DONE) return err;
248 }
249 return IO_UNKNOWN;
250}
251
252/*-------------------------------------------------------------------------*\
253* Receive with timeout
254\*-------------------------------------------------------------------------*/
255int socket_recv(p_socket ps, char *data, size_t count, size_t *got, p_timeout tm) {
256 int err;
257 *got = 0;
258 if (*ps == SOCKET_INVALID) return IO_CLOSED;
259 for ( ;; ) {
260 long taken = (long) recv(*ps, data, count, 0);
261 if (taken > 0) {
262 *got = taken;
263 return IO_DONE;
264 }
265 err = errno;
266 if (taken == 0) return IO_CLOSED;
267 if (err == EINTR) continue;
268 if (err != EAGAIN) return err;
269 if ((err = socket_waitfd(ps, WAITFD_R, tm)) != IO_DONE) return err;
270 }
271 return IO_UNKNOWN;
272}
273
274/*-------------------------------------------------------------------------*\
275* Recvfrom with timeout
276\*-------------------------------------------------------------------------*/
277int socket_recvfrom(p_socket ps, char *data, size_t count, size_t *got,
278 SA *addr, socklen_t *len, p_timeout tm) {
279 int err;
280 *got = 0;
281 if (*ps == SOCKET_INVALID) return IO_CLOSED;
282 for ( ;; ) {
283 long taken = (long) recvfrom(*ps, data, count, 0, addr, len);
284 if (taken > 0) {
285 *got = taken;
286 return IO_DONE;
287 }
288 err = errno;
289 if (taken == 0) return IO_CLOSED;
290 if (err == EINTR) continue;
291 if (err != EAGAIN) return err;
292 if ((err = socket_waitfd(ps, WAITFD_R, tm)) != IO_DONE) return err;
293 }
294 return IO_UNKNOWN;
295}
296
297
298/*-------------------------------------------------------------------------*\
299* Write with timeout
300*
301* socket_read and socket_write are cut-n-paste of socket_send and socket_recv,
302* with send/recv replaced with write/read. We can't just use write/read
303* in the socket version, because behaviour when size is zero is different.
304\*-------------------------------------------------------------------------*/
305int socket_write(p_socket ps, const char *data, size_t count,
306 size_t *sent, p_timeout tm)
307{
308 int err;
309 *sent = 0;
310 /* avoid making system calls on closed sockets */
311 if (*ps == SOCKET_INVALID) return IO_CLOSED;
312 /* loop until we send something or we give up on error */
313 for ( ;; ) {
314 long put = (long) write(*ps, data, count);
315 /* if we sent anything, we are done */
316 if (put >= 0) {
317 *sent = put;
318 return IO_DONE;
319 }
320 err = errno;
321 /* EPIPE means the connection was closed */
322 if (err == EPIPE) return IO_CLOSED;
323 /* EPROTOTYPE means the connection is being closed (on Yosemite!)*/
324 if (err == EPROTOTYPE) continue;
325 /* we call was interrupted, just try again */
326 if (err == EINTR) continue;
327 /* if failed fatal reason, report error */
328 if (err != EAGAIN) return err;
329 /* wait until we can send something or we timeout */
330 if ((err = socket_waitfd(ps, WAITFD_W, tm)) != IO_DONE) return err;
331 }
332 /* can't reach here */
333 return IO_UNKNOWN;
334}
335
336/*-------------------------------------------------------------------------*\
337* Read with timeout
338* See note for socket_write
339\*-------------------------------------------------------------------------*/
340int socket_read(p_socket ps, char *data, size_t count, size_t *got, p_timeout tm) {
341 int err;
342 *got = 0;
343 if (*ps == SOCKET_INVALID) return IO_CLOSED;
344 for ( ;; ) {
345 long taken = (long) read(*ps, data, count);
346 if (taken > 0) {
347 *got = taken;
348 return IO_DONE;
349 }
350 err = errno;
351 if (taken == 0) return IO_CLOSED;
352 if (err == EINTR) continue;
353 if (err != EAGAIN) return err;
354 if ((err = socket_waitfd(ps, WAITFD_R, tm)) != IO_DONE) return err;
355 }
356 return IO_UNKNOWN;
357}
358
359/*-------------------------------------------------------------------------*\
360* Put socket into blocking mode
361\*-------------------------------------------------------------------------*/
362void socket_setblocking(p_socket ps) {
363 int flags = fcntl(*ps, F_GETFL, 0);
364 flags &= (~(O_NONBLOCK));
365 fcntl(*ps, F_SETFL, flags);
366}
367
368/*-------------------------------------------------------------------------*\
369* Put socket into non-blocking mode
370\*-------------------------------------------------------------------------*/
371void socket_setnonblocking(p_socket ps) {
372 int flags = fcntl(*ps, F_GETFL, 0);
373 flags |= O_NONBLOCK;
374 fcntl(*ps, F_SETFL, flags);
375}
376
377/*-------------------------------------------------------------------------*\
378* DNS helpers
379\*-------------------------------------------------------------------------*/
380int socket_gethostbyaddr(const char *addr, socklen_t len, struct hostent **hp) {
381 *hp = gethostbyaddr(addr, len, AF_INET);
382 if (*hp) return IO_DONE;
383 else if (h_errno) return h_errno;
384 else if (errno) return errno;
385 else return IO_UNKNOWN;
386}
387
388int socket_gethostbyname(const char *addr, struct hostent **hp) {
389 *hp = gethostbyname(addr);
390 if (*hp) return IO_DONE;
391 else if (h_errno) return h_errno;
392 else if (errno) return errno;
393 else return IO_UNKNOWN;
394}
395
396/*-------------------------------------------------------------------------*\
397* Error translation functions
398* Make sure important error messages are standard
399\*-------------------------------------------------------------------------*/
400const char *socket_hoststrerror(int err) {
401 if (err <= 0) return io_strerror(err);
402 switch (err) {
403 case HOST_NOT_FOUND: return PIE_HOST_NOT_FOUND;
404 default: return hstrerror(err);
405 }
406}
407
408const char *socket_strerror(int err) {
409 if (err <= 0) return io_strerror(err);
410 switch (err) {
411 case EADDRINUSE: return PIE_ADDRINUSE;
412 case EISCONN: return PIE_ISCONN;
413 case EACCES: return PIE_ACCESS;
414 case ECONNREFUSED: return PIE_CONNREFUSED;
415 case ECONNABORTED: return PIE_CONNABORTED;
416 case ECONNRESET: return PIE_CONNRESET;
417 case ETIMEDOUT: return PIE_TIMEDOUT;
418 default: {
419 return strerror(err);
420 }
421 }
422}
423
424const char *socket_ioerror(p_socket ps, int err) {
425 (void) ps;
426 return socket_strerror(err);
427}
428
429const char *socket_gaistrerror(int err) {
430 if (err == 0) return NULL;
431 switch (err) {
432 case EAI_AGAIN: return PIE_AGAIN;
433 case EAI_BADFLAGS: return PIE_BADFLAGS;
434#ifdef EAI_BADHINTS
435 case EAI_BADHINTS: return PIE_BADHINTS;
436#endif
437 case EAI_FAIL: return PIE_FAIL;
438 case EAI_FAMILY: return PIE_FAMILY;
439 case EAI_MEMORY: return PIE_MEMORY;
440 case EAI_NONAME: return PIE_NONAME;
441 case EAI_OVERFLOW: return PIE_OVERFLOW;
442#ifdef EAI_PROTOCOL
443 case EAI_PROTOCOL: return PIE_PROTOCOL;
444#endif
445 case EAI_SERVICE: return PIE_SERVICE;
446 case EAI_SOCKTYPE: return PIE_SOCKTYPE;
447 case EAI_SYSTEM: return strerror(errno);
448 default: return gai_strerror(err);
449 }
450}
451
452