pidl/lib: Add recursion detection logic to prevent looping.
[Samba.git] / selftest / target / dns_hub.py
blob9f5e3dcd271ce4d52aa215b7f10e6ff0d678c4d6
1 #!/usr/bin/env python3
3 # Unix SMB/CIFS implementation.
4 # Copyright (C) Volker Lendecke 2017
6 # This program is free software; you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 3 of the License, or
9 # (at your option) any later version.
11 # This program is distributed in the hope that it will be useful,
12 # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 # GNU General Public License for more details.
16 # You should have received a copy of the GNU General Public License
17 # along with this program. If not, see <http://www.gnu.org/licenses/>.
19 # Used by selftest to proxy DNS queries to the correct testenv DC.
20 # See selftest/target/README for more details.
21 # Based on the EchoServer example from python docs
23 import threading
24 import sys
25 import select
26 import socket
27 import collections
28 import time
29 from samba.dcerpc import dns
30 import samba.ndr as ndr
32 if sys.version_info[0] < 3:
33 import SocketServer
34 sserver = SocketServer
35 else:
36 import socketserver
37 sserver = socketserver
39 DNS_REQUEST_TIMEOUT = 10
41 # make sure the script dies immediately when hitting control-C,
42 # rather than raising KeyboardInterrupt. As we do all database
43 # operations using transactions, this is safe.
44 import signal
45 signal.signal(signal.SIGINT, signal.SIG_DFL)
47 class DnsHandler(sserver.BaseRequestHandler):
48 dns_qtype_strings = dict((v, k) for k, v in vars(dns).items() if k.startswith('DNS_QTYPE_'))
49 def dns_qtype_string(self, qtype):
50 "Return a readable qtype code"
51 return self.dns_qtype_strings[qtype]
53 dns_rcode_strings = dict((v, k) for k, v in vars(dns).items() if k.startswith('DNS_RCODE_'))
54 def dns_rcode_string(self, rcode):
55 "Return a readable error code"
56 return self.dns_rcode_strings[rcode]
58 def dns_transaction_udp(self, packet, host):
59 "send a DNS query and read the reply"
60 s = None
61 flags = socket.AddressInfo.AI_NUMERICHOST
62 flags |= socket.AddressInfo.AI_NUMERICSERV
63 flags |= socket.AddressInfo.AI_PASSIVE
64 addr_info = socket.getaddrinfo(host, int(53),
65 type=socket.SocketKind.SOCK_DGRAM,
66 flags=flags)
67 assert len(addr_info) == 1
68 try:
69 send_packet = ndr.ndr_pack(packet)
70 s = socket.socket(addr_info[0][0], addr_info[0][1], 0)
71 s.settimeout(DNS_REQUEST_TIMEOUT)
72 s.connect(addr_info[0][4])
73 s.sendall(send_packet, 0)
74 recv_packet = s.recv(2048, 0)
75 return ndr.ndr_unpack(dns.name_packet, recv_packet)
76 except socket.error as err:
77 print("Error sending to host %s for name %s: %s\n" %
78 (host, packet.questions[0].name, err.errno))
79 raise
80 finally:
81 if s is not None:
82 s.close()
84 def get_pdc_ipv4_addr(self, lookup_name):
85 """Maps a DNS realm to the IPv4 address of the PDC for that testenv"""
87 realm_to_ip_mappings = self.server.realm_to_ip_mappings
89 # sort the realms so we find the longest-match first
90 testenv_realms = sorted(realm_to_ip_mappings.keys(), key=len)
91 testenv_realms.reverse()
93 for realm in testenv_realms:
94 if lookup_name.endswith(realm):
95 # return the corresponding IP address for this realm's PDC
96 return realm_to_ip_mappings[realm]
98 return None
100 def forwarder(self, name):
101 lname = name.lower()
103 # check for special cases used by tests (e.g. dns_forwarder.py)
104 if lname.endswith('an-address-that-will-not-resolve'):
105 return 'ignore'
106 if lname.endswith('dsfsdfs'):
107 return 'fail'
108 if lname.endswith("torture1", 0, len(lname)-2):
109 # CATCH TORTURE100, TORTURE101, ...
110 return 'torture'
111 if lname.endswith('_none_.example.com'):
112 return 'torture'
113 if lname.endswith('torturedom.samba.example.com'):
114 return 'torture'
116 # return the testenv PDC matching the realm being requested
117 return self.get_pdc_ipv4_addr(lname)
119 def handle(self):
120 start = time.monotonic()
121 data, sock = self.request
122 query = ndr.ndr_unpack(dns.name_packet, data)
123 name = query.questions[0].name
124 forwarder = self.forwarder(name)
125 response = None
127 if forwarder == 'ignore':
128 return
129 elif forwarder == 'fail':
130 pass
131 elif forwarder in ['torture', None]:
132 response = query
133 response.operation |= dns.DNS_FLAG_REPLY
134 response.operation |= dns.DNS_FLAG_RECURSION_AVAIL
135 response.operation |= dns.DNS_RCODE_NXDOMAIN
136 else:
137 try:
138 response = self.dns_transaction_udp(query, forwarder)
139 except OSError as err:
140 print("dns_hub: Error sending dns query to forwarder[%s] for name[%s]: %s" %
141 (forwarder, name, err))
143 if response is None:
144 response = query
145 response.operation |= dns.DNS_FLAG_REPLY
146 response.operation |= dns.DNS_FLAG_RECURSION_AVAIL
147 response.operation |= dns.DNS_RCODE_SERVFAIL
149 send_packet = ndr.ndr_pack(response)
151 end = time.monotonic()
152 tdiff = end - start
153 errcode = response.operation & dns.DNS_RCODE
154 if tdiff > (DNS_REQUEST_TIMEOUT/5):
155 debug = True
156 else:
157 debug = False
158 if debug:
159 print("dns_hub: forwarder[%s] client[%s] name[%s][%s] %s response.operation[0x%x] tdiff[%s]\n" %
160 (forwarder, self.client_address, name,
161 self.dns_qtype_string(query.questions[0].question_type),
162 self.dns_rcode_string(errcode), response.operation, tdiff))
164 try:
165 sock.sendto(send_packet, self.client_address)
166 except socket.error as err:
167 print("dns_hub: Error sending response to client[%s] for name[%s] tdiff[%s]: %s\n" %
168 (self.client_address, name, tdiff, err))
171 class server_thread(threading.Thread):
172 def __init__(self, server, name):
173 threading.Thread.__init__(self, name=name)
174 self.server = server
176 def run(self):
177 print("dns_hub[%s]: before serve_forever()" % self.name)
178 self.server.serve_forever()
179 print("dns_hub[%s]: after serve_forever()" % self.name)
181 def stop(self):
182 print("dns_hub[%s]: before shutdown()" % self.name)
183 self.server.shutdown()
184 print("dns_hub[%s]: after shutdown()" % self.name)
185 self.server.server_close()
187 class UDPV4Server(sserver.UDPServer):
188 address_family = socket.AF_INET
190 class UDPV6Server(sserver.UDPServer):
191 address_family = socket.AF_INET6
193 def main():
194 if len(sys.argv) < 4:
195 print("Usage: dns_hub.py TIMEOUT LISTENADDRESS[,LISTENADDRESS,...] MAPPING[,MAPPING,...]")
196 sys.exit(1)
198 timeout = int(sys.argv[1]) * 1000
199 timeout = min(timeout, 2**31 - 1) # poll with 32-bit int can't take more
200 # we pass in the listen addresses as a comma-separated string.
201 listenaddresses = sys.argv[2].split(',')
202 # we pass in the realm-to-IP mappings as a comma-separated key=value
203 # string. Convert this back into a dictionary that the DnsHandler can use
204 realm_mappings = collections.OrderedDict(kv.split('=') for kv in sys.argv[3].split(','))
206 def prepare_server_thread(listenaddress, realm_mappings):
208 flags = socket.AddressInfo.AI_NUMERICHOST
209 flags |= socket.AddressInfo.AI_NUMERICSERV
210 flags |= socket.AddressInfo.AI_PASSIVE
211 addr_info = socket.getaddrinfo(listenaddress, int(53),
212 type=socket.SocketKind.SOCK_DGRAM,
213 flags=flags)
214 assert len(addr_info) == 1
215 if addr_info[0][0] == socket.AddressFamily.AF_INET6:
216 server = UDPV6Server(addr_info[0][4], DnsHandler)
217 else:
218 server = UDPV4Server(addr_info[0][4], DnsHandler)
220 # we pass in the realm-to-IP mappings as a comma-separated key=value
221 # string. Convert this back into a dictionary that the DnsHandler can use
222 server.realm_to_ip_mappings = realm_mappings
223 t = server_thread(server, name="UDP[%s]" % listenaddress)
224 return t
226 print("dns_hub will proxy DNS requests for the following realms:")
227 for realm, ip in realm_mappings.items():
228 print(" {0} ==> {1}".format(realm, ip))
230 print("dns_hub will listen on the following UDP addresses:")
231 threads = []
232 for listenaddress in listenaddresses:
233 print(" %s" % listenaddress)
234 t = prepare_server_thread(listenaddress, realm_mappings)
235 threads.append(t)
237 for t in threads:
238 t.start()
239 p = select.poll()
240 stdin = sys.stdin.fileno()
241 p.register(stdin, select.POLLIN)
242 p.poll(timeout)
243 print("dns_hub: after poll()")
244 for t in threads:
245 t.stop()
246 for t in threads:
247 t.join()
248 print("dns_hub: before exit()")
249 sys.exit(0)
251 main()