tests: Move repeated code into a helper function
[Samba.git] / python / samba / tests / __init__.py
blob61036b5247dc594b0df7c4283747bceda5e4db58
1 # Unix SMB/CIFS implementation.
2 # Copyright (C) Jelmer Vernooij <jelmer@samba.org> 2007-2010
3 # Copyright (C) Stefan Metzmacher 2014,2015
5 # This program is free software; you can redistribute it and/or modify
6 # it under the terms of the GNU General Public License as published by
7 # the Free Software Foundation; either version 3 of the License, or
8 # (at your option) any later version.
10 # This program is distributed in the hope that it will be useful,
11 # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 # GNU General Public License for more details.
15 # You should have received a copy of the GNU General Public License
16 # along with this program. If not, see <http://www.gnu.org/licenses/>.
19 """Samba Python tests."""
21 import os
22 import tempfile
23 import ldb
24 import samba
25 from samba import param
26 from samba import credentials
27 from samba.credentials import Credentials
28 from samba import gensec
29 import socket
30 import struct
31 import subprocess
32 import sys
33 import tempfile
34 import unittest
35 import re
36 import samba.auth
37 import samba.dcerpc.base
38 from samba.compat import PY3, text_type
39 from random import randint
40 if not PY3:
41 # Py2 only
42 from samba.samdb import SamDB
43 import samba.ndr
44 import samba.dcerpc.dcerpc
45 import samba.dcerpc.epmapper
47 try:
48 from unittest import SkipTest
49 except ImportError:
50 class SkipTest(Exception):
51 """Test skipped."""
53 HEXDUMP_FILTER=''.join([(len(repr(chr(x)))==3) and chr(x) or '.' for x in range(256)])
55 class TestCase(unittest.TestCase):
56 """A Samba test case."""
58 def setUp(self):
59 super(TestCase, self).setUp()
60 test_debug_level = os.getenv("TEST_DEBUG_LEVEL")
61 if test_debug_level is not None:
62 test_debug_level = int(test_debug_level)
63 self._old_debug_level = samba.get_debug_level()
64 samba.set_debug_level(test_debug_level)
65 self.addCleanup(samba.set_debug_level, test_debug_level)
67 def get_loadparm(self):
68 return env_loadparm()
70 def get_credentials(self):
71 return cmdline_credentials
73 def get_creds_ccache_name(self):
74 creds = self.get_credentials()
75 ccache = creds.get_named_ccache(self.get_loadparm())
76 ccache_name = ccache.get_name()
78 return ccache_name
80 def hexdump(self, src):
81 N = 0
82 result = ''
83 while src:
84 ll = src[:8]
85 lr = src[8:16]
86 src = src[16:]
87 hl = ' '.join(["%02X" % ord(x) for x in ll])
88 hr = ' '.join(["%02X" % ord(x) for x in lr])
89 ll = ll.translate(HEXDUMP_FILTER)
90 lr = lr.translate(HEXDUMP_FILTER)
91 result += "[%04X] %-*s %-*s %s %s\n" % (N, 8*3, hl, 8*3, hr, ll, lr)
92 N += 16
93 return result
95 def insta_creds(self, template=None, username=None, userpass=None, kerberos_state=None):
97 if template is None:
98 assert template is not None
100 if username is not None:
101 assert userpass is not None
103 if username is None:
104 assert userpass is None
106 username = template.get_username()
107 userpass = template.get_password()
109 if kerberos_state is None:
110 kerberos_state = template.get_kerberos_state()
112 # get a copy of the global creds or a the passed in creds
113 c = Credentials()
114 c.set_username(username)
115 c.set_password(userpass)
116 c.set_domain(template.get_domain())
117 c.set_realm(template.get_realm())
118 c.set_workstation(template.get_workstation())
119 c.set_gensec_features(c.get_gensec_features()
120 | gensec.FEATURE_SEAL)
121 c.set_kerberos_state(kerberos_state)
122 return c
126 # These functions didn't exist before Python2.7:
127 if sys.version_info < (2, 7):
128 import warnings
130 def skipTest(self, reason):
131 raise SkipTest(reason)
133 def assertIn(self, member, container, msg=None):
134 self.assertTrue(member in container, msg)
136 def assertIs(self, a, b, msg=None):
137 self.assertTrue(a is b, msg)
139 def assertIsNot(self, a, b, msg=None):
140 self.assertTrue(a is not b, msg)
142 def assertIsNotNone(self, a, msg=None):
143 self.assertTrue(a is not None)
145 def assertIsInstance(self, a, b, msg=None):
146 self.assertTrue(isinstance(a, b), msg)
148 def assertIsNone(self, a, msg=None):
149 self.assertTrue(a is None, msg)
151 def assertGreater(self, a, b, msg=None):
152 self.assertTrue(a > b, msg)
154 def assertGreaterEqual(self, a, b, msg=None):
155 self.assertTrue(a >= b, msg)
157 def assertLess(self, a, b, msg=None):
158 self.assertTrue(a < b, msg)
160 def assertLessEqual(self, a, b, msg=None):
161 self.assertTrue(a <= b, msg)
163 def addCleanup(self, fn, *args, **kwargs):
164 self._cleanups = getattr(self, "_cleanups", []) + [
165 (fn, args, kwargs)]
167 def assertRegexpMatches(self, text, regex, msg=None):
168 # PY3 note: Python 3 will never see this, but we use
169 # text_type for the benefit of linters.
170 if isinstance(regex, (str, text_type)):
171 regex = re.compile(regex)
172 if not regex.search(text):
173 self.fail(msg)
175 def _addSkip(self, result, reason):
176 addSkip = getattr(result, 'addSkip', None)
177 if addSkip is not None:
178 addSkip(self, reason)
179 else:
180 warnings.warn("TestResult has no addSkip method, skips not reported",
181 RuntimeWarning, 2)
182 result.addSuccess(self)
184 def run(self, result=None):
185 if result is None: result = self.defaultTestResult()
186 result.startTest(self)
187 testMethod = getattr(self, self._testMethodName)
188 try:
189 try:
190 self.setUp()
191 except SkipTest as e:
192 self._addSkip(result, str(e))
193 return
194 except KeyboardInterrupt:
195 raise
196 except:
197 result.addError(self, self._exc_info())
198 return
200 ok = False
201 try:
202 testMethod()
203 ok = True
204 except SkipTest as e:
205 self._addSkip(result, str(e))
206 return
207 except self.failureException:
208 result.addFailure(self, self._exc_info())
209 except KeyboardInterrupt:
210 raise
211 except:
212 result.addError(self, self._exc_info())
214 try:
215 self.tearDown()
216 except SkipTest as e:
217 self._addSkip(result, str(e))
218 except KeyboardInterrupt:
219 raise
220 except:
221 result.addError(self, self._exc_info())
222 ok = False
224 for (fn, args, kwargs) in reversed(getattr(self, "_cleanups", [])):
225 fn(*args, **kwargs)
226 if ok: result.addSuccess(self)
227 finally:
228 result.stopTest(self)
230 def assertStringsEqual(self, a, b, msg=None, strip=False):
231 """Assert equality between two strings and highlight any differences.
232 If strip is true, leading and trailing whitespace is ignored."""
233 if strip:
234 a = a.strip()
235 b = b.strip()
237 if a != b:
238 sys.stderr.write("The strings differ %s(lengths %d vs %d); "
239 "a diff follows\n"
240 % ('when stripped ' if strip else '',
241 len(a), len(b),
244 from difflib import unified_diff
245 diff = unified_diff(a.splitlines(True),
246 b.splitlines(True),
247 'a', 'b')
248 for line in diff:
249 sys.stderr.write(line)
251 self.fail(msg)
254 class LdbTestCase(TestCase):
255 """Trivial test case for running tests against a LDB."""
257 def setUp(self):
258 super(LdbTestCase, self).setUp()
259 self.tempfile = tempfile.NamedTemporaryFile(delete=False)
260 self.filename = self.tempfile.name
261 self.ldb = samba.Ldb(self.filename)
263 def set_modules(self, modules=[]):
264 """Change the modules for this Ldb."""
265 m = ldb.Message()
266 m.dn = ldb.Dn(self.ldb, "@MODULES")
267 m["@LIST"] = ",".join(modules)
268 self.ldb.add(m)
269 self.ldb = samba.Ldb(self.filename)
272 class TestCaseInTempDir(TestCase):
274 def setUp(self):
275 super(TestCaseInTempDir, self).setUp()
276 self.tempdir = tempfile.mkdtemp()
277 self.addCleanup(self._remove_tempdir)
279 def _remove_tempdir(self):
280 self.assertEquals([], os.listdir(self.tempdir))
281 os.rmdir(self.tempdir)
282 self.tempdir = None
285 def env_loadparm():
286 lp = param.LoadParm()
287 try:
288 lp.load(os.environ["SMB_CONF_PATH"])
289 except KeyError:
290 raise KeyError("SMB_CONF_PATH not set")
291 return lp
294 def env_get_var_value(var_name, allow_missing=False):
295 """Returns value for variable in os.environ
297 Function throws AssertionError if variable is defined.
298 Unit-test based python tests require certain input params
299 to be set in environment, otherwise they can't be run
301 if allow_missing:
302 if var_name not in os.environ.keys():
303 return None
304 assert var_name in os.environ.keys(), "Please supply %s in environment" % var_name
305 return os.environ[var_name]
308 cmdline_credentials = None
310 class RpcInterfaceTestCase(TestCase):
311 """DCE/RPC Test case."""
314 class ValidNetbiosNameTests(TestCase):
316 def test_valid(self):
317 self.assertTrue(samba.valid_netbios_name("FOO"))
319 def test_too_long(self):
320 self.assertFalse(samba.valid_netbios_name("FOO"*10))
322 def test_invalid_characters(self):
323 self.assertFalse(samba.valid_netbios_name("*BLA"))
326 class BlackboxProcessError(Exception):
327 """This is raised when check_output() process returns a non-zero exit status
329 Exception instance should contain the exact exit code (S.returncode),
330 command line (S.cmd), process output (S.stdout) and process error stream
331 (S.stderr)
334 def __init__(self, returncode, cmd, stdout, stderr, msg=None):
335 self.returncode = returncode
336 self.cmd = cmd
337 self.stdout = stdout
338 self.stderr = stderr
339 self.msg = msg
341 def __str__(self):
342 s = ("Command '%s'; exit status %d; stdout: '%s'; stderr: '%s'" %
343 (self.cmd, self.returncode, self.stdout, self.stderr))
344 if self.msg is not None:
345 s = "%s; message: %s" % (s, self.msg)
347 return s
349 class BlackboxTestCase(TestCaseInTempDir):
350 """Base test case for blackbox tests."""
352 def _make_cmdline(self, line):
353 bindir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../../bin"))
354 parts = line.split(" ")
355 if os.path.exists(os.path.join(bindir, parts[0])):
356 parts[0] = os.path.join(bindir, parts[0])
357 line = " ".join(parts)
358 return line
360 def check_run(self, line, msg=None):
361 self.check_exit_code(line, 0, msg=msg)
363 def check_exit_code(self, line, expected, msg=None):
364 line = self._make_cmdline(line)
365 p = subprocess.Popen(line,
366 stdout=subprocess.PIPE,
367 stderr=subprocess.PIPE,
368 shell=True)
369 stdoutdata, stderrdata = p.communicate()
370 retcode = p.returncode
371 if retcode != expected:
372 raise BlackboxProcessError(retcode,
373 line,
374 stdoutdata,
375 stderrdata,
376 msg)
378 def check_output(self, line):
379 line = self._make_cmdline(line)
380 p = subprocess.Popen(line, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True, close_fds=True)
381 stdoutdata, stderrdata = p.communicate()
382 retcode = p.returncode
383 if retcode:
384 raise BlackboxProcessError(retcode, line, stdoutdata, stderrdata)
385 return stdoutdata
388 def connect_samdb(samdb_url, lp=None, session_info=None, credentials=None,
389 flags=0, ldb_options=None, ldap_only=False, global_schema=True):
390 """Create SamDB instance and connects to samdb_url database.
392 :param samdb_url: Url for database to connect to.
393 :param lp: Optional loadparm object
394 :param session_info: Optional session information
395 :param credentials: Optional credentials, defaults to anonymous.
396 :param flags: Optional LDB flags
397 :param ldap_only: If set, only remote LDAP connection will be created.
398 :param global_schema: Whether to use global schema.
400 Added value for tests is that we have a shorthand function
401 to make proper URL for ldb.connect() while using default
402 parameters for connection based on test environment
404 if not "://" in samdb_url:
405 if not ldap_only and os.path.isfile(samdb_url):
406 samdb_url = "tdb://%s" % samdb_url
407 else:
408 samdb_url = "ldap://%s" % samdb_url
409 # use 'paged_search' module when connecting remotely
410 if samdb_url.startswith("ldap://"):
411 ldb_options = ["modules:paged_searches"]
412 elif ldap_only:
413 raise AssertionError("Trying to connect to %s while remote "
414 "connection is required" % samdb_url)
416 # set defaults for test environment
417 if lp is None:
418 lp = env_loadparm()
419 if session_info is None:
420 session_info = samba.auth.system_session(lp)
421 if credentials is None:
422 credentials = cmdline_credentials
424 return SamDB(url=samdb_url,
425 lp=lp,
426 session_info=session_info,
427 credentials=credentials,
428 flags=flags,
429 options=ldb_options,
430 global_schema=global_schema)
433 def connect_samdb_ex(samdb_url, lp=None, session_info=None, credentials=None,
434 flags=0, ldb_options=None, ldap_only=False):
435 """Connects to samdb_url database
437 :param samdb_url: Url for database to connect to.
438 :param lp: Optional loadparm object
439 :param session_info: Optional session information
440 :param credentials: Optional credentials, defaults to anonymous.
441 :param flags: Optional LDB flags
442 :param ldap_only: If set, only remote LDAP connection will be created.
443 :return: (sam_db_connection, rootDse_record) tuple
445 sam_db = connect_samdb(samdb_url, lp, session_info, credentials,
446 flags, ldb_options, ldap_only)
447 # fetch RootDse
448 res = sam_db.search(base="", expression="", scope=ldb.SCOPE_BASE,
449 attrs=["*"])
450 return (sam_db, res[0])
453 def connect_samdb_env(env_url, env_username, env_password, lp=None):
454 """Connect to SamDB by getting URL and Credentials from environment
456 :param env_url: Environment variable name to get lsb url from
457 :param env_username: Username environment variable
458 :param env_password: Password environment variable
459 :return: sam_db_connection
461 samdb_url = env_get_var_value(env_url)
462 creds = credentials.Credentials()
463 if lp is None:
464 # guess Credentials parameters here. Otherwise workstation
465 # and domain fields are NULL and gencache code segfalts
466 lp = param.LoadParm()
467 creds.guess(lp)
468 creds.set_username(env_get_var_value(env_username))
469 creds.set_password(env_get_var_value(env_password))
470 return connect_samdb(samdb_url, credentials=creds, lp=lp)
473 def delete_force(samdb, dn, **kwargs):
474 try:
475 samdb.delete(dn, **kwargs)
476 except ldb.LdbError as error:
477 (num, errstr) = error.args
478 assert num == ldb.ERR_NO_SUCH_OBJECT, "ldb.delete() failed: %s" % errstr
480 def create_test_ou(samdb, name):
481 """Creates a unique OU for the test"""
483 # Add some randomness to the test OU. Replication between the testenvs is
484 # constantly happening in the background. Deletion of the last test's
485 # objects can be slow to replicate out. So the OU created by a previous
486 # testenv may still exist at the point that tests start on another testenv.
487 rand = randint(1, 10000000)
488 dn = "OU=%s%d,%s" %(name, rand, samdb.get_default_basedn())
489 samdb.add({ "dn": dn, "objectclass": "organizationalUnit"})
490 return dn