connect: no do not leak descriptors on failure
[kgio.git] / ext / kgio / kgio_ext.c
blob977273dd06d37b65e62032e4cdc4a394e1ef1c6f
1 #include <ruby.h>
2 #ifdef HAVE_RUBY_IO_H
3 # include <ruby/io.h>
4 #else
5 # include <rubyio.h>
6 #endif
7 #include <errno.h>
8 #include <sys/types.h>
9 #include <sys/socket.h>
10 #include <sys/un.h>
11 #include <netinet/in.h>
12 #include <fcntl.h>
13 #include <unistd.h>
14 #include <arpa/inet.h>
15 #include <assert.h>
17 #include "missing/accept4.h"
18 #include "nonblock.h"
19 #include "my_fileno.h"
20 #include "sock_for_fd.h"
22 #if defined(__linux__)
24 * we know MSG_DONTWAIT works properly on all stream sockets under Linux
25 * we can define this macro for other platforms as people care and
26 * notice.
28 # define USE_MSG_DONTWAIT
29 static int accept4_flags = SOCK_CLOEXEC;
30 #else /* ! linux */
31 static int accept4_flags = SOCK_CLOEXEC | SOCK_NONBLOCK;
32 #endif /* ! linux */
34 static VALUE cSocket;
35 static VALUE localhost;
36 static VALUE mKgio_WaitReadable, mKgio_WaitWritable;
37 static ID io_wait_rd, io_wait_wr;
38 static ID iv_kgio_addr;
40 struct io_args {
41 VALUE io;
42 VALUE buf;
43 char *ptr;
44 long len;
45 int fd;
48 static void wait_readable(VALUE io)
50 if (io_wait_rd) {
51 (void)rb_funcall(io, io_wait_rd, 0, 0);
52 } else {
53 int fd = my_fileno(io);
55 if (!rb_io_wait_readable(fd))
56 rb_sys_fail("wait readable");
60 static void wait_writable(VALUE io)
62 if (io_wait_wr) {
63 (void)rb_funcall(io, io_wait_wr, 0, 0);
64 } else {
65 int fd = my_fileno(io);
67 if (!rb_io_wait_writable(fd))
68 rb_sys_fail("wait writable");
72 static void prepare_read(struct io_args *a, int argc, VALUE *argv, VALUE io)
74 VALUE length;
76 a->io = io;
77 a->fd = my_fileno(io);
78 rb_scan_args(argc, argv, "11", &length, &a->buf);
79 a->len = NUM2LONG(length);
80 if (NIL_P(a->buf)) {
81 a->buf = rb_str_new(NULL, a->len);
82 } else {
83 StringValue(a->buf);
84 rb_str_resize(a->buf, a->len);
86 a->ptr = RSTRING_PTR(a->buf);
89 static int read_check(struct io_args *a, long n, const char *msg, int io_wait)
91 if (n == -1) {
92 if (errno == EINTR)
93 return -1;
94 rb_str_set_len(a->buf, 0);
95 if (errno == EAGAIN) {
96 if (io_wait) {
97 wait_readable(a->io);
98 return -1;
99 } else {
100 a->buf = mKgio_WaitReadable;
101 return 0;
104 rb_sys_fail(msg);
106 rb_str_set_len(a->buf, n);
107 if (n == 0)
108 a->buf = Qnil;
109 return 0;
113 * Document-method: Kgio::PipeMethods#kgio_read
115 * call-seq:
117 * socket.kgio_read(maxlen) -> buffer
118 * socket.kgio_read(maxlen, buffer) -> buffer
120 * Reads at most maxlen bytes from the stream socket. Returns with a
121 * newly allocated buffer, or may reuse an existing buffer. This
122 * calls the method identified by Kgio.wait_readable, or uses
123 * the normal, thread-safe Ruby function to wait for readability.
124 * This returns nil on EOF.
126 * This behaves like read(2) and IO#readpartial, NOT fread(3) or
127 * IO#read which possess read-in-full behavior.
129 static VALUE my_read(int io_wait, int argc, VALUE *argv, VALUE io)
131 struct io_args a;
132 long n;
134 prepare_read(&a, argc, argv, io);
135 set_nonblocking(a.fd);
136 retry:
137 n = (long)read(a.fd, a.ptr, a.len);
138 if (read_check(&a, n, "read", io_wait) != 0)
139 goto retry;
140 return a.buf;
143 static VALUE kgio_read(int argc, VALUE *argv, VALUE io)
145 return my_read(1, argc, argv, io);
148 static VALUE kgio_tryread(int argc, VALUE *argv, VALUE io)
150 return my_read(0, argc, argv, io);
153 #ifdef USE_MSG_DONTWAIT
154 static VALUE my_recv(int io_wait, int argc, VALUE *argv, VALUE io)
156 struct io_args a;
157 long n;
159 prepare_read(&a, argc, argv, io);
160 retry:
161 n = (long)recv(a.fd, a.ptr, a.len, MSG_DONTWAIT);
162 if (read_check(&a, n, "recv", io_wait) != 0)
163 goto retry;
164 return a.buf;
167 static VALUE kgio_recv(int argc, VALUE *argv, VALUE io)
169 return my_recv(1, argc, argv, io);
172 static VALUE kgio_tryrecv(int argc, VALUE *argv, VALUE io)
174 return my_recv(0, argc, argv, io);
176 #else /* ! USE_MSG_DONTWAIT */
177 # define kgio_recv kgio_read
178 # define kgio_tryrecv kgio_tryread
179 #endif /* USE_MSG_DONTWAIT */
181 static void prepare_write(struct io_args *a, VALUE io, VALUE str)
183 a->buf = (TYPE(str) == T_STRING) ? str : rb_obj_as_string(str);
184 a->ptr = RSTRING_PTR(a->buf);
185 a->len = RSTRING_LEN(a->buf);
186 a->io = io;
187 a->fd = my_fileno(io);
190 static int write_check(struct io_args *a, long n, const char *msg, int io_wait)
192 if (a->len == n) {
193 a->buf = Qnil;
194 } else if (n == -1) {
195 if (errno == EINTR)
196 return -1;
197 if (errno == EAGAIN) {
198 if (io_wait) {
199 wait_writable(a->io);
200 return -1;
201 } else {
202 a->buf = mKgio_WaitWritable;
203 return 0;
206 rb_sys_fail(msg);
207 } else {
208 assert(n >= 0 && n < a->len && "write/send syscall broken?");
209 if (io_wait) {
210 a->ptr += n;
211 a->len -= n;
212 return -1;
214 a->buf = rb_str_new(a->ptr + n, a->len - n);
216 return 0;
219 static VALUE my_write(VALUE io, VALUE str, int io_wait)
221 struct io_args a;
222 long n;
224 prepare_write(&a, io, str);
225 set_nonblocking(a.fd);
226 retry:
227 n = (long)write(a.fd, a.ptr, a.len);
228 if (write_check(&a, n, "write", io_wait) != 0)
229 goto retry;
230 return a.buf;
234 * Returns true if the write was completed.
236 * Calls the method Kgio.wait_writable is not set
238 static VALUE kgio_write(VALUE io, VALUE str)
240 return my_write(io, str, 1);
244 * Returns a String containing the unwritten portion if there was a
245 * partial write. Will return Kgio::WaitReadable if EAGAIN is
246 * encountered.
248 * Returns true if the write completed in full.
250 static VALUE kgio_trywrite(VALUE io, VALUE str)
252 return my_write(io, str, 0);
255 #ifdef USE_MSG_DONTWAIT
257 * This method behaves like Kgio::PipeMethods#kgio_write, except
258 * it will use send(2) with the MSG_DONTWAIT flag on sockets to
259 * avoid unnecessary calls to fcntl(2).
261 static VALUE my_send(VALUE io, VALUE str, int io_wait)
263 struct io_args a;
264 long n;
266 prepare_write(&a, io, str);
267 retry:
268 n = (long)send(a.fd, a.ptr, a.len, MSG_DONTWAIT);
269 if (write_check(&a, n, "send", io_wait) != 0)
270 goto retry;
271 return a.buf;
274 static VALUE kgio_send(VALUE io, VALUE str)
276 return my_send(io, str, 1);
279 static VALUE kgio_trysend(VALUE io, VALUE str)
281 return my_send(io, str, 0);
283 #else /* ! USE_MSG_DONTWAIT */
284 # define kgio_send kgio_write
285 # define kgio_trysend kgio_trywrite
286 #endif /* ! USE_MSG_DONTWAIT */
289 * call-seq:
291 * Kgio.wait_readable = :method_name
293 * Sets a method for kgio_read to call when a read would block.
294 * This is useful for non-blocking frameworks that use Fibers,
295 * as the method referred to this may cause the current Fiber
296 * to yield execution.
298 * A special value of nil will cause Ruby to wait using the
299 * rb_io_wait_readable() function, giving kgio_read similar semantics to
300 * IO#readpartial.
302 static VALUE set_wait_rd(VALUE mod, VALUE sym)
304 switch (TYPE(sym)) {
305 case T_SYMBOL:
306 io_wait_rd = SYM2ID(sym);
307 return sym;
308 case T_NIL:
309 io_wait_rd = 0;
310 return sym;
312 rb_raise(rb_eTypeError, "must be a symbol or nil");
313 return sym;
316 static VALUE set_wait_wr(VALUE mod, VALUE sym)
318 switch (TYPE(sym)) {
319 case T_SYMBOL:
320 io_wait_wr = SYM2ID(sym);
321 return sym;
322 case T_NIL:
323 io_wait_wr = 0;
324 return sym;
326 rb_raise(rb_eTypeError, "must be a symbol or nil");
327 return sym;
330 static VALUE wait_wr(VALUE mod)
332 return io_wait_wr ? ID2SYM(io_wait_wr) : Qnil;
335 static VALUE wait_rd(VALUE mod)
337 return io_wait_rd ? ID2SYM(io_wait_rd) : Qnil;
340 static VALUE
341 my_accept(int sockfd, struct sockaddr *addr, socklen_t *addrlen)
343 int client;
345 retry:
346 client = accept4(sockfd, addr, addrlen, accept4_flags);
347 if (client == -1) {
348 switch (errno) {
349 case EAGAIN:
350 #ifdef ECONNABORTED
351 case ECONNABORTED:
352 #endif /* ECONNABORTED */
353 #ifdef EPROTO
354 case EPROTO:
355 #endif /* EPROTO */
356 return Qnil;
357 case ENOMEM:
358 case EMFILE:
359 case ENFILE:
360 #ifdef ENOBUFS
361 case ENOBUFS:
362 #endif /* ENOBUFS */
363 errno = 0;
364 rb_gc();
365 client = accept4(sockfd, addr, addrlen, accept4_flags);
366 break;
367 case EINTR:
368 goto retry;
370 if (client == -1)
371 rb_sys_fail("accept");
373 return sock_for_fd(cSocket, client);
376 /* non-blocking flag should be set on this socket before accept() is called */
377 static VALUE unix_accept(VALUE io)
379 int fd = my_fileno(io);
380 VALUE rv = my_accept(fd, NULL, NULL);
382 if (! NIL_P(rv))
383 rb_ivar_set(rv, iv_kgio_addr, localhost);
385 return rv;
388 /* non-blocking flag should be set on this socket before accept() is called */
389 static VALUE tcp_accept(VALUE io)
391 int fd = my_fileno(io);
392 struct sockaddr_in addr;
393 socklen_t addrlen = sizeof(struct sockaddr_in);
394 VALUE host;
395 const char *name;
396 VALUE rv = my_accept(fd, (struct sockaddr *)&addr, &addrlen);
398 if (NIL_P(rv))
399 return rv;
401 host = rb_str_new(0, INET_ADDRSTRLEN);
402 addrlen = (socklen_t)INET_ADDRSTRLEN;
403 name = inet_ntop(AF_INET, &addr.sin_addr, RSTRING_PTR(host), addrlen);
404 if (name == NULL)
405 rb_sys_fail("inet_ntop");
406 rb_str_set_len(host, strlen(name));
407 rb_ivar_set(rv, iv_kgio_addr, host);
409 return rv;
412 static VALUE get_cloexec(VALUE mod)
414 return (accept4_flags & SOCK_CLOEXEC) == SOCK_CLOEXEC ? Qtrue : Qfalse;
417 static VALUE get_nonblock(VALUE mod)
419 return (accept4_flags & SOCK_NONBLOCK)==SOCK_NONBLOCK ? Qtrue : Qfalse;
422 static VALUE set_cloexec(VALUE mod, VALUE boolean)
424 switch (TYPE(boolean)) {
425 case T_TRUE:
426 accept4_flags |= SOCK_CLOEXEC;
427 return boolean;
428 case T_FALSE:
429 accept4_flags &= ~SOCK_CLOEXEC;
430 return boolean;
432 rb_raise(rb_eTypeError, "not true or false");
433 return Qnil;
436 static VALUE set_nonblock(VALUE mod, VALUE boolean)
438 switch (TYPE(boolean)) {
439 case T_TRUE:
440 accept4_flags |= SOCK_NONBLOCK;
441 return boolean;
442 case T_FALSE:
443 accept4_flags &= ~SOCK_NONBLOCK;
444 return boolean;
446 rb_raise(rb_eTypeError, "not true or false");
447 return Qnil;
450 static void close_fail(int fd, const char *msg)
452 int saved_errno = errno;
453 (void)close(fd);
454 errno = saved_errno;
455 rb_sys_fail(msg);
458 static VALUE
459 my_connect(VALUE klass, int io_wait, int domain, void *addr, socklen_t addrlen)
461 int fd = socket(domain, SOCK_STREAM, 0);
463 if (fd == -1) {
464 switch (errno) {
465 case EMFILE:
466 case ENFILE:
467 #ifdef ENOBUFS
468 case ENOBUFS:
469 #endif /* ENOBUFS */
470 errno = 0;
471 rb_gc();
472 fd = socket(domain, SOCK_STREAM, 0);
474 if (fd == -1)
475 rb_sys_fail("socket");
478 if (fcntl(fd, F_SETFL, O_RDWR | O_NONBLOCK) == -1)
479 close_fail(fd, "fcntl(F_SETFL, O_RDWR | O_NONBLOCK)");
481 if (connect(fd, addr, addrlen) == -1) {
482 if (errno == EINPROGRESS) {
483 VALUE io = sock_for_fd(klass, fd);
485 if (io_wait) {
486 errno = EAGAIN;
487 wait_writable(io);
489 return io;
491 close_fail(fd, "connect");
493 return sock_for_fd(klass, fd);
496 static VALUE tcp_connect(VALUE klass, VALUE ip, VALUE port, int io_wait)
498 struct sockaddr_in addr = { 0 };
500 addr.sin_family = AF_INET;
501 addr.sin_port = htons((unsigned short)NUM2INT(port));
503 switch (inet_pton(AF_INET, StringValuePtr(ip), &addr.sin_addr)) {
504 case 1:
505 return my_connect(klass, io_wait, PF_INET, &addr, sizeof(addr));
506 case -1:
507 rb_sys_fail("inet_pton");
509 rb_raise(rb_eArgError, "invalid address: %s", StringValuePtr(ip));
511 return Qnil;
515 * call-seq:
517 * Kgio::TCPSocket.new('127.0.0.1', 80) -> socket
519 * Creates a new Kgio::TCPSocket object and initiates a
520 * non-blocking connection. The caller should select/poll
521 * on the socket for writability before attempting to write
522 * or optimistically attempt a write and handle Kgio::WaitWritable
523 * or Errno::EAGAIN.
525 * Unlike the TCPSocket.new in Ruby, this does NOT perform DNS
526 * lookups (which is subject to a different set of timeouts and
527 * best handled elsewhere).
529 * This is only intended as a convenience for testing,
530 * Kgio::Socket.new (along with a cached/memoized addr argument)
531 * is recommended for applications that repeatedly connect to
532 * the same backend servers.
534 static VALUE kgio_tcp_connect(VALUE klass, VALUE ip, VALUE port)
536 return tcp_connect(klass, ip, port, 1);
539 static VALUE kgio_tcp_start(VALUE klass, VALUE ip, VALUE port)
541 return tcp_connect(klass, ip, port, 0);
544 static VALUE unix_connect(VALUE klass, VALUE path, int io_wait)
546 struct sockaddr_un addr = { 0 };
547 long len;
549 StringValue(path);
550 len = RSTRING_LEN(path);
551 if (sizeof(addr.sun_path) <= len)
552 rb_raise(rb_eArgError,
553 "too long unix socket path (max: %dbytes)",
554 (int)sizeof(addr.sun_path)-1);
556 memcpy(addr.sun_path, RSTRING_PTR(path), len);
557 addr.sun_family = AF_UNIX;
559 return my_connect(klass, io_wait, PF_UNIX, &addr, sizeof(addr));
563 * call-seq:
565 * Kgio::UNIXSocket.new("/path/to/unix/socket") -> socket
567 * Creates a new Kgio::UNIXSocket object and initiates a
568 * non-blocking connection. The caller should select/poll
569 * on the socket for writability before attempting to write
570 * or optimistically attempt a write and handle Kgio::WaitWritable
571 * or Errno::EAGAIN.
573 * This is only intended as a convenience for testing,
574 * Kgio::Socket.new (along with a cached/memoized addr argument)
575 * is recommended for applications that repeatedly connect to
576 * the same backend servers.
578 static VALUE kgio_unix_connect(VALUE klass, VALUE path)
580 return unix_connect(klass, path, 1);
583 static VALUE kgio_unix_start(VALUE klass, VALUE path)
585 return unix_connect(klass, path, 0);
588 static VALUE stream_connect(VALUE klass, VALUE addr, int io_wait)
590 int domain;
591 socklen_t addrlen;
592 struct sockaddr *sockaddr;
594 if (TYPE(addr) == T_STRING) {
595 sockaddr = (struct sockaddr *)(RSTRING_PTR(addr));
596 addrlen = (socklen_t)RSTRING_LEN(addr);
597 } else {
598 rb_raise(rb_eTypeError, "invalid address");
600 switch (((struct sockaddr_in *)(sockaddr))->sin_family) {
601 case AF_UNIX: domain = PF_UNIX; break;
602 case AF_INET: domain = PF_INET; break;
603 #ifdef AF_INET6 /* IPv6 support incomplete */
604 case AF_INET6: domain = PF_INET6; break;
605 #endif /* AF_INET6 */
606 default:
607 rb_raise(rb_eArgError, "invalid address family");
610 return my_connect(klass, io_wait, domain, sockaddr, addrlen);
613 static VALUE kgio_connect(VALUE klass, VALUE addr)
615 return stream_connect(klass, addr, 1);
618 static VALUE kgio_start(VALUE klass, VALUE addr)
620 return stream_connect(klass, addr, 0);
624 * call-seq:
626 * addr = Socket.pack_sockaddr_in(80, 'example.com')
627 * Kgio::Socket.new(addr) -> socket
629 * addr = Socket.pack_sockaddr_un("/tmp/unix.sock")
630 * Kgio::Socket.new(addr) -> socket
632 * Generic connect method for addr generated by Socket.pack_sockaddr_in
633 * or Socket.pack_sockaddr_un
637 void Init_kgio_ext(void)
639 VALUE mKgio = rb_define_module("Kgio");
640 VALUE mPipeMethods, mSocketMethods;
641 VALUE cUNIXServer, cTCPServer, cUNIXSocket, cTCPSocket;
643 rb_require("socket");
644 cSocket = rb_const_get(rb_cObject, rb_intern("Socket"));
645 cSocket = rb_define_class_under(mKgio, "Socket", cSocket);
647 localhost = rb_str_new2("127.0.0.1");
648 rb_const_set(mKgio, rb_intern("LOCALHOST"), localhost);
651 * The kgio_read method will return this when waiting for
652 * a read is required.
654 mKgio_WaitReadable = rb_define_module_under(mKgio, "WaitReadable");
657 * The kgio_write method will return this when waiting for
658 * a write is required.
660 mKgio_WaitWritable = rb_define_module_under(mKgio, "WaitWritable");
662 rb_define_singleton_method(mKgio, "wait_readable=", set_wait_rd, 1);
663 rb_define_singleton_method(mKgio, "wait_writable=", set_wait_wr, 1);
664 rb_define_singleton_method(mKgio, "wait_readable", wait_rd, 0);
665 rb_define_singleton_method(mKgio, "wait_writable", wait_wr, 0);
666 rb_define_singleton_method(mKgio, "accept_cloexec?", get_cloexec, 0);
667 rb_define_singleton_method(mKgio, "accept_cloexec=", set_cloexec, 1);
668 rb_define_singleton_method(mKgio, "accept_nonblock?", get_nonblock, 0);
669 rb_define_singleton_method(mKgio, "accept_nonblock=", set_nonblock, 1);
671 mPipeMethods = rb_define_module_under(mKgio, "PipeMethods");
672 rb_define_method(mPipeMethods, "kgio_read", kgio_read, -1);
673 rb_define_method(mPipeMethods, "kgio_write", kgio_write, 1);
674 rb_define_method(mPipeMethods, "kgio_tryread", kgio_tryread, -1);
675 rb_define_method(mPipeMethods, "kgio_trywrite", kgio_trywrite, 1);
677 mSocketMethods = rb_define_module_under(mKgio, "SocketMethods");
678 rb_define_method(mSocketMethods, "kgio_read", kgio_recv, -1);
679 rb_define_method(mSocketMethods, "kgio_write", kgio_send, 1);
680 rb_define_method(mSocketMethods, "kgio_tryread", kgio_tryrecv, -1);
681 rb_define_method(mSocketMethods, "kgio_trywrite", kgio_trysend, 1);
683 rb_define_attr(mSocketMethods, "kgio_addr", 1, 1);
684 rb_include_module(cSocket, mSocketMethods);
685 rb_define_singleton_method(cSocket, "new", kgio_connect, 1);
686 rb_define_singleton_method(cSocket, "start", kgio_start, 1);
688 cUNIXServer = rb_const_get(rb_cObject, rb_intern("UNIXServer"));
689 cUNIXServer = rb_define_class_under(mKgio, "UNIXServer", cUNIXServer);
690 rb_define_method(cUNIXServer, "kgio_accept", unix_accept, 0);
692 cTCPServer = rb_const_get(rb_cObject, rb_intern("TCPServer"));
693 cTCPServer = rb_define_class_under(mKgio, "TCPServer", cTCPServer);
694 rb_define_method(cTCPServer, "kgio_accept", tcp_accept, 0);
696 cTCPSocket = rb_const_get(rb_cObject, rb_intern("TCPSocket"));
697 cTCPSocket = rb_define_class_under(mKgio, "TCPSocket", cTCPSocket);
698 rb_include_module(cTCPSocket, mSocketMethods);
699 rb_define_singleton_method(cTCPSocket, "new", kgio_tcp_connect, 2);
700 rb_define_singleton_method(cTCPSocket, "start", kgio_tcp_start, 2);
702 cUNIXSocket = rb_const_get(rb_cObject, rb_intern("UNIXSocket"));
703 cUNIXSocket = rb_define_class_under(mKgio, "UNIXSocket", cUNIXSocket);
704 rb_include_module(cUNIXSocket, mSocketMethods);
705 rb_define_singleton_method(cUNIXSocket, "new", kgio_unix_connect, 1);
706 rb_define_singleton_method(cUNIXSocket, "start", kgio_unix_start, 1);
708 iv_kgio_addr = rb_intern("@kgio_addr");
709 init_sock_for_fd();