refactorings, tests added
[sympyx.git] / test_basic.py
bloba53d39ba409bd634d5034899fc5cd3de812df40d
1 from sympy import Symbol, Add, Mul, Pow, Integer, SYMBOL, ADD, MUL, POW, \
2 INTEGER
4 def test_eq():
6 x = Symbol("x")
7 y = Symbol("y")
8 z = Symbol("z")
9 a = Symbol("x")
11 assert x == x
12 assert not (x != x)
13 assert x == a
14 assert not (x != a)
15 assert x != y
17 assert Add(x, y) == Add(x, y)
18 assert Add(a, y) == Add(x, y)
19 assert Add(x, y) == Add(y, x)
20 assert Add(x, y) != Add(y, z)
22 assert Mul(x, y) == Mul(x, y)
23 assert Mul(a, y) == Mul(x, y)
24 assert Mul(x, y) == Mul(y, x)
25 assert Mul(x, y) != Mul(y, z)
27 assert Pow(x, y) == Pow(x, y)
28 assert Pow(a, y) == Pow(x, y)
29 assert Pow(x, y) != Pow(y, x)
30 assert Pow(x, y) != Pow(y, z)
31 assert Pow(a, y) != Pow(x, z)
33 assert Integer(3) == Integer(3)
34 assert Integer(3) != Integer(4)
36 def test_add():
37 x = Symbol("x")
38 y = Symbol("y")
39 z = Symbol("z")
41 assert Add(Add(x, y), z) == Add(x, Add(y, z))
42 assert Add(Add(z, x), y) == Add(x, Add(y, z))
43 assert Add(Add(z, x), x) != Add(x, Add(y, z))
45 assert Add(x, x) == Mul(Integer(2), x)
46 assert Add(Add(Add(x, y), z), x) == Add(Add(Mul(Integer(2), x), y), z)
48 def test_mul():
49 x = Symbol("x")
50 y = Symbol("y")
51 z = Symbol("z")
53 assert Mul(Mul(x, y), z) == Mul(x, Mul(y, z))
54 assert Mul(Mul(z, x), y) == Mul(x, Mul(y, z))
55 assert Mul(Mul(z, x), x) != Mul(x, Mul(y, z))
57 assert Mul(x, x) == Pow(x, Integer(2))
58 assert Mul(Mul(Mul(x, y), z), x) == Mul(Mul(Pow(x, Integer(2)), y), z)
60 def test_arit():
61 x = Symbol("x")
62 y = Symbol("y")
63 z = Symbol("z")
65 assert x+y == Add(x, y)
66 assert x+y+z == Add(Add(x, y), z)
68 assert x-y == Add(x, Mul(Integer(-1), y))
69 assert y-x == Add(Mul(Integer(-1), x), y)
71 assert x*y == Mul(x, y)
72 assert x*y*z == Mul(z, Mul(x, y))
74 assert x/y == Mul(x, Pow(y, Integer(-1)))
75 assert y/x == Mul(Pow(x, Integer(-1)), y)
77 assert x**Integer(2) == Pow(x, Integer(2))
79 assert -x == Mul(Integer(-1), x)
80 assert +x == x
82 def test_int_conversion():
83 x = Symbol("x")
84 assert x+1 == Add(x, 1)
85 assert x*1 == x
86 assert x**1 == x
87 assert x/2 == Mul(x, Pow(2, -1))
89 def test_expand1():
90 x = Symbol("x")
91 y = Symbol("y")
92 z = Symbol("z")
94 assert ( (x+y)**2 ).expand() == x**2 + 2*x*y + y**2
95 assert ( (x+y)**3 ).expand() == x**3 + 3*x**2*y +3*x*y**2 + y**3
97 assert ( (x+y+z)**2 ).expand() == x**2 + y**2 + z**2 + 2*x*y + 2*x*z + 2*y*z
99 def test_expand2():
100 x = Symbol("x")
101 y = Symbol("y")
102 z = Symbol("z")
104 assert ( 2*x*y ).expand() == 2*x*y
105 assert ( (x+y) * (x+z) ).expand() == x**2 + x*y + x*z + y*z
106 assert ( x*(x+y)**2 ).expand() == x**3 + 2*x**2*y + x*y**2
107 assert ( x*(x+y)**2 + z*(x+y)**2 ).expand() == \
108 x**3 + 2*x**2*y + y**2*z + x**2*z + x*y**2 + 2*x*y*z
110 assert ( 2*x * (y*x + y*z) ).expand() == 2*x**2*y + 2*x*y*z
111 assert ( (x+y)**2 * (x+z) ).expand() == \
112 x**3 + 2*x**2*y + y**2*z + x**2*z + x*y**2 + 2*x*y*z
114 def test_canonicalization():
115 x = Symbol("x")
116 y = Symbol("y")
117 z = Symbol("z")
119 assert x-x == 0
120 assert x*1 == x
121 assert x+0 == x
122 assert x-0 == x
123 assert x**1 == x
124 assert 1**x == 1
125 assert 0**x == 0
127 def test_pow():
128 x = Symbol("x")
129 y = Symbol("y")
130 z = Symbol("z")
132 assert (x**2)**3 == x**6
133 assert (x**y)**3 == x**(3*y)
134 # this is maybe not mathematically correct:
135 assert (x**y)**z == x**(y*z)
137 def test_args_type():
138 x = Symbol("x")
139 y = Symbol("y")
140 z = Symbol("z")
142 assert (x+y).type == ADD
143 assert set((x+y).args) == set((x, y))
144 assert set((x+y).args) != set((x, z))
146 assert (x*y*z).type == MUL
147 assert set((x*y*z).args) == set((x, y, z))
149 assert (x**y).type == POW
150 assert (x**y).args == (x, y)
151 assert x.type == SYMBOL
152 assert x.args == ()
153 assert Integer(5).type == INTEGER
154 assert Integer(5).args == ()
156 assert ( x-y ).type == ADD
157 assert set(( x-y ).args) == set((x, -y))
159 def test_hash():
160 x = Symbol("x")
161 y = Symbol("y")
162 z = Symbol("z")
163 a = Symbol("x")
165 assert hash(x) != hash(y)
166 assert hash(x) != hash(z)
167 assert hash(x) == hash(a)
169 assert hash(Integer(3)) == hash(Integer(3))
170 assert hash(Integer(3)) != hash(Integer(4))
172 assert hash(x*y) == hash(y*x)
173 assert hash(x*y) == hash(y*a)
174 #assert hash(x*y) != hash(y*z)
175 assert hash(x*y*z) == hash(y*z*x)
176 assert hash(x*y*z) == hash(y*z*a)
178 def test_hash2():
179 x = Symbol("x")
180 y = Symbol("y")
181 z = Symbol("z")
182 a = Symbol("x")
184 assert x*y+y*x == 2*x*y
185 assert x*y-y*x == 0
186 assert x*y+y*a == 2*x*y
187 assert x*y-y*a == 0