Update copyrights to 2021, using "make update-copyright"
[tor.git] / src / test / test_ntor_cl.c
blob94270f1fd61a66bd036db70b8af05dfd67476235
1 /* Copyright (c) 2012-2021, The Tor Project, Inc. */
2 /* See LICENSE for licensing information */
4 #include "orconfig.h"
5 #include <stdio.h>
6 #include <stdlib.h>
8 #define ONION_NTOR_PRIVATE
9 #include "core/or/or.h"
10 #include "lib/crypt_ops/crypto_cipher.h"
11 #include "lib/crypt_ops/crypto_curve25519.h"
12 #include "lib/crypt_ops/crypto_init.h"
13 #include "core/crypto/onion_ntor.h"
15 #define N_ARGS(n) STMT_BEGIN { \
16 if (argc < (n)) { \
17 fprintf(stderr, "%s needs %d arguments.\n",argv[1],n); \
18 return 1; \
19 } \
20 } STMT_END
21 #define BASE16(idx, var, n) STMT_BEGIN { \
22 const char *s = argv[(idx)]; \
23 if (base16_decode((char*)var, n, s, strlen(s)) < (int)n ) { \
24 fprintf(stderr, "couldn't decode argument %d (%s)\n",idx,s); \
25 return 1; \
26 } \
27 } STMT_END
28 #define INT(idx, var) STMT_BEGIN { \
29 var = atoi(argv[(idx)]); \
30 if (var <= 0) { \
31 fprintf(stderr, "bad integer argument %d (%s)\n",idx,argv[(idx)]); \
32 } \
33 } STMT_END
35 static int
36 client1(int argc, char **argv)
38 /* client1 nodeID B -> msg state */
39 curve25519_public_key_t B;
40 uint8_t node_id[DIGEST_LEN];
41 ntor_handshake_state_t *state = NULL;
42 uint8_t msg[NTOR_ONIONSKIN_LEN];
44 char buf[1024];
46 N_ARGS(4);
47 BASE16(2, node_id, DIGEST_LEN);
48 BASE16(3, B.public_key, CURVE25519_PUBKEY_LEN);
50 if (onion_skin_ntor_create(node_id, &B, &state, msg)<0) {
51 fprintf(stderr, "handshake failed");
52 return 2;
55 base16_encode(buf, sizeof(buf), (const char*)msg, sizeof(msg));
56 printf("%s\n", buf);
57 base16_encode(buf, sizeof(buf), (void*)state, sizeof(*state));
58 printf("%s\n", buf);
60 ntor_handshake_state_free(state);
61 return 0;
64 static int
65 server1(int argc, char **argv)
67 uint8_t msg_in[NTOR_ONIONSKIN_LEN];
68 curve25519_keypair_t kp;
69 di_digest256_map_t *keymap=NULL;
70 uint8_t node_id[DIGEST_LEN];
71 int keybytes;
73 uint8_t msg_out[NTOR_REPLY_LEN];
74 uint8_t *keys = NULL;
75 char *hexkeys = NULL;
76 int result = 0;
78 char buf[256];
80 /* server1: b nodeID msg N -> msg keys */
81 N_ARGS(6);
82 BASE16(2, kp.seckey.secret_key, CURVE25519_SECKEY_LEN);
83 BASE16(3, node_id, DIGEST_LEN);
84 BASE16(4, msg_in, NTOR_ONIONSKIN_LEN);
85 INT(5, keybytes);
87 curve25519_public_key_generate(&kp.pubkey, &kp.seckey);
88 dimap_add_entry(&keymap, kp.pubkey.public_key, &kp);
90 keys = tor_malloc(keybytes);
91 hexkeys = tor_malloc(keybytes*2+1);
92 if (onion_skin_ntor_server_handshake(
93 msg_in, keymap, NULL, node_id, msg_out, keys,
94 (size_t)keybytes)<0) {
95 fprintf(stderr, "handshake failed");
96 result = 2;
97 goto done;
100 base16_encode(buf, sizeof(buf), (const char*)msg_out, sizeof(msg_out));
101 printf("%s\n", buf);
102 base16_encode(hexkeys, keybytes*2+1, (const char*)keys, keybytes);
103 printf("%s\n", hexkeys);
105 done:
106 tor_free(keys);
107 tor_free(hexkeys);
108 dimap_free(keymap, NULL);
109 return result;
112 static int
113 client2(int argc, char **argv)
115 struct ntor_handshake_state_t state;
116 uint8_t msg[NTOR_REPLY_LEN];
117 int keybytes;
118 uint8_t *keys;
119 char *hexkeys;
120 int result = 0;
122 N_ARGS(5);
123 BASE16(2, (&state), sizeof(state));
124 BASE16(3, msg, sizeof(msg));
125 INT(4, keybytes);
127 keys = tor_malloc(keybytes);
128 hexkeys = tor_malloc(keybytes*2+1);
129 if (onion_skin_ntor_client_handshake(&state, msg, keys, keybytes, NULL)<0) {
130 fprintf(stderr, "handshake failed");
131 result = 2;
132 goto done;
135 base16_encode(hexkeys, keybytes*2+1, (const char*)keys, keybytes);
136 printf("%s\n", hexkeys);
138 done:
139 tor_free(keys);
140 tor_free(hexkeys);
141 return result;
145 main(int argc, char **argv)
148 client1: nodeID B -> msg state
149 server1: b nodeID msg N -> msg keys
150 client2: state msg N -> keys
152 if (argc < 2) {
153 fprintf(stderr, "I need arguments. Read source for more info.\n");
154 return 1;
157 init_logging(1);
158 curve25519_init();
159 if (crypto_global_init(0, NULL, NULL) < 0)
160 return 1;
162 if (!strcmp(argv[1], "client1")) {
163 return client1(argc, argv);
164 } else if (!strcmp(argv[1], "server1")) {
165 return server1(argc, argv);
166 } else if (!strcmp(argv[1], "client2")) {
167 return client2(argc, argv);
168 } else {
169 fprintf(stderr, "What's a %s?\n", argv[1]);
170 return 1;