improve throughput in copyloop() using bigger buffer
[rofl0r-microsocks.git] / sockssrv.c
blobdbad9c2eb6d094a673fd32074a7d2ebe686c490d
1 /*
2 MicroSocks - multithreaded, small, efficient SOCKS5 server.
4 Copyright (C) 2017 rofl0r.
6 This is the successor of "rocksocks5", and it was written with
7 different goals in mind:
9 - prefer usage of standard libc functions over homegrown ones
10 - no artificial limits
11 - do not aim for minimal binary size, but for minimal source code size,
12 and maximal readability, reusability, and extensibility.
14 as a result of that, ipv4, dns, and ipv6 is supported out of the box
15 and can use the same code, while rocksocks5 has several compile time
16 defines to bring down the size of the resulting binary to extreme values
17 like 10 KB static linked when only ipv4 support is enabled.
19 still, if optimized for size, *this* program when static linked against musl
20 libc is not even 50 KB. that's easily usable even on the cheapest routers.
24 #define _GNU_SOURCE
25 #include <unistd.h>
26 #define _POSIX_C_SOURCE 200809L
27 #include <stdlib.h>
28 #include <string.h>
29 #include <stdio.h>
30 #include <pthread.h>
31 #include <signal.h>
32 #include <poll.h>
33 #include <arpa/inet.h>
34 #include <errno.h>
35 #include <limits.h>
36 #include "server.h"
37 #include "sblist.h"
39 /* timeout in microseconds on resource exhaustion to prevent excessive
40 cpu usage. */
41 #ifndef FAILURE_TIMEOUT
42 #define FAILURE_TIMEOUT 64
43 #endif
45 #ifndef MAX
46 #define MAX(x, y) ((x) > (y) ? (x) : (y))
47 #define MIN(x, y) ((x) < (y) ? (x) : (y))
48 #endif
50 #ifdef PTHREAD_STACK_MIN
51 #define THREAD_STACK_SIZE MAX(16*1024, PTHREAD_STACK_MIN)
52 #else
53 #define THREAD_STACK_SIZE 64*1024
54 #endif
56 #if defined(__APPLE__)
57 #undef THREAD_STACK_SIZE
58 #define THREAD_STACK_SIZE 64*1024
59 #elif defined(__GLIBC__) || defined(__FreeBSD__) || defined(__sun__)
60 #undef THREAD_STACK_SIZE
61 #define THREAD_STACK_SIZE 32*1024
62 #endif
64 static int quiet;
65 static const char* auth_user;
66 static const char* auth_pass;
67 static sblist* auth_ips;
68 static pthread_rwlock_t auth_ips_lock = PTHREAD_RWLOCK_INITIALIZER;
69 static const struct server* server;
70 static union sockaddr_union bind_addr = {.v4.sin_family = AF_UNSPEC};
72 enum socksstate {
73 SS_1_CONNECTED,
74 SS_2_NEED_AUTH, /* skipped if NO_AUTH method supported */
75 SS_3_AUTHED,
78 enum authmethod {
79 AM_NO_AUTH = 0,
80 AM_GSSAPI = 1,
81 AM_USERNAME = 2,
82 AM_INVALID = 0xFF
85 enum errorcode {
86 EC_SUCCESS = 0,
87 EC_GENERAL_FAILURE = 1,
88 EC_NOT_ALLOWED = 2,
89 EC_NET_UNREACHABLE = 3,
90 EC_HOST_UNREACHABLE = 4,
91 EC_CONN_REFUSED = 5,
92 EC_TTL_EXPIRED = 6,
93 EC_COMMAND_NOT_SUPPORTED = 7,
94 EC_ADDRESSTYPE_NOT_SUPPORTED = 8,
97 struct thread {
98 pthread_t pt;
99 struct client client;
100 enum socksstate state;
101 volatile int done;
104 #ifndef CONFIG_LOG
105 #define CONFIG_LOG 1
106 #endif
107 #if CONFIG_LOG
108 /* we log to stderr because it's not using line buffering, i.e. malloc which would need
109 locking when called from different threads. for the same reason we use dprintf,
110 which writes directly to an fd. */
111 #define dolog(...) do { if(!quiet) dprintf(2, __VA_ARGS__); } while(0)
112 #else
113 static void dolog(const char* fmt, ...) { }
114 #endif
116 static struct addrinfo* addr_choose(struct addrinfo* list, union sockaddr_union* bindaddr) {
117 int af = SOCKADDR_UNION_AF(bindaddr);
118 if(af == AF_UNSPEC) return list;
119 struct addrinfo* p;
120 for(p=list; p; p=p->ai_next)
121 if(p->ai_family == af) return p;
122 return list;
125 static int connect_socks_target(unsigned char *buf, size_t n, struct client *client) {
126 if(n < 5) return -EC_GENERAL_FAILURE;
127 if(buf[0] != 5) return -EC_GENERAL_FAILURE;
128 if(buf[1] != 1) return -EC_COMMAND_NOT_SUPPORTED; /* we support only CONNECT method */
129 if(buf[2] != 0) return -EC_GENERAL_FAILURE; /* malformed packet */
131 int af = AF_INET;
132 size_t minlen = 4 + 4 + 2, l;
133 char namebuf[256];
134 struct addrinfo* remote;
136 switch(buf[3]) {
137 case 4: /* ipv6 */
138 af = AF_INET6;
139 minlen = 4 + 2 + 16;
140 /* fall through */
141 case 1: /* ipv4 */
142 if(n < minlen) return -EC_GENERAL_FAILURE;
143 if(namebuf != inet_ntop(af, buf+4, namebuf, sizeof namebuf))
144 return -EC_GENERAL_FAILURE; /* malformed or too long addr */
145 break;
146 case 3: /* dns name */
147 l = buf[4];
148 minlen = 4 + 2 + l + 1;
149 if(n < 4 + 2 + l + 1) return -EC_GENERAL_FAILURE;
150 memcpy(namebuf, buf+4+1, l);
151 namebuf[l] = 0;
152 break;
153 default:
154 return -EC_ADDRESSTYPE_NOT_SUPPORTED;
156 unsigned short port;
157 port = (buf[minlen-2] << 8) | buf[minlen-1];
158 /* there's no suitable errorcode in rfc1928 for dns lookup failure */
159 if(resolve(namebuf, port, &remote)) return -EC_GENERAL_FAILURE;
160 struct addrinfo* raddr = addr_choose(remote, &bind_addr);
161 int fd = socket(raddr->ai_family, SOCK_STREAM, 0);
162 if(fd == -1) {
163 eval_errno:
164 if(fd != -1) close(fd);
165 freeaddrinfo(remote);
166 switch(errno) {
167 case ETIMEDOUT:
168 return -EC_TTL_EXPIRED;
169 case EPROTOTYPE:
170 case EPROTONOSUPPORT:
171 case EAFNOSUPPORT:
172 return -EC_ADDRESSTYPE_NOT_SUPPORTED;
173 case ECONNREFUSED:
174 return -EC_CONN_REFUSED;
175 case ENETDOWN:
176 case ENETUNREACH:
177 return -EC_NET_UNREACHABLE;
178 case EHOSTUNREACH:
179 return -EC_HOST_UNREACHABLE;
180 case EBADF:
181 default:
182 perror("socket/connect");
183 return -EC_GENERAL_FAILURE;
186 if(SOCKADDR_UNION_AF(&bind_addr) == raddr->ai_family &&
187 bindtoip(fd, &bind_addr) == -1)
188 goto eval_errno;
189 if(connect(fd, raddr->ai_addr, raddr->ai_addrlen) == -1)
190 goto eval_errno;
192 freeaddrinfo(remote);
193 if(CONFIG_LOG) {
194 char clientname[256];
195 af = SOCKADDR_UNION_AF(&client->addr);
196 void *ipdata = SOCKADDR_UNION_ADDRESS(&client->addr);
197 inet_ntop(af, ipdata, clientname, sizeof clientname);
198 dolog("client[%d] %s: connected to %s:%d\n", client->fd, clientname, namebuf, port);
200 return fd;
203 static int is_authed(union sockaddr_union *client, union sockaddr_union *authedip) {
204 int af = SOCKADDR_UNION_AF(authedip);
205 if(af == SOCKADDR_UNION_AF(client)) {
206 size_t cmpbytes = af == AF_INET ? 4 : 16;
207 void *cmp1 = SOCKADDR_UNION_ADDRESS(client);
208 void *cmp2 = SOCKADDR_UNION_ADDRESS(authedip);
209 if(!memcmp(cmp1, cmp2, cmpbytes)) return 1;
211 return 0;
214 static int is_in_authed_list(union sockaddr_union *caddr) {
215 size_t i;
216 for(i=0;i<sblist_getsize(auth_ips);i++)
217 if(is_authed(caddr, sblist_get(auth_ips, i)))
218 return 1;
219 return 0;
222 static void add_auth_ip(union sockaddr_union *caddr) {
223 sblist_add(auth_ips, caddr);
226 static enum authmethod check_auth_method(unsigned char *buf, size_t n, struct client*client) {
227 if(buf[0] != 5) return AM_INVALID;
228 size_t idx = 1;
229 if(idx >= n ) return AM_INVALID;
230 int n_methods = buf[idx];
231 idx++;
232 while(idx < n && n_methods > 0) {
233 if(buf[idx] == AM_NO_AUTH) {
234 if(!auth_user) return AM_NO_AUTH;
235 else if(auth_ips) {
236 int authed = 0;
237 if(pthread_rwlock_rdlock(&auth_ips_lock) == 0) {
238 authed = is_in_authed_list(&client->addr);
239 pthread_rwlock_unlock(&auth_ips_lock);
241 if(authed) return AM_NO_AUTH;
243 } else if(buf[idx] == AM_USERNAME) {
244 if(auth_user) return AM_USERNAME;
246 idx++;
247 n_methods--;
249 return AM_INVALID;
252 static void send_auth_response(int fd, int version, enum authmethod meth) {
253 unsigned char buf[2];
254 buf[0] = version;
255 buf[1] = meth;
256 write(fd, buf, 2);
259 static void send_error(int fd, enum errorcode ec) {
260 /* position 4 contains ATYP, the address type, which is the same as used in the connect
261 request. we're lazy and return always IPV4 address type in errors. */
262 char buf[10] = { 5, ec, 0, 1 /*AT_IPV4*/, 0,0,0,0, 0,0 };
263 write(fd, buf, 10);
266 static void copyloop(int fd1, int fd2) {
267 struct pollfd fds[2] = {
268 [0] = {.fd = fd1, .events = POLLIN},
269 [1] = {.fd = fd2, .events = POLLIN},
272 while(1) {
273 /* inactive connections are reaped after 15 min to free resources.
274 usually programs send keep-alive packets so this should only happen
275 when a connection is really unused. */
276 switch(poll(fds, 2, 60*15*1000)) {
277 case 0:
278 return;
279 case -1:
280 if(errno == EINTR || errno == EAGAIN) continue;
281 else perror("poll");
282 return;
284 int infd = (fds[0].revents & POLLIN) ? fd1 : fd2;
285 int outfd = infd == fd2 ? fd1 : fd2;
286 /* since the biggest stack consumer in the entire code is
287 libc's getaddrinfo(), we can safely use at least half the
288 available stacksize to improve throughput. */
289 char buf[MIN(16*1024, THREAD_STACK_SIZE/2)];
290 ssize_t sent = 0, n = read(infd, buf, sizeof buf);
291 if(n <= 0) return;
292 while(sent < n) {
293 ssize_t m = write(outfd, buf+sent, n-sent);
294 if(m < 0) return;
295 sent += m;
300 static enum errorcode check_credentials(unsigned char* buf, size_t n) {
301 if(n < 5) return EC_GENERAL_FAILURE;
302 if(buf[0] != 1) return EC_GENERAL_FAILURE;
303 unsigned ulen, plen;
304 ulen=buf[1];
305 if(n < 2 + ulen + 2) return EC_GENERAL_FAILURE;
306 plen=buf[2+ulen];
307 if(n < 2 + ulen + 1 + plen) return EC_GENERAL_FAILURE;
308 char user[256], pass[256];
309 memcpy(user, buf+2, ulen);
310 memcpy(pass, buf+2+ulen+1, plen);
311 user[ulen] = 0;
312 pass[plen] = 0;
313 if(!strcmp(user, auth_user) && !strcmp(pass, auth_pass)) return EC_SUCCESS;
314 return EC_NOT_ALLOWED;
317 static int handshake(struct thread *t) {
318 unsigned char buf[1024];
319 ssize_t n;
320 int ret;
321 enum authmethod am;
322 t->state = SS_1_CONNECTED;
323 while((n = recv(t->client.fd, buf, sizeof buf, 0)) > 0) {
324 switch(t->state) {
325 case SS_1_CONNECTED:
326 am = check_auth_method(buf, n, &t->client);
327 if(am == AM_NO_AUTH) t->state = SS_3_AUTHED;
328 else if (am == AM_USERNAME) t->state = SS_2_NEED_AUTH;
329 send_auth_response(t->client.fd, 5, am);
330 if(am == AM_INVALID) return -1;
331 break;
332 case SS_2_NEED_AUTH:
333 ret = check_credentials(buf, n);
334 send_auth_response(t->client.fd, 1, ret);
335 if(ret != EC_SUCCESS)
336 return -1;
337 t->state = SS_3_AUTHED;
338 if(auth_ips && !pthread_rwlock_wrlock(&auth_ips_lock)) {
339 if(!is_in_authed_list(&t->client.addr))
340 add_auth_ip(&t->client.addr);
341 pthread_rwlock_unlock(&auth_ips_lock);
343 break;
344 case SS_3_AUTHED:
345 ret = connect_socks_target(buf, n, &t->client);
346 if(ret < 0) {
347 send_error(t->client.fd, ret*-1);
348 return -1;
350 send_error(t->client.fd, EC_SUCCESS);
351 return ret;
354 return -1;
357 static void* clientthread(void *data) {
358 struct thread *t = data;
359 int remotefd = handshake(t);
360 if(remotefd != -1) {
361 copyloop(t->client.fd, remotefd);
362 close(remotefd);
364 close(t->client.fd);
365 t->done = 1;
366 return 0;
369 static void collect(sblist *threads) {
370 size_t i;
371 for(i=0;i<sblist_getsize(threads);) {
372 struct thread* thread = *((struct thread**)sblist_get(threads, i));
373 if(thread->done) {
374 pthread_join(thread->pt, 0);
375 sblist_delete(threads, i);
376 free(thread);
377 } else
378 i++;
382 static int usage(void) {
383 dprintf(2,
384 "MicroSocks SOCKS5 Server\n"
385 "------------------------\n"
386 "usage: microsocks -1 -q -i listenip -p port -u user -P pass -b bindaddr -w ips\n"
387 "all arguments are optional.\n"
388 "by default listenip is 0.0.0.0 and port 1080.\n\n"
389 "option -q disables logging.\n"
390 "option -b specifies which ip outgoing connections are bound to\n"
391 "option -w allows to specify a comma-separated whitelist of ip addresses,\n"
392 " that may use the proxy without user/pass authentication.\n"
393 " e.g. -w 127.0.0.1,192.168.1.1.1,::1 or just -w 10.0.0.1\n"
394 " to allow access ONLY to those ips, choose an impossible to guess user/pw combo.\n"
395 "option -1 activates auth_once mode: once a specific ip address\n"
396 " authed successfully with user/pass, it is added to a whitelist\n"
397 " and may use the proxy without auth.\n"
398 " this is handy for programs like firefox that don't support\n"
399 " user/pass auth. for it to work you'd basically make one connection\n"
400 " with another program that supports it, and then you can use firefox too.\n"
402 return 1;
405 /* prevent username and password from showing up in top. */
406 static void zero_arg(char *s) {
407 size_t i, l = strlen(s);
408 for(i=0;i<l;i++) s[i] = 0;
411 int main(int argc, char** argv) {
412 int ch;
413 const char *listenip = "0.0.0.0";
414 char *p, *q;
415 unsigned port = 1080;
416 while((ch = getopt(argc, argv, ":1qb:i:p:u:P:w:")) != -1) {
417 switch(ch) {
418 case 'w': /* fall-through */
419 case '1':
420 if(!auth_ips)
421 auth_ips = sblist_new(sizeof(union sockaddr_union), 8);
422 if(ch == '1') break;
423 p = optarg;
424 while(1) {
425 union sockaddr_union ca;
426 if((q = strchr(p, ','))) *q = 0;
427 if(resolve_sa(p, 0, &ca)) {
428 dprintf(2, "error: failed to resolve %s\n", p);
429 return 1;
431 add_auth_ip(&ca);
432 if(q) *(q++) = ',', p = q;
433 else break;
435 break;
436 case 'q':
437 quiet = 1;
438 break;
439 case 'b':
440 resolve_sa(optarg, 0, &bind_addr);
441 break;
442 case 'u':
443 auth_user = strdup(optarg);
444 zero_arg(optarg);
445 break;
446 case 'P':
447 auth_pass = strdup(optarg);
448 zero_arg(optarg);
449 break;
450 case 'i':
451 listenip = optarg;
452 break;
453 case 'p':
454 port = atoi(optarg);
455 break;
456 case ':':
457 dprintf(2, "error: option -%c requires an operand\n", optopt);
458 /* fall through */
459 case '?':
460 return usage();
463 if((auth_user && !auth_pass) || (!auth_user && auth_pass)) {
464 dprintf(2, "error: user and pass must be used together\n");
465 return 1;
467 if(auth_ips && !auth_pass) {
468 dprintf(2, "error: -1/-w options must be used together with user/pass\n");
469 return 1;
471 signal(SIGPIPE, SIG_IGN);
472 struct server s;
473 sblist *threads = sblist_new(sizeof (struct thread*), 8);
474 if(server_setup(&s, listenip, port)) {
475 perror("server_setup");
476 return 1;
478 server = &s;
480 while(1) {
481 collect(threads);
482 struct client c;
483 struct thread *curr = malloc(sizeof (struct thread));
484 if(!curr) goto oom;
485 curr->done = 0;
486 if(server_waitclient(&s, &c)) {
487 dolog("failed to accept connection\n");
488 free(curr);
489 usleep(FAILURE_TIMEOUT);
490 continue;
492 curr->client = c;
493 if(!sblist_add(threads, &curr)) {
494 close(curr->client.fd);
495 free(curr);
496 oom:
497 dolog("rejecting connection due to OOM\n");
498 usleep(FAILURE_TIMEOUT); /* prevent 100% CPU usage in OOM situation */
499 continue;
501 pthread_attr_t *a = 0, attr;
502 if(pthread_attr_init(&attr) == 0) {
503 a = &attr;
504 pthread_attr_setstacksize(a, THREAD_STACK_SIZE);
506 if(pthread_create(&curr->pt, a, clientthread, curr) != 0)
507 dolog("pthread_create failed. OOM?\n");
508 if(a) pthread_attr_destroy(&attr);