suifmath/named_sc_fm.cc: don't include sys/times.h
[suif.git] / src / baseparsuif / suifmath / matrix.cc
blob90e7b904e1c331a0dc2f3d2d3b14662529893bef
1 /* file "matrix.cc" */
3 /* Copyright (c) 1994,95 Stanford University
5 All rights reserved.
7 This software is provided under the terms described in
8 the "suif_copyright.h" include file. */
10 #include <suif_copyright.h>
12 #pragma implementation "matrix.h"
14 #define RCS_BASE_FILE matrix_cc
16 #include "matrix.h"
18 RCS_BASE("$Id: matrix.cc,v 1.1.1.1 1998/07/07 05:09:27 brm Exp $")
21 * Private member functions
22 */
24 void matrix::init_decomp()
26 lu_decomp = fract_matrix();
27 factored = FALSE;
28 interchanges = NULL;
29 pivot_cols = NULL;
32 void matrix::clear_decomp()
34 lu_decomp.clear();
35 factored = FALSE;
37 if (interchanges) {
38 delete[] interchanges;
39 interchanges = NULL;
42 if (pivot_cols) {
43 delete[] pivot_cols;
44 pivot_cols = NULL;
48 void matrix::copy_decomp(const matrix &mat)
50 assert(factored && mat.factored);
52 lu_decomp = mat.lu_decomp;
54 int rows = lu_decomp.m();
55 int cols = lu_decomp.n();
57 assert(mat.interchanges);
58 interchanges = new int[rows];
59 int i;
60 for (i = 0; i < rows; i++)
61 interchanges[i] = mat.interchanges[i];
63 assert(mat.pivot_cols);
64 pivot_cols = new boolean[cols];
65 for (i = 0; i < cols; i++)
66 pivot_cols[i] = mat.pivot_cols[i];
69 void matrix::copy_matrix(const matrix &mat)
71 elems = mat.elems;
73 if (mat.factored) {
74 clear_decomp();
75 factored = TRUE;
76 copy_decomp(mat);
78 else {
79 init_decomp();
85 // Linear Algebra support
88 // Factors column 'j' in elems and places it in lu_decomp
90 void matrix::factor_col_and_insert(int j)
92 assert(j < elems.n());
94 int rows = elems.m(); // number of rows in resulting decomp
95 int cols = j; // column being factored = number of prev columns
97 fract_vector new_col(elems.col_list[j]);
99 int cur_pivot_row = 0;
100 for (int i = 0; i < cols; i++) // pivot row = number of pivot cols
101 cur_pivot_row += pivot_cols[i];
103 assert(cur_pivot_row <= rows);
105 int new_pivot_row = factor_col(&new_col, cols, cur_pivot_row);
106 assert((cur_pivot_row <= new_pivot_row && new_pivot_row < rows) ||
107 (new_pivot_row == rows && cur_pivot_row == new_pivot_row));
109 if (cur_pivot_row < rows) {
110 interchanges[cur_pivot_row] = new_pivot_row;
112 if (new_pivot_row != cur_pivot_row) {
114 // new_col has swapped rows, now swap prev cols
115 for(int i = 0; i < cols; i++) {
116 fract temp = lu_decomp.elt(new_pivot_row, i);
117 lu_decomp.elt(new_pivot_row, i) = lu_decomp.elt(cur_pivot_row,i);
118 lu_decomp.elt(cur_pivot_row, i) = temp;
122 pivot_cols[cols] = (new_col[cur_pivot_row] != fract(0));
124 else pivot_cols[cols] = FALSE;
126 lu_decomp.col_list[j] = new_col;
130 // Factor the given fract_vector column. Returns the row that the pivot row
131 // is exchanged with.
133 int matrix::factor_col(fract_vector *vec, int cols, int pivot_row)
135 int rows = elems.m();
137 apply_L(vec, cols); // Apply the L we have so far to the column vec
139 if (pivot_row == rows) return rows;
140 assert(pivot_row < rows);
142 int new_pivot_row = pivot_row;
144 fract max = vec->elt(pivot_row);
146 if (max == fract(0)) { // Find new pivot row to interchange
147 for (int i = pivot_row+1; i < rows; i++) {
148 max = vec->elt(i).abs();
150 if (max > fract(0)) {
151 new_pivot_row = i;
152 break;
157 assert(pivot_row <= new_pivot_row && new_pivot_row < rows);
159 if (new_pivot_row != pivot_row) { // swap
160 fract temp = vec->elt(pivot_row);
161 vec->elt(pivot_row) = vec->elt(new_pivot_row);
162 vec->elt(new_pivot_row) = temp;
165 if (vec->elt(pivot_row) != fract(0)) {
167 if (pivot_row == cols) {
168 for (int i = pivot_row+1; i < rows; i++) {
169 vec->elt(i) /= vec->elt(pivot_row);
172 else {
173 assert(pivot_row < cols);
174 for (int i = pivot_row+1; i < rows; i++) {
175 lu_decomp.elt(i,pivot_row) = vec->elt(i) / vec->elt(pivot_row);
176 vec->elt(i) = 0;
181 return new_pivot_row;
184 // Apply the L matrix to the column.
186 void matrix::apply_L(fract_vector *vec, int cols)
188 int rows = elems.m();
190 // Perform interchanges we've done so far
192 for (int i = 0; i < rows; i++) {
193 if (interchanges[i] != i) {
194 assert(interchanges[i] > i);
195 fract temp = vec->elt(i);
196 vec->elt(i) = vec->elt(interchanges[i]);
197 vec->elt(interchanges[i]) = temp;
201 // Operations on column as dictated by L.
203 for (int j = 0; j < cols; j++) {
204 for (int i = j+1; i < rows; i++) {
205 vec->elt(i) -= vec->elt(j) * lu_decomp.elt(i,j);
211 // Backsolve the column 'vec' using U.
212 // All free variables are assumed to be zero except the var specified by
213 // 'free_var'.
215 fract_vector matrix::backsolve_U(fract_vector &vec, int free_var,
216 boolean *valid)
218 assert(factored);
220 int rows = lu_decomp.m();
221 int cols = lu_decomp.n();
223 fract_vector result(cols);
225 // First examine all zero rows:
226 int last_row = last_nonzero_Urow();
228 for (int ii = last_row + 1; ii < rows; ii++) {
229 if (vec[ii] != fract(0)) {
230 *valid = FALSE;
231 return result; // All zeros
235 int i = last_row;
236 for (int j = cols-1; j >= 0; j--) {
237 if (!pivot_cols[j]) {
238 if(j == free_var) result[j] = fract(1);
239 else result[j] = fract(0);
242 else {
243 fract temp = vec[i];
245 for (int jj = j+1; jj < cols; jj++) {
246 temp -= lu_decomp.elt(i,jj) * result[jj];
248 result[j] = temp/lu_decomp.elt(i,j);
249 i--;
253 *valid = TRUE;
254 return result;
257 int matrix::last_nonzero_Urow()
259 assert(factored);
261 int result = -1;
262 for (int i = 0; i < lu_decomp.n(); i++)
263 result += pivot_cols[i];
265 return result;
270 * Public member functions
274 // Constructors
277 matrix::matrix()
279 elems = fract_matrix();
280 init_decomp();
283 matrix::matrix(int r, int c)
285 elems = fract_matrix(r, c);
286 init_decomp();
289 matrix::matrix(const integer_matrix &mat)
291 elems = fract_matrix(mat);
292 init_decomp();
295 matrix::matrix(const integer_matrix &mat, int div)
297 elems = fract_matrix(mat, div);
298 init_decomp();
301 matrix::matrix(const fract_matrix &mat)
303 elems = mat;
304 init_decomp();
307 matrix::matrix(const fract_matrix &mat, int c)
309 elems = fract_matrix(mat, c);
310 init_decomp();
314 matrix::matrix(const matrix &mat)
316 init_decomp();
317 copy_matrix(mat);
320 matrix::matrix(const matrix &mat, int c)
322 elems = fract_matrix(mat.elems, c);
323 init_decomp();
328 // Equality and assignment operators:
331 boolean matrix::operator==(const matrix &mat) const
333 return (elems == mat.elems);
336 matrix &matrix::operator=(const matrix &mat)
338 copy_matrix(mat);
339 return *this;
344 // Matrix-Matrix operations:
347 matrix matrix::operator*(matrix &mat)
349 return matrix(elems * mat.elems);
352 matrix matrix::operator+(matrix &mat)
354 return matrix(elems + mat.elems);
357 matrix matrix::operator-(matrix &mat)
359 return matrix(elems - mat.elems);
362 matrix &matrix::operator+=(matrix &mat)
364 elems += mat.elems;
365 factored = FALSE;
367 return *this;
370 matrix &matrix::operator-=(matrix &mat)
372 elems -= mat.elems;
373 factored = FALSE;
375 return *this;
380 // Element-wise operations:
382 matrix matrix::operator*(fract val)
384 return matrix(elems * val);
387 matrix matrix::operator/(fract val)
389 return matrix(elems / val);
392 matrix matrix::operator+(fract val)
394 return matrix(elems + val);
397 matrix matrix::operator-(fract val)
399 return matrix(elems - val);
402 matrix &matrix::operator+=(fract val)
404 elems += val;
405 factored = FALSE;
407 return *this;
410 matrix &matrix::operator-=(fract val)
412 elems -= val;
413 factored = FALSE;
415 return *this;
418 matrix &matrix::operator*=(fract val)
420 elems *= val;
421 factored = FALSE;
423 return *this;
426 matrix &matrix::operator/=(fract val)
428 elems /= val;
429 factored = FALSE;
431 return *this;
436 // Matrix-vector
438 fract_vector matrix::operator*(fract_vector &vec)
440 return (elems * vec);
445 // Other useful functions:
447 void matrix::ident(int n)
449 elems.ident(n);
450 factored = FALSE;
453 matrix matrix::transpose()
455 return matrix(elems.transpose());
460 // Linear Algebra Operations
463 boolean matrix::is_pivot(int i)
465 factor();
466 return pivot_cols[i];
470 int matrix::rank()
472 int val = 0;
473 int cols = elems.n();
475 factor();
476 for (int i = 0; i < cols; i++)
477 if (pivot_cols[i]) val++;
479 return (val);
483 void matrix::factor()
485 if (factored) return;
487 int rows = m();
488 int cols = n();
490 clear_decomp();
491 lu_decomp = fract_matrix(0, cols); // Say rows are size 0 so no storage
492 lu_decomp.rows = rows; // is allocated.
494 interchanges = new int[rows];
495 pivot_cols = new boolean[cols];
497 for (int r = 0; r < rows; r++)
498 interchanges[r] = r;
500 for (int j = 0; j < cols; j++) {
501 factor_col_and_insert(j);
504 factored = TRUE;
507 matrix matrix::inverse()
509 int rows = elems.m();
510 int cols = elems.n();
512 assert_msg(rows == cols,
513 ("Matrix of size %d x %d is not square", rows, cols));
515 factor();
517 int i;
518 for (i = 0; i < m(); i++) {
519 assert_msg(pivot_cols[i],
520 ("Matrix is singular, col %d is not a pivot col", i));
523 fract_matrix result(0, cols); // Say rows are size 0 so no storage
524 result.rows = rows; // is allocated.
526 for (i = 0; i < rows; i++) {
527 fract_vector vec(rows);
528 vec[i] = fract(1);
530 apply_L(&vec, cols);
531 boolean valid;
532 fract_vector solve_col = backsolve_U(vec, -1, &valid);
533 assert(valid);
535 result.col_list[i] = solve_col;
538 return matrix(result);
541 vector_space matrix::kernel()
543 int rows = m();
544 int cols = n();
546 factor();
548 vector_space result(cols);
549 fract_vector zeros(rows);
551 for (int free_var = 0; free_var < cols; free_var++) {
552 // the kernel has a dimension for each non-pivot column
554 if (!pivot_cols[free_var]) {
555 boolean valid;
556 fract_vector solve_col = backsolve_U(zeros, free_var, &valid);
557 assert(valid);
558 valid = result.insert(solve_col);
559 assert(valid);
563 return result;
566 fract_vector matrix::particular_solution(const fract_vector &vec,
567 boolean *valid)
569 int cols = n();
570 factor();
572 fract_vector new_vec(vec);
574 apply_L(&new_vec, cols);
575 fract_vector result = backsolve_U(new_vec, -1, valid);
577 return result;
581 vector_space matrix::range()
583 return (vector_space(*this));
587 vector_space matrix::domain()
589 matrix trans = transpose();
590 return (vector_space(trans));
594 matrix matrix::operator%(integer_row &rw)
596 return matrix(elems % rw);
599 matrix matrix::del_row(int i, int j)
601 return matrix(elems.del_row(i,j));
605 matrix matrix::del_col(int i, int j)
607 return matrix(elems.del_col(i,j));
611 matrix matrix::del_columns(integer_row &rw)
613 return matrix(elems.del_columns(rw));
617 matrix matrix::insert_col(int i)
619 return matrix(elems.insert_col(i));
622 matrix matrix::swap_col(int i, int j)
624 return matrix(elems.swap_col(i,j));
627 matrix matrix::swap_row(int i, int j)
629 return matrix(elems.swap_row(i,j));
632 fract_vector matrix::get_row(int i)
634 return elems.get_row(i);
637 fract_vector matrix::get_col(int i)
639 return elems.get_col(i);
642 void matrix::set_col(int i, fract_vector &vec)
644 elems.set_col(i, vec);
645 factored = FALSE;
648 void matrix::set_row(int i, fract_vector &vec)
650 elems.set_row(i, vec);
651 factored = FALSE;
654 matrix matrix::resize_offset(int r1, int r2, int c1, int c2, int fill)
656 return matrix(elems.resize_offset(r1, r2, c1, c2, fill));
659 matrix matrix::r_merge(matrix &mat)
661 return matrix(elems.r_merge(mat.elems));
665 matrix matrix::c_merge(matrix &mat)
667 return matrix(elems.c_merge(mat.elems));
670 immed_list *matrix::cvt_immed_list()
672 return elems.cvt_immed_list();
675 matrix cvt_matrix(immed_list *il)
677 return matrix(cvt_fract_matrix(il));
680 void matrix::print(FILE *fp)
682 elems.print(fp);
685 void matrix::print_full(FILE *fp)
687 fprintf(fp, "Elements:\n");
688 elems.print(fp);
690 if (factored) {
691 fprintf(fp, "LU decomp:\n");
692 lu_decomp.print(fp);
694 assert(interchanges);
695 fprintf(fp, "Interchanges [");
696 int i;
697 for (i = 0; i < lu_decomp.m(); i++)
698 fprintf(fp, "%s%d", (i == 0 ? "" : " "), interchanges[i]);
699 fprintf(fp, "]\n");
701 assert(pivot_cols);
702 fprintf(fp, "Pivot Columns [");
703 for (i = 0; i < lu_decomp.n(); i++)
704 fprintf(fp, "%s%d", (i == 0 ? "" : " "), pivot_cols[i]);
705 fprintf(fp, "]\n");
708 else {
709 fprintf(fp, "Not factored\n");
713 void matrix::clear()
715 elems.clear();
716 clear_decomp();