Added license info into the .py files.
[golden_search.git] / operators.py
blobd98a791863e4181318002d9ba6a0da404f009b16
1 # For licensing info see the included LICENSE file
3 # Josef Moudrik, <J dot Moudrik at standard google mail ending>, 2012
5 from math import sqrt
7 import number
8 from exc import IncompatibleUnits
12 OPERATORS_ALL=[]
14 OP_CLASSES=[]
15 OPERATORS_ARITHMETIC=[]
17 def register_operator(cl):
18 OP_CLASSES.append(cl)
19 return cl
21 class GenericOp(object):
22 def __init__(self, arity, commutativity, name):
23 self.arity = arity
24 self.commutativity = commutativity
25 self.name = name
27 def __call__(self, args):
28 assert len(args) == self.getArity()
29 values = map(lambda n: n.getValue(), args)
30 sdsqs = map(lambda n: n.getSdSquare(), args)
31 units = map(lambda n: n.getUnits(), args)
33 new_value, new_sdsq, new_units = self.computeNewNumber(values, sdsqs, units)
34 return number.Number(new_value, sdsq=new_sdsq, string=self.name, units=new_units, parents=args)
36 def computeNewNumber(self, values, sdsqs, units):
37 raise NotImplementedError
39 def getArity(self):
40 return self.arity
42 def isCommutative(self):
43 return self.commutativity
46 # Arithmetic Operators
49 @register_operator
50 class Plus(GenericOp):
51 def __init__(self):
52 super(Plus, self).__init__(2, True, "Plus")
54 def computeNewNumber(self, (n1v, n2v), (n1sdsq, n2sdsq), (u1, u2)):
55 if u1 != u2:
56 raise IncompatibleUnits
57 new_value = n1v + n2v
58 new_sdsq = n1sdsq + n2sdsq
59 return new_value, new_sdsq, u1
61 @register_operator
62 class Minus(GenericOp):
63 def __init__(self):
64 super(Minus, self).__init__(2, False, "Minus")
66 def computeNewNumber(self, (n1v, n2v), (n1sdsq, n2sdsq), (u1, u2)):
67 if u1 != u2:
68 raise IncompatibleUnits
69 new_value = n1v - n2v
70 new_sdsq = n1sdsq + n2sdsq
71 return new_value, new_sdsq, u1
73 @register_operator
74 class Mult(GenericOp):
75 def __init__(self):
76 super(Mult, self).__init__(2, True, "Mult")
78 def computeNewNumber(self, (n1v, n2v), (n1sdsq, n2sdsq), (u1, u2)):
79 new_value = n1v * n2v
80 new_sdsq = n1v**2 * n2sdsq + n2v**2 * n1sdsq
81 return new_value, new_sdsq, number.units_join(u1, u2)
83 @register_operator
84 class Div(GenericOp):
85 def __init__(self):
86 super(Div, self).__init__(2, False, "Div")
88 def computeNewNumber(self, (n1v, n2v), (n1sdsq, n2sdsq), (u1, u2)):
89 new_value = n1v / n2v
90 new_sdsq = n2v**(-2) * n1sdsq + n1v**2 * n2v**(-4) * n2sdsq
91 return new_value, new_sdsq, number.units_diff(u1, u2)
94 # Numeric operators
96 # e.g. op(number) -> 2 * number
99 class Times(GenericOp):
100 def __init__(self, num):
101 super(Times, self).__init__(1, True, str(num))
102 self.num=num
104 def computeNewNumber(self, (n1v,), (n1sdsq, ), (u1,)):
105 new_value = self.num * n1v
106 new_sdsq = self.num**2 * n1sdsq
107 return new_value, new_sdsq, u1
110 # Unit Conversions
113 # TODO TODO TODO
116 # All together
119 OPERATORS_ARITHMETIC = map(lambda x: x() , OP_CLASSES)