Bug 1796551 [wpt PR 36570] - WebKit export of https://bugs.webkit.org/show_bug.cgi...
[gecko.git] / netwerk / dns / GetAddrInfo.cpp
blob2f0354bd960f1d11eb6c83079ae6ff71dbe39b30
1 /* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
2 /* vim: set ts=8 sts=2 et sw=2 tw=80: */
3 /* This Source Code Form is subject to the terms of the Mozilla Public
4 * License, v. 2.0. If a copy of the MPL was not distributed with this
5 * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
7 #include "GetAddrInfo.h"
9 #ifdef DNSQUERY_AVAILABLE
10 // There is a bug in windns.h where the type of parameter ppQueryResultsSet for
11 // DnsQuery_A is dependent on UNICODE being set. It should *always* be
12 // PDNS_RECORDA, but if UNICODE is set it is PDNS_RECORDW. To get around this
13 // we make sure that UNICODE is unset.
14 # undef UNICODE
15 # include <ws2tcpip.h>
16 # undef GetAddrInfo
17 # include <windns.h>
18 #endif // DNSQUERY_AVAILABLE
20 #include "mozilla/ClearOnShutdown.h"
21 #include "mozilla/net/DNS.h"
22 #include "NativeDNSResolverOverrideParent.h"
23 #include "prnetdb.h"
24 #include "nsIOService.h"
25 #include "nsHostResolver.h"
26 #include "nsError.h"
27 #include "mozilla/net/DNS.h"
28 #include <algorithm>
29 #include "prerror.h"
31 #include "mozilla/Logging.h"
32 #include "mozilla/StaticPrefs_network.h"
34 namespace mozilla::net {
36 static StaticRefPtr<NativeDNSResolverOverride> gOverrideService;
38 static LazyLogModule gGetAddrInfoLog("GetAddrInfo");
39 #define LOG(msg, ...) \
40 MOZ_LOG(gGetAddrInfoLog, LogLevel::Debug, ("[DNS]: " msg, ##__VA_ARGS__))
41 #define LOG_WARNING(msg, ...) \
42 MOZ_LOG(gGetAddrInfoLog, LogLevel::Warning, ("[DNS]: " msg, ##__VA_ARGS__))
44 #ifdef DNSQUERY_AVAILABLE
46 # define COMPUTER_NAME_BUFFER_SIZE 100
47 static char sDNSComputerName[COMPUTER_NAME_BUFFER_SIZE];
48 static char sNETBIOSComputerName[MAX_COMPUTERNAME_LENGTH + 1];
50 ////////////////////////////
51 // WINDOWS IMPLEMENTATION //
52 ////////////////////////////
54 // Ensure consistency of PR_* and AF_* constants to allow for legacy usage of
55 // PR_* constants with this API.
56 static_assert(PR_AF_INET == AF_INET && PR_AF_INET6 == AF_INET6 &&
57 PR_AF_UNSPEC == AF_UNSPEC,
58 "PR_AF_* must match AF_*");
60 // If successful, returns in aResult a TTL value that is smaller or
61 // equal with the one already there. Gets the TTL value by calling
62 // to DnsQuery_A and iterating through the returned
63 // records to find the one with the smallest TTL value.
64 static MOZ_ALWAYS_INLINE nsresult _CallDnsQuery_A_Windows(
65 const nsACString& aHost, uint16_t aAddressFamily, DWORD aFlags,
66 std::function<void(PDNS_RECORDA)> aCallback) {
67 NS_ConvertASCIItoUTF16 name(aHost);
69 auto callDnsQuery_A = [&](uint16_t reqFamily) {
70 PDNS_RECORDA dnsData = nullptr;
71 DNS_STATUS status = DnsQuery_A(aHost.BeginReading(), reqFamily, aFlags,
72 nullptr, &dnsData, nullptr);
73 if (status == DNS_INFO_NO_RECORDS || status == DNS_ERROR_RCODE_NAME_ERROR ||
74 !dnsData) {
75 LOG("No DNS records found for %s. status=%lX. reqFamily = %X\n",
76 aHost.BeginReading(), status, reqFamily);
77 return NS_ERROR_FAILURE;
78 } else if (status != NOERROR) {
79 LOG_WARNING("DnsQuery_A failed with status %lX.\n", status);
80 return NS_ERROR_UNEXPECTED;
83 for (PDNS_RECORDA curRecord = dnsData; curRecord;
84 curRecord = curRecord->pNext) {
85 // Only records in the answer section are important
86 if (curRecord->Flags.S.Section != DnsSectionAnswer) {
87 continue;
89 if (curRecord->wType != reqFamily) {
90 continue;
93 aCallback(curRecord);
96 DnsFree(dnsData, DNS_FREE_TYPE::DnsFreeRecordList);
97 return NS_OK;
100 if (aAddressFamily == PR_AF_UNSPEC || aAddressFamily == PR_AF_INET) {
101 callDnsQuery_A(DNS_TYPE_A);
104 if (aAddressFamily == PR_AF_UNSPEC || aAddressFamily == PR_AF_INET6) {
105 callDnsQuery_A(DNS_TYPE_AAAA);
107 return NS_OK;
110 bool recordTypeMatchesRequest(uint16_t wType, uint16_t aAddressFamily) {
111 if (aAddressFamily == PR_AF_UNSPEC) {
112 return wType == DNS_TYPE_A || wType == DNS_TYPE_AAAA;
114 if (aAddressFamily == PR_AF_INET) {
115 return wType == DNS_TYPE_A;
117 if (aAddressFamily == PR_AF_INET6) {
118 return wType == DNS_TYPE_AAAA;
120 return false;
123 static MOZ_ALWAYS_INLINE nsresult _GetTTLData_Windows(const nsACString& aHost,
124 uint32_t* aResult,
125 uint16_t aAddressFamily) {
126 MOZ_ASSERT(!aHost.IsEmpty());
127 MOZ_ASSERT(aResult);
128 if (aAddressFamily != PR_AF_UNSPEC && aAddressFamily != PR_AF_INET &&
129 aAddressFamily != PR_AF_INET6) {
130 return NS_ERROR_UNEXPECTED;
133 // In order to avoid using ANY records which are not always implemented as a
134 // "Gimme what you have" request in hostname resolvers, we should send A
135 // and/or AAAA requests, based on the address family requested.
136 const DWORD ttlFlags =
137 (DNS_QUERY_STANDARD | DNS_QUERY_NO_NETBT | DNS_QUERY_NO_HOSTS_FILE |
138 DNS_QUERY_NO_MULTICAST | DNS_QUERY_ACCEPT_TRUNCATED_RESPONSE |
139 DNS_QUERY_DONT_RESET_TTL_VALUES);
140 unsigned int ttl = (unsigned int)-1;
141 _CallDnsQuery_A_Windows(
142 aHost, aAddressFamily, ttlFlags,
143 [&ttl, &aHost, aAddressFamily](PDNS_RECORDA curRecord) {
144 if (recordTypeMatchesRequest(curRecord->wType, aAddressFamily)) {
145 ttl = std::min<unsigned int>(ttl, curRecord->dwTtl);
146 } else {
147 LOG("Received unexpected record type %u in response for %s.\n",
148 curRecord->wType, aHost.BeginReading());
152 if (ttl == (unsigned int)-1) {
153 LOG("No useable TTL found.");
154 return NS_ERROR_FAILURE;
157 *aResult = ttl;
158 return NS_OK;
161 static MOZ_ALWAYS_INLINE nsresult
162 _DNSQuery_A_SingleLabel(const nsACString& aCanonHost, uint16_t aAddressFamily,
163 uint16_t aFlags, AddrInfo** aAddrInfo) {
164 bool setCanonName = aFlags & nsHostResolver::RES_CANON_NAME;
165 nsAutoCString canonName;
166 const DWORD flags = (DNS_QUERY_STANDARD | DNS_QUERY_NO_MULTICAST |
167 DNS_QUERY_ACCEPT_TRUNCATED_RESPONSE);
168 nsTArray<NetAddr> addresses;
170 _CallDnsQuery_A_Windows(
171 aCanonHost, aAddressFamily, flags, [&](PDNS_RECORDA curRecord) {
172 MOZ_DIAGNOSTIC_ASSERT(curRecord->wType == DNS_TYPE_A ||
173 curRecord->wType == DNS_TYPE_AAAA);
174 if (setCanonName) {
175 canonName.Assign(curRecord->pName);
177 NetAddr addr{};
178 addr.inet.family = AF_INET;
179 addr.inet.ip = curRecord->Data.A.IpAddress;
180 addresses.AppendElement(addr);
183 LOG("Query for: %s has %zu results", aCanonHost.BeginReading(),
184 addresses.Length());
185 if (addresses.IsEmpty()) {
186 return NS_ERROR_UNKNOWN_HOST;
188 RefPtr<AddrInfo> ai(new AddrInfo(
189 aCanonHost, canonName, DNSResolverType::Native, 0, std::move(addresses)));
190 ai.forget(aAddrInfo);
192 return NS_OK;
195 #endif
197 ////////////////////////////////////
198 // PORTABLE RUNTIME IMPLEMENTATION//
199 ////////////////////////////////////
201 static MOZ_ALWAYS_INLINE nsresult
202 _GetAddrInfo_Portable(const nsACString& aCanonHost, uint16_t aAddressFamily,
203 uint16_t aFlags, AddrInfo** aAddrInfo) {
204 MOZ_ASSERT(!aCanonHost.IsEmpty());
205 MOZ_ASSERT(aAddrInfo);
207 // We accept the same aFlags that nsHostResolver::ResolveHost accepts, but we
208 // need to translate the aFlags into a form that PR_GetAddrInfoByName
209 // accepts.
210 int prFlags = PR_AI_ADDRCONFIG;
211 if (!(aFlags & nsHostResolver::RES_CANON_NAME)) {
212 prFlags |= PR_AI_NOCANONNAME;
215 // We need to remove IPv4 records manually because PR_GetAddrInfoByName
216 // doesn't support PR_AF_INET6.
217 bool disableIPv4 = aAddressFamily == PR_AF_INET6;
218 if (disableIPv4) {
219 aAddressFamily = PR_AF_UNSPEC;
222 #if defined(DNSQUERY_AVAILABLE)
223 if (StaticPrefs::network_dns_dns_query_single_label() &&
224 !aCanonHost.Contains('.') && aCanonHost != "localhost"_ns) {
225 // For some reason we can't use DnsQuery_A to get the computer's IP.
226 if (!aCanonHost.Equals(nsDependentCString(sDNSComputerName),
227 nsCaseInsensitiveCStringComparator) &&
228 !aCanonHost.Equals(nsDependentCString(sNETBIOSComputerName),
229 nsCaseInsensitiveCStringComparator)) {
230 // This is a single label name resolve without a dot.
231 // We use DNSQuery_A for these.
232 LOG("Resolving %s using DnsQuery_A (computername: %s)\n",
233 aCanonHost.BeginReading(), sDNSComputerName);
234 return _DNSQuery_A_SingleLabel(aCanonHost, aAddressFamily, aFlags,
235 aAddrInfo);
238 #endif
240 LOG("Resolving %s using PR_GetAddrInfoByName", aCanonHost.BeginReading());
241 PRAddrInfo* prai =
242 PR_GetAddrInfoByName(aCanonHost.BeginReading(), aAddressFamily, prFlags);
244 if (!prai) {
245 LOG("PR_GetAddrInfoByName returned null PR_GetError:%d PR_GetOSErrpr:%d",
246 PR_GetError(), PR_GetOSError());
247 return NS_ERROR_UNKNOWN_HOST;
250 nsAutoCString canonName;
251 if (aFlags & nsHostResolver::RES_CANON_NAME) {
252 canonName.Assign(PR_GetCanonNameFromAddrInfo(prai));
255 bool filterNameCollision =
256 !(aFlags & nsHostResolver::RES_ALLOW_NAME_COLLISION);
257 RefPtr<AddrInfo> ai(new AddrInfo(aCanonHost, prai, disableIPv4,
258 filterNameCollision, canonName));
259 PR_FreeAddrInfo(prai);
260 if (ai->Addresses().IsEmpty()) {
261 LOG("PR_GetAddrInfoByName returned empty address list");
262 return NS_ERROR_UNKNOWN_HOST;
265 ai.forget(aAddrInfo);
267 LOG("PR_GetAddrInfoByName resolved successfully");
268 return NS_OK;
271 //////////////////////////////////////
272 // COMMON/PLATFORM INDEPENDENT CODE //
273 //////////////////////////////////////
274 nsresult GetAddrInfoInit() {
275 LOG("Initializing GetAddrInfo.\n");
277 #ifdef DNSQUERY_AVAILABLE
278 DWORD namesize = COMPUTER_NAME_BUFFER_SIZE;
279 if (!GetComputerNameExA(ComputerNameDnsHostname, sDNSComputerName,
280 &namesize)) {
281 sDNSComputerName[0] = 0;
283 namesize = MAX_COMPUTERNAME_LENGTH + 1;
284 if (!GetComputerNameExA(ComputerNameNetBIOS, sNETBIOSComputerName,
285 &namesize)) {
286 sNETBIOSComputerName[0] = 0;
288 #endif
289 return NS_OK;
292 nsresult GetAddrInfoShutdown() {
293 LOG("Shutting down GetAddrInfo.\n");
294 return NS_OK;
297 bool FindAddrOverride(const nsACString& aHost, uint16_t aAddressFamily,
298 uint16_t aFlags, AddrInfo** aAddrInfo) {
299 RefPtr<NativeDNSResolverOverride> overrideService = gOverrideService;
300 if (!overrideService) {
301 return false;
303 AutoReadLock lock(overrideService->mLock);
304 auto overrides = overrideService->mOverrides.Lookup(aHost);
305 if (!overrides) {
306 return false;
308 nsCString* cname = nullptr;
309 if (aFlags & nsHostResolver::RES_CANON_NAME) {
310 cname = overrideService->mCnames.Lookup(aHost).DataPtrOrNull();
313 RefPtr<AddrInfo> ai;
315 nsTArray<NetAddr> addresses;
316 for (const auto& ip : *overrides) {
317 if (aAddressFamily != AF_UNSPEC && ip.raw.family != aAddressFamily) {
318 continue;
320 addresses.AppendElement(ip);
323 if (!cname) {
324 ai = new AddrInfo(aHost, DNSResolverType::Native, 0, std::move(addresses));
325 } else {
326 ai = new AddrInfo(aHost, *cname, DNSResolverType::Native, 0,
327 std::move(addresses));
330 ai.forget(aAddrInfo);
331 return true;
334 nsresult GetAddrInfo(const nsACString& aHost, uint16_t aAddressFamily,
335 uint16_t aFlags, AddrInfo** aAddrInfo, bool aGetTtl) {
336 if (NS_WARN_IF(aHost.IsEmpty()) || NS_WARN_IF(!aAddrInfo)) {
337 return NS_ERROR_NULL_POINTER;
339 *aAddrInfo = nullptr;
341 if (StaticPrefs::network_dns_disabled()) {
342 return NS_ERROR_UNKNOWN_HOST;
345 #ifdef DNSQUERY_AVAILABLE
346 // The GetTTLData needs the canonical name to function properly
347 if (aGetTtl) {
348 aFlags |= nsHostResolver::RES_CANON_NAME;
350 #endif
352 // If there is an override for this host, then we synthetize a result.
353 if (gOverrideService &&
354 FindAddrOverride(aHost, aAddressFamily, aFlags, aAddrInfo)) {
355 LOG("Returning IP address from NativeDNSResolverOverride");
356 return (*aAddrInfo)->Addresses().Length() ? NS_OK : NS_ERROR_UNKNOWN_HOST;
359 nsAutoCString host;
360 if (StaticPrefs::network_dns_copy_string_before_call()) {
361 host = Substring(aHost.BeginReading(), aHost.Length());
362 MOZ_ASSERT(aHost.BeginReading() != host.BeginReading());
363 } else {
364 host = aHost;
367 if (gNativeIsLocalhost) {
368 // pretend we use the given host but use IPv4 localhost instead!
369 host = "localhost"_ns;
370 aAddressFamily = PR_AF_INET;
373 RefPtr<AddrInfo> info;
374 nsresult rv =
375 _GetAddrInfo_Portable(host, aAddressFamily, aFlags, getter_AddRefs(info));
377 #ifdef DNSQUERY_AVAILABLE
378 if (aGetTtl && NS_SUCCEEDED(rv)) {
379 // Figure out the canonical name, or if that fails, just use the host name
380 // we have.
381 nsAutoCString name;
382 if (info && !info->CanonicalHostname().IsEmpty()) {
383 name = info->CanonicalHostname();
384 } else {
385 name = host;
388 LOG("Getting TTL for %s (cname = %s).", host.get(), name.get());
389 uint32_t ttl = 0;
390 nsresult ttlRv = _GetTTLData_Windows(name, &ttl, aAddressFamily);
391 if (NS_SUCCEEDED(ttlRv)) {
392 auto builder = info->Build();
393 builder.SetTTL(ttl);
394 info = builder.Finish();
395 LOG("Got TTL %u for %s (name = %s).", ttl, host.get(), name.get());
396 } else {
397 LOG_WARNING("Could not get TTL for %s (cname = %s).", host.get(),
398 name.get());
401 #endif
403 info.forget(aAddrInfo);
404 return rv;
407 // static
408 already_AddRefed<nsINativeDNSResolverOverride>
409 NativeDNSResolverOverride::GetSingleton() {
410 if (nsIOService::UseSocketProcess() && XRE_IsParentProcess()) {
411 return NativeDNSResolverOverrideParent::GetSingleton();
414 if (gOverrideService) {
415 return do_AddRef(gOverrideService);
418 gOverrideService = new NativeDNSResolverOverride();
419 ClearOnShutdown(&gOverrideService);
420 return do_AddRef(gOverrideService);
423 NS_IMPL_ISUPPORTS(NativeDNSResolverOverride, nsINativeDNSResolverOverride)
425 NS_IMETHODIMP NativeDNSResolverOverride::AddIPOverride(
426 const nsACString& aHost, const nsACString& aIPLiteral) {
427 NetAddr tempAddr;
429 if (aIPLiteral.Equals("N/A"_ns)) {
430 AutoWriteLock lock(mLock);
431 auto& overrides = mOverrides.LookupOrInsert(aHost);
432 overrides.Clear();
433 return NS_OK;
436 if (NS_FAILED(tempAddr.InitFromString(aIPLiteral))) {
437 return NS_ERROR_UNEXPECTED;
440 AutoWriteLock lock(mLock);
441 auto& overrides = mOverrides.LookupOrInsert(aHost);
442 overrides.AppendElement(tempAddr);
444 return NS_OK;
447 NS_IMETHODIMP NativeDNSResolverOverride::SetCnameOverride(
448 const nsACString& aHost, const nsACString& aCNAME) {
449 if (aCNAME.IsEmpty()) {
450 return NS_ERROR_UNEXPECTED;
453 AutoWriteLock lock(mLock);
454 mCnames.InsertOrUpdate(aHost, nsCString(aCNAME));
456 return NS_OK;
459 NS_IMETHODIMP NativeDNSResolverOverride::ClearHostOverride(
460 const nsACString& aHost) {
461 AutoWriteLock lock(mLock);
462 mCnames.Remove(aHost);
463 auto overrides = mOverrides.Extract(aHost);
464 if (!overrides) {
465 return NS_OK;
468 overrides->Clear();
469 return NS_OK;
472 NS_IMETHODIMP NativeDNSResolverOverride::ClearOverrides() {
473 AutoWriteLock lock(mLock);
474 mOverrides.Clear();
475 mCnames.Clear();
476 return NS_OK;
479 } // namespace mozilla::net