1 | /*=========================================================================*\ |
2 | * Socket compatibilization module for Unix |
3 | * LuaSocket toolkit |
4 | * |
5 | * The code is now interrupt-safe. |
6 | * The penalty of calling select to avoid busy-wait is only paid when |
7 | * the I/O call fail in the first place. |
8 | \*=========================================================================*/ |
9 | #include <string.h> |
10 | #include <signal.h> |
11 | |
12 | #include "socket.h" |
13 | #include "pierror.h" |
14 | |
15 | /*-------------------------------------------------------------------------*\ |
16 | * Wait for readable/writable/connected socket with timeout |
17 | \*-------------------------------------------------------------------------*/ |
18 | #ifndef SOCKET_SELECT |
19 | #include <sys/poll.h> |
20 | |
21 | #define WAITFD_R POLLIN |
22 | #define WAITFD_W POLLOUT |
23 | #define WAITFD_C (POLLIN|POLLOUT) |
24 | int socket_waitfd(p_socket ps, int sw, p_timeout tm) { |
25 | int ret; |
26 | struct pollfd pfd; |
27 | pfd.fd = *ps; |
28 | pfd.events = sw; |
29 | pfd.revents = 0; |
30 | if (timeout_iszero(tm)) return IO_TIMEOUT; /* optimize timeout == 0 case */ |
31 | do { |
32 | int t = (int)(timeout_getretry(tm)*1e3); |
33 | ret = poll(&pfd, 1, t >= 0? t: -1); |
34 | } while (ret == -1 && errno == EINTR); |
35 | if (ret == -1) return errno; |
36 | if (ret == 0) return IO_TIMEOUT; |
37 | if (sw == WAITFD_C && (pfd.revents & (POLLIN|POLLERR))) return IO_CLOSED; |
38 | return IO_DONE; |
39 | } |
40 | #else |
41 | |
42 | #define WAITFD_R 1 |
43 | #define WAITFD_W 2 |
44 | #define WAITFD_C (WAITFD_R|WAITFD_W) |
45 | |
46 | int socket_waitfd(p_socket ps, int sw, p_timeout tm) { |
47 | int ret; |
48 | fd_set rfds, wfds, *rp, *wp; |
49 | struct timeval tv, *tp; |
50 | double t; |
51 | if (*ps >= FD_SETSIZE) return EINVAL; |
52 | if (timeout_iszero(tm)) return IO_TIMEOUT; /* optimize timeout == 0 case */ |
53 | do { |
54 | /* must set bits within loop, because select may have modifed them */ |
55 | rp = wp = NULL; |
56 | if (sw & WAITFD_R) { FD_ZERO(&rfds); FD_SET(*ps, &rfds); rp = &rfds; } |
57 | if (sw & WAITFD_W) { FD_ZERO(&wfds); FD_SET(*ps, &wfds); wp = &wfds; } |
58 | t = timeout_getretry(tm); |
59 | tp = NULL; |
60 | if (t >= 0.0) { |
61 | tv.tv_sec = (int)t; |
62 | tv.tv_usec = (int)((t-tv.tv_sec)*1.0e6); |
63 | tp = &tv; |
64 | } |
65 | ret = select(*ps+1, rp, wp, NULL, tp); |
66 | } while (ret == -1 && errno == EINTR); |
67 | if (ret == -1) return errno; |
68 | if (ret == 0) return IO_TIMEOUT; |
69 | if (sw == WAITFD_C && FD_ISSET(*ps, &rfds)) return IO_CLOSED; |
70 | return IO_DONE; |
71 | } |
72 | #endif |
73 | |
74 | |
75 | /*-------------------------------------------------------------------------*\ |
76 | * Initializes module |
77 | \*-------------------------------------------------------------------------*/ |
78 | int socket_open(void) { |
79 | /* instals a handler to ignore sigpipe or it will crash us */ |
80 | signal(SIGPIPE, SIG_IGN); |
81 | return 1; |
82 | } |
83 | |
84 | /*-------------------------------------------------------------------------*\ |
85 | * Close module |
86 | \*-------------------------------------------------------------------------*/ |
87 | int socket_close(void) { |
88 | return 1; |
89 | } |
90 | |
91 | /*-------------------------------------------------------------------------*\ |
92 | * Close and inutilize socket |
93 | \*-------------------------------------------------------------------------*/ |
94 | void socket_destroy(p_socket ps) { |
95 | if (*ps != SOCKET_INVALID) { |
96 | close(*ps); |
97 | *ps = SOCKET_INVALID; |
98 | } |
99 | } |
100 | |
101 | /*-------------------------------------------------------------------------*\ |
102 | * Select with timeout control |
103 | \*-------------------------------------------------------------------------*/ |
104 | int socket_select(t_socket n, fd_set *rfds, fd_set *wfds, fd_set *efds, |
105 | p_timeout tm) { |
106 | int ret; |
107 | do { |
108 | struct timeval tv; |
109 | double t = timeout_getretry(tm); |
110 | tv.tv_sec = (int) t; |
111 | tv.tv_usec = (int) ((t - tv.tv_sec) * 1.0e6); |
112 | /* timeout = 0 means no wait */ |
113 | ret = select(n, rfds, wfds, efds, t >= 0.0 ? &tv: NULL); |
114 | } while (ret < 0 && errno == EINTR); |
115 | return ret; |
116 | } |
117 | |
118 | /*-------------------------------------------------------------------------*\ |
119 | * Creates and sets up a socket |
120 | \*-------------------------------------------------------------------------*/ |
121 | int socket_create(p_socket ps, int domain, int type, int protocol) { |
122 | *ps = socket(domain, type, protocol); |
123 | if (*ps != SOCKET_INVALID) return IO_DONE; |
124 | else return errno; |
125 | } |
126 | |
127 | /*-------------------------------------------------------------------------*\ |
128 | * Binds or returns error message |
129 | \*-------------------------------------------------------------------------*/ |
130 | int socket_bind(p_socket ps, SA *addr, socklen_t len) { |
131 | int err = IO_DONE; |
132 | socket_setblocking(ps); |
133 | if (bind(*ps, addr, len) < 0) err = errno; |
134 | socket_setnonblocking(ps); |
135 | return err; |
136 | } |
137 | |
138 | /*-------------------------------------------------------------------------*\ |
139 | * |
140 | \*-------------------------------------------------------------------------*/ |
141 | int socket_listen(p_socket ps, int backlog) { |
142 | int err = IO_DONE; |
143 | if (listen(*ps, backlog)) err = errno; |
144 | return err; |
145 | } |
146 | |
147 | /*-------------------------------------------------------------------------*\ |
148 | * |
149 | \*-------------------------------------------------------------------------*/ |
150 | void socket_shutdown(p_socket ps, int how) { |
151 | shutdown(*ps, how); |
152 | } |
153 | |
154 | /*-------------------------------------------------------------------------*\ |
155 | * Connects or returns error message |
156 | \*-------------------------------------------------------------------------*/ |
157 | int socket_connect(p_socket ps, SA *addr, socklen_t len, p_timeout tm) { |
158 | int err; |
159 | /* avoid calling on closed sockets */ |
160 | if (*ps == SOCKET_INVALID) return IO_CLOSED; |
161 | /* call connect until done or failed without being interrupted */ |
162 | do if (connect(*ps, addr, len) == 0) return IO_DONE; |
163 | while ((err = errno) == EINTR); |
164 | /* if connection failed immediately, return error code */ |
165 | if (err != EINPROGRESS && err != EAGAIN) return err; |
166 | /* zero timeout case optimization */ |
167 | if (timeout_iszero(tm)) return IO_TIMEOUT; |
168 | /* wait until we have the result of the connection attempt or timeout */ |
169 | err = socket_waitfd(ps, WAITFD_C, tm); |
170 | if (err == IO_CLOSED) { |
171 | if (recv(*ps, (char *) &err, 0, 0) == 0) return IO_DONE; |
172 | else return errno; |
173 | } else return err; |
174 | } |
175 | |
176 | /*-------------------------------------------------------------------------*\ |
177 | * Accept with timeout |
178 | \*-------------------------------------------------------------------------*/ |
179 | int socket_accept(p_socket ps, p_socket pa, SA *addr, socklen_t *len, p_timeout tm) { |
180 | if (*ps == SOCKET_INVALID) return IO_CLOSED; |
181 | for ( ;; ) { |
182 | int err; |
183 | if ((*pa = accept(*ps, addr, len)) != SOCKET_INVALID) return IO_DONE; |
184 | err = errno; |
185 | if (err == EINTR) continue; |
186 | if (err != EAGAIN && err != ECONNABORTED) return err; |
187 | if ((err = socket_waitfd(ps, WAITFD_R, tm)) != IO_DONE) return err; |
188 | } |
189 | /* can't reach here */ |
190 | return IO_UNKNOWN; |
191 | } |
192 | |
193 | /*-------------------------------------------------------------------------*\ |
194 | * Send with timeout |
195 | \*-------------------------------------------------------------------------*/ |
196 | int socket_send(p_socket ps, const char *data, size_t count, |
197 | size_t *sent, p_timeout tm) |
198 | { |
199 | int err; |
200 | *sent = 0; |
201 | /* avoid making system calls on closed sockets */ |
202 | if (*ps == SOCKET_INVALID) return IO_CLOSED; |
203 | /* loop until we send something or we give up on error */ |
204 | for ( ;; ) { |
205 | long put = (long) send(*ps, data, count, 0); |
206 | /* if we sent anything, we are done */ |
207 | if (put >= 0) { |
208 | *sent = put; |
209 | return IO_DONE; |
210 | } |
211 | err = errno; |
212 | /* EPIPE means the connection was closed */ |
213 | if (err == EPIPE) return IO_CLOSED; |
214 | /* EPROTOTYPE means the connection is being closed (on Yosemite!)*/ |
215 | if (err == EPROTOTYPE) continue; |
216 | /* we call was interrupted, just try again */ |
217 | if (err == EINTR) continue; |
218 | /* if failed fatal reason, report error */ |
219 | if (err != EAGAIN) return err; |
220 | /* wait until we can send something or we timeout */ |
221 | if ((err = socket_waitfd(ps, WAITFD_W, tm)) != IO_DONE) return err; |
222 | } |
223 | /* can't reach here */ |
224 | return IO_UNKNOWN; |
225 | } |
226 | |
227 | /*-------------------------------------------------------------------------*\ |
228 | * Sendto with timeout |
229 | \*-------------------------------------------------------------------------*/ |
230 | int socket_sendto(p_socket ps, const char *data, size_t count, size_t *sent, |
231 | SA *addr, socklen_t len, p_timeout tm) |
232 | { |
233 | int err; |
234 | *sent = 0; |
235 | if (*ps == SOCKET_INVALID) return IO_CLOSED; |
236 | for ( ;; ) { |
237 | long put = (long) sendto(*ps, data, count, 0, addr, len); |
238 | if (put >= 0) { |
239 | *sent = put; |
240 | return IO_DONE; |
241 | } |
242 | err = errno; |
243 | if (err == EPIPE) return IO_CLOSED; |
244 | if (err == EPROTOTYPE) continue; |
245 | if (err == EINTR) continue; |
246 | if (err != EAGAIN) return err; |
247 | if ((err = socket_waitfd(ps, WAITFD_W, tm)) != IO_DONE) return err; |
248 | } |
249 | return IO_UNKNOWN; |
250 | } |
251 | |
252 | /*-------------------------------------------------------------------------*\ |
253 | * Receive with timeout |
254 | \*-------------------------------------------------------------------------*/ |
255 | int socket_recv(p_socket ps, char *data, size_t count, size_t *got, p_timeout tm) { |
256 | int err; |
257 | *got = 0; |
258 | if (*ps == SOCKET_INVALID) return IO_CLOSED; |
259 | for ( ;; ) { |
260 | long taken = (long) recv(*ps, data, count, 0); |
261 | if (taken > 0) { |
262 | *got = taken; |
263 | return IO_DONE; |
264 | } |
265 | err = errno; |
266 | if (taken == 0) return IO_CLOSED; |
267 | if (err == EINTR) continue; |
268 | if (err != EAGAIN) return err; |
269 | if ((err = socket_waitfd(ps, WAITFD_R, tm)) != IO_DONE) return err; |
270 | } |
271 | return IO_UNKNOWN; |
272 | } |
273 | |
274 | /*-------------------------------------------------------------------------*\ |
275 | * Recvfrom with timeout |
276 | \*-------------------------------------------------------------------------*/ |
277 | int socket_recvfrom(p_socket ps, char *data, size_t count, size_t *got, |
278 | SA *addr, socklen_t *len, p_timeout tm) { |
279 | int err; |
280 | *got = 0; |
281 | if (*ps == SOCKET_INVALID) return IO_CLOSED; |
282 | for ( ;; ) { |
283 | long taken = (long) recvfrom(*ps, data, count, 0, addr, len); |
284 | if (taken > 0) { |
285 | *got = taken; |
286 | return IO_DONE; |
287 | } |
288 | err = errno; |
289 | if (taken == 0) return IO_CLOSED; |
290 | if (err == EINTR) continue; |
291 | if (err != EAGAIN) return err; |
292 | if ((err = socket_waitfd(ps, WAITFD_R, tm)) != IO_DONE) return err; |
293 | } |
294 | return IO_UNKNOWN; |
295 | } |
296 | |
297 | |
298 | /*-------------------------------------------------------------------------*\ |
299 | * Write with timeout |
300 | * |
301 | * socket_read and socket_write are cut-n-paste of socket_send and socket_recv, |
302 | * with send/recv replaced with write/read. We can't just use write/read |
303 | * in the socket version, because behaviour when size is zero is different. |
304 | \*-------------------------------------------------------------------------*/ |
305 | int socket_write(p_socket ps, const char *data, size_t count, |
306 | size_t *sent, p_timeout tm) |
307 | { |
308 | int err; |
309 | *sent = 0; |
310 | /* avoid making system calls on closed sockets */ |
311 | if (*ps == SOCKET_INVALID) return IO_CLOSED; |
312 | /* loop until we send something or we give up on error */ |
313 | for ( ;; ) { |
314 | long put = (long) write(*ps, data, count); |
315 | /* if we sent anything, we are done */ |
316 | if (put >= 0) { |
317 | *sent = put; |
318 | return IO_DONE; |
319 | } |
320 | err = errno; |
321 | /* EPIPE means the connection was closed */ |
322 | if (err == EPIPE) return IO_CLOSED; |
323 | /* EPROTOTYPE means the connection is being closed (on Yosemite!)*/ |
324 | if (err == EPROTOTYPE) continue; |
325 | /* we call was interrupted, just try again */ |
326 | if (err == EINTR) continue; |
327 | /* if failed fatal reason, report error */ |
328 | if (err != EAGAIN) return err; |
329 | /* wait until we can send something or we timeout */ |
330 | if ((err = socket_waitfd(ps, WAITFD_W, tm)) != IO_DONE) return err; |
331 | } |
332 | /* can't reach here */ |
333 | return IO_UNKNOWN; |
334 | } |
335 | |
336 | /*-------------------------------------------------------------------------*\ |
337 | * Read with timeout |
338 | * See note for socket_write |
339 | \*-------------------------------------------------------------------------*/ |
340 | int socket_read(p_socket ps, char *data, size_t count, size_t *got, p_timeout tm) { |
341 | int err; |
342 | *got = 0; |
343 | if (*ps == SOCKET_INVALID) return IO_CLOSED; |
344 | for ( ;; ) { |
345 | long taken = (long) read(*ps, data, count); |
346 | if (taken > 0) { |
347 | *got = taken; |
348 | return IO_DONE; |
349 | } |
350 | err = errno; |
351 | if (taken == 0) return IO_CLOSED; |
352 | if (err == EINTR) continue; |
353 | if (err != EAGAIN) return err; |
354 | if ((err = socket_waitfd(ps, WAITFD_R, tm)) != IO_DONE) return err; |
355 | } |
356 | return IO_UNKNOWN; |
357 | } |
358 | |
359 | /*-------------------------------------------------------------------------*\ |
360 | * Put socket into blocking mode |
361 | \*-------------------------------------------------------------------------*/ |
362 | void socket_setblocking(p_socket ps) { |
363 | int flags = fcntl(*ps, F_GETFL, 0); |
364 | flags &= (~(O_NONBLOCK)); |
365 | fcntl(*ps, F_SETFL, flags); |
366 | } |
367 | |
368 | /*-------------------------------------------------------------------------*\ |
369 | * Put socket into non-blocking mode |
370 | \*-------------------------------------------------------------------------*/ |
371 | void socket_setnonblocking(p_socket ps) { |
372 | int flags = fcntl(*ps, F_GETFL, 0); |
373 | flags |= O_NONBLOCK; |
374 | fcntl(*ps, F_SETFL, flags); |
375 | } |
376 | |
377 | /*-------------------------------------------------------------------------*\ |
378 | * DNS helpers |
379 | \*-------------------------------------------------------------------------*/ |
380 | int socket_gethostbyaddr(const char *addr, socklen_t len, struct hostent **hp) { |
381 | *hp = gethostbyaddr(addr, len, AF_INET); |
382 | if (*hp) return IO_DONE; |
383 | else if (h_errno) return h_errno; |
384 | else if (errno) return errno; |
385 | else return IO_UNKNOWN; |
386 | } |
387 | |
388 | int socket_gethostbyname(const char *addr, struct hostent **hp) { |
389 | *hp = gethostbyname(addr); |
390 | if (*hp) return IO_DONE; |
391 | else if (h_errno) return h_errno; |
392 | else if (errno) return errno; |
393 | else return IO_UNKNOWN; |
394 | } |
395 | |
396 | /*-------------------------------------------------------------------------*\ |
397 | * Error translation functions |
398 | * Make sure important error messages are standard |
399 | \*-------------------------------------------------------------------------*/ |
400 | const char *socket_hoststrerror(int err) { |
401 | if (err <= 0) return io_strerror(err); |
402 | switch (err) { |
403 | case HOST_NOT_FOUND: return PIE_HOST_NOT_FOUND; |
404 | default: return hstrerror(err); |
405 | } |
406 | } |
407 | |
408 | const char *socket_strerror(int err) { |
409 | if (err <= 0) return io_strerror(err); |
410 | switch (err) { |
411 | case EADDRINUSE: return PIE_ADDRINUSE; |
412 | case EISCONN: return PIE_ISCONN; |
413 | case EACCES: return PIE_ACCESS; |
414 | case ECONNREFUSED: return PIE_CONNREFUSED; |
415 | case ECONNABORTED: return PIE_CONNABORTED; |
416 | case ECONNRESET: return PIE_CONNRESET; |
417 | case ETIMEDOUT: return PIE_TIMEDOUT; |
418 | default: { |
419 | return strerror(err); |
420 | } |
421 | } |
422 | } |
423 | |
424 | const char *socket_ioerror(p_socket ps, int err) { |
425 | (void) ps; |
426 | return socket_strerror(err); |
427 | } |
428 | |
429 | const char *socket_gaistrerror(int err) { |
430 | if (err == 0) return NULL; |
431 | switch (err) { |
432 | case EAI_AGAIN: return PIE_AGAIN; |
433 | case EAI_BADFLAGS: return PIE_BADFLAGS; |
434 | #ifdef EAI_BADHINTS |
435 | case EAI_BADHINTS: return PIE_BADHINTS; |
436 | #endif |
437 | case EAI_FAIL: return PIE_FAIL; |
438 | case EAI_FAMILY: return PIE_FAMILY; |
439 | case EAI_MEMORY: return PIE_MEMORY; |
440 | case EAI_NONAME: return PIE_NONAME; |
441 | case EAI_OVERFLOW: return PIE_OVERFLOW; |
442 | #ifdef EAI_PROTOCOL |
443 | case EAI_PROTOCOL: return PIE_PROTOCOL; |
444 | #endif |
445 | case EAI_SERVICE: return PIE_SERVICE; |
446 | case EAI_SOCKTYPE: return PIE_SOCKTYPE; |
447 | case EAI_SYSTEM: return strerror(errno); |
448 | default: return gai_strerror(err); |
449 | } |
450 | } |
451 | |
452 | |