Tweak the comments and formatting.
[python.git] / Lib / lib2to3 / refactor.py
blob6057caf79a468dfcd2bc6f241cb36c1e03e7889a
1 #!/usr/bin/env python2.5
2 # Copyright 2006 Google, Inc. All Rights Reserved.
3 # Licensed to PSF under a Contributor Agreement.
5 """Refactoring framework.
7 Used as a main program, this can refactor any number of files and/or
8 recursively descend down directories. Imported as a module, this
9 provides infrastructure to write your own refactoring tool.
10 """
12 __author__ = "Guido van Rossum <guido@python.org>"
15 # Python imports
16 import os
17 import sys
18 import difflib
19 import optparse
20 import logging
21 from collections import defaultdict
22 from itertools import chain
24 # Local imports
25 from .pgen2 import driver
26 from .pgen2 import tokenize
28 from . import pytree
29 from . import patcomp
30 from . import fixes
31 from . import pygram
33 def main(args=None):
34 """Main program.
36 Call without arguments to use sys.argv[1:] as the arguments; or
37 call with a list of arguments (excluding sys.argv[0]).
39 Returns a suggested exit status (0, 1, 2).
40 """
41 # Set up option parser
42 parser = optparse.OptionParser(usage="refactor.py [options] file|dir ...")
43 parser.add_option("-d", "--doctests_only", action="store_true",
44 help="Fix up doctests only")
45 parser.add_option("-f", "--fix", action="append", default=[],
46 help="Each FIX specifies a transformation; default all")
47 parser.add_option("-l", "--list-fixes", action="store_true",
48 help="List available transformations (fixes/fix_*.py)")
49 parser.add_option("-p", "--print-function", action="store_true",
50 help="Modify the grammar so that print() is a function")
51 parser.add_option("-v", "--verbose", action="store_true",
52 help="More verbose logging")
53 parser.add_option("-w", "--write", action="store_true",
54 help="Write back modified files")
56 # Parse command line arguments
57 options, args = parser.parse_args(args)
58 if options.list_fixes:
59 print "Available transformations for the -f/--fix option:"
60 for fixname in get_all_fix_names():
61 print fixname
62 if not args:
63 return 0
64 if not args:
65 print >>sys.stderr, "At least one file or directory argument required."
66 print >>sys.stderr, "Use --help to show usage."
67 return 2
69 # Set up logging handler
70 if sys.version_info < (2, 4):
71 hdlr = logging.StreamHandler()
72 fmt = logging.Formatter('%(name)s: %(message)s')
73 hdlr.setFormatter(fmt)
74 logging.root.addHandler(hdlr)
75 else:
76 logging.basicConfig(format='%(name)s: %(message)s', level=logging.INFO)
78 # Initialize the refactoring tool
79 rt = RefactoringTool(options)
81 # Refactor all files and directories passed as arguments
82 if not rt.errors:
83 rt.refactor_args(args)
84 rt.summarize()
86 # Return error status (0 if rt.errors is zero)
87 return int(bool(rt.errors))
90 def get_all_fix_names():
91 """Return a sorted list of all available fix names."""
92 fix_names = []
93 names = os.listdir(os.path.dirname(fixes.__file__))
94 names.sort()
95 for name in names:
96 if name.startswith("fix_") and name.endswith(".py"):
97 fix_names.append(name[4:-3])
98 fix_names.sort()
99 return fix_names
101 def get_head_types(pat):
102 """ Accepts a pytree Pattern Node and returns a set
103 of the pattern types which will match first. """
105 if isinstance(pat, (pytree.NodePattern, pytree.LeafPattern)):
106 # NodePatters must either have no type and no content
107 # or a type and content -- so they don't get any farther
108 # Always return leafs
109 return set([pat.type])
111 if isinstance(pat, pytree.NegatedPattern):
112 if pat.content:
113 return get_head_types(pat.content)
114 return set([None]) # Negated Patterns don't have a type
116 if isinstance(pat, pytree.WildcardPattern):
117 # Recurse on each node in content
118 r = set()
119 for p in pat.content:
120 for x in p:
121 r.update(get_head_types(x))
122 return r
124 raise Exception("Oh no! I don't understand pattern %s" %(pat))
126 def get_headnode_dict(fixer_list):
127 """ Accepts a list of fixers and returns a dictionary
128 of head node type --> fixer list. """
129 head_nodes = defaultdict(list)
130 for fixer in fixer_list:
131 if not fixer.pattern:
132 head_nodes[None].append(fixer)
133 continue
134 for t in get_head_types(fixer.pattern):
135 head_nodes[t].append(fixer)
136 return head_nodes
139 class RefactoringTool(object):
141 def __init__(self, options):
142 """Initializer.
144 The argument is an optparse.Values instance.
146 self.options = options
147 self.errors = []
148 self.logger = logging.getLogger("RefactoringTool")
149 self.fixer_log = []
150 if self.options.print_function:
151 del pygram.python_grammar.keywords["print"]
152 self.driver = driver.Driver(pygram.python_grammar,
153 convert=pytree.convert,
154 logger=self.logger)
155 self.pre_order, self.post_order = self.get_fixers()
157 self.pre_order = get_headnode_dict(self.pre_order)
158 self.post_order = get_headnode_dict(self.post_order)
160 self.files = [] # List of files that were or should be modified
162 def get_fixers(self):
163 """Inspects the options to load the requested patterns and handlers.
165 Returns:
166 (pre_order, post_order), where pre_order is the list of fixers that
167 want a pre-order AST traversal, and post_order is the list that want
168 post-order traversal.
170 pre_order_fixers = []
171 post_order_fixers = []
172 fix_names = self.options.fix
173 if not fix_names or "all" in fix_names:
174 fix_names = get_all_fix_names()
175 for fix_name in fix_names:
176 try:
177 mod = __import__("lib2to3.fixes.fix_" + fix_name, {}, {}, ["*"])
178 except ImportError:
179 self.log_error("Can't find transformation %s", fix_name)
180 continue
181 parts = fix_name.split("_")
182 class_name = "Fix" + "".join([p.title() for p in parts])
183 try:
184 fix_class = getattr(mod, class_name)
185 except AttributeError:
186 self.log_error("Can't find fixes.fix_%s.%s",
187 fix_name, class_name)
188 continue
189 try:
190 fixer = fix_class(self.options, self.fixer_log)
191 except Exception, err:
192 self.log_error("Can't instantiate fixes.fix_%s.%s()",
193 fix_name, class_name, exc_info=True)
194 continue
195 if fixer.explicit and fix_name not in self.options.fix:
196 self.log_message("Skipping implicit fixer: %s", fix_name)
197 continue
199 if self.options.verbose:
200 self.log_message("Adding transformation: %s", fix_name)
201 if fixer.order == "pre":
202 pre_order_fixers.append(fixer)
203 elif fixer.order == "post":
204 post_order_fixers.append(fixer)
205 else:
206 raise ValueError("Illegal fixer order: %r" % fixer.order)
208 pre_order_fixers.sort(key=lambda x: x.run_order)
209 post_order_fixers.sort(key=lambda x: x.run_order)
210 return (pre_order_fixers, post_order_fixers)
212 def log_error(self, msg, *args, **kwds):
213 """Increments error count and log a message."""
214 self.errors.append((msg, args, kwds))
215 self.logger.error(msg, *args, **kwds)
217 def log_message(self, msg, *args):
218 """Hook to log a message."""
219 if args:
220 msg = msg % args
221 self.logger.info(msg)
223 def refactor_args(self, args):
224 """Refactors files and directories from an argument list."""
225 for arg in args:
226 if arg == "-":
227 self.refactor_stdin()
228 elif os.path.isdir(arg):
229 self.refactor_dir(arg)
230 else:
231 self.refactor_file(arg)
233 def refactor_dir(self, arg):
234 """Descends down a directory and refactor every Python file found.
236 Python files are assumed to have a .py extension.
238 Files and subdirectories starting with '.' are skipped.
240 for dirpath, dirnames, filenames in os.walk(arg):
241 if self.options.verbose:
242 self.log_message("Descending into %s", dirpath)
243 dirnames.sort()
244 filenames.sort()
245 for name in filenames:
246 if not name.startswith(".") and name.endswith("py"):
247 fullname = os.path.join(dirpath, name)
248 self.refactor_file(fullname)
249 # Modify dirnames in-place to remove subdirs with leading dots
250 dirnames[:] = [dn for dn in dirnames if not dn.startswith(".")]
252 def refactor_file(self, filename):
253 """Refactors a file."""
254 try:
255 f = open(filename)
256 except IOError, err:
257 self.log_error("Can't open %s: %s", filename, err)
258 return
259 try:
260 input = f.read() + "\n" # Silence certain parse errors
261 finally:
262 f.close()
263 if self.options.doctests_only:
264 if self.options.verbose:
265 self.log_message("Refactoring doctests in %s", filename)
266 output = self.refactor_docstring(input, filename)
267 if output != input:
268 self.write_file(output, filename, input)
269 elif self.options.verbose:
270 self.log_message("No doctest changes in %s", filename)
271 else:
272 tree = self.refactor_string(input, filename)
273 if tree and tree.was_changed:
274 # The [:-1] is to take off the \n we added earlier
275 self.write_file(str(tree)[:-1], filename)
276 elif self.options.verbose:
277 self.log_message("No changes in %s", filename)
279 def refactor_string(self, data, name):
280 """Refactor a given input string.
282 Args:
283 data: a string holding the code to be refactored.
284 name: a human-readable name for use in error/log messages.
286 Returns:
287 An AST corresponding to the refactored input stream; None if
288 there were errors during the parse.
290 try:
291 tree = self.driver.parse_string(data,1)
292 except Exception, err:
293 self.log_error("Can't parse %s: %s: %s",
294 name, err.__class__.__name__, err)
295 return
296 if self.options.verbose:
297 self.log_message("Refactoring %s", name)
298 self.refactor_tree(tree, name)
299 return tree
301 def refactor_stdin(self):
302 if self.options.write:
303 self.log_error("Can't write changes back to stdin")
304 return
305 input = sys.stdin.read()
306 if self.options.doctests_only:
307 if self.options.verbose:
308 self.log_message("Refactoring doctests in stdin")
309 output = self.refactor_docstring(input, "<stdin>")
310 if output != input:
311 self.write_file(output, "<stdin>", input)
312 elif self.options.verbose:
313 self.log_message("No doctest changes in stdin")
314 else:
315 tree = self.refactor_string(input, "<stdin>")
316 if tree and tree.was_changed:
317 self.write_file(str(tree), "<stdin>", input)
318 elif self.options.verbose:
319 self.log_message("No changes in stdin")
321 def refactor_tree(self, tree, name):
322 """Refactors a parse tree (modifying the tree in place).
324 Args:
325 tree: a pytree.Node instance representing the root of the tree
326 to be refactored.
327 name: a human-readable name for this tree.
329 Returns:
330 True if the tree was modified, False otherwise.
332 # Two calls to chain are required because pre_order.values()
333 # will be a list of lists of fixers:
334 # [[<fixer ...>, <fixer ...>], [<fixer ...>]]
335 all_fixers = chain(chain(*self.pre_order.values()),\
336 chain(*self.post_order.values()))
337 for fixer in all_fixers:
338 fixer.start_tree(tree, name)
340 self.traverse_by(self.pre_order, tree.pre_order())
341 self.traverse_by(self.post_order, tree.post_order())
343 for fixer in all_fixers:
344 fixer.finish_tree(tree, name)
345 return tree.was_changed
347 def traverse_by(self, fixers, traversal):
348 """Traverse an AST, applying a set of fixers to each node.
350 This is a helper method for refactor_tree().
352 Args:
353 fixers: a list of fixer instances.
354 traversal: a generator that yields AST nodes.
356 Returns:
357 None
359 if not fixers:
360 return
361 for node in traversal:
362 for fixer in fixers[node.type] + fixers[None]:
363 results = fixer.match(node)
364 if results:
365 new = fixer.transform(node, results)
366 if new is not None and (new != node or
367 str(new) != str(node)):
368 node.replace(new)
369 node = new
371 def write_file(self, new_text, filename, old_text=None):
372 """Writes a string to a file.
374 If there are no changes, this is a no-op.
376 Otherwise, it first shows a unified diff between the old text
377 and the new text, and then rewrites the file; the latter is
378 only done if the write option is set.
380 self.files.append(filename)
381 if old_text is None:
382 try:
383 f = open(filename, "r")
384 except IOError, err:
385 self.log_error("Can't read %s: %s", filename, err)
386 return
387 try:
388 old_text = f.read()
389 finally:
390 f.close()
391 if old_text == new_text:
392 if self.options.verbose:
393 self.log_message("No changes to %s", filename)
394 return
395 diff_texts(old_text, new_text, filename)
396 if not self.options.write:
397 if self.options.verbose:
398 self.log_message("Not writing changes to %s", filename)
399 return
400 backup = filename + ".bak"
401 if os.path.lexists(backup):
402 try:
403 os.remove(backup)
404 except os.error, err:
405 self.log_message("Can't remove backup %s", backup)
406 try:
407 os.rename(filename, backup)
408 except os.error, err:
409 self.log_message("Can't rename %s to %s", filename, backup)
410 try:
411 f = open(filename, "w")
412 except os.error, err:
413 self.log_error("Can't create %s: %s", filename, err)
414 return
415 try:
416 try:
417 f.write(new_text)
418 except os.error, err:
419 self.log_error("Can't write %s: %s", filename, err)
420 finally:
421 f.close()
422 if self.options.verbose:
423 self.log_message("Wrote changes to %s", filename)
425 PS1 = ">>> "
426 PS2 = "... "
428 def refactor_docstring(self, input, filename):
429 """Refactors a docstring, looking for doctests.
431 This returns a modified version of the input string. It looks
432 for doctests, which start with a ">>>" prompt, and may be
433 continued with "..." prompts, as long as the "..." is indented
434 the same as the ">>>".
436 (Unfortunately we can't use the doctest module's parser,
437 since, like most parsers, it is not geared towards preserving
438 the original source.)
440 result = []
441 block = None
442 block_lineno = None
443 indent = None
444 lineno = 0
445 for line in input.splitlines(True):
446 lineno += 1
447 if line.lstrip().startswith(self.PS1):
448 if block is not None:
449 result.extend(self.refactor_doctest(block, block_lineno,
450 indent, filename))
451 block_lineno = lineno
452 block = [line]
453 i = line.find(self.PS1)
454 indent = line[:i]
455 elif (indent is not None and
456 (line.startswith(indent + self.PS2) or
457 line == indent + self.PS2.rstrip() + "\n")):
458 block.append(line)
459 else:
460 if block is not None:
461 result.extend(self.refactor_doctest(block, block_lineno,
462 indent, filename))
463 block = None
464 indent = None
465 result.append(line)
466 if block is not None:
467 result.extend(self.refactor_doctest(block, block_lineno,
468 indent, filename))
469 return "".join(result)
471 def refactor_doctest(self, block, lineno, indent, filename):
472 """Refactors one doctest.
474 A doctest is given as a block of lines, the first of which starts
475 with ">>>" (possibly indented), while the remaining lines start
476 with "..." (identically indented).
479 try:
480 tree = self.parse_block(block, lineno, indent)
481 except Exception, err:
482 if self.options.verbose:
483 for line in block:
484 self.log_message("Source: %s", line.rstrip("\n"))
485 self.log_error("Can't parse docstring in %s line %s: %s: %s",
486 filename, lineno, err.__class__.__name__, err)
487 return block
488 if self.refactor_tree(tree, filename):
489 new = str(tree).splitlines(True)
490 # Undo the adjustment of the line numbers in wrap_toks() below.
491 clipped, new = new[:lineno-1], new[lineno-1:]
492 assert clipped == ["\n"] * (lineno-1), clipped
493 if not new[-1].endswith("\n"):
494 new[-1] += "\n"
495 block = [indent + self.PS1 + new.pop(0)]
496 if new:
497 block += [indent + self.PS2 + line for line in new]
498 return block
500 def summarize(self):
501 if self.options.write:
502 were = "were"
503 else:
504 were = "need to be"
505 if not self.files:
506 self.log_message("No files %s modified.", were)
507 else:
508 self.log_message("Files that %s modified:", were)
509 for file in self.files:
510 self.log_message(file)
511 if self.fixer_log:
512 self.log_message("Warnings/messages while refactoring:")
513 for message in self.fixer_log:
514 self.log_message(message)
515 if self.errors:
516 if len(self.errors) == 1:
517 self.log_message("There was 1 error:")
518 else:
519 self.log_message("There were %d errors:", len(self.errors))
520 for msg, args, kwds in self.errors:
521 self.log_message(msg, *args, **kwds)
523 def parse_block(self, block, lineno, indent):
524 """Parses a block into a tree.
526 This is necessary to get correct line number / offset information
527 in the parser diagnostics and embedded into the parse tree.
529 return self.driver.parse_tokens(self.wrap_toks(block, lineno, indent))
531 def wrap_toks(self, block, lineno, indent):
532 """Wraps a tokenize stream to systematically modify start/end."""
533 tokens = tokenize.generate_tokens(self.gen_lines(block, indent).next)
534 for type, value, (line0, col0), (line1, col1), line_text in tokens:
535 line0 += lineno - 1
536 line1 += lineno - 1
537 # Don't bother updating the columns; this is too complicated
538 # since line_text would also have to be updated and it would
539 # still break for tokens spanning lines. Let the user guess
540 # that the column numbers for doctests are relative to the
541 # end of the prompt string (PS1 or PS2).
542 yield type, value, (line0, col0), (line1, col1), line_text
545 def gen_lines(self, block, indent):
546 """Generates lines as expected by tokenize from a list of lines.
548 This strips the first len(indent + self.PS1) characters off each line.
550 prefix1 = indent + self.PS1
551 prefix2 = indent + self.PS2
552 prefix = prefix1
553 for line in block:
554 if line.startswith(prefix):
555 yield line[len(prefix):]
556 elif line == prefix.rstrip() + "\n":
557 yield "\n"
558 else:
559 raise AssertionError("line=%r, prefix=%r" % (line, prefix))
560 prefix = prefix2
561 while True:
562 yield ""
565 def diff_texts(a, b, filename):
566 """Prints a unified diff of two strings."""
567 a = a.splitlines()
568 b = b.splitlines()
569 for line in difflib.unified_diff(a, b, filename, filename,
570 "(original)", "(refactored)",
571 lineterm=""):
572 print line
575 if __name__ == "__main__":
576 sys.exit(main())