CHANGELOG for v1.20210812.0.
[dnstt.git] / dnstt-server / main.go
blobcf40cb7e9ca081ff000746df4052bd2d67c0f8d3
1 // dnstt-server is the server end of a DNS tunnel.
2 //
3 // Usage:
4 // dnstt-server -gen-key [-privkey-file PRIVKEYFILE] [-pubkey-file PUBKEYFILE]
5 // dnstt-server -udp ADDR [-privkey PRIVKEY|-privkey-file PRIVKEYFILE] DOMAIN UPSTREAMADDR
6 //
7 // Example:
8 // dnstt-server -gen-key -privkey-file server.key -pubkey-file server.pub
9 // dnstt-server -udp :53 -privkey-file server.key t.example.com 127.0.0.1:8000
11 // To generate a persistent server private key, first run with the -gen-key
12 // option. By default the generated private and public keys are printed to
13 // standard output. To save them to files instead, use the -privkey-file and
14 // -pubkey-file options.
15 // dnstt-server -gen-key
16 // dnstt-server -gen-key -privkey-file server.key -pubkey-file server.pub
18 // You can give the server's private key as a file or as a hex string.
19 // -privkey-file server.key
20 // -privkey 0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef
22 // The -udp option controls the address that will listen for incoming DNS
23 // queries.
25 // The -mtu option controls the maximum size of response UDP payloads.
26 // Queries that do not advertise requester support for responses of at least
27 // this size at least this size will be responded to with a FORMERR. The default
28 // value is maxUDPPayload.
30 // DOMAIN is the root of the DNS zone reserved for the tunnel. See README for
31 // instructions on setting it up.
33 // UPSTREAMADDR is the TCP address to which incoming tunnelled streams will be
34 // forwarded.
35 package main
37 import (
38 "bytes"
39 "encoding/base32"
40 "encoding/binary"
41 "errors"
42 "flag"
43 "fmt"
44 "io"
45 "io/ioutil"
46 "log"
47 "net"
48 "os"
49 "sync"
50 "time"
52 "github.com/xtaci/kcp-go/v5"
53 "github.com/xtaci/smux"
54 "www.bamsoftware.com/git/dnstt.git/dns"
55 "www.bamsoftware.com/git/dnstt.git/noise"
56 "www.bamsoftware.com/git/dnstt.git/turbotunnel"
59 const (
60 // smux streams will be closed after this much time without receiving data.
61 idleTimeout = 2 * time.Minute
63 // How to set the TTL field in Answer resource records.
64 responseTTL = 60
66 // How long we may wait for downstream data before sending an empty
67 // response. If another query comes in while we are waiting, we'll send
68 // an empty response anyway and restart the delay timer for the next
69 // response.
71 // This number should be less than 2 seconds, which in 2019 was reported
72 // to be the query timeout of the Quad9 DoH server.
73 // https://dnsencryption.info/imc19-doe.html Section 4.2, Finding 2.4
74 maxResponseDelay = 1 * time.Second
76 // How long to wait for a TCP connection to upstream to be established.
77 upstreamDialTimeout = 30 * time.Second
80 var (
81 // We don't send UDP payloads larger than this, in an attempt to avoid
82 // network-layer fragmentation. 1280 is the minimum IPv6 MTU, 40 bytes
83 // is the size of an IPv6 header (though without any extension headers),
84 // and 8 bytes is the size of a UDP header.
86 // Control this value with the -mtu command-line option.
88 // https://dnsflagday.net/2020/#message-size-considerations
89 // "An EDNS buffer size of 1232 bytes will avoid fragmentation on nearly
90 // all current networks."
92 // On 2020-04-19, the Quad9 resolver was seen to have a UDP payload size
93 // of 1232. Cloudflare's was 1452, and Google's was 4096.
94 maxUDPPayload = 1280 - 40 - 8
97 // base32Encoding is a base32 encoding without padding.
98 var base32Encoding = base32.StdEncoding.WithPadding(base32.NoPadding)
100 // generateKeypair generates a private key and the corresponding public key. If
101 // privkeyFilename and pubkeyFilename are respectively empty, it prints the
102 // corresponding key to standard output; otherwise it saves the key to the given
103 // file name. The private key is saved with mode 0400 and the public key is
104 // saved with 0666 (before umask). In case of any error, it attempts to delete
105 // any files it has created before returning.
106 func generateKeypair(privkeyFilename, pubkeyFilename string) (err error) {
107 // Filenames to delete in case of error (avoid leaving partially written
108 // files).
109 var toDelete []string
110 defer func() {
111 for _, filename := range toDelete {
112 fmt.Fprintf(os.Stderr, "deleting partially written file %s\n", filename)
113 if closeErr := os.Remove(filename); closeErr != nil {
114 fmt.Fprintf(os.Stderr, "cannot remove %s: %v\n", filename, closeErr)
115 if err == nil {
116 err = closeErr
122 privkey, err := noise.GeneratePrivkey()
123 if err != nil {
124 return err
126 pubkey := noise.PubkeyFromPrivkey(privkey)
128 if privkeyFilename != "" {
129 // Save the privkey to a file.
130 f, err := os.OpenFile(privkeyFilename, os.O_RDWR|os.O_CREATE, 0400)
131 if err != nil {
132 return err
134 toDelete = append(toDelete, privkeyFilename)
135 err = noise.WriteKey(f, privkey)
136 if err2 := f.Close(); err == nil {
137 err = err2
139 if err != nil {
140 return err
144 if pubkeyFilename != "" {
145 // Save the pubkey to a file.
146 f, err := os.Create(pubkeyFilename)
147 if err != nil {
148 return err
150 toDelete = append(toDelete, pubkeyFilename)
151 err = noise.WriteKey(f, pubkey)
152 if err2 := f.Close(); err == nil {
153 err = err2
155 if err != nil {
156 return err
160 // All good, allow the written files to remain.
161 toDelete = nil
163 if privkeyFilename != "" {
164 fmt.Printf("privkey written to %s\n", privkeyFilename)
165 } else {
166 fmt.Printf("privkey %x\n", privkey)
168 if pubkeyFilename != "" {
169 fmt.Printf("pubkey written to %s\n", pubkeyFilename)
170 } else {
171 fmt.Printf("pubkey %x\n", pubkey)
174 return nil
177 // readKeyFromFile reads a key from a named file.
178 func readKeyFromFile(filename string) ([]byte, error) {
179 f, err := os.Open(filename)
180 if err != nil {
181 return nil, err
183 defer f.Close()
184 return noise.ReadKey(f)
187 // handleStream bidirectionally connects a client stream with a TCP socket
188 // addressed by upstream.
189 func handleStream(stream *smux.Stream, upstream string, conv uint32) error {
190 dialer := net.Dialer{
191 Timeout: upstreamDialTimeout,
193 upstreamConn, err := dialer.Dial("tcp", upstream)
194 if err != nil {
195 return fmt.Errorf("stream %08x:%d connect upstream: %v", conv, stream.ID(), err)
197 defer upstreamConn.Close()
198 upstreamTCPConn := upstreamConn.(*net.TCPConn)
200 var wg sync.WaitGroup
201 wg.Add(2)
202 go func() {
203 defer wg.Done()
204 _, err := io.Copy(stream, upstreamTCPConn)
205 if err == io.EOF {
206 // smux Stream.Write may return io.EOF.
207 err = nil
209 if err != nil && !errors.Is(err, io.ErrClosedPipe) {
210 log.Printf("stream %08x:%d copy streamā†upstream: %v", conv, stream.ID(), err)
212 upstreamTCPConn.CloseRead()
213 stream.Close()
215 go func() {
216 defer wg.Done()
217 _, err := io.Copy(upstreamTCPConn, stream)
218 if err == io.EOF {
219 // smux Stream.WriteTo may return io.EOF.
220 err = nil
222 if err != nil && !errors.Is(err, io.ErrClosedPipe) {
223 log.Printf("stream %08x:%d copy upstreamā†stream: %v", conv, stream.ID(), err)
225 upstreamTCPConn.CloseWrite()
227 wg.Wait()
229 return nil
232 // acceptStreams wraps a KCP session in a Noise channel and an smux.Session,
233 // then awaits smux streams. It passes each stream to handleStream.
234 func acceptStreams(conn *kcp.UDPSession, privkey []byte, upstream string) error {
235 // Put a Noise channel on top of the KCP conn.
236 rw, err := noise.NewServer(conn, privkey)
237 if err != nil {
238 return err
241 // Put an smux session on top of the encrypted Noise channel.
242 smuxConfig := smux.DefaultConfig()
243 smuxConfig.Version = 2
244 smuxConfig.KeepAliveTimeout = idleTimeout
245 smuxConfig.MaxStreamBuffer = 1 * 1024 * 1024 // default is 65536
246 sess, err := smux.Server(rw, smuxConfig)
247 if err != nil {
248 return err
250 defer sess.Close()
252 for {
253 stream, err := sess.AcceptStream()
254 if err != nil {
255 if err, ok := err.(net.Error); ok && err.Temporary() {
256 continue
258 return err
260 log.Printf("begin stream %08x:%d", conn.GetConv(), stream.ID())
261 go func() {
262 defer func() {
263 log.Printf("end stream %08x:%d", conn.GetConv(), stream.ID())
264 stream.Close()
266 err := handleStream(stream, upstream, conn.GetConv())
267 if err != nil {
268 log.Printf("stream %08x:%d handleStream: %v", conn.GetConv(), stream.ID(), err)
274 // acceptSessions listens for incoming KCP connections and passes them to
275 // acceptStreams.
276 func acceptSessions(ln *kcp.Listener, privkey []byte, mtu int, upstream string) error {
277 for {
278 conn, err := ln.AcceptKCP()
279 if err != nil {
280 if err, ok := err.(net.Error); ok && err.Temporary() {
281 continue
283 return err
285 log.Printf("begin session %08x", conn.GetConv())
286 // Permit coalescing the payloads of consecutive sends.
287 conn.SetStreamMode(true)
288 // Disable the dynamic congestion window (limit only by the
289 // maximum of local and remote static windows).
290 conn.SetNoDelay(
291 0, // default nodelay
292 0, // default interval
293 0, // default resend
294 1, // nc=1 => congestion window off
296 conn.SetWindowSize(turbotunnel.QueueSize/2, turbotunnel.QueueSize/2)
297 if rc := conn.SetMtu(mtu); !rc {
298 panic(rc)
300 go func() {
301 defer func() {
302 log.Printf("end session %08x", conn.GetConv())
303 conn.Close()
305 err := acceptStreams(conn, privkey, upstream)
306 if err != nil && !errors.Is(err, io.ErrClosedPipe) {
307 log.Printf("session %08x acceptStreams: %v", conn.GetConv(), err)
313 // nextPacket reads the next length-prefixed packet from r, ignoring padding. It
314 // returns a nil error only when a packet was read successfully. It returns
315 // io.EOF only when there were 0 bytes remaining to read from r. It returns
316 // io.ErrUnexpectedEOF when EOF occurs in the middle of an encoded packet.
318 // The prefixing scheme is as follows. A length prefix L < 0xe0 means a data
319 // packet of L bytes. A length prefix L >= 0xe0 means padding of L - 0xe0 bytes
320 // (not counting the length of the length prefix itself).
321 func nextPacket(r *bytes.Reader) ([]byte, error) {
322 // Convert io.EOF to io.ErrUnexpectedEOF.
323 eof := func(err error) error {
324 if err == io.EOF {
325 err = io.ErrUnexpectedEOF
327 return err
330 for {
331 prefix, err := r.ReadByte()
332 if err != nil {
333 // We may return a real io.EOF only here.
334 return nil, err
336 if prefix >= 224 {
337 paddingLen := prefix - 224
338 _, err := io.CopyN(ioutil.Discard, r, int64(paddingLen))
339 if err != nil {
340 return nil, eof(err)
342 } else {
343 p := make([]byte, int(prefix))
344 _, err = io.ReadFull(r, p)
345 return p, eof(err)
350 // responseFor constructs a response dns.Message that is appropriate for query.
351 // Along with the dns.Message, it returns the query's decoded data payload. If
352 // the returned dns.Message is nil, it means that there should be no response to
353 // this query. If the returned dns.Message has an Rcode() of dns.RcodeNoError,
354 // the message is a candidate for for carrying downstream data in a TXT record.
355 func responseFor(query *dns.Message, domain dns.Name) (*dns.Message, []byte) {
356 resp := &dns.Message{
357 ID: query.ID,
358 Flags: 0x8000, // QR = 1, RCODE = no error
359 Question: query.Question,
362 if query.Flags&0x8000 != 0 {
363 // QR != 0, this is not a query. Don't even send a response.
364 return nil, nil
367 // Check for EDNS(0) support. Include our own OPT RR only if we receive
368 // one from the requester.
369 // https://tools.ietf.org/html/rfc6891#section-6.1.1
370 // "Lack of presence of an OPT record in a request MUST be taken as an
371 // indication that the requester does not implement any part of this
372 // specification and that the responder MUST NOT include an OPT record
373 // in its response."
374 payloadSize := 0
375 for _, rr := range query.Additional {
376 if rr.Type != dns.RRTypeOPT {
377 continue
379 if len(resp.Additional) != 0 {
380 // https://tools.ietf.org/html/rfc6891#section-6.1.1
381 // "If a query message with more than one OPT RR is
382 // received, a FORMERR (RCODE=1) MUST be returned."
383 resp.Flags |= dns.RcodeFormatError
384 log.Printf("FORMERR: more than one OPT RR")
385 return resp, nil
387 resp.Additional = append(resp.Additional, dns.RR{
388 Name: dns.Name{},
389 Type: dns.RRTypeOPT,
390 Class: 4096, // responder's UDP payload size
391 TTL: 0,
392 Data: []byte{},
394 additional := &resp.Additional[0]
396 version := (rr.TTL >> 16) & 0xff
397 if version != 0 {
398 // https://tools.ietf.org/html/rfc6891#section-6.1.1
399 // "If a responder does not implement the VERSION level
400 // of the request, then it MUST respond with
401 // RCODE=BADVERS."
402 resp.Flags |= dns.ExtendedRcodeBadVers & 0xf
403 additional.TTL = (dns.ExtendedRcodeBadVers >> 4) << 24
404 log.Printf("BADVERS: EDNS version %d != 0", version)
405 return resp, nil
408 payloadSize = int(rr.Class)
410 if payloadSize < 512 {
411 // https://tools.ietf.org/html/rfc6891#section-6.1.1 "Values
412 // lower than 512 MUST be treated as equal to 512."
413 payloadSize = 512
415 // We will return RcodeFormatError if payloadSize is too small, but
416 // first, check the name in order to set the AA bit properly.
418 // There must be exactly one question.
419 if len(query.Question) != 1 {
420 resp.Flags |= dns.RcodeFormatError
421 log.Printf("FORMERR: too few or too many questions (%d)", len(query.Question))
422 return resp, nil
424 question := query.Question[0]
425 // Check the name to see if it ends in our chosen domain, and extract
426 // all that comes before the domain if it does. If it does not, we will
427 // return RcodeNameError below, but prefer to return RcodeFormatError
428 // for payload size if that applies as well.
429 prefix, ok := question.Name.TrimSuffix(domain)
430 if !ok {
431 // Not a name we are authoritative for.
432 resp.Flags |= dns.RcodeNameError
433 log.Printf("NXDOMAIN: not authoritative for %s", question.Name)
434 return resp, nil
436 resp.Flags |= 0x0400 // AA = 1
438 if query.Opcode() != 0 {
439 // We don't support OPCODE != QUERY.
440 resp.Flags |= dns.RcodeNotImplemented
441 log.Printf("NOTIMPL: unrecognized OPCODE %d", query.Opcode())
442 return resp, nil
445 if question.Type != dns.RRTypeTXT {
446 // We only support QTYPE == TXT.
447 resp.Flags |= dns.RcodeNameError
448 // No log message here; it's common for recursive resolvers to
449 // send NS or A queries when the client only asked for a TXT. I
450 // suspect this is related to QNAME minimization, but I'm not
451 // sure. https://tools.ietf.org/html/rfc7816
452 // log.Printf("NXDOMAIN: QTYPE %d != TXT", question.Type)
453 return resp, nil
456 encoded := bytes.ToUpper(bytes.Join(prefix, nil))
457 payload := make([]byte, base32Encoding.DecodedLen(len(encoded)))
458 n, err := base32Encoding.Decode(payload, encoded)
459 if err != nil {
460 // Base32 error, make like the name doesn't exist.
461 resp.Flags |= dns.RcodeNameError
462 log.Printf("NXDOMAIN: base32 decoding: %v", err)
463 return resp, nil
465 payload = payload[:n]
467 // We require clients to support EDNS(0) with a minimum payload size;
468 // otherwise we would have to set a small KCP MTU (only around 200
469 // bytes). https://tools.ietf.org/html/rfc6891#section-7 "If there is a
470 // problem with processing the OPT record itself, such as an option
471 // value that is badly formatted or that includes out-of-range values, a
472 // FORMERR MUST be returned."
473 if payloadSize < maxUDPPayload {
474 resp.Flags |= dns.RcodeFormatError
475 log.Printf("FORMERR: requester payload size %d is too small (minimum %d)", payloadSize, maxUDPPayload)
476 return resp, nil
479 return resp, payload
482 // record represents a DNS message appropriate for a response to a previously
483 // received query, along with metadata necessary for sending the response.
484 // recvLoop sends instances of record to sendLoop via a channel. sendLoop
485 // receives instances of record and may fill in the message's Answer section
486 // before sending it.
487 type record struct {
488 Resp *dns.Message
489 Addr net.Addr
490 ClientID turbotunnel.ClientID
493 // recvLoop repeatedly calls dnsConn.ReadFrom, extracts the packets contained in
494 // the incoming DNS queries, and puts them on ttConn's incoming queue. Whenever
495 // a query calls for a response, constructs a partial response and passes it to
496 // sendLoop over ch.
497 func recvLoop(domain dns.Name, dnsConn net.PacketConn, ttConn *turbotunnel.QueuePacketConn, ch chan<- *record) error {
498 for {
499 var buf [4096]byte
500 n, addr, err := dnsConn.ReadFrom(buf[:])
501 if err != nil {
502 if err, ok := err.(net.Error); ok && err.Temporary() {
503 log.Printf("ReadFrom temporary error: %v", err)
504 continue
506 return err
509 // Got a UDP packet. Try to parse it as a DNS message.
510 query, err := dns.MessageFromWireFormat(buf[:n])
511 if err != nil {
512 log.Printf("cannot parse DNS query: %v", err)
513 continue
516 resp, payload := responseFor(&query, domain)
517 // Extract the ClientID from the payload.
518 var clientID turbotunnel.ClientID
519 n = copy(clientID[:], payload)
520 payload = payload[n:]
521 if n == len(clientID) {
522 // Discard padding and pull out the packets contained in
523 // the payload.
524 r := bytes.NewReader(payload)
525 for {
526 p, err := nextPacket(r)
527 if err != nil {
528 break
530 // Feed the incoming packet to KCP.
531 ttConn.QueueIncoming(p, clientID)
533 } else {
534 // Payload is not long enough to contain a ClientID.
535 if resp != nil && resp.Rcode() == dns.RcodeNoError {
536 resp.Flags |= dns.RcodeNameError
537 log.Printf("NXDOMAIN: %d bytes are too short to contain a ClientID", n)
540 // If a response is called for, pass it to sendLoop via the channel.
541 if resp != nil {
542 select {
543 case ch <- &record{resp, addr, clientID}:
544 default:
550 // sendLoop repeatedly receives records from ch. Those that represent an error
551 // response, it sends on the network immediately. Those that represent a
552 // response capable of carrying data, it packs full of as many packets as will
553 // fit while keeping the total size under maxEncodedPayload, then sends it.
554 func sendLoop(dnsConn net.PacketConn, ttConn *turbotunnel.QueuePacketConn, ch <-chan *record, maxEncodedPayload int) error {
555 var nextRec *record
556 for {
557 rec := nextRec
558 nextRec = nil
560 if rec == nil {
561 var ok bool
562 rec, ok = <-ch
563 if !ok {
564 break
568 if rec.Resp.Rcode() == dns.RcodeNoError && len(rec.Resp.Question) == 1 {
569 // If it's a non-error response, we can fill the Answer
570 // section with downstream packets.
572 // Any changes to how responses are built need to happen
573 // also in computeMaxEncodedPayload.
574 rec.Resp.Answer = []dns.RR{
576 Name: rec.Resp.Question[0].Name,
577 Type: rec.Resp.Question[0].Type,
578 Class: rec.Resp.Question[0].Class,
579 TTL: responseTTL,
580 Data: nil, // will be filled in below
584 var payload bytes.Buffer
585 limit := maxEncodedPayload
586 // We loop and bundle as many packets from OutgoingQueue
587 // into the response as will fit. Any packet that would
588 // overflow the capacity of the DNS response, we stash
589 // to be bundled into a future response.
590 timer := time.NewTimer(maxResponseDelay)
591 for {
592 var p []byte
593 unstash := ttConn.Unstash(rec.ClientID)
594 outgoing := ttConn.OutgoingQueue(rec.ClientID)
595 // Prioritize taking a packet first from the
596 // stash, then from the outgoing queue, then
597 // finally check for the expiration of the timer
598 // or for a receive on ch (indicating a new
599 // query that we must respond to).
600 select {
601 case p = <-unstash:
602 default:
603 select {
604 case p = <-unstash:
605 case p = <-outgoing:
606 default:
607 select {
608 case p = <-unstash:
609 case p = <-outgoing:
610 case <-timer.C:
611 case nextRec = <-ch:
615 // We wait for the first packet in a bundle
616 // only. The second and later packets must be
617 // immediately available or they will be omitted
618 // from this bundle.
619 timer.Reset(0)
621 if len(p) == 0 {
622 // timer expired or receive on ch, we
623 // are done with this response.
624 break
627 limit -= 2 + len(p)
628 if payload.Len() == 0 {
629 // No packet length check for the first
630 // packet; if it's too large, we allow
631 // it to be truncated and dropped by the
632 // receiver.
633 } else if limit < 0 {
634 // Stash this packet to send in the next
635 // response.
636 ttConn.Stash(p, rec.ClientID)
637 break
639 if int(uint16(len(p))) != len(p) {
640 panic(len(p))
642 binary.Write(&payload, binary.BigEndian, uint16(len(p)))
643 payload.Write(p)
645 timer.Stop()
647 rec.Resp.Answer[0].Data = dns.EncodeRDataTXT(payload.Bytes())
650 buf, err := rec.Resp.WireFormat()
651 if err != nil {
652 log.Printf("resp WireFormat: %v", err)
653 continue
655 // Truncate if necessary.
656 // https://tools.ietf.org/html/rfc1035#section-4.1.1
657 if len(buf) > maxUDPPayload {
658 log.Printf("truncating response of %d bytes to max of %d", len(buf), maxUDPPayload)
659 buf = buf[:maxUDPPayload]
660 buf[2] |= 0x02 // TC = 1
663 // Now we actually send the message as a UDP packet.
664 _, err = dnsConn.WriteTo(buf, rec.Addr)
665 if err != nil {
666 if err, ok := err.(net.Error); ok && err.Temporary() {
667 log.Printf("WriteTo temporary error: %v", err)
668 continue
670 return err
673 return nil
676 // computeMaxEncodedPayload computes the maximum amount of downstream TXT RR
677 // data that keep the overall response size less than maxUDPPayload, in the
678 // worst case when the response answers a query that has a maximum-length name
679 // in its Question section. Returns 0 in the case that no amount of data makes
680 // the overall response size small enough.
682 // This function needs to be kept in sync with sendLoop with regard to how it
683 // builds candidate responses.
684 func computeMaxEncodedPayload(limit int) int {
685 // 64+64+64+62 octets, needs to be base32-decodable.
686 maxLengthName, err := dns.NewName([][]byte{
687 []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"),
688 []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"),
689 []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"),
690 []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"),
692 if err != nil {
693 panic(err)
696 // Compute the encoded length of maxLengthName and that its
697 // length is actually at the maximum of 255 octets.
698 n := 0
699 for _, label := range maxLengthName {
700 n += len(label) + 1
702 n += 1 // For the terminating null label.
703 if n != 255 {
704 panic(fmt.Sprintf("max-length name is %d octets, should be %d %s", n, 255, maxLengthName))
708 queryLimit := uint16(limit)
709 if int(queryLimit) != limit {
710 queryLimit = 0xffff
712 query := &dns.Message{
713 Question: []dns.Question{
715 Name: maxLengthName,
716 Type: dns.RRTypeTXT,
717 Class: dns.RRTypeTXT,
720 // EDNS(0)
721 Additional: []dns.RR{
723 Name: dns.Name{},
724 Type: dns.RRTypeOPT,
725 Class: queryLimit, // requester's UDP payload size
726 TTL: 0, // extended RCODE and flags
727 Data: []byte{},
731 resp, _ := responseFor(query, dns.Name([][]byte{}))
732 // As in sendLoop.
733 resp.Answer = []dns.RR{
735 Name: query.Question[0].Name,
736 Type: query.Question[0].Type,
737 Class: query.Question[0].Class,
738 TTL: responseTTL,
739 Data: nil, // will be filled in below
743 // Binary search to find the maximum payload length that does not result
744 // in a wire-format message whose length exceeds the limit.
745 low := 0
746 high := 32768
747 for low+1 < high {
748 mid := (low + high) / 2
749 resp.Answer[0].Data = dns.EncodeRDataTXT(make([]byte, mid))
750 buf, err := resp.WireFormat()
751 if err != nil {
752 panic(err)
754 if len(buf) <= limit {
755 low = mid
756 } else {
757 high = mid
761 return low
764 func run(privkey []byte, domain dns.Name, upstream string, dnsConn net.PacketConn) error {
765 defer dnsConn.Close()
767 log.Printf("pubkey %x", noise.PubkeyFromPrivkey(privkey))
769 // We have a variable amount of room in which to encode downstream
770 // packets in each response, because each response must contain the
771 // query's Question section, which is of variable length. But we cannot
772 // give dynamic packet size limits to KCP; the best we can do is set a
773 // global maximum which no packet will exceed. We choose that maximum to
774 // keep the UDP payload size under maxUDPPayload, even in the worst case
775 // of a maximum-length name in the query's Question section.
776 maxEncodedPayload := computeMaxEncodedPayload(maxUDPPayload)
777 // 2 bytes accounts for a packet length prefix.
778 mtu := maxEncodedPayload - 2
779 if mtu < 80 {
780 if mtu < 0 {
781 mtu = 0
783 return fmt.Errorf("maximum UDP payload size of %d leaves only %d bytes for payload", maxUDPPayload, mtu)
785 log.Printf("effective MTU %d", mtu)
787 // Start up the virtual PacketConn for turbotunnel.
788 ttConn := turbotunnel.NewQueuePacketConn(turbotunnel.DummyAddr{}, idleTimeout*2)
789 ln, err := kcp.ServeConn(nil, 0, 0, ttConn)
790 if err != nil {
791 return fmt.Errorf("opening KCP listener: %v", err)
793 defer ln.Close()
794 go func() {
795 err := acceptSessions(ln, privkey, mtu, upstream)
796 if err != nil {
797 log.Printf("acceptSessions: %v", err)
801 ch := make(chan *record, 100)
802 defer close(ch)
804 // We could run multiple copies of sendLoop; that would allow more time
805 // for each response to collect downstream data before being evicted by
806 // another response that needs to be sent.
807 go func() {
808 err := sendLoop(dnsConn, ttConn, ch, maxEncodedPayload)
809 if err != nil {
810 log.Printf("sendLoop: %v", err)
814 return recvLoop(domain, dnsConn, ttConn, ch)
817 func main() {
818 var genKey bool
819 var privkeyFilename string
820 var privkeyString string
821 var pubkeyFilename string
822 var udpAddr string
824 flag.Usage = func() {
825 fmt.Fprintf(flag.CommandLine.Output(), `Usage:
826 %[1]s -gen-key -privkey-file PRIVKEYFILE -pubkey-file PUBKEYFILE
827 %[1]s -udp ADDR -privkey-file PRIVKEYFILE DOMAIN UPSTREAMADDR
829 Example:
830 %[1]s -gen-key -privkey-file server.key -pubkey-file server.pub
831 %[1]s -udp :53 -privkey-file server.key t.example.com 127.0.0.1:8000
833 `, os.Args[0])
834 flag.PrintDefaults()
836 flag.BoolVar(&genKey, "gen-key", false, "generate a server keypair; print to stdout or save to files")
837 flag.IntVar(&maxUDPPayload, "mtu", maxUDPPayload, "maximum size of DNS responses")
838 flag.StringVar(&privkeyString, "privkey", "", fmt.Sprintf("server private key (%d hex digits)", noise.KeyLen*2))
839 flag.StringVar(&privkeyFilename, "privkey-file", "", "read server private key from file (with -gen-key, write to file)")
840 flag.StringVar(&pubkeyFilename, "pubkey-file", "", "with -gen-key, write server public key to file")
841 flag.StringVar(&udpAddr, "udp", "", "UDP address to listen on (required)")
842 flag.Parse()
844 log.SetFlags(log.LstdFlags | log.LUTC)
846 if genKey {
847 // -gen-key mode.
848 if flag.NArg() != 0 || privkeyString != "" || udpAddr != "" {
849 flag.Usage()
850 os.Exit(1)
852 if err := generateKeypair(privkeyFilename, pubkeyFilename); err != nil {
853 fmt.Fprintf(os.Stderr, "cannot generate keypair: %v\n", err)
854 os.Exit(1)
856 } else {
857 // Ordinary server mode.
858 if flag.NArg() != 2 {
859 flag.Usage()
860 os.Exit(1)
862 domain, err := dns.ParseName(flag.Arg(0))
863 if err != nil {
864 fmt.Fprintf(os.Stderr, "invalid domain %+q: %v\n", flag.Arg(0), err)
865 os.Exit(1)
867 upstream := flag.Arg(1)
868 // We keep upstream as a string in order to eventually pass it
869 // to net.Dial in handleStream. But for the sake of displaying
870 // an error or warning at startup, rather than only when the
871 // first stream occurs, we apply some parsing and name
872 // resolution checks here.
874 upstreamHost, _, err := net.SplitHostPort(upstream)
875 if err != nil {
876 // host:port format is required in all cases, so
877 // this is a fatal error.
878 fmt.Fprintf(os.Stderr, "cannot parse upstream address %+q: %v\n", upstream, err)
879 os.Exit(1)
881 upstreamIPAddr, err := net.ResolveIPAddr("ip", upstreamHost)
882 if err != nil {
883 // Failure to resolve the host portion is only a
884 // warning. The name will be re-resolved on each
885 // net.Dial in handleStream.
886 log.Printf("warning: cannot resolve upstream host %+q: %v", upstreamHost, err)
887 } else if upstreamIPAddr.IP == nil {
888 // Handle the special case of an empty string
889 // for the host portion, which resolves to a nil
890 // IP. This is a fatal error as we will not be
891 // able to dial this address.
892 fmt.Fprintf(os.Stderr, "cannot parse upstream address %+q: missing host in address\n", upstream)
893 os.Exit(1)
897 if udpAddr == "" {
898 fmt.Fprintf(os.Stderr, "the -udp option is required\n")
899 os.Exit(1)
901 dnsConn, err := net.ListenPacket("udp", udpAddr)
902 if err != nil {
903 fmt.Fprintf(os.Stderr, "opening UDP listener: %v\n", err)
904 os.Exit(1)
907 if pubkeyFilename != "" {
908 fmt.Fprintf(os.Stderr, "-pubkey-file may only be used with -gen-key\n")
909 os.Exit(1)
912 var privkey []byte
913 if privkeyFilename != "" && privkeyString != "" {
914 fmt.Fprintf(os.Stderr, "only one of -privkey and -privkey-file may be used\n")
915 os.Exit(1)
916 } else if privkeyFilename != "" {
917 var err error
918 privkey, err = readKeyFromFile(privkeyFilename)
919 if err != nil {
920 fmt.Fprintf(os.Stderr, "cannot read privkey from file: %v\n", err)
921 os.Exit(1)
923 } else if privkeyString != "" {
924 var err error
925 privkey, err = noise.DecodeKey(privkeyString)
926 if err != nil {
927 fmt.Fprintf(os.Stderr, "privkey format error: %v\n", err)
928 os.Exit(1)
931 if len(privkey) == 0 {
932 log.Println("generating a temporary one-time keypair")
933 log.Println("use the -privkey or -privkey-file option for a persistent server keypair")
934 var err error
935 privkey, err = noise.GeneratePrivkey()
936 if err != nil {
937 fmt.Fprintln(os.Stderr, err)
938 os.Exit(1)
942 err = run(privkey, domain, upstream, dnsConn)
943 if err != nil {
944 log.Fatal(err)