[tests] do not allow assert_raises_message to be called with JSONRPCException
[bitcoinplatinum.git] / test / functional / test_framework / util.py
blobed35bf576eb83b9a3c6587594beca73a1b00bfb5
1 #!/usr/bin/env python3
2 # Copyright (c) 2014-2016 The Bitcoin Core developers
3 # Distributed under the MIT software license, see the accompanying
4 # file COPYING or http://www.opensource.org/licenses/mit-license.php.
5 """Helpful routines for regression testing."""
7 from base64 import b64encode
8 from binascii import hexlify, unhexlify
9 from decimal import Decimal, ROUND_DOWN
10 import hashlib
11 import json
12 import logging
13 import os
14 import random
15 import re
16 from subprocess import CalledProcessError
17 import time
19 from . import coverage
20 from .authproxy import AuthServiceProxy, JSONRPCException
22 logger = logging.getLogger("TestFramework.utils")
24 # Assert functions
25 ##################
27 def assert_fee_amount(fee, tx_size, fee_per_kB):
28 """Assert the fee was in range"""
29 target_fee = tx_size * fee_per_kB / 1000
30 if fee < target_fee:
31 raise AssertionError("Fee of %s BTC too low! (Should be %s BTC)" % (str(fee), str(target_fee)))
32 # allow the wallet's estimation to be at most 2 bytes off
33 if fee > (tx_size + 2) * fee_per_kB / 1000:
34 raise AssertionError("Fee of %s BTC too high! (Should be %s BTC)" % (str(fee), str(target_fee)))
36 def assert_equal(thing1, thing2, *args):
37 if thing1 != thing2 or any(thing1 != arg for arg in args):
38 raise AssertionError("not(%s)" % " == ".join(str(arg) for arg in (thing1, thing2) + args))
40 def assert_greater_than(thing1, thing2):
41 if thing1 <= thing2:
42 raise AssertionError("%s <= %s" % (str(thing1), str(thing2)))
44 def assert_greater_than_or_equal(thing1, thing2):
45 if thing1 < thing2:
46 raise AssertionError("%s < %s" % (str(thing1), str(thing2)))
48 def assert_raises(exc, fun, *args, **kwds):
49 assert_raises_message(exc, None, fun, *args, **kwds)
51 def assert_raises_message(exc, message, fun, *args, **kwds):
52 try:
53 fun(*args, **kwds)
54 except JSONRPCException:
55 raise AssertionError("Use assert_raises_jsonrpc() to test RPC failures")
56 except exc as e:
57 if message is not None and message not in e.error['message']:
58 raise AssertionError("Expected substring not found:" + e.error['message'])
59 except Exception as e:
60 raise AssertionError("Unexpected exception raised: " + type(e).__name__)
61 else:
62 raise AssertionError("No exception raised")
64 def assert_raises_process_error(returncode, output, fun, *args, **kwds):
65 """Execute a process and asserts the process return code and output.
67 Calls function `fun` with arguments `args` and `kwds`. Catches a CalledProcessError
68 and verifies that the return code and output are as expected. Throws AssertionError if
69 no CalledProcessError was raised or if the return code and output are not as expected.
71 Args:
72 returncode (int): the process return code.
73 output (string): [a substring of] the process output.
74 fun (function): the function to call. This should execute a process.
75 args*: positional arguments for the function.
76 kwds**: named arguments for the function.
77 """
78 try:
79 fun(*args, **kwds)
80 except CalledProcessError as e:
81 if returncode != e.returncode:
82 raise AssertionError("Unexpected returncode %i" % e.returncode)
83 if output not in e.output:
84 raise AssertionError("Expected substring not found:" + e.output)
85 else:
86 raise AssertionError("No exception raised")
88 def assert_raises_jsonrpc(code, message, fun, *args, **kwds):
89 """Run an RPC and verify that a specific JSONRPC exception code and message is raised.
91 Calls function `fun` with arguments `args` and `kwds`. Catches a JSONRPCException
92 and verifies that the error code and message are as expected. Throws AssertionError if
93 no JSONRPCException was raised or if the error code/message are not as expected.
95 Args:
96 code (int), optional: the error code returned by the RPC call (defined
97 in src/rpc/protocol.h). Set to None if checking the error code is not required.
98 message (string), optional: [a substring of] the error string returned by the
99 RPC call. Set to None if checking the error string is not required.
100 fun (function): the function to call. This should be the name of an RPC.
101 args*: positional arguments for the function.
102 kwds**: named arguments for the function.
104 assert try_rpc(code, message, fun, *args, **kwds), "No exception raised"
106 def try_rpc(code, message, fun, *args, **kwds):
107 """Tries to run an rpc command.
109 Test against error code and message if the rpc fails.
110 Returns whether a JSONRPCException was raised."""
111 try:
112 fun(*args, **kwds)
113 except JSONRPCException as e:
114 # JSONRPCException was thrown as expected. Check the code and message values are correct.
115 if (code is not None) and (code != e.error["code"]):
116 raise AssertionError("Unexpected JSONRPC error code %i" % e.error["code"])
117 if (message is not None) and (message not in e.error['message']):
118 raise AssertionError("Expected substring not found:" + e.error['message'])
119 return True
120 except Exception as e:
121 raise AssertionError("Unexpected exception raised: " + type(e).__name__)
122 else:
123 return False
125 def assert_is_hex_string(string):
126 try:
127 int(string, 16)
128 except Exception as e:
129 raise AssertionError(
130 "Couldn't interpret %r as hexadecimal; raised: %s" % (string, e))
132 def assert_is_hash_string(string, length=64):
133 if not isinstance(string, str):
134 raise AssertionError("Expected a string, got type %r" % type(string))
135 elif length and len(string) != length:
136 raise AssertionError(
137 "String of length %d expected; got %d" % (length, len(string)))
138 elif not re.match('[abcdef0-9]+$', string):
139 raise AssertionError(
140 "String %r contains invalid characters for a hash." % string)
142 def assert_array_result(object_array, to_match, expected, should_not_find=False):
144 Pass in array of JSON objects, a dictionary with key/value pairs
145 to match against, and another dictionary with expected key/value
146 pairs.
147 If the should_not_find flag is true, to_match should not be found
148 in object_array
150 if should_not_find:
151 assert_equal(expected, {})
152 num_matched = 0
153 for item in object_array:
154 all_match = True
155 for key, value in to_match.items():
156 if item[key] != value:
157 all_match = False
158 if not all_match:
159 continue
160 elif should_not_find:
161 num_matched = num_matched + 1
162 for key, value in expected.items():
163 if item[key] != value:
164 raise AssertionError("%s : expected %s=%s" % (str(item), str(key), str(value)))
165 num_matched = num_matched + 1
166 if num_matched == 0 and not should_not_find:
167 raise AssertionError("No objects matched %s" % (str(to_match)))
168 if num_matched > 0 and should_not_find:
169 raise AssertionError("Objects were found %s" % (str(to_match)))
171 # Utility functions
172 ###################
174 def check_json_precision():
175 """Make sure json library being used does not lose precision converting BTC values"""
176 n = Decimal("20000000.00000003")
177 satoshis = int(json.loads(json.dumps(float(n))) * 1.0e8)
178 if satoshis != 2000000000000003:
179 raise RuntimeError("JSON encode/decode loses precision")
181 def count_bytes(hex_string):
182 return len(bytearray.fromhex(hex_string))
184 def bytes_to_hex_str(byte_str):
185 return hexlify(byte_str).decode('ascii')
187 def hash256(byte_str):
188 sha256 = hashlib.sha256()
189 sha256.update(byte_str)
190 sha256d = hashlib.sha256()
191 sha256d.update(sha256.digest())
192 return sha256d.digest()[::-1]
194 def hex_str_to_bytes(hex_str):
195 return unhexlify(hex_str.encode('ascii'))
197 def str_to_b64str(string):
198 return b64encode(string.encode('utf-8')).decode('ascii')
200 def satoshi_round(amount):
201 return Decimal(amount).quantize(Decimal('0.00000001'), rounding=ROUND_DOWN)
203 def wait_until(predicate, *, attempts=float('inf'), timeout=float('inf'), lock=None):
204 if attempts == float('inf') and timeout == float('inf'):
205 timeout = 60
206 attempt = 0
207 timeout += time.time()
209 while attempt < attempts and time.time() < timeout:
210 if lock:
211 with lock:
212 if predicate():
213 return
214 else:
215 if predicate():
216 return
217 attempt += 1
218 time.sleep(0.05)
220 # Print the cause of the timeout
221 assert_greater_than(attempts, attempt)
222 assert_greater_than(timeout, time.time())
223 raise RuntimeError('Unreachable')
225 # RPC/P2P connection constants and functions
226 ############################################
228 # The maximum number of nodes a single test can spawn
229 MAX_NODES = 8
230 # Don't assign rpc or p2p ports lower than this
231 PORT_MIN = 11000
232 # The number of ports to "reserve" for p2p and rpc, each
233 PORT_RANGE = 5000
235 class PortSeed:
236 # Must be initialized with a unique integer for each process
237 n = None
239 def get_rpc_proxy(url, node_number, timeout=None, coveragedir=None):
241 Args:
242 url (str): URL of the RPC server to call
243 node_number (int): the node number (or id) that this calls to
245 Kwargs:
246 timeout (int): HTTP timeout in seconds
248 Returns:
249 AuthServiceProxy. convenience object for making RPC calls.
252 proxy_kwargs = {}
253 if timeout is not None:
254 proxy_kwargs['timeout'] = timeout
256 proxy = AuthServiceProxy(url, **proxy_kwargs)
257 proxy.url = url # store URL on proxy for info
259 coverage_logfile = coverage.get_filename(
260 coveragedir, node_number) if coveragedir else None
262 return coverage.AuthServiceProxyWrapper(proxy, coverage_logfile)
264 def p2p_port(n):
265 assert(n <= MAX_NODES)
266 return PORT_MIN + n + (MAX_NODES * PortSeed.n) % (PORT_RANGE - 1 - MAX_NODES)
268 def rpc_port(n):
269 return PORT_MIN + PORT_RANGE + n + (MAX_NODES * PortSeed.n) % (PORT_RANGE - 1 - MAX_NODES)
271 def rpc_url(datadir, i, rpchost=None):
272 rpc_u, rpc_p = get_auth_cookie(datadir)
273 host = '127.0.0.1'
274 port = rpc_port(i)
275 if rpchost:
276 parts = rpchost.split(':')
277 if len(parts) == 2:
278 host, port = parts
279 else:
280 host = rpchost
281 return "http://%s:%s@%s:%d" % (rpc_u, rpc_p, host, int(port))
283 # Node functions
284 ################
286 def initialize_datadir(dirname, n):
287 datadir = os.path.join(dirname, "node" + str(n))
288 if not os.path.isdir(datadir):
289 os.makedirs(datadir)
290 with open(os.path.join(datadir, "bitcoin.conf"), 'w', encoding='utf8') as f:
291 f.write("regtest=1\n")
292 f.write("port=" + str(p2p_port(n)) + "\n")
293 f.write("rpcport=" + str(rpc_port(n)) + "\n")
294 f.write("listenonion=0\n")
295 return datadir
297 def get_datadir_path(dirname, n):
298 return os.path.join(dirname, "node" + str(n))
300 def get_auth_cookie(datadir):
301 user = None
302 password = None
303 if os.path.isfile(os.path.join(datadir, "bitcoin.conf")):
304 with open(os.path.join(datadir, "bitcoin.conf"), 'r', encoding='utf8') as f:
305 for line in f:
306 if line.startswith("rpcuser="):
307 assert user is None # Ensure that there is only one rpcuser line
308 user = line.split("=")[1].strip("\n")
309 if line.startswith("rpcpassword="):
310 assert password is None # Ensure that there is only one rpcpassword line
311 password = line.split("=")[1].strip("\n")
312 if os.path.isfile(os.path.join(datadir, "regtest", ".cookie")):
313 with open(os.path.join(datadir, "regtest", ".cookie"), 'r') as f:
314 userpass = f.read()
315 split_userpass = userpass.split(':')
316 user = split_userpass[0]
317 password = split_userpass[1]
318 if user is None or password is None:
319 raise ValueError("No RPC credentials")
320 return user, password
322 def log_filename(dirname, n_node, logname):
323 return os.path.join(dirname, "node" + str(n_node), "regtest", logname)
325 def get_bip9_status(node, key):
326 info = node.getblockchaininfo()
327 return info['bip9_softforks'][key]
329 def set_node_times(nodes, t):
330 for node in nodes:
331 node.setmocktime(t)
333 def disconnect_nodes(from_connection, node_num):
334 for peer_id in [peer['id'] for peer in from_connection.getpeerinfo() if "testnode%d" % node_num in peer['subver']]:
335 from_connection.disconnectnode(nodeid=peer_id)
337 for _ in range(50):
338 if [peer['id'] for peer in from_connection.getpeerinfo() if "testnode%d" % node_num in peer['subver']] == []:
339 break
340 time.sleep(0.1)
341 else:
342 raise AssertionError("timed out waiting for disconnect")
344 def connect_nodes(from_connection, node_num):
345 ip_port = "127.0.0.1:" + str(p2p_port(node_num))
346 from_connection.addnode(ip_port, "onetry")
347 # poll until version handshake complete to avoid race conditions
348 # with transaction relaying
349 while any(peer['version'] == 0 for peer in from_connection.getpeerinfo()):
350 time.sleep(0.1)
352 def connect_nodes_bi(nodes, a, b):
353 connect_nodes(nodes[a], b)
354 connect_nodes(nodes[b], a)
356 def sync_blocks(rpc_connections, *, wait=1, timeout=60):
358 Wait until everybody has the same tip.
360 sync_blocks needs to be called with an rpc_connections set that has least
361 one node already synced to the latest, stable tip, otherwise there's a
362 chance it might return before all nodes are stably synced.
364 # Use getblockcount() instead of waitforblockheight() to determine the
365 # initial max height because the two RPCs look at different internal global
366 # variables (chainActive vs latestBlock) and the former gets updated
367 # earlier.
368 maxheight = max(x.getblockcount() for x in rpc_connections)
369 start_time = cur_time = time.time()
370 while cur_time <= start_time + timeout:
371 tips = [r.waitforblockheight(maxheight, int(wait * 1000)) for r in rpc_connections]
372 if all(t["height"] == maxheight for t in tips):
373 if all(t["hash"] == tips[0]["hash"] for t in tips):
374 return
375 raise AssertionError("Block sync failed, mismatched block hashes:{}".format(
376 "".join("\n {!r}".format(tip) for tip in tips)))
377 cur_time = time.time()
378 raise AssertionError("Block sync to height {} timed out:{}".format(
379 maxheight, "".join("\n {!r}".format(tip) for tip in tips)))
381 def sync_chain(rpc_connections, *, wait=1, timeout=60):
383 Wait until everybody has the same best block
385 while timeout > 0:
386 best_hash = [x.getbestblockhash() for x in rpc_connections]
387 if best_hash == [best_hash[0]] * len(best_hash):
388 return
389 time.sleep(wait)
390 timeout -= wait
391 raise AssertionError("Chain sync failed: Best block hashes don't match")
393 def sync_mempools(rpc_connections, *, wait=1, timeout=60):
395 Wait until everybody has the same transactions in their memory
396 pools
398 while timeout > 0:
399 pool = set(rpc_connections[0].getrawmempool())
400 num_match = 1
401 for i in range(1, len(rpc_connections)):
402 if set(rpc_connections[i].getrawmempool()) == pool:
403 num_match = num_match + 1
404 if num_match == len(rpc_connections):
405 return
406 time.sleep(wait)
407 timeout -= wait
408 raise AssertionError("Mempool sync failed")
410 # Transaction/Block functions
411 #############################
413 def find_output(node, txid, amount):
415 Return index to output of txid with value amount
416 Raises exception if there is none.
418 txdata = node.getrawtransaction(txid, 1)
419 for i in range(len(txdata["vout"])):
420 if txdata["vout"][i]["value"] == amount:
421 return i
422 raise RuntimeError("find_output txid %s : %s not found" % (txid, str(amount)))
424 def gather_inputs(from_node, amount_needed, confirmations_required=1):
426 Return a random set of unspent txouts that are enough to pay amount_needed
428 assert(confirmations_required >= 0)
429 utxo = from_node.listunspent(confirmations_required)
430 random.shuffle(utxo)
431 inputs = []
432 total_in = Decimal("0.00000000")
433 while total_in < amount_needed and len(utxo) > 0:
434 t = utxo.pop()
435 total_in += t["amount"]
436 inputs.append({"txid": t["txid"], "vout": t["vout"], "address": t["address"]})
437 if total_in < amount_needed:
438 raise RuntimeError("Insufficient funds: need %d, have %d" % (amount_needed, total_in))
439 return (total_in, inputs)
441 def make_change(from_node, amount_in, amount_out, fee):
443 Create change output(s), return them
445 outputs = {}
446 amount = amount_out + fee
447 change = amount_in - amount
448 if change > amount * 2:
449 # Create an extra change output to break up big inputs
450 change_address = from_node.getnewaddress()
451 # Split change in two, being careful of rounding:
452 outputs[change_address] = Decimal(change / 2).quantize(Decimal('0.00000001'), rounding=ROUND_DOWN)
453 change = amount_in - amount - outputs[change_address]
454 if change > 0:
455 outputs[from_node.getnewaddress()] = change
456 return outputs
458 def random_transaction(nodes, amount, min_fee, fee_increment, fee_variants):
460 Create a random transaction.
461 Returns (txid, hex-encoded-transaction-data, fee)
463 from_node = random.choice(nodes)
464 to_node = random.choice(nodes)
465 fee = min_fee + fee_increment * random.randint(0, fee_variants)
467 (total_in, inputs) = gather_inputs(from_node, amount + fee)
468 outputs = make_change(from_node, total_in, amount, fee)
469 outputs[to_node.getnewaddress()] = float(amount)
471 rawtx = from_node.createrawtransaction(inputs, outputs)
472 signresult = from_node.signrawtransaction(rawtx)
473 txid = from_node.sendrawtransaction(signresult["hex"], True)
475 return (txid, signresult["hex"], fee)
477 # Helper to create at least "count" utxos
478 # Pass in a fee that is sufficient for relay and mining new transactions.
479 def create_confirmed_utxos(fee, node, count):
480 to_generate = int(0.5 * count) + 101
481 while to_generate > 0:
482 node.generate(min(25, to_generate))
483 to_generate -= 25
484 utxos = node.listunspent()
485 iterations = count - len(utxos)
486 addr1 = node.getnewaddress()
487 addr2 = node.getnewaddress()
488 if iterations <= 0:
489 return utxos
490 for i in range(iterations):
491 t = utxos.pop()
492 inputs = []
493 inputs.append({"txid": t["txid"], "vout": t["vout"]})
494 outputs = {}
495 send_value = t['amount'] - fee
496 outputs[addr1] = satoshi_round(send_value / 2)
497 outputs[addr2] = satoshi_round(send_value / 2)
498 raw_tx = node.createrawtransaction(inputs, outputs)
499 signed_tx = node.signrawtransaction(raw_tx)["hex"]
500 node.sendrawtransaction(signed_tx)
502 while (node.getmempoolinfo()['size'] > 0):
503 node.generate(1)
505 utxos = node.listunspent()
506 assert(len(utxos) >= count)
507 return utxos
509 # Create large OP_RETURN txouts that can be appended to a transaction
510 # to make it large (helper for constructing large transactions).
511 def gen_return_txouts():
512 # Some pre-processing to create a bunch of OP_RETURN txouts to insert into transactions we create
513 # So we have big transactions (and therefore can't fit very many into each block)
514 # create one script_pubkey
515 script_pubkey = "6a4d0200" # OP_RETURN OP_PUSH2 512 bytes
516 for i in range(512):
517 script_pubkey = script_pubkey + "01"
518 # concatenate 128 txouts of above script_pubkey which we'll insert before the txout for change
519 txouts = "81"
520 for k in range(128):
521 # add txout value
522 txouts = txouts + "0000000000000000"
523 # add length of script_pubkey
524 txouts = txouts + "fd0402"
525 # add script_pubkey
526 txouts = txouts + script_pubkey
527 return txouts
529 def create_tx(node, coinbase, to_address, amount):
530 inputs = [{"txid": coinbase, "vout": 0}]
531 outputs = {to_address: amount}
532 rawtx = node.createrawtransaction(inputs, outputs)
533 signresult = node.signrawtransaction(rawtx)
534 assert_equal(signresult["complete"], True)
535 return signresult["hex"]
537 # Create a spend of each passed-in utxo, splicing in "txouts" to each raw
538 # transaction to make it large. See gen_return_txouts() above.
539 def create_lots_of_big_transactions(node, txouts, utxos, num, fee):
540 addr = node.getnewaddress()
541 txids = []
542 for _ in range(num):
543 t = utxos.pop()
544 inputs = [{"txid": t["txid"], "vout": t["vout"]}]
545 outputs = {}
546 change = t['amount'] - fee
547 outputs[addr] = satoshi_round(change)
548 rawtx = node.createrawtransaction(inputs, outputs)
549 newtx = rawtx[0:92]
550 newtx = newtx + txouts
551 newtx = newtx + rawtx[94:]
552 signresult = node.signrawtransaction(newtx, None, None, "NONE")
553 txid = node.sendrawtransaction(signresult["hex"], True)
554 txids.append(txid)
555 return txids
557 def mine_large_block(node, utxos=None):
558 # generate a 66k transaction,
559 # and 14 of them is close to the 1MB block limit
560 num = 14
561 txouts = gen_return_txouts()
562 utxos = utxos if utxos is not None else []
563 if len(utxos) < num:
564 utxos.clear()
565 utxos.extend(node.listunspent())
566 fee = 100 * node.getnetworkinfo()["relayfee"]
567 create_lots_of_big_transactions(node, txouts, utxos, num, fee=fee)
568 node.generate(1)