1 """Unittests for heapq."""
5 from test
import test_support
8 # We do a bit of trickery here to be able to test both the C implementation
9 # and the Python implementation of the module.
10 import heapq
as c_heapq
11 py_heapq
= test_support
.import_fresh_module('heapq', blocked
=['_heapq'])
13 class TestHeap(unittest
.TestCase
):
16 def test_push_pop(self
):
17 # 1) Push 256 random numbers and pop them off, verifying all's OK.
20 self
.check_invariant(heap
)
22 item
= random
.random()
24 self
.module
.heappush(heap
, item
)
25 self
.check_invariant(heap
)
28 item
= self
.module
.heappop(heap
)
29 self
.check_invariant(heap
)
33 self
.assertEqual(data_sorted
, results
)
34 # 2) Check that the invariant holds for a sorted array
35 self
.check_invariant(results
)
37 self
.assertRaises(TypeError, self
.module
.heappush
, [])
39 self
.assertRaises(TypeError, self
.module
.heappush
, None, None)
40 self
.assertRaises(TypeError, self
.module
.heappop
, None)
41 except AttributeError:
44 def check_invariant(self
, heap
):
45 # Check the heap invariant.
46 for pos
, item
in enumerate(heap
):
47 if pos
: # pos 0 has no parent
48 parentpos
= (pos
-1) >> 1
49 self
.assertTrue(heap
[parentpos
] <= item
)
51 def test_heapify(self
):
52 for size
in range(30):
53 heap
= [random
.random() for dummy
in range(size
)]
54 self
.module
.heapify(heap
)
55 self
.check_invariant(heap
)
57 self
.assertRaises(TypeError, self
.module
.heapify
, None)
59 def test_naive_nbest(self
):
60 data
= [random
.randrange(2000) for i
in range(1000)]
63 self
.module
.heappush(heap
, item
)
65 self
.module
.heappop(heap
)
67 self
.assertEqual(heap
, sorted(data
)[-10:])
69 def heapiter(self
, heap
):
70 # An iterator returning a heap's elements, smallest-first.
73 yield self
.module
.heappop(heap
)
78 # Less-naive "N-best" algorithm, much faster (if len(data) is big
79 # enough <wink>) than sorting all of data. However, if we had a max
80 # heap instead of a min heap, it could go faster still via
81 # heapify'ing all of data (linear time), then doing 10 heappops
82 # (10 log-time steps).
83 data
= [random
.randrange(2000) for i
in range(1000)]
85 self
.module
.heapify(heap
)
86 for item
in data
[10:]:
87 if item
> heap
[0]: # this gets rarer the longer we run
88 self
.module
.heapreplace(heap
, item
)
89 self
.assertEqual(list(self
.heapiter(heap
)), sorted(data
)[-10:])
91 self
.assertRaises(TypeError, self
.module
.heapreplace
, None)
92 self
.assertRaises(TypeError, self
.module
.heapreplace
, None, None)
93 self
.assertRaises(IndexError, self
.module
.heapreplace
, [], None)
95 def test_nbest_with_pushpop(self
):
96 data
= [random
.randrange(2000) for i
in range(1000)]
98 self
.module
.heapify(heap
)
99 for item
in data
[10:]:
100 self
.module
.heappushpop(heap
, item
)
101 self
.assertEqual(list(self
.heapiter(heap
)), sorted(data
)[-10:])
102 self
.assertEqual(self
.module
.heappushpop([], 'x'), 'x')
104 def test_heappushpop(self
):
106 x
= self
.module
.heappushpop(h
, 10)
107 self
.assertEqual((h
, x
), ([], 10))
110 x
= self
.module
.heappushpop(h
, 10.0)
111 self
.assertEqual((h
, x
), ([10], 10.0))
112 self
.assertEqual(type(h
[0]), int)
113 self
.assertEqual(type(x
), float)
116 x
= self
.module
.heappushpop(h
, 9)
117 self
.assertEqual((h
, x
), ([10], 9))
120 x
= self
.module
.heappushpop(h
, 11)
121 self
.assertEqual((h
, x
), ([11], 10))
123 def test_heapsort(self
):
124 # Exercise everything with repeated heapsort checks
125 for trial
in xrange(100):
126 size
= random
.randrange(50)
127 data
= [random
.randrange(25) for i
in range(size
)]
128 if trial
& 1: # Half of the time, use heapify
130 self
.module
.heapify(heap
)
131 else: # The rest of the time, use heappush
134 self
.module
.heappush(heap
, item
)
135 heap_sorted
= [self
.module
.heappop(heap
) for i
in range(size
)]
136 self
.assertEqual(heap_sorted
, sorted(data
))
138 def test_merge(self
):
140 for i
in xrange(random
.randrange(5)):
141 row
= sorted(random
.randrange(1000) for j
in range(random
.randrange(10)))
143 self
.assertEqual(sorted(chain(*inputs
)), list(self
.module
.merge(*inputs
)))
144 self
.assertEqual(list(self
.module
.merge()), [])
146 def test_merge_stability(self
):
149 inputs
= [[], [], [], []]
150 for i
in range(20000):
151 stream
= random
.randrange(4)
152 x
= random
.randrange(500)
154 obj
.pair
= (x
, stream
)
155 inputs
[stream
].append(obj
)
156 for stream
in inputs
:
158 result
= [i
.pair
for i
in self
.module
.merge(*inputs
)]
159 self
.assertEqual(result
, sorted(result
))
161 def test_nsmallest(self
):
162 data
= [(random
.randrange(2000), i
) for i
in range(1000)]
163 for f
in (None, lambda x
: x
[0] * 547 % 2000):
164 for n
in (0, 1, 2, 10, 100, 400, 999, 1000, 1100):
165 self
.assertEqual(self
.module
.nsmallest(n
, data
), sorted(data
)[:n
])
166 self
.assertEqual(self
.module
.nsmallest(n
, data
, key
=f
),
167 sorted(data
, key
=f
)[:n
])
169 def test_nlargest(self
):
170 data
= [(random
.randrange(2000), i
) for i
in range(1000)]
171 for f
in (None, lambda x
: x
[0] * 547 % 2000):
172 for n
in (0, 1, 2, 10, 100, 400, 999, 1000, 1100):
173 self
.assertEqual(self
.module
.nlargest(n
, data
),
174 sorted(data
, reverse
=True)[:n
])
175 self
.assertEqual(self
.module
.nlargest(n
, data
, key
=f
),
176 sorted(data
, key
=f
, reverse
=True)[:n
])
178 class TestHeapPython(TestHeap
):
181 # As an early adopter, we sanity check the
182 # test_support.import_fresh_module utility function
183 def test_pure_python(self
):
184 self
.assertFalse(sys
.modules
['heapq'] is self
.module
)
185 self
.assertTrue(hasattr(self
.module
.heapify
, 'func_code'))
188 class TestHeapC(TestHeap
):
191 def test_comparison_operator(self
):
192 # Issue 3501: Make sure heapq works with both __lt__ and __le__
193 def hsort(data
, comp
):
194 data
= map(comp
, data
)
195 self
.module
.heapify(data
)
196 return [self
.module
.heappop(data
).x
for i
in range(len(data
))]
198 def __init__(self
, x
):
200 def __lt__(self
, other
):
201 return self
.x
> other
.x
203 def __init__(self
, x
):
205 def __le__(self
, other
):
206 return self
.x
>= other
.x
207 data
= [random
.random() for i
in range(100)]
208 target
= sorted(data
, reverse
=True)
209 self
.assertEqual(hsort(data
, LT
), target
)
210 self
.assertEqual(hsort(data
, LE
), target
)
212 # As an early adopter, we sanity check the
213 # test_support.import_fresh_module utility function
214 def test_accelerated(self
):
215 self
.assertTrue(sys
.modules
['heapq'] is self
.module
)
216 self
.assertFalse(hasattr(self
.module
.heapify
, 'func_code'))
219 #==============================================================================
222 "Dummy sequence class defining __len__ but not __getitem__."
227 "Dummy sequence class defining __getitem__ but not __len__."
228 def __getitem__(self
, ndx
):
232 "Dummy element that always raises an error during comparison"
233 def __cmp__(self
, other
):
234 raise ZeroDivisionError
242 'Sequence using __getitem__'
243 def __init__(self
, seqn
):
245 def __getitem__(self
, i
):
249 'Sequence using iterator protocol'
250 def __init__(self
, seqn
):
256 if self
.i
>= len(self
.seqn
): raise StopIteration
257 v
= self
.seqn
[self
.i
]
262 'Sequence using iterator protocol defined with a generator'
263 def __init__(self
, seqn
):
267 for val
in self
.seqn
:
271 'Missing __getitem__ and __iter__'
272 def __init__(self
, seqn
):
276 if self
.i
>= len(self
.seqn
): raise StopIteration
277 v
= self
.seqn
[self
.i
]
282 'Iterator missing next()'
283 def __init__(self
, seqn
):
290 'Test propagation of exceptions'
291 def __init__(self
, seqn
):
300 'Test immediate stop'
301 def __init__(self
, seqn
):
308 from itertools
import chain
, imap
310 'Test multiple tiers of iterators'
311 return chain(imap(lambda x
:x
, R(Ig(G(seqn
)))))
313 class TestErrorHandling(unittest
.TestCase
):
314 # only for C implementation
317 def test_non_sequence(self
):
318 for f
in (self
.module
.heapify
, self
.module
.heappop
):
319 self
.assertRaises(TypeError, f
, 10)
320 for f
in (self
.module
.heappush
, self
.module
.heapreplace
,
321 self
.module
.nlargest
, self
.module
.nsmallest
):
322 self
.assertRaises(TypeError, f
, 10, 10)
324 def test_len_only(self
):
325 for f
in (self
.module
.heapify
, self
.module
.heappop
):
326 self
.assertRaises(TypeError, f
, LenOnly())
327 for f
in (self
.module
.heappush
, self
.module
.heapreplace
):
328 self
.assertRaises(TypeError, f
, LenOnly(), 10)
329 for f
in (self
.module
.nlargest
, self
.module
.nsmallest
):
330 self
.assertRaises(TypeError, f
, 2, LenOnly())
332 def test_get_only(self
):
333 for f
in (self
.module
.heapify
, self
.module
.heappop
):
334 self
.assertRaises(TypeError, f
, GetOnly())
335 for f
in (self
.module
.heappush
, self
.module
.heapreplace
):
336 self
.assertRaises(TypeError, f
, GetOnly(), 10)
337 for f
in (self
.module
.nlargest
, self
.module
.nsmallest
):
338 self
.assertRaises(TypeError, f
, 2, GetOnly())
340 def test_get_only(self
):
341 seq
= [CmpErr(), CmpErr(), CmpErr()]
342 for f
in (self
.module
.heapify
, self
.module
.heappop
):
343 self
.assertRaises(ZeroDivisionError, f
, seq
)
344 for f
in (self
.module
.heappush
, self
.module
.heapreplace
):
345 self
.assertRaises(ZeroDivisionError, f
, seq
, 10)
346 for f
in (self
.module
.nlargest
, self
.module
.nsmallest
):
347 self
.assertRaises(ZeroDivisionError, f
, 2, seq
)
349 def test_arg_parsing(self
):
350 for f
in (self
.module
.heapify
, self
.module
.heappop
,
351 self
.module
.heappush
, self
.module
.heapreplace
,
352 self
.module
.nlargest
, self
.module
.nsmallest
):
353 self
.assertRaises(TypeError, f
, 10)
355 def test_iterable_args(self
):
356 for f
in (self
.module
.nlargest
, self
.module
.nsmallest
):
357 for s
in ("123", "", range(1000), ('do', 1.2), xrange(2000,2200,5)):
358 for g
in (G
, I
, Ig
, L
, R
):
359 with test_support
.check_py3k_warnings(
360 ("comparing unequal types not supported",
361 DeprecationWarning), quiet
=True):
362 self
.assertEqual(f(2, g(s
)), f(2,s
))
363 self
.assertEqual(f(2, S(s
)), [])
364 self
.assertRaises(TypeError, f
, 2, X(s
))
365 self
.assertRaises(TypeError, f
, 2, N(s
))
366 self
.assertRaises(ZeroDivisionError, f
, 2, E(s
))
369 #==============================================================================
372 def test_main(verbose
=None):
373 test_classes
= [TestHeapPython
, TestHeapC
, TestErrorHandling
]
374 test_support
.run_unittest(*test_classes
)
376 # verify reference counting
377 if verbose
and hasattr(sys
, "gettotalrefcount"):
380 for i
in xrange(len(counts
)):
381 test_support
.run_unittest(*test_classes
)
383 counts
[i
] = sys
.gettotalrefcount()
386 if __name__
== "__main__":
387 test_main(verbose
=True)