Initialize the file descriptor in the files_struct before trying to close it. Otherwi...
[Samba/gebeck_regimport.git] / lib / testtools / testtools / matchers / _dict.py
blobff05199e6c14b4a33422f43ed9c7a852988af991
1 # Copyright (c) 2009-2012 testtools developers. See LICENSE for details.
3 __all__ = [
4 'KeysEqual',
7 from ..helpers import (
8 dict_subtract,
9 filter_values,
10 map_values,
12 from ._higherorder import (
13 AnnotatedMismatch,
14 PrefixedMismatch,
15 MismatchesAll,
17 from ._impl import Matcher, Mismatch
20 def LabelledMismatches(mismatches, details=None):
21 """A collection of mismatches, each labelled."""
22 return MismatchesAll(
23 (PrefixedMismatch(k, v) for (k, v) in sorted(mismatches.items())),
24 wrap=False)
27 class MatchesAllDict(Matcher):
28 """Matches if all of the matchers it is created with match.
30 A lot like ``MatchesAll``, but takes a dict of Matchers and labels any
31 mismatches with the key of the dictionary.
32 """
34 def __init__(self, matchers):
35 super(MatchesAllDict, self).__init__()
36 self.matchers = matchers
38 def __str__(self):
39 return 'MatchesAllDict(%s)' % (_format_matcher_dict(self.matchers),)
41 def match(self, observed):
42 mismatches = {}
43 for label in self.matchers:
44 mismatches[label] = self.matchers[label].match(observed)
45 return _dict_to_mismatch(
46 mismatches, result_mismatch=LabelledMismatches)
49 class DictMismatches(Mismatch):
50 """A mismatch with a dict of child mismatches."""
52 def __init__(self, mismatches, details=None):
53 super(DictMismatches, self).__init__(None, details=details)
54 self.mismatches = mismatches
56 def describe(self):
57 lines = ['{']
58 lines.extend(
59 [' %r: %s,' % (key, mismatch.describe())
60 for (key, mismatch) in sorted(self.mismatches.items())])
61 lines.append('}')
62 return '\n'.join(lines)
65 def _dict_to_mismatch(data, to_mismatch=None,
66 result_mismatch=DictMismatches):
67 if to_mismatch:
68 data = map_values(to_mismatch, data)
69 mismatches = filter_values(bool, data)
70 if mismatches:
71 return result_mismatch(mismatches)
74 class _MatchCommonKeys(Matcher):
75 """Match on keys in a dictionary.
77 Given a dictionary where the values are matchers, this will look for
78 common keys in the matched dictionary and match if and only if all common
79 keys match the given matchers.
81 Thus::
83 >>> structure = {'a': Equals('x'), 'b': Equals('y')}
84 >>> _MatchCommonKeys(structure).match({'a': 'x', 'c': 'z'})
85 None
86 """
88 def __init__(self, dict_of_matchers):
89 super(_MatchCommonKeys, self).__init__()
90 self._matchers = dict_of_matchers
92 def _compare_dicts(self, expected, observed):
93 common_keys = set(expected.keys()) & set(observed.keys())
94 mismatches = {}
95 for key in common_keys:
96 mismatch = expected[key].match(observed[key])
97 if mismatch:
98 mismatches[key] = mismatch
99 return mismatches
101 def match(self, observed):
102 mismatches = self._compare_dicts(self._matchers, observed)
103 if mismatches:
104 return DictMismatches(mismatches)
107 class _SubDictOf(Matcher):
108 """Matches if the matched dict only has keys that are in given dict."""
110 def __init__(self, super_dict, format_value=repr):
111 super(_SubDictOf, self).__init__()
112 self.super_dict = super_dict
113 self.format_value = format_value
115 def match(self, observed):
116 excess = dict_subtract(observed, self.super_dict)
117 return _dict_to_mismatch(
118 excess, lambda v: Mismatch(self.format_value(v)))
121 class _SuperDictOf(Matcher):
122 """Matches if all of the keys in the given dict are in the matched dict.
125 def __init__(self, sub_dict, format_value=repr):
126 super(_SuperDictOf, self).__init__()
127 self.sub_dict = sub_dict
128 self.format_value = format_value
130 def match(self, super_dict):
131 return _SubDictOf(super_dict, self.format_value).match(self.sub_dict)
134 def _format_matcher_dict(matchers):
135 return '{%s}' % (
136 ', '.join(sorted('%r: %s' % (k, v) for k, v in matchers.items())))
139 class _CombinedMatcher(Matcher):
140 """Many matchers labelled and combined into one uber-matcher.
142 Subclass this and then specify a dict of matcher factories that take a
143 single 'expected' value and return a matcher. The subclass will match
144 only if all of the matchers made from factories match.
146 Not **entirely** dissimilar from ``MatchesAll``.
149 matcher_factories = {}
151 def __init__(self, expected):
152 super(_CombinedMatcher, self).__init__()
153 self._expected = expected
155 def format_expected(self, expected):
156 return repr(expected)
158 def __str__(self):
159 return '%s(%s)' % (
160 self.__class__.__name__, self.format_expected(self._expected))
162 def match(self, observed):
163 matchers = dict(
164 (k, v(self._expected)) for k, v in self.matcher_factories.items())
165 return MatchesAllDict(matchers).match(observed)
168 class MatchesDict(_CombinedMatcher):
169 """Match a dictionary exactly, by its keys.
171 Specify a dictionary mapping keys (often strings) to matchers. This is
172 the 'expected' dict. Any dictionary that matches this must have exactly
173 the same keys, and the values must match the corresponding matchers in the
174 expected dict.
177 matcher_factories = {
178 'Extra': _SubDictOf,
179 'Missing': lambda m: _SuperDictOf(m, format_value=str),
180 'Differences': _MatchCommonKeys,
183 format_expected = lambda self, expected: _format_matcher_dict(expected)
186 class ContainsDict(_CombinedMatcher):
187 """Match a dictionary for that contains a specified sub-dictionary.
189 Specify a dictionary mapping keys (often strings) to matchers. This is
190 the 'expected' dict. Any dictionary that matches this must have **at
191 least** these keys, and the values must match the corresponding matchers
192 in the expected dict. Dictionaries that have more keys will also match.
194 In other words, any matching dictionary must contain the dictionary given
195 to the constructor.
197 Does not check for strict sub-dictionary. That is, equal dictionaries
198 match.
201 matcher_factories = {
202 'Missing': lambda m: _SuperDictOf(m, format_value=str),
203 'Differences': _MatchCommonKeys,
206 format_expected = lambda self, expected: _format_matcher_dict(expected)
209 class ContainedByDict(_CombinedMatcher):
210 """Match a dictionary for which this is a super-dictionary.
212 Specify a dictionary mapping keys (often strings) to matchers. This is
213 the 'expected' dict. Any dictionary that matches this must have **only**
214 these keys, and the values must match the corresponding matchers in the
215 expected dict. Dictionaries that have fewer keys can also match.
217 In other words, any matching dictionary must be contained by the
218 dictionary given to the constructor.
220 Does not check for strict super-dictionary. That is, equal dictionaries
221 match.
224 matcher_factories = {
225 'Extra': _SubDictOf,
226 'Differences': _MatchCommonKeys,
229 format_expected = lambda self, expected: _format_matcher_dict(expected)
232 class KeysEqual(Matcher):
233 """Checks whether a dict has particular keys."""
235 def __init__(self, *expected):
236 """Create a `KeysEqual` Matcher.
238 :param expected: The keys the dict is expected to have. If a dict,
239 then we use the keys of that dict, if a collection, we assume it
240 is a collection of expected keys.
242 super(KeysEqual, self).__init__()
243 try:
244 self.expected = expected.keys()
245 except AttributeError:
246 self.expected = list(expected)
248 def __str__(self):
249 return "KeysEqual(%s)" % ', '.join(map(repr, self.expected))
251 def match(self, matchee):
252 from ._basic import _BinaryMismatch, Equals
253 expected = sorted(self.expected)
254 matched = Equals(expected).match(sorted(matchee.keys()))
255 if matched:
256 return AnnotatedMismatch(
257 'Keys not equal',
258 _BinaryMismatch(expected, 'does not match', matchee))
259 return None