1 # Copyright 2006 Google, Inc. All Rights Reserved.
2 # Licensed to PSF under a Contributor Agreement.
4 """Refactoring framework.
6 Used as a main program, this can refactor any number of files and/or
7 recursively descend down directories. Imported as a module, this
8 provides infrastructure to write your own refactoring tool.
11 __author__
= "Guido van Rossum <guido@python.org>"
20 from collections
import defaultdict
21 from itertools
import chain
24 from .pgen2
import driver
, tokenize
25 from . import pytree
, pygram
28 def get_all_fix_names(fixer_pkg
, remove_prefix
=True):
29 """Return a sorted list of all available fix names in the given package."""
30 pkg
= __import__(fixer_pkg
, [], [], ["*"])
31 fixer_dir
= os
.path
.dirname(pkg
.__file
__)
33 for name
in sorted(os
.listdir(fixer_dir
)):
34 if name
.startswith("fix_") and name
.endswith(".py"):
37 fix_names
.append(name
[:-3])
40 def get_head_types(pat
):
41 """ Accepts a pytree Pattern Node and returns a set
42 of the pattern types which will match first. """
44 if isinstance(pat
, (pytree
.NodePattern
, pytree
.LeafPattern
)):
45 # NodePatters must either have no type and no content
46 # or a type and content -- so they don't get any farther
48 return set([pat
.type])
50 if isinstance(pat
, pytree
.NegatedPattern
):
52 return get_head_types(pat
.content
)
53 return set([None]) # Negated Patterns don't have a type
55 if isinstance(pat
, pytree
.WildcardPattern
):
56 # Recurse on each node in content
60 r
.update(get_head_types(x
))
63 raise Exception("Oh no! I don't understand pattern %s" %(pat))
65 def get_headnode_dict(fixer_list
):
66 """ Accepts a list of fixers and returns a dictionary
67 of head node type --> fixer list. """
68 head_nodes
= defaultdict(list)
69 for fixer
in fixer_list
:
71 head_nodes
[None].append(fixer
)
73 for t
in get_head_types(fixer
.pattern
):
74 head_nodes
[t
].append(fixer
)
77 def get_fixers_from_package(pkg_name
):
79 Return the fully qualified names for fixers in the package pkg_name.
81 return [pkg_name
+ "." + fix_name
82 for fix_name
in get_all_fix_names(pkg_name
, False)]
87 if sys
.version_info
< (3, 0):
89 _open_with_encoding
= codecs
.open
90 # codecs.open doesn't translate newlines sadly.
91 def _from_system_newlines(input):
92 return input.replace("\r\n", "\n")
93 def _to_system_newlines(input):
94 if os
.linesep
!= "\n":
95 return input.replace("\n", os
.linesep
)
99 _open_with_encoding
= open
100 _from_system_newlines
= _identity
101 _to_system_newlines
= _identity
104 class FixerError(Exception):
105 """A fixer could not be loaded."""
108 class RefactoringTool(object):
110 _default_options
= {"print_function": False}
112 CLASS_PREFIX
= "Fix" # The prefix for fixer classes
113 FILE_PREFIX
= "fix_" # The prefix for modules with a fixer within
115 def __init__(self
, fixer_names
, options
=None, explicit
=None):
119 fixer_names: a list of fixers to import
120 options: an dict with configuration.
121 explicit: a list of fixers to run even if they are explicit.
123 self
.fixers
= fixer_names
124 self
.explicit
= explicit
or []
125 self
.options
= self
._default
_options
.copy()
126 if options
is not None:
127 self
.options
.update(options
)
129 self
.logger
= logging
.getLogger("RefactoringTool")
132 if self
.options
["print_function"]:
133 del pygram
.python_grammar
.keywords
["print"]
134 self
.driver
= driver
.Driver(pygram
.python_grammar
,
135 convert
=pytree
.convert
,
137 self
.pre_order
, self
.post_order
= self
.get_fixers()
139 self
.pre_order_heads
= get_headnode_dict(self
.pre_order
)
140 self
.post_order_heads
= get_headnode_dict(self
.post_order
)
142 self
.files
= [] # List of files that were or should be modified
144 def get_fixers(self
):
145 """Inspects the options to load the requested patterns and handlers.
148 (pre_order, post_order), where pre_order is the list of fixers that
149 want a pre-order AST traversal, and post_order is the list that want
150 post-order traversal.
152 pre_order_fixers
= []
153 post_order_fixers
= []
154 for fix_mod_path
in self
.fixers
:
155 mod
= __import__(fix_mod_path
, {}, {}, ["*"])
156 fix_name
= fix_mod_path
.rsplit(".", 1)[-1]
157 if fix_name
.startswith(self
.FILE_PREFIX
):
158 fix_name
= fix_name
[len(self
.FILE_PREFIX
):]
159 parts
= fix_name
.split("_")
160 class_name
= self
.CLASS_PREFIX
+ "".join([p
.title() for p
in parts
])
162 fix_class
= getattr(mod
, class_name
)
163 except AttributeError:
164 raise FixerError("Can't find %s.%s" % (fix_name
, class_name
))
165 fixer
= fix_class(self
.options
, self
.fixer_log
)
166 if fixer
.explicit
and self
.explicit
is not True and \
167 fix_mod_path
not in self
.explicit
:
168 self
.log_message("Skipping implicit fixer: %s", fix_name
)
171 self
.log_debug("Adding transformation: %s", fix_name
)
172 if fixer
.order
== "pre":
173 pre_order_fixers
.append(fixer
)
174 elif fixer
.order
== "post":
175 post_order_fixers
.append(fixer
)
177 raise FixerError("Illegal fixer order: %r" % fixer
.order
)
179 key_func
= operator
.attrgetter("run_order")
180 pre_order_fixers
.sort(key
=key_func
)
181 post_order_fixers
.sort(key
=key_func
)
182 return (pre_order_fixers
, post_order_fixers
)
184 def log_error(self
, msg
, *args
, **kwds
):
185 """Called when an error occurs."""
188 def log_message(self
, msg
, *args
):
189 """Hook to log a message."""
192 self
.logger
.info(msg
)
194 def log_debug(self
, msg
, *args
):
197 self
.logger
.debug(msg
)
199 def print_output(self
, lines
):
200 """Called with lines of output to give to the user."""
203 def refactor(self
, items
, write
=False, doctests_only
=False):
204 """Refactor a list of files and directories."""
205 for dir_or_file
in items
:
206 if os
.path
.isdir(dir_or_file
):
207 self
.refactor_dir(dir_or_file
, write
, doctests_only
)
209 self
.refactor_file(dir_or_file
, write
, doctests_only
)
211 def refactor_dir(self
, dir_name
, write
=False, doctests_only
=False):
212 """Descends down a directory and refactor every Python file found.
214 Python files are assumed to have a .py extension.
216 Files and subdirectories starting with '.' are skipped.
218 for dirpath
, dirnames
, filenames
in os
.walk(dir_name
):
219 self
.log_debug("Descending into %s", dirpath
)
222 for name
in filenames
:
223 if not name
.startswith(".") and name
.endswith("py"):
224 fullname
= os
.path
.join(dirpath
, name
)
225 self
.refactor_file(fullname
, write
, doctests_only
)
226 # Modify dirnames in-place to remove subdirs with leading dots
227 dirnames
[:] = [dn
for dn
in dirnames
if not dn
.startswith(".")]
229 def _read_python_source(self
, filename
):
231 Do our best to decode a Python source file correctly.
234 f
= open(filename
, "rb")
235 except IOError as err
:
236 self
.log_error("Can't open %s: %s", filename
, err
)
239 encoding
= tokenize
.detect_encoding(f
.readline
)[0]
242 with
_open_with_encoding(filename
, "r", encoding
=encoding
) as f
:
243 return _from_system_newlines(f
.read()), encoding
245 def refactor_file(self
, filename
, write
=False, doctests_only
=False):
246 """Refactors a file."""
247 input, encoding
= self
._read
_python
_source
(filename
)
249 # Reading the file failed.
251 input += "\n" # Silence certain parse errors
253 self
.log_debug("Refactoring doctests in %s", filename
)
254 output
= self
.refactor_docstring(input, filename
)
256 self
.processed_file(output
, filename
, input, write
, encoding
)
258 self
.log_debug("No doctest changes in %s", filename
)
260 tree
= self
.refactor_string(input, filename
)
261 if tree
and tree
.was_changed
:
262 # The [:-1] is to take off the \n we added earlier
263 self
.processed_file(str(tree
)[:-1], filename
,
264 write
=write
, encoding
=encoding
)
266 self
.log_debug("No changes in %s", filename
)
268 def refactor_string(self
, data
, name
):
269 """Refactor a given input string.
272 data: a string holding the code to be refactored.
273 name: a human-readable name for use in error/log messages.
276 An AST corresponding to the refactored input stream; None if
277 there were errors during the parse.
280 tree
= self
.driver
.parse_string(data
)
281 except Exception as err
:
282 self
.log_error("Can't parse %s: %s: %s",
283 name
, err
.__class
__.__name
__, err
)
285 self
.log_debug("Refactoring %s", name
)
286 self
.refactor_tree(tree
, name
)
289 def refactor_stdin(self
, doctests_only
=False):
290 input = sys
.stdin
.read()
292 self
.log_debug("Refactoring doctests in stdin")
293 output
= self
.refactor_docstring(input, "<stdin>")
295 self
.processed_file(output
, "<stdin>", input)
297 self
.log_debug("No doctest changes in stdin")
299 tree
= self
.refactor_string(input, "<stdin>")
300 if tree
and tree
.was_changed
:
301 self
.processed_file(str(tree
), "<stdin>", input)
303 self
.log_debug("No changes in stdin")
305 def refactor_tree(self
, tree
, name
):
306 """Refactors a parse tree (modifying the tree in place).
309 tree: a pytree.Node instance representing the root of the tree
311 name: a human-readable name for this tree.
314 True if the tree was modified, False otherwise.
316 for fixer
in chain(self
.pre_order
, self
.post_order
):
317 fixer
.start_tree(tree
, name
)
319 self
.traverse_by(self
.pre_order_heads
, tree
.pre_order())
320 self
.traverse_by(self
.post_order_heads
, tree
.post_order())
322 for fixer
in chain(self
.pre_order
, self
.post_order
):
323 fixer
.finish_tree(tree
, name
)
324 return tree
.was_changed
326 def traverse_by(self
, fixers
, traversal
):
327 """Traverse an AST, applying a set of fixers to each node.
329 This is a helper method for refactor_tree().
332 fixers: a list of fixer instances.
333 traversal: a generator that yields AST nodes.
340 for node
in traversal
:
341 for fixer
in fixers
[node
.type] + fixers
[None]:
342 results
= fixer
.match(node
)
344 new
= fixer
.transform(node
, results
)
345 if new
is not None and (new
!= node
or
346 str(new
) != str(node
)):
350 def processed_file(self
, new_text
, filename
, old_text
=None, write
=False,
353 Called when a file has been refactored, and there are changes.
355 self
.files
.append(filename
)
357 old_text
= self
._read
_python
_source
(filename
)[0]
360 if old_text
== new_text
:
361 self
.log_debug("No changes to %s", filename
)
363 self
.print_output(diff_texts(old_text
, new_text
, filename
))
365 self
.write_file(new_text
, filename
, old_text
, encoding
)
367 self
.log_debug("Not writing changes to %s", filename
)
369 def write_file(self
, new_text
, filename
, old_text
, encoding
=None):
370 """Writes a string to a file.
372 It first shows a unified diff between the old text and the new text, and
373 then rewrites the file; the latter is only done if the write option is
377 f
= _open_with_encoding(filename
, "w", encoding
=encoding
)
378 except os
.error
as err
:
379 self
.log_error("Can't create %s: %s", filename
, err
)
382 f
.write(_to_system_newlines(new_text
))
383 except os
.error
as err
:
384 self
.log_error("Can't write %s: %s", filename
, err
)
387 self
.log_debug("Wrote changes to %s", filename
)
393 def refactor_docstring(self
, input, filename
):
394 """Refactors a docstring, looking for doctests.
396 This returns a modified version of the input string. It looks
397 for doctests, which start with a ">>>" prompt, and may be
398 continued with "..." prompts, as long as the "..." is indented
399 the same as the ">>>".
401 (Unfortunately we can't use the doctest module's parser,
402 since, like most parsers, it is not geared towards preserving
403 the original source.)
410 for line
in input.splitlines(True):
412 if line
.lstrip().startswith(self
.PS1
):
413 if block
is not None:
414 result
.extend(self
.refactor_doctest(block
, block_lineno
,
416 block_lineno
= lineno
418 i
= line
.find(self
.PS1
)
420 elif (indent
is not None and
421 (line
.startswith(indent
+ self
.PS2
) or
422 line
== indent
+ self
.PS2
.rstrip() + "\n")):
425 if block
is not None:
426 result
.extend(self
.refactor_doctest(block
, block_lineno
,
431 if block
is not None:
432 result
.extend(self
.refactor_doctest(block
, block_lineno
,
434 return "".join(result
)
436 def refactor_doctest(self
, block
, lineno
, indent
, filename
):
437 """Refactors one doctest.
439 A doctest is given as a block of lines, the first of which starts
440 with ">>>" (possibly indented), while the remaining lines start
441 with "..." (identically indented).
445 tree
= self
.parse_block(block
, lineno
, indent
)
446 except Exception as err
:
447 if self
.log
.isEnabledFor(logging
.DEBUG
):
449 self
.log_debug("Source: %s", line
.rstrip("\n"))
450 self
.log_error("Can't parse docstring in %s line %s: %s: %s",
451 filename
, lineno
, err
.__class
__.__name
__, err
)
453 if self
.refactor_tree(tree
, filename
):
454 new
= str(tree
).splitlines(True)
455 # Undo the adjustment of the line numbers in wrap_toks() below.
456 clipped
, new
= new
[:lineno
-1], new
[lineno
-1:]
457 assert clipped
== ["\n"] * (lineno
-1), clipped
458 if not new
[-1].endswith("\n"):
460 block
= [indent
+ self
.PS1
+ new
.pop(0)]
462 block
+= [indent
+ self
.PS2
+ line
for line
in new
]
471 self
.log_message("No files %s modified.", were
)
473 self
.log_message("Files that %s modified:", were
)
474 for file in self
.files
:
475 self
.log_message(file)
477 self
.log_message("Warnings/messages while refactoring:")
478 for message
in self
.fixer_log
:
479 self
.log_message(message
)
481 if len(self
.errors
) == 1:
482 self
.log_message("There was 1 error:")
484 self
.log_message("There were %d errors:", len(self
.errors
))
485 for msg
, args
, kwds
in self
.errors
:
486 self
.log_message(msg
, *args
, **kwds
)
488 def parse_block(self
, block
, lineno
, indent
):
489 """Parses a block into a tree.
491 This is necessary to get correct line number / offset information
492 in the parser diagnostics and embedded into the parse tree.
494 return self
.driver
.parse_tokens(self
.wrap_toks(block
, lineno
, indent
))
496 def wrap_toks(self
, block
, lineno
, indent
):
497 """Wraps a tokenize stream to systematically modify start/end."""
498 tokens
= tokenize
.generate_tokens(self
.gen_lines(block
, indent
).__next
__)
499 for type, value
, (line0
, col0
), (line1
, col1
), line_text
in tokens
:
502 # Don't bother updating the columns; this is too complicated
503 # since line_text would also have to be updated and it would
504 # still break for tokens spanning lines. Let the user guess
505 # that the column numbers for doctests are relative to the
506 # end of the prompt string (PS1 or PS2).
507 yield type, value
, (line0
, col0
), (line1
, col1
), line_text
510 def gen_lines(self
, block
, indent
):
511 """Generates lines as expected by tokenize from a list of lines.
513 This strips the first len(indent + self.PS1) characters off each line.
515 prefix1
= indent
+ self
.PS1
516 prefix2
= indent
+ self
.PS2
519 if line
.startswith(prefix
):
520 yield line
[len(prefix
):]
521 elif line
== prefix
.rstrip() + "\n":
524 raise AssertionError("line=%r, prefix=%r" % (line
, prefix
))
530 class MultiprocessingUnsupported(Exception):
534 class MultiprocessRefactoringTool(RefactoringTool
):
536 def __init__(self
, *args
, **kwargs
):
537 super(MultiprocessRefactoringTool
, self
).__init
__(*args
, **kwargs
)
540 def refactor(self
, items
, write
=False, doctests_only
=False,
542 if num_processes
== 1:
543 return super(MultiprocessRefactoringTool
, self
).refactor(
544 items
, write
, doctests_only
)
546 import multiprocessing
548 raise MultiprocessingUnsupported
549 if self
.queue
is not None:
550 raise RuntimeError("already doing multiple processes")
551 self
.queue
= multiprocessing
.JoinableQueue()
552 processes
= [multiprocessing
.Process(target
=self
._child
)
553 for i
in range(num_processes
)]
557 super(MultiprocessRefactoringTool
, self
).refactor(items
, write
,
561 for i
in range(num_processes
):
569 task
= self
.queue
.get()
570 while task
is not None:
573 super(MultiprocessRefactoringTool
, self
).refactor_file(
576 self
.queue
.task_done()
577 task
= self
.queue
.get()
579 def refactor_file(self
, *args
, **kwargs
):
580 if self
.queue
is not None:
581 self
.queue
.put((args
, kwargs
))
583 return super(MultiprocessRefactoringTool
, self
).refactor_file(
587 def diff_texts(a
, b
, filename
):
588 """Return a unified diff of two strings."""
591 return difflib
.unified_diff(a
, b
, filename
, filename
,
592 "(original)", "(refactored)",