Remove noise.socket, now unused.
[champa.git] / champa-server / main.go
blob0ac2e85f3f54d3b2f214afe15c44af1155000ef7
1 package main
3 import (
4 "bytes"
5 "encoding/base64"
6 "errors"
7 "flag"
8 "fmt"
9 "io"
10 "log"
11 "net"
12 "net/http"
13 "os"
14 "path"
15 "strings"
16 "sync"
17 "time"
19 "github.com/xtaci/kcp-go/v5"
20 "github.com/xtaci/smux"
21 "www.bamsoftware.com/git/champa.git/armor"
22 "www.bamsoftware.com/git/champa.git/encapsulation"
23 "www.bamsoftware.com/git/champa.git/noise"
24 "www.bamsoftware.com/git/champa.git/turbotunnel"
27 const (
28 // smux streams will be closed after this much time without receiving data.
29 idleTimeout = 2 * time.Minute
31 // How long we may wait for downstream data before sending an empty
32 // response.
33 maxResponseDelay = 100 * time.Millisecond
35 // How long to wait for a TCP connection to upstream to be established.
36 upstreamDialTimeout = 30 * time.Second
38 // net/http Server.ReadTimeout, the maximum time allowed to read an
39 // entire request, including the body. Because we are likely to be
40 // proxying through an AMP cache, we expect requests to be small, with
41 // no streaming body.
42 serverReadTimeout = 10 * time.Second
43 // net/http Server.WriteTimeout, the maximum time allowed to write an
44 // entire response, including the body. Because we are likely to be
45 // proxying through an AMP cache, our responses are limited in size and
46 // not streaming.
47 serverWriteTimeout = 20 * time.Second
48 // net/http Server.IdleTimeout, how long to keep a keep-alive HTTP
49 // connection open, awaiting another request.
50 serverIdleTimeout = idleTimeout
53 // handleStream bidirectionally connects a client stream with a TCP socket
54 // addressed by upstream.
55 func handleStream(stream *smux.Stream, upstream string, conv uint32) error {
56 dialer := net.Dialer{
57 Timeout: upstreamDialTimeout,
59 upstreamConn, err := dialer.Dial("tcp", upstream)
60 if err != nil {
61 return fmt.Errorf("stream %08x:%d connect upstream: %v", conv, stream.ID(), err)
63 defer upstreamConn.Close()
64 upstreamTCPConn := upstreamConn.(*net.TCPConn)
66 var wg sync.WaitGroup
67 wg.Add(2)
68 go func() {
69 defer wg.Done()
70 _, err := io.Copy(stream, upstreamTCPConn)
71 if err == io.EOF {
72 // smux Stream.Write may return io.EOF.
73 err = nil
75 if err != nil && !errors.Is(err, io.ErrClosedPipe) {
76 log.Printf("stream %08x:%d copy stream←upstream: %v", conv, stream.ID(), err)
78 upstreamTCPConn.CloseRead()
79 stream.Close()
80 }()
81 go func() {
82 defer wg.Done()
83 _, err := io.Copy(upstreamTCPConn, stream)
84 if err == io.EOF {
85 // smux Stream.WriteTo may return io.EOF.
86 err = nil
88 if err != nil && !errors.Is(err, io.ErrClosedPipe) {
89 log.Printf("stream %08x:%d copy upstream←stream: %v", conv, stream.ID(), err)
91 upstreamTCPConn.CloseWrite()
92 }()
93 wg.Wait()
95 return nil
98 // acceptStreams wraps a KCP session in a Noise channel and an smux.Session,
99 // then awaits smux streams. It passes each stream to handleStream.
100 func acceptStreams(conn *kcp.UDPSession, upstream string) error {
101 // Put an smux session on top of the KCP connection.
102 smuxConfig := smux.DefaultConfig()
103 smuxConfig.Version = 2
104 smuxConfig.KeepAliveTimeout = idleTimeout
105 smuxConfig.MaxReceiveBuffer = 16 * 1024 * 1024 // default is 4 * 1024 * 1024
106 smuxConfig.MaxStreamBuffer = 1 * 1024 * 1024 // default is 65536
107 sess, err := smux.Server(conn, smuxConfig)
108 if err != nil {
109 return err
111 defer sess.Close()
113 for {
114 stream, err := sess.AcceptStream()
115 if err != nil {
116 if err, ok := err.(net.Error); ok && err.Temporary() {
117 continue
119 if err == io.ErrClosedPipe {
120 // We don't want to report this error.
121 err = nil
123 return err
125 log.Printf("begin stream %08x:%d", conn.GetConv(), stream.ID())
126 go func() {
127 defer func() {
128 log.Printf("end stream %08x:%d", conn.GetConv(), stream.ID())
129 stream.Close()
131 err := handleStream(stream, upstream, conn.GetConv())
132 if err != nil {
133 log.Printf("stream %08x:%d handleStream: %v", conn.GetConv(), stream.ID(), err)
139 // acceptSessions listens for incoming KCP connections and passes them to
140 // acceptStreams.
141 func acceptSessions(ln *kcp.Listener, upstream string) error {
142 for {
143 conn, err := ln.AcceptKCP()
144 if err != nil {
145 if err, ok := err.(net.Error); ok && err.Temporary() {
146 continue
148 return err
150 log.Printf("begin session %08x", conn.GetConv())
151 // Permit coalescing the payloads of consecutive sends.
152 conn.SetStreamMode(true)
153 // Disable the dynamic congestion window (limit only by the
154 // maximum of local and remote static windows).
155 conn.SetNoDelay(
156 0, // default nodelay
157 0, // default interval
158 0, // default resend
159 1, // nc=1 => congestion window off
161 conn.SetWindowSize(1024, 1024) // Default is 32, 32.
162 go func() {
163 defer func() {
164 log.Printf("end session %08x", conn.GetConv())
165 conn.Close()
167 err := acceptStreams(conn, upstream)
168 if err != nil && !errors.Is(err, io.ErrClosedPipe) {
169 log.Printf("session %08x acceptStreams: %v", conn.GetConv(), err)
175 type Handler struct {
176 pconn *turbotunnel.QueuePacketConn
179 // decodeRequest extracts a ClientID and a payload from an incoming HTTP
180 // request. In case of a decoding failure, the returned payload slice will be
181 // nil. The payload is always non-nil after a successful decoding, even if the
182 // payload is empty.
183 func decodeRequest(req *http.Request) (turbotunnel.ClientID, []byte) {
184 // Check the version indicator of the incoming client–server protocol.
185 switch {
186 case strings.HasPrefix(req.URL.Path, "/0"):
187 // Version "0"'s payload is base64-encoded, using the URL-safe
188 // alphabet without padding, in the final path component
189 // (earlier path components are ignored).
190 _, encoded := path.Split(req.URL.Path[2:]) // Remove "/0" prefix.
191 decoded, err := base64.RawURLEncoding.DecodeString(encoded)
192 if err != nil {
193 return turbotunnel.ClientID{}, nil
195 var clientID turbotunnel.ClientID
196 n := copy(clientID[:], decoded)
197 if n != len(clientID) {
198 return turbotunnel.ClientID{}, nil
200 payload := decoded[n:]
201 return clientID, payload
202 default:
203 return turbotunnel.ClientID{}, nil
207 func (handler *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
208 const maxPayloadLength = 5000
210 if req.Method != "GET" {
211 http.Error(rw, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
212 return
215 rw.Header().Set("Content-Type", "text/html")
216 // Attempt to hint to an AMP cache not to waste resources caching this
217 // document. "The Google AMP Cache considers any document fresh for at
218 // least 15 seconds."
219 // https://developers.google.com/amp/cache/overview#google-amp-cache-updates
220 rw.Header().Set("Cache-Control", "max-age=15")
221 rw.WriteHeader(http.StatusOK)
223 enc, err := armor.NewEncoder(rw)
224 if err != nil {
225 log.Printf("armor.NewEncoder: %v", err)
226 return
228 defer enc.Close()
230 clientID, payload := decodeRequest(req)
231 if payload == nil {
232 // Could not decode the client request. We do not even have a
233 // meaningful clientID or nonce. This may be a result of the
234 // client deliberately sending a short request for traffic
235 // shaping purposes. Send back a dummy, though still
236 // AMP-compatible, response.
237 // TODO: random padding.
238 return
241 // Read incoming packets from the payload.
242 r := bytes.NewReader(payload)
243 for {
244 p, err := encapsulation.ReadData(r)
245 if err != nil {
246 break
248 handler.pconn.QueueIncoming(p, clientID)
251 limit := maxPayloadLength
252 // We loop and bundle as many outgoing packets as will fit, up to
253 // maxPayloadLength. We wait up to maxResponseDelay for the first
254 // available packet; after that we only include whatever packets are
255 // immediately available.
256 timer := time.NewTimer(maxResponseDelay)
257 defer timer.Stop()
258 first := true
259 for {
260 var p []byte
261 unstash := handler.pconn.Unstash(clientID)
262 outgoing := handler.pconn.OutgoingQueue(clientID)
263 // Prioritize taking a packet first from the stash, then from
264 // the outgoing queue, then finally check for expiration of the
265 // timer. (We continue to bundle packets even after the timer
266 // expires, as long as the packets are immediately available.)
267 select {
268 case p = <-unstash:
269 default:
270 select {
271 case p = <-unstash:
272 case p = <-outgoing:
273 default:
274 select {
275 case p = <-unstash:
276 case p = <-outgoing:
277 case <-timer.C:
281 // We wait for the first packet only. Later packets must be
282 // immediately available.
283 timer.Reset(0)
285 if len(p) == 0 {
286 // Timer expired, we are done bundling packets into this
287 // response.
288 break
291 limit -= len(p)
292 if !first && limit < 0 {
293 // This packet doesn't fit in the payload size limit.
294 // Stash it so that it will be first in line for the
295 // next response.
296 handler.pconn.Stash(p, clientID)
297 break
299 first = false
301 // Write the packet to the AMP response.
302 _, err := encapsulation.WriteData(enc, p)
303 if err != nil {
304 log.Printf("encapsulation.WriteData: %v", err)
305 break
307 if rw, ok := rw.(http.Flusher); ok {
308 rw.Flush()
313 // noiseLoop is the Noise interface between an external noiseConn, which sends
314 // and receives encrypted Noise messages, and an internal plainConn, which sends
315 // and receives normal plaintext packets. This function tracks the state of
316 // Noise handshakes and a map of ongoing sessions, proxies packets between the
317 // connections while a session is active, and removes session from the map when
318 // they are finished.
319 func noiseLoop(noiseConn net.PacketConn, plainConn *turbotunnel.QueuePacketConn, privkey []byte) error {
320 sessions := make(map[turbotunnel.ClientID]*noise.Session)
321 var sessionsLock sync.RWMutex
323 for {
324 msgType, msg, addr, err := noise.ReadMessageFrom(noiseConn)
325 if err != nil {
326 if err, ok := err.(net.Error); ok && err.Temporary() {
327 continue
329 return err
332 sessionsLock.RLock()
333 sess := sessions[addr.(turbotunnel.ClientID)]
334 sessionsLock.RUnlock()
336 switch msgType {
337 // If the msgType of the incoming Noise message is
338 // MsgTypeHandshakeInit, send back a MsgTypeHandshakeResp and
339 // begin a new session for addr.
340 case noise.MsgTypeHandshakeInit:
341 if sess != nil {
342 // Already have a session for this addr.
343 continue
346 // Send back a MsgTypeHandshakeResp to permit the
347 // initiator to complete the Noise handshake.
348 p := []byte{noise.MsgTypeHandshakeResp}
349 sess, p, err := noise.AcceptHandshake(p, msg, privkey)
350 if err != nil {
351 return err
353 _, err = noiseConn.WriteTo(p, addr)
354 if err != nil {
355 if err, ok := err.(net.Error); ok && err.Temporary() {
356 continue
358 return err
361 // We have enough information at this point to start a
362 // session. Store it in the map.
363 sessionsLock.Lock()
364 sessions[addr.(turbotunnel.ClientID)] = sess
365 sessionsLock.Unlock()
367 // Start a goroutine for sending to the peer on this
368 // session. Reading from the peer is handled in the
369 // MsgTypeTransport case in the top-level switch.
370 go func() {
371 defer func() {
372 sessionsLock.Lock()
373 delete(sessions, addr.(turbotunnel.ClientID))
374 sessionsLock.Unlock()
376 for p := range plainConn.OutgoingQueue(addr) {
377 buf := []byte{noise.MsgTypeTransport}
378 buf, err := sess.Encrypt(buf, p)
379 if err != nil {
380 log.Printf("Encrypt: %v", err)
381 break
383 _, err = noiseConn.WriteTo(buf, addr)
384 if err != nil {
385 log.Printf("WriteTo: %v", err)
386 if err, ok := err.(net.Error); ok && err.Temporary() {
387 continue
389 break
394 // If the msgType of the incoming Noise message is
395 // MsgTypeTransport, decrypt the message and queue the contents
396 // with plainConn.
397 case noise.MsgTypeTransport:
398 if sess == nil {
399 // No session yet for this addr.
400 continue
402 p, err := sess.Decrypt(nil, msg)
403 if err != nil {
404 log.Printf("Decrypt: %v", err)
405 continue
407 plainConn.QueueIncoming(p, addr)
409 default:
410 log.Printf("unknown msgType %d", msgType)
415 func run(listen, upstream string, privkey []byte) error {
416 // noiseConn is the packet interface that communicates with the AMP/HTTP
417 // Handler; it deals in encrypted Noise messages. plainConn is the
418 // packet interface that communicates with KCP. noiseLoop sits in the
419 // middle, handling Noise handshakes and sessions, and
420 // encrypting/decrypting between the two net.PacketConns.
421 noiseConn := turbotunnel.NewQueuePacketConn(turbotunnel.DummyAddr{}, idleTimeout*2)
422 plainConn := turbotunnel.NewQueuePacketConn(turbotunnel.DummyAddr{}, idleTimeout*2)
423 go func() {
424 err := noiseLoop(noiseConn, plainConn, privkey)
425 if err != nil {
426 fmt.Printf("noiseLoop: %v", err)
430 ln, err := kcp.ServeConn(nil, 0, 0, plainConn)
431 if err != nil {
432 return fmt.Errorf("opening KCP listener: %v", err)
434 defer ln.Close()
435 go func() {
436 err := acceptSessions(ln, upstream)
437 if err != nil {
438 log.Printf("acceptSessions: %v", err)
442 handler := &Handler{
443 pconn: noiseConn,
446 server := &http.Server{
447 Addr: listen,
448 Handler: handler,
449 ReadTimeout: serverReadTimeout,
450 WriteTimeout: serverWriteTimeout,
451 IdleTimeout: serverIdleTimeout,
452 // The default MaxHeaderBytes is plenty for our purposes.
454 defer server.Close()
456 return server.ListenAndServe()
459 func main() {
460 var genKey bool
461 var privkeyFilename string
462 var privkeyString string
463 var pubkeyFilename string
465 flag.Usage = func() {
466 fmt.Fprintf(flag.CommandLine.Output(), `Usage:
467 %[1]s -gen-key -privkey-file PRIVKEYFILE -pubkey-file PUBKEYFILE
468 %[1]s -privkey-file PRIVKEYFILE LISTENADDR UPSTREAMADDR
470 Example:
471 %[1]s -gen-key -privkey-file server.key -pubkey-file server.pub
472 %[1]s -privkey-file server.key 127.0.0.1:8080 127.0.0.1:7001
474 `, os.Args[0])
475 flag.PrintDefaults()
477 flag.BoolVar(&genKey, "gen-key", false, "generate a server keypair; print to stdout or save to files")
478 flag.StringVar(&privkeyString, "privkey", "", fmt.Sprintf("server private key (%d hex digits)", noise.KeyLen*2))
479 flag.StringVar(&privkeyFilename, "privkey-file", "", "read server private key from file (with -gen-key, write to file)")
480 flag.StringVar(&pubkeyFilename, "pubkey-file", "", "with -gen-key, write server public key to file")
481 flag.Parse()
483 log.SetFlags(log.LstdFlags | log.LUTC)
485 if genKey {
486 // -gen-key mode.
488 if flag.NArg() != 0 || privkeyString != "" {
489 flag.Usage()
490 os.Exit(1)
492 if err := generateKeypair(privkeyFilename, pubkeyFilename); err != nil {
493 fmt.Fprintf(os.Stderr, "cannot generate keypair: %v\n", err)
494 os.Exit(1)
496 } else {
497 // Ordinary server mode.
499 if flag.NArg() != 2 {
500 flag.Usage()
501 os.Exit(1)
503 listen := flag.Arg(0)
504 upstream := flag.Arg(1)
505 // We keep upstream as a string in order to eventually pass it to
506 // net.Dial in handleStream. But we do a preliminary resolution of the
507 // name here, in order to exit with a quick error at startup if the
508 // address cannot be parsed or resolved.
510 upstreamTCPAddr, err := net.ResolveTCPAddr("tcp", upstream)
511 if err == nil && upstreamTCPAddr.IP == nil {
512 err = fmt.Errorf("missing host in address")
514 if err != nil {
515 fmt.Fprintf(os.Stderr, "cannot parse upstream address: %v\n", err)
516 os.Exit(1)
520 var privkey []byte
521 if privkeyFilename != "" && privkeyString != "" {
522 fmt.Fprintf(os.Stderr, "only one of -privkey and -privkey-file may be used\n")
523 os.Exit(1)
524 } else if privkeyFilename != "" {
525 var err error
526 privkey, err = readKeyFromFile(privkeyFilename)
527 if err != nil {
528 fmt.Fprintf(os.Stderr, "cannot read privkey from file: %v\n", err)
529 os.Exit(1)
531 } else if privkeyString != "" {
532 var err error
533 privkey, err = noise.DecodeKey(privkeyString)
534 if err != nil {
535 fmt.Fprintf(os.Stderr, "privkey format error: %v\n", err)
536 os.Exit(1)
538 } else {
539 log.Println("generating a temporary one-time keypair")
540 log.Println("use the -privkey or -privkey-file option for a persistent server keypair")
541 var err error
542 privkey, err = noise.GeneratePrivkey()
543 if err != nil {
544 fmt.Fprintln(os.Stderr, err)
545 os.Exit(1)
547 log.Printf("pubkey %x", noise.PubkeyFromPrivkey(privkey))
550 err := run(listen, upstream, privkey)
551 if err != nil {
552 log.Fatal(err)