Remove unused Python imports
[bitcoinplatinum.git] / test / functional / test_framework / key.py
blobaa91fb5b0d83272ab24255f85a73c59bcee1e8c4
1 # Copyright (c) 2011 Sam Rushing
2 """ECC secp256k1 OpenSSL wrapper.
4 WARNING: This module does not mlock() secrets; your private keys may end up on
5 disk in swap! Use with caution!
7 This file is modified from python-bitcoinlib.
8 """
10 import ctypes
11 import ctypes.util
12 import hashlib
13 import sys
15 ssl = ctypes.cdll.LoadLibrary(ctypes.util.find_library ('ssl') or 'libeay32')
17 ssl.BN_new.restype = ctypes.c_void_p
18 ssl.BN_new.argtypes = []
20 ssl.BN_bin2bn.restype = ctypes.c_void_p
21 ssl.BN_bin2bn.argtypes = [ctypes.c_char_p, ctypes.c_int, ctypes.c_void_p]
23 ssl.BN_CTX_free.restype = None
24 ssl.BN_CTX_free.argtypes = [ctypes.c_void_p]
26 ssl.BN_CTX_new.restype = ctypes.c_void_p
27 ssl.BN_CTX_new.argtypes = []
29 ssl.ECDH_compute_key.restype = ctypes.c_int
30 ssl.ECDH_compute_key.argtypes = [ctypes.c_void_p, ctypes.c_int, ctypes.c_void_p, ctypes.c_void_p]
32 ssl.ECDSA_sign.restype = ctypes.c_int
33 ssl.ECDSA_sign.argtypes = [ctypes.c_int, ctypes.c_void_p, ctypes.c_int, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p]
35 ssl.ECDSA_verify.restype = ctypes.c_int
36 ssl.ECDSA_verify.argtypes = [ctypes.c_int, ctypes.c_void_p, ctypes.c_int, ctypes.c_void_p, ctypes.c_int, ctypes.c_void_p]
38 ssl.EC_KEY_free.restype = None
39 ssl.EC_KEY_free.argtypes = [ctypes.c_void_p]
41 ssl.EC_KEY_new_by_curve_name.restype = ctypes.c_void_p
42 ssl.EC_KEY_new_by_curve_name.argtypes = [ctypes.c_int]
44 ssl.EC_KEY_get0_group.restype = ctypes.c_void_p
45 ssl.EC_KEY_get0_group.argtypes = [ctypes.c_void_p]
47 ssl.EC_KEY_get0_public_key.restype = ctypes.c_void_p
48 ssl.EC_KEY_get0_public_key.argtypes = [ctypes.c_void_p]
50 ssl.EC_KEY_set_private_key.restype = ctypes.c_int
51 ssl.EC_KEY_set_private_key.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
53 ssl.EC_KEY_set_conv_form.restype = None
54 ssl.EC_KEY_set_conv_form.argtypes = [ctypes.c_void_p, ctypes.c_int]
56 ssl.EC_KEY_set_public_key.restype = ctypes.c_int
57 ssl.EC_KEY_set_public_key.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
59 ssl.i2o_ECPublicKey.restype = ctypes.c_void_p
60 ssl.i2o_ECPublicKey.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
62 ssl.EC_POINT_new.restype = ctypes.c_void_p
63 ssl.EC_POINT_new.argtypes = [ctypes.c_void_p]
65 ssl.EC_POINT_free.restype = None
66 ssl.EC_POINT_free.argtypes = [ctypes.c_void_p]
68 ssl.EC_POINT_mul.restype = ctypes.c_int
69 ssl.EC_POINT_mul.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p]
71 # this specifies the curve used with ECDSA.
72 NID_secp256k1 = 714 # from openssl/obj_mac.h
74 SECP256K1_ORDER = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141
75 SECP256K1_ORDER_HALF = SECP256K1_ORDER // 2
77 # Thx to Sam Devlin for the ctypes magic 64-bit fix.
78 def _check_result(val, func, args):
79 if val == 0:
80 raise ValueError
81 else:
82 return ctypes.c_void_p (val)
84 ssl.EC_KEY_new_by_curve_name.restype = ctypes.c_void_p
85 ssl.EC_KEY_new_by_curve_name.errcheck = _check_result
87 class CECKey():
88 """Wrapper around OpenSSL's EC_KEY"""
90 POINT_CONVERSION_COMPRESSED = 2
91 POINT_CONVERSION_UNCOMPRESSED = 4
93 def __init__(self):
94 self.k = ssl.EC_KEY_new_by_curve_name(NID_secp256k1)
96 def __del__(self):
97 if ssl:
98 ssl.EC_KEY_free(self.k)
99 self.k = None
101 def set_secretbytes(self, secret):
102 priv_key = ssl.BN_bin2bn(secret, 32, ssl.BN_new())
103 group = ssl.EC_KEY_get0_group(self.k)
104 pub_key = ssl.EC_POINT_new(group)
105 ctx = ssl.BN_CTX_new()
106 if not ssl.EC_POINT_mul(group, pub_key, priv_key, None, None, ctx):
107 raise ValueError("Could not derive public key from the supplied secret.")
108 ssl.EC_POINT_mul(group, pub_key, priv_key, None, None, ctx)
109 ssl.EC_KEY_set_private_key(self.k, priv_key)
110 ssl.EC_KEY_set_public_key(self.k, pub_key)
111 ssl.EC_POINT_free(pub_key)
112 ssl.BN_CTX_free(ctx)
113 return self.k
115 def set_privkey(self, key):
116 self.mb = ctypes.create_string_buffer(key)
117 return ssl.d2i_ECPrivateKey(ctypes.byref(self.k), ctypes.byref(ctypes.pointer(self.mb)), len(key))
119 def set_pubkey(self, key):
120 self.mb = ctypes.create_string_buffer(key)
121 return ssl.o2i_ECPublicKey(ctypes.byref(self.k), ctypes.byref(ctypes.pointer(self.mb)), len(key))
123 def get_privkey(self):
124 size = ssl.i2d_ECPrivateKey(self.k, 0)
125 mb_pri = ctypes.create_string_buffer(size)
126 ssl.i2d_ECPrivateKey(self.k, ctypes.byref(ctypes.pointer(mb_pri)))
127 return mb_pri.raw
129 def get_pubkey(self):
130 size = ssl.i2o_ECPublicKey(self.k, 0)
131 mb = ctypes.create_string_buffer(size)
132 ssl.i2o_ECPublicKey(self.k, ctypes.byref(ctypes.pointer(mb)))
133 return mb.raw
135 def get_raw_ecdh_key(self, other_pubkey):
136 ecdh_keybuffer = ctypes.create_string_buffer(32)
137 r = ssl.ECDH_compute_key(ctypes.pointer(ecdh_keybuffer), 32,
138 ssl.EC_KEY_get0_public_key(other_pubkey.k),
139 self.k, 0)
140 if r != 32:
141 raise Exception('CKey.get_ecdh_key(): ECDH_compute_key() failed')
142 return ecdh_keybuffer.raw
144 def get_ecdh_key(self, other_pubkey, kdf=lambda k: hashlib.sha256(k).digest()):
145 # FIXME: be warned it's not clear what the kdf should be as a default
146 r = self.get_raw_ecdh_key(other_pubkey)
147 return kdf(r)
149 def sign(self, hash, low_s = True):
150 # FIXME: need unit tests for below cases
151 if not isinstance(hash, bytes):
152 raise TypeError('Hash must be bytes instance; got %r' % hash.__class__)
153 if len(hash) != 32:
154 raise ValueError('Hash must be exactly 32 bytes long')
156 sig_size0 = ctypes.c_uint32()
157 sig_size0.value = ssl.ECDSA_size(self.k)
158 mb_sig = ctypes.create_string_buffer(sig_size0.value)
159 result = ssl.ECDSA_sign(0, hash, len(hash), mb_sig, ctypes.byref(sig_size0), self.k)
160 assert 1 == result
161 assert mb_sig.raw[0] == 0x30
162 assert mb_sig.raw[1] == sig_size0.value - 2
163 total_size = mb_sig.raw[1]
164 assert mb_sig.raw[2] == 2
165 r_size = mb_sig.raw[3]
166 assert mb_sig.raw[4 + r_size] == 2
167 s_size = mb_sig.raw[5 + r_size]
168 s_value = int.from_bytes(mb_sig.raw[6+r_size:6+r_size+s_size], byteorder='big')
169 if (not low_s) or s_value <= SECP256K1_ORDER_HALF:
170 return mb_sig.raw[:sig_size0.value]
171 else:
172 low_s_value = SECP256K1_ORDER - s_value
173 low_s_bytes = (low_s_value).to_bytes(33, byteorder='big')
174 while len(low_s_bytes) > 1 and low_s_bytes[0] == 0 and low_s_bytes[1] < 0x80:
175 low_s_bytes = low_s_bytes[1:]
176 new_s_size = len(low_s_bytes)
177 new_total_size_byte = (total_size + new_s_size - s_size).to_bytes(1,byteorder='big')
178 new_s_size_byte = (new_s_size).to_bytes(1,byteorder='big')
179 return b'\x30' + new_total_size_byte + mb_sig.raw[2:5+r_size] + new_s_size_byte + low_s_bytes
181 def verify(self, hash, sig):
182 """Verify a DER signature"""
183 return ssl.ECDSA_verify(0, hash, len(hash), sig, len(sig), self.k) == 1
185 def set_compressed(self, compressed):
186 if compressed:
187 form = self.POINT_CONVERSION_COMPRESSED
188 else:
189 form = self.POINT_CONVERSION_UNCOMPRESSED
190 ssl.EC_KEY_set_conv_form(self.k, form)
193 class CPubKey(bytes):
194 """An encapsulated public key
196 Attributes:
198 is_valid - Corresponds to CPubKey.IsValid()
199 is_fullyvalid - Corresponds to CPubKey.IsFullyValid()
200 is_compressed - Corresponds to CPubKey.IsCompressed()
203 def __new__(cls, buf, _cec_key=None):
204 self = super(CPubKey, cls).__new__(cls, buf)
205 if _cec_key is None:
206 _cec_key = CECKey()
207 self._cec_key = _cec_key
208 self.is_fullyvalid = _cec_key.set_pubkey(self) != 0
209 return self
211 @property
212 def is_valid(self):
213 return len(self) > 0
215 @property
216 def is_compressed(self):
217 return len(self) == 33
219 def verify(self, hash, sig):
220 return self._cec_key.verify(hash, sig)
222 def __str__(self):
223 return repr(self)
225 def __repr__(self):
226 # Always have represent as b'<secret>' so test cases don't have to
227 # change for py2/3
228 if sys.version > '3':
229 return '%s(%s)' % (self.__class__.__name__, super(CPubKey, self).__repr__())
230 else:
231 return '%s(b%s)' % (self.__class__.__name__, super(CPubKey, self).__repr__())