Bundled cherrypy.
[smonitor.git] / monitor / cherrypy / test / webtest.py
blob969eab0e286088e8d86ddcdaedb07d108e3b57dc
1 """Extensions to unittest for web frameworks.
3 Use the WebCase.getPage method to request a page from your HTTP server.
5 Framework Integration
6 =====================
8 If you have control over your server process, you can handle errors
9 in the server-side of the HTTP conversation a bit better. You must run
10 both the client (your WebCase tests) and the server in the same process
11 (but in separate threads, obviously).
13 When an error occurs in the framework, call server_error. It will print
14 the traceback to stdout, and keep any assertions you have from running
15 (the assumption is that, if the server errors, the page output will not
16 be of further significance to your tests).
17 """
19 import os
20 import pprint
21 import re
22 import socket
23 import sys
24 import time
25 import traceback
26 import types
28 from unittest import *
29 from unittest import _TextTestResult
31 from cherrypy._cpcompat import basestring, HTTPConnection, HTTPSConnection, unicodestr
35 def interface(host):
36 """Return an IP address for a client connection given the server host.
38 If the server is listening on '0.0.0.0' (INADDR_ANY)
39 or '::' (IN6ADDR_ANY), this will return the proper localhost."""
40 if host == '0.0.0.0':
41 # INADDR_ANY, which should respond on localhost.
42 return "127.0.0.1"
43 if host == '::':
44 # IN6ADDR_ANY, which should respond on localhost.
45 return "::1"
46 return host
49 class TerseTestResult(_TextTestResult):
51 def printErrors(self):
52 # Overridden to avoid unnecessary empty line
53 if self.errors or self.failures:
54 if self.dots or self.showAll:
55 self.stream.writeln()
56 self.printErrorList('ERROR', self.errors)
57 self.printErrorList('FAIL', self.failures)
60 class TerseTestRunner(TextTestRunner):
61 """A test runner class that displays results in textual form."""
63 def _makeResult(self):
64 return TerseTestResult(self.stream, self.descriptions, self.verbosity)
66 def run(self, test):
67 "Run the given test case or test suite."
68 # Overridden to remove unnecessary empty lines and separators
69 result = self._makeResult()
70 test(result)
71 result.printErrors()
72 if not result.wasSuccessful():
73 self.stream.write("FAILED (")
74 failed, errored = list(map(len, (result.failures, result.errors)))
75 if failed:
76 self.stream.write("failures=%d" % failed)
77 if errored:
78 if failed: self.stream.write(", ")
79 self.stream.write("errors=%d" % errored)
80 self.stream.writeln(")")
81 return result
84 class ReloadingTestLoader(TestLoader):
86 def loadTestsFromName(self, name, module=None):
87 """Return a suite of all tests cases given a string specifier.
89 The name may resolve either to a module, a test case class, a
90 test method within a test case class, or a callable object which
91 returns a TestCase or TestSuite instance.
93 The method optionally resolves the names relative to a given module.
94 """
95 parts = name.split('.')
96 unused_parts = []
97 if module is None:
98 if not parts:
99 raise ValueError("incomplete test name: %s" % name)
100 else:
101 parts_copy = parts[:]
102 while parts_copy:
103 target = ".".join(parts_copy)
104 if target in sys.modules:
105 module = reload(sys.modules[target])
106 parts = unused_parts
107 break
108 else:
109 try:
110 module = __import__(target)
111 parts = unused_parts
112 break
113 except ImportError:
114 unused_parts.insert(0,parts_copy[-1])
115 del parts_copy[-1]
116 if not parts_copy:
117 raise
118 parts = parts[1:]
119 obj = module
120 for part in parts:
121 obj = getattr(obj, part)
123 if type(obj) == types.ModuleType:
124 return self.loadTestsFromModule(obj)
125 elif (isinstance(obj, (type, types.ClassType)) and
126 issubclass(obj, TestCase)):
127 return self.loadTestsFromTestCase(obj)
128 elif type(obj) == types.UnboundMethodType:
129 return obj.im_class(obj.__name__)
130 elif hasattr(obj, '__call__'):
131 test = obj()
132 if not isinstance(test, TestCase) and \
133 not isinstance(test, TestSuite):
134 raise ValueError("calling %s returned %s, "
135 "not a test" % (obj,test))
136 return test
137 else:
138 raise ValueError("do not know how to make test from: %s" % obj)
141 try:
142 # Jython support
143 if sys.platform[:4] == 'java':
144 def getchar():
145 # Hopefully this is enough
146 return sys.stdin.read(1)
147 else:
148 # On Windows, msvcrt.getch reads a single char without output.
149 import msvcrt
150 def getchar():
151 return msvcrt.getch()
152 except ImportError:
153 # Unix getchr
154 import tty, termios
155 def getchar():
156 fd = sys.stdin.fileno()
157 old_settings = termios.tcgetattr(fd)
158 try:
159 tty.setraw(sys.stdin.fileno())
160 ch = sys.stdin.read(1)
161 finally:
162 termios.tcsetattr(fd, termios.TCSADRAIN, old_settings)
163 return ch
166 class WebCase(TestCase):
167 HOST = "127.0.0.1"
168 PORT = 8000
169 HTTP_CONN = HTTPConnection
170 PROTOCOL = "HTTP/1.1"
172 scheme = "http"
173 url = None
175 status = None
176 headers = None
177 body = None
179 encoding = 'utf-8'
181 time = None
183 def get_conn(self, auto_open=False):
184 """Return a connection to our HTTP server."""
185 if self.scheme == "https":
186 cls = HTTPSConnection
187 else:
188 cls = HTTPConnection
189 conn = cls(self.interface(), self.PORT)
190 # Automatically re-connect?
191 conn.auto_open = auto_open
192 conn.connect()
193 return conn
195 def set_persistent(self, on=True, auto_open=False):
196 """Make our HTTP_CONN persistent (or not).
198 If the 'on' argument is True (the default), then self.HTTP_CONN
199 will be set to an instance of HTTPConnection (or HTTPS
200 if self.scheme is "https"). This will then persist across requests.
202 We only allow for a single open connection, so if you call this
203 and we currently have an open connection, it will be closed.
205 try:
206 self.HTTP_CONN.close()
207 except (TypeError, AttributeError):
208 pass
210 if on:
211 self.HTTP_CONN = self.get_conn(auto_open=auto_open)
212 else:
213 if self.scheme == "https":
214 self.HTTP_CONN = HTTPSConnection
215 else:
216 self.HTTP_CONN = HTTPConnection
218 def _get_persistent(self):
219 return hasattr(self.HTTP_CONN, "__class__")
220 def _set_persistent(self, on):
221 self.set_persistent(on)
222 persistent = property(_get_persistent, _set_persistent)
224 def interface(self):
225 """Return an IP address for a client connection.
227 If the server is listening on '0.0.0.0' (INADDR_ANY)
228 or '::' (IN6ADDR_ANY), this will return the proper localhost."""
229 return interface(self.HOST)
231 def getPage(self, url, headers=None, method="GET", body=None, protocol=None):
232 """Open the url with debugging support. Return status, headers, body."""
233 ServerError.on = False
235 if isinstance(url, unicodestr):
236 url = url.encode('utf-8')
237 if isinstance(body, unicodestr):
238 body = body.encode('utf-8')
240 self.url = url
241 self.time = None
242 start = time.time()
243 result = openURL(url, headers, method, body, self.HOST, self.PORT,
244 self.HTTP_CONN, protocol or self.PROTOCOL)
245 self.time = time.time() - start
246 self.status, self.headers, self.body = result
248 # Build a list of request cookies from the previous response cookies.
249 self.cookies = [('Cookie', v) for k, v in self.headers
250 if k.lower() == 'set-cookie']
252 if ServerError.on:
253 raise ServerError()
254 return result
256 interactive = True
257 console_height = 30
259 def _handlewebError(self, msg):
260 print("")
261 print(" ERROR: %s" % msg)
263 if not self.interactive:
264 raise self.failureException(msg)
266 p = " Show: [B]ody [H]eaders [S]tatus [U]RL; [I]gnore, [R]aise, or sys.e[X]it >> "
267 sys.stdout.write(p)
268 sys.stdout.flush()
269 while True:
270 i = getchar().upper()
271 if i not in "BHSUIRX":
272 continue
273 print(i.upper()) # Also prints new line
274 if i == "B":
275 for x, line in enumerate(self.body.splitlines()):
276 if (x + 1) % self.console_height == 0:
277 # The \r and comma should make the next line overwrite
278 sys.stdout.write("<-- More -->\r")
279 m = getchar().lower()
280 # Erase our "More" prompt
281 sys.stdout.write(" \r")
282 if m == "q":
283 break
284 print(line)
285 elif i == "H":
286 pprint.pprint(self.headers)
287 elif i == "S":
288 print(self.status)
289 elif i == "U":
290 print(self.url)
291 elif i == "I":
292 # return without raising the normal exception
293 return
294 elif i == "R":
295 raise self.failureException(msg)
296 elif i == "X":
297 self.exit()
298 sys.stdout.write(p)
299 sys.stdout.flush()
301 def exit(self):
302 sys.exit()
304 def assertStatus(self, status, msg=None):
305 """Fail if self.status != status."""
306 if isinstance(status, basestring):
307 if not self.status == status:
308 if msg is None:
309 msg = 'Status (%r) != %r' % (self.status, status)
310 self._handlewebError(msg)
311 elif isinstance(status, int):
312 code = int(self.status[:3])
313 if code != status:
314 if msg is None:
315 msg = 'Status (%r) != %r' % (self.status, status)
316 self._handlewebError(msg)
317 else:
318 # status is a tuple or list.
319 match = False
320 for s in status:
321 if isinstance(s, basestring):
322 if self.status == s:
323 match = True
324 break
325 elif int(self.status[:3]) == s:
326 match = True
327 break
328 if not match:
329 if msg is None:
330 msg = 'Status (%r) not in %r' % (self.status, status)
331 self._handlewebError(msg)
333 def assertHeader(self, key, value=None, msg=None):
334 """Fail if (key, [value]) not in self.headers."""
335 lowkey = key.lower()
336 for k, v in self.headers:
337 if k.lower() == lowkey:
338 if value is None or str(value) == v:
339 return v
341 if msg is None:
342 if value is None:
343 msg = '%r not in headers' % key
344 else:
345 msg = '%r:%r not in headers' % (key, value)
346 self._handlewebError(msg)
348 def assertHeaderItemValue(self, key, value, msg=None):
349 """Fail if the header does not contain the specified value"""
350 actual_value = self.assertHeader(key, msg=msg)
351 header_values = map(str.strip, actual_value.split(','))
352 if value in header_values:
353 return value
355 if msg is None:
356 msg = "%r not in %r" % (value, header_values)
357 self._handlewebError(msg)
359 def assertNoHeader(self, key, msg=None):
360 """Fail if key in self.headers."""
361 lowkey = key.lower()
362 matches = [k for k, v in self.headers if k.lower() == lowkey]
363 if matches:
364 if msg is None:
365 msg = '%r in headers' % key
366 self._handlewebError(msg)
368 def assertBody(self, value, msg=None):
369 """Fail if value != self.body."""
370 if value != self.body:
371 if msg is None:
372 msg = 'expected body:\n%r\n\nactual body:\n%r' % (value, self.body)
373 self._handlewebError(msg)
375 def assertInBody(self, value, msg=None):
376 """Fail if value not in self.body."""
377 if value not in self.body:
378 if msg is None:
379 msg = '%r not in body: %s' % (value, self.body)
380 self._handlewebError(msg)
382 def assertNotInBody(self, value, msg=None):
383 """Fail if value in self.body."""
384 if value in self.body:
385 if msg is None:
386 msg = '%r found in body' % value
387 self._handlewebError(msg)
389 def assertMatchesBody(self, pattern, msg=None, flags=0):
390 """Fail if value (a regex pattern) is not in self.body."""
391 if re.search(pattern, self.body, flags) is None:
392 if msg is None:
393 msg = 'No match for %r in body' % pattern
394 self._handlewebError(msg)
397 methods_with_bodies = ("POST", "PUT")
399 def cleanHeaders(headers, method, body, host, port):
400 """Return request headers, with required headers added (if missing)."""
401 if headers is None:
402 headers = []
404 # Add the required Host request header if not present.
405 # [This specifies the host:port of the server, not the client.]
406 found = False
407 for k, v in headers:
408 if k.lower() == 'host':
409 found = True
410 break
411 if not found:
412 if port == 80:
413 headers.append(("Host", host))
414 else:
415 headers.append(("Host", "%s:%s" % (host, port)))
417 if method in methods_with_bodies:
418 # Stick in default type and length headers if not present
419 found = False
420 for k, v in headers:
421 if k.lower() == 'content-type':
422 found = True
423 break
424 if not found:
425 headers.append(("Content-Type", "application/x-www-form-urlencoded"))
426 headers.append(("Content-Length", str(len(body or ""))))
428 return headers
431 def shb(response):
432 """Return status, headers, body the way we like from a response."""
433 h = []
434 key, value = None, None
435 for line in response.msg.headers:
436 if line:
437 if line[0] in " \t":
438 value += line.strip()
439 else:
440 if key and value:
441 h.append((key, value))
442 key, value = line.split(":", 1)
443 key = key.strip()
444 value = value.strip()
445 if key and value:
446 h.append((key, value))
448 return "%s %s" % (response.status, response.reason), h, response.read()
451 def openURL(url, headers=None, method="GET", body=None,
452 host="127.0.0.1", port=8000, http_conn=HTTPConnection,
453 protocol="HTTP/1.1"):
454 """Open the given HTTP resource and return status, headers, and body."""
456 headers = cleanHeaders(headers, method, body, host, port)
458 # Trying 10 times is simply in case of socket errors.
459 # Normal case--it should run once.
460 for trial in range(10):
461 try:
462 # Allow http_conn to be a class or an instance
463 if hasattr(http_conn, "host"):
464 conn = http_conn
465 else:
466 conn = http_conn(interface(host), port)
468 conn._http_vsn_str = protocol
469 conn._http_vsn = int("".join([x for x in protocol if x.isdigit()]))
471 # skip_accept_encoding argument added in python version 2.4
472 if sys.version_info < (2, 4):
473 def putheader(self, header, value):
474 if header == 'Accept-Encoding' and value == 'identity':
475 return
476 self.__class__.putheader(self, header, value)
477 import new
478 conn.putheader = new.instancemethod(putheader, conn, conn.__class__)
479 conn.putrequest(method.upper(), url, skip_host=True)
480 else:
481 conn.putrequest(method.upper(), url, skip_host=True,
482 skip_accept_encoding=True)
484 for key, value in headers:
485 conn.putheader(key, value)
486 conn.endheaders()
488 if body is not None:
489 conn.send(body)
491 # Handle response
492 response = conn.getresponse()
494 s, h, b = shb(response)
496 if not hasattr(http_conn, "host"):
497 # We made our own conn instance. Close it.
498 conn.close()
500 return s, h, b
501 except socket.error:
502 time.sleep(0.5)
503 raise
506 # Add any exceptions which your web framework handles
507 # normally (that you don't want server_error to trap).
508 ignored_exceptions = []
510 # You'll want set this to True when you can't guarantee
511 # that each response will immediately follow each request;
512 # for example, when handling requests via multiple threads.
513 ignore_all = False
515 class ServerError(Exception):
516 on = False
519 def server_error(exc=None):
520 """Server debug hook. Return True if exception handled, False if ignored.
522 You probably want to wrap this, so you can still handle an error using
523 your framework when it's ignored.
525 if exc is None:
526 exc = sys.exc_info()
528 if ignore_all or exc[0] in ignored_exceptions:
529 return False
530 else:
531 ServerError.on = True
532 print("")
533 print("".join(traceback.format_exception(*exc)))
534 return True