Add NEWS entry as per RDM's suggestion (the bug was actually present
[python.git] / Lib / compiler / pyassem.py
blobbe0255d6aa2e3d425362288dc31dfa45dbe0d534
1 """A flow graph representation for Python bytecode"""
3 import dis
4 import types
5 import sys
7 from compiler import misc
8 from compiler.consts \
9 import CO_OPTIMIZED, CO_NEWLOCALS, CO_VARARGS, CO_VARKEYWORDS
11 class FlowGraph:
12 def __init__(self):
13 self.current = self.entry = Block()
14 self.exit = Block("exit")
15 self.blocks = misc.Set()
16 self.blocks.add(self.entry)
17 self.blocks.add(self.exit)
19 def startBlock(self, block):
20 if self._debug:
21 if self.current:
22 print "end", repr(self.current)
23 print " next", self.current.next
24 print " prev", self.current.prev
25 print " ", self.current.get_children()
26 print repr(block)
27 self.current = block
29 def nextBlock(self, block=None):
30 # XXX think we need to specify when there is implicit transfer
31 # from one block to the next. might be better to represent this
32 # with explicit JUMP_ABSOLUTE instructions that are optimized
33 # out when they are unnecessary.
35 # I think this strategy works: each block has a child
36 # designated as "next" which is returned as the last of the
37 # children. because the nodes in a graph are emitted in
38 # reverse post order, the "next" block will always be emitted
39 # immediately after its parent.
40 # Worry: maintaining this invariant could be tricky
41 if block is None:
42 block = self.newBlock()
44 # Note: If the current block ends with an unconditional control
45 # transfer, then it is techically incorrect to add an implicit
46 # transfer to the block graph. Doing so results in code generation
47 # for unreachable blocks. That doesn't appear to be very common
48 # with Python code and since the built-in compiler doesn't optimize
49 # it out we don't either.
50 self.current.addNext(block)
51 self.startBlock(block)
53 def newBlock(self):
54 b = Block()
55 self.blocks.add(b)
56 return b
58 def startExitBlock(self):
59 self.startBlock(self.exit)
61 _debug = 0
63 def _enable_debug(self):
64 self._debug = 1
66 def _disable_debug(self):
67 self._debug = 0
69 def emit(self, *inst):
70 if self._debug:
71 print "\t", inst
72 if len(inst) == 2 and isinstance(inst[1], Block):
73 self.current.addOutEdge(inst[1])
74 self.current.emit(inst)
76 def getBlocksInOrder(self):
77 """Return the blocks in reverse postorder
79 i.e. each node appears before all of its successors
80 """
81 order = order_blocks(self.entry, self.exit)
82 return order
84 def getBlocks(self):
85 return self.blocks.elements()
87 def getRoot(self):
88 """Return nodes appropriate for use with dominator"""
89 return self.entry
91 def getContainedGraphs(self):
92 l = []
93 for b in self.getBlocks():
94 l.extend(b.getContainedGraphs())
95 return l
98 def order_blocks(start_block, exit_block):
99 """Order blocks so that they are emitted in the right order"""
100 # Rules:
101 # - when a block has a next block, the next block must be emitted just after
102 # - when a block has followers (relative jumps), it must be emitted before
103 # them
104 # - all reachable blocks must be emitted
105 order = []
107 # Find all the blocks to be emitted.
108 remaining = set()
109 todo = [start_block]
110 while todo:
111 b = todo.pop()
112 if b in remaining:
113 continue
114 remaining.add(b)
115 for c in b.get_children():
116 if c not in remaining:
117 todo.append(c)
119 # A block is dominated by another block if that block must be emitted
120 # before it.
121 dominators = {}
122 for b in remaining:
123 if __debug__ and b.next:
124 assert b is b.next[0].prev[0], (b, b.next)
125 # Make sure every block appears in dominators, even if no
126 # other block must precede it.
127 dominators.setdefault(b, set())
128 # preceeding blocks dominate following blocks
129 for c in b.get_followers():
130 while 1:
131 dominators.setdefault(c, set()).add(b)
132 # Any block that has a next pointer leading to c is also
133 # dominated because the whole chain will be emitted at once.
134 # Walk backwards and add them all.
135 if c.prev and c.prev[0] is not b:
136 c = c.prev[0]
137 else:
138 break
140 def find_next():
141 # Find a block that can be emitted next.
142 for b in remaining:
143 for c in dominators[b]:
144 if c in remaining:
145 break # can't emit yet, dominated by a remaining block
146 else:
147 return b
148 assert 0, 'circular dependency, cannot find next block'
150 b = start_block
151 while 1:
152 order.append(b)
153 remaining.discard(b)
154 if b.next:
155 b = b.next[0]
156 continue
157 elif b is not exit_block and not b.has_unconditional_transfer():
158 order.append(exit_block)
159 if not remaining:
160 break
161 b = find_next()
162 return order
165 class Block:
166 _count = 0
168 def __init__(self, label=''):
169 self.insts = []
170 self.outEdges = set()
171 self.label = label
172 self.bid = Block._count
173 self.next = []
174 self.prev = []
175 Block._count = Block._count + 1
177 def __repr__(self):
178 if self.label:
179 return "<block %s id=%d>" % (self.label, self.bid)
180 else:
181 return "<block id=%d>" % (self.bid)
183 def __str__(self):
184 insts = map(str, self.insts)
185 return "<block %s %d:\n%s>" % (self.label, self.bid,
186 '\n'.join(insts))
188 def emit(self, inst):
189 op = inst[0]
190 self.insts.append(inst)
192 def getInstructions(self):
193 return self.insts
195 def addOutEdge(self, block):
196 self.outEdges.add(block)
198 def addNext(self, block):
199 self.next.append(block)
200 assert len(self.next) == 1, map(str, self.next)
201 block.prev.append(self)
202 assert len(block.prev) == 1, map(str, block.prev)
204 _uncond_transfer = ('RETURN_VALUE', 'RAISE_VARARGS',
205 'JUMP_ABSOLUTE', 'JUMP_FORWARD', 'CONTINUE_LOOP',
208 def has_unconditional_transfer(self):
209 """Returns True if there is an unconditional transfer to an other block
210 at the end of this block. This means there is no risk for the bytecode
211 executer to go past this block's bytecode."""
212 try:
213 op, arg = self.insts[-1]
214 except (IndexError, ValueError):
215 return
216 return op in self._uncond_transfer
218 def get_children(self):
219 return list(self.outEdges) + self.next
221 def get_followers(self):
222 """Get the whole list of followers, including the next block."""
223 followers = set(self.next)
224 # Blocks that must be emitted *after* this one, because of
225 # bytecode offsets (e.g. relative jumps) pointing to them.
226 for inst in self.insts:
227 if inst[0] in PyFlowGraph.hasjrel:
228 followers.add(inst[1])
229 return followers
231 def getContainedGraphs(self):
232 """Return all graphs contained within this block.
234 For example, a MAKE_FUNCTION block will contain a reference to
235 the graph for the function body.
237 contained = []
238 for inst in self.insts:
239 if len(inst) == 1:
240 continue
241 op = inst[1]
242 if hasattr(op, 'graph'):
243 contained.append(op.graph)
244 return contained
246 # flags for code objects
248 # the FlowGraph is transformed in place; it exists in one of these states
249 RAW = "RAW"
250 FLAT = "FLAT"
251 CONV = "CONV"
252 DONE = "DONE"
254 class PyFlowGraph(FlowGraph):
255 super_init = FlowGraph.__init__
257 def __init__(self, name, filename, args=(), optimized=0, klass=None):
258 self.super_init()
259 self.name = name
260 self.filename = filename
261 self.docstring = None
262 self.args = args # XXX
263 self.argcount = getArgCount(args)
264 self.klass = klass
265 if optimized:
266 self.flags = CO_OPTIMIZED | CO_NEWLOCALS
267 else:
268 self.flags = 0
269 self.consts = []
270 self.names = []
271 # Free variables found by the symbol table scan, including
272 # variables used only in nested scopes, are included here.
273 self.freevars = []
274 self.cellvars = []
275 # The closure list is used to track the order of cell
276 # variables and free variables in the resulting code object.
277 # The offsets used by LOAD_CLOSURE/LOAD_DEREF refer to both
278 # kinds of variables.
279 self.closure = []
280 self.varnames = list(args) or []
281 for i in range(len(self.varnames)):
282 var = self.varnames[i]
283 if isinstance(var, TupleArg):
284 self.varnames[i] = var.getName()
285 self.stage = RAW
287 def setDocstring(self, doc):
288 self.docstring = doc
290 def setFlag(self, flag):
291 self.flags = self.flags | flag
292 if flag == CO_VARARGS:
293 self.argcount = self.argcount - 1
295 def checkFlag(self, flag):
296 if self.flags & flag:
297 return 1
299 def setFreeVars(self, names):
300 self.freevars = list(names)
302 def setCellVars(self, names):
303 self.cellvars = names
305 def getCode(self):
306 """Get a Python code object"""
307 assert self.stage == RAW
308 self.computeStackDepth()
309 self.flattenGraph()
310 assert self.stage == FLAT
311 self.convertArgs()
312 assert self.stage == CONV
313 self.makeByteCode()
314 assert self.stage == DONE
315 return self.newCodeObject()
317 def dump(self, io=None):
318 if io:
319 save = sys.stdout
320 sys.stdout = io
321 pc = 0
322 for t in self.insts:
323 opname = t[0]
324 if opname == "SET_LINENO":
325 print
326 if len(t) == 1:
327 print "\t", "%3d" % pc, opname
328 pc = pc + 1
329 else:
330 print "\t", "%3d" % pc, opname, t[1]
331 pc = pc + 3
332 if io:
333 sys.stdout = save
335 def computeStackDepth(self):
336 """Compute the max stack depth.
338 Approach is to compute the stack effect of each basic block.
339 Then find the path through the code with the largest total
340 effect.
342 depth = {}
343 exit = None
344 for b in self.getBlocks():
345 depth[b] = findDepth(b.getInstructions())
347 seen = {}
349 def max_depth(b, d):
350 if b in seen:
351 return d
352 seen[b] = 1
353 d = d + depth[b]
354 children = b.get_children()
355 if children:
356 return max([max_depth(c, d) for c in children])
357 else:
358 if not b.label == "exit":
359 return max_depth(self.exit, d)
360 else:
361 return d
363 self.stacksize = max_depth(self.entry, 0)
365 def flattenGraph(self):
366 """Arrange the blocks in order and resolve jumps"""
367 assert self.stage == RAW
368 self.insts = insts = []
369 pc = 0
370 begin = {}
371 end = {}
372 for b in self.getBlocksInOrder():
373 begin[b] = pc
374 for inst in b.getInstructions():
375 insts.append(inst)
376 if len(inst) == 1:
377 pc = pc + 1
378 elif inst[0] != "SET_LINENO":
379 # arg takes 2 bytes
380 pc = pc + 3
381 end[b] = pc
382 pc = 0
383 for i in range(len(insts)):
384 inst = insts[i]
385 if len(inst) == 1:
386 pc = pc + 1
387 elif inst[0] != "SET_LINENO":
388 pc = pc + 3
389 opname = inst[0]
390 if opname in self.hasjrel:
391 oparg = inst[1]
392 offset = begin[oparg] - pc
393 insts[i] = opname, offset
394 elif opname in self.hasjabs:
395 insts[i] = opname, begin[inst[1]]
396 self.stage = FLAT
398 hasjrel = set()
399 for i in dis.hasjrel:
400 hasjrel.add(dis.opname[i])
401 hasjabs = set()
402 for i in dis.hasjabs:
403 hasjabs.add(dis.opname[i])
405 def convertArgs(self):
406 """Convert arguments from symbolic to concrete form"""
407 assert self.stage == FLAT
408 self.consts.insert(0, self.docstring)
409 self.sort_cellvars()
410 for i in range(len(self.insts)):
411 t = self.insts[i]
412 if len(t) == 2:
413 opname, oparg = t
414 conv = self._converters.get(opname, None)
415 if conv:
416 self.insts[i] = opname, conv(self, oparg)
417 self.stage = CONV
419 def sort_cellvars(self):
420 """Sort cellvars in the order of varnames and prune from freevars.
422 cells = {}
423 for name in self.cellvars:
424 cells[name] = 1
425 self.cellvars = [name for name in self.varnames
426 if name in cells]
427 for name in self.cellvars:
428 del cells[name]
429 self.cellvars = self.cellvars + cells.keys()
430 self.closure = self.cellvars + self.freevars
432 def _lookupName(self, name, list):
433 """Return index of name in list, appending if necessary
435 This routine uses a list instead of a dictionary, because a
436 dictionary can't store two different keys if the keys have the
437 same value but different types, e.g. 2 and 2L. The compiler
438 must treat these two separately, so it does an explicit type
439 comparison before comparing the values.
441 t = type(name)
442 for i in range(len(list)):
443 if t == type(list[i]) and list[i] == name:
444 return i
445 end = len(list)
446 list.append(name)
447 return end
449 _converters = {}
450 def _convert_LOAD_CONST(self, arg):
451 if hasattr(arg, 'getCode'):
452 arg = arg.getCode()
453 return self._lookupName(arg, self.consts)
455 def _convert_LOAD_FAST(self, arg):
456 self._lookupName(arg, self.names)
457 return self._lookupName(arg, self.varnames)
458 _convert_STORE_FAST = _convert_LOAD_FAST
459 _convert_DELETE_FAST = _convert_LOAD_FAST
461 def _convert_LOAD_NAME(self, arg):
462 if self.klass is None:
463 self._lookupName(arg, self.varnames)
464 return self._lookupName(arg, self.names)
466 def _convert_NAME(self, arg):
467 if self.klass is None:
468 self._lookupName(arg, self.varnames)
469 return self._lookupName(arg, self.names)
470 _convert_STORE_NAME = _convert_NAME
471 _convert_DELETE_NAME = _convert_NAME
472 _convert_IMPORT_NAME = _convert_NAME
473 _convert_IMPORT_FROM = _convert_NAME
474 _convert_STORE_ATTR = _convert_NAME
475 _convert_LOAD_ATTR = _convert_NAME
476 _convert_DELETE_ATTR = _convert_NAME
477 _convert_LOAD_GLOBAL = _convert_NAME
478 _convert_STORE_GLOBAL = _convert_NAME
479 _convert_DELETE_GLOBAL = _convert_NAME
481 def _convert_DEREF(self, arg):
482 self._lookupName(arg, self.names)
483 self._lookupName(arg, self.varnames)
484 return self._lookupName(arg, self.closure)
485 _convert_LOAD_DEREF = _convert_DEREF
486 _convert_STORE_DEREF = _convert_DEREF
488 def _convert_LOAD_CLOSURE(self, arg):
489 self._lookupName(arg, self.varnames)
490 return self._lookupName(arg, self.closure)
492 _cmp = list(dis.cmp_op)
493 def _convert_COMPARE_OP(self, arg):
494 return self._cmp.index(arg)
496 # similarly for other opcodes...
498 for name, obj in locals().items():
499 if name[:9] == "_convert_":
500 opname = name[9:]
501 _converters[opname] = obj
502 del name, obj, opname
504 def makeByteCode(self):
505 assert self.stage == CONV
506 self.lnotab = lnotab = LineAddrTable()
507 for t in self.insts:
508 opname = t[0]
509 if len(t) == 1:
510 lnotab.addCode(self.opnum[opname])
511 else:
512 oparg = t[1]
513 if opname == "SET_LINENO":
514 lnotab.nextLine(oparg)
515 continue
516 hi, lo = twobyte(oparg)
517 try:
518 lnotab.addCode(self.opnum[opname], lo, hi)
519 except ValueError:
520 print opname, oparg
521 print self.opnum[opname], lo, hi
522 raise
523 self.stage = DONE
525 opnum = {}
526 for num in range(len(dis.opname)):
527 opnum[dis.opname[num]] = num
528 del num
530 def newCodeObject(self):
531 assert self.stage == DONE
532 if (self.flags & CO_NEWLOCALS) == 0:
533 nlocals = 0
534 else:
535 nlocals = len(self.varnames)
536 argcount = self.argcount
537 if self.flags & CO_VARKEYWORDS:
538 argcount = argcount - 1
539 return types.CodeType(argcount, nlocals, self.stacksize, self.flags,
540 self.lnotab.getCode(), self.getConsts(),
541 tuple(self.names), tuple(self.varnames),
542 self.filename, self.name, self.lnotab.firstline,
543 self.lnotab.getTable(), tuple(self.freevars),
544 tuple(self.cellvars))
546 def getConsts(self):
547 """Return a tuple for the const slot of the code object
549 Must convert references to code (MAKE_FUNCTION) to code
550 objects recursively.
552 l = []
553 for elt in self.consts:
554 if isinstance(elt, PyFlowGraph):
555 elt = elt.getCode()
556 l.append(elt)
557 return tuple(l)
559 def isJump(opname):
560 if opname[:4] == 'JUMP':
561 return 1
563 class TupleArg:
564 """Helper for marking func defs with nested tuples in arglist"""
565 def __init__(self, count, names):
566 self.count = count
567 self.names = names
568 def __repr__(self):
569 return "TupleArg(%s, %s)" % (self.count, self.names)
570 def getName(self):
571 return ".%d" % self.count
573 def getArgCount(args):
574 argcount = len(args)
575 if args:
576 for arg in args:
577 if isinstance(arg, TupleArg):
578 numNames = len(misc.flatten(arg.names))
579 argcount = argcount - numNames
580 return argcount
582 def twobyte(val):
583 """Convert an int argument into high and low bytes"""
584 assert isinstance(val, int)
585 return divmod(val, 256)
587 class LineAddrTable:
588 """lnotab
590 This class builds the lnotab, which is documented in compile.c.
591 Here's a brief recap:
593 For each SET_LINENO instruction after the first one, two bytes are
594 added to lnotab. (In some cases, multiple two-byte entries are
595 added.) The first byte is the distance in bytes between the
596 instruction for the last SET_LINENO and the current SET_LINENO.
597 The second byte is offset in line numbers. If either offset is
598 greater than 255, multiple two-byte entries are added -- see
599 compile.c for the delicate details.
602 def __init__(self):
603 self.code = []
604 self.codeOffset = 0
605 self.firstline = 0
606 self.lastline = 0
607 self.lastoff = 0
608 self.lnotab = []
610 def addCode(self, *args):
611 for arg in args:
612 self.code.append(chr(arg))
613 self.codeOffset = self.codeOffset + len(args)
615 def nextLine(self, lineno):
616 if self.firstline == 0:
617 self.firstline = lineno
618 self.lastline = lineno
619 else:
620 # compute deltas
621 addr = self.codeOffset - self.lastoff
622 line = lineno - self.lastline
623 # Python assumes that lineno always increases with
624 # increasing bytecode address (lnotab is unsigned char).
625 # Depending on when SET_LINENO instructions are emitted
626 # this is not always true. Consider the code:
627 # a = (1,
628 # b)
629 # In the bytecode stream, the assignment to "a" occurs
630 # after the loading of "b". This works with the C Python
631 # compiler because it only generates a SET_LINENO instruction
632 # for the assignment.
633 if line >= 0:
634 push = self.lnotab.append
635 while addr > 255:
636 push(255); push(0)
637 addr -= 255
638 while line > 255:
639 push(addr); push(255)
640 line -= 255
641 addr = 0
642 if addr > 0 or line > 0:
643 push(addr); push(line)
644 self.lastline = lineno
645 self.lastoff = self.codeOffset
647 def getCode(self):
648 return ''.join(self.code)
650 def getTable(self):
651 return ''.join(map(chr, self.lnotab))
653 class StackDepthTracker:
654 # XXX 1. need to keep track of stack depth on jumps
655 # XXX 2. at least partly as a result, this code is broken
657 def findDepth(self, insts, debug=0):
658 depth = 0
659 maxDepth = 0
660 for i in insts:
661 opname = i[0]
662 if debug:
663 print i,
664 delta = self.effect.get(opname, None)
665 if delta is not None:
666 depth = depth + delta
667 else:
668 # now check patterns
669 for pat, pat_delta in self.patterns:
670 if opname[:len(pat)] == pat:
671 delta = pat_delta
672 depth = depth + delta
673 break
674 # if we still haven't found a match
675 if delta is None:
676 meth = getattr(self, opname, None)
677 if meth is not None:
678 depth = depth + meth(i[1])
679 if depth > maxDepth:
680 maxDepth = depth
681 if debug:
682 print depth, maxDepth
683 return maxDepth
685 effect = {
686 'POP_TOP': -1,
687 'DUP_TOP': 1,
688 'LIST_APPEND': -2,
689 'SLICE+1': -1,
690 'SLICE+2': -1,
691 'SLICE+3': -2,
692 'STORE_SLICE+0': -1,
693 'STORE_SLICE+1': -2,
694 'STORE_SLICE+2': -2,
695 'STORE_SLICE+3': -3,
696 'DELETE_SLICE+0': -1,
697 'DELETE_SLICE+1': -2,
698 'DELETE_SLICE+2': -2,
699 'DELETE_SLICE+3': -3,
700 'STORE_SUBSCR': -3,
701 'DELETE_SUBSCR': -2,
702 # PRINT_EXPR?
703 'PRINT_ITEM': -1,
704 'RETURN_VALUE': -1,
705 'YIELD_VALUE': -1,
706 'EXEC_STMT': -3,
707 'BUILD_CLASS': -2,
708 'STORE_NAME': -1,
709 'STORE_ATTR': -2,
710 'DELETE_ATTR': -1,
711 'STORE_GLOBAL': -1,
712 'BUILD_MAP': 1,
713 'COMPARE_OP': -1,
714 'STORE_FAST': -1,
715 'IMPORT_STAR': -1,
716 'IMPORT_NAME': -1,
717 'IMPORT_FROM': 1,
718 'LOAD_ATTR': 0, # unlike other loads
719 # close enough...
720 'SETUP_EXCEPT': 3,
721 'SETUP_FINALLY': 3,
722 'FOR_ITER': 1,
723 'WITH_CLEANUP': -1,
725 # use pattern match
726 patterns = [
727 ('BINARY_', -1),
728 ('LOAD_', 1),
731 def UNPACK_SEQUENCE(self, count):
732 return count-1
733 def BUILD_TUPLE(self, count):
734 return -count+1
735 def BUILD_LIST(self, count):
736 return -count+1
737 def CALL_FUNCTION(self, argc):
738 hi, lo = divmod(argc, 256)
739 return -(lo + hi * 2)
740 def CALL_FUNCTION_VAR(self, argc):
741 return self.CALL_FUNCTION(argc)-1
742 def CALL_FUNCTION_KW(self, argc):
743 return self.CALL_FUNCTION(argc)-1
744 def CALL_FUNCTION_VAR_KW(self, argc):
745 return self.CALL_FUNCTION(argc)-2
746 def MAKE_FUNCTION(self, argc):
747 return -argc
748 def MAKE_CLOSURE(self, argc):
749 # XXX need to account for free variables too!
750 return -argc
751 def BUILD_SLICE(self, argc):
752 if argc == 2:
753 return -1
754 elif argc == 3:
755 return -2
756 def DUP_TOPX(self, argc):
757 return argc
759 findDepth = StackDepthTracker().findDepth