Merge Chromium + Blink git repositories
[chromium-blink-merge.git] / net / tools / quic / quic_dispatcher_test.cc
blobc45fe3a8b59f9354d289743f3552f64e622f89d8
1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "net/tools/quic/quic_dispatcher.h"
7 #include <ostream>
8 #include <string>
10 #include "base/strings/string_piece.h"
11 #include "net/quic/crypto/crypto_handshake.h"
12 #include "net/quic/crypto/quic_crypto_server_config.h"
13 #include "net/quic/crypto/quic_random.h"
14 #include "net/quic/quic_connection_helper.h"
15 #include "net/quic/quic_crypto_stream.h"
16 #include "net/quic/quic_flags.h"
17 #include "net/quic/quic_utils.h"
18 #include "net/quic/test_tools/quic_test_utils.h"
19 #include "net/tools/epoll_server/epoll_server.h"
20 #include "net/tools/quic/quic_epoll_connection_helper.h"
21 #include "net/tools/quic/quic_packet_writer_wrapper.h"
22 #include "net/tools/quic/quic_time_wait_list_manager.h"
23 #include "net/tools/quic/test_tools/mock_quic_time_wait_list_manager.h"
24 #include "net/tools/quic/test_tools/quic_dispatcher_peer.h"
25 #include "testing/gmock/include/gmock/gmock.h"
26 #include "testing/gtest/include/gtest/gtest.h"
28 using base::StringPiece;
29 using net::EpollServer;
30 using net::test::ConstructEncryptedPacket;
31 using net::test::MockConnection;
32 using net::test::ValueRestore;
33 using net::test::TestWriterFactory;
34 using std::string;
35 using std::vector;
36 using testing::DoAll;
37 using testing::InSequence;
38 using testing::Invoke;
39 using testing::WithoutArgs;
40 using testing::_;
42 namespace net {
43 namespace tools {
44 namespace test {
45 namespace {
47 class TestQuicSpdyServerSession : public QuicServerSession {
48 public:
49 TestQuicSpdyServerSession(const QuicConfig& config,
50 QuicConnection* connection,
51 const QuicCryptoServerConfig* crypto_config)
52 : QuicServerSession(config, connection, nullptr, crypto_config),
53 crypto_stream_(QuicServerSession::GetCryptoStream()) {}
54 ~TestQuicSpdyServerSession() override{};
56 MOCK_METHOD2(OnConnectionClosed, void(QuicErrorCode error, bool from_peer));
57 MOCK_METHOD1(CreateIncomingDynamicStream, QuicDataStream*(QuicStreamId id));
58 MOCK_METHOD0(CreateOutgoingDynamicStream, QuicDataStream*());
60 void SetCryptoStream(QuicCryptoServerStream* crypto_stream) {
61 crypto_stream_ = crypto_stream;
64 QuicCryptoServerStream* GetCryptoStream() override { return crypto_stream_; }
66 private:
67 QuicCryptoServerStream* crypto_stream_;
69 DISALLOW_COPY_AND_ASSIGN(TestQuicSpdyServerSession);
72 class TestDispatcher : public QuicDispatcher {
73 public:
74 TestDispatcher(const QuicConfig& config,
75 const QuicCryptoServerConfig* crypto_config,
76 EpollServer* eps)
77 : QuicDispatcher(config,
78 crypto_config,
79 QuicSupportedVersions(),
80 new QuicDispatcher::DefaultPacketWriterFactory(),
81 new QuicEpollConnectionHelper(eps)) {}
83 MOCK_METHOD3(CreateQuicSession,
84 QuicServerSession*(QuicConnectionId connection_id,
85 const IPEndPoint& server_address,
86 const IPEndPoint& client_address));
88 using QuicDispatcher::current_server_address;
89 using QuicDispatcher::current_client_address;
92 // A Connection class which unregisters the session from the dispatcher when
93 // sending connection close.
94 // It'd be slightly more realistic to do this from the Session but it would
95 // involve a lot more mocking.
96 class MockServerConnection : public MockConnection {
97 public:
98 MockServerConnection(QuicConnectionId connection_id,
99 QuicDispatcher* dispatcher)
100 : MockConnection(connection_id, Perspective::IS_SERVER),
101 dispatcher_(dispatcher) {}
103 void UnregisterOnConnectionClosed() {
104 LOG(ERROR) << "Unregistering " << connection_id();
105 dispatcher_->OnConnectionClosed(connection_id(), QUIC_NO_ERROR);
108 private:
109 QuicDispatcher* dispatcher_;
112 QuicServerSession* CreateSession(QuicDispatcher* dispatcher,
113 const QuicConfig& config,
114 QuicConnectionId connection_id,
115 const IPEndPoint& client_address,
116 const QuicCryptoServerConfig* crypto_config,
117 TestQuicSpdyServerSession** session) {
118 MockServerConnection* connection =
119 new MockServerConnection(connection_id, dispatcher);
120 *session = new TestQuicSpdyServerSession(config, connection, crypto_config);
121 connection->set_visitor(*session);
122 ON_CALL(*connection, SendConnectionClose(_)).WillByDefault(
123 WithoutArgs(Invoke(
124 connection, &MockServerConnection::UnregisterOnConnectionClosed)));
125 EXPECT_CALL(*reinterpret_cast<MockConnection*>((*session)->connection()),
126 ProcessUdpPacket(_, client_address, _));
128 return *session;
131 class QuicDispatcherTest : public ::testing::Test {
132 public:
133 QuicDispatcherTest()
134 : helper_(&eps_),
135 crypto_config_(QuicCryptoServerConfig::TESTING,
136 QuicRandom::GetInstance()),
137 dispatcher_(config_, &crypto_config_, &eps_),
138 time_wait_list_manager_(nullptr),
139 session1_(nullptr),
140 session2_(nullptr) {
141 dispatcher_.InitializeWithWriter(new QuicDefaultPacketWriter(1));
144 ~QuicDispatcherTest() override {}
146 MockConnection* connection1() {
147 return reinterpret_cast<MockConnection*>(session1_->connection());
150 MockConnection* connection2() {
151 return reinterpret_cast<MockConnection*>(session2_->connection());
154 void ProcessPacket(IPEndPoint client_address,
155 QuicConnectionId connection_id,
156 bool has_version_flag,
157 const string& data) {
158 ProcessPacket(client_address, connection_id, has_version_flag, data,
159 PACKET_8BYTE_CONNECTION_ID, PACKET_6BYTE_PACKET_NUMBER);
162 void ProcessPacket(IPEndPoint client_address,
163 QuicConnectionId connection_id,
164 bool has_version_flag,
165 const string& data,
166 QuicConnectionIdLength connection_id_length,
167 QuicPacketNumberLength packet_number_length) {
168 ProcessPacket(client_address, connection_id, has_version_flag, data,
169 connection_id_length, packet_number_length, 1);
172 void ProcessPacket(IPEndPoint client_address,
173 QuicConnectionId connection_id,
174 bool has_version_flag,
175 const string& data,
176 QuicConnectionIdLength connection_id_length,
177 QuicPacketNumberLength packet_number_length,
178 QuicPacketNumber packet_number) {
179 scoped_ptr<QuicEncryptedPacket> packet(ConstructEncryptedPacket(
180 connection_id, has_version_flag, false, packet_number, data,
181 connection_id_length, packet_number_length));
182 data_ = string(packet->data(), packet->length());
183 dispatcher_.ProcessPacket(server_address_, client_address, *packet);
186 void ValidatePacket(const QuicEncryptedPacket& packet) {
187 EXPECT_EQ(data_.length(), packet.AsStringPiece().length());
188 EXPECT_EQ(data_, packet.AsStringPiece());
191 void CreateTimeWaitListManager() {
192 time_wait_list_manager_ = new MockTimeWaitListManager(
193 QuicDispatcherPeer::GetWriter(&dispatcher_), &dispatcher_, &helper_);
194 // dispatcher_ takes the ownership of time_wait_list_manager_.
195 QuicDispatcherPeer::SetTimeWaitListManager(&dispatcher_,
196 time_wait_list_manager_);
199 EpollServer eps_;
200 QuicEpollConnectionHelper helper_;
201 QuicConfig config_;
202 QuicCryptoServerConfig crypto_config_;
203 IPEndPoint server_address_;
204 TestDispatcher dispatcher_;
205 MockTimeWaitListManager* time_wait_list_manager_;
206 TestQuicSpdyServerSession* session1_;
207 TestQuicSpdyServerSession* session2_;
208 string data_;
211 TEST_F(QuicDispatcherTest, ProcessPackets) {
212 IPEndPoint client_address(net::test::Loopback4(), 1);
213 server_address_ = IPEndPoint(net::test::Any4(), 5);
215 EXPECT_CALL(dispatcher_, CreateQuicSession(1, _, client_address))
216 .WillOnce(testing::Return(CreateSession(&dispatcher_, config_, 1,
217 client_address, &crypto_config_,
218 &session1_)));
219 ProcessPacket(client_address, 1, true, "foo");
220 EXPECT_EQ(client_address, dispatcher_.current_client_address());
221 EXPECT_EQ(server_address_, dispatcher_.current_server_address());
223 EXPECT_CALL(dispatcher_, CreateQuicSession(2, _, client_address))
224 .WillOnce(testing::Return(CreateSession(&dispatcher_, config_, 2,
225 client_address, &crypto_config_,
226 &session2_)));
227 ProcessPacket(client_address, 2, true, "bar");
229 EXPECT_CALL(*reinterpret_cast<MockConnection*>(session1_->connection()),
230 ProcessUdpPacket(_, _, _)).Times(1).
231 WillOnce(testing::WithArgs<2>(Invoke(
232 this, &QuicDispatcherTest::ValidatePacket)));
233 ProcessPacket(client_address, 1, false, "eep");
236 TEST_F(QuicDispatcherTest, Shutdown) {
237 IPEndPoint client_address(net::test::Loopback4(), 1);
239 EXPECT_CALL(dispatcher_, CreateQuicSession(_, _, client_address))
240 .WillOnce(testing::Return(CreateSession(&dispatcher_, config_, 1,
241 client_address, &crypto_config_,
242 &session1_)));
244 ProcessPacket(client_address, 1, true, "foo");
246 EXPECT_CALL(*reinterpret_cast<MockConnection*>(session1_->connection()),
247 SendConnectionClose(QUIC_PEER_GOING_AWAY));
249 dispatcher_.Shutdown();
252 TEST_F(QuicDispatcherTest, TimeWaitListManager) {
253 CreateTimeWaitListManager();
255 // Create a new session.
256 IPEndPoint client_address(net::test::Loopback4(), 1);
257 QuicConnectionId connection_id = 1;
258 EXPECT_CALL(dispatcher_, CreateQuicSession(connection_id, _, client_address))
259 .WillOnce(testing::Return(CreateSession(&dispatcher_, config_,
260 connection_id, client_address,
261 &crypto_config_, &session1_)));
262 ProcessPacket(client_address, connection_id, true, "foo");
264 // Close the connection by sending public reset packet.
265 QuicPublicResetPacket packet;
266 packet.public_header.connection_id = connection_id;
267 packet.public_header.reset_flag = true;
268 packet.public_header.version_flag = false;
269 packet.rejected_packet_number = 19191;
270 packet.nonce_proof = 132232;
271 scoped_ptr<QuicEncryptedPacket> encrypted(
272 QuicFramer::BuildPublicResetPacket(packet));
273 EXPECT_CALL(*session1_, OnConnectionClosed(QUIC_PUBLIC_RESET, true)).Times(1)
274 .WillOnce(WithoutArgs(Invoke(
275 reinterpret_cast<MockServerConnection*>(session1_->connection()),
276 &MockServerConnection::UnregisterOnConnectionClosed)));
277 EXPECT_CALL(*reinterpret_cast<MockConnection*>(session1_->connection()),
278 ProcessUdpPacket(_, _, _))
279 .WillOnce(Invoke(
280 reinterpret_cast<MockConnection*>(session1_->connection()),
281 &MockConnection::ReallyProcessUdpPacket));
282 dispatcher_.ProcessPacket(IPEndPoint(), client_address, *encrypted);
283 EXPECT_TRUE(time_wait_list_manager_->IsConnectionIdInTimeWait(connection_id));
285 // Dispatcher forwards subsequent packets for this connection_id to the time
286 // wait list manager.
287 EXPECT_CALL(*time_wait_list_manager_,
288 ProcessPacket(_, _, connection_id, _, _)).Times(1);
289 EXPECT_CALL(*time_wait_list_manager_, AddConnectionIdToTimeWait(_, _, _, _))
290 .Times(0);
291 ProcessPacket(client_address, connection_id, true, "foo");
294 TEST_F(QuicDispatcherTest, NoVersionPacketToTimeWaitListManager) {
295 CreateTimeWaitListManager();
297 IPEndPoint client_address(net::test::Loopback4(), 1);
298 QuicConnectionId connection_id = 1;
299 // Dispatcher forwards all packets for this connection_id to the time wait
300 // list manager.
301 EXPECT_CALL(dispatcher_, CreateQuicSession(_, _, _)).Times(0);
302 EXPECT_CALL(*time_wait_list_manager_,
303 ProcessPacket(_, _, connection_id, _, _)).Times(1);
304 EXPECT_CALL(*time_wait_list_manager_, AddConnectionIdToTimeWait(_, _, _, _))
305 .Times(1);
306 ProcessPacket(client_address, connection_id, false, "data");
309 // Enables mocking of the handshake-confirmation for stateless rejects.
310 class MockQuicCryptoServerStream : public QuicCryptoServerStream {
311 public:
312 MockQuicCryptoServerStream(const QuicCryptoServerConfig& crypto_config,
313 QuicSession* session)
314 : QuicCryptoServerStream(&crypto_config, session) {}
315 void set_handshake_confirmed_for_testing(bool handshake_confirmed) {
316 handshake_confirmed_ = handshake_confirmed;
319 private:
320 DISALLOW_COPY_AND_ASSIGN(MockQuicCryptoServerStream);
323 struct StatelessRejectTestParams {
324 StatelessRejectTestParams(bool enable_stateless_rejects_via_flag,
325 bool use_stateless_rejects_if_peer_supported,
326 bool client_supports_statelesss_rejects,
327 bool crypto_handshake_successful)
328 : enable_stateless_rejects_via_flag(enable_stateless_rejects_via_flag),
329 use_stateless_rejects_if_peer_supported(
330 use_stateless_rejects_if_peer_supported),
331 client_supports_statelesss_rejects(client_supports_statelesss_rejects),
332 crypto_handshake_successful(crypto_handshake_successful) {}
334 friend std::ostream& operator<<(std::ostream& os,
335 const StatelessRejectTestParams& p) {
336 os << " enable_stateless_rejects_via_flag: "
337 << p.enable_stateless_rejects_via_flag << std::endl;
338 os << "{ use_stateless_rejects_if_peer_supported: "
339 << p.use_stateless_rejects_if_peer_supported << std::endl;
340 os << "{ client_supports_statelesss_rejects: "
341 << p.client_supports_statelesss_rejects << std::endl;
342 os << " crypto_handshake_successful: " << p.crypto_handshake_successful
343 << " }";
344 return os;
347 // This only enables the stateless reject feature via the feature-flag.
348 // It does not force the crypto server to emit stateless rejects.
349 bool enable_stateless_rejects_via_flag;
350 // If true, this forces the server to send a stateless reject when rejecting
351 // messages. This should be a no-op if enable_stateless_rejects_via_flag is
352 // false or the peer does not support them.
353 bool use_stateless_rejects_if_peer_supported;
354 // Whether or not the client supports stateless rejects.
355 bool client_supports_statelesss_rejects;
356 // Should the initial crypto handshake succeed or not.
357 bool crypto_handshake_successful;
360 // Constructs various test permutations for stateless rejects.
361 vector<StatelessRejectTestParams> GetStatelessRejectTestParams() {
362 vector<StatelessRejectTestParams> params;
363 for (bool enable_stateless_rejects_via_flag : {true, false}) {
364 for (bool use_stateless_rejects_if_peer_supported : {true, false}) {
365 for (bool client_supports_statelesss_rejects : {true, false}) {
366 for (bool crypto_handshake_successful : {true, false}) {
367 params.push_back(StatelessRejectTestParams(
368 enable_stateless_rejects_via_flag,
369 use_stateless_rejects_if_peer_supported,
370 client_supports_statelesss_rejects, crypto_handshake_successful));
375 return params;
378 class QuicDispatcherStatelessRejectTest
379 : public QuicDispatcherTest,
380 public ::testing::WithParamInterface<StatelessRejectTestParams> {
381 public:
382 QuicDispatcherStatelessRejectTest() : crypto_stream1_(nullptr) {}
384 ~QuicDispatcherStatelessRejectTest() override {
385 if (crypto_stream1_) {
386 delete crypto_stream1_;
390 // This test setup assumes that all testing will be done using
391 // crypto_stream1_.
392 void SetUp() override {
393 FLAGS_enable_quic_stateless_reject_support =
394 GetParam().enable_stateless_rejects_via_flag;
397 // Returns true or false, depending on whether the server will emit
398 // a stateless reject, depending upon the parameters of the test.
399 bool ExpectStatelessReject() {
400 return GetParam().enable_stateless_rejects_via_flag &&
401 GetParam().use_stateless_rejects_if_peer_supported &&
402 !GetParam().crypto_handshake_successful &&
403 GetParam().client_supports_statelesss_rejects;
406 // Sets up dispatcher_, sesession1_, and crypto_stream1_ based on
407 // the test parameters.
408 QuicServerSession* CreateSessionBasedOnTestParams(
409 QuicConnectionId connection_id,
410 const IPEndPoint& client_address) {
411 CreateSession(&dispatcher_, config_, connection_id, client_address,
412 &crypto_config_, &session1_);
414 crypto_stream1_ = new MockQuicCryptoServerStream(crypto_config_, session1_);
415 session1_->SetCryptoStream(crypto_stream1_);
416 crypto_stream1_->set_use_stateless_rejects_if_peer_supported(
417 GetParam().use_stateless_rejects_if_peer_supported);
418 crypto_stream1_->set_handshake_confirmed_for_testing(
419 GetParam().crypto_handshake_successful);
420 crypto_stream1_->set_peer_supports_stateless_rejects(
421 GetParam().client_supports_statelesss_rejects);
422 return session1_;
425 MockQuicCryptoServerStream* crypto_stream1_;
428 TEST_F(QuicDispatcherTest, ProcessPacketWithZeroPort) {
429 CreateTimeWaitListManager();
431 IPEndPoint client_address(net::test::Loopback4(), 0);
432 server_address_ = IPEndPoint(net::test::Any4(), 5);
434 // dispatcher_ should drop this packet.
435 EXPECT_CALL(dispatcher_, CreateQuicSession(1, _, client_address)).Times(0);
436 EXPECT_CALL(*time_wait_list_manager_, ProcessPacket(_, _, _, _, _)).Times(0);
437 EXPECT_CALL(*time_wait_list_manager_, AddConnectionIdToTimeWait(_, _, _, _))
438 .Times(0);
439 ProcessPacket(client_address, 1, true, "foo");
442 TEST_F(QuicDispatcherTest, OKSeqNoPacketProcessed) {
443 IPEndPoint client_address(net::test::Loopback4(), 1);
444 QuicConnectionId connection_id = 1;
445 server_address_ = IPEndPoint(net::test::Any4(), 5);
447 EXPECT_CALL(dispatcher_, CreateQuicSession(1, _, client_address))
448 .WillOnce(testing::Return(CreateSession(&dispatcher_, config_, 1,
449 client_address, &crypto_config_,
450 &session1_)));
451 // A packet whose packet number is the largest that is allowed to start a
452 // connection.
453 ProcessPacket(client_address, connection_id, true, "data",
454 PACKET_8BYTE_CONNECTION_ID, PACKET_6BYTE_PACKET_NUMBER,
455 QuicDispatcher::kMaxReasonableInitialPacketNumber);
456 EXPECT_EQ(client_address, dispatcher_.current_client_address());
457 EXPECT_EQ(server_address_, dispatcher_.current_server_address());
460 TEST_F(QuicDispatcherTest, TooBigSeqNoPacketToTimeWaitListManager) {
461 CreateTimeWaitListManager();
463 IPEndPoint client_address(net::test::Loopback4(), 1);
464 QuicConnectionId connection_id = 1;
465 // Dispatcher forwards this packet for this connection_id to the time wait
466 // list manager.
467 EXPECT_CALL(dispatcher_, CreateQuicSession(_, _, _)).Times(0);
468 EXPECT_CALL(*time_wait_list_manager_,
469 ProcessPacket(_, _, connection_id, _, _)).Times(1);
470 EXPECT_CALL(*time_wait_list_manager_, AddConnectionIdToTimeWait(_, _, _, _))
471 .Times(1);
472 // A packet whose packet number is one to large to be allowed to start a
473 // connection.
474 ProcessPacket(client_address, connection_id, true, "data",
475 PACKET_8BYTE_CONNECTION_ID, PACKET_6BYTE_PACKET_NUMBER,
476 QuicDispatcher::kMaxReasonableInitialPacketNumber + 1);
479 INSTANTIATE_TEST_CASE_P(QuicDispatcherStatelessRejectTests,
480 QuicDispatcherStatelessRejectTest,
481 ::testing::ValuesIn(GetStatelessRejectTestParams()));
483 // Parameterized test for stateless rejects. Should test all
484 // combinations of enabling/disabling, reject/no-reject for stateless
485 // rejects.
486 TEST_P(QuicDispatcherStatelessRejectTest, ParameterizedBasicTest) {
487 CreateTimeWaitListManager();
489 IPEndPoint client_address(net::test::Loopback4(), 1);
490 QuicConnectionId connection_id = 1;
491 EXPECT_CALL(dispatcher_, CreateQuicSession(connection_id, _, client_address))
492 .WillOnce(testing::Return(
493 CreateSessionBasedOnTestParams(connection_id, client_address)));
495 // Process the first packet for the connection.
496 if (ExpectStatelessReject()) {
497 // If this is a stateless reject, we expect the connection to close.
498 EXPECT_CALL(*session1_, OnConnectionClosed(_, _))
499 .Times(1)
500 .WillOnce(WithoutArgs(Invoke(
501 reinterpret_cast<MockServerConnection*>(session1_->connection()),
502 &MockServerConnection::UnregisterOnConnectionClosed)));
504 ProcessPacket(client_address, connection_id, true, "foo");
506 // Send a second packet and check the results. If this is a stateless reject,
507 // the existing connection_id will go on the time-wait list.
508 EXPECT_EQ(ExpectStatelessReject(),
509 time_wait_list_manager_->IsConnectionIdInTimeWait(connection_id));
510 if (ExpectStatelessReject()) {
511 // The second packet will be processed on the time-wait list.
512 EXPECT_CALL(*time_wait_list_manager_,
513 ProcessPacket(_, _, connection_id, _, _)).Times(1);
514 } else {
515 // The second packet will trigger a packet-validation
516 EXPECT_CALL(*reinterpret_cast<MockConnection*>(session1_->connection()),
517 ProcessUdpPacket(_, _, _))
518 .Times(1)
519 .WillOnce(testing::WithArgs<2>(
520 Invoke(this, &QuicDispatcherTest::ValidatePacket)));
522 ProcessPacket(client_address, connection_id, true, "foo");
525 // Verify the stopgap test: Packets with truncated connection IDs should be
526 // dropped.
527 class QuicDispatcherTestStrayPacketConnectionId
528 : public QuicDispatcherTest,
529 public ::testing::WithParamInterface<QuicConnectionIdLength> {};
531 // Packets with truncated connection IDs should be dropped.
532 TEST_P(QuicDispatcherTestStrayPacketConnectionId,
533 StrayPacketTruncatedConnectionId) {
534 const QuicConnectionIdLength connection_id_length = GetParam();
536 CreateTimeWaitListManager();
538 IPEndPoint client_address(net::test::Loopback4(), 1);
539 QuicConnectionId connection_id = 1;
540 // Dispatcher drops this packet.
541 EXPECT_CALL(dispatcher_, CreateQuicSession(_, _, _)).Times(0);
542 EXPECT_CALL(*time_wait_list_manager_,
543 ProcessPacket(_, _, connection_id, _, _)).Times(0);
544 EXPECT_CALL(*time_wait_list_manager_, AddConnectionIdToTimeWait(_, _, _, _))
545 .Times(0);
546 ProcessPacket(client_address, connection_id, true, "data",
547 connection_id_length, PACKET_6BYTE_PACKET_NUMBER);
550 INSTANTIATE_TEST_CASE_P(ConnectionIdLength,
551 QuicDispatcherTestStrayPacketConnectionId,
552 ::testing::Values(PACKET_0BYTE_CONNECTION_ID,
553 PACKET_1BYTE_CONNECTION_ID,
554 PACKET_4BYTE_CONNECTION_ID));
556 class BlockingWriter : public QuicPacketWriterWrapper {
557 public:
558 BlockingWriter() : write_blocked_(false) {}
560 bool IsWriteBlocked() const override { return write_blocked_; }
561 void SetWritable() override { write_blocked_ = false; }
563 WriteResult WritePacket(const char* buffer,
564 size_t buf_len,
565 const IPAddressNumber& self_client_address,
566 const IPEndPoint& peer_client_address) override {
567 // It would be quite possible to actually implement this method here with
568 // the fake blocked status, but it would be significantly more work in
569 // Chromium, and since it's not called anyway, don't bother.
570 LOG(DFATAL) << "Not supported";
571 return WriteResult();
574 bool write_blocked_;
577 class QuicDispatcherWriteBlockedListTest : public QuicDispatcherTest {
578 public:
579 void SetUp() override {
580 writer_ = new BlockingWriter;
581 QuicDispatcherPeer::SetPacketWriterFactory(&dispatcher_,
582 new TestWriterFactory());
583 QuicDispatcherPeer::UseWriter(&dispatcher_, writer_);
585 IPEndPoint client_address(net::test::Loopback4(), 1);
587 EXPECT_CALL(dispatcher_, CreateQuicSession(_, _, client_address))
588 .WillOnce(testing::Return(CreateSession(&dispatcher_, config_, 1,
589 client_address, &crypto_config_,
590 &session1_)));
591 ProcessPacket(client_address, 1, true, "foo");
593 EXPECT_CALL(dispatcher_, CreateQuicSession(_, _, client_address))
594 .WillOnce(testing::Return(CreateSession(&dispatcher_, config_, 2,
595 client_address, &crypto_config_,
596 &session2_)));
597 ProcessPacket(client_address, 2, true, "bar");
599 blocked_list_ = QuicDispatcherPeer::GetWriteBlockedList(&dispatcher_);
602 void TearDown() override {
603 EXPECT_CALL(*connection1(), SendConnectionClose(QUIC_PEER_GOING_AWAY));
604 EXPECT_CALL(*connection2(), SendConnectionClose(QUIC_PEER_GOING_AWAY));
605 dispatcher_.Shutdown();
608 void SetBlocked() {
609 writer_->write_blocked_ = true;
612 void BlockConnection2() {
613 writer_->write_blocked_ = true;
614 dispatcher_.OnWriteBlocked(connection2());
617 protected:
618 BlockingWriter* writer_;
619 QuicDispatcher::WriteBlockedList* blocked_list_;
622 TEST_F(QuicDispatcherWriteBlockedListTest, BasicOnCanWrite) {
623 // No OnCanWrite calls because no connections are blocked.
624 dispatcher_.OnCanWrite();
626 // Register connection 1 for events, and make sure it's notified.
627 SetBlocked();
628 dispatcher_.OnWriteBlocked(connection1());
629 EXPECT_CALL(*connection1(), OnCanWrite());
630 dispatcher_.OnCanWrite();
632 // It should get only one notification.
633 EXPECT_CALL(*connection1(), OnCanWrite()).Times(0);
634 dispatcher_.OnCanWrite();
635 EXPECT_FALSE(dispatcher_.HasPendingWrites());
638 TEST_F(QuicDispatcherWriteBlockedListTest, OnCanWriteOrder) {
639 // Make sure we handle events in order.
640 InSequence s;
641 SetBlocked();
642 dispatcher_.OnWriteBlocked(connection1());
643 dispatcher_.OnWriteBlocked(connection2());
644 EXPECT_CALL(*connection1(), OnCanWrite());
645 EXPECT_CALL(*connection2(), OnCanWrite());
646 dispatcher_.OnCanWrite();
648 // Check the other ordering.
649 SetBlocked();
650 dispatcher_.OnWriteBlocked(connection2());
651 dispatcher_.OnWriteBlocked(connection1());
652 EXPECT_CALL(*connection2(), OnCanWrite());
653 EXPECT_CALL(*connection1(), OnCanWrite());
654 dispatcher_.OnCanWrite();
657 TEST_F(QuicDispatcherWriteBlockedListTest, OnCanWriteRemove) {
658 // Add and remove one connction.
659 SetBlocked();
660 dispatcher_.OnWriteBlocked(connection1());
661 blocked_list_->erase(connection1());
662 EXPECT_CALL(*connection1(), OnCanWrite()).Times(0);
663 dispatcher_.OnCanWrite();
665 // Add and remove one connction and make sure it doesn't affect others.
666 SetBlocked();
667 dispatcher_.OnWriteBlocked(connection1());
668 dispatcher_.OnWriteBlocked(connection2());
669 blocked_list_->erase(connection1());
670 EXPECT_CALL(*connection2(), OnCanWrite());
671 dispatcher_.OnCanWrite();
673 // Add it, remove it, and add it back and make sure things are OK.
674 SetBlocked();
675 dispatcher_.OnWriteBlocked(connection1());
676 blocked_list_->erase(connection1());
677 dispatcher_.OnWriteBlocked(connection1());
678 EXPECT_CALL(*connection1(), OnCanWrite()).Times(1);
679 dispatcher_.OnCanWrite();
682 TEST_F(QuicDispatcherWriteBlockedListTest, DoubleAdd) {
683 // Make sure a double add does not necessitate a double remove.
684 SetBlocked();
685 dispatcher_.OnWriteBlocked(connection1());
686 dispatcher_.OnWriteBlocked(connection1());
687 blocked_list_->erase(connection1());
688 EXPECT_CALL(*connection1(), OnCanWrite()).Times(0);
689 dispatcher_.OnCanWrite();
691 // Make sure a double add does not result in two OnCanWrite calls.
692 SetBlocked();
693 dispatcher_.OnWriteBlocked(connection1());
694 dispatcher_.OnWriteBlocked(connection1());
695 EXPECT_CALL(*connection1(), OnCanWrite()).Times(1);
696 dispatcher_.OnCanWrite();
699 TEST_F(QuicDispatcherWriteBlockedListTest, OnCanWriteHandleBlock) {
700 // Finally make sure if we write block on a write call, we stop calling.
701 InSequence s;
702 SetBlocked();
703 dispatcher_.OnWriteBlocked(connection1());
704 dispatcher_.OnWriteBlocked(connection2());
705 EXPECT_CALL(*connection1(), OnCanWrite()).WillOnce(
706 Invoke(this, &QuicDispatcherWriteBlockedListTest::SetBlocked));
707 EXPECT_CALL(*connection2(), OnCanWrite()).Times(0);
708 dispatcher_.OnCanWrite();
710 // And we'll resume where we left off when we get another call.
711 EXPECT_CALL(*connection2(), OnCanWrite());
712 dispatcher_.OnCanWrite();
715 TEST_F(QuicDispatcherWriteBlockedListTest, LimitedWrites) {
716 // Make sure we call both writers. The first will register for more writing
717 // but should not be immediately called due to limits.
718 InSequence s;
719 SetBlocked();
720 dispatcher_.OnWriteBlocked(connection1());
721 dispatcher_.OnWriteBlocked(connection2());
722 EXPECT_CALL(*connection1(), OnCanWrite());
723 EXPECT_CALL(*connection2(), OnCanWrite()).WillOnce(
724 Invoke(this, &QuicDispatcherWriteBlockedListTest::BlockConnection2));
725 dispatcher_.OnCanWrite();
726 EXPECT_TRUE(dispatcher_.HasPendingWrites());
728 // Now call OnCanWrite again, and connection1 should get its second chance
729 EXPECT_CALL(*connection2(), OnCanWrite());
730 dispatcher_.OnCanWrite();
731 EXPECT_FALSE(dispatcher_.HasPendingWrites());
734 TEST_F(QuicDispatcherWriteBlockedListTest, TestWriteLimits) {
735 // Finally make sure if we write block on a write call, we stop calling.
736 InSequence s;
737 SetBlocked();
738 dispatcher_.OnWriteBlocked(connection1());
739 dispatcher_.OnWriteBlocked(connection2());
740 EXPECT_CALL(*connection1(), OnCanWrite()).WillOnce(
741 Invoke(this, &QuicDispatcherWriteBlockedListTest::SetBlocked));
742 EXPECT_CALL(*connection2(), OnCanWrite()).Times(0);
743 dispatcher_.OnCanWrite();
744 EXPECT_TRUE(dispatcher_.HasPendingWrites());
746 // And we'll resume where we left off when we get another call.
747 EXPECT_CALL(*connection2(), OnCanWrite());
748 dispatcher_.OnCanWrite();
749 EXPECT_FALSE(dispatcher_.HasPendingWrites());
752 } // namespace
753 } // namespace test
754 } // namespace tools
755 } // namespace net