isl_tab: introduce support for "big parameters"
[isl.git] / isl_dim.c
blobb6af7e20772a5f7b1dded8fd6c3f3b3150731130
1 #include "isl_dim.h"
2 #include "isl_name.h"
4 struct isl_dim *isl_dim_alloc(struct isl_ctx *ctx,
5 unsigned nparam, unsigned n_in, unsigned n_out)
7 struct isl_dim *dim;
9 dim = isl_alloc_type(ctx, struct isl_dim);
10 if (!dim)
11 return NULL;
13 dim->ctx = ctx;
14 isl_ctx_ref(ctx);
15 dim->ref = 1;
16 dim->nparam = nparam;
17 dim->n_in = n_in;
18 dim->n_out = n_out;
20 dim->n_name = 0;
21 dim->names = NULL;
23 return dim;
26 struct isl_dim *isl_dim_set_alloc(struct isl_ctx *ctx,
27 unsigned nparam, unsigned dim)
29 return isl_dim_alloc(ctx, nparam, 0, dim);
32 static unsigned global_pos(struct isl_dim *dim,
33 enum isl_dim_type type, unsigned pos)
35 struct isl_ctx *ctx = dim->ctx;
37 switch (type) {
38 case isl_dim_param:
39 isl_assert(ctx, pos < dim->nparam, return isl_dim_total(dim));
40 return pos;
41 case isl_dim_in:
42 isl_assert(ctx, pos < dim->n_in, return isl_dim_total(dim));
43 return pos + dim->nparam;
44 case isl_dim_out:
45 isl_assert(ctx, pos < dim->n_out, return isl_dim_total(dim));
46 return pos + dim->nparam + dim->n_in;
47 default:
48 isl_assert(ctx, 0, goto error);
50 return isl_dim_total(dim);
53 static struct isl_dim *set_name(struct isl_dim *dim,
54 enum isl_dim_type type, unsigned pos,
55 struct isl_name *name)
57 struct isl_ctx *ctx = dim->ctx;
58 dim = isl_dim_cow(dim);
60 if (!dim)
61 goto error;
63 pos = global_pos(dim, type, pos);
64 isl_assert(ctx, pos != isl_dim_total(dim), goto error);
66 if (pos >= dim->n_name) {
67 if (!name)
68 return dim;
69 if (!dim->names) {
70 dim->names = isl_calloc_array(dim->ctx,
71 struct isl_name *, isl_dim_total(dim));
72 if (!dim->names)
73 goto error;
74 } else {
75 int i;
76 dim->names = isl_realloc_array(dim->ctx, dim->names,
77 struct isl_name *, isl_dim_total(dim));
78 if (!dim->names)
79 goto error;
80 for (i = dim->n_name; i < isl_dim_total(dim); ++i)
81 dim->names[i] = NULL;
83 dim->n_name = isl_dim_total(dim);
86 dim->names[pos] = name;
88 return dim;
89 error:
90 isl_name_free(ctx, name);
91 isl_dim_free(dim);
92 return NULL;
95 static struct isl_name *get_name(struct isl_dim *dim,
96 enum isl_dim_type type, unsigned pos)
98 if (!dim)
99 return NULL;
101 pos = global_pos(dim, type, pos);
102 if (pos == isl_dim_total(dim))
103 return NULL;
104 if (pos >= dim->n_name)
105 return NULL;
106 return dim->names[pos];
109 static unsigned offset(struct isl_dim *dim, enum isl_dim_type type)
111 switch (type) {
112 case isl_dim_param: return 0;
113 case isl_dim_in: return dim->nparam;
114 case isl_dim_out: return dim->nparam + dim->n_in;
118 static unsigned n(struct isl_dim *dim, enum isl_dim_type type)
120 switch (type) {
121 case isl_dim_param: return dim->nparam;
122 case isl_dim_in: return dim->n_in;
123 case isl_dim_out: return dim->n_out;
127 unsigned isl_dim_size(struct isl_dim *dim, enum isl_dim_type type)
129 return n(dim, type);
132 static struct isl_dim *copy_names(struct isl_dim *dst,
133 enum isl_dim_type dst_type, unsigned offset, struct isl_dim *src,
134 enum isl_dim_type src_type)
136 int i;
137 struct isl_name *name;
139 for (i = 0; i < n(src, src_type); ++i) {
140 name = get_name(src, src_type, i);
141 if (!name)
142 continue;
143 dst = set_name(dst, dst_type, offset + i,
144 isl_name_copy(dst->ctx, name));
145 if (!dst)
146 return NULL;
148 return dst;
151 struct isl_dim *isl_dim_dup(struct isl_dim *dim)
153 struct isl_dim *dup;
154 dup = isl_dim_alloc(dim->ctx, dim->nparam, dim->n_in, dim->n_out);
155 if (!dim->names)
156 return dup;
157 dup = copy_names(dup, isl_dim_param, 0, dim, isl_dim_param);
158 dup = copy_names(dup, isl_dim_in, 0, dim, isl_dim_in);
159 dup = copy_names(dup, isl_dim_out, 0, dim, isl_dim_out);
160 return dup;
163 struct isl_dim *isl_dim_cow(struct isl_dim *dim)
165 if (!dim)
166 return NULL;
168 if (dim->ref == 1)
169 return dim;
170 dim->ref--;
171 return isl_dim_dup(dim);
174 struct isl_dim *isl_dim_copy(struct isl_dim *dim)
176 if (!dim)
177 return NULL;
179 dim->ref++;
180 return dim;
183 void isl_dim_free(struct isl_dim *dim)
185 int i;
187 if (!dim)
188 return;
190 if (--dim->ref > 0)
191 return;
193 for (i = 0; i < dim->n_name; ++i)
194 isl_name_free(dim->ctx, dim->names[i]);
195 free(dim->names);
196 isl_ctx_deref(dim->ctx);
198 free(dim);
201 struct isl_dim *isl_dim_set_name(struct isl_dim *dim,
202 enum isl_dim_type type, unsigned pos,
203 const char *s)
205 struct isl_name *name;
206 if (!dim)
207 return NULL;
208 name = isl_name_get(dim->ctx, s);
209 if (!name)
210 goto error;
211 return set_name(dim, type, pos, name);
212 error:
213 isl_dim_free(dim);
214 return NULL;
217 const char *isl_dim_get_name(struct isl_dim *dim,
218 enum isl_dim_type type, unsigned pos)
220 struct isl_name *name = get_name(dim, type, pos);
221 return name ? name->name : NULL;
224 static int match(struct isl_dim *dim1, enum isl_dim_type dim1_type,
225 struct isl_dim *dim2, enum isl_dim_type dim2_type)
227 int i;
229 if (n(dim1, dim1_type) != n(dim2, dim2_type))
230 return 0;
232 if (!dim1->names && !dim2->names)
233 return 1;
235 for (i = 0; i < n(dim1, dim1_type); ++i) {
236 if (get_name(dim1, dim1_type, i) !=
237 get_name(dim2, dim2_type, i))
238 return 0;
240 return 1;
243 int isl_dim_match(struct isl_dim *dim1, enum isl_dim_type dim1_type,
244 struct isl_dim *dim2, enum isl_dim_type dim2_type)
246 return match(dim1, dim1_type, dim2, dim2_type);
249 static void get_names(struct isl_dim *dim, enum isl_dim_type type,
250 unsigned first, unsigned n, struct isl_name **names)
252 int i;
254 for (i = 0; i < n ; ++i)
255 names[i] = get_name(dim, type, first+i);
258 struct isl_dim *isl_dim_extend(struct isl_dim *dim,
259 unsigned nparam, unsigned n_in, unsigned n_out)
261 struct isl_name **names = NULL;
263 if (!dim)
264 return NULL;
265 if (dim->nparam == nparam && dim->n_in == n_in && dim->n_out == n_out)
266 return dim;
268 isl_assert(dim->ctx, dim->nparam <= nparam, goto error);
269 isl_assert(dim->ctx, dim->n_in <= n_in, goto error);
270 isl_assert(dim->ctx, dim->n_out <= n_out, goto error);
272 dim = isl_dim_cow(dim);
274 if (dim->names) {
275 names = isl_calloc_array(dim->ctx, struct isl_name *,
276 nparam + n_in + n_out);
277 if (!names)
278 goto error;
279 get_names(dim, isl_dim_param, 0, dim->nparam, names);
280 get_names(dim, isl_dim_in, 0, dim->n_in, names + nparam);
281 get_names(dim, isl_dim_out, 0, dim->n_out,
282 names + nparam + n_in);
283 free(dim->names);
284 dim->names = names;
285 dim->n_name = nparam + n_in + n_out;
287 dim->nparam = nparam;
288 dim->n_in = n_in;
289 dim->n_out = n_out;
291 return dim;
292 error:
293 free(names);
294 isl_dim_free(dim);
295 return NULL;
298 struct isl_dim *isl_dim_add(struct isl_dim *dim, enum isl_dim_type type,
299 unsigned n)
301 switch (type) {
302 case isl_dim_param:
303 return isl_dim_extend(dim,
304 dim->nparam + n, dim->n_in, dim->n_out);
305 case isl_dim_in:
306 return isl_dim_extend(dim,
307 dim->nparam, dim->n_in + n, dim->n_out);
308 case isl_dim_out:
309 return isl_dim_extend(dim,
310 dim->nparam, dim->n_in, dim->n_out + n);
312 return dim;
315 struct isl_dim *isl_dim_join(struct isl_dim *left, struct isl_dim *right)
317 struct isl_dim *dim;
319 if (!left || !right)
320 goto error;
322 isl_assert(left->ctx, match(left, isl_dim_param, right, isl_dim_param),
323 goto error);
324 isl_assert(left->ctx, match(left, isl_dim_out, right, isl_dim_in),
325 goto error);
327 dim = isl_dim_alloc(left->ctx, left->nparam, left->n_in, right->n_out);
328 if (!dim)
329 goto error;
331 dim = copy_names(dim, isl_dim_param, 0, left, isl_dim_param);
332 dim = copy_names(dim, isl_dim_in, 0, left, isl_dim_in);
333 dim = copy_names(dim, isl_dim_out, 0, right, isl_dim_out);
335 isl_dim_free(left);
336 isl_dim_free(right);
338 return dim;
339 error:
340 isl_dim_free(left);
341 isl_dim_free(right);
342 return NULL;
345 struct isl_dim *isl_dim_product(struct isl_dim *left, struct isl_dim *right)
347 struct isl_dim *dim;
349 if (!left || !right)
350 goto error;
352 isl_assert(left->ctx, match(left, isl_dim_param, right, isl_dim_param),
353 goto error);
355 dim = isl_dim_alloc(left->ctx, left->nparam,
356 left->n_in + right->n_in, left->n_out + right->n_out);
357 if (!dim)
358 goto error;
360 dim = copy_names(dim, isl_dim_param, 0, left, isl_dim_param);
361 dim = copy_names(dim, isl_dim_in, 0, left, isl_dim_in);
362 dim = copy_names(dim, isl_dim_in, left->n_in, right, isl_dim_in);
363 dim = copy_names(dim, isl_dim_out, 0, left, isl_dim_out);
364 dim = copy_names(dim, isl_dim_out, left->n_out, right, isl_dim_out);
366 isl_dim_free(left);
367 isl_dim_free(right);
369 return dim;
370 error:
371 isl_dim_free(left);
372 isl_dim_free(right);
373 return NULL;
376 struct isl_dim *isl_dim_map(struct isl_dim *dim)
378 struct isl_name **names = NULL;
380 if (!dim)
381 return NULL;
382 isl_assert(dim->ctx, dim->n_in == 0, goto error);
383 if (dim->n_out == 0)
384 return dim;
385 dim = isl_dim_cow(dim);
386 if (!dim)
387 return NULL;
388 if (dim->names) {
389 names = isl_calloc_array(dim->ctx, struct isl_name *,
390 dim->nparam + dim->n_out + dim->n_out);
391 if (!names)
392 goto error;
393 get_names(dim, isl_dim_param, 0, dim->nparam, names);
394 get_names(dim, isl_dim_out, 0, dim->n_out, names + dim->nparam);
396 dim->n_in = dim->n_out;
397 if (names) {
398 copy_names(dim, isl_dim_out, 0, dim, isl_dim_in);
399 free(dim->names);
400 dim->names = names;
401 dim->n_name = dim->nparam + dim->n_out + dim->n_out;
403 return dim;
404 error:
405 isl_dim_free(dim);
406 return NULL;
409 static struct isl_dim *set_names(struct isl_dim *dim, enum isl_dim_type type,
410 unsigned first, unsigned n, struct isl_name **names)
412 int i;
414 for (i = 0; i < n ; ++i)
415 dim = set_name(dim, type, first+i, names[i]);
417 return dim;
420 struct isl_dim *isl_dim_reverse(struct isl_dim *dim)
422 unsigned t;
423 struct isl_name **names = NULL;
425 if (!dim)
426 return NULL;
427 if (match(dim, isl_dim_in, dim, isl_dim_out))
428 return dim;
430 dim = isl_dim_cow(dim);
431 if (!dim)
432 return NULL;
434 if (dim->names) {
435 names = isl_alloc_array(dim->ctx, struct isl_name *,
436 dim->n_in + dim->n_out);
437 if (!names)
438 goto error;
439 get_names(dim, isl_dim_in, 0, dim->n_in, names);
440 get_names(dim, isl_dim_out, 0, dim->n_out, names + dim->n_in);
443 t = dim->n_in;
444 dim->n_in = dim->n_out;
445 dim->n_out = t;
447 if (dim->names) {
448 dim = set_names(dim, isl_dim_out, 0, dim->n_out, names);
449 dim = set_names(dim, isl_dim_in, 0, dim->n_in, names + dim->n_out);
450 free(names);
453 return dim;
454 error:
455 free(names);
456 isl_dim_free(dim);
457 return NULL;
460 struct isl_dim *isl_dim_drop(struct isl_dim *dim, enum isl_dim_type type,
461 unsigned first, unsigned num)
463 int i;
465 if (!dim)
466 return NULL;
468 if (n == 0)
469 return dim;
471 isl_assert(dim->ctx, first + num <= n(dim, type), goto error);
472 dim = isl_dim_cow(dim);
473 if (!dim)
474 goto error;
475 if (dim->names) {
476 for (i = 0; i < num; ++i)
477 isl_name_free(dim->ctx, get_name(dim, type, first+i));
478 for (i = first+num; i < n(dim, type); ++i)
479 set_name(dim, type, i - num, get_name(dim, type, i));
480 switch (type) {
481 case isl_dim_param:
482 get_names(dim, isl_dim_in, 0, dim->n_in,
483 dim->names + offset(dim, isl_dim_in) - num);
484 case isl_dim_in:
485 get_names(dim, isl_dim_out, 0, dim->n_out,
486 dim->names + offset(dim, isl_dim_out) - num);
487 case isl_dim_out:
491 switch (type) {
492 case isl_dim_param: dim->nparam -= num; break;
493 case isl_dim_in: dim->n_in -= num; break;
494 case isl_dim_out: dim->n_out -= num; break;
496 return dim;
497 error:
498 isl_dim_free(dim);
499 return NULL;
502 struct isl_dim *isl_dim_drop_inputs(struct isl_dim *dim,
503 unsigned first, unsigned n)
505 return isl_dim_drop(dim, isl_dim_in, first, n);
508 struct isl_dim *isl_dim_drop_outputs(struct isl_dim *dim,
509 unsigned first, unsigned n)
511 return isl_dim_drop(dim, isl_dim_out, first, n);
514 struct isl_dim *isl_dim_domain(struct isl_dim *dim)
516 if (!dim)
517 return NULL;
518 dim = isl_dim_drop_outputs(dim, 0, dim->n_out);
519 return isl_dim_reverse(dim);
522 struct isl_dim *isl_dim_range(struct isl_dim *dim)
524 if (!dim)
525 return NULL;
526 return isl_dim_drop_inputs(dim, 0, dim->n_in);
529 struct isl_dim *isl_dim_underlying(struct isl_dim *dim, unsigned n_div)
531 int i;
533 if (!dim)
534 return NULL;
535 if (n_div == 0 &&
536 dim->nparam == 0 && dim->n_in == 0 && dim->n_name == 0)
537 return dim;
538 dim = isl_dim_cow(dim);
539 if (!dim)
540 return NULL;
541 dim->n_out += dim->nparam + dim->n_in + n_div;
542 dim->nparam = 0;
543 dim->n_in = 0;
545 for (i = 0; i < dim->n_name; ++i)
546 isl_name_free(dim->ctx, get_name(dim, isl_dim_out, i));
547 dim->n_name = 0;
549 return dim;
552 unsigned isl_dim_total(struct isl_dim *dim)
554 return dim->nparam + dim->n_in + dim->n_out;
557 int isl_dim_equal(struct isl_dim *dim1, struct isl_dim *dim2)
559 return match(dim1, isl_dim_param, dim2, isl_dim_param) &&
560 match(dim1, isl_dim_in, dim2, isl_dim_in) &&
561 match(dim1, isl_dim_out, dim2, isl_dim_out);
564 int isl_dim_compatible(struct isl_dim *dim1, struct isl_dim *dim2)
566 return dim1->nparam == dim2->nparam &&
567 dim1->n_in + dim1->n_out == dim2->n_in + dim2->n_out;