[tests] remove direct testing on JSONRPCException from individual test cases
[bitcoinplatinum.git] / test / functional / test_framework / util.py
blob44d7e04a898096b9b1dff961f89e6cab85cce665
1 #!/usr/bin/env python3
2 # Copyright (c) 2014-2016 The Bitcoin Core developers
3 # Distributed under the MIT software license, see the accompanying
4 # file COPYING or http://www.opensource.org/licenses/mit-license.php.
5 """Helpful routines for regression testing."""
7 from base64 import b64encode
8 from binascii import hexlify, unhexlify
9 from decimal import Decimal, ROUND_DOWN
10 import hashlib
11 import json
12 import logging
13 import os
14 import random
15 import re
16 from subprocess import CalledProcessError
17 import time
19 from . import coverage
20 from .authproxy import AuthServiceProxy, JSONRPCException
22 logger = logging.getLogger("TestFramework.utils")
24 # Assert functions
25 ##################
27 def assert_fee_amount(fee, tx_size, fee_per_kB):
28 """Assert the fee was in range"""
29 target_fee = tx_size * fee_per_kB / 1000
30 if fee < target_fee:
31 raise AssertionError("Fee of %s BTC too low! (Should be %s BTC)" % (str(fee), str(target_fee)))
32 # allow the wallet's estimation to be at most 2 bytes off
33 if fee > (tx_size + 2) * fee_per_kB / 1000:
34 raise AssertionError("Fee of %s BTC too high! (Should be %s BTC)" % (str(fee), str(target_fee)))
36 def assert_equal(thing1, thing2, *args):
37 if thing1 != thing2 or any(thing1 != arg for arg in args):
38 raise AssertionError("not(%s)" % " == ".join(str(arg) for arg in (thing1, thing2) + args))
40 def assert_greater_than(thing1, thing2):
41 if thing1 <= thing2:
42 raise AssertionError("%s <= %s" % (str(thing1), str(thing2)))
44 def assert_greater_than_or_equal(thing1, thing2):
45 if thing1 < thing2:
46 raise AssertionError("%s < %s" % (str(thing1), str(thing2)))
48 def assert_raises(exc, fun, *args, **kwds):
49 assert_raises_message(exc, None, fun, *args, **kwds)
51 def assert_raises_message(exc, message, fun, *args, **kwds):
52 try:
53 fun(*args, **kwds)
54 except exc as e:
55 if message is not None and message not in e.error['message']:
56 raise AssertionError("Expected substring not found:" + e.error['message'])
57 except Exception as e:
58 raise AssertionError("Unexpected exception raised: " + type(e).__name__)
59 else:
60 raise AssertionError("No exception raised")
62 def assert_raises_process_error(returncode, output, fun, *args, **kwds):
63 """Execute a process and asserts the process return code and output.
65 Calls function `fun` with arguments `args` and `kwds`. Catches a CalledProcessError
66 and verifies that the return code and output are as expected. Throws AssertionError if
67 no CalledProcessError was raised or if the return code and output are not as expected.
69 Args:
70 returncode (int): the process return code.
71 output (string): [a substring of] the process output.
72 fun (function): the function to call. This should execute a process.
73 args*: positional arguments for the function.
74 kwds**: named arguments for the function.
75 """
76 try:
77 fun(*args, **kwds)
78 except CalledProcessError as e:
79 if returncode != e.returncode:
80 raise AssertionError("Unexpected returncode %i" % e.returncode)
81 if output not in e.output:
82 raise AssertionError("Expected substring not found:" + e.output)
83 else:
84 raise AssertionError("No exception raised")
86 def assert_raises_jsonrpc(code, message, fun, *args, **kwds):
87 """Run an RPC and verify that a specific JSONRPC exception code and message is raised.
89 Calls function `fun` with arguments `args` and `kwds`. Catches a JSONRPCException
90 and verifies that the error code and message are as expected. Throws AssertionError if
91 no JSONRPCException was raised or if the error code/message are not as expected.
93 Args:
94 code (int), optional: the error code returned by the RPC call (defined
95 in src/rpc/protocol.h). Set to None if checking the error code is not required.
96 message (string), optional: [a substring of] the error string returned by the
97 RPC call. Set to None if checking the error string is not required.
98 fun (function): the function to call. This should be the name of an RPC.
99 args*: positional arguments for the function.
100 kwds**: named arguments for the function.
102 assert try_rpc(code, message, fun, *args, **kwds), "No exception raised"
104 def try_rpc(code, message, fun, *args, **kwds):
105 """Tries to run an rpc command.
107 Test against error code and message if the rpc fails.
108 Returns whether a JSONRPCException was raised."""
109 try:
110 fun(*args, **kwds)
111 except JSONRPCException as e:
112 # JSONRPCException was thrown as expected. Check the code and message values are correct.
113 if (code is not None) and (code != e.error["code"]):
114 raise AssertionError("Unexpected JSONRPC error code %i" % e.error["code"])
115 if (message is not None) and (message not in e.error['message']):
116 raise AssertionError("Expected substring not found:" + e.error['message'])
117 return True
118 except Exception as e:
119 raise AssertionError("Unexpected exception raised: " + type(e).__name__)
120 else:
121 return False
123 def assert_is_hex_string(string):
124 try:
125 int(string, 16)
126 except Exception as e:
127 raise AssertionError(
128 "Couldn't interpret %r as hexadecimal; raised: %s" % (string, e))
130 def assert_is_hash_string(string, length=64):
131 if not isinstance(string, str):
132 raise AssertionError("Expected a string, got type %r" % type(string))
133 elif length and len(string) != length:
134 raise AssertionError(
135 "String of length %d expected; got %d" % (length, len(string)))
136 elif not re.match('[abcdef0-9]+$', string):
137 raise AssertionError(
138 "String %r contains invalid characters for a hash." % string)
140 def assert_array_result(object_array, to_match, expected, should_not_find=False):
142 Pass in array of JSON objects, a dictionary with key/value pairs
143 to match against, and another dictionary with expected key/value
144 pairs.
145 If the should_not_find flag is true, to_match should not be found
146 in object_array
148 if should_not_find:
149 assert_equal(expected, {})
150 num_matched = 0
151 for item in object_array:
152 all_match = True
153 for key, value in to_match.items():
154 if item[key] != value:
155 all_match = False
156 if not all_match:
157 continue
158 elif should_not_find:
159 num_matched = num_matched + 1
160 for key, value in expected.items():
161 if item[key] != value:
162 raise AssertionError("%s : expected %s=%s" % (str(item), str(key), str(value)))
163 num_matched = num_matched + 1
164 if num_matched == 0 and not should_not_find:
165 raise AssertionError("No objects matched %s" % (str(to_match)))
166 if num_matched > 0 and should_not_find:
167 raise AssertionError("Objects were found %s" % (str(to_match)))
169 # Utility functions
170 ###################
172 def check_json_precision():
173 """Make sure json library being used does not lose precision converting BTC values"""
174 n = Decimal("20000000.00000003")
175 satoshis = int(json.loads(json.dumps(float(n))) * 1.0e8)
176 if satoshis != 2000000000000003:
177 raise RuntimeError("JSON encode/decode loses precision")
179 def count_bytes(hex_string):
180 return len(bytearray.fromhex(hex_string))
182 def bytes_to_hex_str(byte_str):
183 return hexlify(byte_str).decode('ascii')
185 def hash256(byte_str):
186 sha256 = hashlib.sha256()
187 sha256.update(byte_str)
188 sha256d = hashlib.sha256()
189 sha256d.update(sha256.digest())
190 return sha256d.digest()[::-1]
192 def hex_str_to_bytes(hex_str):
193 return unhexlify(hex_str.encode('ascii'))
195 def str_to_b64str(string):
196 return b64encode(string.encode('utf-8')).decode('ascii')
198 def satoshi_round(amount):
199 return Decimal(amount).quantize(Decimal('0.00000001'), rounding=ROUND_DOWN)
201 def wait_until(predicate, *, attempts=float('inf'), timeout=float('inf'), lock=None):
202 if attempts == float('inf') and timeout == float('inf'):
203 timeout = 60
204 attempt = 0
205 timeout += time.time()
207 while attempt < attempts and time.time() < timeout:
208 if lock:
209 with lock:
210 if predicate():
211 return
212 else:
213 if predicate():
214 return
215 attempt += 1
216 time.sleep(0.05)
218 # Print the cause of the timeout
219 assert_greater_than(attempts, attempt)
220 assert_greater_than(timeout, time.time())
221 raise RuntimeError('Unreachable')
223 # RPC/P2P connection constants and functions
224 ############################################
226 # The maximum number of nodes a single test can spawn
227 MAX_NODES = 8
228 # Don't assign rpc or p2p ports lower than this
229 PORT_MIN = 11000
230 # The number of ports to "reserve" for p2p and rpc, each
231 PORT_RANGE = 5000
233 class PortSeed:
234 # Must be initialized with a unique integer for each process
235 n = None
237 def get_rpc_proxy(url, node_number, timeout=None, coveragedir=None):
239 Args:
240 url (str): URL of the RPC server to call
241 node_number (int): the node number (or id) that this calls to
243 Kwargs:
244 timeout (int): HTTP timeout in seconds
246 Returns:
247 AuthServiceProxy. convenience object for making RPC calls.
250 proxy_kwargs = {}
251 if timeout is not None:
252 proxy_kwargs['timeout'] = timeout
254 proxy = AuthServiceProxy(url, **proxy_kwargs)
255 proxy.url = url # store URL on proxy for info
257 coverage_logfile = coverage.get_filename(
258 coveragedir, node_number) if coveragedir else None
260 return coverage.AuthServiceProxyWrapper(proxy, coverage_logfile)
262 def p2p_port(n):
263 assert(n <= MAX_NODES)
264 return PORT_MIN + n + (MAX_NODES * PortSeed.n) % (PORT_RANGE - 1 - MAX_NODES)
266 def rpc_port(n):
267 return PORT_MIN + PORT_RANGE + n + (MAX_NODES * PortSeed.n) % (PORT_RANGE - 1 - MAX_NODES)
269 def rpc_url(datadir, i, rpchost=None):
270 rpc_u, rpc_p = get_auth_cookie(datadir)
271 host = '127.0.0.1'
272 port = rpc_port(i)
273 if rpchost:
274 parts = rpchost.split(':')
275 if len(parts) == 2:
276 host, port = parts
277 else:
278 host = rpchost
279 return "http://%s:%s@%s:%d" % (rpc_u, rpc_p, host, int(port))
281 # Node functions
282 ################
284 def initialize_datadir(dirname, n):
285 datadir = os.path.join(dirname, "node" + str(n))
286 if not os.path.isdir(datadir):
287 os.makedirs(datadir)
288 with open(os.path.join(datadir, "bitcoin.conf"), 'w', encoding='utf8') as f:
289 f.write("regtest=1\n")
290 f.write("port=" + str(p2p_port(n)) + "\n")
291 f.write("rpcport=" + str(rpc_port(n)) + "\n")
292 f.write("listenonion=0\n")
293 return datadir
295 def get_datadir_path(dirname, n):
296 return os.path.join(dirname, "node" + str(n))
298 def get_auth_cookie(datadir):
299 user = None
300 password = None
301 if os.path.isfile(os.path.join(datadir, "bitcoin.conf")):
302 with open(os.path.join(datadir, "bitcoin.conf"), 'r', encoding='utf8') as f:
303 for line in f:
304 if line.startswith("rpcuser="):
305 assert user is None # Ensure that there is only one rpcuser line
306 user = line.split("=")[1].strip("\n")
307 if line.startswith("rpcpassword="):
308 assert password is None # Ensure that there is only one rpcpassword line
309 password = line.split("=")[1].strip("\n")
310 if os.path.isfile(os.path.join(datadir, "regtest", ".cookie")):
311 with open(os.path.join(datadir, "regtest", ".cookie"), 'r') as f:
312 userpass = f.read()
313 split_userpass = userpass.split(':')
314 user = split_userpass[0]
315 password = split_userpass[1]
316 if user is None or password is None:
317 raise ValueError("No RPC credentials")
318 return user, password
320 def log_filename(dirname, n_node, logname):
321 return os.path.join(dirname, "node" + str(n_node), "regtest", logname)
323 def get_bip9_status(node, key):
324 info = node.getblockchaininfo()
325 return info['bip9_softforks'][key]
327 def set_node_times(nodes, t):
328 for node in nodes:
329 node.setmocktime(t)
331 def disconnect_nodes(from_connection, node_num):
332 for peer_id in [peer['id'] for peer in from_connection.getpeerinfo() if "testnode%d" % node_num in peer['subver']]:
333 from_connection.disconnectnode(nodeid=peer_id)
335 for _ in range(50):
336 if [peer['id'] for peer in from_connection.getpeerinfo() if "testnode%d" % node_num in peer['subver']] == []:
337 break
338 time.sleep(0.1)
339 else:
340 raise AssertionError("timed out waiting for disconnect")
342 def connect_nodes(from_connection, node_num):
343 ip_port = "127.0.0.1:" + str(p2p_port(node_num))
344 from_connection.addnode(ip_port, "onetry")
345 # poll until version handshake complete to avoid race conditions
346 # with transaction relaying
347 while any(peer['version'] == 0 for peer in from_connection.getpeerinfo()):
348 time.sleep(0.1)
350 def connect_nodes_bi(nodes, a, b):
351 connect_nodes(nodes[a], b)
352 connect_nodes(nodes[b], a)
354 def sync_blocks(rpc_connections, *, wait=1, timeout=60):
356 Wait until everybody has the same tip.
358 sync_blocks needs to be called with an rpc_connections set that has least
359 one node already synced to the latest, stable tip, otherwise there's a
360 chance it might return before all nodes are stably synced.
362 # Use getblockcount() instead of waitforblockheight() to determine the
363 # initial max height because the two RPCs look at different internal global
364 # variables (chainActive vs latestBlock) and the former gets updated
365 # earlier.
366 maxheight = max(x.getblockcount() for x in rpc_connections)
367 start_time = cur_time = time.time()
368 while cur_time <= start_time + timeout:
369 tips = [r.waitforblockheight(maxheight, int(wait * 1000)) for r in rpc_connections]
370 if all(t["height"] == maxheight for t in tips):
371 if all(t["hash"] == tips[0]["hash"] for t in tips):
372 return
373 raise AssertionError("Block sync failed, mismatched block hashes:{}".format(
374 "".join("\n {!r}".format(tip) for tip in tips)))
375 cur_time = time.time()
376 raise AssertionError("Block sync to height {} timed out:{}".format(
377 maxheight, "".join("\n {!r}".format(tip) for tip in tips)))
379 def sync_chain(rpc_connections, *, wait=1, timeout=60):
381 Wait until everybody has the same best block
383 while timeout > 0:
384 best_hash = [x.getbestblockhash() for x in rpc_connections]
385 if best_hash == [best_hash[0]] * len(best_hash):
386 return
387 time.sleep(wait)
388 timeout -= wait
389 raise AssertionError("Chain sync failed: Best block hashes don't match")
391 def sync_mempools(rpc_connections, *, wait=1, timeout=60):
393 Wait until everybody has the same transactions in their memory
394 pools
396 while timeout > 0:
397 pool = set(rpc_connections[0].getrawmempool())
398 num_match = 1
399 for i in range(1, len(rpc_connections)):
400 if set(rpc_connections[i].getrawmempool()) == pool:
401 num_match = num_match + 1
402 if num_match == len(rpc_connections):
403 return
404 time.sleep(wait)
405 timeout -= wait
406 raise AssertionError("Mempool sync failed")
408 # Transaction/Block functions
409 #############################
411 def find_output(node, txid, amount):
413 Return index to output of txid with value amount
414 Raises exception if there is none.
416 txdata = node.getrawtransaction(txid, 1)
417 for i in range(len(txdata["vout"])):
418 if txdata["vout"][i]["value"] == amount:
419 return i
420 raise RuntimeError("find_output txid %s : %s not found" % (txid, str(amount)))
422 def gather_inputs(from_node, amount_needed, confirmations_required=1):
424 Return a random set of unspent txouts that are enough to pay amount_needed
426 assert(confirmations_required >= 0)
427 utxo = from_node.listunspent(confirmations_required)
428 random.shuffle(utxo)
429 inputs = []
430 total_in = Decimal("0.00000000")
431 while total_in < amount_needed and len(utxo) > 0:
432 t = utxo.pop()
433 total_in += t["amount"]
434 inputs.append({"txid": t["txid"], "vout": t["vout"], "address": t["address"]})
435 if total_in < amount_needed:
436 raise RuntimeError("Insufficient funds: need %d, have %d" % (amount_needed, total_in))
437 return (total_in, inputs)
439 def make_change(from_node, amount_in, amount_out, fee):
441 Create change output(s), return them
443 outputs = {}
444 amount = amount_out + fee
445 change = amount_in - amount
446 if change > amount * 2:
447 # Create an extra change output to break up big inputs
448 change_address = from_node.getnewaddress()
449 # Split change in two, being careful of rounding:
450 outputs[change_address] = Decimal(change / 2).quantize(Decimal('0.00000001'), rounding=ROUND_DOWN)
451 change = amount_in - amount - outputs[change_address]
452 if change > 0:
453 outputs[from_node.getnewaddress()] = change
454 return outputs
456 def random_transaction(nodes, amount, min_fee, fee_increment, fee_variants):
458 Create a random transaction.
459 Returns (txid, hex-encoded-transaction-data, fee)
461 from_node = random.choice(nodes)
462 to_node = random.choice(nodes)
463 fee = min_fee + fee_increment * random.randint(0, fee_variants)
465 (total_in, inputs) = gather_inputs(from_node, amount + fee)
466 outputs = make_change(from_node, total_in, amount, fee)
467 outputs[to_node.getnewaddress()] = float(amount)
469 rawtx = from_node.createrawtransaction(inputs, outputs)
470 signresult = from_node.signrawtransaction(rawtx)
471 txid = from_node.sendrawtransaction(signresult["hex"], True)
473 return (txid, signresult["hex"], fee)
475 # Helper to create at least "count" utxos
476 # Pass in a fee that is sufficient for relay and mining new transactions.
477 def create_confirmed_utxos(fee, node, count):
478 to_generate = int(0.5 * count) + 101
479 while to_generate > 0:
480 node.generate(min(25, to_generate))
481 to_generate -= 25
482 utxos = node.listunspent()
483 iterations = count - len(utxos)
484 addr1 = node.getnewaddress()
485 addr2 = node.getnewaddress()
486 if iterations <= 0:
487 return utxos
488 for i in range(iterations):
489 t = utxos.pop()
490 inputs = []
491 inputs.append({"txid": t["txid"], "vout": t["vout"]})
492 outputs = {}
493 send_value = t['amount'] - fee
494 outputs[addr1] = satoshi_round(send_value / 2)
495 outputs[addr2] = satoshi_round(send_value / 2)
496 raw_tx = node.createrawtransaction(inputs, outputs)
497 signed_tx = node.signrawtransaction(raw_tx)["hex"]
498 node.sendrawtransaction(signed_tx)
500 while (node.getmempoolinfo()['size'] > 0):
501 node.generate(1)
503 utxos = node.listunspent()
504 assert(len(utxos) >= count)
505 return utxos
507 # Create large OP_RETURN txouts that can be appended to a transaction
508 # to make it large (helper for constructing large transactions).
509 def gen_return_txouts():
510 # Some pre-processing to create a bunch of OP_RETURN txouts to insert into transactions we create
511 # So we have big transactions (and therefore can't fit very many into each block)
512 # create one script_pubkey
513 script_pubkey = "6a4d0200" # OP_RETURN OP_PUSH2 512 bytes
514 for i in range(512):
515 script_pubkey = script_pubkey + "01"
516 # concatenate 128 txouts of above script_pubkey which we'll insert before the txout for change
517 txouts = "81"
518 for k in range(128):
519 # add txout value
520 txouts = txouts + "0000000000000000"
521 # add length of script_pubkey
522 txouts = txouts + "fd0402"
523 # add script_pubkey
524 txouts = txouts + script_pubkey
525 return txouts
527 def create_tx(node, coinbase, to_address, amount):
528 inputs = [{"txid": coinbase, "vout": 0}]
529 outputs = {to_address: amount}
530 rawtx = node.createrawtransaction(inputs, outputs)
531 signresult = node.signrawtransaction(rawtx)
532 assert_equal(signresult["complete"], True)
533 return signresult["hex"]
535 # Create a spend of each passed-in utxo, splicing in "txouts" to each raw
536 # transaction to make it large. See gen_return_txouts() above.
537 def create_lots_of_big_transactions(node, txouts, utxos, num, fee):
538 addr = node.getnewaddress()
539 txids = []
540 for _ in range(num):
541 t = utxos.pop()
542 inputs = [{"txid": t["txid"], "vout": t["vout"]}]
543 outputs = {}
544 change = t['amount'] - fee
545 outputs[addr] = satoshi_round(change)
546 rawtx = node.createrawtransaction(inputs, outputs)
547 newtx = rawtx[0:92]
548 newtx = newtx + txouts
549 newtx = newtx + rawtx[94:]
550 signresult = node.signrawtransaction(newtx, None, None, "NONE")
551 txid = node.sendrawtransaction(signresult["hex"], True)
552 txids.append(txid)
553 return txids
555 def mine_large_block(node, utxos=None):
556 # generate a 66k transaction,
557 # and 14 of them is close to the 1MB block limit
558 num = 14
559 txouts = gen_return_txouts()
560 utxos = utxos if utxos is not None else []
561 if len(utxos) < num:
562 utxos.clear()
563 utxos.extend(node.listunspent())
564 fee = 100 * node.getnetworkinfo()["relayfee"]
565 create_lots_of_big_transactions(node, txouts, utxos, num, fee=fee)
566 node.generate(1)