systemd unit
[ddos.git] / dnsq.ml
blob1663df144898a5159b87518443d339a72425ca16
3 open Printf
4 open Devkit
5 open ExtLib
6 open Lwt_unix
8 let log = Log.from "dnsq"
10 let () = Random.self_init ()
12 let verbose = ref 0
13 let maxlen = 1024
15 type dres =
16 | DNS of Dns.rcode (* rcode of DNS query answer *)
17 | TIMEOUT (* query performed, but no answer received *)
18 | ADDR (* no query was needed, address converted from string *)
19 | INVALID (* no query was needed, domain name invalid *)
20 | BLACKDOMAIN (* no query was needed, domain name is blacklisted *)
21 | BLACKIP (* all domain ips are blacklisted *)
23 let show_dres = function
24 | DNS rcode -> Dns.string_of_rcode rcode
25 | TIMEOUT -> "TIMEOUT"
26 | ADDR -> "ADDR"
27 | INVALID -> "INVALID"
28 | BLACKDOMAIN -> "BLACKDOMAIN"
29 | BLACKIP -> "BLACKIP"
31 type ttl = Forever | Seconds of int
32 type ip = { ip : Network.ipv4; ttl : ttl; }
33 type result = { domain: string; dres: dres; cname: string option; ips: ip list; txt : string list list }
35 type t = {
36 sock : file_descr;
37 addr : sockaddr;
38 timeout : float;
39 buf : bytes;
40 edns : int option;
41 h : (int * string, result Lwt.u) Hashtbl.t Lazy.t;
42 mutable queries : int;
45 let upstream_addr ?(timeout=3.) ?edns addr =
46 let%lwt proto = getprotobyname "udp" in
47 let sock = socket PF_INET SOCK_DGRAM proto.p_proto in
48 setsockopt_float sock SO_RCVTIMEO timeout;
49 setsockopt_float sock SO_SNDTIMEO timeout;
50 let%lwt () = connect sock addr in
51 Lwt_unix.set_close_on_exec sock;
52 Lwt.return { sock; addr; timeout; buf = Bytes.create (Option.default maxlen edns); edns; h = lazy (Hashtbl.create 13); queries = 0; }
54 let upstream ?(port=53) ?edns ?timeout host =
55 upstream_addr ?edns (Nix.make_inet_addr_exn host port) ?timeout
57 let get_reply_exn { sock; addr; buf; _ } =
58 let%lwt msg =
59 let%lwt (len,peer) = recvfrom sock buf 0 (Bytes.length buf) [] in
60 if peer = addr then
61 Lwt.return @@ Bytes.sub_string buf 0 len
62 else
63 Exn_lwt.fail "Wrong peer : expected %s got %s" (Nix.show_addr addr) (Nix.show_addr peer)
65 if !verbose > 2 then
66 print_endline @@ Action.hexdump msg;
67 if !verbose > 1 then
68 print_endline @@ Dns.pkt_out_s @@ Dns.to_pkt msg;
69 let (info,typ,cname,addrs,txt) = Dns.parse msg in
70 match typ, info with
71 | `Reply (rcode,_aa,_ra), { Dns.id; qtype = Dns.A | TXT | MX; domain; } ->
72 Lwt.return (id, Inet.string_of_domain domain, rcode, Option.map Inet.string_of_domain cname, addrs, txt)
73 | _ ->
74 Exn_lwt.fail "expected reply to A or TXT query"
76 let resolve_immediate domain =
77 match Network.ipv4_of_string_exn domain with
78 | addr -> Some { domain; dres = ADDR; cname = None; ips = [{ ip = addr; ttl = Forever; }]; txt = [] } (* do not query IP address *)
79 | exception _ -> None
81 match Inet.is_dns_domain domain with
82 | true -> None
83 | false -> Some { domain; dres = INVALID; cname = None; ips = []; txt = [] }
86 let send_query srv dns_id qtype domain =
87 let pkt = Bytes.unsafe_of_string @@ Dns.make_query ?edns:srv.edns dns_id qtype domain in
88 srv.queries <- srv.queries + 1;
89 let%lwt len = sendto srv.sock pkt 0 (Bytes.length pkt) [] srv.addr in
90 if len = Bytes.length pkt then
91 Lwt.return ()
92 else
93 Exn_lwt.fail "can't send full packet for %S to %s" domain (Nix.show_addr srv.addr)
95 let query_exn srv ?(qtype=Dns.A) domain =
96 match resolve_immediate domain with
97 | Some x -> x
98 | None ->
99 let dns_id = Random.int (succ Dns.max_id) in
100 let t = new Action.timer in
101 Lwt_main.run (send_query srv dns_id qtype domain);
102 let rec loop () =
103 let (id',domain,rcode,cname,addrs,txt) = Lwt_main.run (get_reply_exn srv) in
104 if dns_id <> id' then
105 begin
106 if t#get > srv.timeout then Exn.fail "timeouted";
107 loop ()
109 else
110 let ips = List.map (fun (ip, ttl) -> { ip; ttl = Seconds (Int32.to_int ttl); }) addrs in
111 { domain; dres = DNS rcode; cname; ips; txt }
113 loop ()
115 (* asynchronous interface *)
117 let running t = if Lazy.is_val t.h then Hashtbl.length !!(t.h) else 0
119 let setup srv =
120 match Lazy.is_val srv.h with
121 | true -> Lazy.force srv.h
122 | false ->
123 let h = Lazy.force srv.h in
124 let rec loop () =
125 match%lwt Exn_lwt.map get_reply_exn srv with
126 | `Exn Lwt.Canceled -> Lwt.(fail Canceled)
127 | `Exn exn -> log #warn ~exn "receive loop %s" (Nix.show_addr srv.addr); loop ()
128 | `Ok (id,domain,rcode,cname,addrs,txt) ->
129 match Hashtbl.find_option h (id,domain) with
130 | None -> (*log #debug "orphan reply : %d %s" id domain (* timeouted? *) *) loop ()
131 | Some u ->
132 Hashtbl.remove_all h (id,domain);
133 let ips = List.map (fun (ip, ttl) -> { ip; ttl = Seconds (Int32.to_int ttl); }) addrs in
134 Lwt.wakeup u { domain; dres = DNS rcode; cname; ips; txt };
135 loop ()
137 Lwt.ignore_result (loop ()); (* run loop *)
140 let status t = sprintf "dns %s running %d done %d" (Nix.show_addr t.addr) (running t) t.queries
142 let send_query_forget t ?(qtype=Dns.A) domain =
143 let dns_id = Random.int (succ Dns.max_id) in
144 send_query t dns_id qtype domain
146 let do_real_query t qtype domain =
147 let h = setup t in
148 let dns_id = Random.int (succ Dns.max_id) in
149 let req = (dns_id,domain) in
150 if Hashtbl.mem h req then log #warn "(%d,%s) query queued already (use cache!)" dns_id domain;
151 let (result,u) = Lwt.task () in
152 Hashtbl.replace h req u;
153 let%lwt () = send_query t dns_id qtype domain in
154 try%lwt
155 Lwt.pick [Lwt_unix.timeout t.timeout; result]
156 with
157 Lwt_unix.Timeout | Lwt.Canceled -> Hashtbl.remove_all h req; Lwt.return { domain; dres = TIMEOUT; cname = None; ips = []; txt = [] }
159 let do_query t ?(qtype=Dns.A) domain =
160 match resolve_immediate domain with
161 | Some x -> Lwt.return x
162 | None -> do_real_query t qtype domain
165 let refill t =
167 while running t < t.rate do
168 let (s,cb) = Queue.peek t.source in
169 Enum.iter (fun domain -> do_query t domain cb; if running t >= t.rate then raise Exit) e;
170 ignore (Queue.pop t.source) (* exhausted this enum, drop it *)
171 done
172 with
173 | Exit | Queue.Empty -> (*log #info "%s" (status t);*) ()
174 | exn -> log #warn ~exn "refill"
178 let query t domain k =
179 Queue.push (domain, k) t.source;
180 refill t
184 let stop t =
185 log #debug "stop %s" (Nix.show_addr t.addr);
186 if not !(t.stop) then
187 begin t.stop := true; Ev.del t.ev; Ev.del t.timer; Hashtbl.clear t.h end