Move invtrig tests from test_functions.py to test_trigonometric.py
[sympy.git] / sympy / core / pattern_tools.py
blob73ca81f54d41a28e7246ae89c075fdbd85316c6e
1 """
2 Tools for constructing patterns.
4 """
6 import re
8 class Pattern:
9 """
10 p1 | p2 -> <p1> | <p2>
11 p1 + p2 -> <p1> <p2>
12 p1 & p2 -> <p1><p2>
13 ~p1 -> [ <p1> ]
14 ~~p1 -> [ <p1> ]...
15 ~~~p1 -> <p1> [ <p1> ]...
16 ~~~~p1 -> ~~~p1
17 abs(p1) -> whole string match of <p1>
18 p1.named(name) -> match of <p1> has name
19 p1.match(string) -> return string match with <p1>
20 p1.flags(<re.I,..>)
21 p1.rsplit(..) -> split a string from the rightmost p1 occurrence
22 p1.lsplit(..) -> split a string from the leftmost p1 occurrence
23 """
24 _special_symbol_map = {'.': '[.]',
25 '*': '[*]',
26 '+': '[+]',
27 '|': '[|]',
28 '(': r'\(',
29 ')': r'\)',
30 '[': r'\[',
31 ']': r'\]',
32 '^': '[^]',
33 '$': '[$]',
34 '?': '[?]',
35 '{': '\{',
36 '}': '\}',
37 '>': '[>]',
38 '<': '[<]',
39 '=': '[=]'
42 def __init__(self, label, pattern, optional=0, flags=0, value=None):
43 self.label = label
44 self.pattern = pattern
45 self.optional = optional
46 self._flags = flags
47 self.value = value
48 return
50 def flags(self, *flags):
51 f = self._flags
52 for f1 in flags:
53 f = f | f1
54 return Pattern(self.label, self.pattern, optional=self.optional, flags=f, value=self.value)
56 def get_compiled(self):
57 try:
58 return self._compiled_pattern
59 except AttributeError:
60 self._compiled_pattern = compiled = re.compile(self.pattern, self._flags)
61 return compiled
63 def match(self, string):
64 return self.get_compiled().match(string)
66 def search(self, string):
67 return self.get_compiled().search(string)
69 def rsplit(self, string):
70 """
71 Return (<lhs>, <pattern_match>, <rhs>) where
72 string = lhs + pattern_match + rhs
73 and rhs does not contain pattern_match.
74 If no pattern_match is found in string, return None.
75 """
76 compiled = self.get_compiled()
77 t = compiled.split(string)
78 if len(t) < 3: return
79 if '' in t[1:-1]: return
80 rhs = t[-1].strip()
81 pattern_match = t[-2].strip()
82 assert abs(self).match(pattern_match),`self,string,t,pattern_match`
83 lhs = (''.join(t[:-2])).strip()
84 return lhs, pattern_match, rhs
86 def lsplit(self, string):
87 """
88 Return (<lhs>, <pattern_match>, <rhs>) where
89 string = lhs + pattern_match + rhs
90 and rhs does not contain pattern_match.
91 If no pattern_match is found in string, return None.
92 """
93 compiled = self.get_compiled()
94 t = compiled.split(string) # can be optimized
95 if len(t) < 3: return
96 lhs = t[0].strip()
97 pattern_match = t[1].strip()
98 rhs = (''.join(t[2:])).strip()
99 assert abs(self).match(pattern_match),`pattern_match`
100 return lhs, pattern_match, rhs
102 def __abs__(self):
103 return Pattern(self.label, r'\A' + self.pattern+ r'\Z',flags=self._flags, value=self.value)
105 def __repr__(self):
106 return '%s(%r, %r)' % (self.__class__.__name__, self.label, self.pattern)
108 def __or__(self, other):
109 label = '( %s OR %s )' % (self.label, other.label)
110 if self.pattern==other.pattern:
111 pattern = self.pattern
112 flags = self._flags
113 else:
114 pattern = '(%s|%s)' % (self.pattern, other.pattern)
115 flags = self._flags | other._flags
116 return Pattern(label, pattern, flags=flags)
118 def __and__(self, other):
119 if isinstance(other, Pattern):
120 label = '%s%s' % (self.label, other.label)
121 pattern = self.pattern + other.pattern
122 flags = self._flags | other._flags
123 else:
124 assert isinstance(other,str),`other`
125 label = '%s%s' % (self.label, other)
126 pattern = self.pattern + other
127 flags = self._flags
128 return Pattern(label, pattern, flags=flags)
130 def __rand__(self, other):
131 assert isinstance(other,str),`other`
132 label = '%s%s' % (other, self.label)
133 pattern = other + self.pattern
134 return Pattern(label, pattern, flags=self._flags)
136 def __invert__(self):
137 if self.optional:
138 if self.optional==1:
139 return Pattern(self.label + '...', self.pattern[:-1] + '*', optional=2,flags=self._flags)
140 if self.optional==2:
141 return Pattern('%s %s' % (self.label[1:-4].strip(), self.label), self.pattern[:-1] + '+',
142 optional=3, flags=self._flags)
143 return self
144 label = '[ %s ]' % (self.label)
145 pattern = '(%s)?' % (self.pattern)
146 return Pattern(label, pattern, optional=1, flags=self._flags)
148 def __add__(self, other):
149 if isinstance(other, Pattern):
150 label = '%s %s' % (self.label, other.label)
151 pattern = self.pattern + r'\s*' + other.pattern
152 flags = self._flags | other._flags
153 else:
154 assert isinstance(other,str),`other`
155 label = '%s %s' % (self.label, other)
156 other = self._special_symbol_map.get(other, other)
157 pattern = self.pattern + r'\s*' + other
158 flags = self._flags
159 return Pattern(label, pattern, flags = flags)
161 def __radd__(self, other):
162 assert isinstance(other,str),`other`
163 label = '%s %s' % (other, self.label)
164 other = self._special_symbol_map.get(other, other)
165 pattern = other + r'\s*' + self.pattern
166 return Pattern(label, pattern, flags=self._flags)
168 def named(self, name = None):
169 if name is None:
170 label = self.label
171 assert label[0]+label[-1]=='<>' and ' ' not in label,`label`
172 else:
173 label = '<%s>' % (name)
174 pattern = '(?P%s%s)' % (label.replace('-','_'), self.pattern)
175 return Pattern(label, pattern, flags=self._flags, value= self.value)
177 def rename(self, label):
178 if label[0]+label[-1]!='<>':
179 label = '<%s>' % (label)
180 return Pattern(label, self.pattern, optional=self.optional, flags=self._flags, value=self.value)
182 def __call__(self, string):
183 m = self.match(string)
184 if m is None: return
185 if self.value is not None: return self.value
186 return m.group()
188 # Predefined patterns
190 letter = Pattern('<letter>','[A-Z]',flags=re.I)
191 name = Pattern('<name>', r'[A-Z]\w*',flags=re.I)
192 digit = Pattern('<digit>',r'\d')
193 underscore = Pattern('<underscore>', '_')
194 binary_digit = Pattern('<binary-digit>',r'[01]')
195 octal_digit = Pattern('<octal-digit>',r'[0-7]')
196 hex_digit = Pattern('<hex-digit>',r'[\dA-F]',flags=re.I)
198 digit_string = Pattern('<digit-string>',r'\d+')
199 binary_digit_string = Pattern('<binary-digit-string>',r'[01]+')
200 octal_digit_string = Pattern('<octal-digit-string>',r'[0-7]+')
201 hex_digit_string = Pattern('<hex-digit-string>',r'[\dA-F]+',flags=re.I)
203 sign = Pattern('<sign>',r'[+-]')
204 exponent_letter = Pattern('<exponent-letter>',r'[ED]',flags=re.I)
206 alphanumeric_character = Pattern('<alphanumeric-character>',r'\w') # [A-Z0-9_]
207 special_character = Pattern('<special-character>',r'[ =+-*/\()[\]{},.:;!"%&~<>?,\'`^|$#@]')
208 character = alphanumeric_character | special_character
210 kind_param = digit_string | name
211 kind_param_named = kind_param.named('kind-param')
212 signed_digit_string = ~sign + digit_string
213 int_literal_constant = digit_string + ~('_' + kind_param)
214 signed_int_literal_constant = ~sign + int_literal_constant
215 int_literal_constant_named = digit_string.named('value') + ~ ('_' + kind_param_named)
216 signed_int_literal_constant_named = (~sign + digit_string).named('value') + ~ ('_' + kind_param_named)
218 binary_constant = ('B' + ("'" & binary_digit_string & "'" | '"' & binary_digit_string & '"')).flags(re.I)
219 octal_constant = ('O' + ("'" & octal_digit_string & "'" | '"' & octal_digit_string & '"')).flags(re.I)
220 hex_constant = ('Z' + ("'" & hex_digit_string & "'" | '"' & hex_digit_string & '"')).flags(re.I)
221 boz_literal_constant = binary_constant | octal_constant | hex_constant
223 exponent = signed_digit_string
224 significand = digit_string + '.' + ~digit_string | '.' + digit_string
225 real_literal_constant = significand + ~(exponent_letter + exponent) + ~ ('_' + kind_param) | \
226 digit_string + exponent_letter + exponent + ~ ('_' + kind_param)
227 real_literal_constant_named = (significand + ~(exponent_letter + exponent) |\
228 digit_string + exponent_letter + exponent).named('value') + ~ ('_' + kind_param_named)
229 signed_real_literal_constant_named = (~sign + (significand + ~(exponent_letter + exponent) |\
230 digit_string + exponent_letter + exponent)).named('value') + ~ ('_' + kind_param_named)
231 signed_real_literal_constant = ~sign + real_literal_constant
233 named_constant = name
234 real_part = signed_int_literal_constant | signed_real_literal_constant | named_constant
235 imag_part = real_part
236 complex_literal_constant = '(' + real_part + ',' + imag_part + ')'
238 a_n_rep_char = Pattern('<alpha-numeric-rep-char>',r'\w')
239 rep_char = Pattern('<rep-char>',r'.')
240 char_literal_constant = ~( kind_param + '_') + ("'" + ~~rep_char + "'" | '"' + ~~rep_char + '"' )
241 a_n_char_literal_constant_named1 = ~( kind_param_named + '_') + (~~~("'" + ~~a_n_rep_char + "'" )).named('value')
242 a_n_char_literal_constant_named2 = ~( kind_param_named + '_') + (~~~('"' + ~~a_n_rep_char + '"' )).named('value')
244 logical_literal_constant = ('[.](TRUE|FALSE)[.]' + ~ ('_' + kind_param)).flags(re.I)
245 logical_literal_constant_named = Pattern('<value>',r'[.](TRUE|FALSE)[.]',flags=re.I).named() + ~ ('_' + kind_param_named)
246 literal_constant = int_literal_constant | real_literal_constant | complex_literal_constant | logical_literal_constant | char_literal_constant | boz_literal_constant
247 constant = literal_constant | named_constant
248 int_constant = int_literal_constant | boz_literal_constant | named_constant
249 char_constant = char_literal_constant | named_constant
251 # assume that replace_string_map is applied:
252 part_ref = name + ~((r'[(]' + name + r'[)]'))
253 data_ref = part_ref + ~~~(r'[%]' + part_ref)
254 primary = constant | name | data_ref | (r'[(]' + name + r'[)]')
256 power_op = Pattern('<power-op>',r'(?<![*])[*]{2}(?![*])')
257 mult_op = Pattern('<mult-op>',r'(?<![*])[*](?![*])|(?<![/])[/](?![/])')
258 add_op = Pattern('<add-op>',r'[+-]')
259 concat_op = Pattern('<concat-op>',r'(?<![/])[/]{2}(?![/])')
260 rel_op = Pattern('<rel-op>','[.]EQ[.]|[.]NE[.]|[.]LT[.]|[.]LE[.]|[.]GT[.]|[.]GE[.]|[=]{2}|/[=]|[<][=]|[<]|[>][=]|[>]',flags=re.I)
261 not_op = Pattern('<not-op>','[.]NOT[.]',flags=re.I)
262 and_op = Pattern('<and-op>','[.]AND[.]',flags=re.I)
263 or_op = Pattern('<or-op>','[.]OR[.]',flags=re.I)
264 equiv_op = Pattern('<equiv-op>','[.]EQV[.]|[.]NEQV[.]',flags=re.I)
265 percent_op = Pattern('<percent-op>',r'%',flags=re.I)
266 intrinsic_operator = power_op | mult_op | add_op | concat_op | rel_op | not_op | and_op | or_op | equiv_op
267 extended_intrinsic_operator = intrinsic_operator
269 defined_unary_op = Pattern('<defined-unary-op>','[.][A-Z]+[.]',flags=re.I)
270 defined_binary_op = Pattern('<defined-binary-op>','[.][A-Z]+[.]',flags=re.I)
271 defined_operator = defined_unary_op | defined_binary_op | extended_intrinsic_operator
272 abs_defined_operator = abs(defined_operator)
273 defined_op = Pattern('<defined-op>','[.][A-Z]+[.]',flags=re.I)
274 abs_defined_op = abs(defined_op)
276 non_defined_binary_op = intrinsic_operator | logical_literal_constant
278 label = Pattern('<label>','\d{1,5}')
279 abs_label = abs(label)
281 keyword = name
282 keyword_equal = keyword + '='
287 abs_constant = abs(constant)
288 abs_literal_constant = abs(literal_constant)
289 abs_int_literal_constant = abs(int_literal_constant)
290 abs_signed_int_literal_constant = abs(signed_int_literal_constant)
291 abs_signed_int_literal_constant_named = abs(signed_int_literal_constant_named)
292 abs_int_literal_constant_named = abs(int_literal_constant_named)
293 abs_real_literal_constant = abs(real_literal_constant)
294 abs_signed_real_literal_constant = abs(signed_real_literal_constant)
295 abs_signed_real_literal_constant_named = abs(signed_real_literal_constant_named)
296 abs_real_literal_constant_named = abs(real_literal_constant_named)
297 abs_complex_literal_constant = abs(complex_literal_constant)
298 abs_logical_literal_constant = abs(logical_literal_constant)
299 abs_char_literal_constant = abs(char_literal_constant)
300 abs_boz_literal_constant = abs(boz_literal_constant)
301 abs_name = abs(name)
302 abs_a_n_char_literal_constant_named1 = abs(a_n_char_literal_constant_named1)
303 abs_a_n_char_literal_constant_named2 = abs(a_n_char_literal_constant_named2)
304 abs_logical_literal_constant_named = abs(logical_literal_constant_named)
305 abs_binary_constant = abs(binary_constant)
306 abs_octal_constant = abs(octal_constant)
307 abs_hex_constant = abs(hex_constant)
309 intrinsic_type_name = Pattern('<intrinsic-type-name>',r'(INTEGER|REAL|COMPLEX|LOGICAL|CHARACTER|DOUBLE\s*COMPLEX|DOUBLE\s*PRECISION|BYTE)',flags=re.I)
310 abs_intrinsic_type_name = abs(intrinsic_type_name)
311 double_complex_name = Pattern('<double-complex-name>','DOUBLE\s*COMPLEX', flags=re.I, value='DOUBLE COMPLEX')
312 double_precision_name = Pattern('<double-precision-name>','DOUBLE\s*PRECISION', flags=re.I, value='DOUBLE PRECISION')
313 abs_double_complex_name = abs(double_complex_name)
314 abs_double_precision_name = abs(double_precision_name)
316 access_spec = Pattern('<access-spec>',r'PUBLIC|PRIVATE',flags=re.I)
317 abs_access_spec = abs(access_spec)
319 implicit_none = Pattern('<implicit-none>',r'IMPLICIT\s*NONE',flags=re.I, value='IMPLICIT NONE')
320 abs_implicit_none = abs(implicit_none)
322 attr_spec = Pattern('<attr-spec>',r'ALLOCATABLE|ASYNCHRONOUS|EXTERNAL|INTENT|INTRINSIC|OPTIONAL|PARAMETER|POINTER|PROTECTED|SAVE|TARGET|VALUE|VOLATILE',flags=re.I)
323 abs_attr_spec = abs(attr_spec)
325 dimension = Pattern('<dimension>',r'DIMENSION', flags=re.I)
326 abs_dimension = abs(dimension)
328 intent = Pattern('<intent>', r'INTENT', flags=re.I)
329 abs_intent = abs(intent)
331 intent_spec = Pattern('<intent-spec>', r'INOUT|IN|OUT', flags=re.I)
332 abs_intent_spec = abs(intent_spec)
334 subroutine = Pattern('<subroutine>', r'SUBROUTINE', flags=re.I)
336 select_case = Pattern('<select-case>', r'SELECT\s*CASE', flags=re.I, value='SELECT CASE')
337 abs_select_case = abs(select_case)
339 def _test():
340 assert name.match('a1_a')
341 assert abs(name).match('a1_a')
342 assert not abs(name).match('a1_a[]')
344 m = abs(kind_param)
345 assert m.match('23')
346 assert m.match('SHORT')
348 m = abs(signed_digit_string)
349 assert m.match('23')
350 assert m.match('+ 23')
351 assert m.match('- 23')
352 assert m.match('-23')
353 assert not m.match('+n')
355 m = ~sign.named() + digit_string.named('number')
356 r = m.match('23')
357 assert r.groupdict()=={'number': '23', 'sign': None}
358 r = m.match('- 23')
359 assert r.groupdict()=={'number': '23', 'sign': '-'}
361 m = abs(char_literal_constant)
362 assert m.match('"adadfa"')
363 assert m.match('"adadfa""adad"')
364 assert m.match('HEY_"adadfa"')
365 assert m.match('HEY _ "ad\tadfa"')
366 assert not m.match('adadfa')
368 def assert_equal(result, expect):
369 try:
370 assert result==expect
371 except AssertionError, msg:
372 raise AssertionError,"Expected %r but got %r: %s" \
373 % (expect, result, msg)
375 m = mult_op.named()
376 assert m.rsplit('a * b')
377 assert_equal(m.lsplit('a * c* b'),('a','*','c* b'))
378 assert_equal(m.rsplit('a * c* b'),('a * c','*','b'))
379 assert_equal(m.lsplit('a * b ** c'),('a','*','b ** c'))
380 assert_equal(m.rsplit('a * b ** c'),('a','*','b ** c'))
381 assert_equal(m.lsplit('a * b ** c * d'),('a','*','b ** c * d'))
382 assert_equal(m.rsplit('a * b ** c * d'),('a * b ** c','*','d'))
384 m = power_op.named()
385 assert m.rsplit('a ** b')
386 assert_equal(m.lsplit('a * b ** c'),('a * b','**','c'))
387 assert_equal(m.rsplit('a * b ** c'),('a * b','**','c'))
388 assert_equal(m.lsplit('a ** b ** c'),('a','**','b ** c'))
389 assert_equal(m.rsplit('a ** b ** c'),('a ** b','**','c'))
390 print 'ok'
392 if __name__ == '__main__':
393 _test()