Don't use non existing myerror()
[lunit.git] / lunit.lua
blobc18b3c1d03191e216dd98ca19b895c9e7130e191
2 --[[--------------------------------------------------------------------------
4 This file is part of lunit 0.4.
6 For Details about lunit look at: http://www.mroth.net/lunit/
8 Author: Michael Roth <mroth@nessie.de>
10 Copyright (c) 2004, 2006-2009 Michael Roth <mroth@nessie.de>
12 Permission is hereby granted, free of charge, to any person
13 obtaining a copy of this software and associated documentation
14 files (the "Software"), to deal in the Software without restriction,
15 including without limitation the rights to use, copy, modify, merge,
16 publish, distribute, sublicense, and/or sell copies of the Software,
17 and to permit persons to whom the Software is furnished to do so,
18 subject to the following conditions:
20 The above copyright notice and this permission notice shall be
21 included in all copies or substantial portions of the Software.
23 THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
24 EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
25 MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
26 IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
27 CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
28 TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
29 SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
31 --]]--------------------------------------------------------------------------
36 local orig_assert = assert
38 local pairs = pairs
39 local ipairs = ipairs
40 local next = next
41 local type = type
42 local error = error
43 local tostring = tostring
45 local string_sub = string.sub
46 local string_format = string.format
49 module("lunit", package.seeall) -- FIXME: Remove package.seeall
51 local lunit = _M
53 local __failure__ = {} -- Type tag for failed assertions
55 local typenames = { "nil", "boolean", "number", "string", "table", "function", "thread", "userdata" }
59 local traceback_hide -- Traceback function which hides lunit internals
60 local mypcall -- Protected call to a function with own traceback
62 local _tb_hide = setmetatable( {}, {__mode="k"} )
64 function traceback_hide(func)
65 _tb_hide[func] = true
66 end
68 local function my_traceback(errobj)
69 if is_table(errobj) and errobj.type == __failure__ then
70 local info = debug.getinfo(5, "Sl") -- FIXME: Hardcoded integers are bad...
71 errobj.where = string_format( "%s:%d", info.short_src, info.currentline)
72 else
73 errobj = { msg = tostring(errobj) }
74 errobj.tb = {}
75 local i = 2
76 while true do
77 local info = debug.getinfo(i, "Snlf")
78 if not is_table(info) then
79 break
80 end
81 if not _tb_hide[info.func] then
82 local line = {} -- Ripped from ldblib.c...
83 line[#line+1] = string_format("%s:", info.short_src)
84 if info.currentline > 0 then
85 line[#line+1] = string_format("%d:", info.currentline)
86 end
87 if info.namewhat ~= "" then
88 line[#line+1] = string_format(" in function '%s'", info.name)
89 else
90 if info.what == "main" then
91 line[#line+1] = " in main chunk"
92 elseif info.what == "C" or info.what == "tail" then
93 line[#line+1] = " ?"
94 else
95 line[#line+1] = string_format(" in function <%s:%d>", info.short_src, info.linedefined)
96 end
97 end
98 errobj.tb[#errobj.tb+1] = table.concat(line)
99 end
100 i = i + 1
103 return errobj
106 function mypcall(func)
107 orig_assert( is_function(func) )
108 local ok, errobj = xpcall(func, my_traceback)
109 if not ok then
110 return errobj
113 traceback_hide(mypcall)
117 -- Type check functions
119 for _, typename in ipairs(typenames) do
120 lunit["is_"..typename] = function(x)
121 return type(x) == typename
125 local is_nil = is_nil
126 local is_boolean = is_boolean
127 local is_number = is_number
128 local is_string = is_string
129 local is_table = is_table
130 local is_function = is_function
131 local is_thread = is_thread
132 local is_userdata = is_userdata
135 local function failure(name, usermsg, defaultmsg, ...)
136 local errobj = {
137 type = __failure__,
138 name = name,
139 msg = string_format(defaultmsg,...),
140 usermsg = usermsg
142 error(errobj, 0)
144 traceback_hide( failure )
147 local function format_arg(arg)
148 local argtype = type(arg)
149 if argtype == "string" then
150 return "'"..arg.."'"
151 elseif argtype == "number" or argtype == "boolean" or argtype == "nil" then
152 return tostring(arg)
153 else
154 return "["..tostring(arg).."]"
159 function fail(msg)
160 stats.assertions = stats.assertions + 1
161 failure( "fail", msg, "failure" )
163 traceback_hide( fail )
166 function assert(assertion, msg)
167 stats.assertions = stats.assertions + 1
168 if not assertion then
169 failure( "assert", msg, "assertion failed" )
171 return assertion
173 traceback_hide( assert )
176 function assert_true(actual, msg)
177 stats.assertions = stats.assertions + 1
178 local actualtype = type(actual)
179 if actualtype ~= "boolean" then
180 failure( "assert_true", msg, "true expected but was a "..actualtype )
182 if actual ~= true then
183 failure( "assert_true", msg, "true expected but was false" )
185 return actual
187 traceback_hide( assert_true )
190 function assert_false(actual, msg)
191 stats.assertions = stats.assertions + 1
192 local actualtype = type(actual)
193 if actualtype ~= "boolean" then
194 failure( "assert_false", msg, "false expected but was a "..actualtype )
196 if actual ~= false then
197 failure( "assert_false", msg, "false expected but was true" )
199 return actual
201 traceback_hide( assert_false )
204 function assert_equal(expected, actual, msg)
205 stats.assertions = stats.assertions + 1
206 if expected ~= actual then
207 failure( "assert_equal", msg, "expected %s but was %s", format_arg(expected), format_arg(actual) )
209 return actual
211 traceback_hide( assert_equal )
214 function assert_not_equal(unexpected, actual, msg)
215 stats.assertions = stats.assertions + 1
216 if unexpected == actual then
217 failure( "assert_not_equal", msg, "%s not expected but was one", format_arg(unexpected) )
219 return actual
221 traceback_hide( assert_not_equal )
224 function assert_match(pattern, actual, msg)
225 stats.assertions = stats.assertions + 1
226 local patterntype = type(pattern)
227 if patterntype ~= "string" then
228 failure( "assert_match", msg, "expected the pattern as a string but was a "..patterntype )
230 local actualtype = type(actual)
231 if actualtype ~= "string" then
232 failure( "assert_match", msg, "expected a string to match pattern '%s' but was a %s", pattern, actualtype )
234 if not string.find(actual, pattern) then
235 failure( "assert_match", msg, "expected '%s' to match pattern '%s' but doesn't", actual, pattern )
237 return actual
239 traceback_hide( assert_match )
242 function assert_not_match(pattern, actual, msg)
243 stats.assertions = stats.assertions + 1
244 local patterntype = type(pattern)
245 if patterntype ~= "string" then
246 failure( "assert_not_match", msg, "expected the pattern as a string but was a "..patterntype )
248 local actualtype = type(actual)
249 if actualtype ~= "string" then
250 failure( "assert_not_match", msg, "expected a string to not match pattern '%s' but was a %s", pattern, actualtype )
252 if string.find(actual, pattern) then
253 failure( "assert_not_match", msg, "expected '%s' to not match pattern '%s' but it does", actual, pattern )
255 return actual
257 traceback_hide( assert_not_match )
260 function assert_error(msg, func)
261 stats.assertions = stats.assertions + 1
262 if func == nil then
263 func, msg = msg, nil
265 local functype = type(func)
266 if functype ~= "function" then
267 failure( "assert_error", msg, "expected a function as last argument but was a "..functype )
269 local ok, errmsg = pcall(func)
270 if ok then
271 failure( "assert_error", msg, "error expected but no error occurred" )
274 traceback_hide( assert_error )
277 function assert_pass(msg, func)
278 stats.assertions = stats.assertions + 1
279 if func == nil then
280 func, msg = msg, nil
282 local functype = type(func)
283 if functype ~= "function" then
284 failure( "assert_pass", msg, "expected a function as last argument but was a %s", functype )
286 local ok, errmsg = pcall(func)
287 if not ok then
288 failure( "assert_pass", msg, "no error expected but error was: '%s'", errmsg )
291 traceback_hide( assert_pass )
294 -- lunit.assert_typename functions
296 for _, typename in ipairs(typenames) do
297 local assert_typename = "assert_"..typename
298 lunit[assert_typename] = function(actual, msg)
299 stats.assertions = stats.assertions + 1
300 local actualtype = type(actual)
301 if actualtype ~= typename then
302 failure( assert_typename, msg, typename.." expected but was a "..actualtype )
304 return actual
306 traceback_hide( lunit[assert_typename] )
310 -- lunit.assert_not_typename functions
312 for _, typename in ipairs(typenames) do
313 local assert_not_name = "assert_not_"..typename
314 lunit[assert_not_name] = function(actual, msg)
315 stats.assertions = stats.assertions + 1
316 if type(actual) == typename then
317 failure( assert_not_typename, msg, typename.." not expected but was one" )
320 traceback_hide( lunit[assert_not_name] )
324 function lunit.clearstats()
325 stats = {
326 assertions = 0;
327 passed = 0;
328 failed = 0;
329 errors = 0;
334 local report, reporterrobj
336 local testrunner
338 function lunit.setrunner(newrunner)
339 if not ( is_table(newrunner) or is_nil(newrunner) ) then
340 return error("lunit.setrunner: Invalid argument", 0)
342 local oldrunner = testrunner
343 testrunner = newrunner
344 return oldrunner
347 function lunit.loadrunner(name)
348 if not is_string(name) then
349 return error("lunit.loadrunner: Invalid argument", 0)
351 local ok, runner = pcall( require, name )
352 if not ok then
353 return error("lunit.loadrunner: Can't load test runner: "..runner, 0)
355 return setrunner(runner)
358 function report(event, ...)
359 local f = testrunner and testrunner[event]
360 if is_function(f) then
361 pcall(f, ...)
365 function reporterrobj(context, tcname, testname, errobj)
366 local fullname = tcname .. "." .. testname
367 if context == "setup" then
368 fullname = fullname .. ":" .. setupname(tcname, testname)
369 elseif context == "teardown" then
370 fullname = fullname .. ":" .. teardownname(tcname, testname)
372 if errobj.type == __failure__ then
373 stats.failed = stats.failed + 1
374 report("fail", fullname, errobj.where, errobj.msg, errobj.usermsg)
375 else
376 stats.errors = stats.errors + 1
377 report("err", fullname, errobj.msg, errobj.tb)
384 local function key_iter(t, k)
385 return (next(t,k))
389 local testcase
391 -- Array with all registered testcases
392 local _testcases = {}
394 -- Marks a module as a testcase.
395 -- Applied over a module from module("xyz", lunit.testcase).
396 function lunit.testcase(m)
397 orig_assert( is_table(m) )
398 --orig_assert( m._M == m )
399 orig_assert( is_string(m._NAME) )
400 --orig_assert( is_string(m._PACKAGE) )
402 -- Register the module as a testcase
403 _testcases[m._NAME] = m
405 -- Import lunit, fail, assert* and is_* function to the module/testcase
406 m.lunit = lunit
407 m.fail = lunit.fail
408 for funcname, func in pairs(lunit) do
409 if "assert" == string_sub(funcname, 1, 6) or "is_" == string_sub(funcname, 1, 3) then
410 m[funcname] = func
415 -- Iterator (testcasename) over all Testcases
416 function lunit.testcases()
417 -- Make a copy of testcases to prevent confusing the iterator when
418 -- new testcase are defined
419 local _testcases2 = {}
420 for k,v in pairs(_testcases) do
421 _testcases2[k] = true
423 return key_iter, _testcases2, nil
426 function testcase(tcname)
427 return _testcases[tcname]
433 -- Finds a function in a testcase case insensitiv
434 local function findfuncname(tcname, name)
435 for key, value in pairs(testcase(tcname)) do
436 if is_string(key) and is_function(value) and string.lower(key) == name then
437 return value
442 function lunit.setupname(tcname)
443 return findfuncname(tcname, "setup")
446 function lunit.teardownname(tcname)
447 return findfuncname(tcname, "teardown")
450 -- Iterator over all test names in a testcase.
451 -- Have to collect the names first in case one of the test
452 -- functions creates a new global and throws off the iteration.
453 function lunit.tests(tcname)
454 local testnames = {}
455 for key, value in pairs(testcase(tcname)) do
456 if is_string(key) and is_function(value) then
457 local lfn = string.lower(key)
458 if string.sub(lfn, 1, 4) == "test" or string.sub(lfn, -4) == "test" then
459 testnames[key] = true
463 return key_iter, testnames, nil
470 function lunit.runtest(tcname, testname)
471 orig_assert( is_string(tcname) )
472 orig_assert( is_string(testname) )
474 local function callit(context, func)
475 if func then
476 local err = mypcall(func)
477 if err then
478 reporterrobj(context, tcname, testname, err)
479 return false
482 return true
484 traceback_hide(callit)
486 report("run", tcname, testname)
488 local tc = testcase(tcname)
489 local setup = tc[setupname(tcname)]
490 local test = tc[testname]
491 local teardown = tc[teardownname(tcname)]
493 local setup_ok = callit( "setup", setup )
494 local test_ok = setup_ok and callit( "test", test )
495 local teardown_ok = setup_ok and callit( "teardown", teardown )
497 if setup_ok and test_ok and teardown_ok then
498 stats.passed = stats.passed + 1
499 report("pass", tcname, testname)
502 traceback_hide(runtest)
506 function lunit.run()
507 clearstats()
508 report("begin")
509 for testcasename in lunit.testcases() do
510 -- Run tests in the testcases
511 for testname in lunit.tests(testcasename) do
512 runtest(testcasename, testname)
515 report("done")
516 return stats
518 traceback_hide(run)
521 function lunit.loadonly()
522 clearstats()
523 report("begin")
524 report("done")
525 return stats
536 local lunitpat2luapat
538 local conv = {
539 ["^"] = "%^",
540 ["$"] = "%$",
541 ["("] = "%(",
542 [")"] = "%)",
543 ["%"] = "%%",
544 ["."] = "%.",
545 ["["] = "%[",
546 ["]"] = "%]",
547 ["+"] = "%+",
548 ["-"] = "%-",
549 ["?"] = ".",
550 ["*"] = ".*"
552 function lunitpat2luapat(str)
553 return "^" .. string.gsub(str, "%W", conv) .. "$"
559 local function in_patternmap(map, name)
560 if map[name] == true then
561 return true
562 else
563 for _, pat in ipairs(map) do
564 if string.find(name, pat) then
565 return true
569 return false
579 -- Called from 'lunit' shell script.
581 function main(argv)
582 argv = argv or {}
584 -- FIXME: Error handling and error messages aren't nice.
586 local function checkarg(optname, arg)
587 if not is_string(arg) then
588 return error("lunit.main: option "..optname..": argument missing.", 0)
592 local function loadtestcase(filename)
593 if not is_string(filename) then
594 return error("lunit.main: invalid argument")
596 local chunk, err = loadfile(filename)
597 if err then
598 return error(err)
599 else
600 chunk()
604 local testpatterns = nil
605 local doloadonly = false
606 local runner = nil
608 local i = 0
609 while i < #argv do
610 i = i + 1
611 local arg = argv[i]
612 if arg == "--loadonly" then
613 doloadonly = true
614 elseif arg == "--runner" or arg == "-r" then
615 local optname = arg; i = i + 1; arg = argv[i]
616 checkarg(optname, arg)
617 runner = arg
618 elseif arg == "--test" or arg == "-t" then
619 local optname = arg; i = i + 1; arg = argv[i]
620 checkarg(optname, arg)
621 testpatterns = testpatterns or {}
622 testpatterns[#testpatterns+1] = arg
623 elseif arg == "--" then
624 while i < #argv do
625 i = i + 1; arg = argv[i]
626 loadtestcase(arg)
628 else
629 loadtestcase(arg)
633 loadrunner(runner or "lunit-console")
635 if doloadonly then
636 return loadonly()
637 else
638 return run(testpatterns)
642 clearstats()