Merged revisions 73196,73278-73280,73299,73308,73312-73313,73317-73318,73321,73324...
[python/dscho.git] / Parser / asdl.py
blob28a7138797895d61389bcc011bf1b9ab346c84c1
1 """An implementation of the Zephyr Abstract Syntax Definition Language.
3 See http://asdl.sourceforge.net/ and
4 http://www.cs.princeton.edu/~danwang/Papers/dsl97/dsl97-abstract.html.
6 Only supports top level module decl, not view. I'm guessing that view
7 is intended to support the browser and I'm not interested in the
8 browser.
10 Changes for Python: Add support for module versions
11 """
13 import os
14 import sys
15 import traceback
17 import spark
19 def output(string):
20 sys.stdout.write(string + "\n")
23 class Token(object):
24 # spark seems to dispatch in the parser based on a token's
25 # type attribute
26 def __init__(self, type, lineno):
27 self.type = type
28 self.lineno = lineno
30 def __str__(self):
31 return self.type
33 def __repr__(self):
34 return str(self)
36 class Id(Token):
37 def __init__(self, value, lineno):
38 self.type = 'Id'
39 self.value = value
40 self.lineno = lineno
42 def __str__(self):
43 return self.value
45 class String(Token):
46 def __init__(self, value, lineno):
47 self.type = 'String'
48 self.value = value
49 self.lineno = lineno
51 class ASDLSyntaxError(Exception):
53 def __init__(self, lineno, token=None, msg=None):
54 self.lineno = lineno
55 self.token = token
56 self.msg = msg
58 def __str__(self):
59 if self.msg is None:
60 return "Error at '%s', line %d" % (self.token, self.lineno)
61 else:
62 return "%s, line %d" % (self.msg, self.lineno)
64 class ASDLScanner(spark.GenericScanner, object):
66 def tokenize(self, input):
67 self.rv = []
68 self.lineno = 1
69 super(ASDLScanner, self).tokenize(input)
70 return self.rv
72 def t_id(self, s):
73 r"[\w\.]+"
74 # XXX doesn't distinguish upper vs. lower, which is
75 # significant for ASDL.
76 self.rv.append(Id(s, self.lineno))
78 def t_string(self, s):
79 r'"[^"]*"'
80 self.rv.append(String(s, self.lineno))
82 def t_xxx(self, s): # not sure what this production means
83 r"<="
84 self.rv.append(Token(s, self.lineno))
86 def t_punctuation(self, s):
87 r"[\{\}\*\=\|\(\)\,\?\:]"
88 self.rv.append(Token(s, self.lineno))
90 def t_comment(self, s):
91 r"\-\-[^\n]*"
92 pass
94 def t_newline(self, s):
95 r"\n"
96 self.lineno += 1
98 def t_whitespace(self, s):
99 r"[ \t]+"
100 pass
102 def t_default(self, s):
103 r" . +"
104 raise ValueError("unmatched input: %r" % s)
106 class ASDLParser(spark.GenericParser, object):
107 def __init__(self):
108 super(ASDLParser, self).__init__("module")
110 def typestring(self, tok):
111 return tok.type
113 def error(self, tok):
114 raise ASDLSyntaxError(tok.lineno, tok)
116 def p_module_0(self, info):
117 " module ::= Id Id version { } "
118 module, name, version, _0, _1 = info
119 if module.value != "module":
120 raise ASDLSyntaxError(module.lineno,
121 msg="expected 'module', found %s" % module)
122 return Module(name, None, version)
124 def p_module(self, info):
125 " module ::= Id Id version { definitions } "
126 module, name, version, _0, definitions, _1 = info
127 if module.value != "module":
128 raise ASDLSyntaxError(module.lineno,
129 msg="expected 'module', found %s" % module)
130 return Module(name, definitions, version)
132 def p_version(self, info):
133 "version ::= Id String"
134 version, V = info
135 if version.value != "version":
136 raise ASDLSyntaxError(version.lineno,
137 msg="expected 'version', found %" % version)
138 return V
140 def p_definition_0(self, definition):
141 " definitions ::= definition "
142 return definition[0]
144 def p_definition_1(self, definitions):
145 " definitions ::= definition definitions "
146 return definitions[0] + definitions[1]
148 def p_definition(self, info):
149 " definition ::= Id = type "
150 id, _, type = info
151 return [Type(id, type)]
153 def p_type_0(self, product):
154 " type ::= product "
155 return product[0]
157 def p_type_1(self, sum):
158 " type ::= sum "
159 return Sum(sum[0])
161 def p_type_2(self, info):
162 " type ::= sum Id ( fields ) "
163 sum, id, _0, attributes, _1 = info
164 if id.value != "attributes":
165 raise ASDLSyntaxError(id.lineno,
166 msg="expected attributes, found %s" % id)
167 if attributes:
168 attributes.reverse()
169 return Sum(sum, attributes)
171 def p_product(self, info):
172 " product ::= ( fields ) "
173 _0, fields, _1 = info
174 # XXX can't I just construct things in the right order?
175 fields.reverse()
176 return Product(fields)
178 def p_sum_0(self, constructor):
179 " sum ::= constructor "
180 return [constructor[0]]
182 def p_sum_1(self, info):
183 " sum ::= constructor | sum "
184 constructor, _, sum = info
185 return [constructor] + sum
187 def p_sum_2(self, info):
188 " sum ::= constructor | sum "
189 constructor, _, sum = info
190 return [constructor] + sum
192 def p_constructor_0(self, id):
193 " constructor ::= Id "
194 return Constructor(id[0])
196 def p_constructor_1(self, info):
197 " constructor ::= Id ( fields ) "
198 id, _0, fields, _1 = info
199 # XXX can't I just construct things in the right order?
200 fields.reverse()
201 return Constructor(id, fields)
203 def p_fields_0(self, field):
204 " fields ::= field "
205 return [field[0]]
207 def p_fields_1(self, info):
208 " fields ::= field , fields "
209 field, _, fields = info
210 return fields + [field]
212 def p_field_0(self, type_):
213 " field ::= Id "
214 return Field(type_[0])
216 def p_field_1(self, info):
217 " field ::= Id Id "
218 type, name = info
219 return Field(type, name)
221 def p_field_2(self, info):
222 " field ::= Id * Id "
223 type, _, name = info
224 return Field(type, name, seq=True)
226 def p_field_3(self, info):
227 " field ::= Id ? Id "
228 type, _, name = info
229 return Field(type, name, opt=True)
231 def p_field_4(self, type_):
232 " field ::= Id * "
233 return Field(type_[0], seq=True)
235 def p_field_5(self, type_):
236 " field ::= Id ? "
237 return Field(type[0], opt=True)
239 builtin_types = ("identifier", "string", "int", "bool", "object")
241 # below is a collection of classes to capture the AST of an AST :-)
242 # not sure if any of the methods are useful yet, but I'm adding them
243 # piecemeal as they seem helpful
245 class AST(object):
246 pass # a marker class
248 class Module(AST):
249 def __init__(self, name, dfns, version):
250 self.name = name
251 self.dfns = dfns
252 self.version = version
253 self.types = {} # maps type name to value (from dfns)
254 for type in dfns:
255 self.types[type.name.value] = type.value
257 def __repr__(self):
258 return "Module(%s, %s)" % (self.name, self.dfns)
260 class Type(AST):
261 def __init__(self, name, value):
262 self.name = name
263 self.value = value
265 def __repr__(self):
266 return "Type(%s, %s)" % (self.name, self.value)
268 class Constructor(AST):
269 def __init__(self, name, fields=None):
270 self.name = name
271 self.fields = fields or []
273 def __repr__(self):
274 return "Constructor(%s, %s)" % (self.name, self.fields)
276 class Field(AST):
277 def __init__(self, type, name=None, seq=False, opt=False):
278 self.type = type
279 self.name = name
280 self.seq = seq
281 self.opt = opt
283 def __repr__(self):
284 if self.seq:
285 extra = ", seq=True"
286 elif self.opt:
287 extra = ", opt=True"
288 else:
289 extra = ""
290 if self.name is None:
291 return "Field(%s%s)" % (self.type, extra)
292 else:
293 return "Field(%s, %s%s)" % (self.type, self.name, extra)
295 class Sum(AST):
296 def __init__(self, types, attributes=None):
297 self.types = types
298 self.attributes = attributes or []
300 def __repr__(self):
301 if self.attributes is None:
302 return "Sum(%s)" % self.types
303 else:
304 return "Sum(%s, %s)" % (self.types, self.attributes)
306 class Product(AST):
307 def __init__(self, fields):
308 self.fields = fields
310 def __repr__(self):
311 return "Product(%s)" % self.fields
313 class VisitorBase(object):
315 def __init__(self, skip=False):
316 self.cache = {}
317 self.skip = skip
319 def visit(self, object, *args):
320 meth = self._dispatch(object)
321 if meth is None:
322 return
323 try:
324 meth(object, *args)
325 except Exception:
326 output("Error visiting", repr(object))
327 output(sys.exc_info()[1])
328 traceback.print_exc()
329 # XXX hack
330 if hasattr(self, 'file'):
331 self.file.flush()
332 os._exit(1)
334 def _dispatch(self, object):
335 assert isinstance(object, AST), repr(object)
336 klass = object.__class__
337 meth = self.cache.get(klass)
338 if meth is None:
339 methname = "visit" + klass.__name__
340 if self.skip:
341 meth = getattr(self, methname, None)
342 else:
343 meth = getattr(self, methname)
344 self.cache[klass] = meth
345 return meth
347 class Check(VisitorBase):
349 def __init__(self):
350 super(Check, self).__init__(skip=True)
351 self.cons = {}
352 self.errors = 0
353 self.types = {}
355 def visitModule(self, mod):
356 for dfn in mod.dfns:
357 self.visit(dfn)
359 def visitType(self, type):
360 self.visit(type.value, str(type.name))
362 def visitSum(self, sum, name):
363 for t in sum.types:
364 self.visit(t, name)
366 def visitConstructor(self, cons, name):
367 key = str(cons.name)
368 conflict = self.cons.get(key)
369 if conflict is None:
370 self.cons[key] = name
371 else:
372 output("Redefinition of constructor %s" % key)
373 output("Defined in %s and %s" % (conflict, name))
374 self.errors += 1
375 for f in cons.fields:
376 self.visit(f, key)
378 def visitField(self, field, name):
379 key = str(field.type)
380 l = self.types.setdefault(key, [])
381 l.append(name)
383 def visitProduct(self, prod, name):
384 for f in prod.fields:
385 self.visit(f, name)
387 def check(mod):
388 v = Check()
389 v.visit(mod)
391 for t in v.types:
392 if t not in mod.types and not t in builtin_types:
393 v.errors += 1
394 uses = ", ".join(v.types[t])
395 output("Undefined type %s, used in %s" % (t, uses))
397 return not v.errors
399 def parse(file):
400 scanner = ASDLScanner()
401 parser = ASDLParser()
403 buf = open(file).read()
404 tokens = scanner.tokenize(buf)
405 try:
406 return parser.parse(tokens)
407 except ASDLSyntaxError:
408 err = sys.exc_info()[1]
409 output(str(err))
410 lines = buf.split("\n")
411 output(lines[err.lineno - 1]) # lines starts at 0, files at 1
413 if __name__ == "__main__":
414 import glob
415 import sys
417 if len(sys.argv) > 1:
418 files = sys.argv[1:]
419 else:
420 testdir = "tests"
421 files = glob.glob(testdir + "/*.asdl")
423 for file in files:
424 output(file)
425 mod = parse(file)
426 if not mod:
427 break
428 output("module", mod.name)
429 output(len(mod.dfns), "definitions")
430 if not check(mod):
431 output("Check failed")
432 else:
433 for dfn in mod.dfns:
434 output(dfn.type)