9 #include <sys/socket.h>
11 #include <netinet/in.h>
14 #include <arpa/inet.h>
17 #include "missing/accept4.h"
19 #include "my_fileno.h"
20 #include "sock_for_fd.h"
22 #if defined(__linux__)
24 * we know MSG_DONTWAIT works properly on all stream sockets under Linux
25 * we can define this macro for other platforms as people care and
28 # define USE_MSG_DONTWAIT
29 static int accept4_flags
= SOCK_CLOEXEC
;
31 static int accept4_flags
= SOCK_CLOEXEC
| SOCK_NONBLOCK
;
35 static VALUE localhost
;
36 static VALUE mKgio_WaitReadable
, mKgio_WaitWritable
;
37 static ID io_wait_rd
, io_wait_wr
;
38 static ID iv_kgio_addr
;
48 static void wait_readable(VALUE io
)
51 (void)rb_funcall(io
, io_wait_rd
, 0, 0);
53 int fd
= my_fileno(io
);
55 if (!rb_io_wait_readable(fd
))
56 rb_sys_fail("wait readable");
60 static void wait_writable(VALUE io
)
63 (void)rb_funcall(io
, io_wait_wr
, 0, 0);
65 int fd
= my_fileno(io
);
67 if (!rb_io_wait_writable(fd
))
68 rb_sys_fail("wait writable");
72 static void prepare_read(struct io_args
*a
, int argc
, VALUE
*argv
, VALUE io
)
77 a
->fd
= my_fileno(io
);
78 rb_scan_args(argc
, argv
, "11", &length
, &a
->buf
);
79 a
->len
= NUM2LONG(length
);
81 a
->buf
= rb_str_new(NULL
, a
->len
);
84 rb_str_resize(a
->buf
, a
->len
);
86 a
->ptr
= RSTRING_PTR(a
->buf
);
89 static int read_check(struct io_args
*a
, long n
, const char *msg
, int io_wait
)
94 rb_str_set_len(a
->buf
, 0);
95 if (errno
== EAGAIN
) {
100 a
->buf
= mKgio_WaitReadable
;
106 rb_str_set_len(a
->buf
, n
);
113 * Document-method: Kgio::PipeMethods#kgio_read
117 * socket.kgio_read(maxlen) -> buffer
118 * socket.kgio_read(maxlen, buffer) -> buffer
120 * Reads at most maxlen bytes from the stream socket. Returns with a
121 * newly allocated buffer, or may reuse an existing buffer. This
122 * calls the method identified by Kgio.wait_readable, or uses
123 * the normal, thread-safe Ruby function to wait for readability.
124 * This returns nil on EOF.
126 * This behaves like read(2) and IO#readpartial, NOT fread(3) or
127 * IO#read which possess read-in-full behavior.
129 static VALUE
my_read(int io_wait
, int argc
, VALUE
*argv
, VALUE io
)
134 prepare_read(&a
, argc
, argv
, io
);
135 set_nonblocking(a
.fd
);
137 n
= (long)read(a
.fd
, a
.ptr
, a
.len
);
138 if (read_check(&a
, n
, "read", io_wait
) != 0)
143 static VALUE
kgio_read(int argc
, VALUE
*argv
, VALUE io
)
145 return my_read(1, argc
, argv
, io
);
148 static VALUE
kgio_tryread(int argc
, VALUE
*argv
, VALUE io
)
150 return my_read(0, argc
, argv
, io
);
153 #ifdef USE_MSG_DONTWAIT
154 static VALUE
my_recv(int io_wait
, int argc
, VALUE
*argv
, VALUE io
)
159 prepare_read(&a
, argc
, argv
, io
);
161 n
= (long)recv(a
.fd
, a
.ptr
, a
.len
, MSG_DONTWAIT
);
162 if (read_check(&a
, n
, "recv", io_wait
) != 0)
167 static VALUE
kgio_recv(int argc
, VALUE
*argv
, VALUE io
)
169 return my_recv(1, argc
, argv
, io
);
172 static VALUE
kgio_tryrecv(int argc
, VALUE
*argv
, VALUE io
)
174 return my_recv(0, argc
, argv
, io
);
176 #else /* ! USE_MSG_DONTWAIT */
177 # define kgio_recv kgio_read
178 # define kgio_tryrecv kgio_tryread
179 #endif /* USE_MSG_DONTWAIT */
181 static void prepare_write(struct io_args
*a
, VALUE io
, VALUE str
)
183 a
->buf
= (TYPE(str
) == T_STRING
) ? str
: rb_obj_as_string(str
);
184 a
->ptr
= RSTRING_PTR(a
->buf
);
185 a
->len
= RSTRING_LEN(a
->buf
);
187 a
->fd
= my_fileno(io
);
190 static int write_check(struct io_args
*a
, long n
, const char *msg
, int io_wait
)
194 } else if (n
== -1) {
197 if (errno
== EAGAIN
) {
199 wait_writable(a
->io
);
202 a
->buf
= mKgio_WaitWritable
;
208 assert(n
>= 0 && n
< a
->len
&& "write/send syscall broken?");
214 a
->buf
= rb_str_new(a
->ptr
+ n
, a
->len
- n
);
219 static VALUE
my_write(VALUE io
, VALUE str
, int io_wait
)
224 prepare_write(&a
, io
, str
);
225 set_nonblocking(a
.fd
);
227 n
= (long)write(a
.fd
, a
.ptr
, a
.len
);
228 if (write_check(&a
, n
, "write", io_wait
) != 0)
234 * Returns true if the write was completed.
236 * Calls the method Kgio.wait_writable is not set
238 static VALUE
kgio_write(VALUE io
, VALUE str
)
240 return my_write(io
, str
, 1);
244 * Returns a String containing the unwritten portion if there was a
245 * partial write. Will return Kgio::WaitReadable if EAGAIN is
248 * Returns true if the write completed in full.
250 static VALUE
kgio_trywrite(VALUE io
, VALUE str
)
252 return my_write(io
, str
, 0);
255 #ifdef USE_MSG_DONTWAIT
257 * This method behaves like Kgio::PipeMethods#kgio_write, except
258 * it will use send(2) with the MSG_DONTWAIT flag on sockets to
259 * avoid unnecessary calls to fcntl(2).
261 static VALUE
my_send(VALUE io
, VALUE str
, int io_wait
)
266 prepare_write(&a
, io
, str
);
268 n
= (long)send(a
.fd
, a
.ptr
, a
.len
, MSG_DONTWAIT
);
269 if (write_check(&a
, n
, "send", io_wait
) != 0)
274 static VALUE
kgio_send(VALUE io
, VALUE str
)
276 return my_send(io
, str
, 1);
279 static VALUE
kgio_trysend(VALUE io
, VALUE str
)
281 return my_send(io
, str
, 0);
283 #else /* ! USE_MSG_DONTWAIT */
284 # define kgio_send kgio_write
285 # define kgio_trysend kgio_trywrite
286 #endif /* ! USE_MSG_DONTWAIT */
291 * Kgio.wait_readable = :method_name
293 * Sets a method for kgio_read to call when a read would block.
294 * This is useful for non-blocking frameworks that use Fibers,
295 * as the method referred to this may cause the current Fiber
296 * to yield execution.
298 * A special value of nil will cause Ruby to wait using the
299 * rb_io_wait_readable() function, giving kgio_read similar semantics to
302 static VALUE
set_wait_rd(VALUE mod
, VALUE sym
)
306 io_wait_rd
= SYM2ID(sym
);
312 rb_raise(rb_eTypeError
, "must be a symbol or nil");
316 static VALUE
set_wait_wr(VALUE mod
, VALUE sym
)
320 io_wait_wr
= SYM2ID(sym
);
326 rb_raise(rb_eTypeError
, "must be a symbol or nil");
330 static VALUE
wait_wr(VALUE mod
)
332 return io_wait_wr
? ID2SYM(io_wait_wr
) : Qnil
;
335 static VALUE
wait_rd(VALUE mod
)
337 return io_wait_rd
? ID2SYM(io_wait_rd
) : Qnil
;
341 my_accept(int sockfd
, struct sockaddr
*addr
, socklen_t
*addrlen
)
346 client
= accept4(sockfd
, addr
, addrlen
, accept4_flags
);
352 #endif /* ECONNABORTED */
365 client
= accept4(sockfd
, addr
, addrlen
, accept4_flags
);
371 rb_sys_fail("accept");
373 return sock_for_fd(cSocket
, client
);
376 /* non-blocking flag should be set on this socket before accept() is called */
377 static VALUE
unix_accept(VALUE io
)
379 int fd
= my_fileno(io
);
380 VALUE rv
= my_accept(fd
, NULL
, NULL
);
383 rb_ivar_set(rv
, iv_kgio_addr
, localhost
);
388 /* non-blocking flag should be set on this socket before accept() is called */
389 static VALUE
tcp_accept(VALUE io
)
391 int fd
= my_fileno(io
);
392 struct sockaddr_in addr
;
393 socklen_t addrlen
= sizeof(struct sockaddr_in
);
396 VALUE rv
= my_accept(fd
, (struct sockaddr
*)&addr
, &addrlen
);
401 host
= rb_str_new(0, INET_ADDRSTRLEN
);
402 addrlen
= (socklen_t
)INET_ADDRSTRLEN
;
403 name
= inet_ntop(AF_INET
, &addr
.sin_addr
, RSTRING_PTR(host
), addrlen
);
405 rb_sys_fail("inet_ntop");
406 rb_str_set_len(host
, strlen(name
));
407 rb_ivar_set(rv
, iv_kgio_addr
, host
);
412 static VALUE
get_cloexec(VALUE mod
)
414 return (accept4_flags
& SOCK_CLOEXEC
) == SOCK_CLOEXEC
? Qtrue
: Qfalse
;
417 static VALUE
get_nonblock(VALUE mod
)
419 return (accept4_flags
& SOCK_NONBLOCK
)==SOCK_NONBLOCK
? Qtrue
: Qfalse
;
422 static VALUE
set_cloexec(VALUE mod
, VALUE boolean
)
424 switch (TYPE(boolean
)) {
426 accept4_flags
|= SOCK_CLOEXEC
;
429 accept4_flags
&= ~SOCK_CLOEXEC
;
432 rb_raise(rb_eTypeError
, "not true or false");
436 static VALUE
set_nonblock(VALUE mod
, VALUE boolean
)
438 switch (TYPE(boolean
)) {
440 accept4_flags
|= SOCK_NONBLOCK
;
443 accept4_flags
&= ~SOCK_NONBLOCK
;
446 rb_raise(rb_eTypeError
, "not true or false");
451 my_connect(VALUE klass
, int domain
, void *addr
, socklen_t addrlen
)
454 int fd
= socket(domain
, SOCK_STREAM
, 0);
465 fd
= socket(domain
, SOCK_STREAM
, 0);
468 rb_sys_fail("socket");
471 rc
= connect(fd
, addr
, addrlen
);
473 if (errno
== EINPROGRESS
) {
474 VALUE io
= sock_for_fd(klass
, fd
);
480 rb_sys_fail("connect");
482 return sock_for_fd(klass
, fd
);
488 * Kgio::TCPSocket.new('127.0.0.1', 80) -> socket
490 * Creates a new Kgio::TCPSocket object and initiates a
491 * non-blocking connection. The caller should select/poll
492 * on the socket for writability before attempting to write
493 * or optimistically attempt a write and handle Kgio::WaitWritable
496 * Unlike the TCPSocket.new in Ruby, this does NOT perform DNS
497 * lookups (which is subject to a different set of timeouts and
498 * best handled elsewhere).
500 * This is only intended as a convenience for testing,
501 * Kgio::Socket.new (along with a cached/memoized addr argument)
502 * is recommended for applications that repeatedly connect to
503 * the same backend servers.
505 static VALUE
kgio_tcp_connect(VALUE klass
, VALUE ip
, VALUE port
)
507 struct sockaddr_in addr
= { 0 };
509 addr
.sin_family
= AF_INET
;
510 addr
.sin_port
= htons((unsigned short)NUM2INT(port
));
512 switch (inet_pton(AF_INET
, StringValuePtr(ip
), &addr
.sin_addr
)) {
514 return my_connect(klass
, PF_INET
, &addr
, sizeof(addr
));
516 rb_sys_fail("inet_pton");
518 rb_raise(rb_eArgError
, "invalid address: %s",
526 * Kgio::UNIXSocket.new("/path/to/unix/socket") -> socket
528 * Creates a new Kgio::UNIXSocket object and initiates a
529 * non-blocking connection. The caller should select/poll
530 * on the socket for writability before attempting to write
531 * or optimistically attempt a write and handle Kgio::WaitWritable
534 * This is only intended as a convenience for testing,
535 * Kgio::Socket.new (along with a cached/memoized addr argument)
536 * is recommended for applications that repeatedly connect to
537 * the same backend servers.
539 static VALUE
kgio_unix_connect(VALUE klass
, VALUE path
)
541 struct sockaddr_un addr
= { 0 };
545 len
= RSTRING_LEN(path
);
546 if (sizeof(addr
.sun_path
) <= len
)
547 rb_raise(rb_eArgError
,
548 "too long unix socket path (max: %dbytes)",
549 (int)sizeof(addr
.sun_path
)-1);
551 memcpy(addr
.sun_path
, RSTRING_PTR(path
), len
);
552 addr
.sun_family
= AF_UNIX
;
554 return my_connect(klass
, PF_UNIX
, &addr
, sizeof(addr
));
560 * addr = Socket.pack_sockaddr_in(80, 'example.com')
561 * Kgio::Socket.new(addr) -> socket
563 * addr = Socket.pack_sockaddr_un("/tmp/unix.sock")
564 * Kgio::Socket.new(addr) -> socket
566 * Generic connect method for addr generated by Socket.pack_sockaddr_in
567 * or Socket.pack_sockaddr_un
569 static VALUE
kgio_connect(VALUE klass
, VALUE addr
)
573 struct sockaddr
*sockaddr
;
575 if (TYPE(addr
) == T_STRING
) {
576 sockaddr
= (struct sockaddr
*)(RSTRING_PTR(addr
));
577 addrlen
= (socklen_t
)RSTRING_LEN(addr
);
579 rb_raise(rb_eTypeError
, "invalid address");
581 switch (((struct sockaddr_in
*)(sockaddr
))->sin_family
) {
582 case AF_UNIX
: domain
= PF_UNIX
; break;
583 case AF_INET
: domain
= PF_INET
; break;
584 #ifdef AF_INET6 /* IPv6 support incomplete */
585 case AF_INET6
: domain
= PF_INET6
; break;
586 #endif /* AF_INET6 */
588 rb_raise(rb_eArgError
, "invalid address family");
591 return my_connect(klass
, domain
, sockaddr
, addrlen
);
594 void Init_kgio_ext(void)
596 VALUE mKgio
= rb_define_module("Kgio");
597 VALUE mPipeMethods
, mSocketMethods
;
598 VALUE cUNIXServer
, cTCPServer
, cUNIXSocket
, cTCPSocket
;
600 rb_require("socket");
601 cSocket
= rb_const_get(rb_cObject
, rb_intern("Socket"));
602 cSocket
= rb_define_class_under(mKgio
, "Socket", cSocket
);
604 localhost
= rb_str_new2("127.0.0.1");
605 rb_const_set(mKgio
, rb_intern("LOCALHOST"), localhost
);
608 * The kgio_read method will return this when waiting for
609 * a read is required.
611 mKgio_WaitReadable
= rb_define_module_under(mKgio
, "WaitReadable");
614 * The kgio_write method will return this when waiting for
615 * a write is required.
617 mKgio_WaitWritable
= rb_define_module_under(mKgio
, "WaitWritable");
619 rb_define_singleton_method(mKgio
, "wait_readable=", set_wait_rd
, 1);
620 rb_define_singleton_method(mKgio
, "wait_writable=", set_wait_wr
, 1);
621 rb_define_singleton_method(mKgio
, "wait_readable", wait_rd
, 0);
622 rb_define_singleton_method(mKgio
, "wait_writable", wait_wr
, 0);
623 rb_define_singleton_method(mKgio
, "accept_cloexec?", get_cloexec
, 0);
624 rb_define_singleton_method(mKgio
, "accept_cloexec=", set_cloexec
, 1);
625 rb_define_singleton_method(mKgio
, "accept_nonblock?", get_nonblock
, 0);
626 rb_define_singleton_method(mKgio
, "accept_nonblock=", set_nonblock
, 1);
628 mPipeMethods
= rb_define_module_under(mKgio
, "PipeMethods");
629 rb_define_method(mPipeMethods
, "kgio_read", kgio_read
, -1);
630 rb_define_method(mPipeMethods
, "kgio_write", kgio_write
, 1);
631 rb_define_method(mPipeMethods
, "kgio_tryread", kgio_tryread
, -1);
632 rb_define_method(mPipeMethods
, "kgio_trywrite", kgio_trywrite
, 1);
634 mSocketMethods
= rb_define_module_under(mKgio
, "SocketMethods");
635 rb_define_method(mSocketMethods
, "kgio_read", kgio_recv
, -1);
636 rb_define_method(mSocketMethods
, "kgio_write", kgio_send
, 1);
637 rb_define_method(mSocketMethods
, "kgio_tryread", kgio_tryrecv
, -1);
638 rb_define_method(mSocketMethods
, "kgio_trywrite", kgio_trysend
, 1);
640 rb_define_attr(mSocketMethods
, "kgio_addr", 1, 1);
641 rb_include_module(cSocket
, mSocketMethods
);
642 rb_define_singleton_method(cSocket
, "new", kgio_connect
, 1);
644 cUNIXServer
= rb_const_get(rb_cObject
, rb_intern("UNIXServer"));
645 cUNIXServer
= rb_define_class_under(mKgio
, "UNIXServer", cUNIXServer
);
646 rb_define_method(cUNIXServer
, "kgio_accept", unix_accept
, 0);
648 cTCPServer
= rb_const_get(rb_cObject
, rb_intern("TCPServer"));
649 cTCPServer
= rb_define_class_under(mKgio
, "TCPServer", cTCPServer
);
650 rb_define_method(cTCPServer
, "kgio_accept", tcp_accept
, 0);
652 cTCPSocket
= rb_const_get(rb_cObject
, rb_intern("TCPSocket"));
653 cTCPSocket
= rb_define_class_under(mKgio
, "TCPSocket", cTCPSocket
);
654 rb_include_module(cTCPSocket
, mSocketMethods
);
655 rb_define_singleton_method(cTCPSocket
, "new", kgio_tcp_connect
, 2);
657 cUNIXSocket
= rb_const_get(rb_cObject
, rb_intern("UNIXSocket"));
658 cUNIXSocket
= rb_define_class_under(mKgio
, "UNIXSocket", cUNIXSocket
);
659 rb_include_module(cUNIXSocket
, mSocketMethods
);
660 rb_define_singleton_method(cUNIXSocket
, "new", kgio_unix_connect
, 1);
662 iv_kgio_addr
= rb_intern("@kgio_addr");