Little fix after the last commit (mostly a git fail)
[eigenmath-fx.git] / inner.cpp
bloba6ba1a461d798ca9b3b2b259cf1330e912d5fa68
1 // Do the inner product of tensors.
3 #include "stdafx.h"
4 #include "defs.h"
6 static void inner_f(void);
8 void
9 eval_inner(void)
11 p1 = cdr(p1);
12 push(car(p1));
13 eval();
14 p1 = cdr(p1);
15 while (iscons(p1)) {
16 push(car(p1));
17 eval();
18 inner();
19 p1 = cdr(p1);
23 void
24 inner(void)
26 save();
27 p2 = pop();
28 p1 = pop();
29 if (istensor(p1) && istensor(p2))
30 inner_f();
31 else {
32 push(p1);
33 push(p2);
34 if (istensor(p1))
35 tensor_times_scalar();
36 else if (istensor(p2))
37 scalar_times_tensor();
38 else
39 multiply();
41 restore();
44 // inner product of tensors p1 and p2
46 static void
47 inner_f(void)
49 int ak, bk, i, j, k, n, ndim;
50 U **a, **b, **c;
52 n = p1->u.tensor->dim[p1->u.tensor->ndim - 1];
54 if (n != p2->u.tensor->dim[0])
55 stop("inner: tensor dimension check");
57 ndim = p1->u.tensor->ndim + p2->u.tensor->ndim - 2;
59 if (ndim > MAXDIM)
60 stop("inner: rank of result exceeds maximum");
62 a = p1->u.tensor->elem;
63 b = p2->u.tensor->elem;
65 //---------------------------------------------------------------------
67 // ak is the number of rows in tensor A
69 // bk is the number of columns in tensor B
71 // Example:
73 // A[3][3][4] B[4][4][3]
75 // 3 3 ak = 3 * 3 = 9
77 // 4 3 bk = 4 * 3 = 12
79 //---------------------------------------------------------------------
81 ak = 1;
82 for (i = 0; i < p1->u.tensor->ndim - 1; i++)
83 ak *= p1->u.tensor->dim[i];
85 bk = 1;
86 for (i = 1; i < p2->u.tensor->ndim; i++)
87 bk *= p2->u.tensor->dim[i];
89 p3 = alloc_tensor(ak * bk);
91 c = p3->u.tensor->elem;
93 // new method copied from ginac
94 #if 1
95 for (i = 0; i < ak; i++) {
96 for (j = 0; j < n; j++) {
97 if (iszero(a[i * n + j]))
98 continue;
99 for (k = 0; k < bk; k++) {
100 push(a[i * n + j]);
101 push(b[j * bk + k]);
102 multiply();
103 push(c[i * bk + k]);
104 add();
105 c[i * bk + k] = pop();
109 #else
110 for (i = 0; i < ak; i++) {
111 for (j = 0; j < bk; j++) {
112 push(zero);
113 for (k = 0; k < n; k++) {
114 push(a[i * n + k]);
115 push(b[k * bk + j]);
116 multiply();
117 add();
119 c[i * bk + j] = pop();
122 #endif
123 //---------------------------------------------------------------------
125 // Note on understanding "k * bk + j"
127 // k * bk because each element of a column is bk locations apart
129 // + j because the beginnings of all columns are in the first bk
130 // locations
132 // Example: n = 2, bk = 6
134 // b111 <- 1st element of 1st column
135 // b112 <- 1st element of 2nd column
136 // b113 <- 1st element of 3rd column
137 // b121 <- 1st element of 4th column
138 // b122 <- 1st element of 5th column
139 // b123 <- 1st element of 6th column
141 // b211 <- 2nd element of 1st column
142 // b212 <- 2nd element of 2nd column
143 // b213 <- 2nd element of 3rd column
144 // b221 <- 2nd element of 4th column
145 // b222 <- 2nd element of 5th column
146 // b223 <- 2nd element of 6th column
148 //---------------------------------------------------------------------
150 if (ndim == 0)
151 push(p3->u.tensor->elem[0]);
152 else {
153 p3->u.tensor->ndim = ndim;
154 for (i = 0; i < p1->u.tensor->ndim - 1; i++)
155 p3->u.tensor->dim[i] = p1->u.tensor->dim[i];
156 j = i;
157 for (i = 0; i < p2->u.tensor->ndim - 1; i++)
158 p3->u.tensor->dim[j + i] = p2->u.tensor->dim[i + 1];
159 push(p3);
163 #if SELFTEST
165 static char *s[] = {
167 "inner(a,b)",
168 "a*b",
170 "inner(a,(b1,b2))",
171 "(a*b1,a*b2)",
173 "inner((a1,a2),b)",
174 "(a1*b,a2*b)",
176 "inner(((a11,a12),(a21,a22)),(x1,x2))",
177 "(a11*x1+a12*x2,a21*x1+a22*x2)",
179 "inner((1,2),(3,4))",
180 "11",
182 "inner(inner((1,2),((3,4),(5,6))),(7,8))",
183 "219",
185 "inner((1,2),inner(((3,4),(5,6)),(7,8)))",
186 "219",
188 "inner((1,2),((3,4),(5,6)),(7,8))",
189 "219",
192 void
193 test_inner(void)
195 test(__FILE__, s, sizeof s / sizeof (char *));
198 #endif