9 #include <sys/socket.h>
11 #include <netinet/in.h>
14 #include <arpa/inet.h>
17 #include "missing/accept4.h"
19 #include "my_fileno.h"
20 #include "sock_for_fd.h"
22 #if defined(__linux__)
24 * we know MSG_DONTWAIT works properly on all stream sockets under Linux
25 * we can define this macro for other platforms as people care and
28 # define USE_MSG_DONTWAIT
29 static int accept4_flags
= SOCK_CLOEXEC
;
31 static int accept4_flags
= SOCK_CLOEXEC
| SOCK_NONBLOCK
;
35 static VALUE localhost
;
36 static VALUE mKgio_WaitReadable
, mKgio_WaitWritable
;
37 static ID io_wait_rd
, io_wait_wr
;
38 static ID iv_kgio_addr
, id_ruby
;
48 static int maybe_wait_readable(VALUE io
)
51 if (io_wait_rd
== id_ruby
) {
52 if (! rb_io_wait_readable(my_fileno(io
)))
53 rb_sys_fail("wait readable");
57 (void)rb_funcall(io
, io_wait_rd
, 0, 0);
65 static int maybe_wait_writable(VALUE io
)
68 if (io_wait_wr
== id_ruby
) {
69 if (! rb_io_wait_writable(my_fileno(io
)))
70 rb_sys_fail("wait writable");
74 (void)rb_funcall(io
, io_wait_wr
, 0, 0);
82 static void prepare_read(struct io_args
*a
, int argc
, VALUE
*argv
, VALUE io
)
87 a
->fd
= my_fileno(io
);
88 rb_scan_args(argc
, argv
, "11", &length
, &a
->buf
);
89 a
->len
= NUM2LONG(length
);
91 a
->buf
= rb_str_new(NULL
, a
->len
);
94 rb_str_resize(a
->buf
, a
->len
);
96 a
->ptr
= RSTRING_PTR(a
->buf
);
99 static int read_check(struct io_args
*a
, long n
, const char *msg
)
104 rb_str_set_len(a
->buf
, 0);
105 if (errno
== EAGAIN
) {
106 if (maybe_wait_readable(a
->io
)) {
109 a
->buf
= mKgio_WaitReadable
;
115 rb_str_set_len(a
->buf
, n
);
121 #ifdef USE_MSG_DONTWAIT
124 * Document-method: Kgio::SocketMethods#kgio_read
128 * socket.kgio_read(maxlen) => buffer or Kgio::WaitReadable
129 * socket.kgio_read(maxlen, buffer) => buffer or Kgio::WaitReadable
131 * Reads at most maxlen bytes from the stream socket. Returns with a
132 * newly allocated buffer, or may reuse an existing buffer. This
133 * returns Kgio::WaitReadble unless Kgio.wait_readable is set, in
134 * which case it will call the method referred to by Kgio.wait_readable.
136 static VALUE
kgio_recv(int argc
, VALUE
*argv
, VALUE io
)
141 prepare_read(&a
, argc
, argv
, io
);
143 n
= (long)recv(a
.fd
, a
.ptr
, a
.len
, MSG_DONTWAIT
);
144 if (read_check(&a
, n
, "recv") != 0)
148 #else /* ! USE_MSG_DONTWAIT */
149 # define kgio_recv kgio_read
150 #endif /* USE_MSG_DONTWAIT */
153 * Document-method: Kgio::PipeMethods#kgio_read
157 * socket.kgio_read(maxlen) -> buffer or Kgio::WaitReadable
158 * socket.kgio_read(maxlen, buffer) -> buffer or Kgio::WaitReadable
160 * Reads at most maxlen bytes from the stream socket. Returns with a
161 * newly allocated buffer, or may reuse an existing buffer. This
162 * returns Kgio::WaitReadble unless Kgio.wait_readable is set, in
163 * which case it will call the method referred to by Kgio.wait_readable.
165 static VALUE
kgio_read(int argc
, VALUE
*argv
, VALUE io
)
170 prepare_read(&a
, argc
, argv
, io
);
171 set_nonblocking(a
.fd
);
173 n
= (long)read(a
.fd
, a
.ptr
, a
.len
);
174 if (read_check(&a
, n
, "read") != 0)
179 static void prepare_write(struct io_args
*a
, VALUE io
, VALUE str
)
181 a
->buf
= (TYPE(str
) == T_STRING
) ? str
: rb_obj_as_string(str
);
182 a
->ptr
= RSTRING_PTR(a
->buf
);
183 a
->len
= RSTRING_LEN(a
->buf
);
185 a
->fd
= my_fileno(io
);
188 static int write_check(struct io_args
*a
, long n
, const char *msg
)
192 } else if (n
== -1) {
195 if (errno
== EAGAIN
) {
196 if (maybe_wait_writable(a
->io
))
198 a
->buf
= mKgio_WaitWritable
;
203 assert(n
>= 0 && n
< a
->len
&& "write/send syscall broken?");
204 a
->buf
= rb_str_new(a
->ptr
+ n
, a
->len
- n
);
210 * Returns a String containing the unwritten portion if there was a
213 * Returns true if the write was completed.
215 * Returns Kgio::WaitWritable if the write would block and
216 * Kgio.wait_writable is not set
218 static VALUE
kgio_write(VALUE io
, VALUE str
)
223 prepare_write(&a
, io
, str
);
224 set_nonblocking(a
.fd
);
226 n
= (long)write(a
.fd
, a
.ptr
, a
.len
);
227 if (write_check(&a
, n
, "write") != 0)
232 #ifdef USE_MSG_DONTWAIT
234 * This method behaves like Kgio::PipeMethods#kgio_write, except
235 * it will use send(2) with the MSG_DONTWAIT flag on sockets to
236 * avoid unnecessary calls to fcntl(2).
238 static VALUE
kgio_send(VALUE io
, VALUE str
)
243 prepare_write(&a
, io
, str
);
245 n
= (long)send(a
.fd
, a
.ptr
, a
.len
, MSG_DONTWAIT
);
246 if (write_check(&a
, n
, "send") != 0)
250 #else /* ! USE_MSG_DONTWAIT */
251 # define kgio_send kgio_write
252 #endif /* ! USE_MSG_DONTWAIT */
257 * Kgio.wait_readable = :method_name
259 * Sets a method for kgio_read to call when a read would block.
260 * This is useful for non-blocking frameworks that use Fibers,
261 * as the method referred to this may cause the current Fiber
262 * to yield execution.
264 * A special value of ":ruby" will cause Ruby to wait using the
265 * rb_io_wait_readable() function, giving kgio_read similar semantics to
268 static VALUE
set_wait_rd(VALUE mod
, VALUE sym
)
272 io_wait_rd
= SYM2ID(sym
);
278 rb_raise(rb_eTypeError
, "must be a symbol or nil");
282 static VALUE
set_wait_wr(VALUE mod
, VALUE sym
)
286 io_wait_wr
= SYM2ID(sym
);
292 rb_raise(rb_eTypeError
, "must be a symbol or nil");
296 static VALUE
wait_wr(VALUE mod
)
298 return io_wait_wr
? ID2SYM(io_wait_wr
) : Qnil
;
301 static VALUE
wait_rd(VALUE mod
)
303 return io_wait_rd
? ID2SYM(io_wait_rd
) : Qnil
;
307 my_accept(int sockfd
, struct sockaddr
*addr
, socklen_t
*addrlen
)
312 client
= accept4(sockfd
, addr
, addrlen
, accept4_flags
);
318 #endif /* ECONNABORTED */
331 client
= accept4(sockfd
, addr
, addrlen
, accept4_flags
);
337 rb_sys_fail("accept");
339 return sock_for_fd(cSocket
, client
);
342 /* non-blocking flag should be set on this socket before accept() is called */
343 static VALUE
unix_accept(VALUE io
)
345 int fd
= my_fileno(io
);
346 VALUE rv
= my_accept(fd
, NULL
, NULL
);
349 rb_ivar_set(rv
, iv_kgio_addr
, localhost
);
354 /* non-blocking flag should be set on this socket before accept() is called */
355 static VALUE
tcp_accept(VALUE io
)
357 int fd
= my_fileno(io
);
358 struct sockaddr_in addr
;
359 socklen_t addrlen
= sizeof(struct sockaddr_in
);
362 VALUE rv
= my_accept(fd
, (struct sockaddr
*)&addr
, &addrlen
);
367 host
= rb_str_new(0, INET_ADDRSTRLEN
);
368 addrlen
= (socklen_t
)INET_ADDRSTRLEN
;
369 name
= inet_ntop(AF_INET
, &addr
.sin_addr
, RSTRING_PTR(host
), addrlen
);
371 rb_sys_fail("inet_ntop");
372 rb_str_set_len(host
, strlen(name
));
373 rb_ivar_set(rv
, iv_kgio_addr
, host
);
378 static VALUE
get_cloexec(VALUE mod
)
380 return (accept4_flags
& SOCK_CLOEXEC
) == SOCK_CLOEXEC
? Qtrue
: Qfalse
;
383 static VALUE
get_nonblock(VALUE mod
)
385 return (accept4_flags
& SOCK_NONBLOCK
)==SOCK_NONBLOCK
? Qtrue
: Qfalse
;
388 static VALUE
set_cloexec(VALUE mod
, VALUE boolean
)
390 switch (TYPE(boolean
)) {
392 accept4_flags
|= SOCK_CLOEXEC
;
395 accept4_flags
&= ~SOCK_CLOEXEC
;
398 rb_raise(rb_eTypeError
, "not true or false");
402 static VALUE
set_nonblock(VALUE mod
, VALUE boolean
)
404 switch (TYPE(boolean
)) {
406 accept4_flags
|= SOCK_NONBLOCK
;
409 accept4_flags
&= ~SOCK_NONBLOCK
;
412 rb_raise(rb_eTypeError
, "not true or false");
417 my_connect(VALUE klass
, int domain
, void *addr
, socklen_t addrlen
)
420 int fd
= socket(domain
, SOCK_STREAM
, 0);
431 fd
= socket(domain
, SOCK_STREAM
, 0);
434 rb_sys_fail("socket");
437 rc
= connect(fd
, addr
, addrlen
);
439 if (errno
== EINPROGRESS
) {
440 VALUE io
= sock_for_fd(klass
, fd
);
442 (void)maybe_wait_writable(io
);
445 rb_sys_fail("connect");
447 return sock_for_fd(klass
, fd
);
453 * Kgio::TCPSocket.new('127.0.0.1', 80) -> socket
455 * Creates a new Kgio::TCPSocket object and initiates a
456 * non-blocking connection. The caller should select/poll
457 * on the socket for writability before attempting to write
458 * or optimistically attempt a write and handle Kgio::WaitWritable
461 * Unlike the TCPSocket.new in Ruby, this does NOT perform DNS
462 * lookups (which is subject to a different set of timeouts and
463 * best handled elsewhere).
465 * This is only intended as a convenience for testing,
466 * Kgio::Socket.new (along with a cached/memoized addr argument)
467 * is recommended for applications that repeatedly connect to
468 * the same backend servers.
470 static VALUE
kgio_tcp_connect(VALUE klass
, VALUE ip
, VALUE port
)
472 struct sockaddr_in addr
= { 0 };
474 addr
.sin_family
= AF_INET
;
475 addr
.sin_port
= htons((unsigned short)NUM2INT(port
));
477 switch (inet_pton(AF_INET
, StringValuePtr(ip
), &addr
.sin_addr
)) {
479 return my_connect(klass
, PF_INET
, &addr
, sizeof(addr
));
481 rb_sys_fail("inet_pton");
483 rb_raise(rb_eArgError
, "invalid address: %s",
491 * Kgio::UNIXSocket.new("/path/to/unix/socket") -> socket
493 * Creates a new Kgio::UNIXSocket object and initiates a
494 * non-blocking connection. The caller should select/poll
495 * on the socket for writability before attempting to write
496 * or optimistically attempt a write and handle Kgio::WaitWritable
499 * This is only intended as a convenience for testing,
500 * Kgio::Socket.new (along with a cached/memoized addr argument)
501 * is recommended for applications that repeatedly connect to
502 * the same backend servers.
504 static VALUE
kgio_unix_connect(VALUE klass
, VALUE path
)
506 struct sockaddr_un addr
= { 0 };
510 len
= RSTRING_LEN(path
);
511 if (sizeof(addr
.sun_path
) <= len
)
512 rb_raise(rb_eArgError
,
513 "too long unix socket path (max: %dbytes)",
514 (int)sizeof(addr
.sun_path
)-1);
516 memcpy(addr
.sun_path
, RSTRING_PTR(path
), len
);
517 addr
.sun_family
= AF_UNIX
;
519 return my_connect(klass
, PF_UNIX
, &addr
, sizeof(addr
));
525 * addr = Socket.pack_sockaddr_in(80, 'example.com')
526 * Kgio::Socket.new(addr) -> socket
528 * addr = Socket.pack_sockaddr_un("/tmp/unix.sock")
529 * Kgio::Socket.new(addr) -> socket
531 * Generic connect method for addr generated by Socket.pack_sockaddr_in
532 * or Socket.pack_sockaddr_un
534 static VALUE
kgio_connect(VALUE klass
, VALUE addr
)
538 struct sockaddr
*sockaddr
;
540 if (TYPE(addr
) == T_STRING
) {
541 sockaddr
= (struct sockaddr
*)(RSTRING_PTR(addr
));
542 addrlen
= (socklen_t
)RSTRING_LEN(addr
);
544 rb_raise(rb_eTypeError
, "invalid address");
546 switch (((struct sockaddr_in
*)(sockaddr
))->sin_family
) {
547 case AF_UNIX
: domain
= PF_UNIX
; break;
548 case AF_INET
: domain
= PF_INET
; break;
549 #ifdef AF_INET6 /* IPv6 support incomplete */
550 case AF_INET6
: domain
= PF_INET6
; break;
551 #endif /* AF_INET6 */
553 rb_raise(rb_eArgError
, "invalid address family");
556 return my_connect(klass
, domain
, sockaddr
, addrlen
);
559 void Init_kgio_ext(void)
561 VALUE mKgio
= rb_define_module("Kgio");
562 VALUE mPipeMethods
, mSocketMethods
;
563 VALUE cUNIXServer
, cTCPServer
, cUNIXSocket
, cTCPSocket
;
565 rb_require("socket");
566 cSocket
= rb_const_get(rb_cObject
, rb_intern("Socket"));
567 cSocket
= rb_define_class_under(mKgio
, "Socket", cSocket
);
569 localhost
= rb_str_new2("127.0.0.1");
570 rb_const_set(mKgio
, rb_intern("LOCALHOST"), localhost
);
573 * The kgio_read method will return this when waiting for
574 * a read is required.
576 mKgio_WaitReadable
= rb_define_module_under(mKgio
, "WaitReadable");
579 * The kgio_write method will return this when waiting for
580 * a write is required.
582 mKgio_WaitWritable
= rb_define_module_under(mKgio
, "WaitWritable");
584 rb_define_singleton_method(mKgio
, "wait_readable=", set_wait_rd
, 1);
585 rb_define_singleton_method(mKgio
, "wait_writable=", set_wait_wr
, 1);
586 rb_define_singleton_method(mKgio
, "wait_readable", wait_rd
, 0);
587 rb_define_singleton_method(mKgio
, "wait_writable", wait_wr
, 0);
588 rb_define_singleton_method(mKgio
, "accept_cloexec?", get_cloexec
, 0);
589 rb_define_singleton_method(mKgio
, "accept_cloexec=", set_cloexec
, 1);
590 rb_define_singleton_method(mKgio
, "accept_nonblock?", get_nonblock
, 0);
591 rb_define_singleton_method(mKgio
, "accept_nonblock=", set_nonblock
, 1);
593 mPipeMethods
= rb_define_module_under(mKgio
, "PipeMethods");
594 rb_define_method(mPipeMethods
, "kgio_read", kgio_read
, -1);
595 rb_define_method(mPipeMethods
, "kgio_write", kgio_write
, 1);
597 mSocketMethods
= rb_define_module_under(mKgio
, "SocketMethods");
598 rb_define_method(mSocketMethods
, "kgio_read", kgio_recv
, -1);
599 rb_define_method(mSocketMethods
, "kgio_write", kgio_send
, 1);
601 rb_define_attr(mSocketMethods
, "kgio_addr", 1, 1);
602 rb_include_module(cSocket
, mSocketMethods
);
603 rb_define_singleton_method(cSocket
, "new", kgio_connect
, 1);
605 cUNIXServer
= rb_const_get(rb_cObject
, rb_intern("UNIXServer"));
606 cUNIXServer
= rb_define_class_under(mKgio
, "UNIXServer", cUNIXServer
);
607 rb_define_method(cUNIXServer
, "kgio_accept", unix_accept
, 0);
609 cTCPServer
= rb_const_get(rb_cObject
, rb_intern("TCPServer"));
610 cTCPServer
= rb_define_class_under(mKgio
, "TCPServer", cTCPServer
);
611 rb_define_method(cTCPServer
, "kgio_accept", tcp_accept
, 0);
613 cTCPSocket
= rb_const_get(rb_cObject
, rb_intern("TCPSocket"));
614 cTCPSocket
= rb_define_class_under(mKgio
, "TCPSocket", cTCPSocket
);
615 rb_include_module(cTCPSocket
, mSocketMethods
);
616 rb_define_singleton_method(cTCPSocket
, "new", kgio_tcp_connect
, 2);
618 cUNIXSocket
= rb_const_get(rb_cObject
, rb_intern("UNIXSocket"));
619 cUNIXSocket
= rb_define_class_under(mKgio
, "UNIXSocket", cUNIXSocket
);
620 rb_include_module(cUNIXSocket
, mSocketMethods
);
621 rb_define_singleton_method(cUNIXSocket
, "new", kgio_unix_connect
, 1);
623 iv_kgio_addr
= rb_intern("@kgio_addr");
624 id_ruby
= rb_intern("ruby");