App Engine Python SDK version 1.7.4 (2)
[gae.git] / python / lib / django_1_4 / django / test / testcases.py
blob1f451877ac858f493111625dcc13362a21a5dfe1
1 from __future__ import with_statement
3 import difflib
4 import os
5 import re
6 import sys
7 from copy import copy
8 from functools import wraps
9 from urlparse import urlsplit, urlunsplit
10 from xml.dom.minidom import parseString, Node
11 import select
12 import socket
13 import threading
14 import errno
16 from django.conf import settings
17 from django.contrib.staticfiles.handlers import StaticFilesHandler
18 from django.core import mail
19 from django.core.exceptions import ValidationError, ImproperlyConfigured
20 from django.core.handlers.wsgi import WSGIHandler
21 from django.core.management import call_command
22 from django.core.signals import request_started
23 from django.core.servers.basehttp import (WSGIRequestHandler, WSGIServer,
24 WSGIServerException)
25 from django.core.urlresolvers import clear_url_caches
26 from django.core.validators import EMPTY_VALUES
27 from django.db import (transaction, connection, connections, DEFAULT_DB_ALIAS,
28 reset_queries)
29 from django.forms.fields import CharField
30 from django.http import QueryDict
31 from django.test import _doctest as doctest
32 from django.test.client import Client
33 from django.test.html import HTMLParseError, parse_html
34 from django.test.signals import template_rendered
35 from django.test.utils import (get_warnings_state, restore_warnings_state,
36 override_settings)
37 from django.test.utils import ContextList
38 from django.utils import simplejson, unittest as ut2
39 from django.utils.encoding import smart_str, force_unicode
40 from django.utils.unittest.util import safe_repr
41 from django.views.static import serve
43 __all__ = ('DocTestRunner', 'OutputChecker', 'TestCase', 'TransactionTestCase',
44 'SimpleTestCase', 'skipIfDBFeature', 'skipUnlessDBFeature')
46 normalize_long_ints = lambda s: re.sub(r'(?<![\w])(\d+)L(?![\w])', '\\1', s)
47 normalize_decimals = lambda s: re.sub(r"Decimal\('(\d+(\.\d*)?)'\)",
48 lambda m: "Decimal(\"%s\")" % m.groups()[0], s)
50 def to_list(value):
51 """
52 Puts value into a list if it's not already one.
53 Returns an empty list if value is None.
54 """
55 if value is None:
56 value = []
57 elif not isinstance(value, list):
58 value = [value]
59 return value
61 real_commit = transaction.commit
62 real_rollback = transaction.rollback
63 real_enter_transaction_management = transaction.enter_transaction_management
64 real_leave_transaction_management = transaction.leave_transaction_management
65 real_managed = transaction.managed
67 def nop(*args, **kwargs):
68 return
70 def disable_transaction_methods():
71 transaction.commit = nop
72 transaction.rollback = nop
73 transaction.enter_transaction_management = nop
74 transaction.leave_transaction_management = nop
75 transaction.managed = nop
77 def restore_transaction_methods():
78 transaction.commit = real_commit
79 transaction.rollback = real_rollback
80 transaction.enter_transaction_management = real_enter_transaction_management
81 transaction.leave_transaction_management = real_leave_transaction_management
82 transaction.managed = real_managed
85 def assert_and_parse_html(self, html, user_msg, msg):
86 try:
87 dom = parse_html(html)
88 except HTMLParseError, e:
89 standardMsg = u'%s\n%s' % (msg, e.msg)
90 self.fail(self._formatMessage(user_msg, standardMsg))
91 return dom
94 class OutputChecker(doctest.OutputChecker):
95 def check_output(self, want, got, optionflags):
96 """
97 The entry method for doctest output checking. Defers to a sequence of
98 child checkers
99 """
100 checks = (self.check_output_default,
101 self.check_output_numeric,
102 self.check_output_xml,
103 self.check_output_json)
104 for check in checks:
105 if check(want, got, optionflags):
106 return True
107 return False
109 def check_output_default(self, want, got, optionflags):
111 The default comparator provided by doctest - not perfect, but good for
112 most purposes
114 return doctest.OutputChecker.check_output(self, want, got, optionflags)
116 def check_output_numeric(self, want, got, optionflags):
117 """Doctest does an exact string comparison of output, which means that
118 some numerically equivalent values aren't equal. This check normalizes
119 * long integers (22L) so that they equal normal integers. (22)
120 * Decimals so that they are comparable, regardless of the change
121 made to __repr__ in Python 2.6.
123 return doctest.OutputChecker.check_output(self,
124 normalize_decimals(normalize_long_ints(want)),
125 normalize_decimals(normalize_long_ints(got)),
126 optionflags)
128 def check_output_xml(self, want, got, optionsflags):
129 """Tries to do a 'xml-comparision' of want and got. Plain string
130 comparision doesn't always work because, for example, attribute
131 ordering should not be important.
133 Based on http://codespeak.net/svn/lxml/trunk/src/lxml/doctestcompare.py
135 _norm_whitespace_re = re.compile(r'[ \t\n][ \t\n]+')
136 def norm_whitespace(v):
137 return _norm_whitespace_re.sub(' ', v)
139 def child_text(element):
140 return ''.join([c.data for c in element.childNodes
141 if c.nodeType == Node.TEXT_NODE])
143 def children(element):
144 return [c for c in element.childNodes
145 if c.nodeType == Node.ELEMENT_NODE]
147 def norm_child_text(element):
148 return norm_whitespace(child_text(element))
150 def attrs_dict(element):
151 return dict(element.attributes.items())
153 def check_element(want_element, got_element):
154 if want_element.tagName != got_element.tagName:
155 return False
156 if norm_child_text(want_element) != norm_child_text(got_element):
157 return False
158 if attrs_dict(want_element) != attrs_dict(got_element):
159 return False
160 want_children = children(want_element)
161 got_children = children(got_element)
162 if len(want_children) != len(got_children):
163 return False
164 for want, got in zip(want_children, got_children):
165 if not check_element(want, got):
166 return False
167 return True
169 want, got = self._strip_quotes(want, got)
170 want = want.replace('\\n','\n')
171 got = got.replace('\\n','\n')
173 # If the string is not a complete xml document, we may need to add a
174 # root element. This allow us to compare fragments, like "<foo/><bar/>"
175 if not want.startswith('<?xml'):
176 wrapper = '<root>%s</root>'
177 want = wrapper % want
178 got = wrapper % got
180 # Parse the want and got strings, and compare the parsings.
181 try:
182 want_root = parseString(want).firstChild
183 got_root = parseString(got).firstChild
184 except Exception:
185 return False
186 return check_element(want_root, got_root)
188 def check_output_json(self, want, got, optionsflags):
190 Tries to compare want and got as if they were JSON-encoded data
192 want, got = self._strip_quotes(want, got)
193 try:
194 want_json = simplejson.loads(want)
195 got_json = simplejson.loads(got)
196 except Exception:
197 return False
198 return want_json == got_json
200 def _strip_quotes(self, want, got):
202 Strip quotes of doctests output values:
204 >>> o = OutputChecker()
205 >>> o._strip_quotes("'foo'")
206 "foo"
207 >>> o._strip_quotes('"foo"')
208 "foo"
209 >>> o._strip_quotes("u'foo'")
210 "foo"
211 >>> o._strip_quotes('u"foo"')
212 "foo"
214 def is_quoted_string(s):
215 s = s.strip()
216 return (len(s) >= 2
217 and s[0] == s[-1]
218 and s[0] in ('"', "'"))
220 def is_quoted_unicode(s):
221 s = s.strip()
222 return (len(s) >= 3
223 and s[0] == 'u'
224 and s[1] == s[-1]
225 and s[1] in ('"', "'"))
227 if is_quoted_string(want) and is_quoted_string(got):
228 want = want.strip()[1:-1]
229 got = got.strip()[1:-1]
230 elif is_quoted_unicode(want) and is_quoted_unicode(got):
231 want = want.strip()[2:-1]
232 got = got.strip()[2:-1]
233 return want, got
236 class DocTestRunner(doctest.DocTestRunner):
237 def __init__(self, *args, **kwargs):
238 doctest.DocTestRunner.__init__(self, *args, **kwargs)
239 self.optionflags = doctest.ELLIPSIS
241 def report_unexpected_exception(self, out, test, example, exc_info):
242 doctest.DocTestRunner.report_unexpected_exception(self, out, test,
243 example, exc_info)
244 # Rollback, in case of database errors. Otherwise they'd have
245 # side effects on other tests.
246 for conn in connections:
247 transaction.rollback_unless_managed(using=conn)
250 class _AssertNumQueriesContext(object):
251 def __init__(self, test_case, num, connection):
252 self.test_case = test_case
253 self.num = num
254 self.connection = connection
256 def __enter__(self):
257 self.old_debug_cursor = self.connection.use_debug_cursor
258 self.connection.use_debug_cursor = True
259 self.starting_queries = len(self.connection.queries)
260 request_started.disconnect(reset_queries)
261 return self
263 def __exit__(self, exc_type, exc_value, traceback):
264 self.connection.use_debug_cursor = self.old_debug_cursor
265 request_started.connect(reset_queries)
266 if exc_type is not None:
267 return
269 final_queries = len(self.connection.queries)
270 executed = final_queries - self.starting_queries
272 self.test_case.assertEqual(
273 executed, self.num, "%d queries executed, %d expected" % (
274 executed, self.num
279 class _AssertTemplateUsedContext(object):
280 def __init__(self, test_case, template_name):
281 self.test_case = test_case
282 self.template_name = template_name
283 self.rendered_templates = []
284 self.rendered_template_names = []
285 self.context = ContextList()
287 def on_template_render(self, sender, signal, template, context, **kwargs):
288 self.rendered_templates.append(template)
289 self.rendered_template_names.append(template.name)
290 self.context.append(copy(context))
292 def test(self):
293 return self.template_name in self.rendered_template_names
295 def message(self):
296 return u'%s was not rendered.' % self.template_name
298 def __enter__(self):
299 template_rendered.connect(self.on_template_render)
300 return self
302 def __exit__(self, exc_type, exc_value, traceback):
303 template_rendered.disconnect(self.on_template_render)
304 if exc_type is not None:
305 return
307 if not self.test():
308 message = self.message()
309 if len(self.rendered_templates) == 0:
310 message += u' No template was rendered.'
311 else:
312 message += u' Following templates were rendered: %s' % (
313 ', '.join(self.rendered_template_names))
314 self.test_case.fail(message)
317 class _AssertTemplateNotUsedContext(_AssertTemplateUsedContext):
318 def test(self):
319 return self.template_name not in self.rendered_template_names
321 def message(self):
322 return u'%s was rendered.' % self.template_name
325 class SimpleTestCase(ut2.TestCase):
326 def save_warnings_state(self):
328 Saves the state of the warnings module
330 self._warnings_state = get_warnings_state()
332 def restore_warnings_state(self):
334 Restores the state of the warnings module to the state
335 saved by save_warnings_state()
337 restore_warnings_state(self._warnings_state)
339 def settings(self, **kwargs):
341 A context manager that temporarily sets a setting and reverts
342 back to the original value when exiting the context.
344 return override_settings(**kwargs)
346 def assertRaisesMessage(self, expected_exception, expected_message,
347 callable_obj=None, *args, **kwargs):
349 Asserts that the message in a raised exception matches the passed
350 value.
352 Args:
353 expected_exception: Exception class expected to be raised.
354 expected_message: expected error message string value.
355 callable_obj: Function to be called.
356 args: Extra args.
357 kwargs: Extra kwargs.
359 return self.assertRaisesRegexp(expected_exception,
360 re.escape(expected_message), callable_obj, *args, **kwargs)
362 def assertFieldOutput(self, fieldclass, valid, invalid, field_args=None,
363 field_kwargs=None, empty_value=u''):
365 Asserts that a form field behaves correctly with various inputs.
367 Args:
368 fieldclass: the class of the field to be tested.
369 valid: a dictionary mapping valid inputs to their expected
370 cleaned values.
371 invalid: a dictionary mapping invalid inputs to one or more
372 raised error messages.
373 field_args: the args passed to instantiate the field
374 field_kwargs: the kwargs passed to instantiate the field
375 empty_value: the expected clean output for inputs in EMPTY_VALUES
378 if field_args is None:
379 field_args = []
380 if field_kwargs is None:
381 field_kwargs = {}
382 required = fieldclass(*field_args, **field_kwargs)
383 optional = fieldclass(*field_args,
384 **dict(field_kwargs, required=False))
385 # test valid inputs
386 for input, output in valid.items():
387 self.assertEqual(required.clean(input), output)
388 self.assertEqual(optional.clean(input), output)
389 # test invalid inputs
390 for input, errors in invalid.items():
391 with self.assertRaises(ValidationError) as context_manager:
392 required.clean(input)
393 self.assertEqual(context_manager.exception.messages, errors)
395 with self.assertRaises(ValidationError) as context_manager:
396 optional.clean(input)
397 self.assertEqual(context_manager.exception.messages, errors)
398 # test required inputs
399 error_required = [force_unicode(required.error_messages['required'])]
400 for e in EMPTY_VALUES:
401 with self.assertRaises(ValidationError) as context_manager:
402 required.clean(e)
403 self.assertEqual(context_manager.exception.messages,
404 error_required)
405 self.assertEqual(optional.clean(e), empty_value)
406 # test that max_length and min_length are always accepted
407 if issubclass(fieldclass, CharField):
408 field_kwargs.update({'min_length':2, 'max_length':20})
409 self.assertTrue(isinstance(fieldclass(*field_args, **field_kwargs),
410 fieldclass))
412 def assertHTMLEqual(self, html1, html2, msg=None):
414 Asserts that two HTML snippets are semantically the same.
415 Whitespace in most cases is ignored, and attribute ordering is not
416 significant. The passed-in arguments must be valid HTML.
418 dom1 = assert_and_parse_html(self, html1, msg,
419 u'First argument is not valid HTML:')
420 dom2 = assert_and_parse_html(self, html2, msg,
421 u'Second argument is not valid HTML:')
423 if dom1 != dom2:
424 standardMsg = '%s != %s' % (
425 safe_repr(dom1, True), safe_repr(dom2, True))
426 diff = ('\n' + '\n'.join(difflib.ndiff(
427 unicode(dom1).splitlines(),
428 unicode(dom2).splitlines())))
429 standardMsg = self._truncateMessage(standardMsg, diff)
430 self.fail(self._formatMessage(msg, standardMsg))
432 def assertHTMLNotEqual(self, html1, html2, msg=None):
433 """Asserts that two HTML snippets are not semantically equivalent."""
434 dom1 = assert_and_parse_html(self, html1, msg,
435 u'First argument is not valid HTML:')
436 dom2 = assert_and_parse_html(self, html2, msg,
437 u'Second argument is not valid HTML:')
439 if dom1 == dom2:
440 standardMsg = '%s == %s' % (
441 safe_repr(dom1, True), safe_repr(dom2, True))
442 self.fail(self._formatMessage(msg, standardMsg))
445 class TransactionTestCase(SimpleTestCase):
446 # The class we'll use for the test client self.client.
447 # Can be overridden in derived classes.
448 client_class = Client
450 def _pre_setup(self):
451 """Performs any pre-test setup. This includes:
453 * Flushing the database.
454 * If the Test Case class has a 'fixtures' member, installing the
455 named fixtures.
456 * If the Test Case class has a 'urls' member, replace the
457 ROOT_URLCONF with it.
458 * Clearing the mail test outbox.
460 self._fixture_setup()
461 self._urlconf_setup()
462 mail.outbox = []
464 def _fixture_setup(self):
465 # If the test case has a multi_db=True flag, flush all databases.
466 # Otherwise, just flush default.
467 if getattr(self, 'multi_db', False):
468 databases = connections
469 else:
470 databases = [DEFAULT_DB_ALIAS]
471 for db in databases:
472 call_command('flush', verbosity=0, interactive=False, database=db)
474 if hasattr(self, 'fixtures'):
475 # We have to use this slightly awkward syntax due to the fact
476 # that we're using *args and **kwargs together.
477 call_command('loaddata', *self.fixtures,
478 **{'verbosity': 0, 'database': db})
480 def _urlconf_setup(self):
481 if hasattr(self, 'urls'):
482 self._old_root_urlconf = settings.ROOT_URLCONF
483 settings.ROOT_URLCONF = self.urls
484 clear_url_caches()
486 def __call__(self, result=None):
488 Wrapper around default __call__ method to perform common Django test
489 set up. This means that user-defined Test Cases aren't required to
490 include a call to super().setUp().
492 testMethod = getattr(self, self._testMethodName)
493 skipped = (getattr(self.__class__, "__unittest_skip__", False) or
494 getattr(testMethod, "__unittest_skip__", False))
496 if not skipped:
497 self.client = self.client_class()
498 try:
499 self._pre_setup()
500 except (KeyboardInterrupt, SystemExit):
501 raise
502 except Exception:
503 result.addError(self, sys.exc_info())
504 return
505 super(TransactionTestCase, self).__call__(result)
506 if not skipped:
507 try:
508 self._post_teardown()
509 except (KeyboardInterrupt, SystemExit):
510 raise
511 except Exception:
512 result.addError(self, sys.exc_info())
513 return
515 def _post_teardown(self):
516 """ Performs any post-test things. This includes:
518 * Putting back the original ROOT_URLCONF if it was changed.
519 * Force closing the connection, so that the next test gets
520 a clean cursor.
522 self._fixture_teardown()
523 self._urlconf_teardown()
524 # Some DB cursors include SQL statements as part of cursor
525 # creation. If you have a test that does rollback, the effect
526 # of these statements is lost, which can effect the operation
527 # of tests (e.g., losing a timezone setting causing objects to
528 # be created with the wrong time).
529 # To make sure this doesn't happen, get a clean connection at the
530 # start of every test.
531 for conn in connections.all():
532 conn.close()
534 def _fixture_teardown(self):
535 pass
537 def _urlconf_teardown(self):
538 if hasattr(self, '_old_root_urlconf'):
539 settings.ROOT_URLCONF = self._old_root_urlconf
540 clear_url_caches()
542 def assertRedirects(self, response, expected_url, status_code=302,
543 target_status_code=200, host=None, msg_prefix=''):
544 """Asserts that a response redirected to a specific URL, and that the
545 redirect URL can be loaded.
547 Note that assertRedirects won't work for external links since it uses
548 TestClient to do a request.
550 if msg_prefix:
551 msg_prefix += ": "
553 if hasattr(response, 'redirect_chain'):
554 # The request was a followed redirect
555 self.assertTrue(len(response.redirect_chain) > 0,
556 msg_prefix + "Response didn't redirect as expected: Response"
557 " code was %d (expected %d)" %
558 (response.status_code, status_code))
560 self.assertEqual(response.redirect_chain[0][1], status_code,
561 msg_prefix + "Initial response didn't redirect as expected:"
562 " Response code was %d (expected %d)" %
563 (response.redirect_chain[0][1], status_code))
565 url, status_code = response.redirect_chain[-1]
567 self.assertEqual(response.status_code, target_status_code,
568 msg_prefix + "Response didn't redirect as expected: Final"
569 " Response code was %d (expected %d)" %
570 (response.status_code, target_status_code))
572 else:
573 # Not a followed redirect
574 self.assertEqual(response.status_code, status_code,
575 msg_prefix + "Response didn't redirect as expected: Response"
576 " code was %d (expected %d)" %
577 (response.status_code, status_code))
579 url = response['Location']
580 scheme, netloc, path, query, fragment = urlsplit(url)
582 redirect_response = response.client.get(path, QueryDict(query))
584 # Get the redirection page, using the same client that was used
585 # to obtain the original response.
586 self.assertEqual(redirect_response.status_code, target_status_code,
587 msg_prefix + "Couldn't retrieve redirection page '%s':"
588 " response code was %d (expected %d)" %
589 (path, redirect_response.status_code, target_status_code))
591 e_scheme, e_netloc, e_path, e_query, e_fragment = urlsplit(
592 expected_url)
593 if not (e_scheme or e_netloc):
594 expected_url = urlunsplit(('http', host or 'testserver', e_path,
595 e_query, e_fragment))
597 self.assertEqual(url, expected_url,
598 msg_prefix + "Response redirected to '%s', expected '%s'" %
599 (url, expected_url))
601 def assertContains(self, response, text, count=None, status_code=200,
602 msg_prefix='', html=False):
604 Asserts that a response indicates that some content was retrieved
605 successfully, (i.e., the HTTP status code was as expected), and that
606 ``text`` occurs ``count`` times in the content of the response.
607 If ``count`` is None, the count doesn't matter - the assertion is true
608 if the text occurs at least once in the response.
611 # If the response supports deferred rendering and hasn't been rendered
612 # yet, then ensure that it does get rendered before proceeding further.
613 if (hasattr(response, 'render') and callable(response.render)
614 and not response.is_rendered):
615 response.render()
617 if msg_prefix:
618 msg_prefix += ": "
620 self.assertEqual(response.status_code, status_code,
621 msg_prefix + "Couldn't retrieve content: Response code was %d"
622 " (expected %d)" % (response.status_code, status_code))
623 text = smart_str(text, response._charset)
624 content = response.content
625 if html:
626 content = assert_and_parse_html(self, content, None,
627 u"Response's content is not valid HTML:")
628 text = assert_and_parse_html(self, text, None,
629 u"Second argument is not valid HTML:")
630 real_count = content.count(text)
631 if count is not None:
632 self.assertEqual(real_count, count,
633 msg_prefix + "Found %d instances of '%s' in response"
634 " (expected %d)" % (real_count, text, count))
635 else:
636 self.assertTrue(real_count != 0,
637 msg_prefix + "Couldn't find '%s' in response" % text)
639 def assertNotContains(self, response, text, status_code=200,
640 msg_prefix='', html=False):
642 Asserts that a response indicates that some content was retrieved
643 successfully, (i.e., the HTTP status code was as expected), and that
644 ``text`` doesn't occurs in the content of the response.
647 # If the response supports deferred rendering and hasn't been rendered
648 # yet, then ensure that it does get rendered before proceeding further.
649 if (hasattr(response, 'render') and callable(response.render)
650 and not response.is_rendered):
651 response.render()
653 if msg_prefix:
654 msg_prefix += ": "
656 self.assertEqual(response.status_code, status_code,
657 msg_prefix + "Couldn't retrieve content: Response code was %d"
658 " (expected %d)" % (response.status_code, status_code))
659 text = smart_str(text, response._charset)
660 content = response.content
661 if html:
662 content = assert_and_parse_html(self, content, None,
663 u'Response\'s content is not valid HTML:')
664 text = assert_and_parse_html(self, text, None,
665 u'Second argument is not valid HTML:')
666 self.assertEqual(content.count(text), 0,
667 msg_prefix + "Response should not contain '%s'" % text)
669 def assertFormError(self, response, form, field, errors, msg_prefix=''):
671 Asserts that a form used to render the response has a specific field
672 error.
674 if msg_prefix:
675 msg_prefix += ": "
677 # Put context(s) into a list to simplify processing.
678 contexts = to_list(response.context)
679 if not contexts:
680 self.fail(msg_prefix + "Response did not use any contexts to "
681 "render the response")
683 # Put error(s) into a list to simplify processing.
684 errors = to_list(errors)
686 # Search all contexts for the error.
687 found_form = False
688 for i,context in enumerate(contexts):
689 if form not in context:
690 continue
691 found_form = True
692 for err in errors:
693 if field:
694 if field in context[form].errors:
695 field_errors = context[form].errors[field]
696 self.assertTrue(err in field_errors,
697 msg_prefix + "The field '%s' on form '%s' in"
698 " context %d does not contain the error '%s'"
699 " (actual errors: %s)" %
700 (field, form, i, err, repr(field_errors)))
701 elif field in context[form].fields:
702 self.fail(msg_prefix + "The field '%s' on form '%s'"
703 " in context %d contains no errors" %
704 (field, form, i))
705 else:
706 self.fail(msg_prefix + "The form '%s' in context %d"
707 " does not contain the field '%s'" %
708 (form, i, field))
709 else:
710 non_field_errors = context[form].non_field_errors()
711 self.assertTrue(err in non_field_errors,
712 msg_prefix + "The form '%s' in context %d does not"
713 " contain the non-field error '%s'"
714 " (actual errors: %s)" %
715 (form, i, err, non_field_errors))
716 if not found_form:
717 self.fail(msg_prefix + "The form '%s' was not used to render the"
718 " response" % form)
720 def assertTemplateUsed(self, response=None, template_name=None, msg_prefix=''):
722 Asserts that the template with the provided name was used in rendering
723 the response. Also usable as context manager.
725 if response is None and template_name is None:
726 raise TypeError(u'response and/or template_name argument must be provided')
728 if msg_prefix:
729 msg_prefix += ": "
731 # Use assertTemplateUsed as context manager.
732 if not hasattr(response, 'templates') or (response is None and template_name):
733 if response:
734 template_name = response
735 response = None
736 context = _AssertTemplateUsedContext(self, template_name)
737 return context
739 template_names = [t.name for t in response.templates]
740 if not template_names:
741 self.fail(msg_prefix + "No templates used to render the response")
742 self.assertTrue(template_name in template_names,
743 msg_prefix + "Template '%s' was not a template used to render"
744 " the response. Actual template(s) used: %s" %
745 (template_name, u', '.join(template_names)))
747 def assertTemplateNotUsed(self, response=None, template_name=None, msg_prefix=''):
749 Asserts that the template with the provided name was NOT used in
750 rendering the response. Also usable as context manager.
752 if response is None and template_name is None:
753 raise TypeError(u'response and/or template_name argument must be provided')
755 if msg_prefix:
756 msg_prefix += ": "
758 # Use assertTemplateUsed as context manager.
759 if not hasattr(response, 'templates') or (response is None and template_name):
760 if response:
761 template_name = response
762 response = None
763 context = _AssertTemplateNotUsedContext(self, template_name)
764 return context
766 template_names = [t.name for t in response.templates]
767 self.assertFalse(template_name in template_names,
768 msg_prefix + "Template '%s' was used unexpectedly in rendering"
769 " the response" % template_name)
771 def assertQuerysetEqual(self, qs, values, transform=repr, ordered=True):
772 if not ordered:
773 return self.assertEqual(set(map(transform, qs)), set(values))
774 return self.assertEqual(map(transform, qs), values)
776 def assertNumQueries(self, num, func=None, *args, **kwargs):
777 using = kwargs.pop("using", DEFAULT_DB_ALIAS)
778 conn = connections[using]
780 context = _AssertNumQueriesContext(self, num, conn)
781 if func is None:
782 return context
784 with context:
785 func(*args, **kwargs)
788 def connections_support_transactions():
790 Returns True if all connections support transactions.
792 return all(conn.features.supports_transactions
793 for conn in connections.all())
796 class TestCase(TransactionTestCase):
798 Does basically the same as TransactionTestCase, but surrounds every test
799 with a transaction, monkey-patches the real transaction management routines
800 to do nothing, and rollsback the test transaction at the end of the test.
801 You have to use TransactionTestCase, if you need transaction management
802 inside a test.
805 def _fixture_setup(self):
806 if not connections_support_transactions():
807 return super(TestCase, self)._fixture_setup()
809 # If the test case has a multi_db=True flag, setup all databases.
810 # Otherwise, just use default.
811 if getattr(self, 'multi_db', False):
812 databases = connections
813 else:
814 databases = [DEFAULT_DB_ALIAS]
816 for db in databases:
817 transaction.enter_transaction_management(using=db)
818 transaction.managed(True, using=db)
819 disable_transaction_methods()
821 from django.contrib.sites.models import Site
822 Site.objects.clear_cache()
824 for db in databases:
825 if hasattr(self, 'fixtures'):
826 call_command('loaddata', *self.fixtures,
828 'verbosity': 0,
829 'commit': False,
830 'database': db
833 def _fixture_teardown(self):
834 if not connections_support_transactions():
835 return super(TestCase, self)._fixture_teardown()
837 # If the test case has a multi_db=True flag, teardown all databases.
838 # Otherwise, just teardown default.
839 if getattr(self, 'multi_db', False):
840 databases = connections
841 else:
842 databases = [DEFAULT_DB_ALIAS]
844 restore_transaction_methods()
845 for db in databases:
846 transaction.rollback(using=db)
847 transaction.leave_transaction_management(using=db)
850 def _deferredSkip(condition, reason):
851 def decorator(test_func):
852 if not (isinstance(test_func, type) and
853 issubclass(test_func, TestCase)):
854 @wraps(test_func)
855 def skip_wrapper(*args, **kwargs):
856 if condition():
857 raise ut2.SkipTest(reason)
858 return test_func(*args, **kwargs)
859 test_item = skip_wrapper
860 else:
861 test_item = test_func
862 test_item.__unittest_skip_why__ = reason
863 return test_item
864 return decorator
867 def skipIfDBFeature(feature):
869 Skip a test if a database has the named feature
871 return _deferredSkip(lambda: getattr(connection.features, feature),
872 "Database has feature %s" % feature)
875 def skipUnlessDBFeature(feature):
877 Skip a test unless a database has the named feature
879 return _deferredSkip(lambda: not getattr(connection.features, feature),
880 "Database doesn't support feature %s" % feature)
883 class QuietWSGIRequestHandler(WSGIRequestHandler):
885 Just a regular WSGIRequestHandler except it doesn't log to the standard
886 output any of the requests received, so as to not clutter the output for
887 the tests' results.
890 def log_message(*args):
891 pass
894 class _ImprovedEvent(threading._Event):
896 Does the same as `threading.Event` except it overrides the wait() method
897 with some code borrowed from Python 2.7 to return the set state of the
898 event (see: http://hg.python.org/cpython/rev/b5aa8aa78c0f/). This allows
899 to know whether the wait() method exited normally or because of the
900 timeout. This class can be removed when Django supports only Python >= 2.7.
903 def wait(self, timeout=None):
904 self._Event__cond.acquire()
905 try:
906 if not self._Event__flag:
907 self._Event__cond.wait(timeout)
908 return self._Event__flag
909 finally:
910 self._Event__cond.release()
913 class StoppableWSGIServer(WSGIServer):
915 The code in this class is borrowed from the `SocketServer.BaseServer` class
916 in Python 2.6. The important functionality here is that the server is non-
917 blocking and that it can be shut down at any moment. This is made possible
918 by the server regularly polling the socket and checking if it has been
919 asked to stop.
920 Note for the future: Once Django stops supporting Python 2.6, this class
921 can be removed as `WSGIServer` will have this ability to shutdown on
922 demand and will not require the use of the _ImprovedEvent class whose code
923 is borrowed from Python 2.7.
926 def __init__(self, *args, **kwargs):
927 super(StoppableWSGIServer, self).__init__(*args, **kwargs)
928 self.__is_shut_down = _ImprovedEvent()
929 self.__serving = False
931 def serve_forever(self, poll_interval=0.5):
933 Handle one request at a time until shutdown.
935 Polls for shutdown every poll_interval seconds.
937 self.__serving = True
938 self.__is_shut_down.clear()
939 while self.__serving:
940 r, w, e = select.select([self], [], [], poll_interval)
941 if r:
942 self._handle_request_noblock()
943 self.__is_shut_down.set()
945 def shutdown(self):
947 Stops the serve_forever loop.
949 Blocks until the loop has finished. This must be called while
950 serve_forever() is running in another thread, or it will
951 deadlock.
953 self.__serving = False
954 if not self.__is_shut_down.wait(2):
955 raise RuntimeError(
956 "Failed to shutdown the live test server in 2 seconds. The "
957 "server might be stuck or generating a slow response.")
959 def handle_request(self):
960 """Handle one request, possibly blocking.
962 fd_sets = select.select([self], [], [], None)
963 if not fd_sets[0]:
964 return
965 self._handle_request_noblock()
967 def _handle_request_noblock(self):
969 Handle one request, without blocking.
971 I assume that select.select has returned that the socket is
972 readable before this function was called, so there should be
973 no risk of blocking in get_request().
975 try:
976 request, client_address = self.get_request()
977 except socket.error:
978 return
979 if self.verify_request(request, client_address):
980 try:
981 self.process_request(request, client_address)
982 except Exception:
983 self.handle_error(request, client_address)
984 self.close_request(request)
987 class _MediaFilesHandler(StaticFilesHandler):
989 Handler for serving the media files. This is a private class that is
990 meant to be used solely as a convenience by LiveServerThread.
993 def get_base_dir(self):
994 return settings.MEDIA_ROOT
996 def get_base_url(self):
997 return settings.MEDIA_URL
999 def serve(self, request):
1000 relative_url = request.path[len(self.base_url[2]):]
1001 return serve(request, relative_url, document_root=self.get_base_dir())
1004 class LiveServerThread(threading.Thread):
1006 Thread for running a live http server while the tests are running.
1009 def __init__(self, host, possible_ports, connections_override=None):
1010 self.host = host
1011 self.port = None
1012 self.possible_ports = possible_ports
1013 self.is_ready = threading.Event()
1014 self.error = None
1015 self.connections_override = connections_override
1016 super(LiveServerThread, self).__init__()
1018 def run(self):
1020 Sets up the live server and databases, and then loops over handling
1021 http requests.
1023 if self.connections_override:
1024 from django.db import connections
1025 # Override this thread's database connections with the ones
1026 # provided by the main thread.
1027 for alias, conn in self.connections_override.items():
1028 connections[alias] = conn
1029 try:
1030 # Create the handler for serving static and media files
1031 handler = StaticFilesHandler(_MediaFilesHandler(WSGIHandler()))
1033 # Go through the list of possible ports, hoping that we can find
1034 # one that is free to use for the WSGI server.
1035 for index, port in enumerate(self.possible_ports):
1036 try:
1037 self.httpd = StoppableWSGIServer(
1038 (self.host, port), QuietWSGIRequestHandler)
1039 except WSGIServerException, e:
1040 if sys.version_info < (2, 6):
1041 error_code = e.args[0].args[0]
1042 else:
1043 error_code = e.args[0].errno
1044 if (index + 1 < len(self.possible_ports) and
1045 error_code == errno.EADDRINUSE):
1046 # This port is already in use, so we go on and try with
1047 # the next one in the list.
1048 continue
1049 else:
1050 # Either none of the given ports are free or the error
1051 # is something else than "Address already in use". So
1052 # we let that error bubble up to the main thread.
1053 raise
1054 else:
1055 # A free port was found.
1056 self.port = port
1057 break
1059 self.httpd.set_app(handler)
1060 self.is_ready.set()
1061 self.httpd.serve_forever()
1062 except Exception, e:
1063 self.error = e
1064 self.is_ready.set()
1066 def join(self, timeout=None):
1067 if hasattr(self, 'httpd'):
1068 # Stop the WSGI server
1069 self.httpd.shutdown()
1070 self.httpd.server_close()
1071 super(LiveServerThread, self).join(timeout)
1074 class LiveServerTestCase(TransactionTestCase):
1076 Does basically the same as TransactionTestCase but also launches a live
1077 http server in a separate thread so that the tests may use another testing
1078 framework, such as Selenium for example, instead of the built-in dummy
1079 client.
1080 Note that it inherits from TransactionTestCase instead of TestCase because
1081 the threads do not share the same transactions (unless if using in-memory
1082 sqlite) and each thread needs to commit all their transactions so that the
1083 other thread can see the changes.
1086 @property
1087 def live_server_url(self):
1088 return 'http://%s:%s' % (
1089 self.server_thread.host, self.server_thread.port)
1091 @classmethod
1092 def setUpClass(cls):
1093 connections_override = {}
1094 for conn in connections.all():
1095 # If using in-memory sqlite databases, pass the connections to
1096 # the server thread.
1097 if (conn.settings_dict['ENGINE'] == 'django.db.backends.sqlite3'
1098 and conn.settings_dict['NAME'] == ':memory:'):
1099 # Explicitly enable thread-shareability for this connection
1100 conn.allow_thread_sharing = True
1101 connections_override[conn.alias] = conn
1103 # Launch the live server's thread
1104 specified_address = os.environ.get(
1105 'DJANGO_LIVE_TEST_SERVER_ADDRESS', 'localhost:8081')
1107 # The specified ports may be of the form '8000-8010,8080,9200-9300'
1108 # i.e. a comma-separated list of ports or ranges of ports, so we break
1109 # it down into a detailed list of all possible ports.
1110 possible_ports = []
1111 try:
1112 host, port_ranges = specified_address.split(':')
1113 for port_range in port_ranges.split(','):
1114 # A port range can be of either form: '8000' or '8000-8010'.
1115 extremes = map(int, port_range.split('-'))
1116 assert len(extremes) in [1, 2]
1117 if len(extremes) == 1:
1118 # Port range of the form '8000'
1119 possible_ports.append(extremes[0])
1120 else:
1121 # Port range of the form '8000-8010'
1122 for port in range(extremes[0], extremes[1] + 1):
1123 possible_ports.append(port)
1124 except Exception:
1125 raise ImproperlyConfigured('Invalid address ("%s") for live '
1126 'server.' % specified_address)
1127 cls.server_thread = LiveServerThread(
1128 host, possible_ports, connections_override)
1129 cls.server_thread.daemon = True
1130 cls.server_thread.start()
1132 # Wait for the live server to be ready
1133 cls.server_thread.is_ready.wait()
1134 if cls.server_thread.error:
1135 raise cls.server_thread.error
1137 super(LiveServerTestCase, cls).setUpClass()
1139 @classmethod
1140 def tearDownClass(cls):
1141 # There may not be a 'server_thread' attribute if setUpClass() for some
1142 # reasons has raised an exception.
1143 if hasattr(cls, 'server_thread'):
1144 # Terminate the live server's thread
1145 cls.server_thread.join()
1147 # Restore sqlite connections' non-sharability
1148 for conn in connections.all():
1149 if (conn.settings_dict['ENGINE'] == 'django.db.backends.sqlite3'
1150 and conn.settings_dict['NAME'] == ':memory:'):
1151 conn.allow_thread_sharing = False
1153 super(LiveServerTestCase, cls).tearDownClass()