Warn against replacing PyNumber_Add with PyNumber_InPlaceAdd in sum
[python.git] / Lib / collections.py
blobabf6f8927d9c4015fe04281243505f270a8f05f2
1 __all__ = ['Counter', 'deque', 'defaultdict', 'namedtuple', 'OrderedDict']
2 # For bootstrapping reasons, the collection ABCs are defined in _abcoll.py.
3 # They should however be considered an integral part of collections.py.
4 from _abcoll import *
5 import _abcoll
6 __all__ += _abcoll.__all__
8 from _collections import deque, defaultdict
9 from operator import itemgetter as _itemgetter, eq as _eq
10 from keyword import iskeyword as _iskeyword
11 import sys as _sys
12 import heapq as _heapq
13 from weakref import proxy as _proxy
14 from itertools import repeat as _repeat, chain as _chain, starmap as _starmap, \
15 ifilter as _ifilter, imap as _imap, izip as _izip
17 ################################################################################
18 ### OrderedDict
19 ################################################################################
21 class _Link(object):
22 __slots__ = 'prev', 'next', 'key', '__weakref__'
24 class OrderedDict(dict, MutableMapping):
25 'Dictionary that remembers insertion order'
26 # An inherited dict maps keys to values.
27 # The inherited dict provides __getitem__, __len__, __contains__, and get.
28 # The remaining methods are order-aware.
29 # Big-O running times for all methods are the same as for regular dictionaries.
31 # The internal self.__map dictionary maps keys to links in a doubly linked list.
32 # The circular doubly linked list starts and ends with a sentinel element.
33 # The sentinel element never gets deleted (this simplifies the algorithm).
34 # The prev/next links are weakref proxies (to prevent circular references).
35 # Individual links are kept alive by the hard reference in self.__map.
36 # Those hard references disappear when a key is deleted from an OrderedDict.
38 def __init__(self, *args, **kwds):
39 '''Initialize an ordered dictionary. Signature is the same as for
40 regular dictionaries, but keyword arguments are not recommended
41 because their insertion order is arbitrary.
43 '''
44 if len(args) > 1:
45 raise TypeError('expected at most 1 arguments, got %d' % len(args))
46 try:
47 self.__root
48 except AttributeError:
49 self.__root = root = _Link() # sentinel node for the doubly linked list
50 root.prev = root.next = root
51 self.__map = {}
52 self.update(*args, **kwds)
54 def clear(self):
55 'od.clear() -> None. Remove all items from od.'
56 root = self.__root
57 root.prev = root.next = root
58 self.__map.clear()
59 dict.clear(self)
61 def __setitem__(self, key, value):
62 'od.__setitem__(i, y) <==> od[i]=y'
63 # Setting a new item creates a new link which goes at the end of the linked
64 # list, and the inherited dictionary is updated with the new key/value pair.
65 if key not in self:
66 self.__map[key] = link = _Link()
67 root = self.__root
68 last = root.prev
69 link.prev, link.next, link.key = last, root, key
70 last.next = root.prev = _proxy(link)
71 dict.__setitem__(self, key, value)
73 def __delitem__(self, key):
74 'od.__delitem__(y) <==> del od[y]'
75 # Deleting an existing item uses self.__map to find the link which is
76 # then removed by updating the links in the predecessor and successor nodes.
77 dict.__delitem__(self, key)
78 link = self.__map.pop(key)
79 link.prev.next = link.next
80 link.next.prev = link.prev
82 def __iter__(self):
83 'od.__iter__() <==> iter(od)'
84 # Traverse the linked list in order.
85 root = self.__root
86 curr = root.next
87 while curr is not root:
88 yield curr.key
89 curr = curr.next
91 def __reversed__(self):
92 'od.__reversed__() <==> reversed(od)'
93 # Traverse the linked list in reverse order.
94 root = self.__root
95 curr = root.prev
96 while curr is not root:
97 yield curr.key
98 curr = curr.prev
100 def __reduce__(self):
101 'Return state information for pickling'
102 items = [[k, self[k]] for k in self]
103 tmp = self.__map, self.__root
104 del self.__map, self.__root
105 inst_dict = vars(self).copy()
106 self.__map, self.__root = tmp
107 if inst_dict:
108 return (self.__class__, (items,), inst_dict)
109 return self.__class__, (items,)
111 setdefault = MutableMapping.setdefault
112 update = MutableMapping.update
113 pop = MutableMapping.pop
114 keys = MutableMapping.keys
115 values = MutableMapping.values
116 items = MutableMapping.items
117 iterkeys = MutableMapping.iterkeys
118 itervalues = MutableMapping.itervalues
119 iteritems = MutableMapping.iteritems
120 __ne__ = MutableMapping.__ne__
122 def popitem(self, last=True):
123 '''od.popitem() -> (k, v), return and remove a (key, value) pair.
124 Pairs are returned in LIFO order if last is true or FIFO order if false.
127 if not self:
128 raise KeyError('dictionary is empty')
129 key = next(reversed(self) if last else iter(self))
130 value = self.pop(key)
131 return key, value
133 def __repr__(self):
134 'od.__repr__() <==> repr(od)'
135 if not self:
136 return '%s()' % (self.__class__.__name__,)
137 return '%s(%r)' % (self.__class__.__name__, self.items())
139 def copy(self):
140 'od.copy() -> a shallow copy of od'
141 return self.__class__(self)
143 @classmethod
144 def fromkeys(cls, iterable, value=None):
145 '''OD.fromkeys(S[, v]) -> New ordered dictionary with keys from S
146 and values equal to v (which defaults to None).
149 d = cls()
150 for key in iterable:
151 d[key] = value
152 return d
154 def __eq__(self, other):
155 '''od.__eq__(y) <==> od==y. Comparison to another OD is order-sensitive
156 while comparison to a regular mapping is order-insensitive.
159 if isinstance(other, OrderedDict):
160 return len(self)==len(other) and \
161 all(_imap(_eq, self.iteritems(), other.iteritems()))
162 return dict.__eq__(self, other)
166 ################################################################################
167 ### namedtuple
168 ################################################################################
170 def namedtuple(typename, field_names, verbose=False, rename=False):
171 """Returns a new subclass of tuple with named fields.
173 >>> Point = namedtuple('Point', 'x y')
174 >>> Point.__doc__ # docstring for the new class
175 'Point(x, y)'
176 >>> p = Point(11, y=22) # instantiate with positional args or keywords
177 >>> p[0] + p[1] # indexable like a plain tuple
179 >>> x, y = p # unpack like a regular tuple
180 >>> x, y
181 (11, 22)
182 >>> p.x + p.y # fields also accessable by name
184 >>> d = p._asdict() # convert to a dictionary
185 >>> d['x']
187 >>> Point(**d) # convert from a dictionary
188 Point(x=11, y=22)
189 >>> p._replace(x=100) # _replace() is like str.replace() but targets named fields
190 Point(x=100, y=22)
194 # Parse and validate the field names. Validation serves two purposes,
195 # generating informative error messages and preventing template injection attacks.
196 if isinstance(field_names, basestring):
197 field_names = field_names.replace(',', ' ').split() # names separated by whitespace and/or commas
198 field_names = tuple(map(str, field_names))
199 if rename:
200 names = list(field_names)
201 seen = set()
202 for i, name in enumerate(names):
203 if (not all(c.isalnum() or c=='_' for c in name) or _iskeyword(name)
204 or not name or name[0].isdigit() or name.startswith('_')
205 or name in seen):
206 names[i] = '_%d' % i
207 seen.add(name)
208 field_names = tuple(names)
209 for name in (typename,) + field_names:
210 if not all(c.isalnum() or c=='_' for c in name):
211 raise ValueError('Type names and field names can only contain alphanumeric characters and underscores: %r' % name)
212 if _iskeyword(name):
213 raise ValueError('Type names and field names cannot be a keyword: %r' % name)
214 if name[0].isdigit():
215 raise ValueError('Type names and field names cannot start with a number: %r' % name)
216 seen_names = set()
217 for name in field_names:
218 if name.startswith('_') and not rename:
219 raise ValueError('Field names cannot start with an underscore: %r' % name)
220 if name in seen_names:
221 raise ValueError('Encountered duplicate field name: %r' % name)
222 seen_names.add(name)
224 # Create and fill-in the class template
225 numfields = len(field_names)
226 argtxt = repr(field_names).replace("'", "")[1:-1] # tuple repr without parens or quotes
227 reprtxt = ', '.join('%s=%%r' % name for name in field_names)
228 template = '''class %(typename)s(tuple):
229 '%(typename)s(%(argtxt)s)' \n
230 __slots__ = () \n
231 _fields = %(field_names)r \n
232 def __new__(_cls, %(argtxt)s):
233 return _tuple.__new__(_cls, (%(argtxt)s)) \n
234 @classmethod
235 def _make(cls, iterable, new=tuple.__new__, len=len):
236 'Make a new %(typename)s object from a sequence or iterable'
237 result = new(cls, iterable)
238 if len(result) != %(numfields)d:
239 raise TypeError('Expected %(numfields)d arguments, got %%d' %% len(result))
240 return result \n
241 def __repr__(self):
242 return '%(typename)s(%(reprtxt)s)' %% self \n
243 def _asdict(self):
244 'Return a new OrderedDict which maps field names to their values'
245 return OrderedDict(zip(self._fields, self)) \n
246 def _replace(_self, **kwds):
247 'Return a new %(typename)s object replacing specified fields with new values'
248 result = _self._make(map(kwds.pop, %(field_names)r, _self))
249 if kwds:
250 raise ValueError('Got unexpected field names: %%r' %% kwds.keys())
251 return result \n
252 def __getnewargs__(self):
253 return tuple(self) \n\n''' % locals()
254 for i, name in enumerate(field_names):
255 template += ' %s = _property(_itemgetter(%d))\n' % (name, i)
256 if verbose:
257 print template
259 # Execute the template string in a temporary namespace and
260 # support tracing utilities by setting a value for frame.f_globals['__name__']
261 namespace = dict(_itemgetter=_itemgetter, __name__='namedtuple_%s' % typename,
262 OrderedDict=OrderedDict, _property=property, _tuple=tuple)
263 try:
264 exec template in namespace
265 except SyntaxError, e:
266 raise SyntaxError(e.message + ':\n' + template)
267 result = namespace[typename]
269 # For pickling to work, the __module__ variable needs to be set to the frame
270 # where the named tuple is created. Bypass this step in enviroments where
271 # sys._getframe is not defined (Jython for example) or sys._getframe is not
272 # defined for arguments greater than 0 (IronPython).
273 try:
274 result.__module__ = _sys._getframe(1).f_globals.get('__name__', '__main__')
275 except (AttributeError, ValueError):
276 pass
278 return result
281 ########################################################################
282 ### Counter
283 ########################################################################
285 class Counter(dict):
286 '''Dict subclass for counting hashable items. Sometimes called a bag
287 or multiset. Elements are stored as dictionary keys and their counts
288 are stored as dictionary values.
290 >>> c = Counter('abracadabra') # count elements from a string
292 >>> c.most_common(3) # three most common elements
293 [('a', 5), ('r', 2), ('b', 2)]
294 >>> sorted(c) # list all unique elements
295 ['a', 'b', 'c', 'd', 'r']
296 >>> ''.join(sorted(c.elements())) # list elements with repetitions
297 'aaaaabbcdrr'
298 >>> sum(c.values()) # total of all counts
301 >>> c['a'] # count of letter 'a'
303 >>> for elem in 'shazam': # update counts from an iterable
304 ... c[elem] += 1 # by adding 1 to each element's count
305 >>> c['a'] # now there are seven 'a'
307 >>> del c['r'] # remove all 'r'
308 >>> c['r'] # now there are zero 'r'
311 >>> d = Counter('simsalabim') # make another counter
312 >>> c.update(d) # add in the second counter
313 >>> c['a'] # now there are nine 'a'
316 >>> c.clear() # empty the counter
317 >>> c
318 Counter()
320 Note: If a count is set to zero or reduced to zero, it will remain
321 in the counter until the entry is deleted or the counter is cleared:
323 >>> c = Counter('aaabbc')
324 >>> c['b'] -= 2 # reduce the count of 'b' by two
325 >>> c.most_common() # 'b' is still in, but its count is zero
326 [('a', 3), ('c', 1), ('b', 0)]
329 # References:
330 # http://en.wikipedia.org/wiki/Multiset
331 # http://www.gnu.org/software/smalltalk/manual-base/html_node/Bag.html
332 # http://www.demo2s.com/Tutorial/Cpp/0380__set-multiset/Catalog0380__set-multiset.htm
333 # http://code.activestate.com/recipes/259174/
334 # Knuth, TAOCP Vol. II section 4.6.3
336 def __init__(self, iterable=None, **kwds):
337 '''Create a new, empty Counter object. And if given, count elements
338 from an input iterable. Or, initialize the count from another mapping
339 of elements to their counts.
341 >>> c = Counter() # a new, empty counter
342 >>> c = Counter('gallahad') # a new counter from an iterable
343 >>> c = Counter({'a': 4, 'b': 2}) # a new counter from a mapping
344 >>> c = Counter(a=4, b=2) # a new counter from keyword args
347 self.update(iterable, **kwds)
349 def __missing__(self, key):
350 'The count of elements not in the Counter is zero.'
351 # Needed so that self[missing_item] does not raise KeyError
352 return 0
354 def most_common(self, n=None):
355 '''List the n most common elements and their counts from the most
356 common to the least. If n is None, then list all element counts.
358 >>> Counter('abracadabra').most_common(3)
359 [('a', 5), ('r', 2), ('b', 2)]
362 # Emulate Bag.sortedByCount from Smalltalk
363 if n is None:
364 return sorted(self.iteritems(), key=_itemgetter(1), reverse=True)
365 return _heapq.nlargest(n, self.iteritems(), key=_itemgetter(1))
367 def elements(self):
368 '''Iterator over elements repeating each as many times as its count.
370 >>> c = Counter('ABCABC')
371 >>> sorted(c.elements())
372 ['A', 'A', 'B', 'B', 'C', 'C']
374 # Knuth's example for prime factors of 1836: 2**2 * 3**3 * 17**1
375 >>> prime_factors = Counter({2: 2, 3: 3, 17: 1})
376 >>> product = 1
377 >>> for factor in prime_factors.elements(): # loop over factors
378 ... product *= factor # and multiply them
379 >>> product
380 1836
382 Note, if an element's count has been set to zero or is a negative
383 number, elements() will ignore it.
386 # Emulate Bag.do from Smalltalk and Multiset.begin from C++.
387 return _chain.from_iterable(_starmap(_repeat, self.iteritems()))
389 # Override dict methods where necessary
391 @classmethod
392 def fromkeys(cls, iterable, v=None):
393 # There is no equivalent method for counters because setting v=1
394 # means that no element can have a count greater than one.
395 raise NotImplementedError(
396 'Counter.fromkeys() is undefined. Use Counter(iterable) instead.')
398 def update(self, iterable=None, **kwds):
399 '''Like dict.update() but add counts instead of replacing them.
401 Source can be an iterable, a dictionary, or another Counter instance.
403 >>> c = Counter('which')
404 >>> c.update('witch') # add elements from another iterable
405 >>> d = Counter('watch')
406 >>> c.update(d) # add elements from another counter
407 >>> c['h'] # four 'h' in which, witch, and watch
411 # The regular dict.update() operation makes no sense here because the
412 # replace behavior results in the some of original untouched counts
413 # being mixed-in with all of the other counts for a mismash that
414 # doesn't have a straight-forward interpretation in most counting
415 # contexts. Instead, we implement straight-addition. Both the inputs
416 # and outputs are allowed to contain zero and negative counts.
418 if iterable is not None:
419 if isinstance(iterable, Mapping):
420 if self:
421 self_get = self.get
422 for elem, count in iterable.iteritems():
423 self[elem] = self_get(elem, 0) + count
424 else:
425 dict.update(self, iterable) # fast path when counter is empty
426 else:
427 self_get = self.get
428 for elem in iterable:
429 self[elem] = self_get(elem, 0) + 1
430 if kwds:
431 self.update(kwds)
433 def copy(self):
434 'Like dict.copy() but returns a Counter instance instead of a dict.'
435 return Counter(self)
437 def __delitem__(self, elem):
438 'Like dict.__delitem__() but does not raise KeyError for missing values.'
439 if elem in self:
440 dict.__delitem__(self, elem)
442 def __repr__(self):
443 if not self:
444 return '%s()' % self.__class__.__name__
445 items = ', '.join(map('%r: %r'.__mod__, self.most_common()))
446 return '%s({%s})' % (self.__class__.__name__, items)
448 # Multiset-style mathematical operations discussed in:
449 # Knuth TAOCP Volume II section 4.6.3 exercise 19
450 # and at http://en.wikipedia.org/wiki/Multiset
452 # Outputs guaranteed to only include positive counts.
454 # To strip negative and zero counts, add-in an empty counter:
455 # c += Counter()
457 def __add__(self, other):
458 '''Add counts from two counters.
460 >>> Counter('abbb') + Counter('bcc')
461 Counter({'b': 4, 'c': 2, 'a': 1})
464 if not isinstance(other, Counter):
465 return NotImplemented
466 result = Counter()
467 for elem in set(self) | set(other):
468 newcount = self[elem] + other[elem]
469 if newcount > 0:
470 result[elem] = newcount
471 return result
473 def __sub__(self, other):
474 ''' Subtract count, but keep only results with positive counts.
476 >>> Counter('abbbc') - Counter('bccd')
477 Counter({'b': 2, 'a': 1})
480 if not isinstance(other, Counter):
481 return NotImplemented
482 result = Counter()
483 for elem in set(self) | set(other):
484 newcount = self[elem] - other[elem]
485 if newcount > 0:
486 result[elem] = newcount
487 return result
489 def __or__(self, other):
490 '''Union is the maximum of value in either of the input counters.
492 >>> Counter('abbb') | Counter('bcc')
493 Counter({'b': 3, 'c': 2, 'a': 1})
496 if not isinstance(other, Counter):
497 return NotImplemented
498 result = Counter()
499 for elem in set(self) | set(other):
500 p, q = self[elem], other[elem]
501 newcount = q if p < q else p
502 if newcount > 0:
503 result[elem] = newcount
504 return result
506 def __and__(self, other):
507 ''' Intersection is the minimum of corresponding counts.
509 >>> Counter('abbb') & Counter('bcc')
510 Counter({'b': 1})
513 if not isinstance(other, Counter):
514 return NotImplemented
515 result = Counter()
516 if len(self) < len(other):
517 self, other = other, self
518 for elem in _ifilter(self.__contains__, other):
519 p, q = self[elem], other[elem]
520 newcount = p if p < q else q
521 if newcount > 0:
522 result[elem] = newcount
523 return result
526 if __name__ == '__main__':
527 # verify that instances can be pickled
528 from cPickle import loads, dumps
529 Point = namedtuple('Point', 'x, y', True)
530 p = Point(x=10, y=20)
531 assert p == loads(dumps(p))
533 # test and demonstrate ability to override methods
534 class Point(namedtuple('Point', 'x y')):
535 __slots__ = ()
536 @property
537 def hypot(self):
538 return (self.x ** 2 + self.y ** 2) ** 0.5
539 def __str__(self):
540 return 'Point: x=%6.3f y=%6.3f hypot=%6.3f' % (self.x, self.y, self.hypot)
542 for p in Point(3, 4), Point(14, 5/7.):
543 print p
545 class Point(namedtuple('Point', 'x y')):
546 'Point class with optimized _make() and _replace() without error-checking'
547 __slots__ = ()
548 _make = classmethod(tuple.__new__)
549 def _replace(self, _map=map, **kwds):
550 return self._make(_map(kwds.get, ('x', 'y'), self))
552 print Point(11, 22)._replace(x=100)
554 Point3D = namedtuple('Point3D', Point._fields + ('z',))
555 print Point3D.__doc__
557 import doctest
558 TestResults = namedtuple('TestResults', 'failed attempted')
559 print TestResults(*doctest.testmod())