9 from Crypto
.Cipher
import AES
10 from Crypto
.Util
import Counter
12 # Define basic wrappers.
21 d
= hashlib
.sha3_256(s
).digest()
22 assert len(d
) == DIGEST_LEN
26 # Note: In reality, you wouldn't want to generate more bytes than needed.
28 return hashlib
.shake_256(s
).digest(MAX_KEY_BYTES
)
30 def curve25519(pk
, sk
):
31 assert len(pk
) == PUB_KEY_LEN
32 assert len(sk
) == SEC_KEY_LEN
33 private
= donna25519
.PrivateKey
.load(sk
)
34 public
= donna25519
.PublicKey(pk
)
35 return private
.do_exchange(public
)
38 private
= donna25519
.PrivateKey()
39 public
= private
.get_public()
40 return (private
.private
, public
.public
)
43 assert len(k
) == ENC_KEY_LEN
44 cipher
= AES
.new(k
, AES
.MODE_CTR
, counter
=Counter
.new(128, initial_value
=0))
45 return cipher
.encrypt(s
)
47 # Byte-oriented helper. We use this for decoding keystreams and messages.
50 def __init__(self
, data
):
54 assert n
<= len(self
.data
)
55 result
= self
.data
[:n
]
56 self
.data
= self
.data
[n
:]
60 return len(self
.data
) == 0
73 """encapsulate `s` with a length prefix.
75 We use this whenever we need to avoid message ambiguities in
78 assert len(s
) <= 0xffffffff
79 header
= b
"\0\0\0\0" + struct
.pack("!L", len(s
))
80 assert len(header
) == 8
84 return hash_func(encapsulate(tweak
) + s
)
86 def mac(s
, key
, tweak
):
87 return hash_func(encapsulate(tweak
) + encapsulate(key
) + s
)
90 data
= shake_256(encapsulate(tweak
) + s
)
94 return aes256_ctr(k
, s
)
98 PROTOID
= b
"ntor3-curve25519-sha3_256-1"
99 T_KDF_PHASE1
= PROTOID
+ b
":kdf_phase1"
100 T_MAC_PHASE1
= PROTOID
+ b
":msg_mac"
101 T_KDF_FINAL
= PROTOID
+ b
":kdf_final"
102 T_KEY_SEED
= PROTOID
+ b
":key_seed"
103 T_VERIFY
= PROTOID
+ b
":verify"
104 T_AUTH
= PROTOID
+ b
":auth_final"
107 return kdf(s
, T_KDF_PHASE1
)
110 return kdf(s
, T_KDF_FINAL
)
112 def mac_phase1(s
, key
):
113 return mac(s
, key
, T_MAC_PHASE1
)
116 return h(s
, T_KEY_SEED
)
119 return h(s
, T_VERIFY
)
126 def client_phase1(msg
, verification
, B
, ID
):
127 assert len(B
) == PUB_KEY_LEN
128 assert len(ID
) == IDENTITY_LEN
131 p(["x", "X"], locals())
132 p(["msg", "verification"], locals())
133 Bx
= curve25519(B
, x
)
134 secret_input_phase1
= Bx
+ ID
+ X
+ B
+ PROTOID
+ encapsulate(verification
)
136 phase1_keys
= kdf_phase1(secret_input_phase1
)
137 enc_key
= phase1_keys
.take(ENC_KEY_LEN
)
138 mac_key
= phase1_keys
.take(MAC_KEY_LEN
)
139 p(["enc_key", "mac_key"], locals())
141 msg_0
= ID
+ B
+ X
+ enc(msg
, enc_key
)
142 mac
= mac_phase1(msg_0
, mac_key
)
145 client_handshake
= msg_0
+ mac
146 state
= dict(x
=x
, X
=X
, B
=B
, ID
=ID
, Bx
=Bx
, mac
=mac
, verification
=verification
)
148 p(["client_handshake"], locals())
150 return (client_handshake
, state
)
154 class Reject(Exception):
157 def server_part1(cmsg
, verification
, b
, B
, ID
):
158 assert len(B
) == PUB_KEY_LEN
159 assert len(ID
) == IDENTITY_LEN
160 assert len(b
) == SEC_KEY_LEN
162 if len(cmsg
) < (IDENTITY_LEN
+ PUB_KEY_LEN
* 2 + MAC_LEN
):
165 mac_covered_portion
= cmsg
[0:-MAC_LEN
]
167 cmsg_id
= cmsg
.take(IDENTITY_LEN
)
168 cmsg_B
= cmsg
.take(PUB_KEY_LEN
)
169 cmsg_X
= cmsg
.take(PUB_KEY_LEN
)
170 cmsg_msg
= cmsg
.take(cmsg
.remaining() - MAC_LEN
)
171 cmsg_mac
= cmsg
.take(MAC_LEN
)
173 assert cmsg
.exhausted()
175 # XXXX for real purposes, you would use constant-time checks here
176 if cmsg_id
!= ID
or cmsg_B
!= B
:
179 Xb
= curve25519(cmsg_X
, b
)
180 secret_input_phase1
= Xb
+ ID
+ cmsg_X
+ B
+ PROTOID
+ encapsulate(verification
)
182 phase1_keys
= kdf_phase1(secret_input_phase1
)
183 enc_key
= phase1_keys
.take(ENC_KEY_LEN
)
184 mac_key
= phase1_keys
.take(MAC_KEY_LEN
)
186 mac_received
= mac_phase1(mac_covered_portion
, mac_key
)
187 if mac_received
!= cmsg_mac
:
190 client_msg
= enc(cmsg_msg
, enc_key
)
195 mac_received
=mac_received
,
198 verification
=verification
)
200 return (client_msg
, state
)
202 def server_part2(state
, server_msg
):
208 mac_received
= state
['mac_received']
209 verification
= state
['verification']
211 p(["server_msg"], locals())
214 p(["y", "Y"], locals())
215 Xy
= curve25519(X
, y
)
217 secret_input
= Xy
+ Xb
+ ID
+ B
+ X
+ Y
+ PROTOID
+ encapsulate(verification
)
218 key_seed
= h_key_seed(secret_input
)
219 verify
= h_verify(secret_input
)
220 p(["key_seed", "verify"], locals())
222 keys
= kdf_final(key_seed
)
223 server_enc_key
= keys
.take(ENC_KEY_LEN
)
224 p(["server_enc_key"], locals())
226 smsg_msg
= enc(server_msg
, server_enc_key
)
228 auth_input
= verify
+ ID
+ B
+ Y
+ X
+ mac_received
+ encapsulate(smsg_msg
) + PROTOID
+ b
"Server"
230 auth
= h_auth(auth_input
)
231 server_handshake
= Y
+ auth
+ smsg_msg
232 p(["auth", "server_handshake"], locals())
234 return (server_handshake
, keys
)
236 def client_phase2(state
, smsg
):
242 mac_sent
= state
['mac']
243 verification
= state
['verification']
245 if len(smsg
) < PUB_KEY_LEN
+ DIGEST_LEN
:
249 Y
= smsg
.take(PUB_KEY_LEN
)
250 auth_received
= smsg
.take(DIGEST_LEN
)
251 server_msg
= smsg
.take(smsg
.remaining())
255 secret_input
= Yx
+ Bx
+ ID
+ B
+ X
+ Y
+ PROTOID
+ encapsulate(verification
)
256 key_seed
= h_key_seed(secret_input
)
257 verify
= h_verify(secret_input
)
259 auth_input
= verify
+ ID
+ B
+ Y
+ X
+ mac_sent
+ encapsulate(server_msg
) + PROTOID
+ b
"Server"
261 auth
= h_auth(auth_input
)
262 if auth
!= auth_received
:
265 keys
= kdf_final(key_seed
)
266 enc_key
= keys
.take(ENC_KEY_LEN
)
268 server_msg_decrypted
= enc(server_msg
, enc_key
)
270 return (keys
, server_msg_decrypted
)
272 def p(varnames
, localvars
):
275 val
= localvars
[label
]
276 print('{} = "{}"'.format(label
, binascii
.b2a_hex(val
).decode("ascii")))
280 ID
= os
.urandom(IDENTITY_LEN
)
282 p(["b", "B", "ID"], locals())
284 print("# ============")
285 (c_handshake
, c_state
) = client_phase1(b
"hello world", b
"xyzzy", B
, ID
)
287 print("# ============")
289 (c_msg_got
, s_state
) = server_part1(c_handshake
, b
"xyzzy", b
, B
, ID
)
291 #print(repr(c_msg_got))
293 (s_handshake
, s_keys
) = server_part2(s_state
, b
"Hola Mundo")
295 print("# ============")
297 (c_keys
, s_msg_got
) = client_phase2(c_state
, s_handshake
)
299 #print(repr(s_msg_got))
301 c_keys_256
= c_keys
.take(256)
302 p(["c_keys_256"], locals())
304 assert (c_keys_256
== s_keys
.take(256))
307 if __name__
== '__main__':