2 --[[--------------------------------------------------------------------------
4 This file is part of lunit 0.4.
6 For Details about lunit look at: http://www.mroth.net/lunit/
8 Author: Michael Roth <mroth@nessie.de>
10 Copyright (c) 2004, 2006-2009 Michael Roth <mroth@nessie.de>
12 Permission is hereby granted, free of charge, to any person
13 obtaining a copy of this software and associated documentation
14 files (the "Software"), to deal in the Software without restriction,
15 including without limitation the rights to use, copy, modify, merge,
16 publish, distribute, sublicense, and/or sell copies of the Software,
17 and to permit persons to whom the Software is furnished to do so,
18 subject to the following conditions:
20 The above copyright notice and this permission notice shall be
21 included in all copies or substantial portions of the Software.
23 THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
24 EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
25 MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
26 IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
27 CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
28 TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
29 SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
31 --]]--------------------------------------------------------------------------
36 local orig_assert
= assert
43 local tostring = tostring
45 local string_sub
= string.sub
46 local string_format
= string.format
49 module("lunit", package
.seeall
) -- FIXME: Remove package.seeall
53 local __failure__
= {} -- Type tag for failed assertions
55 local typenames
= { "nil", "boolean", "number", "string", "table", "function", "thread", "userdata" }
59 local traceback_hide
-- Traceback function which hides lunit internals
60 local mypcall
-- Protected call to a function with own traceback
62 local _tb_hide
= setmetatable( {}, {__mode
="k"} )
64 function traceback_hide(func
)
68 local function my_traceback(errobj
)
69 if is_table(errobj
) and errobj
.type == __failure__
then
70 local info
= debug
.getinfo(5, "Sl") -- FIXME: Hardcoded integers are bad...
71 errobj
.where
= string_format( "%s:%d", info
.short_src
, info
.currentline
)
73 errobj
= { msg
= tostring(errobj
) }
77 local info
= debug
.getinfo(i
, "Snlf")
78 if not is_table(info
) then
81 if not _tb_hide
[info
.func
] then
82 local line
= {} -- Ripped from ldblib.c...
83 line
[#line
+1] = string_format("%s:", info
.short_src
)
84 if info
.currentline
> 0 then
85 line
[#line
+1] = string_format("%d:", info
.currentline
)
87 if info
.namewhat
~= "" then
88 line
[#line
+1] = string_format(" in function '%s'", info
.name
)
90 if info
.what
== "main" then
91 line
[#line
+1] = " in main chunk"
92 elseif info
.what
== "C" or info
.what
== "tail" then
95 line
[#line
+1] = string_format(" in function <%s:%d>", info
.short_src
, info
.linedefined
)
98 errobj
.tb
[#errobj
.tb
+1] = table.concat(line
)
106 function mypcall(func
)
107 orig_assert( is_function(func
) )
108 local ok
, errobj
= xpcall(func
, my_traceback
)
113 traceback_hide(mypcall
)
117 -- Type check functions
119 for _
, typename
in ipairs(typenames
) do
120 lunit
["is_"..typename
] = function(x
)
121 return type(x
) == typename
125 local is_nil
= is_nil
126 local is_boolean
= is_boolean
127 local is_number
= is_number
128 local is_string
= is_string
129 local is_table
= is_table
130 local is_function
= is_function
131 local is_thread
= is_thread
132 local is_userdata
= is_userdata
135 local function failure(name
, usermsg
, defaultmsg
, ...)
139 msg
= string_format(defaultmsg
,...),
144 traceback_hide( failure
)
147 local function format_arg(arg
)
148 local argtype
= type(arg
)
149 if argtype
== "string" then
151 elseif argtype
== "number" or argtype
== "boolean" or argtype
== "nil" then
154 return "["..tostring(arg
).."]"
160 stats
.assertions
= stats
.assertions
+ 1
161 failure( "fail", msg
, "failure" )
163 traceback_hide( fail
)
166 function assert(assertion
, msg
)
167 stats
.assertions
= stats
.assertions
+ 1
168 if not assertion
then
169 failure( "assert", msg
, "assertion failed" )
173 traceback_hide( assert )
176 function assert_true(actual
, msg
)
177 stats
.assertions
= stats
.assertions
+ 1
178 local actualtype
= type(actual
)
179 if actualtype
~= "boolean" then
180 failure( "assert_true", msg
, "true expected but was a "..actualtype
)
182 if actual
~= true then
183 failure( "assert_true", msg
, "true expected but was false" )
187 traceback_hide( assert_true
)
190 function assert_false(actual
, msg
)
191 stats
.assertions
= stats
.assertions
+ 1
192 local actualtype
= type(actual
)
193 if actualtype
~= "boolean" then
194 failure( "assert_false", msg
, "false expected but was a "..actualtype
)
196 if actual
~= false then
197 failure( "assert_false", msg
, "false expected but was true" )
201 traceback_hide( assert_false
)
204 function assert_equal(expected
, actual
, msg
)
205 stats
.assertions
= stats
.assertions
+ 1
206 if expected
~= actual
then
207 failure( "assert_equal", msg
, "expected %s but was %s", format_arg(expected
), format_arg(actual
) )
211 traceback_hide( assert_equal
)
214 function assert_not_equal(unexpected
, actual
, msg
)
215 stats
.assertions
= stats
.assertions
+ 1
216 if unexpected
== actual
then
217 failure( "assert_not_equal", msg
, "%s not expected but was one", format_arg(unexpected
) )
221 traceback_hide( assert_not_equal
)
224 function assert_match(pattern
, actual
, msg
)
225 stats
.assertions
= stats
.assertions
+ 1
226 local patterntype
= type(pattern
)
227 if patterntype
~= "string" then
228 failure( "assert_match", msg
, "expected the pattern as a string but was a "..patterntype
)
230 local actualtype
= type(actual
)
231 if actualtype
~= "string" then
232 failure( "assert_match", msg
, "expected a string to match pattern '%s' but was a %s", pattern
, actualtype
)
234 if not string.find(actual
, pattern
) then
235 failure( "assert_match", msg
, "expected '%s' to match pattern '%s' but doesn't", actual
, pattern
)
239 traceback_hide( assert_match
)
242 function assert_not_match(pattern
, actual
, msg
)
243 stats
.assertions
= stats
.assertions
+ 1
244 local patterntype
= type(pattern
)
245 if patterntype
~= "string" then
246 failure( "assert_not_match", msg
, "expected the pattern as a string but was a "..patterntype
)
248 local actualtype
= type(actual
)
249 if actualtype
~= "string" then
250 failure( "assert_not_match", msg
, "expected a string to not match pattern '%s' but was a %s", pattern
, actualtype
)
252 if string.find(actual
, pattern
) then
253 failure( "assert_not_match", msg
, "expected '%s' to not match pattern '%s' but it does", actual
, pattern
)
257 traceback_hide( assert_not_match
)
260 function assert_error(msg
, func
)
261 stats
.assertions
= stats
.assertions
+ 1
265 local functype
= type(func
)
266 if functype
~= "function" then
267 failure( "assert_error", msg
, "expected a function as last argument but was a "..functype
)
269 local ok
, errmsg
= pcall(func
)
271 failure( "assert_error", msg
, "error expected but no error occurred" )
274 traceback_hide( assert_error
)
277 function assert_pass(msg
, func
)
278 stats
.assertions
= stats
.assertions
+ 1
282 local functype
= type(func
)
283 if functype
~= "function" then
284 failure( "assert_pass", msg
, "expected a function as last argument but was a %s", functype
)
286 local ok
, errmsg
= pcall(func
)
288 failure( "assert_pass", msg
, "no error expected but error was: '%s'", errmsg
)
291 traceback_hide( assert_pass
)
294 -- lunit.assert_typename functions
296 for _
, typename
in ipairs(typenames
) do
297 local assert_typename
= "assert_"..typename
298 lunit
[assert_typename
] = function(actual
, msg
)
299 stats
.assertions
= stats
.assertions
+ 1
300 local actualtype
= type(actual
)
301 if actualtype
~= typename
then
302 failure( assert_typename
, msg
, typename
.." expected but was a "..actualtype
)
306 traceback_hide( lunit
[assert_typename
] )
310 -- lunit.assert_not_typename functions
312 for _
, typename
in ipairs(typenames
) do
313 local assert_not_name
= "assert_not_"..typename
314 lunit
[assert_not_name
] = function(actual
, msg
)
315 stats
.assertions
= stats
.assertions
+ 1
316 if type(actual
) == typename
then
317 failure( assert_not_typename
, msg
, typename
.." not expected but was one" )
320 traceback_hide( lunit
[assert_not_name
] )
324 function lunit
.clearstats()
334 local report
, reporterrobj
338 function lunit
.setrunner(newrunner
)
339 if not ( is_table(newrunner
) or is_nil(newrunner
) ) then
340 return error("lunit.setrunner: Invalid argument", 0)
342 local oldrunner
= testrunner
343 testrunner
= newrunner
347 function lunit
.loadrunner(name
)
348 if not is_string(name
) then
349 return error("lunit.loadrunner: Invalid argument", 0)
351 local ok
, runner
= pcall( require
, name
)
353 return error("lunit.loadrunner: Can't load test runner: "..runner
, 0)
355 return setrunner(runner
)
358 function report(event
, ...)
359 local f
= testrunner
and testrunner
[event
]
360 if is_function(f
) then
365 function reporterrobj(context
, tcname
, testname
, errobj
)
366 local fullname
= tcname
.. "." .. testname
367 if context
== "setup" then
368 fullname
= fullname
.. ":" .. setupname(tcname
, testname
)
369 elseif context
== "teardown" then
370 fullname
= fullname
.. ":" .. teardownname(tcname
, testname
)
372 if errobj
.type == __failure__
then
373 stats
.failed
= stats
.failed
+ 1
374 report("fail", fullname
, errobj
.where
, errobj
.msg
, errobj
.usermsg
)
376 stats
.errors
= stats
.errors
+ 1
377 report("err", fullname
, errobj
.msg
, errobj
.tb
)
384 local function key_iter(t
, k
)
391 -- Array with all registered testcases
392 local _testcases
= {}
394 -- Marks a module as a testcase.
395 -- Applied over a module from module("xyz", lunit.testcase).
396 function lunit
.testcase(m
)
397 orig_assert( is_table(m
) )
398 --orig_assert( m._M == m )
399 orig_assert( is_string(m
._NAME
) )
400 --orig_assert( is_string(m._PACKAGE) )
402 -- Register the module as a testcase
403 _testcases
[m
._NAME
] = m
405 -- Import lunit, fail, assert* and is_* function to the module/testcase
408 for funcname
, func
in pairs(lunit
) do
409 if "assert" == string_sub(funcname
, 1, 6) or "is_" == string_sub(funcname
, 1, 3) then
415 -- Iterator (testcasename) over all Testcases
416 function lunit
.testcases()
417 -- Make a copy of testcases to prevent confusing the iterator when
418 -- new testcase are defined
419 local _testcases2
= {}
420 for k
,v
in pairs(_testcases
) do
421 _testcases2
[k
] = true
423 return key_iter
, _testcases2
, nil
426 function testcase(tcname
)
427 return _testcases
[tcname
]
433 -- Finds a function in a testcase case insensitiv
434 local function findfuncname(tcname
, name
)
435 for key
, value
in pairs(testcase(tcname
)) do
436 if is_string(key
) and is_function(value
) and string.lower(key
) == name
then
442 function lunit
.setupname(tcname
)
443 return findfuncname(tcname
, "setup")
446 function lunit
.teardownname(tcname
)
447 return findfuncname(tcname
, "teardown")
450 -- Iterator over all test names in a testcase.
451 -- Have to collect the names first in case one of the test
452 -- functions creates a new global and throws off the iteration.
453 function lunit
.tests(tcname
)
455 for key
, value
in pairs(testcase(tcname
)) do
456 if is_string(key
) and is_function(value
) then
457 local lfn
= string.lower(key
)
458 if string.sub(lfn
, 1, 4) == "test" or string.sub(lfn
, -4) == "test" then
459 testnames
[key
] = true
463 return key_iter
, testnames
, nil
470 function lunit
.runtest(tcname
, testname
)
471 orig_assert( is_string(tcname
) )
472 orig_assert( is_string(testname
) )
474 local function callit(context
, func
)
476 local err
= mypcall(func
)
478 reporterrobj(context
, tcname
, testname
, err
)
484 traceback_hide(callit
)
486 report("run", tcname
, testname
)
488 local tc
= testcase(tcname
)
489 local setup
= tc
[setupname(tcname
)]
490 local test
= tc
[testname
]
491 local teardown
= tc
[teardownname(tcname
)]
493 local setup_ok
= callit( "setup", setup
)
494 local test_ok
= setup_ok
and callit( "test", test
)
495 local teardown_ok
= setup_ok
and callit( "teardown", teardown
)
497 if setup_ok
and test_ok
and teardown_ok
then
498 stats
.passed
= stats
.passed
+ 1
499 report("pass", tcname
, testname
)
502 traceback_hide(runtest
)
509 for testcasename
in lunit
.testcases() do
510 -- Run tests in the testcases
511 for testname
in lunit
.tests(testcasename
) do
512 runtest(testcasename
, testname
)
521 function lunit
.loadonly()
536 local lunitpat2luapat
552 function lunitpat2luapat(str
)
553 return "^" .. string.gsub(str
, "%W", conv
) .. "$"
559 local function in_patternmap(map
, name
)
560 if map
[name
] == true then
563 for _
, pat
in ipairs(map
) do
564 if string.find(name
, pat
) then
579 -- Called from 'lunit' shell script.
584 -- FIXME: Error handling and error messages aren't nice.
586 local function checkarg(optname
, arg
)
587 if not is_string(arg
) then
588 return error("lunit.main: option "..optname
..": argument missing.", 0)
592 local function loadtestcase(filename
)
593 if not is_string(filename
) then
594 return error("lunit.main: invalid argument")
596 local chunk
, err
= loadfile(filename
)
604 local testpatterns
= nil
605 local doloadonly
= false
612 if arg
== "--loadonly" then
614 elseif arg
== "--runner" or arg
== "-r" then
615 local optname
= arg
; i
= i
+ 1; arg
= argv
[i
]
616 checkarg(optname
, arg
)
618 elseif arg
== "--test" or arg
== "-t" then
619 local optname
= arg
; i
= i
+ 1; arg
= argv
[i
]
620 checkarg(optname
, arg
)
621 testpatterns
= testpatterns
or {}
622 testpatterns
[#testpatterns
+1] = arg
623 elseif arg
== "--" then
625 i
= i
+ 1; arg
= argv
[i
]
633 loadrunner(runner
or "lunit-console")
638 return run(testpatterns
)