Auto-throw exceptions when resuming tasks
[zeroinstall.git] / zeroinstall / injector / handler.py
blob70b12a619acf7695bf77d84f0e02ad31a289d00e
1 """
2 Integrates download callbacks with an external mainloop.
3 While things are being downloaded, Zero Install returns control to your program.
4 Your mainloop is responsible for monitoring the state of the downloads and notifying
5 Zero Install when they are complete.
7 To do this, you supply a L{Handler} to the L{policy}.
8 """
10 # Copyright (C) 2009, Thomas Leonard
11 # See the README file for details, or visit http://0install.net.
13 from __future__ import print_function
15 from zeroinstall import _
16 import sys
17 from logging import warn, info
19 from zeroinstall import NeedDownload, SafeException
20 from zeroinstall.support import tasks
21 from zeroinstall.injector import download
23 class NoTrustedKeys(SafeException):
24 """Thrown by L{Handler.confirm_import_feed} on failure."""
25 pass
27 class Handler(object):
28 """
29 A Handler is used to interact with the user (e.g. to confirm keys, display download progress, etc).
31 @ivar monitored_downloads: set of downloads in progress
32 @type monitored_downloads: {L{download.Download}}
33 @ivar n_completed_downloads: number of downloads which have finished for GUIs, etc (can be reset as desired).
34 @type n_completed_downloads: int
35 @ivar total_bytes_downloaded: informational counter for GUIs, etc (can be reset as desired). Updated when download finishes.
36 @type total_bytes_downloaded: int
37 @ivar dry_run: instead of starting a download, just report what we would have downloaded
38 @type dry_run: bool
39 """
41 __slots__ = ['monitored_downloads', 'dry_run', 'total_bytes_downloaded', 'n_completed_downloads']
43 def __init__(self, mainloop = None, dry_run = False):
44 self.monitored_downloads = set()
45 self.dry_run = dry_run
46 self.n_completed_downloads = 0
47 self.total_bytes_downloaded = 0
49 def monitor_download(self, dl):
50 """Called when a new L{download} is started.
51 This is mainly used by the GUI to display the progress bar."""
52 self.monitored_downloads.add(dl)
53 self.downloads_changed()
55 @tasks.async
56 def download_done_stats():
57 try:
58 yield dl.downloaded
59 except:
60 # NB: we don't check for exceptions here; someone else should be doing that
61 pass
62 try:
63 self.n_completed_downloads += 1
64 self.total_bytes_downloaded += dl.get_bytes_downloaded_so_far()
65 self.monitored_downloads.remove(dl)
66 self.downloads_changed()
67 except Exception as ex:
68 self.report_error(ex)
69 download_done_stats()
71 def impl_added_to_store(self, impl):
72 """Called by the L{fetch.Fetcher} when adding an implementation.
73 The GUI uses this to update its display.
74 @param impl: the implementation which has been added
75 @type impl: L{model.Implementation}
76 """
77 pass
79 def downloads_changed(self):
80 """This is just for the GUI to override to update its display."""
81 pass
83 def wait_for_blocker(self, blocker):
84 """@deprecated: use tasks.wait_for_blocker instead"""
85 tasks.wait_for_blocker(blocker)
87 @tasks.async
88 def confirm_import_feed(self, pending, valid_sigs):
89 """Sub-classes should override this method to interact with the user about new feeds.
90 If multiple feeds need confirmation, L{trust.TrustMgr.confirm_keys} will only invoke one instance of this
91 method at a time.
92 @param pending: the new feed to be imported
93 @type pending: L{PendingFeed}
94 @param valid_sigs: maps signatures to a list of fetchers collecting information about the key
95 @type valid_sigs: {L{gpg.ValidSig} : L{fetch.KeyInfoFetcher}}
96 @since: 0.42"""
97 from zeroinstall.injector import trust
99 assert valid_sigs
101 domain = trust.domain_from_url(pending.url)
103 # Ask on stderr, because we may be writing XML to stdout
104 print(_("Feed: %s") % pending.url, file=sys.stderr)
105 print(_("The feed is correctly signed with the following keys:"), file=sys.stderr)
106 for x in valid_sigs:
107 print("-", x, file=sys.stderr)
109 def text(parent):
110 text = ""
111 for node in parent.childNodes:
112 if node.nodeType == node.TEXT_NODE:
113 text = text + node.data
114 return text
116 shown = set()
117 key_info_fetchers = valid_sigs.values()
118 while key_info_fetchers:
119 old_kfs = key_info_fetchers
120 key_info_fetchers = []
121 for kf in old_kfs:
122 infos = set(kf.info) - shown
123 if infos:
124 if len(valid_sigs) > 1:
125 print("%s: " % kf.fingerprint)
126 for key_info in infos:
127 print("-", text(key_info), file=sys.stderr)
128 shown.add(key_info)
129 if kf.blocker:
130 key_info_fetchers.append(kf)
131 if key_info_fetchers:
132 for kf in key_info_fetchers: print(kf.status, file=sys.stderr)
133 stdin = tasks.InputBlocker(0, 'console')
134 blockers = [kf.blocker for kf in key_info_fetchers] + [stdin]
135 try:
136 yield blockers
137 except:
138 pass
139 for b in blockers:
140 try:
141 tasks.check(b)
142 except Exception as ex:
143 warn(_("Failed to get key info: %s"), ex)
144 if stdin.happened:
145 print(_("Skipping remaining key lookups due to input from user"), file=sys.stderr)
146 break
147 if not shown:
148 print(_("Warning: Nothing known about this key!"), file=sys.stderr)
150 if len(valid_sigs) == 1:
151 print(_("Do you want to trust this key to sign feeds from '%s'?") % domain, file=sys.stderr)
152 else:
153 print(_("Do you want to trust all of these keys to sign feeds from '%s'?") % domain, file=sys.stderr)
154 while True:
155 print(_("Trust [Y/N] "), end=' ', file=sys.stderr)
156 i = raw_input()
157 if not i: continue
158 if i in 'Nn':
159 raise NoTrustedKeys(_('Not signed with a trusted key'))
160 if i in 'Yy':
161 break
162 for key in valid_sigs:
163 print(_("Trusting %(key_fingerprint)s for %(domain)s") % {'key_fingerprint': key.fingerprint, 'domain': domain}, file=sys.stderr)
164 trust.trust_db.trust_key(key.fingerprint, domain)
166 @tasks.async
167 def confirm_install(self, msg):
168 """We need to check something with the user before continuing with the install.
169 @raise download.DownloadAborted: if the user cancels"""
170 yield
171 print(msg, file=sys.stderr)
172 while True:
173 sys.stderr.write(_("Install [Y/N] "))
174 i = raw_input()
175 if not i: continue
176 if i in 'Nn':
177 raise download.DownloadAborted()
178 if i in 'Yy':
179 break
181 def report_error(self, exception, tb = None):
182 """Report an exception to the user.
183 @param exception: the exception to report
184 @type exception: L{SafeException}
185 @param tb: optional traceback
186 @since: 0.25"""
187 warn("%s", str(exception) or type(exception))
188 #import traceback
189 #traceback.print_exception(exception, None, tb)
191 class ConsoleHandler(Handler):
192 """A Handler that displays progress on stdout (a tty).
193 @since: 0.44"""
194 last_msg_len = None
195 update = None
196 disable_progress = 0
197 screen_width = None
199 def downloads_changed(self):
200 import gobject
201 if self.monitored_downloads and self.update is None:
202 if self.screen_width is None:
203 try:
204 import curses
205 curses.setupterm()
206 self.screen_width = curses.tigetnum('cols') or 80
207 except Exception as ex:
208 info("Failed to initialise curses library: %s", ex)
209 self.screen_width = 80
210 self.show_progress()
211 self.update = gobject.timeout_add(200, self.show_progress)
212 elif len(self.monitored_downloads) == 0:
213 if self.update:
214 gobject.source_remove(self.update)
215 self.update = None
216 print()
217 self.last_msg_len = None
219 def show_progress(self):
220 if not self.monitored_downloads: return True
221 urls = [(dl.url, dl) for dl in self.monitored_downloads]
223 if self.disable_progress: return True
225 screen_width = self.screen_width - 2
226 item_width = max(16, screen_width / len(self.monitored_downloads))
227 url_width = item_width - 7
229 msg = ""
230 for url, dl in sorted(urls):
231 so_far = dl.get_bytes_downloaded_so_far()
232 leaf = url.rsplit('/', 1)[-1]
233 if len(leaf) >= url_width:
234 display = leaf[:url_width]
235 else:
236 display = url[-url_width:]
237 if dl.expected_size:
238 msg += "[%s %d%%] " % (display, int(so_far * 100 / dl.expected_size))
239 else:
240 msg += "[%s] " % (display)
241 msg = msg[:screen_width]
243 if self.last_msg_len is None:
244 sys.stdout.write(msg)
245 else:
246 sys.stdout.write(chr(13) + msg)
247 if len(msg) < self.last_msg_len:
248 sys.stdout.write(" " * (self.last_msg_len - len(msg)))
250 self.last_msg_len = len(msg)
251 sys.stdout.flush()
253 return True
255 def clear_display(self):
256 if self.last_msg_len != None:
257 sys.stdout.write(chr(13) + " " * self.last_msg_len + chr(13))
258 sys.stdout.flush()
259 self.last_msg_len = None
261 def report_error(self, exception, tb = None):
262 self.clear_display()
263 Handler.report_error(self, exception, tb)
265 def confirm_import_feed(self, pending, valid_sigs):
266 self.clear_display()
267 self.disable_progress += 1
268 blocker = Handler.confirm_import_feed(self, pending, valid_sigs)
269 @tasks.async
270 def enable():
271 try:
272 yield blocker
273 except:
274 pass
275 self.disable_progress -= 1
276 self.show_progress()
277 enable()
278 return blocker