Backed out changeset 317994df7ee4 (bug 1889691) for causing dt failures @ browser_web...
[gecko.git] / third_party / rust / neqo-transport / src / packet / mod.rs
blob8458f69779dd22346969c533379e36b6376141e3
1 // Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
2 // http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
3 // <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
4 // option. This file may not be copied, modified, or distributed
5 // except according to those terms.
7 // Encoding and decoding packets off the wire.
8 use std::{
9     cmp::min,
10     fmt,
11     ops::{Deref, DerefMut, Range},
12     time::Instant,
15 use neqo_common::{hex, hex_with_len, qtrace, qwarn, Decoder, Encoder};
16 use neqo_crypto::random;
18 use crate::{
19     cid::{ConnectionId, ConnectionIdDecoder, ConnectionIdRef, MAX_CONNECTION_ID_LEN},
20     crypto::{CryptoDxState, CryptoSpace, CryptoStates},
21     version::{Version, WireVersion},
22     Error, Res,
25 pub const PACKET_BIT_LONG: u8 = 0x80;
26 const PACKET_BIT_SHORT: u8 = 0x00;
27 const PACKET_BIT_FIXED_QUIC: u8 = 0x40;
28 const PACKET_BIT_SPIN: u8 = 0x20;
29 const PACKET_BIT_KEY_PHASE: u8 = 0x04;
31 const PACKET_HP_MASK_LONG: u8 = 0x0f;
32 const PACKET_HP_MASK_SHORT: u8 = 0x1f;
34 const SAMPLE_SIZE: usize = 16;
35 const SAMPLE_OFFSET: usize = 4;
36 const MAX_PACKET_NUMBER_LEN: usize = 4;
38 mod retry;
40 pub type PacketNumber = u64;
42 #[derive(Debug, Clone, Copy, PartialEq, Eq)]
43 pub enum PacketType {
44     VersionNegotiation,
45     Initial,
46     Handshake,
47     ZeroRtt,
48     Retry,
49     Short,
50     OtherVersion,
53 impl PacketType {
54     #[must_use]
55     fn from_byte(t: u8, v: Version) -> Self {
56         // Version2 adds one to the type, modulo 4
57         match t.wrapping_sub(u8::from(v == Version::Version2)) & 3 {
58             0 => Self::Initial,
59             1 => Self::ZeroRtt,
60             2 => Self::Handshake,
61             3 => Self::Retry,
62             _ => panic!("packet type out of range"),
63         }
64     }
66     #[must_use]
67     fn to_byte(self, v: Version) -> u8 {
68         let t = match self {
69             Self::Initial => 0,
70             Self::ZeroRtt => 1,
71             Self::Handshake => 2,
72             Self::Retry => 3,
73             _ => panic!("not a long header packet type"),
74         };
75         // Version2 adds one to the type, modulo 4
76         (t + u8::from(v == Version::Version2)) & 3
77     }
80 impl From<PacketType> for CryptoSpace {
81     fn from(v: PacketType) -> Self {
82         match v {
83             PacketType::Initial => Self::Initial,
84             PacketType::ZeroRtt => Self::ZeroRtt,
85             PacketType::Handshake => Self::Handshake,
86             PacketType::Short => Self::ApplicationData,
87             _ => panic!("shouldn't be here"),
88         }
89     }
92 impl From<CryptoSpace> for PacketType {
93     fn from(cs: CryptoSpace) -> Self {
94         match cs {
95             CryptoSpace::Initial => Self::Initial,
96             CryptoSpace::ZeroRtt => Self::ZeroRtt,
97             CryptoSpace::Handshake => Self::Handshake,
98             CryptoSpace::ApplicationData => Self::Short,
99         }
100     }
103 struct PacketBuilderOffsets {
104     /// The bits of the first octet that need masking.
105     first_byte_mask: u8,
106     /// The offset of the length field.
107     len: usize,
108     /// The location of the packet number field.
109     pn: Range<usize>,
112 /// A packet builder that can be used to produce short packets and long packets.
113 /// This does not produce Retry or Version Negotiation.
114 pub struct PacketBuilder {
115     encoder: Encoder,
116     pn: PacketNumber,
117     header: Range<usize>,
118     offsets: PacketBuilderOffsets,
119     limit: usize,
120     /// Whether to pad the packet before construction.
121     padding: bool,
124 impl PacketBuilder {
125     /// The minimum useful frame size.  If space is less than this, we will claim to be full.
126     pub const MINIMUM_FRAME_SIZE: usize = 2;
128     fn infer_limit(encoder: &Encoder) -> usize {
129         if encoder.capacity() > 64 {
130             encoder.capacity()
131         } else {
132             2048
133         }
134     }
136     /// Start building a short header packet.
137     ///
138     /// This doesn't fail if there isn't enough space; instead it returns a builder that
139     /// has no available space left.  This allows the caller to extract the encoder
140     /// and any packets that might have been added before as adding a packet header is
141     /// only likely to fail if there are other packets already written.
142     ///
143     /// If, after calling this method, `remaining()` returns 0, then call `abort()` to get
144     /// the encoder back.
145     #[allow(clippy::reversed_empty_ranges)]
146     pub fn short(mut encoder: Encoder, key_phase: bool, dcid: impl AsRef<[u8]>) -> Self {
147         let mut limit = Self::infer_limit(&encoder);
148         let header_start = encoder.len();
149         // Check that there is enough space for the header.
150         // 5 = 1 (first byte) + 4 (packet number)
151         if limit > encoder.len() && 5 + dcid.as_ref().len() < limit - encoder.len() {
152             encoder
153                 .encode_byte(PACKET_BIT_SHORT | PACKET_BIT_FIXED_QUIC | (u8::from(key_phase) << 2));
154             encoder.encode(dcid.as_ref());
155         } else {
156             limit = 0;
157         }
158         Self {
159             encoder,
160             pn: u64::max_value(),
161             header: header_start..header_start,
162             offsets: PacketBuilderOffsets {
163                 first_byte_mask: PACKET_HP_MASK_SHORT,
164                 pn: 0..0,
165                 len: 0,
166             },
167             limit,
168             padding: false,
169         }
170     }
172     /// Start building a long header packet.
173     /// For an Initial packet you will need to call `initial_token()`,
174     /// even if the token is empty.
175     ///
176     /// See `short()` for more on how to handle this in cases where there is no space.
177     #[allow(clippy::reversed_empty_ranges)] // For initializing an empty range.
178     #[allow(clippy::similar_names)] // For dcid and scid, which are fine here.
179     pub fn long(
180         mut encoder: Encoder,
181         pt: PacketType,
182         version: Version,
183         dcid: impl AsRef<[u8]>,
184         scid: impl AsRef<[u8]>,
185     ) -> Self {
186         let mut limit = Self::infer_limit(&encoder);
187         let header_start = encoder.len();
188         // Check that there is enough space for the header.
189         // 11 = 1 (first byte) + 4 (version) + 2 (dcid+scid length) + 4 (packet number)
190         if limit > encoder.len()
191             && 11 + dcid.as_ref().len() + scid.as_ref().len() < limit - encoder.len()
192         {
193             encoder.encode_byte(PACKET_BIT_LONG | PACKET_BIT_FIXED_QUIC | pt.to_byte(version) << 4);
194             encoder.encode_uint(4, version.wire_version());
195             encoder.encode_vec(1, dcid.as_ref());
196             encoder.encode_vec(1, scid.as_ref());
197         } else {
198             limit = 0;
199         }
201         Self {
202             encoder,
203             pn: u64::max_value(),
204             header: header_start..header_start,
205             offsets: PacketBuilderOffsets {
206                 first_byte_mask: PACKET_HP_MASK_LONG,
207                 pn: 0..0,
208                 len: 0,
209             },
210             limit,
211             padding: false,
212         }
213     }
215     fn is_long(&self) -> bool {
216         self.as_ref()[self.header.start] & 0x80 == PACKET_BIT_LONG
217     }
219     /// This stores a value that can be used as a limit.  This does not cause
220     /// this limit to be enforced until encryption occurs.  Prior to that, it
221     /// is only used voluntarily by users of the builder, through `remaining()`.
222     pub fn set_limit(&mut self, limit: usize) {
223         self.limit = limit;
224     }
226     /// Get the current limit.
227     #[must_use]
228     pub fn limit(&mut self) -> usize {
229         self.limit
230     }
232     /// How many bytes remain against the size limit for the builder.
233     #[must_use]
234     pub fn remaining(&self) -> usize {
235         self.limit.saturating_sub(self.encoder.len())
236     }
238     /// Returns true if the packet has no more space for frames.
239     #[must_use]
240     pub fn is_full(&self) -> bool {
241         // No useful frame is smaller than 2 bytes long.
242         self.limit < self.encoder.len() + Self::MINIMUM_FRAME_SIZE
243     }
245     /// Adjust the limit to ensure that no more data is added.
246     pub fn mark_full(&mut self) {
247         self.limit = self.encoder.len();
248     }
250     /// Mark the packet as needing padding (or not).
251     pub fn enable_padding(&mut self, needs_padding: bool) {
252         self.padding = needs_padding;
253     }
255     /// Maybe pad with "PADDING" frames.
256     /// Only does so if padding was needed and this is a short packet.
257     /// Returns true if padding was added.
258     pub fn pad(&mut self) -> bool {
259         if self.padding && !self.is_long() {
260             self.encoder.pad_to(self.limit, 0);
261             true
262         } else {
263             false
264         }
265     }
267     /// Add unpredictable values for unprotected parts of the packet.
268     pub fn scramble(&mut self, quic_bit: bool) {
269         debug_assert!(self.len() > self.header.start);
270         let mask = if quic_bit { PACKET_BIT_FIXED_QUIC } else { 0 }
271             | if self.is_long() { 0 } else { PACKET_BIT_SPIN };
272         let first = self.header.start;
273         self.encoder.as_mut()[first] ^= random::<1>()[0] & mask;
274     }
276     /// For an Initial packet, encode the token.
277     /// If you fail to do this, then you will not get a valid packet.
278     pub fn initial_token(&mut self, token: &[u8]) {
279         if Encoder::vvec_len(token.len()) < self.remaining() {
280             self.encoder.encode_vvec(token);
281         } else {
282             self.limit = 0;
283         }
284     }
286     /// Add a packet number of the given size.
287     /// For a long header packet, this also inserts a dummy length.
288     /// The length is filled in after calling `build`.
289     /// Does nothing if there isn't 4 bytes available other than render this builder
290     /// unusable; if `remaining()` returns 0 at any point, call `abort()`.
291     pub fn pn(&mut self, pn: PacketNumber, pn_len: usize) {
292         if self.remaining() < 4 {
293             self.limit = 0;
294             return;
295         }
297         // Reserve space for a length in long headers.
298         if self.is_long() {
299             self.offsets.len = self.encoder.len();
300             self.encoder.encode(&[0; 2]);
301         }
303         // This allows the input to be >4, which is absurd, but we can eat that.
304         let pn_len = min(MAX_PACKET_NUMBER_LEN, pn_len);
305         debug_assert_ne!(pn_len, 0);
306         // Encode the packet number and save its offset.
307         let pn_offset = self.encoder.len();
308         self.encoder.encode_uint(pn_len, pn);
309         self.offsets.pn = pn_offset..self.encoder.len();
311         // Now encode the packet number length and save the header length.
312         self.encoder.as_mut()[self.header.start] |= u8::try_from(pn_len - 1).unwrap();
313         self.header.end = self.encoder.len();
314         self.pn = pn;
315     }
317     #[allow(clippy::cast_possible_truncation)] // Nope.
318     fn write_len(&mut self, expansion: usize) {
319         let len = self.encoder.len() - (self.offsets.len + 2) + expansion;
320         self.encoder.as_mut()[self.offsets.len] = 0x40 | ((len >> 8) & 0x3f) as u8;
321         self.encoder.as_mut()[self.offsets.len + 1] = (len & 0xff) as u8;
322     }
324     fn pad_for_crypto(&mut self, crypto: &mut CryptoDxState) {
325         // Make sure that there is enough data in the packet.
326         // The length of the packet number plus the payload length needs to
327         // be at least 4 (MAX_PACKET_NUMBER_LEN) plus any amount by which
328         // the header protection sample exceeds the AEAD expansion.
329         let crypto_pad = crypto.extra_padding();
330         self.encoder.pad_to(
331             self.offsets.pn.start + MAX_PACKET_NUMBER_LEN + crypto_pad,
332             0,
333         );
334     }
336     /// A lot of frames here are just a collection of varints.
337     /// This helper functions writes a frame like that safely, returning `true` if
338     /// a frame was written.
339     pub fn write_varint_frame(&mut self, values: &[u64]) -> bool {
340         let write = self.remaining()
341             >= values
342                 .iter()
343                 .map(|&v| Encoder::varint_len(v))
344                 .sum::<usize>();
345         if write {
346             for v in values {
347                 self.encode_varint(*v);
348             }
349             debug_assert!(self.len() <= self.limit());
350         };
351         write
352     }
354     /// Build the packet and return the encoder.
355     pub fn build(mut self, crypto: &mut CryptoDxState) -> Res<Encoder> {
356         if self.len() > self.limit {
357             qwarn!("Packet contents are more than the limit");
358             debug_assert!(false);
359             return Err(Error::InternalError);
360         }
362         self.pad_for_crypto(crypto);
363         if self.offsets.len > 0 {
364             self.write_len(crypto.expansion());
365         }
367         let hdr = &self.encoder.as_ref()[self.header.clone()];
368         let body = &self.encoder.as_ref()[self.header.end..];
369         qtrace!(
370             "Packet build pn={} hdr={} body={}",
371             self.pn,
372             hex(hdr),
373             hex(body)
374         );
375         let ciphertext = crypto.encrypt(self.pn, hdr, body)?;
377         // Calculate the mask.
378         let offset = SAMPLE_OFFSET - self.offsets.pn.len();
379         assert!(offset + SAMPLE_SIZE <= ciphertext.len());
380         let sample = &ciphertext[offset..offset + SAMPLE_SIZE];
381         let mask = crypto.compute_mask(sample)?;
383         // Apply the mask.
384         self.encoder.as_mut()[self.header.start] ^= mask[0] & self.offsets.first_byte_mask;
385         for (i, j) in (1..=self.offsets.pn.len()).zip(self.offsets.pn) {
386             self.encoder.as_mut()[j] ^= mask[i];
387         }
389         // Finally, cut off the plaintext and add back the ciphertext.
390         self.encoder.truncate(self.header.end);
391         self.encoder.encode(&ciphertext);
392         qtrace!("Packet built {}", hex(&self.encoder));
393         Ok(self.encoder)
394     }
396     /// Abort writing of this packet and return the encoder.
397     #[must_use]
398     pub fn abort(mut self) -> Encoder {
399         self.encoder.truncate(self.header.start);
400         self.encoder
401     }
403     /// Work out if nothing was added after the header.
404     #[must_use]
405     pub fn packet_empty(&self) -> bool {
406         self.encoder.len() == self.header.end
407     }
409     /// Make a retry packet.
410     /// As this is a simple packet, this is just an associated function.
411     /// As Retry is odd (it has to be constructed with leading bytes),
412     /// this returns a [`Vec<u8>`] rather than building on an encoder.
413     #[allow(clippy::similar_names)] // scid and dcid are fine here.
414     pub fn retry(
415         version: Version,
416         dcid: &[u8],
417         scid: &[u8],
418         token: &[u8],
419         odcid: &[u8],
420     ) -> Res<Vec<u8>> {
421         let mut encoder = Encoder::default();
422         encoder.encode_vec(1, odcid);
423         let start = encoder.len();
424         encoder.encode_byte(
425             PACKET_BIT_LONG
426                 | PACKET_BIT_FIXED_QUIC
427                 | (PacketType::Retry.to_byte(version) << 4)
428                 | (random::<1>()[0] & 0xf),
429         );
430         encoder.encode_uint(4, version.wire_version());
431         encoder.encode_vec(1, dcid);
432         encoder.encode_vec(1, scid);
433         debug_assert_ne!(token.len(), 0);
434         encoder.encode(token);
435         let tag = retry::use_aead(version, |aead| {
436             let mut buf = vec![0; aead.expansion()];
437             Ok(aead.encrypt(0, encoder.as_ref(), &[], &mut buf)?.to_vec())
438         })?;
439         encoder.encode(&tag);
440         let mut complete: Vec<u8> = encoder.into();
441         Ok(complete.split_off(start))
442     }
444     /// Make a Version Negotiation packet.
445     #[allow(clippy::similar_names)] // scid and dcid are fine here.
446     pub fn version_negotiation(
447         dcid: &[u8],
448         scid: &[u8],
449         client_version: u32,
450         versions: &[Version],
451     ) -> Vec<u8> {
452         let mut encoder = Encoder::default();
453         let mut grease = random::<4>();
454         // This will not include the "QUIC bit" sometimes.  Intentionally.
455         encoder.encode_byte(PACKET_BIT_LONG | (grease[3] & 0x7f));
456         encoder.encode(&[0; 4]); // Zero version == VN.
457         encoder.encode_vec(1, dcid);
458         encoder.encode_vec(1, scid);
460         for v in versions {
461             encoder.encode_uint(4, v.wire_version());
462         }
463         // Add a greased version, using the randomness already generated.
464         for g in &mut grease[..3] {
465             *g = *g & 0xf0 | 0x0a;
466         }
468         // Ensure our greased version does not collide with the client version
469         // by making the last byte differ from the client initial.
470         grease[3] = (client_version.wrapping_add(0x10) & 0xf0) as u8 | 0x0a;
471         encoder.encode(&grease[..4]);
473         Vec::from(encoder)
474     }
477 impl Deref for PacketBuilder {
478     type Target = Encoder;
480     fn deref(&self) -> &Self::Target {
481         &self.encoder
482     }
485 impl DerefMut for PacketBuilder {
486     fn deref_mut(&mut self) -> &mut Self::Target {
487         &mut self.encoder
488     }
491 impl From<PacketBuilder> for Encoder {
492     fn from(v: PacketBuilder) -> Self {
493         v.encoder
494     }
497 /// `PublicPacket` holds information from packets that is public only.  This allows for
498 /// processing of packets prior to decryption.
499 pub struct PublicPacket<'a> {
500     /// The packet type.
501     packet_type: PacketType,
502     /// The recovered destination connection ID.
503     dcid: ConnectionIdRef<'a>,
504     /// The source connection ID, if this is a long header packet.
505     scid: Option<ConnectionIdRef<'a>>,
506     /// Any token that is included in the packet (Retry always has a token; Initial sometimes
507     /// does). This is empty when there is no token.
508     token: &'a [u8],
509     /// The size of the header, not including the packet number.
510     header_len: usize,
511     /// Protocol version, if present in header.
512     version: Option<WireVersion>,
513     /// A reference to the entire packet, including the header.
514     data: &'a [u8],
517 impl<'a> PublicPacket<'a> {
518     fn opt<T>(v: Option<T>) -> Res<T> {
519         if let Some(v) = v {
520             Ok(v)
521         } else {
522             Err(Error::NoMoreData)
523         }
524     }
526     /// Decode the type-specific portions of a long header.
527     /// This includes reading the length and the remainder of the packet.
528     /// Returns a tuple of any token and the length of the header.
529     fn decode_long(
530         decoder: &mut Decoder<'a>,
531         packet_type: PacketType,
532         version: Version,
533     ) -> Res<(&'a [u8], usize)> {
534         if packet_type == PacketType::Retry {
535             let header_len = decoder.offset();
536             let expansion = retry::expansion(version);
537             let token = Self::opt(decoder.decode(decoder.remaining() - expansion))?;
538             if token.is_empty() {
539                 return Err(Error::InvalidPacket);
540             }
541             Self::opt(decoder.decode(expansion))?;
542             return Ok((token, header_len));
543         }
544         let token = if packet_type == PacketType::Initial {
545             Self::opt(decoder.decode_vvec())?
546         } else {
547             &[]
548         };
549         let len = Self::opt(decoder.decode_varint())?;
550         let header_len = decoder.offset();
551         let _body = Self::opt(decoder.decode(usize::try_from(len)?))?;
552         Ok((token, header_len))
553     }
555     /// Decode the common parts of a packet.  This provides minimal parsing and validation.
556     /// Returns a tuple of a `PublicPacket` and a slice with any remainder from the datagram.
557     #[allow(clippy::similar_names)] // For dcid and scid, which are fine.
558     pub fn decode(data: &'a [u8], dcid_decoder: &dyn ConnectionIdDecoder) -> Res<(Self, &'a [u8])> {
559         let mut decoder = Decoder::new(data);
560         let first = Self::opt(decoder.decode_byte())?;
562         if first & 0x80 == PACKET_BIT_SHORT {
563             // Conveniently, this also guarantees that there is enough space
564             // for a connection ID of any size.
565             if decoder.remaining() < SAMPLE_OFFSET + SAMPLE_SIZE {
566                 return Err(Error::InvalidPacket);
567             }
568             let dcid = Self::opt(dcid_decoder.decode_cid(&mut decoder))?;
569             if decoder.remaining() < SAMPLE_OFFSET + SAMPLE_SIZE {
570                 return Err(Error::InvalidPacket);
571             }
572             let header_len = decoder.offset();
573             return Ok((
574                 Self {
575                     packet_type: PacketType::Short,
576                     dcid,
577                     scid: None,
578                     token: &[],
579                     header_len,
580                     version: None,
581                     data,
582                 },
583                 &[],
584             ));
585         }
587         // Generic long header.
588         let version = WireVersion::try_from(Self::opt(decoder.decode_uint(4))?).unwrap();
589         let dcid = ConnectionIdRef::from(Self::opt(decoder.decode_vec(1))?);
590         let scid = ConnectionIdRef::from(Self::opt(decoder.decode_vec(1))?);
592         // Version negotiation.
593         if version == 0 {
594             return Ok((
595                 Self {
596                     packet_type: PacketType::VersionNegotiation,
597                     dcid,
598                     scid: Some(scid),
599                     token: &[],
600                     header_len: decoder.offset(),
601                     version: None,
602                     data,
603                 },
604                 &[],
605             ));
606         }
608         // Check that this is a long header from a supported version.
609         let Ok(version) = Version::try_from(version) else {
610             return Ok((
611                 Self {
612                     packet_type: PacketType::OtherVersion,
613                     dcid,
614                     scid: Some(scid),
615                     token: &[],
616                     header_len: decoder.offset(),
617                     version: Some(version),
618                     data,
619                 },
620                 &[],
621             ));
622         };
624         if dcid.len() > MAX_CONNECTION_ID_LEN || scid.len() > MAX_CONNECTION_ID_LEN {
625             return Err(Error::InvalidPacket);
626         }
627         let packet_type = PacketType::from_byte((first >> 4) & 3, version);
629         // The type-specific code includes a token.  This consumes the remainder of the packet.
630         let (token, header_len) = Self::decode_long(&mut decoder, packet_type, version)?;
631         let end = data.len() - decoder.remaining();
632         let (data, remainder) = data.split_at(end);
633         Ok((
634             Self {
635                 packet_type,
636                 dcid,
637                 scid: Some(scid),
638                 token,
639                 header_len,
640                 version: Some(version.wire_version()),
641                 data,
642             },
643             remainder,
644         ))
645     }
647     /// Validate the given packet as though it were a retry.
648     pub fn is_valid_retry(&self, odcid: &ConnectionId) -> bool {
649         if self.packet_type != PacketType::Retry {
650             return false;
651         }
652         let version = self.version().unwrap();
653         let expansion = retry::expansion(version);
654         if self.data.len() <= expansion {
655             return false;
656         }
657         let (header, tag) = self.data.split_at(self.data.len() - expansion);
658         let mut encoder = Encoder::with_capacity(self.data.len());
659         encoder.encode_vec(1, odcid);
660         encoder.encode(header);
661         retry::use_aead(version, |aead| {
662             let mut buf = vec![0; expansion];
663             Ok(aead.decrypt(0, encoder.as_ref(), tag, &mut buf)?.is_empty())
664         })
665         .unwrap_or(false)
666     }
668     pub fn is_valid_initial(&self) -> bool {
669         // Packet has to be an initial, with a DCID of 8 bytes, or a token.
670         // Note: the Server class validates the token and checks the length.
671         self.packet_type == PacketType::Initial
672             && (self.dcid().len() >= 8 || !self.token.is_empty())
673     }
675     pub fn packet_type(&self) -> PacketType {
676         self.packet_type
677     }
679     pub fn dcid(&self) -> ConnectionIdRef<'a> {
680         self.dcid
681     }
683     pub fn scid(&self) -> ConnectionIdRef<'a> {
684         self.scid
685             .expect("should only be called for long header packets")
686     }
688     pub fn token(&self) -> &'a [u8] {
689         self.token
690     }
692     pub fn version(&self) -> Option<Version> {
693         self.version.and_then(|v| Version::try_from(v).ok())
694     }
696     pub fn wire_version(&self) -> WireVersion {
697         debug_assert!(self.version.is_some());
698         self.version.unwrap_or(0)
699     }
701     pub fn len(&self) -> usize {
702         self.data.len()
703     }
705     fn decode_pn(expected: PacketNumber, pn: u64, w: usize) -> PacketNumber {
706         let window = 1_u64 << (w * 8);
707         let candidate = (expected & !(window - 1)) | pn;
708         if candidate + (window / 2) <= expected {
709             candidate + window
710         } else if candidate > expected + (window / 2) {
711             match candidate.checked_sub(window) {
712                 Some(pn_sub) => pn_sub,
713                 None => candidate,
714             }
715         } else {
716             candidate
717         }
718     }
720     /// Decrypt the header of the packet.
721     fn decrypt_header(
722         &self,
723         crypto: &mut CryptoDxState,
724     ) -> Res<(bool, PacketNumber, Vec<u8>, &'a [u8])> {
725         assert_ne!(self.packet_type, PacketType::Retry);
726         assert_ne!(self.packet_type, PacketType::VersionNegotiation);
728         qtrace!(
729             "unmask hdr={}",
730             hex(&self.data[..self.header_len + SAMPLE_OFFSET])
731         );
733         let sample_offset = self.header_len + SAMPLE_OFFSET;
734         let mask = if let Some(sample) = self.data.get(sample_offset..(sample_offset + SAMPLE_SIZE))
735         {
736             crypto.compute_mask(sample)
737         } else {
738             Err(Error::NoMoreData)
739         }?;
741         // Un-mask the leading byte.
742         let bits = if self.packet_type == PacketType::Short {
743             PACKET_HP_MASK_SHORT
744         } else {
745             PACKET_HP_MASK_LONG
746         };
747         let first_byte = self.data[0] ^ (mask[0] & bits);
749         // Make a copy of the header to work on.
750         let mut hdrbytes = self.data[..self.header_len + 4].to_vec();
751         hdrbytes[0] = first_byte;
753         // Unmask the PN.
754         let mut pn_encoded: u64 = 0;
755         for i in 0..MAX_PACKET_NUMBER_LEN {
756             hdrbytes[self.header_len + i] ^= mask[1 + i];
757             pn_encoded <<= 8;
758             pn_encoded += u64::from(hdrbytes[self.header_len + i]);
759         }
761         // Now decode the packet number length and apply it, hopefully in constant time.
762         let pn_len = usize::from((first_byte & 0x3) + 1);
763         hdrbytes.truncate(self.header_len + pn_len);
764         pn_encoded >>= 8 * (MAX_PACKET_NUMBER_LEN - pn_len);
766         qtrace!("unmasked hdr={}", hex(&hdrbytes));
768         let key_phase = self.packet_type == PacketType::Short
769             && (first_byte & PACKET_BIT_KEY_PHASE) == PACKET_BIT_KEY_PHASE;
770         let pn = Self::decode_pn(crypto.next_pn(), pn_encoded, pn_len);
771         Ok((
772             key_phase,
773             pn,
774             hdrbytes,
775             &self.data[self.header_len + pn_len..],
776         ))
777     }
779     pub fn decrypt(&self, crypto: &mut CryptoStates, release_at: Instant) -> Res<DecryptedPacket> {
780         let cspace: CryptoSpace = self.packet_type.into();
781         // When we don't have a version, the crypto code doesn't need a version
782         // for lookup, so use the default, but fix it up if decryption succeeds.
783         let version = self.version().unwrap_or_default();
784         // This has to work in two stages because we need to remove header protection
785         // before picking the keys to use.
786         if let Some(rx) = crypto.rx_hp(version, cspace) {
787             // Note that this will dump early, which creates a side-channel.
788             // This is OK in this case because we the only reason this can
789             // fail is if the cryptographic module is bad or the packet is
790             // too small (which is public information).
791             let (key_phase, pn, header, body) = self.decrypt_header(rx)?;
792             qtrace!([rx], "decoded header: {:?}", header);
793             let rx = crypto.rx(version, cspace, key_phase).unwrap();
794             let version = rx.version(); // Version fixup; see above.
795             let d = rx.decrypt(pn, &header, body)?;
796             // If this is the first packet ever successfully decrypted
797             // using `rx`, make sure to initiate a key update.
798             if rx.needs_update() {
799                 crypto.key_update_received(release_at)?;
800             }
801             crypto.check_pn_overlap()?;
802             Ok(DecryptedPacket {
803                 version,
804                 pt: self.packet_type,
805                 pn,
806                 data: d,
807             })
808         } else if crypto.rx_pending(cspace) {
809             Err(Error::KeysPending(cspace))
810         } else {
811             qtrace!("keys for {:?} already discarded", cspace);
812             Err(Error::KeysDiscarded(cspace))
813         }
814     }
816     pub fn supported_versions(&self) -> Res<Vec<WireVersion>> {
817         assert_eq!(self.packet_type, PacketType::VersionNegotiation);
818         let mut decoder = Decoder::new(&self.data[self.header_len..]);
819         let mut res = Vec::new();
820         while decoder.remaining() > 0 {
821             let version = WireVersion::try_from(Self::opt(decoder.decode_uint(4))?)?;
822             res.push(version);
823         }
824         Ok(res)
825     }
828 impl fmt::Debug for PublicPacket<'_> {
829     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
830         write!(
831             f,
832             "{:?}: {} {}",
833             self.packet_type(),
834             hex_with_len(&self.data[..self.header_len]),
835             hex_with_len(&self.data[self.header_len..])
836         )
837     }
840 pub struct DecryptedPacket {
841     version: Version,
842     pt: PacketType,
843     pn: PacketNumber,
844     data: Vec<u8>,
847 impl DecryptedPacket {
848     pub fn version(&self) -> Version {
849         self.version
850     }
852     pub fn packet_type(&self) -> PacketType {
853         self.pt
854     }
856     pub fn pn(&self) -> PacketNumber {
857         self.pn
858     }
861 impl Deref for DecryptedPacket {
862     type Target = [u8];
864     fn deref(&self) -> &Self::Target {
865         &self.data[..]
866     }
869 #[cfg(all(test, not(feature = "fuzzing")))]
870 mod tests {
871     use neqo_common::Encoder;
872     use test_fixture::{fixture_init, now};
874     use crate::{
875         cid::MAX_CONNECTION_ID_LEN,
876         crypto::{CryptoDxState, CryptoStates},
877         packet::{
878             PacketBuilder, PacketType, PublicPacket, PACKET_BIT_FIXED_QUIC, PACKET_BIT_LONG,
879             PACKET_BIT_SPIN,
880         },
881         ConnectionId, EmptyConnectionIdGenerator, RandomConnectionIdGenerator, Version,
882     };
884     const CLIENT_CID: &[u8] = &[0x83, 0x94, 0xc8, 0xf0, 0x3e, 0x51, 0x57, 0x08];
885     const SERVER_CID: &[u8] = &[0xf0, 0x67, 0xa5, 0x50, 0x2a, 0x42, 0x62, 0xb5];
887     /// This is a connection ID manager, which is only used for decoding short header packets.
888     fn cid_mgr() -> RandomConnectionIdGenerator {
889         RandomConnectionIdGenerator::new(SERVER_CID.len())
890     }
892     const SAMPLE_INITIAL_PAYLOAD: &[u8] = &[
893         0x02, 0x00, 0x00, 0x00, 0x00, 0x06, 0x00, 0x40, 0x5a, 0x02, 0x00, 0x00, 0x56, 0x03, 0x03,
894         0xee, 0xfc, 0xe7, 0xf7, 0xb3, 0x7b, 0xa1, 0xd1, 0x63, 0x2e, 0x96, 0x67, 0x78, 0x25, 0xdd,
895         0xf7, 0x39, 0x88, 0xcf, 0xc7, 0x98, 0x25, 0xdf, 0x56, 0x6d, 0xc5, 0x43, 0x0b, 0x9a, 0x04,
896         0x5a, 0x12, 0x00, 0x13, 0x01, 0x00, 0x00, 0x2e, 0x00, 0x33, 0x00, 0x24, 0x00, 0x1d, 0x00,
897         0x20, 0x9d, 0x3c, 0x94, 0x0d, 0x89, 0x69, 0x0b, 0x84, 0xd0, 0x8a, 0x60, 0x99, 0x3c, 0x14,
898         0x4e, 0xca, 0x68, 0x4d, 0x10, 0x81, 0x28, 0x7c, 0x83, 0x4d, 0x53, 0x11, 0xbc, 0xf3, 0x2b,
899         0xb9, 0xda, 0x1a, 0x00, 0x2b, 0x00, 0x02, 0x03, 0x04,
900     ];
901     const SAMPLE_INITIAL: &[u8] = &[
902         0xcf, 0x00, 0x00, 0x00, 0x01, 0x00, 0x08, 0xf0, 0x67, 0xa5, 0x50, 0x2a, 0x42, 0x62, 0xb5,
903         0x00, 0x40, 0x75, 0xc0, 0xd9, 0x5a, 0x48, 0x2c, 0xd0, 0x99, 0x1c, 0xd2, 0x5b, 0x0a, 0xac,
904         0x40, 0x6a, 0x58, 0x16, 0xb6, 0x39, 0x41, 0x00, 0xf3, 0x7a, 0x1c, 0x69, 0x79, 0x75, 0x54,
905         0x78, 0x0b, 0xb3, 0x8c, 0xc5, 0xa9, 0x9f, 0x5e, 0xde, 0x4c, 0xf7, 0x3c, 0x3e, 0xc2, 0x49,
906         0x3a, 0x18, 0x39, 0xb3, 0xdb, 0xcb, 0xa3, 0xf6, 0xea, 0x46, 0xc5, 0xb7, 0x68, 0x4d, 0xf3,
907         0x54, 0x8e, 0x7d, 0xde, 0xb9, 0xc3, 0xbf, 0x9c, 0x73, 0xcc, 0x3f, 0x3b, 0xde, 0xd7, 0x4b,
908         0x56, 0x2b, 0xfb, 0x19, 0xfb, 0x84, 0x02, 0x2f, 0x8e, 0xf4, 0xcd, 0xd9, 0x37, 0x95, 0xd7,
909         0x7d, 0x06, 0xed, 0xbb, 0x7a, 0xaf, 0x2f, 0x58, 0x89, 0x18, 0x50, 0xab, 0xbd, 0xca, 0x3d,
910         0x20, 0x39, 0x8c, 0x27, 0x64, 0x56, 0xcb, 0xc4, 0x21, 0x58, 0x40, 0x7d, 0xd0, 0x74, 0xee,
911     ];
913     #[test]
914     fn sample_server_initial() {
915         fixture_init();
916         let mut prot = CryptoDxState::test_default();
918         // The spec uses PN=1, but our crypto refuses to skip packet numbers.
919         // So burn an encryption:
920         let burn = prot.encrypt(0, &[], &[]).expect("burn OK");
921         assert_eq!(burn.len(), prot.expansion());
923         let mut builder = PacketBuilder::long(
924             Encoder::new(),
925             PacketType::Initial,
926             Version::default(),
927             &ConnectionId::from(&[][..]),
928             &ConnectionId::from(SERVER_CID),
929         );
930         builder.initial_token(&[]);
931         builder.pn(1, 2);
932         builder.encode(SAMPLE_INITIAL_PAYLOAD);
933         let packet = builder.build(&mut prot).expect("build");
934         assert_eq!(packet.as_ref(), SAMPLE_INITIAL);
935     }
937     #[test]
938     fn decrypt_initial() {
939         const EXTRA: &[u8] = &[0xce; 33];
941         fixture_init();
942         let mut padded = SAMPLE_INITIAL.to_vec();
943         padded.extend_from_slice(EXTRA);
944         let (packet, remainder) = PublicPacket::decode(&padded, &cid_mgr()).unwrap();
945         assert_eq!(packet.packet_type(), PacketType::Initial);
946         assert_eq!(&packet.dcid()[..], &[] as &[u8]);
947         assert_eq!(&packet.scid()[..], SERVER_CID);
948         assert!(packet.token().is_empty());
949         assert_eq!(remainder, EXTRA);
951         let decrypted = packet
952             .decrypt(&mut CryptoStates::test_default(), now())
953             .unwrap();
954         assert_eq!(decrypted.pn(), 1);
955     }
957     #[test]
958     fn disallow_long_dcid() {
959         let mut enc = Encoder::new();
960         enc.encode_byte(PACKET_BIT_LONG | PACKET_BIT_FIXED_QUIC);
961         enc.encode_uint(4, Version::default().wire_version());
962         enc.encode_vec(1, &[0x00; MAX_CONNECTION_ID_LEN + 1]);
963         enc.encode_vec(1, &[]);
964         enc.encode(&[0xff; 40]); // junk
966         assert!(PublicPacket::decode(enc.as_ref(), &cid_mgr()).is_err());
967     }
969     #[test]
970     fn disallow_long_scid() {
971         let mut enc = Encoder::new();
972         enc.encode_byte(PACKET_BIT_LONG | PACKET_BIT_FIXED_QUIC);
973         enc.encode_uint(4, Version::default().wire_version());
974         enc.encode_vec(1, &[]);
975         enc.encode_vec(1, &[0x00; MAX_CONNECTION_ID_LEN + 2]);
976         enc.encode(&[0xff; 40]); // junk
978         assert!(PublicPacket::decode(enc.as_ref(), &cid_mgr()).is_err());
979     }
981     const SAMPLE_SHORT: &[u8] = &[
982         0x40, 0xf0, 0x67, 0xa5, 0x50, 0x2a, 0x42, 0x62, 0xb5, 0xf4, 0xa8, 0x30, 0x39, 0xc4, 0x7d,
983         0x99, 0xe3, 0x94, 0x1c, 0x9b, 0xb9, 0x7a, 0x30, 0x1d, 0xd5, 0x8f, 0xf3, 0xdd, 0xa9,
984     ];
985     const SAMPLE_SHORT_PAYLOAD: &[u8] = &[0; 3];
987     #[test]
988     fn build_short() {
989         fixture_init();
990         let mut builder =
991             PacketBuilder::short(Encoder::new(), true, &ConnectionId::from(SERVER_CID));
992         builder.pn(0, 1);
993         builder.encode(SAMPLE_SHORT_PAYLOAD); // Enough payload for sampling.
994         let packet = builder
995             .build(&mut CryptoDxState::test_default())
996             .expect("build");
997         assert_eq!(packet.as_ref(), SAMPLE_SHORT);
998     }
1000     #[test]
1001     fn scramble_short() {
1002         fixture_init();
1003         let mut firsts = Vec::new();
1004         for _ in 0..64 {
1005             let mut builder =
1006                 PacketBuilder::short(Encoder::new(), true, &ConnectionId::from(SERVER_CID));
1007             builder.scramble(true);
1008             builder.pn(0, 1);
1009             firsts.push(builder.as_ref()[0]);
1010         }
1011         let is_set = |bit| move |v| v & bit == bit;
1012         // There should be at least one value with the QUIC bit set:
1013         assert!(firsts.iter().any(is_set(PACKET_BIT_FIXED_QUIC)));
1014         // ... but not all:
1015         assert!(!firsts.iter().all(is_set(PACKET_BIT_FIXED_QUIC)));
1016         // There should be at least one value with the spin bit set:
1017         assert!(firsts.iter().any(is_set(PACKET_BIT_SPIN)));
1018         // ... but not all:
1019         assert!(!firsts.iter().all(is_set(PACKET_BIT_SPIN)));
1020     }
1022     #[test]
1023     fn decode_short() {
1024         fixture_init();
1025         let (packet, remainder) = PublicPacket::decode(SAMPLE_SHORT, &cid_mgr()).unwrap();
1026         assert_eq!(packet.packet_type(), PacketType::Short);
1027         assert!(remainder.is_empty());
1028         let decrypted = packet
1029             .decrypt(&mut CryptoStates::test_default(), now())
1030             .unwrap();
1031         assert_eq!(&decrypted[..], SAMPLE_SHORT_PAYLOAD);
1032     }
1034     /// By telling the decoder that the connection ID is shorter than it really is, we get a
1035     /// decryption error.
1036     #[test]
1037     fn decode_short_bad_cid() {
1038         fixture_init();
1039         let (packet, remainder) = PublicPacket::decode(
1040             SAMPLE_SHORT,
1041             &RandomConnectionIdGenerator::new(SERVER_CID.len() - 1),
1042         )
1043         .unwrap();
1044         assert_eq!(packet.packet_type(), PacketType::Short);
1045         assert!(remainder.is_empty());
1046         assert!(packet
1047             .decrypt(&mut CryptoStates::test_default(), now())
1048             .is_err());
1049     }
1051     /// Saying that the connection ID is longer causes the initial decode to fail.
1052     #[test]
1053     fn decode_short_long_cid() {
1054         assert!(PublicPacket::decode(
1055             SAMPLE_SHORT,
1056             &RandomConnectionIdGenerator::new(SERVER_CID.len() + 1)
1057         )
1058         .is_err());
1059     }
1061     #[test]
1062     fn build_two() {
1063         fixture_init();
1064         let mut prot = CryptoDxState::test_default();
1065         let mut builder = PacketBuilder::long(
1066             Encoder::new(),
1067             PacketType::Handshake,
1068             Version::default(),
1069             &ConnectionId::from(SERVER_CID),
1070             &ConnectionId::from(CLIENT_CID),
1071         );
1072         builder.pn(0, 1);
1073         builder.encode(&[0; 3]);
1074         let encoder = builder.build(&mut prot).expect("build");
1075         assert_eq!(encoder.len(), 45);
1076         let first = encoder.clone();
1078         let mut builder = PacketBuilder::short(encoder, false, &ConnectionId::from(SERVER_CID));
1079         builder.pn(1, 3);
1080         builder.encode(&[0]); // Minimal size (packet number is big enough).
1081         let encoder = builder.build(&mut prot).expect("build");
1082         assert_eq!(
1083             first.as_ref(),
1084             &encoder.as_ref()[..first.len()],
1085             "the first packet should be a prefix"
1086         );
1087         assert_eq!(encoder.len(), 45 + 29);
1088     }
1090     #[test]
1091     fn build_long() {
1092         const EXPECTED: &[u8] = &[
1093             0xe4, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x40, 0x14, 0xfb, 0xa9, 0x32, 0x3a, 0xf8,
1094             0xbb, 0x18, 0x63, 0xc6, 0xbd, 0x78, 0x0e, 0xba, 0x0c, 0x98, 0x65, 0x58, 0xc9, 0x62,
1095             0x31,
1096         ];
1098         fixture_init();
1099         let mut builder = PacketBuilder::long(
1100             Encoder::new(),
1101             PacketType::Handshake,
1102             Version::default(),
1103             &ConnectionId::from(&[][..]),
1104             &ConnectionId::from(&[][..]),
1105         );
1106         builder.pn(0, 1);
1107         builder.encode(&[1, 2, 3]);
1108         let packet = builder.build(&mut CryptoDxState::test_default()).unwrap();
1109         assert_eq!(packet.as_ref(), EXPECTED);
1110     }
1112     #[test]
1113     fn scramble_long() {
1114         fixture_init();
1115         let mut found_unset = false;
1116         let mut found_set = false;
1117         for _ in 1..64 {
1118             let mut builder = PacketBuilder::long(
1119                 Encoder::new(),
1120                 PacketType::Handshake,
1121                 Version::default(),
1122                 &ConnectionId::from(&[][..]),
1123                 &ConnectionId::from(&[][..]),
1124             );
1125             builder.pn(0, 1);
1126             builder.scramble(true);
1127             if (builder.as_ref()[0] & PACKET_BIT_FIXED_QUIC) == 0 {
1128                 found_unset = true;
1129             } else {
1130                 found_set = true;
1131             }
1132         }
1133         assert!(found_unset);
1134         assert!(found_set);
1135     }
1137     #[test]
1138     fn build_abort() {
1139         let mut builder = PacketBuilder::long(
1140             Encoder::new(),
1141             PacketType::Initial,
1142             Version::default(),
1143             &ConnectionId::from(&[][..]),
1144             &ConnectionId::from(SERVER_CID),
1145         );
1146         assert_ne!(builder.remaining(), 0);
1147         builder.initial_token(&[]);
1148         assert_ne!(builder.remaining(), 0);
1149         builder.pn(1, 2);
1150         assert_ne!(builder.remaining(), 0);
1151         let encoder = builder.abort();
1152         assert!(encoder.is_empty());
1153     }
1155     #[test]
1156     fn build_insufficient_space() {
1157         fixture_init();
1159         let mut builder = PacketBuilder::short(
1160             Encoder::with_capacity(100),
1161             true,
1162             &ConnectionId::from(SERVER_CID),
1163         );
1164         builder.pn(0, 1);
1165         // Pad, but not up to the full capacity. Leave enough space for the
1166         // AEAD expansion and some extra, but not for an entire long header.
1167         builder.set_limit(75);
1168         builder.enable_padding(true);
1169         assert!(builder.pad());
1170         let encoder = builder.build(&mut CryptoDxState::test_default()).unwrap();
1171         let encoder_copy = encoder.clone();
1173         let builder = PacketBuilder::long(
1174             encoder,
1175             PacketType::Initial,
1176             Version::default(),
1177             &ConnectionId::from(SERVER_CID),
1178             &ConnectionId::from(SERVER_CID),
1179         );
1180         assert_eq!(builder.remaining(), 0);
1181         assert_eq!(builder.abort(), encoder_copy);
1182     }
1184     const SAMPLE_RETRY_V2: &[u8] = &[
1185         0xcf, 0x6b, 0x33, 0x43, 0xcf, 0x00, 0x08, 0xf0, 0x67, 0xa5, 0x50, 0x2a, 0x42, 0x62, 0xb5,
1186         0x74, 0x6f, 0x6b, 0x65, 0x6e, 0xc8, 0x64, 0x6c, 0xe8, 0xbf, 0xe3, 0x39, 0x52, 0xd9, 0x55,
1187         0x54, 0x36, 0x65, 0xdc, 0xc7, 0xb6,
1188     ];
1190     const SAMPLE_RETRY_V1: &[u8] = &[
1191         0xff, 0x00, 0x00, 0x00, 0x01, 0x00, 0x08, 0xf0, 0x67, 0xa5, 0x50, 0x2a, 0x42, 0x62, 0xb5,
1192         0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x04, 0xa2, 0x65, 0xba, 0x2e, 0xff, 0x4d, 0x82, 0x90, 0x58,
1193         0xfb, 0x3f, 0x0f, 0x24, 0x96, 0xba,
1194     ];
1196     const SAMPLE_RETRY_29: &[u8] = &[
1197         0xff, 0xff, 0x00, 0x00, 0x1d, 0x00, 0x08, 0xf0, 0x67, 0xa5, 0x50, 0x2a, 0x42, 0x62, 0xb5,
1198         0x74, 0x6f, 0x6b, 0x65, 0x6e, 0xd1, 0x69, 0x26, 0xd8, 0x1f, 0x6f, 0x9c, 0xa2, 0x95, 0x3a,
1199         0x8a, 0xa4, 0x57, 0x5e, 0x1e, 0x49,
1200     ];
1202     const SAMPLE_RETRY_30: &[u8] = &[
1203         0xff, 0xff, 0x00, 0x00, 0x1e, 0x00, 0x08, 0xf0, 0x67, 0xa5, 0x50, 0x2a, 0x42, 0x62, 0xb5,
1204         0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x2d, 0x3e, 0x04, 0x5d, 0x6d, 0x39, 0x20, 0x67, 0x89, 0x94,
1205         0x37, 0x10, 0x8c, 0xe0, 0x0a, 0x61,
1206     ];
1208     const SAMPLE_RETRY_31: &[u8] = &[
1209         0xff, 0xff, 0x00, 0x00, 0x1f, 0x00, 0x08, 0xf0, 0x67, 0xa5, 0x50, 0x2a, 0x42, 0x62, 0xb5,
1210         0x74, 0x6f, 0x6b, 0x65, 0x6e, 0xc7, 0x0c, 0xe5, 0xde, 0x43, 0x0b, 0x4b, 0xdb, 0x7d, 0xf1,
1211         0xa3, 0x83, 0x3a, 0x75, 0xf9, 0x86,
1212     ];
1214     const SAMPLE_RETRY_32: &[u8] = &[
1215         0xff, 0xff, 0x00, 0x00, 0x20, 0x00, 0x08, 0xf0, 0x67, 0xa5, 0x50, 0x2a, 0x42, 0x62, 0xb5,
1216         0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x59, 0x75, 0x65, 0x19, 0xdd, 0x6c, 0xc8, 0x5b, 0xd9, 0x0e,
1217         0x33, 0xa9, 0x34, 0xd2, 0xff, 0x85,
1218     ];
1220     const RETRY_TOKEN: &[u8] = b"token";
1222     fn build_retry_single(version: Version, sample_retry: &[u8]) {
1223         fixture_init();
1224         let retry =
1225             PacketBuilder::retry(version, &[], SERVER_CID, RETRY_TOKEN, CLIENT_CID).unwrap();
1227         let (packet, remainder) = PublicPacket::decode(&retry, &cid_mgr()).unwrap();
1228         assert!(packet.is_valid_retry(&ConnectionId::from(CLIENT_CID)));
1229         assert!(remainder.is_empty());
1231         // The builder adds randomness, which makes expectations hard.
1232         // So only do a full check when that randomness matches up.
1233         if retry[0] == sample_retry[0] {
1234             assert_eq!(&retry, &sample_retry);
1235         } else {
1236             // Otherwise, just check that the header is OK.
1237             assert_eq!(
1238                 retry[0] & 0xf0,
1239                 0xc0 | (PacketType::Retry.to_byte(version) << 4)
1240             );
1241             let header_range = 1..retry.len() - 16;
1242             assert_eq!(&retry[header_range.clone()], &sample_retry[header_range]);
1243         }
1244     }
1246     #[test]
1247     fn build_retry_v2() {
1248         build_retry_single(Version::Version2, SAMPLE_RETRY_V2);
1249     }
1251     #[test]
1252     fn build_retry_v1() {
1253         build_retry_single(Version::Version1, SAMPLE_RETRY_V1);
1254     }
1256     #[test]
1257     fn build_retry_29() {
1258         build_retry_single(Version::Draft29, SAMPLE_RETRY_29);
1259     }
1261     #[test]
1262     fn build_retry_30() {
1263         build_retry_single(Version::Draft30, SAMPLE_RETRY_30);
1264     }
1266     #[test]
1267     fn build_retry_31() {
1268         build_retry_single(Version::Draft31, SAMPLE_RETRY_31);
1269     }
1271     #[test]
1272     fn build_retry_32() {
1273         build_retry_single(Version::Draft32, SAMPLE_RETRY_32);
1274     }
1276     #[test]
1277     fn build_retry_multiple() {
1278         // Run the build_retry test a few times.
1279         // Odds are approximately 1 in 8 that the full comparison doesn't happen
1280         // for a given version.
1281         for _ in 0..32 {
1282             build_retry_v2();
1283             build_retry_v1();
1284             build_retry_29();
1285             build_retry_30();
1286             build_retry_31();
1287             build_retry_32();
1288         }
1289     }
1291     fn decode_retry(version: Version, sample_retry: &[u8]) {
1292         fixture_init();
1293         let (packet, remainder) =
1294             PublicPacket::decode(sample_retry, &RandomConnectionIdGenerator::new(5)).unwrap();
1295         assert!(packet.is_valid_retry(&ConnectionId::from(CLIENT_CID)));
1296         assert_eq!(Some(version), packet.version());
1297         assert!(packet.dcid().is_empty());
1298         assert_eq!(&packet.scid()[..], SERVER_CID);
1299         assert_eq!(packet.token(), RETRY_TOKEN);
1300         assert!(remainder.is_empty());
1301     }
1303     #[test]
1304     fn decode_retry_v2() {
1305         decode_retry(Version::Version2, SAMPLE_RETRY_V2);
1306     }
1308     #[test]
1309     fn decode_retry_v1() {
1310         decode_retry(Version::Version1, SAMPLE_RETRY_V1);
1311     }
1313     #[test]
1314     fn decode_retry_29() {
1315         decode_retry(Version::Draft29, SAMPLE_RETRY_29);
1316     }
1318     #[test]
1319     fn decode_retry_30() {
1320         decode_retry(Version::Draft30, SAMPLE_RETRY_30);
1321     }
1323     #[test]
1324     fn decode_retry_31() {
1325         decode_retry(Version::Draft31, SAMPLE_RETRY_31);
1326     }
1328     #[test]
1329     fn decode_retry_32() {
1330         decode_retry(Version::Draft32, SAMPLE_RETRY_32);
1331     }
1333     /// Check some packets that are clearly not valid Retry packets.
1334     #[test]
1335     fn invalid_retry() {
1336         fixture_init();
1337         let cid_mgr = RandomConnectionIdGenerator::new(5);
1338         let odcid = ConnectionId::from(CLIENT_CID);
1340         assert!(PublicPacket::decode(&[], &cid_mgr).is_err());
1342         let (packet, remainder) = PublicPacket::decode(SAMPLE_RETRY_V1, &cid_mgr).unwrap();
1343         assert!(remainder.is_empty());
1344         assert!(packet.is_valid_retry(&odcid));
1346         let mut damaged_retry = SAMPLE_RETRY_V1.to_vec();
1347         let last = damaged_retry.len() - 1;
1348         damaged_retry[last] ^= 66;
1349         let (packet, remainder) = PublicPacket::decode(&damaged_retry, &cid_mgr).unwrap();
1350         assert!(remainder.is_empty());
1351         assert!(!packet.is_valid_retry(&odcid));
1353         damaged_retry.truncate(last);
1354         let (packet, remainder) = PublicPacket::decode(&damaged_retry, &cid_mgr).unwrap();
1355         assert!(remainder.is_empty());
1356         assert!(!packet.is_valid_retry(&odcid));
1358         // An invalid token should be rejected sooner.
1359         damaged_retry.truncate(last - 4);
1360         assert!(PublicPacket::decode(&damaged_retry, &cid_mgr).is_err());
1362         damaged_retry.truncate(last - 1);
1363         assert!(PublicPacket::decode(&damaged_retry, &cid_mgr).is_err());
1364     }
1366     const SAMPLE_VN: &[u8] = &[
1367         0x80, 0x00, 0x00, 0x00, 0x00, 0x08, 0xf0, 0x67, 0xa5, 0x50, 0x2a, 0x42, 0x62, 0xb5, 0x08,
1368         0x83, 0x94, 0xc8, 0xf0, 0x3e, 0x51, 0x57, 0x08, 0x6b, 0x33, 0x43, 0xcf, 0x00, 0x00, 0x00,
1369         0x01, 0xff, 0x00, 0x00, 0x20, 0xff, 0x00, 0x00, 0x1f, 0xff, 0x00, 0x00, 0x1e, 0xff, 0x00,
1370         0x00, 0x1d, 0x0a, 0x0a, 0x0a, 0x0a,
1371     ];
1373     #[test]
1374     fn build_vn() {
1375         fixture_init();
1376         let mut vn = PacketBuilder::version_negotiation(
1377             SERVER_CID,
1378             CLIENT_CID,
1379             0x0a0a_0a0a,
1380             &Version::all(),
1381         );
1382         // Erase randomness from greasing...
1383         assert_eq!(vn.len(), SAMPLE_VN.len());
1384         vn[0] &= 0x80;
1385         for v in vn.iter_mut().skip(SAMPLE_VN.len() - 4) {
1386             *v &= 0x0f;
1387         }
1388         assert_eq!(&vn, &SAMPLE_VN);
1389     }
1391     #[test]
1392     fn vn_do_not_repeat_client_grease() {
1393         fixture_init();
1394         let vn = PacketBuilder::version_negotiation(
1395             SERVER_CID,
1396             CLIENT_CID,
1397             0x0a0a_0a0a,
1398             &Version::all(),
1399         );
1400         assert_ne!(&vn[SAMPLE_VN.len() - 4..], &[0x0a, 0x0a, 0x0a, 0x0a]);
1401     }
1403     #[test]
1404     fn parse_vn() {
1405         let (packet, remainder) =
1406             PublicPacket::decode(SAMPLE_VN, &EmptyConnectionIdGenerator::default()).unwrap();
1407         assert!(remainder.is_empty());
1408         assert_eq!(&packet.dcid[..], SERVER_CID);
1409         assert!(packet.scid.is_some());
1410         assert_eq!(&packet.scid.unwrap()[..], CLIENT_CID);
1411     }
1413     /// A Version Negotiation packet can have a long connection ID.
1414     #[test]
1415     fn parse_vn_big_cid() {
1416         const BIG_DCID: &[u8] = &[0x44; MAX_CONNECTION_ID_LEN + 1];
1417         const BIG_SCID: &[u8] = &[0xee; 255];
1419         let mut enc = Encoder::from(&[0xff, 0x00, 0x00, 0x00, 0x00][..]);
1420         enc.encode_vec(1, BIG_DCID);
1421         enc.encode_vec(1, BIG_SCID);
1422         enc.encode_uint(4, 0x1a2a_3a4a_u64);
1423         enc.encode_uint(4, Version::default().wire_version());
1424         enc.encode_uint(4, 0x5a6a_7a8a_u64);
1426         let (packet, remainder) =
1427             PublicPacket::decode(enc.as_ref(), &EmptyConnectionIdGenerator::default()).unwrap();
1428         assert!(remainder.is_empty());
1429         assert_eq!(&packet.dcid[..], BIG_DCID);
1430         assert!(packet.scid.is_some());
1431         assert_eq!(&packet.scid.unwrap()[..], BIG_SCID);
1432     }
1434     #[test]
1435     fn decode_pn() {
1436         // When the expected value is low, the value doesn't go negative.
1437         assert_eq!(PublicPacket::decode_pn(0, 0, 1), 0);
1438         assert_eq!(PublicPacket::decode_pn(0, 0xff, 1), 0xff);
1439         assert_eq!(PublicPacket::decode_pn(10, 0, 1), 0);
1440         assert_eq!(PublicPacket::decode_pn(0x7f, 0, 1), 0);
1441         assert_eq!(PublicPacket::decode_pn(0x80, 0, 1), 0x100);
1442         assert_eq!(PublicPacket::decode_pn(0x80, 2, 1), 2);
1443         assert_eq!(PublicPacket::decode_pn(0x80, 0xff, 1), 0xff);
1444         assert_eq!(PublicPacket::decode_pn(0x7ff, 0xfe, 1), 0x7fe);
1446         // This is invalid by spec, as we are expected to check for overflow around 2^62-1,
1447         // but we don't need to worry about overflow
1448         // and hitting this is basically impossible in practice.
1449         assert_eq!(
1450             PublicPacket::decode_pn(0x3fff_ffff_ffff_ffff, 2, 4),
1451             0x4000_0000_0000_0002
1452         );
1453     }
1455     #[test]
1456     fn chacha20_sample() {
1457         const PACKET: &[u8] = &[
1458             0x4c, 0xfe, 0x41, 0x89, 0x65, 0x5e, 0x5c, 0xd5, 0x5c, 0x41, 0xf6, 0x90, 0x80, 0x57,
1459             0x5d, 0x79, 0x99, 0xc2, 0x5a, 0x5b, 0xfb,
1460         ];
1461         fixture_init();
1462         let (packet, slice) =
1463             PublicPacket::decode(PACKET, &EmptyConnectionIdGenerator::default()).unwrap();
1464         assert!(slice.is_empty());
1465         let decrypted = packet
1466             .decrypt(&mut CryptoStates::test_chacha(), now())
1467             .unwrap();
1468         assert_eq!(decrypted.packet_type(), PacketType::Short);
1469         assert_eq!(decrypted.pn(), 654_360_564);
1470         assert_eq!(&decrypted[..], &[0x01]);
1471     }