Code cleanup to avoid name collisions
[gromacs.git] / src / gmxlib / statistics / gmx_statistics.c
blobe57e4fe94296e10dbf6ccf68ec74f03bb222af3d
1 /* -*- mode: c; tab-width: 4; indent-tabs-mode: nil; c-basic-offset: 4; c-file-style: "stroustrup"; -*- */
2 /*
3 *
4 * This source code is part of
5 *
6 * G R O M A C S
7 *
8 * GROningen MAchine for Chemical Simulations
9 *
10 * VERSION 3.2.0
11 * Written by David van der Spoel, Erik Lindahl, Berk Hess, and others.
12 * Copyright (c) 1991-2000, University of Groningen, The Netherlands.
13 * Copyright (c) 2001-2004, The GROMACS development team,
14 * check out http://www.gromacs.org for more information.
16 * This program is free software; you can redistribute it and/or
17 * modify it under the terms of the GNU General Public License
18 * as published by the Free Software Foundation; either version 2
19 * of the License, or (at your option) any later version.
21 * If you want to redistribute modifications, please consider that
22 * scientific software is very special. Version control is crucial -
23 * bugs must be traceable. We will be happy to consider code for
24 * inclusion in the official distribution, but derived work must not
25 * be called official GROMACS. Details are found in the README & COPYING
26 * files - if they are missing, get the official version at www.gromacs.org.
28 * To help us fund GROMACS development, we humbly ask that you cite
29 * the papers on the package - you can find them in the top README file.
31 * For more info, check our website at http://www.gromacs.org
33 * And Hey:
34 * Green Red Orange Magenta Azure Cyan Skyblue
36 #ifdef HAVE_CONFIG_H
37 #include <config.h>
38 #endif
40 #include <math.h>
41 #include "typedefs.h"
42 #include "smalloc.h"
43 #include "vec.h"
44 #include "gmx_statistics.h"
46 static int gmx_dnint(double x)
48 return (int) (x+0.5);
51 typedef struct gmx_stats {
52 double aa,a,b,sigma_aa,sigma_a,sigma_b,aver,sigma_aver,error;
53 double rmsd,Rdata,Rfit,Rfitaa,chi2,chi2aa;
54 double *x,*y,*dx,*dy;
55 int computed;
56 int np,np_c,nalloc;
57 } gmx_stats;
59 gmx_stats_t gmx_stats_init()
61 gmx_stats *stats;
63 snew(stats,1);
65 return (gmx_stats_t) stats;
68 int gmx_stats_get_npoints(gmx_stats_t gstats, int *N)
70 gmx_stats *stats = (gmx_stats *) gstats;
72 *N = stats->np;
74 return estatsOK;
77 int gmx_stats_done(gmx_stats_t gstats)
79 gmx_stats *stats = (gmx_stats *) gstats;
81 sfree(stats->x);
82 stats->x = NULL;
83 sfree(stats->y);
84 stats->y = NULL;
85 sfree(stats->dx);
86 stats->dx = NULL;
87 sfree(stats->dy);
88 stats->dy = NULL;
90 return estatsOK;
93 int gmx_stats_add_point(gmx_stats_t gstats,double x,double y,
94 double dx,double dy)
96 gmx_stats *stats = (gmx_stats *) gstats;
97 int i;
99 if (stats->np+1 >= stats->nalloc)
101 if (stats->nalloc == 0)
102 stats->nalloc = 1024;
103 else
104 stats->nalloc *= 2;
105 srenew(stats->x,stats->nalloc);
106 srenew(stats->y,stats->nalloc);
107 srenew(stats->dx,stats->nalloc);
108 srenew(stats->dy,stats->nalloc);
109 for(i=stats->np; (i<stats->nalloc); i++)
111 stats->x[i] = 0;
112 stats->y[i] = 0;
113 stats->dx[i] = 0;
114 stats->dy[i] = 0;
117 stats->x[stats->np] = x;
118 stats->y[stats->np] = y;
119 stats->dx[stats->np] = dx;
120 stats->dy[stats->np] = dy;
121 stats->np++;
122 stats->computed = 0;
124 return estatsOK;
127 int gmx_stats_get_point(gmx_stats_t gstats,real *x,real *y,
128 real *dx,real *dy)
130 gmx_stats *stats = (gmx_stats *) gstats;
132 if (stats->np_c < stats->np)
134 if (NULL != x) *x = stats->x[stats->np_c];
135 if (NULL != y) *y = stats->y[stats->np_c];
136 if (NULL != dx) *dx = stats->dx[stats->np_c];
137 if (NULL != dy) *dy = stats->dy[stats->np_c];
138 stats->np_c++;
140 return estatsOK;
142 stats->np_c = 0;
144 return estatsNO_POINTS;
147 int gmx_stats_add_points(gmx_stats_t gstats,int n,real *x,real *y,
148 real *dx,real *dy)
150 int i,ok;
152 for(i=0; (i<n); i++)
154 if ((ok = gmx_stats_add_point(gstats,x[i],y[i],
155 (NULL != dx) ? dx[i] : 0,
156 (NULL != dy) ? dy[i] : 0)) != estatsOK)
158 return ok;
161 return estatsOK;
164 static int gmx_stats_compute(gmx_stats *stats,int weight)
166 double yy,yx,xx,sx,sy,dy,chi2,chi2aa,d2;
167 double ssxx,ssyy,ssxy;
168 double w,wtot,yx_nw,sy_nw,sx_nw,yy_nw,xx_nw,dx2,dy2;
169 int i,N;
171 N = stats->np;
172 if (stats->computed == 0)
174 if (N < 1)
176 return estatsNO_POINTS;
179 xx = xx_nw = 0;
180 yy = yy_nw = 0;
181 yx = yx_nw = 0;
182 sx = sx_nw = 0;
183 sy = sy_nw = 0;
184 wtot = 0;
185 d2 = 0;
186 for(i=0; (i<N); i++)
188 d2 += dsqr(stats->x[i]-stats->y[i]);
189 if ((stats->dy[i]) && (weight == elsqWEIGHT_Y))
191 w = 1/dsqr(stats->dy[i]);
193 else
195 w = 1;
198 wtot += w;
200 xx += w*dsqr(stats->x[i]);
201 xx_nw += dsqr(stats->x[i]);
203 yy += w*dsqr(stats->y[i]);
204 yy_nw += dsqr(stats->y[i]);
206 yx += w*stats->y[i]*stats->x[i];
207 yx_nw += stats->y[i]*stats->x[i];
209 sx += w*stats->x[i];
210 sx_nw += stats->x[i];
212 sy += w*stats->y[i];
213 sy_nw += stats->y[i];
216 /* Compute average, sigma and error */
217 stats->aver = sy_nw/N;
218 stats->sigma_aver = sqrt(yy_nw/N - dsqr(sy_nw/N));
219 stats->error = stats->sigma_aver/sqrt(N);
221 /* Compute RMSD between x and y */
222 stats->rmsd = sqrt(d2/N);
224 /* Correlation coefficient for data */
225 yx_nw /= N;
226 xx_nw /= N;
227 yy_nw /= N;
228 sx_nw /= N;
229 sy_nw /= N;
230 ssxx = N*(xx_nw - dsqr(sx_nw));
231 ssyy = N*(yy_nw - dsqr(sy_nw));
232 ssxy = N*(yx_nw - (sx_nw*sy_nw));
233 stats->Rdata = sqrt(dsqr(ssxy)/(ssxx*ssyy));
235 /* Compute straight line through datapoints, either with intercept
236 zero (result in aa) or with intercept variable (results in a
237 and b) */
238 yx = yx/wtot;
239 xx = xx/wtot;
240 sx = sx/wtot;
241 sy = sy/wtot;
243 stats->aa = (yx/xx);
244 stats->a = (yx-sx*sy)/(xx-sx*sx);
245 stats->b = (sy)-(stats->a)*(sx);
247 /* Compute chi2, deviation from a line y = ax+b. Also compute
248 chi2aa which returns the deviation from a line y = ax. */
249 chi2 = 0;
250 chi2aa = 0;
251 for(i=0; (i<N); i++)
253 if (stats->dy[i] > 0)
255 dy = stats->dy[i];
257 else
259 dy = 1;
261 chi2aa += dsqr((stats->y[i]-(stats->aa*stats->x[i]))/dy);
262 chi2 += dsqr((stats->y[i]-(stats->a*stats->x[i]+stats->b))/dy);
264 if (N > 2)
266 stats->chi2 = sqrt(chi2/(N-2));
267 stats->chi2aa = sqrt(chi2aa/(N-2));
269 /* Look up equations! */
270 dx2 = (xx-sx*sx);
271 dy2 = (yy-sy*sy);
272 stats->sigma_a = sqrt(stats->chi2/((N-2)*dx2));
273 stats->sigma_b = stats->sigma_a*sqrt(xx);
274 stats->Rfit = fabs(ssxy)/sqrt(ssxx*ssyy);
275 /*stats->a*sqrt(dx2/dy2);*/
276 stats->Rfitaa = stats->aa*sqrt(dx2/dy2);
278 else
280 stats->chi2 = 0;
281 stats->chi2aa = 0;
282 stats->sigma_a = 0;
283 stats->sigma_b = 0;
284 stats->Rfit = 0;
285 stats->Rfitaa = 0;
288 stats->computed = 1;
291 return estatsOK;
294 int gmx_stats_get_ab(gmx_stats_t gstats,int weight,
295 real *a,real *b,real *da,real *db,
296 real *chi2,real *Rfit)
298 gmx_stats *stats = (gmx_stats *) gstats;
299 int ok;
301 if ((ok = gmx_stats_compute(stats,weight)) != estatsOK)
302 return ok;
303 if (NULL != a)
305 *a = stats->a;
307 if (NULL != b)
309 *b = stats->b;
311 if (NULL != da)
313 *da = stats->sigma_a;
315 if (NULL != db)
317 *db = stats->sigma_b;
319 if (NULL != chi2)
321 *chi2 = stats->chi2;
323 if (NULL != Rfit)
325 *Rfit = stats->Rfit;
328 return estatsOK;
331 int gmx_stats_get_a(gmx_stats_t gstats,int weight,real *a,real *da,
332 real *chi2,real *Rfit)
334 gmx_stats *stats = (gmx_stats *) gstats;
335 int ok;
337 if ((ok = gmx_stats_compute(stats,weight)) != estatsOK)
338 return ok;
339 if (NULL != a) *a = stats->aa;
340 if (NULL != da) *da = stats->sigma_aa;
341 if (NULL != chi2) *chi2 = stats->chi2aa;
342 if (NULL != Rfit) *Rfit = stats->Rfitaa;
344 return estatsOK;
347 int gmx_stats_get_average(gmx_stats_t gstats,real *aver)
349 gmx_stats *stats = (gmx_stats *) gstats;
350 int ok;
352 if ((ok = gmx_stats_compute(stats,elsqWEIGHT_NONE)) != estatsOK)
354 return ok;
357 *aver = stats->aver;
359 return estatsOK;
362 int gmx_stats_get_ase(gmx_stats_t gstats,real *aver,real *sigma,real *error)
364 gmx_stats *stats = (gmx_stats *) gstats;
365 int ok;
367 if ((ok = gmx_stats_compute(stats,elsqWEIGHT_NONE)) != estatsOK)
369 return ok;
372 if (NULL != aver)
374 *aver = stats->aver;
376 if (NULL != sigma)
378 *sigma = stats->sigma_aver;
380 if (NULL != error)
382 *error = stats->error;
385 return estatsOK;
388 int gmx_stats_get_sigma(gmx_stats_t gstats,real *sigma)
390 gmx_stats *stats = (gmx_stats *) gstats;
391 int ok;
393 if ((ok = gmx_stats_compute(stats,elsqWEIGHT_NONE)) != estatsOK)
394 return ok;
396 *sigma = stats->sigma_aver;
398 return estatsOK;
401 int gmx_stats_get_error(gmx_stats_t gstats,real *error)
403 gmx_stats *stats = (gmx_stats *) gstats;
404 int ok;
406 if ((ok = gmx_stats_compute(stats,elsqWEIGHT_NONE)) != estatsOK)
407 return ok;
409 *error = stats->error;
411 return estatsOK;
414 int gmx_stats_get_corr_coeff(gmx_stats_t gstats,real *R)
416 gmx_stats *stats = (gmx_stats *) gstats;
417 int ok;
419 if ((ok = gmx_stats_compute(stats,elsqWEIGHT_NONE)) != estatsOK)
420 return ok;
422 *R = stats->Rdata;
424 return estatsOK;
427 int gmx_stats_get_rmsd(gmx_stats_t gstats,real *rmsd)
429 gmx_stats *stats = (gmx_stats *) gstats;
430 int ok;
432 if ((ok = gmx_stats_compute(stats,elsqWEIGHT_NONE)) != estatsOK)
434 return ok;
437 *rmsd = stats->rmsd;
439 return estatsOK;
442 int gmx_stats_dump_xy(gmx_stats_t gstats,FILE *fp)
444 gmx_stats *stats = (gmx_stats *) gstats;
445 int i,ok;
447 for(i=0; (i<stats->np); i++)
449 fprintf(fp,"%12g %12g %12g %12g\n",stats->x[i],stats->y[i],
450 stats->dx[i],stats->dy[i]);
453 return estatsOK;
456 int gmx_stats_remove_outliers(gmx_stats_t gstats,double level)
458 gmx_stats *stats = (gmx_stats *) gstats;
459 int i,iter=1,done=0,ok;
460 real rmsd,r;
462 while ((stats->np >= 10) && !done)
464 if ((ok = gmx_stats_get_rmsd(gstats,&rmsd)) != estatsOK)
466 return ok;
468 done = 1;
469 for(i=0; (i<stats->np); )
471 r = fabs(stats->x[i]-stats->y[i]);
472 if (r > level*rmsd)
474 fprintf(stderr,"Removing outlier, iter = %d, rmsd = %g, x = %g, y = %g\n",
475 iter,rmsd,stats->x[i],stats->y[i]);
476 if (i < stats->np-1)
478 stats->x[i] = stats->x[stats->np-1];
479 stats->y[i] = stats->y[stats->np-1];
480 stats->dx[i] = stats->dx[stats->np-1];
481 stats->dy[i] = stats->dy[stats->np-1];
483 stats->np--;
484 done = 0;
486 else
488 i++;
491 iter++;
494 return estatsOK;
497 int gmx_stats_make_histogram(gmx_stats_t gstats,real binwidth,int *nb,
498 int ehisto,int normalized,real **x,real **y)
500 gmx_stats *stats = (gmx_stats *) gstats;
501 int i,ok,index=0,nbins=*nb,*nindex;
502 double minx,maxx,maxy,miny,delta,dd,minh;
504 if (((binwidth <= 0) && (nbins <= 0)) ||
505 ((binwidth > 0) && (nbins > 0)))
507 return estatsINVALID_INPUT;
509 if (stats->np <= 2)
511 return estatsNO_POINTS;
513 minx = maxx = stats->x[0];
514 miny = maxy = stats->y[0];
515 for(i=1; (i<stats->np); i++)
517 miny = (stats->y[i] < miny) ? stats->y[i] : miny;
518 maxy = (stats->y[i] > maxy) ? stats->y[i] : maxy;
519 minx = (stats->x[i] < minx) ? stats->x[i] : minx;
520 maxx = (stats->x[i] > maxx) ? stats->x[i] : maxx;
522 if (ehisto == ehistoX)
524 delta = maxx-minx;
525 minh = minx;
527 else if (ehisto == ehistoY)
529 delta = maxy-miny;
530 minh = miny;
532 else
533 return estatsINVALID_INPUT;
535 if (binwidth == 0)
537 binwidth = (delta)/nbins;
539 else
541 nbins = gmx_dnint((delta)/binwidth + 0.5);
543 snew(*x,nbins);
544 snew(nindex,nbins);
545 for(i=0; (i<nbins); i++)
547 (*x)[i] = minh + binwidth*(i+0.5);
549 if (normalized == 0)
551 dd = 1;
553 else
555 dd = 1.0/(binwidth*stats->np);
558 snew(*y,nbins);
559 for(i=0; (i<stats->np); i++)
561 if (ehisto == ehistoY)
562 index = (stats->y[i]-miny)/binwidth;
563 else if (ehisto == ehistoX)
564 index = (stats->x[i]-minx)/binwidth;
565 if (index<0)
567 index = 0;
569 if (index>nbins-1)
571 index = nbins-1;
573 (*y)[index] += dd;
574 nindex[index]++;
576 if (*nb == 0)
577 *nb = nbins;
578 for(i=0; (i<nbins); i++)
579 if (nindex[i] > 0)
580 (*y)[i] /= nindex[i];
582 sfree(nindex);
584 return estatsOK;
587 static const char *stats_error[estatsNR] =
589 "All well in STATS land",
590 "No points",
591 "Not enough memory",
592 "Invalid histogram input",
593 "Unknown error",
594 "Not implemented yet"
597 const char *gmx_stats_message(int estats)
599 if ((estats >= 0) && (estats < estatsNR))
601 return stats_error[estats];
603 else
605 return stats_error[estatsERROR];
609 /* Old convenience functions, should be merged with the core
610 statistics above. */
611 int lsq_y_ax(int n, real x[], real y[], real *a)
613 gmx_stats_t lsq = gmx_stats_init();
614 int ok;
615 real da,chi2,Rfit;
617 gmx_stats_add_points(lsq,n,x,y,0,0);
618 if ((ok = gmx_stats_get_a(lsq,elsqWEIGHT_NONE,a,&da,&chi2,&Rfit)) != estatsOK)
620 return ok;
623 /* int i;
624 double xx,yx;
626 yx=xx=0.0;
627 for (i=0; i<n; i++) {
628 yx+=y[i]*x[i];
629 xx+=x[i]*x[i];
631 *a=yx/xx;
633 return estatsOK;
636 static int low_lsq_y_ax_b(int n, real *xr, double *xd, real yr[],
637 real *a, real *b,real *r,real *chi2)
639 int i,ok;
640 gmx_stats_t lsq;
642 lsq = gmx_stats_init();
643 for(i=0; (i<n); i++)
645 if ((ok = gmx_stats_add_point(lsq,(NULL != xd) ? xd[i] : xr[i],yr[i],0,0))
646 != estatsOK)
648 return ok;
651 if ((ok = gmx_stats_get_ab(lsq,elsqWEIGHT_NONE,a,b,NULL,NULL,chi2,r)) != estatsOK)
653 return ok;
656 return estatsOK;
658 double x,y,yx,xx,yy,sx,sy,chi2;
660 yx=xx=yy=sx=sy=0.0;
661 for (i=0; i<n; i++) {
662 if (xd != NULL) {
663 x = xd[i];
664 } else {
665 x = xr[i];
667 y = yr[i];
669 yx += y*x;
670 xx += x*x;
671 yy += y*y;
672 sx += x;
673 sy += y;
675 *a = (n*yx-sy*sx)/(n*xx-sx*sx);
676 *b = (sy-(*a)*sx)/n;
677 *r = sqrt((xx-sx*sx)/(yy-sy*sy));
679 chi2 = 0;
680 if (xd != NULL) {
681 for(i=0; i<n; i++)
682 chi2 += dsqr(yr[i] - ((*a)*xd[i] + (*b)));
683 } else {
684 for(i=0; i<n; i++)
685 chi2 += dsqr(yr[i] - ((*a)*xr[i] + (*b)));
688 if (n > 2)
689 return sqrt(chi2/(n-2));
690 else
691 return 0;
695 int lsq_y_ax_b(int n, real x[], real y[], real *a, real *b,real *r,real *chi2)
697 return low_lsq_y_ax_b(n,x,NULL,y,a,b,r,chi2);
700 int lsq_y_ax_b_xdouble(int n, double x[], real y[], real *a, real *b,
701 real *r,real *chi2)
703 return low_lsq_y_ax_b(n,NULL,x,y,a,b,r,chi2);
706 int lsq_y_ax_b_error(int n, real x[], real y[], real dy[],
707 real *a, real *b, real *da, real *db,
708 real *r,real *chi2)
710 gmx_stats_t lsq;
711 int i,ok;
713 lsq = gmx_stats_init();
714 for(i=0; (i<n); i++)
716 if ((ok = gmx_stats_add_point(lsq,x[i],y[i],0,dy[i])) != estatsOK)
718 return ok;
721 if ((ok = gmx_stats_get_ab(lsq,elsqWEIGHT_Y,a,b,da,db,chi2,r)) != estatsOK)
723 return ok;
725 if ((ok = gmx_stats_done(lsq)) != estatsOK)
727 return ok;
729 sfree(lsq);
731 return estatsOK;
733 double sxy,sxx,syy,sx,sy,w,s_2,dx2,dy2,mins;
735 sxy=sxx=syy=sx=sy=w=0.0;
736 mins = dy[0];
737 for(i=1; (i<n); i++)
738 mins = min(mins,dy[i]);
739 if (mins <= 0)
740 gmx_fatal(FARGS,"Zero or negative weigths in linear regression analysis");
742 for (i=0; i<n; i++) {
743 s_2 = dsqr(1.0/dy[i]);
744 sxx += s_2*dsqr(x[i]);
745 sxy += s_2*y[i]*x[i];
746 syy += s_2*dsqr(y[i]);
747 sx += s_2*x[i];
748 sy += s_2*y[i];
749 w += s_2;
751 sxx = sxx/w;
752 sxy = sxy/w;
753 syy = syy/w;
754 sx = sx/w;
755 sy = sy/w;
756 dx2 = (sxx-sx*sx);
757 dy2 = (syy-sy*sy);
758 *a=(sxy-sy*sx)/dx2;
759 *b=(sy-(*a)*sx);
761 *chi2=0;
762 for(i=0; i<n; i++)
763 *chi2+=dsqr((y[i]-((*a)*x[i]+(*b)))/dy[i]);
764 *chi2 = *chi2/w;
766 *da = sqrt(*chi2/((n-2)*dx2));
767 *db = *da*sqrt(sxx);
768 *r = *a*sqrt(dx2/dy2);
770 if (debug)
771 fprintf(debug,"sx = %g, sy = %g, sxy = %g, sxx = %g, w = %g\n"
772 "chi2 = %g, dx2 = %g\n",
773 sx,sy,sxy,sxx,w,*chi2,dx2);
775 if (n > 2)
776 *chi2 = sqrt(*chi2/(n-2));
777 else
778 *chi2 = 0;