Start of the 0.5.12 development cycle
[sympy.git] / examples / fem.py
blob10da81d2fee8d8f50a259fabdb40db3a45259ca3
1 from sympy import *
4 x,y,z = symbols('xyz')
6 class ReferenceSimplex:
7 def __init__(self, nsd):
8 self.nsd = nsd
9 coords = []
10 if nsd <= 3:
11 coords = symbols('xyz')[:nsd]
12 else:
13 coords = []
14 for d in range(0,nsd):
15 coords.append(Symbol("x_%d" % d))
16 self.coords = coords
18 def integrate(self,f):
19 coords = self.coords
20 nsd = self.nsd
22 limit = 1
23 for p in coords:
24 limit -= p
26 intf = f
27 for d in range(0,nsd):
28 p = coords[d]
29 limit += p
30 intf = integrate(intf, (p, 0, limit))
31 return intf
33 def bernstein_space(order, nsd):
34 if nsd > 3:
35 raise RuntimeError("Bernstein only implemented in 1D, 2D, and 3D")
36 sum = 0
37 basis = []
38 coeff = []
40 if nsd == 1:
41 b1, b2 = x, 1-x
42 for o1 in range(0,order+1):
43 for o2 in range(0,order+1):
44 if o1 + o2 == order:
45 aij = Symbol("a_%d_%d" % (o1,o2))
46 sum += aij*binomial(order,o1)*pow(b1, o1)*pow(b2,
47 o2)
48 basis.append(binomial(order,o1)*pow(b1,
49 o1)*pow(b2, o2))
50 coeff.append(aij)
53 if nsd == 2:
54 b1, b2, b3 = x, y, 1-x-y
55 for o1 in range(0,order+1):
56 for o2 in range(0,order+1):
57 for o3 in range(0,order+1):
58 if o1 + o2 + o3 == order:
59 aij = Symbol("a_%d_%d_%d" % (o1,o2,o3))
60 fac = factorial(order)/ (factorial(o1)*factorial(o2)*factorial(o3))
61 sum += aij*fac*pow(b1, o1)*pow(b2, o2)*pow(b3,
62 o3)
63 basis.append(fac*pow(b1, o1)*pow(b2,
64 o2)*pow(b3, o3))
65 coeff.append(aij)
67 if nsd == 3:
68 b1, b2, b3, b4 = x, y, z, 1-x-y-z
69 for o1 in range(0,order+1):
70 for o2 in range(0,order+1):
71 for o3 in range(0,order+1):
72 for o4 in range(0,order+1):
73 if o1 + o2 + o3 + o4 == order:
74 aij = Symbol("a_%d_%d_%d_%d" %
75 (o1,o2,o3,o4))
76 fac = factorial(order)/ (factorial(o1)*factorial(o2)*factorial(o3)*factorial(o4))
77 sum += aij*fac*pow(b1, o1)*pow(b2, o2)*pow(b3, o3)*pow(b4, o4)
78 basis.append(fac*pow(b1, o1)*pow(b2,
79 o2)*pow(b3, o3)*pow(b4, o4))
80 coeff.append(aij)
83 return sum, coeff, basis
85 def create_point_set(order, nsd):
86 h = Rational(1,order)
87 set = []
89 if nsd == 1:
90 for i in range(0, order+1):
91 x = i*h
92 if x <= 1:
93 set.append((x,y))
95 if nsd == 2:
96 for i in range(0, order+1):
97 x = i*h
98 for j in range(0, order+1):
99 y = j*h
100 if x + y <= 1:
101 set.append((x,y))
103 if nsd == 3:
104 for i in range(0, order+1):
105 x = i*h
106 for j in range(0, order+1):
107 y = j*h
108 for k in range(0, order+1):
109 z = j*h
110 if x + y + z <= 1:
111 set.append((x,y,z))
113 return set
117 def create_matrix(equations, coeffs):
118 A = zeronm(len(equations), len(equations))
119 i = 0; j = 0
120 for j in range(0, len(coeffs)):
121 c = coeffs[j]
122 for i in range(0, len(equations)):
123 e = equations[i]
124 d, r = div(e, c)
125 A[i,j] = d
126 return A
130 class Lagrange:
131 def __init__(self,nsd, order):
132 self.nsd = nsd
133 self.order = order
134 self.compute_basis()
136 def nbf(self):
137 return len(self.N)
139 def compute_basis(self):
140 order = self.order
141 nsd = self.nsd
142 N = []
143 pol, coeffs, basis = bernstein_space(order, nsd)
144 points = create_point_set(order, nsd)
146 equations = []
147 for p in points:
148 ex = pol.subs(x, p[0])
149 if nsd > 1:
150 ex = ex.subs(y, p[1])
151 if nsd > 2:
152 ex = ex.subs(z, p[2])
153 equations.append(ex )
155 A = create_matrix(equations, coeffs)
156 Ainv = A.inv()
158 b = eye(len(equations))
160 xx = Ainv*b
162 for i in range(0,len(equations)):
163 Ni = pol
164 for j in range(0,len(coeffs)):
165 Ni = Ni.subs(coeffs[j], xx[j,i])
166 N.append(Ni)
168 self.N = N