1 """Extensions to unittest for web frameworks.
3 Use the WebCase.getPage method to request a page from your HTTP server.
8 If you have control over your server process, you can handle errors
9 in the server-side of the HTTP conversation a bit better. You must run
10 both the client (your WebCase tests) and the server in the same process
11 (but in separate threads, obviously).
13 When an error occurs in the framework, call server_error. It will print
14 the traceback to stdout, and keep any assertions you have from running
15 (the assumption is that, if the server errors, the page output will not
16 be of further significance to your tests).
28 from unittest
import *
29 from unittest
import _TextTestResult
31 from cherrypy
._cpcompat
import basestring
, HTTPConnection
, HTTPSConnection
, unicodestr
36 """Return an IP address for a client connection given the server host.
38 If the server is listening on '0.0.0.0' (INADDR_ANY)
39 or '::' (IN6ADDR_ANY), this will return the proper localhost."""
41 # INADDR_ANY, which should respond on localhost.
44 # IN6ADDR_ANY, which should respond on localhost.
49 class TerseTestResult(_TextTestResult
):
51 def printErrors(self
):
52 # Overridden to avoid unnecessary empty line
53 if self
.errors
or self
.failures
:
54 if self
.dots
or self
.showAll
:
56 self
.printErrorList('ERROR', self
.errors
)
57 self
.printErrorList('FAIL', self
.failures
)
60 class TerseTestRunner(TextTestRunner
):
61 """A test runner class that displays results in textual form."""
63 def _makeResult(self
):
64 return TerseTestResult(self
.stream
, self
.descriptions
, self
.verbosity
)
67 "Run the given test case or test suite."
68 # Overridden to remove unnecessary empty lines and separators
69 result
= self
._makeResult
()
72 if not result
.wasSuccessful():
73 self
.stream
.write("FAILED (")
74 failed
, errored
= list(map(len, (result
.failures
, result
.errors
)))
76 self
.stream
.write("failures=%d" % failed
)
78 if failed
: self
.stream
.write(", ")
79 self
.stream
.write("errors=%d" % errored
)
80 self
.stream
.writeln(")")
84 class ReloadingTestLoader(TestLoader
):
86 def loadTestsFromName(self
, name
, module
=None):
87 """Return a suite of all tests cases given a string specifier.
89 The name may resolve either to a module, a test case class, a
90 test method within a test case class, or a callable object which
91 returns a TestCase or TestSuite instance.
93 The method optionally resolves the names relative to a given module.
95 parts
= name
.split('.')
99 raise ValueError("incomplete test name: %s" % name
)
101 parts_copy
= parts
[:]
103 target
= ".".join(parts_copy
)
104 if target
in sys
.modules
:
105 module
= reload(sys
.modules
[target
])
110 module
= __import__(target
)
114 unused_parts
.insert(0,parts_copy
[-1])
121 obj
= getattr(obj
, part
)
123 if type(obj
) == types
.ModuleType
:
124 return self
.loadTestsFromModule(obj
)
125 elif (isinstance(obj
, (type, types
.ClassType
)) and
126 issubclass(obj
, TestCase
)):
127 return self
.loadTestsFromTestCase(obj
)
128 elif type(obj
) == types
.UnboundMethodType
:
129 return obj
.im_class(obj
.__name
__)
130 elif hasattr(obj
, '__call__'):
132 if not isinstance(test
, TestCase
) and \
133 not isinstance(test
, TestSuite
):
134 raise ValueError("calling %s returned %s, "
135 "not a test" % (obj
,test
))
138 raise ValueError("do not know how to make test from: %s" % obj
)
143 if sys
.platform
[:4] == 'java':
145 # Hopefully this is enough
146 return sys
.stdin
.read(1)
148 # On Windows, msvcrt.getch reads a single char without output.
151 return msvcrt
.getch()
156 fd
= sys
.stdin
.fileno()
157 old_settings
= termios
.tcgetattr(fd
)
159 tty
.setraw(sys
.stdin
.fileno())
160 ch
= sys
.stdin
.read(1)
162 termios
.tcsetattr(fd
, termios
.TCSADRAIN
, old_settings
)
166 class WebCase(TestCase
):
169 HTTP_CONN
= HTTPConnection
170 PROTOCOL
= "HTTP/1.1"
183 def get_conn(self
, auto_open
=False):
184 """Return a connection to our HTTP server."""
185 if self
.scheme
== "https":
186 cls
= HTTPSConnection
189 conn
= cls(self
.interface(), self
.PORT
)
190 # Automatically re-connect?
191 conn
.auto_open
= auto_open
195 def set_persistent(self
, on
=True, auto_open
=False):
196 """Make our HTTP_CONN persistent (or not).
198 If the 'on' argument is True (the default), then self.HTTP_CONN
199 will be set to an instance of HTTPConnection (or HTTPS
200 if self.scheme is "https"). This will then persist across requests.
202 We only allow for a single open connection, so if you call this
203 and we currently have an open connection, it will be closed.
206 self
.HTTP_CONN
.close()
207 except (TypeError, AttributeError):
211 self
.HTTP_CONN
= self
.get_conn(auto_open
=auto_open
)
213 if self
.scheme
== "https":
214 self
.HTTP_CONN
= HTTPSConnection
216 self
.HTTP_CONN
= HTTPConnection
218 def _get_persistent(self
):
219 return hasattr(self
.HTTP_CONN
, "__class__")
220 def _set_persistent(self
, on
):
221 self
.set_persistent(on
)
222 persistent
= property(_get_persistent
, _set_persistent
)
225 """Return an IP address for a client connection.
227 If the server is listening on '0.0.0.0' (INADDR_ANY)
228 or '::' (IN6ADDR_ANY), this will return the proper localhost."""
229 return interface(self
.HOST
)
231 def getPage(self
, url
, headers
=None, method
="GET", body
=None, protocol
=None):
232 """Open the url with debugging support. Return status, headers, body."""
233 ServerError
.on
= False
235 if isinstance(url
, unicodestr
):
236 url
= url
.encode('utf-8')
237 if isinstance(body
, unicodestr
):
238 body
= body
.encode('utf-8')
243 result
= openURL(url
, headers
, method
, body
, self
.HOST
, self
.PORT
,
244 self
.HTTP_CONN
, protocol
or self
.PROTOCOL
)
245 self
.time
= time
.time() - start
246 self
.status
, self
.headers
, self
.body
= result
248 # Build a list of request cookies from the previous response cookies.
249 self
.cookies
= [('Cookie', v
) for k
, v
in self
.headers
250 if k
.lower() == 'set-cookie']
259 def _handlewebError(self
, msg
):
261 print(" ERROR: %s" % msg
)
263 if not self
.interactive
:
264 raise self
.failureException(msg
)
266 p
= " Show: [B]ody [H]eaders [S]tatus [U]RL; [I]gnore, [R]aise, or sys.e[X]it >> "
270 i
= getchar().upper()
271 if i
not in "BHSUIRX":
273 print(i
.upper()) # Also prints new line
275 for x
, line
in enumerate(self
.body
.splitlines()):
276 if (x
+ 1) % self
.console_height
== 0:
277 # The \r and comma should make the next line overwrite
278 sys
.stdout
.write("<-- More -->\r")
279 m
= getchar().lower()
280 # Erase our "More" prompt
281 sys
.stdout
.write(" \r")
286 pprint
.pprint(self
.headers
)
292 # return without raising the normal exception
295 raise self
.failureException(msg
)
304 def assertStatus(self
, status
, msg
=None):
305 """Fail if self.status != status."""
306 if isinstance(status
, basestring
):
307 if not self
.status
== status
:
309 msg
= 'Status (%r) != %r' % (self
.status
, status
)
310 self
._handlewebError
(msg
)
311 elif isinstance(status
, int):
312 code
= int(self
.status
[:3])
315 msg
= 'Status (%r) != %r' % (self
.status
, status
)
316 self
._handlewebError
(msg
)
318 # status is a tuple or list.
321 if isinstance(s
, basestring
):
325 elif int(self
.status
[:3]) == s
:
330 msg
= 'Status (%r) not in %r' % (self
.status
, status
)
331 self
._handlewebError
(msg
)
333 def assertHeader(self
, key
, value
=None, msg
=None):
334 """Fail if (key, [value]) not in self.headers."""
336 for k
, v
in self
.headers
:
337 if k
.lower() == lowkey
:
338 if value
is None or str(value
) == v
:
343 msg
= '%r not in headers' % key
345 msg
= '%r:%r not in headers' % (key
, value
)
346 self
._handlewebError
(msg
)
348 def assertHeaderItemValue(self
, key
, value
, msg
=None):
349 """Fail if the header does not contain the specified value"""
350 actual_value
= self
.assertHeader(key
, msg
=msg
)
351 header_values
= map(str.strip
, actual_value
.split(','))
352 if value
in header_values
:
356 msg
= "%r not in %r" % (value
, header_values
)
357 self
._handlewebError
(msg
)
359 def assertNoHeader(self
, key
, msg
=None):
360 """Fail if key in self.headers."""
362 matches
= [k
for k
, v
in self
.headers
if k
.lower() == lowkey
]
365 msg
= '%r in headers' % key
366 self
._handlewebError
(msg
)
368 def assertBody(self
, value
, msg
=None):
369 """Fail if value != self.body."""
370 if value
!= self
.body
:
372 msg
= 'expected body:\n%r\n\nactual body:\n%r' % (value
, self
.body
)
373 self
._handlewebError
(msg
)
375 def assertInBody(self
, value
, msg
=None):
376 """Fail if value not in self.body."""
377 if value
not in self
.body
:
379 msg
= '%r not in body: %s' % (value
, self
.body
)
380 self
._handlewebError
(msg
)
382 def assertNotInBody(self
, value
, msg
=None):
383 """Fail if value in self.body."""
384 if value
in self
.body
:
386 msg
= '%r found in body' % value
387 self
._handlewebError
(msg
)
389 def assertMatchesBody(self
, pattern
, msg
=None, flags
=0):
390 """Fail if value (a regex pattern) is not in self.body."""
391 if re
.search(pattern
, self
.body
, flags
) is None:
393 msg
= 'No match for %r in body' % pattern
394 self
._handlewebError
(msg
)
397 methods_with_bodies
= ("POST", "PUT")
399 def cleanHeaders(headers
, method
, body
, host
, port
):
400 """Return request headers, with required headers added (if missing)."""
404 # Add the required Host request header if not present.
405 # [This specifies the host:port of the server, not the client.]
408 if k
.lower() == 'host':
413 headers
.append(("Host", host
))
415 headers
.append(("Host", "%s:%s" % (host
, port
)))
417 if method
in methods_with_bodies
:
418 # Stick in default type and length headers if not present
421 if k
.lower() == 'content-type':
425 headers
.append(("Content-Type", "application/x-www-form-urlencoded"))
426 headers
.append(("Content-Length", str(len(body
or ""))))
432 """Return status, headers, body the way we like from a response."""
434 key
, value
= None, None
435 for line
in response
.msg
.headers
:
438 value
+= line
.strip()
441 h
.append((key
, value
))
442 key
, value
= line
.split(":", 1)
444 value
= value
.strip()
446 h
.append((key
, value
))
448 return "%s %s" % (response
.status
, response
.reason
), h
, response
.read()
451 def openURL(url
, headers
=None, method
="GET", body
=None,
452 host
="127.0.0.1", port
=8000, http_conn
=HTTPConnection
,
453 protocol
="HTTP/1.1"):
454 """Open the given HTTP resource and return status, headers, and body."""
456 headers
= cleanHeaders(headers
, method
, body
, host
, port
)
458 # Trying 10 times is simply in case of socket errors.
459 # Normal case--it should run once.
460 for trial
in range(10):
462 # Allow http_conn to be a class or an instance
463 if hasattr(http_conn
, "host"):
466 conn
= http_conn(interface(host
), port
)
468 conn
._http
_vsn
_str
= protocol
469 conn
._http
_vsn
= int("".join([x
for x
in protocol
if x
.isdigit()]))
471 # skip_accept_encoding argument added in python version 2.4
472 if sys
.version_info
< (2, 4):
473 def putheader(self
, header
, value
):
474 if header
== 'Accept-Encoding' and value
== 'identity':
476 self
.__class
__.putheader(self
, header
, value
)
478 conn
.putheader
= new
.instancemethod(putheader
, conn
, conn
.__class
__)
479 conn
.putrequest(method
.upper(), url
, skip_host
=True)
481 conn
.putrequest(method
.upper(), url
, skip_host
=True,
482 skip_accept_encoding
=True)
484 for key
, value
in headers
:
485 conn
.putheader(key
, value
)
492 response
= conn
.getresponse()
494 s
, h
, b
= shb(response
)
496 if not hasattr(http_conn
, "host"):
497 # We made our own conn instance. Close it.
506 # Add any exceptions which your web framework handles
507 # normally (that you don't want server_error to trap).
508 ignored_exceptions
= []
510 # You'll want set this to True when you can't guarantee
511 # that each response will immediately follow each request;
512 # for example, when handling requests via multiple threads.
515 class ServerError(Exception):
519 def server_error(exc
=None):
520 """Server debug hook. Return True if exception handled, False if ignored.
522 You probably want to wrap this, so you can still handle an error using
523 your framework when it's ignored.
528 if ignore_all
or exc
[0] in ignored_exceptions
:
531 ServerError
.on
= True
533 print("".join(traceback
.format_exception(*exc
)))