Little fix after the last commit (mostly a git fail)
[eigenmath-fx.git] / det.cpp
blob6ea251ee8b67dfc1fb06628344113fbc5c55c56f
1 //-----------------------------------------------------------------------------
2 //
3 // Input: Matrix on stack
4 //
5 // Output: Determinant on stack
6 //
7 // Example:
8 //
9 // > det(((1,2),(3,4)))
10 // -2
12 // Note:
14 // Uses Gaussian elimination for numerical matrices.
16 //-----------------------------------------------------------------------------
18 #include "stdafx.h"
19 #include "defs.h"
21 static int
22 check_arg(void)
24 if (!istensor(p1))
25 return 0;
26 else if (p1->u.tensor->ndim != 2)
27 return 0;
28 else if (p1->u.tensor->dim[0] != p1->u.tensor->dim[1])
29 return 0;
30 else
31 return 1;
34 void
35 det(void)
37 int i, n;
38 U **a;
40 save();
42 p1 = pop();
44 if (check_arg() == 0) {
45 push_symbol(DET);
46 push(p1);
47 list(2);
48 restore();
49 return;
52 n = p1->u.tensor->nelem;
54 a = p1->u.tensor->elem;
56 for (i = 0; i < n; i++)
57 if (!isnum(a[i]))
58 break;
60 if (i == n)
61 yydetg();
62 else {
63 for (i = 0; i < p1->u.tensor->nelem; i++)
64 push(p1->u.tensor->elem[i]);
65 determinant(p1->u.tensor->dim[0]);
68 restore();
71 // determinant of n * n matrix elements on the stack
73 void
74 determinant(int n)
76 int h, i, j, k, q, s, sign, t;
77 int *a, *c, *d;
79 h = tos - n * n;
81 a = (int *) malloc(3 * n * sizeof (int));
83 if (a == NULL)
84 out_of_memory();
86 c = a + n;
88 d = c + n;
90 for (i = 0; i < n; i++) {
91 a[i] = i;
92 c[i] = 0;
93 d[i] = 1;
96 sign = 1;
98 push(zero);
100 for (;;) {
102 if (sign == 1)
103 push_integer(1);
104 else
105 push_integer(-1);
107 for (i = 0; i < n; i++) {
108 k = n * a[i] + i;
109 push(stack[h + k]);
110 multiply(); // FIXME -- problem here
113 add();
115 /* next permutation (Knuth's algorithm P) */
117 j = n - 1;
118 s = 0;
119 P4: q = c[j] + d[j];
120 if (q < 0) {
121 d[j] = -d[j];
122 j--;
123 goto P4;
125 if (q == j + 1) {
126 if (j == 0)
127 break;
128 s++;
129 d[j] = -d[j];
130 j--;
131 goto P4;
134 t = a[j - c[j] + s];
135 a[j - c[j] + s] = a[j - q + s];
136 a[j - q + s] = t;
137 c[j] = q;
139 sign = -sign;
142 free(a);
144 stack[h] = stack[tos - 1];
146 tos = h + 1;
149 //-----------------------------------------------------------------------------
151 // Input: Matrix on stack
153 // Output: Determinant on stack
155 // Note:
157 // Uses Gaussian elimination which is faster for numerical matrices.
159 // Gaussian Elimination works by walking down the diagonal and clearing
160 // out the columns below it.
162 //-----------------------------------------------------------------------------
164 void
165 detg(void)
167 save();
169 p1 = pop();
171 if (check_arg() == 0) {
172 push_symbol(DET);
173 push(p1);
174 list(2);
175 restore();
176 return;
179 yydetg();
181 restore();
184 void
185 yydetg(void)
187 int i, n;
189 n = p1->u.tensor->dim[0];
191 for (i = 0; i < n * n; i++)
192 push(p1->u.tensor->elem[i]);
194 lu_decomp(n);
196 tos -= n * n;
198 push(p1);
201 //-----------------------------------------------------------------------------
203 // Input: n * n matrix elements on stack
205 // Output: p1 determinant
207 // p2 mangled
209 // upper diagonal matrix on stack
211 //-----------------------------------------------------------------------------
213 #define M(i, j) stack[h + n * (i) + (j)]
215 void
216 lu_decomp(int n)
218 int d, h, i, j;
220 h = tos - n * n;
222 p1 = one;
224 for (d = 0; d < n - 1; d++) {
226 // diagonal element zero?
228 if (equal(M(d, d), zero)) {
230 // find a new row
232 for (i = d + 1; i < n; i++)
233 if (!equal(M(i, d), zero))
234 break;
236 if (i == n) {
237 p1 = zero;
238 break;
241 // exchange rows
243 for (j = d; j < n; j++) {
244 p2 = M(d, j);
245 M(d, j) = M(i, j);
246 M(i, j) = p2;
249 // negate det
251 push(p1);
252 negate();
253 p1 = pop();
256 // update det
258 push(p1);
259 push(M(d, d));
260 multiply();
261 p1 = pop();
263 // update lower diagonal matrix
265 for (i = d + 1; i < n; i++) {
267 // multiplier
269 push(M(i, d));
270 push(M(d, d));
271 divide();
272 negate();
274 p2 = pop();
276 // update one row
278 M(i, d) = zero; // clear column below pivot d
280 for (j = d + 1; j < n; j++) {
281 push(M(d, j));
282 push(p2);
283 multiply();
284 push(M(i, j));
285 add();
286 M(i, j) = pop();
291 // last diagonal element
293 push(p1);
294 push(M(n - 1, n - 1));
295 multiply();
296 p1 = pop();