Merge pull request #113 from gitter-badger/gitter-badge
[sddekit.git] / src / sys_net.c
blob039b602b78d36557c99afa6776f89853f9b93cdb
1 /* copyright 2016 Apache 2 sddekit authors */
3 #include "sddekit.h"
5 typedef struct netd
7 sd_sys sys_if;
8 sd_net net_if;
9 uint32_t n, m, nnz, * restrict M, * restrict Ms,
10 * restrict Ma, * restrict Me, ns, ne,
11 * restrict Or, * restrict Ic;
12 double * restrict w, * restrict d, * restrict cn;
13 sd_sys **models;
14 /* flag for init1 use */
15 bool _init1;
16 } netd;
18 uint32_t get_n (sd_net *net) { return ((netd*)net->ptr)->n; }
19 uint32_t get_m (sd_net *net) { return ((netd*)net->ptr)->m; }
20 uint32_t get_nnz (sd_net *net) { return ((netd*)net->ptr)->nnz; }
21 uint32_t *get_or (sd_net *net) { return ((netd*)net->ptr)->Or; }
22 uint32_t get_or_i (sd_net *net, uint32_t i) { return ((netd*)net->ptr)->Or[i]; }
23 uint32_t *get_ic (sd_net *net) { return ((netd*)net->ptr)->Ic; }
24 uint32_t get_ic_i (sd_net *net, uint32_t i) { return ((netd*)net->ptr)->Ic[i]; }
25 double * get_w (sd_net *net) { return ((netd*)net->ptr)->w; }
26 double get_w_i (sd_net *net, uint32_t i) { return ((netd*)net->ptr)->w[i]; }
27 double * get_d (sd_net *net) { return ((netd*)net->ptr)->d; }
28 double get_d_i (sd_net *net, uint32_t i) { return ((netd*)net->ptr)->d[i]; }
29 uint32_t get_ns (sd_net *net) { return ((netd*)net->ptr)->ns; }
30 uint32_t get_ne (sd_net *net) { return ((netd*)net->ptr)->ne; }
31 bool cn_is_null (sd_net *net) { return ((netd*)net->ptr)->cn == NULL; }
32 uint32_t get_Ms_i (sd_net *net, uint32_t i) { return ((netd*)net->ptr)->Ms[i]; }
33 uint32_t get_Ma_i (sd_net *net, uint32_t i) { return ((netd*)net->ptr)->Ma[i]; }
34 uint32_t get_Me_i (sd_net *net, uint32_t i) { return ((netd*)net->ptr)->Me[i]; }
35 uint32_t get_M_i (sd_net *net, uint32_t i) { return ((netd*)net->ptr)->M[i]; }
36 sd_sys * get_models_i(sd_net *net, uint32_t i) { return ((netd*)net->ptr)->models[i]; }
37 bool get__init1 (sd_net *net) { return ((netd*)net->ptr)->_init1; }
39 static sd_stat apply(sd_sys *sys, sd_sys_in *in, sd_sys_out *out)
41 sd_stat stat;
42 uint32_t l, j;
43 netd *d = sys->ptr;
44 sd_sys **sysi = d->models;
45 sd_sys_in in_l = *in;
46 sd_sys_out out_l = *out;
47 /* compute (sparse) inputs */
48 for (l=0; l<d->ne; l++)
49 for (d->cn[l]=0.0, j=d->Or[l]; j<d->Or[l+1]; j++)
50 d->cn[l] += in->i[in->hist->get_vi2i(in->hist, d->Ic[j])] * d->w[j];
51 /* TODO redo, this is not restrict */
52 in_l.i = out_l.o = d->cn;
53 for (l = 0; l < d->n; l++)
55 uint32_t ml = d->M[l];
56 if ((stat = (*sysi)->apply(*sysi, &in_l, &out_l)) != SD_OK)
57 return stat;
58 /* TODO double check */
59 in_l.id += 1;
60 in_l.i += d->Ma[ml];
61 in_l.x += d->Ms[ml];
62 out_l.f += d->Ms[ml];
63 out_l.g += d->Ms[ml];
64 out_l.o += d->Me[ml];
66 /* compute outputs */
67 for (l=0; l<d->nnz; l++)
68 out->o[l] = d->cn[d->Ic[l]];
69 return SD_OK;
72 sd_net *sd_net_new_hom(uint32_t n, sd_sys *sys,
73 uint32_t ns, uint32_t na, uint32_t ne,
74 uint32_t nnz,
75 uint32_t * restrict Or,
76 uint32_t * restrict Ic,
77 double * restrict w,
78 double * restrict d)
80 uint32_t i, *M, *Ms, *Me, *Ma;
81 sd_sys **models;
82 char *errmsg;
83 sd_net *net;
84 M = Ms = Me = Ma = NULL;
85 models = NULL;
86 if ((M = sd_malloc (sizeof(uint32_t) * n))==NULL
87 || (Ms = sd_malloc (sizeof(uint32_t)))==NULL
88 || (Ma = sd_malloc (sizeof(uint32_t)))==NULL
89 || (Me = sd_malloc (sizeof(uint32_t)))==NULL
90 || (models = sd_malloc (sizeof(sd_sys*)))==NULL)
92 errmsg = "failed to allocate net init1 storage.";
93 goto fail;
95 Ms[0] = ns;
96 Ma[0] = na;
97 Me[0] = ne;
98 for(i=0; i<n; i++)
99 M[i] = 0;
100 models[0] = sys;
101 if ((net = sd_net_new_het(n, 1, M, Ms, Ma, Me, models,
102 nnz, Or, Ic, w, d)) == NULL)
104 errmsg = "net initn failed.";
105 goto fail;
107 ((netd*)net->ptr)->_init1 = true;
108 return net;
109 fail:
110 if (M!=NULL) sd_free(M);
111 if (Ms!=NULL) sd_free(Ms);
112 if (Me!=NULL) sd_free(Me);
113 if (Ma!=NULL) sd_free(Ma);
114 if (models!=NULL) sd_free(models);
115 sd_err(errmsg);
116 return NULL;
119 static void free_ptr(netd *d)
121 if (d->_init1) {
122 sd_free(d->M);
123 sd_free(d->Ms);
124 sd_free(d->Me);
125 sd_free(d->Ma);
126 sd_free(d->models);
127 sd_free(d->cn);
129 sd_free(d);
132 static void free_net(sd_net *net) { free_ptr(net->ptr); }
133 static void free_sys(sd_sys *sys) { free_ptr(sys->ptr); }
135 static uint32_t sys_ndim(sd_sys*sys) { return ((netd*)sys->ptr)->ns; }
136 static uint32_t sys_ndc(sd_sys*sys) { return ((netd*)sys->ptr)->nnz; }
137 static uint32_t sys_nobs(sd_sys*sys) { return ((netd*)sys->ptr)->ne; }
138 static uint32_t sys_nrpar(sd_sys*sys){ (void) sys; return 0; }
139 static uint32_t sys_nipar(sd_sys*sys){ (void) sys; return 0; }
141 static sd_sys net_sys_defaults = {
142 .ndim = &sys_ndim,
143 .ndc = &sys_ndc,
144 .nobs = &sys_nobs,
145 .nrpar = &sys_nrpar,
146 .nipar = &sys_nipar,
147 .apply = &apply,
148 .free = &free_sys,
149 .ptr = NULL
152 static sd_sys *net_to_sys(sd_net *net) { return &(((netd*)net->ptr)->sys_if); }
154 static sd_net net_defaults = {
155 .ptr = NULL,
156 .sys = &net_to_sys,
157 .get_n = &get_n,
158 .get_m = &get_m,
159 .get_nnz = &get_nnz,
160 .get_or = &get_or,
161 .get_or_i = &get_or_i,
162 .get_ic = &get_ic,
163 .get_ic_i = &get_ic_i,
164 .get_w = &get_w,
165 .get_w_i = &get_w_i,
166 .get_d = &get_d,
167 .get_d_i = &get_d_i,
168 .get_ns = &get_ns,
169 .get_ne = &get_ne,
170 .cn_is_null = &cn_is_null,
171 .get_Ms_i = &get_Ms_i,
172 .get_Ma_i = &get_Ma_i,
173 .get_Me_i = &get_Me_i,
174 .get_M_i = &get_M_i,
175 .get_models_i = &get_models_i,
176 .get__init1 = &get__init1,
177 .free = &free_net
180 sd_net *
181 sd_net_new_het(uint32_t n, uint32_t m,
182 uint32_t * restrict M, uint32_t * restrict Ms,
183 uint32_t * restrict Ma, uint32_t * restrict Me,
184 sd_sys **models,
185 uint32_t nnz,
186 uint32_t * restrict Or,
187 uint32_t * restrict Ic,
188 double * restrict w,
189 double * restrict d)
191 uint32_t i;
192 netd *net = sd_malloc(sizeof(netd));
193 if (net == NULL)
195 sd_err("net alloc failed.");
196 return NULL;
198 net->net_if = net_defaults;
199 net->sys_if = net_sys_defaults;
200 net->net_if.ptr = net->sys_if.ptr = net;
201 net->n = n;
202 net->m = m;
203 net->nnz = nnz;
204 net->M = M;
205 net->Ms = Ms;
206 net->Ma = Ma;
207 net->Me = Me;
208 net->models = models;
209 net->Or = Or;
210 net->Ic = Ic; /* TODO ? same as Ie ? */
211 net->w = w;
212 net->d = d;
213 /* intialize based on passed attributes: Ie, cne, cna */
214 net->ns = 0;
215 net->ne = 0;
216 for (i=0; i<n; i++) {
217 net->ns += net->Ms[net->M[i]];
218 net->ne += net->Me[net->M[i]];
220 if ((net->cn = sd_malloc (sizeof(double) * net->ne))==NULL)
222 sd_err("failed to allocate memory for network.");
223 return NULL;
225 net->_init1 = 0;
226 return &(net->net_if);