split out reusable bits into separate headers
[kgio.git] / ext / kgio / kgio_ext.c
blobd0632e364f018e4afc2c389ef5c72e8dbf78f6fb
1 #include <ruby.h>
2 #ifdef HAVE_RUBY_IO_H
3 # include <ruby/io.h>
4 #else
5 # include <rubyio.h>
6 #endif
7 #include <errno.h>
8 #include <sys/types.h>
9 #include <sys/socket.h>
10 #include <sys/un.h>
11 #include <netinet/in.h>
12 #include <fcntl.h>
13 #include <unistd.h>
14 #include <arpa/inet.h>
15 #include <assert.h>
17 #include "missing/accept4.h"
18 #include "nonblock.h"
19 #include "my_fileno.h"
20 #include "sock_for_fd.h"
22 #if defined(__linux__)
24 * we know MSG_DONTWAIT works properly on all stream sockets under Linux
25 * we can define this macro for other platforms as people care and
26 * notice.
28 # define USE_MSG_DONTWAIT
29 static int accept4_flags = SOCK_CLOEXEC;
30 #else
31 static int accept4_flags = SOCK_CLOEXEC | SOCK_NONBLOCK;
32 #endif
34 static VALUE cSocket;
35 static VALUE localhost;
36 static VALUE mKgio_WaitReadable, mKgio_WaitWritable;
37 static ID io_wait_rd, io_wait_wr;
38 static ID iv_kgio_addr, id_ruby;
40 struct io_args {
41 VALUE io;
42 VALUE buf;
43 char *ptr;
44 long len;
45 int fd;
48 static int maybe_wait_readable(VALUE io)
50 if (io_wait_rd) {
51 if (io_wait_rd == id_ruby) {
52 if (! rb_io_wait_readable(my_fileno(io)))
53 rb_sys_fail("wait readable");
54 errno = 0;
55 } else {
56 errno = 0;
57 (void)rb_funcall(io, io_wait_rd, 0, 0);
59 return 1;
61 errno = 0;
62 return 0;
65 static int maybe_wait_writable(VALUE io)
67 if (io_wait_wr) {
68 if (io_wait_wr == id_ruby) {
69 if (! rb_io_wait_writable(my_fileno(io)))
70 rb_sys_fail("wait writable");
71 errno = 0;
72 } else {
73 errno = 0;
74 (void)rb_funcall(io, io_wait_wr, 0, 0);
76 return 1;
78 errno = 0;
79 return 0;
82 static void prepare_read(struct io_args *a, int argc, VALUE *argv, VALUE io)
84 VALUE length;
86 a->io = io;
87 a->fd = my_fileno(io);
88 rb_scan_args(argc, argv, "11", &length, &a->buf);
89 a->len = NUM2LONG(length);
90 if (NIL_P(a->buf)) {
91 a->buf = rb_str_new(NULL, a->len);
92 } else {
93 StringValue(a->buf);
94 rb_str_resize(a->buf, a->len);
96 a->ptr = RSTRING_PTR(a->buf);
99 static int read_check(struct io_args *a, long n, const char *msg)
101 if (n == -1) {
102 if (errno == EINTR)
103 return -1;
104 rb_str_set_len(a->buf, 0);
105 if (errno == EAGAIN) {
106 if (maybe_wait_readable(a->io)) {
107 return -1;
108 } else {
109 a->buf = mKgio_WaitReadable;
110 return 0;
113 rb_sys_fail(msg);
115 rb_str_set_len(a->buf, n);
116 if (n == 0)
117 rb_eof_error();
118 return 0;
121 #ifdef USE_MSG_DONTWAIT
124 * Document-method: Kgio::SocketMethods#kgio_read
126 * call-seq:
128 * socket.kgio_read(maxlen) => buffer or Kgio::WaitReadable
129 * socket.kgio_read(maxlen, buffer) => buffer or Kgio::WaitReadable
131 * Reads at most maxlen bytes from the stream socket. Returns with a
132 * newly allocated buffer, or may reuse an existing buffer. This
133 * returns Kgio::WaitReadble unless Kgio.wait_readable is set, in
134 * which case it will call the method referred to by Kgio.wait_readable.
136 static VALUE kgio_recv(int argc, VALUE *argv, VALUE io)
138 struct io_args a;
139 long n;
141 prepare_read(&a, argc, argv, io);
142 retry:
143 n = (long)recv(a.fd, a.ptr, a.len, MSG_DONTWAIT);
144 if (read_check(&a, n, "recv") != 0)
145 goto retry;
146 return a.buf;
148 #else /* ! USE_MSG_DONTWAIT */
149 # define kgio_recv kgio_read
150 #endif /* USE_MSG_DONTWAIT */
153 * Document-method: Kgio::PipeMethods#kgio_read
155 * call-seq:
157 * socket.kgio_read(maxlen) -> buffer or Kgio::WaitReadable
158 * socket.kgio_read(maxlen, buffer) -> buffer or Kgio::WaitReadable
160 * Reads at most maxlen bytes from the stream socket. Returns with a
161 * newly allocated buffer, or may reuse an existing buffer. This
162 * returns Kgio::WaitReadble unless Kgio.wait_readable is set, in
163 * which case it will call the method referred to by Kgio.wait_readable.
165 static VALUE kgio_read(int argc, VALUE *argv, VALUE io)
167 struct io_args a;
168 long n;
170 prepare_read(&a, argc, argv, io);
171 set_nonblocking(a.fd);
172 retry:
173 n = (long)read(a.fd, a.ptr, a.len);
174 if (read_check(&a, n, "read") != 0)
175 goto retry;
176 return a.buf;
179 static void prepare_write(struct io_args *a, VALUE io, VALUE str)
181 a->buf = (TYPE(str) == T_STRING) ? str : rb_obj_as_string(str);
182 a->ptr = RSTRING_PTR(a->buf);
183 a->len = RSTRING_LEN(a->buf);
184 a->io = io;
185 a->fd = my_fileno(io);
188 static int write_check(struct io_args *a, long n, const char *msg)
190 if (a->len == n) {
191 a->buf = Qnil;
192 } else if (n == -1) {
193 if (errno == EINTR)
194 return -1;
195 if (errno == EAGAIN) {
196 if (maybe_wait_writable(a->io))
197 return -1;
198 a->buf = mKgio_WaitWritable;
199 return 0;
201 rb_sys_fail(msg);
202 } else {
203 assert(n >= 0 && n < a->len && "write/send syscall broken?");
204 a->buf = rb_str_new(a->ptr + n, a->len - n);
206 return 0;
210 * Returns a String containing the unwritten portion if there was a
211 * partial write.
213 * Returns true if the write was completed.
215 * Returns Kgio::WaitWritable if the write would block and
216 * Kgio.wait_writable is not set
218 static VALUE kgio_write(VALUE io, VALUE str)
220 struct io_args a;
221 long n;
223 prepare_write(&a, io, str);
224 set_nonblocking(a.fd);
225 retry:
226 n = (long)write(a.fd, a.ptr, a.len);
227 if (write_check(&a, n, "write") != 0)
228 goto retry;
229 return a.buf;
232 #ifdef USE_MSG_DONTWAIT
234 * This method behaves like Kgio::PipeMethods#kgio_write, except
235 * it will use send(2) with the MSG_DONTWAIT flag on sockets to
236 * avoid unnecessary calls to fcntl(2).
238 static VALUE kgio_send(VALUE io, VALUE str)
240 struct io_args a;
241 long n;
243 prepare_write(&a, io, str);
244 retry:
245 n = (long)send(a.fd, a.ptr, a.len, MSG_DONTWAIT);
246 if (write_check(&a, n, "send") != 0)
247 goto retry;
248 return a.buf;
250 #else /* ! USE_MSG_DONTWAIT */
251 # define kgio_send kgio_write
252 #endif /* ! USE_MSG_DONTWAIT */
255 * call-seq:
257 * Kgio.wait_readable = :method_name
259 * Sets a method for kgio_read to call when a read would block.
260 * This is useful for non-blocking frameworks that use Fibers,
261 * as the method referred to this may cause the current Fiber
262 * to yield execution.
264 * A special value of ":ruby" will cause Ruby to wait using the
265 * rb_io_wait_readable() function, giving kgio_read similar semantics to
266 * IO#readpartial.
268 static VALUE set_wait_rd(VALUE mod, VALUE sym)
270 switch (TYPE(sym)) {
271 case T_SYMBOL:
272 io_wait_rd = SYM2ID(sym);
273 return sym;
274 case T_NIL:
275 io_wait_rd = 0;
276 return sym;
278 rb_raise(rb_eTypeError, "must be a symbol or nil");
279 return sym;
282 static VALUE set_wait_wr(VALUE mod, VALUE sym)
284 switch (TYPE(sym)) {
285 case T_SYMBOL:
286 io_wait_wr = SYM2ID(sym);
287 return sym;
288 case T_NIL:
289 io_wait_wr = 0;
290 return sym;
292 rb_raise(rb_eTypeError, "must be a symbol or nil");
293 return sym;
296 static VALUE wait_wr(VALUE mod)
298 return io_wait_wr ? ID2SYM(io_wait_wr) : Qnil;
301 static VALUE wait_rd(VALUE mod)
303 return io_wait_rd ? ID2SYM(io_wait_rd) : Qnil;
306 static VALUE
307 my_accept(int sockfd, struct sockaddr *addr, socklen_t *addrlen)
309 int client;
311 retry:
312 client = accept4(sockfd, addr, addrlen, accept4_flags);
313 if (client == -1) {
314 switch (errno) {
315 case EAGAIN:
316 #ifdef ECONNABORTED
317 case ECONNABORTED:
318 #endif /* ECONNABORTED */
319 #ifdef EPROTO
320 case EPROTO:
321 #endif /* EPROTO */
322 return Qnil;
323 case ENOMEM:
324 case EMFILE:
325 case ENFILE:
326 #ifdef ENOBUFS
327 case ENOBUFS:
328 #endif /* ENOBUFS */
329 errno = 0;
330 rb_gc();
331 client = accept4(sockfd, addr, addrlen, accept4_flags);
332 break;
333 case EINTR:
334 goto retry;
336 if (client == -1)
337 rb_sys_fail("accept");
339 return sock_for_fd(cSocket, client);
342 /* non-blocking flag should be set on this socket before accept() is called */
343 static VALUE unix_accept(VALUE io)
345 int fd = my_fileno(io);
346 VALUE rv = my_accept(fd, NULL, NULL);
348 if (! NIL_P(rv))
349 rb_ivar_set(rv, iv_kgio_addr, localhost);
351 return rv;
354 /* non-blocking flag should be set on this socket before accept() is called */
355 static VALUE tcp_accept(VALUE io)
357 int fd = my_fileno(io);
358 struct sockaddr_in addr;
359 socklen_t addrlen = sizeof(struct sockaddr_in);
360 VALUE host;
361 const char *name;
362 VALUE rv = my_accept(fd, (struct sockaddr *)&addr, &addrlen);
364 if (NIL_P(rv))
365 return rv;
367 host = rb_str_new(0, INET_ADDRSTRLEN);
368 addrlen = (socklen_t)INET_ADDRSTRLEN;
369 name = inet_ntop(AF_INET, &addr.sin_addr, RSTRING_PTR(host), addrlen);
370 if (name == NULL)
371 rb_sys_fail("inet_ntop");
372 rb_str_set_len(host, strlen(name));
373 rb_ivar_set(rv, iv_kgio_addr, host);
375 return rv;
378 static VALUE get_cloexec(VALUE mod)
380 return (accept4_flags & SOCK_CLOEXEC) == SOCK_CLOEXEC ? Qtrue : Qfalse;
383 static VALUE get_nonblock(VALUE mod)
385 return (accept4_flags & SOCK_NONBLOCK)==SOCK_NONBLOCK ? Qtrue : Qfalse;
388 static VALUE set_cloexec(VALUE mod, VALUE boolean)
390 switch (TYPE(boolean)) {
391 case T_TRUE:
392 accept4_flags |= SOCK_CLOEXEC;
393 return boolean;
394 case T_FALSE:
395 accept4_flags &= ~SOCK_CLOEXEC;
396 return boolean;
398 rb_raise(rb_eTypeError, "not true or false");
399 return Qnil;
402 static VALUE set_nonblock(VALUE mod, VALUE boolean)
404 switch (TYPE(boolean)) {
405 case T_TRUE:
406 accept4_flags |= SOCK_NONBLOCK;
407 return boolean;
408 case T_FALSE:
409 accept4_flags &= ~SOCK_NONBLOCK;
410 return boolean;
412 rb_raise(rb_eTypeError, "not true or false");
413 return Qnil;
416 static VALUE
417 my_connect(VALUE klass, int domain, void *addr, socklen_t addrlen)
419 int rc;
420 int fd = socket(domain, SOCK_STREAM, 0);
422 if (fd == -1) {
423 switch (errno) {
424 case EMFILE:
425 case ENFILE:
426 #ifdef ENOBUFS
427 case ENOBUFS:
428 #endif /* ENOBUFS */
429 errno = 0;
430 rb_gc();
431 fd = socket(domain, SOCK_STREAM, 0);
433 if (fd == -1)
434 rb_sys_fail("socket");
436 set_nonblocking(fd);
437 rc = connect(fd, addr, addrlen);
438 if (rc == -1) {
439 if (errno == EINPROGRESS) {
440 VALUE io = sock_for_fd(klass, fd);
442 (void)maybe_wait_writable(io);
443 return io;
445 rb_sys_fail("connect");
447 return sock_for_fd(klass, fd);
451 * call-seq:
453 * Kgio::TCPSocket.new('127.0.0.1', 80) -> socket
455 * Creates a new Kgio::TCPSocket object and initiates a
456 * non-blocking connection. The caller should select/poll
457 * on the socket for writability before attempting to write
458 * or optimistically attempt a write and handle Kgio::WaitWritable
459 * or Errno::EAGAIN.
461 * Unlike the TCPSocket.new in Ruby, this does NOT perform DNS
462 * lookups (which is subject to a different set of timeouts and
463 * best handled elsewhere).
465 * This is only intended as a convenience for testing,
466 * Kgio::Socket.new (along with a cached/memoized addr argument)
467 * is recommended for applications that repeatedly connect to
468 * the same backend servers.
470 static VALUE kgio_tcp_connect(VALUE klass, VALUE ip, VALUE port)
472 struct sockaddr_in addr = { 0 };
474 addr.sin_family = AF_INET;
475 addr.sin_port = htons((unsigned short)NUM2INT(port));
477 switch (inet_pton(AF_INET, StringValuePtr(ip), &addr.sin_addr)) {
478 case 1:
479 return my_connect(klass, PF_INET, &addr, sizeof(addr));
480 case -1:
481 rb_sys_fail("inet_pton");
483 rb_raise(rb_eArgError, "invalid address: %s",
484 StringValuePtr(ip));
485 return Qnil;
489 * call-seq:
491 * Kgio::UNIXSocket.new("/path/to/unix/socket") -> socket
493 * Creates a new Kgio::UNIXSocket object and initiates a
494 * non-blocking connection. The caller should select/poll
495 * on the socket for writability before attempting to write
496 * or optimistically attempt a write and handle Kgio::WaitWritable
497 * or Errno::EAGAIN.
499 * This is only intended as a convenience for testing,
500 * Kgio::Socket.new (along with a cached/memoized addr argument)
501 * is recommended for applications that repeatedly connect to
502 * the same backend servers.
504 static VALUE kgio_unix_connect(VALUE klass, VALUE path)
506 struct sockaddr_un addr = { 0 };
507 long len;
509 StringValue(path);
510 len = RSTRING_LEN(path);
511 if (sizeof(addr.sun_path) <= len)
512 rb_raise(rb_eArgError,
513 "too long unix socket path (max: %dbytes)",
514 (int)sizeof(addr.sun_path)-1);
516 memcpy(addr.sun_path, RSTRING_PTR(path), len);
517 addr.sun_family = AF_UNIX;
519 return my_connect(klass, PF_UNIX, &addr, sizeof(addr));
523 * call-seq:
525 * addr = Socket.pack_sockaddr_in(80, 'example.com')
526 * Kgio::Socket.new(addr) -> socket
528 * addr = Socket.pack_sockaddr_un("/tmp/unix.sock")
529 * Kgio::Socket.new(addr) -> socket
531 * Generic connect method for addr generated by Socket.pack_sockaddr_in
532 * or Socket.pack_sockaddr_un
534 static VALUE kgio_connect(VALUE klass, VALUE addr)
536 int domain;
537 socklen_t addrlen;
538 struct sockaddr *sockaddr;
540 if (TYPE(addr) == T_STRING) {
541 sockaddr = (struct sockaddr *)(RSTRING_PTR(addr));
542 addrlen = (socklen_t)RSTRING_LEN(addr);
543 } else {
544 rb_raise(rb_eTypeError, "invalid address");
546 switch (((struct sockaddr_in *)(sockaddr))->sin_family) {
547 case AF_UNIX: domain = PF_UNIX; break;
548 case AF_INET: domain = PF_INET; break;
549 #ifdef AF_INET6 /* IPv6 support incomplete */
550 case AF_INET6: domain = PF_INET6; break;
551 #endif /* AF_INET6 */
552 default:
553 rb_raise(rb_eArgError, "invalid address family");
556 return my_connect(klass, domain, sockaddr, addrlen);
559 void Init_kgio_ext(void)
561 VALUE mKgio = rb_define_module("Kgio");
562 VALUE mPipeMethods, mSocketMethods;
563 VALUE cUNIXServer, cTCPServer, cUNIXSocket, cTCPSocket;
565 rb_require("socket");
566 cSocket = rb_const_get(rb_cObject, rb_intern("Socket"));
567 cSocket = rb_define_class_under(mKgio, "Socket", cSocket);
569 localhost = rb_str_new2("127.0.0.1");
570 rb_const_set(mKgio, rb_intern("LOCALHOST"), localhost);
573 * The kgio_read method will return this when waiting for
574 * a read is required.
576 mKgio_WaitReadable = rb_define_module_under(mKgio, "WaitReadable");
579 * The kgio_write method will return this when waiting for
580 * a write is required.
582 mKgio_WaitWritable = rb_define_module_under(mKgio, "WaitWritable");
584 rb_define_singleton_method(mKgio, "wait_readable=", set_wait_rd, 1);
585 rb_define_singleton_method(mKgio, "wait_writable=", set_wait_wr, 1);
586 rb_define_singleton_method(mKgio, "wait_readable", wait_rd, 0);
587 rb_define_singleton_method(mKgio, "wait_writable", wait_wr, 0);
588 rb_define_singleton_method(mKgio, "accept_cloexec?", get_cloexec, 0);
589 rb_define_singleton_method(mKgio, "accept_cloexec=", set_cloexec, 1);
590 rb_define_singleton_method(mKgio, "accept_nonblock?", get_nonblock, 0);
591 rb_define_singleton_method(mKgio, "accept_nonblock=", set_nonblock, 1);
593 mPipeMethods = rb_define_module_under(mKgio, "PipeMethods");
594 rb_define_method(mPipeMethods, "kgio_read", kgio_read, -1);
595 rb_define_method(mPipeMethods, "kgio_write", kgio_write, 1);
597 mSocketMethods = rb_define_module_under(mKgio, "SocketMethods");
598 rb_define_method(mSocketMethods, "kgio_read", kgio_recv, -1);
599 rb_define_method(mSocketMethods, "kgio_write", kgio_send, 1);
601 rb_define_attr(mSocketMethods, "kgio_addr", 1, 1);
602 rb_include_module(cSocket, mSocketMethods);
603 rb_define_singleton_method(cSocket, "new", kgio_connect, 1);
605 cUNIXServer = rb_const_get(rb_cObject, rb_intern("UNIXServer"));
606 cUNIXServer = rb_define_class_under(mKgio, "UNIXServer", cUNIXServer);
607 rb_define_method(cUNIXServer, "kgio_accept", unix_accept, 0);
609 cTCPServer = rb_const_get(rb_cObject, rb_intern("TCPServer"));
610 cTCPServer = rb_define_class_under(mKgio, "TCPServer", cTCPServer);
611 rb_define_method(cTCPServer, "kgio_accept", tcp_accept, 0);
613 cTCPSocket = rb_const_get(rb_cObject, rb_intern("TCPSocket"));
614 cTCPSocket = rb_define_class_under(mKgio, "TCPSocket", cTCPSocket);
615 rb_include_module(cTCPSocket, mSocketMethods);
616 rb_define_singleton_method(cTCPSocket, "new", kgio_tcp_connect, 2);
618 cUNIXSocket = rb_const_get(rb_cObject, rb_intern("UNIXSocket"));
619 cUNIXSocket = rb_define_class_under(mKgio, "UNIXSocket", cUNIXSocket);
620 rb_include_module(cUNIXSocket, mSocketMethods);
621 rb_define_singleton_method(cUNIXSocket, "new", kgio_unix_connect, 1);
623 iv_kgio_addr = rb_intern("@kgio_addr");
624 id_ruby = rb_intern("ruby");
625 init_sock_for_fd();