1 ################################################################################
3 ## Pythonc--Python to C++ translator
5 ## Copyright 2011 Zach Wegner
7 ## This file is part of Pythonc.
9 ## Pythonc is free software: you can redistribute it and/or modify
10 ## it under the terms of the GNU General Public License as published by
11 ## the Free Software Foundation, either version 3 of the License, or
12 ## (at your option) any later version.
14 ## Pythonc is distributed in the hope that it will be useful,
15 ## but WITHOUT ANY WARRANTY; without even the implied warranty of
16 ## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17 ## GNU General Public License for more details.
19 ## You should have received a copy of the GNU General Public License
20 ## along with Pythonc. If not, see <http://www.gnu.org/licenses/>.
22 ################################################################################
54 builtin_symbols
= builtin_functions
+ builtin_classes
+ [
59 class Transformer(ast
.NodeTransformer
):
65 self
.in_function
= False
66 self
.globals_set
= None
68 def get_temp_name(self
):
70 return 'temp_%02i' % self
.temp_id
74 return syntax
.Identifier('temp_%02i' % self
.temp_id
)
76 def flatten_node(self
, node
, statements
=None):
77 old_stmts
= self
.statements
78 if statements
is not None:
79 self
.statements
= statements
80 node
= self
.visit(node
)
84 temp
= self
.get_temp()
85 self
.statements
.append(syntax
.Assign(temp
, node
))
87 self
.statements
= old_stmts
90 def flatten_list(self
, node_list
):
91 old_stmts
= self
.statements
93 for stmt
in node_list
:
95 stmts
= self
.visit(stmt
)
97 if isinstance(stmts
, list):
98 statements
+= self
.statements
+ stmts
100 statements
+= self
.statements
+ [stmts
]
101 self
.statements
= old_stmts
104 def index_global_class_symbols(self
, node
, globals_set
, class_set
):
105 if isinstance(node
, ast
.Global
):
106 for name
in node
.names
:
107 globals_set
.add(name
)
108 # XXX make this check scope
109 elif isinstance(node
, ast
.Name
) and isinstance(node
.ctx
,
110 (ast
.Store
, ast
.AugStore
)):
111 globals_set
.add(node
.id)
112 class_set
.add(node
.id)
113 elif isinstance(node
, (ast
.FunctionDef
, ast
.ClassDef
)):
114 globals_set
.add(node
.name
)
115 class_set
.add(node
.name
)
116 elif isinstance(node
, ast
.Import
):
117 for name
in node
.names
:
118 globals_set
.add(name
.name
)
119 elif isinstance(node
, (ast
.For
, ast
.ListComp
, ast
.DictComp
, ast
.SetComp
,
121 # HACK: set self.iter_temp for the space in the symbol table
122 node
.iter_temp
= self
.get_temp_name()
123 globals_set
.add(node
.iter_temp
)
124 for i
in ast
.iter_child_nodes(node
):
125 self
.index_global_class_symbols(i
, globals_set
, class_set
)
127 def get_globals(self
, node
, globals_set
, locals_set
, all_vars_set
):
128 if isinstance(node
, ast
.Global
):
129 for name
in node
.names
:
130 globals_set
.add(name
)
131 elif isinstance(node
, ast
.Name
):
132 all_vars_set
.add(node
.id)
133 if isinstance(node
.ctx
, (ast
.Store
, ast
.AugStore
)):
134 locals_set
.add(node
.id)
135 elif isinstance(node
, ast
.arg
):
136 all_vars_set
.add(node
.arg
)
137 locals_set
.add(node
.arg
)
138 elif isinstance(node
, (ast
.For
, ast
.ListComp
, ast
.DictComp
, ast
.SetComp
,
140 locals_set
.add(node
.iter_temp
)
141 for i
in ast
.iter_child_nodes(node
):
142 self
.get_globals(i
, globals_set
, locals_set
, all_vars_set
)
144 def get_binding(self
, name
):
146 if name
in self
.globals_set
:
154 if name
in self
.symbol_idx
[scope
]:
155 return (scope
, self
.symbol_idx
[scope
][name
])
156 return (scope
, self
.symbol_idx
[scope
]['$undefined'])
158 def generic_visit(self
, node
):
160 raise RuntimeError('can\'t translate %s' % node
)
162 def visit_children(self
, node
):
163 return [self
.visit(i
) for i
in ast
.iter_child_nodes(node
)]
165 def visit_Name(self
, node
):
166 assert isinstance(node
.ctx
, ast
.Load
)
167 if node
.id in ['True', 'False']:
168 return syntax
.BoolConst(node
.id == 'True')
169 elif node
.id == 'None':
170 return syntax
.NoneConst()
171 return syntax
.Load(node
.id, self
.get_binding(node
.id))
173 def visit_Num(self
, node
):
174 if isinstance(node
.n
, float):
175 raise RuntimeError('Pythonc currently does not support float literals')
176 assert isinstance(node
.n
, int)
177 return syntax
.IntConst(node
.n
)
179 def visit_Str(self
, node
):
180 assert isinstance(node
.s
, str)
181 return syntax
.StringConst(node
.s
)
183 def visit_Bytes(self
, node
):
184 raise RuntimeError('Pythonc currently does not support bytes literals')
187 def visit_Invert(self
, node
): return '__invert__'
188 def visit_Not(self
, node
): return '__not__'
189 def visit_UAdd(self
, node
): return '__pos__'
190 def visit_USub(self
, node
): return '__neg__'
191 def visit_UnaryOp(self
, node
):
192 op
= self
.visit(node
.op
)
193 rhs
= self
.flatten_node(node
.operand
)
194 return syntax
.UnaryOp(op
, rhs
)
197 def visit_Add(self
, node
): return '__add__'
198 def visit_BitAnd(self
, node
): return '__and__'
199 def visit_BitOr(self
, node
): return '__or__'
200 def visit_BitXor(self
, node
): return '__xor__'
201 def visit_Div(self
, node
): return '__truediv__'
202 def visit_FloorDiv(self
, node
): return '__floordiv__'
203 def visit_LShift(self
, node
): return '__lshift__'
204 def visit_Mod(self
, node
): return '__mod__'
205 def visit_Mult(self
, node
): return '__mul__'
206 def visit_Pow(self
, node
): return '__pow__'
207 def visit_RShift(self
, node
): return '__rshift__'
208 def visit_Sub(self
, node
): return '__sub__'
210 def visit_BinOp(self
, node
):
211 op
= self
.visit(node
.op
)
212 lhs
= self
.flatten_node(node
.left
)
213 rhs
= self
.flatten_node(node
.right
)
214 return syntax
.BinaryOp(op
, lhs
, rhs
)
217 def visit_Eq(self
, node
): return '__eq__'
218 def visit_NotEq(self
, node
): return '__ne__'
219 def visit_Lt(self
, node
): return '__lt__'
220 def visit_LtE(self
, node
): return '__le__'
221 def visit_Gt(self
, node
): return '__gt__'
222 def visit_GtE(self
, node
): return '__ge__'
223 def visit_In(self
, node
): return '__contains__'
224 def visit_NotIn(self
, node
): return '__ncontains__'
225 def visit_Is(self
, node
): return '__is__'
226 def visit_IsNot(self
, node
): return '__isnot__'
228 def visit_Compare(self
, node
):
229 assert len(node
.ops
) == 1
230 assert len(node
.comparators
) == 1
231 op
= self
.visit(node
.ops
[0])
232 lhs
= self
.flatten_node(node
.left
)
233 rhs
= self
.flatten_node(node
.comparators
[0])
234 # Sigh--Python has these ordered weirdly
235 if op
in ['__contains__', '__ncontains__']:
237 return syntax
.BinaryOp(op
, lhs
, rhs
)
240 def visit_And(self
, node
): return 'and'
241 def visit_Or(self
, node
): return 'or'
243 def visit_BoolOp(self
, node
):
244 assert len(node
.values
) >= 2
245 op
= self
.visit(node
.op
)
247 rhs_expr
= self
.flatten_node(node
.values
[-1], statements
=rhs_stmts
)
248 for v
in reversed(node
.values
[:-1]):
250 lhs
= self
.flatten_node(v
, statements
=lhs_stmts
)
251 bool_op
= syntax
.BoolOp(op
, lhs
, rhs_stmts
, rhs_expr
)
252 rhs_expr
= bool_op
.flatten(self
, lhs_stmts
)
253 rhs_stmts
= lhs_stmts
254 self
.statements
+= rhs_stmts
257 def visit_IfExp(self
, node
):
258 expr
= self
.flatten_node(node
.test
)
260 true_expr
= self
.flatten_node(node
.body
, statements
=true_stmts
)
262 false_expr
= self
.flatten_node(node
.orelse
, statements
=false_stmts
)
263 if_exp
= syntax
.IfExp(expr
, true_stmts
, true_expr
, false_stmts
, false_expr
)
264 return if_exp
.flatten(self
)
266 def visit_List(self
, node
):
267 items
= [self
.flatten_node(i
) for i
in node
.elts
]
268 l
= syntax
.List(items
)
269 return l
.flatten(self
)
271 def visit_Tuple(self
, node
):
272 items
= [self
.flatten_node(i
) for i
in node
.elts
]
273 l
= syntax
.Tuple(items
)
274 return l
.flatten(self
)
276 def visit_Dict(self
, node
):
277 keys
= [self
.flatten_node(i
) for i
in node
.keys
]
278 values
= [self
.flatten_node(i
) for i
in node
.values
]
279 d
= syntax
.Dict(keys
, values
)
280 return d
.flatten(self
)
282 def visit_Set(self
, node
):
283 items
= [self
.flatten_node(i
) for i
in node
.elts
]
284 d
= syntax
.Set(items
)
285 return d
.flatten(self
)
287 def visit_Subscript(self
, node
):
288 l
= self
.flatten_node(node
.value
)
289 if isinstance(node
.slice, ast
.Index
):
290 index
= self
.flatten_node(node
.slice.value
)
291 return syntax
.Subscript(l
, index
)
292 elif isinstance(node
.slice, ast
.Slice
):
293 [start
, end
, step
] = [self
.flatten_node(a
) if a
else syntax
.NoneConst() for a
in
294 [node
.slice.lower
, node
.slice.upper
, node
.slice.step
]]
295 return syntax
.Slice(l
, start
, end
, step
)
297 def visit_Attribute(self
, node
):
298 assert isinstance(node
.ctx
, ast
.Load
)
299 l
= self
.flatten_node(node
.value
)
300 attr
= syntax
.Attribute(l
, syntax
.StringConst(node
.attr
))
303 def visit_Call(self
, node
):
304 fn
= self
.flatten_node(node
.func
)
308 assert not node
.kwargs
309 args
= syntax
.Tuple(self
.flatten_node(node
.starargs
))
310 args
= args
.flatten(self
)
311 kwargs
= syntax
.Dict([], [])
313 args
= syntax
.Tuple([self
.flatten_node(a
) for a
in node
.args
])
314 args
= args
.flatten(self
)
316 keys
= [syntax
.StringConst(i
.arg
) for i
in node
.keywords
]
317 values
= [self
.flatten_node(i
.value
) for i
in node
.keywords
]
318 kwargs
= syntax
.Dict(keys
, values
)
320 kwargs
= kwargs
.flatten(self
)
321 return syntax
.Call(fn
, args
, kwargs
)
323 def visit_Assign(self
, node
):
324 assert len(node
.targets
) == 1
325 target
= node
.targets
[0]
326 value
= self
.flatten_node(node
.value
)
327 if isinstance(target
, ast
.Name
):
328 return [syntax
.Store(target
.id, value
, self
.get_binding(target
.id))]
329 elif isinstance(target
, ast
.Tuple
):
330 assert all(isinstance(t
, ast
.Name
) for t
in target
.elts
)
332 for i
, t
in enumerate(target
.elts
):
333 stmts
+= [syntax
.Store(t
.id, syntax
.Subscript(value
, syntax
.IntConst(i
)), self
.get_binding(t
.id))]
335 elif isinstance(target
, ast
.Attribute
):
336 base
= self
.flatten_node(target
.value
)
337 return [syntax
.StoreAttr(base
, syntax
.StringConst(target
.attr
), value
)]
338 elif isinstance(target
, ast
.Subscript
):
339 assert isinstance(target
.slice, ast
.Index
)
340 base
= self
.flatten_node(target
.value
)
341 index
= self
.flatten_node(target
.slice.value
)
342 return [syntax
.StoreSubscript(base
, index
, value
)]
346 def visit_AugAssign(self
, node
):
347 op
= self
.visit(node
.op
)
348 value
= self
.flatten_node(node
.value
)
349 if isinstance(node
.target
, ast
.Name
):
350 target
= node
.target
.id
351 # XXX HACK: doesn't modify in place
352 binop
= syntax
.BinaryOp(op
, syntax
.Load(target
, self
.get_binding(target
)), value
)
353 return [syntax
.Store(target
, binop
, self
.get_binding(target
))]
354 elif isinstance(node
.target
, ast
.Attribute
):
355 l
= self
.flatten_node(node
.target
.value
)
356 attr_name
= syntax
.StringConst(node
.target
.attr
)
357 attr
= syntax
.Attribute(l
, attr_name
)
358 binop
= syntax
.BinaryOp(op
, attr
, value
)
359 return [syntax
.StoreAttr(l
, attr_name
, binop
)]
360 elif isinstance(node
.target
, ast
.Subscript
):
361 assert isinstance(node
.target
.slice, ast
.Index
)
362 base
= self
.flatten_node(node
.target
.value
)
363 index
= self
.flatten_node(node
.target
.slice.value
)
364 old
= syntax
.Subscript(base
, index
)
365 binop
= syntax
.BinaryOp(op
, old
, value
)
366 return [syntax
.StoreSubscript(base
, index
, binop
)]
370 def visit_Delete(self
, node
):
371 assert len(node
.targets
) == 1
372 target
= node
.targets
[0]
373 assert isinstance(target
, ast
.Subscript
)
374 assert isinstance(target
.slice, ast
.Index
)
376 name
= self
.flatten_node(target
.value
)
377 value
= self
.flatten_node(target
.slice.value
)
378 return [syntax
.DeleteSubscript(name
, value
)]
380 def visit_If(self
, node
):
381 expr
= self
.flatten_node(node
.test
)
382 stmts
= self
.flatten_list(node
.body
)
384 else_block
= self
.flatten_list(node
.orelse
)
387 return syntax
.If(expr
, stmts
, else_block
)
389 def visit_Break(self
, node
):
390 return syntax
.Break()
392 def visit_Continue(self
, node
):
393 return syntax
.Continue()
395 def visit_For(self
, node
):
396 assert not node
.orelse
397 iter = self
.flatten_node(node
.iter)
398 stmts
= self
.flatten_list(node
.body
)
400 if isinstance(node
.target
, ast
.Name
):
401 target
= (node
.target
.id, self
.get_binding(node
.target
.id))
402 elif isinstance(node
.target
, ast
.Tuple
):
403 target
= [(t
.id, self
.get_binding(t
.id)) for t
in node
.target
.elts
]
406 # HACK: self.iter_temp gets set when enumerating symbols
407 for_loop
= syntax
.For(target
, iter, stmts
, node
.iter_temp
, self
.get_binding(node
.iter_temp
))
408 return for_loop
.flatten(self
)
410 def visit_While(self
, node
):
411 assert not node
.orelse
413 test
= self
.flatten_node(node
.test
, statements
=test_stmts
)
414 stmts
= self
.flatten_list(node
.body
)
415 return syntax
.While(test_stmts
, test
, stmts
)
417 # XXX We are just flattening "with x as y:" into "y = x" (this works in some simple cases with open()).
418 def visit_With(self
, node
):
419 assert node
.optional_vars
420 expr
= self
.flatten_node(node
.context_expr
)
421 stmts
= [syntax
.Store(node
.optional_vars
.id, expr
, self
.get_binding(node
.optional_vars
.id))]
422 stmts
+= self
.flatten_list(node
.body
)
425 def visit_Comprehension(self
, node
, comp_type
):
426 assert len(node
.generators
) == 1
427 gen
= node
.generators
[0]
428 assert len(gen
.ifs
) <= 1
430 if isinstance(gen
.target
, ast
.Name
):
431 target
= (gen
.target
.id, self
.get_binding(gen
.target
.id))
432 elif isinstance(gen
.target
, ast
.Tuple
):
433 target
= [(t
.id, self
.get_binding(t
.id)) for t
in gen
.target
.elts
]
437 iter = self
.flatten_node(gen
.iter)
442 cond
= self
.flatten_node(gen
.ifs
[0], statements
=cond_stmts
)
443 if comp_type
== 'dict':
444 expr
= self
.flatten_node(node
.key
, statements
=expr_stmts
)
445 expr2
= self
.flatten_node(node
.value
, statements
=expr_stmts
)
447 expr
= self
.flatten_node(node
.elt
, statements
=expr_stmts
)
449 comp
= syntax
.Comprehension(comp_type
, target
, iter, node
.iter_temp
,
450 self
.get_binding(node
.iter_temp
), cond_stmts
, cond
, expr_stmts
,
452 return comp
.flatten(self
)
454 def visit_ListComp(self
, node
):
455 return self
.visit_Comprehension(node
, 'list')
457 def visit_SetComp(self
, node
):
458 return self
.visit_Comprehension(node
, 'set')
460 def visit_DictComp(self
, node
):
461 return self
.visit_Comprehension(node
, 'dict')
463 def visit_GeneratorExp(self
, node
):
464 return self
.visit_Comprehension(node
, 'generator')
466 def visit_Return(self
, node
):
467 if node
.value
is not None:
468 expr
= self
.flatten_node(node
.value
)
469 return syntax
.Return(expr
)
471 return syntax
.Return(None)
473 def visit_Assert(self
, node
):
474 expr
= self
.flatten_node(node
.test
)
475 return syntax
.Assert(expr
, node
.lineno
)
477 def visit_Raise(self
, node
):
478 assert not node
.cause
479 expr
= self
.flatten_node(node
.exc
)
480 return syntax
.Raise(expr
, node
.lineno
)
482 def visit_arguments(self
, node
):
483 assert not node
.vararg
484 assert not node
.kwarg
486 args
= [a
.arg
for a
in node
.args
]
487 binding
= [self
.get_binding(a
) for a
in args
]
488 defaults
= self
.flatten_list(node
.defaults
)
489 args
= syntax
.Arguments(args
, binding
, defaults
)
490 return args
.flatten(self
)
492 def visit_FunctionDef(self
, node
):
493 assert not self
.in_function
495 # Get bindings of all variables. Globals are the variables that have "global x"
496 # somewhere in the function, or are never written in the function.
500 self
.get_globals(node
, globals_set
, locals_set
, all_vars_set
)
501 globals_set |
= (all_vars_set
- locals_set
)
503 self
.symbol_idx
['local'] = {symbol
: idx
for idx
, symbol
in enumerate(sorted(locals_set
))}
505 # Set some state and recursively visit child nodes, then restore state
506 self
.globals_set
= globals_set
507 self
.in_function
= True
508 args
= self
.visit(node
.args
)
509 body
= self
.flatten_list(node
.body
)
510 self
.globals_set
= None
511 self
.in_function
= False
513 exp_name
= node
.exp_name
if 'exp_name' in dir(node
) else None
514 fn
= syntax
.FunctionDef(node
.name
, args
, body
, exp_name
, self
.get_binding(node
.name
), len(locals_set
))
515 return fn
.flatten(self
)
517 def visit_ClassDef(self
, node
):
518 assert not node
.bases
519 assert not node
.keywords
520 assert not node
.starargs
521 assert not node
.kwargs
522 assert not node
.decorator_list
523 assert not self
.in_class
524 assert not self
.in_function
527 if isinstance(fn
, ast
.FunctionDef
):
528 fn
.exp_name
= '_%s_%s' % (node
.name
, fn
.name
)
531 body
= self
.flatten_list(node
.body
)
532 self
.in_class
= False
534 c
= syntax
.ClassDef(node
.name
, self
.get_binding(node
.name
), body
)
535 return c
.flatten(self
)
537 # XXX This just turns "import x" into "x = 0". It's certainly not what we really want...
538 def visit_Import(self
, node
):
540 for name
in node
.names
:
541 assert not name
.asname
543 statements
.append(syntax
.Store(name
.name
, syntax
.IntConst(0), self
.get_binding(name
.name
)))
546 def visit_Expr(self
, node
):
547 return self
.visit(node
.value
)
549 def visit_Module(self
, node
):
550 # Set up an index of all possible global/class symbols
551 all_global_syms
= set()
552 all_class_syms
= set()
553 self
.index_global_class_symbols(node
, all_global_syms
, all_class_syms
)
555 all_global_syms
.add('$undefined')
556 all_global_syms |
= set(builtin_symbols
)
559 scope
: {symbol
: idx
for idx
, symbol
in enumerate(sorted(symbols
))}
560 for scope
, symbols
in [['class', all_class_syms
], ['global', all_global_syms
]]
562 self
.global_sym_count
= len(all_global_syms
)
563 self
.class_sym_count
= len(all_class_syms
)
565 return self
.flatten_list(node
.body
)
567 def visit_Pass(self
, node
): pass
568 def visit_Load(self
, node
): pass
569 def visit_Store(self
, node
): pass
570 def visit_Global(self
, node
): pass
572 with
open(sys
.argv
[1]) as f
:
573 node
= ast
.parse(f
.read())
575 transformer
= Transformer()
576 node
= transformer
.visit(node
)
578 with
open(sys
.argv
[2], 'w') as f
:
579 f
.write('#define LIST_BUILTIN_FUNCTIONS(x) %s\n' % ' '.join('x(%s)' % x
580 for x
in builtin_functions
))
581 f
.write('#define LIST_BUILTIN_CLASSES(x) %s\n' % ' '.join('x(%s)' % x
582 for x
in builtin_classes
))
583 for x
in builtin_symbols
:
584 f
.write('#define sym_id_%s %s\n' % (x
, transformer
.symbol_idx
['global'][x
]))
585 f
.write('#include "backend.cpp"\n')
586 syntax
.export_consts(f
)
588 for func
in transformer
.functions
:
589 f
.write('%s\n' % func
)
591 f
.write('int main(int argc, char **argv) {\n')
592 f
.write(' node *global_syms[%s] = {0};\n' % (transformer
.global_sym_count
))
593 f
.write(' context ctx(%s, global_syms), *globals = &ctx;\n' % (transformer
.global_sym_count
))
594 f
.write(' init_context(&ctx, argc, argv);\n')
597 f
.write(' %s;\n' % stmt
)