use opaque pointers
[sddekit.git] / src / sk_net.c
blobd2b65366c90acc64ee6470ade84d8486df47ba13
1 /* Apache 2.0 INS-AMU 2015 */
3 #include <string.h> /* memcpy */
5 #include "sk_net.h"
6 #include "sk_malloc.h"
8 struct sk_net_data {
9 int n, m, nnz, *M, *Ms, *Me, ns, ne, *Or, *Ic;
10 double *w, *d, * restrict cn;
11 sk_sys *models;
12 void **models_data;
13 /* flag for init1 use */
14 int _init1;
17 sk_net_data *sk_net_alloc() {
18 return sk_malloc (sizeof(sk_net_data));
21 SK_DEFSYS(sk_net_sys)
23 int l, j, mi;
24 double *xi, *fi, *gi, *ci;
25 sk_net_data *d = data;
27 /* unused arguments */
28 (void) nx; (void) Jf; (void) Jg; (void) Jce; (void) nc;
30 /* compute coupling
32 * solver deals with c following Ic,d, but here we wanted ce
34 for (l=0; l<d->ne; l++)
35 for (d->cn[l]=0.0, j=d->Or[l]; j<d->Or[l+1]; j++)
36 /* c is len(ui), not ne! Ic[j] can't index it directly */
37 d->cn[l] += c[sk_hist_get_vi2i(hist, d->Ic[j])] * d->w[j];
39 /* evaluate system(s) */
40 if (Jf==NULL) {
41 for (l=0, xi=x, fi=f, gi=g, ci=d->cn; l<d->n; l++,
42 xi+=d->Ms[mi], fi+=d->Ms[mi], gi+=d->Ms[mi], ci+=d->Me[mi])
44 mi = d->M[l];
45 (*(d->models[mi]))(d->models_data[mi], hist, t, i,
46 d->Ms[mi], xi, fi, gi, NULL, NULL,
47 d->Me[mi], ci, NULL);
49 } else {
50 /* TODO evaluate & compute Jf/Jg/Jc
51 * in net, we expect sys to evaluate only its own Jf/Jg, not whole.
52 * we are responsible for Jf/Jg ? or need Jc? Jca, Jce?
54 * Jf/Jg include ca, i.e. [ df0/dx0 df0/dx1 .. df0/dca0 .. ]
56 * may need Jc [ dce0/dx0 .. dce0/dca0 .. ]
58 * system defines Jf/Jg as block, we handle full network Jf/Jg
59 * which may be sparse if have large network.
61 * assume sparse J by default: for small systems, penalty is small
62 * for large systems, need sparse; overall complexity is reduced.
65 /* compute coupling
67 * solver deals with c following Ic,d, but here we wanted ce
68 * so we need to pack our cn into c.
70 for (l=0; l<d->nnz; l++)
71 c[l] = d->cn[d->Ic[l]];
72 return 0;
75 struct sk_net_regmap_data {
76 int i, *n; /* n=1 for sum instead of averaging */
79 SK_DEFSYS(sk_net_regmap)
81 int l;
82 sk_net_regmap_data *d = data;
83 /* unused arguments */
84 (void) nx;(void) t; (void) Jf; (void) Jg; (void) Jce; (void) hist; (void) i;
85 f[0] = 0.0;
86 g[0] = 0.0;
87 x[0] = 0.0;
88 for (l=0; l<nc; l++)
89 x[0] += c[l];
90 c[0] = x[0] / d->n[d->i];
91 return 0;
94 int sk_net_init1(sk_net_data *net, int n, sk_sys sys, void *data,
95 int ns, int ne, int nnz, int *Or, int *Ic, double *w, double *d)
97 int i, *M, *Ms, *Me;
98 sk_sys *models;
99 void **model_data;
100 M = sk_malloc (sizeof(int) * n);
101 Ms = sk_malloc (sizeof(int));
102 Me = sk_malloc (sizeof(int));
103 models = sk_malloc (sizeof(sk_sys));
104 model_data = sk_malloc (sizeof(void*));
105 Ms[0] = ns;
106 Me[0] = ne;
107 for(i=0; i<n; i++)
108 M[i] = 0;
109 models[0] = sys;
110 model_data[0] = data;
111 sk_net_initn(net, n, 1, M, Ms, Me, models, model_data, nnz, Or, Ic, w, d);
112 net->_init1 = 1;
113 return 0;
116 void sk_net_free(sk_net_data *net)
118 if (net->_init1) {
119 sk_free(net->M);
120 sk_free(net->Ms);
121 sk_free(net->Me);
122 sk_free(net->models);
123 sk_free((void*) net->models_data);
124 sk_free(net->cn);
126 sk_free(net);
130 int sk_net_initn(sk_net_data *net, int n, int m,
131 int *M, int *Ms, int *Me, sk_sys *models, void **models_data,
132 int nnz, int *Or, int *Ic, double *w, double *d)
134 int i;
135 net->n = n;
136 net->m = m;
137 net->nnz = nnz;
138 net->M = M;
139 net->Ms = Ms;
140 net->Me = Me;
141 net->models = models;
142 net->models_data = models_data;
143 net->Or = Or;
144 net->Ic = Ic; /* TODO ? same as Ie ? */
145 net->w = w;
146 net->d = d;
147 /* intialize based on passed attributes: Ie, cne, cna */
148 net->ns = 0;
149 net->ne = 0;
150 for (i=0; i<n; i++) {
151 net->ns += net->Ms[net->M[i]];
152 net->ne += net->Me[net->M[i]];
154 net->cn = sk_malloc (sizeof(double) * net->ne);
155 net->_init1 = 0;
156 return 0;
159 int sk_net_get_n(sk_net_data *net) {
160 return net->n;
163 int sk_net_get_m(sk_net_data *net) {
164 return net->m;
167 int sk_net_get_nnz(sk_net_data *net) {
168 return net->nnz;
171 int *sk_net_get_or(sk_net_data *net) {
172 return net->Or;
175 int sk_net_get_or_i(sk_net_data *net, int i) {
176 return net->Or[i];
179 int *sk_net_get_ic(sk_net_data *net) {
180 return net->Ic;
183 int sk_net_get_ic_i(sk_net_data *net, int i) {
184 return net->Ic[i];
187 double *sk_net_get_w(sk_net_data *net) {
188 return net->w;
191 double sk_net_get_w_i(sk_net_data *net, int i) {
192 return net->w[i];
195 double *sk_net_get_d(sk_net_data *net) {
196 return net->d;
199 double sk_net_get_d_i(sk_net_data *net, int i) {
200 return net->d[i];
203 int sk_net_get_ns(sk_net_data *net) {
204 return net->ns;
207 int sk_net_get_ne(sk_net_data *net) {
208 return net->ne;
211 int sk_net_cn_is_null(sk_net_data *net) {
212 return net->cn == NULL;
215 int sk_net_get_Ms_i(sk_net_data *net, int i) {
216 return net->Ms[i];
219 int sk_net_get_Me_i(sk_net_data *net, int i) {
220 return net->Me[i];
223 int sk_net_get_M_i(sk_net_data *net, int i) {
224 return net->M[i];
227 sk_sys sk_net_get_models_i(sk_net_data *net, int i) {
228 return net->models[i];
231 void *sk_net_get_models_data_i(sk_net_data *net, int i) {
232 return net->models_data[i];
235 int sk_net_get__init1(sk_net_data *net) {
236 return net->_init1;