Little fix after the last commit (mostly a git fail)
[eigenmath-fx.git] / transpose.cpp
blobf1f731da0e155a9008a359aa6a29cdb075160f56
1 // Transpose tensor indices
3 #include "stdafx.h"
4 #include "defs.h"
6 void
7 eval_transpose(void)
9 push(cadr(p1));
10 eval();
11 if (cddr(p1) == symbol(NIL)) {
12 push_integer(1);
13 push_integer(2);
14 } else {
15 push(caddr(p1));
16 eval();
17 push(cadddr(p1));
18 eval();
20 transpose();
23 void
24 transpose(void)
26 int i, j, k, l, m, ndim, nelem, t;
27 int ai[MAXDIM], an[MAXDIM];
28 U **a, **b;
30 save();
32 p3 = pop();
33 p2 = pop();
34 p1 = pop();
36 if (!istensor(p1)) {
37 if (!iszero(p1))
38 stop("transpose: tensor expected, 1st arg is not a tensor");
39 push(zero);
40 restore();
41 return;
44 ndim = p1->u.tensor->ndim;
45 nelem = p1->u.tensor->nelem;
47 // vector?
49 if (ndim == 1) {
50 push(p1);
51 restore();
52 return;
55 push(p2);
56 l = pop_integer();
58 push(p3);
59 m = pop_integer();
61 if (l < 1 || l > ndim || m < 1 || m > ndim)
62 stop("transpose: index out of range");
64 l--;
65 m--;
67 p2 = alloc_tensor(nelem);
69 p2->u.tensor->ndim = ndim;
71 for (i = 0; i < ndim; i++)
72 p2->u.tensor->dim[i] = p1->u.tensor->dim[i];
74 p2->u.tensor->dim[l] = p1->u.tensor->dim[m];
75 p2->u.tensor->dim[m] = p1->u.tensor->dim[l];
77 a = p1->u.tensor->elem;
78 b = p2->u.tensor->elem;
80 // init tensor index
82 for (i = 0; i < ndim; i++) {
83 ai[i] = 0;
84 an[i] = p1->u.tensor->dim[i];
87 // copy components from a to b
89 for (i = 0; i < nelem; i++) {
91 // swap indices l and m
93 t = ai[l]; ai[l] = ai[m]; ai[m] = t;
94 t = an[l]; an[l] = an[m]; an[m] = t;
96 // convert tensor index to linear index k
98 k = 0;
99 for (j = 0; j < ndim; j++)
100 k = (k * an[j]) + ai[j];
102 // swap indices back
104 t = ai[l]; ai[l] = ai[m]; ai[m] = t;
105 t = an[l]; an[l] = an[m]; an[m] = t;
107 // copy one element
109 b[k] = a[i];
111 // increment tensor index
113 // Suppose the tensor dimensions are 2 and 3.
114 // Then the tensor index ai increments as follows:
115 // 00 -> 01
116 // 01 -> 02
117 // 02 -> 10
118 // 10 -> 11
119 // 11 -> 12
120 // 12 -> 00
122 for (j = ndim - 1; j >= 0; j--) {
123 if (++ai[j] < an[j])
124 break;
125 ai[j] = 0;
129 push(p2);
130 restore();
133 #if SELFTEST
135 static char *s[] = {
137 "transpose(0)",
138 "0",
140 "transpose(0.0)",
141 "0",
143 "transpose(((a,b),(c,d)))",
144 "((a,c),(b,d))",
146 "transpose(((a,b),(c,d)),1,2)",
147 "((a,c),(b,d))",
149 "transpose(((a,b,c),(d,e,f)),1,2)",
150 "((a,d),(b,e),(c,f))",
152 "transpose(((a,d),(b,e),(c,f)),1,2)",
153 "((a,b,c),(d,e,f))",
155 "transpose((a,b,c))",
156 "(a,b,c)",
159 void
160 test_transpose(void)
162 test(__FILE__, s, sizeof s / sizeof (char *));
165 #endif