move sections
[python/dscho.git] / Lib / test / test_itertools.py
blobe91ac28caafff51be946244de0ddf21e17525b2c
1 import unittest
2 from test import test_support
3 from itertools import *
4 from weakref import proxy
5 from decimal import Decimal
6 from fractions import Fraction
7 import sys
8 import operator
9 import random
10 import copy
11 import pickle
12 from functools import reduce
13 maxsize = test_support.MAX_Py_ssize_t
14 minsize = -maxsize-1
16 def onearg(x):
17 'Test function of one argument'
18 return 2*x
20 def errfunc(*args):
21 'Test function that raises an error'
22 raise ValueError
24 def gen3():
25 'Non-restartable source sequence'
26 for i in (0, 1, 2):
27 yield i
29 def isEven(x):
30 'Test predicate'
31 return x%2==0
33 def isOdd(x):
34 'Test predicate'
35 return x%2==1
37 class StopNow:
38 'Class emulating an empty iterable.'
39 def __iter__(self):
40 return self
41 def next(self):
42 raise StopIteration
44 def take(n, seq):
45 'Convenience function for partially consuming a long of infinite iterable'
46 return list(islice(seq, n))
48 def prod(iterable):
49 return reduce(operator.mul, iterable, 1)
51 def fact(n):
52 'Factorial'
53 return prod(range(1, n+1))
55 class TestBasicOps(unittest.TestCase):
56 def test_chain(self):
58 def chain2(*iterables):
59 'Pure python version in the docs'
60 for it in iterables:
61 for element in it:
62 yield element
64 for c in (chain, chain2):
65 self.assertEqual(list(c('abc', 'def')), list('abcdef'))
66 self.assertEqual(list(c('abc')), list('abc'))
67 self.assertEqual(list(c('')), [])
68 self.assertEqual(take(4, c('abc', 'def')), list('abcd'))
69 self.assertRaises(TypeError, list,c(2, 3))
71 def test_chain_from_iterable(self):
72 self.assertEqual(list(chain.from_iterable(['abc', 'def'])), list('abcdef'))
73 self.assertEqual(list(chain.from_iterable(['abc'])), list('abc'))
74 self.assertEqual(list(chain.from_iterable([''])), [])
75 self.assertEqual(take(4, chain.from_iterable(['abc', 'def'])), list('abcd'))
76 self.assertRaises(TypeError, list, chain.from_iterable([2, 3]))
78 def test_combinations(self):
79 self.assertRaises(TypeError, combinations, 'abc') # missing r argument
80 self.assertRaises(TypeError, combinations, 'abc', 2, 1) # too many arguments
81 self.assertRaises(TypeError, combinations, None) # pool is not iterable
82 self.assertRaises(ValueError, combinations, 'abc', -2) # r is negative
83 self.assertEqual(list(combinations('abc', 32)), []) # r > n
84 self.assertEqual(list(combinations(range(4), 3)),
85 [(0,1,2), (0,1,3), (0,2,3), (1,2,3)])
87 def combinations1(iterable, r):
88 'Pure python version shown in the docs'
89 pool = tuple(iterable)
90 n = len(pool)
91 if r > n:
92 return
93 indices = range(r)
94 yield tuple(pool[i] for i in indices)
95 while 1:
96 for i in reversed(range(r)):
97 if indices[i] != i + n - r:
98 break
99 else:
100 return
101 indices[i] += 1
102 for j in range(i+1, r):
103 indices[j] = indices[j-1] + 1
104 yield tuple(pool[i] for i in indices)
106 def combinations2(iterable, r):
107 'Pure python version shown in the docs'
108 pool = tuple(iterable)
109 n = len(pool)
110 for indices in permutations(range(n), r):
111 if sorted(indices) == list(indices):
112 yield tuple(pool[i] for i in indices)
114 def combinations3(iterable, r):
115 'Pure python version from cwr()'
116 pool = tuple(iterable)
117 n = len(pool)
118 for indices in combinations_with_replacement(range(n), r):
119 if len(set(indices)) == r:
120 yield tuple(pool[i] for i in indices)
122 for n in range(7):
123 values = [5*x-12 for x in range(n)]
124 for r in range(n+2):
125 result = list(combinations(values, r))
126 self.assertEqual(len(result), 0 if r>n else fact(n) // fact(r) // fact(n-r)) # right number of combs
127 self.assertEqual(len(result), len(set(result))) # no repeats
128 self.assertEqual(result, sorted(result)) # lexicographic order
129 for c in result:
130 self.assertEqual(len(c), r) # r-length combinations
131 self.assertEqual(len(set(c)), r) # no duplicate elements
132 self.assertEqual(list(c), sorted(c)) # keep original ordering
133 self.assertTrue(all(e in values for e in c)) # elements taken from input iterable
134 self.assertEqual(list(c),
135 [e for e in values if e in c]) # comb is a subsequence of the input iterable
136 self.assertEqual(result, list(combinations1(values, r))) # matches first pure python version
137 self.assertEqual(result, list(combinations2(values, r))) # matches second pure python version
138 self.assertEqual(result, list(combinations3(values, r))) # matches second pure python version
140 # Test implementation detail: tuple re-use
141 self.assertEqual(len(set(map(id, combinations('abcde', 3)))), 1)
142 self.assertNotEqual(len(set(map(id, list(combinations('abcde', 3))))), 1)
144 def test_combinations_with_replacement(self):
145 cwr = combinations_with_replacement
146 self.assertRaises(TypeError, cwr, 'abc') # missing r argument
147 self.assertRaises(TypeError, cwr, 'abc', 2, 1) # too many arguments
148 self.assertRaises(TypeError, cwr, None) # pool is not iterable
149 self.assertRaises(ValueError, cwr, 'abc', -2) # r is negative
150 self.assertEqual(list(cwr('ABC', 2)),
151 [('A','A'), ('A','B'), ('A','C'), ('B','B'), ('B','C'), ('C','C')])
153 def cwr1(iterable, r):
154 'Pure python version shown in the docs'
155 # number items returned: (n+r-1)! / r! / (n-1)! when n>0
156 pool = tuple(iterable)
157 n = len(pool)
158 if not n and r:
159 return
160 indices = [0] * r
161 yield tuple(pool[i] for i in indices)
162 while 1:
163 for i in reversed(range(r)):
164 if indices[i] != n - 1:
165 break
166 else:
167 return
168 indices[i:] = [indices[i] + 1] * (r - i)
169 yield tuple(pool[i] for i in indices)
171 def cwr2(iterable, r):
172 'Pure python version shown in the docs'
173 pool = tuple(iterable)
174 n = len(pool)
175 for indices in product(range(n), repeat=r):
176 if sorted(indices) == list(indices):
177 yield tuple(pool[i] for i in indices)
179 def numcombs(n, r):
180 if not n:
181 return 0 if r else 1
182 return fact(n+r-1) // fact(r) // fact(n-1)
184 for n in range(7):
185 values = [5*x-12 for x in range(n)]
186 for r in range(n+2):
187 result = list(cwr(values, r))
189 self.assertEqual(len(result), numcombs(n, r)) # right number of combs
190 self.assertEqual(len(result), len(set(result))) # no repeats
191 self.assertEqual(result, sorted(result)) # lexicographic order
193 regular_combs = list(combinations(values, r)) # compare to combs without replacement
194 if n == 0 or r <= 1:
195 self.assertEquals(result, regular_combs) # cases that should be identical
196 else:
197 self.assertTrue(set(result) >= set(regular_combs)) # rest should be supersets of regular combs
199 for c in result:
200 self.assertEqual(len(c), r) # r-length combinations
201 noruns = [k for k,v in groupby(c)] # combo without consecutive repeats
202 self.assertEqual(len(noruns), len(set(noruns))) # no repeats other than consecutive
203 self.assertEqual(list(c), sorted(c)) # keep original ordering
204 self.assertTrue(all(e in values for e in c)) # elements taken from input iterable
205 self.assertEqual(noruns,
206 [e for e in values if e in c]) # comb is a subsequence of the input iterable
207 self.assertEqual(result, list(cwr1(values, r))) # matches first pure python version
208 self.assertEqual(result, list(cwr2(values, r))) # matches second pure python version
210 # Test implementation detail: tuple re-use
211 self.assertEqual(len(set(map(id, cwr('abcde', 3)))), 1)
212 self.assertNotEqual(len(set(map(id, list(cwr('abcde', 3))))), 1)
214 def test_permutations(self):
215 self.assertRaises(TypeError, permutations) # too few arguments
216 self.assertRaises(TypeError, permutations, 'abc', 2, 1) # too many arguments
217 self.assertRaises(TypeError, permutations, None) # pool is not iterable
218 self.assertRaises(ValueError, permutations, 'abc', -2) # r is negative
219 self.assertEqual(list(permutations('abc', 32)), []) # r > n
220 self.assertRaises(TypeError, permutations, 'abc', 's') # r is not an int or None
221 self.assertEqual(list(permutations(range(3), 2)),
222 [(0,1), (0,2), (1,0), (1,2), (2,0), (2,1)])
224 def permutations1(iterable, r=None):
225 'Pure python version shown in the docs'
226 pool = tuple(iterable)
227 n = len(pool)
228 r = n if r is None else r
229 if r > n:
230 return
231 indices = range(n)
232 cycles = range(n, n-r, -1)
233 yield tuple(pool[i] for i in indices[:r])
234 while n:
235 for i in reversed(range(r)):
236 cycles[i] -= 1
237 if cycles[i] == 0:
238 indices[i:] = indices[i+1:] + indices[i:i+1]
239 cycles[i] = n - i
240 else:
241 j = cycles[i]
242 indices[i], indices[-j] = indices[-j], indices[i]
243 yield tuple(pool[i] for i in indices[:r])
244 break
245 else:
246 return
248 def permutations2(iterable, r=None):
249 'Pure python version shown in the docs'
250 pool = tuple(iterable)
251 n = len(pool)
252 r = n if r is None else r
253 for indices in product(range(n), repeat=r):
254 if len(set(indices)) == r:
255 yield tuple(pool[i] for i in indices)
257 for n in range(7):
258 values = [5*x-12 for x in range(n)]
259 for r in range(n+2):
260 result = list(permutations(values, r))
261 self.assertEqual(len(result), 0 if r>n else fact(n) // fact(n-r)) # right number of perms
262 self.assertEqual(len(result), len(set(result))) # no repeats
263 self.assertEqual(result, sorted(result)) # lexicographic order
264 for p in result:
265 self.assertEqual(len(p), r) # r-length permutations
266 self.assertEqual(len(set(p)), r) # no duplicate elements
267 self.assertTrue(all(e in values for e in p)) # elements taken from input iterable
268 self.assertEqual(result, list(permutations1(values, r))) # matches first pure python version
269 self.assertEqual(result, list(permutations2(values, r))) # matches second pure python version
270 if r == n:
271 self.assertEqual(result, list(permutations(values, None))) # test r as None
272 self.assertEqual(result, list(permutations(values))) # test default r
274 # Test implementation detail: tuple re-use
275 self.assertEqual(len(set(map(id, permutations('abcde', 3)))), 1)
276 self.assertNotEqual(len(set(map(id, list(permutations('abcde', 3))))), 1)
278 def test_combinatorics(self):
279 # Test relationships between product(), permutations(),
280 # combinations() and combinations_with_replacement().
282 for n in range(6):
283 s = 'ABCDEFG'[:n]
284 for r in range(8):
285 prod = list(product(s, repeat=r))
286 cwr = list(combinations_with_replacement(s, r))
287 perm = list(permutations(s, r))
288 comb = list(combinations(s, r))
290 # Check size
291 self.assertEquals(len(prod), n**r)
292 self.assertEquals(len(cwr), (fact(n+r-1) // fact(r) // fact(n-1)) if n else (not r))
293 self.assertEquals(len(perm), 0 if r>n else fact(n) // fact(n-r))
294 self.assertEquals(len(comb), 0 if r>n else fact(n) // fact(r) // fact(n-r))
296 # Check lexicographic order without repeated tuples
297 self.assertEquals(prod, sorted(set(prod)))
298 self.assertEquals(cwr, sorted(set(cwr)))
299 self.assertEquals(perm, sorted(set(perm)))
300 self.assertEquals(comb, sorted(set(comb)))
302 # Check interrelationships
303 self.assertEquals(cwr, [t for t in prod if sorted(t)==list(t)]) # cwr: prods which are sorted
304 self.assertEquals(perm, [t for t in prod if len(set(t))==r]) # perm: prods with no dups
305 self.assertEqual(comb, [t for t in perm if sorted(t)==list(t)]) # comb: perms that are sorted
306 self.assertEqual(comb, [t for t in cwr if len(set(t))==r]) # comb: cwrs without dups
307 self.assertEqual(comb, filter(set(cwr).__contains__, perm)) # comb: perm that is a cwr
308 self.assertEqual(comb, filter(set(perm).__contains__, cwr)) # comb: cwr that is a perm
309 self.assertEqual(comb, sorted(set(cwr) & set(perm))) # comb: both a cwr and a perm
311 def test_compress(self):
312 self.assertEqual(list(compress(data='ABCDEF', selectors=[1,0,1,0,1,1])), list('ACEF'))
313 self.assertEqual(list(compress('ABCDEF', [1,0,1,0,1,1])), list('ACEF'))
314 self.assertEqual(list(compress('ABCDEF', [0,0,0,0,0,0])), list(''))
315 self.assertEqual(list(compress('ABCDEF', [1,1,1,1,1,1])), list('ABCDEF'))
316 self.assertEqual(list(compress('ABCDEF', [1,0,1])), list('AC'))
317 self.assertEqual(list(compress('ABC', [0,1,1,1,1,1])), list('BC'))
318 n = 10000
319 data = chain.from_iterable(repeat(range(6), n))
320 selectors = chain.from_iterable(repeat((0, 1)))
321 self.assertEqual(list(compress(data, selectors)), [1,3,5] * n)
322 self.assertRaises(TypeError, compress, None, range(6)) # 1st arg not iterable
323 self.assertRaises(TypeError, compress, range(6), None) # 2nd arg not iterable
324 self.assertRaises(TypeError, compress, range(6)) # too few args
325 self.assertRaises(TypeError, compress, range(6), None) # too many args
327 def test_count(self):
328 self.assertEqual(zip('abc',count()), [('a', 0), ('b', 1), ('c', 2)])
329 self.assertEqual(zip('abc',count(3)), [('a', 3), ('b', 4), ('c', 5)])
330 self.assertEqual(take(2, zip('abc',count(3))), [('a', 3), ('b', 4)])
331 self.assertEqual(take(2, zip('abc',count(-1))), [('a', -1), ('b', 0)])
332 self.assertEqual(take(2, zip('abc',count(-3))), [('a', -3), ('b', -2)])
333 self.assertRaises(TypeError, count, 2, 3, 4)
334 self.assertRaises(TypeError, count, 'a')
335 self.assertEqual(list(islice(count(maxsize-5), 10)), range(maxsize-5, maxsize+5))
336 self.assertEqual(list(islice(count(-maxsize-5), 10)), range(-maxsize-5, -maxsize+5))
337 c = count(3)
338 self.assertEqual(repr(c), 'count(3)')
339 c.next()
340 self.assertEqual(repr(c), 'count(4)')
341 c = count(-9)
342 self.assertEqual(repr(c), 'count(-9)')
343 c.next()
344 self.assertEqual(repr(count(10.25)), 'count(10.25)')
345 self.assertEqual(c.next(), -8)
346 for i in (-sys.maxint-5, -sys.maxint+5 ,-10, -1, 0, 10, sys.maxint-5, sys.maxint+5):
347 # Test repr (ignoring the L in longs)
348 r1 = repr(count(i)).replace('L', '')
349 r2 = 'count(%r)'.__mod__(i).replace('L', '')
350 self.assertEqual(r1, r2)
352 # check copy, deepcopy, pickle
353 for value in -3, 3, sys.maxint-5, sys.maxint+5:
354 c = count(value)
355 self.assertEqual(next(copy.copy(c)), value)
356 self.assertEqual(next(copy.deepcopy(c)), value)
357 self.assertEqual(next(pickle.loads(pickle.dumps(c))), value)
359 def test_count_with_stride(self):
360 self.assertEqual(zip('abc',count(2,3)), [('a', 2), ('b', 5), ('c', 8)])
361 self.assertEqual(zip('abc',count(start=2,step=3)),
362 [('a', 2), ('b', 5), ('c', 8)])
363 self.assertEqual(zip('abc',count(step=-1)),
364 [('a', 0), ('b', -1), ('c', -2)])
365 self.assertEqual(zip('abc',count(2,0)), [('a', 2), ('b', 2), ('c', 2)])
366 self.assertEqual(zip('abc',count(2,1)), [('a', 2), ('b', 3), ('c', 4)])
367 self.assertEqual(take(20, count(maxsize-15, 3)), take(20, range(maxsize-15, maxsize+100, 3)))
368 self.assertEqual(take(20, count(-maxsize-15, 3)), take(20, range(-maxsize-15,-maxsize+100, 3)))
369 self.assertEqual(take(3, count(2, 3.25-4j)), [2, 5.25-4j, 8.5-8j])
370 self.assertEqual(take(3, count(Decimal('1.1'), Decimal('.1'))),
371 [Decimal('1.1'), Decimal('1.2'), Decimal('1.3')])
372 self.assertEqual(take(3, count(Fraction(2,3), Fraction(1,7))),
373 [Fraction(2,3), Fraction(17,21), Fraction(20,21)])
374 self.assertEqual(repr(take(3, count(10, 2.5))), repr([10, 12.5, 15.0]))
375 c = count(3, 5)
376 self.assertEqual(repr(c), 'count(3, 5)')
377 c.next()
378 self.assertEqual(repr(c), 'count(8, 5)')
379 c = count(-9, 0)
380 self.assertEqual(repr(c), 'count(-9, 0)')
381 c.next()
382 self.assertEqual(repr(c), 'count(-9, 0)')
383 c = count(-9, -3)
384 self.assertEqual(repr(c), 'count(-9, -3)')
385 c.next()
386 self.assertEqual(repr(c), 'count(-12, -3)')
387 self.assertEqual(repr(c), 'count(-12, -3)')
388 self.assertEqual(repr(count(10.5, 1.25)), 'count(10.5, 1.25)')
389 self.assertEqual(repr(count(10.5, 1)), 'count(10.5)') # suppress step=1 when it's an int
390 self.assertEqual(repr(count(10.5, 1.00)), 'count(10.5, 1.0)') # do show float values lilke 1.0
391 for i in (-sys.maxint-5, -sys.maxint+5 ,-10, -1, 0, 10, sys.maxint-5, sys.maxint+5):
392 for j in (-sys.maxint-5, -sys.maxint+5 ,-10, -1, 0, 1, 10, sys.maxint-5, sys.maxint+5):
393 # Test repr (ignoring the L in longs)
394 r1 = repr(count(i, j)).replace('L', '')
395 if j == 1:
396 r2 = ('count(%r)' % i).replace('L', '')
397 else:
398 r2 = ('count(%r, %r)' % (i, j)).replace('L', '')
399 self.assertEqual(r1, r2)
401 def test_cycle(self):
402 self.assertEqual(take(10, cycle('abc')), list('abcabcabca'))
403 self.assertEqual(list(cycle('')), [])
404 self.assertRaises(TypeError, cycle)
405 self.assertRaises(TypeError, cycle, 5)
406 self.assertEqual(list(islice(cycle(gen3()),10)), [0,1,2,0,1,2,0,1,2,0])
408 def test_groupby(self):
409 # Check whether it accepts arguments correctly
410 self.assertEqual([], list(groupby([])))
411 self.assertEqual([], list(groupby([], key=id)))
412 self.assertRaises(TypeError, list, groupby('abc', []))
413 self.assertRaises(TypeError, groupby, None)
414 self.assertRaises(TypeError, groupby, 'abc', lambda x:x, 10)
416 # Check normal input
417 s = [(0, 10, 20), (0, 11,21), (0,12,21), (1,13,21), (1,14,22),
418 (2,15,22), (3,16,23), (3,17,23)]
419 dup = []
420 for k, g in groupby(s, lambda r:r[0]):
421 for elem in g:
422 self.assertEqual(k, elem[0])
423 dup.append(elem)
424 self.assertEqual(s, dup)
426 # Check nested case
427 dup = []
428 for k, g in groupby(s, lambda r:r[0]):
429 for ik, ig in groupby(g, lambda r:r[2]):
430 for elem in ig:
431 self.assertEqual(k, elem[0])
432 self.assertEqual(ik, elem[2])
433 dup.append(elem)
434 self.assertEqual(s, dup)
436 # Check case where inner iterator is not used
437 keys = [k for k, g in groupby(s, lambda r:r[0])]
438 expectedkeys = set([r[0] for r in s])
439 self.assertEqual(set(keys), expectedkeys)
440 self.assertEqual(len(keys), len(expectedkeys))
442 # Exercise pipes and filters style
443 s = 'abracadabra'
444 # sort s | uniq
445 r = [k for k, g in groupby(sorted(s))]
446 self.assertEqual(r, ['a', 'b', 'c', 'd', 'r'])
447 # sort s | uniq -d
448 r = [k for k, g in groupby(sorted(s)) if list(islice(g,1,2))]
449 self.assertEqual(r, ['a', 'b', 'r'])
450 # sort s | uniq -c
451 r = [(len(list(g)), k) for k, g in groupby(sorted(s))]
452 self.assertEqual(r, [(5, 'a'), (2, 'b'), (1, 'c'), (1, 'd'), (2, 'r')])
453 # sort s | uniq -c | sort -rn | head -3
454 r = sorted([(len(list(g)) , k) for k, g in groupby(sorted(s))], reverse=True)[:3]
455 self.assertEqual(r, [(5, 'a'), (2, 'r'), (2, 'b')])
457 # iter.next failure
458 class ExpectedError(Exception):
459 pass
460 def delayed_raise(n=0):
461 for i in range(n):
462 yield 'yo'
463 raise ExpectedError
464 def gulp(iterable, keyp=None, func=list):
465 return [func(g) for k, g in groupby(iterable, keyp)]
467 # iter.next failure on outer object
468 self.assertRaises(ExpectedError, gulp, delayed_raise(0))
469 # iter.next failure on inner object
470 self.assertRaises(ExpectedError, gulp, delayed_raise(1))
472 # __cmp__ failure
473 class DummyCmp:
474 def __cmp__(self, dst):
475 raise ExpectedError
476 s = [DummyCmp(), DummyCmp(), None]
478 # __cmp__ failure on outer object
479 self.assertRaises(ExpectedError, gulp, s, func=id)
480 # __cmp__ failure on inner object
481 self.assertRaises(ExpectedError, gulp, s)
483 # keyfunc failure
484 def keyfunc(obj):
485 if keyfunc.skip > 0:
486 keyfunc.skip -= 1
487 return obj
488 else:
489 raise ExpectedError
491 # keyfunc failure on outer object
492 keyfunc.skip = 0
493 self.assertRaises(ExpectedError, gulp, [None], keyfunc)
494 keyfunc.skip = 1
495 self.assertRaises(ExpectedError, gulp, [None, None], keyfunc)
497 def test_ifilter(self):
498 self.assertEqual(list(ifilter(isEven, range(6))), [0,2,4])
499 self.assertEqual(list(ifilter(None, [0,1,0,2,0])), [1,2])
500 self.assertEqual(list(ifilter(bool, [0,1,0,2,0])), [1,2])
501 self.assertEqual(take(4, ifilter(isEven, count())), [0,2,4,6])
502 self.assertRaises(TypeError, ifilter)
503 self.assertRaises(TypeError, ifilter, lambda x:x)
504 self.assertRaises(TypeError, ifilter, lambda x:x, range(6), 7)
505 self.assertRaises(TypeError, ifilter, isEven, 3)
506 self.assertRaises(TypeError, ifilter(range(6), range(6)).next)
508 def test_ifilterfalse(self):
509 self.assertEqual(list(ifilterfalse(isEven, range(6))), [1,3,5])
510 self.assertEqual(list(ifilterfalse(None, [0,1,0,2,0])), [0,0,0])
511 self.assertEqual(list(ifilterfalse(bool, [0,1,0,2,0])), [0,0,0])
512 self.assertEqual(take(4, ifilterfalse(isEven, count())), [1,3,5,7])
513 self.assertRaises(TypeError, ifilterfalse)
514 self.assertRaises(TypeError, ifilterfalse, lambda x:x)
515 self.assertRaises(TypeError, ifilterfalse, lambda x:x, range(6), 7)
516 self.assertRaises(TypeError, ifilterfalse, isEven, 3)
517 self.assertRaises(TypeError, ifilterfalse(range(6), range(6)).next)
519 def test_izip(self):
520 ans = [(x,y) for x, y in izip('abc',count())]
521 self.assertEqual(ans, [('a', 0), ('b', 1), ('c', 2)])
522 self.assertEqual(list(izip('abc', range(6))), zip('abc', range(6)))
523 self.assertEqual(list(izip('abcdef', range(3))), zip('abcdef', range(3)))
524 self.assertEqual(take(3,izip('abcdef', count())), zip('abcdef', range(3)))
525 self.assertEqual(list(izip('abcdef')), zip('abcdef'))
526 self.assertEqual(list(izip()), zip())
527 self.assertRaises(TypeError, izip, 3)
528 self.assertRaises(TypeError, izip, range(3), 3)
529 # Check tuple re-use (implementation detail)
530 self.assertEqual([tuple(list(pair)) for pair in izip('abc', 'def')],
531 zip('abc', 'def'))
532 self.assertEqual([pair for pair in izip('abc', 'def')],
533 zip('abc', 'def'))
534 ids = map(id, izip('abc', 'def'))
535 self.assertEqual(min(ids), max(ids))
536 ids = map(id, list(izip('abc', 'def')))
537 self.assertEqual(len(dict.fromkeys(ids)), len(ids))
539 def test_iziplongest(self):
540 for args in [
541 ['abc', range(6)],
542 [range(6), 'abc'],
543 [range(1000), range(2000,2100), range(3000,3050)],
544 [range(1000), range(0), range(3000,3050), range(1200), range(1500)],
545 [range(1000), range(0), range(3000,3050), range(1200), range(1500), range(0)],
547 # target = map(None, *args) <- this raises a py3k warning
548 # this is the replacement:
549 target = [tuple([arg[i] if i < len(arg) else None for arg in args])
550 for i in range(max(map(len, args)))]
551 self.assertEqual(list(izip_longest(*args)), target)
552 self.assertEqual(list(izip_longest(*args, **{})), target)
553 target = [tuple((e is None and 'X' or e) for e in t) for t in target] # Replace None fills with 'X'
554 self.assertEqual(list(izip_longest(*args, **dict(fillvalue='X'))), target)
556 self.assertEqual(take(3,izip_longest('abcdef', count())), zip('abcdef', range(3))) # take 3 from infinite input
558 self.assertEqual(list(izip_longest()), zip())
559 self.assertEqual(list(izip_longest([])), zip([]))
560 self.assertEqual(list(izip_longest('abcdef')), zip('abcdef'))
562 self.assertEqual(list(izip_longest('abc', 'defg', **{})),
563 zip(list('abc') + [None], 'defg')) # empty keyword dict
564 self.assertRaises(TypeError, izip_longest, 3)
565 self.assertRaises(TypeError, izip_longest, range(3), 3)
567 for stmt in [
568 "izip_longest('abc', fv=1)",
569 "izip_longest('abc', fillvalue=1, bogus_keyword=None)",
571 try:
572 eval(stmt, globals(), locals())
573 except TypeError:
574 pass
575 else:
576 self.fail('Did not raise Type in: ' + stmt)
578 # Check tuple re-use (implementation detail)
579 self.assertEqual([tuple(list(pair)) for pair in izip_longest('abc', 'def')],
580 zip('abc', 'def'))
581 self.assertEqual([pair for pair in izip_longest('abc', 'def')],
582 zip('abc', 'def'))
583 ids = map(id, izip_longest('abc', 'def'))
584 self.assertEqual(min(ids), max(ids))
585 ids = map(id, list(izip_longest('abc', 'def')))
586 self.assertEqual(len(dict.fromkeys(ids)), len(ids))
588 def test_bug_7244(self):
590 class Repeater(object):
591 # this class is similar to itertools.repeat
592 def __init__(self, o, t, e):
593 self.o = o
594 self.t = int(t)
595 self.e = e
596 def __iter__(self): # its iterator is itself
597 return self
598 def next(self):
599 if self.t > 0:
600 self.t -= 1
601 return self.o
602 else:
603 raise self.e
605 # Formerly this code in would fail in debug mode
606 # with Undetected Error and Stop Iteration
607 r1 = Repeater(1, 3, StopIteration)
608 r2 = Repeater(2, 4, StopIteration)
609 def run(r1, r2):
610 result = []
611 for i, j in izip_longest(r1, r2, fillvalue=0):
612 with test_support.captured_output('stdout'):
613 print (i, j)
614 result.append((i, j))
615 return result
616 self.assertEqual(run(r1, r2), [(1,2), (1,2), (1,2), (0,2)])
618 # Formerly, the RuntimeError would be lost
619 # and StopIteration would stop as expected
620 r1 = Repeater(1, 3, RuntimeError)
621 r2 = Repeater(2, 4, StopIteration)
622 it = izip_longest(r1, r2, fillvalue=0)
623 self.assertEqual(next(it), (1, 2))
624 self.assertEqual(next(it), (1, 2))
625 self.assertEqual(next(it), (1, 2))
626 self.assertRaises(RuntimeError, next, it)
628 def test_product(self):
629 for args, result in [
630 ([], [()]), # zero iterables
631 (['ab'], [('a',), ('b',)]), # one iterable
632 ([range(2), range(3)], [(0,0), (0,1), (0,2), (1,0), (1,1), (1,2)]), # two iterables
633 ([range(0), range(2), range(3)], []), # first iterable with zero length
634 ([range(2), range(0), range(3)], []), # middle iterable with zero length
635 ([range(2), range(3), range(0)], []), # last iterable with zero length
637 self.assertEqual(list(product(*args)), result)
638 for r in range(4):
639 self.assertEqual(list(product(*(args*r))),
640 list(product(*args, **dict(repeat=r))))
641 self.assertEqual(len(list(product(*[range(7)]*6))), 7**6)
642 self.assertRaises(TypeError, product, range(6), None)
644 def product1(*args, **kwds):
645 pools = map(tuple, args) * kwds.get('repeat', 1)
646 n = len(pools)
647 if n == 0:
648 yield ()
649 return
650 if any(len(pool) == 0 for pool in pools):
651 return
652 indices = [0] * n
653 yield tuple(pool[i] for pool, i in zip(pools, indices))
654 while 1:
655 for i in reversed(range(n)): # right to left
656 if indices[i] == len(pools[i]) - 1:
657 continue
658 indices[i] += 1
659 for j in range(i+1, n):
660 indices[j] = 0
661 yield tuple(pool[i] for pool, i in zip(pools, indices))
662 break
663 else:
664 return
666 def product2(*args, **kwds):
667 'Pure python version used in docs'
668 pools = map(tuple, args) * kwds.get('repeat', 1)
669 result = [[]]
670 for pool in pools:
671 result = [x+[y] for x in result for y in pool]
672 for prod in result:
673 yield tuple(prod)
675 argtypes = ['', 'abc', '', xrange(0), xrange(4), dict(a=1, b=2, c=3),
676 set('abcdefg'), range(11), tuple(range(13))]
677 for i in range(100):
678 args = [random.choice(argtypes) for j in range(random.randrange(5))]
679 expected_len = prod(map(len, args))
680 self.assertEqual(len(list(product(*args))), expected_len)
681 self.assertEqual(list(product(*args)), list(product1(*args)))
682 self.assertEqual(list(product(*args)), list(product2(*args)))
683 args = map(iter, args)
684 self.assertEqual(len(list(product(*args))), expected_len)
686 # Test implementation detail: tuple re-use
687 self.assertEqual(len(set(map(id, product('abc', 'def')))), 1)
688 self.assertNotEqual(len(set(map(id, list(product('abc', 'def'))))), 1)
690 def test_repeat(self):
691 self.assertEqual(list(repeat(object='a', times=3)), ['a', 'a', 'a'])
692 self.assertEqual(zip(xrange(3),repeat('a')),
693 [(0, 'a'), (1, 'a'), (2, 'a')])
694 self.assertEqual(list(repeat('a', 3)), ['a', 'a', 'a'])
695 self.assertEqual(take(3, repeat('a')), ['a', 'a', 'a'])
696 self.assertEqual(list(repeat('a', 0)), [])
697 self.assertEqual(list(repeat('a', -3)), [])
698 self.assertRaises(TypeError, repeat)
699 self.assertRaises(TypeError, repeat, None, 3, 4)
700 self.assertRaises(TypeError, repeat, None, 'a')
701 r = repeat(1+0j)
702 self.assertEqual(repr(r), 'repeat((1+0j))')
703 r = repeat(1+0j, 5)
704 self.assertEqual(repr(r), 'repeat((1+0j), 5)')
705 list(r)
706 self.assertEqual(repr(r), 'repeat((1+0j), 0)')
708 def test_imap(self):
709 self.assertEqual(list(imap(operator.pow, range(3), range(1,7))),
710 [0**1, 1**2, 2**3])
711 self.assertEqual(list(imap(None, 'abc', range(5))),
712 [('a',0),('b',1),('c',2)])
713 self.assertEqual(list(imap(None, 'abc', count())),
714 [('a',0),('b',1),('c',2)])
715 self.assertEqual(take(2,imap(None, 'abc', count())),
716 [('a',0),('b',1)])
717 self.assertEqual(list(imap(operator.pow, [])), [])
718 self.assertRaises(TypeError, imap)
719 self.assertRaises(TypeError, imap, operator.neg)
720 self.assertRaises(TypeError, imap(10, range(5)).next)
721 self.assertRaises(ValueError, imap(errfunc, [4], [5]).next)
722 self.assertRaises(TypeError, imap(onearg, [4], [5]).next)
724 def test_starmap(self):
725 self.assertEqual(list(starmap(operator.pow, zip(range(3), range(1,7)))),
726 [0**1, 1**2, 2**3])
727 self.assertEqual(take(3, starmap(operator.pow, izip(count(), count(1)))),
728 [0**1, 1**2, 2**3])
729 self.assertEqual(list(starmap(operator.pow, [])), [])
730 self.assertEqual(list(starmap(operator.pow, [iter([4,5])])), [4**5])
731 self.assertRaises(TypeError, list, starmap(operator.pow, [None]))
732 self.assertRaises(TypeError, starmap)
733 self.assertRaises(TypeError, starmap, operator.pow, [(4,5)], 'extra')
734 self.assertRaises(TypeError, starmap(10, [(4,5)]).next)
735 self.assertRaises(ValueError, starmap(errfunc, [(4,5)]).next)
736 self.assertRaises(TypeError, starmap(onearg, [(4,5)]).next)
738 def test_islice(self):
739 for args in [ # islice(args) should agree with range(args)
740 (10, 20, 3),
741 (10, 3, 20),
742 (10, 20),
743 (10, 3),
744 (20,)
746 self.assertEqual(list(islice(xrange(100), *args)), range(*args))
748 for args, tgtargs in [ # Stop when seqn is exhausted
749 ((10, 110, 3), ((10, 100, 3))),
750 ((10, 110), ((10, 100))),
751 ((110,), (100,))
753 self.assertEqual(list(islice(xrange(100), *args)), range(*tgtargs))
755 # Test stop=None
756 self.assertEqual(list(islice(xrange(10), None)), range(10))
757 self.assertEqual(list(islice(xrange(10), None, None)), range(10))
758 self.assertEqual(list(islice(xrange(10), None, None, None)), range(10))
759 self.assertEqual(list(islice(xrange(10), 2, None)), range(2, 10))
760 self.assertEqual(list(islice(xrange(10), 1, None, 2)), range(1, 10, 2))
762 # Test number of items consumed SF #1171417
763 it = iter(range(10))
764 self.assertEqual(list(islice(it, 3)), range(3))
765 self.assertEqual(list(it), range(3, 10))
767 # Test invalid arguments
768 self.assertRaises(TypeError, islice, xrange(10))
769 self.assertRaises(TypeError, islice, xrange(10), 1, 2, 3, 4)
770 self.assertRaises(ValueError, islice, xrange(10), -5, 10, 1)
771 self.assertRaises(ValueError, islice, xrange(10), 1, -5, -1)
772 self.assertRaises(ValueError, islice, xrange(10), 1, 10, -1)
773 self.assertRaises(ValueError, islice, xrange(10), 1, 10, 0)
774 self.assertRaises(ValueError, islice, xrange(10), 'a')
775 self.assertRaises(ValueError, islice, xrange(10), 'a', 1)
776 self.assertRaises(ValueError, islice, xrange(10), 1, 'a')
777 self.assertRaises(ValueError, islice, xrange(10), 'a', 1, 1)
778 self.assertRaises(ValueError, islice, xrange(10), 1, 'a', 1)
779 self.assertEqual(len(list(islice(count(), 1, 10, maxsize))), 1)
781 def test_takewhile(self):
782 data = [1, 3, 5, 20, 2, 4, 6, 8]
783 underten = lambda x: x<10
784 self.assertEqual(list(takewhile(underten, data)), [1, 3, 5])
785 self.assertEqual(list(takewhile(underten, [])), [])
786 self.assertRaises(TypeError, takewhile)
787 self.assertRaises(TypeError, takewhile, operator.pow)
788 self.assertRaises(TypeError, takewhile, operator.pow, [(4,5)], 'extra')
789 self.assertRaises(TypeError, takewhile(10, [(4,5)]).next)
790 self.assertRaises(ValueError, takewhile(errfunc, [(4,5)]).next)
791 t = takewhile(bool, [1, 1, 1, 0, 0, 0])
792 self.assertEqual(list(t), [1, 1, 1])
793 self.assertRaises(StopIteration, t.next)
795 def test_dropwhile(self):
796 data = [1, 3, 5, 20, 2, 4, 6, 8]
797 underten = lambda x: x<10
798 self.assertEqual(list(dropwhile(underten, data)), [20, 2, 4, 6, 8])
799 self.assertEqual(list(dropwhile(underten, [])), [])
800 self.assertRaises(TypeError, dropwhile)
801 self.assertRaises(TypeError, dropwhile, operator.pow)
802 self.assertRaises(TypeError, dropwhile, operator.pow, [(4,5)], 'extra')
803 self.assertRaises(TypeError, dropwhile(10, [(4,5)]).next)
804 self.assertRaises(ValueError, dropwhile(errfunc, [(4,5)]).next)
806 def test_tee(self):
807 n = 200
808 def irange(n):
809 for i in xrange(n):
810 yield i
812 a, b = tee([]) # test empty iterator
813 self.assertEqual(list(a), [])
814 self.assertEqual(list(b), [])
816 a, b = tee(irange(n)) # test 100% interleaved
817 self.assertEqual(zip(a,b), zip(range(n),range(n)))
819 a, b = tee(irange(n)) # test 0% interleaved
820 self.assertEqual(list(a), range(n))
821 self.assertEqual(list(b), range(n))
823 a, b = tee(irange(n)) # test dealloc of leading iterator
824 for i in xrange(100):
825 self.assertEqual(a.next(), i)
826 del a
827 self.assertEqual(list(b), range(n))
829 a, b = tee(irange(n)) # test dealloc of trailing iterator
830 for i in xrange(100):
831 self.assertEqual(a.next(), i)
832 del b
833 self.assertEqual(list(a), range(100, n))
835 for j in xrange(5): # test randomly interleaved
836 order = [0]*n + [1]*n
837 random.shuffle(order)
838 lists = ([], [])
839 its = tee(irange(n))
840 for i in order:
841 value = its[i].next()
842 lists[i].append(value)
843 self.assertEqual(lists[0], range(n))
844 self.assertEqual(lists[1], range(n))
846 # test argument format checking
847 self.assertRaises(TypeError, tee)
848 self.assertRaises(TypeError, tee, 3)
849 self.assertRaises(TypeError, tee, [1,2], 'x')
850 self.assertRaises(TypeError, tee, [1,2], 3, 'x')
852 # tee object should be instantiable
853 a, b = tee('abc')
854 c = type(a)('def')
855 self.assertEqual(list(c), list('def'))
857 # test long-lagged and multi-way split
858 a, b, c = tee(xrange(2000), 3)
859 for i in xrange(100):
860 self.assertEqual(a.next(), i)
861 self.assertEqual(list(b), range(2000))
862 self.assertEqual([c.next(), c.next()], range(2))
863 self.assertEqual(list(a), range(100,2000))
864 self.assertEqual(list(c), range(2,2000))
866 # test values of n
867 self.assertRaises(TypeError, tee, 'abc', 'invalid')
868 self.assertRaises(ValueError, tee, [], -1)
869 for n in xrange(5):
870 result = tee('abc', n)
871 self.assertEqual(type(result), tuple)
872 self.assertEqual(len(result), n)
873 self.assertEqual(map(list, result), [list('abc')]*n)
875 # tee pass-through to copyable iterator
876 a, b = tee('abc')
877 c, d = tee(a)
878 self.assertTrue(a is c)
880 # test tee_new
881 t1, t2 = tee('abc')
882 tnew = type(t1)
883 self.assertRaises(TypeError, tnew)
884 self.assertRaises(TypeError, tnew, 10)
885 t3 = tnew(t1)
886 self.assertTrue(list(t1) == list(t2) == list(t3) == list('abc'))
888 # test that tee objects are weak referencable
889 a, b = tee(xrange(10))
890 p = proxy(a)
891 self.assertEqual(getattr(p, '__class__'), type(b))
892 del a
893 self.assertRaises(ReferenceError, getattr, p, '__class__')
895 def test_StopIteration(self):
896 self.assertRaises(StopIteration, izip().next)
898 for f in (chain, cycle, izip, groupby):
899 self.assertRaises(StopIteration, f([]).next)
900 self.assertRaises(StopIteration, f(StopNow()).next)
902 self.assertRaises(StopIteration, islice([], None).next)
903 self.assertRaises(StopIteration, islice(StopNow(), None).next)
905 p, q = tee([])
906 self.assertRaises(StopIteration, p.next)
907 self.assertRaises(StopIteration, q.next)
908 p, q = tee(StopNow())
909 self.assertRaises(StopIteration, p.next)
910 self.assertRaises(StopIteration, q.next)
912 self.assertRaises(StopIteration, repeat(None, 0).next)
914 for f in (ifilter, ifilterfalse, imap, takewhile, dropwhile, starmap):
915 self.assertRaises(StopIteration, f(lambda x:x, []).next)
916 self.assertRaises(StopIteration, f(lambda x:x, StopNow()).next)
918 class TestExamples(unittest.TestCase):
920 def test_chain(self):
921 self.assertEqual(''.join(chain('ABC', 'DEF')), 'ABCDEF')
923 def test_chain_from_iterable(self):
924 self.assertEqual(''.join(chain.from_iterable(['ABC', 'DEF'])), 'ABCDEF')
926 def test_combinations(self):
927 self.assertEqual(list(combinations('ABCD', 2)),
928 [('A','B'), ('A','C'), ('A','D'), ('B','C'), ('B','D'), ('C','D')])
929 self.assertEqual(list(combinations(range(4), 3)),
930 [(0,1,2), (0,1,3), (0,2,3), (1,2,3)])
932 def test_combinations_with_replacement(self):
933 self.assertEqual(list(combinations_with_replacement('ABC', 2)),
934 [('A','A'), ('A','B'), ('A','C'), ('B','B'), ('B','C'), ('C','C')])
936 def test_compress(self):
937 self.assertEqual(list(compress('ABCDEF', [1,0,1,0,1,1])), list('ACEF'))
939 def test_count(self):
940 self.assertEqual(list(islice(count(10), 5)), [10, 11, 12, 13, 14])
942 def test_cycle(self):
943 self.assertEqual(list(islice(cycle('ABCD'), 12)), list('ABCDABCDABCD'))
945 def test_dropwhile(self):
946 self.assertEqual(list(dropwhile(lambda x: x<5, [1,4,6,4,1])), [6,4,1])
948 def test_groupby(self):
949 self.assertEqual([k for k, g in groupby('AAAABBBCCDAABBB')],
950 list('ABCDAB'))
951 self.assertEqual([(list(g)) for k, g in groupby('AAAABBBCCD')],
952 [list('AAAA'), list('BBB'), list('CC'), list('D')])
954 def test_ifilter(self):
955 self.assertEqual(list(ifilter(lambda x: x%2, range(10))), [1,3,5,7,9])
957 def test_ifilterfalse(self):
958 self.assertEqual(list(ifilterfalse(lambda x: x%2, range(10))), [0,2,4,6,8])
960 def test_imap(self):
961 self.assertEqual(list(imap(pow, (2,3,10), (5,2,3))), [32, 9, 1000])
963 def test_islice(self):
964 self.assertEqual(list(islice('ABCDEFG', 2)), list('AB'))
965 self.assertEqual(list(islice('ABCDEFG', 2, 4)), list('CD'))
966 self.assertEqual(list(islice('ABCDEFG', 2, None)), list('CDEFG'))
967 self.assertEqual(list(islice('ABCDEFG', 0, None, 2)), list('ACEG'))
969 def test_izip(self):
970 self.assertEqual(list(izip('ABCD', 'xy')), [('A', 'x'), ('B', 'y')])
972 def test_izip_longest(self):
973 self.assertEqual(list(izip_longest('ABCD', 'xy', fillvalue='-')),
974 [('A', 'x'), ('B', 'y'), ('C', '-'), ('D', '-')])
976 def test_permutations(self):
977 self.assertEqual(list(permutations('ABCD', 2)),
978 map(tuple, 'AB AC AD BA BC BD CA CB CD DA DB DC'.split()))
979 self.assertEqual(list(permutations(range(3))),
980 [(0,1,2), (0,2,1), (1,0,2), (1,2,0), (2,0,1), (2,1,0)])
982 def test_product(self):
983 self.assertEqual(list(product('ABCD', 'xy')),
984 map(tuple, 'Ax Ay Bx By Cx Cy Dx Dy'.split()))
985 self.assertEqual(list(product(range(2), repeat=3)),
986 [(0,0,0), (0,0,1), (0,1,0), (0,1,1),
987 (1,0,0), (1,0,1), (1,1,0), (1,1,1)])
989 def test_repeat(self):
990 self.assertEqual(list(repeat(10, 3)), [10, 10, 10])
992 def test_stapmap(self):
993 self.assertEqual(list(starmap(pow, [(2,5), (3,2), (10,3)])),
994 [32, 9, 1000])
996 def test_takewhile(self):
997 self.assertEqual(list(takewhile(lambda x: x<5, [1,4,6,4,1])), [1,4])
1000 class TestGC(unittest.TestCase):
1002 def makecycle(self, iterator, container):
1003 container.append(iterator)
1004 iterator.next()
1005 del container, iterator
1007 def test_chain(self):
1008 a = []
1009 self.makecycle(chain(a), a)
1011 def test_chain_from_iterable(self):
1012 a = []
1013 self.makecycle(chain.from_iterable([a]), a)
1015 def test_combinations(self):
1016 a = []
1017 self.makecycle(combinations([1,2,a,3], 3), a)
1019 def test_combinations_with_replacement(self):
1020 a = []
1021 self.makecycle(combinations_with_replacement([1,2,a,3], 3), a)
1023 def test_compress(self):
1024 a = []
1025 self.makecycle(compress('ABCDEF', [1,0,1,0,1,0]), a)
1027 def test_count(self):
1028 a = []
1029 Int = type('Int', (int,), dict(x=a))
1030 self.makecycle(count(Int(0), Int(1)), a)
1032 def test_cycle(self):
1033 a = []
1034 self.makecycle(cycle([a]*2), a)
1036 def test_dropwhile(self):
1037 a = []
1038 self.makecycle(dropwhile(bool, [0, a, a]), a)
1040 def test_groupby(self):
1041 a = []
1042 self.makecycle(groupby([a]*2, lambda x:x), a)
1044 def test_issue2246(self):
1045 # Issue 2246 -- the _grouper iterator was not included in GC
1046 n = 10
1047 keyfunc = lambda x: x
1048 for i, j in groupby(xrange(n), key=keyfunc):
1049 keyfunc.__dict__.setdefault('x',[]).append(j)
1051 def test_ifilter(self):
1052 a = []
1053 self.makecycle(ifilter(lambda x:True, [a]*2), a)
1055 def test_ifilterfalse(self):
1056 a = []
1057 self.makecycle(ifilterfalse(lambda x:False, a), a)
1059 def test_izip(self):
1060 a = []
1061 self.makecycle(izip([a]*2, [a]*3), a)
1063 def test_izip_longest(self):
1064 a = []
1065 self.makecycle(izip_longest([a]*2, [a]*3), a)
1066 b = [a, None]
1067 self.makecycle(izip_longest([a]*2, [a]*3, fillvalue=b), a)
1069 def test_imap(self):
1070 a = []
1071 self.makecycle(imap(lambda x:x, [a]*2), a)
1073 def test_islice(self):
1074 a = []
1075 self.makecycle(islice([a]*2, None), a)
1077 def test_permutations(self):
1078 a = []
1079 self.makecycle(permutations([1,2,a,3], 3), a)
1081 def test_product(self):
1082 a = []
1083 self.makecycle(product([1,2,a,3], repeat=3), a)
1085 def test_repeat(self):
1086 a = []
1087 self.makecycle(repeat(a), a)
1089 def test_starmap(self):
1090 a = []
1091 self.makecycle(starmap(lambda *t: t, [(a,a)]*2), a)
1093 def test_takewhile(self):
1094 a = []
1095 self.makecycle(takewhile(bool, [1, 0, a, a]), a)
1097 def R(seqn):
1098 'Regular generator'
1099 for i in seqn:
1100 yield i
1102 class G:
1103 'Sequence using __getitem__'
1104 def __init__(self, seqn):
1105 self.seqn = seqn
1106 def __getitem__(self, i):
1107 return self.seqn[i]
1109 class I:
1110 'Sequence using iterator protocol'
1111 def __init__(self, seqn):
1112 self.seqn = seqn
1113 self.i = 0
1114 def __iter__(self):
1115 return self
1116 def next(self):
1117 if self.i >= len(self.seqn): raise StopIteration
1118 v = self.seqn[self.i]
1119 self.i += 1
1120 return v
1122 class Ig:
1123 'Sequence using iterator protocol defined with a generator'
1124 def __init__(self, seqn):
1125 self.seqn = seqn
1126 self.i = 0
1127 def __iter__(self):
1128 for val in self.seqn:
1129 yield val
1131 class X:
1132 'Missing __getitem__ and __iter__'
1133 def __init__(self, seqn):
1134 self.seqn = seqn
1135 self.i = 0
1136 def next(self):
1137 if self.i >= len(self.seqn): raise StopIteration
1138 v = self.seqn[self.i]
1139 self.i += 1
1140 return v
1142 class N:
1143 'Iterator missing next()'
1144 def __init__(self, seqn):
1145 self.seqn = seqn
1146 self.i = 0
1147 def __iter__(self):
1148 return self
1150 class E:
1151 'Test propagation of exceptions'
1152 def __init__(self, seqn):
1153 self.seqn = seqn
1154 self.i = 0
1155 def __iter__(self):
1156 return self
1157 def next(self):
1158 3 // 0
1160 class S:
1161 'Test immediate stop'
1162 def __init__(self, seqn):
1163 pass
1164 def __iter__(self):
1165 return self
1166 def next(self):
1167 raise StopIteration
1169 def L(seqn):
1170 'Test multiple tiers of iterators'
1171 return chain(imap(lambda x:x, R(Ig(G(seqn)))))
1174 class TestVariousIteratorArgs(unittest.TestCase):
1176 def test_chain(self):
1177 for s in ("123", "", range(1000), ('do', 1.2), xrange(2000,2200,5)):
1178 for g in (G, I, Ig, S, L, R):
1179 self.assertEqual(list(chain(g(s))), list(g(s)))
1180 self.assertEqual(list(chain(g(s), g(s))), list(g(s))+list(g(s)))
1181 self.assertRaises(TypeError, list, chain(X(s)))
1182 self.assertRaises(TypeError, list, chain(N(s)))
1183 self.assertRaises(ZeroDivisionError, list, chain(E(s)))
1185 def test_compress(self):
1186 for s in ("123", "", range(1000), ('do', 1.2), xrange(2000,2200,5)):
1187 n = len(s)
1188 for g in (G, I, Ig, S, L, R):
1189 self.assertEqual(list(compress(g(s), repeat(1))), list(g(s)))
1190 self.assertRaises(TypeError, compress, X(s), repeat(1))
1191 self.assertRaises(TypeError, list, compress(N(s), repeat(1)))
1192 self.assertRaises(ZeroDivisionError, list, compress(E(s), repeat(1)))
1194 def test_product(self):
1195 for s in ("123", "", range(1000), ('do', 1.2), xrange(2000,2200,5)):
1196 self.assertRaises(TypeError, product, X(s))
1197 self.assertRaises(TypeError, product, N(s))
1198 self.assertRaises(ZeroDivisionError, product, E(s))
1200 def test_cycle(self):
1201 for s in ("123", "", range(1000), ('do', 1.2), xrange(2000,2200,5)):
1202 for g in (G, I, Ig, S, L, R):
1203 tgtlen = len(s) * 3
1204 expected = list(g(s))*3
1205 actual = list(islice(cycle(g(s)), tgtlen))
1206 self.assertEqual(actual, expected)
1207 self.assertRaises(TypeError, cycle, X(s))
1208 self.assertRaises(TypeError, list, cycle(N(s)))
1209 self.assertRaises(ZeroDivisionError, list, cycle(E(s)))
1211 def test_groupby(self):
1212 for s in (range(10), range(0), range(1000), (7,11), xrange(2000,2200,5)):
1213 for g in (G, I, Ig, S, L, R):
1214 self.assertEqual([k for k, sb in groupby(g(s))], list(g(s)))
1215 self.assertRaises(TypeError, groupby, X(s))
1216 self.assertRaises(TypeError, list, groupby(N(s)))
1217 self.assertRaises(ZeroDivisionError, list, groupby(E(s)))
1219 def test_ifilter(self):
1220 for s in (range(10), range(0), range(1000), (7,11), xrange(2000,2200,5)):
1221 for g in (G, I, Ig, S, L, R):
1222 self.assertEqual(list(ifilter(isEven, g(s))), filter(isEven, g(s)))
1223 self.assertRaises(TypeError, ifilter, isEven, X(s))
1224 self.assertRaises(TypeError, list, ifilter(isEven, N(s)))
1225 self.assertRaises(ZeroDivisionError, list, ifilter(isEven, E(s)))
1227 def test_ifilterfalse(self):
1228 for s in (range(10), range(0), range(1000), (7,11), xrange(2000,2200,5)):
1229 for g in (G, I, Ig, S, L, R):
1230 self.assertEqual(list(ifilterfalse(isEven, g(s))), filter(isOdd, g(s)))
1231 self.assertRaises(TypeError, ifilterfalse, isEven, X(s))
1232 self.assertRaises(TypeError, list, ifilterfalse(isEven, N(s)))
1233 self.assertRaises(ZeroDivisionError, list, ifilterfalse(isEven, E(s)))
1235 def test_izip(self):
1236 for s in ("123", "", range(1000), ('do', 1.2), xrange(2000,2200,5)):
1237 for g in (G, I, Ig, S, L, R):
1238 self.assertEqual(list(izip(g(s))), zip(g(s)))
1239 self.assertEqual(list(izip(g(s), g(s))), zip(g(s), g(s)))
1240 self.assertRaises(TypeError, izip, X(s))
1241 self.assertRaises(TypeError, list, izip(N(s)))
1242 self.assertRaises(ZeroDivisionError, list, izip(E(s)))
1244 def test_iziplongest(self):
1245 for s in ("123", "", range(1000), ('do', 1.2), xrange(2000,2200,5)):
1246 for g in (G, I, Ig, S, L, R):
1247 self.assertEqual(list(izip_longest(g(s))), zip(g(s)))
1248 self.assertEqual(list(izip_longest(g(s), g(s))), zip(g(s), g(s)))
1249 self.assertRaises(TypeError, izip_longest, X(s))
1250 self.assertRaises(TypeError, list, izip_longest(N(s)))
1251 self.assertRaises(ZeroDivisionError, list, izip_longest(E(s)))
1253 def test_imap(self):
1254 for s in (range(10), range(0), range(100), (7,11), xrange(20,50,5)):
1255 for g in (G, I, Ig, S, L, R):
1256 self.assertEqual(list(imap(onearg, g(s))), map(onearg, g(s)))
1257 self.assertEqual(list(imap(operator.pow, g(s), g(s))), map(operator.pow, g(s), g(s)))
1258 self.assertRaises(TypeError, imap, onearg, X(s))
1259 self.assertRaises(TypeError, list, imap(onearg, N(s)))
1260 self.assertRaises(ZeroDivisionError, list, imap(onearg, E(s)))
1262 def test_islice(self):
1263 for s in ("12345", "", range(1000), ('do', 1.2), xrange(2000,2200,5)):
1264 for g in (G, I, Ig, S, L, R):
1265 self.assertEqual(list(islice(g(s),1,None,2)), list(g(s))[1::2])
1266 self.assertRaises(TypeError, islice, X(s), 10)
1267 self.assertRaises(TypeError, list, islice(N(s), 10))
1268 self.assertRaises(ZeroDivisionError, list, islice(E(s), 10))
1270 def test_starmap(self):
1271 for s in (range(10), range(0), range(100), (7,11), xrange(20,50,5)):
1272 for g in (G, I, Ig, S, L, R):
1273 ss = zip(s, s)
1274 self.assertEqual(list(starmap(operator.pow, g(ss))), map(operator.pow, g(s), g(s)))
1275 self.assertRaises(TypeError, starmap, operator.pow, X(ss))
1276 self.assertRaises(TypeError, list, starmap(operator.pow, N(ss)))
1277 self.assertRaises(ZeroDivisionError, list, starmap(operator.pow, E(ss)))
1279 def test_takewhile(self):
1280 for s in (range(10), range(0), range(1000), (7,11), xrange(2000,2200,5)):
1281 for g in (G, I, Ig, S, L, R):
1282 tgt = []
1283 for elem in g(s):
1284 if not isEven(elem): break
1285 tgt.append(elem)
1286 self.assertEqual(list(takewhile(isEven, g(s))), tgt)
1287 self.assertRaises(TypeError, takewhile, isEven, X(s))
1288 self.assertRaises(TypeError, list, takewhile(isEven, N(s)))
1289 self.assertRaises(ZeroDivisionError, list, takewhile(isEven, E(s)))
1291 def test_dropwhile(self):
1292 for s in (range(10), range(0), range(1000), (7,11), xrange(2000,2200,5)):
1293 for g in (G, I, Ig, S, L, R):
1294 tgt = []
1295 for elem in g(s):
1296 if not tgt and isOdd(elem): continue
1297 tgt.append(elem)
1298 self.assertEqual(list(dropwhile(isOdd, g(s))), tgt)
1299 self.assertRaises(TypeError, dropwhile, isOdd, X(s))
1300 self.assertRaises(TypeError, list, dropwhile(isOdd, N(s)))
1301 self.assertRaises(ZeroDivisionError, list, dropwhile(isOdd, E(s)))
1303 def test_tee(self):
1304 for s in ("123", "", range(1000), ('do', 1.2), xrange(2000,2200,5)):
1305 for g in (G, I, Ig, S, L, R):
1306 it1, it2 = tee(g(s))
1307 self.assertEqual(list(it1), list(g(s)))
1308 self.assertEqual(list(it2), list(g(s)))
1309 self.assertRaises(TypeError, tee, X(s))
1310 self.assertRaises(TypeError, list, tee(N(s))[0])
1311 self.assertRaises(ZeroDivisionError, list, tee(E(s))[0])
1313 class LengthTransparency(unittest.TestCase):
1315 def test_repeat(self):
1316 from test.test_iterlen import len
1317 self.assertEqual(len(repeat(None, 50)), 50)
1318 self.assertRaises(TypeError, len, repeat(None))
1320 class RegressionTests(unittest.TestCase):
1322 def test_sf_793826(self):
1323 # Fix Armin Rigo's successful efforts to wreak havoc
1325 def mutatingtuple(tuple1, f, tuple2):
1326 # this builds a tuple t which is a copy of tuple1,
1327 # then calls f(t), then mutates t to be equal to tuple2
1328 # (needs len(tuple1) == len(tuple2)).
1329 def g(value, first=[1]):
1330 if first:
1331 del first[:]
1332 f(z.next())
1333 return value
1334 items = list(tuple2)
1335 items[1:1] = list(tuple1)
1336 gen = imap(g, items)
1337 z = izip(*[gen]*len(tuple1))
1338 z.next()
1340 def f(t):
1341 global T
1342 T = t
1343 first[:] = list(T)
1345 first = []
1346 mutatingtuple((1,2,3), f, (4,5,6))
1347 second = list(T)
1348 self.assertEqual(first, second)
1351 def test_sf_950057(self):
1352 # Make sure that chain() and cycle() catch exceptions immediately
1353 # rather than when shifting between input sources
1355 def gen1():
1356 hist.append(0)
1357 yield 1
1358 hist.append(1)
1359 raise AssertionError
1360 hist.append(2)
1362 def gen2(x):
1363 hist.append(3)
1364 yield 2
1365 hist.append(4)
1366 if x:
1367 raise StopIteration
1369 hist = []
1370 self.assertRaises(AssertionError, list, chain(gen1(), gen2(False)))
1371 self.assertEqual(hist, [0,1])
1373 hist = []
1374 self.assertRaises(AssertionError, list, chain(gen1(), gen2(True)))
1375 self.assertEqual(hist, [0,1])
1377 hist = []
1378 self.assertRaises(AssertionError, list, cycle(gen1()))
1379 self.assertEqual(hist, [0,1])
1381 class SubclassWithKwargsTest(unittest.TestCase):
1382 def test_keywords_in_subclass(self):
1383 # count is not subclassable...
1384 for cls in (repeat, izip, ifilter, ifilterfalse, chain, imap,
1385 starmap, islice, takewhile, dropwhile, cycle, compress):
1386 class Subclass(cls):
1387 def __init__(self, newarg=None, *args):
1388 cls.__init__(self, *args)
1389 try:
1390 Subclass(newarg=1)
1391 except TypeError, err:
1392 # we expect type errors because of wrong argument count
1393 self.assertNotIn("does not take keyword arguments", err.args[0])
1396 libreftest = """ Doctest for examples in the library reference: libitertools.tex
1399 >>> amounts = [120.15, 764.05, 823.14]
1400 >>> for checknum, amount in izip(count(1200), amounts):
1401 ... print 'Check %d is for $%.2f' % (checknum, amount)
1403 Check 1200 is for $120.15
1404 Check 1201 is for $764.05
1405 Check 1202 is for $823.14
1407 >>> import operator
1408 >>> for cube in imap(operator.pow, xrange(1,4), repeat(3)):
1409 ... print cube
1415 >>> reportlines = ['EuroPython', 'Roster', '', 'alex', '', 'laura', '', 'martin', '', 'walter', '', 'samuele']
1416 >>> for name in islice(reportlines, 3, None, 2):
1417 ... print name.title()
1419 Alex
1420 Laura
1421 Martin
1422 Walter
1423 Samuele
1425 >>> from operator import itemgetter
1426 >>> d = dict(a=1, b=2, c=1, d=2, e=1, f=2, g=3)
1427 >>> di = sorted(sorted(d.iteritems()), key=itemgetter(1))
1428 >>> for k, g in groupby(di, itemgetter(1)):
1429 ... print k, map(itemgetter(0), g)
1431 1 ['a', 'c', 'e']
1432 2 ['b', 'd', 'f']
1433 3 ['g']
1435 # Find runs of consecutive numbers using groupby. The key to the solution
1436 # is differencing with a range so that consecutive numbers all appear in
1437 # same group.
1438 >>> data = [ 1, 4,5,6, 10, 15,16,17,18, 22, 25,26,27,28]
1439 >>> for k, g in groupby(enumerate(data), lambda t:t[0]-t[1]):
1440 ... print map(operator.itemgetter(1), g)
1443 [4, 5, 6]
1444 [10]
1445 [15, 16, 17, 18]
1446 [22]
1447 [25, 26, 27, 28]
1449 >>> def take(n, iterable):
1450 ... "Return first n items of the iterable as a list"
1451 ... return list(islice(iterable, n))
1453 >>> def enumerate(iterable, start=0):
1454 ... return izip(count(start), iterable)
1456 >>> def tabulate(function, start=0):
1457 ... "Return function(0), function(1), ..."
1458 ... return imap(function, count(start))
1460 >>> def nth(iterable, n, default=None):
1461 ... "Returns the nth item or a default value"
1462 ... return next(islice(iterable, n, None), default)
1464 >>> def quantify(iterable, pred=bool):
1465 ... "Count how many times the predicate is true"
1466 ... return sum(imap(pred, iterable))
1468 >>> def padnone(iterable):
1469 ... "Returns the sequence elements and then returns None indefinitely"
1470 ... return chain(iterable, repeat(None))
1472 >>> def ncycles(iterable, n):
1473 ... "Returns the seqeuence elements n times"
1474 ... return chain(*repeat(iterable, n))
1476 >>> def dotproduct(vec1, vec2):
1477 ... return sum(imap(operator.mul, vec1, vec2))
1479 >>> def flatten(listOfLists):
1480 ... return list(chain.from_iterable(listOfLists))
1482 >>> def repeatfunc(func, times=None, *args):
1483 ... "Repeat calls to func with specified arguments."
1484 ... " Example: repeatfunc(random.random)"
1485 ... if times is None:
1486 ... return starmap(func, repeat(args))
1487 ... else:
1488 ... return starmap(func, repeat(args, times))
1490 >>> def pairwise(iterable):
1491 ... "s -> (s0,s1), (s1,s2), (s2, s3), ..."
1492 ... a, b = tee(iterable)
1493 ... for elem in b:
1494 ... break
1495 ... return izip(a, b)
1497 >>> def grouper(n, iterable, fillvalue=None):
1498 ... "grouper(3, 'ABCDEFG', 'x') --> ABC DEF Gxx"
1499 ... args = [iter(iterable)] * n
1500 ... return izip_longest(fillvalue=fillvalue, *args)
1502 >>> def roundrobin(*iterables):
1503 ... "roundrobin('ABC', 'D', 'EF') --> A D E B F C"
1504 ... # Recipe credited to George Sakkis
1505 ... pending = len(iterables)
1506 ... nexts = cycle(iter(it).next for it in iterables)
1507 ... while pending:
1508 ... try:
1509 ... for next in nexts:
1510 ... yield next()
1511 ... except StopIteration:
1512 ... pending -= 1
1513 ... nexts = cycle(islice(nexts, pending))
1515 >>> def powerset(iterable):
1516 ... "powerset([1,2,3]) --> () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)"
1517 ... s = list(iterable)
1518 ... return chain.from_iterable(combinations(s, r) for r in range(len(s)+1))
1520 >>> def unique_everseen(iterable, key=None):
1521 ... "List unique elements, preserving order. Remember all elements ever seen."
1522 ... # unique_everseen('AAAABBBCCDAABBB') --> A B C D
1523 ... # unique_everseen('ABBCcAD', str.lower) --> A B C D
1524 ... seen = set()
1525 ... seen_add = seen.add
1526 ... if key is None:
1527 ... for element in iterable:
1528 ... if element not in seen:
1529 ... seen_add(element)
1530 ... yield element
1531 ... else:
1532 ... for element in iterable:
1533 ... k = key(element)
1534 ... if k not in seen:
1535 ... seen_add(k)
1536 ... yield element
1538 >>> def unique_justseen(iterable, key=None):
1539 ... "List unique elements, preserving order. Remember only the element just seen."
1540 ... # unique_justseen('AAAABBBCCDAABBB') --> A B C D A B
1541 ... # unique_justseen('ABBCcAD', str.lower) --> A B C A D
1542 ... return imap(next, imap(itemgetter(1), groupby(iterable, key)))
1544 This is not part of the examples but it tests to make sure the definitions
1545 perform as purported.
1547 >>> take(10, count())
1548 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
1550 >>> list(enumerate('abc'))
1551 [(0, 'a'), (1, 'b'), (2, 'c')]
1553 >>> list(islice(tabulate(lambda x: 2*x), 4))
1554 [0, 2, 4, 6]
1556 >>> nth('abcde', 3)
1559 >>> nth('abcde', 9) is None
1560 True
1562 >>> quantify(xrange(99), lambda x: x%2==0)
1565 >>> a = [[1, 2, 3], [4, 5, 6]]
1566 >>> flatten(a)
1567 [1, 2, 3, 4, 5, 6]
1569 >>> list(repeatfunc(pow, 5, 2, 3))
1570 [8, 8, 8, 8, 8]
1572 >>> import random
1573 >>> take(5, imap(int, repeatfunc(random.random)))
1574 [0, 0, 0, 0, 0]
1576 >>> list(pairwise('abcd'))
1577 [('a', 'b'), ('b', 'c'), ('c', 'd')]
1579 >>> list(pairwise([]))
1582 >>> list(pairwise('a'))
1585 >>> list(islice(padnone('abc'), 0, 6))
1586 ['a', 'b', 'c', None, None, None]
1588 >>> list(ncycles('abc', 3))
1589 ['a', 'b', 'c', 'a', 'b', 'c', 'a', 'b', 'c']
1591 >>> dotproduct([1,2,3], [4,5,6])
1594 >>> list(grouper(3, 'abcdefg', 'x'))
1595 [('a', 'b', 'c'), ('d', 'e', 'f'), ('g', 'x', 'x')]
1597 >>> list(roundrobin('abc', 'd', 'ef'))
1598 ['a', 'd', 'e', 'b', 'f', 'c']
1600 >>> list(powerset([1,2,3]))
1601 [(), (1,), (2,), (3,), (1, 2), (1, 3), (2, 3), (1, 2, 3)]
1603 >>> all(len(list(powerset(range(n)))) == 2**n for n in range(18))
1604 True
1606 >>> list(powerset('abcde')) == sorted(sorted(set(powerset('abcde'))), key=len)
1607 True
1609 >>> list(unique_everseen('AAAABBBCCDAABBB'))
1610 ['A', 'B', 'C', 'D']
1612 >>> list(unique_everseen('ABBCcAD', str.lower))
1613 ['A', 'B', 'C', 'D']
1615 >>> list(unique_justseen('AAAABBBCCDAABBB'))
1616 ['A', 'B', 'C', 'D', 'A', 'B']
1618 >>> list(unique_justseen('ABBCcAD', str.lower))
1619 ['A', 'B', 'C', 'A', 'D']
1623 __test__ = {'libreftest' : libreftest}
1625 def test_main(verbose=None):
1626 test_classes = (TestBasicOps, TestVariousIteratorArgs, TestGC,
1627 RegressionTests, LengthTransparency,
1628 SubclassWithKwargsTest, TestExamples)
1629 test_support.run_unittest(*test_classes)
1631 # verify reference counting
1632 if verbose and hasattr(sys, "gettotalrefcount"):
1633 import gc
1634 counts = [None] * 5
1635 for i in xrange(len(counts)):
1636 test_support.run_unittest(*test_classes)
1637 gc.collect()
1638 counts[i] = sys.gettotalrefcount()
1639 print counts
1641 # doctest the examples in the library reference
1642 test_support.run_doctest(sys.modules[__name__], verbose)
1644 if __name__ == "__main__":
1645 test_main(verbose=True)