Fixed compilation issue.
[gromacs/adressmacs.git] / src / gmxlib / statistics / gmx_statistics.c
blob4fd7ae4d5320227f6f09c2e267aa458d20a7003b
1 /* -*- mode: c; tab-width: 4; indent-tabs-mode: nil; c-basic-offset: 4; c-file-style: "stroustrup"; -*- */
2 /*
3 *
4 * This source code is part of
5 *
6 * G R O M A C S
7 *
8 * GROningen MAchine for Chemical Simulations
9 *
10 * VERSION 3.2.0
11 * Written by David van der Spoel, Erik Lindahl, Berk Hess, and others.
12 * Copyright (c) 1991-2000, University of Groningen, The Netherlands.
13 * Copyright (c) 2001-2004, The GROMACS development team,
14 * check out http://www.gromacs.org for more information.
16 * This program is free software; you can redistribute it and/or
17 * modify it under the terms of the GNU General Public License
18 * as published by the Free Software Foundation; either version 2
19 * of the License, or (at your option) any later version.
21 * If you want to redistribute modifications, please consider that
22 * scientific software is very special. Version control is crucial -
23 * bugs must be traceable. We will be happy to consider code for
24 * inclusion in the official distribution, but derived work must not
25 * be called official GROMACS. Details are found in the README & COPYING
26 * files - if they are missing, get the official version at www.gromacs.org.
28 * To help us fund GROMACS development, we humbly ask that you cite
29 * the papers on the package - you can find them in the top README file.
31 * For more info, check our website at http://www.gromacs.org
33 * And Hey:
34 * Green Red Orange Magenta Azure Cyan Skyblue
36 #ifdef HAVE_CONFIG_H
37 #include <config.h>
38 #endif
40 #include <math.h>
41 #include "typedefs.h"
42 #include "smalloc.h"
43 #include "gmx_statistics.h"
45 static double sqr(double x)
47 return x*x;
50 static int gmx_nint(double x)
52 return (int) (x+0.5);
55 typedef struct gmx_stats {
56 double aa,a,b,sigma_aa,sigma_a,sigma_b,aver,sigma_aver,error;
57 double rmsd,Rdata,Rfit,Rfitaa,chi2,chi2aa;
58 double *x,*y,*dx,*dy;
59 int computed;
60 int np,np_c,nalloc;
61 } gmx_stats;
63 gmx_stats_t gmx_stats_init()
65 gmx_stats *stats;
67 snew(stats,1);
69 return (gmx_stats_t) stats;
72 int gmx_stats_get_npoints(gmx_stats_t gstats, int *N)
74 gmx_stats *stats = (gmx_stats *) gstats;
76 *N = stats->np;
78 return estatsOK;
81 int gmx_stats_done(gmx_stats_t gstats)
83 gmx_stats *stats = (gmx_stats *) gstats;
85 sfree(stats->x);
86 stats->x = NULL;
87 sfree(stats->y);
88 stats->y = NULL;
89 sfree(stats->dx);
90 stats->dx = NULL;
91 sfree(stats->dy);
92 stats->dy = NULL;
94 return estatsOK;
97 int gmx_stats_add_point(gmx_stats_t gstats,double x,double y,
98 double dx,double dy)
100 gmx_stats *stats = (gmx_stats *) gstats;
101 int i;
103 if (stats->np+1 >= stats->nalloc)
105 if (stats->nalloc == 0)
106 stats->nalloc = 1024;
107 else
108 stats->nalloc *= 2;
109 srenew(stats->x,stats->nalloc);
110 srenew(stats->y,stats->nalloc);
111 srenew(stats->dx,stats->nalloc);
112 srenew(stats->dy,stats->nalloc);
113 for(i=stats->np; (i<stats->nalloc); i++)
115 stats->x[i] = 0;
116 stats->y[i] = 0;
117 stats->dx[i] = 0;
118 stats->dy[i] = 0;
121 stats->x[stats->np] = x;
122 stats->y[stats->np] = y;
123 stats->dx[stats->np] = dx;
124 stats->dy[stats->np] = dy;
125 stats->np++;
126 stats->computed = 0;
128 return estatsOK;
131 int gmx_stats_get_point(gmx_stats_t gstats,real *x,real *y,
132 real *dx,real *dy,real level)
134 gmx_stats *stats = (gmx_stats *) gstats;
135 int ok,outlier;
136 real rmsd,r;
138 if ((ok = gmx_stats_get_rmsd(gstats,&rmsd)) != estatsOK)
140 return ok;
142 outlier = 0;
143 while ((outlier == 0) && (stats->np_c < stats->np))
145 r = fabs(stats->x[stats->np_c] - stats->y[stats->np_c]);
146 outlier = (r > rmsd*level);
147 if (outlier)
149 if (NULL != x) *x = stats->x[stats->np_c];
150 if (NULL != y) *y = stats->y[stats->np_c];
151 if (NULL != dx) *dx = stats->dx[stats->np_c];
152 if (NULL != dy) *dy = stats->dy[stats->np_c];
154 stats->np_c++;
156 if (outlier)
157 return estatsOK;
160 stats->np_c = 0;
162 return estatsNO_POINTS;
165 int gmx_stats_add_points(gmx_stats_t gstats,int n,real *x,real *y,
166 real *dx,real *dy)
168 int i,ok;
170 for(i=0; (i<n); i++)
172 if ((ok = gmx_stats_add_point(gstats,x[i],y[i],
173 (NULL != dx) ? dx[i] : 0,
174 (NULL != dy) ? dy[i] : 0)) != estatsOK)
176 return ok;
179 return estatsOK;
182 static int gmx_stats_compute(gmx_stats *stats,int weight)
184 double yy,yx,xx,sx,sy,dy,chi2,chi2aa,d2;
185 double ssxx,ssyy,ssxy;
186 double w,wtot,yx_nw,sy_nw,sx_nw,yy_nw,xx_nw,dx2,dy2;
187 int i,N;
189 N = stats->np;
190 if (stats->computed == 0)
192 if (N < 1)
194 return estatsNO_POINTS;
197 if (weight != elsqWEIGHT_NONE)
199 return estatsNOT_IMPLEMENTED;
202 xx = xx_nw = 0;
203 yy = yy_nw = 0;
204 yx = yx_nw = 0;
205 sx = sx_nw = 0;
206 sy = sy_nw = 0;
207 wtot = 0;
208 d2 = 0;
209 for(i=0; (i<N); i++)
211 d2 += sqr(stats->x[i]-stats->y[i]);
212 if ((stats->dy[i]) && (weight == elsqWEIGHT_Y))
214 w = 1/sqr(stats->dy[i]);
216 else
218 w = 1;
221 wtot += w;
223 xx += w*sqr(stats->x[i]);
224 xx_nw += sqr(stats->x[i]);
226 yy += w*sqr(stats->y[i]);
227 yy_nw += sqr(stats->y[i]);
229 yx += w*stats->y[i]*stats->x[i];
230 yx_nw += stats->y[i]*stats->x[i];
232 sx += w*stats->x[i];
233 sx_nw += stats->x[i];
235 sy += w*stats->y[i];
236 sy_nw += stats->y[i];
239 /* Compute average, sigma and error */
240 stats->aver = sy_nw/N;
241 stats->sigma_aver = sqrt(yy_nw/N - sqr(sy_nw/N));
242 stats->error = stats->sigma_aver/sqrt(N);
244 /* Compute RMSD between x and y */
245 stats->rmsd = sqrt(d2/N);
247 /* Correlation coefficient for data */
248 yx_nw /= N;
249 xx_nw /= N;
250 yy_nw /= N;
251 sx_nw /= N;
252 sy_nw /= N;
253 ssxx = N*(xx_nw - sqr(sx_nw));
254 ssyy = N*(yy_nw - sqr(sy_nw));
255 ssxy = N*(yx_nw - (sx_nw*sy_nw));
256 stats->Rdata = sqrt(sqr(ssxy)/(ssxx*ssyy));
258 /* Compute straight line through datapoints, either with intercept
259 zero (result in aa) or with intercept variable (results in a
260 and b) */
261 yx = yx/wtot;
262 xx = xx/wtot;
263 sx = sx/wtot;
264 sy = sy/wtot;
266 stats->aa = (yx/xx);
267 stats->a = (yx-sx*sy)/(xx-sx*sx);
268 stats->b = (sy)-(stats->a)*(sx);
270 /* Compute chi2, deviation from a line y = ax+b. Also compute
271 chi2aa which returns the deviation from a line y = ax. */
272 chi2 = 0;
273 chi2aa = 0;
274 for(i=0; (i<N); i++)
276 if (stats->dy[i] > 0)
278 dy = stats->dy[i];
280 else
282 dy = 1;
284 chi2aa += sqr((stats->y[i]-(stats->aa*stats->x[i]))/dy);
285 chi2 += sqr((stats->y[i]-(stats->a*stats->x[i]+stats->b))/dy);
287 if (N > 2)
289 stats->chi2 = sqrt(chi2/(N-2));
290 stats->chi2aa = sqrt(chi2aa/(N-2));
292 /* Look up equations! */
293 dx2 = (xx-sx*sx);
294 dy2 = (yy-sy*sy);
295 stats->sigma_a = sqrt(stats->chi2/((N-2)*dx2));
296 stats->sigma_b = stats->sigma_a*sqrt(xx);
297 stats->Rfit = fabs(ssxy)/sqrt(ssxx*ssyy);
298 /*stats->a*sqrt(dx2/dy2);*/
299 stats->Rfitaa = stats->aa*sqrt(dx2/dy2);
301 else
303 stats->chi2 = 0;
304 stats->chi2aa = 0;
305 stats->sigma_a = 0;
306 stats->sigma_b = 0;
307 stats->Rfit = 0;
308 stats->Rfitaa = 0;
311 stats->computed = 1;
314 return estatsOK;
317 int gmx_stats_get_ab(gmx_stats_t gstats,int weight,
318 real *a,real *b,real *da,real *db,
319 real *chi2,real *Rfit)
321 gmx_stats *stats = (gmx_stats *) gstats;
322 int ok;
324 if ((ok = gmx_stats_compute(stats,weight)) != estatsOK)
325 return ok;
326 if (NULL != a)
328 *a = stats->a;
330 if (NULL != b)
332 *b = stats->b;
334 if (NULL != da)
336 *da = stats->sigma_a;
338 if (NULL != db)
340 *db = stats->sigma_b;
342 if (NULL != chi2)
344 *chi2 = stats->chi2;
346 if (NULL != Rfit)
348 *Rfit = stats->Rfit;
351 return estatsOK;
354 int gmx_stats_get_a(gmx_stats_t gstats,int weight,real *a,real *da,
355 real *chi2,real *Rfit)
357 gmx_stats *stats = (gmx_stats *) gstats;
358 int ok;
360 if ((ok = gmx_stats_compute(stats,weight)) != estatsOK)
361 return ok;
362 if (NULL != a) *a = stats->aa;
363 if (NULL != da) *da = stats->sigma_aa;
364 if (NULL != chi2) *chi2 = stats->chi2aa;
365 if (NULL != Rfit) *Rfit = stats->Rfitaa;
367 return estatsOK;
370 int gmx_stats_get_average(gmx_stats_t gstats,real *aver)
372 gmx_stats *stats = (gmx_stats *) gstats;
373 int ok;
375 if ((ok = gmx_stats_compute(stats,elsqWEIGHT_NONE)) != estatsOK)
377 return ok;
380 *aver = stats->aver;
382 return estatsOK;
385 int gmx_stats_get_ase(gmx_stats_t gstats,real *aver,real *sigma,real *error)
387 gmx_stats *stats = (gmx_stats *) gstats;
388 int ok;
390 if ((ok = gmx_stats_compute(stats,elsqWEIGHT_NONE)) != estatsOK)
392 return ok;
395 if (NULL != aver)
397 *aver = stats->aver;
399 if (NULL != sigma)
401 *sigma = stats->sigma_aver;
403 if (NULL != error)
405 *error = stats->error;
408 return estatsOK;
411 int gmx_stats_get_sigma(gmx_stats_t gstats,real *sigma)
413 gmx_stats *stats = (gmx_stats *) gstats;
414 int ok;
416 if ((ok = gmx_stats_compute(stats,elsqWEIGHT_NONE)) != estatsOK)
417 return ok;
419 *sigma = stats->sigma_aver;
421 return estatsOK;
424 int gmx_stats_get_error(gmx_stats_t gstats,real *error)
426 gmx_stats *stats = (gmx_stats *) gstats;
427 int ok;
429 if ((ok = gmx_stats_compute(stats,elsqWEIGHT_NONE)) != estatsOK)
430 return ok;
432 *error = stats->error;
434 return estatsOK;
437 int gmx_stats_get_corr_coeff(gmx_stats_t gstats,real *R)
439 gmx_stats *stats = (gmx_stats *) gstats;
440 int ok;
442 if ((ok = gmx_stats_compute(stats,elsqWEIGHT_NONE)) != estatsOK)
443 return ok;
445 *R = stats->Rdata;
447 return estatsOK;
450 int gmx_stats_get_rmsd(gmx_stats_t gstats,real *rmsd)
452 gmx_stats *stats = (gmx_stats *) gstats;
453 int ok;
455 if ((ok = gmx_stats_compute(stats,elsqWEIGHT_NONE)) != estatsOK)
457 return ok;
460 *rmsd = stats->rmsd;
462 return estatsOK;
465 int gmx_stats_dump_xy(gmx_stats_t gstats,FILE *fp)
467 gmx_stats *stats = (gmx_stats *) gstats;
468 int i,ok;
470 for(i=0; (i<stats->np); i++)
472 fprintf(fp,"%12g %12g %12g %12g\n",stats->x[i],stats->y[i],
473 stats->dx[i],stats->dy[i]);
476 return estatsOK;
479 int gmx_stats_remove_outliers(gmx_stats_t gstats,double level)
481 gmx_stats *stats = (gmx_stats *) gstats;
482 int i,iter=1,done=0,ok;
483 real rmsd,r;
485 while ((stats->np >= 10) && !done)
487 if ((ok = gmx_stats_get_rmsd(gstats,&rmsd)) != estatsOK)
489 return ok;
491 done = 1;
492 for(i=0; (i<stats->np); )
494 r = fabs(stats->x[i]-stats->y[i]);
495 if (r > level*rmsd)
497 fprintf(stderr,"Removing outlier, iter = %d, rmsd = %g, x = %g, y = %g\n",
498 iter,rmsd,stats->x[i],stats->y[i]);
499 if (i < stats->np-1)
501 stats->x[i] = stats->x[stats->np-1];
502 stats->y[i] = stats->y[stats->np-1];
503 stats->dx[i] = stats->dx[stats->np-1];
504 stats->dy[i] = stats->dy[stats->np-1];
506 stats->np--;
507 done = 0;
509 else
511 i++;
514 iter++;
517 return estatsOK;
520 int gmx_stats_make_histogram(gmx_stats_t gstats,real binwidth,int *nb,
521 int ehisto,int normalized,real **x,real **y)
523 gmx_stats *stats = (gmx_stats *) gstats;
524 int i,ok,index=0,nbins=*nb,*nindex;
525 double minx,maxx,maxy,miny,delta,dd,minh;
527 if (((binwidth <= 0) && (nbins <= 0)) ||
528 ((binwidth > 0) && (nbins > 0)))
530 return estatsINVALID_INPUT;
532 if (stats->np <= 2)
534 return estatsNO_POINTS;
536 minx = maxx = stats->x[0];
537 miny = maxy = stats->y[0];
538 for(i=1; (i<stats->np); i++)
540 miny = (stats->y[i] < miny) ? stats->y[i] : miny;
541 maxy = (stats->y[i] > maxy) ? stats->y[i] : maxy;
542 minx = (stats->x[i] < minx) ? stats->x[i] : minx;
543 maxx = (stats->x[i] > maxx) ? stats->x[i] : maxx;
545 if (ehisto == ehistoX)
547 delta = maxx-minx;
548 minh = minx;
550 else if (ehisto == ehistoY)
552 delta = maxy-miny;
553 minh = miny;
555 else
556 return estatsINVALID_INPUT;
558 if (binwidth == 0)
560 binwidth = (delta)/nbins;
562 else
564 nbins = gmx_nint((delta)/binwidth + 0.5);
566 snew(*x,nbins);
567 snew(nindex,nbins);
568 for(i=0; (i<nbins); i++)
570 (*x)[i] = minh + binwidth*(i+0.5);
572 if (normalized == 0)
574 dd = 1;
576 else
578 dd = 1.0/(binwidth*stats->np);
581 snew(*y,nbins);
582 for(i=0; (i<stats->np); i++)
584 if (ehisto == ehistoY)
585 index = (stats->y[i]-miny)/binwidth;
586 else if (ehisto == ehistoX)
587 index = (stats->x[i]-minx)/binwidth;
588 if (index<0)
590 index = 0;
592 if (index>nbins-1)
594 index = nbins-1;
596 (*y)[index] += dd;
597 nindex[index]++;
599 if (*nb == 0)
600 *nb = nbins;
601 for(i=0; (i<nbins); i++)
602 if (nindex[i] > 0)
603 (*y)[i] /= nindex[i];
605 sfree(nindex);
607 return estatsOK;
610 static const char *stats_error[estatsNR] =
612 "All well in STATS land",
613 "No points",
614 "Not enough memory",
615 "Invalid histogram input",
616 "Unknown error",
617 "Not implemented yet"
620 const char *gmx_stats_message(int estats)
622 if ((estats >= 0) && (estats < estatsNR))
624 return stats_error[estats];
626 else
628 return stats_error[estatsERROR];
632 /* Old convenience functions, should be merged with the core
633 statistics above. */
634 int lsq_y_ax(int n, real x[], real y[], real *a)
636 gmx_stats_t lsq = gmx_stats_init();
637 int ok;
638 real da,chi2,Rfit;
640 gmx_stats_add_points(lsq,n,x,y,0,0);
641 if ((ok = gmx_stats_get_a(lsq,elsqWEIGHT_NONE,a,&da,&chi2,&Rfit)) != estatsOK)
643 return ok;
646 /* int i;
647 double xx,yx;
649 yx=xx=0.0;
650 for (i=0; i<n; i++) {
651 yx+=y[i]*x[i];
652 xx+=x[i]*x[i];
654 *a=yx/xx;
656 return estatsOK;
659 static int low_lsq_y_ax_b(int n, real *xr, double *xd, real yr[],
660 real *a, real *b,real *r,real *chi2)
662 int i,ok;
663 gmx_stats_t lsq;
665 lsq = gmx_stats_init();
666 for(i=0; (i<n); i++)
668 if ((ok = gmx_stats_add_point(lsq,(NULL != xd) ? xd[i] : xr[i],yr[i],0,0))
669 != estatsOK)
671 return ok;
674 if ((ok = gmx_stats_get_ab(lsq,elsqWEIGHT_NONE,a,b,NULL,NULL,chi2,r)) != estatsOK)
676 return ok;
679 return estatsOK;
681 double x,y,yx,xx,yy,sx,sy,chi2;
683 yx=xx=yy=sx=sy=0.0;
684 for (i=0; i<n; i++) {
685 if (xd != NULL) {
686 x = xd[i];
687 } else {
688 x = xr[i];
690 y = yr[i];
692 yx += y*x;
693 xx += x*x;
694 yy += y*y;
695 sx += x;
696 sy += y;
698 *a = (n*yx-sy*sx)/(n*xx-sx*sx);
699 *b = (sy-(*a)*sx)/n;
700 *r = sqrt((xx-sx*sx)/(yy-sy*sy));
702 chi2 = 0;
703 if (xd != NULL) {
704 for(i=0; i<n; i++)
705 chi2 += sqr(yr[i] - ((*a)*xd[i] + (*b)));
706 } else {
707 for(i=0; i<n; i++)
708 chi2 += sqr(yr[i] - ((*a)*xr[i] + (*b)));
711 if (n > 2)
712 return sqrt(chi2/(n-2));
713 else
714 return 0;
718 int lsq_y_ax_b(int n, real x[], real y[], real *a, real *b,real *r,real *chi2)
720 return low_lsq_y_ax_b(n,x,NULL,y,a,b,r,chi2);
723 int lsq_y_ax_b_xdouble(int n, double x[], real y[], real *a, real *b,
724 real *r,real *chi2)
726 return low_lsq_y_ax_b(n,NULL,x,y,a,b,r,chi2);
729 int lsq_y_ax_b_error(int n, real x[], real y[], real dy[],
730 real *a, real *b, real *da, real *db,
731 real *r,real *chi2)
733 gmx_stats_t lsq;
734 int i,ok;
736 lsq = gmx_stats_init();
737 for(i=0; (i<n); i++)
739 if ((ok = gmx_stats_add_point(lsq,x[i],y[i],0,dy[i])) != estatsOK)
741 return ok;
744 if ((ok = gmx_stats_get_ab(lsq,elsqWEIGHT_Y,a,b,da,db,chi2,r)) != estatsOK)
746 return ok;
748 if ((ok = gmx_stats_done(lsq)) != estatsOK)
750 return ok;
752 sfree(lsq);
754 return estatsOK;
756 double sxy,sxx,syy,sx,sy,w,s_2,dx2,dy2,mins;
758 sxy=sxx=syy=sx=sy=w=0.0;
759 mins = dy[0];
760 for(i=1; (i<n); i++)
761 mins = min(mins,dy[i]);
762 if (mins <= 0)
763 gmx_fatal(FARGS,"Zero or negative weigths in linear regression analysis");
765 for (i=0; i<n; i++) {
766 s_2 = sqr(1.0/dy[i]);
767 sxx += s_2*sqr(x[i]);
768 sxy += s_2*y[i]*x[i];
769 syy += s_2*sqr(y[i]);
770 sx += s_2*x[i];
771 sy += s_2*y[i];
772 w += s_2;
774 sxx = sxx/w;
775 sxy = sxy/w;
776 syy = syy/w;
777 sx = sx/w;
778 sy = sy/w;
779 dx2 = (sxx-sx*sx);
780 dy2 = (syy-sy*sy);
781 *a=(sxy-sy*sx)/dx2;
782 *b=(sy-(*a)*sx);
784 *chi2=0;
785 for(i=0; i<n; i++)
786 *chi2+=sqr((y[i]-((*a)*x[i]+(*b)))/dy[i]);
787 *chi2 = *chi2/w;
789 *da = sqrt(*chi2/((n-2)*dx2));
790 *db = *da*sqrt(sxx);
791 *r = *a*sqrt(dx2/dy2);
793 if (debug)
794 fprintf(debug,"sx = %g, sy = %g, sxy = %g, sxx = %g, w = %g\n"
795 "chi2 = %g, dx2 = %g\n",
796 sx,sy,sxy,sxx,w,*chi2,dx2);
798 if (n > 2)
799 *chi2 = sqrt(*chi2/(n-2));
800 else
801 *chi2 = 0;