Little fix after the last commit (mostly a git fail)
[eigenmath-fx.git] / nroots.cpp
blobd4613a7865b481eb2e8057f897d39884156ad070
1 // find the roots of a polynomial numerically
3 #include "stdafx.h"
4 #include "defs.h"
6 #define YMAX 101
7 #define DELTA 1.0e-6
8 #define EPSILON 1.0e-9
9 #define ABS(z) sqrt((z).r * (z).r + (z).i * (z).i)
10 #define RANDOM (4.0 * (double) rand() / (double) RAND_MAX - 2.0)
12 static struct {
13 double r, i;
14 } a, b, x, y, fa, fb, dx, df, c[YMAX];
16 void
17 eval_nroots(void)
19 int h, i, k, n;
21 push(cadr(p1));
22 eval();
24 push(caddr(p1));
25 eval();
26 p2 = pop();
27 if (p2 == symbol(NIL))
28 guess();
29 else
30 push(p2);
32 p2 = pop();
33 p1 = pop();
35 if (!ispoly(p1, p2))
36 stop("nroots: polynomial?");
38 // mark the stack
40 h = tos;
42 // get the coefficients
44 push(p1);
45 push(p2);
46 n = coeff();
47 if (n > YMAX)
48 stop("nroots: degree?");
50 // convert the coefficients to real and imaginary doubles
52 for (i = 0; i < n; i++) {
53 push(stack[h + i]);
54 real();
55 yyfloat();
56 eval();
57 p1 = pop();
58 push(stack[h + i]);
59 imag();
60 yyfloat();
61 eval();
62 p2 = pop();
63 if (!isdouble(p1) || !isdouble(p2))
64 stop("nroots: coefficients?");
65 c[i].r = p1->u.d;
66 c[i].i = p2->u.d;
69 // pop the coefficients
71 tos = h;
73 // n is the number of coefficients, n = deg(p) + 1
75 monic(n);
77 for (k = n; k > 1; k--) {
78 findroot(k);
79 if (fabs(a.r) < DELTA)
80 a.r = 0.0;
81 if (fabs(a.i) < DELTA)
82 a.i = 0.0;
83 push_double(a.r);
84 push_double(a.i);
85 push(imaginaryunit);
86 multiply();
87 add();
88 divpoly(k);
91 // now make n equal to the number of roots
93 n = tos - h;
95 if (n > 1) {
96 sort_stack(n);
97 p1 = alloc_tensor(n);
98 p1->u.tensor->ndim = 1;
99 p1->u.tensor->dim[0] = n;
100 for (i = 0; i < n; i++)
101 p1->u.tensor->elem[i] = stack[h + i];
102 tos = h;
103 push(p1);
107 // divide the polynomial by its leading coefficient
109 void
110 monic(int n)
112 int k;
113 double t;
114 y = c[n - 1];
115 t = y.r * y.r + y.i * y.i;
116 for (k = 0; k < n - 1; k++) {
117 c[k].r = (c[k].r * y.r + c[k].i * y.i) / t;
118 c[k].i = (c[k].i * y.r - c[k].r * y.i) / t;
120 c[n - 1].r = 1.0;
121 c[n - 1].i = 0.0;
124 // uses the secant method
126 void
127 findroot(int n)
129 int j, k;
130 double t;
132 if (ABS(c[0]) < DELTA) {
133 a.r = 0.0;
134 a.i = 0.0;
135 return;
138 for (j = 0; j < 100; j++) {
140 a.r = RANDOM;
141 a.i = RANDOM;
143 compute_fa(n);
145 b = a;
146 fb = fa;
148 a.r = RANDOM;
149 a.i = RANDOM;
151 for (k = 0; k < 1000; k++) {
153 compute_fa(n);
155 if (ABS(fa) < EPSILON)
156 return;
158 if (ABS(fa) < ABS(fb)) {
159 x = a;
160 a = b;
161 b = x;
162 x = fa;
163 fa = fb;
164 fb = x;
167 // dx = b - a
169 dx.r = b.r - a.r;
170 dx.i = b.i - a.i;
172 // df = fb - fa
174 df.r = fb.r - fa.r;
175 df.i = fb.i - fa.i;
177 // y = dx / df
179 t = df.r * df.r + df.i * df.i;
181 if (t == 0.0)
182 break;
184 y.r = (dx.r * df.r + dx.i * df.i) / t;
185 y.i = (dx.i * df.r - dx.r * df.i) / t;
187 // a = b - y * fb
189 a.r = b.r - (y.r * fb.r - y.i * fb.i);
190 a.i = b.i - (y.r * fb.i + y.i * fb.r);
194 stop("nroots: convergence error");
197 void
198 compute_fa(int n)
200 int k;
201 double t;
203 // x = a
205 x.r = a.r;
206 x.i = a.i;
208 // fa = c0 + c1 * x
210 fa.r = c[0].r + c[1].r * x.r - c[1].i * x.i;
211 fa.i = c[0].i + c[1].r * x.i + c[1].i * x.r;
213 for (k = 2; k < n; k++) {
215 // x = a * x
217 t = a.r * x.r - a.i * x.i;
218 x.i = a.r * x.i + a.i * x.r;
219 x.r = t;
221 // fa += c[k] * x
223 fa.r += c[k].r * x.r - c[k].i * x.i;
224 fa.i += c[k].r * x.i + c[k].i * x.r;
228 // divide the polynomial by x - a
230 void
231 divpoly(int n)
233 int k;
234 for (k = n - 1; k > 0; k--) {
235 c[k - 1].r += c[k].r * a.r - c[k].i * a.i;
236 c[k - 1].i += c[k].i * a.r + c[k].r * a.i;
238 if (ABS(c[0]) > DELTA)
239 stop("nroots: residual error");
240 for (k = 0; k < n - 1; k++) {
241 c[k].r = c[k + 1].r;
242 c[k].i = c[k + 1].i;
246 #if SELFTEST
248 static char *s[] = {
250 "nroots(x)",
251 "0",
253 "nroots((1+i)*x^2+1)",
254 "(-0.17178-0.727673*i,0.17178+0.727673*i)",
256 "nroots(sqrt(2)*exp(i*pi/4)*x^2+1)",
257 "(-0.17178-0.727673*i,0.17178+0.727673*i)",
259 // "nroots(x^4+1)",
260 // "(-0.707107+0.707107*i,-0.707107-0.707107*i,0.707107+0.707107*i,0.707107-0.707107*i)",
263 void
264 test_nroots(void)
266 test(__FILE__, s, sizeof s / sizeof (char *));
269 #endif