options parser support ipv6 now
[anytun.git] / src / anytun.cpp
blob258d98ad877f60336cc97c54efa99a80542236ee
1 /*
2 * anytun
4 * The secure anycast tunneling protocol (satp) defines a protocol used
5 * for communication between any combination of unicast and anycast
6 * tunnel endpoints. It has less protocol overhead than IPSec in Tunnel
7 * mode and allows tunneling of every ETHER TYPE protocol (e.g.
8 * ethernet, ip, arp ...). satp directly includes cryptography and
9 * message authentication based on the methodes used by SRTP. It is
10 * intended to deliver a generic, scaleable and secure solution for
11 * tunneling and relaying of packets of any protocol.
14 * Copyright (C) 2007-2008 Othmar Gsenger, Erwin Nindl,
15 * Christian Pointner <satp@wirdorange.org>
17 * This file is part of Anytun.
19 * Anytun is free software: you can redistribute it and/or modify
20 * it under the terms of the GNU General Public License version 3 as
21 * published by the Free Software Foundation.
23 * Anytun is distributed in the hope that it will be useful,
24 * but WITHOUT ANY WARRANTY; without even the implied warranty of
25 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
26 * GNU General Public License for more details.
28 * You should have received a copy of the GNU General Public License
29 * along with anytun. If not, see <http://www.gnu.org/licenses/>.
32 #include <iostream>
33 #include <fstream>
34 #include <poll.h>
35 #include <fcntl.h>
36 #include <pwd.h>
37 #include <grp.h>
38 #include <sys/wait.h>
39 #include <sys/stat.h>
40 #include <unistd.h>
42 #include <boost/bind.hpp>
43 #include <gcrypt.h>
44 #include <cerrno> // for ENOMEM
46 #include "datatypes.h"
48 #include "log.h"
49 #include "buffer.h"
50 #include "plainPacket.h"
51 #include "encryptedPacket.h"
52 #include "cipher.h"
53 #include "keyDerivation.h"
54 #include "authAlgo.h"
55 #include "cipherFactory.h"
56 #include "authAlgoFactory.h"
57 #include "keyDerivationFactory.h"
58 #include "signalController.h"
59 #include "packetSource.h"
60 #include "tunDevice.h"
61 #include "options.h"
62 #include "seqWindow.h"
63 #include "connectionList.h"
64 #include "routingTable.h"
65 #include "networkAddress.h"
67 #include "syncQueue.h"
68 #include "syncCommand.h"
70 #ifndef ANYTUN_NOSYNC
71 #include "syncServer.h"
72 #include "syncClient.h"
73 #include "syncOnConnect.hpp"
74 #endif
76 #include "threadParam.h"
77 #define MAX_PACKET_LENGTH 1600
79 #define SESSION_KEYLEN_AUTH 20 // TODO: hardcoded size
80 #define SESSION_KEYLEN_ENCR 16 // TODO: hardcoded size
81 #define SESSION_KEYLEN_SALT 14 // TODO: hardcoded size
83 void createConnection(const PacketSourceEndpoint & remote_end, ConnectionList & cl, u_int16_t seqSize, SyncQueue & queue, mux_t mux)
85 SeqWindow * seq= new SeqWindow(seqSize);
86 seq_nr_t seq_nr_=0;
87 KeyDerivation * kd = KeyDerivationFactory::create(gOpt.getKdPrf());
88 kd->init(gOpt.getKey(), gOpt.getSalt());
89 cLog.msg(Log::PRIO_NOTICE) << "added connection remote host " << remote_end;
91 ConnectionParam connparam ( (*kd), (*seq), seq_nr_, remote_end);
92 cl.addConnection(connparam,mux);
93 NetworkAddress addr(ipv4,gOpt.getIfconfigParamRemoteNetmask().c_str());
94 NetworkPrefix prefix(addr,32);
95 gRoutingTable.addRoute(prefix,mux);
96 SyncCommand sc (cl,mux);
97 queue.push(sc);
98 SyncCommand sc2 (prefix);
99 queue.push(sc2);
102 bool checkPacketSeqNr(EncryptedPacket& pack,ConnectionParam& conn)
104 // compare sender_id and seq with window
105 if(conn.seq_window_.hasSeqNr(pack.getSenderId(), pack.getSeqNr()))
107 cLog.msg(Log::PRIO_NOTICE) << "Replay attack from " << conn.remote_end_
108 << " seq:"<<pack.getSeqNr() << " sid: "<<pack.getSenderId();
109 return false;
112 conn.seq_window_.addSeqNr(pack.getSenderId(), pack.getSeqNr());
113 return true;
116 void sender(void* p)
118 try
120 ThreadParam* param = reinterpret_cast<ThreadParam*>(p);
122 std::auto_ptr<Cipher> c(CipherFactory::create(gOpt.getCipher()));
123 std::auto_ptr<AuthAlgo> a(AuthAlgoFactory::create(gOpt.getAuthAlgo()) );
125 PlainPacket plain_packet(MAX_PACKET_LENGTH);
126 EncryptedPacket encrypted_packet(MAX_PACKET_LENGTH);
128 Buffer session_key(u_int32_t(SESSION_KEYLEN_ENCR)); // TODO: hardcoded size
129 Buffer session_salt(u_int32_t(SESSION_KEYLEN_SALT)); // TODO: hardcoded size
130 Buffer session_auth_key(u_int32_t(SESSION_KEYLEN_AUTH)); // TODO: hardcoded size
132 //TODO replace mux
133 u_int16_t mux = gOpt.getMux();
134 PacketSourceEndpoint emptyEndpoint;
135 while(1)
137 plain_packet.setLength(MAX_PACKET_LENGTH);
138 encrypted_packet.withAuthTag(false);
139 encrypted_packet.setLength(MAX_PACKET_LENGTH);
141 // read packet from device
142 u_int32_t len = param->dev.read(plain_packet.getPayload(), plain_packet.getPayloadLength());
143 plain_packet.setPayloadLength(len);
144 // set payload type
145 if(param->dev.getType() == TYPE_TUN)
146 plain_packet.setPayloadType(PAYLOAD_TYPE_TUN);
147 else if(param->dev.getType() == TYPE_TAP)
148 plain_packet.setPayloadType(PAYLOAD_TYPE_TAP);
149 else
150 plain_packet.setPayloadType(0);
152 if(param->cl.empty())
153 continue;
154 //std::cout << "got Packet for plain "<<plain_packet.getDstAddr().toString();
155 mux = gRoutingTable.getRoute(plain_packet.getDstAddr());
156 //std::cout << " -> "<<mux << std::endl;
157 ConnectionMap::iterator cit = param->cl.getConnection(mux);
158 if(cit==param->cl.getEnd())
159 continue;
160 ConnectionParam & conn = cit->second;
162 if(conn.remote_end_ == emptyEndpoint)
164 //cLog.msg(Log::PRIO_INFO) << "no remote address set";
165 continue;
168 // generate packet-key TODO: do this only when needed
169 conn.kd_.generate(LABEL_SATP_ENCRYPTION, conn.seq_nr_, session_key);
170 conn.kd_.generate(LABEL_SATP_SALT, conn.seq_nr_, session_salt);
172 c->setKey(session_key);
173 c->setSalt(session_salt);
175 // encrypt packet
176 c->encrypt(plain_packet, encrypted_packet, conn.seq_nr_, gOpt.getSenderId(), mux);
178 encrypted_packet.setHeader(conn.seq_nr_, gOpt.getSenderId(), mux);
179 conn.seq_nr_++;
181 // add authentication tag
182 if(a->getMaxLength()) {
183 encrypted_packet.addAuthTag();
184 conn.kd_.generate(LABEL_SATP_MSG_AUTH, encrypted_packet.getSeqNr(), session_auth_key);
185 a->setKey(session_auth_key);
186 a->generate(encrypted_packet);
190 param->src.send(encrypted_packet.getBuf(), encrypted_packet.getLength(), conn.remote_end_);
192 catch (std::exception& e)
194 // ignoring icmp port unreachable :) and other socket errors :(
198 catch(std::runtime_error& e)
200 cLog.msg(Log::PRIO_ERR) << "sender thread died due to an uncaught runtime_error: " << e.what();
202 catch(std::exception& e)
204 cLog.msg(Log::PRIO_ERR) << "sender thread died due to an uncaught exception: " << e.what();
208 #ifndef ANYTUN_NOSYNC
209 void syncConnector(void* p )
211 ThreadParam* param = reinterpret_cast<ThreadParam*>(p);
213 SyncClient sc ( param->connto.host, param->connto.port);
214 sc.run();
217 void syncListener(SyncQueue * queue)
221 boost::asio::io_service io_service;
222 SyncTcpConnection::proto::resolver resolver(io_service);
223 SyncTcpConnection::proto::endpoint e;
224 if(gOpt.getLocalSyncAddr()!="")
226 SyncTcpConnection::proto::resolver::query query(gOpt.getLocalSyncAddr(), gOpt.getLocalSyncPort());
227 e = *resolver.resolve(query);
228 } else {
229 SyncTcpConnection::proto::resolver::query query(gOpt.getLocalSyncPort());
230 e = *resolver.resolve(query);
234 SyncServer server(io_service,e);
235 server.onConnect=boost::bind(syncOnConnect,_1);
236 queue->setSyncServerPtr(&server);
237 io_service.run();
239 catch (std::exception& e)
241 std::string addr = gOpt.getLocalSyncAddr() == "" ? "*" : gOpt.getLocalSyncAddr();
242 cLog.msg(Log::PRIO_ERR) << "sync: cannot bind to " << addr << ":" << gOpt.getLocalSyncPort()
243 << " (" << e.what() << ")" << std::endl;
247 #endif
249 void receiver(void* p)
253 ThreadParam* param = reinterpret_cast<ThreadParam*>(p);
255 std::auto_ptr<Cipher> c( CipherFactory::create(gOpt.getCipher()) );
256 std::auto_ptr<AuthAlgo> a( AuthAlgoFactory::create(gOpt.getAuthAlgo()) );
258 EncryptedPacket encrypted_packet(MAX_PACKET_LENGTH);
259 PlainPacket plain_packet(MAX_PACKET_LENGTH);
261 Buffer session_key(u_int32_t(SESSION_KEYLEN_ENCR)); // TODO: hardcoded size
262 Buffer session_salt(u_int32_t(SESSION_KEYLEN_SALT)); // TODO: hardcoded size
263 Buffer session_auth_key(u_int32_t(SESSION_KEYLEN_AUTH)); // TODO: hardcoded size
265 while(1)
267 PacketSourceEndpoint remote_end;
269 plain_packet.setLength(MAX_PACKET_LENGTH);
270 encrypted_packet.withAuthTag(false);
271 encrypted_packet.setLength(MAX_PACKET_LENGTH);
273 // read packet from socket
274 u_int32_t len = param->src.recv(encrypted_packet.getBuf(), encrypted_packet.getLength(), remote_end);
275 encrypted_packet.setLength(len);
277 mux_t mux = encrypted_packet.getMux();
278 // autodetect peer
279 if( param->cl.empty() && gOpt.getRemoteAddr() == "")
281 cLog.msg(Log::PRIO_NOTICE) << "autodetected remote host " << remote_end;
282 createConnection(remote_end, param->cl, gOpt.getSeqWindowSize(),param->queue,mux);
285 ConnectionMap::iterator cit = param->cl.getConnection(mux);
286 if (cit == param->cl.getEnd())
287 continue;
288 ConnectionParam & conn = cit->second;
290 // check whether auth tag is ok or not
291 if(a->getMaxLength()) {
292 encrypted_packet.withAuthTag(true);
293 conn.kd_.generate(LABEL_SATP_MSG_AUTH, encrypted_packet.getSeqNr(), session_auth_key);
294 a->setKey(session_auth_key);
295 if(!a->checkTag(encrypted_packet)) {
296 cLog.msg(Log::PRIO_NOTICE) << "wrong Authentication Tag!" << std::endl;
297 continue;
299 encrypted_packet.removeAuthTag();
302 //Allow dynamic IP changes
303 //TODO: add command line option to turn this off
304 if (remote_end != conn.remote_end_)
306 cLog.msg(Log::PRIO_NOTICE) << "connection "<< mux << " autodetected remote host ip changed " << remote_end;
307 conn.remote_end_=remote_end;
308 SyncCommand sc (param->cl,mux);
309 param->queue.push(sc);
312 // Replay Protection
313 if (!checkPacketSeqNr(encrypted_packet, conn))
314 continue;
316 // generate packet-key
317 conn.kd_.generate(LABEL_SATP_ENCRYPTION, encrypted_packet.getSeqNr(), session_key);
318 conn.kd_.generate(LABEL_SATP_SALT, encrypted_packet.getSeqNr(), session_salt);
319 c->setKey(session_key);
320 c->setSalt(session_salt);
322 // decrypt packet
323 c->decrypt(encrypted_packet, plain_packet);
325 // check payload_type
326 if((param->dev.getType() == TYPE_TUN && plain_packet.getPayloadType() != PAYLOAD_TYPE_TUN4 &&
327 plain_packet.getPayloadType() != PAYLOAD_TYPE_TUN6) ||
328 (param->dev.getType() == TYPE_TAP && plain_packet.getPayloadType() != PAYLOAD_TYPE_TAP))
329 continue;
331 // write it on the device
332 param->dev.write(plain_packet.getPayload(), plain_packet.getLength());
335 catch(std::runtime_error& e)
337 cLog.msg(Log::PRIO_ERR) << "sender thread died due to an uncaught runtime_error: " << e.what();
339 catch(std::exception& e)
341 cLog.msg(Log::PRIO_ERR) << "receiver thread died due to an uncaught exception: " << e.what();
345 // boost thread callbacks for libgcrypt
346 #if defined(BOOST_HAS_PTHREADS)
348 static int boost_mutex_init(void **priv)
350 boost::mutex *lock = new boost::mutex();
351 if (!lock)
352 return ENOMEM;
353 *priv = lock;
354 return 0;
357 static int boost_mutex_destroy(void **lock)
359 delete reinterpret_cast<boost::mutex*>(*lock);
360 return 0;
363 static int boost_mutex_lock(void **lock)
365 reinterpret_cast<boost::mutex*>(*lock)->lock();
366 return 0;
369 static int boost_mutex_unlock(void **lock)
371 reinterpret_cast<boost::mutex*>(*lock)->unlock();
372 return 0;
375 static struct gcry_thread_cbs gcry_threads_boost =
376 { GCRY_THREAD_OPTION_USER, NULL,
377 boost_mutex_init, boost_mutex_destroy,
378 boost_mutex_lock, boost_mutex_unlock };
379 #else
380 #error this libgcrypt thread callbacks only work with pthreads
381 #endif
383 #define MIN_GCRYPT_VERSION "1.2.0"
385 bool initLibGCrypt()
387 // make libgcrypt thread safe
388 // this must be called before any other libgcrypt call
389 gcry_control( GCRYCTL_SET_THREAD_CBS, &gcry_threads_boost );
391 // this must be called right after the GCRYCTL_SET_THREAD_CBS command
392 // no other function must be called till now
393 if( !gcry_check_version( MIN_GCRYPT_VERSION ) ) {
394 std::cout << "initLibGCrypt: Invalid Version of libgcrypt, should be >= " << MIN_GCRYPT_VERSION << std::endl;
395 return false;
398 gcry_error_t err = gcry_control (GCRYCTL_DISABLE_SECMEM, 0);
399 if( err ) {
400 char buf[STERROR_TEXT_MAX];
401 buf[0] = 0;
402 std::cout << "initLibGCrypt: Failed to disable secure memory: " << gpg_strerror_r(err, buf, STERROR_TEXT_MAX) << std::endl;
403 return false;
406 // Tell Libgcrypt that initialization has completed.
407 err = gcry_control(GCRYCTL_INITIALIZATION_FINISHED);
408 if( err ) {
409 char buf[STERROR_TEXT_MAX];
410 buf[0] = 0;
411 std::cout << "initLibGCrypt: Failed to finish initialization: " << gpg_strerror_r(err, buf, STERROR_TEXT_MAX) << std::endl;
412 return false;
415 cLog.msg(Log::PRIO_NOTICE) << "initLibGCrypt: libgcrypt init finished";
416 return true;
419 void chrootAndDrop(std::string const& chrootdir, std::string const& username)
421 if (getuid() != 0)
423 std::cerr << "this programm has to be run as root in order to run in a chroot" << std::endl;
424 exit(-1);
427 struct passwd *pw = getpwnam(username.c_str());
428 if(pw) {
429 if(chroot(chrootdir.c_str()))
431 std::cerr << "can't chroot to " << chrootdir << std::endl;
432 exit(-1);
434 cLog.msg(Log::PRIO_NOTICE) << "we are in chroot jail (" << chrootdir << ") now" << std::endl;
435 chdir("/");
436 if (initgroups(pw->pw_name, pw->pw_gid) || setgid(pw->pw_gid) || setuid(pw->pw_uid))
438 std::cerr << "can't drop to user " << username << " " << pw->pw_uid << ":" << pw->pw_gid << std::endl;
439 exit(-1);
441 cLog.msg(Log::PRIO_NOTICE) << "dropped user to " << username << " " << pw->pw_uid << ":" << pw->pw_gid << std::endl;
443 else
445 std::cerr << "unknown user " << username << std::endl;
446 exit(-1);
450 void daemonize()
452 pid_t pid;
454 pid = fork();
455 if(pid) exit(0);
456 setsid();
457 pid = fork();
458 if(pid) exit(0);
460 // std::cout << "running in background now..." << std::endl;
462 int fd;
463 // for (fd=getdtablesize();fd>=0;--fd) // close all file descriptors
464 for (fd=0;fd<=2;fd++) // close all file descriptors
465 close(fd);
466 fd=open("/dev/null",O_RDWR); // stdin
467 dup(fd); // stdout
468 dup(fd); // stderr
469 umask(027);
472 int execScript(std::string const& script, std::string const& ifname)
474 pid_t pid;
475 pid = fork();
476 if(!pid) {
477 int fd;
478 for (fd=getdtablesize();fd>=0;--fd) // close all file descriptors
479 close(fd);
480 fd=open("/dev/null",O_RDWR); // stdin
481 dup(fd); // stdout
482 dup(fd); // stderr
483 return execl("/bin/sh", "/bin/sh", script.c_str(), ifname.c_str(), NULL);
485 int status = 0;
486 waitpid(pid, &status, 0);
487 return status;
490 int main(int argc, char* argv[])
492 bool daemonized=false;
493 try
496 // std::cout << "anytun - secure anycast tunneling protocol" << std::endl;
497 if(!gOpt.parse(argc, argv)) {
498 gOpt.printUsage();
499 exit(-1);
502 cLog.msg(Log::PRIO_NOTICE) << "anytun started...";
504 std::ofstream pidFile;
505 if(gOpt.getPidFile() != "") {
506 pidFile.open(gOpt.getPidFile().c_str());
507 if(!pidFile.is_open()) {
508 std::cout << "can't open pid file" << std::endl;
512 TunDevice dev(gOpt.getDevName() =="" ? NULL : gOpt.getDevName().c_str(),
513 gOpt.getDevType() =="" ? NULL : gOpt.getDevType().c_str(),
514 gOpt.getIfconfigParamLocal() =="" ? NULL : gOpt.getIfconfigParamLocal().c_str(),
515 gOpt.getIfconfigParamRemoteNetmask() =="" ? NULL : gOpt.getIfconfigParamRemoteNetmask().c_str());
516 cLog.msg(Log::PRIO_NOTICE) << "dev created (opened)";
517 cLog.msg(Log::PRIO_NOTICE) << "dev opened - actual name is '" << dev.getActualName() << "'";
518 cLog.msg(Log::PRIO_NOTICE) << "dev type is '" << dev.getTypeString() << "'";
519 if(gOpt.getPostUpScript() != "") {
520 int postup_ret = execScript(gOpt.getPostUpScript(), dev.getActualName());
521 cLog.msg(Log::PRIO_NOTICE) << "post up script '" << gOpt.getPostUpScript() << "' returned " << postup_ret;
524 PacketSource* src;
525 if(gOpt.getLocalAddr() == "")
526 src = new UDPPacketSource(gOpt.getLocalPort());
527 else
528 src = new UDPPacketSource(gOpt.getLocalAddr(), gOpt.getLocalPort());
530 ConnectionList & cl (gConnectionList);
531 ConnectToList connect_to = gOpt.getConnectTo();
532 SyncQueue queue;
534 if(gOpt.getRemoteAddr() != "")
536 boost::asio::io_service io_service;
537 UDPPacketSource::proto::resolver resolver(io_service);
538 UDPPacketSource::proto::resolver::query query(gOpt.getRemoteAddr(), gOpt.getRemotePort());
539 UDPPacketSource::proto::endpoint endpoint = *resolver.resolve(query);
540 createConnection(endpoint,cl,gOpt.getSeqWindowSize(), queue, gOpt.getMux());
543 if(gOpt.getChroot())
544 chrootAndDrop(gOpt.getChrootDir(), gOpt.getUsername());
545 if(gOpt.getDaemonize())
547 daemonize();
548 daemonized = true;
551 if(pidFile.is_open()) {
552 pid_t pid = getpid();
553 pidFile << pid;
554 pidFile.close();
557 SignalController sig;
558 sig.init();
560 ThreadParam p(dev, *src, cl, queue,*(new OptionConnectTo()));
562 // this must be called before any other libgcrypt call
563 if(!initLibGCrypt())
564 return -1;
566 boost::thread senderThread(boost::bind(sender,&p));
567 boost::thread receiverThread(boost::bind(receiver,&p));
568 #ifndef ANYTUN_NOSYNC
569 boost::thread * syncListenerThread;
570 if(gOpt.getLocalSyncPort() != "")
571 syncListenerThread = new boost::thread(boost::bind(syncListener,&queue));
573 std::list<boost::thread *> connectThreads;
574 for(ConnectToList::iterator it = connect_to.begin() ;it != connect_to.end(); ++it) {
575 ThreadParam * point = new ThreadParam(dev, *src, cl, queue,*it);
576 connectThreads.push_back(new boost::thread(boost::bind(syncConnector,point)));
578 #endif
580 int ret = sig.run();
582 return ret;
583 // TODO cleanup here!
585 pthread_cancel(senderThread);
586 pthread_cancel(receiverThread);
587 #ifndef ANYTUN_NOSYNC
588 if ( gOpt.getLocalSyncPort())
589 pthread_cancel(syncListenerThread);
590 for( std::list<pthread_t>::iterator it = connectThreads.begin() ;it != connectThreads.end(); ++it)
591 pthread_cancel(*it);
592 #endif
594 pthread_join(senderThread, NULL);
595 pthread_join(receiverThread, NULL);
596 #ifndef ANYTUN_NOSYNC
597 if ( gOpt.getLocalSyncPort())
598 pthread_join(syncListenerThread, NULL);
600 for( std::list<pthread_t>::iterator it = connectThreads.begin() ;it != connectThreads.end(); ++it)
601 pthread_join(*it, NULL);
602 #endif
603 delete src;
604 delete &p.connto;
606 return ret;
609 catch(std::runtime_error& e)
611 if(daemonized)
612 cLog.msg(Log::PRIO_ERR) << "uncaught runtime error, exiting: " << e.what();
613 else
614 std::cout << "uncaught runtime error, exiting: " << e.what() << std::endl;
616 catch(std::exception& e)
618 if(daemonized)
619 cLog.msg(Log::PRIO_ERR) << "uncaught exception, exiting: " << e.what();
620 else
621 std::cout << "uncaught exception, exiting: " << e.what() << std::endl;