1 from __future__
import with_statement
8 from functools
import wraps
9 from urlparse
import urlsplit
, urlunsplit
10 from xml
.dom
.minidom
import parseString
, Node
16 from django
.conf
import settings
17 from django
.contrib
.staticfiles
.handlers
import StaticFilesHandler
18 from django
.core
import mail
19 from django
.core
.exceptions
import ValidationError
, ImproperlyConfigured
20 from django
.core
.handlers
.wsgi
import WSGIHandler
21 from django
.core
.management
import call_command
22 from django
.core
.signals
import request_started
23 from django
.core
.servers
.basehttp
import (WSGIRequestHandler
, WSGIServer
,
25 from django
.core
.urlresolvers
import clear_url_caches
26 from django
.core
.validators
import EMPTY_VALUES
27 from django
.db
import (transaction
, connection
, connections
, DEFAULT_DB_ALIAS
,
29 from django
.forms
.fields
import CharField
30 from django
.http
import QueryDict
31 from django
.test
import _doctest
as doctest
32 from django
.test
.client
import Client
33 from django
.test
.html
import HTMLParseError
, parse_html
34 from django
.test
.signals
import template_rendered
35 from django
.test
.utils
import (get_warnings_state
, restore_warnings_state
,
37 from django
.test
.utils
import ContextList
38 from django
.utils
import simplejson
, unittest
as ut2
39 from django
.utils
.encoding
import smart_str
, force_unicode
40 from django
.utils
.unittest
.util
import safe_repr
41 from django
.views
.static
import serve
43 __all__
= ('DocTestRunner', 'OutputChecker', 'TestCase', 'TransactionTestCase',
44 'SimpleTestCase', 'skipIfDBFeature', 'skipUnlessDBFeature')
46 normalize_long_ints
= lambda s
: re
.sub(r
'(?<![\w])(\d+)L(?![\w])', '\\1', s
)
47 normalize_decimals
= lambda s
: re
.sub(r
"Decimal\('(\d+(\.\d*)?)'\)",
48 lambda m
: "Decimal(\"%s\")" % m
.groups()[0], s
)
52 Puts value into a list if it's not already one.
53 Returns an empty list if value is None.
57 elif not isinstance(value
, list):
61 real_commit
= transaction
.commit
62 real_rollback
= transaction
.rollback
63 real_enter_transaction_management
= transaction
.enter_transaction_management
64 real_leave_transaction_management
= transaction
.leave_transaction_management
65 real_managed
= transaction
.managed
67 def nop(*args
, **kwargs
):
70 def disable_transaction_methods():
71 transaction
.commit
= nop
72 transaction
.rollback
= nop
73 transaction
.enter_transaction_management
= nop
74 transaction
.leave_transaction_management
= nop
75 transaction
.managed
= nop
77 def restore_transaction_methods():
78 transaction
.commit
= real_commit
79 transaction
.rollback
= real_rollback
80 transaction
.enter_transaction_management
= real_enter_transaction_management
81 transaction
.leave_transaction_management
= real_leave_transaction_management
82 transaction
.managed
= real_managed
85 def assert_and_parse_html(self
, html
, user_msg
, msg
):
87 dom
= parse_html(html
)
88 except HTMLParseError
, e
:
89 standardMsg
= u
'%s\n%s' % (msg
, e
.msg
)
90 self
.fail(self
._formatMessage
(user_msg
, standardMsg
))
94 class OutputChecker(doctest
.OutputChecker
):
95 def check_output(self
, want
, got
, optionflags
):
97 The entry method for doctest output checking. Defers to a sequence of
100 checks
= (self
.check_output_default
,
101 self
.check_output_numeric
,
102 self
.check_output_xml
,
103 self
.check_output_json
)
105 if check(want
, got
, optionflags
):
109 def check_output_default(self
, want
, got
, optionflags
):
111 The default comparator provided by doctest - not perfect, but good for
114 return doctest
.OutputChecker
.check_output(self
, want
, got
, optionflags
)
116 def check_output_numeric(self
, want
, got
, optionflags
):
117 """Doctest does an exact string comparison of output, which means that
118 some numerically equivalent values aren't equal. This check normalizes
119 * long integers (22L) so that they equal normal integers. (22)
120 * Decimals so that they are comparable, regardless of the change
121 made to __repr__ in Python 2.6.
123 return doctest
.OutputChecker
.check_output(self
,
124 normalize_decimals(normalize_long_ints(want
)),
125 normalize_decimals(normalize_long_ints(got
)),
128 def check_output_xml(self
, want
, got
, optionsflags
):
129 """Tries to do a 'xml-comparision' of want and got. Plain string
130 comparision doesn't always work because, for example, attribute
131 ordering should not be important.
133 Based on http://codespeak.net/svn/lxml/trunk/src/lxml/doctestcompare.py
135 _norm_whitespace_re
= re
.compile(r
'[ \t\n][ \t\n]+')
136 def norm_whitespace(v
):
137 return _norm_whitespace_re
.sub(' ', v
)
139 def child_text(element
):
140 return ''.join([c
.data
for c
in element
.childNodes
141 if c
.nodeType
== Node
.TEXT_NODE
])
143 def children(element
):
144 return [c
for c
in element
.childNodes
145 if c
.nodeType
== Node
.ELEMENT_NODE
]
147 def norm_child_text(element
):
148 return norm_whitespace(child_text(element
))
150 def attrs_dict(element
):
151 return dict(element
.attributes
.items())
153 def check_element(want_element
, got_element
):
154 if want_element
.tagName
!= got_element
.tagName
:
156 if norm_child_text(want_element
) != norm_child_text(got_element
):
158 if attrs_dict(want_element
) != attrs_dict(got_element
):
160 want_children
= children(want_element
)
161 got_children
= children(got_element
)
162 if len(want_children
) != len(got_children
):
164 for want
, got
in zip(want_children
, got_children
):
165 if not check_element(want
, got
):
169 want
, got
= self
._strip
_quotes
(want
, got
)
170 want
= want
.replace('\\n','\n')
171 got
= got
.replace('\\n','\n')
173 # If the string is not a complete xml document, we may need to add a
174 # root element. This allow us to compare fragments, like "<foo/><bar/>"
175 if not want
.startswith('<?xml'):
176 wrapper
= '<root>%s</root>'
177 want
= wrapper
% want
180 # Parse the want and got strings, and compare the parsings.
182 want_root
= parseString(want
).firstChild
183 got_root
= parseString(got
).firstChild
186 return check_element(want_root
, got_root
)
188 def check_output_json(self
, want
, got
, optionsflags
):
190 Tries to compare want and got as if they were JSON-encoded data
192 want
, got
= self
._strip
_quotes
(want
, got
)
194 want_json
= simplejson
.loads(want
)
195 got_json
= simplejson
.loads(got
)
198 return want_json
== got_json
200 def _strip_quotes(self
, want
, got
):
202 Strip quotes of doctests output values:
204 >>> o = OutputChecker()
205 >>> o._strip_quotes("'foo'")
207 >>> o._strip_quotes('"foo"')
209 >>> o._strip_quotes("u'foo'")
211 >>> o._strip_quotes('u"foo"')
214 def is_quoted_string(s
):
218 and s
[0] in ('"', "'"))
220 def is_quoted_unicode(s
):
225 and s
[1] in ('"', "'"))
227 if is_quoted_string(want
) and is_quoted_string(got
):
228 want
= want
.strip()[1:-1]
229 got
= got
.strip()[1:-1]
230 elif is_quoted_unicode(want
) and is_quoted_unicode(got
):
231 want
= want
.strip()[2:-1]
232 got
= got
.strip()[2:-1]
236 class DocTestRunner(doctest
.DocTestRunner
):
237 def __init__(self
, *args
, **kwargs
):
238 doctest
.DocTestRunner
.__init
__(self
, *args
, **kwargs
)
239 self
.optionflags
= doctest
.ELLIPSIS
241 def report_unexpected_exception(self
, out
, test
, example
, exc_info
):
242 doctest
.DocTestRunner
.report_unexpected_exception(self
, out
, test
,
244 # Rollback, in case of database errors. Otherwise they'd have
245 # side effects on other tests.
246 for conn
in connections
:
247 transaction
.rollback_unless_managed(using
=conn
)
250 class _AssertNumQueriesContext(object):
251 def __init__(self
, test_case
, num
, connection
):
252 self
.test_case
= test_case
254 self
.connection
= connection
257 self
.old_debug_cursor
= self
.connection
.use_debug_cursor
258 self
.connection
.use_debug_cursor
= True
259 self
.starting_queries
= len(self
.connection
.queries
)
260 request_started
.disconnect(reset_queries
)
263 def __exit__(self
, exc_type
, exc_value
, traceback
):
264 self
.connection
.use_debug_cursor
= self
.old_debug_cursor
265 request_started
.connect(reset_queries
)
266 if exc_type
is not None:
269 final_queries
= len(self
.connection
.queries
)
270 executed
= final_queries
- self
.starting_queries
272 self
.test_case
.assertEqual(
273 executed
, self
.num
, "%d queries executed, %d expected" % (
279 class _AssertTemplateUsedContext(object):
280 def __init__(self
, test_case
, template_name
):
281 self
.test_case
= test_case
282 self
.template_name
= template_name
283 self
.rendered_templates
= []
284 self
.rendered_template_names
= []
285 self
.context
= ContextList()
287 def on_template_render(self
, sender
, signal
, template
, context
, **kwargs
):
288 self
.rendered_templates
.append(template
)
289 self
.rendered_template_names
.append(template
.name
)
290 self
.context
.append(copy(context
))
293 return self
.template_name
in self
.rendered_template_names
296 return u
'%s was not rendered.' % self
.template_name
299 template_rendered
.connect(self
.on_template_render
)
302 def __exit__(self
, exc_type
, exc_value
, traceback
):
303 template_rendered
.disconnect(self
.on_template_render
)
304 if exc_type
is not None:
308 message
= self
.message()
309 if len(self
.rendered_templates
) == 0:
310 message
+= u
' No template was rendered.'
312 message
+= u
' Following templates were rendered: %s' % (
313 ', '.join(self
.rendered_template_names
))
314 self
.test_case
.fail(message
)
317 class _AssertTemplateNotUsedContext(_AssertTemplateUsedContext
):
319 return self
.template_name
not in self
.rendered_template_names
322 return u
'%s was rendered.' % self
.template_name
325 class SimpleTestCase(ut2
.TestCase
):
326 def save_warnings_state(self
):
328 Saves the state of the warnings module
330 self
._warnings
_state
= get_warnings_state()
332 def restore_warnings_state(self
):
334 Restores the state of the warnings module to the state
335 saved by save_warnings_state()
337 restore_warnings_state(self
._warnings
_state
)
339 def settings(self
, **kwargs
):
341 A context manager that temporarily sets a setting and reverts
342 back to the original value when exiting the context.
344 return override_settings(**kwargs
)
346 def assertRaisesMessage(self
, expected_exception
, expected_message
,
347 callable_obj
=None, *args
, **kwargs
):
349 Asserts that the message in a raised exception matches the passed
353 expected_exception: Exception class expected to be raised.
354 expected_message: expected error message string value.
355 callable_obj: Function to be called.
357 kwargs: Extra kwargs.
359 return self
.assertRaisesRegexp(expected_exception
,
360 re
.escape(expected_message
), callable_obj
, *args
, **kwargs
)
362 def assertFieldOutput(self
, fieldclass
, valid
, invalid
, field_args
=None,
363 field_kwargs
=None, empty_value
=u
''):
365 Asserts that a form field behaves correctly with various inputs.
368 fieldclass: the class of the field to be tested.
369 valid: a dictionary mapping valid inputs to their expected
371 invalid: a dictionary mapping invalid inputs to one or more
372 raised error messages.
373 field_args: the args passed to instantiate the field
374 field_kwargs: the kwargs passed to instantiate the field
375 empty_value: the expected clean output for inputs in EMPTY_VALUES
378 if field_args
is None:
380 if field_kwargs
is None:
382 required
= fieldclass(*field_args
, **field_kwargs
)
383 optional
= fieldclass(*field_args
,
384 **dict(field_kwargs
, required
=False))
386 for input, output
in valid
.items():
387 self
.assertEqual(required
.clean(input), output
)
388 self
.assertEqual(optional
.clean(input), output
)
389 # test invalid inputs
390 for input, errors
in invalid
.items():
391 with self
.assertRaises(ValidationError
) as context_manager
:
392 required
.clean(input)
393 self
.assertEqual(context_manager
.exception
.messages
, errors
)
395 with self
.assertRaises(ValidationError
) as context_manager
:
396 optional
.clean(input)
397 self
.assertEqual(context_manager
.exception
.messages
, errors
)
398 # test required inputs
399 error_required
= [force_unicode(required
.error_messages
['required'])]
400 for e
in EMPTY_VALUES
:
401 with self
.assertRaises(ValidationError
) as context_manager
:
403 self
.assertEqual(context_manager
.exception
.messages
,
405 self
.assertEqual(optional
.clean(e
), empty_value
)
406 # test that max_length and min_length are always accepted
407 if issubclass(fieldclass
, CharField
):
408 field_kwargs
.update({'min_length':2, 'max_length':20})
409 self
.assertTrue(isinstance(fieldclass(*field_args
, **field_kwargs
),
412 def assertHTMLEqual(self
, html1
, html2
, msg
=None):
414 Asserts that two HTML snippets are semantically the same.
415 Whitespace in most cases is ignored, and attribute ordering is not
416 significant. The passed-in arguments must be valid HTML.
418 dom1
= assert_and_parse_html(self
, html1
, msg
,
419 u
'First argument is not valid HTML:')
420 dom2
= assert_and_parse_html(self
, html2
, msg
,
421 u
'Second argument is not valid HTML:')
424 standardMsg
= '%s != %s' % (
425 safe_repr(dom1
, True), safe_repr(dom2
, True))
426 diff
= ('\n' + '\n'.join(difflib
.ndiff(
427 unicode(dom1
).splitlines(),
428 unicode(dom2
).splitlines())))
429 standardMsg
= self
._truncateMessage
(standardMsg
, diff
)
430 self
.fail(self
._formatMessage
(msg
, standardMsg
))
432 def assertHTMLNotEqual(self
, html1
, html2
, msg
=None):
433 """Asserts that two HTML snippets are not semantically equivalent."""
434 dom1
= assert_and_parse_html(self
, html1
, msg
,
435 u
'First argument is not valid HTML:')
436 dom2
= assert_and_parse_html(self
, html2
, msg
,
437 u
'Second argument is not valid HTML:')
440 standardMsg
= '%s == %s' % (
441 safe_repr(dom1
, True), safe_repr(dom2
, True))
442 self
.fail(self
._formatMessage
(msg
, standardMsg
))
445 class TransactionTestCase(SimpleTestCase
):
446 # The class we'll use for the test client self.client.
447 # Can be overridden in derived classes.
448 client_class
= Client
450 def _pre_setup(self
):
451 """Performs any pre-test setup. This includes:
453 * Flushing the database.
454 * If the Test Case class has a 'fixtures' member, installing the
456 * If the Test Case class has a 'urls' member, replace the
457 ROOT_URLCONF with it.
458 * Clearing the mail test outbox.
460 self
._fixture
_setup
()
461 self
._urlconf
_setup
()
464 def _fixture_setup(self
):
465 # If the test case has a multi_db=True flag, flush all databases.
466 # Otherwise, just flush default.
467 if getattr(self
, 'multi_db', False):
468 databases
= connections
470 databases
= [DEFAULT_DB_ALIAS
]
472 call_command('flush', verbosity
=0, interactive
=False, database
=db
)
474 if hasattr(self
, 'fixtures'):
475 # We have to use this slightly awkward syntax due to the fact
476 # that we're using *args and **kwargs together.
477 call_command('loaddata', *self
.fixtures
,
478 **{'verbosity': 0, 'database': db
})
480 def _urlconf_setup(self
):
481 if hasattr(self
, 'urls'):
482 self
._old
_root
_urlconf
= settings
.ROOT_URLCONF
483 settings
.ROOT_URLCONF
= self
.urls
486 def __call__(self
, result
=None):
488 Wrapper around default __call__ method to perform common Django test
489 set up. This means that user-defined Test Cases aren't required to
490 include a call to super().setUp().
492 testMethod
= getattr(self
, self
._testMethodName
)
493 skipped
= (getattr(self
.__class
__, "__unittest_skip__", False) or
494 getattr(testMethod
, "__unittest_skip__", False))
497 self
.client
= self
.client_class()
500 except (KeyboardInterrupt, SystemExit):
503 result
.addError(self
, sys
.exc_info())
505 super(TransactionTestCase
, self
).__call
__(result
)
508 self
._post
_teardown
()
509 except (KeyboardInterrupt, SystemExit):
512 result
.addError(self
, sys
.exc_info())
515 def _post_teardown(self
):
516 """ Performs any post-test things. This includes:
518 * Putting back the original ROOT_URLCONF if it was changed.
519 * Force closing the connection, so that the next test gets
522 self
._fixture
_teardown
()
523 self
._urlconf
_teardown
()
524 # Some DB cursors include SQL statements as part of cursor
525 # creation. If you have a test that does rollback, the effect
526 # of these statements is lost, which can effect the operation
527 # of tests (e.g., losing a timezone setting causing objects to
528 # be created with the wrong time).
529 # To make sure this doesn't happen, get a clean connection at the
530 # start of every test.
531 for conn
in connections
.all():
534 def _fixture_teardown(self
):
537 def _urlconf_teardown(self
):
538 if hasattr(self
, '_old_root_urlconf'):
539 settings
.ROOT_URLCONF
= self
._old
_root
_urlconf
542 def assertRedirects(self
, response
, expected_url
, status_code
=302,
543 target_status_code
=200, host
=None, msg_prefix
=''):
544 """Asserts that a response redirected to a specific URL, and that the
545 redirect URL can be loaded.
547 Note that assertRedirects won't work for external links since it uses
548 TestClient to do a request.
553 if hasattr(response
, 'redirect_chain'):
554 # The request was a followed redirect
555 self
.assertTrue(len(response
.redirect_chain
) > 0,
556 msg_prefix
+ "Response didn't redirect as expected: Response"
557 " code was %d (expected %d)" %
558 (response
.status_code
, status_code
))
560 self
.assertEqual(response
.redirect_chain
[0][1], status_code
,
561 msg_prefix
+ "Initial response didn't redirect as expected:"
562 " Response code was %d (expected %d)" %
563 (response
.redirect_chain
[0][1], status_code
))
565 url
, status_code
= response
.redirect_chain
[-1]
567 self
.assertEqual(response
.status_code
, target_status_code
,
568 msg_prefix
+ "Response didn't redirect as expected: Final"
569 " Response code was %d (expected %d)" %
570 (response
.status_code
, target_status_code
))
573 # Not a followed redirect
574 self
.assertEqual(response
.status_code
, status_code
,
575 msg_prefix
+ "Response didn't redirect as expected: Response"
576 " code was %d (expected %d)" %
577 (response
.status_code
, status_code
))
579 url
= response
['Location']
580 scheme
, netloc
, path
, query
, fragment
= urlsplit(url
)
582 redirect_response
= response
.client
.get(path
, QueryDict(query
))
584 # Get the redirection page, using the same client that was used
585 # to obtain the original response.
586 self
.assertEqual(redirect_response
.status_code
, target_status_code
,
587 msg_prefix
+ "Couldn't retrieve redirection page '%s':"
588 " response code was %d (expected %d)" %
589 (path
, redirect_response
.status_code
, target_status_code
))
591 e_scheme
, e_netloc
, e_path
, e_query
, e_fragment
= urlsplit(
593 if not (e_scheme
or e_netloc
):
594 expected_url
= urlunsplit(('http', host
or 'testserver', e_path
,
595 e_query
, e_fragment
))
597 self
.assertEqual(url
, expected_url
,
598 msg_prefix
+ "Response redirected to '%s', expected '%s'" %
601 def assertContains(self
, response
, text
, count
=None, status_code
=200,
602 msg_prefix
='', html
=False):
604 Asserts that a response indicates that some content was retrieved
605 successfully, (i.e., the HTTP status code was as expected), and that
606 ``text`` occurs ``count`` times in the content of the response.
607 If ``count`` is None, the count doesn't matter - the assertion is true
608 if the text occurs at least once in the response.
611 # If the response supports deferred rendering and hasn't been rendered
612 # yet, then ensure that it does get rendered before proceeding further.
613 if (hasattr(response
, 'render') and callable(response
.render
)
614 and not response
.is_rendered
):
620 self
.assertEqual(response
.status_code
, status_code
,
621 msg_prefix
+ "Couldn't retrieve content: Response code was %d"
622 " (expected %d)" % (response
.status_code
, status_code
))
623 text
= smart_str(text
, response
._charset
)
624 content
= response
.content
626 content
= assert_and_parse_html(self
, content
, None,
627 u
"Response's content is not valid HTML:")
628 text
= assert_and_parse_html(self
, text
, None,
629 u
"Second argument is not valid HTML:")
630 real_count
= content
.count(text
)
631 if count
is not None:
632 self
.assertEqual(real_count
, count
,
633 msg_prefix
+ "Found %d instances of '%s' in response"
634 " (expected %d)" % (real_count
, text
, count
))
636 self
.assertTrue(real_count
!= 0,
637 msg_prefix
+ "Couldn't find '%s' in response" % text
)
639 def assertNotContains(self
, response
, text
, status_code
=200,
640 msg_prefix
='', html
=False):
642 Asserts that a response indicates that some content was retrieved
643 successfully, (i.e., the HTTP status code was as expected), and that
644 ``text`` doesn't occurs in the content of the response.
647 # If the response supports deferred rendering and hasn't been rendered
648 # yet, then ensure that it does get rendered before proceeding further.
649 if (hasattr(response
, 'render') and callable(response
.render
)
650 and not response
.is_rendered
):
656 self
.assertEqual(response
.status_code
, status_code
,
657 msg_prefix
+ "Couldn't retrieve content: Response code was %d"
658 " (expected %d)" % (response
.status_code
, status_code
))
659 text
= smart_str(text
, response
._charset
)
660 content
= response
.content
662 content
= assert_and_parse_html(self
, content
, None,
663 u
'Response\'s content is not valid HTML:')
664 text
= assert_and_parse_html(self
, text
, None,
665 u
'Second argument is not valid HTML:')
666 self
.assertEqual(content
.count(text
), 0,
667 msg_prefix
+ "Response should not contain '%s'" % text
)
669 def assertFormError(self
, response
, form
, field
, errors
, msg_prefix
=''):
671 Asserts that a form used to render the response has a specific field
677 # Put context(s) into a list to simplify processing.
678 contexts
= to_list(response
.context
)
680 self
.fail(msg_prefix
+ "Response did not use any contexts to "
681 "render the response")
683 # Put error(s) into a list to simplify processing.
684 errors
= to_list(errors
)
686 # Search all contexts for the error.
688 for i
,context
in enumerate(contexts
):
689 if form
not in context
:
694 if field
in context
[form
].errors
:
695 field_errors
= context
[form
].errors
[field
]
696 self
.assertTrue(err
in field_errors
,
697 msg_prefix
+ "The field '%s' on form '%s' in"
698 " context %d does not contain the error '%s'"
699 " (actual errors: %s)" %
700 (field
, form
, i
, err
, repr(field_errors
)))
701 elif field
in context
[form
].fields
:
702 self
.fail(msg_prefix
+ "The field '%s' on form '%s'"
703 " in context %d contains no errors" %
706 self
.fail(msg_prefix
+ "The form '%s' in context %d"
707 " does not contain the field '%s'" %
710 non_field_errors
= context
[form
].non_field_errors()
711 self
.assertTrue(err
in non_field_errors
,
712 msg_prefix
+ "The form '%s' in context %d does not"
713 " contain the non-field error '%s'"
714 " (actual errors: %s)" %
715 (form
, i
, err
, non_field_errors
))
717 self
.fail(msg_prefix
+ "The form '%s' was not used to render the"
720 def assertTemplateUsed(self
, response
=None, template_name
=None, msg_prefix
=''):
722 Asserts that the template with the provided name was used in rendering
723 the response. Also usable as context manager.
725 if response
is None and template_name
is None:
726 raise TypeError(u
'response and/or template_name argument must be provided')
731 # Use assertTemplateUsed as context manager.
732 if not hasattr(response
, 'templates') or (response
is None and template_name
):
734 template_name
= response
736 context
= _AssertTemplateUsedContext(self
, template_name
)
739 template_names
= [t
.name
for t
in response
.templates
]
740 if not template_names
:
741 self
.fail(msg_prefix
+ "No templates used to render the response")
742 self
.assertTrue(template_name
in template_names
,
743 msg_prefix
+ "Template '%s' was not a template used to render"
744 " the response. Actual template(s) used: %s" %
745 (template_name
, u
', '.join(template_names
)))
747 def assertTemplateNotUsed(self
, response
=None, template_name
=None, msg_prefix
=''):
749 Asserts that the template with the provided name was NOT used in
750 rendering the response. Also usable as context manager.
752 if response
is None and template_name
is None:
753 raise TypeError(u
'response and/or template_name argument must be provided')
758 # Use assertTemplateUsed as context manager.
759 if not hasattr(response
, 'templates') or (response
is None and template_name
):
761 template_name
= response
763 context
= _AssertTemplateNotUsedContext(self
, template_name
)
766 template_names
= [t
.name
for t
in response
.templates
]
767 self
.assertFalse(template_name
in template_names
,
768 msg_prefix
+ "Template '%s' was used unexpectedly in rendering"
769 " the response" % template_name
)
771 def assertQuerysetEqual(self
, qs
, values
, transform
=repr, ordered
=True):
773 return self
.assertEqual(set(map(transform
, qs
)), set(values
))
774 return self
.assertEqual(map(transform
, qs
), values
)
776 def assertNumQueries(self
, num
, func
=None, *args
, **kwargs
):
777 using
= kwargs
.pop("using", DEFAULT_DB_ALIAS
)
778 conn
= connections
[using
]
780 context
= _AssertNumQueriesContext(self
, num
, conn
)
785 func(*args
, **kwargs
)
788 def connections_support_transactions():
790 Returns True if all connections support transactions.
792 return all(conn
.features
.supports_transactions
793 for conn
in connections
.all())
796 class TestCase(TransactionTestCase
):
798 Does basically the same as TransactionTestCase, but surrounds every test
799 with a transaction, monkey-patches the real transaction management routines
800 to do nothing, and rollsback the test transaction at the end of the test.
801 You have to use TransactionTestCase, if you need transaction management
805 def _fixture_setup(self
):
806 if not connections_support_transactions():
807 return super(TestCase
, self
)._fixture
_setup
()
809 # If the test case has a multi_db=True flag, setup all databases.
810 # Otherwise, just use default.
811 if getattr(self
, 'multi_db', False):
812 databases
= connections
814 databases
= [DEFAULT_DB_ALIAS
]
817 transaction
.enter_transaction_management(using
=db
)
818 transaction
.managed(True, using
=db
)
819 disable_transaction_methods()
821 from django
.contrib
.sites
.models
import Site
822 Site
.objects
.clear_cache()
825 if hasattr(self
, 'fixtures'):
826 call_command('loaddata', *self
.fixtures
,
833 def _fixture_teardown(self
):
834 if not connections_support_transactions():
835 return super(TestCase
, self
)._fixture
_teardown
()
837 # If the test case has a multi_db=True flag, teardown all databases.
838 # Otherwise, just teardown default.
839 if getattr(self
, 'multi_db', False):
840 databases
= connections
842 databases
= [DEFAULT_DB_ALIAS
]
844 restore_transaction_methods()
846 transaction
.rollback(using
=db
)
847 transaction
.leave_transaction_management(using
=db
)
850 def _deferredSkip(condition
, reason
):
851 def decorator(test_func
):
852 if not (isinstance(test_func
, type) and
853 issubclass(test_func
, TestCase
)):
855 def skip_wrapper(*args
, **kwargs
):
857 raise ut2
.SkipTest(reason
)
858 return test_func(*args
, **kwargs
)
859 test_item
= skip_wrapper
861 test_item
= test_func
862 test_item
.__unittest
_skip
_why
__ = reason
867 def skipIfDBFeature(feature
):
869 Skip a test if a database has the named feature
871 return _deferredSkip(lambda: getattr(connection
.features
, feature
),
872 "Database has feature %s" % feature
)
875 def skipUnlessDBFeature(feature
):
877 Skip a test unless a database has the named feature
879 return _deferredSkip(lambda: not getattr(connection
.features
, feature
),
880 "Database doesn't support feature %s" % feature
)
883 class QuietWSGIRequestHandler(WSGIRequestHandler
):
885 Just a regular WSGIRequestHandler except it doesn't log to the standard
886 output any of the requests received, so as to not clutter the output for
890 def log_message(*args
):
894 class _ImprovedEvent(threading
._Event
):
896 Does the same as `threading.Event` except it overrides the wait() method
897 with some code borrowed from Python 2.7 to return the set state of the
898 event (see: http://hg.python.org/cpython/rev/b5aa8aa78c0f/). This allows
899 to know whether the wait() method exited normally or because of the
900 timeout. This class can be removed when Django supports only Python >= 2.7.
903 def wait(self
, timeout
=None):
904 self
._Event
__cond
.acquire()
906 if not self
._Event
__flag
:
907 self
._Event
__cond
.wait(timeout
)
908 return self
._Event
__flag
910 self
._Event
__cond
.release()
913 class StoppableWSGIServer(WSGIServer
):
915 The code in this class is borrowed from the `SocketServer.BaseServer` class
916 in Python 2.6. The important functionality here is that the server is non-
917 blocking and that it can be shut down at any moment. This is made possible
918 by the server regularly polling the socket and checking if it has been
920 Note for the future: Once Django stops supporting Python 2.6, this class
921 can be removed as `WSGIServer` will have this ability to shutdown on
922 demand and will not require the use of the _ImprovedEvent class whose code
923 is borrowed from Python 2.7.
926 def __init__(self
, *args
, **kwargs
):
927 super(StoppableWSGIServer
, self
).__init
__(*args
, **kwargs
)
928 self
.__is
_shut
_down
= _ImprovedEvent()
929 self
.__serving
= False
931 def serve_forever(self
, poll_interval
=0.5):
933 Handle one request at a time until shutdown.
935 Polls for shutdown every poll_interval seconds.
937 self
.__serving
= True
938 self
.__is
_shut
_down
.clear()
939 while self
.__serving
:
940 r
, w
, e
= select
.select([self
], [], [], poll_interval
)
942 self
._handle
_request
_noblock
()
943 self
.__is
_shut
_down
.set()
947 Stops the serve_forever loop.
949 Blocks until the loop has finished. This must be called while
950 serve_forever() is running in another thread, or it will
953 self
.__serving
= False
954 if not self
.__is
_shut
_down
.wait(2):
956 "Failed to shutdown the live test server in 2 seconds. The "
957 "server might be stuck or generating a slow response.")
959 def handle_request(self
):
960 """Handle one request, possibly blocking.
962 fd_sets
= select
.select([self
], [], [], None)
965 self
._handle
_request
_noblock
()
967 def _handle_request_noblock(self
):
969 Handle one request, without blocking.
971 I assume that select.select has returned that the socket is
972 readable before this function was called, so there should be
973 no risk of blocking in get_request().
976 request
, client_address
= self
.get_request()
979 if self
.verify_request(request
, client_address
):
981 self
.process_request(request
, client_address
)
983 self
.handle_error(request
, client_address
)
984 self
.close_request(request
)
987 class _MediaFilesHandler(StaticFilesHandler
):
989 Handler for serving the media files. This is a private class that is
990 meant to be used solely as a convenience by LiveServerThread.
993 def get_base_dir(self
):
994 return settings
.MEDIA_ROOT
996 def get_base_url(self
):
997 return settings
.MEDIA_URL
999 def serve(self
, request
):
1000 relative_url
= request
.path
[len(self
.base_url
[2]):]
1001 return serve(request
, relative_url
, document_root
=self
.get_base_dir())
1004 class LiveServerThread(threading
.Thread
):
1006 Thread for running a live http server while the tests are running.
1009 def __init__(self
, host
, possible_ports
, connections_override
=None):
1012 self
.possible_ports
= possible_ports
1013 self
.is_ready
= threading
.Event()
1015 self
.connections_override
= connections_override
1016 super(LiveServerThread
, self
).__init
__()
1020 Sets up the live server and databases, and then loops over handling
1023 if self
.connections_override
:
1024 from django
.db
import connections
1025 # Override this thread's database connections with the ones
1026 # provided by the main thread.
1027 for alias
, conn
in self
.connections_override
.items():
1028 connections
[alias
] = conn
1030 # Create the handler for serving static and media files
1031 handler
= StaticFilesHandler(_MediaFilesHandler(WSGIHandler()))
1033 # Go through the list of possible ports, hoping that we can find
1034 # one that is free to use for the WSGI server.
1035 for index
, port
in enumerate(self
.possible_ports
):
1037 self
.httpd
= StoppableWSGIServer(
1038 (self
.host
, port
), QuietWSGIRequestHandler
)
1039 except WSGIServerException
, e
:
1040 if sys
.version_info
< (2, 6):
1041 error_code
= e
.args
[0].args
[0]
1043 error_code
= e
.args
[0].errno
1044 if (index
+ 1 < len(self
.possible_ports
) and
1045 error_code
== errno
.EADDRINUSE
):
1046 # This port is already in use, so we go on and try with
1047 # the next one in the list.
1050 # Either none of the given ports are free or the error
1051 # is something else than "Address already in use". So
1052 # we let that error bubble up to the main thread.
1055 # A free port was found.
1059 self
.httpd
.set_app(handler
)
1061 self
.httpd
.serve_forever()
1062 except Exception, e
:
1066 def join(self
, timeout
=None):
1067 if hasattr(self
, 'httpd'):
1068 # Stop the WSGI server
1069 self
.httpd
.shutdown()
1070 self
.httpd
.server_close()
1071 super(LiveServerThread
, self
).join(timeout
)
1074 class LiveServerTestCase(TransactionTestCase
):
1076 Does basically the same as TransactionTestCase but also launches a live
1077 http server in a separate thread so that the tests may use another testing
1078 framework, such as Selenium for example, instead of the built-in dummy
1080 Note that it inherits from TransactionTestCase instead of TestCase because
1081 the threads do not share the same transactions (unless if using in-memory
1082 sqlite) and each thread needs to commit all their transactions so that the
1083 other thread can see the changes.
1087 def live_server_url(self
):
1088 return 'http://%s:%s' % (
1089 self
.server_thread
.host
, self
.server_thread
.port
)
1092 def setUpClass(cls
):
1093 connections_override
= {}
1094 for conn
in connections
.all():
1095 # If using in-memory sqlite databases, pass the connections to
1096 # the server thread.
1097 if (conn
.settings_dict
['ENGINE'] == 'django.db.backends.sqlite3'
1098 and conn
.settings_dict
['NAME'] == ':memory:'):
1099 # Explicitly enable thread-shareability for this connection
1100 conn
.allow_thread_sharing
= True
1101 connections_override
[conn
.alias
] = conn
1103 # Launch the live server's thread
1104 specified_address
= os
.environ
.get(
1105 'DJANGO_LIVE_TEST_SERVER_ADDRESS', 'localhost:8081')
1107 # The specified ports may be of the form '8000-8010,8080,9200-9300'
1108 # i.e. a comma-separated list of ports or ranges of ports, so we break
1109 # it down into a detailed list of all possible ports.
1112 host
, port_ranges
= specified_address
.split(':')
1113 for port_range
in port_ranges
.split(','):
1114 # A port range can be of either form: '8000' or '8000-8010'.
1115 extremes
= map(int, port_range
.split('-'))
1116 assert len(extremes
) in [1, 2]
1117 if len(extremes
) == 1:
1118 # Port range of the form '8000'
1119 possible_ports
.append(extremes
[0])
1121 # Port range of the form '8000-8010'
1122 for port
in range(extremes
[0], extremes
[1] + 1):
1123 possible_ports
.append(port
)
1125 raise ImproperlyConfigured('Invalid address ("%s") for live '
1126 'server.' % specified_address
)
1127 cls
.server_thread
= LiveServerThread(
1128 host
, possible_ports
, connections_override
)
1129 cls
.server_thread
.daemon
= True
1130 cls
.server_thread
.start()
1132 # Wait for the live server to be ready
1133 cls
.server_thread
.is_ready
.wait()
1134 if cls
.server_thread
.error
:
1135 raise cls
.server_thread
.error
1137 super(LiveServerTestCase
, cls
).setUpClass()
1140 def tearDownClass(cls
):
1141 # There may not be a 'server_thread' attribute if setUpClass() for some
1142 # reasons has raised an exception.
1143 if hasattr(cls
, 'server_thread'):
1144 # Terminate the live server's thread
1145 cls
.server_thread
.join()
1147 # Restore sqlite connections' non-sharability
1148 for conn
in connections
.all():
1149 if (conn
.settings_dict
['ENGINE'] == 'django.db.backends.sqlite3'
1150 and conn
.settings_dict
['NAME'] == ':memory:'):
1151 conn
.allow_thread_sharing
= False
1153 super(LiveServerTestCase
, cls
).tearDownClass()