Update link to codesearch tools (#93)
[debiancodesearch.git] / grpcutil / grpcutil.go
blob7d97255675f8d00d322131afbeaeb0793b01576f
1 // Encapsulates common RPC server setup.
2 package grpcutil
4 import (
5 "crypto/tls"
6 "crypto/x509"
7 "flag"
8 "fmt"
9 "io/ioutil"
10 "net"
11 "net/http"
12 "strings"
14 "github.com/grpc-ecosystem/go-grpc-middleware/tracing/opentracing"
15 "golang.org/x/net/http2"
16 "golang.org/x/net/trace"
17 "google.golang.org/grpc"
18 "google.golang.org/grpc/credentials"
21 var (
22 requireClientAuth = flag.Bool("tls_require_client_auth",
23 true,
24 "Require TLS Client Authentication")
27 func init() {
28 // Disable grpc tracing until
29 // https://github.com/grpc/grpc-go/issues/695 is fixed.
30 grpc.EnableTracing = false
33 // grpcHandlerFunc returns an http.Handler that delegates to grpcServer on incoming gRPC
34 // connections or otherHandler otherwise. Copied from cockroachdb.
35 func grpcHandlerFunc(grpcServer *grpc.Server, otherHandler http.Handler) http.Handler {
36 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
37 // This is a partial recreation of gRPC's internal checks:
38 // https://github.com/grpc/grpc-go/blob/7834b974e55fbf85a5b01afb5821391c71084efd/transport/handler_server.go#L61
39 if r.ProtoMajor == 2 && strings.Contains(r.Header.Get("Content-Type"), "application/grpc") {
40 grpcServer.ServeHTTP(w, r)
41 } else {
42 otherHandler.ServeHTTP(w, r)
47 func DialTLS(addr, certFile, keyFile string) (*grpc.ClientConn, error) {
48 cert, err := tls.LoadX509KeyPair(certFile, keyFile)
49 if err != nil {
50 return nil, err
52 roots := x509.NewCertPool()
53 contents, err := ioutil.ReadFile(certFile)
54 if err != nil {
55 return nil, err
57 if !roots.AppendCertsFromPEM(contents) {
58 return nil, fmt.Errorf("Could not parse %q as PEM file (contents: %q)", certFile, contents)
60 auth := credentials.NewTLS(&tls.Config{
61 RootCAs: roots,
62 Certificates: []tls.Certificate{cert}})
64 return grpc.Dial(addr,
65 grpc.WithTransportCredentials(auth),
66 grpc.WithStreamInterceptor(grpc_opentracing.StreamClientInterceptor()),
67 grpc.WithUnaryInterceptor(grpc_opentracing.UnaryClientInterceptor()))
70 func ListenAndServeTLS(addr, certFile, keyFile string, register func(s *grpc.Server)) error {
71 ln, err := net.Listen("tcp", addr)
72 if err != nil {
73 return err
76 auth, err := credentials.NewServerTLSFromFile(certFile, keyFile)
77 if err != nil {
78 return err
81 s := grpc.NewServer(
82 grpc.Creds(auth),
83 grpc.StreamInterceptor(grpc_opentracing.StreamServerInterceptor()),
84 grpc.UnaryInterceptor(grpc_opentracing.UnaryServerInterceptor()))
86 register(s)
88 http.Handle("/", s)
90 srv := http.Server{
91 Addr: addr,
92 Handler: grpcHandlerFunc(s, http.DefaultServeMux),
94 if err := http2.ConfigureServer(&srv, nil); err != nil {
95 return err
97 roots := x509.NewCertPool()
98 contents, err := ioutil.ReadFile(certFile)
99 if err != nil {
100 return err
102 if !roots.AppendCertsFromPEM(contents) {
103 return fmt.Errorf("Could not parse %q as PEM file (contents: %q)", certFile, contents)
106 if *requireClientAuth {
107 srv.TLSConfig.ClientCAs = roots
108 srv.TLSConfig.ClientAuth = tls.RequireAndVerifyClientCert
109 trace.AuthRequest = func(req *http.Request) (bool, bool) {
110 return true, true
113 srv.TLSConfig.Certificates = make([]tls.Certificate, 1)
114 srv.TLSConfig.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile)
115 return srv.Serve(tls.NewListener(ln, srv.TLSConfig))