5 * Created by Mark Levy on 21/02/2006.
6 * Copyright 2006 Centre for Digital Music, Queen Mary, University of London.
8 This program is free software; you can redistribute it and/or
9 modify it under the terms of the GNU General Public License as
10 published by the Free Software Foundation; either version 2 of the
11 License, or (at your option) any later version. See the file
12 COPYING included with this distribution for more information.
18 #include "cluster_melt.h"
20 #define DEFAULT_LAMBDA 0.02;
21 #define DEFAULT_LIMIT 20;
23 double kldist(double* a
, double* b
, int n
) {
24 /* NB assume that all a[i], b[i] are non-negative
25 because a, b represent probability distributions */
30 for (i
= 0; i
< n
; i
++)
32 q
= (a
[i
] + b
[i
]) / 2.0;
36 d
+= a
[i
] * log(a
[i
] / q
);
38 d
+= b
[i
] * log(b
[i
] / q
);
44 void cluster_melt(double *h
, int m
, int n
, double *Bsched
, int t
, int k
, int l
, int *c
) {
45 double lambda
, sum
, beta
, logsumexp
, maxlp
;
46 int i
, j
, a
, b
, b0
, b1
, limit
, B
, it
, maxiter
, maxiter0
, maxiter1
;
47 double** cl
; /* reference histograms for each cluster */
48 int** nc
; /* neighbour counts for each histogram */
49 double** lp
; /* soft assignment probs for each histogram */
50 int* oldc
; /* previous hard assignments (to check convergence) */
52 /* NB h is passed as a 1d row major array */
54 /* parameter values */
55 lambda
= DEFAULT_LAMBDA
;
59 limit
= DEFAULT_LIMIT
; /* use default if no valid neighbourhood limit supplied */
61 maxiter0
= 20; /* number of iterations at initial temperature */
62 maxiter1
= 5; /* number of iterations at subsequent temperatures */
65 cl
= (double**) malloc(k
*sizeof(double*));
66 for (i
= 0; i
< k
; i
++)
67 cl
[i
] = (double*) malloc(m
*sizeof(double));
69 nc
= (int**) malloc(n
*sizeof(int*));
70 for (i
= 0; i
< n
; i
++)
71 nc
[i
] = (int*) malloc(k
*sizeof(int));
73 lp
= (double**) malloc(n
*sizeof(double*));
74 for (i
= 0; i
< n
; i
++)
75 lp
[i
] = (double*) malloc(k
*sizeof(double));
77 oldc
= (int*) malloc(n
* sizeof(int));
80 for (i
= 0; i
< k
; i
++)
83 for (j
= 0; j
< m
; j
++)
85 cl
[i
][j
] = rand(); /* random initial reference histograms */
86 sum
+= cl
[i
][j
] * cl
[i
][j
];
89 for (j
= 0; j
< m
; j
++)
91 cl
[i
][j
] /= sum
; /* normalise */
94 //print_array(cl, k, m);
96 for (i
= 0; i
< n
; i
++)
97 c
[i
] = 1; /* initially assign all histograms to cluster 1 */
99 for (a
= 0; a
< t
; a
++)
108 for (it
= 0; it
< maxiter
; it
++)
110 //if (it == maxiter - 1)
111 // mexPrintf("hasn't converged after %d iterations\n", maxiter);
113 for (i
= 0; i
< n
; i
++)
115 /* save current hard assignments */
118 /* calculate soft assignment logprobs for each cluster */
120 for (j
= 0; j
< k
; j
++)
122 lp
[i
][ j
] = -beta
* kldist(cl
[j
], &h
[i
*m
], m
);
124 /* update matching neighbour counts for this histogram, based on current hard assignments */
127 if (i >= limit && i <= n - 1 - limit)
129 for (b = i - limit; b <= i + limit; b++)
134 nc[i][j] = B - nc[i][j];
143 nc
[i
][j
] = b1
- b0
+ 1; /* = B except at edges */
144 for (b
= b0
; b
<= b1
; b
++)
148 sum
+= exp(lp
[i
][j
]);
151 /* normalise responsibilities and add duration logprior */
152 logsumexp
= log(sum
);
153 for (j
= 0; j
< k
; j
++)
154 lp
[i
][j
] -= logsumexp
+ lambda
* nc
[i
][j
];
156 //print_array(lp, n, k);
158 for (i = 0; i < n; i++)
160 for (j = 0; j < k; j++)
161 mexPrintf("%d ", nc[i][j]);
167 /* update the assignments now that we know the duration priors
168 based on the current assignments */
169 for (i
= 0; i
< n
; i
++)
173 for (j
= 1; j
< k
; j
++)
174 if (lp
[i
][j
] > maxlp
)
181 /* break if assignments haven't changed */
183 while (i
< n
&& oldc
[i
] == c
[i
])
188 /* update reference histograms now we know new responsibilities */
189 for (j
= 0; j
< k
; j
++)
191 for (b
= 0; b
< m
; b
++)
194 for (i
= 0; i
< n
; i
++)
196 cl
[j
][b
] += exp(lp
[i
][j
]) * h
[i
*m
+b
];
201 for (i
= 0; i
< n
; i
++)
202 sum
+= exp(lp
[i
][j
]);
203 for (b
= 0; b
< m
; b
++)
204 cl
[j
][b
] /= sum
; /* normalise */
207 //print_array(cl, k, m);
213 for (i
= 0; i
< k
; i
++)
216 for (i
= 0; i
< n
; i
++)
219 for (i
= 0; i
< n
; i
++)