Merge pull request #113 from gitter-badger/gitter-badge
[sddekit.git] / src / sd_scheme.c
blob33ff02a61f0273ea3e33093ccc6966deccb336ae
1 /* copyright 2016 Apache 2 sddekit authors */
3 #include "sddekit.h"
5 typedef struct one_step_data {
6 int nx;
7 double *f, *g, *z;
8 sd_sch sch;
9 } one_step_data;
11 static uint32_t one_step_get_nx(sd_sch *s) { return ((one_step_data*)s->ptr)->nx; }
13 void one_step_free(sd_sch *sch)
15 one_step_data *d = sch->ptr;
16 sd_free(d->f);
17 sd_free(d->g);
18 sd_free(d->z);
19 sd_free(d);
22 static sd_sch *
23 new_one_step(uint32_t nx)
25 bool err;
26 one_step_data *d;
27 if ((d = sd_malloc(sizeof(one_step_data))) == NULL) {
28 sd_err("memory alloc for one step sch failed.");
29 return NULL;
32 one_step_data zero = { 0 };
33 *d = zero;
35 err = (d->f = sd_malloc(sizeof(double)*nx)) == NULL;
36 err |= (d->g = sd_malloc(sizeof(double)*nx)) == NULL;
37 err |= (d->z = sd_malloc(sizeof(double)*nx)) == NULL;
38 if (err) {
39 if (d->f!=NULL) sd_free(d->f);
40 if (d->g!=NULL) sd_free(d->g);
41 if (d->z!=NULL) sd_free(d->z);
42 if (d != NULL) sd_free(d);
43 sd_err("memory alloc for one step sch failed.");
44 return NULL;
46 d->nx = nx;
47 d->sch.ptr = d;
48 d->sch.get_nx = &one_step_get_nx;
49 d->sch.free = &one_step_free;
50 return &(d->sch);
53 static sd_stat one_step_apply(
54 sd_sch *sch, sd_hist *hist, sd_rng *rng, sd_sys *sys,
55 double t, double dt,
56 uint32_t nx, double * restrict x,
57 uint32_t nc, double * restrict c)
59 (void) dt;
60 sd_stat stat;
61 one_step_data *d = sch->ptr;
62 hist->get(hist, t, c);
63 sd_sys_in in = { .t = t, .nx = nx, .x = x, .nc = nc, .i = c, .hist = hist, .rng = rng };
64 sd_sys_out out = {.f=d->f, .g=d->g, .o=c};
65 if ((stat=sys->apply(sys, &in, &out))!=SD_OK)
66 return stat;
67 rng->fill_norm(rng, nx, d->z);
68 /* compute step & set history */
69 return SD_OK;
72 static sd_stat id_apply(
73 sd_sch *sch, sd_hist *hist, sd_rng *rng, sd_sys *sys,
74 double t, double dt,
75 uint32_t nx, double * restrict x,
76 uint32_t nc, double * restrict c)
78 uint32_t i;
79 sd_stat stat;
80 one_step_data *d = sch->ptr;
81 if ((stat = one_step_apply(sch, hist, rng, sys, t, dt, nx, x, nc, c))!=SD_OK)
82 return stat;
83 for (i=0; i<nx; i++)
84 x[i] = d->f[i] + d->g[i] * d->z[i];
85 hist->set(hist, t, c);
86 return stat;
89 static sd_stat em_apply(
90 sd_sch *sch, sd_hist *hist, sd_rng *rng, sd_sys *sys,
91 double t, double dt,
92 uint32_t nx, double * restrict x,
93 uint32_t nc, double * restrict c)
95 uint32_t i;
96 sd_stat stat;
97 double sqrt_dt;
98 one_step_data *d = sch->ptr;
99 if ((stat = one_step_apply(sch, hist, rng, sys, t, dt, nx, x, nc, c)) != SD_OK)
100 return stat;
101 sqrt_dt = sqrt(dt);
102 for (i=0; i<nx; i++)
103 x[i] += dt * d->f[i] + sqrt_dt * d->g[i] * d->z[i];
104 hist->set(hist, t, c);
105 return 0;
109 sd_sch *sd_sch_new_id(uint32_t nx)
111 sd_sch *new = new_one_step(nx);
112 new->apply = &id_apply;
113 return new;
116 sd_sch *sd_sch_new_em(uint32_t nx)
118 sd_sch *new = new_one_step(nx);
119 new->apply = &em_apply;
120 return new;
124 typedef struct emc_data {
125 bool first_call;
126 uint32_t nx;
127 double *f, *g, *z, *eps, lam;
128 } emc_data;
130 static uint32_t
131 emc_get_nx(sd_sch *sch)
133 return ((emc_data*)sch->ptr)->nx;
136 double sd_sch_emc_get_lam(sd_sch *sch)
138 return ((emc_data*)sch)->lam;
141 static void
142 emc_free(sd_sch *sch)
144 emc_data *d = sch->ptr;
145 if (d==NULL) {
146 sd_err("returning early due to NULL instance pointer");
147 return;
149 if (d->f!=NULL) sd_free(d->f);
150 if (d->g!=NULL) sd_free(d->g);
151 if (d->z!=NULL) sd_free(d->z);
152 if (d->eps!=NULL) sd_free(d->eps);
153 sd_free(d);
154 sd_free(sch);
157 static sd_stat emc_apply(
158 sd_sch *sch, sd_hist *hist, sd_rng *rng, sd_sys *sys,
159 double t, double dt,
160 uint32_t nx, double * restrict x,
161 uint32_t nc, double * restrict c)
163 uint32_t i;
164 sd_stat stat;
165 double E;
166 emc_data *d = sch->ptr;
167 sd_sys_in in = { .t = t, .nx = nx, .x = x, .nc = nc, .i = c, .hist = hist, .rng = rng };
168 sd_sys_out out = {.f=d->f, .g=d->g, .o=c};
169 if (d->first_call) {
170 rng->fill_norm(rng, nx, d->z);
171 if ((stat = sys->apply(sys, &in, &out)) != SD_OK)
172 return stat;
173 for (i=0; i<nx; i++)
174 d->eps[i] = sqrt(d->g[i] * d->lam) * d->z[i];
175 d->first_call = false;
177 E = exp(-d->lam * dt);
178 rng->fill_norm(rng, nx, d->z);
179 hist->get(hist, t, c);
180 if ((stat = sys->apply(sys, &in, &out)) != SD_OK)
181 return stat;
182 for (i=0; i<nx; i++) {
183 x[i] += dt * (d->f[i] + d->eps[i]);
184 d->eps[i] *= E;
185 d->eps[i] += sqrt(d->g[i] * d->lam * (1 - E*E)) * d->z[i];
187 hist->set(hist, t, c);
188 return SD_OK;
191 sd_sch * sd_sch_new_emc(uint32_t nx, double lam)
193 bool err;
194 emc_data *d;
195 sd_sch *sch;
196 if ((d = sd_malloc(sizeof(emc_data))) == NULL) {
197 sd_err("memory alloc for one step sch failed.");
198 return NULL;
201 emc_data zero = { 0 };
202 *d = zero;
204 d->nx = nx;
205 err = (sch = sd_malloc(sizeof(sd_sch))) == NULL;
206 err |= (d->f=sd_malloc(sizeof(double)*nx))==NULL;
207 err |= (d->g=sd_malloc(sizeof(double)*nx))==NULL;
208 err |= (d->z=sd_malloc(sizeof(double)*nx))==NULL;
209 err |= (d->eps=sd_malloc(sizeof(double)*nx))==NULL;
210 if (err) {
211 if (d->f!=NULL) sd_free(d->f);
212 if (d->g!=NULL) sd_free(d->g);
213 if (d->z!=NULL) sd_free(d->z);
214 if (d->eps!=NULL) sd_free(d->eps);
215 if (d != NULL) sd_free(d);
216 if (sch != NULL) sd_free(sch);
217 sd_err("memory alloc for sch em color failed.");
218 return NULL;
220 d->first_call = true;
221 d->lam = lam;
222 sch->ptr = d;
223 sch->get_nx = &emc_get_nx;
224 sch->apply = &emc_apply;
225 sch->free = &emc_free;
226 return sch;
229 typedef struct heun_data {
230 int nx;
231 double *fl, *fr, *gl, *gr, *z, *xr;
232 } heun_data;
234 static uint32_t
235 heun_get_nx(sd_sch *sch)
237 return ((heun_data*)sch->ptr)->nx;
240 static void
241 heun_free(sd_sch *sch)
243 heun_data *d = sch->ptr;
244 sd_free(d->fl);;
245 sd_free(d->fr);;
246 sd_free(d->gl);;
247 sd_free(d->gr);;
248 sd_free(d->z);;
249 sd_free(d->xr);;
250 sd_free(d);
251 sd_free(sch);
254 static sd_stat heun_apply(
255 sd_sch *sch, sd_hist *hist, sd_rng *rng, sd_sys *sys,
256 double t, double dt,
257 uint32_t nx, double * restrict x,
258 uint32_t nc, double * restrict c)
260 uint32_t i;
261 sd_stat stat;
262 double sqrt_dt;
263 heun_data *d = sch->ptr;
264 sd_sys_in in = { .t = t, .nx = nx, .x = x, .nc = nc, .i = c, .hist = hist, .rng = rng };
265 sd_sys_out out = { .f = d->fl, .g = d->gl, .o = c };
266 /* predictor */
267 hist->get(hist, t, c);
268 if ((stat = sys->apply(sys, &in, &out)) != SD_OK)
269 return stat;
270 for (i=0; i<nx; i++)
271 d->xr[i] = x[i] + dt * d->fl[i];
272 hist->set(hist, t, c);
273 /* corrector */
274 hist->get(hist, t + dt, c);
275 in.t = t + dt;
276 in.x = d->xr;
277 out.f = d->fr;
278 out.g = d->gr;
279 if ((stat = sys->apply(sys, &in, &out)) != SD_OK)
280 return stat;
281 rng->fill_norm(rng, nx, d->z);
282 sqrt_dt = sqrt(dt);
283 for (i=0; i<nx; i++)
284 x[i] += 0.5 * (dt*(d->fl[i] + d->fr[i])
285 + sqrt_dt*(d->gl[i] + d->gr[i])*d->z[i]);
286 hist->set(hist, t + dt, c);
287 return 0;
290 sd_sch *
291 sd_sch_new_heun(uint32_t nx)
293 bool err;
294 heun_data *d;
295 sd_sch *sch;
296 if ((d = sd_malloc(sizeof(heun_data))) == NULL)
298 sd_err("memory alloc for heun scheme failed.");
299 return NULL;
302 heun_data zero = { 0 };
303 *d = zero;
305 d->nx = nx;
306 err = (sch = sd_malloc(sizeof(sd_sch)))==NULL;
307 err |= (d->fl=sd_malloc(sizeof(double)*nx))==NULL;
308 err |= (d->fr=sd_malloc(sizeof(double)*nx))==NULL;
309 err |= (d->gl=sd_malloc(sizeof(double)*nx))==NULL;
310 err |= (d->gr=sd_malloc(sizeof(double)*nx))==NULL;
311 err |= (d->z=sd_malloc(sizeof(double)*nx))==NULL;
312 err |= (d->xr=sd_malloc(sizeof(double)*nx))==NULL;
313 if (err) {
314 if (d->fl!=NULL) sd_free(d->fl);;
315 if (d->fr!=NULL) sd_free(d->fr);;
316 if (d->gl!=NULL) sd_free(d->gl);;
317 if (d->gr!=NULL) sd_free(d->gr);;
318 if (d->z!=NULL) sd_free(d->z);;
319 if (d->xr!=NULL) sd_free(d->xr);;
320 if (d != NULL) sd_free(d);
321 if (sch != NULL) sd_free(sch);
322 sd_err("memory alloc durong sch em init failed.");
323 return NULL;
325 sch->ptr = d;
326 sch->get_nx = &heun_get_nx;
327 sch->apply = &heun_apply;
328 sch->free = &heun_free;
329 return sch;