Fixed missing cast to Unicode
[zeroinstall/solver.git] / zeroinstall / injector / handler.py
blobabea2fbfac882b4dd7ee9596ce520c289695bc30
1 """
2 Integrates download callbacks with an external mainloop.
3 While things are being downloaded, Zero Install returns control to your program.
4 Your mainloop is responsible for monitoring the state of the downloads and notifying
5 Zero Install when they are complete.
7 To do this, you supply a L{Handler} to the L{policy}.
8 """
10 # Copyright (C) 2009, Thomas Leonard
11 # See the README file for details, or visit http://0install.net.
13 from zeroinstall import _
14 import sys
15 from logging import debug, warn, info
17 from zeroinstall import NeedDownload, SafeException
18 from zeroinstall.support import tasks
19 from zeroinstall.injector import download
21 class NoTrustedKeys(SafeException):
22 """Thrown by L{Handler.confirm_trust_keys} on failure."""
23 pass
25 class Handler(object):
26 """
27 This implementation uses the GLib mainloop. Note that QT4 can use the GLib mainloop too.
29 @ivar monitored_downloads: dict of downloads in progress
30 @type monitored_downloads: {URL: L{download.Download}}
31 @ivar n_completed_downloads: number of downloads which have finished for GUIs, etc (can be reset as desired).
32 @type n_completed_downloads: int
33 @ivar total_bytes_downloaded: informational counter for GUIs, etc (can be reset as desired). Updated when download finishes.
34 @type total_bytes_downloaded: int
35 @ivar dry_run: instead of starting a download, just report what we would have downloaded
36 @type dry_run: bool
37 """
39 __slots__ = ['monitored_downloads', '_loop', 'dry_run', 'total_bytes_downloaded', 'n_completed_downloads', '_current_confirm']
41 def __init__(self, mainloop = None, dry_run = False):
42 self.monitored_downloads = {}
43 self._loop = None
44 self.dry_run = dry_run
45 self.n_completed_downloads = 0
46 self.total_bytes_downloaded = 0
47 self._current_confirm = None
49 def monitor_download(self, dl):
50 """Called when a new L{download} is started.
51 This is mainly used by the GUI to display the progress bar."""
52 dl.start()
53 self.monitored_downloads[dl.url] = dl
54 self.downloads_changed()
56 @tasks.async
57 def download_done_stats():
58 yield dl.downloaded
59 # NB: we don't check for exceptions here; someone else should be doing that
60 try:
61 self.n_completed_downloads += 1
62 self.total_bytes_downloaded += dl.get_bytes_downloaded_so_far()
63 del self.monitored_downloads[dl.url]
64 self.downloads_changed()
65 except Exception, ex:
66 self.report_error(ex)
67 download_done_stats()
69 def impl_added_to_store(self, impl):
70 """Called by the L{fetch.Fetcher} when adding an implementation.
71 The GUI uses this to update its display.
72 @param impl: the implementation which has been added
73 @type impl: L{model.Implementation}
74 """
75 pass
77 def downloads_changed(self):
78 """This is just for the GUI to override to update its display."""
79 pass
81 def wait_for_blocker(self, blocker):
82 """Run a recursive mainloop until blocker is triggered.
83 @param blocker: event to wait on
84 @type blocker: L{tasks.Blocker}"""
85 if not blocker.happened:
86 import gobject
88 def quitter():
89 yield blocker
90 self._loop.quit()
91 quit = tasks.Task(quitter(), "quitter")
93 assert self._loop is None # Avoid recursion
94 self._loop = gobject.MainLoop(gobject.main_context_default())
95 try:
96 debug(_("Entering mainloop, waiting for %s"), blocker)
97 self._loop.run()
98 finally:
99 self._loop = None
101 assert blocker.happened, "Someone quit the main loop!"
103 tasks.check(blocker)
105 def get_download(self, url, force = False, hint = None, factory = None):
106 """Return the Download object currently downloading 'url'.
107 If no download for this URL has been started, start one now (and
108 start monitoring it).
109 If the download failed and force is False, return it anyway.
110 If force is True, abort any current or failed download and start
111 a new one.
112 @rtype: L{download.Download}
114 if self.dry_run:
115 raise NeedDownload(url)
117 try:
118 dl = self.monitored_downloads[url]
119 if dl and force:
120 dl.abort()
121 raise KeyError
122 except KeyError:
123 if factory is None:
124 dl = download.Download(url, hint)
125 else:
126 dl = factory(url, hint)
127 self.monitor_download(dl)
128 return dl
130 def confirm_keys(self, pending, fetch_key_info):
131 """We don't trust any of the signatures yet. Ask the user.
132 When done update the L{trust} database, and then call L{trust.TrustDB.notify}.
133 This method just calls L{confirm_import_feed} if the handler (self) is
134 new-style, or L{confirm_trust_keys} for older classes. A class
135 is considered old-style if it overrides confirm_trust_keys and
136 not confirm_import_feed.
137 @since: 0.42
138 @arg pending: an object holding details of the updated feed
139 @type pending: L{PendingFeed}
140 @arg fetch_key_info: a function which can be used to fetch information about a key fingerprint
141 @type fetch_key_info: str -> L{Blocker}
142 @return: A blocker that triggers when the user has chosen, or None if already done.
143 @rtype: None | L{Blocker}"""
145 assert pending.sigs
147 if hasattr(self.confirm_trust_keys, 'original') or not hasattr(self.confirm_import_feed, 'original'):
148 # new-style class
149 from zeroinstall.injector import gpg
150 valid_sigs = [s for s in pending.sigs if isinstance(s, gpg.ValidSig)]
151 if not valid_sigs:
152 def format_sig(sig):
153 msg = str(sig)
154 if sig.messages:
155 msg += "\nMessages from GPG:\n" + sig.messages
156 return msg
157 raise SafeException(_('No valid signatures found on "%(url)s". Signatures:%(signatures)s') %
158 {'url': pending.url, 'signatures': ''.join(['\n- ' + format_sig(s) for s in pending.sigs])})
160 # Start downloading information about the keys...
161 kfs = {}
162 for sig in valid_sigs:
163 kfs[sig] = fetch_key_info(sig.fingerprint)
165 return self._queue_confirm_import_feed(pending, kfs)
166 else:
167 # old-style class
168 from zeroinstall.injector import iface_cache
169 import warnings
170 warnings.warn("Should override confirm_import_feed(); using old confirm_trust_keys() for now", DeprecationWarning, stacklevel = 2)
172 iface = iface_cache.iface_cache.get_interface(pending.url)
173 return self.confirm_trust_keys(iface, pending.sigs, pending.new_xml)
175 @tasks.async
176 def _queue_confirm_import_feed(self, pending, valid_sigs):
177 # If we're already confirming something else, wait for that to finish...
178 while self._current_confirm is not None:
179 yield self._current_confirm
181 # Check whether we still need to confirm. The user may have
182 # already approved one of the keys while dealing with another
183 # feed.
184 from zeroinstall.injector import trust
185 domain = trust.domain_from_url(pending.url)
186 for sig in valid_sigs:
187 is_trusted = trust.trust_db.is_trusted(sig.fingerprint, domain)
188 if is_trusted:
189 return
191 # Take the lock and confirm this feed
192 self._current_confirm = lock = tasks.Blocker('confirm key lock')
193 try:
194 done = self.confirm_import_feed(pending, valid_sigs)
195 if done is not None:
196 yield done
197 tasks.check(done)
198 finally:
199 self._current_confirm = None
200 lock.trigger()
202 @tasks.async
203 def confirm_import_feed(self, pending, valid_sigs):
204 """Sub-classes should override this method to interact with the user about new feeds.
205 If multiple feeds need confirmation, L{confirm_keys} will only invoke one instance of this
206 method at a time.
207 @param pending: the new feed to be imported
208 @type pending: L{PendingFeed}
209 @param valid_sigs: maps signatures to a list of fetchers collecting information about the key
210 @type valid_sigs: {L{gpg.ValidSig} : L{fetch.KeyInfoFetcher}}
211 @since: 0.42
212 @see: L{confirm_keys}"""
213 from zeroinstall.injector import trust
215 assert valid_sigs
217 domain = trust.domain_from_url(pending.url)
219 # Ask on stderr, because we may be writing XML to stdout
220 print >>sys.stderr, _("Feed: %s") % pending.url
221 print >>sys.stderr, _("The feed is correctly signed with the following keys:")
222 for x in valid_sigs:
223 print >>sys.stderr, "-", unicode(x).encode('ascii', 'xmlcharrefreplace')
225 def text(parent):
226 text = ""
227 for node in parent.childNodes:
228 if node.nodeType == node.TEXT_NODE:
229 text = text + node.data
230 return text
232 shown = set()
233 key_info_fetchers = valid_sigs.values()
234 while key_info_fetchers:
235 old_kfs = key_info_fetchers
236 key_info_fetchers = []
237 for kf in old_kfs:
238 infos = set(kf.info) - shown
239 if infos:
240 if len(valid_sigs) > 1:
241 print "%s: " % kf.fingerprint
242 for info in infos:
243 print >>sys.stderr, "-", text(info)
244 shown.add(info)
245 if kf.blocker:
246 key_info_fetchers.append(kf)
247 if key_info_fetchers:
248 for kf in key_info_fetchers: print >>sys.stderr, kf.status
249 #stdin = tasks.InputBlocker(0, 'console')
250 blockers = [kf.blocker for kf in key_info_fetchers] #+ [stdin]
251 yield blockers
252 for b in blockers:
253 try:
254 tasks.check(b)
255 except Exception, ex:
256 warn(_("Failed to get key info: %s"), ex)
257 #if stdin.happened:
258 # print >>sys.stderr, _("Skipping remaining key lookups due to input from user")
259 # break
261 if len(valid_sigs) == 1:
262 print >>sys.stderr, _("Do you want to trust this key to sign feeds from '%s'?") % domain
263 else:
264 print >>sys.stderr, _("Do you want to trust all of these keys to sign feeds from '%s'?") % domain
265 while True:
266 print >>sys.stderr, _("Trust [Y/N] ")
267 sys.stderr.flush()
268 i = raw_input()
269 if not i: continue
270 if i in 'Nn':
271 raise NoTrustedKeys(_('Not signed with a trusted key'))
272 if i in 'Yy':
273 break
274 for key in valid_sigs:
275 print >>sys.stderr, _("Trusting %(key_fingerprint)s for %(domain)s") % {'key_fingerprint': key.fingerprint, 'domain': domain}
276 trust.trust_db.trust_key(key.fingerprint, domain)
278 confirm_import_feed.original = True
280 def confirm_trust_keys(self, interface, sigs, iface_xml):
281 """We don't trust any of the signatures yet. Ask the user.
282 When done update the L{trust} database, and then call L{trust.TrustDB.notify}.
283 @deprecated: see L{confirm_keys}
284 @arg interface: the interface being updated
285 @arg sigs: a list of signatures (from L{gpg.check_stream})
286 @arg iface_xml: the downloaded data (not yet trusted)
287 @return: a blocker, if confirmation will happen asynchronously, or None
288 @rtype: L{tasks.Blocker}"""
289 import warnings
290 warnings.warn("Use confirm_keys, not confirm_trust_keys", DeprecationWarning, stacklevel = 2)
291 from zeroinstall.injector import trust, gpg
292 assert sigs
293 valid_sigs = [s for s in sigs if isinstance(s, gpg.ValidSig)]
294 if not valid_sigs:
295 raise SafeException('No valid signatures found on "%s". Signatures:%s' %
296 (interface.uri, ''.join(['\n- ' + str(s) for s in sigs])))
298 domain = trust.domain_from_url(interface.uri)
300 # Ask on stderr, because we may be writing XML to stdout
301 print >>sys.stderr, "\nInterface:", interface.uri
302 print >>sys.stderr, _("The feed is correctly signed with the following keys:")
303 for x in valid_sigs:
304 print >>sys.stderr, "-", unicode(x).encode('ascii', 'xmlcharrefreplace')
306 if len(valid_sigs) == 1:
307 print >>sys.stderr, _("Do you want to trust this key to sign feeds from '%s'?") % domain
308 else:
309 print >>sys.stderr, _("Do you want to trust all of these keys to sign feeds from '%s'?") % domain
310 while True:
311 print >>sys.stderr, _("Trust [Y/N] ")
312 sys.stderr.flush()
313 i = raw_input()
314 if not i: continue
315 if i in 'Nn':
316 raise NoTrustedKeys(_('Not signed with a trusted key'))
317 if i in 'Yy':
318 break
319 for key in valid_sigs:
320 print >>sys.stderr, _("Trusting %s for %s") % (key.fingerprint, domain)
321 trust.trust_db.trust_key(key.fingerprint, domain)
323 trust.trust_db.notify()
325 confirm_trust_keys.original = True # Detect if someone overrides it
327 @tasks.async
328 def confirm_install(self, msg):
329 """We need to check something with the user before continuing with the install.
330 @raise download.DownloadAborted: if the user cancels"""
331 yield
332 print >>sys.stderr, msg
333 while True:
334 sys.stderr.write(_("Install [Y/N] "))
335 sys.stderr.flush()
336 i = raw_input()
337 if not i: continue
338 if i in 'Nn':
339 raise download.DownloadAborted()
340 if i in 'Yy':
341 break
343 def report_error(self, exception, tb = None):
344 """Report an exception to the user.
345 @param exception: the exception to report
346 @type exception: L{SafeException}
347 @param tb: optional traceback
348 @since: 0.25"""
349 warn("%s", str(exception) or type(exception))
350 #import traceback
351 #traceback.print_exception(exception, None, tb)
353 class ConsoleHandler(Handler):
354 """A Handler that displays progress on stdout (a tty).
355 @since: 0.44"""
356 last_msg_len = None
357 update = None
358 disable_progress = 0
359 screen_width = None
361 def downloads_changed(self):
362 import gobject
363 if self.monitored_downloads and self.update is None:
364 if self.screen_width is None:
365 try:
366 import curses
367 curses.setupterm()
368 self.screen_width = curses.tigetnum('cols') or 80
369 except Exception, ex:
370 info("Failed to initialise curses library: %s", ex)
371 self.screen_width = 80
372 self.show_progress()
373 self.update = gobject.timeout_add(200, self.show_progress)
374 elif len(self.monitored_downloads) == 0:
375 if self.update:
376 gobject.source_remove(self.update)
377 self.update = None
378 print
379 self.last_msg_len = None
381 def show_progress(self):
382 urls = self.monitored_downloads.keys()
383 if not urls: return True
385 if self.disable_progress: return True
387 screen_width = self.screen_width - 2
388 item_width = max(16, screen_width / len(self.monitored_downloads))
389 url_width = item_width - 7
391 msg = ""
392 for url in sorted(urls):
393 dl = self.monitored_downloads[url]
394 so_far = dl.get_bytes_downloaded_so_far()
395 leaf = url.rsplit('/', 1)[-1]
396 if len(leaf) >= url_width:
397 display = leaf[:url_width]
398 else:
399 display = url[-url_width:]
400 if dl.expected_size:
401 msg += "[%s %d%%] " % (display, int(so_far * 100 / dl.expected_size))
402 else:
403 msg += "[%s] " % (display)
404 msg = msg[:screen_width]
406 if self.last_msg_len is None:
407 sys.stdout.write(msg)
408 else:
409 sys.stdout.write(chr(13) + msg)
410 if len(msg) < self.last_msg_len:
411 sys.stdout.write(" " * (self.last_msg_len - len(msg)))
413 self.last_msg_len = len(msg)
414 sys.stdout.flush()
416 return True
418 def clear_display(self):
419 if self.last_msg_len != None:
420 sys.stdout.write(chr(13) + " " * self.last_msg_len + chr(13))
421 sys.stdout.flush()
422 self.last_msg_len = None
424 def report_error(self, exception, tb = None):
425 self.clear_display()
426 Handler.report_error(self, exception, tb)
428 def confirm_import_feed(self, pending, valid_sigs):
429 self.clear_display()
430 self.disable_progress += 1
431 blocker = Handler.confirm_import_feed(self, pending, valid_sigs)
432 @tasks.async
433 def enable():
434 yield blocker
435 self.disable_progress -= 1
436 self.show_progress()
437 enable()
438 return blocker
440 class BatchHandler(Handler):
441 """A Handler that writes easily parseable data to stderr."""
443 def confirm_import_feed(self, pending, valid_sigs):
444 print >>sys.stderr, "QUESTION:"
445 return Handler.confirm_import_feed(self, pending, valid_sigs)
447 def confirm_trust_keys(self, interface, sigs, iface_xml):
448 print >>sys.stderr, "QUESTION:"
449 return Handler.confirm_trust_keys(self, interface, sigs, iface_xml)