fixed issue with pid file and chrooting to early
[anytun.git] / anytun.cpp
blobb1279fbe2ae65337580769764841da30254e0851
1 /*
2 * anytun
4 * The secure anycast tunneling protocol (satp) defines a protocol used
5 * for communication between any combination of unicast and anycast
6 * tunnel endpoints. It has less protocol overhead than IPSec in Tunnel
7 * mode and allows tunneling of every ETHER TYPE protocol (e.g.
8 * ethernet, ip, arp ...). satp directly includes cryptography and
9 * message authentication based on the methodes used by SRTP. It is
10 * intended to deliver a generic, scaleable and secure solution for
11 * tunneling and relaying of packets of any protocol.
14 * Copyright (C) 2007 anytun.org <satp@wirdorange.org>
16 * This program is free software; you can redistribute it and/or modify
17 * it under the terms of the GNU General Public License version 2
18 * as published by the Free Software Foundation.
20 * This program is distributed in the hope that it will be useful,
21 * but WITHOUT ANY WARRANTY; without even the implied warranty of
22 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23 * GNU General Public License for more details.
25 * You should have received a copy of the GNU General Public License
26 * along with this program (see the file COPYING included with this
27 * distribution); if not, write to the Free Software Foundation, Inc.,
28 * 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
31 #include <iostream>
32 #include <fstream>
33 #include <poll.h>
34 #include <fcntl.h>
35 #include <pwd.h>
36 #include <grp.h>
38 #include <gcrypt.h>
39 #include <cerrno> // for ENOMEM
41 #include "datatypes.h"
43 #include "log.h"
44 #include "buffer.h"
45 #include "plainPacket.h"
46 #include "encryptedPacket.h"
47 #include "cipher.h"
48 #include "keyDerivation.h"
49 #include "authAlgo.h"
50 #include "cipherFactory.h"
51 #include "authAlgoFactory.h"
52 #include "keyDerivationFactory.h"
53 #include "signalController.h"
54 #include "packetSource.h"
55 #include "tunDevice.h"
56 #include "options.h"
57 #include "seqWindow.h"
58 #include "connectionList.h"
59 #include "routingTable.h"
60 #include "networkAddress.h"
62 #include "syncQueue.h"
63 #include "syncSocketHandler.h"
64 #include "syncListenSocket.h"
66 #include "syncSocket.h"
67 #include "syncClientSocket.h"
68 #include "syncCommand.h"
70 #include "threadParam.h"
72 #define MAX_PACKET_LENGTH 1600
74 #define SESSION_KEYLEN_AUTH 20 // TODO: hardcoded size
75 #define SESSION_KEYLEN_ENCR 16 // TODO: hardcoded size
76 #define SESSION_KEYLEN_SALT 14 // TODO: hardcoded size
78 void createConnection(const std::string & remote_host, u_int16_t remote_port, ConnectionList & cl, u_int16_t seqSize, SyncQueue & queue, mux_t mux)
80 SeqWindow * seq= new SeqWindow(seqSize);
81 seq_nr_t seq_nr_=0;
82 KeyDerivation * kd = KeyDerivationFactory::create(gOpt.getKdPrf());
83 kd->init(gOpt.getKey(), gOpt.getSalt());
84 cLog.msg(Log::PRIO_NOTICE) << "added connection remote host " << remote_host << ":" << remote_port;
85 ConnectionParam connparam ( (*kd), (*seq), seq_nr_, remote_host, remote_port);
86 cl.addConnection(connparam,mux);
87 NetworkAddress addr(ipv4,gOpt.getIfconfigParamRemoteNetmask().c_str());
88 NetworkPrefix prefix(addr,32);
89 gRoutingTable.addRoute(prefix,mux);
90 SyncCommand sc (cl,mux);
91 queue.push(sc);
92 SyncCommand sc2 (prefix);
93 queue.push(sc2);
96 bool checkPacketSeqNr(EncryptedPacket& pack,ConnectionParam& conn)
98 // compare sender_id and seq with window
99 if(conn.seq_window_.hasSeqNr(pack.getSenderId(), pack.getSeqNr()))
101 cLog.msg(Log::PRIO_NOTICE) << "Replay attack from " << conn.remote_host_<<":"<< conn.remote_port_
102 << " seq:"<<pack.getSeqNr() << " sid: "<<pack.getSenderId();
103 return false;
106 conn.seq_window_.addSeqNr(pack.getSenderId(), pack.getSeqNr());
107 return true;
110 void* sender(void* p)
112 ThreadParam* param = reinterpret_cast<ThreadParam*>(p);
114 std::auto_ptr<Cipher> c(CipherFactory::create(gOpt.getCipher()));
115 std::auto_ptr<AuthAlgo> a(AuthAlgoFactory::create(gOpt.getAuthAlgo()) );
117 PlainPacket plain_packet(MAX_PACKET_LENGTH);
118 EncryptedPacket encrypted_packet(MAX_PACKET_LENGTH);
120 Buffer session_key(u_int32_t(SESSION_KEYLEN_ENCR)); // TODO: hardcoded size
121 Buffer session_salt(u_int32_t(SESSION_KEYLEN_SALT)); // TODO: hardcoded size
122 Buffer session_auth_key(u_int32_t(SESSION_KEYLEN_AUTH)); // TODO: hardcoded size
124 //TODO replace mux
125 u_int16_t mux = gOpt.getMux();
126 while(1)
128 plain_packet.setLength(MAX_PACKET_LENGTH);
129 encrypted_packet.withAuthTag(false);
130 encrypted_packet.setLength(MAX_PACKET_LENGTH);
132 // read packet from device
133 u_int32_t len = param->dev.read(plain_packet.getPayload(), plain_packet.getPayloadLength());
134 plain_packet.setPayloadLength(len);
135 // set payload type
136 if(param->dev.getType() == TunDevice::TYPE_TUN)
137 plain_packet.setPayloadType(PAYLOAD_TYPE_TUN);
138 else if(param->dev.getType() == TunDevice::TYPE_TAP)
139 plain_packet.setPayloadType(PAYLOAD_TYPE_TAP);
140 else
141 plain_packet.setPayloadType(0);
143 if(param->cl.empty())
144 continue;
145 //std::cout << "got Packet for plain "<<plain_packet.getDstAddr().toString();
146 mux = gRoutingTable.getRoute(plain_packet.getDstAddr());
147 //std::cout << " -> "<<mux << std::endl;
148 ConnectionMap::iterator cit = param->cl.getConnection(mux);
149 if(cit==param->cl.getEnd())
150 continue;
151 ConnectionParam & conn = cit->second;
153 if(conn.remote_host_==""||!conn.remote_port_)
154 continue;
155 // generate packet-key TODO: do this only when needed
156 conn.kd_.generate(LABEL_SATP_ENCRYPTION, conn.seq_nr_, session_key);
157 conn.kd_.generate(LABEL_SATP_SALT, conn.seq_nr_, session_salt);
159 c->setKey(session_key);
160 c->setSalt(session_salt);
162 // encrypt packet
163 c->encrypt(plain_packet, encrypted_packet, conn.seq_nr_, gOpt.getSenderId(), mux);
165 encrypted_packet.setHeader(conn.seq_nr_, gOpt.getSenderId(), mux);
166 conn.seq_nr_++;
168 // add authentication tag
169 if(a->getMaxLength()) {
170 encrypted_packet.addAuthTag();
171 conn.kd_.generate(LABEL_SATP_MSG_AUTH, encrypted_packet.getSeqNr(), session_auth_key);
172 a->setKey(session_auth_key);
173 a->generate(encrypted_packet);
177 param->src.send(encrypted_packet.getBuf(), encrypted_packet.getLength(), conn.remote_host_, conn.remote_port_);
179 catch (Exception e)
183 pthread_exit(NULL);
186 void* syncConnector(void* p )
188 ThreadParam* param = reinterpret_cast<ThreadParam*>(p);
190 SocketHandler h;
191 SyncClientSocket sock(h,param->cl);
192 // sock.EnableSSL();
193 sock.Open( param->connto.host, param->connto.port);
194 h.Add(&sock);
195 while (h.GetCount())
197 h.Select();
199 pthread_exit(NULL);
202 void* syncListener(void* p )
204 ThreadParam* param = reinterpret_cast<ThreadParam*>(p);
206 SyncSocketHandler h(param->queue);
207 SyncListenSocket<SyncSocket,ConnectionList> l(h,param->cl);
209 if (l.Bind(gOpt.getLocalSyncPort()))
210 pthread_exit(NULL);
212 Utility::ResolveLocal(); // resolve local hostname
213 h.Add(&l);
214 h.Select(1,0);
215 while (1) {
216 h.Select(1,0);
220 void* receiver(void* p)
222 ThreadParam* param = reinterpret_cast<ThreadParam*>(p);
224 std::auto_ptr<Cipher> c( CipherFactory::create(gOpt.getCipher()) );
225 std::auto_ptr<AuthAlgo> a( AuthAlgoFactory::create(gOpt.getAuthAlgo()) );
227 EncryptedPacket encrypted_packet(MAX_PACKET_LENGTH);
228 PlainPacket plain_packet(MAX_PACKET_LENGTH);
230 Buffer session_key(u_int32_t(SESSION_KEYLEN_ENCR)); // TODO: hardcoded size
231 Buffer session_salt(u_int32_t(SESSION_KEYLEN_SALT)); // TODO: hardcoded size
232 Buffer session_auth_key(u_int32_t(SESSION_KEYLEN_AUTH)); // TODO: hardcoded size
234 while(1)
236 string remote_host;
237 u_int16_t remote_port;
239 plain_packet.setLength(MAX_PACKET_LENGTH);
240 encrypted_packet.withAuthTag(false);
241 encrypted_packet.setLength(MAX_PACKET_LENGTH);
243 // read packet from socket
244 u_int32_t len = param->src.recv(encrypted_packet.getBuf(), encrypted_packet.getLength(), remote_host, remote_port);
245 encrypted_packet.setLength(len);
247 mux_t mux = encrypted_packet.getMux();
248 // autodetect peer
249 if(gOpt.getRemoteAddr() == "" && param->cl.empty())
251 cLog.msg(Log::PRIO_NOTICE) << "autodetected remote host " << remote_host << ":" << remote_port;
252 createConnection(remote_host, remote_port, param->cl, gOpt.getSeqWindowSize(),param->queue,mux);
255 ConnectionMap::iterator cit = param->cl.getConnection(mux);
256 if (cit == param->cl.getEnd())
257 continue;
258 ConnectionParam & conn = cit->second;
260 // check whether auth tag is ok or not
261 if(a->getMaxLength()) {
262 encrypted_packet.withAuthTag(true);
263 conn.kd_.generate(LABEL_SATP_MSG_AUTH, encrypted_packet.getSeqNr(), session_auth_key);
264 a->setKey(session_auth_key);
265 if(!a->checkTag(encrypted_packet)) {
266 cLog.msg(Log::PRIO_NOTICE) << "wrong Authentication Tag!" << std::endl;
267 continue;
269 encrypted_packet.removeAuthTag();
272 //Allow dynamic IP changes
273 //TODO: add command line option to turn this off
274 if (remote_host != conn.remote_host_ || remote_port != conn.remote_port_)
276 cLog.msg(Log::PRIO_NOTICE) << "connection "<< mux << " autodetected remote host ip changed " << remote_host << ":" << remote_port;
277 conn.remote_host_=remote_host;
278 conn.remote_port_=remote_port;
279 SyncCommand sc (param->cl,mux);
280 param->queue.push(sc);
283 // Replay Protection
284 if (!checkPacketSeqNr(encrypted_packet, conn))
285 continue;
287 // generate packet-key
288 conn.kd_.generate(LABEL_SATP_ENCRYPTION, encrypted_packet.getSeqNr(), session_key);
289 conn.kd_.generate(LABEL_SATP_SALT, encrypted_packet.getSeqNr(), session_salt);
290 c->setKey(session_key);
291 c->setSalt(session_salt);
293 // decrypt packet
294 c->decrypt(encrypted_packet, plain_packet);
296 // check payload_type
297 if((param->dev.getType() == TunDevice::TYPE_TUN && plain_packet.getPayloadType() != PAYLOAD_TYPE_TUN4 &&
298 plain_packet.getPayloadType() != PAYLOAD_TYPE_TUN6) ||
299 (param->dev.getType() == TunDevice::TYPE_TAP && plain_packet.getPayloadType() != PAYLOAD_TYPE_TAP))
300 continue;
302 // write it on the device
303 param->dev.write(plain_packet.getPayload(), plain_packet.getLength());
305 pthread_exit(NULL);
308 #define MIN_GCRYPT_VERSION "1.2.0"
309 // make libgcrypt thread safe
310 extern "C" {
311 GCRY_THREAD_OPTION_PTHREAD_IMPL;
314 bool initLibGCrypt()
316 // make libgcrypt thread safe
317 // this must be called before any other libgcrypt call
318 gcry_control( GCRYCTL_SET_THREAD_CBS, &gcry_threads_pthread );
320 // this must be called right after the GCRYCTL_SET_THREAD_CBS command
321 // no other function must be called till now
322 if( !gcry_check_version( MIN_GCRYPT_VERSION ) ) {
323 std::cout << "initLibGCrypt: Invalid Version of libgcrypt, should be >= " << MIN_GCRYPT_VERSION << std::endl;
324 return false;
327 gcry_error_t err = gcry_control (GCRYCTL_DISABLE_SECMEM, 0);
328 if( err ) {
329 std::cout << "initLibGCrypt: Failed to disable secure memory: " << gpg_strerror( err ) << std::endl;
330 return false;
333 // Tell Libgcrypt that initialization has completed.
334 err = gcry_control(GCRYCTL_INITIALIZATION_FINISHED);
335 if( err ) {
336 std::cout << "initLibGCrypt: Failed to finish the initialization of libgcrypt: " << gpg_strerror( err ) << std::endl;
337 return false;
340 cLog.msg(Log::PRIO_NOTICE) << "initLibGCrypt: libgcrypt init finished";
341 return true;
344 void chrootAndDrop(string const& chrootdir, string const& username)
346 if (getuid() != 0)
348 std::cerr << "this programm has to be run as root in order to run in a chroot" << std::endl;
349 exit(-1);
352 struct passwd *pw = getpwnam(username.c_str());
353 if(pw) {
354 if(chroot(chrootdir.c_str()))
356 std::cerr << "can't chroot to " << chrootdir << std::endl;
357 exit(-1);
359 cLog.msg(Log::PRIO_NOTICE) << "we are in chroot jail (" << chrootdir << ") now" << std::endl;
360 chdir("/");
361 if (initgroups(pw->pw_name, pw->pw_gid) || setgid(pw->pw_gid) || setuid(pw->pw_uid))
363 std::cerr << "can't drop to user " << username << " " << pw->pw_uid << ":" << pw->pw_gid << std::endl;
364 exit(-1);
366 cLog.msg(Log::PRIO_NOTICE) << "dropped user to " << username << " " << pw->pw_uid << ":" << pw->pw_gid << std::endl;
368 else
370 std::cerr << "unknown user " << username << std::endl;
371 exit(-1);
375 void daemonize()
377 pid_t pid;
379 pid = fork();
380 if(pid) exit(0);
381 setsid();
382 pid = fork();
383 if(pid) exit(0);
385 // std::cout << "running in background now..." << std::endl;
387 int fd;
388 for (fd=getdtablesize();fd>=0;--fd) // close all file descriptors
389 close(fd);
390 fd=open("/dev/null",O_RDWR); // stdin
391 dup(fd); // stdout
392 dup(fd); // stderr
393 umask(027);
396 void writePid(string const& pidFilename)
400 int main(int argc, char* argv[])
402 // std::cout << "anytun - secure anycast tunneling protocol" << std::endl;
403 if(!gOpt.parse(argc, argv))
405 gOpt.printUsage();
406 exit(-1);
409 cLog.msg(Log::PRIO_NOTICE) << "anytun started...";
411 std::ofstream pidFile;
412 if(gOpt.getPidFile() != "") {
413 pidFile.open(gOpt.getPidFile().c_str());
414 if(!pidFile.is_open()) {
415 std::cout << "can't open pid file" << std::endl;
419 std::string dev_type(gOpt.getDevType());
420 TunDevice dev(gOpt.getDevName().c_str(), dev_type=="" ? NULL : dev_type.c_str(),
421 gOpt.getIfconfigParamLocal() =="" ? NULL : gOpt.getIfconfigParamLocal().c_str(),
422 gOpt.getIfconfigParamRemoteNetmask() =="" ? NULL : gOpt.getIfconfigParamRemoteNetmask().c_str());
423 cLog.msg(Log::PRIO_NOTICE) << "dev created (opened)";
424 cLog.msg(Log::PRIO_NOTICE) << "dev opened - actual name is '" << dev.getActualName() << "'";
425 cLog.msg(Log::PRIO_NOTICE) << "dev type is '" << dev.getTypeString() << "'";
427 if(gOpt.getChroot())
428 chrootAndDrop(gOpt.getChrootDir(), gOpt.getUsername());
429 if(gOpt.getDaemonize())
430 daemonize();
431 if(pidFile.is_open()) {
432 pid_t pid = getpid();
433 pidFile << pid;
434 pidFile.close();
437 SignalController sig;
438 sig.init();
440 PacketSource* src;
441 if(gOpt.getLocalAddr() == "")
442 src = new UDPPacketSource(gOpt.getLocalPort());
443 else
444 src = new UDPPacketSource(gOpt.getLocalAddr(), gOpt.getLocalPort());
446 ConnectionList cl;
447 ConnectToList connect_to = gOpt.getConnectTo();
448 SyncQueue queue;
450 if(gOpt.getRemoteAddr() != "")
451 createConnection(gOpt.getRemoteAddr(),gOpt.getRemotePort(),cl,gOpt.getSeqWindowSize(), queue, gOpt.getMux());
453 ThreadParam p(dev, *src, cl, queue,*(new OptionConnectTo()));
455 // this must be called before any other libgcrypt call
456 if(!initLibGCrypt())
457 return -1;
459 pthread_t senderThread;
460 pthread_create(&senderThread, NULL, sender, &p);
461 pthread_t receiverThread;
462 pthread_create(&receiverThread, NULL, receiver, &p);
464 pthread_t syncListenerThread;
465 if ( gOpt.getLocalSyncPort())
466 pthread_create(&syncListenerThread, NULL, syncListener, &p);
468 std::list<pthread_t> connectThreads;
469 for(ConnectToList::iterator it = connect_to.begin() ;it != connect_to.end(); ++it)
471 connectThreads.push_back(pthread_t());
472 ThreadParam * point = new ThreadParam(dev, *src, cl, queue,*it);
473 pthread_create(& connectThreads.back(), NULL, syncConnector, point);
476 int ret = sig.run();
478 pthread_cancel(senderThread);
479 pthread_cancel(receiverThread);
480 if ( gOpt.getLocalSyncPort())
481 pthread_cancel(syncListenerThread);
482 for( std::list<pthread_t>::iterator it = connectThreads.begin() ;it != connectThreads.end(); ++it)
483 pthread_cancel(*it);
485 pthread_join(senderThread, NULL);
486 pthread_join(receiverThread, NULL);
487 if ( gOpt.getLocalSyncPort())
488 pthread_join(syncListenerThread, NULL);
490 for( std::list<pthread_t>::iterator it = connectThreads.begin() ;it != connectThreads.end(); ++it)
491 pthread_join(*it, NULL);
493 delete src;
494 delete &p.connto;
496 return ret;