1 # Unix SMB/CIFS implementation.
2 # Copyright (C) Kai Blin <kai@samba.org> 2011
3 # Copyright (C) Ralph Boehme <slow@samba.org> 2016
5 # This program is free software; you can redistribute it and/or modify
6 # it under the terms of the GNU General Public License as published by
7 # the Free Software Foundation; either version 3 of the License, or
8 # (at your option) any later version.
10 # This program is distributed in the hope that it will be useful,
11 # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 # GNU General Public License for more details.
15 # You should have received a copy of the GNU General Public License
16 # along with this program. If not, see <http://www.gnu.org/licenses/>.
19 from samba
.tests
import TestCaseInTempDir
20 from samba
.dcerpc
import dns
, dnsp
21 from samba
import gensec
, tests
22 from samba
import credentials
24 import samba
.ndr
as ndr
30 class DNSTest(TestCaseInTempDir
):
33 super(DNSTest
, self
).setUp()
36 def errstr(self
, errcode
):
37 "Return a readable error code"
59 return string_codes
[errcode
]
61 def assert_rcode_equals(self
, rcode
, expected
):
62 "Helper function to check return code"
63 self
.assertEquals(rcode
, expected
, "Expected RCODE %s, got %s" %
64 (self
.errstr(expected
), self
.errstr(rcode
)))
66 def assert_dns_rcode_equals(self
, packet
, rcode
):
67 "Helper function to check return code"
68 p_errcode
= packet
.operation
& 0x000F
69 self
.assertEquals(p_errcode
, rcode
, "Expected RCODE %s, got %s" %
70 (self
.errstr(rcode
), self
.errstr(p_errcode
)))
72 def assert_dns_opcode_equals(self
, packet
, opcode
):
73 "Helper function to check opcode"
74 p_opcode
= packet
.operation
& 0x7800
75 self
.assertEquals(p_opcode
, opcode
, "Expected OPCODE %s, got %s" %
78 def make_name_packet(self
, opcode
, qid
=None):
79 "Helper creating a dns.name_packet"
82 p
.id = random
.randint(0x0, 0xff00)
88 def finish_name_packet(self
, packet
, questions
):
89 "Helper to finalize a dns.name_packet"
90 packet
.qdcount
= len(questions
)
91 packet
.questions
= questions
93 def make_name_question(self
, name
, qtype
, qclass
):
94 "Helper creating a dns.name_question"
95 q
= dns
.name_question()
97 q
.question_type
= qtype
98 q
.question_class
= qclass
101 def make_txt_record(self
, records
):
102 rdata_txt
= dns
.txt_record()
103 s_list
= dnsp
.string_list()
104 s_list
.count
= len(records
)
106 rdata_txt
.txt
= s_list
109 def get_dns_domain(self
):
110 "Helper to get dns domain"
111 return self
.creds
.get_realm().lower()
113 def dns_transaction_udp(self
, packet
, host
,
114 dump
=False, timeout
=None):
115 "send a DNS query and read the reply"
118 timeout
= self
.timeout
120 send_packet
= ndr
.ndr_pack(packet
)
122 print self
.hexdump(send_packet
)
123 s
= socket
.socket(socket
.AF_INET
, socket
.SOCK_DGRAM
, 0)
124 s
.settimeout(timeout
)
125 s
.connect((host
, 53))
126 s
.sendall(send_packet
, 0)
127 recv_packet
= s
.recv(2048, 0)
129 print self
.hexdump(recv_packet
)
130 response
= ndr
.ndr_unpack(dns
.name_packet
, recv_packet
)
131 return (response
, recv_packet
)
136 def dns_transaction_tcp(self
, packet
, host
,
137 dump
=False, timeout
=None):
138 "send a DNS query and read the reply, also return the raw packet"
141 timeout
= self
.timeout
143 send_packet
= ndr
.ndr_pack(packet
)
145 print self
.hexdump(send_packet
)
146 s
= socket
.socket(socket
.AF_INET
, socket
.SOCK_STREAM
, 0)
147 s
.settimeout(timeout
)
148 s
.connect((host
, 53))
149 tcp_packet
= struct
.pack('!H', len(send_packet
))
150 tcp_packet
+= send_packet
151 s
.sendall(tcp_packet
)
153 recv_packet
= s
.recv(0xffff + 2, 0)
155 print self
.hexdump(recv_packet
)
156 response
= ndr
.ndr_unpack(dns
.name_packet
, recv_packet
[2:])
162 # unpacking and packing again should produce same bytestream
163 my_packet
= ndr
.ndr_pack(response
)
164 self
.assertEquals(my_packet
, recv_packet
[2:])
166 return (response
, recv_packet
[2:])
168 def make_txt_update(self
, prefix
, txt_array
):
169 p
= self
.make_name_packet(dns
.DNS_OPCODE_UPDATE
)
172 name
= self
.get_dns_domain()
173 u
= self
.make_name_question(name
, dns
.DNS_QTYPE_SOA
, dns
.DNS_QCLASS_IN
)
175 self
.finish_name_packet(p
, updates
)
179 r
.name
= "%s.%s" % (prefix
, self
.get_dns_domain())
180 r
.rr_type
= dns
.DNS_QTYPE_TXT
181 r
.rr_class
= dns
.DNS_QCLASS_IN
184 rdata
= self
.make_txt_record(txt_array
)
187 p
.nscount
= len(updates
)
192 def check_query_txt(self
, prefix
, txt_array
):
193 name
= "%s.%s" % (prefix
, self
.get_dns_domain())
194 p
= self
.make_name_packet(dns
.DNS_OPCODE_QUERY
)
197 q
= self
.make_name_question(name
, dns
.DNS_QTYPE_TXT
, dns
.DNS_QCLASS_IN
)
200 self
.finish_name_packet(p
, questions
)
201 (response
, response_packet
) = self
.dns_transaction_udp(p
, host
=self
.server_ip
)
202 self
.assert_dns_rcode_equals(response
, dns
.DNS_RCODE_OK
)
203 self
.assertEquals(response
.ancount
, 1)
204 self
.assertEquals(response
.answers
[0].rdata
.txt
.str, txt_array
)
207 class DNSTKeyTest(DNSTest
):
209 super(DNSTKeyTest
, self
).setUp()
211 self
.settings
["lp_ctx"] = self
.lp_ctx
= tests
.env_loadparm()
212 self
.settings
["target_hostname"] = self
.server
214 self
.creds
= credentials
.Credentials()
215 self
.creds
.guess(self
.lp_ctx
)
216 self
.creds
.set_username(tests
.env_get_var_value('USERNAME'))
217 self
.creds
.set_password(tests
.env_get_var_value('PASSWORD'))
218 self
.creds
.set_kerberos_state(credentials
.MUST_USE_KERBEROS
)
219 self
.newrecname
= "tkeytsig.%s" % self
.get_dns_domain()
221 def tkey_trans(self
, creds
=None):
222 "Do a TKEY transaction and establish a gensec context"
227 self
.key_name
= "%s.%s" % (uuid
.uuid4(), self
.get_dns_domain())
229 p
= self
.make_name_packet(dns
.DNS_OPCODE_QUERY
)
230 q
= self
.make_name_question(self
.key_name
,
235 self
.finish_name_packet(p
, questions
)
238 r
.name
= self
.key_name
239 r
.rr_type
= dns
.DNS_QTYPE_TKEY
240 r
.rr_class
= dns
.DNS_QCLASS_IN
243 rdata
= dns
.tkey_record()
244 rdata
.algorithm
= "gss-tsig"
245 rdata
.inception
= int(time
.time())
246 rdata
.expiration
= int(time
.time()) + 60*60
247 rdata
.mode
= dns
.DNS_TKEY_MODE_GSSAPI
251 self
.g
= gensec
.Security
.start_client(self
.settings
)
252 self
.g
.set_credentials(creds
)
253 self
.g
.set_target_service("dns")
254 self
.g
.set_target_hostname(self
.server
)
255 self
.g
.want_feature(gensec
.FEATURE_SIGN
)
256 self
.g
.start_mech_by_name("spnego")
259 client_to_server
= ""
261 (finished
, server_to_client
) = self
.g
.update(client_to_server
)
262 self
.assertFalse(finished
)
264 data
= [ord(x
) for x
in list(server_to_client
)]
265 rdata
.key_data
= data
266 rdata
.key_size
= len(data
)
271 p
.additional
= additional
273 (response
, response_packet
) = self
.dns_transaction_tcp(p
, self
.server_ip
)
274 self
.assert_dns_rcode_equals(response
, dns
.DNS_RCODE_OK
)
276 tkey_record
= response
.answers
[0].rdata
277 data
= [chr(x
) for x
in tkey_record
.key_data
]
278 server_to_client
= ''.join(data
)
279 (finished
, client_to_server
) = self
.g
.update(server_to_client
)
280 self
.assertTrue(finished
)
282 self
.verify_packet(response
, response_packet
)
284 def verify_packet(self
, response
, response_packet
, request_mac
=""):
285 self
.assertEqual(response
.additional
[0].rr_type
, dns
.DNS_QTYPE_TSIG
)
287 tsig_record
= response
.additional
[0].rdata
288 mac
= ''.join([chr(x
) for x
in tsig_record
.mac
])
290 # Cut off tsig record from dns response packet for MAC verification
291 # and reset additional record count.
292 key_name_len
= len(self
.key_name
) + 2
293 tsig_record_len
= len(ndr
.ndr_pack(tsig_record
)) + key_name_len
+ 10
295 response_packet_list
= list(response_packet
)
296 del response_packet_list
[-tsig_record_len
:]
297 response_packet_list
[11] = chr(0)
298 response_packet_wo_tsig
= ''.join(response_packet_list
)
300 fake_tsig
= dns
.fake_tsig_rec()
301 fake_tsig
.name
= self
.key_name
302 fake_tsig
.rr_class
= dns
.DNS_QCLASS_ANY
304 fake_tsig
.time_prefix
= tsig_record
.time_prefix
305 fake_tsig
.time
= tsig_record
.time
306 fake_tsig
.algorithm_name
= tsig_record
.algorithm_name
307 fake_tsig
.fudge
= tsig_record
.fudge
309 fake_tsig
.other_size
= 0
310 fake_tsig_packet
= ndr
.ndr_pack(fake_tsig
)
312 data
= request_mac
+ response_packet_wo_tsig
+ fake_tsig_packet
313 self
.g
.check_packet(data
, data
, mac
)
315 def sign_packet(self
, packet
, key_name
):
316 "Sign a packet, calculate a MAC and add TSIG record"
317 packet_data
= ndr
.ndr_pack(packet
)
319 fake_tsig
= dns
.fake_tsig_rec()
320 fake_tsig
.name
= key_name
321 fake_tsig
.rr_class
= dns
.DNS_QCLASS_ANY
323 fake_tsig
.time_prefix
= 0
324 fake_tsig
.time
= int(time
.time())
325 fake_tsig
.algorithm_name
= "gss-tsig"
326 fake_tsig
.fudge
= 300
328 fake_tsig
.other_size
= 0
329 fake_tsig_packet
= ndr
.ndr_pack(fake_tsig
)
331 data
= packet_data
+ fake_tsig_packet
332 mac
= self
.g
.sign_packet(data
, data
)
333 mac_list
= [ord(x
) for x
in list(mac
)]
335 rdata
= dns
.tsig_record()
336 rdata
.algorithm_name
= "gss-tsig"
337 rdata
.time_prefix
= 0
338 rdata
.time
= fake_tsig
.time
340 rdata
.original_id
= packet
.id
344 rdata
.mac_size
= len(mac_list
)
348 r
.rr_type
= dns
.DNS_QTYPE_TSIG
349 r
.rr_class
= dns
.DNS_QCLASS_ANY
355 packet
.additional
= additional
360 def bad_sign_packet(self
, packet
, key_name
):
361 '''Add bad signature for a packet by bitflipping
362 the final byte in the MAC'''
364 mac_list
= [ord(x
) for x
in list("badmac")]
366 rdata
= dns
.tsig_record()
367 rdata
.algorithm_name
= "gss-tsig"
368 rdata
.time_prefix
= 0
369 rdata
.time
= int(time
.time())
371 rdata
.original_id
= packet
.id
375 rdata
.mac_size
= len(mac_list
)
379 r
.rr_type
= dns
.DNS_QTYPE_TSIG
380 r
.rr_class
= dns
.DNS_QCLASS_ANY
386 packet
.additional
= additional
389 def search_record(self
, name
):
390 p
= self
.make_name_packet(dns
.DNS_OPCODE_QUERY
)
393 q
= self
.make_name_question(name
, dns
.DNS_QTYPE_TXT
, dns
.DNS_QCLASS_IN
)
396 self
.finish_name_packet(p
, questions
)
397 (response
, response_packet
) = self
.dns_transaction_udp(p
, self
.server_ip
)
398 return response
.operation
& 0x000F
400 def make_update_request(self
, delete
=False):
401 "Create a DNS update request"
403 rr_class
= dns
.DNS_QCLASS_IN
407 rr_class
= dns
.DNS_QCLASS_NONE
410 p
= self
.make_name_packet(dns
.DNS_OPCODE_UPDATE
)
411 q
= self
.make_name_question(self
.get_dns_domain(),
416 self
.finish_name_packet(p
, questions
)
420 r
.name
= self
.newrecname
421 r
.rr_type
= dns
.DNS_QTYPE_TXT
422 r
.rr_class
= rr_class
425 rdata
= self
.make_txt_record(['"This is a test"'])
428 p
.nscount
= len(updates
)