Re-sync with internal repository
[hiphop-php.git] / third-party / fizz / src / fizz / test / BogoShim.cpp
blobf2d2fda19665e4eebade0a3c59061d53e1c55a85
1 /*
2 * Copyright (c) 2018-present, Facebook, Inc.
3 * All rights reserved.
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
9 #include <fizz/client/AsyncFizzClient.h>
10 #include <fizz/crypto/Utils.h>
11 #include <fizz/crypto/aead/AESGCM128.h>
12 #include <fizz/crypto/aead/OpenSSLEVPCipher.h>
13 #include <fizz/server/AsyncFizzServer.h>
14 #include <fizz/server/TicketTypes.h>
15 #include <folly/String.h>
16 #include <folly/io/async/AsyncSSLSocket.h>
17 #include <folly/io/async/SSLContext.h>
18 #include <folly/portability/GFlags.h>
20 using namespace fizz;
21 using namespace fizz::client;
22 using namespace fizz::server;
23 using namespace folly;
24 using namespace folly::ssl;
26 DEFINE_int32(port, 0, "port to connect to");
27 DEFINE_bool(server, false, "act as a server, otherwise act as a client");
28 DEFINE_string(key_file, "", "key file");
29 DEFINE_string(cert_file, "", "cert file");
30 DEFINE_int32(resume_count, 0, "number of additional connections to open");
32 static constexpr int kUnimplemented = 89;
34 static std::vector<std::string> kKnownFlags{
35 "port",
36 "server",
37 "key_file",
38 "cert_file",
39 "resume_count"};
41 class BogoTestServer : public AsyncSocket::ConnectCallback,
42 public AsyncFizzServer::HandshakeCallback,
43 public AsyncSSLSocket::HandshakeCB,
44 public AsyncTransportWrapper::ReadCallback {
45 public:
46 BogoTestServer(
47 EventBase* evb,
48 uint16_t port,
49 std::shared_ptr<FizzServerContext> serverContext,
50 std::shared_ptr<SSLContext> sslContext)
51 : evb_(evb), serverContext_(serverContext), sslContext_(sslContext) {
52 socket_ = AsyncSocket::UniquePtr(new AsyncSocket(evb));
53 socket_->connect(this, "::", port, 1000);
56 void connectSuccess() noexcept override {
57 transport_ = AsyncFizzServer::UniquePtr(
58 new AsyncFizzServer(std::move(socket_), serverContext_));
59 transport_->accept(this);
62 void connectErr(const AsyncSocketException& ex) noexcept override {
63 LOG(INFO) << "TCP connect failed: " << ex.what();
64 socket_.reset();
65 success_ = false;
68 void fizzHandshakeSuccess(AsyncFizzServer*) noexcept override {
69 success_ = true;
70 transport_->setReadCB(this);
73 void fizzHandshakeError(
74 AsyncFizzServer*,
75 folly::exception_wrapper ex) noexcept override {
76 LOG(INFO) << "Handshake error: " << ex.what();
77 transport_.reset();
78 success_ = false;
81 void fizzHandshakeAttemptFallback(
82 std::unique_ptr<folly::IOBuf> clientHello) override {
83 auto fd = transport_->getUnderlyingTransport<AsyncSocket>()
84 ->detachNetworkSocket()
85 .toFd();
86 transport_.reset();
87 if (!sslContext_) {
88 unimplemented_ = true;
89 } else {
90 sslSocket_ = AsyncSSLSocket::UniquePtr(new AsyncSSLSocket(
91 sslContext_, evb_, folly::NetworkSocket::fromFd(fd)));
92 sslSocket_->setPreReceivedData(std::move(clientHello));
93 sslSocket_->sslAccept(this);
97 void getReadBuffer(void** /* bufReturn */, size_t* /* lenReturn */) override {
98 throw std::runtime_error("getReadBuffer not implemented");
101 void readDataAvailable(size_t /* len */) noexcept override {
102 CHECK(false) << "readDataAvailable not implemented";
105 bool isBufferMovable() noexcept override {
106 return true;
109 void readBufferAvailable(std::unique_ptr<IOBuf> buf) noexcept override {
110 io::Cursor cursor(buf.get());
111 std::unique_ptr<IOBuf> write = IOBuf::create(0);
112 io::Appender appender(write.get(), 50);
113 while (!cursor.isAtEnd()) {
114 uint8_t byte;
115 cursor.pull(&byte, 1);
116 byte ^= 0xff;
117 appender.push(&byte, 1);
119 transport_->writeChain(nullptr, std::move(write));
122 void readEOF() noexcept override {}
124 void readErr(const AsyncSocketException&) noexcept override {}
126 void handshakeSuc(folly::AsyncSSLSocket*) noexcept override {
127 success_ = true;
130 void handshakeErr(
131 folly::AsyncSSLSocket*,
132 const folly::AsyncSocketException& ex) noexcept override {
133 LOG(INFO) << "SSL Handshake error: " << ex.what();
134 sslSocket_.reset();
135 success_ = false;
138 bool unimplemented() const {
139 return unimplemented_;
142 bool success() const {
143 return *success_;
146 private:
147 EventBase* evb_;
148 AsyncSocket::UniquePtr socket_;
150 std::shared_ptr<FizzServerContext> serverContext_;
151 AsyncFizzServer::UniquePtr transport_;
153 std::shared_ptr<SSLContext> sslContext_;
154 AsyncSSLSocket::UniquePtr sslSocket_;
156 bool unimplemented_{false};
157 Optional<bool> success_;
160 class BogoTestClient : public AsyncSocket::ConnectCallback,
161 public AsyncFizzClient::HandshakeCallback,
162 public AsyncTransportWrapper::ReadCallback {
163 public:
164 BogoTestClient(
165 EventBase* evb,
166 uint16_t port,
167 std::shared_ptr<const FizzClientContext> clientContext)
168 : clientContext_(clientContext) {
169 socket_ = AsyncSocket::UniquePtr(new AsyncSocket(evb));
170 socket_->connect(this, "::", port, 1000);
173 void connectSuccess() noexcept override {
174 transport_ = AsyncFizzClient::UniquePtr(
175 new AsyncFizzClient(std::move(socket_), clientContext_));
176 transport_->connect(
177 this,
178 nullptr,
179 folly::none,
180 std::string("resumption-id"),
181 folly::Optional<std::vector<fizz::ech::ECHConfig>>(folly::none));
184 void connectErr(const AsyncSocketException& ex) noexcept override {
185 LOG(INFO) << "TCP connect failed: " << ex.what();
186 socket_.reset();
187 success_ = false;
190 void fizzHandshakeSuccess(AsyncFizzClient*) noexcept override {
191 success_ = true;
192 transport_->setReadCB(this);
195 void fizzHandshakeError(
196 AsyncFizzClient*,
197 folly::exception_wrapper ex) noexcept override {
198 LOG(INFO) << "Handshake error: " << ex.what();
199 transport_.reset();
201 // If the server sent us a protocol_version alert assume that
202 if (ex.what().find(
203 "received alert: protocol_version, in state ExpectingServerHello") !=
204 std::string::npos) {
205 unimplemented_ = true;
207 success_ = false;
210 void getReadBuffer(void** /* bufReturn */, size_t* /* lenReturn */) override {
211 throw std::runtime_error("getReadBuffer not implemented");
214 void readDataAvailable(size_t /* len */) noexcept override {
215 CHECK(false) << "readDataAvailable not implemented";
218 bool isBufferMovable() noexcept override {
219 return true;
222 void readBufferAvailable(std::unique_ptr<IOBuf> buf) noexcept override {
223 io::Cursor cursor(buf.get());
224 std::unique_ptr<IOBuf> write = IOBuf::create(0);
225 io::Appender appender(write.get(), 50);
226 while (!cursor.isAtEnd()) {
227 uint8_t byte;
228 cursor.pull(&byte, 1);
229 byte ^= 0xff;
230 appender.push(&byte, 1);
232 transport_->writeChain(nullptr, std::move(write));
235 void readEOF() noexcept override {}
237 void readErr(const AsyncSocketException&) noexcept override {}
239 bool unimplemented() const {
240 return unimplemented_;
243 bool success() const {
244 return *success_;
247 private:
248 AsyncSocket::UniquePtr socket_;
250 std::shared_ptr<const FizzClientContext> clientContext_;
251 AsyncFizzClient::UniquePtr transport_;
253 bool unimplemented_{false};
254 Optional<bool> success_;
257 class TestRsaCert : public SelfCertImpl<KeyType::RSA> {
258 public:
259 using SelfCertImpl<KeyType::RSA>::SelfCertImpl;
260 std::string getIdentity() const override {
261 return "testrsacert";
265 class TestP256Cert : public SelfCertImpl<KeyType::P256> {
266 public:
267 using SelfCertImpl<KeyType::P256>::SelfCertImpl;
268 std::string getIdentity() const override {
269 return "testp256cert";
273 std::unique_ptr<SelfCert> readSelfCert() {
274 BioUniquePtr b(BIO_new(BIO_s_file()));
275 BIO_read_filename(b.get(), FLAGS_cert_file.c_str());
276 std::vector<X509UniquePtr> certs;
277 while (true) {
278 X509UniquePtr x509(PEM_read_bio_X509(b.get(), nullptr, nullptr, nullptr));
279 if (!x509) {
280 break;
281 } else {
282 certs.push_back(std::move(x509));
285 if (certs.empty()) {
286 throw std::runtime_error("could not read cert");
289 b.reset(BIO_new(BIO_s_file()));
290 BIO_read_filename(b.get(), FLAGS_key_file.c_str());
291 EvpPkeyUniquePtr key(
292 PEM_read_bio_PrivateKey(b.get(), nullptr, nullptr, nullptr));
294 std::unique_ptr<SelfCert> cert;
295 if (EVP_PKEY_id(key.get()) == EVP_PKEY_RSA) {
296 return std::make_unique<TestRsaCert>(std::move(key), std::move(certs));
297 } else if (EVP_PKEY_id(key.get()) == EVP_PKEY_EC) {
298 return std::make_unique<TestP256Cert>(std::move(key), std::move(certs));
299 } else {
300 throw std::runtime_error("unknown cert type");
304 int serverTest() {
305 auto certManager = std::make_shared<CertManager>();
306 certManager->addCert(readSelfCert(), true);
308 auto serverContext = std::make_shared<FizzServerContext>();
309 serverContext->setCertManager(certManager);
310 serverContext->setSupportedAlpns({"h2", "http/1.1"});
311 serverContext->setVersionFallbackEnabled(true);
313 auto ticketCipher = std::make_shared<AES128TicketCipher>(
314 serverContext->getFactoryPtr(), std::move(certManager));
315 auto ticketSeed = RandomGenerator<32>().generateRandom();
316 ticketCipher->setTicketSecrets({{range(ticketSeed)}});
317 server::TicketPolicy policy;
318 policy.setTicketValidity(std::chrono::seconds(60));
319 ticketCipher->setPolicy(std::move(policy));
321 serverContext->setTicketCipher(ticketCipher);
323 EventBase evb;
324 std::vector<std::unique_ptr<BogoTestServer>> servers;
325 for (size_t i = 0; i <= size_t(FLAGS_resume_count); i++) {
326 servers.push_back(std::make_unique<BogoTestServer>(
327 &evb, FLAGS_port, serverContext, nullptr));
329 evb.loop();
330 for (const auto& server : servers) {
331 if (server->unimplemented()) {
332 LOG(INFO) << "Testing unimplemented feature.";
333 return kUnimplemented;
336 for (const auto& server : servers) {
337 if (!server->success()) {
338 LOG(INFO) << "Connection failed.";
339 return 1;
343 return 0;
346 int clientTest() {
347 auto clientContext = std::make_shared<FizzClientContext>();
348 clientContext->setCompatibilityMode(true);
350 if (!FLAGS_cert_file.empty()) {
351 clientContext->setClientCertificate(readSelfCert());
354 EventBase evb;
355 if (FLAGS_resume_count >= 1) {
356 return kUnimplemented;
358 auto client =
359 std::make_unique<BogoTestClient>(&evb, FLAGS_port, clientContext);
360 evb.loop();
361 if (client->unimplemented()) {
362 LOG(INFO) << "Testing unimplemented feature.";
363 return kUnimplemented;
365 if (!client->success()) {
366 LOG(INFO) << "Connection failed.";
367 return 1;
370 return 0;
373 int main(int argc, char** argv) {
374 // Convert "-" in args to "_" so that we can use GFLAGS.
375 for (int i = 1; i < argc; i++) {
376 if (argv[i][0] == '-') {
377 for (char* j = argv[i] + 2; *j; j++) {
378 if (*j == '-') {
379 *j = '_';
382 if (std::find(
383 kKnownFlags.begin(),
384 kKnownFlags.end(),
385 std::string(argv[i] + 1)) == kKnownFlags.end()) {
386 LOG(INFO) << "unknown flag: " << argv[i];
387 return kUnimplemented;
392 gflags::ParseCommandLineFlags(&argc, &argv, true);
393 google::InitGoogleLogging(argv[0]);
394 CryptoUtils::init();
396 if (FLAGS_port == 0) {
397 throw std::runtime_error("must specify port");
400 if (FLAGS_server) {
401 return serverTest();
402 } else {
403 return clientTest();