1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "net/dns/dns_socket_pool.h"
7 #include "base/logging.h"
8 #include "base/rand_util.h"
9 #include "base/stl_util.h"
10 #include "net/base/address_list.h"
11 #include "net/base/ip_endpoint.h"
12 #include "net/base/net_errors.h"
13 #include "net/base/rand_callback.h"
14 #include "net/socket/client_socket_factory.h"
15 #include "net/socket/stream_socket.h"
16 #include "net/udp/datagram_client_socket.h"
22 // When we initialize the SocketPool, we allocate kInitialPoolSize sockets.
23 // When we allocate a socket, we ensure we have at least kAllocateMinSize
24 // sockets to choose from. Freed sockets are not retained.
26 // On Windows, we can't request specific (random) ports, since that will
27 // trigger firewall prompts, so request default ones, but keep a pile of
28 // them. Everywhere else, request fresh, random ports each time.
30 const DatagramSocket::BindType kBindType
= DatagramSocket::DEFAULT_BIND
;
31 const unsigned kInitialPoolSize
= 256;
32 const unsigned kAllocateMinSize
= 256;
34 const DatagramSocket::BindType kBindType
= DatagramSocket::RANDOM_BIND
;
35 const unsigned kInitialPoolSize
= 0;
36 const unsigned kAllocateMinSize
= 1;
41 DnsSocketPool::DnsSocketPool(ClientSocketFactory
* socket_factory
)
42 : socket_factory_(socket_factory
),
48 void DnsSocketPool::InitializeInternal(
49 const std::vector
<IPEndPoint
>* nameservers
,
52 DCHECK(!initialized_
);
55 nameservers_
= nameservers
;
59 scoped_ptr
<StreamSocket
> DnsSocketPool::CreateTCPSocket(
60 unsigned server_index
,
61 const NetLog::Source
& source
) {
62 DCHECK_LT(server_index
, nameservers_
->size());
64 return scoped_ptr
<StreamSocket
>(
65 socket_factory_
->CreateTransportClientSocket(
66 AddressList((*nameservers_
)[server_index
]), net_log_
, source
));
69 scoped_ptr
<DatagramClientSocket
> DnsSocketPool::CreateConnectedSocket(
70 unsigned server_index
) {
71 DCHECK_LT(server_index
, nameservers_
->size());
73 scoped_ptr
<DatagramClientSocket
> socket
;
75 NetLog::Source no_source
;
76 socket
= socket_factory_
->CreateDatagramClientSocket(
77 kBindType
, base::Bind(&base::RandInt
), net_log_
, no_source
);
80 int rv
= socket
->Connect((*nameservers_
)[server_index
]);
82 VLOG(1) << "Failed to connect socket: " << rv
;
86 LOG(WARNING
) << "Failed to create socket.";
92 class NullDnsSocketPool
: public DnsSocketPool
{
94 NullDnsSocketPool(ClientSocketFactory
* factory
)
95 : DnsSocketPool(factory
) {
98 void Initialize(const std::vector
<IPEndPoint
>* nameservers
,
99 NetLog
* net_log
) override
{
100 InitializeInternal(nameservers
, net_log
);
103 scoped_ptr
<DatagramClientSocket
> AllocateSocket(
104 unsigned server_index
) override
{
105 return CreateConnectedSocket(server_index
);
108 void FreeSocket(unsigned server_index
,
109 scoped_ptr
<DatagramClientSocket
> socket
) override
{}
112 DISALLOW_COPY_AND_ASSIGN(NullDnsSocketPool
);
116 scoped_ptr
<DnsSocketPool
> DnsSocketPool::CreateNull(
117 ClientSocketFactory
* factory
) {
118 return scoped_ptr
<DnsSocketPool
>(new NullDnsSocketPool(factory
));
121 class DefaultDnsSocketPool
: public DnsSocketPool
{
123 DefaultDnsSocketPool(ClientSocketFactory
* factory
)
124 : DnsSocketPool(factory
) {
127 ~DefaultDnsSocketPool() override
;
129 void Initialize(const std::vector
<IPEndPoint
>* nameservers
,
130 NetLog
* net_log
) override
;
132 scoped_ptr
<DatagramClientSocket
> AllocateSocket(
133 unsigned server_index
) override
;
135 void FreeSocket(unsigned server_index
,
136 scoped_ptr
<DatagramClientSocket
> socket
) override
;
139 void FillPool(unsigned server_index
, unsigned size
);
141 typedef std::vector
<DatagramClientSocket
*> SocketVector
;
143 std::vector
<SocketVector
> pools_
;
145 DISALLOW_COPY_AND_ASSIGN(DefaultDnsSocketPool
);
149 scoped_ptr
<DnsSocketPool
> DnsSocketPool::CreateDefault(
150 ClientSocketFactory
* factory
) {
151 return scoped_ptr
<DnsSocketPool
>(new DefaultDnsSocketPool(factory
));
154 void DefaultDnsSocketPool::Initialize(
155 const std::vector
<IPEndPoint
>* nameservers
,
157 InitializeInternal(nameservers
, net_log
);
159 DCHECK(pools_
.empty());
160 const unsigned num_servers
= nameservers
->size();
161 pools_
.resize(num_servers
);
162 for (unsigned server_index
= 0; server_index
< num_servers
; ++server_index
)
163 FillPool(server_index
, kInitialPoolSize
);
166 DefaultDnsSocketPool::~DefaultDnsSocketPool() {
167 unsigned num_servers
= pools_
.size();
168 for (unsigned server_index
= 0; server_index
< num_servers
; ++server_index
) {
169 SocketVector
& pool
= pools_
[server_index
];
170 STLDeleteElements(&pool
);
174 scoped_ptr
<DatagramClientSocket
> DefaultDnsSocketPool::AllocateSocket(
175 unsigned server_index
) {
176 DCHECK_LT(server_index
, pools_
.size());
177 SocketVector
& pool
= pools_
[server_index
];
179 FillPool(server_index
, kAllocateMinSize
);
180 if (pool
.size() == 0) {
181 LOG(WARNING
) << "No DNS sockets available in pool " << server_index
<< "!";
182 return scoped_ptr
<DatagramClientSocket
>();
185 if (pool
.size() < kAllocateMinSize
) {
186 LOG(WARNING
) << "Low DNS port entropy: wanted " << kAllocateMinSize
187 << " sockets to choose from, but only have " << pool
.size()
188 << " in pool " << server_index
<< ".";
191 unsigned socket_index
= base::RandInt(0, pool
.size() - 1);
192 DatagramClientSocket
* socket
= pool
[socket_index
];
193 pool
[socket_index
] = pool
.back();
196 return scoped_ptr
<DatagramClientSocket
>(socket
);
199 void DefaultDnsSocketPool::FreeSocket(
200 unsigned server_index
,
201 scoped_ptr
<DatagramClientSocket
> socket
) {
202 DCHECK_LT(server_index
, pools_
.size());
205 void DefaultDnsSocketPool::FillPool(unsigned server_index
, unsigned size
) {
206 SocketVector
& pool
= pools_
[server_index
];
208 for (unsigned pool_index
= pool
.size(); pool_index
< size
; ++pool_index
) {
209 DatagramClientSocket
* socket
=
210 CreateConnectedSocket(server_index
).release();
213 pool
.push_back(socket
);