Edit changelog a little for clarity and conciseness
[tor.git] / src / test / slownacl_curve25519.py
blob4dabab61b6dcd56a12f8634ade7d431aab578a65
1 # This is the curve25519 implementation from Matthew Dempsky's "Slownacl"
2 # library. It is in the public domain.
4 # It isn't constant-time. Don't use it except for testing.
6 # Nick got the slownacl source from:
7 # https://github.com/mdempsky/dnscurve/tree/master/slownacl
9 __all__ = ['smult_curve25519_base', 'smult_curve25519']
11 import sys
13 P = 2 ** 255 - 19
14 A = 486662
16 def expmod(b, e, m):
17 if e == 0: return 1
18 t = expmod(b, e // 2, m) ** 2 % m
19 if e & 1: t = (t * b) % m
20 return t
22 def inv(x):
23 return expmod(x, P - 2, P)
25 # Addition and doubling formulas taken from Appendix D of "Curve25519:
26 # new Diffie-Hellman speed records".
28 def add(n,m,d):
29 (xn,zn), (xm,zm), (xd,zd) = n, m, d
30 x = 4 * (xm * xn - zm * zn) ** 2 * zd
31 z = 4 * (xm * zn - zm * xn) ** 2 * xd
32 return (x % P, z % P)
34 def double(n):
35 (xn,zn) = n
36 x = (xn ** 2 - zn ** 2) ** 2
37 z = 4 * xn * zn * (xn ** 2 + A * xn * zn + zn ** 2)
38 return (x % P, z % P)
40 def curve25519(n, base):
41 one = (base,1)
42 two = double(one)
43 # f(m) evaluates to a tuple containing the mth multiple and the
44 # (m+1)th multiple of base.
45 def f(m):
46 if m == 1: return (one, two)
47 (pm, pm1) = f(m // 2)
48 if (m & 1):
49 return (add(pm, pm1, one), double(pm1))
50 return (double(pm), add(pm, pm1, one))
51 ((x,z), _) = f(n)
52 return (x * inv(z)) % P
54 if sys.version < '3':
55 def b2i(c):
56 return ord(c)
57 def i2b(i):
58 return chr(i)
59 def ba2bs(ba):
60 return "".join(ba)
61 else:
62 def b2i(c):
63 return c
64 def i2b(i):
65 return i
66 def ba2bs(ba):
67 return bytes(ba)
69 def unpack(s):
70 if len(s) != 32: raise ValueError('Invalid Curve25519 argument')
71 return sum(b2i(s[i]) << (8 * i) for i in range(32))
73 def pack(n):
74 return ba2bs([i2b((n >> (8 * i)) & 255) for i in range(32)])
76 def clamp(n):
77 n &= ~7
78 n &= ~(128 << 8 * 31)
79 n |= 64 << 8 * 31
80 return n
82 def smult_curve25519(n, p):
83 n = clamp(unpack(n))
84 p = unpack(p)
85 return pack(curve25519(n, p))
87 def smult_curve25519_base(n):
88 n = clamp(unpack(n))
89 return pack(curve25519(n, 9))
93 # This part I'm adding in for compatibility with the curve25519 python
94 # module. -Nick
96 import os
98 class Private:
99 def __init__(self, secret=None, seed=None):
100 self.private = pack(clamp(unpack(os.urandom(32))))
102 def get_public(self):
103 return Public(smult_curve25519_base(self.private))
105 def get_shared_key(self, public, hashfn):
106 return hashfn(smult_curve25519(self.private, public.public))
108 def serialize(self):
109 return self.private
111 class Public:
112 def __init__(self, public):
113 self.public = public
115 def serialize(self):
116 return self.public