fmt with go1.19 conventions.
[dnstt.git] / dnstt-server / main.go
blob047683c4a55becd407acd7192ed23f1be815d9f9
1 // dnstt-server is the server end of a DNS tunnel.
2 //
3 // Usage:
4 //
5 // dnstt-server -gen-key [-privkey-file PRIVKEYFILE] [-pubkey-file PUBKEYFILE]
6 // dnstt-server -udp ADDR [-privkey PRIVKEY|-privkey-file PRIVKEYFILE] DOMAIN UPSTREAMADDR
7 //
8 // Example:
9 //
10 // dnstt-server -gen-key -privkey-file server.key -pubkey-file server.pub
11 // dnstt-server -udp :53 -privkey-file server.key t.example.com 127.0.0.1:8000
13 // To generate a persistent server private key, first run with the -gen-key
14 // option. By default the generated private and public keys are printed to
15 // standard output. To save them to files instead, use the -privkey-file and
16 // -pubkey-file options.
18 // dnstt-server -gen-key
19 // dnstt-server -gen-key -privkey-file server.key -pubkey-file server.pub
21 // You can give the server's private key as a file or as a hex string.
23 // -privkey-file server.key
24 // -privkey 0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef
26 // The -udp option controls the address that will listen for incoming DNS
27 // queries.
29 // The -mtu option controls the maximum size of response UDP payloads.
30 // Queries that do not advertise requester support for responses of at least
31 // this size at least this size will be responded to with a FORMERR. The default
32 // value is maxUDPPayload.
34 // DOMAIN is the root of the DNS zone reserved for the tunnel. See README for
35 // instructions on setting it up.
37 // UPSTREAMADDR is the TCP address to which incoming tunnelled streams will be
38 // forwarded.
39 package main
41 import (
42 "bytes"
43 "encoding/base32"
44 "encoding/binary"
45 "errors"
46 "flag"
47 "fmt"
48 "io"
49 "io/ioutil"
50 "log"
51 "net"
52 "os"
53 "sync"
54 "time"
56 "github.com/xtaci/kcp-go/v5"
57 "github.com/xtaci/smux"
58 "www.bamsoftware.com/git/dnstt.git/dns"
59 "www.bamsoftware.com/git/dnstt.git/noise"
60 "www.bamsoftware.com/git/dnstt.git/turbotunnel"
63 const (
64 // smux streams will be closed after this much time without receiving data.
65 idleTimeout = 2 * time.Minute
67 // How to set the TTL field in Answer resource records.
68 responseTTL = 60
70 // How long we may wait for downstream data before sending an empty
71 // response. If another query comes in while we are waiting, we'll send
72 // an empty response anyway and restart the delay timer for the next
73 // response.
75 // This number should be less than 2 seconds, which in 2019 was reported
76 // to be the query timeout of the Quad9 DoH server.
77 // https://dnsencryption.info/imc19-doe.html Section 4.2, Finding 2.4
78 maxResponseDelay = 1 * time.Second
80 // How long to wait for a TCP connection to upstream to be established.
81 upstreamDialTimeout = 30 * time.Second
84 var (
85 // We don't send UDP payloads larger than this, in an attempt to avoid
86 // network-layer fragmentation. 1280 is the minimum IPv6 MTU, 40 bytes
87 // is the size of an IPv6 header (though without any extension headers),
88 // and 8 bytes is the size of a UDP header.
90 // Control this value with the -mtu command-line option.
92 // https://dnsflagday.net/2020/#message-size-considerations
93 // "An EDNS buffer size of 1232 bytes will avoid fragmentation on nearly
94 // all current networks."
96 // On 2020-04-19, the Quad9 resolver was seen to have a UDP payload size
97 // of 1232. Cloudflare's was 1452, and Google's was 4096.
98 maxUDPPayload = 1280 - 40 - 8
101 // base32Encoding is a base32 encoding without padding.
102 var base32Encoding = base32.StdEncoding.WithPadding(base32.NoPadding)
104 // generateKeypair generates a private key and the corresponding public key. If
105 // privkeyFilename and pubkeyFilename are respectively empty, it prints the
106 // corresponding key to standard output; otherwise it saves the key to the given
107 // file name. The private key is saved with mode 0400 and the public key is
108 // saved with 0666 (before umask). In case of any error, it attempts to delete
109 // any files it has created before returning.
110 func generateKeypair(privkeyFilename, pubkeyFilename string) (err error) {
111 // Filenames to delete in case of error (avoid leaving partially written
112 // files).
113 var toDelete []string
114 defer func() {
115 for _, filename := range toDelete {
116 fmt.Fprintf(os.Stderr, "deleting partially written file %s\n", filename)
117 if closeErr := os.Remove(filename); closeErr != nil {
118 fmt.Fprintf(os.Stderr, "cannot remove %s: %v\n", filename, closeErr)
119 if err == nil {
120 err = closeErr
126 privkey, err := noise.GeneratePrivkey()
127 if err != nil {
128 return err
130 pubkey := noise.PubkeyFromPrivkey(privkey)
132 if privkeyFilename != "" {
133 // Save the privkey to a file.
134 f, err := os.OpenFile(privkeyFilename, os.O_RDWR|os.O_CREATE, 0400)
135 if err != nil {
136 return err
138 toDelete = append(toDelete, privkeyFilename)
139 err = noise.WriteKey(f, privkey)
140 if err2 := f.Close(); err == nil {
141 err = err2
143 if err != nil {
144 return err
148 if pubkeyFilename != "" {
149 // Save the pubkey to a file.
150 f, err := os.Create(pubkeyFilename)
151 if err != nil {
152 return err
154 toDelete = append(toDelete, pubkeyFilename)
155 err = noise.WriteKey(f, pubkey)
156 if err2 := f.Close(); err == nil {
157 err = err2
159 if err != nil {
160 return err
164 // All good, allow the written files to remain.
165 toDelete = nil
167 if privkeyFilename != "" {
168 fmt.Printf("privkey written to %s\n", privkeyFilename)
169 } else {
170 fmt.Printf("privkey %x\n", privkey)
172 if pubkeyFilename != "" {
173 fmt.Printf("pubkey written to %s\n", pubkeyFilename)
174 } else {
175 fmt.Printf("pubkey %x\n", pubkey)
178 return nil
181 // readKeyFromFile reads a key from a named file.
182 func readKeyFromFile(filename string) ([]byte, error) {
183 f, err := os.Open(filename)
184 if err != nil {
185 return nil, err
187 defer f.Close()
188 return noise.ReadKey(f)
191 // handleStream bidirectionally connects a client stream with a TCP socket
192 // addressed by upstream.
193 func handleStream(stream *smux.Stream, upstream string, conv uint32) error {
194 dialer := net.Dialer{
195 Timeout: upstreamDialTimeout,
197 upstreamConn, err := dialer.Dial("tcp", upstream)
198 if err != nil {
199 return fmt.Errorf("stream %08x:%d connect upstream: %v", conv, stream.ID(), err)
201 defer upstreamConn.Close()
202 upstreamTCPConn := upstreamConn.(*net.TCPConn)
204 var wg sync.WaitGroup
205 wg.Add(2)
206 go func() {
207 defer wg.Done()
208 _, err := io.Copy(stream, upstreamTCPConn)
209 if err == io.EOF {
210 // smux Stream.Write may return io.EOF.
211 err = nil
213 if err != nil && !errors.Is(err, io.ErrClosedPipe) {
214 log.Printf("stream %08x:%d copy streamā†upstream: %v", conv, stream.ID(), err)
216 upstreamTCPConn.CloseRead()
217 stream.Close()
219 go func() {
220 defer wg.Done()
221 _, err := io.Copy(upstreamTCPConn, stream)
222 if err == io.EOF {
223 // smux Stream.WriteTo may return io.EOF.
224 err = nil
226 if err != nil && !errors.Is(err, io.ErrClosedPipe) {
227 log.Printf("stream %08x:%d copy upstreamā†stream: %v", conv, stream.ID(), err)
229 upstreamTCPConn.CloseWrite()
231 wg.Wait()
233 return nil
236 // acceptStreams wraps a KCP session in a Noise channel and an smux.Session,
237 // then awaits smux streams. It passes each stream to handleStream.
238 func acceptStreams(conn *kcp.UDPSession, privkey []byte, upstream string) error {
239 // Put a Noise channel on top of the KCP conn.
240 rw, err := noise.NewServer(conn, privkey)
241 if err != nil {
242 return err
245 // Put an smux session on top of the encrypted Noise channel.
246 smuxConfig := smux.DefaultConfig()
247 smuxConfig.Version = 2
248 smuxConfig.KeepAliveTimeout = idleTimeout
249 smuxConfig.MaxStreamBuffer = 1 * 1024 * 1024 // default is 65536
250 sess, err := smux.Server(rw, smuxConfig)
251 if err != nil {
252 return err
254 defer sess.Close()
256 for {
257 stream, err := sess.AcceptStream()
258 if err != nil {
259 if err, ok := err.(net.Error); ok && err.Temporary() {
260 continue
262 return err
264 log.Printf("begin stream %08x:%d", conn.GetConv(), stream.ID())
265 go func() {
266 defer func() {
267 log.Printf("end stream %08x:%d", conn.GetConv(), stream.ID())
268 stream.Close()
270 err := handleStream(stream, upstream, conn.GetConv())
271 if err != nil {
272 log.Printf("stream %08x:%d handleStream: %v", conn.GetConv(), stream.ID(), err)
278 // acceptSessions listens for incoming KCP connections and passes them to
279 // acceptStreams.
280 func acceptSessions(ln *kcp.Listener, privkey []byte, mtu int, upstream string) error {
281 for {
282 conn, err := ln.AcceptKCP()
283 if err != nil {
284 if err, ok := err.(net.Error); ok && err.Temporary() {
285 continue
287 return err
289 log.Printf("begin session %08x", conn.GetConv())
290 // Permit coalescing the payloads of consecutive sends.
291 conn.SetStreamMode(true)
292 // Disable the dynamic congestion window (limit only by the
293 // maximum of local and remote static windows).
294 conn.SetNoDelay(
295 0, // default nodelay
296 0, // default interval
297 0, // default resend
298 1, // nc=1 => congestion window off
300 conn.SetWindowSize(turbotunnel.QueueSize/2, turbotunnel.QueueSize/2)
301 if rc := conn.SetMtu(mtu); !rc {
302 panic(rc)
304 go func() {
305 defer func() {
306 log.Printf("end session %08x", conn.GetConv())
307 conn.Close()
309 err := acceptStreams(conn, privkey, upstream)
310 if err != nil && !errors.Is(err, io.ErrClosedPipe) {
311 log.Printf("session %08x acceptStreams: %v", conn.GetConv(), err)
317 // nextPacket reads the next length-prefixed packet from r, ignoring padding. It
318 // returns a nil error only when a packet was read successfully. It returns
319 // io.EOF only when there were 0 bytes remaining to read from r. It returns
320 // io.ErrUnexpectedEOF when EOF occurs in the middle of an encoded packet.
322 // The prefixing scheme is as follows. A length prefix L < 0xe0 means a data
323 // packet of L bytes. A length prefix L >= 0xe0 means padding of L - 0xe0 bytes
324 // (not counting the length of the length prefix itself).
325 func nextPacket(r *bytes.Reader) ([]byte, error) {
326 // Convert io.EOF to io.ErrUnexpectedEOF.
327 eof := func(err error) error {
328 if err == io.EOF {
329 err = io.ErrUnexpectedEOF
331 return err
334 for {
335 prefix, err := r.ReadByte()
336 if err != nil {
337 // We may return a real io.EOF only here.
338 return nil, err
340 if prefix >= 224 {
341 paddingLen := prefix - 224
342 _, err := io.CopyN(ioutil.Discard, r, int64(paddingLen))
343 if err != nil {
344 return nil, eof(err)
346 } else {
347 p := make([]byte, int(prefix))
348 _, err = io.ReadFull(r, p)
349 return p, eof(err)
354 // responseFor constructs a response dns.Message that is appropriate for query.
355 // Along with the dns.Message, it returns the query's decoded data payload. If
356 // the returned dns.Message is nil, it means that there should be no response to
357 // this query. If the returned dns.Message has an Rcode() of dns.RcodeNoError,
358 // the message is a candidate for for carrying downstream data in a TXT record.
359 func responseFor(query *dns.Message, domain dns.Name) (*dns.Message, []byte) {
360 resp := &dns.Message{
361 ID: query.ID,
362 Flags: 0x8000, // QR = 1, RCODE = no error
363 Question: query.Question,
366 if query.Flags&0x8000 != 0 {
367 // QR != 0, this is not a query. Don't even send a response.
368 return nil, nil
371 // Check for EDNS(0) support. Include our own OPT RR only if we receive
372 // one from the requester.
373 // https://tools.ietf.org/html/rfc6891#section-6.1.1
374 // "Lack of presence of an OPT record in a request MUST be taken as an
375 // indication that the requester does not implement any part of this
376 // specification and that the responder MUST NOT include an OPT record
377 // in its response."
378 payloadSize := 0
379 for _, rr := range query.Additional {
380 if rr.Type != dns.RRTypeOPT {
381 continue
383 if len(resp.Additional) != 0 {
384 // https://tools.ietf.org/html/rfc6891#section-6.1.1
385 // "If a query message with more than one OPT RR is
386 // received, a FORMERR (RCODE=1) MUST be returned."
387 resp.Flags |= dns.RcodeFormatError
388 log.Printf("FORMERR: more than one OPT RR")
389 return resp, nil
391 resp.Additional = append(resp.Additional, dns.RR{
392 Name: dns.Name{},
393 Type: dns.RRTypeOPT,
394 Class: 4096, // responder's UDP payload size
395 TTL: 0,
396 Data: []byte{},
398 additional := &resp.Additional[0]
400 version := (rr.TTL >> 16) & 0xff
401 if version != 0 {
402 // https://tools.ietf.org/html/rfc6891#section-6.1.1
403 // "If a responder does not implement the VERSION level
404 // of the request, then it MUST respond with
405 // RCODE=BADVERS."
406 resp.Flags |= dns.ExtendedRcodeBadVers & 0xf
407 additional.TTL = (dns.ExtendedRcodeBadVers >> 4) << 24
408 log.Printf("BADVERS: EDNS version %d != 0", version)
409 return resp, nil
412 payloadSize = int(rr.Class)
414 if payloadSize < 512 {
415 // https://tools.ietf.org/html/rfc6891#section-6.1.1 "Values
416 // lower than 512 MUST be treated as equal to 512."
417 payloadSize = 512
419 // We will return RcodeFormatError if payloadSize is too small, but
420 // first, check the name in order to set the AA bit properly.
422 // There must be exactly one question.
423 if len(query.Question) != 1 {
424 resp.Flags |= dns.RcodeFormatError
425 log.Printf("FORMERR: too few or too many questions (%d)", len(query.Question))
426 return resp, nil
428 question := query.Question[0]
429 // Check the name to see if it ends in our chosen domain, and extract
430 // all that comes before the domain if it does. If it does not, we will
431 // return RcodeNameError below, but prefer to return RcodeFormatError
432 // for payload size if that applies as well.
433 prefix, ok := question.Name.TrimSuffix(domain)
434 if !ok {
435 // Not a name we are authoritative for.
436 resp.Flags |= dns.RcodeNameError
437 log.Printf("NXDOMAIN: not authoritative for %s", question.Name)
438 return resp, nil
440 resp.Flags |= 0x0400 // AA = 1
442 if query.Opcode() != 0 {
443 // We don't support OPCODE != QUERY.
444 resp.Flags |= dns.RcodeNotImplemented
445 log.Printf("NOTIMPL: unrecognized OPCODE %d", query.Opcode())
446 return resp, nil
449 if question.Type != dns.RRTypeTXT {
450 // We only support QTYPE == TXT.
451 resp.Flags |= dns.RcodeNameError
452 // No log message here; it's common for recursive resolvers to
453 // send NS or A queries when the client only asked for a TXT. I
454 // suspect this is related to QNAME minimization, but I'm not
455 // sure. https://tools.ietf.org/html/rfc7816
456 // log.Printf("NXDOMAIN: QTYPE %d != TXT", question.Type)
457 return resp, nil
460 encoded := bytes.ToUpper(bytes.Join(prefix, nil))
461 payload := make([]byte, base32Encoding.DecodedLen(len(encoded)))
462 n, err := base32Encoding.Decode(payload, encoded)
463 if err != nil {
464 // Base32 error, make like the name doesn't exist.
465 resp.Flags |= dns.RcodeNameError
466 log.Printf("NXDOMAIN: base32 decoding: %v", err)
467 return resp, nil
469 payload = payload[:n]
471 // We require clients to support EDNS(0) with a minimum payload size;
472 // otherwise we would have to set a small KCP MTU (only around 200
473 // bytes). https://tools.ietf.org/html/rfc6891#section-7 "If there is a
474 // problem with processing the OPT record itself, such as an option
475 // value that is badly formatted or that includes out-of-range values, a
476 // FORMERR MUST be returned."
477 if payloadSize < maxUDPPayload {
478 resp.Flags |= dns.RcodeFormatError
479 log.Printf("FORMERR: requester payload size %d is too small (minimum %d)", payloadSize, maxUDPPayload)
480 return resp, nil
483 return resp, payload
486 // record represents a DNS message appropriate for a response to a previously
487 // received query, along with metadata necessary for sending the response.
488 // recvLoop sends instances of record to sendLoop via a channel. sendLoop
489 // receives instances of record and may fill in the message's Answer section
490 // before sending it.
491 type record struct {
492 Resp *dns.Message
493 Addr net.Addr
494 ClientID turbotunnel.ClientID
497 // recvLoop repeatedly calls dnsConn.ReadFrom, extracts the packets contained in
498 // the incoming DNS queries, and puts them on ttConn's incoming queue. Whenever
499 // a query calls for a response, constructs a partial response and passes it to
500 // sendLoop over ch.
501 func recvLoop(domain dns.Name, dnsConn net.PacketConn, ttConn *turbotunnel.QueuePacketConn, ch chan<- *record) error {
502 for {
503 var buf [4096]byte
504 n, addr, err := dnsConn.ReadFrom(buf[:])
505 if err != nil {
506 if err, ok := err.(net.Error); ok && err.Temporary() {
507 log.Printf("ReadFrom temporary error: %v", err)
508 continue
510 return err
513 // Got a UDP packet. Try to parse it as a DNS message.
514 query, err := dns.MessageFromWireFormat(buf[:n])
515 if err != nil {
516 log.Printf("cannot parse DNS query: %v", err)
517 continue
520 resp, payload := responseFor(&query, domain)
521 // Extract the ClientID from the payload.
522 var clientID turbotunnel.ClientID
523 n = copy(clientID[:], payload)
524 payload = payload[n:]
525 if n == len(clientID) {
526 // Discard padding and pull out the packets contained in
527 // the payload.
528 r := bytes.NewReader(payload)
529 for {
530 p, err := nextPacket(r)
531 if err != nil {
532 break
534 // Feed the incoming packet to KCP.
535 ttConn.QueueIncoming(p, clientID)
537 } else {
538 // Payload is not long enough to contain a ClientID.
539 if resp != nil && resp.Rcode() == dns.RcodeNoError {
540 resp.Flags |= dns.RcodeNameError
541 log.Printf("NXDOMAIN: %d bytes are too short to contain a ClientID", n)
544 // If a response is called for, pass it to sendLoop via the channel.
545 if resp != nil {
546 select {
547 case ch <- &record{resp, addr, clientID}:
548 default:
554 // sendLoop repeatedly receives records from ch. Those that represent an error
555 // response, it sends on the network immediately. Those that represent a
556 // response capable of carrying data, it packs full of as many packets as will
557 // fit while keeping the total size under maxEncodedPayload, then sends it.
558 func sendLoop(dnsConn net.PacketConn, ttConn *turbotunnel.QueuePacketConn, ch <-chan *record, maxEncodedPayload int) error {
559 var nextRec *record
560 for {
561 rec := nextRec
562 nextRec = nil
564 if rec == nil {
565 var ok bool
566 rec, ok = <-ch
567 if !ok {
568 break
572 if rec.Resp.Rcode() == dns.RcodeNoError && len(rec.Resp.Question) == 1 {
573 // If it's a non-error response, we can fill the Answer
574 // section with downstream packets.
576 // Any changes to how responses are built need to happen
577 // also in computeMaxEncodedPayload.
578 rec.Resp.Answer = []dns.RR{
580 Name: rec.Resp.Question[0].Name,
581 Type: rec.Resp.Question[0].Type,
582 Class: rec.Resp.Question[0].Class,
583 TTL: responseTTL,
584 Data: nil, // will be filled in below
588 var payload bytes.Buffer
589 limit := maxEncodedPayload
590 // We loop and bundle as many packets from OutgoingQueue
591 // into the response as will fit. Any packet that would
592 // overflow the capacity of the DNS response, we stash
593 // to be bundled into a future response.
594 timer := time.NewTimer(maxResponseDelay)
595 for {
596 var p []byte
597 unstash := ttConn.Unstash(rec.ClientID)
598 outgoing := ttConn.OutgoingQueue(rec.ClientID)
599 // Prioritize taking a packet first from the
600 // stash, then from the outgoing queue, then
601 // finally check for the expiration of the timer
602 // or for a receive on ch (indicating a new
603 // query that we must respond to).
604 select {
605 case p = <-unstash:
606 default:
607 select {
608 case p = <-unstash:
609 case p = <-outgoing:
610 default:
611 select {
612 case p = <-unstash:
613 case p = <-outgoing:
614 case <-timer.C:
615 case nextRec = <-ch:
619 // We wait for the first packet in a bundle
620 // only. The second and later packets must be
621 // immediately available or they will be omitted
622 // from this bundle.
623 timer.Reset(0)
625 if len(p) == 0 {
626 // timer expired or receive on ch, we
627 // are done with this response.
628 break
631 limit -= 2 + len(p)
632 if payload.Len() == 0 {
633 // No packet length check for the first
634 // packet; if it's too large, we allow
635 // it to be truncated and dropped by the
636 // receiver.
637 } else if limit < 0 {
638 // Stash this packet to send in the next
639 // response.
640 ttConn.Stash(p, rec.ClientID)
641 break
643 if int(uint16(len(p))) != len(p) {
644 panic(len(p))
646 binary.Write(&payload, binary.BigEndian, uint16(len(p)))
647 payload.Write(p)
649 timer.Stop()
651 rec.Resp.Answer[0].Data = dns.EncodeRDataTXT(payload.Bytes())
654 buf, err := rec.Resp.WireFormat()
655 if err != nil {
656 log.Printf("resp WireFormat: %v", err)
657 continue
659 // Truncate if necessary.
660 // https://tools.ietf.org/html/rfc1035#section-4.1.1
661 if len(buf) > maxUDPPayload {
662 log.Printf("truncating response of %d bytes to max of %d", len(buf), maxUDPPayload)
663 buf = buf[:maxUDPPayload]
664 buf[2] |= 0x02 // TC = 1
667 // Now we actually send the message as a UDP packet.
668 _, err = dnsConn.WriteTo(buf, rec.Addr)
669 if err != nil {
670 if err, ok := err.(net.Error); ok && err.Temporary() {
671 log.Printf("WriteTo temporary error: %v", err)
672 continue
674 return err
677 return nil
680 // computeMaxEncodedPayload computes the maximum amount of downstream TXT RR
681 // data that keep the overall response size less than maxUDPPayload, in the
682 // worst case when the response answers a query that has a maximum-length name
683 // in its Question section. Returns 0 in the case that no amount of data makes
684 // the overall response size small enough.
686 // This function needs to be kept in sync with sendLoop with regard to how it
687 // builds candidate responses.
688 func computeMaxEncodedPayload(limit int) int {
689 // 64+64+64+62 octets, needs to be base32-decodable.
690 maxLengthName, err := dns.NewName([][]byte{
691 []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"),
692 []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"),
693 []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"),
694 []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"),
696 if err != nil {
697 panic(err)
700 // Compute the encoded length of maxLengthName and that its
701 // length is actually at the maximum of 255 octets.
702 n := 0
703 for _, label := range maxLengthName {
704 n += len(label) + 1
706 n += 1 // For the terminating null label.
707 if n != 255 {
708 panic(fmt.Sprintf("max-length name is %d octets, should be %d %s", n, 255, maxLengthName))
712 queryLimit := uint16(limit)
713 if int(queryLimit) != limit {
714 queryLimit = 0xffff
716 query := &dns.Message{
717 Question: []dns.Question{
719 Name: maxLengthName,
720 Type: dns.RRTypeTXT,
721 Class: dns.RRTypeTXT,
724 // EDNS(0)
725 Additional: []dns.RR{
727 Name: dns.Name{},
728 Type: dns.RRTypeOPT,
729 Class: queryLimit, // requester's UDP payload size
730 TTL: 0, // extended RCODE and flags
731 Data: []byte{},
735 resp, _ := responseFor(query, dns.Name([][]byte{}))
736 // As in sendLoop.
737 resp.Answer = []dns.RR{
739 Name: query.Question[0].Name,
740 Type: query.Question[0].Type,
741 Class: query.Question[0].Class,
742 TTL: responseTTL,
743 Data: nil, // will be filled in below
747 // Binary search to find the maximum payload length that does not result
748 // in a wire-format message whose length exceeds the limit.
749 low := 0
750 high := 32768
751 for low+1 < high {
752 mid := (low + high) / 2
753 resp.Answer[0].Data = dns.EncodeRDataTXT(make([]byte, mid))
754 buf, err := resp.WireFormat()
755 if err != nil {
756 panic(err)
758 if len(buf) <= limit {
759 low = mid
760 } else {
761 high = mid
765 return low
768 func run(privkey []byte, domain dns.Name, upstream string, dnsConn net.PacketConn) error {
769 defer dnsConn.Close()
771 log.Printf("pubkey %x", noise.PubkeyFromPrivkey(privkey))
773 // We have a variable amount of room in which to encode downstream
774 // packets in each response, because each response must contain the
775 // query's Question section, which is of variable length. But we cannot
776 // give dynamic packet size limits to KCP; the best we can do is set a
777 // global maximum which no packet will exceed. We choose that maximum to
778 // keep the UDP payload size under maxUDPPayload, even in the worst case
779 // of a maximum-length name in the query's Question section.
780 maxEncodedPayload := computeMaxEncodedPayload(maxUDPPayload)
781 // 2 bytes accounts for a packet length prefix.
782 mtu := maxEncodedPayload - 2
783 if mtu < 80 {
784 if mtu < 0 {
785 mtu = 0
787 return fmt.Errorf("maximum UDP payload size of %d leaves only %d bytes for payload", maxUDPPayload, mtu)
789 log.Printf("effective MTU %d", mtu)
791 // Start up the virtual PacketConn for turbotunnel.
792 ttConn := turbotunnel.NewQueuePacketConn(turbotunnel.DummyAddr{}, idleTimeout*2)
793 ln, err := kcp.ServeConn(nil, 0, 0, ttConn)
794 if err != nil {
795 return fmt.Errorf("opening KCP listener: %v", err)
797 defer ln.Close()
798 go func() {
799 err := acceptSessions(ln, privkey, mtu, upstream)
800 if err != nil {
801 log.Printf("acceptSessions: %v", err)
805 ch := make(chan *record, 100)
806 defer close(ch)
808 // We could run multiple copies of sendLoop; that would allow more time
809 // for each response to collect downstream data before being evicted by
810 // another response that needs to be sent.
811 go func() {
812 err := sendLoop(dnsConn, ttConn, ch, maxEncodedPayload)
813 if err != nil {
814 log.Printf("sendLoop: %v", err)
818 return recvLoop(domain, dnsConn, ttConn, ch)
821 func main() {
822 var genKey bool
823 var privkeyFilename string
824 var privkeyString string
825 var pubkeyFilename string
826 var udpAddr string
828 flag.Usage = func() {
829 fmt.Fprintf(flag.CommandLine.Output(), `Usage:
830 %[1]s -gen-key -privkey-file PRIVKEYFILE -pubkey-file PUBKEYFILE
831 %[1]s -udp ADDR -privkey-file PRIVKEYFILE DOMAIN UPSTREAMADDR
833 Example:
834 %[1]s -gen-key -privkey-file server.key -pubkey-file server.pub
835 %[1]s -udp :53 -privkey-file server.key t.example.com 127.0.0.1:8000
837 `, os.Args[0])
838 flag.PrintDefaults()
840 flag.BoolVar(&genKey, "gen-key", false, "generate a server keypair; print to stdout or save to files")
841 flag.IntVar(&maxUDPPayload, "mtu", maxUDPPayload, "maximum size of DNS responses")
842 flag.StringVar(&privkeyString, "privkey", "", fmt.Sprintf("server private key (%d hex digits)", noise.KeyLen*2))
843 flag.StringVar(&privkeyFilename, "privkey-file", "", "read server private key from file (with -gen-key, write to file)")
844 flag.StringVar(&pubkeyFilename, "pubkey-file", "", "with -gen-key, write server public key to file")
845 flag.StringVar(&udpAddr, "udp", "", "UDP address to listen on (required)")
846 flag.Parse()
848 log.SetFlags(log.LstdFlags | log.LUTC)
850 if genKey {
851 // -gen-key mode.
852 if flag.NArg() != 0 || privkeyString != "" || udpAddr != "" {
853 flag.Usage()
854 os.Exit(1)
856 if err := generateKeypair(privkeyFilename, pubkeyFilename); err != nil {
857 fmt.Fprintf(os.Stderr, "cannot generate keypair: %v\n", err)
858 os.Exit(1)
860 } else {
861 // Ordinary server mode.
862 if flag.NArg() != 2 {
863 flag.Usage()
864 os.Exit(1)
866 domain, err := dns.ParseName(flag.Arg(0))
867 if err != nil {
868 fmt.Fprintf(os.Stderr, "invalid domain %+q: %v\n", flag.Arg(0), err)
869 os.Exit(1)
871 upstream := flag.Arg(1)
872 // We keep upstream as a string in order to eventually pass it
873 // to net.Dial in handleStream. But for the sake of displaying
874 // an error or warning at startup, rather than only when the
875 // first stream occurs, we apply some parsing and name
876 // resolution checks here.
878 upstreamHost, _, err := net.SplitHostPort(upstream)
879 if err != nil {
880 // host:port format is required in all cases, so
881 // this is a fatal error.
882 fmt.Fprintf(os.Stderr, "cannot parse upstream address %+q: %v\n", upstream, err)
883 os.Exit(1)
885 upstreamIPAddr, err := net.ResolveIPAddr("ip", upstreamHost)
886 if err != nil {
887 // Failure to resolve the host portion is only a
888 // warning. The name will be re-resolved on each
889 // net.Dial in handleStream.
890 log.Printf("warning: cannot resolve upstream host %+q: %v", upstreamHost, err)
891 } else if upstreamIPAddr.IP == nil {
892 // Handle the special case of an empty string
893 // for the host portion, which resolves to a nil
894 // IP. This is a fatal error as we will not be
895 // able to dial this address.
896 fmt.Fprintf(os.Stderr, "cannot parse upstream address %+q: missing host in address\n", upstream)
897 os.Exit(1)
901 if udpAddr == "" {
902 fmt.Fprintf(os.Stderr, "the -udp option is required\n")
903 os.Exit(1)
905 dnsConn, err := net.ListenPacket("udp", udpAddr)
906 if err != nil {
907 fmt.Fprintf(os.Stderr, "opening UDP listener: %v\n", err)
908 os.Exit(1)
911 if pubkeyFilename != "" {
912 fmt.Fprintf(os.Stderr, "-pubkey-file may only be used with -gen-key\n")
913 os.Exit(1)
916 var privkey []byte
917 if privkeyFilename != "" && privkeyString != "" {
918 fmt.Fprintf(os.Stderr, "only one of -privkey and -privkey-file may be used\n")
919 os.Exit(1)
920 } else if privkeyFilename != "" {
921 var err error
922 privkey, err = readKeyFromFile(privkeyFilename)
923 if err != nil {
924 fmt.Fprintf(os.Stderr, "cannot read privkey from file: %v\n", err)
925 os.Exit(1)
927 } else if privkeyString != "" {
928 var err error
929 privkey, err = noise.DecodeKey(privkeyString)
930 if err != nil {
931 fmt.Fprintf(os.Stderr, "privkey format error: %v\n", err)
932 os.Exit(1)
935 if len(privkey) == 0 {
936 log.Println("generating a temporary one-time keypair")
937 log.Println("use the -privkey or -privkey-file option for a persistent server keypair")
938 var err error
939 privkey, err = noise.GeneratePrivkey()
940 if err != nil {
941 fmt.Fprintln(os.Stderr, err)
942 os.Exit(1)
946 err = run(privkey, domain, upstream, dnsConn)
947 if err != nil {
948 log.Fatal(err)