2 --[[--------------------------------------------------------------------------
4 This file is part of lunit 0.4.
6 For Details about lunit look at: http://www.mroth.net/lunit/
8 Author: Michael Roth <mroth@nessie.de>
10 Copyright (c) 2004, 2006-2009 Michael Roth <mroth@nessie.de>
12 Permission is hereby granted, free of charge, to any person
13 obtaining a copy of this software and associated documentation
14 files (the "Software"), to deal in the Software without restriction,
15 including without limitation the rights to use, copy, modify, merge,
16 publish, distribute, sublicense, and/or sell copies of the Software,
17 and to permit persons to whom the Software is furnished to do so,
18 subject to the following conditions:
20 The above copyright notice and this permission notice shall be
21 included in all copies or substantial portions of the Software.
23 THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
24 EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
25 MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
26 IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
27 CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
28 TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
29 SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
31 --]]--------------------------------------------------------------------------
36 local orig_assert
= assert
44 local string_sub
= string.sub
45 local string_format
= string.format
48 module("lunit", package
.seeall
) -- FIXME: Remove package.seeall
52 local __failure__
= {} -- Type tag for failed assertions
54 local typenames
= { "nil", "boolean", "number", "string", "table", "function", "thread", "userdata" }
58 local traceback_hide
-- Traceback function which hides lunit internals
59 local mypcall
-- Protected call to a function with own traceback
61 local _tb_hide
= setmetatable( {}, {__mode
="k"} )
63 function traceback_hide(func
)
67 local function my_traceback(errobj
)
68 if is_table(errobj
) and errobj
.type == __failure__
then
69 local info
= debug
.getinfo(5, "Sl") -- FIXME: Hardcoded integers are bad...
70 errobj
.where
= string_format( "%s:%d", info
.short_src
, info
.currentline
)
72 errobj
= { msg
= tostring(errobj
) }
76 local info
= debug
.getinfo(i
, "Snlf")
77 if not is_table(info
) then
80 if not _tb_hide
[info
.func
] then
81 local line
= {} -- Ripped from ldblib.c...
82 line
[#line
+1] = string_format("%s:", info
.short_src
)
83 if info
.currentline
> 0 then
84 line
[#line
+1] = string_format("%d:", info
.currentline
)
86 if info
.namewhat
~= "" then
87 line
[#line
+1] = string_format(" in function '%s'", info
.name
)
89 if info
.what
== "main" then
90 line
[#line
+1] = " in main chunk"
91 elseif info
.what
== "C" or info
.what
== "tail" then
94 line
[#line
+1] = string_format(" in function <%s:%d>", info
.short_src
, info
.linedefined
)
97 errobj
.tb
[#errobj
.tb
+1] = table.concat(line
)
105 function mypcall(func
)
106 orig_assert( is_function(func
) )
107 local ok
, errobj
= xpcall(func
, my_traceback
)
112 traceback_hide(mypcall
)
116 -- Type check functions
118 for _
, typename
in ipairs(typenames
) do
119 lunit
["is_"..typename
] = function(x
)
120 return type(x
) == typename
124 local is_nil
= is_nil
125 local is_boolean
= is_boolean
126 local is_number
= is_number
127 local is_string
= is_string
128 local is_table
= is_table
129 local is_function
= is_function
130 local is_thread
= is_thread
131 local is_userdata
= is_userdata
134 local function failure(name
, usermsg
, defaultmsg
, ...)
138 msg
= string_format(defaultmsg
,...),
143 traceback_hide( failure
)
147 stats
.assertions
= stats
.assertions
+ 1
148 failure( "fail", msg
, "failure" )
150 traceback_hide( fail
)
153 function assert(assertion
, msg
)
154 stats
.assertions
= stats
.assertions
+ 1
155 if not assertion
then
156 failure( "assert", msg
, "assertion failed" )
160 traceback_hide( assert )
163 function assert_true(actual
, msg
)
164 stats
.assertions
= stats
.assertions
+ 1
165 local actualtype
= type(actual
)
166 if actualtype
~= "boolean" then
167 failure( "assert_true", msg
, "true expected but was a "..actualtype
)
169 if actual
~= true then
170 failure( "assert_true", msg
, "true expected but was false" )
174 traceback_hide( assert_true
)
177 function assert_false(actual
, msg
)
178 stats
.assertions
= stats
.assertions
+ 1
179 local actualtype
= type(actual
)
180 if actualtype
~= "boolean" then
181 failure( "assert_false", msg
, "false expected but was a "..actualtype
)
183 if actual
~= false then
184 failure( "assert_false", msg
, "false expected but was true" )
188 traceback_hide( assert_false
)
191 function assert_equal(expected
, actual
, msg
)
192 stats
.assertions
= stats
.assertions
+ 1
193 if expected
~= actual
then
194 failure( "assert_equal", msg
, "expected '%s' but was '%s'", expected
, actual
)
198 traceback_hide( assert_equal
)
201 function assert_not_equal(unexpected
, actual
, msg
)
202 stats
.assertions
= stats
.assertions
+ 1
203 if unexpected
== actual
then
204 failure( "assert_not_equal", msg
, "'%s' not expected but was one", unexpected
)
208 traceback_hide( assert_not_equal
)
211 function assert_match(pattern
, actual
, msg
)
212 stats
.assertions
= stats
.assertions
+ 1
213 if not is_string(pattern
) then
214 failure( "assert_match", msg
, "expected the pattern as a string but was '%s'", pattern
)
216 if not is_string(actual
) then
217 failure( "assert_match", msg
, "expected a string to match pattern '%s' but was '%s'", pattern
, actual
)
219 if not string.find(actual
, pattern
) then
220 failure( "assert_match", msg
, "expected '%s' to match pattern '%s' but doesn't", actual
, pattern
)
224 traceback_hide( assert_match
)
227 function assert_not_match(pattern
, actual
, msg
)
228 stats
.assertions
= stats
.assertions
+ 1
229 if not is_string(pattern
) then
230 failure( "assert_not_match", msg
, "expected the pattern as a string but was '%s'", pattern
)
232 if not is_string(actual
) then
233 failure( "assert_not_match", msg
, "expected a string to not match pattern '%s' but was '%s'", pattern
, actual
)
235 if string.find(actual
, pattern
) then
236 failure( "assert_not_match", msg
, "expected '%s' to not match pattern '%s' but it does", actual
, pattern
)
240 traceback_hide( assert_not_match
)
243 function assert_error(msg
, func
)
244 stats
.assertions
= stats
.assertions
+ 1
248 local functype
= type(func
)
249 if functype
~= "function" then
250 failure( "assert_error", msg
, "expected a function as last argument but was a "..functype
)
252 local ok
, errmsg
= pcall(func
)
254 failure( "assert_error", msg
, "error expected but no error occurred" )
257 traceback_hide( assert_error
)
260 function assert_pass(msg
, func
)
261 stats
.assertions
= stats
.assertions
+ 1
265 local functype
= type(func
)
266 if functype
~= "function" then
267 failure( "assert_pass", msg
, "expected a function as last argument but was a %s", functype
)
269 local ok
, errmsg
= pcall(func
)
271 failure( "assert_pass", msg
, "no error expected but error was: '%s'", errmsg
)
274 traceback_hide( assert_pass
)
277 -- lunit.assert_typename functions
279 for _
, typename
in ipairs(typenames
) do
280 local assert_typename
= "assert_"..typename
281 lunit
[assert_typename
] = function(actual
, msg
)
282 stats
.assertions
= stats
.assertions
+ 1
283 local actualtype
= type(actual
)
284 if actualtype
~= typename
then
285 failure( assert_typename
, msg
, typename
.." expected but was a "..actualtype
)
289 traceback_hide( lunit
[assert_typename
] )
293 -- lunit.assert_not_typename functions
295 for _
, typename
in ipairs(typenames
) do
296 local assert_not_name
= "assert_not_"..typename
297 lunit
[assert_not_name
] = function(actual
, msg
)
298 stats
.assertions
= stats
.assertions
+ 1
299 if type(actual
) == typename
then
300 failure( assert_not_typename
, msg
, typename
.." not expected but was one" )
303 traceback_hide( lunit
[assert_not_name
] )
307 function lunit
.clearstats()
317 local report
, reporterrobj
321 function lunit
.setrunner(newrunner
)
322 if not ( is_table(newrunner
) or is_nil(newrunner
) ) then
323 return error("lunit.setrunner: Invalid argument", 0)
325 local oldrunner
= testrunner
326 testrunner
= newrunner
330 function lunit
.loadrunner(name
)
331 if not is_string(name
) then
332 return error("lunit.loadrunner: Invalid argument", 0)
334 local ok
, runner
= pcall( require
, name
)
336 return error("lunit.loadrunner: Can't load test runner: "..runner
, 0)
338 return setrunner(runner
)
341 function report(event
, ...)
342 local f
= testrunner
and testrunner
[event
]
343 if is_function(f
) then
348 function reporterrobj(context
, tcname
, testname
, errobj
)
349 local fullname
= tcname
.. "." .. testname
350 if context
== "setup" then
351 fullname
= fullname
.. ":" .. setupname(tcname
, testname
)
352 elseif context
== "teardown" then
353 fullname
= fullname
.. ":" .. teardownname(tcname
, testname
)
355 if errobj
.type == __failure__
then
356 stats
.failed
= stats
.failed
+ 1
357 report("fail", fullname
, errobj
.where
, errobj
.msg
, errobj
.usermsg
)
359 stats
.errors
= stats
.errors
+ 1
360 report("err", fullname
, errobj
.msg
, errobj
.tb
)
367 local function key_iter(t
, k
)
374 -- Array with all registered testcases
375 local _testcases
= {}
377 -- Marks a module as a testcase.
378 -- Applied over a module from module("xyz", lunit.testcase).
379 function lunit
.testcase(m
)
380 orig_assert( is_table(m
) )
381 --orig_assert( m._M == m )
382 orig_assert( is_string(m
._NAME
) )
383 --orig_assert( is_string(m._PACKAGE) )
385 -- Register the module as a testcase
386 _testcases
[m
._NAME
] = m
388 -- Import lunit, fail, assert* and is_* function to the module/testcase
391 for funcname
, func
in pairs(lunit
) do
392 if "assert" == string_sub(funcname
, 1, 6) or "is_" == string_sub(funcname
, 1, 3) then
398 -- Iterator (testcasename) over all Testcases
399 function lunit
.testcases()
400 -- Make a copy of testcases to prevent confusing the iterator when
401 -- new testcase are defined
402 local _testcases2
= {}
403 for k
,v
in pairs(_testcases
) do
404 _testcases2
[k
] = true
406 return key_iter
, _testcases2
, nil
409 function testcase(tcname
)
410 return _testcases
[tcname
]
416 -- Finds a function in a testcase case insensitiv
417 local function findfuncname(tcname
, name
)
418 for key
, value
in pairs(testcase(tcname
)) do
419 if is_string(key
) and is_function(value
) and string.lower(key
) == name
then
425 function lunit
.setupname(tcname
)
426 return findfuncname(tcname
, "setup")
429 function lunit
.teardownname(tcname
)
430 return findfuncname(tcname
, "teardown")
433 -- Iterator over all test names in a testcase.
434 -- Have to collect the names first in case one of the test
435 -- functions creates a new global and throws off the iteration.
436 function lunit
.tests(tcname
)
438 for key
, value
in pairs(testcase(tcname
)) do
439 if is_string(key
) and is_function(value
) then
440 local lfn
= string.lower(key
)
441 if string.sub(lfn
, 1, 4) == "test" or string.sub(lfn
, -4) == "test" then
442 testnames
[key
] = true
446 return key_iter
, testnames
, nil
453 function lunit
.runtest(tcname
, testname
)
454 orig_assert( is_string(tcname
) )
455 orig_assert( is_string(testname
) )
457 local function callit(context
, func
)
459 local err
= mypcall(func
)
461 reporterrobj(context
, tcname
, testname
, err
)
467 traceback_hide(callit
)
469 report("run", tcname
, testname
)
471 local tc
= testcase(tcname
)
472 local setup
= tc
[setupname(tcname
)]
473 local test
= tc
[testname
]
474 local teardown
= tc
[teardownname(tcname
)]
476 local setup_ok
= callit( "setup", setup
)
477 local test_ok
= setup_ok
and callit( "test", test
)
478 local teardown_ok
= setup_ok
and callit( "teardown", teardown
)
480 if setup_ok
and test_ok
and teardown_ok
then
481 stats
.passed
= stats
.passed
+ 1
482 report("pass", tcname
, testname
)
485 traceback_hide(runtest
)
492 for testcasename
in lunit
.testcases() do
493 -- Run tests in the testcases
494 for testname
in lunit
.tests(testcasename
) do
495 runtest(testcasename
, testname
)
503 function lunit
.loadonly()
517 local lunitpat2luapat
533 function lunitpat2luapat(str
)
534 return "^" .. string.gsub(str
, "%W", conv
) .. "$"
540 local function in_patternmap(map
, name
)
541 if map
[name
] == true then
544 for _
, pat
in ipairs(map
) do
545 if string.find(name
, pat
) then
560 -- Called from 'lunit' shell script.
565 -- FIXME: Error handling and error messages aren't nice.
567 local function checkarg(optname
, arg
)
568 if not is_string(arg
) then
569 return error("lunit.main: option "..optname
..": argument missing.", 0)
573 local function loadtestcase(filename
)
574 if not is_string(filename
) then
575 return error("lunit.main: invalid argument")
577 local chunk
, err
= loadfile(filename
)
585 local testpatterns
= nil
586 local doloadonly
= false
593 if arg
== "--loadonly" then
595 elseif arg
== "--runner" or arg
== "-r" then
596 local optname
= arg
; i
= i
+ 1; arg
= argv
[i
]
597 checkarg(optname
, arg
)
599 elseif arg
== "--test" or arg
== "-t" then
600 local optname
= arg
; i
= i
+ 1; arg
= argv
[i
]
601 checkarg(optname
, arg
)
602 testpatterns
= testpatterns
or {}
603 testpatterns
[#testpatterns
+1] = arg
604 elseif arg
== "--" then
606 i
= i
+ 1; arg
= argv
[i
]
614 loadrunner(runner
or "lunit-console")
619 return run(testpatterns
)