add an extra global that is never defined; this is used to handle names that never...
[pythonc.git] / transform.py
blob90460bfe1f26eb3bc917d6df95a71afbcb9be3b4
1 ################################################################################
2 ##
3 ## Pythonc--Python to C++ translator
4 ##
5 ## Copyright 2011 Zach Wegner
6 ##
7 ## This file is part of Pythonc.
8 ##
9 ## Pythonc is free software: you can redistribute it and/or modify
10 ## it under the terms of the GNU General Public License as published by
11 ## the Free Software Foundation, either version 3 of the License, or
12 ## (at your option) any later version.
14 ## Pythonc is distributed in the hope that it will be useful,
15 ## but WITHOUT ANY WARRANTY; without even the implied warranty of
16 ## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17 ## GNU General Public License for more details.
19 ## You should have received a copy of the GNU General Public License
20 ## along with Pythonc. If not, see <http://www.gnu.org/licenses/>.
22 ################################################################################
24 import ast
25 import sys
27 import syntax
29 builtin_functions = [
30 'all',
31 'any',
32 'isinstance',
33 'len',
34 'open',
35 'ord',
36 'print',
37 'print_nonl',
38 'repr',
39 'sorted',
41 builtin_classes = [
42 'bool',
43 'dict',
44 'enumerate',
45 'int',
46 'list',
47 'range',
48 'reversed',
49 'set',
50 'str',
51 'tuple',
52 'zip',
54 builtin_symbols = builtin_functions + builtin_classes + [
55 '__name__',
56 '__args__',
59 class Transformer(ast.NodeTransformer):
60 def __init__(self):
61 self.temp_id = 0
62 self.statements = []
63 self.functions = []
64 self.in_class = False
65 self.in_function = False
66 self.globals_set = None
68 def get_temp_name(self):
69 self.temp_id += 1
70 return 'temp_%02i' % self.temp_id
72 def get_temp(self):
73 self.temp_id += 1
74 return syntax.Identifier('temp_%02i' % self.temp_id)
76 def flatten_node(self, node, statements=None):
77 old_stmts = self.statements
78 if statements is not None:
79 self.statements = statements
80 node = self.visit(node)
81 if node.is_atom():
82 r = node
83 else:
84 temp = self.get_temp()
85 self.statements.append(syntax.Assign(temp, node))
86 r = temp
87 self.statements = old_stmts
88 return r
90 def flatten_list(self, node_list):
91 old_stmts = self.statements
92 statements = []
93 for stmt in node_list:
94 self.statements = []
95 stmts = self.visit(stmt)
96 if stmts:
97 if isinstance(stmts, list):
98 statements += self.statements + stmts
99 else:
100 statements += self.statements + [stmts]
101 self.statements = old_stmts
102 return statements
104 def index_global_class_symbols(self, node, globals_set, class_set):
105 if isinstance(node, ast.Global):
106 for name in node.names:
107 globals_set.add(name)
108 # XXX make this check scope
109 elif isinstance(node, ast.Name) and isinstance(node.ctx,
110 (ast.Store, ast.AugStore)):
111 globals_set.add(node.id)
112 class_set.add(node.id)
113 elif isinstance(node, (ast.FunctionDef, ast.ClassDef)):
114 globals_set.add(node.name)
115 class_set.add(node.name)
116 elif isinstance(node, ast.Import):
117 for name in node.names:
118 globals_set.add(name.name)
119 elif isinstance(node, (ast.For, ast.ListComp, ast.DictComp, ast.SetComp,
120 ast.GeneratorExp)):
121 # HACK: set self.iter_temp for the space in the symbol table
122 node.iter_temp = self.get_temp_name()
123 globals_set.add(node.iter_temp)
124 for i in ast.iter_child_nodes(node):
125 self.index_global_class_symbols(i, globals_set, class_set)
127 def get_globals(self, node, globals_set, locals_set, all_vars_set):
128 if isinstance(node, ast.Global):
129 for name in node.names:
130 globals_set.add(name)
131 elif isinstance(node, ast.Name):
132 all_vars_set.add(node.id)
133 if isinstance(node.ctx, (ast.Store, ast.AugStore)):
134 locals_set.add(node.id)
135 elif isinstance(node, ast.arg):
136 all_vars_set.add(node.arg)
137 locals_set.add(node.arg)
138 elif isinstance(node, (ast.For, ast.ListComp, ast.DictComp, ast.SetComp,
139 ast.GeneratorExp)):
140 locals_set.add(node.iter_temp)
141 for i in ast.iter_child_nodes(node):
142 self.get_globals(i, globals_set, locals_set, all_vars_set)
144 def get_binding(self, name):
145 if self.in_function:
146 if name in self.globals_set:
147 scope = 'global'
148 else:
149 scope = 'local'
150 elif self.in_class:
151 scope = 'class'
152 else:
153 scope = 'global'
154 if name in self.symbol_idx[scope]:
155 return (scope, self.symbol_idx[scope][name])
156 return (scope, self.symbol_idx[scope]['$undefined'])
158 def generic_visit(self, node):
159 print(node.lineno)
160 raise RuntimeError('can\'t translate %s' % node)
162 def visit_children(self, node):
163 return [self.visit(i) for i in ast.iter_child_nodes(node)]
165 def visit_Name(self, node):
166 assert isinstance(node.ctx, ast.Load)
167 if node.id in ['True', 'False']:
168 return syntax.BoolConst(node.id == 'True')
169 elif node.id == 'None':
170 return syntax.NoneConst()
171 return syntax.Load(node.id, self.get_binding(node.id))
173 def visit_Num(self, node):
174 if isinstance(node.n, float):
175 raise RuntimeError('Pythonc currently does not support float literals')
176 assert isinstance(node.n, int)
177 return syntax.IntConst(node.n)
179 def visit_Str(self, node):
180 assert isinstance(node.s, str)
181 return syntax.StringConst(node.s)
183 def visit_Bytes(self, node):
184 raise RuntimeError('Pythonc currently does not support bytes literals')
186 # Unary Ops
187 def visit_Invert(self, node): return '__invert__'
188 def visit_Not(self, node): return '__not__'
189 def visit_UAdd(self, node): return '__pos__'
190 def visit_USub(self, node): return '__neg__'
191 def visit_UnaryOp(self, node):
192 op = self.visit(node.op)
193 rhs = self.flatten_node(node.operand)
194 return syntax.UnaryOp(op, rhs)
196 # Binary Ops
197 def visit_Add(self, node): return '__add__'
198 def visit_BitAnd(self, node): return '__and__'
199 def visit_BitOr(self, node): return '__or__'
200 def visit_BitXor(self, node): return '__xor__'
201 def visit_Div(self, node): return '__truediv__'
202 def visit_FloorDiv(self, node): return '__floordiv__'
203 def visit_LShift(self, node): return '__lshift__'
204 def visit_Mod(self, node): return '__mod__'
205 def visit_Mult(self, node): return '__mul__'
206 def visit_Pow(self, node): return '__pow__'
207 def visit_RShift(self, node): return '__rshift__'
208 def visit_Sub(self, node): return '__sub__'
210 def visit_BinOp(self, node):
211 op = self.visit(node.op)
212 lhs = self.flatten_node(node.left)
213 rhs = self.flatten_node(node.right)
214 return syntax.BinaryOp(op, lhs, rhs)
216 # Comparisons
217 def visit_Eq(self, node): return '__eq__'
218 def visit_NotEq(self, node): return '__ne__'
219 def visit_Lt(self, node): return '__lt__'
220 def visit_LtE(self, node): return '__le__'
221 def visit_Gt(self, node): return '__gt__'
222 def visit_GtE(self, node): return '__ge__'
223 def visit_In(self, node): return '__contains__'
224 def visit_NotIn(self, node): return '__ncontains__'
225 def visit_Is(self, node): return '__is__'
226 def visit_IsNot(self, node): return '__isnot__'
228 def visit_Compare(self, node):
229 assert len(node.ops) == 1
230 assert len(node.comparators) == 1
231 op = self.visit(node.ops[0])
232 lhs = self.flatten_node(node.left)
233 rhs = self.flatten_node(node.comparators[0])
234 # Sigh--Python has these ordered weirdly
235 if op in ['__contains__', '__ncontains__']:
236 lhs, rhs = rhs, lhs
237 return syntax.BinaryOp(op, lhs, rhs)
239 # Bool ops
240 def visit_And(self, node): return 'and'
241 def visit_Or(self, node): return 'or'
243 def visit_BoolOp(self, node):
244 assert len(node.values) >= 2
245 op = self.visit(node.op)
246 rhs_stmts = []
247 rhs_expr = self.flatten_node(node.values[-1], statements=rhs_stmts)
248 for v in reversed(node.values[:-1]):
249 lhs_stmts = []
250 lhs = self.flatten_node(v, statements=lhs_stmts)
251 bool_op = syntax.BoolOp(op, lhs, rhs_stmts, rhs_expr)
252 rhs_expr = bool_op.flatten(self, lhs_stmts)
253 rhs_stmts = lhs_stmts
254 self.statements += rhs_stmts
255 return rhs_expr
257 def visit_IfExp(self, node):
258 expr = self.flatten_node(node.test)
259 true_stmts = []
260 true_expr = self.flatten_node(node.body, statements=true_stmts)
261 false_stmts = []
262 false_expr = self.flatten_node(node.orelse, statements=false_stmts)
263 if_exp = syntax.IfExp(expr, true_stmts, true_expr, false_stmts, false_expr)
264 return if_exp.flatten(self)
266 def visit_List(self, node):
267 items = [self.flatten_node(i) for i in node.elts]
268 l = syntax.List(items)
269 return l.flatten(self)
271 def visit_Tuple(self, node):
272 items = [self.flatten_node(i) for i in node.elts]
273 l = syntax.Tuple(items)
274 return l.flatten(self)
276 def visit_Dict(self, node):
277 keys = [self.flatten_node(i) for i in node.keys]
278 values = [self.flatten_node(i) for i in node.values]
279 d = syntax.Dict(keys, values)
280 return d.flatten(self)
282 def visit_Set(self, node):
283 items = [self.flatten_node(i) for i in node.elts]
284 d = syntax.Set(items)
285 return d.flatten(self)
287 def visit_Subscript(self, node):
288 l = self.flatten_node(node.value)
289 if isinstance(node.slice, ast.Index):
290 index = self.flatten_node(node.slice.value)
291 return syntax.Subscript(l, index)
292 elif isinstance(node.slice, ast.Slice):
293 [start, end, step] = [self.flatten_node(a) if a else syntax.NoneConst() for a in
294 [node.slice.lower, node.slice.upper, node.slice.step]]
295 return syntax.Slice(l, start, end, step)
297 def visit_Attribute(self, node):
298 assert isinstance(node.ctx, ast.Load)
299 l = self.flatten_node(node.value)
300 attr = syntax.Attribute(l, syntax.StringConst(node.attr))
301 return attr
303 def visit_Call(self, node):
304 fn = self.flatten_node(node.func)
306 if node.starargs:
307 assert not node.args
308 assert not node.kwargs
309 args = syntax.Tuple(self.flatten_node(node.starargs))
310 args = args.flatten(self)
311 kwargs = syntax.Dict([], [])
312 else:
313 args = syntax.Tuple([self.flatten_node(a) for a in node.args])
314 args = args.flatten(self)
316 keys = [syntax.StringConst(i.arg) for i in node.keywords]
317 values = [self.flatten_node(i.value) for i in node.keywords]
318 kwargs = syntax.Dict(keys, values)
320 kwargs = kwargs.flatten(self)
321 return syntax.Call(fn, args, kwargs)
323 def visit_Assign(self, node):
324 assert len(node.targets) == 1
325 target = node.targets[0]
326 value = self.flatten_node(node.value)
327 if isinstance(target, ast.Name):
328 return [syntax.Store(target.id, value, self.get_binding(target.id))]
329 elif isinstance(target, ast.Tuple):
330 assert all(isinstance(t, ast.Name) for t in target.elts)
331 stmts = []
332 for i, t in enumerate(target.elts):
333 stmts += [syntax.Store(t.id, syntax.Subscript(value, syntax.IntConst(i)), self.get_binding(t.id))]
334 return stmts
335 elif isinstance(target, ast.Attribute):
336 base = self.flatten_node(target.value)
337 return [syntax.StoreAttr(base, syntax.StringConst(target.attr), value)]
338 elif isinstance(target, ast.Subscript):
339 assert isinstance(target.slice, ast.Index)
340 base = self.flatten_node(target.value)
341 index = self.flatten_node(target.slice.value)
342 return [syntax.StoreSubscript(base, index, value)]
343 else:
344 assert False
346 def visit_AugAssign(self, node):
347 op = self.visit(node.op)
348 value = self.flatten_node(node.value)
349 if isinstance(node.target, ast.Name):
350 target = node.target.id
351 # XXX HACK: doesn't modify in place
352 binop = syntax.BinaryOp(op, syntax.Load(target, self.get_binding(target)), value)
353 return [syntax.Store(target, binop, self.get_binding(target))]
354 elif isinstance(node.target, ast.Attribute):
355 l = self.flatten_node(node.target.value)
356 attr_name = syntax.StringConst(node.target.attr)
357 attr = syntax.Attribute(l, attr_name)
358 binop = syntax.BinaryOp(op, attr, value)
359 return [syntax.StoreAttr(l, attr_name, binop)]
360 elif isinstance(node.target, ast.Subscript):
361 assert isinstance(node.target.slice, ast.Index)
362 base = self.flatten_node(node.target.value)
363 index = self.flatten_node(node.target.slice.value)
364 old = syntax.Subscript(base, index)
365 binop = syntax.BinaryOp(op, old, value)
366 return [syntax.StoreSubscript(base, index, binop)]
367 else:
368 assert False
370 def visit_Delete(self, node):
371 assert len(node.targets) == 1
372 target = node.targets[0]
373 assert isinstance(target, ast.Subscript)
374 assert isinstance(target.slice, ast.Index)
376 name = self.flatten_node(target.value)
377 value = self.flatten_node(target.slice.value)
378 return [syntax.DeleteSubscript(name, value)]
380 def visit_If(self, node):
381 expr = self.flatten_node(node.test)
382 stmts = self.flatten_list(node.body)
383 if node.orelse:
384 else_block = self.flatten_list(node.orelse)
385 else:
386 else_block = None
387 return syntax.If(expr, stmts, else_block)
389 def visit_Break(self, node):
390 return syntax.Break()
392 def visit_Continue(self, node):
393 return syntax.Continue()
395 def visit_For(self, node):
396 assert not node.orelse
397 iter = self.flatten_node(node.iter)
398 stmts = self.flatten_list(node.body)
400 if isinstance(node.target, ast.Name):
401 target = (node.target.id, self.get_binding(node.target.id))
402 elif isinstance(node.target, ast.Tuple):
403 target = [(t.id, self.get_binding(t.id)) for t in node.target.elts]
404 else:
405 assert False
406 # HACK: self.iter_temp gets set when enumerating symbols
407 for_loop = syntax.For(target, iter, stmts, node.iter_temp, self.get_binding(node.iter_temp))
408 return for_loop.flatten(self)
410 def visit_While(self, node):
411 assert not node.orelse
412 test_stmts = []
413 test = self.flatten_node(node.test, statements=test_stmts)
414 stmts = self.flatten_list(node.body)
415 return syntax.While(test_stmts, test, stmts)
417 # XXX We are just flattening "with x as y:" into "y = x" (this works in some simple cases with open()).
418 def visit_With(self, node):
419 assert node.optional_vars
420 expr = self.flatten_node(node.context_expr)
421 stmts = [syntax.Store(node.optional_vars.id, expr, self.get_binding(node.optional_vars.id))]
422 stmts += self.flatten_list(node.body)
423 return stmts
425 def visit_Comprehension(self, node, comp_type):
426 assert len(node.generators) == 1
427 gen = node.generators[0]
428 assert len(gen.ifs) <= 1
430 if isinstance(gen.target, ast.Name):
431 target = (gen.target.id, self.get_binding(gen.target.id))
432 elif isinstance(gen.target, ast.Tuple):
433 target = [(t.id, self.get_binding(t.id)) for t in gen.target.elts]
434 else:
435 assert False
437 iter = self.flatten_node(gen.iter)
438 cond_stmts = []
439 expr_stmts = []
440 cond = None
441 if gen.ifs:
442 cond = self.flatten_node(gen.ifs[0], statements=cond_stmts)
443 if comp_type == 'dict':
444 expr = self.flatten_node(node.key, statements=expr_stmts)
445 expr2 = self.flatten_node(node.value, statements=expr_stmts)
446 else:
447 expr = self.flatten_node(node.elt, statements=expr_stmts)
448 expr2 = None
449 comp = syntax.Comprehension(comp_type, target, iter, node.iter_temp,
450 self.get_binding(node.iter_temp), cond_stmts, cond, expr_stmts,
451 expr, expr2)
452 return comp.flatten(self)
454 def visit_ListComp(self, node):
455 return self.visit_Comprehension(node, 'list')
457 def visit_SetComp(self, node):
458 return self.visit_Comprehension(node, 'set')
460 def visit_DictComp(self, node):
461 return self.visit_Comprehension(node, 'dict')
463 def visit_GeneratorExp(self, node):
464 return self.visit_Comprehension(node, 'generator')
466 def visit_Return(self, node):
467 if node.value is not None:
468 expr = self.flatten_node(node.value)
469 return syntax.Return(expr)
470 else:
471 return syntax.Return(None)
473 def visit_Assert(self, node):
474 expr = self.flatten_node(node.test)
475 return syntax.Assert(expr, node.lineno)
477 def visit_Raise(self, node):
478 assert not node.cause
479 expr = self.flatten_node(node.exc)
480 return syntax.Raise(expr, node.lineno)
482 def visit_arguments(self, node):
483 assert not node.vararg
484 assert not node.kwarg
486 args = [a.arg for a in node.args]
487 binding = [self.get_binding(a) for a in args]
488 defaults = self.flatten_list(node.defaults)
489 args = syntax.Arguments(args, binding, defaults)
490 return args.flatten(self)
492 def visit_FunctionDef(self, node):
493 assert not self.in_function
495 # Get bindings of all variables. Globals are the variables that have "global x"
496 # somewhere in the function, or are never written in the function.
497 globals_set = set()
498 locals_set = set()
499 all_vars_set = set()
500 self.get_globals(node, globals_set, locals_set, all_vars_set)
501 globals_set |= (all_vars_set - locals_set)
503 self.symbol_idx['local'] = {symbol: idx for idx, symbol in enumerate(sorted(locals_set))}
505 # Set some state and recursively visit child nodes, then restore state
506 self.globals_set = globals_set
507 self.in_function = True
508 args = self.visit(node.args)
509 body = self.flatten_list(node.body)
510 self.globals_set = None
511 self.in_function = False
513 exp_name = node.exp_name if 'exp_name' in dir(node) else None
514 fn = syntax.FunctionDef(node.name, args, body, exp_name, self.get_binding(node.name), len(locals_set))
515 return fn.flatten(self)
517 def visit_ClassDef(self, node):
518 assert not node.bases
519 assert not node.keywords
520 assert not node.starargs
521 assert not node.kwargs
522 assert not node.decorator_list
523 assert not self.in_class
524 assert not self.in_function
526 for fn in node.body:
527 if isinstance(fn, ast.FunctionDef):
528 fn.exp_name = '_%s_%s' % (node.name, fn.name)
530 self.in_class = True
531 body = self.flatten_list(node.body)
532 self.in_class = False
534 c = syntax.ClassDef(node.name, self.get_binding(node.name), body)
535 return c.flatten(self)
537 # XXX This just turns "import x" into "x = 0". It's certainly not what we really want...
538 def visit_Import(self, node):
539 statements = []
540 for name in node.names:
541 assert not name.asname
542 assert name.name
543 statements.append(syntax.Store(name.name, syntax.IntConst(0), self.get_binding(name.name)))
544 return statements
546 def visit_Expr(self, node):
547 return self.visit(node.value)
549 def visit_Module(self, node):
550 # Set up an index of all possible global/class symbols
551 all_global_syms = set()
552 all_class_syms = set()
553 self.index_global_class_symbols(node, all_global_syms, all_class_syms)
555 all_global_syms.add('$undefined')
556 all_global_syms |= set(builtin_symbols)
558 self.symbol_idx = {
559 scope: {symbol: idx for idx, symbol in enumerate(sorted(symbols))}
560 for scope, symbols in [['class', all_class_syms], ['global', all_global_syms]]
562 self.global_sym_count = len(all_global_syms)
563 self.class_sym_count = len(all_class_syms)
565 return self.flatten_list(node.body)
567 def visit_Pass(self, node): pass
568 def visit_Load(self, node): pass
569 def visit_Store(self, node): pass
570 def visit_Global(self, node): pass
572 with open(sys.argv[1]) as f:
573 node = ast.parse(f.read())
575 transformer = Transformer()
576 node = transformer.visit(node)
578 with open(sys.argv[2], 'w') as f:
579 f.write('#define LIST_BUILTIN_FUNCTIONS(x) %s\n' % ' '.join('x(%s)' % x
580 for x in builtin_functions))
581 f.write('#define LIST_BUILTIN_CLASSES(x) %s\n' % ' '.join('x(%s)' % x
582 for x in builtin_classes))
583 for x in builtin_symbols:
584 f.write('#define sym_id_%s %s\n' % (x, transformer.symbol_idx['global'][x]))
585 f.write('#include "backend.cpp"\n')
586 syntax.export_consts(f)
588 for func in transformer.functions:
589 f.write('%s\n' % func)
591 f.write('int main(int argc, char **argv) {\n')
592 f.write(' node *global_syms[%s] = {0};\n' % (transformer.global_sym_count))
593 f.write(' context ctx(%s, global_syms), *globals = &ctx;\n' % (transformer.global_sym_count))
594 f.write(' init_context(&ctx, argc, argv);\n')
596 for stmt in node:
597 f.write(' %s;\n' % stmt)
599 f.write('}\n')