1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "net/quic/test_tools/quic_test_utils.h"
8 #include "base/stl_util.h"
9 #include "base/strings/string_number_conversions.h"
10 #include "net/quic/crypto/crypto_framer.h"
11 #include "net/quic/crypto/crypto_handshake.h"
12 #include "net/quic/crypto/crypto_utils.h"
13 #include "net/quic/crypto/null_encrypter.h"
14 #include "net/quic/crypto/quic_decrypter.h"
15 #include "net/quic/crypto/quic_encrypter.h"
16 #include "net/quic/quic_framer.h"
17 #include "net/quic/quic_packet_creator.h"
18 #include "net/quic/quic_utils.h"
19 #include "net/quic/test_tools/quic_connection_peer.h"
20 #include "net/spdy/spdy_frame_builder.h"
22 using base::StringPiece
;
26 using testing::AnyNumber
;
33 // No-op alarm implementation used by MockHelper.
34 class TestAlarm
: public QuicAlarm
{
36 explicit TestAlarm(QuicAlarm::Delegate
* delegate
)
37 : QuicAlarm(delegate
) {
40 virtual void SetImpl() OVERRIDE
{}
41 virtual void CancelImpl() OVERRIDE
{}
46 QuicAckFrame
MakeAckFrame(QuicPacketSequenceNumber largest_observed
) {
48 ack
.largest_observed
= largest_observed
;
53 QuicAckFrame
MakeAckFrameWithNackRanges(
54 size_t num_nack_ranges
, QuicPacketSequenceNumber least_unacked
) {
55 QuicAckFrame ack
= MakeAckFrame(2 * num_nack_ranges
+ least_unacked
);
56 // Add enough missing packets to get num_nack_ranges nack ranges.
57 for (QuicPacketSequenceNumber i
= 1; i
< 2 * num_nack_ranges
; i
+= 2) {
58 ack
.missing_packets
.insert(least_unacked
+ i
);
63 SerializedPacket
BuildUnsizedDataPacket(QuicFramer
* framer
,
64 const QuicPacketHeader
& header
,
65 const QuicFrames
& frames
) {
66 const size_t max_plaintext_size
= framer
->GetMaxPlaintextSize(kMaxPacketSize
);
67 size_t packet_size
= GetPacketHeaderSize(header
);
68 for (size_t i
= 0; i
< frames
.size(); ++i
) {
69 DCHECK_LE(packet_size
, max_plaintext_size
);
70 bool first_frame
= i
== 0;
71 bool last_frame
= i
== frames
.size() - 1;
72 const size_t frame_size
= framer
->GetSerializedFrameLength(
73 frames
[i
], max_plaintext_size
- packet_size
, first_frame
, last_frame
,
74 header
.is_in_fec_group
,
75 header
.public_header
.sequence_number_length
);
77 packet_size
+= frame_size
;
79 return framer
->BuildDataPacket(header
, frames
, packet_size
);
82 uint64
SimpleRandom::RandUint64() {
83 unsigned char hash
[base::kSHA1Length
];
84 base::SHA1HashBytes(reinterpret_cast<unsigned char*>(&seed_
), sizeof(seed_
),
86 memcpy(&seed_
, hash
, sizeof(seed_
));
90 MockFramerVisitor::MockFramerVisitor() {
91 // By default, we want to accept packets.
92 ON_CALL(*this, OnProtocolVersionMismatch(_
))
93 .WillByDefault(testing::Return(false));
95 // By default, we want to accept packets.
96 ON_CALL(*this, OnUnauthenticatedHeader(_
))
97 .WillByDefault(testing::Return(true));
99 ON_CALL(*this, OnUnauthenticatedPublicHeader(_
))
100 .WillByDefault(testing::Return(true));
102 ON_CALL(*this, OnPacketHeader(_
))
103 .WillByDefault(testing::Return(true));
105 ON_CALL(*this, OnStreamFrame(_
))
106 .WillByDefault(testing::Return(true));
108 ON_CALL(*this, OnAckFrame(_
))
109 .WillByDefault(testing::Return(true));
111 ON_CALL(*this, OnCongestionFeedbackFrame(_
))
112 .WillByDefault(testing::Return(true));
114 ON_CALL(*this, OnStopWaitingFrame(_
))
115 .WillByDefault(testing::Return(true));
117 ON_CALL(*this, OnPingFrame(_
))
118 .WillByDefault(testing::Return(true));
120 ON_CALL(*this, OnRstStreamFrame(_
))
121 .WillByDefault(testing::Return(true));
123 ON_CALL(*this, OnConnectionCloseFrame(_
))
124 .WillByDefault(testing::Return(true));
126 ON_CALL(*this, OnGoAwayFrame(_
))
127 .WillByDefault(testing::Return(true));
130 MockFramerVisitor::~MockFramerVisitor() {
133 bool NoOpFramerVisitor::OnProtocolVersionMismatch(QuicVersion version
) {
137 bool NoOpFramerVisitor::OnUnauthenticatedPublicHeader(
138 const QuicPacketPublicHeader
& header
) {
142 bool NoOpFramerVisitor::OnUnauthenticatedHeader(
143 const QuicPacketHeader
& header
) {
147 bool NoOpFramerVisitor::OnPacketHeader(const QuicPacketHeader
& header
) {
151 bool NoOpFramerVisitor::OnStreamFrame(const QuicStreamFrame
& frame
) {
155 bool NoOpFramerVisitor::OnAckFrame(const QuicAckFrame
& frame
) {
159 bool NoOpFramerVisitor::OnCongestionFeedbackFrame(
160 const QuicCongestionFeedbackFrame
& frame
) {
164 bool NoOpFramerVisitor::OnStopWaitingFrame(
165 const QuicStopWaitingFrame
& frame
) {
169 bool NoOpFramerVisitor::OnPingFrame(const QuicPingFrame
& frame
) {
173 bool NoOpFramerVisitor::OnRstStreamFrame(
174 const QuicRstStreamFrame
& frame
) {
178 bool NoOpFramerVisitor::OnConnectionCloseFrame(
179 const QuicConnectionCloseFrame
& frame
) {
183 bool NoOpFramerVisitor::OnGoAwayFrame(const QuicGoAwayFrame
& frame
) {
187 bool NoOpFramerVisitor::OnWindowUpdateFrame(
188 const QuicWindowUpdateFrame
& frame
) {
192 bool NoOpFramerVisitor::OnBlockedFrame(const QuicBlockedFrame
& frame
) {
196 MockConnectionVisitor::MockConnectionVisitor() {
199 MockConnectionVisitor::~MockConnectionVisitor() {
202 MockHelper::MockHelper() {
205 MockHelper::~MockHelper() {
208 const QuicClock
* MockHelper::GetClock() const {
212 QuicRandom
* MockHelper::GetRandomGenerator() {
213 return &random_generator_
;
216 QuicAlarm
* MockHelper::CreateAlarm(QuicAlarm::Delegate
* delegate
) {
217 return new TestAlarm(delegate
);
220 void MockHelper::AdvanceTime(QuicTime::Delta delta
) {
221 clock_
.AdvanceTime(delta
);
225 class NiceMockPacketWriterFactory
226 : public QuicConnection::PacketWriterFactory
{
228 NiceMockPacketWriterFactory() {}
229 virtual ~NiceMockPacketWriterFactory() {}
231 virtual QuicPacketWriter
* Create(
232 QuicConnection
* /*connection*/) const OVERRIDE
{
233 return new testing::NiceMock
<MockPacketWriter
>();
237 DISALLOW_COPY_AND_ASSIGN(NiceMockPacketWriterFactory
);
241 MockConnection::MockConnection(bool is_server
)
242 : QuicConnection(kTestConnectionId
,
243 IPEndPoint(TestPeerIPAddress(), kTestPort
),
244 new testing::NiceMock
<MockHelper
>(),
245 NiceMockPacketWriterFactory(),
246 /* owns_writer= */ true,
247 is_server
, QuicSupportedVersions()),
251 MockConnection::MockConnection(IPEndPoint address
,
253 : QuicConnection(kTestConnectionId
, address
,
254 new testing::NiceMock
<MockHelper
>(),
255 NiceMockPacketWriterFactory(),
256 /* owns_writer= */ true,
257 is_server
, QuicSupportedVersions()),
261 MockConnection::MockConnection(QuicConnectionId connection_id
,
263 : QuicConnection(connection_id
,
264 IPEndPoint(TestPeerIPAddress(), kTestPort
),
265 new testing::NiceMock
<MockHelper
>(),
266 NiceMockPacketWriterFactory(),
267 /* owns_writer= */ true,
268 is_server
, QuicSupportedVersions()),
272 MockConnection::MockConnection(bool is_server
,
273 const QuicVersionVector
& supported_versions
)
274 : QuicConnection(kTestConnectionId
,
275 IPEndPoint(TestPeerIPAddress(), kTestPort
),
276 new testing::NiceMock
<MockHelper
>(),
277 NiceMockPacketWriterFactory(),
278 /* owns_writer= */ true,
279 is_server
, supported_versions
),
283 MockConnection::~MockConnection() {
286 void MockConnection::AdvanceTime(QuicTime::Delta delta
) {
287 static_cast<MockHelper
*>(helper())->AdvanceTime(delta
);
290 PacketSavingConnection::PacketSavingConnection(bool is_server
)
291 : MockConnection(is_server
) {
294 PacketSavingConnection::PacketSavingConnection(
296 const QuicVersionVector
& supported_versions
)
297 : MockConnection(is_server
, supported_versions
) {
300 PacketSavingConnection::~PacketSavingConnection() {
301 STLDeleteElements(&packets_
);
302 STLDeleteElements(&encrypted_packets_
);
305 void PacketSavingConnection::SendOrQueuePacket(QueuedPacket packet
) {
306 packets_
.push_back(packet
.serialized_packet
.packet
);
307 QuicEncryptedPacket
* encrypted
= QuicConnectionPeer::GetFramer(this)->
308 EncryptPacket(packet
.encryption_level
,
309 packet
.serialized_packet
.sequence_number
,
310 *packet
.serialized_packet
.packet
);
311 encrypted_packets_
.push_back(encrypted
);
312 // Transfer ownership of the packet to the SentPacketManager and the
313 // ack notifier to the AckNotifierManager.
314 sent_packet_manager_
.OnPacketSent(
315 &packet
.serialized_packet
, 0, QuicTime::Zero(), 1000,
316 NOT_RETRANSMISSION
, HAS_RETRANSMITTABLE_DATA
);
319 MockSession::MockSession(QuicConnection
* connection
)
320 : QuicSession(connection
, DefaultQuicConfig()) {
322 ON_CALL(*this, WritevData(_
, _
, _
, _
, _
, _
))
323 .WillByDefault(testing::Return(QuicConsumedData(0, false)));
326 MockSession::~MockSession() {
329 TestSession::TestSession(QuicConnection
* connection
, const QuicConfig
& config
)
330 : QuicSession(connection
, config
),
331 crypto_stream_(NULL
) {
335 TestSession::~TestSession() {}
337 void TestSession::SetCryptoStream(QuicCryptoStream
* stream
) {
338 crypto_stream_
= stream
;
341 QuicCryptoStream
* TestSession::GetCryptoStream() {
342 return crypto_stream_
;
345 TestClientSession::TestClientSession(QuicConnection
* connection
,
346 const QuicConfig
& config
)
347 : QuicClientSessionBase(connection
, config
),
348 crypto_stream_(NULL
) {
349 EXPECT_CALL(*this, OnProofValid(_
)).Times(AnyNumber());
353 TestClientSession::~TestClientSession() {}
355 void TestClientSession::SetCryptoStream(QuicCryptoStream
* stream
) {
356 crypto_stream_
= stream
;
359 QuicCryptoStream
* TestClientSession::GetCryptoStream() {
360 return crypto_stream_
;
363 MockPacketWriter::MockPacketWriter() {
366 MockPacketWriter::~MockPacketWriter() {
369 MockSendAlgorithm::MockSendAlgorithm() {
372 MockSendAlgorithm::~MockSendAlgorithm() {
375 MockLossAlgorithm::MockLossAlgorithm() {
378 MockLossAlgorithm::~MockLossAlgorithm() {
381 MockAckNotifierDelegate::MockAckNotifierDelegate() {
384 MockAckNotifierDelegate::~MockAckNotifierDelegate() {
387 MockNetworkChangeVisitor::MockNetworkChangeVisitor() {
390 MockNetworkChangeVisitor::~MockNetworkChangeVisitor() {
395 string
HexDumpWithMarks(const char* data
, int length
,
396 const bool* marks
, int mark_length
) {
397 static const char kHexChars
[] = "0123456789abcdef";
398 static const int kColumns
= 4;
400 const int kSizeLimit
= 1024;
401 if (length
> kSizeLimit
|| mark_length
> kSizeLimit
) {
402 LOG(ERROR
) << "Only dumping first " << kSizeLimit
<< " bytes.";
403 length
= min(length
, kSizeLimit
);
404 mark_length
= min(mark_length
, kSizeLimit
);
408 for (const char* row
= data
; length
> 0;
409 row
+= kColumns
, length
-= kColumns
) {
410 for (const char *p
= row
; p
< row
+ 4; ++p
) {
411 if (p
< row
+ length
) {
413 (marks
&& (p
- data
) < mark_length
&& marks
[p
- data
]);
414 hex
+= mark
? '*' : ' ';
415 hex
+= kHexChars
[(*p
& 0xf0) >> 4];
416 hex
+= kHexChars
[*p
& 0x0f];
417 hex
+= mark
? '*' : ' ';
424 for (const char *p
= row
; p
< row
+ 4 && p
< row
+ length
; ++p
)
425 hex
+= (*p
>= 0x20 && *p
<= 0x7f) ? (*p
) : '.';
434 IPAddressNumber
TestPeerIPAddress() { return Loopback4(); }
436 QuicVersion
QuicVersionMax() { return QuicSupportedVersions().front(); }
438 QuicVersion
QuicVersionMin() { return QuicSupportedVersions().back(); }
440 IPAddressNumber
Loopback4() {
441 IPAddressNumber addr
;
442 CHECK(ParseIPLiteralToNumber("127.0.0.1", &addr
));
446 IPAddressNumber
Loopback6() {
447 IPAddressNumber addr
;
448 CHECK(ParseIPLiteralToNumber("::1", &addr
));
452 void GenerateBody(string
* body
, int length
) {
454 body
->reserve(length
);
455 for (int i
= 0; i
< length
; ++i
) {
456 body
->append(1, static_cast<char>(32 + i
% (126 - 32)));
460 QuicEncryptedPacket
* ConstructEncryptedPacket(
461 QuicConnectionId connection_id
,
464 QuicPacketSequenceNumber sequence_number
,
465 const string
& data
) {
466 QuicPacketHeader header
;
467 header
.public_header
.connection_id
= connection_id
;
468 header
.public_header
.connection_id_length
= PACKET_8BYTE_CONNECTION_ID
;
469 header
.public_header
.version_flag
= version_flag
;
470 header
.public_header
.reset_flag
= reset_flag
;
471 header
.public_header
.sequence_number_length
= PACKET_6BYTE_SEQUENCE_NUMBER
;
472 header
.packet_sequence_number
= sequence_number
;
473 header
.entropy_flag
= false;
474 header
.entropy_hash
= 0;
475 header
.fec_flag
= false;
476 header
.is_in_fec_group
= NOT_IN_FEC_GROUP
;
477 header
.fec_group
= 0;
478 QuicStreamFrame
stream_frame(1, false, 0, MakeIOVector(data
));
479 QuicFrame
frame(&stream_frame
);
481 frames
.push_back(frame
);
482 QuicFramer
framer(QuicSupportedVersions(), QuicTime::Zero(), false);
483 scoped_ptr
<QuicPacket
> packet(
484 BuildUnsizedDataPacket(&framer
, header
, frames
).packet
);
485 EXPECT_TRUE(packet
!= NULL
);
486 QuicEncryptedPacket
* encrypted
= framer
.EncryptPacket(ENCRYPTION_NONE
,
489 EXPECT_TRUE(encrypted
!= NULL
);
493 void CompareCharArraysWithHexError(
494 const string
& description
,
496 const int actual_len
,
497 const char* expected
,
498 const int expected_len
) {
499 EXPECT_EQ(actual_len
, expected_len
);
500 const int min_len
= min(actual_len
, expected_len
);
501 const int max_len
= max(actual_len
, expected_len
);
502 scoped_ptr
<bool[]> marks(new bool[max_len
]);
503 bool identical
= (actual_len
== expected_len
);
504 for (int i
= 0; i
< min_len
; ++i
) {
505 if (actual
[i
] != expected
[i
]) {
512 for (int i
= min_len
; i
< max_len
; ++i
) {
515 if (identical
) return;
520 << HexDumpWithMarks(expected
, expected_len
, marks
.get(), max_len
)
522 << HexDumpWithMarks(actual
, actual_len
, marks
.get(), max_len
);
525 bool DecodeHexString(const base::StringPiece
& hex
, std::string
* bytes
) {
529 std::vector
<uint8
> v
;
530 if (!base::HexStringToBytes(hex
.as_string(), &v
))
533 bytes
->assign(reinterpret_cast<const char*>(&v
[0]), v
.size());
537 static QuicPacket
* ConstructPacketFromHandshakeMessage(
538 QuicConnectionId connection_id
,
539 const CryptoHandshakeMessage
& message
,
540 bool should_include_version
) {
541 CryptoFramer crypto_framer
;
542 scoped_ptr
<QuicData
> data(crypto_framer
.ConstructHandshakeMessage(message
));
543 QuicFramer
quic_framer(QuicSupportedVersions(), QuicTime::Zero(), false);
545 QuicPacketHeader header
;
546 header
.public_header
.connection_id
= connection_id
;
547 header
.public_header
.reset_flag
= false;
548 header
.public_header
.version_flag
= should_include_version
;
549 header
.packet_sequence_number
= 1;
550 header
.entropy_flag
= false;
551 header
.entropy_hash
= 0;
552 header
.fec_flag
= false;
553 header
.fec_group
= 0;
555 QuicStreamFrame
stream_frame(kCryptoStreamId
, false, 0,
556 MakeIOVector(data
->AsStringPiece()));
558 QuicFrame
frame(&stream_frame
);
560 frames
.push_back(frame
);
561 return BuildUnsizedDataPacket(&quic_framer
, header
, frames
).packet
;
564 QuicPacket
* ConstructHandshakePacket(QuicConnectionId connection_id
,
566 CryptoHandshakeMessage message
;
567 message
.set_tag(tag
);
568 return ConstructPacketFromHandshakeMessage(connection_id
, message
, false);
571 size_t GetPacketLengthForOneStream(
573 bool include_version
,
574 QuicSequenceNumberLength sequence_number_length
,
575 InFecGroup is_in_fec_group
,
576 size_t* payload_length
) {
578 const size_t stream_length
=
579 NullEncrypter().GetCiphertextSize(*payload_length
) +
580 QuicPacketCreator::StreamFramePacketOverhead(
581 PACKET_8BYTE_CONNECTION_ID
, include_version
,
582 sequence_number_length
, 0u, is_in_fec_group
);
583 const size_t ack_length
= NullEncrypter().GetCiphertextSize(
584 QuicFramer::GetMinAckFrameSize(
585 sequence_number_length
, PACKET_1BYTE_SEQUENCE_NUMBER
)) +
586 GetPacketHeaderSize(PACKET_8BYTE_CONNECTION_ID
, include_version
,
587 sequence_number_length
, is_in_fec_group
);
588 if (stream_length
< ack_length
) {
589 *payload_length
= 1 + ack_length
- stream_length
;
592 return NullEncrypter().GetCiphertextSize(*payload_length
) +
593 QuicPacketCreator::StreamFramePacketOverhead(
594 PACKET_8BYTE_CONNECTION_ID
, include_version
,
595 sequence_number_length
, 0u, is_in_fec_group
);
598 TestEntropyCalculator::TestEntropyCalculator() {}
600 TestEntropyCalculator::~TestEntropyCalculator() {}
602 QuicPacketEntropyHash
TestEntropyCalculator::EntropyHash(
603 QuicPacketSequenceNumber sequence_number
) const {
607 MockEntropyCalculator::MockEntropyCalculator() {}
609 MockEntropyCalculator::~MockEntropyCalculator() {}
611 QuicConfig
DefaultQuicConfig() {
613 config
.SetDefaults();
614 config
.SetInitialFlowControlWindowToSend(
615 kInitialSessionFlowControlWindowForTest
);
616 config
.SetInitialStreamFlowControlWindowToSend(
617 kInitialStreamFlowControlWindowForTest
);
618 config
.SetInitialSessionFlowControlWindowToSend(
619 kInitialSessionFlowControlWindowForTest
);
623 QuicVersionVector
SupportedVersions(QuicVersion version
) {
624 QuicVersionVector versions
;
625 versions
.push_back(version
);
629 TestWriterFactory::TestWriterFactory() : current_writer_(NULL
) {}
630 TestWriterFactory::~TestWriterFactory() {}
632 QuicPacketWriter
* TestWriterFactory::Create(QuicServerPacketWriter
* writer
,
633 QuicConnection
* connection
) {
634 return new PerConnectionPacketWriter(this, writer
, connection
);
637 void TestWriterFactory::OnPacketSent(WriteResult result
) {
638 if (current_writer_
!= NULL
&& result
.status
== WRITE_STATUS_ERROR
) {
639 current_writer_
->connection()->OnWriteError(result
.error_code
);
640 current_writer_
= NULL
;
644 void TestWriterFactory::Unregister(PerConnectionPacketWriter
* writer
) {
645 if (current_writer_
== writer
) {
646 current_writer_
= NULL
;
650 TestWriterFactory::PerConnectionPacketWriter::PerConnectionPacketWriter(
651 TestWriterFactory
* factory
,
652 QuicServerPacketWriter
* writer
,
653 QuicConnection
* connection
)
654 : QuicPerConnectionPacketWriter(writer
, connection
),
658 TestWriterFactory::PerConnectionPacketWriter::~PerConnectionPacketWriter() {
659 factory_
->Unregister(this);
662 WriteResult
TestWriterFactory::PerConnectionPacketWriter::WritePacket(
665 const IPAddressNumber
& self_address
,
666 const IPEndPoint
& peer_address
) {
667 // A DCHECK(factory_current_writer_ == NULL) would be wrong here -- this class
668 // may be used in a setting where connection()->OnPacketSent() is called in a
669 // different way, so TestWriterFactory::OnPacketSent might never be called.
670 factory_
->current_writer_
= this;
671 return QuicPerConnectionPacketWriter::WritePacket(buffer
,