Fixed console download progress indicator on Python 3
[zeroinstall/solver.git] / zeroinstall / injector / handler.py
blob940f521927a2e03b7748702eb563f65058cccee2
1 """
2 Integrates download callbacks with an external mainloop.
3 While things are being downloaded, Zero Install returns control to your program.
4 Your mainloop is responsible for monitoring the state of the downloads and notifying
5 Zero Install when they are complete.
7 To do this, you supply a L{Handler} to the L{policy}.
8 """
10 # Copyright (C) 2009, Thomas Leonard
11 # See the README file for details, or visit http://0install.net.
13 from __future__ import print_function
15 from zeroinstall import _
16 import sys
17 from logging import warn, info
19 from zeroinstall import SafeException
20 from zeroinstall import support
21 from zeroinstall.support import tasks
22 from zeroinstall.injector import download
24 class NoTrustedKeys(SafeException):
25 """Thrown by L{Handler.confirm_import_feed} on failure."""
26 pass
28 class Handler(object):
29 """
30 A Handler is used to interact with the user (e.g. to confirm keys, display download progress, etc).
32 @ivar monitored_downloads: set of downloads in progress
33 @type monitored_downloads: {L{download.Download}}
34 @ivar n_completed_downloads: number of downloads which have finished for GUIs, etc (can be reset as desired).
35 @type n_completed_downloads: int
36 @ivar total_bytes_downloaded: informational counter for GUIs, etc (can be reset as desired). Updated when download finishes.
37 @type total_bytes_downloaded: int
38 @ivar dry_run: instead of starting a download, just report what we would have downloaded
39 @type dry_run: bool
40 """
42 __slots__ = ['monitored_downloads', 'dry_run', 'total_bytes_downloaded', 'n_completed_downloads']
44 def __init__(self, mainloop = None, dry_run = False):
45 self.monitored_downloads = set()
46 self.dry_run = dry_run
47 self.n_completed_downloads = 0
48 self.total_bytes_downloaded = 0
50 def monitor_download(self, dl):
51 """Called when a new L{download} is started.
52 This is mainly used by the GUI to display the progress bar."""
53 self.monitored_downloads.add(dl)
54 self.downloads_changed()
56 @tasks.async
57 def download_done_stats():
58 yield dl.downloaded
59 # NB: we don't check for exceptions here; someone else should be doing that
60 try:
61 self.n_completed_downloads += 1
62 self.total_bytes_downloaded += dl.get_bytes_downloaded_so_far()
63 self.monitored_downloads.remove(dl)
64 self.downloads_changed()
65 except Exception as ex:
66 self.report_error(ex)
67 download_done_stats()
69 def impl_added_to_store(self, impl):
70 """Called by the L{fetch.Fetcher} when adding an implementation.
71 The GUI uses this to update its display.
72 @param impl: the implementation which has been added
73 @type impl: L{model.Implementation}
74 """
75 pass
77 def downloads_changed(self):
78 """This is just for the GUI to override to update its display."""
79 pass
81 def wait_for_blocker(self, blocker):
82 """@deprecated: use tasks.wait_for_blocker instead"""
83 tasks.wait_for_blocker(blocker)
85 @tasks.async
86 def confirm_import_feed(self, pending, valid_sigs):
87 """Sub-classes should override this method to interact with the user about new feeds.
88 If multiple feeds need confirmation, L{trust.TrustMgr.confirm_keys} will only invoke one instance of this
89 method at a time.
90 @param pending: the new feed to be imported
91 @type pending: L{PendingFeed}
92 @param valid_sigs: maps signatures to a list of fetchers collecting information about the key
93 @type valid_sigs: {L{gpg.ValidSig} : L{fetch.KeyInfoFetcher}}
94 @since: 0.42"""
95 from zeroinstall.injector import trust
97 assert valid_sigs
99 domain = trust.domain_from_url(pending.url)
101 # Ask on stderr, because we may be writing XML to stdout
102 print(_("Feed: %s") % pending.url, file=sys.stderr)
103 print(_("The feed is correctly signed with the following keys:"), file=sys.stderr)
104 for x in valid_sigs:
105 print("-", x, file=sys.stderr)
107 def text(parent):
108 text = ""
109 for node in parent.childNodes:
110 if node.nodeType == node.TEXT_NODE:
111 text = text + node.data
112 return text
114 shown = set()
115 key_info_fetchers = valid_sigs.values()
116 while key_info_fetchers:
117 old_kfs = key_info_fetchers
118 key_info_fetchers = []
119 for kf in old_kfs:
120 infos = set(kf.info) - shown
121 if infos:
122 if len(valid_sigs) > 1:
123 print("%s: " % kf.fingerprint)
124 for key_info in infos:
125 print("-", text(key_info), file=sys.stderr)
126 shown.add(key_info)
127 if kf.blocker:
128 key_info_fetchers.append(kf)
129 if key_info_fetchers:
130 for kf in key_info_fetchers: print(kf.status, file=sys.stderr)
131 stdin = tasks.InputBlocker(0, 'console')
132 blockers = [kf.blocker for kf in key_info_fetchers] + [stdin]
133 yield blockers
134 for b in blockers:
135 try:
136 tasks.check(b)
137 except Exception as ex:
138 warn(_("Failed to get key info: %s"), ex)
139 if stdin.happened:
140 print(_("Skipping remaining key lookups due to input from user"), file=sys.stderr)
141 break
142 if not shown:
143 print(_("Warning: Nothing known about this key!"), file=sys.stderr)
145 if len(valid_sigs) == 1:
146 print(_("Do you want to trust this key to sign feeds from '%s'?") % domain, file=sys.stderr)
147 else:
148 print(_("Do you want to trust all of these keys to sign feeds from '%s'?") % domain, file=sys.stderr)
149 while True:
150 print(_("Trust [Y/N] "), end=' ', file=sys.stderr)
151 i = support.raw_input()
152 if not i: continue
153 if i in 'Nn':
154 raise NoTrustedKeys(_('Not signed with a trusted key'))
155 if i in 'Yy':
156 break
157 for key in valid_sigs:
158 print(_("Trusting %(key_fingerprint)s for %(domain)s") % {'key_fingerprint': key.fingerprint, 'domain': domain}, file=sys.stderr)
159 trust.trust_db.trust_key(key.fingerprint, domain)
161 @tasks.async
162 def confirm_install(self, msg):
163 """We need to check something with the user before continuing with the install.
164 @raise download.DownloadAborted: if the user cancels"""
165 yield
166 print(msg, file=sys.stderr)
167 while True:
168 sys.stderr.write(_("Install [Y/N] "))
169 i = raw_input()
170 if not i: continue
171 if i in 'Nn':
172 raise download.DownloadAborted()
173 if i in 'Yy':
174 break
176 def report_error(self, exception, tb = None):
177 """Report an exception to the user.
178 @param exception: the exception to report
179 @type exception: L{SafeException}
180 @param tb: optional traceback
181 @since: 0.25"""
182 warn("%s", str(exception) or type(exception))
183 #import traceback
184 #traceback.print_exception(exception, None, tb)
186 class ConsoleHandler(Handler):
187 """A Handler that displays progress on stdout (a tty).
188 @since: 0.44"""
189 last_msg_len = None
190 update = None
191 disable_progress = 0
192 screen_width = None
194 def downloads_changed(self):
195 from zeroinstall import gobject
196 if self.monitored_downloads and self.update is None:
197 if self.screen_width is None:
198 try:
199 import curses
200 curses.setupterm()
201 self.screen_width = curses.tigetnum('cols') or 80
202 except Exception as ex:
203 info("Failed to initialise curses library: %s", ex)
204 self.screen_width = 80
205 self.show_progress()
206 self.update = gobject.timeout_add(200, self.show_progress)
207 elif len(self.monitored_downloads) == 0:
208 if self.update:
209 gobject.source_remove(self.update)
210 self.update = None
211 print()
212 self.last_msg_len = None
214 def show_progress(self):
215 if not self.monitored_downloads: return True
216 urls = [(dl.url, dl) for dl in self.monitored_downloads]
218 if self.disable_progress: return True
220 screen_width = self.screen_width - 2
221 item_width = max(16, screen_width // len(self.monitored_downloads))
222 url_width = item_width - 7
224 msg = ""
225 for url, dl in sorted(urls):
226 so_far = dl.get_bytes_downloaded_so_far()
227 leaf = url.rsplit('/', 1)[-1]
228 if len(leaf) >= url_width:
229 display = leaf[:url_width]
230 else:
231 display = url[-url_width:]
232 if dl.expected_size:
233 msg += "[%s %d%%] " % (display, int(so_far * 100 / dl.expected_size))
234 else:
235 msg += "[%s] " % (display)
236 msg = msg[:screen_width]
238 if self.last_msg_len is None:
239 sys.stdout.write(msg)
240 else:
241 sys.stdout.write(chr(13) + msg)
242 if len(msg) < self.last_msg_len:
243 sys.stdout.write(" " * (self.last_msg_len - len(msg)))
245 self.last_msg_len = len(msg)
246 sys.stdout.flush()
248 return True
250 def clear_display(self):
251 if self.last_msg_len != None:
252 sys.stdout.write(chr(13) + " " * self.last_msg_len + chr(13))
253 sys.stdout.flush()
254 self.last_msg_len = None
256 def report_error(self, exception, tb = None):
257 self.clear_display()
258 Handler.report_error(self, exception, tb)
260 def confirm_import_feed(self, pending, valid_sigs):
261 self.clear_display()
262 self.disable_progress += 1
263 blocker = Handler.confirm_import_feed(self, pending, valid_sigs)
264 @tasks.async
265 def enable():
266 yield blocker
267 self.disable_progress -= 1
268 self.show_progress()
269 enable()
270 return blocker