add isl_space_check_is_wrapping
[isl.git] / isl_test_python.py
blob3894ad1cced9199815924e23eb6f8b44716e8b05
1 # Copyright 2016-2017 Tobias Grosser
3 # Use of this software is governed by the MIT license
5 # Written by Tobias Grosser, Weststrasse 47, CH-8003, Zurich
7 import sys
8 import isl
10 # Test that isl objects can be constructed.
12 # This tests:
13 # - construction from a string
14 # - construction from an integer
15 # - static constructor without a parameter
16 # - conversion construction
17 # - construction of empty union set
19 # The tests to construct from integers and strings cover functionality that
20 # is also tested in the parameter type tests, but here the presence of
21 # multiple overloaded constructors and overload resolution is tested.
23 def test_constructors():
24 zero1 = isl.val("0")
25 assert(zero1.is_zero())
27 zero2 = isl.val(0)
28 assert(zero2.is_zero())
30 zero3 = isl.val.zero()
31 assert(zero3.is_zero())
33 bs = isl.basic_set("{ [1] }")
34 result = isl.set("{ [1] }")
35 s = isl.set(bs)
36 assert(s.is_equal(result))
38 us = isl.union_set("{ A[1]; B[2, 3] }")
39 empty = isl.union_set.empty()
40 assert(us.is_equal(us.union(empty)))
42 # Test integer function parameters for a particular integer value.
44 def test_int(i):
45 val_int = isl.val(i)
46 val_str = isl.val(str(i))
47 assert(val_int.eq(val_str))
49 # Test integer function parameters.
51 # Verify that extreme values and zero work.
53 def test_parameters_int():
54 test_int(sys.maxsize)
55 test_int(-sys.maxsize - 1)
56 test_int(0)
58 # Test isl objects parameters.
60 # Verify that isl objects can be passed as lvalue and rvalue parameters.
61 # Also verify that isl object parameters are automatically type converted if
62 # there is an inheritance relation. Finally, test function calls without
63 # any additional parameters, apart from the isl object on which
64 # the method is called.
66 def test_parameters_obj():
67 a = isl.set("{ [0] }")
68 b = isl.set("{ [1] }")
69 c = isl.set("{ [2] }")
70 expected = isl.set("{ [i] : 0 <= i <= 2 }")
72 tmp = a.union(b)
73 res_lvalue_param = tmp.union(c)
74 assert(res_lvalue_param.is_equal(expected))
76 res_rvalue_param = a.union(b).union(c)
77 assert(res_rvalue_param.is_equal(expected))
79 a2 = isl.basic_set("{ [0] }")
80 assert(a.is_equal(a2))
82 two = isl.val(2)
83 half = isl.val("1/2")
84 res_only_this_param = two.inv()
85 assert(res_only_this_param.eq(half))
87 # Test different kinds of parameters to be passed to functions.
89 # This includes integer and isl object parameters.
91 def test_parameters():
92 test_parameters_int()
93 test_parameters_obj()
95 # Test that isl objects are returned correctly.
97 # This only tests that after combining two objects, the result is successfully
98 # returned.
100 def test_return_obj():
101 one = isl.val("1")
102 two = isl.val("2")
103 three = isl.val("3")
105 res = one.add(two)
107 assert(res.eq(three))
109 # Test that integer values are returned correctly.
111 def test_return_int():
112 one = isl.val("1")
113 neg_one = isl.val("-1")
114 zero = isl.val("0")
116 assert(one.sgn() > 0)
117 assert(neg_one.sgn() < 0)
118 assert(zero.sgn() == 0)
120 # Test that isl_bool values are returned correctly.
122 # In particular, check the conversion to bool in case of true and false.
124 def test_return_bool():
125 empty = isl.set("{ : false }")
126 univ = isl.set("{ : }")
128 b_true = empty.is_empty()
129 b_false = univ.is_empty()
131 assert(b_true)
132 assert(not b_false)
134 # Test that strings are returned correctly.
135 # Do so by calling overloaded isl.ast_build.from_expr methods.
137 def test_return_string():
138 context = isl.set("[n] -> { : }")
139 build = isl.ast_build.from_context(context)
140 pw_aff = isl.pw_aff("[n] -> { [n] }")
141 set = isl.set("[n] -> { : n >= 0 }")
143 expr = build.expr_from(pw_aff)
144 expected_string = "n"
145 assert(expected_string == expr.to_C_str())
147 expr = build.expr_from(set)
148 expected_string = "n >= 0"
149 assert(expected_string == expr.to_C_str())
151 # Test that return values are handled correctly.
153 # Test that isl objects, integers, boolean values, and strings are
154 # returned correctly.
156 def test_return():
157 test_return_obj()
158 test_return_int()
159 test_return_bool()
160 test_return_string()
162 # A class that is used to test isl.id.user.
164 class S:
165 def __init__(self):
166 self.value = 42
168 # Test isl.id.user.
170 # In particular, check that the object attached to an identifier
171 # can be retrieved again.
173 def test_user():
174 id = isl.id("test", 5)
175 id2 = isl.id("test2")
176 id3 = isl.id("S", S())
177 assert id.user() == 5, f"unexpected user object {id.user()}"
178 assert id2.user() is None, f"unexpected user object {id2.user()}"
179 s = id3.user()
180 assert isinstance(s, S), f"unexpected user object {s}"
181 assert s.value == 42, f"unexpected user object {s}"
183 # Test that foreach functions are modeled correctly.
185 # Verify that closures are correctly called as callback of a 'foreach'
186 # function and that variables captured by the closure work correctly. Also
187 # check that the foreach function handles exceptions thrown from
188 # the closure and that it propagates the exception.
190 def test_foreach():
191 s = isl.set("{ [0]; [1]; [2] }")
193 list = []
194 def add(bs):
195 list.append(bs)
196 s.foreach_basic_set(add)
198 assert(len(list) == 3)
199 assert(list[0].is_subset(s))
200 assert(list[1].is_subset(s))
201 assert(list[2].is_subset(s))
202 assert(not list[0].is_equal(list[1]))
203 assert(not list[0].is_equal(list[2]))
204 assert(not list[1].is_equal(list[2]))
206 def fail(bs):
207 raise Exception("fail")
209 caught = False
210 try:
211 s.foreach_basic_set(fail)
212 except:
213 caught = True
214 assert(caught)
216 # Test the functionality of "foreach_scc" functions.
218 # In particular, test it on a list of elements that can be completely sorted
219 # but where two of the elements ("a" and "b") are incomparable.
221 def test_foreach_scc():
222 list = isl.id_list(3)
223 sorted = [isl.id_list(3)]
224 data = {
225 'a' : isl.map("{ [0] -> [1] }"),
226 'b' : isl.map("{ [1] -> [0] }"),
227 'c' : isl.map("{ [i = 0:1] -> [i] }"),
229 for k, v in data.items():
230 list = list.add(k)
231 id = data['a'].space().domain().identity_multi_pw_aff_on_domain()
232 def follows(a, b):
233 map = data[b.name()].apply_domain(data[a.name()])
234 return not map.lex_ge_at(id).is_empty()
236 def add_single(scc):
237 assert(scc.size() == 1)
238 sorted[0] = sorted[0].concat(scc)
240 list.foreach_scc(follows, add_single)
241 assert(sorted[0].size() == 3)
242 assert(sorted[0].at(0).name() == "b")
243 assert(sorted[0].at(1).name() == "c")
244 assert(sorted[0].at(2).name() == "a")
246 # Test the functionality of "every" functions.
248 # In particular, test the generic functionality and
249 # test that exceptions are properly propagated.
251 def test_every():
252 us = isl.union_set("{ A[i]; B[j] }")
254 def is_empty(s):
255 return s.is_empty()
256 assert(not us.every_set(is_empty))
258 def is_non_empty(s):
259 return not s.is_empty()
260 assert(us.every_set(is_non_empty))
262 def in_A(s):
263 return s.is_subset(isl.set("{ A[x] }"))
264 assert(not us.every_set(in_A))
266 def not_in_A(s):
267 return not s.is_subset(isl.set("{ A[x] }"))
268 assert(not us.every_set(not_in_A))
270 def fail(s):
271 raise Exception("fail")
273 caught = False
274 try:
275 us.ever_set(fail)
276 except:
277 caught = True
278 assert(caught)
280 # Check basic construction of spaces.
282 def test_space():
283 unit = isl.space.unit()
284 set_space = unit.add_named_tuple("A", 3)
285 map_space = set_space.add_named_tuple("B", 2)
287 set = isl.set.universe(set_space)
288 map = isl.map.universe(map_space)
289 assert(set.is_equal(isl.set("{ A[*,*,*] }")))
290 assert(map.is_equal(isl.map("{ A[*,*,*] -> B[*,*] }")))
292 # Construct a simple schedule tree with an outer sequence node and
293 # a single-dimensional band node in each branch, with one of them
294 # marked coincident.
296 def construct_schedule_tree():
297 A = isl.union_set("{ A[i] : 0 <= i < 10 }")
298 B = isl.union_set("{ B[i] : 0 <= i < 20 }")
300 node = isl.schedule_node.from_domain(A.union(B))
301 node = node.child(0)
303 filters = isl.union_set_list(A).add(B)
304 node = node.insert_sequence(filters)
306 f_A = isl.multi_union_pw_aff("[ { A[i] -> [i] } ]")
307 node = node.child(0)
308 node = node.child(0)
309 node = node.insert_partial_schedule(f_A)
310 node = node.member_set_coincident(0, True)
311 node = node.ancestor(2)
313 f_B = isl.multi_union_pw_aff("[ { B[i] -> [i] } ]")
314 node = node.child(1)
315 node = node.child(0)
316 node = node.insert_partial_schedule(f_B)
317 node = node.ancestor(2)
319 return node.schedule()
321 # Test basic schedule tree functionality.
323 # In particular, create a simple schedule tree and
324 # - check that the root node is a domain node
325 # - test map_descendant_bottom_up
326 # - test foreach_descendant_top_down
327 # - test every_descendant
329 def test_schedule_tree():
330 schedule = construct_schedule_tree()
331 root = schedule.root()
333 assert(type(root) == isl.schedule_node_domain)
335 count = [0]
336 def inc_count(node):
337 count[0] += 1
338 return node
339 root = root.map_descendant_bottom_up(inc_count)
340 assert(count[0] == 8)
342 def fail_map(node):
343 raise Exception("fail")
344 return node
345 caught = False
346 try:
347 root.map_descendant_bottom_up(fail_map)
348 except:
349 caught = True
350 assert(caught)
352 count = [0]
353 def inc_count(node):
354 count[0] += 1
355 return True
356 root.foreach_descendant_top_down(inc_count)
357 assert(count[0] == 8)
359 count = [0]
360 def inc_count(node):
361 count[0] += 1
362 return False
363 root.foreach_descendant_top_down(inc_count)
364 assert(count[0] == 1)
366 def is_not_domain(node):
367 return type(node) != isl.schedule_node_domain
368 assert(root.child(0).every_descendant(is_not_domain))
369 assert(not root.every_descendant(is_not_domain))
371 def fail(node):
372 raise Exception("fail")
373 caught = False
374 try:
375 root.every_descendant(fail)
376 except:
377 caught = True
378 assert(caught)
380 domain = root.domain()
381 filters = [isl.union_set("{}")]
382 def collect_filters(node):
383 if type(node) == isl.schedule_node_filter:
384 filters[0] = filters[0].union(node.filter())
385 return True
386 root.every_descendant(collect_filters)
387 assert(domain.is_equal(filters[0]))
389 # Test marking band members for unrolling.
390 # "schedule" is the schedule created by construct_schedule_tree.
391 # It schedules two statements, with 10 and 20 instances, respectively.
392 # Unrolling all band members therefore results in 30 at-domain calls
393 # by the AST generator.
395 def test_ast_build_unroll(schedule):
396 root = schedule.root()
397 def mark_unroll(node):
398 if type(node) == isl.schedule_node_band:
399 node = node.member_set_ast_loop_unroll(0)
400 return node
401 root = root.map_descendant_bottom_up(mark_unroll)
402 schedule = root.schedule()
404 count_ast = [0]
405 def inc_count_ast(node, build):
406 count_ast[0] += 1
407 return node
409 build = isl.ast_build()
410 build = build.set_at_each_domain(inc_count_ast)
411 ast = build.node_from(schedule)
412 assert(count_ast[0] == 30)
414 # Test basic AST generation from a schedule tree.
416 # In particular, create a simple schedule tree and
417 # - generate an AST from the schedule tree
418 # - test at_each_domain
419 # - test unrolling
421 def test_ast_build():
422 schedule = construct_schedule_tree()
424 count_ast = [0]
425 def inc_count_ast(node, build):
426 count_ast[0] += 1
427 return node
429 build = isl.ast_build()
430 build_copy = build.set_at_each_domain(inc_count_ast)
431 ast = build.node_from(schedule)
432 assert(count_ast[0] == 0)
433 count_ast[0] = 0
434 ast = build_copy.node_from(schedule)
435 assert(count_ast[0] == 2)
436 build = build_copy
437 count_ast[0] = 0
438 ast = build.node_from(schedule)
439 assert(count_ast[0] == 2)
441 do_fail = True
442 count_ast_fail = [0]
443 def fail_inc_count_ast(node, build):
444 count_ast_fail[0] += 1
445 if do_fail:
446 raise Exception("fail")
447 return node
448 build = isl.ast_build()
449 build = build.set_at_each_domain(fail_inc_count_ast)
450 caught = False
451 try:
452 ast = build.node_from(schedule)
453 except:
454 caught = True
455 assert(caught)
456 assert(count_ast_fail[0] > 0)
457 build_copy = build
458 build_copy = build_copy.set_at_each_domain(inc_count_ast)
459 count_ast[0] = 0
460 ast = build_copy.node_from(schedule)
461 assert(count_ast[0] == 2)
462 count_ast_fail[0] = 0
463 do_fail = False
464 ast = build.node_from(schedule)
465 assert(count_ast_fail[0] == 2)
467 test_ast_build_unroll(schedule)
469 # Test basic AST expression generation from an affine expression.
471 def test_ast_build_expr():
472 pa = isl.pw_aff("[n] -> { [n + 1] }")
473 build = isl.ast_build.from_context(pa.domain())
475 op = build.expr_from(pa)
476 assert(type(op) == isl.ast_expr_op_add)
477 assert(op.n_arg() == 2)
479 # Test the isl Python interface
481 # This includes:
482 # - Object construction
483 # - Different parameter types
484 # - Different return types
485 # - isl.id.user
486 # - Foreach functions
487 # - Foreach SCC function
488 # - Every functions
489 # - Spaces
490 # - Schedule trees
491 # - AST generation
492 # - AST expression generation
494 test_constructors()
495 test_parameters()
496 test_return()
497 test_user()
498 test_foreach()
499 test_foreach_scc()
500 test_every()
501 test_space()
502 test_schedule_tree()
503 test_ast_build()
504 test_ast_build_expr()