Merge remote-tracking branch 'tor-gitlab/mr/513' into main
[tor.git] / src / test / ntor_v3_ref.py
blob28bc0771050cace69bf1cfd60f70694496777091
1 #!/usr/bin/python
3 import binascii
4 import hashlib
5 import os
6 import struct
8 import donna25519
9 from Crypto.Cipher import AES
10 from Crypto.Util import Counter
12 # Define basic wrappers.
14 DIGEST_LEN = 32
15 ENC_KEY_LEN = 32
16 PUB_KEY_LEN = 32
17 SEC_KEY_LEN = 32
18 IDENTITY_LEN = 32
20 def sha3_256(s):
21 d = hashlib.sha3_256(s).digest()
22 assert len(d) == DIGEST_LEN
23 return d
25 def shake_256(s):
26 # Note: In reality, you wouldn't want to generate more bytes than needed.
27 MAX_KEY_BYTES = 1024
28 return hashlib.shake_256(s).digest(MAX_KEY_BYTES)
30 def curve25519(pk, sk):
31 assert len(pk) == PUB_KEY_LEN
32 assert len(sk) == SEC_KEY_LEN
33 private = donna25519.PrivateKey.load(sk)
34 public = donna25519.PublicKey(pk)
35 return private.do_exchange(public)
37 def keygen():
38 private = donna25519.PrivateKey()
39 public = private.get_public()
40 return (private.private, public.public)
42 def aes256_ctr(k, s):
43 assert len(k) == ENC_KEY_LEN
44 cipher = AES.new(k, AES.MODE_CTR, counter=Counter.new(128, initial_value=0))
45 return cipher.encrypt(s)
47 # Byte-oriented helper. We use this for decoding keystreams and messages.
49 class ByteSeq:
50 def __init__(self, data):
51 self.data = data
53 def take(self, n):
54 assert n <= len(self.data)
55 result = self.data[:n]
56 self.data = self.data[n:]
57 return result
59 def exhausted(self):
60 return len(self.data) == 0
62 def remaining(self):
63 return len(self.data)
65 # Low-level functions
67 MAC_KEY_LEN = 32
68 MAC_LEN = DIGEST_LEN
70 hash_func = sha3_256
72 def encapsulate(s):
73 """encapsulate `s` with a length prefix.
75 We use this whenever we need to avoid message ambiguities in
76 cryptographic inputs.
77 """
78 assert len(s) <= 0xffffffff
79 header = b"\0\0\0\0" + struct.pack("!L", len(s))
80 assert len(header) == 8
81 return header + s
83 def h(s, tweak):
84 return hash_func(encapsulate(tweak) + s)
86 def mac(s, key, tweak):
87 return hash_func(encapsulate(tweak) + encapsulate(key) + s)
89 def kdf(s, tweak):
90 data = shake_256(encapsulate(tweak) + s)
91 return ByteSeq(data)
93 def enc(s, k):
94 return aes256_ctr(k, s)
96 # Tweaked wrappers
98 PROTOID = b"ntor3-curve25519-sha3_256-1"
99 T_KDF_PHASE1 = PROTOID + b":kdf_phase1"
100 T_MAC_PHASE1 = PROTOID + b":msg_mac"
101 T_KDF_FINAL = PROTOID + b":kdf_final"
102 T_KEY_SEED = PROTOID + b":key_seed"
103 T_VERIFY = PROTOID + b":verify"
104 T_AUTH = PROTOID + b":auth_final"
106 def kdf_phase1(s):
107 return kdf(s, T_KDF_PHASE1)
109 def kdf_final(s):
110 return kdf(s, T_KDF_FINAL)
112 def mac_phase1(s, key):
113 return mac(s, key, T_MAC_PHASE1)
115 def h_key_seed(s):
116 return h(s, T_KEY_SEED)
118 def h_verify(s):
119 return h(s, T_VERIFY)
121 def h_auth(s):
122 return h(s, T_AUTH)
124 # Handshake.
126 def client_phase1(msg, verification, B, ID):
127 assert len(B) == PUB_KEY_LEN
128 assert len(ID) == IDENTITY_LEN
130 (x,X) = keygen()
131 p(["x", "X"], locals())
132 p(["msg", "verification"], locals())
133 Bx = curve25519(B, x)
134 secret_input_phase1 = Bx + ID + X + B + PROTOID + encapsulate(verification)
136 phase1_keys = kdf_phase1(secret_input_phase1)
137 enc_key = phase1_keys.take(ENC_KEY_LEN)
138 mac_key = phase1_keys.take(MAC_KEY_LEN)
139 p(["enc_key", "mac_key"], locals())
141 msg_0 = ID + B + X + enc(msg, enc_key)
142 mac = mac_phase1(msg_0, mac_key)
143 p(["mac"], locals())
145 client_handshake = msg_0 + mac
146 state = dict(x=x, X=X, B=B, ID=ID, Bx=Bx, mac=mac, verification=verification)
148 p(["client_handshake"], locals())
150 return (client_handshake, state)
152 # server.
154 class Reject(Exception):
155 pass
157 def server_part1(cmsg, verification, b, B, ID):
158 assert len(B) == PUB_KEY_LEN
159 assert len(ID) == IDENTITY_LEN
160 assert len(b) == SEC_KEY_LEN
162 if len(cmsg) < (IDENTITY_LEN + PUB_KEY_LEN * 2 + MAC_LEN):
163 raise Reject()
165 mac_covered_portion = cmsg[0:-MAC_LEN]
166 cmsg = ByteSeq(cmsg)
167 cmsg_id = cmsg.take(IDENTITY_LEN)
168 cmsg_B = cmsg.take(PUB_KEY_LEN)
169 cmsg_X = cmsg.take(PUB_KEY_LEN)
170 cmsg_msg = cmsg.take(cmsg.remaining() - MAC_LEN)
171 cmsg_mac = cmsg.take(MAC_LEN)
173 assert cmsg.exhausted()
175 # XXXX for real purposes, you would use constant-time checks here
176 if cmsg_id != ID or cmsg_B != B:
177 raise Reject()
179 Xb = curve25519(cmsg_X, b)
180 secret_input_phase1 = Xb + ID + cmsg_X + B + PROTOID + encapsulate(verification)
182 phase1_keys = kdf_phase1(secret_input_phase1)
183 enc_key = phase1_keys.take(ENC_KEY_LEN)
184 mac_key = phase1_keys.take(MAC_KEY_LEN)
186 mac_received = mac_phase1(mac_covered_portion, mac_key)
187 if mac_received != cmsg_mac:
188 raise Reject()
190 client_msg = enc(cmsg_msg, enc_key)
191 state = dict(
192 b=b,
193 B=B,
194 X=cmsg_X,
195 mac_received=mac_received,
196 Xb=Xb,
197 ID=ID,
198 verification=verification)
200 return (client_msg, state)
202 def server_part2(state, server_msg):
203 X = state['X']
204 Xb = state['Xb']
205 B = state['B']
206 b = state['b']
207 ID = state['ID']
208 mac_received = state['mac_received']
209 verification = state['verification']
211 p(["server_msg"], locals())
213 (y,Y) = keygen()
214 p(["y", "Y"], locals())
215 Xy = curve25519(X, y)
217 secret_input = Xy + Xb + ID + B + X + Y + PROTOID + encapsulate(verification)
218 key_seed = h_key_seed(secret_input)
219 verify = h_verify(secret_input)
220 p(["key_seed", "verify"], locals())
222 keys = kdf_final(key_seed)
223 server_enc_key = keys.take(ENC_KEY_LEN)
224 p(["server_enc_key"], locals())
226 smsg_msg = enc(server_msg, server_enc_key)
228 auth_input = verify + ID + B + Y + X + mac_received + encapsulate(smsg_msg) + PROTOID + b"Server"
230 auth = h_auth(auth_input)
231 server_handshake = Y + auth + smsg_msg
232 p(["auth", "server_handshake"], locals())
234 return (server_handshake, keys)
236 def client_phase2(state, smsg):
237 x = state['x']
238 X = state['X']
239 B = state['B']
240 ID = state['ID']
241 Bx = state['Bx']
242 mac_sent = state['mac']
243 verification = state['verification']
245 if len(smsg) < PUB_KEY_LEN + DIGEST_LEN:
246 raise Reject()
248 smsg = ByteSeq(smsg)
249 Y = smsg.take(PUB_KEY_LEN)
250 auth_received = smsg.take(DIGEST_LEN)
251 server_msg = smsg.take(smsg.remaining())
253 Yx = curve25519(Y,x)
255 secret_input = Yx + Bx + ID + B + X + Y + PROTOID + encapsulate(verification)
256 key_seed = h_key_seed(secret_input)
257 verify = h_verify(secret_input)
259 auth_input = verify + ID + B + Y + X + mac_sent + encapsulate(server_msg) + PROTOID + b"Server"
261 auth = h_auth(auth_input)
262 if auth != auth_received:
263 raise Reject()
265 keys = kdf_final(key_seed)
266 enc_key = keys.take(ENC_KEY_LEN)
268 server_msg_decrypted = enc(server_msg, enc_key)
270 return (keys, server_msg_decrypted)
272 def p(varnames, localvars):
273 for v in varnames:
274 label = v
275 val = localvars[label]
276 print('{} = "{}"'.format(label, binascii.b2a_hex(val).decode("ascii")))
278 def test():
279 (b,B) = keygen()
280 ID = os.urandom(IDENTITY_LEN)
282 p(["b", "B", "ID"], locals())
284 print("# ============")
285 (c_handshake, c_state) = client_phase1(b"hello world", b"xyzzy", B, ID)
287 print("# ============")
289 (c_msg_got, s_state) = server_part1(c_handshake, b"xyzzy", b, B, ID)
291 #print(repr(c_msg_got))
293 (s_handshake, s_keys) = server_part2(s_state, b"Hola Mundo")
295 print("# ============")
297 (c_keys, s_msg_got) = client_phase2(c_state, s_handshake)
299 #print(repr(s_msg_got))
301 c_keys_256 = c_keys.take(256)
302 p(["c_keys_256"], locals())
304 assert (c_keys_256 == s_keys.take(256))
307 if __name__ == '__main__':
308 test()