1 #!/usr/bin/env python2.5
2 # Copyright 2006 Google, Inc. All Rights Reserved.
3 # Licensed to PSF under a Contributor Agreement.
5 """Refactoring framework.
7 Used as a main program, this can refactor any number of files and/or
8 recursively descend down directories. Imported as a module, this
9 provides infrastructure to write your own refactoring tool.
12 __author__
= "Guido van Rossum <guido@python.org>"
21 from collections
import defaultdict
22 from itertools
import chain
25 from .pgen2
import driver
26 from .pgen2
import tokenize
36 Call without arguments to use sys.argv[1:] as the arguments; or
37 call with a list of arguments (excluding sys.argv[0]).
39 Returns a suggested exit status (0, 1, 2).
41 # Set up option parser
42 parser
= optparse
.OptionParser(usage
="refactor.py [options] file|dir ...")
43 parser
.add_option("-d", "--doctests_only", action
="store_true",
44 help="Fix up doctests only")
45 parser
.add_option("-f", "--fix", action
="append", default
=[],
46 help="Each FIX specifies a transformation; default all")
47 parser
.add_option("-l", "--list-fixes", action
="store_true",
48 help="List available transformations (fixes/fix_*.py)")
49 parser
.add_option("-p", "--print-function", action
="store_true",
50 help="Modify the grammar so that print() is a function")
51 parser
.add_option("-v", "--verbose", action
="store_true",
52 help="More verbose logging")
53 parser
.add_option("-w", "--write", action
="store_true",
54 help="Write back modified files")
56 # Parse command line arguments
57 options
, args
= parser
.parse_args(args
)
58 if options
.list_fixes
:
59 print "Available transformations for the -f/--fix option:"
60 for fixname
in get_all_fix_names():
65 print >>sys
.stderr
, "At least one file or directory argument required."
66 print >>sys
.stderr
, "Use --help to show usage."
69 # Set up logging handler
70 if sys
.version_info
< (2, 4):
71 hdlr
= logging
.StreamHandler()
72 fmt
= logging
.Formatter('%(name)s: %(message)s')
73 hdlr
.setFormatter(fmt
)
74 logging
.root
.addHandler(hdlr
)
76 logging
.basicConfig(format
='%(name)s: %(message)s', level
=logging
.INFO
)
78 # Initialize the refactoring tool
79 rt
= RefactoringTool(options
)
81 # Refactor all files and directories passed as arguments
83 rt
.refactor_args(args
)
86 # Return error status (0 if rt.errors is zero)
87 return int(bool(rt
.errors
))
90 def get_all_fix_names():
91 """Return a sorted list of all available fix names."""
93 names
= os
.listdir(os
.path
.dirname(fixes
.__file
__))
96 if name
.startswith("fix_") and name
.endswith(".py"):
97 fix_names
.append(name
[4:-3])
101 def get_head_types(pat
):
102 """ Accepts a pytree Pattern Node and returns a set
103 of the pattern types which will match first. """
105 if isinstance(pat
, (pytree
.NodePattern
, pytree
.LeafPattern
)):
106 # NodePatters must either have no type and no content
107 # or a type and content -- so they don't get any farther
108 # Always return leafs
109 return set([pat
.type])
111 if isinstance(pat
, pytree
.NegatedPattern
):
113 return get_head_types(pat
.content
)
114 return set([None]) # Negated Patterns don't have a type
116 if isinstance(pat
, pytree
.WildcardPattern
):
117 # Recurse on each node in content
119 for p
in pat
.content
:
121 r
.update(get_head_types(x
))
124 raise Exception("Oh no! I don't understand pattern %s" %(pat))
126 def get_headnode_dict(fixer_list
):
127 """ Accepts a list of fixers and returns a dictionary
128 of head node type --> fixer list. """
129 head_nodes
= defaultdict(list)
130 for fixer
in fixer_list
:
131 if not fixer
.pattern
:
132 head_nodes
[None].append(fixer
)
134 for t
in get_head_types(fixer
.pattern
):
135 head_nodes
[t
].append(fixer
)
139 class RefactoringTool(object):
141 def __init__(self
, options
):
144 The argument is an optparse.Values instance.
146 self
.options
= options
148 self
.logger
= logging
.getLogger("RefactoringTool")
150 if self
.options
.print_function
:
151 del pygram
.python_grammar
.keywords
["print"]
152 self
.driver
= driver
.Driver(pygram
.python_grammar
,
153 convert
=pytree
.convert
,
155 self
.pre_order
, self
.post_order
= self
.get_fixers()
157 self
.pre_order
= get_headnode_dict(self
.pre_order
)
158 self
.post_order
= get_headnode_dict(self
.post_order
)
160 self
.files
= [] # List of files that were or should be modified
162 def get_fixers(self
):
163 """Inspects the options to load the requested patterns and handlers.
166 (pre_order, post_order), where pre_order is the list of fixers that
167 want a pre-order AST traversal, and post_order is the list that want
168 post-order traversal.
170 pre_order_fixers
= []
171 post_order_fixers
= []
172 fix_names
= self
.options
.fix
173 if not fix_names
or "all" in fix_names
:
174 fix_names
= get_all_fix_names()
175 for fix_name
in fix_names
:
177 mod
= __import__("lib2to3.fixes.fix_" + fix_name
, {}, {}, ["*"])
179 self
.log_error("Can't find transformation %s", fix_name
)
181 parts
= fix_name
.split("_")
182 class_name
= "Fix" + "".join([p
.title() for p
in parts
])
184 fix_class
= getattr(mod
, class_name
)
185 except AttributeError:
186 self
.log_error("Can't find fixes.fix_%s.%s",
187 fix_name
, class_name
)
190 fixer
= fix_class(self
.options
, self
.fixer_log
)
191 except Exception, err
:
192 self
.log_error("Can't instantiate fixes.fix_%s.%s()",
193 fix_name
, class_name
, exc_info
=True)
195 if fixer
.explicit
and fix_name
not in self
.options
.fix
:
196 self
.log_message("Skipping implicit fixer: %s", fix_name
)
199 if self
.options
.verbose
:
200 self
.log_message("Adding transformation: %s", fix_name
)
201 if fixer
.order
== "pre":
202 pre_order_fixers
.append(fixer
)
203 elif fixer
.order
== "post":
204 post_order_fixers
.append(fixer
)
206 raise ValueError("Illegal fixer order: %r" % fixer
.order
)
208 pre_order_fixers
.sort(key
=lambda x
: x
.run_order
)
209 post_order_fixers
.sort(key
=lambda x
: x
.run_order
)
210 return (pre_order_fixers
, post_order_fixers
)
212 def log_error(self
, msg
, *args
, **kwds
):
213 """Increments error count and log a message."""
214 self
.errors
.append((msg
, args
, kwds
))
215 self
.logger
.error(msg
, *args
, **kwds
)
217 def log_message(self
, msg
, *args
):
218 """Hook to log a message."""
221 self
.logger
.info(msg
)
223 def refactor_args(self
, args
):
224 """Refactors files and directories from an argument list."""
227 self
.refactor_stdin()
228 elif os
.path
.isdir(arg
):
229 self
.refactor_dir(arg
)
231 self
.refactor_file(arg
)
233 def refactor_dir(self
, arg
):
234 """Descends down a directory and refactor every Python file found.
236 Python files are assumed to have a .py extension.
238 Files and subdirectories starting with '.' are skipped.
240 for dirpath
, dirnames
, filenames
in os
.walk(arg
):
241 if self
.options
.verbose
:
242 self
.log_message("Descending into %s", dirpath
)
245 for name
in filenames
:
246 if not name
.startswith(".") and name
.endswith("py"):
247 fullname
= os
.path
.join(dirpath
, name
)
248 self
.refactor_file(fullname
)
249 # Modify dirnames in-place to remove subdirs with leading dots
250 dirnames
[:] = [dn
for dn
in dirnames
if not dn
.startswith(".")]
252 def refactor_file(self
, filename
):
253 """Refactors a file."""
257 self
.log_error("Can't open %s: %s", filename
, err
)
260 input = f
.read() + "\n" # Silence certain parse errors
263 if self
.options
.doctests_only
:
264 if self
.options
.verbose
:
265 self
.log_message("Refactoring doctests in %s", filename
)
266 output
= self
.refactor_docstring(input, filename
)
268 self
.write_file(output
, filename
, input)
269 elif self
.options
.verbose
:
270 self
.log_message("No doctest changes in %s", filename
)
272 tree
= self
.refactor_string(input, filename
)
273 if tree
and tree
.was_changed
:
274 # The [:-1] is to take off the \n we added earlier
275 self
.write_file(str(tree
)[:-1], filename
)
276 elif self
.options
.verbose
:
277 self
.log_message("No changes in %s", filename
)
279 def refactor_string(self
, data
, name
):
280 """Refactor a given input string.
283 data: a string holding the code to be refactored.
284 name: a human-readable name for use in error/log messages.
287 An AST corresponding to the refactored input stream; None if
288 there were errors during the parse.
291 tree
= self
.driver
.parse_string(data
,1)
292 except Exception, err
:
293 self
.log_error("Can't parse %s: %s: %s",
294 name
, err
.__class
__.__name
__, err
)
296 if self
.options
.verbose
:
297 self
.log_message("Refactoring %s", name
)
298 self
.refactor_tree(tree
, name
)
301 def refactor_stdin(self
):
302 if self
.options
.write
:
303 self
.log_error("Can't write changes back to stdin")
305 input = sys
.stdin
.read()
306 if self
.options
.doctests_only
:
307 if self
.options
.verbose
:
308 self
.log_message("Refactoring doctests in stdin")
309 output
= self
.refactor_docstring(input, "<stdin>")
311 self
.write_file(output
, "<stdin>", input)
312 elif self
.options
.verbose
:
313 self
.log_message("No doctest changes in stdin")
315 tree
= self
.refactor_string(input, "<stdin>")
316 if tree
and tree
.was_changed
:
317 self
.write_file(str(tree
), "<stdin>", input)
318 elif self
.options
.verbose
:
319 self
.log_message("No changes in stdin")
321 def refactor_tree(self
, tree
, name
):
322 """Refactors a parse tree (modifying the tree in place).
325 tree: a pytree.Node instance representing the root of the tree
327 name: a human-readable name for this tree.
330 True if the tree was modified, False otherwise.
332 # Two calls to chain are required because pre_order.values()
333 # will be a list of lists of fixers:
334 # [[<fixer ...>, <fixer ...>], [<fixer ...>]]
335 all_fixers
= chain(chain(*self
.pre_order
.values()),\
336 chain(*self
.post_order
.values()))
337 for fixer
in all_fixers
:
338 fixer
.start_tree(tree
, name
)
340 self
.traverse_by(self
.pre_order
, tree
.pre_order())
341 self
.traverse_by(self
.post_order
, tree
.post_order())
343 for fixer
in all_fixers
:
344 fixer
.finish_tree(tree
, name
)
345 return tree
.was_changed
347 def traverse_by(self
, fixers
, traversal
):
348 """Traverse an AST, applying a set of fixers to each node.
350 This is a helper method for refactor_tree().
353 fixers: a list of fixer instances.
354 traversal: a generator that yields AST nodes.
361 for node
in traversal
:
362 for fixer
in fixers
[node
.type] + fixers
[None]:
363 results
= fixer
.match(node
)
365 new
= fixer
.transform(node
, results
)
366 if new
is not None and (new
!= node
or
367 str(new
) != str(node
)):
371 def write_file(self
, new_text
, filename
, old_text
=None):
372 """Writes a string to a file.
374 If there are no changes, this is a no-op.
376 Otherwise, it first shows a unified diff between the old text
377 and the new text, and then rewrites the file; the latter is
378 only done if the write option is set.
380 self
.files
.append(filename
)
383 f
= open(filename
, "r")
385 self
.log_error("Can't read %s: %s", filename
, err
)
391 if old_text
== new_text
:
392 if self
.options
.verbose
:
393 self
.log_message("No changes to %s", filename
)
395 diff_texts(old_text
, new_text
, filename
)
396 if not self
.options
.write
:
397 if self
.options
.verbose
:
398 self
.log_message("Not writing changes to %s", filename
)
400 backup
= filename
+ ".bak"
401 if os
.path
.lexists(backup
):
404 except os
.error
, err
:
405 self
.log_message("Can't remove backup %s", backup
)
407 os
.rename(filename
, backup
)
408 except os
.error
, err
:
409 self
.log_message("Can't rename %s to %s", filename
, backup
)
411 f
= open(filename
, "w")
412 except os
.error
, err
:
413 self
.log_error("Can't create %s: %s", filename
, err
)
418 except os
.error
, err
:
419 self
.log_error("Can't write %s: %s", filename
, err
)
422 if self
.options
.verbose
:
423 self
.log_message("Wrote changes to %s", filename
)
428 def refactor_docstring(self
, input, filename
):
429 """Refactors a docstring, looking for doctests.
431 This returns a modified version of the input string. It looks
432 for doctests, which start with a ">>>" prompt, and may be
433 continued with "..." prompts, as long as the "..." is indented
434 the same as the ">>>".
436 (Unfortunately we can't use the doctest module's parser,
437 since, like most parsers, it is not geared towards preserving
438 the original source.)
445 for line
in input.splitlines(True):
447 if line
.lstrip().startswith(self
.PS1
):
448 if block
is not None:
449 result
.extend(self
.refactor_doctest(block
, block_lineno
,
451 block_lineno
= lineno
453 i
= line
.find(self
.PS1
)
455 elif (indent
is not None and
456 (line
.startswith(indent
+ self
.PS2
) or
457 line
== indent
+ self
.PS2
.rstrip() + "\n")):
460 if block
is not None:
461 result
.extend(self
.refactor_doctest(block
, block_lineno
,
466 if block
is not None:
467 result
.extend(self
.refactor_doctest(block
, block_lineno
,
469 return "".join(result
)
471 def refactor_doctest(self
, block
, lineno
, indent
, filename
):
472 """Refactors one doctest.
474 A doctest is given as a block of lines, the first of which starts
475 with ">>>" (possibly indented), while the remaining lines start
476 with "..." (identically indented).
480 tree
= self
.parse_block(block
, lineno
, indent
)
481 except Exception, err
:
482 if self
.options
.verbose
:
484 self
.log_message("Source: %s", line
.rstrip("\n"))
485 self
.log_error("Can't parse docstring in %s line %s: %s: %s",
486 filename
, lineno
, err
.__class
__.__name
__, err
)
488 if self
.refactor_tree(tree
, filename
):
489 new
= str(tree
).splitlines(True)
490 # Undo the adjustment of the line numbers in wrap_toks() below.
491 clipped
, new
= new
[:lineno
-1], new
[lineno
-1:]
492 assert clipped
== ["\n"] * (lineno
-1), clipped
493 if not new
[-1].endswith("\n"):
495 block
= [indent
+ self
.PS1
+ new
.pop(0)]
497 block
+= [indent
+ self
.PS2
+ line
for line
in new
]
501 if self
.options
.write
:
506 self
.log_message("No files %s modified.", were
)
508 self
.log_message("Files that %s modified:", were
)
509 for file in self
.files
:
510 self
.log_message(file)
512 self
.log_message("Warnings/messages while refactoring:")
513 for message
in self
.fixer_log
:
514 self
.log_message(message
)
516 if len(self
.errors
) == 1:
517 self
.log_message("There was 1 error:")
519 self
.log_message("There were %d errors:", len(self
.errors
))
520 for msg
, args
, kwds
in self
.errors
:
521 self
.log_message(msg
, *args
, **kwds
)
523 def parse_block(self
, block
, lineno
, indent
):
524 """Parses a block into a tree.
526 This is necessary to get correct line number / offset information
527 in the parser diagnostics and embedded into the parse tree.
529 return self
.driver
.parse_tokens(self
.wrap_toks(block
, lineno
, indent
))
531 def wrap_toks(self
, block
, lineno
, indent
):
532 """Wraps a tokenize stream to systematically modify start/end."""
533 tokens
= tokenize
.generate_tokens(self
.gen_lines(block
, indent
).next
)
534 for type, value
, (line0
, col0
), (line1
, col1
), line_text
in tokens
:
537 # Don't bother updating the columns; this is too complicated
538 # since line_text would also have to be updated and it would
539 # still break for tokens spanning lines. Let the user guess
540 # that the column numbers for doctests are relative to the
541 # end of the prompt string (PS1 or PS2).
542 yield type, value
, (line0
, col0
), (line1
, col1
), line_text
545 def gen_lines(self
, block
, indent
):
546 """Generates lines as expected by tokenize from a list of lines.
548 This strips the first len(indent + self.PS1) characters off each line.
550 prefix1
= indent
+ self
.PS1
551 prefix2
= indent
+ self
.PS2
554 if line
.startswith(prefix
):
555 yield line
[len(prefix
):]
556 elif line
== prefix
.rstrip() + "\n":
559 raise AssertionError("line=%r, prefix=%r" % (line
, prefix
))
565 def diff_texts(a
, b
, filename
):
566 """Prints a unified diff of two strings."""
569 for line
in difflib
.unified_diff(a
, b
, filename
, filename
,
570 "(original)", "(refactored)",
575 if __name__
== "__main__":