Matrix: speedup __getattr__ and __setattr__ for element-wise access
[sympy.git] / sympy / matrices / matrices.py
blob2f8c5829b2932729492acae5cca85b346210408e
1 import warnings
2 from sympy import Basic, Symbol
3 from sympy.core import sympify
5 from sympy.core.basic import S, C
6 from sympy.polys import Poly, roots
7 from sympy.simplify import simplify
9 # from sympy.printing import StrPrinter /cyclic/
11 import random
13 class NonSquareMatrixException(Exception):
14 pass
16 class ShapeError(ValueError):
17 """Wrong matrix shape"""
18 pass
20 class MatrixError(Exception):
21 pass
23 def _dims_to_nm( dims ):
24 """Converts dimensions tuple (or any object with length 1 or 2) or scalar
25 in dims to matrix dimensions n and m."""
27 try:
28 l = len( dims )
29 except TypeError:
30 dims = (dims,)
31 l = 1
33 # This will work for nd-array too when they are added to sympy.
34 try:
35 for dim in dims:
36 assert (dim > 0) and isinstance( dim, int )
37 except AssertionError:
38 raise ValueError("Matrix dimensions should positive integers!")
40 if l == 2:
41 n, m = dims
42 elif l == 1:
43 n, m = dims[0], dims[0]
44 else:
45 raise ValueError("Matrix dimensions should be a two-element tuple of ints or a single int!")
47 return n, m
49 class Matrix(object):
51 # Added just for numpy compatibility
52 # TODO: investigate about __array_priority__
53 __array_priority__ = 10.0
55 def __init__(self, *args):
56 """
57 Matrix can be constructed with values or a rule.
59 >>> from sympy import *
60 >>> Matrix( (1,2+I), (3,4) ) #doctest:+NORMALIZE_WHITESPACE
61 [1, 2 + I]
62 [3, 4]
63 >>> Matrix(2, 2, lambda i,j: (i+1)*j ) #doctest:+NORMALIZE_WHITESPACE
64 [0, 1]
65 [0, 2]
67 Note: in SymPy we count indices from 0. The rule however counts from 1.
68 """
69 if len(args) == 3 and callable(args[2]):
70 operation = args[2]
71 assert isinstance(args[0], int) and isinstance(args[1], int)
72 self.lines = args[0]
73 self.cols = args[1]
74 self.mat = []
75 for i in range(self.lines):
76 for j in range(self.cols):
77 self.mat.append(sympify(operation(i, j)))
78 elif len(args)==3 and isinstance(args[0],int) and \
79 isinstance(args[1],int) and isinstance(args[2], (list, tuple)):
80 self.lines=args[0]
81 self.cols=args[1]
82 mat = args[2]
83 if len(mat) != self.lines*self.cols:
84 raise MatrixError('List length should be equal to rows*columns')
85 self.mat = map(lambda i: sympify(i), mat)
86 else:
87 if len(args) == 1:
88 mat = args[0]
89 else:
90 mat = args
91 if isinstance(mat, Matrix):
92 self.lines = mat.lines
93 self.cols = mat.cols
94 self.mat = mat[:]
95 return
96 elif hasattr(mat, "__array__"):
97 # NumPy array or matrix or some other object that implements
98 # __array__. So let's first use this method to get a
99 # numpy.array() and then make a python list out of it.
100 mat = list(mat.__array__())
101 elif not isinstance(mat[0], (list, tuple)):
102 # make each element a singleton
103 mat = [ [element] for element in mat ]
104 self.lines=len(mat)
105 self.cols=len(mat[0])
106 self.mat=[]
107 for j in range(self.lines):
108 assert len(mat[j])==self.cols
109 for i in range(self.cols):
110 self.mat.append(sympify(mat[j][i]))
112 def key2ij(self,key):
113 """Converts key=(4,6) to 4,6 and ensures the key is correct."""
114 if not (isinstance(key,(list, tuple)) and len(key) == 2):
115 raise TypeError("wrong syntax: a[%s]. Use a[i,j] or a[(i,j)]"
116 %repr(key))
117 i,j=key
118 if not (i>=0 and i<self.lines and j>=0 and j < self.cols):
119 print self.lines, " ", self.cols
120 raise IndexError("Index out of range: a[%s]"%repr(key))
121 return i,j
123 def transpose(self):
125 Matrix transposition.
127 >>> from sympy import *
128 >>> m=Matrix(((1,2+I),(3,4)))
129 >>> m #doctest: +NORMALIZE_WHITESPACE
130 [1, 2 + I]
131 [3, 4]
132 >>> m.transpose() #doctest: +NORMALIZE_WHITESPACE
133 [ 1, 3]
134 [2 + I, 4]
135 >>> m.T == m.transpose()
136 True
138 a = [0]*self.cols*self.lines
139 for i in xrange(self.cols):
140 a[i*self.lines:(i+1)*self.lines] = self.mat[i::self.cols]
141 return Matrix(self.cols,self.lines,a)
143 T = property(transpose,None,None,"Matrix transposition.")
145 def conjugate(self):
146 """By-element conjugation."""
147 out = Matrix(self.lines,self.cols,
148 lambda i,j: self[i,j].conjugate())
149 return out
151 C = property(conjugate,None,None,"By-element conjugation.")
153 @property
154 def H(self):
156 Hermite conjugation.
158 >>> from sympy import *
159 >>> m=Matrix(((1,2+I),(3,4)))
160 >>> m #doctest: +NORMALIZE_WHITESPACE
161 [1, 2 + I]
162 [3, 4]
163 >>> m.H #doctest: +NORMALIZE_WHITESPACE
164 [ 1, 3]
165 [2 - I, 4]
167 out = self.T.C
168 return out
170 @property
171 def D(self):
172 """Dirac conjugation."""
173 from sympy.physics.matrices import mgamma
174 out = self.H * mgamma(0)
175 return out
177 def __getitem__(self,key):
179 >>> from sympy import *
180 >>> m=Matrix(((1,2+I),(3,4)))
181 >>> m #doctest: +NORMALIZE_WHITESPACE
182 [1, 2 + I]
183 [3, 4]
184 >>> m[1,0]
186 >>> m.H[1,0]
187 2 - I
190 if type(key) is tuple:
191 i, j = key
192 if type(i) is slice or type(j) is slice:
193 return self.submatrix(key)
195 else:
196 # a2idx inlined
197 try:
198 i = i.__int__()
199 except AttributeError:
200 try:
201 i = i.__index__()
202 except AttributeError:
203 raise IndexError("Invalid index a[%r]" % (key,))
205 # a2idx inlined
206 try:
207 j = j.__int__()
208 except AttributeError:
209 try:
210 j = j.__index__()
211 except AttributeError:
212 raise IndexError("Invalid index a[%r]" % (key,))
215 if not (i>=0 and i<self.lines and j>=0 and j < self.cols):
216 raise IndexError("Index out of range: a[%s]" % (key,))
217 else:
218 return self.mat[i*self.cols + j]
221 else:
222 # row-wise decomposition of matrix
223 if type(key) is slice:
224 return self.mat[key]
225 else:
226 k = a2idx(key)
227 if k is not None:
228 return self.mat[k]
229 raise IndexError("Invalid index: a[%s]"%repr(key))
231 def __setitem__(self,key,value):
233 >>> from sympy import *
234 >>> m=Matrix(((1,2+I),(3,4)))
235 >>> m #doctest: +NORMALIZE_WHITESPACE
236 [1, 2 + I]
237 [3, 4]
238 >>> m[1,0]=9
239 >>> m #doctest: +NORMALIZE_WHITESPACE
240 [1, 2 + I]
241 [9, 4]
244 if type(key) is tuple:
245 i, j = key
246 if type(i) is slice or type(j) is slice:
247 if isinstance(value, Matrix):
248 self.copyin_matrix(key, value)
249 return
250 if isinstance(value, (list, tuple)):
251 self.copyin_list(key, value)
252 return
253 else:
254 # a2idx inlined
255 try:
256 i = i.__int__()
257 except AttributeError:
258 try:
259 i = i.__index__()
260 except AttributeError:
261 raise IndexError("Invalid index a[%r]" % (key,))
263 # a2idx inlined
264 try:
265 j = j.__int__()
266 except AttributeError:
267 try:
268 j = j.__index__()
269 except AttributeError:
270 raise IndexError("Invalid index a[%r]" % (key,))
273 if not (i>=0 and i<self.lines and j>=0 and j < self.cols):
274 raise IndexError("Index out of range: a[%s]" % (key,))
275 else:
276 self.mat[i*self.cols + j] = sympify(value)
277 return
279 else:
280 # row-wise decomposition of matrix
281 if type(key) is slice:
282 raise IndexError("Vector slices not implemented yet.")
283 else:
284 k = a2idx(key)
285 if k is not None:
286 self.mat[k] = sympify(value)
287 return
288 raise IndexError("Invalid index: a[%s]"%repr(key))
290 def __array__(self):
291 return matrix2numpy(self)
293 def tolist(self):
295 Return the Matrix converted in a python list.
297 >>> from sympy import *
298 >>> m=Matrix(3, 3, range(9))
299 >>> m
300 [0, 1, 2]
301 [3, 4, 5]
302 [6, 7, 8]
303 >>> m.tolist()
304 [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
306 ret = [0]*self.lines
307 for i in xrange(self.lines):
308 ret[i] = self.mat[i*self.cols:(i+1)*self.cols]
309 return ret
311 def copyin_matrix(self, key, value):
312 rlo, rhi = self.slice2bounds(key[0], self.lines)
313 clo, chi = self.slice2bounds(key[1], self.cols)
314 assert value.lines == rhi - rlo and value.cols == chi - clo
315 for i in range(value.lines):
316 for j in range(value.cols):
317 self[i+rlo, j+clo] = sympify(value[i,j])
319 def copyin_list(self, key, value):
320 assert isinstance(value, (list, tuple))
321 self.copyin_matrix(key, Matrix(value))
323 def hash(self):
324 """Compute a hash every time, because the matrix elements
325 could change."""
326 return hash(self.__str__() )
328 @property
329 def shape(self):
330 return (self.lines, self.cols)
332 def __rmul__(self,a):
333 if hasattr(a, "__array__"):
334 return matrix_multiply(a,self)
335 out = Matrix(self.lines,self.cols,map(lambda i: a*i,self.mat))
336 return out
338 def expand(self):
339 out = Matrix(self.lines,self.cols,map(lambda i: i.expand(), self.mat))
340 return out
342 def combine(self):
343 out = Matrix(self.lines,self.cols,map(lambda i: i.combine(),self.mat))
344 return out
346 def subs(self, *args):
347 out = Matrix(self.lines,self.cols,map(lambda i: i.subs(*args),self.mat))
348 return out
350 def __sub__(self,a):
351 return self + (-a)
353 def __mul__(self,a):
354 if hasattr(a, "__array__"):
355 return matrix_multiply(self,a)
356 out = Matrix(self.lines,self.cols,map(lambda i: i*a,self.mat))
357 return out
359 def __pow__(self, num):
360 if not self.is_square:
361 raise NonSquareMatrixException()
362 if isinstance(num, int) or isinstance(num, Integer):
363 n = int(num)
364 if n < 0:
365 return self.inv() ** -n # A**-2 = (A**-1)**2
366 a = eye(self.cols)
367 while n:
368 if n % 2:
369 a = a * self
370 n -= 1
371 self = self * self
372 n = n // 2
373 return a
374 raise NotImplementedError('Can only rise to the power of an integer for now')
376 def __add__(self,a):
377 return matrix_add(self,a)
379 def __radd__(self,a):
380 return matrix_add(a,self)
382 def __div__(self,a):
383 return self * (S.One/a)
385 def __truediv__(self,a):
386 return self.__div__(a)
388 def multiply(self,b):
389 """Returns self*b """
390 return matrix_multiply(self,b)
392 def add(self,b):
393 """Return self+b """
394 return matrix_add(self,b)
396 def __neg__(self):
397 return -1*self
399 def __eq__(self, a):
400 if not isinstance(a, (Matrix, Basic)):
401 a = sympify(a)
402 if isinstance(a, Matrix):
403 return self.hash() == a.hash()
404 else:
405 return False
407 def __ne__(self,a):
408 if not isinstance(a, (Matrix, Basic)):
409 a = sympify(a)
410 if isinstance(a, Matrix):
411 return self.hash() != a.hash()
412 else:
413 return True
415 def _format_str(self, strfunc, rowsep='\n'):
416 # Build table of string representations of the elements
417 res = []
418 # Track per-column max lengths for pretty alignment
419 maxlen = [0] * self.cols
420 for i in range(self.lines):
421 res.append([])
422 for j in range(self.cols):
423 string = strfunc(self[i,j])
424 res[-1].append(string)
425 maxlen[j] = max(len(string), maxlen[j])
426 # Patch strings together
427 for i, row in enumerate(res):
428 for j, elem in enumerate(row):
429 # Pad each element up to maxlen so the columns line up
430 row[j] = elem.rjust(maxlen[j])
431 res[i] = "[" + ", ".join(row) + "]"
432 return rowsep.join(res)
434 def __str__(self):
435 return StrPrinter.doprint(self)
437 def __repr__(self):
438 return StrPrinter.doprint(self)
440 def inv(self, method="GE"):
442 Calculates the matrix inverse.
444 According to the "method" parameter, it calls the appropriate method:
446 GE .... inverse_GE()
447 LU .... inverse_LU()
448 ADJ ... inverse_ADJ()
451 assert self.cols==self.lines
452 if method == "GE":
453 return self.inverse_GE()
454 elif method == "LU":
455 return self.inverse_LU()
456 elif method == "ADJ":
457 return self.inverse_ADJ()
458 else:
459 raise Exception("Inversion method unrecognized")
462 def __mathml__(self):
463 mml = ""
464 for i in range(self.lines):
465 mml += "<matrixrow>"
466 for j in range(self.cols):
467 mml += self[i,j].__mathml__()
468 mml += "</matrixrow>"
469 return "<matrix>" + mml + "</matrix>"
471 def row(self, i, f):
472 """Elementary row operation using functor"""
473 for j in range(0, self.cols):
474 self[i, j] = f(self[i, j], j)
476 def col(self, j, f):
477 """Elementary column operation using functor"""
478 for i in range(0, self.lines):
479 self[i, j] = f(self[i, j], i)
481 def row_swap(self, i, j):
482 for k in range(0, self.cols):
483 self[i, k], self[j, k] = self[j, k], self[i, k]
485 def col_swap(self, i, j):
486 for k in range(0, self.lines):
487 self[k, i], self[k, j] = self[k, j], self[k, i]
489 def row_del(self, i):
490 self.mat = self.mat[:i*self.cols] + self.mat[(i+1)*self.cols:]
491 self.lines -= 1
493 def col_del(self, i):
495 >>> import sympy
496 >>> M = sympy.matrices.eye(3)
497 >>> M.col_del(1)
498 >>> M #doctest: +NORMALIZE_WHITESPACE
499 [1, 0]
500 [0, 0]
501 [0, 1]
503 for j in range(self.lines-1, -1, -1):
504 del self.mat[i+j*self.cols]
505 self.cols -= 1
507 def row_join(self, rhs):
509 Concatenates two matrices along self's last and rhs's first column
511 >>> from sympy import *
512 >>> M = Matrix(3,3,lambda i,j: i+j)
513 >>> V = Matrix(3,1,lambda i,j: 3+i+j)
514 >>> M.row_join(V)
515 [0, 1, 2, 3]
516 [1, 2, 3, 4]
517 [2, 3, 4, 5]
519 assert self.lines == rhs.lines
520 newmat = self.zeros((self.lines, self.cols + rhs.cols))
521 newmat[:,:self.cols] = self[:,:]
522 newmat[:,self.cols:] = rhs
523 return newmat
525 def col_join(self, bott):
527 Concatenates two matrices along self's last and bott's first row
529 >>> from sympy import *
530 >>> M = Matrix(3,3,lambda i,j: i+j)
531 >>> V = Matrix(1,3,lambda i,j: 3+i+j)
532 >>> M.col_join(V)
533 [0, 1, 2]
534 [1, 2, 3]
535 [2, 3, 4]
536 [3, 4, 5]
538 assert self.cols == bott.cols
539 newmat = self.zeros((self.lines+bott.lines, self.cols))
540 newmat[:self.lines,:] = self[:,:]
541 newmat[self.lines:,:] = bott
542 return newmat
544 def row_insert(self, pos, mti):
546 >>> from sympy import *
547 >>> M = Matrix(3,3,lambda i,j: i+j)
548 >>> M
549 [0, 1, 2]
550 [1, 2, 3]
551 [2, 3, 4]
552 >>> V = zeros((1, 3))
553 >>> V
554 [0, 0, 0]
555 >>> M.row_insert(1,V)
556 [0, 1, 2]
557 [0, 0, 0]
558 [1, 2, 3]
559 [2, 3, 4]
561 if pos is 0:
562 return mti.col_join(self)
563 assert self.cols == mti.cols
564 newmat = self.zeros((self.lines + mti.lines, self.cols))
565 newmat[:pos,:] = self[:pos,:]
566 newmat[pos:pos+mti.lines,:] = mti[:,:]
567 newmat[pos+mti.lines:,:] = self[pos:,:]
568 return newmat
570 def col_insert(self, pos, mti):
572 >>> from sympy import *
573 >>> M = Matrix(3,3,lambda i,j: i+j)
574 >>> M
575 [0, 1, 2]
576 [1, 2, 3]
577 [2, 3, 4]
578 >>> V = zeros((3, 1))
579 >>> V
583 >>> M.col_insert(1,V)
584 [0, 0, 1, 2]
585 [1, 0, 2, 3]
586 [2, 0, 3, 4]
588 if pos is 0:
589 return mti.row_join(self)
590 assert self.lines == mti.lines
591 newmat = self.zeros((self.lines, self.cols + mti.cols))
592 newmat[:,:pos] = self[:,:pos]
593 newmat[:,pos:pos+mti.cols] = mti[:,:]
594 newmat[:,pos+mti.cols:] = self[:,pos:]
595 return newmat
597 def trace(self):
598 assert self.cols == self.lines
599 trace = 0
600 for i in range(self.cols):
601 trace += self[i,i]
602 return trace
604 def submatrix(self, keys):
606 >>> from sympy import *
607 >>> m = Matrix(4,4,lambda i,j: i+j)
608 >>> m #doctest: +NORMALIZE_WHITESPACE
609 [0, 1, 2, 3]
610 [1, 2, 3, 4]
611 [2, 3, 4, 5]
612 [3, 4, 5, 6]
613 >>> m[0:1, 1] #doctest: +NORMALIZE_WHITESPACE
615 >>> m[0:2, 0:1] #doctest: +NORMALIZE_WHITESPACE
618 >>> m[2:4, 2:4] #doctest: +NORMALIZE_WHITESPACE
619 [4, 5]
620 [5, 6]
622 assert isinstance(keys[0], slice) or isinstance(keys[1], slice)
623 rlo, rhi = self.slice2bounds(keys[0], self.lines)
624 clo, chi = self.slice2bounds(keys[1], self.cols)
625 if not ( 0<=rlo<=rhi and 0<=clo<=chi ):
626 raise IndexError("Slice indices out of range: a[%s]"%repr(keys))
627 outLines, outCols = rhi-rlo, chi-clo
628 outMat = [0]*outLines*outCols
629 for i in xrange(outLines):
630 outMat[i*outCols:(i+1)*outCols] = self.mat[(i+rlo)*self.cols+clo:(i+rlo)*self.cols+chi]
631 return Matrix(outLines,outCols,outMat)
633 def slice2bounds(self, key, defmax):
635 Takes slice or number and returns (min,max) for iteration
636 Takes a default maxval to deal with the slice ':' which is (none, none)
638 if isinstance(key, slice):
639 lo, hi = 0, defmax
640 if key.start != None:
641 if key.start >= 0:
642 lo = key.start
643 else:
644 lo = defmax+key.start
645 if key.stop != None:
646 if key.stop >= 0:
647 hi = key.stop
648 else:
649 hi = defmax+key.stop
650 return lo, hi
651 elif isinstance(key, int):
652 if key >= 0:
653 return key, key+1
654 else:
655 return defmax+key, defmax+key+1
656 else:
657 raise IndexError("Improper index type")
659 def applyfunc(self, f):
661 >>> from sympy import *
662 >>> m = Matrix(2,2,lambda i,j: i*2+j)
663 >>> m #doctest: +NORMALIZE_WHITESPACE
664 [0, 1]
665 [2, 3]
666 >>> m.applyfunc(lambda i: 2*i) #doctest: +NORMALIZE_WHITESPACE
667 [0, 2]
668 [4, 6]
670 assert callable(f)
671 out = Matrix(self.lines,self.cols,map(f,self.mat))
672 return out
674 def evalf(self):
675 out = self.applyfunc(Basic.evalf)
676 return out
678 def reshape(self, _rows, _cols):
680 >>> from sympy import *
681 >>> m = Matrix(2,3,lambda i,j: 1)
682 >>> m #doctest: +NORMALIZE_WHITESPACE
683 [1, 1, 1]
684 [1, 1, 1]
685 >>> m.reshape(1,6) #doctest: +NORMALIZE_WHITESPACE
686 [1, 1, 1, 1, 1, 1]
687 >>> m.reshape(3,2) #doctest: +NORMALIZE_WHITESPACE
688 [1, 1]
689 [1, 1]
690 [1, 1]
692 if self.lines*self.cols != _rows*_cols:
693 print "Invalid reshape parameters %d %d" % (_rows, _cols)
694 return Matrix(_rows, _cols, lambda i,j: self.mat[i*_cols + j])
696 def print_nonzero (self, symb="X"):
698 Shows location of non-zero entries for fast shape lookup
699 >>> from sympy import *
700 >>> m = Matrix(2,3,lambda i,j: i*3+j)
701 >>> m #doctest: +NORMALIZE_WHITESPACE
702 [0, 1, 2]
703 [3, 4, 5]
704 >>> m.print_nonzero() #doctest: +NORMALIZE_WHITESPACE
705 [ XX]
706 [XXX]
707 >>> m = matrices.eye(4)
708 >>> m.print_nonzero("x") #doctest: +NORMALIZE_WHITESPACE
709 [x ]
710 [ x ]
711 [ x ]
712 [ x]
714 s="";
715 for i in range(self.lines):
716 s+="["
717 for j in range(self.cols):
718 if self[i,j] == 0:
719 s+=" "
720 else:
721 s+= symb+""
722 s+="]\n"
723 print s
725 def LUsolve(self, rhs):
727 Solve the linear system Ax = b.
728 self is the coefficient matrix A and rhs is the right side b.
730 assert rhs.lines == self.lines
731 A, perm = self.LUdecomposition_Simple()
732 n = self.lines
733 b = rhs.permuteFwd(perm)
734 # forward substitution, all diag entries are scaled to 1
735 for i in range(n):
736 for j in range(i):
737 b.row(i, lambda x,k: x - b[j,k]*A[i,j])
738 # backward substitution
739 for i in range(n-1,-1,-1):
740 for j in range(i+1, n):
741 b.row(i, lambda x,k: x - b[j,k]*A[i,j])
742 b.row(i, lambda x,k: x / A[i,i])
743 return b
745 def LUdecomposition(self):
747 Returns the decompositon LU and the row swaps p.
749 combined, p = self.LUdecomposition_Simple()
750 L = self.zeros(self.lines)
751 U = self.zeros(self.lines)
752 for i in range(self.lines):
753 for j in range(self.lines):
754 if i > j:
755 L[i,j] = combined[i,j]
756 else:
757 if i == j:
758 L[i,i] = 1
759 U[i,j] = combined[i,j]
760 return L, U, p
762 def LUdecomposition_Simple(self):
764 Returns A compused of L,U (L's diag entries are 1) and
765 p which is the list of the row swaps (in order).
767 assert self.lines == self.cols
768 n = self.lines
769 A = self[:,:]
770 p = []
771 # factorization
772 for j in range(n):
773 for i in range(j):
774 for k in range(i):
775 A[i,j] = A[i,j] - A[i,k]*A[k,j]
776 pivot = -1
777 for i in range(j,n):
778 for k in range(j):
779 A[i,j] = A[i,j] - A[i,k]*A[k,j]
780 # find the first non-zero pivot, includes any expression
781 if pivot == -1 and A[i,j] != 0:
782 pivot = i
783 if pivot < 0:
784 raise "Error: non-invertible matrix passed to LUdecomposition_Simple()"
785 if pivot != j: # row must be swapped
786 A.row_swap(pivot,j)
787 p.append([pivot,j])
788 assert A[j,j] != 0
789 scale = 1 / A[j,j]
790 for i in range(j+1,n):
791 A[i,j] = A[i,j] * scale
792 return A, p
795 def LUdecompositionFF(self):
797 Returns 4 matrices P, L, D, U such that PA = L D**-1 U.
799 From the paper "fraction-free matrix factors..." by Zhou and Jeffrey
801 n, m = self.lines, self.cols
802 U, L, P = self[:,:], eye(n), eye(n)
803 DD = zeros(n) # store it smarter since it's just diagonal
804 oldpivot = 1
806 for k in range(n-1):
807 if U[k,k] == 0:
808 kpivot = k+1
809 Notfound = True
810 while kpivot < n and Notfound:
811 if U[kpivot, k] != 0:
812 Notfound = False
813 else:
814 kpivot = kpivot + 1
815 if kpivot == n+1:
816 raise "Matrix is not full rank"
817 else:
818 swap = U[k, k:]
819 U[k,k:] = U[kpivot,k:]
820 U[kpivot, k:] = swap
821 swap = P[k, k:]
822 P[k, k:] = P[kpivot, k:]
823 P[kpivot, k:] = swap
824 assert U[k, k] != 0
825 L[k,k] = U[k,k]
826 DD[k,k] = oldpivot * U[k,k]
827 assert DD[k,k] != 0
828 Ukk = U[k,k]
829 for i in range(k+1, n):
830 L[i,k] = U[i,k]
831 Uik = U[i,k]
832 for j in range(k+1, m):
833 U[i,j] = (Ukk * U[i,j] - U[k,j]*Uik) / oldpivot
834 U[i,k] = 0
835 oldpivot = U[k,k]
836 DD[n-1,n-1] = oldpivot
837 return P, L, DD, U
839 def cofactorMatrix(self, method="berkowitz"):
840 out = Matrix(self.lines, self.cols, lambda i,j:
841 self.cofactor(i, j, method))
842 return out
844 def minorEntry(self, i, j, method="berkowitz"):
845 assert 0 <= i < self.lines and 0 <= j < self.cols
846 return self.minorMatrix(i,j).det(method)
848 def minorMatrix(self, i, j):
849 assert 0 <= i < self.lines and 0 <= j < self.cols
850 return self.delRowCol(i,j)
852 def cofactor(self, i, j, method="berkowitz"):
853 if (i+j) % 2 == 0:
854 return self.minorEntry(i, j, method)
855 else:
856 return -1 * self.minorEntry(i, j, method)
858 def jacobian(self, varlist):
860 Calculates the Jacobian matrix (derivative of a vectorial function).
862 self is a vector of expression representing functions f_i(x_1, ...,
863 x_n). varlist is the set of x_i's in order.
865 assert self.lines == 1
866 m = self.cols
867 if isinstance(varlist, Matrix):
868 assert varlist.lines == 1
869 n = varlist.cols
870 elif isinstance(varlist, (list, tuple)):
871 n = len(varlist)
872 assert n > 0 # need to diff by something
873 J = self.zeros((m, n)) # maintain subclass type
874 for i in range(m):
875 if isinstance(self[i], (float, int)):
876 continue # constant function, jacobian row is zero
877 try:
878 tmp = self[i].diff(varlist[0]) # check differentiability
879 J[i,0] = tmp
880 except AttributeError:
881 raise "Function %d is not differentiable" % i
882 for j in range(1,n):
883 J[i,j] = self[i].diff(varlist[j])
884 return J
886 def QRdecomposition(self):
888 Return Q*R where Q is orthogonal and R is upper triangular.
890 Assumes full-rank square (for now).
892 assert self.lines == self.cols
893 n = self.lines
894 Q, R = self.zeros(n), self.zeros(n)
895 for j in range(n): # for each column vector
896 tmp = self[:,j] # take original v
897 for i in range(j):
898 # subtract the project of self on new vector
899 tmp -= Q[:,i] * self[:,j].dot(Q[:,i])
900 tmp.expand()
901 # normalize it
902 R[j,j] = tmp.norm()
903 Q[:,j] = tmp / R[j,j]
904 assert Q[:,j].norm() == 1
905 for i in range(j):
906 R[i,j] = Q[:,i].dot(self[:,j])
907 return Q,R
909 # TODO: QRsolve
911 # Utility functions
912 def simplify(self):
913 for i in xrange(len(self.mat)):
914 self.mat[i] = simplify(self.mat[i])
916 #def evaluate(self): # no more eval() so should be removed
917 # for i in range(self.lines):
918 # for j in range(self.cols):
919 # self[i,j] = self[i,j].eval()
921 def cross(self, b):
922 assert isinstance(b, (list, tuple, Matrix))
923 if not (self.lines == 1 and self.cols == 3 or \
924 self.lines == 3 and self.cols == 1 ) and \
925 (b.lines == 1 and b.cols == 3 or \
926 b.lines == 3 and b.cols == 1):
927 raise "Dimensions incorrect for cross product"
928 else:
929 return Matrix(1,3,((self[1]*b[2] - self[2]*b[1]),
930 (self[2]*b[0] - self[0]*b[2]),
931 (self[0]*b[1] - self[1]*b[0])))
933 def dot(self, b):
934 assert isinstance(b, (list, tuple, Matrix))
935 if isinstance(b, (list, tuple)):
936 m = len(b)
937 else:
938 m = b.lines * b.cols
939 assert self.cols*self.lines == m
940 prod = 0
941 for i in range(m):
942 prod += self[i] * b[i]
943 return prod
945 def norm(self):
946 assert self.lines == 1 or self.cols == 1
947 out = sympify(0)
948 for i in range(self.lines * self.cols):
949 out += self[i]*self[i]
950 return out**S.Half
952 def normalized(self):
953 assert self.lines == 1 or self.cols == 1
954 norm = self.norm()
955 out = self.applyfunc(lambda i: i / norm)
956 return out
958 def project(self, v):
959 """Project onto v."""
960 return v * (self.dot(v) / v.dot(v))
962 def permuteBkwd(self, perm):
963 copy = self[:,:]
964 for i in range(len(perm)-1, -1, -1):
965 copy.row_swap(perm[i][0], perm[i][1])
966 return copy
968 def permuteFwd(self, perm):
969 copy = self[:,:]
970 for i in range(len(perm)):
971 copy.row_swap(perm[i][0], perm[i][1])
972 return copy
974 def delRowCol(self, i, j):
975 # used only for cofactors, makes a copy
976 M = self[:,:]
977 M.row_del(i)
978 M.col_del(j)
979 return M
981 def zeronm(self, n, m):
982 # used so that certain functions above can use this
983 # then only this func need be overloaded in subclasses
984 warnings.warn( 'Deprecated: use zeros() instead.' )
985 return Matrix(n,m,[S.Zero]*n*m)
987 def zero(self, n):
988 """Returns a n x n matrix of zeros."""
989 warnings.warn( 'Deprecated: use zeros() instead.' )
990 return Matrix(n,n,[S.Zero]*n*n)
992 def zeros(self, dims):
993 """Returns a dims = (d1,d2) matrix of zeros."""
994 n, m = _dims_to_nm( dims )
995 return Matrix(n,m,[S.Zero]*n*m)
997 def eye(self, n):
998 """Returns the identity matrix of size n."""
999 tmp = self.zeros(n)
1000 for i in range(tmp.lines):
1001 tmp[i,i] = S.One
1002 return tmp
1004 @property
1005 def is_square(self):
1006 return self.lines == self.cols
1008 def is_upper(self):
1009 for i in range(self.cols):
1010 for j in range(self.lines):
1011 if i > j and self[i,j] != 0:
1012 return False
1013 return True
1015 def is_lower(self):
1016 for i in range(self.cols):
1017 for j in range(self.lines):
1018 if i < j and self[i, j] != 0:
1019 return False
1020 return True
1022 def is_symbolic(self):
1023 for i in range(self.cols):
1024 for j in range(self.lines):
1025 if self[i,j].atoms(Symbol):
1026 return True
1027 return False
1029 def clone(self):
1030 return Matrix(self.lines, self.cols, lambda i, j: self[i, j])
1032 def det(self, method="bareis"):
1034 Computes the matrix determinant using the method "method".
1036 Possible values for "method":
1037 bareis ... det_bareis
1038 berkowitz ... berkowitz_det
1041 if method == "bareis":
1042 return self.det_bareis()
1043 elif method == "berkowitz":
1044 return self.berkowitz_det()
1045 else:
1046 raise Exception("Determinant method unrecognized")
1048 def det_bareis(self):
1049 """Compute matrix determinant using Bareis' fraction-free
1050 algorithm which is an extension of the well known Gaussian
1051 elimination method. This approach is best suited for dense
1052 symbolic matrices and will result in a determinant with
1053 minimal numer of fractions. It means that less term
1054 rewriting is needed on resulting formulae.
1056 TODO: Implement algorithm for sparse matrices (SFF).
1058 if not self.is_square:
1059 raise NonSquareMatrixException()
1061 M, n = self[:,:], self.lines
1063 if n == 1:
1064 det = M[0, 0]
1065 elif n == 2:
1066 det = M[0, 0]*M[1, 1] - M[0, 1]*M[1, 0]
1067 else:
1068 sign = 1 # track current sign in case of column swap
1070 for k in range(n-1):
1071 # look for a pivot in the current column
1072 # and assume det == 0 if none is found
1073 if M[k, k] == 0:
1074 for i in range(k+1, n):
1075 if M[i, k] != 0:
1076 M.row_swap(i, k)
1077 sign *= -1
1078 break
1079 else:
1080 return S.Zero
1082 # proceed with Bareis' fraction-free (FF)
1083 # form of Gaussian elimination algorithm
1084 for i in range(k+1, n):
1085 for j in range(k+1, n):
1086 D = M[k, k]*M[i, j] - M[i, k]*M[k, j]
1088 if k > 0:
1089 D /= M[k-1, k-1]
1091 if D.is_Atom:
1092 M[i, j] = D
1093 else:
1094 M[i, j] = Poly.cancel(D)
1096 det = sign * M[n-1, n-1]
1098 return det.expand()
1100 def adjugate(self, method="berkowitz"):
1102 Returns the adjugate matrix.
1104 Adjugate matrix is the transpose of the cofactor matrix.
1106 http://en.wikipedia.org/wiki/Adjugate
1108 See also: .cofactorMatrix(), .T
1111 return self.cofactorMatrix(method).T
1114 def inverse_LU(self):
1116 Calculates the inverse using LU decomposition.
1118 return self.LUsolve(self.eye(self.lines))
1120 def inverse_GE(self):
1122 Calculates the inverse using Gaussian elimination.
1124 assert self.lines == self.cols
1125 assert self.det() != 0
1126 big = self.row_join(self.eye(self.lines))
1127 red = big.rref()
1128 return red[0][:,big.lines:]
1130 def inverse_ADJ(self):
1132 Calculates the inverse using the adjugate matrix and a determinant.
1134 assert self.lines == self.cols
1135 d = self.berkowitz_det()
1136 assert d != 0
1137 return self.adjugate()/d
1139 def rref(self,simplified=False):
1141 Take any matrix and return reduced row-echelon form and indices of pivot vars
1143 To simplify elements before finding nonzero pivots set simplified=True
1145 # TODO: rewrite inverse_GE to use this
1146 pivots, r = 0, self[:,:] # pivot: index of next row to contain a pivot
1147 pivotlist = [] # indices of pivot variables (non-free)
1148 for i in range(r.cols):
1149 if pivots == r.lines:
1150 break
1151 if simplified:
1152 r[pivots,i] = simplify(r[pivots,i])
1153 if r[pivots,i] == 0:
1154 for k in range(pivots, r.lines):
1155 if simplified and k>pivots:
1156 r[k,i] = simplify(r[k,i])
1157 if r[k,i] != 0:
1158 break
1159 if k == r.lines - 1 and r[k,i] == 0:
1160 continue
1161 r.row_swap(pivots,k)
1162 scale = r[pivots,i]
1163 r.row(pivots, lambda x, _: x/scale)
1164 for j in range(r.lines):
1165 if j == pivots:
1166 continue
1167 scale = r[j,i]
1168 r.row(j, lambda x, k: x - r[pivots,k]*scale)
1169 pivotlist.append(i)
1170 pivots += 1
1171 return r, pivotlist
1173 def nullspace(self,simplified=False):
1175 Returns list of vectors (Matrix objects) that span nullspace of self
1177 assert self.cols >= self.lines
1178 reduced, pivots = self.rref(simplified)
1179 basis = []
1180 # create a set of vectors for the basis
1181 for i in range(self.cols - len(pivots)):
1182 basis.append(zeros((self.cols, 1)))
1183 # contains the variable index to which the vector corresponds
1184 basiskey, cur = [-1]*len(basis), 0
1185 for i in range(self.cols):
1186 if i not in pivots:
1187 basiskey[cur] = i
1188 cur += 1
1189 for i in range(self.cols):
1190 if i not in pivots: # free var, just set vector's ith place to 1
1191 basis[basiskey.index(i)][i,0] = 1
1192 else: # add negative of nonpivot entry to corr vector
1193 for j in range(i+1, self.cols):
1194 line = pivots.index(i)
1195 if reduced[line, j] != 0:
1196 assert j not in pivots
1197 basis[basiskey.index(j)][i,0] = -1 * reduced[line, j]
1198 return basis
1200 def berkowitz(self):
1201 """The Berkowitz algorithm.
1203 Given N x N matrix with symbolic content, compute efficiently
1204 coefficients of characteristic polynomials of 'self' and all
1205 its square sub-matrices composed by removing both i-th row
1206 and column, without division in the ground domain.
1208 This method is particulary useful for computing determinant,
1209 principal minors and characteristic polynomial, when 'self'
1210 has complicated coefficients eg. polynomials. Semi-direct
1211 usage of this algorithm is also important in computing
1212 efficiently subresultant PRS.
1214 Assuming that M is a square matrix of dimension N x N and
1215 I is N x N identity matrix, then the following following
1216 definition of characteristic polynomial is begin used:
1218 charpoly(M) = det(t*I - M)
1220 As a consequence, all polynomials generated by Berkowitz
1221 algorithm are monic.
1223 >>> from sympy import *
1224 >>> x,y,z = symbols('xyz')
1226 >>> M = Matrix([ [x,y,z], [1,0,0], [y,z,x] ])
1228 >>> p, q, r = M.berkowitz()
1230 >>> print p # 1 x 1 M's sub-matrix
1231 (1, -x)
1233 >>> print q # 2 x 2 M's sub-matrix
1234 (1, -x, -y)
1236 >>> print r # 3 x 3 M's sub-matrix
1237 (1, -2*x, -y - y*z + x**2, x*y - z**2)
1239 For more information on the implemented algorithm refer to:
1241 [1] S.J. Berkowitz, On computing the determinant in small
1242 parallel time using a small number of processors, ACM,
1243 Information Processing Letters 18, 1984, pp. 147-150
1245 [2] M. Keber, Division-Free computation of subresultants
1246 using Bezout matrices, Tech. Report MPI-I-2006-1-006,
1247 Saarbrucken, 2006
1250 if not self.is_square:
1251 raise MatrixError
1253 A, N = self, self.lines
1254 transforms = [0] * (N-1)
1256 for n in xrange(N, 1, -1):
1257 T, k = zeros((n+1,n)), n - 1
1259 R, C = -A[k,:k], A[:k,k]
1260 A, a = A[:k,:k], -A[k,k]
1262 items = [ C ]
1264 for i in xrange(0, n-2):
1265 items.append(A * items[i])
1267 for i, B in enumerate(items):
1268 items[i] = (R * B)[0,0]
1270 items = [ S.One, a ] + items
1272 for i in xrange(n):
1273 T[i:,i] = items[:n-i+1]
1275 transforms[k-1] = T
1277 polys = [ Matrix(S.One, -A[0,0]) ]
1279 for i, T in enumerate(transforms):
1280 polys.append(T * polys[i])
1282 return tuple(map(tuple, polys))
1284 def berkowitz_det(self):
1285 """Computes determinant using Berkowitz method."""
1286 poly = self.berkowitz()[-1]
1287 sign = (-1)**(len(poly)-1)
1288 return sign * poly[-1]
1290 def berkowitz_minors(self):
1291 """Computes principal minors using Berkowitz method."""
1292 sign, minors = S.NegativeOne, []
1294 for poly in self.berkowitz():
1295 minors.append(sign*poly[-1])
1296 sign = -sign
1298 return tuple(minors)
1300 def berkowitz_charpoly(self, x):
1301 """Computes characteristic polynomial minors using Berkowitz method."""
1302 coeffs, monoms = self.berkowitz()[-1], range(self.lines+1)
1303 return Poly(list(zip(coeffs, reversed(monoms))), x)
1305 charpoly = berkowitz_charpoly
1307 def berkowitz_eigenvals(self, **flags):
1308 """Computes eigenvalues of a Matrix using Berkowitz method. """
1309 return roots(self.berkowitz_charpoly(Symbol('x', dummy=True)), **flags)
1311 eigenvals = berkowitz_eigenvals
1313 def eigenvects(self, **flags):
1314 """Return list of triples (eigenval, multiplicty, basis)."""
1316 if flags.has_key('multiple'):
1317 del flags['multiple']
1319 out, vlist = [], self.eigenvals(**flags)
1321 for r, k in vlist.iteritems():
1322 tmp = self - eye(self.lines)*r
1323 basis = tmp.nullspace()
1324 # check if basis is right size, don't do it if symbolic - too many solutions
1325 if not tmp.is_symbolic():
1326 assert len(basis) == k
1327 elif len(basis) != k:
1328 # The nullspace routine failed, try it again with simplification
1329 basis = tmp.nullspace(simplified=True)
1330 out.append((r, k, basis))
1331 return out
1333 def fill(self, value):
1334 """Fill the matrix with the scalar value."""
1335 self.mat = [value] * self.lines * self.cols
1337 def matrix_multiply(A,B):
1339 Return A*B.
1341 if A.shape[1] != B.shape[0]:
1342 raise ShapeError()
1343 blst = B.T.tolist()
1344 alst = A.tolist()
1345 return Matrix(A.shape[0], B.shape[1], lambda i,j:
1346 reduce(lambda k,l: k+l,
1347 map(lambda n,m: n*m,
1348 alst[i],
1349 blst[j])).expand())
1350 # .expand() is a test
1352 def matrix_add(A,B):
1353 """Return A+B"""
1354 if A.shape != B.shape:
1355 raise ShapeError()
1356 alst = A.tolist()
1357 blst = B.tolist()
1358 ret = [0]*A.shape[0]
1359 for i in xrange(A.shape[0]):
1360 ret[i] = map(lambda j,k: j+k, alst[i], blst[i])
1361 return Matrix(ret)
1363 def zero(n):
1364 """Create square zero matrix n x n"""
1365 warnings.warn( 'Deprecated: use zeros() instead.' )
1366 return zeronm(n,n)
1368 def zeronm(n,m):
1369 """Create zero matrix n x m"""
1370 warnings.warn( 'Deprecated: use zeros() instead.' )
1371 assert n>0
1372 assert m>0
1373 return Matrix(n,m,[S.Zero]*m*n)
1375 def zeros(dims):
1376 """Create zero matrix of dimensions dims = (d1,d2)"""
1377 n, m = _dims_to_nm( dims )
1378 return Matrix(n,m,[S.Zero]*m*n)
1380 def one(n):
1381 """Create square all-one matrix n x n"""
1382 warnings.warn( 'Deprecated: use ones() instead.' )
1383 return Matrix(n,n,[S.One]*n*n)
1385 def ones(dims):
1386 """Create all-one matrix of dimensions dims = (d1,d2)"""
1387 n, m = _dims_to_nm( dims )
1388 return Matrix(n,m,[S.One]*m*n)
1390 def eye(n):
1391 """Create square identity matrix n x n"""
1392 assert n>0
1393 out = zeros(n)
1394 for i in range(n):
1395 out[i,i]=S.One
1396 return out
1398 def randMatrix(r,c,min=0,max=99,seed=[]):
1399 """Create random matrix r x c"""
1400 if seed == []:
1401 prng = random.Random() # use system time
1402 else:
1403 prng = random.Random(seed)
1404 return Matrix(r,c,lambda i,j: prng.randint(min,max))
1406 def hessian(f, varlist):
1407 """Compute Hessian matrix for a function f
1409 see: http://en.wikipedia.org/wiki/Hessian_matrix
1411 # f is the expression representing a function f, return regular matrix
1412 if isinstance(varlist, (list, tuple)):
1413 m = len(varlist)
1414 elif isinstance(varlist, Matrix):
1415 m = varlist.cols
1416 assert varlist.lines == 1
1417 else:
1418 raise "Improper variable list in hessian function"
1419 assert m > 0
1420 try:
1421 f.diff(varlist[0]) # check differentiability
1422 except AttributeError:
1423 raise "Function %d is not differentiable" % i
1424 out = zeros(m)
1425 for i in range(m):
1426 for j in range(i,m):
1427 out[i,j] = f.diff(varlist[i]).diff(varlist[j])
1428 for i in range(m):
1429 for j in range(i):
1430 out[i,j] = out[j,i]
1431 return out
1433 def GramSchmidt(vlist, orthog=False):
1434 out = []
1435 m = len(vlist)
1436 for i in range(m):
1437 tmp = vlist[i]
1438 for j in range(i):
1439 tmp -= vlist[i].project(out[j])
1440 if tmp == Matrix([[0,0,0]]):
1441 raise "GramSchmidt: vector set not linearly independent"
1442 out.append(tmp)
1443 if orthog:
1444 for i in range(len(out)):
1445 out[i] = out[i].normalized()
1446 return out
1448 def wronskian(functions, var):
1449 """Compute wronskian for [] of functions
1451 | f1 f2 ... fn |
1452 | f1' f2' ... fn' |
1453 | . . . . |
1454 W(f1,...,fn) = | . . . . |
1455 | . . . . |
1456 | n n n |
1457 | D(f1) D(f2) ... D(fn)|
1459 see: http://en.wikipedia.org/wiki/Wronskian
1462 for index in xrange(0, len(functions)):
1463 functions[index] = sympify(functions[index])
1464 n = len(functions)
1465 W = Matrix(n, n, lambda i,j: functions[i].diff(var, j) )
1466 return W.det()
1468 def casoratian(seqs, n, zero=True):
1469 """Given linear difference operator L of order 'k' and homogeneous
1470 equation Ly = 0 we want to compute kernel of L, which is a set
1471 of 'k' sequences: a(n), b(n), ... z(n).
1473 Solutions of L are lineary independent iff their Casoratian,
1474 denoted as C(a, b, ..., z), do not vanish for n = 0.
1476 Casoratian is defined by k x k determinant:
1478 + a(n) b(n) . . . z(n) +
1479 | a(n+1) b(n+1) . . . z(n+1) |
1480 | . . . . |
1481 | . . . . |
1482 | . . . . |
1483 + a(n+k-1) b(n+k-1) . . . z(n+k-1) +
1485 It proves very useful in rsolve_hyper() where it is applied
1486 to a generating set of a recurrence to factor out lineary
1487 dependent solutions and return a basis.
1489 >>> from sympy import *
1490 >>> n = Symbol('n', integer=True)
1492 Exponential and factorial are lineary independent:
1494 >>> casoratian([2**n, factorial(n)], n) != 0
1495 True
1498 seqs = map(sympify, seqs)
1500 if not zero:
1501 f = lambda i, j: seqs[j].subs(n, n+i)
1502 else:
1503 f = lambda i, j: seqs[j].subs(n, i)
1505 k = len(seqs)
1507 return Matrix(k, k, f).det()
1509 class SMatrix(Matrix):
1510 """Sparse matrix"""
1512 def __init__(self, *args):
1513 if len(args) == 3 and callable(args[2]):
1514 op = args[2]
1515 assert isinstance(args[0], int) and isinstance(args[1], int)
1516 self.lines = args[0]
1517 self.cols = args[1]
1518 self.mat = {}
1519 for i in range(self.lines):
1520 for j in range(self.cols):
1521 value = sympify(op(i,j))
1522 if value != 0:
1523 self.mat[(i,j)] = value
1524 elif len(args)==3 and isinstance(args[0],int) and \
1525 isinstance(args[1],int) and isinstance(args[2], (list, tuple)):
1526 self.lines = args[0]
1527 self.cols = args[1]
1528 mat = args[2]
1529 self.mat = {}
1530 for i in range(self.lines):
1531 for j in range(self.cols):
1532 value = sympify(mat[i*self.cols+j])
1533 if value != 0:
1534 self.mat[(i,j)] = value
1535 elif len(args)==3 and isinstance(args[0],int) and \
1536 isinstance(args[1],int) and isinstance(args[2], dict):
1537 self.lines = args[0]
1538 self.cols = args[1]
1539 self.mat = {}
1540 # manual copy, copy.deepcopy() doesn't work
1541 for key in args[2].keys():
1542 self.mat[key] = args[2][key]
1543 else:
1544 if len(args) == 1:
1545 mat = args[0]
1546 else:
1547 mat = args
1548 if not isinstance(mat[0], (list, tuple)):
1549 mat = [ [element] for element in mat ]
1550 self.lines = len(mat)
1551 self.cols = len(mat[0])
1552 self.mat = {}
1553 for i in range(self.lines):
1554 assert len(mat[i]) == self.cols
1555 for j in range(self.cols):
1556 value = sympify(mat[i][j])
1557 if value != 0:
1558 self.mat[(i,j)] = value
1560 def __getitem__(self, key):
1561 if isinstance(key, slice) or isinstance(key, int):
1562 lo, hi = self.slice2bounds(key, self.lines*self.cols)
1563 L = []
1564 for i in range(lo, hi):
1565 m,n = self.rowdecomp(i)
1566 if self.mat.has_key((m,n)):
1567 L.append(self.mat[(m,n)])
1568 else:
1569 L.append(0)
1570 if len(L) == 1:
1571 return L[0]
1572 else:
1573 return L
1574 assert len(key) == 2
1575 if isinstance(key[0], int) and isinstance(key[1], int):
1576 i,j=self.key2ij(key)
1577 if self.mat.has_key((i,j)):
1578 return self.mat[(i,j)]
1579 else:
1580 return 0
1581 elif isinstance(key[0], slice) or isinstance(key[1], slice):
1582 return self.submatrix(key)
1583 else:
1584 raise IndexError("Index out of range: a[%s]"%repr(key))
1586 def rowdecomp(self, num):
1587 assert (0 <= num < self.lines * self.cols) or \
1588 (0 <= -1*num < self.lines * self.cols)
1589 i, j = 0, num
1590 while j >= self.cols:
1591 j -= self.cols
1592 i += 1
1593 return i,j
1595 def __setitem__(self, key, value):
1596 # almost identical, need to test for 0
1597 assert len(key) == 2
1598 if isinstance(key[0], slice) or isinstance(key[1], slice):
1599 if isinstance(value, Matrix):
1600 self.copyin_matrix(key, value)
1601 if isinstance(value, (list, tuple)):
1602 self.copyin_list(key, value)
1603 else:
1604 i,j=self.key2ij(key)
1605 testval = sympify(value)
1606 if testval != 0:
1607 self.mat[(i,j)] = testval
1608 elif self.mat.has_key((i,j)):
1609 del self.mat[(i,j)]
1611 def row_del(self, k):
1612 newD = {}
1613 for (i,j) in self.mat.keys():
1614 if i==k:
1615 pass
1616 elif i > k:
1617 newD[i-1,j] = self.mat[i,j]
1618 else:
1619 newD[i,j] = self.mat[i,j]
1620 self.mat = newD
1621 self.lines -= 1
1623 def col_del(self, k):
1624 newD = {}
1625 for (i,j) in self.mat.keys():
1626 if j==k:
1627 pass
1628 elif j > k:
1629 newD[i,j-1] = self.mat[i,j]
1630 else:
1631 newD[i,j] = self.mat[i,j]
1632 self.mat = newD
1633 self.cols -= 1
1635 def toMatrix(self):
1636 l = []
1637 for i in range(self.lines):
1638 c = []
1639 l.append(c)
1640 for j in range(self.cols):
1641 if self.mat.has_key((i, j)):
1642 c.append(self[i, j])
1643 else:
1644 c.append(0)
1645 return Matrix(l)
1647 # from here to end all functions are same as in matrices.py
1648 # with Matrix replaced with SMatrix
1649 def copyin_list(self, key, value):
1650 assert isinstance(value, (list, tuple))
1651 self.copyin_matrix(key, SMatrix(value))
1653 def multiply(self,b):
1654 """Returns self*b """
1656 def dotprod(a,b,i,j):
1657 assert a.cols == b.lines
1659 for x in range(a.cols):
1660 r+=a[i,x]*b[x,j]
1661 return r
1663 r = SMatrix(self.lines, b.cols, lambda i,j: dotprod(self,b,i,j))
1664 if r.lines == 1 and r.cols ==1:
1665 return r[0,0]
1666 return r
1668 def submatrix(self, keys):
1669 assert isinstance(keys[0], slice) or isinstance(keys[1], slice)
1670 rlo, rhi = self.slice2bounds(keys[0], self.lines)
1671 clo, chi = self.slice2bounds(keys[1], self.cols)
1672 if not ( 0<=rlo<=rhi and 0<=clo<=chi ):
1673 raise IndexError("Slice indices out of range: a[%s]"%repr(keys))
1674 return SMatrix(rhi-rlo, chi-clo, lambda i,j: self[i+rlo, j+clo])
1676 def reshape(self, _rows, _cols):
1677 if self.lines*self.cols != _rows*_cols:
1678 print "Invalid reshape parameters %d %d" % (_rows, _cols)
1679 newD = {}
1680 for i in range(_rows):
1681 for j in range(_cols):
1682 m,n = self.rowdecomp(i*_cols + j)
1683 if self.mat.has_key((m,n)):
1684 newD[(i,j)] = self.mat[(m,n)]
1685 return SMatrix(_rows, _cols, newD)
1687 def cross(self, b):
1688 assert isinstance(b, (list, tuple, Matrix))
1689 if not (self.lines == 1 and self.cols == 3 or \
1690 self.lines == 3 and self.cols == 1 ) and \
1691 (b.lines == 1 and b.cols == 3 or \
1692 b.lines == 3 and b.cols == 1):
1693 raise "Dimensions incorrect for cross product"
1694 else:
1695 return SMatrix(1,3,((self[1]*b[2] - self[2]*b[1]),
1696 (self[2]*b[0] - self[0]*b[2]),
1697 (self[0]*b[1] - self[1]*b[0])))
1699 def zeronm(self,n,m):
1700 warnings.warn( 'Deprecated: use zeros() instead.' )
1701 return SMatrix(n,m,{})
1703 def zero(self, n):
1704 warnings.warn( 'Deprecated: use zeros() instead.' )
1705 return SMatrix(n,n,{})
1707 def zeros(self, dims):
1708 """Returns a dims = (d1,d2) matrix of zeros."""
1709 n, m = _dims_to_nm( dims )
1710 return SMatrix(n,m,{})
1712 def eye(self, n):
1713 tmp = SMatrix(n,n,lambda i,j:0)
1714 for i in range(tmp.lines):
1715 tmp[i,i] = 1
1716 return tmp
1719 def list2numpy(l):
1720 """Converts python list of SymPy expressions to a NumPy array."""
1721 from numpy import empty
1722 a = empty(len(l), dtype=object)
1723 for i, s in enumerate(l):
1724 a[i] = s
1725 return a
1727 def matrix2numpy(m):
1728 """Converts SymPy's matrix to a NumPy array."""
1729 from numpy import empty
1730 a = empty(m.shape, dtype=object)
1731 for i in range(m.lines):
1732 for j in range(m.cols):
1733 a[i, j] = m[i, j]
1734 return a
1736 def a2idx(a):
1738 Tries to convert "a" to an index, returns None on failure.
1740 The result of a2idx() (if not None) can be safely used as an index to
1741 arrays/matrices.
1743 if hasattr(a, "__int__"):
1744 return int(a)
1745 if hasattr(a, "__index__"):
1746 return a.__index__()