1 # Copyright 2006 Google, Inc. All Rights Reserved.
2 # Licensed to PSF under a Contributor Agreement.
4 """Refactoring framework.
6 Used as a main program, this can refactor any number of files and/or
7 recursively descend down directories. Imported as a module, this
8 provides infrastructure to write your own refactoring tool.
11 __author__
= "Guido van Rossum <guido@python.org>"
21 from itertools
import chain
24 from .pgen2
import driver
, tokenize
, token
25 from . import pytree
, pygram
28 def get_all_fix_names(fixer_pkg
, remove_prefix
=True):
29 """Return a sorted list of all available fix names in the given package."""
30 pkg
= __import__(fixer_pkg
, [], [], ["*"])
31 fixer_dir
= os
.path
.dirname(pkg
.__file
__)
33 for name
in sorted(os
.listdir(fixer_dir
)):
34 if name
.startswith("fix_") and name
.endswith(".py"):
37 fix_names
.append(name
[:-3])
41 class _EveryNode(Exception):
45 def _get_head_types(pat
):
46 """ Accepts a pytree Pattern Node and returns a set
47 of the pattern types which will match first. """
49 if isinstance(pat
, (pytree
.NodePattern
, pytree
.LeafPattern
)):
50 # NodePatters must either have no type and no content
51 # or a type and content -- so they don't get any farther
55 return set([pat
.type])
57 if isinstance(pat
, pytree
.NegatedPattern
):
59 return _get_head_types(pat
.content
)
60 raise _EveryNode
# Negated Patterns don't have a type
62 if isinstance(pat
, pytree
.WildcardPattern
):
63 # Recurse on each node in content
67 r
.update(_get_head_types(x
))
70 raise Exception("Oh no! I don't understand pattern %s" %(pat))
73 def _get_headnode_dict(fixer_list
):
74 """ Accepts a list of fixers and returns a dictionary
75 of head node type --> fixer list. """
76 head_nodes
= collections
.defaultdict(list)
78 for fixer
in fixer_list
:
81 heads
= _get_head_types(fixer
.pattern
)
85 for node_type
in heads
:
86 head_nodes
[node_type
].append(fixer
)
88 if fixer
._accept
_type
is not None:
89 head_nodes
[fixer
._accept
_type
].append(fixer
)
92 for node_type
in chain(pygram
.python_grammar
.symbol2number
.itervalues(),
93 pygram
.python_grammar
.tokens
):
94 head_nodes
[node_type
].extend(every
)
95 return dict(head_nodes
)
98 def get_fixers_from_package(pkg_name
):
100 Return the fully qualified names for fixers in the package pkg_name.
102 return [pkg_name
+ "." + fix_name
103 for fix_name
in get_all_fix_names(pkg_name
, False)]
108 if sys
.version_info
< (3, 0):
110 _open_with_encoding
= codecs
.open
111 # codecs.open doesn't translate newlines sadly.
112 def _from_system_newlines(input):
113 return input.replace(u
"\r\n", u
"\n")
114 def _to_system_newlines(input):
115 if os
.linesep
!= "\n":
116 return input.replace(u
"\n", os
.linesep
)
120 _open_with_encoding
= open
121 _from_system_newlines
= _identity
122 _to_system_newlines
= _identity
125 def _detect_future_print(source
):
126 have_docstring
= False
127 gen
= tokenize
.generate_tokens(StringIO
.StringIO(source
).readline
)
130 return tok
[0], tok
[1]
131 ignore
= frozenset((token
.NEWLINE
, tokenize
.NL
, token
.COMMENT
))
134 tp
, value
= advance()
137 elif tp
== token
.STRING
:
140 have_docstring
= True
141 elif tp
== token
.NAME
and value
== u
"from":
142 tp
, value
= advance()
143 if tp
!= token
.NAME
and value
!= u
"__future__":
145 tp
, value
= advance()
146 if tp
!= token
.NAME
and value
!= u
"import":
148 tp
, value
= advance()
149 if tp
== token
.OP
and value
== u
"(":
150 tp
, value
= advance()
151 while tp
== token
.NAME
:
152 if value
== u
"print_function":
154 tp
, value
= advance()
155 if tp
!= token
.OP
and value
!= u
",":
157 tp
, value
= advance()
160 except StopIteration:
165 class FixerError(Exception):
166 """A fixer could not be loaded."""
169 class RefactoringTool(object):
171 _default_options
= {"print_function" : False}
173 CLASS_PREFIX
= "Fix" # The prefix for fixer classes
174 FILE_PREFIX
= "fix_" # The prefix for modules with a fixer within
176 def __init__(self
, fixer_names
, options
=None, explicit
=None):
180 fixer_names: a list of fixers to import
181 options: an dict with configuration.
182 explicit: a list of fixers to run even if they are explicit.
184 self
.fixers
= fixer_names
185 self
.explicit
= explicit
or []
186 self
.options
= self
._default
_options
.copy()
187 if options
is not None:
188 self
.options
.update(options
)
189 if self
.options
["print_function"]:
190 self
.grammar
= pygram
.python_grammar_no_print_statement
192 self
.grammar
= pygram
.python_grammar
194 self
.logger
= logging
.getLogger("RefactoringTool")
197 self
.driver
= driver
.Driver(self
.grammar
,
198 convert
=pytree
.convert
,
200 self
.pre_order
, self
.post_order
= self
.get_fixers()
202 self
.pre_order_heads
= _get_headnode_dict(self
.pre_order
)
203 self
.post_order_heads
= _get_headnode_dict(self
.post_order
)
205 self
.files
= [] # List of files that were or should be modified
207 def get_fixers(self
):
208 """Inspects the options to load the requested patterns and handlers.
211 (pre_order, post_order), where pre_order is the list of fixers that
212 want a pre-order AST traversal, and post_order is the list that want
213 post-order traversal.
215 pre_order_fixers
= []
216 post_order_fixers
= []
217 for fix_mod_path
in self
.fixers
:
218 mod
= __import__(fix_mod_path
, {}, {}, ["*"])
219 fix_name
= fix_mod_path
.rsplit(".", 1)[-1]
220 if fix_name
.startswith(self
.FILE_PREFIX
):
221 fix_name
= fix_name
[len(self
.FILE_PREFIX
):]
222 parts
= fix_name
.split("_")
223 class_name
= self
.CLASS_PREFIX
+ "".join([p
.title() for p
in parts
])
225 fix_class
= getattr(mod
, class_name
)
226 except AttributeError:
227 raise FixerError("Can't find %s.%s" % (fix_name
, class_name
))
228 fixer
= fix_class(self
.options
, self
.fixer_log
)
229 if fixer
.explicit
and self
.explicit
is not True and \
230 fix_mod_path
not in self
.explicit
:
231 self
.log_message("Skipping implicit fixer: %s", fix_name
)
234 self
.log_debug("Adding transformation: %s", fix_name
)
235 if fixer
.order
== "pre":
236 pre_order_fixers
.append(fixer
)
237 elif fixer
.order
== "post":
238 post_order_fixers
.append(fixer
)
240 raise FixerError("Illegal fixer order: %r" % fixer
.order
)
242 key_func
= operator
.attrgetter("run_order")
243 pre_order_fixers
.sort(key
=key_func
)
244 post_order_fixers
.sort(key
=key_func
)
245 return (pre_order_fixers
, post_order_fixers
)
247 def log_error(self
, msg
, *args
, **kwds
):
248 """Called when an error occurs."""
251 def log_message(self
, msg
, *args
):
252 """Hook to log a message."""
255 self
.logger
.info(msg
)
257 def log_debug(self
, msg
, *args
):
260 self
.logger
.debug(msg
)
262 def print_output(self
, old_text
, new_text
, filename
, equal
):
263 """Called with the old version, new version, and filename of a
267 def refactor(self
, items
, write
=False, doctests_only
=False):
268 """Refactor a list of files and directories."""
269 for dir_or_file
in items
:
270 if os
.path
.isdir(dir_or_file
):
271 self
.refactor_dir(dir_or_file
, write
, doctests_only
)
273 self
.refactor_file(dir_or_file
, write
, doctests_only
)
275 def refactor_dir(self
, dir_name
, write
=False, doctests_only
=False):
276 """Descends down a directory and refactor every Python file found.
278 Python files are assumed to have a .py extension.
280 Files and subdirectories starting with '.' are skipped.
282 for dirpath
, dirnames
, filenames
in os
.walk(dir_name
):
283 self
.log_debug("Descending into %s", dirpath
)
286 for name
in filenames
:
287 if not name
.startswith(".") and \
288 os
.path
.splitext(name
)[1].endswith("py"):
289 fullname
= os
.path
.join(dirpath
, name
)
290 self
.refactor_file(fullname
, write
, doctests_only
)
291 # Modify dirnames in-place to remove subdirs with leading dots
292 dirnames
[:] = [dn
for dn
in dirnames
if not dn
.startswith(".")]
294 def _read_python_source(self
, filename
):
296 Do our best to decode a Python source file correctly.
299 f
= open(filename
, "rb")
301 self
.log_error("Can't open %s: %s", filename
, err
)
304 encoding
= tokenize
.detect_encoding(f
.readline
)[0]
307 with
_open_with_encoding(filename
, "r", encoding
=encoding
) as f
:
308 return _from_system_newlines(f
.read()), encoding
310 def refactor_file(self
, filename
, write
=False, doctests_only
=False):
311 """Refactors a file."""
312 input, encoding
= self
._read
_python
_source
(filename
)
314 # Reading the file failed.
316 input += u
"\n" # Silence certain parse errors
318 self
.log_debug("Refactoring doctests in %s", filename
)
319 output
= self
.refactor_docstring(input, filename
)
321 self
.processed_file(output
, filename
, input, write
, encoding
)
323 self
.log_debug("No doctest changes in %s", filename
)
325 tree
= self
.refactor_string(input, filename
)
326 if tree
and tree
.was_changed
:
327 # The [:-1] is to take off the \n we added earlier
328 self
.processed_file(unicode(tree
)[:-1], filename
,
329 write
=write
, encoding
=encoding
)
331 self
.log_debug("No changes in %s", filename
)
333 def refactor_string(self
, data
, name
):
334 """Refactor a given input string.
337 data: a string holding the code to be refactored.
338 name: a human-readable name for use in error/log messages.
341 An AST corresponding to the refactored input stream; None if
342 there were errors during the parse.
344 if _detect_future_print(data
):
345 self
.driver
.grammar
= pygram
.python_grammar_no_print_statement
347 tree
= self
.driver
.parse_string(data
)
348 except Exception, err
:
349 self
.log_error("Can't parse %s: %s: %s",
350 name
, err
.__class
__.__name
__, err
)
353 self
.driver
.grammar
= self
.grammar
354 self
.log_debug("Refactoring %s", name
)
355 self
.refactor_tree(tree
, name
)
358 def refactor_stdin(self
, doctests_only
=False):
359 input = sys
.stdin
.read()
361 self
.log_debug("Refactoring doctests in stdin")
362 output
= self
.refactor_docstring(input, "<stdin>")
364 self
.processed_file(output
, "<stdin>", input)
366 self
.log_debug("No doctest changes in stdin")
368 tree
= self
.refactor_string(input, "<stdin>")
369 if tree
and tree
.was_changed
:
370 self
.processed_file(unicode(tree
), "<stdin>", input)
372 self
.log_debug("No changes in stdin")
374 def refactor_tree(self
, tree
, name
):
375 """Refactors a parse tree (modifying the tree in place).
378 tree: a pytree.Node instance representing the root of the tree
380 name: a human-readable name for this tree.
383 True if the tree was modified, False otherwise.
385 for fixer
in chain(self
.pre_order
, self
.post_order
):
386 fixer
.start_tree(tree
, name
)
388 self
.traverse_by(self
.pre_order_heads
, tree
.pre_order())
389 self
.traverse_by(self
.post_order_heads
, tree
.post_order())
391 for fixer
in chain(self
.pre_order
, self
.post_order
):
392 fixer
.finish_tree(tree
, name
)
393 return tree
.was_changed
395 def traverse_by(self
, fixers
, traversal
):
396 """Traverse an AST, applying a set of fixers to each node.
398 This is a helper method for refactor_tree().
401 fixers: a list of fixer instances.
402 traversal: a generator that yields AST nodes.
409 for node
in traversal
:
410 for fixer
in fixers
[node
.type]:
411 results
= fixer
.match(node
)
413 new
= fixer
.transform(node
, results
)
418 def processed_file(self
, new_text
, filename
, old_text
=None, write
=False,
421 Called when a file has been refactored, and there are changes.
423 self
.files
.append(filename
)
425 old_text
= self
._read
_python
_source
(filename
)[0]
428 equal
= old_text
== new_text
429 self
.print_output(old_text
, new_text
, filename
, equal
)
431 self
.log_debug("No changes to %s", filename
)
434 self
.write_file(new_text
, filename
, old_text
, encoding
)
436 self
.log_debug("Not writing changes to %s", filename
)
438 def write_file(self
, new_text
, filename
, old_text
, encoding
=None):
439 """Writes a string to a file.
441 It first shows a unified diff between the old text and the new text, and
442 then rewrites the file; the latter is only done if the write option is
446 f
= _open_with_encoding(filename
, "w", encoding
=encoding
)
447 except os
.error
, err
:
448 self
.log_error("Can't create %s: %s", filename
, err
)
451 f
.write(_to_system_newlines(new_text
))
452 except os
.error
, err
:
453 self
.log_error("Can't write %s: %s", filename
, err
)
456 self
.log_debug("Wrote changes to %s", filename
)
462 def refactor_docstring(self
, input, filename
):
463 """Refactors a docstring, looking for doctests.
465 This returns a modified version of the input string. It looks
466 for doctests, which start with a ">>>" prompt, and may be
467 continued with "..." prompts, as long as the "..." is indented
468 the same as the ">>>".
470 (Unfortunately we can't use the doctest module's parser,
471 since, like most parsers, it is not geared towards preserving
472 the original source.)
479 for line
in input.splitlines(True):
481 if line
.lstrip().startswith(self
.PS1
):
482 if block
is not None:
483 result
.extend(self
.refactor_doctest(block
, block_lineno
,
485 block_lineno
= lineno
487 i
= line
.find(self
.PS1
)
489 elif (indent
is not None and
490 (line
.startswith(indent
+ self
.PS2
) or
491 line
== indent
+ self
.PS2
.rstrip() + u
"\n")):
494 if block
is not None:
495 result
.extend(self
.refactor_doctest(block
, block_lineno
,
500 if block
is not None:
501 result
.extend(self
.refactor_doctest(block
, block_lineno
,
503 return u
"".join(result
)
505 def refactor_doctest(self
, block
, lineno
, indent
, filename
):
506 """Refactors one doctest.
508 A doctest is given as a block of lines, the first of which starts
509 with ">>>" (possibly indented), while the remaining lines start
510 with "..." (identically indented).
514 tree
= self
.parse_block(block
, lineno
, indent
)
515 except Exception, err
:
516 if self
.log
.isEnabledFor(logging
.DEBUG
):
518 self
.log_debug("Source: %s", line
.rstrip(u
"\n"))
519 self
.log_error("Can't parse docstring in %s line %s: %s: %s",
520 filename
, lineno
, err
.__class
__.__name
__, err
)
522 if self
.refactor_tree(tree
, filename
):
523 new
= unicode(tree
).splitlines(True)
524 # Undo the adjustment of the line numbers in wrap_toks() below.
525 clipped
, new
= new
[:lineno
-1], new
[lineno
-1:]
526 assert clipped
== [u
"\n"] * (lineno
-1), clipped
527 if not new
[-1].endswith(u
"\n"):
529 block
= [indent
+ self
.PS1
+ new
.pop(0)]
531 block
+= [indent
+ self
.PS2
+ line
for line
in new
]
540 self
.log_message("No files %s modified.", were
)
542 self
.log_message("Files that %s modified:", were
)
543 for file in self
.files
:
544 self
.log_message(file)
546 self
.log_message("Warnings/messages while refactoring:")
547 for message
in self
.fixer_log
:
548 self
.log_message(message
)
550 if len(self
.errors
) == 1:
551 self
.log_message("There was 1 error:")
553 self
.log_message("There were %d errors:", len(self
.errors
))
554 for msg
, args
, kwds
in self
.errors
:
555 self
.log_message(msg
, *args
, **kwds
)
557 def parse_block(self
, block
, lineno
, indent
):
558 """Parses a block into a tree.
560 This is necessary to get correct line number / offset information
561 in the parser diagnostics and embedded into the parse tree.
563 return self
.driver
.parse_tokens(self
.wrap_toks(block
, lineno
, indent
))
565 def wrap_toks(self
, block
, lineno
, indent
):
566 """Wraps a tokenize stream to systematically modify start/end."""
567 tokens
= tokenize
.generate_tokens(self
.gen_lines(block
, indent
).next
)
568 for type, value
, (line0
, col0
), (line1
, col1
), line_text
in tokens
:
571 # Don't bother updating the columns; this is too complicated
572 # since line_text would also have to be updated and it would
573 # still break for tokens spanning lines. Let the user guess
574 # that the column numbers for doctests are relative to the
575 # end of the prompt string (PS1 or PS2).
576 yield type, value
, (line0
, col0
), (line1
, col1
), line_text
579 def gen_lines(self
, block
, indent
):
580 """Generates lines as expected by tokenize from a list of lines.
582 This strips the first len(indent + self.PS1) characters off each line.
584 prefix1
= indent
+ self
.PS1
585 prefix2
= indent
+ self
.PS2
588 if line
.startswith(prefix
):
589 yield line
[len(prefix
):]
590 elif line
== prefix
.rstrip() + u
"\n":
593 raise AssertionError("line=%r, prefix=%r" % (line
, prefix
))
599 class MultiprocessingUnsupported(Exception):
603 class MultiprocessRefactoringTool(RefactoringTool
):
605 def __init__(self
, *args
, **kwargs
):
606 super(MultiprocessRefactoringTool
, self
).__init
__(*args
, **kwargs
)
609 def refactor(self
, items
, write
=False, doctests_only
=False,
611 if num_processes
== 1:
612 return super(MultiprocessRefactoringTool
, self
).refactor(
613 items
, write
, doctests_only
)
615 import multiprocessing
617 raise MultiprocessingUnsupported
618 if self
.queue
is not None:
619 raise RuntimeError("already doing multiple processes")
620 self
.queue
= multiprocessing
.JoinableQueue()
621 processes
= [multiprocessing
.Process(target
=self
._child
)
622 for i
in xrange(num_processes
)]
626 super(MultiprocessRefactoringTool
, self
).refactor(items
, write
,
630 for i
in xrange(num_processes
):
638 task
= self
.queue
.get()
639 while task
is not None:
642 super(MultiprocessRefactoringTool
, self
).refactor_file(
645 self
.queue
.task_done()
646 task
= self
.queue
.get()
648 def refactor_file(self
, *args
, **kwargs
):
649 if self
.queue
is not None:
650 self
.queue
.put((args
, kwargs
))
652 return super(MultiprocessRefactoringTool
, self
).refactor_file(