selftest: Use TestCaseInTempDir as base class in dns tests
[Samba.git] / python / samba / tests / dns_base.py
blob2a40d999c36f4b01a0172789799d9314fbef472d
1 # Unix SMB/CIFS implementation.
2 # Copyright (C) Kai Blin <kai@samba.org> 2011
3 # Copyright (C) Ralph Boehme <slow@samba.org> 2016
5 # This program is free software; you can redistribute it and/or modify
6 # it under the terms of the GNU General Public License as published by
7 # the Free Software Foundation; either version 3 of the License, or
8 # (at your option) any later version.
10 # This program is distributed in the hope that it will be useful,
11 # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 # GNU General Public License for more details.
15 # You should have received a copy of the GNU General Public License
16 # along with this program. If not, see <http://www.gnu.org/licenses/>.
19 from samba.tests import TestCaseInTempDir
20 from samba.dcerpc import dns, dnsp
21 from samba import gensec, tests
22 from samba import credentials
23 import struct
24 import samba.ndr as ndr
25 import random
26 import socket
27 import uuid
28 import time
30 class DNSTest(TestCaseInTempDir):
32 def setUp(self):
33 super(DNSTest, self).setUp()
34 self.timeout = None
36 def errstr(self, errcode):
37 "Return a readable error code"
38 string_codes = [
39 "OK",
40 "FORMERR",
41 "SERVFAIL",
42 "NXDOMAIN",
43 "NOTIMP",
44 "REFUSED",
45 "YXDOMAIN",
46 "YXRRSET",
47 "NXRRSET",
48 "NOTAUTH",
49 "NOTZONE",
50 "0x0B",
51 "0x0C",
52 "0x0D",
53 "0x0E",
54 "0x0F",
55 "BADSIG",
56 "BADKEY"
59 return string_codes[errcode]
61 def assert_rcode_equals(self, rcode, expected):
62 "Helper function to check return code"
63 self.assertEquals(rcode, expected, "Expected RCODE %s, got %s" %
64 (self.errstr(expected), self.errstr(rcode)))
66 def assert_dns_rcode_equals(self, packet, rcode):
67 "Helper function to check return code"
68 p_errcode = packet.operation & 0x000F
69 self.assertEquals(p_errcode, rcode, "Expected RCODE %s, got %s" %
70 (self.errstr(rcode), self.errstr(p_errcode)))
72 def assert_dns_opcode_equals(self, packet, opcode):
73 "Helper function to check opcode"
74 p_opcode = packet.operation & 0x7800
75 self.assertEquals(p_opcode, opcode, "Expected OPCODE %s, got %s" %
76 (opcode, p_opcode))
78 def make_name_packet(self, opcode, qid=None):
79 "Helper creating a dns.name_packet"
80 p = dns.name_packet()
81 if qid is None:
82 p.id = random.randint(0x0, 0xff00)
83 p.operation = opcode
84 p.questions = []
85 p.additional = []
86 return p
88 def finish_name_packet(self, packet, questions):
89 "Helper to finalize a dns.name_packet"
90 packet.qdcount = len(questions)
91 packet.questions = questions
93 def make_name_question(self, name, qtype, qclass):
94 "Helper creating a dns.name_question"
95 q = dns.name_question()
96 q.name = name
97 q.question_type = qtype
98 q.question_class = qclass
99 return q
101 def make_txt_record(self, records):
102 rdata_txt = dns.txt_record()
103 s_list = dnsp.string_list()
104 s_list.count = len(records)
105 s_list.str = records
106 rdata_txt.txt = s_list
107 return rdata_txt
109 def get_dns_domain(self):
110 "Helper to get dns domain"
111 return self.creds.get_realm().lower()
113 def dns_transaction_udp(self, packet, host,
114 dump=False, timeout=None):
115 "send a DNS query and read the reply"
116 s = None
117 if timeout is None:
118 timeout = self.timeout
119 try:
120 send_packet = ndr.ndr_pack(packet)
121 if dump:
122 print self.hexdump(send_packet)
123 s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, 0)
124 s.settimeout(timeout)
125 s.connect((host, 53))
126 s.sendall(send_packet, 0)
127 recv_packet = s.recv(2048, 0)
128 if dump:
129 print self.hexdump(recv_packet)
130 response = ndr.ndr_unpack(dns.name_packet, recv_packet)
131 return (response, recv_packet)
132 finally:
133 if s is not None:
134 s.close()
136 def dns_transaction_tcp(self, packet, host,
137 dump=False, timeout=None):
138 "send a DNS query and read the reply, also return the raw packet"
139 s = None
140 if timeout is None:
141 timeout = self.timeout
142 try:
143 send_packet = ndr.ndr_pack(packet)
144 if dump:
145 print self.hexdump(send_packet)
146 s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
147 s.settimeout(timeout)
148 s.connect((host, 53))
149 tcp_packet = struct.pack('!H', len(send_packet))
150 tcp_packet += send_packet
151 s.sendall(tcp_packet)
153 recv_packet = s.recv(0xffff + 2, 0)
154 if dump:
155 print self.hexdump(recv_packet)
156 response = ndr.ndr_unpack(dns.name_packet, recv_packet[2:])
158 finally:
159 if s is not None:
160 s.close()
162 # unpacking and packing again should produce same bytestream
163 my_packet = ndr.ndr_pack(response)
164 self.assertEquals(my_packet, recv_packet[2:])
166 return (response, recv_packet[2:])
168 def make_txt_update(self, prefix, txt_array):
169 p = self.make_name_packet(dns.DNS_OPCODE_UPDATE)
170 updates = []
172 name = self.get_dns_domain()
173 u = self.make_name_question(name, dns.DNS_QTYPE_SOA, dns.DNS_QCLASS_IN)
174 updates.append(u)
175 self.finish_name_packet(p, updates)
177 updates = []
178 r = dns.res_rec()
179 r.name = "%s.%s" % (prefix, self.get_dns_domain())
180 r.rr_type = dns.DNS_QTYPE_TXT
181 r.rr_class = dns.DNS_QCLASS_IN
182 r.ttl = 900
183 r.length = 0xffff
184 rdata = self.make_txt_record(txt_array)
185 r.rdata = rdata
186 updates.append(r)
187 p.nscount = len(updates)
188 p.nsrecs = updates
190 return p
192 def check_query_txt(self, prefix, txt_array):
193 name = "%s.%s" % (prefix, self.get_dns_domain())
194 p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
195 questions = []
197 q = self.make_name_question(name, dns.DNS_QTYPE_TXT, dns.DNS_QCLASS_IN)
198 questions.append(q)
200 self.finish_name_packet(p, questions)
201 (response, response_packet) = self.dns_transaction_udp(p, host=self.server_ip)
202 self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
203 self.assertEquals(response.ancount, 1)
204 self.assertEquals(response.answers[0].rdata.txt.str, txt_array)
207 class DNSTKeyTest(DNSTest):
208 def setUp(self):
209 super(DNSTKeyTest, self).setUp()
210 self.settings = {}
211 self.settings["lp_ctx"] = self.lp_ctx = tests.env_loadparm()
212 self.settings["target_hostname"] = self.server
214 self.creds = credentials.Credentials()
215 self.creds.guess(self.lp_ctx)
216 self.creds.set_username(tests.env_get_var_value('USERNAME'))
217 self.creds.set_password(tests.env_get_var_value('PASSWORD'))
218 self.creds.set_kerberos_state(credentials.MUST_USE_KERBEROS)
219 self.newrecname = "tkeytsig.%s" % self.get_dns_domain()
221 def tkey_trans(self, creds=None):
222 "Do a TKEY transaction and establish a gensec context"
224 if creds is None:
225 creds = self.creds
227 self.key_name = "%s.%s" % (uuid.uuid4(), self.get_dns_domain())
229 p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
230 q = self.make_name_question(self.key_name,
231 dns.DNS_QTYPE_TKEY,
232 dns.DNS_QCLASS_IN)
233 questions = []
234 questions.append(q)
235 self.finish_name_packet(p, questions)
237 r = dns.res_rec()
238 r.name = self.key_name
239 r.rr_type = dns.DNS_QTYPE_TKEY
240 r.rr_class = dns.DNS_QCLASS_IN
241 r.ttl = 0
242 r.length = 0xffff
243 rdata = dns.tkey_record()
244 rdata.algorithm = "gss-tsig"
245 rdata.inception = int(time.time())
246 rdata.expiration = int(time.time()) + 60*60
247 rdata.mode = dns.DNS_TKEY_MODE_GSSAPI
248 rdata.error = 0
249 rdata.other_size = 0
251 self.g = gensec.Security.start_client(self.settings)
252 self.g.set_credentials(creds)
253 self.g.set_target_service("dns")
254 self.g.set_target_hostname(self.server)
255 self.g.want_feature(gensec.FEATURE_SIGN)
256 self.g.start_mech_by_name("spnego")
258 finished = False
259 client_to_server = ""
261 (finished, server_to_client) = self.g.update(client_to_server)
262 self.assertFalse(finished)
264 data = [ord(x) for x in list(server_to_client)]
265 rdata.key_data = data
266 rdata.key_size = len(data)
267 r.rdata = rdata
269 additional = [r]
270 p.arcount = 1
271 p.additional = additional
273 (response, response_packet) = self.dns_transaction_tcp(p, self.server_ip)
274 self.assert_dns_rcode_equals(response, dns.DNS_RCODE_OK)
276 tkey_record = response.answers[0].rdata
277 data = [chr(x) for x in tkey_record.key_data]
278 server_to_client = ''.join(data)
279 (finished, client_to_server) = self.g.update(server_to_client)
280 self.assertTrue(finished)
282 self.verify_packet(response, response_packet)
284 def verify_packet(self, response, response_packet, request_mac=""):
285 self.assertEqual(response.additional[0].rr_type, dns.DNS_QTYPE_TSIG)
287 tsig_record = response.additional[0].rdata
288 mac = ''.join([chr(x) for x in tsig_record.mac])
290 # Cut off tsig record from dns response packet for MAC verification
291 # and reset additional record count.
292 key_name_len = len(self.key_name) + 2
293 tsig_record_len = len(ndr.ndr_pack(tsig_record)) + key_name_len + 10
295 response_packet_list = list(response_packet)
296 del response_packet_list[-tsig_record_len:]
297 response_packet_list[11] = chr(0)
298 response_packet_wo_tsig = ''.join(response_packet_list)
300 fake_tsig = dns.fake_tsig_rec()
301 fake_tsig.name = self.key_name
302 fake_tsig.rr_class = dns.DNS_QCLASS_ANY
303 fake_tsig.ttl = 0
304 fake_tsig.time_prefix = tsig_record.time_prefix
305 fake_tsig.time = tsig_record.time
306 fake_tsig.algorithm_name = tsig_record.algorithm_name
307 fake_tsig.fudge = tsig_record.fudge
308 fake_tsig.error = 0
309 fake_tsig.other_size = 0
310 fake_tsig_packet = ndr.ndr_pack(fake_tsig)
312 data = request_mac + response_packet_wo_tsig + fake_tsig_packet
313 self.g.check_packet(data, data, mac)
315 def sign_packet(self, packet, key_name):
316 "Sign a packet, calculate a MAC and add TSIG record"
317 packet_data = ndr.ndr_pack(packet)
319 fake_tsig = dns.fake_tsig_rec()
320 fake_tsig.name = key_name
321 fake_tsig.rr_class = dns.DNS_QCLASS_ANY
322 fake_tsig.ttl = 0
323 fake_tsig.time_prefix = 0
324 fake_tsig.time = int(time.time())
325 fake_tsig.algorithm_name = "gss-tsig"
326 fake_tsig.fudge = 300
327 fake_tsig.error = 0
328 fake_tsig.other_size = 0
329 fake_tsig_packet = ndr.ndr_pack(fake_tsig)
331 data = packet_data + fake_tsig_packet
332 mac = self.g.sign_packet(data, data)
333 mac_list = [ord(x) for x in list(mac)]
335 rdata = dns.tsig_record()
336 rdata.algorithm_name = "gss-tsig"
337 rdata.time_prefix = 0
338 rdata.time = fake_tsig.time
339 rdata.fudge = 300
340 rdata.original_id = packet.id
341 rdata.error = 0
342 rdata.other_size = 0
343 rdata.mac = mac_list
344 rdata.mac_size = len(mac_list)
346 r = dns.res_rec()
347 r.name = key_name
348 r.rr_type = dns.DNS_QTYPE_TSIG
349 r.rr_class = dns.DNS_QCLASS_ANY
350 r.ttl = 0
351 r.length = 0xffff
352 r.rdata = rdata
354 additional = [r]
355 packet.additional = additional
356 packet.arcount = 1
358 return mac
360 def bad_sign_packet(self, packet, key_name):
361 '''Add bad signature for a packet by bitflipping
362 the final byte in the MAC'''
364 mac_list = [ord(x) for x in list("badmac")]
366 rdata = dns.tsig_record()
367 rdata.algorithm_name = "gss-tsig"
368 rdata.time_prefix = 0
369 rdata.time = int(time.time())
370 rdata.fudge = 300
371 rdata.original_id = packet.id
372 rdata.error = 0
373 rdata.other_size = 0
374 rdata.mac = mac_list
375 rdata.mac_size = len(mac_list)
377 r = dns.res_rec()
378 r.name = key_name
379 r.rr_type = dns.DNS_QTYPE_TSIG
380 r.rr_class = dns.DNS_QCLASS_ANY
381 r.ttl = 0
382 r.length = 0xffff
383 r.rdata = rdata
385 additional = [r]
386 packet.additional = additional
387 packet.arcount = 1
389 def search_record(self, name):
390 p = self.make_name_packet(dns.DNS_OPCODE_QUERY)
391 questions = []
393 q = self.make_name_question(name, dns.DNS_QTYPE_TXT, dns.DNS_QCLASS_IN)
394 questions.append(q)
396 self.finish_name_packet(p, questions)
397 (response, response_packet) = self.dns_transaction_udp(p, self.server_ip)
398 return response.operation & 0x000F
400 def make_update_request(self, delete=False):
401 "Create a DNS update request"
403 rr_class = dns.DNS_QCLASS_IN
404 ttl = 900
406 if delete:
407 rr_class = dns.DNS_QCLASS_NONE
408 ttl = 0
410 p = self.make_name_packet(dns.DNS_OPCODE_UPDATE)
411 q = self.make_name_question(self.get_dns_domain(),
412 dns.DNS_QTYPE_SOA,
413 dns.DNS_QCLASS_IN)
414 questions = []
415 questions.append(q)
416 self.finish_name_packet(p, questions)
418 updates = []
419 r = dns.res_rec()
420 r.name = self.newrecname
421 r.rr_type = dns.DNS_QTYPE_TXT
422 r.rr_class = rr_class
423 r.ttl = ttl
424 r.length = 0xffff
425 rdata = self.make_txt_record(['"This is a test"'])
426 r.rdata = rdata
427 updates.append(r)
428 p.nscount = len(updates)
429 p.nsrecs = updates
431 return p