fix compat with py3
[rofl0r-nat-tunnel.git] / natsrv.py
blobc587db0d8462bcd83573ecba300c3dde3b1c9d22
1 from __future__ import print_function
2 import socket, select, os, threading, hashlib, rocksock, time, sys, codecs
4 PY3 = sys.version_info[0] == 3
5 if PY3:
6 def _b(a, b):
7 return bytes(a, b)
8 else:
9 def _b(a, b):
10 return bytes(a)
12 NONCE_LEN = 8
14 def _get_nonce():
15 return codecs.encode(os.urandom(NONCE_LEN), 'hex')
17 def _hash(str):
18 return _b(hashlib.sha256(str).hexdigest(), 'utf-8')
20 def _format_addr(addr):
21 ip, port = addr
22 ip = _b(ip, 'utf-8')
23 return b"%s:%d"%(ip, port)
25 def _timestamp():
26 return _b(time.strftime('[%Y-%m-%d %H:%M:%S] ', time.localtime(time.time())), 'utf-8')
28 class Tunnel():
29 def __init__(self, fds, fdc, caddr):
30 self.fds = fds
31 self.fdc = fdc
32 self.done = threading.Event()
33 self.t = None
34 def _cleanup(self):
35 if self.fdc: self.fdc.close()
36 if self.fds: self.fds.close()
37 self.fdc = None
38 self.fds = None
39 def _threadfunc(self):
40 while True:
41 a,b,c = select.select([self.fds, self.fdc], [], [])
42 try:
43 buf = a[0].recv(1024)
44 except:
45 buf = ''
46 if len(buf) == 0:
47 break
48 try:
49 if a[0] == self.fds:
50 self.fdc.send(buf)
51 else:
52 self.fds.send(buf)
53 except:
54 break
55 self._cleanup()
56 self.done.set()
57 def start(self):
58 self.t = threading.Thread(target=self._threadfunc)
59 self.t.daemon = True
60 self.t.start()
61 def finished(self):
62 return self.done.is_set()
63 def reap(self):
64 self.t.join()
66 class NATClient():
67 def __init__(self, secret, upstream_ip, upstream_port, localserv_ip, localserv_port):
68 self.secret = secret
69 self.localserv_ip = localserv_ip
70 self.localserv_port = localserv_port
71 self.upstream_ip = upstream_ip
72 self.upstream_port = upstream_port
73 self.controlsock = None
74 self.next_csock = None
75 self.threads = []
77 def _setup_sock(self, cmd):
78 sock = rocksock.Rocksock(host=self.upstream_ip, port=self.upstream_port)
79 sock.connect()
80 nonce = sock.recv(NONCE_LEN*2 + 1).rstrip(b'\n')
81 sock.send(_hash(cmd + self.secret + nonce) + b'\n')
82 return sock
84 def setup(self):
85 self.controlsock = self._setup_sock(b'adm')
86 self.next_csock = self._setup_sock(b'skt')
88 def doit(self):
89 while True:
90 i = 0
91 while i < len(self.threads):
92 if self.threads[i].finished():
93 self.threads[i].reap()
94 self.threads.pop(i)
95 else:
96 i += 1
98 l = self.controlsock.recvline()
99 print(_timestamp() + l.rstrip(b'\n'))
100 if l.startswith(b'CONN:'):
101 addr=l.rstrip(b'\n').split(b':')[1]
102 local_conn = rocksock.Rocksock(host=self.localserv_ip, port=self.localserv_port)
103 local_conn.connect()
104 thread = Tunnel(local_conn.sock, self.next_csock.sock, addr)
105 thread.start()
106 self.threads.append(thread)
107 self.next_csock = self._setup_sock(b'skt')
110 class NATSrv():
111 def _isnumericipv4(self, ip):
112 try:
113 a,b,c,d = ip.split('.')
114 if int(a) < 256 and int(b) < 256 and int(c) < 256 and int(d) < 256:
115 return True
116 return False
117 except:
118 return False
120 def _resolve(self, host, port, want_v4=True):
121 if self._isnumericipv4(host):
122 return socket.AF_INET, (host, port)
123 for res in socket.getaddrinfo(host, port, \
124 socket.AF_UNSPEC, socket.SOCK_STREAM, 0, socket.AI_PASSIVE):
125 af, socktype, proto, canonname, sa = res
126 if want_v4 and af != socket.AF_INET: continue
127 if af != socket.AF_INET and af != socket.AF_INET6: continue
128 else: return af, sa
130 return None, None
132 def __init__(self, secret, upstream_listen_ip, upstream_port, client_listen_ip, client_port):
133 self.up_port = upstream_port
134 self.up_ip = upstream_listen_ip
135 self.client_port = client_port
136 self.client_ip = client_listen_ip
137 self.secret = secret
138 self.threads = []
139 self.su = None
140 self.sc = None
141 self.control_socket = None
142 self.next_upstream_socket = None
143 self.hashlen = len(_hash(b''))
145 def _setup_listen_socket(self, listenip, port):
146 af, sa = self._resolve(listenip, port)
147 s = socket.socket(af, socket.SOCK_STREAM)
148 s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
149 s.bind((sa[0], sa[1]))
150 s.listen(1)
151 return s
153 def setup(self):
154 self.su = self._setup_listen_socket(self.up_ip, self.up_port)
155 self.sc = self._setup_listen_socket(self.client_ip, self.client_port)
157 def wait_conn_up(self):
158 conn, addr = self.su.accept()
159 nonce = _get_nonce()
160 print(_timestamp() + b"CONN: %s (nonce: %s) ... "%(_format_addr(addr), nonce), end='')
161 conn.send(nonce + b'\n')
162 cmd = conn.recv(1 + self.hashlen).rstrip(b'\n')
163 if cmd == _hash(b'adm' + self.secret + nonce):
164 if self.control_socket:
165 self.control_socket.close()
166 self.control_socket = conn
167 print("OK (admin)")
168 elif cmd == _hash(b'skt' + self.secret + nonce):
169 print("OK (tunnel)")
170 if not self.control_socket:
171 conn.close()
172 else:
173 self.next_upstream_socket = conn
174 else:
175 print("rejected!")
176 conn.close()
178 def wait_conn_client(self):
179 conn, addr = self.sc.accept()
180 self.control_socket.send(b"CONN:%s\n"%_format_addr(addr))
181 thread = Tunnel(self.next_upstream_socket, conn, addr)
182 thread.start()
183 self.threads.append(thread)
184 self.next_upstream_socket = None
186 def doit(self):
187 while True:
188 i = 0
189 while i < len(self.threads):
190 if self.threads[i].finished():
191 self.threads[i].reap()
192 self.threads.pop(i)
193 else:
194 i += 1
195 if not self.control_socket:
196 self.wait_conn_up()
197 if not self.next_upstream_socket:
198 self.wait_conn_up()
199 if self.control_socket and self.next_upstream_socket:
200 a,b,c = select.select([self.sc, self.control_socket, ], [], [])
201 if self.control_socket in a:
202 print("lost control socket")
203 self.control_socket.close()
204 self.control_socket = None
205 continue
206 if self.next_upstream_socket in a:
207 print("lost spare upstream socket")
208 self.next_upstream_socket.close()
209 self.next_upstream_socket = None
210 continue
211 if self.sc in a:
212 self.wait_conn_client()
215 if __name__ == "__main__":
216 import argparse
217 desc=(
218 "NAT Tunnel v0.01\n"
219 "----------------\n"
220 "If you have access to a server with public IP and unfiltered ports\n"
221 "you can run NAT Tunnel (NT) server on the server, and NT client\n"
222 "on your box behind NAT.\n"
223 "the server requires 2 open ports: one for communication with the\n"
224 "NT client (--admin), the other for regular clients to connect to\n"
225 "(--public: this is the port you want your users to use).\n"
226 "\n"
227 "The NT client opens a connection to the server's admin ip/port.\n"
228 "As soon as the server receives a new connection, it signals the\n"
229 "NT client, which then creates a new tunnel connection to the\n"
230 "server, which is then connected to the desired service on the\n"
231 "NT client's side (--local)\n"
232 "\n"
233 "The connection between NT Client and NT Server on the admin\n"
234 "interface is protected by a shared secret against unauthorized use.\n"
235 "An adversary who can intercept packets could crack the secret\n"
236 "if it's of insufficient complexity. At least 10 random\n"
237 "characters and numbers are recommended.\n"
238 "\n"
239 "Example:\n"
240 "You have a HTTP server listening on your local machine on port 80.\n"
241 "You want to make it available on your cloud server/VPS/etc's public\n"
242 "IP on port 7000.\n"
243 "We use port 8000 on the cloud server for the control channel.\n"
244 "\n"
245 "Server:\n"
246 " %s --mode server --secret s3cretP4ss --public 0.0.0.0:7000 --admin 0.0.0.0:8000\n"
247 "Client:\n"
248 " %s --mode client --secret s3cretP4ss --local localhost:80 --admin example.com:8000\n"
249 ) % (sys.argv[0], sys.argv[0])
250 if len(sys.argv) < 2 or (sys.argv[1] == '-h' or sys.argv[1] == '--help'):
251 print(desc)
252 parser = argparse.ArgumentParser(description='')
253 parser.add_argument('--secret', help='shared secret between natserver/client', type=str, default='', required=True)
254 parser.add_argument('--mode', help='work mode: server or client', type=str, default='server', required=True)
255 parser.add_argument('--public', help='(server only) ip:port where we will listen for regular clients', type=str, default='0.0.0.0:8080', required=False)
256 parser.add_argument('--local', help='(client only) ip:port of the local target service', type=str, default="localhost:80", required=False)
257 parser.add_argument('--admin', help='ip:port tuple for admin/upstream/control connection', type=str, default="0.0.0.0:8081", required=False)
258 args = parser.parse_args()
259 adminip, adminport = args.admin.split(':')
260 if args.mode == 'server':
261 clientip, clientport = args.public.split(':')
262 srv = NATSrv(_b(args.secret, 'utf-8'), adminip, int(adminport), clientip, int(clientport))
263 srv.setup()
264 srv.doit()
265 else:
266 localip, localport = args.local.split(':')
267 cl = NATClient(_b(args.secret, 'utf-8'), adminip, int(adminport), localip, int(localport))
268 cl.setup()
269 cl.doit()