From 60adf4f659e68b87d04517b5ede243c90aa07ba7 Mon Sep 17 00:00:00 2001 From: Ondrej Certik Date: Mon, 7 Jan 2008 00:35:20 +0100 Subject: [PATCH] Lambda support works now (#427) Currently only one variable lambdas work. Example: In [1]: f=Lambda(x, x**2) In [2]: f(4) Out[2]: 16 --- sympy/__init__.py | 1 + sympy/core/ast_parser.py | 4 +- sympy/core/function.py | 176 +++++++++++-------------------------- sympy/core/tests/test_functions.py | 19 +++- sympy/core/tests/test_sympify.py | 5 ++ 5 files changed, 75 insertions(+), 130 deletions(-) diff --git a/sympy/__init__.py b/sympy/__init__.py index 033a695..156e00c 100644 --- a/sympy/__init__.py +++ b/sympy/__init__.py @@ -18,6 +18,7 @@ import sys sys.path.insert(0, os.path.join(os.path.dirname(__file__), "thirdparty", \ "pyglet")) +import symbol as stdlib_symbol from sympy.core import * from series import * diff --git a/sympy/core/ast_parser.py b/sympy/core/ast_parser.py index c036046..cb1b22a 100644 --- a/sympy/core/ast_parser.py +++ b/sympy/core/ast_parser.py @@ -61,8 +61,8 @@ class SymPyTransformer(Transformer): def lambdef(self, nodelist): #this is never executed #this is python stdlib symbol, not SymPy symbol: - from symbol import varargslist - if nodelist[2][0] == varargslist: + from sympy import stdlib_symbol + if nodelist[2][0] == stdlib_symbol.varargslist: names, defaults, flags = self.com_arglist(nodelist[2][1:]) else: names = defaults = () diff --git a/sympy/core/function.py b/sympy/core/function.py index 2a04749..b657584 100644 --- a/sympy/core/function.py +++ b/sympy/core/function.py @@ -386,119 +386,6 @@ class WildFunction(Function, Atom): def _eval_apply_evalf(cls, arg): return -class Lambda(Function): - """ - Lambda(expr, arg1, arg2, ...) -> lambda arg1, arg2,... : expr - - Lambda instance has the same assumptions as its body. - - """ - precedence = Basic.Lambda_precedence - name = None - has_derivative = True - - def __new__(cls, expr, *args): - expr = Basic.sympify(expr) - args = tuple(map(Basic.sympify, args)) - # XXX - #if isinstance(expr, Apply): - # if expr[:]==args: - # return expr.func - dummy_args = [] - for a in args: - if not isinstance(a, Basic.Symbol): - raise TypeError("%s %s-th argument must be Symbol instance (got %r)" \ - % (cls.__name__, len(dummy_args)+1,a)) - d = a.as_dummy() - expr = expr.subs(a, d) - dummy_args.append(d) - obj = Basic.__new__(cls, expr, *dummy_args, **expr._assumptions) - return obj - - def _hashable_content(self): - return self._args - - @property - def nargs(self): - return len(self._args)-1 - - def __getitem__(self, iter): - return self._args[1:][iter] - - def __len__(self): - return len(self[:]) - - @property - def body(self): - return self._args[0] - - def tostr(self, level=0): - precedence = self.precedence - r = 'lambda %s: %s' % (', '.join([a.tostr() for a in self]), - self.body.tostr(precedence)) - if precedence <= level: - return '(%s)' % r - return r - - def torepr(self): - return '%s(%s)' % (self.__class__.__name__, ', '.join([a.torepr() for a in self])) - - def as_coeff_terms(self, x=None): - c,t = self.body.as_coeff_terms(x) - return c, [Lambda(Basic.Mul(*t),*self[:])] - - def _eval_power(b, e): - """ - (lambda x:f(x))**e -> (lambda x:f(x)**e) - """ - return Lambda(b.body**e, *b[:]) - - def _eval_fpower(b, e): - """ - FPow(lambda x:f(x), 2) -> lambda x:f(f(x))) - """ - if isinstance(e, Basic.Integer) and e.is_positive and e.p < 10 and len(b)==1: - r = b.body - for i in xrange(e.p-1): - r = b(r) - return Lambda(r, *b[:]) - - def with_dummy_arguments(self, args = None): - if args is None: - args = tuple([a.as_dummy() for a in self]) - if len(args) != len(self): - raise TypeError("different number of arguments in Lambda functions: %s, %s" % (len(args), len(self))) - expr = self.body - for a,na in zip(self, args): - expr = expr.subs(a, na) - return expr, args - - def _eval_expand_basic(self, *args): - return Lambda(self.body._eval_expand_basic(*args), *self[:]) - - def diff(self, *symbols): - return Lambda(self.body.diff(*symbols), *self[:]) - - def fdiff(self, argindex=1): - if not (1<=argindex<=len(self)): - raise TypeError("%s.fderivative() argindex %r not in the range [1,%s]"\ - % (self.__class__, argindex, len(self))) - s = self[argindex-1] - expr = self.body.diff(s) - return Lambda(expr, *self[:]) - - _eval_subs = Basic._seq_subs - - def canonize(cls, *args): - n = cls.nargs - if n!=len(args): - raise TypeError('%s takes exactly %s arguments (got %s)'\ - % (cls, n, len(args))) - expr = cls.body - for da,a in zip(cls, args): - expr = expr.subs(da,a) - return expr - class Derivative(Basic, ArithMeths, RelMeths): """ Carries out differentation of the given expression with respect to symbols. @@ -536,20 +423,6 @@ class Derivative(Basic, ArithMeths, RelMeths): return expr return Basic.__new__(cls, expr, *unevaluated_symbols) - # FIXME is this needed - def xas_apply(self): - # Derivative(f(x),x) -> Apply(Lambda(f(_x),_x), x) - symbols = [] - indices = [] - for s in self.symbols: - if s not in symbols: - symbols.append(s) - indices.append(len(symbols)) - else: - indices.append(symbols.index(s)+1) - stop - return Apply(FApply(FDerivative(*indices), Lambda(self.expr, *symbols)), *symbols) - def _eval_derivative(self, s): #print #print self @@ -608,6 +481,55 @@ class Derivative(Basic, ArithMeths, RelMeths): repl_dict[pattern] = expr return repl_dict +class Lambda(Function): + """ + Lambda(x, expr) represents a lambda function similar to Python's + 'lambda x: expr'. A function of several variables is written as + Lambda((x, y, ...), expr). + + A simple example: + >>> from sympy import Symbol + >>> x = Symbol('x') + >>> f = Lambda(x, x**2) + >>> f(4) + 16 + """ + + #XXX currently only one argument Lambda is supported + + nargs = 2 + + @classmethod + def canonize(cls, x, expr): + obj = Basic.__new__(cls, x, expr) + #use dummy variables internally, just to be sure + tmp = Basic.Symbol("x", dummy=True) + obj._args = (tmp, expr.subs(x, tmp)) + return obj + + def apply(self, x): + """Applies the Lambda function "self" to the "x" + + Example: + >>> from sympy import Symbol + >>> x = Symbol('x') + >>> f = Lambda(x, x**2) + >>> f.apply(4) + 16 + + """ + return self[1].subs(self[0], x) + + def __call__(self, *args): + return self.apply(*args) + + def __eq__(self, other): + if isinstance(other, Lambda): + if self[1] == other[1].subs(other[0], self[0]): + return True + return False + + def diff(f, x, times = 1, evaluate=True): """Differentiate f with respect to x diff --git a/sympy/core/tests/test_functions.py b/sympy/core/tests/test_functions.py index c142fd7..4c5c292 100644 --- a/sympy/core/tests/test_functions.py +++ b/sympy/core/tests/test_functions.py @@ -1,6 +1,8 @@ -from sympy import * +from sympy import Lambda, Symbol, Function, WildFunction, Derivative, sqrt, \ + log, exp, Rational, sign, Basic from sympy.utilities.pytest import XFAIL from sympy.utilities.test import REPR0 +from sympy.abc import x, y def test_log(): @@ -157,3 +159,18 @@ def test_unapplied_function_str(): assert repr(f) == "Function('f')" # this does not work assert str(f) == "f" # this does not work + +def test_Lambda(): + e = Lambda(x, x**2) + assert e(4) == 16 + assert e(x) == x**2 + assert e(y) == y**2 + + assert Lambda(x, x**2) == Lambda(x, x**2) + assert Lambda(x, x**2) == Lambda(y, y**2) + assert Lambda(x, x**2) != Lambda(y, y**2+1) + + #doesn't work yet: + #class F(Function): + # pass + #assert Lambda(x, F(x)) == F diff --git a/sympy/core/tests/test_sympify.py b/sympy/core/tests/test_sympify.py index bce74f1..6126d73 100644 --- a/sympy/core/tests/test_sympify.py +++ b/sympy/core/tests/test_sympify.py @@ -44,3 +44,8 @@ def test_sage(): def test_bug496(): a_ = sympify("a_") _a = sympify("_a") + +#def test_lambda(): +# x = Symbol('x') +# assert sympify('lambda : 1')==Lambda(x, 1) +# assert sympify('lambda x: 2*x')==Lambda(x, 2*x) -- 2.11.4.GIT