partial_reducer: specialize list version of base
[barvinok.git] / bfcounter.cc
blob2e437d50a0ca98019f51029162a511314a0dab5f
1 #include <vector>
2 #include "bfcounter.h"
3 #include "lattice_point.h"
5 using std::vector;
7 static int lex_cmp(vec_ZZ& a, vec_ZZ& b)
9 assert(a.length() == b.length());
11 for (int j = 0; j < a.length(); ++j)
12 if (a[j] != b[j])
13 return a[j] < b[j] ? -1 : 1;
14 return 0;
17 void bf_base::add_term(bfc_term_base *t, vec_ZZ& num_orig, vec_ZZ& extra_num)
19 vec_ZZ num;
20 int d = num_orig.length();
21 num.SetLength(d-1);
22 for (int l = 0; l < d-1; ++l)
23 num[l] = num_orig[l+1] + extra_num[l];
25 add_term(t, num);
28 void bf_base::add_term(bfc_term_base *t, vec_ZZ& num)
30 int len = t->terms.NumRows();
31 int i, r;
32 for (i = 0; i < len; ++i) {
33 r = lex_cmp(t->terms[i], num);
34 if (r >= 0)
35 break;
37 if (i == len || r > 0) {
38 t->terms.SetDims(len+1, num.length());
39 insert_term(t, i);
40 t->terms[i] = num;
41 } else {
42 // i < len && r == 0
43 update_term(t, i);
47 bfc_term_base* bf_base::find_bfc_term(bfc_vec& v, int *powers, int len)
49 bfc_vec::iterator i;
50 for (i = v.begin(); i != v.end(); ++i) {
51 int j;
52 for (j = 0; j < len; ++j)
53 if ((*i)->powers[j] != powers[j])
54 break;
55 if (j == len)
56 return (*i);
57 if ((*i)->powers[j] > powers[j])
58 break;
61 bfc_term_base* t = new_bf_term(len);
62 v.insert(i, t);
63 memcpy(t->powers, powers, len * sizeof(int));
65 return t;
68 void bf_base::reduce(mat_ZZ& factors, bfc_vec& v, barvinok_options *options)
70 assert(v.size() > 0);
71 unsigned nf = factors.NumRows();
72 unsigned d = factors.NumCols();
74 if (d == lower)
75 return base(factors, v);
77 bf_reducer bfr(factors, v, this);
79 bfr.reduce(options);
81 if (bfr.vn.size() > 0)
82 reduce(bfr.nfactors, bfr.vn, options);
85 int bf_base::setup_factors(const mat_ZZ& rays, mat_ZZ& factors,
86 bfc_term_base* t, int s)
88 factors.SetDims(dim, dim);
90 int r;
92 for (r = 0; r < dim; ++r)
93 t->powers[r] = 1;
95 for (r = 0; r < dim; ++r) {
96 factors[r] = rays[r];
97 int k;
98 for (k = 0; k < dim; ++k)
99 if (factors[r][k] != 0)
100 break;
101 if (factors[r][k] < 0) {
102 factors[r] = -factors[r];
103 t->terms[0] += factors[r];
104 s = -s;
108 return s;
111 void bf_base::handle(const mat_ZZ& rays, Value *vertex, QQ c, int *closed,
112 barvinok_options *options)
114 bfc_term* t = new bfc_term(dim);
115 vector< bfc_term_base * > v;
116 v.push_back(t);
118 t->c.SetLength(1);
120 t->terms.SetDims(1, dim);
121 lattice_point(vertex, rays, t->terms[0], closed);
123 // the elements of factors are always lexpositive
124 mat_ZZ factors;
125 int s = setup_factors(rays, factors, t, 1);
127 t->c[0].n = s * c.n;
128 t->c[0].d = c.d;
130 reduce(factors, v, options);
133 bfc_term_base* bfcounter_base::new_bf_term(int len)
135 bfc_term* t = new bfc_term(len);
136 t->c.SetLength(0);
137 return t;
140 void bfcounter_base::set_factor(bfc_term_base *t, int k, int change)
142 bfc_term* bfct = static_cast<bfc_term *>(t);
143 c = bfct->c[k];
144 if (change)
145 c.n = -c.n;
148 void bfcounter_base::set_factor(bfc_term_base *t, int k, mpq_t &f, int change)
150 bfc_term* bfct = static_cast<bfc_term *>(t);
151 value2zz(mpq_numref(f), c.n);
152 value2zz(mpq_denref(f), c.d);
153 c *= bfct->c[k];
154 if (change)
155 c.n = -c.n;
158 void bfcounter_base::set_factor(bfc_term_base *t, int k, const QQ& c_factor,
159 int change)
161 bfc_term* bfct = static_cast<bfc_term *>(t);
162 c = bfct->c[k];
163 c *= c_factor;
164 if (change)
165 c.n = -c.n;
168 void bfcounter_base::insert_term(bfc_term_base *t, int i)
170 bfc_term* bfct = static_cast<bfc_term *>(t);
171 int len = t->terms.NumRows()-1; // already increased by one
173 bfct->c.SetLength(len+1);
174 for (int j = len; j > i; --j) {
175 bfct->c[j] = bfct->c[j-1];
176 t->terms[j] = t->terms[j-1];
178 bfct->c[i] = c;
181 void bfcounter_base::update_term(bfc_term_base *t, int i)
183 bfc_term* bfct = static_cast<bfc_term *>(t);
185 bfct->c[i] += c;
188 void bf_reducer::compute_extra_num(int i)
190 clear(extra_num);
191 changes = 0;
192 no_param = 0; // r from text
193 only_param = 0; // k-r-s from text
194 total_power = 0; // k from text
196 for (int j = 0; j < nf; ++j) {
197 if (v[i]->powers[j] == 0)
198 continue;
200 total_power += v[i]->powers[j];
201 if (factors[j][0] == 0) {
202 only_param += v[i]->powers[j];
203 continue;
206 if (old2new[j] == -1)
207 no_param += v[i]->powers[j];
208 else
209 extra_num += -sign[j] * v[i]->powers[j] * nfactors[old2new[j]];
210 changes += v[i]->powers[j];
214 void bf_reducer::update_powers(const std::vector<int>& powers)
216 for (int l = 0; l < nnf; ++l)
217 npowers[l] = bpowers[l];
219 l_extra_num = extra_num;
220 l_changes = changes;
222 for (int l = 0; l < powers.size(); ++l) {
223 int n = powers[l];
224 if (n == 0)
225 continue;
226 assert(old2new[l] != -1);
228 npowers[old2new[l]] += n;
229 // interpretation of sign has been inverted
230 // since we inverted the power for specialization
231 if (sign[l] == 1) {
232 l_extra_num += n * nfactors[old2new[l]];
233 l_changes += n;
239 void bf_reducer::compute_reduced_factors()
241 unsigned nf = factors.NumRows();
242 unsigned d = factors.NumCols();
243 nnf = 0;
244 nfactors.SetDims(nnf, d-1);
246 for (int i = 0; i < nf; ++i) {
247 int j;
248 int s = 1;
249 for (j = 0; j < nnf; ++j) {
250 int k;
251 for (k = 1; k < d; ++k)
252 if (factors[i][k] != 0 || nfactors[j][k-1] != 0)
253 break;
254 if (k < d && factors[i][k] == -nfactors[j][k-1])
255 s = -1;
256 for (; k < d; ++k)
257 if (factors[i][k] != s * nfactors[j][k-1])
258 break;
259 if (k == d)
260 break;
262 old2new[i] = j;
263 if (j == nnf) {
264 int k;
265 for (k = 1; k < d; ++k)
266 if (factors[i][k] != 0)
267 break;
268 if (k < d) {
269 if (factors[i][k] < 0)
270 s = -1;
271 nfactors.SetDims(++nnf, d-1);
272 for (int k = 1; k < d; ++k)
273 nfactors[j][k-1] = s * factors[i][k];
274 } else
275 old2new[i] = -1;
277 sign[i] = s;
279 npowers = new int[nnf];
280 bpowers = new int[nnf];
283 void bf_reducer::reduce(barvinok_options *options)
285 compute_reduced_factors();
287 for (int i = 0; i < v.size(); ++i) {
288 compute_extra_num(i);
290 if (no_param == 0) {
291 vec_ZZ extra_num;
292 extra_num.SetLength(d-1);
293 int changes = 0;
294 int npowers[nnf];
295 for (int k = 0; k < nnf; ++k)
296 npowers[k] = 0;
297 for (int k = 0; k < nf; ++k) {
298 assert(old2new[k] != -1);
299 npowers[old2new[k]] += v[i]->powers[k];
300 if (sign[k] == -1) {
301 extra_num += v[i]->powers[k] * nfactors[old2new[k]];
302 changes += v[i]->powers[k];
306 bfc_term_base * t = bf->find_bfc_term(vn, npowers, nnf);
307 for (int k = 0; k < v[i]->terms.NumRows(); ++k) {
308 bf->set_factor(v[i], k, changes % 2);
309 bf->add_term(t, v[i]->terms[k], extra_num);
311 } else {
312 // powers of "constant" part
313 for (int k = 0; k < nnf; ++k)
314 bpowers[k] = 0;
315 for (int k = 0; k < nf; ++k) {
316 if (factors[k][0] != 0)
317 continue;
318 assert(old2new[k] != -1);
319 bpowers[old2new[k]] += v[i]->powers[k];
320 if (sign[k] == -1) {
321 extra_num += v[i]->powers[k] * nfactors[old2new[k]];
322 changes += v[i]->powers[k];
326 int j;
327 for (j = 0; j < nf; ++j)
328 if (old2new[j] == -1 && v[i]->powers[j] > 0)
329 break;
331 dpoly D(no_param, factors[j][0], 1);
332 for (int k = 1; k < v[i]->powers[j]; ++k) {
333 dpoly fact(no_param, factors[j][0], 1);
334 D *= fact;
336 for ( ; ++j < nf; )
337 if (old2new[j] == -1)
338 for (int k = 0; k < v[i]->powers[j]; ++k) {
339 dpoly fact(no_param, factors[j][0], 1);
340 D *= fact;
343 if (no_param + only_param == total_power &&
344 bf->constant_vertex(d)) {
345 bfc_term_base * t = NULL;
346 vec_ZZ num;
347 num.SetLength(d-1);
348 ZZ cn;
349 ZZ cd;
350 for (int k = 0; k < v[i]->terms.NumRows(); ++k) {
351 dpoly n(no_param, v[i]->terms[k][0]);
352 mpq_set_si(bf->tcount, 0, 1);
353 n.div(D, bf->tcount, bf->one);
355 if (value_zero_p(mpq_numref(bf->tcount)))
356 continue;
358 if (!t)
359 t = bf->find_bfc_term(vn, bpowers, nnf);
360 bf->set_factor(v[i], k, bf->tcount, changes % 2);
361 bf->add_term(t, v[i]->terms[k], extra_num);
363 } else {
364 for (int j = 0; j < v[i]->terms.NumRows(); ++j) {
365 dpoly n(no_param, v[i]->terms[j][0]);
367 dpoly_r * r = 0;
368 if (no_param + only_param == total_power)
369 r = new dpoly_r(n, nf);
370 else
371 for (int k = 0; k < nf; ++k) {
372 if (v[i]->powers[k] == 0)
373 continue;
374 if (factors[k][0] == 0 || old2new[k] == -1)
375 continue;
377 dpoly pd(no_param-1, factors[k][0], 1);
379 for (int l = 0; l < v[i]->powers[k]; ++l) {
380 int q;
381 for (q = 0; q < k; ++q)
382 if (old2new[q] == old2new[k] &&
383 sign[q] == sign[k])
384 break;
386 if (r == 0)
387 r = new dpoly_r(n, pd, q, nf);
388 else {
389 dpoly_r *nr = new dpoly_r(r, pd, q, nf);
390 delete r;
391 r = nr;
396 dpoly_r *rc = r->div(D);
397 delete r;
398 QQ factor;
399 factor.d = rc->denom;
401 if (bf->constant_vertex(d)) {
402 dpoly_r_term_list& final = rc->c[rc->len-1];
404 dpoly_r_term_list::iterator k;
405 for (k = final.begin(); k != final.end(); ++k) {
406 if ((*k)->coeff == 0)
407 continue;
409 update_powers((*k)->powers);
411 bfc_term_base * t = bf->find_bfc_term(vn, npowers, nnf);
412 factor.n = (*k)->coeff;
413 bf->set_factor(v[i], j, factor, l_changes % 2);
414 bf->add_term(t, v[i]->terms[j], l_extra_num);
416 } else
417 bf->cum(this, v[i], j, rc, options);
419 delete rc;
423 delete v[i];