added 2nd part (SATP) of acn presentation
[anytun.git] / src / anytun.cpp
blobfdeaead4afff1f3ed494c556ed6d95d784308a9e
1 /*
2 * anytun
4 * The secure anycast tunneling protocol (satp) defines a protocol used
5 * for communication between any combination of unicast and anycast
6 * tunnel endpoints. It has less protocol overhead than IPSec in Tunnel
7 * mode and allows tunneling of every ETHER TYPE protocol (e.g.
8 * ethernet, ip, arp ...). satp directly includes cryptography and
9 * message authentication based on the methodes used by SRTP. It is
10 * intended to deliver a generic, scaleable and secure solution for
11 * tunneling and relaying of packets of any protocol.
14 * Copyright (C) 2007 anytun.org <satp@wirdorange.org>
16 * This program is free software; you can redistribute it and/or modify
17 * it under the terms of the GNU General Public License version 2
18 * as published by the Free Software Foundation.
20 * This program is distributed in the hope that it will be useful,
21 * but WITHOUT ANY WARRANTY; without even the implied warranty of
22 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23 * GNU General Public License for more details.
25 * You should have received a copy of the GNU General Public License
26 * along with this program (see the file COPYING included with this
27 * distribution); if not, write to the Free Software Foundation, Inc.,
28 * 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
31 #include <iostream>
32 #include <fstream>
33 #include <poll.h>
34 #include <fcntl.h>
35 #include <pwd.h>
36 #include <grp.h>
37 #include <sys/wait.h>
38 #include <sys/stat.h>
39 #include <unistd.h>
41 #include <pthread.h>
42 #include <gcrypt.h>
43 #include <cerrno> // for ENOMEM
45 #include "datatypes.h"
47 #include "log.h"
48 #include "buffer.h"
49 #include "plainPacket.h"
50 #include "encryptedPacket.h"
51 #include "cipher.h"
52 #include "keyDerivation.h"
53 #include "authAlgo.h"
54 #include "cipherFactory.h"
55 #include "authAlgoFactory.h"
56 #include "keyDerivationFactory.h"
57 #include "signalController.h"
58 #include "packetSource.h"
59 #include "tunDevice.h"
60 #include "options.h"
61 #include "seqWindow.h"
62 #include "connectionList.h"
63 #include "routingTable.h"
64 #include "networkAddress.h"
66 #include "syncQueue.h"
67 #include "syncCommand.h"
69 #ifndef ANYTUN_NOSYNC
70 #include "syncSocketHandler.h"
71 #include "syncListenSocket.h"
73 #include "syncSocket.h"
74 #include "syncClientSocket.h"
75 #endif
77 #include "threadParam.h"
79 #define MAX_PACKET_LENGTH 1600
81 #define SESSION_KEYLEN_AUTH 20 // TODO: hardcoded size
82 #define SESSION_KEYLEN_ENCR 16 // TODO: hardcoded size
83 #define SESSION_KEYLEN_SALT 14 // TODO: hardcoded size
85 void createConnection(const std::string & remote_host, u_int16_t remote_port, ConnectionList & cl, u_int16_t seqSize, SyncQueue & queue, mux_t mux)
87 SeqWindow * seq= new SeqWindow(seqSize);
88 seq_nr_t seq_nr_=0;
89 KeyDerivation * kd = KeyDerivationFactory::create(gOpt.getKdPrf());
90 kd->init(gOpt.getKey(), gOpt.getSalt());
91 cLog.msg(Log::PRIO_NOTICE) << "added connection remote host " << remote_host << ":" << remote_port;
92 ConnectionParam connparam ( (*kd), (*seq), seq_nr_, remote_host, remote_port);
93 cl.addConnection(connparam,mux);
94 NetworkAddress addr(ipv4,gOpt.getIfconfigParamRemoteNetmask().c_str());
95 NetworkPrefix prefix(addr,32);
96 gRoutingTable.addRoute(prefix,mux);
97 SyncCommand sc (cl,mux);
98 queue.push(sc);
99 SyncCommand sc2 (prefix);
100 queue.push(sc2);
103 bool checkPacketSeqNr(EncryptedPacket& pack,ConnectionParam& conn)
105 // compare sender_id and seq with window
106 if(conn.seq_window_.hasSeqNr(pack.getSenderId(), pack.getSeqNr()))
108 cLog.msg(Log::PRIO_NOTICE) << "Replay attack from " << conn.remote_host_<<":"<< conn.remote_port_
109 << " seq:"<<pack.getSeqNr() << " sid: "<<pack.getSenderId();
110 return false;
113 conn.seq_window_.addSeqNr(pack.getSenderId(), pack.getSeqNr());
114 return true;
117 void* sender(void* p)
119 try
121 ThreadParam* param = reinterpret_cast<ThreadParam*>(p);
123 std::auto_ptr<Cipher> c(CipherFactory::create(gOpt.getCipher()));
124 std::auto_ptr<AuthAlgo> a(AuthAlgoFactory::create(gOpt.getAuthAlgo()) );
126 PlainPacket plain_packet(MAX_PACKET_LENGTH);
127 EncryptedPacket encrypted_packet(MAX_PACKET_LENGTH);
129 Buffer session_key(u_int32_t(SESSION_KEYLEN_ENCR)); // TODO: hardcoded size
130 Buffer session_salt(u_int32_t(SESSION_KEYLEN_SALT)); // TODO: hardcoded size
131 Buffer session_auth_key(u_int32_t(SESSION_KEYLEN_AUTH)); // TODO: hardcoded size
133 //TODO replace mux
134 u_int16_t mux = gOpt.getMux();
135 while(1)
137 plain_packet.setLength(MAX_PACKET_LENGTH);
138 encrypted_packet.withAuthTag(false);
139 encrypted_packet.setLength(MAX_PACKET_LENGTH);
141 // read packet from device
142 u_int32_t len = param->dev.read(plain_packet.getPayload(), plain_packet.getPayloadLength());
143 plain_packet.setPayloadLength(len);
144 // set payload type
145 if(param->dev.getType() == TYPE_TUN)
146 plain_packet.setPayloadType(PAYLOAD_TYPE_TUN);
147 else if(param->dev.getType() == TYPE_TAP)
148 plain_packet.setPayloadType(PAYLOAD_TYPE_TAP);
149 else
150 plain_packet.setPayloadType(0);
152 if(param->cl.empty())
153 continue;
154 //std::cout << "got Packet for plain "<<plain_packet.getDstAddr().toString();
155 mux = gRoutingTable.getRoute(plain_packet.getDstAddr());
156 //std::cout << " -> "<<mux << std::endl;
157 ConnectionMap::iterator cit = param->cl.getConnection(mux);
158 if(cit==param->cl.getEnd())
159 continue;
160 ConnectionParam & conn = cit->second;
162 if(conn.remote_host_==""||!conn.remote_port_)
163 continue;
164 // generate packet-key TODO: do this only when needed
165 conn.kd_.generate(LABEL_SATP_ENCRYPTION, conn.seq_nr_, session_key);
166 conn.kd_.generate(LABEL_SATP_SALT, conn.seq_nr_, session_salt);
168 c->setKey(session_key);
169 c->setSalt(session_salt);
171 // encrypt packet
172 c->encrypt(plain_packet, encrypted_packet, conn.seq_nr_, gOpt.getSenderId(), mux);
174 encrypted_packet.setHeader(conn.seq_nr_, gOpt.getSenderId(), mux);
175 conn.seq_nr_++;
177 // add authentication tag
178 if(a->getMaxLength()) {
179 encrypted_packet.addAuthTag();
180 conn.kd_.generate(LABEL_SATP_MSG_AUTH, encrypted_packet.getSeqNr(), session_auth_key);
181 a->setKey(session_auth_key);
182 a->generate(encrypted_packet);
186 param->src.send(encrypted_packet.getBuf(), encrypted_packet.getLength(), conn.remote_host_, conn.remote_port_);
188 catch (std::exception e)
190 // ignoring icmp port unreachable :) and other socket errors :(
194 catch(std::runtime_error e)
196 cLog.msg(Log::PRIO_ERR) << "sender thread died due to an uncaught runtime_error: " << e.what();
198 catch(std::exception e)
200 cLog.msg(Log::PRIO_ERR) << "sender thread died due to an uncaught exception: " << e.what();
202 pthread_exit(NULL);
205 #ifndef ANYTUN_NOSYNC
206 void* syncConnector(void* p )
208 ThreadParam* param = reinterpret_cast<ThreadParam*>(p);
210 SocketHandler h;
211 SyncClientSocket sock(h,param->cl);
212 // sock.EnableSSL();
213 sock.Open( param->connto.host, param->connto.port);
214 h.Add(&sock);
215 while (h.GetCount())
217 h.Select();
219 pthread_exit(NULL);
222 void* syncListener(void* p )
224 ThreadParam* param = reinterpret_cast<ThreadParam*>(p);
226 SyncSocketHandler h(param->queue);
227 SyncListenSocket<SyncSocket,ConnectionList> l(h,param->cl);
229 if (l.Bind(gOpt.getLocalSyncPort()))
230 pthread_exit(NULL);
232 Utility::ResolveLocal(); // resolve local hostname
233 h.Add(&l);
234 h.Select(1,0);
235 while (1) {
236 h.Select(1,0);
239 #endif
241 void* receiver(void* p)
245 ThreadParam* param = reinterpret_cast<ThreadParam*>(p);
247 std::auto_ptr<Cipher> c( CipherFactory::create(gOpt.getCipher()) );
248 std::auto_ptr<AuthAlgo> a( AuthAlgoFactory::create(gOpt.getAuthAlgo()) );
250 EncryptedPacket encrypted_packet(MAX_PACKET_LENGTH);
251 PlainPacket plain_packet(MAX_PACKET_LENGTH);
253 Buffer session_key(u_int32_t(SESSION_KEYLEN_ENCR)); // TODO: hardcoded size
254 Buffer session_salt(u_int32_t(SESSION_KEYLEN_SALT)); // TODO: hardcoded size
255 Buffer session_auth_key(u_int32_t(SESSION_KEYLEN_AUTH)); // TODO: hardcoded size
257 while(1)
259 string remote_host;
260 u_int16_t remote_port;
262 plain_packet.setLength(MAX_PACKET_LENGTH);
263 encrypted_packet.withAuthTag(false);
264 encrypted_packet.setLength(MAX_PACKET_LENGTH);
266 // read packet from socket
267 u_int32_t len = param->src.recv(encrypted_packet.getBuf(), encrypted_packet.getLength(), remote_host, remote_port);
268 encrypted_packet.setLength(len);
270 mux_t mux = encrypted_packet.getMux();
271 // autodetect peer
272 if(gOpt.getRemoteAddr() == "" && param->cl.empty())
274 cLog.msg(Log::PRIO_NOTICE) << "autodetected remote host " << remote_host << ":" << remote_port;
275 createConnection(remote_host, remote_port, param->cl, gOpt.getSeqWindowSize(),param->queue,mux);
278 ConnectionMap::iterator cit = param->cl.getConnection(mux);
279 if (cit == param->cl.getEnd())
280 continue;
281 ConnectionParam & conn = cit->second;
283 // check whether auth tag is ok or not
284 if(a->getMaxLength()) {
285 encrypted_packet.withAuthTag(true);
286 conn.kd_.generate(LABEL_SATP_MSG_AUTH, encrypted_packet.getSeqNr(), session_auth_key);
287 a->setKey(session_auth_key);
288 if(!a->checkTag(encrypted_packet)) {
289 cLog.msg(Log::PRIO_NOTICE) << "wrong Authentication Tag!" << std::endl;
290 continue;
292 encrypted_packet.removeAuthTag();
295 //Allow dynamic IP changes
296 //TODO: add command line option to turn this off
297 if (remote_host != conn.remote_host_ || remote_port != conn.remote_port_)
299 cLog.msg(Log::PRIO_NOTICE) << "connection "<< mux << " autodetected remote host ip changed "
300 << remote_host << ":" << remote_port;
301 conn.remote_host_=remote_host;
302 conn.remote_port_=remote_port;
303 SyncCommand sc (param->cl,mux);
304 param->queue.push(sc);
307 // Replay Protection
308 if (!checkPacketSeqNr(encrypted_packet, conn))
309 continue;
311 // generate packet-key
312 conn.kd_.generate(LABEL_SATP_ENCRYPTION, encrypted_packet.getSeqNr(), session_key);
313 conn.kd_.generate(LABEL_SATP_SALT, encrypted_packet.getSeqNr(), session_salt);
314 c->setKey(session_key);
315 c->setSalt(session_salt);
317 // decrypt packet
318 c->decrypt(encrypted_packet, plain_packet);
320 // check payload_type
321 if((param->dev.getType() == TYPE_TUN && plain_packet.getPayloadType() != PAYLOAD_TYPE_TUN4 &&
322 plain_packet.getPayloadType() != PAYLOAD_TYPE_TUN6) ||
323 (param->dev.getType() == TYPE_TAP && plain_packet.getPayloadType() != PAYLOAD_TYPE_TAP))
324 continue;
326 // write it on the device
327 param->dev.write(plain_packet.getPayload(), plain_packet.getLength());
330 catch(std::runtime_error e)
332 cLog.msg(Log::PRIO_ERR) << "sender thread died due to an uncaught runtime_error: " << e.what();
334 catch(std::exception e)
336 cLog.msg(Log::PRIO_ERR) << "receiver thread died due to an uncaught exception: " << e.what();
338 pthread_exit(NULL);
341 #define MIN_GCRYPT_VERSION "1.2.0"
342 #if defined(__GNUC__) && !defined(__OpenBSD__) // TODO: thread-safety on OpenBSD
343 // make libgcrypt thread safe
344 extern "C" {
345 GCRY_THREAD_OPTION_PTHREAD_IMPL;
347 #endif
349 bool initLibGCrypt()
351 // make libgcrypt thread safe
352 // this must be called before any other libgcrypt call
353 #if defined(__GNUC__) && !defined(__OpenBSD__) // TODO: thread-safety on OpenBSD
354 gcry_control( GCRYCTL_SET_THREAD_CBS, &gcry_threads_pthread );
355 #endif
357 // this must be called right after the GCRYCTL_SET_THREAD_CBS command
358 // no other function must be called till now
359 if( !gcry_check_version( MIN_GCRYPT_VERSION ) ) {
360 std::cout << "initLibGCrypt: Invalid Version of libgcrypt, should be >= " << MIN_GCRYPT_VERSION << std::endl;
361 return false;
364 gcry_error_t err = gcry_control (GCRYCTL_DISABLE_SECMEM, 0);
365 if( err ) {
366 char buf[STERROR_TEXT_MAX];
367 buf[0] = 0;
368 std::cout << "initLibGCrypt: Failed to disable secure memory: " << gpg_strerror_r(err, buf, STERROR_TEXT_MAX) << std::endl;
369 return false;
372 // Tell Libgcrypt that initialization has completed.
373 err = gcry_control(GCRYCTL_INITIALIZATION_FINISHED);
374 if( err ) {
375 char buf[STERROR_TEXT_MAX];
376 buf[0] = 0;
377 std::cout << "initLibGCrypt: Failed to finish initialization: " << gpg_strerror_r(err, buf, STERROR_TEXT_MAX) << std::endl;
378 return false;
381 cLog.msg(Log::PRIO_NOTICE) << "initLibGCrypt: libgcrypt init finished";
382 return true;
385 void chrootAndDrop(std::string const& chrootdir, std::string const& username)
387 if (getuid() != 0)
389 std::cerr << "this programm has to be run as root in order to run in a chroot" << std::endl;
390 exit(-1);
393 struct passwd *pw = getpwnam(username.c_str());
394 if(pw) {
395 if(chroot(chrootdir.c_str()))
397 std::cerr << "can't chroot to " << chrootdir << std::endl;
398 exit(-1);
400 cLog.msg(Log::PRIO_NOTICE) << "we are in chroot jail (" << chrootdir << ") now" << std::endl;
401 chdir("/");
402 if (initgroups(pw->pw_name, pw->pw_gid) || setgid(pw->pw_gid) || setuid(pw->pw_uid))
404 std::cerr << "can't drop to user " << username << " " << pw->pw_uid << ":" << pw->pw_gid << std::endl;
405 exit(-1);
407 cLog.msg(Log::PRIO_NOTICE) << "dropped user to " << username << " " << pw->pw_uid << ":" << pw->pw_gid << std::endl;
409 else
411 std::cerr << "unknown user " << username << std::endl;
412 exit(-1);
416 void daemonize()
418 pid_t pid;
420 pid = fork();
421 if(pid) exit(0);
422 setsid();
423 pid = fork();
424 if(pid) exit(0);
426 // std::cout << "running in background now..." << std::endl;
428 int fd;
429 // for (fd=getdtablesize();fd>=0;--fd) // close all file descriptors
430 for (fd=0;fd<=2;fd++) // close all file descriptors
431 close(fd);
432 fd=open("/dev/null",O_RDWR); // stdin
433 dup(fd); // stdout
434 dup(fd); // stderr
435 umask(027);
438 int execScript(string const& script, string const& ifname)
440 pid_t pid;
441 pid = fork();
442 if(!pid) {
443 int fd;
444 for (fd=getdtablesize();fd>=0;--fd) // close all file descriptors
445 close(fd);
446 fd=open("/dev/null",O_RDWR); // stdin
447 dup(fd); // stdout
448 dup(fd); // stderr
449 return execl("/bin/sh", "/bin/sh", script.c_str(), ifname.c_str(), NULL);
451 int status = 0;
452 waitpid(pid, &status, 0);
453 return status;
456 int main(int argc, char* argv[])
458 bool daemonized=false;
459 try
462 // std::cout << "anytun - secure anycast tunneling protocol" << std::endl;
463 if(!gOpt.parse(argc, argv)) {
464 gOpt.printUsage();
465 exit(-1);
468 cLog.msg(Log::PRIO_NOTICE) << "anytun started...";
470 std::ofstream pidFile;
471 if(gOpt.getPidFile() != "") {
472 pidFile.open(gOpt.getPidFile().c_str());
473 if(!pidFile.is_open()) {
474 std::cout << "can't open pid file" << std::endl;
478 TunDevice dev(gOpt.getDevName() =="" ? NULL : gOpt.getDevName().c_str(),
479 gOpt.getDevType() =="" ? NULL : gOpt.getDevType().c_str(),
480 gOpt.getIfconfigParamLocal() =="" ? NULL : gOpt.getIfconfigParamLocal().c_str(),
481 gOpt.getIfconfigParamRemoteNetmask() =="" ? NULL : gOpt.getIfconfigParamRemoteNetmask().c_str());
482 cLog.msg(Log::PRIO_NOTICE) << "dev created (opened)";
483 cLog.msg(Log::PRIO_NOTICE) << "dev opened - actual name is '" << dev.getActualName() << "'";
484 cLog.msg(Log::PRIO_NOTICE) << "dev type is '" << dev.getTypeString() << "'";
485 if(gOpt.getPostUpScript() != "") {
486 int postup_ret = execScript(gOpt.getPostUpScript(), dev.getActualName());
487 cLog.msg(Log::PRIO_NOTICE) << "post up script '" << gOpt.getPostUpScript() << "' returned " << postup_ret;
491 // Buffer buff(u_int32_t(1600));
492 // int len;
493 // while(1)
494 // {
495 // len = dev.read(buff.getBuf(), buff.getLength());
496 // std::cout << "read " << len << " bytes from interface " << dev.getActualName() << std::endl;
497 // dev.write(buff.getBuf(), len);
498 // }
500 // return 0;
504 if(gOpt.getChroot())
505 chrootAndDrop(gOpt.getChrootDir(), gOpt.getUsername());
506 if(gOpt.getDaemonize())
507 daemonize();
508 daemonized = true;
510 if(pidFile.is_open()) {
511 pid_t pid = getpid();
512 pidFile << pid;
513 pidFile.close();
516 SignalController sig;
517 sig.init();
519 PacketSource* src;
520 if(gOpt.getLocalAddr() == "")
521 src = new UDPPacketSource(gOpt.getLocalPort());
522 else
523 src = new UDPPacketSource(gOpt.getLocalAddr(), gOpt.getLocalPort());
525 ConnectionList cl;
526 ConnectToList connect_to = gOpt.getConnectTo();
527 SyncQueue queue;
529 if(gOpt.getRemoteAddr() != "")
530 createConnection(gOpt.getRemoteAddr(),gOpt.getRemotePort(),cl,gOpt.getSeqWindowSize(), queue, gOpt.getMux());
532 ThreadParam p(dev, *src, cl, queue,*(new OptionConnectTo()));
534 // this must be called before any other libgcrypt call
535 if(!initLibGCrypt())
536 return -1;
538 pthread_t senderThread;
539 pthread_create(&senderThread, NULL, sender, &p);
540 pthread_t receiverThread;
541 pthread_create(&receiverThread, NULL, receiver, &p);
542 #ifndef ANYTUN_NOSYNC
543 pthread_t syncListenerThread;
544 if ( gOpt.getLocalSyncPort())
545 pthread_create(&syncListenerThread, NULL, syncListener, &p);
547 std::list<pthread_t> connectThreads;
548 for(ConnectToList::iterator it = connect_to.begin() ;it != connect_to.end(); ++it) {
549 connectThreads.push_back(pthread_t());
550 ThreadParam * point = new ThreadParam(dev, *src, cl, queue,*it);
551 pthread_create(& connectThreads.back(), NULL, syncConnector, point);
553 #endif
555 int ret = sig.run();
557 pthread_cancel(senderThread);
558 pthread_cancel(receiverThread);
559 #ifndef ANYTUN_NOSYNC
560 if ( gOpt.getLocalSyncPort())
561 pthread_cancel(syncListenerThread);
562 for( std::list<pthread_t>::iterator it = connectThreads.begin() ;it != connectThreads.end(); ++it)
563 pthread_cancel(*it);
564 #endif
566 pthread_join(senderThread, NULL);
567 pthread_join(receiverThread, NULL);
568 #ifndef ANYTUN_NOSYNC
569 if ( gOpt.getLocalSyncPort())
570 pthread_join(syncListenerThread, NULL);
572 for( std::list<pthread_t>::iterator it = connectThreads.begin() ;it != connectThreads.end(); ++it)
573 pthread_join(*it, NULL);
574 #endif
575 delete src;
576 delete &p.connto;
578 return ret;
580 catch(std::runtime_error e)
582 if(daemonized)
583 cLog.msg(Log::PRIO_ERR) << "uncaught runtime error, exiting: " << e.what();
584 else
585 std::cout << "uncaught runtime error, exiting: " << e.what() << std::endl;
587 catch(std::exception e)
589 if(daemonized)
590 cLog.msg(Log::PRIO_ERR) << "uncaught exception, exiting: " << e.what();
591 else
592 std::cout << "uncaught exception, exiting: " << e.what() << std::endl;