isl_basic_set_preimage: add extra sanity check
[isl.git] / isl_dim.c
blobea772336619bd4c3e4742db940ec4657d50830d9
1 /*
2 * Copyright 2008-2009 Katholieke Universiteit Leuven
4 * Use of this software is governed by the GNU LGPLv2.1 license
6 * Written by Sven Verdoolaege, K.U.Leuven, Departement
7 * Computerwetenschappen, Celestijnenlaan 200A, B-3001 Leuven, Belgium
8 */
10 #include <isl_dim.h>
11 #include "isl_name.h"
13 struct isl_dim *isl_dim_alloc(struct isl_ctx *ctx,
14 unsigned nparam, unsigned n_in, unsigned n_out)
16 struct isl_dim *dim;
18 dim = isl_alloc_type(ctx, struct isl_dim);
19 if (!dim)
20 return NULL;
22 dim->ctx = ctx;
23 isl_ctx_ref(ctx);
24 dim->ref = 1;
25 dim->nparam = nparam;
26 dim->n_in = n_in;
27 dim->n_out = n_out;
29 dim->n_name = 0;
30 dim->names = NULL;
32 return dim;
35 struct isl_dim *isl_dim_set_alloc(struct isl_ctx *ctx,
36 unsigned nparam, unsigned dim)
38 return isl_dim_alloc(ctx, nparam, 0, dim);
41 static unsigned global_pos(struct isl_dim *dim,
42 enum isl_dim_type type, unsigned pos)
44 struct isl_ctx *ctx = dim->ctx;
46 switch (type) {
47 case isl_dim_param:
48 isl_assert(ctx, pos < dim->nparam, return isl_dim_total(dim));
49 return pos;
50 case isl_dim_in:
51 isl_assert(ctx, pos < dim->n_in, return isl_dim_total(dim));
52 return pos + dim->nparam;
53 case isl_dim_out:
54 isl_assert(ctx, pos < dim->n_out, return isl_dim_total(dim));
55 return pos + dim->nparam + dim->n_in;
56 default:
57 isl_assert(ctx, 0, return isl_dim_total(dim));
59 return isl_dim_total(dim);
62 /* Extend length of names array to the total number of dimensions.
64 static __isl_give isl_dim *extend_names(__isl_take isl_dim *dim)
66 struct isl_name **names;
67 int i;
69 if (isl_dim_total(dim) <= dim->n_name)
70 return dim;
72 if (!dim->names) {
73 dim->names = isl_calloc_array(dim->ctx,
74 struct isl_name *, isl_dim_total(dim));
75 if (!dim->names)
76 goto error;
77 } else {
78 names = isl_realloc_array(dim->ctx, dim->names,
79 struct isl_name *, isl_dim_total(dim));
80 if (!names)
81 goto error;
82 dim->names = names;
83 for (i = dim->n_name; i < isl_dim_total(dim); ++i)
84 dim->names[i] = NULL;
87 dim->n_name = isl_dim_total(dim);
89 return dim;
90 error:
91 isl_dim_free(dim);
92 return NULL;
95 static struct isl_dim *set_name(struct isl_dim *dim,
96 enum isl_dim_type type, unsigned pos,
97 struct isl_name *name)
99 struct isl_ctx *ctx = dim->ctx;
100 dim = isl_dim_cow(dim);
102 if (!dim)
103 goto error;
105 pos = global_pos(dim, type, pos);
106 isl_assert(ctx, pos != isl_dim_total(dim), goto error);
108 if (pos >= dim->n_name) {
109 if (!name)
110 return dim;
111 dim = extend_names(dim);
112 if (!dim)
113 goto error;
116 dim->names[pos] = name;
118 return dim;
119 error:
120 isl_name_free(ctx, name);
121 isl_dim_free(dim);
122 return NULL;
125 static struct isl_name *get_name(struct isl_dim *dim,
126 enum isl_dim_type type, unsigned pos)
128 if (!dim)
129 return NULL;
131 pos = global_pos(dim, type, pos);
132 if (pos == isl_dim_total(dim))
133 return NULL;
134 if (pos >= dim->n_name)
135 return NULL;
136 return dim->names[pos];
139 static unsigned offset(struct isl_dim *dim, enum isl_dim_type type)
141 switch (type) {
142 case isl_dim_param: return 0;
143 case isl_dim_in: return dim->nparam;
144 case isl_dim_out: return dim->nparam + dim->n_in;
148 static unsigned n(struct isl_dim *dim, enum isl_dim_type type)
150 switch (type) {
151 case isl_dim_param: return dim->nparam;
152 case isl_dim_in: return dim->n_in;
153 case isl_dim_out: return dim->n_out;
157 unsigned isl_dim_size(struct isl_dim *dim, enum isl_dim_type type)
159 if (!dim)
160 return 0;
161 return n(dim, type);
164 static struct isl_dim *copy_names(struct isl_dim *dst,
165 enum isl_dim_type dst_type, unsigned offset, struct isl_dim *src,
166 enum isl_dim_type src_type)
168 int i;
169 struct isl_name *name;
171 for (i = 0; i < n(src, src_type); ++i) {
172 name = get_name(src, src_type, i);
173 if (!name)
174 continue;
175 dst = set_name(dst, dst_type, offset + i,
176 isl_name_copy(dst->ctx, name));
177 if (!dst)
178 return NULL;
180 return dst;
183 struct isl_dim *isl_dim_dup(struct isl_dim *dim)
185 struct isl_dim *dup;
186 dup = isl_dim_alloc(dim->ctx, dim->nparam, dim->n_in, dim->n_out);
187 if (!dim->names)
188 return dup;
189 dup = copy_names(dup, isl_dim_param, 0, dim, isl_dim_param);
190 dup = copy_names(dup, isl_dim_in, 0, dim, isl_dim_in);
191 dup = copy_names(dup, isl_dim_out, 0, dim, isl_dim_out);
192 return dup;
195 struct isl_dim *isl_dim_cow(struct isl_dim *dim)
197 if (!dim)
198 return NULL;
200 if (dim->ref == 1)
201 return dim;
202 dim->ref--;
203 return isl_dim_dup(dim);
206 struct isl_dim *isl_dim_copy(struct isl_dim *dim)
208 if (!dim)
209 return NULL;
211 dim->ref++;
212 return dim;
215 void isl_dim_free(struct isl_dim *dim)
217 int i;
219 if (!dim)
220 return;
222 if (--dim->ref > 0)
223 return;
225 for (i = 0; i < dim->n_name; ++i)
226 isl_name_free(dim->ctx, dim->names[i]);
227 free(dim->names);
228 isl_ctx_deref(dim->ctx);
230 free(dim);
233 struct isl_dim *isl_dim_set_name(struct isl_dim *dim,
234 enum isl_dim_type type, unsigned pos,
235 const char *s)
237 struct isl_name *name;
238 if (!dim)
239 return NULL;
240 name = isl_name_get(dim->ctx, s);
241 if (!name)
242 goto error;
243 return set_name(dim, type, pos, name);
244 error:
245 isl_dim_free(dim);
246 return NULL;
249 const char *isl_dim_get_name(struct isl_dim *dim,
250 enum isl_dim_type type, unsigned pos)
252 struct isl_name *name = get_name(dim, type, pos);
253 return name ? name->name : NULL;
256 static int match(struct isl_dim *dim1, enum isl_dim_type dim1_type,
257 struct isl_dim *dim2, enum isl_dim_type dim2_type)
259 int i;
261 if (n(dim1, dim1_type) != n(dim2, dim2_type))
262 return 0;
264 if (!dim1->names && !dim2->names)
265 return 1;
267 for (i = 0; i < n(dim1, dim1_type); ++i) {
268 if (get_name(dim1, dim1_type, i) !=
269 get_name(dim2, dim2_type, i))
270 return 0;
272 return 1;
275 int isl_dim_match(struct isl_dim *dim1, enum isl_dim_type dim1_type,
276 struct isl_dim *dim2, enum isl_dim_type dim2_type)
278 return match(dim1, dim1_type, dim2, dim2_type);
281 static void get_names(struct isl_dim *dim, enum isl_dim_type type,
282 unsigned first, unsigned n, struct isl_name **names)
284 int i;
286 for (i = 0; i < n ; ++i)
287 names[i] = get_name(dim, type, first+i);
290 struct isl_dim *isl_dim_extend(struct isl_dim *dim,
291 unsigned nparam, unsigned n_in, unsigned n_out)
293 struct isl_name **names = NULL;
295 if (!dim)
296 return NULL;
297 if (dim->nparam == nparam && dim->n_in == n_in && dim->n_out == n_out)
298 return dim;
300 isl_assert(dim->ctx, dim->nparam <= nparam, goto error);
301 isl_assert(dim->ctx, dim->n_in <= n_in, goto error);
302 isl_assert(dim->ctx, dim->n_out <= n_out, goto error);
304 dim = isl_dim_cow(dim);
306 if (dim->names) {
307 names = isl_calloc_array(dim->ctx, struct isl_name *,
308 nparam + n_in + n_out);
309 if (!names)
310 goto error;
311 get_names(dim, isl_dim_param, 0, dim->nparam, names);
312 get_names(dim, isl_dim_in, 0, dim->n_in, names + nparam);
313 get_names(dim, isl_dim_out, 0, dim->n_out,
314 names + nparam + n_in);
315 free(dim->names);
316 dim->names = names;
317 dim->n_name = nparam + n_in + n_out;
319 dim->nparam = nparam;
320 dim->n_in = n_in;
321 dim->n_out = n_out;
323 return dim;
324 error:
325 free(names);
326 isl_dim_free(dim);
327 return NULL;
330 struct isl_dim *isl_dim_add(struct isl_dim *dim, enum isl_dim_type type,
331 unsigned n)
333 switch (type) {
334 case isl_dim_param:
335 return isl_dim_extend(dim,
336 dim->nparam + n, dim->n_in, dim->n_out);
337 case isl_dim_in:
338 return isl_dim_extend(dim,
339 dim->nparam, dim->n_in + n, dim->n_out);
340 case isl_dim_out:
341 return isl_dim_extend(dim,
342 dim->nparam, dim->n_in, dim->n_out + n);
344 return dim;
347 __isl_give isl_dim *isl_dim_insert(__isl_take isl_dim *dim,
348 enum isl_dim_type type, unsigned pos, unsigned n)
350 struct isl_name **names = NULL;
352 if (!dim)
353 return NULL;
354 if (n == 0)
355 return dim;
357 isl_assert(dim->ctx, pos <= isl_dim_size(dim, type), goto error);
359 dim = isl_dim_cow(dim);
360 if (!dim)
361 return NULL;
363 if (dim->names) {
364 enum isl_dim_type t;
365 int off;
366 int size[3];
367 names = isl_calloc_array(dim->ctx, struct isl_name *,
368 dim->nparam + dim->n_in + dim->n_out + n);
369 if (!names)
370 goto error;
371 off = 0;
372 size[isl_dim_param] = dim->nparam;
373 size[isl_dim_in] = dim->n_in;
374 size[isl_dim_out] = dim->n_out;
375 for (t = isl_dim_param; t <= isl_dim_out; ++t) {
376 if (t != type) {
377 get_names(dim, t, 0, size[t], names + off);
378 off += size[t];
379 } else {
380 get_names(dim, t, 0, pos, names + off);
381 off += pos + n;
382 get_names(dim, t, pos, size[t]-pos, names+off);
383 off += size[t] - pos;
386 free(dim->names);
387 dim->names = names;
388 dim->n_name = dim->nparam + dim->n_in + dim->n_out + n;
390 switch (type) {
391 case isl_dim_param: dim->nparam += n; break;
392 case isl_dim_in: dim->n_in += n; break;
393 case isl_dim_out: dim->n_out += n; break;
396 return dim;
397 error:
398 isl_dim_free(dim);
399 return NULL;
402 __isl_give isl_dim *isl_dim_move(__isl_take isl_dim *dim,
403 enum isl_dim_type dst_type, unsigned dst_pos,
404 enum isl_dim_type src_type, unsigned src_pos, unsigned n)
406 if (!dim)
407 return NULL;
408 if (n == 0)
409 return dim;
411 isl_assert(dim->ctx, src_pos + n <= isl_dim_size(dim, src_type),
412 goto error);
414 if (dst_type == src_type && dst_pos == src_pos)
415 return dim;
417 isl_assert(dim->ctx, dst_type != src_type, goto error);
419 dim = isl_dim_cow(dim);
420 if (!dim)
421 return NULL;
423 if (dim->names) {
424 struct isl_name **names;
425 enum isl_dim_type t;
426 int off;
427 int size[3];
428 names = isl_calloc_array(dim->ctx, struct isl_name *,
429 dim->nparam + dim->n_in + dim->n_out);
430 if (!names)
431 goto error;
432 off = 0;
433 size[isl_dim_param] = dim->nparam;
434 size[isl_dim_in] = dim->n_in;
435 size[isl_dim_out] = dim->n_out;
436 for (t = isl_dim_param; t <= isl_dim_out; ++t) {
437 if (t == dst_type) {
438 get_names(dim, t, 0, dst_pos, names + off);
439 off += dst_pos;
440 get_names(dim, src_type, src_pos, n, names+off);
441 off += n;
442 get_names(dim, t, dst_pos, size[t] - dst_pos,
443 names + off);
444 off += size[t] - dst_pos;
445 } else if (t == src_type) {
446 get_names(dim, t, 0, src_pos, names + off);
447 off += src_pos;
448 get_names(dim, t, src_pos + n,
449 size[t] - src_pos - n, names + off);
450 off += size[t] - src_pos - n;
451 } else {
452 get_names(dim, t, 0, size[t], names + off);
453 off += size[t];
456 free(dim->names);
457 dim->names = names;
458 dim->n_name = dim->nparam + dim->n_in + dim->n_out;
461 switch (dst_type) {
462 case isl_dim_param: dim->nparam += n; break;
463 case isl_dim_in: dim->n_in += n; break;
464 case isl_dim_out: dim->n_out += n; break;
467 switch (src_type) {
468 case isl_dim_param: dim->nparam -= n; break;
469 case isl_dim_in: dim->n_in -= n; break;
470 case isl_dim_out: dim->n_out -= n; break;
473 return dim;
474 error:
475 isl_dim_free(dim);
476 return NULL;
479 struct isl_dim *isl_dim_join(struct isl_dim *left, struct isl_dim *right)
481 struct isl_dim *dim;
483 if (!left || !right)
484 goto error;
486 isl_assert(left->ctx, match(left, isl_dim_param, right, isl_dim_param),
487 goto error);
488 isl_assert(left->ctx, n(left, isl_dim_out) == n(right, isl_dim_in),
489 goto error);
491 dim = isl_dim_alloc(left->ctx, left->nparam, left->n_in, right->n_out);
492 if (!dim)
493 goto error;
495 dim = copy_names(dim, isl_dim_param, 0, left, isl_dim_param);
496 dim = copy_names(dim, isl_dim_in, 0, left, isl_dim_in);
497 dim = copy_names(dim, isl_dim_out, 0, right, isl_dim_out);
499 isl_dim_free(left);
500 isl_dim_free(right);
502 return dim;
503 error:
504 isl_dim_free(left);
505 isl_dim_free(right);
506 return NULL;
509 struct isl_dim *isl_dim_product(struct isl_dim *left, struct isl_dim *right)
511 struct isl_dim *dim;
513 if (!left || !right)
514 goto error;
516 isl_assert(left->ctx, match(left, isl_dim_param, right, isl_dim_param),
517 goto error);
519 dim = isl_dim_alloc(left->ctx, left->nparam,
520 left->n_in + right->n_in, left->n_out + right->n_out);
521 if (!dim)
522 goto error;
524 dim = copy_names(dim, isl_dim_param, 0, left, isl_dim_param);
525 dim = copy_names(dim, isl_dim_in, 0, left, isl_dim_in);
526 dim = copy_names(dim, isl_dim_in, left->n_in, right, isl_dim_in);
527 dim = copy_names(dim, isl_dim_out, 0, left, isl_dim_out);
528 dim = copy_names(dim, isl_dim_out, left->n_out, right, isl_dim_out);
530 isl_dim_free(left);
531 isl_dim_free(right);
533 return dim;
534 error:
535 isl_dim_free(left);
536 isl_dim_free(right);
537 return NULL;
540 struct isl_dim *isl_dim_map(struct isl_dim *dim)
542 struct isl_name **names = NULL;
544 if (!dim)
545 return NULL;
546 isl_assert(dim->ctx, dim->n_in == 0, goto error);
547 if (dim->n_out == 0)
548 return dim;
549 dim = isl_dim_cow(dim);
550 if (!dim)
551 return NULL;
552 if (dim->names) {
553 names = isl_calloc_array(dim->ctx, struct isl_name *,
554 dim->nparam + dim->n_out + dim->n_out);
555 if (!names)
556 goto error;
557 get_names(dim, isl_dim_param, 0, dim->nparam, names);
558 get_names(dim, isl_dim_out, 0, dim->n_out, names + dim->nparam);
560 dim->n_in = dim->n_out;
561 if (names) {
562 free(dim->names);
563 dim->names = names;
564 dim->n_name = dim->nparam + dim->n_out + dim->n_out;
565 dim = copy_names(dim, isl_dim_out, 0, dim, isl_dim_in);
567 return dim;
568 error:
569 isl_dim_free(dim);
570 return NULL;
573 static struct isl_dim *set_names(struct isl_dim *dim, enum isl_dim_type type,
574 unsigned first, unsigned n, struct isl_name **names)
576 int i;
578 for (i = 0; i < n ; ++i)
579 dim = set_name(dim, type, first+i, names[i]);
581 return dim;
584 struct isl_dim *isl_dim_reverse(struct isl_dim *dim)
586 unsigned t;
587 struct isl_name **names = NULL;
589 if (!dim)
590 return NULL;
591 if (match(dim, isl_dim_in, dim, isl_dim_out))
592 return dim;
594 dim = isl_dim_cow(dim);
595 if (!dim)
596 return NULL;
598 if (dim->names) {
599 names = isl_alloc_array(dim->ctx, struct isl_name *,
600 dim->n_in + dim->n_out);
601 if (!names)
602 goto error;
603 get_names(dim, isl_dim_in, 0, dim->n_in, names);
604 get_names(dim, isl_dim_out, 0, dim->n_out, names + dim->n_in);
607 t = dim->n_in;
608 dim->n_in = dim->n_out;
609 dim->n_out = t;
611 if (dim->names) {
612 dim = set_names(dim, isl_dim_out, 0, dim->n_out, names);
613 dim = set_names(dim, isl_dim_in, 0, dim->n_in, names + dim->n_out);
614 free(names);
617 return dim;
618 error:
619 free(names);
620 isl_dim_free(dim);
621 return NULL;
624 struct isl_dim *isl_dim_drop(struct isl_dim *dim, enum isl_dim_type type,
625 unsigned first, unsigned num)
627 int i;
629 if (!dim)
630 return NULL;
632 if (n == 0)
633 return dim;
635 isl_assert(dim->ctx, first + num <= n(dim, type), goto error);
636 dim = isl_dim_cow(dim);
637 if (!dim)
638 goto error;
639 if (dim->names) {
640 dim = extend_names(dim);
641 if (!dim)
642 goto error;
643 for (i = 0; i < num; ++i)
644 isl_name_free(dim->ctx, get_name(dim, type, first+i));
645 for (i = first+num; i < n(dim, type); ++i)
646 set_name(dim, type, i - num, get_name(dim, type, i));
647 switch (type) {
648 case isl_dim_param:
649 get_names(dim, isl_dim_in, 0, dim->n_in,
650 dim->names + offset(dim, isl_dim_in) - num);
651 case isl_dim_in:
652 get_names(dim, isl_dim_out, 0, dim->n_out,
653 dim->names + offset(dim, isl_dim_out) - num);
654 case isl_dim_out:
657 dim->n_name -= num;
659 switch (type) {
660 case isl_dim_param: dim->nparam -= num; break;
661 case isl_dim_in: dim->n_in -= num; break;
662 case isl_dim_out: dim->n_out -= num; break;
664 return dim;
665 error:
666 isl_dim_free(dim);
667 return NULL;
670 struct isl_dim *isl_dim_drop_inputs(struct isl_dim *dim,
671 unsigned first, unsigned n)
673 return isl_dim_drop(dim, isl_dim_in, first, n);
676 struct isl_dim *isl_dim_drop_outputs(struct isl_dim *dim,
677 unsigned first, unsigned n)
679 return isl_dim_drop(dim, isl_dim_out, first, n);
682 struct isl_dim *isl_dim_domain(struct isl_dim *dim)
684 if (!dim)
685 return NULL;
686 dim = isl_dim_drop_outputs(dim, 0, dim->n_out);
687 return isl_dim_reverse(dim);
690 struct isl_dim *isl_dim_range(struct isl_dim *dim)
692 if (!dim)
693 return NULL;
694 return isl_dim_drop_inputs(dim, 0, dim->n_in);
697 struct isl_dim *isl_dim_underlying(struct isl_dim *dim, unsigned n_div)
699 int i;
701 if (!dim)
702 return NULL;
703 if (n_div == 0 &&
704 dim->nparam == 0 && dim->n_in == 0 && dim->n_name == 0)
705 return dim;
706 dim = isl_dim_cow(dim);
707 if (!dim)
708 return NULL;
709 dim->n_out += dim->nparam + dim->n_in + n_div;
710 dim->nparam = 0;
711 dim->n_in = 0;
713 for (i = 0; i < dim->n_name; ++i)
714 isl_name_free(dim->ctx, get_name(dim, isl_dim_out, i));
715 dim->n_name = 0;
717 return dim;
720 unsigned isl_dim_total(struct isl_dim *dim)
722 return dim->nparam + dim->n_in + dim->n_out;
725 int isl_dim_equal(struct isl_dim *dim1, struct isl_dim *dim2)
727 return match(dim1, isl_dim_param, dim2, isl_dim_param) &&
728 n(dim1, isl_dim_in) == n(dim2, isl_dim_in) &&
729 n(dim1, isl_dim_out) == n(dim2, isl_dim_out);
732 int isl_dim_compatible(struct isl_dim *dim1, struct isl_dim *dim2)
734 return dim1->nparam == dim2->nparam &&
735 dim1->n_in + dim1->n_out == dim2->n_in + dim2->n_out;