released version 0.5
[lunit.git] / lunit.lua
blob52d45880a128d7ce912620c39e034540fcc1edb9
2 --[[--------------------------------------------------------------------------
4 This file is part of lunit 0.5.
6 For Details about lunit look at: http://www.mroth.net/lunit/
8 Author: Michael Roth <mroth@nessie.de>
10 Copyright (c) 2004, 2006-2009 Michael Roth <mroth@nessie.de>
12 Permission is hereby granted, free of charge, to any person
13 obtaining a copy of this software and associated documentation
14 files (the "Software"), to deal in the Software without restriction,
15 including without limitation the rights to use, copy, modify, merge,
16 publish, distribute, sublicense, and/or sell copies of the Software,
17 and to permit persons to whom the Software is furnished to do so,
18 subject to the following conditions:
20 The above copyright notice and this permission notice shall be
21 included in all copies or substantial portions of the Software.
23 THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
24 EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
25 MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
26 IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
27 CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
28 TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
29 SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
31 --]]--------------------------------------------------------------------------
36 local orig_assert = assert
38 local pairs = pairs
39 local ipairs = ipairs
40 local next = next
41 local type = type
42 local error = error
43 local tostring = tostring
45 local string_sub = string.sub
46 local string_format = string.format
49 module("lunit", package.seeall) -- FIXME: Remove package.seeall
51 local lunit = _M
53 local __failure__ = {} -- Type tag for failed assertions
55 local typenames = { "nil", "boolean", "number", "string", "table", "function", "thread", "userdata" }
59 local traceback_hide -- Traceback function which hides lunit internals
60 local mypcall -- Protected call to a function with own traceback
62 local _tb_hide = setmetatable( {}, {__mode="k"} )
64 function traceback_hide(func)
65 _tb_hide[func] = true
66 end
68 local function my_traceback(errobj)
69 if is_table(errobj) and errobj.type == __failure__ then
70 local info = debug.getinfo(5, "Sl") -- FIXME: Hardcoded integers are bad...
71 errobj.where = string_format( "%s:%d", info.short_src, info.currentline)
72 else
73 errobj = { msg = tostring(errobj) }
74 errobj.tb = {}
75 local i = 2
76 while true do
77 local info = debug.getinfo(i, "Snlf")
78 if not is_table(info) then
79 break
80 end
81 if not _tb_hide[info.func] then
82 local line = {} -- Ripped from ldblib.c...
83 line[#line+1] = string_format("%s:", info.short_src)
84 if info.currentline > 0 then
85 line[#line+1] = string_format("%d:", info.currentline)
86 end
87 if info.namewhat ~= "" then
88 line[#line+1] = string_format(" in function '%s'", info.name)
89 else
90 if info.what == "main" then
91 line[#line+1] = " in main chunk"
92 elseif info.what == "C" or info.what == "tail" then
93 line[#line+1] = " ?"
94 else
95 line[#line+1] = string_format(" in function <%s:%d>", info.short_src, info.linedefined)
96 end
97 end
98 errobj.tb[#errobj.tb+1] = table.concat(line)
99 end
100 i = i + 1
103 return errobj
106 function mypcall(func)
107 orig_assert( is_function(func) )
108 local ok, errobj = xpcall(func, my_traceback)
109 if not ok then
110 return errobj
113 traceback_hide(mypcall)
117 -- Type check functions
119 for _, typename in ipairs(typenames) do
120 lunit["is_"..typename] = function(x)
121 return type(x) == typename
125 local is_nil = is_nil
126 local is_boolean = is_boolean
127 local is_number = is_number
128 local is_string = is_string
129 local is_table = is_table
130 local is_function = is_function
131 local is_thread = is_thread
132 local is_userdata = is_userdata
135 local function failure(name, usermsg, defaultmsg, ...)
136 local errobj = {
137 type = __failure__,
138 name = name,
139 msg = string_format(defaultmsg,...),
140 usermsg = usermsg
142 error(errobj, 0)
144 traceback_hide( failure )
147 local function format_arg(arg)
148 local argtype = type(arg)
149 if argtype == "string" then
150 return "'"..arg.."'"
151 elseif argtype == "number" or argtype == "boolean" or argtype == "nil" then
152 return tostring(arg)
153 else
154 return "["..tostring(arg).."]"
159 function fail(msg)
160 stats.assertions = stats.assertions + 1
161 failure( "fail", msg, "failure" )
163 traceback_hide( fail )
166 function assert(assertion, msg)
167 stats.assertions = stats.assertions + 1
168 if not assertion then
169 failure( "assert", msg, "assertion failed" )
171 return assertion
173 traceback_hide( assert )
176 function assert_true(actual, msg)
177 stats.assertions = stats.assertions + 1
178 local actualtype = type(actual)
179 if actualtype ~= "boolean" then
180 failure( "assert_true", msg, "true expected but was a "..actualtype )
182 if actual ~= true then
183 failure( "assert_true", msg, "true expected but was false" )
185 return actual
187 traceback_hide( assert_true )
190 function assert_false(actual, msg)
191 stats.assertions = stats.assertions + 1
192 local actualtype = type(actual)
193 if actualtype ~= "boolean" then
194 failure( "assert_false", msg, "false expected but was a "..actualtype )
196 if actual ~= false then
197 failure( "assert_false", msg, "false expected but was true" )
199 return actual
201 traceback_hide( assert_false )
204 function assert_equal(expected, actual, msg)
205 stats.assertions = stats.assertions + 1
206 if expected ~= actual then
207 failure( "assert_equal", msg, "expected %s but was %s", format_arg(expected), format_arg(actual) )
209 return actual
211 traceback_hide( assert_equal )
214 function assert_not_equal(unexpected, actual, msg)
215 stats.assertions = stats.assertions + 1
216 if unexpected == actual then
217 failure( "assert_not_equal", msg, "%s not expected but was one", format_arg(unexpected) )
219 return actual
221 traceback_hide( assert_not_equal )
224 function assert_match(pattern, actual, msg)
225 stats.assertions = stats.assertions + 1
226 local patterntype = type(pattern)
227 if patterntype ~= "string" then
228 failure( "assert_match", msg, "expected the pattern as a string but was a "..patterntype )
230 local actualtype = type(actual)
231 if actualtype ~= "string" then
232 failure( "assert_match", msg, "expected a string to match pattern '%s' but was a %s", pattern, actualtype )
234 if not string.find(actual, pattern) then
235 failure( "assert_match", msg, "expected '%s' to match pattern '%s' but doesn't", actual, pattern )
237 return actual
239 traceback_hide( assert_match )
242 function assert_not_match(pattern, actual, msg)
243 stats.assertions = stats.assertions + 1
244 local patterntype = type(pattern)
245 if patterntype ~= "string" then
246 failure( "assert_not_match", msg, "expected the pattern as a string but was a "..patterntype )
248 local actualtype = type(actual)
249 if actualtype ~= "string" then
250 failure( "assert_not_match", msg, "expected a string to not match pattern '%s' but was a %s", pattern, actualtype )
252 if string.find(actual, pattern) then
253 failure( "assert_not_match", msg, "expected '%s' to not match pattern '%s' but it does", actual, pattern )
255 return actual
257 traceback_hide( assert_not_match )
260 function assert_error(msg, func)
261 stats.assertions = stats.assertions + 1
262 if func == nil then
263 func, msg = msg, nil
265 local functype = type(func)
266 if functype ~= "function" then
267 failure( "assert_error", msg, "expected a function as last argument but was a "..functype )
269 local ok, errmsg = pcall(func)
270 if ok then
271 failure( "assert_error", msg, "error expected but no error occurred" )
274 traceback_hide( assert_error )
277 function assert_error_match(msg, pattern, func)
278 stats.assertions = stats.assertions + 1
279 if func == nil then
280 msg, pattern, func = nil, msg, pattern
282 local patterntype = type(pattern)
283 if patterntype ~= "string" then
284 failure( "assert_error_match", msg, "expected the pattern as a string but was a "..patterntype )
286 local functype = type(func)
287 if functype ~= "function" then
288 failure( "assert_error_match", msg, "expected a function as last argument but was a "..functype )
290 local ok, errmsg = pcall(func)
291 if ok then
292 failure( "assert_error_match", msg, "error expected but no error occurred" )
294 local errmsgtype = type(errmsg)
295 if errmsgtype ~= "string" then
296 failure( "assert_error_match", msg, "error as string expected but was a "..errmsgtype )
298 if not string.find(errmsg, pattern) then
299 failure( "assert_error_match", msg, "expected error '%s' to match pattern '%s' but doesn't", errmsg, pattern )
302 traceback_hide( assert_error_match )
305 function assert_pass(msg, func)
306 stats.assertions = stats.assertions + 1
307 if func == nil then
308 func, msg = msg, nil
310 local functype = type(func)
311 if functype ~= "function" then
312 failure( "assert_pass", msg, "expected a function as last argument but was a %s", functype )
314 local ok, errmsg = pcall(func)
315 if not ok then
316 failure( "assert_pass", msg, "no error expected but error was: '%s'", errmsg )
319 traceback_hide( assert_pass )
322 -- lunit.assert_typename functions
324 for _, typename in ipairs(typenames) do
325 local assert_typename = "assert_"..typename
326 lunit[assert_typename] = function(actual, msg)
327 stats.assertions = stats.assertions + 1
328 local actualtype = type(actual)
329 if actualtype ~= typename then
330 failure( assert_typename, msg, typename.." expected but was a "..actualtype )
332 return actual
334 traceback_hide( lunit[assert_typename] )
338 -- lunit.assert_not_typename functions
340 for _, typename in ipairs(typenames) do
341 local assert_not_typename = "assert_not_"..typename
342 lunit[assert_not_typename] = function(actual, msg)
343 stats.assertions = stats.assertions + 1
344 if type(actual) == typename then
345 failure( assert_not_typename, msg, typename.." not expected but was one" )
348 traceback_hide( lunit[assert_not_typename] )
352 function lunit.clearstats()
353 stats = {
354 assertions = 0;
355 passed = 0;
356 failed = 0;
357 errors = 0;
362 local report, reporterrobj
364 local testrunner
366 function lunit.setrunner(newrunner)
367 if not ( is_table(newrunner) or is_nil(newrunner) ) then
368 return error("lunit.setrunner: Invalid argument", 0)
370 local oldrunner = testrunner
371 testrunner = newrunner
372 return oldrunner
375 function lunit.loadrunner(name)
376 if not is_string(name) then
377 return error("lunit.loadrunner: Invalid argument", 0)
379 local ok, runner = pcall( require, name )
380 if not ok then
381 return error("lunit.loadrunner: Can't load test runner: "..runner, 0)
383 return setrunner(runner)
386 function report(event, ...)
387 local f = testrunner and testrunner[event]
388 if is_function(f) then
389 pcall(f, ...)
393 function reporterrobj(context, tcname, testname, errobj)
394 local fullname = tcname .. "." .. testname
395 if context == "setup" then
396 fullname = fullname .. ":" .. setupname(tcname, testname)
397 elseif context == "teardown" then
398 fullname = fullname .. ":" .. teardownname(tcname, testname)
400 if errobj.type == __failure__ then
401 stats.failed = stats.failed + 1
402 report("fail", fullname, errobj.where, errobj.msg, errobj.usermsg)
403 else
404 stats.errors = stats.errors + 1
405 report("err", fullname, errobj.msg, errobj.tb)
412 local function key_iter(t, k)
413 return (next(t,k))
417 local testcase
419 -- Array with all registered testcases
420 local _testcases = {}
422 -- Marks a module as a testcase.
423 -- Applied over a module from module("xyz", lunit.testcase).
424 function lunit.testcase(m)
425 orig_assert( is_table(m) )
426 --orig_assert( m._M == m )
427 orig_assert( is_string(m._NAME) )
428 --orig_assert( is_string(m._PACKAGE) )
430 -- Register the module as a testcase
431 _testcases[m._NAME] = m
433 -- Import lunit, fail, assert* and is_* function to the module/testcase
434 m.lunit = lunit
435 m.fail = lunit.fail
436 for funcname, func in pairs(lunit) do
437 if "assert" == string_sub(funcname, 1, 6) or "is_" == string_sub(funcname, 1, 3) then
438 m[funcname] = func
443 -- Iterator (testcasename) over all Testcases
444 function lunit.testcases()
445 -- Make a copy of testcases to prevent confusing the iterator when
446 -- new testcase are defined
447 local _testcases2 = {}
448 for k,v in pairs(_testcases) do
449 _testcases2[k] = true
451 return key_iter, _testcases2, nil
454 function testcase(tcname)
455 return _testcases[tcname]
461 -- Finds a function in a testcase case insensitive
462 local function findfuncname(tcname, name)
463 for key, value in pairs(testcase(tcname)) do
464 if is_string(key) and is_function(value) and string.lower(key) == name then
465 return key
470 function lunit.setupname(tcname)
471 return findfuncname(tcname, "setup")
474 function lunit.teardownname(tcname)
475 return findfuncname(tcname, "teardown")
478 -- Iterator over all test names in a testcase.
479 -- Have to collect the names first in case one of the test
480 -- functions creates a new global and throws off the iteration.
481 function lunit.tests(tcname)
482 local testnames = {}
483 for key, value in pairs(testcase(tcname)) do
484 if is_string(key) and is_function(value) then
485 local lfn = string.lower(key)
486 if string.sub(lfn, 1, 4) == "test" or string.sub(lfn, -4) == "test" then
487 testnames[key] = true
491 return key_iter, testnames, nil
498 function lunit.runtest(tcname, testname)
499 orig_assert( is_string(tcname) )
500 orig_assert( is_string(testname) )
502 local function callit(context, func)
503 if func then
504 local err = mypcall(func)
505 if err then
506 reporterrobj(context, tcname, testname, err)
507 return false
510 return true
512 traceback_hide(callit)
514 report("run", tcname, testname)
516 local tc = testcase(tcname)
517 local setup = tc[setupname(tcname)]
518 local test = tc[testname]
519 local teardown = tc[teardownname(tcname)]
521 local setup_ok = callit( "setup", setup )
522 local test_ok = setup_ok and callit( "test", test )
523 local teardown_ok = setup_ok and callit( "teardown", teardown )
525 if setup_ok and test_ok and teardown_ok then
526 stats.passed = stats.passed + 1
527 report("pass", tcname, testname)
530 traceback_hide(runtest)
534 function lunit.run()
535 clearstats()
536 report("begin")
537 for testcasename in lunit.testcases() do
538 -- Run tests in the testcases
539 for testname in lunit.tests(testcasename) do
540 runtest(testcasename, testname)
543 report("done")
544 return stats
546 traceback_hide(run)
549 function lunit.loadonly()
550 clearstats()
551 report("begin")
552 report("done")
553 return stats
564 local lunitpat2luapat
566 local conv = {
567 ["^"] = "%^",
568 ["$"] = "%$",
569 ["("] = "%(",
570 [")"] = "%)",
571 ["%"] = "%%",
572 ["."] = "%.",
573 ["["] = "%[",
574 ["]"] = "%]",
575 ["+"] = "%+",
576 ["-"] = "%-",
577 ["?"] = ".",
578 ["*"] = ".*"
580 function lunitpat2luapat(str)
581 return "^" .. string.gsub(str, "%W", conv) .. "$"
587 local function in_patternmap(map, name)
588 if map[name] == true then
589 return true
590 else
591 for _, pat in ipairs(map) do
592 if string.find(name, pat) then
593 return true
597 return false
607 -- Called from 'lunit' shell script.
609 function main(argv)
610 argv = argv or {}
612 -- FIXME: Error handling and error messages aren't nice.
614 local function checkarg(optname, arg)
615 if not is_string(arg) then
616 return error("lunit.main: option "..optname..": argument missing.", 0)
620 local function loadtestcase(filename)
621 if not is_string(filename) then
622 return error("lunit.main: invalid argument")
624 local chunk, err = loadfile(filename)
625 if err then
626 return error(err)
627 else
628 chunk()
632 local testpatterns = nil
633 local doloadonly = false
634 local runner = nil
636 local i = 0
637 while i < #argv do
638 i = i + 1
639 local arg = argv[i]
640 if arg == "--loadonly" then
641 doloadonly = true
642 elseif arg == "--runner" or arg == "-r" then
643 local optname = arg; i = i + 1; arg = argv[i]
644 checkarg(optname, arg)
645 runner = arg
646 elseif arg == "--test" or arg == "-t" then
647 local optname = arg; i = i + 1; arg = argv[i]
648 checkarg(optname, arg)
649 testpatterns = testpatterns or {}
650 testpatterns[#testpatterns+1] = arg
651 elseif arg == "--" then
652 while i < #argv do
653 i = i + 1; arg = argv[i]
654 loadtestcase(arg)
656 else
657 loadtestcase(arg)
661 loadrunner(runner or "lunit-console")
663 if doloadonly then
664 return loadonly()
665 else
666 return run(testpatterns)
670 clearstats()