3 /* Copyright (c) 1994,95 Stanford University
7 This software is provided under the terms described in
8 the "suif_copyright.h" include file. */
10 #include <suif_copyright.h>
12 #pragma implementation "matrix.h"
14 #define RCS_BASE_FILE matrix_cc
18 RCS_BASE("$Id: matrix.cc,v 1.1.1.1 1998/07/07 05:09:27 brm Exp $")
21 * Private member functions
24 void matrix::init_decomp()
26 lu_decomp
= fract_matrix();
32 void matrix::clear_decomp()
38 delete[] interchanges
;
48 void matrix::copy_decomp(const matrix
&mat
)
50 assert(factored
&& mat
.factored
);
52 lu_decomp
= mat
.lu_decomp
;
54 int rows
= lu_decomp
.m();
55 int cols
= lu_decomp
.n();
57 assert(mat
.interchanges
);
58 interchanges
= new int[rows
];
60 for (i
= 0; i
< rows
; i
++)
61 interchanges
[i
] = mat
.interchanges
[i
];
63 assert(mat
.pivot_cols
);
64 pivot_cols
= new boolean
[cols
];
65 for (i
= 0; i
< cols
; i
++)
66 pivot_cols
[i
] = mat
.pivot_cols
[i
];
69 void matrix::copy_matrix(const matrix
&mat
)
85 // Linear Algebra support
88 // Factors column 'j' in elems and places it in lu_decomp
90 void matrix::factor_col_and_insert(int j
)
92 assert(j
< elems
.n());
94 int rows
= elems
.m(); // number of rows in resulting decomp
95 int cols
= j
; // column being factored = number of prev columns
97 fract_vector
new_col(elems
.col_list
[j
]);
99 int cur_pivot_row
= 0;
100 for (int i
= 0; i
< cols
; i
++) // pivot row = number of pivot cols
101 cur_pivot_row
+= pivot_cols
[i
];
103 assert(cur_pivot_row
<= rows
);
105 int new_pivot_row
= factor_col(&new_col
, cols
, cur_pivot_row
);
106 assert((cur_pivot_row
<= new_pivot_row
&& new_pivot_row
< rows
) ||
107 (new_pivot_row
== rows
&& cur_pivot_row
== new_pivot_row
));
109 if (cur_pivot_row
< rows
) {
110 interchanges
[cur_pivot_row
] = new_pivot_row
;
112 if (new_pivot_row
!= cur_pivot_row
) {
114 // new_col has swapped rows, now swap prev cols
115 for(int i
= 0; i
< cols
; i
++) {
116 fract temp
= lu_decomp
.elt(new_pivot_row
, i
);
117 lu_decomp
.elt(new_pivot_row
, i
) = lu_decomp
.elt(cur_pivot_row
,i
);
118 lu_decomp
.elt(cur_pivot_row
, i
) = temp
;
122 pivot_cols
[cols
] = (new_col
[cur_pivot_row
] != fract(0));
124 else pivot_cols
[cols
] = FALSE
;
126 lu_decomp
.col_list
[j
] = new_col
;
130 // Factor the given fract_vector column. Returns the row that the pivot row
131 // is exchanged with.
133 int matrix::factor_col(fract_vector
*vec
, int cols
, int pivot_row
)
135 int rows
= elems
.m();
137 apply_L(vec
, cols
); // Apply the L we have so far to the column vec
139 if (pivot_row
== rows
) return rows
;
140 assert(pivot_row
< rows
);
142 int new_pivot_row
= pivot_row
;
144 fract max
= vec
->elt(pivot_row
);
146 if (max
== fract(0)) { // Find new pivot row to interchange
147 for (int i
= pivot_row
+1; i
< rows
; i
++) {
148 max
= vec
->elt(i
).abs();
150 if (max
> fract(0)) {
157 assert(pivot_row
<= new_pivot_row
&& new_pivot_row
< rows
);
159 if (new_pivot_row
!= pivot_row
) { // swap
160 fract temp
= vec
->elt(pivot_row
);
161 vec
->elt(pivot_row
) = vec
->elt(new_pivot_row
);
162 vec
->elt(new_pivot_row
) = temp
;
165 if (vec
->elt(pivot_row
) != fract(0)) {
167 if (pivot_row
== cols
) {
168 for (int i
= pivot_row
+1; i
< rows
; i
++) {
169 vec
->elt(i
) /= vec
->elt(pivot_row
);
173 assert(pivot_row
< cols
);
174 for (int i
= pivot_row
+1; i
< rows
; i
++) {
175 lu_decomp
.elt(i
,pivot_row
) = vec
->elt(i
) / vec
->elt(pivot_row
);
181 return new_pivot_row
;
184 // Apply the L matrix to the column.
186 void matrix::apply_L(fract_vector
*vec
, int cols
)
188 int rows
= elems
.m();
190 // Perform interchanges we've done so far
192 for (int i
= 0; i
< rows
; i
++) {
193 if (interchanges
[i
] != i
) {
194 assert(interchanges
[i
] > i
);
195 fract temp
= vec
->elt(i
);
196 vec
->elt(i
) = vec
->elt(interchanges
[i
]);
197 vec
->elt(interchanges
[i
]) = temp
;
201 // Operations on column as dictated by L.
203 for (int j
= 0; j
< cols
; j
++) {
204 for (int i
= j
+1; i
< rows
; i
++) {
205 vec
->elt(i
) -= vec
->elt(j
) * lu_decomp
.elt(i
,j
);
211 // Backsolve the column 'vec' using U.
212 // All free variables are assumed to be zero except the var specified by
215 fract_vector
matrix::backsolve_U(fract_vector
&vec
, int free_var
,
220 int rows
= lu_decomp
.m();
221 int cols
= lu_decomp
.n();
223 fract_vector
result(cols
);
225 // First examine all zero rows:
226 int last_row
= last_nonzero_Urow();
228 for (int ii
= last_row
+ 1; ii
< rows
; ii
++) {
229 if (vec
[ii
] != fract(0)) {
231 return result
; // All zeros
236 for (int j
= cols
-1; j
>= 0; j
--) {
237 if (!pivot_cols
[j
]) {
238 if(j
== free_var
) result
[j
] = fract(1);
239 else result
[j
] = fract(0);
245 for (int jj
= j
+1; jj
< cols
; jj
++) {
246 temp
-= lu_decomp
.elt(i
,jj
) * result
[jj
];
248 result
[j
] = temp
/lu_decomp
.elt(i
,j
);
257 int matrix::last_nonzero_Urow()
262 for (int i
= 0; i
< lu_decomp
.n(); i
++)
263 result
+= pivot_cols
[i
];
270 * Public member functions
279 elems
= fract_matrix();
283 matrix::matrix(int r
, int c
)
285 elems
= fract_matrix(r
, c
);
289 matrix::matrix(const integer_matrix
&mat
)
291 elems
= fract_matrix(mat
);
295 matrix::matrix(const integer_matrix
&mat
, int div
)
297 elems
= fract_matrix(mat
, div
);
301 matrix::matrix(const fract_matrix
&mat
)
307 matrix::matrix(const fract_matrix
&mat
, int c
)
309 elems
= fract_matrix(mat
, c
);
314 matrix::matrix(const matrix
&mat
)
320 matrix::matrix(const matrix
&mat
, int c
)
322 elems
= fract_matrix(mat
.elems
, c
);
328 // Equality and assignment operators:
331 boolean
matrix::operator==(const matrix
&mat
) const
333 return (elems
== mat
.elems
);
336 matrix
&matrix::operator=(const matrix
&mat
)
344 // Matrix-Matrix operations:
347 matrix
matrix::operator*(matrix
&mat
)
349 return matrix(elems
* mat
.elems
);
352 matrix
matrix::operator+(matrix
&mat
)
354 return matrix(elems
+ mat
.elems
);
357 matrix
matrix::operator-(matrix
&mat
)
359 return matrix(elems
- mat
.elems
);
362 matrix
&matrix::operator+=(matrix
&mat
)
370 matrix
&matrix::operator-=(matrix
&mat
)
380 // Element-wise operations:
382 matrix
matrix::operator*(fract val
)
384 return matrix(elems
* val
);
387 matrix
matrix::operator/(fract val
)
389 return matrix(elems
/ val
);
392 matrix
matrix::operator+(fract val
)
394 return matrix(elems
+ val
);
397 matrix
matrix::operator-(fract val
)
399 return matrix(elems
- val
);
402 matrix
&matrix::operator+=(fract val
)
410 matrix
&matrix::operator-=(fract val
)
418 matrix
&matrix::operator*=(fract val
)
426 matrix
&matrix::operator/=(fract val
)
438 fract_vector
matrix::operator*(fract_vector
&vec
)
440 return (elems
* vec
);
445 // Other useful functions:
447 void matrix::ident(int n
)
453 matrix
matrix::transpose()
455 return matrix(elems
.transpose());
460 // Linear Algebra Operations
463 boolean
matrix::is_pivot(int i
)
466 return pivot_cols
[i
];
473 int cols
= elems
.n();
476 for (int i
= 0; i
< cols
; i
++)
477 if (pivot_cols
[i
]) val
++;
483 void matrix::factor()
485 if (factored
) return;
491 lu_decomp
= fract_matrix(0, cols
); // Say rows are size 0 so no storage
492 lu_decomp
.rows
= rows
; // is allocated.
494 interchanges
= new int[rows
];
495 pivot_cols
= new boolean
[cols
];
497 for (int r
= 0; r
< rows
; r
++)
500 for (int j
= 0; j
< cols
; j
++) {
501 factor_col_and_insert(j
);
507 matrix
matrix::inverse()
509 int rows
= elems
.m();
510 int cols
= elems
.n();
512 assert_msg(rows
== cols
,
513 ("Matrix of size %d x %d is not square", rows
, cols
));
518 for (i
= 0; i
< m(); i
++) {
519 assert_msg(pivot_cols
[i
],
520 ("Matrix is singular, col %d is not a pivot col", i
));
523 fract_matrix
result(0, cols
); // Say rows are size 0 so no storage
524 result
.rows
= rows
; // is allocated.
526 for (i
= 0; i
< rows
; i
++) {
527 fract_vector
vec(rows
);
532 fract_vector solve_col
= backsolve_U(vec
, -1, &valid
);
535 result
.col_list
[i
] = solve_col
;
538 return matrix(result
);
541 vector_space
matrix::kernel()
548 vector_space
result(cols
);
549 fract_vector
zeros(rows
);
551 for (int free_var
= 0; free_var
< cols
; free_var
++) {
552 // the kernel has a dimension for each non-pivot column
554 if (!pivot_cols
[free_var
]) {
556 fract_vector solve_col
= backsolve_U(zeros
, free_var
, &valid
);
558 valid
= result
.insert(solve_col
);
566 fract_vector
matrix::particular_solution(const fract_vector
&vec
,
572 fract_vector
new_vec(vec
);
574 apply_L(&new_vec
, cols
);
575 fract_vector result
= backsolve_U(new_vec
, -1, valid
);
581 vector_space
matrix::range()
583 return (vector_space(*this));
587 vector_space
matrix::domain()
589 matrix trans
= transpose();
590 return (vector_space(trans
));
594 matrix
matrix::operator%(integer_row
&rw
)
596 return matrix(elems
% rw
);
599 matrix
matrix::del_row(int i
, int j
)
601 return matrix(elems
.del_row(i
,j
));
605 matrix
matrix::del_col(int i
, int j
)
607 return matrix(elems
.del_col(i
,j
));
611 matrix
matrix::del_columns(integer_row
&rw
)
613 return matrix(elems
.del_columns(rw
));
617 matrix
matrix::insert_col(int i
)
619 return matrix(elems
.insert_col(i
));
622 matrix
matrix::swap_col(int i
, int j
)
624 return matrix(elems
.swap_col(i
,j
));
627 matrix
matrix::swap_row(int i
, int j
)
629 return matrix(elems
.swap_row(i
,j
));
632 fract_vector
matrix::get_row(int i
)
634 return elems
.get_row(i
);
637 fract_vector
matrix::get_col(int i
)
639 return elems
.get_col(i
);
642 void matrix::set_col(int i
, fract_vector
&vec
)
644 elems
.set_col(i
, vec
);
648 void matrix::set_row(int i
, fract_vector
&vec
)
650 elems
.set_row(i
, vec
);
654 matrix
matrix::resize_offset(int r1
, int r2
, int c1
, int c2
, int fill
)
656 return matrix(elems
.resize_offset(r1
, r2
, c1
, c2
, fill
));
659 matrix
matrix::r_merge(matrix
&mat
)
661 return matrix(elems
.r_merge(mat
.elems
));
665 matrix
matrix::c_merge(matrix
&mat
)
667 return matrix(elems
.c_merge(mat
.elems
));
670 immed_list
*matrix::cvt_immed_list()
672 return elems
.cvt_immed_list();
675 matrix
cvt_matrix(immed_list
*il
)
677 return matrix(cvt_fract_matrix(il
));
680 void matrix::print(FILE *fp
)
685 void matrix::print_full(FILE *fp
)
687 fprintf(fp
, "Elements:\n");
691 fprintf(fp
, "LU decomp:\n");
694 assert(interchanges
);
695 fprintf(fp
, "Interchanges [");
697 for (i
= 0; i
< lu_decomp
.m(); i
++)
698 fprintf(fp
, "%s%d", (i
== 0 ? "" : " "), interchanges
[i
]);
702 fprintf(fp
, "Pivot Columns [");
703 for (i
= 0; i
< lu_decomp
.n(); i
++)
704 fprintf(fp
, "%s%d", (i
== 0 ? "" : " "), pivot_cols
[i
]);
709 fprintf(fp
, "Not factored\n");