1 # Copyright (c) 2009-2012 testtools developers. See LICENSE for details.
7 from ..helpers
import (
12 from ._higherorder
import (
17 from ._impl
import Matcher
, Mismatch
20 def LabelledMismatches(mismatches
, details
=None):
21 """A collection of mismatches, each labelled."""
23 (PrefixedMismatch(k
, v
) for (k
, v
) in sorted(mismatches
.items())),
27 class MatchesAllDict(Matcher
):
28 """Matches if all of the matchers it is created with match.
30 A lot like ``MatchesAll``, but takes a dict of Matchers and labels any
31 mismatches with the key of the dictionary.
34 def __init__(self
, matchers
):
35 super(MatchesAllDict
, self
).__init
__()
36 self
.matchers
= matchers
39 return 'MatchesAllDict(%s)' % (_format_matcher_dict(self
.matchers
),)
41 def match(self
, observed
):
43 for label
in self
.matchers
:
44 mismatches
[label
] = self
.matchers
[label
].match(observed
)
45 return _dict_to_mismatch(
46 mismatches
, result_mismatch
=LabelledMismatches
)
49 class DictMismatches(Mismatch
):
50 """A mismatch with a dict of child mismatches."""
52 def __init__(self
, mismatches
, details
=None):
53 super(DictMismatches
, self
).__init
__(None, details
=details
)
54 self
.mismatches
= mismatches
59 [' %r: %s,' % (key
, mismatch
.describe())
60 for (key
, mismatch
) in sorted(self
.mismatches
.items())])
62 return '\n'.join(lines
)
65 def _dict_to_mismatch(data
, to_mismatch
=None,
66 result_mismatch
=DictMismatches
):
68 data
= map_values(to_mismatch
, data
)
69 mismatches
= filter_values(bool, data
)
71 return result_mismatch(mismatches
)
74 class _MatchCommonKeys(Matcher
):
75 """Match on keys in a dictionary.
77 Given a dictionary where the values are matchers, this will look for
78 common keys in the matched dictionary and match if and only if all common
79 keys match the given matchers.
83 >>> structure = {'a': Equals('x'), 'b': Equals('y')}
84 >>> _MatchCommonKeys(structure).match({'a': 'x', 'c': 'z'})
88 def __init__(self
, dict_of_matchers
):
89 super(_MatchCommonKeys
, self
).__init
__()
90 self
._matchers
= dict_of_matchers
92 def _compare_dicts(self
, expected
, observed
):
93 common_keys
= set(expected
.keys()) & set(observed
.keys())
95 for key
in common_keys
:
96 mismatch
= expected
[key
].match(observed
[key
])
98 mismatches
[key
] = mismatch
101 def match(self
, observed
):
102 mismatches
= self
._compare
_dicts
(self
._matchers
, observed
)
104 return DictMismatches(mismatches
)
107 class _SubDictOf(Matcher
):
108 """Matches if the matched dict only has keys that are in given dict."""
110 def __init__(self
, super_dict
, format_value
=repr):
111 super(_SubDictOf
, self
).__init
__()
112 self
.super_dict
= super_dict
113 self
.format_value
= format_value
115 def match(self
, observed
):
116 excess
= dict_subtract(observed
, self
.super_dict
)
117 return _dict_to_mismatch(
118 excess
, lambda v
: Mismatch(self
.format_value(v
)))
121 class _SuperDictOf(Matcher
):
122 """Matches if all of the keys in the given dict are in the matched dict.
125 def __init__(self
, sub_dict
, format_value
=repr):
126 super(_SuperDictOf
, self
).__init
__()
127 self
.sub_dict
= sub_dict
128 self
.format_value
= format_value
130 def match(self
, super_dict
):
131 return _SubDictOf(super_dict
, self
.format_value
).match(self
.sub_dict
)
134 def _format_matcher_dict(matchers
):
136 ', '.join(sorted('%r: %s' % (k
, v
) for k
, v
in matchers
.items())))
139 class _CombinedMatcher(Matcher
):
140 """Many matchers labelled and combined into one uber-matcher.
142 Subclass this and then specify a dict of matcher factories that take a
143 single 'expected' value and return a matcher. The subclass will match
144 only if all of the matchers made from factories match.
146 Not **entirely** dissimilar from ``MatchesAll``.
149 matcher_factories
= {}
151 def __init__(self
, expected
):
152 super(_CombinedMatcher
, self
).__init
__()
153 self
._expected
= expected
155 def format_expected(self
, expected
):
156 return repr(expected
)
160 self
.__class
__.__name
__, self
.format_expected(self
._expected
))
162 def match(self
, observed
):
164 (k
, v(self
._expected
)) for k
, v
in self
.matcher_factories
.items())
165 return MatchesAllDict(matchers
).match(observed
)
168 class MatchesDict(_CombinedMatcher
):
169 """Match a dictionary exactly, by its keys.
171 Specify a dictionary mapping keys (often strings) to matchers. This is
172 the 'expected' dict. Any dictionary that matches this must have exactly
173 the same keys, and the values must match the corresponding matchers in the
177 matcher_factories
= {
179 'Missing': lambda m
: _SuperDictOf(m
, format_value
=str),
180 'Differences': _MatchCommonKeys
,
183 format_expected
= lambda self
, expected
: _format_matcher_dict(expected
)
186 class ContainsDict(_CombinedMatcher
):
187 """Match a dictionary for that contains a specified sub-dictionary.
189 Specify a dictionary mapping keys (often strings) to matchers. This is
190 the 'expected' dict. Any dictionary that matches this must have **at
191 least** these keys, and the values must match the corresponding matchers
192 in the expected dict. Dictionaries that have more keys will also match.
194 In other words, any matching dictionary must contain the dictionary given
197 Does not check for strict sub-dictionary. That is, equal dictionaries
201 matcher_factories
= {
202 'Missing': lambda m
: _SuperDictOf(m
, format_value
=str),
203 'Differences': _MatchCommonKeys
,
206 format_expected
= lambda self
, expected
: _format_matcher_dict(expected
)
209 class ContainedByDict(_CombinedMatcher
):
210 """Match a dictionary for which this is a super-dictionary.
212 Specify a dictionary mapping keys (often strings) to matchers. This is
213 the 'expected' dict. Any dictionary that matches this must have **only**
214 these keys, and the values must match the corresponding matchers in the
215 expected dict. Dictionaries that have fewer keys can also match.
217 In other words, any matching dictionary must be contained by the
218 dictionary given to the constructor.
220 Does not check for strict super-dictionary. That is, equal dictionaries
224 matcher_factories
= {
226 'Differences': _MatchCommonKeys
,
229 format_expected
= lambda self
, expected
: _format_matcher_dict(expected
)
232 class KeysEqual(Matcher
):
233 """Checks whether a dict has particular keys."""
235 def __init__(self
, *expected
):
236 """Create a `KeysEqual` Matcher.
238 :param expected: The keys the dict is expected to have. If a dict,
239 then we use the keys of that dict, if a collection, we assume it
240 is a collection of expected keys.
242 super(KeysEqual
, self
).__init
__()
244 self
.expected
= expected
.keys()
245 except AttributeError:
246 self
.expected
= list(expected
)
249 return "KeysEqual(%s)" % ', '.join(map(repr, self
.expected
))
251 def match(self
, matchee
):
252 from ._basic
import _BinaryMismatch
, Equals
253 expected
= sorted(self
.expected
)
254 matched
= Equals(expected
).match(sorted(matchee
.keys()))
256 return AnnotatedMismatch(
258 _BinaryMismatch(expected
, 'does not match', matchee
))