smbd: Fix crossing automounter mount points
[Samba.git] / python / samba / tests / complex_expressions.py
blob4cb6330c895514510fa15435f1c86aa91334c796
1 # -*- coding: utf-8 -*-
3 # Copyright Andrew Bartlett 2018
5 # This program is free software; you can redistribute it and/or modify
6 # it under the terms of the GNU General Public License as published by
7 # the Free Software Foundation; either version 3 of the License, or
8 # (at your option) any later version.
10 # This program is distributed in the hope that it will be useful,
11 # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 # GNU General Public License for more details.
15 # You should have received a copy of the GNU General Public License
16 # along with this program. If not, see <http://www.gnu.org/licenses/>.
19 import optparse
20 import samba
21 import samba.getopt as options
22 import sys
23 import os
24 import time
25 from samba.auth import system_session
26 from samba.tests import TestCase
27 import ldb
29 ERRCODE_ENTRY_EXISTS = 68
30 ERRCODE_OPERATIONS_ERROR = 1
31 ERRCODE_INVALID_VALUE = 21
32 ERRCODE_CLASS_VIOLATION = 65
34 parser = optparse.OptionParser("{0} <host>".format(sys.argv[0]))
35 sambaopts = options.SambaOptions(parser)
37 # use command line creds if available
38 credopts = options.CredentialsOptions(parser)
39 parser.add_option_group(credopts)
40 parser.add_option("-v", action="store_true", dest="verbose",
41 help="print successful expression outputs")
42 opts, args = parser.parse_args()
44 if len(args) < 1:
45 parser.print_usage()
46 sys.exit(1)
48 lp = sambaopts.get_loadparm()
49 creds = credopts.get_credentials(lp)
51 # Set properly at end of file.
52 host = None
54 global ou_count
55 ou_count = 0
58 class ComplexExpressionTests(TestCase):
59 # Using setUpClass instead of setup because we're not modifying any
60 # records in the tests
61 @classmethod
62 def setUpClass(cls):
63 super().setUpClass()
64 cls.samdb = samba.samdb.SamDB(host, lp=lp,
65 session_info=system_session(),
66 credentials=creds)
68 ou_name = "ComplexExprTest"
69 cls.base_dn = "OU={0},{1}".format(ou_name, cls.samdb.domain_dn())
71 try:
72 cls.samdb.delete(cls.base_dn, ["tree_delete:1"])
73 except:
74 pass
76 try:
77 cls.samdb.create_ou(cls.base_dn)
78 except ldb.LdbError as e:
79 if e.args[0] == ERRCODE_ENTRY_EXISTS:
80 print(('test ou {ou} already exists. Delete with '
81 '"samba-tool group delete OU={ou} '
82 '--force-subtree-delete"').format(ou=ou_name))
83 raise e
85 cls.name_template = "testuser{0}"
86 cls.default_n = 10
88 # These fields are carefully hand-picked from the schema. They have
89 # syntax and handling appropriate for our test structure.
90 cls.largeint_f = "accountExpires"
91 cls.str_f = "accountNameHistory"
92 cls.int_f = "flags"
93 cls.enum_f = "preferredDeliveryMethod"
94 cls.time_f = "msTSExpireDate"
95 cls.ranged_int_f = "countryCode"
97 @classmethod
98 def tearDownClass(cls):
99 cls.samdb.delete(cls.base_dn, ["tree_delete:1"])
101 # Make test OU containing users with field=val for each val
102 def make_test_objects(self, field, vals):
103 global ou_count
104 ou_count += 1
105 ou_dn = "OU=testou{0},{1}".format(ou_count, self.base_dn)
106 self.samdb.create_ou(ou_dn)
108 ldap_objects = [{"dn": "CN=testuser{0},{1}".format(n, ou_dn),
109 "name": self.name_template.format(n),
110 "objectClass": "user",
111 field: n}
112 for n in vals]
114 for ldap_object in ldap_objects:
115 # It's useful to keep appropriate python types in the ldap_object
116 # dict but samdb's 'add' function expects strings.
117 stringed_ldap_object = {k: str(v)
118 for (k, v) in ldap_object.items()}
119 try:
120 self.samdb.add(stringed_ldap_object)
121 except ldb.LdbError as e:
122 print("failed to add %s" % (stringed_ldap_object))
123 raise e
125 return ou_dn, ldap_objects
127 # Run search expr and print out time. This function should be used for
128 # almost all searching.
129 def time_ldap_search(self, expr, dn):
130 time_taken = 0
131 try:
132 start_time = time.time()
133 res = self.samdb.search(base=dn,
134 scope=ldb.SCOPE_SUBTREE,
135 expression=expr)
136 time_taken = time.time() - start_time
137 except Exception as e:
138 print("failed expr " + expr)
139 raise e
140 print("{0} took {1}s".format(expr, time_taken))
141 return res, time_taken
143 # Take an ldap expression and an equivalent python expression.
144 # Run and time the ldap expression and compare the result to the python
145 # expression run over a list of ldap_object dicts.
146 def assertLDAPQuery(self, ldap_expr, ou_dn, py_expr, ldap_objects):
148 # run (and time) the LDAP search expression over the DB
149 res, time_taken = self.time_ldap_search(ldap_expr, ou_dn)
150 results = {str(row.get('name')[0]) for row in res}
152 # build the set of expected results by evaluating the python-equivalent
153 # of the search expression over the same set of objects
154 expected_results = set()
155 for ldap_object in ldap_objects:
156 try:
157 final_expr = py_expr.format(**ldap_object)
158 except KeyError:
159 # If the format on the py_expr hits a key error, then
160 # ldap_object doesn't have the field, so it shouldn't match.
161 continue
163 if eval(final_expr):
164 expected_results.add(str(ldap_object['name']))
166 self.assertEqual(results, expected_results)
168 if opts.verbose:
169 ldap_object_names = {l['name'] for l in ldap_objects}
170 excluded = ldap_object_names - results
171 excluded = "\n ".join(excluded) or "[NOTHING]"
172 returned = "\n ".join(expected_results) or "[NOTHING]"
174 print("PASS: Expression {0} took {1}s and returned:"
175 "\n {2}\n"
176 "Excluded:\n {3}\n".format(ldap_expr,
177 time_taken,
178 returned,
179 excluded))
181 # Basic integer range test
182 def test_int_range(self, field=None):
183 n = self.default_n
184 field = field or self.int_f
185 ou_dn, ldap_objects = self.make_test_objects(field, range(n))
187 expr = "(&(%s>=%s)(%s<=%s))" % (field, n-1, field, n+1)
188 py_expr = "%d <= {%s} <= %d" % (n-1, field, n+1)
189 self.assertLDAPQuery(expr, ou_dn, py_expr, ldap_objects)
191 half_n = int(n/2)
193 expr = "(%s<=%s)" % (field, half_n)
194 py_expr = "{%s} <= %d" % (field, half_n)
195 self.assertLDAPQuery(expr, ou_dn, py_expr, ldap_objects)
197 expr = "(%s>=%s)" % (field, half_n)
198 py_expr = "{%s} >= %d" % (field, half_n)
199 self.assertLDAPQuery(expr, ou_dn, py_expr, ldap_objects)
201 # Same test again for largeint and enum
202 def test_largeint_range(self):
203 self.test_int_range(self.largeint_f)
205 def test_enum_range(self):
206 self.test_int_range(self.enum_f)
208 # Special range test for integer field with upper and lower bounds defined.
209 # The bounds are checked on insertion, not search, so we should be able
210 # to compare to a constant that is outside bounds.
211 def test_ranged_int_range(self):
212 field = self.ranged_int_f
213 ubound = 2**16
214 width = 8
216 vals = list(range(ubound-width, ubound))
217 ou_dn, ldap_objects = self.make_test_objects(field, vals)
219 # Check <= value above overflow returns all vals
220 expr = "(%s<=%d)" % (field, ubound+5)
221 py_expr = "{%s} <= %d" % (field, ubound+5)
222 self.assertLDAPQuery(expr, ou_dn, py_expr, ldap_objects)
224 # Test range also works for time fields
225 def test_time_range(self):
226 n = self.default_n
227 field = self.time_f
228 n = self.default_n
229 width = int(n/2)
231 base_time = 20050116175514
232 time_range = [base_time + t for t in range(-width, width)]
233 time_range = [str(t) + ".0Z" for t in time_range]
234 ou_dn, ldap_objects = self.make_test_objects(field, time_range)
236 expr = "(%s<=%s)" % (field, str(base_time) + ".0Z")
237 py_expr = 'int("{%s}"[:-3]) <= %d' % (field, base_time)
238 self.assertLDAPQuery(expr, ou_dn, py_expr, ldap_objects)
240 expr = "(&(%s>=%s)(%s<=%s))" % (field, str(base_time-1) + ".0Z",
241 field, str(base_time+1) + ".0Z")
242 py_expr = '%d <= int("{%s}"[:-3]) <= %d' % (base_time-1,
243 field,
244 base_time+1)
245 self.assertLDAPQuery(expr, ou_dn, py_expr, ldap_objects)
247 # Run each comparison op on a simple test set. Time taken will be printed.
248 def test_int_single_cmp_op_speeds(self, field=None):
249 n = self.default_n
250 field = field or self.int_f
251 ou_dn, ldap_objects = self.make_test_objects(field, range(n))
253 comp_ops = ['=', '<=', '>=']
254 py_comp_ops = ['==', '<=', '>=']
255 exprs = ["(%s%s%d)" % (field, c, n) for c in comp_ops]
256 py_exprs = ["{%s}%s%d" % (field, c, n) for c in py_comp_ops]
258 for expr, py_expr in zip(exprs, py_exprs):
259 self.assertLDAPQuery(expr, ou_dn, py_expr, ldap_objects)
261 def test_largeint_single_cmp_op_speeds(self):
262 self.test_int_single_cmp_op_speeds(self.largeint_f)
264 def test_enum_single_cmp_op_speeds(self):
265 self.test_int_single_cmp_op_speeds(self.enum_f)
267 # Check strings are ordered using a naive ordering.
268 def test_str_ordering(self):
269 field = self.str_f
270 a_ord = ord('A')
271 n = 10
272 str_range = ['abc{0}d'.format(chr(c)) for c in range(a_ord, a_ord+n)]
273 ou_dn, ldap_objects = self.make_test_objects(field, str_range)
274 half_n = int(a_ord + n/2)
276 # Basic <= and >= statements
277 expr = "(%s>=abc%s)" % (field, chr(half_n))
278 py_expr = "'{%s}' >= 'abc%s'" % (field, chr(half_n))
279 self.assertLDAPQuery(expr, ou_dn, py_expr, ldap_objects)
281 expr = "(%s<=abc%s)" % (field, chr(half_n))
282 py_expr = "'{%s}' <= 'abc%s'" % (field, chr(half_n))
283 self.assertLDAPQuery(expr, ou_dn, py_expr, ldap_objects)
285 # String range
286 expr = "(&(%s>=abc%s)(%s<=abc%s))" % (field, chr(half_n-2),
287 field, chr(half_n+2))
288 py_expr = "'abc%s' <= '{%s}' <= 'abc%s'" % (chr(half_n-2),
289 field,
290 chr(half_n+2))
291 self.assertLDAPQuery(expr, ou_dn, py_expr, ldap_objects)
293 # Integers treated as string
294 expr = "(%s>=1)" % (field)
295 py_expr = "'{%s}' >= '1'" % (field)
296 self.assertLDAPQuery(expr, ou_dn, py_expr, ldap_objects)
298 # Windows returns nothing for invalid expressions. Expected fail on samba.
299 def test_invalid_expressions(self, field=None):
300 field = field or self.int_f
301 n = self.default_n
302 ou_dn, ldap_objects = self.make_test_objects(field, list(range(n)))
303 int_expressions = ["(%s>=abc)",
304 "(%s<=abc)",
305 "(%s=abc)"]
307 for expr in int_expressions:
308 expr = expr % (field)
309 self.assertLDAPQuery(expr, ou_dn, "False", ldap_objects)
311 def test_largeint_invalid_expressions(self):
312 self.test_invalid_expressions(self.largeint_f)
314 def test_enum_invalid_expressions(self):
315 self.test_invalid_expressions(self.enum_f)
317 def test_case_insensitive(self):
318 str_range = ["äbc"+str(n) for n in range(10)]
319 ou_dn, ldap_objects = self.make_test_objects(self.str_f, str_range)
321 expr = "(%s=äbc1)" % (self.str_f)
322 pyexpr = '"{%s}"=="äbc1"' % (self.str_f)
323 self.assertLDAPQuery(expr, ou_dn, pyexpr, ldap_objects)
325 expr = "(%s=ÄbC1)" % (self.str_f)
326 self.assertLDAPQuery(expr, ou_dn, pyexpr, ldap_objects)
328 # Check negative numbers can be entered and compared
329 def test_negative_cmp(self, field=None):
330 field = field or self.int_f
331 width = 6
332 around_zero = list(range(-width, width))
333 ou_dn, ldap_objects = self.make_test_objects(field, around_zero)
335 expr = "(%s>=-3)" % (field)
336 py_expr = "{%s} >= -3" % (field)
337 self.assertLDAPQuery(expr, ou_dn, py_expr, ldap_objects)
339 def test_negative_cmp_largeint(self):
340 self.test_negative_cmp(self.largeint_f)
342 def test_negative_cmp_enum(self):
343 self.test_negative_cmp(self.enum_f)
345 # Check behaviour on insertion and comparison of zero-prefixed numbers.
346 # Samba errors on insertion, Windows strips the leading zeroes.
347 def test_zero_prefix(self, field=None):
348 field = field or self.int_f
350 # Test comparison with 0-prefixed constants.
351 n = self.default_n
352 ou_dn, ldap_objects = self.make_test_objects(field, list(range(n)))
354 expr = "(%s>=00%d)" % (field, n/2)
355 py_expr = "{%s} >= %d" % (field, n/2)
356 self.assertLDAPQuery(expr, ou_dn, py_expr, ldap_objects)
358 # Delete the test OU so we don't mix it up with the next one.
359 self.samdb.delete(ou_dn, ["tree_delete:1"])
361 # Try inserting 0-prefixed numbers, check it fails.
362 zero_pref_nums = ['00'+str(num) for num in range(n)]
363 try:
364 ou_dn, ldap_objects = self.make_test_objects(field, zero_pref_nums)
365 except ldb.LdbError as e:
366 if e.args[0] != ERRCODE_INVALID_VALUE:
367 raise e
368 return
370 # Samba doesn't get this far - the exception is raised. Windows allows
371 # the insertion and removes the leading 0s as tested below.
372 # Either behaviour is fine.
373 print("LDAP allowed insertion of 0-prefixed nums for field " + field)
375 res = self.samdb.search(base=ou_dn,
376 scope=ldb.SCOPE_SUBTREE,
377 expression="(objectClass=user)")
378 returned_nums = [str(r.get(field)[0]) for r in res]
379 expect = [str(n) for n in range(n)]
380 self.assertEqual(set(returned_nums), set(expect))
382 expr = "(%s>=%d)" % (field, n/2)
383 py_expr = "{%s} >= %d" % (field, n/2)
384 for ldap_object in ldap_objects:
385 ldap_object[field] = int(ldap_object[field])
387 self.assertLDAPQuery(expr, ou_dn, py_expr, ldap_objects)
389 def test_zero_prefix_largeint(self):
390 self.test_zero_prefix(self.largeint_f)
392 def test_zero_prefix_enum(self):
393 self.test_zero_prefix(self.enum_f)
395 # Check integer overflow is handled as best it can be.
396 def test_int_overflow(self, field=None, of=None):
397 field = field or self.int_f
398 of = of or 2**31-1
399 width = 8
401 vals = list(range(of-width, of+width))
402 ou_dn, ldap_objects = self.make_test_objects(field, vals)
404 # Check ">=overflow" doesn't return vals past overflow
405 expr = "(%s>=%d)" % (field, of-3)
406 py_expr = "%d <= {%s} <= %d" % (of-3, field, of)
407 self.assertLDAPQuery(expr, ou_dn, py_expr, ldap_objects)
409 # "<=overflow" returns everything
410 expr = "(%s<=%d)" % (field, of)
411 py_expr = "True"
412 self.assertLDAPQuery(expr, ou_dn, py_expr, ldap_objects)
414 # Values past overflow should be negative
415 expr = "(&(%s<=%d)(%s>=0))" % (field, of, field)
416 py_expr = "{%s} <= %d" % (field, of)
417 self.assertLDAPQuery(expr, ou_dn, py_expr, ldap_objects)
418 expr = "(%s<=0)" % (field)
419 py_expr = "{%s} >= %d" % (field, of+1)
420 self.assertLDAPQuery(expr, ou_dn, py_expr, ldap_objects)
422 # Get the values back out and check vals past overflow are negative.
423 res = self.samdb.search(base=ou_dn,
424 scope=ldb.SCOPE_SUBTREE,
425 expression="(objectClass=user)")
426 returned_nums = [str(r.get(field)[0]) for r in res]
428 # Note: range(a,b) == [a..b-1] (confusing)
429 up_to_overflow = list(range(of-width, of+1))
430 negatives = list(range(-of-1, -of+width-2))
432 expect = [str(n) for n in up_to_overflow + negatives]
433 self.assertEqual(set(returned_nums), set(expect))
435 def test_enum_overflow(self):
436 self.test_int_overflow(self.enum_f, 2**31-1)
438 # Check cmp works on uSNChanged. We can't insert uSNChanged vals, they get
439 # added automatically so we'll just insert some objects and go with what
440 # we get.
441 def test_usnchanged(self):
442 field = "uSNChanged"
443 n = 10
444 # Note we can't actually set uSNChanged via LDAP (LDB ignores it),
445 # so the input val range doesn't matter here
446 ou_dn, _ = self.make_test_objects(field, list(range(n)))
448 # Get the assigned uSNChanged values
449 res = self.samdb.search(base=ou_dn,
450 scope=ldb.SCOPE_SUBTREE,
451 expression="(objectClass=user)")
453 # Our vals got ignored so make ldap_objects from search result
454 ldap_objects = [{'name': str(r['name'][0]),
455 field: int(r[field][0])}
456 for r in res]
458 # Get the median val and use as the number in the test search expr.
459 nums = [l[field] for l in ldap_objects]
460 nums = list(sorted(nums))
461 search_num = nums[int(len(nums)/2)]
463 expr = "(&(%s<=%d)(objectClass=user))" % (field, search_num)
464 py_expr = "{%s} <= %d" % (field, search_num)
465 self.assertLDAPQuery(expr, ou_dn, py_expr, ldap_objects)
467 expr = "(&(%s>=%d)(objectClass=user))" % (field, search_num)
468 py_expr = "{%s} >= %d" % (field, search_num)
469 self.assertLDAPQuery(expr, ou_dn, py_expr, ldap_objects)
472 # If we're called independently then import subunit, get host from first
473 # arg and run. Otherwise, subunit ran us so just set host from env.
474 # We always try to run over LDAP rather than direct file, so that
475 # search timings are not impacted by opening and closing the tdb file.
476 if __name__ == "__main__":
477 from samba.tests.subunitrun import TestProgram
478 host = args[0]
480 if "://" not in host:
481 if os.path.isfile(host):
482 host = "tdb://%s" % host
483 else:
484 host = "ldap://%s" % host
485 TestProgram(module=__name__)
486 else:
487 host = "ldap://" + os.getenv("SERVER")