From acd35e70048494cee8adcd7b41b38d1334d0e22d Mon Sep 17 00:00:00 2001 From: Peter Grayson Date: Mon, 15 Feb 2021 17:07:53 -0500 Subject: [PATCH] Use tempfile.TemporaryDirectory in imprt.py Using TemporaryDirectory as a context manager is cleaner than writing our own try/finally flows and is available on all supported Python versions. Also drop compatibility code when importing from urllib. This was needed for Python 2 support. Signed-off-by: Peter Grayson --- stgit/commands/imprt.py | 38 ++++++++++++-------------------------- 1 file changed, 12 insertions(+), 26 deletions(-) diff --git a/stgit/commands/imprt.py b/stgit/commands/imprt.py index 48aa827..2439d9f 100644 --- a/stgit/commands/imprt.py +++ b/stgit/commands/imprt.py @@ -1,6 +1,7 @@ import os import re import sys +import tempfile from stgit import argparse from stgit.argparse import opt @@ -328,16 +329,10 @@ def __import_series(filename, options): def __import_mail(filename, options): """Import a patch from an email file or mbox""" - import shutil - import tempfile - - tmpdir = tempfile.mkdtemp('.stg') - try: + with tempfile.TemporaryDirectory('.stg') as tmpdir: mail_paths = __mailsplit(tmpdir, filename, options) for mail_path in mail_paths: __import_mail_path(mail_path, filename, options) - finally: - shutil.rmtree(tmpdir) def __mailsplit(tmpdir, filename, options): @@ -399,27 +394,19 @@ def __import_mail_path(mail_path, filename, options): def __import_url(url, options): """Import a patch from a URL""" - try: - from urllib.parse import unquote - from urllib.request import urlretrieve - except ImportError: - from urllib import unquote, urlretrieve - import tempfile - - if not url: - raise CmdException('URL argument required') + from urllib.parse import unquote + from urllib.request import urlretrieve - patch = os.path.basename(unquote(url)) - filename = os.path.join(tempfile.gettempdir(), patch) - urlretrieve(url, filename) - __import_file(filename, options) + with tempfile.TemporaryDirectory('.stg') as tmpdir: + patch = os.path.basename(unquote(url)) + filename = os.path.join(tmpdir, patch) + urlretrieve(url, filename) + __import_file(filename, options) def __import_tarfile(tarpath, options): """Import patch series from a tar archive""" - import shutil import tarfile - import tempfile assert tarfile.is_tarfile(tarpath) @@ -441,12 +428,9 @@ def __import_tarfile(tarpath, options): raise CmdException("no 'series' file found in %s" % tarpath) # unpack into a tmp dir - tmpdir = tempfile.mkdtemp('.stg') - try: + with tempfile.TemporaryDirectory('.stg') as tmpdir: tar.extractall(tmpdir) __import_series(os.path.join(tmpdir, seriesfile), options) - finally: - shutil.rmtree(tmpdir) def func(parser, options, args): @@ -454,6 +438,8 @@ def func(parser, options, args): parser.error('incorrect number of arguments') elif len(args) == 1: filename = args[0] + elif options.url: + raise CmdException('URL argument required') else: filename = None -- 2.11.4.GIT