1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "net/quic/test_tools/quic_test_utils.h"
7 #include "base/stl_util.h"
8 #include "base/strings/string_number_conversions.h"
9 #include "net/quic/crypto/crypto_framer.h"
10 #include "net/quic/crypto/crypto_handshake.h"
11 #include "net/quic/crypto/crypto_utils.h"
12 #include "net/quic/crypto/null_encrypter.h"
13 #include "net/quic/crypto/quic_decrypter.h"
14 #include "net/quic/crypto/quic_encrypter.h"
15 #include "net/quic/quic_framer.h"
16 #include "net/quic/quic_packet_creator.h"
17 #include "net/quic/quic_utils.h"
18 #include "net/quic/test_tools/quic_connection_peer.h"
19 #include "net/spdy/spdy_frame_builder.h"
21 using base::StringPiece;
26 using testing::AnyNumber;
32 // No-op alarm implementation used by MockHelper.
33 class TestAlarm : public QuicAlarm {
35 explicit TestAlarm(QuicAlarm::Delegate* delegate)
36 : QuicAlarm(delegate) {
39 virtual void SetImpl() OVERRIDE {}
40 virtual void CancelImpl() OVERRIDE {}
45 QuicAckFrame MakeAckFrame(QuicPacketSequenceNumber largest_observed,
46 QuicPacketSequenceNumber least_unacked) {
48 ack.received_info.largest_observed = largest_observed;
49 ack.received_info.entropy_hash = 0;
50 ack.sent_info.least_unacked = least_unacked;
51 ack.sent_info.entropy_hash = 0;
55 MockFramerVisitor::MockFramerVisitor() {
56 // By default, we want to accept packets.
57 ON_CALL(*this, OnProtocolVersionMismatch(_))
58 .WillByDefault(testing::Return(false));
60 // By default, we want to accept packets.
61 ON_CALL(*this, OnUnauthenticatedHeader(_))
62 .WillByDefault(testing::Return(true));
64 ON_CALL(*this, OnUnauthenticatedPublicHeader(_))
65 .WillByDefault(testing::Return(true));
67 ON_CALL(*this, OnPacketHeader(_))
68 .WillByDefault(testing::Return(true));
70 ON_CALL(*this, OnStreamFrame(_))
71 .WillByDefault(testing::Return(true));
73 ON_CALL(*this, OnAckFrame(_))
74 .WillByDefault(testing::Return(true));
76 ON_CALL(*this, OnCongestionFeedbackFrame(_))
77 .WillByDefault(testing::Return(true));
79 ON_CALL(*this, OnStopWaitingFrame(_))
80 .WillByDefault(testing::Return(true));
82 ON_CALL(*this, OnPingFrame(_))
83 .WillByDefault(testing::Return(true));
85 ON_CALL(*this, OnRstStreamFrame(_))
86 .WillByDefault(testing::Return(true));
88 ON_CALL(*this, OnConnectionCloseFrame(_))
89 .WillByDefault(testing::Return(true));
91 ON_CALL(*this, OnGoAwayFrame(_))
92 .WillByDefault(testing::Return(true));
95 MockFramerVisitor::~MockFramerVisitor() {
98 bool NoOpFramerVisitor::OnProtocolVersionMismatch(QuicVersion version) {
102 bool NoOpFramerVisitor::OnUnauthenticatedPublicHeader(
103 const QuicPacketPublicHeader& header) {
107 bool NoOpFramerVisitor::OnUnauthenticatedHeader(
108 const QuicPacketHeader& header) {
112 bool NoOpFramerVisitor::OnPacketHeader(const QuicPacketHeader& header) {
116 bool NoOpFramerVisitor::OnStreamFrame(const QuicStreamFrame& frame) {
120 bool NoOpFramerVisitor::OnAckFrame(const QuicAckFrame& frame) {
124 bool NoOpFramerVisitor::OnCongestionFeedbackFrame(
125 const QuicCongestionFeedbackFrame& frame) {
129 bool NoOpFramerVisitor::OnStopWaitingFrame(
130 const QuicStopWaitingFrame& frame) {
134 bool NoOpFramerVisitor::OnPingFrame(const QuicPingFrame& frame) {
138 bool NoOpFramerVisitor::OnRstStreamFrame(
139 const QuicRstStreamFrame& frame) {
143 bool NoOpFramerVisitor::OnConnectionCloseFrame(
144 const QuicConnectionCloseFrame& frame) {
148 bool NoOpFramerVisitor::OnGoAwayFrame(const QuicGoAwayFrame& frame) {
152 bool NoOpFramerVisitor::OnWindowUpdateFrame(
153 const QuicWindowUpdateFrame& frame) {
157 bool NoOpFramerVisitor::OnBlockedFrame(const QuicBlockedFrame& frame) {
161 MockConnectionVisitor::MockConnectionVisitor() {
164 MockConnectionVisitor::~MockConnectionVisitor() {
167 MockHelper::MockHelper() {
170 MockHelper::~MockHelper() {
173 const QuicClock* MockHelper::GetClock() const {
177 QuicRandom* MockHelper::GetRandomGenerator() {
178 return &random_generator_;
181 QuicAlarm* MockHelper::CreateAlarm(QuicAlarm::Delegate* delegate) {
182 return new TestAlarm(delegate);
185 void MockHelper::AdvanceTime(QuicTime::Delta delta) {
186 clock_.AdvanceTime(delta);
189 MockConnection::MockConnection(bool is_server)
190 : QuicConnection(kTestConnectionId,
191 IPEndPoint(TestPeerIPAddress(), kTestPort),
192 new testing::NiceMock<MockHelper>(),
193 new testing::NiceMock<MockPacketWriter>(),
194 is_server, QuicSupportedVersions(),
195 kInitialFlowControlWindowForTest),
196 writer_(QuicConnectionPeer::GetWriter(this)),
200 MockConnection::MockConnection(IPEndPoint address,
202 : QuicConnection(kTestConnectionId, address,
203 new testing::NiceMock<MockHelper>(),
204 new testing::NiceMock<MockPacketWriter>(),
205 is_server, QuicSupportedVersions(),
206 kInitialFlowControlWindowForTest),
207 writer_(QuicConnectionPeer::GetWriter(this)),
211 MockConnection::MockConnection(QuicConnectionId connection_id,
213 : QuicConnection(connection_id,
214 IPEndPoint(TestPeerIPAddress(), kTestPort),
215 new testing::NiceMock<MockHelper>(),
216 new testing::NiceMock<MockPacketWriter>(),
217 is_server, QuicSupportedVersions(),
218 kInitialFlowControlWindowForTest),
219 writer_(QuicConnectionPeer::GetWriter(this)),
223 MockConnection::MockConnection(bool is_server,
224 const QuicVersionVector& supported_versions)
225 : QuicConnection(kTestConnectionId,
226 IPEndPoint(TestPeerIPAddress(), kTestPort),
227 new testing::NiceMock<MockHelper>(),
228 new testing::NiceMock<MockPacketWriter>(),
229 is_server, supported_versions,
230 kInitialFlowControlWindowForTest),
231 writer_(QuicConnectionPeer::GetWriter(this)),
235 MockConnection::~MockConnection() {
238 void MockConnection::AdvanceTime(QuicTime::Delta delta) {
239 static_cast<MockHelper*>(helper())->AdvanceTime(delta);
242 PacketSavingConnection::PacketSavingConnection(bool is_server)
243 : MockConnection(is_server) {
246 PacketSavingConnection::PacketSavingConnection(
248 const QuicVersionVector& supported_versions)
249 : MockConnection(is_server, supported_versions) {
252 PacketSavingConnection::~PacketSavingConnection() {
253 STLDeleteElements(&packets_);
254 STLDeleteElements(&encrypted_packets_);
257 bool PacketSavingConnection::SendOrQueuePacket(
258 EncryptionLevel level,
259 const SerializedPacket& packet,
260 TransmissionType transmission_type) {
261 packets_.push_back(packet.packet);
262 QuicEncryptedPacket* encrypted = QuicConnectionPeer::GetFramer(this)->
263 EncryptPacket(level, packet.sequence_number, *packet.packet);
264 encrypted_packets_.push_back(encrypted);
268 MockSession::MockSession(QuicConnection* connection)
269 : QuicSession(connection, DefaultQuicConfig()) {
270 ON_CALL(*this, WritevData(_, _, _, _, _))
271 .WillByDefault(testing::Return(QuicConsumedData(0, false)));
274 MockSession::~MockSession() {
277 TestSession::TestSession(QuicConnection* connection,
278 const QuicConfig& config)
279 : QuicSession(connection, config),
280 crypto_stream_(NULL) {
283 TestSession::~TestSession() {}
285 void TestSession::SetCryptoStream(QuicCryptoStream* stream) {
286 crypto_stream_ = stream;
289 QuicCryptoStream* TestSession::GetCryptoStream() {
290 return crypto_stream_;
293 TestClientSession::TestClientSession(QuicConnection* connection,
294 const QuicConfig& config)
295 : QuicClientSessionBase(connection, config),
296 crypto_stream_(NULL) {
297 EXPECT_CALL(*this, OnProofValid(_)).Times(AnyNumber());
300 TestClientSession::~TestClientSession() {}
302 void TestClientSession::SetCryptoStream(QuicCryptoStream* stream) {
303 crypto_stream_ = stream;
306 QuicCryptoStream* TestClientSession::GetCryptoStream() {
307 return crypto_stream_;
310 MockPacketWriter::MockPacketWriter() {
313 MockPacketWriter::~MockPacketWriter() {
316 MockSendAlgorithm::MockSendAlgorithm() {
319 MockSendAlgorithm::~MockSendAlgorithm() {
322 MockLossAlgorithm::MockLossAlgorithm() {
325 MockLossAlgorithm::~MockLossAlgorithm() {
328 MockAckNotifierDelegate::MockAckNotifierDelegate() {
331 MockAckNotifierDelegate::~MockAckNotifierDelegate() {
336 string HexDumpWithMarks(const char* data, int length,
337 const bool* marks, int mark_length) {
338 static const char kHexChars[] = "0123456789abcdef";
339 static const int kColumns = 4;
341 const int kSizeLimit = 1024;
342 if (length > kSizeLimit || mark_length > kSizeLimit) {
343 LOG(ERROR) << "Only dumping first " << kSizeLimit << " bytes.";
344 length = min(length, kSizeLimit);
345 mark_length = min(mark_length, kSizeLimit);
349 for (const char* row = data; length > 0;
350 row += kColumns, length -= kColumns) {
351 for (const char *p = row; p < row + 4; ++p) {
352 if (p < row + length) {
354 (marks && (p - data) < mark_length && marks[p - data]);
355 hex += mark ? '*' : ' ';
356 hex += kHexChars[(*p & 0xf0) >> 4];
357 hex += kHexChars[*p & 0x0f];
358 hex += mark ? '*' : ' ';
365 for (const char *p = row; p < row + 4 && p < row + length; ++p)
366 hex += (*p >= 0x20 && *p <= 0x7f) ? (*p) : '.';
375 IPAddressNumber TestPeerIPAddress() { return Loopback4(); }
377 QuicVersion QuicVersionMax() { return QuicSupportedVersions().front(); }
379 QuicVersion QuicVersionMin() { return QuicSupportedVersions().back(); }
381 IPAddressNumber Loopback4() {
382 IPAddressNumber addr;
383 CHECK(ParseIPLiteralToNumber("127.0.0.1", &addr));
387 void GenerateBody(string* body, int length) {
389 body->reserve(length);
390 for (int i = 0; i < length; ++i) {
391 body->append(1, static_cast<char>(32 + i % (126 - 32)));
395 QuicEncryptedPacket* ConstructEncryptedPacket(
396 QuicConnectionId connection_id,
399 QuicPacketSequenceNumber sequence_number,
400 const string& data) {
401 QuicPacketHeader header;
402 header.public_header.connection_id = connection_id;
403 header.public_header.connection_id_length = PACKET_8BYTE_CONNECTION_ID;
404 header.public_header.version_flag = version_flag;
405 header.public_header.reset_flag = reset_flag;
406 header.public_header.sequence_number_length = PACKET_6BYTE_SEQUENCE_NUMBER;
407 header.packet_sequence_number = sequence_number;
408 header.entropy_flag = false;
409 header.entropy_hash = 0;
410 header.fec_flag = false;
411 header.is_in_fec_group = NOT_IN_FEC_GROUP;
412 header.fec_group = 0;
413 QuicStreamFrame stream_frame(1, false, 0, MakeIOVector(data));
414 QuicFrame frame(&stream_frame);
416 frames.push_back(frame);
417 QuicFramer framer(QuicSupportedVersions(), QuicTime::Zero(), false);
418 scoped_ptr<QuicPacket> packet(
419 framer.BuildUnsizedDataPacket(header, frames).packet);
420 EXPECT_TRUE(packet != NULL);
421 QuicEncryptedPacket* encrypted = framer.EncryptPacket(ENCRYPTION_NONE,
424 EXPECT_TRUE(encrypted != NULL);
428 void CompareCharArraysWithHexError(
429 const string& description,
431 const int actual_len,
432 const char* expected,
433 const int expected_len) {
434 EXPECT_EQ(actual_len, expected_len);
435 const int min_len = min(actual_len, expected_len);
436 const int max_len = max(actual_len, expected_len);
437 scoped_ptr<bool[]> marks(new bool[max_len]);
438 bool identical = (actual_len == expected_len);
439 for (int i = 0; i < min_len; ++i) {
440 if (actual[i] != expected[i]) {
447 for (int i = min_len; i < max_len; ++i) {
450 if (identical) return;
455 << HexDumpWithMarks(expected, expected_len, marks.get(), max_len)
457 << HexDumpWithMarks(actual, actual_len, marks.get(), max_len);
460 bool DecodeHexString(const base::StringPiece& hex, std::string* bytes) {
464 std::vector<uint8> v;
465 if (!base::HexStringToBytes(hex.as_string(), &v))
468 bytes->assign(reinterpret_cast<const char*>(&v[0]), v.size());
472 static QuicPacket* ConstructPacketFromHandshakeMessage(
473 QuicConnectionId connection_id,
474 const CryptoHandshakeMessage& message,
475 bool should_include_version) {
476 CryptoFramer crypto_framer;
477 scoped_ptr<QuicData> data(crypto_framer.ConstructHandshakeMessage(message));
478 QuicFramer quic_framer(QuicSupportedVersions(), QuicTime::Zero(), false);
480 QuicPacketHeader header;
481 header.public_header.connection_id = connection_id;
482 header.public_header.reset_flag = false;
483 header.public_header.version_flag = should_include_version;
484 header.packet_sequence_number = 1;
485 header.entropy_flag = false;
486 header.entropy_hash = 0;
487 header.fec_flag = false;
488 header.fec_group = 0;
490 QuicStreamFrame stream_frame(kCryptoStreamId, false, 0,
491 MakeIOVector(data->AsStringPiece()));
493 QuicFrame frame(&stream_frame);
495 frames.push_back(frame);
496 return quic_framer.BuildUnsizedDataPacket(header, frames).packet;
499 QuicPacket* ConstructHandshakePacket(QuicConnectionId connection_id,
501 CryptoHandshakeMessage message;
502 message.set_tag(tag);
503 return ConstructPacketFromHandshakeMessage(connection_id, message, false);
506 size_t GetPacketLengthForOneStream(
508 bool include_version,
509 QuicSequenceNumberLength sequence_number_length,
510 InFecGroup is_in_fec_group,
511 size_t* payload_length) {
513 const size_t stream_length =
514 NullEncrypter().GetCiphertextSize(*payload_length) +
515 QuicPacketCreator::StreamFramePacketOverhead(
516 version, PACKET_8BYTE_CONNECTION_ID, include_version,
517 sequence_number_length, is_in_fec_group);
518 const size_t ack_length = NullEncrypter().GetCiphertextSize(
519 QuicFramer::GetMinAckFrameSize(
520 version, sequence_number_length, PACKET_1BYTE_SEQUENCE_NUMBER)) +
521 GetPacketHeaderSize(PACKET_8BYTE_CONNECTION_ID, include_version,
522 sequence_number_length, is_in_fec_group);
523 if (stream_length < ack_length) {
524 *payload_length = 1 + ack_length - stream_length;
527 return NullEncrypter().GetCiphertextSize(*payload_length) +
528 QuicPacketCreator::StreamFramePacketOverhead(
529 version, PACKET_8BYTE_CONNECTION_ID, include_version,
530 sequence_number_length, is_in_fec_group);
533 TestEntropyCalculator::TestEntropyCalculator() { }
535 TestEntropyCalculator::~TestEntropyCalculator() { }
537 QuicPacketEntropyHash TestEntropyCalculator::EntropyHash(
538 QuicPacketSequenceNumber sequence_number) const {
542 MockEntropyCalculator::MockEntropyCalculator() { }
544 MockEntropyCalculator::~MockEntropyCalculator() { }
546 QuicConfig DefaultQuicConfig() {
548 config.SetDefaults();
552 QuicVersionVector SupportedVersions(QuicVersion version) {
553 QuicVersionVector versions;
554 versions.push_back(version);