Remove all unnecessary HAVE_CONFIG_H
[gromacs.git] / src / gromacs / statistics / statistics.c
blob797176b41eadb08332d2c0bb34f113af6b8d138d
1 /*
2 * This file is part of the GROMACS molecular simulation package.
4 * Copyright (c) 1991-2000, University of Groningen, The Netherlands.
5 * Copyright (c) 2001-2004, The GROMACS development team.
6 * Copyright (c) 2012,2014, by the GROMACS development team, led by
7 * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
8 * and including many others, as listed in the AUTHORS file in the
9 * top-level source directory and at http://www.gromacs.org.
11 * GROMACS is free software; you can redistribute it and/or
12 * modify it under the terms of the GNU Lesser General Public License
13 * as published by the Free Software Foundation; either version 2.1
14 * of the License, or (at your option) any later version.
16 * GROMACS is distributed in the hope that it will be useful,
17 * but WITHOUT ANY WARRANTY; without even the implied warranty of
18 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
19 * Lesser General Public License for more details.
21 * You should have received a copy of the GNU Lesser General Public
22 * License along with GROMACS; if not, see
23 * http://www.gnu.org/licenses, or write to the Free Software Foundation,
24 * Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
26 * If you want to redistribute modifications to GROMACS, please
27 * consider that scientific software is very special. Version
28 * control is crucial - bugs must be traceable. We will be happy to
29 * consider code for inclusion in the official distribution, but
30 * derived work must not be called official GROMACS. Details are found
31 * in the README & COPYING files - if they are missing, get the
32 * official version at http://www.gromacs.org.
34 * To help us fund GROMACS development, we humbly ask that you cite
35 * the research papers on the package. Check out http://www.gromacs.org.
37 #include "statistics.h"
39 #include "config.h"
40 #include <math.h>
42 #include "gromacs/math/vec.h"
43 #include "gromacs/utility/real.h"
44 #include "gromacs/utility/smalloc.h"
46 static int gmx_dnint(double x)
48 return (int) (x+0.5);
51 typedef struct gmx_stats {
52 double aa, a, b, sigma_aa, sigma_a, sigma_b, aver, sigma_aver, error;
53 double rmsd, Rdata, Rfit, Rfitaa, chi2, chi2aa;
54 double *x, *y, *dx, *dy;
55 int computed;
56 int np, np_c, nalloc;
57 } gmx_stats;
59 gmx_stats_t gmx_stats_init()
61 gmx_stats *stats;
63 snew(stats, 1);
65 return (gmx_stats_t) stats;
68 int gmx_stats_get_npoints(gmx_stats_t gstats, int *N)
70 gmx_stats *stats = (gmx_stats *) gstats;
72 *N = stats->np;
74 return estatsOK;
77 int gmx_stats_done(gmx_stats_t gstats)
79 gmx_stats *stats = (gmx_stats *) gstats;
81 sfree(stats->x);
82 stats->x = NULL;
83 sfree(stats->y);
84 stats->y = NULL;
85 sfree(stats->dx);
86 stats->dx = NULL;
87 sfree(stats->dy);
88 stats->dy = NULL;
90 return estatsOK;
93 int gmx_stats_add_point(gmx_stats_t gstats, double x, double y,
94 double dx, double dy)
96 gmx_stats *stats = (gmx_stats *) gstats;
97 int i;
99 if (stats->np+1 >= stats->nalloc)
101 if (stats->nalloc == 0)
103 stats->nalloc = 1024;
105 else
107 stats->nalloc *= 2;
109 srenew(stats->x, stats->nalloc);
110 srenew(stats->y, stats->nalloc);
111 srenew(stats->dx, stats->nalloc);
112 srenew(stats->dy, stats->nalloc);
113 for (i = stats->np; (i < stats->nalloc); i++)
115 stats->x[i] = 0;
116 stats->y[i] = 0;
117 stats->dx[i] = 0;
118 stats->dy[i] = 0;
121 stats->x[stats->np] = x;
122 stats->y[stats->np] = y;
123 stats->dx[stats->np] = dx;
124 stats->dy[stats->np] = dy;
125 stats->np++;
126 stats->computed = 0;
128 return estatsOK;
131 int gmx_stats_get_point(gmx_stats_t gstats, real *x, real *y,
132 real *dx, real *dy, real level)
134 gmx_stats *stats = (gmx_stats *) gstats;
135 int ok, outlier;
136 real rmsd, r;
138 if ((ok = gmx_stats_get_rmsd(gstats, &rmsd)) != estatsOK)
140 return ok;
142 outlier = 0;
143 while ((outlier == 0) && (stats->np_c < stats->np))
145 r = fabs(stats->x[stats->np_c] - stats->y[stats->np_c]);
146 outlier = (r > rmsd*level);
147 if (outlier)
149 if (NULL != x)
151 *x = stats->x[stats->np_c];
153 if (NULL != y)
155 *y = stats->y[stats->np_c];
157 if (NULL != dx)
159 *dx = stats->dx[stats->np_c];
161 if (NULL != dy)
163 *dy = stats->dy[stats->np_c];
166 stats->np_c++;
168 if (outlier)
170 return estatsOK;
174 stats->np_c = 0;
176 return estatsNO_POINTS;
179 int gmx_stats_add_points(gmx_stats_t gstats, int n, real *x, real *y,
180 real *dx, real *dy)
182 int i, ok;
184 for (i = 0; (i < n); i++)
186 if ((ok = gmx_stats_add_point(gstats, x[i], y[i],
187 (NULL != dx) ? dx[i] : 0,
188 (NULL != dy) ? dy[i] : 0)) != estatsOK)
190 return ok;
193 return estatsOK;
196 static int gmx_stats_compute(gmx_stats *stats, int weight)
198 double yy, yx, xx, sx, sy, dy, chi2, chi2aa, d2;
199 double ssxx, ssyy, ssxy;
200 double w, wtot, yx_nw, sy_nw, sx_nw, yy_nw, xx_nw, dx2, dy2;
201 int i, N;
203 N = stats->np;
204 if (stats->computed == 0)
206 if (N < 1)
208 return estatsNO_POINTS;
211 xx = xx_nw = 0;
212 yy = yy_nw = 0;
213 yx = yx_nw = 0;
214 sx = sx_nw = 0;
215 sy = sy_nw = 0;
216 wtot = 0;
217 d2 = 0;
218 for (i = 0; (i < N); i++)
220 d2 += dsqr(stats->x[i]-stats->y[i]);
221 if ((stats->dy[i]) && (weight == elsqWEIGHT_Y))
223 w = 1/dsqr(stats->dy[i]);
225 else
227 w = 1;
230 wtot += w;
232 xx += w*dsqr(stats->x[i]);
233 xx_nw += dsqr(stats->x[i]);
235 yy += w*dsqr(stats->y[i]);
236 yy_nw += dsqr(stats->y[i]);
238 yx += w*stats->y[i]*stats->x[i];
239 yx_nw += stats->y[i]*stats->x[i];
241 sx += w*stats->x[i];
242 sx_nw += stats->x[i];
244 sy += w*stats->y[i];
245 sy_nw += stats->y[i];
248 /* Compute average, sigma and error */
249 stats->aver = sy_nw/N;
250 stats->sigma_aver = sqrt(yy_nw/N - dsqr(sy_nw/N));
251 stats->error = stats->sigma_aver/sqrt(N);
253 /* Compute RMSD between x and y */
254 stats->rmsd = sqrt(d2/N);
256 /* Correlation coefficient for data */
257 yx_nw /= N;
258 xx_nw /= N;
259 yy_nw /= N;
260 sx_nw /= N;
261 sy_nw /= N;
262 ssxx = N*(xx_nw - dsqr(sx_nw));
263 ssyy = N*(yy_nw - dsqr(sy_nw));
264 ssxy = N*(yx_nw - (sx_nw*sy_nw));
265 stats->Rdata = sqrt(dsqr(ssxy)/(ssxx*ssyy));
267 /* Compute straight line through datapoints, either with intercept
268 zero (result in aa) or with intercept variable (results in a
269 and b) */
270 yx = yx/wtot;
271 xx = xx/wtot;
272 sx = sx/wtot;
273 sy = sy/wtot;
275 stats->aa = (yx/xx);
276 stats->a = (yx-sx*sy)/(xx-sx*sx);
277 stats->b = (sy)-(stats->a)*(sx);
279 /* Compute chi2, deviation from a line y = ax+b. Also compute
280 chi2aa which returns the deviation from a line y = ax. */
281 chi2 = 0;
282 chi2aa = 0;
283 for (i = 0; (i < N); i++)
285 if (stats->dy[i] > 0)
287 dy = stats->dy[i];
289 else
291 dy = 1;
293 chi2aa += dsqr((stats->y[i]-(stats->aa*stats->x[i]))/dy);
294 chi2 += dsqr((stats->y[i]-(stats->a*stats->x[i]+stats->b))/dy);
296 if (N > 2)
298 stats->chi2 = sqrt(chi2/(N-2));
299 stats->chi2aa = sqrt(chi2aa/(N-2));
301 /* Look up equations! */
302 dx2 = (xx-sx*sx);
303 dy2 = (yy-sy*sy);
304 stats->sigma_a = sqrt(stats->chi2/((N-2)*dx2));
305 stats->sigma_b = stats->sigma_a*sqrt(xx);
306 stats->Rfit = fabs(ssxy)/sqrt(ssxx*ssyy);
307 /*stats->a*sqrt(dx2/dy2);*/
308 stats->Rfitaa = stats->aa*sqrt(dx2/dy2);
310 else
312 stats->chi2 = 0;
313 stats->chi2aa = 0;
314 stats->sigma_a = 0;
315 stats->sigma_b = 0;
316 stats->Rfit = 0;
317 stats->Rfitaa = 0;
320 stats->computed = 1;
323 return estatsOK;
326 int gmx_stats_get_ab(gmx_stats_t gstats, int weight,
327 real *a, real *b, real *da, real *db,
328 real *chi2, real *Rfit)
330 gmx_stats *stats = (gmx_stats *) gstats;
331 int ok;
333 if ((ok = gmx_stats_compute(stats, weight)) != estatsOK)
335 return ok;
337 if (NULL != a)
339 *a = stats->a;
341 if (NULL != b)
343 *b = stats->b;
345 if (NULL != da)
347 *da = stats->sigma_a;
349 if (NULL != db)
351 *db = stats->sigma_b;
353 if (NULL != chi2)
355 *chi2 = stats->chi2;
357 if (NULL != Rfit)
359 *Rfit = stats->Rfit;
362 return estatsOK;
365 int gmx_stats_get_a(gmx_stats_t gstats, int weight, real *a, real *da,
366 real *chi2, real *Rfit)
368 gmx_stats *stats = (gmx_stats *) gstats;
369 int ok;
371 if ((ok = gmx_stats_compute(stats, weight)) != estatsOK)
373 return ok;
375 if (NULL != a)
377 *a = stats->aa;
379 if (NULL != da)
381 *da = stats->sigma_aa;
383 if (NULL != chi2)
385 *chi2 = stats->chi2aa;
387 if (NULL != Rfit)
389 *Rfit = stats->Rfitaa;
392 return estatsOK;
395 int gmx_stats_get_average(gmx_stats_t gstats, real *aver)
397 gmx_stats *stats = (gmx_stats *) gstats;
398 int ok;
400 if ((ok = gmx_stats_compute(stats, elsqWEIGHT_NONE)) != estatsOK)
402 return ok;
405 *aver = stats->aver;
407 return estatsOK;
410 int gmx_stats_get_ase(gmx_stats_t gstats, real *aver, real *sigma, real *error)
412 gmx_stats *stats = (gmx_stats *) gstats;
413 int ok;
415 if ((ok = gmx_stats_compute(stats, elsqWEIGHT_NONE)) != estatsOK)
417 return ok;
420 if (NULL != aver)
422 *aver = stats->aver;
424 if (NULL != sigma)
426 *sigma = stats->sigma_aver;
428 if (NULL != error)
430 *error = stats->error;
433 return estatsOK;
436 int gmx_stats_get_sigma(gmx_stats_t gstats, real *sigma)
438 gmx_stats *stats = (gmx_stats *) gstats;
439 int ok;
441 if ((ok = gmx_stats_compute(stats, elsqWEIGHT_NONE)) != estatsOK)
443 return ok;
446 *sigma = stats->sigma_aver;
448 return estatsOK;
451 int gmx_stats_get_error(gmx_stats_t gstats, real *error)
453 gmx_stats *stats = (gmx_stats *) gstats;
454 int ok;
456 if ((ok = gmx_stats_compute(stats, elsqWEIGHT_NONE)) != estatsOK)
458 return ok;
461 *error = stats->error;
463 return estatsOK;
466 int gmx_stats_get_corr_coeff(gmx_stats_t gstats, real *R)
468 gmx_stats *stats = (gmx_stats *) gstats;
469 int ok;
471 if ((ok = gmx_stats_compute(stats, elsqWEIGHT_NONE)) != estatsOK)
473 return ok;
476 *R = stats->Rdata;
478 return estatsOK;
481 int gmx_stats_get_rmsd(gmx_stats_t gstats, real *rmsd)
483 gmx_stats *stats = (gmx_stats *) gstats;
484 int ok;
486 if ((ok = gmx_stats_compute(stats, elsqWEIGHT_NONE)) != estatsOK)
488 return ok;
491 *rmsd = stats->rmsd;
493 return estatsOK;
496 int gmx_stats_dump_xy(gmx_stats_t gstats, FILE *fp)
498 gmx_stats *stats = (gmx_stats *) gstats;
499 int i, ok;
501 for (i = 0; (i < stats->np); i++)
503 fprintf(fp, "%12g %12g %12g %12g\n", stats->x[i], stats->y[i],
504 stats->dx[i], stats->dy[i]);
507 return estatsOK;
510 int gmx_stats_remove_outliers(gmx_stats_t gstats, double level)
512 gmx_stats *stats = (gmx_stats *) gstats;
513 int i, iter = 1, done = 0, ok;
514 real rmsd, r;
516 while ((stats->np >= 10) && !done)
518 if ((ok = gmx_stats_get_rmsd(gstats, &rmsd)) != estatsOK)
520 return ok;
522 done = 1;
523 for (i = 0; (i < stats->np); )
525 r = fabs(stats->x[i]-stats->y[i]);
526 if (r > level*rmsd)
528 fprintf(stderr, "Removing outlier, iter = %d, rmsd = %g, x = %g, y = %g\n",
529 iter, rmsd, stats->x[i], stats->y[i]);
530 if (i < stats->np-1)
532 stats->x[i] = stats->x[stats->np-1];
533 stats->y[i] = stats->y[stats->np-1];
534 stats->dx[i] = stats->dx[stats->np-1];
535 stats->dy[i] = stats->dy[stats->np-1];
537 stats->np--;
538 done = 0;
540 else
542 i++;
545 iter++;
548 return estatsOK;
551 int gmx_stats_make_histogram(gmx_stats_t gstats, real binwidth, int *nb,
552 int ehisto, int normalized, real **x, real **y)
554 gmx_stats *stats = (gmx_stats *) gstats;
555 int i, ok, index = 0, nbins = *nb, *nindex;
556 double minx, maxx, maxy, miny, delta, dd, minh;
558 if (((binwidth <= 0) && (nbins <= 0)) ||
559 ((binwidth > 0) && (nbins > 0)))
561 return estatsINVALID_INPUT;
563 if (stats->np <= 2)
565 return estatsNO_POINTS;
567 minx = maxx = stats->x[0];
568 miny = maxy = stats->y[0];
569 for (i = 1; (i < stats->np); i++)
571 miny = (stats->y[i] < miny) ? stats->y[i] : miny;
572 maxy = (stats->y[i] > maxy) ? stats->y[i] : maxy;
573 minx = (stats->x[i] < minx) ? stats->x[i] : minx;
574 maxx = (stats->x[i] > maxx) ? stats->x[i] : maxx;
576 if (ehisto == ehistoX)
578 delta = maxx-minx;
579 minh = minx;
581 else if (ehisto == ehistoY)
583 delta = maxy-miny;
584 minh = miny;
586 else
588 return estatsINVALID_INPUT;
591 if (binwidth == 0)
593 binwidth = (delta)/nbins;
595 else
597 nbins = gmx_dnint((delta)/binwidth + 0.5);
599 snew(*x, nbins);
600 snew(nindex, nbins);
601 for (i = 0; (i < nbins); i++)
603 (*x)[i] = minh + binwidth*(i+0.5);
605 if (normalized == 0)
607 dd = 1;
609 else
611 dd = 1.0/(binwidth*stats->np);
614 snew(*y, nbins);
615 for (i = 0; (i < stats->np); i++)
617 if (ehisto == ehistoY)
619 index = (stats->y[i]-miny)/binwidth;
621 else if (ehisto == ehistoX)
623 index = (stats->x[i]-minx)/binwidth;
625 if (index < 0)
627 index = 0;
629 if (index > nbins-1)
631 index = nbins-1;
633 (*y)[index] += dd;
634 nindex[index]++;
636 if (*nb == 0)
638 *nb = nbins;
640 for (i = 0; (i < nbins); i++)
642 if (nindex[i] > 0)
644 (*y)[i] /= nindex[i];
648 sfree(nindex);
650 return estatsOK;
653 static const char *stats_error[estatsNR] =
655 "All well in STATS land",
656 "No points",
657 "Not enough memory",
658 "Invalid histogram input",
659 "Unknown error",
660 "Not implemented yet"
663 const char *gmx_stats_message(int estats)
665 if ((estats >= 0) && (estats < estatsNR))
667 return stats_error[estats];
669 else
671 return stats_error[estatsERROR];
675 /* Old convenience functions, should be merged with the core
676 statistics above. */
677 int lsq_y_ax(int n, real x[], real y[], real *a)
679 gmx_stats_t lsq = gmx_stats_init();
680 int ok;
681 real da, chi2, Rfit;
683 gmx_stats_add_points(lsq, n, x, y, 0, 0);
684 if ((ok = gmx_stats_get_a(lsq, elsqWEIGHT_NONE, a, &da, &chi2, &Rfit)) != estatsOK)
686 return ok;
689 /* int i;
690 double xx,yx;
692 yx=xx=0.0;
693 for (i=0; i<n; i++) {
694 yx+=y[i]*x[i];
695 xx+=x[i]*x[i];
697 * a=yx/xx;
699 return estatsOK;
702 static int low_lsq_y_ax_b(int n, real *xr, double *xd, real yr[],
703 real *a, real *b, real *r, real *chi2)
705 int i, ok;
706 gmx_stats_t lsq;
708 lsq = gmx_stats_init();
709 for (i = 0; (i < n); i++)
711 if ((ok = gmx_stats_add_point(lsq, (NULL != xd) ? xd[i] : xr[i], yr[i], 0, 0))
712 != estatsOK)
714 return ok;
717 if ((ok = gmx_stats_get_ab(lsq, elsqWEIGHT_NONE, a, b, NULL, NULL, chi2, r)) != estatsOK)
719 return ok;
722 return estatsOK;
724 double x,y,yx,xx,yy,sx,sy,chi2;
726 yx=xx=yy=sx=sy=0.0;
727 for (i=0; i<n; i++) {
728 if (xd != NULL) {
729 x = xd[i];
730 } else {
731 x = xr[i];
733 y = yr[i];
735 yx += y*x;
736 xx += x*x;
737 yy += y*y;
738 sx += x;
739 sy += y;
741 * a = (n*yx-sy*sx)/(n*xx-sx*sx);
742 * b = (sy-(*a)*sx)/n;
743 * r = sqrt((xx-sx*sx)/(yy-sy*sy));
745 chi2 = 0;
746 if (xd != NULL) {
747 for(i=0; i<n; i++)
748 chi2 += dsqr(yr[i] - ((*a)*xd[i] + (*b)));
749 } else {
750 for(i=0; i<n; i++)
751 chi2 += dsqr(yr[i] - ((*a)*xr[i] + (*b)));
754 if (n > 2)
755 return sqrt(chi2/(n-2));
756 else
757 return 0;
761 int lsq_y_ax_b(int n, real x[], real y[], real *a, real *b, real *r, real *chi2)
763 return low_lsq_y_ax_b(n, x, NULL, y, a, b, r, chi2);
766 int lsq_y_ax_b_xdouble(int n, double x[], real y[], real *a, real *b,
767 real *r, real *chi2)
769 return low_lsq_y_ax_b(n, NULL, x, y, a, b, r, chi2);
772 int lsq_y_ax_b_error(int n, real x[], real y[], real dy[],
773 real *a, real *b, real *da, real *db,
774 real *r, real *chi2)
776 gmx_stats_t lsq;
777 int i, ok;
779 lsq = gmx_stats_init();
780 for (i = 0; (i < n); i++)
782 if ((ok = gmx_stats_add_point(lsq, x[i], y[i], 0, dy[i])) != estatsOK)
784 return ok;
787 if ((ok = gmx_stats_get_ab(lsq, elsqWEIGHT_Y, a, b, da, db, chi2, r)) != estatsOK)
789 return ok;
791 if ((ok = gmx_stats_done(lsq)) != estatsOK)
793 return ok;
795 sfree(lsq);
797 return estatsOK;
799 double sxy,sxx,syy,sx,sy,w,s_2,dx2,dy2,mins;
801 sxy=sxx=syy=sx=sy=w=0.0;
802 mins = dy[0];
803 for(i=1; (i<n); i++)
804 mins = min(mins,dy[i]);
805 if (mins <= 0)
806 gmx_fatal(FARGS,"Zero or negative weigths in linear regression analysis");
808 for (i=0; i<n; i++) {
809 s_2 = dsqr(1.0/dy[i]);
810 sxx += s_2*dsqr(x[i]);
811 sxy += s_2*y[i]*x[i];
812 syy += s_2*dsqr(y[i]);
813 sx += s_2*x[i];
814 sy += s_2*y[i];
815 w += s_2;
817 sxx = sxx/w;
818 sxy = sxy/w;
819 syy = syy/w;
820 sx = sx/w;
821 sy = sy/w;
822 dx2 = (sxx-sx*sx);
823 dy2 = (syy-sy*sy);
824 * a=(sxy-sy*sx)/dx2;
825 * b=(sy-(*a)*sx);
827 * chi2=0;
828 for(i=0; i<n; i++)
829 * chi2+=dsqr((y[i]-((*a)*x[i]+(*b)))/dy[i]);
830 * chi2 = *chi2/w;
832 * da = sqrt(*chi2/((n-2)*dx2));
833 * db = *da*sqrt(sxx);
834 * r = *a*sqrt(dx2/dy2);
836 if (debug)
837 fprintf(debug,"sx = %g, sy = %g, sxy = %g, sxx = %g, w = %g\n"
838 "chi2 = %g, dx2 = %g\n",
839 sx,sy,sxy,sxx,w,*chi2,dx2);
841 if (n > 2)
842 * chi2 = sqrt(*chi2/(n-2));
843 else
844 * chi2 = 0;