Little fix after the last commit (mostly a git fail)
[eigenmath-fx.git] / inv.cpp
blob5dd40f43d4716182c6276a0437179516b4fc345d
1 //-----------------------------------------------------------------------------
2 //
3 // Input: Matrix on stack
4 //
5 // Output: Inverse on stack
6 //
7 // Example:
8 //
9 // > inv(((1,2),(3,4))
10 // ((-2,1),(3/2,-1/2))
12 // Note:
14 // Uses Gaussian elimination for numerical matrices.
16 //-----------------------------------------------------------------------------
18 #include "stdafx.h"
19 #include "defs.h"
21 static int
22 check_arg(void)
24 if (!istensor(p1))
25 return 0;
26 else if (p1->u.tensor->ndim != 2)
27 return 0;
28 else if (p1->u.tensor->dim[0] != p1->u.tensor->dim[1])
29 return 0;
30 else
31 return 1;
34 void
35 inv(void)
37 int i, n;
38 U **a;
40 save();
42 p1 = pop();
44 if (check_arg() == 0) {
45 push_symbol(INV);
46 push(p1);
47 list(2);
48 restore();
49 return;
52 n = p1->u.tensor->nelem;
54 a = p1->u.tensor->elem;
56 for (i = 0; i < n; i++)
57 if (!isnum(a[i]))
58 break;
60 if (i == n)
61 yyinvg();
62 else {
63 push(p1);
64 adj();
65 push(p1);
66 det();
67 p2 = pop();
68 if (iszero(p2))
69 stop("inverse of singular matrix");
70 push(p2);
71 divide();
74 restore();
77 void
78 invg(void)
80 save();
82 p1 = pop();
84 if (check_arg() == 0) {
85 push_symbol(INVG);
86 push(p1);
87 list(2);
88 restore();
89 return;
92 yyinvg();
94 restore();
97 // inverse using gaussian elimination
99 void
100 yyinvg(void)
102 int h, i, j, n;
104 n = p1->u.tensor->dim[0];
106 h = tos;
108 for (i = 0; i < n; i++)
109 for (j = 0; j < n; j++)
110 if (i == j)
111 push(one);
112 else
113 push(zero);
115 for (i = 0; i < n * n; i++)
116 push(p1->u.tensor->elem[i]);
118 decomp(n);
120 p1 = alloc_tensor(n * n);
122 p1->u.tensor->ndim = 2;
123 p1->u.tensor->dim[0] = n;
124 p1->u.tensor->dim[1] = n;
126 for (i = 0; i < n * n; i++)
127 p1->u.tensor->elem[i] = stack[h + i];
129 tos -= 2 * n * n;
131 push(p1);
134 //-----------------------------------------------------------------------------
136 // Input: n * n unit matrix on stack
138 // n * n operand on stack
140 // Output: n * n inverse matrix on stack
142 // n * n garbage on stack
144 // p2 mangled
146 //-----------------------------------------------------------------------------
148 #define A(i, j) stack[a + n * (i) + (j)]
149 #define U(i, j) stack[u + n * (i) + (j)]
151 void
152 decomp(int n)
154 int a, d, i, j, u;
156 a = tos - n * n;
158 u = a - n * n;
160 for (d = 0; d < n; d++) {
162 // diagonal element zero?
164 if (equal(A(d, d), zero)) {
166 // find a new row
168 for (i = d + 1; i < n; i++)
169 if (!equal(A(i, d), zero))
170 break;
172 if (i == n)
173 stop("inverse of singular matrix");
175 // exchange rows
177 for (j = 0; j < n; j++) {
179 p2 = A(d, j);
180 A(d, j) = A(i, j);
181 A(i, j) = p2;
183 p2 = U(d, j);
184 U(d, j) = U(i, j);
185 U(i, j) = p2;
189 // multiply the pivot row by 1 / pivot
191 p2 = A(d, d);
193 for (j = 0; j < n; j++) {
195 if (j > d) {
196 push(A(d, j));
197 push(p2);
198 divide();
199 A(d, j) = pop();
202 push(U(d, j));
203 push(p2);
204 divide();
205 U(d, j) = pop();
208 // clear out the column above and below the pivot
210 for (i = 0; i < n; i++) {
212 if (i == d)
213 continue;
215 // multiplier
217 p2 = A(i, d);
219 // add pivot row to i-th row
221 for (j = 0; j < n; j++) {
223 if (j > d) {
224 push(A(i, j));
225 push(A(d, j));
226 push(p2);
227 multiply();
228 subtract();
229 A(i, j) = pop();
232 push(U(i, j));
233 push(U(d, j));
234 push(p2);
235 multiply();
236 subtract();
237 U(i, j) = pop();