1 # vim: set ts=4 et sw=4 tw=80
2 # This Source Code Form is subject to the terms of the Mozilla Public
3 # License, v. 2.0. If a copy of the MPL was not distributed with this
4 # file, You can obtain one at http://mozilla.org/MPL/2.0/.
10 import passlib
.utils
# for saslprep
19 from functools
import reduce
20 from string
import Template
21 from twisted
.internet
import reactor
, protocol
22 from twisted
.internet
.task
import LoopingCall
23 from twisted
.internet
.address
import IPv4Address
24 from twisted
.internet
.address
import IPv6Address
26 MAGIC_COOKIE
= 0x2112A442
38 CREATE_PERMISSION
= 0x008
41 # STUN spec chose silly values for these
45 MAPPED_ADDRESS
= 0x0001
47 MESSAGE_INTEGRITY
= 0x0008
49 UNKNOWN_ATTRIBUTES
= 0x000A
52 XOR_PEER_ADDRESS
= 0x0012
55 XOR_RELAYED_ADDRESS
= 0x0016
56 REQUESTED_TRANSPORT
= 0x0019
57 DONT_FRAGMENT
= 0x001A
58 XOR_MAPPED_ADDRESS
= 0x0020
60 ALTERNATE_SERVER
= 0x8023
66 TURN_REDIRECT_PORT
= 3479
67 TURNS_REDIRECT_PORT
= 5350
70 def unpack_uint(bytes_buf
):
72 for byte
in bytes_buf
:
73 result
= (result
<< 8) + byte
77 def pack_uint(value
, width
):
79 raise ValueError("Invalid value: {}".format(value
))
80 buf
= bytearray([0] * width
)
81 for i
in range(0, width
):
82 buf
[i
] = (value
>> (8 * (width
- i
- 1))) & 0xFF
87 def unpack(bytes_buf
, format_array
):
89 for width
in format_array
:
90 results
= results
+ (unpack_uint(bytes_buf
[0:width
]),)
91 bytes_buf
= bytes_buf
[width
:]
95 def pack(values
, format_array
):
96 if len(values
) != len(format_array
):
99 for i
in range(0, len(values
)):
100 buf
.extend(pack_uint(values
[i
], format_array
[i
]))
104 def bitwise_pack(source
, dest
, start_bit
, num_bits
):
105 if num_bits
<= 0 or num_bits
> start_bit
+ 1:
107 "Invalid num_bits: {}, start_bit = {}".format(num_bits
, start_bit
)
109 last_bit
= start_bit
- num_bits
+ 1
110 source
= source
>> last_bit
111 dest
= dest
<< num_bits
112 mask
= (1 << num_bits
) - 1
113 dest
+= source
& mask
117 def to_ipaddress(protocol
, host
, port
):
119 return IPv4Address(protocol
, host
, port
)
121 return IPv6Address(protocol
, host
, port
)
124 class StunAttribute(object):
126 Represents a STUN attribute in a raw format, according to the following:
129 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
130 +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
131 | StunAttribute.attr_type | Length (derived as needed) |
132 +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
133 | StunAttribute.data (variable length) ....
134 +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
137 __attr_header_fmt
= [2, 2]
138 __attr_header_size
= reduce(operator
.add
, __attr_header_fmt
)
140 def __init__(self
, attr_type
=0, buf
=bytearray()):
141 self
.attr_type
= attr_type
145 buf
= pack((self
.attr_type
, len(self
.data
)), self
.__attr
_header
_fmt
)
146 buf
.extend(self
.data
)
147 # add padding if necessary
149 buf
.extend([0] * (4 - (len(buf
) % 4)))
152 def parse(self
, buf
):
153 if self
.__attr
_header
_size
> len(buf
):
154 raise Exception("truncated at attribute: incomplete header")
156 self
.attr_type
, length
= unpack(buf
, self
.__attr
_header
_fmt
)
157 length
+= self
.__attr
_header
_size
159 if length
> len(buf
):
160 raise Exception("truncated at attribute: incomplete contents")
162 self
.data
= buf
[self
.__attr
_header
_size
: length
]
167 raise ValueError("Non-zero padding")
173 class StunMessage(object):
175 Represents a STUN message. Contains a method, msg_class, cookie,
176 transaction_id, and attributes (as an array of StunAttribute).
178 Has various functions for getting/adding attributes.
184 self
.cookie
= MAGIC_COOKIE
185 self
.transaction_id
= 0
189 # 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
190 # +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
191 # |0 0|M M M M M|C|M M M|C|M M M M| Message Length |
192 # +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
194 # +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
196 # | Transaction ID (96 bits) |
198 # +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
199 __header_fmt
= [2, 2, 4, 12]
200 __header_size
= reduce(operator
.add
, __header_fmt
)
202 # Returns how many bytes were parsed if buf was large enough, or how many
203 # bytes we would have needed if not. Throws if buf is malformed.
204 def parse(self
, buf
):
205 min_buf_size
= self
.__header
_size
206 if len(buf
) < min_buf_size
:
209 message_type
, length
, cookie
, self
.transaction_id
= unpack(
210 buf
, self
.__header
_fmt
212 min_buf_size
+= length
213 if len(buf
) < min_buf_size
:
217 self
.method
= bitwise_pack(message_type
, 0, 13, 5)
218 self
.msg_class
= bitwise_pack(message_type
, 0, 8, 1)
219 self
.method
= bitwise_pack(message_type
, self
.method
, 7, 3)
220 self
.msg_class
= bitwise_pack(message_type
, self
.msg_class
, 4, 1)
221 self
.method
= bitwise_pack(message_type
, self
.method
, 3, 4)
223 if cookie
!= self
.cookie
:
224 raise Exception("Invalid cookie: {}".format(cookie
))
226 buf
= buf
[self
.__header
_size
: min_buf_size
]
228 attr
= StunAttribute()
229 length
= attr
.parse(buf
)
231 self
.attributes
.append(attr
)
235 # stop_after_attr_type is useful for calculating MESSAGE-DIGEST
236 def build(self
, stop_after_attr_type
=0):
238 for attr
in self
.attributes
:
239 attrs
.extend(attr
.build())
240 if attr
.attr_type
== stop_after_attr_type
:
243 message_type
= bitwise_pack(self
.method
, 0, 11, 5)
244 message_type
= bitwise_pack(self
.msg_class
, message_type
, 1, 1)
245 message_type
= bitwise_pack(self
.method
, message_type
, 6, 3)
246 message_type
= bitwise_pack(self
.msg_class
, message_type
, 0, 1)
247 message_type
= bitwise_pack(self
.method
, message_type
, 3, 4)
250 (message_type
, len(attrs
), self
.cookie
, self
.transaction_id
),
253 message
.extend(attrs
)
257 def add_error_code(self
, code
, phrase
=None):
259 # 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
260 # +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
261 # | Reserved, should be 0 |Class| Number |
262 # +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
263 # | Reason Phrase (variable) ..
264 # +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
265 error_code_fmt
= [3, 1]
266 error_code
= pack((code
// 100, code
% 100), error_code_fmt
)
268 error_code
.extend(bytearray(phrase
, "utf-8"))
269 self
.attributes
.append(StunAttribute(ERROR_CODE
, error_code
))
272 # 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
273 # +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
274 # |x x x x x x x x| Family | X-Port |
275 # +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
276 # | X-Address (Variable)
277 # +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
278 __v4addr_fmt
= [1, 1, 2, 4]
279 __v6addr_fmt
= [1, 1, 2, 16]
280 __v4addr_size
= reduce(operator
.add
, __v4addr_fmt
)
281 __v6addr_size
= reduce(operator
.add
, __v6addr_fmt
)
283 def add_address(self
, ip_address
, version
, port
, attr_type
):
284 if version
== STUN_IPV4
:
285 address
= pack((0, STUN_IPV4
, port
, ip_address
), self
.__v
4addr
_fmt
)
286 elif version
== STUN_IPV6
:
287 address
= pack((0, STUN_IPV6
, port
, ip_address
), self
.__v
6addr
_fmt
)
289 raise ValueError("Invalid ip version: {}".format(version
))
290 self
.attributes
.append(StunAttribute(attr_type
, address
))
292 def get_xaddr(self
, ip_addr
, version
):
293 if version
== STUN_IPV4
:
294 return self
.cookie ^ ip_addr
295 elif version
== STUN_IPV6
:
296 return ((self
.cookie
<< 96) + self
.transaction_id
) ^ ip_addr
298 raise ValueError("Invalid family: {}".format(version
))
300 def get_xport(self
, port
):
301 return (self
.cookie
>> 16) ^ port
303 def add_xor_address(self
, addr_port
, attr_type
):
304 ip_address
= ipaddr
.IPAddress(addr_port
.host
)
305 version
= STUN_IPV6
if ip_address
.version
== 6 else STUN_IPV4
306 xaddr
= self
.get_xaddr(int(ip_address
), version
)
307 xport
= self
.get_xport(addr_port
.port
)
308 self
.add_address(xaddr
, version
, xport
, attr_type
)
310 def add_data(self
, buf
):
311 self
.attributes
.append(StunAttribute(DATA_ATTR
, buf
))
313 def find(self
, attr_type
):
314 for attr
in self
.attributes
:
315 if attr
.attr_type
== attr_type
:
319 def get_xor_address(self
, attr_type
):
320 addr_attr
= self
.find(attr_type
)
324 padding
, family
, xport
, xaddr
= unpack(addr_attr
.data
, self
.__v
4addr
_fmt
)
325 addr_ctor
= IPv4Address
326 if family
== STUN_IPV6
:
327 padding
, family
, xport
, xaddr
= unpack(addr_attr
.data
, self
.__v
6addr
_fmt
)
328 addr_ctor
= IPv6Address
329 elif family
!= STUN_IPV4
:
330 raise ValueError("Invalid family: {}".format(family
))
334 str(ipaddr
.IPAddress(self
.get_xaddr(xaddr
, family
))),
335 self
.get_xport(xport
),
338 def add_nonce(self
, nonce
):
339 self
.attributes
.append(StunAttribute(NONCE
, bytearray(nonce
, "utf-8")))
341 def add_realm(self
, realm
):
342 self
.attributes
.append(StunAttribute(REALM
, bytearray(realm
, "utf-8")))
344 def calculate_message_digest(self
, username
, realm
, password
):
345 digest_buf
= self
.build(MESSAGE_INTEGRITY
)
346 # Trim off the MESSAGE-INTEGRITY attr
347 digest_buf
= digest_buf
[: len(digest_buf
) - 24]
348 password
= passlib
.utils
.saslprep(six
.text_type(password
))
349 key_string
= "{}:{}:{}".format(username
, realm
, password
)
351 md5
.update(bytearray(key_string
, "utf-8"))
353 return bytearray(hmac
.new(key
, digest_buf
, hashlib
.sha1
).digest())
355 def add_lifetime(self
, lifetime
):
356 self
.attributes
.append(StunAttribute(LIFETIME
, pack_uint(lifetime
, 4)))
358 def get_lifetime(self
):
359 lifetime_attr
= self
.find(LIFETIME
)
360 if not lifetime_attr
:
362 return unpack_uint(lifetime_attr
.data
[0:4])
364 def get_username(self
):
365 username
= self
.find(USERNAME
)
368 return str(username
.data
)
370 def add_message_integrity(self
, username
, realm
, password
):
371 dummy_value
= bytearray([0] * 20)
372 self
.attributes
.append(StunAttribute(MESSAGE_INTEGRITY
, dummy_value
))
373 digest
= self
.calculate_message_digest(username
, realm
, password
)
374 self
.find(MESSAGE_INTEGRITY
).data
= digest
376 def add_alternate_server(self
, host
, port
):
377 address
= ipaddr
.IPAddress(host
)
378 version
= STUN_IPV6
if address
.version
== 6 else STUN_IPV4
379 self
.add_address(int(address
), version
, port
, ALTERNATE_SERVER
)
382 class Allocation(protocol
.DatagramProtocol
):
384 Comprises the socket for a TURN allocation, a back-reference to the
385 transport we will forward received traffic on, the allocator's address and
386 username, the set of permissions for the allocation, and the allocation's
390 def __init__(self
, other_transport_handler
, allocator_address
, username
):
391 self
.permissions
= set() # str, int tuples
392 # Handler to use when sending stuff that arrives on the allocation
393 self
.other_transport_handler
= other_transport_handler
394 self
.allocator_address
= allocator_address
395 self
.username
= username
396 self
.expiry
= time
.time()
397 self
.port
= reactor
.listenUDP(0, self
, interface
=v4_address
)
399 def datagramReceived(self
, data
, address
):
402 if not host
in self
.permissions
:
404 "Dropping packet from {}:{}, no permission on allocation {}".format(
405 host
, port
, self
.transport
.getHost()
410 data_indication
= StunMessage()
411 data_indication
.method
= DATA_MSG
412 data_indication
.msg_class
= INDICATION
413 data_indication
.transaction_id
= random
.getrandbits(96)
415 # Only handles UDP allocations. Doubtful that we need more than this.
416 data_indication
.add_xor_address(
417 to_ipaddress("UDP", host
, port
), XOR_PEER_ADDRESS
419 data_indication
.add_data(data
)
421 self
.other_transport_handler
.write(
422 data_indication
.build(), self
.allocator_address
426 self
.port
.stopListening()
430 class StunHandler(object):
432 Frames and handles STUN messages. This is the core logic of the TURN
433 server, along with Allocation.
436 def __init__(self
, transport_handler
):
437 self
.client_address
= None
438 self
.data
= bytearray()
439 self
.transport_handler
= transport_handler
441 def data_received(self
, data
, address
):
444 stun_message
= StunMessage()
445 parsed_len
= stun_message
.parse(self
.data
)
446 if parsed_len
> len(self
.data
):
448 self
.data
= self
.data
[parsed_len
:]
450 response
= self
.handle_stun(stun_message
, address
)
452 self
.transport_handler
.write(response
, address
)
454 def handle_stun(self
, stun_message
, address
):
455 self
.client_address
= address
456 if stun_message
.msg_class
== INDICATION
:
457 if stun_message
.method
== SEND
:
458 self
.handle_send_indication(stun_message
)
461 "Dropping unknown indication method: {}".format(stun_message
.method
)
465 if stun_message
.msg_class
!= REQUEST
:
466 print("Dropping STUN response, method: {}".format(stun_message
.method
))
469 if stun_message
.method
== BINDING
:
470 return self
.make_success_response(stun_message
).build()
471 elif stun_message
.method
== ALLOCATE
:
472 return self
.handle_allocation(stun_message
).build()
473 elif stun_message
.method
== REFRESH
:
474 return self
.handle_refresh(stun_message
).build()
475 elif stun_message
.method
== CREATE_PERMISSION
:
476 return self
.handle_permission(stun_message
).build()
478 return self
.make_error_response(
481 ("Unsupported STUN request, method: {}".format(stun_message
.method
)),
484 def get_allocation_tuple(self
):
486 self
.client_address
.host
,
487 self
.client_address
.port
,
488 self
.transport_handler
.transport
.getHost().type,
489 self
.transport_handler
.transport
.getHost().host
,
490 self
.transport_handler
.transport
.getHost().port
,
493 def handle_allocation(self
, request
):
494 allocate_response
= self
.check_long_term_auth(request
)
495 if allocate_response
.msg_class
== SUCCESS_RESPONSE
:
496 if self
.get_allocation_tuple() in allocations
:
497 return self
.make_error_response(
501 "Duplicate allocation request for tuple {}".format(
502 self
.get_allocation_tuple()
507 allocation
= Allocation(
508 self
.transport_handler
, self
.client_address
, request
.get_username()
511 allocate_response
.add_xor_address(
512 allocation
.transport
.getHost(), XOR_RELAYED_ADDRESS
515 lifetime
= request
.get_lifetime()
517 return self
.make_error_response(
518 request
, 400, "Missing lifetime attribute in allocation request"
521 lifetime
= min(lifetime
, 3600)
522 allocate_response
.add_lifetime(lifetime
)
523 allocation
.expiry
= time
.time() + lifetime
525 allocate_response
.add_message_integrity(turn_user
, turn_realm
, turn_pass
)
526 allocations
[self
.get_allocation_tuple()] = allocation
527 return allocate_response
529 def handle_refresh(self
, request
):
530 refresh_response
= self
.check_long_term_auth(request
)
531 if refresh_response
.msg_class
== SUCCESS_RESPONSE
:
533 allocation
= allocations
[self
.get_allocation_tuple()]
535 return self
.make_error_response(
539 "Refresh request for non-existing allocation, tuple {}".format(
540 self
.get_allocation_tuple()
545 if allocation
.username
!= request
.get_username():
546 return self
.make_error_response(
550 "Refresh request with wrong user, exp {}, got {}".format(
551 allocation
.username
, request
.get_username()
556 lifetime
= request
.get_lifetime()
558 return self
.make_error_response(
559 request
, 400, "Missing lifetime attribute in allocation request"
562 lifetime
= min(lifetime
, 3600)
563 refresh_response
.add_lifetime(lifetime
)
564 allocation
.expiry
= time
.time() + lifetime
566 refresh_response
.add_message_integrity(turn_user
, turn_realm
, turn_pass
)
567 return refresh_response
569 def handle_permission(self
, request
):
570 permission_response
= self
.check_long_term_auth(request
)
571 if permission_response
.msg_class
== SUCCESS_RESPONSE
:
573 allocation
= allocations
[self
.get_allocation_tuple()]
575 return self
.make_error_response(
579 "No such allocation for permission request, tuple {}".format(
580 self
.get_allocation_tuple()
585 if allocation
.username
!= request
.get_username():
586 return self
.make_error_response(
590 "Permission request with wrong user, exp {}, got {}".format(
591 allocation
.username
, request
.get_username()
596 # TODO: Handle multiple XOR-PEER-ADDRESS
597 peer_address
= request
.get_xor_address(XOR_PEER_ADDRESS
)
599 return self
.make_error_response(
600 request
, 400, "Missing XOR-PEER-ADDRESS on permission request"
603 permission_response
.add_message_integrity(turn_user
, turn_realm
, turn_pass
)
604 allocation
.permissions
.add(peer_address
.host
)
606 return permission_response
608 def handle_send_indication(self
, indication
):
610 allocation
= allocations
[self
.get_allocation_tuple()]
613 "Dropping send indication; no allocation for tuple {}".format(
614 self
.get_allocation_tuple()
619 peer_address
= indication
.get_xor_address(XOR_PEER_ADDRESS
)
621 print("Dropping send indication, missing XOR-PEER-ADDRESS")
624 data_attr
= indication
.find(DATA_ATTR
)
626 print("Dropping send indication, missing DATA")
629 if indication
.find(DONT_FRAGMENT
):
630 print("Dropping send indication, DONT-FRAGMENT set")
633 if not peer_address
.host
in allocation
.permissions
:
635 "Dropping send indication, no permission for {} on tuple {}".format(
636 peer_address
.host
, self
.get_allocation_tuple()
641 allocation
.transport
.write(
642 data_attr
.data
, (peer_address
.host
, peer_address
.port
)
645 def make_success_response(self
, request
):
646 response
= copy
.deepcopy(request
)
647 response
.attributes
= []
648 response
.add_xor_address(self
.client_address
, XOR_MAPPED_ADDRESS
)
649 response
.msg_class
= SUCCESS_RESPONSE
652 def make_error_response(self
, request
, code
, reason
=None):
654 print("{}: rejecting with {}".format(reason
, code
))
655 response
= copy
.deepcopy(request
)
656 response
.attributes
= []
657 response
.add_error_code(code
, reason
)
658 response
.msg_class
= ERROR_RESPONSE
661 def make_challenge_response(self
, request
, reason
=None):
662 response
= self
.make_error_response(request
, 401, reason
)
663 # 65 means the hex encoding will need padding half the time
664 response
.add_nonce("{:x}".format(random
.getrandbits(65)))
665 response
.add_realm(turn_realm
)
668 def check_long_term_auth(self
, request
):
669 message_integrity
= request
.find(MESSAGE_INTEGRITY
)
670 if not message_integrity
:
671 return self
.make_challenge_response(request
)
673 username
= request
.find(USERNAME
)
674 realm
= request
.find(REALM
)
675 nonce
= request
.find(NONCE
)
676 if not username
or not realm
or not nonce
:
677 return self
.make_error_response(
678 request
, 400, "Missing either USERNAME, NONCE, or REALM"
681 if username
.data
.decode("utf-8") != turn_user
:
682 return self
.make_challenge_response(
683 request
, "Wrong user {}, exp {}".format(username
.data
, turn_user
)
686 expected_message_digest
= request
.calculate_message_digest(
687 turn_user
, turn_realm
, turn_pass
689 if message_integrity
.data
!= expected_message_digest
:
690 return self
.make_challenge_response(request
, "Incorrect message disgest")
692 return self
.make_success_response(request
)
695 class StunRedirectHandler(StunHandler
):
697 Frames and handles STUN messages by redirecting to the "real" server port.
698 Performs the redirect with auth, so does a 401 to unauthed requests.
699 Can be used to test port-based redirect handling.
702 def __init__(self
, transport_handler
):
703 super(StunRedirectHandler
, self
).__init
__(transport_handler
)
705 def handle_stun(self
, stun_message
, address
):
706 self
.client_address
= address
707 if stun_message
.msg_class
== REQUEST
:
708 challenge_response
= self
.check_long_term_auth(stun_message
)
710 if challenge_response
.msg_class
== SUCCESS_RESPONSE
:
711 return self
.make_redirect_response(stun_message
).build()
713 return challenge_response
.build()
715 def make_redirect_response(self
, request
):
716 response
= self
.make_error_response(request
, 300, "Try alternate")
718 if self
.transport_handler
.transport
.getHost().port
== TURNS_REDIRECT_PORT
:
721 response
.add_alternate_server(
722 self
.transport_handler
.transport
.getHost().host
, port
725 response
.add_message_integrity(turn_user
, turn_realm
, turn_pass
)
729 class UdpStunHandler(protocol
.DatagramProtocol
):
731 Represents a UDP listen port for TURN.
734 def datagramReceived(self
, data
, address
):
735 stun_handler
= StunHandler(self
)
736 stun_handler
.data_received(data
, to_ipaddress("UDP", address
[0], address
[1]))
738 def write(self
, data
, address
):
739 self
.transport
.write(bytes(data
), (address
.host
, address
.port
))
742 class UdpStunRedirectHandler(protocol
.DatagramProtocol
):
744 Represents a UDP listen port for TURN that will redirect.
747 def datagramReceived(self
, data
, address
):
748 stun_handler
= StunRedirectHandler(self
)
749 stun_handler
.data_received(data
, to_ipaddress("UDP", address
[0], address
[1]))
751 def write(self
, data
, address
):
752 self
.transport
.write(bytes(data
), (address
.host
, address
.port
))
755 class TcpStunHandlerFactory(protocol
.Factory
):
757 Represents a TCP listen port for TURN.
760 def buildProtocol(self
, addr
):
761 return TcpStunHandler(addr
)
764 class TcpStunHandler(protocol
.Protocol
):
766 Represents a connected TCP port for TURN.
769 def __init__(self
, addr
):
771 self
.stun_handler
= None
773 def dataReceived(self
, data
):
774 # This needs to persist, since it handles framing
775 if not self
.stun_handler
:
776 self
.stun_handler
= StunHandler(self
)
777 self
.stun_handler
.data_received(data
, self
.address
)
779 def connectionLost(self
, reason
):
780 print("Lost connection from {}".format(self
.address
))
781 # Destroy allocations that this connection made
783 for key
, allocation
in allocations
.items():
784 if allocation
.other_transport_handler
== self
:
785 print("Closing allocation due to dropped connection: {}".format(key
))
786 keys_to_delete
.append(key
)
789 for key
in keys_to_delete
:
792 def write(self
, data
, address
):
793 self
.transport
.write(bytes(data
))
796 class TcpStunRedirectHandlerFactory(protocol
.Factory
):
798 Represents a TCP listen port for TURN that will redirect.
801 def buildProtocol(self
, addr
):
802 return TcpStunRedirectHandler(addr
)
805 class TcpStunRedirectHandler(protocol
.DatagramProtocol
):
806 def __init__(self
, addr
):
808 self
.stun_handler
= None
810 def dataReceived(self
, data
):
811 # This needs to persist, since it handles framing. Framing matters here
812 # because we do a round of auth before redirecting.
813 if not self
.stun_handler
:
814 self
.stun_handler
= StunRedirectHandler(self
)
815 self
.stun_handler
.data_received(data
, self
.address
)
817 def write(self
, data
, address
):
818 self
.transport
.write(bytes(data
))
820 def connectionLost(self
, reason
):
821 print("Lost connection from {}".format(self
.address
))
824 def get_default_route(family
):
825 dummy_socket
= socket
.socket(family
, socket
.SOCK_DGRAM
)
826 if family
is socket
.AF_INET
:
827 dummy_socket
.connect(("8.8.8.8", 53))
829 dummy_socket
.connect(("2001:4860:4860::8888", 53))
831 default_route
= dummy_socket
.getsockname()[0]
838 turn_realm
= "mozilla.invalid"
840 v4_address
= get_default_route(socket
.AF_INET
)
842 v6_address
= get_default_route(socket
.AF_INET6
)
847 def prune_allocations():
850 for key
, allocation
in allocations
.items():
851 if allocation
.expiry
< now
:
852 print("Allocation expired: {}".format(key
))
853 keys_to_delete
.append(key
)
856 for key
in keys_to_delete
:
860 CERT_FILE
= "selfsigned.crt"
861 KEY_FILE
= "private.key"
864 def create_self_signed_cert(name
):
865 # pyOpenSSL used to have some wrappers to help with this, but those have
866 # been deprecated, and they have instructed users to use stuff from
867 # cryptography.hazmat directly. This strikes me as a bad idea, but here we
870 from cryptography
.hazmat
.primitives
.asymmetric
import rsa
871 from cryptography
import x509
872 from cryptography
.x509
.oid
import NameOID
873 from cryptography
.hazmat
.primitives
import hashes
875 from cryptography
.hazmat
.primitives
import serialization
877 # Not ideal, but in order to avoid generating certs with duplicate serial
878 # numbers, we don't regenerate if there's one there already. If we wanted
879 # to regenerate, we'd need to load the cert if it was there, determine its
880 # serial number, and then make a new cert with a higher serial number.
881 if os
.path
.isfile(CERT_FILE
) and os
.path
.isfile(KEY_FILE
):
884 # Key size does not need to be big, this is a self-signed cert for testing,
885 # but I'm going to use something common to avoid warnings that might come
887 # Why 65537? Because the documentation says so, citing a document written
888 # by Colin Percival in 2009. Will this ever be out of date? Is it out of
889 # date already? Who knows!
890 key
= rsa
.generate_private_key(key_size
=2048, public_exponent
=65537)
894 x509
.NameAttribute(NameOID
.COUNTRY_NAME
, "US"),
895 x509
.NameAttribute(NameOID
.STATE_OR_PROVINCE_NAME
, "TX"),
896 x509
.NameAttribute(NameOID
.LOCALITY_NAME
, "Dallas"),
897 x509
.NameAttribute(NameOID
.ORGANIZATION_NAME
, "Mozilla test iceserver"),
898 x509
.NameAttribute(NameOID
.COMMON_NAME
, name
),
902 # create a self-signed cert
904 x509
.CertificateBuilder()
905 .subject_name(subject
)
906 .issuer_name(subject
)
908 .not_valid_before(datetime
.datetime
.now(datetime
.timezone
.utc
))
910 datetime
.datetime
.now(datetime
.timezone
.utc
) + datetime
.timedelta(days
=365)
912 .public_key(key
.public_key())
914 x509
.SubjectAlternativeName([x509
.DNSName(name
)]),
917 .sign(key
, hashes
.SHA256())
920 open(CERT_FILE
, "wb").write(cert
.public_bytes(encoding
=serialization
.Encoding
.PEM
))
921 open(KEY_FILE
, "wb").write(
923 encoding
=serialization
.Encoding
.PEM
,
924 format
=serialization
.PrivateFormat
.PKCS8
,
925 encryption_algorithm
=serialization
.NoEncryption(),
930 if __name__
== "__main__":
933 if platform
.system() == "Windows":
934 # Windows is finicky about allowing real interfaces to talk to loopback.
935 interface_4
= v4_address
936 interface_6
= v6_address
937 hostname
= socket
.gethostname()
939 # Our linux builders do not have a hostname that resolves to the real
941 interface_4
= "127.0.0.1"
943 hostname
= "localhost"
945 reactor
.listenUDP(STUN_PORT
, UdpStunHandler(), interface
=interface_4
)
946 reactor
.listenTCP(STUN_PORT
, TcpStunHandlerFactory(), interface
=interface_4
)
949 TURN_REDIRECT_PORT
, UdpStunRedirectHandler(), interface
=interface_4
952 TURN_REDIRECT_PORT
, TcpStunRedirectHandlerFactory(), interface
=interface_4
956 reactor
.listenUDP(STUN_PORT
, UdpStunHandler(), interface
=interface_6
)
957 reactor
.listenTCP(STUN_PORT
, TcpStunHandlerFactory(), interface
=interface_6
)
960 TURN_REDIRECT_PORT
, UdpStunRedirectHandler(), interface
=interface_6
963 TURN_REDIRECT_PORT
, TcpStunRedirectHandlerFactory(), interface
=interface_6
969 from twisted
.internet
import ssl
970 from OpenSSL
import SSL
972 create_self_signed_cert(hostname
)
973 tls_context_factory
= ssl
.DefaultOpenSSLContextFactory(
974 KEY_FILE
, CERT_FILE
, SSL
.TLSv1_2_METHOD
978 TcpStunHandlerFactory(),
980 interface
=interface_4
,
986 TcpStunHandlerFactory(),
988 interface
=interface_6
,
993 TcpStunRedirectHandlerFactory(),
995 interface
=interface_6
,
1000 f
= open(CERT_FILE
, "r")
1001 lines
= f
.readlines()
1002 lines
.pop(0) # Remove BEGIN CERTIFICATE
1003 lines
.pop() # Remove END CERTIFICATE
1004 # pylint --py3k: W1636 W1649
1005 lines
= list(map(str.strip
, lines
))
1006 certbase64
= "".join(lines
) # pylint --py3k: W1649
1008 turns_url
= ', "turns:' + hostname
+ '"'
1009 cert_prop
= ', "cert":"' + certbase64
+ '"'
1015 allocation_pruner
= LoopingCall(prune_allocations
)
1016 allocation_pruner
.start(1)
1018 template
= Template(
1020 {"urls":["stun:$hostname", "stun:$hostname?transport=tcp"]}, \
1021 {"username":"$user","credential":"$pwd","turn_redirect_port":"$TURN_REDIRECT_PORT","turns_redirect_port":"$TURNS_REDIRECT_PORT","urls": \
1022 ["turn:$hostname", "turn:$hostname?transport=tcp" $turns_url] \
1023 $cert_prop}]' # Hack to make it easier to override cert checks
1027 template
.substitute(
1031 turns_url
=turns_url
,
1032 cert_prop
=cert_prop
,
1033 TURN_REDIRECT_PORT
=TURN_REDIRECT_PORT
,
1034 TURNS_REDIRECT_PORT
=TURNS_REDIRECT_PORT
,