Turn Files into a streaming RPC, reduce maximum message size limit
[debiancodesearch.git] / cmd / dcs-source-backend / source-backend.go
blob21db062cffd2258a3140d38ae01866d34b6b63ac
1 // vim:ts=4:sw=4:noexpandtab
2 package main
4 import (
5 "flag"
6 "fmt"
7 "io"
8 "io/ioutil"
9 "log"
10 "math/rand"
11 "net/http"
12 _ "net/http/pprof"
13 "net/url"
14 "os"
15 "path"
16 "sort"
17 "strings"
18 "sync"
19 "time"
21 "github.com/Debian/dcs/grpcutil"
22 "github.com/Debian/dcs/proto"
23 "github.com/Debian/dcs/ranking"
24 "github.com/Debian/dcs/regexp"
25 _ "github.com/Debian/dcs/varz"
26 opentracing "github.com/opentracing/opentracing-go"
27 olog "github.com/opentracing/opentracing-go/log"
28 "github.com/prometheus/client_golang/prometheus"
29 "github.com/uber/jaeger-client-go"
30 jaegercfg "github.com/uber/jaeger-client-go/config"
31 "golang.org/x/net/context"
32 "google.golang.org/grpc"
35 var (
36 listenAddress = flag.String("listen_address", ":28082", "listen address ([host]:port)")
37 unpackedPath = flag.String("unpacked_path",
38 "/dcs-ssd/unpacked/",
39 "Path to the unpacked sources")
40 rankingDataPath = flag.String("ranking_data_path",
41 "/var/dcs/ranking.json",
42 "Path to the JSON containing ranking data")
43 tlsCertPath = flag.String("tls_cert_path", "", "Path to a .pem file containing the TLS certificate.")
44 tlsKeyPath = flag.String("tls_key_path", "", "Path to a .pem file containing the TLS private key.")
45 jaegerAgent = flag.String("jaeger_agent",
46 "localhost:5775",
47 "host:port of a github.com/uber/jaeger agent")
49 indexBackend proto.IndexBackendClient
52 type SourceReply struct {
53 // The number of the last used filename, needed for pagination
54 LastUsedFilename int
56 AllMatches []regexp.Match
59 type server struct {
62 // Serves a single file for displaying it in /show
63 func (s *server) File(ctx context.Context, in *proto.FileRequest) (*proto.FileReply, error) {
64 log.Printf("requested filename *%s*\n", in.Path)
65 // path.Join calls path.Clean so we get the shortest path without any "..".
66 absPath := path.Join(*unpackedPath, in.Path)
67 log.Printf("clean, absolute path is *%s*\n", absPath)
68 if !strings.HasPrefix(absPath, *unpackedPath) {
69 return nil, fmt.Errorf("Path traversal is bad, mhkay?")
72 contents, err := ioutil.ReadFile(absPath)
73 if err != nil {
74 return nil, err
76 return &proto.FileReply{
77 Contents: contents,
78 }, nil
81 func filterByKeywords(rewritten *url.URL, files []ranking.ResultPath) []ranking.ResultPath {
82 // The "package:" keyword, if specified.
83 pkg := rewritten.Query().Get("package")
84 // The "-package:" keywords, if specified.
85 npkgs := rewritten.Query()["npackage"]
86 // The "path:" keywords, if specified.
87 paths := rewritten.Query()["path"]
88 // The "-path" keywords, if specified.
89 npaths := rewritten.Query()["npath"]
91 // Filter the filenames if the "package:" keyword was specified.
92 if pkg != "" {
93 fmt.Printf("Filtering for package %q\n", pkg)
94 filtered := make(ranking.ResultPaths, 0, len(files))
95 for _, file := range files {
96 // XXX: Do we want this to be a regular expression match, too?
97 if file.Path[file.SourcePkgIdx[0]:file.SourcePkgIdx[1]] != pkg {
98 continue
101 filtered = append(filtered, file)
104 files = filtered
107 // Filter the filenames if the "-package:" keyword was specified.
108 for _, npkg := range npkgs {
109 fmt.Printf("Excluding matches for package %q\n", npkg)
110 filtered := make(ranking.ResultPaths, 0, len(files))
111 for _, file := range files {
112 // XXX: Do we want this to be a regular expression match, too?
113 if file.Path[file.SourcePkgIdx[0]:file.SourcePkgIdx[1]] == npkg {
114 continue
117 filtered = append(filtered, file)
120 files = filtered
123 for _, path := range paths {
124 fmt.Printf("Filtering for path %q\n", path)
125 pathRegexp, err := regexp.Compile(path)
126 if err != nil {
127 return files
128 // TODO: perform this validation before accepting the query, i.e. in dcs-web
129 //err := common.Templates.ExecuteTemplate(w, "error.html", map[string]interface{}{
130 // "q": r.URL.Query().Get("q"),
131 // "errormsg": fmt.Sprintf(`%v`, err),
132 // "suggestion": template.HTML(`See <a href="http://codesearch.debian.net/faq#regexp">http://codesearch.debian.net/faq#regexp</a> for help on regular expressions.`),
133 //})
134 //if err != nil {
135 // http.Error(w, err.Error(), http.StatusInternalServerError)
139 filtered := make(ranking.ResultPaths, 0, len(files))
140 for _, file := range files {
141 if pathRegexp.MatchString(file.Path, true, true) == -1 {
142 continue
145 filtered = append(filtered, file)
148 files = filtered
151 for _, path := range npaths {
152 fmt.Printf("Filtering for path %q\n", path)
153 pathRegexp, err := regexp.Compile(path)
154 if err != nil {
155 return files
156 // TODO: perform this validation before accepting the query, i.e. in dcs-web
157 //err := common.Templates.ExecuteTemplate(w, "error.html", map[string]interface{}{
158 // "q": r.URL.Query().Get("q"),
159 // "errormsg": fmt.Sprintf(`%v`, err),
160 // "suggestion": template.HTML(`See <a href="http://codesearch.debian.net/faq#regexp">http://codesearch.debian.net/faq#regexp</a> for help on regular expressions.`),
161 //})
162 //if err != nil {
163 // http.Error(w, err.Error(), http.StatusInternalServerError)
167 filtered := make(ranking.ResultPaths, 0, len(files))
168 for _, file := range files {
169 if pathRegexp.MatchString(file.Path, true, true) != -1 {
170 continue
173 filtered = append(filtered, file)
176 files = filtered
179 return files
182 func sendProgressUpdate(stream proto.SourceBackend_SearchServer, connMu *sync.Mutex, filesProcessed, filesTotal int) error {
183 connMu.Lock()
184 defer connMu.Unlock()
185 return stream.Send(&proto.SearchReply{
186 Type: proto.SearchReply_PROGRESS_UPDATE,
187 ProgressUpdate: &proto.ProgressUpdate{
188 FilesProcessed: uint64(filesProcessed),
189 FilesTotal: uint64(filesTotal),
194 // Reads a single JSON request from the TCP connection, performs the search and
195 // sends results back over the TCP connection as they appear.
196 func (s *server) Search(in *proto.SearchRequest, stream proto.SourceBackend_SearchServer) error {
197 ctx := stream.Context()
198 connMu := new(sync.Mutex)
199 logprefix := fmt.Sprintf("[%q]", in.Query)
200 span := opentracing.SpanFromContext(ctx)
202 // Ask the local index backend for all the filenames.
203 fstream, err := indexBackend.Files(ctx, &proto.FilesRequest{Query: in.Query})
204 if err != nil {
205 return fmt.Errorf("%s Error querying index backend for query %q: %v\n", logprefix, in.Query, err)
208 var possible []string
209 for {
210 resp, err := fstream.Recv()
211 if err == io.EOF {
212 break
214 if err != nil {
215 return err
217 possible = append(possible, resp.Path)
220 span.LogFields(olog.Int("files.possible", len(possible)))
222 // Parse the (rewritten) URL to extract all ranking options/keywords.
223 rewritten, err := url.Parse(in.RewrittenUrl)
224 if err != nil {
225 return err
227 rankingopts := ranking.RankingOptsFromQuery(rewritten.Query())
228 span.LogFields(olog.String("rankingopts", fmt.Sprintf("%+v", rankingopts)))
230 // Rank all the paths.
231 rankspan, _ := opentracing.StartSpanFromContext(ctx, "Rank")
232 files := make(ranking.ResultPaths, 0, len(possible))
233 for _, filename := range possible {
234 result := ranking.ResultPath{Path: filename}
235 result.Rank(&rankingopts)
236 if result.Ranking > -1 {
237 files = append(files, result)
240 rankspan.Finish()
242 // Filter all files that should be excluded.
243 filterspan, _ := opentracing.StartSpanFromContext(ctx, "Filter")
244 files = filterByKeywords(rewritten, files)
245 filterspan.Finish()
247 span.LogFields(olog.Int("files.filtered", len(files)))
249 // While not strictly necessary, this will lead to better results being
250 // discovered (and returned!) earlier, so let’s spend a few cycles on
251 // sorting the list of potential files first.
252 sort.Sort(files)
254 re, err := regexp.Compile(in.Query)
255 if err != nil {
256 return fmt.Errorf("%s Could not compile regexp: %v\n", logprefix, err)
259 span.LogFields(olog.String("regexp", re.String()))
261 log.Printf("%s regexp = %q, %d possible files\n", logprefix, re, len(files))
263 // Send the first progress update so that clients know how many files are
264 // going to be searched.
265 if err := sendProgressUpdate(stream, connMu, 0, len(files)); err != nil {
266 return fmt.Errorf("%s %v\n", logprefix, err)
269 // The tricky part here is “flow control”: if we just start grepping like
270 // crazy, we will eventually run out of memory because all our writes are
271 // blocked on the connection (and the goroutines need to keep the write
272 // buffer in memory until the write is done).
274 // So instead, we start 1000 worker goroutines and feed them work through a
275 // single channel. Due to these goroutines being blocked on writing,
276 // the grepping will naturally become slower.
277 work := make(chan ranking.ResultPath)
278 progress := make(chan int)
280 var wg sync.WaitGroup
281 // We add the additional 1 for the progress updater goroutine. It also
282 // needs to be done before we can return, otherwise it will try to use the
283 // (already closed) network connection, which is a fatal error.
284 wg.Add(len(files) + 1)
286 go func() {
287 for _, file := range files {
288 work <- file
290 close(work)
293 go func() {
294 cnt := 0
295 errorShown := false
296 var lastProgressUpdate time.Time
297 progressInterval := 2*time.Second + time.Duration(rand.Int63n(int64(500*time.Millisecond)))
298 for cnt < len(files) {
299 add := <-progress
300 cnt += add
302 if time.Since(lastProgressUpdate) > progressInterval {
303 if err := sendProgressUpdate(stream, connMu, cnt, len(files)); err != nil {
304 if !errorShown {
305 log.Printf("%s %v\n", logprefix, err)
306 // We need to read the 'progress' channel, so we cannot
307 // just exit the loop here. Instead, we suppress all
308 // error messages after the first one.
309 errorShown = true
312 lastProgressUpdate = time.Now()
316 if err := sendProgressUpdate(stream, connMu, len(files), len(files)); err != nil {
317 log.Printf("%s %v\n", logprefix, err)
319 close(progress)
321 wg.Done()
324 querystr := ranking.NewQueryStr(in.Query)
326 numWorkers := 1000
327 if len(files) < 1000 {
328 numWorkers = len(files)
330 for i := 0; i < numWorkers; i++ {
331 go func() {
332 re, err := regexp.Compile(in.Query)
333 if err != nil {
334 log.Printf("%s\n", err)
335 return
338 grep := regexp.Grep{
339 Regexp: re,
340 Stdout: os.Stdout,
341 Stderr: os.Stderr,
344 for file := range work {
345 sourcePkgName := file.Path[file.SourcePkgIdx[0]:file.SourcePkgIdx[1]]
346 if rankingopts.Pathmatch {
347 file.Ranking += querystr.Match(&file.Path)
349 if rankingopts.Sourcepkgmatch {
350 file.Ranking += querystr.Match(&sourcePkgName)
352 if rankingopts.Weighted {
353 file.Ranking += 0.1460 * querystr.Match(&file.Path)
354 file.Ranking += 0.0008 * querystr.Match(&sourcePkgName)
357 // TODO: figure out how to safely clone a dcs/regexp
358 matches := grep.File(path.Join(*unpackedPath, file.Path))
359 for _, match := range matches {
360 match.Ranking = ranking.PostRank(rankingopts, &match, &querystr)
361 match.PathRank = file.Ranking
362 //match.Path = match.Path[len(*unpackedPath):]
363 // NB: populating match.Ranking happens in
364 // cmd/dcs-web/querymanager because it depends on at least
365 // one other result.
367 // TODO: ideally, we’d get proto.Match structs from grep.File(), let’s do that after profiling the decoding performance
369 path := match.Path[len(*unpackedPath):]
370 connMu.Lock()
371 if err := stream.Send(&proto.SearchReply{
372 Type: proto.SearchReply_MATCH,
373 Match: &proto.Match{
374 Path: path,
375 Line: uint32(match.Line),
376 Package: path[:strings.Index(path, "/")],
377 Ctxp2: match.Ctxp2,
378 Ctxp1: match.Ctxp1,
379 Context: match.Context,
380 Ctxn1: match.Ctxn1,
381 Ctxn2: match.Ctxn2,
382 Pathrank: match.PathRank,
383 Ranking: match.Ranking,
385 }); err != nil {
386 connMu.Unlock()
387 log.Printf("%s %v\n", logprefix, err)
388 // Drain the work channel, but without doing any work.
389 // This effectively exits the worker goroutine(s)
390 // cleanly.
391 for _ = range work {
393 break
395 connMu.Unlock()
398 progress <- 1
400 wg.Done()
405 wg.Wait()
407 log.Printf("%s Sent all results.\n", logprefix)
408 return nil
411 func main() {
412 log.SetFlags(log.LstdFlags | log.Lshortfile)
413 flag.Parse()
415 cfg := jaegercfg.Configuration{
416 Sampler: &jaegercfg.SamplerConfig{
417 Type: "const",
418 Param: 1,
420 Reporter: &jaegercfg.ReporterConfig{
421 BufferFlushInterval: 1 * time.Second,
422 LocalAgentHostPort: *jaegerAgent,
425 closer, err := cfg.InitGlobalTracer(
426 "dcs-source-backend",
427 jaegercfg.Logger(jaeger.StdLogger),
429 if err != nil {
430 log.Fatal(err)
432 defer closer.Close()
434 rand.Seed(time.Now().UnixNano())
435 if !strings.HasSuffix(*unpackedPath, "/") {
436 *unpackedPath = *unpackedPath + "/"
438 fmt.Println("Debian Code Search source-backend")
440 if err := ranking.ReadRankingData(*rankingDataPath); err != nil {
441 log.Fatal(err)
444 conn, err := grpcutil.DialTLS("localhost:28081", *tlsCertPath, *tlsKeyPath)
445 if err != nil {
446 log.Fatalf("could not connect to %q: %v", "localhost:28081", err)
448 defer conn.Close()
449 indexBackend = proto.NewIndexBackendClient(conn)
451 http.Handle("/metrics", prometheus.Handler())
452 log.Fatal(grpcutil.ListenAndServeTLS(*listenAddress,
453 *tlsCertPath,
454 *tlsKeyPath,
455 func(s *grpc.Server) {
456 proto.RegisterSourceBackendServer(s, &server{})