2 * Implementation of OpenSSH 9.x's hybrid key exchange protocol
3 * sntrup761x25519-sha512@openssh.com .
5 * This consists of the 'Streamlined NTRU Prime' quantum-resistant
6 * cryptosystem, run in parallel with ordinary Curve25519 to generate
7 * a shared secret combining the output of both systems.
9 * (Hence, even if you don't trust this newfangled NTRU Prime thing at
10 * all, it's at least no _less_ secure than the kex you were using
13 * References for the NTRU Prime cryptosystem, up to and including
14 * binary encodings of public and private keys and the exact preimages
15 * of the hashes used in key exchange:
17 * https://ntruprime.cr.yp.to/
18 * https://ntruprime.cr.yp.to/nist/ntruprime-20201007.pdf
20 * The SSH protocol layer is not documented anywhere I could find (as
21 * of 2022-04-15, not even in OpenSSH's PROTOCOL.* files). I had to
22 * read OpenSSH's source code to find out how it worked, and the
23 * answer is as follows:
25 * This hybrid kex method is treated for SSH purposes as a form of
26 * elliptic-curve Diffie-Hellman, and shares the same SSH message
27 * sequence: client sends SSH2_MSG_KEX_ECDH_INIT containing its public
28 * half, server responds with SSH2_MSG_KEX_ECDH_REPLY containing _its_
29 * public half plus the host key and signature on the shared secret.
31 * (This is a bit of a fudge, because unlike actual ECDH, this kex
32 * method is asymmetric: one side sends a public key, and the other
33 * side encrypts something with it and sends the ciphertext back. So
34 * while the normal ECDH implementations can compute the two sides
35 * independently in parallel, this system reusing the same messages
36 * has to be serial. But the order of the messages _is_ firmly
37 * specified in SSH ECDH, so it works anyway.)
39 * For this kex method, SSH2_MSG_KEX_ECDH_INIT still contains a single
40 * SSH 'string', which consists of the concatenation of a Streamlined
41 * NTRU Prime public key with the Curve25519 public value. (Both of
42 * these have fixed length in bytes, so there's no ambiguity in the
45 * SSH2_MSG_KEX_ECDH_REPLY is mostly the same as usual. The only
46 * string in the packet that varies is the second one, which would
47 * normally contain the server's public elliptic curve point. Instead,
48 * it now contains the concatenation of
50 * - a Streamlined NTRU Prime ciphertext
51 * - the 'confirmation hash' specified in ntruprime-20201007.pdf,
52 * hashing the plaintext of that ciphertext together with the
54 * - the Curve25519 public point as usual.
56 * Again, all three of those elements have fixed lengths.
58 * The client decrypts the ciphertext, checks the confirmation hash,
59 * and if successful, generates the 'session hash' specified in
60 * ntruprime-20201007.pdf, which is 32 bytes long and is the ultimate
61 * output of the Streamlined NTRU Prime key exchange.
63 * The output of the hybrid kex method as a whole is an SSH 'string'
64 * of length 64 containing the SHA-512 hash of the concatenatio of
66 * - the Streamlined NTRU Prime session hash (32 bytes)
67 * - the Curve25519 shared secret (32 bytes).
69 * That string is included directly into the SSH exchange hash and key
70 * derivation hashes, in place of the mpint that comes out of most
83 /* ----------------------------------------------------------------------
84 * Preliminaries: we're going to need to do modular arithmetic on
85 * small values (considerably smaller than 2^16), and we need to do it
86 * without using integer division which might not be time-safe.
88 * The strategy for this is the same as I used in
89 * mp_mod_known_integer: see there for the proofs. The basic idea is
90 * that we precompute the reciprocal of our modulus as a fixed-point
91 * number, and use that to get an approximate quotient which we
92 * subtract off. For these integer sizes, precomputing a fixed-point
93 * reciprocal of the form (2^48 / modulus) leaves us at most off by 1
94 * in the quotient, so there's a single (time-safe) trial subtraction
97 * (It's possible that some speed could be gained by not reducing
98 * fully at every step. But then you'd have to carefully identify all
99 * the places in the algorithm where things are compared to zero. This
100 * was the easiest way to get it all working in the first place.)
103 /* Precompute the reciprocal */
104 static uint64_t reciprocal_for_reduction(uint16_t q
)
106 return ((uint64_t)1 << 48) / q
;
109 /* Reduce x mod q, assuming qrecip == reciprocal_for_reduction(q) */
110 static uint16_t reduce(uint32_t x
, uint16_t q
, uint64_t qrecip
)
112 uint64_t unshifted_quot
= x
* qrecip
;
113 uint64_t quot
= unshifted_quot
>> 48;
114 uint16_t reduced
= x
- quot
* q
;
115 reduced
-= q
* (1 & ((q
-1 - reduced
) >> 15));
119 /* Reduce x mod q as above, but also return the quotient */
120 static uint16_t reduce_with_quot(uint32_t x
, uint32_t *quot_out
,
121 uint16_t q
, uint64_t qrecip
)
123 uint64_t unshifted_quot
= x
* qrecip
;
124 uint64_t quot
= unshifted_quot
>> 48;
125 uint16_t reduced
= x
- quot
* q
;
126 uint64_t extraquot
= (1 & ((q
-1 - reduced
) >> 15));
127 reduced
-= extraquot
* q
;
128 *quot_out
= quot
+ extraquot
;
132 /* Invert x mod q, assuming it's nonzero. (For time-safety, no check
133 * is made for zero; it just returns 0.) */
134 static uint16_t invert(uint16_t x
, uint16_t q
, uint64_t qrecip
)
136 /* Fermat inversion: compute x^(q-2), since x^(q-1) == 1. */
137 uint32_t sq
= x
, bit
= 1, acc
= 1, exp
= q
-2;
140 acc
= reduce(acc
* sq
, q
, qrecip
);
145 sq
= reduce(sq
* sq
, q
, qrecip
);
150 /* Check whether x == 0, time-safely, and return 1 if it is or 0 otherwise. */
151 static unsigned iszero(uint16_t x
)
153 return 1 & ~((x
+ 0xFFFF) >> 16);
157 * Handy macros to cut down on all those extra function parameters. In
158 * the common case where a function is working mod the same modulus
159 * throughout (and has called it q), you can just write 'SETUP;' at
160 * the top and then call REDUCE(...) and INVERT(...) without having to
161 * write out q and qrecip every time.
163 #define SETUP uint64_t qrecip = reciprocal_for_reduction(q)
164 #define REDUCE(x) reduce(x, q, qrecip)
165 #define INVERT(x) invert(x, q, qrecip)
167 /* ----------------------------------------------------------------------
168 * Quotient-ring functions.
170 * NTRU Prime works with two similar but different quotient rings:
172 * Z_q[x] / <x^p-x-1> where p,q are the prime parameters of the system
173 * Z_3[x] / <x^p-x-1> with the same p, but coefficients mod 3.
175 * The former is a field (every nonzero element is invertible),
176 * because the system parameters are chosen such that x^p-x-1 is
177 * invertible over Z_q. The latter is not a field (or not necessarily,
178 * and in particular, not for the value of p we use here).
180 * In these core functions, you pass in the modulus you want as the
181 * parameter q, which is either the 'real' q specified in the system
182 * parameters, or 3 if you're doing one of the mod-3 parts of the
187 * Multiply two elements of a quotient ring.
189 * 'a' and 'b' are arrays of exactly p coefficients, with constant
190 * term first. 'out' is an array the same size to write the inverse
193 void ntru_ring_multiply(uint16_t *out
, const uint16_t *a
, const uint16_t *b
,
194 unsigned p
, unsigned q
)
199 * Strategy: just compute the full product with 2p coefficients,
200 * and then reduce it mod x^p-x-1 by working downwards from the
201 * top coefficient replacing x^{p+k} with (x+1)x^k for k = ...,1,0.
203 * Possibly some speed could be gained here by doing the recursive
204 * Karatsuba optimisation for the initial multiplication? But I
207 uint32_t *unreduced
= snewn(2*p
, uint32_t);
208 for (unsigned i
= 0; i
< 2*p
; i
++)
210 for (unsigned i
= 0; i
< p
; i
++)
211 for (unsigned j
= 0; j
< p
; j
++)
212 unreduced
[i
+j
] = REDUCE(unreduced
[i
+j
] + a
[i
] * b
[j
]);
214 for (unsigned i
= 2*p
- 1; i
>= p
; i
--) {
215 unreduced
[i
-p
] += unreduced
[i
];
216 unreduced
[i
-p
+1] += unreduced
[i
];
220 for (unsigned i
= 0; i
< p
; i
++)
221 out
[i
] = REDUCE(unreduced
[i
]);
223 smemclr(unreduced
, 2*p
* sizeof(*unreduced
));
228 * Invert an element of the quotient ring.
230 * 'in' is an array of exactly p coefficients, with constant term
231 * first. 'out' is an array the same size to write the inverse into.
233 * Method: essentially Stein's gcd algorithm, taking the gcd of the
234 * input (regarded as an element of Z_q[x] proper) and x^p-x-1. Given
235 * two polynomials over a field which are not both divisible by x, you
236 * can find their gcd by iterating the following procedure:
238 * - if one is divisible by x, divide off x
239 * - otherwise, subtract from the higher-degree one whatever scalar
240 * multiple of the lower-degree one will make it divisible by x,
241 * and _then_ divide off x
243 * Neither of these types of step changes the gcd of the two
246 * Each step reduces the sum of the two polynomials' degree by at
247 * least one, as long as at least one of the degrees is positive.
248 * (Maybe more than one if all the stars align in the second case, if
249 * the subtraction cancels the leading term as well as the constant
250 * term.) So in at most deg A + deg B steps, we must have reached the
251 * situation where both polys are constants; in one more step after
252 * that, one of them will be zero; and in one step after _that_, the
253 * zero one will reliably be the one we're dividing by x. Or rather,
254 * that's what happens in the case where A,B are coprime; if not, then
255 * one hits zero while the other is still nonzero.
257 * In a normal gcd algorithm, you'd track a linear combination of the
258 * two original polynomials that yields each working value, and end up
259 * with a linear combination of the inputs that yields the gcd. In
260 * this algorithm, the 'divide off x' step makes that awkward - but we
261 * can solve that by instead multiplying by the inverse of x in the
262 * ring that we want our answer to be valid in! And since the modulus
263 * polynomial of the ring is x^p-x-1, the inverse of x is easy to
264 * calculate, because it's always just x^{p-1} - 1, which is also very
265 * easy to multiply by.
267 unsigned ntru_ring_invert(uint16_t *out
, const uint16_t *in
,
268 unsigned p
, unsigned q
)
272 /* Size of the polynomial arrays we'll work with */
273 const size_t SIZE
= p
+1;
275 /* Number of steps of the algorithm is the max possible value of
276 * deg A + deg B + 2, where deg A <= p-1 and deg B = p */
277 const size_t STEPS
= 2*p
+ 1;
279 /* Our two working polynomials */
280 uint16_t *A
= snewn(SIZE
, uint16_t);
281 uint16_t *B
= snewn(SIZE
, uint16_t);
283 /* Coefficient of the input value in each one */
284 uint16_t *Ac
= snewn(SIZE
, uint16_t);
285 uint16_t *Bc
= snewn(SIZE
, uint16_t);
287 /* Initialise A to the input, and Ac correspondingly to 1 */
288 memcpy(A
, in
, p
*sizeof(uint16_t));
291 for (size_t i
= 1; i
< SIZE
; i
++)
294 /* Initialise B to the quotient polynomial of the ring, x^p-x-1
297 for (size_t i
= 2; i
< p
; i
++)
300 for (size_t i
= 0; i
< SIZE
; i
++)
303 /* Run the gcd-finding algorithm. */
304 for (size_t i
= 0; i
< STEPS
; i
++) {
306 * First swap round so that A is the one we'll be dividing by x.
308 * In the case where one of the two polys has a zero constant
309 * term, it's that one. In the other case, it's the one of
310 * smaller degree. We must compute both, and choose between
311 * them in a side-channel-safe way.
313 unsigned x_divides_A
= iszero(A
[0]);
314 unsigned x_divides_B
= iszero(B
[0]);
315 unsigned B_is_bigger
= 0;
317 unsigned not_seen_top_term_of_A
= 1, not_seen_top_term_of_B
= 1;
318 for (size_t j
= SIZE
; j
-- > 0 ;) {
319 not_seen_top_term_of_A
&= iszero(A
[j
]);
320 not_seen_top_term_of_B
&= iszero(B
[j
]);
321 B_is_bigger
|= (~not_seen_top_term_of_B
&
322 not_seen_top_term_of_A
);
325 unsigned need_swap
= x_divides_B
| (~x_divides_A
& B_is_bigger
);
326 uint16_t swap_mask
= -need_swap
;
327 for (size_t j
= 0; j
< SIZE
; j
++) {
328 uint16_t diff
= (A
[j
] ^ B
[j
]) & swap_mask
;
332 for (size_t j
= 0; j
< SIZE
; j
++) {
333 uint16_t diff
= (Ac
[j
] ^ Bc
[j
]) & swap_mask
;
339 * Replace A with a linear combination of both A and B that
340 * has constant term zero, which we do by calculating
342 * (constant term of B) * A - (constant term of A) * B
344 * In one of the two cases, A's constant term is already zero,
345 * so the coefficient of B will be zero too; hence, this will
346 * do nothing useful (it will merely scale A by some scalar
347 * value), but it will take the same length of time as doing
348 * something, which is just what we want.
350 uint16_t Amult
= B
[0], Bmult
= q
- A
[0];
351 for (size_t j
= 0; j
< SIZE
; j
++)
352 A
[j
] = REDUCE(Amult
* A
[j
] + Bmult
* B
[j
]);
353 /* And do the same transformation to Ac */
354 for (size_t j
= 0; j
< SIZE
; j
++)
355 Ac
[j
] = REDUCE(Amult
* Ac
[j
] + Bmult
* Bc
[j
]);
358 * Now divide A by x, and compensate by multiplying Ac by
359 * x^{p-1}-1 mod x^p-x-1.
361 * That multiplication is particularly easy, precisely because
362 * x^{p-1}-1 is the multiplicative inverse of x! Each x^n term
363 * for n>0 just moves down to the x^{n-1} term, and only the
364 * constant term has to be dealt with in an interesting way.
366 for (size_t j
= 1; j
< SIZE
; j
++)
369 uint16_t Ac0
= Ac
[0];
370 for (size_t j
= 1; j
< p
; j
++)
373 Ac
[0] = REDUCE(Ac
[0] + q
- Ac0
);
377 * Now we expect that A is 0, and B is a constant. If so, then
378 * they are coprime, and we're going to return success. If not,
379 * they have a common factor.
381 unsigned success
= iszero(A
[0]) & (1 ^ iszero(B
[0]));
382 for (size_t j
= 1; j
< SIZE
; j
++)
383 success
&= iszero(A
[j
]) & iszero(B
[j
]);
386 * So we're going to return Bc, but first, scale it by the
387 * multiplicative inverse of the constant we ended up with in
390 uint16_t scale
= INVERT(B
[0]);
391 for (size_t i
= 0; i
< p
; i
++)
392 out
[i
] = REDUCE(scale
* Bc
[i
]);
394 smemclr(A
, SIZE
* sizeof(*A
));
396 smemclr(B
, SIZE
* sizeof(*B
));
398 smemclr(Ac
, SIZE
* sizeof(*Ac
));
400 smemclr(Bc
, SIZE
* sizeof(*Bc
));
407 * Given an array of values mod q, convert each one to its
408 * minimum-absolute-value representative, and then reduce mod 3.
410 * Output values are 0, 1 and 0xFFFF, representing -1.
412 * (Normally our arrays of uint16_t are in 'minimal non-negative
413 * residue' form, so the output of this function is unusual. But it's
414 * useful to have it in this form so that it can be reused by
415 * ntru_round3. You can put it back to the usual representation using
416 * ntru_normalise, below.)
418 void ntru_mod3(uint16_t *out
, const uint16_t *in
, unsigned p
, unsigned q
)
420 uint64_t qrecip
= reciprocal_for_reduction(q
);
421 uint64_t recip3
= reciprocal_for_reduction(3);
424 uint16_t adjust
= 3 - reduce(bias
-1, 3, recip3
);
426 for (unsigned i
= 0; i
< p
; i
++) {
427 uint16_t val
= reduce(in
[i
] + bias
, q
, qrecip
);
428 uint16_t residue
= reduce(val
+ adjust
, 3, recip3
);
429 out
[i
] = residue
- 1;
434 * Given an array of values mod q, round each one to the nearest
435 * multiple of 3 to its minimum-absolute-value representative.
437 * Output values are signed integers coerced to uint16_t, so again,
438 * use ntru_normalise afterwards to put them back to normal.
440 void ntru_round3(uint16_t *out
, const uint16_t *in
, unsigned p
, unsigned q
)
444 ntru_mod3(out
, in
, p
, q
);
445 for (unsigned i
= 0; i
< p
; i
++)
446 out
[i
] = REDUCE(in
[i
] + bias
) - bias
- out
[i
];
450 * Given an array of signed integers coerced to uint16_t in the range
451 * [-q/2,+q/2], normalise them back to mod q values.
453 static void ntru_normalise(uint16_t *out
, const uint16_t *in
,
454 unsigned p
, unsigned q
)
456 for (unsigned i
= 0; i
< p
; i
++)
457 out
[i
] = in
[i
] + q
* (in
[i
] >> 15);
461 * Given an array of values mod q, add a constant to each one.
463 void ntru_bias(uint16_t *out
, const uint16_t *in
, unsigned bias
,
464 unsigned p
, unsigned q
)
467 for (unsigned i
= 0; i
< p
; i
++)
468 out
[i
] = REDUCE(in
[i
] + bias
);
472 * Given an array of values mod q, multiply each one by a constant.
474 void ntru_scale(uint16_t *out
, const uint16_t *in
, uint16_t scale
,
475 unsigned p
, unsigned q
)
478 for (unsigned i
= 0; i
< p
; i
++)
479 out
[i
] = REDUCE(in
[i
] * scale
);
483 * Given an array of values mod 3, convert them to values mod q in a
484 * way that maps -1,0,+1 to -1,0,+1.
486 static void ntru_expand(
487 uint16_t *out
, const uint16_t *in
, unsigned p
, unsigned q
)
489 for (size_t i
= 0; i
< p
; i
++) {
491 /* Map 2 to q-1, and leave 0 and 1 unchanged */
492 v
+= (v
>> 1) * (q
-3);
497 /* ----------------------------------------------------------------------
498 * Implement the binary encoding from ntruprime-20201007.pdf, which is
499 * used to encode public keys and ciphertexts (though not plaintexts,
500 * which are done in a much simpler way).
502 * The general idea is that your encoder takes as input a list of
503 * small non-negative integers (r_i), and a sequence of limits (m_i)
504 * such that 0 <= r_i < m_i, and emits a sequence of bytes that encode
505 * all of these as tightly as reasonably possible.
507 * That's more general than is really needed, because in both the
508 * actual uses of this encoding, the input m_i are all the same! But
509 * the array of (r_i,m_i) pairs evolves during encoding, so they don't
510 * _stay_ all the same, so you still have to have all the generality.
512 * The encoding process makes a number of passes along the list of
513 * inputs. In each step, pairs of adjacent numbers are combined into
514 * one larger one by turning (r_i,m_i) and (r_{i+1},m_{i+1}) into the
515 * pair (r_i + m_i r_{i+1}, m_i m_{i+1}), i.e. so that the original
516 * numbers could be recovered by taking the quotient and remaiinder of
517 * the new r value by m_i. Then, if the new m_i is at least 2^14, we
518 * emit the low 8 bits of r_i to the output stream and reduce r_i and
519 * its limit correspondingly. So at the end of the pass, we've got
520 * half as many numbers still to encode, they're all still not too
521 * big, and we've emitted some amount of data into the output. Then do
522 * another pass, keep going until there's only one number left, and
523 * emit it little-endian.
525 * That's all very well, but how do you decode it again? DJB exhibits
526 * a pair of recursive functions that are supposed to be mutually
527 * inverse, but I didn't have any confidence that I'd be able to debug
528 * them sensibly if they turned out not to be (or rather, if I
529 * implemented one of them wrong). So I came up with my own strategy
532 * In my strategy, we start by processing just the (m_i) into an
533 * 'encoding schedule' consisting of a sequence of simple
534 * instructions. The instructions operate on a FIFO queue of numbers,
535 * initialised to the original (r_i). The three instruction types are:
537 * - 'COMBINE': consume two numbers a,b from the head of the queue,
538 * combine them by calculating a + m*b for some specified m, and
539 * push the result on the tail of the queue.
541 * - 'BYTE': divide the tail element of the queue by 2^8 and emit the
542 * low bits into the output stream.
544 * - 'COPY': pop a number from the head of the queue and push it
545 * straight back on the tail. (Used for handling the leftover
546 * element at the end of a pass if the input to the pass was a list
549 * So we effectively implement DJB's encoding process in simulation,
550 * and instead of actually processing a set of (r_i), we 'compile' the
551 * process into a sequence of instructions that can be handed just the
552 * (r_i) later and encode them in the right way. At the end of the
553 * instructions, the queue is expected to have been reduced to length
554 * 1 and contain the single integer 0.
556 * The nice thing about this system is that each of those three
557 * instructions is easy to reverse. So you can also use the same
558 * instructions for decoding: start with a queue containing 0, and
559 * process the instructions in reverse order and reverse sense. So
560 * BYTE means to _consume_ a byte from the encoded data (starting from
561 * the rightmost end) and use it to make a queue element bigger; and
562 * COMBINE run in reverse pops a single element from one end of the
563 * queue, divides it by m, and pushes the quotient and remainder on
566 * (So it's easy to debug, because the queue passes through the exact
567 * same sequence of states during decoding that it did during
568 * encoding, just in reverse order.)
570 * Also, the encoding schedule comes with information about the
571 * expected size of the encoded data, because you can find that out
572 * easily by just counting the BYTE commands.
577 * Command values appearing in the 'ops' array. ENC_COPY and
578 * ENC_BYTE are single values; values of the form
579 * (ENC_COMBINE_BASE + m) represent a COMBINE command with
582 ENC_COPY
, ENC_BYTE
, ENC_COMBINE_BASE
584 struct NTRUEncodeSchedule
{
586 * Object representing a compiled set of encoding instructions.
588 * 'nvals' is the number of r_i we expect to encode. 'nops' is the
589 * number of encoding commands in the 'ops' list; 'opsize' is the
590 * physical size of the array, used during construction.
592 * 'endpos' is used to avoid a last-minute faff during decoding.
593 * We implement our FIFO of integers as a ring buffer of size
594 * 'nvals'. Encoding cycles round it some number of times, and the
595 * final 0 element ends up at some random location in the array.
596 * If we know _where_ the 0 ends up during encoding, we can put
597 * the initial 0 there at the start of decoding, and then when we
598 * finish reversing all the instructions, we'll end up with the
599 * output numbers already arranged at their correct positions, so
600 * that there's no need to rotate the array at the last minute.
602 size_t nvals
, endpos
, nops
, opsize
;
605 static inline void sched_append(NTRUEncodeSchedule
*sched
, uint16_t op
)
607 /* Helper function to append an operation to the schedule, and
609 sgrowarray(sched
->ops
, sched
->opsize
, sched
->nops
);
610 sched
->ops
[sched
->nops
++] = op
;
612 sched
->endpos
= (sched
->endpos
+ 1) % sched
->nvals
;
616 * Take in the list of limit values (m_i) and compute the encoding
619 NTRUEncodeSchedule
*ntru_encode_schedule(const uint16_t *ms_in
, size_t n
)
621 NTRUEncodeSchedule
*sched
= snew(NTRUEncodeSchedule
);
624 sched
->nops
= sched
->opsize
= 0;
630 * 'ms' is the list of (m_i) on input to the current pass.
631 * 'ms_new' is the list output from the current pass. After each
632 * pass we swap the arrays round.
634 uint32_t *ms
= snewn(n
, uint32_t);
635 uint32_t *msnew
= snewn(n
, uint32_t);
636 for (size_t i
= 0; i
< n
; i
++)
641 for (size_t i
= 0; i
< n
; i
+= 2) {
644 * Odd element at the end of the input list: just copy
645 * it unchanged to the output.
647 sched_append(sched
, ENC_COPY
);
648 msnew
[nnew
++] = ms
[i
];
653 * Normal case: consume two elements from the input list
656 uint32_t m1
= ms
[i
], m2
= ms
[i
+1], m
= m1
*m2
;
657 sched_append(sched
, ENC_COMBINE_BASE
+ m1
);
660 * And then, as long as the combined limit is big enough,
661 * emit an output byte from the bottom of it.
663 while (m
>= (1<<14)) {
664 sched_append(sched
, ENC_BYTE
);
669 * Whatever is left after that, we emit into the output
670 * list and append to the fifo.
676 * End of pass. The output list of (m_i) now becomes the input
686 * When that loop terminates, it's because there's exactly one
687 * number left to encode. (Or, technically, _at most_ one - but we
688 * don't support encoding a completely empty list in this
689 * implementation, because what would be the point?) That number
690 * is just emitted little-endian until its limit is 1 (meaning its
691 * only possible actual value is 0).
696 sched_append(sched
, ENC_BYTE
);
706 void ntru_encode_schedule_free(NTRUEncodeSchedule
*sched
)
713 * Calculate the output length of the encoded data in bytes.
715 size_t ntru_encode_schedule_length(NTRUEncodeSchedule
*sched
)
718 for (size_t i
= 0; i
< sched
->nops
; i
++)
719 if (sched
->ops
[i
] == ENC_BYTE
)
725 * Retrieve the number of items encoded. (Used by testcrypt.)
727 size_t ntru_encode_schedule_nvals(NTRUEncodeSchedule
*sched
)
733 * Actually encode a sequence of (r_i), emitting the output bytes to
734 * an arbitrary BinarySink.
736 void ntru_encode(NTRUEncodeSchedule
*sched
, const uint16_t *rs_in
,
739 size_t n
= sched
->nvals
;
740 uint32_t *rs
= snewn(n
, uint32_t);
741 for (size_t i
= 0; i
< n
; i
++)
745 * The head and tail pointers of the queue are both 'full'. That
746 * is, rs[head] is the first element actually in the queue, and
747 * rs[tail] is the last element.
749 * So you append to the queue by first advancing 'tail' and then
750 * writing to rs[tail], whereas you consume from the queue by
751 * first reading rs[head] and _then_ advancing 'head'.
753 * The more normal thing would be to make 'tail' point to the
754 * first empty slot instead of the last full one. But then you'd
755 * have to faff about with modular arithmetic to find the last
756 * full slot for the BYTE command, so in this case, it's easier to
757 * do it the less usual way.
759 size_t head
= 0, tail
= n
-1;
761 for (size_t i
= 0; i
< sched
->nops
; i
++) {
762 uint16_t op
= sched
->ops
[i
];
765 put_byte(bs
, rs
[tail
] & 0xFF);
769 uint32_t r
= rs
[head
];
770 head
= (head
+ 1) % n
;
771 tail
= (tail
+ 1) % n
;
776 uint32_t r1
= rs
[head
];
777 head
= (head
+ 1) % n
;
778 uint32_t r2
= rs
[head
];
779 head
= (head
+ 1) % n
;
780 tail
= (tail
+ 1) % n
;
781 rs
[tail
] = r1
+ (op
- ENC_COMBINE_BASE
) * r2
;
788 * Expect that we've ended up with a single zero in the queue, at
789 * exactly the position that the setup-time analysis predicted it.
791 assert(head
== sched
->endpos
);
792 assert(tail
== sched
->endpos
);
793 assert(rs
[head
] == 0);
795 smemclr(rs
, n
* sizeof(*rs
));
800 * Decode a ptrlen of binary data into a sequence of (r_i). The data
801 * is expected to be of exactly the right length (on pain of assertion
804 void ntru_decode(NTRUEncodeSchedule
*sched
, uint16_t *rs_out
, ptrlen data
)
806 size_t n
= sched
->nvals
;
807 const uint8_t *base
= (const uint8_t *)data
.ptr
;
808 const uint8_t *pos
= base
+ data
.len
;
811 * Initialise the queue to a single zero, at the 'endpos' position
812 * that will mean the final output is correctly aligned.
814 * 'head' and 'tail' have the same meanings as in encoding. So
815 * 'tail' is the location that BYTE modifies and COPY and COMBINE
816 * consume from, and 'head' is the location that COPY and COMBINE
817 * push on to. As in encoding, they both point at the extremal
818 * full slots in the array.
820 uint32_t *rs
= snewn(n
, uint32_t);
821 size_t head
= sched
->endpos
, tail
= head
;
824 for (size_t i
= sched
->nops
; i
-- > 0 ;) {
825 uint16_t op
= sched
->ops
[i
];
829 uint8_t byte
= *--pos
;
830 rs
[tail
] = (rs
[tail
] << 8) | byte
;
834 uint32_t r
= rs
[tail
];
835 tail
= (tail
+ n
- 1) % n
;
836 head
= (head
+ n
- 1) % n
;
841 uint32_t r
= rs
[tail
];
842 tail
= (tail
+ n
- 1) % n
;
844 uint32_t m
= op
- ENC_COMBINE_BASE
;
845 uint64_t mrecip
= reciprocal_for_reduction(m
);
848 r1
= reduce_with_quot(r
, &r2
, m
, mrecip
);
850 head
= (head
+ n
- 1) % n
;
852 head
= (head
+ n
- 1) % n
;
863 for (size_t i
= 0; i
< n
; i
++)
865 smemclr(rs
, n
* sizeof(*rs
));
869 /* ----------------------------------------------------------------------
870 * The actual public-key cryptosystem.
875 uint16_t *h
; /* public key */
876 uint16_t *f3
, *ginv
; /* private key */
877 uint16_t *rho
; /* for implicit rejection */
880 /* Helper function to free an array of uint16_t containing a ring
881 * element, clearing it on the way since some of them are sensitive. */
882 static void ring_free(uint16_t *val
, unsigned p
)
884 smemclr(val
, p
*sizeof(*val
));
888 void ntru_keypair_free(NTRUKeyPair
*keypair
)
890 ring_free(keypair
->h
, keypair
->p
);
891 ring_free(keypair
->f3
, keypair
->p
);
892 ring_free(keypair
->ginv
, keypair
->p
);
893 ring_free(keypair
->rho
, keypair
->p
);
897 /* Trivial accessors used by test programs. */
898 unsigned ntru_keypair_p(NTRUKeyPair
*keypair
) { return keypair
->p
; }
899 const uint16_t *ntru_pubkey(NTRUKeyPair
*keypair
) { return keypair
->h
; }
902 * Generate a value of the class DJB describes as 'Short': it consists
903 * of p terms that are all either 0 or +1 or -1, and exactly w of them
906 * Values of this kind are used for several purposes: part of the
907 * private key, a plaintext, and the 'rho' fake-plaintext value used
908 * for deliberately returning a duff but non-revealing session hash if
911 * -1 is represented as 2 in the output array. So if you want these
912 * numbers mod 3, then they come out already in the right form.
913 * Otherwise, use ntru_expand.
915 void ntru_gen_short(uint16_t *v
, unsigned p
, unsigned w
)
918 * Get enough random data to generate a polynomial all of whose p
919 * terms are in {0,+1,-1}, and exactly w of them are nonzero.
920 * We'll do this by making up a completely random sequence of
921 * {+1,-1} and then setting a random subset of them to 0.
923 * So we'll need p random bits to choose the nonzero values, and
924 * then (doing it the simplest way) log2(p!) bits to shuffle them,
925 * plus say 128 bits to ensure any fluctuations in uniformity are
928 * log2(p!) is a pain to calculate, so we'll bound it above by
929 * p*log2(p), which we bound in turn by p*16.
931 size_t randbitpos
= 17 * p
+ 128;
932 mp_int
*randdata
= mp_resize(mp_random_bits(randbitpos
), randbitpos
+ 32);
935 * Initial value before zeroing out some terms: p randomly chosen
938 for (size_t i
= 0; i
< p
; i
++)
939 v
[i
] = 1 + mp_get_bit(randdata
, --randbitpos
);
942 * Hereafter we're going to extract random bits by multiplication,
943 * treating randdata as a large fixed-point number.
945 mp_reduce_mod_2to(randdata
, randbitpos
);
948 * Zero out some terms, leaving a randomly selected w of them
951 uint32_t nonzeros_left
= w
;
952 mp_int
*x
= mp_new(64);
953 for (size_t i
= p
; i
-- > 0 ;) {
955 * Pick a random number out of the number of terms remaning.
957 mp_mul_integer_into(randdata
, randdata
, i
+1);
958 mp_rshift_fixed_into(x
, randdata
, randbitpos
);
959 mp_reduce_mod_2to(randdata
, randbitpos
);
960 size_t j
= mp_get_integer(x
);
963 * If that's less than nonzeros_left, then we're leaving this
964 * number nonzero. Otherwise we're zeroing it out.
966 uint32_t keep
= (uint32_t)(j
- nonzeros_left
) >> 31;
967 v
[i
] &= -keep
; /* clear this field if keep == 0 */
968 nonzeros_left
-= keep
; /* decrement counter if keep == 1 */
976 * Make a single attempt at generating a key pair. This involves
977 * inventing random elements of both our quotient rings and hoping
978 * they're both invertible.
980 * They may not be, if you're unlucky. The element of Z_q/<x^p-x-1>
981 * will _almost_ certainly be invertible, because that is a field, so
982 * invertibility can only fail if you were so unlucky as to choose the
983 * all-0s element. But the element of Z_3/<x^p-x-1> may fail to be
984 * invertible because it has a common factor with x^p-x-1 (which, over
985 * Z_3, is not irreducible).
987 * So we can't guarantee to generate a key pair in constant time,
988 * because there's no predicting how many retries we'll need. However,
989 * this isn't a failure of side-channel safety, because we completely
990 * discard all the random numbers and state from each failed attempt.
991 * So if there were a side-channel leakage from a failure, the only
992 * thing it would give away would be a bunch of random numbers that
993 * turned out not to be used anyway.
995 * But a _successful_ call to this function should execute in a
996 * secret-independent manner, and this 'make a single attempt'
997 * function is exposed in the API so that 'testsc' can check that.
999 NTRUKeyPair
*ntru_keygen_attempt(unsigned p
, unsigned q
, unsigned w
)
1002 * First invent g, which is the one more likely to fail to invert.
1003 * This is simply a uniformly random polynomial with p terms over
1004 * Z_3. So we need p*log2(3) random bits for it, plus 128 for
1005 * uniformity. It's easiest to bound log2(3) above by 2.
1007 size_t randbitpos
= 2 * p
+ 128;
1008 mp_int
*randdata
= mp_resize(mp_random_bits(randbitpos
), randbitpos
+ 32);
1011 * Select p random values from {0,1,2}.
1013 uint16_t *g
= snewn(p
, uint16_t);
1014 mp_int
*x
= mp_new(64);
1015 for (size_t i
= 0; i
< p
; i
++) {
1016 mp_mul_integer_into(randdata
, randdata
, 3);
1017 mp_rshift_fixed_into(x
, randdata
, randbitpos
);
1018 mp_reduce_mod_2to(randdata
, randbitpos
);
1019 g
[i
] = mp_get_integer(x
);
1025 * Try to invert g over Z_3, and fail if it isn't invertible.
1027 uint16_t *ginv
= snewn(p
, uint16_t);
1028 if (!ntru_ring_invert(ginv
, g
, p
, 3)) {
1035 * Fine; we have g. Now make up an f, and convert it to a
1036 * polynomial over q.
1038 uint16_t *f
= snewn(p
, uint16_t);
1039 ntru_gen_short(f
, p
, w
);
1040 ntru_expand(f
, f
, p
, q
);
1045 uint16_t *f3
= snewn(p
, uint16_t);
1046 ntru_scale(f3
, f
, 3, p
, q
);
1049 * Try to invert 3*f over Z_q. This should be _almost_ guaranteed
1050 * to succeed, since Z_q/<x^p-x-1> is a field, so the only
1051 * non-invertible value is 0. Even so, there _is_ one, so check
1054 uint16_t *f3inv
= snewn(p
, uint16_t);
1055 if (!ntru_ring_invert(f3inv
, f3
, p
, q
)) {
1058 ring_free(f3inv
, p
);
1065 * Make the public key, by converting g to a polynomial over q and
1066 * then multiplying by f3inv.
1068 uint16_t *g_q
= snewn(p
, uint16_t);
1069 ntru_expand(g_q
, g
, p
, q
);
1070 uint16_t *h
= snewn(p
, uint16_t);
1071 ntru_ring_multiply(h
, g_q
, f3inv
, p
, q
);
1074 * Make up rho, used to substitute for the plaintext in the
1075 * session hash in case of confirmation failure.
1077 uint16_t *rho
= snewn(p
, uint16_t);
1078 ntru_gen_short(rho
, p
, w
);
1081 * And we're done! Free everything except the pieces we're
1084 NTRUKeyPair
*keypair
= snew(NTRUKeyPair
);
1090 keypair
->ginv
= ginv
;
1093 ring_free(f3inv
, p
);
1100 * The top-level key generation function for real use (as opposed to
1101 * testsc): keep trying to make a key until you succeed.
1103 NTRUKeyPair
*ntru_keygen(unsigned p
, unsigned q
, unsigned w
)
1106 NTRUKeyPair
*keypair
= ntru_keygen_attempt(p
, q
, w
);
1113 * Public-key encryption.
1115 void ntru_encrypt(uint16_t *ciphertext
, const uint16_t *plaintext
,
1116 uint16_t *pubkey
, unsigned p
, unsigned q
)
1118 uint16_t *r_q
= snewn(p
, uint16_t);
1119 ntru_expand(r_q
, plaintext
, p
, q
);
1121 uint16_t *unrounded
= snewn(p
, uint16_t);
1122 ntru_ring_multiply(unrounded
, r_q
, pubkey
, p
, q
);
1124 ntru_round3(ciphertext
, unrounded
, p
, q
);
1125 ntru_normalise(ciphertext
, ciphertext
, p
, q
);
1128 ring_free(unrounded
, p
);
1132 * Public-key decryption.
1134 void ntru_decrypt(uint16_t *plaintext
, const uint16_t *ciphertext
,
1135 NTRUKeyPair
*keypair
)
1137 unsigned p
= keypair
->p
, q
= keypair
->q
, w
= keypair
->w
;
1138 uint16_t *tmp
= snewn(p
, uint16_t);
1140 ntru_ring_multiply(tmp
, ciphertext
, keypair
->f3
, p
, q
);
1142 ntru_mod3(tmp
, tmp
, p
, q
);
1143 ntru_normalise(tmp
, tmp
, p
, 3);
1145 ntru_ring_multiply(plaintext
, tmp
, keypair
->ginv
, p
, 3);
1149 * With luck, this should have recovered exactly the original
1150 * plaintext. But, as per the spec, we check whether it has
1151 * exactly w nonzero coefficients, and if not, then something has
1152 * gone wrong - and in that situation we time-safely substitute a
1155 * (I don't know exactly why we do this, but I assume it's because
1156 * otherwise the mis-decoded output could be made to disgorge a
1157 * secret about the private key in some way.)
1160 unsigned weight
= p
;
1161 for (size_t i
= 0; i
< p
; i
++)
1162 weight
-= iszero(plaintext
[i
]);
1163 unsigned ok
= iszero(weight
^ w
);
1166 * The default failure return value consists of w 1s followed by
1169 unsigned mask
= ok
- 1;
1170 for (size_t i
= 0; i
< w
; i
++) {
1171 uint16_t diff
= (1 ^ plaintext
[i
]) & mask
;
1172 plaintext
[i
] ^= diff
;
1174 for (size_t i
= w
; i
< p
; i
++) {
1175 uint16_t diff
= (0 ^ plaintext
[i
]) & mask
;
1176 plaintext
[i
] ^= diff
;
1180 /* ----------------------------------------------------------------------
1181 * Encode and decode public keys, ciphertexts and plaintexts.
1183 * Public keys and ciphertexts use the complicated binary encoding
1184 * system implemented above. In both cases, the inputs are regarded as
1185 * symmetric about zero, and are first biased to map their most
1186 * negative permitted value to 0, so that they become non-negative and
1187 * hence suitable as inputs to the encoding system. In the case of a
1188 * ciphertext, where the input coefficients have also been coerced to
1189 * be multiples of 3, we divide by 3 as well, saving space by reducing
1190 * the upper bounds (m_i) on all the encoded numbers.
1194 * Compute the encoding schedule for a public key.
1196 static NTRUEncodeSchedule
*ntru_encode_pubkey_schedule(unsigned p
, unsigned q
)
1198 uint16_t *ms
= snewn(p
, uint16_t);
1199 for (size_t i
= 0; i
< p
; i
++)
1201 NTRUEncodeSchedule
*sched
= ntru_encode_schedule(ms
, p
);
1207 * Encode a public key.
1209 void ntru_encode_pubkey(const uint16_t *pubkey
, unsigned p
, unsigned q
,
1212 /* Compute the biased version for encoding */
1213 uint16_t *biased_pubkey
= snewn(p
, uint16_t);
1214 ntru_bias(biased_pubkey
, pubkey
, q
/ 2, p
, q
);
1217 NTRUEncodeSchedule
*sched
= ntru_encode_pubkey_schedule(p
, q
);
1218 ntru_encode(sched
, biased_pubkey
, bs
);
1219 ntru_encode_schedule_free(sched
);
1221 ring_free(biased_pubkey
, p
);
1225 * Decode a public key and write it into 'pubkey'. We also return a
1226 * ptrlen pointing at the chunk of data we removed from the
1229 ptrlen
ntru_decode_pubkey(uint16_t *pubkey
, unsigned p
, unsigned q
,
1232 NTRUEncodeSchedule
*sched
= ntru_encode_pubkey_schedule(p
, q
);
1234 /* Retrieve the right number of bytes from the source */
1235 size_t len
= ntru_encode_schedule_length(sched
);
1236 ptrlen encoded
= get_data(src
, len
);
1238 /* If there wasn't enough data, give up and return all-zeroes
1239 * purely for determinism. But that value should never be
1240 * used, because the caller will also check get_err(src). */
1241 memset(pubkey
, 0, p
*sizeof(*pubkey
));
1243 /* Do the decoding */
1244 ntru_decode(sched
, pubkey
, encoded
);
1246 /* Unbias the coefficients */
1247 ntru_bias(pubkey
, pubkey
, q
-q
/2, p
, q
);
1250 ntru_encode_schedule_free(sched
);
1255 * For ciphertext biasing: work out the largest absolute value a
1256 * ciphertext element can take, which is given by taking q/2 and
1257 * rounding it to the nearest multiple of 3.
1259 static inline unsigned ciphertext_bias(unsigned q
)
1265 * The number of possible values of a ciphertext coefficient (for use
1266 * as the m_i in encoding) ranges from +ciphertext_bias(q) to
1267 * -ciphertext_bias(q) inclusive.
1269 static inline unsigned ciphertext_m(unsigned q
)
1271 return 1 + 2 * ciphertext_bias(q
);
1275 * Compute the encoding schedule for a ciphertext.
1277 static NTRUEncodeSchedule
*ntru_encode_ciphertext_schedule(
1278 unsigned p
, unsigned q
)
1280 unsigned m
= ciphertext_m(q
);
1281 uint16_t *ms
= snewn(p
, uint16_t);
1282 for (size_t i
= 0; i
< p
; i
++)
1284 NTRUEncodeSchedule
*sched
= ntru_encode_schedule(ms
, p
);
1290 * Encode a ciphertext.
1292 void ntru_encode_ciphertext(const uint16_t *ciphertext
, unsigned p
, unsigned q
,
1298 * Bias the ciphertext, and scale down by 1/3, which we do by
1299 * modular multiplication by the inverse of 3 mod q. (That only
1300 * works if we know the inputs are all _exact_ multiples of 3
1303 uint16_t *biased_ciphertext
= snewn(p
, uint16_t);
1304 ntru_bias(biased_ciphertext
, ciphertext
, 3 * ciphertext_bias(q
), p
, q
);
1305 ntru_scale(biased_ciphertext
, biased_ciphertext
, INVERT(3), p
, q
);
1308 NTRUEncodeSchedule
*sched
= ntru_encode_ciphertext_schedule(p
, q
);
1309 ntru_encode(sched
, biased_ciphertext
, bs
);
1310 ntru_encode_schedule_free(sched
);
1312 ring_free(biased_ciphertext
, p
);
1315 ptrlen
ntru_decode_ciphertext(uint16_t *ct
, NTRUKeyPair
*keypair
,
1318 unsigned p
= keypair
->p
, q
= keypair
->q
;
1320 NTRUEncodeSchedule
*sched
= ntru_encode_ciphertext_schedule(p
, q
);
1322 /* Retrieve the right number of bytes from the source */
1323 size_t len
= ntru_encode_schedule_length(sched
);
1324 ptrlen encoded
= get_data(src
, len
);
1326 /* As above, return deterministic nonsense on failure */
1327 memset(ct
, 0, p
*sizeof(*ct
));
1329 /* Do the decoding */
1330 ntru_decode(sched
, ct
, encoded
);
1332 /* Undo the scaling and bias */
1333 ntru_scale(ct
, ct
, 3, p
, q
);
1334 ntru_bias(ct
, ct
, q
- 3 * ciphertext_bias(q
), p
, q
);
1337 ntru_encode_schedule_free(sched
);
1338 return encoded
; /* also useful to the caller, optionally */
1342 * Encode a plaintext.
1344 * This is a much simpler encoding than the NTRUEncodeSchedule system:
1345 * since elements of a plaintext are mod 3, we just encode each one in
1346 * 2 bits, applying the usual bias so that {-1,0,+1} map to {0,1,2}
1349 * There's no corresponding decode function, because plaintexts are
1350 * never transmitted on the wire (the whole point is that they're too
1351 * secret!). Plaintexts are only encoded in order to put them into
1354 void ntru_encode_plaintext(const uint16_t *plaintext
, unsigned p
,
1357 unsigned byte
= 0, bitpos
= 0;
1358 for (size_t i
= 0; i
< p
; i
++) {
1359 unsigned encoding
= (plaintext
[i
] + 1) * iszero(plaintext
[i
] >> 1);
1360 byte
|= encoding
<< bitpos
;
1362 if (bitpos
== 8 || i
+1 == p
) {
1370 /* ----------------------------------------------------------------------
1371 * Compute the hashes required by the key exchange layer of NTRU Prime.
1373 * There are two of these. The 'confirmation hash' is sent by the
1374 * server along with the ciphertext, and the client can recalculate it
1375 * to check whether the ciphertext was decrypted correctly. Then, the
1376 * 'session hash' is the actual output of key exchange, and if the
1377 * confirmation hash doesn't match, it gets deliberately corrupted.
1381 * Make the confirmation hash, whose inputs are the plaintext and the
1384 * This is defined as H(2 || H(3 || r) || H(4 || K)), where r is the
1385 * plaintext and K is the public key (as encoded by the above
1386 * functions), and the constants 2,3,4 are single bytes. The choice of
1387 * hash function (H itself) is SHA-512 truncated to 256 bits.
1389 * (To be clear: that is _not_ the thing that FIPS 180-4 6.7 defines
1390 * as "SHA-512/256", which varies the initialisation vector of the
1391 * SHA-512 algorithm as well as truncating the output. _This_
1392 * algorithm uses the standard SHA-512 IV, and _just_ truncates the
1393 * output, in the manner suggested by FIPS 180-4 section 7.)
1395 * 'out' should therefore expect to receive 32 bytes of data.
1397 static void ntru_confirmation_hash(
1398 uint8_t *out
, const uint16_t *plaintext
,
1399 const uint16_t *pubkey
, unsigned p
, unsigned q
)
1401 /* The outer hash object */
1402 ssh_hash
*hconfirm
= ssh_hash_new(&ssh_sha512
);
1403 put_byte(hconfirm
, 2); /* initial byte 2 */
1405 uint8_t hashdata
[64];
1407 /* Compute H(3 || r) and add it to the main hash */
1408 ssh_hash
*h3r
= ssh_hash_new(&ssh_sha512
);
1410 ntru_encode_plaintext(plaintext
, p
, BinarySink_UPCAST(h3r
));
1411 ssh_hash_final(h3r
, hashdata
);
1412 put_data(hconfirm
, hashdata
, 32);
1414 /* Compute H(4 || K) and add it to the main hash */
1415 ssh_hash
*h4K
= ssh_hash_new(&ssh_sha512
);
1417 ntru_encode_pubkey(pubkey
, p
, q
, BinarySink_UPCAST(h4K
));
1418 ssh_hash_final(h4K
, hashdata
);
1419 put_data(hconfirm
, hashdata
, 32);
1421 /* Compute the full output of the main SHA-512 hash */
1422 ssh_hash_final(hconfirm
, hashdata
);
1424 /* And copy the first 32 bytes into the caller's output array */
1425 memcpy(out
, hashdata
, 32);
1426 smemclr(hashdata
, sizeof(hashdata
));
1430 * Make the session hash, whose inputs are the plaintext, the
1431 * ciphertext, and the confirmation hash (hence, transitively, a
1432 * dependence on the public key as well).
1434 * As computed by the server, and by the client if the confirmation
1435 * hash matched, this is defined as
1437 * H(1 || H(3 || r) || ciphertext || confirmation hash)
1439 * but if the confirmation hash _didn't_ match, then the plaintext r
1440 * is replaced with the dummy plaintext-shaped value 'rho' we invented
1441 * during key generation (presumably to avoid leaking any information
1442 * about our secrets), and the initial byte 1 is replaced with 0 (to
1443 * ensure that the resulting hash preimage can't match any legitimate
1444 * preimage). So in that case, you instead get
1446 * H(0 || H(3 || rho) || ciphertext || confirmation hash)
1448 * The inputs to this function include 'ok', which is the value to use
1449 * as the initial byte (1 on success, 0 on failure), and 'plaintext'
1450 * which should already have been substituted with rho in case of
1453 * The ciphertext is provided in already-encoded form.
1455 static void ntru_session_hash(
1456 uint8_t *out
, unsigned ok
, const uint16_t *plaintext
,
1457 unsigned p
, ptrlen ciphertext
, ptrlen confirmation_hash
)
1459 /* The outer hash object */
1460 ssh_hash
*hsession
= ssh_hash_new(&ssh_sha512
);
1461 put_byte(hsession
, ok
); /* initial byte 1 or 0 */
1463 uint8_t hashdata
[64];
1465 /* Compute H(3 || r), or maybe H(3 || rho), and add it to the main hash */
1466 ssh_hash
*h3r
= ssh_hash_new(&ssh_sha512
);
1468 ntru_encode_plaintext(plaintext
, p
, BinarySink_UPCAST(h3r
));
1469 ssh_hash_final(h3r
, hashdata
);
1470 put_data(hsession
, hashdata
, 32);
1472 /* Put the ciphertext and confirmation hash in */
1473 put_datapl(hsession
, ciphertext
);
1474 put_datapl(hsession
, confirmation_hash
);
1476 /* Compute the full output of the main SHA-512 hash */
1477 ssh_hash_final(hsession
, hashdata
);
1479 /* And copy the first 32 bytes into the caller's output array */
1480 memcpy(out
, hashdata
, 32);
1481 smemclr(hashdata
, sizeof(hashdata
));
1484 /* ----------------------------------------------------------------------
1485 * Top-level key exchange and SSH integration.
1487 * Although this system borrows the ECDH packet structure, it's unlike
1488 * true ECDH in that it is completely asymmetric between client and
1489 * server. So we have two separate vtables of methods for the two
1490 * sides of the system, and a third vtable containing only the class
1491 * methods, in particular a constructor which chooses which one to
1496 * The parameters p,q,w for the system. There are other choices of
1497 * these, but OpenSSH only specifies this set. (If that ever changes,
1498 * we'll need to turn these into elements of the state structures.)
1504 static char *ssh_ntru_description(const ssh_kex
*kex
)
1506 return dupprintf("NTRU Prime / Curve25519 hybrid key exchange");
1510 * State structure for the client, which takes the role of inventing a
1511 * key pair and decrypting a secret plaintext sent to it by the server.
1513 typedef struct ntru_client_key
{
1514 NTRUKeyPair
*keypair
;
1515 ecdh_key
*curve25519
;
1520 static void ssh_ntru_client_free(ecdh_key
*dh
);
1521 static void ssh_ntru_client_getpublic(ecdh_key
*dh
, BinarySink
*bs
);
1522 static bool ssh_ntru_client_getkey(ecdh_key
*dh
, ptrlen remoteKey
,
1525 static const ecdh_keyalg ssh_ntru_client_vt
= {
1526 /* This vtable has no 'new' method, because it's constructed via
1527 * the selector vt below */
1528 .free
= ssh_ntru_client_free
,
1529 .getpublic
= ssh_ntru_client_getpublic
,
1530 .getkey
= ssh_ntru_client_getkey
,
1531 .description
= ssh_ntru_description
,
1534 static ecdh_key
*ssh_ntru_client_new(void)
1536 ntru_client_key
*nk
= snew(ntru_client_key
);
1537 nk
->ek
.vt
= &ssh_ntru_client_vt
;
1539 nk
->keypair
= ntru_keygen(p_LIVE
, q_LIVE
, w_LIVE
);
1540 nk
->curve25519
= ecdh_key_new(&ssh_ec_kex_curve25519
, false);
1545 static void ssh_ntru_client_free(ecdh_key
*dh
)
1547 ntru_client_key
*nk
= container_of(dh
, ntru_client_key
, ek
);
1548 ntru_keypair_free(nk
->keypair
);
1549 ecdh_key_free(nk
->curve25519
);
1553 static void ssh_ntru_client_getpublic(ecdh_key
*dh
, BinarySink
*bs
)
1555 ntru_client_key
*nk
= container_of(dh
, ntru_client_key
, ek
);
1558 * The client's public information is a single SSH string
1559 * containing the NTRU public key and the Curve25519 public point
1560 * concatenated. So write both of those into the output
1563 ntru_encode_pubkey(nk
->keypair
->h
, p_LIVE
, q_LIVE
, bs
);
1564 ecdh_key_getpublic(nk
->curve25519
, bs
);
1567 static bool ssh_ntru_client_getkey(ecdh_key
*dh
, ptrlen remoteKey
,
1570 ntru_client_key
*nk
= container_of(dh
, ntru_client_key
, ek
);
1573 * We expect the server to have sent us a string containing a
1574 * ciphertext, a confirmation hash, and a Curve25519 public point.
1575 * Extract all three.
1577 BinarySource src
[1];
1578 BinarySource_BARE_INIT_PL(src
, remoteKey
);
1580 uint16_t *ciphertext
= snewn(p_LIVE
, uint16_t);
1581 ptrlen ciphertext_encoded
= ntru_decode_ciphertext(
1582 ciphertext
, nk
->keypair
, src
);
1583 ptrlen confirmation_hash
= get_data(src
, 32);
1584 ptrlen curve25519_remoteKey
= get_data(src
, 32);
1586 if (get_err(src
) || get_avail(src
)) {
1587 /* Hard-fail if the input wasn't exactly the right length */
1588 ring_free(ciphertext
, p_LIVE
);
1593 * Main hash object which will combine the NTRU and Curve25519
1596 ssh_hash
*h
= ssh_hash_new(&ssh_sha512
);
1598 /* Reusable buffer for storing various hash outputs. */
1599 uint8_t hashdata
[64];
1605 /* Decrypt the ciphertext to recover the server's plaintext */
1606 uint16_t *plaintext
= snewn(p_LIVE
, uint16_t);
1607 ntru_decrypt(plaintext
, ciphertext
, nk
->keypair
);
1609 /* Make the confirmation hash */
1610 ntru_confirmation_hash(hashdata
, plaintext
, nk
->keypair
->h
,
1613 /* Check it matches the one the server sent */
1614 unsigned ok
= smemeq(hashdata
, confirmation_hash
.ptr
, 32);
1616 /* If not, substitute in rho for the plaintext in the session hash */
1617 unsigned mask
= ok
-1;
1618 for (size_t i
= 0; i
< p_LIVE
; i
++)
1619 plaintext
[i
] ^= mask
& (plaintext
[i
] ^ nk
->keypair
->rho
[i
]);
1621 /* Compute the session hash, whether or not we did that */
1622 ntru_session_hash(hashdata
, ok
, plaintext
, p_LIVE
, ciphertext_encoded
,
1625 /* Free temporary values */
1626 ring_free(plaintext
, p_LIVE
);
1627 ring_free(ciphertext
, p_LIVE
);
1629 /* And put the NTRU session hash into the main hash object. */
1630 put_data(h
, hashdata
, 32);
1637 strbuf
*otherkey
= strbuf_new_nm();
1639 /* Call out to Curve25519 to compute the shared secret from that
1641 bool ok
= ecdh_key_getkey(nk
->curve25519
, curve25519_remoteKey
,
1642 BinarySink_UPCAST(otherkey
));
1644 /* If that failed (which only happens if the other end does
1645 * something wrong, like sending a low-order curve point
1646 * outside the subgroup it's supposed to), we might as well
1647 * just abort and return failure. That's what we'd have done
1648 * in standalone Curve25519. */
1651 smemclr(hashdata
, sizeof(hashdata
));
1652 strbuf_free(otherkey
);
1657 * ecdh_key_getkey will have returned us a chunk of data
1658 * containing an encoded mpint, which is how the Curve25519
1659 * output normally goes into the exchange hash. But in this
1660 * context we want to treat it as a fixed big-endian 32 bytes,
1661 * so extract it from its encoding and put it into the main
1662 * hash object in the new format.
1664 BinarySource src
[1];
1665 BinarySource_BARE_INIT_PL(src
, ptrlen_from_strbuf(otherkey
));
1666 mp_int
*curvekey
= get_mp_ssh2(src
);
1668 for (unsigned i
= 32; i
-- > 0 ;)
1669 put_byte(h
, mp_get_byte(curvekey
, i
));
1672 strbuf_free(otherkey
);
1676 * Finish up: compute the final output hash (full 64 bytes of
1677 * SHA-512 this time), and return it encoded as a string.
1679 ssh_hash_final(h
, hashdata
);
1680 put_stringpl(bs
, make_ptrlen(hashdata
, sizeof(hashdata
)));
1681 smemclr(hashdata
, sizeof(hashdata
));
1687 * State structure for the server, which takes the role of inventing a
1688 * secret plaintext and sending it to the client encrypted with the
1689 * public key the client sent.
1691 typedef struct ntru_server_key
{
1692 uint16_t *plaintext
;
1693 strbuf
*ciphertext_encoded
, *confirmation_hash
;
1694 ecdh_key
*curve25519
;
1699 static void ssh_ntru_server_free(ecdh_key
*dh
);
1700 static void ssh_ntru_server_getpublic(ecdh_key
*dh
, BinarySink
*bs
);
1701 static bool ssh_ntru_server_getkey(ecdh_key
*dh
, ptrlen remoteKey
,
1704 static const ecdh_keyalg ssh_ntru_server_vt
= {
1705 /* This vtable has no 'new' method, because it's constructed via
1706 * the selector vt below */
1707 .free
= ssh_ntru_server_free
,
1708 .getpublic
= ssh_ntru_server_getpublic
,
1709 .getkey
= ssh_ntru_server_getkey
,
1710 .description
= ssh_ntru_description
,
1713 static ecdh_key
*ssh_ntru_server_new(void)
1715 ntru_server_key
*nk
= snew(ntru_server_key
);
1716 nk
->ek
.vt
= &ssh_ntru_server_vt
;
1718 nk
->plaintext
= snewn(p_LIVE
, uint16_t);
1719 nk
->ciphertext_encoded
= strbuf_new_nm();
1720 nk
->confirmation_hash
= strbuf_new_nm();
1721 ntru_gen_short(nk
->plaintext
, p_LIVE
, w_LIVE
);
1723 nk
->curve25519
= ecdh_key_new(&ssh_ec_kex_curve25519
, false);
1728 static void ssh_ntru_server_free(ecdh_key
*dh
)
1730 ntru_server_key
*nk
= container_of(dh
, ntru_server_key
, ek
);
1731 ring_free(nk
->plaintext
, p_LIVE
);
1732 strbuf_free(nk
->ciphertext_encoded
);
1733 strbuf_free(nk
->confirmation_hash
);
1734 ecdh_key_free(nk
->curve25519
);
1738 static bool ssh_ntru_server_getkey(ecdh_key
*dh
, ptrlen remoteKey
,
1741 ntru_server_key
*nk
= container_of(dh
, ntru_server_key
, ek
);
1744 * In the server, getkey is called first, with the public
1745 * information received from the client. We expect the client to
1746 * have sent us a string containing a public key and a Curve25519
1749 BinarySource src
[1];
1750 BinarySource_BARE_INIT_PL(src
, remoteKey
);
1752 uint16_t *pubkey
= snewn(p_LIVE
, uint16_t);
1753 ntru_decode_pubkey(pubkey
, p_LIVE
, q_LIVE
, src
);
1754 ptrlen curve25519_remoteKey
= get_data(src
, 32);
1756 if (get_err(src
) || get_avail(src
)) {
1757 /* Hard-fail if the input wasn't exactly the right length */
1758 ring_free(pubkey
, p_LIVE
);
1763 * Main hash object which will combine the NTRU and Curve25519
1766 ssh_hash
*h
= ssh_hash_new(&ssh_sha512
);
1768 /* Reusable buffer for storing various hash outputs. */
1769 uint8_t hashdata
[64];
1775 /* Encrypt the plaintext we generated at construction time,
1776 * and encode the ciphertext into a strbuf so we can reuse it
1777 * for both the session hash and sending to the client. */
1778 uint16_t *ciphertext
= snewn(p_LIVE
, uint16_t);
1779 ntru_encrypt(ciphertext
, nk
->plaintext
, pubkey
, p_LIVE
, q_LIVE
);
1780 ntru_encode_ciphertext(ciphertext
, p_LIVE
, q_LIVE
,
1781 BinarySink_UPCAST(nk
->ciphertext_encoded
));
1782 ring_free(ciphertext
, p_LIVE
);
1784 /* Compute the confirmation hash, and write it into another
1786 ntru_confirmation_hash(hashdata
, nk
->plaintext
, pubkey
,
1788 put_data(nk
->confirmation_hash
, hashdata
, 32);
1790 /* Compute the session hash (which is easy on the server side,
1791 * requiring no conditional substitution). */
1792 ntru_session_hash(hashdata
, 1, nk
->plaintext
, p_LIVE
,
1793 ptrlen_from_strbuf(nk
->ciphertext_encoded
),
1794 ptrlen_from_strbuf(nk
->confirmation_hash
));
1796 /* And put the NTRU session hash into the main hash object. */
1797 put_data(h
, hashdata
, 32);
1799 /* Now we can free the public key */
1800 ring_free(pubkey
, p_LIVE
);
1807 strbuf
*otherkey
= strbuf_new_nm();
1809 /* Call out to Curve25519 to compute the shared secret from that
1811 bool ok
= ecdh_key_getkey(nk
->curve25519
, curve25519_remoteKey
,
1812 BinarySink_UPCAST(otherkey
));
1813 /* As on the client side, abort if Curve25519 reported failure */
1816 smemclr(hashdata
, sizeof(hashdata
));
1817 strbuf_free(otherkey
);
1821 /* As on the client side, decode Curve25519's mpint so we can
1822 * re-encode it appropriately for our hash preimage */
1823 BinarySource src
[1];
1824 BinarySource_BARE_INIT_PL(src
, ptrlen_from_strbuf(otherkey
));
1825 mp_int
*curvekey
= get_mp_ssh2(src
);
1827 for (unsigned i
= 32; i
-- > 0 ;)
1828 put_byte(h
, mp_get_byte(curvekey
, i
));
1831 strbuf_free(otherkey
);
1835 * Finish up: compute the final output hash (full 64 bytes of
1836 * SHA-512 this time), and return it encoded as a string.
1838 ssh_hash_final(h
, hashdata
);
1839 put_stringpl(bs
, make_ptrlen(hashdata
, sizeof(hashdata
)));
1840 smemclr(hashdata
, sizeof(hashdata
));
1845 static void ssh_ntru_server_getpublic(ecdh_key
*dh
, BinarySink
*bs
)
1847 ntru_server_key
*nk
= container_of(dh
, ntru_server_key
, ek
);
1850 * In the server, this function is called after getkey, so we
1851 * already have all our pieces prepared. Just concatenate them all
1852 * into the 'server's public data' string to go in ECDH_REPLY.
1854 put_datapl(bs
, ptrlen_from_strbuf(nk
->ciphertext_encoded
));
1855 put_datapl(bs
, ptrlen_from_strbuf(nk
->confirmation_hash
));
1856 ecdh_key_getpublic(nk
->curve25519
, bs
);
1859 /* ----------------------------------------------------------------------
1860 * Selector vtable that instantiates the appropriate one of the above,
1861 * depending on is_server.
1863 static ecdh_key
*ssh_ntru_new(const ssh_kex
*kex
, bool is_server
)
1866 return ssh_ntru_server_new();
1868 return ssh_ntru_client_new();
1871 static const ecdh_keyalg ssh_ntru_selector_vt
= {
1872 /* This is a never-instantiated vtable which only implements the
1873 * functions that don't require an instance. */
1874 .new = ssh_ntru_new
,
1875 .description
= ssh_ntru_description
,
1878 static const ssh_kex ssh_ntru_curve25519
= {
1879 .name
= "sntrup761x25519-sha512@openssh.com",
1880 .main_type
= KEXTYPE_ECDH
,
1881 .hash
= &ssh_sha512
,
1882 .ecdh_vt
= &ssh_ntru_selector_vt
,
1885 static const ssh_kex
*const hybrid_list
[] = {
1886 &ssh_ntru_curve25519
,
1889 const ssh_kexes ssh_ntru_hybrid_kex
= { lenof(hybrid_list
), hybrid_list
};