1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "net/socket/socks_client_socket.h"
7 #include "base/memory/scoped_ptr.h"
8 #include "net/base/address_list.h"
9 #include "net/base/test_completion_callback.h"
10 #include "net/base/winsock_init.h"
11 #include "net/dns/host_resolver.h"
12 #include "net/dns/mock_host_resolver.h"
13 #include "net/log/net_log.h"
14 #include "net/log/net_log_unittest.h"
15 #include "net/socket/client_socket_factory.h"
16 #include "net/socket/socket_test_util.h"
17 #include "net/socket/tcp_client_socket.h"
18 #include "testing/gtest/include/gtest/gtest.h"
19 #include "testing/platform_test.h"
21 //-----------------------------------------------------------------------------
25 const char kSOCKSOkRequest
[] = { 0x04, 0x01, 0x00, 0x50, 127, 0, 0, 1, 0 };
26 const char kSOCKSOkReply
[] = { 0x00, 0x5A, 0x00, 0x00, 0, 0, 0, 0 };
28 class SOCKSClientSocketTest
: public PlatformTest
{
30 SOCKSClientSocketTest();
31 // Create a SOCKSClientSocket on top of a MockSocket.
32 scoped_ptr
<SOCKSClientSocket
> BuildMockSocket(
33 MockRead reads
[], size_t reads_count
,
34 MockWrite writes
[], size_t writes_count
,
35 HostResolver
* host_resolver
,
36 const std::string
& hostname
, int port
,
38 void SetUp() override
;
41 scoped_ptr
<SOCKSClientSocket
> user_sock_
;
42 AddressList address_list_
;
43 // Filled in by BuildMockSocket() and owned by its return value
44 // (which |user_sock| is set to).
45 StreamSocket
* tcp_sock_
;
46 TestCompletionCallback callback_
;
47 scoped_ptr
<MockHostResolver
> host_resolver_
;
48 scoped_ptr
<SocketDataProvider
> data_
;
51 SOCKSClientSocketTest::SOCKSClientSocketTest()
52 : host_resolver_(new MockHostResolver
) {
55 // Set up platform before every test case
56 void SOCKSClientSocketTest::SetUp() {
57 PlatformTest::SetUp();
60 scoped_ptr
<SOCKSClientSocket
> SOCKSClientSocketTest::BuildMockSocket(
65 HostResolver
* host_resolver
,
66 const std::string
& hostname
,
70 TestCompletionCallback callback
;
71 data_
.reset(new StaticSocketDataProvider(reads
, reads_count
,
72 writes
, writes_count
));
73 tcp_sock_
= new MockTCPClientSocket(address_list_
, net_log
, data_
.get());
75 int rv
= tcp_sock_
->Connect(callback
.callback());
76 EXPECT_EQ(ERR_IO_PENDING
, rv
);
77 rv
= callback
.WaitForResult();
79 EXPECT_TRUE(tcp_sock_
->IsConnected());
81 scoped_ptr
<ClientSocketHandle
> connection(new ClientSocketHandle
);
82 // |connection| takes ownership of |tcp_sock_|, but keep a
83 // non-owning pointer to it.
84 connection
->SetSocket(scoped_ptr
<StreamSocket
>(tcp_sock_
));
85 return scoped_ptr
<SOCKSClientSocket
>(new SOCKSClientSocket(
87 HostResolver::RequestInfo(HostPortPair(hostname
, port
)),
92 // Implementation of HostResolver that never completes its resolve request.
93 // We use this in the test "DisconnectWhileHostResolveInProgress" to make
94 // sure that the outstanding resolve request gets cancelled.
95 class HangingHostResolverWithCancel
: public HostResolver
{
97 HangingHostResolverWithCancel() : outstanding_request_(NULL
) {}
99 int Resolve(const RequestInfo
& info
,
100 RequestPriority priority
,
101 AddressList
* addresses
,
102 const CompletionCallback
& callback
,
103 RequestHandle
* out_req
,
104 const BoundNetLog
& net_log
) override
{
106 DCHECK_EQ(false, callback
.is_null());
107 EXPECT_FALSE(HasOutstandingRequest());
108 outstanding_request_
= reinterpret_cast<RequestHandle
>(1);
109 *out_req
= outstanding_request_
;
110 return ERR_IO_PENDING
;
113 int ResolveFromCache(const RequestInfo
& info
,
114 AddressList
* addresses
,
115 const BoundNetLog
& net_log
) override
{
117 return ERR_UNEXPECTED
;
120 void CancelRequest(RequestHandle req
) override
{
121 EXPECT_TRUE(HasOutstandingRequest());
122 EXPECT_EQ(outstanding_request_
, req
);
123 outstanding_request_
= NULL
;
126 bool HasOutstandingRequest() {
127 return outstanding_request_
!= NULL
;
131 RequestHandle outstanding_request_
;
133 DISALLOW_COPY_AND_ASSIGN(HangingHostResolverWithCancel
);
136 // Tests a complete handshake and the disconnection.
137 TEST_F(SOCKSClientSocketTest
, CompleteHandshake
) {
138 const std::string payload_write
= "random data";
139 const std::string payload_read
= "moar random data";
141 MockWrite data_writes
[] = {
142 MockWrite(ASYNC
, kSOCKSOkRequest
, arraysize(kSOCKSOkRequest
)),
143 MockWrite(ASYNC
, payload_write
.data(), payload_write
.size()) };
144 MockRead data_reads
[] = {
145 MockRead(ASYNC
, kSOCKSOkReply
, arraysize(kSOCKSOkReply
)),
146 MockRead(ASYNC
, payload_read
.data(), payload_read
.size()) };
149 user_sock_
= BuildMockSocket(data_reads
, arraysize(data_reads
),
150 data_writes
, arraysize(data_writes
),
151 host_resolver_
.get(),
155 // At this state the TCP connection is completed but not the SOCKS handshake.
156 EXPECT_TRUE(tcp_sock_
->IsConnected());
157 EXPECT_FALSE(user_sock_
->IsConnected());
159 int rv
= user_sock_
->Connect(callback_
.callback());
160 EXPECT_EQ(ERR_IO_PENDING
, rv
);
162 TestNetLog::CapturedEntryList entries
;
163 log
.GetEntries(&entries
);
165 LogContainsBeginEvent(entries
, 0, NetLog::TYPE_SOCKS_CONNECT
));
166 EXPECT_FALSE(user_sock_
->IsConnected());
168 rv
= callback_
.WaitForResult();
170 EXPECT_TRUE(user_sock_
->IsConnected());
171 log
.GetEntries(&entries
);
172 EXPECT_TRUE(LogContainsEndEvent(
173 entries
, -1, NetLog::TYPE_SOCKS_CONNECT
));
175 scoped_refptr
<IOBuffer
> buffer(new IOBuffer(payload_write
.size()));
176 memcpy(buffer
->data(), payload_write
.data(), payload_write
.size());
177 rv
= user_sock_
->Write(
178 buffer
.get(), payload_write
.size(), callback_
.callback());
179 EXPECT_EQ(ERR_IO_PENDING
, rv
);
180 rv
= callback_
.WaitForResult();
181 EXPECT_EQ(static_cast<int>(payload_write
.size()), rv
);
183 buffer
= new IOBuffer(payload_read
.size());
185 user_sock_
->Read(buffer
.get(), payload_read
.size(), callback_
.callback());
186 EXPECT_EQ(ERR_IO_PENDING
, rv
);
187 rv
= callback_
.WaitForResult();
188 EXPECT_EQ(static_cast<int>(payload_read
.size()), rv
);
189 EXPECT_EQ(payload_read
, std::string(buffer
->data(), payload_read
.size()));
191 user_sock_
->Disconnect();
192 EXPECT_FALSE(tcp_sock_
->IsConnected());
193 EXPECT_FALSE(user_sock_
->IsConnected());
196 // List of responses from the socks server and the errors they should
197 // throw up are tested here.
198 TEST_F(SOCKSClientSocketTest
, HandshakeFailures
) {
200 const char fail_reply
[8];
203 // Failure of the server response code
205 { 0x01, 0x5A, 0x00, 0x00, 0, 0, 0, 0 },
206 ERR_SOCKS_CONNECTION_FAILED
,
208 // Failure of the null byte
210 { 0x00, 0x5B, 0x00, 0x00, 0, 0, 0, 0 },
211 ERR_SOCKS_CONNECTION_FAILED
,
215 //---------------------------------------
217 for (size_t i
= 0; i
< arraysize(tests
); ++i
) {
218 MockWrite data_writes
[] = {
219 MockWrite(SYNCHRONOUS
, kSOCKSOkRequest
, arraysize(kSOCKSOkRequest
)) };
220 MockRead data_reads
[] = {
221 MockRead(SYNCHRONOUS
, tests
[i
].fail_reply
,
222 arraysize(tests
[i
].fail_reply
)) };
225 user_sock_
= BuildMockSocket(data_reads
, arraysize(data_reads
),
226 data_writes
, arraysize(data_writes
),
227 host_resolver_
.get(),
231 int rv
= user_sock_
->Connect(callback_
.callback());
232 EXPECT_EQ(ERR_IO_PENDING
, rv
);
234 TestNetLog::CapturedEntryList entries
;
235 log
.GetEntries(&entries
);
236 EXPECT_TRUE(LogContainsBeginEvent(
237 entries
, 0, NetLog::TYPE_SOCKS_CONNECT
));
239 rv
= callback_
.WaitForResult();
240 EXPECT_EQ(tests
[i
].fail_code
, rv
);
241 EXPECT_FALSE(user_sock_
->IsConnected());
242 EXPECT_TRUE(tcp_sock_
->IsConnected());
243 log
.GetEntries(&entries
);
244 EXPECT_TRUE(LogContainsEndEvent(
245 entries
, -1, NetLog::TYPE_SOCKS_CONNECT
));
249 // Tests scenario when the server sends the handshake response in
250 // more than one packet.
251 TEST_F(SOCKSClientSocketTest
, PartialServerReads
) {
252 const char kSOCKSPartialReply1
[] = { 0x00 };
253 const char kSOCKSPartialReply2
[] = { 0x5A, 0x00, 0x00, 0, 0, 0, 0 };
255 MockWrite data_writes
[] = {
256 MockWrite(ASYNC
, kSOCKSOkRequest
, arraysize(kSOCKSOkRequest
)) };
257 MockRead data_reads
[] = {
258 MockRead(ASYNC
, kSOCKSPartialReply1
, arraysize(kSOCKSPartialReply1
)),
259 MockRead(ASYNC
, kSOCKSPartialReply2
, arraysize(kSOCKSPartialReply2
)) };
262 user_sock_
= BuildMockSocket(data_reads
, arraysize(data_reads
),
263 data_writes
, arraysize(data_writes
),
264 host_resolver_
.get(),
268 int rv
= user_sock_
->Connect(callback_
.callback());
269 EXPECT_EQ(ERR_IO_PENDING
, rv
);
270 TestNetLog::CapturedEntryList entries
;
271 log
.GetEntries(&entries
);
272 EXPECT_TRUE(LogContainsBeginEvent(
273 entries
, 0, NetLog::TYPE_SOCKS_CONNECT
));
275 rv
= callback_
.WaitForResult();
277 EXPECT_TRUE(user_sock_
->IsConnected());
278 log
.GetEntries(&entries
);
279 EXPECT_TRUE(LogContainsEndEvent(
280 entries
, -1, NetLog::TYPE_SOCKS_CONNECT
));
283 // Tests scenario when the client sends the handshake request in
284 // more than one packet.
285 TEST_F(SOCKSClientSocketTest
, PartialClientWrites
) {
286 const char kSOCKSPartialRequest1
[] = { 0x04, 0x01 };
287 const char kSOCKSPartialRequest2
[] = { 0x00, 0x50, 127, 0, 0, 1, 0 };
289 MockWrite data_writes
[] = {
290 MockWrite(ASYNC
, kSOCKSPartialRequest1
, arraysize(kSOCKSPartialRequest1
)),
291 // simulate some empty writes
294 MockWrite(ASYNC
, kSOCKSPartialRequest2
, arraysize(kSOCKSPartialRequest2
)),
296 MockRead data_reads
[] = {
297 MockRead(ASYNC
, kSOCKSOkReply
, arraysize(kSOCKSOkReply
)) };
300 user_sock_
= BuildMockSocket(data_reads
, arraysize(data_reads
),
301 data_writes
, arraysize(data_writes
),
302 host_resolver_
.get(),
306 int rv
= user_sock_
->Connect(callback_
.callback());
307 EXPECT_EQ(ERR_IO_PENDING
, rv
);
308 TestNetLog::CapturedEntryList entries
;
309 log
.GetEntries(&entries
);
310 EXPECT_TRUE(LogContainsBeginEvent(
311 entries
, 0, NetLog::TYPE_SOCKS_CONNECT
));
313 rv
= callback_
.WaitForResult();
315 EXPECT_TRUE(user_sock_
->IsConnected());
316 log
.GetEntries(&entries
);
317 EXPECT_TRUE(LogContainsEndEvent(
318 entries
, -1, NetLog::TYPE_SOCKS_CONNECT
));
321 // Tests the case when the server sends a smaller sized handshake data
322 // and closes the connection.
323 TEST_F(SOCKSClientSocketTest
, FailedSocketRead
) {
324 MockWrite data_writes
[] = {
325 MockWrite(ASYNC
, kSOCKSOkRequest
, arraysize(kSOCKSOkRequest
)) };
326 MockRead data_reads
[] = {
327 MockRead(ASYNC
, kSOCKSOkReply
, arraysize(kSOCKSOkReply
) - 2),
328 // close connection unexpectedly
329 MockRead(SYNCHRONOUS
, 0) };
332 user_sock_
= BuildMockSocket(data_reads
, arraysize(data_reads
),
333 data_writes
, arraysize(data_writes
),
334 host_resolver_
.get(),
338 int rv
= user_sock_
->Connect(callback_
.callback());
339 EXPECT_EQ(ERR_IO_PENDING
, rv
);
340 TestNetLog::CapturedEntryList entries
;
341 log
.GetEntries(&entries
);
342 EXPECT_TRUE(LogContainsBeginEvent(
343 entries
, 0, NetLog::TYPE_SOCKS_CONNECT
));
345 rv
= callback_
.WaitForResult();
346 EXPECT_EQ(ERR_CONNECTION_CLOSED
, rv
);
347 EXPECT_FALSE(user_sock_
->IsConnected());
348 log
.GetEntries(&entries
);
349 EXPECT_TRUE(LogContainsEndEvent(
350 entries
, -1, NetLog::TYPE_SOCKS_CONNECT
));
353 // Tries to connect to an unknown hostname. Should fail rather than
354 // falling back to SOCKS4a.
355 TEST_F(SOCKSClientSocketTest
, FailedDNS
) {
356 const char hostname
[] = "unresolved.ipv4.address";
358 host_resolver_
->rules()->AddSimulatedFailure(hostname
);
362 user_sock_
= BuildMockSocket(NULL
, 0,
364 host_resolver_
.get(),
368 int rv
= user_sock_
->Connect(callback_
.callback());
369 EXPECT_EQ(ERR_IO_PENDING
, rv
);
370 TestNetLog::CapturedEntryList entries
;
371 log
.GetEntries(&entries
);
372 EXPECT_TRUE(LogContainsBeginEvent(
373 entries
, 0, NetLog::TYPE_SOCKS_CONNECT
));
375 rv
= callback_
.WaitForResult();
376 EXPECT_EQ(ERR_NAME_NOT_RESOLVED
, rv
);
377 EXPECT_FALSE(user_sock_
->IsConnected());
378 log
.GetEntries(&entries
);
379 EXPECT_TRUE(LogContainsEndEvent(
380 entries
, -1, NetLog::TYPE_SOCKS_CONNECT
));
383 // Calls Disconnect() while a host resolve is in progress. The outstanding host
384 // resolve should be cancelled.
385 TEST_F(SOCKSClientSocketTest
, DisconnectWhileHostResolveInProgress
) {
386 scoped_ptr
<HangingHostResolverWithCancel
> hanging_resolver(
387 new HangingHostResolverWithCancel());
389 // Doesn't matter what the socket data is, we will never use it -- garbage.
390 MockWrite data_writes
[] = { MockWrite(SYNCHRONOUS
, "", 0) };
391 MockRead data_reads
[] = { MockRead(SYNCHRONOUS
, "", 0) };
393 user_sock_
= BuildMockSocket(data_reads
, arraysize(data_reads
),
394 data_writes
, arraysize(data_writes
),
395 hanging_resolver
.get(),
399 // Start connecting (will get stuck waiting for the host to resolve).
400 int rv
= user_sock_
->Connect(callback_
.callback());
401 EXPECT_EQ(ERR_IO_PENDING
, rv
);
403 EXPECT_FALSE(user_sock_
->IsConnected());
404 EXPECT_FALSE(user_sock_
->IsConnectedAndIdle());
406 // The host resolver should have received the resolve request.
407 EXPECT_TRUE(hanging_resolver
->HasOutstandingRequest());
409 // Disconnect the SOCKS socket -- this should cancel the outstanding resolve.
410 user_sock_
->Disconnect();
412 EXPECT_FALSE(hanging_resolver
->HasOutstandingRequest());
414 EXPECT_FALSE(user_sock_
->IsConnected());
415 EXPECT_FALSE(user_sock_
->IsConnectedAndIdle());
418 // Tries to connect to an IPv6 IP. Should fail, as SOCKS4 does not support
420 TEST_F(SOCKSClientSocketTest
, NoIPv6
) {
421 const char kHostName
[] = "::1";
423 user_sock_
= BuildMockSocket(NULL
, 0,
425 host_resolver_
.get(),
429 EXPECT_EQ(ERR_NAME_NOT_RESOLVED
,
430 callback_
.GetResult(user_sock_
->Connect(callback_
.callback())));
433 // Same as above, but with a real resolver, to protect against regressions.
434 TEST_F(SOCKSClientSocketTest
, NoIPv6RealResolver
) {
435 const char kHostName
[] = "::1";
437 scoped_ptr
<HostResolver
> host_resolver(
438 HostResolver::CreateSystemResolver(HostResolver::Options(), NULL
));
440 user_sock_
= BuildMockSocket(NULL
, 0,
446 EXPECT_EQ(ERR_NAME_NOT_RESOLVED
,
447 callback_
.GetResult(user_sock_
->Connect(callback_
.callback())));