s3:smb2_lock: let smbd_smb2_lock_cancel() trigger NT_STATUS_CANCELED
[Samba.git] / lib / dnspython / dns / query.py
blobaddee4e3f2de67adfb4f6c2953e7b37cc5aa51a9
1 # Copyright (C) 2003-2007, 2009-2011 Nominum, Inc.
3 # Permission to use, copy, modify, and distribute this software and its
4 # documentation for any purpose with or without fee is hereby granted,
5 # provided that the above copyright notice and this permission notice
6 # appear in all copies.
8 # THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
9 # WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
10 # MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
11 # ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
12 # WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
13 # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
14 # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
16 """Talk to a DNS server."""
18 from __future__ import generators
20 import errno
21 import select
22 import socket
23 import struct
24 import sys
25 import time
27 import dns.exception
28 import dns.inet
29 import dns.name
30 import dns.message
31 import dns.rdataclass
32 import dns.rdatatype
34 class UnexpectedSource(dns.exception.DNSException):
35 """Raised if a query response comes from an unexpected address or port."""
36 pass
38 class BadResponse(dns.exception.FormError):
39 """Raised if a query response does not respond to the question asked."""
40 pass
42 def _compute_expiration(timeout):
43 if timeout is None:
44 return None
45 else:
46 return time.time() + timeout
48 def _poll_for(fd, readable, writable, error, timeout):
49 """
50 @param fd: File descriptor (int).
51 @param readable: Whether to wait for readability (bool).
52 @param writable: Whether to wait for writability (bool).
53 @param expiration: Deadline timeout (expiration time, in seconds (float)).
55 @return True on success, False on timeout
56 """
57 event_mask = 0
58 if readable:
59 event_mask |= select.POLLIN
60 if writable:
61 event_mask |= select.POLLOUT
62 if error:
63 event_mask |= select.POLLERR
65 pollable = select.poll()
66 pollable.register(fd, event_mask)
68 if timeout:
69 event_list = pollable.poll(long(timeout * 1000))
70 else:
71 event_list = pollable.poll()
73 return bool(event_list)
75 def _select_for(fd, readable, writable, error, timeout):
76 """
77 @param fd: File descriptor (int).
78 @param readable: Whether to wait for readability (bool).
79 @param writable: Whether to wait for writability (bool).
80 @param expiration: Deadline timeout (expiration time, in seconds (float)).
82 @return True on success, False on timeout
83 """
84 rset, wset, xset = [], [], []
86 if readable:
87 rset = [fd]
88 if writable:
89 wset = [fd]
90 if error:
91 xset = [fd]
93 if timeout is None:
94 (rcount, wcount, xcount) = select.select(rset, wset, xset)
95 else:
96 (rcount, wcount, xcount) = select.select(rset, wset, xset, timeout)
98 return bool((rcount or wcount or xcount))
100 def _wait_for(fd, readable, writable, error, expiration):
101 done = False
102 while not done:
103 if expiration is None:
104 timeout = None
105 else:
106 timeout = expiration - time.time()
107 if timeout <= 0.0:
108 raise dns.exception.Timeout
109 try:
110 if not _polling_backend(fd, readable, writable, error, timeout):
111 raise dns.exception.Timeout
112 except select.error, e:
113 if e.args[0] != errno.EINTR:
114 raise e
115 done = True
117 def _set_polling_backend(fn):
119 Internal API. Do not use.
121 global _polling_backend
123 _polling_backend = fn
125 if hasattr(select, 'poll'):
126 # Prefer poll() on platforms that support it because it has no
127 # limits on the maximum value of a file descriptor (plus it will
128 # be more efficient for high values).
129 _polling_backend = _poll_for
130 else:
131 _polling_backend = _select_for
133 def _wait_for_readable(s, expiration):
134 _wait_for(s, True, False, True, expiration)
136 def _wait_for_writable(s, expiration):
137 _wait_for(s, False, True, True, expiration)
139 def _addresses_equal(af, a1, a2):
140 # Convert the first value of the tuple, which is a textual format
141 # address into binary form, so that we are not confused by different
142 # textual representations of the same address
143 n1 = dns.inet.inet_pton(af, a1[0])
144 n2 = dns.inet.inet_pton(af, a2[0])
145 return n1 == n2 and a1[1:] == a2[1:]
147 def udp(q, where, timeout=None, port=53, af=None, source=None, source_port=0,
148 ignore_unexpected=False, one_rr_per_rrset=False):
149 """Return the response obtained after sending a query via UDP.
151 @param q: the query
152 @type q: dns.message.Message
153 @param where: where to send the message
154 @type where: string containing an IPv4 or IPv6 address
155 @param timeout: The number of seconds to wait before the query times out.
156 If None, the default, wait forever.
157 @type timeout: float
158 @param port: The port to which to send the message. The default is 53.
159 @type port: int
160 @param af: the address family to use. The default is None, which
161 causes the address family to use to be inferred from the form of of where.
162 If the inference attempt fails, AF_INET is used.
163 @type af: int
164 @rtype: dns.message.Message object
165 @param source: source address. The default is the IPv4 wildcard address.
166 @type source: string
167 @param source_port: The port from which to send the message.
168 The default is 0.
169 @type source_port: int
170 @param ignore_unexpected: If True, ignore responses from unexpected
171 sources. The default is False.
172 @type ignore_unexpected: bool
173 @param one_rr_per_rrset: Put each RR into its own RRset
174 @type one_rr_per_rrset: bool
177 wire = q.to_wire()
178 if af is None:
179 try:
180 af = dns.inet.af_for_address(where)
181 except:
182 af = dns.inet.AF_INET
183 if af == dns.inet.AF_INET:
184 destination = (where, port)
185 if source is not None:
186 source = (source, source_port)
187 elif af == dns.inet.AF_INET6:
188 destination = (where, port, 0, 0)
189 if source is not None:
190 source = (source, source_port, 0, 0)
191 s = socket.socket(af, socket.SOCK_DGRAM, 0)
192 try:
193 expiration = _compute_expiration(timeout)
194 s.setblocking(0)
195 if source is not None:
196 s.bind(source)
197 _wait_for_writable(s, expiration)
198 s.sendto(wire, destination)
199 while 1:
200 _wait_for_readable(s, expiration)
201 (wire, from_address) = s.recvfrom(65535)
202 if _addresses_equal(af, from_address, destination) or \
203 (dns.inet.is_multicast(where) and \
204 from_address[1:] == destination[1:]):
205 break
206 if not ignore_unexpected:
207 raise UnexpectedSource('got a response from '
208 '%s instead of %s' % (from_address,
209 destination))
210 finally:
211 s.close()
212 r = dns.message.from_wire(wire, keyring=q.keyring, request_mac=q.mac,
213 one_rr_per_rrset=one_rr_per_rrset)
214 if not q.is_response(r):
215 raise BadResponse
216 return r
218 def _net_read(sock, count, expiration):
219 """Read the specified number of bytes from sock. Keep trying until we
220 either get the desired amount, or we hit EOF.
221 A Timeout exception will be raised if the operation is not completed
222 by the expiration time.
224 s = ''
225 while count > 0:
226 _wait_for_readable(sock, expiration)
227 n = sock.recv(count)
228 if n == '':
229 raise EOFError
230 count = count - len(n)
231 s = s + n
232 return s
234 def _net_write(sock, data, expiration):
235 """Write the specified data to the socket.
236 A Timeout exception will be raised if the operation is not completed
237 by the expiration time.
239 current = 0
240 l = len(data)
241 while current < l:
242 _wait_for_writable(sock, expiration)
243 current += sock.send(data[current:])
245 def _connect(s, address):
246 try:
247 s.connect(address)
248 except socket.error:
249 (ty, v) = sys.exc_info()[:2]
250 if v[0] != errno.EINPROGRESS and \
251 v[0] != errno.EWOULDBLOCK and \
252 v[0] != errno.EALREADY:
253 raise v
255 def tcp(q, where, timeout=None, port=53, af=None, source=None, source_port=0,
256 one_rr_per_rrset=False):
257 """Return the response obtained after sending a query via TCP.
259 @param q: the query
260 @type q: dns.message.Message object
261 @param where: where to send the message
262 @type where: string containing an IPv4 or IPv6 address
263 @param timeout: The number of seconds to wait before the query times out.
264 If None, the default, wait forever.
265 @type timeout: float
266 @param port: The port to which to send the message. The default is 53.
267 @type port: int
268 @param af: the address family to use. The default is None, which
269 causes the address family to use to be inferred from the form of of where.
270 If the inference attempt fails, AF_INET is used.
271 @type af: int
272 @rtype: dns.message.Message object
273 @param source: source address. The default is the IPv4 wildcard address.
274 @type source: string
275 @param source_port: The port from which to send the message.
276 The default is 0.
277 @type source_port: int
278 @param one_rr_per_rrset: Put each RR into its own RRset
279 @type one_rr_per_rrset: bool
282 wire = q.to_wire()
283 if af is None:
284 try:
285 af = dns.inet.af_for_address(where)
286 except:
287 af = dns.inet.AF_INET
288 if af == dns.inet.AF_INET:
289 destination = (where, port)
290 if source is not None:
291 source = (source, source_port)
292 elif af == dns.inet.AF_INET6:
293 destination = (where, port, 0, 0)
294 if source is not None:
295 source = (source, source_port, 0, 0)
296 s = socket.socket(af, socket.SOCK_STREAM, 0)
297 try:
298 expiration = _compute_expiration(timeout)
299 s.setblocking(0)
300 if source is not None:
301 s.bind(source)
302 _connect(s, destination)
304 l = len(wire)
306 # copying the wire into tcpmsg is inefficient, but lets us
307 # avoid writev() or doing a short write that would get pushed
308 # onto the net
309 tcpmsg = struct.pack("!H", l) + wire
310 _net_write(s, tcpmsg, expiration)
311 ldata = _net_read(s, 2, expiration)
312 (l,) = struct.unpack("!H", ldata)
313 wire = _net_read(s, l, expiration)
314 finally:
315 s.close()
316 r = dns.message.from_wire(wire, keyring=q.keyring, request_mac=q.mac,
317 one_rr_per_rrset=one_rr_per_rrset)
318 if not q.is_response(r):
319 raise BadResponse
320 return r
322 def xfr(where, zone, rdtype=dns.rdatatype.AXFR, rdclass=dns.rdataclass.IN,
323 timeout=None, port=53, keyring=None, keyname=None, relativize=True,
324 af=None, lifetime=None, source=None, source_port=0, serial=0,
325 use_udp=False, keyalgorithm=dns.tsig.default_algorithm):
326 """Return a generator for the responses to a zone transfer.
328 @param where: where to send the message
329 @type where: string containing an IPv4 or IPv6 address
330 @param zone: The name of the zone to transfer
331 @type zone: dns.name.Name object or string
332 @param rdtype: The type of zone transfer. The default is
333 dns.rdatatype.AXFR.
334 @type rdtype: int or string
335 @param rdclass: The class of the zone transfer. The default is
336 dns.rdatatype.IN.
337 @type rdclass: int or string
338 @param timeout: The number of seconds to wait for each response message.
339 If None, the default, wait forever.
340 @type timeout: float
341 @param port: The port to which to send the message. The default is 53.
342 @type port: int
343 @param keyring: The TSIG keyring to use
344 @type keyring: dict
345 @param keyname: The name of the TSIG key to use
346 @type keyname: dns.name.Name object or string
347 @param relativize: If True, all names in the zone will be relativized to
348 the zone origin. It is essential that the relativize setting matches
349 the one specified to dns.zone.from_xfr().
350 @type relativize: bool
351 @param af: the address family to use. The default is None, which
352 causes the address family to use to be inferred from the form of of where.
353 If the inference attempt fails, AF_INET is used.
354 @type af: int
355 @param lifetime: The total number of seconds to spend doing the transfer.
356 If None, the default, then there is no limit on the time the transfer may
357 take.
358 @type lifetime: float
359 @rtype: generator of dns.message.Message objects.
360 @param source: source address. The default is the IPv4 wildcard address.
361 @type source: string
362 @param source_port: The port from which to send the message.
363 The default is 0.
364 @type source_port: int
365 @param serial: The SOA serial number to use as the base for an IXFR diff
366 sequence (only meaningful if rdtype == dns.rdatatype.IXFR).
367 @type serial: int
368 @param use_udp: Use UDP (only meaningful for IXFR)
369 @type use_udp: bool
370 @param keyalgorithm: The TSIG algorithm to use; defaults to
371 dns.tsig.default_algorithm
372 @type keyalgorithm: string
375 if isinstance(zone, (str, unicode)):
376 zone = dns.name.from_text(zone)
377 if isinstance(rdtype, (str, unicode)):
378 rdtype = dns.rdatatype.from_text(rdtype)
379 q = dns.message.make_query(zone, rdtype, rdclass)
380 if rdtype == dns.rdatatype.IXFR:
381 rrset = dns.rrset.from_text(zone, 0, 'IN', 'SOA',
382 '. . %u 0 0 0 0' % serial)
383 q.authority.append(rrset)
384 if not keyring is None:
385 q.use_tsig(keyring, keyname, algorithm=keyalgorithm)
386 wire = q.to_wire()
387 if af is None:
388 try:
389 af = dns.inet.af_for_address(where)
390 except:
391 af = dns.inet.AF_INET
392 if af == dns.inet.AF_INET:
393 destination = (where, port)
394 if source is not None:
395 source = (source, source_port)
396 elif af == dns.inet.AF_INET6:
397 destination = (where, port, 0, 0)
398 if source is not None:
399 source = (source, source_port, 0, 0)
400 if use_udp:
401 if rdtype != dns.rdatatype.IXFR:
402 raise ValueError('cannot do a UDP AXFR')
403 s = socket.socket(af, socket.SOCK_DGRAM, 0)
404 else:
405 s = socket.socket(af, socket.SOCK_STREAM, 0)
406 s.setblocking(0)
407 if source is not None:
408 s.bind(source)
409 expiration = _compute_expiration(lifetime)
410 _connect(s, destination)
411 l = len(wire)
412 if use_udp:
413 _wait_for_writable(s, expiration)
414 s.send(wire)
415 else:
416 tcpmsg = struct.pack("!H", l) + wire
417 _net_write(s, tcpmsg, expiration)
418 done = False
419 soa_rrset = None
420 soa_count = 0
421 if relativize:
422 origin = zone
423 oname = dns.name.empty
424 else:
425 origin = None
426 oname = zone
427 tsig_ctx = None
428 first = True
429 while not done:
430 mexpiration = _compute_expiration(timeout)
431 if mexpiration is None or mexpiration > expiration:
432 mexpiration = expiration
433 if use_udp:
434 _wait_for_readable(s, expiration)
435 (wire, from_address) = s.recvfrom(65535)
436 else:
437 ldata = _net_read(s, 2, mexpiration)
438 (l,) = struct.unpack("!H", ldata)
439 wire = _net_read(s, l, mexpiration)
440 r = dns.message.from_wire(wire, keyring=q.keyring, request_mac=q.mac,
441 xfr=True, origin=origin, tsig_ctx=tsig_ctx,
442 multi=True, first=first,
443 one_rr_per_rrset=(rdtype==dns.rdatatype.IXFR))
444 tsig_ctx = r.tsig_ctx
445 first = False
446 answer_index = 0
447 delete_mode = False
448 expecting_SOA = False
449 if soa_rrset is None:
450 if not r.answer or r.answer[0].name != oname:
451 raise dns.exception.FormError
452 rrset = r.answer[0]
453 if rrset.rdtype != dns.rdatatype.SOA:
454 raise dns.exception.FormError("first RRset is not an SOA")
455 answer_index = 1
456 soa_rrset = rrset.copy()
457 if rdtype == dns.rdatatype.IXFR:
458 if soa_rrset[0].serial == serial:
460 # We're already up-to-date.
462 done = True
463 else:
464 expecting_SOA = True
466 # Process SOAs in the answer section (other than the initial
467 # SOA in the first message).
469 for rrset in r.answer[answer_index:]:
470 if done:
471 raise dns.exception.FormError("answers after final SOA")
472 if rrset.rdtype == dns.rdatatype.SOA and rrset.name == oname:
473 if expecting_SOA:
474 if rrset[0].serial != serial:
475 raise dns.exception.FormError("IXFR base serial mismatch")
476 expecting_SOA = False
477 elif rdtype == dns.rdatatype.IXFR:
478 delete_mode = not delete_mode
479 if rrset == soa_rrset and not delete_mode:
480 done = True
481 elif expecting_SOA:
483 # We made an IXFR request and are expecting another
484 # SOA RR, but saw something else, so this must be an
485 # AXFR response.
487 rdtype = dns.rdatatype.AXFR
488 expecting_SOA = False
489 if done and q.keyring and not r.had_tsig:
490 raise dns.exception.FormError("missing TSIG")
491 yield r
492 s.close()