From 43b1c57f6601faf688ac0a7bb25d0219c7e383ac Mon Sep 17 00:00:00 2001 From: "fredrik.johansson" Date: Tue, 31 Jul 2007 01:20:03 +0000 Subject: [PATCH] made factorial_simplify work again along with binomial, rising_factorial etc in the factorials module; temporarily renamed some functions to avoid name clashes with the implementations in the core --- sympy/specfun/__init__.py | 2 +- sympy/specfun/factorials.py | 87 ++++++++++++++++++---------------- sympy/specfun/tests/test_factorials.py | 49 ++++++++++--------- sympy/specfun/tests/test_specfun.py | 4 +- 4 files changed, 73 insertions(+), 69 deletions(-) diff --git a/sympy/specfun/__init__.py b/sympy/specfun/__init__.py index 1313952..a50c087 100644 --- a/sympy/specfun/__init__.py +++ b/sympy/specfun/__init__.py @@ -4,7 +4,7 @@ as trigonometric functions, orthogonal polynomials, the gamma function, and so on. """ -from factorials import factorial, factorial2, rising_factorial, \ +from factorials import factorial_, factorial2, binomial2, rising_factorial, \ falling_factorial, gamma, lower_gamma, upper_gamma, \ factorial_simplify diff --git a/sympy/specfun/factorials.py b/sympy/specfun/factorials.py index 39e4a51..22aa63f 100644 --- a/sympy/specfun/factorials.py +++ b/sympy/specfun/factorials.py @@ -25,7 +25,7 @@ def _lanczos(z): return exp(logw) -class Factorial(DefinedFunction): +class Factorial_(DefinedFunction): """ Usage ===== @@ -57,14 +57,14 @@ class Factorial(DefinedFunction): if x.is_integer: if x < 0: return oo - y = Rational(1) + y = 1 for m in xrange(1, x.p+1): y *= m - return y + return Rational(y) if x.q == 2: n = (x.p + 1) / 2 if n < 0: - return (-1)**(-n+1) * pi * x / factorial(-x) + return (-1)**(-n+1) * pi * x / factorial_(-x) return sqrt(pi) * Rational(1, 2**n) * factorial2(2*n-1) def diff(self, sym): @@ -99,8 +99,11 @@ class Factorial(DefinedFunction): return s + "!" -def _fac(x): - return factorial(x, evaluate=False) +class UnevalatedFactorial(Factorial_): + def _eval_apply(self, x): + return None + +unfac = UnevalatedFactorial() class Factorial2(DefinedFunction): @@ -153,6 +156,9 @@ class Factorial2(DefinedFunction): # factorial_simplify helpers; could use refactoring +def _isfactorial(expr): + return isinstance(expr, Apply) and isinstance(expr[0], Factorial_) + def _collect_factors(expr): assert isinstance(expr, Mul) numer_args = [] @@ -166,16 +172,16 @@ def _collect_factors(expr): other += o elif isinstance(x, Pow): base, exp = x[:] - if isinstance(base, factorial) and \ + if _isfactorial(base) and \ isinstance(exp, Rational) and exp.is_integer: if exp > 0: - for i in xrange(exp.p): numer_args.append(base._args) + for i in xrange(exp.p): numer_args.append(base.args[0]) else: - for i in xrange(-exp.p): denom_args.append(base._args) + for i in xrange(-exp.p): denom_args.append(base.args[0]) else: other.append(x) - elif isinstance(x, factorial): - numer_args.append(x._args) + elif _isfactorial(x): + numer_args.append(x.args[0]) else: other.append(x) return numer_args, denom_args, other @@ -186,7 +192,8 @@ def _simplify_quotient(na, da, other): candidates = [] for i, y in enumerate(na): for j, x in enumerate(da): - delta = simplify(y - x) + #delta = simplify(y - x) + delta = y - x if isinstance(delta, Rational) and delta.is_integer: candidates.append((delta, i, j)) if candidates: @@ -215,16 +222,14 @@ def _simplify_recurrence(facs, other, reciprocal=False): while i < len(facs): j = 0 while j < len(other): + othr = other[j] + fact = facs[i] if reciprocal: - if simplify(other[j] - facs[i]) == 0: - facs[i] -= 1; del other[j]; j = -1 - elif simplify(1/other[j] - facs[i]) == 1: - facs[i] += 1; del other[j]; j = -1 - else: - if simplify(other[j] - facs[i]) == 1: - facs[i] += 1; del other[j]; j = -1 - elif simplify(1/other[j] - facs[i]) == 0: - facs[i] -= 1; del other[j]; j = -1 + othr = 1/othr + if othr - fact == 1: facs[i] += 1; del other[j]; j -= 1 + elif -othr - fact == 1: facs[i] += 1; del other[j]; other.append(-1); j -= 1 + elif 1/othr - fact == 0: facs[i] -= 1; del other[j]; j -= 1 + elif -1/othr - fact == 0: facs[i] -= 1; del other[j]; other.append(-1); j -= 1 j += 1 i += 1 @@ -242,25 +247,27 @@ def factorial_simplify(expr): if isinstance(expr, Add): return Add(*(factorial_simplify(x) for x in expr)) - if isinstance(expr, factorial): - return expr.eval() + if isinstance(expr, Factorial_): + #return expr.eval() + return expr if isinstance(expr, Pow): return Pow(factorial_simplify(expr[0]), expr[1]) if isinstance(expr, Mul): na, da, other = _collect_factors(expr) - _simplify_quotient(na, da, other) _simplify_recurrence(na, other) _simplify_recurrence(da, other, reciprocal=True) result = Rational(1) - for n in na: result *= factorial(n).eval() - for d in da: result /= factorial(d).eval() + for n in na: result *= factorial_(n) + for d in da: result /= factorial_(d) for o in other: result *= o return result + expr = expr.subs(unfac, factorial_) + return expr class Rising_factorial(DefinedFunction): @@ -280,7 +287,7 @@ class Rising_factorial(DefinedFunction): nofargs = 2 def _eval_apply(self, x, n): - return factorial_simplify(_fac(x+n-1) / _fac(x-1)) + return factorial_simplify(unfac(x+n-1) / unfac(x-1)) def __latex__(self): x, n = self._args @@ -304,7 +311,7 @@ class Falling_factorial(DefinedFunction): nofargs = 2 def _eval_apply(self, x, n): - return factorial_simplify(_fac(x) / _fac(x-n)) + return factorial_simplify(unfac(x) / unfac(x-n)) def __latex__(self): x, n = self._args @@ -329,22 +336,22 @@ class Binomial2(DefinedFunction): ======== >>> from sympy import * >>> from sympy.specfun.factorials import * - >>> binomial(15,8) + >>> binomial2(15,8) 6435 >>> # Building Pascal's triangle - >>> [binomial(0,k) for k in range(1)] + >>> [binomial2(0,k) for k in range(1)] [1] - >>> [binomial(1,k) for k in range(2)] + >>> [binomial2(1,k) for k in range(2)] [1, 1] - >>> [binomial(2,k) for k in range(3)] + >>> [binomial2(2,k) for k in range(3)] [1, 2, 1] - >>> [binomial(3,k) for k in range(4)] + >>> [binomial2(3,k) for k in range(4)] [1, 3, 3, 1] >>> # n can be arbitrary if k is a positive integer - >>> binomial(Rational(5,4), 3) + >>> binomial2(Rational(5,4), 3) -5/128 >>> x = Symbol('x') - >>> binomial(x, 3) + >>> binomial2(x, 3) 1/6*x*(-2+x)*(-1+x) """ @@ -356,7 +363,7 @@ class Binomial2(DefinedFunction): if n == 0 and k != 0: return sin(pi*k)/(pi*k) - return factorial_simplify(_fac(n) / _fac(k) / _fac(n-k)) + return factorial_simplify(unfac(n) / unfac(k) / unfac(n-k)) def __latex__(self): n, k = self._args @@ -389,9 +396,9 @@ class Gamma(DefinedFunction): nofargs = 1 def _eval_apply(self, x): - y = factorial(x-1) + y = factorial_(x-1) try: - if not isinstance(y.func, Factorial): + if not isinstance(y.func, Factorial_): return y except: return y @@ -441,11 +448,11 @@ class UpperGamma(DefinedFunction): #return self -factorial = Factorial() +factorial_ = Factorial_() factorial2 = Factorial2() rising_factorial = Rising_factorial() falling_factorial = Falling_factorial() -#binomial = Binomial() +binomial2 = Binomial2() upper_gamma = UpperGamma() lower_gamma = LowerGamma() gamma = Gamma() \ No newline at end of file diff --git a/sympy/specfun/tests/test_factorials.py b/sympy/specfun/tests/test_factorials.py index 9fce64f..61a49e5 100644 --- a/sympy/specfun/tests/test_factorials.py +++ b/sympy/specfun/tests/test_factorials.py @@ -6,7 +6,7 @@ y = Symbol('y') z = Symbol('z') fs = factorial_simplify -fac = factorial +fac = factorial_ def test_factorial1(): assert [fac(t) for t in [0,1,2,3,4]] == [1,1,2,6,24] @@ -57,7 +57,7 @@ def test_factorial2(): #assert latex(factorial2(-4, evaluate=False)) == "$(-4)!!$" #assert latex(factorial2(-x, evaluate=False)) == "$(- x)!!$" -def _test_factorial_simplify(): +def test_factorial_simplify(): assert fs(fac(x+5)/fac(x+5)) == 1 assert fs(fac(x+1)/fac(x)) == 1+x assert fs(fac(x+2)/fac(x)) == (1+x)*(2+x) @@ -72,7 +72,7 @@ def _test_factorial_simplify(): assert fs(fac(x)*fac(y-2)*fac(z+2)/fac(z)/fac(y+1)) == fac(x)*(z+1)*(z+2)/(y-1)/y/(y+1) assert fs(fac(x)*fac(y+1)*fac(z+2)/fac(z)/fac(y-2)) == fac(x)*(z+1)*(z+2)*(y-1)*y*(y+1) -def _test_rising_falling(): +def test_rising_falling(): assert rising_factorial(x, 0) == 1 assert rising_factorial(x, 1) == x assert rising_factorial(x, 2) == x*(x+1) @@ -81,32 +81,32 @@ def _test_rising_falling(): assert falling_factorial(x, 1) == x assert falling_factorial(x, 2) == x*(x-1) assert falling_factorial(x, 3) == x*(x-1)*(x-2) - assert rising_factorial(1, x) == factorial(x) - assert falling_factorial(1, x) == 1/factorial(1-x) + assert rising_factorial(1, x) == fac(x) + assert falling_factorial(1, x) == 1/fac(1-x) assert falling_factorial(15, 8) == 259459200 assert falling_factorial(-3,4) == 360 assert falling_factorial(3,-4) == Rational(1,840) assert rising_factorial(15, 8) == 12893126400 assert rising_factorial(-3,4) == 0 - n = Symbol('n') - assert latex(rising_factorial(x, n, evaluate=False)) == "${(x)}^{(n)}$" - assert latex(falling_factorial(x, n, evaluate=False)) == "${(x)}_{(n)}$" + #n = Symbol('n') + #assert latex(rising_factorial(x, n, evaluate=False)) == "${(x)}^{(n)}$" + #assert latex(falling_factorial(x, n, evaluate=False)) == "${(x)}_{(n)}$" -def _test_binomial(): - assert binomial(x, 0) == 1 - assert binomial(x, x) == 1 - assert binomial(0, 0) == 1 - assert binomial(0, 1) == 0 - assert binomial(0, x) == sin(pi*x)/(pi*x) - assert binomial(x, 1) == x - assert binomial(x, 2) == Rational(1,2)*x*(x-1) - assert binomial(x, 3) == Rational(1,6)*x*(x-1)*(x-2) - assert [binomial(4,k) for k in range(5)] == [1,4,6,4,1] - assert [binomial(5,k) for k in range(6)] == [1,5,10,10,5,1] - assert sum(binomial(20, k) for k in range(21)) == 2**20 - assert binomial(10**20, 10**20 - 2) == \ +def test_binomial2(): + assert binomial2(x, 0) == 1 + assert binomial2(x, x) == 1 + assert binomial2(0, 0) == 1 + assert binomial2(0, 1) == 0 + assert binomial2(0, x) == sin(pi*x)/(pi*x) + assert binomial2(x, 1) == x + assert binomial2(x, 2) == Rational(1,2)*x*(x-1) + assert binomial2(x, 3) == Rational(1,6)*x*(x-1)*(x-2) + assert [binomial2(4,k) for k in range(5)] == [1,4,6,4,1] + assert [binomial2(5,k) for k in range(6)] == [1,5,10,10,5,1] + assert sum(binomial2(20, k) for k in range(21)) == 2**20 + assert binomial2(10**20, 10**20 - 2) == \ 4999999999999999999950000000000000000000 - assert latex(binomial(8,3,evaluate=False)) == r"${{8}\choose{3}}$" + #assert latex(binomial(8,3,evaluate=False)) == r"${{8}\choose{3}}$" def test_gamma(): assert gamma(0) == oo @@ -115,9 +115,8 @@ def test_gamma(): assert gamma(3) == 2 assert gamma(Rational(1,2)) == sqrt(pi) #assert latex(gamma(3+x)) == "$\Gamma(3+x)$" - #from sympy import simplify - #assert simplify(lower_gamma(1,x) + upper_gamma(1,x)) == gamma(1) - #assert simplify(lower_gamma(5,x) + upper_gamma(5,x)) == gamma(5) + assert lower_gamma(1,x) + upper_gamma(1,x) == gamma(1) + assert lower_gamma(5,x) + upper_gamma(5,x) == gamma(5) def _test_derivatives(): x = Symbol('x') diff --git a/sympy/specfun/tests/test_specfun.py b/sympy/specfun/tests/test_specfun.py index 5e98349..64a87f4 100644 --- a/sympy/specfun/tests/test_specfun.py +++ b/sympy/specfun/tests/test_specfun.py @@ -1,6 +1,4 @@ from sympy.specfun import * def test_import(): - pass - #isn't working yet - assert factorial(3) == 6 + assert factorial_(3) == 6 -- 2.11.4.GIT