Bug fix if a test function creates a new variable in the module table
[lunit.git] / lunit.lua
blob1bddbca5450fd607b7c66945da6f7ecadfd482ac
2 --[[--------------------------------------------------------------------------
4 This file is part of lunit 0.4.
6 For Details about lunit look at: http://www.mroth.net/lunit/
8 Author: Michael Roth <mroth@nessie.de>
10 Copyright (c) 2004, 2006-2009 Michael Roth <mroth@nessie.de>
12 Permission is hereby granted, free of charge, to any person
13 obtaining a copy of this software and associated documentation
14 files (the "Software"), to deal in the Software without restriction,
15 including without limitation the rights to use, copy, modify, merge,
16 publish, distribute, sublicense, and/or sell copies of the Software,
17 and to permit persons to whom the Software is furnished to do so,
18 subject to the following conditions:
20 The above copyright notice and this permission notice shall be
21 included in all copies or substantial portions of the Software.
23 THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
24 EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
25 MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
26 IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
27 CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
28 TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
29 SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
31 --]]--------------------------------------------------------------------------
36 local orig_assert = assert
38 local pairs = pairs
39 local ipairs = ipairs
40 local next = next
41 local type = type
42 local error = error
44 local string_sub = string.sub
45 local string_format = string.format
48 module("lunit", package.seeall) -- FIXME: Remove package.seeall
50 local lunit = _M
52 local __failure__ = {} -- Type tag for failed assertions
54 local typenames = { "nil", "boolean", "number", "string", "table", "function", "thread", "userdata" }
58 local traceback_hide -- Traceback function which hides lunit internals
59 local mypcall -- Protected call to a function with own traceback
61 local _tb_hide = setmetatable( {}, {__mode="k"} )
63 function traceback_hide(func)
64 _tb_hide[func] = true
65 end
67 local function my_traceback(errobj)
68 if is_table(errobj) and errobj.type == __failure__ then
69 local info = debug.getinfo(5, "Sl") -- FIXME: Hardcoded integers are bad...
70 errobj.where = string_format( "%s:%d", info.short_src, info.currentline)
71 else
72 errobj = { msg = tostring(errobj) }
73 errobj.tb = {}
74 local i = 2
75 while true do
76 local info = debug.getinfo(i, "Snlf")
77 if not is_table(info) then
78 break
79 end
80 if not _tb_hide[info.func] then
81 local line = {} -- Ripped from ldblib.c...
82 line[#line+1] = string_format("%s:", info.short_src)
83 if info.currentline > 0 then
84 line[#line+1] = string_format("%d:", info.currentline)
85 end
86 if info.namewhat ~= "" then
87 line[#line+1] = string_format(" in function '%s'", info.name)
88 else
89 if info.what == "main" then
90 line[#line+1] = " in main chunk"
91 elseif info.what == "C" or info.what == "tail" then
92 line[#line+1] = " ?"
93 else
94 line[#line+1] = string_format(" in function <%s:%d>", info.short_src, info.linedefined)
95 end
96 end
97 errobj.tb[#errobj.tb+1] = table.concat(line)
98 end
99 i = i + 1
102 return errobj
105 function mypcall(func)
106 orig_assert( is_function(func) )
107 local ok, errobj = xpcall(func, my_traceback)
108 if not ok then
109 return errobj
112 traceback_hide(mypcall)
116 -- Type check functions
118 for _, typename in ipairs(typenames) do
119 lunit["is_"..typename] = function(x)
120 return type(x) == typename
124 local is_nil = is_nil
125 local is_boolean = is_boolean
126 local is_number = is_number
127 local is_string = is_string
128 local is_table = is_table
129 local is_function = is_function
130 local is_thread = is_thread
131 local is_userdata = is_userdata
134 local function failure(name, usermsg, defaultmsg, ...)
135 local errobj = {
136 type = __failure__,
137 name = name,
138 msg = string_format(defaultmsg,...),
139 usermsg = usermsg
141 error(errobj, 0)
143 traceback_hide( failure )
146 function fail(msg)
147 stats.assertions = stats.assertions + 1
148 failure( "fail", msg, "failure" )
150 traceback_hide( fail )
153 function assert(assertion, msg)
154 stats.assertions = stats.assertions + 1
155 if not assertion then
156 failure( "assert", msg, "assertion failed" )
158 return assertion
160 traceback_hide( assert )
163 function assert_true(actual, msg)
164 stats.assertions = stats.assertions + 1
165 local actualtype = type(actual)
166 if actualtype ~= "boolean" then
167 failure( "assert_true", msg, "true expected but was a "..actualtype )
169 if actual ~= true then
170 failure( "assert_true", msg, "true expected but was false" )
172 return actual
174 traceback_hide( assert_true )
177 function assert_false(actual, msg)
178 stats.assertions = stats.assertions + 1
179 local actualtype = type(actual)
180 if actualtype ~= "boolean" then
181 failure( "assert_false", msg, "false expected but was a "..actualtype )
183 if actual ~= false then
184 failure( "assert_false", msg, "false expected but was true" )
186 return actual
188 traceback_hide( assert_false )
191 function assert_equal(expected, actual, msg)
192 stats.assertions = stats.assertions + 1
193 if expected ~= actual then
194 failure( "assert_equal", msg, "expected '%s' but was '%s'", expected, actual )
196 return actual
198 traceback_hide( assert_equal )
201 function assert_not_equal(unexpected, actual, msg)
202 stats.assertions = stats.assertions + 1
203 if unexpected == actual then
204 failure( "assert_not_equal", msg, "'%s' not expected but was one", unexpected )
206 return actual
208 traceback_hide( assert_not_equal )
211 function assert_match(pattern, actual, msg)
212 stats.assertions = stats.assertions + 1
213 if not is_string(pattern) then
214 failure( "assert_match", msg, "expected the pattern as a string but was '%s'", pattern )
216 if not is_string(actual) then
217 failure( "assert_match", msg, "expected a string to match pattern '%s' but was '%s'", pattern, actual )
219 if not string.find(actual, pattern) then
220 failure( "assert_match", msg, "expected '%s' to match pattern '%s' but doesn't", actual, pattern )
222 return actual
224 traceback_hide( assert_match )
227 function assert_not_match(pattern, actual, msg)
228 stats.assertions = stats.assertions + 1
229 if not is_string(pattern) then
230 failure( "assert_not_match", msg, "expected the pattern as a string but was '%s'", pattern )
232 if not is_string(actual) then
233 failure( "assert_not_match", msg, "expected a string to not match pattern '%s' but was '%s'", pattern, actual )
235 if string.find(actual, pattern) then
236 failure( "assert_not_match", msg, "expected '%s' to not match pattern '%s' but it does", actual, pattern )
238 return actual
240 traceback_hide( assert_not_match )
243 function assert_error(msg, func)
244 stats.assertions = stats.assertions + 1
245 if func == nil then
246 func, msg = msg, nil
248 local functype = type(func)
249 if functype ~= "function" then
250 failure( "assert_error", msg, "expected a function as last argument but was a "..functype )
252 local ok, errmsg = pcall(func)
253 if ok then
254 failure( "assert_error", msg, "error expected but no error occurred" )
257 traceback_hide( assert_error )
260 function assert_pass(msg, func)
261 stats.assertions = stats.assertions + 1
262 if func == nil then
263 func, msg = msg, nil
265 local functype = type(func)
266 if functype ~= "function" then
267 failure( "assert_pass", msg, "expected a function as last argument but was a %s", functype )
269 local ok, errmsg = pcall(func)
270 if not ok then
271 failure( "assert_pass", msg, "no error expected but error was: '%s'", errmsg )
274 traceback_hide( assert_pass )
277 -- lunit.assert_typename functions
279 for _, typename in ipairs(typenames) do
280 local assert_typename = "assert_"..typename
281 lunit[assert_typename] = function(actual, msg)
282 stats.assertions = stats.assertions + 1
283 local actualtype = type(actual)
284 if actualtype ~= typename then
285 failure( assert_typename, msg, typename.." expected but was a "..actualtype )
287 return actual
289 traceback_hide( lunit[assert_typename] )
293 -- lunit.assert_not_typename functions
295 for _, typename in ipairs(typenames) do
296 local assert_not_name = "assert_not_"..typename
297 lunit[assert_not_name] = function(actual, msg)
298 stats.assertions = stats.assertions + 1
299 if type(actual) == typename then
300 failure( assert_not_typename, msg, typename.." not expected but was one" )
303 traceback_hide( lunit[assert_not_name] )
307 function lunit.clearstats()
308 stats = {
309 assertions = 0;
310 passed = 0;
311 failed = 0;
312 errors = 0;
317 local report, reporterrobj
319 local testrunner
321 function lunit.setrunner(newrunner)
322 if not ( is_table(newrunner) or is_nil(newrunner) ) then
323 return error("lunit.setrunner: Invalid argument", 0)
325 local oldrunner = testrunner
326 testrunner = newrunner
327 return oldrunner
330 function lunit.loadrunner(name)
331 if not is_string(name) then
332 return error("lunit.loadrunner: Invalid argument", 0)
334 local ok, runner = pcall( require, name )
335 if not ok then
336 return error("lunit.loadrunner: Can't load test runner: "..runner, 0)
338 return setrunner(runner)
341 function report(event, ...)
342 local f = testrunner and testrunner[event]
343 if is_function(f) then
344 pcall(f, ...)
348 function reporterrobj(context, tcname, testname, errobj)
349 local fullname = tcname .. "." .. testname
350 if context == "setup" then
351 fullname = fullname .. ":" .. setupname(tcname, testname)
352 elseif context == "teardown" then
353 fullname = fullname .. ":" .. teardownname(tcname, testname)
355 if errobj.type == __failure__ then
356 stats.failed = stats.failed + 1
357 report("fail", fullname, errobj.where, errobj.msg, errobj.usermsg)
358 else
359 stats.errors = stats.errors + 1
360 report("err", fullname, errobj.msg, errobj.tb)
367 local function key_iter(t, k)
368 return (next(t,k))
372 local testcase
374 -- Array with all registered testcases
375 local _testcases = {}
377 -- Marks a module as a testcase.
378 -- Applied over a module from module("xyz", lunit.testcase).
379 function lunit.testcase(m)
380 orig_assert( is_table(m) )
381 --orig_assert( m._M == m )
382 orig_assert( is_string(m._NAME) )
383 --orig_assert( is_string(m._PACKAGE) )
385 -- Register the module as a testcase
386 _testcases[m._NAME] = m
388 -- Import lunit, fail, assert* and is_* function to the module/testcase
389 m.lunit = lunit
390 m.fail = lunit.fail
391 for funcname, func in pairs(lunit) do
392 if "assert" == string_sub(funcname, 1, 6) or "is_" == string_sub(funcname, 1, 3) then
393 m[funcname] = func
398 -- Iterator (testcasename) over all Testcases
399 function lunit.testcases()
400 -- Make a copy of testcases to prevent confusing the iterator when
401 -- new testcase are defined
402 local _testcases2 = {}
403 for k,v in pairs(_testcases) do
404 _testcases2[k] = true
406 return key_iter, _testcases2, nil
409 function testcase(tcname)
410 return _testcases[tcname]
416 -- Finds a function in a testcase case insensitiv
417 local function findfuncname(tcname, name)
418 for key, value in pairs(testcase(tcname)) do
419 if is_string(key) and is_function(value) and string.lower(key) == name then
420 return value
425 function lunit.setupname(tcname)
426 return findfuncname(tcname, "setup")
429 function lunit.teardownname(tcname)
430 return findfuncname(tcname, "teardown")
433 -- Iterator over all test names in a testcase.
434 -- Have to collect the names first in case one of the test
435 -- functions creates a new global and throws off the iteration.
436 function lunit.tests(tcname)
437 local testnames = {}
438 for key, value in pairs(testcase(tcname)) do
439 if is_string(key) and is_function(value) then
440 local lfn = string.lower(key)
441 if string.sub(lfn, 1, 4) == "test" or string.sub(lfn, -4) == "test" then
442 testnames[key] = true
446 return key_iter, testnames, nil
453 function lunit.runtest(tcname, testname)
454 orig_assert( is_string(tcname) )
455 orig_assert( is_string(testname) )
457 local function callit(context, func)
458 if func then
459 local err = mypcall(func)
460 if err then
461 reporterrobj(context, tcname, testname, err)
462 return false
465 return true
467 traceback_hide(callit)
469 report("run", tcname, testname)
471 local tc = testcase(tcname)
472 local setup = tc[setupname(tcname)]
473 local test = tc[testname]
474 local teardown = tc[teardownname(tcname)]
476 local setup_ok = callit( "setup", setup )
477 local test_ok = setup_ok and callit( "test", test )
478 local teardown_ok = setup_ok and callit( "teardown", teardown )
480 if setup_ok and test_ok and teardown_ok then
481 stats.passed = stats.passed + 1
482 report("pass", tcname, testname)
485 traceback_hide(runtest)
489 function lunit.run()
490 clearstats()
491 report("begin")
492 for testcasename in lunit.testcases() do
493 -- Run tests in the testcases
494 for testname in lunit.tests(testcasename) do
495 runtest(testcasename, testname)
498 report("done")
500 traceback_hide(run)
503 function lunit.loadonly()
504 clearstats()
505 report("begin")
506 report("done")
517 local lunitpat2luapat
519 local conv = {
520 ["^"] = "%^",
521 ["$"] = "%$",
522 ["("] = "%(",
523 [")"] = "%)",
524 ["%"] = "%%",
525 ["."] = "%.",
526 ["["] = "%[",
527 ["]"] = "%]",
528 ["+"] = "%+",
529 ["-"] = "%-",
530 ["?"] = ".",
531 ["*"] = ".*"
533 function lunitpat2luapat(str)
534 return "^" .. string.gsub(str, "%W", conv) .. "$"
540 local function in_patternmap(map, name)
541 if map[name] == true then
542 return true
543 else
544 for _, pat in ipairs(map) do
545 if string.find(name, pat) then
546 return true
550 return false
560 -- Called from 'lunit' shell script.
562 function main(argv)
563 argv = argv or {}
565 -- FIXME: Error handling and error messages aren't nice.
567 local function checkarg(optname, arg)
568 if not is_string(arg) then
569 return error("lunit.main: option "..optname..": argument missing.", 0)
573 local function loadtestcase(filename)
574 if not is_string(filename) then
575 return error("lunit.main: invalid argument")
577 local chunk, err = loadfile(filename)
578 if err then
579 myerror(err)
580 else
581 chunk()
585 local testpatterns = nil
586 local doloadonly = false
587 local runner = nil
589 local i = 0
590 while i < #argv do
591 i = i + 1
592 local arg = argv[i]
593 if arg == "--loadonly" then
594 doloadonly = true
595 elseif arg == "--runner" or arg == "-r" then
596 local optname = arg; i = i + 1; arg = argv[i]
597 checkarg(optname, arg)
598 runner = arg
599 elseif arg == "--test" or arg == "-t" then
600 local optname = arg; i = i + 1; arg = argv[i]
601 checkarg(optname, arg)
602 testpatterns = testpatterns or {}
603 testpatterns[#testpatterns+1] = arg
604 elseif arg == "--" then
605 while i < #argv do
606 i = i + 1; arg = argv[i]
607 loadtestcase(arg)
609 else
610 loadtestcase(arg)
614 loadrunner(runner or "lunit-console")
616 if doloadonly then
617 return loadonly()
618 else
619 return run(testpatterns)
623 clearstats()