diff --git a/cmd/frontend/graphqlbackend/repository_reindex.go b/cmd/frontend/graphqlbackend/repository_reindex.go index 66e88a868797..352445b70786 100644 --- a/cmd/frontend/graphqlbackend/repository_reindex.go +++ b/cmd/frontend/graphqlbackend/repository_reindex.go @@ -2,9 +2,11 @@ package graphqlbackend import ( "context" + "fmt" "github.com/graph-gophers/graphql-go" + "github.com/sourcegraph/sourcegraph/cmd/frontend/internal/search/idf" "github.com/sourcegraph/sourcegraph/internal/auth" "github.com/sourcegraph/sourcegraph/internal/search/zoekt" ) @@ -13,6 +15,9 @@ import ( func (r *schemaResolver) ReindexRepository(ctx context.Context, args *struct { Repository graphql.ID }) (*EmptyResponse, error) { + // MARK(beyang): this is triggered by the "Reindex now" button on a page like https://sourcegraph.test:3443/github.com/hashicorp/errwrap/-/settings/index + fmt.Printf("# schemaResolver.ReindexRepository\n") + // 🚨 SECURITY: There is no reason why non-site-admins would need to run this operation. if err := auth.CheckCurrentUserIsSiteAdmin(ctx, r.db); err != nil { return nil, err @@ -23,6 +28,10 @@ func (r *schemaResolver) ReindexRepository(ctx context.Context, args *struct { return nil, err } + if err := idf.Update(ctx, repo.RepoName()); err != nil { + return nil, err + } + err = zoekt.Reindex(ctx, repo.RepoName(), repo.IDInt32()) if err != nil { return nil, err diff --git a/cmd/frontend/internal/codycontext/context.go b/cmd/frontend/internal/codycontext/context.go index a0906e58561d..24a0b53398c0 100644 --- a/cmd/frontend/internal/codycontext/context.go +++ b/cmd/frontend/internal/codycontext/context.go @@ -3,15 +3,19 @@ package codycontext import ( "context" "fmt" + "sort" "strings" "sync" + lg "log" + "github.com/grafana/regexp" "github.com/sourcegraph/conc/pool" "github.com/sourcegraph/log" "go.opentelemetry.io/otel/attribute" "github.com/sourcegraph/sourcegraph/cmd/frontend/internal/cody" + "github.com/sourcegraph/sourcegraph/cmd/frontend/internal/search/idf" "github.com/sourcegraph/sourcegraph/internal/api" "github.com/sourcegraph/sourcegraph/internal/conf" "github.com/sourcegraph/sourcegraph/internal/database" @@ -82,6 +86,7 @@ type CodyContextClient struct { type GetContextArgs struct { Repos []types.RepoIDName + RepoStats map[api.RepoName]*idf.StatsProvider Query string CodeResultsCount int32 TextResultsCount int32 @@ -138,13 +143,15 @@ func (c *CodyContextClient) GetCodyContext(ctx context.Context, args GetContextA embeddingsArgs := GetContextArgs{ Repos: embeddingRepos, + RepoStats: args.RepoStats, Query: args.Query, CodeResultsCount: int32(float32(args.CodeResultsCount) * embeddingsResultRatio), TextResultsCount: int32(float32(args.TextResultsCount) * embeddingsResultRatio), } keywordArgs := GetContextArgs{ - Repos: keywordRepos, - Query: args.Query, + Repos: keywordRepos, + RepoStats: args.RepoStats, + Query: args.Query, // Assign the remaining result budget to keyword search CodeResultsCount: args.CodeResultsCount - embeddingsArgs.CodeResultsCount, TextResultsCount: args.TextResultsCount - embeddingsArgs.TextResultsCount, @@ -277,7 +284,11 @@ func (c *CodyContextClient) getKeywordContext(ctx context.Context, args GetConte // mini-HACK: pass in the scope using repo: filters. In an ideal world, we // would not be using query text manipulation for this and would be using // the job structs directly. - keywordQuery := fmt.Sprintf(`repo:%s %s %s`, reposAsRegexp(args.Repos), getKeywordContextExcludeFilePathsQuery(), args.Query) + var maxTermsPerWord = 5 + transformedQuery := getTransformedQuery(args, maxTermsPerWord) + lg.Printf("# userQuery -> transformedQuery: %q -> %q", args.Query, transformedQuery) + fmt.Printf("# userQuery -> transformedQuery: %q -> %q", args.Query, transformedQuery) + keywordQuery := fmt.Sprintf(`repo:%s %s %s`, reposAsRegexp(args.Repos), getKeywordContextExcludeFilePathsQuery(), transformedQuery) ctx, cancel := context.WithCancel(ctx) defer cancel() @@ -371,3 +382,51 @@ func fileMatchToContextMatch(fm *result.FileMatch) FileChunkContext { StartLine: startLine, } } + +func getTransformedQuery(args GetContextArgs, maxTermsPerWord int) string { + if args.RepoStats == nil { + lg.Printf("# no stats set") + return args.Query + } + + for _, repo := range args.Repos { + if _, ok := args.RepoStats[repo.Name]; !ok { + // Don't transform query if one of the repositories lacks an IDF table + lg.Printf("# didn't find stats for repo %s", repo.Name) + return args.Query + } + } + + // TODO(rishabh): currently we are just picking up top-k vocab terms based on idf scores, but we can do a better semantic ranking of terms + // current matching is fairly limited based on substring matching, but perhaps stemming/lemmatization might be considered? + + var filteredToks []string + // var maxTermsPerWord = 5 + + type termScore struct { + term string + score float32 + } + + for _, word := range strings.Fields(args.Query) { + if len(word) < 4 { + continue + } + var matches []termScore + for _, stats := range args.RepoStats { + for term, score := range stats.GetTerms() { + if strings.Contains(term, word) && len(term) > 4 && score > 3 { + matches = append(matches, termScore{term: term, score: score}) + } + } + } + sort.Slice(matches, func(i, j int) bool { + return matches[i].score > matches[j].score + }) + for i := 0; i < min(maxTermsPerWord, len(matches)); i++ { + filteredToks = append(filteredToks, matches[i].term) + } + } + + return strings.Join(filteredToks, " ") +} diff --git a/cmd/frontend/internal/context/resolvers/context.go b/cmd/frontend/internal/context/resolvers/context.go index 207767ced54b..5698b434ede9 100644 --- a/cmd/frontend/internal/context/resolvers/context.go +++ b/cmd/frontend/internal/context/resolvers/context.go @@ -5,6 +5,7 @@ import ( "context" "encoding/json" "io" + lg "log" "net/http" "time" @@ -13,9 +14,11 @@ import ( "github.com/sourcegraph/conc/iter" "github.com/sourcegraph/conc/pool" "github.com/sourcegraph/log" + "github.com/sourcegraph/sourcegraph/cmd/frontend/graphqlbackend" "github.com/sourcegraph/sourcegraph/cmd/frontend/internal/cody" "github.com/sourcegraph/sourcegraph/cmd/frontend/internal/codycontext" + "github.com/sourcegraph/sourcegraph/cmd/frontend/internal/search/idf" "github.com/sourcegraph/sourcegraph/internal/api" "github.com/sourcegraph/sourcegraph/internal/conf" "github.com/sourcegraph/sourcegraph/internal/database" @@ -183,6 +186,7 @@ func (r *Resolver) GetCodyContext(ctx context.Context, args graphqlbackend.GetCo } repoNameIDs := make([]types.RepoIDName, len(repoIDs)) + repoStats := make(map[api.RepoName]*idf.StatsProvider) for i, repoID := range repoIDs { repo, ok := repos[repoID] if !ok { @@ -191,10 +195,18 @@ func (r *Resolver) GetCodyContext(ctx context.Context, args graphqlbackend.GetCo } repoNameIDs[i] = types.RepoIDName{ID: repoID, Name: repo.Name} + + stats, err := idf.Get(ctx, repo.Name) + if err != nil { + lg.Printf("Unexpected error getting idf index value for repo %v: %v", repoID, err) + continue + } + repoStats[repo.Name] = stats } fileChunks, err := r.contextClient.GetCodyContext(ctx, codycontext.GetContextArgs{ Repos: repoNameIDs, + RepoStats: repoStats, Query: args.Query, CodeResultsCount: args.CodeResultsCount, TextResultsCount: args.TextResultsCount, diff --git a/cmd/frontend/internal/search/idf/idf.go b/cmd/frontend/internal/search/idf/idf.go new file mode 100644 index 000000000000..b5347984713d --- /dev/null +++ b/cmd/frontend/internal/search/idf/idf.go @@ -0,0 +1,164 @@ +// Package idf computes and stores the inverse document frequency (IDF) of a set of repositories. +// +// TODO(beyang): should probably move this elsewhere +package idf + +import ( + "archive/tar" + "bufio" + "context" + "encoding/json" + "fmt" + "io" + "log" + "math" + "path" + "strings" + "unicode" + + "github.com/sourcegraph/sourcegraph/internal/api" + "github.com/sourcegraph/sourcegraph/internal/gitserver" + "github.com/sourcegraph/sourcegraph/internal/rcache" + "github.com/sourcegraph/sourcegraph/internal/redispool" + "github.com/sourcegraph/sourcegraph/lib/errors" +) + +var redisCache = rcache.NewWithTTL(redispool.Cache, "idf-index", 10*24*60*60) + +func Update(ctx context.Context, repoName api.RepoName) error { + fmt.Printf("# idf.Update(%v)\n", repoName) + + stats := NewStatsAggregator() + + git := gitserver.NewClient("idf-indexer") + r, err := git.ArchiveReader(ctx, repoName, gitserver.ArchiveOptions{Treeish: "HEAD", Format: gitserver.ArchiveFormatTar}) + if err != nil { + return nil + } + + permissibleExtensions := map[string]bool{ + ".py": true, ".js": true, ".ts": true, ".java": true, ".cpp": true, + ".c": true, ".cs": true, ".go": true, ".rb": true, ".rs": true, + ".php": true, ".html": true, ".css": true, ".scss": true, ".md": true, + ".sh": true, ".swift": true, ".kt": true, ".m": true, + } + + tr := tar.NewReader(r) + for { + header, err := tr.Next() + if err == io.EOF { + break // End of archive + } + if err != nil { + log.Printf("Error reading next tar header: %v", err) + continue + } + + // Skip directories + if header.Typeflag == tar.TypeDir { + continue + } + + // Check if the file has a permissible extension + ext := strings.ToLower(path.Ext(header.Name)) + + if !permissibleExtensions[ext] { + continue + } + + // Read the first line of the file + scanner := bufio.NewScanner(tr) + if scanner.Scan() { + stats.ProcessDoc(scanner.Text()) + } else if err := scanner.Err(); err != nil { + log.Printf("Error reading file content: %v", err) + } + } + + statsP := stats.EvalProvider() + statsBytes, err := json.Marshal(statsP) + + log.Printf("# storing stats: %s", string(statsBytes)) + + if err != nil { + return errors.Wrap(err, "idf.Update: failed to marshal IDF table") + } + + redisCache.Set(fmt.Sprintf("repo:%v", repoName), statsBytes) + return nil +} + +func Get(ctx context.Context, repoName api.RepoName) (*StatsProvider, error) { + fmt.Printf("# idf.Get(%v)", repoName) + b, ok := redisCache.Get(fmt.Sprintf("repo:%v", repoName)) + if !ok { + return nil, nil + } + + var stats StatsProvider + if err := json.Unmarshal(b, &stats); err != nil { + return nil, errors.Wrap(err, "idf.Get: failed to unmarshal IDF table") + } + + log.Printf("# fetching stats: %v", stats) + + return &stats, nil +} + +type StatsAggregator struct { + TermToDocCt map[string]int + DoctCt int +} + +func NewStatsAggregator() *StatsAggregator { + return &StatsAggregator{ + TermToDocCt: make(map[string]int), + } +} + +func isValidWord(word string) bool { + if len(word) < 3 || len(word) > 20 { + return false + } + hasLetter := false + for _, char := range word { + if !unicode.IsLetter(char) && !unicode.IsNumber(char) { + return false + } + if unicode.IsLetter(char) { + hasLetter = true + } + } + return hasLetter +} + +func (s *StatsAggregator) ProcessDoc(text string) { + words := strings.Fields(text) + for _, word := range words { + // word = strings.ToLower(word) + if isValidWord(word) { + s.TermToDocCt[word]++ + } + } + s.DoctCt++ +} + +func (s *StatsAggregator) EvalProvider() StatsProvider { + idf := make(map[string]float32) + for term, docCt := range s.TermToDocCt { + idf[term] = float32(math.Log(float64(s.DoctCt) / (1.0 + float64(docCt)))) + } + return StatsProvider{IDF: idf} +} + +type StatsProvider struct { + IDF map[string]float32 +} + +func (s *StatsProvider) GetIDF(term string) float32 { + return s.IDF[strings.ToLower(term)] +} + +func (s *StatsProvider) GetTerms() map[string]float32 { + return s.IDF +} diff --git a/cmd/frontend/internal/search/idf/tokenize.go b/cmd/frontend/internal/search/idf/tokenize.go new file mode 100644 index 000000000000..910a28f7fc09 --- /dev/null +++ b/cmd/frontend/internal/search/idf/tokenize.go @@ -0,0 +1,63 @@ +package idf + +import ( + "regexp" + "strings" +) + +var ( + camelStartRe = regexp.MustCompile(`^[A-Za-z][^A-Z]+`) + capStartRe = regexp.MustCompile(`^[A-Z][A-Z0-9]*`) +) + +func tokenizeCamelCase(s string) []string { + remainder := s + var toks []string + for len(remainder) > 0 { + if found := camelStartRe.FindString(remainder); found != "" { + toks = append(toks, found) + remainder = remainder[len(found):] + continue + } + if found := capStartRe.FindString(remainder); found != "" { + if len(found) == 1 || len(found) == len(remainder) { + toks = append(toks, found) + remainder = remainder[len(found):] + } else { + toks = append(toks, found[:len(found)-1]) + remainder = remainder[len(found)-1:] + } + continue + } + remainder = remainder[1:] + } + return toks +} + +func tokenizeSnakeCase(s string) []string { + return strings.Split(s, "_") +} + +var ( + sepRe = regexp.MustCompile(`([[:punct:]]|\s)+`) +) + +func TokenizeWord(w string) []string { + var toks []string + for _, part := range tokenizeSnakeCase(w) { + toks = append(toks, tokenizeCamelCase(part)...) + } + return toks +} + +func Tokenize(s string) []string { + var toks []string + for _, word := range Words(s) { + toks = append(toks, TokenizeWord(word)...) + } + return toks +} + +func Words(s string) []string { + return sepRe.Split(s, -1) +} diff --git a/cmd/frontend/internal/search/idf/tokenize_test.go b/cmd/frontend/internal/search/idf/tokenize_test.go new file mode 100644 index 000000000000..d03b37295de1 --- /dev/null +++ b/cmd/frontend/internal/search/idf/tokenize_test.go @@ -0,0 +1,69 @@ +package idf + +import ( + "testing" + + "github.com/google/go-cmp/cmp" +) + +func TestTokenizeCamelCase(t *testing.T) { + type testCase struct { + s string + expToks []string + } + cases := []testCase{ + { + s: "FooBar", + expToks: []string{"Foo", "Bar"}, + }, + { + s: "fooBarBaz", + expToks: []string{"foo", "Bar", "Baz"}, + }, + { + s: "HTMLParser", + expToks: []string{"HTML", "Parser"}, + }, + { + s: "parseHTML", + expToks: []string{"parse", "HTML"}, + }, + { + s: "HTML5Parser", + expToks: []string{"HTML5", "Parser"}, + }, + { + s: "parseHTML5", + expToks: []string{"parse", "HTML5"}, + }, + } + for _, c := range cases { + toks := tokenizeCamelCase(c.s) + if diff := cmp.Diff(toks, c.expToks); diff != "" { + t.Errorf(diff) + } + } +} + +func TestTokenize(t *testing.T) { + type testCase struct { + s string + expToks []string + } + cases := []testCase{ + { + s: "camelCase.snake_case + _weird_.", + expToks: []string{"camel", "Case", "snake", "case", "weird"}, + }, + { + s: "two words camelCase--!:@withPunctuation and_snake_case", + expToks: []string{"two", "words", "camel", "Case", "with", "Punctuation", "and", "snake", "case"}, + }, + } + for _, c := range cases { + toks := Tokenize(c.s) + if diff := cmp.Diff(c.expToks, toks); diff != "" { + t.Errorf(diff) + } + } +}