1 # Unix SMB/CIFS implementation.
2 # Copyright (C) Jelmer Vernooij <jelmer@samba.org> 2007-2010
3 # Copyright (C) Stefan Metzmacher 2014,2015
5 # This program is free software; you can redistribute it and/or modify
6 # it under the terms of the GNU General Public License as published by
7 # the Free Software Foundation; either version 3 of the License, or
8 # (at your option) any later version.
10 # This program is distributed in the hope that it will be useful,
11 # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 # GNU General Public License for more details.
15 # You should have received a copy of the GNU General Public License
16 # along with this program. If not, see <http://www.gnu.org/licenses/>.
19 """Samba Python tests."""
25 from samba
import param
26 from samba
import credentials
27 from samba
.credentials
import Credentials
28 from samba
import gensec
37 import samba
.dcerpc
.base
38 from samba
.compat
import PY3
, text_type
39 from random
import randint
42 from samba
.samdb
import SamDB
44 import samba
.dcerpc
.dcerpc
45 import samba
.dcerpc
.epmapper
48 from unittest
import SkipTest
50 class SkipTest(Exception):
53 HEXDUMP_FILTER
=''.join([(len(repr(chr(x
)))==3) and chr(x
) or '.' for x
in range(256)])
55 class TestCase(unittest
.TestCase
):
56 """A Samba test case."""
59 super(TestCase
, self
).setUp()
60 test_debug_level
= os
.getenv("TEST_DEBUG_LEVEL")
61 if test_debug_level
is not None:
62 test_debug_level
= int(test_debug_level
)
63 self
._old
_debug
_level
= samba
.get_debug_level()
64 samba
.set_debug_level(test_debug_level
)
65 self
.addCleanup(samba
.set_debug_level
, test_debug_level
)
67 def get_loadparm(self
):
70 def get_credentials(self
):
71 return cmdline_credentials
73 def get_creds_ccache_name(self
):
74 creds
= self
.get_credentials()
75 ccache
= creds
.get_named_ccache(self
.get_loadparm())
76 ccache_name
= ccache
.get_name()
80 def hexdump(self
, src
):
87 hl
= ' '.join(["%02X" % ord(x
) for x
in ll
])
88 hr
= ' '.join(["%02X" % ord(x
) for x
in lr
])
89 ll
= ll
.translate(HEXDUMP_FILTER
)
90 lr
= lr
.translate(HEXDUMP_FILTER
)
91 result
+= "[%04X] %-*s %-*s %s %s\n" % (N
, 8*3, hl
, 8*3, hr
, ll
, lr
)
95 def insta_creds(self
, template
=None, username
=None, userpass
=None, kerberos_state
=None):
98 assert template
is not None
100 if username
is not None:
101 assert userpass
is not None
104 assert userpass
is None
106 username
= template
.get_username()
107 userpass
= template
.get_password()
109 if kerberos_state
is None:
110 kerberos_state
= template
.get_kerberos_state()
112 # get a copy of the global creds or a the passed in creds
114 c
.set_username(username
)
115 c
.set_password(userpass
)
116 c
.set_domain(template
.get_domain())
117 c
.set_realm(template
.get_realm())
118 c
.set_workstation(template
.get_workstation())
119 c
.set_gensec_features(c
.get_gensec_features()
120 | gensec
.FEATURE_SEAL
)
121 c
.set_kerberos_state(kerberos_state
)
126 # These functions didn't exist before Python2.7:
127 if sys
.version_info
< (2, 7):
130 def skipTest(self
, reason
):
131 raise SkipTest(reason
)
133 def assertIn(self
, member
, container
, msg
=None):
134 self
.assertTrue(member
in container
, msg
)
136 def assertIs(self
, a
, b
, msg
=None):
137 self
.assertTrue(a
is b
, msg
)
139 def assertIsNot(self
, a
, b
, msg
=None):
140 self
.assertTrue(a
is not b
, msg
)
142 def assertIsNotNone(self
, a
, msg
=None):
143 self
.assertTrue(a
is not None)
145 def assertIsInstance(self
, a
, b
, msg
=None):
146 self
.assertTrue(isinstance(a
, b
), msg
)
148 def assertIsNone(self
, a
, msg
=None):
149 self
.assertTrue(a
is None, msg
)
151 def assertGreater(self
, a
, b
, msg
=None):
152 self
.assertTrue(a
> b
, msg
)
154 def assertGreaterEqual(self
, a
, b
, msg
=None):
155 self
.assertTrue(a
>= b
, msg
)
157 def assertLess(self
, a
, b
, msg
=None):
158 self
.assertTrue(a
< b
, msg
)
160 def assertLessEqual(self
, a
, b
, msg
=None):
161 self
.assertTrue(a
<= b
, msg
)
163 def addCleanup(self
, fn
, *args
, **kwargs
):
164 self
._cleanups
= getattr(self
, "_cleanups", []) + [
167 def assertRegexpMatches(self
, text
, regex
, msg
=None):
168 # PY3 note: Python 3 will never see this, but we use
169 # text_type for the benefit of linters.
170 if isinstance(regex
, (str, text_type
)):
171 regex
= re
.compile(regex
)
172 if not regex
.search(text
):
175 def _addSkip(self
, result
, reason
):
176 addSkip
= getattr(result
, 'addSkip', None)
177 if addSkip
is not None:
178 addSkip(self
, reason
)
180 warnings
.warn("TestResult has no addSkip method, skips not reported",
182 result
.addSuccess(self
)
184 def run(self
, result
=None):
185 if result
is None: result
= self
.defaultTestResult()
186 result
.startTest(self
)
187 testMethod
= getattr(self
, self
._testMethodName
)
191 except SkipTest
as e
:
192 self
._addSkip
(result
, str(e
))
194 except KeyboardInterrupt:
197 result
.addError(self
, self
._exc
_info
())
204 except SkipTest
as e
:
205 self
._addSkip
(result
, str(e
))
207 except self
.failureException
:
208 result
.addFailure(self
, self
._exc
_info
())
209 except KeyboardInterrupt:
212 result
.addError(self
, self
._exc
_info
())
216 except SkipTest
as e
:
217 self
._addSkip
(result
, str(e
))
218 except KeyboardInterrupt:
221 result
.addError(self
, self
._exc
_info
())
224 for (fn
, args
, kwargs
) in reversed(getattr(self
, "_cleanups", [])):
226 if ok
: result
.addSuccess(self
)
228 result
.stopTest(self
)
230 def assertStringsEqual(self
, a
, b
, msg
=None, strip
=False):
231 """Assert equality between two strings and highlight any differences.
232 If strip is true, leading and trailing whitespace is ignored."""
238 sys
.stderr
.write("The strings differ %s(lengths %d vs %d); "
240 % ('when stripped ' if strip
else '',
244 from difflib
import unified_diff
245 diff
= unified_diff(a
.splitlines(True),
249 sys
.stderr
.write(line
)
254 class LdbTestCase(TestCase
):
255 """Trivial test case for running tests against a LDB."""
258 super(LdbTestCase
, self
).setUp()
259 self
.tempfile
= tempfile
.NamedTemporaryFile(delete
=False)
260 self
.filename
= self
.tempfile
.name
261 self
.ldb
= samba
.Ldb(self
.filename
)
263 def set_modules(self
, modules
=[]):
264 """Change the modules for this Ldb."""
266 m
.dn
= ldb
.Dn(self
.ldb
, "@MODULES")
267 m
["@LIST"] = ",".join(modules
)
269 self
.ldb
= samba
.Ldb(self
.filename
)
272 class TestCaseInTempDir(TestCase
):
275 super(TestCaseInTempDir
, self
).setUp()
276 self
.tempdir
= tempfile
.mkdtemp()
277 self
.addCleanup(self
._remove
_tempdir
)
279 def _remove_tempdir(self
):
280 self
.assertEquals([], os
.listdir(self
.tempdir
))
281 os
.rmdir(self
.tempdir
)
286 lp
= param
.LoadParm()
288 lp
.load(os
.environ
["SMB_CONF_PATH"])
290 raise KeyError("SMB_CONF_PATH not set")
294 def env_get_var_value(var_name
, allow_missing
=False):
295 """Returns value for variable in os.environ
297 Function throws AssertionError if variable is defined.
298 Unit-test based python tests require certain input params
299 to be set in environment, otherwise they can't be run
302 if var_name
not in os
.environ
.keys():
304 assert var_name
in os
.environ
.keys(), "Please supply %s in environment" % var_name
305 return os
.environ
[var_name
]
308 cmdline_credentials
= None
310 class RpcInterfaceTestCase(TestCase
):
311 """DCE/RPC Test case."""
314 class ValidNetbiosNameTests(TestCase
):
316 def test_valid(self
):
317 self
.assertTrue(samba
.valid_netbios_name("FOO"))
319 def test_too_long(self
):
320 self
.assertFalse(samba
.valid_netbios_name("FOO"*10))
322 def test_invalid_characters(self
):
323 self
.assertFalse(samba
.valid_netbios_name("*BLA"))
326 class BlackboxProcessError(Exception):
327 """This is raised when check_output() process returns a non-zero exit status
329 Exception instance should contain the exact exit code (S.returncode),
330 command line (S.cmd), process output (S.stdout) and process error stream
334 def __init__(self
, returncode
, cmd
, stdout
, stderr
, msg
=None):
335 self
.returncode
= returncode
342 s
= ("Command '%s'; exit status %d; stdout: '%s'; stderr: '%s'" %
343 (self
.cmd
, self
.returncode
, self
.stdout
, self
.stderr
))
344 if self
.msg
is not None:
345 s
= "%s; message: %s" % (s
, self
.msg
)
349 class BlackboxTestCase(TestCaseInTempDir
):
350 """Base test case for blackbox tests."""
352 def _make_cmdline(self
, line
):
353 bindir
= os
.path
.abspath(os
.path
.join(os
.path
.dirname(__file__
), "../../../../bin"))
354 parts
= line
.split(" ")
355 if os
.path
.exists(os
.path
.join(bindir
, parts
[0])):
356 parts
[0] = os
.path
.join(bindir
, parts
[0])
357 line
= " ".join(parts
)
360 def check_run(self
, line
, msg
=None):
361 self
.check_exit_code(line
, 0, msg
=msg
)
363 def check_exit_code(self
, line
, expected
, msg
=None):
364 line
= self
._make
_cmdline
(line
)
365 p
= subprocess
.Popen(line
,
366 stdout
=subprocess
.PIPE
,
367 stderr
=subprocess
.PIPE
,
369 stdoutdata
, stderrdata
= p
.communicate()
370 retcode
= p
.returncode
371 if retcode
!= expected
:
372 raise BlackboxProcessError(retcode
,
378 def check_output(self
, line
):
379 line
= self
._make
_cmdline
(line
)
380 p
= subprocess
.Popen(line
, stdout
=subprocess
.PIPE
, stderr
=subprocess
.PIPE
, shell
=True, close_fds
=True)
381 stdoutdata
, stderrdata
= p
.communicate()
382 retcode
= p
.returncode
384 raise BlackboxProcessError(retcode
, line
, stdoutdata
, stderrdata
)
388 def connect_samdb(samdb_url
, lp
=None, session_info
=None, credentials
=None,
389 flags
=0, ldb_options
=None, ldap_only
=False, global_schema
=True):
390 """Create SamDB instance and connects to samdb_url database.
392 :param samdb_url: Url for database to connect to.
393 :param lp: Optional loadparm object
394 :param session_info: Optional session information
395 :param credentials: Optional credentials, defaults to anonymous.
396 :param flags: Optional LDB flags
397 :param ldap_only: If set, only remote LDAP connection will be created.
398 :param global_schema: Whether to use global schema.
400 Added value for tests is that we have a shorthand function
401 to make proper URL for ldb.connect() while using default
402 parameters for connection based on test environment
404 if not "://" in samdb_url
:
405 if not ldap_only
and os
.path
.isfile(samdb_url
):
406 samdb_url
= "tdb://%s" % samdb_url
408 samdb_url
= "ldap://%s" % samdb_url
409 # use 'paged_search' module when connecting remotely
410 if samdb_url
.startswith("ldap://"):
411 ldb_options
= ["modules:paged_searches"]
413 raise AssertionError("Trying to connect to %s while remote "
414 "connection is required" % samdb_url
)
416 # set defaults for test environment
419 if session_info
is None:
420 session_info
= samba
.auth
.system_session(lp
)
421 if credentials
is None:
422 credentials
= cmdline_credentials
424 return SamDB(url
=samdb_url
,
426 session_info
=session_info
,
427 credentials
=credentials
,
430 global_schema
=global_schema
)
433 def connect_samdb_ex(samdb_url
, lp
=None, session_info
=None, credentials
=None,
434 flags
=0, ldb_options
=None, ldap_only
=False):
435 """Connects to samdb_url database
437 :param samdb_url: Url for database to connect to.
438 :param lp: Optional loadparm object
439 :param session_info: Optional session information
440 :param credentials: Optional credentials, defaults to anonymous.
441 :param flags: Optional LDB flags
442 :param ldap_only: If set, only remote LDAP connection will be created.
443 :return: (sam_db_connection, rootDse_record) tuple
445 sam_db
= connect_samdb(samdb_url
, lp
, session_info
, credentials
,
446 flags
, ldb_options
, ldap_only
)
448 res
= sam_db
.search(base
="", expression
="", scope
=ldb
.SCOPE_BASE
,
450 return (sam_db
, res
[0])
453 def connect_samdb_env(env_url
, env_username
, env_password
, lp
=None):
454 """Connect to SamDB by getting URL and Credentials from environment
456 :param env_url: Environment variable name to get lsb url from
457 :param env_username: Username environment variable
458 :param env_password: Password environment variable
459 :return: sam_db_connection
461 samdb_url
= env_get_var_value(env_url
)
462 creds
= credentials
.Credentials()
464 # guess Credentials parameters here. Otherwise workstation
465 # and domain fields are NULL and gencache code segfalts
466 lp
= param
.LoadParm()
468 creds
.set_username(env_get_var_value(env_username
))
469 creds
.set_password(env_get_var_value(env_password
))
470 return connect_samdb(samdb_url
, credentials
=creds
, lp
=lp
)
473 def delete_force(samdb
, dn
, **kwargs
):
475 samdb
.delete(dn
, **kwargs
)
476 except ldb
.LdbError
as error
:
477 (num
, errstr
) = error
.args
478 assert num
== ldb
.ERR_NO_SUCH_OBJECT
, "ldb.delete() failed: %s" % errstr
480 def create_test_ou(samdb
, name
):
481 """Creates a unique OU for the test"""
483 # Add some randomness to the test OU. Replication between the testenvs is
484 # constantly happening in the background. Deletion of the last test's
485 # objects can be slow to replicate out. So the OU created by a previous
486 # testenv may still exist at the point that tests start on another testenv.
487 rand
= randint(1, 10000000)
488 dn
= "OU=%s%d,%s" %(name
, rand
, samdb
.get_default_basedn())
489 samdb
.add({ "dn": dn
, "objectclass": "organizationalUnit"})