Bug 1821144 - remove old windows worker definitions. r=aryx
[gecko.git] / third_party / rust / neqo-crypto / src / ext.rs
blob5c6dc1c8ff7b7523c9dc8a188cda377f169add33
1 // Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
2 // http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
3 // <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
4 // option. This file may not be copied, modified, or distributed
5 // except according to those terms.
7 use crate::agentio::as_c_void;
8 use crate::constants::{
9     Extension, HandshakeMessage, TLS_HS_CLIENT_HELLO, TLS_HS_ENCRYPTED_EXTENSIONS,
11 use crate::err::Res;
12 use crate::ssl::{
13     PRBool, PRFileDesc, SECFailure, SECStatus, SECSuccess, SSLAlertDescription,
14     SSLExtensionHandler, SSLExtensionWriter, SSLHandshakeType,
17 use std::cell::RefCell;
18 use std::convert::TryFrom;
19 use std::os::raw::{c_uint, c_void};
20 use std::pin::Pin;
21 use std::rc::Rc;
23 experimental_api!(SSL_InstallExtensionHooks(
24     fd: *mut PRFileDesc,
25     extension: u16,
26     writer: SSLExtensionWriter,
27     writer_arg: *mut c_void,
28     handler: SSLExtensionHandler,
29     handler_arg: *mut c_void,
30 ));
32 pub enum ExtensionWriterResult {
33     Write(usize),
34     Skip,
37 pub enum ExtensionHandlerResult {
38     Ok,
39     Alert(crate::constants::Alert),
42 pub trait ExtensionHandler {
43     fn write(&mut self, msg: HandshakeMessage, _d: &mut [u8]) -> ExtensionWriterResult {
44         match msg {
45             TLS_HS_CLIENT_HELLO | TLS_HS_ENCRYPTED_EXTENSIONS => ExtensionWriterResult::Write(0),
46             _ => ExtensionWriterResult::Skip,
47         }
48     }
50     fn handle(&mut self, msg: HandshakeMessage, _d: &[u8]) -> ExtensionHandlerResult {
51         match msg {
52             TLS_HS_CLIENT_HELLO | TLS_HS_ENCRYPTED_EXTENSIONS => ExtensionHandlerResult::Ok,
53             _ => ExtensionHandlerResult::Alert(110), // unsupported_extension
54         }
55     }
58 type BoxedExtensionHandler = Box<Rc<RefCell<dyn ExtensionHandler>>>;
60 pub struct ExtensionTracker {
61     extension: Extension,
62     handler: Pin<Box<BoxedExtensionHandler>>,
65 impl ExtensionTracker {
66     // Technically the as_mut() call here is the only unsafe bit,
67     // but don't call this function lightly.
68     unsafe fn wrap_handler_call<F, T>(arg: *mut c_void, f: F) -> T
69     where
70         F: FnOnce(&mut dyn ExtensionHandler) -> T,
71     {
72         let rc = arg.cast::<BoxedExtensionHandler>().as_mut().unwrap();
73         f(&mut *rc.borrow_mut())
74     }
76     #[allow(clippy::cast_possible_truncation)]
77     unsafe extern "C" fn extension_writer(
78         _fd: *mut PRFileDesc,
79         message: SSLHandshakeType::Type,
80         data: *mut u8,
81         len: *mut c_uint,
82         max_len: c_uint,
83         arg: *mut c_void,
84     ) -> PRBool {
85         let d = std::slice::from_raw_parts_mut(data, max_len as usize);
86         Self::wrap_handler_call(arg, |handler| {
87             // Cast is safe here because the message type is always part of the enum
88             match handler.write(message as HandshakeMessage, d) {
89                 ExtensionWriterResult::Write(sz) => {
90                     *len = c_uint::try_from(sz).expect("integer overflow from extension writer");
91                     1
92                 }
93                 ExtensionWriterResult::Skip => 0,
94             }
95         })
96     }
98     unsafe extern "C" fn extension_handler(
99         _fd: *mut PRFileDesc,
100         message: SSLHandshakeType::Type,
101         data: *const u8,
102         len: c_uint,
103         alert: *mut SSLAlertDescription,
104         arg: *mut c_void,
105     ) -> SECStatus {
106         let d = std::slice::from_raw_parts(data, len as usize);
107         #[allow(clippy::cast_possible_truncation)]
108         Self::wrap_handler_call(arg, |handler| {
109             // Cast is safe here because the message type is always part of the enum
110             match handler.handle(message as HandshakeMessage, d) {
111                 ExtensionHandlerResult::Ok => SECSuccess,
112                 ExtensionHandlerResult::Alert(a) => {
113                     *alert = a;
114                     SECFailure
115                 }
116             }
117         })
118     }
120     /// Use the provided handler to manage an extension.  This is quite unsafe.
121     ///
122     /// # Safety
123     /// The holder of this `ExtensionTracker` needs to ensure that it lives at
124     /// least as long as the file descriptor, as NSS provides no way to remove
125     /// an extension handler once it is configured.
126     ///
127     /// # Errors
128     /// If the underlying NSS API fails to register a handler.
129     pub unsafe fn new(
130         fd: *mut PRFileDesc,
131         extension: Extension,
132         handler: Rc<RefCell<dyn ExtensionHandler>>,
133     ) -> Res<Self> {
134         // The ergonomics here aren't great for users of this API, but it's
135         // horrific here. The pinned outer box gives us a stable pointer to the inner
136         // box.  This is the pointer that is passed to NSS.
137         //
138         // The inner box points to the reference-counted object.  This inner box is
139         // what we end up with a reference to in callbacks.  That extra wrapper around
140         // the Rc avoid any touching of reference counts in callbacks, which would
141         // inevitably lead to leaks as we don't control how many times the callback
142         // is invoked.
143         //
144         // This way, only this "outer" code deals with the reference count.
145         let mut tracker = Self {
146             extension,
147             handler: Box::pin(Box::new(handler)),
148         };
149         SSL_InstallExtensionHooks(
150             fd,
151             extension,
152             Some(Self::extension_writer),
153             as_c_void(&mut tracker.handler),
154             Some(Self::extension_handler),
155             as_c_void(&mut tracker.handler),
156         )?;
157         Ok(tracker)
158     }
161 impl std::fmt::Debug for ExtensionTracker {
162     fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
163         write!(f, "ExtensionTracker: {:?}", self.extension)
164     }