1 // Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
2 // http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
3 // <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
4 // option. This file may not be copied, modified, or distributed
5 // except according to those terms.
7 use crate::agentio::as_c_void;
8 use crate::constants::{
9 Extension, HandshakeMessage, TLS_HS_CLIENT_HELLO, TLS_HS_ENCRYPTED_EXTENSIONS,
13 PRBool, PRFileDesc, SECFailure, SECStatus, SECSuccess, SSLAlertDescription,
14 SSLExtensionHandler, SSLExtensionWriter, SSLHandshakeType,
17 use std::cell::RefCell;
18 use std::convert::TryFrom;
19 use std::os::raw::{c_uint, c_void};
23 experimental_api!(SSL_InstallExtensionHooks(
26 writer: SSLExtensionWriter,
27 writer_arg: *mut c_void,
28 handler: SSLExtensionHandler,
29 handler_arg: *mut c_void,
32 pub enum ExtensionWriterResult {
37 pub enum ExtensionHandlerResult {
39 Alert(crate::constants::Alert),
42 pub trait ExtensionHandler {
43 fn write(&mut self, msg: HandshakeMessage, _d: &mut [u8]) -> ExtensionWriterResult {
45 TLS_HS_CLIENT_HELLO | TLS_HS_ENCRYPTED_EXTENSIONS => ExtensionWriterResult::Write(0),
46 _ => ExtensionWriterResult::Skip,
50 fn handle(&mut self, msg: HandshakeMessage, _d: &[u8]) -> ExtensionHandlerResult {
52 TLS_HS_CLIENT_HELLO | TLS_HS_ENCRYPTED_EXTENSIONS => ExtensionHandlerResult::Ok,
53 _ => ExtensionHandlerResult::Alert(110), // unsupported_extension
58 type BoxedExtensionHandler = Box<Rc<RefCell<dyn ExtensionHandler>>>;
60 pub struct ExtensionTracker {
62 handler: Pin<Box<BoxedExtensionHandler>>,
65 impl ExtensionTracker {
66 // Technically the as_mut() call here is the only unsafe bit,
67 // but don't call this function lightly.
68 unsafe fn wrap_handler_call<F, T>(arg: *mut c_void, f: F) -> T
70 F: FnOnce(&mut dyn ExtensionHandler) -> T,
72 let rc = arg.cast::<BoxedExtensionHandler>().as_mut().unwrap();
73 f(&mut *rc.borrow_mut())
76 #[allow(clippy::cast_possible_truncation)]
77 unsafe extern "C" fn extension_writer(
79 message: SSLHandshakeType::Type,
85 let d = std::slice::from_raw_parts_mut(data, max_len as usize);
86 Self::wrap_handler_call(arg, |handler| {
87 // Cast is safe here because the message type is always part of the enum
88 match handler.write(message as HandshakeMessage, d) {
89 ExtensionWriterResult::Write(sz) => {
90 *len = c_uint::try_from(sz).expect("integer overflow from extension writer");
93 ExtensionWriterResult::Skip => 0,
98 unsafe extern "C" fn extension_handler(
100 message: SSLHandshakeType::Type,
103 alert: *mut SSLAlertDescription,
106 let d = std::slice::from_raw_parts(data, len as usize);
107 #[allow(clippy::cast_possible_truncation)]
108 Self::wrap_handler_call(arg, |handler| {
109 // Cast is safe here because the message type is always part of the enum
110 match handler.handle(message as HandshakeMessage, d) {
111 ExtensionHandlerResult::Ok => SECSuccess,
112 ExtensionHandlerResult::Alert(a) => {
120 /// Use the provided handler to manage an extension. This is quite unsafe.
123 /// The holder of this `ExtensionTracker` needs to ensure that it lives at
124 /// least as long as the file descriptor, as NSS provides no way to remove
125 /// an extension handler once it is configured.
128 /// If the underlying NSS API fails to register a handler.
131 extension: Extension,
132 handler: Rc<RefCell<dyn ExtensionHandler>>,
134 // The ergonomics here aren't great for users of this API, but it's
135 // horrific here. The pinned outer box gives us a stable pointer to the inner
136 // box. This is the pointer that is passed to NSS.
138 // The inner box points to the reference-counted object. This inner box is
139 // what we end up with a reference to in callbacks. That extra wrapper around
140 // the Rc avoid any touching of reference counts in callbacks, which would
141 // inevitably lead to leaks as we don't control how many times the callback
144 // This way, only this "outer" code deals with the reference count.
145 let mut tracker = Self {
147 handler: Box::pin(Box::new(handler)),
149 SSL_InstallExtensionHooks(
152 Some(Self::extension_writer),
153 as_c_void(&mut tracker.handler),
154 Some(Self::extension_handler),
155 as_c_void(&mut tracker.handler),
161 impl std::fmt::Debug for ExtensionTracker {
162 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
163 write!(f, "ExtensionTracker: {:?}", self.extension)