2 * Copyright (c) 2018-present, Facebook, Inc.
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
9 #include <fizz/client/AsyncFizzClient.h>
10 #include <fizz/crypto/Utils.h>
11 #include <fizz/crypto/aead/AESGCM128.h>
12 #include <fizz/crypto/aead/OpenSSLEVPCipher.h>
13 #include <fizz/server/AsyncFizzServer.h>
14 #include <fizz/server/TicketTypes.h>
15 #include <folly/String.h>
16 #include <folly/io/async/AsyncSSLSocket.h>
17 #include <folly/io/async/SSLContext.h>
18 #include <folly/portability/GFlags.h>
21 using namespace fizz::client
;
22 using namespace fizz::server
;
23 using namespace folly
;
24 using namespace folly::ssl
;
26 DEFINE_int32(port
, 0, "port to connect to");
27 DEFINE_bool(server
, false, "act as a server, otherwise act as a client");
28 DEFINE_string(key_file
, "", "key file");
29 DEFINE_string(cert_file
, "", "cert file");
30 DEFINE_int32(resume_count
, 0, "number of additional connections to open");
32 static constexpr int kUnimplemented
= 89;
34 static std::vector
<std::string
> kKnownFlags
{
41 class BogoTestServer
: public AsyncSocket::ConnectCallback
,
42 public AsyncFizzServer::HandshakeCallback
,
43 public AsyncSSLSocket::HandshakeCB
,
44 public AsyncTransportWrapper::ReadCallback
{
49 std::shared_ptr
<FizzServerContext
> serverContext
,
50 std::shared_ptr
<SSLContext
> sslContext
)
51 : evb_(evb
), serverContext_(serverContext
), sslContext_(sslContext
) {
52 socket_
= AsyncSocket::UniquePtr(new AsyncSocket(evb
));
53 socket_
->connect(this, "::", port
, 1000);
56 void connectSuccess() noexcept override
{
57 transport_
= AsyncFizzServer::UniquePtr(
58 new AsyncFizzServer(std::move(socket_
), serverContext_
));
59 transport_
->accept(this);
62 void connectErr(const AsyncSocketException
& ex
) noexcept override
{
63 LOG(INFO
) << "TCP connect failed: " << ex
.what();
68 void fizzHandshakeSuccess(AsyncFizzServer
*) noexcept override
{
70 transport_
->setReadCB(this);
73 void fizzHandshakeError(
75 folly::exception_wrapper ex
) noexcept override
{
76 LOG(INFO
) << "Handshake error: " << ex
.what();
81 void fizzHandshakeAttemptFallback(
82 std::unique_ptr
<folly::IOBuf
> clientHello
) override
{
83 auto fd
= transport_
->getUnderlyingTransport
<AsyncSocket
>()
84 ->detachNetworkSocket()
88 unimplemented_
= true;
90 sslSocket_
= AsyncSSLSocket::UniquePtr(new AsyncSSLSocket(
91 sslContext_
, evb_
, folly::NetworkSocket::fromFd(fd
)));
92 sslSocket_
->setPreReceivedData(std::move(clientHello
));
93 sslSocket_
->sslAccept(this);
97 void getReadBuffer(void** /* bufReturn */, size_t* /* lenReturn */) override
{
98 throw std::runtime_error("getReadBuffer not implemented");
101 void readDataAvailable(size_t /* len */) noexcept override
{
102 CHECK(false) << "readDataAvailable not implemented";
105 bool isBufferMovable() noexcept override
{
109 void readBufferAvailable(std::unique_ptr
<IOBuf
> buf
) noexcept override
{
110 io::Cursor
cursor(buf
.get());
111 std::unique_ptr
<IOBuf
> write
= IOBuf::create(0);
112 io::Appender
appender(write
.get(), 50);
113 while (!cursor
.isAtEnd()) {
115 cursor
.pull(&byte
, 1);
117 appender
.push(&byte
, 1);
119 transport_
->writeChain(nullptr, std::move(write
));
122 void readEOF() noexcept override
{}
124 void readErr(const AsyncSocketException
&) noexcept override
{}
126 void handshakeSuc(folly::AsyncSSLSocket
*) noexcept override
{
131 folly::AsyncSSLSocket
*,
132 const folly::AsyncSocketException
& ex
) noexcept override
{
133 LOG(INFO
) << "SSL Handshake error: " << ex
.what();
138 bool unimplemented() const {
139 return unimplemented_
;
142 bool success() const {
148 AsyncSocket::UniquePtr socket_
;
150 std::shared_ptr
<FizzServerContext
> serverContext_
;
151 AsyncFizzServer::UniquePtr transport_
;
153 std::shared_ptr
<SSLContext
> sslContext_
;
154 AsyncSSLSocket::UniquePtr sslSocket_
;
156 bool unimplemented_
{false};
157 Optional
<bool> success_
;
160 class BogoTestClient
: public AsyncSocket::ConnectCallback
,
161 public AsyncFizzClient::HandshakeCallback
,
162 public AsyncTransportWrapper::ReadCallback
{
167 std::shared_ptr
<const FizzClientContext
> clientContext
)
168 : clientContext_(clientContext
) {
169 socket_
= AsyncSocket::UniquePtr(new AsyncSocket(evb
));
170 socket_
->connect(this, "::", port
, 1000);
173 void connectSuccess() noexcept override
{
174 transport_
= AsyncFizzClient::UniquePtr(
175 new AsyncFizzClient(std::move(socket_
), clientContext_
));
180 std::string("resumption-id"),
181 folly::Optional
<std::vector
<fizz::ech::ECHConfig
>>(folly::none
));
184 void connectErr(const AsyncSocketException
& ex
) noexcept override
{
185 LOG(INFO
) << "TCP connect failed: " << ex
.what();
190 void fizzHandshakeSuccess(AsyncFizzClient
*) noexcept override
{
192 transport_
->setReadCB(this);
195 void fizzHandshakeError(
197 folly::exception_wrapper ex
) noexcept override
{
198 LOG(INFO
) << "Handshake error: " << ex
.what();
201 // If the server sent us a protocol_version alert assume that
203 "received alert: protocol_version, in state ExpectingServerHello") !=
205 unimplemented_
= true;
210 void getReadBuffer(void** /* bufReturn */, size_t* /* lenReturn */) override
{
211 throw std::runtime_error("getReadBuffer not implemented");
214 void readDataAvailable(size_t /* len */) noexcept override
{
215 CHECK(false) << "readDataAvailable not implemented";
218 bool isBufferMovable() noexcept override
{
222 void readBufferAvailable(std::unique_ptr
<IOBuf
> buf
) noexcept override
{
223 io::Cursor
cursor(buf
.get());
224 std::unique_ptr
<IOBuf
> write
= IOBuf::create(0);
225 io::Appender
appender(write
.get(), 50);
226 while (!cursor
.isAtEnd()) {
228 cursor
.pull(&byte
, 1);
230 appender
.push(&byte
, 1);
232 transport_
->writeChain(nullptr, std::move(write
));
235 void readEOF() noexcept override
{}
237 void readErr(const AsyncSocketException
&) noexcept override
{}
239 bool unimplemented() const {
240 return unimplemented_
;
243 bool success() const {
248 AsyncSocket::UniquePtr socket_
;
250 std::shared_ptr
<const FizzClientContext
> clientContext_
;
251 AsyncFizzClient::UniquePtr transport_
;
253 bool unimplemented_
{false};
254 Optional
<bool> success_
;
257 class TestRsaCert
: public SelfCertImpl
<KeyType::RSA
> {
259 using SelfCertImpl
<KeyType::RSA
>::SelfCertImpl
;
260 std::string
getIdentity() const override
{
261 return "testrsacert";
265 class TestP256Cert
: public SelfCertImpl
<KeyType::P256
> {
267 using SelfCertImpl
<KeyType::P256
>::SelfCertImpl
;
268 std::string
getIdentity() const override
{
269 return "testp256cert";
273 std::unique_ptr
<SelfCert
> readSelfCert() {
274 BioUniquePtr
b(BIO_new(BIO_s_file()));
275 BIO_read_filename(b
.get(), FLAGS_cert_file
.c_str());
276 std::vector
<X509UniquePtr
> certs
;
278 X509UniquePtr
x509(PEM_read_bio_X509(b
.get(), nullptr, nullptr, nullptr));
282 certs
.push_back(std::move(x509
));
286 throw std::runtime_error("could not read cert");
289 b
.reset(BIO_new(BIO_s_file()));
290 BIO_read_filename(b
.get(), FLAGS_key_file
.c_str());
291 EvpPkeyUniquePtr
key(
292 PEM_read_bio_PrivateKey(b
.get(), nullptr, nullptr, nullptr));
294 std::unique_ptr
<SelfCert
> cert
;
295 if (EVP_PKEY_id(key
.get()) == EVP_PKEY_RSA
) {
296 return std::make_unique
<TestRsaCert
>(std::move(key
), std::move(certs
));
297 } else if (EVP_PKEY_id(key
.get()) == EVP_PKEY_EC
) {
298 return std::make_unique
<TestP256Cert
>(std::move(key
), std::move(certs
));
300 throw std::runtime_error("unknown cert type");
305 auto certManager
= std::make_shared
<CertManager
>();
306 certManager
->addCert(readSelfCert(), true);
308 auto serverContext
= std::make_shared
<FizzServerContext
>();
309 serverContext
->setCertManager(certManager
);
310 serverContext
->setSupportedAlpns({"h2", "http/1.1"});
311 serverContext
->setVersionFallbackEnabled(true);
313 auto ticketCipher
= std::make_shared
<AES128TicketCipher
>(
314 serverContext
->getFactoryPtr(), std::move(certManager
));
315 auto ticketSeed
= RandomGenerator
<32>().generateRandom();
316 ticketCipher
->setTicketSecrets({{range(ticketSeed
)}});
317 server::TicketPolicy policy
;
318 policy
.setTicketValidity(std::chrono::seconds(60));
319 ticketCipher
->setPolicy(std::move(policy
));
321 serverContext
->setTicketCipher(ticketCipher
);
324 std::vector
<std::unique_ptr
<BogoTestServer
>> servers
;
325 for (size_t i
= 0; i
<= size_t(FLAGS_resume_count
); i
++) {
326 servers
.push_back(std::make_unique
<BogoTestServer
>(
327 &evb
, FLAGS_port
, serverContext
, nullptr));
330 for (const auto& server
: servers
) {
331 if (server
->unimplemented()) {
332 LOG(INFO
) << "Testing unimplemented feature.";
333 return kUnimplemented
;
336 for (const auto& server
: servers
) {
337 if (!server
->success()) {
338 LOG(INFO
) << "Connection failed.";
347 auto clientContext
= std::make_shared
<FizzClientContext
>();
348 clientContext
->setCompatibilityMode(true);
350 if (!FLAGS_cert_file
.empty()) {
351 clientContext
->setClientCertificate(readSelfCert());
355 if (FLAGS_resume_count
>= 1) {
356 return kUnimplemented
;
359 std::make_unique
<BogoTestClient
>(&evb
, FLAGS_port
, clientContext
);
361 if (client
->unimplemented()) {
362 LOG(INFO
) << "Testing unimplemented feature.";
363 return kUnimplemented
;
365 if (!client
->success()) {
366 LOG(INFO
) << "Connection failed.";
373 int main(int argc
, char** argv
) {
374 // Convert "-" in args to "_" so that we can use GFLAGS.
375 for (int i
= 1; i
< argc
; i
++) {
376 if (argv
[i
][0] == '-') {
377 for (char* j
= argv
[i
] + 2; *j
; j
++) {
385 std::string(argv
[i
] + 1)) == kKnownFlags
.end()) {
386 LOG(INFO
) << "unknown flag: " << argv
[i
];
387 return kUnimplemented
;
392 gflags::ParseCommandLineFlags(&argc
, &argv
, true);
393 google::InitGoogleLogging(argv
[0]);
396 if (FLAGS_port
== 0) {
397 throw std::runtime_error("must specify port");