fixed anytun-controld
[anytun.git] / src / anytun.cpp
blob8504d907caecedf9c288b5cf32c7dfed402790b4
1 /*
2 * anytun
4 * The secure anycast tunneling protocol (satp) defines a protocol used
5 * for communication between any combination of unicast and anycast
6 * tunnel endpoints. It has less protocol overhead than IPSec in Tunnel
7 * mode and allows tunneling of every ETHER TYPE protocol (e.g.
8 * ethernet, ip, arp ...). satp directly includes cryptography and
9 * message authentication based on the methodes used by SRTP. It is
10 * intended to deliver a generic, scaleable and secure solution for
11 * tunneling and relaying of packets of any protocol.
14 * Copyright (C) 2007-2008 Othmar Gsenger, Erwin Nindl,
15 * Christian Pointner <satp@wirdorange.org>
17 * This file is part of Anytun.
19 * Anytun is free software: you can redistribute it and/or modify
20 * it under the terms of the GNU General Public License version 3 as
21 * published by the Free Software Foundation.
23 * Anytun is distributed in the hope that it will be useful,
24 * but WITHOUT ANY WARRANTY; without even the implied warranty of
25 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
26 * GNU General Public License for more details.
28 * You should have received a copy of the GNU General Public License
29 * along with anytun. If not, see <http://www.gnu.org/licenses/>.
32 #include <iostream>
33 #include <fstream>
34 #include <poll.h>
35 #include <fcntl.h>
36 #include <pwd.h>
37 #include <grp.h>
38 #include <sys/wait.h>
39 #include <sys/stat.h>
40 #include <unistd.h>
42 #include <boost/bind.hpp>
43 #include <boost/thread/detail/lock.hpp>
44 #include <gcrypt.h>
45 #include <cerrno> // for ENOMEM
47 #include "datatypes.h"
49 #include "log.h"
50 #include "buffer.h"
51 #include "plainPacket.h"
52 #include "encryptedPacket.h"
53 #include "cipher.h"
54 #include "keyDerivation.h"
55 #include "authAlgo.h"
56 #include "cipherFactory.h"
57 #include "authAlgoFactory.h"
58 #include "keyDerivationFactory.h"
59 #include "signalController.h"
60 #include "packetSource.h"
61 #include "tunDevice.h"
62 #include "options.h"
63 #include "seqWindow.h"
64 #include "connectionList.h"
65 #include "routingTable.h"
66 #include "networkAddress.h"
68 #include "syncQueue.h"
69 #include "syncCommand.h"
71 #ifndef ANYTUN_NOSYNC
72 #include "syncServer.h"
73 #include "syncClient.h"
74 #include "syncOnConnect.hpp"
75 #endif
77 #include "threadParam.h"
78 #define MAX_PACKET_LENGTH 1600
80 #define SESSION_KEYLEN_AUTH 20 // TODO: hardcoded size
81 #define SESSION_KEYLEN_ENCR 16 // TODO: hardcoded size
82 #define SESSION_KEYLEN_SALT 14 // TODO: hardcoded size
84 void createConnection(const std::string & remote_host, u_int16_t remote_port, ConnectionList & cl, u_int16_t seqSize, SyncQueue & queue, mux_t mux)
86 SeqWindow * seq= new SeqWindow(seqSize);
87 seq_nr_t seq_nr_=0;
88 KeyDerivation * kd = KeyDerivationFactory::create(gOpt.getKdPrf());
89 kd->init(gOpt.getKey(), gOpt.getSalt());
90 cLog.msg(Log::PRIO_NOTICE) << "added connection remote host " << remote_host << ":" << remote_port;
91 ConnectionParam connparam ( (*kd), (*seq), seq_nr_, remote_host, remote_port);
92 cl.addConnection(connparam,mux);
93 NetworkAddress addr(ipv4,gOpt.getIfconfigParamRemoteNetmask().c_str());
94 NetworkPrefix prefix(addr,32);
95 gRoutingTable.addRoute(prefix,mux);
96 SyncCommand sc (cl,mux);
97 queue.push(sc);
98 SyncCommand sc2 (prefix);
99 queue.push(sc2);
102 bool checkPacketSeqNr(EncryptedPacket& pack,ConnectionParam& conn)
104 // compare sender_id and seq with window
105 if(conn.seq_window_.hasSeqNr(pack.getSenderId(), pack.getSeqNr()))
107 cLog.msg(Log::PRIO_NOTICE) << "Replay attack from " << conn.remote_host_<<":"<< conn.remote_port_
108 << " seq:"<<pack.getSeqNr() << " sid: "<<pack.getSenderId();
109 return false;
112 conn.seq_window_.addSeqNr(pack.getSenderId(), pack.getSeqNr());
113 return true;
116 void sender(void* p)
118 try
120 ThreadParam* param = reinterpret_cast<ThreadParam*>(p);
122 std::auto_ptr<Cipher> c(CipherFactory::create(gOpt.getCipher()));
123 std::auto_ptr<AuthAlgo> a(AuthAlgoFactory::create(gOpt.getAuthAlgo()) );
125 PlainPacket plain_packet(MAX_PACKET_LENGTH);
126 EncryptedPacket encrypted_packet(MAX_PACKET_LENGTH);
128 Buffer session_key(u_int32_t(SESSION_KEYLEN_ENCR)); // TODO: hardcoded size
129 Buffer session_salt(u_int32_t(SESSION_KEYLEN_SALT)); // TODO: hardcoded size
130 Buffer session_auth_key(u_int32_t(SESSION_KEYLEN_AUTH)); // TODO: hardcoded size
132 //TODO replace mux
133 u_int16_t mux = gOpt.getMux();
134 while(1)
136 plain_packet.setLength(MAX_PACKET_LENGTH);
137 encrypted_packet.withAuthTag(false);
138 encrypted_packet.setLength(MAX_PACKET_LENGTH);
140 // read packet from device
141 u_int32_t len = param->dev.read(plain_packet.getPayload(), plain_packet.getPayloadLength());
142 plain_packet.setPayloadLength(len);
143 // set payload type
144 if(param->dev.getType() == TYPE_TUN)
145 plain_packet.setPayloadType(PAYLOAD_TYPE_TUN);
146 else if(param->dev.getType() == TYPE_TAP)
147 plain_packet.setPayloadType(PAYLOAD_TYPE_TAP);
148 else
149 plain_packet.setPayloadType(0);
151 if(param->cl.empty())
152 continue;
153 //std::cout << "got Packet for plain "<<plain_packet.getDstAddr().toString();
154 mux = gRoutingTable.getRoute(plain_packet.getDstAddr());
155 //std::cout << " -> "<<mux << std::endl;
156 ConnectionMap::iterator cit = param->cl.getConnection(mux);
157 if(cit==param->cl.getEnd())
158 continue;
159 ConnectionParam & conn = cit->second;
161 if(conn.remote_host_==""||!conn.remote_port_)
162 continue;
163 // generate packet-key TODO: do this only when needed
164 conn.kd_.generate(LABEL_SATP_ENCRYPTION, conn.seq_nr_, session_key);
165 conn.kd_.generate(LABEL_SATP_SALT, conn.seq_nr_, session_salt);
167 c->setKey(session_key);
168 c->setSalt(session_salt);
170 // encrypt packet
171 c->encrypt(plain_packet, encrypted_packet, conn.seq_nr_, gOpt.getSenderId(), mux);
173 encrypted_packet.setHeader(conn.seq_nr_, gOpt.getSenderId(), mux);
174 conn.seq_nr_++;
176 // add authentication tag
177 if(a->getMaxLength()) {
178 encrypted_packet.addAuthTag();
179 conn.kd_.generate(LABEL_SATP_MSG_AUTH, encrypted_packet.getSeqNr(), session_auth_key);
180 a->setKey(session_auth_key);
181 a->generate(encrypted_packet);
185 param->src.send(encrypted_packet.getBuf(), encrypted_packet.getLength(), conn.remote_host_, conn.remote_port_);
187 catch (std::exception& e)
189 // ignoring icmp port unreachable :) and other socket errors :(
193 catch(std::runtime_error& e)
195 cLog.msg(Log::PRIO_ERR) << "sender thread died due to an uncaught runtime_error: " << e.what();
197 catch(std::exception& e)
199 cLog.msg(Log::PRIO_ERR) << "sender thread died due to an uncaught exception: " << e.what();
203 #ifndef ANYTUN_NOSYNC
204 void syncConnector(void* p )
206 ThreadParam* param = reinterpret_cast<ThreadParam*>(p);
208 SyncClient sc ( param->connto.host, param->connto.port);
209 sc.run();
212 void syncListener(SyncQueue * queue )
214 // ThreadParam* param = reinterpret_cast<ThreadParam*>(p);
218 asio::io_service io_service;
219 SyncServer server(io_service,asio::ip::tcp::endpoint(asio::ip::tcp::v4(), gOpt.getLocalSyncPort()));
220 server.onConnect=boost::bind(syncOnConnect,_1);
221 queue->setSyncServerPtr(&server);
222 io_service.run();
224 catch (std::exception& e)
226 std::cerr << e.what() << std::endl;
230 #endif
232 void receiver(void* p)
236 ThreadParam* param = reinterpret_cast<ThreadParam*>(p);
238 std::auto_ptr<Cipher> c( CipherFactory::create(gOpt.getCipher()) );
239 std::auto_ptr<AuthAlgo> a( AuthAlgoFactory::create(gOpt.getAuthAlgo()) );
241 EncryptedPacket encrypted_packet(MAX_PACKET_LENGTH);
242 PlainPacket plain_packet(MAX_PACKET_LENGTH);
244 Buffer session_key(u_int32_t(SESSION_KEYLEN_ENCR)); // TODO: hardcoded size
245 Buffer session_salt(u_int32_t(SESSION_KEYLEN_SALT)); // TODO: hardcoded size
246 Buffer session_auth_key(u_int32_t(SESSION_KEYLEN_AUTH)); // TODO: hardcoded size
248 while(1)
250 std::string remote_host;
251 u_int16_t remote_port;
253 plain_packet.setLength(MAX_PACKET_LENGTH);
254 encrypted_packet.withAuthTag(false);
255 encrypted_packet.setLength(MAX_PACKET_LENGTH);
257 // read packet from socket
258 u_int32_t len = param->src.recv(encrypted_packet.getBuf(), encrypted_packet.getLength(), remote_host, remote_port);
259 encrypted_packet.setLength(len);
261 mux_t mux = encrypted_packet.getMux();
262 // autodetect peer
263 if(gOpt.getRemoteAddr() == "" && param->cl.empty())
265 cLog.msg(Log::PRIO_NOTICE) << "autodetected remote host " << remote_host << ":" << remote_port;
266 createConnection(remote_host, remote_port, param->cl, gOpt.getSeqWindowSize(),param->queue,mux);
269 ConnectionMap::iterator cit = param->cl.getConnection(mux);
270 if (cit == param->cl.getEnd())
271 continue;
272 ConnectionParam & conn = cit->second;
274 // check whether auth tag is ok or not
275 if(a->getMaxLength()) {
276 encrypted_packet.withAuthTag(true);
277 conn.kd_.generate(LABEL_SATP_MSG_AUTH, encrypted_packet.getSeqNr(), session_auth_key);
278 a->setKey(session_auth_key);
279 if(!a->checkTag(encrypted_packet)) {
280 cLog.msg(Log::PRIO_NOTICE) << "wrong Authentication Tag!" << std::endl;
281 continue;
283 encrypted_packet.removeAuthTag();
286 //Allow dynamic IP changes
287 //TODO: add command line option to turn this off
288 if (remote_host != conn.remote_host_ || remote_port != conn.remote_port_)
290 cLog.msg(Log::PRIO_NOTICE) << "connection "<< mux << " autodetected remote host ip changed "
291 << remote_host << ":" << remote_port;
292 conn.remote_host_=remote_host;
293 conn.remote_port_=remote_port;
294 SyncCommand sc (param->cl,mux);
295 param->queue.push(sc);
298 // Replay Protection
299 if (!checkPacketSeqNr(encrypted_packet, conn))
300 continue;
302 // generate packet-key
303 conn.kd_.generate(LABEL_SATP_ENCRYPTION, encrypted_packet.getSeqNr(), session_key);
304 conn.kd_.generate(LABEL_SATP_SALT, encrypted_packet.getSeqNr(), session_salt);
305 c->setKey(session_key);
306 c->setSalt(session_salt);
308 // decrypt packet
309 c->decrypt(encrypted_packet, plain_packet);
311 // check payload_type
312 if((param->dev.getType() == TYPE_TUN && plain_packet.getPayloadType() != PAYLOAD_TYPE_TUN4 &&
313 plain_packet.getPayloadType() != PAYLOAD_TYPE_TUN6) ||
314 (param->dev.getType() == TYPE_TAP && plain_packet.getPayloadType() != PAYLOAD_TYPE_TAP))
315 continue;
317 // write it on the device
318 param->dev.write(plain_packet.getPayload(), plain_packet.getLength());
321 catch(std::runtime_error& e)
323 cLog.msg(Log::PRIO_ERR) << "sender thread died due to an uncaught runtime_error: " << e.what();
325 catch(std::exception& e)
327 cLog.msg(Log::PRIO_ERR) << "receiver thread died due to an uncaught exception: " << e.what();
331 // boost thread callbacks for libgcrypt
332 #if defined(BOOST_HAS_PTHREADS)
333 typedef boost::detail::thread::lock_ops<boost::mutex> mutex_ops;
335 static int boost_mutex_init(void **priv)
337 boost::mutex *lock = new boost::mutex();
338 if (!lock)
339 return ENOMEM;
340 *priv = lock;
341 return 0;
344 static int boost_mutex_destroy(void **lock)
346 delete reinterpret_cast<boost::mutex*>(*lock);
347 return 0;
350 static int boost_mutex_lock(void **lock)
352 mutex_ops::lock(*reinterpret_cast<boost::mutex*>(*lock));
353 return 0;
356 static int boost_mutex_unlock(void **lock)
358 mutex_ops::unlock(*reinterpret_cast<boost::mutex*>(*lock));
359 return 0;
362 static struct gcry_thread_cbs gcry_threads_boost =
363 { GCRY_THREAD_OPTION_USER, NULL,
364 boost_mutex_init, boost_mutex_destroy,
365 boost_mutex_lock, boost_mutex_unlock };
366 #else
367 #error this libgcrypt thread callbacks only work with pthreads
368 #endif
370 #define MIN_GCRYPT_VERSION "1.2.0"
372 bool initLibGCrypt()
374 // make libgcrypt thread safe
375 // this must be called before any other libgcrypt call
376 gcry_control( GCRYCTL_SET_THREAD_CBS, &gcry_threads_boost );
378 // this must be called right after the GCRYCTL_SET_THREAD_CBS command
379 // no other function must be called till now
380 if( !gcry_check_version( MIN_GCRYPT_VERSION ) ) {
381 std::cout << "initLibGCrypt: Invalid Version of libgcrypt, should be >= " << MIN_GCRYPT_VERSION << std::endl;
382 return false;
385 gcry_error_t err = gcry_control (GCRYCTL_DISABLE_SECMEM, 0);
386 if( err ) {
387 char buf[STERROR_TEXT_MAX];
388 buf[0] = 0;
389 std::cout << "initLibGCrypt: Failed to disable secure memory: " << gpg_strerror_r(err, buf, STERROR_TEXT_MAX) << std::endl;
390 return false;
393 // Tell Libgcrypt that initialization has completed.
394 err = gcry_control(GCRYCTL_INITIALIZATION_FINISHED);
395 if( err ) {
396 char buf[STERROR_TEXT_MAX];
397 buf[0] = 0;
398 std::cout << "initLibGCrypt: Failed to finish initialization: " << gpg_strerror_r(err, buf, STERROR_TEXT_MAX) << std::endl;
399 return false;
402 cLog.msg(Log::PRIO_NOTICE) << "initLibGCrypt: libgcrypt init finished";
403 return true;
406 void chrootAndDrop(std::string const& chrootdir, std::string const& username)
408 if (getuid() != 0)
410 std::cerr << "this programm has to be run as root in order to run in a chroot" << std::endl;
411 exit(-1);
414 struct passwd *pw = getpwnam(username.c_str());
415 if(pw) {
416 if(chroot(chrootdir.c_str()))
418 std::cerr << "can't chroot to " << chrootdir << std::endl;
419 exit(-1);
421 cLog.msg(Log::PRIO_NOTICE) << "we are in chroot jail (" << chrootdir << ") now" << std::endl;
422 chdir("/");
423 if (initgroups(pw->pw_name, pw->pw_gid) || setgid(pw->pw_gid) || setuid(pw->pw_uid))
425 std::cerr << "can't drop to user " << username << " " << pw->pw_uid << ":" << pw->pw_gid << std::endl;
426 exit(-1);
428 cLog.msg(Log::PRIO_NOTICE) << "dropped user to " << username << " " << pw->pw_uid << ":" << pw->pw_gid << std::endl;
430 else
432 std::cerr << "unknown user " << username << std::endl;
433 exit(-1);
437 void daemonize()
439 pid_t pid;
441 pid = fork();
442 if(pid) exit(0);
443 setsid();
444 pid = fork();
445 if(pid) exit(0);
447 // std::cout << "running in background now..." << std::endl;
449 int fd;
450 // for (fd=getdtablesize();fd>=0;--fd) // close all file descriptors
451 for (fd=0;fd<=2;fd++) // close all file descriptors
452 close(fd);
453 fd=open("/dev/null",O_RDWR); // stdin
454 dup(fd); // stdout
455 dup(fd); // stderr
456 umask(027);
459 int execScript(std::string const& script, std::string const& ifname)
461 pid_t pid;
462 pid = fork();
463 if(!pid) {
464 int fd;
465 for (fd=getdtablesize();fd>=0;--fd) // close all file descriptors
466 close(fd);
467 fd=open("/dev/null",O_RDWR); // stdin
468 dup(fd); // stdout
469 dup(fd); // stderr
470 return execl("/bin/sh", "/bin/sh", script.c_str(), ifname.c_str(), NULL);
472 int status = 0;
473 waitpid(pid, &status, 0);
474 return status;
477 int main(int argc, char* argv[])
479 bool daemonized=false;
480 try
483 // std::cout << "anytun - secure anycast tunneling protocol" << std::endl;
484 if(!gOpt.parse(argc, argv)) {
485 gOpt.printUsage();
486 exit(-1);
489 cLog.msg(Log::PRIO_NOTICE) << "anytun started...";
491 std::ofstream pidFile;
492 if(gOpt.getPidFile() != "") {
493 pidFile.open(gOpt.getPidFile().c_str());
494 if(!pidFile.is_open()) {
495 std::cout << "can't open pid file" << std::endl;
499 TunDevice dev(gOpt.getDevName() =="" ? NULL : gOpt.getDevName().c_str(),
500 gOpt.getDevType() =="" ? NULL : gOpt.getDevType().c_str(),
501 gOpt.getIfconfigParamLocal() =="" ? NULL : gOpt.getIfconfigParamLocal().c_str(),
502 gOpt.getIfconfigParamRemoteNetmask() =="" ? NULL : gOpt.getIfconfigParamRemoteNetmask().c_str());
503 cLog.msg(Log::PRIO_NOTICE) << "dev created (opened)";
504 cLog.msg(Log::PRIO_NOTICE) << "dev opened - actual name is '" << dev.getActualName() << "'";
505 cLog.msg(Log::PRIO_NOTICE) << "dev type is '" << dev.getTypeString() << "'";
506 if(gOpt.getPostUpScript() != "") {
507 int postup_ret = execScript(gOpt.getPostUpScript(), dev.getActualName());
508 cLog.msg(Log::PRIO_NOTICE) << "post up script '" << gOpt.getPostUpScript() << "' returned " << postup_ret;
511 if(gOpt.getChroot())
512 chrootAndDrop(gOpt.getChrootDir(), gOpt.getUsername());
513 if(gOpt.getDaemonize())
515 daemonize();
516 daemonized = true;
519 if(pidFile.is_open()) {
520 pid_t pid = getpid();
521 pidFile << pid;
522 pidFile.close();
525 SignalController sig;
526 sig.init();
528 PacketSource* src;
529 if(gOpt.getLocalAddr() == "")
530 src = new UDPPacketSource(gOpt.getLocalPort());
531 else
532 src = new UDPPacketSource(gOpt.getLocalAddr(), gOpt.getLocalPort());
534 ConnectionList & cl (gConnectionList);
535 ConnectToList connect_to = gOpt.getConnectTo();
536 SyncQueue queue;
538 if(gOpt.getRemoteAddr() != "")
539 createConnection(gOpt.getRemoteAddr(),gOpt.getRemotePort(),cl,gOpt.getSeqWindowSize(), queue, gOpt.getMux());
541 ThreadParam p(dev, *src, cl, queue,*(new OptionConnectTo()));
543 // this must be called before any other libgcrypt call
544 if(!initLibGCrypt())
545 return -1;
547 boost::thread senderThread(boost::bind(sender,&p));
548 boost::thread receiverThread(boost::bind(receiver,&p));
549 #ifndef ANYTUN_NOSYNC
550 boost::thread * syncListenerThread;
551 if ( gOpt.getLocalSyncPort())
552 syncListenerThread = new boost::thread(boost::bind(syncListener,&queue));
554 std::list<boost::thread *> connectThreads;
555 for(ConnectToList::iterator it = connect_to.begin() ;it != connect_to.end(); ++it) {
556 ThreadParam * point = new ThreadParam(dev, *src, cl, queue,*it);
557 connectThreads.push_back(new boost::thread(boost::bind(syncConnector,point)));
559 #endif
561 int ret = sig.run();
563 return ret;
564 // TODO cleanup here!
566 pthread_cancel(senderThread);
567 pthread_cancel(receiverThread);
568 #ifndef ANYTUN_NOSYNC
569 if ( gOpt.getLocalSyncPort())
570 pthread_cancel(syncListenerThread);
571 for( std::list<pthread_t>::iterator it = connectThreads.begin() ;it != connectThreads.end(); ++it)
572 pthread_cancel(*it);
573 #endif
575 pthread_join(senderThread, NULL);
576 pthread_join(receiverThread, NULL);
577 #ifndef ANYTUN_NOSYNC
578 if ( gOpt.getLocalSyncPort())
579 pthread_join(syncListenerThread, NULL);
581 for( std::list<pthread_t>::iterator it = connectThreads.begin() ;it != connectThreads.end(); ++it)
582 pthread_join(*it, NULL);
583 #endif
584 delete src;
585 delete &p.connto;
587 return ret;
590 catch(std::runtime_error& e)
592 if(daemonized)
593 cLog.msg(Log::PRIO_ERR) << "uncaught runtime error, exiting: " << e.what();
594 else
595 std::cout << "uncaught runtime error, exiting: " << e.what() << std::endl;
597 catch(std::exception& e)
599 if(daemonized)
600 cLog.msg(Log::PRIO_ERR) << "uncaught exception, exiting: " << e.what();
601 else
602 std::cout << "uncaught exception, exiting: " << e.what() << std::endl;