10 from pykickstart
.version
import versionMap
, returnClassForVersion
11 from pykickstart
.errors
import *
13 gettext
.textdomain("pykickstart")
14 _
= lambda x
: gettext
.ldgettext("pykickstart", x
)
16 # Base class for any test case
17 class CommandTest(unittest
.TestCase
):
19 '''Perform any command setup'''
20 unittest
.TestCase
.setUp(self
)
23 # ignore DeprecationWarning
24 warnings
.simplefilter("ignore", category
=DeprecationWarning, append
=0)
27 '''Undo anything performed by setUp(self)'''
29 warnings
.filters
= warnings
.filters
[1:]
31 unittest
.TestCase
.tearDown(self
)
33 def getParser(self
, inputStr
):
34 '''Find a handler using the class name. Return the requested command
36 args
= shlex
.split(inputStr
)
39 if self
.handler
is None:
40 version
= self
.__class
__.__name
__.split("_")[0]
41 self
.handler
= returnClassForVersion(version
)
43 parser
= self
.handler().commands
[cmd
]
44 parser
.currentLine
= inputStr
45 parser
.currentCmd
= args
[0]
49 def assert_parse(self
, inputStr
, expectedStr
=None, ignoreComments
=True):
50 '''KickstartParseError is not raised and the resulting string matches
52 parser
= self
.getParser(inputStr
)
53 args
= shlex
.split(inputStr
)
55 # If expectedStr supplied, we want to ensure the parsed result matches
56 if expectedStr
is not None:
57 result
= parser
.parse(args
[1:])
59 # Strip any comment lines ... we only match on non-comments
61 result
= re
.sub("^#[^\n]*\n", "", str(result
))
63 # Ensure we parsed as expected
64 self
.assertEqual(str(result
), expectedStr
)
65 # No expectedStr supplied, just make sure it does not raise an
69 result
= parser
.parse(args
[1:])
71 self
.fail("Failed while parsing: %s" % e
)
73 def assert_parse_error(self
, inputStr
, exception
=KickstartParseError
):
74 '''Assert that parsing the supplied string raises a
75 KickstartParseError'''
76 parser
= self
.getParser(inputStr
)
77 args
= shlex
.split(inputStr
)
79 self
.assertRaises(exception
, parser
.parse
, args
[1:])
81 def assert_deprecated(self
, cmd
, opt
):
82 '''Ensure that the provided option is listed as deprecated'''
83 parser
= self
.getParser(cmd
)
85 for op
in parser
.op
.option_list
:
86 if op
.get_opt_string() == opt
:
87 self
.assert_(op
.deprecated
)
89 def assert_removed(self
, cmd
, opt
):
90 '''Ensure that the provided option is not present in option_list'''
91 parser
= self
.getParser(cmd
)
92 for op
in parser
.op
.option_list
:
93 self
.assertNotEqual(op
.dest
, opt
)
95 def assert_required(self
, cmd
, opt
):
96 '''Ensure that the provided option is labelled as required in
98 parser
= self
.getParser(cmd
)
99 for op
in parser
.op
.option_list
:
100 if op
.get_opt_string() == opt
:
101 self
.assert_(op
.required
)
103 def assert_type(self
, cmd
, opt
, opt_type
):
104 '''Ensure that the provided option is of the requested type'''
105 parser
= self
.getParser(cmd
)
106 for op
in parser
.op
.option_list
:
107 if op
.get_opt_string() == opt
:
108 self
.assertEqual(op
.type, opt_type
)
110 def loadModules(moduleDir
, cls_pattern
="_TestCase", skip_list
=["__init__", "baseclass"]):
111 '''taken from firstboot/loader.py'''
113 # Guaruntee that __init__ is skipped
114 if skip_list
.count("__init__") == 0:
115 skip_list
.append("__init__")
119 # Make sure moduleDir is in the system path so imputil works.
120 if not moduleDir
in sys
.path
:
121 sys
.path
.insert(0, moduleDir
)
123 # Get a list of all *.py files in moduleDir
125 lst
= map(lambda x
: os
.path
.splitext(os
.path
.basename(x
))[0],
126 glob
.glob(moduleDir
+ "/*.py"))
128 # Inspect each .py file found
130 if module
in skip_list
:
133 # Attempt to load the found module.
135 found
= imputil
.imp
.find_module(module
)
136 loaded
= imputil
.imp
.load_module(module
, found
[0], found
[1], found
[2])
137 except ImportError, e
:
138 print(_("Error loading module %s.") % module
)
140 # Find class names that match the supplied pattern (default: "_TestCase")
141 beforeCount
= len(tstList
)
142 for obj
in loaded
.__dict
__.keys():
143 if obj
.endswith(cls_pattern
):
144 tstList
.append(loaded
.__dict
__[obj
])
145 afterCount
= len(tstList
)
147 # Warn if no tests found
148 if beforeCount
== afterCount
:
149 print(_("Module %s does not contain any test cases; skipping.") % module
)
155 if __name__
== "__main__":
157 # Create a test suite
158 PyKickstartTestSuite
= unittest
.TestSuite()
161 tstList
= loadModules(os
.path
.join(os
.environ
.get("PWD"), "tests/"))
162 tstList
.extend(loadModules(os
.path
.join(os
.environ
.get("PWD"), "tests/commands")))
164 PyKickstartTestSuite
.addTest(tst())
167 unittest
.main(defaultTest
="PyKickstartTestSuite")