test isl_space_add_unnamed_tuple_ui export
[isl.git] / isl_test_python.py
blob92c10a87af17c2bc90055e566558346a625fe240
1 # Copyright 2016-2017 Tobias Grosser
3 # Use of this software is governed by the MIT license
5 # Written by Tobias Grosser, Weststrasse 47, CH-8003, Zurich
7 import sys
8 import isl
10 # Test that isl objects can be constructed.
12 # This tests:
13 # - construction from a string
14 # - construction from an integer
15 # - static constructor without a parameter
16 # - conversion construction
18 # The tests to construct from integers and strings cover functionality that
19 # is also tested in the parameter type tests, but here the presence of
20 # multiple overloaded constructors and overload resolution is tested.
22 def test_constructors():
23 zero1 = isl.val("0")
24 assert(zero1.is_zero())
26 zero2 = isl.val(0)
27 assert(zero2.is_zero())
29 zero3 = isl.val.zero()
30 assert(zero3.is_zero())
32 bs = isl.basic_set("{ [1] }")
33 result = isl.set("{ [1] }")
34 s = isl.set(bs)
35 assert(s.is_equal(result))
37 # Test integer function parameters for a particular integer value.
39 def test_int(i):
40 val_int = isl.val(i)
41 val_str = isl.val(str(i))
42 assert(val_int.eq(val_str))
44 # Test integer function parameters.
46 # Verify that extreme values and zero work.
48 def test_parameters_int():
49 test_int(sys.maxsize)
50 test_int(-sys.maxsize - 1)
51 test_int(0)
53 # Test isl objects parameters.
55 # Verify that isl objects can be passed as lvalue and rvalue parameters.
56 # Also verify that isl object parameters are automatically type converted if
57 # there is an inheritance relation. Finally, test function calls without
58 # any additional parameters, apart from the isl object on which
59 # the method is called.
61 def test_parameters_obj():
62 a = isl.set("{ [0] }")
63 b = isl.set("{ [1] }")
64 c = isl.set("{ [2] }")
65 expected = isl.set("{ [i] : 0 <= i <= 2 }")
67 tmp = a.union(b)
68 res_lvalue_param = tmp.union(c)
69 assert(res_lvalue_param.is_equal(expected))
71 res_rvalue_param = a.union(b).union(c)
72 assert(res_rvalue_param.is_equal(expected))
74 a2 = isl.basic_set("{ [0] }")
75 assert(a.is_equal(a2))
77 two = isl.val(2)
78 half = isl.val("1/2")
79 res_only_this_param = two.inv()
80 assert(res_only_this_param.eq(half))
82 # Test different kinds of parameters to be passed to functions.
84 # This includes integer and isl object parameters.
86 def test_parameters():
87 test_parameters_int()
88 test_parameters_obj()
90 # Test that isl objects are returned correctly.
92 # This only tests that after combining two objects, the result is successfully
93 # returned.
95 def test_return_obj():
96 one = isl.val("1")
97 two = isl.val("2")
98 three = isl.val("3")
100 res = one.add(two)
102 assert(res.eq(three))
104 # Test that integer values are returned correctly.
106 def test_return_int():
107 one = isl.val("1")
108 neg_one = isl.val("-1")
109 zero = isl.val("0")
111 assert(one.sgn() > 0)
112 assert(neg_one.sgn() < 0)
113 assert(zero.sgn() == 0)
115 # Test that isl_bool values are returned correctly.
117 # In particular, check the conversion to bool in case of true and false.
119 def test_return_bool():
120 empty = isl.set("{ : false }")
121 univ = isl.set("{ : }")
123 b_true = empty.is_empty()
124 b_false = univ.is_empty()
126 assert(b_true)
127 assert(not b_false)
129 # Test that strings are returned correctly.
130 # Do so by calling overloaded isl.ast_build.from_expr methods.
132 def test_return_string():
133 context = isl.set("[n] -> { : }")
134 build = isl.ast_build.from_context(context)
135 pw_aff = isl.pw_aff("[n] -> { [n] }")
136 set = isl.set("[n] -> { : n >= 0 }")
138 expr = build.expr_from(pw_aff)
139 expected_string = "n"
140 assert(expected_string == expr.to_C_str())
142 expr = build.expr_from(set)
143 expected_string = "n >= 0"
144 assert(expected_string == expr.to_C_str())
146 # Test that return values are handled correctly.
148 # Test that isl objects, integers, boolean values, and strings are
149 # returned correctly.
151 def test_return():
152 test_return_obj()
153 test_return_int()
154 test_return_bool()
155 test_return_string()
157 # Test that foreach functions are modeled correctly.
159 # Verify that closures are correctly called as callback of a 'foreach'
160 # function and that variables captured by the closure work correctly. Also
161 # check that the foreach function handles exceptions thrown from
162 # the closure and that it propagates the exception.
164 def test_foreach():
165 s = isl.set("{ [0]; [1]; [2] }")
167 list = []
168 def add(bs):
169 list.append(bs)
170 s.foreach_basic_set(add)
172 assert(len(list) == 3)
173 assert(list[0].is_subset(s))
174 assert(list[1].is_subset(s))
175 assert(list[2].is_subset(s))
176 assert(not list[0].is_equal(list[1]))
177 assert(not list[0].is_equal(list[2]))
178 assert(not list[1].is_equal(list[2]))
180 def fail(bs):
181 raise "fail"
183 caught = False
184 try:
185 s.foreach_basic_set(fail)
186 except:
187 caught = True
188 assert(caught)
190 # Test the functionality of "every" functions.
192 # In particular, test the generic functionality and
193 # test that exceptions are properly propagated.
195 def test_every():
196 us = isl.union_set("{ A[i]; B[j] }")
198 def is_empty(s):
199 return s.is_empty()
200 assert(not us.every_set(is_empty))
202 def is_non_empty(s):
203 return not s.is_empty()
204 assert(us.every_set(is_non_empty))
206 def in_A(s):
207 return s.is_subset(isl.set("{ A[x] }"))
208 assert(not us.every_set(in_A))
210 def not_in_A(s):
211 return not s.is_subset(isl.set("{ A[x] }"))
212 assert(not us.every_set(not_in_A))
214 def fail(s):
215 raise "fail"
217 caught = False
218 try:
219 us.ever_set(fail)
220 except:
221 caught = True
222 assert(caught)
224 # Check basic construction of spaces.
226 def test_space():
227 unit = isl.space.unit()
228 set_space = unit.add_named_tuple("A", 3)
229 map_space = set_space.add_named_tuple("B", 2)
231 set = isl.set.universe(set_space)
232 map = isl.map.universe(map_space)
233 assert(set.is_equal(isl.set("{ A[*,*,*] }")))
234 assert(map.is_equal(isl.map("{ A[*,*,*] -> B[*,*] }")))
236 # Construct a simple schedule tree with an outer sequence node and
237 # a single-dimensional band node in each branch, with one of them
238 # marked coincident.
240 def construct_schedule_tree():
241 A = isl.union_set("{ A[i] : 0 <= i < 10 }")
242 B = isl.union_set("{ B[i] : 0 <= i < 20 }")
244 node = isl.schedule_node.from_domain(A.union(B))
245 node = node.child(0)
247 filters = isl.union_set_list(A).add(B)
248 node = node.insert_sequence(filters)
250 f_A = isl.multi_union_pw_aff("[ { A[i] -> [i] } ]")
251 node = node.child(0)
252 node = node.child(0)
253 node = node.insert_partial_schedule(f_A)
254 node = node.member_set_coincident(0, True)
255 node = node.ancestor(2)
257 f_B = isl.multi_union_pw_aff("[ { B[i] -> [i] } ]")
258 node = node.child(1)
259 node = node.child(0)
260 node = node.insert_partial_schedule(f_B)
261 node = node.ancestor(2)
263 return node.schedule()
265 # Test basic schedule tree functionality.
267 # In particular, create a simple schedule tree and
268 # - check that the root node is a domain node
269 # - test map_descendant_bottom_up
270 # - test foreach_descendant_top_down
271 # - test every_descendant
273 def test_schedule_tree():
274 schedule = construct_schedule_tree()
275 root = schedule.root()
277 assert(type(root) == isl.schedule_node_domain)
279 count = [0]
280 def inc_count(node):
281 count[0] += 1
282 return node
283 root = root.map_descendant_bottom_up(inc_count)
284 assert(count[0] == 8)
286 def fail_map(node):
287 raise "fail"
288 return node
289 caught = False
290 try:
291 root.map_descendant_bottom_up(fail_map)
292 except:
293 caught = True
294 assert(caught)
296 count = [0]
297 def inc_count(node):
298 count[0] += 1
299 return True
300 root.foreach_descendant_top_down(inc_count)
301 assert(count[0] == 8)
303 count = [0]
304 def inc_count(node):
305 count[0] += 1
306 return False
307 root.foreach_descendant_top_down(inc_count)
308 assert(count[0] == 1)
310 def is_not_domain(node):
311 return type(node) != isl.schedule_node_domain
312 assert(root.child(0).every_descendant(is_not_domain))
313 assert(not root.every_descendant(is_not_domain))
315 def fail(node):
316 raise "fail"
317 caught = False
318 try:
319 root.every_descendant(fail)
320 except:
321 caught = True
322 assert(caught)
324 domain = root.domain()
325 filters = [isl.union_set("{}")]
326 def collect_filters(node):
327 if type(node) == isl.schedule_node_filter:
328 filters[0] = filters[0].union(node.filter())
329 return True
330 root.every_descendant(collect_filters)
331 assert(domain.is_equal(filters[0]))
333 # Test marking band members for unrolling.
334 # "schedule" is the schedule created by construct_schedule_tree.
335 # It schedules two statements, with 10 and 20 instances, respectively.
336 # Unrolling all band members therefore results in 30 at-domain calls
337 # by the AST generator.
339 def test_ast_build_unroll(schedule):
340 root = schedule.root()
341 def mark_unroll(node):
342 if type(node) == isl.schedule_node_band:
343 node = node.member_set_ast_loop_unroll(0)
344 return node
345 root = root.map_descendant_bottom_up(mark_unroll)
346 schedule = root.schedule()
348 count_ast = [0]
349 def inc_count_ast(node, build):
350 count_ast[0] += 1
351 return node
353 build = isl.ast_build()
354 build = build.set_at_each_domain(inc_count_ast)
355 ast = build.node_from(schedule)
356 assert(count_ast[0] == 30)
358 # Test basic AST generation from a schedule tree.
360 # In particular, create a simple schedule tree and
361 # - generate an AST from the schedule tree
362 # - test at_each_domain
363 # - test unrolling
365 def test_ast_build():
366 schedule = construct_schedule_tree()
368 count_ast = [0]
369 def inc_count_ast(node, build):
370 count_ast[0] += 1
371 return node
373 build = isl.ast_build()
374 build_copy = build.set_at_each_domain(inc_count_ast)
375 ast = build.node_from(schedule)
376 assert(count_ast[0] == 0)
377 count_ast[0] = 0
378 ast = build_copy.node_from(schedule)
379 assert(count_ast[0] == 2)
380 build = build_copy
381 count_ast[0] = 0
382 ast = build.node_from(schedule)
383 assert(count_ast[0] == 2)
385 do_fail = True
386 count_ast_fail = [0]
387 def fail_inc_count_ast(node, build):
388 count_ast_fail[0] += 1
389 if do_fail:
390 raise "fail"
391 return node
392 build = isl.ast_build()
393 build = build.set_at_each_domain(fail_inc_count_ast)
394 caught = False
395 try:
396 ast = build.node_from(schedule)
397 except:
398 caught = True
399 assert(caught)
400 assert(count_ast_fail[0] > 0)
401 build_copy = build
402 build_copy = build_copy.set_at_each_domain(inc_count_ast)
403 count_ast[0] = 0
404 ast = build_copy.node_from(schedule)
405 assert(count_ast[0] == 2)
406 count_ast_fail[0] = 0
407 do_fail = False
408 ast = build.node_from(schedule)
409 assert(count_ast_fail[0] == 2)
411 test_ast_build_unroll(schedule)
413 # Test basic AST expression generation from an affine expression.
415 def test_ast_build_expr():
416 pa = isl.pw_aff("[n] -> { [n + 1] }")
417 build = isl.ast_build.from_context(pa.domain())
419 op = build.expr_from(pa)
420 assert(type(op) == isl.ast_expr_op_add)
421 assert(op.n_arg() == 2)
423 # Test the isl Python interface
425 # This includes:
426 # - Object construction
427 # - Different parameter types
428 # - Different return types
429 # - Foreach functions
430 # - Every functions
431 # - Spaces
432 # - Schedule trees
433 # - AST generation
434 # - AST expression generation
436 test_constructors()
437 test_parameters()
438 test_return()
439 test_foreach()
440 test_every()
441 test_space()
442 test_schedule_tree()
443 test_ast_build()
444 test_ast_build_expr()