Bug 1826304: Update libprio-rs to 0.12.0. r=emilio,glandium,supply-chain-reviewers
[gecko.git] / third_party / rust / prio / src / field.rs
blob15369dcaf2857d2bac93afcb919f7618b5aac0ea
1 // Copyright (c) 2020 Apple Inc.
2 // SPDX-License-Identifier: MPL-2.0
4 //! Finite field arithmetic.
5 //!
6 //! Basic field arithmetic is captured in the [`FieldElement`] trait. Fields used in Prio implement
7 //! [`FftFriendlyFieldElement`], and have an associated element called the "generator" that
8 //! generates a multiplicative subgroup of order `2^n` for some `n`.
10 #[cfg(feature = "crypto-dependencies")]
11 use crate::prng::{Prng, PrngError};
12 use crate::{
13     codec::{CodecError, Decode, Encode},
14     fp::{FP128, FP32, FP64},
15     vdaf::prg::{CoinToss, SeedStream},
17 use serde::{
18     de::{DeserializeOwned, Visitor},
19     Deserialize, Deserializer, Serialize, Serializer,
21 use std::{
22     cmp::min,
23     convert::{TryFrom, TryInto},
24     fmt::{self, Debug, Display, Formatter},
25     hash::{Hash, Hasher},
26     io::{Cursor, Read},
27     marker::PhantomData,
28     ops::{Add, AddAssign, BitAnd, Div, DivAssign, Mul, MulAssign, Neg, Shl, Shr, Sub, SubAssign},
30 use subtle::{Choice, ConditionallyNegatable, ConditionallySelectable, ConstantTimeEq};
32 #[cfg(feature = "experimental")]
33 mod field255;
35 #[cfg(feature = "experimental")]
36 pub use field255::Field255;
38 /// Possible errors from finite field operations.
39 #[derive(Debug, thiserror::Error)]
40 pub enum FieldError {
41     /// Input sizes do not match.
42     #[error("input sizes do not match")]
43     InputSizeMismatch,
44     /// Returned when decoding a [`FieldElement`] from a too-short byte string.
45     #[error("short read from bytes")]
46     ShortRead,
47     /// Returned when decoding a [`FieldElement`] from a byte string that encodes an integer greater
48     /// than or equal to the field modulus.
49     #[error("read from byte slice exceeds modulus")]
50     ModulusOverflow,
51     /// Error while performing I/O.
52     #[error("I/O error")]
53     Io(#[from] std::io::Error),
54     /// Error encoding or decoding a field.
55     #[error("Codec error")]
56     Codec(#[from] CodecError),
57     /// Error converting to [`FieldElementWithInteger::Integer`].
58     #[error("Integer TryFrom error")]
59     IntegerTryFrom,
62 /// Objects with this trait represent an element of `GF(p)` for some prime `p`.
63 pub trait FieldElement:
64     Sized
65     + Debug
66     + Copy
67     + PartialEq
68     + Eq
69     + ConstantTimeEq
70     + ConditionallySelectable
71     + ConditionallyNegatable
72     + Add<Output = Self>
73     + AddAssign
74     + Sub<Output = Self>
75     + SubAssign
76     + Mul<Output = Self>
77     + MulAssign
78     + Div<Output = Self>
79     + DivAssign
80     + Neg<Output = Self>
81     + Display
82     + for<'a> TryFrom<&'a [u8], Error = FieldError>
83     // NOTE Ideally we would require `Into<[u8; Self::ENCODED_SIZE]>` instead of `Into<Vec<u8>>`,
84     // since the former avoids a heap allocation and can easily be converted into Vec<u8>, but that
85     // isn't possible yet[1]. However we can provide the impl on FieldElement implementations.
86     // [1]: https://github.com/rust-lang/rust/issues/60551
87     + Into<Vec<u8>>
88     + Serialize
89     + DeserializeOwned
90     + Encode
91     + Decode
92     + 'static // NOTE This bound is needed for downcasting a `dyn Gadget<F>>` to a concrete type.
94     /// Size in bytes of an encoded field element.
95     const ENCODED_SIZE: usize;
97     /// Modular inversion, i.e., `self^-1 (mod p)`. If `self` is 0, then the output is undefined.
98     fn inv(&self) -> Self;
100     /// Interprets the next [`Self::ENCODED_SIZE`] bytes from the input slice as an element of the
101     /// field. The `m` most significant bits are cleared, where `m` is equal to the length of
102     /// [`Self::Integer`] in bits minus the length of the modulus in bits.
103     ///
104     /// # Errors
105     ///
106     /// An error is returned if the provided slice is too small to encode a field element or if the
107     /// result encodes an integer larger than or equal to the field modulus.
108     ///
109     /// # Warnings
110     ///
111     /// This function should only be used within [`prng::Prng`] to convert a random byte string into
112     /// a field element. Use [`Self::decode`] to deserialize field elements. Use
113     /// [`field::rand`] or [`prng::Prng`] to randomly generate field elements.
114     #[doc(hidden)]
115     fn try_from_random(bytes: &[u8]) -> Result<Self, FieldError>;
117     /// Returns the additive identity.
118     fn zero() -> Self;
120     /// Returns the multiplicative identity.
121     fn one() -> Self;
123     /// Convert a slice of field elements into a vector of bytes.
124     ///
125     /// # Notes
126     ///
127     /// Ideally we would implement `From<&[F: FieldElement]> for Vec<u8>` or the corresponding
128     /// `Into`, but the orphan rule and the stdlib's blanket implementations of `Into` make this
129     /// impossible.
130     fn slice_into_byte_vec(values: &[Self]) -> Vec<u8> {
131         let mut vec = Vec::with_capacity(values.len() * Self::ENCODED_SIZE);
132         encode_fieldvec(values, &mut vec);
133         vec
134     }
136     /// Convert a slice of bytes into a vector of field elements. The slice is interpreted as a
137     /// sequence of [`Self::ENCODED_SIZE`]-byte sequences.
138     ///
139     /// # Errors
140     ///
141     /// Returns an error if the length of the provided byte slice is not a multiple of the size of a
142     /// field element, or if any of the values in the byte slice are invalid encodings of a field
143     /// element, because the encoded integer is larger than or equal to the field modulus.
144     ///
145     /// # Notes
146     ///
147     /// Ideally we would implement `From<&[u8]> for Vec<F: FieldElement>` or the corresponding
148     /// `Into`, but the orphan rule and the stdlib's blanket implementations of `Into` make this
149     /// impossible.
150     fn byte_slice_into_vec(bytes: &[u8]) -> Result<Vec<Self>, FieldError> {
151         if bytes.len() % Self::ENCODED_SIZE != 0 {
152             return Err(FieldError::ShortRead);
153         }
154         let mut vec = Vec::with_capacity(bytes.len() / Self::ENCODED_SIZE);
155         for chunk in bytes.chunks_exact(Self::ENCODED_SIZE) {
156             vec.push(Self::get_decoded(chunk)?);
157         }
158         Ok(vec)
159     }
162 /// Extension trait for field elements that can be converted back and forth to an integer type.
164 /// The `Integer` associated type is an integer (primitive or otherwise) that supports various
165 /// arithmetic operations. The order of the field is guaranteed to fit inside the range of the
166 /// integer type. This trait also defines methods on field elements, `pow` and `modulus`, that make
167 /// use of the associated integer type.
168 pub trait FieldElementWithInteger: FieldElement + From<Self::Integer> {
169     /// The error returned if converting `usize` to an `Integer` fails.
170     type IntegerTryFromError: std::error::Error;
172     /// The error returned if converting an `Integer` to a `u64` fails.
173     type TryIntoU64Error: std::error::Error;
175     /// The integer representation of a field element.
176     type Integer: Copy
177         + Debug
178         + Eq
179         + Ord
180         + BitAnd<Output = Self::Integer>
181         + Div<Output = Self::Integer>
182         + Shl<usize, Output = Self::Integer>
183         + Shr<usize, Output = Self::Integer>
184         + Add<Output = Self::Integer>
185         + Sub<Output = Self::Integer>
186         + From<Self>
187         + TryFrom<usize, Error = Self::IntegerTryFromError>
188         + TryInto<u64, Error = Self::TryIntoU64Error>;
190     /// Modular exponentation, i.e., `self^exp (mod p)`.
191     fn pow(&self, exp: Self::Integer) -> Self;
193     /// Returns the prime modulus `p`.
194     fn modulus() -> Self::Integer;
197 /// Methods common to all `FieldElementWithInteger` implementations that are private to the crate.
198 pub(crate) trait FieldElementExt: FieldElementWithInteger {
199     /// Encode `input` as bitvector of elements of `Self`. Output is written into the `output` slice.
200     /// If `output.len()` is smaller than the number of bits required to respresent `input`,
201     /// an error is returned.
202     ///
203     /// # Arguments
204     ///
205     /// * `input` - The field element to encode
206     /// * `output` - The slice to write the encoded bits into. Least signicant bit comes first
207     fn fill_with_bitvector_representation(
208         input: &Self::Integer,
209         output: &mut [Self],
210     ) -> Result<(), FieldError> {
211         // Create a mutable copy of `input`. In each iteration of the following loop we take the
212         // least significant bit, and shift input to the right by one bit.
213         let mut i = *input;
215         let one = Self::Integer::from(Self::one());
216         for bit in output.iter_mut() {
217             let w = Self::from(i & one);
218             *bit = w;
219             i = i >> 1;
220         }
222         // If `i` is still not zero, this means that it cannot be encoded by `bits` bits.
223         if i != Self::Integer::from(Self::zero()) {
224             return Err(FieldError::InputSizeMismatch);
225         }
227         Ok(())
228     }
230     /// Encode `input` as `bits`-bit vector of elements of `Self` if it's small enough
231     /// to be represented with that many bits.
232     ///
233     /// # Arguments
234     ///
235     /// * `input` - The field element to encode
236     /// * `bits` - The number of bits to use for the encoding
237     fn encode_into_bitvector_representation(
238         input: &Self::Integer,
239         bits: usize,
240     ) -> Result<Vec<Self>, FieldError> {
241         let mut result = vec![Self::zero(); bits];
242         Self::fill_with_bitvector_representation(input, &mut result)?;
243         Ok(result)
244     }
246     /// Decode the bitvector-represented value `input` into a simple representation as a single
247     /// field element.
248     ///
249     /// # Errors
250     ///
251     /// This function errors if `2^input.len() - 1` does not fit into the field `Self`.
252     fn decode_from_bitvector_representation(input: &[Self]) -> Result<Self, FieldError> {
253         let fi_one = Self::Integer::from(Self::one());
255         if !Self::valid_integer_bitlength(input.len()) {
256             return Err(FieldError::ModulusOverflow);
257         }
259         let mut decoded = Self::zero();
260         for (l, bit) in input.iter().enumerate() {
261             let w = fi_one << l;
262             decoded += Self::from(w) * *bit;
263         }
264         Ok(decoded)
265     }
267     /// Interpret `i` as [`Self::Integer`] if it's representable in that type and smaller than the
268     /// field modulus.
269     fn valid_integer_try_from<N>(i: N) -> Result<Self::Integer, FieldError>
270     where
271         Self::Integer: TryFrom<N>,
272     {
273         let i_int = Self::Integer::try_from(i).map_err(|_| FieldError::IntegerTryFrom)?;
274         if Self::modulus() <= i_int {
275             return Err(FieldError::ModulusOverflow);
276         }
277         Ok(i_int)
278     }
280     /// Check if the largest number representable with `bits` bits (i.e. 2^bits - 1) is
281     /// representable in this field.
282     fn valid_integer_bitlength(bits: usize) -> bool {
283         if bits >= 8 * Self::ENCODED_SIZE {
284             return false;
285         }
286         if Self::modulus() >> bits != Self::Integer::from(Self::zero()) {
287             return true;
288         }
289         false
290     }
293 impl<F: FieldElementWithInteger> FieldElementExt for F {}
295 /// serde Visitor implementation used to generically deserialize `FieldElement`
296 /// values from byte arrays.
297 pub(crate) struct FieldElementVisitor<F: FieldElement> {
298     pub(crate) phantom: PhantomData<F>,
301 impl<'de, F: FieldElement> Visitor<'de> for FieldElementVisitor<F> {
302     type Value = F;
304     fn expecting(&self, formatter: &mut Formatter) -> fmt::Result {
305         formatter.write_fmt(format_args!("an array of {} bytes", F::ENCODED_SIZE))
306     }
308     fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
309     where
310         E: serde::de::Error,
311     {
312         Self::Value::try_from(v).map_err(E::custom)
313     }
315     fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
316     where
317         A: serde::de::SeqAccess<'de>,
318     {
319         let mut bytes = vec![];
320         while let Some(byte) = seq.next_element()? {
321             bytes.push(byte);
322         }
324         self.visit_bytes(&bytes)
325     }
328 /// Objects with this trait represent an element of `GF(p)`, where `p` is some prime and the
329 /// field's multiplicative group has a subgroup with an order that is a power of 2, and at least
330 /// `2^20`.
331 pub trait FftFriendlyFieldElement: FieldElementWithInteger {
332     /// Returns the size of the multiplicative subgroup generated by
333     /// [`FftFriendlyFieldElement::generator`].
334     fn generator_order() -> Self::Integer;
336     /// Returns the generator of the multiplicative subgroup of size
337     /// [`FftFriendlyFieldElement::generator_order`].
338     fn generator() -> Self;
340     /// Returns the `2^l`-th principal root of unity for any `l <= 20`. Note that the `2^0`-th
341     /// prinicpal root of unity is `1` by definition.
342     fn root(l: usize) -> Option<Self>;
345 macro_rules! make_field {
346     (
347         $(#[$meta:meta])*
348         $elem:ident, $int:ident, $fp:ident, $encoding_size:literal,
349     ) => {
350         $(#[$meta])*
351         ///
352         /// This structure represents a field element in a prime order field. The concrete
353         /// representation of the element is via the Montgomery domain. For an element `n` in
354         /// `GF(p)`, we store `n * R^-1 mod p` (where `R` is a given power of two). This
355         /// representation enables using a more efficient (and branchless) multiplication algorithm,
356         /// at the expense of having to convert elements between their Montgomery domain
357         /// representation and natural representation. For calculations with many multiplications or
358         /// exponentiations, this is worthwhile.
359         ///
360         /// As an invariant, this integer representing the field element in the Montgomery domain
361         /// must be less than the field modulus, `p`.
362         #[derive(Clone, Copy, PartialOrd, Ord, Default)]
363         pub struct $elem(u128);
365         impl $elem {
366             /// Attempts to instantiate an `$elem` from the first `Self::ENCODED_SIZE` bytes in the
367             /// provided slice. The decoded value will be bitwise-ANDed with `mask` before reducing
368             /// it using the field modulus.
369             ///
370             /// # Errors
371             ///
372             /// An error is returned if the provided slice is not long enough to encode a field
373             /// element or if the decoded value is greater than the field prime.
374             ///
375             /// # Notes
376             ///
377             /// We cannot use `u128::from_le_bytes` or `u128::from_be_bytes` because those functions
378             /// expect inputs to be exactly 16 bytes long. Our encoding of most field elements is
379             /// more compact, and does not have to correspond to the size of an integer type. For
380             /// instance,`Field96`'s encoding is 12 bytes, even though it is a 16 byte `u128` in
381             /// memory.
382             fn try_from_bytes(bytes: &[u8], mask: u128) -> Result<Self, FieldError> {
383                 if Self::ENCODED_SIZE > bytes.len() {
384                     return Err(FieldError::ShortRead);
385                 }
387                 let mut int = 0;
388                 for i in 0..Self::ENCODED_SIZE {
389                     int |= (bytes[i] as u128) << (i << 3);
390                 }
392                 int &= mask;
394                 if int >= $fp.p {
395                     return Err(FieldError::ModulusOverflow);
396                 }
397                 // FieldParameters::montgomery() will return a value that has been fully reduced
398                 // mod p, satisfying the invariant on Self.
399                 Ok(Self($fp.montgomery(int)))
400             }
401         }
403         impl PartialEq for $elem {
404             fn eq(&self, rhs: &Self) -> bool {
405                 // The fields included in this comparison MUST match the fields
406                 // used in Hash::hash
407                 // https://doc.rust-lang.org/std/hash/trait.Hash.html#hash-and-eq
409                 // Check the invariant that the integer representation is fully reduced.
410                 debug_assert!(self.0 < $fp.p);
411                 debug_assert!(rhs.0 < $fp.p);
413                 self.0 == rhs.0
414             }
415         }
417         impl ConstantTimeEq for $elem {
418             fn ct_eq(&self, rhs: &Self) -> Choice {
419                 self.0.ct_eq(&rhs.0)
420             }
421         }
423         impl ConditionallySelectable for $elem {
424             fn conditional_select(a: &Self, b: &Self, choice: subtle::Choice) -> Self {
425                 Self(u128::conditional_select(&a.0, &b.0, choice))
426             }
427         }
429         impl Hash for $elem {
430             fn hash<H: Hasher>(&self, state: &mut H) {
431                 // The fields included in this hash MUST match the fields used
432                 // in PartialEq::eq
433                 // https://doc.rust-lang.org/std/hash/trait.Hash.html#hash-and-eq
435                 // Check the invariant that the integer representation is fully reduced.
436                 debug_assert!(self.0 < $fp.p);
438                 self.0.hash(state);
439             }
440         }
442         impl Eq for $elem {}
444         impl Add for $elem {
445             type Output = $elem;
446             fn add(self, rhs: Self) -> Self {
447                 // FieldParameters::add() returns a value that has been fully reduced
448                 // mod p, satisfying the invariant on Self.
449                 Self($fp.add(self.0, rhs.0))
450             }
451         }
453         impl Add for &$elem {
454             type Output = $elem;
455             fn add(self, rhs: Self) -> $elem {
456                 *self + *rhs
457             }
458         }
460         impl AddAssign for $elem {
461             fn add_assign(&mut self, rhs: Self) {
462                 *self = *self + rhs;
463             }
464         }
466         impl Sub for $elem {
467             type Output = $elem;
468             fn sub(self, rhs: Self) -> Self {
469                 // We know that self.0 and rhs.0 are both less than p, thus FieldParameters::sub()
470                 // returns a value less than p, satisfying the invariant on Self.
471                 Self($fp.sub(self.0, rhs.0))
472             }
473         }
475         impl Sub for &$elem {
476             type Output = $elem;
477             fn sub(self, rhs: Self) -> $elem {
478                 *self - *rhs
479             }
480         }
482         impl SubAssign for $elem {
483             fn sub_assign(&mut self, rhs: Self) {
484                 *self = *self - rhs;
485             }
486         }
488         impl Mul for $elem {
489             type Output = $elem;
490             fn mul(self, rhs: Self) -> Self {
491                 // FieldParameters::mul() always returns a value less than p, so the invariant on
492                 // Self is satisfied.
493                 Self($fp.mul(self.0, rhs.0))
494             }
495         }
497         impl Mul for &$elem {
498             type Output = $elem;
499             fn mul(self, rhs: Self) -> $elem {
500                 *self * *rhs
501             }
502         }
504         impl MulAssign for $elem {
505             fn mul_assign(&mut self, rhs: Self) {
506                 *self = *self * rhs;
507             }
508         }
510         impl Div for $elem {
511             type Output = $elem;
512             #[allow(clippy::suspicious_arithmetic_impl)]
513             fn div(self, rhs: Self) -> Self {
514                 self * rhs.inv()
515             }
516         }
518         impl Div for &$elem {
519             type Output = $elem;
520             fn div(self, rhs: Self) -> $elem {
521                 *self / *rhs
522             }
523         }
525         impl DivAssign for $elem {
526             fn div_assign(&mut self, rhs: Self) {
527                 *self = *self / rhs;
528             }
529         }
531         impl Neg for $elem {
532             type Output = $elem;
533             fn neg(self) -> Self {
534                 // FieldParameters::neg() will return a value less than p because self.0 is less
535                 // than p, and neg() dispatches to sub().
536                 Self($fp.neg(self.0))
537             }
538         }
540         impl Neg for &$elem {
541             type Output = $elem;
542             fn neg(self) -> $elem {
543                 -(*self)
544             }
545         }
547         impl From<$int> for $elem {
548             fn from(x: $int) -> Self {
549                 // FieldParameters::montgomery() will return a value that has been fully reduced
550                 // mod p, satisfying the invariant on Self.
551                 Self($fp.montgomery(u128::try_from(x).unwrap()))
552             }
553         }
555         impl From<$elem> for $int {
556             fn from(x: $elem) -> Self {
557                 $int::try_from($fp.residue(x.0)).unwrap()
558             }
559         }
561         impl PartialEq<$int> for $elem {
562             fn eq(&self, rhs: &$int) -> bool {
563                 $fp.residue(self.0) == u128::try_from(*rhs).unwrap()
564             }
565         }
567         impl<'a> TryFrom<&'a [u8]> for $elem {
568             type Error = FieldError;
570             fn try_from(bytes: &[u8]) -> Result<Self, FieldError> {
571                 Self::try_from_bytes(bytes, u128::MAX)
572             }
573         }
575         impl From<$elem> for [u8; $elem::ENCODED_SIZE] {
576             fn from(elem: $elem) -> Self {
577                 let int = $fp.residue(elem.0);
578                 let mut slice = [0; $elem::ENCODED_SIZE];
579                 for i in 0..$elem::ENCODED_SIZE {
580                     slice[i] = ((int >> (i << 3)) & 0xff) as u8;
581                 }
582                 slice
583             }
584         }
586         impl From<$elem> for Vec<u8> {
587             fn from(elem: $elem) -> Self {
588                 <[u8; $elem::ENCODED_SIZE]>::from(elem).to_vec()
589             }
590         }
592         impl Display for $elem {
593             fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
594                 write!(f, "{}", $fp.residue(self.0))
595             }
596         }
598         impl Debug for $elem {
599             fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
600                 write!(f, "{}", $fp.residue(self.0))
601             }
602         }
604         // We provide custom [`serde::Serialize`] and [`serde::Deserialize`] implementations because
605         // the derived implementations would represent `FieldElement` values as the backing `u128`,
606         // which is not what we want because (1) we can be more efficient in all cases and (2) in
607         // some circumstances, [some serializers don't support `u128`](https://github.com/serde-rs/json/issues/625).
608         impl Serialize for $elem {
609             fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
610                 let bytes: [u8; $elem::ENCODED_SIZE] = (*self).into();
611                 serializer.serialize_bytes(&bytes)
612             }
613         }
615         impl<'de> Deserialize<'de> for $elem {
616             fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<$elem, D::Error> {
617                 deserializer.deserialize_bytes(FieldElementVisitor { phantom: PhantomData })
618             }
619         }
621         impl Encode for $elem {
622             fn encode(&self, bytes: &mut Vec<u8>) {
623                 let slice = <[u8; $elem::ENCODED_SIZE]>::from(*self);
624                 bytes.extend_from_slice(&slice);
625             }
627             fn encoded_len(&self) -> Option<usize> {
628                 Some(Self::ENCODED_SIZE)
629             }
630         }
632         impl Decode for $elem {
633             fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
634                 let mut value = [0u8; $elem::ENCODED_SIZE];
635                 bytes.read_exact(&mut value)?;
636                 $elem::try_from_bytes(&value, u128::MAX).map_err(|e| {
637                     CodecError::Other(Box::new(e) as Box<dyn std::error::Error + 'static + Send + Sync>)
638                 })
639             }
640         }
642         impl FieldElement for $elem {
643             const ENCODED_SIZE: usize = $encoding_size;
644             fn inv(&self) -> Self {
645                 // FieldParameters::inv() ultimately relies on mul(), and will always return a
646                 // value less than p.
647                 Self($fp.inv(self.0))
648             }
650             fn try_from_random(bytes: &[u8]) -> Result<Self, FieldError> {
651                 $elem::try_from_bytes(bytes, $fp.bit_mask)
652             }
654             fn zero() -> Self {
655                 Self(0)
656             }
658             fn one() -> Self {
659                 Self($fp.roots[0])
660             }
661         }
663         impl FieldElementWithInteger for $elem {
664             type Integer = $int;
665             type IntegerTryFromError = <Self::Integer as TryFrom<usize>>::Error;
666             type TryIntoU64Error = <Self::Integer as TryInto<u64>>::Error;
668             fn pow(&self, exp: Self::Integer) -> Self {
669                 // FieldParameters::pow() relies on mul(), and will always return a value less
670                 // than p.
671                 Self($fp.pow(self.0, u128::try_from(exp).unwrap()))
672             }
674             fn modulus() -> Self::Integer {
675                 $fp.p as $int
676             }
677         }
679         impl FftFriendlyFieldElement for $elem {
680             fn generator() -> Self {
681                 Self($fp.g)
682             }
684             fn generator_order() -> Self::Integer {
685                 1 << (Self::Integer::try_from($fp.num_roots).unwrap())
686             }
688             fn root(l: usize) -> Option<Self> {
689                 if l < min($fp.roots.len(), $fp.num_roots+1) {
690                     Some(Self($fp.roots[l]))
691                 } else {
692                     None
693                 }
694             }
695         }
696     };
699 make_field!(
700     /// Same as Field32, but encoded in little endian for compatibility with Prio v2.
701     FieldPrio2,
702     u32,
703     FP32,
704     4,
707 make_field!(
708     /// `GF(18446744069414584321)`, a 64-bit field.
709     Field64,
710     u64,
711     FP64,
712     8,
715 /// This nested module is an implementation detail to limit the scope of a module-wide
716 /// `allow(deprecated)` attribute. [`Field96`] is marked as deprecated, and deprecation warnings
717 /// must be silenced on multiple implementation blocks, and the macro invocation itself that
718 /// defines the struct and its implementation.
719 mod field96 {
720     #![allow(deprecated)]
722     use super::{
723         FftFriendlyFieldElement, FieldElement, FieldElementVisitor, FieldElementWithInteger,
724         FieldError,
725     };
726     use crate::{
727         codec::{CodecError, Decode, Encode},
728         fp::FP96,
729     };
730     use serde::{Deserialize, Deserializer, Serialize, Serializer};
731     use std::{
732         cmp::min,
733         fmt::{Debug, Display, Formatter},
734         hash::{Hash, Hasher},
735         io::{Cursor, Read},
736         marker::PhantomData,
737         ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign},
738     };
739     use subtle::{Choice, ConditionallySelectable, ConstantTimeEq};
741     make_field!(
742         #[deprecated]
743         /// `GF(79228148845226978974766202881)`, a 96-bit field.
744         ///
745         /// This is deprecated because it is not currently used by either Prio v2 or any VDAF.
746         Field96,
747         u128,
748         FP96,
749         12,
750     );
753 #[allow(deprecated)]
754 pub use field96::Field96;
756 make_field!(
757     /// `GF(340282366920938462946865773367900766209)`, a 128-bit field.
758     Field128,
759     u128,
760     FP128,
761     16,
764 /// Merge two vectors of fields by summing other_vector into accumulator.
766 /// # Errors
768 /// Fails if the two vectors do not have the same length.
769 pub(crate) fn merge_vector<F: FieldElement>(
770     accumulator: &mut [F],
771     other_vector: &[F],
772 ) -> Result<(), FieldError> {
773     if accumulator.len() != other_vector.len() {
774         return Err(FieldError::InputSizeMismatch);
775     }
776     for (a, o) in accumulator.iter_mut().zip(other_vector.iter()) {
777         *a += *o;
778     }
780     Ok(())
783 /// Outputs an additive secret sharing of the input.
784 #[cfg(all(feature = "crypto-dependencies", test))]
785 pub(crate) fn split_vector<F: FieldElement>(
786     inp: &[F],
787     num_shares: usize,
788 ) -> Result<Vec<Vec<F>>, PrngError> {
789     if num_shares == 0 {
790         return Ok(vec![]);
791     }
793     let mut outp = Vec::with_capacity(num_shares);
794     outp.push(inp.to_vec());
796     for _ in 1..num_shares {
797         let share: Vec<F> = random_vector(inp.len())?;
798         for (x, y) in outp[0].iter_mut().zip(&share) {
799             *x -= *y;
800         }
801         outp.push(share);
802     }
804     Ok(outp)
807 /// Generate a vector of uniformly distributed random field elements.
808 #[cfg(feature = "crypto-dependencies")]
809 pub fn random_vector<F: FieldElement>(len: usize) -> Result<Vec<F>, PrngError> {
810     Ok(Prng::new()?.take(len).collect())
813 /// `encode_fieldvec` serializes a type that is equivalent to a vector of field elements.
814 #[inline(always)]
815 pub(crate) fn encode_fieldvec<F: FieldElement, T: AsRef<[F]>>(val: T, bytes: &mut Vec<u8>) {
816     for elem in val.as_ref() {
817         bytes.append(&mut (*elem).into());
818     }
821 /// `decode_fieldvec` deserializes some number of field elements from a cursor, and advances the
822 /// cursor's position.
823 pub(crate) fn decode_fieldvec<F: FieldElement>(
824     count: usize,
825     input: &mut Cursor<&[u8]>,
826 ) -> Result<Vec<F>, CodecError> {
827     let mut vec = Vec::with_capacity(count);
828     let mut buffer = [0u8; 64];
829     assert!(
830         buffer.len() >= F::ENCODED_SIZE,
831         "field is too big for buffer"
832     );
833     for _ in 0..count {
834         input.read_exact(&mut buffer[..F::ENCODED_SIZE])?;
835         vec.push(
836             F::try_from(&buffer[..F::ENCODED_SIZE]).map_err(|e| CodecError::Other(Box::new(e)))?,
837         );
838     }
839     Ok(vec)
842 impl<F> CoinToss for F
843 where
844     F: FieldElement,
846     fn sample<S>(seed_stream: &mut S) -> Self
847     where
848         S: SeedStream,
849     {
850         // This is analogous to `Prng::get()`, but does not make use of a persistent buffer of
851         // `SeedStream` output.
852         let mut buffer = [0u8; 64];
853         assert!(
854             buffer.len() >= F::ENCODED_SIZE,
855             "field is too big for buffer"
856         );
857         loop {
858             seed_stream.fill(&mut buffer[..F::ENCODED_SIZE]);
859             match Self::try_from_random(&buffer[..F::ENCODED_SIZE]) {
860                 Ok(x) => return x,
861                 Err(FieldError::ModulusOverflow) => continue,
862                 Err(err) => panic!("unexpected error: {err}"),
863             }
864         }
865     }
868 #[cfg(test)]
869 pub(crate) mod test_utils {
870     use super::{FieldElement, FieldElementWithInteger};
871     use crate::{codec::CodecError, field::FieldError, prng::Prng};
872     use assert_matches::assert_matches;
873     use std::{
874         collections::hash_map::DefaultHasher,
875         convert::{TryFrom, TryInto},
876         fmt::Debug,
877         hash::{Hash, Hasher},
878         io::Cursor,
879         ops::{Add, BitAnd, Div, Shl, Shr, Sub},
880     };
882     /// A test-only copy of `FieldElementWithInteger`.
883     ///
884     /// This trait is only used in tests, and it is implemented on some fields that do not have
885     /// `FieldElementWithInteger` implementations. This separate trait is used in order to avoid
886     /// affecting trait resolution with conditional compilation. Additionally, this trait only
887     /// requires the `Integer` associated type satisfy `Clone`, not `Copy`, so that it may be used
888     /// with arbitrary precision integer implementations.
889     pub(crate) trait TestFieldElementWithInteger:
890         FieldElement + From<Self::Integer>
891     {
892         type IntegerTryFromError: std::error::Error;
893         type TryIntoU64Error: std::error::Error;
894         type Integer: Clone
895             + Debug
896             + Eq
897             + Ord
898             + BitAnd<Output = Self::Integer>
899             + Div<Output = Self::Integer>
900             + Shl<usize, Output = Self::Integer>
901             + Shr<usize, Output = Self::Integer>
902             + Add<Output = Self::Integer>
903             + Sub<Output = Self::Integer>
904             + From<Self>
905             + TryFrom<usize, Error = Self::IntegerTryFromError>
906             + TryInto<u64, Error = Self::TryIntoU64Error>;
908         fn pow(&self, exp: Self::Integer) -> Self;
910         fn modulus() -> Self::Integer;
911     }
913     impl<F> TestFieldElementWithInteger for F
914     where
915         F: FieldElementWithInteger,
916     {
917         type IntegerTryFromError = <F as FieldElementWithInteger>::IntegerTryFromError;
918         type TryIntoU64Error = <F as FieldElementWithInteger>::TryIntoU64Error;
919         type Integer = <F as FieldElementWithInteger>::Integer;
921         fn pow(&self, exp: Self::Integer) -> Self {
922             <F as FieldElementWithInteger>::pow(self, exp)
923         }
925         fn modulus() -> Self::Integer {
926             <F as FieldElementWithInteger>::modulus()
927         }
928     }
930     pub(crate) fn field_element_test_common<F: TestFieldElementWithInteger>() {
931         let mut prng: Prng<F, _> = Prng::new().unwrap();
932         let int_modulus = F::modulus();
933         let int_one = F::Integer::try_from(1).unwrap();
934         let zero = F::zero();
935         let one = F::one();
936         let two = F::from(F::Integer::try_from(2).unwrap());
937         let four = F::from(F::Integer::try_from(4).unwrap());
939         // add
940         assert_eq!(F::from(int_modulus.clone() - int_one.clone()) + one, zero);
941         assert_eq!(one + one, two);
942         assert_eq!(two + F::from(int_modulus.clone()), two);
944         // add w/ assignment
945         let mut a = prng.get();
946         let b = prng.get();
947         let c = a + b;
948         a += b;
949         assert_eq!(a, c);
951         // sub
952         assert_eq!(zero - one, F::from(int_modulus.clone() - int_one.clone()));
953         #[allow(clippy::eq_op)]
954         {
955             assert_eq!(one - one, zero);
956         }
957         assert_eq!(one + (-one), zero);
958         assert_eq!(two - F::from(int_modulus.clone()), two);
959         assert_eq!(one - F::from(int_modulus.clone() - int_one.clone()), two);
961         // sub w/ assignment
962         let mut a = prng.get();
963         let b = prng.get();
964         let c = a - b;
965         a -= b;
966         assert_eq!(a, c);
968         // add + sub
969         for _ in 0..100 {
970             let f = prng.get();
971             let g = prng.get();
972             assert_eq!(f + g - f - g, zero);
973             assert_eq!(f + g - g, f);
974             assert_eq!(f + g - f, g);
975         }
977         // mul
978         assert_eq!(two * two, four);
979         assert_eq!(two * one, two);
980         assert_eq!(two * zero, zero);
981         assert_eq!(one * F::from(int_modulus.clone()), zero);
983         // mul w/ assignment
984         let mut a = prng.get();
985         let b = prng.get();
986         let c = a * b;
987         a *= b;
988         assert_eq!(a, c);
990         // integer conversion
991         assert_eq!(F::Integer::from(zero), F::Integer::try_from(0).unwrap());
992         assert_eq!(F::Integer::from(one), F::Integer::try_from(1).unwrap());
993         assert_eq!(F::Integer::from(two), F::Integer::try_from(2).unwrap());
994         assert_eq!(F::Integer::from(four), F::Integer::try_from(4).unwrap());
996         // serialization
997         let test_inputs = vec![
998             zero,
999             one,
1000             prng.get(),
1001             F::from(int_modulus.clone() - int_one.clone()),
1002         ];
1003         for want in test_inputs.iter() {
1004             let mut bytes = vec![];
1005             want.encode(&mut bytes);
1007             assert_eq!(bytes.len(), F::ENCODED_SIZE);
1008             assert_eq!(want.encoded_len().unwrap(), F::ENCODED_SIZE);
1010             let got = F::get_decoded(&bytes).unwrap();
1011             assert_eq!(got, *want);
1012         }
1014         let serialized_vec = F::slice_into_byte_vec(&test_inputs);
1015         let deserialized = F::byte_slice_into_vec(&serialized_vec).unwrap();
1016         assert_eq!(deserialized, test_inputs);
1018         let test_input = prng.get();
1019         let json = serde_json::to_string(&test_input).unwrap();
1020         let deserialized = serde_json::from_str::<F>(&json).unwrap();
1021         assert_eq!(deserialized, test_input);
1023         let value = serde_json::from_str::<serde_json::Value>(&json).unwrap();
1024         let array = value.as_array().unwrap();
1025         for element in array {
1026             element.as_u64().unwrap();
1027         }
1029         let err = F::byte_slice_into_vec(&[0]).unwrap_err();
1030         assert_matches!(err, FieldError::ShortRead);
1032         let err = F::byte_slice_into_vec(&vec![0xffu8; F::ENCODED_SIZE]).unwrap_err();
1033         assert_matches!(err, FieldError::Codec(CodecError::Other(err)) => {
1034             assert_matches!(err.downcast_ref::<FieldError>(), Some(FieldError::ModulusOverflow));
1035         });
1037         let insufficient = vec![0u8; F::ENCODED_SIZE - 1];
1038         let err = F::try_from(insufficient.as_ref()).unwrap_err();
1039         assert_matches!(err, FieldError::ShortRead);
1040         let err = F::decode(&mut Cursor::new(&insufficient)).unwrap_err();
1041         assert_matches!(err, CodecError::Io(_));
1043         let err = F::decode(&mut Cursor::new(&vec![0xffu8; F::ENCODED_SIZE])).unwrap_err();
1044         assert_matches!(err, CodecError::Other(err) => {
1045             assert_matches!(err.downcast_ref::<FieldError>(), Some(FieldError::ModulusOverflow));
1046         });
1048         // equality and hash: Generate many elements, confirm they are not equal, and confirm
1049         // various products that should be equal have the same hash. Three is chosen as a generator
1050         // here because it happens to generate fairly large subgroups of (Z/pZ)* for all four
1051         // primes.
1052         let three = F::from(F::Integer::try_from(3).unwrap());
1053         let mut powers_of_three = Vec::with_capacity(500);
1054         let mut power = one;
1055         for _ in 0..500 {
1056             powers_of_three.push(power);
1057             power *= three;
1058         }
1059         // Check all these elements are mutually not equal.
1060         for i in 0..powers_of_three.len() {
1061             let first = &powers_of_three[i];
1062             for second in &powers_of_three[0..i] {
1063                 assert_ne!(first, second);
1064             }
1065         }
1067         // Construct an element from a number that needs to be reduced, and test comparisons on it,
1068         // confirming that it is reduced correctly.
1069         let p = F::from(int_modulus.clone());
1070         assert_eq!(p, zero);
1071         let p_plus_one = F::from(int_modulus + int_one);
1072         assert_eq!(p_plus_one, one);
1073     }
1075     pub(super) fn hash_helper<H: Hash>(input: H) -> u64 {
1076         let mut hasher = DefaultHasher::new();
1077         input.hash(&mut hasher);
1078         hasher.finish()
1079     }
1082 #[cfg(test)]
1083 mod tests {
1084     use super::*;
1085     use crate::field::test_utils::{field_element_test_common, hash_helper};
1086     use crate::fp::MAX_ROOTS;
1087     use crate::prng::Prng;
1088     use assert_matches::assert_matches;
1090     #[test]
1091     fn test_accumulate() {
1092         let mut lhs = vec![FieldPrio2(1); 10];
1093         let rhs = vec![FieldPrio2(2); 10];
1095         merge_vector(&mut lhs, &rhs).unwrap();
1097         lhs.iter().for_each(|f| assert_eq!(*f, FieldPrio2(3)));
1098         rhs.iter().for_each(|f| assert_eq!(*f, FieldPrio2(2)));
1100         let wrong_len = vec![FieldPrio2::zero(); 9];
1101         let result = merge_vector(&mut lhs, &wrong_len);
1102         assert_matches!(result, Err(FieldError::InputSizeMismatch));
1103     }
1105     fn field_element_test<F: FftFriendlyFieldElement + Hash>() {
1106         field_element_test_common::<F>();
1108         let mut prng: Prng<F, _> = Prng::new().unwrap();
1109         let int_modulus = F::modulus();
1110         let int_one = F::Integer::try_from(1).unwrap();
1111         let zero = F::zero();
1112         let one = F::one();
1113         let two = F::from(F::Integer::try_from(2).unwrap());
1114         let four = F::from(F::Integer::try_from(4).unwrap());
1116         // div
1117         assert_eq!(four / two, two);
1118         #[allow(clippy::eq_op)]
1119         {
1120             assert_eq!(two / two, one);
1121         }
1122         assert_eq!(zero / two, zero);
1123         assert_eq!(two / zero, zero); // Undefined behavior
1124         assert_eq!(zero.inv(), zero); // Undefined behavior
1126         // div w/ assignment
1127         let mut a = prng.get();
1128         let b = prng.get();
1129         let c = a / b;
1130         a /= b;
1131         assert_eq!(a, c);
1132         assert_eq!(hash_helper(a), hash_helper(c));
1134         // mul + div
1135         for _ in 0..100 {
1136             let f = prng.get();
1137             if f == zero {
1138                 continue;
1139             }
1140             assert_eq!(f * f.inv(), one);
1141             assert_eq!(f.inv() * f, one);
1142         }
1144         // pow
1145         assert_eq!(two.pow(F::Integer::try_from(0).unwrap()), one);
1146         assert_eq!(two.pow(int_one), two);
1147         assert_eq!(two.pow(F::Integer::try_from(2).unwrap()), four);
1148         assert_eq!(two.pow(int_modulus - int_one), one);
1149         assert_eq!(two.pow(int_modulus), two);
1151         // roots
1152         let mut int_order = F::generator_order();
1153         for l in 0..MAX_ROOTS + 1 {
1154             assert_eq!(
1155                 F::generator().pow(int_order),
1156                 F::root(l).unwrap(),
1157                 "failure for F::root({l})"
1158             );
1159             int_order = int_order >> 1;
1160         }
1162         // formatting
1163         assert_eq!(format!("{zero}"), "0");
1164         assert_eq!(format!("{one}"), "1");
1165         assert_eq!(format!("{zero:?}"), "0");
1166         assert_eq!(format!("{one:?}"), "1");
1168         let three = F::from(F::Integer::try_from(3).unwrap());
1169         let mut powers_of_three = Vec::with_capacity(500);
1170         let mut power = one;
1171         for _ in 0..500 {
1172             powers_of_three.push(power);
1173             power *= three;
1174         }
1176         // Check that 3^i is the same whether it's calculated with pow() or repeated
1177         // multiplication, with both equality and hash equality.
1178         for (i, power) in powers_of_three.iter().enumerate() {
1179             let result = three.pow(F::Integer::try_from(i).unwrap());
1180             assert_eq!(result, *power);
1181             let hash1 = hash_helper(power);
1182             let hash2 = hash_helper(result);
1183             assert_eq!(hash1, hash2);
1184         }
1186         // Check that 3^n = (3^i)*(3^(n-i)), via both equality and hash equality.
1187         let expected_product = powers_of_three[powers_of_three.len() - 1];
1188         let expected_hash = hash_helper(expected_product);
1189         for i in 0..powers_of_three.len() {
1190             let a = powers_of_three[i];
1191             let b = powers_of_three[powers_of_three.len() - 1 - i];
1192             let product = a * b;
1193             assert_eq!(product, expected_product);
1194             assert_eq!(hash_helper(product), expected_hash);
1195         }
1196     }
1198     #[test]
1199     fn test_field_prio2() {
1200         field_element_test::<FieldPrio2>();
1201     }
1203     #[test]
1204     fn test_field64() {
1205         field_element_test::<Field64>();
1206     }
1208     #[test]
1209     fn test_field96() {
1210         #[allow(deprecated)]
1211         field_element_test::<Field96>();
1212     }
1214     #[test]
1215     fn test_field128() {
1216         field_element_test::<Field128>();
1217     }
1219     #[test]
1220     fn test_encode_into_bitvector() {
1221         let zero = Field128::zero();
1222         let one = Field128::one();
1223         let zero_enc = Field128::encode_into_bitvector_representation(&0, 4).unwrap();
1224         let one_enc = Field128::encode_into_bitvector_representation(&1, 4).unwrap();
1225         let fifteen_enc = Field128::encode_into_bitvector_representation(&15, 4).unwrap();
1226         assert_eq!(zero_enc, [zero; 4]);
1227         assert_eq!(one_enc, [one, zero, zero, zero]);
1228         assert_eq!(fifteen_enc, [one; 4]);
1229         Field128::encode_into_bitvector_representation(&16, 4).unwrap_err();
1230     }
1232     #[test]
1233     fn test_fill_bitvector() {
1234         let zero = Field128::zero();
1235         let one = Field128::one();
1236         let mut output: Vec<Field128> = vec![zero; 6];
1237         Field128::fill_with_bitvector_representation(&9, &mut output[1..5]).unwrap();
1238         assert_eq!(output, [zero, one, zero, zero, one, zero]);
1239         Field128::fill_with_bitvector_representation(&16, &mut output[1..5]).unwrap_err();
1240     }