connwrap - initialize gnutls session in cw_connect
[centerim.git] / connwrap / connwrap.c
blob2da9acc4c849f38a90ba86e332994dd08d1c511e
1 #include "connwrap.h"
3 #include <config.h>
5 #include <arpa/inet.h>
6 #include <errno.h>
7 #include <fcntl.h>
8 #include <netdb.h>
9 #include <netinet/in.h>
10 #include <string.h>
11 #include <stdio.h>
12 #include <stdlib.h>
13 #include <sys/time.h>
14 #include <unistd.h>
15 #define PROXY_TIMEOUT 10
16 // HTTP proxy timeout in seconds (for the CONNECT method)
18 #ifdef HAVE_OPENSSL
20 #define OPENSSL_NO_KRB5 1
21 #include <openssl/ssl.h>
22 #include <openssl/err.h>
24 #elif HAVE_GNUTLS
26 #include <gnutls/gnutls.h>
28 #elif HAVE_NSS_COMPAT
30 #include <nss_compat_ossl/nss_compat_ossl.h>
32 #endif
34 static int in_http_connect = 0;
36 #if defined(HAVE_OPENSSL) || defined(HAVE_NSS_COMPAT)
38 static SSL_CTX *ctx = NULL;
39 typedef struct { int fd; SSL *ssl; } sslsock;
41 #elif HAVE_GNUTLS
43 typedef struct { int fd; gnutls_session_t session; } sslsock;
45 #endif
47 #ifdef HAVE_SSL
49 static sslsock *socks = 0;
50 static int sockcount = 0;
52 static sslsock *getsock(int fd) {
53 int i;
55 for(i = 0; i < sockcount; i++)
56 if(socks[i].fd == fd)
57 return &socks[i];
59 return 0;
62 static sslsock *addsock(int fd) {
63 sslsock *p;
65 #ifdef HAVE_GNUTLS
66 gnutls_certificate_credentials_t xcred;
67 #endif
69 if (socks)
70 socks = (sslsock *) realloc(socks, sizeof(sslsock)*++sockcount);
71 else
72 socks = (sslsock *) malloc(sizeof(sslsock)*++sockcount);
74 p = &socks[sockcount-1];
76 #if defined(HAVE_OPENSSL) || defined(HAVE_NSS_COMPAT)
77 if(!ctx) {
78 SSL_library_init();
80 SSL_load_error_strings();
82 #ifdef HAVE_SSLEAY
83 SSLeay_add_all_algorithms();
84 #elif HAVE_OPENSSL
85 OpenSSL_add_all_algorithms();
86 #endif
87 ctx = SSL_CTX_new(SSLv23_client_method());
89 p->ssl = SSL_new(ctx);
90 SSL_set_fd(p->ssl, p->fd = fd);
91 #elif HAVE_GNUTLS
92 gnutls_global_init ();
93 gnutls_certificate_allocate_credentials (&xcred);
94 gnutls_init (&(p->session), GNUTLS_CLIENT);
95 gnutls_set_default_priority (p->session);
96 gnutls_credentials_set (p->session, GNUTLS_CRD_CERTIFICATE, xcred);
97 p->fd = fd;
98 gnutls_transport_set_ptr(p->session,(gnutls_transport_ptr_t)fd);
99 #endif
101 return p;
104 static void delsock(int fd) {
105 int i, nsockcount;
106 sslsock *nsocks;
108 if (sockcount > 0)
110 nsockcount = 0;
112 for(i = 0; i < sockcount; i++) {
113 if(socks[i].fd != fd) {
114 socks[nsockcount++] = socks[i];
116 #if defined(HAVE_OPENSSL) || defined(HAVE_NSS_COMPAT)
117 else {
118 SSL_free(socks[i].ssl);
120 #elif HAVE_GNUTLS
121 else {
122 gnutls_bye( socks[i].session, GNUTLS_SHUT_WR);
123 gnutls_deinit(socks[i].session);
125 #endif
128 socks = realloc(socks, sizeof(sslsock)*(nsockcount));
129 sockcount = nsockcount;
133 #endif
135 static char *bindaddr = 0, *proxyhost = 0, *proxyuser = 0, *proxypass = 0;
136 static int proxyport = 3128;
137 static int proxy_ssl = 0;
139 #define SOCKOUT(s) write(sockfd, s, strlen(s))
141 int cw_http_connect(int sockfd, const struct sockaddr *serv_addr, int addrlen) {
142 int err, pos, fl;
143 struct hostent *server;
144 struct sockaddr_in paddr;
145 char buf[512];
146 fd_set rfds;
148 err = 0;
149 in_http_connect = 1;
151 if(!(server = gethostbyname(proxyhost))) {
152 errno = h_errno;
153 err = -1;
156 if(!err) {
157 memset(&paddr, 0, sizeof(paddr));
158 paddr.sin_family = AF_INET;
159 memcpy(&paddr.sin_addr.s_addr, *server->h_addr_list, server->h_length);
160 paddr.sin_port = htons(proxyport);
162 fl = fcntl(sockfd, F_GETFL);
163 fcntl(sockfd, F_SETFL, fl & ~O_NONBLOCK);
165 buf[0] = 0;
167 err = cw_connect(sockfd, (struct sockaddr *) &paddr, sizeof(paddr), proxy_ssl);
170 errno = ECONNREFUSED;
172 if(!err) {
173 struct sockaddr_in *sin = (struct sockaddr_in *) serv_addr;
174 char *ip = inet_ntoa(sin->sin_addr), c;
175 struct timeval tv;
177 sprintf(buf, "%d", ntohs(sin->sin_port));
178 SOCKOUT("CONNECT ");
179 SOCKOUT(ip);
180 SOCKOUT(":");
181 SOCKOUT(buf);
182 SOCKOUT(" HTTP/1.0\r\n");
184 if(proxyuser) {
185 char *b;
186 SOCKOUT("Proxy-Authorization: Basic ");
188 sprintf(buf, "%s:%s", proxyuser, proxypass);
189 b = cw_base64_encode(buf);
190 SOCKOUT(b);
191 free(b);
193 SOCKOUT("\r\n");
196 SOCKOUT("\r\n");
198 buf[0] = 0;
200 while(err != -1) {
201 FD_ZERO(&rfds);
202 FD_SET(sockfd, &rfds);
204 tv.tv_sec = PROXY_TIMEOUT;
205 tv.tv_usec = 0;
207 err = select(sockfd+1, &rfds, 0, 0, &tv);
209 if(err < 1) err = -1;
211 if(err != -1 && FD_ISSET(sockfd, &rfds)) {
212 err = read(sockfd, &c, 1);
213 if(!err) err = -1;
215 if(err != -1) {
216 pos = strlen(buf);
217 buf[pos] = c;
218 buf[pos+1] = 0;
220 if(strlen(buf) > 4)
221 if(!strcmp(buf+strlen(buf)-4, "\r\n\r\n"))
222 break;
228 if(err != -1 && strlen(buf)) {
229 char *p = strstr(buf, " ");
231 err = -1;
233 if(p)
234 if(atoi(++p) == 200)
235 err = 0;
237 fcntl(sockfd, F_SETFL, fl);
238 if(fl & O_NONBLOCK) {
239 errno = EINPROGRESS;
240 err = -1;
244 in_http_connect = 0;
246 return err;
249 int cw_connect(int sockfd, const struct sockaddr *serv_addr, int addrlen, int ssl) {
250 int rc;
251 struct sockaddr_in ba;
253 if(bindaddr)
254 if(strlen(bindaddr)) {
255 #ifdef HAVE_INET_ATON
256 struct in_addr addr;
257 rc = inet_aton(bindaddr, &addr);
258 ba.sin_addr.s_addr = addr.s_addr;
259 #else
260 rc = inet_pton(AF_INET, bindaddr, &ba);
261 #endif
263 if(rc) {
264 ba.sin_port = 0;
265 rc = bind(sockfd, (struct sockaddr *) &ba, sizeof(ba));
266 if ((rc==-1) && (errno == EINVAL))
267 rc=0;
268 } else {
269 rc = -1;
272 if(rc) return rc;
275 if(proxyhost && !in_http_connect) {
276 rc = cw_http_connect(sockfd, serv_addr, addrlen);
278 else{
279 rc = connect(sockfd, serv_addr, addrlen);
282 #ifdef HAVE_SSL
283 if(ssl && !rc) {
284 sslsock *p = addsock(sockfd);
285 #if defined(HAVE_OPENSSL) || defined(HAVE_NSS_COMPAT)
286 if(SSL_connect(p->ssl) != 1){
287 //fprintf(stderr, "cw_connect(%d) - cannot connect to SSL\n", sockfd);
288 return -1;
290 #elif defined(HAVE_GNUTLS)
291 int ret;
292 do {
293 ret = gnutls_handshake(p->session);
294 } while ((ret == GNUTLS_E_AGAIN) || (ret == GNUTLS_E_INTERRUPTED));
295 if (ret < 0) {
296 gnutls_perror (ret);
297 return -1;
299 #endif
301 #endif
302 return rc;
305 int cw_nb_connect(int sockfd, const struct sockaddr *serv_addr, int addrlen, int ssl, int *state) {
306 int rc = 0;
307 int ret;
308 struct sockaddr_in ba;
310 if(bindaddr)
311 if(strlen(bindaddr)) {
312 #ifdef HAVE_INET_ATON
313 struct in_addr addr;
314 rc = inet_aton(bindaddr, &addr);
315 ba.sin_addr.s_addr = addr.s_addr;
316 #else
317 rc = inet_pton(AF_INET, bindaddr, &ba);
318 #endif
320 if(rc) {
321 ba.sin_port = 0;
322 rc = bind(sockfd, (struct sockaddr *) &ba, sizeof(ba));
323 if ((rc==-1) && (errno == EINVAL))
324 rc=0;
325 } else {
326 rc = -1;
329 if(rc) return rc;
332 #ifdef HAVE_SSL
333 if(ssl) {
334 if ( !(*state & CW_CONNECT_WANT_SOMETHING)){
335 rc = cw_connect(sockfd, serv_addr, addrlen, 0);
337 else{ /* check if the socket is connected correctly */
338 int optlen = sizeof(int), optval;
339 if (getsockopt(sockfd, SOL_SOCKET, SO_ERROR, &optval, &optlen) || optval){
341 /* Look for a better solution to print errors */
342 //fprintf(stderr,"cw_nb_connect(%d): getsockopt error!!\n", sockfd);
343 return -1;
347 if(!rc) {
348 sslsock *p;
349 if (*state & CW_CONNECT_SSL)
350 p = getsock(sockfd);
351 else
352 p = addsock(sockfd);
354 #ifdef HAVE_GNUTLS
356 ret = gnutls_handshake(p->session);
357 }while ((ret == GNUTLS_E_AGAIN) || (ret == GNUTLS_E_INTERRUPTED));
358 if (ret < 0) {
359 /* gnutls_deinit(p->session);
360 will be dealt with in delsock()
362 gnutls_perror (ret);
363 return -1;
365 else{
366 *state = 1;
367 return 0;
370 #elif defined(HAVE_OPENSSL) || defined(HAVE_NSS_COMPAT)
371 rc = SSL_connect(p->ssl);
372 switch(rc){
373 case 1:
374 *state = 0;
375 return 0;
376 case 0:
377 return -1;
378 default:
379 switch (SSL_get_error(p->ssl, rc)){
380 case SSL_ERROR_WANT_READ:
381 *state = CW_CONNECT_SSL | CW_CONNECT_WANT_READ;
382 return 0;
383 case SSL_ERROR_WANT_WRITE:
384 *state = CW_CONNECT_SSL | CW_CONNECT_WANT_WRITE;
385 return 0;
386 default:
387 return -1;
391 #endif
392 else{ // catch EINPROGRESS error from the connect call
393 if (errno == EINPROGRESS){
394 *state = CW_CONNECT_STARTED | CW_CONNECT_WANT_WRITE;
395 return 0;
398 return rc;
400 #endif
402 if ( !(*state & CW_CONNECT_WANT_SOMETHING)){
403 rc = connect(sockfd, serv_addr, addrlen);
405 else{ /* check if the socket is connected correctly */
406 int optlen = sizeof(int), optval;
407 if (getsockopt(sockfd, SOL_SOCKET, SO_ERROR, &optval, &optlen) || optval) {
408 //fprintf(stderr,"getsockopt error!!");
409 return -1;
411 *state = 0;
412 return 0;
414 if (rc)
415 if (errno == EINPROGRESS){
416 *state = CW_CONNECT_STARTED | CW_CONNECT_WANT_WRITE;
417 return 0;
419 return rc;
422 int cw_accept(int s, struct sockaddr *addr, int *addrlen, int ssl) {
423 #if defined(HAVE_OPENSSL) || defined(HAVE_NSS_COMPAT)
424 int rc;
426 if(ssl) {
427 rc = accept(s, addr, addrlen);
429 if(!rc) {
430 sslsock *p = addsock(s);
431 if(SSL_accept(p->ssl) != 1)
432 return -1;
436 return rc;
438 #endif
439 return accept(s, addr, addrlen);
442 int cw_write(int fd, const void *buf, int count, int ssl) {
443 #ifdef HAVE_SSL
444 sslsock *p;
445 int ret;
447 if(ssl)
448 #endif
449 #ifdef HAVE_GNUTLS
450 if(p = getsock(fd)){
451 ret = gnutls_record_send( p->session, buf, count);
452 return ret;
454 #elif defined(HAVE_OPENSSL) || defined(HAVE_NSS_COMPAT)
455 if(p = getsock(fd))
456 return SSL_write(p->ssl, buf, count);
457 #endif
458 return write(fd, buf, count);
461 int cw_read(int fd, void *buf, int count, int ssl) {
462 #ifdef HAVE_SSL
463 sslsock *p;
464 int ret;
466 if(ssl)
467 #endif
468 #ifdef HAVE_GNUTLS
469 if(p = getsock(fd)){
470 do {
471 ret = gnutls_record_recv(p->session, buf, count);
472 } while ( ret < 0 && (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN) );
473 return ret;
475 #elif defined(HAVE_OPENSSL) || defined(HAVE_NSS_COMPAT)
476 if(p = getsock(fd))
477 return SSL_read(p->ssl, buf, count);
478 #endif
479 return read(fd, buf, count);
482 int cw_close(int fd) {
483 #ifdef HAVE_SSL
484 delsock(fd);
485 #endif
486 return close(fd);
489 #define FREEVAR(v) if(v) free(v), v = 0;
491 void cw_setbind(const char *abindaddr) {
492 FREEVAR(bindaddr);
493 bindaddr = strdup(abindaddr);
496 void cw_setproxy(const char *aproxyhost, int aproxyport, const char *aproxyuser, const char *aproxypass) {
497 FREEVAR(proxyhost);
498 FREEVAR(proxyuser);
499 FREEVAR(proxypass);
501 if(aproxyhost && strlen(aproxyhost)) proxyhost = strdup(aproxyhost);
502 if(aproxyuser && strlen(aproxyuser)) proxyuser = strdup(aproxyuser);
503 if(aproxypass && strlen(aproxypass)) proxypass = strdup(aproxypass);
504 proxyport = aproxyport;
507 char *cw_base64_encode(const char *in) {
508 static char base64digits[] = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789._";
510 int j = 0;
511 int inlen = strlen(in);
512 char *out = (char *) malloc(inlen*4+1), c;
514 for(out[0] = 0; inlen >= 3; inlen -= 3) {
515 strncat(out, &base64digits[ in[j] >> 2 ], 1);
516 strncat(out, &base64digits[ ((in[j] << 4) & 0x30) | (in[j+1] >> 4) ], 1);
517 strncat(out, &base64digits[ ((in[j+1] << 2) & 0x3c) | (in[j+2] >> 6) ], 1);
518 strncat(out, &base64digits[ in[j+2] & 0x3f ], 1);
519 j += 3;
522 if(inlen > 0) {
523 unsigned char fragment;
525 strncat(out, &base64digits[in[j] >> 2], 1);
526 fragment = (in[j] << 4) & 0x30;
528 if(inlen > 1)
529 fragment |= in[j+1] >> 4;
531 strncat(out, &base64digits[fragment], 1);
533 c = (inlen < 2) ? '-' : base64digits[ (in[j+1] << 2) & 0x3c ];
534 strncat(out, &c, 1);
535 c = '-';
536 strncat(out, &c, 1);
539 return out;