3 # Unix SMB/CIFS implementation.
4 # Copyright (C) Volker Lendecke 2017
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 3 of the License, or
9 # (at your option) any later version.
11 # This program is distributed in the hope that it will be useful,
12 # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 # GNU General Public License for more details.
16 # You should have received a copy of the GNU General Public License
17 # along with this program. If not, see <http://www.gnu.org/licenses/>.
19 # Used by selftest to proxy DNS queries to the correct testenv DC.
20 # See selftest/target/README for more details.
21 # Based on the EchoServer example from python docs
29 from samba
.dcerpc
import dns
30 import samba
.ndr
as ndr
32 if sys
.version_info
[0] < 3:
34 sserver
= SocketServer
37 sserver
= socketserver
39 DNS_REQUEST_TIMEOUT
= 10
41 # make sure the script dies immediately when hitting control-C,
42 # rather than raising KeyboardInterrupt. As we do all database
43 # operations using transactions, this is safe.
45 signal
.signal(signal
.SIGINT
, signal
.SIG_DFL
)
47 class DnsHandler(sserver
.BaseRequestHandler
):
48 dns_qtype_strings
= dict((v
, k
) for k
, v
in vars(dns
).items() if k
.startswith('DNS_QTYPE_'))
49 def dns_qtype_string(self
, qtype
):
50 "Return a readable qtype code"
51 return self
.dns_qtype_strings
[qtype
]
53 dns_rcode_strings
= dict((v
, k
) for k
, v
in vars(dns
).items() if k
.startswith('DNS_RCODE_'))
54 def dns_rcode_string(self
, rcode
):
55 "Return a readable error code"
56 return self
.dns_rcode_strings
[rcode
]
58 def dns_transaction_udp(self
, packet
, host
):
59 "send a DNS query and read the reply"
61 flags
= socket
.AddressInfo
.AI_NUMERICHOST
62 flags |
= socket
.AddressInfo
.AI_NUMERICSERV
63 flags |
= socket
.AddressInfo
.AI_PASSIVE
64 addr_info
= socket
.getaddrinfo(host
, int(53),
65 type=socket
.SocketKind
.SOCK_DGRAM
,
67 assert len(addr_info
) == 1
69 send_packet
= ndr
.ndr_pack(packet
)
70 s
= socket
.socket(addr_info
[0][0], addr_info
[0][1], 0)
71 s
.settimeout(DNS_REQUEST_TIMEOUT
)
72 s
.connect(addr_info
[0][4])
73 s
.sendall(send_packet
, 0)
74 recv_packet
= s
.recv(2048, 0)
75 return ndr
.ndr_unpack(dns
.name_packet
, recv_packet
)
76 except socket
.error
as err
:
77 print("Error sending to host %s for name %s: %s\n" %
78 (host
, packet
.questions
[0].name
, err
.errno
))
84 def get_pdc_ipv4_addr(self
, lookup_name
):
85 """Maps a DNS realm to the IPv4 address of the PDC for that testenv"""
87 realm_to_ip_mappings
= self
.server
.realm_to_ip_mappings
89 # sort the realms so we find the longest-match first
90 testenv_realms
= sorted(realm_to_ip_mappings
.keys(), key
=len)
91 testenv_realms
.reverse()
93 for realm
in testenv_realms
:
94 if lookup_name
.endswith(realm
):
95 # return the corresponding IP address for this realm's PDC
96 return realm_to_ip_mappings
[realm
]
100 def forwarder(self
, name
):
103 # check for special cases used by tests (e.g. dns_forwarder.py)
104 if lname
.endswith('an-address-that-will-not-resolve'):
106 if lname
.endswith('dsfsdfs'):
108 if lname
.endswith("torture1", 0, len(lname
)-2):
109 # CATCH TORTURE100, TORTURE101, ...
111 if lname
.endswith('_none_.example.com'):
113 if lname
.endswith('torturedom.samba.example.com'):
116 # return the testenv PDC matching the realm being requested
117 return self
.get_pdc_ipv4_addr(lname
)
120 start
= time
.monotonic()
121 data
, sock
= self
.request
122 query
= ndr
.ndr_unpack(dns
.name_packet
, data
)
123 name
= query
.questions
[0].name
124 forwarder
= self
.forwarder(name
)
127 if forwarder
== 'ignore':
129 elif forwarder
== 'fail':
131 elif forwarder
in ['torture', None]:
133 response
.operation |
= dns
.DNS_FLAG_REPLY
134 response
.operation |
= dns
.DNS_FLAG_RECURSION_AVAIL
135 response
.operation |
= dns
.DNS_RCODE_NXDOMAIN
138 response
= self
.dns_transaction_udp(query
, forwarder
)
139 except OSError as err
:
140 print("dns_hub: Error sending dns query to forwarder[%s] for name[%s]: %s" %
141 (forwarder
, name
, err
))
145 response
.operation |
= dns
.DNS_FLAG_REPLY
146 response
.operation |
= dns
.DNS_FLAG_RECURSION_AVAIL
147 response
.operation |
= dns
.DNS_RCODE_SERVFAIL
149 send_packet
= ndr
.ndr_pack(response
)
151 end
= time
.monotonic()
153 errcode
= response
.operation
& dns
.DNS_RCODE
154 if tdiff
> (DNS_REQUEST_TIMEOUT
/5):
159 print("dns_hub: forwarder[%s] client[%s] name[%s][%s] %s response.operation[0x%x] tdiff[%s]\n" %
160 (forwarder
, self
.client_address
, name
,
161 self
.dns_qtype_string(query
.questions
[0].question_type
),
162 self
.dns_rcode_string(errcode
), response
.operation
, tdiff
))
165 sock
.sendto(send_packet
, self
.client_address
)
166 except socket
.error
as err
:
167 print("dns_hub: Error sending response to client[%s] for name[%s] tdiff[%s]: %s\n" %
168 (self
.client_address
, name
, tdiff
, err
))
171 class server_thread(threading
.Thread
):
172 def __init__(self
, server
, name
):
173 threading
.Thread
.__init
__(self
, name
=name
)
177 print("dns_hub[%s]: before serve_forever()" % self
.name
)
178 self
.server
.serve_forever()
179 print("dns_hub[%s]: after serve_forever()" % self
.name
)
182 print("dns_hub[%s]: before shutdown()" % self
.name
)
183 self
.server
.shutdown()
184 print("dns_hub[%s]: after shutdown()" % self
.name
)
185 self
.server
.server_close()
187 class UDPV4Server(sserver
.UDPServer
):
188 address_family
= socket
.AF_INET
190 class UDPV6Server(sserver
.UDPServer
):
191 address_family
= socket
.AF_INET6
194 if len(sys
.argv
) < 4:
195 print("Usage: dns_hub.py TIMEOUT LISTENADDRESS[,LISTENADDRESS,...] MAPPING[,MAPPING,...]")
198 timeout
= int(sys
.argv
[1]) * 1000
199 timeout
= min(timeout
, 2**31 - 1) # poll with 32-bit int can't take more
200 # we pass in the listen addresses as a comma-separated string.
201 listenaddresses
= sys
.argv
[2].split(',')
202 # we pass in the realm-to-IP mappings as a comma-separated key=value
203 # string. Convert this back into a dictionary that the DnsHandler can use
204 realm_mappings
= collections
.OrderedDict(kv
.split('=') for kv
in sys
.argv
[3].split(','))
206 def prepare_server_thread(listenaddress
, realm_mappings
):
208 flags
= socket
.AddressInfo
.AI_NUMERICHOST
209 flags |
= socket
.AddressInfo
.AI_NUMERICSERV
210 flags |
= socket
.AddressInfo
.AI_PASSIVE
211 addr_info
= socket
.getaddrinfo(listenaddress
, int(53),
212 type=socket
.SocketKind
.SOCK_DGRAM
,
214 assert len(addr_info
) == 1
215 if addr_info
[0][0] == socket
.AddressFamily
.AF_INET6
:
216 server
= UDPV6Server(addr_info
[0][4], DnsHandler
)
218 server
= UDPV4Server(addr_info
[0][4], DnsHandler
)
220 # we pass in the realm-to-IP mappings as a comma-separated key=value
221 # string. Convert this back into a dictionary that the DnsHandler can use
222 server
.realm_to_ip_mappings
= realm_mappings
223 t
= server_thread(server
, name
="UDP[%s]" % listenaddress
)
226 print("dns_hub will proxy DNS requests for the following realms:")
227 for realm
, ip
in realm_mappings
.items():
228 print(" {0} ==> {1}".format(realm
, ip
))
230 print("dns_hub will listen on the following UDP addresses:")
232 for listenaddress
in listenaddresses
:
233 print(" %s" % listenaddress
)
234 t
= prepare_server_thread(listenaddress
, realm_mappings
)
240 stdin
= sys
.stdin
.fileno()
241 p
.register(stdin
, select
.POLLIN
)
243 print("dns_hub: after poll()")
248 print("dns_hub: before exit()")