diff --git a/.air.toml b/.air.toml index 038991a..fae7f37 100644 --- a/.air.toml +++ b/.air.toml @@ -3,9 +3,9 @@ testdata_dir = "testdata" tmp_dir = "tmp" [build] - args_bin = ["serve"] - bin = "./tmp/main" - cmd = "go build -o ./tmp/main ./cmd/tld" + args_bin = ["watch"] + bin = "./tmp/tlddebug" + cmd = "go build -o ./tmp/tlddebug ./cmd/tld" delay = 1000 exclude_dir = ["assets", "tmp", "vendor", "frontend", "data"] exclude_file = [] diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 6bafe07..9b2eaba 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -8,6 +8,7 @@ on: permissions: contents: read + checks: write jobs: test: @@ -49,11 +50,16 @@ jobs: git diff --exit-code - name: Run unit tests with race detection - run: go test -race -shuffle=on -coverprofile=coverage.txt ./... + run: | + go install gotest.tools/gotestsum@latest + gotestsum \ + --junitfile results.xml \ + --format testdox \ + ./... - - name: Upload coverage reports to Codecov - uses: codecov/codecov-action@v5 + - name: Publish Test Results + uses: dorny/test-reporter@v1 with: - token: ${{ secrets.CODECOV_TOKEN }} - file: ./coverage.txt - fail_ci_if_error: false + name: Go Tests + path: results.xml + reporter: java-junit diff --git a/.gitignore b/.gitignore index cba93e3..aa67864 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,6 @@ frontend/public tmp frontend/*.tgz .agents +.claude/skills/impeccable +skills-lock.json +.DS_Store diff --git a/Makefile b/Makefile index 5c0532a..68238da 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -.PHONY: frontend-deps frontend-build lint-be lint-fe build run clean dev test-backend build-go setup-hooks make-be make-fe +.PHONY: frontend-deps frontend-build lint-be lint-fe build run clean dev dev-stop test-backend build-go setup-hooks make-be make-fe setup-hooks: chmod +x scripts/pre-commit.sh @@ -28,6 +28,10 @@ dev: @echo "Starting development stack..." @$(MAKE) -j 2 be fe +dev-stop: + @echo "Stopping development backend..." + -pkill -x tlddebug + proto: ## Update go.mod to latest BSR-published proto versions (run after buf push in proto/) go get buf.build/gen/go/tldiagramcom/diagram/protocolbuffers/go@$(shell buf registry sdk version --module=buf.build/tldiagramcom/diagram --plugin=buf.build/protocolbuffers/go) go get buf.build/gen/go/tldiagramcom/diagram/connectrpc/go@$(shell buf registry sdk version --module=buf.build/tldiagramcom/diagram --plugin=buf.build/connectrpc/go) diff --git a/README.md b/README.md index 77088ba..7a299c4 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,8 @@ -# tld +[![Logo](./frontend/logo/tld.svg)](https://tldiagram.com) [![Go Version](https://img.shields.io/github/go-mod/go-version/mertcikla/tld)](https://go.dev/) [![License](https://img.shields.io/github/license/mertcikla/tld)](./LICENSE) [![Build Status](https://img.shields.io/github/actions/workflow/status/mertcikla/tld/test.yml?branch=main)](https://github.com/mertcikla/tld/actions) [![Go Report Card](https://goreportcard.com/badge/github.com/mertcikla/tld)](https://goreportcard.com/report/github.com/mertcikla/tld) -`tld` provides a complete software architecture management platform that bundles a high-performance Go backend with an interactive React frontend into a single, standalone binary. Includes a CLI to enable managing diagrams from the shell or in CI. +`tld` provides a complete software architecture management platform that bundles a high-performance Go backend with an interactive React frontend into a single, standalone binary. Includes a CLI to enable managing diagrams from the shell or in CI. Designed for local-first development and private self-hosting, `tld` allows teams to visualize, document, and manage their system architecture using a combination of a rich web UI and "Diagrams as Code" workflows. @@ -11,6 +11,7 @@ Designed for local-first development and private self-hosting, `tld` allows team ## Key Features - **Full-Featured Web UI**: A React frontend designed, polished and optimized to handle complex architectures while attempting to intelligently show and hide details. +- **Git diff visualization**: Seamlessly sync and visualize the changes you or your agent are making live in diagram form. Inspect the dependencies and intervene when necessary. - **Bi-directional Sync**: Seamlessly sync changes between your local YAML files, the self-hosted web UI, and the cloud version at tlDiagram.com. - **Standalone Distribution**: A single, dependency-free binary containing both the server and the web application. - **CLI built that speaks agent**: Use the [agent skill](./skills/create-diagram/SKILL.md) and teach your agent how to create a diagram of your codebase with the exact detail level you need. You can prompt your agent to add/remove details as needed. @@ -29,20 +30,6 @@ Here are some examples that were generated using the agent skill. --- -## Table of Contents - -1. [Quick Start](#quick-start) -2. [Deployment & Self-Hosting](#deployment--self-hosting) -3. [The tlDiagram Workflow](#the-tldiagram-workflow) -4. [Tech Stack](#tech-stack) -5. [Development Setup](#development-setup) -6. [Commands Reference](#commands-reference) -7. [Workspace Structure](#workspace-structure) -8. [Environment Variables](#environment-variables) -9. [Troubleshooting](#troubleshooting) - ---- - ## Quick Start ### Single line install and start ```bash @@ -86,6 +73,9 @@ Run `tld serve` in any directory to start a local instance that uses your curren 1. Provide a persistent volume for the `.tld/` directory (where YAMLs and the SQLite cache are stored). 2. Set `TLD_ADDR=0.0.0.0` and `PORT=8060`. +### Configuration +Various configuration options are available in `~/.config/tldiagram/tld.yaml` + --- ## The tlDiagram Workflow @@ -98,8 +88,6 @@ Run `tld serve` in any directory to start a local instance that uses your curren --- -## Tech Stack - - **Backend**: Go 1.26+ - *CLI*: Cobra - *API*: Connect RPC (gRPC compatible) @@ -178,7 +166,7 @@ Flags: --format string output format: text or json (default "text") -h, --help help for tld -v, --version version for tld - -w, --workspace string workspace directory (default "tld") + -w, --workspace string workspace directory (prefers .tld, then tld; empty when neither exists) Use "tld [command] --help" for more information about a command diff --git a/build-assets/icons.tar.gz b/build-assets/icons.tar.gz index d4334c8..f0b8ffc 100644 Binary files a/build-assets/icons.tar.gz and b/build-assets/icons.tar.gz differ diff --git a/cmd/add/add.go b/cmd/add/add.go index 22834cc..3148a83 100644 --- a/cmd/add/add.go +++ b/cmd/add/add.go @@ -5,6 +5,7 @@ import ( "github.com/mertcikla/tld/internal/cmdutil" "github.com/mertcikla/tld/internal/completion" + "github.com/mertcikla/tld/internal/term" "github.com/mertcikla/tld/internal/workspace" "github.com/spf13/cobra" ) @@ -71,8 +72,8 @@ func NewAddCmd(wdir, format *string, compact *bool) *cobra.Command { if cmdutil.WantsJSON(*format) { return cmdutil.WriteMutation(cmd.OutOrStdout(), *compact, "add", "add", r) } - _, _ = fmt.Fprintf(cmd.OutOrStdout(), "Updated elements.yaml (upserted %s)\n", r) - _, _ = fmt.Fprintln(cmd.OutOrStdout(), "Change recorded locally in elements.yaml. Run 'tld apply' to push to cloud.") + term.Successf(cmd.OutOrStdout(), "Added element %s to elements.yaml", r) + term.Hint(cmd.OutOrStdout(), "Run 'tld apply' to push to cloud.") return nil }, } diff --git a/cmd/analyze/analyze.go b/cmd/analyze/analyze.go index e712fef..a8c7d4d 100644 --- a/cmd/analyze/analyze.go +++ b/cmd/analyze/analyze.go @@ -1,47 +1,44 @@ package analyze import ( + "encoding/json" "fmt" "io" "os" "path/filepath" - "sort" "strings" + "sync" "time" - "github.com/mertcikla/tld/internal/analyzer" - "github.com/mertcikla/tld/internal/cmdutil" - "github.com/mertcikla/tld/internal/git" - "github.com/mertcikla/tld/internal/ignore" + assets "github.com/mertcikla/tld" + "github.com/mertcikla/tld/internal/localserver" + "github.com/mertcikla/tld/internal/store" "github.com/mertcikla/tld/internal/term" + watchpkg "github.com/mertcikla/tld/internal/watch" + "github.com/mertcikla/tld/internal/watch/exportyaml" "github.com/mertcikla/tld/internal/workspace" "github.com/schollz/progressbar/v3" "github.com/spf13/cobra" - "gopkg.in/yaml.v3" ) -var analyzeSpinnerFrames = []string{"⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"} - -var analyzeService analyzer.Service = analyzer.DefaultService() - func NewAnalyzeCmd(wdir *string) *cobra.Command { - var deep bool var dryRun bool - var changedSince string + var dataDirFlag string + var embeddingProvider, embeddingEndpoint, embeddingModel string + var embeddingDimension int + var languageFlags []string + var rescan, failOnDrift bool + var maxElements, maxConnectors, maxIncoming, maxOutgoing, maxExpandedGroup int c := &cobra.Command{ Use: "analyze ", - Short: "Extract symbols from source files and upsert them as workspace elements", - Long: `Walks the given path, extracts code symbols (functions, classes, types) using -tree-sitter grammar modules, and upserts each symbol as an Element in elements.yaml. -References and imports found between files, folders, and symbols are upserted as Connectors in connectors.yaml. - -By default only the given path is scanned. Use --deep to scan the entire git repo -for cross-file call references.`, + Short: "Scan and materialize a source repository into workspace YAML", + Long: `Scans the git repository containing the given path through the watch pipeline, +materializes the canonical SQLite code representation, and exports generated resources +to elements.yaml and connectors.yaml. Manual YAML resources are preserved.`, Args: cobra.ExactArgs(1), RunE: func(cmd *cobra.Command, args []string) error { scanPath := args[0] - absPath, err := filepath.Abs(scanPath) if err != nil { return fmt.Errorf("resolve path: %w", err) @@ -49,1080 +46,269 @@ for cross-file call references.`, if _, err := os.Stat(absPath); err != nil { return fmt.Errorf("path %q not found: %w", scanPath, err) } - ws, err := workspace.Load(*wdir) if err != nil { return fmt.Errorf("load workspace: %w", err) } - - repoScopes, err := cmdutil.ResolveAnalyzeRepoScopes(ws, absPath) + cfg, err := workspace.LoadGlobalConfig() if err != nil { return err } - - ctx := cmd.Context() - totalElements := 0 - totalConnectors := 0 - incrementalFiles := 0 - totalEntries := 0 - knownElements := buildAnalyzeElementIndex(ws) - usedNames := buildAnalyzeElementNameOwners(ws) - modeLabel := "shallow" - if deep { - modeLabel = "deep" - } - linePrefix := "" - if dryRun { - linePrefix = "[dry-run] " - } - _, _ = fmt.Fprintf(cmd.OutOrStdout(), "%sAnalyzing %s (%s)...\n", linePrefix, scanPath, modeLabel) - workspaceRoot, _ := filepath.Abs(ws.Dir) - workspaceScan := cmdutil.SamePath(absPath, workspaceRoot) - scanConfiguredRepositories := workspaceScan && ws.WorkspaceConfig != nil && len(ws.WorkspaceConfig.Repositories) > 0 - countTasks := 0 - for _, repoCtx := range repoScopes { - countTasks++ - if deep && repoCtx.Active() && !cmdutil.SamePath(absPath, workspaceRoot) { - countTasks++ - } + dataDir, err := workspace.ResolveDataDir(cfg, dataDirFlag) + if err != nil { + return err } - countProgress := newAnalyzeProgressBar(cmd.ErrOrStderr(), countTasks) - if countProgress != nil { - defer func() { - if !countProgress.IsFinished() { - _ = countProgress.Clear() - } - }() - countProgress.Describe(fmt.Sprintf("%s Counting scan plan", analyzeSpinnerFrames[0])) + if err := os.MkdirAll(dataDir, 0o755); err != nil { + return fmt.Errorf("create data dir: %w", err) } - for _, repoCtx := range repoScopes { - rules := ws.IgnoreRulesForRepository(repoCtx.Name) - scanRoot := absPath - if scanConfiguredRepositories { - scanRoot = repoCtx.Root - } - entries, err := countAnalyzeEntries(scanRoot, rules) + embeddingCfg := resolveAnalyzeEmbeddingConfig(cfg, embeddingProvider, embeddingEndpoint, embeddingModel, embeddingDimension) + settings := resolveAnalyzeWatchSettings(cfg, languageFlags, maxElements, maxConnectors, maxIncoming, maxOutgoing, maxExpandedGroup) + if embeddingCfg.Provider != "none" { + checked, _, err := watchpkg.CheckEmbeddingHealth(cmd.Context(), embeddingCfg) if err != nil { - return fmt.Errorf("count entries: %w", err) - } - totalEntries += entries - if countProgress != nil { - countProgress.Describe(fmt.Sprintf("%s Counting scan plan for %s", analyzeSpinnerFrames[entries%len(analyzeSpinnerFrames)], repoCtx.Name)) - _ = countProgress.Add(1) - } - if deep && repoCtx.Active() && !cmdutil.SamePath(absPath, workspaceRoot) { - deepEntries, err := countAnalyzeEntries(repoCtx.Root, rules) - if err != nil { - return fmt.Errorf("count deep entries: %w", err) - } - totalEntries += deepEntries - if countProgress != nil { - countProgress.Describe(fmt.Sprintf("%s Counting deep scan for %s", analyzeSpinnerFrames[deepEntries%len(analyzeSpinnerFrames)], repoCtx.Name)) - _ = countProgress.Add(1) - } + return fmt.Errorf("embedding healthcheck failed: %w", err) } + embeddingCfg = checked } - if countProgress != nil { - _ = countProgress.Finish() + sqliteStore, err := store.Open(localserver.DatabasePath(dataDir), assets.FS) + if err != nil { + return err } - - progress := newAnalyzeProgressBar(cmd.ErrOrStderr(), totalEntries) - if progress != nil { - defer func() { - if !progress.IsFinished() { - _ = progress.Clear() - } - }() + defer func() { _ = sqliteStore.DB().Close() }() + watchStore := watchpkg.NewStore(sqliteStore.DB()) + progress := newAnalyzeWatchProgress(cmd.ErrOrStderr()) + linePrefix := "" + if dryRun { + linePrefix = "[dry-run] " } - processedEntries := 0 - - for i, repoCtx := range repoScopes { - if progress != nil { - progress.Describe(fmt.Sprintf("%s Scanning %s (%d/%d)", analyzeSpinnerFrames[processedEntries%len(analyzeSpinnerFrames)], repoCtx.Name, i+1, len(repoScopes))) - } - rules := ws.IgnoreRulesForRepository(repoCtx.Name) - scanRoot := absPath - if scanConfiguredRepositories { - scanRoot = repoCtx.Root - } - - var repoURL, branch string - if repoCtx.Active() { - if url, err := git.DetectRemoteURL(repoCtx.Root); err == nil { - repoURL = url - } - if b, err := git.DetectBranch(repoCtx.Root); err == nil { - branch = b - } - } - - scanResult, err := analyzeService.ExtractPath(ctx, scanRoot, rules, func(path string, isDir bool) { - processedEntries++ - if progress == nil { - return - } - spinner := analyzeSpinnerFrames[processedEntries%len(analyzeSpinnerFrames)] - progress.Describe(fmt.Sprintf("%s Scanning %s (%d/%d)", spinner, repoCtx.Name, processedEntries, totalEntries)) - _ = progress.Add(1) - }) - if err != nil { - return fmt.Errorf("extract symbols: %w", err) - } - - changedFileSet := map[string]struct{}{} - if changedSince != "" && repoCtx.Active() { - changed, err := git.FilesChangedSince(repoCtx.Root, changedSince) - if err != nil { - return fmt.Errorf("git changed-since: %w", err) - } - incrementalFiles += len(changed) - for _, file := range changed { - changedFileSet[filepath.Clean(file)] = struct{}{} - } - } - - if deep && repoCtx.Active() && !workspaceScan { - deepResult, err := analyzeService.ExtractPath(ctx, repoCtx.Root, rules, func(path string, isDir bool) { - processedEntries++ - if progress == nil { - return - } - spinner := analyzeSpinnerFrames[processedEntries%len(analyzeSpinnerFrames)] - progress.Describe(fmt.Sprintf("%s Scanning %s (%d/%d)", spinner, repoCtx.Name, processedEntries, totalEntries)) - _ = progress.Add(1) - }) - if err != nil { - _, _ = fmt.Fprintf(cmd.ErrOrStderr(), "Warning: deep scan failed: %v\n", err) - } else { - scanResult.Refs = append(scanResult.Refs, deepResult.Refs...) - } - } - - filtered := filterSymbols(scanResult.Symbols, rules) - if len(changedFileSet) > 0 { - filtered = filterSymbolsByFiles(filtered, changedFileSet) - scanResult.Refs = filterRefsByFiles(scanResult.Refs, changedFileSet) - } - - if len(filtered) == 0 { - continue - } - - if progress != nil && !dryRun { - _ = progress.Finish() - } - - elementRoot := analyzeElementRoot(scanRoot, repoCtx.Root, repoCtx.Active()) - filePaths := uniqueFilePaths(filtered, elementRoot) - folderPaths := uniqueFolderPaths(filePaths) - plannedElementWrites := 1 + len(folderPaths) + len(filePaths) + len(filtered) - writeProgress := newAnalyzeProgressBar(cmd.ErrOrStderr(), plannedElementWrites) - elementWriteAttempts := 0 - - usedRefs := make(map[string]struct{}, len(ws.Elements)) - for ref := range ws.Elements { - usedRefs[ref] = struct{}{} - } - - repoName := filepath.Base(repoCtx.Root) - repoRef, err := ensureAnalyzeElement(*wdir, dryRun, ws, knownElements, usedRefs, usedNames, analyzeElementSpec{ - Name: repoName, - Kind: "repository", - Owner: repoCtx.Name, - Repo: repoURL, - Branch: branch, - HasView: len(filtered) > 0, - Technology: "Git Repository", - ViewLabel: repoName, - ParentRef: "root", - Identity: analyzeElementIdentity{ - Repo: repoURL, - Branch: branch, - FilePath: "", - Symbol: "", - Kind: "repository", - Name: repoName, - }, - }) - if err != nil { - return fmt.Errorf("ensure repository element: %w", err) - } - if writeProgress != nil { - elementWriteAttempts++ - advanceAnalyzeWriteProgress(writeProgress, "elements.yaml", elementWriteAttempts, plannedElementWrites) - } - - folderRefs := make(map[string]string) - fileRefs := make(map[string]string) - symbolRefs := make(map[analyzeElementLookupKey]string) - symbolRefsByName := make(map[string][]string) - symbolFiles := make(map[string]string) - repoElements := 1 - - for _, relPath := range folderPaths { - folderName := filepath.Base(relPath) - parentRef := repoRef - if parentPath := filepath.Dir(relPath); parentPath != "." { - if existingParentRef := folderRefs[parentPath]; existingParentRef != "" { - parentRef = existingParentRef - } - } - - folderRef, err := ensureAnalyzeElement(*wdir, dryRun, ws, knownElements, usedRefs, usedNames, analyzeElementSpec{ - Name: folderName, - Kind: "folder", - Owner: repoCtx.Name, - Repo: repoURL, - Branch: branch, - FilePath: relPath, - Technology: "Folder", - ParentRef: parentRef, - Identity: analyzeElementIdentity{ - Repo: repoURL, - Branch: branch, - FilePath: relPath, - Kind: "folder", - Name: folderName, - }, - }) - if err != nil { - return fmt.Errorf("ensure folder element %q: %w", relPath, err) - } - if writeProgress != nil { - elementWriteAttempts++ - advanceAnalyzeWriteProgress(writeProgress, "elements.yaml", elementWriteAttempts, plannedElementWrites) - } - folderRefs[relPath] = folderRef - repoElements++ - } - - for _, relPath := range filePaths { - fileName := filepath.Base(relPath) - parentRef := repoRef - if parentPath := filepath.Dir(relPath); parentPath != "." { - if folderRef := folderRefs[parentPath]; folderRef != "" { - parentRef = folderRef - } - } - fileTech := "" - if lang, ok := analyzer.DetectLanguage(relPath); ok { - fileTech = string(lang) - } - hasFileSymbols := false - for _, sym := range filtered { - if filepath.Clean(sym.FilePath) == filepath.Clean(relPath) { - hasFileSymbols = true - break - } - } - fileRef, err := ensureAnalyzeElement(*wdir, dryRun, ws, knownElements, usedRefs, usedNames, analyzeElementSpec{ - Name: fileName, - Kind: "file", - Owner: repoCtx.Name, - Repo: repoURL, - Branch: branch, - FilePath: relPath, - HasView: hasFileSymbols, - ViewLabel: fileName, - Technology: fileTech, - ParentRef: parentRef, - Identity: analyzeElementIdentity{ - Repo: repoURL, - Branch: branch, - FilePath: relPath, - Symbol: "", - Kind: "file", - Name: fileName, - }, - }) - if err != nil { - return fmt.Errorf("ensure file element %q: %w", relPath, err) - } - if writeProgress != nil { - elementWriteAttempts++ - advanceAnalyzeWriteProgress(writeProgress, "elements.yaml", elementWriteAttempts, plannedElementWrites) - } - fileRefs[relPath] = fileRef - repoElements++ - } - - for _, sym := range filtered { - relPath := analyzeRelativeFilePath(sym.FilePath, elementRoot) - fileRef := fileRefs[relPath] - if fileRef == "" { - continue - } - - parentRef := fileRef - if sym.Parent != "" { - if refs := symbolRefsByName[sym.Parent]; len(refs) == 1 { - p := refs[0] - parentRef = p - } - } - - ref, err := ensureAnalyzeElement(*wdir, dryRun, ws, knownElements, usedRefs, usedNames, analyzeElementSpec{ - Name: sym.Name, - Kind: sym.Kind, - Owner: repoCtx.Name, - Repo: repoURL, - Branch: branch, - FilePath: relPath, - Symbol: sym.Name, - ParentName: sym.Parent, - Technology: sym.Technology, - Description: sym.Description, - ParentRef: parentRef, - Identity: analyzeElementIdentity{ - Repo: repoURL, - Branch: branch, - FilePath: relPath, - Symbol: sym.Name, - Kind: sym.Kind, - Name: sym.Name, - }, - }) - if progress != nil && !dryRun { - elementWriteAttempts++ - advanceAnalyzeWriteProgress(progress, "elements.yaml", elementWriteAttempts, plannedElementWrites) - } - if err != nil { - _, _ = fmt.Fprintf(cmd.ErrOrStderr(), "Warning: upsert element %q: %v\n", sym.Name, err) - continue - } - symbolRefs[analyzeSymbolLookupKey(sym)] = ref - symbolRefsByName[sym.Name] = append(symbolRefsByName[sym.Name], ref) - symbolFiles[ref] = relPath - repoElements++ - } - - if !dryRun && repoElements > 0 { - if err := workspace.Save(ws); err != nil { - _, _ = fmt.Fprintf(cmd.ErrOrStderr(), "Warning: save elements: %v\n", err) - } - } - - resolverRoot := scanRoot - if repoCtx.Active() { - resolverRoot = repoCtx.Root - } - resolver := newAnalyzeLSPResolver(resolverRoot) - modulePath := analyzeModulePath(resolverRoot) - - plannedResolutionSteps := len(scanResult.Refs) - if writeProgress != nil && plannedResolutionSteps > 0 { - writeProgress.AddMax(plannedResolutionSteps) - describeAnalyzeResolutionProgress(writeProgress, repoCtx.DisplayName(), 0, plannedResolutionSteps) - } - resolvedSteps := 0 - plannedConnectors := make([]*workspace.Connector, 0, len(scanResult.Refs)) - for _, ref := range scanResult.Refs { - resolvedSteps++ - if writeProgress != nil && plannedResolutionSteps > 0 { - describeAnalyzeResolutionProgress(writeProgress, repoCtx.DisplayName(), resolvedSteps, plannedResolutionSteps) - } - if ref.Kind != "import" && rules.ShouldIgnoreSymbol(ref.Name) { - if writeProgress != nil && plannedResolutionSteps > 0 { - _ = writeProgress.Add(1) - } - continue - } - plannedConnectors = append(plannedConnectors, buildAnalyzeConnectorsForRef( - ctx, - resolver, - ref, - ws, - filtered, - symbolRefs, - symbolRefsByName, - fileRefs, - folderRefs, - symbolFiles, - repoRef, - elementRoot, - modulePath, - )...) - if writeProgress != nil && plannedResolutionSteps > 0 { - _ = writeProgress.Add(1) - } - } - _ = resolver.Close() - - plannedConnectors = uniqueAnalyzeConnectors(plannedConnectors) - - repoConnectors := 0 - if writeProgress != nil && len(plannedConnectors) > 0 { - writeProgress.AddMax(len(plannedConnectors)) - } - - if !dryRun { - if err := workspace.AppendConnectors(*wdir, plannedConnectors); err != nil { - _, _ = fmt.Fprintf(cmd.ErrOrStderr(), "Warning: write connectors: %v\n", err) - } else { - repoConnectors = len(plannedConnectors) - } - if writeProgress != nil { - advanceAnalyzeWriteProgress(writeProgress, "connectors.yaml", len(plannedConnectors), len(plannedConnectors)) - } - } else { - repoConnectors = len(plannedConnectors) - } - - totalElements += repoElements - totalConnectors += repoConnectors + if formatFlag(cmd) != "json" { + _, _ = fmt.Fprintf(cmd.OutOrStdout(), "%sAnalyzing %s (watch pipeline)...\n", linePrefix, scanPath) } - if progress != nil { - _ = progress.Finish() + once, err := watchpkg.NewRunner(watchStore).RunOnce(cmd.Context(), watchpkg.OneShotOptions{Path: absPath, Rescan: rescan, Embedding: embeddingCfg, Settings: settings, Progress: progress}) + if err != nil { + return err } - if !dryRun { - if err := ensureAnalyzeRepositoriesRegistered(ws, repoScopes); err != nil { - return fmt.Errorf("register analyzed repositories: %w", err) - } + exported, exportResult, err := exportyaml.ExportWithProgress(cmd.Context(), sqliteStore, watchStore, ws, once.Scan.RepositoryID, progress) + if err != nil { + return fmt.Errorf("export yaml: %w", err) } - - _, _ = fmt.Fprintf(cmd.OutOrStdout(), "%s OK %d elements written to elements.yaml\n", linePrefix, totalElements) - _, _ = fmt.Fprintf(cmd.OutOrStdout(), "%s OK %d connectors written to connectors.yaml\n", linePrefix, totalConnectors) - _, _ = fmt.Fprintf(cmd.OutOrStdout(), "%s OK %d repositories scanned\n", linePrefix, len(repoScopes)) - if changedSince != "" { - _, _ = fmt.Fprintf(cmd.OutOrStdout(), "%s OK Incremental scan: %d files changed since %s\n", linePrefix, incrementalFiles, changedSince) + changed := hasAnalyzeDrift(once.Diffs) + if formatFlag(cmd) == "json" { + payload := map[string]any{"changed": changed, "scan": once.Scan, "representation": once.Representation, "export": exportResult, "diffs": once.Diffs} + if err := json.NewEncoder(cmd.OutOrStdout()).Encode(payload); err != nil { + return err + } + } else { + _, _ = fmt.Fprintf(cmd.OutOrStdout(), "%s OK %d elements written to elements.yaml\n", linePrefix, exportResult.ElementsWritten) + _, _ = fmt.Fprintf(cmd.OutOrStdout(), "%s OK %d connectors written to connectors.yaml\n", linePrefix, exportResult.ConnectorsWritten) + _, _ = fmt.Fprintf(cmd.OutOrStdout(), "%s OK 1 repository scanned\n", linePrefix) + _, _ = fmt.Fprintln(cmd.OutOrStdout()) } - _, _ = fmt.Fprintln(cmd.OutOrStdout()) if dryRun { - _, _ = fmt.Fprintf(cmd.OutOrStdout(), "%sNo files written. Remove --dry-run to apply.\n", linePrefix) - } else { - _, _ = fmt.Fprintln(cmd.OutOrStdout(), "Tip: run `tld plan` to preview what will be applied.") + if formatFlag(cmd) != "json" { + _, _ = fmt.Fprintf(cmd.OutOrStdout(), "%sNo files written. Remove --dry-run to apply.\n", linePrefix) + } + } else if err := workspace.Save(exported); err != nil { + return fmt.Errorf("save workspace: %w", err) + } + if failOnDrift && changed { + return fmt.Errorf("watch representation drift detected") } return nil }, } - c.Flags().BoolVar(&deep, "deep", false, "scan entire git repo for cross-file references (slower)") - c.Flags().BoolVar(&dryRun, "dry-run", false, "print what would be written without modifying workspace") - c.Flags().StringVar(&changedSince, "changed-since", "", "only re-analyse files changed since this git SHA") + c.Flags().BoolVar(&dryRun, "dry-run", false, "scan, materialize, and print drift without modifying workspace YAML") + c.Flags().StringVar(&dataDirFlag, "data-dir", "", "directory for the local app database") + c.Flags().BoolVar(&rescan, "rescan", false, "force reparsing files even if cached") + c.Flags().BoolVar(&failOnDrift, "fail-on-drift", false, "exit nonzero when representation drift is detected") + c.Flags().StringSliceVar(&languageFlags, "language", nil, "source language to scan (repeatable)") + c.Flags().StringVar(&embeddingProvider, "embedding-provider", "", "embedding provider for representation") + c.Flags().StringVar(&embeddingEndpoint, "embedding-endpoint", "", "embedding endpoint for representation") + c.Flags().StringVar(&embeddingModel, "embedding-model", "", "embedding model for representation") + c.Flags().IntVar(&embeddingDimension, "embedding-dimension", 0, "embedding vector dimension") + c.Flags().IntVar(&maxElements, "max-elements-per-view", 0, "maximum generated elements per view") + c.Flags().IntVar(&maxConnectors, "max-connectors-per-view", 0, "maximum generated connectors per view") + c.Flags().IntVar(&maxIncoming, "max-incoming-per-element", 0, "maximum incoming references per element before collapsing") + c.Flags().IntVar(&maxOutgoing, "max-outgoing-per-element", 0, "maximum outgoing references per element before collapsing") + c.Flags().IntVar(&maxExpandedGroup, "max-expanded-connectors-per-group", 0, "maximum file-pair connectors to expand before collapsing to a folder connector") return c } -func ensureAnalyzeRepositoriesRegistered(ws *workspace.Workspace, repoScopes []cmdutil.RepoScope) error { - if ws == nil { - return nil - } - - config := ws.WorkspaceConfig - if config == nil { - config = &workspace.WorkspaceConfig{} - } - if config.Repositories == nil { - config.Repositories = make(map[string]workspace.Repository) - } - - workspaceRoot, err := filepath.Abs(ws.Dir) - if err != nil { - return fmt.Errorf("resolve workspace root: %w", err) - } - - changed := false - for _, repoScope := range repoScopes { - if repoScope.Name == "" { - continue - } - if _, exists := config.Repositories[repoScope.Name]; exists { - continue +func resolveAnalyzeEmbeddingConfig(cfg *workspace.Config, provider, endpoint, model string, dimension int) watchpkg.EmbeddingConfig { + embedding := watchpkg.EmbeddingConfig{Provider: "none"} + if cfg != nil { + embedding = watchpkg.EmbeddingConfig{ + Provider: cfg.Watch.Embedding.Provider, + Endpoint: cfg.Watch.Embedding.Endpoint, + Model: cfg.Watch.Embedding.Model, + Dimension: cfg.Watch.Embedding.Dimension, + HealthThreshold: cfg.Watch.Embedding.HealthThreshold, } - config.Repositories[repoScope.Name] = workspace.Repository{ - URL: repoScope.RemoteURL, - LocalDir: analyzeRepositoryLocalDir(workspaceRoot, repoScope.Root), - } - changed = true - } - if !changed { - return nil - } - - data, err := yaml.Marshal(config) - if err != nil { - return fmt.Errorf("marshal workspace config: %w", err) - } - if err := os.WriteFile(workspace.WorkspaceConfigPath(ws.Dir), data, 0600); err != nil { - return fmt.Errorf("write .tld.yaml: %w", err) - } - ws.WorkspaceConfig = config - return nil -} - -func analyzeRepositoryLocalDir(workspaceRoot, repoRoot string) string { - if workspaceRoot == "" || repoRoot == "" { - return repoRoot } - parent := filepath.Dir(workspaceRoot) - if cmdutil.SamePath(repoRoot, parent) { - return "" + if provider != "" { + embedding.Provider = provider } - if rel, err := filepath.Rel(workspaceRoot, repoRoot); err == nil && rel != ".." && !strings.HasPrefix(rel, ".."+string(os.PathSeparator)) { - return rel + if endpoint != "" { + embedding.Endpoint = endpoint } - return repoRoot -} - -func newAnalyzeProgressBar(out io.Writer, total int) *progressbar.ProgressBar { - if total <= 0 || !term.IsTerminal(out) { - return nil - } - return progressbar.NewOptions(total, - progressbar.OptionSetWriter(out), - progressbar.OptionSetVisibility(true), - progressbar.OptionSetDescription("⠋ Scanning"), - progressbar.OptionShowCount(), - progressbar.OptionShowIts(), - progressbar.OptionSetWidth(12), - progressbar.OptionFullWidth(), - progressbar.OptionClearOnFinish(), - progressbar.OptionThrottle(60*time.Millisecond), - ) -} - -func advanceAnalyzeWriteProgress(progress *progressbar.ProgressBar, fileName string, completed, total int) { - if progress == nil || total <= 0 { - return + if model != "" { + embedding.Model = model } - spinner := analyzeSpinnerFrames[completed%len(analyzeSpinnerFrames)] - progress.Describe(fmt.Sprintf("%s Writing %s (%d/%d)", spinner, fileName, completed, total)) - _ = progress.Add(1) -} - -func describeAnalyzeResolutionProgress(progress *progressbar.ProgressBar, repoName string, completed, total int) { - if progress == nil || total <= 0 { - return + if dimension > 0 { + embedding.Dimension = dimension } - spinner := analyzeSpinnerFrames[completed%len(analyzeSpinnerFrames)] - progress.Describe(fmt.Sprintf("%s Resolving symbols via LSP in %s (%d/%d)", spinner, repoName, completed, total)) + return watchpkg.NormalizeEmbeddingConfig(embedding) } -func countAnalyzeEntries(path string, rules *ignore.Rules) (int, error) { - info, err := os.Stat(path) - if err != nil { - return 0, err - } - if !info.IsDir() { - if rules.ShouldIgnoreFile(path) { - return 0, nil +func resolveAnalyzeWatchSettings(cfg *workspace.Config, languages []string, maxElements, maxConnectors, maxIncoming, maxOutgoing, maxExpandedGroup int) watchpkg.Settings { + settings := watchpkg.DefaultSettings() + if cfg != nil { + settings.Languages = cfg.Watch.Languages + settings.Watcher = cfg.Watch.Watcher + settings.PollInterval = parseAnalyzeDurationOrZero(cfg.Watch.PollInterval) + settings.Debounce = parseAnalyzeDurationOrZero(cfg.Watch.Debounce) + settings.Thresholds = watchpkg.Thresholds{ + MaxElementsPerView: cfg.Watch.Thresholds.MaxElementsPerView, + MaxConnectorsPerView: cfg.Watch.Thresholds.MaxConnectorsPerView, + MaxIncomingPerElement: cfg.Watch.Thresholds.MaxIncomingPerElement, + MaxOutgoingPerElement: cfg.Watch.Thresholds.MaxOutgoingPerElement, + MaxExpandedConnectorsPerGroup: cfg.Watch.Thresholds.MaxExpandedConnectorsPerGroup, } - return 1, nil - } - - count := 0 - err = filepath.WalkDir(path, func(currentPath string, d os.DirEntry, walkErr error) error { - if walkErr != nil { - return walkErr + settings.Visibility = watchpkg.VisibilityConfig{ + CoreThresholdEnabled: cfg.Watch.Visibility.CoreThresholdEnabled, + CoreThreshold: cfg.Watch.Visibility.CoreThreshold, + TierMultiplier: cfg.Watch.Visibility.TierMultiplier, + MaxExpansionMultiplier: cfg.Watch.Visibility.MaxExpansionMultiplier, + CoreThresholdSet: true, + WeightsSet: true, + Weights: watchpkg.VisibilityWeights{ + Changed: cfg.Watch.Visibility.Weights.Changed, + Selected: cfg.Watch.Visibility.Weights.Selected, + UserShow: cfg.Watch.Visibility.Weights.UserShow, + UserHide: cfg.Watch.Visibility.Weights.UserHide, + HighSignalFact: cfg.Watch.Visibility.Weights.HighSignalFact, + RelationshipProximity: cfg.Watch.Visibility.Weights.RelationshipProximity, + DependencyFact: cfg.Watch.Visibility.Weights.DependencyFact, + UtilityNoise: cfg.Watch.Visibility.Weights.UtilityNoise, + HighDegreeNoise: cfg.Watch.Visibility.Weights.HighDegreeNoise, + }, } - if d.IsDir() { - rel, _ := filepath.Rel(path, currentPath) - if rules.ShouldIgnorePath(rel) || rules.ShouldIgnorePath(d.Name()) { - return filepath.SkipDir - } - count++ - return nil - } - if rules.ShouldIgnorePath(currentPath) { - return nil - } - count++ - return nil - }) - if err != nil { - return 0, err } - return count, nil -} - -func filterSymbols(symbols []analyzer.Symbol, rules *ignore.Rules) []analyzer.Symbol { - var out []analyzer.Symbol - for _, s := range symbols { - if rules.ShouldIgnoreSymbol(s.Name) { - continue - } - out = append(out, s) - } - return out -} - -func filterSymbolsByFiles(symbols []analyzer.Symbol, changedFiles map[string]struct{}) []analyzer.Symbol { - var out []analyzer.Symbol - for _, sym := range symbols { - if _, ok := changedFiles[filepath.Clean(sym.FilePath)]; ok { - out = append(out, sym) - } - } - return out -} - -func filterRefsByFiles(refs []analyzer.Ref, changedFiles map[string]struct{}) []analyzer.Ref { - var out []analyzer.Ref - for _, ref := range refs { - if _, ok := changedFiles[filepath.Clean(ref.FilePath)]; ok { - out = append(out, ref) - } - } - return out -} - -func refByFileAndLine(filePath string, line int, refMap map[analyzeElementLookupKey]string, symbols []analyzer.Symbol) string { - symbol, ok := symbolByFileAndLine(filePath, line, symbols) - if !ok { - return "" - } - return refMap[analyzeSymbolLookupKey(symbol)] -} - -type analyzeElementIdentity struct { - Repo string - Branch string - FilePath string - Symbol string - Kind string - Name string -} - -type analyzeElementLookupKey struct { - Branch string - FilePath string - Symbol string - Kind string -} - -type analyzeElementSpec struct { - Name string - Kind string - Owner string - Repo string - Branch string - FilePath string - Symbol string - ParentName string - HasView bool - ViewLabel string - ParentRef string - Technology string - Description string - Identity analyzeElementIdentity -} - -func buildAnalyzeElementIndex(ws *workspace.Workspace) map[analyzeElementLookupKey]string { - index := make(map[analyzeElementLookupKey]string, len(ws.Elements)) - for ref, element := range ws.Elements { - if element == nil { - continue - } - index[analyzeElementLookupKey{ - Branch: element.Branch, - FilePath: normalizeAnalyzePath(element.FilePath), - Symbol: element.Symbol, - Kind: element.Kind, - }] = ref - } - return index -} - -func buildAnalyzeElementNameOwners(ws *workspace.Workspace) map[string]map[string]struct{} { - owners := make(map[string]map[string]struct{}, len(ws.Elements)) - for ref, element := range ws.Elements { - if element == nil || element.Name == "" { - continue - } - if owners[element.Name] == nil { - owners[element.Name] = make(map[string]struct{}) - } - owners[element.Name][ref] = struct{}{} - } - return owners -} - -func normalizeAnalyzePath(p string) string { - return filepath.ToSlash(filepath.Clean(p)) -} - -func normalizeAnalyzeElementLookupKey(identity analyzeElementIdentity) analyzeElementLookupKey { - return analyzeElementLookupKey{ - Branch: identity.Branch, - FilePath: normalizeAnalyzePath(identity.FilePath), - Symbol: identity.Symbol, - Kind: identity.Kind, - } -} - -func ensureAnalyzeElement(wdir string, dryRun bool, ws *workspace.Workspace, known map[analyzeElementLookupKey]string, usedRefs map[string]struct{}, usedNames map[string]map[string]struct{}, spec analyzeElementSpec) (string, error) { - identity := normalizeAnalyzeElementLookupKey(spec.Identity) - ref := "" - if knownRef, ok := known[identity]; ok { - ref = knownRef - } else if existingRef, ok := findAnalyzeElementRef(ws, analyzeElementIdentity{ - Branch: identity.Branch, - FilePath: identity.FilePath, - Symbol: identity.Symbol, - Kind: identity.Kind, - }); ok { - ref = existingRef - known[identity] = ref - } else { - ref = uniqueAnalyzeRef(spec.Name, spec.FilePath, usedRefs) - usedRefs[ref] = struct{}{} - known[identity] = ref + if len(languages) > 0 { + settings.Languages = languages } - if ref == "" { - ref = uniqueAnalyzeRef(spec.Name, spec.FilePath, usedRefs) - usedRefs[ref] = struct{}{} - known[identity] = ref + if maxElements > 0 { + settings.Thresholds.MaxElementsPerView = maxElements } - - if ws.Elements != nil { - if existing := ws.Elements[ref]; existing != nil && existing.Name != "" { - releaseAnalyzeElementName(usedNames, existing.Name, ref) - } + if maxConnectors > 0 { + settings.Thresholds.MaxConnectorsPerView = maxConnectors } - spec.Name = uniqueAnalyzeElementName(ref, spec, usedNames) - claimAnalyzeElementName(usedNames, spec.Name, ref) - if dryRun { - return ref, nil + if maxIncoming > 0 { + settings.Thresholds.MaxIncomingPerElement = maxIncoming } - elementSpec := analyzeElementToWorkspaceElement(spec) - if existing := ws.Elements[ref]; existing != nil { - if elementSpec.Description == "" { - elementSpec.Description = existing.Description - } - if elementSpec.Technology == "" { - elementSpec.Technology = existing.Technology - } - if elementSpec.URL == "" { - elementSpec.URL = existing.URL - } + if maxOutgoing > 0 { + settings.Thresholds.MaxOutgoingPerElement = maxOutgoing } - if ws.Elements == nil { - ws.Elements = make(map[string]*workspace.Element) + if maxExpandedGroup > 0 { + settings.Thresholds.MaxExpandedConnectorsPerGroup = maxExpandedGroup } - ws.Elements[ref] = elementSpec - return ref, nil + return watchpkg.NormalizeSettings(settings) } -func analyzeSymbolLookupKey(symbol analyzer.Symbol) analyzeElementLookupKey { - return analyzeElementLookupKey{ - FilePath: filepath.Clean(symbol.FilePath), - Symbol: symbol.Name, - Kind: symbol.Kind, - } -} - -func analyzeElementToWorkspaceElement(spec analyzeElementSpec) *workspace.Element { - return &workspace.Element{ - Name: spec.Name, - Kind: spec.Kind, - Owner: spec.Owner, - Repo: spec.Repo, - Branch: spec.Branch, - FilePath: spec.FilePath, - Symbol: spec.Symbol, - HasView: spec.HasView, - ViewLabel: spec.ViewLabel, - Technology: spec.Technology, - Description: spec.Description, - Placements: []workspace.ViewPlacement{{ - ParentRef: spec.ParentRef, - }}, +func parseAnalyzeDurationOrZero(value string) time.Duration { + parsed, err := time.ParseDuration(strings.TrimSpace(value)) + if err != nil { + return 0 } + return parsed } -func findAnalyzeElementRef(ws *workspace.Workspace, identity analyzeElementIdentity) (string, bool) { - targetPath := normalizeAnalyzePath(identity.FilePath) - - // 1. Try exact match including branch - for ref, element := range ws.Elements { - if element == nil { - continue - } - if element.Kind == identity.Kind && - normalizeAnalyzePath(element.FilePath) == targetPath && - element.Symbol == identity.Symbol && - element.Branch == identity.Branch { - return ref, true - } - } - - // 2. Try lenient match excluding branch - for ref, element := range ws.Elements { - if element == nil { - continue - } - if element.Kind == identity.Kind && - normalizeAnalyzePath(element.FilePath) == targetPath && - element.Symbol == identity.Symbol { - return ref, true - } +func formatFlag(cmd *cobra.Command) string { + flag := cmd.Root().Flag("format") + if flag == nil || strings.TrimSpace(flag.Value.String()) == "" { + return "text" } - - return "", false + return strings.ToLower(strings.TrimSpace(flag.Value.String())) } -func uniqueAnalyzeElementName(ref string, spec analyzeElementSpec, usedNames map[string]map[string]struct{}) string { - for _, candidate := range analyzeElementNameCandidates(spec) { - candidate = strings.TrimSpace(candidate) - if candidate == "" { - continue - } - if owners := usedNames[candidate]; len(owners) == 0 || (len(owners) == 1 && containsAnalyzeNameOwner(owners, ref)) { - return candidate - } - } - base := spec.Name - if candidates := analyzeElementNameCandidates(spec); len(candidates) > 0 { - base = candidates[len(candidates)-1] - } - for i := 2; ; i++ { - candidate := fmt.Sprintf("%s (%d)", base, i) - if owners := usedNames[candidate]; len(owners) == 0 || (len(owners) == 1 && containsAnalyzeNameOwner(owners, ref)) { - return candidate +func hasAnalyzeDrift(diffs []watchpkg.RepresentationDiff) bool { + for _, diff := range diffs { + if diff.ChangeType != "initialized" && diff.OwnerType != "repository" { + return true } } + return false } -func analyzeElementNameCandidates(spec analyzeElementSpec) []string { - rawName := strings.TrimSpace(spec.Name) - filePath := analyzeQualifiedElementPath(spec.Owner, spec.FilePath) - qualifiedSymbol := rawName - if spec.ParentName != "" { - qualifiedSymbol = spec.ParentName + "." + rawName - } - - candidates := []string{rawName} - switch { - case spec.Kind == "repository": - if spec.Owner != "" { - candidates = append(candidates, spec.Owner) - } - case spec.Symbol != "": - if qualifiedSymbol != rawName { - candidates = append(candidates, qualifiedSymbol) - } - if spec.FilePath != "" { - candidates = append(candidates, filepath.ToSlash(filepath.Clean(spec.FilePath))+"::"+qualifiedSymbol) - } - if filePath != "" { - candidates = append(candidates, filePath+"::"+qualifiedSymbol) - } - case spec.FilePath != "": - candidates = append(candidates, filepath.ToSlash(filepath.Clean(spec.FilePath))) - if filePath != "" { - candidates = append(candidates, filePath) - } - } - return candidates +type analyzeWatchProgress struct { + out io.Writer + bar *progressbar.ProgressBar + mu sync.Mutex } -func analyzeQualifiedElementPath(owner, path string) string { - cleanPath := filepath.ToSlash(filepath.Clean(path)) - if cleanPath == "" || cleanPath == "." { - return strings.TrimSpace(owner) - } - if owner == "" { - return cleanPath - } - return owner + "/" + cleanPath +func newAnalyzeWatchProgress(out io.Writer) *analyzeWatchProgress { + return &analyzeWatchProgress{out: out} } -func claimAnalyzeElementName(usedNames map[string]map[string]struct{}, name, ref string) { - if name == "" || ref == "" { +func (p *analyzeWatchProgress) Start(label string, total int) { + if p == nil || p.out == nil || total <= 0 || !term.IsTerminal(p.out) { return } - if usedNames[name] == nil { - usedNames[name] = make(map[string]struct{}) + p.mu.Lock() + defer p.mu.Unlock() + if p.bar != nil { + _ = p.bar.Finish() + p.bar = nil } - usedNames[name][ref] = struct{}{} -} - -func releaseAnalyzeElementName(usedNames map[string]map[string]struct{}, name, ref string) { - owners := usedNames[name] - if len(owners) == 0 { - return - } - delete(owners, ref) - if len(owners) == 0 { - delete(usedNames, name) - } -} - -func containsAnalyzeNameOwner(owners map[string]struct{}, ref string) bool { - _, ok := owners[ref] - return ok -} - -func uniqueAnalyzeRef(name, filePath string, used map[string]struct{}) string { - base := workspace.Slugify(name) - if base == "" { - base = "element" - } - if _, taken := used[base]; !taken { - return base - } - fileBase := strings.TrimSuffix(filepath.Base(filePath), filepath.Ext(filePath)) - candidate := workspace.Slugify(fileBase + "-" + name) - if candidate == "" { - candidate = base - } - if _, taken := used[candidate]; !taken { - return candidate - } - for i := 2; ; i++ { - withSuffix := fmt.Sprintf("%s-%d", candidate, i) - if _, taken := used[withSuffix]; !taken { - return withSuffix - } - } -} - -func uniqueFilePaths(symbols []analyzer.Symbol, root string) []string { - seen := make(map[string]struct{}) - paths := make([]string, 0, len(symbols)) - for _, sym := range symbols { - relPath := analyzeRelativeFilePath(sym.FilePath, root) - if _, ok := seen[relPath]; ok { - continue - } - seen[relPath] = struct{}{} - paths = append(paths, relPath) + p.bar = newAnalyzeProgressBar(p.out, total) + if p.bar != nil { + p.bar.Describe(label) } - return paths } -func uniqueFolderPaths(filePaths []string) []string { - seen := make(map[string]struct{}) - folders := make([]string, 0, len(filePaths)) - for _, filePath := range filePaths { - for dir := filepath.Dir(filePath); dir != "." && dir != string(filepath.Separator); dir = filepath.Dir(dir) { - dir = filepath.Clean(dir) - if _, ok := seen[dir]; ok { - if next := filepath.Dir(dir); next == dir { - break - } - continue - } - seen[dir] = struct{}{} - folders = append(folders, dir) - } - } - sort.Slice(folders, func(i, j int) bool { - leftDepth := strings.Count(filepath.ToSlash(folders[i]), "/") - rightDepth := strings.Count(filepath.ToSlash(folders[j]), "/") - if leftDepth != rightDepth { - return leftDepth < rightDepth - } - return folders[i] < folders[j] - }) - return folders -} - -func analyzeElementRoot(scanRoot, repoRoot string, activeRepo bool) string { - cleanScanRoot := filepath.Clean(scanRoot) - if activeRepo && cmdutil.PathWithin(cleanScanRoot, filepath.Clean(repoRoot)) { - return filepath.Clean(repoRoot) - } - info, err := os.Stat(cleanScanRoot) - if err == nil && !info.IsDir() { - return filepath.Dir(cleanScanRoot) - } - return cleanScanRoot -} - -func analyzeRelativeFilePath(path, root string) string { - cleanPath := filepath.Clean(path) - if root == "" || cleanPath == "" || !filepath.IsAbs(cleanPath) { - return normalizeAnalyzePath(cleanPath) - } - if relPath, ok := analyzePathWithinRoot(root, cleanPath); ok { - return normalizeAnalyzePath(relPath) - } - resolvedRoot, rootErr := filepath.EvalSymlinks(root) - resolvedPath, pathErr := filepath.EvalSymlinks(cleanPath) - if rootErr == nil && pathErr == nil { - if relPath, ok := analyzePathWithinRoot(resolvedRoot, resolvedPath); ok { - return normalizeAnalyzePath(relPath) - } - } - return normalizeAnalyzePath(cleanPath) -} - -func analyzePathWithinRoot(root, path string) (string, bool) { - relPath, err := filepath.Rel(root, path) - if err != nil { - return "", false +func (p *analyzeWatchProgress) Advance(label string) { + if p == nil { + return } - relPath = filepath.Clean(relPath) - if relPath == "." { - return relPath, true + p.mu.Lock() + defer p.mu.Unlock() + if p.bar == nil { + return } - if relPath == ".." || strings.HasPrefix(relPath, ".."+string(os.PathSeparator)) { - return "", false + if label != "" { + p.bar.Describe(label) } - return relPath, true + _ = p.bar.Add(1) } -func uniqueAnalyzeConnectors(connectors []*workspace.Connector) []*workspace.Connector { - seen := make(map[string]*workspace.Connector, len(connectors)) - unique := make([]*workspace.Connector, 0, len(connectors)) - for _, connector := range connectors { - if connector == nil { - continue - } - if merged := analyzeMergeDuplicateConnector(seen, connector); merged { - continue - } - key := workspace.ConnectorKey(connector) - reverseKey := analyzeReverseConnectorKey(connector) - if existing, ok := seen[reverseKey]; ok { - existing.Direction = "both" - continue - } - seen[key] = connector - unique = append(unique, connector) - } - return unique -} - -func analyzeMergeDuplicateConnector(seen map[string]*workspace.Connector, connector *workspace.Connector) bool { - if connector == nil { - return false - } - key := workspace.ConnectorKey(connector) - if existing, ok := seen[key]; ok { - if existing.Relationship == "depends_on" && connector.Relationship == "depends_on" { - existing.Label = analyzeMergeDependencyLabels(existing.Label, connector.Label) - delete(seen, key) - seen[workspace.ConnectorKey(existing)] = existing - } - return true - } - mergeKey := analyzeDependencyMergeKey(connector) - if mergeKey == "" { - return false - } - for existingKey, existing := range seen { - if existing == nil || analyzeDependencyMergeKey(existing) != mergeKey { - continue - } - existing.Label = analyzeMergeDependencyLabels(existing.Label, connector.Label) - delete(seen, existingKey) - seen[workspace.ConnectorKey(existing)] = existing - return true +func (p *analyzeWatchProgress) Finish() { + if p == nil { + return } - return false -} - -func analyzeDependencyMergeKey(connector *workspace.Connector) string { - if connector == nil || connector.Relationship != "depends_on" || connector.Direction != "forward" { - return "" + p.mu.Lock() + defer p.mu.Unlock() + if p.bar == nil { + return } - return connector.View + ":" + connector.Source + ":" + connector.Target + ":" + connector.Relationship + _ = p.bar.Finish() + p.bar = nil } -func analyzeReverseConnectorKey(connector *workspace.Connector) string { - if connector == nil { - return "" +func newAnalyzeProgressBar(out io.Writer, total int) *progressbar.ProgressBar { + if total <= 0 || !term.IsTerminal(out) { + return nil } - return connector.View + ":" + connector.Target + ":" + connector.Source + ":" + connector.Label + return progressbar.NewOptions(total, + progressbar.OptionSetWriter(out), + progressbar.OptionSetVisibility(true), + progressbar.OptionSetDescription("Scanning"), + progressbar.OptionShowCount(), + progressbar.OptionShowIts(), + progressbar.OptionSetWidth(12), + progressbar.OptionFullWidth(), + progressbar.OptionClearOnFinish(), + progressbar.OptionUseANSICodes(true), + progressbar.OptionThrottle(60*time.Millisecond), + ) } diff --git a/cmd/analyze/analyze_connectors.go b/cmd/analyze/analyze_connectors.go deleted file mode 100644 index aa5810b..0000000 --- a/cmd/analyze/analyze_connectors.go +++ /dev/null @@ -1,357 +0,0 @@ -package analyze - -import ( - "context" - "os" - "path/filepath" - "strings" - - "github.com/mertcikla/tld/internal/analyzer" - "github.com/mertcikla/tld/internal/workspace" -) - -func buildAnalyzeConnectorsForRef( - ctx context.Context, - resolver analyzeDefinitionResolver, - ref analyzer.Ref, - ws *workspace.Workspace, - symbols []analyzer.Symbol, - symbolRefs map[analyzeElementLookupKey]string, - symbolRefsByName map[string][]string, - fileRefs map[string]string, - folderRefs map[string]string, - symbolFiles map[string]string, - repoRef string, - elementRoot string, - modulePath string, -) []*workspace.Connector { - kind := strings.TrimSpace(ref.Kind) - if kind == "import" { - return buildAnalyzeImportConnectors(ref, ws, fileRefs, folderRefs, repoRef, elementRoot, modulePath) - } - return buildAnalyzeReferenceConnectors(ctx, resolver, ref, ws, symbols, symbolRefs, symbolRefsByName, fileRefs, folderRefs, symbolFiles, repoRef) -} - -func buildAnalyzeReferenceConnectors( - ctx context.Context, - resolver analyzeDefinitionResolver, - ref analyzer.Ref, - ws *workspace.Workspace, - symbols []analyzer.Symbol, - symbolRefs map[analyzeElementLookupKey]string, - symbolRefsByName map[string][]string, - fileRefs map[string]string, - folderRefs map[string]string, - symbolFiles map[string]string, - repoRef string, -) []*workspace.Connector { - toRef := resolveAnalyzeTargetRef(ctx, resolver, ref, symbols, symbolRefs, symbolRefsByName) - if toRef == "" { - return nil - } - - fromRef := refByFileAndLine(ref.FilePath, ref.Line, symbolRefs, symbols) - if fromRef == "" || fromRef == toRef { - return nil - } - - connectors := []*workspace.Connector{{ - View: analyzeCommonConnectorView(ws, fromRef, toRef, repoRef), - Source: fromRef, - Target: toRef, - Label: "calls", - Relationship: "uses", - Direction: "forward", - }} - - sourceFile := symbolFiles[fromRef] - targetFile := symbolFiles[toRef] - if sourceFile == "" || targetFile == "" || sourceFile == targetFile { - return connectors - } - - sourceFileRef := fileRefs[sourceFile] - targetFileRef := fileRefs[targetFile] - if sourceFileRef != "" && targetFileRef != "" && sourceFileRef != targetFileRef { - connectors = append(connectors, &workspace.Connector{ - View: analyzeCommonConnectorView(ws, sourceFileRef, targetFileRef, repoRef), - Source: sourceFileRef, - Target: targetFileRef, - Label: analyzeDependencyLabelReference, - Relationship: "depends_on", - Direction: "forward", - }) - } - - sourceFolderRef := analyzeFolderRefForFile(sourceFile, folderRefs, repoRef) - targetFolderRef := analyzeFolderRefForFile(targetFile, folderRefs, repoRef) - if sourceFolderRef != "" && targetFolderRef != "" && sourceFolderRef != targetFolderRef { - connectors = append(connectors, &workspace.Connector{ - View: analyzeCommonConnectorView(ws, sourceFolderRef, targetFolderRef, repoRef), - Source: sourceFolderRef, - Target: targetFolderRef, - Label: analyzeDependencyLabelReference, - Relationship: "depends_on", - Direction: "forward", - }) - } - - return connectors -} - -func buildAnalyzeImportConnectors( - ref analyzer.Ref, - ws *workspace.Workspace, - fileRefs map[string]string, - folderRefs map[string]string, - repoRef string, - elementRoot string, - modulePath string, -) []*workspace.Connector { - targetDir := analyzeRepoRelativeImportDir(ref.TargetPath, modulePath) - if targetDir == "" && strings.HasSuffix(ref.FilePath, ".py") { - targetDir = resolvePythonImportDir(ref.FilePath, ref.TargetPath, elementRoot) - } - - if targetDir == "" { - return nil - } - - sourceFile := analyzeRelativeFilePath(ref.FilePath, elementRoot) - sourceFileRef := fileRefs[sourceFile] - targetFolderRef := analyzeFolderRefForDir(targetDir, folderRefs, repoRef) - if sourceFileRef == "" || targetFolderRef == "" || sourceFileRef == targetFolderRef { - return nil - } - - connectors := []*workspace.Connector{{ - View: analyzeCommonConnectorView(ws, sourceFileRef, targetFolderRef, repoRef), - Source: sourceFileRef, - Target: targetFolderRef, - Label: analyzeDependencyLabelImport, - Relationship: "depends_on", - Direction: "forward", - }} - - sourceFolderRef := analyzeFolderRefForFile(sourceFile, folderRefs, repoRef) - if sourceFolderRef != "" && sourceFolderRef != targetFolderRef { - connectors = append(connectors, &workspace.Connector{ - View: analyzeCommonConnectorView(ws, sourceFolderRef, targetFolderRef, repoRef), - Source: sourceFolderRef, - Target: targetFolderRef, - Label: analyzeDependencyLabelImport, - Relationship: "depends_on", - Direction: "forward", - }) - } - - return connectors -} - -func resolvePythonImportDir(filePath, targetPath, elementRoot string) string { - if strings.HasPrefix(targetPath, ".") { - return resolvePythonRelativeImport(filePath, targetPath, elementRoot) - } - return resolvePythonAbsoluteImport(targetPath, elementRoot) -} - -func resolvePythonRelativeImport(filePath, targetPath, elementRoot string) string { - dir := filepath.Dir(filePath) - dots := 0 - for strings.HasPrefix(targetPath[dots:], ".") { - dots++ - } - for i := 0; i < dots-1; i++ { - dir = filepath.Dir(dir) - } - importPath := strings.ReplaceAll(targetPath[dots:], ".", "/") - fullPath := filepath.Join(dir, importPath) - rel, err := filepath.Rel(elementRoot, fullPath) - if err != nil { - return "" - } - return filepath.Clean(rel) -} - -func resolvePythonAbsoluteImport(targetPath, elementRoot string) string { - importPath := strings.ReplaceAll(targetPath, ".", "/") - fullPath := filepath.Join(elementRoot, importPath) - - // Check if it's a directory (package) - if info, err := os.Stat(fullPath); err == nil && info.IsDir() { - return importPath - } - // Check if it's a file (module) - if info, err := os.Stat(fullPath + ".py"); err == nil && !info.IsDir() { - return filepath.Dir(importPath) - } - - return "" -} - -func analyzeFolderRefForFile(filePath string, folderRefs map[string]string, repoRef string) string { - return analyzeFolderRefForDir(filepath.Dir(filePath), folderRefs, repoRef) -} - -func analyzeFolderRefForDir(dir string, folderRefs map[string]string, repoRef string) string { - cleanDir := filepath.Clean(dir) - if cleanDir == "." || cleanDir == string(os.PathSeparator) || cleanDir == "" { - return repoRef - } - if ref := folderRefs[cleanDir]; ref != "" { - return ref - } - return "" -} - -func analyzeRepoRelativeImportDir(importPath, modulePath string) string { - cleanImportPath := strings.TrimSpace(importPath) - cleanModulePath := strings.TrimSpace(modulePath) - if cleanImportPath == "" || cleanModulePath == "" { - return "" - } - if cleanImportPath == cleanModulePath { - return "." - } - prefix := cleanModulePath + "/" - if !strings.HasPrefix(cleanImportPath, prefix) { - return "" - } - return filepath.Clean(filepath.FromSlash(strings.TrimPrefix(cleanImportPath, prefix))) -} - -func analyzeModulePath(repoRoot string) string { - if repoRoot == "" { - return "" - } - data, err := os.ReadFile(filepath.Join(repoRoot, "go.mod")) - if err != nil { - return "" - } - for line := range strings.SplitSeq(string(data), "\n") { - trimmed := strings.TrimSpace(line) - if trimmed == "" || strings.HasPrefix(trimmed, "//") { - continue - } - if !strings.HasPrefix(trimmed, "module ") { - continue - } - fields := strings.Fields(trimmed) - if len(fields) < 2 { - return "" - } - return strings.Trim(fields[1], "\"") - } - return "" -} - -const ( - analyzeDependencyLabelImport = "imports" - analyzeDependencyLabelReference = "references" - analyzeDependencyLabelBoth = "depends_on" -) - -func analyzeDependencyKindsForLabel(label string) (hasImport bool, hasReference bool) { - switch strings.TrimSpace(label) { - case analyzeDependencyLabelImport: - return true, false - case analyzeDependencyLabelReference: - return false, true - case analyzeDependencyLabelBoth: - return true, true - default: - return false, false - } -} - -func analyzeMergeDependencyLabels(left, right string) string { - leftImport, leftReference := analyzeDependencyKindsForLabel(left) - rightImport, rightReference := analyzeDependencyKindsForLabel(right) - hasImport := leftImport || rightImport - hasReference := leftReference || rightReference - switch { - case hasImport && hasReference: - return analyzeDependencyLabelBoth - case hasImport: - return analyzeDependencyLabelImport - case hasReference: - return analyzeDependencyLabelReference - default: - return left - } -} - -func analyzeCommonConnectorView(ws *workspace.Workspace, fromRef, toRef, fallback string) string { - if fallback == "" { - fallback = "root" - } - if ws == nil { - return fallback - } - fromAncestors := analyzeAncestorDepths(ws, fromRef) - toAncestors := analyzeAncestorDepths(ws, toRef) - bestRef := fallback - bestScore := int(^uint(0) >> 1) - for ref, fromDepth := range fromAncestors { - toDepth, ok := toAncestors[ref] - if !ok { - continue - } - score := fromDepth + toDepth - if score < bestScore || (score == bestScore && bestRef == "root" && ref != "root") { - bestRef = ref - bestScore = score - } - } - return bestRef -} - -func analyzeAncestorDepths(ws *workspace.Workspace, ref string) map[string]int { - depths := map[string]int{"root": 1 << 20} - type queueItem struct { - ref string - depth int - } - queue := make([]queueItem, 0, 4) - seedParents := []string{"root"} - if element := ws.Elements[ref]; element != nil && len(element.Placements) > 0 { - seedParents = seedParents[:0] - for _, placement := range element.Placements { - parentRef := placement.ParentRef - if parentRef == "" { - parentRef = "root" - } - seedParents = append(seedParents, parentRef) - } - } - for _, parentRef := range seedParents { - queue = append(queue, queueItem{ref: parentRef, depth: 0}) - } - for len(queue) > 0 { - current := queue[0] - queue = queue[1:] - if existingDepth, ok := depths[current.ref]; ok && existingDepth <= current.depth { - continue - } - depths[current.ref] = current.depth - if current.ref == "root" { - continue - } - element := ws.Elements[current.ref] - if element == nil || len(element.Placements) == 0 { - queue = append(queue, queueItem{ref: "root", depth: current.depth + 1}) - continue - } - for _, placement := range element.Placements { - parentRef := placement.ParentRef - if parentRef == "" { - parentRef = "root" - } - queue = append(queue, queueItem{ref: parentRef, depth: current.depth + 1}) - } - } - if _, ok := depths["root"]; !ok { - depths["root"] = 0 - } - return depths -} diff --git a/cmd/analyze/analyze_internal_test.go b/cmd/analyze/analyze_internal_test.go deleted file mode 100644 index 3e6ac72..0000000 --- a/cmd/analyze/analyze_internal_test.go +++ /dev/null @@ -1,159 +0,0 @@ -package analyze - -import ( - "context" - "path/filepath" - "testing" - - "github.com/mertcikla/tld/internal/analyzer" - "github.com/mertcikla/tld/internal/workspace" -) - -func TestEnsureAnalyzeElement_ReusesIdentityWithinRun(t *testing.T) { - dir := t.TempDir() - ws := &workspace.Workspace{ - Dir: dir, - Elements: map[string]*workspace.Element{}, - } - known := buildAnalyzeElementIndex(ws) - usedRefs := map[string]struct{}{} - usedNames := buildAnalyzeElementNameOwners(ws) - spec := analyzeElementSpec{ - Name: "alias_for_table", - Kind: "function", - Owner: "digitaltwin-poc", - Repo: "git@gitlab.btsgrp.com:ridvan.zengin/dtwin.git", - Branch: "main", - FilePath: filepath.Clean("backhaul_analysis/backhaul_analysis.py"), - Symbol: "alias_for_table", - ParentRef: "backhaul-analysis-py", - Identity: analyzeElementIdentity{ - Repo: "git@gitlab.btsgrp.com:ridvan.zengin/dtwin.git", - Branch: "main", - FilePath: "backhaul_analysis/backhaul_analysis.py", - Symbol: "alias_for_table", - Kind: "function", - Name: "alias_for_table", - }, - } - - firstRef, err := ensureAnalyzeElement(dir, false, ws, known, usedRefs, usedNames, spec) - if err != nil { - t.Fatalf("first ensureAnalyzeElement: %v", err) - } - secondRef, err := ensureAnalyzeElement(dir, false, ws, known, usedRefs, usedNames, spec) - if err != nil { - t.Fatalf("second ensureAnalyzeElement: %v", err) - } - if firstRef != secondRef { - t.Fatalf("refs differ: first=%q second=%q", firstRef, secondRef) - } - if len(ws.Elements) != 1 { - t.Fatalf("elements = %d, want 1", len(ws.Elements)) - } - if _, ok := ws.Elements[firstRef]; !ok { - t.Fatalf("missing expected ref %q in ws.Elements", firstRef) - } -} - -func TestEnsureAnalyzeElement_DoesNotReuseNameAcrossFiles(t *testing.T) { - dir := t.TempDir() - ws := &workspace.Workspace{ - Dir: dir, - Elements: map[string]*workspace.Element{}, - } - known := buildAnalyzeElementIndex(ws) - usedRefs := map[string]struct{}{} - usedNames := buildAnalyzeElementNameOwners(ws) - - firstSpec := analyzeElementSpec{ - Name: "Load", - Kind: "function", - Owner: "tld", - Repo: "git@github.com:Mertcikla/tld-cli.git", - Branch: "main", - FilePath: "workspace/loader.go", - Symbol: "Load", - ParentRef: "loader-go", - Identity: analyzeElementIdentity{ - Repo: "git@github.com:Mertcikla/tld-cli.git", - Branch: "main", - FilePath: "workspace/loader.go", - Symbol: "Load", - Kind: "function", - Name: "Load", - }, - } - secondSpec := firstSpec - secondSpec.FilePath = "planner/loader.go" - secondSpec.ParentRef = "planner-loader-go" - secondSpec.Identity.FilePath = "planner/loader.go" - - firstRef, err := ensureAnalyzeElement(dir, false, ws, known, usedRefs, usedNames, firstSpec) - if err != nil { - t.Fatalf("first ensureAnalyzeElement: %v", err) - } - secondRef, err := ensureAnalyzeElement(dir, false, ws, known, usedRefs, usedNames, secondSpec) - if err != nil { - t.Fatalf("second ensureAnalyzeElement: %v", err) - } - if firstRef == secondRef { - t.Fatalf("expected distinct refs for duplicate symbol names, got %q", firstRef) - } - if len(ws.Elements) != 2 { - t.Fatalf("elements = %d, want 2", len(ws.Elements)) - } - if ws.Elements[firstRef].Name == ws.Elements[secondRef].Name { - t.Fatalf("expected unique element names, both were %q", ws.Elements[firstRef].Name) - } -} - -func TestResolveAnalyzeTargetRef_UsesDefinitionLocation(t *testing.T) { - symbols := []analyzer.Symbol{ - {Name: "Load", Kind: "function", FilePath: "cmd/loader.go", Line: 10, EndLine: 20}, - {Name: "Load", Kind: "function", FilePath: "workspace/loader.go", Line: 30, EndLine: 40}, - } - refBySymbol := map[analyzeElementLookupKey]string{ - analyzeSymbolLookupKey(symbols[0]): "cmd-load", - analyzeSymbolLookupKey(symbols[1]): "workspace-load", - } - refsByName := map[string][]string{"Load": {"cmd-load", "workspace-load"}} - - resolved := resolveAnalyzeTargetRef(context.Background(), fakeAnalyzeDefinitionResolver{ - locations: []analyzeDefinitionLocation{{FilePath: "workspace/loader.go", Line: 30}}, - }, analyzer.Ref{Name: "Load", FilePath: "cmd/analyze.go", Line: 5, Column: 12}, symbols, refBySymbol, refsByName) - - if resolved != "workspace-load" { - t.Fatalf("resolved ref = %q, want workspace-load", resolved) - } -} - -func TestResolveAnalyzeTargetRef_DropsAmbiguousFallback(t *testing.T) { - symbols := []analyzer.Symbol{ - {Name: "Load", Kind: "function", FilePath: "cmd/loader.go", Line: 10, EndLine: 20}, - {Name: "Load", Kind: "function", FilePath: "workspace/loader.go", Line: 30, EndLine: 40}, - } - refBySymbol := map[analyzeElementLookupKey]string{ - analyzeSymbolLookupKey(symbols[0]): "cmd-load", - analyzeSymbolLookupKey(symbols[1]): "workspace-load", - } - refsByName := map[string][]string{"Load": {"cmd-load", "workspace-load"}} - - resolved := resolveAnalyzeTargetRef(context.Background(), nil, analyzer.Ref{Name: "Load", FilePath: "cmd/analyze.go", Line: 5}, symbols, refBySymbol, refsByName) - if resolved != "" { - t.Fatalf("resolved ref = %q, want empty", resolved) - } -} - -type fakeAnalyzeDefinitionResolver struct { - locations []analyzeDefinitionLocation - err error -} - -func (r fakeAnalyzeDefinitionResolver) ResolveDefinitions(context.Context, analyzer.Ref) ([]analyzeDefinitionLocation, error) { - return r.locations, r.err -} - -func (r fakeAnalyzeDefinitionResolver) Close() error { - return nil -} diff --git a/cmd/analyze/analyze_repo_scope_test.go b/cmd/analyze/analyze_repo_scope_test.go deleted file mode 100644 index 8237210..0000000 --- a/cmd/analyze/analyze_repo_scope_test.go +++ /dev/null @@ -1,157 +0,0 @@ -package analyze_test - -import ( - "os" - "os/exec" - "path/filepath" - "strings" - "testing" - - "github.com/mertcikla/tld/cmd" -) - -func TestAnalyzeCmd_RejectsRepoOutsideConfiguredRepositories(t *testing.T) { - workspaceDir := t.TempDir() - cmd.MustInitWorkspace(t, workspaceDir) - - workspaceCfg := strings.Join([]string{ - "project_name: Demo", - "repositories:", - " frontend:", - " url: github.com/example/frontend", - " localDir: frontend", - "exclude: []", - "", - }, "\n") - if err := os.WriteFile(filepath.Join(workspaceDir, ".tld.yaml"), []byte(workspaceCfg), 0600); err != nil { - t.Fatalf("write workspace config: %v", err) - } - cmd.InitGitRepo(t, filepath.Join(workspaceDir, "frontend"), "frontend.go", "package frontend\nfunc FrontendService() {}\n") - - repoDir := t.TempDir() - git := func(args ...string) { - cmd := exec.Command("git", args...) - cmd.Dir = repoDir - cmd.Env = append(os.Environ(), "GIT_AUTHOR_NAME=Test", "GIT_AUTHOR_EMAIL=test@example.com", "GIT_COMMITTER_NAME=Test", "GIT_COMMITTER_EMAIL=test@example.com") - out, err := cmd.CombinedOutput() - if err != nil { - t.Fatalf("git %v: %v\n%s", args, err, out) - } - } - - git("init", "-b", "main") - if err := os.WriteFile(filepath.Join(repoDir, "main.go"), []byte("package main\n"), 0600); err != nil { - t.Fatalf("write main.go: %v", err) - } - git("add", "main.go") - git("commit", "-m", "initial") - - stdout, stderr, err := cmd.RunCmd(t, workspaceDir, "analyze", filepath.Join(repoDir, "main.go")) - if err == nil { - t.Fatalf("expected analyze to fail for repo outside configured repositories\nstdout: %s\nstderr: %s", stdout, stderr) - } - if !strings.Contains(err.Error(), "repo") || !strings.Contains(err.Error(), "repository") { - t.Fatalf("unexpected error: %v\nstdout: %s\nstderr: %s", err, stdout, stderr) - } - if strings.Contains(stdout, "Analyzed:") { - t.Fatalf("analyze should not write workspace changes\nstdout: %s\nstderr: %s", stdout, stderr) - } -} - -func TestAnalyzeCmd_DiscoversConfiguredRepositories(t *testing.T) { - workspaceDir := t.TempDir() - cmd.MustInitWorkspace(t, workspaceDir) - - workspaceCfg := strings.Join([]string{ - "project_name: Demo", - "repositories:", - " frontend:", - " url: github.com/example/frontend", - " localDir: frontend", - " backend:", - " url: github.com/example/backend", - " localDir: backend", - "exclude: []", - "", - }, "\n") - if err := os.WriteFile(filepath.Join(workspaceDir, ".tld.yaml"), []byte(workspaceCfg), 0600); err != nil { - t.Fatalf("write workspace config: %v", err) - } - - cmd.InitGitRepo(t, filepath.Join(workspaceDir, "frontend"), "frontend.go", "package frontend\nfunc FrontendService() {}\n") - cmd.InitGitRepo(t, filepath.Join(workspaceDir, "backend"), "backend.go", "package backend\nfunc BackendService() {}\n") - - stdout, stderr, err := cmd.RunCmd(t, workspaceDir, "analyze", workspaceDir, "--dry-run") - if err != nil { - t.Fatalf("analyze: %v\nstdout: %s\nstderr: %s", err, stdout, stderr) - } - if !strings.Contains(stdout, "Analyzing "+workspaceDir+" (shallow)...") { - t.Fatalf("stdout does not show analyze header\nstdout: %s\nstderr: %s", stdout, stderr) - } - if !strings.Contains(stdout, "[dry-run] OK 2 repositories scanned") { - t.Fatalf("stdout does not summarize both repos\nstdout: %s\nstderr: %s", stdout, stderr) - } - if !strings.Contains(stdout, "[dry-run] No files written. Remove --dry-run to apply.") { - t.Fatalf("stdout missing dry-run guidance\nstdout: %s\nstderr: %s", stdout, stderr) - } -} - -func TestAnalyzeCmd_ChangedSinceLimitsScan(t *testing.T) { - workspaceDir := t.TempDir() - cmd.MustInitWorkspace(t, workspaceDir) - - workspaceCfg := strings.Join([]string{ - "project_name: Demo", - "repositories:", - " frontend:", - " url: github.com/example/frontend", - " localDir: frontend", - "exclude: []", - "", - }, "\n") - if err := os.WriteFile(filepath.Join(workspaceDir, ".tld.yaml"), []byte(workspaceCfg), 0600); err != nil { - t.Fatalf("write workspace config: %v", err) - } - - repoDir := filepath.Join(workspaceDir, "frontend") - cmd.InitGitRepo(t, repoDir, "frontend.go", "package frontend\nfunc FrontendService() {}\n") - - baseCmd := exec.Command("git", "rev-parse", "HEAD") - baseCmd.Dir = repoDir - base, err := baseCmd.Output() - if err != nil { - t.Fatalf("rev-parse HEAD: %v", err) - } - - if err := os.WriteFile(filepath.Join(repoDir, "frontend.go"), []byte("package frontend\nfunc FrontendService() {}\nfunc NewFrontendService() {}\n"), 0600); err != nil { - t.Fatalf("write frontend.go: %v", err) - } - commit := exec.Command("git", "commit", "-am", "update") - commit.Dir = repoDir - commit.Env = append(os.Environ(), - "GIT_AUTHOR_NAME=Test", - "GIT_AUTHOR_EMAIL=test@example.com", - "GIT_COMMITTER_NAME=Test", - "GIT_COMMITTER_EMAIL=test@example.com", - ) - if out, err := commit.CombinedOutput(); err != nil { - t.Fatalf("git commit: %v\n%s", err, out) - } - - stdout, stderr, err := cmd.RunCmd(t, workspaceDir, "analyze", workspaceDir, "--dry-run", "--changed-since", strings.TrimSpace(string(base))) - if err != nil { - t.Fatalf("analyze: %v\nstdout: %s\nstderr: %s", err, stdout, stderr) - } - if !strings.Contains(stdout, "[dry-run] OK Incremental scan: 1 files changed since ") { - t.Fatalf("stdout missing incremental summary\nstdout: %s\nstderr: %s", stdout, stderr) - } - if !strings.Contains(stdout, "[dry-run] OK 4 elements written to elements.yaml") { - t.Fatalf("stdout missing changed-file element count\nstdout: %s\nstderr: %s", stdout, stderr) - } - if !strings.Contains(stdout, "[dry-run] OK 1 repositories scanned") { - t.Fatalf("stdout missing repository count\nstdout: %s\nstderr: %s", stdout, stderr) - } - if !strings.Contains(stdout, "[dry-run] No files written. Remove --dry-run to apply.") { - t.Fatalf("stdout missing dry-run guidance\nstdout: %s\nstderr: %s", stdout, stderr) - } -} diff --git a/cmd/analyze/analyze_test.go b/cmd/analyze/analyze_test.go index 1922be0..92f8b9e 100644 --- a/cmd/analyze/analyze_test.go +++ b/cmd/analyze/analyze_test.go @@ -1,59 +1,24 @@ package analyze_test import ( + "encoding/json" "os" "path/filepath" "strings" "testing" "github.com/mertcikla/tld/cmd" - "github.com/mertcikla/tld/internal/workspace" ) -const ( - testAnalyzeDependencyLabelImport = "imports" - testAnalyzeDependencyLabelReference = "references" - testAnalyzeDependencyLabelBoth = "depends_on" -) - -func TestAnalyzeCmd_DryRun_NoWrite(t *testing.T) { - dir := t.TempDir() - cmd.MustInitWorkspace(t, dir) - file := filepath.Join(dir, "service.go") - if err := os.WriteFile(file, []byte("package main\nfunc Service() {}\n"), 0600); err != nil { - t.Fatal(err) - } - before, err := os.ReadFile(filepath.Join(dir, "elements.yaml")) - if err != nil { - t.Fatal(err) - } - - stdout, stderr, err := cmd.RunCmd(t, dir, "analyze", file, "--dry-run") - if err != nil { - t.Fatalf("analyze --dry-run: %v\nstdout: %s\nstderr: %s", err, stdout, stderr) - } - after, err := os.ReadFile(filepath.Join(dir, "elements.yaml")) - if err != nil { - t.Fatal(err) - } - if string(before) != string(after) { - t.Fatalf("elements.yaml changed during dry-run") - } - if !strings.Contains(stdout, "[dry-run] OK 3 elements written to elements.yaml") { - t.Fatalf("unexpected stdout: %s", stdout) - } -} - -func TestAnalyzeCmd_WritesElements(t *testing.T) { +func TestAnalyzeCmd_WatchPipelineWritesYAML(t *testing.T) { dir := t.TempDir() + dataDir := t.TempDir() cmd.MustInitWorkspace(t, dir) - file := filepath.Join(dir, "service.go") - if err := os.WriteFile(file, []byte("package main\nfunc Foo() {}\nfunc Bar() {}\n"), 0600); err != nil { - t.Fatal(err) - } + repoDir := filepath.Join(dir, "app") + cmd.InitGitRepo(t, repoDir, "service.go", "package main\nfunc Foo() {}\nfunc Bar() {}\n") - stdout, stderr, err := cmd.RunCmd(t, dir, "analyze", file) + stdout, stderr, err := cmd.RunCmd(t, dir, "analyze", repoDir, "--data-dir", dataDir, "--embedding-provider", "none") if err != nil { t.Fatalf("analyze: %v\nstdout: %s\nstderr: %s", err, stdout, stderr) } @@ -61,177 +26,161 @@ func TestAnalyzeCmd_WritesElements(t *testing.T) { if err != nil { t.Fatal(err) } - if len(ws.Elements) != 4 { - t.Fatalf("elements = %d, want 4", len(ws.Elements)) + if countKind(ws, "repository") != 1 || countKind(ws, "file") != 1 || countKind(ws, "function") != 2 { + t.Fatalf("unexpected analyzed elements: %+v", ws.Elements) } for ref, element := range ws.Elements { - if element.Symbol == "" { - continue - } - if len(element.Placements) == 0 { + if element.Kind == "function" && len(element.Placements) == 0 { t.Fatalf("symbol %q (%s) has no placement", element.Name, ref) } - if element.Placements[0].ParentRef == "root" { - t.Fatalf("symbol %q (%s) was created at root", element.Name, ref) - } } -} -func TestAnalyzeCmd_CreatesFolderHierarchy(t *testing.T) { - dir := t.TempDir() - cmd.MustInitWorkspace(t, dir) - repoDir := filepath.Join(dir, "app") - if err := os.MkdirAll(filepath.Join(repoDir, "internal", "service"), 0750); err != nil { - t.Fatal(err) - } - cmd.InitGitRepo(t, repoDir, filepath.Join("internal", "service", "service.go"), "package service\nfunc Run() {}\n") - - stdout, stderr, err := cmd.RunCmd(t, dir, "analyze", repoDir) + stdout, stderr, err = cmd.RunCmd(t, dir, "analyze", repoDir, "--data-dir", dataDir, "--embedding-provider", "none") if err != nil { - t.Fatalf("analyze: %v\nstdout: %s\nstderr: %s", err, stdout, stderr) + t.Fatalf("second analyze: %v\nstdout: %s\nstderr: %s", err, stdout, stderr) } - ws, err := workspace.Load(dir) + ws, err = workspace.Load(dir) if err != nil { t.Fatal(err) } - - var folderRefs []string - var fileRef string - for ref, element := range ws.Elements { - if element.Kind == "folder" { - folderRefs = append(folderRefs, ref) - } - if element.Kind == "file" && element.FilePath == filepath.Join("internal", "service", "service.go") { - fileRef = ref + for _, element := range ws.Elements { + if element.Kind == "file" && strings.HasPrefix(element.FilePath, ".tld/") { + t.Fatalf("generated workspace YAML should not be scanned as source: %+v", element) } } - if len(folderRefs) != 2 { - t.Fatalf("folder elements = %d, want 2: %+v", len(folderRefs), ws.Elements) - } - if fileRef == "" { - t.Fatalf("expected nested file element, got %+v", ws.Elements) - } - fileElement := ws.Elements[fileRef] - if len(fileElement.Placements) == 0 { - t.Fatalf("file element has no placements: %+v", fileElement) - } - parentRef := fileElement.Placements[0].ParentRef - parent := ws.Elements[parentRef] - if parent == nil || parent.Kind != "folder" || parent.FilePath != filepath.Join("internal", "service") { - t.Fatalf("file parent = %q (%+v), want folder internal/service", parentRef, parent) - } - grandparent := ws.Elements[parent.Placements[0].ParentRef] - if grandparent == nil || grandparent.Kind != "folder" || grandparent.FilePath != "internal" { - t.Fatalf("folder parent = %+v, want folder internal", grandparent) - } } -func TestAnalyzeCmd_ReusesExistingElements(t *testing.T) { +func TestAnalyzeCmd_RuntimeArtifactsUseArchitectureView(t *testing.T) { dir := t.TempDir() + dataDir := t.TempDir() cmd.MustInitWorkspace(t, dir) - - repoDir := filepath.Join(dir, "backhaul_analysis") - if err := os.MkdirAll(repoDir, 0750); err != nil { - t.Fatal(err) - } - cmd.InitGitRepo(t, repoDir, "backhaul_analysis.py", "from collections import OrderedDict\n\n\ndef get_columns():\n return []\n") - - if err := workspace.UpsertElement(dir, "backhaul-analysis", &workspace.Element{ - Name: "backhaul_analysis", - Kind: "repository", - Branch: "main", - HasView: true, - ViewLabel: "backhaul_analysis", - Placements: []workspace.ViewPlacement{{ParentRef: "root"}}, - }); err != nil { - t.Fatal(err) - } - if err := workspace.UpsertElement(dir, "backhaul-analysis-py", &workspace.Element{ - Name: "backhaul_analysis.py", - Kind: "file", - Branch: "main", - FilePath: "backhaul_analysis.py", - HasView: true, - ViewLabel: "backhaul_analysis.py", - Placements: []workspace.ViewPlacement{{ParentRef: "backhaul-analysis"}}, - }); err != nil { - t.Fatal(err) - } - if err := workspace.UpsertElement(dir, "existing-get-columns", &workspace.Element{ - Name: "get_columns", - Kind: "function", - Branch: "main", - FilePath: "backhaul_analysis.py", - Symbol: "get_columns", - Placements: []workspace.ViewPlacement{{ParentRef: "backhaul-analysis-py"}}, - }); err != nil { - t.Fatal(err) - } - - stdout, stderr, err := cmd.RunCmd(t, dir, "analyze", repoDir) + repoDir := filepath.Join(dir, "runtime-app") + cmd.InitGitRepo(t, repoDir, "generated/client.go", "package generated\n\n// Code generated by protoc. DO NOT EDIT.\nfunc NoisyStub() {}\n") + writeAnalyzeTestFile(t, repoDir, "deploy/topology.yaml", ` +apiVersion: apps/v1 +kind: Deployment +metadata: + name: alpha +spec: + template: + spec: + containers: + - name: app + image: example/alpha:latest + ports: + - containerPort: 8080 + env: + - name: PEER + value: "beta:9090" + - name: CACHE + value: "cache:6379" +--- +apiVersion: v1 +kind: Service +metadata: + name: alpha +spec: + type: LoadBalancer + ports: + - port: 80 + targetPort: 8080 +--- +apiVersion: apps/v1 +kind: Deployment +metadata: + name: beta +spec: + template: + spec: + containers: + - name: app + image: example/beta:latest + ports: + - containerPort: 9090 +--- +apiVersion: v1 +kind: Service +metadata: + name: beta +spec: + ports: + - port: 9090 +--- +apiVersion: v1 +kind: Service +metadata: + name: cache +spec: + ports: + - port: 6379 +`) + + stdout, stderr, err := cmd.RunCmd(t, dir, "analyze", repoDir, "--data-dir", dataDir, "--embedding-provider", "none") if err != nil { t.Fatalf("analyze: %v\nstdout: %s\nstderr: %s", err, stdout, stderr) } - ws, err := workspace.Load(dir) if err != nil { t.Fatal(err) } - if _, ok := ws.Elements["existing-get-columns"]; !ok { - t.Fatalf("expected existing symbol ref to be reused, got keys: %v", ws.Elements) + architectureRef := refByElementName(ws, "Architecture") + structuralRef := refByElementName(ws, "Structural") + repositoryRef := refByKind(ws, "repository") + if architectureRef == "" || structuralRef == "" || repositoryRef == "" { + t.Fatalf("missing repository sections: %+v", ws.Elements) } - if len(ws.Elements) != 3 { - t.Fatalf("elements = %d, want 3", len(ws.Elements)) + if !hasPlacementParent(ws, architectureRef, repositoryRef) || !hasPlacementParent(ws, structuralRef, repositoryRef) { + t.Fatalf("architecture and structural sections should be siblings under repository: %+v", ws.Elements) } - element := ws.Elements["existing-get-columns"] - if len(element.Placements) == 0 || element.Placements[0].ParentRef != "backhaul-analysis-py" { - t.Fatalf("reused symbol placement = %+v, want parent backhaul-analysis-py", element.Placements) + for _, name := range []string{"alpha", "beta", "cache", "External traffic"} { + ref := refByElementNameWithParent(ws, name, architectureRef) + if ref == "" { + t.Fatalf("missing architecture element %q in %+v", name, ws.Elements) + } } -} - -func TestAnalyzeCmd_WritesConnectors(t *testing.T) { - dir := t.TempDir() - cmd.MustInitWorkspace(t, dir) - if err := os.WriteFile(filepath.Join(dir, "foo.go"), []byte("package main\nfunc Foo() {}\n"), 0600); err != nil { - t.Fatal(err) + deployRef := refByElementName(ws, "deploy") + if deployRef == "" || !hasPlacementParent(ws, deployRef, structuralRef) { + t.Fatalf("top-level structural folder should be under Structural, ref=%q elements=%+v", deployRef, ws.Elements) } - if err := os.WriteFile(filepath.Join(dir, "bar.go"), []byte("package main\nfunc Bar() { Foo() }\n"), 0600); err != nil { - t.Fatal(err) + if ref := refByElementName(ws, "topology.yaml"); ref == "" || !hasPlacementParent(ws, ref, deployRef) { + t.Fatalf("structural file should be under its folder view, ref=%q elements=%+v", ref, ws.Elements) } - - stdout, stderr, err := cmd.RunCmd(t, dir, "analyze", dir) - if err != nil { - t.Fatalf("analyze: %v\nstdout: %s\nstderr: %s", err, stdout, stderr) + alphaRef := refByElementNameWithParent(ws, "alpha", architectureRef) + topologyRef := refByElementName(ws, "topology.yaml") + if alphaRef == "" || topologyRef == "" || !ws.Elements[alphaRef].HasView || !hasPlacementParent(ws, topologyRef, alphaRef) { + t.Fatalf("bound architecture component should own a deep-dive view with structural targets: alpha=%q topology=%q %+v", alphaRef, topologyRef, ws.Elements) } - ws, err := workspace.Load(dir) - if err != nil { - t.Fatal(err) + if !connectorByElementNamesInParent(ws, "alpha", "beta", architectureRef) || !connectorByElementNamesInParent(ws, "alpha", "cache", architectureRef) || !connectorByElementNamesInParent(ws, "External traffic", "alpha", architectureRef) { + t.Fatalf("missing expected architecture connectors: %+v", ws.Connectors) } - if len(ws.Connectors) == 0 { - t.Fatalf("expected at least one connector, stdout=%s stderr=%s", stdout, stderr) + for _, connector := range ws.Connectors { + if connector.Description == "" { + t.Fatalf("architecture connector should include provenance/confidence: %+v", connector) + } } } -func TestAnalyzeCmd_AddsCrossFileAndCrossFolderConnectors(t *testing.T) { +func TestAnalyzeCmd_ComposeInferenceDoesNotDependOnFolderNames(t *testing.T) { dir := t.TempDir() + dataDir := t.TempDir() cmd.MustInitWorkspace(t, dir) - - repoDir := filepath.Join(dir, "app") - cmd.InitGitRepo(t, repoDir, "go.mod", "module example.com/demo\n\ngo 1.23.0\n") - if err := os.MkdirAll(filepath.Join(repoDir, "cmd", "app"), 0750); err != nil { - t.Fatal(err) - } - if err := os.MkdirAll(filepath.Join(repoDir, "internal", "service"), 0750); err != nil { - t.Fatal(err) - } - if err := os.WriteFile(filepath.Join(repoDir, "internal", "service", "service.go"), []byte("package service\n\nfunc Run() {}\n"), 0600); err != nil { - t.Fatal(err) - } - if err := os.WriteFile(filepath.Join(repoDir, "cmd", "app", "main.go"), []byte("package main\n\nimport \"example.com/demo/internal/service\"\n\nfunc main() {\n\tservice.Run()\n}\n"), 0600); err != nil { - t.Fatal(err) - } - - stdout, stderr, err := cmd.RunCmd(t, dir, "analyze", repoDir) + repoDir := filepath.Join(dir, "odd-layout") + cmd.InitGitRepo(t, repoDir, "main.go", "package main\nfunc Main() {}\n") + writeAnalyzeTestFile(t, repoDir, "ops/runtime.yml", ` +services: + worker: + image: example/worker + environment: + TARGET_URL: "http://api:8080" + api: + image: example/api + ports: + - "8080:8080" + datastore: + image: redis:7 +`) + + stdout, stderr, err := cmd.RunCmd(t, dir, "analyze", repoDir, "--data-dir", dataDir, "--embedding-provider", "none") if err != nil { t.Fatalf("analyze: %v\nstdout: %s\nstderr: %s", err, stdout, stderr) } @@ -239,103 +188,130 @@ func TestAnalyzeCmd_AddsCrossFileAndCrossFolderConnectors(t *testing.T) { if err != nil { t.Fatal(err) } - - mainFileRef := findAnalyzeElementRefByKindAndPath(t, ws, "file", filepath.Join("cmd", "app", "main.go")) - serviceFileRef := findAnalyzeElementRefByKindAndPath(t, ws, "file", filepath.Join("internal", "service", "service.go")) - cmdFolderRef := findAnalyzeElementRefByKindAndPath(t, ws, "folder", filepath.Join("cmd", "app")) - serviceFolderRef := findAnalyzeElementRefByKindAndPath(t, ws, "folder", filepath.Join("internal", "service")) - - assertAnalyzeConnectorExists(t, ws, mainFileRef, serviceFileRef, testAnalyzeDependencyLabelReference) - assertAnalyzeConnectorExists(t, ws, mainFileRef, serviceFolderRef, testAnalyzeDependencyLabelImport) - assertAnalyzeConnectorExists(t, ws, cmdFolderRef, serviceFolderRef, testAnalyzeDependencyLabelBoth) - assertAnalyzeConnectorCount(t, ws, cmdFolderRef, serviceFolderRef, testAnalyzeDependencyLabelBoth, 1) - if len(ws.Connectors) < 4 { - t.Fatalf("expected at least 4 connectors, got %d: %+v", len(ws.Connectors), ws.Connectors) + if refByElementName(ws, "worker") == "" || refByElementName(ws, "api") == "" { + t.Fatalf("compose services were not inferred generically: %+v", ws.Elements) } -} - -func TestAnalyzeCmd_MergesReverseConnectorsAsBidirectional(t *testing.T) { - dir := t.TempDir() - cmd.MustInitWorkspace(t, dir) - if err := os.WriteFile(filepath.Join(dir, "foo.go"), []byte("package main\nfunc Foo() { Bar() }\n"), 0600); err != nil { - t.Fatal(err) + if !connectorByElementNames(ws, "worker", "api") { + t.Fatalf("expected env endpoint connector from worker to api, got %+v", ws.Connectors) } - if err := os.WriteFile(filepath.Join(dir, "bar.go"), []byte("package main\nfunc Bar() { Foo() }\n"), 0600); err != nil { - t.Fatal(err) + architectureRef := refByElementName(ws, "Architecture") + structuralRef := refByElementName(ws, "Structural") + if architectureRef == "" || structuralRef == "" { + t.Fatalf("missing repository sections: %+v", ws.Elements) } - - stdout, stderr, err := cmd.RunCmd(t, dir, "analyze", dir) - if err != nil { - t.Fatalf("analyze: %v\nstdout: %s\nstderr: %s", err, stdout, stderr) + for _, name := range []string{"worker", "api"} { + ref := refByElementNameWithParent(ws, name, architectureRef) + if ref == "" { + t.Fatalf("architecture element %q should be under Architecture, ref=%q placements=%+v", name, ref, ws.Elements[ref]) + } } - ws, err := workspace.Load(dir) - if err != nil { - t.Fatal(err) + if ref := refByElementName(ws, "Main"); ref == "" { + t.Fatalf("structural symbol should still be materialized: %+v", ws.Elements) + } else if !hasPlacementParent(ws, ref, refByElementName(ws, "main.go")) { + t.Fatalf("symbol should remain nested under its file view: %+v", ws.Elements[ref].Placements) + } + if ref := refByElementName(ws, "main.go"); ref == "" || !hasPlacementParent(ws, ref, structuralRef) { + t.Fatalf("top-level structural file should be under Structural, ref=%q elements=%+v", ref, ws.Elements) } - - fooRef := findAnalyzeElementRefBySymbol(t, ws, "Foo") - barRef := findAnalyzeElementRefBySymbol(t, ws, "Bar") - assertBidirectionalAnalyzeConnector(t, ws, fooRef, barRef, "calls") - assertAnalyzeConnectorCountUnordered(t, ws, fooRef, barRef, "calls", 1) - assertAnalyzeConnectorCountByLabel(t, ws, "calls", 1) - assertAnalyzeConnectorCountByLabel(t, ws, testAnalyzeDependencyLabelReference, 1) } -func TestAnalyzeCmd_DeepModeDoesNotDoubleConnectorCounts(t *testing.T) { +func TestAnalyzeCmd_CrossRepositoryArchitectureLinksReuseStructuralElements(t *testing.T) { dir := t.TempDir() + dataDir := t.TempDir() cmd.MustInitWorkspace(t, dir) - if err := os.WriteFile(filepath.Join(dir, "foo.go"), []byte("package main\nfunc Foo() {}\n"), 0600); err != nil { - t.Fatal(err) - } - if err := os.WriteFile(filepath.Join(dir, "bar.go"), []byte("package main\nfunc Bar() { Foo() }\n"), 0600); err != nil { - t.Fatal(err) - } - - stdout, stderr, err := cmd.RunCmd(t, dir, "analyze", dir, "--dry-run", "--deep") + sourceRepo := filepath.Join(dir, "source-repo") + cmd.InitGitRepo(t, sourceRepo, "modules/cart/main.go", "package main\nfunc ServeCart() {}\n") + runtimeRepo := filepath.Join(dir, "runtime-repo") + cmd.InitGitRepo(t, runtimeRepo, "deploy/cart.yaml", ` +apiVersion: apps/v1 +kind: Deployment +metadata: + name: cart +spec: + template: + spec: + containers: + - name: app + image: example/cart + env: + - name: PEER + value: "cache:6379" +--- +apiVersion: v1 +kind: Service +metadata: + name: cache +spec: + ports: + - port: 6379 +`) + + stdout, stderr, err := cmd.RunCmd(t, dir, "analyze", sourceRepo, "--data-dir", dataDir, "--embedding-provider", "none") if err != nil { - t.Fatalf("analyze --deep: %v\nstdout: %s\nstderr: %s", err, stdout, stderr) + t.Fatalf("analyze source: %v\nstdout: %s\nstderr: %s", err, stdout, stderr) } - if !strings.Contains(stdout, "[dry-run] OK 2 connectors written to connectors.yaml") { - t.Fatalf("unexpected connector count in deep mode\nstdout: %s\nstderr: %s", stdout, stderr) - } -} - -func TestAnalyzeCmd_WorkspaceRootWithoutConfiguredReposUsesWorkspaceFiles(t *testing.T) { - dir := t.TempDir() - cmd.MustInitWorkspace(t, dir) - if err := os.WriteFile(filepath.Join(dir, "foo.go"), []byte("package main\nfunc Foo() {}\n"), 0600); err != nil { - t.Fatal(err) + stdout, stderr, err = cmd.RunCmd(t, dir, "analyze", runtimeRepo, "--data-dir", dataDir, "--embedding-provider", "none") + if err != nil { + t.Fatalf("analyze runtime: %v\nstdout: %s\nstderr: %s", err, stdout, stderr) } - if err := os.WriteFile(filepath.Join(dir, "bar.go"), []byte("package main\nfunc Bar() { Foo() }\n"), 0600); err != nil { + ws, err := workspace.Load(dir) + if err != nil { t.Fatal(err) } - - stdout, stderr, err := cmd.RunCmd(t, dir, "analyze", dir, "--dry-run") - if err != nil { - t.Fatalf("analyze: %v\nstdout: %s\nstderr: %s", err, stdout, stderr) + runtimeRef := refByElementName(ws, "runtime-repo") + architectureRef := refByElementNameWithParent(ws, "Architecture", runtimeRef) + cartArchRef := refByElementNameWithParent(ws, "cart", architectureRef) + cartFolderRef := refByKindAndFilePath(ws, "folder", "modules/cart") + if runtimeRef == "" || architectureRef == "" || cartArchRef == "" || cartFolderRef == "" { + t.Fatalf("missing cross-repo test elements: runtime=%q architecture=%q cartArch=%q cartFolder=%q elements=%+v", runtimeRef, architectureRef, cartArchRef, cartFolderRef, ws.Elements) } - if !strings.Contains(stdout, "[dry-run] OK 5 elements written to elements.yaml") { - t.Fatalf("unexpected element count\nstdout: %s\nstderr: %s", stdout, stderr) + if !ws.Elements[cartArchRef].HasView || !hasPlacementParent(ws, cartFolderRef, cartArchRef) { + t.Fatalf("runtime architecture component should deep-link to source repo structural folder: arch=%+v folder=%+v", ws.Elements[cartArchRef], ws.Elements[cartFolderRef]) } - if !strings.Contains(stdout, "[dry-run] OK 2 connectors written to connectors.yaml") { - t.Fatalf("unexpected connector count\nstdout: %s\nstderr: %s", stdout, stderr) + if placementCount(ws, cartFolderRef) < 2 { + t.Fatalf("source structural folder should remain in its original structural view and be reused in runtime deep-dive view: %+v", ws.Elements[cartFolderRef]) } } -func TestAnalyzeCmd_PythonImports(t *testing.T) { +func TestAnalyzeCmd_PrunesDisconnectedArchitectureComponents(t *testing.T) { dir := t.TempDir() + dataDir := t.TempDir() cmd.MustInitWorkspace(t, dir) - - repoDir := filepath.Join(dir, "pyapp") - if err := os.MkdirAll(filepath.Join(repoDir, "myapp", "utils"), 0750); err != nil { - t.Fatal(err) - } - cmd.InitGitRepo(t, repoDir, filepath.Join("myapp", "utils", "helper.py"), "def run():\n pass\n") - if err := os.WriteFile(filepath.Join(repoDir, "myapp", "main.py"), []byte("from .utils import helper\nfrom myapp.utils.helper import run\n\ndef main():\n run()\n"), 0600); err != nil { - t.Fatal(err) - } - - stdout, stderr, err := cmd.RunCmd(t, dir, "analyze", repoDir) + repoDir := filepath.Join(dir, "runtime-app") + cmd.InitGitRepo(t, repoDir, "main.go", "package main\nfunc Main() {}\n") + writeAnalyzeTestFile(t, repoDir, "deploy/topology.yaml", ` +apiVersion: apps/v1 +kind: Deployment +metadata: + name: connected +spec: + template: + spec: + containers: + - name: app + image: example/connected:latest + env: + - name: PEER + value: "target:9090" +--- +apiVersion: v1 +kind: Service +metadata: + name: target +spec: + ports: + - port: 9090 +--- +apiVersion: v1 +kind: Service +metadata: + name: isolated +spec: + ports: + - port: 9999 +`) + + stdout, stderr, err := cmd.RunCmd(t, dir, "analyze", repoDir, "--data-dir", dataDir, "--embedding-provider", "none") if err != nil { t.Fatalf("analyze: %v\nstdout: %s\nstderr: %s", err, stdout, stderr) } @@ -343,227 +319,200 @@ func TestAnalyzeCmd_PythonImports(t *testing.T) { if err != nil { t.Fatal(err) } - - mainFileRef := findAnalyzeElementRefByKindAndPath(t, ws, "file", filepath.Join("myapp", "main.py")) - helperFileRef := findAnalyzeElementRefByKindAndPath(t, ws, "file", filepath.Join("myapp", "utils", "helper.py")) - utilsFolderRef := findAnalyzeElementRefByKindAndPath(t, ws, "folder", filepath.Join("myapp", "utils")) - - assertAnalyzeConnectorExists(t, ws, mainFileRef, utilsFolderRef, testAnalyzeDependencyLabelImport) - assertAnalyzeConnectorExists(t, ws, mainFileRef, helperFileRef, testAnalyzeDependencyLabelReference) + architectureRef := refByElementName(ws, "Architecture") + for ref, element := range ws.Elements { + if element.Name == "isolated" && hasPlacementParent(ws, ref, architectureRef) { + t.Fatalf("disconnected architecture element should be pruned from Architecture: %+v", ws.Elements) + } + } + if !connectorByElementNamesInParent(ws, "connected", "target", architectureRef) { + t.Fatalf("expected connected architecture edge, got %+v", ws.Connectors) + } } -func TestAnalyzeCmd_ExcludeRules(t *testing.T) { +func TestAnalyzeCmd_PreservesManualYAML(t *testing.T) { dir := t.TempDir() + dataDir := t.TempDir() cmd.MustInitWorkspace(t, dir) - config := "project_name: Demo\nexclude:\n - '*_test.go'\n" - if err := os.WriteFile(filepath.Join(dir, ".tld.yaml"), []byte(config), 0600); err != nil { - t.Fatal(err) - } - if err := os.WriteFile(filepath.Join(dir, "prod.go"), []byte("package main\nfunc Prod() {}\n"), 0600); err != nil { - t.Fatal(err) - } - if err := os.WriteFile(filepath.Join(dir, "prod_test.go"), []byte("package main\nfunc TestOnly() {}\n"), 0600); err != nil { - t.Fatal(err) - } + cmd.MustRunCmd(t, dir, "add", "Manual API", "--ref", "manual-api", "--kind", "service") + repoDir := filepath.Join(dir, "app") + cmd.InitGitRepo(t, repoDir, "main.go", "package main\nfunc Main() {}\n") - if _, _, err := cmd.RunCmd(t, dir, "analyze", dir); err != nil { - t.Fatalf("analyze: %v", err) + if stdout, stderr, err := cmd.RunCmd(t, dir, "analyze", repoDir, "--data-dir", dataDir, "--embedding-provider", "none"); err != nil { + t.Fatalf("analyze: %v\nstdout: %s\nstderr: %s", err, stdout, stderr) } ws, err := workspace.Load(dir) if err != nil { t.Fatal(err) } - for _, element := range ws.Elements { - if element.Name == "TestOnly" { - t.Fatalf("unexpected test symbol in elements.yaml: %+v", ws.Elements) - } + if ws.Elements["manual-api"] == nil { + t.Fatalf("manual element was not preserved: %+v", ws.Elements) } } -func TestAnalyzeCmd_GeneratedNamesAreGloballyUnique(t *testing.T) { +func TestAnalyzeCmd_DryRunDoesNotWriteYAML(t *testing.T) { dir := t.TempDir() + dataDir := t.TempDir() cmd.MustInitWorkspace(t, dir) - repoDir := filepath.Join(dir, "app") - if err := os.MkdirAll(filepath.Join(repoDir, "cmd"), 0750); err != nil { - t.Fatal(err) - } - cmd.InitGitRepo(t, repoDir, filepath.Join("cmd", "main.go"), "package main\nfunc main() {}\n") - if err := os.MkdirAll(filepath.Join(repoDir, "tools"), 0750); err != nil { - t.Fatal(err) - } - if err := os.WriteFile(filepath.Join(repoDir, "tools", "main.go"), []byte("package main\nfunc main() {}\n"), 0600); err != nil { + cmd.InitGitRepo(t, repoDir, "service.go", "package main\nfunc Service() {}\n") + before, err := os.ReadFile(filepath.Join(dir, "elements.yaml")) + if err != nil { t.Fatal(err) } - stdout, stderr, err := cmd.RunCmd(t, dir, "analyze", repoDir) + stdout, stderr, err := cmd.RunCmd(t, dir, "analyze", repoDir, "--dry-run", "--data-dir", dataDir, "--embedding-provider", "none") if err != nil { - t.Fatalf("analyze: %v\nstdout: %s\nstderr: %s", err, stdout, stderr) + t.Fatalf("analyze --dry-run: %v\nstdout: %s\nstderr: %s", err, stdout, stderr) } - ws, err := workspace.Load(dir) + after, err := os.ReadFile(filepath.Join(dir, "elements.yaml")) if err != nil { t.Fatal(err) } - for _, validationErr := range ws.ValidateWithOpts(workspace.ValidationOptions{SkipSymbols: true}) { - if strings.Contains(validationErr.Message, "duplicate element name") { - t.Fatalf("unexpected duplicate-name validation error: %v", validationErr) - } + if string(before) != string(after) { + t.Fatal("elements.yaml changed during dry-run") } } -func TestAnalyzeCmd_PathNotExist(t *testing.T) { +func TestAnalyzeCmd_RemovedFlagsFail(t *testing.T) { dir := t.TempDir() cmd.MustInitWorkspace(t, dir) - _, _, err := cmd.RunCmd(t, dir, "analyze", filepath.Join(dir, "missing.go")) - if err == nil { - t.Fatal("expected error for missing path") + if _, _, err := cmd.RunCmd(t, dir, "analyze", dir, "--deep"); err == nil { + t.Fatal("expected --deep to fail") + } + if _, _, err := cmd.RunCmd(t, dir, "analyze", dir, "--changed-since", "HEAD"); err == nil { + t.Fatal("expected --changed-since to fail") } } -func TestAnalyzeCmd_SetsTechnology(t *testing.T) { +func TestAnalyzeCmd_JSONDryRunUsesWatchDiffShape(t *testing.T) { dir := t.TempDir() + dataDir := t.TempDir() cmd.MustInitWorkspace(t, dir) - file := filepath.Join(dir, "service.go") - if err := os.WriteFile(file, []byte("package main\nfunc Foo() {}\n"), 0600); err != nil { - t.Fatal(err) - } + repoDir := filepath.Join(dir, "app") + cmd.InitGitRepo(t, repoDir, "main.go", "package main\nfunc Main() {}\n") - if _, _, err := cmd.RunCmd(t, dir, "analyze", file); err != nil { - t.Fatalf("analyze: %v", err) - } - ws, err := workspace.Load(dir) + stdout, stderr, err := cmd.RunCmd(t, dir, "--format", "json", "analyze", repoDir, "--dry-run", "--data-dir", dataDir, "--embedding-provider", "none") if err != nil { - t.Fatal(err) + t.Fatalf("analyze --format json --dry-run: %v\nstdout: %s\nstderr: %s", err, stdout, stderr) } - - fileRef := findAnalyzeElementRefByKindAndPath(t, ws, "file", "service.go") - fileElement := ws.Elements[fileRef] - if fileElement.Technology != "go" { - t.Fatalf("file technology = %q, want go", fileElement.Technology) + var payload struct { + Changed bool `json:"changed"` + Scan map[string]any `json:"scan"` + Representation map[string]any `json:"representation"` + Export map[string]any `json:"export"` + Diffs []map[string]any `json:"diffs"` } - - fooRef := findAnalyzeElementRefBySymbol(t, ws, "Foo") - fooElement := ws.Elements[fooRef] - if fooElement.Technology != "go" { - t.Fatalf("symbol technology = %q, want go", fooElement.Technology) + if err := json.Unmarshal([]byte(stdout), &payload); err != nil { + t.Fatalf("decode json: %v\n%s", err, stdout) + } + if payload.Scan["repository_id"] == nil || payload.Representation["representation_hash"] == nil || payload.Export["elements_written"] == nil { + t.Fatalf("unexpected payload: %+v", payload) } } -func TestAnalyzeCmd_SetsTechnologyPython(t *testing.T) { - dir := t.TempDir() - cmd.MustInitWorkspace(t, dir) - file := filepath.Join(dir, "app.py") - if err := os.WriteFile(file, []byte("def hello():\n pass\n"), 0600); err != nil { - t.Fatal(err) +func countKind(ws *workspace.Workspace, kind string) int { + count := 0 + for _, element := range ws.Elements { + if element.Kind == kind { + count++ + } } + return count +} - if _, _, err := cmd.RunCmd(t, dir, "analyze", file); err != nil { - t.Fatalf("analyze: %v", err) - } - ws, err := workspace.Load(dir) - if err != nil { +func writeAnalyzeTestFile(t *testing.T, root, name, content string) { + t.Helper() + path := filepath.Join(root, name) + if err := os.MkdirAll(filepath.Dir(path), 0750); err != nil { t.Fatal(err) } - - fileRef := findAnalyzeElementRefByKindAndPath(t, ws, "file", "app.py") - fileElement := ws.Elements[fileRef] - if fileElement.Technology != "python" { - t.Fatalf("file technology = %q, want python", fileElement.Technology) - } - - helloRef := findAnalyzeElementRefBySymbol(t, ws, "hello") - helloElement := ws.Elements[helloRef] - if helloElement.Technology != "python" { - t.Fatalf("symbol technology = %q, want python", helloElement.Technology) + if err := os.WriteFile(path, []byte(content), 0600); err != nil { + t.Fatal(err) } } -func findAnalyzeElementRefByKindAndPath(t *testing.T, ws *workspace.Workspace, kind, filePath string) string { - t.Helper() +func refByElementName(ws *workspace.Workspace, name string) string { for ref, element := range ws.Elements { - if element.Kind == kind && element.FilePath == filePath { + if element.Name == name { return ref } } - t.Fatalf("expected %s element for %s, got %+v", kind, filePath, ws.Elements) return "" } -func findAnalyzeElementRefBySymbol(t *testing.T, ws *workspace.Workspace, symbol string) string { - t.Helper() +func refByElementNameWithParent(ws *workspace.Workspace, name, parentRef string) string { for ref, element := range ws.Elements { - if element.Symbol == symbol { + if element.Name == name && hasPlacementParent(ws, ref, parentRef) { return ref } } - t.Fatalf("expected symbol element for %s, got %+v", symbol, ws.Elements) return "" } -func assertAnalyzeConnectorExists(t *testing.T, ws *workspace.Workspace, source, target, label string) { - t.Helper() - for _, connector := range ws.Connectors { - if connector.Source == source && connector.Target == target && connector.Label == label { - return +func refByKind(ws *workspace.Workspace, kind string) string { + for ref, element := range ws.Elements { + if element.Kind == kind { + return ref } } - t.Fatalf("expected connector %s -> %s (%s), got %+v", source, target, label, ws.Connectors) + return "" } -func assertAnalyzeConnectorCount(t *testing.T, ws *workspace.Workspace, source, target, label string, want int) { - t.Helper() - got := 0 - for _, connector := range ws.Connectors { - if connector.Source == source && connector.Target == target && connector.Label == label { - got++ +func refByKindAndFilePath(ws *workspace.Workspace, kind, filePath string) string { + for ref, element := range ws.Elements { + if element.Kind == kind && element.FilePath == filePath { + return ref } } - if got != want { - t.Fatalf("connector count %s -> %s (%s) = %d, want %d: %+v", source, target, label, got, want, ws.Connectors) - } + return "" } -func assertAnalyzeConnectorCountUnordered(t *testing.T, ws *workspace.Workspace, left, right, label string, want int) { - t.Helper() - got := 0 - for _, connector := range ws.Connectors { - if connector.Label != label { - continue - } - if (connector.Source == left && connector.Target == right) || (connector.Source == right && connector.Target == left) { - got++ +func hasPlacementParent(ws *workspace.Workspace, ref, parentRef string) bool { + element := ws.Elements[ref] + if element == nil { + return false + } + for _, placement := range element.Placements { + if placement.ParentRef == parentRef { + return true } } - if got != want { - t.Fatalf("unordered connector count %s <-> %s (%s) = %d, want %d: %+v", left, right, label, got, want, ws.Connectors) + return false +} + +func placementCount(ws *workspace.Workspace, ref string) int { + element := ws.Elements[ref] + if element == nil { + return 0 } + return len(element.Placements) } -func assertAnalyzeConnectorCountByLabel(t *testing.T, ws *workspace.Workspace, label string, want int) { - t.Helper() - got := 0 +func connectorByElementNamesInParent(ws *workspace.Workspace, sourceName, targetName, parentRef string) bool { + sourceRef := refByElementNameWithParent(ws, sourceName, parentRef) + targetRef := refByElementNameWithParent(ws, targetName, parentRef) + if sourceRef == "" || targetRef == "" { + return false + } for _, connector := range ws.Connectors { - if connector.Label == label { - got++ + if connector.Source == sourceRef && connector.Target == targetRef { + return true } } - if got != want { - t.Fatalf("connector count for label %s = %d, want %d: %+v", label, got, want, ws.Connectors) - } + return false } -func assertBidirectionalAnalyzeConnector(t *testing.T, ws *workspace.Workspace, left, right, label string) { - t.Helper() +func connectorByElementNames(ws *workspace.Workspace, sourceName, targetName string) bool { + sourceRef := refByElementName(ws, sourceName) + targetRef := refByElementName(ws, targetName) + if sourceRef == "" || targetRef == "" { + return false + } for _, connector := range ws.Connectors { - if connector.Label != label { - continue - } - if (connector.Source == left && connector.Target == right) || (connector.Source == right && connector.Target == left) { - if connector.Direction != "both" { - t.Fatalf("connector %s <-> %s (%s) direction = %s, want both: %+v", left, right, label, connector.Direction, connector) - } - return + if connector.Source == sourceRef && connector.Target == targetRef { + return true } } - t.Fatalf("expected bidirectional connector %s <-> %s (%s), got %+v", left, right, label, ws.Connectors) + return false } diff --git a/cmd/apply/apply.go b/cmd/apply/apply.go index 73bf167..bb877a7 100644 --- a/cmd/apply/apply.go +++ b/cmd/apply/apply.go @@ -15,6 +15,7 @@ import ( "github.com/mertcikla/tld/internal/cmdutil" "github.com/mertcikla/tld/internal/planner" "github.com/mertcikla/tld/internal/reporter" + "github.com/mertcikla/tld/internal/term" "github.com/mertcikla/tld/internal/workspace" "github.com/spf13/cobra" "google.golang.org/protobuf/proto" @@ -92,17 +93,15 @@ func NewApplyCmd(wdir *string) *cobra.Command { return fmt.Errorf("server drift check failed: %w", err) } if hasDrift { - _, _ = fmt.Fprintln(cmd.OutOrStdout(), "Warning: The server has changes that are not in your local YAML.") - _, _ = fmt.Fprintln(cmd.OutOrStdout(), " Run `tld pull` to merge them first, or use --force-apply to overwrite.") - if _, err := fmt.Fprint(cmd.OutOrStdout(), " Continue anyway? [yes/no]: "); err != nil { - return err - } + term.Warn(cmd.OutOrStdout(), "The server has changes that are not in your local YAML.") + term.Hint(cmd.OutOrStdout(), "Run `tld pull` to merge them first, or use --force-apply to overwrite.") + _, _ = fmt.Fprint(cmd.OutOrStdout(), " Continue anyway? [yes/no]: ") if !scanner.Scan() { return errors.New("aborted") } answer := strings.TrimSpace(strings.ToLower(scanner.Text())) if answer != "yes" && answer != "y" { - _, _ = fmt.Fprintln(cmd.OutOrStdout(), "Apply cancelled.") + term.Info(cmd.OutOrStdout(), "Apply cancelled.") return nil } } @@ -134,13 +133,13 @@ func NewApplyCmd(wdir *string) *cobra.Command { } if !force { - _, _ = fmt.Fprintf(cmd.OutOrStdout(), "Apply %d resources? [yes/no]: ", total) + _, _ = fmt.Fprintf(cmd.OutOrStdout(), " Apply %d resources? [yes/no]: ", total) if !scanner.Scan() { return errors.New("aborted") } answer := strings.TrimSpace(strings.ToLower(scanner.Text())) if answer != "yes" && answer != "y" { - _, _ = fmt.Fprintln(cmd.OutOrStdout(), "Apply cancelled.") + term.Info(cmd.OutOrStdout(), "Apply cancelled.") return nil } } @@ -151,11 +150,11 @@ func NewApplyCmd(wdir *string) *cobra.Command { if cmdutil.WantsJSON(cmd.Root().PersistentFlags().Lookup("format").Value.String()) { return cmdutil.WriteCommandError(cmd.OutOrStdout(), cmd.Root().PersistentFlags().Lookup("compact").Value.String() == "true", "apply", err) } - _, _ = fmt.Fprintln(cmd.ErrOrStderr(), "Apply failed:", err) - _, _ = fmt.Fprintf(cmd.ErrOrStderr(), " Target URL: %s\n", client.NormalizeURL(ws.Config.ServerURL)) + term.Fail(cmd.ErrOrStderr(), fmt.Sprintf("Apply failed: %v", err)) + term.Label(cmd.ErrOrStderr(), 12, "Target URL", client.NormalizeURL(ws.Config.ServerURL)) if connectErr := new(connect.Error); errors.As(err, &connectErr) { - _, _ = fmt.Fprintf(cmd.ErrOrStderr(), " Code: %s\n", connectErr.Code().String()) + term.Label(cmd.ErrOrStderr(), 12, "Code", connectErr.Code().String()) if len(connectErr.Details()) > 0 { _, _ = fmt.Fprintln(cmd.ErrOrStderr(), " Details:") for _, detail := range connectErr.Details() { @@ -164,7 +163,7 @@ func NewApplyCmd(wdir *string) *cobra.Command { } } - _, _ = fmt.Fprintln(cmd.ErrOrStderr(), "Transaction rolled back.") + term.Info(cmd.ErrOrStderr(), "Transaction rolled back.") reporter.RenderExecutionMarkdown(cmd.ErrOrStderr(), plan, nil, false, false) return cmdutil.WithUnauthorizedHint("apply failed", err) } @@ -195,7 +194,7 @@ func NewApplyCmd(wdir *string) *cobra.Command { } currentWS.Meta = meta if err := workspace.Save(currentWS); err != nil { - _, _ = fmt.Fprintf(cmd.ErrOrStderr(), "Warning: failed to rewrite workspace metadata: %v\n", err) + term.Warnf(cmd.ErrOrStderr(), "Failed to rewrite workspace metadata: %v", err) } if cmdutil.WantsJSON(cmd.Root().PersistentFlags().Lookup("format").Value.String()) { @@ -205,7 +204,7 @@ func NewApplyCmd(wdir *string) *cobra.Command { reporter.RenderExecutionMarkdown(cmd.OutOrStdout(), plan, resp.Msg, true, verbose) if len(resp.Msg.Drift) > 0 { - _, _ = fmt.Fprintf(cmd.ErrOrStderr(), "Warning: %d drift item(s) detected\n", len(resp.Msg.Drift)) + term.Warnf(cmd.ErrOrStderr(), "%d drift item(s) detected", len(resp.Msg.Drift)) return fmt.Errorf("%d drift item(s) detected", len(resp.Msg.Drift)) } return nil @@ -243,7 +242,7 @@ func autoPullAndRebuild(cmd *cobra.Command, ws *workspace.Workspace, lockFile *w return nil, 0, nil } if !cmdutil.WantsJSON(cmd.Root().PersistentFlags().Lookup("format").Value.String()) { - _, _ = fmt.Fprintln(cmd.ErrOrStderr(), "Version conflict detected during force apply. Pulling and retrying once...") + term.Warnf(cmd.ErrOrStderr(), "Version conflict detected during force apply. Pulling and retrying once...") } newPlan, err := pullAndRebuildPlan(cmd, ws, lockFile, wdir, recreateIDs) if err != nil { @@ -279,22 +278,22 @@ func detectAndHandleConflicts( _, _ = fmt.Fprintf(cmd.ErrOrStderr(), "Version conflict detected:\n") if resp.Msg.Version != nil { - _, _ = fmt.Fprintf(cmd.ErrOrStderr(), "- Remote has newer version %s (%s) via %s\n", + term.Warnf(cmd.ErrOrStderr(), "Remote has newer version %s (%s) via %s", resp.Msg.Version.VersionId, resp.Msg.Version.CreatedAt.AsTime().Format(time.RFC3339), resp.Msg.Version.CreatedBy) } - _, _ = fmt.Fprintf(cmd.ErrOrStderr(), "- %d conflicts detected:\n", len(resp.Msg.Conflicts)) + _, _ = fmt.Fprintf(cmd.ErrOrStderr(), " %d conflicts detected:\n", len(resp.Msg.Conflicts)) for _, conflict := range resp.Msg.Conflicts { - _, _ = fmt.Fprintf(cmd.ErrOrStderr(), " * %s \"%s\" (local %s, remote %s)\n", + _, _ = fmt.Fprintf(cmd.ErrOrStderr(), " * %s %q (local %s, remote %s)\n", conflict.ResourceType, conflict.Ref, conflict.LocalUpdatedAt.AsTime().Format(time.RFC3339), conflict.RemoteUpdatedAt.AsTime().Format(time.RFC3339)) } _, _ = fmt.Fprintf(cmd.ErrOrStderr(), "\nOptions:\n") - _, _ = fmt.Fprintf(cmd.ErrOrStderr(), "[1] Abort and review changes\n") - _, _ = fmt.Fprintf(cmd.ErrOrStderr(), "[2] Pull & Merge (fetch server state and merge locally)\n") - _, _ = fmt.Fprintf(cmd.ErrOrStderr(), "[3] Force Apply (overwrite remote changes)\n") + _, _ = fmt.Fprintf(cmd.ErrOrStderr(), " [1] Abort and review changes\n") + _, _ = fmt.Fprintf(cmd.ErrOrStderr(), " [2] Pull & Merge (fetch server state and merge locally)\n") + _, _ = fmt.Fprintf(cmd.ErrOrStderr(), " [3] Force Apply (overwrite remote changes)\n") _, _ = fmt.Fprintf(cmd.ErrOrStderr(), "\nChoose option [1-3]: ") if !scanner.Scan() { @@ -304,20 +303,20 @@ func detectAndHandleConflicts( choice := strings.TrimSpace(scanner.Text()) switch choice { case "1": - _, _ = fmt.Fprintln(cmd.OutOrStdout(), "Apply aborted.") + term.Info(cmd.OutOrStdout(), "Apply aborted.") return nil, errors.New("apply aborted by user") case "2": - _, _ = fmt.Fprintln(cmd.OutOrStdout(), "Pulling server state and merging locally...") + term.Info(cmd.OutOrStdout(), "Pulling server state and merging locally...") newPlan, err := pullAndRebuildPlan(cmd, ws, lockFile, wdir, recreateIDs) if err != nil { return nil, fmt.Errorf("pull & merge: %w", err) } - _, _ = fmt.Fprintln(cmd.OutOrStdout(), "Merge complete. Proceeding with apply...") + term.Success(cmd.OutOrStdout(), "Merge complete. Proceeding with apply...") return newPlan, nil case "3": - _, _ = fmt.Fprintln(cmd.OutOrStdout(), "Proceeding with force apply...") + term.Info(cmd.OutOrStdout(), "Proceeding with force apply...") return nil, nil default: diff --git a/cmd/apply/apply_test.go b/cmd/apply/apply_test.go index 8257626..f8ca6ba 100644 --- a/cmd/apply/apply_test.go +++ b/cmd/apply/apply_test.go @@ -187,14 +187,6 @@ func TestApplyCmd_InteractiveDecline(t *testing.T) { } } -func TestApplyCmd_MissingConfig(t *testing.T) { - dir := t.TempDir() - _, _, err := cmd.RunCmd(t, dir, "apply", "--force") - if err == nil || !strings.Contains(err.Error(), "load workspace") { - t.Fatalf("expected missing config error, got %v", err) - } -} - func TestApplyCmd_CreatedResourcesInOutput(t *testing.T) { svc := &cmd.MockDiagramService{} serverURL := cmd.NewMockServer(t, svc) diff --git a/cmd/check/check.go b/cmd/check/check.go index fec8769..b93a361 100644 --- a/cmd/check/check.go +++ b/cmd/check/check.go @@ -4,6 +4,7 @@ import ( "fmt" "github.com/mertcikla/tld/internal/cmdutil" + "github.com/mertcikla/tld/internal/term" "github.com/mertcikla/tld/internal/workspace" "github.com/spf13/cobra" ) @@ -28,24 +29,24 @@ func NewCheckCmd(wdir *string) *cobra.Command { errs := ws.Validate() if len(errs) > 0 { allPassed = false - _, _ = fmt.Fprintln(cmd.ErrOrStderr(), "FAIL Validation") + term.Fail(cmd.ErrOrStderr(), "Validation") for _, e := range errs { _, _ = fmt.Fprintf(cmd.ErrOrStderr(), " - %s\n", e) } } else { - _, _ = fmt.Fprintln(cmd.OutOrStdout(), "PASS Validation") + term.Success(cmd.OutOrStdout(), "Validation") } // 2. Check Symbols broken := cmdutil.CheckSymbols(cmd.Context(), ws, repoCtx, rules) if len(broken) > 0 { allPassed = false - _, _ = fmt.Fprintln(cmd.ErrOrStderr(), "FAIL Symbol Verification") + term.Fail(cmd.ErrOrStderr(), "Symbol Verification") for _, msg := range broken { _, _ = fmt.Fprintf(cmd.ErrOrStderr(), " - %s\n", msg) } } else { - _, _ = fmt.Fprintln(cmd.OutOrStdout(), "PASS Symbol Verification") + term.Success(cmd.OutOrStdout(), "Symbol Verification") } // 2. Check Freshness @@ -54,19 +55,19 @@ func NewCheckCmd(wdir *string) *cobra.Command { if strict { allPassed = false } - label := "WARN " if strict { - label = "FAIL " + term.Fail(cmd.ErrOrStderr(), "Outdated Diagrams") + } else { + term.Warn(cmd.ErrOrStderr(), "Outdated Diagrams") } - _, _ = fmt.Fprintf(cmd.ErrOrStderr(), "%s Outdated Diagrams\n", label) for _, msg := range outdated { _, _ = fmt.Fprintf(cmd.ErrOrStderr(), " - %s\n", msg) } if strict { - _, _ = fmt.Fprintln(cmd.ErrOrStderr(), " (use `tld apply` to sync diagram metadata)") + term.Hint(cmd.ErrOrStderr(), "use `tld apply` to sync diagram metadata") } } else { - _, _ = fmt.Fprintln(cmd.OutOrStdout(), "PASS Outdated Diagrams") + term.Success(cmd.OutOrStdout(), "Outdated Diagrams") } if !allPassed { diff --git a/cmd/check/check_repo_scope_test.go b/cmd/check/check_repo_scope_test.go index 073e38f..552677e 100644 --- a/cmd/check/check_repo_scope_test.go +++ b/cmd/check/check_repo_scope_test.go @@ -35,7 +35,7 @@ foreign: if err != nil { t.Fatalf("check: %v\nstdout: %s\nstderr: %s", err, stdout, stderr) } - if !strings.Contains(stdout, "PASS Symbol Verification") { + if !strings.Contains(stdout, "Symbol Verification") { t.Errorf("stdout %q does not contain symbol verification pass", stdout) } if strings.Contains(stderr, "Foreign Service") || strings.Contains(stderr, "doesNotExist") { @@ -69,7 +69,7 @@ foreign: if err != nil { t.Fatalf("validate: %v\nstdout: %s\nstderr: %s", err, stdout, stderr) } - if !strings.Contains(stdout, "Symbol verification: passed") { + if !strings.Contains(stdout, "Symbol verification") { t.Errorf("stdout %q does not contain symbol verification pass", stdout) } if strings.Contains(stderr, "Foreign Service") || strings.Contains(stderr, "doesNotExist") { diff --git a/cmd/check/check_test.go b/cmd/check/check_test.go index 5f244e6..fd16b7d 100644 --- a/cmd/check/check_test.go +++ b/cmd/check/check_test.go @@ -38,7 +38,7 @@ func TestCheckCmd_AllPass(t *testing.T) { if err != nil { t.Fatalf("check: %v\nstdout: %s\nstderr: %s", err, stdout, stderr) } - if !strings.Contains(stdout, "PASS Validation") || !strings.Contains(stdout, "PASS Symbol Verification") || !strings.Contains(stdout, "PASS Outdated Diagrams") { + if !strings.Contains(stdout, "Validation") || !strings.Contains(stdout, "Symbol Verification") || !strings.Contains(stdout, "Outdated Diagrams") { t.Fatalf("unexpected stdout: %s", stdout) } } @@ -57,7 +57,7 @@ func TestCheckCmd_BrokenSymbol(t *testing.T) { if err == nil { t.Fatalf("expected check failure\nstdout: %s\nstderr: %s", stdout, stderr) } - if !strings.Contains(stderr, "FAIL Symbol Verification") { + if !strings.Contains(stderr, "Symbol Verification") { t.Fatalf("unexpected stderr: %s", stderr) } } @@ -76,7 +76,7 @@ func TestCheckCmd_OutdatedStrict(t *testing.T) { if err == nil { t.Fatalf("expected strict check failure\nstdout: %s\nstderr: %s", stdout, stderr) } - if !strings.Contains(stderr, "FAIL Outdated Diagrams") { + if !strings.Contains(stderr, "Outdated Diagrams") { t.Fatalf("unexpected stderr: %s", stderr) } } @@ -95,7 +95,7 @@ func TestCheckCmd_ValidationFail(t *testing.T) { if err == nil { t.Fatalf("expected validation failure\nstdout: %s\nstderr: %s", stdout, stderr) } - if !strings.Contains(stderr, "FAIL Validation") { + if !strings.Contains(stderr, "Validation") { t.Fatalf("unexpected stderr: %s", stderr) } } @@ -114,7 +114,7 @@ func TestCheckCmd_OutdatedWarn(t *testing.T) { if err != nil { t.Fatalf("expected warning-only check\nstdout: %s\nstderr: %s\nerr: %v", stdout, stderr, err) } - if !strings.Contains(stderr, "WARN Outdated Diagrams") { + if !strings.Contains(stderr, "Outdated Diagrams") { t.Fatalf("unexpected stderr: %s", stderr) } _ = time.Now() diff --git a/cmd/cli_port_test.go b/cmd/cli_port_test.go index bbc97ef..d83dc01 100644 --- a/cmd/cli_port_test.go +++ b/cmd/cli_port_test.go @@ -9,6 +9,7 @@ import ( "github.com/mertcikla/tld/cmd/version" "github.com/mertcikla/tld/cmd" + "github.com/mertcikla/tld/internal/workspace" ) func TestRootCmd_HelpMatchesReferenceSurface(t *testing.T) { @@ -87,39 +88,24 @@ func TestConnectCmd_HelpHidesStyleFlag(t *testing.T) { func TestAnalyzeCmd_EmptyGoFileDoesNotChangeWorkspaceContents(t *testing.T) { dir := t.TempDir() + dataDir := t.TempDir() cmd.MustInitWorkspace(t, dir) - file := filepath.Join(dir, "empty.go") - if err := os.WriteFile(file, []byte("package main\n"), 0600); err != nil { - t.Fatalf("write empty.go: %v", err) - } + repoDir := filepath.Join(dir, "repo") + cmd.InitGitRepo(t, repoDir, "empty.go", "package main\n") - beforeElements, err := os.ReadFile(filepath.Join(dir, "elements.yaml")) - if err != nil { - t.Fatalf("read elements.yaml before analyze: %v", err) - } - beforeConnectors, err := os.ReadFile(filepath.Join(dir, "connectors.yaml")) - if err != nil { - t.Fatalf("read connectors.yaml before analyze: %v", err) - } - - stdout, stderr, err := cmd.RunCmd(t, dir, "analyze", file) + stdout, stderr, err := cmd.RunCmd(t, dir, "analyze", repoDir, "--data-dir", dataDir, "--embedding-provider", "none") if err != nil { t.Fatalf("analyze empty.go: %v\nstdout: %s\nstderr: %s", err, stdout, stderr) } - - afterElements, err := os.ReadFile(filepath.Join(dir, "elements.yaml")) - if err != nil { - t.Fatalf("read elements.yaml after analyze: %v", err) - } - afterConnectors, err := os.ReadFile(filepath.Join(dir, "connectors.yaml")) + ws, err := workspace.Load(dir) if err != nil { - t.Fatalf("read connectors.yaml after analyze: %v", err) + t.Fatal(err) } - if string(beforeElements) != string(afterElements) { - t.Fatalf("elements.yaml changed for an empty Go file\nbefore:\n%s\nafter:\n%s", beforeElements, afterElements) + if len(ws.Connectors) != 0 { + t.Fatalf("empty Go file should not create connectors: %+v", ws.Connectors) } - if string(beforeConnectors) != string(afterConnectors) { - t.Fatalf("connectors.yaml changed for an empty Go file\nbefore:\n%s\nafter:\n%s", beforeConnectors, afterConnectors) + if len(ws.Elements) == 0 { + t.Fatalf("watch-backed analyze should materialize repository/file context") } } diff --git a/cmd/config/config.go b/cmd/config/config.go new file mode 100644 index 0000000..c64a451 --- /dev/null +++ b/cmd/config/config.go @@ -0,0 +1,169 @@ +package config + +import ( + "encoding/json" + "fmt" + "strings" + "text/tabwriter" + + "github.com/mertcikla/tld/internal/workspace" + "github.com/spf13/cobra" +) + +func NewConfigCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "config", + Short: "Inspect and update the global tld configuration", + Long: "Inspect and update the global tld.yaml configuration file.", + } + cmd.AddCommand(newPathCmd(), newListCmd(), newGetCmd(), newSetCmd(), newValidateCmd()) + return cmd +} + +func newPathCmd() *cobra.Command { + return &cobra.Command{ + Use: "path", + Short: "Print the active global config path", + Args: cobra.NoArgs, + RunE: func(cmd *cobra.Command, _ []string) error { + path, err := workspace.ConfigPath() + if err != nil { + return err + } + _, _ = fmt.Fprintln(cmd.OutOrStdout(), path) + return nil + }, + } +} + +func newListCmd() *cobra.Command { + var showSecrets bool + cmd := &cobra.Command{ + Use: "list", + Short: "List effective global config values", + Args: cobra.NoArgs, + RunE: func(cmd *cobra.Command, _ []string) error { + state, err := workspace.LoadGlobalConfigState() + if err != nil { + return err + } + values := redactConfigValues(state.Values, showSecrets) + if wantsJSON(cmd) { + return writeJSON(cmd, values) + } + w := tabwriter.NewWriter(cmd.OutOrStdout(), 0, 0, 2, ' ', 0) + _, _ = fmt.Fprintln(w, "KEY\tVALUE\tSOURCE\tENV\tDESCRIPTION") + for _, value := range values { + _, _ = fmt.Fprintf(w, "%s\t%s\t%s\t%s\t%s\n", value.Key, value.Value, value.Source, value.Env, value.Description) + } + return w.Flush() + }, + } + cmd.Flags().BoolVar(&showSecrets, "show-secrets", false, "show secret values instead of redacting them") + return cmd +} + +func newGetCmd() *cobra.Command { + return &cobra.Command{ + Use: "get ", + Short: "Print one effective global config value", + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + key := strings.ToLower(strings.TrimSpace(args[0])) + if _, ok := workspace.ConfigDefinitionForKey(key); !ok { + return fmt.Errorf("unknown global config key %q", args[0]) + } + state, err := workspace.LoadGlobalConfigState() + if err != nil { + return err + } + for _, value := range state.Values { + if value.Key == key { + if wantsJSON(cmd) { + return writeJSON(cmd, value) + } + _, _ = fmt.Fprintln(cmd.OutOrStdout(), value.Value) + return nil + } + } + return fmt.Errorf("unknown global config key %q", args[0]) + }, + } +} + +func newSetCmd() *cobra.Command { + return &cobra.Command{ + Use: "set ", + Short: "Set one global config value", + Args: cobra.ExactArgs(2), + RunE: func(cmd *cobra.Command, args []string) error { + if err := workspace.SetGlobalConfigValue(args[0], args[1]); err != nil { + return err + } + _, _ = fmt.Fprintf(cmd.OutOrStdout(), "Set %s\n", strings.ToLower(strings.TrimSpace(args[0]))) + return nil + }, + } +} + +func newValidateCmd() *cobra.Command { + return &cobra.Command{ + Use: "validate", + Short: "Validate the global config file and env overrides", + Args: cobra.NoArgs, + RunE: func(cmd *cobra.Command, _ []string) error { + state, err := workspace.LoadGlobalConfigState() + if err != nil { + return err + } + issues := workspace.ValidateGlobalConfig(state.Config) + if wantsJSON(cmd) { + payload := struct { + OK bool `json:"ok"` + Issues []workspace.ConfigValidationError `json:"issues"` + }{OK: len(issues) == 0, Issues: []workspace.ConfigValidationError(issues)} + if err := writeJSON(cmd, payload); err != nil { + return err + } + } + if len(issues) > 0 { + return issues + } + if !wantsJSON(cmd) { + _, _ = fmt.Fprintln(cmd.OutOrStdout(), "Global config valid.") + } + return nil + }, + } +} + +func redactConfigValues(values []workspace.ConfigValue, showSecrets bool) []workspace.ConfigValue { + out := append([]workspace.ConfigValue(nil), values...) + if showSecrets { + return out + } + for i := range out { + if out[i].Secret && out[i].Value != "" { + out[i].Value = "********" + } + } + return out +} + +func wantsJSON(cmd *cobra.Command) bool { + flag := cmd.Root().PersistentFlags().Lookup("format") + return flag != nil && flag.Value.String() == "json" +} + +func compactJSON(cmd *cobra.Command) bool { + flag := cmd.Root().PersistentFlags().Lookup("compact") + return flag != nil && flag.Value.String() == "true" +} + +func writeJSON(cmd *cobra.Command, payload any) error { + enc := json.NewEncoder(cmd.OutOrStdout()) + if !compactJSON(cmd) { + enc.SetIndent("", " ") + } + return enc.Encode(payload) +} diff --git a/cmd/config_test.go b/cmd/config_test.go new file mode 100644 index 0000000..43d0587 --- /dev/null +++ b/cmd/config_test.go @@ -0,0 +1,112 @@ +package cmd + +import ( + "encoding/json" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/mertcikla/tld/internal/workspace" +) + +func TestConfigCommandPathSetGetAndListRedactsSecrets(t *testing.T) { + configDir := t.TempDir() + t.Setenv("TLD_CONFIG_DIR", configDir) + dir := t.TempDir() + + stdout, _, err := RunCmd(t, dir, "config", "path") + if err != nil { + t.Fatalf("config path: %v", err) + } + if strings.TrimSpace(stdout) != filepath.Join(configDir, "tld.yaml") { + t.Fatalf("config path = %q", stdout) + } + + if stdout, _, err = RunCmd(t, dir, "config", "set", "api_key", "secret-value"); err != nil { + t.Fatalf("config set api_key: %v\nstdout: %s", err, stdout) + } + stdout, _, err = RunCmd(t, dir, "config", "get", "api_key") + if err != nil { + t.Fatalf("config get api_key: %v", err) + } + if strings.TrimSpace(stdout) != "secret-value" { + t.Fatalf("api_key get = %q, want secret-value", stdout) + } + + stdout, _, err = RunCmd(t, dir, "config", "list") + if err != nil { + t.Fatalf("config list: %v", err) + } + if strings.Contains(stdout, "secret-value") || !strings.Contains(stdout, "********") { + t.Fatalf("config list did not redact api_key:\n%s", stdout) + } + + stdout, _, err = RunCmd(t, dir, "config", "list", "--show-secrets") + if err != nil { + t.Fatalf("config list --show-secrets: %v", err) + } + if !strings.Contains(stdout, "secret-value") { + t.Fatalf("config list --show-secrets did not include api_key:\n%s", stdout) + } +} + +func TestConfigCommandJSONAndValidation(t *testing.T) { + configDir := t.TempDir() + t.Setenv("TLD_CONFIG_DIR", configDir) + dir := t.TempDir() + + if _, _, err := RunCmd(t, dir, "config", "set", "watch.languages", "go,typescript"); err != nil { + t.Fatalf("config set languages: %v", err) + } + + stdout, _, err := RunCmd(t, dir, "--format", "json", "config", "get", "watch.languages") + if err != nil { + t.Fatalf("config get json: %v", err) + } + var value workspace.ConfigValue + if err := json.Unmarshal([]byte(stdout), &value); err != nil { + t.Fatalf("unmarshal config value: %v\n%s", err, stdout) + } + if value.Key != "watch.languages" || value.Value != "go,typescript" || value.Source != workspace.ConfigSourceFile { + t.Fatalf("unexpected config value: %+v", value) + } + + stdout, _, err = RunCmd(t, dir, "--format", "json", "config", "validate") + if err != nil { + t.Fatalf("config validate json: %v", err) + } + if !strings.Contains(stdout, `"ok": true`) { + t.Fatalf("validate json did not report ok:\n%s", stdout) + } + + if _, _, err := RunCmd(t, dir, "config", "set", "TLD_CONFIG_DIR", "/tmp/nope"); err == nil { + t.Fatal("expected runtime locator key to be rejected") + } + if _, _, err := RunCmd(t, dir, "config", "set", "serve.port", "99999"); err == nil { + t.Fatal("expected invalid port to be rejected") + } +} + +func TestConfigCommandEnvOverrideSource(t *testing.T) { + configDir := t.TempDir() + t.Setenv("TLD_CONFIG_DIR", configDir) + t.Setenv("PORT", "7777") + dir := t.TempDir() + + if err := os.WriteFile(filepath.Join(configDir, "tld.yaml"), []byte("serve:\n port: \"8888\"\n"), 0o600); err != nil { + t.Fatalf("write config: %v", err) + } + + stdout, _, err := RunCmd(t, dir, "--format", "json", "config", "get", "serve.port") + if err != nil { + t.Fatalf("config get serve.port: %v", err) + } + var value workspace.ConfigValue + if err := json.Unmarshal([]byte(stdout), &value); err != nil { + t.Fatalf("unmarshal config value: %v\n%s", err, stdout) + } + if value.Value != "7777" || value.Source != workspace.ConfigSourceEnv || value.Env != "PORT" { + t.Fatalf("serve.port = %+v, want PORT env override", value) + } +} diff --git a/cmd/connect/connect.go b/cmd/connect/connect.go index 70ec34c..ae4ad7a 100644 --- a/cmd/connect/connect.go +++ b/cmd/connect/connect.go @@ -6,6 +6,7 @@ import ( "github.com/mertcikla/tld/internal/cmdutil" "github.com/mertcikla/tld/internal/completion" + "github.com/mertcikla/tld/internal/term" "github.com/mertcikla/tld/internal/workspace" "github.com/spf13/cobra" ) @@ -59,8 +60,8 @@ func NewConnectCmd(wdir, format *string, compact *bool) *cobra.Command { if cmdutil.WantsJSON(*format) { return cmdutil.WriteMutation(cmd.OutOrStdout(), *compact, "connect", "connect", fmt.Sprintf("%s:%s", from, to)) } - _, _ = fmt.Fprintf(cmd.OutOrStdout(), "Appended connector %s -> %s in view %s to connectors.yaml\n", from, to, view) - _, _ = fmt.Fprintln(cmd.OutOrStdout(), "Change recorded locally in connectors.yaml. Run 'tld apply' to push to cloud.") + term.Successf(cmd.OutOrStdout(), "Connector %s → %s added in view %s", from, to, view) + term.Hint(cmd.OutOrStdout(), "Run 'tld apply' to push to cloud.") return nil }, } diff --git a/cmd/initialize/init.go b/cmd/initialize/init.go index 20632b7..d6521bb 100644 --- a/cmd/initialize/init.go +++ b/cmd/initialize/init.go @@ -8,6 +8,7 @@ import ( "strings" "github.com/mertcikla/tld/internal/git" + "github.com/mertcikla/tld/internal/term" "github.com/mertcikla/tld/internal/workspace" "github.com/spf13/cobra" "gopkg.in/yaml.v3" @@ -108,22 +109,18 @@ func NewInitCmd() *cobra.Command { } if _, err := os.Stat(cfgPath); err == nil { - _, _ = fmt.Fprintf(cmd.OutOrStdout(), "Initialized workspace at %s (global config already exists at %s)\n", dir, cfgPath) + term.Successf(cmd.OutOrStdout(), "Workspace initialized at %s", term.Path(cmd.OutOrStdout(), dir)) + term.Infof(cmd.OutOrStdout(), "Global config already exists at %s", term.Path(cmd.OutOrStdout(), cfgPath)) } else { - cfg := `# tld global configuration -server_url: https://tldiagram.com -api_key: "" # or set TLD_API_KEY env var -org_id: "" # UUID of your organisation -` - if err := os.WriteFile(cfgPath, []byte(cfg), 0600); err != nil { - return fmt.Errorf("write tld.yaml: %w", err) + if err := workspace.EnsureGlobalConfig(); err != nil { + return fmt.Errorf("ensure global config: %w", err) } - _, _ = fmt.Fprintf(cmd.OutOrStdout(), "Initialized workspace at %s\n", dir) - _, _ = fmt.Fprintf(cmd.OutOrStdout(), "Global configuration file created: %s\n", cfgPath) + term.Successf(cmd.OutOrStdout(), "Workspace initialized at %s", term.Path(cmd.OutOrStdout(), dir)) + term.Infof(cmd.OutOrStdout(), "Global config created at %s", term.Path(cmd.OutOrStdout(), cfgPath)) } if !wizard { - _, _ = fmt.Fprintln(cmd.OutOrStdout(), "Run `tld login` to authenticate with tldiagram.com") + term.Hint(cmd.OutOrStdout(), "Run 'tld login' to authenticate with tldiagram.com") } return nil }, @@ -155,7 +152,7 @@ func runInitWizard(cmd *cobra.Command, dir string) error { if remoteURL, err := git.DetectRemoteURL(parentDir); err == nil { defaultRepoURL = remoteURL } else { - _, _ = fmt.Fprintln(cmd.OutOrStdout(), "Warning: current folder is not a git repository. You can still initialize tld, but automatic source code linking will require manual configuration of repository URL and localDir in .tld.yaml.") + term.Warn(cmd.OutOrStdout(), "Current folder is not a git repository. Automatic source linking requires manual configuration of repository URL and localDir in .tld.yaml.") } for { @@ -206,11 +203,12 @@ func runInitWizard(cmd *cobra.Command, dir string) error { return fmt.Errorf("write .tld.yaml: %w", err) } - _, _ = fmt.Fprintln(cmd.OutOrStdout(), "Next steps:") - _, _ = fmt.Fprintln(cmd.OutOrStdout(), " 1. tld login - authenticate with tlDiagram.com") - _, _ = fmt.Fprintln(cmd.OutOrStdout(), " 2. tld analyze . - extract symbols from your repo") - _, _ = fmt.Fprintln(cmd.OutOrStdout(), " 3. tld plan - preview what will be created") - _, _ = fmt.Fprintln(cmd.OutOrStdout(), " 4. tld apply - push to tlDiagram.com") + term.Separator(cmd.OutOrStdout()) + term.Info(cmd.OutOrStdout(), "Next steps:") + _, _ = fmt.Fprintln(cmd.OutOrStdout(), " 1. tld login - authenticate with tlDiagram.com") + _, _ = fmt.Fprintln(cmd.OutOrStdout(), " 2. tld analyze . - extract symbols from your repo") + _, _ = fmt.Fprintln(cmd.OutOrStdout(), " 3. tld plan - preview what will be created") + _, _ = fmt.Fprintln(cmd.OutOrStdout(), " 4. tld apply - push to tlDiagram.com") return nil } diff --git a/cmd/initialize/init_test.go b/cmd/initialize/init_test.go index c12ac4b..21e7ea5 100644 --- a/cmd/initialize/init_test.go +++ b/cmd/initialize/init_test.go @@ -19,8 +19,8 @@ func TestInitCmd_CreatesWorkspace(t *testing.T) { if err != nil { t.Fatalf("init: %v", err) } - if !strings.Contains(stdout, "Initialized workspace") { - t.Errorf("stdout %q does not contain 'Initialized workspace'", stdout) + if !strings.Contains(stdout, "Workspace initialized at") { + t.Errorf("stdout %q does not contain 'Workspace initialized at'", stdout) } workspaceCfgPath := filepath.Join(dir, ".tld.yaml") @@ -99,7 +99,7 @@ func TestInitCmd_AlreadyInitialized(t *testing.T) { if err != nil { t.Fatalf("second init: %v", err) } - if !strings.Contains(stdout, "Initialized workspace at") || !strings.Contains(stdout, "config already exists") { + if !strings.Contains(stdout, "Workspace initialized at") || !strings.Contains(stdout, "config already exists") { t.Errorf("stdout %q does not contain 'Initialized' or 'config already exists'", stdout) } } diff --git a/cmd/login/login.go b/cmd/login/login.go index 6c5c034..740230a 100644 --- a/cmd/login/login.go +++ b/cmd/login/login.go @@ -12,6 +12,7 @@ import ( "connectrpc.com/connect" "github.com/mertcikla/tld/internal/client" "github.com/mertcikla/tld/internal/cmdutil" + "github.com/mertcikla/tld/internal/term" "github.com/mertcikla/tld/internal/workspace" "github.com/spf13/cobra" "gopkg.in/yaml.v3" @@ -51,9 +52,12 @@ enter the displayed code at /app/device.`, } // Step 2: inform the user. - _, _ = fmt.Fprintf(cmd.OutOrStdout(), "\nOpen the following URL to log in:\n\n %s\n\n", auth.VerificationUriComplete) - _, _ = fmt.Fprintf(cmd.OutOrStdout(), "Or navigate to %s and enter the code:\n\n %s\n\n", auth.VerificationUri, auth.UserCode) - _, _ = fmt.Fprintln(cmd.OutOrStdout(), "Waiting for authorisation… (press Ctrl+C to cancel)") + term.Separator(cmd.OutOrStdout()) + term.Infof(cmd.OutOrStdout(), "Open the following URL to log in:\n\n %s", term.URL(cmd.OutOrStdout(), auth.VerificationUriComplete)) + term.Separator(cmd.OutOrStdout()) + term.Infof(cmd.OutOrStdout(), "Or navigate to %s and enter the code:\n\n %s", auth.VerificationUri, auth.UserCode) + term.Separator(cmd.OutOrStdout()) + term.Info(cmd.OutOrStdout(), "Waiting for authorisation… (press Ctrl+C to cancel)") // Step 3: optionally open the browser. if !noBrowser { @@ -75,7 +79,8 @@ enter the displayed code at /app/device.`, } cfgPath, _ := workspace.ConfigPath() - _, _ = fmt.Fprintf(cmd.OutOrStdout(), "\nAuthorised! Config written to %s\n", cfgPath) + term.Separator(cmd.OutOrStdout()) + term.Successf(cmd.OutOrStdout(), "Authorised! Config written to %s", term.Path(cmd.OutOrStdout(), cfgPath)) return nil }, } diff --git a/cmd/mcp/mcp.go b/cmd/mcp/mcp.go index a3400cd..9255e42 100644 --- a/cmd/mcp/mcp.go +++ b/cmd/mcp/mcp.go @@ -6,6 +6,7 @@ import ( "fmt" "os" "os/exec" + "slices" "strconv" "github.com/mertcikla/tld/cmd/apply" @@ -249,9 +250,6 @@ func registerTools(server *mcpsdk.Server, wdir *string) { repoCtx := cmdutil.DetectRepoScope(cmdutil.GetWorkingDir(), *wdir) rules := ws.IgnoreRulesForRepository(repoCtx.Name) if a.Strictness > 0 { - if ws.Config.Validation == nil { - ws.Config.Validation = &workspace.ValidationConfig{} - } ws.Config.Validation.Level = a.Strictness } out := "" @@ -400,10 +398,8 @@ func inferView(ws *workspace.Workspace, from, to string) string { fp := parents(fromEl) tp := parents(toEl) for _, f := range fp { - for _, t := range tp { - if f == t { - return f - } + if slices.Contains(tp, f) { + return f } } return "root" @@ -455,13 +451,17 @@ Accepts the same --host, --port, --data-dir flags as 'tld serve'.`, port, _ := cmd.Flags().GetString("port") dataDirFlag, _ := cmd.Flags().GetString("data-dir") - cfg, _ := workspace.LoadGlobalConfig() + cfg, err := workspace.LoadGlobalConfig() + if err != nil { + return err + } dataDir, err := workspace.ResolveDataDir(cfg, dataDirFlag) if err != nil { return err } + serveCfg := workspace.ResolveServeOptions(cfg, host, port) - if err := ensureServeRunning(cmd, host, port, dataDir); err != nil { + if err := ensureServeRunning(cmd, serveCfg.Host, serveCfg.Port, dataDir); err != nil { fmt.Fprintf(os.Stderr, "warning: failed to start tld serve in background: %v\n", err) } diff --git a/cmd/plan/plan.go b/cmd/plan/plan.go index 2ca8c17..31cf34f 100644 --- a/cmd/plan/plan.go +++ b/cmd/plan/plan.go @@ -34,9 +34,6 @@ func NewPlanCmd(wdir *string) *cobra.Command { // Override strictness if flag is set if strictness > 0 { - if ws.Config.Validation == nil { - ws.Config.Validation = &workspace.ValidationConfig{} - } ws.Config.Validation.Level = strictness } @@ -103,9 +100,9 @@ func NewPlanCmd(wdir *string) *cobra.Command { // Evaluate Diagram warnings if len(warnings) > 0 { - level := workspace.DefaultValidationLevel - if ws.Config.Validation != nil && ws.Config.Validation.Level > 0 { - level = ws.Config.Validation.Level + level := ws.Config.Validation.Level + if level == 0 { + level = workspace.DefaultValidationLevel } levelNames := map[int]string{1: "Minimal", 2: "Standard", 3: "Strict"} _, _ = fmt.Fprintf(out, "\n## Architectural Warnings (Level %d: %s)\n\n", level, levelNames[level]) diff --git a/cmd/plan/plan_test.go b/cmd/plan/plan_test.go index e971966..f0ad846 100644 --- a/cmd/plan/plan_test.go +++ b/cmd/plan/plan_test.go @@ -128,14 +128,3 @@ func TestPlanCmd_InvalidWorkspaceErrors(t *testing.T) { t.Fatal("expected error for invalid workspace") } } - -func TestPlanCmd_MissingConfig(t *testing.T) { - dir := t.TempDir() - _, _, err := cmd.RunCmd(t, dir, "plan") - if err == nil { - t.Fatal("expected error for missing config") - } - if !strings.Contains(err.Error(), "load workspace") { - t.Errorf("error %q does not contain 'load workspace'", err.Error()) - } -} diff --git a/cmd/pull/pull.go b/cmd/pull/pull.go index bac180f..798bfd0 100644 --- a/cmd/pull/pull.go +++ b/cmd/pull/pull.go @@ -11,6 +11,7 @@ import ( "connectrpc.com/connect" "github.com/mertcikla/tld/internal/client" "github.com/mertcikla/tld/internal/cmdutil" + "github.com/mertcikla/tld/internal/term" "github.com/mertcikla/tld/internal/workspace" "github.com/spf13/cobra" ) @@ -53,15 +54,15 @@ you before overwriting them. Use --force to skip the prompt.`, return fmt.Errorf("calculate hash: %w", err) } if currentHash != lockFile.WorkspaceHash { - _, _ = fmt.Fprintf(cmd.OutOrStdout(), "Warning: local workspace has uncommitted changes.\n") - _, _ = fmt.Fprintf(cmd.OutOrStdout(), "Pull will overwrite them. Continue? [yes/no]: ") + term.Warn(cmd.OutOrStdout(), "Local workspace has uncommitted changes. Pull will overwrite them.") + _, _ = fmt.Fprint(cmd.OutOrStdout(), " Continue? [yes/no]: ") scanner := bufio.NewScanner(cmd.InOrStdin()) if !scanner.Scan() { return errors.New("aborted") } answer := strings.TrimSpace(strings.ToLower(scanner.Text())) if answer != "yes" && answer != "y" { - _, _ = fmt.Fprintln(cmd.OutOrStdout(), "Pull cancelled.") + term.Infof(cmd.OutOrStdout(), "Pull cancelled.") return nil } } @@ -78,7 +79,7 @@ you before overwriting them. Use --force to skip the prompt.`, newWS := cmdutil.ConvertExportResponse(ws, resp.Msg) if dryRun { - _, _ = fmt.Fprintf(cmd.OutOrStdout(), "Would pull: %d elements, %d diagrams, %d connectors\n", + term.Infof(cmd.OutOrStdout(), "Would pull: %d elements, %d diagrams, %d connectors", len(newWS.Elements), cmdutil.CountElementDiagrams(newWS), len(newWS.Connectors)) return nil } @@ -123,7 +124,7 @@ you before overwriting them. Use --force to skip the prompt.`, return fmt.Errorf("write lock file: %w", err) } - _, _ = fmt.Fprintf(cmd.OutOrStdout(), "Pulled %d elements, %d diagrams, %d connectors\n", + term.Successf(cmd.OutOrStdout(), "Pulled %d elements, %d diagrams, %d connectors", len(newWS.Elements), cmdutil.CountElementDiagrams(newWS), len(newWS.Connectors)) return nil diff --git a/cmd/remove/remove.go b/cmd/remove/remove.go index 61ef20e..5b16a71 100644 --- a/cmd/remove/remove.go +++ b/cmd/remove/remove.go @@ -5,6 +5,7 @@ import ( "github.com/mertcikla/tld/internal/cmdutil" "github.com/mertcikla/tld/internal/completion" + "github.com/mertcikla/tld/internal/term" "github.com/mertcikla/tld/internal/workspace" "github.com/spf13/cobra" ) @@ -47,8 +48,8 @@ func newElementCmd(wdir, format *string, compact *bool) *cobra.Command { if cmdutil.WantsJSON(*format) { return cmdutil.WriteMutation(cmd.OutOrStdout(), *compact, "remove element", "remove", ref) } - _, _ = fmt.Fprintf(cmd.OutOrStdout(), "Removed %s from elements.yaml\n", ref) - _, _ = fmt.Fprintln(cmd.OutOrStdout(), "Change recorded locally in elements.yaml. Run 'tld apply' to push to cloud.") + term.Successf(cmd.OutOrStdout(), "Removed %s from elements.yaml", ref) + term.Hint(cmd.OutOrStdout(), "Run 'tld apply' to push to cloud.") return nil }, } @@ -77,10 +78,10 @@ func newConnectorCmd(wdir, format *string, compact *bool) *cobra.Command { return cmdutil.WriteMutation(cmd.OutOrStdout(), *compact, "remove connector", "remove", fmt.Sprintf("%s:%s:%s", view, from, to)) } if n == 0 { - _, _ = fmt.Fprintln(cmd.OutOrStdout(), "No matching connectors found - nothing removed.") + term.Info(cmd.OutOrStdout(), "No matching connectors found — nothing removed.") } else { - _, _ = fmt.Fprintf(cmd.OutOrStdout(), "Removed %d connector(s) from connectors.yaml\n", n) - _, _ = fmt.Fprintln(cmd.OutOrStdout(), "Change recorded locally in connectors.yaml. Run 'tld apply' to push to cloud.") + term.Successf(cmd.OutOrStdout(), "Removed %d connector(s) from connectors.yaml", n) + term.Hint(cmd.OutOrStdout(), "Run 'tld apply' to push to cloud.") } return nil }, diff --git a/cmd/rename/rename.go b/cmd/rename/rename.go index 9a6615c..f756963 100644 --- a/cmd/rename/rename.go +++ b/cmd/rename/rename.go @@ -5,6 +5,7 @@ import ( "github.com/mertcikla/tld/internal/cmdutil" "github.com/mertcikla/tld/internal/completion" + "github.com/mertcikla/tld/internal/term" "github.com/mertcikla/tld/internal/workspace" "github.com/spf13/cobra" ) @@ -30,9 +31,9 @@ func NewRenameCmd(wdir *string) *cobra.Command { if cmdutil.WantsJSON(cmd.Root().PersistentFlags().Lookup("format").Value.String()) { return cmdutil.WriteMutation(cmd.OutOrStdout(), cmd.Root().PersistentFlags().Lookup("compact").Value.String() == "true", "rename", "rename", fmt.Sprintf("%s -> %s", from, to)) } - _, _ = fmt.Fprintf(cmd.OutOrStdout(), "Renamed element %s to %s in elements.yaml\n", from, to) - _, _ = fmt.Fprintln(cmd.OutOrStdout(), "Note: References in connectors.yaml and other diagrams were updated automatically.") - _, _ = fmt.Fprintln(cmd.OutOrStdout(), "Change recorded locally. Run 'tld apply' to push to cloud.") + term.Successf(cmd.OutOrStdout(), "Renamed element %s → %s", from, to) + term.Hint(cmd.OutOrStdout(), "References in connectors.yaml were updated automatically.") + term.Hint(cmd.OutOrStdout(), "Run 'tld apply' to push to cloud.") return nil }, } diff --git a/cmd/rename/rename_test.go b/cmd/rename/rename_test.go new file mode 100644 index 0000000..d478668 --- /dev/null +++ b/cmd/rename/rename_test.go @@ -0,0 +1,48 @@ +package rename_test + +import ( + "strings" + "testing" + + "github.com/mertcikla/tld/cmd" + "github.com/mertcikla/tld/internal/workspace" +) + +func TestRenameCmdCascadesElementReferences(t *testing.T) { + dir := t.TempDir() + cmd.MustInitWorkspace(t, dir) + cmd.SeedElementWorkspace(t, dir) + + stdout, stderr, err := cmd.RunCmd(t, dir, "rename", "--from", "api", "--to", "service-api") + if err != nil { + t.Fatalf("rename: %v\nstdout:%s\nstderr:%s", err, stdout, stderr) + } + if !strings.Contains(stdout, "Renamed element") { + t.Fatalf("stdout = %q, want rename confirmation", stdout) + } + + ws, err := workspace.Load(dir) + if err != nil { + t.Fatal(err) + } + if _, ok := ws.Elements["api"]; ok { + t.Fatalf("old api ref still exists: %+v", ws.Elements) + } + if ws.Elements["service-api"] == nil { + t.Fatalf("new service-api ref missing: %+v", ws.Elements) + } + connector := ws.Connectors["platform:service-api:db:reads"] + if connector == nil { + t.Fatalf("renamed connector key missing: %+v", ws.Connectors) + } + if connector.Source != "service-api" || connector.Target != "db" { + t.Fatalf("connector endpoints = %s -> %s, want service-api -> db", connector.Source, connector.Target) + } +} + +func TestRenameCmdRequiresBothRefs(t *testing.T) { + _, _, err := cmd.RunCmd(t, t.TempDir(), "rename", "--from", "api") + if err == nil || !strings.Contains(err.Error(), `required flag(s) "to" not set`) { + t.Fatalf("err = %v, want cobra required flag validation", err) + } +} diff --git a/cmd/root.go b/cmd/root.go index c1f847a..2eef71d 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -8,6 +8,7 @@ import ( "github.com/mertcikla/tld/cmd/analyze" "github.com/mertcikla/tld/cmd/apply" "github.com/mertcikla/tld/cmd/check" + configcmd "github.com/mertcikla/tld/cmd/config" "github.com/mertcikla/tld/cmd/connect" "github.com/mertcikla/tld/cmd/diff" "github.com/mertcikla/tld/cmd/export" @@ -25,6 +26,7 @@ import ( "github.com/mertcikla/tld/cmd/validate" "github.com/mertcikla/tld/cmd/version" "github.com/mertcikla/tld/cmd/views" + watchcmd "github.com/mertcikla/tld/cmd/watch" "github.com/mertcikla/tld/internal/completion" "github.com/spf13/cobra" ) @@ -70,7 +72,7 @@ and apply them atomically with 'tld apply'.`, } var wdir string - defaultWdir := "." + defaultWdir := "" if _, err := os.Stat(".tld"); err == nil { defaultWdir = ".tld" } else if _, err := os.Stat("tld"); err == nil { @@ -147,6 +149,12 @@ and apply them atomically with 'tld apply'.`, checkCmd := check.NewCheckCmd(&wdir) checkCmd.GroupID = secondaryGroup.ID + configCmd := configcmd.NewConfigCmd() + configCmd.GroupID = secondaryGroup.ID + + watchCmd := watchcmd.NewWatchCmd() + watchCmd.GroupID = secondaryGroup.ID + serveCmd := serve.NewServeCmd(nil) serveCmd.GroupID = secondaryGroup.ID @@ -174,6 +182,8 @@ and apply them atomically with 'tld apply'.`, renameCmd, analyzeCmd, checkCmd, + configCmd, + watchCmd, serveCmd, mcpCmd, stopCmd, diff --git a/cmd/root_workspace_test.go b/cmd/root_workspace_test.go new file mode 100644 index 0000000..7a5b77f --- /dev/null +++ b/cmd/root_workspace_test.go @@ -0,0 +1,62 @@ +package cmd + +import ( + "os" + "path/filepath" + "testing" +) + +func TestRootWorkspaceDefaultPrecedence(t *testing.T) { + t.Run("prefers .tld over tld", func(t *testing.T) { + dir := t.TempDir() + if err := os.Mkdir(filepath.Join(dir, ".tld"), 0o750); err != nil { + t.Fatal(err) + } + if err := os.Mkdir(filepath.Join(dir, "tld"), 0o750); err != nil { + t.Fatal(err) + } + + cwd, _ := os.Getwd() + defer func() { _ = os.Chdir(cwd) }() + if err := os.Chdir(dir); err != nil { + t.Fatal(err) + } + + root := NewRootCmd() + if got := root.PersistentFlags().Lookup("workspace").DefValue; got != ".tld" { + t.Fatalf("workspace default = %q, want %q", got, ".tld") + } + }) + + t.Run("uses tld when .tld is missing", func(t *testing.T) { + dir := t.TempDir() + if err := os.Mkdir(filepath.Join(dir, "tld"), 0o750); err != nil { + t.Fatal(err) + } + + cwd, _ := os.Getwd() + defer func() { _ = os.Chdir(cwd) }() + if err := os.Chdir(dir); err != nil { + t.Fatal(err) + } + + root := NewRootCmd() + if got := root.PersistentFlags().Lookup("workspace").DefValue; got != "tld" { + t.Fatalf("workspace default = %q, want %q", got, "tld") + } + }) + + t.Run("uses empty default when no workspace exists", func(t *testing.T) { + dir := t.TempDir() + cwd, _ := os.Getwd() + defer func() { _ = os.Chdir(cwd) }() + if err := os.Chdir(dir); err != nil { + t.Fatal(err) + } + + root := NewRootCmd() + if got := root.PersistentFlags().Lookup("workspace").DefValue; got != "" { + t.Fatalf("workspace default = %q, want empty string", got) + } + }) +} diff --git a/cmd/serve/serve.go b/cmd/serve/serve.go index 3b65dc1..1a2c74e 100644 --- a/cmd/serve/serve.go +++ b/cmd/serve/serve.go @@ -20,6 +20,11 @@ import ( "github.com/spf13/cobra" ) +const ( + backgroundReadyTimeout = 30 * time.Second + readyRequestTimeout = 10 * time.Second +) + func defaultServeRunE(cmd *cobra.Command, args []string) error { _ = workspace.EnsureGlobalConfig() @@ -29,7 +34,10 @@ func defaultServeRunE(cmd *cobra.Command, args []string) error { port, _ := cmd.Flags().GetString("port") dataDirFlag, _ := cmd.Flags().GetString("data-dir") - cfg, _ := workspace.LoadGlobalConfig() + cfg, err := workspace.LoadGlobalConfig() + if err != nil { + return err + } dataDir, err := workspace.ResolveDataDir(cfg, dataDirFlag) if err != nil { return err @@ -43,7 +51,11 @@ func defaultServeRunE(cmd *cobra.Command, args []string) error { func runForeground(cmd *cobra.Command, host, port, dataDir string, openBrowser bool) error { started := time.Now() - opts := resolveServeOptions(host, port) + cfg, err := workspace.LoadGlobalConfig() + if err != nil { + return err + } + opts := resolveServeOptions(cfg, host, port) app, err := localserver.Bootstrap(dataDir, opts) if err != nil { @@ -85,7 +97,11 @@ func runForeground(cmd *cobra.Command, host, port, dataDir string, openBrowser b func runBackground(cmd *cobra.Command, host, port, dataDir string, openBrowser bool) error { started := time.Now() pidPath := localserver.PIDPath(dataDir) - opts := resolveServeOptions(host, port) + cfg, err := workspace.LoadGlobalConfig() + if err != nil { + return err + } + opts := resolveServeOptions(cfg, host, port) addr := localserver.ResolveAddr(opts) url := "http://" + addr initializedData := databaseWillBeInitialized(dataDir) @@ -98,7 +114,7 @@ func runBackground(cmd *cobra.Command, host, port, dataDir string, openBrowser b Mode: "background", InitializedData: initializedData, Resources: readyResources(ready), - PID: intPtr(pid), + PID: new(pid), BindAddr: addr, Startup: 0, DBPath: localserver.DatabasePath(dataDir), @@ -119,11 +135,11 @@ func runBackground(cmd *cobra.Command, host, port, dataDir string, openBrowser b } fwdArgs := []string{"serve", "--foreground"} - if host != "" { - fwdArgs = append(fwdArgs, "--host", host) + if opts.Host != "" { + fwdArgs = append(fwdArgs, "--host", opts.Host) } - if port != "" { - fwdArgs = append(fwdArgs, "--port", port) + if opts.Port != "" { + fwdArgs = append(fwdArgs, "--port", opts.Port) } // Always pass resolved dataDir to foreground child fwdArgs = append(fwdArgs, "--data-dir", dataDir) @@ -148,7 +164,7 @@ func runBackground(cmd *cobra.Command, host, port, dataDir string, openBrowser b return fmt.Errorf("write pid file: %w", err) } - ready, err := waitReady(url+"/api/ready", 10*time.Second) + ready, err := waitReady(url+"/api/ready", backgroundReadyTimeout) if err != nil { _ = child.Process.Kill() _ = os.Remove(pidPath) @@ -165,7 +181,7 @@ func runBackground(cmd *cobra.Command, host, port, dataDir string, openBrowser b Mode: "background", InitializedData: initializedData, Resources: readyResources(ready), - PID: intPtr(child.Process.Pid), + PID: new(child.Process.Pid), BindAddr: addr, Startup: time.Since(started), DBPath: localserver.DatabasePath(dataDir), @@ -189,28 +205,28 @@ type serveStatus struct { func printServeInfo(out io.Writer, url string, status serveStatus) { cfgPath, _ := workspace.ConfigPath() - _, _ = fmt.Fprintf(out, "Mode: %s\n", printableMode(status.Mode)) + term.Label(out, 20, "Mode", printableMode(status.Mode)) if status.PID != nil { - _, _ = fmt.Fprintf(out, "PID: %d\n", *status.PID) + term.Label(out, 20, "PID", fmt.Sprintf("%d", *status.PID)) } - _, _ = fmt.Fprintf(out, "Server status: %s\n", dataStatus(status.InitializedData)) - _, _ = fmt.Fprintf(out, "Bind address: %s\n", status.BindAddr) + term.Label(out, 20, "Server status", dataStatus(status.InitializedData)) + term.Label(out, 20, "Bind address", status.BindAddr) if !status.InitializedData { - _, _ = fmt.Fprintf(out, "Resource counts: %d views, %d elements, %d connectors\n", status.Resources.Views, status.Resources.Elements, status.Resources.Connectors) + term.Label(out, 20, "Resource counts", fmt.Sprintf("%d views, %d elements, %d connectors", status.Resources.Views, status.Resources.Elements, status.Resources.Connectors)) } if status.Startup > 0 { - _, _ = fmt.Fprintf(out, "Ready in: %s\n", status.Startup.Round(time.Millisecond)) + term.Label(out, 20, "Ready in", status.Startup.Round(time.Millisecond).String()) } - _, _ = fmt.Fprintf(out, "DB: %s\n", styledLocalPath(out, status.DBPath)) + term.Label(out, 20, "DB", term.Path(out, status.DBPath)) if info, err := os.Stat(status.DBPath); err == nil { - _, _ = fmt.Fprintf(out, "DB size: %s\n", humanBytes(info.Size())) - _, _ = fmt.Fprintf(out, "DB last modified: %s\n", info.ModTime().Format(time.RFC3339)) - } - _, _ = fmt.Fprintf(out, "Config path: %s\n", styledLocalPath(out, cfgPath)) - _, _ = fmt.Fprintln(out) - _, _ = fmt.Fprintf(out, "tlDiagram available at: %s\n", styledWebappURL(out, url)) - _, _ = fmt.Fprintln(out) - _, _ = fmt.Fprintln(out, "Run 'tld stop' to shut down the server") + term.Label(out, 20, "DB size", humanBytes(info.Size())) + term.Label(out, 20, "DB last modified", info.ModTime().Format(time.RFC3339)) + } + term.Label(out, 20, "Config path", term.Path(out, cfgPath)) + term.Separator(out) + _, _ = fmt.Fprintf(out, " tlDiagram available at: %s\n", term.URL(out, url)) + term.Separator(out) + term.Hint(out, "Run 'tld stop' to shut down the server") } func databaseWillBeInitialized(dataDir string) bool { @@ -232,14 +248,6 @@ func printableMode(mode string) string { return mode } -func styledLocalPath(out io.Writer, path string) string { - return formatLocalPath(path, term.IsColorEnabled(out)) -} - -func styledWebappURL(out io.Writer, url string) string { - return formatWebappURL(url, term.IsColorEnabled(out)) -} - func formatLocalPath(path string, colorEnabled bool) string { if !colorEnabled { return path @@ -267,8 +275,6 @@ func humanBytes(size int64) string { return fmt.Sprintf("%.1f %cB", float64(size)/float64(div), "KMGTPE"[exp]) } -func intPtr(v int) *int { return &v } - type readyInfo struct { OK bool `json:"ok"` Resources struct { @@ -302,7 +308,7 @@ func waitReady(url string, timeout time.Duration) (*readyInfo, error) { } func getReady(url string) (*readyInfo, error) { - client := &http.Client{Timeout: 2 * time.Second} + client := &http.Client{Timeout: readyRequestTimeout} resp, err := client.Get(url) if err != nil { return nil, err @@ -318,17 +324,9 @@ func getReady(url string) (*readyInfo, error) { return &ready, nil } -func resolveServeOptions(flagHost, flagPort string) localserver.ServeOptions { - cfg, _ := workspace.LoadGlobalConfig() - host := cfg.Serve.Host - port := cfg.Serve.Port - if flagHost != "" { - host = flagHost - } - if flagPort != "" { - port = flagPort - } - return localserver.ServeOptions{Host: host, Port: port} +func resolveServeOptions(cfg *workspace.Config, flagHost, flagPort string) localserver.ServeOptions { + serve := workspace.ResolveServeOptions(cfg, flagHost, flagPort) + return localserver.ServeOptions{Host: serve.Host, Port: serve.Port} } func NewServeCmd(runE func(*cobra.Command, []string) error) *cobra.Command { diff --git a/cmd/serve/serve_internal_test.go b/cmd/serve/serve_internal_test.go index 82842bd..5845cc0 100644 --- a/cmd/serve/serve_internal_test.go +++ b/cmd/serve/serve_internal_test.go @@ -23,26 +23,26 @@ func TestPrintServeInfoIncludesCoreFields(t *testing.T) { }) got := out.String() - if !strings.Contains(got, "Mode: foreground") { + if !strings.Contains(got, "Mode:") || !strings.Contains(got, "foreground") { t.Fatalf("missing mode in output: %q", got) } - if !strings.Contains(got, "Server status: initialized new local data") { + if !strings.Contains(got, "Server status:") || !strings.Contains(got, "initialized new local data") { t.Fatalf("missing initialized data status in output: %q", got) } - if !strings.Contains(got, "Bind address: 127.0.0.1:8060") { + if !strings.Contains(got, "Bind address:") || !strings.Contains(got, "127.0.0.1:8060") { t.Fatalf("missing bind address in output: %q", got) } - if !strings.Contains(got, "DB: "+dbPath) { + if !strings.Contains(got, "DB:") || !strings.Contains(got, dbPath) { t.Fatalf("missing database path in output: %q", got) } if strings.Contains(got, "Data path:") { t.Fatalf("data path should not be printed anymore: %q", got) } - if !strings.Contains(got, "\n\ntlDiagram available at:") { - t.Fatalf("webapp line should be preceded by a blank line: %q", got) + if !strings.Contains(got, "tlDiagram available at:") { + t.Fatalf("webapp line should be present: %q", got) } - if !strings.Contains(got, "tlDiagram available at: http://127.0.0.1:8060\n\nRun 'tld stop'") { - t.Fatalf("webapp line should be followed by a blank line before stop hint: %q", got) + if !strings.Contains(got, "http://127.0.0.1:8060") { + t.Fatalf("webapp url should be present: %q", got) } } @@ -61,13 +61,13 @@ func TestPrintServeInfoIncludesExistingDataCounts(t *testing.T) { }) got := out.String() - if !strings.Contains(got, "Mode: background") { + if !strings.Contains(got, "Mode:") || !strings.Contains(got, "background") { t.Fatalf("missing background mode in output: %q", got) } - if !strings.Contains(got, "Server status: using existing local data") { + if !strings.Contains(got, "Server status:") || !strings.Contains(got, "using existing local data") { t.Fatalf("missing existing data status in output: %q", got) } - if !strings.Contains(got, "Resource counts: 2 views, 7 elements, 3 connectors") { + if !strings.Contains(got, "Resource counts:") || !strings.Contains(got, "2 views, 7 elements, 3 connectors") { t.Fatalf("missing resource counts in output: %q", got) } if !strings.Contains(got, "DB:") { @@ -89,7 +89,7 @@ func TestPrintServeInfoShowsDBPathWhileInitializing(t *testing.T) { DBPath: dbPath, }) - if !strings.Contains(out.String(), "DB: "+dbPath) { + if !strings.Contains(out.String(), "DB:") || !strings.Contains(out.String(), dbPath) { t.Fatalf("expected db path even when initializing, got %q", out.String()) } if strings.Contains(out.String(), "DB size:") { diff --git a/cmd/status/status.go b/cmd/status/status.go index a1969c1..263c389 100644 --- a/cmd/status/status.go +++ b/cmd/status/status.go @@ -69,33 +69,33 @@ any drift from manual changes in the frontend.`, return cmdutil.WriteJSON(cmd.OutOrStdout(), cmd.Root().PersistentFlags().Lookup("compact").Value.String() == "true", cmdutil.BuildStatusJSON(lockFile, localModified, serverDrift, conflicts, respOrNil(serverResp))) } printStatusHeader(cmd.OutOrStdout(), localModified, serverDrift, conflicts) - _, _ = fmt.Fprintf(cmd.OutOrStdout(), "Last sync: %s\n", lockFile.LastApply.Format(time.RFC3339)) - _, _ = fmt.Fprintf(cmd.OutOrStdout(), "Applied by: %s\n", lockFile.AppliedBy) - _, _ = fmt.Fprintf(cmd.OutOrStdout(), "Version: %s\n", lockFile.VersionID) + term.Label(cmd.OutOrStdout(), 15, "Last sync", lockFile.LastApply.Format(time.RFC3339)) + term.Label(cmd.OutOrStdout(), 15, "Applied by", lockFile.AppliedBy) + term.Label(cmd.OutOrStdout(), 15, "Version", lockFile.VersionID) if hashErr == nil { if localModified { - _, _ = fmt.Fprintln(cmd.OutOrStdout(), "Local changes: Modified") + term.Label(cmd.OutOrStdout(), 15, "Local changes", "Modified") } else { - _, _ = fmt.Fprintln(cmd.OutOrStdout(), "Local changes: Clean") + term.Label(cmd.OutOrStdout(), 15, "Local changes", "Clean") } } if conflicts > 0 { - _, _ = fmt.Fprintf(cmd.OutOrStdout(), "Merge conflicts: %d found (run 'tld diff' to review)\n", conflicts) + term.Label(cmd.OutOrStdout(), 15, "Merge conflicts", fmt.Sprintf("%d found (run 'tld diff' to review)", conflicts)) } if checkServer { - _, _ = fmt.Fprintln(cmd.OutOrStdout(), "\nChecking server drift...") + term.Infof(cmd.OutOrStdout(), "Checking server drift...") if len(serverResp.Msg.Drift) == 0 && len(serverResp.Msg.Conflicts) == 0 { - _, _ = fmt.Fprintln(cmd.OutOrStdout(), "Server state: In sync") + term.Label(cmd.OutOrStdout(), 15, "Server state", "In sync") } else { if len(serverResp.Msg.Drift) > 0 { - _, _ = fmt.Fprintf(cmd.OutOrStdout(), "Server state: %d drift items found (run 'tld pull' to sync)\n", len(serverResp.Msg.Drift)) + term.Label(cmd.OutOrStdout(), 15, "Server state", fmt.Sprintf("%d drift items found (run 'tld pull' to sync)", len(serverResp.Msg.Drift))) for _, d := range serverResp.Msg.Drift { - _, _ = fmt.Fprintf(cmd.OutOrStdout(), " - %s: %s (%s)\n", d.ResourceType, d.Ref, d.Reason) + _, _ = fmt.Fprintf(cmd.OutOrStdout(), " - %s: %s (%s)\n", d.ResourceType, d.Ref, d.Reason) } } if len(serverResp.Msg.Conflicts) > 0 { - _, _ = fmt.Fprintf(cmd.OutOrStdout(), "Server state: %d conflicts found (run 'tld pull' or 'tld apply' to resolve)\n", len(serverResp.Msg.Conflicts)) + term.Label(cmd.OutOrStdout(), 15, "Server state", fmt.Sprintf("%d conflicts found (run 'tld pull' or 'tld apply' to resolve)", len(serverResp.Msg.Conflicts))) } } } @@ -103,7 +103,7 @@ any drift from manual changes in the frontend.`, if cmdutil.WantsJSON(cmd.Root().PersistentFlags().Lookup("format").Value.String()) { return cmdutil.WriteJSON(cmd.OutOrStdout(), cmd.Root().PersistentFlags().Lookup("compact").Value.String() == "true", cmdutil.BuildStatusJSON(nil, false, false, 0, nil)) } - _, _ = fmt.Fprintln(cmd.OutOrStdout(), "No sync history found.") + term.Info(cmd.OutOrStdout(), "No sync history found.") } return nil @@ -122,22 +122,22 @@ func respOrNil(resp *connect.Response[diagv1.ApplyPlanResponse]) *diagv1.ApplyPl } func printStatusHeader(out interface{ Write([]byte) (int, error) }, localModified, serverDrift bool, conflicts int) { - message := "* IN SYNC - workspace matches last applied state" - color := term.ColorGreen - - if serverDrift { - message = "* DRIFTED - server has changes not in YAML (run tld pull)" - color = term.ColorRed - } else if localModified || conflicts > 0 { + switch { + case serverDrift: + _, _ = fmt.Fprintln(out, term.Colorize(out, term.ColorRed, "✗ DRIFTED"), + " server has changes not in YAML (run tld pull)") + case localModified || conflicts > 0: if conflicts > 0 { - message = fmt.Sprintf("* MODIFIED - %d merge conflicts (run tld diff to review)", conflicts) + _, _ = fmt.Fprintf(out, "%s %d merge conflicts (run tld diff to review)\n", + term.Colorize(out, term.ColorYellow, "! MODIFIED"), conflicts) } else { - message = "* MODIFIED - local changes not pushed (run tld apply)" + _, _ = fmt.Fprintln(out, term.Colorize(out, term.ColorYellow, "! MODIFIED"), + " local changes not pushed (run tld apply)") } - color = term.ColorYellow + default: + _, _ = fmt.Fprintln(out, term.Colorize(out, term.ColorGreen, "✓ IN SYNC"), + " workspace matches last applied state") } - - _, _ = fmt.Fprintln(out, term.Colorize(out, color, message)) } func countWorkspaceConflicts(ws *workspace.Workspace) int { diff --git a/cmd/status/status_test.go b/cmd/status/status_test.go index d1a1d47..ba7ec97 100644 --- a/cmd/status/status_test.go +++ b/cmd/status/status_test.go @@ -41,7 +41,7 @@ func TestStatusCmd_Clean(t *testing.T) { if !strings.Contains(stdout, "IN SYNC") { t.Fatalf("missing IN SYNC header:\n%s", stdout) } - if !strings.Contains(stdout, "Local changes: Clean") { + if !strings.Contains(stdout, "Local changes:") { t.Fatalf("missing clean status detail:\n%s", stdout) } } @@ -70,7 +70,7 @@ func TestStatusCmd_Modified(t *testing.T) { if !strings.Contains(stdout, "MODIFIED") { t.Fatalf("missing MODIFIED header:\n%s", stdout) } - if !strings.Contains(stdout, "Local changes: Modified") { + if !strings.Contains(stdout, "Local changes:") { t.Fatalf("missing modified detail:\n%s", stdout) } } @@ -144,7 +144,7 @@ func TestStatusCmd_ConflictCount(t *testing.T) { if err != nil { t.Fatalf("status: %v\nstdout: %s\nstderr: %s", err, stdout, stderr) } - if !strings.Contains(stdout, "Merge conflicts: 2") { + if !strings.Contains(stdout, "Merge conflicts:") || !strings.Contains(stdout, "2") { t.Fatalf("missing conflict count: %s", stdout) } } @@ -167,7 +167,7 @@ func TestStatusCmd_CheckServer_InSync(t *testing.T) { if err != nil { t.Fatalf("status --check-server: %v\nstdout: %s\nstderr: %s", err, stdout, stderr) } - if !strings.Contains(stdout, "Server state: In sync") { + if !strings.Contains(stdout, "Server state:") || !strings.Contains(stdout, "In sync") { t.Fatalf("missing in-sync server output: %s", stdout) } } diff --git a/cmd/stop/stop.go b/cmd/stop/stop.go index 9ae4c2f..38384b5 100644 --- a/cmd/stop/stop.go +++ b/cmd/stop/stop.go @@ -7,6 +7,7 @@ import ( "time" "github.com/mertcikla/tld/internal/localserver" + "github.com/mertcikla/tld/internal/term" "github.com/mertcikla/tld/internal/workspace" "github.com/spf13/cobra" ) @@ -23,7 +24,10 @@ Sends SIGTERM and waits up to 10 seconds for a graceful shutdown. Use --kill to send SIGKILL immediately.`, RunE: func(cmd *cobra.Command, _ []string) error { dataDirFlag, _ := cmd.Flags().GetString("data-dir") - cfg, _ := workspace.LoadGlobalConfig() + cfg, err := workspace.LoadGlobalConfig() + if err != nil { + return err + } dataDir, err := workspace.ResolveDataDir(cfg, dataDirFlag) if err != nil { return err @@ -59,7 +63,7 @@ func runStop(cmd *cobra.Command, forceKill bool, dataDir string) error { return fmt.Errorf("kill: %w", err) } _ = os.Remove(pidPath) - _, _ = fmt.Fprintln(cmd.OutOrStdout(), "Server killed.") + term.Success(cmd.OutOrStdout(), "Server killed.") return nil } @@ -71,7 +75,7 @@ func runStop(cmd *cobra.Command, forceKill bool, dataDir string) error { for time.Now().Before(deadline) { if !localserver.IsRunning(pid) { _ = os.Remove(pidPath) - _, _ = fmt.Fprintln(cmd.OutOrStdout(), "Server stopped.") + term.Success(cmd.OutOrStdout(), "Server stopped.") return nil } time.Sleep(200 * time.Millisecond) diff --git a/cmd/stop/stop_test.go b/cmd/stop/stop_test.go new file mode 100644 index 0000000..351d949 --- /dev/null +++ b/cmd/stop/stop_test.go @@ -0,0 +1,22 @@ +package stop_test + +import ( + "bytes" + "strings" + "testing" + + "github.com/mertcikla/tld/cmd/stop" +) + +func TestStopCmdReportsNoServerRunningForEmptyDataDir(t *testing.T) { + cmd := stop.NewStopCmd() + var out bytes.Buffer + cmd.SetOut(&out) + cmd.SetErr(&out) + cmd.SetArgs([]string{"--data-dir", t.TempDir()}) + + err := cmd.Execute() + if err == nil || !strings.Contains(err.Error(), "no server running") { + t.Fatalf("err = %v, want no server running", err) + } +} diff --git a/cmd/update/update.go b/cmd/update/update.go index b4f4a8b..9a672c4 100644 --- a/cmd/update/update.go +++ b/cmd/update/update.go @@ -5,6 +5,7 @@ import ( "github.com/mertcikla/tld/internal/cmdutil" "github.com/mertcikla/tld/internal/completion" + "github.com/mertcikla/tld/internal/term" "github.com/mertcikla/tld/internal/workspace" "github.com/spf13/cobra" ) @@ -53,8 +54,8 @@ func newElementCmd(wdir, format *string, compact *bool) *cobra.Command { if cmdutil.WantsJSON(*format) { return cmdutil.WriteMutation(cmd.OutOrStdout(), *compact, "update element", "update", ref) } - _, _ = fmt.Fprintf(cmd.OutOrStdout(), "Updated element %q: %s=%q\n", ref, field, value) - _, _ = fmt.Fprintln(cmd.OutOrStdout(), "Change recorded locally in elements.yaml. Run 'tld apply' to push to cloud.") + term.Successf(cmd.OutOrStdout(), "Updated element %q: %s=%q", ref, field, value) + term.Hint(cmd.OutOrStdout(), "Run 'tld apply' to push to cloud.") return nil }, } @@ -91,8 +92,8 @@ func newConnectorCmd(wdir, format *string, compact *bool) *cobra.Command { if cmdutil.WantsJSON(*format) { return cmdutil.WriteMutation(cmd.OutOrStdout(), *compact, "update connector", "update", ref) } - _, _ = fmt.Fprintf(cmd.OutOrStdout(), "Updated connector %q: %s=%q\n", ref, field, value) - _, _ = fmt.Fprintln(cmd.OutOrStdout(), "Change recorded locally in connectors.yaml. Run 'tld apply' to push to cloud.") + term.Successf(cmd.OutOrStdout(), "Updated connector %q: %s=%q", ref, field, value) + term.Hint(cmd.OutOrStdout(), "Run 'tld apply' to push to cloud.") return nil }, } diff --git a/cmd/update/update_test.go b/cmd/update/update_test.go new file mode 100644 index 0000000..e758631 --- /dev/null +++ b/cmd/update/update_test.go @@ -0,0 +1,61 @@ +package update_test + +import ( + "strings" + "testing" + + "github.com/mertcikla/tld/cmd" + "github.com/mertcikla/tld/internal/workspace" +) + +func TestUpdateElementCmdUpdatesScalarField(t *testing.T) { + dir := t.TempDir() + cmd.MustInitWorkspace(t, dir) + cmd.SeedElementWorkspace(t, dir) + + stdout, stderr, err := cmd.RunCmd(t, dir, "update", "element", "api", "description", "Handles traffic") + if err != nil { + t.Fatalf("update element: %v\nstdout:%s\nstderr:%s", err, stdout, stderr) + } + if !strings.Contains(stdout, "Updated element") { + t.Fatalf("stdout = %q, want update confirmation", stdout) + } + ws, err := workspace.Load(dir) + if err != nil { + t.Fatal(err) + } + if got := ws.Elements["api"].Description; got != "Handles traffic" { + t.Fatalf("description = %q, want Handles traffic", got) + } +} + +func TestUpdateConnectorCmdUpdatesDirection(t *testing.T) { + dir := t.TempDir() + cmd.MustInitWorkspace(t, dir) + cmd.SeedElementWorkspace(t, dir) + + stdout, stderr, err := cmd.RunCmd(t, dir, "update", "connector", "platform:api:db:reads", "direction", "bidirectional") + if err != nil { + t.Fatalf("update connector: %v\nstdout:%s\nstderr:%s", err, stdout, stderr) + } + if !strings.Contains(stdout, "Updated connector") { + t.Fatalf("stdout = %q, want update confirmation", stdout) + } + ws, err := workspace.Load(dir) + if err != nil { + t.Fatal(err) + } + if got := ws.Connectors["platform:api:db:reads"].Direction; got != "bidirectional" { + t.Fatalf("direction = %q, want bidirectional", got) + } +} + +func TestUpdateCmdShowsHelpWithNoSubcommand(t *testing.T) { + stdout, stderr, err := cmd.RunCmd(t, t.TempDir(), "update") + if err != nil { + t.Fatalf("update help: %v\nstdout:%s\nstderr:%s", err, stdout, stderr) + } + if !strings.Contains(stdout, "Update a resource field") { + t.Fatalf("stdout = %q, want update help", stdout) + } +} diff --git a/cmd/validate/validate.go b/cmd/validate/validate.go index 11aa1ea..07540de 100644 --- a/cmd/validate/validate.go +++ b/cmd/validate/validate.go @@ -5,6 +5,7 @@ import ( "github.com/mertcikla/tld/internal/cmdutil" "github.com/mertcikla/tld/internal/planner" + "github.com/mertcikla/tld/internal/term" "github.com/mertcikla/tld/internal/workspace" "github.com/spf13/cobra" ) @@ -26,45 +27,40 @@ func NewValidateCmd(wdir *string) *cobra.Command { // Override strictness if flag is set if strictness > 0 { - if ws.Config.Validation == nil { - ws.Config.Validation = &workspace.ValidationConfig{} - } ws.Config.Validation.Level = strictness } errs := ws.Validate() if len(errs) > 0 { - _, _ = fmt.Fprintln(cmd.ErrOrStderr(), "Validation errors:") + term.Fail(cmd.ErrOrStderr(), "Validation errors:") for _, e := range errs { - _, _ = fmt.Fprintf(cmd.ErrOrStderr(), " - %s\n", e) + _, _ = fmt.Fprintf(cmd.ErrOrStderr(), " - %s\n", e) } return fmt.Errorf("%d validation error(s)", len(errs)) } broken := cmdutil.CheckSymbols(cmd.Context(), ws, repoCtx, rules) if len(broken) > 0 { - _, _ = fmt.Fprintln(cmd.ErrOrStderr(), "Symbol verification errors:") + term.Fail(cmd.ErrOrStderr(), "Symbol verification errors:") for _, msg := range broken { - _, _ = fmt.Fprintf(cmd.ErrOrStderr(), " - %s\n", msg) + _, _ = fmt.Fprintf(cmd.ErrOrStderr(), " - %s\n", msg) } return fmt.Errorf("%d symbol verification error(s)", len(broken)) } - _, _ = fmt.Fprintln(cmd.OutOrStdout(), "Symbol verification: passed") + term.Success(cmd.OutOrStdout(), "Symbol verification passed") if len(ws.Elements) > 0 || len(ws.Connectors) > 0 { diagramCount := cmdutil.CountElementDiagrams(ws) - _, _ = fmt.Fprintf(cmd.OutOrStdout(), "Workspace valid: %d elements, %d diagrams, %d connectors\n", - len(ws.Elements), diagramCount, len(ws.Connectors)) - _, _ = fmt.Fprintf(cmd.OutOrStdout(), "Element workspace: %d elements, %d diagrams, %d connectors\n", + term.Successf(cmd.OutOrStdout(), "Workspace valid: %d elements, %d diagrams, %d connectors", len(ws.Elements), diagramCount, len(ws.Connectors)) } // Evaluate Architectural warnings warnings := planner.AnalyzePlan(ws) if len(warnings) > 0 { - level := workspace.DefaultValidationLevel - if ws.Config.Validation != nil && ws.Config.Validation.Level > 0 { - level = ws.Config.Validation.Level + level := ws.Config.Validation.Level + if level == 0 { + level = workspace.DefaultValidationLevel } levelNames := map[int]string{1: "Minimal", 2: "Standard", 3: "Strict"} _, _ = fmt.Fprintf(cmd.OutOrStdout(), "\n## Architectural Warnings (Level %d: %s)\n\n", level, levelNames[level]) diff --git a/cmd/validate/validate_test.go b/cmd/validate/validate_test.go index f19239f..5d3df93 100644 --- a/cmd/validate/validate_test.go +++ b/cmd/validate/validate_test.go @@ -23,7 +23,7 @@ func TestValidateCmd_ValidWorkspace(t *testing.T) { if !strings.Contains(stdout, "Workspace valid") { t.Errorf("stdout %q does not contain 'Workspace valid'", stdout) } - if !strings.Contains(stdout, "Element workspace: 1 elements, 1 diagrams, 0 connectors") { + if !strings.Contains(stdout, "1 elements") || !strings.Contains(stdout, "1 diagrams") { t.Errorf("stdout %q does not contain count summary", stdout) } } @@ -43,15 +43,3 @@ func TestValidateCmd_InvalidWorkspace(t *testing.T) { t.Errorf("stderr %q does not contain 'Validation errors'", stderr) } } - -func TestValidateCmd_MissingConfig(t *testing.T) { - dir := t.TempDir() - // No .tld.yaml - _, _, err := cmd.RunCmd(t, dir, "validate") - if err == nil { - t.Fatal("expected error for missing config") - } - if !strings.Contains(err.Error(), "load workspace") { - t.Errorf("error %q does not contain 'load workspace'", err.Error()) - } -} diff --git a/cmd/version/version.go b/cmd/version/version.go index ca167c0..c07fb46 100644 --- a/cmd/version/version.go +++ b/cmd/version/version.go @@ -1,8 +1,7 @@ package version import ( - "fmt" - + "github.com/mertcikla/tld/internal/term" "github.com/spf13/cobra" ) @@ -15,7 +14,7 @@ func NewVersionCmd() *cobra.Command { Use: "version", Short: "Print the version number of tld", Run: func(cmd *cobra.Command, _ []string) { - _, _ = fmt.Fprintf(cmd.OutOrStdout(), "tld version %s\n", Version) + term.Infof(cmd.OutOrStdout(), "tld version %s", Version) }, } } diff --git a/cmd/version/version_test.go b/cmd/version/version_test.go new file mode 100644 index 0000000..0e92351 --- /dev/null +++ b/cmd/version/version_test.go @@ -0,0 +1,24 @@ +package version_test + +import ( + "bytes" + "strings" + "testing" + + "github.com/mertcikla/tld/cmd/version" +) + +func TestVersionCmdPrintsCurrentVersion(t *testing.T) { + cmd := version.NewVersionCmd() + var out bytes.Buffer + cmd.SetOut(&out) + + err := cmd.Execute() + if err != nil { + t.Fatalf("Execute() error = %v", err) + } + + if got := out.String(); !strings.Contains(got, "tld version "+version.Version) { + t.Fatalf("stdout = %q, want current version", got) + } +} diff --git a/cmd/watch/watch.go b/cmd/watch/watch.go new file mode 100644 index 0000000..fd31736 --- /dev/null +++ b/cmd/watch/watch.go @@ -0,0 +1,920 @@ +package watch + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "os" + "os/signal" + "strings" + "sync" + "syscall" + "time" + + assets "github.com/mertcikla/tld" + "github.com/mertcikla/tld/internal/cmdutil" + "github.com/mertcikla/tld/internal/localserver" + "github.com/mertcikla/tld/internal/store" + "github.com/mertcikla/tld/internal/term" + "github.com/mertcikla/tld/internal/watch" + "github.com/mertcikla/tld/internal/workspace" + "github.com/schollz/progressbar/v3" + "github.com/spf13/cobra" +) + +func NewWatchCmd() *cobra.Command { + var host, port, dataDirFlag string + var embeddingProvider, embeddingEndpoint, embeddingModel string + var embeddingDimension int + var languageFlags []string + var watcherMode, pollInterval, debounce string + var maxElements, maxConnectors, maxIncoming, maxOutgoing, maxExpandedGroup int + var noServe, openBrowser, rescan, verbose, dryRun, failOnDrift bool + c := &cobra.Command{ + Use: "watch [path]", + Short: "Scan and materialize source repositories into the local workspace", + Args: cobra.MaximumNArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + path := "." + if len(args) > 0 { + path = args[0] + } + if dryRun { + return runWatchDiff(cmd, path, watchDiffOptions{ + DataDirFlag: dataDirFlag, + EmbeddingProvider: embeddingProvider, + EmbeddingEndpoint: embeddingEndpoint, + EmbeddingModel: embeddingModel, + EmbeddingDimension: embeddingDimension, + LanguageFlags: languageFlags, + MaxElements: maxElements, + MaxConnectors: maxConnectors, + MaxIncoming: maxIncoming, + MaxOutgoing: maxOutgoing, + MaxExpandedGroup: maxExpandedGroup, + Rescan: rescan, + FailOnDrift: failOnDrift, + GroupDiffs: true, + }) + } + cfg, err := workspace.LoadGlobalConfig() + if err != nil { + return err + } + dataDir, err := workspace.ResolveDataDir(cfg, dataDirFlag) + if err != nil { + return err + } + if err := os.MkdirAll(dataDir, 0o755); err != nil { + return fmt.Errorf("create data dir: %w", err) + } + embeddingCfg := resolveEmbeddingConfig(cfg, embeddingProvider, embeddingEndpoint, embeddingModel, embeddingDimension) + watchSettings := resolveWatchSettings(cfg, languageFlags, watcherMode, pollInterval, debounce, maxElements, maxConnectors, maxIncoming, maxOutgoing, maxExpandedGroup) + term.Infof(cmd.OutOrStdout(), "watch booting: data=%s embeddings=%s/%s", term.Path(cmd.OutOrStdout(), dataDir), embeddingCfg.Provider, embeddingCfg.Model) + progress := newCLIProgress(cmd.ErrOrStderr()) + if embeddingCfg.Provider != "none" { + term.Infof(cmd.OutOrStdout(), "embedding healthcheck: %s %s", embeddingCfg.Endpoint, embeddingCfg.Model) + checked, health, err := watch.CheckEmbeddingHealth(cmd.Context(), embeddingCfg) + if err != nil { + return fmt.Errorf("embedding healthcheck failed: %w", err) + } + embeddingCfg = checked + term.Successf(cmd.OutOrStdout(), "embedding healthcheck ok: dimension=%d similarity=%.3f", health.Dimension, health.Similarity) + } + serveCfg := workspace.ResolveServeOptions(cfg, host, port) + serveOpts := localserver.ServeOptions{Host: serveCfg.Host, Port: serveCfg.Port} + addr := localserver.ResolveAddr(serveOpts) + url := "http://" + addr + var srv *http.Server + if !noServe { + if !serverReady(url) { + term.Infof(cmd.OutOrStdout(), "server booting: %s", url) + app, err := localserver.Bootstrap(dataDir, serveOpts) + if err != nil { + return err + } + srv = &http.Server{Addr: app.Addr, Handler: app.Handler} + go func() { + if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed { + term.Failf(cmd.ErrOrStderr(), "server error: %v", err) + } + }() + url = "http://" + app.Addr + } + term.Successf(cmd.OutOrStdout(), "server ready: %s", term.URL(cmd.OutOrStdout(), url)) + if openBrowser { + _ = cmdutil.OpenBrowser(url) + } + } + defer func() { + if srv != nil { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + _ = srv.Shutdown(ctx) + } + }() + + sqliteStore, err := store.Open(localserver.DatabasePath(dataDir), assets.FS) + if err != nil { + return err + } + watchStore := watch.NewStore(sqliteStore.DB()) + ctx, stop := signal.NotifyContext(cmd.Context(), syscall.SIGINT, syscall.SIGTERM) + defer stop() + events := make(chan watch.Event, 16) + ready := make(chan watch.RunnerResult, 1) + watchProgress := newWatchActivityProgress(cmd.ErrOrStderr(), watchClientCounter(url)) + defer func() { + if watchProgress != nil { + watchProgress.Stop() + } + }() + go func() { + for event := range events { + if logWatchEvent(cmd, event, watchProgress) { + continue + } + if verbose || event.Type == "watch.error" || event.Type == "version.created" { + if event.Message != "" { + _, _ = fmt.Fprintf(cmd.OutOrStdout(), "%s: %s\n", event.Type, event.Message) + } else { + _, _ = fmt.Fprintf(cmd.OutOrStdout(), "%s\n", event.Type) + } + } + } + }() + errCh := make(chan error, 1) + go func() { + _, runErr := watch.NewRunner(watchStore).Run(ctx, watch.RunnerOptions{Path: path, Rescan: rescan, Verbose: verbose, Embedding: embeddingCfg, Settings: watchSettings, Progress: progress, Events: events, Ready: ready}) + errCh <- runErr + close(events) + }() + var result watch.RunnerResult + select { + case result = <-ready: + case err := <-errCh: + if err != nil { + return err + } + return nil + } + repo := result.Repository + term.Separator(cmd.OutOrStdout()) + term.Label(cmd.OutOrStdout(), 20, "Watching", repo.RepoRoot) + term.Label(cmd.OutOrStdout(), 20, "Repository", repoIdentity(repo)) + term.Label(cmd.OutOrStdout(), 20, "Branch", result.GitStatus.Branch) + term.Label(cmd.OutOrStdout(), 20, "HEAD", result.GitStatus.HeadCommit) + term.Label(cmd.OutOrStdout(), 20, "Mode", "watch") + term.Label(cmd.OutOrStdout(), 20, "tlDiagram available at", term.URL(cmd.OutOrStdout(), url)) + term.Separator(cmd.OutOrStdout()) + term.Hint(cmd.OutOrStdout(), "Press Ctrl-C to stop watching.") + if err := <-errCh; err != nil { + return err + } + return nil + }, + } + c.Flags().StringVar(&host, "host", "", "host for the local app server") + c.Flags().StringVar(&port, "port", "", "port for the local app server") + c.Flags().StringVar(&dataDirFlag, "data-dir", "", "directory for the local app database") + c.Flags().BoolVar(&noServe, "no-serve", false, "do not start the local app server") + c.Flags().BoolVar(&openBrowser, "open", false, "open the webapp in a browser") + c.Flags().BoolVar(&dryRun, "dry-run", false, "scan, materialize, print frontend-equivalent watch diffs as JSON, and exit") + c.Flags().StringVar(&embeddingProvider, "embedding-provider", "", "embedding provider for representation") + c.Flags().StringVar(&embeddingEndpoint, "embedding-endpoint", "", "embedding endpoint for representation") + c.Flags().StringVar(&embeddingModel, "embedding-model", "", "embedding model for representation") + c.Flags().IntVar(&embeddingDimension, "embedding-dimension", 0, "embedding vector dimension") + c.Flags().StringSliceVar(&languageFlags, "language", nil, "source language to watch (repeatable)") + c.Flags().StringVar(&watcherMode, "watcher", "", "watcher backend: auto, fsnotify, or poll") + c.Flags().StringVar(&pollInterval, "poll-interval", "", "poll interval (for example 1s)") + c.Flags().StringVar(&debounce, "debounce", "", "change debounce duration (for example 500ms)") + c.Flags().IntVar(&maxElements, "max-elements-per-view", 0, "maximum generated elements per view") + c.Flags().IntVar(&maxConnectors, "max-connectors-per-view", 0, "maximum generated connectors per view") + c.Flags().IntVar(&maxIncoming, "max-incoming-per-element", 0, "maximum incoming references per element before collapsing") + c.Flags().IntVar(&maxOutgoing, "max-outgoing-per-element", 0, "maximum outgoing references per element before collapsing") + c.Flags().IntVar(&maxExpandedGroup, "max-expanded-connectors-per-group", 0, "maximum file-pair connectors to expand before collapsing to a folder connector") + c.Flags().BoolVar(&rescan, "rescan", false, "force a rescan before watching") + c.Flags().BoolVar(&failOnDrift, "fail-on-drift", false, "with --dry-run, exit nonzero when representation drift is detected") + c.Flags().BoolVar(&verbose, "verbose", false, "print watch events") + c.AddCommand(newScanCmd()) + c.AddCommand(newRepresentCmd()) + c.AddCommand(newDiffCmd()) + return c +} + +func logWatchEvent(cmd *cobra.Command, event watch.Event, activity *watchActivityProgress) bool { + out := cmd.OutOrStdout() + switch event.Type { + case "watch.started": + if activity != nil { + activity.Start("watching for changes") + } + _, _ = fmt.Fprintf(out, "%s started\n", term.Colorize(out, term.ColorGreen, "watch")) + return true + case "watch.stopped": + if activity != nil { + activity.Stop() + } + _, _ = fmt.Fprintf(out, "%s stopped\n", term.Colorize(out, term.ColorYellow, "watch")) + return true + case "scan.started": + _, _ = fmt.Fprintf(out, "%s scanning source graph\n", term.Colorize(out, term.ColorBlue, "watch")) + return true + case "scan.completed": + if scan, ok := event.Data.(watch.ScanResult); ok { + _, _ = fmt.Fprintf(out, "%s scan complete: %d files, %d parsed, %d skipped\n", term.Colorize(out, term.ColorGreen, "watch"), scan.FilesSeen, scan.FilesParsed, scan.FilesSkipped) + return true + } + return false + case "representation.started": + _, _ = fmt.Fprintf(out, "%s materializing representation\n", term.Colorize(out, term.ColorBlue, "watch")) + return true + case "representation.updated": + if rep, ok := event.Data.(watch.RepresentResult); ok { + _, _ = fmt.Fprintf(out, "%s representation updated: elements +%d/%d, connectors +%d/%d, embeddings +%d/%d cached\n", + term.Colorize(out, term.ColorGreen, "watch"), + rep.ElementsCreated, rep.ElementsUpdated, + rep.ConnectorsCreated, rep.ConnectorsUpdated, + rep.EmbeddingsCreated, rep.EmbeddingCacheHits) + return true + } + return false + case "source.changed": + result, ok := event.Data.(watch.SourceFileChangeResult) + if !ok { + return false + } + if activity != nil { + activity.Advance("") + } + status := term.Colorize(out, term.ColorYellow, "no representation update") + if result.RepresentationChanged { + status = term.Colorize(out, term.ColorGreen, "representation updated") + } + _, _ = fmt.Fprintf(out, "%s %s %s: %s (%s)\n", + term.Colorize(out, term.ColorBlue, "source"), + term.Colorize(out, term.ColorYellow, result.Change.ChangeType), + result.Change.Path, + status, + representationChangeSummary(result.Representation, result.GitTags), + ) + return true + case "watch.changeCounter": + counter, ok := event.Data.(watch.ChangeCounter) + if !ok { + return false + } + if activity != nil { + if counter.IntervalChangesProcessed > 0 { + activity.Advance(fmt.Sprintf("watching: %d total, %d in last minute", counter.TotalChangesProcessed, counter.IntervalChangesProcessed)) + } else { + activity.Advance(fmt.Sprintf("watching: %d total", counter.TotalChangesProcessed)) + } + return true + } + if counter.IntervalChangesProcessed > 0 { + _, _ = fmt.Fprintf(out, "%s changes processed: %d total, %d in the last minute\n", + term.Colorize(out, term.ColorBlue, "watch"), + counter.TotalChangesProcessed, + counter.IntervalChangesProcessed, + ) + } else { + _, _ = fmt.Fprintf(out, "%s changes processed: %d total\n", + term.Colorize(out, term.ColorBlue, "watch"), + counter.TotalChangesProcessed, + ) + } + return true + case "watch.error": + message := event.Message + if message == "" { + message = "unknown error" + } + _, _ = fmt.Fprintf(out, "%s %s\n", term.Colorize(out, term.ColorRed, "watch.error:"), message) + return true + case "version.created": + _, _ = fmt.Fprintf(out, "%s version created\n", term.Colorize(out, term.ColorGreen, "watch")) + return true + default: + return false + } +} + +func representationChangeSummary(rep watch.RepresentResult, tags watch.GitTagUpdateResult) string { + return fmt.Sprintf("elements +%d/%d, connectors +%d/%d, views +%d, tags +%d/-%d", + rep.ElementsCreated, + rep.ElementsUpdated, + rep.ConnectorsCreated, + rep.ConnectorsUpdated, + rep.ViewsCreated, + tags.TagsAdded, + tags.TagsRemoved, + ) +} + +func repoIdentity(repo watch.Repository) string { + if repo.RemoteURL.Valid && repo.RemoteURL.String != "" { + return repo.RemoteURL.String + } + return repo.RepoRoot +} + +func serverReady(url string) bool { + client := &http.Client{Timeout: 500 * time.Millisecond} + resp, err := client.Get(url + "/api/ready") + if err != nil { + return false + } + defer func() { _ = resp.Body.Close() }() + return resp.StatusCode == http.StatusOK +} + +func watchClientCounter(url string) func() int { + client := &http.Client{Timeout: 500 * time.Millisecond} + var mu sync.Mutex + var cached int + var checkedAt time.Time + return func() int { + mu.Lock() + defer mu.Unlock() + if time.Since(checkedAt) < time.Second { + return cached + } + checkedAt = time.Now() + resp, err := client.Get(url + "/api/watch/status") + if err != nil { + cached = watch.WatchWebSocketClientCount() + return cached + } + defer func() { _ = resp.Body.Close() }() + var status struct { + ConnectedClients int `json:"connected_clients"` + } + if resp.StatusCode != http.StatusOK || json.NewDecoder(resp.Body).Decode(&status) != nil { + cached = watch.WatchWebSocketClientCount() + return cached + } + cached = status.ConnectedClients + return cached + } +} + +func newScanCmd() *cobra.Command { + var dataDirFlag string + var languageFlags []string + var jsonOut, rescan bool + c := &cobra.Command{ + Use: "scan [path]", + Short: "Scan a repository into the local raw code graph", + Args: cobra.MaximumNArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + path := "." + if len(args) > 0 { + path = args[0] + } + cfg, err := workspace.LoadGlobalConfig() + if err != nil { + return err + } + dataDir, err := workspace.ResolveDataDir(cfg, dataDirFlag) + if err != nil { + return err + } + if err := os.MkdirAll(dataDir, 0o755); err != nil { + return fmt.Errorf("create data dir: %w", err) + } + watchSettings := resolveWatchSettings(cfg, languageFlags, "", "", "", 0, 0, 0, 0, 0) + sqliteStore, err := store.Open(localserver.DatabasePath(dataDir), assets.FS) + if err != nil { + return err + } + defer func() { _ = sqliteStore.DB().Close() }() + scanner := watch.NewScanner(watch.NewStore(sqliteStore.DB())) + scanner.Settings = watchSettings + scanner.Progress = newCLIProgress(cmd.ErrOrStderr()) + result, err := scanner.ScanWithOptions(cmd.Context(), path, watch.ScanOptions{Force: rescan}) + if err != nil { + return err + } + if jsonOut { + return json.NewEncoder(cmd.OutOrStdout()).Encode(result) + } + term.Label(cmd.OutOrStdout(), 15, "Repository", fmt.Sprintf("%d", result.RepositoryID)) + term.Label(cmd.OutOrStdout(), 15, "Scan run", fmt.Sprintf("%d", result.ScanRunID)) + term.Label(cmd.OutOrStdout(), 15, "Files", fmt.Sprintf("%d seen, %d parsed, %d skipped", result.FilesSeen, result.FilesParsed, result.FilesSkipped)) + term.Label(cmd.OutOrStdout(), 15, "Symbols", fmt.Sprintf("%d", result.SymbolsSeen)) + term.Label(cmd.OutOrStdout(), 15, "References", fmt.Sprintf("%d", result.ReferencesSeen)) + if result.Warning != "" { + term.Warn(cmd.OutOrStdout(), result.Warning) + } + return nil + }, + } + c.Flags().StringVar(&dataDirFlag, "data-dir", "", "directory for the local app database") + c.Flags().StringSliceVar(&languageFlags, "language", nil, "source language to scan (repeatable)") + c.Flags().BoolVar(&rescan, "rescan", false, "force reparsing files even if cached") + c.Flags().BoolVar(&jsonOut, "json", false, "print machine-readable JSON") + return c +} + +func newRepresentCmd() *cobra.Command { + var dataDirFlag string + var embeddingProvider, embeddingEndpoint, embeddingModel string + var embeddingDimension int + var languageFlags []string + var jsonOut, rescan bool + var maxElements, maxConnectors, maxIncoming, maxOutgoing, maxExpandedGroup int + c := &cobra.Command{ + Use: "represent [path]", + Short: "Materialize a scanned repository into the local workspace", + Args: cobra.MaximumNArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + path := "." + if len(args) > 0 { + path = args[0] + } + cfg, err := workspace.LoadGlobalConfig() + if err != nil { + return err + } + dataDir, err := workspace.ResolveDataDir(cfg, dataDirFlag) + if err != nil { + return err + } + if err := os.MkdirAll(dataDir, 0o755); err != nil { + return fmt.Errorf("create data dir: %w", err) + } + embeddingCfg := resolveEmbeddingConfig(cfg, embeddingProvider, embeddingEndpoint, embeddingModel, embeddingDimension) + watchSettings := resolveWatchSettings(cfg, languageFlags, "", "", "", maxElements, maxConnectors, maxIncoming, maxOutgoing, maxExpandedGroup) + progress := newCLIProgress(cmd.ErrOrStderr()) + if embeddingCfg.Provider != "none" { + checked, health, err := watch.CheckEmbeddingHealth(cmd.Context(), embeddingCfg) + if err != nil { + return fmt.Errorf("embedding healthcheck failed: %w", err) + } + embeddingCfg = checked + _, _ = fmt.Fprintf(cmd.OutOrStdout(), "Embedding: %s/%s dimension=%d similarity=%.3f\n", embeddingCfg.Provider, embeddingCfg.Model, health.Dimension, health.Similarity) + } + sqliteStore, err := store.Open(localserver.DatabasePath(dataDir), assets.FS) + if err != nil { + return err + } + defer func() { _ = sqliteStore.DB().Close() }() + watchStore := watch.NewStore(sqliteStore.DB()) + scanner := watch.NewScanner(watchStore) + scanner.Settings = watchSettings + scanner.Progress = progress + scanResult, err := scanner.ScanWithOptions(cmd.Context(), path, watch.ScanOptions{Force: rescan}) + if err != nil { + return err + } + result, err := watch.NewRepresenter(watchStore).Represent(cmd.Context(), scanResult.RepositoryID, watch.RepresentRequest{Embedding: embeddingCfg, Thresholds: watchSettings.Thresholds, Visibility: watchSettings.Visibility, Progress: progress}) + if err != nil { + return err + } + if jsonOut { + return json.NewEncoder(cmd.OutOrStdout()).Encode(struct { + Scan watch.ScanResult `json:"scan"` + Representation watch.RepresentResult `json:"representation"` + }{Scan: scanResult, Representation: result}) + } + term.Label(cmd.OutOrStdout(), 18, "Repository", fmt.Sprintf("%d", result.RepositoryID)) + term.Label(cmd.OutOrStdout(), 18, "Scan run", fmt.Sprintf("%d", scanResult.ScanRunID)) + term.Label(cmd.OutOrStdout(), 18, "Filter run", fmt.Sprintf("%d", result.FilterRunID)) + term.Label(cmd.OutOrStdout(), 18, "Represent run", fmt.Sprintf("%d", result.RepresentationRun)) + term.Label(cmd.OutOrStdout(), 18, "Elements", fmt.Sprintf("%d created, %d updated", result.ElementsCreated, result.ElementsUpdated)) + term.Label(cmd.OutOrStdout(), 18, "Connectors", fmt.Sprintf("%d created, %d updated", result.ConnectorsCreated, result.ConnectorsUpdated)) + term.Label(cmd.OutOrStdout(), 18, "Views", fmt.Sprintf("%d created", result.ViewsCreated)) + term.Label(cmd.OutOrStdout(), 18, "Raw graph hash", result.RawGraphHash) + term.Label(cmd.OutOrStdout(), 18, "Representation", result.RepresentationHash) + return nil + }, + } + c.Flags().StringVar(&dataDirFlag, "data-dir", "", "directory for the local app database") + c.Flags().StringVar(&embeddingProvider, "embedding-provider", "", "embedding provider for representation") + c.Flags().StringVar(&embeddingEndpoint, "embedding-endpoint", "", "embedding endpoint for representation") + c.Flags().StringVar(&embeddingModel, "embedding-model", "", "embedding model for representation") + c.Flags().IntVar(&embeddingDimension, "embedding-dimension", 0, "embedding vector dimension") + c.Flags().StringSliceVar(&languageFlags, "language", nil, "source language to scan (repeatable)") + c.Flags().IntVar(&maxElements, "max-elements-per-view", 0, "maximum generated elements per view") + c.Flags().IntVar(&maxConnectors, "max-connectors-per-view", 0, "maximum generated connectors per view") + c.Flags().IntVar(&maxIncoming, "max-incoming-per-element", 0, "maximum incoming references per element before collapsing") + c.Flags().IntVar(&maxOutgoing, "max-outgoing-per-element", 0, "maximum outgoing references per element before collapsing") + c.Flags().IntVar(&maxExpandedGroup, "max-expanded-connectors-per-group", 0, "maximum file-pair connectors to expand before collapsing to a folder connector") + c.Flags().BoolVar(&rescan, "rescan", false, "force reparsing files even if cached") + c.Flags().BoolVar(&jsonOut, "json", false, "print machine-readable JSON") + return c +} + +func newDiffCmd() *cobra.Command { + var dataDirFlag string + var embeddingProvider, embeddingEndpoint, embeddingModel string + var embeddingDimension int + var languageFlags []string + var failOnDrift bool + var maxElements, maxConnectors, maxIncoming, maxOutgoing, maxExpandedGroup int + c := &cobra.Command{ + Use: "diff [path]", + Short: "Scan and report watch representation drift as JSON", + Args: cobra.MaximumNArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + path := "." + if len(args) > 0 { + path = args[0] + } + return runWatchDiff(cmd, path, watchDiffOptions{ + DataDirFlag: dataDirFlag, + EmbeddingProvider: embeddingProvider, + EmbeddingEndpoint: embeddingEndpoint, + EmbeddingModel: embeddingModel, + EmbeddingDimension: embeddingDimension, + LanguageFlags: languageFlags, + MaxElements: maxElements, + MaxConnectors: maxConnectors, + MaxIncoming: maxIncoming, + MaxOutgoing: maxOutgoing, + MaxExpandedGroup: maxExpandedGroup, + FailOnDrift: failOnDrift, + }) + }, + } + c.Flags().StringVar(&dataDirFlag, "data-dir", "", "directory for the local app database") + c.Flags().StringVar(&embeddingProvider, "embedding-provider", "", "embedding provider for representation") + c.Flags().StringVar(&embeddingEndpoint, "embedding-endpoint", "", "embedding endpoint for representation") + c.Flags().StringVar(&embeddingModel, "embedding-model", "", "embedding model for representation") + c.Flags().IntVar(&embeddingDimension, "embedding-dimension", 0, "embedding vector dimension") + c.Flags().StringSliceVar(&languageFlags, "language", nil, "source language to scan (repeatable)") + c.Flags().IntVar(&maxElements, "max-elements-per-view", 0, "maximum generated elements per view") + c.Flags().IntVar(&maxConnectors, "max-connectors-per-view", 0, "maximum generated connectors per view") + c.Flags().IntVar(&maxIncoming, "max-incoming-per-element", 0, "maximum incoming references per element before collapsing") + c.Flags().IntVar(&maxOutgoing, "max-outgoing-per-element", 0, "maximum outgoing references per element before collapsing") + c.Flags().IntVar(&maxExpandedGroup, "max-expanded-connectors-per-group", 0, "maximum file-pair connectors to expand before collapsing to a folder connector") + c.Flags().BoolVar(&failOnDrift, "fail-on-drift", false, "exit nonzero when representation drift is detected") + return c +} + +type watchDiffOptions struct { + DataDirFlag string + EmbeddingProvider string + EmbeddingEndpoint string + EmbeddingModel string + EmbeddingDimension int + LanguageFlags []string + MaxElements int + MaxConnectors int + MaxIncoming int + MaxOutgoing int + MaxExpandedGroup int + Rescan bool + FailOnDrift bool + GroupDiffs bool +} + +type watchDiffPayload struct { + Changed bool `json:"changed"` + Scan watch.ScanResult `json:"scan"` + Representation watch.RepresentResult `json:"representation"` + Diffs []watch.RepresentationDiff `json:"diffs"` +} + +type watchGroupedDiffPayload struct { + Changed bool `json:"changed"` + Scan watch.ScanResult `json:"scan"` + Representation watch.RepresentResult `json:"representation"` + Diffs map[string]map[string][]watch.RepresentationDiff `json:"diffs"` +} + +func runWatchDiff(cmd *cobra.Command, path string, opts watchDiffOptions) error { + cfg, err := workspace.LoadGlobalConfig() + if err != nil { + return err + } + dataDir, err := workspace.ResolveDataDir(cfg, opts.DataDirFlag) + if err != nil { + return err + } + if err := os.MkdirAll(dataDir, 0o755); err != nil { + return fmt.Errorf("create data dir: %w", err) + } + embeddingCfg := resolveEmbeddingConfig(cfg, opts.EmbeddingProvider, opts.EmbeddingEndpoint, opts.EmbeddingModel, opts.EmbeddingDimension) + watchSettings := resolveWatchSettings(cfg, opts.LanguageFlags, "", "", "", opts.MaxElements, opts.MaxConnectors, opts.MaxIncoming, opts.MaxOutgoing, opts.MaxExpandedGroup) + sqliteStore, err := store.Open(localserver.DatabasePath(dataDir), assets.FS) + if err != nil { + return err + } + defer func() { _ = sqliteStore.DB().Close() }() + watchStore := watch.NewStore(sqliteStore.DB()) + once, err := watch.NewRunner(watchStore).RunOnce(cmd.Context(), watch.OneShotOptions{Path: path, Rescan: opts.Rescan, Embedding: embeddingCfg, Settings: watchSettings}) + if err != nil { + return err + } + latest, found, err := watchStore.LatestWatchVersion(cmd.Context(), once.Scan.RepositoryID) + if err != nil { + return err + } + changed := found && latest.RepresentationHash != once.Representation.RepresentationHash || hasWatchDriftDiffs(once.Diffs) + var payload any = watchDiffPayload{Changed: changed, Scan: once.Scan, Representation: once.Representation, Diffs: once.Diffs} + if opts.GroupDiffs { + payload = watchGroupedDiffPayload{Changed: changed, Scan: once.Scan, Representation: once.Representation, Diffs: groupWatchDiffs(once.Diffs)} + } + if err := json.NewEncoder(cmd.OutOrStdout()).Encode(payload); err != nil { + return err + } + if opts.FailOnDrift && changed { + return fmt.Errorf("watch representation drift detected") + } + return nil +} + +func groupWatchDiffs(diffs []watch.RepresentationDiff) map[string]map[string][]watch.RepresentationDiff { + grouped := map[string]map[string][]watch.RepresentationDiff{} + for _, diff := range diffs { + changeType := strings.TrimSpace(diff.ChangeType) + if changeType == "" { + changeType = "updated" + } + resourceType := diffResourceType(diff) + if _, ok := grouped[changeType]; !ok { + grouped[changeType] = map[string][]watch.RepresentationDiff{} + } + grouped[changeType][resourceType] = append(grouped[changeType][resourceType], diff) + } + return grouped +} + +func diffResourceType(diff watch.RepresentationDiff) string { + if diff.ResourceType != nil && strings.TrimSpace(*diff.ResourceType) != "" { + return strings.TrimSpace(*diff.ResourceType) + } + if strings.TrimSpace(diff.OwnerType) != "" { + return strings.TrimSpace(diff.OwnerType) + } + return "unknown" +} + +func hasWatchDriftDiffs(diffs []watch.RepresentationDiff) bool { + for _, diff := range diffs { + if diff.ChangeType != "initialized" && diff.OwnerType != "repository" { + return true + } + } + return false +} + +func resolveEmbeddingConfig(cfg *workspace.Config, provider, endpoint, model string, dimension int) watch.EmbeddingConfig { + embedding := watch.EmbeddingConfig{} + if cfg != nil { + embedding.Provider = cfg.Watch.Embedding.Provider + embedding.Endpoint = cfg.Watch.Embedding.Endpoint + embedding.Model = cfg.Watch.Embedding.Model + embedding.Dimension = cfg.Watch.Embedding.Dimension + embedding.HealthThreshold = cfg.Watch.Embedding.HealthThreshold + } + if provider != "" { + embedding.Provider = provider + } + if endpoint != "" { + embedding.Endpoint = endpoint + } + if model != "" { + embedding.Model = model + } + if dimension > 0 { + embedding.Dimension = dimension + } + return watch.NormalizeEmbeddingConfig(embedding) +} + +func resolveWatchSettings(cfg *workspace.Config, languages []string, watcherMode, pollInterval, debounce string, maxElements, maxConnectors, maxIncoming, maxOutgoing, maxExpandedGroup int) watch.Settings { + settings := watch.DefaultSettings() + if cfg != nil { + settings.Languages = cfg.Watch.Languages + settings.Watcher = cfg.Watch.Watcher + settings.PollInterval = parseDurationOrZero(cfg.Watch.PollInterval) + settings.Debounce = parseDurationOrZero(cfg.Watch.Debounce) + settings.Thresholds = watch.Thresholds{ + MaxElementsPerView: cfg.Watch.Thresholds.MaxElementsPerView, + MaxConnectorsPerView: cfg.Watch.Thresholds.MaxConnectorsPerView, + MaxIncomingPerElement: cfg.Watch.Thresholds.MaxIncomingPerElement, + MaxOutgoingPerElement: cfg.Watch.Thresholds.MaxOutgoingPerElement, + MaxExpandedConnectorsPerGroup: cfg.Watch.Thresholds.MaxExpandedConnectorsPerGroup, + } + settings.Visibility = watch.VisibilityConfig{ + CoreThresholdEnabled: cfg.Watch.Visibility.CoreThresholdEnabled, + CoreThreshold: cfg.Watch.Visibility.CoreThreshold, + TierMultiplier: cfg.Watch.Visibility.TierMultiplier, + MaxExpansionMultiplier: cfg.Watch.Visibility.MaxExpansionMultiplier, + CoreThresholdSet: true, + WeightsSet: true, + Weights: watch.VisibilityWeights{ + Changed: cfg.Watch.Visibility.Weights.Changed, + Selected: cfg.Watch.Visibility.Weights.Selected, + UserShow: cfg.Watch.Visibility.Weights.UserShow, + UserHide: cfg.Watch.Visibility.Weights.UserHide, + HighSignalFact: cfg.Watch.Visibility.Weights.HighSignalFact, + RelationshipProximity: cfg.Watch.Visibility.Weights.RelationshipProximity, + DependencyFact: cfg.Watch.Visibility.Weights.DependencyFact, + UtilityNoise: cfg.Watch.Visibility.Weights.UtilityNoise, + HighDegreeNoise: cfg.Watch.Visibility.Weights.HighDegreeNoise, + }, + } + } + if len(languages) > 0 { + settings.Languages = languages + } + if watcherMode != "" { + settings.Watcher = watcherMode + } + if pollInterval != "" { + settings.PollInterval = parseDurationOrZero(pollInterval) + } + if debounce != "" { + settings.Debounce = parseDurationOrZero(debounce) + } + if maxElements > 0 { + settings.Thresholds.MaxElementsPerView = maxElements + } + if maxConnectors > 0 { + settings.Thresholds.MaxConnectorsPerView = maxConnectors + } + if maxIncoming > 0 { + settings.Thresholds.MaxIncomingPerElement = maxIncoming + } + if maxOutgoing > 0 { + settings.Thresholds.MaxOutgoingPerElement = maxOutgoing + } + if maxExpandedGroup > 0 { + settings.Thresholds.MaxExpandedConnectorsPerGroup = maxExpandedGroup + } + return watch.NormalizeSettings(settings) +} + +func parseDurationOrZero(value string) time.Duration { + parsed, err := time.ParseDuration(strings.TrimSpace(value)) + if err != nil { + return 0 + } + return parsed +} + +type cliProgress struct { + out io.Writer + bar *progressbar.ProgressBar + mu sync.Mutex +} + +type watchActivityProgress struct { + out io.Writer + mu sync.Mutex + ticker *time.Ticker + stopCh chan struct{} + startTime time.Time + dots int + label string + clientCount func() int +} + +func newCLIProgress(out io.Writer) watch.ProgressSink { + if !term.IsTerminal(out) { + return nil + } + return &cliProgress{out: out} +} + +func newWatchActivityProgress(out io.Writer, clientCount func() int) *watchActivityProgress { + if !term.IsTerminal(out) { + return nil + } + return &watchActivityProgress{out: out, clientCount: clientCount} +} + +func (p *watchActivityProgress) Start(label string) { + if p == nil { + return + } + p.mu.Lock() + defer p.mu.Unlock() + if p.ticker != nil { + if label != "" { + p.label = label + p.renderLocked(false) + } + return + } + p.label = label + p.startTime = time.Now() + p.ticker = time.NewTicker(1 * time.Second) + p.stopCh = make(chan struct{}) + p.renderLocked(false) + go func() { + for { + select { + case <-p.ticker.C: + p.mu.Lock() + p.renderLocked(true) + p.mu.Unlock() + case <-p.stopCh: + return + } + } + }() +} + +func (p *watchActivityProgress) renderLocked(incrementDots bool) { + if incrementDots { + p.dots = (p.dots + 1) % 4 + } + dotsStr := strings.Repeat(".", p.dots) + strings.Repeat(" ", 3-p.dots) + elapsed := time.Since(p.startTime).Round(time.Second) + clientLabel := "" + if p.clientCount != nil { + clients := p.clientCount() + plural := "s" + if clients == 1 { + plural = "" + } + clientLabel = fmt.Sprintf(" · %d client%s connected", clients, plural) + } + _, _ = fmt.Fprintf(p.out, "\r\033[K%s%s [%s]%s", term.Colorize(p.out, term.ColorCyan, p.label), dotsStr, elapsed, clientLabel) +} + +func (p *watchActivityProgress) Advance(label string) { + if p == nil { + return + } + p.mu.Lock() + defer p.mu.Unlock() + if p.ticker == nil { + return + } + if label != "" { + p.label = label + p.renderLocked(false) + } +} + +func (p *watchActivityProgress) Stop() { + if p == nil { + return + } + p.mu.Lock() + defer p.mu.Unlock() + if p.ticker != nil { + p.ticker.Stop() + p.ticker = nil + } + if p.stopCh != nil { + close(p.stopCh) + p.stopCh = nil + } + _, _ = fmt.Fprintf(p.out, "\r\033[K") +} + +func (p *cliProgress) Start(label string, total int) { + if p == nil || total <= 0 { + return + } + p.mu.Lock() + defer p.mu.Unlock() + p.bar = progressbar.NewOptions(total, + progressbar.OptionSetWriter(p.out), + progressbar.OptionSetVisibility(true), + progressbar.OptionSetDescription(label), + progressbar.OptionShowCount(), + progressbar.OptionSetWidth(12), + progressbar.OptionFullWidth(), + progressbar.OptionClearOnFinish(), + progressbar.OptionUseANSICodes(true), + progressbar.OptionThrottle(60*time.Millisecond), + ) +} + +func (p *cliProgress) Advance(label string) { + if p == nil { + return + } + p.mu.Lock() + defer p.mu.Unlock() + if p.bar == nil { + return + } + if label != "" { + p.bar.Describe(label) + } + _ = p.bar.Add(1) +} + +func (p *cliProgress) Finish() { + if p == nil { + return + } + p.mu.Lock() + defer p.mu.Unlock() + if p.bar == nil { + return + } + _ = p.bar.Finish() + p.bar = nil +} diff --git a/cmd/watch/watch_test.go b/cmd/watch/watch_test.go new file mode 100644 index 0000000..19cf33e --- /dev/null +++ b/cmd/watch/watch_test.go @@ -0,0 +1,423 @@ +package watch + +import ( + "bytes" + "encoding/json" + "os" + "os/exec" + "path/filepath" + "sort" + "strings" + "testing" + "time" + + assets "github.com/mertcikla/tld" + "github.com/mertcikla/tld/internal/localserver" + storepkg "github.com/mertcikla/tld/internal/store" + watchpkg "github.com/mertcikla/tld/internal/watch" + "github.com/mertcikla/tld/internal/workspace" +) + +func TestWatchSubcommandsFailClearlyOutsideGitRepositoryWithoutRepositoryRows(t *testing.T) { + for _, subcommand := range []string{"scan", "represent", "diff"} { + t.Run(subcommand, func(t *testing.T) { + dataDir := t.TempDir() + cmd := NewWatchCmd() + var out bytes.Buffer + cmd.SetOut(&out) + cmd.SetErr(&out) + args := []string{subcommand, t.TempDir(), "--data-dir", dataDir} + if subcommand == "represent" || subcommand == "diff" { + args = append(args, "--embedding-provider", "none") + } + cmd.SetArgs(args) + + err := cmd.Execute() + if err == nil || !strings.Contains(err.Error(), "not inside a git repository") { + t.Fatalf("expected outside-git error, got %v\n%s", err, out.String()) + } + + sqliteStore, err := storepkg.Open(localserver.DatabasePath(dataDir), assets.FS) + if err != nil { + t.Fatal(err) + } + defer func() { _ = sqliteStore.DB().Close() }() + var repositories int + if err := sqliteStore.DB().QueryRow(`SELECT COUNT(*) FROM watch_repositories`).Scan(&repositories); err != nil { + t.Fatal(err) + } + if repositories != 0 { + t.Fatalf("expected no watch repository rows after failed %s, found %d", subcommand, repositories) + } + }) + } +} + +func TestScanCommandPrintsCountsAndSkipsRepeatScan(t *testing.T) { + repo := initGitRepoNoCommit(t) + writeFile(t, repo, "main.go", `package main + +func main() { + helper() +} + +func helper() {} +`) + dataDir := t.TempDir() + + first := runScanCommand(t, repo, dataDir) + if !strings.Contains(first, "Files:") || + !strings.Contains(first, "1 seen, 1 parsed, 0 skipped") || + !strings.Contains(first, "Symbols:") || + !strings.Contains(first, "2") || + !strings.Contains(first, "References:") || + !strings.Contains(first, "1") { + t.Fatalf("unexpected first scan output:\n%s", first) + } + + second := runScanCommand(t, repo, dataDir) + if !strings.Contains(second, "Files:") || !strings.Contains(second, "1 seen, 0 parsed, 1 skipped") { + t.Fatalf("unexpected repeat scan output:\n%s", second) + } +} + +func TestRepresentCommandPrintsMaterializationCounts(t *testing.T) { + repo := initGitRepoNoCommit(t) + writeFile(t, repo, "main.go", `package main + +func Main() { + helper() +} + +func helper() {} +`) + dataDir := t.TempDir() + + cmd := NewWatchCmd() + var out bytes.Buffer + cmd.SetOut(&out) + cmd.SetErr(&out) + cmd.SetArgs([]string{"represent", repo, "--data-dir", dataDir, "--embedding-provider", "none"}) + if err := cmd.Execute(); err != nil { + t.Fatalf("represent command: %v\n%s", err, out.String()) + } + text := out.String() + for _, expected := range []string{"Filter run:", "Represent run:", "Elements:", "Connectors:", "Representation:"} { + if !strings.Contains(text, expected) { + t.Fatalf("represent output missing %q:\n%s", expected, text) + } + } +} + +func TestScanCommandJSONRespectsLanguageFlag(t *testing.T) { + repo := initGitRepoNoCommit(t) + writeFile(t, repo, "main.go", "package main\nfunc Main() {}\n") + writeFile(t, repo, "web/app.ts", "export function render() { return 1 }\n") + dataDir := t.TempDir() + + cmd := NewWatchCmd() + var out bytes.Buffer + cmd.SetOut(&out) + cmd.SetErr(&out) + cmd.SetArgs([]string{"scan", repo, "--data-dir", dataDir, "--language", "typescript", "--json"}) + if err := cmd.Execute(); err != nil { + t.Fatalf("scan command: %v\n%s", err, out.String()) + } + var result struct { + FilesSeen int `json:"files_seen"` + FilesParsed int `json:"files_parsed"` + SymbolsSeen int `json:"symbols_seen"` + } + if err := json.Unmarshal(out.Bytes(), &result); err != nil { + t.Fatalf("invalid JSON output %q: %v", out.String(), err) + } + if result.FilesSeen != 1 || result.FilesParsed != 1 || result.SymbolsSeen == 0 { + t.Fatalf("expected only TypeScript file in JSON scan result, got %+v\n%s", result, out.String()) + } +} + +func TestDiffCommandJSONAndFailOnDrift(t *testing.T) { + repo := initGitRepoNoCommit(t) + writeFile(t, repo, "main.go", `package main + +func Main() {} +`) + dataDir := t.TempDir() + + cmd := NewWatchCmd() + var out bytes.Buffer + cmd.SetOut(&out) + cmd.SetErr(&out) + cmd.SetArgs([]string{"diff", repo, "--data-dir", dataDir, "--embedding-provider", "none"}) + if err := cmd.Execute(); err != nil { + t.Fatalf("diff command: %v\n%s", err, out.String()) + } + var payload struct { + Changed bool `json:"changed"` + Scan struct { + FilesSeen int `json:"files_seen"` + } `json:"scan"` + Diffs []struct { + ChangeType string `json:"change_type"` + ResourceType *string `json:"resource_type"` + } `json:"diffs"` + } + if err := json.Unmarshal(out.Bytes(), &payload); err != nil { + t.Fatalf("invalid JSON output %q: %v", out.String(), err) + } + if !payload.Changed || payload.Scan.FilesSeen != 1 || len(payload.Diffs) == 0 { + t.Fatalf("unexpected diff payload: %+v\n%s", payload, out.String()) + } + + cmd = NewWatchCmd() + out.Reset() + var errOut bytes.Buffer + cmd.SetOut(&out) + cmd.SetErr(&errOut) + cmd.SetArgs([]string{"diff", repo, "--data-dir", dataDir, "--embedding-provider", "none", "--fail-on-drift"}) + err := cmd.Execute() + if err == nil || !strings.Contains(err.Error(), "drift detected") { + t.Fatalf("expected fail-on-drift error, got %v\nstdout:\n%s\nstderr:\n%s", err, out.String(), errOut.String()) + } + var driftPayload struct { + Changed bool `json:"changed"` + } + if err := json.NewDecoder(strings.NewReader(out.String())).Decode(&driftPayload); err != nil || !driftPayload.Changed { + t.Fatalf("fail-on-drift should print a JSON payload before usage text, payload=%+v err=%v output=%q", driftPayload, err, out.String()) + } +} + +func TestWatchDryRunGroupsSameDiffPayloadAsDiffCommand(t *testing.T) { + repo := initGitRepoNoCommit(t) + writeFile(t, repo, "main.go", `package main + +func Main() {} +`) + dataDir := t.TempDir() + + dryRunCmd := NewWatchCmd() + var dryRunOut bytes.Buffer + dryRunCmd.SetOut(&dryRunOut) + dryRunCmd.SetErr(&dryRunOut) + dryRunCmd.SetArgs([]string{"--dry-run", repo, "--data-dir", dataDir, "--embedding-provider", "none"}) + if err := dryRunCmd.Execute(); err != nil { + t.Fatalf("watch --dry-run command: %v\n%s", err, dryRunOut.String()) + } + var dryRunPayload struct { + Changed bool `json:"changed"` + Diffs map[string]map[string][]watchpkg.RepresentationDiff `json:"diffs"` + } + if err := json.Unmarshal(dryRunOut.Bytes(), &dryRunPayload); err != nil { + t.Fatalf("invalid dry-run JSON output %q: %v", dryRunOut.String(), err) + } + if !dryRunPayload.Changed || len(dryRunPayload.Diffs) == 0 { + t.Fatalf("unexpected dry-run payload: %+v\n%s", dryRunPayload, dryRunOut.String()) + } + if _, ok := dryRunPayload.Diffs["added"]["element"]; !ok { + t.Fatalf("expected dry-run diffs to be grouped by change_type then resource_type, got %+v", dryRunPayload.Diffs) + } + if strings.Contains(dryRunOut.String(), "Watching") { + t.Fatalf("dry-run should exit after printing JSON, got watch output:\n%s", dryRunOut.String()) + } + + diffCmd := NewWatchCmd() + var diffOut bytes.Buffer + diffCmd.SetOut(&diffOut) + diffCmd.SetErr(&diffOut) + diffCmd.SetArgs([]string{"diff", repo, "--data-dir", dataDir, "--embedding-provider", "none"}) + if err := diffCmd.Execute(); err != nil { + t.Fatalf("watch diff command: %v\n%s", err, diffOut.String()) + } + var diffPayload struct { + Diffs []watchpkg.RepresentationDiff `json:"diffs"` + } + if err := json.Unmarshal(diffOut.Bytes(), &diffPayload); err != nil { + t.Fatalf("invalid diff JSON output %q: %v", diffOut.String(), err) + } + if !sameDiffPayload(flattenGroupedDiffPayload(dryRunPayload.Diffs), diffPayload.Diffs) { + t.Fatalf("watch --dry-run diffs should match watch diff diffs\n dry-run: %+v\n diff: %+v", dryRunPayload.Diffs, diffPayload.Diffs) + } +} + +func TestWatchDryRunCleanHeadInitializesWithoutDrift(t *testing.T) { + repo := initGitRepoNoCommit(t) + writeFile(t, repo, "main.go", `package main + +func Main() {} +`) + runGit(t, repo, "add", ".") + runGit(t, repo, "commit", "-m", "initial") + dataDir := t.TempDir() + + cmd := NewWatchCmd() + var out bytes.Buffer + cmd.SetOut(&out) + cmd.SetErr(&out) + cmd.SetArgs([]string{"--dry-run", repo, "--data-dir", dataDir, "--embedding-provider", "none"}) + if err := cmd.Execute(); err != nil { + t.Fatalf("watch --dry-run command: %v\n%s", err, out.String()) + } + var payload struct { + Changed bool `json:"changed"` + Diffs map[string]map[string][]struct { + ChangeType string `json:"change_type"` + ResourceType *string `json:"resource_type"` + } `json:"diffs"` + } + if err := json.Unmarshal(out.Bytes(), &payload); err != nil { + t.Fatalf("invalid dry-run JSON output %q: %v", out.String(), err) + } + if payload.Changed { + t.Fatalf("clean HEAD dry-run should initialize without drift, got %+v", payload) + } + for _, byResource := range payload.Diffs { + for _, diffs := range byResource { + for _, diff := range diffs { + if diff.ResourceType != nil && diff.ChangeType == "added" { + t.Fatalf("clean HEAD dry-run should not include added resource diffs, got %+v", payload.Diffs) + } + } + } + } +} + +func flattenGroupedDiffPayload(grouped map[string]map[string][]watchpkg.RepresentationDiff) []watchpkg.RepresentationDiff { + var out []watchpkg.RepresentationDiff + for _, byResource := range grouped { + for _, diffs := range byResource { + out = append(out, diffs...) + } + } + return out +} + +func sameDiffPayload(a, b []watchpkg.RepresentationDiff) bool { + canonical := func(diffs []watchpkg.RepresentationDiff) []string { + out := make([]string, 0, len(diffs)) + for _, diff := range diffs { + data, _ := json.Marshal(diff) + out = append(out, string(data)) + } + sort.Strings(out) + return out + } + left := canonical(a) + right := canonical(b) + if len(left) != len(right) { + return false + } + for i := range left { + if left[i] != right[i] { + return false + } + } + return true +} + +func TestResolveEmbeddingConfigPrecedence(t *testing.T) { + configDir := t.TempDir() + t.Setenv("TLD_CONFIG_DIR", configDir) + t.Setenv("TLD_EMBEDDING_PROVIDER", "local-deterministic-test") + t.Setenv("TLD_EMBEDDING_MODEL", "env-model") + t.Setenv("TLD_EMBEDDING_DIMENSION", "7") + + // Write a config file to test that env overrides it + writeFile(t, configDir, "tld.yaml", "watch:\n embedding:\n provider: ollama\n model: config-model\n") + + cfg, err := workspace.LoadGlobalConfig() + if err != nil { + t.Fatalf("LoadGlobalConfig: %v", err) + } + + resolved := resolveEmbeddingConfig(cfg, "none", "", "", 0) + if resolved.Provider != "none" { + t.Fatalf("flag provider should win over env/config, got %+v", resolved) + } + + resolved = resolveEmbeddingConfig(cfg, "", "", "", 0) + if resolved.Provider != "local-deterministic-test" || resolved.Model != "env-model" || resolved.Dimension != 7 { + t.Fatalf("env should win over config, got %+v", resolved) + } +} + +func TestResolveWatchSettingsPrecedence(t *testing.T) { + configDir := t.TempDir() + t.Setenv("TLD_CONFIG_DIR", configDir) + t.Setenv("TLD_WATCH_LANGUAGES", "python,typescript") + t.Setenv("TLD_WATCH_WATCHER", "poll") + t.Setenv("TLD_WATCH_POLL_INTERVAL", "3s") + t.Setenv("TLD_WATCH_DEBOUNCE", "250ms") + + // Write a config file to test that env overrides it + writeFile(t, configDir, "tld.yaml", "watch:\n languages: [go]\n watcher: fsnotify\n poll_interval: 9s\n debounce: 8s\n thresholds:\n max_elements_per_view: 11\n max_connectors_per_view: 12\n") + + cfg, err := workspace.LoadGlobalConfig() + if err != nil { + t.Fatalf("LoadGlobalConfig: %v", err) + } + + envResolved := resolveWatchSettings(cfg, nil, "", "", "", 0, 0, 0, 0, 0) + if strings.Join(envResolved.Languages, ",") != "python,typescript" || + envResolved.Watcher != "poll" || + envResolved.PollInterval != 3*time.Second || + envResolved.Debounce != 250*time.Millisecond || + envResolved.Thresholds.MaxElementsPerView != 11 || + envResolved.Thresholds.MaxConnectorsPerView != 12 { + t.Fatalf("env/config precedence resolved incorrectly: %+v", envResolved) + } + + flagResolved := resolveWatchSettings(cfg, []string{"java"}, "fsnotify", "1s", "2s", 21, 22, 23, 24, 25) + if strings.Join(flagResolved.Languages, ",") != "java" || + flagResolved.Watcher != "fsnotify" || + flagResolved.PollInterval != time.Second || + flagResolved.Debounce != 2*time.Second || + flagResolved.Thresholds.MaxElementsPerView != 21 || + flagResolved.Thresholds.MaxConnectorsPerView != 22 || + flagResolved.Thresholds.MaxIncomingPerElement != 23 || + flagResolved.Thresholds.MaxOutgoingPerElement != 24 || + flagResolved.Thresholds.MaxExpandedConnectorsPerGroup != 25 { + t.Fatalf("flag precedence resolved incorrectly: %+v", flagResolved) + } +} + +func runScanCommand(t *testing.T, repo, dataDir string) string { + t.Helper() + cmd := NewWatchCmd() + var out bytes.Buffer + cmd.SetOut(&out) + cmd.SetErr(&out) + cmd.SetArgs([]string{"scan", repo, "--data-dir", dataDir}) + if err := cmd.Execute(); err != nil { + t.Fatalf("scan command: %v\n%s", err, out.String()) + } + return out.String() +} + +func initGitRepoNoCommit(t *testing.T) string { + t.Helper() + dir := t.TempDir() + runGit(t, dir, "init") + runGit(t, dir, "config", "user.email", "test@example.com") + runGit(t, dir, "config", "user.name", "Test User") + return dir +} + +func runGit(t *testing.T, dir string, args ...string) { + t.Helper() + cmd := exec.Command("git", args...) + cmd.Dir = dir + if out, err := cmd.CombinedOutput(); err != nil { + t.Fatalf("git %v failed: %v\n%s", args, err, out) + } +} + +func writeFile(t *testing.T, root, name, content string) { + t.Helper() + path := filepath.Join(root, name) + if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(path, []byte(content), 0o644); err != nil { + t.Fatal(err) + } +} diff --git a/density.md b/density.md new file mode 100644 index 0000000..4c83309 --- /dev/null +++ b/density.md @@ -0,0 +1,84 @@ +# Density Slider with Per-View Visibility Overrides + +## Summary +Implement a persisted `-2..2` density system for every view, plus element+connector promote/demote controls in ViewEditor. Density remains a read-time projection: moving the slider or changing overrides does not delete or rematerialize generated objects. Default density is `0`. + +Density levels: +- `-2 Essential`: soft target 4 elements / 8 connectors. +- `-1 Compact`: soft target 8 / 16. +- `0 Balanced`: soft target 12 / 24. +- `1 Expanded`: soft target 32 / 64. +- `2 Full`: no projection cap. + +## Key Changes +- Add persisted view density: + - Add `views.density_level INTEGER NOT NULL DEFAULT 0`. + - Add local JSON endpoints: + - `GET /api/views/{id}/density` + - `PUT /api/views/{id}/density { density_level }` + - Validate density is `-2..2`. + +- Add per-view object overrides: + - Add `view_visibility_overrides` table keyed by `view_id`, `resource_type` (`element` or `connector`), and `resource_id`. + - Store `level_delta INTEGER NOT NULL DEFAULT 0`, clamped to `-4..4`, plus timestamps. + - Positive delta promotes visibility; negative delta demotes visibility. + - Effective level is computed from backend semantic/structural score plus delta. + - Delta `0` removes the override row. + +- Add override APIs: + - `GET /api/views/{id}/visibility-overrides` + - `PUT /api/views/{id}/visibility-overrides { resource_type, resource_id, level_delta }` + - `POST /api/views/{id}/visibility-overrides/{resource_type}/{resource_id}/promote` + - `POST /api/views/{id}/visibility-overrides/{resource_type}/{resource_id}/demote` + - `DELETE /api/views/{id}/visibility-overrides/{resource_type}/{resource_id}` + +- Add density projection: + - Project placements/connectors at read time using density level, soft caps, and overrides. + - Watch-backed views rank with `watch_filter_decisions`, fact confidence, materialization owner type, and architecture link confidence. + - Manual/non-watch views rank structurally using degree, connectivity, selected/focused state, and connector endpoint preservation. + - Connectors are visible only when both endpoints are visible, unless projection promotes a missing endpoint to preserve a user-promoted connector. + +- Update ViewEditor: + - Show a Density slider/segmented control for every view. + - Add quick Promote / Demote / Reset controls in `ElementPanel` and `ConnectorPanel`. + - Promoted/demoted objects get a small visual indicator in the panel and optionally on-canvas. + - Replace view-level Show Context / Clean Noise with Density for the main toolbar. + - Keep element-level Hide Context only as an advanced watch-specific action. + +## Behavior Rules +- Overrides are scoped to the current view only. +- Promote/demote is relative, not an exact level picker. +- Level `2 Full` shows all current view content, except explicit demotions may still hide objects unless the user resets overrides. +- Promoted connectors pull in endpoints if needed. +- Manual edits and generated watch rematerialization must not erase override rows unless the view or resource is deleted. +- Existing enricher metadata stays unchanged. + +## Test Plan +- Backend tests: + - Density and override validation. + - Promote/demote clamps deltas and reset removes rows. + - Projection applies overrides after base scoring. + - Promoted connector preserves both endpoints. + - Demoted element hides incident connectors unless those connectors are separately promoted. + - Manual views use structural projection without watch metadata. + - Watch views use filter decisions and confidence inputs. + +- Integration tests: + - Persisted density survives reload. + - Overrides survive watch re-representation. + - Deleting a view/resource cleans override rows. + - `2 Full` returns full content except explicit demotions. + +- Frontend tests: + - Density slider loads/saves per view. + - Element and connector panels promote/demote/reset objects. + - View content refreshes after density or override changes. + - Override indicators render for promoted/demoted objects. + +## Assumptions +- Use local JSON endpoints to avoid proto churn. +- Per-view persistence is required. +- Overrides are current-view scoped. +- Promote/demote uses relative deltas. +- Soft caps are preferred over hard caps. +- `watch_filter_decisions` becomes projection input, but remains the explainability/audit trail. diff --git a/frontend/logo/tld.svg b/frontend/logo/tld.svg new file mode 100644 index 0000000..366e447 --- /dev/null +++ b/frontend/logo/tld.svg @@ -0,0 +1 @@ + diff --git a/frontend/package-lock.json b/frontend/package-lock.json index 88be1ae..46b0b87 100644 --- a/frontend/package-lock.json +++ b/frontend/package-lock.json @@ -8,7 +8,7 @@ "name": "@tldiagram/core-ui", "version": "1.95.1", "dependencies": { - "@buf/tldiagramcom_diagram.bufbuild_es": "^2.11.0-20260419172603-8192c519478e.1", + "@buf/tldiagramcom_diagram.bufbuild_es": "^2.12.0-20260503002426-45e3166b5ec1.1", "@bufbuild/protobuf": "^2.11.0", "esbuild": "^0.25.12", "zustand": "^5.0.12" @@ -450,16 +450,16 @@ } }, "node_modules/@buf/tldiagramcom_diagram.bufbuild_es": { - "version": "2.11.0-20260419172603-8192c519478e.1", - "resolved": "https://buf.build/gen/npm/v1/@buf/tldiagramcom_diagram.bufbuild_es/-/tldiagramcom_diagram.bufbuild_es-2.11.0-20260419172603-8192c519478e.1.tgz", + "version": "2.12.0-20260503002426-45e3166b5ec1.1", + "resolved": "https://buf.build/gen/npm/v1/@buf/tldiagramcom_diagram.bufbuild_es/-/tldiagramcom_diagram.bufbuild_es-2.12.0-20260503002426-45e3166b5ec1.1.tgz", "peerDependencies": { - "@bufbuild/protobuf": "^2.11.0" + "@bufbuild/protobuf": "^2.12.0" } }, "node_modules/@bufbuild/protobuf": { - "version": "2.11.0", - "resolved": "https://registry.npmjs.org/@bufbuild/protobuf/-/protobuf-2.11.0.tgz", - "integrity": "sha512-sBXGT13cpmPR5BMgHE6UEEfEaShh5Ror6rfN3yEK5si7QVrtZg8LEPQb0VVhiLRUslD2yLnXtnRzG035J/mZXQ==", + "version": "2.12.0", + "resolved": "https://registry.npmjs.org/@bufbuild/protobuf/-/protobuf-2.12.0.tgz", + "integrity": "sha512-B/XlCaFIP8LOwzo+bz5uFzATYokcwCKQcghqnlfwSmM5eX/qTkvDBnDPs+gXtX/RyjxJ4DRikECcPJbyALA8FA==", "license": "(Apache-2.0 AND BSD-3-Clause)" }, "node_modules/@chakra-ui/anatomy": { diff --git a/frontend/package.json b/frontend/package.json index 92595e6..e5276e6 100644 --- a/frontend/package.json +++ b/frontend/package.json @@ -36,7 +36,7 @@ "test": "vitest run" }, "dependencies": { - "@buf/tldiagramcom_diagram.bufbuild_es": "^2.11.0-20260419172603-8192c519478e.1", + "@buf/tldiagramcom_diagram.bufbuild_es": "^2.12.0-20260503002426-45e3166b5ec1.1", "@bufbuild/protobuf": "^2.11.0", "esbuild": "^0.25.12", "zustand": "^5.0.12" diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx index dc26d71..d2d152f 100644 --- a/frontend/src/App.tsx +++ b/frontend/src/App.tsx @@ -1,6 +1,7 @@ import { useEffect, useState } from 'react' import { Routes, Route, Navigate, Outlet } from 'react-router-dom' -import { Box, Spinner, Center } from '@chakra-ui/react' +import { Box, Spinner, Center, IconButton, Tooltip } from '@chakra-ui/react' +import { ViewIcon, ViewOffIcon } from '@chakra-ui/icons' import { api } from './api/client' import ViewEditor from './pages/ViewEditor' import ViewsPage from './pages/Views' @@ -10,17 +11,20 @@ import Settings from './pages/Settings' import AppearanceSettings from './pages/AppearanceSettings' import { HeaderProvider, useHeader } from './components/HeaderContext' import TopMenuBar from './components/TopMenuBar' +import WorkspacePanel from './components/WorkspacePanel' import { ThemeProvider } from './context/ThemeContext' +import { WorkspaceVersionProvider } from './context/WorkspaceVersionContext' import { ACCENT_DEFAULT, BACKGROUND_DEFAULT, ELEMENT_DEFAULT, hexToRgba } from './constants/colors' import { platform } from './platform/local' function AppLayout() { const header = useHeader() + const [workspacePanelVisible, setWorkspacePanelVisible] = useState(true) const node = header && typeof header === 'object' && 'node' in header ? (header as { node: React.ReactNode }).node : header const hideMobileBar = header && typeof header === 'object' && 'hideMobileBar' in header ? !!(header as { hideMobileBar?: boolean }).hideMobileBar : false return ( - + {node} @@ -29,8 +33,9 @@ function AppLayout() { mb={{ base: 'var(--topbar-content-gap)', sm: '0px' }} flexShrink={0} /> - + + {workspacePanelVisible && } ) @@ -77,7 +82,7 @@ function HomeRedirect() { if (loading) { return ( -
+
) @@ -97,7 +102,7 @@ export default function App() { if (!ready) { return ( -
+
) @@ -105,15 +110,17 @@ export default function App() { return ( - + {platform.getRoutes({ user: null })} - } /> + } /> - + + + } > diff --git a/frontend/src/api/client.ts b/frontend/src/api/client.ts index decb739..53d40c5 100644 --- a/frontend/src/api/client.ts +++ b/frontend/src/api/client.ts @@ -9,11 +9,13 @@ import type { LibraryElement, PlacedElement, Tag, + TechnologyConnector, View, ViewConnector, ViewLayer, ViewPlacement, ViewTreeNode, + VisibilityOverride, } from '../types' import { WorkspaceService, @@ -46,19 +48,161 @@ import { import { ImportService, } from '@buf/tldiagramcom_diagram.bufbuild_es/diag/v1/import_service_pb' +import { + WorkspaceVersionService, + type WorkspaceVersionInfo, +} from '@buf/tldiagramcom_diagram.bufbuild_es/diag/v1/workspace_version_service_pb' +import { + OrgService, + ListTagColorsResponseSchema, +} from '@buf/tldiagramcom_diagram.bufbuild_es/diag/v1/org_service_pb' import { transport } from './transport' import { apiUrl, fetchApiAsset } from '../config/runtime' +async function responseError(res: Response, fallback: string): Promise { + const body = await res.json().catch(() => null) as { error?: string } | null + return new Error(body?.error || `${fallback}: ${res.statusText}`) +} + export interface DependenciesResponse { elements: DependencyElement[] connectors: DependencyConnector[] + totalCount?: number +} + +export interface WatchRepository { + id: number + remote_url: string | null + repo_root: string + display_name: string + branch: string | null + head_commit: string | null + identity_status: string +} + +export interface WatchLock { + id: number + repository_id: number + pid: number + started_at: string + heartbeat_at: string + status: 'active' | 'paused' | 'stopping' | 'stale' | 'released' | string +} + +export interface WatchStatus { + active: boolean + repository?: WatchRepository + lock?: WatchLock + connected_clients?: number +} + +export interface WatchRepresentationSummary { + repository_id: number + raw_graph_hash?: string + filter_settings_hash?: string + representation_hash?: string + last_status?: string + last_started_at?: string + last_finished_at?: string + elements_created: number + elements_updated: number + connectors_created: number + connectors_updated: number + views_created: number + diffs?: WatchDiff[] +} + +export interface WatchContextActionResponse { + repository_id: number + action: 'show' | 'hide' | 'clean' | string + policies_created: number + policies_updated: number + policies_deactivated: number + owners_affected: number + tier_before: number + tier_after: number + max_tier: number + elements_added: number + connectors_added: number + views_added: number + elements_removed: number + connectors_removed: number + views_removed: number + representation: { + repository_id: number + representation_run_id: number + filter_run_id: number + raw_graph_hash: string + filter_settings_hash: string + representation_hash: string + } + summary: WatchRepresentationSummary +} + +export interface WatchEvent { + type: string + repository_id?: number + message?: string + at: string + data?: unknown + phase?: string + watcher_mode?: string + languages?: string[] + changed_files?: number + warnings?: string[] } +export interface WatchVersion { + id: number + repository_id: number + commit_hash: string + commit_message?: string + parent_commit_hash?: string + branch?: string + representation_hash: string + workspace_version_id?: number + created_at: string +} + +export interface WatchDiff { + id: number + version_id: number + owner_type: string + owner_key: string + change_type: string + before_hash?: string + after_hash?: string + resource_type?: string + resource_id?: number + language?: string + summary?: string + added_lines?: number + removed_lines?: number +} + +export interface WorkspaceVersion { + id: string + version_id: string + source: string + parent_version_id?: string + view_count: number + element_count: number + connector_count: number + description?: string + workspace_hash?: string + created_at: string +} + +export type SourceEditor = 'zed' | 'vscode' + // ─── RPC clients ───────────────────────────────────────────────────────────── const workspaceClient = createClient(WorkspaceService, transport) const dependencyClient = createClient(DependencyService, transport) const importClient = createClient(ImportService, transport) +const workspaceVersionClient = createClient(WorkspaceVersionService, transport) +const orgClient = createClient(OrgService, transport) +let dependencyConnectorsCache: Promise | null = null // ─── Helpers ───────────────────────────────────────────────────────────────── @@ -75,6 +219,49 @@ function j(schema: Parameters[0], msg: Parameters) { + const res = await fetchApiAsset(apiUrl('/diag.v1.WorkspaceService/GetWorkspace'), { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + 'Connect-Protocol-Version': '1', + }, + body: JSON.stringify(body), + }) + if (!res.ok) { + throw new Error(`GetWorkspace failed: ${res.statusText}`) + } + return res.json() as Promise<{ + views?: ProtoDiagram[] + total_count?: number + totalCount?: number + content?: Record[]; connectors?: Record[] }> + navigations?: Record[] + }> +} + // ─── Proto → frontend type mappers ─────────────────────────────────────────── interface ProtoDiagram { @@ -139,10 +326,10 @@ function protoElementToLibrary(e: Record): LibraryElement { technology: (e.technology ?? null) as string | null, url: (e.url ?? null) as string | null, logo_url: (e.logo_url ?? e.logoUrl ?? null) as string | null, - technology_connectors: ((e.technology_connectors ?? e.technologyLinks ?? []) as any[]).map(tl => ({ - type: tl.type, + technology_connectors: ((e.technology_connectors ?? e.technologyLinks ?? []) as Array<{ type?: string; slug?: string; label?: string; is_primary_icon?: boolean; isPrimaryIcon?: boolean }>).map(tl => ({ + type: (tl.type ?? 'custom') as TechnologyConnector['type'], slug: tl.slug, - label: tl.label, + label: tl.label ?? '', is_primary_icon: !!(tl.is_primary_icon ?? tl.isPrimaryIcon), })), tags: (e.tags ?? []) as string[], @@ -157,6 +344,26 @@ function protoElementToLibrary(e: Record): LibraryElement { } } +function libraryElementToDependency(element: LibraryElement): DependencyElement { + return { + id: String(element.id), + name: element.name, + type: element.kind, + description: element.description, + technology: element.technology, + url: element.url, + logo_url: element.logo_url, + technology_connectors: element.technology_connectors, + tags: element.tags, + repo: element.repo, + branch: element.branch, + language: element.language, + file_path: element.file_path, + created_at: element.created_at, + updated_at: element.updated_at, + } +} + function protoPlacedElement(p: Record): PlacedElement { return { id: Number(p.id ?? 0), @@ -170,10 +377,10 @@ function protoPlacedElement(p: Record): PlacedElement { technology: (p.technology ?? null) as string | null, url: (p.url ?? null) as string | null, logo_url: (p.logo_url ?? p.logoUrl ?? null) as string | null, - technology_connectors: ((p.technology_connect_ors ?? p.technology_connectors ?? p.technologyLinks ?? []) as any[]).map(tl => ({ - type: tl.type, + technology_connectors: ((p.technology_connect_ors ?? p.technology_connectors ?? p.technologyLinks ?? []) as Array<{ type?: string; slug?: string; label?: string; is_primary_icon?: boolean; isPrimaryIcon?: boolean }>).map(tl => ({ + type: (tl.type ?? 'custom') as TechnologyConnector['type'], slug: tl.slug, - label: tl.label, + label: tl.label ?? '', is_primary_icon: !!(tl.is_primary_icon ?? tl.isPrimaryIcon), })), tags: (p.tags ?? []) as string[], @@ -205,6 +412,25 @@ function protoConnector(e: Record): Connector { } } +function protoDependencyConnector(e: Record): DependencyConnector { + return { + id: String(e.id ?? 0), + view_id: String(e.view_id ?? e.viewId ?? 0), + source_element_id: String(e.source_element_id ?? e.sourceElementId ?? 0), + target_element_id: String(e.target_element_id ?? e.targetElementId ?? 0), + label: (e.label ?? null) as string | null, + description: (e.description ?? null) as string | null, + relationship_type: (e.relationship_type ?? e.relationshipType ?? e.relationship ?? null) as string | null, + direction: String(e.direction ?? 'forward'), + connector_type: String(e.connector_type ?? e.connectorType ?? e.style ?? 'solid'), + url: (e.url ?? null) as string | null, + source_handle: (e.source_handle ?? e.sourceHandle ?? null) as string | null, + target_handle: (e.target_handle ?? e.targetHandle ?? null) as string | null, + created_at: String(e.created_at ?? e.createdAt ?? ''), + updated_at: String(e.updated_at ?? e.updatedAt ?? ''), + } +} + function protoNavigation(n: Record): ViewConnector { return { id: Number(n.id ?? 0), @@ -339,7 +565,20 @@ export const api = { workspace: { orgs: { tagColors: { - list: (): Promise => Promise.resolve([]), + list: (): Promise> => + rpc(async () => { + const res = await orgClient.listTagColors({}) + const json = j<{ tags?: Record }>(ListTagColorsResponseSchema, res) + const tags: Record = {} + Object.entries(json.tags ?? {}).forEach(([name, tag]) => { + tags[name] = { name, color: tag.color ?? '#A0AEC0', description: tag.description ?? null } + }) + return tags + }), + update: (name: string, color: string, description?: string | null): Promise => + rpc(async () => { + await orgClient.updateTag({ tag: name, color, description: description ?? undefined }) + }), }, }, @@ -384,18 +623,15 @@ export const api = { }), content: (id: number): Promise<{ placements: PlacedElement[]; connectors: Connector[] }> => - rpc(async () => { - const [placementsRes, connectorsRes] = await Promise.all([ - workspaceClient.listPlacements({ viewId: id }), - workspaceClient.listConnectors({ viewId: id }), - ]) - const placementJson = j<{ placements: Record[] }>(ListPlacementsResponseSchema, placementsRes) - const connectorJson = j<{ connectors: Record[] }>(ListConnectorsResponseSchema, connectorsRes) + (async () => { + const res = await fetch(apiUrl(`/views/${id}/projected-content`)) + if (!res.ok) throw new Error('Failed to load view content') + const json = await res.json() as { placements?: Record[]; connectors?: Record[] } return { - placements: (placementJson.placements ?? []).map(protoPlacedElement), - connectors: (connectorJson.connectors ?? []).map(protoConnector), + placements: (json.placements ?? []).map(protoPlacedElement), + connectors: (json.connectors ?? []).map(protoConnector), } - }), + })(), tree: (): Promise => rpc(async () => { @@ -434,6 +670,60 @@ export const api = { return (json.views ?? []).map(mapDiagram) }), + treeAround: async ( + viewId: number, + opts: { ancestorLevels?: number; descendantLevels?: number } = {}, + ): Promise => { + const ancestorLevels = opts.ancestorLevels ?? 2 + const descendantLevels = opts.descendantLevels ?? 2 + const current = await api.workspace.views.get(viewId) + + const ancestors: ViewTreeNode[] = [] + let cursor: ViewTreeNode = current + for (let depth = 0; depth < ancestorLevels && cursor.parent_view_id != null; depth += 1) { + const parent = await api.workspace.views.get(cursor.parent_view_id) + ancestors.unshift(parent) + cursor = parent + } + + const withDescendants = async (node: ViewTreeNode, remainingDepth: number): Promise => { + const scoped: ViewTreeNode = { ...node, children: [] } + if (remainingDepth <= 0) return scoped + const children = await api.workspace.views.treeChildren(node.id) + scoped.children = await Promise.all(children.map((child) => withDescendants(child, remainingDepth - 1))) + return scoped + } + + let scoped = await withDescendants(current, descendantLevels) + for (let index = ancestors.length - 1; index >= 0; index -= 1) { + scoped = { ...ancestors[index], children: [scoped] } + } + return [scoped] + }, + + gridData: (): Promise<{ + views: ViewTreeNode[] + content: Record + }> => + rpc(async () => { + const json = await fetchWorkspaceRaw({ + includeContent: true, + hasView: true, + }) + return { + views: (json.views ?? []).map(mapDiagram), + content: Object.fromEntries( + Object.entries(json.content ?? {}).map(([key, value]) => [ + Number(key), + { + placements: (value.placements ?? []).map(protoPlacedElement), + connectors: (value.connectors ?? []).map(protoConnector), + }, + ]) + ), + } + }), + get: (id: number): Promise => rpc(async () => { const res = await workspaceClient.getView({ viewId: id }) @@ -470,6 +760,60 @@ export const api = { setLevel: (id: number, level: number): Promise => rpc(async () => { await workspaceClient.setViewLevel({ viewId: id, level }) }), + density: { + get: async (id: number): Promise => { + const res = await fetch(apiUrl(`/views/${id}/density`)) + if (!res.ok) throw new Error('Failed to load density') + const json = await res.json() as { density_level?: number } + return Number(json.density_level ?? 0) + }, + set: async (id: number, densityLevel: number): Promise => { + const res = await fetch(apiUrl(`/views/${id}/density`), { + method: 'PUT', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ density_level: densityLevel }), + }) + if (!res.ok) throw new Error('Failed to save density') + const json = await res.json() as { density_level?: number } + return Number(json.density_level ?? densityLevel) + }, + }, + + visibilityOverrides: { + list: async (id: number): Promise => { + const res = await fetch(apiUrl(`/views/${id}/visibility-overrides`)) + if (!res.ok) throw new Error('Failed to load visibility overrides') + const json = await res.json() as { overrides?: VisibilityOverride[] } + return json.overrides ?? [] + }, + set: async (id: number, resourceType: VisibilityOverride['resource_type'], resourceId: number, levelDelta: number): Promise => { + const res = await fetch(apiUrl(`/views/${id}/visibility-overrides`), { + method: 'PUT', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ resource_type: resourceType, resource_id: resourceId, level_delta: levelDelta }), + }) + if (!res.ok) throw new Error('Failed to save visibility override') + const json = await res.json() as { override?: VisibilityOverride } + return json.override ?? { view_id: id, resource_type: resourceType, resource_id: resourceId, level_delta: levelDelta } + }, + promote: async (id: number, resourceType: VisibilityOverride['resource_type'], resourceId: number): Promise => { + const res = await fetch(apiUrl(`/views/${id}/visibility-overrides/${resourceType}/${resourceId}/promote`), { method: 'POST' }) + if (!res.ok) throw new Error('Failed to promote visibility') + const json = await res.json() as { override?: VisibilityOverride } + return json.override ?? { view_id: id, resource_type: resourceType, resource_id: resourceId, level_delta: 1 } + }, + demote: async (id: number, resourceType: VisibilityOverride['resource_type'], resourceId: number): Promise => { + const res = await fetch(apiUrl(`/views/${id}/visibility-overrides/${resourceType}/${resourceId}/demote`), { method: 'POST' }) + if (!res.ok) throw new Error('Failed to demote visibility') + const json = await res.json() as { override?: VisibilityOverride } + return json.override ?? { view_id: id, resource_type: resourceType, resource_id: resourceId, level_delta: -1 } + }, + reset: async (id: number, resourceType: VisibilityOverride['resource_type'], resourceId: number): Promise => { + const res = await fetch(apiUrl(`/views/${id}/visibility-overrides/${resourceType}/${resourceId}`), { method: 'DELETE' }) + if (!res.ok) throw new Error('Failed to reset visibility override') + }, + }, + delete: (_orgId: string, id: number): Promise => rpc(async () => { await workspaceClient.deleteView({ orgId: '', viewId: id }) }), @@ -620,8 +964,36 @@ export const api = { }, dependencies: { - list: (): Promise => + list: (params?: { limit?: number; offset?: number; search?: string }): Promise => rpc(async () => { + if (params) { + if (!dependencyConnectorsCache) { + dependencyConnectorsCache = workspaceClient.listConnectors({ viewId: 0 }) + .then((res) => { + const connectorJson = j<{ connectors: Record[] }>(ListConnectorsResponseSchema, res) + return (connectorJson.connectors ?? []).map(protoDependencyConnector) + }) + } + const [elements, connectors] = await Promise.all([ + workspaceClient.listElements({ + limit: params.limit ?? 0, + offset: params.offset ?? 0, + search: params.search ?? '', + }).then((res) => { + const json = j<{ elements: Record[] }>(ListElementsResponseSchema, res) + return { + elements: (json.elements ?? []).map(protoElementToLibrary), + totalCount: res.pagination ? Number(res.pagination.totalCount) : undefined, + } + }), + dependencyConnectorsCache, + ]) + return { + elements: elements.elements.map(libraryElementToDependency), + connectors, + totalCount: elements.totalCount, + } + } const res = await dependencyClient.listDependencies({}) return j(ListDependenciesResponseSchema, res) }), @@ -665,8 +1037,8 @@ export const api = { throw new Error(`Failed to load shared diagram: ${res.statusText}`) } const data = await res.json() as { - tree: any[] - views: Record + tree: ProtoDiagram[] + views: Record[]; connectors: Record[] }> password_required?: boolean } @@ -683,7 +1055,7 @@ export const api = { // Ensure that the share root is treated as a root (no parent) so that computeLayout // picks it up even if it was nested in the original workspace. - const sharedRoot = tree.find(n => String(n.id) === String(data.views[token]?.elements?.[0]?.view_id ?? '')) + const _sharedRoot = tree.find(n => String(n.id) === String(data.views[token]?.elements?.[0]?.view_id ?? '')) // Backend actually returns the shareToken.ViewID as the root of the tree it builds. // We should find the node in 'tree' that has no parent *within the returned set*. // For shared explore, the backend typically returns a tree starting at the shared view. @@ -695,9 +1067,9 @@ export const api = { } }) const navigations: ViewConnector[] = [] - const elementToChildView = new Map() - const allViews: any[] = [] - const flatTree = (nodes: any[]) => { + const elementToChildView = new Map() + const allViews: ViewTreeNode[] = [] + const flatTree = (nodes: ViewTreeNode[]) => { nodes.forEach(n => { allViews.push(n) if (n.owner_element_id) elementToChildView.set(n.owner_element_id, n) @@ -706,8 +1078,8 @@ export const api = { } flatTree(tree) - Object.values(views).forEach((v: any) => { - v.placements.forEach((p: any) => { + Object.values(views).forEach((v) => { + v.placements.forEach((p) => { const childView = elementToChildView.get(p.element_id) if (childView) { navigations.push({ @@ -751,4 +1123,91 @@ export const api = { } }), }, + + versions: { + list: (limit = 50): Promise => + rpc(async () => { + const res = await workspaceVersionClient.listVersions({ limit }) + return (res.versions ?? []).map(mapWorkspaceVersion) + }), + }, + + watch: { + status: async (): Promise => { + const res = await fetch(apiUrl('/watch/status')) + if (!res.ok) throw new Error(`Failed to load watch status: ${res.statusText}`) + return res.json() + }, + websocketUrl: (): string => { + const url = new URL(apiUrl('/watch/ws'), window.location.href) + url.protocol = url.protocol === 'https:' ? 'wss:' : 'ws:' + return url.toString() + }, + repositories: async (): Promise => { + const res = await fetch(apiUrl('/watch/repositories')) + if (!res.ok) throw new Error(`Failed to load watch repositories: ${res.statusText}`) + return res.json() + }, + versions: async (repositoryId: number): Promise => { + const res = await fetch(apiUrl(`/watch/repositories/${repositoryId}/versions`)) + if (!res.ok) throw new Error(`Failed to load watch versions: ${res.statusText}`) + return res.json() + }, + diffs: async (versionId: number, filters?: { owner_type?: string; change_type?: string; resource_type?: string; language?: string }): Promise => { + const params = new URLSearchParams() + if (filters?.owner_type) params.set('owner_type', filters.owner_type) + if (filters?.change_type) params.set('change_type', filters.change_type) + if (filters?.resource_type) params.set('resource_type', filters.resource_type) + if (filters?.language) params.set('language', filters.language) + const suffix = params.toString() ? `?${params}` : '' + const res = await fetch(apiUrl(`/watch/versions/${versionId}/diffs${suffix}`)) + if (!res.ok) throw new Error(`Failed to load watch diffs: ${res.statusText}`) + return res.json() + }, + showContext: async (repositoryId: number, input: { resource_type: 'element' | 'view'; resource_id: number }): Promise => { + const res = await fetch(apiUrl(`/watch/repositories/${repositoryId}/context/show`), { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify(input), + }) + if (!res.ok) throw await responseError(res, 'Failed to show watch context') + return res.json() + }, + cleanContext: async (repositoryId: number, input: { resource_type: 'element' | 'view'; resource_id: number }): Promise => { + const res = await fetch(apiUrl(`/watch/repositories/${repositoryId}/context/clean`), { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify(input), + }) + if (!res.ok) throw await responseError(res, 'Failed to clean watch context') + return res.json() + }, + hideContext: async (repositoryId: number, input: { resource_type: 'element' | 'view'; resource_id: number }): Promise => { + const res = await fetch(apiUrl(`/watch/repositories/${repositoryId}/context/hide`), { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify(input), + }) + if (!res.ok) throw await responseError(res, 'Failed to hide watch context') + return res.json() + }, + }, + + editor: { + open: async (input: { editor: SourceEditor; repo?: string | null; file_path: string; line?: number | null }): Promise => { + const res = await fetch(apiUrl('/editor/open'), { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ + editor: input.editor, + repo: input.repo ?? '', + file_path: input.file_path, + line: input.line ?? 0, + }), + }) + if (!res.ok) { + throw await responseError(res, 'Failed to open editor') + } + }, + }, } diff --git a/frontend/src/components/CodePreviewPanel.tsx b/frontend/src/components/CodePreviewPanel.tsx index 8e63ede..5c107b0 100644 --- a/frontend/src/components/CodePreviewPanel.tsx +++ b/frontend/src/components/CodePreviewPanel.tsx @@ -2,6 +2,24 @@ import { useEffect, useState, useRef } from 'react' import type { SVGProps } from 'react' import { Box, Button, CloseButton, HStack, Icon, Spinner, Text, Tooltip, VStack } from '@chakra-ui/react' import { ExternalLinkIcon } from '@chakra-ui/icons' +import CodeMirror, { ReactCodeMirrorRef } from '@uiw/react-codemirror' +import { EditorView } from '@codemirror/view' +import { oneDark } from '@codemirror/theme-one-dark' +import { javascript } from '@codemirror/lang-javascript' +import { python } from '@codemirror/lang-python' +import { cpp } from '@codemirror/lang-cpp' +import { java } from '@codemirror/lang-java' +import { rust } from '@codemirror/lang-rust' + +import SlidingPanel from './SlidingPanel' +import { api } from '../api/client' +import { findSymbolByName, getParser, detectLanguage, type SupportedLanguage } from '../utils/treesitter' +import { githubCache } from '../utils/githubCache' +import { getGithubRepoVisibility } from '../utils/githubApi' +import { parseRepoSlug } from '../utils/url' +import { useSourceEditor } from '../utils/sourceEditor' +import { toast } from '../utils/toast' +import type { PlacedElement } from '../types' const GithubIcon = (props: SVGProps) => ( 0 ? line : null +} export default function CodePreviewPanel({ isOpen, onClose, element, hasBackdrop = true }: Props) { const [code, setCode] = useState('') @@ -80,6 +92,8 @@ export default function CodePreviewPanel({ isOpen, onClose, element, hasBackdrop const [resolvedStartLine, setResolvedStartLine] = useState(null) const [resolvedEndLine, setResolvedEndLine] = useState(null) const [isPrivateRepo, setIsPrivateRepo] = useState(false) + const [openingEditor, setOpeningEditor] = useState(false) + const { editor: sourceEditor } = useSourceEditor() const editorRef = useRef(null) @@ -88,6 +102,10 @@ export default function CodePreviewPanel({ isOpen, onClose, element, hasBackdrop const basePath = hashIdx >= 0 ? filePath.slice(0, hashIdx) : filePath const symbolInfoStr = hashIdx >= 0 ? filePath.slice(hashIdx + 1) : '' const repoSlug = element?.repo ? parseRepoSlug(element.repo) : '' + const anchor = parseAnchor(symbolInfoStr) + const anchorStartLine = anchor.kind === 'lines' ? anchor.startLine : null + const fallbackStartLine = inferLineFromDescription(element?.description, basePath) + const editorStartLine = resolvedStartLine ?? anchorStartLine ?? fallbackStartLine useEffect(() => { if (!isOpen || !element || !repoSlug || !basePath) return @@ -205,9 +223,31 @@ export default function CodePreviewPanel({ isOpen, onClose, element, hasBackdrop const githubUrl = element?.repo && basePath ? `https://github.com/${repoSlug}/blob/${element.branch || 'main'}/${basePath}` - + (resolvedStartLine ? `#L${resolvedStartLine}-L${resolvedEndLine ?? resolvedStartLine}` : '') + + (editorStartLine ? `#L${editorStartLine}-L${resolvedEndLine ?? editorStartLine}` : '') : null + const handleOpenInEditor = async () => { + if (!basePath) return + setOpeningEditor(true) + try { + await api.editor.open({ + editor: sourceEditor, + repo: element?.repo ?? '', + file_path: basePath, + line: editorStartLine, + }) + } catch (err) { + toast({ + title: 'Failed to open editor', + description: err instanceof Error ? err.message : String(err), + status: 'error', + duration: 4000, + }) + } finally { + setOpeningEditor(false) + } + } + const getLanguageExtension = () => { const extensions = [customCodeTheme] const effectiveLanguage = element?.language || detectLanguage(basePath) @@ -333,6 +373,40 @@ export default function CodePreviewPanel({ isOpen, onClose, element, hasBackdrop )} + {basePath && ( + + } + size="xs" + variant="outline" + color="whiteAlpha.700" + borderColor="whiteAlpha.200" + h="24px" + px={2.5} + fontSize="11px" + fontWeight="600" + bg="whiteAlpha.50" + isLoading={openingEditor} + onClick={handleOpenInEditor} + _hover={{ + color: 'white', + bg: 'whiteAlpha.100', + borderColor: 'whiteAlpha.400', + textDecoration: 'none', + transform: 'translateY(-0.5px)', + boxShadow: '0 2px 4px rgba(0,0,0,0.2)' + }} + _active={{ + bg: 'whiteAlpha.200', + transform: 'translateY(0)', + }} + transition="all 0.1s" + > + Open in Editor + + + )} diff --git a/frontend/src/components/ConnectorPanel.tsx b/frontend/src/components/ConnectorPanel.tsx index df9daec..11292f0 100644 --- a/frontend/src/components/ConnectorPanel.tsx +++ b/frontend/src/components/ConnectorPanel.tsx @@ -1,6 +1,7 @@ import { memo, useEffect, useRef, useState, useCallback } from 'react' import type { ConnectorPanelSlots } from '../slots' import { + Badge, Box, Button, Divider, @@ -68,6 +69,10 @@ export interface ConnectorPanelProps extends ConnectorPanelSlots { onSave: (connector: Connector) => void autoSave?: boolean onDelete: (edgeId: number) => void + visibilityOverrideDelta?: number + onPromoteVisibility?: (id: number) => Promise | void + onDemoteVisibility?: (id: number) => Promise | void + onResetVisibility?: (id: number) => Promise | void hasBackdrop?: boolean } @@ -77,7 +82,7 @@ export interface ConnectorPanelProps extends ConnectorPanelSlots { * Location: Right side of the screen on desktop. Overlays screen on mobile. * Aliases: Connector Properties, Connector Details. */ -function ConnectorPanel({ isOpen, onClose, connector, orgId, onSave, autoSave = false, onDelete, hasBackdrop = true, connectorPanelAfterContentSlot }: ConnectorPanelProps) { +function ConnectorPanel({ isOpen, onClose, connector, orgId, onSave, autoSave = false, onDelete, visibilityOverrideDelta = 0, onPromoteVisibility, onDemoteVisibility, onResetVisibility, hasBackdrop = true, connectorPanelAfterContentSlot }: ConnectorPanelProps) { const { canEdit, viewId } = useViewEditorContext() const isReadOnly = !canEdit const autoSaveEdit = autoSave && !!connector && !isReadOnly @@ -350,6 +355,32 @@ function ConnectorPanel({ isOpen, onClose, connector, orgId, onSave, autoSave = /> + {connector && (onPromoteVisibility || onDemoteVisibility || onResetVisibility) && ( + + + DENSITY + {visibilityOverrideDelta !== 0 && ( + 0 ? 'teal' : 'orange'} variant="subtle"> + {visibilityOverrideDelta > 0 ? `+${visibilityOverrideDelta}` : visibilityOverrideDelta} + + )} + + + + + {visibilityOverrideDelta !== 0 && ( + + )} + + + )} + {connectorPanelAfterContentSlot} diff --git a/frontend/src/components/ContextNeighborElement.tsx b/frontend/src/components/ContextNeighborElement.tsx index a9af157..1323fae 100644 --- a/frontend/src/components/ContextNeighborElement.tsx +++ b/frontend/src/components/ContextNeighborElement.tsx @@ -20,7 +20,7 @@ import { import { ChevronDownIcon, ChevronLeftIcon, ChevronRightIcon, ChevronUpIcon, LinkIcon } from '@chakra-ui/icons' import type { PlacedElement } from '../types' import { TYPE_COLORS } from '../types' -import { resolveIconPath } from '../utils/url' +import { resolveElementIconUrl } from '../utils/elementIcon' import { ElementBody } from './NodeBody' import { ElementContainer } from './NodeContainer' @@ -61,10 +61,7 @@ function ContextNeighborNode({ data }: Props) { const color = TYPE_COLORS[data.kind ?? ''] ?? 'gray' const logoUrl = useMemo(() => { - if (data.logo_url) return resolveIconPath(data.logo_url) - const selected = data.technology_connectors?.find((link) => link.type === 'catalog' && !!(link.is_primary_icon ?? (link as any).isPrimaryIcon) && !!link.slug) - if (!selected?.slug) return undefined - return resolveIconPath(`/icons/${selected.slug}.png`) + return resolveElementIconUrl(data.logo_url, data.technology_connectors) ?? undefined }, [data.logo_url, data.technology_connectors]) const primaryOwnerViewId = data.ownerViewIds[0] ?? data.commonAncestorViewId ?? null diff --git a/frontend/src/components/CrossBranchControls.tsx b/frontend/src/components/CrossBranchControls.tsx index d87b254..1384b1b 100644 --- a/frontend/src/components/CrossBranchControls.tsx +++ b/frontend/src/components/CrossBranchControls.tsx @@ -16,26 +16,29 @@ import { Text, VStack, } from '@chakra-ui/react' -import { CROSS_BRANCH_DEPTH_ALL, CROSS_BRANCH_DEPTH_MAX, CROSS_BRANCH_DEPTH_MIN } from '../crossBranch/types' -import type { CrossBranchContextSettings } from '../crossBranch/types' +import { + CROSS_BRANCH_CONNECTOR_BUDGET_MAX, + CROSS_BRANCH_CONNECTOR_BUDGET_MIN, +} from '../crossBranch/types' +import type { CrossBranchConnectorPriority, CrossBranchContextSettings } from '../crossBranch/types' interface Props { settings: CrossBranchContextSettings onEnabledChange: (enabled: boolean) => void - onDepthChange: (depth: number) => void + onBudgetChange: (budget: number) => void + onPriorityChange: (priority: CrossBranchConnectorPriority) => void label?: string } -function depthLabel(depth: number) { - return depth >= CROSS_BRANCH_DEPTH_ALL ? 'All' : String(depth) -} - export default function CrossBranchControls({ settings, onEnabledChange, - onDepthChange, + onBudgetChange, + onPriorityChange, label = 'Cross-Branch', }: Props) { + const connectorBudget = settings.connectorBudget + return ( @@ -50,7 +53,7 @@ export default function CrossBranchControls({ {label} - {settings.enabled ? depthLabel(settings.depth) : 'Off'} + {settings.enabled ? connectorBudget : 'Off'} @@ -71,20 +74,46 @@ export default function CrossBranchControls({ Show cross-branch context onEnabledChange(event.target.checked)} colorScheme="blue" /> + + + Priority + + + {(['external', 'internal'] as const).map((priority) => { + const active = settings.connectorPriority === priority + return ( + + ) + })} + + - Descendant Depth + Connector Budget - {depthLabel(settings.depth)} + {connectorBudget} @@ -92,8 +121,8 @@ export default function CrossBranchControls({ - Near - All + {CROSS_BRANCH_CONNECTOR_BUDGET_MIN} + {CROSS_BRANCH_CONNECTOR_BUDGET_MAX} diff --git a/frontend/src/components/ElementNode.tsx b/frontend/src/components/ElementNode.tsx index ea1309f..3f02815 100644 --- a/frontend/src/components/ElementNode.tsx +++ b/frontend/src/components/ElementNode.tsx @@ -7,7 +7,7 @@ import { useAccentColor } from '../context/ThemeContext' import type { PlacedElement, ViewConnector, Tag } from '../types' import { ElementContainer } from './NodeContainer' import { ElementBody } from './NodeBody' -import { resolveIconPath } from '../utils/url' +import { resolveElementIconUrl } from '../utils/elementIcon' import { ZoomInIcon, ZoomOutIcon, TrashIcon as TrashSvg, EditIcon as EditSvg } from './Icons' import { vscodeBridge } from '../lib/vscodeBridge' import type { ExtensionToWebviewMessage } from '../types/vscode-messages' @@ -155,6 +155,8 @@ interface NodeData extends PlacedElement { selectedHandleIds?: readonly string[] reconnectCandidates?: readonly { handleId: string; edgeId: string; endpoint: 'source' | 'target'; selected: boolean }[] isConnectorHighlighted?: boolean + versionChangeType?: 'added' | 'updated' | 'deleted' | 'initialized' + versionLineDelta?: { added: number; removed: number } } interface Props { @@ -316,13 +318,7 @@ function ElementNode({ data, selected }: Props) { return next }, [data.reconnectCandidates]) - const derivedPrimaryIconPath = (() => { - const selected = data.technology_connectors?.find((link) => link.type === 'catalog' && !!(link.is_primary_icon ?? (link as any).isPrimaryIcon) && !!link.slug) - if (!selected?.slug) return undefined - return resolveIconPath(`/icons/${selected.slug}.png`) - })() - - const nodeLogoUrl = data.logo_url ? resolveIconPath(data.logo_url) : derivedPrimaryIconPath + const nodeLogoUrl = resolveElementIconUrl(data.logo_url, data.technology_connectors) ?? undefined const technologyLinkCount = (data.technology_connectors || []).filter((l) => !!l.label).length const technologyParts = (data.technology || '') @@ -456,6 +452,13 @@ function ElementNode({ data, selected }: Props) { const isTarget = !!data.interactionSourceId && !isSource const bodyCursor = isSource ? 'crosshair' : isTarget ? 'cell' : 'pointer' + const versionColor = data.versionChangeType === 'added' + ? 'green.300' + : data.versionChangeType === 'deleted' + ? 'red.300' + : data.versionChangeType + ? 'yellow.300' + : undefined return ( {HANDLE_CONFIGS.flatMap(({ side, position }) => @@ -702,62 +707,108 @@ function ElementNode({ data, selected }: Props) { )} {/* Code Preview Icon/Link in Bottom Right Corner */} - {((data.repo || data.url) && !window.__TLD_VSCODE__) && ( - - { try { return JSON.parse(data.file_path.split('#')[1]).name } catch { return 'Link' } })() : 'Link'}${data.url ? ' / URL' : ''}` - : 'Open Link' - } - placement="top" - isDisabled={data.isCanvasMoving} - > - { - e.stopPropagation() - if (data.repo) { - data.onOpenCodePreview?.(data.element_id) - } else if (data.url) { - window.open(data.url, '_blank', 'noopener,noreferrer') - } - }} - onPointerDown={(e: React.PointerEvent) => e.stopPropagation()} + bg="rgba(var(--bg-main-rgb), 0.86)" + border="1px solid" + borderColor="whiteAlpha.300" + boxShadow="0 4px 12px rgba(0,0,0,0.28)" + pointerEvents="none" > - - - - + {data.versionLineDelta.added > 0 && ( + +{data.versionLineDelta.added} + )} + {data.versionLineDelta.removed > 0 && ( + -{data.versionLineDelta.removed} + )} + + )} + {(data.repo || data.url) && !window.__TLD_VSCODE__ && ( + { try { return JSON.parse(data.file_path.split('#')[1]).name } catch { return 'Link' } })() : 'Link'}${data.url ? ' / URL' : ''}` + : 'Open Link' + } + placement="top" + isDisabled={data.isCanvasMoving} + > + { + e.stopPropagation() + if (data.repo) { + data.onOpenCodePreview?.(data.element_id) + } else if (data.url) { + window.open(data.url, '_blank', 'noopener,noreferrer') + } + }} + onPointerDown={(e: React.PointerEvent) => e.stopPropagation()} + > + + + + )} + )} {/* VSCode specific file link with hover preview */} {window.__TLD_VSCODE__ && data.file_path && ( - + {data.versionLineDelta && ( + + {data.versionLineDelta.added > 0 && ( + +{data.versionLineDelta.added} + )} + {data.versionLineDelta.removed > 0 && ( + -{data.versionLineDelta.removed} + )} + + )} - + )} {selected && !isSource && ( diff --git a/frontend/src/components/ElementPanel.tsx b/frontend/src/components/ElementPanel.tsx index c075d6a..03cb46d 100644 --- a/frontend/src/components/ElementPanel.tsx +++ b/frontend/src/components/ElementPanel.tsx @@ -73,8 +73,8 @@ function dedupeTechnologyLinks(links: TechnologyConnector[]): TechnologyConnecto // Sort links to process primary ones first, ensuring they are preserved during deduping const sortedLinks = [...links].sort((a, b) => { - const aPrimary = !!(a.is_primary_icon ?? (a as any).isPrimaryIcon) - const bPrimary = !!(b.is_primary_icon ?? (b as any).isPrimaryIcon) + const aPrimary = !!(a.is_primary_icon ?? a.isPrimaryIcon) + const bPrimary = !!(b.is_primary_icon ?? b.isPrimaryIcon) if (aPrimary && !bPrimary) return -1 if (!aPrimary && bPrimary) return 1 return 0 @@ -84,7 +84,7 @@ function dedupeTechnologyLinks(links: TechnologyConnector[]): TechnologyConnecto const label = link.label.trim() if (!label) continue - const isPrimary = !!(link.is_primary_icon ?? (link as any).isPrimaryIcon) + const isPrimary = !!(link.is_primary_icon ?? link.isPrimaryIcon) if (link.type === 'catalog' && link.slug) { const slug = link.slug.trim() @@ -141,7 +141,7 @@ async function normalizeInitialTechnologyLinks(element: LibraryElement): Promise type: 'catalog', slug: link.slug, label: match?.name ?? link.label, - is_primary_icon: !!(link.is_primary_icon ?? (link as any).isPrimaryIcon), + is_primary_icon: !!(link.is_primary_icon ?? link.isPrimaryIcon), }) } else { const parts = splitTechnologyLabel(link.label) @@ -218,6 +218,12 @@ export interface ElementPanelProps extends ElementPanelSlots { autoSave?: boolean onDelete?: (id: number) => void onPermanentDelete?: (id: number) => void + onShowContext?: (id: number) => Promise | void + onHideContext?: (id: number) => Promise | void + visibilityOverrideDelta?: number + onPromoteVisibility?: (id: number) => Promise | void + onDemoteVisibility?: (id: number) => Promise | void + onResetVisibility?: (id: number) => Promise | void orgId?: string links?: ViewConnector[] parentLinks?: ViewConnector[] @@ -231,7 +237,7 @@ export interface ElementPanelProps extends ElementPanelSlots { * Location: Right side of the screen on desktop. Overlays screen on mobile. * Aliases: Element Properties, Element Details. */ -function ElementPanel({ isOpen, onClose, element, onSave, autoSave = false, onDelete, onPermanentDelete, orgId, links = [], parentLinks = [], hasBackdrop = true, availableTags = [], elementPanelAfterContentSlot }: ElementPanelProps) { +function ElementPanel({ isOpen, onClose, element, onSave, autoSave = false, onDelete, onPermanentDelete, onShowContext, onHideContext, visibilityOverrideDelta = 0, onPromoteVisibility, onDemoteVisibility, onResetVisibility, orgId, links = [], parentLinks = [], hasBackdrop = true, availableTags = [], elementPanelAfterContentSlot }: ElementPanelProps) { const { canEdit, viewId } = useViewEditorContext() const isEdit = !!element const isReadOnly = !canEdit @@ -280,7 +286,7 @@ function ElementPanel({ isOpen, onClose, element, onSave, autoSave = false, onDe const linksFromElement = (element.technology_connectors ?? []).map(tl => ({ ...tl, - is_primary_icon: !!(tl.is_primary_icon ?? (tl as any).isPrimaryIcon), + is_primary_icon: !!(tl.is_primary_icon ?? tl.isPrimaryIcon), })) const fallbackLinks: TechnologyConnector[] = linksFromElement.length > 0 ? linksFromElement @@ -333,14 +339,14 @@ function ElementPanel({ isOpen, onClose, element, onSave, autoSave = false, onDe }, [element, isOpen]) const buildPayloadAndFingerprint = useCallback(async () => { - const primaryLink = technologyLinks.find((link) => link.type === 'catalog' && !!(link.is_primary_icon ?? (link as any).isPrimaryIcon) && link.slug) + const primaryLink = technologyLinks.find((link) => link.type === 'catalog' && !!(link.is_primary_icon ?? link.isPrimaryIcon) && link.slug) const primarySlug = primaryLink?.slug const normalizedLinks = technologyLinks.map((link) => ({ type: link.type, slug: link.type === 'catalog' ? link.slug : undefined, label: link.label, - is_primary_icon: !!(link.is_primary_icon ?? (link as any).isPrimaryIcon), + is_primary_icon: !!(link.is_primary_icon ?? link.isPrimaryIcon), })) const normalizedType = type.trim().toLowerCase() @@ -572,7 +578,7 @@ function ElementPanel({ isOpen, onClose, element, onSave, autoSave = false, onDe scheduleAutoSave() } - const selectedPrimarySlug = technologyLinks.find((link) => link.type === 'catalog' && !!(link.is_primary_icon ?? (link as any).isPrimaryIcon) && !!link.slug)?.slug ?? '' + const selectedPrimarySlug = technologyLinks.find((link) => link.type === 'catalog' && !!(link.is_primary_icon ?? link.isPrimaryIcon) && !!link.slug)?.slug ?? '' const commitTypeFromQuery = () => { if (isReadOnly) return @@ -595,7 +601,7 @@ function ElementPanel({ isOpen, onClose, element, onSave, autoSave = false, onDe if (isReadOnly || !name.trim()) return setLoading(true) try { - const primaryLink = technologyLinks.find((link) => link.type === 'catalog' && !!(link.is_primary_icon ?? (link as any).isPrimaryIcon) && link.slug) + const primaryLink = technologyLinks.find((link) => link.type === 'catalog' && !!(link.is_primary_icon ?? link.isPrimaryIcon) && link.slug) const primaryMetadata = primaryLink?.slug ? (technologyMeta[primaryLink.slug] ?? await getTechnologyCatalogItemBySlug(primaryLink.slug)) : null @@ -604,7 +610,7 @@ function ElementPanel({ isOpen, onClose, element, onSave, autoSave = false, onDe type: link.type, slug: link.type === 'catalog' ? link.slug : undefined, label: link.label, - is_primary_icon: !!(link.is_primary_icon ?? (link as any).isPrimaryIcon), + is_primary_icon: !!(link.is_primary_icon ?? link.isPrimaryIcon), })) const normalizedType = type.trim().toLowerCase() @@ -862,7 +868,7 @@ function ElementPanel({ isOpen, onClose, element, onSave, autoSave = false, onDe const meta = link.slug ? technologyMeta[link.slug] : undefined const sourceUrl = meta?.websiteUrl || meta?.docsUrl const isSelectable = link.type === 'catalog' && !!link.slug && !isReadOnly - const isPrimaryIcon = link.type === 'catalog' && !!(link.is_primary_icon ?? (link as any).isPrimaryIcon) && !!link.slug + const isPrimaryIcon = link.type === 'catalog' && !!(link.is_primary_icon ?? link.isPrimaryIcon) && !!link.slug return ( @@ -1034,6 +1040,47 @@ function ElementPanel({ isOpen, onClose, element, onSave, autoSave = false, onDe {elementPanelAfterContentSlot} + {element && (onPromoteVisibility || onDemoteVisibility || onResetVisibility) && ( + + + DENSITY + {visibilityOverrideDelta !== 0 && ( + 0 ? 'teal' : 'orange'} variant="subtle"> + {visibilityOverrideDelta > 0 ? `+${visibilityOverrideDelta}` : visibilityOverrideDelta} + + )} + + + + + {visibilityOverrideDelta !== 0 && ( + + )} + + + )} + + {element && (onShowContext || onHideContext) && element.file_path && ( + + {onShowContext && ( + + )} + {onHideContext && ( + + )} + + )} + {isEdit && canEdit && (
-
- {text} -
+ {text && ( +
+ {text} +
+ )} + {proxyBadgeText && ( + + )} + {versionBadgeText && ( +
+ {versionBadgeText} +
+ )}
)} diff --git a/frontend/src/components/ViewExplorer/index.tsx b/frontend/src/components/ViewExplorer/index.tsx index 2cb27a5..c967944 100644 --- a/frontend/src/components/ViewExplorer/index.tsx +++ b/frontend/src/components/ViewExplorer/index.tsx @@ -43,7 +43,7 @@ interface Props { tagColors: Record selectedElement?: LibraryElement | null onUpdateTags?: (elementId: number, tags: string[]) => Promise - onCreateTag: (tag: string, color?: string) => Promise + onCreateTag: (tag: string, color?: string, description?: string) => Promise layers: ViewLayer[] onHoverLayer: (tags: string[] | null, color?: string | null) => void onCreateLayer: (name: string, tags: string[], color: string) => Promise diff --git a/frontend/src/components/ViewFloatingMenu-vscode.tsx b/frontend/src/components/ViewFloatingMenu-vscode.tsx index d20c468..f3be7ba 100644 --- a/frontend/src/components/ViewFloatingMenu-vscode.tsx +++ b/frontend/src/components/ViewFloatingMenu-vscode.tsx @@ -25,6 +25,8 @@ interface ViewFloatingMenuProps { activeTags?: string[] setActiveTags?: (tags: string[]) => void availableTags?: string[] + onShowViewContext?: () => void + onHideViewContext?: () => void } function LayerIcon() { @@ -77,6 +79,8 @@ export default function ViewFloatingMenu({ activeTags = [], setActiveTags, availableTags = [], + onShowViewContext, + onHideViewContext, }: ViewFloatingMenuProps) { return ( + {(onShowViewContext || onHideViewContext) && ( + <> + + {onShowViewContext && ( + + + + )} + {onHideViewContext && ( + + + + )} + + )} + {/* Draw mode toggle */} diff --git a/frontend/src/components/ViewFloatingMenu.tsx b/frontend/src/components/ViewFloatingMenu.tsx index 351ee0c..f832c6f 100644 --- a/frontend/src/components/ViewFloatingMenu.tsx +++ b/frontend/src/components/ViewFloatingMenu.tsx @@ -2,7 +2,7 @@ import React, { memo } from 'react' import type { ViewFloatingMenuSlots } from '../slots' import { - HStack, Tooltip, Button, Box, Text, Popover, PopoverTrigger, Portal, PopoverContent, PopoverBody, IconButton, useDisclosure + HStack, Tooltip, Button, Box, Text, Popover, PopoverTrigger, Portal, PopoverContent, PopoverBody, IconButton, Slider, SliderTrack, SliderFilledTrack, SliderThumb, useDisclosure } from '@chakra-ui/react' import { DownloadIcon } from '@chakra-ui/icons' import { @@ -35,6 +35,10 @@ export interface ViewFloatingMenuProps extends ViewFloatingMenuSlots { onShare?: () => void focusMode: boolean onFocusModeChange: (enabled: boolean) => void + onShowViewContext?: () => void + onHideViewContext?: () => void + densityLevel?: number + onDensityLevelChange?: (level: number) => void // Tag-related props allTags: string[] @@ -73,6 +77,8 @@ function ViewFloatingMenu({ onExport, focusMode, onFocusModeChange, + densityLevel = 0, + onDensityLevelChange, allTags, layers, tagColors, @@ -90,6 +96,11 @@ function ViewFloatingMenu({ }: ViewFloatingMenuProps) { const { canEdit } = useViewEditorContext() const { isOpen: isTagsOpen, onClose: onTagsClose, onToggle: onTagsToggle } = useDisclosure() + const [draftDensityLevel, setDraftDensityLevel] = React.useState(densityLevel) + + React.useEffect(() => { + setDraftDensityLevel(densityLevel) + }, [densityLevel]) return ( )} + {onDensityLevelChange && ( + <> + + + + { + setDraftDensityLevel(value) + onDensityLevelChange(value) + }} + focusThumbOnChange={false} + > + + + + {[-2, -1, 0, 1, 2].map((value) => ( + = value ? 'var(--accent)' : 'whiteAlpha.400'} + pointerEvents="none" + /> + ))} + + + + + + )} + {/* Draw mode toggle */} { if (!isMenuOpen && !isTooltipOpen) return @@ -114,6 +118,7 @@ export default function ViewGridNode({ data }: { data: ViewGridNodeData }) { }, []) useEffect(() => { + if (isCluster) return if (!hasRequested) return let active = true @@ -141,13 +146,15 @@ export default function ViewGridNode({ data }: { data: ViewGridNodeData }) { URL.revokeObjectURL(url) } } - }, [hasRequested, data.id]) + }, [hasRequested, data.id, isCluster]) const borderColor = data.focused ? accent : 'rgba(255,255,255,0.14)' const boxShadow = data.focused ? `0 0 24px ${hexToRgba(accent, 0.4)}` - : '0 8px 24px rgba(0,0,0,0.4), inset 0 1px 0 rgba(255,255,255,0.05)' + : isCluster + ? '0 14px 34px rgba(0,0,0,0.42), inset 0 1px 0 rgba(255,255,255,0.05)' + : '0 8px 24px rgba(0,0,0,0.4), inset 0 1px 0 rgba(255,255,255,0.05)' return ( // Outer container: sizing + group context, overflow visible for the "New Child" hover button @@ -160,7 +167,26 @@ export default function ViewGridNode({ data }: { data: ViewGridNodeData }) { h="150px" position="relative" userSelect="none" + opacity={data.dimmed ? 0.5 : 1} transition="opacity 0.3s cubic-bezier(0.16, 1, 0.3, 1)" + _before={isCluster ? { + content: '""', + position: 'absolute', + inset: '8px -9px -8px 9px', + borderRadius: '12px', + border: '1px solid rgba(255,255,255,0.08)', + bg: 'rgba(var(--bg-element-rgb), 0.55)', + boxShadow: '0 8px 20px rgba(0,0,0,0.28)', + } : undefined} + _after={isCluster ? { + content: '""', + position: 'absolute', + inset: '16px -18px -16px 18px', + borderRadius: '12px', + border: '1px solid rgba(255,255,255,0.06)', + bg: 'rgba(var(--bg-element-rgb), 0.35)', + boxShadow: '0 8px 20px rgba(0,0,0,0.2)', + } : undefined} > - {thumbnailUrl ? ( + {isCluster ? ( + + {Array.from({ length: Math.min(18, Math.max(6, data.collapsedCount ?? 6)) }).map((_, i) => ( + + ))} + + ) : thumbnailUrl ? ( )} - {!data.isEditing && ( + {!data.isEditing && !isCluster && ( e.stopPropagation()} mt="-2px"> - {data.counts - ? `${data.counts.nodes}n · ${data.counts.edges}e` - : '-'} + {isCluster && data.collapsedCount + ? `${data.collapsedCount} views` + : data.counts + ? `${data.counts.nodes}n · ${data.counts.edges}e` + : '-'} diff --git a/frontend/src/components/WorkspacePanel.tsx b/frontend/src/components/WorkspacePanel.tsx new file mode 100644 index 0000000..492f0e5 --- /dev/null +++ b/frontend/src/components/WorkspacePanel.tsx @@ -0,0 +1,880 @@ +import { useCallback, useEffect, useMemo, useRef, useState } from 'react' +import { useLocation, useNavigate } from 'react-router-dom' +import { useQueryClient } from '@tanstack/react-query' +import { + Badge, + Box, + Button, + Collapse, + HStack, + IconButton, + Menu, + MenuButton, + MenuList, + MenuItem, + Portal, + Text, + Tooltip, + VStack, +} from '@chakra-ui/react' +import { ChevronDownIcon, ChevronLeftIcon, ChevronRightIcon, ChevronUpIcon, CloseIcon, RepeatIcon, TimeIcon, ViewIcon, ViewOffIcon } from '@chakra-ui/icons' +import { + api, + type WatchDiff, + type WatchEvent, + type WatchLock, + type WatchRepresentationSummary, + type WatchRepository, + type WatchVersion, + type WorkspaceVersion, +} from '../api/client' +import { buildWorkspaceVersionPreview, useWorkspaceVersionPreview } from '../context/WorkspaceVersionContext' +import { + buildWatchDiffLocations, + formatTldStatLine, + summarizeWatchDiffs, + type WatchDiffLocation, + type WatchDiffSummary, +} from '../utils/watchDiffSummary' + +export const WATCH_REPRESENTATION_UPDATED_EVENT = 'tld:watch-representation-updated' + +// ─── Watch helpers ──────────────────────────────────────────────────────────── + +type WatchLine = { + id: number + at: string + text: string + tone: 'info' | 'success' | 'warning' | 'error' +} + +function PauseGlyph() { + return ( + + ) +} + +function summarizeEvent(event: WatchEvent): WatchLine | null { + const id = Date.now() + Math.random() + const at = event.at || new Date().toISOString() + const type = event.type + if (type === 'watch.heartbeat') return null + if (type === 'watch.connected') return { id, at, text: 'Watch stream connected', tone: 'success' } + if (type === 'watch.paused') return { id, at, text: 'Watch paused', tone: 'warning' } + if (type === 'watch.stopped') return { id, at, text: 'Watch stopped', tone: 'warning' } + if (type === 'watch.error') return { id, at, text: event.message || 'Watch error', tone: 'error' } + if (type === 'lock.disabled') return null + if (type === 'lock.enabled') return { id, at, text: 'Workspace locked for watch updates', tone: 'info' } + if (type === 'version.created') return null + if (type === 'representation.updated') { + const data = event.data as Partial | undefined + const changed = [ + data?.views_created ? `views +${data.views_created}` : '', + data?.elements_created || data?.elements_updated ? `elements +${data.elements_created ?? 0}/${data.elements_updated ?? 0}` : '', + data?.connectors_created || data?.connectors_updated ? `connectors +${data.connectors_created ?? 0}/${data.connectors_updated ?? 0}` : '', + ].filter(Boolean).join(', ') + return { id, at, text: changed ? `Workspace updated: ${changed}` : 'Workspace refreshed', tone: 'success' } + } + if (type === 'scan.started') { + const files = event.changed_files ? ` · ${event.changed_files} files` : '' + return { id, at, text: `Scanning${files}`, tone: 'info' } + } + if (type === 'scan.completed') { + const warnings = event.warnings?.length ? ` · ${event.warnings[0]}` : '' + return { id, at, text: `Scan complete${warnings}`, tone: event.warnings?.length ? 'warning' : 'success' } + } + if (type === 'source.changed') { + const data = event.data as { change?: { path?: string; change_type?: string }; representation_changed?: boolean } | undefined + const path = data?.change?.path ?? 'source file' + const suffix = data?.representation_changed ? 'changed the diagram' : 'did not change the diagram' + return { id, at, text: `${path} ${suffix}`, tone: data?.representation_changed ? 'success' : 'info' } + } + return { id, at, text: type, tone: 'info' } +} + +function shortPath(path: string | undefined): string { + if (!path) return 'repository' + const parts = path.split(/[/\\]/).filter(Boolean) + return parts.slice(-2).join('/') || path +} + +function versionLabel(version: WatchVersion) { + const subject = version.commit_message?.trim() + return subject || `Version ${new Date(version.created_at).toLocaleTimeString()}` +} + +function changeLabel(diffs: WatchDiff[]) { + const summary = summarizeWatchDiffs(diffs) + const total = summary.elements.added + summary.elements.updated + summary.elements.deleted + summary.elements.initialized + + summary.connectors.added + summary.connectors.updated + summary.connectors.deleted + summary.connectors.initialized + return total > 0 ? formatTldStatLine(summary) : 'No materialized changes' +} + +function normalizeDiffs(value: WatchDiff[] | null | undefined): WatchDiff[] { + return Array.isArray(value) ? value : [] +} + +function mergeRepositoryOption(repos: WatchRepository[], repo: WatchRepository | null | undefined): WatchRepository[] { + if (!repo) return repos + const existing = repos.find((item) => item.id === repo.id) + if (existing) { + return repos.map((item) => item.id === repo.id ? { ...item, ...repo } : item) + } + return [repo, ...repos] +} + +function ResourceCountDisplay({ summary }: { summary: WatchDiffSummary }) { + const rows = [ + { label: 'Elements', stat: summary.elements }, + { label: 'Connectors', stat: summary.connectors }, + ] + const total = rows.reduce((sum, row) => ( + sum + row.stat.added + row.stat.updated + row.stat.deleted + row.stat.initialized + ), 0) + const changes = [ + { key: 'added', label: 'added', color: 'green.300' }, + { key: 'updated', label: 'updated', color: 'yellow.300' }, + { key: 'deleted', label: 'deleted', color: 'red.300' }, + { key: 'initialized', label: 'initialized', color: 'blue.300' }, + ] as const + + return ( + + + + Diagram resources + + {total} total + + + {rows.map((row) => ( + + {row.label} + + {changes.map((change) => { + const count = row.stat[change.key] + return count > 0 ? ( + + {count} {change.label} + + ) : null + })} + {row.stat.added + row.stat.updated + row.stat.deleted + row.stat.initialized === 0 && ( + none + )} + + + ))} + + + ) +} + +// ─── Themed dropdown ────────────────────────────────────────────────────────── + +interface ThemedSelectProps { + value: T | '' + options: { value: T; label: string }[] + placeholder?: string + onChange: (value: T | '') => void + isDisabled?: boolean + flex?: number +} + +function ThemedSelect({ value, options, placeholder, onChange, isDisabled, flex }: ThemedSelectProps) { + const selected = options.find((o) => o.value === value) + return ( + + } + size="sm" + variant="ghost" + isDisabled={isDisabled} + flex={flex} + minW={0} + h="32px" + px={3} + fontSize="13px" + fontWeight="500" + color={selected ? 'gray.100' : 'gray.500'} + bg="whiteAlpha.50" + border="1px solid" + borderColor="whiteAlpha.100" + borderRadius="md" + _hover={{ bg: 'whiteAlpha.100', borderColor: 'whiteAlpha.200' }} + _active={{ bg: 'whiteAlpha.150' }} + textAlign="left" + justifyContent="flex-start" + overflow="hidden" + sx={{ '> span:first-of-type': { overflow: 'hidden', textOverflow: 'ellipsis', whiteSpace: 'nowrap' } }} + > + {selected?.label ?? placeholder ?? '—'} + + + + {options.length === 0 && ( + No options + )} + {options.map((opt) => ( + onChange(opt.value)} + > + {opt.label} + + ))} + + + + ) +} + +// ─── Main combined panel ────────────────────────────────────────────────────── + +export default function WorkspacePanel() { + const navigate = useNavigate() + const location = useLocation() + const queryClient = useQueryClient() + + // ── Version state ───────────────────────────────────────────────────────── + const { preview, setPreview, clearPreview, requestFollow } = useWorkspaceVersionPreview() + const [versionsOpen, setVersionsOpen] = useState(false) + const [diffVisible, setDiffVisible] = useState(false) + const [repos, setRepos] = useState([]) + const [versions, setVersions] = useState([]) + const [workspaceVersions, setWorkspaceVersions] = useState([]) + const [repoId, setRepoId] = useState('') + const [versionId, setVersionId] = useState('') + const [diffs, setDiffs] = useState([]) + const [diffLocations, setDiffLocations] = useState([]) + const [activeDiffLocationKey, setActiveDiffLocationKey] = useState(null) + const [watchActive, setWatchActive] = useState(false) + const [watchPaused, setWatchPaused] = useState(false) + const [watchRepository, setWatchRepository] = useState(null) + const [watchLock, setWatchLock] = useState(null) + const [watchConnected, setWatchConnected] = useState(false) + const [watcherMode, setWatcherMode] = useState('') + const [languages, setLanguages] = useState([]) + const [watchLines, setWatchLines] = useState([]) + const [runtimeOpen, setRuntimeOpen] = useState(true) + + const repoOptions = useMemo(() => mergeRepositoryOption(repos, watchRepository), [repos, watchRepository]) + const selectedRepo = useMemo(() => { + const selected = repoOptions.find((r) => r.id === repoId) + if (selected) return selected + if (!repoId || watchRepository?.id === repoId) return watchRepository ?? null + return null + }, [repoOptions, repoId, watchRepository]) + const selectedVersion = useMemo(() => versions.find((v) => v.id === versionId) ?? null, [versions, versionId]) + + const selectLatestWatchVersion = useCallback(async (targetRepoId: number) => { + const nextVersions = await api.watch.versions(targetRepoId) + setVersions(nextVersions) + const latest = nextVersions[0] ?? null + setVersionId(latest?.id ?? '') + if (!latest) { + setDiffs([]) + return + } + const latestDiffs = await api.watch.diffs(latest.id).catch(() => [] as WatchDiff[]) + setDiffs(normalizeDiffs(latestDiffs)) + }, []) + + const loadVersions = useCallback(async () => { + const [nextRepos, nextWsVersions] = await Promise.all([ + api.watch.repositories().catch(() => [] as WatchRepository[]), + api.versions.list(50).catch(() => [] as WorkspaceVersion[]), + ]) + const mergedRepos = mergeRepositoryOption(nextRepos, watchRepository) + setRepos(mergedRepos) + setWorkspaceVersions(nextWsVersions) + const nextRepoId = repoId || watchRepository?.id || mergedRepos[0]?.id || '' + setRepoId(nextRepoId) + if (nextRepoId) { + const nextVersions = await api.watch.versions(nextRepoId) + setVersions(nextVersions) + setVersionId(versionId || nextVersions[0]?.id || '') + } + }, [repoId, versionId, watchRepository]) + + useEffect(() => { + if (!versionsOpen && !preview) return + void loadVersions() + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [versionsOpen]) + + useEffect(() => { + if (!repoId) { setVersions([]); setVersionId(''); return } + api.watch.versions(repoId).then((next) => { + setVersions(next) + setVersionId(next[0]?.id ?? '') + }).catch(() => { setVersions([]); setVersionId('') }) + }, [repoId]) + + useEffect(() => { + if (!versionId) { setDiffs([]); return } + api.watch.diffs(versionId).then((next) => setDiffs(normalizeDiffs(next))).catch(() => setDiffs([])) + }, [versionId]) + + useEffect(() => { + if (!diffs.length) { + setDiffLocations([]) + return + } + let cancelled = false + api.explore.load().then((data) => { + if (!cancelled) setDiffLocations(buildWatchDiffLocations(data, diffs)) + }).catch(() => { + if (!cancelled) setDiffLocations([]) + }) + return () => { cancelled = true } + }, [diffs]) + + const displayedDiffLocations = useMemo(() => diffLocations.slice(0, 24), [diffLocations]) + const navigableDiffLocations = useMemo(() => { + const elementLocations = diffLocations.filter((target) => target.resourceType === 'element') + return elementLocations.length > 0 ? elementLocations : diffLocations + }, [diffLocations]) + const activeDiffLocationIndex = useMemo(() => { + if (!activeDiffLocationKey) return -1 + const index = navigableDiffLocations.findIndex((target) => target.key === activeDiffLocationKey) + return index >= 0 ? index : -1 + }, [activeDiffLocationKey, navigableDiffLocations]) + + useEffect(() => { + if (!selectedVersion || !diffVisible) { + clearPreview() + return + } + setPreview(buildWorkspaceVersionPreview({ repository: selectedRepo, version: selectedVersion, workspaceVersions, diffs })) + }, [clearPreview, diffVisible, diffs, selectedRepo, selectedVersion, setPreview, workspaceVersions]) + + const navigateToDiffLocation = useCallback((target: WatchDiffLocation) => { + setActiveDiffLocationKey(target.key) + requestFollow({ + resourceType: target.resourceType, + resourceId: target.resourceId, + viewId: target.viewId, + changeType: target.changeType, + }) + if (location.pathname === '/dependencies' && target.resourceType === 'element' && target.resourceId) { + navigate(`/dependencies?element=${target.resourceId}`) + return + } + if (location.pathname.startsWith('/views/') && !location.pathname.startsWith('/views?')) { + const elementQuery = target.resourceType === 'element' && target.resourceId ? `?element=${target.resourceId}` : '' + navigate(`/views/${target.viewId}${elementQuery}`) + return + } + const elementQuery = target.resourceType === 'element' && target.resourceId ? `&element=${target.resourceId}` : '' + navigate(`/views?view=explore&focus=${target.viewId}${elementQuery}`) + }, [location.pathname, navigate, requestFollow]) + + const navigateDiffLocationByOffset = useCallback((offset: number) => { + if (navigableDiffLocations.length === 0) return + const nextIndex = activeDiffLocationIndex < 0 + ? offset > 0 ? 0 : navigableDiffLocations.length - 1 + : (activeDiffLocationIndex + offset + navigableDiffLocations.length) % navigableDiffLocations.length + navigateToDiffLocation(navigableDiffLocations[nextIndex]) + }, [activeDiffLocationIndex, navigableDiffLocations, navigateToDiffLocation]) + + const activeVersion = preview?.version ?? selectedVersion + const activeRepo = preview?.repository ?? selectedRepo + const diffSummary = useMemo(() => summarizeWatchDiffs(diffs), [diffs]) + const totalFileChanges = diffSummary.files.added + diffSummary.files.updated + diffSummary.files.deleted + diffSummary.files.initialized + const totalTldChanges = diffSummary.elements.added + diffSummary.elements.updated + diffSummary.elements.deleted + diffSummary.elements.initialized + + diffSummary.connectors.added + diffSummary.connectors.updated + diffSummary.connectors.deleted + diffSummary.connectors.initialized + const activeDiffLocation = activeDiffLocationIndex >= 0 ? navigableDiffLocations[activeDiffLocationIndex] : null + const headerAddedLines = activeDiffLocation?.addedLines ?? diffSummary.elements.addedLines + diffSummary.connectors.addedLines + const headerRemovedLines = activeDiffLocation?.removedLines ?? diffSummary.elements.removedLines + diffSummary.connectors.removedLines + + // ── Watch state ─────────────────────────────────────────────────────────── + const socketRef = useRef(null) + const reconnectTimerRef = useRef(null) + const reconnectAttemptRef = useRef(0) + const lastWatchMessageAtRef = useRef(0) + const socketHealthTimerRef = useRef(null) + const lastRepresentationHashRef = useRef('') + const addLine = useCallback((line: WatchLine | null) => { + if (!line) return + setWatchLines((current) => { + if (current[0]?.text === line.text && current[0]?.tone === line.tone) return current + return [line, ...current].slice(0, 8) + }) + }, []) + + const refreshWorkspace = useCallback((event: WatchEvent) => { + const data = event.data as Partial | undefined + const hash = data?.representation_hash ?? '' + if (hash && hash === lastRepresentationHashRef.current) return + if (hash) lastRepresentationHashRef.current = hash + void queryClient.invalidateQueries({ queryKey: ['workspace', 'views'] }) + void queryClient.invalidateQueries({ queryKey: ['elements', 'list'] }) + window.dispatchEvent(new CustomEvent(WATCH_REPRESENTATION_UPDATED_EVENT, { detail: event })) + }, [queryClient]) + + const handleEvent = useCallback((event: WatchEvent) => { + const eventLock = event.data && typeof event.data === 'object' && 'status' in event.data + ? event.data as WatchLock : null + if (event.repository_id) setWatchLock((current) => eventLock ?? current) + if (eventLock) setWatchPaused(eventLock.status === 'paused') + if (event.watcher_mode) setWatcherMode(event.watcher_mode) + if (event.languages?.length) setLanguages(event.languages) + if (event.type === 'watch.paused') setWatchPaused(true) + if (event.type === 'watch.heartbeat') { + setWatchActive(true) + if (eventLock) setWatchPaused(eventLock.status === 'paused') + } + if (event.type === 'watch.stopped') { setWatchActive(false); setWatchPaused(false) } + if (event.type === 'representation.updated') { + const data = event.data as Partial | undefined + if ('diffs' in (data ?? {})) setDiffs(normalizeDiffs(data?.diffs)) + refreshWorkspace(event) + } + if (event.type === 'version.created') { + const version = event.data as Partial | undefined + const targetRepoId = event.repository_id || version?.repository_id || watchLock?.repository_id || watchRepository?.id || 0 + clearPreview() + setDiffs([]) + if (targetRepoId > 0) { + setRepoId(targetRepoId) + void selectLatestWatchVersion(targetRepoId) + } + } + if (event.type !== 'watch.stopped' || watchActive) addLine(summarizeEvent(event)) + }, [watchActive, addLine, clearPreview, refreshWorkspace, selectLatestWatchVersion, watchLock?.repository_id, watchRepository?.id]) + const handleEventRef = useRef(handleEvent) + + useEffect(() => { + handleEventRef.current = handleEvent + }, [handleEvent]) + + useEffect(() => { + let cancelled = false + const poll = async () => { + const status = await api.watch.status().catch(() => null) + if (!status || cancelled) return + setWatchActive(status.active) + setWatchRepository(status.repository ?? null) + setWatchLock(status.lock ?? null) + setWatchPaused(status.lock?.status === 'paused') + if (status.repository) { + setRepos((current) => mergeRepositoryOption(current, status.repository)) + setRepoId((current) => current || status.repository?.id || '') + } + } + void poll() + const interval = window.setInterval(poll, 5000) + return () => { cancelled = true; window.clearInterval(interval) } + }, []) + + useEffect(() => { + let disposed = false + const scheduleReconnect = () => { + if (disposed) return + const delay = Math.min(5000, 1000 * 2 ** Math.min(reconnectAttemptRef.current, 3)) + reconnectAttemptRef.current += 1 + reconnectTimerRef.current = window.setTimeout(connect, delay) + } + const connect = () => { + if (disposed) return + const socket = new WebSocket(api.watch.websocketUrl()) + socketRef.current = socket + lastWatchMessageAtRef.current = Date.now() + socket.onopen = () => { + setWatchConnected(true) + reconnectAttemptRef.current = 0 + lastWatchMessageAtRef.current = Date.now() + addLine({ id: Date.now() + Math.random(), at: new Date().toISOString(), text: 'Watch stream connected', tone: 'success' }) + try { socket.send(JSON.stringify({ type: 'watch.status' })) } catch { /* ignore */ } + } + socket.onclose = () => { + setWatchConnected(false) + if (!disposed) { + addLine({ id: Date.now() + Math.random(), at: new Date().toISOString(), text: 'Watch stream reconnecting', tone: 'warning' }) + scheduleReconnect() + } + } + socket.onerror = () => socket.close() + socket.onmessage = (msg) => { + lastWatchMessageAtRef.current = Date.now() + try { handleEventRef.current(JSON.parse(msg.data) as WatchEvent) } catch { /* ignore */ } + } + } + connect() + socketHealthTimerRef.current = window.setInterval(() => { + const socket = socketRef.current + if (socket?.readyState === WebSocket.OPEN && Date.now() - lastWatchMessageAtRef.current > 10000) { + socket.close() + } + }, 5000) + return () => { + disposed = true + if (reconnectTimerRef.current !== null) window.clearTimeout(reconnectTimerRef.current) + if (socketHealthTimerRef.current !== null) window.clearInterval(socketHealthTimerRef.current) + socketRef.current?.close() + socketRef.current = null + } + }, [addLine]) + + const sendControl = useCallback((type: 'watch.pause' | 'watch.resume' | 'watch.stop') => { + const socket = socketRef.current + if (!socket || socket.readyState !== WebSocket.OPEN) return + socket.send(JSON.stringify({ type, repository_id: watchLock?.repository_id ?? watchRepository?.id ?? 0 })) + if (type === 'watch.pause') setWatchPaused(true) + if (type === 'watch.resume') setWatchPaused(false) + if (type === 'watch.stop') setWatchActive(false) + }, [watchLock?.repository_id, watchRepository?.id]) + + const watchStatusColor = !watchActive ? 'gray' : watchPaused ? 'yellow' : watchConnected ? 'green' : 'orange' + const watchStatusLabel = !watchActive ? 'Stopped' : watchPaused ? 'Paused' : 'Live' + const watchTitle = useMemo(() => shortPath(watchRepository?.repo_root), [watchRepository?.repo_root]) + const watchMode = [watcherMode || (watchConnected ? 'live' : 'connecting'), languages.length ? languages.join(', ') : ''].filter(Boolean).join(' · ') + + const showRuntimeSection = watchActive || watchLines.length > 0 + + // ── Render ──────────────────────────────────────────────────────────────── + return ( + + + {/* ── Versions header ── */} + + + + + value={repoId} + placeholder="Repository" + options={repoOptions.map((r) => ({ value: r.id, label: r.display_name || shortPath(r.repo_root) }))} + onChange={(v) => setRepoId(v)} + flex={1} + /> + + value={versionId} + placeholder="Branch" + options={versions.map((v) => ({ + value: v.id, + label: `${v.branch || 'detached'} (${new Date(v.created_at).toLocaleTimeString([], { hour: '2-digit', minute: '2-digit' })})`, + }))} + onChange={(v) => setVersionId(v)} + flex={1} + /> + + + {activeVersion && ( + + : } + size="sm" + variant="ghost" + color={diffVisible ? 'var(--accent)' : 'whiteAlpha.700'} + onClick={() => setDiffVisible((visible) => !visible)} + /> + + )} + + : } + size="sm" + variant="ghost" + color={versionsOpen ? 'var(--accent)' : 'whiteAlpha.700'} + onClick={() => setVersionsOpen((v) => !v)} + /> + + + + + + + + +{headerAddedLines} + + + -{headerRemovedLines} + + + {activeDiffLocation + ? `${activeDiffLocationIndex + 1} of ${navigableDiffLocations.length}: ${activeDiffLocation.label}` + : `${totalTldChanges} changed elements`} + + + + + } + size="sm" + variant="solid" + h="32px" + w="32px" + bg="whiteAlpha.200" + _hover={{ bg: 'whiteAlpha.300' }} + _active={{ bg: 'whiteAlpha.400' }} + isDisabled={navigableDiffLocations.length === 0} + onClick={() => navigateDiffLocationByOffset(-1)} + /> + + + } + size="sm" + variant="solid" + h="32px" + w="32px" + bg="whiteAlpha.200" + _hover={{ bg: 'whiteAlpha.300' }} + _active={{ bg: 'whiteAlpha.400' }} + isDisabled={navigableDiffLocations.length === 0} + onClick={() => navigateDiffLocationByOffset(1)} + /> + + + + + + {/* ── Versions body ── */} + + + + + {activeVersion ? versionLabel(activeVersion) : 'Repository snapshot'} + + + + {totalFileChanges} files + + +{diffSummary.files.addedLines} + -{diffSummary.files.removedLines} + {workspaceVersions.length} snapshots + + + + + + {displayedDiffLocations.length > 0 && ( + + {displayedDiffLocations.map((target) => ( + + ))} + + )} + + + + {/* ── Runtime section (collapsible) ── */} + {showRuntimeSection && ( + + setRuntimeOpen((v) => !v)} + _hover={{ bg: 'whiteAlpha.50' }} + transition="background 0.15s" + > + + + {watchStatusLabel.toUpperCase()} + + {watchTitle} + {watchMode ? {watchMode} : null} + + e.stopPropagation()}> + {watchActive && ( + <> + + : } + size="sm" + variant="ghost" + color="gray.400" + _hover={{ color: 'white', bg: 'whiteAlpha.100' }} + onClick={() => sendControl(watchPaused ? 'watch.resume' : 'watch.pause')} + /> + + + } + size="sm" + variant="ghost" + color="gray.400" + _hover={{ color: 'white', bg: 'whiteAlpha.100' }} + onClick={() => sendControl('watch.stop')} + /> + + + )} + : } + size="sm" + variant="ghost" + color="gray.400" + _hover={{ color: 'white', bg: 'whiteAlpha.100' }} + onClick={() => setRuntimeOpen((v) => !v)} + /> + + + + + + {watchLines.length === 0 ? ( + Waiting for watch output… + ) : watchLines.map((line) => ( + + + {new Date(line.at).toLocaleTimeString([], { hour: '2-digit', minute: '2-digit', second: '2-digit' })} + + + {line.text} + + + ))} + + + + )} + + + ) +} diff --git a/frontend/src/components/ZUI/ZUICanvas.tsx b/frontend/src/components/ZUI/ZUICanvas.tsx index d427cd2..f97f33f 100644 --- a/frontend/src/components/ZUI/ZUICanvas.tsx +++ b/frontend/src/components/ZUI/ZUICanvas.tsx @@ -27,17 +27,29 @@ import { Link as RouterLink } from 'react-router-dom' import { ExternalLinkIcon } from '@chakra-ui/icons' import type { ExploreData } from '../../types' import { computeLayout } from './layout' -import { renderFrame, getExpandThresholds, setOnImageLoadCallback, setHighlightedTags as setRendererHighlightedTags, setHiddenTags as setRendererHiddenTags, setHighlightColor as setRendererHighlightColor } from './renderer' +import { renderFrame, getExpandThresholds, getCameraRebase, rawCameraView, screenToWorldX, screenToWorldY, worldToScreenX, worldToScreenY, setOnImageLoadCallback, setHighlightedTags as setRendererHighlightedTags, setHiddenTags as setRendererHiddenTags, setHighlightColor as setRendererHighlightColor, setVersionDiff as setRendererVersionDiff } from './renderer' import { useZUIInteraction } from './useZUIInteraction' import type { DiagramGroupLayout, ZUIViewState } from './types' +import { findDiagramFocusTarget, findElementFocusTarget, viewportForDiagramFocusTarget, viewportForElementFocusTarget } from './focus' import { buildWorkspaceGraphSnapshot } from '../../crossBranch/graph' import type { CrossBranchContextSettings } from '../../crossBranch/types' -import type { ZUIResolvedConnector } from '../../crossBranch/resolve' -import { buildVisibleProxyConnectors, collectVisibleNodeAnchors, drawVisibleProxyConnectors, findHoveredProxyConnector } from './proxy' +import { DEFAULT_MIN_CONNECTOR_ANCHOR_ALPHA } from '../../crossBranch/settings' +import type { WorkspaceVersionFollowTarget, WorkspaceVersionPreview } from '../../context/WorkspaceVersionContext' +import { + buildProxyConnectorSpatialIndex, + buildVisibleProxyConnectors, + collectVisibleNodeAnchors, + drawVisibleDirectProxyBadges, + drawVisibleProxyConnectors, + findHoveredProxyConnector, + type ProxyConnectorSpatialIndex, + type VisibleNodeAnchor, +} from './proxy' export interface ZUICanvasHandle { fitView(): void focusDiagram(viewId: number): boolean + focusElement(viewId: number, elementId: number): boolean setCameraFrame(frame: ZUICameraFrame): boolean } @@ -55,6 +67,8 @@ interface Props { highlightedTags?: string[] highlightColor?: string hiddenTags?: string[] + versionPreview?: WorkspaceVersionPreview | null + versionFollowTarget?: WorkspaceVersionFollowTarget | null crossBranchSettings: CrossBranchContextSettings hoverLocked?: boolean } @@ -71,6 +85,26 @@ interface PathItem { absH: number } +function rebaseVisibleNodeAnchors( + anchors: Map, + originX: number, + originY: number, +): Map { + const rebased = new Map() + for (const [nodeId, anchor] of anchors) { + rebased.set(nodeId, { + ...anchor, + worldX: anchor.worldX - originX, + worldY: anchor.worldY - originY, + }) + } + return rebased +} + +function anchorViewForZoom(zoom: number): ZUIViewState { + return { x: 0, y: 0, zoom: Math.max(0.0001, zoom) } +} + function getPathAt( view: ZUIViewState, groups: DiagramGroupLayout[], @@ -80,8 +114,8 @@ function getPathAt( if (canvasW === 0 || canvasH === 0) return [] // World center of the screen - const worldCenterX = (canvasW / 2 - view.x) / view.zoom - const worldCenterY = (canvasH / 2 - view.y) / view.zoom + const worldCenterX = screenToWorldX(canvasW / 2, view) + const worldCenterY = screenToWorldY(canvasH / 2, view) const thresholds = getExpandThresholds(canvasW) for (const group of groups) { @@ -177,70 +211,6 @@ function getPathAt( return [] } -function findDiagramFocusTarget(groups: DiagramGroupLayout[], viewId: number): PathItem | null { - for (const group of groups) { - if (group.diagramId === viewId) { - return { - id: `g-${group.diagramId}`, - label: group.label, - type: 'group', - absX: group.worldX, - absY: group.worldY, - absW: group.worldW, - absH: group.worldH, - } - } - - const found = findLinkedDiagramInNodes(viewId, group.nodes, 0, 0, 1, 0, 0) - if (found) return found - } - return null -} - -function findLinkedDiagramInNodes( - viewId: number, - nodes: DiagramGroupLayout['nodes'], - parentAbsX: number, - parentAbsY: number, - parentAbsScale: number, - parentChildOffsetX: number, - parentChildOffsetY: number, -): PathItem | null { - for (const node of nodes) { - const absX = parentAbsX + (node.worldX - parentChildOffsetX) * parentAbsScale - const absY = parentAbsY + (node.worldY - parentChildOffsetY) * parentAbsScale - const absW = node.worldW * parentAbsScale - const absH = node.worldH * parentAbsScale - - if (node.linkedDiagramId === viewId) { - return { - id: node.id, - label: node.linkedDiagramLabel || node.label, - type: 'node', - isCircular: node.isCircular, - absX, - absY, - absW, - absH, - } - } - - if (node.children.length > 0) { - const found = findLinkedDiagramInNodes( - viewId, - node.children, - absX, - absY, - parentAbsScale * node.childScale, - node.childOffsetX, - node.childOffsetY, - ) - if (found) return found - } - } - return null -} - function easeOutQuart(t: number): number { return 1 - Math.pow(1 - t, 4) } @@ -320,25 +290,27 @@ function findFirstExpandableNodeInTree( return null } -export const ZUICanvas = forwardRef(function ZUICanvas({ data, onReady, onZoom, onPan, initialCameraFrame, highlightedTags, highlightColor, hiddenTags, crossBranchSettings, hoverLocked = false }, ref) { +export const ZUICanvas = forwardRef(function ZUICanvas({ data, onReady, onZoom, onPan, initialCameraFrame, highlightedTags, highlightColor, hiddenTags, versionPreview, versionFollowTarget, crossBranchSettings, hoverLocked = false }, ref) { const canvasRef = useRef(null) const containerRef = useRef(null) const cameraTransitionRef = useRef(null) const [initialized, setInitialized] = useState(false) const [containerSize, setContainerSize] = useState({ w: 0, h: 0 }) const isMobileLayout = useBreakpointValue({ base: true, md: false }) ?? false + const debugViewport = useMemo(() => typeof window !== 'undefined' && new URLSearchParams(window.location.search).has('debugZuiCamera'), []) // ── Layout ────────────────────────────────────────────────────── const layout = useMemo(() => computeLayout(data), [data]) const workspaceSnapshot = useMemo(() => buildWorkspaceGraphSnapshot(data), [data]) - // Holds the most-recently resolved connector topology so hover detection can - // use it without re-running the expensive O(connectors) resolution on every mousemove. - const proxyConnectorsRef = useRef([]) - - const resolveHoveredProxyItem = useCallback((worldX: number, worldY: number, view: ZUIViewState, canvasW: number) => { - const freshAnchors = collectVisibleNodeAnchors(layout.groups, view, canvasW, hiddenTags) - return findHoveredProxyConnector(worldX, worldY, proxyConnectorsRef.current, freshAnchors.byNodeId, view) - }, [hiddenTags, layout.groups]) + // Holds the latest proxy hover index so mousemove can query it without + // rebuilding anchors or connector geometry. + const proxyHoverIndexRef = useRef(null) + + const resolveHoveredProxyItem = useCallback((worldX: number, worldY: number, view: ZUIViewState) => { + const index = proxyHoverIndexRef.current + if (!index) return null + return findHoveredProxyConnector(worldX, worldY, index, view) + }, []) // ── Interaction ───────────────────────────────────────────────── const { viewState, viewStateRef, setViewState, fitView, maxZoom, hoveredItem, setHoveredItem, setHoverLocked } = useZUIInteraction( @@ -352,42 +324,100 @@ export const ZUICanvas = forwardRef(function ZUICanvas({ resolveHoveredProxyItem, ) - // Anchor positions recompute every render (fast tree traversal, view-dependent). + // Anchor positions are zoom-dependent, but not pan-dependent. Keeping pan out + // of this memo avoids re-walking the ZUI tree during drag/trackpad movement. const anchors = useMemo(() => - collectVisibleNodeAnchors(layout.groups, viewState, containerSize.w || 1, hiddenTags), - [layout.groups, viewState, containerSize.w, hiddenTags], + collectVisibleNodeAnchors(layout.groups, anchorViewForZoom(viewState.zoom), containerSize.w || 1, hiddenTags), + [layout.groups, viewState.zoom, containerSize.w, hiddenTags], ) + const viewportBounds = useMemo(() => { + const zoom = Math.max(0.0001, viewState.zoom) + const stableView = { ...viewState, zoom } + const minX = screenToWorldX(0, stableView) + const minY = screenToWorldY(0, stableView) + const maxX = screenToWorldX(containerSize.w, stableView) + const maxY = screenToWorldY(containerSize.h, stableView) + return { + minX, + minY, + maxX, + maxY, + centerX: (minX + maxX) / 2, + centerY: (minY + maxY) / 2, + } + }, [containerSize.h, containerSize.w, viewState]) + + useEffect(() => { + if (!debugViewport) return + const cameraRebase = getCameraRebase(viewState, containerSize.w, containerSize.h) + console.debug('[ZUICanvas] viewport', { + x: viewState.x, + y: viewState.y, + zoom: viewState.zoom, + width: containerSize.w, + height: containerSize.h, + minX: viewportBounds.minX, + minY: viewportBounds.minY, + maxX: viewportBounds.maxX, + maxY: viewportBounds.maxY, + centerX: viewportBounds.centerX, + centerY: viewportBounds.centerY, + renderX: cameraRebase.view.x, + renderY: cameraRebase.view.y, + renderOriginX: cameraRebase.originX, + renderOriginY: cameraRebase.originY, + }) + }, [containerSize.h, containerSize.w, debugViewport, viewState, viewportBounds]) + // A stable string key encoding which element→nodeId pairs are currently visible. // This only changes when nodes cross zoom-expansion thresholds not on every pan pixel. const visibleElementSig = useMemo(() => Array.from(anchors.visibleAnchors.entries()) .sort(([a], [b]) => a - b) - .map(([id, anchor]) => `${id}:${anchor.nodeId}`) + .map(([id, anchor]) => `${id}:${anchor.nodeId}:${anchor.renderAlpha >= (crossBranchSettings.minConnectorAnchorAlpha ?? DEFAULT_MIN_CONNECTOR_ANCHOR_ALPHA) ? 1 : 0}`) .join(','), - [anchors.visibleAnchors], + [anchors.visibleAnchors, crossBranchSettings.minConnectorAnchorAlpha], ) - - // Connector topology: expensive O(connectors) resolution only when visibility set changes. + const proxySettingsSig = [ + crossBranchSettings.enabled, + crossBranchSettings.depth, + crossBranchSettings.connectorBudget, + crossBranchSettings.connectorPriority, + crossBranchSettings.minConnectorAnchorAlpha ?? '', + crossBranchSettings.maxProxyConnectorGroups ?? '', + ].join(':') + + // Connector topology follows visible anchor identity, not camera position. + // Continuous pan/zoom renders reuse the previous topology until zoom changes + // which elements have visible/eligible anchors. const proxyConnectors = useMemo(() => { const resolved = buildVisibleProxyConnectors(workspaceSnapshot, anchors.visibleAnchors, crossBranchSettings) - proxyConnectorsRef.current = resolved return resolved // eslint-disable-next-line react-hooks/exhaustive-deps - }, [workspaceSnapshot, visibleElementSig, crossBranchSettings]) + }, [workspaceSnapshot, visibleElementSig, proxySettingsSig]) + + const proxyHoverIndex = useMemo(() => ( + buildProxyConnectorSpatialIndex(proxyConnectors.connectors, anchors.byNodeId) + ), [proxyConnectors.connectors, anchors.byNodeId]) + proxyHoverIndexRef.current = proxyHoverIndex const visibleProxyState = useMemo(() => ({ ...anchors, - proxyConnectors, + proxyConnectors: proxyConnectors.connectors, + hiddenProxyBadges: proxyConnectors.hiddenBadges, }), [anchors, proxyConnectors]) const visibleProxyStateRef = useRef(visibleProxyState) visibleProxyStateRef.current = visibleProxyState const labelBgRef = useRef('#171923') + const accentRef = useRef('#63b3ed') useEffect(() => { const update = () => { - labelBgRef.current = getComputedStyle(document.documentElement).getPropertyValue('--chakra-colors-gray-900').trim() || '#171923' + const styles = getComputedStyle(document.documentElement) + labelBgRef.current = styles.getPropertyValue('--chakra-colors-gray-900').trim() || '#171923' + accentRef.current = styles.getPropertyValue('--accent').trim() || '#63b3ed' needsRedrawRef.current = true } update() @@ -415,8 +445,8 @@ export const ZUICanvas = forwardRef(function ZUICanvas({ absY = g.worldY + g.diagramY - absH } - const sx = absX * viewState.zoom + viewState.x - const sy = absY * viewState.zoom + viewState.y + const sx = worldToScreenX(absX, viewState) + const sy = worldToScreenY(absY, viewState) const sw = absW * viewState.zoom const sh = absH * viewState.zoom @@ -473,36 +503,13 @@ export const ZUICanvas = forwardRef(function ZUICanvas({ setViewState({ x, y, zoom }) }, [containerSize, maxZoom, setViewState, setHoveredItem]) - const focusDiagram = useCallback((viewId: number) => { - const el = containerRef.current - const target = findDiagramFocusTarget(layout.groups, viewId) - if (!el || !target) return false - - const canvasW = el.offsetWidth - const canvasH = el.offsetHeight - if (canvasW === 0 || canvasH === 0) return false - - setHoveredItem(null, true) - - const padding = isMobileLayout ? 0.18 : 0.16 - const bboxW = Math.max(1, target.absW) - const bboxH = Math.max(1, target.absH) - const zoom = Math.min( - (canvasW * (1 - padding * 2)) / bboxW, - (canvasH * (1 - padding * 2)) / bboxH, - maxZoom, - ) - - const x = (canvasW - bboxW * zoom) / 2 - target.absX * zoom - const y = (canvasH - bboxH * zoom) / 2 - target.absY * zoom - + const animateToViewport = useCallback((to: ZUIViewState) => { if (cameraTransitionRef.current !== null) { cancelAnimationFrame(cameraTransitionRef.current) cameraTransitionRef.current = null } - const from = viewStateRef.current - const to = { x, y, zoom } + const from = rawCameraView(viewStateRef.current) const duration = 520 const startedAt = performance.now() @@ -524,8 +531,41 @@ export const ZUICanvas = forwardRef(function ZUICanvas({ } cameraTransitionRef.current = requestAnimationFrame(step) + }, [setViewState, viewStateRef]) + + const focusDiagram = useCallback((viewId: number) => { + const el = containerRef.current + const target = findDiagramFocusTarget(layout.groups, viewId) + if (!el || !target) return false + + const canvasW = el.offsetWidth + const canvasH = el.offsetHeight + if (canvasW === 0 || canvasH === 0) return false + + const to = viewportForDiagramFocusTarget(target, canvasW, canvasH, maxZoom, isMobileLayout) + if (!to) return false + + setHoveredItem(null, true) + animateToViewport(to) + return true + }, [animateToViewport, isMobileLayout, layout.groups, maxZoom, setHoveredItem]) + + const focusElement = useCallback((viewId: number, elementId: number) => { + const el = containerRef.current + const target = findElementFocusTarget(layout.groups, viewId, elementId) + if (!el || !target) return false + + const canvasW = el.offsetWidth + const canvasH = el.offsetHeight + if (canvasW === 0 || canvasH === 0) return false + + const to = viewportForElementFocusTarget(target, canvasW, canvasH, maxZoom, isMobileLayout) + if (!to) return false + + setHoveredItem(null, true) + animateToViewport(to) return true - }, [isMobileLayout, layout.groups, maxZoom, setHoveredItem, setViewState, viewStateRef]) + }, [animateToViewport, isMobileLayout, layout.groups, maxZoom, setHoveredItem]) const setCameraFrame = useCallback((frame: ZUICameraFrame) => { if (frame.profile !== 'detail-to-overview') return false @@ -625,9 +665,10 @@ export const ZUICanvas = forwardRef(function ZUICanvas({ fitView(el.offsetWidth, el.offsetHeight, layout.bbox) }, focusDiagram, + focusElement, setCameraFrame, }), - [fitView, focusDiagram, layout.bbox, setCameraFrame, setHoveredItem], + [fitView, focusDiagram, focusElement, layout.bbox, setCameraFrame, setHoveredItem], ) // ── RAF render loop ────────────────────────────────────────────── @@ -719,14 +760,29 @@ export const ZUICanvas = forwardRef(function ZUICanvas({ ctx.save() ctx.setTransform(dpr, 0, 0, dpr, 0, 0) const occupiedLabelRects = renderFrame(ctx, layout.groups, currentView, w, h) + const cameraRebase = getCameraRebase(currentView, w, h) + const rebasedProxyAnchors = rebaseVisibleNodeAnchors( + visibleProxyStateRef.current.byNodeId, + cameraRebase.originX, + cameraRebase.originY, + ) ctx.save() - ctx.translate(currentView.x, currentView.y) - ctx.scale(currentView.zoom, currentView.zoom) + ctx.translate(cameraRebase.view.x, cameraRebase.view.y) + ctx.scale(cameraRebase.view.zoom, cameraRebase.view.zoom) drawVisibleProxyConnectors( ctx, visibleProxyStateRef.current.proxyConnectors, - visibleProxyStateRef.current.byNodeId, - currentView.zoom, + rebasedProxyAnchors, + cameraRebase.view.zoom, + labelBgRef.current, + accentRef.current, + occupiedLabelRects, + ) + drawVisibleDirectProxyBadges( + ctx, + visibleProxyStateRef.current.hiddenProxyBadges, + rebasedProxyAnchors, + cameraRebase.view.zoom, labelBgRef.current, occupiedLabelRects, ) @@ -761,6 +817,30 @@ export const ZUICanvas = forwardRef(function ZUICanvas({ needsRedrawRef.current = true }, [hiddenTags]) + useEffect(() => { + const pulsedElementChanges = new Map() + const pulsedElementLineDeltas = new Map() + if (versionFollowTarget?.resourceType === 'element' && versionFollowTarget.resourceId) { + const change = versionFollowTarget.changeType ?? versionPreview?.elementChanges.get(versionFollowTarget.resourceId) + if (change) pulsedElementChanges.set(versionFollowTarget.resourceId, change) + } + setRendererVersionDiff( + pulsedElementChanges, + versionPreview?.connectorChanges ?? new Map(), + versionPreview?.elementLineDeltas ?? pulsedElementLineDeltas, + ) + needsRedrawRef.current = true + }, [versionPreview, versionFollowTarget]) + + useEffect(() => { + if (!initialized || !versionFollowTarget?.viewId) return + if (versionFollowTarget.resourceType === 'element' && versionFollowTarget.resourceId) { + focusElement(versionFollowTarget.viewId, versionFollowTarget.resourceId) + return + } + focusDiagram(versionFollowTarget.viewId) + }, [focusDiagram, focusElement, initialized, versionFollowTarget?.resourceId, versionFollowTarget?.resourceType, versionFollowTarget?.token, versionFollowTarget?.viewId]) + useEffect(() => { setHoverLocked(hoverLocked) }, [hoverLocked, setHoverLocked]) @@ -771,6 +851,7 @@ export const ZUICanvas = forwardRef(function ZUICanvas({ setRendererHighlightedTags(new Set()) setRendererHighlightColor('') setRendererHiddenTags(new Set()) + setRendererVersionDiff(new Map(), new Map()) } }, []) @@ -957,6 +1038,19 @@ export const ZUICanvas = forwardRef(function ZUICanvas({ {hoveredItem.data.details.label} + + UNDERLYING PATHS + {hoveredItem.data.details.connectors.slice(0, 4).map((leaf, index) => ( + + {leaf.source.actualElementName} → {leaf.target.actualElementName} + + ))} + {hoveredItem.data.details.connectors.length > 4 && ( + + +{hoveredItem.data.details.connectors.length - 4} more + + )} + {hoveredItem.data.details.ownerViewIds.map((ownerViewId, index) => ( diff --git a/frontend/src/components/ZUI/focus.test.ts b/frontend/src/components/ZUI/focus.test.ts new file mode 100644 index 0000000..ccc1f40 --- /dev/null +++ b/frontend/src/components/ZUI/focus.test.ts @@ -0,0 +1,534 @@ +import { describe, expect, it } from 'vitest' +import { computeLayout } from './layout' +import { findDiagramFocusTarget, findElementFocusTarget, viewportForDiagramFocusTarget, viewportForElementFocusTarget, viewportForFocusTarget, type ZUIFocusTarget } from './focus' +import { calculateMaxZoom, constrainViewState } from './useZUIInteraction' +import { buildCameraTransitionRebase, findFocusedFlattenedLayerForTest, getCameraRebase, getExpandThresholds, rawCameraView, worldToScreenX, worldToScreenY } from './renderer' +import type { ExploreData, PlacedElement, ViewConnector, ViewTreeNode } from '../../types' +import type { DiagramGroupLayout, LayoutNode, ZUIViewState } from './types' + +function treeNode(id: number, name: string, ownerElementId: number | null, parentViewId: number | null, children: ViewTreeNode[] = []): ViewTreeNode { + return { + id, + owner_element_id: ownerElementId, + name, + description: null, + level_label: null, + level: 0, + depth: parentViewId == null ? 0 : 1, + created_at: '2024-01-01', + updated_at: '2024-01-01', + parent_view_id: parentViewId, + children, + } +} + +function placed(viewId: number, elementId: number, x: number, y: number, hasView = false): PlacedElement { + return { + id: viewId * 1000 + elementId, + view_id: viewId, + element_id: elementId, + position_x: x, + position_y: y, + name: `Element ${elementId}`, + description: null, + kind: 'service', + technology: null, + url: null, + logo_url: null, + technology_connectors: [], + tags: [], + has_view: hasView, + view_label: null, + } +} + +function testNode(id: string, children: LayoutNode[] = [], childScale = 0.8): LayoutNode { + return { + id, + elementId: Number(id.replace(/\D/g, '')) || 1, + diagramId: 1, + worldX: 0, + worldY: 0, + worldW: 180, + worldH: 85, + label: id, + type: 'service', + logoUrl: null, + description: null, + technology: null, + tags: [], + ancestorElementIds: [], + pathElementIds: [], + children, + childScale, + childOffsetX: 0, + childOffsetY: 0, + edgesOut: [], + } +} + +function testGroup(nodes: LayoutNode[]): DiagramGroupLayout { + return { + diagramId: 1, + label: 'Root', + description: null, + level: 0, + levelLabel: null, + worldX: 0, + worldY: 0, + worldW: 180, + worldH: 85, + diagramW: 180, + diagramH: 85, + diagramX: 0, + diagramY: 0, + nodes, + edges: [], + } +} + +function navigation(fromViewId: number, elementId: number, toViewId: number): ViewConnector { + return { + id: toViewId, + element_id: elementId, + from_view_id: fromViewId, + to_view_id: toViewId, + to_view_name: `View ${toViewId}`, + relation_type: 'child', + } +} + +function nestedExploreData(): ExploreData { + return { + tree: [ + treeNode(1, 'Root', null, null, [ + treeNode(2, 'Second', 101, 1, [ + treeNode(3, 'Third', 201, 2, [ + treeNode(4, 'Fourth', 301, 3), + ]), + ]), + ]), + ], + views: { + 1: { placements: [placed(1, 101, 120, 100, true)], connectors: [] }, + 2: { placements: [placed(2, 201, 200, 160, true)], connectors: [] }, + 3: { placements: [placed(3, 301, 300, 220, true)], connectors: [] }, + 4: { + placements: [ + placed(4, 401, 40, 60), + placed(4, 499, 10_000, 8_000), + ], + connectors: [], + }, + }, + navigations: [ + navigation(1, 101, 2), + navigation(2, 201, 3), + navigation(3, 301, 4), + ], + } +} + +function deepSingleChainExploreData(depth: number): ExploreData { + const treeById = new Map() + for (let viewId = depth; viewId >= 1; viewId -= 1) { + treeById.set( + viewId, + treeNode( + viewId, + `View ${viewId}`, + viewId === 1 ? null : 1000 + viewId - 1, + viewId === 1 ? null : viewId - 1, + viewId < depth ? [treeById.get(viewId + 1)!] : [], + ), + ) + } + + const views: ExploreData['views'] = {} + const navigations: ViewConnector[] = [] + for (let viewId = 1; viewId <= depth; viewId += 1) { + const elementId = viewId === depth ? 9001 : 1000 + viewId + views[viewId] = { + placements: [placed(viewId, elementId, viewId * 15, viewId * 10, viewId < depth)], + connectors: [], + } + if (viewId < depth) { + navigations.push(navigation(viewId, elementId, viewId + 1)) + } + } + + return { + tree: [treeById.get(1)!], + views, + navigations, + } +} + +function focusMatrixExploreData(depth: number): ExploreData { + const treeById = new Map() + for (let viewId = depth; viewId >= 1; viewId -= 1) { + treeById.set( + viewId, + treeNode( + viewId, + `Matrix View ${viewId}`, + viewId === 1 ? null : 10_000 + viewId - 1, + viewId === 1 ? null : viewId - 1, + viewId < depth ? [treeById.get(viewId + 1)!] : [], + ), + ) + } + + const views: ExploreData['views'] = {} + const navigations: ViewConnector[] = [] + for (let viewId = 1; viewId <= depth; viewId += 1) { + const childOwnerId = 10_000 + viewId + const childBearing = viewId < depth + ? [placed(viewId, childOwnerId, viewId % 2 === 0 ? 1600 : -1500, viewId % 3 === 0 ? -1250 : 1350, true)] + : [] + views[viewId] = { + placements: [ + ...childBearing, + placed(viewId, viewId * 100 + 1, -2600 + viewId * 37, 1800 - viewId * 53), + placed(viewId, viewId * 100 + 2, 2400 - viewId * 41, -2100 + viewId * 47), + placed(viewId, viewId * 100 + 3, viewId * 180, -viewId * 140), + ], + connectors: [], + } + if (viewId < depth) { + navigations.push(navigation(viewId, childOwnerId, viewId + 1)) + } + } + + return { + tree: [treeById.get(1)!], + views, + navigations, + } +} + +function placementsIn(data: ExploreData): Array<{ viewId: number; elementId: number }> { + return Object.entries(data.views).flatMap(([viewIdText, content]) => { + const viewId = Number(viewIdText) + return (content.placements ?? []).map((placement) => ({ viewId, elementId: placement.element_id })) + }) +} + +function viewsIn(data: ExploreData): number[] { + return Object.keys(data.views).map(Number).filter(Number.isFinite) +} + +function screenRect(target: ZUIFocusTarget, viewport: ZUIViewState) { + return { + left: worldToScreenX(target.absX, viewport), + top: worldToScreenY(target.absY, viewport), + right: worldToScreenX(target.absX + target.absW, viewport), + bottom: worldToScreenY(target.absY + target.absH, viewport), + width: target.absW * viewport.zoom, + height: target.absH * viewport.zoom, + } +} + +function worldScreenRect(rect: { x: number; y: number; w: number; h: number }, viewport: ZUIViewState) { + return { + left: worldToScreenX(rect.x, viewport), + top: worldToScreenY(rect.y, viewport), + right: worldToScreenX(rect.x + rect.w, viewport), + bottom: worldToScreenY(rect.y + rect.h, viewport), + width: rect.w * viewport.zoom, + height: rect.h * viewport.zoom, + } +} + +function interpolateViewState(from: ZUIViewState, to: ZUIViewState, t: number): ZUIViewState { + return { + x: from.x + (to.x - from.x) * t, + y: from.y + (to.y - from.y) * t, + zoom: from.zoom + (to.zoom - from.zoom) * t, + } +} + +function completeFocusNavigationFromCurrent( + current: ZUIViewState, + destination: ZUIViewState, + canvasW: number, + canvasH: number, + bbox: { minX: number; minY: number; maxX: number; maxY: number }, +): ZUIViewState { + ;[0.15, 0.5, 0.85].forEach((t) => { + constrainViewState(interpolateViewState(current, destination, t), canvasW, canvasH, bbox) + }) + return constrainViewState(destination, canvasW, canvasH, bbox) +} + +function expectFiniteViewport(viewport: ZUIViewState, context: string) { + expect(Number.isFinite(viewport.x), `${context} x`).toBe(true) + expect(Number.isFinite(viewport.y), `${context} y`).toBe(true) + expect(Number.isFinite(viewport.zoom), `${context} zoom`).toBe(true) + expect(viewport.zoom, `${context} zoom`).toBeGreaterThan(0) +} + +function expectScreenRectVisible( + rect: ReturnType, + canvasW: number, + canvasH: number, + context: string, +) { + const epsilon = 0.75 + expect(rect.left, `${context} left`).toBeGreaterThanOrEqual(-epsilon) + expect(rect.top, `${context} top`).toBeGreaterThanOrEqual(-epsilon) + expect(rect.right, `${context} right`).toBeLessThanOrEqual(canvasW + epsilon) + expect(rect.bottom, `${context} bottom`).toBeLessThanOrEqual(canvasH + epsilon) + expect(rect.width, `${context} width`).toBeGreaterThan(0) + expect(rect.height, `${context} height`).toBeGreaterThan(0) +} + +describe('ZUI focus targets', () => { + it('rebases a high-zoom camera to a small centered render transform', () => { + const rebase = getCameraRebase( + { x: -147_317_059.10654327, y: -184_315_493.52577353, zoom: 906_732.1382976775 }, + 997, + 975, + ) + + expect(rebase.originX).toBeCloseTo(162.47086805935993, 10) + expect(rebase.originY).toBeCloseTo(203.27500619070713, 10) + expect(rebase.view).toEqual({ + x: 498.5, + y: 487.5, + zoom: 906_732.1382976775, + }) + }) + + it('flattens the focused deepest layer at extreme zoom', () => { + const layout = computeLayout(deepSingleChainExploreData(8)) + const elementTarget = findElementFocusTarget(layout.groups, 8, 9001) + expect(elementTarget).not.toBeNull() + const constrained = { + x: 498.5, + y: 487.5, + zoom: 13_610_091, + originX: elementTarget!.absX + elementTarget!.absW / 2, + originY: elementTarget!.absY + elementTarget!.absH / 2, + } + const rebase = getCameraRebase(constrained, 997, 975) + const layer = findFocusedFlattenedLayerForTest( + layout.groups, + constrained, + 997, + 975, + getExpandThresholds(997), + rebase, + ) + + expect(layer?.nodes.length).toBeGreaterThan(0) + const target = layer!.nodes.find((node) => node.elementId === 9001) + expect(target).toBeTruthy() + const left = worldToScreenX(target!.worldX, layer!.view) + const right = worldToScreenX(target!.worldX + target!.worldW, layer!.view) + expect(Number.isFinite(left)).toBe(true) + expect(Number.isFinite(right)).toBe(true) + expect(right - left).toBeGreaterThan(0) + }) + + it('rebases stacked camera-center transitions without forcing ancestor expansion', () => { + const grandchild = testNode('node-3', []) + const child = testNode('node-2', [grandchild]) + const parent = testNode('node-1', [child]) + const groups = [testGroup([parent])] + const thresholds = getExpandThresholds(1200) + + const rebase = buildCameraTransitionRebase( + groups, + { x: 420, y: 315, zoom: 2.2 }, + 1200, + 800, + thresholds, + ) + + expect(rebase.preserveChildAlphaNodeIds.has('node-1')).toBe(true) + expect(rebase.preserveChildAlphaNodeIds.has('node-2')).toBe(false) + }) + + it('does not rebase when only one camera-center transition is active', () => { + const child = testNode('node-2', []) + const parent = testNode('node-1', [child]) + const groups = [testGroup([parent])] + const thresholds = getExpandThresholds(1200) + + const rebase = buildCameraTransitionRebase( + groups, + { x: 420, y: 315, zoom: 1.9 }, + 1200, + 800, + thresholds, + ) + + expect(rebase.preserveChildAlphaNodeIds.size).toBe(0) + }) + + it('finds and centers an element inside a deeply nested view', () => { + const layout = computeLayout(nestedExploreData()) + const target = findElementFocusTarget(layout.groups, 4, 401) + expect(target).not.toBeNull() + + const viewport = viewportForFocusTarget(target!, 1200, 800, 100_000, 0.18, { + minTargetScreenW: 320, + keepParentVisible: true, + }) + expect(viewport).not.toBeNull() + + const constrained = constrainViewState(viewport!, 1200, 800, layout.bbox) + const rect = screenRect(target!, constrained) + expect(rect.left).toBeGreaterThanOrEqual(0) + expect(rect.top).toBeGreaterThanOrEqual(0) + expect(rect.right).toBeLessThanOrEqual(1200) + expect(rect.bottom).toBeLessThanOrEqual(800) + expect(rect.width).toBeGreaterThanOrEqual(320) + }) + + it('zooms nested view navigation far enough for the selected view contents to render', () => { + const layout = computeLayout(nestedExploreData()) + const viewTarget = findDiagramFocusTarget(layout.groups, 4) + const elementTarget = findElementFocusTarget(layout.groups, 4, 401) + expect(viewTarget?.contentRect).toBeTruthy() + expect(elementTarget).not.toBeNull() + + const viewport = viewportForFocusTarget(viewTarget!, 1200, 800, 100_000, 0.16, { + preferContent: true, + minTargetScreenW: 260, + minChildScreenW: 104, + }) + expect(viewport).not.toBeNull() + + const constrained = constrainViewState(viewport!, 1200, 800, layout.bbox) + const rect = screenRect(elementTarget!, constrained) + expect(rect.width).toBeGreaterThanOrEqual(104) + }) + + it('does not inflate sub-pixel nested content bounds when centering a deep view', () => { + const layout = computeLayout(deepSingleChainExploreData(8)) + const viewTarget = findDiagramFocusTarget(layout.groups, 8) + const elementTarget = findElementFocusTarget(layout.groups, 8, 9001) + expect(viewTarget?.contentRect).toBeTruthy() + expect(elementTarget).not.toBeNull() + + const viewport = viewportForFocusTarget(viewTarget!, 1200, 800, 1_000_000, 0.16, { + preferContent: true, + minTargetScreenW: 260, + minChildScreenW: 104, + }) + expect(viewport).not.toBeNull() + + const rect = screenRect(elementTarget!, constrainViewState(viewport!, 1200, 800, layout.bbox)) + expect(rect.left).toBeGreaterThanOrEqual(0) + expect(rect.top).toBeGreaterThanOrEqual(0) + expect(rect.right).toBeLessThanOrEqual(1200) + expect(rect.bottom).toBeLessThanOrEqual(800) + expect(rect.width).toBeGreaterThanOrEqual(104) + }) + + it('keeps a sub-pixel expandable element visible when capping child zoom', () => { + const layout = computeLayout(deepSingleChainExploreData(8)) + const target = findElementFocusTarget(layout.groups, 7, 1007) + expect(target?.node?.children.length).toBe(1) + + const viewport = viewportForFocusTarget(target!, 1200, 800, 1_000_000, 0.18, { + minTargetScreenW: 320, + keepParentVisible: true, + }) + expect(viewport).not.toBeNull() + + const rect = screenRect(target!, constrainViewState(viewport!, 1200, 800, layout.bbox)) + expect(rect.left).toBeGreaterThanOrEqual(0) + expect(rect.top).toBeGreaterThanOrEqual(0) + expect(rect.right).toBeLessThanOrEqual(1200) + expect(rect.bottom).toBeLessThanOrEqual(800) + expect(rect.width).toBeGreaterThanOrEqual(320) + }) + + it('can navigate and zoom to every placed element across viewport sizes, levels, and current cameras', () => { + const data = focusMatrixExploreData(6) + const layout = computeLayout(data) + const canvasCases = [ + { name: 'desktop', w: 1200, h: 800, isMobile: false, leafMinWidth: 320 }, + { name: 'mobile', w: 390, h: 720, isMobile: true, leafMinWidth: 220 }, + { name: 'ultrawide', w: 1800, h: 900, isMobile: false, leafMinWidth: 320 }, + ] + const currentViewports: ZUIViewState[] = [ + { x: 0, y: 0, zoom: 0.4 }, + { x: -25_000, y: 18_000, zoom: 0.8 }, + { x: 40_000, y: -35_000, zoom: 24 }, + ] + + for (const { name, w, h, isMobile, leafMinWidth } of canvasCases) { + const maxZoom = calculateMaxZoom(layout.groups, w) + const thresholds = getExpandThresholds(w) + for (const { viewId, elementId } of placementsIn(data)) { + const target = findElementFocusTarget(layout.groups, viewId, elementId) + expect(target, `${name} view ${viewId} element ${elementId} target`).not.toBeNull() + const viewport = viewportForElementFocusTarget(target!, w, h, maxZoom, isMobile) + expect(viewport, `${name} view ${viewId} element ${elementId} viewport`).not.toBeNull() + expectFiniteViewport(viewport!, `${name} view ${viewId} element ${elementId}`) + + for (const current of currentViewports) { + const finalViewport = completeFocusNavigationFromCurrent(current, viewport!, w, h, layout.bbox) + const rect = screenRect(target!, finalViewport) + const context = `${name} from ${current.x}/${current.y}/${current.zoom} to view ${viewId} element ${elementId}` + expectScreenRectVisible(rect, w, h, context) + + const expectedMinWidth = target!.node?.children.length ? Math.min(leafMinWidth, thresholds.start) : leafMinWidth + expect(rect.width, `${context} usable width`).toBeGreaterThanOrEqual(expectedMinWidth - 0.75) + } + } + } + }) + + it('can navigate and zoom to every linked view target without losing the content center', () => { + const data = focusMatrixExploreData(6) + const layout = computeLayout(data) + const canvasW = 1200 + const canvasH = 800 + const maxZoom = calculateMaxZoom(layout.groups, canvasW) + + for (const viewId of viewsIn(data)) { + const target = findDiagramFocusTarget(layout.groups, viewId) + expect(target, `view ${viewId} target`).not.toBeNull() + const viewport = viewportForDiagramFocusTarget(target!, canvasW, canvasH, maxZoom, false) + expect(viewport, `view ${viewId} viewport`).not.toBeNull() + expectFiniteViewport(viewport!, `view ${viewId}`) + + const finalViewport = constrainViewState(viewport!, canvasW, canvasH, layout.bbox) + const rect = screenRect(target!, finalViewport) + expect(rect.width, `view ${viewId} target width`).toBeGreaterThan(0) + expect(rect.height, `view ${viewId} target height`).toBeGreaterThan(0) + expect((rect.left + rect.right) / 2, `view ${viewId} target center x`).toBeGreaterThanOrEqual(0) + expect((rect.left + rect.right) / 2, `view ${viewId} target center x`).toBeLessThanOrEqual(canvasW) + expect((rect.top + rect.bottom) / 2, `view ${viewId} target center y`).toBeGreaterThanOrEqual(0) + expect((rect.top + rect.bottom) / 2, `view ${viewId} target center y`).toBeLessThanOrEqual(canvasH) + + if (target!.contentRect) { + const contentRect = worldScreenRect(target!.contentRect, finalViewport) + expect(contentRect.width, `view ${viewId} content width`).toBeGreaterThan(0) + expect(contentRect.height, `view ${viewId} content height`).toBeGreaterThan(0) + expect((contentRect.left + contentRect.right) / 2, `view ${viewId} content center x`).toBeGreaterThanOrEqual(0) + expect((contentRect.left + contentRect.right) / 2, `view ${viewId} content center x`).toBeLessThanOrEqual(canvasW) + expect((contentRect.top + contentRect.bottom) / 2, `view ${viewId} content center y`).toBeGreaterThanOrEqual(0) + expect((contentRect.top + contentRect.bottom) / 2, `view ${viewId} content center y`).toBeLessThanOrEqual(canvasH) + } + } + }) + + it('keeps focus centering available when the canvas is smaller than the old fixed padding', () => { + const targetView = { x: 400, y: 300, zoom: 1 } + const constrained = constrainViewState(targetView, 1000, 800, { + minX: 0, + minY: 0, + maxX: 1600, + maxY: 1200, + }) + + expect(rawCameraView(constrained).x).toBeCloseTo(targetView.x) + expect(rawCameraView(constrained).y).toBeCloseTo(targetView.y) + }) +}) diff --git a/frontend/src/components/ZUI/focus.ts b/frontend/src/components/ZUI/focus.ts new file mode 100644 index 0000000..5f4b9da --- /dev/null +++ b/frontend/src/components/ZUI/focus.ts @@ -0,0 +1,293 @@ +import type { DiagramGroupLayout, LayoutNode, ZUIViewState } from './types' +import { getExpandThresholds } from './renderer' + +interface Rect { + x: number + y: number + w: number + h: number +} + +export interface ZUIFocusTarget { + id: string + label: string + type: 'group' | 'node' + isCircular?: boolean + absX: number + absY: number + absW: number + absH: number + absScale: number + node?: LayoutNode + contentRect?: Rect +} + +export interface ZUIFocusViewportOptions { + preferContent?: boolean + minTargetScreenW?: number + minChildScreenW?: number + keepParentVisible?: boolean +} + +function boundsForRects(rects: Rect[]): Rect | null { + if (rects.length === 0) return null + + let minX = Infinity + let minY = Infinity + let maxX = -Infinity + let maxY = -Infinity + + for (const rect of rects) { + minX = Math.min(minX, rect.x) + minY = Math.min(minY, rect.y) + maxX = Math.max(maxX, rect.x + rect.w) + maxY = Math.max(maxY, rect.y + rect.h) + } + + if (!Number.isFinite(minX) || !Number.isFinite(minY) || !Number.isFinite(maxX) || !Number.isFinite(maxY)) { + return null + } + + return { x: minX, y: minY, w: positiveSize(maxX - minX), h: positiveSize(maxY - minY) } +} + +function positiveSize(value: number): number { + return Number.isFinite(value) && value > 0 ? value : 0.0001 +} + +function childContentRect(node: LayoutNode, absX: number, absY: number, absScale: number): Rect | null { + if (node.children.length === 0) return null + + const childAbsScale = absScale * node.childScale + return boundsForRects(node.children.map((child) => ({ + x: absX + (child.worldX - node.childOffsetX) * childAbsScale, + y: absY + (child.worldY - node.childOffsetY) * childAbsScale, + w: child.worldW * childAbsScale, + h: child.worldH * childAbsScale, + }))) +} + +function nodeTarget( + node: LayoutNode, + parentAbsX: number, + parentAbsY: number, + parentAbsScale: number, + parentChildOffsetX: number, + parentChildOffsetY: number, +): ZUIFocusTarget { + const absX = parentAbsX + (node.worldX - parentChildOffsetX) * parentAbsScale + const absY = parentAbsY + (node.worldY - parentChildOffsetY) * parentAbsScale + const absW = node.worldW * parentAbsScale + const absH = node.worldH * parentAbsScale + + return { + id: node.id, + label: node.linkedDiagramLabel || node.label, + type: 'node', + isCircular: node.isCircular, + absX, + absY, + absW, + absH, + absScale: parentAbsScale, + node, + contentRect: childContentRect(node, absX, absY, parentAbsScale) ?? undefined, + } +} + +function findLinkedDiagramInNodes( + viewId: number, + nodes: DiagramGroupLayout['nodes'], + parentAbsX: number, + parentAbsY: number, + parentAbsScale: number, + parentChildOffsetX: number, + parentChildOffsetY: number, +): ZUIFocusTarget | null { + for (const node of nodes) { + const target = nodeTarget(node, parentAbsX, parentAbsY, parentAbsScale, parentChildOffsetX, parentChildOffsetY) + + if (node.linkedDiagramId === viewId) { + return target + } + + if (node.children.length > 0) { + const found = findLinkedDiagramInNodes( + viewId, + node.children, + target.absX, + target.absY, + parentAbsScale * node.childScale, + node.childOffsetX, + node.childOffsetY, + ) + if (found) return found + } + } + + return null +} + +function findElementInNodes( + viewId: number, + elementId: number, + nodes: DiagramGroupLayout['nodes'], + parentAbsX: number, + parentAbsY: number, + parentAbsScale: number, + parentChildOffsetX: number, + parentChildOffsetY: number, +): ZUIFocusTarget | null { + for (const node of nodes) { + const target = nodeTarget(node, parentAbsX, parentAbsY, parentAbsScale, parentChildOffsetX, parentChildOffsetY) + + if (node.diagramId === viewId && node.elementId === elementId) { + return target + } + + if (node.children.length > 0) { + const found = findElementInNodes( + viewId, + elementId, + node.children, + target.absX, + target.absY, + parentAbsScale * node.childScale, + node.childOffsetX, + node.childOffsetY, + ) + if (found) return found + } + } + + return null +} + +export function findDiagramFocusTarget(groups: DiagramGroupLayout[], viewId: number): ZUIFocusTarget | null { + for (const group of groups) { + if (group.diagramId === viewId) { + return { + id: `g-${group.diagramId}`, + label: group.label, + type: 'group', + absX: group.worldX, + absY: group.worldY, + absW: group.worldW, + absH: group.worldH, + absScale: 1, + } + } + + const found = findLinkedDiagramInNodes(viewId, group.nodes, 0, 0, 1, 0, 0) + if (found) return found + } + + return null +} + +export function findElementFocusTarget(groups: DiagramGroupLayout[], viewId: number, elementId: number): ZUIFocusTarget | null { + for (const group of groups) { + const found = findElementInNodes(viewId, elementId, group.nodes, 0, 0, 1, 0, 0) + if (found) return found + } + + return null +} + +export function viewportForFocusTarget( + target: ZUIFocusTarget, + canvasW: number, + canvasH: number, + maxZoom: number, + padding: number, + options: ZUIFocusViewportOptions = {}, +): ZUIViewState | null { + const rect = options.preferContent && target.contentRect ? target.contentRect : { + x: target.absX, + y: target.absY, + w: target.absW, + h: target.absH, + } + + const bboxW = positiveSize(rect.w) + const bboxH = positiveSize(rect.h) + const fitZoom = Math.min( + (canvasW * (1 - padding * 2)) / bboxW, + (canvasH * (1 - padding * 2)) / bboxH, + ) + const minZooms: number[] = [] + + if (options.minTargetScreenW && target.absW > 0) { + minZooms.push(options.minTargetScreenW / target.absW) + } + + if (target.node?.children.length && options.minChildScreenW && target.node.childScale > 0) { + const childAbsW = target.node.worldW * target.absScale * target.node.childScale + if (childAbsW > 0) { + minZooms.push(options.minChildScreenW / childAbsW) + } + } + + const finiteMinZooms = minZooms.filter((value) => Number.isFinite(value) && value > 0) + const zoomLimit = Math.max(maxZoom, ...finiteMinZooms) + let zoom = Math.min(fitZoom, zoomLimit) + + for (const minZoom of finiteMinZooms) { + zoom = Math.max(zoom, Math.min(minZoom, zoomLimit)) + } + + if (target.node?.children.length && options.keepParentVisible) { + const thresholds = getExpandThresholds(canvasW) + const maxParentScreenW = thresholds.start + (thresholds.end - thresholds.start) * 0.78 + zoom = Math.min(zoom, maxParentScreenW / positiveSize(target.absW)) + } + + if (!Number.isFinite(zoom) || zoom <= 0) return null + + return { + x: (canvasW - bboxW * zoom) / 2 - rect.x * zoom, + y: (canvasH - bboxH * zoom) / 2 - rect.y * zoom, + zoom, + } +} + +export function viewportForDiagramFocusTarget( + target: ZUIFocusTarget, + canvasW: number, + canvasH: number, + maxZoom: number, + isMobileLayout: boolean, +): ZUIViewState | null { + return viewportForFocusTarget( + target, + canvasW, + canvasH, + maxZoom, + isMobileLayout ? 0.18 : 0.16, + { + preferContent: true, + minTargetScreenW: isMobileLayout ? 180 : 260, + minChildScreenW: isMobileLayout ? 76 : 104, + }, + ) +} + +export function viewportForElementFocusTarget( + target: ZUIFocusTarget, + canvasW: number, + canvasH: number, + maxZoom: number, + isMobileLayout: boolean, +): ZUIViewState | null { + return viewportForFocusTarget( + target, + canvasW, + canvasH, + maxZoom, + isMobileLayout ? 0.2 : 0.18, + { + minTargetScreenW: isMobileLayout ? 220 : 320, + keepParentVisible: true, + }, + ) +} diff --git a/frontend/src/components/ZUI/layout.ts b/frontend/src/components/ZUI/layout.ts index ea093ef..ac65445 100644 --- a/frontend/src/components/ZUI/layout.ts +++ b/frontend/src/components/ZUI/layout.ts @@ -10,15 +10,15 @@ import type { ExploreData, ViewConnector, } from '../../types' -import { resolveIconPath } from '../../utils/url' +import { resolveElementIconUrl } from '../../utils/elementIcon' // ── Constants ────────────────────────────────────────────────────── -export const NODE_W = 200 -export const NODE_H = 100 +export const NODE_W = 180 +export const NODE_H = 85 const GROUP_PAD = 80 // padding inside a diagram group box const GROUP_SPACING = 400 // horizontal gap between root diagrams -const CHILD_PAD = 1 // padding inside a node when rendering children +const CHILD_PAD = 4 // padding inside a node when rendering children // ── Helpers ──────────────────────────────────────────────────────── @@ -159,6 +159,7 @@ function buildNodes( const edgesOut = (views[String(diagramId)]?.connectors ?? []) .filter((e) => e.source_element_id === obj.element_id) .map((e) => ({ + id: e.id, targetId: nodeId(diagramId, e.target_element_id), label: e.label ?? '', direction: e.direction ?? 'forward', @@ -167,12 +168,6 @@ function buildNodes( type: e.style || 'bezier', })) - const derivedPrimaryIconPath = (() => { - const selected = obj.technology_connectors?.find((link) => link.type === 'catalog' && !!(link.is_primary_icon ?? (link as any).isPrimaryIcon) && !!link.slug) - if (!selected?.slug) return null - return resolveIconPath(`/icons/${selected.slug}.png`) - })() - return { id: nodeId(diagramId, obj.element_id), elementId: obj.element_id, @@ -183,7 +178,7 @@ function buildNodes( worldH: NODE_H, label: obj.name, type: obj.kind ?? 'system', - logoUrl: obj.logo_url ? resolveIconPath(obj.logo_url) : derivedPrimaryIconPath, + logoUrl: resolveElementIconUrl(obj.logo_url, obj.technology_connectors), description: obj.description ?? null, technology: obj.technology ?? null, tags: obj.tags ?? [], @@ -253,6 +248,7 @@ export function computeLayout(data: ExploreData): ZUILayout { // Edges within the same diagram (world-level, not children) const edges = (diagData.connectors ?? []).map((e) => ({ + id: e.id, sourceId: nodeId(diag.id, e.source_element_id), targetId: nodeId(diag.id, e.target_element_id), label: e.label ?? '', diff --git a/frontend/src/components/ZUI/proxy.ts b/frontend/src/components/ZUI/proxy.ts index dfb32a1..5b06f74 100644 --- a/frontend/src/components/ZUI/proxy.ts +++ b/frontend/src/components/ZUI/proxy.ts @@ -1,8 +1,15 @@ -import { resolveZUIProxyConnectors, type ZUIResolvedConnector } from '../../crossBranch/resolve' +import { + resolveZUIProxyConnectors, + type ZUIHiddenProxyBadge, + type ZUIViewportBounds, + type ZUIProxyResolution, + type ZUIResolvedConnector, +} from '../../crossBranch/resolve' import type { WorkspaceGraphSnapshot } from '../../crossBranch/types' import type { LayoutNode, ZUIViewState, HoveredItem } from './types' import { getExpandThresholds, pickEdgeLabelPosition, type ScreenRect } from './renderer' import type { CrossBranchContextSettings } from '../../crossBranch/types' +import { DEFAULT_MIN_CONNECTOR_ANCHOR_ALPHA } from '../../crossBranch/settings' export interface VisibleNodeAnchor { nodeId: string @@ -20,12 +27,58 @@ function clamp(value: number, min: number, max: number) { return value < min ? min : value > max ? max : value } +function connectorAlpha(alpha: number): number { + return clamp(alpha * 1.1, 0.35, 0.95) +} + function transitionT(screenW: number, start: number, end: number): number { return clamp((screenW - start) / (end - start), 0, 1) } -function collectVisibleAnchorsInNodes( - nodes: LayoutNode[], +function visualRectForNode( + absX: number, + absY: number, + absW: number, + absH: number, + hasChildren: boolean, + screenW: number, + thresholds: { start: number; end: number }, +) { + if (!hasChildren && screenW > thresholds.end) { + const scale = thresholds.end / screenW + const visualW = absW * scale + const visualH = absH * scale + return { + worldX: absX + (absW - visualW) / 2, + worldY: absY + (absH - visualH) / 2, + worldW: visualW, + worldH: visualH, + } + } + + return { + worldX: absX, + worldY: absY, + worldW: absW, + worldH: absH, + } +} + +function registerVisibleAnchor( + node: LayoutNode, + visibleAnchors: Map, + byNodeId: Map, + anchor: VisibleNodeAnchor, +) { + const existing = visibleAnchors.get(node.elementId) + if (!existing || existing.pathDepth < anchor.pathDepth || existing.renderAlpha < anchor.renderAlpha) { + visibleAnchors.set(node.elementId, anchor) + } + byNodeId.set(node.id, anchor) +} + +function collectVisibleAnchorForNode( + node: LayoutNode, view: ZUIViewState, thresholds: { start: number; end: number }, hiddenTags: Set, @@ -38,42 +91,42 @@ function collectVisibleAnchorsInNodes( parentChildOffsetX: number, parentChildOffsetY: number, ) { - for (const node of nodes) { - if (hiddenTags.size > 0 && node.tags.some((tag) => hiddenTags.has(tag))) continue - - const absX = parentAbsX + (node.worldX - parentChildOffsetX) * parentAbsScale - const absY = parentAbsY + (node.worldY - parentChildOffsetY) * parentAbsScale - const absScale = parentAbsScale - const absW = node.worldW * absScale - const absH = node.worldH * absScale - const screenW = absW * view.zoom - if (screenW < 2) continue - - const hasChildren = node.children && node.children.length > 0 - const t = hasChildren ? transitionT(screenW, thresholds.start, thresholds.end) : 0 - const parentAlpha = inheritedAlpha * (1 - t) - const childAlpha = inheritedAlpha * t - - if (!hasChildren || t <= 0.95) { - const anchor: VisibleNodeAnchor = { - nodeId: node.id, - elementId: node.elementId, - label: node.label, - worldX: absX, - worldY: absY, - worldW: absW, - worldH: absH, - pathDepth: node.pathElementIds.length, - renderAlpha: hasChildren ? parentAlpha : inheritedAlpha, - } - const existing = visibleAnchors.get(node.elementId) - if (!existing || existing.pathDepth < anchor.pathDepth) visibleAnchors.set(node.elementId, anchor) - byNodeId.set(node.id, anchor) - } + if (hiddenTags.size > 0 && node.tags.some((tag) => hiddenTags.has(tag))) return { selfDrawn: false } + + const absX = parentAbsX + (node.worldX - parentChildOffsetX) * parentAbsScale + const absY = parentAbsY + (node.worldY - parentChildOffsetY) * parentAbsScale + const absScale = parentAbsScale + const absW = node.worldW * absScale + const absH = node.worldH * absScale + const screenW = absW * view.zoom + if (screenW < 2) return { selfDrawn: false } + + const hasChildren = node.children && node.children.length > 0 + const t = hasChildren ? transitionT(screenW, thresholds.start, thresholds.end) : 0 + const parentAlpha = inheritedAlpha * (1 - t) + const childAlpha = inheritedAlpha * t + const selfDrawn = !hasChildren || t <= 0.95 + const visualRect = visualRectForNode(absX, absY, absW, absH, hasChildren, screenW, thresholds) + + if (selfDrawn) { + registerVisibleAnchor(node, visibleAnchors, byNodeId, { + nodeId: node.id, + elementId: node.elementId, + label: node.label, + worldX: visualRect.worldX, + worldY: visualRect.worldY, + worldW: visualRect.worldW, + worldH: visualRect.worldH, + pathDepth: node.pathElementIds.length, + renderAlpha: hasChildren ? parentAlpha : inheritedAlpha, + }) + } - if (hasChildren && t > 0.05) { - collectVisibleAnchorsInNodes( - node.children, + let hasDirectChildDrawn = false + if (hasChildren && t > 0.05) { + for (const child of node.children) { + const childResult = collectVisibleAnchorForNode( + child, view, thresholds, hiddenTags, @@ -86,8 +139,57 @@ function collectVisibleAnchorsInNodes( node.childOffsetX, node.childOffsetY, ) + hasDirectChildDrawn = hasDirectChildDrawn || childResult.selfDrawn } } + + if (!selfDrawn && hasDirectChildDrawn) { + registerVisibleAnchor(node, visibleAnchors, byNodeId, { + nodeId: node.id, + elementId: node.elementId, + label: node.label, + worldX: visualRect.worldX, + worldY: visualRect.worldY, + worldW: visualRect.worldW, + worldH: visualRect.worldH, + pathDepth: node.pathElementIds.length, + renderAlpha: Math.max(0.12, inheritedAlpha * 0.28), + }) + } + + return { selfDrawn } +} + +function collectVisibleAnchorsInNodes( + nodes: LayoutNode[], + view: ZUIViewState, + thresholds: { start: number; end: number }, + hiddenTags: Set, + visibleAnchors: Map, + byNodeId: Map, + inheritedAlpha: number, + parentAbsX: number, + parentAbsY: number, + parentAbsScale: number, + parentChildOffsetX: number, + parentChildOffsetY: number, +) { + for (const node of nodes) { + collectVisibleAnchorForNode( + node, + view, + thresholds, + hiddenTags, + visibleAnchors, + byNodeId, + inheritedAlpha, + parentAbsX, + parentAbsY, + parentAbsScale, + parentChildOffsetX, + parentChildOffsetY, + ) + } } export function collectVisibleNodeAnchors( @@ -121,13 +223,23 @@ export function collectVisibleNodeAnchors( return { visibleAnchors, byNodeId } } -function getDirectAnchorPoint(anchor: VisibleNodeAnchor, towards: VisibleNodeAnchor) { +function getAnchorCenter(anchor: VisibleNodeAnchor) { + return { + x: anchor.worldX + anchor.worldW / 2, + y: anchor.worldY + anchor.worldH / 2, + } +} + +function containsPoint(anchor: VisibleNodeAnchor, x: number, y: number) { + return x >= anchor.worldX && + x <= anchor.worldX + anchor.worldW && + y >= anchor.worldY && + y <= anchor.worldY + anchor.worldH +} + +function getRectBoundaryPoint(anchor: VisibleNodeAnchor, dx: number, dy: number) { const cx = anchor.worldX + anchor.worldW / 2 const cy = anchor.worldY + anchor.worldH / 2 - const tx = towards.worldX + towards.worldW / 2 - const ty = towards.worldY + towards.worldH / 2 - const dx = tx - cx - const dy = ty - cy const hw = anchor.worldW / 2 const hh = anchor.worldH / 2 @@ -148,15 +260,191 @@ function getDirectAnchorPoint(anchor: VisibleNodeAnchor, towards: VisibleNodeAnc } } +function getDirectAnchorPoint(anchor: VisibleNodeAnchor, towards: VisibleNodeAnchor) { + const anchorCenter = getAnchorCenter(anchor) + const towardsCenter = getAnchorCenter(towards) + + // Nested anchors represent parent/child nodes. Aim the child endpoint away + // from the parent center so proxy lines attach to the nearer child edge. + if (containsPoint(towards, anchorCenter.x, anchorCenter.y)) { + return getRectBoundaryPoint( + anchor, + anchorCenter.x - towardsCenter.x, + anchorCenter.y - towardsCenter.y, + ) + } + + return getRectBoundaryPoint( + anchor, + towardsCenter.x - anchorCenter.x, + towardsCenter.y - anchorCenter.y, + ) +} + +function getDirectAnchorPoints(source: VisibleNodeAnchor, target: VisibleNodeAnchor) { + const sourcePoint = getDirectAnchorPoint(source, target) + const targetPoint = getDirectAnchorPoint(target, source) + return { sourcePoint, targetPoint } +} + +function getDevicePixelRatio(): number { + return typeof window !== 'undefined' ? window.devicePixelRatio || 1 : 1 +} + +function roundRectPath(ctx: CanvasRenderingContext2D, x: number, y: number, w: number, h: number, r: number) { + ctx.beginPath() + ctx.roundRect(x, y, w, h, r) +} + +function drawFixedScreenProxyBadge( + ctx: CanvasRenderingContext2D, + label: string, + labelPos: { x: number; y: number }, + badgeCssW: number, + badgeCssH: number, + labelBg: string, + strokeStyle: string, + lineDashCss: number[], + fontWeight = 600, +) { + const matrix = ctx.getTransform() + const dpr = getDevicePixelRatio() + const centerX = matrix.a * labelPos.x + matrix.c * labelPos.y + matrix.e + const centerY = matrix.b * labelPos.x + matrix.d * labelPos.y + matrix.f + const badgeW = badgeCssW * dpr + const badgeH = badgeCssH * dpr + const radius = badgeH / 2 + + ctx.save() + ctx.setTransform(1, 0, 0, 1, 0, 0) + ctx.fillStyle = labelBg + roundRectPath(ctx, centerX - badgeW / 2, centerY - badgeH / 2, badgeW, badgeH, radius) + ctx.fill() + ctx.strokeStyle = strokeStyle + ctx.lineWidth = dpr + ctx.setLineDash(lineDashCss.map((value) => value * dpr)) + ctx.stroke() + ctx.setLineDash([]) + ctx.fillStyle = 'white' + ctx.font = `${fontWeight} ${11 * dpr}px Inter, system-ui, sans-serif` + ctx.textAlign = 'center' + ctx.textBaseline = 'middle' + ctx.fillText(label, centerX, centerY) + ctx.restore() +} + +function measureProxyBadge(ctx: CanvasRenderingContext2D, label: string, zoom: number, fontWeight = 600) { + ctx.save() + ctx.setTransform(1, 0, 0, 1, 0, 0) + ctx.font = `${fontWeight} 11px Inter, system-ui, sans-serif` + const textW = ctx.measureText(label).width + ctx.restore() + + const badgeCssW = Math.max(24, textW + 14) + const badgeCssH = 24 + return { + badgeCssW, + badgeCssH, + worldW: badgeCssW / zoom, + worldH: badgeCssH / zoom, + } +} + +interface IndexedProxyConnector { + connector: ZUIResolvedConnector + x1: number + y1: number + x2: number + y2: number + midX: number + midY: number +} + +export interface ProxyConnectorSpatialIndex { + cellSize: number + cells: Map +} + +const PROXY_CONNECTOR_INDEX_CELL_SIZE = 360 + +function proxyCellKey(cx: number, cy: number): string { + return `${cx},${cy}` +} + +function addProxyConnectorToSpatialIndex(index: ProxyConnectorSpatialIndex, connector: IndexedProxyConnector): void { + const minX = Math.min(connector.x1, connector.x2) + const maxX = Math.max(connector.x1, connector.x2) + const minY = Math.min(connector.y1, connector.y2) + const maxY = Math.max(connector.y1, connector.y2) + const startX = Math.floor(minX / index.cellSize) + const endX = Math.floor(maxX / index.cellSize) + const startY = Math.floor(minY / index.cellSize) + const endY = Math.floor(maxY / index.cellSize) + + for (let cx = startX; cx <= endX; cx++) { + for (let cy = startY; cy <= endY; cy++) { + const key = proxyCellKey(cx, cy) + let bucket = index.cells.get(key) + if (!bucket) { + bucket = [] + index.cells.set(key, bucket) + } + bucket.push(connector) + } + } +} + +export function buildProxyConnectorSpatialIndex( + connectors: ZUIResolvedConnector[], + visibleAnchorsByNodeId: Map, +): ProxyConnectorSpatialIndex { + const index: ProxyConnectorSpatialIndex = { + cellSize: PROXY_CONNECTOR_INDEX_CELL_SIZE, + cells: new Map(), + } + + for (const connector of connectors) { + const source = visibleAnchorsByNodeId.get(connector.sourceNodeId) + const target = visibleAnchorsByNodeId.get(connector.targetNodeId) + if (!source || !target) continue + + const { sourcePoint, targetPoint } = getDirectAnchorPoints(source, target) + addProxyConnectorToSpatialIndex(index, { + connector, + x1: sourcePoint.x, + y1: sourcePoint.y, + x2: targetPoint.x, + y2: targetPoint.y, + midX: (sourcePoint.x + targetPoint.x) / 2, + midY: (sourcePoint.y + targetPoint.y) / 2, + }) + } + + return index +} + export function buildVisibleProxyConnectors( snapshot: WorkspaceGraphSnapshot | null, visibleAnchors: Map, settings: CrossBranchContextSettings, -): ZUIResolvedConnector[] { + viewport?: ZUIViewportBounds | null, +): ZUIProxyResolution { + const minAlpha = settings.minConnectorAnchorAlpha ?? DEFAULT_MIN_CONNECTOR_ANCHOR_ALPHA + const eligibleAnchors = Array.from(visibleAnchors.entries()) + .filter(([, anchor]) => anchor.renderAlpha >= minAlpha) + const connectorAnchors = new Map(eligibleAnchors.map(([elementId, anchor]) => [elementId, anchor.nodeId])) + const anchorsByElementId = new Map(eligibleAnchors.map(([elementId, anchor]) => [elementId, { + nodeId: anchor.nodeId, + worldX: anchor.worldX, + worldY: anchor.worldY, + worldW: anchor.worldW, + worldH: anchor.worldH, + }])) return resolveZUIProxyConnectors( snapshot, - new Map(Array.from(visibleAnchors.entries()).map(([elementId, anchor]) => [elementId, anchor.nodeId])), + connectorAnchors, settings, + { viewport, anchorsByElementId }, ) } @@ -166,8 +454,27 @@ export function drawVisibleProxyConnectors( visibleAnchorsByNodeId: Map, zoom: number, labelBg: string, + accent: string, occupiedLabelRects: ScreenRect[], ) { + const connectorsByActualPair = new Map() + for (const connector of connectors) { + const pairKey = `${Math.min(connector.sourceElementId, connector.targetElementId)}::${Math.max(connector.sourceElementId, connector.targetElementId)}` + const family = connectorsByActualPair.get(pairKey) + if (family) family.push(connector) + else connectorsByActualPair.set(pairKey, [connector]) + } + + const provenanceKeys = new Set() + for (const family of connectorsByActualPair.values()) { + if (family.length < 2) continue + const sorted = [...family].sort((left, right) => { + if (left.maxDepth !== right.maxDepth) return left.maxDepth - right.maxDepth + return (left.sourceDepth + left.targetDepth) - (right.sourceDepth + right.targetDepth) + }) + for (const connector of sorted.slice(1)) provenanceKeys.add(connector.key) + } + for (const connector of connectors) { const source = visibleAnchorsByNodeId.get(connector.sourceNodeId) const target = visibleAnchorsByNodeId.get(connector.targetNodeId) @@ -175,104 +482,153 @@ export function drawVisibleProxyConnectors( const alpha = Math.min(source.renderAlpha, target.renderAlpha) if (alpha < 0.01) continue - const sourcePoint = getDirectAnchorPoint(source, target) - const targetPoint = getDirectAnchorPoint(target, source) + const { sourcePoint, targetPoint } = getDirectAnchorPoints(source, target) const midX = (sourcePoint.x + targetPoint.x) / 2 const midY = (sourcePoint.y + targetPoint.y) / 2 const label = String(connector.details.count) ctx.save() - ctx.globalAlpha = alpha - ctx.strokeStyle = 'rgba(255, 255, 255, 0.2)' - ctx.lineWidth = 2 / zoom + const isProvenanceStub = provenanceKeys.has(connector.key) + if (isProvenanceStub) { + ctx.restore() + continue + } + + ctx.globalAlpha = connectorAlpha(alpha) * 0.8 + ctx.strokeStyle = accent + ctx.lineWidth = 1 / zoom + ctx.lineCap = 'round' + ctx.setLineDash([1 / zoom, 4 / zoom]) ctx.beginPath() ctx.moveTo(sourcePoint.x, sourcePoint.y) ctx.lineTo(targetPoint.x, targetPoint.y) ctx.stroke() - const fontSize = 11 / zoom - ctx.font = `${fontSize}px Inter, system-ui, sans-serif` - const textMetrics = ctx.measureText(label) - const textW = textMetrics.width - const textH = fontSize + ctx.setLineDash([]) + const badge = measureProxyBadge(ctx, label, zoom) const labelPos = pickEdgeLabelPosition( ctx.getTransform(), midX, midY, - textW, - textH, + badge.worldW, + badge.worldH, targetPoint.x - sourcePoint.x, targetPoint.y - sourcePoint.y, occupiedLabelRects, ) - const px = 6 / zoom - const py = 4 / zoom - const badgeW = textW + px * 2 - const badgeH = textH + py * 2 - const badgeRadius = badgeH / 2 - ctx.fillStyle = labelBg - ctx.beginPath() - ctx.roundRect( - labelPos.x - badgeW / 2, - labelPos.y - badgeH / 2, - badgeW, - badgeH, - badgeRadius, - ) - ctx.fill() - ctx.strokeStyle = 'rgba(255, 255, 255, 0.2)' - ctx.lineWidth = 1 / zoom - ctx.stroke() - ctx.fillStyle = 'white' - ctx.textAlign = 'center' - ctx.textBaseline = 'middle' - ctx.fillText(label, labelPos.x, labelPos.y) + drawFixedScreenProxyBadge(ctx, label, labelPos, badge.badgeCssW, badge.badgeCssH, labelBg, accent, [1, 4]) ctx.restore() } } +export function drawVisibleDirectProxyBadges( + ctx: CanvasRenderingContext2D, + badges: ZUIHiddenProxyBadge[], + visibleAnchorsByNodeId: Map, + zoom: number, + labelBg: string, + occupiedLabelRects: ScreenRect[], +) { + for (const badge of badges) { + const source = visibleAnchorsByNodeId.get(badge.sourceNodeId) + const target = visibleAnchorsByNodeId.get(badge.targetNodeId) + if (!source || !target) continue + const alpha = Math.min(source.renderAlpha, target.renderAlpha) + if (alpha < 0.01) continue + + const { sourcePoint, targetPoint } = getDirectAnchorPoints(source, target) + const midX = (sourcePoint.x + targetPoint.x) / 2 + const midY = (sourcePoint.y + targetPoint.y) / 2 + const label = `+${badge.count}` + + ctx.save() + ctx.globalAlpha = alpha + const badgeMetrics = measureProxyBadge(ctx, label, zoom) + const labelPos = pickEdgeLabelPosition( + ctx.getTransform(), + midX, + midY, + badgeMetrics.worldW, + badgeMetrics.worldH, + targetPoint.x - sourcePoint.x, + targetPoint.y - sourcePoint.y, + occupiedLabelRects, + ) + drawFixedScreenProxyBadge( + ctx, + label, + labelPos, + badgeMetrics.badgeCssW, + badgeMetrics.badgeCssH, + labelBg, + 'rgba(255, 255, 255, 0.5)', + [4, 3], + ) + ctx.restore() + } +} + export function findHoveredProxyConnector( worldX: number, worldY: number, - connectors: ZUIResolvedConnector[], - visibleAnchorsByNodeId: Map, + index: ProxyConnectorSpatialIndex, view: ZUIViewState, ): HoveredItem | null { const threshold = 18 / view.zoom - for (const connector of connectors) { - const source = visibleAnchorsByNodeId.get(connector.sourceNodeId) - const target = visibleAnchorsByNodeId.get(connector.targetNodeId) - if (!source || !target) continue - const x1 = source.worldX + source.worldW / 2 - const y1 = source.worldY + source.worldH / 2 - const x2 = target.worldX + target.worldW / 2 - const y2 = target.worldY + target.worldH / 2 - const dx = x2 - x1 - const dy = y2 - y1 - const l2 = dx * dx + dy * dy - if (l2 === 0) continue - let t = ((worldX - x1) * dx + (worldY - y1) * dy) / l2 - t = Math.max(0, Math.min(1, t)) - const nearestX = x1 + t * dx - const nearestY = y1 + t * dy - const dist = Math.sqrt((worldX - nearestX) ** 2 + (worldY - nearestY) ** 2) - if (dist > threshold) continue + const startX = Math.floor((worldX - threshold) / index.cellSize) + const endX = Math.floor((worldX + threshold) / index.cellSize) + const startY = Math.floor((worldY - threshold) / index.cellSize) + const endY = Math.floor((worldY + threshold) / index.cellSize) + const thresholdSquared = threshold * threshold + const seen = new Set() + let bestConnector: IndexedProxyConnector | null = null + let bestDistSquared = thresholdSquared - return { - type: 'edge', - data: { - sourceId: connector.details.sourceAnchorName, - targetId: connector.details.targetAnchorName, - label: connector.details.label || 'Cross-branch connector', - diagramId: connector.details.ownerViewIds[0] ?? 0, - sourceObjId: connector.sourceAnchorElementId, - targetObjId: connector.targetAnchorElementId, - isProxy: true, - details: connector.details, - }, - absX: (x1 + x2) / 2, - absY: (y1 + y2) / 2, + for (let cx = startX; cx <= endX; cx++) { + for (let cy = startY; cy <= endY; cy++) { + const bucket = index.cells.get(proxyCellKey(cx, cy)) + if (!bucket) continue + + for (const indexed of bucket) { + const connector = indexed.connector + if (seen.has(connector.key)) continue + seen.add(connector.key) + const x1 = indexed.x1 + const y1 = indexed.y1 + const x2 = indexed.x2 + const y2 = indexed.y2 + const dx = x2 - x1 + const dy = y2 - y1 + const l2 = dx * dx + dy * dy + if (l2 === 0) continue + let t = ((worldX - x1) * dx + (worldY - y1) * dy) / l2 + t = Math.max(0, Math.min(1, t)) + const nearestX = x1 + t * dx + const nearestY = y1 + t * dy + const distSquared = (worldX - nearestX) ** 2 + (worldY - nearestY) ** 2 + if (distSquared > bestDistSquared) continue + bestDistSquared = distSquared + bestConnector = indexed + } } } - return null + + if (!bestConnector) return null + + const connector = bestConnector.connector + return { + type: 'edge', + data: { + sourceId: connector.details.sourceAnchorName, + targetId: connector.details.targetAnchorName, + label: connector.details.label || 'Cross-branch connector', + diagramId: connector.details.ownerViewIds[0] ?? 0, + sourceObjId: connector.sourceAnchorElementId, + targetObjId: connector.targetAnchorElementId, + isProxy: true, + details: connector.details, + }, + absX: bestConnector.midX, + absY: bestConnector.midY, + } } diff --git a/frontend/src/components/ZUI/renderer.ts b/frontend/src/components/ZUI/renderer.ts index 692b12c..5a14b5a 100644 --- a/frontend/src/components/ZUI/renderer.ts +++ b/frontend/src/components/ZUI/renderer.ts @@ -5,6 +5,7 @@ import { DEFAULT_SOURCE_HANDLE_SIDE, DEFAULT_TARGET_HANDLE_SIDE, getHandleFlowPosition, + getHandleSlotOffsetFromId, getLogicalHandleId, getVisualHandleIdForGroup, } from '../../utils/edgeDistribution' @@ -21,15 +22,21 @@ export function getExpandThresholds(canvasW: number) { const MIN_LABEL_PX = 12 // below this screen width, skip label text const MIN_DRAW_PX = 2 // below this screen width, skip node entirely const BADGE_THRESHOLD = 100 // node width in screen pixels below which we hide type badge and zoom icon +const CONNECTOR_MIN_ALPHA = 0.32 +const CONNECTOR_MAX_ALPHA = 0.95 +const CONNECTOR_LINE_PX = 2 // ── Screen-space font limits (px) ────────────────────────────────── -const MIN_FONT_NAME = 10 -const MAX_FONT_NAME = 50 -const MIN_FONT_BADGE = 12 -const MAX_FONT_BADGE = 30 const MIN_FONT_HINT = 12 const MAX_FONT_HINT = 24 +// Match ViewEditor ElementNode: nameSize="xl" (20px) and typeSize="2xs" +// (10px), rounded="lg" (8px), on the default 85px-high node. +const VIEW_EDITOR_NODE_H = 85 +const NAME_FONT_TO_NODE_H = 20 / VIEW_EDITOR_NODE_H +const TYPE_FONT_TO_NODE_H = 10 / VIEW_EDITOR_NODE_H +const RADIUS_TO_NODE_H = 8 / VIEW_EDITOR_NODE_H + export interface ScreenRect { left: number top: number @@ -48,20 +55,6 @@ function getClampedFontSize(baseWorldSize: number, minScreenSize: number, maxScr return clamp(baseWorldSize, minScreenSize / zoom, maxScreenSize / zoom) } -// ── Chakra v2 type palette - mirrors TYPE_COLORS in src/types/index.ts ─ -// .400 variants: used for type badge text and border tint -const TYPE_COLOR_400: Record = { - person: '#38b2ac', // teal.400 - system: '#63b3ed', // blue.400 - container: '#9f7aea', // purple.400 - component: '#f6ad55', // orange.400 - database: '#4fd1c5', // cyan.400 - queue: '#f6e05e', // yellow.400 - api: '#68d391', // green.400 - service: '#f687b3', // pink.400 - external: '#a0aec0', // gray.400 -} - /** Border color: type .400 at 50% alpha - bold branded tint */ const typeBorderColorCache = new Map() function typeBorderColor(type: string, alpha = 0.5): string { @@ -69,8 +62,7 @@ function typeBorderColor(type: string, alpha = 0.5): string { const cached = typeBorderColorCache.get(cacheKey) if (cached) return cached - const color = TYPE_COLOR_400[type] - const hex = typeof color === 'string' ? color : '#a0aec0' + const hex = '#a0aec0' const r = parseInt(hex.slice(1, 3), 16) const g = parseInt(hex.slice(3, 5), 16) const b = parseInt(hex.slice(5, 7), 16) @@ -144,6 +136,19 @@ export function setHiddenTags(tags: Set): void { currentHiddenTags = tags } +let currentVersionElementChanges: Map = new Map() +let currentVersionConnectorChanges: Map = new Map() +let currentVersionElementLineDeltas: Map = new Map() +export function setVersionDiff( + elementChanges: Map, + connectorChanges: Map, + elementLineDeltas: Map = new Map(), +): void { + currentVersionElementChanges = elementChanges + currentVersionConnectorChanges = connectorChanges + currentVersionElementLineDeltas = elementLineDeltas +} + /** * Get image from cache or start loading it. * Returns the image if already loaded, null otherwise. @@ -168,10 +173,125 @@ function clamp(v: number, min: number, max: number): number { return v < min ? min : v > max ? max : v } -function transitionT(screenW: number, start: number, end: number): number { +export function viewOriginX(view: ZUIViewState): number { + return view.originX ?? 0 +} + +export function viewOriginY(view: ZUIViewState): number { + return view.originY ?? 0 +} + +export function worldToScreenX(worldX: number, view: ZUIViewState): number { + return (worldX - viewOriginX(view)) * view.zoom + view.x +} + +export function worldToScreenY(worldY: number, view: ZUIViewState): number { + return (worldY - viewOriginY(view)) * view.zoom + view.y +} + +export function screenToWorldX(screenX: number, view: ZUIViewState): number { + return viewOriginX(view) + (screenX - view.x) / view.zoom +} + +export function screenToWorldY(screenY: number, view: ZUIViewState): number { + return viewOriginY(view) + (screenY - view.y) / view.zoom +} + +export function rawCameraView(view: ZUIViewState): ZUIViewState { + return { + x: view.x - viewOriginX(view) * view.zoom, + y: view.y - viewOriginY(view) * view.zoom, + zoom: view.zoom, + } +} + +function connectorAlpha(alpha: number): number { + return clamp(alpha * 1.15, CONNECTOR_MIN_ALPHA, CONNECTOR_MAX_ALPHA) +} + +function normalizeEdgeRouteType(type: string | null | undefined): 'bezier' | 'straight' | 'step' | 'smoothstep' { + if (type === 'straight' || type === 'step' || type === 'smoothstep') return type + return 'bezier' +} + +export interface ZUITransitionRebase { + preserveChildAlphaNodeIds: Set +} + +export function transitionT(screenW: number, start: number, end: number): number { return clamp((screenW - start) / (end - start), 0, 1) } +export function buildCameraTransitionRebase( + groups: DiagramGroupLayout[], + view: ZUIViewState, + canvasW: number, + canvasH: number, + thresholds: { start: number; end: number }, +): ZUITransitionRebase { + if (canvasW <= 0 || canvasH <= 0 || view.zoom <= 0) { + return { preserveChildAlphaNodeIds: new Set() } + } + + const worldCenterX = screenToWorldX(canvasW / 2, view) + const worldCenterY = screenToWorldY(canvasH / 2, view) + const path: Array<{ id: string; t: number }> = [] + + for (const group of groups) { + if ( + worldCenterX < group.worldX || + worldCenterX > group.worldX + group.worldW || + worldCenterY < group.worldY || + worldCenterY > group.worldY + group.worldH + ) { + continue + } + + let currentX = worldCenterX + let currentY = worldCenterY + let currentNodes = group.nodes + let cumulativeScale = 1 + + while (true) { + const node = currentNodes.find((candidate) => + currentX >= candidate.worldX && + currentX <= candidate.worldX + candidate.worldW && + currentY >= candidate.worldY && + currentY <= candidate.worldY + candidate.worldH + ) + + if (!node) break + + const hasChildren = node.children && node.children.length > 0 + const screenW = node.worldW * view.zoom * cumulativeScale + const t = hasChildren ? transitionT(screenW, thresholds.start, thresholds.end) : 0 + path.push({ id: node.id, t }) + + if (!hasChildren || t <= 0.05 || node.childScale <= 0) break + + currentX = (currentX - node.worldX) / node.childScale + node.childOffsetX + currentY = (currentY - node.worldY) / node.childScale + node.childOffsetY + currentNodes = node.children + cumulativeScale *= node.childScale + } + + break + } + + const activeTransitionIndexes = path + .map((entry, index) => ({ ...entry, index })) + .filter((entry) => entry.t > 0.05 && entry.t < 0.95) + + if (activeTransitionIndexes.length <= 1) { + return { preserveChildAlphaNodeIds: new Set() } + } + + const deepestActiveIndex = activeTransitionIndexes[activeTransitionIndexes.length - 1].index + return { + preserveChildAlphaNodeIds: new Set(path.slice(0, deepestActiveIndex).map((entry) => entry.id)), + } +} + function rectsOverlap(a: ScreenRect, b: ScreenRect): boolean { return a.left < b.right && a.right > b.left && a.top < b.bottom && a.bottom > b.top } @@ -262,8 +382,8 @@ export function isVisible( worldX: number, worldY: number, worldW: number, worldH: number, view: ZUIViewState, canvasW: number, canvasH: number, ): boolean { - const sx = worldX * view.zoom + view.x - const sy = worldY * view.zoom + view.y + const sx = worldToScreenX(worldX, view) + const sy = worldToScreenY(worldY, view) const sw = worldW * view.zoom const sh = worldH * view.zoom return sx + sw > 0 && sy + sh > 0 && sx < canvasW && sy < canvasH @@ -274,13 +394,204 @@ export function isFullyVisible( worldX: number, worldY: number, worldW: number, worldH: number, view: ZUIViewState, canvasW: number, canvasH: number, ): boolean { - const sx = worldX * view.zoom + view.x - const sy = worldY * view.zoom + view.y + const sx = worldToScreenX(worldX, view) + const sy = worldToScreenY(worldY, view) const sw = worldW * view.zoom const sh = worldH * view.zoom return sx >= 0 && sy >= 0 && sx + sw <= canvasW && sy + sh <= canvasH } +export interface ZUICameraRebase { + originX: number + originY: number + view: ZUIViewState +} + +export function getCameraRebase(view: ZUIViewState, canvasW: number, canvasH: number): ZUICameraRebase { + const zoom = Math.max(0.0001, view.zoom) + return { + originX: screenToWorldX(canvasW / 2, { ...view, zoom }), + originY: screenToWorldY(canvasH / 2, { ...view, zoom }), + view: { + x: canvasW / 2, + y: canvasH / 2, + zoom: view.zoom, + }, + } +} + +function rebaseRootNodeForRender(node: LayoutNode, rebase: ZUICameraRebase): LayoutNode { + return { + ...node, + worldX: node.worldX - rebase.originX, + worldY: node.worldY - rebase.originY, + } +} + +interface RebasedRenderGroup { + sourceNodes: LayoutNode[] + group: DiagramGroupLayout +} + +const rebasedRenderGroupCache = new WeakMap() + +function rebaseGroupForRender(group: DiagramGroupLayout, rebase: ZUICameraRebase): DiagramGroupLayout { + let cached = rebasedRenderGroupCache.get(group) + if (!cached || cached.sourceNodes !== group.nodes) { + cached = { + sourceNodes: group.nodes, + group: { + ...group, + nodes: group.nodes.map((node) => rebaseRootNodeForRender(node, rebase)), + }, + } + rebasedRenderGroupCache.set(group, cached) + } + + cached.group.worldX = group.worldX - rebase.originX + cached.group.worldY = group.worldY - rebase.originY + cached.group.worldW = group.worldW + cached.group.worldH = group.worldH + cached.group.diagramW = group.diagramW + cached.group.diagramH = group.diagramH + cached.group.diagramX = group.diagramX + cached.group.diagramY = group.diagramY + cached.group.edges = group.edges + + for (let index = 0; index < group.nodes.length; index += 1) { + const source = group.nodes[index] + const target = cached.group.nodes[index] + Object.assign(target, source, { + worldX: source.worldX - rebase.originX, + worldY: source.worldY - rebase.originY, + }) + } + + return cached.group +} + +interface FocusedFlattenedLayer { + nodes: LayoutNode[] + view: ZUIViewState +} + +function flattenNodeForRender( + node: LayoutNode, + absX: number, + absY: number, + layerScale: number, + rebase: ZUICameraRebase, +): LayoutNode { + return { + ...node, + worldX: (absX - rebase.originX) / layerScale, + worldY: (absY - rebase.originY) / layerScale, + worldW: node.worldW, + worldH: node.worldH, + children: [], + } +} + +function flattenSiblingLayerForRender( + nodes: LayoutNode[], + parentAbsX: number, + parentAbsY: number, + parentAbsScale: number, + parentChildOffsetX: number, + parentChildOffsetY: number, + rebase: ZUICameraRebase, +): LayoutNode[] { + return nodes.map((node) => { + const absX = parentAbsX + (node.worldX - parentChildOffsetX) * parentAbsScale + const absY = parentAbsY + (node.worldY - parentChildOffsetY) * parentAbsScale + return flattenNodeForRender(node, absX, absY, parentAbsScale, rebase) + }) +} + +export function findFocusedFlattenedLayerForTest( + groups: DiagramGroupLayout[], + view: ZUIViewState, + canvasW: number, + canvasH: number, + thresholds: { start: number; end: number }, + rebase: ZUICameraRebase, +): FocusedFlattenedLayer | null { + if (canvasW <= 0 || canvasH <= 0 || view.zoom < 1_000_000) return null + + const worldCenterX = screenToWorldX(canvasW / 2, view) + const worldCenterY = screenToWorldY(canvasH / 2, view) + + for (const group of groups) { + if ( + worldCenterX < group.worldX || + worldCenterX > group.worldX + group.worldW || + worldCenterY < group.worldY || + worldCenterY > group.worldY + group.worldH + ) { + continue + } + + let currentX = worldCenterX + let currentY = worldCenterY + let currentNodes = group.nodes + let parentAbsX = 0 + let parentAbsY = 0 + let parentAbsScale = 1 + let parentChildOffsetX = 0 + let parentChildOffsetY = 0 + let focusedLayer: FocusedFlattenedLayer | null = null + + while (true) { + const node = currentNodes.find((candidate) => + currentX >= candidate.worldX && + currentX <= candidate.worldX + candidate.worldW && + currentY >= candidate.worldY && + currentY <= candidate.worldY + candidate.worldH + ) + if (!node) break + + const absX = parentAbsX + (node.worldX - parentChildOffsetX) * parentAbsScale + const absY = parentAbsY + (node.worldY - parentChildOffsetY) * parentAbsScale + const hasChildren = node.children && node.children.length > 0 + const screenW = node.worldW * parentAbsScale * view.zoom + const t = hasChildren ? transitionT(screenW, thresholds.start, thresholds.end) : 0 + + if (!hasChildren || t < 0.95 || node.childScale <= 0) break + + const childAbsScale = parentAbsScale * node.childScale + focusedLayer = { + nodes: flattenSiblingLayerForRender( + node.children, + absX, + absY, + childAbsScale, + node.childOffsetX, + node.childOffsetY, + rebase, + ), + view: { + x: canvasW / 2, + y: canvasH / 2, + zoom: view.zoom * childAbsScale, + }, + } + + currentX = (currentX - node.worldX) / node.childScale + node.childOffsetX + currentY = (currentY - node.worldY) / node.childScale + node.childOffsetY + currentNodes = node.children + parentAbsX = absX + parentAbsY = absY + parentAbsScale = childAbsScale + parentChildOffsetX = node.childOffsetX + parentChildOffsetY = node.childOffsetY + } + + return focusedLayer + } + + return null +} + /** Draw the ZoomIn magnifying glass icon. */ function drawZoomInIcon(ctx: CanvasRenderingContext2D, x: number, y: number, size: number, strokeWidth: number): void { ctx.save() @@ -306,29 +617,7 @@ function drawZoomInIcon(ctx: CanvasRenderingContext2D, x: number, y: number, siz ctx.restore() } -/** Draw a portal arrow icon (↗) for portal nodes. */ -function drawPortalIcon(ctx: CanvasRenderingContext2D, x: number, y: number, size: number, strokeWidth: number, color: string): void { - ctx.save() - ctx.strokeStyle = color - ctx.lineWidth = strokeWidth - ctx.lineCap = 'round' - ctx.lineJoin = 'round' - ctx.translate(x, y) - const s = size / 16 - ctx.scale(s, s) - ctx.beginPath() - // Arrow shaft: (2,14) → (13,3) - ctx.moveTo(2, 14) - ctx.lineTo(13, 3) - // Arrow head - ctx.moveTo(5, 3) - ctx.lineTo(13, 3) - ctx.lineTo(13, 11) - ctx.stroke() - ctx.restore() -} - -/** Draw a cycle icon (↺) for circular nodes. */ +/** Draw a cycle icon (↺) for circular nodes. NOT USED CURRENTLY */ function drawCycleIcon(ctx: CanvasRenderingContext2D, x: number, y: number, size: number, strokeWidth: number, color: string): void { ctx.save() ctx.strokeStyle = color @@ -371,23 +660,26 @@ function portalTintColor(accent: string, alpha: number): string { return rgba } -/** Draw a squiggly line from (x1, y1) to (x2, y2). */ -function drawSquigglyLine(ctx: CanvasRenderingContext2D, x1: number, y1: number, x2: number, y2: number, zoom: number): void { - ctx.save() - ctx.beginPath() - ctx.moveTo(x1, y1) - ctx.lineTo(x2, y2) - const dashLen = 6 / zoom - ctx.setLineDash([dashLen, dashLen * 1.5]) - ctx.stroke() - ctx.restore() -} - /** Calculate coordinate for a named handle on a node. */ -function getHandlePos(nodeX: number, nodeY: number, nodeW: number, nodeH: number, handleId: string | null, isSource: boolean): { x: number, y: number, pos: 'top' | 'bottom' | 'left' | 'right' } { +function getHandlePos(nodeX: number, nodeY: number, nodeW: number, nodeH: number, handleId: string | null, isSource: boolean, slotScale = 1): { x: number, y: number, pos: 'top' | 'bottom' | 'left' | 'right' } { const fallback = isSource ? DEFAULT_SOURCE_HANDLE_SIDE : DEFAULT_TARGET_HANDLE_SIDE - const { x, y, side } = getHandleFlowPosition(nodeX, nodeY, nodeW, nodeH, handleId, fallback) - return { x, y, pos: side } + if (slotScale === 1) { + const { x, y, side } = getHandleFlowPosition(nodeX, nodeY, nodeW, nodeH, handleId, fallback) + return { x, y, pos: side } + } + + const side = getLogicalHandleId(handleId, fallback) ?? fallback + const offset = getHandleSlotOffsetFromId(handleId) * slotScale + switch (side) { + case 'top': + return { x: nodeX + nodeW / 2 + offset, y: nodeY, pos: side } + case 'bottom': + return { x: nodeX + nodeW / 2 + offset, y: nodeY + nodeH, pos: side } + case 'left': + return { x: nodeX, y: nodeY + nodeH / 2 + offset, pos: side } + case 'right': + return { x: nodeX + nodeW, y: nodeY + nodeH / 2 + offset, pos: side } + } } /** Draw a closed arrow head matching React Flow MarkerType.ArrowClosed. */ @@ -443,6 +735,7 @@ function drawNode( absY: number, absScale: number, occupiedLabelRects: ScreenRect[], + transitionRebase: ZUITransitionRebase, ): void { if (screenW < MIN_DRAW_PX || alpha < 0.01) return @@ -474,8 +767,8 @@ function drawNode( } const parentAlpha = alpha * (1 - t) - const childAlpha = alpha * t - const r = 8 / drawZoom // matches Chakra rounded="lg" (8px) + const childAlpha = transitionRebase.preserveChildAlphaNodeIds.has(node.id) ? alpha : alpha * t + const r = h * RADIUS_TO_NODE_H const borderColor = typeBorderColor(node.type) @@ -519,6 +812,19 @@ function drawNode( ctx.restore() } + // ── Shadow ─────────────────────────────────────────────────────── + // Subtler shadow for Canvas performance + if (parentAlpha > 0.5 && screenW > 40) { + ctx.save() + ctx.globalAlpha = parentAlpha * 0.4 + ctx.shadowColor = 'rgba(0, 0, 0, 0.5)' + ctx.shadowBlur = 12 / drawZoom + ctx.shadowOffsetY = 4 / drawZoom + traceShape() + ctx.fill() + ctx.restore() + } + // ── Background ─────────────────────────────────────────────────── // We draw two backgrounds: // 1. A base background (canvasBg) that remains opaque (total 'alpha'). @@ -605,10 +911,7 @@ function drawNode( // ── Label - portal shows "PORTAL" badge in accent; otherwise type badge ─ if (screenW >= MIN_LABEL_PX && parentAlpha > 0.1) { - // Dynamic minimum: don't let font be larger than a fraction of node height on screen - const minName = Math.min(MIN_FONT_NAME, screenW * 0.35) - // w=200, so 0.10w = 20px (Chakra 'xl') - const nameFontSize = getClampedFontSize(w * 0.10, minName, MAX_FONT_NAME, drawZoom) + const nameFontSize = h * NAME_FONT_TO_NODE_H const screenFontSize = nameFontSize * drawZoom if (screenFontSize >= 6) { @@ -637,13 +940,10 @@ function drawNode( // Type badge - using regular element type display if (drawScreenW > BADGE_THRESHOLD) { - const minBadge = Math.min(MIN_FONT_BADGE, screenW * 0.20) - // 0.05w = 10px (Chakra '2xs') - const badgeFontSize = getClampedFontSize(w * 0.05, minBadge, MAX_FONT_BADGE, drawZoom) + const badgeFontSize = h * TYPE_FONT_TO_NODE_H if (badgeFontSize * drawZoom >= 5) { ctx.font = `${badgeFontSize}px Inter, system-ui, sans-serif` - const badgeColor = TYPE_COLOR_400[node.type] - ctx.fillStyle = typeof badgeColor === 'string' ? badgeColor : '#a0aec0' + ctx.fillStyle = '#a0aec0' const displayType = typeof node.type === 'string' ? node.type.toUpperCase() : 'UNKNOWN' ctx.fillText(displayType, x + w / 2, y + h * (0.62 + baseOffset)) } @@ -663,13 +963,13 @@ function drawNode( if (t > 0.8) { // Sticky hint Y: stick to viewport bottom - const viewportBottomWorld = (canvasH - screenFontSize - view.y) / view.zoom + const viewportBottomWorld = screenToWorldY(canvasH - screenFontSize, view) hintY = Math.min(hintY, viewportBottomWorld) hintY = Math.max(hintY, y + h / 2) // avoid overlapping center // Sticky hint X: stick to viewport sides - const vwL = -view.x / view.zoom - const vwR = (canvasW - view.x) / view.zoom + const vwL = screenToWorldX(0, view) + const vwR = screenToWorldX(canvasW, view) ctx.save() ctx.font = `${hintFontSize}px Inter, system-ui, sans-serif` @@ -713,7 +1013,7 @@ function drawNode( // Recursive children's edges DRAWN FIRST (below nodes) if (childAlpha > 0.2) { - drawEdges(ctx, node.children, childAlpha * 0.5, edgeZoom, thresholds, accent, labelBg, occupiedLabelRects) + drawEdges(ctx, node.children, childAlpha * 0.8, edgeZoom, thresholds, accent, labelBg, occupiedLabelRects) } const nextAbsScale = absScale * node.childScale @@ -725,7 +1025,7 @@ function drawNode( if (!isVisible(childAbsX, childAbsY, childAbsW, childAbsH, view, canvasW, canvasH)) continue const childScreenW = child.worldW * childZoom - drawNode(ctx, child, childScreenW, thresholds, childAlpha, childZoom, nodeBg, canvasBg, view, canvasW, canvasH, accent, labelBg, childAbsX, childAbsY, nextAbsScale, occupiedLabelRects) + drawNode(ctx, child, childScreenW, thresholds, childAlpha, childZoom, nodeBg, canvasBg, view, canvasW, canvasH, accent, labelBg, childAbsX, childAbsY, nextAbsScale, occupiedLabelRects, transitionRebase) } ctx.restore() @@ -742,11 +1042,8 @@ function drawNode( ctx.strokeStyle = accent if (node.isCircular) { drawCycleIcon(ctx, x + w - iconSize - padding, y + padding, iconSize, 3.5, accent) - } else if (node.isPortal) { - // Portal: use arrow icon instead of magnifying glass - drawPortalIcon(ctx, x + w - iconSize - padding, y + padding, iconSize, 3.5, accent) } else { - drawZoomInIcon(ctx, x + w - iconSize - padding, y + padding, iconSize, 3.5) + drawZoomInIcon(ctx, x + w - iconSize - padding, y + padding, iconSize, 2.5) } ctx.restore() } @@ -777,6 +1074,58 @@ function drawNode( } } + if ((currentVersionElementChanges.size > 0 || currentVersionConnectorChanges.size > 0) && parentAlpha > 0.05) { + const change = currentVersionElementChanges.get(node.elementId) + if (!change) { + ctx.save() + ctx.globalAlpha = parentAlpha * 0.9 + ctx.fillStyle = canvasBg + traceShape() + ctx.fill() + ctx.restore() + } else { + const color = change === 'added' ? '#68d391' : change === 'deleted' ? '#fc8181' : '#f6e05e' + ctx.save() + ctx.globalAlpha = parentAlpha + ctx.shadowColor = color + ctx.shadowBlur = 8 / drawZoom + ctx.strokeStyle = color + ctx.lineWidth = 2.5 / drawZoom + traceShape() + ctx.stroke() + ctx.restore() + + } + } + + const delta = currentVersionElementLineDeltas.get(node.elementId) + if (delta && (delta.added > 0 || delta.removed > 0) && drawScreenW > 52 && parentAlpha > 0.05) { + const addText = delta.added > 0 ? `+${delta.added}` : '' + const removeText = delta.removed > 0 ? `-${delta.removed}` : '' + const badgeText = [addText, removeText].filter(Boolean).join(' ') + const fontSize = getClampedFontSize(12, 8, 13, drawZoom) + ctx.save() + ctx.globalAlpha = parentAlpha + ctx.font = `800 ${fontSize}px Inter, system-ui, sans-serif` + const textWidth = ctx.measureText(badgeText).width + const badgeW = textWidth + 12 / drawZoom + const badgeH = 20 / drawZoom + const badgeX = x + w - badgeW - 6 / drawZoom + const badgeY = y + h - badgeH - 6 / drawZoom + ctx.fillStyle = 'rgba(17, 24, 39, 0.9)' + ctx.strokeStyle = 'rgba(255, 255, 255, 0.22)' + ctx.lineWidth = 1 / drawZoom + ctx.beginPath() + ctx.roundRect(badgeX, badgeY, badgeW, badgeH, 5 / drawZoom) + ctx.fill() + ctx.stroke() + ctx.textAlign = 'center' + ctx.textBaseline = 'middle' + ctx.fillStyle = delta.added > 0 && delta.removed === 0 ? '#68d391' : delta.removed > 0 && delta.added === 0 ? '#fc8181' : '#e2e8f0' + ctx.fillText(badgeText, badgeX + badgeW / 2, badgeY + badgeH / 2) + ctx.restore() + } + if (!hasChildren && screenW > thresholds.end) { ctx.restore() } @@ -882,7 +1231,7 @@ function drawEdges( } const dir = edge.direction ?? 'forward' - const type = edge.type || 'bezier' + const type = normalizeEdgeRouteType(edge.type) // ── Effective visual dimensions (handles capping) ───────────── const hasSourceChildren = node.children && node.children.length > 0 @@ -922,6 +1271,7 @@ function drawEdges( effHSource, getVisualHandleIdForGroup(sourceSide, sourceGroupIndex, Math.max(srcGroup.length, 1)), true, + sSource, ) const tH = getHandlePos( effXTarget, @@ -930,12 +1280,15 @@ function drawEdges( effHTarget, getVisualHandleIdForGroup(targetSide, targetGroupIndex, Math.max(tgtGroup.length, 1)), false, + sTarget, ) ctx.save() - ctx.globalAlpha = alpha * 0.8 + const edgeChange = currentVersionConnectorChanges.get(edge.id) + const versionPreviewActive = currentVersionElementChanges.size > 0 || currentVersionConnectorChanges.size > 0 + ctx.globalAlpha = versionPreviewActive && !edgeChange ? Math.max(alpha * 0.18, 0.08) : connectorAlpha(alpha) ctx.strokeStyle = accent - ctx.lineWidth = 2 / zoom + ctx.lineWidth = CONNECTOR_LINE_PX / zoom let midX = (sH.x + tH.x) / 2 let midY = (sH.y + tH.y) / 2 @@ -1148,7 +1501,7 @@ function drawGroupLabel( const labelY = group.worldY + group.diagramY - 22 / view.zoom // Ensure label is within viewport - const screenY = labelY * view.zoom + view.y + const screenY = worldToScreenY(labelY, view) if (screenY < -20 || screenY > canvasH + 20) return ctx.save() @@ -1183,6 +1536,40 @@ function drawGroupLabel( } +/** Draw a dot grid matching React Flow default style. */ +function drawGrid(ctx: CanvasRenderingContext2D, view: ZUIViewState, canvasW: number, canvasH: number): void { + const gridSize = 20 + const dotSize = 1.0 + const color = 'rgba(255, 255, 255, 0.1)' // subtle white dots on dark background + const rebase = getCameraRebase(view, canvasW, canvasH) + + const left = screenToWorldX(0, view) + const top = screenToWorldY(0, view) + const right = screenToWorldX(canvasW, view) + const bottom = screenToWorldY(canvasH, view) + + const startX = Math.floor(left / gridSize) * gridSize + const startY = Math.floor(top / gridSize) * gridSize + + ctx.save() + ctx.fillStyle = color + + // Dot grid rendering: only show if zoom is not too small + if (view.zoom > 0.2) { + for (let wx = startX; wx < right; wx += gridSize) { + for (let wy = startY; wy < bottom; wy += gridSize) { + const sx = (wx - rebase.originX) * rebase.view.zoom + rebase.view.x + const sy = (wy - rebase.originY) * rebase.view.zoom + rebase.view.y + + ctx.beginPath() + ctx.arc(sx, sy, dotSize, 0, Math.PI * 2) + ctx.fill() + } + } + } + ctx.restore() +} + // ── Public: render one frame ─────────────────────────────────────── /** @@ -1206,81 +1593,70 @@ export function renderFrame( ctx.fillStyle = canvasBg ctx.fillRect(0, 0, canvasW, canvasH) + drawGrid(ctx, view, canvasW, canvasH) + + const rebase = getCameraRebase(view, canvasW, canvasH) + const renderView = rebase.view + const renderGroups = groups.map((group) => rebaseGroupForRender(group, rebase)) // Apply world transform ctx.save() - ctx.translate(view.x, view.y) - ctx.scale(view.zoom, view.zoom) + ctx.translate(renderView.x, renderView.y) + ctx.scale(renderView.zoom, renderView.zoom) const thresholds = getExpandThresholds(canvasW) + const transitionRebase = buildCameraTransitionRebase(renderGroups, renderView, canvasW, canvasH, thresholds) const occupiedLabelRects = frameLabelRects occupiedLabelRects.length = 0 + const focusedLayer = findFocusedFlattenedLayerForTest(groups, view, canvasW, canvasH, thresholds, rebase) - for (const group of groups) { - if (!isVisible(group.worldX, group.worldY, group.worldW, group.worldH, view, canvasW, canvasH)) { + if (focusedLayer) { + ctx.restore() + ctx.save() + ctx.translate(focusedLayer.view.x, focusedLayer.view.y) + ctx.scale(focusedLayer.view.zoom, focusedLayer.view.zoom) + drawEdges(ctx, focusedLayer.nodes, 0.7, focusedLayer.view.zoom, thresholds, accent, labelBg, occupiedLabelRects) + for (const node of focusedLayer.nodes) { + if (!isVisible(node.worldX, node.worldY, node.worldW, node.worldH, focusedLayer.view, canvasW, canvasH)) { + continue + } + const screenW = node.worldW * focusedLayer.view.zoom + drawNode(ctx, node, screenW, thresholds, 1, focusedLayer.view.zoom, nodeBg, canvasBg, focusedLayer.view, canvasW, canvasH, accent, labelBg, node.worldX, node.worldY, 1, occupiedLabelRects, transitionRebase) + } + ctx.restore() + return occupiedLabelRects + } + + for (const group of renderGroups) { + if (!isVisible(group.worldX, group.worldY, group.worldW, group.worldH, renderView, canvasW, canvasH)) { continue } - drawGroupLabel(ctx, group, view, canvasW, canvasH, accent) + drawGroupLabel(ctx, group, renderView, canvasW, canvasH, accent) // ── Group box (diagram elements container) ────────────────────────── - const borderAlpha = clamp(0.5 - view.zoom * 0.05, 0.15, 0.5) + const borderAlpha = clamp(0.5 - renderView.zoom * 0.05, 0.15, 0.5) ctx.save() ctx.globalAlpha = borderAlpha ctx.strokeStyle = accent - ctx.lineWidth = 2 / view.zoom + ctx.lineWidth = 2 / renderView.zoom ctx.setLineDash([2, 2]) - // Only draw the border around the diagram part (not portals) + // Only draw the border around the diagram part ctx.strokeRect(group.worldX + group.diagramX, group.worldY + group.diagramY, group.diagramW, group.diagramH) ctx.setLineDash([]) ctx.restore() - // ── Squiggly edges to portal nodes ──────────────────────────────── - ctx.save() - ctx.strokeStyle = accent - ctx.setLineDash([]) - ctx.lineWidth = 2 / view.zoom - ctx.globalAlpha = 0.6 - for (const node of group.nodes) { - if (node.isPortal) { - // Draw squiggle/dash from diagram box boundary to portal box boundary - const cx = group.worldX + group.diagramX + group.diagramW / 2 - const cy = group.worldY + group.diagramY + group.diagramH / 2 - const px = node.worldX + node.worldW / 2 - const py = node.worldY + node.worldH / 2 - - const dx = px - cx - const dy = py - cy - - const getBBoxIntersection = (boxW: number, boxH: number, targetDX: number, targetDY: number) => { - const hw = boxW / 2 + 10 // pad - const hh = boxH / 2 + 10 // pad - if (Math.abs(targetDX * hh) > Math.abs(targetDY * hw)) { - return { x: Math.sign(targetDX) * hw, y: targetDY * (hw / Math.abs(targetDX)) } - } else { - return { x: targetDX * (hh / Math.abs(targetDY)), y: Math.sign(targetDY) * hh } - } - } - - const start = getBBoxIntersection(group.diagramW, group.diagramH, dx, dy) - const end = getBBoxIntersection(node.worldW, node.worldH, -dx, -dy) - - drawSquigglyLine(ctx, cx + start.x, cy + start.y, px + end.x, py + end.y, view.zoom) - } - } - ctx.restore() - // Edges in this group - drawEdges(ctx, group.nodes, 0.7, view.zoom, thresholds, accent, labelBg, occupiedLabelRects) + drawEdges(ctx, group.nodes, 0.7, renderView.zoom, thresholds, accent, labelBg, occupiedLabelRects) // Nodes in this group for (const node of group.nodes) { - if (!isVisible(node.worldX, node.worldY, node.worldW, node.worldH, view, canvasW, canvasH)) { + if (!isVisible(node.worldX, node.worldY, node.worldW, node.worldH, renderView, canvasW, canvasH)) { continue } - const screenW = node.worldW * view.zoom - drawNode(ctx, node, screenW, thresholds, 1, view.zoom, nodeBg, canvasBg, view, canvasW, canvasH, accent, labelBg, node.worldX, node.worldY, 1, occupiedLabelRects) + const screenW = node.worldW * renderView.zoom + drawNode(ctx, node, screenW, thresholds, 1, renderView.zoom, nodeBg, canvasBg, renderView, canvasW, canvasH, accent, labelBg, node.worldX, node.worldY, 1, occupiedLabelRects, transitionRebase) } } diff --git a/frontend/src/components/ZUI/types.ts b/frontend/src/components/ZUI/types.ts index 31710ed..9324d77 100644 --- a/frontend/src/components/ZUI/types.ts +++ b/frontend/src/components/ZUI/types.ts @@ -10,6 +10,10 @@ export interface ZUIViewState { y: number /** Current zoom multiplier (1 = 1 world-pixel per screen-pixel). */ zoom: number + /** World-space X rendered at the local camera origin. Keeps x screen-sized at deep zoom. */ + originX?: number + /** World-space Y rendered at the local camera origin. Keeps y screen-sized at deep zoom. */ + originY?: number } /** @@ -71,6 +75,7 @@ export interface LayoutNode { // ── Edges within the same diagram ──────────────────────────────── edgesOut: Array<{ + id: number /** LayoutNode id of the target. */ targetId: string label: string @@ -103,6 +108,7 @@ export interface DiagramGroupLayout { nodes: LayoutNode[] /** Edges whose both endpoints are in this diagram. */ edges: Array<{ + id: number sourceId: string targetId: string label: string diff --git a/frontend/src/components/ZUI/useZUIInteraction.ts b/frontend/src/components/ZUI/useZUIInteraction.ts index 0fc9021..86479b7 100644 --- a/frontend/src/components/ZUI/useZUIInteraction.ts +++ b/frontend/src/components/ZUI/useZUIInteraction.ts @@ -2,23 +2,39 @@ import { useCallback, useEffect, useRef, useState, useMemo } from 'react' import type { BBox, DiagramGroupLayout, LayoutNode, ZUIViewState, HoveredItem } from './types' -import { getExpandThresholds } from './renderer' +import { getExpandThresholds, screenToWorldX, screenToWorldY, viewOriginX, viewOriginY } from './renderer' + +export function constrainViewState(view: ZUIViewState, canvasW: number, canvasH: number, bbox: BBox): ZUIViewState { + const padding = Math.min(600, canvasW * 0.45, canvasH * 0.45) + const normalized = normalizeViewState(view, canvasW, canvasH) + const halfVisibleX = Math.max(0, canvasW / 2 - padding) / normalized.zoom + const halfVisibleY = Math.max(0, canvasH / 2 - padding) / normalized.zoom + const minOriginX = bbox.minX - halfVisibleX + const maxOriginX = bbox.maxX + halfVisibleX + const minOriginY = bbox.minY - halfVisibleY + const maxOriginY = bbox.maxY + halfVisibleY -function constrainViewState(view: ZUIViewState, canvasW: number, canvasH: number, bbox: BBox): ZUIViewState { - const padding = 600 // pixels - const minX = padding - bbox.maxX * view.zoom - const maxX = canvasW - padding - bbox.minX * view.zoom - const minY = padding - bbox.maxY * view.zoom - const maxY = canvasH - padding - bbox.minY * view.zoom - - let { x, y } = view - if (maxX >= minX) x = Math.max(minX, Math.min(maxX, x)) - else x = (minX + maxX) / 2 - - if (maxY >= minY) y = Math.max(minY, Math.min(maxY, y)) - else y = (minY + maxY) / 2 + return { + ...normalized, + originX: maxOriginX >= minOriginX + ? Math.max(minOriginX, Math.min(maxOriginX, viewOriginX(normalized))) + : (minOriginX + maxOriginX) / 2, + originY: maxOriginY >= minOriginY + ? Math.max(minOriginY, Math.min(maxOriginY, viewOriginY(normalized))) + : (minOriginY + maxOriginY) / 2, + } +} - return { ...view, x, y } +function normalizeViewState(view: ZUIViewState, canvasW: number, canvasH: number): ZUIViewState { + const zoom = Math.max(0.0001, view.zoom) + return { + ...view, + x: canvasW / 2, + y: canvasH / 2, + zoom, + originX: screenToWorldX(canvasW / 2, { ...view, zoom }), + originY: screenToWorldY(canvasH / 2, { ...view, zoom }), + } } interface DeepestNodeResult { @@ -392,6 +408,12 @@ export function calculateMaxZoom(groups: DiagramGroupLayout[], canvasW: number): } const MIN_ZOOM = 0.4 +const ZUI_NATIVE_WHEEL_SELECTOR = '[data-zui-native-wheel="true"]' + +function shouldIgnoreCapturedWheel(e: WheelEvent): boolean { + const target = e.target + return target instanceof Element && target.closest(ZUI_NATIVE_WHEEL_SELECTOR) !== null +} function clampZoom(z: number, prevZ: number, maxZ: number): number { if (z > prevZ) { @@ -412,11 +434,16 @@ function zoomAround( maxZoom: number, ): ZUIViewState { const newZoom = clampZoom(view.zoom * factor, view.zoom, maxZoom) - const scale = newZoom / view.zoom + const worldX = screenToWorldX(focalX, view) + const worldY = screenToWorldY(focalY, view) + const originX = viewOriginX(view) + const originY = viewOriginY(view) return { + originX, + originY, zoom: newZoom, - x: focalX - (focalX - view.x) * scale, - y: focalY - (focalY - view.y) * scale, + x: focalX - (worldX - originX) * newZoom, + y: focalY - (worldY - originY) * newZoom, } } @@ -590,6 +617,17 @@ export function useZUIInteraction( if (!el) return function onWheel(e: WheelEvent) { + if (shouldIgnoreCapturedWheel(e)) return + + const rect = el!.getBoundingClientRect() + const isInsideCanvas = + e.clientX >= rect.left && + e.clientX <= rect.right && + e.clientY >= rect.top && + e.clientY <= rect.bottom + + if (!isInsideCanvas) return + // Heuristic to distinguish between trackpad and physical mouse wheel: // 1. If ctrlKey is true, it's a pinch (trackpad) or Ctrl+Wheel. We always zoom. // 2. If deltaMode !== 0, it's a physical mouse wheel (DOM_DELTA_LINE/PAGE). We zoom. @@ -619,7 +657,6 @@ export function useZUIInteraction( const isRealMouseWheel = e.deltaMode !== 0 || isNotchedWheel if (isPinch || isRealMouseWheel) { - const rect = el!.getBoundingClientRect() const focalX = e.clientX - rect.left const focalY = e.clientY - rect.top @@ -628,8 +665,8 @@ export function useZUIInteraction( factor = Math.max(0.85, Math.min(1.15, factor)) setViewState((prev) => { - const worldX = (focalX - prev.x) / prev.zoom - const worldY = (focalY - prev.y) / prev.zoom + const worldX = screenToWorldX(focalX, prev) + const worldY = screenToWorldY(focalY, prev) const thresholds = getExpandThresholds(rect.width) const deepest = findDeepestAt(worldX, worldY, groupsRef.current, prev, thresholds) @@ -677,8 +714,8 @@ export function useZUIInteraction( // Hover detection const view = viewStateRef.current - const worldX = (screenX - view.x) / view.zoom - const worldY = (screenY - view.y) / view.zoom + const worldX = screenToWorldX(screenX, view) + const worldY = screenToWorldY(screenY, view) const thresholds = getExpandThresholds(rect.width) const deepest = findDeepestAt(worldX, worldY, groupsRef.current, view, thresholds) @@ -731,8 +768,8 @@ export function useZUIInteraction( setHoveredItem(null, true) // Clear popover immediately on double-click zoom setViewState((prev) => { - const worldX = (focalX - prev.x) / prev.zoom - const worldY = (focalY - prev.y) / prev.zoom + const worldX = screenToWorldX(focalX, prev) + const worldY = screenToWorldY(focalY, prev) const thresholds = getExpandThresholds(rect.width) const deepest = findDeepestAt(worldX, worldY, groupsRef.current, prev, thresholds) @@ -802,8 +839,8 @@ export function useZUIInteraction( if (isFinite(factor) && factor > 0) { setViewState((prev) => { const rect = el!.getBoundingClientRect() - const worldX = (mid.x - prev.x) / prev.zoom - const worldY = (mid.y - prev.y) / prev.zoom + const worldX = screenToWorldX(mid.x, prev) + const worldY = screenToWorldY(mid.y, prev) const thresholds = getExpandThresholds(rect.width) const deepest = findDeepestAt(worldX, worldY, groupsRef.current, prev, thresholds) @@ -843,7 +880,7 @@ export function useZUIInteraction( el.style.cursor = 'grab' - el.addEventListener('wheel', onWheel, { passive: false }) + window.addEventListener('wheel', onWheel, { passive: false, capture: true }) el.addEventListener('mousedown', onMouseDown) el.addEventListener('mouseleave', onMouseOut) el.addEventListener('mouseout', onMouseOut) @@ -856,7 +893,7 @@ export function useZUIInteraction( el.addEventListener('touchcancel', onTouchEnd) return () => { - el.removeEventListener('wheel', onWheel) + window.removeEventListener('wheel', onWheel, { capture: true }) el.removeEventListener('mousedown', onMouseDown) el.removeEventListener('mouseleave', onMouseOut) el.removeEventListener('mouseout', onMouseOut) diff --git a/frontend/src/context/WorkspaceVersionContext.tsx b/frontend/src/context/WorkspaceVersionContext.tsx new file mode 100644 index 0000000..c45bed9 --- /dev/null +++ b/frontend/src/context/WorkspaceVersionContext.tsx @@ -0,0 +1,126 @@ +import { createContext, useCallback, useContext, useEffect, useMemo, useRef, useState } from 'react' +import type { WatchDiff, WatchRepository, WatchVersion, WorkspaceVersion } from '../api/client' +import { normalizeWatchChangeType } from '../utils/watchDiffSummary' + +export type VersionChangeType = 'added' | 'updated' | 'deleted' | 'initialized' + +export interface VersionLineDelta { + added: number + removed: number +} + +export interface WorkspaceVersionPreview { + repository: WatchRepository | null + version: WatchVersion | null + workspaceVersions: WorkspaceVersion[] + diffs: WatchDiff[] + elementChanges: Map + elementLineDeltas: Map + connectorChanges: Map + summary: { + added: number + updated: number + deleted: number + initialized: number + elements: number + connectors: number + } +} + +export interface WorkspaceVersionFollowTarget { + token: number + resourceType: string + resourceId?: number + viewId?: number + changeType?: VersionChangeType +} + +interface WorkspaceVersionContextValue { + preview: WorkspaceVersionPreview | null + followToken: number + followTarget: WorkspaceVersionFollowTarget | null + setPreview: (preview: WorkspaceVersionPreview | null) => void + clearPreview: () => void + requestFollow: (target?: Omit) => void +} + +const WorkspaceVersionContext = createContext(null) + +export function buildWorkspaceVersionPreview(args: { + repository: WatchRepository | null + version: WatchVersion | null + workspaceVersions: WorkspaceVersion[] + diffs: WatchDiff[] | null | undefined +}): WorkspaceVersionPreview { + const diffs = Array.isArray(args.diffs) ? args.diffs : [] + const elementChanges = new Map() + const elementLineDeltas = new Map() + const connectorChanges = new Map() + const summary = { added: 0, updated: 0, deleted: 0, initialized: 0, elements: 0, connectors: 0 } + + diffs.forEach((diff) => { + const change = normalizeWatchChangeType(diff.change_type) + summary[change] += 1 + if (diff.resource_type === 'element' && diff.resource_id) { + elementChanges.set(diff.resource_id, change) + const added = Math.max(0, diff.added_lines ?? 0) + const removed = Math.max(0, diff.removed_lines ?? 0) + if (added > 0 || removed > 0) { + elementLineDeltas.set(diff.resource_id, { added, removed }) + } + summary.elements += 1 + } + if (diff.resource_type === 'connector' && diff.resource_id) { + connectorChanges.set(diff.resource_id, change) + summary.connectors += 1 + } + }) + + return { + repository: args.repository, + version: args.version, + workspaceVersions: args.workspaceVersions, + diffs, + elementChanges, + elementLineDeltas, + connectorChanges, + summary, + } +} + +export function WorkspaceVersionProvider({ children }: { children: React.ReactNode }) { + const [preview, setPreview] = useState(null) + const [followToken, setFollowToken] = useState(0) + const [followTarget, setFollowTarget] = useState(null) + const followClearTimerRef = useRef(null) + const clearPreview = useCallback(() => setPreview(null), []) + const requestFollow = useCallback((target?: Omit) => { + setFollowToken((value) => value + 1) + if (followClearTimerRef.current !== null) { + window.clearTimeout(followClearTimerRef.current) + followClearTimerRef.current = null + } + if (!target) { + setFollowTarget(null) + return + } + setFollowTarget({ ...target, token: Date.now() }) + followClearTimerRef.current = window.setTimeout(() => { + setFollowTarget(null) + followClearTimerRef.current = null + }, 1600) + }, []) + useEffect(() => { + return () => { + if (followClearTimerRef.current !== null) window.clearTimeout(followClearTimerRef.current) + } + }, []) + const value = useMemo(() => ({ preview, followToken, followTarget, setPreview, clearPreview, requestFollow }), [preview, followToken, followTarget, clearPreview, requestFollow]) + return {children} +} + +export function useWorkspaceVersionPreview() { + const value = useContext(WorkspaceVersionContext) + if (!value) throw new Error('useWorkspaceVersionPreview must be used within WorkspaceVersionProvider') + return value +} diff --git a/frontend/src/crossBranch/resolve.test.ts b/frontend/src/crossBranch/resolve.test.ts new file mode 100644 index 0000000..7673ff8 --- /dev/null +++ b/frontend/src/crossBranch/resolve.test.ts @@ -0,0 +1,342 @@ +import { describe, expect, it } from 'vitest' +import { buildWorkspaceGraphSnapshot } from './graph' +import { resolveZUIProxyConnectors } from './resolve' +import type { ResolveZUIProxyConnectorOptions, ZUIConnectorAnchorInfo } from './resolve' +import type { Connector, ExploreData, PlacedElement, ViewTreeNode } from '../types' +import type { CrossBranchContextSettings } from './types' + +function placedElement(view_id: number, element_id: number, name: string): PlacedElement { + return { + id: view_id * 100 + element_id, + view_id, + element_id, + position_x: element_id * 10, + position_y: 0, + name, + description: null, + kind: 'service', + technology: null, + url: null, + logo_url: null, + technology_connectors: [], + tags: [], + has_view: element_id === 1 || element_id === 3, + view_label: null, + } +} + +function connector(id: number, view_id: number, source_element_id: number, target_element_id: number, label: string): Connector { + return { + id, + view_id, + source_element_id, + target_element_id, + label, + description: null, + relationship: null, + direction: 'forward', + style: 'bezier', + url: null, + source_handle: null, + target_handle: null, + created_at: '2024-01-01', + updated_at: '2024-01-01', + } +} + +function zuiSettings(overrides: Partial = {}): CrossBranchContextSettings { + return { + enabled: true, + depth: 5, + connectorBudget: 50, + connectorPriority: 'external', + ...overrides, + } +} + +function anchor(nodeId: string, x: number, y = 0): ZUIConnectorAnchorInfo { + return { + nodeId, + worldX: x, + worldY: y, + worldW: 10, + worldH: 10, + } +} + +function viewportOptions(anchorsByElementId: Map): ResolveZUIProxyConnectorOptions { + return { + anchorsByElementId, + viewport: { + minX: 0, + minY: 0, + maxX: 100, + maxY: 100, + centerX: 50, + centerY: 50, + }, + } +} + +const tree: ViewTreeNode[] = [{ + id: 1, + owner_element_id: null, + name: 'Root', + description: null, + level_label: null, + level: 0, + depth: 0, + created_at: '2024-01-01', + updated_at: '2024-01-01', + parent_view_id: null, + children: [{ + id: 2, + owner_element_id: 1, + name: 'A Child', + description: null, + level_label: null, + level: 1, + depth: 1, + created_at: '2024-01-01', + updated_at: '2024-01-01', + parent_view_id: 1, + children: [{ + id: 3, + owner_element_id: 3, + name: 'AA Child', + description: null, + level_label: null, + level: 2, + depth: 2, + created_at: '2024-01-01', + updated_at: '2024-01-01', + parent_view_id: 2, + children: [], + }], + }], +}] + +function baseData(connectors: Connector[]): ExploreData { + return { + tree, + navigations: [ + { id: 1, element_id: 1, from_view_id: 1, to_view_id: 2, to_view_name: 'A Child', relation_type: 'child' }, + { id: 2, element_id: 3, from_view_id: 2, to_view_id: 3, to_view_name: 'AA Child', relation_type: 'child' }, + ], + views: { + '1': { + placements: [placedElement(1, 1, 'A'), placedElement(1, 2, 'B')], + connectors, + }, + '2': { + placements: [placedElement(2, 3, 'AA')], + connectors: [], + }, + '3': { + placements: [placedElement(3, 4, 'AAA')], + connectors: [], + }, + }, + } +} + +describe('resolveZUIProxyConnectors', () => { + it('collapses direct-child cross-branch links into a native +N badge', () => { + const snapshot = buildWorkspaceGraphSnapshot(baseData([ + connector(1, 1, 1, 2, 'A-B'), + connector(2, 1, 3, 2, 'AA-B'), + ])) + + const resolved = resolveZUIProxyConnectors( + snapshot, + new Map([ + [1, 'd1-o1'], + [2, 'd1-o2'], + ]), + zuiSettings(), + ) + + expect(resolved.connectors).toHaveLength(0) + expect(resolved.hiddenBadges).toHaveLength(1) + expect(resolved.hiddenBadges[0]).toMatchObject({ + sourceAnchorElementId: 1, + targetAnchorElementId: 2, + count: 1, + }) + }) + + it('fractures the badge into direct child and parent connectors when the child is visible', () => { + const snapshot = buildWorkspaceGraphSnapshot(baseData([ + connector(1, 1, 1, 2, 'A-B'), + connector(2, 1, 3, 2, 'AA-B'), + ])) + + const resolved = resolveZUIProxyConnectors( + snapshot, + new Map([ + [1, 'd1-o1'], + [2, 'd1-o2'], + [3, 'd2-o3'], + ]), + zuiSettings(), + ) + + expect(resolved.hiddenBadges).toHaveLength(0) + expect(resolved.connectors.map((item) => [item.sourceAnchorElementId, item.targetAnchorElementId])).toEqual([[2, 3]]) + }) + + it('keeps only the deepest visible connector and its parent once grandchildren are visible', () => { + const snapshot = buildWorkspaceGraphSnapshot(baseData([ + connector(1, 1, 1, 2, 'A-B'), + connector(2, 1, 4, 2, 'AAA-B'), + ])) + + const resolved = resolveZUIProxyConnectors( + snapshot, + new Map([ + [1, 'd1-o1'], + [2, 'd1-o2'], + [3, 'd2-o3'], + [4, 'd3-o4'], + ]), + zuiSettings(), + ) + + expect(resolved.hiddenBadges).toHaveLength(0) + expect(resolved.connectors.map((item) => [item.sourceAnchorElementId, item.targetAnchorElementId]).sort()).toEqual([ + [2, 3], + [2, 4], + ]) + }) + + it('budgets visible connector groups and reports the omitted leaf count', () => { + const data = baseData([ + connector(1, 1, 3, 2, 'one'), + connector(2, 1, 4, 2, 'two'), + connector(3, 1, 5, 2, 'three'), + ]) + data.views['2'].placements = [ + placedElement(2, 3, 'AA'), + placedElement(2, 4, 'AB'), + placedElement(2, 5, 'AC'), + ] + const snapshot = buildWorkspaceGraphSnapshot(data) + + const resolved = resolveZUIProxyConnectors( + snapshot, + new Map([ + [1, 'd1-o1'], + [2, 'd1-o2'], + [3, 'd2-o3'], + [4, 'd2-o4'], + [5, 'd2-o5'], + ]), + zuiSettings({ connectorBudget: 2 }), + ) + + expect(resolved.connectors).toHaveLength(2) + expect(resolved.omittedConnectorCount).toBe(3) + }) + + it('reducing the budget keeps a subset of the larger-budget result', () => { + const connectors = [ + connector(1, 1, 4, 2, 'deep-one'), + connector(2, 1, 4, 2, 'deep-two'), + connector(3, 1, 3, 2, 'shallow'), + ] + const snapshot = buildWorkspaceGraphSnapshot(baseData(connectors)) + const visibleNodes = new Map([ + [1, 'd1-o1'], + [2, 'd1-o2'], + [3, 'd2-o3'], + [4, 'd3-o4'], + ]) + + const budgetTwo = resolveZUIProxyConnectors( + snapshot, + visibleNodes, + zuiSettings({ connectorBudget: 2 }), + ) + const budgetOne = resolveZUIProxyConnectors( + snapshot, + visibleNodes, + zuiSettings({ connectorBudget: 1 }), + ) + + const budgetTwoKeys = new Set(budgetTwo.connectors.map((item) => item.key)) + expect(budgetTwo.connectors).toHaveLength(2) + expect(budgetOne.connectors).toHaveLength(1) + expect(budgetOne.connectors.every((item) => budgetTwoKeys.has(item.key))).toBe(true) + }) + + it('prioritizes one-near one-far connector groups in external mode', () => { + const data = baseData([ + connector(1, 1, 3, 2, 'near-far'), + connector(2, 1, 4, 2, 'near-near'), + ]) + data.views['2'].placements = [ + placedElement(2, 3, 'External Far'), + placedElement(2, 4, 'Internal Near'), + ] + const snapshot = buildWorkspaceGraphSnapshot(data) + const visibleNodes = new Map([ + [2, 'd1-o2'], + [3, 'd2-o3'], + [4, 'd2-o4'], + ]) + const options = viewportOptions(new Map([ + [2, anchor('d1-o2', 45, 45)], + [3, anchor('d2-o3', 400, 45)], + [4, anchor('d2-o4', 60, 45)], + ])) + + const resolved = resolveZUIProxyConnectors( + snapshot, + visibleNodes, + zuiSettings({ connectorBudget: 1, connectorPriority: 'external' }), + options, + ) + + expect(resolved.connectors).toHaveLength(1) + expect(resolved.connectors[0].details.connectors[0].connector.label).toBe('near-far') + }) + + it('prioritizes both-near connector groups in internal mode', () => { + const data = baseData([ + connector(1, 1, 3, 2, 'near-far'), + connector(2, 1, 4, 2, 'near-near'), + ]) + data.views['2'].placements = [ + placedElement(2, 3, 'External Far'), + placedElement(2, 4, 'Internal Near'), + ] + const snapshot = buildWorkspaceGraphSnapshot(data) + const visibleNodes = new Map([ + [2, 'd1-o2'], + [3, 'd2-o3'], + [4, 'd2-o4'], + ]) + const options = viewportOptions(new Map([ + [2, anchor('d1-o2', 45, 45)], + [3, anchor('d2-o3', 400, 45)], + [4, anchor('d2-o4', 60, 45)], + ])) + + const resolved = resolveZUIProxyConnectors( + snapshot, + visibleNodes, + zuiSettings({ connectorBudget: 1, connectorPriority: 'internal' }), + options, + ) + + expect(resolved.connectors).toHaveLength(1) + expect(resolved.connectors[0].details.connectors[0].connector.label).toBe('near-near') + }) + + it('uses a default budget of 50 and external priority in test settings', () => { + expect(zuiSettings()).toMatchObject({ + connectorBudget: 50, + connectorPriority: 'external', + }) + }) +}) diff --git a/frontend/src/crossBranch/resolve.ts b/frontend/src/crossBranch/resolve.ts index 2059425..183af0c 100644 --- a/frontend/src/crossBranch/resolve.ts +++ b/frontend/src/crossBranch/resolve.ts @@ -1,7 +1,8 @@ import type { Connector, PlacedElement } from '../types' -import { CROSS_BRANCH_DEPTH_ALL } from './types' +import { CROSS_BRANCH_CONNECTOR_BUDGET_DEFAULT, CROSS_BRANCH_DEPTH_ALL } from './types' import type { AggregatedProxyConnector, + CrossBranchConnectorPriority, CrossBranchContextSettings, GraphPlacementRef, ProxyConnectorDetails, @@ -12,6 +13,31 @@ import type { } from './types' import { allConnectors, findLowestCommonAncestorViewId, isDescendantView, relativeOwnerElementPath, viewName } from './graph' +const connectorsBySnapshotCache = new WeakMap() +const endpointPathCacheBySnapshot = new WeakMap>() + +function connectorsForSnapshot(snapshot: WorkspaceGraphSnapshot): Connector[] { + const cached = connectorsBySnapshotCache.get(snapshot) + if (cached) return cached + + const connectors = allConnectors(snapshot) + connectorsBySnapshotCache.set(snapshot, connectors) + return connectors +} + +function endpointPathCacheForSnapshot(snapshot: WorkspaceGraphSnapshot): Map { + let cache = endpointPathCacheBySnapshot.get(snapshot) + if (!cache) { + cache = new Map() + endpointPathCacheBySnapshot.set(snapshot, cache) + } + return cache +} + +function endpointPathCacheKey(ownerViewId: number, elementId: number): string { + return `${ownerViewId}:${elementId}` +} + function firstPlacementForElement(snapshot: WorkspaceGraphSnapshot, elementId: number): GraphPlacementRef | null { return snapshot.placementsByElementId[elementId]?.[0] ?? null } @@ -514,95 +540,306 @@ export interface ZUIResolvedConnector { direction: string style: string label: string + sourceDepth: number + targetDepth: number + maxDepth: number + details: ProxyConnectorDetails +} + +export interface ZUIHiddenProxyBadge { + key: string + sourceAnchorElementId: number + targetAnchorElementId: number + sourceNodeId: string + targetNodeId: string + count: number details: ProxyConnectorDetails } +export interface ZUIProxyResolution { + connectors: ZUIResolvedConnector[] + hiddenBadges: ZUIHiddenProxyBadge[] + omittedConnectorCount: number +} + +export interface ZUIViewportBounds { + minX: number + minY: number + maxX: number + maxY: number + centerX: number + centerY: number +} + +export interface ZUIConnectorAnchorInfo { + nodeId: string + worldX: number + worldY: number + worldW: number + worldH: number +} + +export interface ResolveZUIProxyConnectorOptions { + viewport?: ZUIViewportBounds | null + anchorsByElementId?: Map + connectorPriority?: CrossBranchConnectorPriority +} + function endpointPathForOwnerView(snapshot: WorkspaceGraphSnapshot, ownerViewId: number, elementId: number): number[] { + const cache = endpointPathCacheForSnapshot(snapshot) + const key = endpointPathCacheKey(ownerViewId, elementId) + const cached = cache.get(key) + if (cached) return cached + const placement = chooseBestPlacement(snapshot, elementId, ownerViewId, ownerViewId) - if (!placement) return [elementId] + if (!placement) { + const path = [elementId] + cache.set(key, path) + return path + } const owners = relativeOwnerElementPath(snapshot, snapshot.ancestorsByViewId[placement.viewId]?.[0] ?? placement.viewId, placement.viewId) const path = [...owners] if (path[path.length - 1] !== elementId) path.push(elementId) - return path.length > 0 ? path : [elementId] + const resolvedPath = path.length > 0 ? path : [elementId] + cache.set(key, resolvedPath) + return resolvedPath +} + +interface ZUIEndpointCandidate { + actualElementId: number + actualElementName: string + anchorElementId: number + anchorElementName: string + anchorViewId: number | null + anchorViewName: string | null + placementViewId: number | null + placementViewName: string | null + depth: number +} + +function visibleEndpointCandidates( + snapshot: WorkspaceGraphSnapshot, + ownerViewId: number, + actualElementId: number, + visibleElements: Set, +): ZUIEndpointCandidate[] { + const path = endpointPathForOwnerView(snapshot, ownerViewId, actualElementId) + const visibleIndexes = path + .map((elementId, index) => visibleElements.has(elementId) ? index : -1) + .filter((index) => index >= 0) + + if (visibleIndexes.length === 0) return [] + + const actualElementName = firstPlacementForElement(snapshot, actualElementId)?.element.name ?? `Element ${actualElementId}` + const deepestVisibleIndex = visibleIndexes[visibleIndexes.length - 1] + const candidateIndexes = [deepestVisibleIndex] + if (visibleIndexes.length >= 2) candidateIndexes.push(visibleIndexes[visibleIndexes.length - 2]) + + return candidateIndexes.map((pathIndex) => { + const anchorElementId = path[pathIndex] + const anchorPlacement = firstPlacementForElement(snapshot, anchorElementId) + return { + actualElementId, + actualElementName, + anchorElementId, + anchorElementName: anchorPlacement?.element.name ?? `Element ${anchorElementId}`, + anchorViewId: anchorPlacement?.viewId ?? ownerViewId, + anchorViewName: anchorPlacement?.viewName ?? viewName(snapshot, ownerViewId), + placementViewId: ownerViewId, + placementViewName: viewName(snapshot, ownerViewId), + depth: Math.max(0, path.length - 1 - pathIndex), + } + }) +} + +function isNativelyRenderedInZUI( + connector: Connector, + sourceAnchorElementId: number, + targetAnchorElementId: number, + visibleNodeIdsByElementId: Map, +): boolean { + return visibleNodeIdsByElementId.get(sourceAnchorElementId) === `d${connector.view_id}-o${sourceAnchorElementId}` && + visibleNodeIdsByElementId.get(targetAnchorElementId) === `d${connector.view_id}-o${targetAnchorElementId}` +} + +function visibleEndpointCandidateCacheKey(ownerViewId: number, actualElementId: number): string { + return `${ownerViewId}:${actualElementId}` +} + +function anchorCenter(anchor: ZUIConnectorAnchorInfo) { + return { + x: anchor.worldX + anchor.worldW / 2, + y: anchor.worldY + anchor.worldH / 2, + } +} + +function anchorIsInViewport(anchor: ZUIConnectorAnchorInfo, viewport: ZUIViewportBounds): boolean { + const center = anchorCenter(anchor) + return center.x >= viewport.minX && + center.x <= viewport.maxX && + center.y >= viewport.minY && + center.y <= viewport.maxY +} + +function normalizedDistanceToViewportCenter(anchor: ZUIConnectorAnchorInfo, viewport: ZUIViewportBounds): number { + const center = anchorCenter(anchor) + const dx = center.x - viewport.centerX + const dy = center.y - viewport.centerY + const diagonal = Math.max(1, Math.hypot(viewport.maxX - viewport.minX, viewport.maxY - viewport.minY)) + return Math.hypot(dx, dy) / diagonal +} + +function viewportPriorityScore( + connector: ZUIResolvedConnector, + options: ResolveZUIProxyConnectorOptions | undefined, +): number { + const viewport = options?.viewport + const anchors = options?.anchorsByElementId + const source = anchors?.get(connector.sourceAnchorElementId) + const target = anchors?.get(connector.targetAnchorElementId) + if (!viewport || !source || !target) { + return connector.maxDepth * 100 + connector.sourceDepth + connector.targetDepth + } + + const sourceDistance = normalizedDistanceToViewportCenter(source, viewport) + const targetDistance = normalizedDistanceToViewportCenter(target, viewport) + const nearDistance = Math.min(sourceDistance, targetDistance) + const farDistance = Math.max(sourceDistance, targetDistance) + const sourceInViewport = anchorIsInViewport(source, viewport) + const targetInViewport = anchorIsInViewport(target, viewport) + const inViewportCount = Number(sourceInViewport) + Number(targetInViewport) + + if (connector.details.connectors.length === 0) return Number.MAX_SAFE_INTEGER + + if (options?.connectorPriority === 'internal') { + return (sourceDistance + targetDistance) * 1000 + farDistance * 400 - inViewportCount * 250 + } + + return nearDistance * 1000 - farDistance * 320 - (inViewportCount > 0 ? 300 : 0) + (inViewportCount === 2 ? 160 : 0) +} + +function connectorTouchesViewport( + connector: ZUIResolvedConnector, + options: ResolveZUIProxyConnectorOptions | undefined, +): boolean { + const viewport = options?.viewport + const anchors = options?.anchorsByElementId + if (!viewport || !anchors) return true + const source = anchors.get(connector.sourceAnchorElementId) + const target = anchors.get(connector.targetAnchorElementId) + if (!source || !target) return false + return anchorIsInViewport(source, viewport) || anchorIsInViewport(target, viewport) } export function resolveZUIProxyConnectors( snapshot: WorkspaceGraphSnapshot | null, visibleNodeIdsByElementId: Map, settings: CrossBranchContextSettings, -): ZUIResolvedConnector[] { - if (!snapshot || !settings.enabled || visibleNodeIdsByElementId.size === 0) return [] + options?: ResolveZUIProxyConnectorOptions, +): ZUIProxyResolution { + if (!snapshot || !settings.enabled || visibleNodeIdsByElementId.size === 0) { + return { connectors: [], hiddenBadges: [], omittedConnectorCount: 0 } + } const visibleElements = new Set(visibleNodeIdsByElementId.keys()) + const connectors = connectorsForSnapshot(snapshot) + const endpointCandidateCache = new Map() + const endpointCandidates = (ownerViewId: number, actualElementId: number): ZUIEndpointCandidate[] => { + const key = visibleEndpointCandidateCacheKey(ownerViewId, actualElementId) + const cached = endpointCandidateCache.get(key) + if (cached) return cached + + const candidates = visibleEndpointCandidates(snapshot, ownerViewId, actualElementId, visibleElements) + endpointCandidateCache.set(key, candidates) + return candidates + } const grouped = new Map() + const nativeVisiblePairs = new Set() - for (const connector of allConnectors(snapshot)) { - const sourcePath = endpointPathForOwnerView(snapshot, connector.view_id, connector.source_element_id) - const targetPath = endpointPathForOwnerView(snapshot, connector.view_id, connector.target_element_id) - - const sourceAnchorElementId = [...sourcePath].reverse().find((elementId) => visibleElements.has(elementId)) - const targetAnchorElementId = [...targetPath].reverse().find((elementId) => visibleElements.has(elementId)) - if (sourceAnchorElementId == null || targetAnchorElementId == null) continue - if (sourceAnchorElementId === targetAnchorElementId) continue - // If both real endpoints are already visible, the normal edge renderer will - // draw this connector in-place. Only keep connectors that need anchoring to - // an ancestor/summary node. - if ( - sourceAnchorElementId === connector.source_element_id && - targetAnchorElementId === connector.target_element_id - ) continue - - const sourceDepth = Math.max(0, sourcePath.length - 1 - sourcePath.indexOf(sourceAnchorElementId)) - const targetDepth = Math.max(0, targetPath.length - 1 - targetPath.indexOf(targetAnchorElementId)) - if (settings.depth < CROSS_BRANCH_DEPTH_ALL && Math.max(sourceDepth, targetDepth) > settings.depth) continue - - const sourceEndpoint: ProxyEndpoint = { - actualElementId: connector.source_element_id, - actualElementName: firstPlacementForElement(snapshot, connector.source_element_id)?.element.name ?? `Element ${connector.source_element_id}`, - anchorElementId: sourceAnchorElementId, - anchorElementName: firstPlacementForElement(snapshot, sourceAnchorElementId)?.element.name ?? `Element ${sourceAnchorElementId}`, - anchorViewId: firstPlacementForElement(snapshot, sourceAnchorElementId)?.viewId ?? connector.view_id, - anchorViewName: firstPlacementForElement(snapshot, sourceAnchorElementId)?.viewName ?? viewName(snapshot, connector.view_id), - placementViewId: connector.view_id, - placementViewName: viewName(snapshot, connector.view_id), - depth: sourceDepth, - externalToView: sourceAnchorElementId !== connector.source_element_id, - currentBranchElementId: null, - commonAncestorViewId: null, - commonAncestorViewName: null, - } - const targetEndpoint: ProxyEndpoint = { - actualElementId: connector.target_element_id, - actualElementName: firstPlacementForElement(snapshot, connector.target_element_id)?.element.name ?? `Element ${connector.target_element_id}`, - anchorElementId: targetAnchorElementId, - anchorElementName: firstPlacementForElement(snapshot, targetAnchorElementId)?.element.name ?? `Element ${targetAnchorElementId}`, - anchorViewId: firstPlacementForElement(snapshot, targetAnchorElementId)?.viewId ?? connector.view_id, - anchorViewName: firstPlacementForElement(snapshot, targetAnchorElementId)?.viewName ?? viewName(snapshot, connector.view_id), - placementViewId: connector.view_id, - placementViewName: viewName(snapshot, connector.view_id), - depth: targetDepth, - externalToView: targetAnchorElementId !== connector.target_element_id, - currentBranchElementId: null, - commonAncestorViewId: null, - commonAncestorViewName: null, - } + for (const connector of connectors) { + if (!visibleElements.has(connector.source_element_id) || !visibleElements.has(connector.target_element_id)) continue + if (!isNativelyRenderedInZUI(connector, connector.source_element_id, connector.target_element_id, visibleNodeIdsByElementId)) continue + const [leftAnchorElementId, rightAnchorElementId] = canonicalPairElements(connector.source_element_id, connector.target_element_id) + nativeVisiblePairs.add([leftAnchorElementId, rightAnchorElementId].join('::')) + } - const leaf: ProxyConnectorLeaf = { - connector, - ownerViewId: connector.view_id, - ownerViewName: viewName(snapshot, connector.view_id) ?? `View ${connector.view_id}`, - source: sourceEndpoint, - target: targetEndpoint, + for (const connector of connectors) { + const sourceCandidates = endpointCandidates(connector.view_id, connector.source_element_id) + const targetCandidates = endpointCandidates(connector.view_id, connector.target_element_id) + const seenPairsForConnector = new Set() + + for (const sourceCandidate of sourceCandidates) { + for (const targetCandidate of targetCandidates) { + if (sourceCandidate.anchorElementId === targetCandidate.anchorElementId) continue + if ( + sourceCandidate.actualElementId === sourceCandidate.anchorElementId && + targetCandidate.actualElementId === targetCandidate.anchorElementId && + isNativelyRenderedInZUI( + connector, + sourceCandidate.anchorElementId, + targetCandidate.anchorElementId, + visibleNodeIdsByElementId, + ) + ) { + continue + } + + const sourceEndpoint: ProxyEndpoint = { + actualElementId: sourceCandidate.actualElementId, + actualElementName: sourceCandidate.actualElementName, + anchorElementId: sourceCandidate.anchorElementId, + anchorElementName: sourceCandidate.anchorElementName, + anchorViewId: sourceCandidate.anchorViewId, + anchorViewName: sourceCandidate.anchorViewName, + placementViewId: sourceCandidate.placementViewId, + placementViewName: sourceCandidate.placementViewName, + depth: sourceCandidate.depth, + externalToView: sourceCandidate.anchorElementId !== sourceCandidate.actualElementId, + currentBranchElementId: null, + commonAncestorViewId: null, + commonAncestorViewName: null, + } + const targetEndpoint: ProxyEndpoint = { + actualElementId: targetCandidate.actualElementId, + actualElementName: targetCandidate.actualElementName, + anchorElementId: targetCandidate.anchorElementId, + anchorElementName: targetCandidate.anchorElementName, + anchorViewId: targetCandidate.anchorViewId, + anchorViewName: targetCandidate.anchorViewName, + placementViewId: targetCandidate.placementViewId, + placementViewName: targetCandidate.placementViewName, + depth: targetCandidate.depth, + externalToView: targetCandidate.anchorElementId !== targetCandidate.actualElementId, + currentBranchElementId: null, + commonAncestorViewId: null, + commonAncestorViewName: null, + } + + const leaf: ProxyConnectorLeaf = { + connector, + ownerViewId: connector.view_id, + ownerViewName: viewName(snapshot, connector.view_id) ?? `View ${connector.view_id}`, + source: sourceEndpoint, + target: targetEndpoint, + } + + const [leftAnchorElementId, rightAnchorElementId] = canonicalPairElements( + sourceCandidate.anchorElementId, + targetCandidate.anchorElementId, + ) + const key = [leftAnchorElementId, rightAnchorElementId].join('::') + const pairKey = `${connector.id}:${key}` + if (seenPairsForConnector.has(pairKey)) continue + seenPairsForConnector.add(pairKey) + const existing = grouped.get(key) + if (existing) existing.push(leaf) + else grouped.set(key, [leaf]) + } } - - const [leftAnchorElementId, rightAnchorElementId] = canonicalPairElements(sourceAnchorElementId, targetAnchorElementId) - const key = [leftAnchorElementId, rightAnchorElementId].join('::') - const existing = grouped.get(key) - if (existing) existing.push(leaf) - else grouped.set(key, [leaf]) } const resolved: ZUIResolvedConnector[] = [] + const hiddenBadges: ZUIHiddenProxyBadge[] = [] for (const [key, leaves] of grouped) { const [first] = leaves const { ownerViewIds, ownerViewNames } = ownerViewsFromLeaves(leaves) @@ -611,6 +848,8 @@ export function resolveZUIProxyConnectors( first.target.anchorElementId, ) const canonicalFirstIsSource = canonicalSourceAnchorElementId === first.source.anchorElementId + const canonicalSourceDepth = canonicalFirstIsSource ? first.source.depth : first.target.depth + const canonicalTargetDepth = canonicalFirstIsSource ? first.target.depth : first.source.depth const details: ProxyConnectorDetails = { key, label: proxyDisplayLabel(leaves), @@ -624,6 +863,30 @@ export function resolveZUIProxyConnectors( connectors: leaves, } + const isDirectChildBadgeOnly = leaves.every((leaf) => { + if (Math.max(leaf.source.depth, leaf.target.depth) !== 1) return false + const sourceOk = leaf.source.actualElementId === leaf.source.anchorElementId || + endpointCandidates(leaf.ownerViewId, leaf.source.actualElementId)[0]?.anchorElementId === leaf.source.anchorElementId + const targetOk = leaf.target.actualElementId === leaf.target.anchorElementId || + endpointCandidates(leaf.ownerViewId, leaf.target.actualElementId)[0]?.anchorElementId === leaf.target.anchorElementId + return sourceOk && targetOk + }) + const pairHasNativeDirect = nativeVisiblePairs.has(key) + if (pairHasNativeDirect) { + if (isDirectChildBadgeOnly) { + hiddenBadges.push({ + key: `badge:${key}`, + sourceAnchorElementId: canonicalSourceAnchorElementId, + targetAnchorElementId: canonicalTargetAnchorElementId, + sourceNodeId: visibleNodeIdsByElementId.get(canonicalSourceAnchorElementId) ?? '', + targetNodeId: visibleNodeIdsByElementId.get(canonicalTargetAnchorElementId) ?? '', + count: details.count, + details, + }) + } + continue + } + resolved.push({ key, sourceElementId: canonicalFirstIsSource ? first.source.actualElementId : first.target.actualElementId, @@ -635,9 +898,46 @@ export function resolveZUIProxyConnectors( direction: 'merged', style: first.connector.style || 'bezier', label: details.label, + sourceDepth: canonicalSourceDepth, + targetDepth: canonicalTargetDepth, + maxDepth: Math.max(canonicalSourceDepth, canonicalTargetDepth), details, }) } - return resolved.filter((connector) => connector.sourceNodeId && connector.targetNodeId) + const visibleResolved = resolved + .filter((connector) => connector.sourceNodeId && connector.targetNodeId) + .filter((connector) => connectorTouchesViewport(connector, options)) + .sort((left, right) => { + const scoreDelta = viewportPriorityScore(left, { + ...options, + connectorPriority: settings.connectorPriority, + }) - viewportPriorityScore(right, { + ...options, + connectorPriority: settings.connectorPriority, + }) + if (scoreDelta !== 0) return scoreDelta + if (right.details.count !== left.details.count) return right.details.count - left.details.count + if (left.maxDepth !== right.maxDepth) return left.maxDepth - right.maxDepth + const depthDelta = (left.sourceDepth + left.targetDepth) - (right.sourceDepth + right.targetDepth) + if (depthDelta !== 0) return depthDelta + return left.key.localeCompare(right.key) + }) + const maxGroups = settings.connectorBudget ?? settings.maxProxyConnectorGroups ?? CROSS_BRANCH_CONNECTOR_BUDGET_DEFAULT + const budgetedResolved = maxGroups > 0 ? visibleResolved.slice(0, maxGroups) : visibleResolved + const omittedConnectorIds = new Set() + if (maxGroups > 0) { + for (const connector of visibleResolved.slice(maxGroups)) { + for (const leaf of connector.details.connectors) { + omittedConnectorIds.add(leaf.connector.id) + } + } + } + const omittedConnectorCount = omittedConnectorIds.size + + return { + connectors: budgetedResolved, + hiddenBadges: hiddenBadges.filter((badge) => badge.sourceNodeId && badge.targetNodeId), + omittedConnectorCount, + } } diff --git a/frontend/src/crossBranch/settings.ts b/frontend/src/crossBranch/settings.ts index 983e67c..80aa619 100644 --- a/frontend/src/crossBranch/settings.ts +++ b/frontend/src/crossBranch/settings.ts @@ -1,8 +1,16 @@ import { useCallback, useEffect, useMemo, useState } from 'react' -import type { CrossBranchContextSettings, CrossBranchSurface } from './types' -import { CROSS_BRANCH_DEPTH_ALL } from './types' +import type { CrossBranchConnectorPriority, CrossBranchContextSettings, CrossBranchSurface } from './types' +import { + CROSS_BRANCH_CONNECTOR_BUDGET_DEFAULT, + CROSS_BRANCH_CONNECTOR_BUDGET_MAX, + CROSS_BRANCH_CONNECTOR_BUDGET_MIN, + CROSS_BRANCH_DEPTH_ALL, +} from './types' const STORAGE_PREFIX = 'diag:cross-branch' +export const DEFAULT_MIN_CONNECTOR_ANCHOR_ALPHA = 0.35 +export const DEFAULT_MAX_PROXY_CONNECTOR_GROUPS = 32 +export const DEFAULT_CONNECTOR_PRIORITY: CrossBranchConnectorPriority = 'external' function storageKey(surface: CrossBranchSurface) { return `${STORAGE_PREFIX}:${surface}` @@ -12,9 +20,25 @@ function defaultSettings(surface: CrossBranchSurface): CrossBranchContextSetting return { enabled: surface !== 'zui-shared', depth: CROSS_BRANCH_DEPTH_ALL, + connectorBudget: CROSS_BRANCH_CONNECTOR_BUDGET_DEFAULT, + connectorPriority: DEFAULT_CONNECTOR_PRIORITY, + minConnectorAnchorAlpha: DEFAULT_MIN_CONNECTOR_ANCHOR_ALPHA, + maxProxyConnectorGroups: DEFAULT_MAX_PROXY_CONNECTOR_GROUPS, } } +function normalizeConnectorBudget(value: unknown, fallback: number): number { + if (typeof value !== 'number' || !Number.isFinite(value)) return fallback + return Math.max( + CROSS_BRANCH_CONNECTOR_BUDGET_MIN, + Math.min(CROSS_BRANCH_CONNECTOR_BUDGET_MAX, Math.round(value)), + ) +} + +function normalizeConnectorPriority(value: unknown, fallback: CrossBranchConnectorPriority): CrossBranchConnectorPriority { + return value === 'internal' || value === 'external' ? value : fallback +} + function readSettings(surface: CrossBranchSurface): CrossBranchContextSettings { const defaults = defaultSettings(surface) if (typeof window === 'undefined') return defaults @@ -25,6 +49,14 @@ function readSettings(surface: CrossBranchSurface): CrossBranchContextSettings { return { enabled: parsed.enabled ?? defaults.enabled, depth: typeof parsed.depth === 'number' ? parsed.depth : CROSS_BRANCH_DEPTH_ALL, + connectorBudget: normalizeConnectorBudget(parsed.connectorBudget, defaults.connectorBudget), + connectorPriority: normalizeConnectorPriority(parsed.connectorPriority, defaults.connectorPriority), + minConnectorAnchorAlpha: typeof parsed.minConnectorAnchorAlpha === 'number' + ? parsed.minConnectorAnchorAlpha + : defaults.minConnectorAnchorAlpha, + maxProxyConnectorGroups: typeof parsed.maxProxyConnectorGroups === 'number' + ? parsed.maxProxyConnectorGroups + : defaults.maxProxyConnectorGroups, } } catch { return defaults @@ -51,9 +83,23 @@ export function useCrossBranchContextSettings(surface: CrossBranchSurface) { setSettings((prev) => ({ ...prev, depth })) }, []) + const setConnectorBudget = useCallback((connectorBudget: number) => { + setSettings((prev) => ({ + ...prev, + connectorBudget: normalizeConnectorBudget(connectorBudget, prev.connectorBudget), + maxProxyConnectorGroups: normalizeConnectorBudget(connectorBudget, prev.connectorBudget), + })) + }, []) + + const setConnectorPriority = useCallback((connectorPriority: CrossBranchConnectorPriority) => { + setSettings((prev) => ({ ...prev, connectorPriority })) + }, []) + return useMemo(() => ({ settings, setEnabled, setDepth, - }), [settings, setEnabled, setDepth]) + setConnectorBudget, + setConnectorPriority, + }), [settings, setEnabled, setDepth, setConnectorBudget, setConnectorPriority]) } diff --git a/frontend/src/crossBranch/types.ts b/frontend/src/crossBranch/types.ts index dcb906c..483fb7d 100644 --- a/frontend/src/crossBranch/types.ts +++ b/frontend/src/crossBranch/types.ts @@ -3,12 +3,21 @@ import type { Connector, ExploreData, PlacedElement, ViewTreeNode } from '../typ export const CROSS_BRANCH_DEPTH_ALL = 5 export const CROSS_BRANCH_DEPTH_MIN = 1 export const CROSS_BRANCH_DEPTH_MAX = CROSS_BRANCH_DEPTH_ALL +export const CROSS_BRANCH_CONNECTOR_BUDGET_MIN = 10 +export const CROSS_BRANCH_CONNECTOR_BUDGET_MAX = 200 +export const CROSS_BRANCH_CONNECTOR_BUDGET_DEFAULT = 50 + +export type CrossBranchConnectorPriority = 'external' | 'internal' export type CrossBranchSurface = 'editor' | 'zui' | 'zui-shared' export interface CrossBranchContextSettings { enabled: boolean depth: number + connectorBudget: number + connectorPriority: CrossBranchConnectorPriority + minConnectorAnchorAlpha?: number + maxProxyConnectorGroups?: number } export interface GraphPlacementRef { diff --git a/frontend/src/index.css b/frontend/src/index.css index a86eca2..c100652 100644 --- a/frontend/src/index.css +++ b/frontend/src/index.css @@ -42,6 +42,17 @@ --rf-panel-bl-bottom: 82px; } +html, +body, +#root { + height: 100%; + overflow: hidden; +} + +body { + margin: 0; +} + .react-flow { background-color: var(--bg-canvas); } diff --git a/frontend/src/index.ts b/frontend/src/index.ts index 6feed4c..fe1abb3 100644 --- a/frontend/src/index.ts +++ b/frontend/src/index.ts @@ -94,6 +94,13 @@ export { default as theme } from './theme' // ─── Contexts ──────────────────────────────────────────────────────────────── export { ThemeProvider, useAccentColor, useTheme } from './context/ThemeContext' export { HeaderProvider, useSetHeader, useHeader } from './components/HeaderContext' +export { + WorkspaceVersionProvider, + buildWorkspaceVersionPreview, + useWorkspaceVersionPreview, + type WorkspaceVersionFollowTarget, + type WorkspaceVersionPreview, +} from './context/WorkspaceVersionContext' // ─── Types ─────────────────────────────────────────────────────────────────── export * from './types' diff --git a/frontend/src/pages/AppearanceSettings.tsx b/frontend/src/pages/AppearanceSettings.tsx index 8b99f20..2072f6a 100644 --- a/frontend/src/pages/AppearanceSettings.tsx +++ b/frontend/src/pages/AppearanceSettings.tsx @@ -1,9 +1,12 @@ -import { Box, FormLabel, HStack, Text, Tooltip, VStack, Wrap, WrapItem } from '@chakra-ui/react' +import { Box, FormLabel, HStack, Select, Text, Tooltip, VStack, Wrap, WrapItem } from '@chakra-ui/react' import { ACCENT_OPTIONS, BACKGROUND_OPTIONS, ELEMENT_OPTIONS } from '../constants/colors' import { useTheme } from '../context/ThemeContext' +import { useSourceEditor } from '../utils/sourceEditor' +import type { SourceEditor } from '../api/client' export default function AppearanceSettings({ compact = false }: { compact?: boolean }) { const { accent, setAccent, background, setBackground, elementColor, setElementColor } = useTheme() + const { editor, setEditor } = useSourceEditor() const swatchSize = compact ? '28px' : '32px' const sectionGap = compact ? 5 : 8 @@ -19,6 +22,26 @@ export default function AppearanceSettings({ compact = false }: { compact?: bool + + + Source Editor + + + + Accent diff --git a/frontend/src/pages/Dependencies.tsx b/frontend/src/pages/Dependencies.tsx index 8c9b72c..731621a 100644 --- a/frontend/src/pages/Dependencies.tsx +++ b/frontend/src/pages/Dependencies.tsx @@ -1,4 +1,5 @@ import React, { useCallback, useEffect, useMemo, useRef, useState } from 'react' +import { useSearchParams } from 'react-router-dom' import { motion } from 'framer-motion' import { Box, @@ -13,6 +14,7 @@ import { MenuItem, MenuList, Spinner, + Badge, Tag, Text, VStack, @@ -27,6 +29,7 @@ import { ElementBody } from '../components/NodeBody' import DependenciesOnboarding from '../components/DependenciesOnboarding' import { useTheme } from '../context/ThemeContext' import { hexToRgba } from '../constants/colors' +import { useWorkspaceVersionPreview, type VersionChangeType } from '../context/WorkspaceVersionContext' // ── Data types ───────────────────────────────────────────────────────────── interface ElementWithNeighbours extends DependencyElement { @@ -39,6 +42,8 @@ interface NeighbourNode { position: 'left' | 'right' | 'top' | 'bottom' } +const PAGE_SIZE = 50 + // ── Helpers ──────────────────────────────────────────────────────────────── function computeNeighbourCounts(elements: DependencyElement[], connectors: DependencyConnector[]): ElementWithNeighbours[] { const counts = new Map>() @@ -175,11 +180,13 @@ function NeighbourCard({ onClick, setRef, compactLevel = 0, + versionChangeType, }: { node: NeighbourNode onClick: () => void setRef?: (el: HTMLDivElement | null) => void compactLevel?: number + versionChangeType?: VersionChangeType }) { const cardPadding = compactLevel >= 3 ? 1 : compactLevel >= 2 ? 1.5 : compactLevel >= 1 ? 2 : 3 const showTech = compactLevel < 2 @@ -195,6 +202,13 @@ function NeighbourCard({ compactLevel >= 2 ? (nameLen > 20 ? '2xs' : 'xs') : compactLevel >= 1 ? (nameLen > 22 ? 'xs' : 'sm') : (nameLen > 24 ? 'xs' : 'sm') + const versionColor = versionChangeType === 'added' + ? 'green.300' + : versionChangeType === 'deleted' + ? 'red.300' + : versionChangeType + ? 'yellow.300' + : undefined return ( = { export default function Dependencies() { const setHeader = useSetHeader() const { accent, elementColor } = useTheme() + const [searchParams, setSearchParams] = useSearchParams() + const { preview: versionPreview, followTarget: versionFollowTarget } = useWorkspaceVersionPreview() + const versionPulseChangeForElement = useCallback((elementId: number): VersionChangeType | undefined => { + if (versionFollowTarget?.resourceType !== 'element' || versionFollowTarget.resourceId !== elementId) return undefined + return versionFollowTarget.changeType ?? versionPreview?.elementChanges.get(elementId) + }, [versionFollowTarget, versionPreview]) const [elements, setElements] = useState([]) const [allEdges, setAllEdges] = useState([]) const [loading, setLoading] = useState(true) + const [pageLoading, setPageLoading] = useState(false) const [search, setSearch] = useState('') const [typeFilter, setTypeFilter] = useState('') const [selectedId, setSelectedId] = useState(null) const [topRatio, setTopRatio] = useState(0.45) + const [page, setPage] = useState(0) + const [hasNextPage, setHasNextPage] = useState(false) + const [totalCount, setTotalCount] = useState(0) + const [neighbourElements, setNeighbourElements] = useState>({}) // Graph layout measurement const graphRef = useRef(null) @@ -287,57 +315,124 @@ export default function Dependencies() { useEffect(() => { applyPan(0, 0) }, [selectedId, applyPan]) + useEffect(() => { + const requestedId = searchParams.get('element') + if (requestedId) setSelectedId(requestedId) + }, [searchParams]) + + const selectElement = useCallback((id: string | null) => { + setSelectedId(id) + const next = new URLSearchParams(searchParams) + if (id) next.set('element', id) + else next.delete('element') + setSearchParams(next, { replace: true }) + }, [searchParams, setSearchParams]) + // Header useEffect(() => { setHeader({ hideMobileBar: true, - node: ( - - - {elements.length} elements - - - - {allEdges.length} connectors - - - {elements.length}E - / - {allEdges.length}C - - - ), + node: null, }) return () => setHeader(null) - }, [elements.length, allEdges.length, setHeader]) + }, [setHeader]) + + useEffect(() => { + setPage(0) + }, [search]) // Data fetch useEffect(() => { - api.dependencies - .list() - .then((resp) => { - const objs = resp.elements || [] - const edgs = resp.connectors || [] - setElements(objs) - setAllEdges(edgs) - - if (objs.length > 0) { - const withCounts = computeNeighbourCounts(objs, edgs) - const sorted = [...withCounts].sort((a, b) => b.neighbourCount - a.neighbourCount) - setSelectedId(sorted[0].id) - } + let cancelled = false + const timer = window.setTimeout(() => { + setPageLoading(true) + api.dependencies + .list({ limit: PAGE_SIZE, offset: page * PAGE_SIZE, search }) + .then((resp) => { + if (cancelled) return + const objs = resp.elements || [] + const edgs = resp.connectors || [] + const total = resp.totalCount + setElements(objs) + setAllEdges(edgs) + setTotalCount(total ?? page * PAGE_SIZE + objs.length) + setHasNextPage(total === undefined ? objs.length === PAGE_SIZE : page * PAGE_SIZE + objs.length < total) + + setSelectedId((current) => { + if (objs.length === 0) return null + if (current && objs.some((obj) => obj.id === current)) return current + const withCounts = computeNeighbourCounts(objs, edgs) + const sorted = [...withCounts].sort((a, b) => b.neighbourCount - a.neighbourCount) + return sorted[0]?.id ?? null + }) + }) + .catch(() => { /* intentionally empty */ }) + .finally(() => { + if (!cancelled) { + setLoading(false) + setPageLoading(false) + } + }) + }, 180) + return () => { + cancelled = true + window.clearTimeout(timer) + } + }, [page, search]) + + const elementUniverse = useMemo(() => { + const byID = new Map() + elements.forEach((element) => byID.set(element.id, element)) + Object.values(neighbourElements).forEach((element) => byID.set(element.id, element)) + return Array.from(byID.values()) + }, [elements, neighbourElements]) + + useEffect(() => { + if (selectedId === null) return + const known = new Set(elementUniverse.map((element) => element.id)) + const missing = new Set() + allEdges.forEach((connector) => { + if (connector.source_element_id === selectedId && !known.has(connector.target_element_id)) { + missing.add(connector.target_element_id) + } + if (connector.target_element_id === selectedId && !known.has(connector.source_element_id)) { + missing.add(connector.source_element_id) + } + }) + if (missing.size === 0) return + let cancelled = false + Promise.all( + Array.from(missing).slice(0, 120).map((id) => + api.elements.get(Number(id)).then((element) => ({ + id: String(element.id), + name: element.name, + type: element.kind, + description: element.description, + technology: element.technology, + url: element.url, + logo_url: element.logo_url, + technology_connectors: element.technology_connectors, + tags: element.tags, + repo: element.repo, + branch: element.branch, + language: element.language, + file_path: element.file_path, + created_at: element.created_at, + updated_at: element.updated_at, + } satisfies DependencyElement)).catch(() => null), + ), + ).then((items) => { + if (cancelled) return + setNeighbourElements((prev) => { + const next = { ...prev } + items.forEach((item) => { + if (item) next[item.id] = item + }) + return next }) - .catch(() => { /* intentionally empty */ }) - .finally(() => setLoading(false)) - }, []) + }) + return () => { cancelled = true } + }, [allEdges, elementUniverse, selectedId]) // Derived data const elementsWithCounts = useMemo( @@ -364,12 +459,12 @@ export default function Dependencies() { const selectedElement = useMemo(() => { if (selectedId === null) return null - return elements.find((o) => o.id === selectedId) || null - }, [elements, selectedId]) + return elementUniverse.find((o) => o.id === selectedId) || null + }, [elementUniverse, selectedId]) const neighbourGraph = useMemo(() => { if (selectedId === null) return [] - return getNeighbourGraph(selectedId, elements, allEdges) - }, [selectedId, elements, allEdges]) + return getNeighbourGraph(selectedId, elementUniverse, allEdges) + }, [selectedId, elementUniverse, allEdges]) // Divider drag const startDrag = useCallback(() => { @@ -497,7 +592,7 @@ export default function Dependencies() { if (loading) { return ( - + ) @@ -528,9 +623,11 @@ export default function Dependencies() { const colSpacing = maxCompactLevel >= 3 ? 2 : maxCompactLevel >= 2 ? 3 : maxCompactLevel >= 1 ? 5 : 8 const nodeSpacing = maxCompactLevel >= 2 ? 1 : maxCompactLevel >= 1 ? 2 : 3 const selectedCardShadow = `0 0 0 3px ${hexToRgba(accent, 0.38)}, 0 18px 48px ${hexToRgba(accent, 0.12)}, 0 10px 36px rgba(0,0,0,0.55), 0 3px 10px rgba(0,0,0,0.4)` + const rangeStart = elements.length > 0 ? page * PAGE_SIZE + 1 : 0 + const rangeEnd = page * PAGE_SIZE + elements.length return ( - + {/* ── Top: Listing ──────────────────────────────────────────────────── */} @@ -594,9 +691,45 @@ export default function Dependencies() { - - {filteredElements.length} element{filteredElements.length !== 1 ? 's' : ''} + + + + {totalCount} + elements + + + {allEdges.length} + connectors + + + + + + + {rangeStart}-{rangeEnd} of {totalCount} + {pageLoading && } + + + + Page {page + 1} + + + {/* Column headers */} @@ -640,6 +773,15 @@ export default function Dependencies() { const color = TYPE_COLORS[typeKey] ?? 'gray' const accentHex = TYPE_HEX[typeKey] ?? '#718096' const isSelected = selectedId === obj.id + const versionChangeType = versionPulseChangeForElement(Number(obj.id)) + const versionLineDelta = versionPreview?.elementLineDeltas.get(Number(obj.id)) + const versionColor = versionChangeType === 'added' + ? 'green.300' + : versionChangeType === 'deleted' + ? 'red.300' + : versionChangeType + ? 'yellow.300' + : undefined return ( setSelectedId(isSelected ? null : obj.id)} + onClick={() => selectElement(isSelected ? null : obj.id)} position="relative" role="row" + outline={versionColor ? '1px solid' : undefined} + outlineColor={versionColor} + outlineOffset="-1px" > {/* Left type-color accent */} - - {obj.name} - + + + {obj.name} + + {versionChangeType && ( + + {versionChangeType === 'added' ? '+' : versionChangeType === 'deleted' ? '-' : '~'} + + )} + {versionLineDelta && ( + + {versionLineDelta.added > 0 && +{versionLineDelta.added}} + {versionLineDelta.removed > 0 && -{versionLineDelta.removed}} + + )} + {/* Type badge */} @@ -800,7 +959,8 @@ export default function Dependencies() { key={n.element.id} node={n} compactLevel={maxCompactLevel} - onClick={() => setSelectedId(n.element.id)} + versionChangeType={versionPulseChangeForElement(Number(n.element.id))} + onClick={() => selectElement(n.element.id)} /> ))} @@ -823,7 +983,8 @@ export default function Dependencies() { key={n.element.id} node={n} compactLevel={leftCompactLevel} - onClick={() => setSelectedId(n.element.id)} + versionChangeType={versionPulseChangeForElement(Number(n.element.id))} + onClick={() => selectElement(n.element.id)} /> ))} @@ -844,6 +1005,9 @@ export default function Dependencies() { bg={elementColor} borderColor={accent} borderWidth="2px" + outline={selectedId && versionPulseChangeForElement(Number(selectedId)) ? '3px solid' : undefined} + outlineColor={selectedId && versionPulseChangeForElement(Number(selectedId)) === 'added' ? 'green.300' : selectedId && versionPulseChangeForElement(Number(selectedId)) === 'deleted' ? 'red.300' : selectedId && versionPulseChangeForElement(Number(selectedId)) ? 'yellow.300' : undefined} + outlineOffset="3px" boxShadow={selectedCardShadow} > ( setSelectedId(n.element.id)} - /> + node={n} + compactLevel={rightCompactLevel} + versionChangeType={versionPulseChangeForElement(Number(n.element.id))} + onClick={() => selectElement(n.element.id)} + /> ))} ))} @@ -897,7 +1062,8 @@ export default function Dependencies() { key={n.element.id} node={n} compactLevel={maxCompactLevel} - onClick={() => setSelectedId(n.element.id)} + versionChangeType={versionPulseChangeForElement(Number(n.element.id))} + onClick={() => selectElement(n.element.id)} /> ))} diff --git a/frontend/src/pages/InfiniteZoom.tsx b/frontend/src/pages/InfiniteZoom.tsx index 7e93b0b..ba2e540 100644 --- a/frontend/src/pages/InfiniteZoom.tsx +++ b/frontend/src/pages/InfiniteZoom.tsx @@ -20,13 +20,16 @@ import { } from '@chakra-ui/react' import { api } from '../api/client' import type { ExploreData, ViewLayer } from '../types' -import { FitViewIcon as FitViewSvg, TagsIcon, EyeIcon, EyeOffIcon, FocusIcon as FocusSvg } from '../components/Icons' +import { FitViewIcon as FitViewSvg, TagsIcon, EyeIcon, EyeOffIcon } from '../components/Icons' import ExploreOnboarding from '../components/ExploreOnboarding' import ExplorePageOnboarding from '../components/ExplorePageOnboarding' import MiniZoomOnboarding from '../components/MiniZoomOnboarding' import { ZUICanvas, type ZUICameraFrame, type ZUICanvasHandle } from '../components/ZUI' import { useCrossBranchContextSettings } from '../crossBranch/settings' +import CrossBranchControls from '../components/CrossBranchControls' import { primeWorkspaceGraphSnapshot } from '../crossBranch/store' +import { WATCH_REPRESENTATION_UPDATED_EVENT } from '../components/WorkspacePanel' +import { useWorkspaceVersionPreview } from '../context/WorkspaceVersionContext' // ── Types ────────────────────────────────────────────────────────── interface Props { @@ -36,6 +39,7 @@ interface Props { export interface InfiniteZoomHandle { focusDiagram(viewId: number): boolean + focusElement(viewId: number, elementId: number): boolean setCameraFrame(frame: ZUICameraFrame): boolean } @@ -59,7 +63,13 @@ function InfiniteZoomInner({ sharedToken, shareSlot }: Props, ref?: React.Ref(null) const crossBranchSurface = sharedToken ? 'zui-shared' : 'zui' - const { settings: crossBranchSettings, setEnabled: setCrossBranchEnabled } = useCrossBranchContextSettings(crossBranchSurface) + const { + settings: crossBranchSettings, + setEnabled: setCrossBranchEnabled, + setConnectorBudget: setCrossBranchConnectorBudget, + setConnectorPriority: setCrossBranchConnectorPriority, + } = useCrossBranchContextSettings(crossBranchSurface) + const { preview: versionPreview, followTarget: versionFollowTarget } = useWorkspaceVersionPreview() const cameraProfile = useMemo(() => new URLSearchParams(location.search).get('profile'), [location.search]) const isDetailToOverviewProfile = sharedToken && cameraProfile === 'detail-to-overview' @@ -74,6 +84,9 @@ function InfiniteZoomInner({ sharedToken, shareSlot }: Props, ref?: React.Ref { + const loadExploreData = useCallback(() => { const loader = sharedToken ? api.explore.loadShared(sharedToken) : api.explore.load() - loader.then((d) => { + return loader.then((d) => { if (d.password_required) { setLoading(false) } else { @@ -174,6 +187,20 @@ function InfiniteZoomInner({ sharedToken, shareSlot }: Props, ref?: React.Ref setLoading(false)) }, [sharedToken]) + useEffect(() => { + void loadExploreData() + }, [loadExploreData]) + + useEffect(() => { + if (sharedToken) return + const refresh = () => { + setLoading(true) + void loadExploreData() + } + window.addEventListener(WATCH_REPRESENTATION_UPDATED_EVENT, refresh) + return () => window.removeEventListener(WATCH_REPRESENTATION_UPDATED_EVENT, refresh) + }, [loadExploreData, sharedToken]) + // Fetch tag colors and layers once data is loaded (authenticated users only). // Only fetch from root tree nodes child/nested diagrams would duplicate the same layers. useEffect(() => { @@ -275,6 +302,8 @@ function InfiniteZoomInner({ sharedToken, shareSlot }: Props, ref?: React.Ref @@ -315,21 +344,13 @@ function InfiniteZoomInner({ sharedToken, shareSlot }: Props, ref?: React.Ref - - - + {(allTags.length > 0 || layers.length > 0) && ( <> @@ -356,6 +377,7 @@ function InfiniteZoomInner({ sharedToken, shareSlot }: Props, ref?: React.Ref + {/* Sidebar (hidden on small screens) */} { + items.forEach((item) => { + out.push(item) + walk(item.children ?? []) + }) + } + walk(nodes) + return out +} + export function applyNodeChangesWithStructuralSharing(changes: NodeChange[], nodes: RFNode[]) { if (changes.length === 0) return nodes @@ -476,7 +488,14 @@ export function useCanvasInteractions({ // ── Zoom-in / zoom-out stable callbacks ─────────────────────────────────── const stableOnZoomIn = useCallback(async (elementId: number) => { const childLinks = linksMapRef.current[elementId] || [] - if (childLinks.length > 0) { navigateRef.current(`/views/${childLinks[0].to_view_id}`); return } + if (childLinks.length > 0) { + setSelectedElement(null) + setSelectedEdge(null) + closeElementPanel() + closeConnectorPanel() + navigateRef.current(`/views/${childLinks[0].to_view_id}`) + return + } const obj = viewElementsRef.current.find((o) => o.element_id === elementId) if (obj?.has_view) { @@ -491,6 +510,10 @@ export function useCanvasInteractions({ } const existingView = findInTree(treeDataRef.current) if (existingView) { + setSelectedElement(null) + setSelectedEdge(null) + closeElementPanel() + closeConnectorPanel() navigateRef.current(`/views/${existingView.id}`) return } @@ -506,9 +529,13 @@ export function useCanvasInteractions({ [elementId]: [...(prev[elementId] || []), { id: 0, element_id: elementId, from_view_id: cid, to_view_id: newView.id, to_view_name: newView.name, relation_type: 'child' as const }], })) + setSelectedElement(null) + setSelectedEdge(null) + closeElementPanel() + closeConnectorPanel() navigateRef.current(`/views/${newView.id}`) } catch { /* intentionally empty */ } - }, [canEdit, linksMapRef, viewIdRef, viewElementsRef, navigateRef, setLinksMap, treeDataRef]) + }, [canEdit, linksMapRef, viewIdRef, viewElementsRef, navigateRef, setLinksMap, treeDataRef, setSelectedElement, setSelectedEdge, closeElementPanel, closeConnectorPanel]) const stableOnZoomOut = useCallback(async (elementId: number) => { const parentLinks = parentLinksMapRef.current[elementId] || [] @@ -517,7 +544,14 @@ export function useCanvasInteractions({ // from the clicked element's ID for elements like functions/classes that // don't own a view themselves). const anyParentLink = parentLinks[0] ?? Object.values(parentLinksMapRef.current).flat()[0] - if (anyParentLink) { navigateRef.current(`/views/${anyParentLink.from_view_id}`); return } + if (anyParentLink) { + setSelectedElement(null) + setSelectedEdge(null) + closeElementPanel() + closeConnectorPanel() + navigateRef.current(`/views/${anyParentLink.from_view_id}`) + return + } // Final fallback: use current view's parent_view_id if available const findInTreeById = (nodes: ViewTreeNode[], id: number): ViewTreeNode | null => { @@ -530,13 +564,21 @@ export function useCanvasInteractions({ } const currentView = findInTreeById(treeDataRef.current, viewIdRef.current || -1) if (currentView?.parent_view_id) { + setSelectedElement(null) + setSelectedEdge(null) + closeElementPanel() + closeConnectorPanel() navigateRef.current(`/views/${currentView.parent_view_id}`) } - }, [parentLinksMapRef, navigateRef, treeDataRef, viewIdRef]) + }, [parentLinksMapRef, navigateRef, treeDataRef, viewIdRef, setSelectedElement, setSelectedEdge, closeElementPanel, closeConnectorPanel]) const stableOnNavigateToView = useCallback((id: number) => { + setSelectedElement(null) + setSelectedEdge(null) + closeElementPanel() + closeConnectorPanel() navigateRef.current(`/views/${id}`) - }, [navigateRef]) + }, [navigateRef, setSelectedElement, setSelectedEdge, closeElementPanel, closeConnectorPanel]) const stableOnHoverZoom = useCallback((elementId: number, type: 'in' | 'out' | null) => { const prev = hoveredZoomRef.current @@ -1012,7 +1054,7 @@ export function useCanvasInteractions({ setSelectedProxyConnectorDetails(null) setSelectedEdge(connector) openConnectorPanelRef.current() - }, [closeElementPanel, connectors, openProxyConnectorPanel, setSelectedEdge, setSelectedElement, setSelectedProxyConnectorDetails]) + }, [closeConnectorPanel, closeElementPanel, connectors, openProxyConnectorPanel, setSelectedEdge, setSelectedElement, setSelectedProxyConnectorDetails]) // ── Pane interactions ───────────────────────────────────────────────────── const onPaneClick = useCallback((e: React.MouseEvent) => { @@ -1265,7 +1307,7 @@ export function useCanvasInteractions({ const cid = viewIdRef.current if (!cid) return const incoming = incomingLinksRef.current - const tree = treeDataRef.current + const tree = flattenViewTree(treeDataRef.current) const nav = navigateRef.current const links = linksMapRef.current const treeNode = tree.find((n) => n.id === cid) diff --git a/frontend/src/pages/ViewEditor/hooks/useViewContextNeighbours.ts b/frontend/src/pages/ViewEditor/hooks/useViewContextNeighbours.ts index 7f1453a..2bf19a6 100644 --- a/frontend/src/pages/ViewEditor/hooks/useViewContextNeighbours.ts +++ b/frontend/src/pages/ViewEditor/hooks/useViewContextNeighbours.ts @@ -1,6 +1,6 @@ import { useMemo } from 'react' import { type Edge as RFEdge, type Node as RFNode } from 'reactflow' -import type { PlacedElement } from '../../../types' +import type { Connector, PlacedElement } from '../../../types' import type { CrossBranchContextSettings, ProxyConnectorDetails, WorkspaceGraphSnapshot } from '../../../crossBranch/types' import { resolveViewProxyGraph } from '../../../crossBranch/resolve' @@ -69,6 +69,61 @@ function isAncestorContextNode( (snapshot.descendantsByViewId[ownedViewId]?.includes(descendant.placementViewId) ?? false) } +function canonicalElementPairKey(leftId: number, rightId: number) { + return leftId <= rightId ? `${leftId}::${rightId}` : `${rightId}::${leftId}` +} + +function canonicalNodePairKey(leftId: string, rightId: string) { + return leftId <= rightId ? `${leftId}::${rightId}` : `${rightId}::${leftId}` +} + +function buildDirectConnectorPairSet(connectors: Connector[], visibleElementIds: Set) { + const pairs = new Set() + for (const connector of connectors) { + if (!visibleElementIds.has(connector.source_element_id) || !visibleElementIds.has(connector.target_element_id)) continue + pairs.add(canonicalElementPairKey(connector.source_element_id, connector.target_element_id)) + } + return pairs +} + +function mergeHiddenProxyDetails( + existing: ProxyConnectorDetails | undefined, + next: ProxyConnectorDetails, +): ProxyConnectorDetails { + if (!existing) { + return { + ...next, + ownerViewIds: [...next.ownerViewIds], + ownerViewNames: [...next.ownerViewNames], + connectors: [...next.connectors], + } + } + + const ownerViews = new Map() + existing.ownerViewIds.forEach((ownerViewId, index) => { + ownerViews.set(ownerViewId, existing.ownerViewNames[index] ?? `View ${ownerViewId}`) + }) + next.ownerViewIds.forEach((ownerViewId, index) => { + ownerViews.set(ownerViewId, next.ownerViewNames[index] ?? `View ${ownerViewId}`) + }) + + const connectors = [...existing.connectors, ...next.connectors] + const count = connectors.length + + return { + key: existing.key, + label: count === 1 ? connectors[0]?.connector.label?.trim() || connectors[0]?.connector.relationship?.trim() || 'Cross-branch' : `${count} connectors`, + count, + sourceAnchorId: existing.sourceAnchorId, + targetAnchorId: existing.targetAnchorId, + sourceAnchorName: existing.sourceAnchorName, + targetAnchorName: existing.targetAnchorName, + ownerViewIds: Array.from(ownerViews.keys()), + ownerViewNames: Array.from(ownerViews.values()), + connectors, + } +} + export function useViewContextNeighbours({ snapshot, settings, @@ -82,18 +137,38 @@ export function useViewContextNeighbours({ }: Props) { return useMemo(() => { if (!snapshot || viewId == null || !settings.enabled) { - return { contextNodes: [] as RFNode[], contextConnectors: [] as RFEdge[], proxyConnectorDetailsByKey: {} as Record } + return { + contextNodes: [] as RFNode[], + contextConnectors: [] as RFEdge[], + proxyConnectorDetailsByKey: {} as Record, + hiddenProxyCountsByPair: {} as Record, + hiddenProxyDetailsByPair: {} as Record, + } } const { proxyNodes, proxyConnectors, proxyConnectorDetailsByKey } = resolveViewProxyGraph(snapshot, viewId, viewElements, settings) if (proxyNodes.length === 0 && proxyConnectors.length === 0) { - return { contextNodes: [] as RFNode[], contextConnectors: [] as RFEdge[], proxyConnectorDetailsByKey } + return { + contextNodes: [] as RFNode[], + contextConnectors: [] as RFEdge[], + proxyConnectorDetailsByKey, + hiddenProxyCountsByPair: {} as Record, + hiddenProxyDetailsByPair: {} as Record, + } } const mainNodes = rfNodes.filter((node) => node.type === 'elementNode') if (mainNodes.length === 0) { - return { contextNodes: [] as RFNode[], contextConnectors: [] as RFEdge[], proxyConnectorDetailsByKey } + return { + contextNodes: [] as RFNode[], + contextConnectors: [] as RFEdge[], + proxyConnectorDetailsByKey, + hiddenProxyCountsByPair: {} as Record, + hiddenProxyDetailsByPair: {} as Record, + } } + const visibleElementIds = new Set(viewElements.map((element) => element.element_id)) + const directConnectorPairs = buildDirectConnectorPairSet(snapshot.connectorsByViewId[viewId] ?? [], visibleElementIds) let minX = Infinity let minY = Infinity @@ -460,6 +535,8 @@ export function useViewContextNeighbours({ }) const seenCollapsedPairs = new Set() + const hiddenProxyCountsByPair: Record = {} + const hiddenProxyDetailsByPair: Record = {} const contextConnectors: RFEdge[] = proxyConnectors.flatMap((connector) => { let sourceId = connector.sourceAnchorId let targetId = connector.targetAnchorId @@ -472,7 +549,20 @@ export function useViewContextNeighbours({ if (sourceId === targetId) return [] - const pairKey = `${sourceId}::${targetId}` + const pairKey = canonicalNodePairKey(sourceId, targetId) + if (directConnectorPairs.has(pairKey)) { + hiddenProxyCountsByPair[pairKey] = (hiddenProxyCountsByPair[pairKey] ?? 0) + connector.details.count + hiddenProxyDetailsByPair[pairKey] = mergeHiddenProxyDetails( + hiddenProxyDetailsByPair[pairKey], + { + ...connector.details, + key: `hidden:${pairKey}`, + sourceAnchorId: sourceId, + targetAnchorId: targetId, + }, + ) + return [] + } if (seenCollapsedPairs.has(pairKey)) return [] seenCollapsedPairs.add(pairKey) @@ -496,6 +586,12 @@ export function useViewContextNeighbours({ }] }) - return { contextNodes: [ContextBoundaryElement, ...contextNodes], contextConnectors, proxyConnectorDetailsByKey } + return { + contextNodes: [ContextBoundaryElement, ...contextNodes], + contextConnectors, + proxyConnectorDetailsByKey, + hiddenProxyCountsByPair, + hiddenProxyDetailsByPair, + } }, [snapshot, settings, viewId, viewElements, rfNodes, stableOnNavigateToView, onSelectProxyDetails, expandedAncestorGroups, onToggleAncestorGroup]) } diff --git a/frontend/src/pages/ViewEditor/hooks/useViewData.ts b/frontend/src/pages/ViewEditor/hooks/useViewData.ts index 62a0d6c..9357836 100644 --- a/frontend/src/pages/ViewEditor/hooks/useViewData.ts +++ b/frontend/src/pages/ViewEditor/hooks/useViewData.ts @@ -18,6 +18,7 @@ import { getVisualHandleSlot, } from '../../../utils/edgeDistribution' import { buildViewContentLinks, useStore } from '../../../store/useStore' +import type { WorkspaceVersionFollowTarget, WorkspaceVersionPreview } from '../../../context/WorkspaceVersionContext' interface ViewDataOptions { viewId: number | null @@ -29,6 +30,8 @@ interface ViewDataOptions { hoveredLayerTags: string[] | null hoveredLayerColor: string | null tagColors: Record + versionPreview?: WorkspaceVersionPreview | null + versionFollowTarget?: WorkspaceVersionFollowTarget | null // Node-level callbacks (stable refs from parent) stableOnZoomIn: (elementId: number) => Promise stableOnZoomOut: (elementId: number) => Promise @@ -52,6 +55,7 @@ function alphaColor(color: string, opacity: number): string { // letting structural-sharing fast-path bail out without rebuilding the node. const HIDDEN_STYLE: CSSProperties = { opacity: 0.1, pointerEvents: 'none' } const SOFT_FOCUS_STYLE: CSSProperties = { opacity: 0.2 } +const VERSION_DIM_STYLE: CSSProperties = { opacity: 0.1 } const EMPTY_ARRAY: readonly never[] = Object.freeze([]) const EMPTY_NODE_CONNECTION_META = Object.freeze({ key: '', @@ -158,6 +162,8 @@ export function useViewData({ hoveredLayerTags, hoveredLayerColor, tagColors, + versionPreview, + versionFollowTarget, stableOnZoomIn, stableOnZoomOut, stableOnNavigateToView, @@ -189,7 +195,6 @@ export function useViewData({ const incomingLinks = useStore((state) => state.incomingLinks) const treeData = useStore((state) => state.treeData) const allElements = useStore((state) => state.allElements) - const setAllElements = useStore((state) => state.setAllElements) const hydrateViewContent = useStore((state) => state.hydrateViewContent) const resetCanvas = useStore((state) => state.resetCanvas) const removeElementPlacement = useStore((state) => state.removeElementPlacement) @@ -216,13 +221,14 @@ export function useViewData({ // ── Fetch tree ───────────────────────────────────────────────────────────── const refreshGrid = useCallback(async () => { + if (viewId === null) return const tree = await queryClient.fetchQuery({ - queryKey: ['workspace', 'views', 'tree'], - queryFn: () => api.workspace.views.tree(), + queryKey: ['workspace', 'views', viewId, 'editor-tree'], + queryFn: () => api.workspace.views.treeAround(viewId, { ancestorLevels: 2, descendantLevels: 2 }), staleTime: 0, }).catch(() => null) if (tree) useStore.getState().setTreeData(tree) - }, [queryClient]) + }, [queryClient, viewId]) // ── Fetch view content ────────────────────────────────────────────────── const viewContentQuery = useQuery({ @@ -233,7 +239,7 @@ export function useViewData({ const [diag, content, tree] = await Promise.all([ api.workspace.views.get(viewId), api.workspace.views.content(viewId), - api.workspace.views.tree(), + api.workspace.views.treeAround(viewId, { ancestorLevels: 2, descendantLevels: 2 }), ]) const viewElements = content.placements || [] const connectors = content.connectors || [] @@ -264,16 +270,6 @@ export function useViewData({ resetCanvas() }, [resetCanvas, viewId]) - // ── Keep all-org elements for inline adder ────────────────────────────────── - const allElementsQuery = useQuery({ - queryKey: ['elements', 'list'], - queryFn: () => api.elements.list(), - }) - - useEffect(() => { - if (allElementsQuery.data) setAllElements(allElementsQuery.data) - }, [allElementsQuery.data, setAllElements]) - // ── Refresh elements ──────────────────────────────────────────────────────── const refreshElements = useCallback(async () => { if (viewId === null) return @@ -428,6 +424,9 @@ export function useViewData({ const activeSet = activeTags.length > 0 ? new Set(activeTags) : null const hoveredSet = hoveredLayerTags !== null ? new Set(hoveredLayerTags) : null const isClickConnectMode = clickConnectMode !== null + const versionElementChanges = versionPreview?.elementChanges + const versionElementLineDeltas = versionPreview?.elementLineDeltas + const versionActive = !!versionPreview return viewElements.map((obj) => { const nodeId = String(obj.element_id) @@ -438,13 +437,21 @@ export function useViewData({ const isInactive = isHiddenByLayer || (activeSet !== null && !objTags.some((t) => activeSet.has(t))) const isLayerHighlighted = hoveredSet !== null && objTags.some((t) => hoveredSet.has(t)) const isSoftFocused = hoveredSet !== null && !isLayerHighlighted - - const newZIndex = isLayerHighlighted ? 10 : interactionSourceId === obj.element_id ? 1000 : 0 + const versionChangeType = versionElementChanges?.get(obj.element_id) + const versionLineDelta = versionElementLineDeltas?.get(obj.element_id) + const versionPulseChangeType = versionFollowTarget?.resourceType === 'element' && versionFollowTarget.resourceId === obj.element_id + ? versionFollowTarget.changeType ?? versionChangeType + : undefined + const isDimmedByVersionPreview = versionActive && !versionChangeType + + const newZIndex = versionPulseChangeType ? 20 : isLayerHighlighted ? 10 : interactionSourceId === obj.element_id ? 1000 : 0 const newStyle = isInactive ? HIDDEN_STYLE : isSoftFocused ? SOFT_FOCUS_STYLE - : undefined + : isDimmedByVersionPreview + ? VERSION_DIM_STYLE + : undefined const layerHighlightColor = isLayerHighlighted ? (hoveredLayerColor ?? undefined) : undefined const position = existing?.dragging ? existing.position : { x: obj.position_x ?? 0, y: obj.position_y ?? 0 } const isZoomHovered = hoveredZoomRef.current?.elementId === obj.element_id ? hoveredZoomRef.current.type : null @@ -487,7 +494,9 @@ export function useViewData({ existing.data.connectedHandleIds === connectionMeta.connectedHandleIds && existing.data.selectedHandleIds === connectionMeta.selectedHandleIds && existing.data.reconnectCandidates === connectionMeta.reconnectCandidates && - existing.data.isConnectorHighlighted === connectionMeta.isConnectorHighlighted + existing.data.isConnectorHighlighted === connectionMeta.isConnectorHighlighted && + existing.data.versionChangeType === versionPulseChangeType && + existing.data.versionLineDelta === versionLineDelta ) { return existing } @@ -527,6 +536,8 @@ export function useViewData({ selectedHandleIds: connectionMeta.selectedHandleIds, reconnectCandidates: connectionMeta.reconnectCandidates, isConnectorHighlighted: connectionMeta.isConnectorHighlighted, + versionChangeType: versionPulseChangeType, + versionLineDelta, }, } }) @@ -537,7 +548,7 @@ export function useViewData({ stableOnZoomIn, stableOnZoomOut, stableOnNavigateToView, stableOnSelect, stableOnInteractionStart, stableOnConnectTo, stableOnStartHandleReconnect, stableOnRemoveElement, stableOnHoverZoom, stableOnOpenCodePreview, hoveredZoomRef, activeTags, hiddenLayerTags, hoveredLayerTags, hoveredLayerColor, tagColors, - nodeConnectionMetaByElementId, setRfNodes, + nodeConnectionMetaByElementId, setRfNodes, versionPreview, versionFollowTarget, ]) // ── Derive RF connectors ──────────────────────────────────────────────────────── @@ -545,6 +556,8 @@ export function useViewData({ const hiddenSet = hiddenLayerTags.length > 0 ? new Set(hiddenLayerTags) : null const activeSet = activeTags.length > 0 ? new Set(activeTags) : null const hoveredSet = hoveredLayerTags !== null ? new Set(hoveredLayerTags) : null + const versionConnectorChanges = versionPreview?.connectorChanges + const versionActive = !!versionPreview setRfEdges((prevConnectors) => { const prevEdgeMap = new Map(prevConnectors.map((e) => [e.id, e])) @@ -572,11 +585,13 @@ export function useViewData({ !srcTags.some((t) => hoveredSet.has(t)) || !tgtTags.some((t) => hoveredSet.has(t)) ) - const edgeOpacity = isInactive ? 0.1 : isSoftFocused ? 0.2 : 0.8 - const markerOpacity = isInactive ? 0.1 : isSoftFocused ? 0.2 : 1 + const versionChangeType = versionConnectorChanges?.get(e.id) + const isDimmedByVersionPreview = versionActive && !versionChangeType + const edgeOpacity = isInactive || isDimmedByVersionPreview ? 0.1 : isSoftFocused ? 0.2 : 0.8 + const markerOpacity = isInactive || isDimmedByVersionPreview ? 0.1 : isSoftFocused ? 0.2 : 1 const newZIndex = selectedEdgeId !== null && edgeId === String(selectedEdgeId) ? 1000 : 100 const pointerEvents = (isInactive || isSoftFocused) ? 'none' : 'auto' - const labelBgOpacity = isInactive ? 0.1 : isSoftFocused ? 0.2 : 0.95 + const labelBgOpacity = isInactive || isDimmedByVersionPreview ? 0.1 : isSoftFocused ? 0.2 : 0.95 // Structural sharing: when all user-visible outputs match prev exactly, reuse prev ref. // We match on the underlying connector ref plus every computed visibility/layout value. @@ -594,7 +609,8 @@ export function useViewData({ (existing.data as { sourceGroupIndex?: number }).sourceGroupIndex === layout.sourceGroupIndex && (existing.data as { targetGroupIndex?: number }).targetGroupIndex === layout.targetGroupIndex && (existing.data as { sourceGroupCount?: number }).sourceGroupCount === layout.sourceGroupCount && - (existing.data as { targetGroupCount?: number }).targetGroupCount === layout.targetGroupCount + (existing.data as { targetGroupCount?: number }).targetGroupCount === layout.targetGroupCount && + (existing.data as { versionChangeType?: string }).versionChangeType === versionChangeType ) { return existing } @@ -620,6 +636,7 @@ export function useViewData({ targetHandleSide: layout.targetHandleSide, sourceHandleSlot: layout.sourceHandleSlot, targetHandleSlot: layout.targetHandleSlot, + versionChangeType, }, style: { stroke: 'var(--accent)', strokeWidth: 2, opacity: edgeOpacity, pointerEvents }, @@ -632,7 +649,7 @@ export function useViewData({ } }) }) - }, [connectorLayouts, selectedEdgeId, activeTags, hiddenLayerTags, hoveredLayerTags, elementMap, setRfEdges]) + }, [connectorLayouts, selectedEdgeId, activeTags, hiddenLayerTags, hoveredLayerTags, elementMap, setRfEdges, versionPreview]) // ── Boost z-index of selected connector ──────────────────────────────────────── @@ -685,6 +702,5 @@ export function useViewData({ handleElementDeleted, handleElementPermanentlyDeleted, handleElementSaved, - setAllElements, } } diff --git a/frontend/src/pages/ViewEditor/index.tsx b/frontend/src/pages/ViewEditor/index.tsx index bffcd1d..3559a8f 100644 --- a/frontend/src/pages/ViewEditor/index.tsx +++ b/frontend/src/pages/ViewEditor/index.tsx @@ -41,6 +41,7 @@ import type { LibraryElement as WorkspaceElement, Connector, ViewConnector, + VisibilityOverride, Tag, } from '../../types' import ElementNode from '../../components/ElementNode' @@ -86,6 +87,8 @@ import { removeConnectorGraphSnapshot, upsertConnectorGraphSnapshot, useWorkspac import type { ProxyConnectorDetails } from '../../crossBranch/types' import { useDemoRevealViewport, type ViewEditorDemoOptions } from '../../demo/viewEditor' import { buildElementLibraryItems, useStore } from '../../store/useStore' +import { useWorkspaceVersionPreview } from '../../context/WorkspaceVersionContext' +import { WATCH_REPRESENTATION_UPDATED_EVENT } from '../../components/WorkspacePanel' const nodeTypes = { elementNode: ElementNode, @@ -127,6 +130,10 @@ function areTranslateExtentsEqual( left[1][1] === right[1][1] } +function canonicalNodePairKey(leftId: string, rightId: string) { + return leftId <= rightId ? `${leftId}::${rightId}` : `${rightId}::${leftId}` +} + // ───────────────────────────────────────────────────────────────────────────── @@ -160,6 +167,8 @@ function ViewEditorInner({ const setHeader = useSetHeader() const isMobileLayout = useBreakpointValue({ base: true, md: false }) ?? false + const [densityLevel, setDensityLevel] = useState(0) + const [visibilityOverrides, setVisibilityOverrides] = useState([]) const elementPanel = useDisclosure() const connectorPanel = useDisclosure() @@ -169,6 +178,24 @@ function ViewEditorInner({ const importModal = useDisclosure() const codePreview = useDisclosure() + useEffect(() => { + if (viewId == null) { + setDensityLevel(0) + setVisibilityOverrides([]) + return + } + let cancelled = false + void Promise.all([ + api.workspace.views.density.get(viewId).catch(() => 0), + api.workspace.views.visibilityOverrides.list(viewId).catch(() => []), + ]).then(([level, overrides]) => { + if (cancelled) return + setDensityLevel(level) + setVisibilityOverrides(overrides) + }) + return () => { cancelled = true } + }, [viewId]) + // ── Stable disclosure refs ────────────────────────────────────────────── const openElementPanelRef = useRef(elementPanel.onOpen) openElementPanelRef.current = elementPanel.onOpen @@ -204,6 +231,14 @@ function ViewEditorInner({ const [selectedElement, setSelectedElement] = useState(null) const [selectedEdge, setSelectedEdge] = useState(null) const [selectedProxyConnectorDetails, setSelectedProxyConnectorDetails] = useState(null) + + const [prevViewId, setPrevViewId] = useState(viewId) + if (viewId !== prevViewId) { + setPrevViewId(viewId) + setSelectedElement(null) + setSelectedEdge(null) + setSelectedProxyConnectorDetails(null) + } const [previewElement, setPreviewElement] = useState(null) const [libraryOpen, setLibraryOpen] = useState(() => { if (typeof window === 'undefined') return false @@ -226,7 +261,7 @@ function ViewEditorInner({ const setStoreSnapToGrid = useStore((state) => state.setSnapToGrid) const upsertStoreConnector = useStore((state) => state.upsertConnector) const removeStoreConnector = useStore((state) => state.removeConnector) - const refreshElementsRef = useRef<() => Promise>(async () => {}) + const refreshElementsRef = useRef<() => Promise>(async () => { }) const setSnapToGrid = useCallback((snap: boolean) => { setStoreSnapToGrid(snap) if (typeof window !== 'undefined') localStorage.setItem('diag:snapToGrid', String(snap)) @@ -257,6 +292,7 @@ function ViewEditorInner({ const [activeTags, setActiveTags] = useState([]) const activeTagsRef = useRef([]) activeTagsRef.current = activeTags + const { preview: versionPreview, followTarget: versionFollowTarget } = useWorkspaceVersionPreview() const [tagColors, setTagColors] = useState>({}) useEffect(() => { @@ -304,6 +340,7 @@ function ViewEditorInner({ const nextColor = color ?? tagColors[name]?.color ?? pickUnusedColor(Object.values(tagColors).map(t => t.color)) const nextDescription = description ?? tagColors[name]?.description ?? null + await api.workspace.orgs.tagColors.update(name, nextColor, nextDescription) setTagColors((prev) => ({ ...prev, [name]: { name, color: nextColor, description: nextDescription } })) }, [tagColors]) @@ -384,6 +421,8 @@ function ViewEditorInner({ hoveredLayerTags, hoveredLayerColor, tagColors, + versionPreview, + versionFollowTarget, stableOnZoomIn: useCallback(async (id: number) => { await stableOnZoomInRef.current(id) }, []), stableOnZoomOut: useCallback(async (id: number) => { await stableOnZoomOutRef.current(id) }, []), stableOnNavigateToView: useCallback((id: number) => { stableOnNavigateToViewRef.current(id) }, []), @@ -436,10 +475,82 @@ function ViewEditorInner({ treeDataRef, rfNodesRef, rfEdgesRef, viewIdRef, refreshGrid, refreshElements, handleElementDeleted, handleElementPermanentlyDeleted, handleElementSaved, - setAllElements: _setAllElements, } = data refreshElementsRef.current = refreshElements + const overrideDeltaFor = useCallback((resourceType: VisibilityOverride['resource_type'], resourceId?: number | null) => { + if (resourceId == null) return 0 + return visibilityOverrides.find((override) => override.resource_type === resourceType && override.resource_id === resourceId)?.level_delta ?? 0 + }, [visibilityOverrides]) + + const reloadVisibilityOverrides = useCallback(async () => { + if (viewId == null) return + const overrides = await api.workspace.views.visibilityOverrides.list(viewId).catch(() => []) + setVisibilityOverrides(overrides) + }, [viewId]) + + const handleDensityLevelChange = useCallback(async (level: number) => { + if (viewId == null) return + setDensityLevel(level) + try { + await api.workspace.views.density.set(viewId, level) + await refreshElements() + } catch { + toast({ status: 'error', title: 'Density was not saved' }) + } + }, [refreshElements, toast, viewId]) + + const handleVisibilityOverride = useCallback(async (resourceType: VisibilityOverride['resource_type'], resourceId: number, action: 'promote' | 'demote' | 'reset') => { + if (viewId == null) return + try { + if (action === 'promote') await api.workspace.views.visibilityOverrides.promote(viewId, resourceType, resourceId) + else if (action === 'demote') await api.workspace.views.visibilityOverrides.demote(viewId, resourceType, resourceId) + else await api.workspace.views.visibilityOverrides.reset(viewId, resourceType, resourceId) + await reloadVisibilityOverrides() + await refreshElements() + } catch { + toast({ status: 'error', title: 'Visibility override was not saved' }) + } + }, [refreshElements, reloadVisibilityOverrides, toast, viewId]) + + const resolveWatchRepositoryId = useCallback(async () => { + const status = await api.watch.status().catch(() => null) + if (status?.repository?.id) return status.repository.id + const repositories = await api.watch.repositories().catch(() => []) + return repositories[0]?.id ?? null + }, []) + + const applyWatchContextAction = useCallback(async (action: 'show' | 'hide' | 'clean', resourceType: 'element' | 'view', resourceId: number) => { + const repositoryId = await resolveWatchRepositoryId() + if (!repositoryId) { + toast({ status: 'warning', title: 'No watch repository found' }) + return + } + try { + const result = action === 'show' + ? await api.watch.showContext(repositoryId, { resource_type: resourceType, resource_id: resourceId }) + : action === 'hide' + ? await api.watch.hideContext(repositoryId, { resource_type: resourceType, resource_id: resourceId }) + : await api.watch.cleanContext(repositoryId, { resource_type: resourceType, resource_id: resourceId }) + await refreshGrid() + await refreshElements() + window.dispatchEvent(new CustomEvent(WATCH_REPRESENTATION_UPDATED_EVENT, { + detail: { type: 'representation.updated', repository_id: repositoryId, at: new Date().toISOString(), data: result.summary }, + })) + toast({ + status: 'success', + title: action === 'show' ? 'Context revealed' : 'Noise cleaned', + description: action === 'show' + ? `${result.elements_added + result.connectors_added + result.views_added} generated item${result.elements_added + result.connectors_added + result.views_added === 1 ? '' : 's'} added. Tier ${result.tier_after}/${result.max_tier}.` + : action === 'hide' + ? `${result.elements_removed + result.connectors_removed + result.views_removed} generated item${result.elements_removed + result.connectors_removed + result.views_removed === 1 ? '' : 's'} removed.` + : `${result.elements_removed + result.connectors_removed + result.views_removed} generated item${result.elements_removed + result.connectors_removed + result.views_removed === 1 ? '' : 's'} removed. Tier ${result.tier_after}/${result.max_tier}.`, + }) + } catch (err) { + toast({ status: 'error', title: action === 'show' ? 'Failed to show context' : 'Failed to clean noise', description: String(err) }) + } + }, [refreshElements, refreshGrid, resolveWatchRepositoryId, toast]) + const tagCounts = useMemo(() => { const counts: Record = {} viewElements.forEach(p => { @@ -510,10 +621,11 @@ function ViewEditorInner({ const availableTags = useMemo(() => { const tags = new Set() + viewElements.forEach((o) => o.tags?.forEach((t: string) => tags.add(t))) allElements.forEach((o) => o.tags?.forEach((t: string) => tags.add(t))) Object.keys(tagColors).forEach((t) => tags.add(t)) return Array.from(tags).sort((a, b) => a.localeCompare(b)) - }, [allElements, tagColors]) + }, [allElements, tagColors, viewElements]) const effectiveWorkspaceSnapshot = useMemo(() => { if (viewId == null) return workspaceGraphSnapshot @@ -648,8 +760,9 @@ function ViewEditorInner({ openProxyConnectorPanel: useCallback(() => openProxyConnectorPanelRef.current(), []), closeProxyConnectorPanel: useCallback(() => closeProxyConnectorPanelRef.current(), []), handleElementDeleted, handleElementPermanentlyDeleted, - handleConnectorDeleted: useCallback((edgeId: number) => { - if (viewId != null) removeConnectorGraphSnapshot(viewId, edgeId) + handleConnectorDeleted: useCallback((edgeId: number, ownerViewId?: number) => { + const vid = ownerViewId ?? viewId + if (vid != null) removeConnectorGraphSnapshot(vid, edgeId) removeStoreConnector(edgeId) void refreshElementsRef.current() }, [removeStoreConnector, viewId]), @@ -680,7 +793,7 @@ function ViewEditorInner({ }) }, []) - const { contextNodes, contextConnectors } = useViewContextNeighbours({ + const { contextNodes, contextConnectors, hiddenProxyCountsByPair, hiddenProxyDetailsByPair } = useViewContextNeighbours({ snapshot: effectiveWorkspaceSnapshot, settings: crossBranchSettings, viewId, @@ -699,6 +812,39 @@ function ViewEditorInner({ onToggleAncestorGroup: stableOnToggleAncestorGroup, }) + const rfEdgesWithProxyBadges = useMemo(() => { + if (Object.keys(hiddenProxyCountsByPair).length === 0) return rfEdges + + let changed = false + const next = rfEdges.map((edge) => { + const pairKey = canonicalNodePairKey(edge.source, edge.target) + const proxyBadgeCount = hiddenProxyCountsByPair[pairKey] ?? 0 + const currentBadgeCount = (edge.data as { proxyBadgeCount?: number } | undefined)?.proxyBadgeCount ?? 0 + const proxyBadgeDetails = hiddenProxyDetailsByPair[pairKey] ?? null + const currentBadgeDetails = (edge.data as { proxyBadgeDetails?: ProxyConnectorDetails | null } | undefined)?.proxyBadgeDetails ?? null + if (proxyBadgeCount === currentBadgeCount && proxyBadgeDetails === currentBadgeDetails) return edge + changed = true + return { + ...edge, + data: { + ...(edge.data ?? {}), + proxyBadgeCount: proxyBadgeCount > 0 ? proxyBadgeCount : undefined, + proxyBadgeDetails, + onOpenProxyBadge: (details: ProxyConnectorDetails) => { + setSelectedElement(null) + setSelectedEdge(null) + closeConnectorPanelRef.current() + closeElementPanelRef.current() + setSelectedProxyConnectorDetails(details) + openProxyConnectorPanelRef.current() + }, + }, + } + }) + + return changed ? next : rfEdges + }, [hiddenProxyCountsByPair, hiddenProxyDetailsByPair, rfEdges]) + // Keep context nodes in state so React Flow can store measured dimensions. // When computed positions change (e.g. main node drag), preserve the previously // measured width/height so nodes don't flash hidden while being re-measured. @@ -735,10 +881,10 @@ function ViewEditorInner({ } const allEdges = contextConnectors.length === 0 - ? rfEdges - : rfEdges.length === 0 + ? rfEdgesWithProxyBadges + : rfEdgesWithProxyBadges.length === 0 ? contextConnectors - : [...contextConnectors, ...rfEdges] + : [...contextConnectors, ...rfEdgesWithProxyBadges] const selectedEdgeEndPoints = new Set() let hasEdgeSel = false @@ -773,14 +919,14 @@ function ViewEditorInner({ cache.set(n, faded) return faded }) - }, [liveContextNodes, rfNodes, contextConnectors, rfEdges]) + }, [liveContextNodes, rfNodes, contextConnectors, rfEdgesWithProxyBadges]) const flowEdges = useMemo(() => { const allEdges = contextConnectors.length === 0 - ? rfEdges - : rfEdges.length === 0 + ? rfEdgesWithProxyBadges + : rfEdgesWithProxyBadges.length === 0 ? contextConnectors - : [...contextConnectors, ...rfEdges] + : [...contextConnectors, ...rfEdgesWithProxyBadges] const allNodes = liveContextNodes.length === 0 ? rfNodes : rfNodes.length === 0 @@ -815,7 +961,7 @@ function ViewEditorInner({ cache.set(e, faded) return faded }) - }, [contextConnectors, rfEdges, liveContextNodes, rfNodes]) + }, [contextConnectors, rfEdgesWithProxyBadges, liveContextNodes, rfNodes]) // Route onNodesChange: context node changes (dimensions, selection) go to // liveContextNodes state; main node changes go to the canvas handler. @@ -885,7 +1031,7 @@ function ViewEditorInner({ return } - const ok = safeFitView({ duration: 0 }) + const ok = safeFitView({ duration: 0, padding: 400 }) if (ok) needsFitView.current = false else setTimeout(() => { if (needsFitView.current) maybeFitView() }, 50) }, [applyDemoRevealViewport, clampedRevealProgress, safeFitView, rfNodesRef]) @@ -902,7 +1048,15 @@ function ViewEditorInner({ return () => observer.disconnect() }, [maybeFitView]) - useEffect(() => { needsFitView.current = true }, [viewId]) + useEffect(() => { + setSelectedElement(null) + setSelectedEdge(null) + setSelectedProxyConnectorDetails(null) + elementPanel.onClose() + connectorPanel.onClose() + proxyConnectorPanel.onClose() + needsFitView.current = true + }, [viewId]) // ── Dynamic viewport bounds ──────────────────────────────────────────────── useEffect(() => { @@ -995,14 +1149,14 @@ function ViewEditorInner({ useEffect(() => () => setHeader(null), [setHeader]) // ── Share ────────────────────────────────────────────────────────────────── - const onShare = useCallback(() => {}, []) + const onShare = useCallback(() => { }, []) const handleExplorerHoverZoom = useCallback((elementId: number | null, type: 'in' | 'out' | null) => { setHoveredZoom(type && elementId ? { elementId, type } : null) }, []) const handleToggleExplorer = useCallback(() => setIsExplorerOpen((v) => !v), []) const handleCloseLibrary = useCallback(() => setLibraryOpen(false), []) - const handleCreateNewLibraryRef = useRef<() => void>(() => {}) + const handleCreateNewLibraryRef = useRef<() => void>(() => { }) const handleCreateNewLibrary = useCallback(() => handleCreateNewLibraryRef.current(), []) const handleFocusModeChange = useCallback((v: boolean) => setCrossBranchEnabled(!v), [setCrossBranchEnabled]) const handleOpenExport = useCallback(() => exportModal.onOpen(), [exportModal]) @@ -1010,13 +1164,15 @@ function ViewEditorInner({ upsertConnectorGraphSnapshot(updated) upsertStoreConnector(updated) }, [upsertStoreConnector]) - const handleConnectorDeleted = useCallback((edgeId: number) => { - if (viewId != null) removeConnectorGraphSnapshot(viewId, edgeId) + const handleConnectorDeleted = useCallback((edgeId: number, ownerViewId?: number) => { + const vid = ownerViewId ?? viewId + if (vid != null) removeConnectorGraphSnapshot(vid, edgeId) removeStoreConnector(edgeId) void refreshElements() }, [refreshElements, removeStoreConnector, viewId]) - const handleConnectorDeleteInPanel = useCallback((edgeId: number) => { - handleConnectorDeleted(edgeId) + + const handleConnectorDeleteInPanel = useCallback((edgeId: number, ownerViewId?: number) => { + handleConnectorDeleted(edgeId, ownerViewId) setSelectedEdge(null) }, [handleConnectorDeleted, setSelectedEdge]) const handleViewSave = useCallback((updated: ViewTreeNode) => setView(updated), [setView]) @@ -1114,10 +1270,10 @@ function ViewEditorInner({ // Render states // ───────────────────────────────────────────────────────────────────────────── if (view === undefined) { - return + return } if (view === null) { - return View not found. + return View not found. } return ( @@ -1125,7 +1281,7 @@ function ViewEditorInner({ viewId, canEdit, isOwner, isFreePlan, snapToGrid, setSnapToGrid, selectedElement, selectedConnector: selectedEdge }}> - + applyWatchContextAction('hide', 'element', id)} + visibilityOverrideDelta={overrideDeltaFor('element', selectedElement?.id)} + onPromoteVisibility={(id) => handleVisibilityOverride('element', id, 'promote')} + onDemoteVisibility={(id) => handleVisibilityOverride('element', id, 'demote')} + onResetVisibility={(id) => handleVisibilityOverride('element', id, 'reset')} orgId={''} links={selectedElement ? (linksMap[selectedElement.id] || EMPTY_LINKS) : EMPTY_LINKS} parentLinks={selectedElement ? (parentLinksMap[selectedElement.id] || EMPTY_LINKS) : EMPTY_LINKS} @@ -1412,14 +1575,25 @@ function ViewEditorInner({ orgId={''} onSave={handleConnectorSave} autoSave onDelete={handleConnectorDeleteInPanel} + visibilityOverrideDelta={overrideDeltaFor('connector', selectedEdge?.id)} + onPromoteVisibility={(id) => handleVisibilityOverride('connector', id, 'promote')} + onDemoteVisibility={(id) => handleVisibilityOverride('connector', id, 'demote')} + onResetVisibility={(id) => handleVisibilityOverride('connector', id, 'reset')} hasBackdrop={isMobileLayout} - connectorPanelAfterContentSlot={connectorPanelAfterContentSlot} - /> + connectorPanelAfterContentSlot={connectorPanelAfterContentSlot} + /> { + setSelectedEdge(connector) + connectorPanel.onOpen() + }} + onDelete={(edgeId, ownerViewId) => { + handleConnectorDeleteInPanel(edgeId, ownerViewId) + }} /> void } -type ViewType = 'explore' | 'hierarchy' +type ViewType = JumpViewMode const MotionBox = motion.create(Box) @@ -56,24 +65,14 @@ function HierarchyModeIcon({ size = 13 }: { size?: number }) { ) } -function flattenTree(roots: ViewTreeNode[]): ViewTreeNode[] { - const result: ViewTreeNode[] = [] - const traverse = (node: ViewTreeNode) => { - result.push(node) - node.children.forEach(traverse) - } - roots.forEach(traverse) - return result -} - interface DiagramJumpToolbarProps { view: ViewType searchTerm: string - searchResults: ViewTreeNode[] + searchResults: JumpSearchResult[] activeSearchIndex: number onSearchChange: (term: string) => void onSearchKeyDown: (e: React.KeyboardEvent) => void - onResultClick: (result: ViewTreeNode) => void + onResultClick: (result: JumpSearchResult) => void onViewChange: (view: ViewType) => void onCreateOpen: () => void } @@ -328,7 +327,7 @@ function DiagramJumpToolbar({ > {searchResults.map((result, idx) => ( - Level {result.level} • {result.level_label || 'Diagram'} + {jumpResultSubtitle(result)} {idx === activeSearchIndex && ( - {view === 'explore' ? 'ZOOM' : 'OPEN'} + {jumpResultActionLabel(view)} @@ -384,15 +383,33 @@ export default function ViewsPage({ shareSlot, onShareView }: Props) { const [treeLoading, setTreeLoading] = useState(true) const [focusedHierarchyId, setFocusedHierarchyId] = useState(null) const [searchTerm, setSearchTerm] = useState('') - const [searchResults, setSearchResults] = useState([]) + const [searchResults, setSearchResults] = useState([]) const [activeSearchIndex, setActiveSearchIndex] = useState(-1) + const [exploreSearchData, setExploreSearchData] = useState(null) const { isOpen: isCreateOpen, onOpen: onCreateOpen, onClose: onCreateClose } = useDisclosure() const [newName, setNewName] = useState('') const [isCreating, setIsCreating] = useState(false) const exploreRef = useRef(null) + const exploreSearchLoadRef = useRef | null>(null) const flatTree = useMemo(() => flattenTree(treeData), [treeData]) + const ensureExploreSearchData = useCallback(() => { + if (exploreSearchData || exploreSearchLoadRef.current) return + exploreSearchLoadRef.current = api.explore.load() + .then((data) => { + if (!data.password_required) { + setExploreSearchData(data) + return data + } + return null + }) + .catch(() => null) + .finally(() => { + exploreSearchLoadRef.current = null + }) + }, [exploreSearchData]) + const handleViewChange = useCallback((newView: ViewType) => { setView(newView) const newParams = new URLSearchParams(searchParams) @@ -411,8 +428,43 @@ export default function ViewsPage({ shareSlot, onShareView }: Props) { } }, [searchParams]) + useEffect(() => { + const focusId = Number(searchParams.get('focus') ?? 0) + const elementId = Number(searchParams.get('element') ?? 0) + if (view !== 'explore' || !Number.isFinite(focusId) || focusId <= 0) return + let attempts = 0 + let timer: number | null = null + const focus = () => { + attempts += 1 + if (Number.isFinite(elementId) && elementId > 0) { + if (exploreRef.current?.focusElement(focusId, elementId)) return + if (attempts < 12) timer = window.setTimeout(focus, 150) + return + } + if (exploreRef.current?.focusDiagram(focusId)) return + if (attempts < 12) timer = window.setTimeout(focus, 150) + } + focus() + return () => { + if (timer !== null) window.clearTimeout(timer) + } + }, [searchParams, view]) + + useEffect(() => { + const trimmed = searchTerm.trim() + if (trimmed.length < 3) return + + const matches = buildJumpSearchResults(trimmed, flatTree, exploreSearchData) + setSearchResults(matches) + setActiveSearchIndex(matches.length > 0 ? 0 : -1) + if (view === 'hierarchy' && matches[0]) { + setFocusedHierarchyId(matches[0].viewId) + } + }, [exploreSearchData, flatTree, searchTerm, view]) + const refreshTree = useCallback(async () => { setTreeLoading(true) + setExploreSearchData(null) const tree = await api.workspace.views.tree().catch(() => null) if (tree) { setTreeData(tree) @@ -452,17 +504,38 @@ export default function ViewsPage({ shareSlot, onShareView }: Props) { // eslint-disable-next-line react-hooks/exhaustive-deps }, []) - const commitSearchResult = useCallback((result: ViewTreeNode) => { + useEffect(() => { + const refresh = () => { + void refreshTree() + } + window.addEventListener(WATCH_REPRESENTATION_UPDATED_EVENT, refresh) + return () => window.removeEventListener(WATCH_REPRESENTATION_UPDATED_EVENT, refresh) + }, [refreshTree]) + + const commitSearchResult = useCallback((result: JumpSearchResult) => { if (view === 'explore') { - exploreRef.current?.focusDiagram(result.id) + const newParams = new URLSearchParams(searchParams) + newParams.set('view', 'explore') + newParams.set('focus', String(result.viewId)) + if (result.type === 'element') { + newParams.set('element', String(result.elementId)) + exploreRef.current?.focusElement(result.viewId, result.elementId) + } else { + newParams.delete('element') + exploreRef.current?.focusDiagram(result.viewId) + } + setSearchParams(newParams, { replace: true }) + } else if (result.type === 'element') { + setFocusedHierarchyId(result.viewId) + navigate(`/views/${result.viewId}?element=${result.elementId}`) } else { - setFocusedHierarchyId(result.id) - navigate(`/views/${result.id}`) + setFocusedHierarchyId(result.viewId) + navigate(`/views/${result.viewId}`) } setSearchResults([]) setActiveSearchIndex(-1) setSearchTerm('') - }, [navigate, view]) + }, [navigate, searchParams, setSearchParams, view]) const handleSearchChange = useCallback((term: string) => { setSearchTerm(term) @@ -471,20 +544,8 @@ export default function ViewsPage({ shareSlot, onShareView }: Props) { setActiveSearchIndex(-1) return } - - const normalized = term.trim().toLowerCase() - const matches = flatTree - .filter((n) => n.name.toLowerCase().includes(normalized)) - .slice(0, 5) - - setSearchResults(matches) - if (matches.length > 0) { - setActiveSearchIndex(0) - if (view === 'hierarchy') setFocusedHierarchyId(matches[0].id) - } else { - setActiveSearchIndex(-1) - } - }, [flatTree, view]) + ensureExploreSearchData() + }, [ensureExploreSearchData]) const handleSearchKeyDown = useCallback((e: React.KeyboardEvent) => { if (e.key === 'Escape') { @@ -498,12 +559,12 @@ export default function ViewsPage({ shareSlot, onShareView }: Props) { e.preventDefault() const nextIndex = (activeSearchIndex + 1) % searchResults.length setActiveSearchIndex(nextIndex) - if (view === 'hierarchy') setFocusedHierarchyId(searchResults[nextIndex].id) + if (view === 'hierarchy') setFocusedHierarchyId(searchResults[nextIndex].viewId) } else if (e.key === 'ArrowUp') { e.preventDefault() const nextIndex = (activeSearchIndex - 1 + searchResults.length) % searchResults.length setActiveSearchIndex(nextIndex) - if (view === 'hierarchy') setFocusedHierarchyId(searchResults[nextIndex].id) + if (view === 'hierarchy') setFocusedHierarchyId(searchResults[nextIndex].viewId) } else if (e.key === 'Enter' && activeSearchIndex >= 0) { e.preventDefault() commitSearchResult(searchResults[activeSearchIndex]) @@ -514,7 +575,17 @@ export default function ViewsPage({ shareSlot, onShareView }: Props) { if (!newName.trim()) return setIsCreating(true) try { - const d = await api.workspace.views.create({ name: newName.trim() }) + let d + if (treeData.length > 0) { + // Root view already exists. Create a new element in the root view to own this new diagram. + const name = newName.trim() + const element = await api.workspace.elements.create({ name }) + const root = treeData[0] + await api.workspace.views.placements.add(root.id, element.id, 100, 100) + d = await api.workspace.views.create({ name, parent_view_id: element.id }) + } else { + d = await api.workspace.views.create({ name: newName.trim() }) + } await refreshTree() navigate(`/views/${d.id}`) onCreateClose() @@ -524,7 +595,7 @@ export default function ViewsPage({ shareSlot, onShareView }: Props) { } finally { setIsCreating(false) } - }, [navigate, newName, onCreateClose, refreshTree]) + }, [navigate, newName, onCreateClose, refreshTree, treeData]) if (initializing) { return ( diff --git a/frontend/src/pages/ViewsGrid.tsx b/frontend/src/pages/ViewsGrid.tsx index 96a23a7..41ffda6 100644 --- a/frontend/src/pages/ViewsGrid.tsx +++ b/frontend/src/pages/ViewsGrid.tsx @@ -54,57 +54,59 @@ function flattenTree(roots: ViewTreeNode[]): ViewTreeNode[] { return result } +function filterTreeForGrid(nodes: ViewTreeNode[], allowedIds: Set | null): ViewTreeNode[] { + if (!allowedIds) return nodes + + const visit = (node: ViewTreeNode): ViewTreeNode | null => { + const children = node.children + .map(visit) + .filter((child): child is ViewTreeNode => child !== null) + const include = allowedIds.has(node.id) || (node.parent_view_id === null && children.length > 0) + if (!include) return null + return { ...node, children } + } + + return nodes + .map(visit) + .filter((node): node is ViewTreeNode => node !== null) +} + // ── Layout algorithm ────────────────────────────────────────────────────────── const CELL_W = 260 const CELL_H = 150 const GAP_H = 80 const GAP_V = 120 - -function subtreeWidth(node: ViewTreeNode): number { - if (node.children.length === 0) return 1 - return node.children.reduce((sum, c) => sum + subtreeWidth(c), 0) +const COMPACT_WORKSPACE_THRESHOLD = 32 +const LAYOUT_TRANSITION = 'transform 560ms cubic-bezier(0.16, 1, 0.3, 1), opacity 260ms ease, filter 260ms ease' + +interface GridDisplayNode { + id: string + kind: 'view' | 'cluster' + view: ViewTreeNode + parentId: string | null + depth: number + children: GridDisplayNode[] + collapsedCount?: number + dimmed?: boolean } -function buildDescendantSets(roots: ViewTreeNode[]): Map> { - const map = new Map>() - - function visit(node: ViewTreeNode): Set { - const set = new Set([node.id]) - node.children.forEach((child) => { - const childSet = visit(child) - childSet.forEach((id) => set.add(id)) - }) - map.set(node.id, set) - return set - } - - roots.forEach(visit) - return map +function displaySubtreeWidth(node: GridDisplayNode): number { + if (node.children.length === 0) return 1 + return node.children.reduce((sum, child) => sum + displaySubtreeWidth(child), 0) } /** - * Compute layout positions. - * - * Y-axis: node.depth (= node.level) - honours manual level overrides so a - * diagram at L2 is always rendered in the L2 row even if its parent - * is at L0. - * - * X-axis: column derived from the tree-walk (pre-order rank within each - * level band), then a de-overlap pass shifts any colliding nodes and - * their subtrees rightward so nothing overlaps on the same row. + * Compute layout positions for real cards plus collapsed cluster cards. + * Y follows view depth so manual level overrides still read as horizontal bands. */ -function computeLayout(roots: ViewTreeNode[]): Map { - const positions = new Map() - const flat: ViewTreeNode[] = [] - const visit = (n: ViewTreeNode) => { flat.push(n); n.children.forEach(visit) } - roots.forEach(visit) - - if (flat.length === 0) return positions +function computeDisplayLayout(roots: GridDisplayNode[]): Map { + const positions = new Map() + if (roots.length === 0) return positions + const flat = flattenDisplayTree(roots) - // ── Step 1: initial column assignment via tree walk ───────────────────────── - function layoutNode(node: ViewTreeNode, startCol: number) { - const w = subtreeWidth(node) + function layoutNode(node: GridDisplayNode, startCol: number) { + const w = displaySubtreeWidth(node) const centerCol = startCol + (w - 1) / 2 positions.set(node.id, { x: centerCol * (CELL_W + GAP_H), @@ -113,41 +115,44 @@ function computeLayout(roots: ViewTreeNode[]): Map>() + const collectDescendants = (node: GridDisplayNode): Set => { + const set = new Set([node.id]) + node.children.forEach((child) => { + collectDescendants(child).forEach((id) => set.add(id)) + }) + descendants.set(node.id, set) + return set } - // ── Step 2: build descendant sets so we can shift whole subtrees ──────────── - const descendants = buildDescendantSets(roots) - - // ── Step 3: de-overlap pass - per Y row (top-down), fix X collisions ──────── - const STEP = CELL_W + GAP_H - const byY = new Map() - flat.forEach((n) => { - const y = n.depth * (CELL_H + GAP_V) + roots.forEach(collectDescendants) + + const byY = new Map() + flat.forEach((node) => { + const y = node.depth * (CELL_H + GAP_V) if (!byY.has(y)) byY.set(y, []) - byY.get(y)!.push(n.id) + byY.get(y)!.push(node.id) }) - // Process rows top-down (ascending Y) so parent shifts propagate downward first const sortedYRows = Array.from(byY.entries()).sort(([ya], [yb]) => ya - yb) - + const step = CELL_W + GAP_H for (const [rowY, ids] of sortedYRows) { - // Snapshot original X values before any mutations in this row - - // this prevents a just-shifted node's new position from cascading - // into the next comparison and wrongly pushing correct neighbors right. - const origX = new Map(ids.map((id) => [id, positions.get(id)?.x ?? 0])) + const origX = new Map(ids.map((id) => [id, positions.get(id)?.x ?? 0])) ids.sort((a, b) => (origX.get(a) ?? 0) - (origX.get(b) ?? 0)) - let rightmostX = origX.get(ids[0]) ?? 0 for (let i = 1; i < ids.length; i++) { const originalX = origX.get(ids[i]) ?? 0 - const placedX = Math.max(originalX, rightmostX + STEP) + const placedX = Math.max(originalX, rightmostX + step) if (placedX > originalX) { const delta = placedX - originalX @@ -167,6 +172,150 @@ function computeLayout(roots: ViewTreeNode[]): Map sum + 1 + countDescendantViews(child), 0) +} + +function flattenDisplayTree(roots: GridDisplayNode[]): GridDisplayNode[] { + const result: GridDisplayNode[] = [] + const visit = (node: GridDisplayNode) => { + result.push(node) + node.children.forEach(visit) + } + roots.forEach(visit) + return result +} + +function sumContentCounts( + nodes: ViewTreeNode[], + countsByView: Record +): { nodes: number; edges: number } { + let nodesCount = 0 + let edgesCount = 0 + + const visit = (node: ViewTreeNode) => { + const counts = countsByView[node.id] + nodesCount += counts?.nodes ?? 0 + edgesCount += counts?.edges ?? 0 + node.children.forEach(visit) + } + + nodes.forEach(visit) + return { nodes: nodesCount, edges: edgesCount } +} + +function buildDisplayTree(roots: ViewTreeNode[], focusedId: number | null): GridDisplayNode[] { + const flat = flattenTree(roots) + if (flat.length <= COMPACT_WORKSPACE_THRESHOLD) { + const convert = (node: ViewTreeNode, parentId: string | null): GridDisplayNode => ({ + id: String(node.id), + kind: 'view', + view: node, + parentId, + depth: node.depth, + children: node.children.map((child) => convert(child, String(node.id))), + }) + return roots.map((root) => convert(root, null)) + } + + const byId = new Map(flat.map((node) => [node.id, node])) + const focusedNode = focusedId ? byId.get(focusedId) ?? null : null + const visible = new Set() + const emphasis = new Set() + + if (!focusedNode) { + flat.forEach((node) => { + if (node.parent_view_id === null || node.depth <= 1) visible.add(node.id) + }) + } else { + let cursor: ViewTreeNode | undefined = focusedNode + while (cursor) { + visible.add(cursor.id) + emphasis.add(cursor.id) + cursor = cursor.parent_view_id ? byId.get(cursor.parent_view_id) : undefined + } + + roots.forEach((root) => visible.add(root.id)) + + const parent = focusedNode.parent_view_id ? byId.get(focusedNode.parent_view_id) : null + const siblings = flat.filter((node) => node.parent_view_id === focusedNode.parent_view_id) + siblings.forEach((node) => { + visible.add(node.id) + emphasis.add(node.id) + }) + + focusedNode.children.forEach((child) => { + visible.add(child.id) + emphasis.add(child.id) + }) + + parent?.children.forEach((child) => visible.add(child.id)) + } + + const makeNode = (node: ViewTreeNode, parentId: string | null): GridDisplayNode | null => { + if (!visible.has(node.id)) return null + + const displayId = String(node.id) + const visibleChildren = node.children + .map((child) => makeNode(child, displayId)) + .filter((child): child is GridDisplayNode => child !== null) + + const hiddenChildren = node.children.filter((child) => !visible.has(child.id)) + const hiddenCount = hiddenChildren.reduce((sum, child) => sum + 1 + countDescendantViews(child), 0) + const cluster: GridDisplayNode[] = hiddenCount > 0 ? [{ + id: `${node.id}:cluster`, + kind: 'cluster', + view: node, + parentId: displayId, + depth: node.depth + 1, + children: [], + collapsedCount: hiddenCount, + dimmed: focusedNode ? !emphasis.has(node.id) : false, + }] : [] + + return { + id: displayId, + kind: 'view', + view: node, + parentId, + depth: node.depth, + children: [...visibleChildren, ...cluster], + dimmed: focusedNode ? !emphasis.has(node.id) : false, + } + } + + return roots + .map((root) => makeNode(root, null)) + .filter((node): node is GridDisplayNode => node !== null) +} + +function buildStableLayoutIds(flat: ViewTreeNode[], focusedId: number | null): Set { + const stable = new Set() + const byId = new Map(flat.map((node) => [node.id, node])) + + if (focusedId === null) { + flat.forEach((node) => { + if (node.parent_view_id === null || node.depth <= 1) stable.add(String(node.id)) + }) + return stable + } + + const focused = byId.get(focusedId) + if (!focused) return stable + + let cursor: ViewTreeNode | undefined = focused + while (cursor) { + stable.add(String(cursor.id)) + cursor = cursor.parent_view_id ? byId.get(cursor.parent_view_id) : undefined + } + + flat.forEach((node) => { + if (node.parent_view_id === focused.parent_view_id) stable.add(String(node.id)) + }) + + return stable +} + @@ -387,7 +536,8 @@ function ViewGridInner({ onShare, treeData, loading, focusedId, onFocusChange, s }, [zoomIn, zoomOut]) // ── Derived tree structures ───────────────────────────────────────────────── - const roots = useMemo(() => treeData, [treeData]) + const [gridViewIds, setGridViewIds] = useState | null>(null) + const roots = useMemo(() => filterTreeForGrid(treeData, gridViewIds), [treeData, gridViewIds]) const flatTree = useMemo(() => flattenTree(roots), [roots]) // Rename @@ -422,28 +572,41 @@ function ViewGridInner({ onShare, treeData, loading, focusedId, onFocusChange, s } }, [loading, treeData.length]) - // Fetch node/edge counts + // Fetch visible grid cards and node/edge counts in one workspace roundtrip. useEffect(() => { let cancelled = false - const ids = flatTree.map((n) => n.id) - if (ids.length === 0) { setCountsByDiagram({}); return } + if (treeData.length === 0) { + setGridViewIds(new Set()) + setCountsByDiagram({}) + return + } ; (async () => { - const next: Record = {} - await Promise.all( - ids.map(async (id) => { - try { - const [objs, edges] = await Promise.all([ - api.workspace.views.placements.list(id), - api.workspace.connectors.list(id), - ]) - next[id] = { nodes: objs.length, edges: edges.length } - } catch { /* ignore per-diagram failure */ } + try { + const workspace = await api.workspace.views.gridData() + if (cancelled) return + + const visibleIds = new Set(flattenTree(workspace.views).map((view) => view.id)) + const next: Record = {} + + visibleIds.forEach((id) => { + const content = workspace.content[id] + next[id] = { + nodes: content?.placements.length ?? 0, + edges: content?.connectors.length ?? 0, + } }) - ) - if (!cancelled) setCountsByDiagram((prev) => ({ ...prev, ...next })) + + setGridViewIds(visibleIds) + setCountsByDiagram(next) + } catch { + if (!cancelled) { + setGridViewIds(null) + setCountsByDiagram({}) + } + } })() return () => { cancelled = true } - }, [flatTree]) + }, [treeData]) // ── Rename ────────────────────────────────────────────────────────────────── const startEdit = useCallback((id: number, name: string) => { @@ -545,7 +708,49 @@ function ViewGridInner({ onShare, treeData, loading, focusedId, onFocusChange, s } // ── RF nodes - pure derivation, no useState/useEffect ─────────────────────── - const layoutPositions = useMemo(() => computeLayout(roots), [roots]) + const displayTree = useMemo( + () => buildDisplayTree(roots, focusedId), + [roots, focusedId] + ) + const displayFlat = useMemo(() => flattenDisplayTree(displayTree), [displayTree]) + const rawLayoutPositions = useMemo(() => computeDisplayLayout(displayTree), [displayTree]) + const stableLayoutIds = useMemo(() => buildStableLayoutIds(flatTree, focusedId), [flatTree, focusedId]) + const previousLayoutRef = useRef>(new Map()) + + const layoutPositions = useMemo(() => { + const next = new Map(rawLayoutPositions) + const previousLayout = previousLayoutRef.current + + if (focusedId !== null) { + const focusedKey = String(focusedId) + const previousFocusedPosition = previousLayout.get(focusedKey) + const nextFocusedPosition = next.get(focusedKey) + + if (previousFocusedPosition && nextFocusedPosition) { + const dx = previousFocusedPosition.x - nextFocusedPosition.x + const dy = previousFocusedPosition.y - nextFocusedPosition.y + + if (dx !== 0 || dy !== 0) { + next.forEach((position, id) => { + next.set(id, { x: position.x + dx, y: position.y + dy }) + }) + } + } + } + + stableLayoutIds.forEach((id) => { + const previousPosition = previousLayout.get(id) + if (previousPosition && next.has(id)) { + next.set(id, previousPosition) + } + }) + + return next + }, [rawLayoutPositions, focusedId, stableLayoutIds]) + + useEffect(() => { + previousLayoutRef.current = layoutPositions + }, [layoutPositions]) // Stable during drag (layoutPositions only changes after treeData refresh, never on mouse moves) const computedMinZoom = useMemo(() => { @@ -599,35 +804,50 @@ function ViewGridInner({ onShare, treeData, loading, focusedId, onFocusChange, s }, [focusedId, flatTree]) const rfNodes = useMemo((): RFNode[] => - flatTree.map((n): RFNode => ({ - id: String(n.id), - type: 'diagramGrid', - position: layoutPositions.get(n.id) ?? { x: 0, y: 0 }, - data: { - id: n.id, - name: n.name, - level_label: n.level_label, - counts: countsByView[n.id], - focused: focusedId === n.id, - canEdit, - isEditing: editingId === n.id, - editName, - onFocus: () => onFocusChange(n.id), - onOpen: () => navigate(`/views/${n.id}`), - onStartRename: () => startEdit(n.id, n.name), - onDetails: () => handleDetailsOpen(n.id), - onDelete: () => { setDeleteTargetId(n.id); onDeleteOpen() }, - onShare: onShare ? () => onShare(n.id) : () => {}, - onEditNameChange: setEditName, - onEditCommit: commitEdit, - onEditCancel: cancelEdit, - isMobile: isMobileLayout, - wasdKey: wasdTargets[n.id], - } satisfies ViewGridNodeData, - draggable: false, - })), + displayFlat.map((displayNode): RFNode => { + const n = displayNode.view + const isCluster = displayNode.kind === 'cluster' + const hiddenChildren = isCluster + ? n.children.filter((child) => !displayFlat.some((visibleNode) => visibleNode.kind === 'view' && visibleNode.view.id === child.id)) + : [] + + return { + id: displayNode.id, + type: 'diagramGrid', + position: layoutPositions.get(displayNode.id) ?? { x: 0, y: 0 }, + data: { + id: n.id, + name: isCluster ? `${n.name} descendants` : n.name, + level_label: isCluster ? 'Collapsed stack' : n.level_label, + counts: isCluster ? sumContentCounts(hiddenChildren, countsByView) : countsByView[n.id], + kind: displayNode.kind, + collapsedCount: displayNode.collapsedCount, + dimmed: displayNode.dimmed, + focused: !isCluster && focusedId === n.id, + canEdit: !isCluster && canEdit, + isEditing: !isCluster && editingId === n.id, + editName, + onFocus: () => onFocusChange(n.id), + onOpen: () => isCluster ? onFocusChange(n.id) : navigate(`/views/${n.id}`), + onStartRename: () => startEdit(n.id, n.name), + onDetails: () => handleDetailsOpen(n.id), + onDelete: () => { setDeleteTargetId(n.id); onDeleteOpen() }, + onShare: onShare ? () => onShare(n.id) : () => {}, + onEditNameChange: setEditName, + onEditCommit: commitEdit, + onEditCancel: cancelEdit, + isMobile: isMobileLayout, + wasdKey: isCluster ? undefined : wasdTargets[n.id], + } satisfies ViewGridNodeData, + draggable: false, + style: { + transition: LAYOUT_TRANSITION, + zIndex: displayNode.kind === 'cluster' ? 1 : focusedId === n.id ? 3 : 2, + }, + } + }), // eslint-disable-next-line react-hooks/exhaustive-deps - [flatTree, layoutPositions, focusedId, countsByView, + [displayFlat, layoutPositions, focusedId, countsByView, editingId, editName, canEdit, navigate, startEdit, handleDetailsOpen, commitEdit, cancelEdit, onDeleteOpen, wasdTargets, levelEditingNodeId] @@ -674,17 +894,17 @@ function ViewGridInner({ onShare, treeData, loading, focusedId, onFocusChange, s // ── RF edges ──────────────────────────────────────────────────────────────── const rfEdges = useMemo((): RFEdge[] => - flatTree - .filter((n) => n.parent_view_id) + displayFlat + .filter((n) => n.parentId) .map((n) => ({ - id: `${n.parent_view_id}-${n.id}`, - source: String(n.parent_view_id!), - target: String(n.id), + id: `${n.parentId}-${n.id}`, + source: n.parentId!, + target: n.id, type: 'floating', animated: false, - data: { color: HIERARCHY_EDGE_COLOR, dashed: false }, + data: { color: n.kind === 'cluster' ? hexToRgba(accent, 0.28) : HIERARCHY_EDGE_COLOR, dashed: n.kind === 'cluster' }, })), - [flatTree] + [displayFlat, accent] ) const allRfEdges = rfEdges @@ -743,7 +963,7 @@ function ViewGridInner({ onShare, treeData, loading, focusedId, onFocusChange, s // ── Camera: pan to focused node only when it's out of view ────────────────── useEffect(() => { if (!focusedId) return - const pos = layoutPositions.get(focusedId) + const pos = layoutPositions.get(String(focusedId)) if (!pos) return const t = setTimeout(() => { const { x: vpX, y: vpY, zoom } = getViewport() @@ -846,8 +1066,7 @@ function ViewGridInner({ onShare, treeData, loading, focusedId, onFocusChange, s onFocusChange(null) }} style={{ - background: 'var(--bg-canvas)', - boxShadow: 'inset 0 0 100px rgba(0,0,0,0.6)' + background: 'var(--bg-canvas)' }} > {/* Micro dots for high precision technical feel */} diff --git a/frontend/src/pages/viewsJumpSearch.test.ts b/frontend/src/pages/viewsJumpSearch.test.ts new file mode 100644 index 0000000..2452994 --- /dev/null +++ b/frontend/src/pages/viewsJumpSearch.test.ts @@ -0,0 +1,193 @@ +import { describe, expect, it } from 'vitest' +import { + buildJumpSearchResults, + flattenTree, + jumpResultActionLabel, + jumpResultSubtitle, +} from './viewsJumpSearch' +import type { ExploreData, PlacedElement, ViewTreeNode } from '../types' + +function viewNode(id: number, name: string, children: ViewTreeNode[] = [], level = 1): ViewTreeNode { + return { + id, + owner_element_id: null, + name, + description: null, + level_label: level === 1 ? 'System' : null, + level, + depth: level - 1, + created_at: '2024-01-01', + updated_at: '2024-01-01', + parent_view_id: null, + children, + } +} + +function placed(viewId: number, elementId: number, name: string, overrides: Partial = {}): PlacedElement { + return { + id: viewId * 1000 + elementId, + view_id: viewId, + element_id: elementId, + position_x: 0, + position_y: 0, + name, + description: null, + kind: 'service', + technology: null, + url: null, + logo_url: null, + technology_connectors: [], + tags: [], + has_view: false, + view_label: null, + ...overrides, + } +} + +function exploreData(tree: ViewTreeNode[], placementsByView: Record): ExploreData { + return { + tree, + navigations: [], + views: Object.fromEntries( + Object.entries(placementsByView).map(([viewId, placements]) => [ + viewId, + { placements, connectors: [] }, + ]), + ), + } +} + +describe('views jump search', () => { + const tree = [ + viewNode(1, 'Workspace', [ + viewNode(2, 'Payments', [], 2), + viewNode(3, 'Platform Payments', [], 2), + viewNode(4, 'Identity', [], 2), + ]), + ] + const flatTree = flattenTree(tree) + + it('flattens the view tree in navigation order', () => { + expect(flatTree.map((node) => node.name)).toEqual([ + 'Workspace', + 'Payments', + 'Platform Payments', + 'Identity', + ]) + }) + + it('returns view results without requiring workspace placement data', () => { + const results = buildJumpSearchResults('payments', flatTree, null) + + expect(results).toEqual([ + { + type: 'view', + key: 'view:2', + name: 'Payments', + viewId: 2, + level: 2, + levelLabel: null, + }, + { + type: 'view', + key: 'view:3', + name: 'Platform Payments', + viewId: 3, + level: 2, + levelLabel: null, + }, + ]) + }) + + it('builds element navigation results from names and metadata fields', () => { + const data = exploreData(tree, { + 2: [ + placed(2, 201, 'Checkout API', { + kind: 'api', + technology: 'Node', + file_path: 'services/checkout/index.ts', + tags: ['critical-path'], + }), + ], + 4: [ + placed(4, 401, 'Session Store', { + kind: 'database', + technology: 'Redis', + tags: ['auth'], + }), + ], + }) + + const byName = buildJumpSearchResults('checkout', flatTree, data) + expect(byName).toContainEqual({ + type: 'element', + key: 'element:2:201', + name: 'Checkout API', + viewId: 2, + viewName: 'Payments', + elementId: 201, + kind: 'api', + }) + + const byTag = buildJumpSearchResults('critical', flatTree, data) + expect(byTag.map((result) => result.key)).toEqual(['element:2:201']) + + const byTechnology = buildJumpSearchResults('redis', flatTree, data) + expect(byTechnology.map((result) => result.key)).toEqual(['element:4:401']) + + const byPath = buildJumpSearchResults('services/checkout', flatTree, data) + expect(byPath.map((result) => result.key)).toEqual(['element:2:201']) + }) + + it('keeps same-element placements in different views as distinct navigation targets', () => { + const data = exploreData(tree, { + 2: [placed(2, 501, 'Shared Logger')], + 3: [placed(3, 501, 'Shared Logger')], + }) + + const results = buildJumpSearchResults('shared logger', flatTree, data) + + expect(results).toEqual([ + expect.objectContaining({ type: 'element', key: 'element:2:501', viewId: 2, elementId: 501 }), + expect.objectContaining({ type: 'element', key: 'element:3:501', viewId: 3, elementId: 501 }), + ]) + }) + + it('caps result groups while preserving view-first ordering', () => { + const manyViews = [ + viewNode(10, 'Alpha Root', [ + viewNode(11, 'Alpha Billing', [], 2), + viewNode(12, 'Alpha Catalog', [], 2), + viewNode(13, 'Alpha Checkout', [], 2), + viewNode(14, 'Alpha Delivery', [], 2), + viewNode(15, 'Alpha Events', [], 2), + ]), + ] + const manyFlatTree = flattenTree(manyViews) + const data = exploreData(manyViews, { + 11: Array.from({ length: 8 }, (_, index) => placed(11, 700 + index, `Alpha Element ${index}`)), + }) + + const results = buildJumpSearchResults('alpha', manyFlatTree, data) + + expect(results).toHaveLength(8) + expect(results.filter((result) => result.type === 'view')).toHaveLength(4) + expect(results.filter((result) => result.type === 'element')).toHaveLength(4) + expect(results.slice(0, 4).every((result) => result.type === 'view')).toBe(true) + }) + + it('formats result hints used by the toolbar', () => { + expect(buildJumpSearchResults('id', flatTree, null)).toEqual([]) + expect(jumpResultActionLabel('explore')).toBe('ZOOM') + expect(jumpResultActionLabel('hierarchy')).toBe('OPEN') + + const [viewResult] = buildJumpSearchResults('payments', flatTree, null) + expect(jumpResultSubtitle(viewResult)).toBe('Level 2 • Diagram') + + const data = exploreData(tree, { + 4: [placed(4, 401, 'Session Store', { kind: null })], + }) + const [elementResult] = buildJumpSearchResults('session', flatTree, data) + expect(jumpResultSubtitle(elementResult)).toBe('Element • Identity') + }) +}) diff --git a/frontend/src/pages/viewsJumpSearch.ts b/frontend/src/pages/viewsJumpSearch.ts new file mode 100644 index 0000000..04771e6 --- /dev/null +++ b/frontend/src/pages/viewsJumpSearch.ts @@ -0,0 +1,111 @@ +import type { ExploreData, PlacedElement, ViewTreeNode } from '../types' + +export type JumpViewMode = 'explore' | 'hierarchy' + +export type JumpSearchResult = + | { + type: 'view' + key: string + name: string + viewId: number + level: number + levelLabel: string | null + } + | { + type: 'element' + key: string + name: string + viewId: number + viewName: string + elementId: number + kind: string | null + } + +export function flattenTree(roots: ViewTreeNode[]): ViewTreeNode[] { + const result: ViewTreeNode[] = [] + const traverse = (node: ViewTreeNode) => { + result.push(node) + node.children.forEach(traverse) + } + roots.forEach(traverse) + return result +} + +export function jumpResultSubtitle(result: JumpSearchResult): string { + if (result.type === 'view') { + return `Level ${result.level} • ${result.levelLabel || 'Diagram'}` + } + return `${result.kind || 'Element'} • ${result.viewName}` +} + +export function jumpResultActionLabel(view: JumpViewMode): string { + if (view === 'explore') return 'ZOOM' + return 'OPEN' +} + +function searchScore(value: string, normalizedTerm: string): number { + const normalizedValue = value.toLowerCase() + if (normalizedValue === normalizedTerm) return 0 + if (normalizedValue.startsWith(normalizedTerm)) return 1 + return normalizedValue.includes(normalizedTerm) ? 2 : 3 +} + +function placementMatches(placement: PlacedElement, normalizedTerm: string): boolean { + return [ + placement.name, + placement.kind, + placement.technology, + placement.file_path, + ...(placement.tags ?? []), + ] + .filter(Boolean) + .some((value) => String(value).toLowerCase().includes(normalizedTerm)) +} + +export function buildJumpSearchResults(term: string, flatTree: ViewTreeNode[], exploreData: ExploreData | null): JumpSearchResult[] { + const normalized = term.trim().toLowerCase() + if (normalized.length < 3) return [] + + const viewById = new Map(flatTree.map((node) => [node.id, node])) + const viewResults: JumpSearchResult[] = flatTree + .filter((node) => node.name.toLowerCase().includes(normalized)) + .sort((a, b) => searchScore(a.name, normalized) - searchScore(b.name, normalized) || a.name.localeCompare(b.name)) + .slice(0, 4) + .map((node) => ({ + type: 'view', + key: `view:${node.id}`, + name: node.name, + viewId: node.id, + level: node.level, + levelLabel: node.level_label, + })) + + const elementResults: JumpSearchResult[] = [] + if (exploreData) { + Object.entries(exploreData.views ?? {}).forEach(([viewIdText, content]) => { + const viewId = Number(viewIdText) + if (!Number.isFinite(viewId)) return + const viewName = viewById.get(viewId)?.name ?? `View ${viewId}` + ;(content.placements ?? []).forEach((placement) => { + if (!placementMatches(placement, normalized)) return + elementResults.push({ + type: 'element', + key: `element:${viewId}:${placement.element_id}`, + name: placement.name || `Element ${placement.element_id}`, + viewId, + viewName, + elementId: placement.element_id, + kind: placement.kind, + }) + }) + }) + } + + const dedupedElements = Array.from( + new Map(elementResults.map((result) => [result.key, result])).values(), + ) + .sort((a, b) => searchScore(a.name, normalized) - searchScore(b.name, normalized) || a.name.localeCompare(b.name)) + .slice(0, 6) + + return [...viewResults, ...dedupedElements].slice(0, 8) +} diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts index 3e28a4d..c0cfd1e 100644 --- a/frontend/src/types/index.ts +++ b/frontend/src/types/index.ts @@ -3,6 +3,7 @@ export interface TechnologyConnector { slug?: string label: string is_primary_icon?: boolean + isPrimaryIcon?: boolean } export interface TechnologyCatalogItem { @@ -82,6 +83,15 @@ export interface PlacedElement { view_label: string | null } +export interface VisibilityOverride { + view_id: number + resource_type: 'element' | 'connector' + resource_id: number + level_delta: number + created_at?: string + updated_at?: string +} + export interface NavigationConnector { id: number element_id: number | null diff --git a/frontend/src/utils/elementIcon.test.ts b/frontend/src/utils/elementIcon.test.ts new file mode 100644 index 0000000..773ee5c --- /dev/null +++ b/frontend/src/utils/elementIcon.test.ts @@ -0,0 +1,22 @@ +import { describe, expect, it } from 'vitest' +import { resolveElementIconUrl } from './elementIcon' + +describe('resolveElementIconUrl', () => { + it('uses an explicit logo url before derived technology icons', () => { + expect(resolveElementIconUrl('/custom.svg', [ + { type: 'catalog', slug: 'golang', label: 'Go', is_primary_icon: true }, + ])).toBe('/custom.svg') + }) + + it('derives the selected catalog technology icon when logo_url is missing', () => { + expect(resolveElementIconUrl(null, [ + { type: 'catalog', slug: 'golang', label: 'Go', is_primary_icon: true }, + ])).toBe('/icons/golang.png') + }) + + it('preserves explicit no-icon clears instead of falling back to technology', () => { + expect(resolveElementIconUrl('', [ + { type: 'catalog', slug: 'golang', label: 'Go', is_primary_icon: true }, + ])).toBeNull() + }) +}) diff --git a/frontend/src/utils/elementIcon.ts b/frontend/src/utils/elementIcon.ts new file mode 100644 index 0000000..1cc97eb --- /dev/null +++ b/frontend/src/utils/elementIcon.ts @@ -0,0 +1,19 @@ +import type { TechnologyConnector } from '../types' +import { resolveIconPath } from './url' + +export function resolveElementIconUrl( + logoUrl: string | null | undefined, + technologyConnectors: TechnologyConnector[] | null | undefined, +): string | null { + if (logoUrl != null) { + return logoUrl === '' ? null : resolveIconPath(logoUrl) + } + + const selected = technologyConnectors?.find((link) => ( + link.type === 'catalog' && + !!(link.is_primary_icon ?? link.isPrimaryIcon) && + !!link.slug + )) + if (!selected?.slug) return null + return resolveIconPath(`/icons/${selected.slug}.png`) +} diff --git a/frontend/src/utils/sourceEditor.ts b/frontend/src/utils/sourceEditor.ts new file mode 100644 index 0000000..14ebabc --- /dev/null +++ b/frontend/src/utils/sourceEditor.ts @@ -0,0 +1,46 @@ +import { useEffect, useState } from 'react' +import type { SourceEditor } from '../api/client' + +const SOURCE_EDITOR_KEY = 'diag:source-editor' +const DEFAULT_SOURCE_EDITOR: SourceEditor = 'zed' + +function readSourceEditor(): SourceEditor { + if (typeof window === 'undefined') return DEFAULT_SOURCE_EDITOR + const stored = window.localStorage.getItem(SOURCE_EDITOR_KEY) + return stored === 'vscode' || stored === 'zed' ? stored : DEFAULT_SOURCE_EDITOR +} + +export function getSourceEditor(): SourceEditor { + return readSourceEditor() +} + +export function setSourceEditor(value: SourceEditor) { + window.localStorage.setItem(SOURCE_EDITOR_KEY, value) + window.dispatchEvent(new CustomEvent('diag:source-editor-change', { detail: value })) +} + +export function useSourceEditor() { + const [editor, setEditorState] = useState(() => readSourceEditor()) + + useEffect(() => { + const handleStorage = (event: StorageEvent) => { + if (event.key === SOURCE_EDITOR_KEY) { + setEditorState(readSourceEditor()) + } + } + const handleChange = () => setEditorState(readSourceEditor()) + window.addEventListener('storage', handleStorage) + window.addEventListener('diag:source-editor-change', handleChange) + return () => { + window.removeEventListener('storage', handleStorage) + window.removeEventListener('diag:source-editor-change', handleChange) + } + }, []) + + const setEditor = (value: SourceEditor) => { + setSourceEditor(value) + setEditorState(value) + } + + return { editor, setEditor } +} diff --git a/frontend/src/utils/watchDiffSummary.ts b/frontend/src/utils/watchDiffSummary.ts new file mode 100644 index 0000000..cd0bf80 --- /dev/null +++ b/frontend/src/utils/watchDiffSummary.ts @@ -0,0 +1,159 @@ +import type { Connector, ExploreData, ViewTreeNode } from '../types' +import type { WatchDiff } from '../api/client' + +export type WatchChangeType = 'added' | 'updated' | 'deleted' | 'initialized' + +export interface WatchResourceStat { + added: number + updated: number + deleted: number + initialized: number + addedLines: number + removedLines: number +} + +export interface WatchDiffLocation { + key: string + label: string + resourceType: string + resourceId?: number + changeType: WatchChangeType + summary?: string + addedLines: number + removedLines: number + viewId: number + viewName: string +} + +export interface WatchDiffSummary { + files: WatchResourceStat + elements: WatchResourceStat + connectors: WatchResourceStat +} + +export function normalizeWatchChangeType(value: string): WatchChangeType { + if (value === 'added' || value === 'updated' || value === 'deleted' || value === 'initialized') return value + return 'updated' +} + +export function emptyWatchResourceStat(): WatchResourceStat { + return { added: 0, updated: 0, deleted: 0, initialized: 0, addedLines: 0, removedLines: 0 } +} + +export function summarizeWatchDiffs(diffs: WatchDiff[] | null | undefined): WatchDiffSummary { + const summary = { + files: emptyWatchResourceStat(), + elements: emptyWatchResourceStat(), + connectors: emptyWatchResourceStat(), + } + ;(Array.isArray(diffs) ? diffs : []).forEach((diff) => { + const bucket = + diff.resource_type === 'file' || diff.owner_type === 'file' + ? summary.files + : diff.resource_type === 'element' + ? summary.elements + : diff.resource_type === 'connector' + ? summary.connectors + : null + if (!bucket) return + bucket[normalizeWatchChangeType(diff.change_type)] += 1 + bucket.addedLines += Math.max(0, diff.added_lines ?? 0) + bucket.removedLines += Math.max(0, diff.removed_lines ?? 0) + }) + return summary +} + +export function formatStatLine(label: string, stat: WatchResourceStat): string { + const total = stat.added + stat.updated + stat.deleted + stat.initialized + const parts = [`${total} ${label}${total === 1 ? '' : 's'} changed`] + if (stat.addedLines > 0) parts.push(`+${stat.addedLines}`) + if (stat.removedLines > 0) parts.push(`-${stat.removedLines}`) + return parts.join(', ') +} + +export function formatTldStatLine(summary: WatchDiffSummary): string { + return [ + formatStatLine('element', summary.elements), + formatStatLine('connector', summary.connectors), + ].join(' · ') +} + +function flattenViews(nodes: ViewTreeNode[], out = new Map()): Map { + nodes.forEach((node) => { + out.set(node.id, node) + flattenViews(node.children ?? [], out) + }) + return out +} + +function connectorName(connector: Connector): string { + return connector.label || connector.relationship || `connector ${connector.id}` +} + +export function buildWatchDiffLocations(data: ExploreData, diffs: WatchDiff[] | null | undefined): WatchDiffLocation[] { + const views = flattenViews(data.tree ?? []) + const elementViews = new Map() + const connectorViews = new Map() + + Object.entries(data.views ?? {}).forEach(([viewIdText, content]) => { + const viewId = Number(viewIdText) + if (!Number.isFinite(viewId)) return + const viewName = views.get(viewId)?.name ?? `View ${viewId}` + ;(content.placements ?? []).forEach((placement) => { + const list = elementViews.get(placement.element_id) ?? [] + list.push({ + key: `element:${placement.element_id}:${viewId}`, + label: placement.name || `element ${placement.element_id}`, + resourceType: 'element', + resourceId: placement.element_id, + changeType: 'updated', + addedLines: 0, + removedLines: 0, + viewId, + viewName, + }) + elementViews.set(placement.element_id, list) + }) + ;(content.connectors ?? []).forEach((connector) => { + connectorViews.set(connector.id, { + key: `connector:${connector.id}:${viewId}`, + label: connectorName(connector), + resourceType: 'connector', + resourceId: connector.id, + changeType: 'updated', + addedLines: 0, + removedLines: 0, + viewId, + viewName, + }) + }) + }) + + const locations: WatchDiffLocation[] = [] + ;(Array.isArray(diffs) ? diffs : []).forEach((diff) => { + if (!diff.resource_id) return + const base = { + changeType: normalizeWatchChangeType(diff.change_type), + summary: diff.summary, + addedLines: Math.max(0, diff.added_lines ?? 0), + removedLines: Math.max(0, diff.removed_lines ?? 0), + } + if (diff.resource_type === 'element') { + ;(elementViews.get(diff.resource_id) ?? []).forEach((location) => { + locations.push({ ...location, ...base }) + }) + } + if (diff.resource_type === 'connector') { + const location = connectorViews.get(diff.resource_id) + if (location) locations.push({ ...location, ...base }) + } + }) + + const seen = new Set() + return locations.filter((location) => { + const key = `${location.resourceType}:${location.resourceId}:${location.viewId}` + if (seen.has(key)) return false + seen.add(key) + return true + }) +} diff --git a/frontend/tsconfig.json b/frontend/tsconfig.json index 9bc1f75..454016b 100644 --- a/frontend/tsconfig.json +++ b/frontend/tsconfig.json @@ -2,7 +2,11 @@ "compilerOptions": { "target": "ES2020", "useDefineForClassFields": true, - "lib": ["ES2020", "DOM", "DOM.Iterable"], + "lib": [ + "ES2020", + "DOM", + "DOM.Iterable" + ], "module": "ESNext", "skipLibCheck": true, "moduleResolution": "bundler", @@ -17,6 +21,13 @@ "noFallthroughCasesInSwitch": true, "ignoreDeprecations": "5.0", }, - "include": ["src/main.tsx", "src/**/*.d.ts"], - "references": [{ "path": "./tsconfig.node.json" }] + "include": [ + "src/main.tsx", + "src/**/*.d.ts" + ], + "references": [ + { + "path": "./tsconfig.node.json" + } + ] } diff --git a/frontend/vite.config.ts b/frontend/vite.config.ts index 7e109bd..4b2a6f0 100644 --- a/frontend/vite.config.ts +++ b/frontend/vite.config.ts @@ -39,7 +39,14 @@ function iconsAliasPlugin() { } export default defineConfig(async () => { - const plugins: Plugin[] = [react(), tsconfigPaths(), iconsAliasPlugin()]; + const plugins: Plugin[] = [ + react(), + tsconfigPaths({ + projects: [fileURLToPath(new URL("./tsconfig.json", import.meta.url))], + ignoreConfigErrors: true, + }), + iconsAliasPlugin(), + ]; return { plugins, diff --git a/frontend/vite.lib.config.ts b/frontend/vite.lib.config.ts index d2e6120..92003f9 100644 --- a/frontend/vite.lib.config.ts +++ b/frontend/vite.lib.config.ts @@ -60,8 +60,24 @@ const EXTERNAL_PACKAGES = new Set([ 'mermaid-ast', ]) +const localProtoGenDir = process.env.TLD_LOCAL_PROTO_GEN +const preserveDistOnRebuild = process.env.TLD_PRESERVE_DIST === '1' + export default defineConfig({ - plugins: [react(), tsconfigPaths()] as Plugin[], + plugins: [ + react(), + tsconfigPaths({ + projects: [resolve(__dirname, 'tsconfig.json')], + ignoreConfigErrors: true, + }), + ] as Plugin[], + resolve: localProtoGenDir + ? { + alias: { + '@buf/tldiagramcom_diagram.bufbuild_es': resolve(__dirname, localProtoGenDir), + }, + } + : undefined, build: { lib: { entry: resolve(__dirname, 'src/index.ts'), @@ -83,7 +99,7 @@ export default defineConfig({ }, // Emit a single CSS file alongside the JS bundle cssCodeSplit: false, - // Keep output clean - emptyOutDir: true, + // For watch-mode local dev, keep the previous bundle in place until the new one is written. + emptyOutDir: !preserveDistOnRebuild, }, }) diff --git a/go.mod b/go.mod index 265bc81..47e8f55 100644 --- a/go.mod +++ b/go.mod @@ -3,17 +3,20 @@ module github.com/mertcikla/tld go 1.26.1 require ( - buf.build/gen/go/tldiagramcom/diagram/connectrpc/go v1.19.2-20260424214817-528c19e30457.1 - buf.build/gen/go/tldiagramcom/diagram/protocolbuffers/go v1.36.11-20260424214817-528c19e30457.1 + buf.build/gen/go/tldiagramcom/diagram/connectrpc/go v1.19.2-20260503002426-45e3166b5ec1.1 + buf.build/gen/go/tldiagramcom/diagram/protocolbuffers/go v1.36.11-20260503002426-45e3166b5ec1.1 connectrpc.com/connect v1.19.2 github.com/bmatcuk/doublestar/v4 v4.6.1 + github.com/fsnotify/fsnotify v1.10.0 github.com/google/uuid v1.6.0 github.com/modelcontextprotocol/go-sdk v1.5.0 github.com/odvcencio/gotreesitter v0.15.2 + github.com/openai/openai-go v1.12.0 github.com/schollz/progressbar/v3 v3.19.0 github.com/speps/go-hashids/v2 v2.0.1 github.com/spf13/cobra v1.9.1 github.com/tetratelabs/wazero v1.11.0 + github.com/viant/sqlite-vec v0.3.0 go.lsp.dev/jsonrpc2 v0.10.0 go.lsp.dev/protocol v0.12.0 go.lsp.dev/uri v0.3.0 @@ -25,6 +28,7 @@ require ( ) require ( + github.com/compose-spec/compose-go/v2 v2.10.2 // indirect github.com/dustin/go-humanize v1.0.1 // indirect github.com/google/jsonschema-go v0.4.2 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect @@ -37,6 +41,11 @@ require ( github.com/segmentio/encoding v0.5.4 // indirect github.com/spf13/pflag v1.0.6 // indirect github.com/stretchr/testify v1.10.0 // indirect + github.com/tidwall/gjson v1.14.4 // indirect + github.com/tidwall/match v1.1.1 // indirect + github.com/tidwall/pretty v1.2.1 // indirect + github.com/tidwall/sjson v1.2.5 // indirect + github.com/viant/vec v0.2.3 // indirect github.com/yosida95/uritemplate/v3 v3.0.2 // indirect go.lsp.dev/pkg v0.0.0-20210717090340-384b27a52fb2 // indirect go.uber.org/atomic v1.9.0 // indirect diff --git a/go.sum b/go.sum index d7c4876..3ccb030 100644 --- a/go.sum +++ b/go.sum @@ -1,7 +1,7 @@ -buf.build/gen/go/tldiagramcom/diagram/connectrpc/go v1.19.2-20260424214817-528c19e30457.1 h1:BPDWZZwgEG0aoW1KI/zF1obCZwVSjkXe0JKS223DIZk= -buf.build/gen/go/tldiagramcom/diagram/connectrpc/go v1.19.2-20260424214817-528c19e30457.1/go.mod h1:UvrxmE9Tap7TrdjeiJ8kv7WIV5cFMO4Anso3GU9kEjc= -buf.build/gen/go/tldiagramcom/diagram/protocolbuffers/go v1.36.11-20260424214817-528c19e30457.1 h1:rOe5hscOJW22MFv6d8qwfzopMwIhJoQFvxgGtTiuSy8= -buf.build/gen/go/tldiagramcom/diagram/protocolbuffers/go v1.36.11-20260424214817-528c19e30457.1/go.mod h1:LAKkY9RBJgkdENcD9Hq6EBmfTMH7U6xhhQOu1M6sCHM= +buf.build/gen/go/tldiagramcom/diagram/connectrpc/go v1.19.2-20260503002426-45e3166b5ec1.1 h1:Znp6fDsb1gFginIssNzEDnG0FDHgeutEP8S13zHfV00= +buf.build/gen/go/tldiagramcom/diagram/connectrpc/go v1.19.2-20260503002426-45e3166b5ec1.1/go.mod h1:j+AQJPxRjfDkp1bqTIDBpSml5jW2oA1l4/Q8EhE7hwE= +buf.build/gen/go/tldiagramcom/diagram/protocolbuffers/go v1.36.11-20260503002426-45e3166b5ec1.1 h1:s60RxXzkKC7fubKCyeW3UkrA4XJxKXP1t5iNfq03f3k= +buf.build/gen/go/tldiagramcom/diagram/protocolbuffers/go v1.36.11-20260503002426-45e3166b5ec1.1/go.mod h1:LAKkY9RBJgkdENcD9Hq6EBmfTMH7U6xhhQOu1M6sCHM= connectrpc.com/connect v1.19.2 h1:McQ83FGdzL+t60peksi0gXC7MQ/iLKgLduAnThbM0mo= connectrpc.com/connect v1.19.2/go.mod h1:tN20fjdGlewnSFeZxLKb0xwIZ6ozc3OQs2hTXy4du9w= github.com/benbjohnson/clock v1.1.0 h1:Q92kusRqC1XV2MjkWETPvjJVqKetz1OzxZB7mHJLju8= @@ -10,12 +10,16 @@ github.com/bmatcuk/doublestar/v4 v4.6.1 h1:FH9SifrbvJhnlQpztAx++wlkk70QBf0iBWDwN github.com/bmatcuk/doublestar/v4 v4.6.1/go.mod h1:xBQ8jztBU6kakFMg+8WGxn0c6z1fTSPVIjEY1Wr7jzc= github.com/chengxilo/virtualterm v1.0.4 h1:Z6IpERbRVlfB8WkOmtbHiDbBANU7cimRIof7mk9/PwM= github.com/chengxilo/virtualterm v1.0.4/go.mod h1:DyxxBZz/x1iqJjFxTFcr6/x+jSpqN0iwWCOK1q10rlY= +github.com/compose-spec/compose-go/v2 v2.10.2 h1:USa1NUbDcl/cjb8T9iwnuFsnO79H+2ho2L5SjFKz3uI= +github.com/compose-spec/compose-go/v2 v2.10.2/go.mod h1:ZU6zlcweCZKyiB7BVfCizQT9XmkEIMFE+PRZydVcsZg= github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= +github.com/fsnotify/fsnotify v1.10.0 h1:Xx/5Ydg9CeBDX/wi4VJqStNtohYjitZhhlHt4h3St1M= +github.com/fsnotify/fsnotify v1.10.0/go.mod h1:TLheqan6HD6GBK6PrDWyDPBaEV8LspOxvPSjC+bVfgo= github.com/golang-jwt/jwt/v5 v5.3.1 h1:kYf81DTWFe7t+1VvL7eS+jKFVWaUnK9cB1qbwn63YCY= github.com/golang-jwt/jwt/v5 v5.3.1/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= @@ -48,6 +52,8 @@ github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOF github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= github.com/odvcencio/gotreesitter v0.15.2 h1:XDNuSI1Gg738HGK62Mc06tornEfdteOTtLlpME0F5y4= github.com/odvcencio/gotreesitter v0.15.2/go.mod h1:Sx+iYJBfw5xSWkSttLSuFvguJctlH+ma1BTxZ0MPCqo= +github.com/openai/openai-go v1.12.0 h1:NBQCnXzqOTv5wsgNC36PrFEiskGfO5wccfCWDo9S1U0= +github.com/openai/openai-go v1.12.0/go.mod h1:g461MYGXEXBVdV5SaR/5tNzNbSfwTBBefwc+LlDCK0Y= github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= @@ -76,6 +82,20 @@ github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOf github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/tetratelabs/wazero v1.11.0 h1:+gKemEuKCTevU4d7ZTzlsvgd1uaToIDtlQlmNbwqYhA= github.com/tetratelabs/wazero v1.11.0/go.mod h1:eV28rsN8Q+xwjogd7f4/Pp4xFxO7uOGbLcD/LzB1wiU= +github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/gjson v1.14.4 h1:uo0p8EbA09J7RQaflQ1aBRffTR7xedD2bcIVSYxLnkM= +github.com/tidwall/gjson v1.14.4/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= +github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= +github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= +github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= +github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= +github.com/viant/sqlite-vec v0.3.0 h1:D0wCrkJ0KsO3sanmHV/m+58UwCZMFxU0wcK8LsYk00o= +github.com/viant/sqlite-vec v0.3.0/go.mod h1:SA89LGdU/cxpc/gsvat2MYtJYrv3bg9mtW/uHMs2nBs= +github.com/viant/vec v0.2.3 h1:NMWW1WtBXJ3Q47LHMGrXZAb6pVL3MjJfWVcEMD8V2t8= +github.com/viant/vec v0.2.3/go.mod h1:d1coA6/d5WBePJe0nDhgE7aRYkMd7CMiRWSNid3tvds= github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= diff --git a/internal/analyzer/cpp_tree_sitter.go b/internal/analyzer/cpp_tree_sitter.go index d2531ba..4504686 100644 --- a/internal/analyzer/cpp_tree_sitter.go +++ b/internal/analyzer/cpp_tree_sitter.go @@ -20,6 +20,7 @@ func (p *cppParser) ParseFile(ctx context.Context, path string, source []byte) ( result := &Result{} root := parsed.tree.RootNode() p.walkNode(root, parsed.lang, source, path, "", result) + p.appendTopLevelFunctionDeclarations(source, path, result) return result, nil } @@ -90,15 +91,12 @@ func (p *cppParser) appendFunction(node *gotreesitter.Node, lang *gotreesitter.L } func (p *cppParser) appendMemberDeclaration(node *gotreesitter.Node, lang *gotreesitter.Language, source []byte, path, parent string, result *Result) { - if parent == "" { - return - } declarator := childByFieldName(node, lang, "declarator") - if !cppHasFunctionDeclarator(declarator, lang) { - return - } name, owner := cppFunctionInfo(declarator, source) - if name == "" { + if name == "" && parent == "" { + name, owner = cppDeclarationInfo(node, source) + } + if name == "" || (!cppHasFunctionDeclarator(declarator, lang) && parent != "") { return } owner = cppResolveOwner(owner, parent) @@ -140,6 +138,195 @@ func cppFunctionInfo(declarator *gotreesitter.Node, source []byte) (string, stri return cppFunctionName(text), cppFunctionOwner(text) } +func cppDeclarationInfo(node *gotreesitter.Node, source []byte) (string, string) { + if node == nil { + return "", "" + } + text := strings.TrimSpace(nodeText(node, source)) + if !cppLooksLikeTopLevelFunctionDeclaration(text) { + return "", "" + } + return cppDeclarationInfoFromText(text) +} + +func cppLooksLikeTopLevelFunctionDeclaration(text string) bool { + text = strings.TrimSpace(text) + if text == "" || !strings.HasSuffix(text, ";") || !strings.Contains(text, "(") || strings.Contains(text, "=") || strings.Contains(text, "(*") { + return false + } + if beforeCall := cppBeforeCall(text); strings.Contains(beforeCall, "{") || strings.Contains(beforeCall, "}") { + return false + } + lower := strings.ToLower(text) + for _, prefix := range []string{"typedef ", "if ", "if(", "for ", "for(", "while ", "while(", "switch ", "switch(", "return "} { + if strings.HasPrefix(lower, prefix) { + return false + } + } + return true +} + +func (p *cppParser) appendTopLevelFunctionDeclarations(source []byte, path string, result *Result) { + seen := make(map[string]struct{}, len(result.Symbols)) + for _, sym := range result.Symbols { + seen[fmt.Sprintf("%s:%s:%d", sym.Kind, sym.Name, sym.Line)] = struct{}{} + } + + depth := 0 + inSingleComment := false + inMultiComment := false + inString := false + inChar := false + escapeNext := false + + var currentDecl strings.Builder + declLine := 0 + lineNum := 1 + + for i := 0; i < len(source); i++ { + c := source[i] + + if c == '\n' { + lineNum++ + inSingleComment = false + escapeNext = false + if depth == 0 && currentDecl.Len() > 0 && currentDecl.String()[currentDecl.Len()-1] != ' ' { + currentDecl.WriteByte(' ') + } + continue + } + + if escapeNext { + escapeNext = false + continue + } + + if inSingleComment { + continue + } + + if inMultiComment { + if c == '*' && i+1 < len(source) && source[i+1] == '/' { + inMultiComment = false + i++ + } + continue + } + + if inString { + switch c { + case '\\': + escapeNext = true + case '"': + inString = false + } + continue + } + + if inChar { + switch c { + case '\\': + escapeNext = true + case '\'': + inChar = false + } + continue + } + + if c == '/' && i+1 < len(source) { + if source[i+1] == '/' { + inSingleComment = true + i++ + continue + } else if source[i+1] == '*' { + inMultiComment = true + i++ + continue + } + } + + if c == '"' { + inString = true + continue + } + + if c == '\'' { + inChar = true + continue + } + + if c == '{' { + depth++ + currentDecl.Reset() + continue + } + + if c == '}' { + depth-- + if depth < 0 { + depth = 0 + } + currentDecl.Reset() + continue + } + + if depth == 0 { + isSpace := c == ' ' || c == '\t' || c == '\r' + if currentDecl.Len() == 0 && !isSpace { + declLine = lineNum + } + + if isSpace { + if currentDecl.Len() > 0 && currentDecl.String()[currentDecl.Len()-1] != ' ' { + currentDecl.WriteByte(' ') + } + } else { + currentDecl.WriteByte(c) + } + + if c == ';' { + declStr := currentDecl.String() + if cppLooksLikeTopLevelFunctionDeclaration(declStr) { + name, owner := cppDeclarationInfoFromText(declStr) + if name != "" { + kind := cppFunctionKind(name, owner) + key := fmt.Sprintf("%s:%s:%d", kind, name, declLine) + if _, ok := seen[key]; !ok { + if cppHasDeclarationSymbol(result.Symbols, kind, name, owner, declLine, lineNum) { + currentDecl.Reset() + continue + } + result.Symbols = append(result.Symbols, Symbol{ + Name: name, + Kind: kind, + FilePath: path, + Line: declLine, + EndLine: lineNum, + Parent: owner, + }) + seen[key] = struct{}{} + } + } + } + currentDecl.Reset() + } + } + } +} + +func cppHasDeclarationSymbol(symbols []Symbol, kind, name, owner string, startLine, endLine int) bool { + for _, sym := range symbols { + if sym.Kind == kind && sym.Name == name && sym.Parent == owner && sym.Line >= startLine && sym.Line <= endLine { + return true + } + } + return false +} + +func cppDeclarationInfoFromText(text string) (string, string) { + return cppFunctionName(text), cppFunctionOwner(text) +} + func cppHasFunctionDeclarator(node *gotreesitter.Node, lang *gotreesitter.Language) bool { if node == nil { return false diff --git a/internal/analyzer/cpp_tree_sitter_test.go b/internal/analyzer/cpp_tree_sitter_test.go new file mode 100644 index 0000000..c1576c3 --- /dev/null +++ b/internal/analyzer/cpp_tree_sitter_test.go @@ -0,0 +1,113 @@ +package analyzer + +import ( + "context" + "testing" +) + +func TestCPPParser_TopLevelFunctionDeclarations(t *testing.T) { + parser := &cppParser{} + source := `#ifndef SERVICE_H +#define SERVICE_H + +UV_EXTERN void uv_sleep(unsigned int msec); +int helper(int value); + +#endif +` + result, err := parser.ParseFile(context.Background(), "service.h", []byte(source)) + if err != nil { + t.Fatalf("ParseFile: %v", err) + } + + symbols := map[string]Symbol{} + for _, sym := range result.Symbols { + symbols[sym.Name] = sym + } + for _, want := range []string{"uv_sleep", "helper"} { + sym, ok := symbols[want] + if !ok { + t.Fatalf("missing top-level declaration %q in symbols: %+v", want, result.Symbols) + } + if sym.Kind != "function" || sym.Parent != "" { + t.Fatalf("%s = kind %q parent %q, want top-level function", want, sym.Kind, sym.Parent) + } + } +} + +func TestCPPParser_ComplexFunctionDeclarations(t *testing.T) { + parser := &cppParser{} + source := ` +// Single line +void foo(); + +/* + Multi-line +*/ +int +bar( + int x, + char* y +); + +// Declaration with comment inside +void baz( /* inline comment */ int z ); + +// String containing semicolon and braces +const char* s = "void fake(); { }"; + +// Braces in comment +/* { */ void depth_test(); /* } */ + +// Multiline string +const char* ms = "multi \ +line \ +string"; + +void after_multiline_string(); + +void overload(int x); +void overload(double x); +` + result, err := parser.ParseFile(context.Background(), "complex.h", []byte(source)) + if err != nil { + t.Fatalf("ParseFile: %v", err) + } + + symbols := map[string]Symbol{} + counts := map[string]int{} + for _, sym := range result.Symbols { + symbols[sym.Name] = sym + counts[sym.Name]++ + } + + wants := []struct { + name string + line int + }{ + {"foo", 3}, + {"bar", 8}, // Fallback scanner finds it at start of declaration + {"baz", 15}, + {"depth_test", 21}, + {"after_multiline_string", 28}, + } + + for _, want := range wants { + sym, ok := symbols[want.name] + if !ok { + t.Errorf("missing symbol %q", want.name) + continue + } + if counts[want.name] != 1 { + t.Errorf("%s count = %d, want 1", want.name, counts[want.name]) + } + // Tree-Sitter and fallback scanner might find same symbol on different lines (start of decl vs start of declarator) + // We allow both for this test as long as they are close. + if sym.Line != want.line && sym.Line != want.line+1 { + t.Errorf("%s line = %d, want %d (or %d)", want.name, sym.Line, want.line, want.line+1) + } + } + if counts["overload"] != 2 { + t.Errorf("overload count = %d, want 2", counts["overload"]) + } +} diff --git a/internal/analyzer/go_tree_sitter.go b/internal/analyzer/go_tree_sitter.go index a8ce9d3..ed373ba 100644 --- a/internal/analyzer/go_tree_sitter.go +++ b/internal/analyzer/go_tree_sitter.go @@ -55,12 +55,17 @@ func (p *goParser) appendFunction(node *gotreesitter.Node, lang *gotreesitter.La if nameNode == nil { return } + parent := "" + if kind == "method" { + parent = goReceiverTypeName(node, lang, source) + } result.Symbols = append(result.Symbols, Symbol{ Name: nodeText(nameNode, source), Kind: kind, FilePath: path, Line: int(nameNode.StartPoint().Row) + 1, EndLine: int(node.EndPoint().Row) + 1, + Parent: parent, Description: p.findComment(node, lang, source), }) } @@ -190,3 +195,50 @@ func goCallName(node *gotreesitter.Node, lang *gotreesitter.Language, source []b } return strings.TrimSpace(text) } + +func goReceiverTypeName(node *gotreesitter.Node, lang *gotreesitter.Language, source []byte) string { + receiver := childByFieldName(node, lang, "receiver") + if receiver == nil { + return "" + } + for _, child := range namedChildren(receiver) { + if nodeKind(child, lang) != "parameter_declaration" { + continue + } + if name := goTypeName(childByFieldName(child, lang, "type"), lang, source); name != "" { + return name + } + } + return goTypeName(receiver, lang, source) +} + +func goTypeName(node *gotreesitter.Node, lang *gotreesitter.Language, source []byte) string { + if node == nil { + return "" + } + switch nodeKind(node, lang) { + case "type_identifier", "identifier", "field_identifier": + return nodeText(node, source) + case "qualified_type", "selector_expression": + children := namedChildren(node) + for i := len(children) - 1; i >= 0; i-- { + if name := goTypeName(children[i], lang, source); name != "" { + return name + } + } + } + for _, field := range []string{"type", "name"} { + if name := goTypeName(childByFieldName(node, lang, field), lang, source); name != "" { + return name + } + } + for _, child := range namedChildren(node) { + if nodeKind(child, lang) == "type_arguments" { + continue + } + if name := goTypeName(child, lang, source); name != "" { + return name + } + } + return "" +} diff --git a/internal/analyzer/go_tree_sitter_test.go b/internal/analyzer/go_tree_sitter_test.go new file mode 100644 index 0000000..2a7945e --- /dev/null +++ b/internal/analyzer/go_tree_sitter_test.go @@ -0,0 +1,34 @@ +package analyzer + +import ( + "context" + "testing" +) + +func TestGoParser_MethodReceiverParent(t *testing.T) { + parser := &goParser{} + source := `package main + +type Page struct{} +type Card struct{} + +func (p *Page) Render() {} +func (c Card) Render() {} +` + result, err := parser.ParseFile(context.Background(), "view.go", []byte(source)) + if err != nil { + t.Fatalf("ParseFile: %v", err) + } + + parents := map[string]bool{} + for _, sym := range result.Symbols { + if sym.Kind == "method" && sym.Name == "Render" { + parents[sym.Parent] = true + } + } + for _, want := range []string{"Page", "Card"} { + if !parents[want] { + t.Fatalf("missing Render parent %q in symbols: %+v", want, result.Symbols) + } + } +} diff --git a/cmd/analyze/analyze_resolution.go b/internal/analyzer/lsp/resolver.go similarity index 51% rename from cmd/analyze/analyze_resolution.go rename to internal/analyzer/lsp/resolver.go index 53f7085..c9313bf 100644 --- a/cmd/analyze/analyze_resolution.go +++ b/internal/analyzer/lsp/resolver.go @@ -1,4 +1,4 @@ -package analyze +package lsp import ( "context" @@ -8,41 +8,35 @@ import ( "path/filepath" "github.com/mertcikla/tld/internal/analyzer" - analyzerlsp "github.com/mertcikla/tld/internal/analyzer/lsp" "go.lsp.dev/protocol" "go.lsp.dev/uri" ) -type analyzeDefinitionLocation struct { +type DefinitionLocation struct { FilePath string Line int } -type analyzeDefinitionResolver interface { - ResolveDefinitions(ctx context.Context, ref analyzer.Ref) ([]analyzeDefinitionLocation, error) - Close() error -} - -type analyzeLSPResolver struct { - rootDir string - sessions map[analyzer.Language]*analyzerlsp.Session +type MultiLanguageResolver struct { + RootDir string + sessions map[analyzer.Language]*Session disabled map[analyzer.Language]struct{} opened map[string]struct{} contents map[string]string } -func newAnalyzeLSPResolver(rootDir string) *analyzeLSPResolver { - return &analyzeLSPResolver{ - rootDir: rootDir, - sessions: make(map[analyzer.Language]*analyzerlsp.Session), +func NewMultiLanguageResolver(rootDir string) *MultiLanguageResolver { + return &MultiLanguageResolver{ + RootDir: rootDir, + sessions: make(map[analyzer.Language]*Session), disabled: make(map[analyzer.Language]struct{}), opened: make(map[string]struct{}), contents: make(map[string]string), } } -func (r *analyzeLSPResolver) ResolveDefinitions(ctx context.Context, ref analyzer.Ref) ([]analyzeDefinitionLocation, error) { - if r == nil || r.rootDir == "" || ref.FilePath == "" || ref.Line <= 0 { +func (r *MultiLanguageResolver) ResolveDefinitions(ctx context.Context, ref analyzer.Ref) ([]DefinitionLocation, error) { + if r == nil || r.RootDir == "" || ref.FilePath == "" || ref.Line <= 0 { return nil, nil } language, ok := analyzer.DetectLanguage(ref.FilePath) @@ -72,13 +66,13 @@ func (r *analyzeLSPResolver) ResolveDefinitions(ctx context.Context, ref analyze if err != nil { return nil, err } - resolved := make([]analyzeDefinitionLocation, 0, len(locations)) + resolved := make([]DefinitionLocation, 0, len(locations)) for _, location := range locations { filePath := filepath.Clean(location.URI.Filename()) if filePath == "" { continue } - resolved = append(resolved, analyzeDefinitionLocation{ + resolved = append(resolved, DefinitionLocation{ FilePath: filePath, Line: int(location.Range.Start.Line) + 1, }) @@ -86,7 +80,7 @@ func (r *analyzeLSPResolver) ResolveDefinitions(ctx context.Context, ref analyze return resolved, nil } -func (r *analyzeLSPResolver) Close() error { +func (r *MultiLanguageResolver) Close() error { if r == nil { return nil } @@ -102,7 +96,7 @@ func (r *analyzeLSPResolver) Close() error { return errors.Join(errs...) } -func (r *analyzeLSPResolver) sessionForLanguage(ctx context.Context, language analyzer.Language) (*analyzerlsp.Session, bool, error) { +func (r *MultiLanguageResolver) sessionForLanguage(ctx context.Context, language analyzer.Language) (*Session, bool, error) { if r == nil { return nil, false, nil } @@ -112,9 +106,9 @@ func (r *analyzeLSPResolver) sessionForLanguage(ctx context.Context, language an if _, disabled := r.disabled[language]; disabled { return nil, false, nil } - session, err := analyzerlsp.StartSession(ctx, analyzerlsp.SessionConfig{ + session, err := StartSession(ctx, SessionConfig{ Language: language, - RootDir: r.rootDir, + RootDir: r.RootDir, }) if err != nil { r.disabled[language] = struct{}{} @@ -129,7 +123,7 @@ func (r *analyzeLSPResolver) sessionForLanguage(ctx context.Context, language an return session, true, nil } -func (r *analyzeLSPResolver) openDocument(ctx context.Context, session *analyzerlsp.Session, filePath string) error { +func (r *MultiLanguageResolver) openDocument(ctx context.Context, session *Session, filePath string) error { cleanPath := filepath.Clean(filePath) if _, ok := r.opened[cleanPath]; ok { return nil @@ -149,43 +143,3 @@ func (r *analyzeLSPResolver) openDocument(ctx context.Context, session *analyzer r.opened[cleanPath] = struct{}{} return nil } - -func resolveAnalyzeTargetRef(ctx context.Context, resolver analyzeDefinitionResolver, ref analyzer.Ref, symbols []analyzer.Symbol, refBySymbol map[analyzeElementLookupKey]string, refsByName map[string][]string) string { - if resolver != nil { - locations, err := resolver.ResolveDefinitions(ctx, ref) - if err == nil { - for _, location := range locations { - symbol, ok := symbolByFileAndLine(location.FilePath, location.Line, symbols) - if !ok { - continue - } - if targetRef, ok := refBySymbol[analyzeSymbolLookupKey(symbol)]; ok { - return targetRef - } - } - } - } - candidates := refsByName[ref.Name] - if len(candidates) == 1 { - return candidates[0] - } - return "" -} - -func symbolByFileAndLine(filePath string, line int, symbols []analyzer.Symbol) (analyzer.Symbol, bool) { - var bestSymbol analyzer.Symbol - found := false - cleanFilePath := filepath.Clean(filePath) - for _, symbol := range symbols { - if filepath.Clean(symbol.FilePath) != cleanFilePath { - continue - } - if symbol.Line <= line && (symbol.EndLine == 0 || symbol.EndLine >= line) { - if !found || symbol.Line > bestSymbol.Line { - bestSymbol = symbol - found = true - } - } - } - return bestSymbol, found -} diff --git a/internal/analyzer/parser_registry.go b/internal/analyzer/parser_registry.go index 0956f81..53d98c7 100644 --- a/internal/analyzer/parser_registry.go +++ b/internal/analyzer/parser_registry.go @@ -16,10 +16,12 @@ type parserRegistry struct { func newDefaultParserRegistry() *parserRegistry { return &parserRegistry{ parsers: map[Language]fileParser{ + LanguageC: &cppParser{}, LanguageCPP: &cppParser{}, LanguageGo: &goParser{}, LanguageJava: &javaParser{}, LanguagePython: &pythonParser{}, + LanguageRust: &rustParser{}, LanguageTypeScript: &tsParser{}, LanguageJavaScript: &jsParser{}, }, diff --git a/internal/analyzer/rust_test.go b/internal/analyzer/rust_test.go new file mode 100644 index 0000000..6f26cb2 --- /dev/null +++ b/internal/analyzer/rust_test.go @@ -0,0 +1,101 @@ +package analyzer + +import ( + "context" + "testing" +) + +func TestRustParser_ParseFile(t *testing.T) { + parser := &rustParser{} + source := ` +use std::collections::HashMap; +use std::io::{self, Write}; +use std::fs as filesystem; + +mod internal { + fn secret() {} +} + +struct Point { + x: i32, + y: i32, +} + +impl Point { + fn new(x: i32, y: i32) -> Self { + Point { x, y } + } + + fn distance(&self) -> f64 { + ((self.x * self.x + self.y * self.y) as f64).sqrt() + } +} + +trait Drawable { + fn draw(&self); +} + +fn main() { + let p = Point::new(10, 20); + p.distance(); + println!("Hello"); +} +` + result, err := parser.ParseFile(context.Background(), "test.rs", []byte(source)) + if err != nil { + t.Fatalf("ParseFile: %v", err) + } + + // Verify Symbols + expectedSymbols := []Symbol{ + {Name: "internal", Kind: "module", Parent: ""}, + {Name: "secret", Kind: "method", Parent: "internal"}, + {Name: "Point", Kind: "struct", Parent: ""}, + {Name: "new", Kind: "method", Parent: "Point"}, + {Name: "distance", Kind: "method", Parent: "Point"}, + {Name: "Drawable", Kind: "trait", Parent: ""}, + {Name: "draw", Kind: "method", Parent: "Drawable"}, + {Name: "main", Kind: "function", Parent: ""}, + } + + for _, expected := range expectedSymbols { + found := false + for _, actual := range result.Symbols { + if actual.Name == expected.Name && actual.Kind == expected.Kind && actual.Parent == expected.Parent { + found = true + break + } + } + if !found { + t.Errorf("Symbol not found or mismatch: %+v", expected) + } + } + + // Verify Refs + expectedRefs := []Ref{ + {Name: "HashMap", Kind: "import", TargetPath: "std::collections::HashMap"}, + {Name: "io", Kind: "import", TargetPath: "std::io"}, + {Name: "Write", Kind: "import", TargetPath: "std::io::Write"}, + {Name: "filesystem", Kind: "import", TargetPath: "std::fs"}, + {Name: "new", Kind: "call"}, + {Name: "distance", Kind: "call"}, + {Name: "println!", Kind: "call"}, + {Name: "sqrt", Kind: "call"}, + } + + for _, expected := range expectedRefs { + found := false + for _, actual := range result.Refs { + if actual.Name == expected.Name && (expected.Kind == "" || actual.Kind == expected.Kind) { + if expected.TargetPath != "" && actual.TargetPath != expected.TargetPath { + continue + } + found = true + break + } + } + if !found { + t.Errorf("Ref not found or mismatch: %+v", expected) + } + } +} diff --git a/internal/analyzer/rust_tree_sitter.go b/internal/analyzer/rust_tree_sitter.go new file mode 100644 index 0000000..5753dfa --- /dev/null +++ b/internal/analyzer/rust_tree_sitter.go @@ -0,0 +1,348 @@ +package analyzer + +import ( + "context" + "fmt" + "strings" + + "github.com/odvcencio/gotreesitter" +) + +type rustParser struct{} + +func (p *rustParser) ParseFile(ctx context.Context, path string, source []byte) (*Result, error) { + parsed, err := parseTree(ctx, path, source) + if err != nil { + return nil, fmt.Errorf("parse rust tree-sitter source: %w", err) + } + defer parsed.Close() + + result := &Result{} + root := parsed.tree.RootNode() + p.walkNode(root, parsed.lang, source, path, "", result) + return result, nil +} + +func (p *rustParser) walkNode(node *gotreesitter.Node, lang *gotreesitter.Language, source []byte, path, parent string, result *Result) { + if node == nil { + return + } + + nextParent := parent + kind := nodeKind(node, lang) + switch kind { + case "function_item", "function_signature_item": + nextParent = p.appendFunction(node, lang, source, path, parent, result) + case "struct_item": + nextParent = p.appendSymbol(node, lang, source, path, parent, "struct", result) + case "enum_item": + nextParent = p.appendSymbol(node, lang, source, path, parent, "enum", result) + case "trait_item": + nextParent = p.appendSymbol(node, lang, source, path, parent, "trait", result) + case "mod_item": + nextParent = p.appendSymbol(node, lang, source, path, parent, "module", result) + case "impl_item": + nextParent = p.handleImpl(node, lang, source, path, parent, result) + case "type_item": + nextParent = p.appendSymbol(node, lang, source, path, parent, "type", result) + case "use_declaration": + p.appendUse(node, lang, source, path, result) + case "call_expression": + p.appendCall(node, lang, source, path, result) + case "macro_invocation": + p.appendMacro(node, lang, source, path, result) + case "struct_expression": + p.appendStructExpr(node, lang, source, path, result) + } + + for _, child := range namedChildren(node) { + p.walkNode(child, lang, source, path, nextParent, result) + } +} + +func (p *rustParser) handleImpl(node *gotreesitter.Node, lang *gotreesitter.Language, source []byte, path, parent string, result *Result) string { + // impl Point { ... } or impl Drawable for Point { ... } + // We want the type name as the parent for methods inside. + typeNode := childByFieldName(node, lang, "type") + if typeNode == nil { + // Fallback to type_identifier + for _, child := range namedChildren(node) { + if nodeKind(child, lang) == "type_identifier" { + typeNode = child + break + } + } + } + if typeNode != nil { + return nodeText(typeNode, source) + } + return parent +} + +func (p *rustParser) appendFunction(node *gotreesitter.Node, lang *gotreesitter.Language, source []byte, path, parent string, result *Result) string { + nameNode := childByFieldName(node, lang, "name") + if nameNode == nil { + // Fallback to first identifier child if name field is missing + for _, child := range namedChildren(node) { + if nodeKind(child, lang) == "identifier" { + nameNode = child + break + } + } + } + if nameNode == nil { + return parent + } + + name := nodeText(nameNode, source) + kind := "function" + // If parent is an impl or trait, it's a method + if parent != "" { + kind = "method" + } + + result.Symbols = append(result.Symbols, Symbol{ + Name: name, + Kind: kind, + FilePath: path, + Line: int(nameNode.StartPoint().Row) + 1, + EndLine: int(node.EndPoint().Row) + 1, + Parent: parent, + }) + return name +} + +func (p *rustParser) appendSymbol(node *gotreesitter.Node, lang *gotreesitter.Language, source []byte, path, parent, kind string, result *Result) string { + nameNode := childByFieldName(node, lang, "name") + if nameNode == nil { + // Fallback to first identifier or type_identifier + for _, child := range namedChildren(node) { + k := nodeKind(child, lang) + if k == "identifier" || k == "type_identifier" { + nameNode = child + break + } + } + } + if nameNode == nil { + return parent + } + + name := nodeText(nameNode, source) + result.Symbols = append(result.Symbols, Symbol{ + Name: name, + Kind: kind, + FilePath: path, + Line: int(nameNode.StartPoint().Row) + 1, + EndLine: int(node.EndPoint().Row) + 1, + Parent: parent, + }) + return name +} + +func (p *rustParser) appendUse(node *gotreesitter.Node, lang *gotreesitter.Language, source []byte, path string, result *Result) { + argumentNode := childByFieldName(node, lang, "argument") + if argumentNode == nil { + // Fallback to searching children + for _, child := range namedChildren(node) { + k := nodeKind(child, lang) + if k == "scoped_identifier" || k == "identifier" || k == "use_list" || k == "scoped_use_list" || k == "use_as_clause" { + argumentNode = child + break + } + } + } + if argumentNode == nil { + return + } + + p.processUseArgument(argumentNode, lang, source, path, "", result) +} + +func (p *rustParser) processUseArgument(node *gotreesitter.Node, lang *gotreesitter.Language, source []byte, path, prefix string, result *Result) { + kind := nodeKind(node, lang) + switch kind { + case "identifier", "scoped_identifier": + targetPath := nodeText(node, source) + if prefix != "" { + targetPath = prefix + "::" + targetPath + } + name := targetPath + if idx := strings.LastIndex(targetPath, "::"); idx >= 0 { + name = targetPath[idx+2:] + } + result.Refs = append(result.Refs, Ref{ + Name: name, + Kind: "import", + TargetPath: targetPath, + FilePath: path, + Line: int(node.StartPoint().Row) + 1, + Column: int(node.StartPoint().Column) + 1, + }) + case "use_list": + for _, child := range namedChildren(node) { + p.processUseArgument(child, lang, source, path, prefix, result) + } + case "scoped_use_list": + newPrefix := "" + // [scoped_identifier] [::] [use_list] + if node.ChildCount() >= 3 { + newPrefix = nodeText(node.Child(0), source) + if prefix != "" { + newPrefix = prefix + "::" + newPrefix + } + p.processUseArgument(node.Child(2), lang, source, path, newPrefix, result) + } + case "use_as_clause": + aliasNode := childByFieldName(node, lang, "alias") + pathNode := childByFieldName(node, lang, "path") + if aliasNode != nil && pathNode != nil { + targetPath := nodeText(pathNode, source) + if prefix != "" { + targetPath = prefix + "::" + targetPath + } + result.Refs = append(result.Refs, Ref{ + Name: nodeText(aliasNode, source), + Kind: "import", + TargetPath: targetPath, + FilePath: path, + Line: int(aliasNode.StartPoint().Row) + 1, + Column: int(aliasNode.StartPoint().Column) + 1, + }) + } + case "self": + if prefix != "" { + targetPath := prefix + name := targetPath + if idx := strings.LastIndex(targetPath, "::"); idx >= 0 { + name = targetPath[idx+2:] + } + result.Refs = append(result.Refs, Ref{ + Name: name, + Kind: "import", + TargetPath: targetPath, + FilePath: path, + Line: int(node.StartPoint().Row) + 1, + Column: int(node.StartPoint().Column) + 1, + }) + } + } +} + +func (p *rustParser) appendCall(node *gotreesitter.Node, lang *gotreesitter.Language, source []byte, path string, result *Result) { + functionNode := childByFieldName(node, lang, "function") + if functionNode == nil { + return + } + name := rustCallName(functionNode, lang, source) + if name == "" { + return + } + result.Refs = append(result.Refs, Ref{ + Name: name, + Kind: "call", + FilePath: path, + Line: int(functionNode.StartPoint().Row) + 1, + Column: int(functionNode.StartPoint().Column) + 1, + }) +} + +func (p *rustParser) appendMacro(node *gotreesitter.Node, lang *gotreesitter.Language, source []byte, path string, result *Result) { + nameNode := childByFieldName(node, lang, "macro") + if nameNode == nil { + // First child is usually the identifier for macros + for _, child := range namedChildren(node) { + if nodeKind(child, lang) == "identifier" { + nameNode = child + break + } + } + } + if nameNode == nil { + return + } + name := nodeText(nameNode, source) + result.Refs = append(result.Refs, Ref{ + Name: name + "!", + Kind: "call", + FilePath: path, + Line: int(nameNode.StartPoint().Row) + 1, + Column: int(nameNode.StartPoint().Column) + 1, + }) +} + +func (p *rustParser) appendStructExpr(node *gotreesitter.Node, lang *gotreesitter.Language, source []byte, path string, result *Result) { + nameNode := childByFieldName(node, lang, "name") + if nameNode == nil { + // Fallback + for _, child := range namedChildren(node) { + k := nodeKind(child, lang) + if k == "type_identifier" || k == "scoped_type_identifier" || k == "generic_type" { + nameNode = child + break + } + } + } + if nameNode == nil { + return + } + name := rustCallName(nameNode, lang, source) + if name == "" { + return + } + result.Refs = append(result.Refs, Ref{ + Name: name, + Kind: "call", + FilePath: path, + Line: int(nameNode.StartPoint().Row) + 1, + Column: int(nameNode.StartPoint().Column) + 1, + }) +} + +func rustCallName(node *gotreesitter.Node, lang *gotreesitter.Language, source []byte) string { + if node == nil { + return "" + } + kind := nodeKind(node, lang) + switch kind { + case "identifier", "field_identifier", "type_identifier": + return nodeText(node, source) + case "scoped_identifier", "scoped_type_identifier": + nameNode := childByFieldName(node, lang, "name") + if nameNode != nil { + return nodeText(nameNode, source) + } + // Fallback to last identifier + children := namedChildren(node) + if len(children) > 0 { + return rustCallName(children[len(children)-1], lang, source) + } + case "field_expression": + fieldNode := childByFieldName(node, lang, "field") + if fieldNode != nil { + return nodeText(fieldNode, source) + } + case "generic_type": + typeNode := childByFieldName(node, lang, "type") + if typeNode != nil { + return rustCallName(typeNode, lang, source) + } + } + text := strings.TrimSpace(nodeText(node, source)) + if text == "" { + return "" + } + if index := strings.LastIndex(text, "::"); index >= 0 { + text = text[index+2:] + } + if index := strings.LastIndex(text, "."); index >= 0 { + text = text[index+1:] + } + if index := strings.Index(text, "<"); index >= 0 { + text = text[:index] + } + if index := strings.Index(text, "("); index >= 0 { + text = text[:index] + } + return strings.TrimSpace(text) +} diff --git a/internal/app/connector.go b/internal/app/connector.go new file mode 100644 index 0000000..ece8faa --- /dev/null +++ b/internal/app/connector.go @@ -0,0 +1,117 @@ +package app + +import "context" + +func (s *Store) Connectors(ctx context.Context, viewID int64) ([]Connector, error) { + rows, err := s.db.QueryContext(ctx, ` + SELECT id, view_id, source_element_id, target_element_id, label, description, relationship, direction, style, url, source_handle, target_handle, created_at, updated_at + FROM connectors WHERE view_id = ? ORDER BY id`, viewID) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + out := make([]Connector, 0) + for rows.Next() { + var item Connector + if err := rows.Scan(&item.ID, &item.ViewID, &item.SourceElementID, &item.TargetElementID, &item.Label, &item.Description, &item.Relationship, &item.Direction, &item.Style, &item.URL, &item.SourceHandle, &item.TargetHandle, &item.CreatedAt, &item.UpdatedAt); err != nil { + return nil, err + } + out = append(out, item) + } + return out, rows.Err() +} + +func (s *Store) AllConnectors(ctx context.Context) ([]Connector, error) { + rows, err := s.db.QueryContext(ctx, ` + SELECT id, view_id, source_element_id, target_element_id, label, description, relationship, direction, style, url, source_handle, target_handle, created_at, updated_at + FROM connectors ORDER BY view_id, id`) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + out := make([]Connector, 0) + for rows.Next() { + var item Connector + if err := rows.Scan(&item.ID, &item.ViewID, &item.SourceElementID, &item.TargetElementID, &item.Label, &item.Description, &item.Relationship, &item.Direction, &item.Style, &item.URL, &item.SourceHandle, &item.TargetHandle, &item.CreatedAt, &item.UpdatedAt); err != nil { + return nil, err + } + out = append(out, item) + } + return out, rows.Err() +} + +func (s *Store) CreateConnector(ctx context.Context, input Connector) (Connector, error) { + now := nowString() + res, err := s.db.ExecContext(ctx, ` + INSERT INTO connectors(view_id, source_element_id, target_element_id, label, description, relationship, direction, style, url, source_handle, target_handle, created_at, updated_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, + input.ViewID, input.SourceElementID, input.TargetElementID, input.Label, input.Description, input.Relationship, + normalizeDirection(new(input.Direction)), input.Style, input.URL, input.SourceHandle, input.TargetHandle, now, now) + if err != nil { + return Connector{}, err + } + id, _ := res.LastInsertId() + return s.ConnectorByID(ctx, id) +} + +func (s *Store) ConnectorByID(ctx context.Context, id int64) (Connector, error) { + row := s.db.QueryRowContext(ctx, `SELECT id, view_id, source_element_id, target_element_id, label, description, relationship, direction, style, url, source_handle, target_handle, created_at, updated_at FROM connectors WHERE id = ?`, id) + var item Connector + if err := row.Scan(&item.ID, &item.ViewID, &item.SourceElementID, &item.TargetElementID, &item.Label, &item.Description, &item.Relationship, &item.Direction, &item.Style, &item.URL, &item.SourceHandle, &item.TargetHandle, &item.CreatedAt, &item.UpdatedAt); err != nil { + return Connector{}, err + } + return item, nil +} + +func (s *Store) UpdateConnector(ctx context.Context, id int64, patch Connector) (Connector, error) { + current, err := s.ConnectorByID(ctx, id) + if err != nil { + return Connector{}, err + } + if patch.SourceElementID == 0 { + patch.SourceElementID = current.SourceElementID + } + if patch.TargetElementID == 0 { + patch.TargetElementID = current.TargetElementID + } + if patch.ViewID == 0 { + patch.ViewID = current.ViewID + } + if patch.Direction == "" { + patch.Direction = current.Direction + } + if patch.Style == "" { + patch.Style = current.Style + } + if patch.Label == nil { + patch.Label = current.Label + } + if patch.Description == nil { + patch.Description = current.Description + } + if patch.Relationship == nil { + patch.Relationship = current.Relationship + } + if patch.URL == nil { + patch.URL = current.URL + } + if patch.SourceHandle == nil { + patch.SourceHandle = current.SourceHandle + } + if patch.TargetHandle == nil { + patch.TargetHandle = current.TargetHandle + } + _, err = s.db.ExecContext(ctx, ` + UPDATE connectors SET source_element_id = ?, target_element_id = ?, label = ?, description = ?, relationship = ?, direction = ?, style = ?, url = ?, source_handle = ?, target_handle = ?, updated_at = ? + WHERE id = ?`, + patch.SourceElementID, patch.TargetElementID, patch.Label, patch.Description, patch.Relationship, patch.Direction, patch.Style, patch.URL, patch.SourceHandle, patch.TargetHandle, nowString(), id) + if err != nil { + return Connector{}, err + } + return s.ConnectorByID(ctx, id) +} + +func (s *Store) DeleteConnector(ctx context.Context, id int64) error { + _, err := s.db.ExecContext(ctx, `DELETE FROM connectors WHERE id = ?`, id) + return err +} diff --git a/internal/app/element.go b/internal/app/element.go new file mode 100644 index 0000000..6935068 --- /dev/null +++ b/internal/app/element.go @@ -0,0 +1,362 @@ +package app + +import ( + "context" + "database/sql" + "errors" + "strings" +) + +type LibraryElement struct { + ID int64 `json:"id"` + Name string `json:"name"` + Kind *string `json:"kind"` + Description *string `json:"description"` + Technology *string `json:"technology"` + URL *string `json:"url"` + LogoURL *string `json:"logo_url"` + TechnologyConnectors []TechnologyConnector `json:"technology_connectors"` + Tags []string `json:"tags"` + Repo *string `json:"repo,omitempty"` + Branch *string `json:"branch,omitempty"` + FilePath *string `json:"file_path,omitempty"` + Language *string `json:"language,omitempty"` + CreatedAt string `json:"created_at"` + UpdatedAt string `json:"updated_at"` + HasView bool `json:"has_view"` + ViewLabel *string `json:"view_label"` +} + +type PlacedElement struct { + ID int64 `json:"id"` + ViewID int64 `json:"view_id"` + ElementID int64 `json:"element_id"` + PositionX float64 `json:"position_x"` + PositionY float64 `json:"position_y"` + Name string `json:"name"` + Description *string `json:"description"` + Kind *string `json:"kind"` + Technology *string `json:"technology"` + URL *string `json:"url"` + LogoURL *string `json:"logo_url"` + TechnologyConnectors []TechnologyConnector `json:"technology_connectors"` + Tags []string `json:"tags"` + Repo *string `json:"repo,omitempty"` + Branch *string `json:"branch,omitempty"` + FilePath *string `json:"file_path,omitempty"` + Language *string `json:"language,omitempty"` + HasView bool `json:"has_view"` + ViewLabel *string `json:"view_label"` +} + +type ElementPlacement struct { + ID int64 `json:"id"` + ViewID int64 `json:"view_id"` + ElementID int64 `json:"element_id"` + PositionX float64 `json:"position_x"` + PositionY float64 `json:"position_y"` +} + +type DependencyElement struct { + ID string `json:"id"` + Name string `json:"name"` + Description *string `json:"description,omitempty"` + Type *string `json:"type,omitempty"` + Technology *string `json:"technology,omitempty"` + URL *string `json:"url,omitempty"` + LogoURL *string `json:"logo_url,omitempty"` + TechnologyConnectors []TechnologyConnector `json:"technology_connectors"` + Tags []string `json:"tags"` + Repo *string `json:"repo,omitempty"` + Branch *string `json:"branch,omitempty"` + Language *string `json:"language,omitempty"` + FilePath *string `json:"file_path,omitempty"` + CreatedAt string `json:"created_at"` + UpdatedAt string `json:"updated_at"` +} + +type PlanElement struct { + Ref string `json:"ref"` + Name string `json:"name"` + Kind *string `json:"kind"` + Description *string `json:"description"` + Technology *string `json:"technology"` + URL *string `json:"url"` + LogoURL *string `json:"logo_url"` + TechnologyLinks []TechnologyConnector `json:"technology_links"` + Tags []string `json:"tags"` + Repo *string `json:"repo"` + Branch *string `json:"branch"` + Language *string `json:"language"` + FilePath *string `json:"file_path"` + HasView bool `json:"has_view"` + ViewLabel *string `json:"view_label"` +} + +func (s *Store) Elements(ctx context.Context, limit, offset int, search string) ([]LibraryElement, int, error) { + type elementRow struct { + ID int64 + Name string + Kind sql.NullString + Description sql.NullString + Technology sql.NullString + URL sql.NullString + LogoURL sql.NullString + TechRaw string + TagRaw string + Repo sql.NullString + Branch sql.NullString + FilePath sql.NullString + Language sql.NullString + CreatedAt string + UpdatedAt string + } + + where := "" + args := []any{} + if strings.TrimSpace(search) != "" { + where = ` WHERE LOWER(name) LIKE LOWER(?) OR LOWER(COALESCE(description, '')) LIKE LOWER(?)` + pattern := "%" + strings.TrimSpace(search) + "%" + args = append(args, pattern, pattern) + } + var total int + if err := s.db.QueryRowContext(ctx, `SELECT COUNT(*) FROM elements`+where, args...).Scan(&total); err != nil { + return nil, 0, err + } + + query := `SELECT id, name, kind, description, technology, url, logo_url, technology_connectors, tags, repo, branch, file_path, language, created_at, updated_at FROM elements` + where + query += ` ORDER BY updated_at DESC` + if limit > 0 { + query += ` LIMIT ? OFFSET ?` + args = append(args, limit, offset) + } + rows, err := s.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, 0, err + } + defer func() { _ = rows.Close() }() + scanned := make([]elementRow, 0) + for rows.Next() { + var row elementRow + if err := rows.Scan( + &row.ID, + &row.Name, + &row.Kind, + &row.Description, + &row.Technology, + &row.URL, + &row.LogoURL, + &row.TechRaw, + &row.TagRaw, + &row.Repo, + &row.Branch, + &row.FilePath, + &row.Language, + &row.CreatedAt, + &row.UpdatedAt, + ); err != nil { + return nil, 0, err + } + scanned = append(scanned, row) + } + if err := rows.Err(); err != nil { + return nil, 0, err + } + viewMeta, err := s.childViewMetaMap(ctx) + if err != nil { + return nil, 0, err + } + + out := make([]LibraryElement, 0, len(scanned)) + for _, row := range scanned { + elem := LibraryElement{ + ID: row.ID, + Name: row.Name, + TechnologyConnectors: parseTechnologyConnectors(row.TechRaw), + Tags: parseStrings(row.TagRaw), + CreatedAt: row.CreatedAt, + UpdatedAt: row.UpdatedAt, + } + if row.Kind.Valid { + elem.Kind = &row.Kind.String + } + if row.Description.Valid { + elem.Description = &row.Description.String + } + if row.Technology.Valid { + elem.Technology = &row.Technology.String + } + if row.URL.Valid { + elem.URL = &row.URL.String + } + if row.LogoURL.Valid { + elem.LogoURL = &row.LogoURL.String + } + if row.Repo.Valid { + elem.Repo = &row.Repo.String + } + if row.Branch.Valid { + elem.Branch = &row.Branch.String + } + if row.FilePath.Valid { + elem.FilePath = &row.FilePath.String + } + if row.Language.Valid { + elem.Language = &row.Language.String + } + if meta, ok := viewMeta[elem.ID]; ok { + elem.HasView = meta.hasView + elem.ViewLabel = meta.label + } + out = append(out, elem) + } + return out, total, nil +} + +func (s *Store) ElementByID(ctx context.Context, id int64) (LibraryElement, error) { + row := s.db.QueryRowContext(ctx, `SELECT id, name, kind, description, technology, url, logo_url, technology_connectors, tags, repo, branch, file_path, language, created_at, updated_at FROM elements WHERE id = ?`, id) + return scanElement(row, true, s, ctx) +} + +func (s *Store) CreateElement(ctx context.Context, input LibraryElement) (LibraryElement, error) { + if err := s.ensureTagColors(ctx, input.Tags); err != nil { + return LibraryElement{}, err + } + now := nowString() + res, err := s.db.ExecContext(ctx, ` + INSERT INTO elements(name, kind, description, technology, url, logo_url, technology_connectors, tags, repo, branch, file_path, language, created_at, updated_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, + strings.TrimSpace(input.Name), + input.Kind, + input.Description, + input.Technology, + input.URL, + input.LogoURL, + jsonString(input.TechnologyConnectors, "[]"), + jsonString(input.Tags, "[]"), + input.Repo, + input.Branch, + input.FilePath, + input.Language, + now, + now, + ) + if err != nil { + return LibraryElement{}, err + } + id, _ := res.LastInsertId() + return s.ElementByID(ctx, id) +} + +func (s *Store) UpdateElement(ctx context.Context, id int64, input LibraryElement) (LibraryElement, error) { + if input.Tags != nil { + if err := s.ensureTagColors(ctx, input.Tags); err != nil { + return LibraryElement{}, err + } + } + current, err := s.ElementByID(ctx, id) + if err != nil { + return LibraryElement{}, err + } + if input.Name == "" { + input.Name = current.Name + } + if input.Kind == nil { + input.Kind = current.Kind + } + if input.Description == nil { + input.Description = current.Description + } + if input.Technology == nil { + input.Technology = current.Technology + } + if input.URL == nil { + input.URL = current.URL + } + if input.LogoURL == nil { + input.LogoURL = current.LogoURL + } + if input.Repo == nil { + input.Repo = current.Repo + } + if input.Branch == nil { + input.Branch = current.Branch + } + if input.FilePath == nil { + input.FilePath = current.FilePath + } + if input.Language == nil { + input.Language = current.Language + } + if input.TechnologyConnectors == nil { + input.TechnologyConnectors = current.TechnologyConnectors + } + if input.Tags == nil { + input.Tags = current.Tags + } + _, err = s.db.ExecContext(ctx, ` + UPDATE elements SET name = ?, kind = ?, description = ?, technology = ?, url = ?, logo_url = ?, technology_connectors = ?, tags = ?, repo = ?, branch = ?, file_path = ?, language = ?, updated_at = ? + WHERE id = ?`, + input.Name, input.Kind, input.Description, input.Technology, input.URL, input.LogoURL, + jsonString(input.TechnologyConnectors, "[]"), jsonString(input.Tags, "[]"), + input.Repo, input.Branch, input.FilePath, input.Language, nowString(), id, + ) + if err != nil { + return LibraryElement{}, err + } + return s.ElementByID(ctx, id) +} + +func (s *Store) DeleteElement(ctx context.Context, id int64) error { + _, err := s.db.ExecContext(ctx, `DELETE FROM elements WHERE id = ?`, id) + return err +} + +func (s *Store) ListElementPlacements(ctx context.Context, elementID int64) ([]ViewPlacement, error) { + rows, err := s.db.QueryContext(ctx, ` + SELECT p.view_id, v.name + FROM placements p + JOIN views v ON v.id = p.view_id + WHERE p.element_id = ? + ORDER BY p.view_id`, elementID) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + out := make([]ViewPlacement, 0) + for rows.Next() { + var placement ViewPlacement + if err := rows.Scan(&placement.ViewID, &placement.ViewName); err != nil { + return nil, err + } + out = append(out, placement) + } + return out, rows.Err() +} + +func (s *Store) ListElementNavigations(ctx context.Context, elementID int64, fromViewID, toViewID *int64) ([]ViewConnector, error) { + row := s.db.QueryRowContext(ctx, `SELECT id, name FROM views WHERE owner_element_id = ? ORDER BY id LIMIT 1`, elementID) + var childViewID int64 + var childViewName string + if err := row.Scan(&childViewID, &childViewName); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return []ViewConnector{}, nil + } + return nil, err + } + parentID, err := s.parentViewForOwner(ctx, elementID, childViewID) + if err != nil { + return nil, err + } + out := make([]ViewConnector, 0, 1) + if fromViewID != nil && *fromViewID > 0 { + if parentID != nil && *parentID == *fromViewID { + out = append(out, ViewConnector{ID: 0, ElementID: &elementID, FromViewID: *fromViewID, ToViewID: childViewID, ToViewName: childViewName, RelationType: "child"}) + } + return out, nil + } + if toViewID != nil && *toViewID > 0 && parentID != nil && *toViewID == childViewID { + out = append(out, ViewConnector{ID: 0, ElementID: &elementID, FromViewID: *parentID, ToViewID: childViewID, ToViewName: childViewName, RelationType: "child"}) + } + return out, nil +} diff --git a/internal/app/placement.go b/internal/app/placement.go new file mode 100644 index 0000000..6232108 --- /dev/null +++ b/internal/app/placement.go @@ -0,0 +1,147 @@ +package app + +import "context" + +func (s *Store) Placements(ctx context.Context, viewID int64) ([]PlacedElement, error) { + type placementRow struct { + item PlacedElement + techRaw string + tagRaw string + } + + rows, err := s.db.QueryContext(ctx, ` + SELECT p.id, p.view_id, p.element_id, p.position_x, p.position_y, + e.name, e.kind, e.description, e.technology, e.url, e.logo_url, e.technology_connectors, e.tags, e.repo, e.branch, e.file_path, e.language, e.created_at, e.updated_at + FROM placements p + JOIN elements e ON e.id = p.element_id + WHERE p.view_id = ? + ORDER BY p.id`, viewID) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + scanned := make([]placementRow, 0) + for rows.Next() { + var row placementRow + if err := rows.Scan(&row.item.ID, &row.item.ViewID, &row.item.ElementID, &row.item.PositionX, &row.item.PositionY, + &row.item.Name, &row.item.Kind, &row.item.Description, &row.item.Technology, &row.item.URL, &row.item.LogoURL, + &row.techRaw, &row.tagRaw, &row.item.Repo, &row.item.Branch, &row.item.FilePath, &row.item.Language, new(string), new(string)); err != nil { + return nil, err + } + scanned = append(scanned, row) + } + if err := rows.Err(); err != nil { + return nil, err + } + viewMeta, err := s.childViewMetaMap(ctx) + if err != nil { + return nil, err + } + + out := make([]PlacedElement, 0, len(scanned)) + for _, row := range scanned { + item := row.item + item.TechnologyConnectors = parseTechnologyConnectors(row.techRaw) + item.Tags = parseStrings(row.tagRaw) + if meta, ok := viewMeta[item.ElementID]; ok { + item.HasView = meta.hasView + item.ViewLabel = meta.label + } + out = append(out, item) + } + return out, nil +} + +func (s *Store) AllPlacements(ctx context.Context) ([]PlacedElement, error) { + type placementRow struct { + item PlacedElement + techRaw string + tagRaw string + } + + rows, err := s.db.QueryContext(ctx, ` + SELECT p.id, p.view_id, p.element_id, p.position_x, p.position_y, + e.name, e.kind, e.description, e.technology, e.url, e.logo_url, e.technology_connectors, e.tags, e.repo, e.branch, e.file_path, e.language, e.created_at, e.updated_at + FROM placements p + JOIN elements e ON e.id = p.element_id + ORDER BY p.view_id, p.id`) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + scanned := make([]placementRow, 0) + for rows.Next() { + var row placementRow + if err := rows.Scan(&row.item.ID, &row.item.ViewID, &row.item.ElementID, &row.item.PositionX, &row.item.PositionY, + &row.item.Name, &row.item.Kind, &row.item.Description, &row.item.Technology, &row.item.URL, &row.item.LogoURL, + &row.techRaw, &row.tagRaw, &row.item.Repo, &row.item.Branch, &row.item.FilePath, &row.item.Language, new(string), new(string)); err != nil { + return nil, err + } + scanned = append(scanned, row) + } + if err := rows.Err(); err != nil { + return nil, err + } + viewMeta, err := s.childViewMetaMap(ctx) + if err != nil { + return nil, err + } + + out := make([]PlacedElement, 0, len(scanned)) + for _, row := range scanned { + item := row.item + item.TechnologyConnectors = parseTechnologyConnectors(row.techRaw) + item.Tags = parseStrings(row.tagRaw) + if meta, ok := viewMeta[item.ElementID]; ok { + item.HasView = meta.hasView + item.ViewLabel = meta.label + } + out = append(out, item) + } + return out, nil +} + +func (s *Store) ElementPlacements(ctx context.Context, viewID int64) ([]ElementPlacement, error) { + rows, err := s.db.QueryContext(ctx, `SELECT id, view_id, element_id, position_x, position_y FROM placements WHERE view_id = ? ORDER BY id`, viewID) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + out := make([]ElementPlacement, 0) + for rows.Next() { + var item ElementPlacement + if err := rows.Scan(&item.ID, &item.ViewID, &item.ElementID, &item.PositionX, &item.PositionY); err != nil { + return nil, err + } + out = append(out, item) + } + return out, rows.Err() +} + +func (s *Store) AddPlacement(ctx context.Context, viewID, elementID int64, x, y float64) (ElementPlacement, error) { + now := nowString() + _, err := s.db.ExecContext(ctx, ` + INSERT INTO placements(view_id, element_id, position_x, position_y, created_at, updated_at) + VALUES (?, ?, ?, ?, ?, ?) + ON CONFLICT(view_id, element_id) DO UPDATE SET position_x = excluded.position_x, position_y = excluded.position_y, updated_at = excluded.updated_at`, + viewID, elementID, x, y, now, now) + if err != nil { + return ElementPlacement{}, err + } + row := s.db.QueryRowContext(ctx, `SELECT id, view_id, element_id, position_x, position_y FROM placements WHERE view_id = ? AND element_id = ?`, viewID, elementID) + var item ElementPlacement + if err := row.Scan(&item.ID, &item.ViewID, &item.ElementID, &item.PositionX, &item.PositionY); err != nil { + return ElementPlacement{}, err + } + return item, nil +} + +func (s *Store) UpdatePlacement(ctx context.Context, viewID, elementID int64, x, y float64) error { + _, err := s.db.ExecContext(ctx, `UPDATE placements SET position_x = ?, position_y = ?, updated_at = ? WHERE view_id = ? AND element_id = ?`, x, y, nowString(), viewID, elementID) + return err +} + +func (s *Store) DeletePlacement(ctx context.Context, viewID, elementID int64) error { + _, err := s.db.ExecContext(ctx, `DELETE FROM placements WHERE view_id = ? AND element_id = ?`, viewID, elementID) + return err +} diff --git a/internal/app/store.go b/internal/app/store.go index 6072096..ff22da7 100644 --- a/internal/app/store.go +++ b/internal/app/store.go @@ -5,15 +5,14 @@ import ( "database/sql" "embed" "encoding/json" - "errors" "fmt" "io/fs" - "maps" "math" "sort" "strings" "time" + sqlitevec "github.com/viant/sqlite-vec/vec" _ "modernc.org/sqlite" ) @@ -21,6 +20,10 @@ type Store struct { db *sql.DB } +func (s *Store) DB() *sql.DB { + return s.db +} + type TechnologyConnector struct { Type string `json:"type"` Slug string `json:"slug,omitempty"` @@ -28,79 +31,6 @@ type TechnologyConnector struct { IsPrimaryIcon bool `json:"is_primary_icon,omitempty"` } -type LibraryElement struct { - ID int64 `json:"id"` - Name string `json:"name"` - Kind *string `json:"kind"` - Description *string `json:"description"` - Technology *string `json:"technology"` - URL *string `json:"url"` - LogoURL *string `json:"logo_url"` - TechnologyConnectors []TechnologyConnector `json:"technology_connectors"` - Tags []string `json:"tags"` - Repo *string `json:"repo,omitempty"` - Branch *string `json:"branch,omitempty"` - FilePath *string `json:"file_path,omitempty"` - Language *string `json:"language,omitempty"` - CreatedAt string `json:"created_at"` - UpdatedAt string `json:"updated_at"` - HasView bool `json:"has_view"` - ViewLabel *string `json:"view_label"` -} - -type ViewTreeNode struct { - ID int64 `json:"id"` - OwnerElementID *int64 `json:"owner_element_id"` - Name string `json:"name"` - Description *string `json:"description"` - LevelLabel *string `json:"level_label"` - Level int `json:"level"` - Depth int `json:"depth"` - CreatedAt string `json:"created_at"` - UpdatedAt string `json:"updated_at"` - ParentViewID *int64 `json:"parent_view_id"` - Children []ViewTreeNode `json:"children"` -} - -type ViewSummary struct { - ID int64 `json:"id"` - Name string `json:"name"` - Label *string `json:"label"` - IsRoot bool `json:"is_root"` - CreatedAt string `json:"created_at"` - UpdatedAt string `json:"updated_at"` -} - -type PlacedElement struct { - ID int64 `json:"id"` - ViewID int64 `json:"view_id"` - ElementID int64 `json:"element_id"` - PositionX float64 `json:"position_x"` - PositionY float64 `json:"position_y"` - Name string `json:"name"` - Description *string `json:"description"` - Kind *string `json:"kind"` - Technology *string `json:"technology"` - URL *string `json:"url"` - LogoURL *string `json:"logo_url"` - TechnologyConnectors []TechnologyConnector `json:"technology_connectors"` - Tags []string `json:"tags"` - Repo *string `json:"repo,omitempty"` - Branch *string `json:"branch,omitempty"` - FilePath *string `json:"file_path,omitempty"` - Language *string `json:"language,omitempty"` - HasView bool `json:"has_view"` - ViewLabel *string `json:"view_label"` -} - -type ElementPlacement struct { - ID int64 `json:"id"` - ViewID int64 `json:"view_id"` - ElementID int64 `json:"element_id"` - PositionX float64 `json:"position_x"` - PositionY float64 `json:"position_y"` -} - type Connector struct { ID int64 `json:"id"` ViewID int64 `json:"view_id"` @@ -118,45 +48,6 @@ type Connector struct { UpdatedAt string `json:"updated_at"` } -type ViewConnector struct { - ID int64 `json:"id"` - ElementID *int64 `json:"element_id"` - FromViewID int64 `json:"from_view_id"` - ToViewID int64 `json:"to_view_id"` - ToViewName string `json:"to_view_name"` - RelationType string `json:"relation_type"` -} - -type IncomingViewConnector struct { - ID int64 `json:"id"` - ElementID int64 `json:"element_id"` - ElementName string `json:"element_name"` - FromViewID int64 `json:"from_view_id"` - FromViewName string `json:"from_view_name"` - ToViewID int64 `json:"to_view_id"` -} - -type ViewPlacement struct { - ViewID int64 `json:"view_id"` - ViewName string `json:"view_name"` -} - -type ViewLayer struct { - ID int64 `json:"id"` - DiagramID int64 `json:"diagram_id"` - Name string `json:"name"` - Tags []string `json:"tags"` - Color *string `json:"color,omitempty"` - CreatedAt string `json:"created_at,omitempty"` - UpdatedAt string `json:"updated_at,omitempty"` -} - -type Tag struct { - Name string `json:"name"` - Color string `json:"color"` - Description *string `json:"description"` -} - type ExploreViewData struct { Placements []PlacedElement `json:"placements"` Connectors []Connector `json:"connectors"` @@ -168,24 +59,6 @@ type ExploreData struct { Navigations []ViewConnector `json:"navigations"` } -type DependencyElement struct { - ID string `json:"id"` - Name string `json:"name"` - Description *string `json:"description,omitempty"` - Type *string `json:"type,omitempty"` - Technology *string `json:"technology,omitempty"` - URL *string `json:"url,omitempty"` - LogoURL *string `json:"logo_url,omitempty"` - TechnologyConnectors []TechnologyConnector `json:"technology_connectors"` - Tags []string `json:"tags"` - Repo *string `json:"repo,omitempty"` - Branch *string `json:"branch,omitempty"` - Language *string `json:"language,omitempty"` - FilePath *string `json:"file_path,omitempty"` - CreatedAt string `json:"created_at"` - UpdatedAt string `json:"updated_at"` -} - type DependencyConnector struct { ID string `json:"id"` ViewID string `json:"view_id"` @@ -203,24 +76,6 @@ type DependencyConnector struct { UpdatedAt string `json:"updated_at"` } -type PlanElement struct { - Ref string `json:"ref"` - Name string `json:"name"` - Kind *string `json:"kind"` - Description *string `json:"description"` - Technology *string `json:"technology"` - URL *string `json:"url"` - LogoURL *string `json:"logo_url"` - TechnologyLinks []TechnologyConnector `json:"technology_links"` - Tags []string `json:"tags"` - Repo *string `json:"repo"` - Branch *string `json:"branch"` - Language *string `json:"language"` - FilePath *string `json:"file_path"` - HasView bool `json:"has_view"` - ViewLabel *string `json:"view_label"` -} - type PlanConnector struct { Ref string `json:"ref"` ViewRef string `json:"view_ref"` @@ -241,20 +96,51 @@ func OpenStore(dbPath string, migrations embed.FS) (*Store, error) { if err != nil { return nil, err } - db.SetMaxOpenConns(1) - if _, err := db.Exec(`PRAGMA foreign_keys = ON;`); err != nil { + if err := configureSQLiteDB(db); err != nil { + _ = db.Close() return nil, err } + if err := sqlitevec.Register(db); err != nil { + _ = db.Close() + return nil, fmt.Errorf("register sqlite-vec: %w", err) + } if err := applyMigrations(db, migrations); err != nil { + _ = db.Close() return nil, err } store := &Store{db: db} if err := store.ensureBootstrapData(context.Background()); err != nil { + _ = db.Close() return nil, err } return store, nil } +func configureSQLiteDB(db *sql.DB) error { + db.SetMaxOpenConns(1) + db.SetMaxIdleConns(1) + + pragmas := []string{ + `PRAGMA busy_timeout = 5000;`, + `PRAGMA journal_mode = WAL;`, + `PRAGMA synchronous = NORMAL;`, + `PRAGMA foreign_keys = ON;`, + } + for _, pragma := range pragmas { + if _, err := db.Exec(pragma); err != nil { + return fmt.Errorf("configure sqlite %s: %w", pragma, err) + } + } + return nil +} + +func (s *Store) Close() error { + if s == nil || s.db == nil { + return nil + } + return s.db.Close() +} + func (s *Store) ensureBootstrapData(ctx context.Context) error { var count int if err := s.db.QueryRowContext(ctx, `SELECT COUNT(*) FROM views`).Scan(&count); err != nil { @@ -291,6 +177,9 @@ func applyMigrations(db *sql.DB, migrations embed.FS) error { return err } if _, err := db.Exec(string(sqlBytes)); err != nil { + if strings.Contains(err.Error(), "duplicate column name") { + continue + } return fmt.Errorf("apply migration %s: %w", entry.Name(), err) } } @@ -362,923 +251,10 @@ type viewRow struct { UpdatedAt string } -func (s *Store) listViewRows(ctx context.Context) ([]viewRow, error) { - rows, err := s.db.QueryContext(ctx, `SELECT id, owner_element_id, name, description, level_label, level, created_at, updated_at FROM views ORDER BY id`) - if err != nil { - return nil, err - } - defer func() { _ = rows.Close() }() - var out []viewRow - for rows.Next() { - var row viewRow - if err := rows.Scan(&row.ID, &row.OwnerElementID, &row.Name, &row.Description, &row.LevelLabel, &row.Level, &row.CreatedAt, &row.UpdatedAt); err != nil { - return nil, err - } - out = append(out, row) - } - return out, rows.Err() -} - -func (s *Store) parentViewForOwner(ctx context.Context, ownerElementID int64, excludeViewID int64) (*int64, error) { - row := s.db.QueryRowContext(ctx, `SELECT view_id FROM placements WHERE element_id = ? AND view_id != ? ORDER BY view_id LIMIT 1`, ownerElementID, excludeViewID) - var viewID int64 - if err := row.Scan(&viewID); err != nil { - if errors.Is(err, sql.ErrNoRows) { - return nil, nil - } - return nil, err - } - return &viewID, nil -} - -func (s *Store) childViewMeta(ctx context.Context, elementID int64) (bool, *string, error) { - row := s.db.QueryRowContext(ctx, `SELECT level_label FROM views WHERE owner_element_id = ? ORDER BY id LIMIT 1`, elementID) - var label sql.NullString - if err := row.Scan(&label); err != nil { - if errors.Is(err, sql.ErrNoRows) { - return false, nil, nil - } - return false, nil, err - } - if label.Valid { - return true, &label.String, nil - } - return true, nil, nil -} - -func viewNodeFromRow(row viewRow, parentID *int64, depth int) ViewTreeNode { - var ownerElementID *int64 - if row.OwnerElementID.Valid { - ownerElementID = new(row.OwnerElementID.Int64) - } - var description *string - if row.Description.Valid { - description = new(row.Description.String) - } - var levelLabel *string - if row.LevelLabel.Valid { - levelLabel = new(row.LevelLabel.String) - } - return ViewTreeNode{ - ID: row.ID, - OwnerElementID: ownerElementID, - Name: row.Name, - Description: description, - LevelLabel: levelLabel, - Level: row.Level, - Depth: depth, - CreatedAt: row.CreatedAt, - UpdatedAt: row.UpdatedAt, - ParentViewID: parentID, - Children: []ViewTreeNode{}, - } -} - -func (s *Store) ViewTree(ctx context.Context) ([]ViewTreeNode, error) { - rows, err := s.listViewRows(ctx) - if err != nil { - return nil, err - } - rowByID := make(map[int64]viewRow, len(rows)) - byParent := map[int64][]viewRow{} - var roots []viewRow - parentMap := map[int64]*int64{} - for _, row := range rows { - rowByID[row.ID] = row - if row.OwnerElementID.Valid { - parentID, err := s.parentViewForOwner(ctx, row.OwnerElementID.Int64, row.ID) - if err != nil { - return nil, err - } - parentMap[row.ID] = parentID - if parentID != nil { - byParent[*parentID] = append(byParent[*parentID], row) - continue - } - } - parentMap[row.ID] = nil - roots = append(roots, row) - } - visited := make(map[int64]bool, len(rows)) - var build func(row viewRow, depth int, stack map[int64]bool) ViewTreeNode - build = func(row viewRow, depth int, stack map[int64]bool) ViewTreeNode { - node := viewNodeFromRow(row, parentMap[row.ID], depth) - visited[row.ID] = true - if stack[row.ID] { - return node - } - nextStack := make(map[int64]bool, len(stack)+1) - maps.Copy(nextStack, stack) - nextStack[row.ID] = true - children := byParent[row.ID] - sort.Slice(children, func(i, j int) bool { return children[i].ID < children[j].ID }) - for _, child := range children { - if nextStack[child.ID] { - continue - } - node.Children = append(node.Children, build(child, depth+1, nextStack)) - } - return node - } - sort.Slice(roots, func(i, j int) bool { return roots[i].ID < roots[j].ID }) - out := make([]ViewTreeNode, 0, len(roots)) - for _, root := range roots { - out = append(out, build(root, 0, map[int64]bool{})) - } - if len(visited) < len(rows) { - remaining := make([]viewRow, 0, len(rows)-len(visited)) - for _, row := range rows { - if visited[row.ID] { - continue - } - remaining = append(remaining, rowByID[row.ID]) - } - sort.Slice(remaining, func(i, j int) bool { return remaining[i].ID < remaining[j].ID }) - for _, row := range remaining { - if visited[row.ID] { - continue - } - node := build(row, 0, map[int64]bool{}) - node.ParentViewID = nil - out = append(out, node) - } - } - return out, nil -} - -func flattenTree(nodes []ViewTreeNode) []ViewTreeNode { - var out []ViewTreeNode - var walk func(items []ViewTreeNode) - walk = func(items []ViewTreeNode) { - for _, item := range items { - children := item.Children - item.Children = nil - out = append(out, item) - walk(children) - } - } - walk(nodes) - return out -} - -func (s *Store) Views(ctx context.Context) ([]ViewSummary, error) { - tree, err := s.ViewTree(ctx) - if err != nil { - return nil, err - } - flat := flattenTree(tree) - out := make([]ViewSummary, 0, len(flat)) - for _, node := range flat { - out = append(out, ViewSummary{ - ID: node.ID, - Name: node.Name, - Label: node.LevelLabel, - IsRoot: node.ParentViewID == nil, - CreatedAt: node.CreatedAt, - UpdatedAt: node.UpdatedAt, - }) - } - return out, nil -} - -func (s *Store) ViewByID(ctx context.Context, id int64) (ViewTreeNode, error) { - tree, err := s.ViewTree(ctx) - if err != nil { - return ViewTreeNode{}, err - } - for _, node := range flattenTree(tree) { - if node.ID == id { - return node, nil - } - } - return ViewTreeNode{}, sql.ErrNoRows -} - -func (s *Store) CreateView(ctx context.Context, name string, levelLabel *string, ownerElementID *int64) (ViewSummary, error) { - now := nowString() - level := 1 - if ownerElementID != nil { - parentID, err := s.parentViewForOwner(ctx, *ownerElementID, 0) - if err == nil && parentID != nil { - parent, err := s.ViewByID(ctx, *parentID) - if err == nil { - level = parent.Level + 1 - } - } - } - res, err := s.db.ExecContext(ctx, `INSERT INTO views(owner_element_id, name, description, level_label, level, created_at, updated_at) VALUES (?, ?, NULL, ?, ?, ?, ?)`, - ownerElementID, strings.TrimSpace(name), levelLabel, level, now, now) - if err != nil { - return ViewSummary{}, err - } - id, _ := res.LastInsertId() - view, err := s.ViewByID(ctx, id) - if err != nil { - return ViewSummary{}, err - } - return ViewSummary{ - ID: view.ID, - Name: view.Name, - Label: view.LevelLabel, - IsRoot: view.ParentViewID == nil, - CreatedAt: view.CreatedAt, - UpdatedAt: view.UpdatedAt, - }, nil -} - -func (s *Store) UpdateView(ctx context.Context, id int64, name *string, levelLabel *string) (ViewSummary, error) { - current, err := s.ViewByID(ctx, id) - if err != nil { - return ViewSummary{}, err - } - nextName := current.Name - if name != nil && strings.TrimSpace(*name) != "" { - nextName = strings.TrimSpace(*name) - } - _, err = s.db.ExecContext(ctx, `UPDATE views SET name = ?, level_label = ?, updated_at = ? WHERE id = ?`, nextName, levelLabel, nowString(), id) - if err != nil { - return ViewSummary{}, err - } - updated, err := s.ViewByID(ctx, id) - if err != nil { - return ViewSummary{}, err - } - return ViewSummary{ - ID: updated.ID, - Name: updated.Name, - Label: updated.LevelLabel, - IsRoot: updated.ParentViewID == nil, - CreatedAt: updated.CreatedAt, - UpdatedAt: updated.UpdatedAt, - }, nil -} - -func (s *Store) SetViewLevel(ctx context.Context, id int64, level int) error { - _, err := s.db.ExecContext(ctx, `UPDATE views SET level = ?, updated_at = ? WHERE id = ?`, level, nowString(), id) - return err -} - -func (s *Store) DeleteView(ctx context.Context, id int64) error { - _, err := s.db.ExecContext(ctx, `DELETE FROM views WHERE id = ?`, id) - return err -} - -func scanElement(row scanner, includeViewMeta bool, store *Store, ctx context.Context) (LibraryElement, error) { - var ( - elem LibraryElement - techRaw string - tagRaw string - kind sql.NullString - description sql.NullString - technology sql.NullString - url sql.NullString - logoURL sql.NullString - repo sql.NullString - branch sql.NullString - filePath sql.NullString - language sql.NullString - ) - if err := row.Scan(&elem.ID, &elem.Name, &kind, &description, &technology, &url, &logoURL, &techRaw, &tagRaw, &repo, &branch, &filePath, &language, &elem.CreatedAt, &elem.UpdatedAt); err != nil { - return LibraryElement{}, err - } - if kind.Valid { - elem.Kind = &kind.String - } - if description.Valid { - elem.Description = &description.String - } - if technology.Valid { - elem.Technology = &technology.String - } - if url.Valid { - elem.URL = &url.String - } - if logoURL.Valid { - elem.LogoURL = &logoURL.String - } - if repo.Valid { - elem.Repo = &repo.String - } - if branch.Valid { - elem.Branch = &branch.String - } - if filePath.Valid { - elem.FilePath = &filePath.String - } - if language.Valid { - elem.Language = &language.String - } - elem.TechnologyConnectors = parseTechnologyConnectors(techRaw) - elem.Tags = parseStrings(tagRaw) - if includeViewMeta { - hasView, label, err := store.childViewMeta(ctx, elem.ID) - if err != nil { - return LibraryElement{}, err - } - elem.HasView = hasView - elem.ViewLabel = label - } - return elem, nil -} - type scanner interface { Scan(dest ...any) error } -func (s *Store) Elements(ctx context.Context, limit, offset int, search string) ([]LibraryElement, error) { - type elementRow struct { - ID int64 - Name string - Kind sql.NullString - Description sql.NullString - Technology sql.NullString - URL sql.NullString - LogoURL sql.NullString - TechRaw string - TagRaw string - Repo sql.NullString - Branch sql.NullString - FilePath sql.NullString - Language sql.NullString - CreatedAt string - UpdatedAt string - } - - query := `SELECT id, name, kind, description, technology, url, logo_url, technology_connectors, tags, repo, branch, file_path, language, created_at, updated_at FROM elements` - args := []any{} - if strings.TrimSpace(search) != "" { - query += ` WHERE LOWER(name) LIKE LOWER(?) OR LOWER(COALESCE(description, '')) LIKE LOWER(?)` - pattern := "%" + strings.TrimSpace(search) + "%" - args = append(args, pattern, pattern) - } - query += ` ORDER BY updated_at DESC` - if limit > 0 { - query += ` LIMIT ? OFFSET ?` - args = append(args, limit, offset) - } - rows, err := s.db.QueryContext(ctx, query, args...) - if err != nil { - return nil, err - } - defer func() { _ = rows.Close() }() - scanned := make([]elementRow, 0) - for rows.Next() { - var row elementRow - if err := rows.Scan( - &row.ID, - &row.Name, - &row.Kind, - &row.Description, - &row.Technology, - &row.URL, - &row.LogoURL, - &row.TechRaw, - &row.TagRaw, - &row.Repo, - &row.Branch, - &row.FilePath, - &row.Language, - &row.CreatedAt, - &row.UpdatedAt, - ); err != nil { - return nil, err - } - scanned = append(scanned, row) - } - if err := rows.Err(); err != nil { - return nil, err - } - - out := make([]LibraryElement, 0, len(scanned)) - for _, row := range scanned { - elem := LibraryElement{ - ID: row.ID, - Name: row.Name, - TechnologyConnectors: parseTechnologyConnectors(row.TechRaw), - Tags: parseStrings(row.TagRaw), - CreatedAt: row.CreatedAt, - UpdatedAt: row.UpdatedAt, - } - if row.Kind.Valid { - elem.Kind = &row.Kind.String - } - if row.Description.Valid { - elem.Description = &row.Description.String - } - if row.Technology.Valid { - elem.Technology = &row.Technology.String - } - if row.URL.Valid { - elem.URL = &row.URL.String - } - if row.LogoURL.Valid { - elem.LogoURL = &row.LogoURL.String - } - if row.Repo.Valid { - elem.Repo = &row.Repo.String - } - if row.Branch.Valid { - elem.Branch = &row.Branch.String - } - if row.FilePath.Valid { - elem.FilePath = &row.FilePath.String - } - if row.Language.Valid { - elem.Language = &row.Language.String - } - hasView, label, err := s.childViewMeta(ctx, elem.ID) - if err != nil { - return nil, err - } - elem.HasView = hasView - elem.ViewLabel = label - out = append(out, elem) - } - return out, nil -} - -func (s *Store) ElementByID(ctx context.Context, id int64) (LibraryElement, error) { - row := s.db.QueryRowContext(ctx, `SELECT id, name, kind, description, technology, url, logo_url, technology_connectors, tags, repo, branch, file_path, language, created_at, updated_at FROM elements WHERE id = ?`, id) - return scanElement(row, true, s, ctx) -} - -func (s *Store) CreateElement(ctx context.Context, input LibraryElement) (LibraryElement, error) { - if err := s.ensureTagColors(ctx, input.Tags); err != nil { - return LibraryElement{}, err - } - now := nowString() - res, err := s.db.ExecContext(ctx, ` - INSERT INTO elements(name, kind, description, technology, url, logo_url, technology_connectors, tags, repo, branch, file_path, language, created_at, updated_at) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, - strings.TrimSpace(input.Name), - input.Kind, - input.Description, - input.Technology, - input.URL, - input.LogoURL, - jsonString(input.TechnologyConnectors, "[]"), - jsonString(input.Tags, "[]"), - input.Repo, - input.Branch, - input.FilePath, - input.Language, - now, - now, - ) - if err != nil { - return LibraryElement{}, err - } - id, _ := res.LastInsertId() - return s.ElementByID(ctx, id) -} - -func (s *Store) UpdateElement(ctx context.Context, id int64, input LibraryElement) (LibraryElement, error) { - if input.Tags != nil { - if err := s.ensureTagColors(ctx, input.Tags); err != nil { - return LibraryElement{}, err - } - } - current, err := s.ElementByID(ctx, id) - if err != nil { - return LibraryElement{}, err - } - if input.Name == "" { - input.Name = current.Name - } - if input.Kind == nil { - input.Kind = current.Kind - } - if input.Description == nil { - input.Description = current.Description - } - if input.Technology == nil { - input.Technology = current.Technology - } - if input.URL == nil { - input.URL = current.URL - } - if input.LogoURL == nil { - input.LogoURL = current.LogoURL - } - if input.Repo == nil { - input.Repo = current.Repo - } - if input.Branch == nil { - input.Branch = current.Branch - } - if input.FilePath == nil { - input.FilePath = current.FilePath - } - if input.Language == nil { - input.Language = current.Language - } - if input.TechnologyConnectors == nil { - input.TechnologyConnectors = current.TechnologyConnectors - } - if input.Tags == nil { - input.Tags = current.Tags - } - _, err = s.db.ExecContext(ctx, ` - UPDATE elements SET name = ?, kind = ?, description = ?, technology = ?, url = ?, logo_url = ?, technology_connectors = ?, tags = ?, repo = ?, branch = ?, file_path = ?, language = ?, updated_at = ? - WHERE id = ?`, - input.Name, input.Kind, input.Description, input.Technology, input.URL, input.LogoURL, - jsonString(input.TechnologyConnectors, "[]"), jsonString(input.Tags, "[]"), - input.Repo, input.Branch, input.FilePath, input.Language, nowString(), id, - ) - if err != nil { - return LibraryElement{}, err - } - return s.ElementByID(ctx, id) -} - -func (s *Store) DeleteElement(ctx context.Context, id int64) error { - _, err := s.db.ExecContext(ctx, `DELETE FROM elements WHERE id = ?`, id) - return err -} - -func (s *Store) ListElementPlacements(ctx context.Context, elementID int64) ([]ViewPlacement, error) { - rows, err := s.db.QueryContext(ctx, ` - SELECT p.view_id, v.name - FROM placements p - JOIN views v ON v.id = p.view_id - WHERE p.element_id = ? - ORDER BY p.view_id`, elementID) - if err != nil { - return nil, err - } - defer func() { _ = rows.Close() }() - out := make([]ViewPlacement, 0) - for rows.Next() { - var placement ViewPlacement - if err := rows.Scan(&placement.ViewID, &placement.ViewName); err != nil { - return nil, err - } - out = append(out, placement) - } - return out, rows.Err() -} - -func (s *Store) Placements(ctx context.Context, viewID int64) ([]PlacedElement, error) { - type placementRow struct { - item PlacedElement - techRaw string - tagRaw string - } - - rows, err := s.db.QueryContext(ctx, ` - SELECT p.id, p.view_id, p.element_id, p.position_x, p.position_y, - e.name, e.kind, e.description, e.technology, e.url, e.logo_url, e.technology_connectors, e.tags, e.repo, e.branch, e.file_path, e.language, e.created_at, e.updated_at - FROM placements p - JOIN elements e ON e.id = p.element_id - WHERE p.view_id = ? - ORDER BY p.id`, viewID) - if err != nil { - return nil, err - } - defer func() { _ = rows.Close() }() - scanned := make([]placementRow, 0) - for rows.Next() { - var row placementRow - if err := rows.Scan(&row.item.ID, &row.item.ViewID, &row.item.ElementID, &row.item.PositionX, &row.item.PositionY, - &row.item.Name, &row.item.Kind, &row.item.Description, &row.item.Technology, &row.item.URL, &row.item.LogoURL, - &row.techRaw, &row.tagRaw, &row.item.Repo, &row.item.Branch, &row.item.FilePath, &row.item.Language, new(string), new(string)); err != nil { - return nil, err - } - scanned = append(scanned, row) - } - if err := rows.Err(); err != nil { - return nil, err - } - - out := make([]PlacedElement, 0, len(scanned)) - for _, row := range scanned { - item := row.item - item.TechnologyConnectors = parseTechnologyConnectors(row.techRaw) - item.Tags = parseStrings(row.tagRaw) - hasView, label, err := s.childViewMeta(ctx, item.ElementID) - if err != nil { - return nil, err - } - item.HasView = hasView - item.ViewLabel = label - out = append(out, item) - } - return out, nil -} - -func (s *Store) ElementPlacements(ctx context.Context, viewID int64) ([]ElementPlacement, error) { - rows, err := s.db.QueryContext(ctx, `SELECT id, view_id, element_id, position_x, position_y FROM placements WHERE view_id = ? ORDER BY id`, viewID) - if err != nil { - return nil, err - } - defer func() { _ = rows.Close() }() - out := make([]ElementPlacement, 0) - for rows.Next() { - var item ElementPlacement - if err := rows.Scan(&item.ID, &item.ViewID, &item.ElementID, &item.PositionX, &item.PositionY); err != nil { - return nil, err - } - out = append(out, item) - } - return out, rows.Err() -} - -func (s *Store) AddPlacement(ctx context.Context, viewID, elementID int64, x, y float64) (ElementPlacement, error) { - now := nowString() - _, err := s.db.ExecContext(ctx, ` - INSERT INTO placements(view_id, element_id, position_x, position_y, created_at, updated_at) - VALUES (?, ?, ?, ?, ?, ?) - ON CONFLICT(view_id, element_id) DO UPDATE SET position_x = excluded.position_x, position_y = excluded.position_y, updated_at = excluded.updated_at`, - viewID, elementID, x, y, now, now) - if err != nil { - return ElementPlacement{}, err - } - row := s.db.QueryRowContext(ctx, `SELECT id, view_id, element_id, position_x, position_y FROM placements WHERE view_id = ? AND element_id = ?`, viewID, elementID) - var item ElementPlacement - if err := row.Scan(&item.ID, &item.ViewID, &item.ElementID, &item.PositionX, &item.PositionY); err != nil { - return ElementPlacement{}, err - } - return item, nil -} - -func (s *Store) UpdatePlacement(ctx context.Context, viewID, elementID int64, x, y float64) error { - _, err := s.db.ExecContext(ctx, `UPDATE placements SET position_x = ?, position_y = ?, updated_at = ? WHERE view_id = ? AND element_id = ?`, x, y, nowString(), viewID, elementID) - return err -} - -func (s *Store) DeletePlacement(ctx context.Context, viewID, elementID int64) error { - _, err := s.db.ExecContext(ctx, `DELETE FROM placements WHERE view_id = ? AND element_id = ?`, viewID, elementID) - return err -} - -func (s *Store) Connectors(ctx context.Context, viewID int64) ([]Connector, error) { - rows, err := s.db.QueryContext(ctx, ` - SELECT id, view_id, source_element_id, target_element_id, label, description, relationship, direction, style, url, source_handle, target_handle, created_at, updated_at - FROM connectors WHERE view_id = ? ORDER BY id`, viewID) - if err != nil { - return nil, err - } - defer func() { _ = rows.Close() }() - out := make([]Connector, 0) - for rows.Next() { - var item Connector - if err := rows.Scan(&item.ID, &item.ViewID, &item.SourceElementID, &item.TargetElementID, &item.Label, &item.Description, &item.Relationship, &item.Direction, &item.Style, &item.URL, &item.SourceHandle, &item.TargetHandle, &item.CreatedAt, &item.UpdatedAt); err != nil { - return nil, err - } - out = append(out, item) - } - return out, rows.Err() -} - -func (s *Store) CreateConnector(ctx context.Context, input Connector) (Connector, error) { - now := nowString() - res, err := s.db.ExecContext(ctx, ` - INSERT INTO connectors(view_id, source_element_id, target_element_id, label, description, relationship, direction, style, url, source_handle, target_handle, created_at, updated_at) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, - input.ViewID, input.SourceElementID, input.TargetElementID, input.Label, input.Description, input.Relationship, - normalizeDirection(new(input.Direction)), input.Style, input.URL, input.SourceHandle, input.TargetHandle, now, now) - if err != nil { - return Connector{}, err - } - id, _ := res.LastInsertId() - return s.ConnectorByID(ctx, id) -} - -func (s *Store) ConnectorByID(ctx context.Context, id int64) (Connector, error) { - row := s.db.QueryRowContext(ctx, `SELECT id, view_id, source_element_id, target_element_id, label, description, relationship, direction, style, url, source_handle, target_handle, created_at, updated_at FROM connectors WHERE id = ?`, id) - var item Connector - if err := row.Scan(&item.ID, &item.ViewID, &item.SourceElementID, &item.TargetElementID, &item.Label, &item.Description, &item.Relationship, &item.Direction, &item.Style, &item.URL, &item.SourceHandle, &item.TargetHandle, &item.CreatedAt, &item.UpdatedAt); err != nil { - return Connector{}, err - } - return item, nil -} - -func (s *Store) UpdateConnector(ctx context.Context, id int64, patch Connector) (Connector, error) { - current, err := s.ConnectorByID(ctx, id) - if err != nil { - return Connector{}, err - } - if patch.SourceElementID == 0 { - patch.SourceElementID = current.SourceElementID - } - if patch.TargetElementID == 0 { - patch.TargetElementID = current.TargetElementID - } - if patch.ViewID == 0 { - patch.ViewID = current.ViewID - } - if patch.Direction == "" { - patch.Direction = current.Direction - } - if patch.Style == "" { - patch.Style = current.Style - } - if patch.Label == nil { - patch.Label = current.Label - } - if patch.Description == nil { - patch.Description = current.Description - } - if patch.Relationship == nil { - patch.Relationship = current.Relationship - } - if patch.URL == nil { - patch.URL = current.URL - } - if patch.SourceHandle == nil { - patch.SourceHandle = current.SourceHandle - } - if patch.TargetHandle == nil { - patch.TargetHandle = current.TargetHandle - } - _, err = s.db.ExecContext(ctx, ` - UPDATE connectors SET source_element_id = ?, target_element_id = ?, label = ?, description = ?, relationship = ?, direction = ?, style = ?, url = ?, source_handle = ?, target_handle = ?, updated_at = ? - WHERE id = ?`, - patch.SourceElementID, patch.TargetElementID, patch.Label, patch.Description, patch.Relationship, patch.Direction, patch.Style, patch.URL, patch.SourceHandle, patch.TargetHandle, nowString(), id) - if err != nil { - return Connector{}, err - } - return s.ConnectorByID(ctx, id) -} - -func (s *Store) DeleteConnector(ctx context.Context, id int64) error { - _, err := s.db.ExecContext(ctx, `DELETE FROM connectors WHERE id = ?`, id) - return err -} - -func (s *Store) Layers(ctx context.Context, viewID int64) ([]ViewLayer, error) { - rows, err := s.db.QueryContext(ctx, `SELECT id, view_id, name, tags, color, created_at, updated_at FROM view_layers WHERE view_id = ? ORDER BY id`, viewID) - if err != nil { - return nil, err - } - defer func() { _ = rows.Close() }() - out := make([]ViewLayer, 0) - for rows.Next() { - var rawTags string - var item ViewLayer - if err := rows.Scan(&item.ID, &item.DiagramID, &item.Name, &rawTags, &item.Color, &item.CreatedAt, &item.UpdatedAt); err != nil { - return nil, err - } - item.Tags = parseStrings(rawTags) - out = append(out, item) - } - return out, rows.Err() -} - -func (s *Store) CreateLayer(ctx context.Context, viewID int64, name string, tags []string, color *string) (ViewLayer, error) { - if err := s.ensureTagColors(ctx, tags); err != nil { - return ViewLayer{}, err - } - - if color == nil || strings.TrimSpace(*color) == "" { - // User said pick unused, usually means relative to existing layers in the same view or global tags. - // Frontend uses tagColors. - tagsMap, err := s.Tags(ctx) - if err != nil { - return ViewLayer{}, err - } - var usedColors []string - for _, t := range tagsMap { - usedColors = append(usedColors, t.Color) - } - // Also consider existing layers colors - layers, err := s.Layers(ctx, viewID) - if err == nil { - for _, l := range layers { - if l.Color != nil { - usedColors = append(usedColors, *l.Color) - } - } - } - c := s.pickUnusedColor(ctx, usedColors) - color = &c - } - - now := nowString() - res, err := s.db.ExecContext(ctx, `INSERT INTO view_layers(view_id, name, tags, color, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?)`, - viewID, name, jsonString(tags, "[]"), color, now, now) - if err != nil { - return ViewLayer{}, err - } - id, _ := res.LastInsertId() - return s.LayerByID(ctx, id) -} - -func (s *Store) LayerByID(ctx context.Context, id int64) (ViewLayer, error) { - row := s.db.QueryRowContext(ctx, `SELECT id, view_id, name, tags, color, created_at, updated_at FROM view_layers WHERE id = ?`, id) - var rawTags string - var item ViewLayer - if err := row.Scan(&item.ID, &item.DiagramID, &item.Name, &rawTags, &item.Color, &item.CreatedAt, &item.UpdatedAt); err != nil { - return ViewLayer{}, err - } - item.Tags = parseStrings(rawTags) - return item, nil -} - -func (s *Store) UpdateLayer(ctx context.Context, id int64, patch ViewLayer) (ViewLayer, error) { - current, err := s.LayerByID(ctx, id) - if err != nil { - return ViewLayer{}, err - } - if patch.Name == "" { - patch.Name = current.Name - } - if patch.Tags == nil { - patch.Tags = current.Tags - } - if patch.Color == nil { - patch.Color = current.Color - } - _, err = s.db.ExecContext(ctx, `UPDATE view_layers SET name = ?, tags = ?, color = ?, updated_at = ? WHERE id = ?`, patch.Name, jsonString(patch.Tags, "[]"), patch.Color, nowString(), id) - if err != nil { - return ViewLayer{}, err - } - return s.LayerByID(ctx, id) -} - -func (s *Store) DeleteLayer(ctx context.Context, id int64) error { - _, err := s.db.ExecContext(ctx, `DELETE FROM view_layers WHERE id = ?`, id) - return err -} - -func (s *Store) Tags(ctx context.Context) (map[string]Tag, error) { - rows, err := s.db.QueryContext(ctx, `SELECT name, color, description FROM tags ORDER BY name`) - if err != nil { - return nil, err - } - defer func() { _ = rows.Close() }() - out := map[string]Tag{} - for rows.Next() { - var tag Tag - if err := rows.Scan(&tag.Name, &tag.Color, &tag.Description); err != nil { - return nil, err - } - out[tag.Name] = tag - } - return out, rows.Err() -} - -func (s *Store) UpdateTag(ctx context.Context, name, color string, description *string) error { - _, err := s.db.ExecContext(ctx, ` - INSERT INTO tags(name, color, description) VALUES (?, ?, ?) - ON CONFLICT(name) DO UPDATE SET color = excluded.color, description = excluded.description`, - name, color, description) - return err -} - -func (s *Store) ListElementNavigations(ctx context.Context, elementID int64, fromViewID, toViewID *int64) ([]ViewConnector, error) { - row := s.db.QueryRowContext(ctx, `SELECT id, name FROM views WHERE owner_element_id = ? ORDER BY id LIMIT 1`, elementID) - var childViewID int64 - var childViewName string - if err := row.Scan(&childViewID, &childViewName); err != nil { - if errors.Is(err, sql.ErrNoRows) { - return []ViewConnector{}, nil - } - return nil, err - } - parentID, err := s.parentViewForOwner(ctx, elementID, childViewID) - if err != nil { - return nil, err - } - out := make([]ViewConnector, 0, 1) - if fromViewID != nil && *fromViewID > 0 { - if parentID != nil && *parentID == *fromViewID { - out = append(out, ViewConnector{ID: 0, ElementID: &elementID, FromViewID: *fromViewID, ToViewID: childViewID, ToViewName: childViewName, RelationType: "child"}) - } - return out, nil - } - if toViewID != nil && *toViewID > 0 && parentID != nil && *toViewID == childViewID { - out = append(out, ViewConnector{ID: 0, ElementID: &elementID, FromViewID: *parentID, ToViewID: childViewID, ToViewName: childViewName, RelationType: "child"}) - } - return out, nil -} - -func (s *Store) ListIncomingNavigations(ctx context.Context, viewID int64) ([]IncomingViewConnector, error) { - view, err := s.ViewByID(ctx, viewID) - if err != nil { - return nil, err - } - if view.OwnerElementID == nil || view.ParentViewID == nil { - return []IncomingViewConnector{}, nil - } - element, err := s.ElementByID(ctx, *view.OwnerElementID) - if err != nil { - return nil, err - } - parent, err := s.ViewByID(ctx, *view.ParentViewID) - if err != nil { - return nil, err - } - return []IncomingViewConnector{{ - ID: 0, - ElementID: *view.OwnerElementID, - ElementName: element.Name, - FromViewID: parent.ID, - FromViewName: parent.Name, - ToViewID: view.ID, - }}, nil -} - func (s *Store) Explore(ctx context.Context) (ExploreData, error) { tree, err := s.ViewTree(ctx) if err != nil { @@ -1314,7 +290,7 @@ func (s *Store) Explore(ctx context.Context) (ExploreData, error) { } func (s *Store) Dependencies(ctx context.Context) (map[string]any, error) { - elements, err := s.Elements(ctx, 0, 0, "") + elements, _, err := s.Elements(ctx, 0, 0, "") if err != nil { return nil, err } @@ -1504,62 +480,3 @@ func htmlEscape(value string) string { replacer := strings.NewReplacer("&", "&", "<", "<", ">", ">", `"`, """) return replacer.Replace(value) } - -var SWATCH_COLORS = []string{ - "#F56565", "#ED8936", "#ECC94B", "#48BB78", "#38B2AC", - "#4299E1", "#667EEA", "#9F7AEA", "#ED64A6", "#A0AEC0", -} - -func (s *Store) pickUnusedColor(ctx context.Context, usedColors []string) string { - used := make(map[string]bool) - for _, c := range usedColors { - used[strings.ToUpper(c)] = true - } - - var pool []string - for _, c := range SWATCH_COLORS { - if !used[strings.ToUpper(c)] { - pool = append(pool, c) - } - } - - source := pool - if len(source) == 0 { - source = SWATCH_COLORS - } - - return source[time.Now().UnixNano()%int64(len(source))] -} - -func (s *Store) ensureTagColors(ctx context.Context, tags []string) error { - if len(tags) == 0 { - return nil - } - - existing, err := s.Tags(ctx) - if err != nil { - return err - } - - var usedColors []string - for _, t := range existing { - usedColors = append(usedColors, t.Color) - } - - for _, name := range tags { - name = strings.TrimSpace(name) - if name == "" { - continue - } - if _, ok := existing[name]; !ok { - color := s.pickUnusedColor(ctx, usedColors) - if err := s.UpdateTag(ctx, name, color, nil); err != nil { - return err - } - usedColors = append(usedColors, color) - // Refresh existing to avoid re-adding same tag if it appears twice in the list - existing[name] = Tag{Name: name, Color: color} - } - } - return nil -} diff --git a/internal/app/store_test.go b/internal/app/store_test.go new file mode 100644 index 0000000..3e68dd3 --- /dev/null +++ b/internal/app/store_test.go @@ -0,0 +1,259 @@ +package app + +import ( + "context" + "database/sql" + "fmt" + "path/filepath" + "strings" + "testing" + + assets "github.com/mertcikla/tld" + "github.com/mertcikla/tld/internal/tagcolors" +) + +func TestConfigureSQLiteDBEnablesBusyTimeoutAndWAL(t *testing.T) { + db, err := sql.Open("sqlite", filepath.Join(t.TempDir(), "tld.db")) + if err != nil { + t.Fatal(err) + } + defer func() { _ = db.Close() }() + + if err := configureSQLiteDB(db); err != nil { + t.Fatal(err) + } + + var busyTimeout int + if err := db.QueryRow(`PRAGMA busy_timeout;`).Scan(&busyTimeout); err != nil { + t.Fatal(err) + } + if busyTimeout != 5000 { + t.Fatalf("busy_timeout = %d, want 5000", busyTimeout) + } + + var journalMode string + if err := db.QueryRow(`PRAGMA journal_mode;`).Scan(&journalMode); err != nil { + t.Fatal(err) + } + if strings.ToLower(journalMode) != "wal" { + t.Fatalf("journal_mode = %q, want wal", journalMode) + } +} + +func TestStoreElementsSearchPaginationAndViewMetadata(t *testing.T) { + store := openAppStore(t) + ctx := context.Background() + + serviceKind := "service" + api, err := store.CreateElement(ctx, LibraryElement{Name: "API", Kind: &serviceKind, Description: new("Public runtime API"), Tags: []string{"runtime"}}) + if err != nil { + t.Fatal(err) + } + worker, err := store.CreateElement(ctx, LibraryElement{Name: "Worker", Kind: &serviceKind, Description: new("Background jobs"), Tags: []string{"runtime"}}) + if err != nil { + t.Fatal(err) + } + if _, err := store.CreateView(ctx, "API detail", new("Service"), &api.ID); err != nil { + t.Fatal(err) + } + + results, total, err := store.Elements(ctx, 1, 0, "runtime") + if err != nil { + t.Fatal(err) + } + if total != 1 || len(results) != 1 || results[0].ID != api.ID { + t.Fatalf("search results = total:%d elements:%+v, want only API", total, results) + } + if !results[0].HasView || results[0].ViewLabel == nil || *results[0].ViewLabel != "Service" { + t.Fatalf("view metadata = has:%v label:%v, want Service child view", results[0].HasView, results[0].ViewLabel) + } + + results, total, err = store.Elements(ctx, 1, 1, "") + if err != nil { + t.Fatal(err) + } + if total != 2 || len(results) != 1 || results[0].ID != api.ID { + t.Fatalf("paginated results = total:%d elements:%+v, want second inserted API after Worker", total, results) + } + + tags, err := store.Tags(ctx) + if err != nil { + t.Fatal(err) + } + if _, ok := tags["runtime"]; !ok { + t.Fatalf("tags = %+v, want runtime tag color created with element", tags) + } + _ = worker +} + +func TestStoreConnectorsPreserveHandlesAndPatchDefaults(t *testing.T) { + store := openAppStore(t) + ctx := context.Background() + + source, err := store.CreateElement(ctx, LibraryElement{Name: "API"}) + if err != nil { + t.Fatal(err) + } + target, err := store.CreateElement(ctx, LibraryElement{Name: "DB"}) + if err != nil { + t.Fatal(err) + } + label := "reads" + sourceHandle := "right" + targetHandle := "left" + connector, err := store.CreateConnector(ctx, Connector{ + ViewID: 1, + SourceElementID: source.ID, + TargetElementID: target.ID, + Label: &label, + Style: "bezier", + SourceHandle: &sourceHandle, + TargetHandle: &targetHandle, + }) + if err != nil { + t.Fatal(err) + } + if connector.Direction != "forward" { + t.Fatalf("direction = %q, want forward default", connector.Direction) + } + if connector.SourceHandle == nil || *connector.SourceHandle != "right" || connector.TargetHandle == nil || *connector.TargetHandle != "left" { + t.Fatalf("handles = %v/%v, want right/left", connector.SourceHandle, connector.TargetHandle) + } + + updatedLabel := "streams" + updated, err := store.UpdateConnector(ctx, connector.ID, Connector{Label: &updatedLabel}) + if err != nil { + t.Fatal(err) + } + if updated.Label == nil || *updated.Label != "streams" { + t.Fatalf("label = %v, want streams", updated.Label) + } + if updated.SourceElementID != source.ID || updated.TargetElementID != target.ID || updated.Style != "bezier" || updated.Direction != "forward" { + t.Fatalf("patched connector lost defaults or endpoints: %+v", updated) + } + if updated.SourceHandle == nil || *updated.SourceHandle != "right" || updated.TargetHandle == nil || *updated.TargetHandle != "left" { + t.Fatalf("patched handles = %v/%v, want right/left", updated.SourceHandle, updated.TargetHandle) + } +} + +func TestStoreLayersPersistTagsColorsAndUpdates(t *testing.T) { + store := openAppStore(t) + ctx := context.Background() + + layer, err := store.CreateLayer(ctx, 1, "Runtime", []string{"api", "db"}, nil) + if err != nil { + t.Fatal(err) + } + if layer.Color == nil || *layer.Color == "" { + t.Fatalf("layer color = %v, want generated color", layer.Color) + } + if strings.Join(layer.Tags, ",") != "api,db" { + t.Fatalf("layer tags = %+v, want api,db", layer.Tags) + } + + color := "#123456" + updated, err := store.UpdateLayer(ctx, layer.ID, ViewLayer{Name: "Data", Tags: []string{"db"}, Color: &color}) + if err != nil { + t.Fatal(err) + } + if updated.Name != "Data" || updated.Color == nil || *updated.Color != color || strings.Join(updated.Tags, ",") != "db" { + t.Fatalf("updated layer = %+v, want Data/db/%s", updated, color) + } + + tags, err := store.Tags(ctx) + if err != nil { + t.Fatal(err) + } + if _, ok := tags["api"]; !ok { + t.Fatalf("tags = %+v, want api tag retained", tags) + } + if _, ok := tags["db"]; !ok { + t.Fatalf("tags = %+v, want db tag retained", tags) + } + + updated, err = store.UpdateLayer(ctx, layer.ID, ViewLayer{Name: "Data", Tags: []string{"queue"}, Color: &color}) + if err != nil { + t.Fatal(err) + } + tags, err = store.Tags(ctx) + if err != nil { + t.Fatal(err) + } + if tag, ok := tags["queue"]; !ok || tag.Color == "" { + t.Fatalf("tags = %+v, want queue tag with generated color after layer update", tags) + } + _ = updated +} + +func TestStoreAutoTagColorsPreserveUserMetadata(t *testing.T) { + store := openAppStore(t) + ctx := context.Background() + + description := "User chosen tag" + if err := store.UpdateTag(ctx, "runtime", "#123456", &description); err != nil { + t.Fatal(err) + } + if _, err := store.CreateElement(ctx, LibraryElement{Name: "API", Tags: []string{"runtime", "worker", "api"}}); err != nil { + t.Fatal(err) + } + + tags, err := store.Tags(ctx) + if err != nil { + t.Fatal(err) + } + runtime := tags["runtime"] + if runtime.Color != "#123456" || runtime.Description == nil || *runtime.Description != description { + t.Fatalf("runtime tag = %+v, want user metadata preserved", runtime) + } + if tags["worker"].Color == "" || tags["api"].Color == "" { + t.Fatalf("tags = %+v, want generated colors for new tags", tags) + } + if tags["worker"].Color == tags["api"].Color { + t.Fatalf("worker/api colors both %q, want unused colors preferred", tags["worker"].Color) + } +} + +func TestStoreAutoTagColorsGenerateUnusedColorsAfterSwatchesAreExhausted(t *testing.T) { + store := openAppStore(t) + ctx := context.Background() + + for i, color := range tagcolors.SwatchColors { + if err := store.UpdateTag(ctx, fmt.Sprintf("existing-%d", i), color, nil); err != nil { + t.Fatal(err) + } + } + if _, err := store.CreateElement(ctx, LibraryElement{Name: "Worker", Tags: []string{"generated-a", "generated-b", "generated-c"}}); err != nil { + t.Fatal(err) + } + + tags, err := store.Tags(ctx) + if err != nil { + t.Fatal(err) + } + seen := map[string]string{} + for _, name := range []string{"generated-a", "generated-b", "generated-c"} { + tag := tags[name] + if tag.Color == "" { + t.Fatalf("%s color is empty", name) + } + if existing := seen[tag.Color]; existing != "" { + t.Fatalf("%s and %s both use %s, want generated fallback colors to stay unused", existing, name, tag.Color) + } + seen[tag.Color] = name + for _, swatch := range tagcolors.SwatchColors { + if strings.EqualFold(tag.Color, swatch) { + t.Fatalf("%s color = %s, want non-swatch fallback after swatches exhausted", name, tag.Color) + } + } + } +} + +func openAppStore(t *testing.T) *Store { + t.Helper() + store, err := OpenStore(filepath.Join(t.TempDir(), "tld.db"), assets.FS) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { _ = store.Close() }) + return store +} diff --git a/internal/app/tag.go b/internal/app/tag.go new file mode 100644 index 0000000..2200488 --- /dev/null +++ b/internal/app/tag.go @@ -0,0 +1,145 @@ +package app + +import ( + "context" + "strings" + + "github.com/mertcikla/tld/internal/tagcolors" +) + +type Tag struct { + Name string `json:"name"` + Color string `json:"color"` + Description *string `json:"description"` +} + +func (s *Store) Layers(ctx context.Context, viewID int64) ([]ViewLayer, error) { + rows, err := s.db.QueryContext(ctx, `SELECT id, view_id, name, tags, color, created_at, updated_at FROM view_layers WHERE view_id = ? ORDER BY id`, viewID) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + out := make([]ViewLayer, 0) + for rows.Next() { + var rawTags string + var item ViewLayer + if err := rows.Scan(&item.ID, &item.DiagramID, &item.Name, &rawTags, &item.Color, &item.CreatedAt, &item.UpdatedAt); err != nil { + return nil, err + } + item.Tags = parseStrings(rawTags) + out = append(out, item) + } + return out, rows.Err() +} + +func (s *Store) CreateLayer(ctx context.Context, viewID int64, name string, tags []string, color *string) (ViewLayer, error) { + if err := s.ensureTagColors(ctx, tags); err != nil { + return ViewLayer{}, err + } + + if color == nil || strings.TrimSpace(*color) == "" { + // User said pick unused, usually means relative to existing layers in the same view or global tags. + // Frontend uses tagColors. + tagsMap, err := s.Tags(ctx) + if err != nil { + return ViewLayer{}, err + } + var usedColors []string + for _, t := range tagsMap { + usedColors = append(usedColors, t.Color) + } + // Also consider existing layers colors + layers, err := s.Layers(ctx, viewID) + if err == nil { + for _, l := range layers { + if l.Color != nil { + usedColors = append(usedColors, *l.Color) + } + } + } + c := s.pickUnusedColor(ctx, usedColors) + color = &c + } + + now := nowString() + res, err := s.db.ExecContext(ctx, `INSERT INTO view_layers(view_id, name, tags, color, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?)`, + viewID, name, jsonString(tags, "[]"), color, now, now) + if err != nil { + return ViewLayer{}, err + } + id, _ := res.LastInsertId() + return s.LayerByID(ctx, id) +} + +func (s *Store) LayerByID(ctx context.Context, id int64) (ViewLayer, error) { + row := s.db.QueryRowContext(ctx, `SELECT id, view_id, name, tags, color, created_at, updated_at FROM view_layers WHERE id = ?`, id) + var rawTags string + var item ViewLayer + if err := row.Scan(&item.ID, &item.DiagramID, &item.Name, &rawTags, &item.Color, &item.CreatedAt, &item.UpdatedAt); err != nil { + return ViewLayer{}, err + } + item.Tags = parseStrings(rawTags) + return item, nil +} + +func (s *Store) UpdateLayer(ctx context.Context, id int64, patch ViewLayer) (ViewLayer, error) { + current, err := s.LayerByID(ctx, id) + if err != nil { + return ViewLayer{}, err + } + if patch.Name == "" { + patch.Name = current.Name + } + if patch.Tags == nil { + patch.Tags = current.Tags + } + if err := s.ensureTagColors(ctx, patch.Tags); err != nil { + return ViewLayer{}, err + } + if patch.Color == nil { + patch.Color = current.Color + } + _, err = s.db.ExecContext(ctx, `UPDATE view_layers SET name = ?, tags = ?, color = ?, updated_at = ? WHERE id = ?`, patch.Name, jsonString(patch.Tags, "[]"), patch.Color, nowString(), id) + if err != nil { + return ViewLayer{}, err + } + return s.LayerByID(ctx, id) +} + +func (s *Store) DeleteLayer(ctx context.Context, id int64) error { + _, err := s.db.ExecContext(ctx, `DELETE FROM view_layers WHERE id = ?`, id) + return err +} + +func (s *Store) Tags(ctx context.Context) (map[string]Tag, error) { + rows, err := s.db.QueryContext(ctx, `SELECT name, color, description FROM tags ORDER BY name`) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + out := map[string]Tag{} + for rows.Next() { + var tag Tag + if err := rows.Scan(&tag.Name, &tag.Color, &tag.Description); err != nil { + return nil, err + } + out[tag.Name] = tag + } + return out, rows.Err() +} + +func (s *Store) UpdateTag(ctx context.Context, name, color string, description *string) error { + _, err := s.db.ExecContext(ctx, ` + INSERT INTO tags(name, color, description) VALUES (?, ?, ?) + ON CONFLICT(name) DO UPDATE SET color = excluded.color, description = excluded.description`, + name, color, description) + return err +} + +func (s *Store) pickUnusedColor(ctx context.Context, usedColors []string) string { + return tagcolors.PickUnusedColor(usedColors) +} + +func (s *Store) ensureTagColors(ctx context.Context, tags []string) error { + return tagcolors.Ensure(ctx, s.db, tags) +} diff --git a/internal/app/view.go b/internal/app/view.go new file mode 100644 index 0000000..890f4a6 --- /dev/null +++ b/internal/app/view.go @@ -0,0 +1,528 @@ +package app + +import ( + "context" + "database/sql" + "errors" + "maps" + "sort" + "strings" +) + +type ViewTreeNode struct { + ID int64 `json:"id"` + OwnerElementID *int64 `json:"owner_element_id"` + Name string `json:"name"` + Description *string `json:"description"` + LevelLabel *string `json:"level_label"` + Level int `json:"level"` + Depth int `json:"depth"` + CreatedAt string `json:"created_at"` + UpdatedAt string `json:"updated_at"` + ParentViewID *int64 `json:"parent_view_id"` + Children []ViewTreeNode `json:"children"` +} + +type ViewSummary struct { + ID int64 `json:"id"` + Name string `json:"name"` + Label *string `json:"label"` + IsRoot bool `json:"is_root"` + CreatedAt string `json:"created_at"` + UpdatedAt string `json:"updated_at"` +} + +type ViewConnector struct { + ID int64 `json:"id"` + ElementID *int64 `json:"element_id"` + FromViewID int64 `json:"from_view_id"` + ToViewID int64 `json:"to_view_id"` + ToViewName string `json:"to_view_name"` + RelationType string `json:"relation_type"` +} + +type IncomingViewConnector struct { + ID int64 `json:"id"` + ElementID int64 `json:"element_id"` + ElementName string `json:"element_name"` + FromViewID int64 `json:"from_view_id"` + FromViewName string `json:"from_view_name"` + ToViewID int64 `json:"to_view_id"` +} + +type ViewPlacement struct { + ViewID int64 `json:"view_id"` + ViewName string `json:"view_name"` +} + +type ViewLayer struct { + ID int64 `json:"id"` + DiagramID int64 `json:"diagram_id"` + Name string `json:"name"` + Tags []string `json:"tags"` + Color *string `json:"color,omitempty"` + CreatedAt string `json:"created_at,omitempty"` + UpdatedAt string `json:"updated_at,omitempty"` +} + +func (s *Store) listViewRows(ctx context.Context) ([]viewRow, error) { + rows, err := s.db.QueryContext(ctx, `SELECT id, owner_element_id, name, description, level_label, level, created_at, updated_at FROM views ORDER BY id`) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + var out []viewRow + for rows.Next() { + var row viewRow + if err := rows.Scan(&row.ID, &row.OwnerElementID, &row.Name, &row.Description, &row.LevelLabel, &row.Level, &row.CreatedAt, &row.UpdatedAt); err != nil { + return nil, err + } + out = append(out, row) + } + return out, rows.Err() +} + +func (s *Store) parentViewForOwner(ctx context.Context, ownerElementID int64, excludeViewID int64) (*int64, error) { + row := s.db.QueryRowContext(ctx, `SELECT view_id FROM placements WHERE element_id = ? AND view_id != ? ORDER BY view_id LIMIT 1`, ownerElementID, excludeViewID) + var viewID int64 + if err := row.Scan(&viewID); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, nil + } + return nil, err + } + return &viewID, nil +} + +func (s *Store) parentViewMap(ctx context.Context, rows []viewRow) (map[int64]*int64, error) { + ownerViewIDs := make(map[int64][]int64, len(rows)) + parentMap := make(map[int64]*int64, len(rows)) + for _, row := range rows { + parentMap[row.ID] = nil + if row.OwnerElementID.Valid { + ownerViewIDs[row.OwnerElementID.Int64] = append(ownerViewIDs[row.OwnerElementID.Int64], row.ID) + } + } + if len(ownerViewIDs) == 0 { + return parentMap, nil + } + + placementRows, err := s.db.QueryContext(ctx, ` + SELECT DISTINCT p.element_id, p.view_id + FROM placements p + JOIN views v ON v.owner_element_id = p.element_id + ORDER BY p.element_id, p.view_id`) + if err != nil { + return nil, err + } + defer func() { _ = placementRows.Close() }() + for placementRows.Next() { + var elementID, parentID int64 + if err := placementRows.Scan(&elementID, &parentID); err != nil { + return nil, err + } + for _, childID := range ownerViewIDs[elementID] { + if parentID == childID || parentMap[childID] != nil { + continue + } + pid := parentID + parentMap[childID] = &pid + } + } + return parentMap, placementRows.Err() +} + +func (s *Store) childViewMeta(ctx context.Context, elementID int64) (bool, *string, error) { + row := s.db.QueryRowContext(ctx, `SELECT level_label FROM views WHERE owner_element_id = ? ORDER BY id LIMIT 1`, elementID) + var label sql.NullString + if err := row.Scan(&label); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return false, nil, nil + } + return false, nil, err + } + if label.Valid { + return true, &label.String, nil + } + return true, nil, nil +} + +type childViewMetaValue struct { + hasView bool + label *string +} + +func (s *Store) childViewMetaMap(ctx context.Context) (map[int64]childViewMetaValue, error) { + rows, err := s.db.QueryContext(ctx, `SELECT owner_element_id, level_label FROM views WHERE owner_element_id IS NOT NULL ORDER BY id`) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + out := map[int64]childViewMetaValue{} + for rows.Next() { + var elementID int64 + var label sql.NullString + if err := rows.Scan(&elementID, &label); err != nil { + return nil, err + } + if _, exists := out[elementID]; exists { + continue + } + meta := childViewMetaValue{hasView: true} + if label.Valid { + labelCopy := label.String + meta.label = &labelCopy + } + out[elementID] = meta + } + return out, rows.Err() +} + +func viewNodeFromRow(row viewRow, parentID *int64, depth int) ViewTreeNode { + var ownerElementID *int64 + if row.OwnerElementID.Valid { + ownerElementID = new(row.OwnerElementID.Int64) + } + var description *string + if row.Description.Valid { + description = new(row.Description.String) + } + var levelLabel *string + if row.LevelLabel.Valid { + levelLabel = new(row.LevelLabel.String) + } + return ViewTreeNode{ + ID: row.ID, + OwnerElementID: ownerElementID, + Name: row.Name, + Description: description, + LevelLabel: levelLabel, + Level: row.Level, + Depth: depth, + CreatedAt: row.CreatedAt, + UpdatedAt: row.UpdatedAt, + ParentViewID: parentID, + Children: []ViewTreeNode{}, + } +} + +func (s *Store) ViewTree(ctx context.Context) ([]ViewTreeNode, error) { + rows, err := s.listViewRows(ctx) + if err != nil { + return nil, err + } + parentMap, err := s.parentViewMap(ctx, rows) + if err != nil { + return nil, err + } + rowByID := make(map[int64]viewRow, len(rows)) + byParent := map[int64][]viewRow{} + var roots []viewRow + for _, row := range rows { + rowByID[row.ID] = row + if parentID := parentMap[row.ID]; parentID != nil { + byParent[*parentID] = append(byParent[*parentID], row) + continue + } + roots = append(roots, row) + } + visited := make(map[int64]bool, len(rows)) + var build func(row viewRow, depth int, stack map[int64]bool) ViewTreeNode + build = func(row viewRow, depth int, stack map[int64]bool) ViewTreeNode { + node := viewNodeFromRow(row, parentMap[row.ID], depth) + visited[row.ID] = true + if stack[row.ID] { + return node + } + nextStack := make(map[int64]bool, len(stack)+1) + maps.Copy(nextStack, stack) + nextStack[row.ID] = true + children := byParent[row.ID] + sort.Slice(children, func(i, j int) bool { return children[i].ID < children[j].ID }) + for _, child := range children { + if nextStack[child.ID] { + continue + } + node.Children = append(node.Children, build(child, depth+1, nextStack)) + } + return node + } + sort.Slice(roots, func(i, j int) bool { return roots[i].ID < roots[j].ID }) + out := make([]ViewTreeNode, 0, len(roots)) + for _, root := range roots { + out = append(out, build(root, 0, map[int64]bool{})) + } + if len(visited) < len(rows) { + remaining := make([]viewRow, 0, len(rows)-len(visited)) + for _, row := range rows { + if visited[row.ID] { + continue + } + remaining = append(remaining, rowByID[row.ID]) + } + sort.Slice(remaining, func(i, j int) bool { return remaining[i].ID < remaining[j].ID }) + for _, row := range remaining { + if visited[row.ID] { + continue + } + node := build(row, 0, map[int64]bool{}) + node.ParentViewID = nil + out = append(out, node) + } + } + return out, nil +} + +func flattenTree(nodes []ViewTreeNode) []ViewTreeNode { + var out []ViewTreeNode + var walk func(items []ViewTreeNode) + walk = func(items []ViewTreeNode) { + for _, item := range items { + children := item.Children + item.Children = nil + out = append(out, item) + walk(children) + } + } + walk(nodes) + return out +} + +func (s *Store) Views(ctx context.Context) ([]ViewSummary, error) { + tree, err := s.ViewTree(ctx) + if err != nil { + return nil, err + } + flat := flattenTree(tree) + out := make([]ViewSummary, 0, len(flat)) + for _, node := range flat { + out = append(out, ViewSummary{ + ID: node.ID, + Name: node.Name, + Label: node.LevelLabel, + IsRoot: node.ParentViewID == nil, + CreatedAt: node.CreatedAt, + UpdatedAt: node.UpdatedAt, + }) + } + return out, nil +} + +func (s *Store) ViewByID(ctx context.Context, id int64) (ViewTreeNode, error) { + row := s.db.QueryRowContext(ctx, `SELECT id, owner_element_id, name, description, level_label, level, created_at, updated_at FROM views WHERE id = ?`, id) + var view viewRow + if err := row.Scan(&view.ID, &view.OwnerElementID, &view.Name, &view.Description, &view.LevelLabel, &view.Level, &view.CreatedAt, &view.UpdatedAt); err != nil { + return ViewTreeNode{}, err + } + var parentID *int64 + var err error + if view.OwnerElementID.Valid { + parentID, err = s.parentViewForOwner(ctx, view.OwnerElementID.Int64, view.ID) + if err != nil { + return ViewTreeNode{}, err + } + } + return viewNodeFromRow(view, parentID, 0), nil +} + +func (s *Store) ChildViews(ctx context.Context, parentViewID int64) ([]ViewTreeNode, error) { + rows, err := s.db.QueryContext(ctx, ` + SELECT DISTINCT v.id, v.owner_element_id, v.name, v.description, v.level_label, v.level, v.created_at, v.updated_at + FROM views v + JOIN placements p ON p.element_id = v.owner_element_id + WHERE p.view_id = ? AND v.id != ? + ORDER BY v.id`, parentViewID, parentViewID) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + out := []ViewTreeNode{} + for rows.Next() { + var row viewRow + if err := rows.Scan(&row.ID, &row.OwnerElementID, &row.Name, &row.Description, &row.LevelLabel, &row.Level, &row.CreatedAt, &row.UpdatedAt); err != nil { + return nil, err + } + parentID := parentViewID + out = append(out, viewNodeFromRow(row, &parentID, 0)) + } + return out, rows.Err() +} + +func (s *Store) RootViews(ctx context.Context) ([]ViewTreeNode, error) { + rows, err := s.db.QueryContext(ctx, ` + SELECT v.id, v.owner_element_id, v.name, v.description, v.level_label, v.level, v.created_at, v.updated_at + FROM views v + WHERE v.owner_element_id IS NULL + OR NOT EXISTS ( + SELECT 1 FROM placements p + WHERE p.element_id = v.owner_element_id + AND p.view_id != v.id + ) + ORDER BY v.id`) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + out := []ViewTreeNode{} + for rows.Next() { + var row viewRow + if err := rows.Scan(&row.ID, &row.OwnerElementID, &row.Name, &row.Description, &row.LevelLabel, &row.Level, &row.CreatedAt, &row.UpdatedAt); err != nil { + return nil, err + } + out = append(out, viewNodeFromRow(row, nil, 0)) + } + return out, rows.Err() +} + +func (s *Store) CreateView(ctx context.Context, name string, levelLabel *string, ownerElementID *int64) (ViewSummary, error) { + now := nowString() + level := 1 + if ownerElementID != nil { + parentID, err := s.parentViewForOwner(ctx, *ownerElementID, 0) + if err == nil && parentID != nil { + parent, err := s.ViewByID(ctx, *parentID) + if err == nil { + level = parent.Level + 1 + } + } + } + res, err := s.db.ExecContext(ctx, `INSERT INTO views(owner_element_id, name, description, level_label, level, created_at, updated_at) VALUES (?, ?, NULL, ?, ?, ?, ?)`, + ownerElementID, strings.TrimSpace(name), levelLabel, level, now, now) + if err != nil { + return ViewSummary{}, err + } + id, _ := res.LastInsertId() + view, err := s.ViewByID(ctx, id) + if err != nil { + return ViewSummary{}, err + } + return ViewSummary{ + ID: view.ID, + Name: view.Name, + Label: view.LevelLabel, + IsRoot: view.ParentViewID == nil, + CreatedAt: view.CreatedAt, + UpdatedAt: view.UpdatedAt, + }, nil +} + +func (s *Store) UpdateView(ctx context.Context, id int64, name *string, levelLabel *string) (ViewSummary, error) { + current, err := s.ViewByID(ctx, id) + if err != nil { + return ViewSummary{}, err + } + nextName := current.Name + if name != nil && strings.TrimSpace(*name) != "" { + nextName = strings.TrimSpace(*name) + } + _, err = s.db.ExecContext(ctx, `UPDATE views SET name = ?, level_label = ?, updated_at = ? WHERE id = ?`, nextName, levelLabel, nowString(), id) + if err != nil { + return ViewSummary{}, err + } + updated, err := s.ViewByID(ctx, id) + if err != nil { + return ViewSummary{}, err + } + return ViewSummary{ + ID: updated.ID, + Name: updated.Name, + Label: updated.LevelLabel, + IsRoot: updated.ParentViewID == nil, + CreatedAt: updated.CreatedAt, + UpdatedAt: updated.UpdatedAt, + }, nil +} + +func (s *Store) SetViewLevel(ctx context.Context, id int64, level int) error { + _, err := s.db.ExecContext(ctx, `UPDATE views SET level = ?, updated_at = ? WHERE id = ?`, level, nowString(), id) + return err +} + +func (s *Store) DeleteView(ctx context.Context, id int64) error { + _, err := s.db.ExecContext(ctx, `DELETE FROM views WHERE id = ?`, id) + return err +} + +func scanElement(row scanner, includeViewMeta bool, store *Store, ctx context.Context) (LibraryElement, error) { + var ( + elem LibraryElement + techRaw string + tagRaw string + kind sql.NullString + description sql.NullString + technology sql.NullString + url sql.NullString + logoURL sql.NullString + repo sql.NullString + branch sql.NullString + filePath sql.NullString + language sql.NullString + ) + if err := row.Scan(&elem.ID, &elem.Name, &kind, &description, &technology, &url, &logoURL, &techRaw, &tagRaw, &repo, &branch, &filePath, &language, &elem.CreatedAt, &elem.UpdatedAt); err != nil { + return LibraryElement{}, err + } + if kind.Valid { + elem.Kind = &kind.String + } + if description.Valid { + elem.Description = &description.String + } + if technology.Valid { + elem.Technology = &technology.String + } + if url.Valid { + elem.URL = &url.String + } + if logoURL.Valid { + elem.LogoURL = &logoURL.String + } + if repo.Valid { + elem.Repo = &repo.String + } + if branch.Valid { + elem.Branch = &branch.String + } + if filePath.Valid { + elem.FilePath = &filePath.String + } + if language.Valid { + elem.Language = &language.String + } + elem.TechnologyConnectors = parseTechnologyConnectors(techRaw) + elem.Tags = parseStrings(tagRaw) + if includeViewMeta { + hasView, label, err := store.childViewMeta(ctx, elem.ID) + if err != nil { + return LibraryElement{}, err + } + elem.HasView = hasView + elem.ViewLabel = label + } + return elem, nil +} + +func (s *Store) ListIncomingNavigations(ctx context.Context, viewID int64) ([]IncomingViewConnector, error) { + view, err := s.ViewByID(ctx, viewID) + if err != nil { + return nil, err + } + if view.OwnerElementID == nil || view.ParentViewID == nil { + return []IncomingViewConnector{}, nil + } + element, err := s.ElementByID(ctx, *view.OwnerElementID) + if err != nil { + return nil, err + } + parent, err := s.ViewByID(ctx, *view.ParentViewID) + if err != nil { + return nil, err + } + return []IncomingViewConnector{{ + ID: 0, + ElementID: *view.OwnerElementID, + ElementName: element.Name, + FromViewID: parent.ID, + FromViewName: parent.Name, + ToViewID: view.ID, + }}, nil +} diff --git a/internal/cmdutil/convert.go b/internal/cmdutil/convert.go index e0f28ea..34e6c64 100644 --- a/internal/cmdutil/convert.go +++ b/internal/cmdutil/convert.go @@ -55,6 +55,7 @@ func ConvertExportResponse(baseWS *workspace.Workspace, msg *diagv1.ExportOrgani Branch: e.GetBranch(), Language: e.GetLanguage(), FilePath: e.GetFilePath(), + Tags: cloneStrings(e.GetTags()), HasView: e.GetHasView(), ViewLabel: strings.TrimSpace(e.GetViewLabel()), } @@ -160,6 +161,13 @@ func exportedDiagramLabel(diagram *diagv1.View, elementName string) string { return "" } +func cloneStrings(values []string) []string { + if len(values) == 0 { + return nil + } + return append([]string(nil), values...) +} + func buildDiagramOwnerIndex(msg *diagv1.ExportOrganizationResponse, elements map[string]*workspace.Element, objectIDToRef map[int32]string) map[int32]string { owners := make(map[int32]string) usedRefs := make(map[string]struct{}) diff --git a/internal/cmdutil/convert_test.go b/internal/cmdutil/convert_test.go index c848ade..d96c74b 100644 --- a/internal/cmdutil/convert_test.go +++ b/internal/cmdutil/convert_test.go @@ -9,10 +9,6 @@ import ( "google.golang.org/protobuf/types/known/timestamppb" ) -func strPtr(s string) *string { - return &s -} - func TestConvertExportResponsePreservesRefsAndInfersOwnedViews(t *testing.T) { updated := timestamppb.New(time.Date(2026, 4, 25, 12, 0, 0, 0, time.UTC)) base := &workspace.Workspace{ @@ -30,11 +26,11 @@ func TestConvertExportResponsePreservesRefsAndInfersOwnedViews(t *testing.T) { } msg := &diagv1.ExportOrganizationResponse{ Elements: []*diagv1.Element{ - {Id: 10, Name: "API Service", Kind: strPtr("container"), HasView: true, ViewLabel: strPtr("Runtime"), Technology: strPtr("Go"), UpdatedAt: updated}, - {Id: 20, Name: "Database", Kind: strPtr("container"), UpdatedAt: updated}, + {Id: 10, Name: "API Service", Kind: new("container"), HasView: true, ViewLabel: new("Runtime"), Technology: new("Go"), UpdatedAt: updated}, + {Id: 20, Name: "Database", Kind: new("container"), UpdatedAt: updated}, }, Views: []*diagv1.View{ - {Id: 100, Name: "API Service", LevelLabel: strPtr("Runtime"), UpdatedAt: updated}, + {Id: 100, Name: "API Service", LevelLabel: new("Runtime"), UpdatedAt: updated}, {Id: 101, Name: "Landscape", UpdatedAt: updated}, }, Placements: []*diagv1.ElementPlacement{ @@ -42,7 +38,7 @@ func TestConvertExportResponsePreservesRefsAndInfersOwnedViews(t *testing.T) { {ViewId: 999, ElementId: 10, PositionX: 1, PositionY: 2}, }, Connectors: []*diagv1.Connector{ - {Id: 50, ViewId: 100, SourceElementId: 10, TargetElementId: 20, Label: strPtr("reads"), Relationship: strPtr("dependency"), Direction: "forward", UpdatedAt: updated}, + {Id: 50, ViewId: 100, SourceElementId: 10, TargetElementId: 20, Label: new("reads"), Relationship: new("dependency"), Direction: "forward", UpdatedAt: updated}, }, } diff --git a/internal/codeowners/codeowners.go b/internal/codeowners/codeowners.go new file mode 100644 index 0000000..61471eb --- /dev/null +++ b/internal/codeowners/codeowners.go @@ -0,0 +1,227 @@ +package codeowners + +import ( + "bufio" + "os" + "path/filepath" + "sort" + "strings" +) + +var candidateFiles = []string{ + "CODEOWNERS", + filepath.Join(".github", "CODEOWNERS"), + filepath.Join("docs", "CODEOWNERS"), +} + +type Matcher struct { + rules []rule +} + +type rule struct { + pattern string + owners []string +} + +func Load(repoRoot string) (*Matcher, error) { + for _, name := range candidateFiles { + path := filepath.Join(repoRoot, name) + data, err := os.ReadFile(path) + if err == nil { + return Parse(string(data)), nil + } + if !os.IsNotExist(err) { + return nil, err + } + } + return &Matcher{}, nil +} + +func Parse(data string) *Matcher { + scanner := bufio.NewScanner(strings.NewReader(data)) + var rules []rule + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if line == "" || strings.HasPrefix(line, "#") { + continue + } + fields := strings.Fields(line) + if len(fields) < 2 { + continue + } + pattern := normalizePattern(fields[0]) + if pattern == "" { + continue + } + owners := parseOwners(fields[1:]) + if len(owners) == 0 { + continue + } + rules = append(rules, rule{pattern: pattern, owners: owners}) + } + return &Matcher{rules: rules} +} + +func (m *Matcher) TagsForPath(path string) []string { + if m == nil { + return nil + } + clean := normalizePath(path) + if clean == "" { + return nil + } + var owners []string + for _, rule := range m.rules { + if rule.matches(clean) { + owners = rule.owners + } + } + return ownerTags(owners) +} + +func parseOwners(fields []string) []string { + seen := map[string]struct{}{} + var owners []string + for _, field := range fields { + if strings.HasPrefix(field, "#") { + break + } + if !strings.Contains(field, "@") { + continue + } + if idx := strings.IndexByte(field, ':'); idx >= 0 { + field = field[:idx] + } + field = strings.TrimSpace(field) + if field == "" || !strings.HasPrefix(field, "@") { + continue + } + if _, ok := seen[field]; ok { + continue + } + seen[field] = struct{}{} + owners = append(owners, field) + } + sort.Strings(owners) + return owners +} + +func ownerTags(owners []string) []string { + if len(owners) == 0 { + return nil + } + tags := make([]string, 0, len(owners)) + for _, owner := range owners { + tags = append(tags, "owner:"+owner) + } + sort.Strings(tags) + return tags +} + +func normalizePattern(pattern string) string { + pattern = strings.TrimSpace(pattern) + if pattern == "" || pattern == "!" { + return "" + } + pattern = strings.TrimPrefix(pattern, "!") + rooted := strings.HasPrefix(pattern, "/") + dirPattern := strings.HasSuffix(pattern, "/") + pattern = strings.Trim(pattern, "/") + pattern = normalizePath(pattern) + if pattern == "" { + return "" + } + if rooted { + pattern = "/" + pattern + } + if dirPattern { + pattern += "/" + } + return pattern +} + +func normalizePath(path string) string { + path = filepath.ToSlash(filepath.Clean(filepath.FromSlash(path))) + path = strings.TrimPrefix(path, "./") + path = strings.Trim(path, "/") + if path == "." { + return "" + } + return path +} + +func (r rule) matches(candidate string) bool { + pattern := r.pattern + rooted := strings.HasPrefix(pattern, "/") + pattern = strings.TrimPrefix(pattern, "/") + dirPattern := strings.HasSuffix(pattern, "/") + pattern = strings.TrimSuffix(pattern, "/") + if pattern == "" { + return false + } + if rooted { + return matchAnchored(pattern, candidate, dirPattern) + } + if !strings.Contains(pattern, "/") { + for part := range strings.SplitSeq(candidate, "/") { + if matchSegment(pattern, part) { + return true + } + } + return false + } + if matchAnchored(pattern, candidate, dirPattern) { + return true + } + parts := strings.Split(candidate, "/") + for i := 1; i < len(parts); i++ { + if matchAnchored(pattern, strings.Join(parts[i:], "/"), dirPattern) { + return true + } + } + return false +} + +func matchAnchored(pattern, candidate string, dirPattern bool) bool { + if dirPattern && (candidate == pattern || strings.HasPrefix(candidate, pattern+"/")) { + return true + } + if matchGlob(pattern, candidate) { + return true + } + if !strings.ContainsAny(pattern, "*?[") && (candidate == pattern || strings.HasPrefix(candidate, pattern+"/")) { + return true + } + return patternOwnsCandidateFolder(pattern, candidate) +} + +func patternOwnsCandidateFolder(pattern, candidate string) bool { + if strings.ContainsAny(candidate, "*?[") || candidate == "" { + return false + } + staticPrefix := pattern + if idx := strings.IndexAny(staticPrefix, "*?["); idx >= 0 { + staticPrefix = staticPrefix[:idx] + } + staticPrefix = strings.Trim(staticPrefix, "/") + return staticPrefix != "" && (staticPrefix == candidate || strings.HasPrefix(staticPrefix, candidate+"/")) +} + +func matchSegment(pattern, candidate string) bool { + ok, err := filepath.Match(pattern, candidate) + return err == nil && ok +} + +func matchGlob(pattern, candidate string) bool { + patternParts := strings.Split(pattern, "/") + candidateParts := strings.Split(candidate, "/") + if len(patternParts) != len(candidateParts) { + return false + } + for i := range patternParts { + if !matchSegment(patternParts[i], candidateParts[i]) { + return false + } + } + return true +} diff --git a/internal/codeowners/codeowners_test.go b/internal/codeowners/codeowners_test.go new file mode 100644 index 0000000..7993ccb --- /dev/null +++ b/internal/codeowners/codeowners_test.go @@ -0,0 +1,69 @@ +package codeowners + +import ( + "os" + "path/filepath" + "testing" +) + +func TestParseMatchesBasicAndExtendedOwners(t *testing.T) { + matcher := Parse(` +# comment +/path/to/code @username +/frontend/* @org/web-team:random(2) +/backend/* @org/backend:least_busy(3) # Randomly select reviewers +`) + + assertTags(t, matcher.TagsForPath("path/to/code"), []string{"owner:@username"}) + assertTags(t, matcher.TagsForPath("frontend/app.ts"), []string{"owner:@org/web-team"}) + assertTags(t, matcher.TagsForPath("backend/main.go"), []string{"owner:@org/backend"}) +} + +func TestLastMatchWinsAndOwnersAreSortedDeduped(t *testing.T) { + matcher := Parse(` +*.go @zeta @alpha @zeta @ignore:least_busy(1) +/cmd/* @cmd-owner +`) + + assertTags(t, matcher.TagsForPath("internal/app.go"), []string{"owner:@alpha", "owner:@ignore", "owner:@zeta"}) + assertTags(t, matcher.TagsForPath("cmd/main.go"), []string{"owner:@cmd-owner"}) +} + +func TestDirectoryAndFolderOwnership(t *testing.T) { + matcher := Parse(` +/frontend/ @org/web-team +/backend/* @org/backend +`) + + assertTags(t, matcher.TagsForPath("frontend"), []string{"owner:@org/web-team"}) + assertTags(t, matcher.TagsForPath("frontend/app.ts"), []string{"owner:@org/web-team"}) + assertTags(t, matcher.TagsForPath("backend"), []string{"owner:@org/backend"}) + assertTags(t, matcher.TagsForPath("backend/service.go"), []string{"owner:@org/backend"}) +} + +func TestLoadFindsSupportedLocations(t *testing.T) { + dir := t.TempDir() + if err := os.MkdirAll(filepath.Join(dir, ".github"), 0o755); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(dir, ".github", "CODEOWNERS"), []byte("/src/* @owner\n"), 0o644); err != nil { + t.Fatal(err) + } + matcher, err := Load(dir) + if err != nil { + t.Fatal(err) + } + assertTags(t, matcher.TagsForPath("src/main.go"), []string{"owner:@owner"}) +} + +func assertTags(t *testing.T, got, want []string) { + t.Helper() + if len(got) != len(want) { + t.Fatalf("tags = %#v, want %#v", got, want) + } + for i := range want { + if got[i] != want[i] { + t.Fatalf("tags = %#v, want %#v", got, want) + } + } +} diff --git a/internal/completion/completion.go b/internal/completion/completion.go index 1408d3f..2e4dfd4 100644 --- a/internal/completion/completion.go +++ b/internal/completion/completion.go @@ -10,7 +10,6 @@ package completion import ( "context" - "os" "sort" "time" @@ -39,7 +38,11 @@ func loadWS(wdir *string) *workspace.Workspace { } func remoteEnabled() bool { - return os.Getenv(remoteEnvVar) == "1" + state, err := workspace.LoadGlobalConfigStateNoRepair() + if err != nil { + return false + } + return state.Config.Completion.Remote } // remoteElements fetches elements from the API with a short deadline. Any diff --git a/internal/completion/config_test.go b/internal/completion/config_test.go new file mode 100644 index 0000000..8153fb7 --- /dev/null +++ b/internal/completion/config_test.go @@ -0,0 +1,23 @@ +package completion + +import ( + "os" + "path/filepath" + "testing" +) + +func TestRemoteEnabledUsesGlobalConfigAndEnvOverride(t *testing.T) { + configDir := t.TempDir() + t.Setenv("TLD_CONFIG_DIR", configDir) + if err := os.WriteFile(filepath.Join(configDir, "tld.yaml"), []byte("completion:\n remote: true\n"), 0o600); err != nil { + t.Fatalf("write config: %v", err) + } + if !remoteEnabled() { + t.Fatal("expected completion.remote config to enable remote completion") + } + + t.Setenv("TLD_COMPLETION_REMOTE", "0") + if remoteEnabled() { + t.Fatal("expected TLD_COMPLETION_REMOTE=0 to disable remote completion") + } +} diff --git a/internal/core/store.go b/internal/core/store.go index 1079bc0..4d7fdbc 100644 --- a/internal/core/store.go +++ b/internal/core/store.go @@ -23,7 +23,7 @@ type ViewStore interface { } type ElementStore interface { - Elements(ctx context.Context, limit, offset int, search string) ([]LibraryElement, error) + Elements(ctx context.Context, limit, offset int, search string) ([]LibraryElement, int, error) ElementByID(ctx context.Context, id int64) (LibraryElement, error) CreateElement(ctx context.Context, input LibraryElement) (LibraryElement, error) UpdateElement(ctx context.Context, id int64, input LibraryElement) (LibraryElement, error) diff --git a/internal/git/git.go b/internal/git/git.go index 03c79fa..317945c 100644 --- a/internal/git/git.go +++ b/internal/git/git.go @@ -5,11 +5,46 @@ package git import ( "fmt" "os/exec" + "path/filepath" "strconv" "strings" "time" ) +type Status struct { + Branch string + HeadCommit string + HeadMessage string + RemoteURL string + Staged []string + Unstaged []string + Untracked []string + Deleted []string +} + +type LineDiff struct { + Added int + Removed int +} + +type LineHunk struct { + File string + OldStart int + OldLineCount int + NewStart int + NewLineCount int + AddedLines []int + RemovedLines []int +} + +type WorktreeChange string + +const ( + WorktreeAdded WorktreeChange = "added" + WorktreeUpdated WorktreeChange = "updated" + WorktreeDeleted WorktreeChange = "deleted" +) + // DetectBranch returns the current branch name for the git repo rooted at dir. func DetectBranch(dir string) (string, error) { out, err := run(dir, "branch", "--show-current") @@ -32,6 +67,56 @@ func DetectRemoteURL(dir string) (string, error) { return strings.TrimSpace(out), nil } +// DetectHeadCommit returns the current HEAD commit SHA for the git repo at dir. +func DetectHeadCommit(dir string) (string, error) { + out, err := run(dir, "rev-parse", "HEAD") + if err != nil { + return "", fmt.Errorf("detect head commit: %w", err) + } + return strings.TrimSpace(out), nil +} + +// DetectHeadMessage returns the subject line for HEAD. +func DetectHeadMessage(dir string) (string, error) { + out, err := run(dir, "log", "-1", "--format=%s") + if err != nil { + return "", fmt.Errorf("detect head message: %w", err) + } + return strings.TrimSpace(out), nil +} + +// DetectParentCommit returns the first parent commit SHA for HEAD. +func DetectParentCommit(dir string) (string, error) { + out, err := run(dir, "rev-parse", "HEAD^") + if err != nil { + return "", fmt.Errorf("detect parent commit: %w", err) + } + return strings.TrimSpace(out), nil +} + +// FileBlobHash returns the git blob hash for a tracked file at HEAD/index. +// filePath may be absolute or relative to dir. +func FileBlobHash(dir, filePath string) (string, error) { + rel := filePath + if filepath.IsAbs(filePath) { + var err error + rel, err = filepath.Rel(dir, filePath) + if err != nil { + return "", fmt.Errorf("file blob hash: %w", err) + } + } + rel = filepath.ToSlash(rel) + out, err := run(dir, "ls-files", "-s", "--", rel) + if err != nil { + return "", fmt.Errorf("file blob hash: %w", err) + } + fields := strings.Fields(out) + if len(fields) < 2 { + return "", fmt.Errorf("file blob hash: %q is not tracked", rel) + } + return fields[1], nil +} + // FileLastCommitAt returns the timestamp of the most recent commit that touched filePath // in the git repo rooted at dir. filePath may be absolute or relative to dir. func FileLastCommitAt(dir, filePath string) (time.Time, error) { @@ -50,6 +135,217 @@ func FileLastCommitAt(dir, filePath string) (time.Time, error) { return time.Unix(unix, 0).UTC(), nil } +func StatusSnapshot(dir string) (Status, error) { + status := Status{ + Branch: detectBestEffort(func() (string, error) { return DetectBranch(dir) }), + HeadCommit: detectBestEffort(func() (string, error) { return DetectHeadCommit(dir) }), + HeadMessage: detectBestEffort(func() (string, error) { return DetectHeadMessage(dir) }), + RemoteURL: detectBestEffort(func() (string, error) { return DetectRemoteURL(dir) }), + } + out, err := run(dir, "status", "--porcelain=v1", "-z") + if err != nil { + return status, fmt.Errorf("git status: %w", err) + } + entries := strings.Split(out, "\x00") + for i := 0; i < len(entries); i++ { + entry := entries[i] + if entry == "" || len(entry) < 4 { + continue + } + x, y := entry[0], entry[1] + path := strings.TrimSpace(entry[3:]) + if x == 'R' || x == 'C' { + i++ + } + if x != ' ' && x != '?' { + status.Staged = append(status.Staged, filepath.ToSlash(path)) + } + if y != ' ' && y != '?' { + status.Unstaged = append(status.Unstaged, filepath.ToSlash(path)) + } + if x == '?' && y == '?' { + status.Untracked = append(status.Untracked, filepath.ToSlash(path)) + } + if x == 'D' || y == 'D' { + status.Deleted = append(status.Deleted, filepath.ToSlash(path)) + } + } + return status, nil +} + +func WorktreeChangesAgainstHead(dir string) (map[string]WorktreeChange, error) { + status, err := StatusSnapshot(dir) + if err != nil { + return nil, err + } + changes := map[string]WorktreeChange{} + if status.HeadCommit != "" { + out, err := run(dir, "diff", "--name-status", "HEAD", "--") + if err != nil { + return nil, fmt.Errorf("git diff name-status: %w", err) + } + for line := range strings.SplitSeq(strings.TrimSpace(out), "\n") { + if strings.TrimSpace(line) == "" { + continue + } + fields := strings.Split(line, "\t") + if len(fields) < 2 { + continue + } + code := strings.TrimSpace(fields[0]) + switch { + case strings.HasPrefix(code, "R") || strings.HasPrefix(code, "C"): + if len(fields) >= 3 { + changes[filepath.ToSlash(fields[1])] = WorktreeDeleted + changes[filepath.ToSlash(fields[2])] = WorktreeAdded + } + case strings.HasPrefix(code, "A"): + changes[filepath.ToSlash(fields[1])] = WorktreeAdded + case strings.HasPrefix(code, "D"): + changes[filepath.ToSlash(fields[1])] = WorktreeDeleted + default: + changes[filepath.ToSlash(fields[1])] = WorktreeUpdated + } + } + } + for _, path := range status.Untracked { + changes[filepath.ToSlash(path)] = WorktreeAdded + } + if status.HeadCommit == "" { + for _, path := range status.Staged { + changes[filepath.ToSlash(path)] = WorktreeAdded + } + for _, path := range status.Unstaged { + if _, ok := changes[filepath.ToSlash(path)]; !ok { + changes[filepath.ToSlash(path)] = WorktreeAdded + } + } + for _, path := range status.Deleted { + changes[filepath.ToSlash(path)] = WorktreeDeleted + } + } + return changes, nil +} + +func LineDiffsAgainstHead(dir string) (map[string]LineDiff, error) { + out, err := run(dir, "diff", "--numstat", "HEAD", "--") + if err != nil { + return nil, fmt.Errorf("git diff numstat: %w", err) + } + diffs := map[string]LineDiff{} + for line := range strings.SplitSeq(strings.TrimSpace(out), "\n") { + if strings.TrimSpace(line) == "" { + continue + } + fields := strings.Split(line, "\t") + if len(fields) < 3 || fields[0] == "-" || fields[1] == "-" { + continue + } + added, err := strconv.Atoi(fields[0]) + if err != nil { + continue + } + removed, err := strconv.Atoi(fields[1]) + if err != nil { + continue + } + diffs[filepath.ToSlash(fields[2])] = LineDiff{Added: added, Removed: removed} + } + return diffs, nil +} + +func LineHunksAgainstHead(dir string) (map[string][]LineHunk, error) { + out, err := run(dir, "diff", "--unified=0", "HEAD", "--") + if err != nil { + return nil, fmt.Errorf("git diff hunks: %w", err) + } + return ParseLineHunks(out), nil +} + +func ParseLineHunks(diff string) map[string][]LineHunk { + hunks := map[string][]LineHunk{} + file := "" + var current *LineHunk + oldLine, newLine := 0, 0 + flush := func() { + if current != nil && file != "" { + current.File = file + hunks[file] = append(hunks[file], *current) + } + current = nil + } + for line := range strings.SplitSeq(diff, "\n") { + switch { + case strings.HasPrefix(line, "diff --git "): + flush() + file = parseDiffGitPath(line) + case strings.HasPrefix(line, "+++ b/"): + file = filepath.ToSlash(strings.TrimPrefix(line, "+++ b/")) + case strings.HasPrefix(line, "@@ "): + flush() + hunk, ok := parseHunkHeader(line) + if !ok { + continue + } + current = &hunk + oldLine = hunk.OldStart + newLine = hunk.NewStart + case current != nil && strings.HasPrefix(line, "+") && !strings.HasPrefix(line, "+++"): + current.AddedLines = append(current.AddedLines, newLine) + newLine++ + case current != nil && strings.HasPrefix(line, "-") && !strings.HasPrefix(line, "---"): + current.RemovedLines = append(current.RemovedLines, oldLine) + oldLine++ + case current != nil && strings.HasPrefix(line, " "): + oldLine++ + newLine++ + } + } + flush() + return hunks +} + +func parseDiffGitPath(line string) string { + fields := strings.Fields(line) + if len(fields) < 4 { + return "" + } + path := strings.TrimPrefix(fields[3], "b/") + return filepath.ToSlash(path) +} + +func parseHunkHeader(line string) (LineHunk, bool) { + fields := strings.Fields(line) + if len(fields) < 3 || !strings.HasPrefix(fields[1], "-") || !strings.HasPrefix(fields[2], "+") { + return LineHunk{}, false + } + oldStart, oldCount, ok := parseHunkRange(strings.TrimPrefix(fields[1], "-")) + if !ok { + return LineHunk{}, false + } + newStart, newCount, ok := parseHunkRange(strings.TrimPrefix(fields[2], "+")) + if !ok { + return LineHunk{}, false + } + return LineHunk{OldStart: oldStart, OldLineCount: oldCount, NewStart: newStart, NewLineCount: newCount}, true +} + +func parseHunkRange(value string) (int, int, bool) { + parts := strings.SplitN(value, ",", 2) + start, err := strconv.Atoi(parts[0]) + if err != nil { + return 0, 0, false + } + count := 1 + if len(parts) == 2 { + count, err = strconv.Atoi(parts[1]) + if err != nil { + return 0, 0, false + } + } + return start, count, true +} + // RepoRoot returns the absolute path of the top-level git working tree for the // repository that contains dir. func RepoRoot(dir string) (string, error) { @@ -60,6 +356,14 @@ func RepoRoot(dir string) (string, error) { return strings.TrimSpace(out), nil } +func detectBestEffort(fn func() (string, error)) string { + value, err := fn() + if err != nil { + return "" + } + return value +} + // run executes git with the given args in dir and returns the combined stdout output. func run(dir string, args ...string) (string, error) { cmd := exec.Command("git", args...) diff --git a/internal/git/git_test.go b/internal/git/git_test.go index 1dbd2d7..ea20269 100644 --- a/internal/git/git_test.go +++ b/internal/git/git_test.go @@ -4,6 +4,7 @@ import ( "os" "os/exec" "path/filepath" + "strconv" "strings" "testing" "time" @@ -165,3 +166,38 @@ func TestFilesChangedSince(t *testing.T) { t.Fatalf("unexpected files: %v", files) } } + +func TestParseLineHunks(t *testing.T) { + diff := `diff --git a/main.go b/main.go +index 1111111..2222222 100644 +--- a/main.go ++++ b/main.go +@@ -2 +2,2 @@ func A() { +- old ++ new ++ next +@@ -8,2 +9 @@ func B() { +- remove +- again ++ replace +` + hunks := ParseLineHunks(diff) + got := hunks["main.go"] + if len(got) != 2 { + t.Fatalf("expected 2 hunks, got %+v", got) + } + if strings.Join(intsToStrings(got[0].AddedLines), ",") != "2,3" || strings.Join(intsToStrings(got[0].RemovedLines), ",") != "2" { + t.Fatalf("unexpected first hunk lines: %+v", got[0]) + } + if strings.Join(intsToStrings(got[1].AddedLines), ",") != "9" || strings.Join(intsToStrings(got[1].RemovedLines), ",") != "8,9" { + t.Fatalf("unexpected second hunk lines: %+v", got[1]) + } +} + +func intsToStrings(values []int) []string { + out := make([]string, 0, len(values)) + for _, value := range values { + out = append(out, strconv.Itoa(value)) + } + return out +} diff --git a/internal/ignore/ignore.go b/internal/ignore/ignore.go index 0d43928..0d590d5 100644 --- a/internal/ignore/ignore.go +++ b/internal/ignore/ignore.go @@ -3,6 +3,8 @@ package ignore import ( + "bufio" + "os" "path/filepath" "strings" @@ -11,7 +13,13 @@ import ( // Rules holds gitignore-style exclusion patterns loaded from the workspace configuration file. type Rules struct { - Exclude []string `yaml:"exclude,omitempty"` + Exclude []string `yaml:"exclude,omitempty"` + Patterns []Pattern `yaml:"-"` +} + +type Pattern struct { + Value string + Negate bool } var implicitPathExcludes = []string{ @@ -25,6 +33,7 @@ var implicitPathExcludes = []string{ func Merge(rules ...*Rules) *Rules { merged := &Rules{} seen := make(map[string]struct{}) + seenPatterns := make(map[Pattern]struct{}) for _, ruleSet := range rules { if ruleSet == nil { continue @@ -40,20 +49,74 @@ func Merge(rules ...*Rules) *Rules { seen[pattern] = struct{}{} merged.Exclude = append(merged.Exclude, pattern) } + for _, pattern := range ruleSet.Patterns { + pattern.Value = strings.TrimSpace(pattern.Value) + if pattern.Value == "" { + continue + } + if _, ok := seenPatterns[pattern]; ok { + continue + } + seenPatterns[pattern] = struct{}{} + merged.Patterns = append(merged.Patterns, pattern) + } } - if len(merged.Exclude) == 0 { + if len(merged.Exclude) == 0 && len(merged.Patterns) == 0 { return nil } return merged } +func LoadGitIgnore(root string) (*Rules, error) { + root = filepath.Clean(root) + var patterns []Pattern + err := filepath.WalkDir(root, func(path string, d os.DirEntry, err error) error { + if err != nil { + return err + } + if d.IsDir() { + name := d.Name() + if name == ".git" || name == "node_modules" || name == "vendor" || name == ".venv" { + return filepath.SkipDir + } + return nil + } + if d.Name() != ".gitignore" { + return nil + } + base, err := filepath.Rel(root, filepath.Dir(path)) + if err != nil { + return err + } + if base == "." { + base = "" + } + filePatterns, err := readGitIgnoreFile(path, base) + if err != nil { + return err + } + patterns = append(patterns, filePatterns...) + return nil + }) + if err != nil { + return nil, err + } + if len(patterns) == 0 { + return nil, nil + } + return &Rules{Patterns: patterns}, nil +} + // ShouldIgnorePath returns true if the given file or folder path matches any exclusion pattern. // The path can be absolute or relative; matching is performed against both the full path and base name. func (r *Rules) ShouldIgnorePath(path string) bool { if r == nil { return shouldIgnorePathWithPatterns(path, implicitPathExcludes) } - return shouldIgnorePathWithPatterns(path, append(append([]string{}, implicitPathExcludes...), r.Exclude...)) + if shouldIgnorePathWithPatterns(path, append(append([]string{}, implicitPathExcludes...), r.Exclude...)) { + return true + } + return shouldIgnorePathWithOrderedPatterns(path, r.Patterns) } func shouldIgnorePathWithPatterns(path string, patterns []string) bool { @@ -77,6 +140,100 @@ func shouldIgnorePathWithPatterns(path string, patterns []string) bool { return false } +func shouldIgnorePathWithOrderedPatterns(path string, patterns []Pattern) bool { + path = normalizePath(path) + ignored := false + for _, pattern := range patterns { + if matchPathPattern(path, pattern.Value) { + ignored = !pattern.Negate + } + } + return ignored +} + +func readGitIgnoreFile(path, base string) ([]Pattern, error) { + file, err := os.Open(path) + if err != nil { + return nil, err + } + defer func() { _ = file.Close() }() + base = normalizePath(base) + scanner := bufio.NewScanner(file) + var patterns []Pattern + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if line == "" || strings.HasPrefix(line, "#") { + continue + } + negate := strings.HasPrefix(line, "!") + if negate { + line = strings.TrimSpace(strings.TrimPrefix(line, "!")) + } + line = strings.TrimPrefix(line, "\\") + for _, pattern := range expandGitIgnorePattern(base, line) { + patterns = append(patterns, Pattern{Value: pattern, Negate: negate}) + } + } + return patterns, scanner.Err() +} + +func expandGitIgnorePattern(base, pattern string) []string { + pattern = normalizePattern(pattern) + if pattern == "" || pattern == "/" { + return nil + } + rooted := strings.HasPrefix(pattern, "/") + dirOnly := strings.HasSuffix(pattern, "/") + pattern = strings.Trim(pattern, "/") + if pattern == "" { + return nil + } + hasSlash := strings.Contains(pattern, "/") + var expanded []string + add := func(value string) { + value = normalizePattern(value) + value = strings.TrimPrefix(value, "/") + if value == "" { + return + } + if dirOnly && !strings.HasSuffix(value, "/") { + value += "/" + } + expanded = append(expanded, value) + } + if base != "" { + if rooted || hasSlash { + add(base + "/" + pattern) + } else { + add(base + "/" + pattern) + add(base + "/**/" + pattern) + } + return expanded + } + if rooted || hasSlash { + add(pattern) + } else { + add(pattern) + add("**/" + pattern) + } + return expanded +} + +func matchPathPattern(path, pattern string) bool { + pattern = normalizePattern(pattern) + base := filepath.Base(path) + if matchPattern(pattern, path) || matchPattern(pattern, base) { + return true + } + if before, ok := strings.CutSuffix(pattern, "/"); ok { + trimmed := before + if path == trimmed || strings.HasPrefix(path, trimmed+"/") || base == trimmed { + return true + } + } + return false +} + // ShouldIgnoreFile returns true if the given file path is excluded. func (r *Rules) ShouldIgnoreFile(path string) bool { return r.ShouldIgnorePath(path) diff --git a/internal/ignore/ignore_test.go b/internal/ignore/ignore_test.go index c208432..8ec4cbb 100644 --- a/internal/ignore/ignore_test.go +++ b/internal/ignore/ignore_test.go @@ -1,6 +1,10 @@ package ignore -import "testing" +import ( + "os" + "path/filepath" + "testing" +) func TestShouldIgnorePath(t *testing.T) { r := &Rules{Exclude: []string{"vendor/", "node_modules/", ".git/", "**/*.pb.go", "**/*_test.go"}} @@ -59,3 +63,41 @@ func TestNilRules(t *testing.T) { t.Error("nil rules should never ignore") } } + +func TestLoadGitIgnore(t *testing.T) { + dir := t.TempDir() + if err := os.WriteFile(filepath.Join(dir, ".gitignore"), []byte("*.log\nignored.go\n/generated/\n!important.log\n"), 0o644); err != nil { + t.Fatal(err) + } + if err := os.MkdirAll(filepath.Join(dir, "pkg"), 0o755); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(dir, "pkg", ".gitignore"), []byte("local.go\nnested/\n"), 0o644); err != nil { + t.Fatal(err) + } + + rules, err := LoadGitIgnore(dir) + if err != nil { + t.Fatal(err) + } + cases := []struct { + path string + expect bool + }{ + {"ignored.go", true}, + {"sub/ignored.go", true}, + {"generated", true}, + {"generated/file.go", true}, + {"debug.log", true}, + {"important.log", false}, + {"pkg/local.go", true}, + {"pkg/sub/local.go", true}, + {"pkg/nested/file.go", true}, + {"other/local.go", false}, + } + for _, c := range cases { + if got := rules.ShouldIgnorePath(c.path); got != c.expect { + t.Errorf("ShouldIgnorePath(%q) = %v, want %v", c.path, got, c.expect) + } + } +} diff --git a/internal/layout/organic.go b/internal/layout/organic.go new file mode 100644 index 0000000..e7954b0 --- /dev/null +++ b/internal/layout/organic.go @@ -0,0 +1,206 @@ +// Package layout provides force-directed graph layout algorithms for the local +// watch materializer. The organic layout is a port of backend-wrapper/pkg/layout/organic.go +// adapted for int64 element IDs used in the SQLite-backed watch store. +package layout + +import ( + "math" + "math/rand/v2" + "time" + + "github.com/mertcikla/tld/internal/workspace" +) + +const ( + NodeWidth = 200.0 + NodeHeight = 120.0 + Iterations = 300 + AlphaStart = 1.0 + AlphaMin = 0.001 + VelocityDecay = 0.6 +) + +// Tunable layout parameters — override via environment variables. +var ( + layoutConfig = workspace.ResolveWatchLayoutConfig() + LinkDistance = layoutConfig.LinkDistance + ChargeStrength = layoutConfig.ChargeStrength + CollideRadius = layoutConfig.CollideRadius + GravityStrength = layoutConfig.GravityStrength +) + +// Node is a positioned graph node. ID matches an element_id in the placements table. +type Node struct { + ID int64 + X, Y float64 + VX, VY float64 + Degree int // number of connected edges +} + +// Edge is a directed connection between two Nodes. +type Edge struct { + Source *Node + Target *Node +} + +// OrganicLayout applies a D3-like force-directed layout to nodes and edges, +// mutating node X/Y positions in place. +func OrganicLayout(nodes []*Node, edges []*Edge) { + if len(nodes) == 0 { + return + } + + // Initialize degrees. + for _, e := range edges { + if e.Source != nil && e.Target != nil { + e.Source.Degree++ + e.Target.Degree++ + } + } + + // Initialize random generator and scatter unpositioned nodes to avoid exact overlapping. + // #nosec G404 + rng := rand.New(rand.NewPCG(uint64(time.Now().UnixNano()), 0)) + for _, n := range nodes { + // D3 initialises unpositioned nodes in a phyllotaxis arrangement; we + // scatter in a 2:1 aspect ratio to match the canvas. + if n.X == 0 && n.Y == 0 { + n.X = (rng.Float64() - 0.5) * float64(len(nodes)) * 10.0 + n.Y = (rng.Float64() - 0.5) * float64(len(nodes)) * 5.0 + } + } + + alphaDecay := 1.0 - math.Pow(AlphaMin, 1.0/float64(Iterations)) + alpha := AlphaStart + + for range Iterations { + alpha += (AlphaMin - alpha) * alphaDecay + + // 0. Gravity — pull nodes toward the centre. + gravityStrength := GravityStrength * alpha + for _, n := range nodes { + n.VX -= n.X * gravityStrength + n.VY -= n.Y * gravityStrength * 2.0 // 2x in Y nudges layout toward 2:1 aspect ratio + } + + // 1. Many-Body Force (repulsion, O(n²)). + for i := range nodes { + n1 := nodes[i] + for j := i + 1; j < len(nodes); j++ { + n2 := nodes[j] + + dx := n2.X - n1.X + dy := n2.Y - n1.Y + distSq := dx*dx + dy*dy + if distSq == 0 { + dx = (rng.Float64() - 0.5) * 1e-3 + dy = (rng.Float64() - 0.5) * 1e-3 + distSq = dx*dx + dy*dy + } + dist := math.Sqrt(distSq) + if dist < 1.0 { + distSq = 1.0 + dist = 1.0 + } + + w := ChargeStrength * alpha / (distSq * dist) + fvX := dx * w + fvY := dy * w + + n2.VX += fvX + n2.VY += fvY + n1.VX -= fvX + n1.VY -= fvY + } + } + + // 2. Link Force (spring attraction along edges). + for _, e := range edges { + s := e.Source + t := e.Target + if s == nil || t == nil { + continue + } + + dx := t.X + t.VX - (s.X + s.VX) + dy := t.Y + t.VY - (s.Y + s.VY) + dist := math.Sqrt(dx*dx + dy*dy) + if dist == 0 { + dx = (rng.Float64() - 0.5) * 1e-3 + dy = (rng.Float64() - 0.5) * 1e-3 + dist = math.Sqrt(dx*dx + dy*dy) + } + + diff := (dist - LinkDistance) / dist * alpha + biasS := float64(s.Degree) / float64(s.Degree+t.Degree) + if s.Degree+t.Degree == 0 { + biasS = 0.5 + } + biasT := 1.0 - biasS + + t.VX -= dx * diff * biasS + t.VY -= dy * diff * biasS + s.VX += dx * diff * biasT + s.VY += dy * diff * biasT + } + + // 3. Collision Force (simple O(n²) — fine for <1 000 nodes). + r := CollideRadius * 2.0 + rSq := r * r + for i := range nodes { + n1 := nodes[i] + for j := i + 1; j < len(nodes); j++ { + n2 := nodes[j] + + dx := n2.X + n2.VX - (n1.X + n1.VX) + dy := n2.Y + n2.VY - (n1.Y + n1.VY) + distSq := dx*dx + dy*dy + + if distSq < rSq { + dist := math.Sqrt(distSq) + if dist == 0 { + dx = (rng.Float64() - 0.5) * 1e-3 + dy = (rng.Float64() - 0.5) * 1e-3 + dist = math.Sqrt(dx*dx + dy*dy) + } + diff := (r - dist) / dist * 0.7 + + n2.VX += dx * diff * 0.5 + n2.VY += dy * diff * 0.5 + n1.VX -= dx * diff * 0.5 + n1.VY -= dy * diff * 0.5 + } + } + } + + // 4. Velocity verlet integration. + for _, n := range nodes { + n.VX *= VelocityDecay + n.VY *= VelocityDecay + n.X += n.VX + n.Y += n.VY + } + } + + // 5. Shift centroid to origin. + var sumX, sumY float64 + for _, n := range nodes { + sumX += n.X + sumY += n.Y + } + if len(nodes) > 0 { + cx := sumX / float64(len(nodes)) + cy := sumY / float64(len(nodes)) + for _, n := range nodes { + n.X -= cx + n.Y -= cy + } + } + + // 6. Apply the same top-left offset as the frontend `runForce`: + // positions.set(n.id, { x: n.x - NODE_W / 2, y: n.y - NODE_H / 2 }) + for _, n := range nodes { + n.X -= NodeWidth / 2.0 + n.Y -= NodeHeight / 2.0 + } +} diff --git a/internal/planner/plan.go b/internal/planner/plan.go index c0cc34e..02cfa08 100644 --- a/internal/planner/plan.go +++ b/internal/planner/plan.go @@ -111,6 +111,9 @@ func buildFromElements(ws *workspace.Workspace, recreateIDs bool) (*Plan, error) if element.FilePath != "" { planElement.FilePath = &element.FilePath } + if len(element.Tags) > 0 { + planElement.Tags = append([]string(nil), element.Tags...) + } if element.ViewLabel != "" { planElement.ViewLabel = &element.ViewLabel } diff --git a/internal/planner/warnings.go b/internal/planner/warnings.go index 03dda87..9dd311d 100644 --- a/internal/planner/warnings.go +++ b/internal/planner/warnings.go @@ -172,18 +172,10 @@ func AnalyzePlan(ws *workspace.Workspace) []WarningGroup { } func newWarningContext(ws *workspace.Workspace) *warningContext { - level := workspace.DefaultValidationLevel - allowLowInsight := false - var includeRules []string - var excludeRules []string - if ws.Config.Validation != nil { - if ws.Config.Validation.Level > 0 { - level = ws.Config.Validation.Level - } - allowLowInsight = ws.Config.Validation.AllowLowInsight - includeRules = ws.Config.Validation.IncludeRules - excludeRules = ws.Config.Validation.ExcludeRules - } + level := ws.Config.Validation.Level + allowLowInsight := ws.Config.Validation.AllowLowInsight + includeRules := ws.Config.Validation.IncludeRules + excludeRules := ws.Config.Validation.ExcludeRules return &warningContext{ ws: ws, diff --git a/internal/planner/warnings_test.go b/internal/planner/warnings_test.go index 0bbc34b..0d1b427 100644 --- a/internal/planner/warnings_test.go +++ b/internal/planner/warnings_test.go @@ -91,7 +91,7 @@ func TestAnalyzePlan_TechnologyValidation(t *testing.T) { }, }, Config: workspace.Config{ - Validation: &workspace.ValidationConfig{ + Validation: workspace.ValidationConfig{ Level: tt.level, IncludeRules: tt.includeRules, ExcludeRules: tt.excludeRules, @@ -130,7 +130,7 @@ func TestAnalyzePlan_DeadEndDrilldownUsesOwnedViews(t *testing.T) { }, }, Config: workspace.Config{ - Validation: &workspace.ValidationConfig{Level: 1}, + Validation: workspace.ValidationConfig{Level: 1}, }, } diff --git a/internal/server/density.go b/internal/server/density.go new file mode 100644 index 0000000..08ffca4 --- /dev/null +++ b/internal/server/density.go @@ -0,0 +1,169 @@ +package server + +import ( + "database/sql" + "encoding/json" + "errors" + "net/http" + "strconv" + + "github.com/mertcikla/tld/internal/store" +) + +type densityRequest struct { + DensityLevel int `json:"density_level"` +} + +type visibilityOverrideRequest struct { + ResourceType string `json:"resource_type"` + ResourceID int64 `json:"resource_id"` + LevelDelta int `json:"level_delta"` +} + +func registerDensityHandlers(mux *http.ServeMux, sqliteStore *store.SQLiteStore) { + mux.HandleFunc("GET /api/views/{id}/projected-content", func(w http.ResponseWriter, r *http.Request) { + viewID, ok := parseViewID(w, r) + if !ok { + return + } + content, err := sqliteStore.ProjectedViewContent(r.Context(), viewID) + if err != nil { + writeDensityError(w, err) + return + } + writeJSON(w, content) + }) + + mux.HandleFunc("GET /api/views/{id}/density", func(w http.ResponseWriter, r *http.Request) { + viewID, ok := parseViewID(w, r) + if !ok { + return + } + level, err := sqliteStore.ViewDensityLevel(r.Context(), viewID) + if err != nil { + writeDensityError(w, err) + return + } + writeJSON(w, map[string]any{"view_id": viewID, "density_level": level}) + }) + + mux.HandleFunc("PUT /api/views/{id}/density", func(w http.ResponseWriter, r *http.Request) { + viewID, ok := parseViewID(w, r) + if !ok { + return + } + var req densityRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeJSONError(w, http.StatusBadRequest, "invalid JSON") + return + } + if err := sqliteStore.SetViewDensityLevel(r.Context(), viewID, req.DensityLevel); err != nil { + writeDensityError(w, err) + return + } + writeJSON(w, map[string]any{"view_id": viewID, "density_level": req.DensityLevel}) + }) + + mux.HandleFunc("GET /api/views/{id}/visibility-overrides", func(w http.ResponseWriter, r *http.Request) { + viewID, ok := parseViewID(w, r) + if !ok { + return + } + overrides, err := sqliteStore.VisibilityOverrides(r.Context(), viewID) + if err != nil { + writeDensityError(w, err) + return + } + writeJSON(w, map[string]any{"overrides": overrides}) + }) + + mux.HandleFunc("PUT /api/views/{id}/visibility-overrides", func(w http.ResponseWriter, r *http.Request) { + viewID, ok := parseViewID(w, r) + if !ok { + return + } + var req visibilityOverrideRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeJSONError(w, http.StatusBadRequest, "invalid JSON") + return + } + override, err := sqliteStore.SetVisibilityOverride(r.Context(), viewID, req.ResourceType, req.ResourceID, req.LevelDelta) + if err != nil { + writeDensityError(w, err) + return + } + writeJSON(w, map[string]any{"override": override}) + }) + + mux.HandleFunc("POST /api/views/{id}/visibility-overrides/{resource_type}/{resource_id}/promote", func(w http.ResponseWriter, r *http.Request) { + adjustVisibilityOverride(w, r, sqliteStore, 1) + }) + mux.HandleFunc("POST /api/views/{id}/visibility-overrides/{resource_type}/{resource_id}/demote", func(w http.ResponseWriter, r *http.Request) { + adjustVisibilityOverride(w, r, sqliteStore, -1) + }) + mux.HandleFunc("DELETE /api/views/{id}/visibility-overrides/{resource_type}/{resource_id}", func(w http.ResponseWriter, r *http.Request) { + viewID, resourceType, resourceID, ok := parseOverridePath(w, r) + if !ok { + return + } + if err := sqliteStore.DeleteVisibilityOverride(r.Context(), viewID, resourceType, resourceID); err != nil { + writeDensityError(w, err) + return + } + writeJSON(w, map[string]bool{"ok": true}) + }) +} + +func adjustVisibilityOverride(w http.ResponseWriter, r *http.Request, sqliteStore *store.SQLiteStore, step int) { + viewID, resourceType, resourceID, ok := parseOverridePath(w, r) + if !ok { + return + } + override, err := sqliteStore.AdjustVisibilityOverride(r.Context(), viewID, resourceType, resourceID, step) + if err != nil { + writeDensityError(w, err) + return + } + writeJSON(w, map[string]any{"override": override}) +} + +func parseViewID(w http.ResponseWriter, r *http.Request) (int64, bool) { + viewID, err := strconv.ParseInt(r.PathValue("id"), 10, 64) + if err != nil || viewID <= 0 { + writeJSONError(w, http.StatusBadRequest, "invalid view id") + return 0, false + } + return viewID, true +} + +func parseOverridePath(w http.ResponseWriter, r *http.Request) (int64, string, int64, bool) { + viewID, ok := parseViewID(w, r) + if !ok { + return 0, "", 0, false + } + resourceType := r.PathValue("resource_type") + if err := store.ValidateResourceType(resourceType); err != nil { + writeJSONError(w, http.StatusBadRequest, err.Error()) + return 0, "", 0, false + } + resourceID, err := strconv.ParseInt(r.PathValue("resource_id"), 10, 64) + if err != nil || resourceID <= 0 { + writeJSONError(w, http.StatusBadRequest, "invalid resource id") + return 0, "", 0, false + } + return viewID, resourceType, resourceID, true +} + +func writeDensityError(w http.ResponseWriter, err error) { + switch { + case errors.Is(err, sql.ErrNoRows): + writeJSONError(w, http.StatusNotFound, "view not found") + default: + writeJSONError(w, http.StatusBadRequest, err.Error()) + } +} + +func writeJSON(w http.ResponseWriter, payload any) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(payload) +} diff --git a/internal/server/editor.go b/internal/server/editor.go new file mode 100644 index 0000000..de0d862 --- /dev/null +++ b/internal/server/editor.go @@ -0,0 +1,195 @@ +package server + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + "os" + "os/exec" + "path/filepath" + "runtime" + "strconv" + "strings" + + "github.com/mertcikla/tld/internal/watch" +) + +type openEditorRequest struct { + Editor string `json:"editor"` + Repo string `json:"repo"` + FilePath string `json:"file_path"` + Line int `json:"line"` +} + +func registerEditorHandlers(mux *http.ServeMux, store *watch.Store) { + mux.HandleFunc("POST /api/editor/open", func(w http.ResponseWriter, r *http.Request) { + var req openEditorRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeJSONError(w, http.StatusBadRequest, "invalid JSON") + return + } + if err := openInEditor(r.Context(), store, req); err != nil { + writeJSONError(w, http.StatusBadRequest, err.Error()) + return + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]bool{"ok": true}) + }) +} + +func openInEditor(ctx context.Context, store *watch.Store, req openEditorRequest) error { + editor := strings.TrimSpace(strings.ToLower(req.Editor)) + if editor != "zed" && editor != "vscode" { + return fmt.Errorf("unsupported editor %q", req.Editor) + } + if strings.TrimSpace(req.FilePath) == "" { + return errors.New("file_path is required") + } + + target, err := resolveEditorPath(ctx, store, req.Repo, req.FilePath) + if err != nil { + return err + } + if req.Line > 0 { + target = target + ":" + strconv.Itoa(req.Line) + } + + var cmdName string + var args []string + if editor == "zed" { + cmdName = "zed" + args = []string{target} + } else { + cmdName = "code" + args = []string{"-g", target} + } + + cmdPath, err := lookPath(cmdName) + if err != nil { + return fmt.Errorf("%s command not found in PATH", cmdName) + } + cmd := exec.Command(cmdPath, args...) + if err := cmd.Start(); err != nil { + return err + } + if cmd.Process != nil { + return cmd.Process.Release() + } + return nil +} + +type repositoryFetcher interface { + Repositories(ctx context.Context) ([]watch.Repository, error) +} + +func resolveEditorPath(ctx context.Context, store repositoryFetcher, repoValue string, filePath string) (string, error) { + cleanFile := strings.TrimSpace(filePath) + + repos, err := store.Repositories(ctx) + if err != nil { + return "", fmt.Errorf("load watched repositories: %w", err) + } + + if filepath.IsAbs(cleanFile) { + cleanFile = filepath.Clean(cleanFile) + for _, repo := range repos { + root := filepath.Clean(repo.RepoRoot) + if cleanFile == root || strings.HasPrefix(cleanFile, root+string(filepath.Separator)) { + return cleanFile, nil + } + } + return "", errors.New("absolute file_path must reside within a watched repository") + } + + if strings.HasPrefix(cleanFile, "~") { + return "", errors.New("file_path must be absolute or relative to a watched repository") + } + + relative := filepath.Clean(filepath.FromSlash(cleanFile)) + if relative == "." || strings.HasPrefix(relative, ".."+string(filepath.Separator)) || relative == ".." { + return "", errors.New("file_path must stay inside the watched repository") + } + + if len(repos) == 0 { + return "", errors.New("no watched repositories are available to resolve this relative file path") + } + + repo, ok := matchRepository(repos, repoValue) + if !ok && len(repos) == 1 { + repo = repos[0] + ok = true + } + if !ok { + return "", errors.New("could not resolve the linked repository to a local worktree") + } + + root := filepath.Clean(repo.RepoRoot) + target := filepath.Clean(filepath.Join(root, relative)) + if target != root && !strings.HasPrefix(target, root+string(filepath.Separator)) { + return "", errors.New("resolved file path escapes the watched repository") + } + return target, nil +} + +func matchRepository(repos []watch.Repository, value string) (watch.Repository, bool) { + needle := strings.TrimSpace(value) + needleSlug := githubSlug(needle) + for _, repo := range repos { + candidates := []string{repo.RepoRoot} + if repo.RemoteURL.Valid { + candidates = append(candidates, repo.RemoteURL.String) + } + for _, candidate := range candidates { + if strings.EqualFold(strings.TrimSpace(candidate), needle) { + return repo, true + } + if needleSlug != "" && strings.EqualFold(githubSlug(candidate), needleSlug) { + return repo, true + } + } + } + return watch.Repository{}, false +} + +func githubSlug(value string) string { + cleaned := strings.TrimSpace(value) + cleaned = strings.TrimSuffix(cleaned, ".git") + if after, ok := strings.CutPrefix(cleaned, "git@github.com:"); ok { + return strings.ToLower(after) + } + cleaned = strings.TrimPrefix(cleaned, "https://") + cleaned = strings.TrimPrefix(cleaned, "http://") + cleaned = strings.TrimPrefix(cleaned, "github.com/") + cleaned = strings.TrimPrefix(cleaned, "www.github.com/") + parts := strings.Split(cleaned, "/") + if len(parts) >= 2 && !strings.Contains(parts[0], ".") { + return strings.ToLower(parts[0] + "/" + parts[1]) + } + if len(parts) >= 3 && strings.EqualFold(parts[0], "github.com") { + return strings.ToLower(parts[1] + "/" + parts[2]) + } + return "" +} + +func lookPath(name string) (string, error) { + if path, err := exec.LookPath(name); err == nil { + return path, nil + } + if runtime.GOOS == "darwin" { + for _, dir := range []string{"/opt/homebrew/bin", "/usr/local/bin", "/usr/bin", "/bin"} { + candidate := filepath.Join(dir, name) + if info, err := os.Stat(candidate); err == nil && !info.IsDir() && info.Mode()&0o111 != 0 { + return candidate, nil + } + } + } + return "", exec.ErrNotFound +} + +func writeJSONError(w http.ResponseWriter, status int, message string) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + _ = json.NewEncoder(w).Encode(map[string]string{"error": message}) +} diff --git a/internal/server/editor_test.go b/internal/server/editor_test.go new file mode 100644 index 0000000..ee15772 --- /dev/null +++ b/internal/server/editor_test.go @@ -0,0 +1,104 @@ +package server + +import ( + "context" + "path/filepath" + "testing" + + "github.com/mertcikla/tld/internal/watch" +) + +type mockStore struct { + repos []watch.Repository + err error +} + +func (m *mockStore) Repositories(ctx context.Context) ([]watch.Repository, error) { + return m.repos, m.err +} + +func TestResolveEditorPath(t *testing.T) { + repos := []watch.Repository{ + {RepoRoot: "/a/project1"}, + {RepoRoot: "/b/project2"}, + } + if filepath.Separator == '\\' { + repos = []watch.Repository{ + {RepoRoot: "C:\\a\\project1"}, + {RepoRoot: "C:\\b\\project2"}, + } + } + + store := &mockStore{repos: repos} + + t.Run("absolute path inside repository", func(t *testing.T) { + path := filepath.Join(repos[0].RepoRoot, "src", "main.go") + got, err := resolveEditorPath(context.Background(), store, "", path) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got != path { + t.Errorf("got %q, want %q", got, path) + } + }) + + t.Run("absolute path matching repository root exactly", func(t *testing.T) { + path := repos[1].RepoRoot + got, err := resolveEditorPath(context.Background(), store, "", path) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got != path { + t.Errorf("got %q, want %q", got, path) + } + }) + + t.Run("absolute path outside repositories", func(t *testing.T) { + path := "/etc/passwd" + if filepath.Separator == '\\' { + path = "C:\\Windows\\System32\\drivers\\etc\\hosts" + } + + _, err := resolveEditorPath(context.Background(), store, "", path) + if err == nil { + t.Fatal("expected error for path outside repository, got nil") + } + expectedErr := "absolute file_path must reside within a watched repository" + if err.Error() != expectedErr { + t.Errorf("got error %q, want %q", err.Error(), expectedErr) + } + }) + + t.Run("relative path with single repo", func(t *testing.T) { + singleStore := &mockStore{repos: repos[:1]} + path := "src/main.go" + got, err := resolveEditorPath(context.Background(), singleStore, "", path) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + expected := filepath.Join(repos[0].RepoRoot, "src", "main.go") + if got != expected { + t.Errorf("got %q, want %q", got, expected) + } + }) + + t.Run("relative path with multiple repos and explicit repo match", func(t *testing.T) { + path := "src/main.go" + got, err := resolveEditorPath(context.Background(), store, repos[1].RepoRoot, path) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + expected := filepath.Join(repos[1].RepoRoot, "src", "main.go") + if got != expected { + t.Errorf("got %q, want %q", got, expected) + } + }) + + t.Run("relative path escaping repository", func(t *testing.T) { + path := "../outside.go" + _, err := resolveEditorPath(context.Background(), store, "", path) + if err == nil { + t.Fatal("expected error for escaping path, got nil") + } + }) +} diff --git a/internal/server/server.go b/internal/server/server.go index 315d089..4546126 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -3,6 +3,7 @@ package server import ( "context" "encoding/json" + "fmt" "io/fs" "net/http" "net/http/httputil" @@ -13,8 +14,11 @@ import ( "strings" "buf.build/gen/go/tldiagramcom/diagram/connectrpc/go/diag/v1/diagv1connect" + diagv1 "buf.build/gen/go/tldiagramcom/diagram/protocolbuffers/go/diag/v1" + "connectrpc.com/connect" "github.com/google/uuid" "github.com/mertcikla/tld/internal/store" + "github.com/mertcikla/tld/internal/watch" "github.com/mertcikla/tld/pkg/api" ) @@ -24,12 +28,18 @@ type Server struct { func New(sqliteStore *store.SQLiteStore, static fs.FS, workspaceID uuid.UUID) (*Server, error) { apiStore := store.NewAPIAdapter(sqliteStore) - wsSvc := &api.WorkspaceService{Store: apiStore} + watchStore := watch.NewStore(sqliteStore.DB()) + lockHooks := watchLockHooks{store: watchStore} + wsSvc := &api.WorkspaceService{Store: apiStore, Hooks: lockHooks} + orgSvc := &api.OrgService{Store: apiStore, Hooks: lockHooks} depSvc := &api.DependencyService{Store: apiStore} importSvc := &api.ImportService{Store: apiStore} - versionSvc := &api.WorkspaceVersionService{Store: apiStore} + versionSvc := &api.WorkspaceVersionService{Store: apiStore, Hooks: lockHooks} mux := http.NewServeMux() + watch.NewHandler(watchStore).Register(mux) + registerEditorHandlers(mux, watchStore) + registerDensityHandlers(mux, sqliteStore) mux.HandleFunc("GET /api/ready", func(w http.ResponseWriter, r *http.Request) { views, elements, connectors, err := apiStore.GetWorkspaceResourceCounts(r.Context(), workspaceID) @@ -73,6 +83,9 @@ func New(sqliteStore *store.SQLiteStore, static fs.FS, workspaceID uuid.UUID) (* wsPath, wsHandler := diagv1connect.NewWorkspaceServiceHandler(wsSvc) mux.Handle("/api"+wsPath, http.StripPrefix("/api", wsHandler)) + orgPath, orgHandler := diagv1connect.NewOrgServiceHandler(orgSvc) + mux.Handle("/api"+orgPath, http.StripPrefix("/api", orgHandler)) + depPath, depHandler := diagv1connect.NewDependencyServiceHandler(depSvc) mux.Handle("/api"+depPath, http.StripPrefix("/api", depHandler)) @@ -93,6 +106,26 @@ func New(sqliteStore *store.SQLiteStore, static fs.FS, workspaceID uuid.UUID) (* return &Server{handler: handler}, nil } +type watchLockHooks struct { + api.NopWorkspaceHooks + store *watch.Store +} + +func (h watchLockHooks) CheckWrite(ctx context.Context, _ uuid.UUID, resourceType string) error { + if h.store == nil { + return nil + } + applying, err := h.store.ActiveApplyLock(ctx, watch.LockHeartbeatTimeout) + if err != nil || !applying { + return err + } + return connect.NewError(connect.CodeFailedPrecondition, fmt.Errorf("workspace is being updated by tld watch; retry editing %s shortly", resourceType)) +} + +func (h watchLockHooks) CheckApplyPlan(ctx context.Context, workspaceID uuid.UUID, _ *diagv1.ApplyPlanRequest) error { + return h.CheckWrite(ctx, workspaceID, "workspace") +} + func (s *Server) Routes() http.Handler { return s.handler } diff --git a/internal/server/server_test.go b/internal/server/server_test.go new file mode 100644 index 0000000..0127b30 --- /dev/null +++ b/internal/server/server_test.go @@ -0,0 +1,247 @@ +package server + +import ( + "context" + "database/sql" + "encoding/json" + "io/fs" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "strings" + "testing" + "testing/fstest" + + diagv1connect "buf.build/gen/go/tldiagramcom/diagram/connectrpc/go/diag/v1/diagv1connect" + diagv1 "buf.build/gen/go/tldiagramcom/diagram/protocolbuffers/go/diag/v1" + "connectrpc.com/connect" + "github.com/google/uuid" + assets "github.com/mertcikla/tld" + localstore "github.com/mertcikla/tld/internal/store" + "github.com/mertcikla/tld/internal/watch" +) + +func TestServerReadyReportsResourceCounts(t *testing.T) { + sqliteStore, routes := newTestServer(t, uuid.MustParse("11111111-2222-3333-4444-555555555555"), nil) + if _, err := sqliteStore.DB().Exec(` + INSERT INTO elements(id, name, tags, technology_connectors, created_at, updated_at) + VALUES + (10, 'API', '[]', '[]', 'now', 'now'), + (11, 'DB', '[]', '[]', 'now', 'now'); + INSERT INTO connectors(view_id, source_element_id, target_element_id, direction, style, created_at, updated_at) + VALUES (1, 10, 11, 'forward', 'solid', 'now', 'now'); + `); err != nil { + t.Fatal(err) + } + + rec := httptest.NewRecorder() + routes.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/api/ready", nil)) + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, body = %s", rec.Code, rec.Body.String()) + } + var body struct { + OK bool `json:"ok"` + Resources struct { + Views int `json:"views"` + Elements int `json:"elements"` + Connectors int `json:"connectors"` + } `json:"resources"` + } + if err := json.NewDecoder(rec.Body).Decode(&body); err != nil { + t.Fatal(err) + } + if !body.OK || body.Resources.Views != 1 || body.Resources.Elements != 2 || body.Resources.Connectors != 1 { + t.Fatalf("ready body = %+v, want 1/2/1 resources", body) + } +} + +func TestServerOrgTagColorsRoundTrip(t *testing.T) { + _, routes := newTestServer(t, uuid.MustParse("11111111-2222-3333-4444-555555555555"), nil) + server := httptest.NewServer(routes) + defer server.Close() + + client := diagv1connect.NewOrgServiceClient(http.DefaultClient, server.URL+"/api") + description := "User managed color" + if _, err := client.UpdateTag(context.Background(), connect.NewRequest(&diagv1.UpdateTagRequest{ + Tag: "role:watch", + Color: "#123456", + Description: &description, + })); err != nil { + t.Fatal(err) + } + resp, err := client.ListTagColors(context.Background(), connect.NewRequest(&diagv1.ListTagColorsRequest{})) + if err != nil { + t.Fatal(err) + } + tag := resp.Msg.GetTags()["role:watch"] + if tag == nil || tag.GetColor() != "#123456" || tag.Description == nil || tag.GetDescription() != description { + t.Fatalf("tag = %+v, want persisted color and description", tag) + } +} + +func TestServerInjectsWorkspaceIDIntoConnectRPCResponses(t *testing.T) { + workspaceID := uuid.MustParse("aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee") + sqliteStore, routes := newTestServer(t, workspaceID, nil) + if _, err := sqliteStore.DB().Exec(` + INSERT INTO elements(id, name, tags, technology_connectors, created_at, updated_at) + VALUES (10, 'API', '[]', '[]', 'now', 'now'); + `); err != nil { + t.Fatal(err) + } + srv := httptest.NewServer(routes) + t.Cleanup(srv.Close) + + client := diagv1connect.NewWorkspaceServiceClient(srv.Client(), srv.URL+"/api") + resp, err := client.ListElements(context.Background(), connect.NewRequest(&diagv1.ListElementsRequest{})) + if err != nil { + t.Fatal(err) + } + if len(resp.Msg.GetElements()) != 1 { + t.Fatalf("elements = %+v, want one element", resp.Msg.GetElements()) + } + if got := resp.Msg.GetElements()[0].GetOrgId(); got != workspaceID.String() { + t.Fatalf("org id = %q, want %s", got, workspaceID) + } +} + +func TestWatchSessionLeaseDoesNotBlockWorkspaceWrites(t *testing.T) { + sqliteStore, routes := newTestServer(t, uuid.New(), nil) + repositoryID := insertWatchRepository(t, sqliteStore.DB()) + watchStore := watch.NewStore(sqliteStore.DB()) + if _, err := watchStore.AcquireLock(context.Background(), repositoryID, os.Getpid(), "session-token", watch.LockHeartbeatTimeout); err != nil { + t.Fatal(err) + } + srv := httptest.NewServer(routes) + t.Cleanup(srv.Close) + + client := diagv1connect.NewWorkspaceServiceClient(srv.Client(), srv.URL+"/api") + resp, err := client.CreateElement(context.Background(), connect.NewRequest(&diagv1.CreateElementRequest{Name: "Manual"})) + if err != nil { + t.Fatal(err) + } + if resp.Msg.GetElement().GetName() != "Manual" { + t.Fatalf("created element = %+v", resp.Msg.GetElement()) + } +} + +func TestWatchApplyLeaseBlocksWorkspaceWrites(t *testing.T) { + sqliteStore, routes := newTestServer(t, uuid.New(), nil) + repositoryID := insertWatchRepository(t, sqliteStore.DB()) + watchStore := watch.NewStore(sqliteStore.DB()) + if err := watchStore.AcquireApplyLock(context.Background(), repositoryID, os.Getpid(), "apply-token", watch.LockHeartbeatTimeout); err != nil { + t.Fatal(err) + } + srv := httptest.NewServer(routes) + t.Cleanup(srv.Close) + + client := diagv1connect.NewWorkspaceServiceClient(srv.Client(), srv.URL+"/api") + _, err := client.CreateElement(context.Background(), connect.NewRequest(&diagv1.CreateElementRequest{Name: "Manual"})) + if code := connect.CodeOf(err); code != connect.CodeFailedPrecondition { + t.Fatalf("code = %s, want failed_precondition: %v", code, err) + } + if err := watchStore.ReleaseApplyLock(context.Background(), repositoryID, "apply-token"); err != nil { + t.Fatal(err) + } + if _, err := client.CreateElement(context.Background(), connect.NewRequest(&diagv1.CreateElementRequest{Name: "Manual"})); err != nil { + t.Fatal(err) + } +} + +func TestServerRoutesThumbnailAndStaticFallback(t *testing.T) { + _, routes := newTestServer(t, uuid.New(), fstest.MapFS{ + "frontend/dist/index.html": {Data: []byte("app")}, + "frontend/dist/app.js": {Data: []byte("console.log('app')")}, + }) + + tests := []struct { + name string + path string + wantStatus int + wantType string + wantBodySub string + }{ + { + name: "root thumbnail", + path: "/api/views/1/thumbnail.svg", + wantStatus: http.StatusOK, + wantType: "image/svg+xml; charset=utf-8", + wantBodySub: "app")}} + } + sqliteStore, err := localstore.Open(filepath.Join(t.TempDir(), "tld.db"), assets.FS) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { _ = sqliteStore.Legacy().Close() }) + srv, err := New(sqliteStore, static, workspaceID) + if err != nil { + t.Fatal(err) + } + return sqliteStore, srv.Routes() +} diff --git a/internal/store/apistore.go b/internal/store/apistore.go index b6eed8f..793cb40 100644 --- a/internal/store/apistore.go +++ b/internal/store/apistore.go @@ -2,8 +2,10 @@ package store import ( "context" + "database/sql" "errors" "fmt" + "strconv" "time" diagv1 "buf.build/gen/go/tldiagramcom/diagram/protocolbuffers/go/diag/v1" @@ -39,19 +41,30 @@ func (a *APIAdapter) ListViews(ctx context.Context, _ uuid.UUID) ([]*diagv1.View return out, nil } -func (a *APIAdapter) GetViews(ctx context.Context, _ uuid.UUID, ownerElementID *int32, isRoot *bool, search string, limit, offset int) ([]*diagv1.View, int, error) { - nodes, err := a.Store.ViewTree(ctx) - if err != nil { - return nil, 0, err +func (a *APIAdapter) GetViews(ctx context.Context, _ uuid.UUID, parentViewID *int32, isRoot *bool, search string, limit, offset int) ([]*diagv1.View, int, error) { + var flat []app.ViewTreeNode + switch { + case parentViewID != nil: + nodes, err := a.Store.legacy.ChildViews(ctx, int64(*parentViewID)) + if err != nil { + return nil, 0, err + } + flat = nodes + case isRoot != nil && *isRoot: + nodes, err := a.Store.legacy.RootViews(ctx) + if err != nil { + return nil, 0, err + } + flat = nodes + default: + nodes, err := a.Store.ViewTree(ctx) + if err != nil { + return nil, 0, err + } + flat = flattenViewTreeNodes(nodes) } - flat := flattenViewTreeNodes(nodes) filtered := make([]app.ViewTreeNode, 0, len(flat)) for _, node := range flat { - if ownerElementID != nil { - if node.OwnerElementID == nil || int32(*node.OwnerElementID) != *ownerElementID { - continue - } - } if isRoot != nil { nodeIsRoot := node.ParentViewID == nil if nodeIsRoot != *isRoot { @@ -114,17 +127,17 @@ func (a *APIAdapter) DeleteView(ctx context.Context, id int32, _ uuid.UUID) erro return a.Store.legacy.DeleteView(ctx, int64(id)) } -func (a *APIAdapter) ListElements(ctx context.Context, _ uuid.UUID, limit, offset int32, search string) ([]*diagv1.Element, error) { - elements, err := a.Store.legacy.Elements(ctx, int(limit), int(offset), search) +func (a *APIAdapter) ListElements(ctx context.Context, _ uuid.UUID, limit, offset int32, search string) ([]*diagv1.Element, int, error) { + elements, total, err := a.Store.legacy.Elements(ctx, int(limit), int(offset), search) if err != nil { - return nil, err + return nil, 0, err } out := make([]*diagv1.Element, 0, len(elements)) workspaceID := api.WorkspaceIDFromCtx(ctx) for _, element := range elements { out = append(out, elementToProto(element, workspaceID)) } - return out, nil + return out, total, nil } func (a *APIAdapter) GetElement(ctx context.Context, id int32, _ uuid.UUID) (*diagv1.Element, error) { @@ -182,33 +195,32 @@ func (a *APIAdapter) UpdateElement(ctx context.Context, id int32, _ uuid.UUID, i } func (a *APIAdapter) DeleteElement(ctx context.Context, id int32, _ uuid.UUID) error { + if err := a.Store.DeleteResourceVisibilityOverrides(ctx, "element", int64(id)); err != nil { + return err + } return a.Store.legacy.DeleteElement(ctx, int64(id)) } func (a *APIAdapter) ListPlacements(ctx context.Context, viewID int32) ([]*diagv1.PlacedElement, error) { - placements, err := a.Store.legacy.Placements(ctx, int64(viewID)) + content, err := a.Store.ProjectedViewContent(ctx, int64(viewID)) if err != nil { return nil, err } - out := make([]*diagv1.PlacedElement, 0, len(placements)) - for _, placement := range placements { + out := make([]*diagv1.PlacedElement, 0, len(content.Placements)) + for _, placement := range content.Placements { out = append(out, placedElementToProto(placement)) } return out, nil } func (a *APIAdapter) ListAllPlacements(ctx context.Context, _ uuid.UUID) ([]*diagv1.PlacedElement, error) { - nodes, err := a.Store.ViewTree(ctx) + placements, err := a.Store.legacy.AllPlacements(ctx) if err != nil { return nil, err } - var out []*diagv1.PlacedElement - for _, node := range flattenViewTreeNodes(nodes) { - items, err := a.ListPlacements(ctx, int32(node.ID)) - if err != nil { - return nil, err - } - out = append(out, items...) + out := make([]*diagv1.PlacedElement, 0, len(placements)) + for _, placement := range placements { + out = append(out, placedElementToProto(placement)) } return out, nil } @@ -249,29 +261,25 @@ func (a *APIAdapter) RemovePlacement(ctx context.Context, viewID, elementID int3 } func (a *APIAdapter) ListConnectors(ctx context.Context, viewID int32, _ uuid.UUID) ([]*diagv1.Connector, error) { - connectors, err := a.Store.legacy.Connectors(ctx, int64(viewID)) + content, err := a.Store.ProjectedViewContent(ctx, int64(viewID)) if err != nil { return nil, err } - out := make([]*diagv1.Connector, 0, len(connectors)) - for _, connector := range connectors { + out := make([]*diagv1.Connector, 0, len(content.Connectors)) + for _, connector := range content.Connectors { out = append(out, connectorToProto(connector)) } return out, nil } func (a *APIAdapter) ListAllConnectors(ctx context.Context, _ uuid.UUID) ([]*diagv1.Connector, error) { - nodes, err := a.Store.ViewTree(ctx) + connectors, err := a.Store.legacy.AllConnectors(ctx) if err != nil { return nil, err } - var out []*diagv1.Connector - for _, node := range flattenViewTreeNodes(nodes) { - items, err := a.ListConnectors(ctx, int32(node.ID), uuid.Nil) - if err != nil { - return nil, err - } - out = append(out, items...) + out := make([]*diagv1.Connector, 0, len(connectors)) + for _, connector := range connectors { + out = append(out, connectorToProto(connector)) } return out, nil } @@ -326,6 +334,9 @@ func (a *APIAdapter) UpdateConnector(ctx context.Context, id int32, _ uuid.UUID, } func (a *APIAdapter) DeleteConnector(ctx context.Context, id int32, _ uuid.UUID) error { + if err := a.Store.DeleteResourceVisibilityOverrides(ctx, "connector", int64(id)); err != nil { + return err + } return a.Store.legacy.DeleteConnector(ctx, int64(id)) } @@ -421,6 +432,25 @@ func (a *APIAdapter) DeleteViewLayer(ctx context.Context, id int32) error { return a.Store.legacy.DeleteLayer(ctx, int64(id)) } +func (a *APIAdapter) Tags(ctx context.Context, _ uuid.UUID) (map[string]*diagv1.Tag, error) { + tags, err := a.Store.Tags(ctx) + if err != nil { + return nil, err + } + out := make(map[string]*diagv1.Tag, len(tags)) + for name, tag := range tags { + out[name] = &diagv1.Tag{ + Color: tag.Color, + Description: tag.Description, + } + } + return out, nil +} + +func (a *APIAdapter) UpdateTag(ctx context.Context, _ uuid.UUID, name, color string, description *string) error { + return a.Store.UpdateTag(ctx, name, color, description) +} + func (a *APIAdapter) ApplyPlan(ctx context.Context, _ uuid.UUID, req *diagv1.ApplyPlanRequest) (*diagv1.ApplyPlanResponse, error) { if req.GetDryRun() { return nil, connect.NewError(connect.CodeUnimplemented, errors.New("dry_run is not supported by the local sqlite adapter")) @@ -611,40 +641,101 @@ func (a *APIAdapter) ApplyPlan(ctx context.Context, _ uuid.UUID, req *diagv1.App return resp, nil } -func (a *APIAdapter) ListVersions(context.Context, uuid.UUID, int) ([]*diagv1.WorkspaceVersionInfo, error) { - return nil, api.ErrUnimplemented +func (a *APIAdapter) ListVersions(ctx context.Context, workspaceID uuid.UUID, limit int) ([]*diagv1.WorkspaceVersionInfo, error) { + if limit <= 0 { + limit = 50 + } + rows, err := a.Store.DB().QueryContext(ctx, ` + SELECT id, version_id, source, parent_version_id, view_count, element_count, connector_count, description, workspace_hash, created_at + FROM workspace_versions + ORDER BY id DESC + LIMIT ?`, limit) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + var out []*diagv1.WorkspaceVersionInfo + for rows.Next() { + version, err := scanWorkspaceVersion(rows, workspaceID) + if err != nil { + return nil, err + } + out = append(out, version) + } + return out, rows.Err() } -func (a *APIAdapter) GetLatestVersion(context.Context, uuid.UUID) (*diagv1.WorkspaceVersionInfo, error) { - return nil, api.ErrUnimplemented +func (a *APIAdapter) GetLatestVersion(ctx context.Context, workspaceID uuid.UUID) (*diagv1.WorkspaceVersionInfo, error) { + row := a.Store.DB().QueryRowContext(ctx, ` + SELECT id, version_id, source, parent_version_id, view_count, element_count, connector_count, description, workspace_hash, created_at + FROM workspace_versions + ORDER BY id DESC + LIMIT 1`) + version, err := scanWorkspaceVersion(row, workspaceID) + if errors.Is(err, sql.ErrNoRows) { + return nil, api.ErrUnimplemented + } + return version, err } -func (a *APIAdapter) CreateVersion(context.Context, uuid.UUID, string, string, *int32, int, int, int, *string, *string) (*diagv1.WorkspaceVersionInfo, error) { - return nil, api.ErrUnimplemented +func (a *APIAdapter) CreateVersion(ctx context.Context, workspaceID uuid.UUID, versionID, source string, parentID *int32, viewCount, elementCount, connectorCount int, description, workspaceHash *string) (*diagv1.WorkspaceVersionInfo, error) { + var parent any + if parentID != nil { + parent = *parentID + } + res, err := a.Store.DB().ExecContext(ctx, ` + INSERT INTO workspace_versions(version_id, source, parent_version_id, view_count, element_count, connector_count, description, workspace_hash, created_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`, + versionID, source, parent, viewCount, elementCount, connectorCount, description, workspaceHash, time.Now().UTC().Format(time.RFC3339)) + if err != nil { + return nil, err + } + id, err := res.LastInsertId() + if err != nil { + return nil, err + } + row := a.Store.DB().QueryRowContext(ctx, ` + SELECT id, version_id, source, parent_version_id, view_count, element_count, connector_count, description, workspace_hash, created_at + FROM workspace_versions + WHERE id = ?`, id) + return scanWorkspaceVersion(row, workspaceID) } -func (a *APIAdapter) GetVersioningEnabled(context.Context, uuid.UUID) (bool, error) { - return false, api.ErrUnimplemented +func (a *APIAdapter) GetVersioningEnabled(ctx context.Context, _ uuid.UUID) (bool, error) { + var enabled int + err := a.Store.DB().QueryRowContext(ctx, `SELECT cli_versioning_enabled FROM workspace_version_settings WHERE id = 1`).Scan(&enabled) + if errors.Is(err, sql.ErrNoRows) { + return true, nil + } + return enabled != 0, err } -func (a *APIAdapter) SetVersioningEnabled(context.Context, uuid.UUID, bool) error { - return api.ErrUnimplemented +func (a *APIAdapter) SetVersioningEnabled(ctx context.Context, _ uuid.UUID, enabled bool) error { + value := 0 + if enabled { + value = 1 + } + _, err := a.Store.DB().ExecContext(ctx, ` + INSERT INTO workspace_version_settings(id, cli_versioning_enabled) + VALUES (1, ?) + ON CONFLICT(id) DO UPDATE SET cli_versioning_enabled = excluded.cli_versioning_enabled`, value) + return err } func (a *APIAdapter) GetWorkspaceResourceCounts(ctx context.Context, _ uuid.UUID) (views, elements, connectors int, err error) { - allViews, err := a.Store.ViewTree(ctx) - if err != nil { - return 0, 0, 0, err - } - allElements, err := a.Store.legacy.Elements(ctx, 0, 0, "") - if err != nil { - return 0, 0, 0, err - } - allConnectors, err := a.ListAllConnectors(ctx, uuid.Nil) - if err != nil { - return 0, 0, 0, err + for _, item := range []struct { + query string + dest *int + }{ + {query: `SELECT COUNT(*) FROM views`, dest: &views}, + {query: `SELECT COUNT(*) FROM elements`, dest: &elements}, + {query: `SELECT COUNT(*) FROM connectors`, dest: &connectors}, + } { + if err := a.Store.DB().QueryRowContext(ctx, item.query).Scan(item.dest); err != nil { + return 0, 0, 0, err + } } - return len(flattenViewTreeNodes(allViews)), len(allElements), len(allConnectors), nil + return views, elements, connectors, nil } func (a *APIAdapter) ensureRootViewID(ctx context.Context) (int32, error) { @@ -660,6 +751,48 @@ func (a *APIAdapter) ensureRootViewID(ctx context.Context) (int32, error) { return 0, fmt.Errorf("root view not found") } +type sqlRowScanner interface { + Scan(dest ...any) error +} + +func scanWorkspaceVersion(row sqlRowScanner, workspaceID uuid.UUID) (*diagv1.WorkspaceVersionInfo, error) { + var ( + id, viewCount, elementCount, connectorCount int64 + versionID, source, createdAtRaw string + parentID sql.NullInt64 + description sql.NullString + workspaceHash sql.NullString + ) + if err := row.Scan(&id, &versionID, &source, &parentID, &viewCount, &elementCount, &connectorCount, &description, &workspaceHash, &createdAtRaw); err != nil { + return nil, err + } + createdAt, err := time.Parse(time.RFC3339, createdAtRaw) + if err != nil { + createdAt = time.Now().UTC() + } + info := &diagv1.WorkspaceVersionInfo{ + Id: strconv.FormatInt(id, 10), + OrgId: workspaceID.String(), + VersionId: versionID, + Source: source, + ViewCount: int32(viewCount), + ElementCount: int32(elementCount), + ConnectorCount: int32(connectorCount), + CreatedAt: timestamppb.New(createdAt), + } + if parentID.Valid { + parent := strconv.FormatInt(parentID.Int64, 10) + info.ParentVersionId = &parent + } + if description.Valid { + info.Description = &description.String + } + if workspaceHash.Valid { + info.WorkspaceHash = &workspaceHash.String + } + return info, nil +} + func (a *APIAdapter) findPlacedElement(ctx context.Context, viewID, elementID int64) (*diagv1.PlacedElement, error) { items, err := a.Store.legacy.Placements(ctx, viewID) if err != nil { diff --git a/internal/store/apistore_test.go b/internal/store/apistore_test.go new file mode 100644 index 0000000..193f3ef --- /dev/null +++ b/internal/store/apistore_test.go @@ -0,0 +1,171 @@ +package store + +import ( + "context" + "path/filepath" + "testing" + + "github.com/google/uuid" + assets "github.com/mertcikla/tld" + "github.com/mertcikla/tld/pkg/api" +) + +func openAdapterTestStore(t *testing.T) *SQLiteStore { + t.Helper() + sqliteStore, err := Open(filepath.Join(t.TempDir(), "tld.db"), assets.FS) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { _ = sqliteStore.Legacy().Close() }) + return sqliteStore +} + +func TestGetWorkspaceResourceCountsUsesTableCounts(t *testing.T) { + sqliteStore := openAdapterTestStore(t) + + db := sqliteStore.DB() + if _, err := db.Exec(` + INSERT INTO elements(name, tags, technology_connectors, created_at, updated_at) + VALUES + ('A', '[]', '[]', 'now', 'now'), + ('B', '[]', '[]', 'now', 'now'); + INSERT INTO views(owner_element_id, name, description, level_label, level, created_at, updated_at) + VALUES (1, 'A view', NULL, 'Service', 2, 'now', 'now'); + INSERT INTO placements(view_id, element_id, position_x, position_y, created_at, updated_at) + VALUES (1, 1, 0, 0, 'now', 'now'), (2, 2, 10, 10, 'now', 'now'); + INSERT INTO connectors(view_id, source_element_id, target_element_id, direction, style, created_at, updated_at) + VALUES (1, 1, 2, 'forward', 'solid', 'now', 'now'); + `); err != nil { + t.Fatal(err) + } + + views, elements, connectors, err := NewAPIAdapter(sqliteStore).GetWorkspaceResourceCounts(context.Background(), uuid.Nil) + if err != nil { + t.Fatal(err) + } + if views != 2 || elements != 2 || connectors != 1 { + t.Fatalf("counts = views:%d elements:%d connectors:%d, want 2/2/1", views, elements, connectors) + } +} + +func TestGetViewsFiltersDirectChildrenByParentViewID(t *testing.T) { + sqliteStore := openAdapterTestStore(t) + + db := sqliteStore.DB() + if _, err := db.Exec(` + INSERT INTO elements(id, name, tags, technology_connectors, created_at, updated_at) + VALUES + (10, 'Service', '[]', '[]', 'now', 'now'), + (11, 'Component', '[]', '[]', 'now', 'now'); + INSERT INTO views(id, owner_element_id, name, description, level_label, level, created_at, updated_at) + VALUES + (20, 10, 'Service view', NULL, 'Service', 2, 'now', 'now'), + (21, 11, 'Component view', NULL, 'Component', 3, 'now', 'now'); + INSERT INTO placements(view_id, element_id, position_x, position_y, created_at, updated_at) + VALUES + (1, 10, 0, 0, 'now', 'now'), + (20, 11, 10, 10, 'now', 'now'); + `); err != nil { + t.Fatal(err) + } + + parentID := int32(1) + children, total, err := NewAPIAdapter(sqliteStore).GetViews(context.Background(), uuid.Nil, &parentID, nil, "", 0, 0) + if err != nil { + t.Fatal(err) + } + if total != 1 || len(children) != 1 || children[0].GetId() != 20 { + t.Fatalf("root children = total:%d views:%v, want only view 20", total, children) + } + + parentID = 20 + children, total, err = NewAPIAdapter(sqliteStore).GetViews(context.Background(), uuid.Nil, &parentID, nil, "", 0, 0) + if err != nil { + t.Fatal(err) + } + if total != 1 || len(children) != 1 || children[0].GetId() != 21 { + t.Fatalf("nested children = total:%d views:%v, want only view 21", total, children) + } +} + +func TestListElementsMapsSearchPaginationAndViewMetadata(t *testing.T) { + sqliteStore := openAdapterTestStore(t) + db := sqliteStore.DB() + if _, err := db.Exec(` + INSERT INTO elements(id, name, kind, description, tags, technology_connectors, created_at, updated_at) + VALUES + (10, 'API', 'service', 'Public runtime API', '["runtime"]', '[]', 'now', '2026-01-02T00:00:00Z'), + (11, 'Worker', 'service', 'Background jobs', '["runtime"]', '[]', 'now', '2026-01-03T00:00:00Z'); + INSERT INTO views(id, owner_element_id, name, description, level_label, level, created_at, updated_at) + VALUES (20, 10, 'API view', NULL, 'Service', 2, 'now', 'now'); + `); err != nil { + t.Fatal(err) + } + + items, total, err := NewAPIAdapter(sqliteStore).ListElements(context.Background(), uuid.Nil, 1, 0, "runtime") + if err != nil { + t.Fatal(err) + } + if total != 1 || len(items) != 1 || items[0].GetId() != 10 { + t.Fatalf("filtered elements = total:%d items:%+v, want only API", total, items) + } + if !items[0].GetHasView() || items[0].GetViewLabel() != "Service" { + t.Fatalf("view metadata = has:%v label:%q, want Service child view", items[0].GetHasView(), items[0].GetViewLabel()) + } + + items, total, err = NewAPIAdapter(sqliteStore).ListElements(context.Background(), uuid.Nil, 1, 1, "") + if err != nil { + t.Fatal(err) + } + if total != 2 || len(items) != 1 || items[0].GetId() != 10 { + t.Fatalf("paginated elements = total:%d items:%+v, want API as second updated item", total, items) + } +} + +func TestConnectorAdapterPreservesHandlesDefaultsAndViewFiltering(t *testing.T) { + sqliteStore := openAdapterTestStore(t) + db := sqliteStore.DB() + if _, err := db.Exec(` + INSERT INTO elements(id, name, tags, technology_connectors, created_at, updated_at) + VALUES + (10, 'API', '[]', '[]', 'now', 'now'), + (11, 'DB', '[]', '[]', 'now', 'now'); + INSERT INTO views(id, owner_element_id, name, description, level_label, level, created_at, updated_at) + VALUES (20, 10, 'API view', NULL, 'Service', 2, 'now', 'now'); + `); err != nil { + t.Fatal(err) + } + label := "reads" + sourceHandle := "right" + targetHandle := "left" + connector, err := NewAPIAdapter(sqliteStore).CreateConnector(context.Background(), uuid.Nil, api.ConnectorInput{ + ViewID: 20, + SourceID: 10, + TargetID: 11, + Label: &label, + Style: "solid", + SourceHandle: &sourceHandle, + TargetHandle: &targetHandle, + }) + if err != nil { + t.Fatal(err) + } + if connector.GetDirection() != "forward" || connector.GetStyle() != "solid" { + t.Fatalf("connector defaults = direction:%q style:%q, want forward/solid", connector.GetDirection(), connector.GetStyle()) + } + if connector.GetSourceHandle() != "right" || connector.GetTargetHandle() != "left" { + t.Fatalf("connector handles = %q/%q, want right/left", connector.GetSourceHandle(), connector.GetTargetHandle()) + } + + all, err := NewAPIAdapter(sqliteStore).ListAllConnectors(context.Background(), uuid.Nil) + if err != nil { + t.Fatal(err) + } + inView, err := NewAPIAdapter(sqliteStore).ListConnectors(context.Background(), 20, uuid.Nil) + if err != nil { + t.Fatal(err) + } + if len(all) != 1 || len(inView) != 1 || all[0].GetId() != inView[0].GetId() { + t.Fatalf("connector list mismatch: all=%+v inView=%+v", all, inView) + } +} diff --git a/internal/store/density.go b/internal/store/density.go new file mode 100644 index 0000000..3564b29 --- /dev/null +++ b/internal/store/density.go @@ -0,0 +1,527 @@ +package store + +import ( + "context" + "database/sql" + "errors" + "fmt" + "math" + "sort" + "time" + + "github.com/mertcikla/tld/internal/app" +) + +const ( + MinDensityLevel = -2 + MaxDensityLevel = 2 + MinOverrideDelta = -4 + MaxOverrideDelta = 4 +) + +type VisibilityOverride struct { + ViewID int64 `json:"view_id"` + ResourceType string `json:"resource_type"` + ResourceID int64 `json:"resource_id"` + LevelDelta int `json:"level_delta"` + CreatedAt string `json:"created_at"` + UpdatedAt string `json:"updated_at"` +} + +type ProjectedViewContent struct { + Placements []app.PlacedElement `json:"placements"` + Connectors []app.Connector `json:"connectors"` +} + +type densitySignalKey struct { + resourceType string + resourceID int64 +} + +type densitySignals struct { + filterScore map[densitySignalKey]float64 + filterTier map[densitySignalKey]int + architectureConfidence map[densitySignalKey]float64 +} + +func ValidateDensityLevel(level int) error { + if level < MinDensityLevel || level > MaxDensityLevel { + return fmt.Errorf("density_level must be between %d and %d", MinDensityLevel, MaxDensityLevel) + } + return nil +} + +func ValidateResourceType(resourceType string) error { + if resourceType != "element" && resourceType != "connector" { + return errors.New("resource_type must be element or connector") + } + return nil +} + +func clampOverrideDelta(delta int) int { + return min(MaxOverrideDelta, max(MinOverrideDelta, delta)) +} + +func (s *SQLiteStore) ViewDensityLevel(ctx context.Context, viewID int64) (int, error) { + var level int + err := s.DB().QueryRowContext(ctx, `SELECT density_level FROM views WHERE id = ?`, viewID).Scan(&level) + if errors.Is(err, sql.ErrNoRows) { + return 0, err + } + return level, err +} + +func (s *SQLiteStore) SetViewDensityLevel(ctx context.Context, viewID int64, level int) error { + if err := ValidateDensityLevel(level); err != nil { + return err + } + res, err := s.DB().ExecContext(ctx, `UPDATE views SET density_level = ?, updated_at = ? WHERE id = ?`, level, nowString(), viewID) + if err != nil { + return err + } + affected, err := res.RowsAffected() + if err != nil { + return err + } + if affected == 0 { + return sql.ErrNoRows + } + return nil +} + +func (s *SQLiteStore) VisibilityOverrides(ctx context.Context, viewID int64) ([]VisibilityOverride, error) { + rows, err := s.DB().QueryContext(ctx, ` + SELECT view_id, resource_type, resource_id, level_delta, created_at, updated_at + FROM view_visibility_overrides + WHERE view_id = ? + ORDER BY resource_type, resource_id`, viewID) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + out := make([]VisibilityOverride, 0) + for rows.Next() { + var item VisibilityOverride + if err := rows.Scan(&item.ViewID, &item.ResourceType, &item.ResourceID, &item.LevelDelta, &item.CreatedAt, &item.UpdatedAt); err != nil { + return nil, err + } + out = append(out, item) + } + return out, rows.Err() +} + +func (s *SQLiteStore) SetVisibilityOverride(ctx context.Context, viewID int64, resourceType string, resourceID int64, delta int) (VisibilityOverride, error) { + if err := ValidateResourceType(resourceType); err != nil { + return VisibilityOverride{}, err + } + delta = clampOverrideDelta(delta) + if delta == 0 { + if err := s.DeleteVisibilityOverride(ctx, viewID, resourceType, resourceID); err != nil { + return VisibilityOverride{}, err + } + return VisibilityOverride{ViewID: viewID, ResourceType: resourceType, ResourceID: resourceID, LevelDelta: 0}, nil + } + now := nowString() + _, err := s.DB().ExecContext(ctx, ` + INSERT INTO view_visibility_overrides(view_id, resource_type, resource_id, level_delta, created_at, updated_at) + VALUES (?, ?, ?, ?, ?, ?) + ON CONFLICT(view_id, resource_type, resource_id) DO UPDATE SET + level_delta = excluded.level_delta, + updated_at = excluded.updated_at`, + viewID, resourceType, resourceID, delta, now, now) + if err != nil { + return VisibilityOverride{}, err + } + return s.visibilityOverride(ctx, viewID, resourceType, resourceID) +} + +func (s *SQLiteStore) AdjustVisibilityOverride(ctx context.Context, viewID int64, resourceType string, resourceID int64, step int) (VisibilityOverride, error) { + if err := ValidateResourceType(resourceType); err != nil { + return VisibilityOverride{}, err + } + var current int + err := s.DB().QueryRowContext(ctx, ` + SELECT level_delta FROM view_visibility_overrides + WHERE view_id = ? AND resource_type = ? AND resource_id = ?`, viewID, resourceType, resourceID).Scan(¤t) + if err != nil && !errors.Is(err, sql.ErrNoRows) { + return VisibilityOverride{}, err + } + return s.SetVisibilityOverride(ctx, viewID, resourceType, resourceID, current+step) +} + +func (s *SQLiteStore) DeleteVisibilityOverride(ctx context.Context, viewID int64, resourceType string, resourceID int64) error { + if err := ValidateResourceType(resourceType); err != nil { + return err + } + _, err := s.DB().ExecContext(ctx, ` + DELETE FROM view_visibility_overrides + WHERE view_id = ? AND resource_type = ? AND resource_id = ?`, viewID, resourceType, resourceID) + return err +} + +func (s *SQLiteStore) DeleteResourceVisibilityOverrides(ctx context.Context, resourceType string, resourceID int64) error { + if err := ValidateResourceType(resourceType); err != nil { + return err + } + _, err := s.DB().ExecContext(ctx, `DELETE FROM view_visibility_overrides WHERE resource_type = ? AND resource_id = ?`, resourceType, resourceID) + return err +} + +func (s *SQLiteStore) visibilityOverride(ctx context.Context, viewID int64, resourceType string, resourceID int64) (VisibilityOverride, error) { + var item VisibilityOverride + err := s.DB().QueryRowContext(ctx, ` + SELECT view_id, resource_type, resource_id, level_delta, created_at, updated_at + FROM view_visibility_overrides + WHERE view_id = ? AND resource_type = ? AND resource_id = ?`, viewID, resourceType, resourceID).Scan( + &item.ViewID, &item.ResourceType, &item.ResourceID, &item.LevelDelta, &item.CreatedAt, &item.UpdatedAt, + ) + return item, err +} + +func (s *SQLiteStore) ProjectedViewContent(ctx context.Context, viewID int64) (ProjectedViewContent, error) { + level, err := s.ViewDensityLevel(ctx, viewID) + if err != nil { + return ProjectedViewContent{}, err + } + placements, err := s.legacy.Placements(ctx, viewID) + if err != nil { + return ProjectedViewContent{}, err + } + connectors, err := s.legacy.Connectors(ctx, viewID) + if err != nil { + return ProjectedViewContent{}, err + } + if len(placements) == 0 { + return ProjectedViewContent{Placements: placements, Connectors: connectors}, nil + } + caps := capsForDensity(level) + overrides, err := s.VisibilityOverrides(ctx, viewID) + if err != nil { + return ProjectedViewContent{}, err + } + signals := emptyDensitySignals() + if !caps.full { + signals, err = s.densitySignals(ctx, placements, connectors) + if err != nil { + return ProjectedViewContent{}, err + } + } + return projectViewContent(placements, connectors, overrides, level, signals), nil +} + +func emptyDensitySignals() densitySignals { + return densitySignals{ + filterScore: map[densitySignalKey]float64{}, + filterTier: map[densitySignalKey]int{}, + architectureConfidence: map[densitySignalKey]float64{}, + } +} + +func (s *SQLiteStore) densitySignals(ctx context.Context, placements []app.PlacedElement, connectors []app.Connector) (densitySignals, error) { + signals := emptyDensitySignals() + + elementIDs := make([]int64, 0, len(placements)) + for _, placement := range placements { + elementIDs = append(elementIDs, placement.ElementID) + } + connectorIDs := make([]int64, 0, len(connectors)) + for _, connector := range connectors { + connectorIDs = append(connectorIDs, connector.ID) + } + + if err := s.loadFilterSignals(ctx, signals, "element", elementIDs); err != nil { + return densitySignals{}, err + } + if err := s.loadFilterSignals(ctx, signals, "connector", connectorIDs); err != nil { + return densitySignals{}, err + } + if err := s.loadArchitectureSignals(ctx, signals, "element", elementIDs); err != nil { + return densitySignals{}, err + } + if err := s.loadArchitectureSignals(ctx, signals, "connector", connectorIDs); err != nil { + return densitySignals{}, err + } + + return signals, nil +} + +func (s *SQLiteStore) loadFilterSignals(ctx context.Context, signals densitySignals, resourceType string, resourceIDs []int64) error { + return queryIDChunks(resourceIDs, 450, func(ids []int64) error { + query, args := idInQuery(` + SELECT wm.resource_type, wm.resource_id, MAX(wfd.score), MIN(wfd.tier) + FROM watch_materialization wm + JOIN watch_filter_decisions wfd + ON wfd.owner_type = wm.owner_type + AND wfd.owner_key = wm.owner_key + WHERE wm.resource_type = ? AND wm.resource_id IN (%s) + GROUP BY wm.resource_type, wm.resource_id`, resourceType, ids) + rows, err := s.DB().QueryContext(ctx, query, args...) + if err != nil { + return err + } + for rows.Next() { + var rowResourceType string + var resourceID int64 + var score sql.NullFloat64 + var tier sql.NullInt64 + if err := rows.Scan(&rowResourceType, &resourceID, &score, &tier); err != nil { + _ = rows.Close() + return err + } + key := densitySignalKey{resourceType: rowResourceType, resourceID: resourceID} + if score.Valid { + signals.filterScore[key] = score.Float64 + } + if tier.Valid { + signals.filterTier[key] = int(tier.Int64) + } + } + if err := rows.Close(); err != nil { + return err + } + return rows.Err() + }) +} + +func (s *SQLiteStore) loadArchitectureSignals(ctx context.Context, signals densitySignals, resourceType string, resourceIDs []int64) error { + return queryIDChunks(resourceIDs, 450, func(ids []int64) error { + query, args := idInQuery(` + SELECT target_resource_type, target_resource_id, MAX(confidence) + FROM watch_architecture_links + WHERE target_resource_type = ? AND target_resource_id IN (%s) + GROUP BY target_resource_type, target_resource_id`, resourceType, ids) + rows, err := s.DB().QueryContext(ctx, query, args...) + if err != nil { + return err + } + for rows.Next() { + var rowResourceType string + var resourceID int64 + var confidence sql.NullFloat64 + if err := rows.Scan(&rowResourceType, &resourceID, &confidence); err != nil { + _ = rows.Close() + return err + } + if confidence.Valid { + signals.architectureConfidence[densitySignalKey{resourceType: rowResourceType, resourceID: resourceID}] = confidence.Float64 + } + } + if err := rows.Close(); err != nil { + return err + } + return rows.Err() + }) +} + +func queryIDChunks(ids []int64, size int, fn func([]int64) error) error { + if len(ids) == 0 { + return nil + } + for start := 0; start < len(ids); start += size { + end := min(start+size, len(ids)) + if err := fn(ids[start:end]); err != nil { + return err + } + } + return nil +} + +func idInQuery(template string, resourceType string, ids []int64) (string, []any) { + placeholders := make([]byte, 0, len(ids)*2-1) + args := make([]any, 0, len(ids)+1) + args = append(args, resourceType) + for i, id := range ids { + if i > 0 { + placeholders = append(placeholders, ',') + } + placeholders = append(placeholders, '?') + args = append(args, id) + } + return fmt.Sprintf(template, string(placeholders)), args +} + +type densityCaps struct { + elements int + connectors int + full bool +} + +func capsForDensity(level int) densityCaps { + switch level { + case -2: + return densityCaps{elements: 4, connectors: 8} + case -1: + return densityCaps{elements: 8, connectors: 16} + case 1: + return densityCaps{elements: 32, connectors: 64} + case 2: + return densityCaps{full: true} + default: + return densityCaps{elements: 12, connectors: 24} + } +} + +type rankedElement struct { + item app.PlacedElement + score float64 + delta int +} + +type rankedConnector struct { + item app.Connector + score float64 + delta int +} + +func projectViewContent(placements []app.PlacedElement, connectors []app.Connector, overrides []VisibilityOverride, level int, signals densitySignals) ProjectedViewContent { + caps := capsForDensity(level) + elementDeltas := make(map[int64]int) + connectorDeltas := make(map[int64]int) + for _, override := range overrides { + switch override.ResourceType { + case "element": + elementDeltas[override.ResourceID] = override.LevelDelta + case "connector": + connectorDeltas[override.ResourceID] = override.LevelDelta + } + } + + degree := make(map[int64]int) + for _, connector := range connectors { + degree[connector.SourceElementID]++ + degree[connector.TargetElementID]++ + } + + rankedElements := make([]rankedElement, 0, len(placements)) + for _, placement := range placements { + delta := elementDeltas[placement.ElementID] + rankedElements = append(rankedElements, rankedElement{ + item: placement, + score: baseElementScore(placement, degree[placement.ElementID], signals) + float64(delta)*100, + delta: delta, + }) + } + sort.SliceStable(rankedElements, func(i, j int) bool { + if rankedElements[i].score == rankedElements[j].score { + return rankedElements[i].item.ID < rankedElements[j].item.ID + } + return rankedElements[i].score > rankedElements[j].score + }) + + visibleElements := make(map[int64]struct{}) + elementLimit := caps.elements + if caps.full { + elementLimit = len(rankedElements) + } + for _, ranked := range rankedElements { + if ranked.delta <= -4 || (caps.full && ranked.delta < 0) { + continue + } + if !caps.full && len(visibleElements) >= elementLimit && ranked.delta <= 0 { + continue + } + visibleElements[ranked.item.ElementID] = struct{}{} + } + + rankedConnectors := make([]rankedConnector, 0, len(connectors)) + for _, connector := range connectors { + delta := connectorDeltas[connector.ID] + rankedConnectors = append(rankedConnectors, rankedConnector{ + item: connector, + score: baseConnectorScore(connector, signals) + float64(delta)*100, + delta: delta, + }) + } + sort.SliceStable(rankedConnectors, func(i, j int) bool { + if rankedConnectors[i].score == rankedConnectors[j].score { + return rankedConnectors[i].item.ID < rankedConnectors[j].item.ID + } + return rankedConnectors[i].score > rankedConnectors[j].score + }) + + visibleConnectors := make(map[int64]struct{}) + connectorLimit := caps.connectors + if caps.full { + connectorLimit = len(rankedConnectors) + } + for _, ranked := range rankedConnectors { + connector := ranked.item + if ranked.delta <= -4 || (caps.full && ranked.delta < 0) { + continue + } + if ranked.delta > 0 { + visibleElements[connector.SourceElementID] = struct{}{} + visibleElements[connector.TargetElementID] = struct{}{} + } + _, sourceVisible := visibleElements[connector.SourceElementID] + _, targetVisible := visibleElements[connector.TargetElementID] + if !sourceVisible || !targetVisible { + continue + } + if !caps.full && len(visibleConnectors) >= connectorLimit && ranked.delta <= 0 { + continue + } + visibleConnectors[connector.ID] = struct{}{} + } + + outPlacements := make([]app.PlacedElement, 0, len(visibleElements)) + for _, placement := range placements { + if _, ok := visibleElements[placement.ElementID]; ok { + outPlacements = append(outPlacements, placement) + } + } + outConnectors := make([]app.Connector, 0, len(visibleConnectors)) + for _, connector := range connectors { + if _, ok := visibleConnectors[connector.ID]; ok { + outConnectors = append(outConnectors, connector) + } + } + return ProjectedViewContent{Placements: outPlacements, Connectors: outConnectors} +} + +func baseElementScore(placement app.PlacedElement, degree int, signals densitySignals) float64 { + score := float64(degree) * 12 + key := densitySignalKey{resourceType: "element", resourceID: placement.ElementID} + score += signals.filterScore[key] * 30 + if tier, ok := signals.filterTier[key]; ok { + score += float64(max(0, 10-tier)) * 5 + } + score += signals.architectureConfidence[key] * 20 + if placement.HasView { + score += 20 + } + if placement.Description != nil && *placement.Description != "" { + score += 4 + } + if len(placement.Tags) > 0 { + score += 3 + } + if placement.FilePath != nil && *placement.FilePath != "" { + score += 2 + } + return score - math.Log1p(float64(max(0, placement.ID)))*0.001 +} + +func baseConnectorScore(connector app.Connector, signals densitySignals) float64 { + score := 0.0 + key := densitySignalKey{resourceType: "connector", resourceID: connector.ID} + score += signals.filterScore[key] * 30 + if tier, ok := signals.filterTier[key]; ok { + score += float64(max(0, 10-tier)) * 5 + } + score += signals.architectureConfidence[key] * 20 + if connector.Relationship != nil && *connector.Relationship != "" { + score += 10 + } + if connector.Label != nil && *connector.Label != "" { + score += 6 + } + if connector.Description != nil && *connector.Description != "" { + score += 3 + } + return score - math.Log1p(float64(max(0, connector.ID)))*0.001 +} + +func nowString() string { + return time.Now().UTC().Format(time.RFC3339) +} diff --git a/internal/store/density_test.go b/internal/store/density_test.go new file mode 100644 index 0000000..39783e5 --- /dev/null +++ b/internal/store/density_test.go @@ -0,0 +1,166 @@ +package store + +import ( + "context" + "testing" + + "github.com/mertcikla/tld/internal/app" +) + +func seedDensityView(t *testing.T, sqliteStore *SQLiteStore) { + t.Helper() + if _, err := sqliteStore.DB().Exec(` + INSERT INTO elements(id, name, tags, technology_connectors, created_at, updated_at) + VALUES + (101, 'A', '[]', '[]', 'now', 'now'), + (102, 'B', '[]', '[]', 'now', 'now'), + (103, 'C', '[]', '[]', 'now', 'now'), + (104, 'D', '[]', '[]', 'now', 'now'), + (105, 'E', '[]', '[]', 'now', 'now'), + (106, 'F', '[]', '[]', 'now', 'now'); + INSERT INTO placements(view_id, element_id, position_x, position_y, created_at, updated_at) + VALUES + (1, 101, 0, 0, 'now', 'now'), + (1, 102, 10, 0, 'now', 'now'), + (1, 103, 20, 0, 'now', 'now'), + (1, 104, 30, 0, 'now', 'now'), + (1, 105, 40, 0, 'now', 'now'), + (1, 106, 50, 0, 'now', 'now'); + INSERT INTO connectors(id, view_id, source_element_id, target_element_id, label, direction, style, created_at, updated_at) + VALUES + (201, 1, 101, 102, 'important', 'forward', 'solid', 'now', 'now'), + (202, 1, 105, 106, NULL, 'forward', 'solid', 'now', 'now'), + (203, 1, 103, 104, 'important', 'forward', 'solid', 'now', 'now'); + `); err != nil { + t.Fatal(err) + } +} + +func TestDensityValidationAndOverrideClamping(t *testing.T) { + sqliteStore := openAdapterTestStore(t) + ctx := context.Background() + + if err := sqliteStore.SetViewDensityLevel(ctx, 1, -3); err == nil { + t.Fatal("expected invalid density to fail") + } + if err := sqliteStore.SetViewDensityLevel(ctx, 1, 2); err != nil { + t.Fatal(err) + } + level, err := sqliteStore.ViewDensityLevel(ctx, 1) + if err != nil { + t.Fatal(err) + } + if level != 2 { + t.Fatalf("density = %d, want 2", level) + } + + override, err := sqliteStore.SetVisibilityOverride(ctx, 1, "element", 1, 99) + if err != nil { + t.Fatal(err) + } + if override.LevelDelta != 4 { + t.Fatalf("delta = %d, want clamp to 4", override.LevelDelta) + } + override, err = sqliteStore.SetVisibilityOverride(ctx, 1, "element", 1, 0) + if err != nil { + t.Fatal(err) + } + if override.LevelDelta != 0 { + t.Fatalf("reset delta = %d, want 0", override.LevelDelta) + } + overrides, err := sqliteStore.VisibilityOverrides(ctx, 1) + if err != nil { + t.Fatal(err) + } + if len(overrides) != 0 { + t.Fatalf("overrides after reset = %v, want none", overrides) + } +} + +func TestDensityProjectionPromotedConnectorPullsEndpoints(t *testing.T) { + sqliteStore := openAdapterTestStore(t) + seedDensityView(t, sqliteStore) + ctx := context.Background() + + if err := sqliteStore.SetViewDensityLevel(ctx, 1, -2); err != nil { + t.Fatal(err) + } + content, err := sqliteStore.ProjectedViewContent(ctx, 1) + if err != nil { + t.Fatal(err) + } + if len(content.Placements) != 4 { + t.Fatalf("compact placements = %d, want soft cap 4", len(content.Placements)) + } + if containsConnector(content.Connectors, 202) { + t.Fatal("connector 202 should be outside the compact projection before override") + } + + if _, err := sqliteStore.AdjustVisibilityOverride(ctx, 1, "connector", 202, 1); err != nil { + t.Fatal(err) + } + content, err = sqliteStore.ProjectedViewContent(ctx, 1) + if err != nil { + t.Fatal(err) + } + if !containsConnector(content.Connectors, 202) || !containsPlacement(content.Placements, 105) || !containsPlacement(content.Placements, 106) { + t.Fatalf("promoted connector did not pull endpoint: placements=%v connectors=%v", placementIDs(content.Placements), connectorIDs(content.Connectors)) + } +} + +func TestFullDensityKeepsAllExceptExplicitDemotions(t *testing.T) { + sqliteStore := openAdapterTestStore(t) + seedDensityView(t, sqliteStore) + ctx := context.Background() + + if err := sqliteStore.SetViewDensityLevel(ctx, 1, 2); err != nil { + t.Fatal(err) + } + if _, err := sqliteStore.AdjustVisibilityOverride(ctx, 1, "element", 102, -1); err != nil { + t.Fatal(err) + } + content, err := sqliteStore.ProjectedViewContent(ctx, 1) + if err != nil { + t.Fatal(err) + } + if containsPlacement(content.Placements, 102) { + t.Fatal("demoted element should be hidden at full density") + } + if containsConnector(content.Connectors, 201) { + t.Fatal("connector incident to hidden element should be hidden") + } +} + +func containsPlacement(items []app.PlacedElement, elementID int64) bool { + for _, item := range items { + if item.ElementID == elementID { + return true + } + } + return false +} + +func containsConnector(items []app.Connector, connectorID int64) bool { + for _, item := range items { + if item.ID == connectorID { + return true + } + } + return false +} + +func placementIDs(items []app.PlacedElement) []int64 { + out := make([]int64, 0, len(items)) + for _, item := range items { + out = append(out, item.ElementID) + } + return out +} + +func connectorIDs(items []app.Connector) []int64 { + out := make([]int64, 0, len(items)) + for _, item := range items { + out = append(out, item.ID) + } + return out +} diff --git a/internal/store/sqlite.go b/internal/store/sqlite.go index a6634e2..8edb761 100644 --- a/internal/store/sqlite.go +++ b/internal/store/sqlite.go @@ -2,6 +2,7 @@ package store import ( "context" + "database/sql" "embed" "github.com/mertcikla/tld/internal/app" @@ -26,6 +27,10 @@ func (s *SQLiteStore) Legacy() *app.Store { return s.legacy } +func (s *SQLiteStore) DB() *sql.DB { + return s.legacy.DB() +} + func (s *SQLiteStore) ViewTree(ctx context.Context) ([]core.ViewTreeNode, error) { out, err := s.legacy.ViewTree(ctx) if err != nil { @@ -138,12 +143,12 @@ func (s *SQLiteStore) ThumbnailSVG(ctx context.Context, viewID int64) (string, e return s.legacy.ThumbnailSVG(ctx, viewID) } -func (s *SQLiteStore) Elements(ctx context.Context, limit, offset int, search string) ([]core.LibraryElement, error) { - out, err := s.legacy.Elements(ctx, limit, offset, search) +func (s *SQLiteStore) Elements(ctx context.Context, limit, offset int, search string) ([]core.LibraryElement, int, error) { + out, total, err := s.legacy.Elements(ctx, limit, offset, search) if err != nil { - return nil, err + return nil, 0, err } - return convertSlice(out, func(v app.LibraryElement) core.LibraryElement { return core.LibraryElement(v) }), nil + return convertSlice(out, func(v app.LibraryElement) core.LibraryElement { return core.LibraryElement(v) }), total, nil } func (s *SQLiteStore) ElementByID(ctx context.Context, id int64) (core.LibraryElement, error) { diff --git a/internal/symbol/grammars/src/go/main.go b/internal/symbol/grammars/src/go/main.go index 77c1dd3..0381cc5 100644 --- a/internal/symbol/grammars/src/go/main.go +++ b/internal/symbol/grammars/src/go/main.go @@ -31,6 +31,10 @@ type Ref struct { Line int `json:"line"` } +type Reffer struct { + Name string `json:"name"` + Line int `json:"line"` +} type Result struct { Symbols []Symbol `json:"symbols"` Refs []Ref `json:"refs"` diff --git a/internal/tagcolors/tagcolors.go b/internal/tagcolors/tagcolors.go new file mode 100644 index 0000000..9e665b1 --- /dev/null +++ b/internal/tagcolors/tagcolors.go @@ -0,0 +1,115 @@ +package tagcolors + +import ( + "context" + crand "crypto/rand" + "database/sql" + "fmt" + "hash/fnv" + "strings" +) + +var SwatchColors = []string{ + "#F56565", "#ED8936", "#ECC94B", "#48BB78", "#38B2AC", + "#4299E1", "#667EEA", "#9F7AEA", "#ED64A6", "#A0AEC0", +} + +func Ensure(ctx context.Context, db *sql.DB, tags []string) error { + if len(tags) == 0 { + return nil + } + + rows, err := db.QueryContext(ctx, `SELECT name, color FROM tags ORDER BY name`) + if err != nil { + return err + } + defer func() { _ = rows.Close() }() + + existing := map[string]struct{}{} + var usedColors []string + for rows.Next() { + var name, color string + if err := rows.Scan(&name, &color); err != nil { + return err + } + existing[name] = struct{}{} + usedColors = append(usedColors, color) + } + if err := rows.Err(); err != nil { + return err + } + + for _, name := range tags { + name = strings.TrimSpace(name) + if name == "" { + continue + } + if _, ok := existing[name]; ok { + continue + } + color := PickUnusedColor(usedColors) + if _, err := db.ExecContext(ctx, `INSERT OR IGNORE INTO tags(name, color, description) VALUES (?, ?, NULL)`, name, color); err != nil { + return err + } + usedColors = append(usedColors, color) + existing[name] = struct{}{} + } + return nil +} + +func PickUnusedColor(usedColors []string) string { + used := make(map[string]bool) + for _, c := range usedColors { + used[strings.ToUpper(c)] = true + } + + var pool []string + for _, c := range SwatchColors { + if !used[strings.ToUpper(c)] { + pool = append(pool, c) + } + } + + source := pool + if len(source) == 0 { + return randomUnusedColor(used) + } + + return source[randomIndex(len(source))] +} + +func randomIndex(n int) int { + if n <= 1 { + return 0 + } + var b [1]byte + if _, err := crand.Read(b[:]); err == nil { + return int(b[0]) % n + } + return 0 +} + +func randomUnusedColor(used map[string]bool) string { + var b [3]byte + for range 32 { + if _, err := crand.Read(b[:]); err == nil { + color := fmt.Sprintf("#%02X%02X%02X", b[0], b[1], b[2]) + if !used[color] { + return color + } + } + } + return fallbackUnusedColor(used) +} + +func fallbackUnusedColor(used map[string]bool) string { + for i := 0; ; i++ { + h := fnv.New32a() + _, _ = fmt.Fprintf(h, "tld-tag-color-%d", i) + sum := h.Sum32() + color := fmt.Sprintf("#%06X", sum&0xFFFFFF) + if !used[color] { + return color + } + } +} diff --git a/internal/tech/icons.json b/internal/tech/icons.json index 80864ce..3682736 100644 --- a/internal/tech/icons.json +++ b/internal/tech/icons.json @@ -14308,5 +14308,15 @@ "name": "Project", "nameShort": "Project", "defaultSlug": "gcp-project" + }, + { + "name": "Architecture", + "nameShort": "Architecture", + "defaultSlug": "architecture" + }, + { + "name": "Structural", + "nameShort": "Structural", + "defaultSlug": "structural" } ] diff --git a/internal/tech/tech.go b/internal/tech/tech.go index 4b7aedc..2f5fbe9 100644 --- a/internal/tech/tech.go +++ b/internal/tech/tech.go @@ -52,7 +52,7 @@ func initializeCatalog() { manualAliases := []string{ "go", "postgres", "node", "ts", "js", "tailwind", "tailwindcss", "next.js", "k8s", "dockerfile", "python3", "cpp", "c#", "dotnet", - "aws", "gcp", "azure", + "aws", "gcp", "azure", "container", } for _, alias := range manualAliases { diff --git a/internal/tech/tech_test.go b/internal/tech/tech_test.go new file mode 100644 index 0000000..5e8bd71 --- /dev/null +++ b/internal/tech/tech_test.go @@ -0,0 +1,9 @@ +package tech + +import "testing" + +func TestValidateAcceptsContainerAsDockerAlias(t *testing.T) { + if missing := Validate("Container"); len(missing) != 0 { + t.Fatalf("Validate(%q) missing = %v, want none", "Container", missing) + } +} diff --git a/internal/term/ui.go b/internal/term/ui.go new file mode 100644 index 0000000..eb64ad3 --- /dev/null +++ b/internal/term/ui.go @@ -0,0 +1,109 @@ +package term + +import ( + "fmt" + "io" +) + +// ANSI color/style constants for richer output. +const ( + ColorCyan = "\033[36m" + ColorBold = "\033[1m" + ColorDim = "\033[2m" + StyleUnderlineGreenURL = ColorGreen + ColorUnderline +) + +// Styled icon prefixes. +const ( + iconSuccess = "✓" + iconInfo = "→" + iconWarn = "!" + iconFail = "✗" +) + +// Success prints a green "✓ " line. +func Success(w io.Writer, msg string) { + prefix := Colorize(w, ColorGreen, iconSuccess) + _, _ = fmt.Fprintf(w, "%s %s\n", prefix, msg) +} + +// Successf prints a green "✓ " line. +func Successf(w io.Writer, format string, args ...any) { + Success(w, fmt.Sprintf(format, args...)) +} + +// Info prints a cyan "→ " line. +func Info(w io.Writer, msg string) { + prefix := Colorize(w, ColorCyan, iconInfo) + _, _ = fmt.Fprintf(w, "%s %s\n", prefix, msg) +} + +// Infof prints a cyan "→ " line. +func Infof(w io.Writer, format string, args ...any) { + Info(w, fmt.Sprintf(format, args...)) +} + +// Warn prints a yellow "! " line. +func Warn(w io.Writer, msg string) { + prefix := Colorize(w, ColorYellow, iconWarn) + _, _ = fmt.Fprintf(w, "%s %s\n", prefix, msg) +} + +// Warnf prints a yellow "! " line. +func Warnf(w io.Writer, format string, args ...any) { + Warn(w, fmt.Sprintf(format, args...)) +} + +// Fail prints a red "✗ " line. +func Fail(w io.Writer, msg string) { + prefix := Colorize(w, ColorRed, iconFail) + _, _ = fmt.Fprintf(w, "%s %s\n", prefix, msg) +} + +// Failf prints a red "✗ " line. +func Failf(w io.Writer, format string, args ...any) { + Fail(w, fmt.Sprintf(format, args...)) +} + +// Label formats " : \n" with the key in cyan/bold when color is on. +// It uses a fixed-width left column (width chars, left-aligned). +func Label(w io.Writer, width int, key, value string) { + keyFmt := fmt.Sprintf("%-*s", width, key+":") + if IsColorEnabled(w) { + keyFmt = ColorCyan + ColorBold + keyFmt + ColorReset + } + _, _ = fmt.Fprintf(w, " %s %s\n", keyFmt, value) +} + +// URL returns the url styled as an underlined green clickable link (when color is on). +func URL(w io.Writer, url string) string { + if !IsColorEnabled(w) { + return url + } + return StyleUnderlineGreenURL + url + ColorReset +} + +// Path returns the path styled as blue text (when color is on). +func Path(w io.Writer, path string) string { + return Colorize(w, ColorBlue, path) +} + +// Badge returns a colored inline badge like "[watch]" used in log-style output. +func Badge(w io.Writer, color, text string) string { + return Colorize(w, color, text) +} + +// Dim returns the text styled as dim/grey (when color is on). +func Dim(w io.Writer, text string) string { + return Colorize(w, ColorDim, text) +} + +// Hint prints a dim indented hint line. +func Hint(w io.Writer, hint string) { + _, _ = fmt.Fprintf(w, " %s\n", Dim(w, hint)) +} + +// Separator prints a blank line. +func Separator(w io.Writer) { + _, _ = fmt.Fprintln(w) +} diff --git a/internal/watch/architecture.go b/internal/watch/architecture.go new file mode 100644 index 0000000..436a5c7 --- /dev/null +++ b/internal/watch/architecture.go @@ -0,0 +1,778 @@ +package watch + +import ( + "bufio" + "bytes" + "fmt" + "io" + "net" + "net/url" + "os" + "path" + "path/filepath" + "regexp" + "sort" + "strconv" + "strings" + + "gopkg.in/yaml.v3" +) + +type architectureModel struct { + Components map[string]*architectureComponent + Connectors map[string]*architectureConnector +} + +type architectureComponent struct { + Key string + Name string + Kind string + Technology string + Description string + FilePath string + Tags []string + Evidence []architectureEvidence +} + +type architectureConnector struct { + Key string + SourceKey string + TargetKey string + Label string + Relationship string + Direction string + Description string + Confidence float64 + Evidence []architectureEvidence +} + +type architectureEvidence struct { + Kind string + Path string + Note string +} + +type architectureEndpoint struct { + ComponentKey string + Name string + Hosts []string + Ports []int + Protocol string + FilePath string +} + +type architectureEndpointRef struct { + SourceKey string + Target string + Protocol string + FilePath string + Note string +} + +func inferArchitecture(repoRoot string) architectureModel { + return inferArchitectureWithProgress(repoRoot, nil) +} + +func inferArchitectureWithProgress(repoRoot string, progress ProgressSink) architectureModel { + model := architectureModel{ + Components: map[string]*architectureComponent{}, + Connectors: map[string]*architectureConnector{}, + } + if strings.TrimSpace(repoRoot) == "" { + return model + } + collector := &architectureCollector{ + root: repoRoot, + model: model, + endpoints: map[string]architectureEndpoint{}, + } + files := collectArchitectureFiles(repoRoot) + progressStart(progress, "Inferring architecture artifacts", len(files)) + defer progressFinish(progress) + for _, file := range files { + collector.scanFile(filepath.Join(repoRoot, filepath.FromSlash(file)), file) + progressAdvance(progress, file) + } + collector.resolveEndpointRefs() + return model +} + +func collectArchitectureFiles(repoRoot string) []string { + var files []string + _ = filepath.WalkDir(repoRoot, func(absPath string, entry os.DirEntry, err error) error { + if err != nil { + return nil + } + if entry.IsDir() { + if shouldSkipArchitectureDir(entry.Name()) && absPath != repoRoot { + return filepath.SkipDir + } + return nil + } + rel := filepath.ToSlash(mustRel(repoRoot, absPath)) + if shouldSkipArchitectureFile(rel, entry.Name()) { + return nil + } + if isArchitectureArtifact(rel) { + files = append(files, rel) + } + return nil + }) + sort.Strings(files) + return files +} + +func isArchitectureArtifact(rel string) bool { + switch strings.ToLower(filepath.Ext(rel)) { + case ".yaml", ".yml", ".proto", ".tf", ".json": + return true + default: + return false + } +} + +type architectureCollector struct { + root string + model architectureModel + endpoints map[string]architectureEndpoint + endpointRefs []architectureEndpointRef +} + +func (c *architectureCollector) scanFile(absPath, rel string) { + ext := strings.ToLower(filepath.Ext(rel)) + switch ext { + case ".yaml", ".yml": + c.scanYAML(absPath, rel) + case ".proto": + c.scanProto(absPath, rel) + case ".tf": + c.scanTerraform(absPath, rel) + case ".json": + c.scanJSONSpec(absPath, rel) + } +} + +func (c *architectureCollector) scanYAML(absPath, rel string) { + data, err := os.ReadFile(absPath) + if err != nil || !looksLikeRuntimeYAML(data) { + return + } + dec := yaml.NewDecoder(bytes.NewReader(data)) + for { + var doc map[string]any + err := dec.Decode(&doc) + if err == io.EOF { + break + } + if err != nil { + break + } + if len(doc) == 0 { + continue + } + c.consumeYAMLDocument(doc, rel) + } +} + +func (c *architectureCollector) consumeYAMLDocument(doc map[string]any, rel string) { + kind := stringValue(doc["kind"]) + apiVersion := stringValue(doc["apiVersion"]) + switch strings.ToLower(kind) { + case "deployment", "statefulset", "daemonset", "replicaset", "job", "cronjob", "pod": + c.consumeKubernetesWorkload(kind, doc, rel) + case "service": + c.consumeKubernetesService(doc, rel) + case "ingress", "gateway", "httproute", "tcproute", "virtualservice": + c.consumeIngress(kind, doc, rel) + case "networkpolicy": + c.consumeNetworkPolicy(doc, rel) + case "serviceentry": + c.consumeExternalServiceEntry(doc, rel) + default: + if strings.Contains(strings.ToLower(apiVersion), "gateway.networking") { + c.consumeIngress(kind, doc, rel) + } + } +} + +func (c *architectureCollector) consumeKubernetesWorkload(kind string, doc map[string]any, rel string) { + name := metadataName(doc) + if name == "" { + return + } + key := architectureKey("component", name) + component := c.ensureComponent(key, name, "service", "Kubernetes", rel, architectureEvidence{Kind: "deployable", Path: rel, Note: kind}) + component.Tags = appendUnique(component.Tags, "arch:deployable", "runtime:kubernetes") + for _, container := range kubernetesContainers(doc) { + if image := stringValue(container["image"]); image != "" { + component.Technology = firstNonEmpty(component.Technology, imageTechnology(image), "Container") + component.Tags = appendUnique(component.Tags, "arch:container") + } + for _, port := range portsFromList(sliceValue(container["ports"])) { + c.addEndpoint(name, architectureEndpoint{ComponentKey: key, Name: name, Hosts: []string{name}, Ports: []int{port}, FilePath: rel}) + } + for _, ref := range endpointRefsFromEnv(sliceValue(container["env"])) { + ref.SourceKey = key + ref.FilePath = rel + c.endpointRefs = append(c.endpointRefs, ref) + } + } +} + +func (c *architectureCollector) consumeKubernetesService(doc map[string]any, rel string) { + name := metadataName(doc) + if name == "" { + return + } + key := architectureKey("component", name) + component := c.ensureComponent(key, name, "service", "Kubernetes", rel, architectureEvidence{Kind: "endpoint", Path: rel, Note: "Kubernetes Service"}) + component.Tags = appendUnique(component.Tags, "arch:endpoint", "runtime:kubernetes") + spec := mapValue(doc["spec"]) + if serviceType := stringValue(spec["type"]); strings.EqualFold(serviceType, "LoadBalancer") || strings.EqualFold(serviceType, "NodePort") { + externalKey := architectureKey("external", "external traffic") + c.ensureComponent(externalKey, "External traffic", "external", "Network", rel, architectureEvidence{Kind: "ingress", Path: rel, Note: serviceType}) + c.addConnector(externalKey, key, "routes", "ingress", 0.82, architectureEvidence{Kind: "ingress", Path: rel, Note: serviceType}) + } + ports := portsFromList(sliceValue(spec["ports"])) + hosts := []string{name} + if clusterIP := stringValue(spec["clusterIP"]); clusterIP != "" && !strings.EqualFold(clusterIP, "none") { + hosts = append(hosts, clusterIP) + } + c.addEndpoint(name, architectureEndpoint{ComponentKey: key, Name: name, Hosts: hosts, Ports: ports, FilePath: rel}) +} + +func (c *architectureCollector) consumeIngress(kind string, doc map[string]any, rel string) { + name := metadataName(doc) + externalKey := architectureKey("external", "external traffic") + c.ensureComponent(externalKey, "External traffic", "external", "Network", rel, architectureEvidence{Kind: "ingress", Path: rel, Note: kind}) + for _, target := range serviceNamesFromYAML(doc) { + if targetKey := c.lookupEndpointTarget(target); targetKey != "" { + c.addConnector(externalKey, targetKey, "routes", "ingress", 0.82, architectureEvidence{Kind: "ingress", Path: rel, Note: firstNonEmpty(name, kind)}) + } + } +} + +func (c *architectureCollector) consumeNetworkPolicy(doc map[string]any, rel string) { + targetName := metadataName(doc) + targetKey := c.lookupEndpointTarget(targetName) + if targetKey == "" { + return + } + for _, peer := range serviceNamesFromYAML(doc) { + sourceKey := c.lookupEndpointTarget(peer) + if sourceKey == "" || sourceKey == targetKey { + continue + } + c.addConnector(sourceKey, targetKey, "allows", "network-policy", 0.55, architectureEvidence{Kind: "network-policy", Path: rel, Note: "ingress policy"}) + } +} + +func (c *architectureCollector) consumeExternalServiceEntry(doc map[string]any, rel string) { + for _, host := range stringList(mapValue(doc["spec"])["hosts"]) { + key := architectureKey("external", host) + c.ensureComponent(key, host, "external", "Network", rel, architectureEvidence{Kind: "external-dependency", Path: rel, Note: "service entry"}) + c.addEndpoint(host, architectureEndpoint{ComponentKey: key, Name: host, Hosts: []string{host}, FilePath: rel}) + } +} + +func (c *architectureCollector) scanProto(absPath, rel string) { + data, err := os.ReadFile(absPath) + if err != nil || isGeneratedSource(rel, data) { + return + } + re := regexp.MustCompile(`(?m)^\s*service\s+([A-Za-z_][A-Za-z0-9_]*)\s*\{`) + matches := re.FindAllStringSubmatch(string(data), -1) + for _, match := range matches { + name := match[1] + key := architectureKey("contract", name) + component := c.ensureComponent(key, name, "interface", "gRPC", rel, architectureEvidence{Kind: "service-contract", Path: rel, Note: "protobuf service"}) + component.Tags = appendUnique(component.Tags, "arch:contract", "protocol:grpc") + c.addEndpoint(name, architectureEndpoint{ComponentKey: key, Name: name, Hosts: []string{name}, Protocol: "grpc", FilePath: rel}) + } +} + +func (c *architectureCollector) scanJSONSpec(absPath, rel string) { + data, err := os.ReadFile(absPath) + if err != nil || !bytes.Contains(data, []byte(`"openapi"`)) { + return + } + name := strings.TrimSuffix(path.Base(rel), path.Ext(rel)) + key := architectureKey("contract", name) + component := c.ensureComponent(key, name, "interface", "OpenAPI", rel, architectureEvidence{Kind: "service-contract", Path: rel, Note: "OpenAPI"}) + component.Tags = appendUnique(component.Tags, "arch:contract", "protocol:http") +} + +func (c *architectureCollector) scanTerraform(absPath, rel string) { + file, err := os.Open(absPath) + if err != nil { + return + } + defer func() { _ = file.Close() }() + re := regexp.MustCompile(`^\s*resource\s+"([^"]+)"\s+"([^"]+)"`) + scanner := bufio.NewScanner(file) + for scanner.Scan() { + match := re.FindStringSubmatch(scanner.Text()) + if len(match) != 3 { + continue + } + resourceType, resourceName := match[1], match[2] + kind, tech, ok := infrastructureKind(resourceType) + if !ok { + continue + } + key := architectureKey(kind, resourceType+":"+resourceName) + component := c.ensureComponent(key, resourceName, kind, tech, rel, architectureEvidence{Kind: "infrastructure", Path: rel, Note: resourceType}) + component.Tags = appendUnique(component.Tags, "arch:infrastructure") + c.addEndpoint(resourceName, architectureEndpoint{ComponentKey: key, Name: resourceName, Hosts: []string{resourceName}, FilePath: rel}) + } +} + +func (c *architectureCollector) resolveEndpointRefs() { + for _, ref := range c.endpointRefs { + targetKey := c.lookupEndpointTarget(ref.Target) + if targetKey == "" || ref.SourceKey == "" || ref.SourceKey == targetKey { + continue + } + label := "uses" + if ref.Protocol != "" { + label = ref.Protocol + } + c.addConnector(ref.SourceKey, targetKey, label, "runtime-dependency", 0.78, architectureEvidence{Kind: "consumed-endpoint", Path: ref.FilePath, Note: ref.Note}) + } +} + +func (c *architectureCollector) lookupEndpointTarget(value string) string { + host := normalizeEndpointHost(value) + if host == "" { + return "" + } + if ep, ok := c.endpoints[host]; ok { + return ep.ComponentKey + } + return "" +} + +func (c *architectureCollector) addEndpoint(host string, ep architectureEndpoint) { + for _, candidate := range append(ep.Hosts, host, ep.Name) { + normalized := normalizeEndpointHost(candidate) + if normalized == "" { + continue + } + c.endpoints[normalized] = ep + short := strings.Split(normalized, ".")[0] + if short != "" { + c.endpoints[short] = ep + } + } +} + +func (c *architectureCollector) ensureComponent(key, name, kind, technology, filePath string, evidence architectureEvidence) *architectureComponent { + if existing := c.model.Components[key]; existing != nil { + existing.Evidence = append(existing.Evidence, evidence) + existing.Technology = firstNonEmpty(existing.Technology, technology) + existing.FilePath = firstNonEmpty(existing.FilePath, filePath) + return existing + } + component := &architectureComponent{ + Key: key, + Name: name, + Kind: kind, + Technology: technology, + FilePath: filePath, + Tags: []string{"arch:component"}, + Evidence: []architectureEvidence{evidence}, + } + c.model.Components[key] = component + return component +} + +func (c *architectureCollector) addConnector(sourceKey, targetKey, label, relationship string, confidence float64, evidence architectureEvidence) { + if sourceKey == "" || targetKey == "" || sourceKey == targetKey { + return + } + key := sourceKey + "->" + targetKey + ":" + relationship + ":" + label + if existing := c.model.Connectors[key]; existing != nil { + existing.Evidence = append(existing.Evidence, evidence) + if confidence > existing.Confidence { + existing.Confidence = confidence + } + return + } + c.model.Connectors[key] = &architectureConnector{ + Key: key, + SourceKey: sourceKey, + TargetKey: targetKey, + Label: label, + Relationship: relationship, + Direction: "forward", + Confidence: confidence, + Evidence: []architectureEvidence{evidence}, + } +} + +func shouldSkipArchitectureDir(name string) bool { + switch strings.ToLower(name) { + case ".git", ".tld", "node_modules", "vendor", "dist", "build", "target", ".cache", ".terraform", "coverage": + return true + default: + return false + } +} + +func shouldSkipArchitectureFile(rel, name string) bool { + lower := strings.ToLower(name) + switch filepath.Ext(lower) { + case ".png", ".jpg", ".jpeg", ".gif", ".svg", ".ico", ".css", ".lock", ".sum", ".md": + return true + } + return strings.Contains(lower, ".test.") || strings.HasSuffix(lower, "_test.go") || strings.HasSuffix(lower, ".generated.go") +} + +func looksLikeRuntimeYAML(data []byte) bool { + lower := bytes.ToLower(data) + return bytes.Contains(lower, []byte("kind:")) || bytes.Contains(lower, []byte("services:")) || bytes.Contains(lower, []byte("openapi:")) +} + +func metadataName(doc map[string]any) string { + return stringValue(mapValue(doc["metadata"])["name"]) +} + +func kubernetesContainers(doc map[string]any) []map[string]any { + var out []map[string]any + spec := mapValue(doc["spec"]) + templateSpec := mapValue(mapValue(mapValue(spec["template"])["spec"])) + for _, raw := range append(sliceValue(templateSpec["initContainers"]), sliceValue(templateSpec["containers"])...) { + if container := mapValue(raw); len(container) > 0 { + out = append(out, container) + } + } + for _, raw := range sliceValue(spec["containers"]) { + if container := mapValue(raw); len(container) > 0 { + out = append(out, container) + } + } + return out +} + +func portsFromList(values []any) []int { + var out []int + for _, raw := range values { + switch v := raw.(type) { + case map[string]any: + if port := intValue(v["containerPort"]); port > 0 { + out = append(out, port) + } + if port := intValue(v["port"]); port > 0 { + out = append(out, port) + } + if port := intValue(v["targetPort"]); port > 0 { + out = append(out, port) + } + case int: + out = append(out, v) + case string: + if port, err := strconv.Atoi(strings.TrimSpace(v)); err == nil { + out = append(out, port) + } + } + } + return uniqueInts(out) +} + +func endpointRefsFromEnv(values []any) []architectureEndpointRef { + var refs []architectureEndpointRef + for _, raw := range values { + switch v := raw.(type) { + case map[string]any: + value := stringValue(v["value"]) + if value == "" { + continue + } + if target := endpointHostCandidate(value); target != "" { + refs = append(refs, architectureEndpointRef{Target: target, Protocol: protocolFromEndpointValue(value), Note: "environment endpoint value"}) + } + case string: + if _, after, ok := strings.Cut(v, "="); ok { + if target := endpointHostCandidate(after); target != "" { + refs = append(refs, architectureEndpointRef{Target: target, Protocol: protocolFromEndpointValue(v), Note: "environment endpoint value"}) + } + } + } + } + return refs +} + +func endpointHostCandidate(value string) string { + value = strings.Trim(strings.TrimSpace(value), `"'`) + if value == "" || strings.Contains(value, "{{") || strings.Contains(value, "$(") { + return "" + } + if parsed, err := url.Parse(value); err == nil && parsed.Hostname() != "" { + return parsed.Hostname() + } + if host, _, err := net.SplitHostPort(value); err == nil { + return host + } + if strings.Contains(value, ":") { + host := strings.Split(value, ":")[0] + if looksLikeHost(host) { + return host + } + } + if looksLikeHost(value) { + return value + } + return "" +} + +func protocolFromEndpointValue(value string) string { + lower := strings.ToLower(value) + switch { + case strings.HasPrefix(lower, "http://"), strings.HasPrefix(lower, "https://"): + return "http" + case strings.Contains(lower, ":"): + return "uses" + default: + return "" + } +} + +func looksLikeHost(value string) bool { + if value == "" || strings.ContainsAny(value, " /\\") { + return false + } + if strings.Contains(value, ".") { + return true + } + for _, r := range value { + if (r < 'a' || r > 'z') && (r < 'A' || r > 'Z') && (r < '0' || r > '9') && r != '-' { + return false + } + } + return true +} + +func normalizeEndpointHost(value string) string { + host := endpointHostCandidate(value) + if host == "" { + host = strings.TrimSpace(value) + } + host = strings.Trim(strings.ToLower(host), ".") + if host == "" { + return "" + } + return host +} + +func serviceNamesFromYAML(value any) []string { + seen := map[string]struct{}{} + var walk func(any) + walk = func(raw any) { + switch v := raw.(type) { + case map[string]any: + for key, child := range v { + lower := strings.ToLower(key) + if lower == "name" || lower == "service" || lower == "servicename" || lower == "host" { + if name := endpointHostCandidate(stringValue(child)); name != "" { + seen[name] = struct{}{} + } + } + walk(child) + } + case []any: + for _, child := range v { + walk(child) + } + } + } + walk(value) + return sortedKeys(seen) +} + +func infrastructureKind(resourceType string) (string, string, bool) { + lower := strings.ToLower(resourceType) + switch { + case strings.Contains(lower, "redis"), strings.Contains(lower, "memcache"), strings.Contains(lower, "cache"): + return "datastore", "Cache", true + case strings.Contains(lower, "sql"), strings.Contains(lower, "database"), strings.Contains(lower, "spanner"), strings.Contains(lower, "alloydb"), strings.Contains(lower, "postgres"), strings.Contains(lower, "mysql"): + return "datastore", "Database", true + case strings.Contains(lower, "queue"), strings.Contains(lower, "pubsub"), strings.Contains(lower, "topic"), strings.Contains(lower, "subscription"): + return "queue", "Messaging", true + case strings.Contains(lower, "bucket"), strings.Contains(lower, "storage"): + return "datastore", "Object Storage", true + default: + return "", "", false + } +} + +func isGeneratedSource(rel string, data []byte) bool { + lowerPath := strings.ToLower(rel) + lowerHead := strings.ToLower(string(data[:minInt(len(data), 4096)])) + return strings.Contains(lowerPath, "genproto/") || + strings.Contains(lowerPath, "_pb2") || + strings.Contains(lowerHead, "code generated") || + strings.Contains(lowerHead, "generated by") || + strings.Contains(lowerHead, "@generated") +} + +func imageTechnology(image string) string { + base := strings.ToLower(path.Base(strings.Split(image, ":")[0])) + switch { + case strings.Contains(base, "redis"): + return "Redis" + case strings.Contains(base, "postgres"): + return "PostgreSQL" + case strings.Contains(base, "mysql"): + return "MySQL" + case strings.Contains(base, "nginx"): + return "Nginx" + default: + return "Container" + } +} + +func architectureKey(kind, name string) string { + return kind + ":" + architectureSlug(name) +} + +func architectureSlug(value string) string { + value = strings.TrimSpace(strings.ToLower(value)) + var b strings.Builder + lastDash := false + for _, r := range value { + switch { + case r >= 'a' && r <= 'z', r >= '0' && r <= '9': + b.WriteRune(r) + lastDash = false + default: + if !lastDash && b.Len() > 0 { + b.WriteByte('-') + lastDash = true + } + } + } + out := strings.Trim(b.String(), "-") + if out == "" { + return "unknown" + } + return out +} + +func mapValue(raw any) map[string]any { + if value, ok := raw.(map[string]any); ok { + return value + } + return nil +} + +func sliceValue(raw any) []any { + if value, ok := raw.([]any); ok { + return value + } + return nil +} + +func stringValue(raw any) string { + switch v := raw.(type) { + case string: + return strings.TrimSpace(v) + case int: + return strconv.Itoa(v) + case int64: + return strconv.FormatInt(v, 10) + case float64: + if v == float64(int64(v)) { + return strconv.FormatInt(int64(v), 10) + } + return fmt.Sprintf("%g", v) + default: + return "" + } +} + +func stringList(raw any) []string { + var out []string + for _, item := range sliceValue(raw) { + if value := stringValue(item); value != "" { + out = append(out, value) + } + } + return out +} + +func intValue(raw any) int { + switch v := raw.(type) { + case int: + return v + case int64: + return int(v) + case float64: + return int(v) + case string: + i, _ := strconv.Atoi(strings.TrimSpace(v)) + return i + default: + return 0 + } +} + +func uniqueInts(values []int) []int { + seen := map[int]struct{}{} + var out []int + for _, value := range values { + if value <= 0 { + continue + } + if _, ok := seen[value]; ok { + continue + } + seen[value] = struct{}{} + out = append(out, value) + } + sort.Ints(out) + return out +} + +func uniqueStrings(values []string) []string { + seen := map[string]struct{}{} + var out []string + for _, value := range values { + if value == "" { + continue + } + if _, ok := seen[value]; ok { + continue + } + seen[value] = struct{}{} + out = append(out, value) + } + return out +} + +func appendUnique(values []string, next ...string) []string { + return uniqueStrings(append(values, next...)) +} + +func firstNonEmpty(values ...string) string { + for _, value := range values { + if strings.TrimSpace(value) != "" { + return value + } + } + return "" +} + +func minInt(a, b int) int { + if a < b { + return a + } + return b +} + +func mustRel(root, absPath string) string { + rel, err := filepath.Rel(root, absPath) + if err != nil { + return absPath + } + return rel +} diff --git a/internal/watch/architecture_bindings.go b/internal/watch/architecture_bindings.go new file mode 100644 index 0000000..24343b8 --- /dev/null +++ b/internal/watch/architecture_bindings.go @@ -0,0 +1,343 @@ +package watch + +import ( + "path" + "sort" + "strings" +) + +const ( + minArchitectureBindingConfidence = 0.72 + maxArchitectureBindingTargets = 12 +) + +func resolveArchitectureBindings(repo Repository, architecture architectureModel, targets []ArchitectureBindingTarget) []ArchitectureBinding { + if len(architecture.Components) == 0 || len(targets) == 0 { + return nil + } + var bindings []ArchitectureBinding + for _, component := range sortedArchitectureComponents(architecture.Components) { + scored := scoreArchitectureTargets(repo, component, targets) + if len(scored) == 0 { + continue + } + if architectureTopTargetsAmbiguous(scored) { + continue + } + limit := min(maxArchitectureBindingTargets, len(scored)) + for i, candidate := range scored[:limit] { + role := "source" + if i == 0 { + role = "primary" + } + bindings = append(bindings, ArchitectureBinding{ + RepositoryID: repo.ID, + ComponentKey: component.Key, + TargetRepositoryID: candidate.Target.RepositoryID, + TargetOwnerType: candidate.Target.OwnerType, + TargetOwnerKey: candidate.Target.OwnerKey, + TargetResourceType: candidate.Target.ResourceType, + TargetResourceID: candidate.Target.ResourceID, + Role: role, + Confidence: candidate.Score, + Evidence: candidate.Evidence, + }) + } + } + sort.SliceStable(bindings, func(i, j int) bool { + if bindings[i].ComponentKey == bindings[j].ComponentKey { + if bindings[i].Role == bindings[j].Role { + return bindings[i].TargetOwnerKey < bindings[j].TargetOwnerKey + } + return architectureBindingRoleRank(bindings[i].Role) < architectureBindingRoleRank(bindings[j].Role) + } + return bindings[i].ComponentKey < bindings[j].ComponentKey + }) + return bindings +} + +type scoredArchitectureTarget struct { + Target ArchitectureBindingTarget + Score float64 + Evidence []ArchitectureBindingEvidence +} + +func scoreArchitectureTargets(repo Repository, component *architectureComponent, targets []ArchitectureBindingTarget) []scoredArchitectureTarget { + var out []scoredArchitectureTarget + for _, target := range targets { + score, evidence := architectureTargetScore(repo, component, target) + if score < minArchitectureBindingConfidence { + continue + } + out = append(out, scoredArchitectureTarget{Target: target, Score: score, Evidence: evidence}) + } + sort.SliceStable(out, func(i, j int) bool { + if out[i].Score != out[j].Score { + return out[i].Score > out[j].Score + } + leftRank := architectureTargetOwnerRank(out[i].Target.OwnerType) + rightRank := architectureTargetOwnerRank(out[j].Target.OwnerType) + if leftRank != rightRank { + return leftRank < rightRank + } + if out[i].Target.RepositoryID != out[j].Target.RepositoryID { + if out[i].Target.RepositoryID == repo.ID { + return true + } + if out[j].Target.RepositoryID == repo.ID { + return false + } + } + return out[i].Target.OwnerKey < out[j].Target.OwnerKey + }) + return out +} + +func architectureTargetScore(repo Repository, component *architectureComponent, target ArchitectureBindingTarget) (float64, []ArchitectureBindingEvidence) { + if component == nil { + return 0, nil + } + var score float64 + var evidence []ArchitectureBindingEvidence + add := func(kind, detail string, value float64) { + if value <= 0 { + return + } + score += value + evidence = append(evidence, ArchitectureBindingEvidence{Kind: kind, Detail: detail, Score: value}) + } + if target.RepositoryID == repo.ID { + add("same-repository", repo.DisplayName, 0.04) + } + componentVariants := architectureComponentVariants(component) + targetVariants := architectureTargetVariants(target) + if architectureAnyExactVariant(componentVariants, targetVariants) { + add("name-match", target.Name, 0.42) + } + if architectureAnyPathSegmentVariant(componentVariants, target.FilePath) { + add("path-token-match", target.FilePath, 0.30) + } + if architectureTokenOverlap(componentVariants, architectureTargetTokens(target)) >= 2 { + add("token-overlap", target.Name, 0.16) + } + for _, evidencePath := range architectureComponentEvidencePaths(component) { + switch { + case target.FilePath != "" && sameArchitecturePath(target.FilePath, evidencePath): + add("evidence-path", evidencePath, 0.78) + case target.OwnerType == "folder" && target.FilePath != "" && architecturePathContains(target.FilePath, evidencePath): + add("evidence-under-target", evidencePath, 0.66) + case target.FilePath != "" && sameArchitectureDir(target.FilePath, evidencePath): + add("evidence-directory", evidencePath, 0.30) + } + } + if target.OwnerType == "symbol" && architectureAnyExactVariant(componentVariants, architectureNameTokens(target.Name)) { + add("symbol-name-match", target.Name, 0.18) + } + if target.OwnerType == "fact" && architectureAnyExactVariant(componentVariants, architectureNameTokens(target.Name)) { + add("fact-name-match", target.Name, 0.20) + } + if target.OwnerType == "fact-summary" { + score -= 0.08 + } + if score > 1 { + score = 1 + } + return score, evidence +} + +func architectureTopTargetsAmbiguous(scored []scoredArchitectureTarget) bool { + if len(scored) < 2 { + return false + } + top := scored[0] + second := scored[1] + if top.Score-second.Score > 0.02 { + return false + } + return !architectureBindingHasConcreteEvidence(top.Evidence) && !architectureBindingHasConcreteEvidence(second.Evidence) +} + +func architectureBindingHasConcreteEvidence(evidence []ArchitectureBindingEvidence) bool { + for _, item := range evidence { + switch item.Kind { + case "evidence-path", "evidence-under-target", "evidence-directory": + return true + } + } + return false +} + +func architectureComponentVariants(component *architectureComponent) []string { + set := map[string]struct{}{} + addTokens := func(tokens []string) { + if len(tokens) == 0 { + return + } + joined := strings.Join(tokens, "-") + compact := strings.ReplaceAll(joined, "-", "") + set[joined] = struct{}{} + set[compact] = struct{}{} + if root, ok := architectureServiceRootFromTokens(tokens, true); ok { + set[root] = struct{}{} + set[strings.ReplaceAll(root, "-", "")] = struct{}{} + } + } + addTokens(architectureNameTokens(component.Name)) + for _, evidence := range component.Evidence { + addTokens(architectureNameTokens(evidence.Note)) + } + return sortedKeys(set) +} + +func architectureTargetVariants(target ArchitectureBindingTarget) []string { + set := map[string]struct{}{} + for _, value := range []string{target.Name, path.Base(target.FilePath)} { + tokens := architectureNameTokens(value) + if len(tokens) == 0 { + continue + } + joined := strings.Join(tokens, "-") + set[joined] = struct{}{} + set[strings.ReplaceAll(joined, "-", "")] = struct{}{} + } + return sortedKeys(set) +} + +func architectureTargetTokens(target ArchitectureBindingTarget) []string { + set := map[string]struct{}{} + for _, value := range []string{target.Name, target.Kind, target.Language} { + for _, token := range architectureNameTokens(value) { + set[token] = struct{}{} + } + } + for part := range strings.SplitSeq(filepathToSlash(target.FilePath), "/") { + for _, token := range architectureNameTokens(part) { + set[token] = struct{}{} + } + } + for _, tag := range target.Tags { + for _, token := range architectureNameTokens(tag) { + set[token] = struct{}{} + } + } + return sortedKeys(set) +} + +func architectureComponentEvidencePaths(component *architectureComponent) []string { + set := map[string]struct{}{} + if path := cleanArchitecturePath(component.FilePath); path != "" { + set[path] = struct{}{} + } + for _, evidence := range component.Evidence { + if path := cleanArchitecturePath(evidence.Path); path != "" { + set[path] = struct{}{} + } + } + return sortedKeys(set) +} + +func architectureAnyExactVariant(left, right []string) bool { + set := map[string]struct{}{} + for _, item := range left { + if item != "" { + set[item] = struct{}{} + } + } + for _, item := range right { + if _, ok := set[item]; ok { + return true + } + } + return false +} + +func architectureAnyPathSegmentVariant(variants []string, value string) bool { + if strings.TrimSpace(value) == "" { + return false + } + set := map[string]struct{}{} + for _, variant := range variants { + if variant != "" { + set[variant] = struct{}{} + } + } + for part := range strings.SplitSeq(filepathToSlash(value), "/") { + tokens := architectureNameTokens(part) + joined := strings.Join(tokens, "-") + if _, ok := set[joined]; ok { + return true + } + if _, ok := set[strings.ReplaceAll(joined, "-", "")]; ok { + return true + } + } + return false +} + +func architectureTokenOverlap(componentVariants []string, targetTokens []string) int { + set := map[string]struct{}{} + for _, variant := range componentVariants { + for token := range strings.SplitSeq(variant, "-") { + if token != "" && !architectureRoleToken(token) { + set[token] = struct{}{} + } + } + } + var count int + for _, token := range targetTokens { + if _, ok := set[token]; ok { + count++ + } + } + return count +} + +func cleanArchitecturePath(value string) string { + value = filepathToSlash(strings.TrimSpace(value)) + if value == "" || value == "." { + return "" + } + return path.Clean(value) +} + +func sameArchitecturePath(a, b string) bool { + return cleanArchitecturePath(a) == cleanArchitecturePath(b) +} + +func sameArchitectureDir(a, b string) bool { + a, b = cleanArchitecturePath(a), cleanArchitecturePath(b) + if a == "" || b == "" { + return false + } + return path.Dir(a) == path.Dir(b) +} + +func architecturePathContains(parent, child string) bool { + parent, child = cleanArchitecturePath(parent), cleanArchitecturePath(child) + if parent == "" || child == "" || parent == "." { + return false + } + return child == parent || strings.HasPrefix(child, parent+"/") +} + +func architectureTargetOwnerRank(ownerType string) int { + switch ownerType { + case "folder": + return 0 + case "file": + return 1 + case "symbol": + return 2 + case "fact": + return 3 + default: + return 4 + } +} + +func architectureBindingRoleRank(role string) int { + if role == "primary" { + return 0 + } + return 1 +} diff --git a/internal/watch/architecture_facts.go b/internal/watch/architecture_facts.go new file mode 100644 index 0000000..e37a8d5 --- /dev/null +++ b/internal/watch/architecture_facts.go @@ -0,0 +1,325 @@ +package watch + +import ( + "encoding/json" + "path" + "regexp" + "strings" +) + +func mergeArchitectureModels(models ...architectureModel) architectureModel { + merged := architectureModel{Components: map[string]*architectureComponent{}, Connectors: map[string]*architectureConnector{}} + for _, model := range models { + for key, component := range model.Components { + if component == nil { + continue + } + existing := merged.Components[key] + if existing == nil { + copyComponent := *component + copyComponent.Tags = append([]string{}, component.Tags...) + copyComponent.Evidence = append([]architectureEvidence{}, component.Evidence...) + merged.Components[key] = ©Component + continue + } + existing.Technology = firstNonEmpty(existing.Technology, component.Technology) + existing.Description = firstNonEmpty(existing.Description, component.Description) + existing.FilePath = firstNonEmpty(existing.FilePath, component.FilePath) + existing.Tags = appendUnique(existing.Tags, component.Tags...) + existing.Evidence = append(existing.Evidence, component.Evidence...) + } + for key, connector := range model.Connectors { + if connector == nil { + continue + } + existing := merged.Connectors[key] + if existing == nil { + copyConnector := *connector + copyConnector.Evidence = append([]architectureEvidence{}, connector.Evidence...) + merged.Connectors[key] = ©Connector + continue + } + if connector.Confidence > existing.Confidence { + existing.Confidence = connector.Confidence + } + existing.Evidence = append(existing.Evidence, connector.Evidence...) + } + } + return merged +} + +func architectureFromFacts(facts []Fact) architectureModel { + model := architectureModel{Components: map[string]*architectureComponent{}, Connectors: map[string]*architectureConnector{}} + sourceByFile := map[string]string{} + for _, fact := range facts { + if fact.Type == enrichmentVersionType { + continue + } + if !architectureFactType(fact.Type) { + continue + } + attrs := factAttributes(fact) + source := inferredComponentFromFact(fact, attrs) + if source != "" { + sourceByFile[fact.FilePath] = source + key := architectureKey("component", source) + component := ensureFactComponent(model, key, source, "service", technologyFromLanguage(fact.FilePath), fact.FilePath, factEvidence(fact, "source-service")) + component.Tags = appendUnique(component.Tags, "arch:component", "arch:source") + } + switch fact.Type { + case "runtime.component": + name := firstNonEmpty(attrs["name"], fact.Name, fact.ObjectName) + if name == "" { + continue + } + kind := firstNonEmpty(attrs["kind"], "service") + technology := firstNonEmpty(attrs["technology"], "Runtime") + key := architectureKey(kindKey(kind), name) + filePath := fact.FilePath + evidencePath := filePath + if isComposeComponent(attrs, fact.Tags) { + evidencePath = "compose/service:" + name + } + component := ensureFactComponent(model, key, name, kind, technology, filePath, architectureEvidence{Kind: "runtime-component", Path: evidencePath, Note: fact.Name}) + component.Tags = appendUnique(component.Tags, fact.Tags...) + case "runtime.connection": + source := firstNonEmpty(attrs["source"], sourceByFile[fact.FilePath]) + target := firstNonEmpty(attrs["target"], fact.ObjectName) + addFactConnector(model, source, target, firstNonEmpty(attrs["label"], "uses"), firstNonEmpty(fact.Relationship, "runtime-dependency"), fact.Confidence, factEvidence(fact, "runtime-connection")) + case "runtime.endpoint_ref": + source := sourceByFile[fact.FilePath] + target := firstNonEmpty(attrs["target"], fact.ObjectName) + if target != "" && source != "" { + addFactConnector(model, source, target, "uses", "runtime-dependency", fact.Confidence, factEvidence(fact, "endpoint-reference")) + } + case "grpc.server", "grpc.contract": + name := firstNonEmpty(attrs["service"], fact.Name, fact.ObjectName) + if name == "" { + continue + } + kind := "interface" + if fact.Type == "grpc.server" { + kind = "service" + } + key := architectureKey(kindKey(kind), name) + component := ensureFactComponent(model, key, name, kind, "gRPC", fact.FilePath, factEvidence(fact, fact.Type)) + component.Tags = appendUnique(component.Tags, "protocol:grpc", "arch:contract") + case "grpc.client": + source := sourceByFile[fact.FilePath] + target := firstNonEmpty(attrs["service"], fact.Name, fact.ObjectName) + addFactConnector(model, source, target, "grpc", "runtime-dependency", fact.Confidence, factEvidence(fact, "grpc-client")) + case "http.client": + source := sourceByFile[fact.FilePath] + target := firstNonEmpty(attrs["target"], "external traffic") + addFactConnector(model, source, target, "http", "runtime-dependency", fact.Confidence, factEvidence(fact, "http-client")) + case "datastore.dependency", "external.dependency": + source := sourceByFile[fact.FilePath] + target := firstNonEmpty(attrs["name"], fact.Name, fact.ObjectName) + kind := "external" + if fact.Type == "datastore.dependency" { + kind = "datastore" + } + if target != "" { + component := ensureFactComponent(model, architectureKey(kind, target), target, kind, firstNonEmpty(attrs["technology"], "External"), fact.FilePath, factEvidence(fact, fact.Type)) + component.Tags = appendUnique(component.Tags, fact.Tags...) + } + addFactConnector(model, source, target, labelForDependency(target), "runtime-dependency", fact.Confidence, factEvidence(fact, fact.Type)) + } + } + return model +} + +func architectureFactType(factType string) bool { + switch factType { + case "runtime.component", "runtime.connection", "runtime.endpoint_ref", "grpc.server", "grpc.contract", "grpc.client", "http.client", "datastore.dependency", "external.dependency": + return true + default: + return false + } +} + +func ensureFactComponent(model architectureModel, key, name, kind, technology, filePath string, evidence architectureEvidence) *architectureComponent { + if key == "" || name == "" { + return nil + } + if existing := model.Components[key]; existing != nil { + existing.Technology = firstNonEmpty(existing.Technology, technology) + existing.FilePath = firstNonEmpty(existing.FilePath, filePath) + existing.Evidence = append(existing.Evidence, evidence) + return existing + } + component := &architectureComponent{ + Key: key, + Name: name, + Kind: kind, + Technology: technology, + FilePath: filePath, + Tags: []string{"arch:component"}, + Evidence: []architectureEvidence{evidence}, + } + model.Components[key] = component + return component +} + +func addFactConnector(model architectureModel, source, target, label, relationship string, confidence float64, evidence architectureEvidence) { + source = normalizeFactEndpoint(source) + target = normalizeFactEndpoint(target) + if source == "" || target == "" || source == target { + return + } + sourceKey := architectureKey("component", source) + targetKey := architectureKey("component", target) + if label == "redis" || strings.Contains(target, "redis") { + targetKey = architectureKey("component", target) + } + if model.Components[sourceKey] == nil { + ensureFactComponent(model, sourceKey, source, "service", "Runtime", evidence.Path, evidence) + } + if model.Components[targetKey] == nil { + ensureFactComponent(model, targetKey, target, "service", "Runtime", evidence.Path, evidence) + } + if label == "" { + label = "uses" + } + if relationship == "" { + relationship = "runtime-dependency" + } + key := sourceKey + "->" + targetKey + ":" + relationship + ":" + label + if existing := model.Connectors[key]; existing != nil { + existing.Evidence = append(existing.Evidence, evidence) + if confidence > existing.Confidence { + existing.Confidence = confidence + } + return + } + model.Connectors[key] = &architectureConnector{ + Key: key, + SourceKey: sourceKey, + TargetKey: targetKey, + Label: label, + Relationship: relationship, + Direction: "forward", + Confidence: confidence, + Evidence: []architectureEvidence{evidence}, + } +} + +func factAttributes(fact Fact) map[string]string { + var attrs map[string]string + if fact.AttributesJSON != "" { + _ = json.Unmarshal([]byte(fact.AttributesJSON), &attrs) + } + if attrs == nil { + attrs = map[string]string{} + } + return attrs +} + +func inferredComponentFromFact(fact Fact, attrs map[string]string) string { + if value := attrs["source"]; value != "" { + return value + } + return componentFromPath(fact.FilePath) +} + +func componentFromPath(rel string) string { + rel = path.Clean(strings.ReplaceAll(rel, "\\", "/")) + parts := strings.Split(rel, "/") + for i := len(parts) - 2; i >= 0; i-- { + part := strings.TrimSpace(parts[i]) + if part != "." && part != "" && !architecturePathLayoutToken(part) { + return part + } + } + return "" +} + +func architecturePathLayoutToken(value string) bool { + switch strings.ToLower(strings.TrimSpace(value)) { + case "app", "apps", "cmd", "internal", "lib", "libs", "pkg", "packages", "service", "services", "source", "src": + return true + default: + return false + } +} + +func normalizeFactEndpoint(value string) string { + value = strings.Trim(strings.TrimSpace(strings.ToLower(value)), `"'`) + if value == "" || strings.Contains(value, "{{") || strings.HasPrefix(value, "$") { + return "" + } + if strings.Contains(value, "://") { + value = strings.SplitN(value, "://", 2)[1] + } + if strings.Contains(value, ":") { + value = strings.Split(value, ":")[0] + } + value = strings.Split(value, ".")[0] + value = strings.Trim(value, "/") + nonName := regexp.MustCompile(`[^a-z0-9-]+`) + value = nonName.ReplaceAllString(value, "-") + return strings.Trim(value, "-") +} + +func factEvidence(fact Fact, kind string) architectureEvidence { + note := fact.Name + if note == "" { + note = fact.Type + } + return architectureEvidence{Kind: kind, Path: fact.FilePath, Note: note} +} + +func technologyFromLanguage(filePath string) string { + switch strings.ToLower(path.Ext(filePath)) { + case ".go": + return "Go" + case ".py": + return "Python" + case ".js", ".mjs", ".cjs": + return "Javascript" + case ".ts", ".tsx": + return "Typescript" + case ".java": + return "Java" + case ".cs", ".csproj": + return ".NET" + case ".cpp", ".c", ".h": + return "C/C++" + case ".rb": + return "Ruby" + case ".php": + return "PHP" + default: + return "Runtime" + } +} + +func kindKey(kind string) string { + switch strings.ToLower(kind) { + case "datastore", "queue", "external", "contract": + return strings.ToLower(kind) + case "interface": + return "contract" + default: + return "component" + } +} + +func labelForDependency(target string) string { + if strings.Contains(target, "redis") { + return "redis" + } + if strings.Contains(target, "otel") || strings.Contains(target, "opentelemetry") { + return "observes" + } + return "uses" +} + +func isComposeComponent(attrs map[string]string, tags []string) bool { + for _, tag := range tags { + if tag == "runtime:compose" { + return true + } + } + return attrs["technology"] == "Docker Compose" +} diff --git a/internal/watch/architecture_reconcile.go b/internal/watch/architecture_reconcile.go new file mode 100644 index 0000000..0ec5aa4 --- /dev/null +++ b/internal/watch/architecture_reconcile.go @@ -0,0 +1,774 @@ +package watch + +import ( + "slices" + "sort" + "strings" +) + +func canonicalizeArchitecture(model architectureModel) architectureModel { + if len(model.Components) == 0 { + return model + } + uf := newArchitectureUnion(model.Components) + unionExactArchitectureAliases(model, uf) + unionServiceRootArchitectureAliases(model, uf) + unionInferredServiceRootArchitectureAliases(model, uf) + unionGenericArchitectureDependencies(model, uf) + return rewriteCanonicalArchitecture(model, uf) +} + +func pruneDisconnectedArchitecture(model architectureModel) architectureModel { + if len(model.Components) == 0 || len(model.Connectors) == 0 { + return architectureModel{Components: map[string]*architectureComponent{}, Connectors: map[string]*architectureConnector{}} + } + degree := map[string]int{} + connectors := map[string]*architectureConnector{} + for _, connector := range model.Connectors { + if connector == nil || connector.SourceKey == "" || connector.TargetKey == "" || connector.SourceKey == connector.TargetKey { + continue + } + if model.Components[connector.SourceKey] == nil || model.Components[connector.TargetKey] == nil { + continue + } + degree[connector.SourceKey]++ + degree[connector.TargetKey]++ + copyConnector := *connector + copyConnector.Evidence = append([]architectureEvidence{}, connector.Evidence...) + connectors[connector.Key] = ©Connector + } + components := map[string]*architectureComponent{} + for key, component := range model.Components { + if component == nil || degree[key] == 0 { + continue + } + copyComponent := *component + copyComponent.Tags = append([]string{}, component.Tags...) + copyComponent.Evidence = append([]architectureEvidence{}, component.Evidence...) + components[key] = ©Component + } + for key, connector := range connectors { + if components[connector.SourceKey] == nil || components[connector.TargetKey] == nil { + delete(connectors, key) + } + } + return architectureModel{Components: components, Connectors: connectors} +} + +type architectureUnion struct { + parent map[string]string +} + +func newArchitectureUnion(components map[string]*architectureComponent) *architectureUnion { + parent := map[string]string{} + for key := range components { + parent[key] = key + } + return &architectureUnion{parent: parent} +} + +func (u *architectureUnion) find(key string) string { + parent, ok := u.parent[key] + if !ok { + return key + } + if parent == key { + return key + } + root := u.find(parent) + u.parent[key] = root + return root +} + +func (u *architectureUnion) union(a, b string) { + ra, rb := u.find(a), u.find(b) + if ra == rb { + return + } + if architectureCanonicalRankKey(ra) <= architectureCanonicalRankKey(rb) { + u.parent[rb] = ra + return + } + u.parent[ra] = rb +} + +func (u *architectureUnion) unionInto(parent, child string) { + rp, rc := u.find(parent), u.find(child) + if rp == rc { + return + } + u.parent[rc] = rp +} + +func unionExactArchitectureAliases(model architectureModel, uf *architectureUnion) { + byName := map[string][]string{} + for key, component := range model.Components { + name := normalizedArchitectureName(component) + if name == "" { + continue + } + byName[name] = append(byName[name], key) + } + for _, keys := range byName { + if len(keys) < 2 { + continue + } + sort.Strings(keys) + for _, key := range keys[1:] { + if architectureComponentsAliasCompatible(model.Components[keys[0]], model.Components[key]) { + uf.union(keys[0], key) + } + } + } +} + +func unionGenericArchitectureDependencies(model architectureModel, uf *architectureUnion) { + bySourceFamily := map[string]map[string]map[string]struct{}{} + for _, connector := range model.Connectors { + if connector == nil { + continue + } + source := uf.find(connector.SourceKey) + target := uf.find(connector.TargetKey) + component := model.Components[target] + if component == nil { + continue + } + family := architectureDependencyFamily(component) + if family == "" { + continue + } + sourceKey := source + "\x00" + family + if bySourceFamily[sourceKey] == nil { + bySourceFamily[sourceKey] = map[string]map[string]struct{}{"generic": {}, "concrete": {}} + } + if architectureGenericDependency(component, family) { + bySourceFamily[sourceKey]["generic"][target] = struct{}{} + continue + } + bySourceFamily[sourceKey]["concrete"][target] = struct{}{} + } + for _, group := range bySourceFamily { + if len(group["generic"]) == 0 || len(group["concrete"]) != 1 { + continue + } + var concrete string + for key := range group["concrete"] { + concrete = key + } + for generic := range group["generic"] { + uf.union(concrete, generic) + } + } +} + +func unionServiceRootArchitectureAliases(model architectureModel, uf *architectureUnion) { + byRoot := map[string][]string{} + for key, component := range model.Components { + root := architectureServiceRootIdentity(component) + if root == "" { + continue + } + byRoot[root] = append(byRoot[root], key) + } + for _, keys := range byRoot { + if len(keys) < 2 { + continue + } + sort.SliceStable(keys, func(i, j int) bool { + left, right := model.Components[keys[i]], model.Components[keys[j]] + leftRank, rightRank := architectureServiceAliasRank(left), architectureServiceAliasRank(right) + if leftRank != rightRank { + return leftRank < rightRank + } + return keys[i] < keys[j] + }) + canonical := keys[0] + for _, key := range keys[1:] { + if architectureServiceRootAliasCompatible(model, canonical, key) { + uf.unionInto(canonical, key) + } + } + } +} + +func unionInferredServiceRootArchitectureAliases(model architectureModel, uf *architectureUnion) { + knownRoots := knownArchitectureServiceRoots(model) + byRoot := map[string][]string{} + for key, component := range model.Components { + root := architectureInferredServiceRootIdentity(component, knownRoots) + if root == "" { + continue + } + byRoot[root] = append(byRoot[root], key) + } + for _, keys := range byRoot { + if len(keys) < 2 { + continue + } + sort.SliceStable(keys, func(i, j int) bool { + left, right := model.Components[keys[i]], model.Components[keys[j]] + leftRank, rightRank := architectureServiceAliasRank(left), architectureServiceAliasRank(right) + if leftRank != rightRank { + return leftRank < rightRank + } + return keys[i] < keys[j] + }) + canonical := keys[0] + for _, key := range keys[1:] { + if architectureInferredServiceRootAliasCompatible(model, canonical, key, knownRoots) { + uf.unionInto(canonical, key) + } + } + } +} + +func rewriteCanonicalArchitecture(model architectureModel, uf *architectureUnion) architectureModel { + out := architectureModel{Components: map[string]*architectureComponent{}, Connectors: map[string]*architectureConnector{}} + for key, component := range model.Components { + if component == nil { + continue + } + root := uf.find(key) + existing := out.Components[root] + if existing == nil { + copyComponent := *component + copyComponent.Key = root + copyComponent.Tags = append([]string{}, component.Tags...) + copyComponent.Evidence = append([]architectureEvidence{}, component.Evidence...) + out.Components[root] = ©Component + continue + } + mergeArchitectureComponent(existing, component) + } + connectorPairIndex := map[string]string{} + connectors := make([]*architectureConnector, 0, len(model.Connectors)) + for _, connector := range model.Connectors { + if connector != nil { + connectors = append(connectors, connector) + } + } + sort.SliceStable(connectors, func(i, j int) bool { + return connectors[i].Key < connectors[j].Key + }) + for _, connector := range connectors { + if connector == nil { + continue + } + source := uf.find(connector.SourceKey) + target := uf.find(connector.TargetKey) + if source == "" || target == "" || source == target { + continue + } + pairKey := architectureConnectorPairKey(source, target) + direction := normalizedArchitectureConnectorDirection(connector.Direction) + if existingKey, ok := connectorPairIndex[pairKey]; ok { + existing := out.Connectors[existingKey] + if existing == nil { + continue + } + if existing.SourceKey == target && existing.TargetKey == source { + direction = reverseArchitectureConnectorDirection(direction) + } + existing.Label = "" + if existing.Relationship != connector.Relationship { + existing.Relationship = "" + } + existing.Direction = mergeArchitectureConnectorDirections(existing.Direction, direction) + if connector.Confidence > existing.Confidence { + existing.Confidence = connector.Confidence + } + if existing.Description == "" { + existing.Description = connector.Description + } + existing.Evidence = append(existing.Evidence, connector.Evidence...) + continue + } + key := source + "->" + target + connectorPairIndex[pairKey] = key + copyConnector := *connector + copyConnector.Key = key + copyConnector.SourceKey = source + copyConnector.TargetKey = target + copyConnector.Direction = direction + copyConnector.Evidence = append([]architectureEvidence{}, connector.Evidence...) + out.Connectors[key] = ©Connector + } + return out +} + +func architectureConnectorPairKey(source, target string) string { + if source <= target { + return source + "\x00" + target + } + return target + "\x00" + source +} + +func normalizedArchitectureConnectorDirection(direction string) string { + switch strings.ToLower(strings.TrimSpace(direction)) { + case "backward": + return "backward" + case "both", "bidirectional": + return "both" + case "none": + return "none" + default: + return "forward" + } +} + +func reverseArchitectureConnectorDirection(direction string) string { + switch normalizedArchitectureConnectorDirection(direction) { + case "forward": + return "backward" + case "backward": + return "forward" + default: + return normalizedArchitectureConnectorDirection(direction) + } +} + +func mergeArchitectureConnectorDirections(a, b string) string { + forward, backward, none := architectureConnectorDirectionBits(a) + bForward, bBackward, bNone := architectureConnectorDirectionBits(b) + forward = forward || bForward + backward = backward || bBackward + none = none && bNone + switch { + case forward && backward: + return "both" + case backward: + return "backward" + case forward: + return "forward" + case none: + return "none" + default: + return "forward" + } +} + +func architectureConnectorDirectionBits(direction string) (forward, backward, none bool) { + switch normalizedArchitectureConnectorDirection(direction) { + case "both": + return true, true, false + case "backward": + return false, true, false + case "none": + return false, false, true + default: + return true, false, false + } +} + +func mergeArchitectureComponent(dst, src *architectureComponent) { + if architectureComponentRank(src) < architectureComponentRank(dst) { + dst.Name = src.Name + dst.Kind = src.Kind + dst.Technology = src.Technology + dst.Description = firstNonEmpty(src.Description, dst.Description) + dst.FilePath = firstNonEmpty(src.FilePath, dst.FilePath) + } else { + dst.Technology = firstNonEmpty(dst.Technology, src.Technology) + dst.Description = firstNonEmpty(dst.Description, src.Description) + dst.FilePath = firstNonEmpty(dst.FilePath, src.FilePath) + } + dst.Tags = appendUnique(dst.Tags, src.Tags...) + dst.Evidence = append(dst.Evidence, src.Evidence...) +} + +func architectureComponentsAliasCompatible(a, b *architectureComponent) bool { + if a == nil || b == nil { + return false + } + if normalizedArchitectureName(a) == "" || normalizedArchitectureName(a) != normalizedArchitectureName(b) { + return false + } + if architectureComponentClass(a) == "interface" || architectureComponentClass(b) == "interface" { + return architectureComponentClass(a) == architectureComponentClass(b) + } + return true +} + +func normalizedArchitectureName(component *architectureComponent) string { + if component == nil { + return "" + } + return architectureSlug(component.Name) +} + +func architectureServiceRootIdentity(component *architectureComponent) string { + if component == nil || !architectureServiceAliasCandidate(component) { + return "" + } + root, ok := architectureServiceRootFromTokens(architectureNameTokens(component.Name), false) + if !ok { + return "" + } + return root +} + +func architectureServiceRootFromTokens(tokens []string, allowShort bool) (string, bool) { + if len(tokens) == 0 { + return "", false + } + filtered := make([]string, 0, len(tokens)) + for _, token := range tokens { + if architectureRoleToken(token) { + continue + } + filtered = append(filtered, token) + } + if len(filtered) == 0 { + return "", false + } + root := strings.Join(filtered, "-") + if len(root) < 3 && !allowShort { + return "", false + } + return root, true +} + +func architectureInferredServiceRootIdentity(component *architectureComponent, knownRoots map[string]struct{}) string { + if component == nil || !architectureServiceAliasCandidate(component) { + return "" + } + tokens := architectureNameTokens(component.Name) + if root, ok := architectureServiceRootFromTokens(tokens, true); ok { + compactRoot := strings.ReplaceAll(root, "-", "") + for known := range knownRoots { + if strings.Contains(known, "-") && strings.ReplaceAll(known, "-", "") == compactRoot { + return known + } + } + if _, known := knownRoots[root]; known { + return root + } + if len(root) >= 3 { + return root + } + if architectureHasRoleToken(tokens) { + return root + } + } + compact := architectureCompactServiceRoot(component.Name) + if compact == "" { + return "" + } + for root := range knownRoots { + if strings.ReplaceAll(root, "-", "") == compact { + return root + } + } + return "" +} + +func architectureNameTokens(value string) []string { + var tokens []string + var b strings.Builder + var prevClass int + flush := func() { + if b.Len() == 0 { + return + } + tokens = append(tokens, strings.ToLower(b.String())) + b.Reset() + } + for _, r := range strings.TrimSpace(value) { + class := architectureRuneClass(r) + if class == 0 { + flush() + prevClass = 0 + continue + } + if b.Len() > 0 && class != prevClass { + if class == 2 || prevClass == 3 || class == 3 { + flush() + } + } + if r >= 'A' && r <= 'Z' { + r += 'a' - 'A' + } + b.WriteRune(r) + prevClass = class + } + flush() + return splitArchitectureAcronymSuffixes(tokens) +} + +func splitArchitectureAcronymSuffixes(tokens []string) []string { + var out []string + for _, token := range tokens { + if len(token) <= 2 { + out = append(out, token) + continue + } + matched := false + for _, suffix := range []string{"controller", "database", "service", "gateway", "adapter", "handler", "server", "client", "worker", "store", "api", "svc", "db"} { + if strings.HasSuffix(token, suffix) && len(token) > len(suffix) { + out = append(out, token[:len(token)-len(suffix)], suffix) + matched = true + break + } + } + if !matched { + out = append(out, token) + } + } + return out +} + +func architectureRuneClass(r rune) int { + switch { + case r >= 'a' && r <= 'z': + return 1 + case r >= 'A' && r <= 'Z': + return 2 + case r >= '0' && r <= '9': + return 3 + default: + return 0 + } +} + +func architectureRoleToken(token string) bool { + switch token { + case "service", "svc", "api", "db", "database", "store", "client", "server", "worker", "gateway", "adapter", "controller", "handler": + return true + default: + return false + } +} + +func architectureHasRoleToken(tokens []string) bool { + return slices.ContainsFunc(tokens, architectureRoleToken) +} + +func knownArchitectureServiceRoots(model architectureModel) map[string]struct{} { + known := map[string]struct{}{} + for _, component := range model.Components { + if component == nil || !architectureServiceAliasCandidate(component) { + continue + } + root, ok := architectureServiceRootFromTokens(architectureNameTokens(component.Name), true) + if !ok { + continue + } + if strings.Contains(root, "-") || architectureHasRoleToken(architectureNameTokens(component.Name)) { + known[root] = struct{}{} + } + } + return known +} + +func architectureCompactServiceRoot(name string) string { + tokens := architectureNameTokens(name) + root, ok := architectureServiceRootFromTokens(tokens, true) + if !ok { + return "" + } + return strings.ReplaceAll(root, "-", "") +} + +func architectureServiceAliasCandidate(component *architectureComponent) bool { + if component == nil { + return false + } + if architectureComponentClass(component) == "interface" { + return true + } + switch component.Kind { + case "service", "interface": + return true + default: + return false + } +} + +func architectureServiceRootAliasCompatible(model architectureModel, aKey, bKey string) bool { + a, b := model.Components[aKey], model.Components[bKey] + if !architectureServiceAliasCandidate(a) || !architectureServiceAliasCandidate(b) { + return false + } + root := architectureServiceRootIdentity(a) + if root == "" || root != architectureServiceRootIdentity(b) { + return false + } + if architectureComponentsConnected(model, aKey, bKey) { + return true + } + return architectureComponentHasStrongServiceEvidence(a) && architectureComponentHasStrongServiceEvidence(b) +} + +func architectureInferredServiceRootAliasCompatible(model architectureModel, aKey, bKey string, knownRoots map[string]struct{}) bool { + a, b := model.Components[aKey], model.Components[bKey] + if !architectureServiceAliasCandidate(a) || !architectureServiceAliasCandidate(b) { + return false + } + root := architectureInferredServiceRootIdentity(a, knownRoots) + if root == "" || root != architectureInferredServiceRootIdentity(b, knownRoots) { + return false + } + if len(root) < 3 { + return architectureShortRootAliasCompatible(a, b) + } + if architectureComponentsConnected(model, aKey, bKey) { + return true + } + return architectureComponentHasStrongServiceEvidence(a) && architectureComponentHasStrongServiceEvidence(b) +} + +func architectureShortRootAliasCompatible(a, b *architectureComponent) bool { + if !architectureComponentHasStrongServiceEvidence(a) || !architectureComponentHasStrongServiceEvidence(b) { + return false + } + return architectureHasRoleToken(architectureNameTokens(a.Name)) || architectureHasRoleToken(architectureNameTokens(b.Name)) +} + +func architectureComponentsConnected(model architectureModel, aKey, bKey string) bool { + for _, connector := range model.Connectors { + if connector == nil { + continue + } + if connector.SourceKey == aKey && connector.TargetKey == bKey || connector.SourceKey == bKey && connector.TargetKey == aKey { + return true + } + } + return false +} + +func architectureComponentHasStrongServiceEvidence(component *architectureComponent) bool { + if component == nil { + return false + } + for _, ev := range component.Evidence { + switch ev.Kind { + case "deployable", "endpoint", "runtime-component", "source-service", "grpc.server", "grpc-client", "grpc.contract", "service-contract": + return true + } + } + return false +} + +func architectureServiceAliasRank(component *architectureComponent) int { + if component == nil { + return 99 + } + name := normalizedArchitectureName(component) + root := architectureServiceRootIdentity(component) + switch { + case name == root+"service": + return 0 + case component.Kind == "service" && architectureComponentHasEvidence(component, "deployable"): + return 1 + case component.Kind == "service" && architectureComponentHasEvidence(component, "runtime-component"): + return 2 + case component.Kind == "service" && name == root: + return 3 + case architectureComponentClass(component) == "interface": + return 4 + default: + return 5 + } +} + +func architectureComponentHasEvidence(component *architectureComponent, kind string) bool { + if component == nil { + return false + } + for _, ev := range component.Evidence { + if ev.Kind == kind { + return true + } + } + return false +} + +func architectureDependencyFamily(component *architectureComponent) string { + if component == nil { + return "" + } + for _, tag := range component.Tags { + tag = strings.TrimSpace(strings.ToLower(tag)) + if value, ok := strings.CutPrefix(tag, "datastore:"); ok && value != "" { + return architectureSlug(value) + } + if value, ok := strings.CutPrefix(tag, "external:"); ok && value != "" { + return architectureSlug(value) + } + } + tech := architectureSlug(component.Technology) + name := normalizedArchitectureName(component) + switch component.Kind { + case "datastore", "queue", "external": + return firstNonEmpty(tech, name) + default: + if tech != "" && tech != "runtime" && tech != "container" && tech != "kubernetes" && tech != "docker-compose" { + return tech + } + } + return "" +} + +func architectureGenericDependency(component *architectureComponent, family string) bool { + name := normalizedArchitectureName(component) + if name == "" || family == "" { + return false + } + return name == family +} + +func architectureComponentClass(component *architectureComponent) string { + if component == nil { + return "" + } + switch component.Kind { + case "datastore", "queue", "external": + return component.Kind + case "interface": + return "interface" + default: + return "component" + } +} + +func architectureCanonicalRankKey(key string) string { + return architectureRankPrefixFromKey(key) + ":" + key +} + +func architectureRankPrefixFromKey(key string) string { + kind := key + if before, _, ok := strings.Cut(key, ":"); ok { + kind = before + } + switch kind { + case "component": + return "0" + case "datastore", "queue": + return "1" + case "external": + return "2" + case "contract": + return "3" + default: + return "4" + } +} + +func architectureComponentRank(component *architectureComponent) int { + if component == nil { + return 99 + } + switch component.Kind { + case "service": + return 0 + case "datastore", "queue": + return 1 + case "external": + return 2 + case "interface": + return 3 + default: + return 4 + } +} diff --git a/internal/watch/context.go b/internal/watch/context.go new file mode 100644 index 0000000..3d0d056 --- /dev/null +++ b/internal/watch/context.go @@ -0,0 +1,696 @@ +package watch + +import ( + "context" + "database/sql" + "errors" + "fmt" + "math" + "path" + "sort" + "strings" +) + +const ( + contextActionShow = "show" + contextActionHide = "hide" + contextActionClean = "clean" +) + +type contextOwner struct { + OwnerType string + OwnerKey string +} + +type contextPolicySet struct { + Show map[string]string + Hide map[string]string +} + +type contextExpansionSet struct { + Tiers map[string]int + MaxTier int +} + +type contextExpansionAdjustment struct { + TierBefore int + TierAfter int + MaxTier int + Owners int +} + +type contextRemovalStats struct { + Elements int + Connectors int + Views int +} + +func (s *Store) ApplyContextAction(ctx context.Context, repositoryID int64, action string, req ContextResourceRequest, representReq RepresentRequest) (ContextActionResult, error) { + if action != contextActionShow && action != contextActionHide && action != contextActionClean { + return ContextActionResult{}, fmt.Errorf("unsupported context action %q", action) + } + owners, err := s.contextOwnersForResource(ctx, repositoryID, action, req) + if err != nil { + return ContextActionResult{}, err + } + if len(owners) == 0 { + return ContextActionResult{}, fmt.Errorf("resource is not backed by watch materialization") + } + if action == contextActionShow { + if err := s.focusedRescanContextOwners(ctx, repositoryID, owners); err != nil { + return ContextActionResult{}, err + } + } + maxTier := maxVisibilityTier(defaultVisibilityConfig(representReq.Visibility)) + adjustment := contextExpansionAdjustment{MaxTier: maxTier} + policiesCreated, policiesUpdated, deactivated := 0, 0, 0 + switch action { + case contextActionShow: + adjustment, err = s.AdjustContextExpansion(ctx, repositoryID, req, owners, 1, maxTier) + if err != nil { + return ContextActionResult{}, err + } + case contextActionClean: + adjustment, err = s.AdjustContextExpansion(ctx, repositoryID, req, owners, -1, maxTier) + if err != nil { + return ContextActionResult{}, err + } + default: + policiesCreated, policiesUpdated, deactivated, err = s.saveContextPolicies(ctx, repositoryID, action, contextScope(req.ResourceType), owners) + if err != nil { + return ContextActionResult{}, err + } + } + before, err := s.generatedWorkspaceCounts(ctx, repositoryID) + if err != nil { + return ContextActionResult{}, err + } + representation, err := NewRepresenter(s).Represent(ctx, repositoryID, representReq) + if err != nil { + return ContextActionResult{}, err + } + after, err := s.generatedWorkspaceCounts(ctx, repositoryID) + if err != nil { + return ContextActionResult{}, err + } + summary, err := s.RepresentationSummary(ctx, repositoryID) + if err != nil { + return ContextActionResult{}, err + } + return ContextActionResult{ + RepositoryID: repositoryID, + Action: action, + PoliciesCreated: policiesCreated, + PoliciesUpdated: policiesUpdated, + PoliciesDeactivated: deactivated, + OwnersAffected: len(owners), + TierBefore: adjustment.TierBefore, + TierAfter: adjustment.TierAfter, + MaxTier: adjustment.MaxTier, + ElementsAdded: positiveDelta(after.Elements, before.Elements), + ConnectorsAdded: positiveDelta(after.Connectors, before.Connectors), + ViewsAdded: positiveDelta(after.Views, before.Views), + ElementsRemoved: positiveDelta(before.Elements, after.Elements), + ConnectorsRemoved: positiveDelta(before.Connectors, after.Connectors), + ViewsRemoved: positiveDelta(before.Views, after.Views), + Representation: representation, + Summary: summary, + }, nil +} + +func contextScope(resourceType string) string { + switch resourceType { + case "view": + return "view" + default: + return "element" + } +} + +func positiveDelta(before, after int) int { + if before > after { + return before - after + } + return 0 +} + +func (s *Store) generatedWorkspaceCounts(ctx context.Context, repositoryID int64) (contextRemovalStats, error) { + var out contextRemovalStats + for resourceType, dest := range map[string]*int{ + "element": &out.Elements, + "connector": &out.Connectors, + "view": &out.Views, + } { + if err := s.db.QueryRowContext(ctx, `SELECT COUNT(*) FROM watch_materialization WHERE repository_id = ? AND resource_type = ?`, repositoryID, resourceType).Scan(dest); err != nil { + return contextRemovalStats{}, err + } + } + return out, nil +} + +func maxVisibilityTier(cfg VisibilityConfig) int { + cfg = defaultVisibilityConfig(cfg) + if cfg.TierMultiplier <= 0 { + return 0 + } + maxMultiplier := cfg.MaxExpansionMultiplier + if maxMultiplier < 1 { + maxMultiplier = 1 + } + tier := int((maxMultiplier - 1) / cfg.TierMultiplier) + if tier < 0 { + return 0 + } + if tier == 0 && maxMultiplier > 1 { + return 1 + } + return tier +} + +func effectiveMaxElementsPerView(thresholds Thresholds, cfg VisibilityConfig, tier int) int { + thresholds = defaultThresholds(thresholds) + cfg = defaultVisibilityConfig(cfg) + if tier <= 0 { + return thresholds.MaxElementsPerView + } + multiplier := 1 + float64(tier)*cfg.TierMultiplier + if multiplier > cfg.MaxExpansionMultiplier { + multiplier = cfg.MaxExpansionMultiplier + } + out := int(math.Round(float64(thresholds.MaxElementsPerView) * multiplier)) + if out < thresholds.MaxElementsPerView { + return thresholds.MaxElementsPerView + } + return out +} + +func (s *Store) AdjustContextExpansion(ctx context.Context, repositoryID int64, req ContextResourceRequest, owners []contextOwner, delta, maxTier int) (contextExpansionAdjustment, error) { + owners = uniqueContextOwners(owners) + if len(owners) == 0 { + return contextExpansionAdjustment{MaxTier: maxTier}, nil + } + before, err := s.contextExpansionTier(ctx, repositoryID, req) + if err != nil { + return contextExpansionAdjustment{}, err + } + after := max(before+delta, 0) + if maxTier >= 0 && after > maxTier { + after = maxTier + } + now := nowString() + for _, owner := range owners { + if after == 0 { + if _, err := s.db.ExecContext(ctx, ` + DELETE FROM watch_context_expansions + WHERE repository_id = ? AND scope_resource_type = ? AND scope_resource_id = ? AND scope_owner_type = ? AND scope_owner_key = ?`, + repositoryID, req.ResourceType, req.ResourceID, owner.OwnerType, owner.OwnerKey); err != nil { + return contextExpansionAdjustment{}, err + } + continue + } + if _, err := s.db.ExecContext(ctx, ` + INSERT INTO watch_context_expansions(repository_id, scope_resource_type, scope_resource_id, scope_owner_type, scope_owner_key, tier, created_at, updated_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?) + ON CONFLICT(repository_id, scope_resource_type, scope_resource_id, scope_owner_type, scope_owner_key) + DO UPDATE SET tier = excluded.tier, updated_at = excluded.updated_at`, + repositoryID, req.ResourceType, req.ResourceID, owner.OwnerType, owner.OwnerKey, after, now, now); err != nil { + return contextExpansionAdjustment{}, err + } + } + return contextExpansionAdjustment{TierBefore: before, TierAfter: after, MaxTier: maxTier, Owners: len(owners)}, nil +} + +func (s *Store) contextExpansionTier(ctx context.Context, repositoryID int64, req ContextResourceRequest) (int, error) { + var tier sql.NullInt64 + err := s.db.QueryRowContext(ctx, ` + SELECT MAX(tier) + FROM watch_context_expansions + WHERE repository_id = ? AND scope_resource_type = ? AND scope_resource_id = ?`, + repositoryID, req.ResourceType, req.ResourceID).Scan(&tier) + if errors.Is(err, sql.ErrNoRows) { + return 0, nil + } + if err != nil { + return 0, err + } + if !tier.Valid { + return 0, nil + } + return int(tier.Int64), nil +} + +func (s *Store) ActiveContextExpansionSet(ctx context.Context, repositoryID int64) (contextExpansionSet, error) { + rows, err := s.db.QueryContext(ctx, ` + SELECT scope_owner_type, scope_owner_key, tier + FROM watch_context_expansions + WHERE repository_id = ? AND tier > 0 + ORDER BY id`, repositoryID) + if err != nil { + return contextExpansionSet{}, err + } + defer func() { _ = rows.Close() }() + set := contextExpansionSet{Tiers: map[string]int{}} + for rows.Next() { + var ownerType, ownerKey string + var tier int + if err := rows.Scan(&ownerType, &ownerKey, &tier); err != nil { + return contextExpansionSet{}, err + } + key := ownerMapKey(ownerType, ownerKey) + if tier > set.Tiers[key] { + set.Tiers[key] = tier + } + if tier > set.MaxTier { + set.MaxTier = tier + } + } + return set, rows.Err() +} + +func (s *Store) contextOwnersForResource(ctx context.Context, repositoryID int64, action string, req ContextResourceRequest) ([]contextOwner, error) { + if req.ResourceID <= 0 { + return nil, fmt.Errorf("invalid resource id") + } + switch req.ResourceType { + case "element": + return s.contextOwnersForElement(ctx, repositoryID, action, req.ResourceID) + case "view": + return s.contextOwnersForView(ctx, repositoryID, action, req.ResourceID) + default: + return nil, fmt.Errorf("unsupported resource type %q", req.ResourceType) + } +} + +func (s *Store) contextOwnersForElement(ctx context.Context, repositoryID int64, action string, elementID int64) ([]contextOwner, error) { + mapping, ok, err := s.materializationByResource(ctx, repositoryID, "element", elementID) + if err != nil || !ok { + return nil, err + } + base := contextOwner{OwnerType: mapping.OwnerType, OwnerKey: mapping.OwnerKey} + if action == contextActionShow { + return []contextOwner{base}, nil + } + if action == contextActionClean { + return []contextOwner{base}, nil + } + return s.hideOwnersForElement(ctx, repositoryID, elementID, base) +} + +func (s *Store) contextOwnersForView(ctx context.Context, repositoryID int64, action string, viewID int64) ([]contextOwner, error) { + var owners []contextOwner + if mapping, ok, err := s.materializationByResource(ctx, repositoryID, "view", viewID); err != nil { + return nil, err + } else if ok { + owners = append(owners, contextOwner{OwnerType: mapping.OwnerType, OwnerKey: mapping.OwnerKey}) + } + placementOwners, err := s.materializedElementOwnersInView(ctx, repositoryID, viewID) + if err != nil { + return nil, err + } + owners = append(owners, placementOwners...) + connectorOwners, err := s.materializedConnectorOwnersInView(ctx, repositoryID, viewID) + if err != nil { + return nil, err + } + owners = append(owners, connectorOwners...) + owners = uniqueContextOwners(owners) + if action == contextActionShow { + return owners, nil + } + if action == contextActionClean { + return owners, nil + } + return s.rankHideOwners(ctx, repositoryID, owners) +} + +func (s *Store) hideOwnersForElement(ctx context.Context, repositoryID, elementID int64, base contextOwner) ([]contextOwner, error) { + owners := []contextOwner{base} + connectorOwners, err := s.materializedConnectorOwnersTouchingElement(ctx, repositoryID, elementID) + if err != nil { + return nil, err + } + owners = append(owners, connectorOwners...) + neighborOwners, err := s.materializedNeighborOwners(ctx, repositoryID, elementID) + if err != nil { + return nil, err + } + owners = append(owners, neighborOwners...) + return s.rankHideOwners(ctx, repositoryID, owners) +} + +func (s *Store) rankHideOwners(ctx context.Context, repositoryID int64, owners []contextOwner) ([]contextOwner, error) { + symbols, err := s.SymbolsForRepository(ctx, repositoryID) + if err != nil { + return nil, err + } + identityKeys, err := s.SymbolIdentityKeys(ctx, repositoryID) + if err != nil { + return nil, err + } + owners = expandContainerOwnersToSymbols(owners, symbols, identityKeys) + keep := map[string]struct{}{} + for _, sym := range symbols { + key := ownerMapKey("symbol", symbolOwnerKey(sym, identityKeys)) + if isExportedSymbol(sym) { + keep[key] = struct{}{} + } + } + var out []contextOwner + for _, owner := range uniqueContextOwners(owners) { + if owner.OwnerType == "file" || owner.OwnerType == "folder" { + continue + } + if _, ok := keep[ownerKey(owner)]; ok { + continue + } + out = append(out, owner) + } + if len(out) == 0 { + return uniqueContextOwners(owners), nil + } + sort.SliceStable(out, func(i, j int) bool { + return hideOwnerPriority(out[i]) < hideOwnerPriority(out[j]) + }) + return out, nil +} + +func expandContainerOwnersToSymbols(owners []contextOwner, symbols []Symbol, identityKeys map[string]string) []contextOwner { + out := append([]contextOwner{}, owners...) + for _, owner := range owners { + switch owner.OwnerType { + case "file": + file := strings.TrimPrefix(owner.OwnerKey, "file:") + for _, sym := range symbols { + if sym.FilePath == file { + out = append(out, contextOwner{OwnerType: "symbol", OwnerKey: symbolOwnerKey(sym, identityKeys)}) + } + } + case "folder": + folder := strings.TrimSuffix(strings.TrimPrefix(owner.OwnerKey, "folder:"), "/") + for _, sym := range symbols { + if sym.FilePath == folder || strings.HasPrefix(sym.FilePath, folder+"/") { + out = append(out, contextOwner{OwnerType: "symbol", OwnerKey: symbolOwnerKey(sym, identityKeys)}) + } + } + } + } + return out +} + +func hideOwnerPriority(owner contextOwner) int { + switch owner.OwnerType { + case "reference", "file-reference", "folder-reference": + return 0 + case "symbol": + return 1 + case "file", "folder", "cluster": + return 2 + default: + return 3 + } +} + +func (s *Store) focusedRescanContextOwners(ctx context.Context, repositoryID int64, owners []contextOwner) error { + repo, err := s.Repository(ctx, repositoryID) + if err != nil { + return err + } + files, err := s.contextOwnerFiles(ctx, repositoryID, owners) + if err != nil { + return err + } + if len(files) == 0 { + return nil + } + scanner := NewScanner(s) + _, err = scanner.ScanFilesWithOptions(ctx, repo, files, ScanOptions{Force: true}) + return err +} + +func (s *Store) contextOwnerFiles(ctx context.Context, repositoryID int64, owners []contextOwner) ([]string, error) { + symbols, err := s.SymbolsForRepository(ctx, repositoryID) + if err != nil { + return nil, err + } + identityKeys, err := s.SymbolIdentityKeys(ctx, repositoryID) + if err != nil { + return nil, err + } + files := map[string]struct{}{} + for _, owner := range owners { + switch owner.OwnerType { + case "file": + if file := strings.TrimPrefix(owner.OwnerKey, "file:"); file != "" { + files[file] = struct{}{} + } + case "symbol": + for _, sym := range symbols { + if owner.OwnerKey == sym.StableKey || owner.OwnerKey == symbolOwnerKey(sym, identityKeys) { + files[sym.FilePath] = struct{}{} + } + } + } + } + out := make([]string, 0, len(files)) + for file := range files { + out = append(out, file) + } + sort.Strings(out) + return out, nil +} + +func (s *Store) materializationByResource(ctx context.Context, repositoryID int64, resourceType string, resourceID int64) (watchMaterializationMapping, bool, error) { + var item watchMaterializationMapping + err := s.db.QueryRowContext(ctx, ` + SELECT id, owner_type, owner_key, resource_type, resource_id, updated_at + FROM watch_materialization + WHERE repository_id = ? AND resource_type = ? AND resource_id = ? + ORDER BY id DESC + LIMIT 1`, repositoryID, resourceType, resourceID).Scan(&item.ID, &item.OwnerType, &item.OwnerKey, &item.ResourceType, &item.ResourceID, &item.UpdatedAt) + if errors.Is(err, sql.ErrNoRows) { + return watchMaterializationMapping{}, false, nil + } + return item, err == nil, err +} + +func (s *Store) materializedElementOwnersInView(ctx context.Context, repositoryID, viewID int64) ([]contextOwner, error) { + rows, err := s.db.QueryContext(ctx, ` + SELECT wm.owner_type, wm.owner_key + FROM placements p + JOIN watch_materialization wm ON wm.resource_type = 'element' AND wm.resource_id = p.element_id + WHERE wm.repository_id = ? AND p.view_id = ?`, repositoryID, viewID) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + return scanContextOwners(rows) +} + +func (s *Store) materializedConnectorOwnersInView(ctx context.Context, repositoryID, viewID int64) ([]contextOwner, error) { + rows, err := s.db.QueryContext(ctx, ` + SELECT wm.owner_type, wm.owner_key + FROM connectors c + JOIN watch_materialization wm ON wm.resource_type = 'connector' AND wm.resource_id = c.id + WHERE wm.repository_id = ? AND c.view_id = ?`, repositoryID, viewID) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + return scanContextOwners(rows) +} + +func (s *Store) materializedConnectorOwnersTouchingElement(ctx context.Context, repositoryID, elementID int64) ([]contextOwner, error) { + rows, err := s.db.QueryContext(ctx, ` + SELECT wm.owner_type, wm.owner_key + FROM connectors c + JOIN watch_materialization wm ON wm.resource_type = 'connector' AND wm.resource_id = c.id + WHERE wm.repository_id = ? AND (c.source_element_id = ? OR c.target_element_id = ?)`, repositoryID, elementID, elementID) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + return scanContextOwners(rows) +} + +func (s *Store) materializedNeighborOwners(ctx context.Context, repositoryID, elementID int64) ([]contextOwner, error) { + rows, err := s.db.QueryContext(ctx, ` + SELECT wm.owner_type, wm.owner_key + FROM connectors c + JOIN watch_materialization wm ON wm.resource_type = 'element' + AND wm.resource_id = CASE WHEN c.source_element_id = ? THEN c.target_element_id ELSE c.source_element_id END + WHERE wm.repository_id = ? AND (c.source_element_id = ? OR c.target_element_id = ?)`, elementID, repositoryID, elementID, elementID) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + return scanContextOwners(rows) +} + +func scanContextOwners(rows *sql.Rows) ([]contextOwner, error) { + var owners []contextOwner + for rows.Next() { + var owner contextOwner + if err := rows.Scan(&owner.OwnerType, &owner.OwnerKey); err != nil { + return nil, err + } + owners = append(owners, owner) + } + return uniqueContextOwners(owners), rows.Err() +} + +func (s *Store) saveContextPolicies(ctx context.Context, repositoryID int64, action, scope string, owners []contextOwner) (int, int, int, error) { + now := nowString() + opposite := contextActionShow + if action == contextActionShow { + opposite = contextActionHide + } + created, updated, deactivated := 0, 0, 0 + for _, owner := range uniqueContextOwners(owners) { + res, err := s.db.ExecContext(ctx, ` + UPDATE watch_context_policies + SET active = 0, updated_at = ? + WHERE repository_id = ? AND owner_type = ? AND owner_key = ? AND action = ? AND active = 1`, + now, repositoryID, owner.OwnerType, owner.OwnerKey, opposite) + if err != nil { + return 0, 0, 0, err + } + if rows, _ := res.RowsAffected(); rows > 0 { + deactivated += int(rows) + } + var id int64 + err = s.db.QueryRowContext(ctx, ` + SELECT id FROM watch_context_policies + WHERE repository_id = ? AND owner_type = ? AND owner_key = ? AND action = ? AND active = 1 + ORDER BY id DESC LIMIT 1`, repositoryID, owner.OwnerType, owner.OwnerKey, action).Scan(&id) + if err != nil && !errors.Is(err, sql.ErrNoRows) { + return 0, 0, 0, err + } + reason := "user context " + action + if id != 0 { + if _, err := s.db.ExecContext(ctx, `UPDATE watch_context_policies SET scope = ?, reason = ?, updated_at = ? WHERE id = ?`, scope, reason, now, id); err != nil { + return 0, 0, 0, err + } + updated++ + continue + } + if _, err := s.db.ExecContext(ctx, ` + INSERT INTO watch_context_policies(repository_id, owner_type, owner_key, action, scope, active, reason, created_at, updated_at) + VALUES (?, ?, ?, ?, ?, 1, ?, ?, ?)`, repositoryID, owner.OwnerType, owner.OwnerKey, action, scope, reason, now, now); err != nil { + return 0, 0, 0, err + } + created++ + } + return created, updated, deactivated, nil +} + +func (s *Store) ActiveContextPolicySet(ctx context.Context, repositoryID int64) (contextPolicySet, error) { + rows, err := s.db.QueryContext(ctx, ` + SELECT owner_type, owner_key, action + FROM watch_context_policies + WHERE repository_id = ? AND active = 1 + ORDER BY id`, repositoryID) + if err != nil { + return contextPolicySet{}, err + } + defer func() { _ = rows.Close() }() + policies := contextPolicySet{Show: map[string]string{}, Hide: map[string]string{}} + for rows.Next() { + var ownerType, ownerKey, action string + if err := rows.Scan(&ownerType, &ownerKey, &action); err != nil { + return contextPolicySet{}, err + } + key := ownerMapKey(ownerType, ownerKey) + switch action { + case contextActionShow: + policies.Show[key] = "user marked as context" + delete(policies.Hide, key) + case contextActionHide: + if _, shown := policies.Show[key]; !shown { + policies.Hide[key] = "user marked as noise" + } + } + } + return policies, rows.Err() +} + +func uniqueContextOwners(owners []contextOwner) []contextOwner { + seen := map[string]struct{}{} + var out []contextOwner + for _, owner := range owners { + owner.OwnerType = strings.TrimSpace(owner.OwnerType) + owner.OwnerKey = strings.TrimSpace(owner.OwnerKey) + if owner.OwnerType == "" || owner.OwnerKey == "" { + continue + } + key := ownerKey(owner) + if _, ok := seen[key]; ok { + continue + } + seen[key] = struct{}{} + out = append(out, owner) + } + return out +} + +func ownerKey(owner contextOwner) string { + return ownerMapKey(owner.OwnerType, owner.OwnerKey) +} + +func ownerMapKey(ownerType, ownerKey string) string { + return ownerType + "\x00" + ownerKey +} + +func (p contextExpansionSet) ownerTier(ownerType, ownerKey string) int { + if p.Tiers == nil { + return 0 + } + return p.Tiers[ownerMapKey(ownerType, ownerKey)] +} + +func (p contextExpansionSet) symbolTier(sym Symbol, identityKeys map[string]string) int { + tier := maxInt( + p.ownerTier("symbol", symbolOwnerKey(sym, identityKeys)), + p.ownerTier("symbol", sym.StableKey), + p.ownerTier("file", "file:"+sym.FilePath), + ) + dir := path.Dir(sym.FilePath) + for dir != "." && dir != "/" && dir != "" { + tier = maxInt(tier, p.ownerTier("folder", "folder:"+dir)) + next := path.Dir(dir) + if next == dir { + break + } + dir = next + } + return tier +} + +func (p contextExpansionSet) fileTier(filePath string) int { + tier := p.ownerTier("file", "file:"+filePath) + dir := path.Dir(filePath) + for dir != "." && dir != "/" && dir != "" { + tier = maxInt(tier, p.ownerTier("folder", "folder:"+dir)) + next := path.Dir(dir) + if next == dir { + break + } + dir = next + } + return tier +} + +func maxInt(values ...int) int { + out := 0 + for _, value := range values { + if value > out { + out = value + } + } + return out +} + +func referenceOwnerKey(ref Reference, symbols map[int64]Symbol, identityKeys map[string]string) string { + source := symbols[ref.SourceSymbolID] + target := symbols[ref.TargetSymbolID] + return fmt.Sprintf("symbol:%s:%s:%s", symbolOwnerKey(source, identityKeys), symbolOwnerKey(target, identityKeys), ref.Kind) +} diff --git a/internal/watch/embedding.go b/internal/watch/embedding.go new file mode 100644 index 0000000..739ac36 --- /dev/null +++ b/internal/watch/embedding.go @@ -0,0 +1,653 @@ +package watch + +import ( + "bytes" + "context" + "crypto/sha256" + "encoding/binary" + "encoding/hex" + "encoding/json" + "fmt" + "math" + "net/http" + "regexp" + "sort" + "strings" + "time" + "unicode" + + "github.com/openai/openai-go" + "github.com/openai/openai-go/option" +) + +const ( + DefaultEmbeddingProvider = "openai" + DefaultOpenAIEndpoint = "http://127.0.0.1:8000/v1/embeddings" + DefaultOpenAIModel = "embeddinggemma-300m-4bit" + DefaultOpenAIAPIKey = "tldcli" + DefaultOllamaEndpoint = "http://localhost:11434" + DefaultOllamaModel = "jina/jina-embeddings-v2-base-en" + DefaultLexicalModel = "lexical-code-fingerprint-v1" + DefaultLexicalDimension = 512 + DefaultEmbeddingHealthThreshold = 0.70 + RenameEmbeddingThreshold = 0.78 +) + +type ModelID struct { + Provider string + Model string + Dimension int + ConfigHash string +} + +type EmbeddingInput struct { + OwnerType string + OwnerKey string + Text string +} + +type Vector []float32 + +type Provider interface { + ModelID() ModelID + Embed(ctx context.Context, inputs []EmbeddingInput) ([]Vector, error) +} + +type HealthResult struct { + Dimension int + Similarity float64 +} + +type HealthCheckingProvider interface { + HealthCheck(ctx context.Context) (HealthResult, error) +} + +type NoopProvider struct{} + +func (NoopProvider) ModelID() ModelID { + return ModelID{Provider: "none", Model: "", Dimension: 0, ConfigHash: stableHash(normalizeEmbeddingConfig(EmbeddingConfig{}))} +} + +func (NoopProvider) Embed(context.Context, []EmbeddingInput) ([]Vector, error) { + return []Vector{}, nil +} + +type DeterministicProvider struct { + Model string + Dimension int +} + +func (p DeterministicProvider) ModelID() ModelID { + dimension := p.Dimension + if dimension <= 0 { + dimension = 8 + } + model := p.Model + if strings.TrimSpace(model) == "" { + model = "local-deterministic-test" + } + cfg := EmbeddingConfig{Provider: "local-deterministic-test", Model: model, Dimension: dimension} + return ModelID{Provider: cfg.Provider, Model: cfg.Model, Dimension: cfg.Dimension, ConfigHash: stableHash(cfg)} +} + +func (p DeterministicProvider) Embed(_ context.Context, inputs []EmbeddingInput) ([]Vector, error) { + id := p.ModelID() + out := make([]Vector, 0, len(inputs)) + for _, input := range inputs { + vector := make(Vector, id.Dimension) + seed := []byte(input.OwnerType + "\x00" + input.OwnerKey + "\x00" + input.Text) + for i := range vector { + sum := sha256.Sum256(append(seed, byte(i))) + raw := binary.BigEndian.Uint32(sum[:4]) + vector[i] = float32(float64(raw)/float64(math.MaxUint32)*2 - 1) + } + out = append(out, vector) + } + return out, nil +} + +type LexicalProvider struct { + Model string + Dimension int +} + +func (p LexicalProvider) ModelID() ModelID { + dimension := p.Dimension + if dimension <= 0 { + dimension = DefaultLexicalDimension + } + model := p.Model + if strings.TrimSpace(model) == "" { + model = DefaultLexicalModel + } + cfg := EmbeddingConfig{Provider: "local-lexical", Model: model, Dimension: dimension} + return ModelID{Provider: cfg.Provider, Model: cfg.Model, Dimension: cfg.Dimension, ConfigHash: stableHash(cfg)} +} + +func (p LexicalProvider) Embed(_ context.Context, inputs []EmbeddingInput) ([]Vector, error) { + id := p.ModelID() + out := make([]Vector, 0, len(inputs)) + for _, input := range inputs { + out = append(out, lexicalVector(input.Text, id.Dimension)) + } + return out, nil +} + +var lexicalIdentifierRE = regexp.MustCompile(`[A-Za-z_][A-Za-z0-9_]*|\d+(?:\.\d+)?|"[^"\n]*"|'[^'\n]*'|` + "`[^`\n]*`" + `|[{}()[\].,;:+\-*/%=&|!<>^~?]`) + +var lexicalKeywords = map[string]struct{}{ + "break": {}, "case": {}, "catch": {}, "class": {}, "const": {}, "continue": {}, "def": {}, "defer": {}, "do": {}, "else": {}, "enum": {}, "except": {}, "finally": {}, "for": {}, "func": {}, "function": {}, "go": {}, "if": {}, "import": {}, "interface": {}, "lambda": {}, "match": {}, "method": {}, "package": {}, "private": {}, "protected": {}, "public": {}, "raise": {}, "return": {}, "select": {}, "static": {}, "struct": {}, "switch": {}, "throw": {}, "try": {}, "type": {}, "var": {}, "while": {}, "yield": {}, +} + +func lexicalVector(text string, dimension int) Vector { + if dimension <= 0 { + dimension = DefaultLexicalDimension + } + vector := make(Vector, dimension) + tokens := lexicalTokens(text) + for i, token := range tokens { + lowerToken := strings.ToLower(token) + addFeature(vector, "tok:"+lowerToken, 1.0) + if _, ok := lexicalKeywords[lowerToken]; ok { + addFeature(vector, "kw:"+lowerToken, 1.4) + } + for _, part := range splitIdentifierToken(token) { + addFeature(vector, "id:"+part, 1.2) + } + for n := 3; n <= 5; n++ { + for _, gram := range charNGrams(lowerToken, n) { + addFeature(vector, fmt.Sprintf("c%d:%s", n, gram), 0.25) + } + } + if i+1 < len(tokens) { + addFeature(vector, "bi:"+lowerToken+"\x00"+strings.ToLower(tokens[i+1]), 0.8) + } + if i+2 < len(tokens) { + addFeature(vector, "tri:"+lowerToken+"\x00"+strings.ToLower(tokens[i+1])+"\x00"+strings.ToLower(tokens[i+2]), 0.45) + } + } + for _, token := range structuralTokens(text) { + addFeature(vector, "ast:"+token, 1.0) + } + normalizeVector(vector) + return vector +} + +func lexicalTokens(text string) []string { + matches := lexicalIdentifierRE.FindAllString(text, -1) + tokens := make([]string, 0, len(matches)) + for _, match := range matches { + token := normalizeLexicalToken(match) + if token != "" { + tokens = append(tokens, token) + } + } + return tokens +} + +func normalizeLexicalToken(token string) string { + token = strings.TrimSpace(token) + if token == "" { + return "" + } + switch { + case strings.HasPrefix(token, "\"") || strings.HasPrefix(token, "'") || strings.HasPrefix(token, "`"): + return "string_lit" + case unicode.IsDigit([]rune(token)[0]): + return "number_lit" + } + return token +} + +func splitIdentifierToken(token string) []string { + if token == "string_lit" || token == "number_lit" || token == "" { + return nil + } + var parts []string + var current []rune + flush := func() { + if len(current) > 0 { + parts = append(parts, strings.ToLower(string(current))) + current = nil + } + } + for i, r := range token { + if r == '_' || r == '-' || r == '.' { + flush() + continue + } + if i > 0 && unicode.IsUpper(r) { + flush() + } + if unicode.IsLetter(r) || unicode.IsDigit(r) { + current = append(current, unicode.ToLower(r)) + } + } + flush() + sort.Strings(parts) + return compactStrings(parts) +} + +func charNGrams(token string, n int) []string { + runes := []rune(token) + if len(runes) < n { + return nil + } + out := make([]string, 0, len(runes)-n+1) + for i := 0; i+n <= len(runes); i++ { + out = append(out, string(runes[i:i+n])) + } + return out +} + +func structuralTokens(text string) []string { + tokens := []string{} + for line := range strings.SplitSeq(text, "\n") { + trimmed := strings.TrimSpace(line) + if trimmed == "" { + continue + } + for _, marker := range []string{"if", "for", "while", "switch", "match", "try", "catch", "except", "return", "yield", "throw", "raise", "defer", "go"} { + if strings.Contains(trimmed, marker) { + tokens = append(tokens, marker) + } + } + tokens = append(tokens, fmt.Sprintf("indent:%d", leadingIndentWidth(line)/4)) + } + return tokens +} + +func addFeature(vector Vector, feature string, weight float32) { + sum := sha256.Sum256([]byte(feature)) + index := int(binary.LittleEndian.Uint32(sum[:4]) % uint32(len(vector))) + sign := float32(1) + if sum[4]&1 == 1 { + sign = -1 + } + vector[index] += sign * weight +} + +func normalizeVector(vector Vector) { + var norm float64 + for _, value := range vector { + norm += float64(value * value) + } + if norm == 0 { + return + } + scale := float32(1 / math.Sqrt(norm)) + for i := range vector { + vector[i] *= scale + } +} + +func compactStrings(values []string) []string { + if len(values) == 0 { + return values + } + out := values[:0] + last := "" + for _, value := range values { + if value == "" || value == last { + continue + } + out = append(out, value) + last = value + } + return out +} + +type OllamaProvider struct { + Endpoint string + Model string + Dimension int + HealthThreshold float64 + Client *http.Client +} + +func (p *OllamaProvider) ModelID() ModelID { + cfg := normalizeEmbeddingConfig(EmbeddingConfig{ + Provider: "ollama", + Endpoint: p.Endpoint, + Model: p.Model, + Dimension: p.Dimension, + HealthThreshold: p.HealthThreshold, + }) + return ModelID{Provider: cfg.Provider, Model: cfg.Model, Dimension: cfg.Dimension, ConfigHash: stableHash(cfg)} +} + +func (p *OllamaProvider) Embed(ctx context.Context, inputs []EmbeddingInput) ([]Vector, error) { + if len(inputs) == 0 { + return []Vector{}, nil + } + texts := make([]string, 0, len(inputs)) + for _, input := range inputs { + texts = append(texts, input.Text) + } + vectors, err := p.embedTexts(ctx, texts) + if err != nil { + return nil, err + } + if len(vectors) != len(inputs) { + return nil, fmt.Errorf("ollama returned %d embeddings for %d inputs", len(vectors), len(inputs)) + } + if len(vectors) > 0 && p.Dimension <= 0 { + p.Dimension = len(vectors[0]) + } + return vectors, nil +} + +func (p *OllamaProvider) HealthCheck(ctx context.Context) (HealthResult, error) { + texts := []string{ + "Why is the sky blue?", + "What causes the sky to look blue during the day?", + } + vectors, err := p.embedTexts(ctx, texts) + if err != nil { + return HealthResult{}, err + } + if len(vectors) != 2 || len(vectors[0]) == 0 || len(vectors[1]) == 0 { + return HealthResult{}, fmt.Errorf("ollama healthcheck returned empty embeddings") + } + if len(vectors[0]) != len(vectors[1]) { + return HealthResult{}, fmt.Errorf("ollama healthcheck returned mismatched dimensions %d and %d", len(vectors[0]), len(vectors[1])) + } + sim := CosineSimilarity(vectors[0], vectors[1]) + threshold := p.HealthThreshold + if threshold <= 0 { + threshold = DefaultEmbeddingHealthThreshold + } + if sim < threshold { + return HealthResult{}, fmt.Errorf("ollama healthcheck similarity %.3f is below threshold %.3f", sim, threshold) + } + p.Dimension = len(vectors[0]) + return HealthResult{Dimension: len(vectors[0]), Similarity: sim}, nil +} + +func (p *OllamaProvider) embedTexts(ctx context.Context, texts []string) ([]Vector, error) { + endpoint := strings.TrimRight(p.Endpoint, "/") + if endpoint == "" { + endpoint = DefaultOllamaEndpoint + } + client := p.Client + if client == nil { + client = &http.Client{Timeout: 30 * time.Second} + } + body, _ := json.Marshal(map[string]any{ + "model": p.Model, + "input": texts, + }) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint+"/api/embed", bytes.NewReader(body)) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("ollama embed request: %w", err) + } + defer func() { _ = resp.Body.Close() }() + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return nil, fmt.Errorf("ollama embed request failed: %s", resp.Status) + } + var parsed struct { + Embedding []float32 `json:"embedding"` + Embeddings [][]float32 `json:"embeddings"` + } + if err := json.NewDecoder(resp.Body).Decode(&parsed); err != nil { + return nil, fmt.Errorf("decode ollama embed response: %w", err) + } + if len(parsed.Embeddings) > 0 { + return vectorsFromFloatSlices(parsed.Embeddings), nil + } + if len(parsed.Embedding) > 0 { + return []Vector{Vector(parsed.Embedding)}, nil + } + return nil, fmt.Errorf("ollama embed response did not include embeddings") +} + +type OpenAIProvider struct { + Endpoint string + Model string + Dimension int + HealthThreshold float64 + Client *http.Client +} + +func (p *OpenAIProvider) ModelID() ModelID { + cfg := normalizeEmbeddingConfig(EmbeddingConfig{ + Provider: "openai", + Endpoint: p.Endpoint, + Model: p.Model, + Dimension: p.Dimension, + HealthThreshold: p.HealthThreshold, + }) + return ModelID{Provider: cfg.Provider, Model: cfg.Model, Dimension: cfg.Dimension, ConfigHash: stableHash(cfg)} +} + +func (p *OpenAIProvider) Embed(ctx context.Context, inputs []EmbeddingInput) ([]Vector, error) { + if len(inputs) == 0 { + return []Vector{}, nil + } + texts := make([]string, 0, len(inputs)) + for _, input := range inputs { + texts = append(texts, input.Text) + } + vectors, err := p.embedTexts(ctx, texts) + if err != nil { + return nil, err + } + if len(vectors) != len(inputs) { + return nil, fmt.Errorf("openai returned %d embeddings for %d inputs", len(vectors), len(inputs)) + } + if len(vectors) > 0 && p.Dimension <= 0 { + p.Dimension = len(vectors[0]) + } + return vectors, nil +} + +func (p *OpenAIProvider) HealthCheck(ctx context.Context) (HealthResult, error) { + texts := []string{ + "Why is the sky blue?", + "What causes the sky to look blue during the day?", + } + vectors, err := p.embedTexts(ctx, texts) + if err != nil { + return HealthResult{}, err + } + if len(vectors) != 2 || len(vectors[0]) == 0 || len(vectors[1]) == 0 { + return HealthResult{}, fmt.Errorf("openai healthcheck returned empty embeddings") + } + if len(vectors[0]) != len(vectors[1]) { + return HealthResult{}, fmt.Errorf("openai healthcheck returned mismatched dimensions %d and %d", len(vectors[0]), len(vectors[1])) + } + sim := CosineSimilarity(vectors[0], vectors[1]) + threshold := p.HealthThreshold + if threshold <= 0 { + threshold = DefaultEmbeddingHealthThreshold + } + if sim < threshold { + return HealthResult{}, fmt.Errorf("openai healthcheck similarity %.3f is below threshold %.3f", sim, threshold) + } + p.Dimension = len(vectors[0]) + return HealthResult{Dimension: len(vectors[0]), Similarity: sim}, nil +} + +func (p *OpenAIProvider) embedTexts(ctx context.Context, texts []string) ([]Vector, error) { + opts := []option.RequestOption{ + option.WithBaseURL(openAIBaseURL(p.Endpoint)), + option.WithAPIKey(DefaultOpenAIAPIKey), + option.WithRequestTimeout(30 * time.Second), + } + if p.Client != nil { + opts = append(opts, option.WithHTTPClient(p.Client)) + } + client := openai.NewClient(opts...) + resp, err := client.Embeddings.New(ctx, openai.EmbeddingNewParams{ + Model: openai.EmbeddingModel(p.Model), + Input: openai.EmbeddingNewParamsInputUnion{OfArrayOfStrings: texts}, + }) + if err != nil { + return nil, fmt.Errorf("openai embeddings request: %w", err) + } + vectors := make([]Vector, 0, len(resp.Data)) + for _, item := range resp.Data { + vector := make(Vector, len(item.Embedding)) + for i, value := range item.Embedding { + vector[i] = float32(value) + } + vectors = append(vectors, vector) + } + return vectors, nil +} + +func openAIBaseURL(endpoint string) string { + endpoint = strings.TrimRight(strings.TrimSpace(endpoint), "/") + if endpoint == "" { + endpoint = DefaultOpenAIEndpoint + } + if before, ok := strings.CutSuffix(endpoint, "/embeddings"); ok { + return before + } + return endpoint +} + +func NewEmbeddingProvider(cfg EmbeddingConfig) (Provider, error) { + cfg = normalizeEmbeddingConfig(cfg) + switch cfg.Provider { + case "none": + return NoopProvider{}, nil + case "openai": + return &OpenAIProvider{Endpoint: cfg.Endpoint, Model: cfg.Model, Dimension: cfg.Dimension, HealthThreshold: cfg.HealthThreshold}, nil + case "ollama": + return &OllamaProvider{Endpoint: cfg.Endpoint, Model: cfg.Model, Dimension: cfg.Dimension, HealthThreshold: cfg.HealthThreshold}, nil + case "local-lexical": + return LexicalProvider{Model: cfg.Model, Dimension: cfg.Dimension}, nil + case "local-deterministic-test": + return DeterministicProvider{Model: cfg.Model, Dimension: cfg.Dimension}, nil + default: + return nil, fmt.Errorf("unsupported embedding provider %q", cfg.Provider) + } +} + +func normalizeEmbeddingConfig(cfg EmbeddingConfig) EmbeddingConfig { + cfg.Provider = strings.TrimSpace(cfg.Provider) + cfg.Endpoint = strings.TrimRight(strings.TrimSpace(cfg.Endpoint), "/") + cfg.Model = strings.TrimSpace(cfg.Model) + if cfg.Provider == "" { + cfg.Provider = DefaultEmbeddingProvider + } + if cfg.Provider == "none" { + cfg.Endpoint = "" + cfg.Model = "" + cfg.Dimension = 0 + cfg.HealthThreshold = 0 + } + if cfg.Provider == "openai" { + if cfg.Endpoint == "" { + cfg.Endpoint = DefaultOpenAIEndpoint + } + if cfg.Model == "" { + cfg.Model = DefaultOpenAIModel + } + if cfg.HealthThreshold <= 0 { + cfg.HealthThreshold = DefaultEmbeddingHealthThreshold + } + } + if cfg.Provider == "ollama" { + if cfg.Endpoint == "" { + cfg.Endpoint = DefaultOllamaEndpoint + } + if cfg.Model == "" { + cfg.Model = DefaultOllamaModel + } + if cfg.HealthThreshold <= 0 { + cfg.HealthThreshold = DefaultEmbeddingHealthThreshold + } + } + if cfg.Provider == "local-lexical" { + cfg.Endpoint = "" + if cfg.Model == "" { + cfg.Model = DefaultLexicalModel + } + if cfg.Dimension <= 0 { + cfg.Dimension = DefaultLexicalDimension + } + cfg.HealthThreshold = 0 + } + if cfg.Provider == "local-deterministic-test" && cfg.Dimension <= 0 { + cfg.Dimension = 8 + } + if cfg.TimeoutSeconds <= 0 { + cfg.TimeoutSeconds = 60 + } + return cfg +} + +func NormalizeEmbeddingConfig(cfg EmbeddingConfig) EmbeddingConfig { + return normalizeEmbeddingConfig(cfg) +} + +func CheckEmbeddingHealth(ctx context.Context, cfg EmbeddingConfig) (EmbeddingConfig, HealthResult, error) { + cfg = normalizeEmbeddingConfig(cfg) + provider, err := NewEmbeddingProvider(cfg) + if err != nil { + return cfg, HealthResult{}, err + } + checker, ok := provider.(HealthCheckingProvider) + if !ok { + return cfg, HealthResult{Dimension: provider.ModelID().Dimension, Similarity: 1}, nil + } + result, err := checker.HealthCheck(ctx) + if err != nil { + return cfg, HealthResult{}, err + } + if result.Dimension > 0 { + cfg.Dimension = result.Dimension + } + return cfg, result, nil +} + +func vectorBytes(vector Vector) []byte { + out := make([]byte, len(vector)*4) + for i, value := range vector { + binary.LittleEndian.PutUint32(out[i*4:(i+1)*4], math.Float32bits(value)) + } + return out +} + +func inputHash(input EmbeddingInput) string { + return hashString(input.Text) +} + +func stableHash(value any) string { + data, _ := json.Marshal(value) + sum := sha256.Sum256(data) + return hex.EncodeToString(sum[:]) +} + +func vectorsFromFloatSlices(values [][]float32) []Vector { + out := make([]Vector, 0, len(values)) + for _, value := range values { + out = append(out, Vector(value)) + } + return out +} + +func CosineSimilarity(left, right Vector) float64 { + if len(left) == 0 || len(left) != len(right) { + return 0 + } + var dot, leftNorm, rightNorm float64 + for i := range left { + l := float64(left[i]) + r := float64(right[i]) + dot += l * r + leftNorm += l * l + rightNorm += r * r + } + if leftNorm == 0 || rightNorm == 0 { + return 0 + } + return dot / (math.Sqrt(leftNorm) * math.Sqrt(rightNorm)) +} diff --git a/internal/watch/enrich/defaults/catalog.go b/internal/watch/enrich/defaults/catalog.go new file mode 100644 index 0000000..1a1c9cd --- /dev/null +++ b/internal/watch/enrich/defaults/catalog.go @@ -0,0 +1,288 @@ +package defaults + +import ( + "github.com/mertcikla/tld/internal/watch/enrich" + "github.com/mertcikla/tld/internal/watch/enrich/enrichers/ai" + "github.com/mertcikla/tld/internal/watch/enrich/enrichers/apispec" + "github.com/mertcikla/tld/internal/watch/enrich/enrichers/auth" + "github.com/mertcikla/tld/internal/watch/enrich/enrichers/cloud" + "github.com/mertcikla/tld/internal/watch/enrich/enrichers/compose" + "github.com/mertcikla/tld/internal/watch/enrich/enrichers/config" + "github.com/mertcikla/tld/internal/watch/enrich/enrichers/dataeng" + "github.com/mertcikla/tld/internal/watch/enrich/enrichers/datastore" + "github.com/mertcikla/tld/internal/watch/enrich/enrichers/deployment" + frontendts "github.com/mertcikla/tld/internal/watch/enrich/enrichers/frontend/typescript" + "github.com/mertcikla/tld/internal/watch/enrich/enrichers/httpclient" + "github.com/mertcikla/tld/internal/watch/enrich/enrichers/iac" + "github.com/mertcikla/tld/internal/watch/enrich/enrichers/inventory" + "github.com/mertcikla/tld/internal/watch/enrich/enrichers/iot" + "github.com/mertcikla/tld/internal/watch/enrich/enrichers/ipc" + "github.com/mertcikla/tld/internal/watch/enrich/enrichers/jobs" + "github.com/mertcikla/tld/internal/watch/enrich/enrichers/messaging" + "github.com/mertcikla/tld/internal/watch/enrich/enrichers/observability" + ormcpp "github.com/mertcikla/tld/internal/watch/enrich/enrichers/orm/cpp" + ormgo "github.com/mertcikla/tld/internal/watch/enrich/enrichers/orm/golang" + ormjava "github.com/mertcikla/tld/internal/watch/enrich/enrichers/orm/java" + ormpython "github.com/mertcikla/tld/internal/watch/enrich/enrichers/orm/python" + ormrust "github.com/mertcikla/tld/internal/watch/enrich/enrichers/orm/rust" + ormts "github.com/mertcikla/tld/internal/watch/enrich/enrichers/orm/typescript" + "github.com/mertcikla/tld/internal/watch/enrich/enrichers/osintegration" + cpproutes "github.com/mertcikla/tld/internal/watch/enrich/enrichers/routes/cpp" + goroutes "github.com/mertcikla/tld/internal/watch/enrich/enrichers/routes/golang" + javaroutes "github.com/mertcikla/tld/internal/watch/enrich/enrichers/routes/java" + pythonroutes "github.com/mertcikla/tld/internal/watch/enrich/enrichers/routes/python" + rustroutes "github.com/mertcikla/tld/internal/watch/enrich/enrichers/routes/rust" + tstypes "github.com/mertcikla/tld/internal/watch/enrich/enrichers/routes/typescript" + rpcclients "github.com/mertcikla/tld/internal/watch/enrich/enrichers/rpc/clients" + rpcgrpc "github.com/mertcikla/tld/internal/watch/enrich/enrichers/rpc/grpc" + runtimeenrich "github.com/mertcikla/tld/internal/watch/enrich/enrichers/runtime" + "github.com/mertcikla/tld/internal/watch/enrich/enrichers/secrets" + "github.com/mertcikla/tld/internal/watch/enrich/enrichers/storage" + pythontraffic "github.com/mertcikla/tld/internal/watch/enrich/enrichers/traffic/python" + "github.com/mertcikla/tld/internal/watch/enrich/enrichers/web3" + "github.com/mertcikla/tld/internal/watch/enrich/enrichers/workspace" +) + +// NewRegistry returns the complete built-in enricher registry. +func NewRegistry() *enrich.Registry { + return enrich.NewRegistry(DefaultEnrichers()...) +} + +// DefaultEnrichers returns the complete built-in catalog. +// +// Keep this function as composition only. Add new enrichers to the narrowest +// domain/language package so the default registry does not become an unreadable +// list of framework constructors. +func DefaultEnrichers() []enrich.Enricher { + return appendGroups( + InventoryEnrichers(), + ConfigEnrichers(), + HTTPClientEnrichers(), + RouteEnrichers(), + FrontendEnrichers(), + ORMEnrichers(), + RPCEnrichers(), + RuntimeEnrichers(), + IaCEnrichers(), + ComposeEnrichers(), + CloudEnrichers(), + MessagingEnrichers(), + StorageEnrichers(), + DatastoreEnrichers(), + TrafficEnrichers(), + ObservabilityEnrichers(), + AuthEnrichers(), + JobEnrichers(), + APISpecEnrichers(), + DeploymentEnrichers(), + SecretEnrichers(), + WorkspaceEnrichers(), + AIEnrichers(), + IoTEnrichers(), + IPCEnrichers(), + DataEnrichers(), + Web3Enrichers(), + OSIntegrationEnrichers(), + ) +} + +func InventoryEnrichers() []enrich.Enricher { + return []enrich.Enricher{ + inventory.DependencyInventory(), + } +} + +func ConfigEnrichers() []enrich.Enricher { return config.All() } +func HTTPClientEnrichers() []enrich.Enricher { return httpclient.All() } + +func RouteEnrichers() []enrich.Enricher { + return appendGroups( + GoRouteEnrichers(), + TypeScriptRouteEnrichers(), + PythonRouteEnrichers(), + JavaRouteEnrichers(), + RustRouteEnrichers(), + CPPRouteEnrichers(), + ) +} + +func GoRouteEnrichers() []enrich.Enricher { + return []enrich.Enricher{ + goroutes.GoNetHTTP(), + goroutes.GoChi(), + goroutes.GoGin(), + goroutes.GoGorillaMux(), + goroutes.GoEcho(), + goroutes.GoFiber(), + } +} + +func TypeScriptRouteEnrichers() []enrich.Enricher { + return []enrich.Enricher{ + tstypes.Express(), + tstypes.Fastify(), + tstypes.NestJS(), + tstypes.Hono(), + } +} + +func PythonRouteEnrichers() []enrich.Enricher { + return []enrich.Enricher{ + pythonroutes.PythonFlask(), + pythonroutes.PythonFastAPI(), + pythonroutes.PythonDjango(), + pythonroutes.PythonStarlette(), + } +} + +func JavaRouteEnrichers() []enrich.Enricher { + return []enrich.Enricher{ + javaroutes.Spring(), + javaroutes.JAXRS(), + javaroutes.Micronaut(), + javaroutes.Quarkus(), + } +} + +func RustRouteEnrichers() []enrich.Enricher { + return []enrich.Enricher{ + rustroutes.Axum(), + rustroutes.ActixWeb(), + rustroutes.Rocket(), + rustroutes.Warp(), + } +} + +func CPPRouteEnrichers() []enrich.Enricher { + return []enrich.Enricher{ + cpproutes.Drogon(), + cpproutes.Oatpp(), + cpproutes.Pistache(), + cpproutes.Crow(), + cpproutes.CppRestSDK(), + } +} + +func FrontendEnrichers() []enrich.Enricher { + return []enrich.Enricher{ + frontendts.NextJS(), + frontendts.ReactRouter(), + } +} + +func ORMEnrichers() []enrich.Enricher { + return appendGroups( + ormts.All(), + ormgo.All(), + ormpython.All(), + ormjava.All(), + ormrust.All(), + ormcpp.All(), + ) +} + +func RPCEnrichers() []enrich.Enricher { + return appendGroups( + ContractEnrichers(), + GRPCEnrichers(), + RPCClientEnrichers(), + ) +} + +func RPCClientEnrichers() []enrich.Enricher { return rpcclients.All() } + +func ContractEnrichers() []enrich.Enricher { + return []enrich.Enricher{ + rpcgrpc.ProtobufContracts(), + } +} + +func GRPCEnrichers() []enrich.Enricher { + return appendGroups( + GoGRPCEnrichers(), + PythonGRPCEnrichers(), + NodeGRPCEnrichers(), + JavaGRPCEnrichers(), + DotNetGRPCEnrichers(), + ) +} + +func GoGRPCEnrichers() []enrich.Enricher { + return []enrich.Enricher{ + rpcgrpc.GoGRPC(), + } +} + +func PythonGRPCEnrichers() []enrich.Enricher { + return []enrich.Enricher{ + rpcgrpc.PythonGRPC(), + } +} + +func NodeGRPCEnrichers() []enrich.Enricher { + return []enrich.Enricher{ + rpcgrpc.NodeGRPC(), + } +} + +func JavaGRPCEnrichers() []enrich.Enricher { + return []enrich.Enricher{ + rpcgrpc.JavaGRPC(), + } +} + +func DotNetGRPCEnrichers() []enrich.Enricher { + return []enrich.Enricher{ + rpcgrpc.DotNetGRPC(), + } +} + +func RuntimeEnrichers() []enrich.Enricher { + return []enrich.Enricher{ + runtimeenrich.RuntimeManifests(), + } +} + +func IaCEnrichers() []enrich.Enricher { return iac.All() } +func ComposeEnrichers() []enrich.Enricher { return compose.All() } +func CloudEnrichers() []enrich.Enricher { return cloud.All() } +func MessagingEnrichers() []enrich.Enricher { return messaging.All() } +func StorageEnrichers() []enrich.Enricher { return storage.All() } + +func DatastoreEnrichers() []enrich.Enricher { + return []enrich.Enricher{ + datastore.DatastoreGlue(), + } +} + +func TrafficEnrichers() []enrich.Enricher { + return []enrich.Enricher{ + pythontraffic.PythonLocust(), + } +} + +func ObservabilityEnrichers() []enrich.Enricher { return observability.All() } +func AuthEnrichers() []enrich.Enricher { return auth.All() } +func JobEnrichers() []enrich.Enricher { return jobs.All() } +func APISpecEnrichers() []enrich.Enricher { return apispec.All() } +func DeploymentEnrichers() []enrich.Enricher { return deployment.All() } +func SecretEnrichers() []enrich.Enricher { return secrets.All() } +func WorkspaceEnrichers() []enrich.Enricher { return workspace.All() } +func AIEnrichers() []enrich.Enricher { return ai.All() } +func IoTEnrichers() []enrich.Enricher { return iot.All() } +func IPCEnrichers() []enrich.Enricher { return ipc.All() } +func DataEnrichers() []enrich.Enricher { return dataeng.All() } +func Web3Enrichers() []enrich.Enricher { return web3.All() } +func OSIntegrationEnrichers() []enrich.Enricher { return osintegration.All() } + +func appendGroups(groups ...[]enrich.Enricher) []enrich.Enricher { + total := 0 + for _, group := range groups { + total += len(group) + } + out := make([]enrich.Enricher, 0, total) + for _, group := range groups { + out = append(out, group...) + } + return out +} diff --git a/internal/watch/enrich/enrichers/ai/ai.go b/internal/watch/enrich/enrichers/ai/ai.go new file mode 100644 index 0000000..cdd4860 --- /dev/null +++ b/internal/watch/enrich/enrichers/ai/ai.go @@ -0,0 +1,50 @@ +package ai + +import ( + "github.com/mertcikla/tld/internal/watch/enrich" + "github.com/mertcikla/tld/internal/watch/enrich/enrichers/pattern" +) + +func All() []enrich.Enricher { return pattern.FromSpecs(Specs()) } + +func Specs() []pattern.Spec { + return []pattern.Spec{ + spec("ts.pinecone", "TypeScript Pinecone", "typescript", "@pinecone-database/pinecone", "Pinecone", "ai.vector_index", "queries_index"), + spec("python.pinecone", "Python Pinecone", "python", "pinecone-client", "Pinecone", "ai.vector_index", "queries_index"), + spec("ts.milvus", "TypeScript Milvus", "typescript", "@zilliz/milvus2-sdk-node", "MilvusClient", "ai.vector_index", "queries_index"), + spec("python.milvus", "Python Milvus", "python", "pymilvus", "MilvusClient", "ai.vector_index", "queries_index"), + spec("ts.qdrant", "TypeScript Qdrant", "typescript", "@qdrant/js-client-rest", "QdrantClient", "ai.vector_index", "queries_index"), + spec("python.qdrant", "Python Qdrant", "python", "qdrant-client", "QdrantClient", "ai.vector_index", "queries_index"), + spec("ts.chroma", "TypeScript Chroma", "typescript", "chromadb", "ChromaClient", "ai.vector_index", "queries_index"), + spec("python.chroma", "Python Chroma", "python", "chromadb", "chromadb", "ai.vector_index", "queries_index"), + spec("ts.weaviate", "TypeScript Weaviate", "typescript", "weaviate-ts-client", "weaviate.client", "ai.vector_index", "queries_index"), + spec("python.weaviate", "Python Weaviate", "python", "weaviate-client", "weaviate.Client", "ai.vector_index", "queries_index"), + spec("python.huggingface", "Python Hugging Face", "python", "huggingface_hub", "huggingface_hub", "ai.model_id", "loads_model"), + spec("python.mlflow", "Python MLflow", "python", "mlflow", "mlflow.start_run", "ai.experiment_tracker", "tracks_metrics_to"), + spec("python.wandb", "Python Weights & Biases", "python", "wandb", "wandb.init", "ai.experiment_tracker", "tracks_metrics_to"), + spec("ts.openai", "TypeScript OpenAI SDK", "typescript", "openai", "new OpenAI", "ai.llm_endpoint", "calls_llm"), + spec("python.openai", "Python OpenAI SDK", "python", "openai", "OpenAI(", "ai.llm_endpoint", "calls_llm"), + spec("ts.anthropic", "TypeScript Anthropic SDK", "typescript", "@anthropic-ai/sdk", "Anthropic", "ai.llm_endpoint", "calls_llm"), + spec("python.anthropic", "Python Anthropic SDK", "python", "anthropic", "Anthropic(", "ai.llm_endpoint", "calls_llm"), + spec("ts.langchain", "TypeScript LangChain", "typescript", "langchain", "langchain", "ai.llm_endpoint", "calls_llm"), + spec("python.langchain", "Python LangChain", "python", "langchain", "langchain", "ai.llm_endpoint", "calls_llm"), + spec("ts.llamaindex", "TypeScript LlamaIndex", "typescript", "llamaindex", "llamaindex", "ai.llm_endpoint", "calls_llm"), + spec("python.llamaindex", "Python LlamaIndex", "python", "llama-index", "llama_index", "ai.llm_endpoint", "calls_llm"), + } +} + +func spec(id, name, language, dependency, token, factType, relationship string) pattern.Spec { + return pattern.Spec{ + ID: id, + Name: name, + Category: "ai", + Languages: []string{language}, + Mode: enrich.ActivationImportOrDependency, + Triggers: []enrich.ActivationSignal{{Kind: enrich.SignalDependency, Value: dependency}, {Kind: enrich.SignalImport, Value: dependency}}, + FactType: factType, + Relationship: relationship, + SourceTokens: []string{token}, + Tags: []string{"ai:" + id}, + Attributes: map[string]string{"dependency": dependency, "language": language}, + } +} diff --git a/internal/watch/enrich/enrichers/ai/ai_test.go b/internal/watch/enrich/enrichers/ai/ai_test.go new file mode 100644 index 0000000..54f4e23 --- /dev/null +++ b/internal/watch/enrich/enrichers/ai/ai_test.go @@ -0,0 +1,33 @@ +package ai + +import ( + "testing" + + "github.com/mertcikla/tld/internal/watch/enrich" + "github.com/mertcikla/tld/internal/watch/enrich/enrichertest" +) + +func TestAIEnrichers(t *testing.T) { + byID := enrichersByID() + for _, spec := range Specs() { + enrichertest.Run(t, enrichertest.Case{ + Name: spec.ID, + Enricher: byID[spec.ID], + Input: enrich.FileInput{ + RelPath: "src/ai", + Language: spec.Languages[0], + Source: []byte(spec.SourceTokens[0]), + }, + Signals: []enrich.ActivationSignal{{Kind: enrich.SignalDependency, Value: spec.Triggers[0].Value}}, + Want: enrichertest.Fact{Type: spec.FactType, Tag: "category:ai", Name: spec.Name, Attribute: "technology", AttrValue: spec.Name}, + }) + } +} + +func enrichersByID() map[string]enrich.Enricher { + out := map[string]enrich.Enricher{} + for _, enricher := range All() { + out[enricher.Metadata().ID] = enricher + } + return out +} diff --git a/internal/watch/enrich/enrichers/apispec/apispec.go b/internal/watch/enrich/enrichers/apispec/apispec.go new file mode 100644 index 0000000..f59515e --- /dev/null +++ b/internal/watch/enrich/enrichers/apispec/apispec.go @@ -0,0 +1,35 @@ +package apispec + +import ( + "github.com/mertcikla/tld/internal/watch/enrich" + "github.com/mertcikla/tld/internal/watch/enrich/enrichers/pattern" +) + +func All() []enrich.Enricher { return pattern.FromSpecs(Specs()) } + +func Specs() []pattern.Spec { + return []pattern.Spec{ + spec("apispec.openapi", "OpenAPI / Swagger", "yaml", []string{"openapi.yaml", "swagger.yaml"}, []string{"openapi: 3", "openapi: \"3\""}, "api.spec", "documents"), + spec("apispec.asyncapi", "AsyncAPI", "yaml", []string{"asyncapi.yaml"}, []string{"asyncapi: 2", "asyncapi: \"2\""}, "api.spec", "documents"), + spec("apispec.graphql_schema", "GraphQL schema", "graphql", []string{".graphql", ".gql"}, []string{"type Query"}, "api.schema", "declares"), + spec("apispec.protobuf", "Protocol Buffers", "protobuf", []string{".proto"}, []string{"service "}, "rpc.service", "exposes"), + spec("apispec.avro", "Avro", "json", []string{".avsc", ".avdl"}, []string{"\"type\":\"record\""}, "api.schema", "declares"), + spec("apispec.json_schema", "JSON Schema", "json", []string{"schema.json"}, []string{"\"$schema\""}, "api.schema", "declares"), + } +} + +func spec(id, name, language string, pathTokens, sourceTokens []string, factType, relationship string) pattern.Spec { + return pattern.Spec{ + ID: id, + Name: name, + Category: "api-spec", + Languages: []string{language}, + Mode: enrich.ActivationAlways, + FactType: factType, + Relationship: relationship, + SourceTokens: sourceTokens, + PathTokens: pathTokens, + Tags: []string{"api-spec:" + id}, + Attributes: map[string]string{"format": id}, + } +} diff --git a/internal/watch/enrich/enrichers/apispec/apispec_test.go b/internal/watch/enrich/enrichers/apispec/apispec_test.go new file mode 100644 index 0000000..75e04cd --- /dev/null +++ b/internal/watch/enrich/enrichers/apispec/apispec_test.go @@ -0,0 +1,32 @@ +package apispec + +import ( + "testing" + + "github.com/mertcikla/tld/internal/watch/enrich" + "github.com/mertcikla/tld/internal/watch/enrich/enrichertest" +) + +func TestAPISpecEnrichers(t *testing.T) { + byID := enrichersByID() + for _, spec := range Specs() { + enrichertest.Run(t, enrichertest.Case{ + Name: spec.ID, + Enricher: byID[spec.ID], + Input: enrich.FileInput{ + RelPath: "contracts/" + spec.PathTokens[0], + Language: spec.Languages[0], + Source: []byte(spec.SourceTokens[0]), + }, + Want: enrichertest.Fact{Type: spec.FactType, Tag: "category:api-spec", Name: spec.Name, Attribute: "technology", AttrValue: spec.Name}, + }) + } +} + +func enrichersByID() map[string]enrich.Enricher { + out := map[string]enrich.Enricher{} + for _, enricher := range All() { + out[enricher.Metadata().ID] = enricher + } + return out +} diff --git a/internal/watch/enrich/enrichers/auth/auth.go b/internal/watch/enrich/enrichers/auth/auth.go new file mode 100644 index 0000000..5d3fa20 --- /dev/null +++ b/internal/watch/enrich/enrichers/auth/auth.go @@ -0,0 +1,50 @@ +package auth + +import ( + "github.com/mertcikla/tld/internal/watch/enrich" + "github.com/mertcikla/tld/internal/watch/enrich/enrichers/pattern" +) + +func All() []enrich.Enricher { return pattern.FromSpecs(Specs()) } + +func Specs() []pattern.Spec { + return []pattern.Spec{ + spec("ts.auth0", "TypeScript Auth0", "typescript", "@auth0/auth0-react", "@auth0/", "auth.provider", "uses_identity_provider"), + spec("ts.cognito", "TypeScript Cognito", "typescript", "@aws-sdk/client-cognito-identity-provider", "CognitoIdentityProviderClient", "auth.provider", "uses_identity_provider"), + spec("ts.firebase_auth", "TypeScript Firebase Auth", "typescript", "firebase", "firebase/auth", "auth.provider", "uses_identity_provider"), + spec("ts.clerk", "TypeScript Clerk", "typescript", "@clerk/nextjs", "@clerk/", "auth.provider", "uses_identity_provider"), + spec("ts.nextauth", "TypeScript NextAuth", "typescript", "next-auth", "next-auth", "auth.provider", "uses_identity_provider"), + spec("go.jwt", "Go JWT middleware", "go", "github.com/golang-jwt/jwt", "jwt.Parse", "auth.issuer", "trusts_issuer"), + spec("go.oidc", "Go OIDC clients", "go", "github.com/coreos/go-oidc", "oidc.NewProvider", "auth.issuer", "trusts_issuer"), + spec("go.auth0_cognito", "Go Auth0/Cognito SDK", "go", "github.com/auth0/go-auth0", "auth0", "auth.provider", "uses_identity_provider"), + spec("python.pyjwt", "Python PyJWT", "python", "PyJWT", "jwt.decode", "auth.issuer", "trusts_issuer"), + spec("python.authlib", "Python Authlib", "python", "Authlib", "authlib", "auth.provider", "uses_identity_provider"), + spec("python.django_auth", "Python Django auth", "python", "Django", "django.contrib.auth", "auth.provider", "authenticates_with"), + spec("python.fastapi_security", "Python FastAPI security", "python", "fastapi", "fastapi.security", "auth.provider", "authenticates_with"), + spec("java.spring_security", "Java Spring Security OAuth/OIDC", "java", "spring-security-oauth2-client", "@EnableWebSecurity", "auth.provider", "uses_identity_provider"), + spec("java.keycloak", "Java Keycloak", "java", "keycloak", "Keycloak", "auth.provider", "uses_identity_provider"), + spec("java.cognito", "Java Cognito", "java", "software.amazon.awssdk.services.cognitoidentityprovider", "CognitoIdentityProviderClient", "auth.provider", "uses_identity_provider"), + spec("java.oidc", "Java OIDC", "java", "spring-security-oauth2-jose", "issuer-uri", "auth.issuer", "trusts_issuer"), + spec("rust.jsonwebtoken", "Rust jsonwebtoken", "rust", "jsonwebtoken", "jsonwebtoken::decode", "auth.issuer", "trusts_issuer"), + spec("rust.oauth2", "Rust oauth2", "rust", "oauth2", "oauth2::", "auth.provider", "uses_identity_provider"), + spec("rust.oidc", "Rust OIDC crates", "rust", "openidconnect", "openidconnect::", "auth.issuer", "trusts_issuer"), + spec("cpp.jwt", "C++ JWT validation", "cpp", "jwt-cpp", "jwt::decode", "auth.issuer", "trusts_issuer"), + spec("cpp.oidc_jwks", "C++ OIDC/JWKS config", "cpp", "oidc", "jwks_uri", "auth.jwks_endpoint", "trusts_issuer"), + } +} + +func spec(id, name, language, dependency, token, factType, relationship string) pattern.Spec { + return pattern.Spec{ + ID: id, + Name: name, + Category: "auth", + Languages: []string{language}, + Mode: enrich.ActivationImportOrDependency, + Triggers: []enrich.ActivationSignal{{Kind: enrich.SignalDependency, Value: dependency}, {Kind: enrich.SignalImport, Value: dependency}}, + FactType: factType, + Relationship: relationship, + SourceTokens: []string{token}, + Tags: []string{"auth:" + id}, + Attributes: map[string]string{"dependency": dependency, "language": language}, + } +} diff --git a/internal/watch/enrich/enrichers/auth/auth_test.go b/internal/watch/enrich/enrichers/auth/auth_test.go new file mode 100644 index 0000000..f3b84b0 --- /dev/null +++ b/internal/watch/enrich/enrichers/auth/auth_test.go @@ -0,0 +1,33 @@ +package auth + +import ( + "testing" + + "github.com/mertcikla/tld/internal/watch/enrich" + "github.com/mertcikla/tld/internal/watch/enrich/enrichertest" +) + +func TestAuthEnrichers(t *testing.T) { + byID := enrichersByID() + for _, spec := range Specs() { + enrichertest.Run(t, enrichertest.Case{ + Name: spec.ID, + Enricher: byID[spec.ID], + Input: enrich.FileInput{ + RelPath: "src/auth", + Language: spec.Languages[0], + Source: []byte(spec.SourceTokens[0]), + }, + Signals: []enrich.ActivationSignal{{Kind: enrich.SignalDependency, Value: spec.Triggers[0].Value}}, + Want: enrichertest.Fact{Type: spec.FactType, Tag: "category:auth", Name: spec.Name, Attribute: "technology", AttrValue: spec.Name}, + }) + } +} + +func enrichersByID() map[string]enrich.Enricher { + out := map[string]enrich.Enricher{} + for _, enricher := range All() { + out[enricher.Metadata().ID] = enricher + } + return out +} diff --git a/internal/watch/enrich/enrichers/cloud/cloud.go b/internal/watch/enrich/enrichers/cloud/cloud.go new file mode 100644 index 0000000..5700187 --- /dev/null +++ b/internal/watch/enrich/enrichers/cloud/cloud.go @@ -0,0 +1,47 @@ +package cloud + +import ( + "github.com/mertcikla/tld/internal/watch/enrich" + "github.com/mertcikla/tld/internal/watch/enrich/enrichers/pattern" +) + +func All() []enrich.Enricher { return pattern.FromSpecs(Specs()) } + +func Specs() []pattern.Spec { + return []pattern.Spec{ + spec("ts.aws_sdk_v3", "TypeScript AWS SDK v3", "typescript", "@aws-sdk/client-s3", "S3Client", "cloud.bucket", "reads_from"), + spec("ts.google_cloud", "TypeScript Google Cloud clients", "typescript", "@google-cloud/storage", "Storage(", "cloud.bucket", "reads_from"), + spec("ts.azure_sdk", "TypeScript Azure SDK", "typescript", "@azure/storage-blob", "BlobServiceClient", "cloud.resource", "reads_from"), + spec("go.aws_sdk_v2", "Go AWS SDK v2", "go", "github.com/aws/aws-sdk-go-v2", "aws.Config", "cloud.resource", "reads_from"), + spec("go.google_cloud", "Google Cloud Go", "go", "cloud.google.com/go", "cloud.google.com/go", "cloud.resource", "reads_from"), + spec("go.azure_sdk", "Azure SDK for Go", "go", "github.com/Azure/azure-sdk-for-go", "azidentity", "cloud.resource", "reads_from"), + spec("python.boto3", "Python boto3", "python", "boto3", "boto3.client", "cloud.resource", "reads_from"), + spec("python.google_cloud", "Python google-cloud clients", "python", "google-cloud", "google.cloud", "cloud.resource", "reads_from"), + spec("python.azure_sdk", "Python Azure SDK", "python", "azure-", "azure.storage", "cloud.resource", "reads_from"), + spec("java.aws_sdk_v2", "Java AWS SDK v2", "java", "software.amazon.awssdk", "software.amazon.awssdk", "cloud.resource", "reads_from"), + spec("java.google_cloud", "Google Cloud Java", "java", "com.google.cloud", "com.google.cloud", "cloud.resource", "reads_from"), + spec("java.azure_sdk", "Azure SDK Java", "java", "com.azure", "com.azure", "cloud.resource", "reads_from"), + spec("rust.aws_sdk", "AWS SDK Rust", "rust", "aws-sdk", "aws_sdk_", "cloud.resource", "reads_from"), + spec("rust.google_cloud", "Google Cloud Rust clients", "rust", "google-cloud", "google_cloud", "cloud.resource", "reads_from"), + spec("rust.azure_sdk", "Azure SDK Rust", "rust", "azure_", "azure_", "cloud.resource", "reads_from"), + spec("cpp.aws_sdk", "AWS SDK C++", "cpp", "aws-sdk-cpp", "Aws::", "cloud.resource", "reads_from"), + spec("cpp.google_cloud", "Google Cloud C++", "cpp", "google-cloud-cpp", "google::cloud", "cloud.resource", "reads_from"), + spec("cpp.azure_sdk", "Azure SDK C++", "cpp", "azure-sdk-for-cpp", "Azure::", "cloud.resource", "reads_from"), + } +} + +func spec(id, name, language, dependency, token, factType, relationship string) pattern.Spec { + return pattern.Spec{ + ID: id, + Name: name, + Category: "cloud", + Languages: []string{language}, + Mode: enrich.ActivationImportOrDependency, + Triggers: []enrich.ActivationSignal{{Kind: enrich.SignalDependency, Value: dependency}, {Kind: enrich.SignalImport, Value: dependency}}, + FactType: factType, + Relationship: relationship, + SourceTokens: []string{token}, + Tags: []string{"cloud:" + id}, + Attributes: map[string]string{"dependency": dependency, "language": language}, + } +} diff --git a/internal/watch/enrich/enrichers/cloud/cloud_test.go b/internal/watch/enrich/enrichers/cloud/cloud_test.go new file mode 100644 index 0000000..73dcd6f --- /dev/null +++ b/internal/watch/enrich/enrichers/cloud/cloud_test.go @@ -0,0 +1,33 @@ +package cloud + +import ( + "testing" + + "github.com/mertcikla/tld/internal/watch/enrich" + "github.com/mertcikla/tld/internal/watch/enrich/enrichertest" +) + +func TestCloudEnrichers(t *testing.T) { + byID := enrichersByID() + for _, spec := range Specs() { + enrichertest.Run(t, enrichertest.Case{ + Name: spec.ID, + Enricher: byID[spec.ID], + Input: enrich.FileInput{ + RelPath: "src/cloud", + Language: spec.Languages[0], + Source: []byte(spec.SourceTokens[0]), + }, + Signals: []enrich.ActivationSignal{{Kind: enrich.SignalDependency, Value: spec.Triggers[0].Value}}, + Want: enrichertest.Fact{Type: spec.FactType, Tag: "category:cloud", Name: spec.Name, Attribute: "technology", AttrValue: spec.Name}, + }) + } +} + +func enrichersByID() map[string]enrich.Enricher { + out := map[string]enrich.Enricher{} + for _, enricher := range All() { + out[enricher.Metadata().ID] = enricher + } + return out +} diff --git a/internal/watch/enrich/enrichers/compose/compose.go b/internal/watch/enrich/enrichers/compose/compose.go new file mode 100644 index 0000000..e2fb09e --- /dev/null +++ b/internal/watch/enrich/enrichers/compose/compose.go @@ -0,0 +1,332 @@ +package compose + +import ( + "context" + "fmt" + "strings" + + "github.com/compose-spec/compose-go/v2/loader" + "github.com/compose-spec/compose-go/v2/types" + "github.com/mertcikla/tld/internal/watch/enrich" +) + +func All() []enrich.Enricher { + return []enrich.Enricher{Compose()} +} + +func Compose() enrich.Enricher { + return enrich.NewEnricher( + enrich.Metadata{ + ID: "compose.docker_compose", + Name: "Docker Compose", + Mode: enrich.ActivationImportOrDependency, + Triggers: []enrich.ActivationSignal{ + {Kind: enrich.SignalDependency, Value: "docker-compose"}, + }, + }, + func(input enrich.FileInput) bool { + if input.Language != "yaml" { + return false + } + return composePathTokens(input.RelPath) + }, + func(ctx context.Context, input enrich.FileInput, emit enrich.FactEmitter) error { + return emitComposeFacts(input, emit) + }, + ) +} + +func composePathTokens(rel string) bool { + lower := strings.ToLower(rel) + return strings.Contains(lower, "docker-compose") || strings.Contains(lower, "compose.yaml") || strings.Contains(lower, "compose.yml") +} + +func emitComposeFacts(input enrich.FileInput, emit enrich.FactEmitter) error { + project, err := parseCompose(input) + if err != nil || len(project.Services) == 0 { + return nil + } + serviceNames := project.ServiceNames() + serviceSet := make(map[string]struct{}, len(serviceNames)) + for _, name := range serviceNames { + serviceSet[name] = struct{}{} + } + lineOffsets := serviceLineOffsets(string(input.Source), serviceNames) + + for name, svc := range project.Services { + line := lineOffsets[name] + tech, kind := imageToTech(svc.Image) + if labelKind := svc.Labels["tld.kind"]; labelKind != "" { + kind = labelKind + } + buildCtx := "" + if svc.Build != nil { + buildCtx = svc.Build.Context + } + + if err := emitServiceFact(input, emit, name, tech, kind, buildCtx, svc.Image, svc.Labels, line); err != nil { + return err + } + for _, dep := range dedupKeys(svc.DependsOn) { + if err := emitDependsOnFact(input, emit, name, dep, line); err != nil { + return err + } + } + for _, ref := range envEndpointRefs(svc.Environment, serviceNames) { + if ref != name { + if err := emitEnvConnectionFact(input, emit, name, ref, line); err != nil { + return err + } + } + } + for _, port := range svc.Ports { + if err := emitPortFact(input, emit, name, port, line); err != nil { + return err + } + } + for _, vol := range svc.Volumes { + if err := emitVolumeFact(input, emit, name, vol, line); err != nil { + return err + } + } + } + return nil +} + +func parseCompose(input enrich.FileInput) (*types.Project, error) { + config := types.ConfigDetails{ + WorkingDir: input.RepoRoot, + ConfigFiles: []types.ConfigFile{{ + Filename: input.AbsPath, + Content: input.Source, + }}, + } + opts := func(o *loader.Options) { + o.SkipValidation = true + o.SkipInterpolation = true + o.SkipNormalization = true + o.SkipConsistencyCheck = true + o.SkipResolveEnvironment = true + o.SkipExtends = true + o.SkipInclude = true + } + project, err := loader.LoadWithContext(context.Background(), config, opts) + if err != nil { + return nil, err + } + return project, nil +} + +func emitServiceFact(input enrich.FileInput, emit enrich.FactEmitter, name, technology, kind, buildCtx, image string, labels types.Labels, line int) error { + attrs := map[string]string{ + "name": name, + "kind": kind, + "technology": technology, + } + if image != "" { + attrs["image"] = image + } + if buildCtx != "" { + attrs["build_context"] = buildCtx + } + for k, v := range labels { + if strings.HasPrefix(k, "tld.") { + attrs[strings.TrimPrefix(k, "tld.")] = v + } + } + tags := []string{ + "arch:component", + "runtime:compose", + "arch:deployable", + "technology:" + tagValue(technology), + "kind:" + kind, + } + return emit.EmitFact(enrich.Fact{ + Type: "runtime.component", + StableKey: fmt.Sprintf("runtime.component:%s:%s:%d", input.RelPath, name, line), + Subject: enrich.FileSubject(input.RelPath), + Object: enrich.SubjectRef{Kind: "runtime.component", StableKey: "runtime.component:" + name, FilePath: input.RelPath, Name: name}, + Relationship: "declares", + Source: enrich.SourceSpan{FilePath: input.RelPath, StartLine: line, EndLine: line}, + Confidence: 0.85, + Name: name, + Tags: tags, + Attributes: attrs, + VisibilityHints: map[string]float64{ + "high_signal": 0.85, + }, + }) +} + +func emitDependsOnFact(input enrich.FileInput, emit enrich.FactEmitter, source, target string, line int) error { + return emit.EmitFact(enrich.Fact{ + Type: "runtime.connection", + StableKey: fmt.Sprintf("runtime.connection:%s:%s:depends_on:%s:%d", input.RelPath, source, target, line), + Subject: enrich.FileSubject(input.RelPath), + Object: enrich.SubjectRef{Kind: "runtime.component", StableKey: "runtime.component:" + target, Name: target}, + Relationship: "depends_on", + Source: enrich.SourceSpan{FilePath: input.RelPath, StartLine: line, EndLine: line}, + Confidence: 0.75, + Name: source + " -> " + target, + Tags: []string{"arch:connection"}, + Attributes: map[string]string{"source": source, "target": target, "label": "depends on"}, + VisibilityHints: map[string]float64{ + "high_signal": 0.7, + }, + }) +} + +func emitEnvConnectionFact(input enrich.FileInput, emit enrich.FactEmitter, source, target string, line int) error { + return emit.EmitFact(enrich.Fact{ + Type: "runtime.connection", + StableKey: fmt.Sprintf("runtime.connection:%s:%s:connects_to:%s:%d", input.RelPath, source, target, line), + Subject: enrich.FileSubject(input.RelPath), + Object: enrich.SubjectRef{Kind: "runtime.component", StableKey: "runtime.component:" + target, Name: target}, + Relationship: "connects_to", + Source: enrich.SourceSpan{FilePath: input.RelPath, StartLine: line, EndLine: line}, + Confidence: 0.55, + Name: source + " -> " + target, + Tags: []string{"arch:connection", "compose:implicit"}, + Attributes: map[string]string{"source": source, "target": target, "label": "connects via env", "note": "inferred from environment variable"}, + VisibilityHints: map[string]float64{ + "high_signal": 0.4, + }, + }) +} + +func emitPortFact(input enrich.FileInput, emit enrich.FactEmitter, serviceName string, port types.ServicePortConfig, line int) error { + protocol := port.Protocol + if protocol == "" { + protocol = "tcp" + } + label := fmt.Sprintf("%d/%s", port.Target, protocol) + if port.Published != "" { + label = port.Published + ":" + label + } + return emit.EmitFact(enrich.Fact{ + Type: "runtime.endpoint", + StableKey: fmt.Sprintf("runtime.endpoint:%s:%s:%d:%s:%d", input.RelPath, serviceName, port.Target, protocol, line), + Subject: enrich.FileSubject(input.RelPath), + Object: enrich.SubjectRef{Kind: "runtime.endpoint", StableKey: fmt.Sprintf("runtime.endpoint:%s:%d", serviceName, port.Target), Name: label}, + Relationship: "exposes", + Source: enrich.SourceSpan{FilePath: input.RelPath, StartLine: line, EndLine: line}, + Confidence: 0.80, + Name: serviceName + ":" + label, + Tags: []string{"arch:endpoint"}, + Attributes: map[string]string{"service": serviceName, "port": fmt.Sprint(port.Target), "protocol": protocol}, + VisibilityHints: map[string]float64{ + "high_signal": 0.7, + }, + }) +} + +func emitVolumeFact(input enrich.FileInput, emit enrich.FactEmitter, serviceName string, vol types.ServiceVolumeConfig, line int) error { + source := vol.Source + if source == "" { + source = vol.Target + } + return emit.EmitFact(enrich.Fact{ + Type: "storage.volume", + StableKey: fmt.Sprintf("storage.volume:%s:%s:%s:%d", input.RelPath, serviceName, source, line), + Subject: enrich.FileSubject(input.RelPath), + Object: enrich.SubjectRef{Kind: "storage.volume", StableKey: "storage.volume:" + source, Name: source}, + Relationship: "uses", + Source: enrich.SourceSpan{FilePath: input.RelPath, StartLine: line, EndLine: line}, + Confidence: 0.70, + Name: serviceName + " -> " + source, + Tags: []string{"storage:volume"}, + Attributes: map[string]string{"service": serviceName, "source": source, "target": vol.Target}, + VisibilityHints: map[string]float64{ + "high_signal": 0.5, + }, + }) +} + +func envEndpointRefs(env types.MappingWithEquals, serviceNames []string) []string { + var refs []string + for _, value := range env { + if value == nil { + continue + } + target := extractServiceRef(*value, serviceNames) + if target == "" { + continue + } + refs = append(refs, target) + } + return refs +} + +func extractServiceRef(value string, serviceNames []string) string { + value = strings.TrimSpace(value) + lower := strings.ToLower(value) + + // Filter out non-signal values + if value == "" || lower == "true" || lower == "false" || lower == "null" || lower == "localhost" || lower == "127.0.0.1" || lower == "0.0.0.0" { + return "" + } + + // Extract hostname from URL or host:port + host := value + if strings.Contains(host, "://") { + parts := strings.SplitN(host, "://", 2) + host = parts[1] + } + if strings.Contains(host, ":") { + host = strings.SplitN(host, ":", 2)[0] + } + host = strings.Trim(host, "/") + if strings.Contains(host, ".") { + host = strings.SplitN(host, ".", 2)[0] + } + host = strings.ToLower(host) + if host == "" || strings.ContainsAny(host, " /\\") { + return "" + } + + for _, name := range serviceNames { + if strings.EqualFold(name, host) { + return name + } + } + return "" +} + +func serviceLineOffsets(source string, serviceNames []string) map[string]int { + lines := strings.Split(source, "\n") + offsets := make(map[string]int, len(serviceNames)) + for i, line := range lines { + trimmed := strings.TrimSpace(line) + for _, name := range serviceNames { + if _, ok := offsets[name]; ok { + continue + } + if strings.HasPrefix(trimmed, name+":") { + offsets[name] = i + 1 + } + } + } + for _, name := range serviceNames { + if _, ok := offsets[name]; !ok { + offsets[name] = 1 + } + } + return offsets +} + +func dedupKeys(m map[string]types.ServiceDependency) []string { + keys := make([]string, 0, len(m)) + for k := range m { + keys = append(keys, k) + } + return keys +} + +func tagValue(value string) string { + value = strings.ToLower(strings.TrimSpace(value)) + value = strings.NewReplacer(" / ", "-", "/", "-", " ", "-", "&", "and", ".", "", "+", "plus").Replace(value) + for strings.Contains(value, "--") { + value = strings.ReplaceAll(value, "--", "-") + } + return strings.Trim(value, "-") +} diff --git a/internal/watch/enrich/enrichers/compose/compose_test.go b/internal/watch/enrich/enrichers/compose/compose_test.go new file mode 100644 index 0000000..d94d75a --- /dev/null +++ b/internal/watch/enrich/enrichers/compose/compose_test.go @@ -0,0 +1,340 @@ +package compose + +import ( + "context" + "testing" + + "github.com/mertcikla/tld/internal/watch/enrich" + "github.com/mertcikla/tld/internal/watch/enrich/enrichertest" +) + +func triggerSignals() []enrich.ActivationSignal { + return []enrich.ActivationSignal{ + {Kind: enrich.SignalDependency, Value: "docker-compose"}, + } +} + +func TestComposeServiceWithImage(t *testing.T) { + enrichertest.Run(t, enrichertest.Case{ + Name: "compose service with image", + Enricher: Compose(), + Input: enrich.FileInput{ + RelPath: "docker-compose.yml", + Language: "yaml", + Source: []byte(`services: + web: + image: nginx:latest +`), + }, + Signals: triggerSignals(), + Want: enrichertest.Fact{Type: "runtime.component", Tag: "runtime:compose", Name: "web", Attribute: "image", AttrValue: "nginx:latest"}, + }) +} + +func TestComposeServiceImageTechnology(t *testing.T) { + enrichertest.Run(t, enrichertest.Case{ + Name: "compose service with inferred technology", + Enricher: Compose(), + Input: enrich.FileInput{ + RelPath: "docker-compose.yml", + Language: "yaml", + Source: []byte(`services: + db: + image: postgres:15 +`), + }, + Signals: triggerSignals(), + Want: enrichertest.Fact{Type: "runtime.component", Tag: "kind:database", Name: "db", Attribute: "technology", AttrValue: "PostgreSQL"}, + }) +} + +func TestComposeDependsOn(t *testing.T) { + enrichertest.Run(t, enrichertest.Case{ + Name: "compose depends_on connection", + Enricher: Compose(), + Input: enrich.FileInput{ + RelPath: "docker-compose.yml", + Language: "yaml", + Source: []byte(`services: + web: + image: nginx:latest + depends_on: + - db + db: + image: postgres:15 +`), + }, + Signals: triggerSignals(), + Want: enrichertest.Fact{Type: "runtime.connection", Name: "web -> db", Attribute: "source", AttrValue: "web"}, + }) +} + +func TestComposeEnvConnection(t *testing.T) { + enrichertest.Run(t, enrichertest.Case{ + Name: "compose env references another service", + Enricher: Compose(), + Input: enrich.FileInput{ + RelPath: "docker-compose.yml", + Language: "yaml", + Source: []byte(`services: + web: + image: nginx:latest + environment: + - DATABASE_URL=http://db:5432 + db: + image: postgres:15 +`), + }, + Signals: triggerSignals(), + Want: enrichertest.Fact{Type: "runtime.connection", Tag: "compose:implicit", Attribute: "source", AttrValue: "web"}, + }) +} + +func TestComposeBuildContext(t *testing.T) { + enrichertest.Run(t, enrichertest.Case{ + Name: "compose service with build context", + Enricher: Compose(), + Input: enrich.FileInput{ + RelPath: "docker-compose.yml", + Language: "yaml", + Source: []byte(`services: + app: + build: + context: ./src +`), + }, + Signals: triggerSignals(), + Want: enrichertest.Fact{Type: "runtime.component", Name: "app", Attribute: "build_context", AttrValue: "src"}, + }) +} + +func TestComposePort(t *testing.T) { + enrichertest.Run(t, enrichertest.Case{ + Name: "compose port exposes endpoint", + Enricher: Compose(), + Input: enrich.FileInput{ + RelPath: "docker-compose.yml", + Language: "yaml", + Source: []byte(`services: + web: + image: nginx:latest + ports: + - "80:80" +`), + }, + Signals: triggerSignals(), + Want: enrichertest.Fact{Type: "runtime.endpoint", Attribute: "service", AttrValue: "web"}, + }) +} + +func TestComposeVolume(t *testing.T) { + enrichertest.Run(t, enrichertest.Case{ + Name: "compose volume reference", + Enricher: Compose(), + Input: enrich.FileInput{ + RelPath: "docker-compose.yml", + Language: "yaml", + Source: []byte(`services: + db: + image: postgres:15 + volumes: + - pgdata:/var/lib/postgresql/data +volumes: + pgdata: +`), + }, + Signals: triggerSignals(), + Want: enrichertest.Fact{Type: "storage.volume", Attribute: "source", AttrValue: "pgdata"}, + }) +} + +func TestComposeLabelKindOverride(t *testing.T) { + enrichertest.Run(t, enrichertest.Case{ + Name: "compose label overrides kind", + Enricher: Compose(), + Input: enrich.FileInput{ + RelPath: "docker-compose.yml", + Language: "yaml", + Source: []byte(`services: + worker: + image: alpine:latest + labels: + tld.kind: "worker" +`), + }, + Signals: triggerSignals(), + Want: enrichertest.Fact{Type: "runtime.component", Name: "worker", Attribute: "kind", AttrValue: "worker"}, + }) +} + +func TestComposeEmptyServices(t *testing.T) { + registry := enrich.NewRegistry(Compose()) + facts, _, err := registry.EnrichFile(context.Background(), enrich.FileInput{ + RelPath: "docker-compose.yml", + Language: "yaml", + Source: []byte("services:\n"), + Signals: triggerSignals(), + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(facts) > 0 { + t.Fatalf("expected no facts for empty services, got %d: %+v", len(facts), facts) + } +} + +func TestComposeNonComposeYAML(t *testing.T) { + registry := enrich.NewRegistry(Compose()) + facts, _, err := registry.EnrichFile(context.Background(), enrich.FileInput{ + RelPath: "config.yaml", + Language: "yaml", + Source: []byte(`apiVersion: v1 +kind: ConfigMap +`), + Signals: triggerSignals(), + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(facts) > 0 { + t.Fatalf("expected no facts for non-compose yaml, got %d: %+v", len(facts), facts) + } +} + +func TestComposeComposeYAMLPath(t *testing.T) { + enrichertest.Run(t, enrichertest.Case{ + Name: "compose.yaml path matches", + Enricher: Compose(), + Input: enrich.FileInput{ + RelPath: "compose.yaml", + Language: "yaml", + Source: []byte(`services: + api: + image: myapp:latest +`), + }, + Signals: triggerSignals(), + Want: enrichertest.Fact{Type: "runtime.component", Tag: "runtime:compose", Name: "api"}, + }) +} + +func TestComposeOverrideFile(t *testing.T) { + enrichertest.Run(t, enrichertest.Case{ + Name: "docker-compose.override.yml matches", + Enricher: Compose(), + Input: enrich.FileInput{ + RelPath: "docker-compose.override.yml", + Language: "yaml", + Source: []byte(`services: + web: + ports: + - "8080:80" +`), + }, + Signals: triggerSignals(), + Want: enrichertest.Fact{Type: "runtime.component", Name: "web"}, + }) +} + +func TestImageToTech(t *testing.T) { + tests := []struct { + image string + wantTech string + wantKind string + }{ + {"", "Container", "service"}, + {"postgres:15", "PostgreSQL", "database"}, + {"mysql:8", "MySQL", "database"}, + {"redis:alpine", "Redis", "cache"}, + {"nginx:latest", "Nginx", "proxy"}, + {"rabbitmq:3-management", "RabbitMQ", "queue"}, + {"kafka:latest", "Apache Kafka", "queue"}, + {"grafana/grafana:latest", "Grafana", "monitoring"}, + {"unknown-image:1.0", "Container", "service"}, + {"registry.example.com/myapp:v1", "Container", "service"}, + } + for _, tt := range tests { + t.Run(tt.image, func(t *testing.T) { + tech, kind := imageToTech(tt.image) + if tech != tt.wantTech { + t.Errorf("imageToTech(%q) tech = %q, want %q", tt.image, tech, tt.wantTech) + } + if kind != tt.wantKind { + t.Errorf("imageToTech(%q) kind = %q, want %q", tt.image, kind, tt.wantKind) + } + }) + } +} + +func TestComposePathTokens(t *testing.T) { + tests := []struct { + path string + want bool + }{ + {"docker-compose.yml", true}, + {"docker-compose.yaml", true}, + {"docker-compose.override.yml", true}, + {"docker-compose.prod.yaml", true}, + {"compose.yaml", true}, + {"compose.yml", true}, + {"deploy/docker-compose.yml", true}, + {"config.yaml", false}, + {"main.go", false}, + {"docker-compose", true}, + } + for _, tt := range tests { + t.Run(tt.path, func(t *testing.T) { + got := composePathTokens(tt.path) + if got != tt.want { + t.Errorf("composePathTokens(%q) = %v, want %v", tt.path, got, tt.want) + } + }) + } +} + +func TestExtractServiceRef(t *testing.T) { + names := []string{"api", "db", "redis", "web", "worker"} + tests := []struct { + value string + want string + }{ + {"http://api:8080", "api"}, + {"db:5432", "db"}, + {"redis://redis:6379", "redis"}, + {"web", "web"}, + {"worker.internal:9000", "worker"}, + {"localhost", ""}, + {"127.0.0.1", ""}, + {"true", ""}, + {"false", ""}, + {"", ""}, + } + for _, tt := range tests { + t.Run(tt.value, func(t *testing.T) { + got := extractServiceRef(tt.value, names) + if got != tt.want { + t.Errorf("extractServiceRef(%q) = %q, want %q", tt.value, got, tt.want) + } + }) + } +} + +func TestTagValue(t *testing.T) { + tests := []struct { + input string + want string + }{ + {"PostgreSQL", "postgresql"}, + {"Apache Kafka", "apache-kafka"}, + {"My App", "my-app"}, + {" Spaces ", "spaces"}, + } + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + got := tagValue(tt.input) + if got != tt.want { + t.Errorf("tagValue(%q) = %q, want %q", tt.input, got, tt.want) + } + }) + } +} diff --git a/internal/watch/enrich/enrichers/compose/image_tech.go b/internal/watch/enrich/enrichers/compose/image_tech.go new file mode 100644 index 0000000..5e8c5bd --- /dev/null +++ b/internal/watch/enrich/enrichers/compose/image_tech.go @@ -0,0 +1,132 @@ +package compose + +import ( + "path" + "strings" +) + +type imageInfo struct { + Tech string + Kind string +} + +var imageTechRegistry = map[string]imageInfo{ + "postgres": {"PostgreSQL", "database"}, + "mysql": {"MySQL", "database"}, + "mariadb": {"MariaDB", "database"}, + "mongo": {"MongoDB", "database"}, + "mongodb": {"MongoDB", "database"}, + "couchdb": {"CouchDB", "database"}, + "couchbase": {"Couchbase", "database"}, + "cassandra": {"Cassandra", "database"}, + "neo4j": {"Neo4j", "database"}, + "arangodb": {"ArangoDB", "database"}, + "dynamodb": {"DynamoDB", "database"}, + "cockroachdb": {"CockroachDB", "database"}, + "rethinkdb": {"RethinkDB", "database"}, + "influxdb": {"InfluxDB", "database"}, + "timescaledb": {"TimescaleDB", "database"}, + "clickhouse": {"ClickHouse", "database"}, + "sqlite": {"SQLite", "database"}, + + "redis": {"Redis", "cache"}, + "memcached": {"Memcached", "cache"}, + "dragonflydb": {"DragonflyDB", "cache"}, + "keydb": {"KeyDB", "cache"}, + + "celery": {"Celery", "worker"}, + + "nginx": {"Nginx", "proxy"}, + "traefik": {"Traefik", "proxy"}, + "haproxy": {"HAProxy", "proxy"}, + "envoy": {"Envoy", "proxy"}, + "caddy": {"Caddy", "proxy"}, + "apache": {"Apache", "proxy"}, + "httpd": {"Apache", "proxy"}, + "kong": {"Kong", "proxy"}, + "apisix": {"Apache APISIX", "proxy"}, + "bff": {"BFF", "gateway"}, + "gateway": {"API Gateway", "gateway"}, + + "kafka": {"Apache Kafka", "queue"}, + "rabbitmq": {"RabbitMQ", "queue"}, + "nats": {"NATS", "queue"}, + "activemq": {"ActiveMQ", "queue"}, + "pulsar": {"Apache Pulsar", "queue"}, + "zeromq": {"ZeroMQ", "queue"}, + "redpanda": {"Redpanda", "queue"}, + "pubsub": {"Pub/Sub", "queue"}, + "kinesis": {"Kinesis", "queue"}, + "amazonmq": {"Amazon MQ", "queue"}, + + "elasticsearch": {"Elasticsearch", "search"}, + "opensearch": {"OpenSearch", "search"}, + "meilisearch": {"Meilisearch", "search"}, + "solr": {"Solr", "search"}, + "algolia": {"Algolia", "search"}, + "typesense": {"Typesense", "search"}, + + "minio": {"MinIO", "storage"}, + "ceph": {"Ceph", "storage"}, + "seaweedfs": {"SeaweedFS", "storage"}, + "glusterfs": {"GlusterFS", "storage"}, + + "grafana": {"Grafana", "monitoring"}, + "prometheus": {"Prometheus", "monitoring"}, + "jaeger": {"Jaeger", "observability"}, + "opentelemetry": {"OpenTelemetry", "observability"}, + "datadog": {"Datadog", "monitoring"}, + "loki": {"Loki", "monitoring"}, + "tempo": {"Tempo", "observability"}, + "mimir": {"Mimir", "monitoring"}, + + "vault": {"Vault", "security"}, + "consul": {"Consul", "service-mesh"}, + "istio": {"Istio", "service-mesh"}, + "linkerd": {"Linkerd", "service-mesh"}, + "keycloak": {"Keycloak", "auth"}, + "authentik": {"Authentik", "auth"}, + "authelia": {"Authelia", "auth"}, + + "zookeeper": {"ZooKeeper", "coordination"}, + "etcd": {"etcd", "coordination"}, + + "curlimages": {"curl", "utility"}, + "alpine": {"Alpine", "utility"}, + "busybox": {"BusyBox", "utility"}, + "ubuntu": {"Ubuntu", "utility"}, + "debian": {"Debian", "utility"}, + "golang": {"Go", "utility"}, + "python": {"Python", "utility"}, + "node": {"Node.js", "utility"}, + "ruby": {"Ruby", "utility"}, + "openjdk": {"OpenJDK", "utility"}, + "rust": {"Rust", "utility"}, +} + +func imageToTech(image string) (tech, kind string) { + if image == "" { + return "Container", "service" + } + base := strings.ToLower(path.Base(strings.Split(image, ":")[0])) + if info, ok := imageTechRegistry[base]; ok { + return info.Tech, info.Kind + } + if info, ok := imageTechRegistry[stripVersion(base)]; ok { + return info.Tech, info.Kind + } + return "Container", "service" +} + +func stripVersion(base string) string { + for i := len(base) - 1; i >= 0; i-- { + if base[i] >= '0' && base[i] <= '9' { + continue + } + if base[i] == '.' || base[i] == '-' || base[i] == '_' { + continue + } + return base[:i+1] + } + return base +} diff --git a/internal/watch/enrich/enrichers/config/config.go b/internal/watch/enrich/enrichers/config/config.go new file mode 100644 index 0000000..4d8c719 --- /dev/null +++ b/internal/watch/enrich/enrichers/config/config.go @@ -0,0 +1,63 @@ +package config + +import ( + "github.com/mertcikla/tld/internal/watch/enrich" + "github.com/mertcikla/tld/internal/watch/enrich/enrichers/pattern" +) + +func All() []enrich.Enricher { return pattern.FromSpecs(Specs()) } + +func Specs() []pattern.Spec { + return []pattern.Spec{ + native("ts.process_env", "TypeScript process.env", "typescript", "process.env", "process.env"), + lib("ts.dotenv", "TypeScript dotenv", "typescript", "dotenv", "dotenv.config"), + lib("ts.next_env", "Next.js env", "typescript", "next", "NEXT_PUBLIC_"), + lib("ts.vite_env", "Vite env", "typescript", "vite", "import.meta.env"), + native("go.os_getenv", "Go os.Getenv", "go", "os.Getenv", "os"), + lib("go.viper", "Go Viper", "go", "github.com/spf13/viper", "viper.Get"), + lib("go.envconfig", "Go envconfig", "go", "github.com/kelseyhightower/envconfig", "envconfig.Process"), + native("python.os_environ", "Python os.environ", "python", "os.environ", "os"), + lib("python.pydantic_settings", "Python Pydantic Settings", "python", "pydantic-settings", "BaseSettings"), + lib("python.dynaconf", "Python Dynaconf", "python", "dynaconf", "Dynaconf"), + lib("python.django_settings", "Django settings", "python", "django", "django.conf.settings"), + native("java.system_getenv", "Java System.getenv", "java", "System.getenv", "java.lang.System"), + lib("java.spring_value", "Spring @Value", "java", "org.springframework.beans.factory.annotation.Value", "@Value"), + lib("java.spring_configuration_properties", "Spring @ConfigurationProperties", "java", "org.springframework.boot.context.properties.ConfigurationProperties", "@ConfigurationProperties"), + lib("java.microprofile_config", "MicroProfile Config", "java", "org.eclipse.microprofile.config", "ConfigProvider"), + native("rust.std_env_var", "Rust std::env::var", "rust", "std::env::var", "std::env"), + lib("rust.dotenvy", "Rust dotenvy", "rust", "dotenvy", "dotenvy::"), + lib("rust.config", "Rust config", "rust", "config", "Config::builder"), + lib("rust.figment", "Rust figment", "rust", "figment", "Figment::"), + native("cpp.getenv", "C++ std::getenv", "cpp", "std::getenv", "cstdlib"), + lib("cpp.yaml_cpp", "C++ yaml-cpp", "cpp", "yaml-cpp", "YAML::Load"), + lib("cpp.nlohmann_json", "C++ nlohmann_json", "cpp", "nlohmann_json", "nlohmann::json"), + lib("cpp.tomlplusplus", "C++ tomlplusplus", "cpp", "tomlplusplus", "toml::parse"), + } +} + +func native(id, name, language, token, dependency string) pattern.Spec { + spec := base(id, name, language, dependency, token) + spec.Mode = enrich.ActivationAlways + spec.Triggers = nil + return spec +} + +func lib(id, name, language, dependency, token string) pattern.Spec { + return base(id, name, language, dependency, token) +} + +func base(id, name, language, dependency, token string) pattern.Spec { + return pattern.Spec{ + ID: id, + Name: name, + Category: "config", + Languages: []string{language}, + Mode: enrich.ActivationImportOrDependency, + Triggers: []enrich.ActivationSignal{{Kind: enrich.SignalDependency, Value: dependency}, {Kind: enrich.SignalImport, Value: dependency}}, + FactType: "config.env", + Relationship: "reads_config", + SourceTokens: []string{token}, + Tags: []string{"config:" + id}, + Attributes: map[string]string{"dependency": dependency, "language": language}, + } +} diff --git a/internal/watch/enrich/enrichers/config/config_test.go b/internal/watch/enrich/enrichers/config/config_test.go new file mode 100644 index 0000000..00fab4c --- /dev/null +++ b/internal/watch/enrich/enrichers/config/config_test.go @@ -0,0 +1,36 @@ +package config + +import ( + "testing" + + "github.com/mertcikla/tld/internal/watch/enrich" + "github.com/mertcikla/tld/internal/watch/enrich/enrichertest" +) + +func TestConfigEnrichers(t *testing.T) { + byID := enrichersByID() + for _, spec := range Specs() { + tc := enrichertest.Case{ + Name: spec.ID, + Enricher: byID[spec.ID], + Input: enrich.FileInput{ + RelPath: "src/config", + Language: spec.Languages[0], + Source: []byte(spec.SourceTokens[0]), + }, + Want: enrichertest.Fact{Type: spec.FactType, Tag: "category:config", Name: spec.Name, Attribute: "technology", AttrValue: spec.Name}, + } + if len(spec.Triggers) > 0 { + tc.Signals = []enrich.ActivationSignal{{Kind: enrich.SignalDependency, Value: spec.Triggers[0].Value}} + } + enrichertest.Run(t, tc) + } +} + +func enrichersByID() map[string]enrich.Enricher { + out := map[string]enrich.Enricher{} + for _, enricher := range All() { + out[enricher.Metadata().ID] = enricher + } + return out +} diff --git a/internal/watch/enrich/enrichers/dataeng/dataeng.go b/internal/watch/enrich/enrichers/dataeng/dataeng.go new file mode 100644 index 0000000..ea22b19 --- /dev/null +++ b/internal/watch/enrich/enrichers/dataeng/dataeng.go @@ -0,0 +1,36 @@ +package dataeng + +import ( + "github.com/mertcikla/tld/internal/watch/enrich" + "github.com/mertcikla/tld/internal/watch/enrich/enrichers/pattern" +) + +func All() []enrich.Enricher { return pattern.FromSpecs(Specs()) } + +func Specs() []pattern.Spec { + return []pattern.Spec{ + spec("python.airflow", "Python Apache Airflow", "python", "apache-airflow", "DAG(", "data.pipeline_id", "depends_on_task"), + spec("python.prefect", "Python Prefect", "python", "prefect", "@flow", "data.pipeline_id", "depends_on_task"), + spec("python.dagster", "Python Dagster", "python", "dagster", "@asset", "data.pipeline_id", "depends_on_task"), + spec("python.spark", "Python Apache Spark", "python", "pyspark", "spark.sql", "data.dataset_uri", "reads_dataset"), + spec("java.spark", "Java Apache Spark", "java", "org.apache.spark", "SparkSession", "data.dataset_uri", "reads_dataset"), + spec("python.ray", "Python Ray", "python", "ray", "ray.init", "data.pipeline_id", "depends_on_task"), + spec("java.ray", "Java Ray", "java", "io.ray", "Ray.init", "data.pipeline_id", "depends_on_task"), + } +} + +func spec(id, name, language, dependency, token, factType, relationship string) pattern.Spec { + return pattern.Spec{ + ID: id, + Name: name, + Category: "data", + Languages: []string{language}, + Mode: enrich.ActivationImportOrDependency, + Triggers: []enrich.ActivationSignal{{Kind: enrich.SignalDependency, Value: dependency}, {Kind: enrich.SignalImport, Value: dependency}}, + FactType: factType, + Relationship: relationship, + SourceTokens: []string{token}, + Tags: []string{"data:" + id}, + Attributes: map[string]string{"dependency": dependency, "language": language}, + } +} diff --git a/internal/watch/enrich/enrichers/dataeng/dataeng_test.go b/internal/watch/enrich/enrichers/dataeng/dataeng_test.go new file mode 100644 index 0000000..fc4135c --- /dev/null +++ b/internal/watch/enrich/enrichers/dataeng/dataeng_test.go @@ -0,0 +1,33 @@ +package dataeng + +import ( + "testing" + + "github.com/mertcikla/tld/internal/watch/enrich" + "github.com/mertcikla/tld/internal/watch/enrich/enrichertest" +) + +func TestDataEnrichers(t *testing.T) { + byID := enrichersByID() + for _, spec := range Specs() { + enrichertest.Run(t, enrichertest.Case{ + Name: spec.ID, + Enricher: byID[spec.ID], + Input: enrich.FileInput{ + RelPath: "src/pipeline", + Language: spec.Languages[0], + Source: []byte(spec.SourceTokens[0]), + }, + Signals: []enrich.ActivationSignal{{Kind: enrich.SignalDependency, Value: spec.Triggers[0].Value}}, + Want: enrichertest.Fact{Type: spec.FactType, Tag: "category:data", Name: spec.Name, Attribute: "technology", AttrValue: spec.Name}, + }) + } +} + +func enrichersByID() map[string]enrich.Enricher { + out := map[string]enrich.Enricher{} + for _, enricher := range All() { + out[enricher.Metadata().ID] = enricher + } + return out +} diff --git a/internal/watch/enrich/enrichers/datastore/datastore.go b/internal/watch/enrich/enrichers/datastore/datastore.go new file mode 100644 index 0000000..2269f4d --- /dev/null +++ b/internal/watch/enrich/enrichers/datastore/datastore.go @@ -0,0 +1,77 @@ +package datastore + +import ( + "context" + "fmt" + "regexp" + "strings" + + "github.com/mertcikla/tld/internal/watch/enrich" +) + +type Enricher = enrich.Enricher +type Fact = enrich.Fact +type FactEmitter = enrich.FactEmitter +type FileInput = enrich.FileInput +type Metadata = enrich.Metadata +type SourceSpan = enrich.SourceSpan +type SubjectRef = enrich.SubjectRef + +const ActivationAlways = enrich.ActivationAlways + +var ( + fileSubject = enrich.FileSubject + lineForOffset = enrich.LineForOffset + matchLanguages = enrich.MatchLanguages + tokenCleanupRE = regexp.MustCompile(`(?m)(^|[^:])//.*$|#.*$|/\*[\s\S]*?\*/|`) +) + +func DatastoreGlue() Enricher { + return enrich.NewEnricher( + Metadata{ID: "datastore.glue", Name: "Datastore glue", Mode: ActivationAlways}, + matchLanguages("go", "python", "javascript", "typescript", "c-sharp", "xml", "go-mod", "json", "python-requirements"), + func(ctx context.Context, input FileInput, emit FactEmitter) error { + source := string(input.Source) + scannable := tokenCleanupRE.ReplaceAllString(source, "$1") + lower := strings.ToLower(scannable) + candidates := []struct { + needle string + name string + tech string + }{ + {"redis://", "redis", "Redis"}, + {"github.com/redis/go-redis", "redis", "Redis"}, + {"spanner.googleapis.com", "spanner", "Spanner"}, + {"alloydb.googleapis.com", "alloydb", "AlloyDB"}, + {"postgres://", "postgres", "PostgreSQL"}, + {"postgresql://", "postgres", "PostgreSQL"}, + {"github.com/lib/pq", "postgres", "PostgreSQL"}, + {"secretmanager.googleapis.com", "secretmanager", "Secret Manager"}, + {"go.opentelemetry.io/otel", "opentelemetry", "OpenTelemetry"}, + } + for _, candidate := range candidates { + if !strings.Contains(lower, candidate.needle) { + continue + } + idx := strings.Index(lower, candidate.needle) + line := lineForOffset(scannable, idx) + if err := emit.EmitFact(Fact{ + Type: "datastore.dependency", + StableKey: fmt.Sprintf("datastore.dependency:%s:%s", input.RelPath, candidate.name), + Subject: fileSubject(input.RelPath), + Object: SubjectRef{Kind: "datastore", StableKey: "datastore:" + candidate.name, Name: candidate.name}, + Relationship: "uses", + Source: SourceSpan{FilePath: input.RelPath, StartLine: line, EndLine: line}, + Confidence: 0.72, + Name: candidate.name, + Tags: []string{"arch:datastore", "datastore:" + candidate.name}, + Attributes: map[string]string{"name": candidate.name, "technology": candidate.tech}, + VisibilityHints: map[string]float64{"high_signal": 0.5}, + }); err != nil { + return err + } + } + return nil + }, + ) +} diff --git a/internal/watch/enrich/enrichers/datastore/datastore_test.go b/internal/watch/enrich/enrichers/datastore/datastore_test.go new file mode 100644 index 0000000..7c72daa --- /dev/null +++ b/internal/watch/enrich/enrichers/datastore/datastore_test.go @@ -0,0 +1,64 @@ +package datastore + +import ( + "context" + "testing" + + "github.com/mertcikla/tld/internal/watch/enrich" + "github.com/mertcikla/tld/internal/watch/enrich/enrichertest" +) + +func TestDatastoreGlue(t *testing.T) { + enrichertest.Run(t, []enrichertest.Case{ + { + Name: "datastore glue matches redis connection string", + Enricher: DatastoreGlue(), + Input: enrich.FileInput{ + RelPath: "cache.go", + Language: "go", + Source: []byte(`func connect() { _ = "redis://cache:6379" }`), + }, + Want: enrichertest.Fact{Type: "datastore.dependency", Tag: "datastore:redis", Name: "redis"}, + }, + }...) +} + +func TestDatastoreGlueNegatives(t *testing.T) { + cases := []struct { + name string + source string + }{ + {"ignores redis in comments", `// TODO: consider using redis://cache:6379`}, + {"ignores bare redis mention", `var x = "redis"`}, + {"ignores postgres in comments", `/* postgres://localhost */`}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + input := enrich.FileInput{ + RelPath: "test.go", + Language: "go", + Source: []byte(tc.source), + } + emitter := &factCollector{} + err := DatastoreGlue().EnrichFile(context.Background(), input, emitter) + if err != nil { + t.Fatalf("enrich: %v", err) + } + if len(emitter.facts) > 0 { + t.Fatalf("expected no facts for source %q, got %v", tc.source, emitter.facts) + } + }) + } +} + +type factCollector struct { + facts []enrich.Fact +} + +func (c *factCollector) EmitFact(f enrich.Fact) error { + c.facts = append(c.facts, f) + return nil +} + +func (c *factCollector) Warn(w enrich.Warning) {} diff --git a/internal/watch/enrich/enrichers/deployment/deployment.go b/internal/watch/enrich/enrichers/deployment/deployment.go new file mode 100644 index 0000000..b3e5950 --- /dev/null +++ b/internal/watch/enrich/enrichers/deployment/deployment.go @@ -0,0 +1,36 @@ +package deployment + +import ( + "github.com/mertcikla/tld/internal/watch/enrich" + "github.com/mertcikla/tld/internal/watch/enrich/enrichers/pattern" +) + +func All() []enrich.Enricher { return pattern.FromSpecs(Specs()) } + +func Specs() []pattern.Spec { + return []pattern.Spec{ + spec("deployment.github_actions", "GitHub Actions", []string{".github/workflows/"}, nil, "deployment.workflow", "builds"), + spec("deployment.gitlab_ci", "GitLab CI", []string{".gitlab-ci.yml"}, nil, "deployment.workflow", "builds"), + spec("deployment.circleci", "CircleCI", []string{".circleci/config.yml"}, nil, "deployment.workflow", "builds"), + spec("deployment.jenkinsfile", "Jenkinsfile", []string{"jenkinsfile"}, nil, "deployment.workflow", "builds"), + spec("deployment.buildkite", "Buildkite", []string{".buildkite/"}, nil, "deployment.workflow", "builds"), + spec("deployment.argo_cd", "Argo CD", []string{"argocd"}, []string{"argoproj.io"}, "deployment.target", "deploys_to"), + spec("deployment.flux", "Flux", nil, []string{"toolkit.fluxcd.io"}, "deployment.target", "deploys_to"), + } +} + +func spec(id, name string, pathTokens, sourceTokens []string, factType, relationship string) pattern.Spec { + return pattern.Spec{ + ID: id, + Name: name, + Category: "deployment", + Languages: []string{"yaml", "groovy"}, + Mode: enrich.ActivationAlways, + FactType: factType, + Relationship: relationship, + SourceTokens: sourceTokens, + PathTokens: pathTokens, + Tags: []string{"deployment:" + id}, + Attributes: map[string]string{"provider": id}, + } +} diff --git a/internal/watch/enrich/enrichers/deployment/deployment_test.go b/internal/watch/enrich/enrichers/deployment/deployment_test.go new file mode 100644 index 0000000..37c78a9 --- /dev/null +++ b/internal/watch/enrich/enrichers/deployment/deployment_test.go @@ -0,0 +1,42 @@ +package deployment + +import ( + "testing" + + "github.com/mertcikla/tld/internal/watch/enrich" + "github.com/mertcikla/tld/internal/watch/enrich/enrichertest" +) + +func TestDeploymentEnrichers(t *testing.T) { + byID := enrichersByID() + for _, spec := range Specs() { + relPath := "deploy/pipeline.yml" + if len(spec.PathTokens) > 0 { + relPath = spec.PathTokens[0] + "pipeline.yml" + } + var source []byte + if len(spec.SourceTokens) > 0 { + source = []byte(spec.SourceTokens[0]) + } else { + source = []byte("jobs:\n build:\n runs-on: ubuntu-latest\n") + } + enrichertest.Run(t, enrichertest.Case{ + Name: spec.ID, + Enricher: byID[spec.ID], + Input: enrich.FileInput{ + RelPath: relPath, + Language: "yaml", + Source: source, + }, + Want: enrichertest.Fact{Type: spec.FactType, Tag: "category:deployment", Name: spec.Name, Attribute: "technology", AttrValue: spec.Name}, + }) + } +} + +func enrichersByID() map[string]enrich.Enricher { + out := map[string]enrich.Enricher{} + for _, enricher := range All() { + out[enricher.Metadata().ID] = enricher + } + return out +} diff --git a/internal/watch/enrich/enrichers/frontend/typescript/frontend.go b/internal/watch/enrich/enrichers/frontend/typescript/frontend.go new file mode 100644 index 0000000..9c4f914 --- /dev/null +++ b/internal/watch/enrich/enrichers/frontend/typescript/frontend.go @@ -0,0 +1,129 @@ +package typescript + +import ( + "context" + "fmt" + "path" + "path/filepath" + "regexp" + "strings" + + "github.com/mertcikla/tld/internal/watch/enrich" +) + +type ActivationSignal = enrich.ActivationSignal +type Enricher = enrich.Enricher +type Fact = enrich.Fact +type FactEmitter = enrich.FactEmitter +type FileInput = enrich.FileInput +type Metadata = enrich.Metadata +type RoutePattern = enrich.RoutePattern +type SourceSpan = enrich.SourceSpan +type SubjectRef = enrich.SubjectRef + +const ( + ActivationImportOrDependency = enrich.ActivationImportOrDependency + SignalDependency = enrich.SignalDependency + SignalImport = enrich.SignalImport +) + +var fileSubject = enrich.FileSubject + +func NextJS() Enricher { + return enrich.NewEnricher( + Metadata{ + ID: "ts.nextjs", + Name: "Next.js routes", + Mode: ActivationImportOrDependency, + Triggers: []ActivationSignal{ + {Kind: SignalDependency, Value: "next"}, + {Kind: SignalImport, Value: "next"}, + }, + }, + func(input FileInput) bool { + route := nextRoutePath(input.RelPath) + return route != "" + }, + func(ctx context.Context, input FileInput, emit FactEmitter) error { + route := nextRoutePath(input.RelPath) + if route == "" { + return nil + } + return emit.EmitFact(Fact{ + Type: "frontend.route", + StableKey: fmt.Sprintf("frontend.route:nextjs:%s:%s", input.RelPath, route), + Subject: fileSubject(input.RelPath), + Object: SubjectRef{Kind: "frontend.route", StableKey: "frontend.route:nextjs:" + route, FilePath: input.RelPath, Name: route}, + Relationship: "declares", + Source: SourceSpan{FilePath: input.RelPath, StartLine: 1, EndLine: 1}, + Confidence: 0.95, + Name: route, + Tags: []string{"frontend:route", "framework:nextjs"}, + Attributes: map[string]string{"framework": "nextjs", "path": route}, + VisibilityHints: map[string]float64{ + "high_signal": 1, + }, + }) + }, + ) +} + +func ReactRouter() Enricher { + return enrich.RouteRegexEnricher("ts.react_router", "React Router routes", "typescript,javascript", []ActivationSignal{ + {Kind: SignalImport, Value: "react-router"}, + {Kind: SignalImport, Value: "react-router-dom"}, + {Kind: SignalDependency, Value: "react-router"}, + {Kind: SignalDependency, Value: "react-router-dom"}, + }, []*RoutePattern{ + {Re: regexp.MustCompile(`]*\bpath\s*=\s*["'{\x60]([^"'}\x60]+)["'}\x60]`), FactType: "frontend.route", Framework: "react-router", Tags: []string{"frontend:route", "framework:react-router"}}, + {Re: regexp.MustCompile(`\bpath\s*:\s*["'\x60]([^"'\x60]+)["'\x60]`), FactType: "frontend.route", Framework: "react-router", Tags: []string{"frontend:route", "framework:react-router"}}, + }) +} + +func nextRoutePath(relPath string) string { + rel := filepath.ToSlash(relPath) + ext := path.Ext(rel) + if ext == "" { + return "" + } + trimmed := strings.TrimSuffix(rel, ext) + for _, prefix := range []string{"src/app/", "app/"} { + if after, ok := strings.CutPrefix(trimmed, prefix); ok { + route := after + if !strings.HasSuffix(route, "/page") && !strings.HasSuffix(route, "/route") { + return "" + } + route = strings.TrimSuffix(strings.TrimSuffix(route, "/page"), "/route") + return normalizeNextRoute(route) + } + } + for _, prefix := range []string{"src/pages/", "pages/"} { + if after, ok := strings.CutPrefix(trimmed, prefix); ok { + route := after + return normalizeNextRoute(route) + } + } + return "" +} + +func normalizeNextRoute(route string) string { + route = strings.Trim(route, "/") + if route == "" || route == "index" { + return "/" + } + parts := strings.Split(route, "/") + out := make([]string, 0, len(parts)) + for _, part := range parts { + if part == "index" { + continue + } + part = strings.TrimSpace(part) + if strings.HasPrefix(part, "[") && strings.HasSuffix(part, "]") { + part = ":" + strings.Trim(part, "[]") + } + if part != "" { + out = append(out, part) + } + } + return "/" + strings.Join(out, "/") +} diff --git a/internal/watch/enrich/enrichers/frontend/typescript/frontend_test.go b/internal/watch/enrich/enrichers/frontend/typescript/frontend_test.go new file mode 100644 index 0000000..736de0d --- /dev/null +++ b/internal/watch/enrich/enrichers/frontend/typescript/frontend_test.go @@ -0,0 +1,22 @@ +package typescript + +import ( + "testing" + + "github.com/mertcikla/tld/internal/watch/enrich" + "github.com/mertcikla/tld/internal/watch/enrich/enrichertest" +) + +func TestTypeScriptFrontendEnrichers(t *testing.T) { + enrichertest.Run(t, enrichertest.Case{ + Name: "nextjs file route requires activation and matches app route path", + Enricher: NextJS(), + Input: enrich.FileInput{ + RelPath: "src/app/users/[id]/page.tsx", + Language: "typescript", + Source: []byte(`export default function Page() { return null }`), + }, + Signals: []enrich.ActivationSignal{{Kind: enrich.SignalDependency, Value: "next"}}, + Want: enrichertest.Fact{Type: "frontend.route", Tag: "framework:nextjs", Name: "/users/:id"}, + }) +} diff --git a/internal/watch/enrich/enrichers/generic/glue.go b/internal/watch/enrich/enrichers/generic/glue.go new file mode 100644 index 0000000..488de0a --- /dev/null +++ b/internal/watch/enrich/enrichers/generic/glue.go @@ -0,0 +1,275 @@ +package generic + +import ( + "context" + "fmt" + "maps" + "path/filepath" + "regexp" + "strings" + + "github.com/mertcikla/tld/internal/watch/enrich" +) + +type Enricher = enrich.Enricher +type Fact = enrich.Fact +type FactEmitter = enrich.FactEmitter +type FileInput = enrich.FileInput + +type detector struct { + ID string + Name string + Category string + FactType string + Relationship string + ObjectKind string + Tags []string + Tokens []string + PathTokens []string + Attrs map[string]string +} + +var tokenCleanupRE = regexp.MustCompile(`(?m)(^|[^:])//.*$|#.*$|/\*[\s\S]*?\*/|`) + +// ArchitectureGlue detects common framework/library entrypoints and integration +// glue from imports, manifests, and configuration files. +func ArchitectureGlue() Enricher { + return enrich.NewEnricher( + enrich.Metadata{ID: "generic.architecture_glue", Name: "Generic architecture glue", Mode: enrich.ActivationAlways}, + func(input FileInput) bool { + return !ignoredPath(input.RelPath) + }, + func(ctx context.Context, input FileInput, emit FactEmitter) error { + return emitGenericFacts(input, emit) + }, + ) +} + +func emitGenericFacts(input FileInput, emit FactEmitter) error { + source := string(input.Source) + scannable := tokenCleanupRE.ReplaceAllString(source, "$1") + seen := map[string]struct{}{} + for _, det := range detectors { + line := detectorLine(input.RelPath, scannable, det) + if line == 0 { + continue + } + key := det.FactType + ":" + det.ID + ":" + input.RelPath + if _, ok := seen[key]; ok { + continue + } + seen[key] = struct{}{} + attrs := map[string]string{ + "category": det.Category, + "technology": det.Name, + "detector": det.ID, + } + maps.Copy(attrs, det.Attrs) + tags := append([]string{"arch:glue", "category:" + tagValue(det.Category), "technology:" + tagValue(det.Name)}, det.Tags...) + if err := emit.EmitFact(Fact{ + Type: det.FactType, + StableKey: key, + Subject: enrich.SubjectForLine(input, line), + Object: enrich.SubjectRef{Kind: firstNonEmpty(det.ObjectKind, det.FactType), StableKey: det.FactType + ":" + det.ID, FilePath: input.RelPath, Name: det.Name}, + Relationship: det.Relationship, + Source: enrich.SourceSpan{FilePath: input.RelPath, StartLine: line, EndLine: line}, + Confidence: 0.72, + Name: det.Name, + Tags: tags, + Attributes: attrs, + VisibilityHints: map[string]float64{ + "high_signal": 0.6, + }, + }); err != nil { + return err + } + } + return nil +} + +func detectorLine(relPath, source string, det detector) int { + rel := strings.ToLower(filepath.ToSlash(relPath)) + for _, token := range det.PathTokens { + if strings.Contains(rel, strings.ToLower(token)) { + return 1 + } + } + lower := strings.ToLower(source) + for _, token := range det.Tokens { + idx := strings.Index(lower, strings.ToLower(token)) + if idx >= 0 { + return enrich.LineForOffset(source, idx) + } + } + return 0 +} + +func ignoredPath(relPath string) bool { + parts := strings.SplitSeq(filepath.ToSlash(relPath), "/") + for part := range parts { + switch strings.ToLower(part) { + case ".git", "node_modules", "vendor", "dist", "build", "coverage", "generated", "gen": + return true + } + } + return false +} + +func firstNonEmpty(values ...string) string { + for _, value := range values { + if strings.TrimSpace(value) != "" { + return value + } + } + return "" +} + +func tagValue(value string) string { + value = strings.ToLower(strings.TrimSpace(value)) + value = strings.NewReplacer(" / ", "-", "/", "-", " ", "-", "&", "and", ".", "").Replace(value) + for strings.Contains(value, "--") { + value = strings.ReplaceAll(value, "--", "-") + } + return strings.Trim(value, "-") +} + +func d(id, name, category, factType, relationship string, tokens ...string) detector { + return detector{ + ID: id, + Name: name, + Category: category, + FactType: factType, + Relationship: relationship, + Tags: []string{fmt.Sprintf("%s:%s", tagValue(category), tagValue(name))}, + Tokens: tokens, + Attrs: map[string]string{"framework": id}, + } +} + +var detectors = []detector{ + // Observability / telemetry. + d("opentelemetry", "OpenTelemetry", "observability", "telemetry.project", "reports_to", "@opentelemetry/", "go.opentelemetry.io/otel", "opentelemetry", "io.opentelemetry", "opentelemetry-cpp"), + d("prometheus", "Prometheus", "observability", "telemetry.metric", "emits_metric", "prom-client", "github.com/prometheus/client_golang", "prometheus_client", "micrometer-registry-prometheus", "prometheus-cpp"), + d("sentry", "Sentry", "observability", "telemetry.project", "reports_to", "@sentry/", "github.com/getsentry/sentry-go", "sentry_sdk", "io.sentry", "sentry-native"), + d("datadog", "Datadog", "observability", "telemetry.project", "reports_to", "dd-trace", "gopkg.in/DataDog/dd-trace-go", "ddtrace", "com.datadoghq", "datadog"), + d("micrometer", "Micrometer", "observability", "telemetry.metric", "emits_metric", "io.micrometer", "micrometer-registry"), + d("rust-tracing", "Rust tracing", "observability", "telemetry.span", "creates_span", "tracing::", "tracing ="), + d("rust-metrics", "Rust metrics", "observability", "telemetry.metric", "emits_metric", "metrics::", "metrics ="), + + // Auth / identity. + d("auth0", "Auth0", "auth", "auth.provider", "uses_identity_provider", "@auth0/", "auth0.com", "github.com/auth0", "com.auth0"), + d("cognito", "Cognito", "auth", "auth.provider", "uses_identity_provider", "amazon-cognito", "cognitoidentityprovider", "cognito-idp", "software.amazon.awssdk.services.cognitoidentityprovider"), + d("firebase-auth", "Firebase Auth", "auth", "auth.provider", "uses_identity_provider", "firebase/auth", "firebase_admin.auth", "firebaseauth"), + d("clerk", "Clerk", "auth", "auth.provider", "uses_identity_provider", "@clerk/", "clerk.com"), + d("nextauth", "NextAuth", "auth", "auth.provider", "uses_identity_provider", "next-auth"), + d("jwt", "JWT validation", "auth", "auth.issuer", "trusts_issuer", "github.com/golang-jwt/jwt", "jsonwebtoken", "jwt-cpp", "jwtvalidator"), + d("oidc", "OIDC", "auth", "auth.issuer", "trusts_issuer", "coreos/go-oidc", "openidconnect", "openid-client", "oidc"), + d("pyjwt", "PyJWT", "auth", "auth.issuer", "trusts_issuer", "pyjwt", "import jwt"), + d("authlib", "Authlib", "auth", "auth.provider", "uses_identity_provider", "authlib"), + d("django-auth", "Django auth", "auth", "auth.provider", "authenticates_with", "django.contrib.auth"), + d("fastapi-security", "FastAPI security", "auth", "auth.provider", "authenticates_with", "fastapi.security"), + d("spring-security", "Spring Security OAuth/OIDC", "auth", "auth.provider", "uses_identity_provider", "spring-boot-starter-oauth2-client", "spring-security-oauth2", "@enablewebsecurity"), + d("keycloak", "Keycloak", "auth", "auth.provider", "uses_identity_provider", "keycloak"), + d("rust-oauth2", "Rust oauth2", "auth", "auth.provider", "uses_identity_provider", "oauth2 ="), + + // Background jobs / schedulers. + d("bullmq", "BullMQ", "jobs", "job.queue", "enqueues", "bullmq"), + d("agenda", "Agenda", "jobs", "job.schedule", "runs_on_schedule", "agenda"), + d("node-cron", "node-cron", "jobs", "job.schedule", "runs_on_schedule", "node-cron", "cron.schedule"), + d("robfig-cron", "robfig/cron", "jobs", "job.schedule", "runs_on_schedule", "github.com/robfig/cron"), + d("asynq", "asynq", "jobs", "job.queue", "consumes", "github.com/hibiken/asynq"), + d("machinery", "machinery", "jobs", "job.queue", "consumes", "github.com/RichardKnop/machinery"), + d("celery", "Celery", "jobs", "job.worker", "handles_job", "celery", "@shared_task"), + d("rq", "RQ", "jobs", "job.queue", "consumes", "from rq", "import rq"), + d("apscheduler", "APScheduler", "jobs", "job.schedule", "runs_on_schedule", "apscheduler"), + d("spring-scheduling", "Spring Scheduling", "jobs", "job.schedule", "runs_on_schedule", "@scheduled", "@enablescheduling"), + d("quartz", "Quartz", "jobs", "job.schedule", "runs_on_schedule", "org.quartz"), + d("tokio-cron-scheduler", "tokio-cron-scheduler", "jobs", "job.schedule", "runs_on_schedule", "tokio_cron_scheduler"), + d("apalis", "apalis", "jobs", "job.queue", "consumes", "apalis"), + + // API specs / schema files. + {ID: "openapi", Name: "OpenAPI / Swagger", Category: "api-spec", FactType: "api.spec", Relationship: "documents", PathTokens: []string{"openapi", "swagger"}, Tokens: []string{"openapi:", "\"openapi\"", "swagger:"}, Tags: []string{"api-spec:openapi"}, Attrs: map[string]string{"format": "openapi"}}, + {ID: "asyncapi", Name: "AsyncAPI", Category: "api-spec", FactType: "api.spec", Relationship: "documents", PathTokens: []string{"asyncapi"}, Tokens: []string{"asyncapi:", "\"asyncapi\""}, Tags: []string{"api-spec:asyncapi"}, Attrs: map[string]string{"format": "asyncapi"}}, + {ID: "graphql-schema", Name: "GraphQL schema", Category: "api-spec", FactType: "api.schema", Relationship: "declares", PathTokens: []string{".graphql", ".gql"}, Tokens: []string{"type Query", "schema {"}, Tags: []string{"api-spec:graphql"}, Attrs: map[string]string{"format": "graphql"}}, + {ID: "protobuf", Name: "Protocol Buffers", Category: "api-spec", FactType: "rpc.service", Relationship: "exposes", PathTokens: []string{".proto"}, Tokens: []string{"syntax = \"proto", "service "}, Tags: []string{"api-spec:protobuf", "protocol:grpc"}, Attrs: map[string]string{"format": "protobuf"}}, + {ID: "avro", Name: "Avro", Category: "api-spec", FactType: "api.schema", Relationship: "declares", PathTokens: []string{".avsc", ".avdl"}, Tokens: []string{"\"type\":\"record\"", "\"namespace\""}, Tags: []string{"api-spec:avro"}, Attrs: map[string]string{"format": "avro"}}, + {ID: "json-schema", Name: "JSON Schema", Category: "api-spec", FactType: "api.schema", Relationship: "declares", PathTokens: []string{"schema.json"}, Tokens: []string{"\"$schema\"", "json-schema.org"}, Tags: []string{"api-spec:json-schema"}, Attrs: map[string]string{"format": "json-schema"}}, + + // CI/CD and deployment. + {ID: "github-actions", Name: "GitHub Actions", Category: "deployment", FactType: "deployment.workflow", Relationship: "builds", PathTokens: []string{".github/workflows/"}, Tokens: []string{"runs-on:", "uses: actions/"}, Tags: []string{"deployment:github-actions"}, Attrs: map[string]string{"provider": "github-actions"}}, + {ID: "gitlab-ci", Name: "GitLab CI", Category: "deployment", FactType: "deployment.workflow", Relationship: "builds", PathTokens: []string{".gitlab-ci.yml"}, Tokens: []string{"gitlab-ci", "stages:"}, Tags: []string{"deployment:gitlab-ci"}, Attrs: map[string]string{"provider": "gitlab-ci"}}, + {ID: "circleci", Name: "CircleCI", Category: "deployment", FactType: "deployment.workflow", Relationship: "builds", PathTokens: []string{".circleci/config.yml"}, Tokens: []string{"circleci", "orbs:"}, Tags: []string{"deployment:circleci"}, Attrs: map[string]string{"provider": "circleci"}}, + {ID: "jenkinsfile", Name: "Jenkinsfile", Category: "deployment", FactType: "deployment.workflow", Relationship: "builds", PathTokens: []string{"jenkinsfile"}, Tokens: []string{"pipeline {"}, Tags: []string{"deployment:jenkinsfile"}, Attrs: map[string]string{"provider": "jenkins"}}, + {ID: "buildkite", Name: "Buildkite", Category: "deployment", FactType: "deployment.workflow", Relationship: "builds", PathTokens: []string{".buildkite/"}, Tokens: []string{"buildkite", "plugins:"}, Tags: []string{"deployment:buildkite"}, Attrs: map[string]string{"provider": "buildkite"}}, + {ID: "argo-cd", Name: "Argo CD", Category: "deployment", FactType: "deployment.target", Relationship: "deploys_to", Tokens: []string{"argoproj.io", "kind: Application"}, Tags: []string{"deployment:argo-cd"}, Attrs: map[string]string{"provider": "argo-cd"}}, + {ID: "flux", Name: "Flux", Category: "deployment", FactType: "deployment.target", Relationship: "deploys_to", Tokens: []string{"toolkit.fluxcd.io", "kind: Kustomization"}, Tags: []string{"deployment:flux"}, Attrs: map[string]string{"provider": "flux"}}, + + // Secrets / credentials. + d("aws-secrets-manager", "AWS Secrets Manager", "secrets", "secret.provider", "uses_secret", "secretsmanager", "aws_secretsmanager_secret", "software.amazon.awssdk.services.secretsmanager"), + d("aws-ssm", "AWS SSM Parameter Store", "secrets", "secret.provider", "reads_config", "ssm:GetParameter", "aws_ssm_parameter", "ssm.get_parameter"), + d("gcp-secret-manager", "GCP Secret Manager", "secrets", "secret.provider", "uses_secret", "secretmanager.googleapis.com", "google.cloud.secretmanager"), + d("azure-key-vault", "Azure Key Vault", "secrets", "secret.provider", "uses_secret", "azure.keyvault", "azurerm_key_vault", "vault.azure.net"), + d("kubernetes-secrets", "Kubernetes Secrets", "secrets", "secret.provider", "uses_secret", "kind: Secret", "secretKeyRef"), + d("vault", "Vault", "secrets", "secret.provider", "uses_secret", "hashicorp/vault", "vault kv", "vault.hashicorp.com"), + d("doppler", "Doppler", "secrets", "secret.provider", "uses_secret", "doppler", "DOPPLER_TOKEN"), + d("onepassword", "1Password Secrets Automation", "secrets", "secret.provider", "uses_secret", "1password", "op://", "OP_SERVICE_ACCOUNT_TOKEN"), + + // Monorepo / package boundaries. + d("nx", "Nx", "workspace", "workspace.package", "contains", "nx.json", "@nrwl/", "@nx/"), + d("turborepo", "Turborepo", "workspace", "workspace.package", "builds", "turbo.json", "turbo run"), + d("pnpm-workspaces", "pnpm workspaces", "workspace", "workspace.package", "contains", "pnpm-workspace.yaml"), + d("yarn-workspaces", "Yarn workspaces", "workspace", "workspace.package", "contains", "\"workspaces\""), + d("bazel", "Bazel", "workspace", "module.boundary", "builds", "WORKSPACE", "BUILD.bazel", "bazel_dep("), + d("gradle-multiproject", "Gradle multi-project", "workspace", "module.boundary", "contains", "settings.gradle", "include("), + d("maven-modules", "Maven modules", "workspace", "module.boundary", "contains", "", ""), + d("cargo-workspace", "Cargo workspace", "workspace", "workspace.package", "contains", "[workspace]"), + d("go-workspace", "Go workspaces", "workspace", "workspace.package", "contains", "go.work", "use ("), + + // AI / ML operations and LLMs. + d("pinecone", "Pinecone", "ai", "ai.vector_index", "queries_index", "pinecone"), + d("milvus", "Milvus", "ai", "ai.vector_index", "queries_index", "milvus"), + d("qdrant", "Qdrant", "ai", "ai.vector_index", "queries_index", "qdrant"), + d("chroma", "Chroma", "ai", "ai.vector_index", "queries_index", "chromadb", "chroma_client"), + d("weaviate", "Weaviate", "ai", "ai.vector_index", "queries_index", "weaviate"), + d("huggingface", "Hugging Face", "ai", "ai.model_id", "loads_model", "huggingface_hub", "transformers"), + d("mlflow", "MLflow", "ai", "ai.experiment_tracker", "tracks_metrics_to", "mlflow"), + d("wandb", "Weights & Biases", "ai", "ai.experiment_tracker", "tracks_metrics_to", "wandb"), + d("openai", "OpenAI SDK", "ai", "ai.llm_endpoint", "calls_llm", "openai", "@openai/"), + d("anthropic", "Anthropic SDK", "ai", "ai.llm_endpoint", "calls_llm", "anthropic", "@anthropic-ai/"), + d("langchain", "LangChain", "ai", "ai.llm_endpoint", "calls_llm", "langchain"), + d("llamaindex", "LlamaIndex", "ai", "ai.llm_endpoint", "calls_llm", "llama_index", "llamaindex"), + + // Embedded systems and IoT messaging. + d("mqtt", "MQTT", "iot", "iot.mqtt_topic", "publishes_to_device", "mqtt", "paho.mqtt", "mosquitto"), + d("coap", "CoAP", "iot", "iot.broker", "publishes_to_device", "coap"), + d("i2c", "I2C", "iot", "hardware.bus_address", "communicates_via_i2c", "i2c_init", "i2c_open", "i2c_transfer", "i2c_read", "i2c_write", "SMBus"), + d("spi", "SPI", "iot", "hardware.bus_address", "communicates_via_i2c", "spi_init", "spi_open", "spi_transfer", "spi_mode", "SPIDevice", "spidev"), + d("uart", "UART", "iot", "hardware.pin", "communicates_via_i2c", "uart_init", "uart_open", "uart_write", "uart_read", "uart_puts", "serialport"), + d("can-bus", "CAN Bus", "iot", "hardware.bus_address", "communicates_via_i2c", "canbus", "socketcan"), + + // Kernel, systems, and local IPC. + d("unix-socket", "Unix Domain Sockets", "ipc", "ipc.socket_path", "connects_to_socket", "unix://", "AF_UNIX"), + d("dbus", "D-Bus", "ipc", "ipc.dbus_interface", "exposes_dbus_service", "dbus", "org.freedesktop"), + d("named-pipes", "Named Pipes", "ipc", "ipc.socket_path", "connects_to_socket", `\\.\pipe\`, "mkfifo"), + d("grpc-uds", "gRPC over UDS", "ipc", "ipc.socket_path", "connects_to_socket", "unix:", "grpc.WithContextDialer"), + d("sysfs", "sysfs / procfs", "kernel", "kernel.device_node", "reads_device", "/sys/", "/proc/"), + d("ebpf", "eBPF", "kernel", "kernel.device_node", "reads_device", "kprobe", "uprobe", "tracepoint", "libbpf", "bcc"), + + // Data engineering and orchestration. + d("airflow", "Apache Airflow", "data", "data.pipeline_id", "depends_on_task", "airflow", "DAG("), + d("prefect", "Prefect", "data", "data.pipeline_id", "depends_on_task", "prefect", "@flow"), + d("dagster", "Dagster", "data", "data.pipeline_id", "depends_on_task", "dagster", "@asset"), + d("spark", "Apache Spark", "data", "data.dataset_uri", "reads_dataset", "pyspark", "spark.sql", "org.apache.spark"), + d("ray", "Ray", "data", "data.pipeline_id", "depends_on_task", "ray.init", "import ray"), + + // Web3 / blockchain. + d("ethers-js", "ethers.js", "web3", "web3.rpc_endpoint", "connects_to_chain", "ethers"), + d("web3-js", "web3.js", "web3", "web3.rpc_endpoint", "connects_to_chain", "web3.js", "new Web3"), + d("web3-py", "web3.py", "web3", "web3.rpc_endpoint", "connects_to_chain", "from web3 import Web3", "web3.py"), + d("foundry", "Foundry", "web3", "web3.chain_id", "connects_to_chain", "foundry.toml", "forge-std"), + d("hardhat", "Hardhat", "web3", "web3.chain_id", "connects_to_chain", "hardhat.config", "hardhat"), + + // Desktop / mobile OS integration. + {ID: "uri-schemes", Name: "Custom URI schemes", Category: "os-integration", FactType: "os.uri_scheme", Relationship: "handles_deep_link", PathTokens: []string{"info.plist", "androidmanifest.xml", "electron"}, Tokens: []string{"CFBundleURLSchemes", "android.intent.action.VIEW"}, Tags: []string{"os-integration:uri-schemes"}, Attrs: map[string]string{"platform": "desktop-mobile"}}, + {ID: "android-intents", Name: "Android Intents", Category: "os-integration", FactType: "os.intent", Relationship: "broadcasts_intent", PathTokens: []string{"androidmanifest.xml"}, Tokens: []string{" 0 { + tc.Signals = []enrich.ActivationSignal{{Kind: enrich.SignalDependency, Value: spec.Triggers[0].Value}} + } + enrichertest.Run(t, tc) + } +} + +func enrichersByID() map[string]enrich.Enricher { + out := map[string]enrich.Enricher{} + for _, enricher := range All() { + out[enricher.Metadata().ID] = enricher + } + return out +} diff --git a/internal/watch/enrich/enrichers/iac/iac.go b/internal/watch/enrich/enrichers/iac/iac.go new file mode 100644 index 0000000..4b362f8 --- /dev/null +++ b/internal/watch/enrich/enrichers/iac/iac.go @@ -0,0 +1,36 @@ +package iac + +import ( + "github.com/mertcikla/tld/internal/watch/enrich" + "github.com/mertcikla/tld/internal/watch/enrich/enrichers/pattern" +) + +func All() []enrich.Enricher { return pattern.FromSpecs(Specs()) } + +func Specs() []pattern.Spec { + return []pattern.Spec{ + spec("iac.kubernetes", "Kubernetes YAML", "yaml", []string{"kind: Deployment", "apiVersion: apps/v1", "kind: StatefulSet", "kind: DaemonSet"}, nil, "runtime.service", "deploys"), + spec("iac.helm", "Helm values", "yaml", []string{"helm.sh/chart", "{{ .Values", "{{ .Release"}, []string{"chart.yaml", "values.yaml"}, "runtime.service", "deploys"), + spec("iac.terraform", "Terraform", "terraform", []string{"resource \"", "module \""}, []string{".tf"}, "cloud.resource", "provisions"), + spec("iac.pulumi", "Pulumi", "typescript", []string{"new aws.", "pulumi."}, []string{"pulumi.yaml", "pulumi.yml"}, "cloud.resource", "provisions"), + spec("iac.serverless", "Serverless Framework", "yaml", nil, []string{"serverless.yml", "serverless.yaml"}, "runtime.service", "deploys"), + spec("iac.aws_cdk", "AWS CDK", "typescript", []string{"aws-cdk-lib", "new cdk.Stack"}, []string{"cdk.json"}, "cloud.resource", "provisions"), + spec("iac.github_actions_deploy", "GitHub Actions deployment configs", "yaml", nil, []string{".github/workflows/"}, "deployment.workflow", "deploys"), + } +} + +func spec(id, name, language string, tokens, pathTokens []string, factType, relationship string) pattern.Spec { + return pattern.Spec{ + ID: id, + Name: name, + Category: "iac", + Languages: []string{language}, + Mode: enrich.ActivationAlways, + FactType: factType, + Relationship: relationship, + SourceTokens: tokens, + PathTokens: pathTokens, + Tags: []string{"iac:" + id}, + Attributes: map[string]string{"language": language}, + } +} diff --git a/internal/watch/enrich/enrichers/iac/iac_test.go b/internal/watch/enrich/enrichers/iac/iac_test.go new file mode 100644 index 0000000..f3411eb --- /dev/null +++ b/internal/watch/enrich/enrichers/iac/iac_test.go @@ -0,0 +1,46 @@ +package iac + +import ( + "testing" + + "github.com/mertcikla/tld/internal/watch/enrich" + "github.com/mertcikla/tld/internal/watch/enrich/enrichertest" +) + +func TestIaCEnrichers(t *testing.T) { + byID := enrichersByID() + for _, spec := range Specs() { + relPath := "deploy/service.yaml" + language := "yaml" + if len(spec.Languages) > 0 { + language = spec.Languages[0] + } + if len(spec.PathTokens) > 0 { + relPath = spec.PathTokens[0] + } + var source []byte + if len(spec.SourceTokens) > 0 { + source = []byte(spec.SourceTokens[0]) + } else { + source = []byte("kind: Service\n") + } + enrichertest.Run(t, enrichertest.Case{ + Name: spec.ID, + Enricher: byID[spec.ID], + Input: enrich.FileInput{ + RelPath: relPath, + Language: language, + Source: source, + }, + Want: enrichertest.Fact{Type: spec.FactType, Tag: "category:iac", Name: spec.Name, Attribute: "technology", AttrValue: spec.Name}, + }) + } +} + +func enrichersByID() map[string]enrich.Enricher { + out := map[string]enrich.Enricher{} + for _, enricher := range All() { + out[enricher.Metadata().ID] = enricher + } + return out +} diff --git a/internal/watch/enrich/enrichers/inventory/dependencies.go b/internal/watch/enrich/enrichers/inventory/dependencies.go new file mode 100644 index 0000000..bb2d0b6 --- /dev/null +++ b/internal/watch/enrich/enrichers/inventory/dependencies.go @@ -0,0 +1,290 @@ +package inventory + +import ( + "bufio" + "context" + "encoding/json" + "fmt" + "path" + "regexp" + "sort" + "strings" + + "github.com/mertcikla/tld/internal/watch/enrich" +) + +var goRequireLineRE = regexp.MustCompile(`^\s*([A-Za-z0-9_./~-]+)\s+v[0-9]`) + +type Enricher = enrich.Enricher +type Fact = enrich.Fact +type FactEmitter = enrich.FactEmitter +type FileInput = enrich.FileInput +type Metadata = enrich.Metadata +type SourceSpan = enrich.SourceSpan +type SubjectRef = enrich.SubjectRef + +const ActivationAlways = enrich.ActivationAlways + +var fileSubject = enrich.FileSubject + +func DependencyInventory() Enricher { + return enrich.NewEnricher( + Metadata{ID: "dependency.inventory", Name: "Dependency and import inventory", Mode: ActivationAlways}, + func(input FileInput) bool { + base := path.Base(input.RelPath) + return dependencyManifest(base) || input.Parsed != nil + }, + dependencyInventoryRun, + ) +} + +func dependencyInventoryRun(ctx context.Context, input FileInput, emit FactEmitter) error { + base := path.Base(input.RelPath) + switch base { + case "go.mod": + return emitGoModFacts(input, emit) + case "package.json": + return emitPackageJSONFacts(input, emit) + case "requirements.txt": + return emitLineDependencyFacts(input, emit, "python", requirementName) + case "pyproject.toml", "poetry.lock": + return emitLineDependencyFacts(input, emit, "python", tomlDependencyName) + case "Cargo.toml": + return emitLineDependencyFacts(input, emit, "cargo", cargoDependencyName) + case "pom.xml": + return emitPomFacts(input, emit) + case "build.gradle", "build.gradle.kts": + return emitLineDependencyFacts(input, emit, "gradle", gradleDependencyName) + case "CMakeLists.txt", "conanfile.txt", "conanfile.py", "vcpkg.json": + return emitLineDependencyFacts(input, emit, "cpp", cppDependencyName) + } + if input.Parsed == nil { + return nil + } + for _, ref := range input.Parsed.Refs { + if ref.Kind != "import" || strings.TrimSpace(ref.TargetPath) == "" { + continue + } + line := ref.Line + if line <= 0 { + line = 1 + } + if err := emit.EmitFact(Fact{ + Type: "dependency.import", + StableKey: fmt.Sprintf("dependency.import:%s:%s:%d", input.RelPath, ref.TargetPath, line), + Subject: fileSubject(input.RelPath), + Object: SubjectRef{Kind: "dependency.module", StableKey: "dependency.module:" + ref.TargetPath, Name: ref.TargetPath}, + Relationship: "imports", + Source: SourceSpan{FilePath: input.RelPath, StartLine: line, StartColumn: ref.Column}, + Confidence: 1, + Name: ref.TargetPath, + Tags: []string{"dependency:import"}, + Attributes: map[string]string{"module": ref.TargetPath, "name": ref.Name}, + VisibilityHints: map[string]float64{ + "dependency": 1, + }, + }); err != nil { + return err + } + } + return nil +} + +func dependencyManifest(base string) bool { + switch base { + case "go.mod", "package.json", "requirements.txt", "pyproject.toml", "poetry.lock", "Cargo.toml", "pom.xml", "build.gradle", "build.gradle.kts", "CMakeLists.txt", "conanfile.txt", "conanfile.py", "vcpkg.json": + return true + default: + return false + } +} + +func emitGoModFacts(input FileInput, emit FactEmitter) error { + scanner := bufio.NewScanner(strings.NewReader(string(input.Source))) + line := 0 + for scanner.Scan() { + line++ + match := goRequireLineRE.FindStringSubmatch(scanner.Text()) + if len(match) != 2 { + continue + } + if err := emit.EmitFact(dependencyFact(input.RelPath, line, match[1], "go")); err != nil { + return err + } + } + return scanner.Err() +} + +func emitPackageJSONFacts(input FileInput, emit FactEmitter) error { + var pkg struct { + Dependencies map[string]string `json:"dependencies"` + DevDependencies map[string]string `json:"devDependencies"` + PeerDependencies map[string]string `json:"peerDependencies"` + OptionalDependencies map[string]string `json:"optionalDependencies"` + } + if err := json.Unmarshal(input.Source, &pkg); err != nil { + return nil + } + names := map[string]string{} + add := func(section string, values map[string]string) { + for name := range values { + names[name] = section + } + } + add("dependencies", pkg.Dependencies) + add("devDependencies", pkg.DevDependencies) + add("peerDependencies", pkg.PeerDependencies) + add("optionalDependencies", pkg.OptionalDependencies) + var sorted []string + for name := range names { + sorted = append(sorted, name) + } + sort.Strings(sorted) + for _, name := range sorted { + fact := dependencyFact(input.RelPath, 1, name, "npm") + fact.Attributes["section"] = names[name] + if err := emit.EmitFact(fact); err != nil { + return err + } + } + return nil +} + +func emitLineDependencyFacts(input FileInput, emit FactEmitter, ecosystem string, parse func(string) string) error { + scanner := bufio.NewScanner(strings.NewReader(string(input.Source))) + line := 0 + seen := map[string]struct{}{} + for scanner.Scan() { + line++ + name := parse(scanner.Text()) + if name == "" { + continue + } + if _, ok := seen[name]; ok { + continue + } + seen[name] = struct{}{} + if err := emit.EmitFact(dependencyFact(input.RelPath, line, name, ecosystem)); err != nil { + return err + } + } + return scanner.Err() +} + +func emitPomFacts(input FileInput, emit FactEmitter) error { + source := string(input.Source) + re := regexp.MustCompile(`(?s).*?\s*([^<\s]+)\s*.*?\s*([^<\s]+)\s*.*?`) + for _, indexes := range re.FindAllStringSubmatchIndex(source, -1) { + match := enrich.Submatches(source, indexes) + if len(match) != 3 { + continue + } + name := match[1] + ":" + match[2] + line := enrich.LineForOffset(source, indexes[0]) + if err := emit.EmitFact(dependencyFact(input.RelPath, line, name, "maven")); err != nil { + return err + } + } + return nil +} + +func requirementName(line string) string { + line = strings.TrimSpace(strings.Split(line, "#")[0]) + if line == "" || strings.HasPrefix(line, "-") { + return "" + } + return dependencyPrefix(line, "=", "<", ">", "~", "!", "[", ";") +} + +func tomlDependencyName(line string) string { + line = strings.TrimSpace(strings.Split(line, "#")[0]) + if line == "" || strings.HasPrefix(line, "[") { + return "" + } + if strings.HasPrefix(line, "\"") || strings.HasPrefix(line, "'") { + trimmed := strings.Trim(line, " ,") + trimmed = strings.Trim(trimmed, `"'`) + if trimmed != "" && !strings.Contains(trimmed, "=") { + return requirementName(trimmed) + } + } + if idx := strings.Index(line, "="); idx > 0 { + name := strings.TrimSpace(line[:idx]) + return strings.Trim(name, `"'`) + } + return "" +} + +func cargoDependencyName(line string) string { + name := tomlDependencyName(line) + switch name { + case "package", "dependencies", "dev-dependencies", "build-dependencies", "workspace": + return "" + default: + return name + } +} + +func gradleDependencyName(line string) string { + line = strings.TrimSpace(line) + for _, quote := range []string{"\"", "'"} { + start := strings.Index(line, quote) + if start < 0 { + continue + } + rest := line[start+1:] + before, _, ok := strings.Cut(rest, quote) + if !ok { + continue + } + value := before + if strings.Count(value, ":") >= 1 { + return value + } + } + return "" +} + +func cppDependencyName(line string) string { + line = strings.TrimSpace(strings.Split(line, "#")[0]) + if line == "" { + return "" + } + for _, prefix := range []string{"find_package(", "target_link_libraries(", "requires =", "self.requires(", "\"name\":"} { + if _, after, ok := strings.Cut(line, prefix); ok { + value := strings.TrimSpace(after) + value = strings.Trim(value, ` "'),[]`) + return dependencyPrefix(value, " ", "/", ")", ",", "\"") + } + } + return "" +} + +func dependencyPrefix(value string, stops ...string) string { + value = strings.TrimSpace(value) + end := len(value) + for _, stop := range stops { + if idx := strings.Index(value, stop); idx >= 0 && idx < end { + end = idx + } + } + return strings.TrimSpace(value[:end]) +} + +func dependencyFact(relPath string, line int, name, ecosystem string) Fact { + return Fact{ + Type: "dependency.module", + StableKey: fmt.Sprintf("dependency.module:%s:%s", relPath, name), + Subject: fileSubject(relPath), + Object: SubjectRef{Kind: "dependency.module", StableKey: "dependency.module:" + name, Name: name}, + Relationship: "declares_dependency", + Source: SourceSpan{FilePath: relPath, StartLine: line, EndLine: line}, + Confidence: 1, + Name: name, + Tags: []string{"dependency:module"}, + Attributes: map[string]string{"module": name, "ecosystem": ecosystem}, + VisibilityHints: map[string]float64{ + "dependency": 1, + }, + } +} diff --git a/internal/watch/enrich/enrichers/iot/iot.go b/internal/watch/enrich/enrichers/iot/iot.go new file mode 100644 index 0000000..4a2b142 --- /dev/null +++ b/internal/watch/enrich/enrichers/iot/iot.go @@ -0,0 +1,49 @@ +package iot + +import ( + "github.com/mertcikla/tld/internal/watch/enrich" + "github.com/mertcikla/tld/internal/watch/enrich/enrichers/pattern" +) + +func All() []enrich.Enricher { return pattern.FromSpecs(Specs()) } + +func Specs() []pattern.Spec { + return []pattern.Spec{ + spec("ts.mqtt", "TypeScript MQTT", "typescript", "mqtt", "mqtt.connect", "iot.mqtt_topic", "publishes_to_device"), + spec("python.mqtt", "Python MQTT", "python", "paho-mqtt", "paho.mqtt", "iot.mqtt_topic", "publishes_to_device"), + spec("go.mqtt", "Go MQTT", "go", "github.com/eclipse/paho.mqtt.golang", "mqtt.NewClient", "iot.mqtt_topic", "publishes_to_device"), + spec("cpp.mqtt", "C++ MQTT", "cpp", "paho.mqtt.cpp", "mqtt::client", "iot.mqtt_topic", "publishes_to_device"), + spec("ts.coap", "TypeScript CoAP", "typescript", "coap", "coap.request", "iot.broker", "publishes_to_device"), + spec("python.coap", "Python CoAP", "python", "aiocoap", "aiocoap", "iot.broker", "publishes_to_device"), + spec("go.coap", "Go CoAP", "go", "github.com/plgd-dev/go-coap", "coap", "iot.broker", "publishes_to_device"), + spec("cpp.coap", "C++ CoAP", "cpp", "libcoap", "coap_", "iot.broker", "publishes_to_device"), + spec("python.i2c", "Python I2C", "python", "smbus", "smbus", "hardware.bus_address", "communicates_via_i2c"), + spec("rust.i2c", "Rust I2C", "rust", "embedded-hal", "i2c", "hardware.bus_address", "communicates_via_i2c"), + spec("cpp.i2c", "C++ I2C", "cpp", "i2c", "ioctl", "hardware.bus_address", "communicates_via_i2c"), + spec("python.spi", "Python SPI", "python", "spidev", "spidev", "hardware.bus_address", "communicates_via_i2c"), + spec("rust.spi", "Rust SPI", "rust", "embedded-hal", "spi", "hardware.bus_address", "communicates_via_i2c"), + spec("cpp.spi", "C++ SPI", "cpp", "spi", "SPI_IOC", "hardware.bus_address", "communicates_via_i2c"), + spec("ts.uart", "TypeScript UART", "typescript", "serialport", "SerialPort", "hardware.pin", "communicates_via_i2c"), + spec("python.uart", "Python UART", "python", "pyserial", "serial.Serial", "hardware.pin", "communicates_via_i2c"), + spec("cpp.uart", "C++ UART", "cpp", "uart", "termios", "hardware.pin", "communicates_via_i2c"), + spec("python.can_bus", "Python CAN Bus", "python", "python-can", "can.Bus", "hardware.bus_address", "communicates_via_i2c"), + spec("rust.can_bus", "Rust CAN Bus", "rust", "socketcan", "CANSocket", "hardware.bus_address", "communicates_via_i2c"), + spec("cpp.can_bus", "C++ CAN Bus", "cpp", "socketcan", "PF_CAN", "hardware.bus_address", "communicates_via_i2c"), + } +} + +func spec(id, name, language, dependency, token, factType, relationship string) pattern.Spec { + return pattern.Spec{ + ID: id, + Name: name, + Category: "iot", + Languages: []string{language}, + Mode: enrich.ActivationImportOrDependency, + Triggers: []enrich.ActivationSignal{{Kind: enrich.SignalDependency, Value: dependency}, {Kind: enrich.SignalImport, Value: dependency}}, + FactType: factType, + Relationship: relationship, + SourceTokens: []string{token}, + Tags: []string{"iot:" + id}, + Attributes: map[string]string{"dependency": dependency, "language": language}, + } +} diff --git a/internal/watch/enrich/enrichers/iot/iot_test.go b/internal/watch/enrich/enrichers/iot/iot_test.go new file mode 100644 index 0000000..53712e2 --- /dev/null +++ b/internal/watch/enrich/enrichers/iot/iot_test.go @@ -0,0 +1,33 @@ +package iot + +import ( + "testing" + + "github.com/mertcikla/tld/internal/watch/enrich" + "github.com/mertcikla/tld/internal/watch/enrich/enrichertest" +) + +func TestIoTEnrichers(t *testing.T) { + byID := enrichersByID() + for _, spec := range Specs() { + enrichertest.Run(t, enrichertest.Case{ + Name: spec.ID, + Enricher: byID[spec.ID], + Input: enrich.FileInput{ + RelPath: "src/device", + Language: spec.Languages[0], + Source: []byte(spec.SourceTokens[0]), + }, + Signals: []enrich.ActivationSignal{{Kind: enrich.SignalDependency, Value: spec.Triggers[0].Value}}, + Want: enrichertest.Fact{Type: spec.FactType, Tag: "category:iot", Name: spec.Name, Attribute: "technology", AttrValue: spec.Name}, + }) + } +} + +func enrichersByID() map[string]enrich.Enricher { + out := map[string]enrich.Enricher{} + for _, enricher := range All() { + out[enricher.Metadata().ID] = enricher + } + return out +} diff --git a/internal/watch/enrich/enrichers/ipc/ipc.go b/internal/watch/enrich/enrichers/ipc/ipc.go new file mode 100644 index 0000000..61f24b6 --- /dev/null +++ b/internal/watch/enrich/enrichers/ipc/ipc.go @@ -0,0 +1,57 @@ +package ipc + +import ( + "github.com/mertcikla/tld/internal/watch/enrich" + "github.com/mertcikla/tld/internal/watch/enrich/enrichers/pattern" +) + +func All() []enrich.Enricher { return pattern.FromSpecs(Specs()) } + +func Specs() []pattern.Spec { + return []pattern.Spec{ + spec("go.unix_socket", "Go Unix Domain Sockets", "go", "net", "AF_UNIX", "ipc.socket_path", "connects_to_socket"), + spec("python.unix_socket", "Python Unix Domain Sockets", "python", "socket", "AF_UNIX", "ipc.socket_path", "connects_to_socket"), + spec("rust.unix_socket", "Rust Unix Domain Sockets", "rust", "tokio", "UnixStream", "ipc.socket_path", "connects_to_socket"), + spec("cpp.unix_socket", "C++ Unix Domain Sockets", "cpp", "sys/socket.h", "AF_UNIX", "ipc.socket_path", "connects_to_socket"), + spec("go.dbus", "Go D-Bus", "go", "github.com/godbus/dbus", "org.freedesktop", "ipc.dbus_interface", "exposes_dbus_service"), + spec("python.dbus", "Python D-Bus", "python", "dbus-python", "org.freedesktop", "ipc.dbus_interface", "exposes_dbus_service"), + spec("rust.dbus", "Rust D-Bus", "rust", "zbus", "org.freedesktop", "ipc.dbus_interface", "exposes_dbus_service"), + spec("cpp.dbus", "C++ D-Bus", "cpp", "sdbus-c++", "org.freedesktop", "ipc.dbus_interface", "exposes_dbus_service"), + spec("ts.named_pipes", "TypeScript Named Pipes", "typescript", "net", "\\\\.\\pipe\\", "ipc.socket_path", "connects_to_socket"), + spec("go.named_pipes", "Go Named Pipes", "go", "winio", "\\\\.\\pipe\\", "ipc.socket_path", "connects_to_socket"), + spec("python.named_pipes", "Python Named Pipes", "python", "pywin32", "\\\\.\\pipe\\", "ipc.socket_path", "connects_to_socket"), + spec("cpp.named_pipes", "C++ Named Pipes", "cpp", "windows.h", "CreateNamedPipe", "ipc.socket_path", "connects_to_socket"), + spec("go.grpc_uds", "Go gRPC over UDS", "go", "google.golang.org/grpc", "grpc.WithContextDialer", "ipc.socket_path", "connects_to_socket"), + spec("python.grpc_uds", "Python gRPC over UDS", "python", "grpcio", "unix:", "ipc.socket_path", "connects_to_socket"), + spec("ts.grpc_uds", "TypeScript gRPC over UDS", "typescript", "@grpc/grpc-js", "unix:", "ipc.socket_path", "connects_to_socket"), + spec("cpp.grpc_uds", "C++ gRPC over UDS", "cpp", "grpc", "unix:", "ipc.socket_path", "connects_to_socket"), + spec("go.dev_node", "Go /dev device nodes", "go", "os", "/dev/", "kernel.device_node", "reads_device"), + spec("python.dev_node", "Python /dev device nodes", "python", "os", "/dev/", "kernel.device_node", "reads_device"), + spec("rust.dev_node", "Rust /dev device nodes", "rust", "std", "/dev/", "kernel.device_node", "reads_device"), + spec("cpp.dev_node", "C++ /dev device nodes", "cpp", "fcntl.h", "/dev/", "kernel.device_node", "reads_device"), + spec("go.sysfs_procfs", "Go sysfs / procfs", "go", "os", "/proc/", "kernel.device_node", "reads_device"), + spec("python.sysfs_procfs", "Python sysfs / procfs", "python", "os", "/sys/", "kernel.device_node", "reads_device"), + spec("rust.sysfs_procfs", "Rust sysfs / procfs", "rust", "std", "/proc/", "kernel.device_node", "reads_device"), + spec("cpp.sysfs_procfs", "C++ sysfs / procfs", "cpp", "fstream", "/sys/", "kernel.device_node", "reads_device"), + spec("go.ebpf", "Go eBPF", "go", "github.com/cilium/ebpf", "kprobe", "kernel.device_node", "reads_device"), + spec("python.ebpf", "Python eBPF", "python", "bcc", "BPF(", "kernel.device_node", "reads_device"), + spec("rust.ebpf", "Rust eBPF", "rust", "aya", "tracepoint", "kernel.device_node", "reads_device"), + spec("cpp.ebpf", "C++ eBPF", "cpp", "libbpf", "uprobe", "kernel.device_node", "reads_device"), + } +} + +func spec(id, name, language, dependency, token, factType, relationship string) pattern.Spec { + return pattern.Spec{ + ID: id, + Name: name, + Category: "ipc", + Languages: []string{language}, + Mode: enrich.ActivationImportOrDependency, + Triggers: []enrich.ActivationSignal{{Kind: enrich.SignalDependency, Value: dependency}, {Kind: enrich.SignalImport, Value: dependency}}, + FactType: factType, + Relationship: relationship, + SourceTokens: []string{token}, + Tags: []string{"ipc:" + id}, + Attributes: map[string]string{"dependency": dependency, "language": language}, + } +} diff --git a/internal/watch/enrich/enrichers/ipc/ipc_test.go b/internal/watch/enrich/enrichers/ipc/ipc_test.go new file mode 100644 index 0000000..e986ed5 --- /dev/null +++ b/internal/watch/enrich/enrichers/ipc/ipc_test.go @@ -0,0 +1,33 @@ +package ipc + +import ( + "testing" + + "github.com/mertcikla/tld/internal/watch/enrich" + "github.com/mertcikla/tld/internal/watch/enrich/enrichertest" +) + +func TestIPCEnrichers(t *testing.T) { + byID := enrichersByID() + for _, spec := range Specs() { + enrichertest.Run(t, enrichertest.Case{ + Name: spec.ID, + Enricher: byID[spec.ID], + Input: enrich.FileInput{ + RelPath: "src/ipc", + Language: spec.Languages[0], + Source: []byte(spec.SourceTokens[0]), + }, + Signals: []enrich.ActivationSignal{{Kind: enrich.SignalDependency, Value: spec.Triggers[0].Value}}, + Want: enrichertest.Fact{Type: spec.FactType, Tag: "category:ipc", Name: spec.Name, Attribute: "technology", AttrValue: spec.Name}, + }) + } +} + +func enrichersByID() map[string]enrich.Enricher { + out := map[string]enrich.Enricher{} + for _, enricher := range All() { + out[enricher.Metadata().ID] = enricher + } + return out +} diff --git a/internal/watch/enrich/enrichers/jobs/jobs.go b/internal/watch/enrich/enrichers/jobs/jobs.go new file mode 100644 index 0000000..a4841d7 --- /dev/null +++ b/internal/watch/enrich/enrichers/jobs/jobs.go @@ -0,0 +1,44 @@ +package jobs + +import ( + "github.com/mertcikla/tld/internal/watch/enrich" + "github.com/mertcikla/tld/internal/watch/enrich/enrichers/pattern" +) + +func All() []enrich.Enricher { return pattern.FromSpecs(Specs()) } + +func Specs() []pattern.Spec { + return []pattern.Spec{ + spec("ts.bullmq", "TypeScript BullMQ", "typescript", "bullmq", "new Queue", "job.queue", "enqueues"), + spec("ts.agenda", "TypeScript Agenda", "typescript", "agenda", "agenda.define", "job.handler", "handles_job"), + spec("ts.node_cron", "TypeScript node-cron", "typescript", "node-cron", "cron.schedule", "job.schedule", "runs_on_schedule"), + spec("go.robfig_cron", "Go robfig/cron", "go", "github.com/robfig/cron", "cron.New", "job.schedule", "runs_on_schedule"), + spec("go.asynq", "Go asynq", "go", "github.com/hibiken/asynq", "asynq.NewServer", "job.queue", "consumes"), + spec("go.machinery", "Go machinery", "go", "github.com/RichardKnop/machinery", "machinery", "job.queue", "consumes"), + spec("python.celery", "Python Celery", "python", "celery", "@shared_task", "job.handler", "handles_job"), + spec("python.rq", "Python RQ", "python", "rq", "Queue(", "job.queue", "consumes"), + spec("python.apscheduler", "Python APScheduler", "python", "apscheduler", "add_job", "job.schedule", "runs_on_schedule"), + spec("java.spring_scheduling", "Java Spring Scheduling", "java", "spring-context", "@Scheduled", "job.schedule", "runs_on_schedule"), + spec("java.quartz", "Java Quartz", "java", "org.quartz", "JobBuilder", "job.handler", "handles_job"), + spec("rust.tokio_cron_scheduler", "Rust tokio-cron-scheduler", "rust", "tokio-cron-scheduler", "JobScheduler", "job.schedule", "runs_on_schedule"), + spec("rust.apalis", "Rust apalis", "rust", "apalis", "apalis::", "job.queue", "consumes"), + spec("cpp.custom_scheduler", "C++ custom schedulers", "cpp", "cron", "schedule_every", "job.schedule", "runs_on_schedule"), + spec("cpp.queue_consumer", "C++ queue consumers", "cpp", "queue", "consume_queue", "job.queue", "consumes"), + } +} + +func spec(id, name, language, dependency, token, factType, relationship string) pattern.Spec { + return pattern.Spec{ + ID: id, + Name: name, + Category: "jobs", + Languages: []string{language}, + Mode: enrich.ActivationImportOrDependency, + Triggers: []enrich.ActivationSignal{{Kind: enrich.SignalDependency, Value: dependency}, {Kind: enrich.SignalImport, Value: dependency}}, + FactType: factType, + Relationship: relationship, + SourceTokens: []string{token}, + Tags: []string{"jobs:" + id}, + Attributes: map[string]string{"dependency": dependency, "language": language}, + } +} diff --git a/internal/watch/enrich/enrichers/jobs/jobs_test.go b/internal/watch/enrich/enrichers/jobs/jobs_test.go new file mode 100644 index 0000000..baf447f --- /dev/null +++ b/internal/watch/enrich/enrichers/jobs/jobs_test.go @@ -0,0 +1,33 @@ +package jobs + +import ( + "testing" + + "github.com/mertcikla/tld/internal/watch/enrich" + "github.com/mertcikla/tld/internal/watch/enrich/enrichertest" +) + +func TestJobEnrichers(t *testing.T) { + byID := enrichersByID() + for _, spec := range Specs() { + enrichertest.Run(t, enrichertest.Case{ + Name: spec.ID, + Enricher: byID[spec.ID], + Input: enrich.FileInput{ + RelPath: "src/jobs", + Language: spec.Languages[0], + Source: []byte(spec.SourceTokens[0]), + }, + Signals: []enrich.ActivationSignal{{Kind: enrich.SignalDependency, Value: spec.Triggers[0].Value}}, + Want: enrichertest.Fact{Type: spec.FactType, Tag: "category:jobs", Name: spec.Name, Attribute: "technology", AttrValue: spec.Name}, + }) + } +} + +func enrichersByID() map[string]enrich.Enricher { + out := map[string]enrich.Enricher{} + for _, enricher := range All() { + out[enricher.Metadata().ID] = enricher + } + return out +} diff --git a/internal/watch/enrich/enrichers/messaging/messaging.go b/internal/watch/enrich/enrichers/messaging/messaging.go new file mode 100644 index 0000000..651756d --- /dev/null +++ b/internal/watch/enrich/enrichers/messaging/messaging.go @@ -0,0 +1,57 @@ +package messaging + +import ( + "github.com/mertcikla/tld/internal/watch/enrich" + "github.com/mertcikla/tld/internal/watch/enrich/enrichers/pattern" +) + +func All() []enrich.Enricher { return pattern.FromSpecs(Specs()) } + +func Specs() []pattern.Spec { + return []pattern.Spec{ + spec("ts.kafkajs", "TypeScript KafkaJS", "typescript", "kafkajs", "Kafka(", "messaging.topic", "publishes"), + spec("ts.bullmq_messaging", "TypeScript BullMQ messaging", "typescript", "bullmq", "new Queue", "messaging.queue", "publishes"), + spec("ts.aws_sqs", "TypeScript AWS SQS SDK", "typescript", "@aws-sdk/client-sqs", "SQSClient", "messaging.queue", "publishes"), + spec("ts.amqplib", "TypeScript amqplib", "typescript", "amqplib", "amqplib", "messaging.queue", "publishes"), + spec("ts.nats", "TypeScript NATS", "typescript", "nats", "connect(", "messaging.topic", "subscribes_to"), + spec("go.kafka_go", "Go kafka-go", "go", "github.com/segmentio/kafka-go", "kafka.Writer", "messaging.topic", "publishes"), + spec("go.sarama", "Go Sarama", "go", "github.com/IBM/sarama", "sarama.New", "messaging.topic", "publishes"), + spec("go.nats", "Go NATS", "go", "github.com/nats-io/nats.go", "nats.Connect", "messaging.topic", "subscribes_to"), + spec("go.rabbitmq", "Go RabbitMQ", "go", "github.com/rabbitmq/amqp091-go", "amqp.Dial", "messaging.queue", "consumes"), + spec("go.aws_sqs", "Go AWS SQS SDK", "go", "github.com/aws/aws-sdk-go-v2/service/sqs", "sqs.NewFromConfig", "messaging.queue", "publishes"), + spec("python.celery_messaging", "Python Celery messaging", "python", "celery", "Celery(", "messaging.queue", "consumes"), + spec("python.kafka_python", "Python kafka-python", "python", "kafka-python", "KafkaProducer", "messaging.topic", "publishes"), + spec("python.confluent_kafka", "Python confluent-kafka", "python", "confluent-kafka", "confluent_kafka", "messaging.topic", "publishes"), + spec("python.pika", "Python pika", "python", "pika", "pika.", "messaging.queue", "consumes"), + spec("python.boto3_sqs", "Python boto3 SQS", "python", "boto3", "sqs", "messaging.queue", "publishes"), + spec("java.spring_kafka", "Java Spring Kafka", "java", "spring-kafka", "KafkaTemplate", "messaging.topic", "publishes"), + spec("java.kafka_clients", "Java Kafka clients", "java", "org.apache.kafka", "KafkaProducer", "messaging.topic", "publishes"), + spec("java.spring_amqp", "Java Spring AMQP", "java", "spring-amqp", "RabbitTemplate", "messaging.queue", "publishes"), + spec("java.jms", "Java JMS", "java", "jakarta.jms", "JMSContext", "messaging.queue", "publishes"), + spec("java.aws_sqs", "Java AWS SQS SDK", "java", "software.amazon.awssdk.services.sqs", "SqsClient", "messaging.queue", "publishes"), + spec("rust.rdkafka", "Rust rdkafka", "rust", "rdkafka", "rdkafka::", "messaging.topic", "publishes"), + spec("rust.lapin", "Rust lapin", "rust", "lapin", "lapin::", "messaging.queue", "consumes"), + spec("rust.async_nats", "Rust async-nats", "rust", "async-nats", "async_nats::", "messaging.topic", "subscribes_to"), + spec("rust.aws_sqs", "Rust AWS SQS SDK", "rust", "aws-sdk-sqs", "aws_sdk_sqs::", "messaging.queue", "publishes"), + spec("cpp.librdkafka", "C++ librdkafka", "cpp", "librdkafka", "RdKafka::", "messaging.topic", "publishes"), + spec("cpp.rabbitmq_c", "C++ rabbitmq-c", "cpp", "rabbitmq-c", "amqp_login", "messaging.queue", "consumes"), + spec("cpp.nats", "C++ nats.cpp", "cpp", "nats.cpp", "natsConnection", "messaging.topic", "subscribes_to"), + spec("cpp.aws_sqs", "C++ AWS SQS SDK", "cpp", "aws-sdk-cpp", "Aws::SQS", "messaging.queue", "publishes"), + } +} + +func spec(id, name, language, dependency, token, factType, relationship string) pattern.Spec { + return pattern.Spec{ + ID: id, + Name: name, + Category: "messaging", + Languages: []string{language}, + Mode: enrich.ActivationImportOrDependency, + Triggers: []enrich.ActivationSignal{{Kind: enrich.SignalDependency, Value: dependency}, {Kind: enrich.SignalImport, Value: dependency}}, + FactType: factType, + Relationship: relationship, + SourceTokens: []string{token}, + Tags: []string{"messaging:" + id}, + Attributes: map[string]string{"dependency": dependency, "language": language}, + } +} diff --git a/internal/watch/enrich/enrichers/messaging/messaging_test.go b/internal/watch/enrich/enrichers/messaging/messaging_test.go new file mode 100644 index 0000000..afb4b84 --- /dev/null +++ b/internal/watch/enrich/enrichers/messaging/messaging_test.go @@ -0,0 +1,33 @@ +package messaging + +import ( + "testing" + + "github.com/mertcikla/tld/internal/watch/enrich" + "github.com/mertcikla/tld/internal/watch/enrich/enrichertest" +) + +func TestMessagingEnrichers(t *testing.T) { + byID := enrichersByID() + for _, spec := range Specs() { + enrichertest.Run(t, enrichertest.Case{ + Name: spec.ID, + Enricher: byID[spec.ID], + Input: enrich.FileInput{ + RelPath: "src/messaging", + Language: spec.Languages[0], + Source: []byte(spec.SourceTokens[0]), + }, + Signals: []enrich.ActivationSignal{{Kind: enrich.SignalDependency, Value: spec.Triggers[0].Value}}, + Want: enrichertest.Fact{Type: spec.FactType, Tag: "category:messaging", Name: spec.Name, Attribute: "technology", AttrValue: spec.Name}, + }) + } +} + +func enrichersByID() map[string]enrich.Enricher { + out := map[string]enrich.Enricher{} + for _, enricher := range All() { + out[enricher.Metadata().ID] = enricher + } + return out +} diff --git a/internal/watch/enrich/enrichers/observability/observability.go b/internal/watch/enrich/enrichers/observability/observability.go new file mode 100644 index 0000000..dd240db --- /dev/null +++ b/internal/watch/enrich/enrichers/observability/observability.go @@ -0,0 +1,53 @@ +package observability + +import ( + "github.com/mertcikla/tld/internal/watch/enrich" + "github.com/mertcikla/tld/internal/watch/enrich/enrichers/pattern" +) + +func All() []enrich.Enricher { return pattern.FromSpecs(Specs()) } + +func Specs() []pattern.Spec { + return []pattern.Spec{ + spec("ts.opentelemetry", "TypeScript OpenTelemetry", "typescript", "@opentelemetry/api", "@opentelemetry/", "telemetry.span", "creates_span"), + spec("ts.prometheus_client", "TypeScript Prometheus client", "typescript", "prom-client", "prom-client", "telemetry.metric", "emits_metric"), + spec("ts.sentry", "TypeScript Sentry", "typescript", "@sentry/node", "@sentry/", "telemetry.project", "reports_to"), + spec("ts.datadog", "TypeScript Datadog", "typescript", "dd-trace", "dd-trace", "telemetry.project", "reports_to"), + spec("go.opentelemetry", "Go OpenTelemetry", "go", "go.opentelemetry.io/otel", "go.opentelemetry.io/otel", "telemetry.span", "creates_span"), + spec("go.prometheus", "Go Prometheus", "go", "github.com/prometheus/client_golang", "prometheus.New", "telemetry.metric", "emits_metric"), + spec("go.sentry", "Go Sentry", "go", "github.com/getsentry/sentry-go", "sentry.Init", "telemetry.project", "reports_to"), + spec("go.datadog", "Go Datadog", "go", "gopkg.in/DataDog/dd-trace-go", "ddtrace", "telemetry.project", "reports_to"), + spec("python.opentelemetry", "Python OpenTelemetry", "python", "opentelemetry-api", "opentelemetry", "telemetry.span", "creates_span"), + spec("python.prometheus_client", "Python Prometheus client", "python", "prometheus-client", "prometheus_client", "telemetry.metric", "emits_metric"), + spec("python.sentry_sdk", "Python Sentry SDK", "python", "sentry-sdk", "sentry_sdk", "telemetry.project", "reports_to"), + spec("python.datadog_tracing", "Python Datadog tracing", "python", "ddtrace", "ddtrace", "telemetry.project", "reports_to"), + spec("java.opentelemetry", "Java OpenTelemetry", "java", "io.opentelemetry", "io.opentelemetry", "telemetry.span", "creates_span"), + spec("java.micrometer", "Java Micrometer", "java", "io.micrometer", "MeterRegistry", "telemetry.metric", "emits_metric"), + spec("java.prometheus", "Java Prometheus", "java", "micrometer-registry-prometheus", "PrometheusMeterRegistry", "telemetry.metric", "emits_metric"), + spec("java.sentry", "Java Sentry", "java", "io.sentry", "Sentry.init", "telemetry.project", "reports_to"), + spec("java.datadog", "Java Datadog", "java", "com.datadoghq", "datadog", "telemetry.project", "reports_to"), + spec("rust.tracing", "Rust tracing", "rust", "tracing", "tracing::", "telemetry.span", "creates_span"), + spec("rust.opentelemetry", "Rust OpenTelemetry", "rust", "opentelemetry", "opentelemetry::", "telemetry.span", "creates_span"), + spec("rust.metrics", "Rust metrics", "rust", "metrics", "metrics::", "telemetry.metric", "emits_metric"), + spec("rust.sentry", "Rust Sentry", "rust", "sentry", "sentry::", "telemetry.project", "reports_to"), + spec("cpp.opentelemetry", "C++ OpenTelemetry", "cpp", "opentelemetry-cpp", "opentelemetry", "telemetry.span", "creates_span"), + spec("cpp.prometheus", "C++ Prometheus", "cpp", "prometheus-cpp", "prometheus::", "telemetry.metric", "emits_metric"), + spec("cpp.sentry_native", "C++ Sentry Native", "cpp", "sentry-native", "sentry_init", "telemetry.project", "reports_to"), + } +} + +func spec(id, name, language, dependency, token, factType, relationship string) pattern.Spec { + return pattern.Spec{ + ID: id, + Name: name, + Category: "observability", + Languages: []string{language}, + Mode: enrich.ActivationImportOrDependency, + Triggers: []enrich.ActivationSignal{{Kind: enrich.SignalDependency, Value: dependency}, {Kind: enrich.SignalImport, Value: dependency}}, + FactType: factType, + Relationship: relationship, + SourceTokens: []string{token}, + Tags: []string{"observability:" + id}, + Attributes: map[string]string{"dependency": dependency, "language": language}, + } +} diff --git a/internal/watch/enrich/enrichers/observability/observability_test.go b/internal/watch/enrich/enrichers/observability/observability_test.go new file mode 100644 index 0000000..0f209c8 --- /dev/null +++ b/internal/watch/enrich/enrichers/observability/observability_test.go @@ -0,0 +1,33 @@ +package observability + +import ( + "testing" + + "github.com/mertcikla/tld/internal/watch/enrich" + "github.com/mertcikla/tld/internal/watch/enrich/enrichertest" +) + +func TestObservabilityEnrichers(t *testing.T) { + for _, spec := range Specs() { + source := spec.SourceTokens[0] + enrichertest.Run(t, enrichertest.Case{ + Name: spec.ID, + Enricher: enrichersByID()[spec.ID], + Input: enrich.FileInput{ + RelPath: "src/service", + Language: spec.Languages[0], + Source: []byte(source), + }, + Signals: []enrich.ActivationSignal{{Kind: enrich.SignalDependency, Value: spec.Triggers[0].Value}}, + Want: enrichertest.Fact{Type: spec.FactType, Tag: "category:observability", Name: spec.Name, Attribute: "technology", AttrValue: spec.Name}, + }) + } +} + +func enrichersByID() map[string]enrich.Enricher { + out := map[string]enrich.Enricher{} + for _, enricher := range All() { + out[enricher.Metadata().ID] = enricher + } + return out +} diff --git a/internal/watch/enrich/enrichers/orm/cpp/orm.go b/internal/watch/enrich/enrichers/orm/cpp/orm.go new file mode 100644 index 0000000..e06c4fc --- /dev/null +++ b/internal/watch/enrich/enrichers/orm/cpp/orm.go @@ -0,0 +1,34 @@ +package cpp + +import ( + "github.com/mertcikla/tld/internal/watch/enrich" + "github.com/mertcikla/tld/internal/watch/enrich/enrichers/pattern" +) + +func All() []enrich.Enricher { return pattern.FromSpecs(Specs()) } + +func Specs() []pattern.Spec { + return []pattern.Spec{ + spec("cpp.raw_sql", "C++ raw SQL", "sqlite3", "sqlite3_prepare", "raw-sql"), + spec("cpp.libpqxx", "C++ libpqxx", "libpqxx", "pqxx::", "libpqxx"), + spec("cpp.soci", "C++ SOCI", "soci", "soci::session", "soci"), + spec("cpp.sqlite_orm", "C++ sqlite_orm", "sqlite_orm", "make_storage", "sqlite-orm"), + spec("cpp.odb", "C++ ODB", "odb", "odb::database", "odb"), + } +} + +func spec(id, name, dependency, token, orm string) pattern.Spec { + return pattern.Spec{ + ID: id, + Name: name, + Category: "orm", + Languages: []string{"cpp"}, + Mode: enrich.ActivationImportOrDependency, + Triggers: []enrich.ActivationSignal{{Kind: enrich.SignalDependency, Value: dependency}, {Kind: enrich.SignalImport, Value: dependency}}, + FactType: "orm.query", + Relationship: "queries", + SourceTokens: []string{token}, + Tags: []string{"orm:" + orm}, + Attributes: map[string]string{"dependency": dependency, "language": "cpp", "orm": orm}, + } +} diff --git a/internal/watch/enrich/enrichers/orm/cpp/orm_test.go b/internal/watch/enrich/enrichers/orm/cpp/orm_test.go new file mode 100644 index 0000000..a619486 --- /dev/null +++ b/internal/watch/enrich/enrichers/orm/cpp/orm_test.go @@ -0,0 +1,29 @@ +package cpp + +import ( + "testing" + + "github.com/mertcikla/tld/internal/watch/enrich" + "github.com/mertcikla/tld/internal/watch/enrich/enrichertest" +) + +func TestCPPORMEnrichers(t *testing.T) { + byID := enrichersByID() + for _, spec := range Specs() { + enrichertest.Run(t, enrichertest.Case{ + Name: spec.ID, + Enricher: byID[spec.ID], + Input: enrich.FileInput{RelPath: "src/db.cpp", Language: "cpp", Source: []byte(spec.SourceTokens[0])}, + Signals: []enrich.ActivationSignal{{Kind: enrich.SignalDependency, Value: spec.Triggers[0].Value}}, + Want: enrichertest.Fact{Type: spec.FactType, Tag: "category:orm", Name: spec.Name, Attribute: "technology", AttrValue: spec.Name}, + }) + } +} + +func enrichersByID() map[string]enrich.Enricher { + out := map[string]enrich.Enricher{} + for _, enricher := range All() { + out[enricher.Metadata().ID] = enricher + } + return out +} diff --git a/internal/watch/enrich/enrichers/orm/golang/orm.go b/internal/watch/enrich/enrichers/orm/golang/orm.go new file mode 100644 index 0000000..29917a0 --- /dev/null +++ b/internal/watch/enrich/enrichers/orm/golang/orm.go @@ -0,0 +1,33 @@ +package golang + +import ( + "github.com/mertcikla/tld/internal/watch/enrich" + "github.com/mertcikla/tld/internal/watch/enrich/enrichers/pattern" +) + +func All() []enrich.Enricher { return pattern.FromSpecs(Specs()) } + +func Specs() []pattern.Spec { + return []pattern.Spec{ + spec("go.gorm", "Go GORM", "github.com/go-gorm/gorm", "gorm.Open", "gorm"), + spec("go.sqlc", "Go sqlc", "github.com/sqlc-dev/sqlc", "sqlc", "sqlc"), + spec("go.ent", "Go ent", "entgo.io/ent", "ent.Client", "ent"), + spec("go.database_sql", "Go database/sql", "database/sql", "sql.Open", "database-sql"), + } +} + +func spec(id, name, dependency, token, orm string) pattern.Spec { + return pattern.Spec{ + ID: id, + Name: name, + Category: "orm", + Languages: []string{"go"}, + Mode: enrich.ActivationImportOrDependency, + Triggers: []enrich.ActivationSignal{{Kind: enrich.SignalDependency, Value: dependency}, {Kind: enrich.SignalImport, Value: dependency}}, + FactType: "orm.query", + Relationship: "queries", + SourceTokens: []string{token}, + Tags: []string{"orm:" + orm}, + Attributes: map[string]string{"dependency": dependency, "language": "go", "orm": orm}, + } +} diff --git a/internal/watch/enrich/enrichers/orm/golang/orm_test.go b/internal/watch/enrich/enrichers/orm/golang/orm_test.go new file mode 100644 index 0000000..c4c7b58 --- /dev/null +++ b/internal/watch/enrich/enrichers/orm/golang/orm_test.go @@ -0,0 +1,33 @@ +package golang + +import ( + "testing" + + "github.com/mertcikla/tld/internal/watch/enrich" + "github.com/mertcikla/tld/internal/watch/enrich/enrichertest" +) + +func TestGoORMEnrichers(t *testing.T) { + byID := enrichersByID() + for _, spec := range Specs() { + enrichertest.Run(t, enrichertest.Case{ + Name: spec.ID, + Enricher: byID[spec.ID], + Input: enrich.FileInput{ + RelPath: "db.go", + Language: "go", + Source: []byte(spec.SourceTokens[0]), + }, + Signals: []enrich.ActivationSignal{{Kind: enrich.SignalDependency, Value: spec.Triggers[0].Value}}, + Want: enrichertest.Fact{Type: spec.FactType, Tag: "category:orm", Name: spec.Name, Attribute: "technology", AttrValue: spec.Name}, + }) + } +} + +func enrichersByID() map[string]enrich.Enricher { + out := map[string]enrich.Enricher{} + for _, enricher := range All() { + out[enricher.Metadata().ID] = enricher + } + return out +} diff --git a/internal/watch/enrich/enrichers/orm/java/orm.go b/internal/watch/enrich/enrichers/orm/java/orm.go new file mode 100644 index 0000000..72d181d --- /dev/null +++ b/internal/watch/enrich/enrichers/orm/java/orm.go @@ -0,0 +1,34 @@ +package java + +import ( + "github.com/mertcikla/tld/internal/watch/enrich" + "github.com/mertcikla/tld/internal/watch/enrich/enrichers/pattern" +) + +func All() []enrich.Enricher { return pattern.FromSpecs(Specs()) } + +func Specs() []pattern.Spec { + return []pattern.Spec{ + spec("java.hibernate", "Java Hibernate", "org.hibernate", "org.hibernate", "hibernate"), + spec("java.jpa", "Java JPA", "jakarta.persistence", "@Entity", "jpa"), + spec("java.spring_data_jpa", "Spring Data JPA", "spring-data-jpa", "JpaRepository", "spring-data-jpa"), + spec("java.mybatis", "Java MyBatis", "mybatis", "@Mapper", "mybatis"), + spec("java.jooq", "Java jOOQ", "jooq", "DSLContext", "jooq"), + } +} + +func spec(id, name, dependency, token, orm string) pattern.Spec { + return pattern.Spec{ + ID: id, + Name: name, + Category: "orm", + Languages: []string{"java"}, + Mode: enrich.ActivationImportOrDependency, + Triggers: []enrich.ActivationSignal{{Kind: enrich.SignalDependency, Value: dependency}, {Kind: enrich.SignalImport, Value: dependency}}, + FactType: "orm.query", + Relationship: "queries", + SourceTokens: []string{token}, + Tags: []string{"orm:" + orm}, + Attributes: map[string]string{"dependency": dependency, "language": "java", "orm": orm}, + } +} diff --git a/internal/watch/enrich/enrichers/orm/java/orm_test.go b/internal/watch/enrich/enrichers/orm/java/orm_test.go new file mode 100644 index 0000000..f7fc659 --- /dev/null +++ b/internal/watch/enrich/enrichers/orm/java/orm_test.go @@ -0,0 +1,29 @@ +package java + +import ( + "testing" + + "github.com/mertcikla/tld/internal/watch/enrich" + "github.com/mertcikla/tld/internal/watch/enrich/enrichertest" +) + +func TestJavaORMEnrichers(t *testing.T) { + byID := enrichersByID() + for _, spec := range Specs() { + enrichertest.Run(t, enrichertest.Case{ + Name: spec.ID, + Enricher: byID[spec.ID], + Input: enrich.FileInput{RelPath: "Model.java", Language: "java", Source: []byte(spec.SourceTokens[0])}, + Signals: []enrich.ActivationSignal{{Kind: enrich.SignalDependency, Value: spec.Triggers[0].Value}}, + Want: enrichertest.Fact{Type: spec.FactType, Tag: "category:orm", Name: spec.Name, Attribute: "technology", AttrValue: spec.Name}, + }) + } +} + +func enrichersByID() map[string]enrich.Enricher { + out := map[string]enrich.Enricher{} + for _, enricher := range All() { + out[enricher.Metadata().ID] = enricher + } + return out +} diff --git a/internal/watch/enrich/enrichers/orm/python/orm.go b/internal/watch/enrich/enrichers/orm/python/orm.go new file mode 100644 index 0000000..1a0ddc6 --- /dev/null +++ b/internal/watch/enrich/enrichers/orm/python/orm.go @@ -0,0 +1,33 @@ +package python + +import ( + "github.com/mertcikla/tld/internal/watch/enrich" + "github.com/mertcikla/tld/internal/watch/enrich/enrichers/pattern" +) + +func All() []enrich.Enricher { return pattern.FromSpecs(Specs()) } + +func Specs() []pattern.Spec { + return []pattern.Spec{ + spec("python.sqlalchemy", "Python SQLAlchemy", "sqlalchemy", "sqlalchemy", "sqlalchemy"), + spec("python.django_orm", "Django ORM", "django", "django.db.models", "django"), + spec("python.peewee", "Python Peewee", "peewee", "peewee", "peewee"), + spec("python.tortoise", "Tortoise ORM", "tortoise-orm", "tortoise", "tortoise"), + } +} + +func spec(id, name, dependency, token, orm string) pattern.Spec { + return pattern.Spec{ + ID: id, + Name: name, + Category: "orm", + Languages: []string{"python"}, + Mode: enrich.ActivationImportOrDependency, + Triggers: []enrich.ActivationSignal{{Kind: enrich.SignalDependency, Value: dependency}, {Kind: enrich.SignalImport, Value: dependency}}, + FactType: "orm.query", + Relationship: "queries", + SourceTokens: []string{token}, + Tags: []string{"orm:" + orm}, + Attributes: map[string]string{"dependency": dependency, "language": "python", "orm": orm}, + } +} diff --git a/internal/watch/enrich/enrichers/orm/python/orm_test.go b/internal/watch/enrich/enrichers/orm/python/orm_test.go new file mode 100644 index 0000000..9688cf8 --- /dev/null +++ b/internal/watch/enrich/enrichers/orm/python/orm_test.go @@ -0,0 +1,29 @@ +package python + +import ( + "testing" + + "github.com/mertcikla/tld/internal/watch/enrich" + "github.com/mertcikla/tld/internal/watch/enrich/enrichertest" +) + +func TestPythonORMEnrichers(t *testing.T) { + byID := enrichersByID() + for _, spec := range Specs() { + enrichertest.Run(t, enrichertest.Case{ + Name: spec.ID, + Enricher: byID[spec.ID], + Input: enrich.FileInput{RelPath: "models.py", Language: "python", Source: []byte(spec.SourceTokens[0])}, + Signals: []enrich.ActivationSignal{{Kind: enrich.SignalDependency, Value: spec.Triggers[0].Value}}, + Want: enrichertest.Fact{Type: spec.FactType, Tag: "category:orm", Name: spec.Name, Attribute: "technology", AttrValue: spec.Name}, + }) + } +} + +func enrichersByID() map[string]enrich.Enricher { + out := map[string]enrich.Enricher{} + for _, enricher := range All() { + out[enricher.Metadata().ID] = enricher + } + return out +} diff --git a/internal/watch/enrich/enrichers/orm/rust/orm.go b/internal/watch/enrich/enrichers/orm/rust/orm.go new file mode 100644 index 0000000..b3e14b0 --- /dev/null +++ b/internal/watch/enrich/enrichers/orm/rust/orm.go @@ -0,0 +1,33 @@ +package rust + +import ( + "github.com/mertcikla/tld/internal/watch/enrich" + "github.com/mertcikla/tld/internal/watch/enrich/enrichers/pattern" +) + +func All() []enrich.Enricher { return pattern.FromSpecs(Specs()) } + +func Specs() []pattern.Spec { + return []pattern.Spec{ + spec("rust.sqlx", "Rust sqlx", "sqlx", "sqlx::query", "sqlx"), + spec("rust.diesel", "Rust Diesel", "diesel", "diesel::", "diesel"), + spec("rust.seaorm", "Rust SeaORM", "sea-orm", "EntityTrait", "seaorm"), + spec("rust.tokio_postgres", "Rust tokio-postgres", "tokio-postgres", "tokio_postgres", "tokio-postgres"), + } +} + +func spec(id, name, dependency, token, orm string) pattern.Spec { + return pattern.Spec{ + ID: id, + Name: name, + Category: "orm", + Languages: []string{"rust"}, + Mode: enrich.ActivationImportOrDependency, + Triggers: []enrich.ActivationSignal{{Kind: enrich.SignalDependency, Value: dependency}, {Kind: enrich.SignalImport, Value: dependency}}, + FactType: "orm.query", + Relationship: "queries", + SourceTokens: []string{token}, + Tags: []string{"orm:" + orm}, + Attributes: map[string]string{"dependency": dependency, "language": "rust", "orm": orm}, + } +} diff --git a/internal/watch/enrich/enrichers/orm/rust/orm_test.go b/internal/watch/enrich/enrichers/orm/rust/orm_test.go new file mode 100644 index 0000000..e1a7787 --- /dev/null +++ b/internal/watch/enrich/enrichers/orm/rust/orm_test.go @@ -0,0 +1,29 @@ +package rust + +import ( + "testing" + + "github.com/mertcikla/tld/internal/watch/enrich" + "github.com/mertcikla/tld/internal/watch/enrich/enrichertest" +) + +func TestRustORMEnrichers(t *testing.T) { + byID := enrichersByID() + for _, spec := range Specs() { + enrichertest.Run(t, enrichertest.Case{ + Name: spec.ID, + Enricher: byID[spec.ID], + Input: enrich.FileInput{RelPath: "src/db.rs", Language: "rust", Source: []byte(spec.SourceTokens[0])}, + Signals: []enrich.ActivationSignal{{Kind: enrich.SignalDependency, Value: spec.Triggers[0].Value}}, + Want: enrichertest.Fact{Type: spec.FactType, Tag: "category:orm", Name: spec.Name, Attribute: "technology", AttrValue: spec.Name}, + }) + } +} + +func enrichersByID() map[string]enrich.Enricher { + out := map[string]enrich.Enricher{} + for _, enricher := range All() { + out[enricher.Metadata().ID] = enricher + } + return out +} diff --git a/internal/watch/enrich/enrichers/orm/typescript/catalog.go b/internal/watch/enrich/enrichers/orm/typescript/catalog.go new file mode 100644 index 0000000..a654125 --- /dev/null +++ b/internal/watch/enrich/enrichers/orm/typescript/catalog.go @@ -0,0 +1,36 @@ +package typescript + +import ( + "github.com/mertcikla/tld/internal/watch/enrich" + "github.com/mertcikla/tld/internal/watch/enrich/enrichers/pattern" +) + +func All() []enrich.Enricher { + out := []enrich.Enricher{Prisma()} + out = append(out, pattern.FromSpecs(Specs())...) + return out +} + +func Specs() []pattern.Spec { + return []pattern.Spec{ + spec("ts.typeorm", "TypeScript TypeORM", "typeorm", "DataSource", "typeorm"), + spec("ts.sequelize", "TypeScript Sequelize", "sequelize", "Sequelize", "sequelize"), + spec("ts.drizzle", "TypeScript Drizzle", "drizzle-orm", "drizzle(", "drizzle"), + } +} + +func spec(id, name, dependency, token, orm string) pattern.Spec { + return pattern.Spec{ + ID: id, + Name: name, + Category: "orm", + Languages: []string{"typescript", "javascript"}, + Mode: enrich.ActivationImportOrDependency, + Triggers: []enrich.ActivationSignal{{Kind: enrich.SignalDependency, Value: dependency}, {Kind: enrich.SignalImport, Value: dependency}}, + FactType: "orm.query", + Relationship: "queries", + SourceTokens: []string{token}, + Tags: []string{"orm:" + orm}, + Attributes: map[string]string{"dependency": dependency, "language": "typescript", "orm": orm}, + } +} diff --git a/internal/watch/enrich/enrichers/orm/typescript/prisma.go b/internal/watch/enrich/enrichers/orm/typescript/prisma.go new file mode 100644 index 0000000..5bd74da --- /dev/null +++ b/internal/watch/enrich/enrichers/orm/typescript/prisma.go @@ -0,0 +1,50 @@ +package typescript + +import ( + "context" + "regexp" + + "github.com/mertcikla/tld/internal/watch/enrich" +) + +type ActivationSignal = enrich.ActivationSignal +type Enricher = enrich.Enricher +type FactEmitter = enrich.FactEmitter +type FileInput = enrich.FileInput +type Metadata = enrich.Metadata +type RoutePattern = enrich.RoutePattern + +const ( + ActivationImportOrDependency = enrich.ActivationImportOrDependency + SignalDependency = enrich.SignalDependency + SignalImport = enrich.SignalImport +) + +var matchLanguages = enrich.MatchLanguages + +func Prisma() Enricher { + return enrich.NewEnricher( + Metadata{ + ID: "ts.prisma", + Name: "Prisma ORM queries", + Mode: ActivationImportOrDependency, + Triggers: []ActivationSignal{ + {Kind: SignalImport, Value: "@prisma/client"}, + {Kind: SignalDependency, Value: "@prisma/client"}, + {Kind: SignalDependency, Value: "prisma"}, + }, + }, + matchLanguages("typescript", "javascript"), + func(ctx context.Context, input FileInput, emit FactEmitter) error { + return enrich.EmitMatches(input, emit, []*RoutePattern{{ + Re: regexp.MustCompile(`\bprisma\.([A-Za-z_][A-Za-z0-9_]*)\.(findMany|findUnique|findFirst|create|createMany|update|updateMany|delete|deleteMany|upsert|aggregate|count)\b`), + FactType: "orm.query", + Framework: "prisma", + Tags: []string{"orm:prisma"}, + Custom: func(match []string) (name string, attrs map[string]string, tags []string) { + return match[1] + "." + match[2], map[string]string{"orm": "prisma", "model": match[1], "operation": match[2]}, []string{"orm:prisma"} + }, + }}) + }, + ) +} diff --git a/internal/watch/enrich/enrichers/orm/typescript/prisma_test.go b/internal/watch/enrich/enrichers/orm/typescript/prisma_test.go new file mode 100644 index 0000000..58d37b7 --- /dev/null +++ b/internal/watch/enrich/enrichers/orm/typescript/prisma_test.go @@ -0,0 +1,47 @@ +package typescript + +import ( + "testing" + + "github.com/mertcikla/tld/internal/watch/enrich" + "github.com/mertcikla/tld/internal/watch/enrich/enrichertest" +) + +func TestPrismaEnricher(t *testing.T) { + enrichertest.Run(t, enrichertest.Case{ + Name: "prisma query requires activation and matches model operation", + Enricher: Prisma(), + Input: enrich.FileInput{ + RelPath: "db.ts", + Language: "typescript", + Source: []byte(`await prisma.user.findMany()`), + }, + Signals: []enrich.ActivationSignal{{Kind: enrich.SignalDependency, Value: "@prisma/client"}}, + Want: enrichertest.Fact{Type: "orm.query", Tag: "orm:prisma", Name: "user.findMany", Attribute: "operation", AttrValue: "findMany"}, + }) +} + +func TestTypeScriptORMCatalogEnrichers(t *testing.T) { + byID := enrichersByID() + for _, spec := range Specs() { + enrichertest.Run(t, enrichertest.Case{ + Name: spec.ID, + Enricher: byID[spec.ID], + Input: enrich.FileInput{ + RelPath: "db.ts", + Language: "typescript", + Source: []byte(spec.SourceTokens[0]), + }, + Signals: []enrich.ActivationSignal{{Kind: enrich.SignalDependency, Value: spec.Triggers[0].Value}}, + Want: enrichertest.Fact{Type: spec.FactType, Tag: "category:orm", Name: spec.Name, Attribute: "technology", AttrValue: spec.Name}, + }) + } +} + +func enrichersByID() map[string]enrich.Enricher { + out := map[string]enrich.Enricher{} + for _, enricher := range All() { + out[enricher.Metadata().ID] = enricher + } + return out +} diff --git a/internal/watch/enrich/enrichers/osintegration/osintegration.go b/internal/watch/enrich/enrichers/osintegration/osintegration.go new file mode 100644 index 0000000..0c309bf --- /dev/null +++ b/internal/watch/enrich/enrichers/osintegration/osintegration.go @@ -0,0 +1,31 @@ +package osintegration + +import ( + "github.com/mertcikla/tld/internal/watch/enrich" + "github.com/mertcikla/tld/internal/watch/enrich/enrichers/pattern" +) + +func All() []enrich.Enricher { return pattern.FromSpecs(Specs()) } + +func Specs() []pattern.Spec { + return []pattern.Spec{ + spec("os.uri_schemes", "Custom URI schemes", []string{"info.plist", "androidmanifest.xml", "electron"}, []string{"CFBundleURLSchemes"}, "os.uri_scheme", "handles_deep_link"), + spec("os.android_intents", "Android Intents", []string{"androidmanifest.xml"}, []string{"`) + +func New(spec Spec) enrich.Enricher { + return enrich.NewEnricher( + enrich.Metadata{ + ID: spec.ID, + Name: spec.Name, + Mode: spec.Mode, + Triggers: spec.Triggers, + }, + func(input enrich.FileInput) bool { + if ignoredPath(input.RelPath) { + return false + } + if len(spec.Languages) == 0 { + return true + } + for _, language := range spec.Languages { + if strings.EqualFold(strings.TrimSpace(language), strings.TrimSpace(input.Language)) { + return true + } + } + return pathMatches(input.RelPath, spec.PathTokens) + }, + func(ctx context.Context, input enrich.FileInput, emit enrich.FactEmitter) error { + line := matchLine(input, spec) + if line == 0 { + return nil + } + attrs := map[string]string{ + "category": spec.Category, + "technology": spec.Name, + "framework": spec.ID, + } + maps.Copy(attrs, spec.Attributes) + tags := []string{"arch:glue", "category:" + tagValue(spec.Category), "technology:" + tagValue(spec.Name)} + tags = append(tags, spec.Tags...) + objectKind := spec.ObjectKind + if objectKind == "" { + objectKind = spec.FactType + } + return emit.EmitFact(enrich.Fact{ + Type: spec.FactType, + StableKey: fmt.Sprintf("%s:%s:%s:%d", spec.FactType, spec.ID, input.RelPath, line), + Subject: enrich.SubjectForLine(input, line), + Object: enrich.SubjectRef{Kind: objectKind, StableKey: spec.FactType + ":" + spec.ID, FilePath: input.RelPath, Name: spec.Name}, + Relationship: spec.Relationship, + Source: enrich.SourceSpan{FilePath: input.RelPath, StartLine: line, EndLine: line}, + Confidence: 0.78, + Name: spec.Name, + Tags: tags, + Attributes: attrs, + VisibilityHints: map[string]float64{ + "high_signal": 0.6, + }, + }) + }, + ) +} + +func FromSpecs(specs []Spec) []enrich.Enricher { + out := make([]enrich.Enricher, 0, len(specs)) + for _, spec := range specs { + out = append(out, New(spec)) + } + return out +} + +func matchLine(input enrich.FileInput, spec Spec) int { + if len(spec.SourceTokens) > 0 { + return matchSourceTokens(input, spec.SourceTokens) + } + if pathMatches(input.RelPath, spec.PathTokens) { + return 1 + } + return 0 +} + +func matchSourceTokens(input enrich.FileInput, tokens []string) int { + source := commentsRE.ReplaceAllString(string(input.Source), "") + lower := strings.ToLower(source) + for _, token := range tokens { + token = strings.ToLower(strings.TrimSpace(token)) + if token == "" { + continue + } + if idx := strings.Index(lower, token); idx >= 0 { + return enrich.LineForOffset(source, idx) + } + } + return 0 +} + +func pathMatches(relPath string, tokens []string) bool { + rel := strings.ToLower(filepath.ToSlash(relPath)) + for _, token := range tokens { + token = strings.ToLower(strings.TrimSpace(token)) + if token != "" && strings.Contains(rel, token) { + return true + } + } + return false +} + +func ignoredPath(relPath string) bool { + parts := strings.SplitSeq(filepath.ToSlash(relPath), "/") + for part := range parts { + switch strings.ToLower(part) { + case ".git", "node_modules", "vendor", "dist", "build", "coverage", "generated", "gen": + return true + } + } + return false +} + +func tagValue(value string) string { + value = strings.ToLower(strings.TrimSpace(value)) + value = strings.NewReplacer(" / ", "-", "/", "-", " ", "-", "&", "and", ".", "", "+", "plus").Replace(value) + for strings.Contains(value, "--") { + value = strings.ReplaceAll(value, "--", "-") + } + return strings.Trim(value, "-") +} diff --git a/internal/watch/enrich/enrichers/pattern/pattern_test.go b/internal/watch/enrich/enrichers/pattern/pattern_test.go new file mode 100644 index 0000000..203a5bf --- /dev/null +++ b/internal/watch/enrich/enrichers/pattern/pattern_test.go @@ -0,0 +1,58 @@ +package pattern + +import ( + "context" + "testing" + + "github.com/mertcikla/tld/internal/watch/enrich" +) + +func TestPatternEnricherIgnoresCommentedMatches(t *testing.T) { + spec := Spec{ + ID: "demo.pattern", + Name: "Demo Pattern", + Category: "demo", + Languages: []string{"go"}, + Mode: enrich.ActivationAlways, + FactType: "demo.fact", + Relationship: "uses", + SourceTokens: []string{"real_call"}, + } + facts, _, err := enrich.NewRegistry(New(spec)).EnrichFile(context.Background(), enrich.FileInput{ + RelPath: "demo.go", + Language: "go", + Source: []byte(`// real_call should not match`), + }) + if err != nil { + t.Fatal(err) + } + if len(facts) != 0 { + t.Fatalf("expected no facts for commented match, got %+v", facts) + } +} + +func TestPatternEnricherIgnoresGeneratedAndVendorPaths(t *testing.T) { + spec := Spec{ + ID: "demo.pattern", + Name: "Demo Pattern", + Category: "demo", + Languages: []string{"go"}, + Mode: enrich.ActivationAlways, + FactType: "demo.fact", + Relationship: "uses", + SourceTokens: []string{"real_call"}, + } + for _, relPath := range []string{"vendor/pkg/demo.go", "generated/demo.go"} { + facts, _, err := enrich.NewRegistry(New(spec)).EnrichFile(context.Background(), enrich.FileInput{ + RelPath: relPath, + Language: "go", + Source: []byte(`real_call()`), + }) + if err != nil { + t.Fatal(err) + } + if len(facts) != 0 { + t.Fatalf("expected no facts for ignored path %s, got %+v", relPath, facts) + } + } +} diff --git a/internal/watch/enrich/enrichers/routes/cpp/routes.go b/internal/watch/enrich/enrichers/routes/cpp/routes.go new file mode 100644 index 0000000..6a8fda7 --- /dev/null +++ b/internal/watch/enrich/enrichers/routes/cpp/routes.go @@ -0,0 +1,57 @@ +package cpp + +import ( + "regexp" + + "github.com/mertcikla/tld/internal/watch/enrich" +) + +type ActivationSignal = enrich.ActivationSignal +type Enricher = enrich.Enricher +type RoutePattern = enrich.RoutePattern + +const ( + SignalDependency = enrich.SignalDependency + SignalImport = enrich.SignalImport +) + +func Drogon() Enricher { + return enrich.RouteRegexEnricher("cpp.drogon", "C++ Drogon routes", "cpp", []ActivationSignal{{Kind: SignalDependency, Value: "drogon"}, {Kind: SignalImport, Value: "drogon"}}, []*RoutePattern{ + {Re: regexp.MustCompile(`METHOD_(GET|POST|PUT|DELETE|PATCH).*?ADD_METHOD_TO\([^,]+,\s*"([^"]+)"`), Framework: "drogon", MethodGroup: 1, PathGroup: 2}, + {Re: regexp.MustCompile(`PATH_ADD\(\s*"([^"]+)"`), Framework: "drogon", PathGroup: 1}, + }) +} + +func Oatpp() Enricher { + return enrich.RouteRegexEnricher("cpp.oatpp", "C++ oatpp routes", "cpp", []ActivationSignal{{Kind: SignalDependency, Value: "oatpp"}, {Kind: SignalImport, Value: "oatpp"}}, []*RoutePattern{ + {Re: regexp.MustCompile(`ENDPOINT\(\s*"([A-Z]+)"\s*,\s*"([^"]+)"`), Framework: "oatpp", MethodGroup: 1, PathGroup: 2}, + }) +} + +func Pistache() Enricher { + return enrich.RouteRegexEnricher("cpp.pistache", "C++ Pistache routes", "cpp", []ActivationSignal{{Kind: SignalDependency, Value: "pistache"}, {Kind: SignalImport, Value: "pistache"}}, []*RoutePattern{ + {Re: regexp.MustCompile(`Routes::(Get|Post|Put|Delete|Patch)\(\s*router\s*,\s*"([^"]+)"`), Framework: "pistache", MethodGroup: 1, PathGroup: 2}, + }) +} + +func Crow() Enricher { + return enrich.RouteRegexEnricher("cpp.crow", "C++ Crow routes", "cpp", []ActivationSignal{{Kind: SignalDependency, Value: "crow"}, {Kind: SignalImport, Value: "crow"}}, []*RoutePattern{ + {Re: regexp.MustCompile(`CROW_ROUTE\([^,]+,\s*"([^"]+)"`), Framework: "crow", PathGroup: 1}, + }) +} + +func CppRestSDK() Enricher { + return enrich.RouteRegexEnricher("cpp.cpprestsdk", "C++ cpprestsdk routes", "cpp", []ActivationSignal{{Kind: SignalDependency, Value: "cpprestsdk"}, {Kind: SignalImport, Value: "cpprest"}}, []*RoutePattern{ + { + Re: regexp.MustCompile(`support\(\s*methods::(GET|POST|PUT|DEL|PATCH)\s*,`), + Framework: "cpprestsdk", + Custom: func(match []string) (string, map[string]string, []string) { + method := match[1] + if method == "DEL" { + method = "DELETE" + } + return method + " handler", map[string]string{"framework": "cpprestsdk", "method": method}, []string{"http:route", "framework:cpprestsdk"} + }, + }, + }) +} diff --git a/internal/watch/enrich/enrichers/routes/cpp/routes_test.go b/internal/watch/enrich/enrichers/routes/cpp/routes_test.go new file mode 100644 index 0000000..93c38b5 --- /dev/null +++ b/internal/watch/enrich/enrichers/routes/cpp/routes_test.go @@ -0,0 +1,30 @@ +package cpp + +import ( + "testing" + + "github.com/mertcikla/tld/internal/watch/enrich" + "github.com/mertcikla/tld/internal/watch/enrich/enrichertest" +) + +func TestCPPRouteEnrichers(t *testing.T) { + enrichertest.Run(t, + enrichertest.Case{Name: "drogon route", Enricher: Drogon(), Input: input(`METHOD_GET ADD_METHOD_TO(UserController::get, "/users")`), Signals: signal("drogon"), Want: want("framework:drogon", "GET /users")}, + enrichertest.Case{Name: "oatpp route", Enricher: Oatpp(), Input: input(`ENDPOINT("POST", "/orders", createOrder)`), Signals: signal("oatpp"), Want: want("framework:oatpp", "POST /orders")}, + enrichertest.Case{Name: "pistache route", Enricher: Pistache(), Input: input(`Routes::Get(router, "/health", handler)`), Signals: signal("pistache"), Want: want("framework:pistache", "GET /health")}, + enrichertest.Case{Name: "crow route", Enricher: Crow(), Input: input(`CROW_ROUTE(app, "/metrics")`), Signals: signal("crow"), Want: want("framework:crow", "/metrics")}, + enrichertest.Case{Name: "cpprestsdk route", Enricher: CppRestSDK(), Input: input(`listener.support(methods::GET, handler)`), Signals: signal("cpprestsdk"), Want: want("framework:cpprestsdk", "GET handler")}, + ) +} + +func input(source string) enrich.FileInput { + return enrich.FileInput{RelPath: "src/routes.cpp", Language: "cpp", Source: []byte(source)} +} + +func signal(value string) []enrich.ActivationSignal { + return []enrich.ActivationSignal{{Kind: enrich.SignalDependency, Value: value}} +} + +func want(tag, name string) enrichertest.Fact { + return enrichertest.Fact{Type: "http.route", Tag: tag, Name: name} +} diff --git a/internal/watch/enrich/enrichers/routes/golang/routes.go b/internal/watch/enrich/enrichers/routes/golang/routes.go new file mode 100644 index 0000000..cd11962 --- /dev/null +++ b/internal/watch/enrich/enrichers/routes/golang/routes.go @@ -0,0 +1,73 @@ +package golang + +import ( + "regexp" + + "github.com/mertcikla/tld/internal/watch/enrich" +) + +type ActivationSignal = enrich.ActivationSignal +type Enricher = enrich.Enricher +type RoutePattern = enrich.RoutePattern + +const ( + SignalDependency = enrich.SignalDependency + SignalImport = enrich.SignalImport +) + +func GoGorillaMux() Enricher { + return enrich.RouteRegexEnricher("go.gorilla_mux", "Go gorilla/mux routes", "go", []ActivationSignal{ + {Kind: SignalImport, Value: "github.com/gorilla/mux"}, + {Kind: SignalDependency, Value: "github.com/gorilla/mux"}, + }, []*RoutePattern{ + {Re: regexp.MustCompile(`\b[A-Za-z_][A-Za-z0-9_]*\.HandleFunc\(\s*([^,\n]+),`), Framework: "gorilla-mux", PathGroup: 1}, + }) +} + +func GoNetHTTP() Enricher { + return enrich.RouteRegexEnricher("go.nethttp", "Go net/http routes", "go", []ActivationSignal{ + {Kind: SignalImport, Value: "net/http"}, + {Kind: SignalDependency, Value: "net/http"}, + }, []*RoutePattern{ + {Re: regexp.MustCompile(`\bhttp\.HandleFunc\(\s*"([^"]+)"`), Method: "", Framework: "nethttp"}, + {Re: regexp.MustCompile(`\b[A-Za-z_][A-Za-z0-9_]*\.HandleFunc\(\s*"([^"]+)"`), Method: "", Framework: "nethttp"}, + }) +} + +func GoChi() Enricher { + return enrich.RouteRegexEnricher("go.chi", "Go chi routes", "go", []ActivationSignal{ + {Kind: SignalImport, Value: "github.com/go-chi/chi"}, + {Kind: SignalDependency, Value: "github.com/go-chi/chi"}, + }, []*RoutePattern{ + {Re: regexp.MustCompile(`\b[A-Za-z_][A-Za-z0-9_]*\.(Get|Post|Put|Delete|Patch)\(\s*"([^"]+)"`), Framework: "chi", MethodGroup: 1, PathGroup: 2}, + {Re: regexp.MustCompile(`\b[A-Za-z_][A-Za-z0-9_]*\.Route\(\s*"([^"]+)"`), Framework: "chi"}, + }) +} + +func GoGin() Enricher { + return enrich.RouteRegexEnricher("go.gin", "Go gin routes", "go", []ActivationSignal{ + {Kind: SignalImport, Value: "github.com/gin-gonic/gin"}, + {Kind: SignalDependency, Value: "github.com/gin-gonic/gin"}, + }, []*RoutePattern{ + {Re: regexp.MustCompile(`\b[A-Za-z_][A-Za-z0-9_]*\.(GET|POST|PUT|DELETE|PATCH)\(\s*"([^"]+)"`), Framework: "gin", MethodGroup: 1, PathGroup: 2}, + {Re: regexp.MustCompile(`\b[A-Za-z_][A-Za-z0-9_]*\.Group\(\s*"([^"]+)"`), Framework: "gin"}, + }) +} + +func GoEcho() Enricher { + return enrich.RouteRegexEnricher("go.echo", "Go Echo routes", "go", []ActivationSignal{ + {Kind: SignalImport, Value: "github.com/labstack/echo"}, + {Kind: SignalDependency, Value: "github.com/labstack/echo"}, + }, []*RoutePattern{ + {Re: regexp.MustCompile(`\b[A-Za-z_][A-Za-z0-9_]*\.(GET|POST|PUT|DELETE|PATCH)\(\s*"([^"]+)"`), Framework: "echo", MethodGroup: 1, PathGroup: 2}, + }) +} + +func GoFiber() Enricher { + return enrich.RouteRegexEnricher("go.fiber", "Go Fiber routes", "go", []ActivationSignal{ + {Kind: SignalImport, Value: "github.com/gofiber/fiber"}, + {Kind: SignalDependency, Value: "github.com/gofiber/fiber"}, + }, []*RoutePattern{ + {Re: regexp.MustCompile(`\b[A-Za-z_][A-Za-z0-9_]*\.(Get|Post|Put|Delete|Patch)\(\s*"([^"]+)"`), Framework: "fiber", MethodGroup: 1, PathGroup: 2}, + }) +} diff --git a/internal/watch/enrich/enrichers/routes/golang/routes_test.go b/internal/watch/enrich/enrichers/routes/golang/routes_test.go new file mode 100644 index 0000000..17b134a --- /dev/null +++ b/internal/watch/enrich/enrichers/routes/golang/routes_test.go @@ -0,0 +1,46 @@ +package golang + +import ( + "testing" + + "github.com/mertcikla/tld/internal/watch/enrich" + "github.com/mertcikla/tld/internal/watch/enrich/enrichertest" +) + +func TestGoRouteEnrichers(t *testing.T) { + enrichertest.Run(t, + enrichertest.Case{ + Name: "chi route requires activation and matches route call", + Enricher: GoChi(), + Input: enrich.FileInput{ + RelPath: "routes.go", + Language: "go", + Source: []byte(`func routes(r chi.Router) { r.Get("/users/{id}", getUser) }`), + }, + Signals: []enrich.ActivationSignal{{Kind: enrich.SignalImport, Value: "github.com/go-chi/chi/v5"}}, + Want: enrichertest.Fact{Type: "http.route", Tag: "framework:chi", Name: "GET /users/{id}"}, + }, + enrichertest.Case{ + Name: "echo route", + Enricher: GoEcho(), + Input: enrich.FileInput{ + RelPath: "routes.go", + Language: "go", + Source: []byte(`e.POST("/orders", createOrder)`), + }, + Signals: []enrich.ActivationSignal{{Kind: enrich.SignalDependency, Value: "github.com/labstack/echo/v4"}}, + Want: enrichertest.Fact{Type: "http.route", Tag: "framework:echo", Name: "POST /orders"}, + }, + enrichertest.Case{ + Name: "fiber route", + Enricher: GoFiber(), + Input: enrich.FileInput{ + RelPath: "routes.go", + Language: "go", + Source: []byte(`app.Get("/status", status)`), + }, + Signals: []enrich.ActivationSignal{{Kind: enrich.SignalDependency, Value: "github.com/gofiber/fiber/v2"}}, + Want: enrichertest.Fact{Type: "http.route", Tag: "framework:fiber", Name: "GET /status"}, + }, + ) +} diff --git a/internal/watch/enrich/enrichers/routes/java/routes.go b/internal/watch/enrich/enrichers/routes/java/routes.go new file mode 100644 index 0000000..b784779 --- /dev/null +++ b/internal/watch/enrich/enrichers/routes/java/routes.go @@ -0,0 +1,55 @@ +package java + +import ( + "regexp" + + "github.com/mertcikla/tld/internal/watch/enrich" +) + +type ActivationSignal = enrich.ActivationSignal +type Enricher = enrich.Enricher +type RoutePattern = enrich.RoutePattern + +const ( + SignalDependency = enrich.SignalDependency + SignalImport = enrich.SignalImport +) + +func Spring() Enricher { + return enrich.RouteRegexEnricher("java.spring_web", "Java Spring MVC/WebFlux routes", "java", []ActivationSignal{ + {Kind: SignalImport, Value: "org.springframework.web.bind.annotation"}, + {Kind: SignalDependency, Value: "spring-boot-starter-web"}, + {Kind: SignalDependency, Value: "spring-boot-starter-webflux"}, + }, []*RoutePattern{ + {Re: regexp.MustCompile(`@(Get|Post|Put|Delete|Patch)Mapping\(\s*["']([^"']+)["']`), Framework: "spring", MethodGroup: 1, PathGroup: 2}, + {Re: regexp.MustCompile(`@RequestMapping\(\s*["']([^"']+)["']`), Framework: "spring", PathGroup: 1}, + }) +} + +func JAXRS() Enricher { + return enrich.RouteRegexEnricher("java.jax_rs", "Java JAX-RS routes", "java", []ActivationSignal{ + {Kind: SignalImport, Value: "jakarta.ws.rs"}, + {Kind: SignalDependency, Value: "jakarta.ws.rs-api"}, + }, []*RoutePattern{ + {Re: regexp.MustCompile(`@Path\(\s*["']([^"']+)["']`), Framework: "jax-rs", PathGroup: 1}, + }) +} + +func Micronaut() Enricher { + return enrich.RouteRegexEnricher("java.micronaut", "Java Micronaut routes", "java", []ActivationSignal{ + {Kind: SignalImport, Value: "io.micronaut.http.annotation"}, + {Kind: SignalDependency, Value: "micronaut-http-server"}, + }, []*RoutePattern{ + {Re: regexp.MustCompile(`@(Get|Post|Put|Delete|Patch)\(\s*["']([^"']+)["']`), Framework: "micronaut", MethodGroup: 1, PathGroup: 2}, + {Re: regexp.MustCompile(`@Controller\(\s*["']([^"']+)["']`), Framework: "micronaut", PathGroup: 1}, + }) +} + +func Quarkus() Enricher { + return enrich.RouteRegexEnricher("java.quarkus", "Java Quarkus routes", "java", []ActivationSignal{ + {Kind: SignalImport, Value: "io.quarkus"}, + {Kind: SignalDependency, Value: "quarkus-resteasy-reactive"}, + }, []*RoutePattern{ + {Re: regexp.MustCompile(`@Path\(\s*["']([^"']+)["']`), Framework: "quarkus", PathGroup: 1}, + }) +} diff --git a/internal/watch/enrich/enrichers/routes/java/routes_test.go b/internal/watch/enrich/enrichers/routes/java/routes_test.go new file mode 100644 index 0000000..92f907a --- /dev/null +++ b/internal/watch/enrich/enrichers/routes/java/routes_test.go @@ -0,0 +1,29 @@ +package java + +import ( + "testing" + + "github.com/mertcikla/tld/internal/watch/enrich" + "github.com/mertcikla/tld/internal/watch/enrich/enrichertest" +) + +func TestJavaRouteEnrichers(t *testing.T) { + enrichertest.Run(t, + enrichertest.Case{Name: "spring route", Enricher: Spring(), Input: input(`@GetMapping("/users")`), Signals: signal("spring-boot-starter-web"), Want: want("framework:spring", "GET /users")}, + enrichertest.Case{Name: "jax-rs route", Enricher: JAXRS(), Input: input(`@Path("/users")`), Signals: signal("jakarta.ws.rs-api"), Want: want("framework:jax-rs", "/users")}, + enrichertest.Case{Name: "micronaut route", Enricher: Micronaut(), Input: input(`@Post("/orders")`), Signals: signal("micronaut-http-server"), Want: want("framework:micronaut", "POST /orders")}, + enrichertest.Case{Name: "quarkus route", Enricher: Quarkus(), Input: input(`@Path("/health")`), Signals: signal("quarkus-resteasy-reactive"), Want: want("framework:quarkus", "/health")}, + ) +} + +func input(source string) enrich.FileInput { + return enrich.FileInput{RelPath: "Controller.java", Language: "java", Source: []byte(source)} +} + +func signal(value string) []enrich.ActivationSignal { + return []enrich.ActivationSignal{{Kind: enrich.SignalDependency, Value: value}} +} + +func want(tag, name string) enrichertest.Fact { + return enrichertest.Fact{Type: "http.route", Tag: tag, Name: name} +} diff --git a/internal/watch/enrich/enrichers/routes/python/routes.go b/internal/watch/enrich/enrichers/routes/python/routes.go new file mode 100644 index 0000000..8f71d2e --- /dev/null +++ b/internal/watch/enrich/enrichers/routes/python/routes.go @@ -0,0 +1,53 @@ +package python + +import ( + "regexp" + + "github.com/mertcikla/tld/internal/watch/enrich" +) + +type ActivationSignal = enrich.ActivationSignal +type Enricher = enrich.Enricher +type RoutePattern = enrich.RoutePattern + +const ( + SignalDependency = enrich.SignalDependency + SignalImport = enrich.SignalImport +) + +func PythonFlask() Enricher { + return enrich.RouteRegexEnricher("python.flask", "Python Flask routes", "python", []ActivationSignal{ + {Kind: SignalImport, Value: "flask"}, + {Kind: SignalDependency, Value: "flask"}, + }, []*RoutePattern{ + {Re: regexp.MustCompile(`@(?:[A-Za-z_][A-Za-z0-9_]*\.)?route\(\s*["']([^"']+)["']`), FactType: "http.route", Framework: "flask", Tags: []string{"http:route", "framework:flask"}}, + }) +} + +func PythonFastAPI() Enricher { + return enrich.RouteRegexEnricher("python.fastapi", "Python FastAPI routes", "python", []ActivationSignal{ + {Kind: SignalImport, Value: "fastapi"}, + {Kind: SignalDependency, Value: "fastapi"}, + }, []*RoutePattern{ + {Re: regexp.MustCompile(`@(?:[A-Za-z_][A-Za-z0-9_]*\.)?(get|post|put|delete|patch)\(\s*["']([^"']+)["']`), FactType: "http.route", Framework: "fastapi", MethodGroup: 1, PathGroup: 2, Tags: []string{"http:route", "framework:fastapi"}}, + }) +} + +func PythonDjango() Enricher { + return enrich.RouteRegexEnricher("python.django", "Python Django routes", "python", []ActivationSignal{ + {Kind: SignalImport, Value: "django"}, + {Kind: SignalDependency, Value: "django"}, + }, []*RoutePattern{ + {Re: regexp.MustCompile(`\bpath\(\s*["']([^"']+)["']`), FactType: "http.route", Framework: "django", Tags: []string{"http:route", "framework:django"}}, + {Re: regexp.MustCompile(`\bre_path\(\s*["']([^"']+)["']`), FactType: "http.route", Framework: "django", Tags: []string{"http:route", "framework:django"}}, + }) +} + +func PythonStarlette() Enricher { + return enrich.RouteRegexEnricher("python.starlette", "Python Starlette routes", "python", []ActivationSignal{ + {Kind: SignalImport, Value: "starlette"}, + {Kind: SignalDependency, Value: "starlette"}, + }, []*RoutePattern{ + {Re: regexp.MustCompile(`\bRoute\(\s*["']([^"']+)["']`), FactType: "http.route", Framework: "starlette", Tags: []string{"http:route", "framework:starlette"}}, + }) +} diff --git a/internal/watch/enrich/enrichers/routes/python/routes_test.go b/internal/watch/enrich/enrichers/routes/python/routes_test.go new file mode 100644 index 0000000..f12bb76 --- /dev/null +++ b/internal/watch/enrich/enrichers/routes/python/routes_test.go @@ -0,0 +1,57 @@ +package python + +import ( + "testing" + + "github.com/mertcikla/tld/internal/watch/enrich" + "github.com/mertcikla/tld/internal/watch/enrich/enrichertest" +) + +func TestPythonRouteEnrichers(t *testing.T) { + enrichertest.Run(t, + enrichertest.Case{ + Name: "flask route requires activation and matches route decorator", + Enricher: PythonFlask(), + Input: enrich.FileInput{ + RelPath: "app.py", + Language: "python", + Source: []byte(`@app.route("/users/")`), + }, + Signals: []enrich.ActivationSignal{{Kind: enrich.SignalDependency, Value: "flask"}}, + Want: enrichertest.Fact{Type: "http.route", Tag: "framework:flask", Name: "/users/"}, + }, + enrichertest.Case{ + Name: "fastapi route", + Enricher: PythonFastAPI(), + Input: enrich.FileInput{ + RelPath: "app.py", + Language: "python", + Source: []byte(`@app.post("/orders")`), + }, + Signals: []enrich.ActivationSignal{{Kind: enrich.SignalDependency, Value: "fastapi"}}, + Want: enrichertest.Fact{Type: "http.route", Tag: "framework:fastapi", Name: "POST /orders"}, + }, + enrichertest.Case{ + Name: "django route", + Enricher: PythonDjango(), + Input: enrich.FileInput{ + RelPath: "urls.py", + Language: "python", + Source: []byte(`path("users//", view)`), + }, + Signals: []enrich.ActivationSignal{{Kind: enrich.SignalDependency, Value: "django"}}, + Want: enrichertest.Fact{Type: "http.route", Tag: "framework:django", Name: "users//"}, + }, + enrichertest.Case{ + Name: "starlette route", + Enricher: PythonStarlette(), + Input: enrich.FileInput{ + RelPath: "routes.py", + Language: "python", + Source: []byte(`Route("/health", health)`), + }, + Signals: []enrich.ActivationSignal{{Kind: enrich.SignalDependency, Value: "starlette"}}, + Want: enrichertest.Fact{Type: "http.route", Tag: "framework:starlette", Name: "/health"}, + }, + ) +} diff --git a/internal/watch/enrich/enrichers/routes/rust/routes.go b/internal/watch/enrich/enrichers/routes/rust/routes.go new file mode 100644 index 0000000..b168346 --- /dev/null +++ b/internal/watch/enrich/enrichers/routes/rust/routes.go @@ -0,0 +1,41 @@ +package rust + +import ( + "regexp" + + "github.com/mertcikla/tld/internal/watch/enrich" +) + +type ActivationSignal = enrich.ActivationSignal +type Enricher = enrich.Enricher +type RoutePattern = enrich.RoutePattern + +const ( + SignalDependency = enrich.SignalDependency + SignalImport = enrich.SignalImport +) + +func Axum() Enricher { + return enrich.RouteRegexEnricher("rust.axum", "Rust axum routes", "rust", []ActivationSignal{{Kind: SignalDependency, Value: "axum"}, {Kind: SignalImport, Value: "axum"}}, []*RoutePattern{ + {Re: regexp.MustCompile(`\broute\(\s*"([^"]+)"\s*,\s*(get|post|put|delete|patch)\(`), Framework: "axum", PathGroup: 1, MethodGroup: 2}, + }) +} + +func ActixWeb() Enricher { + return enrich.RouteRegexEnricher("rust.actix_web", "Rust actix-web routes", "rust", []ActivationSignal{{Kind: SignalDependency, Value: "actix-web"}, {Kind: SignalImport, Value: "actix_web"}}, []*RoutePattern{ + {Re: regexp.MustCompile(`#\[(get|post|put|delete|patch)\("([^"]+)"\)\]`), Framework: "actix-web", MethodGroup: 1, PathGroup: 2}, + {Re: regexp.MustCompile(`\.route\(\s*"([^"]+)"\s*,\s*web::(get|post|put|delete|patch)\(`), Framework: "actix-web", PathGroup: 1, MethodGroup: 2}, + }) +} + +func Rocket() Enricher { + return enrich.RouteRegexEnricher("rust.rocket", "Rust Rocket routes", "rust", []ActivationSignal{{Kind: SignalDependency, Value: "rocket"}, {Kind: SignalImport, Value: "rocket"}}, []*RoutePattern{ + {Re: regexp.MustCompile(`#\[(get|post|put|delete|patch)\("([^"]+)"\)\]`), Framework: "rocket", MethodGroup: 1, PathGroup: 2}, + }) +} + +func Warp() Enricher { + return enrich.RouteRegexEnricher("rust.warp", "Rust warp routes", "rust", []ActivationSignal{{Kind: SignalDependency, Value: "warp"}, {Kind: SignalImport, Value: "warp"}}, []*RoutePattern{ + {Re: regexp.MustCompile(`warp::path!\(\s*"([^"]+)"`), Framework: "warp", PathGroup: 1}, + }) +} diff --git a/internal/watch/enrich/enrichers/routes/rust/routes_test.go b/internal/watch/enrich/enrichers/routes/rust/routes_test.go new file mode 100644 index 0000000..3c6f0ca --- /dev/null +++ b/internal/watch/enrich/enrichers/routes/rust/routes_test.go @@ -0,0 +1,29 @@ +package rust + +import ( + "testing" + + "github.com/mertcikla/tld/internal/watch/enrich" + "github.com/mertcikla/tld/internal/watch/enrich/enrichertest" +) + +func TestRustRouteEnrichers(t *testing.T) { + enrichertest.Run(t, + enrichertest.Case{Name: "axum route", Enricher: Axum(), Input: input(`route("/users", get(handler))`), Signals: signal("axum"), Want: want("framework:axum", "GET /users")}, + enrichertest.Case{Name: "actix route", Enricher: ActixWeb(), Input: input(`#[post("/orders")]`), Signals: signal("actix-web"), Want: want("framework:actix-web", "POST /orders")}, + enrichertest.Case{Name: "rocket route", Enricher: Rocket(), Input: input(`#[get("/health")]`), Signals: signal("rocket"), Want: want("framework:rocket", "GET /health")}, + enrichertest.Case{Name: "warp route", Enricher: Warp(), Input: input(`warp::path!("metrics")`), Signals: signal("warp"), Want: want("framework:warp", "metrics")}, + ) +} + +func input(source string) enrich.FileInput { + return enrich.FileInput{RelPath: "src/routes.rs", Language: "rust", Source: []byte(source)} +} + +func signal(value string) []enrich.ActivationSignal { + return []enrich.ActivationSignal{{Kind: enrich.SignalDependency, Value: value}} +} + +func want(tag, name string) enrichertest.Fact { + return enrichertest.Fact{Type: "http.route", Tag: tag, Name: name} +} diff --git a/internal/watch/enrich/enrichers/routes/typescript/routes.go b/internal/watch/enrich/enrichers/routes/typescript/routes.go new file mode 100644 index 0000000..842876d --- /dev/null +++ b/internal/watch/enrich/enrichers/routes/typescript/routes.go @@ -0,0 +1,54 @@ +package typescript + +import ( + "regexp" + + "github.com/mertcikla/tld/internal/watch/enrich" +) + +type ActivationSignal = enrich.ActivationSignal +type Enricher = enrich.Enricher +type RoutePattern = enrich.RoutePattern + +const ( + SignalDependency = enrich.SignalDependency + SignalImport = enrich.SignalImport +) + +func Express() Enricher { + return enrich.RouteRegexEnricher("ts.express", "Express routes", "typescript,javascript", []ActivationSignal{ + {Kind: SignalImport, Value: "express"}, + {Kind: SignalDependency, Value: "express"}, + }, []*RoutePattern{ + {Re: regexp.MustCompile(`\b(?:app|router)\.(get|post|put|delete|patch)\(\s*["'\x60]([^"'\x60]+)["'\x60]`), Framework: "express", MethodGroup: 1, PathGroup: 2}, + }) +} + +func Fastify() Enricher { + return enrich.RouteRegexEnricher("ts.fastify", "Fastify routes", "typescript,javascript", []ActivationSignal{ + {Kind: SignalImport, Value: "fastify"}, + {Kind: SignalDependency, Value: "fastify"}, + }, []*RoutePattern{ + {Re: regexp.MustCompile(`\b[A-Za-z_][A-Za-z0-9_]*\.(get|post|put|delete|patch)\(\s*["'\x60]([^"'\x60]+)["'\x60]`), Framework: "fastify", MethodGroup: 1, PathGroup: 2}, + {Re: regexp.MustCompile(`\b[A-Za-z_][A-Za-z0-9_]*\.route\(\s*\{[^}]*method:\s*["'\x60]([A-Z]+)["'\x60][^}]*url:\s*["'\x60]([^"'\x60]+)["'\x60]`), Framework: "fastify", MethodGroup: 1, PathGroup: 2}, + }) +} + +func NestJS() Enricher { + return enrich.RouteRegexEnricher("ts.nestjs", "NestJS routes", "typescript,javascript", []ActivationSignal{ + {Kind: SignalImport, Value: "@nestjs/common"}, + {Kind: SignalDependency, Value: "@nestjs/common"}, + }, []*RoutePattern{ + {Re: regexp.MustCompile(`@(Get|Post|Put|Delete|Patch)\(\s*["'\x60]?([^"'\x60)]*)["'\x60]?\s*\)`), Framework: "nestjs", MethodGroup: 1, PathGroup: 2}, + {Re: regexp.MustCompile(`@Controller\(\s*["'\x60]([^"'\x60]+)["'\x60]`), Framework: "nestjs", PathGroup: 1}, + }) +} + +func Hono() Enricher { + return enrich.RouteRegexEnricher("ts.hono", "Hono routes", "typescript,javascript", []ActivationSignal{ + {Kind: SignalImport, Value: "hono"}, + {Kind: SignalDependency, Value: "hono"}, + }, []*RoutePattern{ + {Re: regexp.MustCompile(`\b[A-Za-z_][A-Za-z0-9_]*\.(get|post|put|delete|patch)\(\s*["'\x60]([^"'\x60]+)["'\x60]`), Framework: "hono", MethodGroup: 1, PathGroup: 2}, + }) +} diff --git a/internal/watch/enrich/enrichers/routes/typescript/routes_test.go b/internal/watch/enrich/enrichers/routes/typescript/routes_test.go new file mode 100644 index 0000000..9a8c810 --- /dev/null +++ b/internal/watch/enrich/enrichers/routes/typescript/routes_test.go @@ -0,0 +1,57 @@ +package typescript + +import ( + "testing" + + "github.com/mertcikla/tld/internal/watch/enrich" + "github.com/mertcikla/tld/internal/watch/enrich/enrichertest" +) + +func TestTypeScriptRouteEnrichers(t *testing.T) { + enrichertest.Run(t, + enrichertest.Case{ + Name: "express route requires activation and matches router call", + Enricher: Express(), + Input: enrich.FileInput{ + RelPath: "server.ts", + Language: "typescript", + Source: []byte(`router.post("/api/users", createUser)`), + }, + Signals: []enrich.ActivationSignal{{Kind: enrich.SignalDependency, Value: "express"}}, + Want: enrichertest.Fact{Type: "http.route", Tag: "framework:express", Name: "POST /api/users"}, + }, + enrichertest.Case{ + Name: "fastify route", + Enricher: Fastify(), + Input: enrich.FileInput{ + RelPath: "server.ts", + Language: "typescript", + Source: []byte(`fastify.get("/health", handler)`), + }, + Signals: []enrich.ActivationSignal{{Kind: enrich.SignalDependency, Value: "fastify"}}, + Want: enrichertest.Fact{Type: "http.route", Tag: "framework:fastify", Name: "GET /health"}, + }, + enrichertest.Case{ + Name: "nestjs route", + Enricher: NestJS(), + Input: enrich.FileInput{ + RelPath: "users.controller.ts", + Language: "typescript", + Source: []byte(`@Get("users/:id")`), + }, + Signals: []enrich.ActivationSignal{{Kind: enrich.SignalDependency, Value: "@nestjs/common"}}, + Want: enrichertest.Fact{Type: "http.route", Tag: "framework:nestjs", Name: "GET users/:id"}, + }, + enrichertest.Case{ + Name: "hono route", + Enricher: Hono(), + Input: enrich.FileInput{ + RelPath: "server.ts", + Language: "typescript", + Source: []byte(`app.post("/events", handler)`), + }, + Signals: []enrich.ActivationSignal{{Kind: enrich.SignalDependency, Value: "hono"}}, + Want: enrichertest.Fact{Type: "http.route", Tag: "framework:hono", Name: "POST /events"}, + }, + ) +} diff --git a/internal/watch/enrich/enrichers/rpc/clients/clients.go b/internal/watch/enrich/enrichers/rpc/clients/clients.go new file mode 100644 index 0000000..e915439 --- /dev/null +++ b/internal/watch/enrich/enrichers/rpc/clients/clients.go @@ -0,0 +1,47 @@ +package clients + +import ( + "github.com/mertcikla/tld/internal/watch/enrich" + "github.com/mertcikla/tld/internal/watch/enrich/enrichers/pattern" +) + +func All() []enrich.Enricher { return pattern.FromSpecs(Specs()) } + +func Specs() []pattern.Spec { + return []pattern.Spec{ + spec("ts.connectrpc", "TypeScript ConnectRPC", "typescript", "@connectrpc/connect", "createClient", "rpc.client", "calls"), + spec("ts.openapi_client", "TypeScript generated OpenAPI client", "typescript", "openapi", "new DefaultApi", "integration.client", "calls"), + spec("ts.graphql_client", "TypeScript GraphQL client", "typescript", "graphql-request", "GraphQLClient", "rpc.client", "calls"), + spec("go.connectrpc", "Go connect-go", "go", "connectrpc.com/connect", "connect.New", "rpc.client", "calls"), + spec("go.twirp", "Go Twirp", "go", "github.com/twitchtv/twirp", "twirp", "rpc.client", "calls"), + spec("go.openapi_client", "Go generated OpenAPI client", "go", "openapi", "NewAPIClient", "integration.client", "calls"), + spec("python.openapi_client", "Python generated OpenAPI client", "python", "openapi", "ApiClient", "integration.client", "calls"), + spec("python.gql", "Python gql", "python", "gql", "Client(", "rpc.client", "calls"), + spec("java.openfeign", "Java OpenFeign", "java", "spring-cloud-starter-openfeign", "@FeignClient", "rpc.client", "calls"), + spec("java.retrofit_rpc", "Java Retrofit RPC", "java", "retrofit2", "Retrofit.Builder", "rpc.client", "calls"), + spec("java.openapi_client", "Java generated OpenAPI client", "java", "openapi", "ApiClient", "integration.client", "calls"), + spec("java.graphql_client", "Java GraphQL client", "java", "graphql-java", "GraphQL", "rpc.client", "calls"), + spec("rust.tonic", "Rust tonic", "rust", "tonic", "tonic::transport", "rpc.client", "calls"), + spec("rust.openapi_client", "Rust generated OpenAPI client", "rust", "openapi", "apis::", "integration.client", "calls"), + spec("rust.graphql_client", "Rust graphql_client", "rust", "graphql_client", "GraphQLQuery", "rpc.client", "calls"), + spec("cpp.grpc_cpp", "C++ grpc-cpp", "cpp", "grpc++", "grpc::CreateChannel", "rpc.client", "calls"), + spec("cpp.openapi_client", "C++ generated OpenAPI client", "cpp", "openapi", "ApiClient", "integration.client", "calls"), + spec("cpp.thrift", "C++ Thrift", "cpp", "thrift", "apache::thrift", "rpc.client", "calls"), + } +} + +func spec(id, name, language, dependency, token, factType, relationship string) pattern.Spec { + return pattern.Spec{ + ID: id, + Name: name, + Category: "rpc", + Languages: []string{language}, + Mode: enrich.ActivationImportOrDependency, + Triggers: []enrich.ActivationSignal{{Kind: enrich.SignalDependency, Value: dependency}, {Kind: enrich.SignalImport, Value: dependency}}, + FactType: factType, + Relationship: relationship, + SourceTokens: []string{token}, + Tags: []string{"rpc:" + id}, + Attributes: map[string]string{"dependency": dependency, "language": language}, + } +} diff --git a/internal/watch/enrich/enrichers/rpc/clients/clients_test.go b/internal/watch/enrich/enrichers/rpc/clients/clients_test.go new file mode 100644 index 0000000..607cda5 --- /dev/null +++ b/internal/watch/enrich/enrichers/rpc/clients/clients_test.go @@ -0,0 +1,33 @@ +package clients + +import ( + "testing" + + "github.com/mertcikla/tld/internal/watch/enrich" + "github.com/mertcikla/tld/internal/watch/enrich/enrichertest" +) + +func TestRPCClientEnrichers(t *testing.T) { + byID := enrichersByID() + for _, spec := range Specs() { + enrichertest.Run(t, enrichertest.Case{ + Name: spec.ID, + Enricher: byID[spec.ID], + Input: enrich.FileInput{ + RelPath: "src/rpc", + Language: spec.Languages[0], + Source: []byte(spec.SourceTokens[0]), + }, + Signals: []enrich.ActivationSignal{{Kind: enrich.SignalDependency, Value: spec.Triggers[0].Value}}, + Want: enrichertest.Fact{Type: spec.FactType, Tag: "category:rpc", Name: spec.Name, Attribute: "technology", AttrValue: spec.Name}, + }) + } +} + +func enrichersByID() map[string]enrich.Enricher { + out := map[string]enrich.Enricher{} + for _, enricher := range All() { + out[enricher.Metadata().ID] = enricher + } + return out +} diff --git a/internal/watch/enrich/enrichers/rpc/grpc/grpc.go b/internal/watch/enrich/enrichers/rpc/grpc/grpc.go new file mode 100644 index 0000000..4d679e7 --- /dev/null +++ b/internal/watch/enrich/enrichers/rpc/grpc/grpc.go @@ -0,0 +1,296 @@ +package grpc + +import ( + "context" + "fmt" + "path" + "regexp" + "strings" + + "github.com/mertcikla/tld/internal/watch/enrich" +) + +type ActivationSignal = enrich.ActivationSignal +type Enricher = enrich.Enricher +type Fact = enrich.Fact +type FactEmitter = enrich.FactEmitter +type FileInput = enrich.FileInput +type Metadata = enrich.Metadata +type SourceSpan = enrich.SourceSpan +type SubjectRef = enrich.SubjectRef + +const ( + ActivationAlways = enrich.ActivationAlways + ActivationImportOrDependency = enrich.ActivationImportOrDependency + SignalDependency = enrich.SignalDependency + SignalImport = enrich.SignalImport +) + +var ( + fileSubject = enrich.FileSubject + lineForOffset = enrich.LineForOffset + matchLanguages = enrich.MatchLanguages + subjectForLine = enrich.SubjectForLine + submatches = enrich.Submatches +) + +func ProtobufContracts() Enricher { + return enrich.NewEnricher( + Metadata{ID: "protobuf.contracts", Name: "Protocol Buffer service contracts", Mode: ActivationAlways}, + matchLanguages("protobuf"), + func(ctx context.Context, input FileInput, emit FactEmitter) error { + if generatedLike(input.RelPath, input.Source) { + return nil + } + return emitServiceMatches(input, emit, "grpc.contract", "protobuf", regexp.MustCompile(`(?m)^\s*service\s+([A-Za-z_][A-Za-z0-9_]*)\s*\{`), []string{"protocol:grpc", "arch:contract"}) + }, + ) +} + +func GoGRPC() Enricher { + return enrich.NewEnricher( + Metadata{ + ID: "go.grpc", Name: "Go gRPC glue", Mode: ActivationImportOrDependency, + Triggers: []ActivationSignal{{Kind: SignalImport, Value: "google.golang.org/grpc"}, {Kind: SignalDependency, Value: "google.golang.org/grpc"}}, + }, + matchLanguages("go"), + func(ctx context.Context, input FileInput, emit FactEmitter) error { + if err := emitServiceMatches(input, emit, "grpc.server", "go", regexp.MustCompile(`\b(?:[A-Za-z_][A-Za-z0-9_]*\.)?Register([A-Za-z_][A-Za-z0-9_]*)Server\s*\(`), []string{"protocol:grpc", "grpc:server", "framework:go-grpc"}); err != nil { + return err + } + if err := emitServiceMatches(input, emit, "grpc.client", "go", regexp.MustCompile(`\b(?:[A-Za-z_][A-Za-z0-9_]*\.)?New([A-Za-z_][A-Za-z0-9_]*)Client\s*\(`), []string{"protocol:grpc", "grpc:client", "framework:go-grpc"}); err != nil { + return err + } + return emitEndpointReads(input, emit, "go") + }, + ) +} + +func PythonGRPC() Enricher { + return enrich.NewEnricher( + Metadata{ + ID: "python.grpc", Name: "Python grpcio glue", Mode: ActivationImportOrDependency, + Triggers: []ActivationSignal{{Kind: SignalImport, Value: "grpc"}, {Kind: SignalDependency, Value: "grpcio"}}, + }, + matchLanguages("python"), + func(ctx context.Context, input FileInput, emit FactEmitter) error { + if generatedLike(input.RelPath, input.Source) { + return nil + } + if err := emitServiceMatches(input, emit, "grpc.server", "python", regexp.MustCompile(`\badd_([A-Za-z_][A-Za-z0-9_]*)Servicer_to_server\s*\(`), []string{"protocol:grpc", "grpc:server", "framework:python-grpc"}); err != nil { + return err + } + if err := emitServiceMatches(input, emit, "grpc.client", "python", regexp.MustCompile(`\b([A-Za-z_][A-Za-z0-9_]*)Stub\s*\(`), []string{"protocol:grpc", "grpc:client", "framework:python-grpc"}); err != nil { + return err + } + return emitEndpointReads(input, emit, "python") + }, + ) +} + +func NodeGRPC() Enricher { + return enrich.NewEnricher( + Metadata{ + ID: "node.grpc", Name: "Node gRPC glue", Mode: ActivationImportOrDependency, + Triggers: []ActivationSignal{{Kind: SignalImport, Value: "@grpc/grpc-js"}, {Kind: SignalDependency, Value: "@grpc/grpc-js"}, {Kind: SignalDependency, Value: "grpc"}}, + }, + matchLanguages("javascript", "typescript"), + func(ctx context.Context, input FileInput, emit FactEmitter) error { + if err := emitServiceMatches(input, emit, "grpc.server", "node", regexp.MustCompile(`\.addService\(\s*[^,\n]*\.([A-Za-z_][A-Za-z0-9_]*)\.service`), []string{"protocol:grpc", "grpc:server", "framework:node-grpc"}); err != nil { + return err + } + return emitEndpointReads(input, emit, "node") + }, + ) +} + +func JavaGRPC() Enricher { + return enrich.NewEnricher( + Metadata{ + ID: "java.grpc", Name: "Java gRPC glue", Mode: ActivationImportOrDependency, + Triggers: []ActivationSignal{{Kind: SignalImport, Value: "io.grpc"}, {Kind: SignalDependency, Value: "io.grpc"}}, + }, + func(input FileInput) bool { return matchLanguages("java", "gradle")(input) }, + func(ctx context.Context, input FileInput, emit FactEmitter) error { + if input.Language == "gradle" { + return emitBuildDependencyFact(input, emit, "io.grpc", "grpc", "java-grpc") + } + if err := emitServiceMatches(input, emit, "grpc.server", "java", regexp.MustCompile(`\b([A-Za-z_][A-Za-z0-9_]*)Grpc\.([A-Za-z_][A-Za-z0-9_]*)ImplBase\b`), []string{"protocol:grpc", "grpc:server", "framework:java-grpc"}); err != nil { + return err + } + return emitServiceMatches(input, emit, "grpc.server", "java", regexp.MustCompile(`\bServerBuilder\.forPort\s*\(`), []string{"protocol:grpc", "grpc:server", "framework:java-grpc"}) + }, + ) +} + +func DotNetGRPC() Enricher { + return enrich.NewEnricher( + Metadata{ + ID: "dotnet.grpc", Name: ".NET gRPC glue", Mode: ActivationImportOrDependency, + Triggers: []ActivationSignal{{Kind: SignalDependency, Value: "Grpc.AspNetCore"}, {Kind: SignalImport, Value: "Grpc.Core"}}, + }, + matchLanguages("c-sharp", "xml"), + func(ctx context.Context, input FileInput, emit FactEmitter) error { + if err := emitServiceMatches(input, emit, "grpc.server", "dotnet", regexp.MustCompile(`\bMapGrpcService<([A-Za-z_][A-Za-z0-9_.]*)>\s*\(`), []string{"protocol:grpc", "grpc:server", "framework:dotnet-grpc"}); err != nil { + return err + } + if err := emitServiceMatches(input, emit, "grpc.contract", "dotnet", regexp.MustCompile(`]*GrpcServices=["']([^"']+)["']`), []string{"protocol:grpc", "arch:contract", "framework:dotnet-grpc"}); err != nil { + return err + } + return emitEndpointReads(input, emit, "dotnet") + }, + ) +} + +func emitServiceMatches(input FileInput, emit FactEmitter, factType, framework string, re *regexp.Regexp, tags []string) error { + source := string(input.Source) + for _, indexes := range re.FindAllStringSubmatchIndex(source, -1) { + match := submatches(source, indexes) + line := lineForOffset(source, indexes[0]) + name := "" + if len(match) > 1 { + name = normalizeServiceName(match[1]) + } + if name == "" { + name = inferredServiceNameFromPath(input.RelPath) + } + if name == "" { + continue + } + relationship := "declares" + if strings.HasSuffix(factType, ".client") { + relationship = "calls" + } + if err := emit.EmitFact(Fact{ + Type: factType, + StableKey: fmt.Sprintf("%s:%s:%s:%s:%d", factType, framework, input.RelPath, name, line), + Subject: subjectForLine(input, line), + Object: SubjectRef{Kind: factType, StableKey: factType + ":" + framework + ":" + name, FilePath: input.RelPath, Name: name}, + Relationship: relationship, + Source: SourceSpan{FilePath: input.RelPath, StartLine: line, EndLine: line}, + Confidence: 0.86, + Name: name, + Tags: tags, + Attributes: map[string]string{"framework": framework, "service": name}, + VisibilityHints: map[string]float64{"high_signal": 1}, + }); err != nil { + return err + } + } + return nil +} + +func emitEndpointReads(input FileInput, emit FactEmitter, framework string) error { + source := string(input.Source) + patterns := []*regexp.Regexp{ + regexp.MustCompile(`\b(?:os\.Getenv|os\.LookupEnv|os\.environ\.get|process\.env(?:\[[^\]]+\]|\.[A-Za-z_][A-Za-z0-9_]*)|Configuration\[[^\]]+\])\s*\(?\s*["']?([A-Z0-9_]*(?:ADDR|HOST|URL|PORT|REDIS|SPANNER|ALLOYDB|COLLECTOR)[A-Z0-9_]*)["']?`), + regexp.MustCompile(`\bmustMapEnv\([^,\n]+,\s*"([A-Z0-9_]+)"\s*\)`), + } + for _, re := range patterns { + for _, indexes := range re.FindAllStringSubmatchIndex(source, -1) { + match := submatches(source, indexes) + if len(match) < 2 { + continue + } + env := strings.Trim(match[1], `"'[]`) + target := "" + line := lineForOffset(source, indexes[0]) + if err := emit.EmitFact(Fact{ + Type: "runtime.endpoint_ref", + StableKey: fmt.Sprintf("runtime.endpoint_ref:%s:%s:%d", input.RelPath, env, line), + Subject: subjectForLine(input, line), + Object: SubjectRef{Kind: "runtime.endpoint", StableKey: "runtime.endpoint:" + target, Name: target}, + Relationship: "uses", + Source: SourceSpan{FilePath: input.RelPath, StartLine: line, EndLine: line}, + Confidence: 0.62, + Name: env, + Tags: []string{"arch:endpoint-ref", "framework:" + framework}, + Attributes: map[string]string{"env": env, "target": target, "framework": framework}, + VisibilityHints: map[string]float64{"high_signal": 0.5}, + }); err != nil { + return err + } + } + } + return nil +} + +func emitBuildDependencyFact(input FileInput, emit FactEmitter, needle, name, framework string) error { + if !strings.Contains(string(input.Source), needle) { + return nil + } + return emit.EmitFact(Fact{ + Type: "dependency.module", + StableKey: fmt.Sprintf("dependency.module:%s:%s", input.RelPath, name), + Subject: fileSubject(input.RelPath), + Object: SubjectRef{Kind: "dependency.module", StableKey: "dependency.module:" + name, Name: name}, + Relationship: "declares_dependency", + Source: SourceSpan{FilePath: input.RelPath, StartLine: 1, EndLine: 1}, + Confidence: 1, + Name: name, + Tags: []string{"dependency:module", "framework:" + framework}, + Attributes: map[string]string{"module": name, "ecosystem": framework}, + VisibilityHints: map[string]float64{"dependency": 1}, + }) +} + +func normalizeServiceName(value string) string { + value = strings.TrimSpace(value) + value = strings.TrimSuffix(value, "Servicer") + value = strings.TrimSuffix(value, "Service") + if value == "Health" || value == "grpc" || value == "" { + return "" + } + return lowerCamelToService(value) +} + +func lowerCamelToService(value string) string { + var b strings.Builder + for i, r := range value { + if r >= 'A' && r <= 'Z' { + if i > 0 { + prev := rune(value[i-1]) + if prev >= 'a' && prev <= 'z' { + b.WriteByte('-') + } + } + r += 'a' - 'A' + } + if r == '_' || r == '.' { + b.WriteByte('-') + continue + } + b.WriteRune(r) + } + out := strings.Trim(b.String(), "-") + if out == "" { + return "" + } + if !strings.HasSuffix(out, "service") && strings.Contains(strings.ToLower(value), "Service") { + out += "service" + } + return out +} + +func inferredServiceNameFromPath(rel string) string { + parts := strings.Split(path.Clean(filepathSlash(rel)), "/") + for i, part := range parts { + if part == "src" && i+1 < len(parts) { + return parts[i+1] + } + } + if len(parts) > 1 { + return parts[len(parts)-2] + } + return "" +} + +func filepathSlash(value string) string { + return strings.ReplaceAll(value, "\\", "/") +} + +func generatedLike(rel string, data []byte) bool { + lowerPath := strings.ToLower(rel) + head := strings.ToLower(string(data[:min(len(data), 4096)])) + return strings.Contains(lowerPath, "genproto/") || strings.Contains(lowerPath, "_pb2") || strings.Contains(head, "code generated") || strings.Contains(head, "generated by") +} diff --git a/internal/watch/enrich/enrichers/rpc/grpc/grpc_test.go b/internal/watch/enrich/enrichers/rpc/grpc/grpc_test.go new file mode 100644 index 0000000..55d9c10 --- /dev/null +++ b/internal/watch/enrich/enrichers/rpc/grpc/grpc_test.go @@ -0,0 +1,22 @@ +package grpc + +import ( + "testing" + + "github.com/mertcikla/tld/internal/watch/enrich" + "github.com/mertcikla/tld/internal/watch/enrich/enrichertest" +) + +func TestGRPCEnrichers(t *testing.T) { + enrichertest.Run(t, enrichertest.Case{ + Name: "go grpc client requires activation and matches generated client constructor", + Enricher: GoGRPC(), + Input: enrich.FileInput{ + RelPath: "src/frontend/rpc.go", + Language: "go", + Source: []byte(`func f() { _ = pb.NewCartServiceClient(conn) }`), + }, + Signals: []enrich.ActivationSignal{{Kind: enrich.SignalImport, Value: "google.golang.org/grpc"}}, + Want: enrichertest.Fact{Type: "grpc.client", Tag: "grpc:client", Name: "cart"}, + }) +} diff --git a/internal/watch/enrich/enrichers/runtime/runtime.go b/internal/watch/enrich/enrichers/runtime/runtime.go new file mode 100644 index 0000000..d6a6d41 --- /dev/null +++ b/internal/watch/enrich/enrichers/runtime/runtime.go @@ -0,0 +1,248 @@ +package runtimeenrich + +import ( + "context" + "fmt" + "path" + "regexp" + "strings" + + "github.com/mertcikla/tld/internal/watch/enrich" +) + +type Enricher = enrich.Enricher +type Fact = enrich.Fact +type FactEmitter = enrich.FactEmitter +type FileInput = enrich.FileInput +type Metadata = enrich.Metadata +type SourceSpan = enrich.SourceSpan +type SubjectRef = enrich.SubjectRef + +const ActivationAlways = enrich.ActivationAlways + +var ( + fileSubject = enrich.FileSubject + lineForOffset = enrich.LineForOffset + matchLanguages = enrich.MatchLanguages + submatches = enrich.Submatches +) + +func RuntimeManifests() Enricher { + return enrich.NewEnricher( + Metadata{ID: "runtime.manifests", Name: "Runtime manifests", Mode: ActivationAlways}, + matchLanguages("yaml", "terraform", "json"), + func(ctx context.Context, input FileInput, emit FactEmitter) error { + switch input.Language { + case "yaml": + return emitRuntimeYAMLFacts(input, emit) + case "terraform": + return emitTerraformFacts(input, emit) + case "json": + if path.Base(input.RelPath) == "package.json" || path.Base(input.RelPath) == "package-lock.json" { + return nil + } + return emitOpenAPIFact(input, emit) + default: + return nil + } + }, + ) +} + +func emitRuntimeYAMLFacts(input FileInput, emit FactEmitter) error { + source := string(input.Source) + if !strings.Contains(strings.ToLower(source), "kind:") && !strings.Contains(strings.ToLower(source), "services:") { + return nil + } + componentRE := regexp.MustCompile(`(?m)^\s*(?:kind:\s*(Deployment|StatefulSet|DaemonSet|Job|CronJob|Pod|Service)|name:\s*([A-Za-z0-9_.-]+)|image:\s*([^#\n]+)|value:\s*["']?([^"'\n]+)["']?)`) + lines := strings.Split(source, "\n") + var lastKind, lastName, lastEnvVar string + for i, line := range lines { + match := componentRE.FindStringSubmatch(line) + if len(match) == 0 { + continue + } + if match[1] != "" { + if runtimeWorkloadKind(match[1]) { + lastKind = match[1] + lastName = "" + } + continue + } + if match[2] != "" { + if lastKind != "" && lastName == "" { + lastName = match[2] + if err := emitComponentFact(input, emit, lastName, "service", "Kubernetes", i+1, []string{"runtime:kubernetes", "arch:deployable"}, map[string]string{"runtime": "kubernetes", "kind": lastKind}); err != nil { + return err + } + } else { + lastEnvVar = match[2] + } + continue + } + if match[4] != "" && lastName != "" { + target := endpointName(match[4], lastEnvVar) + if target != "" && target != lastName { + if err := emitConnectorFact(input, emit, lastName, target, protocolFromValue(match[4]), "runtime-dependency", i+1, "runtime manifest env endpoint", 0.78); err != nil { + return err + } + } + } + } + return nil +} + +func emitTerraformFacts(input FileInput, emit FactEmitter) error { + re := regexp.MustCompile(`(?m)^\s*resource\s+"([^"]+)"\s+"([^"]+)"`) + for _, indexes := range re.FindAllStringSubmatchIndex(string(input.Source), -1) { + match := submatches(string(input.Source), indexes) + if len(match) < 3 { + continue + } + line := lineForOffset(string(input.Source), indexes[0]) + kind, tech := infrastructureKind(match[1]) + if kind == "" { + continue + } + if err := emitComponentFact(input, emit, match[2], kind, tech, line, []string{"arch:infrastructure"}, map[string]string{"resource_type": match[1]}); err != nil { + return err + } + } + return nil +} + +func emitOpenAPIFact(input FileInput, emit FactEmitter) error { + if !strings.Contains(strings.ToLower(string(input.Source)), `"openapi"`) { + return nil + } + name := strings.TrimSuffix(path.Base(input.RelPath), path.Ext(input.RelPath)) + return emitComponentFact(input, emit, name, "interface", "OpenAPI", 1, []string{"arch:contract", "protocol:http"}, map[string]string{"protocol": "http"}) +} + +func emitComponentFact(input FileInput, emit FactEmitter, name, kind, technology string, line int, tags []string, attrs map[string]string) error { + if attrs == nil { + attrs = map[string]string{} + } + attrs["name"] = name + attrs["kind"] = kind + attrs["technology"] = technology + return emit.EmitFact(Fact{ + Type: "runtime.component", + StableKey: fmt.Sprintf("runtime.component:%s:%s:%d", input.RelPath, name, line), + Subject: fileSubject(input.RelPath), + Object: SubjectRef{Kind: "runtime.component", StableKey: "runtime.component:" + name, FilePath: input.RelPath, Name: name}, + Relationship: "declares", + Source: SourceSpan{FilePath: input.RelPath, StartLine: line, EndLine: line}, + Confidence: 0.8, + Name: name, + Tags: append(tags, "arch:component"), + Attributes: attrs, + VisibilityHints: map[string]float64{"high_signal": 0.8}, + }) +} + +func emitConnectorFact(input FileInput, emit FactEmitter, source, target, label, relationship string, line int, note string, confidence float64) error { + if label == "" { + label = "uses" + } + return emit.EmitFact(Fact{ + Type: "runtime.connection", + StableKey: fmt.Sprintf("runtime.connection:%s:%s:%s:%s:%d", input.RelPath, source, target, relationship, line), + Subject: fileSubject(input.RelPath), + Object: SubjectRef{Kind: "runtime.component", StableKey: "runtime.component:" + target, Name: target}, + Relationship: relationship, + Source: SourceSpan{FilePath: input.RelPath, StartLine: line, EndLine: line}, + Confidence: confidence, + Name: source + " -> " + target, + Tags: []string{"arch:connection"}, + Attributes: map[string]string{"source": source, "target": target, "label": label, "note": note}, + VisibilityHints: map[string]float64{"high_signal": 1}, + }) +} + +func endpointName(value, envName string) string { + value = strings.Trim(strings.TrimSpace(value), `"'`) + if value == "" || strings.Contains(value, "{{") { + return "" + } + lower := strings.ToLower(value) + if lower == "true" || lower == "false" || lower == "null" { + return "" + } + if regexp.MustCompile(`^\d+$`).MatchString(value) { + return "" + } + + // Tighten: Require a protocol scheme OR a high-signal environment variable name. + hasScheme := strings.Contains(value, "://") + lowerEnv := strings.ToLower(envName) + highSignalName := strings.HasSuffix(lowerEnv, "_url") || strings.HasSuffix(lowerEnv, "_host") || + strings.HasSuffix(lowerEnv, "_uri") || strings.HasSuffix(lowerEnv, "_endpoint") || + strings.HasSuffix(lowerEnv, "_address") || strings.HasSuffix(lowerEnv, "_addr") + + if !hasScheme && !highSignalName { + return "" + } + + if hasScheme { + parts := strings.SplitN(value, "://", 2) + value = parts[1] + } + if strings.Contains(value, ":") { + value = strings.Split(value, ":")[0] + } + value = strings.Trim(value, "/") + if strings.Contains(value, ".") { + value = strings.Split(value, ".")[0] + } + value = strings.ToLower(value) + if strings.HasPrefix(value, "$") || strings.ContainsAny(value, " /\\_=") { + return "" + } + if strings.HasSuffix(value, "-addr") || strings.HasSuffix(value, "-url") || strings.HasSuffix(value, "-host") || strings.HasSuffix(value, "-port") { + return "" + } + if !regexp.MustCompile(`[a-z]`).MatchString(value) { + return "" + } + return strings.Trim(value, "-_") +} + +func runtimeWorkloadKind(kind string) bool { + switch strings.ToLower(kind) { + case "deployment", "statefulset", "daemonset", "job", "cronjob", "pod", "service": + return true + default: + return false + } +} + +func protocolFromValue(value string) string { + lower := strings.ToLower(value) + switch { + case strings.Contains(lower, "redis"): + return "redis" + case strings.HasPrefix(lower, "http://"), strings.HasPrefix(lower, "https://"): + return "http" + case strings.Contains(lower, ":"): + return "grpc" + default: + return "uses" + } +} + +func infrastructureKind(resourceType string) (string, string) { + lower := strings.ToLower(resourceType) + switch { + case strings.Contains(lower, "redis"), strings.Contains(lower, "memcache"), strings.Contains(lower, "cache"): + return "datastore", "Cache" + case strings.Contains(lower, "sql"), strings.Contains(lower, "database"), strings.Contains(lower, "spanner"), strings.Contains(lower, "alloydb"), strings.Contains(lower, "postgres"), strings.Contains(lower, "mysql"): + return "datastore", "Database" + case strings.Contains(lower, "queue"), strings.Contains(lower, "pubsub"), strings.Contains(lower, "topic"), strings.Contains(lower, "subscription"): + return "queue", "Messaging" + case strings.Contains(lower, "bucket"), strings.Contains(lower, "storage"): + return "datastore", "Object Storage" + default: + return "", "" + } +} diff --git a/internal/watch/enrich/enrichers/runtime/runtime_test.go b/internal/watch/enrich/enrichers/runtime/runtime_test.go new file mode 100644 index 0000000..ed74103 --- /dev/null +++ b/internal/watch/enrich/enrichers/runtime/runtime_test.go @@ -0,0 +1,25 @@ +package runtimeenrich + +import ( + "testing" + + "github.com/mertcikla/tld/internal/watch/enrich" + "github.com/mertcikla/tld/internal/watch/enrich/enrichertest" +) + +func TestRuntimeManifests(t *testing.T) { + enrichertest.Run(t, enrichertest.Case{ + Name: "kubernetes manifest matches workload name", + Enricher: RuntimeManifests(), + Input: enrich.FileInput{ + RelPath: "k8s/frontend.yaml", + Language: "yaml", + Source: []byte(`apiVersion: apps/v1 +kind: Deployment +metadata: + name: frontend +`), + }, + Want: enrichertest.Fact{Type: "runtime.component", Tag: "runtime:kubernetes", Name: "frontend"}, + }) +} diff --git a/internal/watch/enrich/enrichers/secrets/secrets.go b/internal/watch/enrich/enrichers/secrets/secrets.go new file mode 100644 index 0000000..fb5c779 --- /dev/null +++ b/internal/watch/enrich/enrichers/secrets/secrets.go @@ -0,0 +1,49 @@ +package secrets + +import ( + "github.com/mertcikla/tld/internal/watch/enrich" + "github.com/mertcikla/tld/internal/watch/enrich/enrichers/pattern" +) + +func All() []enrich.Enricher { return pattern.FromSpecs(Specs()) } + +func Specs() []pattern.Spec { + providers := []struct { + id string + name string + token string + }{ + {"aws_secrets_manager", "AWS Secrets Manager", "secretsmanager"}, + {"aws_ssm", "AWS SSM Parameter Store", "ssm.get_parameter"}, + {"gcp_secret_manager", "GCP Secret Manager", "google.cloud.secretmanager"}, + {"azure_key_vault", "Azure Key Vault", "vault.azure.net"}, + {"kubernetes_secrets", "Kubernetes Secrets", "secretKeyRef"}, + {"vault", "Vault", "vault.hashicorp.com"}, + {"doppler", "Doppler", "DOPPLER_TOKEN"}, + {"onepassword", "1Password Secrets Automation", "OP_SERVICE_ACCOUNT_TOKEN"}, + } + var specs []pattern.Spec + for _, provider := range providers { + specs = append(specs, + spec("secrets.code."+provider.id, provider.name+" code reference", "go", provider.token), + spec("secrets.config."+provider.id, provider.name+" config reference", "yaml", provider.token), + spec("secrets.iac."+provider.id, provider.name+" IaC reference", "hcl", provider.token), + ) + } + return specs +} + +func spec(id, name, language, token string) pattern.Spec { + return pattern.Spec{ + ID: id, + Name: name, + Category: "secrets", + Languages: []string{language}, + Mode: enrich.ActivationAlways, + FactType: "secret.provider", + Relationship: "uses_secret", + SourceTokens: []string{token}, + Tags: []string{"secrets:" + id}, + Attributes: map[string]string{"surface": language}, + } +} diff --git a/internal/watch/enrich/enrichers/secrets/secrets_test.go b/internal/watch/enrich/enrichers/secrets/secrets_test.go new file mode 100644 index 0000000..0cedaa1 --- /dev/null +++ b/internal/watch/enrich/enrichers/secrets/secrets_test.go @@ -0,0 +1,32 @@ +package secrets + +import ( + "testing" + + "github.com/mertcikla/tld/internal/watch/enrich" + "github.com/mertcikla/tld/internal/watch/enrich/enrichertest" +) + +func TestSecretEnrichers(t *testing.T) { + byID := enrichersByID() + for _, spec := range Specs() { + enrichertest.Run(t, enrichertest.Case{ + Name: spec.ID, + Enricher: byID[spec.ID], + Input: enrich.FileInput{ + RelPath: "deploy/secrets", + Language: spec.Languages[0], + Source: []byte(spec.SourceTokens[0]), + }, + Want: enrichertest.Fact{Type: spec.FactType, Tag: "category:secrets", Name: spec.Name, Attribute: "technology", AttrValue: spec.Name}, + }) + } +} + +func enrichersByID() map[string]enrich.Enricher { + out := map[string]enrich.Enricher{} + for _, enricher := range All() { + out[enricher.Metadata().ID] = enricher + } + return out +} diff --git a/internal/watch/enrich/enrichers/storage/storage.go b/internal/watch/enrich/enrichers/storage/storage.go new file mode 100644 index 0000000..7440a81 --- /dev/null +++ b/internal/watch/enrich/enrichers/storage/storage.go @@ -0,0 +1,53 @@ +package storage + +import ( + "github.com/mertcikla/tld/internal/watch/enrich" + "github.com/mertcikla/tld/internal/watch/enrich/enrichers/pattern" +) + +func All() []enrich.Enricher { return pattern.FromSpecs(Specs()) } + +func Specs() []pattern.Spec { + return []pattern.Spec{ + spec("ts.redis", "TypeScript Redis/ioredis", "typescript", "ioredis", "Redis(", "cache.key", "caches"), + spec("ts.mongodb", "TypeScript MongoDB driver", "typescript", "mongodb", "MongoClient", "storage.collection", "reads_from"), + spec("ts.elasticsearch", "TypeScript Elasticsearch client", "typescript", "@elastic/elasticsearch", "Client(", "storage.index", "indexes"), + spec("ts.opensearch", "TypeScript OpenSearch client", "typescript", "@opensearch-project/opensearch", "OpenSearch", "storage.index", "indexes"), + spec("go.redis", "Go go-redis", "go", "github.com/redis/go-redis", "redis.NewClient", "cache.key", "caches"), + spec("go.mongodb", "MongoDB Go driver", "go", "go.mongodb.org/mongo-driver", "mongo.Connect", "storage.collection", "reads_from"), + spec("go.elasticsearch", "Go Elasticsearch client", "go", "github.com/elastic/go-elasticsearch", "elasticsearch.NewClient", "storage.index", "indexes"), + spec("go.opensearch", "Go OpenSearch client", "go", "github.com/opensearch-project/opensearch-go", "opensearch.NewClient", "storage.index", "indexes"), + spec("python.redis", "Python redis-py", "python", "redis", "redis.Redis", "cache.key", "caches"), + spec("python.pymongo", "Python PyMongo", "python", "pymongo", "MongoClient", "storage.collection", "reads_from"), + spec("python.elasticsearch", "Python Elasticsearch", "python", "elasticsearch", "Elasticsearch(", "storage.index", "indexes"), + spec("python.opensearch", "Python opensearch-py", "python", "opensearch-py", "OpenSearch(", "storage.index", "indexes"), + spec("java.lettuce", "Java Lettuce", "java", "io.lettuce", "RedisClient", "cache.key", "caches"), + spec("java.jedis", "Java Jedis", "java", "redis.clients", "Jedis", "cache.key", "caches"), + spec("java.mongodb", "MongoDB Java driver", "java", "mongodb-driver", "MongoClient", "storage.collection", "reads_from"), + spec("java.elasticsearch", "Elasticsearch Java client", "java", "co.elastic.clients", "ElasticsearchClient", "storage.index", "indexes"), + spec("java.opensearch", "OpenSearch Java client", "java", "org.opensearch.client", "OpenSearchClient", "storage.index", "indexes"), + spec("rust.redis", "Rust redis", "rust", "redis", "redis::", "cache.key", "caches"), + spec("rust.mongodb", "Rust mongodb", "rust", "mongodb", "mongodb::", "storage.collection", "reads_from"), + spec("rust.elasticsearch", "Rust elasticsearch", "rust", "elasticsearch", "elasticsearch::", "storage.index", "indexes"), + spec("rust.opensearch", "Rust opensearch", "rust", "opensearch", "opensearch::", "storage.index", "indexes"), + spec("cpp.redis_plus_plus", "C++ redis-plus-plus", "cpp", "redis-plus-plus", "sw::redis", "cache.key", "caches"), + spec("cpp.mongodb", "MongoDB C++ driver", "cpp", "mongo-cxx-driver", "mongocxx::", "storage.collection", "reads_from"), + spec("cpp.elasticsearch_http", "C++ Elasticsearch/OpenSearch HTTP", "cpp", "elasticsearch", "elasticsearch", "storage.index", "indexes"), + } +} + +func spec(id, name, language, dependency, token, factType, relationship string) pattern.Spec { + return pattern.Spec{ + ID: id, + Name: name, + Category: "storage", + Languages: []string{language}, + Mode: enrich.ActivationImportOrDependency, + Triggers: []enrich.ActivationSignal{{Kind: enrich.SignalDependency, Value: dependency}, {Kind: enrich.SignalImport, Value: dependency}}, + FactType: factType, + Relationship: relationship, + SourceTokens: []string{token}, + Tags: []string{"storage:" + id}, + Attributes: map[string]string{"dependency": dependency, "language": language}, + } +} diff --git a/internal/watch/enrich/enrichers/storage/storage_test.go b/internal/watch/enrich/enrichers/storage/storage_test.go new file mode 100644 index 0000000..c940111 --- /dev/null +++ b/internal/watch/enrich/enrichers/storage/storage_test.go @@ -0,0 +1,33 @@ +package storage + +import ( + "testing" + + "github.com/mertcikla/tld/internal/watch/enrich" + "github.com/mertcikla/tld/internal/watch/enrich/enrichertest" +) + +func TestStorageEnrichers(t *testing.T) { + byID := enrichersByID() + for _, spec := range Specs() { + enrichertest.Run(t, enrichertest.Case{ + Name: spec.ID, + Enricher: byID[spec.ID], + Input: enrich.FileInput{ + RelPath: "src/storage", + Language: spec.Languages[0], + Source: []byte(spec.SourceTokens[0]), + }, + Signals: []enrich.ActivationSignal{{Kind: enrich.SignalDependency, Value: spec.Triggers[0].Value}}, + Want: enrichertest.Fact{Type: spec.FactType, Tag: "category:storage", Name: spec.Name, Attribute: "technology", AttrValue: spec.Name}, + }) + } +} + +func enrichersByID() map[string]enrich.Enricher { + out := map[string]enrich.Enricher{} + for _, enricher := range All() { + out[enricher.Metadata().ID] = enricher + } + return out +} diff --git a/internal/watch/enrich/enrichers/traffic/python/traffic.go b/internal/watch/enrich/enrichers/traffic/python/traffic.go new file mode 100644 index 0000000..374fe7a --- /dev/null +++ b/internal/watch/enrich/enrichers/traffic/python/traffic.go @@ -0,0 +1,43 @@ +package python + +import ( + "context" + "regexp" + + "github.com/mertcikla/tld/internal/watch/enrich" +) + +type ActivationSignal = enrich.ActivationSignal +type Enricher = enrich.Enricher +type FactEmitter = enrich.FactEmitter +type FileInput = enrich.FileInput +type Metadata = enrich.Metadata +type RoutePattern = enrich.RoutePattern + +const ( + ActivationImportOrDependency = enrich.ActivationImportOrDependency + SignalDependency = enrich.SignalDependency + SignalImport = enrich.SignalImport +) + +var matchLanguages = enrich.MatchLanguages + +func PythonLocust() Enricher { + return enrich.NewEnricher( + Metadata{ + ID: "python.locust", Name: "Locust HTTP traffic", Mode: ActivationImportOrDependency, + Triggers: []ActivationSignal{{Kind: SignalImport, Value: "locust"}, {Kind: SignalDependency, Value: "locust"}}, + }, + matchLanguages("python"), + func(ctx context.Context, input FileInput, emit FactEmitter) error { + return enrich.EmitMatches(input, emit, []*RoutePattern{{ + Re: regexp.MustCompile(`\bclient\.(get|post|put|delete|patch)\(\s*["']([^"']+)["']`), + FactType: "http.client", + Framework: "locust", + MethodGroup: 1, + PathGroup: 2, + Tags: []string{"http:client", "framework:locust"}, + }}) + }, + ) +} diff --git a/internal/watch/enrich/enrichers/traffic/python/traffic_test.go b/internal/watch/enrich/enrichers/traffic/python/traffic_test.go new file mode 100644 index 0000000..69262d2 --- /dev/null +++ b/internal/watch/enrich/enrichers/traffic/python/traffic_test.go @@ -0,0 +1,22 @@ +package python + +import ( + "testing" + + "github.com/mertcikla/tld/internal/watch/enrich" + "github.com/mertcikla/tld/internal/watch/enrich/enrichertest" +) + +func TestPythonTrafficEnrichers(t *testing.T) { + enrichertest.Run(t, enrichertest.Case{ + Name: "locust client call requires activation and matches request", + Enricher: PythonLocust(), + Input: enrich.FileInput{ + RelPath: "load_test.py", + Language: "python", + Source: []byte(`self.client.get("/checkout")`), + }, + Signals: []enrich.ActivationSignal{{Kind: enrich.SignalDependency, Value: "locust"}}, + Want: enrichertest.Fact{Type: "http.client", Tag: "framework:locust", Name: "GET /checkout"}, + }) +} diff --git a/internal/watch/enrich/enrichers/web3/web3.go b/internal/watch/enrich/enrichers/web3/web3.go new file mode 100644 index 0000000..80a4ea5 --- /dev/null +++ b/internal/watch/enrich/enrichers/web3/web3.go @@ -0,0 +1,35 @@ +package web3 + +import ( + "github.com/mertcikla/tld/internal/watch/enrich" + "github.com/mertcikla/tld/internal/watch/enrich/enrichers/pattern" +) + +func All() []enrich.Enricher { return pattern.FromSpecs(Specs()) } + +func Specs() []pattern.Spec { + return []pattern.Spec{ + spec("ts.ethers", "TypeScript ethers.js", "typescript", "ethers", "JsonRpcProvider", "web3.rpc_endpoint", "connects_to_chain"), + spec("ts.web3js", "TypeScript web3.js", "typescript", "web3", "new Web3", "web3.rpc_endpoint", "connects_to_chain"), + spec("python.web3py", "Python web3.py", "python", "web3", "Web3.HTTPProvider", "web3.rpc_endpoint", "connects_to_chain"), + spec("solidity.foundry", "Foundry", "toml", "forge-std", "foundry.toml", "web3.chain_id", "connects_to_chain"), + spec("ts.hardhat", "Hardhat", "typescript", "hardhat", "hardhat.config", "web3.chain_id", "connects_to_chain"), + } +} + +func spec(id, name, language, dependency, token, factType, relationship string) pattern.Spec { + return pattern.Spec{ + ID: id, + Name: name, + Category: "web3", + Languages: []string{language}, + Mode: enrich.ActivationImportOrDependency, + Triggers: []enrich.ActivationSignal{{Kind: enrich.SignalDependency, Value: dependency}, {Kind: enrich.SignalImport, Value: dependency}}, + FactType: factType, + Relationship: relationship, + SourceTokens: []string{token}, + PathTokens: []string{token}, + Tags: []string{"web3:" + id}, + Attributes: map[string]string{"dependency": dependency, "language": language}, + } +} diff --git a/internal/watch/enrich/enrichers/web3/web3_test.go b/internal/watch/enrich/enrichers/web3/web3_test.go new file mode 100644 index 0000000..cd309f8 --- /dev/null +++ b/internal/watch/enrich/enrichers/web3/web3_test.go @@ -0,0 +1,33 @@ +package web3 + +import ( + "testing" + + "github.com/mertcikla/tld/internal/watch/enrich" + "github.com/mertcikla/tld/internal/watch/enrich/enrichertest" +) + +func TestWeb3Enrichers(t *testing.T) { + byID := enrichersByID() + for _, spec := range Specs() { + enrichertest.Run(t, enrichertest.Case{ + Name: spec.ID, + Enricher: byID[spec.ID], + Input: enrich.FileInput{ + RelPath: "src/chain", + Language: spec.Languages[0], + Source: []byte(spec.SourceTokens[0]), + }, + Signals: []enrich.ActivationSignal{{Kind: enrich.SignalDependency, Value: spec.Triggers[0].Value}}, + Want: enrichertest.Fact{Type: spec.FactType, Tag: "category:web3", Name: spec.Name, Attribute: "technology", AttrValue: spec.Name}, + }) + } +} + +func enrichersByID() map[string]enrich.Enricher { + out := map[string]enrich.Enricher{} + for _, enricher := range All() { + out[enricher.Metadata().ID] = enricher + } + return out +} diff --git a/internal/watch/enrich/enrichers/workspace/workspace.go b/internal/watch/enrich/enrichers/workspace/workspace.go new file mode 100644 index 0000000..ec5e54e --- /dev/null +++ b/internal/watch/enrich/enrichers/workspace/workspace.go @@ -0,0 +1,38 @@ +package workspace + +import ( + "github.com/mertcikla/tld/internal/watch/enrich" + "github.com/mertcikla/tld/internal/watch/enrich/enrichers/pattern" +) + +func All() []enrich.Enricher { return pattern.FromSpecs(Specs()) } + +func Specs() []pattern.Spec { + return []pattern.Spec{ + spec("workspace.nx", "Nx", "json", []string{"nx.json"}, []string{"\"projects\""}), + spec("workspace.turborepo", "Turborepo", "json", []string{"turbo.json"}, []string{"\"pipeline\""}), + spec("workspace.pnpm", "pnpm workspaces", "yaml", []string{"pnpm-workspace.yaml"}, []string{"packages:"}), + spec("workspace.yarn", "Yarn workspaces", "json", nil, []string{"\"workspaces\""}), + spec("workspace.bazel", "Bazel", "bazel", []string{"WORKSPACE", "BUILD.bazel"}, []string{"bazel_dep("}), + spec("workspace.gradle", "Gradle multi-project", "gradle", []string{"settings.gradle"}, []string{"include("}), + spec("workspace.maven", "Maven modules", "xml", []string{"pom.xml"}, []string{""}), + spec("workspace.cargo", "Cargo workspace", "toml", []string{"Cargo.toml"}, []string{"[workspace]"}), + spec("workspace.go", "Go workspaces", "go-work", []string{"go.work"}, []string{"use ("}), + } +} + +func spec(id, name, language string, pathTokens, sourceTokens []string) pattern.Spec { + return pattern.Spec{ + ID: id, + Name: name, + Category: "workspace", + Languages: []string{language}, + Mode: enrich.ActivationAlways, + FactType: "workspace.package", + Relationship: "contains", + SourceTokens: sourceTokens, + PathTokens: pathTokens, + Tags: []string{"workspace:" + id}, + Attributes: map[string]string{"tool": id}, + } +} diff --git a/internal/watch/enrich/enrichers/workspace/workspace_test.go b/internal/watch/enrich/enrichers/workspace/workspace_test.go new file mode 100644 index 0000000..ef525f6 --- /dev/null +++ b/internal/watch/enrich/enrichers/workspace/workspace_test.go @@ -0,0 +1,36 @@ +package workspace + +import ( + "testing" + + "github.com/mertcikla/tld/internal/watch/enrich" + "github.com/mertcikla/tld/internal/watch/enrich/enrichertest" +) + +func TestWorkspaceEnrichers(t *testing.T) { + byID := enrichersByID() + for _, spec := range Specs() { + relPath := "workspace.config" + if len(spec.PathTokens) > 0 { + relPath = spec.PathTokens[0] + } + enrichertest.Run(t, enrichertest.Case{ + Name: spec.ID, + Enricher: byID[spec.ID], + Input: enrich.FileInput{ + RelPath: relPath, + Language: spec.Languages[0], + Source: []byte(spec.SourceTokens[0]), + }, + Want: enrichertest.Fact{Type: spec.FactType, Tag: "category:workspace", Name: spec.Name, Attribute: "technology", AttrValue: spec.Name}, + }) + } +} + +func enrichersByID() map[string]enrich.Enricher { + out := map[string]enrich.Enricher{} + for _, enricher := range All() { + out[enricher.Metadata().ID] = enricher + } + return out +} diff --git a/internal/watch/enrich/enrichertest/enrichertest.go b/internal/watch/enrich/enrichertest/enrichertest.go new file mode 100644 index 0000000..5935892 --- /dev/null +++ b/internal/watch/enrich/enrichertest/enrichertest.go @@ -0,0 +1,87 @@ +package enrichertest + +import ( + "context" + "slices" + "strings" + "testing" + + "github.com/mertcikla/tld/internal/watch/enrich" +) + +type Case struct { + Name string + Enricher enrich.Enricher + Input enrich.FileInput + Signals []enrich.ActivationSignal + Want Fact +} + +type Fact struct { + Type string + Tag string + Name string + Attribute string + AttrValue string + StablePart string +} + +func Run(t *testing.T, cases ...Case) { + t.Helper() + for _, tc := range cases { + t.Run(tc.Name, func(t *testing.T) { + t.Helper() + input := tc.Input + if input.RelPath == "" { + input.RelPath = "snippet" + } + input.Signals = nil + + registry := enrich.NewRegistry(tc.Enricher) + if tc.Enricher.Metadata().Mode == enrich.ActivationImportOrDependency && len(tc.Signals) > 0 { + facts, _, err := registry.EnrichFile(context.Background(), input) + if err != nil { + t.Fatalf("enrich without signals: %v", err) + } + if len(facts) != 0 { + t.Fatalf("expected enricher to stay inactive without activation signals, got %+v", facts) + } + } + + input.Signals = tc.Signals + facts, _, err := registry.EnrichFile(context.Background(), input) + if err != nil { + t.Fatalf("enrich with signals: %v", err) + } + if !hasFact(facts, tc.Want) { + t.Fatalf("missing expected fact %+v in %+v", tc.Want, facts) + } + }) + } +} + +func hasFact(facts []enrich.Fact, want Fact) bool { + for _, fact := range facts { + if want.Type != "" && fact.Type != want.Type { + continue + } + if want.Tag != "" && !contains(fact.Tags, want.Tag) { + continue + } + if want.Name != "" && fact.Name != want.Name { + continue + } + if want.Attribute != "" && fact.Attributes[want.Attribute] != want.AttrValue { + continue + } + if want.StablePart != "" && !strings.Contains(fact.StableKey, want.StablePart) { + continue + } + return true + } + return false +} + +func contains(values []string, want string) bool { + return slices.Contains(values, want) +} diff --git a/internal/watch/enrich/helpers.go b/internal/watch/enrich/helpers.go new file mode 100644 index 0000000..7d619af --- /dev/null +++ b/internal/watch/enrich/helpers.go @@ -0,0 +1,220 @@ +package enrich + +import ( + "context" + "fmt" + "path" + "regexp" + "strings" + + "github.com/mertcikla/tld/internal/analyzer" +) + +type enricherFunc struct { + meta Metadata + match func(FileInput) bool + run func(context.Context, FileInput, FactEmitter) error +} + +func NewEnricher(meta Metadata, match func(FileInput) bool, run func(context.Context, FileInput, FactEmitter) error) Enricher { + return enricherFunc{meta: meta, match: match, run: run} +} + +func (e enricherFunc) Metadata() Metadata { return e.meta } +func (e enricherFunc) MatchFile(input FileInput) bool { + if e.match == nil { + return true + } + return e.match(input) +} +func (e enricherFunc) EnrichFile(ctx context.Context, input FileInput, emit FactEmitter) error { + if e.run == nil { + return nil + } + return e.run(ctx, input, emit) +} + +type RoutePattern struct { + Re *regexp.Regexp + FactType string + Method string + Framework string + MethodGroup int + PathGroup int + Tags []string + Custom func([]string) (string, map[string]string, []string) +} + +func RouteRegexEnricher(id, name, languages string, triggers []ActivationSignal, patterns []*RoutePattern) Enricher { + return enricherFunc{ + meta: Metadata{ID: id, Name: name, Mode: ActivationImportOrDependency, Triggers: triggers}, + match: func(input FileInput) bool { + allowed := strings.Split(languages, ",") + return matchLanguages(allowed...)(input) + }, + run: func(ctx context.Context, input FileInput, emit FactEmitter) error { + return emitMatches(input, emit, patterns) + }, + } +} + +func EmitMatches(input FileInput, emit FactEmitter, patterns []*RoutePattern) error { + source := string(input.Source) + for _, pattern := range patterns { + matches := pattern.Re.FindAllStringSubmatchIndex(source, -1) + for _, indexes := range matches { + match := submatches(source, indexes) + line := lineForOffset(source, indexes[0]) + factType := pattern.FactType + if factType == "" { + factType = "http.route" + } + name, attrs, tags := routeFactValues(pattern, match) + if name == "" { + continue + } + subject := subjectForLine(input, line) + key := fmt.Sprintf("%s:%s:%s:%s:%d", factType, pattern.Framework, input.RelPath, name, line) + if err := emit.EmitFact(Fact{ + Type: factType, + StableKey: key, + Subject: subject, + Object: SubjectRef{Kind: factType, StableKey: factType + ":" + pattern.Framework + ":" + name, FilePath: input.RelPath, Name: name}, + Relationship: "declares", + Source: SourceSpan{FilePath: input.RelPath, StartLine: line, EndLine: line}, + Confidence: 0.90, + Name: name, + Tags: tags, + Attributes: attrs, + VisibilityHints: map[string]float64{ + "high_signal": 1, + }, + }); err != nil { + return err + } + } + } + return nil +} + +func emitMatches(input FileInput, emit FactEmitter, patterns []*RoutePattern) error { + return EmitMatches(input, emit, patterns) +} + +func routeFactValues(pattern *RoutePattern, match []string) (string, map[string]string, []string) { + if pattern.Custom != nil { + return pattern.Custom(match) + } + method := strings.ToUpper(pattern.Method) + routePath := "" + if pattern.PathGroup > 0 && pattern.PathGroup < len(match) { + routePath = match[pattern.PathGroup] + } else if len(match) > 1 { + routePath = match[1] + } + if pattern.MethodGroup > 0 && pattern.MethodGroup < len(match) { + method = strings.ToUpper(match[pattern.MethodGroup]) + } + attrs := map[string]string{"framework": pattern.Framework, "path": routePath} + name := routePath + if method != "" { + attrs["method"] = method + name = method + " " + routePath + } + tags := append([]string{}, pattern.Tags...) + if len(tags) == 0 { + tags = []string{"http:route"} + } + if pattern.Framework != "" { + tags = append(tags, "framework:"+pattern.Framework) + } + return name, attrs, tags +} + +func matchLanguages(languages ...string) func(FileInput) bool { + return MatchLanguages(languages...) +} + +func MatchLanguages(languages ...string) func(FileInput) bool { + allowed := map[string]struct{}{} + for _, language := range languages { + language = strings.TrimSpace(strings.ToLower(language)) + if language != "" { + allowed[language] = struct{}{} + } + } + return func(input FileInput) bool { + _, ok := allowed[strings.ToLower(input.Language)] + return ok + } +} + +func subjectForLine(input FileInput, line int) SubjectRef { + return SubjectForLine(input, line) +} + +func SubjectForLine(input FileInput, line int) SubjectRef { + if input.Parsed != nil { + for _, sym := range input.Parsed.Symbols { + end := sym.EndLine + if end <= 0 { + end = sym.Line + } + if sym.Line <= line && end >= line { + return SubjectRef{ + Kind: "symbol", + StableKey: symbolStableKey(input.Language, input.RelPath, sym), + FilePath: input.RelPath, + Name: symbolQualifiedName(sym), + } + } + } + } + return fileSubject(input.RelPath) +} + +func fileSubject(relPath string) SubjectRef { + return FileSubject(relPath) +} + +func FileSubject(relPath string) SubjectRef { + return SubjectRef{Kind: "file", StableKey: "file:" + relPath, FilePath: relPath, Name: path.Base(relPath)} +} + +func symbolStableKey(language, relPath string, sym analyzer.Symbol) string { + return fmt.Sprintf("%s:%s:%s:%s", language, relPath, sym.Kind, symbolQualifiedName(sym)) +} + +func symbolQualifiedName(sym analyzer.Symbol) string { + if sym.Parent == "" { + return sym.Name + } + return sym.Parent + "." + sym.Name +} + +func submatches(source string, indexes []int) []string { + return Submatches(source, indexes) +} + +func Submatches(source string, indexes []int) []string { + out := make([]string, 0, len(indexes)/2) + for i := 0; i < len(indexes); i += 2 { + if indexes[i] < 0 || indexes[i+1] < 0 { + out = append(out, "") + continue + } + out = append(out, source[indexes[i]:indexes[i+1]]) + } + return out +} + +func lineForOffset(source string, offset int) int { + return LineForOffset(source, offset) +} + +func LineForOffset(source string, offset int) int { + if offset < 0 { + return 1 + } + return strings.Count(source[:offset], "\n") + 1 +} diff --git a/internal/watch/enrich/registry.go b/internal/watch/enrich/registry.go new file mode 100644 index 0000000..d574dd7 --- /dev/null +++ b/internal/watch/enrich/registry.go @@ -0,0 +1,147 @@ +package enrich + +import ( + "context" + "fmt" + "sort" + "strings" +) + +type Registry struct { + enrichers []Enricher +} + +func NewRegistry(enrichers ...Enricher) *Registry { + r := &Registry{} + for _, enricher := range enrichers { + r.Register(enricher) + } + return r +} + +func (r *Registry) Register(enricher Enricher) { + if r == nil || enricher == nil { + return + } + r.enrichers = append(r.enrichers, enricher) + sort.SliceStable(r.enrichers, func(i, j int) bool { + return r.enrichers[i].Metadata().ID < r.enrichers[j].Metadata().ID + }) +} + +func (r *Registry) EnrichFile(ctx context.Context, input FileInput) ([]Fact, []Warning, error) { + if r == nil { + return nil, nil, nil + } + var facts []Fact + var warnings []Warning + for _, enricher := range r.enrichers { + meta := enricher.Metadata() + if strings.TrimSpace(meta.ID) == "" { + continue + } + if !r.active(meta, input.Signals) || !enricher.MatchFile(input) { + continue + } + emitter := &collector{enricher: meta.ID} + if err := enricher.EnrichFile(ctx, input, emitter); err != nil { + return nil, nil, fmt.Errorf("%s enrich %s: %w", meta.ID, input.RelPath, err) + } + facts = append(facts, emitter.facts...) + warnings = append(warnings, emitter.warnings...) + } + sort.SliceStable(facts, func(i, j int) bool { + if facts[i].Enricher == facts[j].Enricher { + return facts[i].StableKey < facts[j].StableKey + } + return facts[i].Enricher < facts[j].Enricher + }) + return facts, warnings, nil +} + +func (r *Registry) active(meta Metadata, signals []ActivationSignal) bool { + switch meta.Mode { + case "", ActivationAlways: + return true + case ActivationImportOrDependency: + for _, trigger := range meta.Triggers { + for _, signal := range signals { + if signalMatches(trigger, signal) { + return true + } + } + } + } + return false +} + +func signalMatches(trigger, signal ActivationSignal) bool { + if trigger.Kind != "" && trigger.Kind != signal.Kind { + return false + } + triggerValue := strings.TrimSpace(trigger.Value) + signalValue := strings.TrimSpace(signal.Value) + if triggerValue == "" || signalValue == "" { + return false + } + return signalValue == triggerValue || strings.HasPrefix(signalValue, triggerValue+"/") +} + +type collector struct { + enricher string + facts []Fact + warnings []Warning +} + +func (c *collector) EmitFact(fact Fact) error { + fact.Enricher = strings.TrimSpace(fact.Enricher) + if fact.Enricher == "" { + fact.Enricher = c.enricher + } + fact.Type = strings.TrimSpace(fact.Type) + fact.StableKey = strings.TrimSpace(fact.StableKey) + if fact.Type == "" { + return fmt.Errorf("fact type is required") + } + if fact.StableKey == "" { + return fmt.Errorf("fact stable key is required") + } + if fact.Confidence <= 0 { + fact.Confidence = 1 + } + if fact.Attributes == nil { + fact.Attributes = map[string]string{} + } + fact.Relationship = strings.TrimSpace(fact.Relationship) + if fact.VisibilityHints == nil { + fact.VisibilityHints = map[string]float64{} + } + fact.Tags = normalizeTags(fact.Tags) + c.facts = append(c.facts, fact) + return nil +} + +func (c *collector) Warn(warning Warning) { + if warning.Enricher == "" { + warning.Enricher = c.enricher + } + c.warnings = append(c.warnings, warning) +} + +func normalizeTags(tags []string) []string { + seen := map[string]struct{}{} + out := make([]string, 0, len(tags)) + for _, tag := range tags { + tag = strings.ToLower(strings.TrimSpace(tag)) + if tag == "" { + continue + } + if _, ok := seen[tag]; ok { + continue + } + seen[tag] = struct{}{} + out = append(out, tag) + } + sort.Strings(out) + return out +} diff --git a/internal/watch/enrich/registry_test.go b/internal/watch/enrich/registry_test.go new file mode 100644 index 0000000..90db3fa --- /dev/null +++ b/internal/watch/enrich/registry_test.go @@ -0,0 +1,363 @@ +package enrich_test + +import ( + "context" + "os" + "path/filepath" + "slices" + "strings" + "testing" + + "github.com/mertcikla/tld/internal/analyzer" + "github.com/mertcikla/tld/internal/watch/enrich" + "github.com/mertcikla/tld/internal/watch/enrich/defaults" + goroutes "github.com/mertcikla/tld/internal/watch/enrich/enrichers/routes/golang" +) + +type ActivationSignal = enrich.ActivationSignal +type Fact = enrich.Fact +type FactEmitter = enrich.FactEmitter +type FileInput = enrich.FileInput +type Metadata = enrich.Metadata + +const ( + ActivationAlways = enrich.ActivationAlways + SignalDependency = enrich.SignalDependency + SignalImport = enrich.SignalImport +) + +var ( + DefaultEnrichers = defaults.DefaultEnrichers + GoChi = goroutes.GoChi + ImportSignals = enrich.ImportSignals + DiscoverSignals = enrich.DiscoverRepositorySignalsFromFiles + NewDefaultRegistry = defaults.NewRegistry + NewRegistry = enrich.NewRegistry +) + +func TestRegistryActivatesImportGatedEnrichers(t *testing.T) { + source := []byte(`package main + +import "github.com/go-chi/chi/v5" + +func routes(r chi.Router) { + r.Get("/users/{id}", getUser) +} + +func getUser() {} +`) + input := FileInput{ + RelPath: "routes.go", + Language: "go", + Source: source, + Parsed: &analyzer.Result{Refs: []analyzer.Ref{{ + Kind: "import", + TargetPath: "github.com/go-chi/chi/v5", + FilePath: "routes.go", + Line: 3, + }}}, + } + + withoutSignals, _, err := NewRegistry(GoChi()).EnrichFile(context.Background(), input) + if err != nil { + t.Fatalf("enrich without signals: %v", err) + } + if len(withoutSignals) != 0 { + t.Fatalf("expected inactive chi enricher without signals, got %+v", withoutSignals) + } + + input.Signals = ImportSignals(input.Parsed.Refs) + withSignals, _, err := NewRegistry(GoChi()).EnrichFile(context.Background(), input) + if err != nil { + t.Fatalf("enrich with signals: %v", err) + } + if len(withSignals) != 1 || withSignals[0].Type != "http.route" || !containsTag(withSignals[0].Tags, "framework:chi") { + t.Fatalf("expected chi route fact, got %+v", withSignals) + } +} + +func TestRegistryRejectsInvalidFacts(t *testing.T) { + bad := enrich.NewEnricher( + Metadata{ID: "bad", Mode: ActivationAlways}, + nil, + func(ctx context.Context, input FileInput, emit FactEmitter) error { + return emit.EmitFact(Fact{Type: "demo.fact"}) + }, + ) + _, _, err := NewRegistry(bad).EnrichFile(context.Background(), FileInput{RelPath: "demo.go"}) + if err == nil || !strings.Contains(err.Error(), "stable key") { + t.Fatalf("expected stable key validation error, got %v", err) + } +} + +func TestDefaultEnrichersHaveUniqueIDs(t *testing.T) { + seen := map[string]struct{}{} + for _, enricher := range DefaultEnrichers() { + meta := enricher.Metadata() + if strings.TrimSpace(meta.ID) == "" { + t.Fatalf("default enricher has empty ID: %+v", meta) + } + if _, ok := seen[meta.ID]; ok { + t.Fatalf("default enricher ID registered more than once: %s", meta.ID) + } + seen[meta.ID] = struct{}{} + } +} + +func TestDefaultEnrichersIncludeExpandedCatalog(t *testing.T) { + enrichers := DefaultEnrichers() + if len(enrichers) < 360 || len(enrichers) > 430 { + t.Fatalf("expected expanded default catalog, got %d", len(enrichers)) + } + want := []string{ + "ts.process_env", + "python.httpx", + "java.spring_web", + "rust.axum", + "cpp.drogon", + "python.sqlalchemy", + "rust.tonic", + "ts.kafkajs", + "go.aws_sdk_v2", + "java.opensearch", + "iac.terraform", + "ts.opentelemetry", + "go.jwt", + "ts.bullmq", + "apispec.openapi", + "deployment.github_actions", + "secrets.code.aws_secrets_manager", + "workspace.nx", + "python.openai", + "go.mqtt", + "go.unix_socket", + "python.airflow", + "ts.ethers", + "os.uri_schemes", + } + seen := map[string]struct{}{} + for _, enricher := range enrichers { + seen[enricher.Metadata().ID] = struct{}{} + } + for _, id := range want { + if _, ok := seen[id]; !ok { + t.Fatalf("default catalog missing enricher %s", id) + } + } + if _, ok := seen["generic.architecture_glue"]; ok { + t.Fatalf("generic architecture glue should not be registered alongside categorized enrichers") + } +} + +func TestDiscoverRepositorySignalsFromExpandedManifests(t *testing.T) { + root := t.TempDir() + files := map[string]string{ + "requirements.txt": "fastapi==0.110.0\nhttpx>=0.27.0\n", + "pyproject.toml": "[project]\ndependencies = [\"sqlalchemy>=2\"]\n[tool.poetry.dependencies]\ndjango = \"^5\"\n", + "Cargo.toml": "[dependencies]\naxum = \"0.7\"\ntonic = \"0.11\"\n", + "pom.xml": `org.springframework.bootspring-boot-starter-web`, + "build.gradle": `implementation "org.springframework.kafka:spring-kafka:3.1.0"`, + "CMakeLists.txt": "find_package(Drogon REQUIRED)\n", + "conanfile.txt": "requires = cpprestsdk/2.10.18\n", + "vcpkg.json": `{"dependencies":[{"name":"boost-beast"}]}`, + } + var paths []string + for rel, data := range files { + path := filepath.Join(root, rel) + if err := os.WriteFile(path, []byte(data), 0o644); err != nil { + t.Fatal(err) + } + paths = append(paths, path) + } + signals := DiscoverSignals(root, paths) + for _, want := range []string{"fastapi", "httpx", "django", "axum", "tonic", "spring-boot-starter-web", "spring-kafka", "Drogon", "cpprestsdk", "boost-beast"} { + if !hasSignal(signals, want) { + t.Fatalf("missing dependency signal %q in %+v", want, signals) + } + } +} + +func TestDefaultRegistryEmitsDemoFacts(t *testing.T) { + tests := []struct { + name string + input FileInput + signals []ActivationSignal + wantType string + wantTag string + }{ + { + name: "express route", + input: FileInput{ + RelPath: "server.ts", + Language: "typescript", + Source: []byte(`router.get("/api/users", listUsers)`), + }, + signals: []ActivationSignal{{Kind: SignalDependency, Value: "express"}}, + wantType: "http.route", + wantTag: "framework:express", + }, + { + name: "next page route", + input: FileInput{ + RelPath: "src/app/users/[id]/page.tsx", + Language: "typescript", + Source: []byte(`export default function Page() { return null }`), + }, + signals: []ActivationSignal{{Kind: SignalDependency, Value: "next"}}, + wantType: "frontend.route", + wantTag: "framework:nextjs", + }, + { + name: "prisma query", + input: FileInput{ + RelPath: "db.ts", + Language: "typescript", + Source: []byte(`await prisma.user.findMany()`), + }, + signals: []ActivationSignal{{Kind: SignalDependency, Value: "@prisma/client"}}, + wantType: "orm.query", + wantTag: "orm:prisma", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.input.Signals = tt.signals + facts, _, err := NewDefaultRegistry().EnrichFile(context.Background(), tt.input) + if err != nil { + t.Fatal(err) + } + if !hasFact(facts, tt.wantType, tt.wantTag) { + t.Fatalf("missing %s/%s in %+v", tt.wantType, tt.wantTag, facts) + } + }) + } +} + +func TestDefaultRegistryEmitsArchitectureGlueFacts(t *testing.T) { + tests := []struct { + name string + input FileInput + signals []ActivationSignal + wantType string + wantTag string + }{ + { + name: "go grpc client", + input: FileInput{ + RelPath: "src/frontend/rpc.go", + Language: "go", + Source: []byte(`package main +func f() { _ = pb.NewCartServiceClient(conn).GetCart(ctx, req) }`), + }, + signals: []ActivationSignal{{Kind: SignalImport, Value: "google.golang.org/grpc"}}, + wantType: "grpc.client", + wantTag: "grpc:client", + }, + { + name: "python grpc server", + input: FileInput{ + RelPath: "src/emailservice/email_server.py", + Language: "python", + Source: []byte(`demo_pb2_grpc.add_EmailServiceServicer_to_server(service, server)`)}, + signals: []ActivationSignal{{Kind: SignalImport, Value: "grpc"}}, + wantType: "grpc.server", + wantTag: "grpc:server", + }, + { + name: "node grpc server", + input: FileInput{ + RelPath: "src/paymentservice/server.js", + Language: "javascript", + Source: []byte(`this.server.addService(hipsterShopPackage.PaymentService.service, { charge })`)}, + signals: []ActivationSignal{{Kind: SignalDependency, Value: "@grpc/grpc-js"}}, + wantType: "grpc.server", + wantTag: "grpc:server", + }, + { + name: "java grpc server", + input: FileInput{ + RelPath: "src/adservice/src/main/java/hipstershop/AdService.java", + Language: "java", + Source: []byte(`class AdServiceImpl extends hipstershop.AdServiceGrpc.AdServiceImplBase {}`)}, + signals: []ActivationSignal{{Kind: SignalImport, Value: "io.grpc"}}, + wantType: "grpc.server", + wantTag: "grpc:server", + }, + { + name: "dotnet grpc server", + input: FileInput{ + RelPath: "src/cartservice/src/Startup.cs", + Language: "c-sharp", + Source: []byte(`endpoints.MapGrpcService();`)}, + signals: []ActivationSignal{{Kind: SignalDependency, Value: "Grpc.AspNetCore"}}, + wantType: "grpc.server", + wantTag: "grpc:server", + }, + { + name: "protobuf contract", + input: FileInput{ + RelPath: "protos/demo.proto", + Language: "protobuf", + Source: []byte(`service CheckoutService { rpc PlaceOrder(PlaceOrderRequest) returns (PlaceOrderResponse); }`)}, + wantType: "grpc.contract", + wantTag: "arch:contract", + }, + { + name: "runtime manifest component", + input: FileInput{ + RelPath: "kubernetes-manifests/frontend.yaml", + Language: "yaml", + Source: []byte(`apiVersion: apps/v1 +kind: Deployment +metadata: + name: frontend +spec: + template: + spec: + containers: + - image: frontend + env: + - name: CART_SERVICE_ADDR + value: cartservice:7070 +`)}, + wantType: "runtime.component", + wantTag: "runtime:kubernetes", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.input.Signals = tt.signals + facts, _, err := NewDefaultRegistry().EnrichFile(context.Background(), tt.input) + if err != nil { + t.Fatal(err) + } + if !hasFact(facts, tt.wantType, tt.wantTag) { + t.Fatalf("missing %s/%s in %+v", tt.wantType, tt.wantTag, facts) + } + }) + } +} + +func hasFact(facts []Fact, factType, tag string) bool { + for _, fact := range facts { + if fact.Type == factType && containsTag(fact.Tags, tag) { + return true + } + } + return false +} + +func containsTag(tags []string, tag string) bool { + return slices.Contains(tags, tag) +} + +func hasSignal(signals []ActivationSignal, value string) bool { + for _, signal := range signals { + if signal.Kind == SignalDependency && signal.Value == value { + return true + } + } + return false +} diff --git a/internal/watch/enrich/signals.go b/internal/watch/enrich/signals.go new file mode 100644 index 0000000..e05166c --- /dev/null +++ b/internal/watch/enrich/signals.go @@ -0,0 +1,320 @@ +package enrich + +import ( + "encoding/json" + "os" + "path/filepath" + "regexp" + "sort" + "strings" + + "github.com/mertcikla/tld/internal/analyzer" +) + +var goRequireLineRE = regexp.MustCompile(`^\s*([A-Za-z0-9_./~-]+)\s+v[0-9]`) + +func DiscoverRepositorySignals(repoRoot string) []ActivationSignal { + var signals []ActivationSignal + signals = append(signals, discoverGoModSignals(filepath.Join(repoRoot, "go.mod"))...) + signals = append(signals, discoverPackageJSONSignals(repoRoot)...) + signals = append(signals, discoverComposeSignals(repoRoot)...) + return uniqueSignals(signals) +} + +func DiscoverRepositorySignalsFromFiles(repoRoot string, files []string) []ActivationSignal { + var signals []ActivationSignal + for _, file := range files { + rel, err := filepath.Rel(repoRoot, file) + if err != nil { + rel = file + } + rel = filepath.ToSlash(rel) + switch filepath.Base(file) { + case "go.mod": + signals = append(signals, discoverGoModSignals(file)...) + case "package.json": + signals = append(signals, packageJSONSignals(file, rel)...) + case "requirements.txt": + signals = append(signals, lineDependencySignals(file, rel, requirementSignalName)...) + case "pyproject.toml", "poetry.lock": + signals = append(signals, lineDependencySignals(file, rel, tomlSignalName)...) + case "Cargo.toml": + signals = append(signals, lineDependencySignals(file, rel, cargoSignalName)...) + case "pom.xml": + signals = append(signals, pomSignals(file, rel)...) + case "build.gradle", "build.gradle.kts": + signals = append(signals, lineDependencySignals(file, rel, gradleSignalName)...) + case "CMakeLists.txt", "conanfile.txt", "conanfile.py", "vcpkg.json": + signals = append(signals, lineDependencySignals(file, rel, cppSignalName)...) + } + if isComposeFile(rel) { + signals = append(signals, ActivationSignal{Kind: SignalDependency, Value: "docker-compose", Source: rel}) + } + } + return uniqueSignals(signals) +} + +func ImportSignals(refs []analyzer.Ref) []ActivationSignal { + signals := make([]ActivationSignal, 0, len(refs)) + for _, ref := range refs { + if ref.Kind != "import" || strings.TrimSpace(ref.TargetPath) == "" { + continue + } + signals = append(signals, ActivationSignal{Kind: SignalImport, Value: ref.TargetPath, Source: ref.FilePath}) + } + return uniqueSignals(signals) +} + +func discoverGoModSignals(path string) []ActivationSignal { + data, err := os.ReadFile(path) + if err != nil { + return nil + } + var signals []ActivationSignal + for line := range strings.SplitSeq(string(data), "\n") { + match := goRequireLineRE.FindStringSubmatch(line) + if len(match) != 2 { + continue + } + signals = append(signals, ActivationSignal{Kind: SignalDependency, Value: match[1], Source: "go.mod"}) + } + return signals +} + +func discoverPackageJSONSignals(repoRoot string) []ActivationSignal { + var signals []ActivationSignal + _ = filepath.WalkDir(repoRoot, func(path string, d os.DirEntry, err error) error { + if err != nil { + return nil + } + if d.IsDir() { + if isSignalScanIgnoredDir(d.Name()) { + return filepath.SkipDir + } + return nil + } + if d.Name() != "package.json" { + return nil + } + rel, relErr := filepath.Rel(repoRoot, path) + if relErr != nil { + rel = path + } + signals = append(signals, packageJSONSignals(path, filepath.ToSlash(rel))...) + return nil + }) + return signals +} + +func discoverComposeSignals(repoRoot string) []ActivationSignal { + var signals []ActivationSignal + _ = filepath.WalkDir(repoRoot, func(path string, d os.DirEntry, err error) error { + if err != nil { + return nil + } + if d.IsDir() { + if isSignalScanIgnoredDir(d.Name()) { + return filepath.SkipDir + } + return nil + } + rel, relErr := filepath.Rel(repoRoot, path) + if relErr != nil { + rel = path + } + if isComposeFile(filepath.ToSlash(rel)) { + signals = append(signals, ActivationSignal{Kind: SignalDependency, Value: "docker-compose", Source: rel}) + } + return nil + }) + return signals +} + +func isComposeFile(rel string) bool { + base := strings.ToLower(filepath.Base(rel)) + switch base { + case "docker-compose.yml", "docker-compose.yaml", "compose.yaml", "compose.yml": + return true + default: + return strings.HasPrefix(base, "docker-compose.") && (strings.HasSuffix(base, ".yml") || strings.HasSuffix(base, ".yaml")) + } +} + +func isSignalScanIgnoredDir(name string) bool { + switch strings.ToLower(name) { + case ".git", ".hg", ".svn", "node_modules", "dist", "build", ".next", ".turbo", "coverage", "vendor": + return true + default: + return strings.HasPrefix(name, ".") + } +} + +func packageJSONSignals(path, rel string) []ActivationSignal { + data, err := os.ReadFile(path) + if err != nil { + return nil + } + var pkg struct { + Dependencies map[string]string `json:"dependencies"` + DevDependencies map[string]string `json:"devDependencies"` + PeerDependencies map[string]string `json:"peerDependencies"` + OptionalDependencies map[string]string `json:"optionalDependencies"` + } + if err := json.Unmarshal(data, &pkg); err != nil { + return nil + } + var signals []ActivationSignal + add := func(values map[string]string) { + for name := range values { + signals = append(signals, ActivationSignal{Kind: SignalDependency, Value: name, Source: rel}) + } + } + add(pkg.Dependencies) + add(pkg.DevDependencies) + add(pkg.PeerDependencies) + add(pkg.OptionalDependencies) + return signals +} + +func lineDependencySignals(path, rel string, parse func(string) string) []ActivationSignal { + data, err := os.ReadFile(path) + if err != nil { + return nil + } + var signals []ActivationSignal + for line := range strings.SplitSeq(string(data), "\n") { + name := parse(line) + if name == "" { + continue + } + signals = append(signals, ActivationSignal{Kind: SignalDependency, Value: name, Source: rel}) + } + return signals +} + +func pomSignals(path, rel string) []ActivationSignal { + data, err := os.ReadFile(path) + if err != nil { + return nil + } + re := regexp.MustCompile(`(?s).*?\s*([^<\s]+)\s*.*?\s*([^<\s]+)\s*.*?`) + var signals []ActivationSignal + for _, match := range re.FindAllStringSubmatch(string(data), -1) { + if len(match) == 3 { + signals = append(signals, ActivationSignal{Kind: SignalDependency, Value: match[1] + ":" + match[2], Source: rel}) + signals = append(signals, ActivationSignal{Kind: SignalDependency, Value: match[2], Source: rel}) + } + } + return signals +} + +func requirementSignalName(line string) string { + line = strings.TrimSpace(strings.Split(line, "#")[0]) + if line == "" || strings.HasPrefix(line, "-") { + return "" + } + return signalPrefix(line, "=", "<", ">", "~", "!", "[", ";") +} + +func tomlSignalName(line string) string { + line = strings.TrimSpace(strings.Split(line, "#")[0]) + if line == "" || strings.HasPrefix(line, "[") { + return "" + } + if strings.HasPrefix(line, "\"") || strings.HasPrefix(line, "'") { + trimmed := strings.Trim(line, " ,") + trimmed = strings.Trim(trimmed, `"'`) + if trimmed != "" && !strings.Contains(trimmed, "=") { + return requirementSignalName(trimmed) + } + } + if idx := strings.Index(line, "="); idx > 0 { + return strings.Trim(strings.TrimSpace(line[:idx]), `"'`) + } + return "" +} + +func cargoSignalName(line string) string { + name := tomlSignalName(line) + switch name { + case "package", "dependencies", "dev-dependencies", "build-dependencies", "workspace": + return "" + default: + return name + } +} + +func gradleSignalName(line string) string { + line = strings.TrimSpace(line) + for _, quote := range []string{"\"", "'"} { + start := strings.Index(line, quote) + if start < 0 { + continue + } + rest := line[start+1:] + before, _, ok := strings.Cut(rest, quote) + if !ok { + continue + } + value := before + if strings.Count(value, ":") >= 1 { + parts := strings.Split(value, ":") + return parts[len(parts)-2] + } + } + return "" +} + +func cppSignalName(line string) string { + line = strings.TrimSpace(strings.Split(line, "#")[0]) + if line == "" { + return "" + } + for _, prefix := range []string{"find_package(", "target_link_libraries(", "requires =", "self.requires(", "\"name\":"} { + if _, after, ok := strings.Cut(line, prefix); ok { + value := strings.TrimSpace(after) + value = strings.Trim(value, ` "'),[]`) + return signalPrefix(value, " ", "/", ")", ",", "\"") + } + } + return "" +} + +func signalPrefix(value string, stops ...string) string { + value = strings.TrimSpace(value) + end := len(value) + for _, stop := range stops { + if idx := strings.Index(value, stop); idx >= 0 && idx < end { + end = idx + } + } + return strings.TrimSpace(value[:end]) +} + +func uniqueSignals(signals []ActivationSignal) []ActivationSignal { + seen := map[string]struct{}{} + out := make([]ActivationSignal, 0, len(signals)) + for _, signal := range signals { + signal.Kind = strings.TrimSpace(signal.Kind) + signal.Value = strings.TrimSpace(signal.Value) + if signal.Kind == "" || signal.Value == "" { + continue + } + key := signal.Kind + "\x00" + signal.Value + "\x00" + signal.Source + if _, ok := seen[key]; ok { + continue + } + seen[key] = struct{}{} + out = append(out, signal) + } + sort.SliceStable(out, func(i, j int) bool { + if out[i].Kind == out[j].Kind { + if out[i].Value == out[j].Value { + return out[i].Source < out[j].Source + } + return out[i].Value < out[j].Value + } + return out[i].Kind < out[j].Kind + }) + return out +} diff --git a/internal/watch/enrich/types.go b/internal/watch/enrich/types.go new file mode 100644 index 0000000..212cddc --- /dev/null +++ b/internal/watch/enrich/types.go @@ -0,0 +1,89 @@ +package enrich + +import ( + "context" + + "github.com/mertcikla/tld/internal/analyzer" +) + +type ActivationMode string + +const ( + ActivationAlways ActivationMode = "always" + ActivationImportOrDependency ActivationMode = "import_or_dependency" +) + +const ( + SignalDependency = "dependency" + SignalImport = "import" +) + +type ActivationSignal struct { + Kind string + Value string + Source string +} + +type Metadata struct { + ID string + Name string + Mode ActivationMode + Triggers []ActivationSignal +} + +type SourceSpan struct { + FilePath string + StartLine int + EndLine int + StartColumn int + EndColumn int +} + +type SubjectRef struct { + Kind string + StableKey string + FilePath string + Name string +} + +type Fact struct { + Type string + StableKey string + Enricher string + Subject SubjectRef + Object SubjectRef + Relationship string + Source SourceSpan + Confidence float64 + Name string + Tags []string + Attributes map[string]string + VisibilityHints map[string]float64 +} + +type FactEmitter interface { + EmitFact(Fact) error + Warn(Warning) +} + +type Warning struct { + Enricher string + FilePath string + Message string +} + +type FileInput struct { + RepoRoot string + AbsPath string + RelPath string + Language string + Source []byte + Parsed *analyzer.Result + Signals []ActivationSignal +} + +type Enricher interface { + Metadata() Metadata + MatchFile(FileInput) bool + EnrichFile(context.Context, FileInput, FactEmitter) error +} diff --git a/internal/watch/exportyaml/exportyaml.go b/internal/watch/exportyaml/exportyaml.go new file mode 100644 index 0000000..1e2e47e --- /dev/null +++ b/internal/watch/exportyaml/exportyaml.go @@ -0,0 +1,421 @@ +package exportyaml + +import ( + "context" + "fmt" + "sort" + "strconv" + "strings" + "time" + + diagv1 "buf.build/gen/go/tldiagramcom/diagram/protocolbuffers/go/diag/v1" + "github.com/google/uuid" + "github.com/mertcikla/tld/internal/store" + watchpkg "github.com/mertcikla/tld/internal/watch" + "github.com/mertcikla/tld/internal/workspace" +) + +type Result struct { + ElementsWritten int `json:"elements_written"` + ConnectorsWritten int `json:"connectors_written"` + ViewsWritten int `json:"views_written"` +} + +func Export(ctx context.Context, sqliteStore *store.SQLiteStore, watchStore *watchpkg.Store, base *workspace.Workspace, repositoryID int64) (*workspace.Workspace, Result, error) { + return ExportWithProgress(ctx, sqliteStore, watchStore, base, repositoryID, nil) +} + +func ExportWithProgress(ctx context.Context, sqliteStore *store.SQLiteStore, watchStore *watchpkg.Store, base *workspace.Workspace, repositoryID int64, progress watchpkg.ProgressSink) (*workspace.Workspace, Result, error) { + if sqliteStore == nil || watchStore == nil { + return nil, Result{}, fmt.Errorf("export yaml requires sqlite and watch stores") + } + if base == nil { + return nil, Result{}, fmt.Errorf("export yaml requires a workspace") + } + progressStart(progress, "Exporting workspace YAML", 8) + defer progressFinish(progress) + mappings, err := watchStore.Materialization(ctx, repositoryID) + if err != nil { + return nil, Result{}, err + } + progressAdvance(progress, "Materialization loaded") + index := buildMappingIndex(mappings) + api := store.NewAPIAdapter(sqliteStore) + views, err := api.ListViews(ctx, uuid.Nil) + if err != nil { + return nil, Result{}, err + } + progressAdvance(progress, "Views loaded") + elements, _, err := api.ListElements(ctx, uuid.Nil, 0, 0, "") + if err != nil { + return nil, Result{}, err + } + progressAdvance(progress, "Elements loaded") + placements, err := api.ListAllPlacements(ctx, uuid.Nil) + if err != nil { + return nil, Result{}, err + } + progressAdvance(progress, "Placements loaded") + connectors, err := api.ListAllConnectors(ctx, uuid.Nil) + if err != nil { + return nil, Result{}, err + } + progressAdvance(progress, "Connectors loaded") + + out := cloneWorkspace(base) + removeGenerated(out, index) + progressAdvance(progress, "Previous generated YAML removed") + + elementRefByID := existingRefsByID(metaElements(base)) + viewRefByID := existingRefsByID(metaViews(base)) + connectorRefByID := existingRefsByID(metaConnectors(base)) + usedRefs := map[string]struct{}{"root": {}} + for ref := range out.Elements { + usedRefs[ref] = struct{}{} + } + + elementByID := elementsByID(elements) + for _, mapping := range sortedMappings(mappings, "element") { + elem := elementByID[int32(mapping.ResourceID)] + if elem == nil { + continue + } + ref := elementRefByID[int32(mapping.ResourceID)] + if ref == "" || out.Elements[ref] != nil { + ref = uniqueRef(generatedRef(mapping), usedRefs) + } + usedRefs[ref] = struct{}{} + elementRefByID[int32(mapping.ResourceID)] = ref + out.Elements[ref] = &workspace.Element{ + Name: elem.GetName(), + Kind: defaultString(elem.GetKind(), "element"), + Description: elem.GetDescription(), + Technology: elem.GetTechnology(), + URL: elem.GetUrl(), + LogoURL: elem.GetLogoUrl(), + Repo: elem.GetRepo(), + Branch: elem.GetBranch(), + Language: elem.GetLanguage(), + FilePath: elem.GetFilePath(), + Tags: exportedAnalyzeTags(elem.GetTags()), + HasView: false, + ViewLabel: strings.TrimSpace(elem.GetViewLabel()), + } + out.Meta.Elements[ref] = &workspace.ResourceMetadata{ID: workspace.ResourceID(elem.Id), UpdatedAt: timestampTime(elem.GetUpdatedAt())} + } + progressAdvance(progress, "Elements merged") + + viewByID := viewsByID(views) + for _, mapping := range sortedMappings(mappings, "view") { + view := viewByID[int32(mapping.ResourceID)] + if view == nil || view.OwnerElementId == nil { + continue + } + ownerRef := elementRefByID[*view.OwnerElementId] + if ownerRef == "" || out.Elements[ownerRef] == nil { + continue + } + viewRefByID[int32(mapping.ResourceID)] = ownerRef + out.Elements[ownerRef].HasView = true + if out.Elements[ownerRef].ViewLabel == "" && strings.TrimSpace(view.GetLevelLabel()) != "" { + out.Elements[ownerRef].ViewLabel = strings.TrimSpace(view.GetLevelLabel()) + } + out.Meta.Views[ownerRef] = &workspace.ResourceMetadata{ID: workspace.ResourceID(view.Id), UpdatedAt: timestampTime(view.GetUpdatedAt())} + } + + for _, placement := range placements { + elemRef := elementRefByID[placement.ElementId] + if elemRef == "" || out.Elements[elemRef] == nil { + continue + } + parentRef := viewRefByID[placement.ViewId] + if parentRef == "" { + parentRef = "root" + } + out.Elements[elemRef].Placements = append(out.Elements[elemRef].Placements, workspace.ViewPlacement{ + ParentRef: parentRef, + PositionX: placement.PositionX, + PositionY: placement.PositionY, + }) + } + + connectorByID := connectorsByID(connectors) + for _, mapping := range sortedMappings(mappings, "connector") { + conn := connectorByID[int32(mapping.ResourceID)] + if conn == nil { + continue + } + viewRef := viewRefByID[conn.ViewId] + if viewRef == "" { + viewRef = "root" + } + sourceRef := elementRefByID[conn.SourceElementId] + targetRef := elementRefByID[conn.TargetElementId] + if sourceRef == "" || targetRef == "" { + continue + } + ref := connectorRefByID[int32(mapping.ResourceID)] + spec := &workspace.Connector{ + View: viewRef, + Source: sourceRef, + Target: targetRef, + Label: conn.GetLabel(), + Description: conn.GetDescription(), + Relationship: conn.GetRelationship(), + Direction: conn.GetDirection(), + Style: conn.GetStyle(), + URL: conn.GetUrl(), + SourceHandle: conn.GetSourceHandle(), + TargetHandle: conn.GetTargetHandle(), + } + if ref == "" || out.Connectors[ref] != nil { + ref = workspace.ConnectorKey(spec) + } + out.Connectors[ref] = spec + out.Meta.Connectors[ref] = &workspace.ResourceMetadata{ID: workspace.ResourceID(conn.Id), UpdatedAt: timestampTime(conn.GetUpdatedAt())} + } + progressAdvance(progress, "Connectors merged") + + return out, Result{ElementsWritten: len(index.elementIDs), ConnectorsWritten: len(index.connectorIDs), ViewsWritten: len(index.viewIDs)}, nil +} + +func progressStart(progress watchpkg.ProgressSink, label string, total int) { + if progress != nil { + progress.Start(label, total) + } +} + +func progressAdvance(progress watchpkg.ProgressSink, label string) { + if progress != nil { + progress.Advance(label) + } +} + +func progressFinish(progress watchpkg.ProgressSink) { + if progress != nil { + progress.Finish() + } +} + +type mappingIndex struct { + elementIDs map[int32]struct{} + viewIDs map[int32]struct{} + connectorIDs map[int32]struct{} +} + +func buildMappingIndex(mappings []watchpkg.MaterializationMapping) mappingIndex { + index := mappingIndex{elementIDs: map[int32]struct{}{}, viewIDs: map[int32]struct{}{}, connectorIDs: map[int32]struct{}{}} + for _, mapping := range mappings { + switch mapping.ResourceType { + case "element": + index.elementIDs[int32(mapping.ResourceID)] = struct{}{} + case "view": + index.viewIDs[int32(mapping.ResourceID)] = struct{}{} + case "connector": + index.connectorIDs[int32(mapping.ResourceID)] = struct{}{} + } + } + return index +} + +func removeGenerated(ws *workspace.Workspace, index mappingIndex) { + for ref, meta := range ws.Meta.Elements { + if meta != nil { + if _, ok := index.elementIDs[int32(meta.ID)]; ok { + delete(ws.Elements, ref) + delete(ws.Meta.Elements, ref) + delete(ws.Meta.Views, ref) + } + } + } + for ref, meta := range ws.Meta.Connectors { + if meta != nil { + if _, ok := index.connectorIDs[int32(meta.ID)]; ok { + delete(ws.Connectors, ref) + delete(ws.Meta.Connectors, ref) + } + } + } +} + +func cloneWorkspace(ws *workspace.Workspace) *workspace.Workspace { + out := &workspace.Workspace{ + Dir: ws.Dir, + Config: ws.Config, + WorkspaceConfig: ws.WorkspaceConfig, + Elements: map[string]*workspace.Element{}, + Connectors: map[string]*workspace.Connector{}, + Meta: ensureMeta(ws.Meta), + IgnoreRules: ws.IgnoreRules, + ActiveRepo: ws.ActiveRepo, + } + for ref, element := range ws.Elements { + copyElement := *element + copyElement.Tags = cloneStrings(element.Tags) + copyElement.Placements = append([]workspace.ViewPlacement(nil), element.Placements...) + out.Elements[ref] = ©Element + } + for ref, connector := range ws.Connectors { + copyConnector := *connector + out.Connectors[ref] = ©Connector + } + return out +} + +func ensureMeta(meta *workspace.Meta) *workspace.Meta { + out := &workspace.Meta{Elements: map[string]*workspace.ResourceMetadata{}, Views: map[string]*workspace.ResourceMetadata{}, Connectors: map[string]*workspace.ResourceMetadata{}} + if meta == nil { + return out + } + for ref, value := range meta.Elements { + if value == nil { + continue + } + copyValue := *value + out.Elements[ref] = ©Value + } + for ref, value := range meta.Views { + if value == nil { + continue + } + copyValue := *value + out.Views[ref] = ©Value + } + for ref, value := range meta.Connectors { + if value == nil { + continue + } + copyValue := *value + out.Connectors[ref] = ©Value + } + return out +} + +func existingRefsByID(meta map[string]*workspace.ResourceMetadata) map[int32]string { + out := map[int32]string{} + for ref, item := range meta { + if item != nil && item.ID != 0 { + out[int32(item.ID)] = ref + } + } + return out +} + +func sortedMappings(mappings []watchpkg.MaterializationMapping, resourceType string) []watchpkg.MaterializationMapping { + var out []watchpkg.MaterializationMapping + for _, mapping := range mappings { + if mapping.ResourceType == resourceType { + out = append(out, mapping) + } + } + sort.Slice(out, func(i, j int) bool { + if out[i].OwnerType == out[j].OwnerType { + return out[i].OwnerKey < out[j].OwnerKey + } + return out[i].OwnerType < out[j].OwnerType + }) + return out +} + +func generatedRef(mapping watchpkg.MaterializationMapping) string { + base := workspace.Slugify(mapping.OwnerType + "-" + mapping.OwnerKey) + if base == "" { + base = "watch-" + strconv.FormatInt(mapping.ResourceID, 10) + } + return base +} + +func uniqueRef(base string, used map[string]struct{}) string { + if _, ok := used[base]; !ok { + return base + } + for i := 2; ; i++ { + candidate := fmt.Sprintf("%s-%d", base, i) + if _, ok := used[candidate]; !ok { + return candidate + } + } +} + +func defaultString(value, fallback string) string { + if strings.TrimSpace(value) == "" { + return fallback + } + return value +} + +func cloneStrings(values []string) []string { + if len(values) == 0 { + return nil + } + return append([]string(nil), values...) +} + +func exportedAnalyzeTags(values []string) []string { + out := make([]string, 0, len(values)) + for _, value := range values { + if strings.HasPrefix(strings.TrimSpace(value), "role:") { + out = append(out, value) + } + } + if len(out) == 0 { + return nil + } + return out +} + +func elementsByID(items []*diagv1.Element) map[int32]*diagv1.Element { + out := map[int32]*diagv1.Element{} + for _, item := range items { + out[item.Id] = item + } + return out +} + +func viewsByID(items []*diagv1.View) map[int32]*diagv1.View { + out := map[int32]*diagv1.View{} + for _, item := range items { + out[item.Id] = item + } + return out +} + +func connectorsByID(items []*diagv1.Connector) map[int32]*diagv1.Connector { + out := map[int32]*diagv1.Connector{} + for _, item := range items { + out[item.Id] = item + } + return out +} + +func metaElements(ws *workspace.Workspace) map[string]*workspace.ResourceMetadata { + if ws.Meta == nil || ws.Meta.Elements == nil { + return nil + } + return ws.Meta.Elements +} + +func metaViews(ws *workspace.Workspace) map[string]*workspace.ResourceMetadata { + if ws.Meta == nil || ws.Meta.Views == nil { + return nil + } + return ws.Meta.Views +} + +func metaConnectors(ws *workspace.Workspace) map[string]*workspace.ResourceMetadata { + if ws.Meta == nil || ws.Meta.Connectors == nil { + return nil + } + return ws.Meta.Connectors +} + +type protoTimestamp interface { + AsTime() time.Time +} + +func timestampTime(ts protoTimestamp) time.Time { + if ts == nil { + return time.Time{} + } + return ts.AsTime() +} diff --git a/internal/watch/filter.go b/internal/watch/filter.go new file mode 100644 index 0000000..1f1484a --- /dev/null +++ b/internal/watch/filter.go @@ -0,0 +1,509 @@ +package watch + +import ( + "context" + "encoding/json" + "path" + "sort" + "strconv" + "strings" + "unicode" +) + +type filterResult struct { + RunID int64 + SettingsHash string + RawGraphHash string + VisibleSymbols map[int64]Symbol + VisibleReferences []Reference + VisibleFacts []Fact + VisibleFiles map[string]struct{} + Incoming map[int64]int + Outgoing map[int64]int + ChangedFiles map[string]struct{} + ContextPolicies contextPolicySet + ContextExpansions contextExpansionSet + Visibility VisibilityConfig +} + +type visibilitySignal struct { + Name string `json:"name"` + Weight float64 `json:"weight"` + Reason string `json:"reason"` +} + +type visibilityScore struct { + Score float64 + Signals []visibilitySignal + Forced bool + Tier int +} + +func (s *visibilityScore) add(name string, weight float64, reason string) { + if weight == 0 { + return + } + s.Score += weight + s.Signals = append(s.Signals, visibilitySignal{Name: name, Weight: weight, Reason: reason}) +} + +func (s visibilityScore) reason(fallback string) string { + var reasons []string + for _, signal := range s.Signals { + if strings.TrimSpace(signal.Reason) != "" { + reasons = append(reasons, signal.Reason) + } + } + if len(reasons) == 0 { + return fallback + } + return strings.Join(reasons, "; ") +} + +func (s visibilityScore) signalsJSON() string { + data, err := json.Marshal(s.Signals) + if err != nil { + return "[]" + } + return string(data) +} + +func defaultThresholds(thresholds Thresholds) Thresholds { + if thresholds.MaxElementsPerView <= 0 { + thresholds.MaxElementsPerView = 50 + } + if thresholds.MaxConnectorsPerView <= 0 { + thresholds.MaxConnectorsPerView = 100 + } + if thresholds.MaxIncomingPerElement <= 0 { + thresholds.MaxIncomingPerElement = 25 + } + if thresholds.MaxOutgoingPerElement <= 0 { + thresholds.MaxOutgoingPerElement = 40 + } + if thresholds.MaxExpandedConnectorsPerGroup <= 0 { + thresholds.MaxExpandedConnectorsPerGroup = 24 + } + return thresholds +} + +func settingsHash(req RepresentRequest) string { + req.Embedding = normalizeEmbeddingConfig(req.Embedding) + req.Thresholds = defaultThresholds(req.Thresholds) + req.Visibility = defaultVisibilityConfig(req.Visibility) + return stableHash(req) +} + +func runFilter(ctx context.Context, store *Store, repositoryID int64, thresholds Thresholds, visibilityCfg VisibilityConfig, rawGraphHash, settingsHash string, embeddings map[int64]Vector, forcedVisibleSymbols map[int64]string, contextPolicies contextPolicySet, identityKeys map[string]string) (filterResult, error) { + visibilityCfg = defaultVisibilityConfig(visibilityCfg) + symbols, err := store.SymbolsForRepository(ctx, repositoryID) + if err != nil { + return filterResult{}, err + } + refs, err := store.QueryReferences(ctx, repositoryID, ReferenceQuery{Limit: -1}) + if err != nil { + return filterResult{}, err + } + facts, err := store.FactsForRepository(ctx, repositoryID) + if err != nil { + return filterResult{}, err + } + expansions, err := store.ActiveContextExpansionSet(ctx, repositoryID) + if err != nil { + return filterResult{}, err + } + if err := ctx.Err(); err != nil { + return filterResult{}, err + } + incoming := map[int64]int{} + outgoing := map[int64]int{} + for _, ref := range refs { + outgoing[ref.SourceSymbolID]++ + incoming[ref.TargetSymbolID]++ + } + + visible := map[int64]Symbol{} + scores := map[int64]*visibilityScore{} + symbolsByID := map[int64]Symbol{} + for _, sym := range symbols { + symbolsByID[sym.ID] = sym + scores[sym.ID] = &visibilityScore{} + } + factsBySubject := map[string][]Fact{} + for _, fact := range facts { + key := ownerMapKey(fact.SubjectKind, fact.SubjectStableKey) + factsBySubject[key] = append(factsBySubject[key], fact) + } + for _, sym := range symbols { + score := scores[sym.ID] + if isExportedSymbol(sym) { + score.add("entrypoint.exported", 1.2, "exported/public symbol") + } + if outgoing[sym.ID] > 0 { + score.add("graph.outgoing", 1, "has resolved outgoing reference") + } + for _, fact := range factsBySubject[ownerMapKey("symbol", sym.StableKey)] { + if highSignalFact(fact) { + score.add("fact.high_signal", visibilityCfg.Weights.HighSignalFact*fact.Confidence, "has high-signal fact "+fact.Type) + } else if dependencyFact(fact) { + score.add("fact.dependency", visibilityCfg.Weights.DependencyFact*fact.Confidence, "has dependency fact") + } + } + if reason, ok := forcedVisibleSymbols[sym.ID]; ok { + if strings.TrimSpace(reason) == "" { + reason = "changed since latest watch version" + } + score.add("change.changed", visibilityCfg.Weights.Changed, reason) + score.Forced = true + } + if reason, ok := contextPolicies.showSymbol(sym, identityKeys); ok { + if strings.TrimSpace(reason) == "" { + reason = "user marked as context" + } + score.add("policy.show", visibilityCfg.Weights.UserShow, reason) + score.Forced = true + } + if tier := expansions.symbolTier(sym, identityKeys); tier > 0 { + score.Tier = tier + score.add("context.expansion", visibilityCfg.Weights.Selected, "selected context expansion tier "+strconv.Itoa(tier)) + score.Forced = true + } + if outgoing[sym.ID] > thresholds.MaxOutgoingPerElement || incoming[sym.ID] > thresholds.MaxIncomingPerElement { + score.add("noise.high_degree", visibilityCfg.Weights.HighDegreeNoise, "high-degree non-entrypoint collapsed") + } + if looksLikeTinyUtility(sym) && outgoing[sym.ID]+incoming[sym.ID] > 8 { + score.add("noise.utility", visibilityCfg.Weights.UtilityNoise, "utility noise collapsed") + } + } + for _, sym := range symbols { + score := scores[sym.ID] + if score.Forced || !visibilityCfg.CoreThresholdEnabled || score.Score >= visibilityCfg.CoreThreshold { + visible[sym.ID] = sym + } + } + changed := true + for changed { + changed = false + for _, ref := range refs { + if _, ok := visible[ref.SourceSymbolID]; !ok { + continue + } + if _, ok := visible[ref.TargetSymbolID]; ok { + continue + } + if target, ok := symbolsByID[ref.TargetSymbolID]; ok { + score := scores[target.ID] + score.add("graph.proximity", visibilityCfg.Weights.RelationshipProximity, "referenced by visible symbol") + if score.Forced || !visibilityCfg.CoreThresholdEnabled || score.Score >= visibilityCfg.CoreThreshold { + visible[target.ID] = target + changed = true + } + } + } + } + for _, ref := range refs { + if _, ok := forcedVisibleSymbols[ref.SourceSymbolID]; ok { + if target, ok := symbolsByID[ref.TargetSymbolID]; ok { + score := scores[target.ID] + score.add("change.endpoint", visibilityCfg.Weights.Changed, "endpoint of changed symbol") + score.Forced = true + visible[target.ID] = target + } + continue + } + if _, ok := forcedVisibleSymbols[ref.TargetSymbolID]; ok { + if source, ok := symbolsByID[ref.SourceSymbolID]; ok { + score := scores[source.ID] + score.add("change.endpoint", visibilityCfg.Weights.Changed, "endpoint of changed symbol") + score.Forced = true + visible[source.ID] = source + } + } + } + if len(embeddings) > 0 { + rescueRelatedSymbolsScored(symbols, refs, visible, scores, embeddings, visibilityCfg.Weights.RelationshipProximity) + } + for _, sym := range symbols { + if _, ok := contextPolicies.hideSymbol(sym, identityKeys); ok { + score := scores[sym.ID] + score.add("policy.hide", visibilityCfg.Weights.UserHide, "user marked as noise") + if !score.Forced { + delete(visible, sym.ID) + } + } + } + if err := ctx.Err(); err != nil { + return filterResult{}, err + } + + runID, err := store.BeginFilterRun(ctx, repositoryID, settingsHash, rawGraphHash) + if err != nil { + return filterResult{}, err + } + visibleSymbols := 0 + hiddenSymbols := 0 + for _, sym := range symbols { + score := scores[sym.ID] + scoreValue := score.Score + ownerKey := symbolOwnerKey(sym, identityKeys) + if _, ok := visible[sym.ID]; ok { + visibleSymbols++ + if err := store.SaveFilterDecision(ctx, runID, "symbol", sym.ID, ownerKey, "visible", score.reason("visible by graph context"), &scoreValue, score.Tier, score.signalsJSON()); err != nil { + return filterResult{}, err + } + continue + } + hiddenSymbols++ + if err := store.SaveFilterDecision(ctx, runID, "symbol", sym.ID, ownerKey, "hidden", score.reason("leaf private symbol without useful outgoing references"), &scoreValue, score.Tier, score.signalsJSON()); err != nil { + return filterResult{}, err + } + } + + var visibleRefs []Reference + hiddenRefs := 0 + for _, ref := range refs { + _, sourceOK := visible[ref.SourceSymbolID] + _, targetOK := visible[ref.TargetSymbolID] + refOwnerKey := referenceOwnerKey(ref, symbolsByID, identityKeys) + if _, hidden := contextPolicies.Hide[ownerMapKey("reference", refOwnerKey)]; hidden { + hiddenRefs++ + scoreValue := visibilityCfg.Weights.UserHide + if err := store.SaveFilterDecision(ctx, runID, "reference", ref.ID, refOwnerKey, "hidden", "user marked as noise", &scoreValue, 0, `[]`); err != nil { + return filterResult{}, err + } + } else if sourceOK && targetOK { + visibleRefs = append(visibleRefs, ref) + scoreValue := 1.0 + if err := store.SaveFilterDecision(ctx, runID, "reference", ref.ID, refOwnerKey, "visible", "connects visible symbols", &scoreValue, 0, `[]`); err != nil { + return filterResult{}, err + } + } else { + hiddenRefs++ + scoreValue := 0.0 + if err := store.SaveFilterDecision(ctx, runID, "reference", ref.ID, refOwnerKey, "hidden", "unresolved or hidden endpoint", &scoreValue, 0, `[]`); err != nil { + return filterResult{}, err + } + } + } + visibleFiles := filesForSymbols(visible) + for file := range forcedVisibleSymbolsByFile(symbols, forcedVisibleSymbols) { + visibleFiles[file] = struct{}{} + } + for file := range expansionsFiles(symbols, expansions, identityKeys) { + visibleFiles[file] = struct{}{} + } + visibleFacts, err := scoreFacts(ctx, store, runID, facts, visible, visibleFiles, visibilityCfg, expansions) + if err != nil { + return filterResult{}, err + } + for _, fact := range visibleFacts { + if strings.TrimSpace(fact.FilePath) != "" { + visibleFiles[fact.FilePath] = struct{}{} + } + } + if err := store.FinishFilterRun(ctx, runID, "completed", visibleSymbols, hiddenSymbols, len(visibleRefs), hiddenRefs); err != nil { + return filterResult{}, err + } + return filterResult{ + RunID: runID, + SettingsHash: settingsHash, + RawGraphHash: rawGraphHash, + VisibleSymbols: visible, + VisibleReferences: visibleRefs, + VisibleFacts: visibleFacts, + VisibleFiles: visibleFiles, + Incoming: incoming, + Outgoing: outgoing, + ContextPolicies: contextPolicies, + ContextExpansions: expansions, + Visibility: visibilityCfg, + }, nil +} + +func (p contextPolicySet) showSymbol(sym Symbol, identityKeys map[string]string) (string, bool) { + return p.symbolPolicy(p.Show, sym, identityKeys) +} + +func (p contextPolicySet) hideSymbol(sym Symbol, identityKeys map[string]string) (string, bool) { + return p.symbolPolicy(p.Hide, sym, identityKeys) +} + +func (p contextPolicySet) symbolPolicy(policies map[string]string, sym Symbol, identityKeys map[string]string) (string, bool) { + for _, key := range []string{ + ownerMapKey("symbol", symbolOwnerKey(sym, identityKeys)), + ownerMapKey("symbol", sym.StableKey), + ownerMapKey("file", "file:"+sym.FilePath), + } { + if reason, ok := policies[key]; ok { + return reason, true + } + } + dir := path.Dir(sym.FilePath) + for dir != "." && dir != "/" && dir != "" { + if reason, ok := policies[ownerMapKey("folder", "folder:"+dir)]; ok { + return reason, true + } + next := path.Dir(dir) + if next == dir { + break + } + dir = next + } + return "", false +} + +func forcedVisibleSymbolsByFile(symbols []Symbol, forced map[int64]string) map[string]struct{} { + out := map[string]struct{}{} + if len(forced) == 0 { + return out + } + for _, sym := range symbols { + if _, ok := forced[sym.ID]; ok && strings.TrimSpace(sym.FilePath) != "" { + out[sym.FilePath] = struct{}{} + } + } + return out +} + +func expansionsFiles(symbols []Symbol, expansions contextExpansionSet, identityKeys map[string]string) map[string]struct{} { + out := map[string]struct{}{} + for _, sym := range symbols { + if expansions.symbolTier(sym, identityKeys) > 0 && strings.TrimSpace(sym.FilePath) != "" { + out[sym.FilePath] = struct{}{} + } + } + return out +} + +func scoreFacts(ctx context.Context, store *Store, runID int64, facts []Fact, visibleSymbols map[int64]Symbol, visibleFiles map[string]struct{}, cfg VisibilityConfig, expansions contextExpansionSet) ([]Fact, error) { + var visible []Fact + visibleSymbolStable := map[string]struct{}{} + for _, sym := range visibleSymbols { + visibleSymbolStable[sym.StableKey] = struct{}{} + } + for _, fact := range facts { + if fact.Type == enrichmentVersionType { + continue + } + score := visibilityScore{} + if highSignalFact(fact) { + score.add("fact.high_signal", cfg.Weights.HighSignalFact*fact.Confidence, "high-signal fact "+fact.Type) + } else if dependencyFact(fact) { + score.add("fact.dependency", cfg.Weights.DependencyFact*fact.Confidence, "dependency fact") + } + if fact.SubjectKind == "symbol" { + if _, ok := visibleSymbolStable[fact.SubjectStableKey]; ok { + score.add("fact.subject_visible", cfg.Weights.RelationshipProximity, "subject symbol is visible") + } + } + if _, ok := visibleFiles[fact.FilePath]; ok { + score.add("fact.file_visible", cfg.Weights.RelationshipProximity, "source file is visible") + } + if tier := expansions.fileTier(fact.FilePath); tier > 0 { + score.Tier = tier + score.add("context.expansion", cfg.Weights.Selected, "selected context expansion tier "+strconv.Itoa(tier)) + score.Forced = true + } + if hintWeight := factVisibilityHint(fact, "score"); hintWeight != 0 { + score.add("fact.hint", hintWeight, "enricher visibility hint") + } + decision := "hidden" + if (highSignalFact(fact) || dependencyFact(fact)) && (score.Forced || !cfg.CoreThresholdEnabled || score.Score >= cfg.CoreThreshold) { + decision = "visible" + visible = append(visible, fact) + } + scoreValue := score.Score + if err := store.SaveFilterDecision(ctx, runID, "fact", fact.ID, factOwnerKey(fact), decision, score.reason("fact below visibility threshold"), &scoreValue, score.Tier, score.signalsJSON()); err != nil { + return nil, err + } + } + return visible, nil +} + +func factVisibilityHint(fact Fact, key string) float64 { + if strings.TrimSpace(fact.VisibilityHintsJSON) == "" { + return 0 + } + values := map[string]float64{} + if err := json.Unmarshal([]byte(fact.VisibilityHintsJSON), &values); err != nil { + return 0 + } + return values[key] +} + +func highSignalFact(fact Fact) bool { + if factVisibilityHint(fact, "high_signal") > 0 { + return true + } + switch fact.Type { + case "http.route", "frontend.route", "orm.query": + return true + default: + return false + } +} + +func dependencyFact(fact Fact) bool { + return strings.HasPrefix(fact.Type, "dependency.") +} + +func factOwnerKey(fact Fact) string { + return "fact:" + fact.Enricher + ":" + fact.StableKey +} + +func rescueRelatedSymbolsScored(symbols []Symbol, refs []Reference, visible map[int64]Symbol, scores map[int64]*visibilityScore, embeddings map[int64]Vector, weight float64) { + byID := map[int64]Symbol{} + for _, sym := range symbols { + byID[sym.ID] = sym + } + for _, ref := range refs { + sourceVisible := visible[ref.SourceSymbolID] + targetVisible := visible[ref.TargetSymbolID] + switch { + case sourceVisible.ID != 0 && targetVisible.ID == 0: + if target, ok := byID[ref.TargetSymbolID]; ok && embeddingSimilar(sourceVisible.ID, target.ID, embeddings, 0.82) { + visible[target.ID] = target + scores[target.ID].add("embedding.neighbor", weight, "embedding-similar graph neighbor") + } + case targetVisible.ID != 0 && sourceVisible.ID == 0: + if source, ok := byID[ref.SourceSymbolID]; ok && embeddingSimilar(targetVisible.ID, source.ID, embeddings, 0.82) { + visible[source.ID] = source + scores[source.ID].add("embedding.neighbor", weight, "embedding-similar graph neighbor") + } + } + } +} + +func embeddingSimilar(leftID, rightID int64, embeddings map[int64]Vector, threshold float64) bool { + left, leftOK := embeddings[leftID] + right, rightOK := embeddings[rightID] + if !leftOK || !rightOK { + return false + } + return CosineSimilarity(left, right) >= threshold +} + +func isExportedSymbol(sym Symbol) bool { + if sym.Name == "" { + return false + } + first := []rune(sym.Name)[0] + return unicode.IsUpper(first) +} + +func looksLikeTinyUtility(sym Symbol) bool { + name := strings.ToLower(sym.Name) + file := strings.ToLower(path.Base(sym.FilePath)) + for _, marker := range []string{"log", "logger", "metric", "trace", "debug", "helper", "util"} { + if strings.Contains(name, marker) || strings.Contains(file, marker) { + return true + } + } + return false +} + +func stableClusterKey(repositoryID int64, parentScope, settingsHash string, memberKeys []string) string { + keys := append([]string(nil), memberKeys...) + sort.Strings(keys) + return "cluster:" + strconv.FormatInt(repositoryID, 10) + ":" + parentScope + ":" + settingsHash + ":" + stableHash(keys) +} diff --git a/internal/watch/http.go b/internal/watch/http.go new file mode 100644 index 0000000..b70abea --- /dev/null +++ b/internal/watch/http.go @@ -0,0 +1,324 @@ +package watch + +import ( + "encoding/json" + "errors" + "io" + "net/http" + "strconv" +) + +type Handler struct { + Store *Store + Representer *Representer +} + +func NewHandler(store *Store) *Handler { + return &Handler{Store: store, Representer: NewRepresenter(store)} +} + +func (h *Handler) Register(mux *http.ServeMux) { + mux.HandleFunc("GET /api/watch/ws", h.watchWebSocket) + mux.HandleFunc("GET /api/watch/status", h.status) + mux.HandleFunc("GET /api/watch/repositories", h.listRepositories) + mux.HandleFunc("GET /api/watch/repositories/{id}/raw-graph/summary", h.rawGraphSummary) + mux.HandleFunc("GET /api/watch/repositories/{id}/raw-graph/symbols", h.rawGraphSymbols) + mux.HandleFunc("GET /api/watch/repositories/{id}/raw-graph/references", h.rawGraphReferences) + mux.HandleFunc("POST /api/watch/repositories/{id}/reassociate", h.reassociateRepository) + mux.HandleFunc("POST /api/watch/repositories/{id}/represent", h.representRepository) + mux.HandleFunc("POST /api/watch/repositories/{id}/context/show", h.showContext) + mux.HandleFunc("POST /api/watch/repositories/{id}/context/clean", h.cleanContext) + mux.HandleFunc("POST /api/watch/repositories/{id}/context/hide", h.hideContext) + mux.HandleFunc("GET /api/watch/repositories/{id}/representation/summary", h.representationSummary) + mux.HandleFunc("GET /api/watch/repositories/{id}/filter-decisions", h.filterDecisions) + mux.HandleFunc("GET /api/watch/repositories/{id}/clusters", h.clusters) + mux.HandleFunc("GET /api/watch/repositories/{id}/materialization", h.materialization) + mux.HandleFunc("GET /api/watch/repositories/{id}/versions", h.versions) + mux.HandleFunc("GET /api/watch/versions/{id}/diffs", h.versionDiffs) +} + +func (h *Handler) showContext(w http.ResponseWriter, r *http.Request) { + h.contextAction(w, r, contextActionShow) +} + +func (h *Handler) hideContext(w http.ResponseWriter, r *http.Request) { + h.contextAction(w, r, contextActionHide) +} + +func (h *Handler) cleanContext(w http.ResponseWriter, r *http.Request) { + h.contextAction(w, r, contextActionClean) +} + +func (h *Handler) contextAction(w http.ResponseWriter, r *http.Request, action string) { + repositoryID, ok := parseIDPath(w, r, "id") + if !ok { + return + } + var body struct { + ContextResourceRequest + Represent RepresentRequest `json:"represent"` + } + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + writeError(w, http.StatusBadRequest, "invalid JSON") + return + } + representReq := body.Represent + if representReq.Embedding.Provider == "" { + representReq.Embedding.Provider = "none" + } + result, err := h.Store.ApplyContextAction(r.Context(), repositoryID, action, body.ContextResourceRequest, representReq) + if err != nil { + writeError(w, http.StatusBadRequest, err.Error()) + return + } + writeJSON(w, http.StatusOK, result) +} + +func (h *Handler) status(w http.ResponseWriter, r *http.Request) { + lock, live, err := h.Store.ActiveLiveLock(r.Context(), LockHeartbeatTimeout) + if err != nil { + writeError(w, http.StatusInternalServerError, "load watch status") + return + } + if !live { + writeJSON(w, http.StatusOK, map[string]any{ + "active": false, + "connected_clients": WatchWebSocketClientCount(), + }) + return + } + repo, err := h.Store.Repository(r.Context(), lock.RepositoryID) + if err != nil { + writeError(w, http.StatusInternalServerError, "load watch repository") + return + } + writeJSON(w, http.StatusOK, map[string]any{ + "active": true, + "repository": repo.JSON(), + "lock": lock, + "connected_clients": WatchWebSocketClientCount(), + }) +} + +func (h *Handler) listRepositories(w http.ResponseWriter, r *http.Request) { + repos, err := h.Store.Repositories(r.Context()) + if err != nil { + writeError(w, http.StatusInternalServerError, "list repositories") + return + } + out := make([]RepositoryJSON, 0, len(repos)) + for _, repo := range repos { + out = append(out, repo.JSON()) + } + writeJSON(w, http.StatusOK, out) +} + +func (h *Handler) rawGraphSummary(w http.ResponseWriter, r *http.Request) { + repositoryID, ok := parseIDPath(w, r, "id") + if !ok { + return + } + summary, err := h.Store.Summary(r.Context(), repositoryID) + if err != nil { + writeError(w, http.StatusInternalServerError, "load raw graph summary") + return + } + writeJSON(w, http.StatusOK, summary) +} + +func (h *Handler) rawGraphSymbols(w http.ResponseWriter, r *http.Request) { + repositoryID, ok := parseIDPath(w, r, "id") + if !ok { + return + } + query := r.URL.Query() + symbols, err := h.Store.QuerySymbols(r.Context(), repositoryID, SymbolQuery{ + Search: query.Get("search"), + File: query.Get("file"), + Kind: query.Get("kind"), + Limit: parseInt(query.Get("limit"), 100), + Offset: parseInt(query.Get("offset"), 0), + }) + if err != nil { + writeError(w, http.StatusInternalServerError, "list raw graph symbols") + return + } + writeJSON(w, http.StatusOK, symbols) +} + +func (h *Handler) rawGraphReferences(w http.ResponseWriter, r *http.Request) { + repositoryID, ok := parseIDPath(w, r, "id") + if !ok { + return + } + query := r.URL.Query() + refs, err := h.Store.QueryReferences(r.Context(), repositoryID, ReferenceQuery{ + SymbolID: int64(parseInt(query.Get("symbol_id"), 0)), + Limit: parseInt(query.Get("limit"), 100), + Offset: parseInt(query.Get("offset"), 0), + }) + if err != nil { + writeError(w, http.StatusInternalServerError, "list raw graph references") + return + } + writeJSON(w, http.StatusOK, refs) +} + +func (h *Handler) reassociateRepository(w http.ResponseWriter, r *http.Request) { + repositoryID, ok := parseIDPath(w, r, "id") + if !ok { + return + } + var body struct { + RemoteURL string `json:"remote_url"` + } + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + writeError(w, http.StatusBadRequest, "invalid JSON") + return + } + repo, err := h.Store.ReassociateRepository(r.Context(), repositoryID, body.RemoteURL) + if err != nil { + writeError(w, http.StatusBadRequest, err.Error()) + return + } + writeJSON(w, http.StatusOK, repo.JSON()) +} + +func (h *Handler) representRepository(w http.ResponseWriter, r *http.Request) { + repositoryID, ok := parseIDPath(w, r, "id") + if !ok { + return + } + var body RepresentRequest + if r.Body != nil { + if err := json.NewDecoder(r.Body).Decode(&body); err != nil && !errors.Is(err, io.EOF) { + writeError(w, http.StatusBadRequest, "invalid JSON") + return + } + } + representer := h.Representer + if representer == nil { + representer = NewRepresenter(h.Store) + } + result, err := representer.Represent(r.Context(), repositoryID, body) + if err != nil { + writeError(w, http.StatusInternalServerError, err.Error()) + return + } + writeJSON(w, http.StatusOK, result) +} + +func (h *Handler) representationSummary(w http.ResponseWriter, r *http.Request) { + repositoryID, ok := parseIDPath(w, r, "id") + if !ok { + return + } + summary, err := h.Store.RepresentationSummary(r.Context(), repositoryID) + if err != nil { + writeError(w, http.StatusInternalServerError, "load representation summary") + return + } + writeJSON(w, http.StatusOK, summary) +} + +func (h *Handler) filterDecisions(w http.ResponseWriter, r *http.Request) { + repositoryID, ok := parseIDPath(w, r, "id") + if !ok { + return + } + query := r.URL.Query() + decisions, err := h.Store.FilterDecisions(r.Context(), repositoryID, FilterDecisionQuery{ + OwnerType: query.Get("owner_type"), + Decision: query.Get("decision"), + Limit: parseInt(query.Get("limit"), 100), + Offset: parseInt(query.Get("offset"), 0), + }) + if err != nil { + writeError(w, http.StatusInternalServerError, "list filter decisions") + return + } + writeJSON(w, http.StatusOK, decisions) +} + +func (h *Handler) clusters(w http.ResponseWriter, r *http.Request) { + repositoryID, ok := parseIDPath(w, r, "id") + if !ok { + return + } + clusters, err := h.Store.Clusters(r.Context(), repositoryID) + if err != nil { + writeError(w, http.StatusInternalServerError, "list clusters") + return + } + writeJSON(w, http.StatusOK, clusters) +} + +func (h *Handler) materialization(w http.ResponseWriter, r *http.Request) { + repositoryID, ok := parseIDPath(w, r, "id") + if !ok { + return + } + mappings, err := h.Store.Materialization(r.Context(), repositoryID) + if err != nil { + writeError(w, http.StatusInternalServerError, "list materialization mappings") + return + } + writeJSON(w, http.StatusOK, mappings) +} + +func (h *Handler) versions(w http.ResponseWriter, r *http.Request) { + repositoryID, ok := parseIDPath(w, r, "id") + if !ok { + return + } + versions, err := h.Store.WatchVersions(r.Context(), repositoryID, parseInt(r.URL.Query().Get("limit"), 100)) + if err != nil { + writeError(w, http.StatusInternalServerError, "list watch versions") + return + } + writeJSON(w, http.StatusOK, versions) +} + +func (h *Handler) versionDiffs(w http.ResponseWriter, r *http.Request) { + versionID, ok := parseIDPath(w, r, "id") + if !ok { + return + } + query := r.URL.Query() + diffs, err := h.Store.WatchDiffs(r.Context(), versionID, query.Get("owner_type"), query.Get("change_type"), query.Get("resource_type"), query.Get("language"), parseInt(query.Get("limit"), 200)) + if err != nil { + writeError(w, http.StatusInternalServerError, "list watch diffs") + return + } + writeJSON(w, http.StatusOK, diffs) +} + +func parseIDPath(w http.ResponseWriter, r *http.Request, name string) (int64, bool) { + id, err := strconv.ParseInt(r.PathValue(name), 10, 64) + if err != nil || id <= 0 { + writeError(w, http.StatusBadRequest, "invalid id") + return 0, false + } + return id, true +} + +func parseInt(value string, fallback int) int { + if value == "" { + return fallback + } + parsed, err := strconv.Atoi(value) + if err != nil || parsed < 0 { + return fallback + } + return parsed +} + +func writeJSON(w http.ResponseWriter, status int, value any) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + _ = json.NewEncoder(w).Encode(value) +} + +func writeError(w http.ResponseWriter, status int, message string) { + writeJSON(w, status, map[string]string{"error": message}) +} diff --git a/internal/watch/models.go b/internal/watch/models.go new file mode 100644 index 0000000..6779c5f --- /dev/null +++ b/internal/watch/models.go @@ -0,0 +1,455 @@ +package watch + +import ( + "database/sql" + "time" +) + +const SettingsHash = "" + +type Repository struct { + ID int64 `json:"id"` + RemoteURL sql.NullString `json:"-"` + RepoRoot string `json:"repo_root"` + DisplayName string `json:"display_name"` + Branch sql.NullString `json:"-"` + HeadCommit sql.NullString `json:"-"` + IdentityStatus string `json:"identity_status"` + SettingsHash string `json:"settings_hash"` + CreatedAt string `json:"created_at"` + UpdatedAt string `json:"updated_at"` +} + +type RepositoryJSON struct { + ID int64 `json:"id"` + RemoteURL *string `json:"remote_url"` + RepoRoot string `json:"repo_root"` + DisplayName string `json:"display_name"` + Branch *string `json:"branch"` + HeadCommit *string `json:"head_commit"` + IdentityStatus string `json:"identity_status"` + SettingsHash string `json:"settings_hash"` + CreatedAt string `json:"created_at"` + UpdatedAt string `json:"updated_at"` +} + +type File struct { + ID int64 `json:"id"` + RepositoryID int64 `json:"repository_id"` + Path string `json:"path"` + Language string `json:"language"` + GitBlobHash sql.NullString `json:"-"` + WorktreeHash string `json:"worktree_hash"` + SizeBytes int64 `json:"size_bytes"` + MtimeUnix int64 `json:"mtime_unix"` + ScanStatus string `json:"scan_status"` + ScanError sql.NullString `json:"-"` + CreatedAt string `json:"created_at"` + UpdatedAt string `json:"updated_at"` +} + +type Symbol struct { + ID int64 `json:"id"` + RepositoryID int64 `json:"repository_id"` + FileID int64 `json:"file_id"` + FilePath string `json:"file_path,omitempty"` + StableKey string `json:"stable_key"` + Name string `json:"name"` + QualifiedName string `json:"qualified_name"` + Kind string `json:"kind"` + StartLine int `json:"start_line"` + EndLine *int `json:"end_line"` + SignatureHash string `json:"signature_hash"` + ContentHash string `json:"content_hash"` + RawJSON string `json:"raw_json"` + CreatedAt string `json:"created_at"` + UpdatedAt string `json:"updated_at"` +} + +type Reference struct { + ID int64 `json:"id"` + RepositoryID int64 `json:"repository_id"` + SourceSymbolID int64 `json:"source_symbol_id"` + TargetSymbolID int64 `json:"target_symbol_id"` + SourceFileID int64 `json:"source_file_id"` + Kind string `json:"kind"` + Line int `json:"line"` + Column int `json:"column"` + EvidenceHash string `json:"evidence_hash"` + RawJSON string `json:"raw_json"` + CreatedAt string `json:"created_at"` + UpdatedAt string `json:"updated_at"` +} + +type Fact struct { + ID int64 `json:"id"` + RepositoryID int64 `json:"repository_id"` + FileID int64 `json:"file_id"` + FilePath string `json:"file_path"` + StableKey string `json:"stable_key"` + Type string `json:"type"` + Enricher string `json:"enricher"` + SubjectKind string `json:"subject_kind"` + SubjectStableKey string `json:"subject_stable_key"` + ObjectKind string `json:"object_kind,omitempty"` + ObjectStableKey string `json:"object_stable_key,omitempty"` + ObjectFilePath string `json:"object_file_path,omitempty"` + ObjectName string `json:"object_name,omitempty"` + Relationship string `json:"relationship,omitempty"` + StartLine int `json:"start_line"` + EndLine *int `json:"end_line,omitempty"` + Confidence float64 `json:"confidence"` + Name string `json:"name"` + Tags []string `json:"tags"` + AttributesJSON string `json:"attributes_json"` + VisibilityHintsJSON string `json:"visibility_hints_json"` + FactHash string `json:"fact_hash"` + RawJSON string `json:"raw_json"` + CreatedAt string `json:"created_at"` + UpdatedAt string `json:"updated_at"` +} + +type Summary struct { + RepositoryID int64 `json:"repository_id"` + Files int `json:"files"` + Symbols int `json:"symbols"` + References int `json:"references"` + LastScanStatus string `json:"last_scan_status,omitempty"` + LastScanStarted string `json:"last_scan_started_at,omitempty"` + LastScanFinished string `json:"last_scan_finished_at,omitempty"` +} + +type ScanResult struct { + RepositoryID int64 `json:"repository_id"` + ScanRunID int64 `json:"scan_run_id"` + FilesSeen int `json:"files_seen"` + FilesParsed int `json:"files_parsed"` + FilesSkipped int `json:"files_skipped"` + SymbolsSeen int `json:"symbols_seen"` + ReferencesSeen int `json:"references_seen"` + Warning string `json:"warning,omitempty"` + Warnings []string `json:"warnings,omitempty"` +} + +type EmbeddingConfig struct { + Provider string `json:"provider" yaml:"provider"` + Endpoint string `json:"endpoint,omitempty" yaml:"endpoint"` + Model string `json:"model" yaml:"model"` + Dimension int `json:"dimension" yaml:"dimension"` + HealthThreshold float64 `json:"health_threshold,omitempty" yaml:"health_threshold"` + TimeoutSeconds int `json:"timeout_seconds,omitempty" yaml:"timeout_seconds"` +} + +type Thresholds struct { + MaxElementsPerView int `json:"max_elements_per_view"` + MaxConnectorsPerView int `json:"max_connectors_per_view"` + MaxIncomingPerElement int `json:"max_incoming_per_element"` + MaxOutgoingPerElement int `json:"max_outgoing_per_element"` + MaxExpandedConnectorsPerGroup int `json:"max_expanded_connectors_per_group"` +} + +type VisibilityWeights struct { + Changed float64 `json:"changed" yaml:"changed"` + Selected float64 `json:"selected" yaml:"selected"` + UserShow float64 `json:"user_show" yaml:"user_show"` + UserHide float64 `json:"user_hide" yaml:"user_hide"` + HighSignalFact float64 `json:"high_signal_fact" yaml:"high_signal_fact"` + RelationshipProximity float64 `json:"relationship_proximity" yaml:"relationship_proximity"` + DependencyFact float64 `json:"dependency_fact" yaml:"dependency_fact"` + UtilityNoise float64 `json:"utility_noise" yaml:"utility_noise"` + HighDegreeNoise float64 `json:"high_degree_noise" yaml:"high_degree_noise"` +} + +type VisibilityConfig struct { + CoreThresholdEnabled bool `json:"core_threshold_enabled" yaml:"core_threshold_enabled"` + CoreThreshold float64 `json:"core_threshold" yaml:"core_threshold"` + TierMultiplier float64 `json:"tier_multiplier" yaml:"tier_multiplier"` + MaxExpansionMultiplier float64 `json:"max_expansion_multiplier" yaml:"max_expansion_multiplier"` + Weights VisibilityWeights `json:"weights" yaml:"weights"` + CoreThresholdSet bool `json:"-" yaml:"-"` + WeightsSet bool `json:"-" yaml:"-"` +} + +type Settings struct { + Languages []string `json:"languages"` + Watcher string `json:"watcher"` + PollInterval time.Duration `json:"poll_interval"` + Debounce time.Duration `json:"debounce"` + Thresholds Thresholds `json:"thresholds"` + Visibility VisibilityConfig `json:"visibility"` +} + +type RepresentRequest struct { + Embedding EmbeddingConfig `json:"embedding"` + Thresholds Thresholds `json:"thresholds"` + Visibility VisibilityConfig `json:"visibility"` + Progress ProgressSink `json:"-"` +} + +type RepresentResult struct { + RepositoryID int64 `json:"repository_id"` + RepresentationRun int64 `json:"representation_run_id"` + FilterRunID int64 `json:"filter_run_id"` + RawGraphHash string `json:"raw_graph_hash"` + SettingsHash string `json:"settings_hash"` + RepresentationHash string `json:"representation_hash"` + ElementsCreated int `json:"elements_created"` + ElementsUpdated int `json:"elements_updated"` + ConnectorsCreated int `json:"connectors_created"` + ConnectorsUpdated int `json:"connectors_updated"` + ViewsCreated int `json:"views_created"` + ElementsPreserved int `json:"elements_preserved"` + ConnectorsPreserved int `json:"connectors_preserved"` + ViewsPreserved int `json:"views_preserved"` + DeletesPreserved int `json:"deletes_preserved"` + EmbeddingCacheHits int `json:"embedding_cache_hits"` + EmbeddingsCreated int `json:"embeddings_created"` +} + +type ProgressSink interface { + Start(label string, total int) + Advance(label string) + Finish() +} + +type RepresentationSummary struct { + RepositoryID int64 `json:"repository_id"` + RawGraphHash string `json:"raw_graph_hash,omitempty"` + SettingsHash string `json:"filter_settings_hash,omitempty"` + RepresentationHash string `json:"representation_hash,omitempty"` + LastStatus string `json:"last_status,omitempty"` + LastStartedAt string `json:"last_started_at,omitempty"` + LastFinishedAt *string `json:"last_finished_at,omitempty"` + ElementsCreated int `json:"elements_created"` + ElementsUpdated int `json:"elements_updated"` + ConnectorsCreated int `json:"connectors_created"` + ConnectorsUpdated int `json:"connectors_updated"` + ViewsCreated int `json:"views_created"` + VisibleSymbols int `json:"visible_symbols"` + HiddenSymbols int `json:"hidden_symbols"` + VisibleReferences int `json:"visible_references"` + HiddenReferences int `json:"hidden_references"` + Diffs []RepresentationDiff `json:"diffs,omitempty"` +} + +type FilterDecision struct { + ID int64 `json:"id"` + FilterRunID int64 `json:"filter_run_id"` + OwnerType string `json:"owner_type"` + OwnerID int64 `json:"owner_id"` + OwnerKey string `json:"owner_key,omitempty"` + Decision string `json:"decision"` + Reason string `json:"reason"` + Score *float64 `json:"score,omitempty"` + Tier int `json:"tier,omitempty"` + SignalsJSON string `json:"signals_json,omitempty"` +} + +type Cluster struct { + ID int64 `json:"id"` + RepositoryID int64 `json:"repository_id"` + StableKey string `json:"stable_key"` + ParentClusterID *int64 `json:"parent_cluster_id,omitempty"` + Name string `json:"name"` + Kind string `json:"kind"` + Algorithm string `json:"algorithm"` + SettingsHash string `json:"settings_hash"` + MemberCount int `json:"member_count"` + CreatedAt string `json:"created_at"` + UpdatedAt string `json:"updated_at"` +} + +type MaterializationMapping struct { + ID int64 `json:"id"` + RepositoryID int64 `json:"repository_id"` + OwnerType string `json:"owner_type"` + OwnerKey string `json:"owner_key"` + ResourceType string `json:"resource_type"` + ResourceID int64 `json:"resource_id"` + LastWatchHash *string `json:"last_watch_hash,omitempty"` + Dirty bool `json:"dirty"` + DirtyDetectedAt *string `json:"dirty_detected_at,omitempty"` + CreatedAt string `json:"created_at"` + UpdatedAt string `json:"updated_at"` +} + +type ArchitectureBinding struct { + ID int64 `json:"id,omitempty"` + RepositoryID int64 `json:"repository_id"` + ComponentKey string `json:"component_key"` + TargetRepositoryID int64 `json:"target_repository_id"` + TargetOwnerType string `json:"target_owner_type"` + TargetOwnerKey string `json:"target_owner_key"` + TargetResourceType string `json:"target_resource_type"` + TargetResourceID int64 `json:"target_resource_id"` + Role string `json:"role"` + Confidence float64 `json:"confidence"` + Evidence []ArchitectureBindingEvidence `json:"evidence"` + CreatedAt string `json:"created_at,omitempty"` + UpdatedAt string `json:"updated_at,omitempty"` +} + +type ArchitectureBindingEvidence struct { + Kind string `json:"kind"` + Detail string `json:"detail"` + Score float64 `json:"score"` +} + +type ArchitectureBindingTarget struct { + RepositoryID int64 `json:"repository_id"` + OwnerType string `json:"owner_type"` + OwnerKey string `json:"owner_key"` + ResourceType string `json:"resource_type"` + ResourceID int64 `json:"resource_id"` + ViewID int64 `json:"view_id,omitempty"` + Name string `json:"name"` + Kind string `json:"kind"` + FilePath string `json:"file_path,omitempty"` + Language string `json:"language,omitempty"` + Tags []string `json:"tags,omitempty"` +} + +type ContextPolicy struct { + ID int64 `json:"id"` + RepositoryID int64 `json:"repository_id"` + OwnerType string `json:"owner_type"` + OwnerKey string `json:"owner_key"` + Action string `json:"action"` + Scope string `json:"scope"` + Active bool `json:"active"` + Reason string `json:"reason"` + CreatedAt string `json:"created_at"` + UpdatedAt string `json:"updated_at"` +} + +type ContextResourceRequest struct { + ResourceType string `json:"resource_type"` + ResourceID int64 `json:"resource_id"` +} + +type ContextActionResult struct { + RepositoryID int64 `json:"repository_id"` + Action string `json:"action"` + PoliciesCreated int `json:"policies_created"` + PoliciesUpdated int `json:"policies_updated"` + PoliciesDeactivated int `json:"policies_deactivated"` + OwnersAffected int `json:"owners_affected"` + TierBefore int `json:"tier_before"` + TierAfter int `json:"tier_after"` + MaxTier int `json:"max_tier"` + ElementsAdded int `json:"elements_added"` + ConnectorsAdded int `json:"connectors_added"` + ViewsAdded int `json:"views_added"` + ElementsRemoved int `json:"elements_removed"` + ConnectorsRemoved int `json:"connectors_removed"` + ViewsRemoved int `json:"views_removed"` + Representation RepresentResult `json:"representation"` + Summary RepresentationSummary `json:"summary"` +} + +type Lock struct { + ID int64 `json:"id"` + RepositoryID int64 `json:"repository_id"` + PID int `json:"pid"` + Token string `json:"token,omitempty"` + StartedAt string `json:"started_at"` + HeartbeatAt string `json:"heartbeat_at"` + Status string `json:"status"` +} + +type GitStatus struct { + Branch string `json:"branch"` + HeadCommit string `json:"head_commit"` + HeadMessage string `json:"head_message,omitempty"` + RemoteURL string `json:"remote_url"` + Staged []string `json:"staged"` + Unstaged []string `json:"unstaged"` + Untracked []string `json:"untracked"` + Deleted []string `json:"deleted"` +} + +type GitTagUpdateResult struct { + TagsAdded int `json:"tags_added"` + TagsRemoved int `json:"tags_removed"` +} + +type SourceFileChange struct { + Path string `json:"path"` + ChangeType string `json:"change_type"` + Language string `json:"language,omitempty"` +} + +type SourceFileChangeResult struct { + Change SourceFileChange `json:"change"` + RepresentationChanged bool `json:"representation_changed"` + Representation RepresentResult `json:"representation"` + GitTags GitTagUpdateResult `json:"git_tags"` +} + +type ChangeCounter struct { + TotalChangesProcessed int `json:"total_changes_processed"` + IntervalChangesProcessed int `json:"interval_changes_processed"` +} + +type Event struct { + Type string `json:"type"` + RepositoryID int64 `json:"repository_id,omitempty"` + Message string `json:"message,omitempty"` + At string `json:"at"` + Data any `json:"data,omitempty"` + Phase string `json:"phase,omitempty"` + WatcherMode string `json:"watcher_mode,omitempty"` + Languages []string `json:"languages,omitempty"` + ChangedFiles int `json:"changed_files,omitempty"` + Warnings []string `json:"warnings,omitempty"` +} + +type Version struct { + ID int64 `json:"id"` + RepositoryID int64 `json:"repository_id"` + CommitHash string `json:"commit_hash"` + CommitMessage string `json:"commit_message,omitempty"` + ParentCommitHash string `json:"parent_commit_hash,omitempty"` + Branch string `json:"branch,omitempty"` + RepresentationHash string `json:"representation_hash"` + WorkspaceVersionID *int64 `json:"workspace_version_id,omitempty"` + CreatedAt string `json:"created_at"` +} + +type RepresentationDiff struct { + ID int64 `json:"id"` + VersionID int64 `json:"version_id"` + OwnerType string `json:"owner_type"` + OwnerKey string `json:"owner_key"` + ChangeType string `json:"change_type"` + BeforeHash *string `json:"before_hash,omitempty"` + AfterHash *string `json:"after_hash,omitempty"` + ResourceType *string `json:"resource_type,omitempty"` + ResourceID *int64 `json:"resource_id,omitempty"` + Language *string `json:"language,omitempty"` + Summary *string `json:"summary,omitempty"` + AddedLines int `json:"added_lines,omitempty"` + RemovedLines int `json:"removed_lines,omitempty"` +} + +func (r Repository) JSON() RepositoryJSON { + return RepositoryJSON{ + ID: r.ID, + RemoteURL: nullStringPtr(r.RemoteURL), + RepoRoot: r.RepoRoot, + DisplayName: r.DisplayName, + Branch: nullStringPtr(r.Branch), + HeadCommit: nullStringPtr(r.HeadCommit), + IdentityStatus: r.IdentityStatus, + SettingsHash: r.SettingsHash, + CreatedAt: r.CreatedAt, + UpdatedAt: r.UpdatedAt, + } +} + +func nullStringPtr(value sql.NullString) *string { + if !value.Valid { + return nil + } + return &value.String +} diff --git a/internal/watch/once.go b/internal/watch/once.go new file mode 100644 index 0000000..50d510f --- /dev/null +++ b/internal/watch/once.go @@ -0,0 +1,82 @@ +package watch + +import ( + "context" + "fmt" + "path/filepath" + + tldgit "github.com/mertcikla/tld/internal/git" +) + +type OneShotOptions struct { + Path string + Rescan bool + Embedding EmbeddingConfig + Settings Settings + Progress ProgressSink +} + +type OneShotResult struct { + Repository Repository `json:"repository"` + Scan ScanResult `json:"scan"` + Representation RepresentResult `json:"representation"` + GitStatus GitStatus `json:"git_status"` + Diffs []RepresentationDiff `json:"diffs,omitempty"` +} + +func (r *Runner) RunOnce(ctx context.Context, opts OneShotOptions) (OneShotResult, error) { + if r == nil || r.Store == nil { + return OneShotResult{}, fmt.Errorf("watch runner requires a store") + } + if r.Scanner == nil { + r.Scanner = NewScanner(r.Store) + } + if r.Representer == nil { + r.Representer = NewRepresenter(r.Store) + } + if opts.Path == "" { + opts.Path = "." + } + settings := NormalizeSettings(opts.Settings) + r.Scanner.Settings = settings + r.Scanner.Progress = opts.Progress + + progressStart(opts.Progress, "Preparing repository", 3) + absPath, err := filepath.Abs(opts.Path) + if err != nil { + progressFinish(opts.Progress) + return OneShotResult{}, err + } + progressAdvance(opts.Progress, "Resolved repository path") + repoRoot, err := tldgit.RepoRoot(absPath) + if err != nil { + progressFinish(opts.Progress) + return OneShotResult{}, fmt.Errorf("%s is not inside a git repository: %w", opts.Path, err) + } + progressAdvance(opts.Progress, "Detected git repository") + gitStatus, _ := gitStatusSnapshot(repoRoot) + progressAdvance(opts.Progress, "Captured git status") + progressFinish(opts.Progress) + + scan, err := r.Scanner.ScanWithOptions(ctx, repoRoot, ScanOptions{Force: opts.Rescan}) + if err != nil { + return OneShotResult{}, err + } + repo, err := r.Store.Repository(ctx, scan.RepositoryID) + if err != nil { + return OneShotResult{}, err + } + rep, err := r.Representer.Represent(ctx, repo.ID, RepresentRequest{Embedding: opts.Embedding, Thresholds: settings.Thresholds, Visibility: settings.Visibility, Progress: opts.Progress}) + if err != nil { + return OneShotResult{}, err + } + progressStart(opts.Progress, "Computing representation diffs", 1) + diffs, err := r.Store.BuildWatchDiffs(ctx, repo.ID, rep.RepresentationHash) + if err != nil { + progressFinish(opts.Progress) + return OneShotResult{}, err + } + progressAdvance(opts.Progress, "Representation diffs computed") + progressFinish(opts.Progress) + return OneShotResult{Repository: repo, Scan: scan, Representation: rep, GitStatus: gitStatus, Diffs: diffs}, nil +} diff --git a/internal/watch/process_unix.go b/internal/watch/process_unix.go new file mode 100644 index 0000000..36d23a2 --- /dev/null +++ b/internal/watch/process_unix.go @@ -0,0 +1,21 @@ +//go:build !windows + +package watch + +import ( + "os" + "syscall" +) + +var watchProcessIsRunning = processExists + +func processExists(pid int) bool { + if pid <= 0 { + return false + } + proc, err := os.FindProcess(pid) + if err != nil { + return false + } + return proc.Signal(syscall.Signal(0)) == nil +} diff --git a/internal/watch/process_windows.go b/internal/watch/process_windows.go new file mode 100644 index 0000000..350c46e --- /dev/null +++ b/internal/watch/process_windows.go @@ -0,0 +1,29 @@ +//go:build windows + +package watch + +import "syscall" + +const ( + processQueryLimitedInformation = 0x1000 + stillActive = 259 +) + +var watchProcessIsRunning = processExists + +func processExists(pid int) bool { + if pid <= 0 { + return false + } + handle, err := syscall.OpenProcess(processQueryLimitedInformation, false, uint32(pid)) + if err != nil { + return false + } + defer syscall.CloseHandle(handle) + + var exitCode uint32 + if err := syscall.GetExitCodeProcess(handle, &exitCode); err != nil { + return false + } + return exitCode == stillActive +} diff --git a/internal/watch/represent.go b/internal/watch/represent.go new file mode 100644 index 0000000..8d76883 --- /dev/null +++ b/internal/watch/represent.go @@ -0,0 +1,2778 @@ +package watch + +import ( + "context" + "database/sql" + "encoding/json" + "errors" + "fmt" + "math" + "os" + "path" + "path/filepath" + "slices" + "sort" + "strings" + "time" + + "github.com/mertcikla/tld/internal/codeowners" + "github.com/mertcikla/tld/internal/layout" + "github.com/mertcikla/tld/internal/tagcolors" +) + +const ( + defaultEmbeddingBatchSize = 256 + maxEmbeddingInputApproxTokens = 8000 + maxEmbeddingInputChars = maxEmbeddingInputApproxTokens * 4 +) + +var ( + maxEmbeddingSymbolsPerRun = 5000 + maxDetailedSymbolElements = 5000 +) + +type Representer struct { + Store *Store +} + +func NewRepresenter(store *Store) *Representer { + return &Representer{Store: store} +} + +func (r *Representer) Represent(ctx context.Context, repositoryID int64, req RepresentRequest) (RepresentResult, error) { + if r == nil || r.Store == nil { + return RepresentResult{}, fmt.Errorf("watch representer requires a store") + } + req.Embedding = normalizeEmbeddingConfig(req.Embedding) + req.Thresholds = defaultThresholds(req.Thresholds) + req.Visibility = defaultVisibilityConfig(req.Visibility) + settingsHash := settingsHash(req) + progressStart(req.Progress, "Preparing representation graph", 8) + rawGraphHash, err := r.Store.RawGraphHash(ctx, repositoryID) + if err != nil { + progressFinish(req.Progress) + return RepresentResult{}, err + } + progressAdvance(req.Progress, "Raw graph hashed") + repo, err := r.Store.Repository(ctx, repositoryID) + if err != nil { + progressFinish(req.Progress) + return RepresentResult{}, err + } + progressAdvance(req.Progress, "Repository loaded") + + provider, err := NewEmbeddingProvider(req.Embedding) + if err != nil { + progressFinish(req.Progress) + return RepresentResult{}, err + } + progressAdvance(req.Progress, "Embedding provider configured") + model := provider.ModelID() + modelID, err := r.Store.EnsureEmbeddingModel(ctx, EmbeddingConfig{Provider: model.Provider, Model: model.Model, Dimension: model.Dimension}, model.ConfigHash) + if err != nil { + progressFinish(req.Progress) + return RepresentResult{}, err + } + progressAdvance(req.Progress, "Embedding model registered") + modelIDPtr := &modelID + if model.Provider == "none" { + modelIDPtr = nil + } + + identityKeys, err := r.Store.SymbolIdentityKeys(ctx, repositoryID) + if err != nil { + progressFinish(req.Progress) + return RepresentResult{}, err + } + progressAdvance(req.Progress, "Symbol identities loaded") + contextPolicies, err := r.Store.ActiveContextPolicySet(ctx, repositoryID) + if err != nil { + progressFinish(req.Progress) + return RepresentResult{}, err + } + progressAdvance(req.Progress, "Context policies loaded") + changedRaw, err := r.Store.ChangedRawResourcesSinceLatest(ctx, repositoryID) + if err != nil { + progressFinish(req.Progress) + return RepresentResult{}, err + } + progressAdvance(req.Progress, "Changed resources loaded") + filtered, err := runFilter(ctx, r.Store, repositoryID, req.Thresholds, req.Visibility, rawGraphHash, settingsHash, nil, changedRaw.Symbols, contextPolicies, identityKeys) + if err != nil { + progressFinish(req.Progress) + return RepresentResult{}, err + } + filtered.ChangedFiles = changedRaw.Files + progressAdvance(req.Progress, "Architecture view filtered") + progressFinish(req.Progress) + + result := RepresentResult{} + if model.Provider != "none" { + embeddingSymbols := embeddingCandidateSymbols(filtered.VisibleSymbols, maxEmbeddingSymbolsPerRun) + stats, vectors, err := r.cacheEmbeddings(ctx, modelID, provider, repo.RepoRoot, embeddingSymbols, identityKeys, req.Progress, time.Duration(req.Embedding.TimeoutSeconds)*time.Second) + if err != nil { + return RepresentResult{}, err + } + result.EmbeddingCacheHits = stats.CacheHits + result.EmbeddingsCreated = stats.Created + if len(embeddingSymbols) == len(filtered.VisibleSymbols) { + progressStart(req.Progress, "Refreshing semantic filter", 1) + filtered, err = runFilter(ctx, r.Store, repositoryID, req.Thresholds, req.Visibility, rawGraphHash, settingsHash, vectors, changedRaw.Symbols, contextPolicies, identityKeys) + if err != nil { + progressFinish(req.Progress) + return RepresentResult{}, err + } + filtered.ChangedFiles = changedRaw.Files + progressAdvance(req.Progress, "Semantic filter refreshed") + progressFinish(req.Progress) + } + } + + representationHash := representationHash(filtered, req) + result = RepresentResult{ + RepositoryID: repositoryID, + FilterRunID: filtered.RunID, + RawGraphHash: rawGraphHash, + SettingsHash: settingsHash, + RepresentationHash: representationHash, + EmbeddingCacheHits: result.EmbeddingCacheHits, + EmbeddingsCreated: result.EmbeddingsCreated, + } + runID, err := r.Store.BeginRepresentationRun(ctx, repositoryID, rawGraphHash, settingsHash, modelIDPtr, representationHash) + if err != nil { + return RepresentResult{}, err + } + result.RepresentationRun = runID + status := "completed" + var runErr error + defer func() { + if runErr != nil { + status = "failed" + } + _ = r.Store.FinishRepresentationRun(context.Background(), runID, status, result, runErr) + }() + + progressStart(req.Progress, "Materializing representation", 3) + ownerMatcher, err := codeowners.Load(repo.RepoRoot) + if err != nil { + progressFinish(req.Progress) + runErr = err + return result, err + } + progressAdvance(req.Progress, "Ownership metadata loaded") + applyToken := randomToken() + if err := r.Store.AcquireApplyLock(ctx, repositoryID, os.Getpid(), applyToken, LockHeartbeatTimeout); err != nil { + progressFinish(req.Progress) + runErr = err + return result, err + } + progressAdvance(req.Progress, "Apply lock acquired") + defer func() { + _ = r.Store.ReleaseApplyLock(context.Background(), repositoryID, applyToken) + }() + stats, err := r.materialize(ctx, repo, filtered, req.Thresholds, settingsHash, identityKeys, ownerMatcher) + if err != nil { + progressFinish(req.Progress) + runErr = err + return result, err + } + progressAdvance(req.Progress, "Resources materialized") + progressFinish(req.Progress) + result.ElementsCreated = stats.ElementsCreated + result.ElementsUpdated = stats.ElementsUpdated + result.ConnectorsCreated = stats.ConnectorsCreated + result.ConnectorsUpdated = stats.ConnectorsUpdated + result.ViewsCreated = stats.ViewsCreated + result.ElementsPreserved = stats.ElementsPreserved + result.ConnectorsPreserved = stats.ConnectorsPreserved + result.ViewsPreserved = stats.ViewsPreserved + result.DeletesPreserved = stats.DeletesPreserved + return result, nil +} + +func (r *Representer) RepresentArchitecture(ctx context.Context, repo Repository, architecture architectureModel, thresholds Thresholds, progress ProgressSink) (RepresentResult, error) { + if r == nil || r.Store == nil { + return RepresentResult{}, fmt.Errorf("watch representer requires a store") + } + thresholds = defaultThresholds(thresholds) + rawGraphHash := stableHash(architecture) + settingsHash := stableHash(thresholds) + representationHash := stableHash([]any{rawGraphHash, settingsHash, "architecture"}) + result := RepresentResult{ + RepositoryID: repo.ID, + RawGraphHash: rawGraphHash, + SettingsHash: settingsHash, + RepresentationHash: representationHash, + } + runID, err := r.Store.BeginRepresentationRun(ctx, repo.ID, rawGraphHash, settingsHash, nil, representationHash) + if err != nil { + return RepresentResult{}, err + } + result.RepresentationRun = runID + status := "completed" + var runErr error + defer func() { + if runErr != nil { + status = "failed" + } + _ = r.Store.FinishRepresentationRun(context.Background(), runID, status, result, runErr) + }() + + progressStart(progress, "Materializing architecture view", 7) + applyToken := randomToken() + if err := r.Store.AcquireApplyLock(ctx, repo.ID, os.Getpid(), applyToken, LockHeartbeatTimeout); err != nil { + progressFinish(progress) + runErr = err + return result, err + } + progressAdvance(progress, "Apply lock acquired") + defer func() { + _ = r.Store.ReleaseApplyLock(context.Background(), repo.ID, applyToken) + }() + + initialLayout, err := r.Store.RepositoryMaterializationCount(ctx, repo.ID) + if err != nil { + progressFinish(progress) + runErr = err + return result, err + } + progressAdvance(progress, "Existing materialization inspected") + m := &materializer{ + store: r.Store, + repo: repo, + thresholds: thresholds, + settingsHash: settingsHash, + identityKeys: map[string]string{}, + tagPlan: semanticTagPlan{approved: map[string]struct{}{}, byOwner: map[string][]string{}}, + initialLayout: initialLayout == 0, + runMarker: time.Now().UTC().Format(time.RFC3339Nano), + newPlacements: map[int64]map[int64]struct{}{}, + } + rootViewID, err := m.workspaceRootViewID(ctx) + if err != nil { + progressFinish(progress) + runErr = err + return result, err + } + progressAdvance(progress, "Workspace root loaded") + repoElem, err := m.upsertElement(ctx, "repository", fmt.Sprintf("repository:%d", repo.ID), elementInput{ + Name: repo.DisplayName, + Kind: "repository", + Technology: "Runtime", + Repo: repoIdentity(repo), + Branch: nullStringValue(repo.Branch), + Tags: []string{"view:architecture"}, + }) + if err != nil { + progressFinish(progress) + runErr = err + return result, err + } + if err := m.upsertPlacement(ctx, rootViewID, repoElem, 0, 0); err != nil { + progressFinish(progress) + runErr = err + return result, err + } + repoView, err := m.upsertView(ctx, "repository", fmt.Sprintf("repository:%d", repo.ID), repoElem, repo.DisplayName, "Architecture") + if err != nil { + progressFinish(progress) + runErr = err + return result, err + } + progressAdvance(progress, "Repository view materialized") + if err := m.materializeArchitecture(ctx, architecture, repoView); err != nil { + progressFinish(progress) + runErr = err + return result, err + } + progressAdvance(progress, "Architecture resources materialized") + if err := m.pruneStaleResources(ctx); err != nil { + progressFinish(progress) + runErr = err + return result, err + } + progressAdvance(progress, "Stale generated resources pruned") + if err := m.layoutPlacements(ctx); err != nil { + progressFinish(progress) + runErr = err + return result, err + } + progressAdvance(progress, "Layout updated") + progressFinish(progress) + result.ElementsCreated = m.stats.ElementsCreated + result.ElementsUpdated = m.stats.ElementsUpdated + result.ConnectorsCreated = m.stats.ConnectorsCreated + result.ConnectorsUpdated = m.stats.ConnectorsUpdated + result.ViewsCreated = m.stats.ViewsCreated + result.ElementsPreserved = m.stats.ElementsPreserved + result.ConnectorsPreserved = m.stats.ConnectorsPreserved + result.ViewsPreserved = m.stats.ViewsPreserved + result.DeletesPreserved = m.stats.DeletesPreserved + return result, nil +} + +type embeddingCacheStats struct { + CacheHits int + Created int +} + +func progressStart(progress ProgressSink, label string, total int) { + if progress != nil { + progress.Start(label, total) + } +} + +func progressAdvance(progress ProgressSink, label string) { + if progress != nil { + progress.Advance(label) + } +} + +func progressFinish(progress ProgressSink) { + if progress != nil { + progress.Finish() + } +} + +func (r *Representer) cacheEmbeddings(ctx context.Context, modelID int64, provider Provider, repoRoot string, symbols []Symbol, identityKeys map[string]string, progress ProgressSink, timeout time.Duration) (embeddingCacheStats, map[int64]Vector, error) { + stats := embeddingCacheStats{} + vectorsBySymbol := map[int64]Vector{} + model := provider.ModelID() + if model.Provider == "none" { + return stats, vectorsBySymbol, nil + } + inputs := make([]EmbeddingInput, 0, len(symbols)) + missingSymbols := make([]Symbol, 0, len(symbols)) + progressStart(progress, "Preparing symbol embeddings", len(symbols)) + for _, sym := range symbols { + ownerKey := symbolOwnerKey(sym, identityKeys) + input := EmbeddingInput{OwnerType: "symbol", OwnerKey: ownerKey, Text: symbolEmbeddingText(repoRoot, sym)} + if data, ok, err := r.Store.Embedding(ctx, modelID, input.OwnerType, input.OwnerKey, inputHash(input)); err != nil { + progressFinish(progress) + return stats, vectorsBySymbol, err + } else if !ok { + inputs = append(inputs, input) + missingSymbols = append(missingSymbols, sym) + } else { + stats.CacheHits++ + vectorsBySymbol[sym.ID] = bytesToVector(data) + } + progressAdvance(progress, sym.QualifiedName) + } + progressFinish(progress) + if len(inputs) == 0 { + return stats, vectorsBySymbol, nil + } + + if timeout <= 0 { + timeout = 60 * time.Second + } + vectors := make([]Vector, 0, len(inputs)) + progressStart(progress, "Embedding symbols", len(inputs)) + for start := 0; start < len(inputs); start += defaultEmbeddingBatchSize { + if err := ctx.Err(); err != nil { + progressFinish(progress) + return stats, vectorsBySymbol, err + } + end := min(start+defaultEmbeddingBatchSize, len(inputs)) + chunk := inputs[start:end] + embedCtx, cancel := context.WithTimeout(ctx, timeout) + chunkVectors, err := provider.Embed(embedCtx, chunk) + cancel() + if err != nil { + progressFinish(progress) + return stats, vectorsBySymbol, err + } + if len(chunkVectors) != len(chunk) { + progressFinish(progress) + return stats, vectorsBySymbol, fmt.Errorf("embedding provider returned %d vectors for %d inputs", len(chunkVectors), len(chunk)) + } + vectors = append(vectors, chunkVectors...) + for _, input := range chunk { + progressAdvance(progress, input.OwnerKey) + } + } + for i, input := range inputs { + if err := r.Store.SaveEmbedding(ctx, modelID, input.OwnerType, input.OwnerKey, inputHash(input), vectorBytes(vectors[i])); err != nil { + progressFinish(progress) + return stats, vectorsBySymbol, err + } + stats.Created++ + vectorsBySymbol[missingSymbols[i].ID] = vectors[i] + } + progressFinish(progress) + return stats, vectorsBySymbol, nil +} + +func embeddingCandidateSymbols(symbols map[int64]Symbol, limit int) []Symbol { + out := sortedSymbols(symbols) + if limit > 0 && len(out) > limit { + out = out[:limit] + } + return out +} + +func symbolEmbeddingText(repoRoot string, sym Symbol) string { + body := symbolCodeBody(repoRoot, sym) + if strings.TrimSpace(body) == "" { + body = sym.QualifiedName + "\n" + sym.Kind + "\n" + sym.FilePath + } + return shrinkEmbeddingText(outdentCode(body)) +} + +func symbolCodeBody(repoRoot string, sym Symbol) string { + if strings.TrimSpace(repoRoot) == "" || strings.TrimSpace(sym.FilePath) == "" { + return "" + } + cleanRel := filepath.Clean(filepath.FromSlash(sym.FilePath)) + if filepath.IsAbs(cleanRel) || cleanRel == "." || cleanRel == ".." || strings.HasPrefix(cleanRel, ".."+string(filepath.Separator)) { + return "" + } + data, err := os.ReadFile(filepath.Join(repoRoot, cleanRel)) + if err != nil { + return "" + } + end := sym.StartLine + if sym.EndLine != nil { + end = *sym.EndLine + } + return lineRange(strings.Split(string(data), "\n"), sym.StartLine, end) +} + +func outdentCode(code string) string { + code = strings.ReplaceAll(code, "\r\n", "\n") + code = strings.ReplaceAll(code, "\r", "\n") + lines := strings.Split(code, "\n") + minIndent := -1 + for _, line := range lines { + if strings.TrimSpace(line) == "" { + continue + } + indent := leadingIndentWidth(line) + if minIndent == -1 || indent < minIndent { + minIndent = indent + } + } + if minIndent <= 0 { + return strings.TrimSpace(code) + } + for i, line := range lines { + lines[i] = trimIndentWidth(line, minIndent) + } + return strings.TrimSpace(strings.Join(lines, "\n")) +} + +func leadingIndentWidth(line string) int { + width := 0 + for _, r := range line { + switch r { + case ' ': + width++ + case '\t': + width += 4 + default: + return width + } + } + return width +} + +func trimIndentWidth(line string, maxWidth int) string { + width := 0 + for i, r := range line { + switch r { + case ' ': + width++ + case '\t': + width += 4 + default: + return line[i:] + } + if width >= maxWidth { + return line[i+len(string(r)):] + } + } + return "" +} + +func shrinkEmbeddingText(text string) string { + text = strings.TrimSpace(text) + if approximateTokenCount(text) <= maxEmbeddingInputApproxTokens { + return text + } + text = dropLowSignalCodeLines(text) + if approximateTokenCount(text) <= maxEmbeddingInputApproxTokens { + return text + } + if len(text) <= maxEmbeddingInputChars { + return text + } + marker := "\n\n/* ... middle omitted for embedding context ... */\n\n" + keep := maxEmbeddingInputChars - len(marker) + if keep <= 0 { + return text[:maxEmbeddingInputChars] + } + head := keep * 2 / 3 + tail := keep - head + return strings.TrimSpace(text[:head]) + marker + strings.TrimSpace(text[len(text)-tail:]) +} + +func approximateTokenCount(text string) int { + if text == "" { + return 0 + } + fields := strings.Fields(text) + byChars := (len(text) + 3) / 4 + if byChars > len(fields) { + return byChars + } + return len(fields) +} + +func dropLowSignalCodeLines(text string) string { + lines := strings.Split(text, "\n") + kept := make([]string, 0, len(lines)) + for _, line := range lines { + trimmed := strings.TrimSpace(line) + if trimmed == "" || strings.HasPrefix(trimmed, "//") || strings.HasPrefix(trimmed, "/*") || strings.HasPrefix(trimmed, "*") { + continue + } + kept = append(kept, line) + } + if len(kept) == 0 { + return text + } + return strings.TrimSpace(strings.Join(kept, "\n")) +} + +type materializeStats struct { + ElementsCreated int + ElementsUpdated int + ConnectorsCreated int + ConnectorsUpdated int + ViewsCreated int + ElementsPreserved int + ConnectorsPreserved int + ViewsPreserved int + DeletesPreserved int +} + +const ( + minSemanticTagCoverage = 2 + maxSemanticTagsPerElement = 5 + maxUsefulSemanticTagRatio = 0.70 + semanticTagOwnerKeyJoinChar = "\x00" +) + +type semanticTagPlan struct { + approved map[string]struct{} + byOwner map[string][]string +} + +func buildSemanticTagPlan(repo Repository, filtered filterResult, thresholds Thresholds, settingsHash string, identityKeys map[string]string, ownerMatcher *codeowners.Matcher, facts []Fact) semanticTagPlan { + candidates := map[string][]string{} + add := func(ownerType, ownerKey string, tags ...string) { + key := semanticTagOwnerKey(ownerType, ownerKey) + candidates[key] = roleSemanticTags(uniqueSemanticTags(append(candidates[key], tags...))) + } + + repoLanguage := dominantLanguage(filtered.VisibleSymbols) + add("repository", fmt.Sprintf("repository:%d", repo.ID), semanticLanguageTag(repoLanguage)) + + visibleFiles := filesForSymbols(filtered.VisibleSymbols) + for _, folder := range folderSet(visibleFiles) { + add("folder", "folder:"+folder, append(semanticPathTags(folder, repoLanguage), ownerMatcher.TagsForPath(folder)...)...) + } + for file := range visibleFiles { + add("file", "file:"+file, append(semanticPathTags(file, languageForFile(file, filtered.VisibleSymbols)), ownerMatcher.TagsForPath(file)...)...) + } + + for file, symbols := range symbolsByFile(filtered.VisibleSymbols) { + chunks := chunkSymbols(symbols, thresholds.MaxElementsPerView) + for _, chunk := range chunks { + if len(chunks) <= 1 || len(chunk) == 0 { + continue + } + keys := make([]string, 0, len(chunk)) + for _, sym := range chunk { + keys = append(keys, sym.StableKey) + } + clusterKey := stableClusterKey(repo.ID, file, settingsHash, keys) + add("cluster", clusterKey, semanticPathTags(file, languageFromStableKey(chunk[0].StableKey))...) + } + } + + for _, sym := range sortedSymbols(filtered.VisibleSymbols) { + tags := semanticPathTags(sym.FilePath, languageFromStableKey(sym.StableKey)) + tags = append(tags, semanticKindTag(sym.Kind)) + tags = append(tags, semanticSymbolRoleTags(sym, filtered.Incoming[sym.ID], filtered.Outgoing[sym.ID])...) + tags = append(tags, ownerMatcher.TagsForPath(sym.FilePath)...) + add("symbol", symbolOwnerKey(sym, identityKeys), tags...) + } + + addFactSemanticTags(facts, filtered.VisibleSymbols, identityKeys, add) + + counts := map[string]int{} + forced := map[string]struct{}{} + for _, tags := range candidates { + for _, tag := range tags { + counts[tag]++ + if strings.HasPrefix(tag, "owner:") || forceFactSemanticTag(tag) { + forced[tag] = struct{}{} + } + } + } + total := len(candidates) + maxCoverage := int(math.Floor(float64(total) * maxUsefulSemanticTagRatio)) + if maxCoverage < minSemanticTagCoverage { + maxCoverage = total - 1 + } + approved := map[string]struct{}{} + for tag, count := range counts { + if _, ok := forced[tag]; ok { + approved[tag] = struct{}{} + continue + } + if count < minSemanticTagCoverage { + continue + } + if total > 1 && count > maxCoverage { + continue + } + approved[tag] = struct{}{} + } + + byOwner := map[string][]string{} + for key, tags := range candidates { + var kept []string + for _, tag := range tags { + if _, ok := approved[tag]; ok { + kept = append(kept, tag) + } + } + sort.SliceStable(kept, func(i, j int) bool { + left, right := semanticTagPriority(kept[i]), semanticTagPriority(kept[j]) + if left == right { + return kept[i] < kept[j] + } + return left < right + }) + byOwner[key] = limitSemanticTags(kept) + } + return semanticTagPlan{approved: approved, byOwner: byOwner} +} + +func addFactSemanticTags(facts []Fact, symbols map[int64]Symbol, identityKeys map[string]string, add func(ownerType, ownerKey string, tags ...string)) { + symbolOwners := map[string]string{} + for _, sym := range symbols { + symbolOwners[sym.StableKey] = symbolOwnerKey(sym, identityKeys) + } + for _, fact := range facts { + tags := factSemanticTags(fact) + if len(tags) == 0 { + continue + } + add("fact", factOwnerKey(fact), tags...) + if fact.SubjectKind == "symbol" { + if owner, ok := symbolOwners[fact.SubjectStableKey]; ok { + add("symbol", owner, tags...) + continue + } + } + if strings.TrimSpace(fact.FilePath) != "" { + add("file", "file:"+fact.FilePath, tags...) + } + } +} + +func factSemanticTags(fact Fact) []string { + tags := append([]string{}, fact.Tags...) + switch fact.Type { + case "http.route": + tags = append(tags, "http:route") + case "frontend.route": + tags = append(tags, "frontend:route") + case "orm.query": + if !hasStringPrefix(tags, "orm:") { + tags = append(tags, "orm:query") + } + } + return uniqueSemanticTags(tags) +} + +func roleSemanticTags(tags []string) []string { + out := make([]string, 0, len(tags)) + for _, tag := range tags { + if strings.HasPrefix(tag, "role:") { + out = append(out, tag) + } + } + return out +} + +func forceFactSemanticTag(tag string) bool { + return strings.HasPrefix(tag, "framework:") || + strings.HasPrefix(tag, "orm:") || + strings.HasPrefix(tag, "technology:") || + tag == "http:route" || + tag == "frontend:route" +} + +func hasStringPrefix(values []string, prefix string) bool { + for _, value := range values { + if strings.HasPrefix(value, prefix) { + return true + } + } + return false +} + +func limitSemanticTags(tags []string) []string { + if len(tags) <= maxSemanticTagsPerElement { + return tags + } + var forced []string + var regular []string + for _, tag := range tags { + if strings.HasPrefix(tag, "owner:") { + forced = append(forced, tag) + continue + } + regular = append(regular, tag) + } + limit := max(maxSemanticTagsPerElement-len(forced), 0) + if len(regular) > limit { + regular = regular[:limit] + } + out := append(regular, forced...) + sort.SliceStable(out, func(i, j int) bool { + left, right := semanticTagPriority(out[i]), semanticTagPriority(out[j]) + if left == right { + return out[i] < out[j] + } + return left < right + }) + return out +} + +func (p semanticTagPlan) tagsFor(ownerType, ownerKey string) []string { + tags := p.byOwner[semanticTagOwnerKey(ownerType, ownerKey)] + return append([]string(nil), tags...) +} + +func (p semanticTagPlan) approvedTags() []string { + tags := make([]string, 0, len(p.approved)) + for tag := range p.approved { + tags = append(tags, tag) + } + sort.Strings(tags) + return tags +} + +func semanticTagOwnerKey(ownerType, ownerKey string) string { + return ownerType + semanticTagOwnerKeyJoinChar + ownerKey +} + +func uniqueSemanticTags(tags []string) []string { + seen := map[string]struct{}{} + out := make([]string, 0, len(tags)) + for _, tag := range tags { + tag = strings.TrimSpace(tag) + if !strings.HasPrefix(tag, "owner:") { + tag = strings.ToLower(tag) + } + if tag == "" { + continue + } + if _, ok := seen[tag]; ok { + continue + } + seen[tag] = struct{}{} + out = append(out, tag) + } + return out +} + +func semanticPathTags(filePath, language string) []string { + var tags []string + if area := semanticAreaTag(filePath); area != "" { + tags = append(tags, area) + } + tags = append(tags, semanticRoleTags(filePath)...) + if tag := semanticLanguageTag(language); tag != "" { + tags = append(tags, tag) + } + return tags +} + +func semanticAreaTag(filePath string) string { + parts := strings.Split(strings.Trim(filePath, "/"), "/") + if len(parts) == 0 || parts[0] == "" || len(parts) == 1 { + return "" + } + return "area:" + semanticTagSlug(parts[0]) +} + +func semanticLanguageTag(language string) string { + language = strings.TrimSpace(strings.ToLower(language)) + if language == "" || language == "source" { + return "" + } + return "lang:" + semanticTagSlug(language) +} + +func semanticKindTag(kind string) string { + kind = strings.TrimSpace(strings.ToLower(kind)) + if kind == "" { + return "" + } + return "kind:" + semanticTagSlug(kind) +} + +func semanticSymbolRoleTags(sym Symbol, incoming, outgoing int) []string { + var tags []string + if isExportedSymbol(sym) { + tags = append(tags, "graph:entrypoint") + } + if incoming >= 3 { + tags = append(tags, "graph:fan-in") + } + if outgoing >= 3 { + tags = append(tags, "graph:fan-out") + } + nameText := strings.ToLower(sym.Name + " " + sym.QualifiedName + " " + sym.Kind) + tags = append(tags, semanticRoleTags(nameText)...) + return tags +} + +func semanticRoleTags(text string) []string { + lower := strings.ToLower(text) + rules := []struct { + tag string + keywords []string + }{ + {"role:watch", []string{"watch", "watcher", "scan", "scanner", "represent", "materializ", "embedding"}}, + {"role:cli", []string{"cmd/", "/cmd/", "cli", "command", "cobra"}}, + {"role:api", []string{"api", "http", "server", "handler", "route", "rpc", "websocket"}}, + {"role:persistence", []string{"store", "db", "database", "sqlite", "migration", "schema", "repository"}}, + {"role:ui", []string{"frontend", "component", "view", "react", "canvas", "zui"}}, + {"role:analysis", []string{"analyzer", "symbol", "parser", "importer", "planner", "dependency"}}, + {"role:versioning", []string{"git", "version", "history", "commit", "branch"}}, + {"role:config", []string{"config", "setting", "option", "threshold"}}, + {"role:test", []string{"test", "_test.go", "fixture", "mock"}}, + } + var tags []string + for _, rule := range rules { + for _, keyword := range rule.keywords { + if strings.Contains(lower, keyword) { + tags = append(tags, rule.tag) + break + } + } + } + return tags +} + +func semanticTagSlug(value string) string { + value = strings.TrimSpace(strings.ToLower(value)) + var b strings.Builder + lastDash := false + for _, r := range value { + switch { + case r >= 'a' && r <= 'z', r >= '0' && r <= '9': + b.WriteRune(r) + lastDash = false + default: + if !lastDash && b.Len() > 0 { + b.WriteByte('-') + lastDash = true + } + } + } + return strings.Trim(b.String(), "-") +} + +func semanticTagPriority(tag string) int { + switch { + case strings.HasPrefix(tag, "role:"): + return 0 + case strings.HasPrefix(tag, "framework:"): + return 1 + case strings.HasPrefix(tag, "technology:"): + return 1 + case tag == "http:route" || tag == "frontend:route" || strings.HasPrefix(tag, "orm:"): + return 1 + case strings.HasPrefix(tag, "area:"): + return 2 + case strings.HasPrefix(tag, "kind:"): + return 3 + case strings.HasPrefix(tag, "graph:"): + return 4 + case strings.HasPrefix(tag, "lang:"): + return 5 + case strings.HasPrefix(tag, "owner:"): + return 6 + default: + return 7 + } +} + +func (r *Representer) materialize(ctx context.Context, repo Repository, filtered filterResult, thresholds Thresholds, settingsHash string, identityKeys map[string]string, ownerMatcher *codeowners.Matcher) (materializeStats, error) { + initialLayout, err := r.Store.RepositoryMaterializationCount(ctx, repo.ID) + if err != nil { + return materializeStats{}, err + } + facts, err := r.Store.FactsForRepository(ctx, repo.ID) + if err != nil { + return materializeStats{}, err + } + tagPlan := buildSemanticTagPlan(repo, filtered, thresholds, settingsHash, identityKeys, ownerMatcher, facts) + m := &materializer{store: r.Store, repo: repo, thresholds: thresholds, settingsHash: settingsHash, identityKeys: identityKeys, contextPolicies: filtered.ContextPolicies, tagPlan: tagPlan, initialLayout: initialLayout == 0, runMarker: time.Now().UTC().Format(time.RFC3339Nano), newPlacements: map[int64]map[int64]struct{}{}} + if err := m.ensureTags(ctx); err != nil { + return m.stats, err + } + rootViewID, err := m.workspaceRootViewID(ctx) + if err != nil { + return m.stats, err + } + repoLanguage := dominantLanguage(filtered.VisibleSymbols) + repoElem, err := m.upsertElement(ctx, "repository", fmt.Sprintf("repository:%d", repo.ID), elementInput{ + Name: repo.DisplayName, + Kind: "repository", + Technology: technologyLabel(repoLanguage), + Repo: repoIdentity(repo), + Branch: nullStringValue(repo.Branch), + Language: repoLanguage, + Tags: tagPlan.tagsFor("repository", fmt.Sprintf("repository:%d", repo.ID)), + }) + if err != nil { + return m.stats, err + } + if err := m.upsertPlacement(ctx, rootViewID, repoElem, 0, 0); err != nil { + return m.stats, err + } + repoView, err := m.upsertView(ctx, "repository", fmt.Sprintf("repository:%d", repo.ID), repoElem, repo.DisplayName, "Repository") + if err != nil { + return m.stats, err + } + + architectureView, structuralView, err := m.materializeRepositorySections(ctx, repoView, repoLanguage) + if err != nil { + return m.stats, err + } + + architecture := pruneDisconnectedArchitecture(canonicalizeArchitecture(mergeArchitectureModels(inferArchitecture(repo.RepoRoot), architectureFromFacts(facts)))) + + visibleFiles := filesForSymbols(filtered.VisibleSymbols) + for file := range filtered.VisibleFiles { + visibleFiles[file] = struct{}{} + } + for file := range filtered.ChangedFiles { + visibleFiles[file] = struct{}{} + } + folders := folderSet(visibleFiles) + folderElements := map[string]int64{} + folderViews := map[string]int64{} + for _, folder := range folders { + parentView := structuralView + if parent := path.Dir(folder); parent != "." && parent != "/" { + if id, ok := folderViews[parent]; ok { + parentView = id + } + } + elem, err := m.upsertElement(ctx, "folder", "folder:"+folder, elementInput{ + Name: path.Base(folder), + Kind: "folder", + Technology: technologyLabel(repoLanguage), + Repo: repoIdentity(repo), + Branch: nullStringValue(repo.Branch), + FilePath: folder, + Language: repoLanguage, + Tags: tagPlan.tagsFor("folder", "folder:"+folder), + }) + if err != nil { + return m.stats, err + } + x, y := gridPosition(len(folderViews)) + if err := m.upsertPlacement(ctx, parentView, elem, x, y); err != nil { + return m.stats, err + } + view, err := m.upsertView(ctx, "folder", "folder:"+folder, elem, folder, "Folder") + if err != nil { + return m.stats, err + } + folderElements[folder] = elem + folderViews[folder] = view + } + + fileElements := map[string]int64{} + fileViews := map[string]int64{} + fileLanguages, err := m.store.FileLanguages(ctx, repo.ID) + if err != nil { + return m.stats, err + } + for i, file := range sortedKeys(visibleFiles) { + fileLanguage := languageForFile(file, filtered.VisibleSymbols) + if language := strings.TrimSpace(fileLanguages[file]); language != "" { + fileLanguage = language + } + parentView := structuralView + if dir := path.Dir(file); dir != "." { + if id, ok := folderViews[dir]; ok { + parentView = id + } + } + elem, err := m.upsertElement(ctx, "file", "file:"+file, elementInput{ + Name: path.Base(file), + Kind: "file", + Technology: technologyLabel(fileLanguage), + Repo: repoIdentity(repo), + Branch: nullStringValue(repo.Branch), + FilePath: file, + Language: fileLanguage, + Tags: tagPlan.tagsFor("file", "file:"+file), + }) + if err != nil { + return m.stats, err + } + x, y := gridPosition(i) + if err := m.upsertPlacement(ctx, parentView, elem, x, y); err != nil { + return m.stats, err + } + view, err := m.upsertView(ctx, "file", "file:"+file, elem, file, "File") + if err != nil { + return m.stats, err + } + fileElements[file] = elem + fileViews[file] = view + } + + symbolElements := map[int64]int64{} + symbolViews := map[int64]int64{} + symbolPositions := map[int64]layoutPoint{} + occupied := map[int64]map[string]struct{}{} + detailedSymbols := len(filtered.VisibleSymbols) <= maxDetailedSymbolElements + for file, symbols := range symbolsByFile(filtered.VisibleSymbols) { + fileView := fileViews[file] + if fileView == 0 { + continue + } + chunks := chunkSymbols(symbols, effectiveMaxElementsPerView(thresholds, filtered.Visibility, filtered.ContextExpansions.fileTier(file))) + for chunkIndex, chunk := range chunks { + targetView := fileView + if len(chunks) > 1 { + keys := make([]string, 0, len(chunk)) + ids := make([]int64, 0, len(chunk)) + for _, sym := range chunk { + keys = append(keys, sym.StableKey) + ids = append(ids, sym.ID) + } + clusterKey := stableClusterKey(repo.ID, file, settingsHash, keys) + cluster, err := m.store.UpsertCluster(ctx, repo.ID, clusterKey, nil, fmt.Sprintf("%s cluster %d", path.Base(file), chunkIndex+1), "structural", "deterministic-chunk", settingsHash, ids) + if err != nil { + return m.stats, err + } + clusterElem, err := m.upsertElement(ctx, "cluster", clusterKey, elementInput{ + Name: cluster.Name, + Kind: "cluster", + Technology: technologyLabel(languageFromStableKey(chunk[0].StableKey)), + Repo: repoIdentity(repo), + Branch: nullStringValue(repo.Branch), + FilePath: file, + Language: languageFromStableKey(chunk[0].StableKey), + Tags: tagPlan.tagsFor("cluster", clusterKey), + }) + if err != nil { + return m.stats, err + } + x, y := gridPosition(chunkIndex) + if err := m.upsertPlacement(ctx, fileView, clusterElem, x, y); err != nil { + return m.stats, err + } + markOccupied(occupied, fileView, layoutPoint{X: x, Y: y}) + targetView, err = m.upsertView(ctx, "cluster", clusterKey, clusterElem, cluster.Name, "Cluster") + if err != nil { + return m.stats, err + } + } + if !detailedSymbols { + continue + } + for i, sym := range chunk { + language := languageFromStableKey(sym.StableKey) + elem, err := m.upsertElement(ctx, "symbol", symbolOwnerKey(sym, m.identityKeys), elementInput{ + Name: sym.QualifiedName, + Kind: sym.Kind, + Description: fmt.Sprintf("%s:%d", sym.FilePath, sym.StartLine), + Technology: technologyLabel(language), + Repo: repoIdentity(repo), + Branch: nullStringValue(repo.Branch), + FilePath: sym.FilePath, + Language: language, + Tags: tagPlan.tagsFor("symbol", symbolOwnerKey(sym, m.identityKeys)), + }) + if err != nil { + return m.stats, err + } + x, y := gridPosition(i) + if err := m.upsertPlacement(ctx, targetView, elem, x, y); err != nil { + return m.stats, err + } + point := layoutPoint{X: x, Y: y} + markOccupied(occupied, targetView, point) + symbolElements[sym.ID] = elem + symbolViews[sym.ID] = targetView + symbolPositions[sym.ID] = point + } + } + } + if err := m.materializeFacts(ctx, filtered.VisibleFacts, filtered.VisibleSymbols, fileViews, symbolElements, symbolViews, symbolPositions, occupied, filtered); err != nil { + return m.stats, err + } + + if err := m.materializeConnectors(ctx, filtered.VisibleReferences, filtered.VisibleSymbols, folderElements, fileElements, symbolElements, symbolViews, structuralView); err != nil { + return m.stats, err + } + if len(architecture.Components) > 0 { + if err := m.materializeArchitecture(ctx, architecture, architectureView); err != nil { + return m.stats, err + } + } + if err := m.pruneStaleResources(ctx); err != nil { + return m.stats, err + } + if err := m.layoutPlacements(ctx); err != nil { + return m.stats, err + } + return m.stats, nil +} + +func (m *materializer) materializeArchitecture(ctx context.Context, architecture architectureModel, repoView int64) error { + componentElements := map[string]int64{} + for i, component := range sortedArchitectureComponents(architecture.Components) { + tags := appendUnique(component.Tags, "view:architecture") + elem, err := m.upsertElement(ctx, "architecture-component", component.Key, elementInput{ + Name: component.Name, + Kind: component.Kind, + Description: component.Description, + Technology: firstNonEmpty(component.Technology, "Runtime"), + Repo: repoIdentity(m.repo), + Branch: nullStringValue(m.repo.Branch), + FilePath: component.FilePath, + Tags: tags, + }) + if err != nil { + return err + } + x, y := gridPosition(i) + if err := m.upsertPlacement(ctx, repoView, elem, x, y); err != nil { + return err + } + componentElements[component.Key] = elem + } + + for _, connector := range sortedArchitectureConnectors(architecture.Connectors) { + sourceID := componentElements[connector.SourceKey] + targetID := componentElements[connector.TargetKey] + if sourceID == 0 || targetID == 0 { + continue + } + if err := m.upsertConnectorDetailedWithDirection(ctx, "architecture-connector", connector.Key, repoView, sourceID, targetID, connector.Label, connector.Relationship, connector.Direction, ""); err != nil { + return err + } + } + return nil +} + +func (m *materializer) materializeRepositorySections(ctx context.Context, repoView int64, repoLanguage string) (int64, int64, error) { + architectureElem, err := m.upsertElement(ctx, "repository-section", fmt.Sprintf("repository-architecture:%d", m.repo.ID), elementInput{ + Name: "Architecture", + Kind: "view", + Description: "Generated architecture view", + Technology: "Architecture", + Repo: repoIdentity(m.repo), + Branch: nullStringValue(m.repo.Branch), + Language: repoLanguage, + Tags: []string{"view:architecture"}, + }) + if err != nil { + return 0, 0, err + } + if err := m.upsertPlacement(ctx, repoView, architectureElem, 0, 0); err != nil { + return 0, 0, err + } + architectureView, err := m.upsertView(ctx, "repository-section", fmt.Sprintf("repository-architecture:%d", m.repo.ID), architectureElem, m.repo.DisplayName+" Architecture", "Architecture") + if err != nil { + return 0, 0, err + } + + structuralElem, err := m.upsertElement(ctx, "repository-section", fmt.Sprintf("repository-structural:%d", m.repo.ID), elementInput{ + Name: "Structural", + Kind: "view", + Description: "Generated structural code view", + Technology: "Structural", + Repo: repoIdentity(m.repo), + Branch: nullStringValue(m.repo.Branch), + Language: repoLanguage, + Tags: []string{"view:structural"}, + }) + if err != nil { + return 0, 0, err + } + if err := m.upsertPlacement(ctx, repoView, structuralElem, 260, 0); err != nil { + return 0, 0, err + } + structuralView, err := m.upsertView(ctx, "repository-section", fmt.Sprintf("repository-structural:%d", m.repo.ID), structuralElem, m.repo.DisplayName+" Structural", "Structural") + if err != nil { + return 0, 0, err + } + return architectureView, structuralView, nil +} + +func sortedArchitectureComponents(values map[string]*architectureComponent) []*architectureComponent { + out := make([]*architectureComponent, 0, len(values)) + for _, value := range values { + out = append(out, value) + } + sort.SliceStable(out, func(i, j int) bool { + if out[i].Kind == out[j].Kind { + return out[i].Name < out[j].Name + } + return architectureKindRank(out[i].Kind) < architectureKindRank(out[j].Kind) + }) + return out +} + +func sortedArchitectureConnectors(values map[string]*architectureConnector) []*architectureConnector { + out := make([]*architectureConnector, 0, len(values)) + for _, value := range values { + out = append(out, value) + } + sort.SliceStable(out, func(i, j int) bool { return out[i].Key < out[j].Key }) + return out +} + +func architectureKindRank(kind string) int { + switch kind { + case "external": + return 0 + case "service": + return 1 + case "interface": + return 2 + case "datastore": + return 3 + case "queue": + return 4 + default: + return 5 + } +} + +type materializer struct { + store *Store + repo Repository + thresholds Thresholds + settingsHash string + identityKeys map[string]string + contextPolicies contextPolicySet + tagPlan semanticTagPlan + initialLayout bool + runMarker string + newPlacements map[int64]map[int64]struct{} + stats materializeStats +} + +type layoutPoint struct { + X float64 + Y float64 +} + +type factPlacement struct { + Point layoutPoint + SourceHandle string + TargetHandle string +} + +type elementInput struct { + Name string + Kind string + Description string + Technology string + Repo string + Branch string + FilePath string + Language string + Tags []string +} + +type materializedTechnologyLink struct { + Type string `json:"type"` + Slug string `json:"slug,omitempty"` + Label string `json:"label"` + IsPrimaryIcon bool `json:"is_primary_icon,omitempty"` +} + +func (m *materializer) ensureTags(ctx context.Context) error { + return tagcolors.Ensure(ctx, m.store.db, m.tagPlan.approvedTags()) +} + +func (m *materializer) workspaceRootViewID(ctx context.Context) (int64, error) { + var id int64 + err := m.store.db.QueryRowContext(ctx, `SELECT id FROM views WHERE owner_element_id IS NULL ORDER BY id LIMIT 1`).Scan(&id) + return id, err +} + +func (m *materializer) upsertElement(ctx context.Context, ownerType, ownerKey string, input elementInput) (int64, error) { + input.Technology, input.Tags = extractTechnologyFromTags(input.Technology, input.Tags) + if state, ok, err := m.store.MappingState(ctx, m.repo.ID, ownerType, ownerKey, "element"); err != nil { + return 0, err + } else if ok && elementExists(ctx, m.store.db, state.ResourceID) { + dirty, err := m.mappingDirty(ctx, ownerType, ownerKey, "element", state) + if err != nil { + return 0, err + } + if dirty { + m.stats.ElementsPreserved++ + return state.ResourceID, m.saveMapping(ctx, ownerType, ownerKey, "element", state.ResourceID) + } + tags, _ := json.Marshal(input.Tags) + techLinks, _ := json.Marshal(technologyLinksForElement(input.Technology, input.Language)) + _, err = m.store.db.ExecContext(ctx, ` + UPDATE elements + SET name = ?, kind = ?, description = ?, technology = ?, technology_connectors = ?, tags = ?, repo = ?, branch = ?, file_path = ?, language = ?, updated_at = ? + WHERE id = ?`, + input.Name, nullString(input.Kind), nullString(input.Description), nullString(input.Technology), string(techLinks), string(tags), + nullString(input.Repo), nullString(input.Branch), nullString(input.FilePath), nullString(input.Language), nowString(), state.ResourceID) + if err != nil { + return 0, err + } + if err := m.saveMappingWithCurrentHash(ctx, ownerType, ownerKey, "element", state.ResourceID); err != nil { + return 0, err + } + m.stats.ElementsUpdated++ + return state.ResourceID, nil + } + now := nowString() + tags, _ := json.Marshal(input.Tags) + techLinks, _ := json.Marshal(technologyLinksForElement(input.Technology, input.Language)) + res, err := m.store.db.ExecContext(ctx, ` + INSERT INTO elements(name, kind, description, technology, technology_connectors, tags, repo, branch, file_path, language, created_at, updated_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, + input.Name, nullString(input.Kind), nullString(input.Description), nullString(input.Technology), string(techLinks), string(tags), + nullString(input.Repo), nullString(input.Branch), nullString(input.FilePath), nullString(input.Language), now, now) + if err != nil { + return 0, err + } + id, err := res.LastInsertId() + if err != nil { + return 0, err + } + if err := m.saveMappingWithCurrentHash(ctx, ownerType, ownerKey, "element", id); err != nil { + return 0, err + } + m.stats.ElementsCreated++ + return id, nil +} + +func (m *materializer) upsertView(ctx context.Context, ownerType, ownerKey string, ownerElementID int64, name, label string) (int64, error) { + if state, ok, err := m.store.MappingState(ctx, m.repo.ID, ownerType, ownerKey, "view"); err != nil { + return 0, err + } else if ok && viewExists(ctx, m.store.db, state.ResourceID) { + dirty, err := m.mappingDirty(ctx, ownerType, ownerKey, "view", state) + if err != nil { + return 0, err + } + if dirty { + m.stats.ViewsPreserved++ + return state.ResourceID, m.saveMapping(ctx, ownerType, ownerKey, "view", state.ResourceID) + } + if _, err := m.store.db.ExecContext(ctx, `UPDATE views SET owner_element_id = ?, name = ?, level_label = ?, updated_at = ? WHERE id = ?`, ownerElementID, name, label, nowString(), state.ResourceID); err != nil { + return 0, err + } + return state.ResourceID, m.saveMappingWithCurrentHash(ctx, ownerType, ownerKey, "view", state.ResourceID) + } + now := nowString() + res, err := m.store.db.ExecContext(ctx, `INSERT INTO views(owner_element_id, name, level_label, level, created_at, updated_at) VALUES (?, ?, ?, 1, ?, ?)`, ownerElementID, name, label, now, now) + if err != nil { + return 0, err + } + id, err := res.LastInsertId() + if err != nil { + return 0, err + } + if err := m.saveMappingWithCurrentHash(ctx, ownerType, ownerKey, "view", id); err != nil { + return 0, err + } + m.stats.ViewsCreated++ + return id, nil +} + +func (m *materializer) upsertPlacement(ctx context.Context, viewID, elementID int64, x, y float64) error { + var existingID int64 + err := m.store.db.QueryRowContext(ctx, `SELECT id FROM placements WHERE view_id = ? AND element_id = ?`, viewID, elementID).Scan(&existingID) + if err == nil { + return nil + } + if !errors.Is(err, sql.ErrNoRows) { + return err + } + err = m.store.db.QueryRowContext(ctx, ` + SELECT p.id + FROM placements p + JOIN watch_materialization wm + ON wm.repository_id = ? AND wm.resource_type = 'element' AND wm.resource_id = p.element_id + WHERE p.element_id = ? + ORDER BY p.id + LIMIT 1`, m.repo.ID, elementID).Scan(&existingID) + if err == nil { + _, err = m.store.db.ExecContext(ctx, `UPDATE placements SET view_id = ?, position_x = ?, position_y = ?, updated_at = ? WHERE id = ?`, viewID, x, y, nowString(), existingID) + if err == nil { + m.markNewPlacement(viewID, elementID) + } + return err + } + if !errors.Is(err, sql.ErrNoRows) { + return err + } + now := nowString() + _, err = m.store.db.ExecContext(ctx, ` + INSERT INTO placements(view_id, element_id, position_x, position_y, created_at, updated_at) + VALUES (?, ?, ?, ?, ?, ?)`, + viewID, elementID, x, y, now, now) + if err == nil { + m.markNewPlacement(viewID, elementID) + } + return err +} + +func (m *materializer) markNewPlacement(viewID, elementID int64) { + if m.newPlacements == nil { + m.newPlacements = map[int64]map[int64]struct{}{} + } + if m.newPlacements[viewID] == nil { + m.newPlacements[viewID] = map[int64]struct{}{} + } + m.newPlacements[viewID][elementID] = struct{}{} +} + +const ( + watchLayoutNodeWidth = 140.0 + watchLayoutNodeHeight = 80.0 + watchLayoutGapX = 260.0 + watchLayoutGapY = 170.0 + watchLayoutMaxRowsPerColumn = 6 +) + +type watchPlacementNode struct { + ElementID int64 + X float64 + Y float64 +} + +type watchLayoutConnector struct { + Source int64 + Target int64 +} + +func (m *materializer) layoutPlacements(ctx context.Context) error { + targets := m.newPlacements + if m.initialLayout { + var err error + targets, err = m.generatedPlacementsByView(ctx) + if err != nil { + return err + } + } + for viewID, elementIDs := range targets { + if len(elementIDs) == 0 { + continue + } + if err := m.layoutView(ctx, viewID, elementIDs, m.initialLayout); err != nil { + return err + } + } + return nil +} + +func (m *materializer) generatedPlacementsByView(ctx context.Context) (map[int64]map[int64]struct{}, error) { + rows, err := m.store.db.QueryContext(ctx, ` + SELECT p.view_id, p.element_id + FROM placements p + JOIN watch_materialization wm + ON wm.repository_id = ? AND wm.resource_type = 'element' AND wm.resource_id = p.element_id + ORDER BY p.view_id, p.id`, m.repo.ID) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + out := map[int64]map[int64]struct{}{} + for rows.Next() { + var viewID, elementID int64 + if err := rows.Scan(&viewID, &elementID); err != nil { + return nil, err + } + if out[viewID] == nil { + out[viewID] = map[int64]struct{}{} + } + out[viewID][elementID] = struct{}{} + } + return out, rows.Err() +} + +func (m *materializer) layoutView(ctx context.Context, viewID int64, targets map[int64]struct{}, force bool) error { + placements, err := m.viewPlacementNodes(ctx, viewID) + if err != nil { + return err + } + connectors, err := m.viewLayoutConnectors(ctx, viewID) + if err != nil { + return err + } + if force || hasNoPreservedPlacements(placements, targets) { + next := organicWatchLayout(targets, connectors) + for _, elementID := range sortedInt64Set(targets) { + pos := next[elementID] + if _, err := m.store.db.ExecContext(ctx, `UPDATE placements SET position_x = ?, position_y = ?, updated_at = ? WHERE view_id = ? AND element_id = ?`, pos.X, pos.Y, nowString(), viewID, elementID); err != nil { + return err + } + } + _ = placements // already committed; kept for potential future collision pass + return nil + } + + positioned := map[int64]watchPlacementNode{} + for _, p := range placements { + if _, isNew := targets[p.ElementID]; !isNew { + positioned[p.ElementID] = p + } + } + occupied := occupiedWatchCells(placements, targets) + for _, elementID := range sortedInt64Set(targets) { + x, y := bestIncrementalWatchPosition(elementID, positioned, occupied, connectors) + occupied[watchCellKey(x, y)] = struct{}{} + positioned[elementID] = watchPlacementNode{ElementID: elementID, X: x, Y: y} + if _, err := m.store.db.ExecContext(ctx, `UPDATE placements SET position_x = ?, position_y = ?, updated_at = ? WHERE view_id = ? AND element_id = ?`, x, y, nowString(), viewID, elementID); err != nil { + return err + } + } + return nil +} + +func hasNoPreservedPlacements(placements []watchPlacementNode, targets map[int64]struct{}) bool { + if len(targets) == 0 { + return false + } + for _, p := range placements { + if _, isTarget := targets[p.ElementID]; !isTarget { + return false + } + } + return true +} + +func (m *materializer) viewPlacementNodes(ctx context.Context, viewID int64) ([]watchPlacementNode, error) { + rows, err := m.store.db.QueryContext(ctx, `SELECT element_id, position_x, position_y FROM placements WHERE view_id = ? ORDER BY id`, viewID) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + var out []watchPlacementNode + for rows.Next() { + var p watchPlacementNode + if err := rows.Scan(&p.ElementID, &p.X, &p.Y); err != nil { + return nil, err + } + out = append(out, p) + } + return out, rows.Err() +} + +func (m *materializer) viewLayoutConnectors(ctx context.Context, viewID int64) ([]watchLayoutConnector, error) { + rows, err := m.store.db.QueryContext(ctx, `SELECT source_element_id, target_element_id FROM connectors WHERE view_id = ? ORDER BY id`, viewID) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + var out []watchLayoutConnector + for rows.Next() { + var c watchLayoutConnector + if err := rows.Scan(&c.Source, &c.Target); err != nil { + return nil, err + } + out = append(out, c) + } + return out, rows.Err() +} + +// organicWatchLayout runs the force-directed OrganicLayout on the target element +// set, using only the connectors that exist between those targets. +// It returns a position map keyed by element ID. +func organicWatchLayout(targets map[int64]struct{}, connectors []watchLayoutConnector) map[int64]watchPlacementNode { + // Build layout nodes. + nodeByID := make(map[int64]*layout.Node, len(targets)) + nodes := make([]*layout.Node, 0, len(targets)) + for id := range targets { + n := &layout.Node{ID: id} + nodeByID[id] = n + nodes = append(nodes, n) + } + + // Build layout edges (only between targets). + var edges []*layout.Edge + for _, c := range connectors { + src, srcOK := nodeByID[c.Source] + tgt, tgtOK := nodeByID[c.Target] + if srcOK && tgtOK { + edges = append(edges, &layout.Edge{Source: src, Target: tgt}) + } + } + + layout.OrganicLayout(nodes, edges) + applyDirectedWatchLevels(nodes, connectors, targets) + + out := make(map[int64]watchPlacementNode, len(nodes)) + for _, n := range nodes { + out[n.ID] = watchPlacementNode{ElementID: n.ID, X: n.X, Y: n.Y} + } + return out +} + +func applyDirectedWatchLevels(nodes []*layout.Node, connectors []watchLayoutConnector, targets map[int64]struct{}) { + if len(nodes) == 0 { + return + } + level := directedWatchLevels(targets, connectors) + maxLevel := 0 + for _, value := range level { + if value > maxLevel { + maxLevel = value + } + } + if maxLevel == 0 { + sort.Slice(nodes, func(i, j int) bool { return nodes[i].ID < nodes[j].ID }) + for i, n := range nodes { + n.X = float64(i/watchLayoutMaxRowsPerColumn) * watchLayoutGapX + n.Y = float64(i%watchLayoutMaxRowsPerColumn) * watchLayoutGapY + } + return + } + nodesByLevel := map[int][]*layout.Node{} + for _, n := range nodes { + nodesByLevel[level[n.ID]] = append(nodesByLevel[level[n.ID]], n) + } + nextCol := 0 + for _, col := range sortedLayoutNodeLevels(nodesByLevel) { + group := nodesByLevel[col] + sort.Slice(group, func(i, j int) bool { + if group[i].Y == group[j].Y { + return group[i].ID < group[j].ID + } + return group[i].Y < group[j].Y + }) + for row, n := range group { + n.X = float64(nextCol+row/watchLayoutMaxRowsPerColumn) * watchLayoutGapX + n.Y = float64(row%watchLayoutMaxRowsPerColumn) * watchLayoutGapY + } + nextCol += max(1, (len(group)+watchLayoutMaxRowsPerColumn-1)/watchLayoutMaxRowsPerColumn) + } +} + +func directedWatchLevels(targets map[int64]struct{}, connectors []watchLayoutConnector) map[int64]int { + level := map[int64]int{} + for id := range targets { + level[id] = 0 + } + for i := 0; i < len(targets); i++ { + changed := false + for _, c := range connectors { + if _, ok := targets[c.Source]; !ok { + continue + } + if _, ok := targets[c.Target]; !ok { + continue + } + if level[c.Source] >= len(targets)-1 { + continue + } + next := level[c.Source] + 1 + if level[c.Target] < next { + level[c.Target] = next + changed = true + } + } + if !changed { + break + } + } + for id, value := range level { + if value >= len(targets) { + level[id] = 0 + } + } + return level +} + +func sortedLayoutNodeLevels(values map[int][]*layout.Node) []int { + out := make([]int, 0, len(values)) + for value := range values { + out = append(out, value) + } + sort.Ints(out) + return out +} + +func bestIncrementalWatchPosition(elementID int64, positioned map[int64]watchPlacementNode, occupied map[string]struct{}, connectors []watchLayoutConnector) (float64, float64) { + candidates := watchLayoutCandidates(positioned) + bestX, bestY := 0.0, 0.0 + bestScore := math.Inf(1) + for _, candidate := range candidates { + if _, blocked := occupied[watchCellKey(candidate.X, candidate.Y)]; blocked { + continue + } + score := incrementalWatchScore(elementID, candidate, positioned, connectors) + if score < bestScore { + bestScore = score + bestX, bestY = candidate.X, candidate.Y + } + } + if math.IsInf(bestScore, 1) { + return nearestFreeWatchCell(0, 0, occupied) + } + return bestX, bestY +} + +func incrementalWatchScore(elementID int64, candidate watchPlacementNode, positioned map[int64]watchPlacementNode, connectors []watchLayoutConnector) float64 { + score := math.Abs(candidate.X)*0.01 + math.Abs(candidate.Y)*0.01 + candidateEdges := [][2]watchPlacementNode{} + existingEdges := [][2]watchPlacementNode{} + for _, c := range connectors { + source, sourceOK := positioned[c.Source] + target, targetOK := positioned[c.Target] + if c.Source == elementID { + source, sourceOK = candidate, true + } + if c.Target == elementID { + target, targetOK = candidate, true + } + if sourceOK && targetOK { + edge := [2]watchPlacementNode{source, target} + if c.Source == elementID || c.Target == elementID { + candidateEdges = append(candidateEdges, edge) + score += watchDistance(source, target) + } else { + existingEdges = append(existingEdges, edge) + } + } + } + if len(candidateEdges) == 0 { + return score + nearestWatchNeighborDistance(candidate, positioned) + } + for _, candidateEdge := range candidateEdges { + for _, existingEdge := range existingEdges { + if candidateEdge[0].ElementID == existingEdge[0].ElementID || candidateEdge[0].ElementID == existingEdge[1].ElementID || + candidateEdge[1].ElementID == existingEdge[0].ElementID || candidateEdge[1].ElementID == existingEdge[1].ElementID { + continue + } + if watchSegmentsIntersect(candidateEdge[0], candidateEdge[1], existingEdge[0], existingEdge[1]) { + score += 10000 + } + } + } + return score +} + +func watchLayoutCandidates(positioned map[int64]watchPlacementNode) []watchPlacementNode { + minCol, maxCol, minRow, maxRow := 0, 4, 0, 3 + if len(positioned) > 0 { + minCol, maxCol, minRow, maxRow = math.MaxInt, math.MinInt, math.MaxInt, math.MinInt + for _, p := range positioned { + col := int(math.Round(p.X / watchLayoutGapX)) + row := int(math.Round(p.Y / watchLayoutGapY)) + if col < minCol { + minCol = col + } + if col > maxCol { + maxCol = col + } + if row < minRow { + minRow = row + } + if row > maxRow { + maxRow = row + } + } + minCol-- + maxCol += 2 + minRow-- + maxRow += 2 + } + out := make([]watchPlacementNode, 0, (maxCol-minCol+1)*(maxRow-minRow+1)) + for col := minCol; col <= maxCol; col++ { + for row := minRow; row <= maxRow; row++ { + out = append(out, watchPlacementNode{X: float64(col) * watchLayoutGapX, Y: float64(row) * watchLayoutGapY}) + } + } + return out +} + +func occupiedWatchCells(placements []watchPlacementNode, ignored map[int64]struct{}) map[string]struct{} { + occupied := map[string]struct{}{} + for _, p := range placements { + if _, ok := ignored[p.ElementID]; ok { + continue + } + occupied[watchCellKey(p.X, p.Y)] = struct{}{} + } + return occupied +} + +func nearestFreeWatchCell(x, y float64, occupied map[string]struct{}) (float64, float64) { + baseCol := int(math.Round(x / watchLayoutGapX)) + baseRow := int(math.Round(y / watchLayoutGapY)) + for radius := range 200 { + for col := baseCol - radius; col <= baseCol+radius; col++ { + for row := baseRow - radius; row <= baseRow+radius; row++ { + if watchAbsInt(col-baseCol) != radius && watchAbsInt(row-baseRow) != radius { + continue + } + nx, ny := float64(col)*watchLayoutGapX, float64(row)*watchLayoutGapY + if _, ok := occupied[watchCellKey(nx, ny)]; !ok { + return nx, ny + } + } + } + } + return x, y +} + +func watchCellKey(x, y float64) string { + return fmt.Sprintf("%d:%d", int(math.Round(x/watchLayoutGapX)), int(math.Round(y/watchLayoutGapY))) +} + +func watchDistance(a, b watchPlacementNode) float64 { + return math.Hypot(a.X-b.X, a.Y-b.Y) +} + +func nearestWatchNeighborDistance(candidate watchPlacementNode, positioned map[int64]watchPlacementNode) float64 { + if len(positioned) == 0 { + return 0 + } + best := math.Inf(1) + for _, p := range positioned { + if d := watchDistance(candidate, p); d < best { + best = d + } + } + return best +} + +func watchCenter(p watchPlacementNode) (float64, float64) { + return p.X + watchLayoutNodeWidth/2, p.Y + watchLayoutNodeHeight/2 +} + +func watchSegmentsIntersect(a, b, c, d watchPlacementNode) bool { + ax, ay := watchCenter(a) + bx, by := watchCenter(b) + cx, cy := watchCenter(c) + dx, dy := watchCenter(d) + return segmentOrientation(ax, ay, cx, cy, dx, dy) != segmentOrientation(bx, by, cx, cy, dx, dy) && + segmentOrientation(ax, ay, bx, by, cx, cy) != segmentOrientation(ax, ay, bx, by, dx, dy) +} + +func segmentOrientation(ax, ay, bx, by, cx, cy float64) int { + value := (by-ay)*(cx-bx) - (bx-ax)*(cy-by) + if math.Abs(value) < 0.000001 { + return 0 + } + if value > 0 { + return 1 + } + return -1 +} + +func sortedInt64Set(values map[int64]struct{}) []int64 { + out := make([]int64, 0, len(values)) + for value := range values { + out = append(out, value) + } + slices.Sort(out) + return out +} + +func watchAbsInt(value int) int { + if value < 0 { + return -value + } + return value +} + +type filePairReference struct { + Key string + Ref Reference + Count int +} + +func (m *materializer) materializeFacts(ctx context.Context, facts []Fact, symbols map[int64]Symbol, fileViews map[string]int64, symbolElements map[int64]int64, symbolViews map[int64]int64, symbolPositions map[int64]layoutPoint, occupied map[int64]map[string]struct{}, filtered filterResult) error { + if len(facts) == 0 { + return nil + } + symbolIDByStable := map[string]int64{} + for id, sym := range symbols { + symbolIDByStable[sym.StableKey] = id + } + nodeFactsByFile := map[string][]Fact{} + summaryFactsByFile := map[string][]Fact{} + for _, fact := range facts { + if strings.TrimSpace(fact.FilePath) == "" || fileViews[fact.FilePath] == 0 { + continue + } + if highSignalFact(fact) { + nodeFactsByFile[fact.FilePath] = append(nodeFactsByFile[fact.FilePath], fact) + } else { + summaryFactsByFile[fact.FilePath] = append(summaryFactsByFile[fact.FilePath], fact) + } + } + fileSet := map[string]struct{}{} + for file := range nodeFactsByFile { + fileSet[file] = struct{}{} + } + for file := range summaryFactsByFile { + fileSet[file] = struct{}{} + } + for _, file := range sortedKeys(fileSet) { + items := nodeFactsByFile[file] + sort.SliceStable(items, func(i, j int) bool { + if items[i].Type == items[j].Type { + return factOwnerKey(items[i]) < factOwnerKey(items[j]) + } + return items[i].Type < items[j].Type + }) + limit := min(factNodeLimitForFile(m.thresholds, filtered.Visibility, filtered.ContextExpansions.fileTier(file)), len(items)) + subjectFactCounts := map[int64]int{} + for i, fact := range items[:limit] { + elem, err := m.upsertElement(ctx, "fact", factOwnerKey(fact), elementInput{ + Name: factNodeName(fact), + Kind: factNodeKind(fact), + Description: factNodeDescription(fact), + Technology: factTechnology(fact), + Repo: repoIdentity(m.repo), + Branch: nullStringValue(m.repo.Branch), + FilePath: fact.FilePath, + Language: languageForFile(fact.FilePath, symbols), + Tags: m.tagPlan.tagsFor("fact", factOwnerKey(fact)), + }) + if err != nil { + return err + } + viewID := fileViews[file] + var subjectID int64 + if fact.SubjectKind == "symbol" { + subjectID = symbolIDByStable[fact.SubjectStableKey] + } + placement := nextFactPlacement(viewID, subjectID, subjectFactCounts[subjectID], symbolViews, symbolPositions, occupied, i) + if subjectID != 0 { + subjectFactCounts[subjectID]++ + } + if err := m.upsertPlacement(ctx, viewID, elem, placement.Point.X, placement.Point.Y); err != nil { + return err + } + markOccupied(occupied, viewID, placement.Point) + if fact.SubjectKind == "symbol" { + if symID := subjectID; symID != 0 && symbolElements[symID] != 0 && symbolViews[symID] == viewID { + ownerKey := factOwnerKey(fact) + ":subject" + label := firstNonEmpty(fact.Relationship, "declares") + if err := m.upsertConnectorDetailed(ctx, "fact-reference", ownerKey, viewID, symbolElements[symID], elem, label, label, ""); err != nil { + return err + } + } + } + } + summaryFacts := append([]Fact(nil), summaryFactsByFile[file]...) + if limit < len(items) { + summaryFacts = append(summaryFacts, items[limit:]...) + } + if len(summaryFacts) > 0 { + if err := m.materializeFactSummaries(ctx, file, fileViews[file], summaryFacts, occupied); err != nil { + return err + } + } + } + return nil +} + +func factNodeLimitForFile(thresholds Thresholds, visibility VisibilityConfig, tier int) int { + limit := effectiveMaxElementsPerView(thresholds, visibility, tier) / 3 + if limit < 3 { + return 3 + } + return limit +} + +func (m *materializer) materializeFactSummaries(ctx context.Context, file string, viewID int64, facts []Fact, occupied map[int64]map[string]struct{}) error { + byType := map[string][]Fact{} + for _, fact := range facts { + byType[fact.Type] = append(byType[fact.Type], fact) + } + i := 0 + for _, factType := range sortedKeysFromFactSummaryGroups(byType) { + items := byType[factType] + keys := make([]string, 0, len(items)) + for _, fact := range items { + keys = append(keys, factOwnerKey(fact)) + } + sort.Strings(keys) + ownerKey := "fact-summary:" + file + ":" + factType + ":" + stableHash(keys) + elem, err := m.upsertElement(ctx, "fact-summary", ownerKey, elementInput{ + Name: fmt.Sprintf("%d %s", len(items), factSummaryLabel(factType, len(items))), + Kind: "summary", + Description: fmt.Sprintf("%d omitted %s facts in %s", len(items), factType, file), + Technology: "Runtime", + Repo: repoIdentity(m.repo), + Branch: nullStringValue(m.repo.Branch), + FilePath: file, + Tags: summaryTagsForFacts(items), + }) + if err != nil { + return err + } + point := nextOpenGridPoint(viewID, occupied, 1000+i) + if err := m.upsertPlacement(ctx, viewID, elem, point.X, point.Y); err != nil { + return err + } + markOccupied(occupied, viewID, point) + i++ + } + return nil +} + +func nextFactPlacement(viewID, subjectID int64, subjectIndex int, symbolViews map[int64]int64, symbolPositions map[int64]layoutPoint, occupied map[int64]map[string]struct{}, fallbackIndex int) factPlacement { + if subjectID == 0 || symbolViews[subjectID] != viewID { + point := nextOpenGridPoint(viewID, occupied, fallbackIndex) + return factPlacement{Point: point, SourceHandle: "right", TargetHandle: "left"} + } + origin := symbolPositions[subjectID] + candidates := factPlacementCandidates(origin, subjectIndex) + for _, candidate := range candidates { + if !isOccupied(occupied, viewID, candidate.Point) { + return candidate + } + } + point := nextOpenGridPoint(viewID, occupied, fallbackIndex) + return factPlacement{Point: point, SourceHandle: "right", TargetHandle: "left"} +} + +func factPlacementCandidates(origin layoutPoint, subjectIndex int) []factPlacement { + ring := subjectIndex/8 + 1 + spread := float64((subjectIndex%3)-1) * 90 + dx := float64(ring) * watchLayoutGapX + dy := float64(ring) * watchLayoutGapY + return []factPlacement{ + {Point: layoutPoint{X: origin.X + dx, Y: origin.Y + spread}, SourceHandle: "right", TargetHandle: "left"}, + {Point: layoutPoint{X: origin.X, Y: origin.Y + dy + spread}, SourceHandle: "bottom", TargetHandle: "top"}, + {Point: layoutPoint{X: origin.X - dx, Y: origin.Y + spread}, SourceHandle: "left", TargetHandle: "right"}, + {Point: layoutPoint{X: origin.X, Y: origin.Y - dy + spread}, SourceHandle: "top", TargetHandle: "bottom"}, + {Point: layoutPoint{X: origin.X + dx, Y: origin.Y + dy}, SourceHandle: "right", TargetHandle: "left"}, + {Point: layoutPoint{X: origin.X - dx, Y: origin.Y + dy}, SourceHandle: "left", TargetHandle: "right"}, + {Point: layoutPoint{X: origin.X + dx, Y: origin.Y - dy}, SourceHandle: "right", TargetHandle: "left"}, + {Point: layoutPoint{X: origin.X - dx, Y: origin.Y - dy}, SourceHandle: "left", TargetHandle: "right"}, + } +} + +func nextOpenGridPoint(viewID int64, occupied map[int64]map[string]struct{}, startIndex int) layoutPoint { + for i := startIndex; ; i++ { + x, y := gridPosition(i) + point := layoutPoint{X: x, Y: y} + if !isOccupied(occupied, viewID, point) { + return point + } + } +} + +func markOccupied(occupied map[int64]map[string]struct{}, viewID int64, point layoutPoint) { + if occupied[viewID] == nil { + occupied[viewID] = map[string]struct{}{} + } + occupied[viewID][layoutPointKey(point)] = struct{}{} +} + +func isOccupied(occupied map[int64]map[string]struct{}, viewID int64, point layoutPoint) bool { + if occupied[viewID] == nil { + return false + } + _, ok := occupied[viewID][layoutPointKey(point)] + return ok +} + +func layoutPointKey(point layoutPoint) string { + return fmt.Sprintf("%.0f:%.0f", point.X, point.Y) +} + +func sortedKeysFromFactSummaryGroups(groups map[string][]Fact) []string { + keys := make([]string, 0, len(groups)) + for key := range groups { + keys = append(keys, key) + } + sort.Strings(keys) + return keys +} + +func factSummaryLabel(factType string, count int) string { + label := factType + switch factType { + case "http.route", "frontend.route": + label = "routes" + case "orm.query": + label = "data access facts" + } + if count == 1 { + return strings.TrimSuffix(label, "s") + } + return label +} + +func summaryTagsForFacts(facts []Fact) []string { + set := map[string]struct{}{} + for _, fact := range facts { + for _, tag := range fact.Tags { + tag = strings.TrimSpace(tag) + if strings.HasPrefix(tag, "role:") { + set[tag] = struct{}{} + } + } + } + return sortedKeys(set) +} + +func factNodeName(fact Fact) string { + return firstNonEmpty(fact.Name, fact.ObjectName, fact.Type) +} + +func factNodeKind(fact Fact) string { + switch fact.Type { + case "http.route", "frontend.route": + return "route" + case "orm.query": + return "data-access" + case "runtime.component": + attrs := map[string]string{} + _ = json.Unmarshal([]byte(fact.AttributesJSON), &attrs) + if kind := strings.TrimSpace(attrs["kind"]); kind != "" { + return kind + } + return "service" + case "runtime.connection": + return "connection" + case "storage.volume": + return "volume" + case "runtime.endpoint": + return "endpoint" + default: + return "fact" + } +} + +func factTechnology(fact Fact) string { + attrs := map[string]string{} + _ = json.Unmarshal([]byte(fact.AttributesJSON), &attrs) + if framework := strings.TrimSpace(attrs["framework"]); framework != "" { + return framework + } + if orm := strings.TrimSpace(attrs["orm"]); orm != "" { + return orm + } + if technology := strings.TrimSpace(attrs["technology"]); technology != "" { + return technology + } + return "Runtime" +} + +func extractTechnologyFromTags(currentTechnology string, tags []string) (string, []string) { + var filtered []string + var extracted string + for _, tag := range tags { + if extracted == "" && strings.HasPrefix(tag, "technology:") { + extracted = strings.TrimPrefix(tag, "technology:") + continue + } + filtered = append(filtered, tag) + } + if extracted == "" { + return currentTechnology, filtered + } + if currentTechnology == "" || currentTechnology == "Runtime" || currentTechnology == "Source" { + return extracted, filtered + } + return currentTechnology, filtered +} + +func factNodeDescription(fact Fact) string { + parts := []string{fact.Type} + if fact.Relationship != "" { + parts = append(parts, fact.Relationship) + } + if fact.FilePath != "" && fact.StartLine > 0 { + parts = append(parts, fmt.Sprintf("%s:%d", fact.FilePath, fact.StartLine)) + } + return strings.Join(parts, " - ") +} + +func (m *materializer) materializeConnectors(ctx context.Context, refs []Reference, symbols map[int64]Symbol, folderElements map[string]int64, fileElements map[string]int64, symbolElements map[int64]int64, symbolViews map[int64]int64, repoView int64) error { + filePairs := map[string]filePairReference{} + symbolConnectorCount := map[int64]int{} + for _, ref := range refs { + source := symbols[ref.SourceSymbolID] + target := symbols[ref.TargetSymbolID] + if source.FilePath != "" && target.FilePath != "" && source.FilePath != target.FilePath { + key := source.FilePath + "->" + target.FilePath + pair := filePairs[key] + if pair.Count == 0 { + pair = filePairReference{Key: key, Ref: ref} + } + pair.Count++ + filePairs[key] = pair + continue + } + viewID := symbolViews[ref.SourceSymbolID] + if viewID == 0 || viewID != symbolViews[ref.TargetSymbolID] || symbolConnectorCount[viewID] >= m.thresholds.MaxConnectorsPerView { + continue + } + sourceKey := symbolOwnerKey(source, m.identityKeys) + targetKey := symbolOwnerKey(target, m.identityKeys) + ownerKey := fmt.Sprintf("symbol:%s:%s:%s", sourceKey, targetKey, ref.Kind) + if m.contextPolicyHidden("reference", ownerKey) { + continue + } + if err := m.upsertConnector(ctx, "reference", ownerKey, viewID, symbolElements[ref.SourceSymbolID], symbolElements[ref.TargetSymbolID], "calls"); err != nil { + return err + } + symbolConnectorCount[viewID]++ + } + + fileGroups := map[string][]filePairReference{} + for _, key := range sortedKeys(filePairs) { + pair := filePairs[key] + source := symbols[pair.Ref.SourceSymbolID] + target := symbols[pair.Ref.TargetSymbolID] + sourceGroup := connectorGroupFolder(source.FilePath) + targetGroup := connectorGroupFolder(target.FilePath) + if sourceGroup == "" || targetGroup == "" || sourceGroup == targetGroup || folderElements[sourceGroup] == 0 || folderElements[targetGroup] == 0 { + fileGroups["file:"+key] = append(fileGroups["file:"+key], pair) + continue + } + groupKey := "folder:" + sourceGroup + "->" + targetGroup + fileGroups[groupKey] = append(fileGroups[groupKey], pair) + } + + fileConnectorCount := 0 + for _, groupKey := range sortedFileGroupKeys(fileGroups) { + if fileConnectorCount >= m.thresholds.MaxConnectorsPerView { + break + } + group := fileGroups[groupKey] + if len(group) == 0 { + continue + } + rawReferenceCount := filePairReferenceCount(group) + if strings.HasPrefix(groupKey, "folder:") && rawReferenceCount > m.thresholds.MaxExpandedConnectorsPerGroup { + if m.contextPolicyHidden("folder-reference", groupKey) { + continue + } + first := group[0].Ref + source := symbols[first.SourceSymbolID] + target := symbols[first.TargetSymbolID] + sourceGroup := connectorGroupFolder(source.FilePath) + targetGroup := connectorGroupFolder(target.FilePath) + if err := m.upsertConnector(ctx, "folder-reference", groupKey, repoView, folderElements[sourceGroup], folderElements[targetGroup], fmt.Sprintf("%d references", rawReferenceCount)); err != nil { + return err + } + fileConnectorCount++ + continue + } + for _, item := range group { + if fileConnectorCount >= m.thresholds.MaxConnectorsPerView { + break + } + ref := item.Ref + source := symbols[ref.SourceSymbolID] + target := symbols[ref.TargetSymbolID] + if fileElements[source.FilePath] == 0 || fileElements[target.FilePath] == 0 { + continue + } + ownerKey := "file:" + item.Key + if m.contextPolicyHidden("file-reference", ownerKey) { + continue + } + if err := m.upsertConnector(ctx, "file-reference", ownerKey, repoView, fileElements[source.FilePath], fileElements[target.FilePath], "references"); err != nil { + return err + } + fileConnectorCount++ + } + } + return nil +} + +func (m *materializer) contextPolicyHidden(ownerType, ownerKey string) bool { + _, hidden := m.contextPolicies.Hide[ownerMapKey(ownerType, ownerKey)] + return hidden +} + +func filePairReferenceCount(group []filePairReference) int { + count := 0 + for _, item := range group { + count += item.Count + } + return count +} + +func sortedFileGroupKeys(groups map[string][]filePairReference) []string { + keys := sortedKeys(groups) + sort.SliceStable(keys, func(i, j int) bool { + left := keys[i] + right := keys[j] + leftCross := strings.HasPrefix(left, "folder:") + rightCross := strings.HasPrefix(right, "folder:") + if leftCross != rightCross { + return leftCross + } + leftCount := filePairReferenceCount(groups[left]) + rightCount := filePairReferenceCount(groups[right]) + if leftCount != rightCount { + return leftCount > rightCount + } + return left < right + }) + return keys +} + +func connectorGroupFolder(filePath string) string { + dir := path.Dir(filePath) + if dir == "." || dir == "/" || dir == "" { + return "" + } + if before, _, ok := strings.Cut(dir, "/"); ok { + return before + } + return dir +} + +func (m *materializer) upsertConnector(ctx context.Context, ownerType, ownerKey string, viewID, sourceElementID, targetElementID int64, label string) error { + return m.upsertConnectorDetailed(ctx, ownerType, ownerKey, viewID, sourceElementID, targetElementID, label, label, "") +} + +func (m *materializer) upsertConnectorDetailed(ctx context.Context, ownerType, ownerKey string, viewID, sourceElementID, targetElementID int64, label, relationship, description string) error { + return m.upsertConnectorDetailedWithDirection(ctx, ownerType, ownerKey, viewID, sourceElementID, targetElementID, label, relationship, "forward", description) +} + +func (m *materializer) upsertConnectorDetailedWithDirection(ctx context.Context, ownerType, ownerKey string, viewID, sourceElementID, targetElementID int64, label, relationship, direction, description string) error { + if sourceElementID == 0 || targetElementID == 0 || sourceElementID == targetElementID { + return nil + } + if strings.TrimSpace(relationship) == "" { + relationship = label + } + direction = normalizedArchitectureConnectorDirection(direction) + sourceHandle, targetHandle := "", "" + if ownerType == "fact-reference" { + sourceHandle = "right" + targetHandle = "left" + } + if state, ok, err := m.store.MappingState(ctx, m.repo.ID, ownerType, ownerKey, "connector"); err != nil { + return err + } else if ok && connectorExists(ctx, m.store.db, state.ResourceID) { + dirty, err := m.mappingDirty(ctx, ownerType, ownerKey, "connector", state) + if err != nil { + return err + } + if dirty { + m.stats.ConnectorsPreserved++ + return m.saveMapping(ctx, ownerType, ownerKey, "connector", state.ResourceID) + } + _, err = m.store.db.ExecContext(ctx, ` + UPDATE connectors + SET view_id = ?, source_element_id = ?, target_element_id = ?, label = ?, description = ?, relationship = ?, direction = ?, style = 'solid', source_handle = ?, target_handle = ?, updated_at = ? + WHERE id = ?`, viewID, sourceElementID, targetElementID, label, nullString(description), relationship, direction, nullString(sourceHandle), nullString(targetHandle), nowString(), state.ResourceID) + if err != nil { + return err + } + if err := m.saveMappingWithCurrentHash(ctx, ownerType, ownerKey, "connector", state.ResourceID); err != nil { + return err + } + m.stats.ConnectorsUpdated++ + return nil + } + now := nowString() + res, err := m.store.db.ExecContext(ctx, ` + INSERT INTO connectors(view_id, source_element_id, target_element_id, label, description, relationship, direction, style, source_handle, target_handle, created_at, updated_at) + VALUES (?, ?, ?, ?, ?, ?, ?, 'solid', ?, ?, ?, ?)`, viewID, sourceElementID, targetElementID, label, nullString(description), relationship, direction, nullString(sourceHandle), nullString(targetHandle), now, now) + if err != nil { + return err + } + id, err := res.LastInsertId() + if err != nil { + return err + } + if err := m.saveMappingWithCurrentHash(ctx, ownerType, ownerKey, "connector", id); err != nil { + return err + } + m.stats.ConnectorsCreated++ + return nil +} + +func (m *materializer) saveMapping(ctx context.Context, ownerType, ownerKey, resourceType string, resourceID int64) error { + return m.store.SaveMappingAt(ctx, m.repo.ID, ownerType, ownerKey, resourceType, resourceID, m.runMarker) +} + +func (m *materializer) saveMappingWithCurrentHash(ctx context.Context, ownerType, ownerKey, resourceType string, resourceID int64) error { + resourceHash, exists, err := m.store.WatchResourceHash(ctx, resourceType, resourceID) + if err != nil { + return err + } + if !exists { + return m.saveMapping(ctx, ownerType, ownerKey, resourceType, resourceID) + } + return m.store.SaveMappingHashAt(ctx, m.repo.ID, ownerType, ownerKey, resourceType, resourceID, resourceHash, m.runMarker) +} + +func (m *materializer) mappingDirty(ctx context.Context, ownerType, ownerKey, resourceType string, state materializationState) (bool, error) { + if state.Dirty { + return true, nil + } + if state.LastWatchHash == nil { + return false, nil + } + currentHash, exists, err := m.store.WatchResourceHash(ctx, resourceType, state.ResourceID) + if err != nil { + return false, err + } + if !exists || currentHash == *state.LastWatchHash { + return false, nil + } + if err := m.store.MarkMappingDirty(ctx, m.repo.ID, ownerType, ownerKey, resourceType, state.ResourceID); err != nil { + return false, err + } + return true, nil +} + +func (m *materializer) pruneStaleResources(ctx context.Context) error { + preserved, err := m.store.PruneStaleMaterializedResources(ctx, m.repo.ID, m.runMarker) + if err != nil { + return err + } + m.stats.DeletesPreserved += preserved + return nil +} + +func elementExists(ctx context.Context, db *sql.DB, id int64) bool { + return rowExists(ctx, db, `SELECT 1 FROM elements WHERE id = ?`, id) +} + +func viewExists(ctx context.Context, db *sql.DB, id int64) bool { + return rowExists(ctx, db, `SELECT 1 FROM views WHERE id = ?`, id) +} + +func connectorExists(ctx context.Context, db *sql.DB, id int64) bool { + return rowExists(ctx, db, `SELECT 1 FROM connectors WHERE id = ?`, id) +} + +func rowExists(ctx context.Context, db *sql.DB, query string, id int64) bool { + var one int + err := db.QueryRowContext(ctx, query, id).Scan(&one) + return err == nil +} + +func filesForSymbols(symbols map[int64]Symbol) map[string]struct{} { + out := map[string]struct{}{} + for _, sym := range symbols { + if sym.FilePath != "" { + out[sym.FilePath] = struct{}{} + } + } + return out +} + +func symbolOwnerKey(sym Symbol, identityKeys map[string]string) string { + if identityKeys != nil { + if key := strings.TrimSpace(identityKeys[sym.StableKey]); key != "" { + return key + } + } + return sym.StableKey +} + +func folderSet(files map[string]struct{}) []string { + set := map[string]struct{}{} + for file := range files { + dir := path.Dir(file) + for dir != "." && dir != "/" { + set[dir] = struct{}{} + next := path.Dir(dir) + if next == dir { + break + } + dir = next + } + } + out := sortedKeys(set) + sort.SliceStable(out, func(i, j int) bool { + di := strings.Count(out[i], "/") + dj := strings.Count(out[j], "/") + if di == dj { + return out[i] < out[j] + } + return di < dj + }) + return out +} + +func dominantLanguage(symbols map[int64]Symbol) string { + counts := map[string]int{} + for _, sym := range symbols { + language := languageFromStableKey(sym.StableKey) + if language != "" { + counts[language]++ + } + } + best := "source" + bestCount := 0 + for language, count := range counts { + if count > bestCount || (count == bestCount && language < best) { + best = language + bestCount = count + } + } + return best +} + +func languageForFile(file string, symbols map[int64]Symbol) string { + counts := map[string]int{} + for _, sym := range symbols { + if sym.FilePath != file { + continue + } + language := languageFromStableKey(sym.StableKey) + if language != "" { + counts[language]++ + } + } + if len(counts) == 0 { + return "" + } + best := dominantLanguage(symbols) + bestCount := 0 + for language, count := range counts { + if count > bestCount || (count == bestCount && language < best) { + best = language + bestCount = count + } + } + return best +} + +func languageFromStableKey(stableKey string) string { + if idx := strings.Index(stableKey, ":"); idx > 0 { + return stableKey[:idx] + } + return "source" +} + +func technologyLabel(language string) string { + switch language { + case "go": + return "Go" + case "typescript": + return "TypeScript" + case "javascript": + return "JavaScript" + case "python": + return "Python" + case "java": + return "Java" + case "cpp": + return "C++" + case "c": + return "C" + default: + return "" + } +} + +func technologyLinksForLanguage(language string) []materializedTechnologyLink { + label := technologyLabel(language) + slug := technologyCatalogSlug(language) + if slug == "" { + if label == "" { + return []materializedTechnologyLink{} + } + return []materializedTechnologyLink{{Type: "custom", Label: label}} + } + return []materializedTechnologyLink{{ + Type: "catalog", + Slug: slug, + Label: label, + IsPrimaryIcon: true, + }} +} + +func technologyLinksForElement(technology, language string) []materializedTechnologyLink { + if slug := technologyCatalogSlugForLabel(technology); slug != "" { + label := strings.TrimSpace(technology) + return []materializedTechnologyLink{{ + Type: "catalog", + Slug: slug, + Label: label, + IsPrimaryIcon: true, + }} + } + return technologyLinksForLanguage(language) +} + +func technologyCatalogSlugForLabel(label string) string { + switch strings.ToLower(strings.TrimSpace(label)) { + case "architecture": + return "architecture" + case "structural": + return "structural" + case "container": + return "docker" + default: + return "" + } +} + +func technologyCatalogSlug(language string) string { + switch language { + case "go": + return "golang" + case "typescript": + return "typescript" + case "javascript": + return "javascript" + case "python": + return "python" + case "java": + return "java" + case "cpp": + return "c-plusplus" + case "c": + return "c" + case "json": + return "json-javascript-object-notation" + default: + return "" + } +} + +func symbolsByFile(symbols map[int64]Symbol) map[string][]Symbol { + out := map[string][]Symbol{} + for _, sym := range sortedSymbols(symbols) { + out[sym.FilePath] = append(out[sym.FilePath], sym) + } + return out +} + +func sortedSymbols(symbols map[int64]Symbol) []Symbol { + out := make([]Symbol, 0, len(symbols)) + for _, sym := range symbols { + out = append(out, sym) + } + sort.Slice(out, func(i, j int) bool { + if out[i].FilePath == out[j].FilePath { + if out[i].StartLine == out[j].StartLine { + return out[i].StableKey < out[j].StableKey + } + return out[i].StartLine < out[j].StartLine + } + return out[i].FilePath < out[j].FilePath + }) + return out +} + +func sortedKeys[T any](m map[string]T) []string { + out := make([]string, 0, len(m)) + for key := range m { + out = append(out, key) + } + sort.Strings(out) + return out +} + +func chunkSymbols(symbols []Symbol, size int) [][]Symbol { + if size <= 0 || len(symbols) <= size { + return [][]Symbol{symbols} + } + var chunks [][]Symbol + for start := 0; start < len(symbols); start += size { + end := min(start+size, len(symbols)) + chunks = append(chunks, symbols[start:end]) + } + return chunks +} + +func gridPosition(index int) (float64, float64) { + col := index % 5 + row := index / 5 + return float64(col * 260), float64(row * 160) +} + +func nullStringValue(value sql.NullString) string { + if value.Valid { + return value.String + } + return "" +} + +func repoIdentity(repo Repository) string { + if repo.RemoteURL.Valid && strings.TrimSpace(repo.RemoteURL.String) != "" { + return repo.RemoteURL.String + } + return repo.RepoRoot +} + +func representationHash(filtered filterResult, req RepresentRequest) string { + parts := []string{filtered.RawGraphHash, filtered.SettingsHash, stableHash(req)} + for _, file := range sortedKeys(filtered.ChangedFiles) { + parts = append(parts, "f:"+file) + } + for _, sym := range sortedSymbols(filtered.VisibleSymbols) { + parts = append(parts, "s:"+sym.StableKey) + } + facts := append([]Fact(nil), filtered.VisibleFacts...) + sort.SliceStable(facts, func(i, j int) bool { + if facts[i].Enricher == facts[j].Enricher { + return facts[i].StableKey < facts[j].StableKey + } + return facts[i].Enricher < facts[j].Enricher + }) + for _, fact := range facts { + parts = append(parts, "fact:"+fact.Enricher+":"+fact.StableKey+":"+fact.FactHash) + } + var expansionKeys []string + for key := range filtered.ContextExpansions.Tiers { + expansionKeys = append(expansionKeys, key) + } + sort.Strings(expansionKeys) + for _, key := range expansionKeys { + parts = append(parts, fmt.Sprintf("x:%s:%d", key, filtered.ContextExpansions.Tiers[key])) + } + refs := append([]Reference(nil), filtered.VisibleReferences...) + sort.Slice(refs, func(i, j int) bool { + leftSource := filtered.VisibleSymbols[refs[i].SourceSymbolID].StableKey + rightSource := filtered.VisibleSymbols[refs[j].SourceSymbolID].StableKey + leftTarget := filtered.VisibleSymbols[refs[i].TargetSymbolID].StableKey + rightTarget := filtered.VisibleSymbols[refs[j].TargetSymbolID].StableKey + if leftSource == rightSource { + if leftTarget == rightTarget { + return refs[i].EvidenceHash < refs[j].EvidenceHash + } + return leftTarget < rightTarget + } + return leftSource < rightSource + }) + for _, ref := range refs { + source := filtered.VisibleSymbols[ref.SourceSymbolID].StableKey + target := filtered.VisibleSymbols[ref.TargetSymbolID].StableKey + parts = append(parts, fmt.Sprintf("r:%s:%s:%s:%s", source, target, ref.Kind, ref.EvidenceHash)) + } + return stableHash(parts) +} diff --git a/internal/watch/runner.go b/internal/watch/runner.go new file mode 100644 index 0000000..0b35515 --- /dev/null +++ b/internal/watch/runner.go @@ -0,0 +1,505 @@ +package watch + +import ( + "context" + "crypto/rand" + "database/sql" + "encoding/hex" + "errors" + "fmt" + "os" + "path/filepath" + "sort" + "strings" + "time" + + tldgit "github.com/mertcikla/tld/internal/git" + "github.com/mertcikla/tld/internal/ignore" +) + +type RunnerOptions struct { + Path string + Rescan bool + Verbose bool + PollInterval time.Duration + Debounce time.Duration + HeartbeatInterval time.Duration + SummaryInterval time.Duration + Embedding EmbeddingConfig + Settings Settings + Progress ProgressSink + Events chan<- Event + Ready chan<- RunnerResult +} + +type RunnerResult struct { + Repository Repository + InitialScan ScanResult + InitialRep RepresentResult + GitStatus GitStatus + Token string +} + +type Runner struct { + Store *Store + Scanner *Scanner + Representer *Representer +} + +func NewRunner(store *Store) *Runner { + return &Runner{ + Store: store, + Scanner: NewScanner(store), + Representer: NewRepresenter(store), + } +} + +func (r *Runner) Run(ctx context.Context, opts RunnerOptions) (RunnerResult, error) { + if r == nil || r.Store == nil { + return RunnerResult{}, fmt.Errorf("watch runner requires a store") + } + if r.Scanner == nil { + r.Scanner = NewScanner(r.Store) + } + r.Scanner.Progress = opts.Progress + if r.Representer == nil { + r.Representer = NewRepresenter(r.Store) + } + if opts.Path == "" { + opts.Path = "." + } + settings := NormalizeSettings(opts.Settings) + if opts.PollInterval <= 0 { + opts.PollInterval = settings.PollInterval + } + if opts.Debounce <= 0 { + opts.Debounce = settings.Debounce + } + if opts.HeartbeatInterval <= 0 { + opts.HeartbeatInterval = 2 * time.Second + } + if opts.SummaryInterval <= 0 { + opts.SummaryInterval = time.Minute + } + absPath, err := filepath.Abs(opts.Path) + if err != nil { + return RunnerResult{}, err + } + repoRoot, err := tldgit.RepoRoot(absPath) + if err != nil { + return RunnerResult{}, fmt.Errorf("%s is not inside a git repository: %w", opts.Path, err) + } + + gitStatus, _ := gitStatusSnapshot(repoRoot) + emit(opts.Events, Event{Type: "scan.started", At: nowString(), Phase: "scan", WatcherMode: settings.Watcher, Languages: settings.Languages}) + once, err := r.RunOnce(ctx, OneShotOptions{Path: repoRoot, Rescan: opts.Rescan, Embedding: opts.Embedding, Settings: settings, Progress: opts.Progress}) + if err != nil { + return RunnerResult{}, err + } + scan := once.Scan + emit(opts.Events, Event{Type: "scan.completed", RepositoryID: scan.RepositoryID, At: nowString(), Data: scan, Phase: "scan", WatcherMode: settings.Watcher, Languages: settings.Languages, Warnings: scan.Warnings}) + emit(opts.Events, Event{Type: "representation.started", RepositoryID: scan.RepositoryID, At: nowString(), Phase: "represent", WatcherMode: settings.Watcher, Languages: settings.Languages, Warnings: scan.Warnings}) + repo := once.Repository + token := randomToken() + lock, err := r.Store.AcquireLock(ctx, repo.ID, os.Getpid(), token, LockHeartbeatTimeout) + if err != nil { + return RunnerResult{}, err + } + _ = lock + sourceWatcher := newSourceWatcher(ctx, repoRoot, settings, r.Scanner.EffectiveRules) + watcherMode := sourceWatcher.Mode + warnings := append([]string{}, sourceWatcher.Warnings...) + emit(opts.Events, Event{Type: "watch.started", RepositoryID: repo.ID, At: nowString(), Data: repo.JSON(), Phase: "watch", WatcherMode: watcherMode, Languages: settings.Languages, Warnings: warnings}) + emit(opts.Events, Event{Type: "lock.enabled", RepositoryID: repo.ID, At: nowString()}) + defer func() { + _ = r.Store.ReleaseLock(context.Background(), repo.ID, token) + emit(opts.Events, Event{Type: "lock.disabled", RepositoryID: repo.ID, At: nowString()}) + emit(opts.Events, Event{Type: "watch.stopped", RepositoryID: repo.ID, At: nowString()}) + }() + + rep := once.Representation + emit(opts.Events, Event{Type: "representation.updated", RepositoryID: repo.ID, At: nowString(), Data: rep, Phase: "represent", WatcherMode: watcherMode, Languages: settings.Languages, Warnings: warnings}) + _, _ = r.Store.ApplyGitTags(ctx, repo.ID, gitStatus) + if gitStatus.HeadCommit != "" { + _ = r.createVersionForHead(ctx, repo.ID, gitStatus, rep.RepresentationHash, false) + } + + result := RunnerResult{Repository: repo, InitialScan: scan, InitialRep: rep, GitStatus: gitStatus, Token: token} + if opts.Ready != nil { + select { + case opts.Ready <- result: + default: + } + } + lastSourceSnapshot := sourceFileSnapshot(repoRoot, settings, r.Scanner.Rules) + lastFingerprint := sourceFileFingerprint(lastSourceSnapshot) + lastHead := gitStatus.HeadCommit + lastGitFingerprint := gitStatusFingerprint(gitStatus) + heartbeat := time.NewTicker(opts.HeartbeatInterval) + poll := time.NewTicker(opts.PollInterval) + summary := time.NewTicker(opts.SummaryInterval) + defer heartbeat.Stop() + defer poll.Stop() + defer summary.Stop() + totalChangesProcessed := 0 + intervalChangesProcessed := 0 + + for { + select { + case <-ctx.Done(): + return result, nil + case <-summary.C: + emit(opts.Events, Event{ + Type: "watch.changeCounter", + RepositoryID: repo.ID, + At: nowString(), + WatcherMode: watcherMode, + Languages: settings.Languages, + Data: ChangeCounter{ + TotalChangesProcessed: totalChangesProcessed, + IntervalChangesProcessed: intervalChangesProcessed, + }, + }) + intervalChangesProcessed = 0 + case <-heartbeat.C: + if _, err := r.Store.HeartbeatLock(ctx, repo.ID, token); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return result, nil + } + return result, err + } + status, err := r.Store.LockStatus(ctx, repo.ID, token) + if errors.Is(err, sql.ErrNoRows) { + return result, nil + } + if err == nil && status == "stopping" { + return result, nil + } + if err == nil && status == "paused" { + emit(opts.Events, Event{Type: "watch.paused", RepositoryID: repo.ID, At: nowString()}) + } + emit(opts.Events, Event{Type: "watch.heartbeat", RepositoryID: repo.ID, At: nowString(), Phase: "watch", WatcherMode: watcherMode, Languages: settings.Languages, Warnings: warnings}) + case _, ok := <-sourceWatcher.Events: + if ok { + poll.Reset(time.Millisecond) + } + case <-poll.C: + status, err := r.Store.LockStatus(ctx, repo.ID, token) + if errors.Is(err, sql.ErrNoRows) { + return result, nil + } + if err == nil && status == "paused" { + continue + } + if err == nil && status == "stopping" { + return result, nil + } + nextSourceSnapshot := sourceFileSnapshot(repoRoot, settings, r.Scanner.Rules) + nextFingerprint := sourceFileFingerprint(nextSourceSnapshot) + nextGit, _ := gitStatusSnapshot(repoRoot) + nextGitFingerprint := gitStatusFingerprint(nextGit) + if nextFingerprint == lastFingerprint && nextGit.HeadCommit == lastHead && nextGitFingerprint == lastGitFingerprint { + continue + } + time.Sleep(opts.Debounce) + stableSourceSnapshot := sourceFileSnapshot(repoRoot, settings, r.Scanner.Rules) + sourceChanged := sourceFileFingerprint(stableSourceSnapshot) != lastFingerprint + nextGit, _ = gitStatusSnapshot(repoRoot) + nextGitFingerprint = gitStatusFingerprint(nextGit) + sourceChanges := diffSourceFileSnapshots(lastSourceSnapshot, stableSourceSnapshot) + emit(opts.Events, Event{Type: "scan.started", RepositoryID: repo.ID, At: nowString(), Phase: "scan", WatcherMode: watcherMode, Languages: settings.Languages, ChangedFiles: len(sourceChanges), Warnings: warnings}) + once, err := r.RunOnce(ctx, OneShotOptions{Path: repoRoot, Embedding: opts.Embedding, Settings: settings, Progress: opts.Progress}) + if err != nil { + emit(opts.Events, Event{Type: "watch.error", RepositoryID: repo.ID, At: nowString(), Message: err.Error()}) + continue + } + scan := once.Scan + eventWarnings := append(append([]string{}, warnings...), scan.Warnings...) + emit(opts.Events, Event{Type: "scan.completed", RepositoryID: repo.ID, At: nowString(), Data: scan, Phase: "scan", WatcherMode: watcherMode, Languages: settings.Languages, ChangedFiles: len(sourceChanges), Warnings: eventWarnings}) + emit(opts.Events, Event{Type: "representation.started", RepositoryID: repo.ID, At: nowString(), Phase: "represent", WatcherMode: watcherMode, Languages: settings.Languages, ChangedFiles: len(sourceChanges), Warnings: eventWarnings}) + rep := once.Representation + emit(opts.Events, Event{Type: "representation.updated", RepositoryID: repo.ID, At: nowString(), Data: rep, Phase: "represent", WatcherMode: watcherMode, Languages: settings.Languages, ChangedFiles: len(sourceChanges), Warnings: eventWarnings}) + tagResult, _ := r.Store.ApplyGitTags(ctx, repo.ID, nextGit) + diffs, diffErr := r.Store.BuildWatchDiffs(ctx, repo.ID, rep.RepresentationHash) + if diffErr != nil { + emit(opts.Events, Event{Type: "watch.error", RepositoryID: repo.ID, At: nowString(), Message: diffErr.Error()}) + } + for _, change := range sourceChanges { + emit(opts.Events, Event{ + Type: "source.changed", + RepositoryID: repo.ID, + At: nowString(), + Phase: "watch", + WatcherMode: watcherMode, + Languages: settings.Languages, + ChangedFiles: len(sourceChanges), + Warnings: eventWarnings, + Data: SourceFileChangeResult{ + Change: change, + RepresentationChanged: sourceChangeRepresentationChanged(change, diffs), + Representation: rep, + GitTags: tagResult, + }, + }) + } + processed := len(sourceChanges) + if processed == 0 { + processed = 1 + } + totalChangesProcessed += processed + intervalChangesProcessed += processed + result.InitialRep = rep + emit(opts.Events, Event{Type: "git.statusChanged", RepositoryID: repo.ID, At: nowString(), Data: nextGit}) + if nextGit.HeadCommit != "" && nextGit.HeadCommit != lastHead { + if err := r.createVersionForHead(ctx, repo.ID, nextGit, rep.RepresentationHash, !sourceChanged); err != nil { + emit(opts.Events, Event{Type: "watch.error", RepositoryID: repo.ID, At: nowString(), Message: err.Error()}) + } else { + emit(opts.Events, Event{Type: "version.created", RepositoryID: repo.ID, At: nowString(), Data: map[string]string{"commit_hash": nextGit.HeadCommit}}) + } + lastHead = nextGit.HeadCommit + } + lastSourceSnapshot = stableSourceSnapshot + lastFingerprint = sourceFileFingerprint(stableSourceSnapshot) + lastGitFingerprint = nextGitFingerprint + } + } +} + +func sourceChangeRepresentationChanged(change SourceFileChange, diffs []RepresentationDiff) bool { + path := strings.TrimSpace(filepathToSlash(change.Path)) + if path == "" { + return false + } + for _, diff := range diffs { + if diff.OwnerType == "repository" { + continue + } + for _, candidate := range representationDiffSourcePaths(diff) { + if candidate == path || strings.HasPrefix(candidate, path+"/") || strings.HasPrefix(path, candidate+"/") { + return true + } + } + } + return false +} + +func representationDiffSourcePaths(diff RepresentationDiff) []string { + seen := map[string]struct{}{} + var out []string + add := func(value string) { + value = strings.TrimSpace(filepathToSlash(value)) + value = strings.TrimPrefix(value, "file:") + value = strings.TrimPrefix(value, "folder:") + if value == "" || value == "." { + return + } + if _, ok := seen[value]; ok { + return + } + seen[value] = struct{}{} + out = append(out, value) + } + switch diff.OwnerType { + case "file", "folder": + add(diff.OwnerKey) + case "symbol": + if path, ok := filePathFromStableKey(diff.OwnerKey); ok { + add(path) + } + default: + if strings.HasPrefix(diff.OwnerKey, "file:") || strings.HasPrefix(diff.OwnerKey, "folder:") { + add(diff.OwnerKey) + } + } + return out +} + +func (r *Runner) createVersionForHead(ctx context.Context, repositoryID int64, status GitStatus, representationHash string, baselineOnly bool) error { + if gitStatusClean(status) { + baselineOnly = true + if err := r.Store.PruneDeletedMaterializedResources(ctx, repositoryID); err != nil { + return err + } + } + latest, found, err := r.Store.LatestWatchVersion(ctx, repositoryID) + if err != nil { + return err + } + if found && latest.CommitHash == status.HeadCommit && latest.RepresentationHash == representationHash { + return nil + } + views, elements, connectors, err := r.Store.WorkspaceResourceCounts(ctx) + if err != nil { + return err + } + description := strings.TrimSpace(status.HeadMessage) + if description == "" { + description = "tld watch " + shortHash(status.HeadCommit) + } + workspaceVersionID, err := r.Store.CreateWorkspaceVersion(ctx, status.HeadCommit, "watch", nil, views, elements, connectors, &description, &representationHash) + if err != nil && !strings.Contains(err.Error(), "constraint failed") { + return err + } + var workspaceID *int64 + if err == nil { + workspaceID = &workspaceVersionID + } + parent := "" + if repo, err := r.Store.Repository(ctx, repositoryID); err == nil { + parent, _ = tldgit.DetectParentCommit(repo.RepoRoot) + } + if parent == "" && found { + parent = latest.CommitHash + } + var diffs []RepresentationDiff + if !baselineOnly { + diffs, err = r.Store.BuildWatchDiffs(ctx, repositoryID, representationHash) + if err != nil { + return err + } + } + _, err = r.Store.CreateWatchVersion(ctx, repositoryID, status.HeadCommit, strings.TrimSpace(status.HeadMessage), parent, status.Branch, representationHash, workspaceID, diffs) + return err +} + +func gitStatusSnapshot(repoRoot string) (GitStatus, error) { + status, err := tldgit.StatusSnapshot(repoRoot) + return GitStatus{ + Branch: status.Branch, + HeadCommit: status.HeadCommit, + HeadMessage: status.HeadMessage, + RemoteURL: status.RemoteURL, + Staged: status.Staged, + Unstaged: status.Unstaged, + Untracked: status.Untracked, + Deleted: status.Deleted, + }, err +} + +func gitStatusClean(status GitStatus) bool { + return len(status.Staged) == 0 && len(status.Unstaged) == 0 && len(status.Untracked) == 0 && len(status.Deleted) == 0 +} + +func gitStatusFingerprint(status GitStatus) string { + parts := []string{status.Branch, status.HeadCommit, status.HeadMessage, status.RemoteURL} + appendPaths := func(kind string, paths []string) { + sorted := append([]string(nil), paths...) + sort.Strings(sorted) + for _, path := range sorted { + parts = append(parts, kind+":"+path) + } + } + appendPaths("staged", status.Staged) + appendPaths("unstaged", status.Unstaged) + appendPaths("untracked", status.Untracked) + appendPaths("deleted", status.Deleted) + return hashString(strings.Join(parts, "\n")) +} + +func sourceFileSnapshot(repoRoot string, settings Settings, rules *ignore.Rules) map[string]string { + files := map[string]string{} + settings = NormalizeSettings(settings) + allowed := map[string]struct{}{} + for _, language := range settings.Languages { + allowed[language] = struct{}{} + } + if rules == nil { + rules = &ignore.Rules{} + } + _ = filepath.WalkDir(repoRoot, func(path string, d os.DirEntry, err error) error { + if err != nil { + return nil + } + rel, _ := filepath.Rel(repoRoot, path) + rel = filepath.ToSlash(rel) + if d.IsDir() { + if rel != "." && (rules.ShouldIgnorePath(rel) || isHiddenBuildOutput(d.Name())) { + return filepath.SkipDir + } + return nil + } + language, parseable, ok := watchedFileLanguage(path) + if !ok || (parseable && !languageAllowed(language, allowed)) || rules.ShouldIgnorePath(rel) { + return nil + } + info, err := d.Info() + if err != nil { + return nil + } + files[rel] = language + ":" + info.ModTime().UTC().Format(time.RFC3339Nano) + ":" + fmt.Sprint(info.Size()) + return nil + }) + return files +} + +func sourceFileFingerprint(files map[string]string) string { + h := hashString("") + paths := make([]string, 0, len(files)) + for path := range files { + paths = append(paths, path) + } + sort.Strings(paths) + for _, path := range paths { + h = hashString(h + path + files[path]) + } + return h +} + +func diffSourceFileSnapshots(before, after map[string]string) []SourceFileChange { + seen := map[string]struct{}{} + var changes []SourceFileChange + for path, next := range after { + seen[path] = struct{}{} + prev, ok := before[path] + switch { + case !ok: + changes = append(changes, SourceFileChange{Path: path, ChangeType: "added", Language: sourceSnapshotLanguage(next)}) + case prev != next: + changes = append(changes, SourceFileChange{Path: path, ChangeType: "modified", Language: sourceSnapshotLanguage(next)}) + } + } + for path := range before { + if _, ok := seen[path]; !ok { + changes = append(changes, SourceFileChange{Path: path, ChangeType: "deleted", Language: sourceSnapshotLanguage(before[path])}) + } + } + sort.Slice(changes, func(i, j int) bool { + if changes[i].Path == changes[j].Path { + return changes[i].ChangeType < changes[j].ChangeType + } + return changes[i].Path < changes[j].Path + }) + return changes +} + +func sourceSnapshotLanguage(value string) string { + if idx := strings.Index(value, ":"); idx > 0 { + return value[:idx] + } + return "" +} + +func randomToken() string { + var buf [16]byte + if _, err := rand.Read(buf[:]); err != nil { + return fmt.Sprintf("%d", time.Now().UnixNano()) + } + return hex.EncodeToString(buf[:]) +} + +func emit(ch chan<- Event, event Event) { + if ch == nil { + return + } + select { + case ch <- event: + default: + } +} + +func shortHash(hash string) string { + if len(hash) > 7 { + return hash[:7] + } + return hash +} diff --git a/internal/watch/scan.go b/internal/watch/scan.go new file mode 100644 index 0000000..cfa0364 --- /dev/null +++ b/internal/watch/scan.go @@ -0,0 +1,1056 @@ +package watch + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "fmt" + "io" + "io/fs" + "os" + "path/filepath" + "runtime" + "sort" + "strings" + "sync" + + "github.com/mertcikla/tld/internal/analyzer" + analyzerlsp "github.com/mertcikla/tld/internal/analyzer/lsp" + tldgit "github.com/mertcikla/tld/internal/git" + "github.com/mertcikla/tld/internal/ignore" + "github.com/mertcikla/tld/internal/watch/enrich" + "github.com/mertcikla/tld/internal/watch/enrich/defaults" +) + +const ( + enrichmentVersion = "watch-enrich-v2" + enrichmentVersionEnricher = "watch.enrichment" + enrichmentVersionType = "watch.enrichment.version" +) + +type Scanner struct { + Store *Store + Analyzer analyzer.Service + Enrichers *enrich.Registry + Rules *ignore.Rules + EffectiveRules *ignore.Rules + Progress ProgressSink + Settings Settings +} + +type synchronizedProgress struct { + sink ProgressSink + mu sync.Mutex +} + +func (p *synchronizedProgress) Start(label string, total int) { + if p.sink == nil { + return + } + p.mu.Lock() + defer p.mu.Unlock() + p.sink.Start(label, total) +} + +func (p *synchronizedProgress) Advance(label string) { + if p.sink == nil { + return + } + p.mu.Lock() + defer p.mu.Unlock() + p.sink.Advance(label) +} + +func (p *synchronizedProgress) Finish() { + if p.sink == nil { + return + } + p.mu.Lock() + defer p.mu.Unlock() + p.sink.Finish() +} + +func NewScanner(store *Store) *Scanner { + return &Scanner{ + Store: store, + Analyzer: analyzer.NewService(), + Enrichers: defaults.NewRegistry(), + Rules: &ignore.Rules{}, + } +} + +func (s *Scanner) Scan(ctx context.Context, path string) (ScanResult, error) { + return s.ScanWithOptions(ctx, path, ScanOptions{}) +} + +type ScanOptions struct { + Force bool +} + +func (s *Scanner) ScanFilesWithOptions(ctx context.Context, repo Repository, relFiles []string, opts ScanOptions) (ScanResult, error) { + if s == nil || s.Store == nil { + return ScanResult{}, fmt.Errorf("watch scanner requires a store") + } + if s.Analyzer == nil { + s.Analyzer = analyzer.NewService() + } + if s.Enrichers == nil { + s.Enrichers = defaults.NewRegistry() + } + repoRoot := filepath.Clean(repo.RepoRoot) + gitignoreRules, err := ignore.LoadGitIgnore(repoRoot) + if err != nil { + return ScanResult{}, fmt.Errorf("load .gitignore rules: %w", err) + } + effectiveRules := ignore.Merge(s.Rules, gitignoreRules) + if effectiveRules == nil { + effectiveRules = &ignore.Rules{} + } + s.EffectiveRules = effectiveRules + settings := NormalizeSettings(s.Settings) + allowed := map[string]struct{}{} + for _, language := range settings.Languages { + allowed[language] = struct{}{} + } + + files := make([]string, 0, len(relFiles)) + seenRel := map[string]struct{}{} + for _, rel := range relFiles { + rel = filepath.ToSlash(filepath.Clean(filepath.FromSlash(rel))) + if rel == "." || rel == ".." || filepath.IsAbs(rel) || strings.HasPrefix(rel, "../") { + continue + } + absFile := filepath.Join(repoRoot, filepath.FromSlash(rel)) + language, parseable, ok := watchedFileLanguage(absFile) + if !ok || (parseable && !languageAllowed(language, allowed)) || effectiveRules.ShouldIgnorePath(rel) { + continue + } + if _, ok := seenRel[rel]; ok { + continue + } + seenRel[rel] = struct{}{} + files = append(files, absFile) + } + sort.Strings(files) + repoSignals := enrich.DiscoverRepositorySignalsFromFiles(repoRoot, files) + result := ScanResult{RepositoryID: repo.ID, FilesSeen: len(files)} + mode := "focused" + if opts.Force { + mode = "focused-force" + } + runID, err := s.Store.BeginScanRun(ctx, repo.ID, mode) + if err != nil { + return ScanResult{}, err + } + result.ScanRunID = runID + status := "completed" + var scanErr error + defer func() { + if scanErr != nil { + status = "failed" + } + _ = s.Store.FinishScanRun(context.Background(), runID, status, result, scanErr) + }() + if len(files) == 0 { + return result, nil + } + workers := runtime.NumCPU() + progress := &synchronizedProgress{sink: s.Progress} + progressStart(progress, "Scanning context files", len(files)) + defer progressFinish(progress) + fileResults, err := s.scanFiles(ctx, repo.ID, repoRoot, files, workers, progress, opts.Force, effectiveRules, repoSignals) + if err != nil { + scanErr = err + return result, err + } + var parsedFiles []parsedFile + var parsedFileIDs []int64 + for _, fileResult := range fileResults { + if fileResult.Skipped { + result.FilesSkipped++ + } + if fileResult.Parsed { + result.FilesParsed++ + result.SymbolsSeen += fileResult.SymbolsSeen + parsedFiles = append(parsedFiles, parsedFile{File: fileResult.File, Refs: fileResult.Refs}) + parsedFileIDs = append(parsedFileIDs, fileResult.File.ID) + } + result.Warnings = append(result.Warnings, fileResult.Warnings...) + } + if len(parsedFileIDs) == 0 { + return result, nil + } + progressFinish(progress) + progressStart(progress, "Resolving code references", len(parsedFiles)) + refs, warning, err := s.resolveReferences(ctx, repoRoot, repo.ID, parsedFiles, progress) + progressFinish(progress) + if err != nil { + scanErr = err + return result, err + } + result.Warning = warning + if warning != "" { + result.Warnings = append(result.Warnings, warning) + } + if err := s.Store.ReplaceReferencesForFiles(ctx, repo.ID, parsedFileIDs, refs); err != nil { + scanErr = err + return result, err + } + result.ReferencesSeen = len(refs) + return result, nil +} + +func (s *Scanner) ScanWithOptions(ctx context.Context, path string, opts ScanOptions) (ScanResult, error) { + if s == nil || s.Store == nil { + return ScanResult{}, fmt.Errorf("watch scanner requires a store") + } + if s.Analyzer == nil { + s.Analyzer = analyzer.NewService() + } + if s.Enrichers == nil { + s.Enrichers = defaults.NewRegistry() + } + absPath, err := filepath.Abs(path) + if err != nil { + return ScanResult{}, err + } + repoRoot, err := tldgit.RepoRoot(absPath) + if err != nil { + return ScanResult{}, fmt.Errorf("%s is not inside a git repository: %w", path, err) + } + repoRoot = filepath.Clean(repoRoot) + gitignoreRules, err := ignore.LoadGitIgnore(repoRoot) + if err != nil { + return ScanResult{}, fmt.Errorf("load .gitignore rules: %w", err) + } + effectiveRules := ignore.Merge(s.Rules, gitignoreRules) + if effectiveRules == nil { + effectiveRules = &ignore.Rules{} + } + s.EffectiveRules = effectiveRules + settings := NormalizeSettings(s.Settings) + + repoInput := RepositoryInput{ + RemoteURL: detectString(func() (string, error) { return tldgit.DetectRemoteURL(repoRoot) }), + RepoRoot: repoRoot, + DisplayName: filepath.Base(repoRoot), + Branch: detectString(func() (string, error) { return tldgit.DetectBranch(repoRoot) }), + HeadCommit: detectString(func() (string, error) { return tldgit.DetectHeadCommit(repoRoot) }), + SettingsHash: stableHash(settings), + } + repo, err := s.Store.EnsureRepository(ctx, repoInput) + if err != nil { + return ScanResult{}, err + } + result := ScanResult{RepositoryID: repo.ID} + + mode := "incremental" + if opts.Force { + mode = "full" + } + runID, err := s.Store.BeginScanRun(ctx, repo.ID, mode) + if err != nil { + return ScanResult{}, err + } + result.ScanRunID = runID + status := "completed" + var scanErr error + defer func() { + if scanErr != nil { + status = "failed" + } + _ = s.Store.FinishScanRun(context.Background(), runID, status, result, scanErr) + }() + + workers := runtime.NumCPU() + progress := &synchronizedProgress{sink: s.Progress} + files, err := s.collectSourceFiles(repoRoot, workers, settings.Languages, effectiveRules, progress) + progressFinish(progress) + if err != nil { + scanErr = err + return result, err + } + if err := ctx.Err(); err != nil { + scanErr = err + return result, err + } + repoSignals := enrich.DiscoverRepositorySignalsFromFiles(repoRoot, files) + result.FilesSeen = len(files) + progressStart(progress, "Scanning source files", len(files)) + defer progressFinish(progress) + seen := make(map[string]struct{}, len(files)) + var parsedFiles []parsedFile + var parsedFileIDs []int64 + + fileResults, err := s.scanFiles(ctx, repo.ID, repoRoot, files, workers, progress, opts.Force, effectiveRules, repoSignals) + if err != nil { + scanErr = err + return result, err + } + for _, fileResult := range fileResults { + seen[fileResult.RelPath] = struct{}{} + if fileResult.Skipped { + result.FilesSkipped++ + } + if fileResult.Parsed { + result.FilesParsed++ + result.SymbolsSeen += fileResult.SymbolsSeen + parsedFiles = append(parsedFiles, parsedFile{File: fileResult.File, Refs: fileResult.Refs}) + parsedFileIDs = append(parsedFileIDs, fileResult.File.ID) + } + result.Warnings = append(result.Warnings, fileResult.Warnings...) + } + + if err := ctx.Err(); err != nil { + scanErr = err + return result, err + } + if err := s.Store.DeleteMissingFiles(ctx, repo.ID, seen); err != nil { + scanErr = err + return result, err + } + if len(parsedFileIDs) == 0 { + if summary, err := s.Store.Summary(ctx, repo.ID); err == nil { + result.SymbolsSeen = summary.Symbols + result.ReferencesSeen = summary.References + } + return result, nil + } + + progressFinish(progress) + progressStart(progress, "Resolving code references", len(parsedFiles)) + refs, warning, err := s.resolveReferences(ctx, repoRoot, repo.ID, parsedFiles, progress) + progressFinish(progress) + if err != nil { + scanErr = err + return result, err + } + result.Warning = warning + if warning != "" { + result.Warnings = append(result.Warnings, warning) + } + if err := ctx.Err(); err != nil { + scanErr = err + return result, err + } + if err := s.Store.ReplaceReferencesForFiles(ctx, repo.ID, parsedFileIDs, refs); err != nil { + scanErr = err + return result, err + } + result.ReferencesSeen = len(refs) + return result, nil +} + +type parsedFile struct { + File File + Refs []analyzer.Ref +} + +type scanFileResult struct { + RelPath string + File File + Refs []analyzer.Ref + Parsed bool + Skipped bool + SymbolsSeen int + Warnings []string +} + +func (s *Scanner) scanFiles(ctx context.Context, repositoryID int64, repoRoot string, files []string, workers int, progress ProgressSink, force bool, rules *ignore.Rules, repoSignals []enrich.ActivationSignal) ([]scanFileResult, error) { + if workers <= 0 { + workers = 1 + } + if workers > len(files) && len(files) > 0 { + workers = len(files) + } + jobs := make(chan string) + results := make(chan scanFileResult, len(files)) + errs := make(chan error, 1) + var wg sync.WaitGroup + for i := 0; i < workers; i++ { + wg.Go(func() { + workerAnalyzer := analyzer.NewService() + for absFile := range jobs { + fileResult, err := s.scanFile(ctx, workerAnalyzer, repositoryID, repoRoot, absFile, progress, force, rules, repoSignals) + if err != nil { + select { + case errs <- err: + default: + } + continue + } + results <- fileResult + } + }) + } + for _, file := range files { + select { + case jobs <- file: + case err := <-errs: + close(jobs) + wg.Wait() + close(results) + return nil, err + } + } + close(jobs) + wg.Wait() + close(results) + select { + case err := <-errs: + return nil, err + default: + } + out := make([]scanFileResult, 0, len(files)) + for result := range results { + out = append(out, result) + } + sort.Slice(out, func(i, j int) bool { return out[i].RelPath < out[j].RelPath }) + return out, nil +} + +func (s *Scanner) scanFile(ctx context.Context, workerAnalyzer analyzer.Service, repositoryID int64, repoRoot, absFile string, progress ProgressSink, force bool, rules *ignore.Rules, repoSignals []enrich.ActivationSignal) (scanFileResult, error) { + rel, err := filepath.Rel(repoRoot, absFile) + if err != nil { + return scanFileResult{}, err + } + rel = filepath.ToSlash(rel) + defer progressAdvance(progress, rel) + result := scanFileResult{RelPath: rel} + languageName, parseable, ok := watchedFileLanguage(absFile) + if !ok { + result.Skipped = true + return result, nil + } + info, err := os.Stat(absFile) + if err != nil { + file, _, upsertErr := s.Store.UpsertFile(ctx, repositoryID, rel, languageName, "", "", 0, 0, "error", err) + if upsertErr != nil { + return result, upsertErr + } + result.File = file + return result, nil + } + if cached, ok, err := s.Store.CachedFileByPath(ctx, repositoryID, rel); err != nil { + return result, err + } else if !force && ok && cached.SizeBytes == info.Size() && cached.MtimeUnix == info.ModTime().UnixNano() && cached.WorktreeHash != "" && cached.ScanStatus != "error" { + file, _, err := s.Store.UpsertFile(ctx, repositoryID, rel, languageName, nullStringValue(cached.GitBlobHash), cached.WorktreeHash, info.Size(), info.ModTime().UnixNano(), "parsed", nil) + if err != nil { + return result, err + } + result.File = file + if err := s.backfillFactsForCachedFile(ctx, workerAnalyzer, repositoryID, repoRoot, rel, absFile, languageName, parseable, rules, repoSignals, file, nil, &result); err != nil { + return result, err + } + result.Skipped = true + return result, nil + } + data, err := os.ReadFile(absFile) + if err != nil { + _, _, upsertErr := s.Store.UpsertFile(ctx, repositoryID, rel, languageName, "", "", info.Size(), info.ModTime().UnixNano(), "error", err) + return result, upsertErr + } + worktreeHash := hashBytes(data) + blobHash := detectString(func() (string, error) { return tldgit.FileBlobHash(repoRoot, rel) }) + file, skipped, err := s.Store.UpsertFile(ctx, repositoryID, rel, languageName, blobHash, worktreeHash, info.Size(), info.ModTime().UnixNano(), "parsed", nil) + if err != nil { + return result, err + } + result.File = file + if !force && skipped { + if err := s.backfillFactsForCachedFile(ctx, workerAnalyzer, repositoryID, repoRoot, rel, absFile, languageName, parseable, rules, repoSignals, file, data, &result); err != nil { + return result, err + } + result.Skipped = true + return result, nil + } + if !parseable { + if err := s.enrichFile(ctx, repositoryID, file.ID, repoRoot, rel, absFile, languageName, data, nil, repoSignals, &result); err != nil { + return result, err + } + return result, nil + } + extracted, err := workerAnalyzer.ExtractPath(ctx, absFile, rules, nil) + if err != nil { + _, _, upsertErr := s.Store.UpsertFile(ctx, repositoryID, rel, languageName, blobHash, worktreeHash, info.Size(), info.ModTime().UnixNano(), "error", err) + return result, upsertErr + } + symbols := watchSymbolsFromAnalyzer(repositoryID, file.ID, rel, languageName, data, extracted.Symbols) + if err := s.Store.ReplaceFileSymbols(ctx, repositoryID, file.ID, symbols); err != nil { + return result, err + } + if err := s.enrichFile(ctx, repositoryID, file.ID, repoRoot, rel, absFile, languageName, data, extracted, repoSignals, &result); err != nil { + return result, err + } + result.Parsed = true + result.SymbolsSeen = len(symbols) + result.Refs = extracted.Refs + return result, nil +} + +func (s *Scanner) backfillFactsForCachedFile(ctx context.Context, workerAnalyzer analyzer.Service, repositoryID int64, repoRoot, rel, absFile, language string, parseable bool, rules *ignore.Rules, repoSignals []enrich.ActivationSignal, file File, data []byte, result *scanFileResult) error { + version, err := s.Store.FactVersionForFile(ctx, repositoryID, file.ID, enrichmentVersionEnricher, enrichmentVersionStableKey(rel)) + if err != nil { + return err + } + if version == enrichmentVersion { + return nil + } + if data == nil { + data, err = os.ReadFile(absFile) + if err != nil { + return err + } + } + var extracted *analyzer.Result + if parseable { + extracted, err = workerAnalyzer.ExtractPath(ctx, absFile, rules, nil) + if err != nil { + return err + } + } + return s.enrichFile(ctx, repositoryID, file.ID, repoRoot, rel, absFile, language, data, extracted, repoSignals, result) +} + +func (s *Scanner) enrichFile(ctx context.Context, repositoryID, fileID int64, repoRoot, rel, absFile, language string, data []byte, extracted *analyzer.Result, repoSignals []enrich.ActivationSignal, result *scanFileResult) error { + if s.Enrichers == nil { + s.Enrichers = defaults.NewRegistry() + } + signals := append([]enrich.ActivationSignal{}, repoSignals...) + if extracted != nil { + signals = append(signals, enrich.ImportSignals(extracted.Refs)...) + } + facts, warnings, err := s.Enrichers.EnrichFile(ctx, enrich.FileInput{ + RepoRoot: repoRoot, + AbsPath: absFile, + RelPath: rel, + Language: language, + Source: data, + Parsed: extracted, + Signals: signals, + }) + if err != nil { + return err + } + for _, warning := range warnings { + if warning.Message != "" { + result.Warnings = append(result.Warnings, warning.Enricher+": "+warning.Message) + } + } + watchFacts := watchFactsFromEnrich(repositoryID, fileID, rel, facts) + watchFacts = append(watchFacts, enrichmentVersionFact(repositoryID, fileID, rel)) + return s.Store.ReplaceFactsForFile(ctx, repositoryID, fileID, watchFacts) +} + +func watchedFileLanguage(path string) (language string, parseable bool, ok bool) { + if language, ok := analyzer.DetectLanguage(path); ok { + return string(language), true, true + } + switch strings.ToLower(filepath.Base(path)) { + case "go.mod": + return "go-mod", false, true + case "package.json", "package-lock.json": + return "json", false, true + case "requirements.txt", "requirements.in": + return "python-requirements", false, true + case "build.gradle", "settings.gradle": + return "gradle", false, true + case "cartservice.csproj": + return "xml", false, true + default: + switch strings.ToLower(filepath.Ext(path)) { + case ".cs": + return "c-sharp", false, true + case ".yaml", ".yml": + return "yaml", false, true + case ".proto": + return "protobuf", false, true + case ".tf": + return "terraform", false, true + case ".csproj": + return "xml", false, true + default: + return "", false, false + } + } +} + +func watchFactsFromEnrich(repositoryID, fileID int64, relPath string, facts []enrich.Fact) []Fact { + out := make([]Fact, 0, len(facts)) + for _, fact := range facts { + filePath := strings.TrimSpace(fact.Source.FilePath) + if filePath == "" { + filePath = relPath + } + subjectKind := strings.TrimSpace(fact.Subject.Kind) + if subjectKind == "" { + subjectKind = "file" + } + subjectKey := strings.TrimSpace(fact.Subject.StableKey) + if subjectKey == "" { + subjectKey = "file:" + relPath + } + endLine := fact.Source.EndLine + var endPtr *int + if endLine > 0 { + endPtr = &endLine + } + attrs, _ := json.Marshal(fact.Attributes) + hints, _ := json.Marshal(fact.VisibilityHints) + raw, _ := json.Marshal(fact) + watchFact := Fact{ + RepositoryID: repositoryID, + FileID: fileID, + FilePath: filePath, + StableKey: fact.StableKey, + Type: fact.Type, + Enricher: fact.Enricher, + SubjectKind: subjectKind, + SubjectStableKey: subjectKey, + ObjectKind: strings.TrimSpace(fact.Object.Kind), + ObjectStableKey: strings.TrimSpace(fact.Object.StableKey), + ObjectFilePath: strings.TrimSpace(fact.Object.FilePath), + ObjectName: strings.TrimSpace(fact.Object.Name), + Relationship: strings.TrimSpace(fact.Relationship), + StartLine: fact.Source.StartLine, + EndLine: endPtr, + Confidence: fact.Confidence, + Name: fact.Name, + Tags: append([]string{}, fact.Tags...), + AttributesJSON: string(attrs), + VisibilityHintsJSON: string(hints), + RawJSON: string(raw), + } + watchFact.FactHash = stableHash(struct { + Type string `json:"type"` + StableKey string `json:"stable_key"` + Enricher string `json:"enricher"` + Subject string `json:"subject"` + ObjectKind string `json:"object_kind,omitempty"` + ObjectStableKey string `json:"object_stable_key,omitempty"` + ObjectFilePath string `json:"object_file_path,omitempty"` + ObjectName string `json:"object_name,omitempty"` + Relationship string `json:"relationship,omitempty"` + FilePath string `json:"file_path"` + StartLine int `json:"start_line"` + EndLine *int `json:"end_line,omitempty"` + Confidence float64 `json:"confidence"` + Name string `json:"name"` + Tags []string `json:"tags"` + Attributes map[string]string `json:"attributes"` + VisibilityHints map[string]float64 `json:"visibility_hints,omitempty"` + }{ + Type: watchFact.Type, + StableKey: watchFact.StableKey, + Enricher: watchFact.Enricher, + Subject: watchFact.SubjectKind + ":" + watchFact.SubjectStableKey, + ObjectKind: watchFact.ObjectKind, + ObjectStableKey: watchFact.ObjectStableKey, + ObjectFilePath: watchFact.ObjectFilePath, + ObjectName: watchFact.ObjectName, + Relationship: watchFact.Relationship, + FilePath: watchFact.FilePath, + StartLine: watchFact.StartLine, + EndLine: watchFact.EndLine, + Confidence: watchFact.Confidence, + Name: watchFact.Name, + Tags: watchFact.Tags, + Attributes: fact.Attributes, + VisibilityHints: fact.VisibilityHints, + }) + out = append(out, watchFact) + } + return out +} + +func enrichmentVersionStableKey(relPath string) string { + return "watch.enrichment.version:" + relPath +} + +func enrichmentVersionFact(repositoryID, fileID int64, relPath string) Fact { + fact := Fact{ + RepositoryID: repositoryID, + FileID: fileID, + FilePath: relPath, + StableKey: enrichmentVersionStableKey(relPath), + Type: enrichmentVersionType, + Enricher: enrichmentVersionEnricher, + SubjectKind: "file", + SubjectStableKey: "file:" + relPath, + StartLine: 1, + Confidence: 1, + Name: enrichmentVersion, + AttributesJSON: `{"version":"` + enrichmentVersion + `"}`, + VisibilityHintsJSON: `{}`, + RawJSON: `{"version":"` + enrichmentVersion + `"}`, + } + fact.FactHash = stableHash(struct { + Type string `json:"type"` + StableKey string `json:"stable_key"` + Enricher string `json:"enricher"` + Version string `json:"version"` + }{ + Type: fact.Type, + StableKey: fact.StableKey, + Enricher: fact.Enricher, + Version: enrichmentVersion, + }) + return fact +} + +func (s *Scanner) collectSourceFiles(root string, workers int, languages []string, rules *ignore.Rules, progress ProgressSink) ([]string, error) { + var files []string + if rules == nil { + rules = &ignore.Rules{} + } + entries, err := os.ReadDir(root) + if err != nil { + return nil, err + } + if workers <= 0 { + workers = 1 + } + if workers > len(entries) && len(entries) > 0 { + workers = len(entries) + } + progressStart(progress, "Discovering source files", len(entries)) + jobs := make(chan string) + results := make(chan []string, len(entries)) + errs := make(chan error, 1) + var wg sync.WaitGroup + for i := 0; i < workers; i++ { + wg.Go(func() { + for entryPath := range jobs { + found, err := s.collectSourceFilesUnder(root, entryPath, rules, languages) + progressAdvance(progress, filepath.ToSlash(mustRel(root, entryPath))) + if err != nil { + select { + case errs <- err: + default: + } + continue + } + results <- found + } + }) + } + for _, entry := range entries { + select { + case jobs <- filepath.Join(root, entry.Name()): + case err := <-errs: + close(jobs) + wg.Wait() + close(results) + return nil, err + } + } + close(jobs) + wg.Wait() + close(results) + select { + case err := <-errs: + return nil, err + default: + } + for result := range results { + files = append(files, result...) + } + sort.Strings(files) + return files, nil +} + +func (s *Scanner) collectSourceFilesUnder(root, start string, rules *ignore.Rules, languages []string) ([]string, error) { + var files []string + allowed := map[string]struct{}{} + for _, language := range NormalizeSettings(Settings{Languages: languages}).Languages { + allowed[language] = struct{}{} + } + err := filepath.WalkDir(start, func(path string, d fs.DirEntry, err error) error { + if err != nil { + return err + } + rel, _ := filepath.Rel(root, path) + rel = filepath.ToSlash(rel) + if d.IsDir() { + if rules.ShouldIgnorePath(rel) || isHiddenBuildOutput(d.Name()) { + return filepath.SkipDir + } + return nil + } + language, parseable, ok := watchedFileLanguage(path) + if !ok || (parseable && !languageAllowed(language, allowed)) { + return nil + } + if rules.ShouldIgnorePath(rel) { + return nil + } + if parseable && language == string(analyzer.LanguageGo) { + generated, err := isGeneratedGoFile(path) + if err != nil { + return nil + } + if generated { + return nil + } + } + files = append(files, path) + return nil + }) + return files, err +} + +func watchSymbolsFromAnalyzer(repositoryID, fileID int64, relPath, language string, source []byte, symbols []analyzer.Symbol) []Symbol { + out := make([]Symbol, 0, len(symbols)) + lines := strings.Split(string(source), "\n") + baseKeyCounts := make(map[string]int, len(symbols)) + for _, sym := range symbols { + baseKeyCounts[watchSymbolStableKey(language, relPath, sym)]++ + } + baseKeySeen := make(map[string]int, len(baseKeyCounts)) + for _, sym := range symbols { + qualified := watchSymbolQualifiedName(sym) + stableKey := watchSymbolStableKey(language, relPath, sym) + if baseKeyCounts[stableKey] > 1 { + baseKeySeen[stableKey]++ + stableKey = fmt.Sprintf("%s:line:%d:ordinal:%d", stableKey, sym.Line, baseKeySeen[stableKey]) + } + endLine := sym.EndLine + if endLine <= 0 { + endLine = sym.Line + } + raw, _ := json.Marshal(sym) + endPtr := endLine + body := lineRange(lines, sym.Line, endLine) + out = append(out, Symbol{ + RepositoryID: repositoryID, + FileID: fileID, + FilePath: relPath, + StableKey: stableKey, + Name: sym.Name, + QualifiedName: qualified, + Kind: sym.Kind, + StartLine: sym.Line, + EndLine: &endPtr, + SignatureHash: hashString(fmt.Sprintf("%s:%s:%d", sym.Kind, qualified, sym.Line)), + ContentHash: hashString(normalizeSymbolContent(body, sym.Name, qualified)), + RawJSON: string(raw), + }) + } + return out +} + +func watchSymbolQualifiedName(sym analyzer.Symbol) string { + if sym.Parent == "" { + return sym.Name + } + return sym.Parent + "." + sym.Name +} + +func watchSymbolStableKey(language, relPath string, sym analyzer.Symbol) string { + return fmt.Sprintf("%s:%s:%s:%s", language, relPath, sym.Kind, watchSymbolQualifiedName(sym)) +} + +func normalizeSymbolContent(body, name, qualified string) string { + body = strings.TrimSpace(outdentCode(body)) + replacements := []string{name} + if leaf := pathBaseQualifier(qualified); leaf != "" && leaf != name { + replacements = append(replacements, leaf) + } + for _, replacement := range replacements { + if replacement == "" { + continue + } + body = strings.ReplaceAll(body, replacement, "__symbol__") + } + return body +} + +func (s *Scanner) resolveReferences(ctx context.Context, repoRoot string, repositoryID int64, files []parsedFile, progress ProgressSink) ([]Reference, string, error) { + symbols, err := s.Store.SymbolsForRepository(ctx, repositoryID) + if err != nil { + return nil, "", err + } + byName := make(map[string][]Symbol) + byFile := make(map[int64][]Symbol) + for _, sym := range symbols { + byName[sym.Name] = append(byName[sym.Name], sym) + byFile[sym.FileID] = append(byFile[sym.FileID], sym) + } + for fileID := range byFile { + sort.Slice(byFile[fileID], func(i, j int) bool { + return byFile[fileID][i].StartLine > byFile[fileID][j].StartLine + }) + } + + resolver := analyzerlsp.NewMultiLanguageResolver(repoRoot) + defer func() { _ = resolver.Close() }() + + var refs []Reference + for _, file := range files { + progressAdvance(progress, file.File.Path) + for _, parsedRef := range file.Refs { + if parsedRef.Kind != "" && parsedRef.Kind != "call" { + continue + } + target, ok := resolveTargetSymbol(ctx, resolver, repoRoot, parsedRef, byName, symbols) + if !ok { + continue + } + source, ok := enclosingSymbol(byFile[file.File.ID], parsedRef.Line) + if !ok || source.ID == target.ID { + continue + } + raw, _ := json.Marshal(parsedRef) + kind := parsedRef.Kind + if kind == "" { + kind = "call" + } + refs = append(refs, Reference{ + RepositoryID: repositoryID, + SourceSymbolID: source.ID, + TargetSymbolID: target.ID, + SourceFileID: file.File.ID, + Kind: kind, + Line: parsedRef.Line, + Column: parsedRef.Column, + EvidenceHash: hashString(fmt.Sprintf("%d:%d:%s:%s", parsedRef.Line, parsedRef.Column, kind, parsedRef.Name)), + RawJSON: string(raw), + }) + } + } + return refs, "", nil +} + +func resolveTargetSymbol(ctx context.Context, resolver *analyzerlsp.MultiLanguageResolver, repoRoot string, ref analyzer.Ref, byName map[string][]Symbol, symbols []Symbol) (Symbol, bool) { + if resolver != nil { + locations, err := resolver.ResolveDefinitions(ctx, ref) + if err == nil { + for _, location := range locations { + if sym, ok := symbolAtLocation(repoRoot, symbols, definitionLocation{FilePath: location.FilePath, Line: location.Line}); ok { + return sym, true + } + } + } + } + targets := byName[ref.Name] + if len(targets) != 1 { + return Symbol{}, false + } + return targets[0], true +} + +type definitionLocation struct { + FilePath string + Line int +} + +func symbolAtLocation(repoRoot string, symbols []Symbol, location definitionLocation) (Symbol, bool) { + rel, err := filepath.Rel(repoRoot, location.FilePath) + if err != nil { + return Symbol{}, false + } + rel = filepath.ToSlash(rel) + var best Symbol + found := false + for _, sym := range symbols { + if sym.FilePath != rel { + continue + } + end := sym.StartLine + if sym.EndLine != nil { + end = *sym.EndLine + } + if sym.StartLine <= location.Line && end >= location.Line { + if !found || sym.StartLine > best.StartLine { + best = sym + found = true + } + } + } + return best, found +} + +func enclosingSymbol(symbols []Symbol, line int) (Symbol, bool) { + for _, sym := range symbols { + end := sym.StartLine + if sym.EndLine != nil { + end = *sym.EndLine + } + if sym.StartLine <= line && end >= line { + return sym, true + } + } + return Symbol{}, false +} + +func detectString(fn func() (string, error)) string { + value, err := fn() + if err != nil { + return "" + } + return strings.TrimSpace(value) +} + +func hashBytes(data []byte) string { + sum := sha256.Sum256(data) + return hex.EncodeToString(sum[:]) +} + +func hashString(value string) string { + return hashBytes([]byte(value)) +} + +func lineRange(lines []string, start, end int) string { + if start <= 0 { + start = 1 + } + if end < start { + end = start + } + if start > len(lines) { + return "" + } + if end > len(lines) { + end = len(lines) + } + return strings.Join(lines[start-1:end], "\n") +} + +func isGeneratedGoFile(path string) (bool, error) { + f, err := os.Open(path) + if err != nil { + return false, err + } + defer func() { _ = f.Close() }() + buf := make([]byte, 8192) + n, err := f.Read(buf) + if err != nil && err != io.EOF { + return false, err + } + lines := strings.SplitN(string(buf[:n]), "\n", 21) + for _, line := range lines { + if strings.Contains(line, "Code generated") && strings.Contains(line, "DO NOT EDIT") { + return true, nil + } + } + return false, nil +} + +func isHiddenBuildOutput(name string) bool { + if name == "" || name == "." { + return false + } + if strings.HasPrefix(name, ".") { + switch name { + case ".git", ".cache", ".next", ".tld", ".turbo": + return true + } + } + switch name { + case "dist", "build", "out", "tmp": + return true + default: + return false + } +} diff --git a/internal/watch/settings.go b/internal/watch/settings.go new file mode 100644 index 0000000..6d0a1b0 --- /dev/null +++ b/internal/watch/settings.go @@ -0,0 +1,144 @@ +package watch + +import ( + "sort" + "strings" + "time" + + "github.com/mertcikla/tld/internal/analyzer" +) + +const ( + WatcherAuto = "auto" + WatcherFSNotify = "fsnotify" + WatcherPoll = "poll" +) + +func DefaultSettings() Settings { + langs := make([]string, 0, len(analyzer.SupportedLanguages())) + for _, spec := range analyzer.SupportedLanguages() { + langs = append(langs, string(spec.Language)) + } + sort.Strings(langs) + return Settings{ + Languages: langs, + Watcher: WatcherAuto, + PollInterval: time.Second, + Debounce: 500 * time.Millisecond, + Thresholds: defaultThresholds(Thresholds{}), + Visibility: defaultVisibilityConfig(VisibilityConfig{}), + } +} + +func NormalizeSettings(settings Settings) Settings { + defaults := DefaultSettings() + if len(settings.Languages) == 0 { + settings.Languages = defaults.Languages + } else { + settings.Languages = normalizeLanguages(settings.Languages) + } + switch strings.ToLower(strings.TrimSpace(settings.Watcher)) { + case WatcherFSNotify: + settings.Watcher = WatcherFSNotify + case WatcherPoll: + settings.Watcher = WatcherPoll + default: + settings.Watcher = WatcherAuto + } + if settings.PollInterval <= 0 { + settings.PollInterval = defaults.PollInterval + } + if settings.Debounce <= 0 { + settings.Debounce = defaults.Debounce + } + settings.Thresholds = defaultThresholds(settings.Thresholds) + settings.Visibility = defaultVisibilityConfig(settings.Visibility) + return settings +} + +func defaultVisibilityConfig(cfg VisibilityConfig) VisibilityConfig { + if !cfg.CoreThresholdSet && !cfg.CoreThresholdEnabled { + cfg.CoreThresholdEnabled = true + } + if cfg.CoreThreshold <= 0 { + cfg.CoreThreshold = 1 + } + if cfg.TierMultiplier <= 0 { + cfg.TierMultiplier = 0.5 + } + if cfg.MaxExpansionMultiplier <= 0 { + cfg.MaxExpansionMultiplier = 2 + } + defaults := VisibilityWeights{ + Changed: 100, + Selected: 100, + UserShow: 100, + UserHide: -100, + HighSignalFact: 1.5, + RelationshipProximity: 1, + DependencyFact: 0.2, + UtilityNoise: -0.8, + HighDegreeNoise: -1.5, + } + if !cfg.WeightsSet { + if cfg.Weights.Changed == 0 { + cfg.Weights.Changed = defaults.Changed + } + if cfg.Weights.Selected == 0 { + cfg.Weights.Selected = defaults.Selected + } + if cfg.Weights.UserShow == 0 { + cfg.Weights.UserShow = defaults.UserShow + } + if cfg.Weights.UserHide == 0 { + cfg.Weights.UserHide = defaults.UserHide + } + if cfg.Weights.HighSignalFact == 0 { + cfg.Weights.HighSignalFact = defaults.HighSignalFact + } + if cfg.Weights.RelationshipProximity == 0 { + cfg.Weights.RelationshipProximity = defaults.RelationshipProximity + } + if cfg.Weights.DependencyFact == 0 { + cfg.Weights.DependencyFact = defaults.DependencyFact + } + if cfg.Weights.UtilityNoise == 0 { + cfg.Weights.UtilityNoise = defaults.UtilityNoise + } + if cfg.Weights.HighDegreeNoise == 0 { + cfg.Weights.HighDegreeNoise = defaults.HighDegreeNoise + } + } + return cfg +} + +func normalizeLanguages(values []string) []string { + seen := map[string]struct{}{} + for _, value := range values { + lang := strings.ToLower(strings.TrimSpace(value)) + if lang == "" { + continue + } + if _, ok := analyzer.LanguageSpecFor(analyzer.Language(lang)); !ok { + continue + } + seen[lang] = struct{}{} + } + if len(seen) == 0 { + return DefaultSettings().Languages + } + out := make([]string, 0, len(seen)) + for lang := range seen { + out = append(out, lang) + } + sort.Strings(out) + return out +} + +func languageAllowed(language string, allowed map[string]struct{}) bool { + if len(allowed) == 0 { + return true + } + _, ok := allowed[strings.ToLower(language)] + return ok +} diff --git a/internal/watch/store.go b/internal/watch/store.go new file mode 100644 index 0000000..2a00002 --- /dev/null +++ b/internal/watch/store.go @@ -0,0 +1,3425 @@ +package watch + +import ( + "context" + "crypto/sha256" + "database/sql" + "encoding/binary" + "encoding/hex" + "encoding/json" + "errors" + "fmt" + "maps" + "math" + "os" + "path/filepath" + "sort" + "strings" + "time" + + tldgit "github.com/mertcikla/tld/internal/git" + "github.com/mertcikla/tld/internal/tagcolors" + "github.com/viant/sqlite-vec/vector" +) + +const LockHeartbeatTimeout = 30 * time.Second + +type Store struct { + db *sql.DB +} + +func NewStore(db *sql.DB) *Store { + return &Store{db: db} +} + +type RepositoryInput struct { + RemoteURL string + RepoRoot string + DisplayName string + Branch string + HeadCommit string + IdentityStatus string + SettingsHash string +} + +func (s *Store) EnsureRepository(ctx context.Context, input RepositoryInput) (Repository, error) { + input.RemoteURL = strings.TrimSpace(input.RemoteURL) + input.RepoRoot = strings.TrimSpace(input.RepoRoot) + input.DisplayName = strings.TrimSpace(input.DisplayName) + if input.DisplayName == "" { + input.DisplayName = input.RepoRoot + } + if input.IdentityStatus == "" { + input.IdentityStatus = "known" + } + if input.RemoteURL == "" { + input.IdentityStatus = "local_only" + } + now := nowString() + + var existingID int64 + var err error + if input.RemoteURL != "" { + err = s.db.QueryRowContext(ctx, `SELECT id FROM watch_repositories WHERE remote_url = ?`, input.RemoteURL).Scan(&existingID) + } else { + err = s.db.QueryRowContext(ctx, `SELECT id FROM watch_repositories WHERE repo_root = ? AND identity_status = 'local_only'`, input.RepoRoot).Scan(&existingID) + } + if err != nil && !errors.Is(err, sql.ErrNoRows) { + return Repository{}, err + } + if existingID > 0 { + _, err = s.db.ExecContext(ctx, ` + UPDATE watch_repositories + SET repo_root = ?, display_name = ?, branch = ?, head_commit = ?, identity_status = ?, settings_hash = ?, updated_at = ? + WHERE id = ?`, + input.RepoRoot, + input.DisplayName, + nullString(input.Branch), + nullString(input.HeadCommit), + input.IdentityStatus, + input.SettingsHash, + now, + existingID, + ) + if err != nil { + return Repository{}, err + } + return s.Repository(ctx, existingID) + } + + res, err := s.db.ExecContext(ctx, ` + INSERT INTO watch_repositories(remote_url, repo_root, display_name, branch, head_commit, identity_status, settings_hash, created_at, updated_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`, + nullString(input.RemoteURL), + input.RepoRoot, + input.DisplayName, + nullString(input.Branch), + nullString(input.HeadCommit), + input.IdentityStatus, + input.SettingsHash, + now, + now, + ) + if err != nil { + return Repository{}, err + } + id, err := res.LastInsertId() + if err != nil { + return Repository{}, err + } + return s.Repository(ctx, id) +} + +func (s *Store) Repository(ctx context.Context, id int64) (Repository, error) { + var repo Repository + err := s.db.QueryRowContext(ctx, ` + SELECT id, remote_url, repo_root, display_name, branch, head_commit, identity_status, settings_hash, created_at, updated_at + FROM watch_repositories + WHERE id = ?`, id).Scan( + &repo.ID, + &repo.RemoteURL, + &repo.RepoRoot, + &repo.DisplayName, + &repo.Branch, + &repo.HeadCommit, + &repo.IdentityStatus, + &repo.SettingsHash, + &repo.CreatedAt, + &repo.UpdatedAt, + ) + return repo, err +} + +func (s *Store) Repositories(ctx context.Context) ([]Repository, error) { + rows, err := s.db.QueryContext(ctx, ` + SELECT id, remote_url, repo_root, display_name, branch, head_commit, identity_status, settings_hash, created_at, updated_at + FROM watch_repositories + ORDER BY display_name, id`) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + var repos []Repository + for rows.Next() { + var repo Repository + if err := rows.Scan(&repo.ID, &repo.RemoteURL, &repo.RepoRoot, &repo.DisplayName, &repo.Branch, &repo.HeadCommit, &repo.IdentityStatus, &repo.SettingsHash, &repo.CreatedAt, &repo.UpdatedAt); err != nil { + return nil, err + } + repos = append(repos, repo) + } + return repos, rows.Err() +} + +func (s *Store) ReassociateRepository(ctx context.Context, id int64, remoteURL string) (Repository, error) { + remoteURL = strings.TrimSpace(remoteURL) + if remoteURL == "" { + return Repository{}, fmt.Errorf("remote_url is required") + } + _, err := s.db.ExecContext(ctx, ` + UPDATE watch_repositories + SET remote_url = ?, identity_status = 'known', updated_at = ? + WHERE id = ?`, remoteURL, nowString(), id) + if err != nil { + return Repository{}, err + } + return s.Repository(ctx, id) +} + +func (s *Store) BeginScanRun(ctx context.Context, repositoryID int64, mode string) (int64, error) { + res, err := s.db.ExecContext(ctx, ` + INSERT INTO watch_scan_runs(repository_id, mode, started_at, status) + VALUES (?, ?, ?, 'running')`, repositoryID, mode, nowString()) + if err != nil { + return 0, err + } + return res.LastInsertId() +} + +func (s *Store) FinishScanRun(ctx context.Context, id int64, status string, result ScanResult, runErr error) error { + var errText any + if runErr != nil { + errText = runErr.Error() + } else if result.Warning != "" { + errText = result.Warning + } + _, err := s.db.ExecContext(ctx, ` + UPDATE watch_scan_runs + SET finished_at = ?, status = ?, files_seen = ?, files_parsed = ?, files_skipped = ?, symbols_seen = ?, references_seen = ?, error = ? + WHERE id = ?`, + nowString(), + status, + result.FilesSeen, + result.FilesParsed, + result.FilesSkipped, + result.SymbolsSeen, + result.ReferencesSeen, + errText, + id, + ) + return err +} + +func (s *Store) UpsertFile(ctx context.Context, repositoryID int64, path, language, gitBlobHash, worktreeHash string, sizeBytes, mtimeUnix int64, status string, scanErr error) (File, bool, error) { + existing, found, err := s.fileByPath(ctx, repositoryID, path) + if err != nil { + return File{}, false, err + } + unchanged := found && existing.WorktreeHash == worktreeHash && existing.ScanStatus != "error" + if unchanged { + _, err := s.db.ExecContext(ctx, ` + UPDATE watch_files + SET git_blob_hash = ?, size_bytes = ?, mtime_unix = ?, scan_status = 'skipped', scan_error = NULL, updated_at = ? + WHERE id = ?`, nullString(gitBlobHash), sizeBytes, mtimeUnix, nowString(), existing.ID) + if err != nil { + return File{}, false, err + } + file, err := s.file(ctx, existing.ID) + return file, true, err + } + + errText := "" + if scanErr != nil { + errText = scanErr.Error() + } + now := nowString() + _, err = s.db.ExecContext(ctx, ` + INSERT INTO watch_files(repository_id, path, language, git_blob_hash, worktree_hash, size_bytes, mtime_unix, scan_status, scan_error, created_at, updated_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + ON CONFLICT(repository_id, path) DO UPDATE SET + language = excluded.language, + git_blob_hash = excluded.git_blob_hash, + worktree_hash = excluded.worktree_hash, + size_bytes = excluded.size_bytes, + mtime_unix = excluded.mtime_unix, + scan_status = excluded.scan_status, + scan_error = excluded.scan_error, + updated_at = excluded.updated_at`, + repositoryID, path, language, nullString(gitBlobHash), worktreeHash, sizeBytes, mtimeUnix, status, nullString(errText), now, now) + if err != nil { + return File{}, false, err + } + file, err := s.fileByPathMust(ctx, repositoryID, path) + return file, false, err +} + +func (s *Store) DeleteMissingFiles(ctx context.Context, repositoryID int64, seen map[string]struct{}) error { + rows, err := s.db.QueryContext(ctx, `SELECT id, path FROM watch_files WHERE repository_id = ?`, repositoryID) + if err != nil { + return err + } + defer func() { _ = rows.Close() }() + var ids []int64 + for rows.Next() { + var id int64 + var path string + if err := rows.Scan(&id, &path); err != nil { + return err + } + if _, ok := seen[path]; !ok { + ids = append(ids, id) + } + } + if err := rows.Err(); err != nil { + return err + } + for _, id := range ids { + if _, err := s.db.ExecContext(ctx, `DELETE FROM watch_files WHERE id = ?`, id); err != nil { + return err + } + } + return nil +} + +func (s *Store) ReplaceFileSymbols(ctx context.Context, repositoryID, fileID int64, symbols []Symbol) error { + existingIdentities, err := s.replacementIdentityCandidates(ctx, repositoryID, fileID) + if err != nil { + return err + } + usedIdentities := map[string]struct{}{} + keep := make(map[string]struct{}, len(symbols)) + for _, sym := range symbols { + keep[sym.StableKey] = struct{}{} + identityKey := s.matchSymbolIdentity(sym, existingIdentities, usedIdentities) + usedIdentities[identityKey] = struct{}{} + now := nowString() + _, err := s.db.ExecContext(ctx, ` + INSERT INTO watch_symbols(repository_id, file_id, stable_key, name, qualified_name, kind, start_line, end_line, signature_hash, content_hash, raw_json, created_at, updated_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + ON CONFLICT(repository_id, stable_key) DO UPDATE SET + file_id = excluded.file_id, + name = excluded.name, + qualified_name = excluded.qualified_name, + kind = excluded.kind, + start_line = excluded.start_line, + end_line = excluded.end_line, + signature_hash = excluded.signature_hash, + content_hash = excluded.content_hash, + raw_json = excluded.raw_json, + updated_at = excluded.updated_at`, + repositoryID, fileID, sym.StableKey, sym.Name, sym.QualifiedName, sym.Kind, sym.StartLine, sym.EndLine, sym.SignatureHash, sym.ContentHash, sym.RawJSON, now, now) + if err != nil { + return err + } + if err := s.UpsertSymbolIdentity(ctx, repositoryID, identityKey, sym); err != nil { + return err + } + } + rows, err := s.db.QueryContext(ctx, `SELECT id, stable_key FROM watch_symbols WHERE repository_id = ? AND file_id = ?`, repositoryID, fileID) + if err != nil { + return err + } + defer func() { _ = rows.Close() }() + var deleteIDs []int64 + for rows.Next() { + var id int64 + var stableKey string + if err := rows.Scan(&id, &stableKey); err != nil { + return err + } + if _, ok := keep[stableKey]; !ok { + deleteIDs = append(deleteIDs, id) + } + } + if err := rows.Err(); err != nil { + return err + } + for _, id := range deleteIDs { + if _, err := s.db.ExecContext(ctx, `DELETE FROM watch_symbols WHERE id = ?`, id); err != nil { + return err + } + } + return nil +} + +func (s *Store) CachedFileByPath(ctx context.Context, repositoryID int64, path string) (File, bool, error) { + return s.fileByPath(ctx, repositoryID, path) +} + +type storedSymbolIdentity struct { + IdentityKey string + StableKey string + FilePath string + Kind string + Name string + QualifiedName string + StartLine int + ContentHash string + MissingFile bool +} + +func (s *Store) symbolIdentitiesForFile(ctx context.Context, repositoryID, fileID int64) ([]storedSymbolIdentity, error) { + rows, err := s.db.QueryContext(ctx, ` + SELECT COALESCE(i.identity_key, ws.stable_key), ws.stable_key, f.path, ws.kind, ws.name, ws.qualified_name, ws.start_line, ws.content_hash + FROM watch_symbols ws + JOIN watch_files f ON f.id = ws.file_id + LEFT JOIN watch_symbol_identities i ON i.repository_id = ws.repository_id AND i.current_stable_key = ws.stable_key + WHERE ws.repository_id = ? AND ws.file_id = ?`, repositoryID, fileID) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + var out []storedSymbolIdentity + for rows.Next() { + var identity storedSymbolIdentity + if err := rows.Scan(&identity.IdentityKey, &identity.StableKey, &identity.FilePath, &identity.Kind, &identity.Name, &identity.QualifiedName, &identity.StartLine, &identity.ContentHash); err != nil { + return nil, err + } + out = append(out, identity) + } + return out, rows.Err() +} + +func (s *Store) symbolIdentitiesForRepository(ctx context.Context, repositoryID int64) ([]storedSymbolIdentity, error) { + rows, err := s.db.QueryContext(ctx, ` + SELECT COALESCE(i.identity_key, ws.stable_key), ws.stable_key, f.path, ws.kind, ws.name, ws.qualified_name, ws.start_line, ws.content_hash + FROM watch_symbols ws + JOIN watch_files f ON f.id = ws.file_id + LEFT JOIN watch_symbol_identities i ON i.repository_id = ws.repository_id AND i.current_stable_key = ws.stable_key + WHERE ws.repository_id = ?`, repositoryID) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + var out []storedSymbolIdentity + for rows.Next() { + var identity storedSymbolIdentity + if err := rows.Scan(&identity.IdentityKey, &identity.StableKey, &identity.FilePath, &identity.Kind, &identity.Name, &identity.QualifiedName, &identity.StartLine, &identity.ContentHash); err != nil { + return nil, err + } + out = append(out, identity) + } + return out, rows.Err() +} + +func (s *Store) replacementIdentityCandidates(ctx context.Context, repositoryID, fileID int64) ([]storedSymbolIdentity, error) { + currentFile, err := s.symbolIdentitiesForFile(ctx, repositoryID, fileID) + if err != nil { + return nil, err + } + repo, err := s.Repository(ctx, repositoryID) + if err != nil || strings.TrimSpace(repo.RepoRoot) == "" { + return currentFile, err + } + all, err := s.symbolIdentitiesForRepository(ctx, repositoryID) + if err != nil { + return nil, err + } + seen := map[string]struct{}{} + out := make([]storedSymbolIdentity, 0, len(currentFile)) + for _, identity := range currentFile { + seen[identity.IdentityKey] = struct{}{} + out = append(out, identity) + } + for _, identity := range all { + if _, ok := seen[identity.IdentityKey]; ok { + continue + } + if identity.FilePath == "" || !sourcePathMissing(repo.RepoRoot, identity.FilePath) { + continue + } + identity.MissingFile = true + out = append(out, identity) + seen[identity.IdentityKey] = struct{}{} + } + return out, nil +} + +func sourcePathMissing(repoRoot, relPath string) bool { + cleanRel := filepath.Clean(filepath.FromSlash(relPath)) + if filepath.IsAbs(cleanRel) || cleanRel == "." || cleanRel == ".." || strings.HasPrefix(cleanRel, ".."+string(filepath.Separator)) { + return false + } + _, err := os.Stat(filepath.Join(repoRoot, cleanRel)) + return errors.Is(err, os.ErrNotExist) +} + +func filePathFromStableKey(stableKey string) (string, bool) { + parts := strings.SplitN(stableKey, ":", 4) + if len(parts) < 4 || strings.TrimSpace(parts[1]) == "" { + return "", false + } + return filepathToSlash(parts[1]), true +} + +func (s *Store) materializedOwnerFilePath(ctx context.Context, repositoryID int64, ownerType, ownerKey string) (string, bool, error) { + switch ownerType { + case "file": + path := strings.TrimPrefix(ownerKey, "file:") + if strings.TrimSpace(path) == "" { + return "", false, nil + } + return filepathToSlash(path), true, nil + case "symbol": + var path string + err := s.db.QueryRowContext(ctx, ` + SELECT file_path + FROM watch_symbol_identities + WHERE repository_id = ? AND (identity_key = ? OR current_stable_key = ?) + ORDER BY updated_at DESC + LIMIT 1`, repositoryID, ownerKey, ownerKey).Scan(&path) + if err == nil && strings.TrimSpace(path) != "" { + return filepathToSlash(path), true, nil + } + if err != nil && !errors.Is(err, sql.ErrNoRows) { + return "", false, err + } + if path, ok := filePathFromStableKey(ownerKey); ok { + return path, true, nil + } + return "", false, nil + default: + return "", false, nil + } +} + +func (s *Store) materializedOwnerMissing(ctx context.Context, repositoryID int64, ownerType, ownerKey string) (bool, error) { + repo, err := s.Repository(ctx, repositoryID) + if err != nil { + return false, err + } + path, ok, err := s.materializedOwnerFilePath(ctx, repositoryID, ownerType, ownerKey) + if err != nil || !ok { + return false, err + } + return sourcePathMissing(repo.RepoRoot, path), nil +} + +func (s *Store) deletedMaterializedElementIDs(ctx context.Context, repositoryID int64) (map[int64]struct{}, error) { + rows, err := s.db.QueryContext(ctx, ` + SELECT resource_id, owner_type, owner_key + FROM watch_materialization + WHERE repository_id = ? AND resource_type = 'element' AND owner_type IN ('file', 'symbol')`, repositoryID) + if err != nil { + return nil, err + } + type elementOwner struct { + id int64 + ownerType string + ownerKey string + } + var owners []elementOwner + for rows.Next() { + var id int64 + var ownerType, ownerKey string + if err := rows.Scan(&id, &ownerType, &ownerKey); err != nil { + _ = rows.Close() + return nil, err + } + owners = append(owners, elementOwner{id: id, ownerType: ownerType, ownerKey: ownerKey}) + } + if err := rows.Err(); err != nil { + _ = rows.Close() + return nil, err + } + if err := rows.Close(); err != nil { + return nil, err + } + out := map[int64]struct{}{} + for _, owner := range owners { + missing, err := s.materializedOwnerMissing(ctx, repositoryID, owner.ownerType, owner.ownerKey) + if err != nil { + return nil, err + } + if missing { + out[owner.id] = struct{}{} + } + } + return out, nil +} + +func (s *Store) connectorTouchesElements(ctx context.Context, connectorID int64, elementIDs map[int64]struct{}) (bool, error) { + if len(elementIDs) == 0 { + return false, nil + } + var sourceID, targetID int64 + err := s.db.QueryRowContext(ctx, `SELECT source_element_id, target_element_id FROM connectors WHERE id = ?`, connectorID).Scan(&sourceID, &targetID) + if errors.Is(err, sql.ErrNoRows) { + return false, nil + } + if err != nil { + return false, err + } + _, sourceDeleted := elementIDs[sourceID] + _, targetDeleted := elementIDs[targetID] + return sourceDeleted || targetDeleted, nil +} + +func (s *Store) materializationMappingTombstoned(ctx context.Context, repositoryID int64, mapping watchMaterializationMapping, deletedElementIDs map[int64]struct{}) (bool, error) { + switch mapping.ResourceType { + case "element", "view": + return s.materializedOwnerMissing(ctx, repositoryID, mapping.OwnerType, mapping.OwnerKey) + case "connector": + return s.connectorTouchesElements(ctx, mapping.ResourceID, deletedElementIDs) + default: + return false, nil + } +} + +func (s *Store) matchSymbolIdentity(sym Symbol, existing []storedSymbolIdentity, used map[string]struct{}) string { + for _, identity := range existing { + if identity.StableKey == sym.StableKey { + return identity.IdentityKey + } + } + bestScore := 0.0 + bestKey := "" + for _, identity := range existing { + if _, ok := used[identity.IdentityKey]; ok { + continue + } + if identity.FilePath != sym.FilePath || identity.Kind != sym.Kind { + continue + } + lineDelta := absInt(identity.StartLine - sym.StartLine) + if lineDelta > 3 { + continue + } + score := 0.35 + if lineDelta == 0 { + score += 0.35 + } else { + score += 0.2 + } + if identity.ContentHash == sym.ContentHash { + score += 0.2 + } + if sameQualifierParent(identity.QualifiedName, sym.QualifiedName) { + score += 0.1 + } + if score > bestScore { + bestScore = score + bestKey = identity.IdentityKey + } + } + for _, identity := range existing { + if _, ok := used[identity.IdentityKey]; ok { + continue + } + if !identity.MissingFile || identity.Kind != sym.Kind || identity.ContentHash == "" || identity.ContentHash != sym.ContentHash { + continue + } + score := 0.80 + if sameQualifierParent(identity.QualifiedName, sym.QualifiedName) { + score += 0.10 + } + if nameTokenSimilarity(identity.QualifiedName, sym.QualifiedName) >= 0.50 { + score += 0.05 + } + lineDelta := absInt(identity.StartLine - sym.StartLine) + if lineDelta <= 5 { + score += 0.05 + } + if score > bestScore { + bestScore = score + bestKey = identity.IdentityKey + } + } + if bestScore >= 0.70 && bestKey != "" { + return bestKey + } + return sym.StableKey +} + +func (s *Store) UpsertSymbolIdentity(ctx context.Context, repositoryID int64, identityKey string, sym Symbol) error { + now := nowString() + _, err := s.db.ExecContext(ctx, ` + INSERT INTO watch_symbol_identities(repository_id, identity_key, current_stable_key, file_path, kind, name, qualified_name, start_line, content_hash, created_at, updated_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + ON CONFLICT(repository_id, identity_key) DO UPDATE SET + current_stable_key = excluded.current_stable_key, + file_path = excluded.file_path, + kind = excluded.kind, + name = excluded.name, + qualified_name = excluded.qualified_name, + start_line = excluded.start_line, + content_hash = excluded.content_hash, + updated_at = excluded.updated_at`, + repositoryID, identityKey, sym.StableKey, sym.FilePath, sym.Kind, sym.Name, sym.QualifiedName, sym.StartLine, sym.ContentHash, now, now) + return err +} + +func (s *Store) SymbolIdentityKeys(ctx context.Context, repositoryID int64) (map[string]string, error) { + rows, err := s.db.QueryContext(ctx, `SELECT current_stable_key, identity_key FROM watch_symbol_identities WHERE repository_id = ?`, repositoryID) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + out := map[string]string{} + for rows.Next() { + var stableKey, identityKey string + if err := rows.Scan(&stableKey, &identityKey); err != nil { + return nil, err + } + out[stableKey] = identityKey + } + return out, rows.Err() +} + +func (s *Store) ReplaceReferencesForFiles(ctx context.Context, repositoryID int64, fileIDs []int64, refs []Reference) error { + for _, fileID := range fileIDs { + if _, err := s.db.ExecContext(ctx, `DELETE FROM watch_references WHERE repository_id = ? AND source_file_id = ?`, repositoryID, fileID); err != nil { + return err + } + } + for _, ref := range refs { + now := nowString() + _, err := s.db.ExecContext(ctx, ` + INSERT INTO watch_references(repository_id, source_symbol_id, target_symbol_id, source_file_id, kind, line, column, evidence_hash, raw_json, created_at, updated_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + ON CONFLICT(repository_id, source_symbol_id, target_symbol_id, kind, evidence_hash) DO UPDATE SET + source_file_id = excluded.source_file_id, + line = excluded.line, + column = excluded.column, + raw_json = excluded.raw_json, + updated_at = excluded.updated_at`, + repositoryID, ref.SourceSymbolID, ref.TargetSymbolID, ref.SourceFileID, ref.Kind, ref.Line, ref.Column, ref.EvidenceHash, ref.RawJSON, now, now) + if err != nil { + return err + } + } + return nil +} + +func (s *Store) ReplaceFactsForFile(ctx context.Context, repositoryID, fileID int64, facts []Fact) error { + if _, err := s.db.ExecContext(ctx, `DELETE FROM watch_facts WHERE repository_id = ? AND file_id = ?`, repositoryID, fileID); err != nil { + return err + } + for _, fact := range facts { + now := nowString() + tags, _ := json.Marshal(fact.Tags) + if fact.AttributesJSON == "" { + fact.AttributesJSON = "{}" + } + if fact.VisibilityHintsJSON == "" { + fact.VisibilityHintsJSON = "{}" + } + if fact.RawJSON == "" { + fact.RawJSON = "{}" + } + _, err := s.db.ExecContext(ctx, ` + INSERT INTO watch_facts(repository_id, file_id, stable_key, type, enricher, subject_kind, subject_stable_key, object_kind, object_stable_key, object_file_path, object_name, relationship, file_path, start_line, end_line, confidence, name, tags, attributes_json, visibility_hints_json, fact_hash, raw_json, created_at, updated_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + ON CONFLICT(repository_id, enricher, stable_key) DO UPDATE SET + file_id = excluded.file_id, + type = excluded.type, + subject_kind = excluded.subject_kind, + subject_stable_key = excluded.subject_stable_key, + object_kind = excluded.object_kind, + object_stable_key = excluded.object_stable_key, + object_file_path = excluded.object_file_path, + object_name = excluded.object_name, + relationship = excluded.relationship, + file_path = excluded.file_path, + start_line = excluded.start_line, + end_line = excluded.end_line, + confidence = excluded.confidence, + name = excluded.name, + tags = excluded.tags, + attributes_json = excluded.attributes_json, + visibility_hints_json = excluded.visibility_hints_json, + fact_hash = excluded.fact_hash, + raw_json = excluded.raw_json, + updated_at = excluded.updated_at`, + repositoryID, fileID, fact.StableKey, fact.Type, fact.Enricher, fact.SubjectKind, fact.SubjectStableKey, fact.ObjectKind, fact.ObjectStableKey, fact.ObjectFilePath, fact.ObjectName, fact.Relationship, fact.FilePath, fact.StartLine, fact.EndLine, fact.Confidence, fact.Name, string(tags), fact.AttributesJSON, fact.VisibilityHintsJSON, fact.FactHash, fact.RawJSON, now, now) + if err != nil { + return err + } + } + return nil +} + +func (s *Store) FactVersionForFile(ctx context.Context, repositoryID, fileID int64, enricher, stableKey string) (string, error) { + var version string + err := s.db.QueryRowContext(ctx, ` + SELECT name + FROM watch_facts + WHERE repository_id = ? AND file_id = ? AND enricher = ? AND stable_key = ? + LIMIT 1`, repositoryID, fileID, enricher, stableKey).Scan(&version) + if errors.Is(err, sql.ErrNoRows) { + return "", nil + } + if err != nil { + return "", err + } + return version, nil +} + +func (s *Store) FactsForRepository(ctx context.Context, repositoryID int64) ([]Fact, error) { + rows, err := s.db.QueryContext(ctx, ` + SELECT id, repository_id, file_id, file_path, stable_key, type, enricher, subject_kind, subject_stable_key, object_kind, object_stable_key, object_file_path, object_name, relationship, start_line, end_line, confidence, name, tags, attributes_json, visibility_hints_json, fact_hash, raw_json, created_at, updated_at + FROM watch_facts + WHERE repository_id = ? + ORDER BY file_path, type, enricher, stable_key`, repositoryID) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + var facts []Fact + for rows.Next() { + fact, err := scanFact(rows) + if err != nil { + return nil, err + } + facts = append(facts, fact) + } + return facts, rows.Err() +} + +type factScanner interface { + Scan(dest ...any) error +} + +func scanFact(row factScanner) (Fact, error) { + var fact Fact + var endLine sql.NullInt64 + var rawTags string + if err := row.Scan(&fact.ID, &fact.RepositoryID, &fact.FileID, &fact.FilePath, &fact.StableKey, &fact.Type, &fact.Enricher, &fact.SubjectKind, &fact.SubjectStableKey, &fact.ObjectKind, &fact.ObjectStableKey, &fact.ObjectFilePath, &fact.ObjectName, &fact.Relationship, &fact.StartLine, &endLine, &fact.Confidence, &fact.Name, &rawTags, &fact.AttributesJSON, &fact.VisibilityHintsJSON, &fact.FactHash, &fact.RawJSON, &fact.CreatedAt, &fact.UpdatedAt); err != nil { + return Fact{}, err + } + if endLine.Valid { + value := int(endLine.Int64) + fact.EndLine = &value + } + _ = json.Unmarshal([]byte(rawTags), &fact.Tags) + return fact, nil +} + +func (s *Store) SymbolsForRepository(ctx context.Context, repositoryID int64) ([]Symbol, error) { + return s.QuerySymbols(ctx, repositoryID, SymbolQuery{Limit: -1}) +} + +type SymbolQuery struct { + Search string + File string + Kind string + Limit int + Offset int +} + +func (s *Store) QuerySymbols(ctx context.Context, repositoryID int64, q SymbolQuery) ([]Symbol, error) { + query := ` + SELECT s.id, s.repository_id, s.file_id, f.path, s.stable_key, s.name, s.qualified_name, s.kind, s.start_line, s.end_line, s.signature_hash, s.content_hash, s.raw_json, s.created_at, s.updated_at + FROM watch_symbols s + JOIN watch_files f ON f.id = s.file_id + WHERE s.repository_id = ?` + args := []any{repositoryID} + if q.Search != "" { + query += ` AND (s.name LIKE ? OR s.qualified_name LIKE ?)` + needle := "%" + q.Search + "%" + args = append(args, needle, needle) + } + if q.File != "" { + query += ` AND f.path = ?` + args = append(args, q.File) + } + if q.Kind != "" { + query += ` AND s.kind = ?` + args = append(args, q.Kind) + } + query += ` ORDER BY f.path, s.start_line, s.name` + if q.Limit >= 0 { + if q.Limit == 0 { + q.Limit = 100 + } + if q.Limit > 0 { + query += ` LIMIT ? OFFSET ?` + args = append(args, q.Limit, q.Offset) + } + } + rows, err := s.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + var out []Symbol + for rows.Next() { + var sym Symbol + var endLine sql.NullInt64 + if err := rows.Scan(&sym.ID, &sym.RepositoryID, &sym.FileID, &sym.FilePath, &sym.StableKey, &sym.Name, &sym.QualifiedName, &sym.Kind, &sym.StartLine, &endLine, &sym.SignatureHash, &sym.ContentHash, &sym.RawJSON, &sym.CreatedAt, &sym.UpdatedAt); err != nil { + return nil, err + } + if endLine.Valid { + value := int(endLine.Int64) + sym.EndLine = &value + } + out = append(out, sym) + } + return out, rows.Err() +} + +type ReferenceQuery struct { + SymbolID int64 + Limit int + Offset int +} + +func (s *Store) QueryReferences(ctx context.Context, repositoryID int64, q ReferenceQuery) ([]Reference, error) { + query := ` + SELECT id, repository_id, source_symbol_id, target_symbol_id, source_file_id, kind, line, column, evidence_hash, raw_json, created_at, updated_at + FROM watch_references + WHERE repository_id = ?` + args := []any{repositoryID} + if q.SymbolID > 0 { + query += ` AND (source_symbol_id = ? OR target_symbol_id = ?)` + args = append(args, q.SymbolID, q.SymbolID) + } + query += ` ORDER BY source_file_id, line, column` + if q.Limit == 0 { + q.Limit = 100 + } + query += ` LIMIT ? OFFSET ?` + args = append(args, q.Limit, q.Offset) + rows, err := s.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + var out []Reference + for rows.Next() { + var ref Reference + if err := rows.Scan(&ref.ID, &ref.RepositoryID, &ref.SourceSymbolID, &ref.TargetSymbolID, &ref.SourceFileID, &ref.Kind, &ref.Line, &ref.Column, &ref.EvidenceHash, &ref.RawJSON, &ref.CreatedAt, &ref.UpdatedAt); err != nil { + return nil, err + } + out = append(out, ref) + } + return out, rows.Err() +} + +func (s *Store) Summary(ctx context.Context, repositoryID int64) (Summary, error) { + summary := Summary{RepositoryID: repositoryID} + if err := s.db.QueryRowContext(ctx, `SELECT COUNT(*) FROM watch_files WHERE repository_id = ?`, repositoryID).Scan(&summary.Files); err != nil { + return Summary{}, err + } + if err := s.db.QueryRowContext(ctx, `SELECT COUNT(*) FROM watch_symbols WHERE repository_id = ?`, repositoryID).Scan(&summary.Symbols); err != nil { + return Summary{}, err + } + if err := s.db.QueryRowContext(ctx, `SELECT COUNT(*) FROM watch_references WHERE repository_id = ?`, repositoryID).Scan(&summary.References); err != nil { + return Summary{}, err + } + var finished sql.NullString + err := s.db.QueryRowContext(ctx, ` + SELECT status, started_at, finished_at + FROM watch_scan_runs + WHERE repository_id = ? + ORDER BY id DESC + LIMIT 1`, repositoryID).Scan(&summary.LastScanStatus, &summary.LastScanStarted, &finished) + if err != nil && !errors.Is(err, sql.ErrNoRows) { + return Summary{}, err + } + if finished.Valid { + summary.LastScanFinished = finished.String + } + return summary, nil +} + +func (s *Store) EnsureEmbeddingModel(ctx context.Context, cfg EmbeddingConfig, configHash string) (int64, error) { + cfg = normalizeEmbeddingConfig(cfg) + now := nowString() + _, err := s.db.ExecContext(ctx, ` + INSERT INTO watch_embedding_models(provider, model, dimension, config_hash, created_at) + VALUES (?, ?, ?, ?, ?) + ON CONFLICT(provider, model, dimension, config_hash) DO NOTHING`, + cfg.Provider, cfg.Model, cfg.Dimension, configHash, now) + if err != nil { + return 0, err + } + var id int64 + err = s.db.QueryRowContext(ctx, ` + SELECT id FROM watch_embedding_models + WHERE provider = ? AND model = ? AND dimension = ? AND config_hash = ?`, + cfg.Provider, cfg.Model, cfg.Dimension, configHash).Scan(&id) + return id, err +} + +func (s *Store) Embedding(ctx context.Context, modelID int64, ownerType, ownerKey, inputHash string) ([]byte, bool, error) { + var vector []byte + err := s.db.QueryRowContext(ctx, ` + SELECT vector FROM watch_embeddings + WHERE model_id = ? AND owner_type = ? AND owner_key = ? AND input_hash = ?`, + modelID, ownerType, ownerKey, inputHash).Scan(&vector) + if errors.Is(err, sql.ErrNoRows) { + return nil, false, nil + } + return vector, err == nil, err +} + +func (s *Store) SaveEmbedding(ctx context.Context, modelID int64, ownerType, ownerKey, inputHash string, vectorData []byte) error { + if err := s.EnsureEmbeddingVectorSchema(ctx); err != nil { + return err + } + _, err := s.db.ExecContext(ctx, ` + INSERT INTO watch_embeddings(model_id, owner_type, owner_key, input_hash, vector, created_at) + VALUES (?, ?, ?, ?, ?, ?) + ON CONFLICT(model_id, owner_type, owner_key, input_hash) DO NOTHING`, + modelID, ownerType, ownerKey, inputHash, vectorData, nowString()) + if err != nil { + return err + } + var embeddingID int64 + if err := s.db.QueryRowContext(ctx, ` + SELECT id FROM watch_embeddings + WHERE model_id = ? AND owner_type = ? AND owner_key = ? AND input_hash = ?`, + modelID, ownerType, ownerKey, inputHash).Scan(&embeddingID); err != nil { + return err + } + encoded, err := vector.EncodeEmbedding(bytesToVector(vectorData)) + if err != nil { + return err + } + _, err = s.db.ExecContext(ctx, ` + INSERT INTO _vec_watch_embedding_vec(dataset_id, id, content, meta, embedding) + VALUES (?, ?, ?, ?, ?) + ON CONFLICT(dataset_id, id) DO UPDATE SET + content = excluded.content, + meta = excluded.meta, + embedding = excluded.embedding`, + embeddingDataset(modelID), fmt.Sprintf("%d", embeddingID), ownerKey, ownerType, encoded) + return err +} + +func (s *Store) SimilarEmbeddings(ctx context.Context, modelID int64, query Vector, limit int) ([]int64, error) { + if limit <= 0 { + limit = 10 + } + if err := s.EnsureEmbeddingVectorSchema(ctx); err != nil { + return nil, err + } + return s.similarEmbeddingsFallback(ctx, modelID, query, limit) +} + +func (s *Store) similarEmbeddingsFallback(ctx context.Context, modelID int64, query Vector, limit int) ([]int64, error) { + rows, err := s.db.QueryContext(ctx, `SELECT id, vector FROM watch_embeddings WHERE model_id = ?`, modelID) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + type scored struct { + ID int64 + Score float64 + } + var scoredRows []scored + for rows.Next() { + var id int64 + var data []byte + if err := rows.Scan(&id, &data); err != nil { + return nil, err + } + scoredRows = append(scoredRows, scored{ID: id, Score: CosineSimilarity(query, bytesToVector(data))}) + } + if err := rows.Err(); err != nil { + return nil, err + } + sort.Slice(scoredRows, func(i, j int) bool { return scoredRows[i].Score > scoredRows[j].Score }) + if len(scoredRows) > limit { + scoredRows = scoredRows[:limit] + } + out := make([]int64, 0, len(scoredRows)) + for _, row := range scoredRows { + out = append(out, row.ID) + } + return out, nil +} + +func (s *Store) EnsureEmbeddingVectorSchema(ctx context.Context) error { + if _, err := s.db.ExecContext(ctx, ` + CREATE TABLE IF NOT EXISTS _vec_watch_embedding_vec ( + dataset_id TEXT NOT NULL, + id TEXT NOT NULL, + content TEXT, + meta TEXT, + embedding BLOB, + PRIMARY KEY(dataset_id, id) + )`); err != nil { + return err + } + return nil +} + +func (s *Store) BeginFilterRun(ctx context.Context, repositoryID int64, settingsHash, rawGraphHash string) (int64, error) { + res, err := s.db.ExecContext(ctx, ` + INSERT INTO watch_filter_runs(repository_id, settings_hash, raw_graph_hash, started_at, status) + VALUES (?, ?, ?, ?, 'running')`, repositoryID, settingsHash, rawGraphHash, nowString()) + if err != nil { + return 0, err + } + return res.LastInsertId() +} + +func (s *Store) SaveFilterDecision(ctx context.Context, filterRunID int64, ownerType string, ownerID int64, ownerKey string, decision, reason string, score *float64, tier int, signalsJSON string) error { + if signalsJSON == "" { + signalsJSON = "[]" + } + _, err := s.db.ExecContext(ctx, ` + INSERT INTO watch_filter_decisions(filter_run_id, owner_type, owner_id, owner_key, decision, reason, score, tier, signals_json) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`, filterRunID, ownerType, ownerID, ownerKey, decision, reason, score, tier, signalsJSON) + return err +} + +func (s *Store) FinishFilterRun(ctx context.Context, id int64, status string, visibleSymbols, hiddenSymbols, visibleReferences, hiddenReferences int) error { + _, err := s.db.ExecContext(ctx, ` + UPDATE watch_filter_runs + SET finished_at = ?, status = ?, visible_symbols = ?, hidden_symbols = ?, visible_references = ?, hidden_references = ? + WHERE id = ?`, + nowString(), status, visibleSymbols, hiddenSymbols, visibleReferences, hiddenReferences, id) + return err +} + +func (s *Store) UpsertCluster(ctx context.Context, repositoryID int64, stableKey string, parentClusterID *int64, name, kind, algorithm, settingsHash string, memberIDs []int64) (Cluster, error) { + now := nowString() + _, err := s.db.ExecContext(ctx, ` + INSERT INTO watch_clusters(repository_id, stable_key, parent_cluster_id, name, kind, algorithm, settings_hash, member_count, created_at, updated_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + ON CONFLICT(repository_id, stable_key) DO UPDATE SET + parent_cluster_id = excluded.parent_cluster_id, + name = excluded.name, + kind = excluded.kind, + algorithm = excluded.algorithm, + settings_hash = excluded.settings_hash, + member_count = excluded.member_count, + updated_at = excluded.updated_at`, + repositoryID, stableKey, parentClusterID, name, kind, algorithm, settingsHash, len(memberIDs), now, now) + if err != nil { + return Cluster{}, err + } + cluster, err := s.clusterByStableKey(ctx, repositoryID, stableKey) + if err != nil { + return Cluster{}, err + } + if _, err := s.db.ExecContext(ctx, `DELETE FROM watch_cluster_members WHERE cluster_id = ?`, cluster.ID); err != nil { + return Cluster{}, err + } + for _, memberID := range memberIDs { + if _, err := s.db.ExecContext(ctx, ` + INSERT INTO watch_cluster_members(cluster_id, owner_type, owner_id) + VALUES (?, 'symbol', ?)`, cluster.ID, memberID); err != nil { + return Cluster{}, err + } + } + return cluster, nil +} + +func (s *Store) Clusters(ctx context.Context, repositoryID int64) ([]Cluster, error) { + rows, err := s.db.QueryContext(ctx, ` + SELECT id, repository_id, stable_key, parent_cluster_id, name, kind, algorithm, settings_hash, member_count, created_at, updated_at + FROM watch_clusters + WHERE repository_id = ? + ORDER BY stable_key`, repositoryID) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + var out []Cluster + for rows.Next() { + cluster, err := scanCluster(rows) + if err != nil { + return nil, err + } + out = append(out, cluster) + } + return out, rows.Err() +} + +func (s *Store) BeginRepresentationRun(ctx context.Context, repositoryID int64, rawGraphHash, settingsHash string, embeddingModelID *int64, representationHash string) (int64, error) { + res, err := s.db.ExecContext(ctx, ` + INSERT INTO watch_representation_runs(repository_id, raw_graph_hash, filter_settings_hash, embedding_model_id, representation_hash, started_at, status) + VALUES (?, ?, ?, ?, ?, ?, 'running')`, + repositoryID, rawGraphHash, settingsHash, embeddingModelID, representationHash, nowString()) + if err != nil { + return 0, err + } + return res.LastInsertId() +} + +func (s *Store) FinishRepresentationRun(ctx context.Context, id int64, status string, result RepresentResult, runErr error) error { + var errText any + if runErr != nil { + errText = runErr.Error() + } + _, err := s.db.ExecContext(ctx, ` + UPDATE watch_representation_runs + SET finished_at = ?, status = ?, elements_created = ?, elements_updated = ?, connectors_created = ?, connectors_updated = ?, views_created = ?, error = ? + WHERE id = ?`, + nowString(), status, result.ElementsCreated, result.ElementsUpdated, result.ConnectorsCreated, result.ConnectorsUpdated, result.ViewsCreated, errText, id) + return err +} + +func (s *Store) MappingResourceID(ctx context.Context, repositoryID int64, ownerType, ownerKey, resourceType string) (int64, bool, error) { + var id int64 + err := s.db.QueryRowContext(ctx, ` + SELECT resource_id FROM watch_materialization + WHERE repository_id = ? AND owner_type = ? AND owner_key = ? AND resource_type = ?`, + repositoryID, ownerType, ownerKey, resourceType).Scan(&id) + if errors.Is(err, sql.ErrNoRows) { + return 0, false, nil + } + return id, err == nil, err +} + +type materializationState struct { + ResourceID int64 + LastWatchHash *string + Dirty bool + DirtyDetectedAt *string +} + +func (s *Store) MappingState(ctx context.Context, repositoryID int64, ownerType, ownerKey, resourceType string) (materializationState, bool, error) { + var state materializationState + var lastHash sql.NullString + var dirtyAt sql.NullString + var dirty int + err := s.db.QueryRowContext(ctx, ` + SELECT resource_id, last_watch_hash, dirty, dirty_detected_at FROM watch_materialization + WHERE repository_id = ? AND owner_type = ? AND owner_key = ? AND resource_type = ?`, + repositoryID, ownerType, ownerKey, resourceType).Scan(&state.ResourceID, &lastHash, &dirty, &dirtyAt) + if errors.Is(err, sql.ErrNoRows) { + return materializationState{}, false, nil + } + if err != nil { + return materializationState{}, false, err + } + if lastHash.Valid { + state.LastWatchHash = &lastHash.String + } + if dirtyAt.Valid { + state.DirtyDetectedAt = &dirtyAt.String + } + state.Dirty = dirty != 0 + return state, true, nil +} + +func (s *Store) SaveMapping(ctx context.Context, repositoryID int64, ownerType, ownerKey, resourceType string, resourceID int64) error { + return s.SaveMappingAt(ctx, repositoryID, ownerType, ownerKey, resourceType, resourceID, nowString()) +} + +func (s *Store) SaveMappingAt(ctx context.Context, repositoryID int64, ownerType, ownerKey, resourceType string, resourceID int64, updatedAt string) error { + _, err := s.db.ExecContext(ctx, ` + INSERT INTO watch_materialization(repository_id, owner_type, owner_key, resource_type, resource_id, created_at, updated_at) + VALUES (?, ?, ?, ?, ?, ?, ?) + ON CONFLICT(repository_id, owner_type, owner_key, resource_type) DO UPDATE SET + resource_id = excluded.resource_id, + updated_at = excluded.updated_at`, + repositoryID, ownerType, ownerKey, resourceType, resourceID, updatedAt, updatedAt) + return err +} + +func (s *Store) SaveMappingHashAt(ctx context.Context, repositoryID int64, ownerType, ownerKey, resourceType string, resourceID int64, resourceHash string, updatedAt string) error { + _, err := s.db.ExecContext(ctx, ` + INSERT INTO watch_materialization(repository_id, owner_type, owner_key, resource_type, resource_id, last_watch_hash, dirty, dirty_detected_at, created_at, updated_at) + VALUES (?, ?, ?, ?, ?, ?, 0, NULL, ?, ?) + ON CONFLICT(repository_id, owner_type, owner_key, resource_type) DO UPDATE SET + resource_id = excluded.resource_id, + last_watch_hash = excluded.last_watch_hash, + dirty = 0, + dirty_detected_at = NULL, + updated_at = excluded.updated_at`, + repositoryID, ownerType, ownerKey, resourceType, resourceID, resourceHash, updatedAt, updatedAt) + return err +} + +func (s *Store) MarkMappingDirty(ctx context.Context, repositoryID int64, ownerType, ownerKey, resourceType string, resourceID int64) error { + _, err := s.db.ExecContext(ctx, ` + UPDATE watch_materialization + SET dirty = 1, + dirty_detected_at = COALESCE(dirty_detected_at, ?), + updated_at = ? + WHERE repository_id = ? AND owner_type = ? AND owner_key = ? AND resource_type = ? AND resource_id = ?`, + nowString(), nowString(), repositoryID, ownerType, ownerKey, resourceType, resourceID) + return err +} + +func (s *Store) WatchResourceHash(ctx context.Context, resourceType string, resourceID int64) (string, bool, error) { + switch resourceType { + case "element": + return s.watchElementHash(ctx, resourceID) + case "connector": + return s.watchConnectorHash(ctx, resourceID) + case "view": + return s.watchViewHash(ctx, resourceID) + default: + return "", false, fmt.Errorf("unsupported watch resource type %q", resourceType) + } +} + +func (s *Store) watchElementHash(ctx context.Context, id int64) (string, bool, error) { + var kind, description, technology, repo, branch, filePath, language sql.NullString + var name, techLinks, tags string + err := s.db.QueryRowContext(ctx, ` + SELECT name, kind, description, technology, technology_connectors, tags, repo, branch, file_path, language + FROM elements WHERE id = ?`, id).Scan(&name, &kind, &description, &technology, &techLinks, &tags, &repo, &branch, &filePath, &language) + if errors.Is(err, sql.ErrNoRows) { + return "", false, nil + } + if err != nil { + return "", false, err + } + payload := map[string]any{ + "name": name, + "kind": nullableString(kind), + "description": nullableString(description), + "technology": nullableString(technology), + "technology_connectors": normalizedJSONValue(techLinks), + "tags": normalizedJSONValue(tags), + "repo": nullableString(repo), + "branch": nullableString(branch), + "file_path": nullableString(filePath), + "language": nullableString(language), + } + return hashCanonicalPayload(payload), true, nil +} + +func (s *Store) watchConnectorHash(ctx context.Context, id int64) (string, bool, error) { + var label, relationship sql.NullString + var viewID, sourceID, targetID int64 + var direction, style string + err := s.db.QueryRowContext(ctx, ` + SELECT view_id, source_element_id, target_element_id, label, relationship, direction, style + FROM connectors WHERE id = ?`, id).Scan(&viewID, &sourceID, &targetID, &label, &relationship, &direction, &style) + if errors.Is(err, sql.ErrNoRows) { + return "", false, nil + } + if err != nil { + return "", false, err + } + payload := map[string]any{ + "view_id": viewID, + "source_element_id": sourceID, + "target_element_id": targetID, + "label": nullableString(label), + "relationship": nullableString(relationship), + "direction": direction, + "style": style, + } + return hashCanonicalPayload(payload), true, nil +} + +func (s *Store) watchViewHash(ctx context.Context, id int64) (string, bool, error) { + var ownerID sql.NullInt64 + var levelLabel sql.NullString + var name string + err := s.db.QueryRowContext(ctx, ` + SELECT owner_element_id, name, level_label + FROM views WHERE id = ?`, id).Scan(&ownerID, &name, &levelLabel) + if errors.Is(err, sql.ErrNoRows) { + return "", false, nil + } + if err != nil { + return "", false, err + } + var owner any + if ownerID.Valid { + owner = ownerID.Int64 + } + payload := map[string]any{ + "owner_element_id": owner, + "name": name, + "level_label": nullableString(levelLabel), + } + return hashCanonicalPayload(payload), true, nil +} + +func hashCanonicalPayload(payload any) string { + data, _ := json.Marshal(payload) + return hashBytes(data) +} + +func nullableString(value sql.NullString) any { + if !value.Valid { + return nil + } + return value.String +} + +func normalizedJSONValue(raw string) any { + var value any + if err := json.Unmarshal([]byte(raw), &value); err != nil { + return raw + } + return value +} + +func (s *Store) Materialization(ctx context.Context, repositoryID int64) ([]MaterializationMapping, error) { + rows, err := s.db.QueryContext(ctx, ` + SELECT id, repository_id, owner_type, owner_key, resource_type, resource_id, last_watch_hash, dirty, dirty_detected_at, created_at, updated_at + FROM watch_materialization + WHERE repository_id = ? + ORDER BY owner_type, owner_key, resource_type`, repositoryID) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + var out []MaterializationMapping + for rows.Next() { + var item MaterializationMapping + var lastHash sql.NullString + var dirtyAt sql.NullString + var dirty int + if err := rows.Scan(&item.ID, &item.RepositoryID, &item.OwnerType, &item.OwnerKey, &item.ResourceType, &item.ResourceID, &lastHash, &dirty, &dirtyAt, &item.CreatedAt, &item.UpdatedAt); err != nil { + return nil, err + } + if lastHash.Valid { + item.LastWatchHash = &lastHash.String + } + item.Dirty = dirty != 0 + if dirtyAt.Valid { + item.DirtyDetectedAt = &dirtyAt.String + } + out = append(out, item) + } + return out, rows.Err() +} + +func (s *Store) ReplaceArchitectureBindings(ctx context.Context, repositoryID int64, bindings []ArchitectureBinding) error { + if err := s.ensureArchitectureLinksTable(ctx); err != nil { + return err + } + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return err + } + defer func() { _ = tx.Rollback() }() + if _, err := tx.ExecContext(ctx, `DELETE FROM watch_architecture_links WHERE repository_id = ?`, repositoryID); err != nil { + return err + } + now := nowString() + for _, binding := range bindings { + evidence, _ := json.Marshal(binding.Evidence) + if _, err := tx.ExecContext(ctx, ` + INSERT INTO watch_architecture_links( + repository_id, component_key, target_repository_id, target_owner_type, target_owner_key, + target_resource_type, target_resource_id, role, confidence, evidence_json, created_at, updated_at + ) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, + repositoryID, + binding.ComponentKey, + binding.TargetRepositoryID, + binding.TargetOwnerType, + binding.TargetOwnerKey, + binding.TargetResourceType, + binding.TargetResourceID, + binding.Role, + binding.Confidence, + string(evidence), + now, + now, + ); err != nil { + return err + } + } + return tx.Commit() +} + +func (s *Store) ArchitectureBindings(ctx context.Context, repositoryID int64) ([]ArchitectureBinding, error) { + if err := s.ensureArchitectureLinksTable(ctx); err != nil { + return nil, err + } + rows, err := s.db.QueryContext(ctx, ` + SELECT id, repository_id, component_key, target_repository_id, target_owner_type, target_owner_key, + target_resource_type, target_resource_id, role, confidence, evidence_json, created_at, updated_at + FROM watch_architecture_links + WHERE repository_id = ? + ORDER BY component_key, role, confidence DESC, target_owner_type, target_owner_key`, repositoryID) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + var out []ArchitectureBinding + for rows.Next() { + var item ArchitectureBinding + var evidenceJSON string + if err := rows.Scan( + &item.ID, + &item.RepositoryID, + &item.ComponentKey, + &item.TargetRepositoryID, + &item.TargetOwnerType, + &item.TargetOwnerKey, + &item.TargetResourceType, + &item.TargetResourceID, + &item.Role, + &item.Confidence, + &evidenceJSON, + &item.CreatedAt, + &item.UpdatedAt, + ); err != nil { + return nil, err + } + _ = json.Unmarshal([]byte(evidenceJSON), &item.Evidence) + out = append(out, item) + } + return out, rows.Err() +} + +func (s *Store) ArchitectureBindingTargets(ctx context.Context) ([]ArchitectureBindingTarget, error) { + rows, err := s.db.QueryContext(ctx, ` + SELECT wm.repository_id, wm.owner_type, wm.owner_key, wm.resource_type, wm.resource_id, + COALESCE(v.id, 0), e.name, COALESCE(e.kind, ''), COALESCE(e.file_path, ''), + COALESCE(e.language, ''), COALESCE(e.tags, '[]') + FROM watch_materialization wm + JOIN elements e ON e.id = wm.resource_id + LEFT JOIN views v ON v.owner_element_id = e.id + WHERE wm.resource_type = 'element' + AND wm.owner_type IN ('folder', 'file', 'symbol', 'cluster', 'fact', 'fact-summary') + ORDER BY wm.repository_id, wm.owner_type, wm.owner_key`) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + var out []ArchitectureBindingTarget + for rows.Next() { + var item ArchitectureBindingTarget + var tagsJSON string + if err := rows.Scan( + &item.RepositoryID, + &item.OwnerType, + &item.OwnerKey, + &item.ResourceType, + &item.ResourceID, + &item.ViewID, + &item.Name, + &item.Kind, + &item.FilePath, + &item.Language, + &tagsJSON, + ); err != nil { + return nil, err + } + _ = json.Unmarshal([]byte(tagsJSON), &item.Tags) + out = append(out, item) + } + return out, rows.Err() +} + +func (s *Store) ensureArchitectureLinksTable(ctx context.Context) error { + _, err := s.db.ExecContext(ctx, ` + CREATE TABLE IF NOT EXISTS watch_architecture_links ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + repository_id INTEGER NOT NULL, + component_key TEXT NOT NULL, + target_repository_id INTEGER NOT NULL, + target_owner_type TEXT NOT NULL, + target_owner_key TEXT NOT NULL, + target_resource_type TEXT NOT NULL, + target_resource_id INTEGER NOT NULL, + role TEXT NOT NULL, + confidence REAL NOT NULL, + evidence_json TEXT NOT NULL DEFAULT '[]', + created_at TEXT NOT NULL, + updated_at TEXT NOT NULL, + UNIQUE(repository_id, component_key, target_repository_id, target_owner_type, target_owner_key, target_resource_type, role), + FOREIGN KEY (repository_id) REFERENCES watch_repositories(id) ON DELETE CASCADE, + FOREIGN KEY (target_repository_id) REFERENCES watch_repositories(id) ON DELETE CASCADE + ); + CREATE INDEX IF NOT EXISTS idx_watch_architecture_links_repository_id + ON watch_architecture_links(repository_id); + CREATE INDEX IF NOT EXISTS idx_watch_architecture_links_target + ON watch_architecture_links(target_repository_id, target_owner_type, target_owner_key); + `) + return err +} + +type watchMaterializationMapping struct { + ID int64 + OwnerType string + OwnerKey string + ResourceType string + ResourceID int64 + LastWatchHash *string + Dirty bool + DirtyDetectedAt *string + UpdatedAt string +} + +func (s *Store) staleMaterializationMappings(ctx context.Context, repositoryID int64, runMarker string) ([]watchMaterializationMapping, error) { + rows, err := s.db.QueryContext(ctx, ` + SELECT id, owner_type, owner_key, resource_type, resource_id, last_watch_hash, dirty, dirty_detected_at, updated_at + FROM watch_materialization + WHERE repository_id = ? AND updated_at != ?`, repositoryID, runMarker) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + var out []watchMaterializationMapping + for rows.Next() { + var item watchMaterializationMapping + var lastHash sql.NullString + var dirtyAt sql.NullString + var dirty int + if err := rows.Scan(&item.ID, &item.OwnerType, &item.OwnerKey, &item.ResourceType, &item.ResourceID, &lastHash, &dirty, &dirtyAt, &item.UpdatedAt); err != nil { + return nil, err + } + if lastHash.Valid { + item.LastWatchHash = &lastHash.String + } + item.Dirty = dirty != 0 + if dirtyAt.Valid { + item.DirtyDetectedAt = &dirtyAt.String + } + out = append(out, item) + } + if err := rows.Err(); err != nil { + return nil, err + } + sortMaterializationMappingsForDelete(out) + return out, nil +} + +func (s *Store) allMaterializationMappings(ctx context.Context, repositoryID int64) ([]watchMaterializationMapping, error) { + rows, err := s.db.QueryContext(ctx, ` + SELECT id, owner_type, owner_key, resource_type, resource_id, last_watch_hash, dirty, dirty_detected_at, updated_at + FROM watch_materialization + WHERE repository_id = ?`, repositoryID) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + var out []watchMaterializationMapping + for rows.Next() { + var item watchMaterializationMapping + var lastHash sql.NullString + var dirtyAt sql.NullString + var dirty int + if err := rows.Scan(&item.ID, &item.OwnerType, &item.OwnerKey, &item.ResourceType, &item.ResourceID, &lastHash, &dirty, &dirtyAt, &item.UpdatedAt); err != nil { + return nil, err + } + if lastHash.Valid { + item.LastWatchHash = &lastHash.String + } + item.Dirty = dirty != 0 + if dirtyAt.Valid { + item.DirtyDetectedAt = &dirtyAt.String + } + out = append(out, item) + } + if err := rows.Err(); err != nil { + return nil, err + } + sortMaterializationMappingsForDelete(out) + return out, nil +} + +func sortMaterializationMappingsForDelete(items []watchMaterializationMapping) { + order := map[string]int{"connector": 0, "view": 1, "element": 2} + sort.Slice(items, func(i, j int) bool { + left, right := order[items[i].ResourceType], order[items[j].ResourceType] + if left == right { + return items[i].ID < items[j].ID + } + return left < right + }) +} + +func (s *Store) deleteMaterializationMapping(ctx context.Context, mapping watchMaterializationMapping) error { + var query string + switch mapping.ResourceType { + case "connector": + query = `DELETE FROM connectors WHERE id = ?` + case "view": + query = `DELETE FROM views WHERE id = ?` + case "element": + query = `DELETE FROM elements WHERE id = ?` + default: + query = "" + } + if query != "" { + if _, err := s.db.ExecContext(ctx, query, mapping.ResourceID); err != nil { + return err + } + } + _, err := s.db.ExecContext(ctx, `DELETE FROM watch_materialization WHERE id = ?`, mapping.ID) + return err +} + +func (s *Store) PruneStaleMaterializedResources(ctx context.Context, repositoryID int64, runMarker string) (int, error) { + if runMarker == "" { + return 0, nil + } + mappings, err := s.staleMaterializationMappings(ctx, repositoryID, runMarker) + if err != nil { + return 0, err + } + deletedElementIDs, err := s.deletedMaterializedElementIDs(ctx, repositoryID) + if err != nil { + return 0, err + } + preserved := 0 + for _, mapping := range mappings { + tombstoned, err := s.materializationMappingTombstoned(ctx, repositoryID, mapping, deletedElementIDs) + if err != nil { + return preserved, err + } + if tombstoned { + continue + } + dirty, err := s.mappingResourceDirty(ctx, repositoryID, mapping) + if err != nil { + return preserved, err + } + if dirty { + preserved++ + continue + } + if err := s.deleteMaterializationMapping(ctx, mapping); err != nil { + return preserved, err + } + } + return preserved, nil +} + +func (s *Store) PruneDeletedMaterializedResources(ctx context.Context, repositoryID int64) error { + mappings, err := s.allMaterializationMappings(ctx, repositoryID) + if err != nil { + return err + } + deletedElementIDs, err := s.deletedMaterializedElementIDs(ctx, repositoryID) + if err != nil { + return err + } + for _, mapping := range mappings { + tombstoned, err := s.materializationMappingTombstoned(ctx, repositoryID, mapping, deletedElementIDs) + if err != nil { + return err + } + if !tombstoned { + continue + } + dirty, err := s.mappingResourceDirty(ctx, repositoryID, mapping) + if err != nil { + return err + } + if dirty { + continue + } + if err := s.deleteMaterializationMapping(ctx, mapping); err != nil { + return err + } + } + return nil +} + +func (s *Store) mappingResourceDirty(ctx context.Context, repositoryID int64, mapping watchMaterializationMapping) (bool, error) { + if mapping.Dirty { + return true, nil + } + if mapping.LastWatchHash == nil { + return false, nil + } + currentHash, exists, err := s.WatchResourceHash(ctx, mapping.ResourceType, mapping.ResourceID) + if err != nil { + return false, err + } + if !exists { + return false, nil + } + if currentHash == *mapping.LastWatchHash { + return false, nil + } + if err := s.MarkMappingDirty(ctx, repositoryID, mapping.OwnerType, mapping.OwnerKey, mapping.ResourceType, mapping.ResourceID); err != nil { + return false, err + } + return true, nil +} + +func (s *Store) RepositoryMaterializationCount(ctx context.Context, repositoryID int64) (int, error) { + var count int + err := s.db.QueryRowContext(ctx, `SELECT COUNT(*) FROM watch_materialization WHERE repository_id = ?`, repositoryID).Scan(&count) + return count, err +} + +type FilterDecisionQuery struct { + OwnerType string + Decision string + Limit int + Offset int +} + +func (s *Store) FilterDecisions(ctx context.Context, repositoryID int64, q FilterDecisionQuery) ([]FilterDecision, error) { + runID, err := s.latestFilterRunID(ctx, repositoryID) + if err != nil { + return nil, err + } + if runID == 0 { + return []FilterDecision{}, nil + } + query := ` + SELECT id, filter_run_id, owner_type, owner_id, owner_key, decision, reason, score, tier, signals_json + FROM watch_filter_decisions + WHERE filter_run_id = ?` + args := []any{runID} + if q.OwnerType != "" { + query += ` AND owner_type = ?` + args = append(args, q.OwnerType) + } + if q.Decision != "" { + query += ` AND decision = ?` + args = append(args, q.Decision) + } + query += ` ORDER BY id` + if q.Limit == 0 { + q.Limit = 100 + } + query += ` LIMIT ? OFFSET ?` + args = append(args, q.Limit, q.Offset) + rows, err := s.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + var out []FilterDecision + for rows.Next() { + var item FilterDecision + var score sql.NullFloat64 + if err := rows.Scan(&item.ID, &item.FilterRunID, &item.OwnerType, &item.OwnerID, &item.OwnerKey, &item.Decision, &item.Reason, &score, &item.Tier, &item.SignalsJSON); err != nil { + return nil, err + } + if score.Valid { + item.Score = &score.Float64 + } + out = append(out, item) + } + return out, rows.Err() +} + +func (s *Store) RepresentationSummary(ctx context.Context, repositoryID int64) (RepresentationSummary, error) { + summary := RepresentationSummary{RepositoryID: repositoryID} + var finished sql.NullString + err := s.db.QueryRowContext(ctx, ` + SELECT raw_graph_hash, filter_settings_hash, representation_hash, status, started_at, finished_at, + elements_created, elements_updated, connectors_created, connectors_updated, views_created + FROM watch_representation_runs + WHERE repository_id = ? + ORDER BY id DESC + LIMIT 1`, repositoryID).Scan( + &summary.RawGraphHash, &summary.SettingsHash, &summary.RepresentationHash, &summary.LastStatus, &summary.LastStartedAt, &finished, + &summary.ElementsCreated, &summary.ElementsUpdated, &summary.ConnectorsCreated, &summary.ConnectorsUpdated, &summary.ViewsCreated, + ) + if err != nil && !errors.Is(err, sql.ErrNoRows) { + return RepresentationSummary{}, err + } + if finished.Valid { + summary.LastFinishedAt = &finished.String + } + var filterFinished sql.NullString + err = s.db.QueryRowContext(ctx, ` + SELECT visible_symbols, hidden_symbols, visible_references, hidden_references, finished_at + FROM watch_filter_runs + WHERE repository_id = ? + ORDER BY id DESC + LIMIT 1`, repositoryID).Scan(&summary.VisibleSymbols, &summary.HiddenSymbols, &summary.VisibleReferences, &summary.HiddenReferences, &filterFinished) + if err != nil && !errors.Is(err, sql.ErrNoRows) { + return RepresentationSummary{}, err + } + return summary, nil +} + +func (s *Store) RawGraphHash(ctx context.Context, repositoryID int64) (string, error) { + rows, err := s.db.QueryContext(ctx, ` + SELECT stable_key, signature_hash, content_hash + FROM watch_symbols + WHERE repository_id = ? + ORDER BY stable_key`, repositoryID) + if err != nil { + return "", err + } + defer func() { _ = rows.Close() }() + h := sha256.New() + for rows.Next() { + var stableKey, signatureHash, contentHash string + if err := rows.Scan(&stableKey, &signatureHash, &contentHash); err != nil { + return "", err + } + _, _ = h.Write([]byte("s:" + stableKey + ":" + signatureHash + ":" + contentHash + "\n")) + } + if err := rows.Err(); err != nil { + return "", err + } + refRows, err := s.db.QueryContext(ctx, ` + SELECT source.stable_key, target.stable_key, r.kind, r.evidence_hash + FROM watch_references r + JOIN watch_symbols source ON source.id = r.source_symbol_id + JOIN watch_symbols target ON target.id = r.target_symbol_id + WHERE r.repository_id = ? + ORDER BY source.stable_key, target.stable_key, r.kind, r.evidence_hash`, repositoryID) + if err != nil { + return "", err + } + defer func() { _ = refRows.Close() }() + for refRows.Next() { + var sourceKey, targetKey string + var kind, evidenceHash string + if err := refRows.Scan(&sourceKey, &targetKey, &kind, &evidenceHash); err != nil { + return "", err + } + _, _ = fmt.Fprintf(h, "r:%s:%s:%s:%s\n", sourceKey, targetKey, kind, evidenceHash) + } + if err := refRows.Err(); err != nil { + return "", err + } + factRows, err := s.db.QueryContext(ctx, ` + SELECT enricher, stable_key, type, fact_hash + FROM watch_facts + WHERE repository_id = ? + ORDER BY enricher, stable_key, type`, repositoryID) + if err != nil { + return "", err + } + defer func() { _ = factRows.Close() }() + for factRows.Next() { + var enricher, stableKey, factType, factHash string + if err := factRows.Scan(&enricher, &stableKey, &factType, &factHash); err != nil { + return "", err + } + _, _ = fmt.Fprintf(h, "f:%s:%s:%s:%s\n", enricher, stableKey, factType, factHash) + } + if err := factRows.Err(); err != nil { + return "", err + } + return hex.EncodeToString(h.Sum(nil)), nil +} + +func (s *Store) AcquireLock(ctx context.Context, repositoryID int64, pid int, token string, staleAfter time.Duration) (Lock, error) { + if staleAfter <= 0 { + staleAfter = LockHeartbeatTimeout + } + now := nowString() + cutoff := time.Now().UTC().Add(-staleAfter).Format(time.RFC3339) + if err := s.markStaleLocks(ctx, cutoff); err != nil { + return Lock{}, err + } + _, err := s.db.ExecContext(ctx, ` + INSERT INTO watch_locks(id, repository_id, pid, token, started_at, heartbeat_at, status) + VALUES (1, ?, ?, ?, ?, ?, 'active') + ON CONFLICT(id) DO UPDATE SET + repository_id = excluded.repository_id, + pid = excluded.pid, + token = excluded.token, + started_at = excluded.started_at, + heartbeat_at = excluded.heartbeat_at, + status = 'active' + WHERE watch_locks.status NOT IN ('active', 'paused', 'stopping') OR watch_locks.heartbeat_at < ?`, + repositoryID, pid, token, now, now, cutoff) + if err != nil { + return Lock{}, err + } + lock, err := s.ActiveLock(ctx) + if err != nil { + return Lock{}, err + } + if lock.RepositoryID != repositoryID || lock.Token != token { + return Lock{}, fmt.Errorf("repository is already watched by pid %d", lock.PID) + } + return lock, nil +} + +func (s *Store) markStaleLocks(ctx context.Context, cutoff string) error { + if _, err := s.db.ExecContext(ctx, `UPDATE watch_locks SET status = 'stale' WHERE status IN ('active', 'paused', 'stopping') AND heartbeat_at < ?`, cutoff); err != nil { + return err + } + rows, err := s.db.QueryContext(ctx, ` + SELECT id, pid + FROM watch_locks + WHERE status IN ('active', 'paused', 'stopping')`) + if err != nil { + return err + } + var staleIDs []int64 + for rows.Next() { + var id int64 + var pid int + if err := rows.Scan(&id, &pid); err != nil { + return err + } + if !watchProcessIsRunning(pid) { + staleIDs = append(staleIDs, id) + } + } + if err := rows.Err(); err != nil { + _ = rows.Close() + return err + } + if err := rows.Close(); err != nil { + return err + } + for _, id := range staleIDs { + if _, err := s.db.ExecContext(ctx, `UPDATE watch_locks SET status = 'stale' WHERE id = ? AND status IN ('active', 'paused', 'stopping')`, id); err != nil { + return err + } + } + return nil +} + +func (s *Store) ActiveLock(ctx context.Context) (Lock, error) { + row := s.db.QueryRowContext(ctx, ` + SELECT id, repository_id, pid, token, started_at, heartbeat_at, status + FROM watch_locks + WHERE status IN ('active', 'paused', 'stopping') + ORDER BY id + LIMIT 1`) + return scanLock(row) +} + +func (s *Store) lockByRepositoryToken(ctx context.Context, repositoryID int64, token string) (Lock, error) { + row := s.db.QueryRowContext(ctx, ` + SELECT id, repository_id, pid, token, started_at, heartbeat_at, status + FROM watch_locks + WHERE repository_id = ? AND token = ? + LIMIT 1`, repositoryID, token) + return scanLock(row) +} + +func (s *Store) ActiveLiveLock(ctx context.Context, staleAfter time.Duration) (Lock, bool, error) { + if staleAfter <= 0 { + staleAfter = LockHeartbeatTimeout + } + lock, err := s.ActiveLock(ctx) + if errors.Is(err, sql.ErrNoRows) { + return Lock{}, false, nil + } + if err != nil { + return Lock{}, false, err + } + heartbeat, err := time.Parse(time.RFC3339, lock.HeartbeatAt) + if err != nil || time.Since(heartbeat) > staleAfter || !watchProcessIsRunning(lock.PID) || lock.Status == "stale" || lock.Status == "released" { + _, _ = s.db.ExecContext(ctx, `UPDATE watch_locks SET status = 'stale' WHERE id = ? AND status IN ('active', 'paused', 'stopping')`, lock.ID) + return lock, false, nil + } + return lock, true, nil +} + +func (s *Store) HeartbeatLock(ctx context.Context, repositoryID int64, token string) (Lock, error) { + res, err := s.db.ExecContext(ctx, ` + UPDATE watch_locks + SET heartbeat_at = ? + WHERE repository_id = ? AND token = ? AND status IN ('active', 'paused')`, + nowString(), repositoryID, token) + if err != nil { + return Lock{}, err + } + if rows, err := res.RowsAffected(); err == nil && rows == 0 { + return Lock{}, sql.ErrNoRows + } + return s.lockByRepositoryToken(ctx, repositoryID, token) +} + +func (s *Store) RequestStop(ctx context.Context, repositoryID int64) error { + _, err := s.db.ExecContext(ctx, `UPDATE watch_locks SET status = 'stopping', heartbeat_at = ? WHERE repository_id = ? AND status IN ('active', 'paused')`, nowString(), repositoryID) + return err +} + +func (s *Store) RequestStopActive(ctx context.Context) error { + _, err := s.db.ExecContext(ctx, `UPDATE watch_locks SET status = 'stopping', heartbeat_at = ? WHERE status IN ('active', 'paused')`, nowString()) + return err +} + +func (s *Store) RequestPause(ctx context.Context, repositoryID int64) error { + _, err := s.db.ExecContext(ctx, `UPDATE watch_locks SET status = 'paused', heartbeat_at = ? WHERE repository_id = ? AND status = 'active'`, nowString(), repositoryID) + return err +} + +func (s *Store) RequestPauseActive(ctx context.Context) error { + _, err := s.db.ExecContext(ctx, `UPDATE watch_locks SET status = 'paused', heartbeat_at = ? WHERE status = 'active'`, nowString()) + return err +} + +func (s *Store) RequestResume(ctx context.Context, repositoryID int64) error { + _, err := s.db.ExecContext(ctx, `UPDATE watch_locks SET status = 'active', heartbeat_at = ? WHERE repository_id = ? AND status = 'paused'`, nowString(), repositoryID) + return err +} + +func (s *Store) RequestResumeActive(ctx context.Context) error { + _, err := s.db.ExecContext(ctx, `UPDATE watch_locks SET status = 'active', heartbeat_at = ? WHERE status = 'paused'`, nowString()) + return err +} + +func (s *Store) LockStatus(ctx context.Context, repositoryID int64, token string) (string, error) { + var status string + err := s.db.QueryRowContext(ctx, `SELECT status FROM watch_locks WHERE repository_id = ? AND token = ?`, repositoryID, token).Scan(&status) + return status, err +} + +func (s *Store) ReleaseLock(ctx context.Context, repositoryID int64, token string) error { + _, err := s.db.ExecContext(ctx, `UPDATE watch_locks SET status = 'released', heartbeat_at = ? WHERE repository_id = ? AND token = ?`, nowString(), repositoryID, token) + return err +} + +func (s *Store) AcquireApplyLock(ctx context.Context, repositoryID int64, pid int, token string, staleAfter time.Duration) error { + if staleAfter <= 0 { + staleAfter = LockHeartbeatTimeout + } + now := nowString() + cutoff := time.Now().UTC().Add(-staleAfter).Format(time.RFC3339) + _, err := s.db.ExecContext(ctx, ` + INSERT INTO watch_apply_locks(id, repository_id, pid, token, started_at, heartbeat_at, status) + VALUES (1, ?, ?, ?, ?, ?, 'active') + ON CONFLICT(id) DO UPDATE SET + repository_id = excluded.repository_id, + pid = excluded.pid, + token = excluded.token, + started_at = excluded.started_at, + heartbeat_at = excluded.heartbeat_at, + status = 'active' + WHERE watch_apply_locks.status != 'active' OR watch_apply_locks.heartbeat_at < ?`, + repositoryID, pid, token, now, now, cutoff) + if err != nil { + return err + } + live, err := s.ActiveApplyLock(ctx, staleAfter) + if err != nil || !live { + return err + } + var got string + err = s.db.QueryRowContext(ctx, `SELECT token FROM watch_apply_locks WHERE id = 1 AND status = 'active'`).Scan(&got) + if err != nil { + return err + } + if got != token { + return fmt.Errorf("watch apply is already active") + } + return nil +} + +func (s *Store) ActiveApplyLock(ctx context.Context, staleAfter time.Duration) (bool, error) { + if staleAfter <= 0 { + staleAfter = LockHeartbeatTimeout + } + var id int64 + var pid int + var heartbeatAt, status string + err := s.db.QueryRowContext(ctx, ` + SELECT id, pid, heartbeat_at, status + FROM watch_apply_locks + WHERE id = 1 AND status = 'active'`).Scan(&id, &pid, &heartbeatAt, &status) + if errors.Is(err, sql.ErrNoRows) { + return false, nil + } + if err != nil { + return false, err + } + heartbeat, err := time.Parse(time.RFC3339, heartbeatAt) + if err != nil || time.Since(heartbeat) > staleAfter || !watchProcessIsRunning(pid) || status != "active" { + _, _ = s.db.ExecContext(ctx, `UPDATE watch_apply_locks SET status = 'stale' WHERE id = ? AND status = 'active'`, id) + return false, nil + } + return true, nil +} + +func (s *Store) ReleaseApplyLock(ctx context.Context, repositoryID int64, token string) error { + _, err := s.db.ExecContext(ctx, `UPDATE watch_apply_locks SET status = 'released', heartbeat_at = ? WHERE repository_id = ? AND token = ?`, nowString(), repositoryID, token) + return err +} + +func (s *Store) EnsureGitTags(ctx context.Context) error { + return tagcolors.Ensure(ctx, s.db, managedGitTags()) +} + +func (s *Store) ApplyGitTags(ctx context.Context, repositoryID int64, status GitStatus) (GitTagUpdateResult, error) { + if err := s.EnsureGitTags(ctx); err != nil { + return GitTagUpdateResult{}, err + } + files := map[string][]string{} + addTags := func(paths []string, tag string) { + for _, p := range paths { + files[filepathToSlash(p)] = append(files[filepathToSlash(p)], tag) + } + } + addTags(status.Staged, "git:staged") + addTags(status.Unstaged, "git:unstaged") + addTags(status.Untracked, "git:untracked") + addTags(status.Deleted, "watch:deleted") + rows, err := s.db.QueryContext(ctx, ` + SELECT resource_id, owner_type, owner_key + FROM watch_materialization + WHERE repository_id = ? AND resource_type = 'element' AND owner_type IN ('file', 'symbol')`, repositoryID) + if err != nil { + return GitTagUpdateResult{}, err + } + defer func() { _ = rows.Close() }() + type update struct { + id int64 + tags []string + } + var updates []update + var allElementIDs []int64 + type elementOwner struct { + id int64 + ownerType string + ownerKey string + } + var owners []elementOwner + for rows.Next() { + var id int64 + var ownerType, ownerKey string + if err := rows.Scan(&id, &ownerType, &ownerKey); err != nil { + return GitTagUpdateResult{}, err + } + allElementIDs = append(allElementIDs, id) + owners = append(owners, elementOwner{id: id, ownerType: ownerType, ownerKey: ownerKey}) + } + if err := rows.Err(); err != nil { + return GitTagUpdateResult{}, err + } + if err := rows.Close(); err != nil { + return GitTagUpdateResult{}, err + } + for _, owner := range owners { + file, ok, err := s.materializedOwnerFilePath(ctx, repositoryID, owner.ownerType, owner.ownerKey) + if err != nil { + return GitTagUpdateResult{}, err + } + if !ok { + continue + } + if tags := files[file]; len(tags) > 0 { + updates = append(updates, update{id: owner.id, tags: tags}) + } + } + var result GitTagUpdateResult + for _, id := range allElementIDs { + removed, err := s.removeElementTags(ctx, id, managedGitTags()) + if err != nil { + return GitTagUpdateResult{}, err + } + result.TagsRemoved += removed + } + for _, item := range updates { + added, err := s.addElementTags(ctx, item.id, item.tags) + if err != nil { + return GitTagUpdateResult{}, err + } + result.TagsAdded += added + } + return result, nil +} + +func (s *Store) CreateWatchVersion(ctx context.Context, repositoryID int64, commitHash, commitMessage, parentCommitHash, branch, representationHash string, workspaceVersionID *int64, diffs []RepresentationDiff) (Version, error) { + now := nowString() + _, err := s.db.ExecContext(ctx, ` + INSERT INTO watch_versions(repository_id, commit_hash, commit_message, parent_commit_hash, branch, representation_hash, workspace_version_id, created_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?) + ON CONFLICT(repository_id, commit_hash, representation_hash) DO NOTHING`, + repositoryID, commitHash, nullString(commitMessage), nullString(parentCommitHash), nullString(branch), representationHash, workspaceVersionID, now) + if err != nil { + return Version{}, err + } + version, err := s.WatchVersion(ctx, repositoryID, commitHash, representationHash) + if err != nil { + return Version{}, err + } + for _, diff := range diffs { + _, err := s.db.ExecContext(ctx, ` + INSERT INTO watch_representation_diffs(version_id, owner_type, owner_key, change_type, before_hash, after_hash, resource_type, resource_id, summary, added_lines, removed_lines) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, + version.ID, diff.OwnerType, diff.OwnerKey, diff.ChangeType, diff.BeforeHash, diff.AfterHash, diff.ResourceType, diff.ResourceID, diff.Summary, diff.AddedLines, diff.RemovedLines) + if err != nil { + return Version{}, err + } + } + if err := s.SaveWatchVersionResources(ctx, version.ID, repositoryID); err != nil { + return Version{}, err + } + if err := s.PruneWatchVersions(ctx, repositoryID, 5); err != nil { + return Version{}, err + } + return version, nil +} + +func (s *Store) PruneWatchVersions(ctx context.Context, repositoryID int64, keep int) error { + if keep <= 0 { + keep = 5 + } + _, err := s.db.ExecContext(ctx, ` + DELETE FROM watch_versions + WHERE repository_id = ? + AND id NOT IN ( + SELECT id + FROM watch_versions + WHERE repository_id = ? + ORDER BY id DESC + LIMIT ? + )`, repositoryID, repositoryID, keep) + return err +} + +func (s *Store) WatchVersion(ctx context.Context, repositoryID int64, commitHash, representationHash string) (Version, error) { + row := s.db.QueryRowContext(ctx, ` + SELECT id, repository_id, commit_hash, commit_message, parent_commit_hash, branch, representation_hash, workspace_version_id, created_at + FROM watch_versions + WHERE repository_id = ? AND commit_hash = ? AND representation_hash = ?`, repositoryID, commitHash, representationHash) + return scanVersion(row) +} + +func (s *Store) WatchVersions(ctx context.Context, repositoryID int64, limit int) ([]Version, error) { + if limit <= 0 { + limit = 100 + } + rows, err := s.db.QueryContext(ctx, ` + SELECT id, repository_id, commit_hash, commit_message, parent_commit_hash, branch, representation_hash, workspace_version_id, created_at + FROM watch_versions + WHERE repository_id = ? + ORDER BY id DESC + LIMIT ?`, repositoryID, limit) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + var out []Version + for rows.Next() { + version, err := scanVersion(rows) + if err != nil { + return nil, err + } + out = append(out, version) + } + return out, rows.Err() +} + +func (s *Store) LatestWatchVersion(ctx context.Context, repositoryID int64) (Version, bool, error) { + row := s.db.QueryRowContext(ctx, ` + SELECT id, repository_id, commit_hash, commit_message, parent_commit_hash, branch, representation_hash, workspace_version_id, created_at + FROM watch_versions + WHERE repository_id = ? + ORDER BY id DESC + LIMIT 1`, repositoryID) + version, err := scanVersion(row) + if errors.Is(err, sql.ErrNoRows) { + return Version{}, false, nil + } + return version, err == nil, err +} + +func (s *Store) WorkspaceResourceCounts(ctx context.Context) (views, elements, connectors int, err error) { + for query, dest := range map[string]*int{ + `SELECT COUNT(*) FROM views`: &views, + `SELECT COUNT(*) FROM elements`: &elements, + `SELECT COUNT(*) FROM connectors`: &connectors, + } { + if scanErr := s.db.QueryRowContext(ctx, query).Scan(dest); scanErr != nil { + return 0, 0, 0, scanErr + } + } + return views, elements, connectors, nil +} + +func (s *Store) CreateWorkspaceVersion(ctx context.Context, versionID, source string, parentID *int64, viewCount, elementCount, connectorCount int, description, workspaceHash *string) (int64, error) { + res, err := s.db.ExecContext(ctx, ` + INSERT INTO workspace_versions(version_id, source, parent_version_id, view_count, element_count, connector_count, description, workspace_hash, created_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`, + versionID, source, parentID, viewCount, elementCount, connectorCount, description, workspaceHash, nowString()) + if err != nil { + return 0, err + } + return res.LastInsertId() +} + +func (s *Store) WatchDiffs(ctx context.Context, versionID int64, ownerType, changeType, resourceType, language string, limit int) ([]RepresentationDiff, error) { + if limit <= 0 { + limit = 200 + } + query := ` + SELECT id, version_id, owner_type, owner_key, change_type, before_hash, after_hash, resource_type, resource_id, summary, added_lines, removed_lines + FROM watch_representation_diffs + WHERE version_id = ?` + args := []any{versionID} + if ownerType != "" { + query += ` AND owner_type = ?` + args = append(args, ownerType) + } + if changeType != "" { + query += ` AND change_type = ?` + args = append(args, changeType) + } + if resourceType != "" { + query += ` AND resource_type = ?` + args = append(args, resourceType) + } + query += ` ORDER BY id LIMIT ?` + args = append(args, limit) + rows, err := s.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + var out []RepresentationDiff + for rows.Next() { + var diff RepresentationDiff + var before, after, resourceType, summary sql.NullString + var resourceID sql.NullInt64 + if err := rows.Scan(&diff.ID, &diff.VersionID, &diff.OwnerType, &diff.OwnerKey, &diff.ChangeType, &before, &after, &resourceType, &resourceID, &summary, &diff.AddedLines, &diff.RemovedLines); err != nil { + return nil, err + } + diff.BeforeHash = nullStringPtr(before) + diff.AfterHash = nullStringPtr(after) + diff.ResourceType = nullStringPtr(resourceType) + if resourceID.Valid { + diff.ResourceID = &resourceID.Int64 + } + if lang := diffLanguage(diff); lang != "" { + diff.Language = &lang + } + diff.Summary = nullStringPtr(summary) + if language == "" || (diff.Language != nil && *diff.Language == language) { + out = append(out, diff) + } + } + return out, rows.Err() +} + +type watchResourceSnapshot struct { + OwnerType string + OwnerKey string + ResourceType string + ResourceID *int64 + Language string + Hash string + Summary string + LineCount int + FilePath string + StartLine int + EndLine int +} + +type changedRawResources struct { + Files map[string]struct{} + Symbols map[int64]string +} + +func (s *Store) ChangedRawResourcesSinceLatest(ctx context.Context, repositoryID int64) (changedRawResources, error) { + changed := changedRawResources{Files: map[string]struct{}{}, Symbols: map[int64]string{}} + latest, found, err := s.LatestWatchVersion(ctx, repositoryID) + if err != nil || !found { + return changed, err + } + previous, err := s.WatchVersionResourceSnapshots(ctx, latest.ID) + if err != nil { + return changed, err + } + current, err := s.CurrentWatchResourceSnapshots(ctx, repositoryID) + if err != nil { + return changed, err + } + for key, next := range current { + if next.OwnerType != next.ResourceType { + continue + } + if next.ResourceType != "file" && next.ResourceType != "symbol" { + continue + } + prev, ok := previous[key] + if ok && prev.Hash == next.Hash { + continue + } + switch next.ResourceType { + case "file": + changed.Files[next.OwnerKey] = struct{}{} + case "symbol": + if next.ResourceID != nil { + reason := "changed since latest watch version" + if !ok { + reason = "added since latest watch version" + } + changed.Symbols[*next.ResourceID] = reason + } + } + } + for key, prev := range previous { + if prev.OwnerType != prev.ResourceType { + continue + } + if prev.ResourceType != "file" { + continue + } + if _, ok := current[key]; !ok { + changed.Files[prev.OwnerKey] = struct{}{} + } + } + return changed, nil +} + +func (s *Store) FileLanguages(ctx context.Context, repositoryID int64) (map[string]string, error) { + rows, err := s.db.QueryContext(ctx, `SELECT path, language FROM watch_files WHERE repository_id = ?`, repositoryID) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + out := map[string]string{} + for rows.Next() { + var path, language string + if err := rows.Scan(&path, &language); err != nil { + return nil, err + } + out[path] = language + } + return out, rows.Err() +} + +func (s *Store) BuildWatchDiffs(ctx context.Context, repositoryID int64, representationHash string) ([]RepresentationDiff, error) { + current, err := s.CurrentWatchResourceSnapshots(ctx, repositoryID) + if err != nil { + return nil, err + } + latest, found, err := s.LatestWatchVersion(ctx, repositoryID) + if err != nil { + return nil, err + } + previous := map[string]watchResourceSnapshot{} + if found { + previous, err = s.WatchVersionResourceSnapshots(ctx, latest.ID) + if err != nil { + return nil, err + } + } + previousBaseline := cloneWatchResourceSnapshots(previous) + lineDiffs := s.gitLineDiffsAgainstHead(ctx, repositoryID) + lineHunks := s.gitLineHunksAgainstHead(ctx, repositoryID) + worktreeChanges := s.gitWorktreeChangesAgainstHead(ctx, repositoryID) + var diffs []RepresentationDiff + repoKey := fmt.Sprintf("%d", repositoryID) + repoSummary := "Representation initialized" + change := "initialized" + if found { + change = "updated" + repoSummary = "Representation updated" + } else if len(worktreeChanges) > 0 { + repoSummary = "Representation initialized from dirty worktree" + } + diffs = append(diffs, RepresentationDiff{OwnerType: "repository", OwnerKey: repoKey, ChangeType: change, BeforeHash: stringPtrIf(found, latest.RepresentationHash), AfterHash: &representationHash, Summary: &repoSummary}) + for key, next := range current { + prev, ok := previous[key] + if !ok { + if prevRaw, rawOK := previousRawSnapshotForMaterialized(previousBaseline, current, next); rawOK { + before, after := prevRaw.Hash, next.Hash + diff := snapshotDiff(next, "updated", &before, &after, &prevRaw) + applyGitLineDiff(&diff, next, &prevRaw, lineDiffs, lineHunks) + diffs = append(diffs, diff) + continue + } + changeType := "added" + if !found { + var emit bool + changeType, emit = shouldEmitInitialSnapshotDiff(next, worktreeChanges) + if !emit { + continue + } + } + diff := snapshotDiff(next, changeType, nil, &next.Hash, nil) + applyGitLineDiff(&diff, next, nil, lineDiffs, lineHunks) + diffs = append(diffs, diff) + continue + } + if prev.Hash != next.Hash || ptrInt64Value(prev.ResourceID) != ptrInt64Value(next.ResourceID) { + before, after := prev.Hash, next.Hash + diff := snapshotDiff(next, "updated", &before, &after, &prev) + applyGitLineDiff(&diff, next, &prev, lineDiffs, lineHunks) + diffs = append(diffs, diff) + } + delete(previous, key) + } + for _, prev := range previous { + before := prev.Hash + diff := snapshotDiff(prev, "deleted", &before, nil, nil) + applyGitLineDiff(&diff, prev, nil, lineDiffs, lineHunks) + diffs = append(diffs, diff) + } + sort.Slice(diffs, func(i, j int) bool { + if diffs[i].OwnerType == diffs[j].OwnerType { + return diffs[i].OwnerKey < diffs[j].OwnerKey + } + return diffs[i].OwnerType < diffs[j].OwnerType + }) + return diffs, nil +} + +func shouldEmitInitialSnapshotDiff(snapshot watchResourceSnapshot, changes map[string]tldgit.WorktreeChange) (string, bool) { + changeType := initialSnapshotChangeType(snapshot, changes) + if len(changes) > 0 && changeType == "initialized" { + return changeType, false + } + return changeType, true +} + +func initialSnapshotChangeType(snapshot watchResourceSnapshot, changes map[string]tldgit.WorktreeChange) string { + if len(changes) == 0 { + return "initialized" + } + paths := snapshotDiffFilePaths(snapshot) + if len(paths) == 0 { + return "initialized" + } + hasUpdated := false + for _, path := range paths { + switch changes[path] { + case tldgit.WorktreeAdded: + return "added" + case tldgit.WorktreeUpdated: + hasUpdated = true + } + } + if hasUpdated { + return "updated" + } + return "initialized" +} + +func cloneWatchResourceSnapshots(in map[string]watchResourceSnapshot) map[string]watchResourceSnapshot { + out := make(map[string]watchResourceSnapshot, len(in)) + maps.Copy(out, in) + return out +} + +func (s *Store) CurrentWatchResourceSnapshots(ctx context.Context, repositoryID int64) (map[string]watchResourceSnapshot, error) { + out := map[string]watchResourceSnapshot{} + fileRows, err := s.db.QueryContext(ctx, `SELECT id, path, language, worktree_hash FROM watch_files WHERE repository_id = ?`, repositoryID) + if err != nil { + return nil, err + } + for fileRows.Next() { + var id int64 + var path, language, hash string + if err := fileRows.Scan(&id, &path, &language, &hash); err != nil { + _ = fileRows.Close() + return nil, err + } + out[resourceSnapshotKey("file", path, "file")] = watchResourceSnapshot{OwnerType: "file", OwnerKey: path, ResourceType: "file", ResourceID: &id, Language: language, Hash: hash, Summary: path} + } + if err := fileRows.Close(); err != nil { + return nil, err + } + symRows, err := s.db.QueryContext(ctx, ` + SELECT s.id, COALESCE(i.identity_key, s.stable_key), s.stable_key, f.path, s.content_hash, s.signature_hash, s.qualified_name, s.start_line, s.end_line + FROM watch_symbols s + JOIN watch_files f ON f.id = s.file_id + LEFT JOIN watch_symbol_identities i ON i.repository_id = s.repository_id AND i.current_stable_key = s.stable_key + WHERE s.repository_id = ?`, repositoryID) + if err != nil { + return nil, err + } + for symRows.Next() { + var id int64 + var key, stableKey, filePath, contentHash, signatureHash, name string + var startLine int + var endLine sql.NullInt64 + if err := symRows.Scan(&id, &key, &stableKey, &filePath, &contentHash, &signatureHash, &name, &startLine, &endLine); err != nil { + _ = symRows.Close() + return nil, err + } + hash := hashString(contentHash + ":" + signatureHash) + end := normalizedEndLine(startLine, endLine) + out[resourceSnapshotKey("symbol", key, "symbol")] = watchResourceSnapshot{OwnerType: "symbol", OwnerKey: key, ResourceType: "symbol", ResourceID: &id, Language: languageFromStableKey(stableKey), Hash: hash, Summary: name, LineCount: lineCountFromRange(startLine, endLine), FilePath: filepathToSlash(filePath), StartLine: startLine, EndLine: end} + } + if err := symRows.Close(); err != nil { + return nil, err + } + deletedElementIDs, err := s.deletedMaterializedElementIDs(ctx, repositoryID) + if err != nil { + return nil, err + } + mapRows, err := s.db.QueryContext(ctx, ` + SELECT id, owner_type, owner_key, resource_type, resource_id, updated_at + FROM watch_materialization + WHERE repository_id = ?`, repositoryID) + if err != nil { + return nil, err + } + var mappings []watchMaterializationMapping + for mapRows.Next() { + var mapping watchMaterializationMapping + if err := mapRows.Scan(&mapping.ID, &mapping.OwnerType, &mapping.OwnerKey, &mapping.ResourceType, &mapping.ResourceID, &mapping.UpdatedAt); err != nil { + _ = mapRows.Close() + return nil, err + } + mappings = append(mappings, mapping) + } + if err := mapRows.Close(); err != nil { + return nil, err + } + for _, mapping := range mappings { + tombstoned, err := s.materializationMappingTombstoned(ctx, repositoryID, mapping, deletedElementIDs) + if err != nil { + return nil, err + } + if tombstoned { + continue + } + hash, summary, language, lineCount, err := s.materializedResourceHash(ctx, repositoryID, mapping.OwnerType, mapping.OwnerKey, mapping.ResourceType, mapping.ResourceID) + if err != nil { + continue + } + id := mapping.ResourceID + filePath, startLine, endLine := materializedSourceRange(ctx, s.db, repositoryID, mapping.OwnerType, mapping.OwnerKey, "") + out[resourceSnapshotKey(mapping.OwnerType, mapping.OwnerKey, mapping.ResourceType)] = watchResourceSnapshot{OwnerType: mapping.OwnerType, OwnerKey: mapping.OwnerKey, ResourceType: mapping.ResourceType, ResourceID: &id, Language: language, Hash: hash, Summary: summary, LineCount: lineCount, FilePath: filePath, StartLine: startLine, EndLine: endLine} + } + return out, nil +} + +func (s *Store) materializedResourceHash(ctx context.Context, repositoryID int64, ownerType, ownerKey, resourceType string, resourceID int64) (string, string, string, int, error) { + switch resourceType { + case "element": + var name, kind, description, repo, branch, filePath, language sql.NullString + err := s.db.QueryRowContext(ctx, `SELECT name, kind, description, repo, branch, file_path, language FROM elements WHERE id = ?`, resourceID).Scan(&name, &kind, &description, &repo, &branch, &filePath, &language) + if err != nil { + return "", "", "", 0, err + } + raw := strings.Join([]string{name.String, kind.String, description.String, repo.String, branch.String, filePath.String, language.String}, "\n") + if ownerType == "symbol" { + raw += "\n" + symbolSnapshotHash(ctx, s.db, repositoryID, ownerKey) + } + return hashString(raw), name.String, language.String, materializedLineCount(ctx, s.db, repositoryID, ownerType, ownerKey, filePath.String), nil + case "view": + var name, label sql.NullString + err := s.db.QueryRowContext(ctx, `SELECT name, level_label FROM views WHERE id = ?`, resourceID).Scan(&name, &label) + if err != nil { + return "", "", "", 0, err + } + return hashString(name.String + "\n" + label.String), name.String, "", 0, nil + case "connector": + var viewID, sourceID, targetID int64 + var label, relationship, direction sql.NullString + err := s.db.QueryRowContext(ctx, `SELECT view_id, source_element_id, target_element_id, label, relationship, direction FROM connectors WHERE id = ?`, resourceID).Scan(&viewID, &sourceID, &targetID, &label, &relationship, &direction) + if err != nil { + return "", "", "", 0, err + } + raw := fmt.Sprintf("%d:%d:%d:%s:%s:%s", viewID, sourceID, targetID, label.String, relationship.String, direction.String) + return hashString(raw), s.connectorSummary(ctx, sourceID, targetID, direction.String), "", 0, nil + default: + return "", "", "", 0, fmt.Errorf("unsupported resource type %q", resourceType) + } +} + +func (s *Store) connectorSummary(ctx context.Context, sourceID, targetID int64, direction string) string { + sourceName := elementName(ctx, s.db, sourceID) + targetName := elementName(ctx, s.db, targetID) + if sourceName == "" { + sourceName = fmt.Sprintf("element %d", sourceID) + } + if targetName == "" { + targetName = fmt.Sprintf("element %d", targetID) + } + switch strings.ToLower(strings.TrimSpace(direction)) { + case "both", "bidirectional": + return sourceName + "<->" + targetName + case "backward": + return targetName + "->" + sourceName + case "none": + return sourceName + "--" + targetName + default: + return sourceName + "->" + targetName + } +} + +func elementName(ctx context.Context, db *sql.DB, id int64) string { + var name sql.NullString + if err := db.QueryRowContext(ctx, `SELECT name FROM elements WHERE id = ?`, id).Scan(&name); err != nil || !name.Valid { + return "" + } + return strings.TrimSpace(name.String) +} + +func (s *Store) WatchVersionResourceSnapshots(ctx context.Context, versionID int64) (map[string]watchResourceSnapshot, error) { + if err := s.ensureWatchVersionResourceRangeColumns(ctx); err != nil { + return nil, err + } + rows, err := s.db.QueryContext(ctx, ` + SELECT owner_type, owner_key, resource_type, resource_id, language, resource_hash, summary, line_count, file_path, start_line, end_line + FROM watch_version_resources + WHERE version_id = ?`, versionID) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + out := map[string]watchResourceSnapshot{} + for rows.Next() { + var item watchResourceSnapshot + var resourceID sql.NullInt64 + var language, summary, filePath sql.NullString + if err := rows.Scan(&item.OwnerType, &item.OwnerKey, &item.ResourceType, &resourceID, &language, &item.Hash, &summary, &item.LineCount, &filePath, &item.StartLine, &item.EndLine); err != nil { + return nil, err + } + if resourceID.Valid { + item.ResourceID = &resourceID.Int64 + } + item.Language = language.String + item.Summary = summary.String + item.FilePath = filepathToSlash(filePath.String) + out[resourceSnapshotKey(item.OwnerType, item.OwnerKey, item.ResourceType)] = item + } + return out, rows.Err() +} + +func (s *Store) SaveWatchVersionResources(ctx context.Context, versionID, repositoryID int64) error { + if err := s.ensureWatchVersionResourceRangeColumns(ctx); err != nil { + return err + } + snapshots, err := s.CurrentWatchResourceSnapshots(ctx, repositoryID) + if err != nil { + return err + } + for _, item := range snapshots { + _, err := s.db.ExecContext(ctx, ` + INSERT OR REPLACE INTO watch_version_resources(version_id, owner_type, owner_key, resource_type, resource_id, language, resource_hash, summary, line_count, file_path, start_line, end_line) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, + versionID, item.OwnerType, item.OwnerKey, item.ResourceType, item.ResourceID, nullString(item.Language), item.Hash, nullString(item.Summary), item.LineCount, nullString(item.FilePath), item.StartLine, item.EndLine) + if err != nil { + return err + } + } + return nil +} + +func (s *Store) ensureWatchVersionResourceRangeColumns(ctx context.Context) error { + for _, stmt := range []string{ + `ALTER TABLE watch_version_resources ADD COLUMN file_path TEXT NULL`, + `ALTER TABLE watch_version_resources ADD COLUMN start_line INTEGER NOT NULL DEFAULT 0`, + `ALTER TABLE watch_version_resources ADD COLUMN end_line INTEGER NOT NULL DEFAULT 0`, + } { + if _, err := s.db.ExecContext(ctx, stmt); err != nil && !strings.Contains(err.Error(), "duplicate column name") { + return err + } + } + return nil +} + +func snapshotDiff(snapshot watchResourceSnapshot, changeType string, beforeHash, afterHash *string, previous *watchResourceSnapshot) RepresentationDiff { + resourceType := snapshot.ResourceType + summary := snapshot.Summary + language := snapshot.Language + addedLines, removedLines := lineDelta(changeType, snapshot.LineCount, previous) + return RepresentationDiff{OwnerType: snapshot.OwnerType, OwnerKey: snapshot.OwnerKey, ChangeType: changeType, BeforeHash: beforeHash, AfterHash: afterHash, ResourceType: &resourceType, ResourceID: snapshot.ResourceID, Language: &language, Summary: &summary, AddedLines: addedLines, RemovedLines: removedLines} +} + +func previousRawSnapshotForMaterialized(previous, current map[string]watchResourceSnapshot, snapshot watchResourceSnapshot) (watchResourceSnapshot, bool) { + if snapshot.ResourceType != "element" { + return watchResourceSnapshot{}, false + } + rawType := "" + rawOwnerKey := snapshot.OwnerKey + switch snapshot.OwnerType { + case "file": + rawType = "file" + rawOwnerKey = strings.TrimPrefix(snapshot.OwnerKey, "file:") + case "symbol": + rawType = "symbol" + default: + return watchResourceSnapshot{}, false + } + rawKey := resourceSnapshotKey(snapshot.OwnerType, rawOwnerKey, rawType) + prev, ok := previous[rawKey] + if !ok { + return watchResourceSnapshot{}, false + } + next, ok := current[rawKey] + if !ok || prev.Hash == next.Hash { + return watchResourceSnapshot{}, false + } + return prev, true +} + +func (s *Store) gitLineDiffsAgainstHead(ctx context.Context, repositoryID int64) map[string]tldgit.LineDiff { + repo, err := s.Repository(ctx, repositoryID) + if err != nil || strings.TrimSpace(repo.RepoRoot) == "" { + return nil + } + diffs, err := tldgit.LineDiffsAgainstHead(repo.RepoRoot) + if err != nil { + return nil + } + return diffs +} + +func (s *Store) gitLineHunksAgainstHead(ctx context.Context, repositoryID int64) map[string][]tldgit.LineHunk { + repo, err := s.Repository(ctx, repositoryID) + if err != nil || strings.TrimSpace(repo.RepoRoot) == "" { + return nil + } + hunks, err := tldgit.LineHunksAgainstHead(repo.RepoRoot) + if err != nil { + return nil + } + return hunks +} + +func (s *Store) gitWorktreeChangesAgainstHead(ctx context.Context, repositoryID int64) map[string]tldgit.WorktreeChange { + repo, err := s.Repository(ctx, repositoryID) + if err != nil || strings.TrimSpace(repo.RepoRoot) == "" { + return nil + } + changes, err := tldgit.WorktreeChangesAgainstHead(repo.RepoRoot) + if err != nil { + return nil + } + return changes +} + +func applyGitLineDiff(diff *RepresentationDiff, snapshot watchResourceSnapshot, previous *watchResourceSnapshot, lineDiffs map[string]tldgit.LineDiff, lineHunks map[string][]tldgit.LineHunk) { + if diff == nil || len(lineDiffs) == 0 || diff.ChangeType != "updated" { + return + } + file := snapshotDiffFilePath(snapshot) + if file == "" { + return + } + if symbolLineAttributionCandidate(snapshot) { + if added, removed, ok := symbolLineDiff(snapshot, previous, lineHunks[file]); ok { + diff.AddedLines = added + diff.RemovedLines = removed + return + } + } + lineDiff, ok := lineDiffs[file] + if !ok { + return + } + diff.AddedLines = lineDiff.Added + diff.RemovedLines = lineDiff.Removed +} + +func symbolLineAttributionCandidate(snapshot watchResourceSnapshot) bool { + return snapshot.OwnerType == "symbol" || (snapshot.ResourceType == "element" && snapshot.OwnerType == "symbol") +} + +func symbolLineDiff(snapshot watchResourceSnapshot, previous *watchResourceSnapshot, hunks []tldgit.LineHunk) (int, int, bool) { + if len(hunks) == 0 || snapshot.StartLine <= 0 || snapshot.EndLine <= 0 { + return 0, 0, false + } + oldStart, oldEnd := snapshot.StartLine, snapshot.EndLine + if previous != nil && previous.StartLine > 0 && previous.EndLine > 0 { + oldStart, oldEnd = previous.StartLine, previous.EndLine + } + added, removed := 0, 0 + for _, hunk := range hunks { + added += countLinesInRange(hunk.AddedLines, snapshot.StartLine, snapshot.EndLine) + removed += countLinesInRange(hunk.RemovedLines, oldStart, oldEnd) + } + return added, removed, true +} + +func countLinesInRange(lines []int, start, end int) int { + if start <= 0 || end < start { + return 0 + } + count := 0 + for _, line := range lines { + if line >= start && line <= end { + count++ + } + } + return count +} + +func snapshotDiffFilePath(snapshot watchResourceSnapshot) string { + paths := snapshotDiffFilePaths(snapshot) + if len(paths) == 0 { + return "" + } + return paths[0] +} + +func snapshotDiffFilePaths(snapshot watchResourceSnapshot) []string { + if path := filepathToSlash(snapshot.FilePath); path != "" { + return []string{path} + } + switch snapshot.OwnerType { + case "file": + if path := strings.TrimPrefix(snapshot.OwnerKey, "file:"); strings.TrimSpace(path) != "" { + return []string{filepathToSlash(path)} + } + case "symbol": + if file, ok := filePathFromStableKey(snapshot.OwnerKey); ok { + return []string{file} + } + case "file-reference": + return filePairPaths(strings.TrimPrefix(snapshot.OwnerKey, "file:")) + case "reference": + return referenceOwnerPaths(snapshot.OwnerKey) + } + return nil +} + +func filePairPaths(value string) []string { + parts := strings.Split(value, "->") + if len(parts) != 2 { + return nil + } + var out []string + for _, part := range parts { + if path := filepathToSlash(strings.TrimSpace(part)); path != "" { + out = append(out, path) + } + } + return out +} + +func referenceOwnerPaths(ownerKey string) []string { + ownerKey = strings.TrimPrefix(ownerKey, "symbol:") + parts := strings.Split(ownerKey, ":") + seen := map[string]struct{}{} + var out []string + for i := 0; i+3 < len(parts); i++ { + candidate := strings.Join(parts[i:i+4], ":") + path, ok := filePathFromStableKey(candidate) + if !ok { + continue + } + if _, exists := seen[path]; exists { + continue + } + seen[path] = struct{}{} + out = append(out, path) + } + return out +} + +func materializedLineCount(ctx context.Context, db *sql.DB, repositoryID int64, ownerType, ownerKey, filePath string) int { + switch ownerType { + case "symbol": + return symbolLineCount(ctx, db, repositoryID, ownerKey) + case "file": + return fileLineCount(ctx, db, repositoryID, strings.TrimPrefix(ownerKey, "file:")) + } + if count := sourceAnchorLineCount(filePath); count > 0 { + return count + } + return 0 +} + +func symbolLineCount(ctx context.Context, db *sql.DB, repositoryID int64, ownerKey string) int { + var startLine int + var endLine sql.NullInt64 + err := db.QueryRowContext(ctx, ` + SELECT s.start_line, s.end_line + FROM watch_symbols s + LEFT JOIN watch_symbol_identities i ON i.repository_id = s.repository_id AND i.current_stable_key = s.stable_key + WHERE s.repository_id = ? AND COALESCE(i.identity_key, s.stable_key) = ? + ORDER BY s.id + LIMIT 1`, repositoryID, ownerKey).Scan(&startLine, &endLine) + if err != nil { + return 0 + } + return lineCountFromRange(startLine, endLine) +} + +func materializedSourceRange(ctx context.Context, db *sql.DB, repositoryID int64, ownerType, ownerKey, fallbackFilePath string) (string, int, int) { + switch ownerType { + case "symbol": + return symbolLineRange(ctx, db, repositoryID, ownerKey) + case "file": + path := strings.TrimPrefix(ownerKey, "file:") + if strings.TrimSpace(path) == "" { + path = fallbackFilePath + } + return filepathToSlash(path), 0, 0 + default: + if path := sourceAnchorFilePath(fallbackFilePath); path != "" { + start, end := sourceAnchorRange(fallbackFilePath) + return path, start, end + } + } + return filepathToSlash(fallbackFilePath), 0, 0 +} + +func symbolLineRange(ctx context.Context, db *sql.DB, repositoryID int64, ownerKey string) (string, int, int) { + var filePath string + var startLine int + var endLine sql.NullInt64 + err := db.QueryRowContext(ctx, ` + SELECT f.path, s.start_line, s.end_line + FROM watch_symbols s + JOIN watch_files f ON f.id = s.file_id + LEFT JOIN watch_symbol_identities i ON i.repository_id = s.repository_id AND i.current_stable_key = s.stable_key + WHERE s.repository_id = ? AND COALESCE(i.identity_key, s.stable_key) = ? + ORDER BY s.id + LIMIT 1`, repositoryID, ownerKey).Scan(&filePath, &startLine, &endLine) + if err != nil { + return "", 0, 0 + } + return filepathToSlash(filePath), startLine, normalizedEndLine(startLine, endLine) +} + +func symbolSnapshotHash(ctx context.Context, db *sql.DB, repositoryID int64, ownerKey string) string { + var contentHash, signatureHash string + var startLine int + var endLine sql.NullInt64 + err := db.QueryRowContext(ctx, ` + SELECT s.content_hash, s.signature_hash, s.start_line, s.end_line + FROM watch_symbols s + LEFT JOIN watch_symbol_identities i ON i.repository_id = s.repository_id AND i.current_stable_key = s.stable_key + WHERE s.repository_id = ? AND COALESCE(i.identity_key, s.stable_key) = ? + ORDER BY s.id + LIMIT 1`, repositoryID, ownerKey).Scan(&contentHash, &signatureHash, &startLine, &endLine) + if err != nil { + return "" + } + return fmt.Sprintf("%s:%s:%d", contentHash, signatureHash, lineCountFromRange(startLine, endLine)) +} + +func normalizedEndLine(startLine int, endLine sql.NullInt64) int { + if startLine <= 0 { + return 0 + } + if endLine.Valid && int(endLine.Int64) >= startLine { + return int(endLine.Int64) + } + return startLine +} + +func lineCountFromRange(startLine int, endLine sql.NullInt64) int { + if startLine <= 0 { + return 0 + } + end := startLine + if endLine.Valid { + end = int(endLine.Int64) + } + if end < startLine { + return 0 + } + return end - startLine + 1 +} + +func fileLineCount(ctx context.Context, db *sql.DB, repositoryID int64, filePath string) int { + if strings.TrimSpace(filePath) == "" { + return 0 + } + var maxEnd sql.NullInt64 + err := db.QueryRowContext(ctx, ` + SELECT MAX(COALESCE(s.end_line, s.start_line)) + FROM watch_symbols s + JOIN watch_files f ON f.id = s.file_id + WHERE s.repository_id = ? AND f.path = ?`, repositoryID, filePath).Scan(&maxEnd) + if err != nil || !maxEnd.Valid { + return 0 + } + return int(maxEnd.Int64) +} + +func sourceAnchorLineCount(filePath string) int { + start, end := sourceAnchorRange(filePath) + if start <= 0 || end < start { + return 0 + } + return end - start + 1 +} + +func sourceAnchorRange(filePath string) (int, int) { + hash := strings.IndexByte(filePath, '#') + if hash < 0 || hash == len(filePath)-1 { + return 0, 0 + } + var anchor struct { + StartLine int `json:"startLine"` + EndLine int `json:"endLine"` + } + if err := json.Unmarshal([]byte(filePath[hash+1:]), &anchor); err != nil { + return 0, 0 + } + if anchor.StartLine <= 0 { + return 0, 0 + } + if anchor.EndLine <= 0 { + anchor.EndLine = anchor.StartLine + } + if anchor.EndLine < anchor.StartLine { + return 0, 0 + } + return anchor.StartLine, anchor.EndLine +} + +func sourceAnchorFilePath(filePath string) string { + before, _, ok := strings.Cut(filePath, "#") + if !ok { + return filepathToSlash(filePath) + } + return filepathToSlash(before) +} + +func lineDelta(changeType string, lineCount int, previous *watchResourceSnapshot) (int, int) { + if lineCount < 0 { + lineCount = 0 + } + switch changeType { + case "added": + return lineCount, 0 + case "deleted": + return 0, lineCount + case "updated": + if previous == nil || previous.LineCount <= 0 || lineCount <= 0 { + return 0, 0 + } + delta := lineCount - previous.LineCount + if delta > 0 { + return delta, 0 + } + if delta < 0 { + return 0, -delta + } + } + return 0, 0 +} + +func resourceSnapshotKey(ownerType, ownerKey, resourceType string) string { + return ownerType + "\x00" + ownerKey + "\x00" + resourceType +} + +func ptrInt64Value(value *int64) int64 { + if value == nil { + return 0 + } + return *value +} + +func stringPtrIf(ok bool, value string) *string { + if !ok { + return nil + } + return &value +} + +func diffLanguage(diff RepresentationDiff) string { + if diff.Language != nil { + return *diff.Language + } + if diff.OwnerType == "symbol" || diff.ResourceType != nil && *diff.ResourceType == "symbol" { + return languageFromStableKey(diff.OwnerKey) + } + return "" +} + +func (s *Store) fileByPath(ctx context.Context, repositoryID int64, path string) (File, bool, error) { + file, err := s.fileByPathMust(ctx, repositoryID, path) + if errors.Is(err, sql.ErrNoRows) { + return File{}, false, nil + } + return file, err == nil, err +} + +func scanLock(row rowScanner) (Lock, error) { + var lock Lock + if err := row.Scan(&lock.ID, &lock.RepositoryID, &lock.PID, &lock.Token, &lock.StartedAt, &lock.HeartbeatAt, &lock.Status); err != nil { + return Lock{}, err + } + return lock, nil +} + +func scanVersion(row rowScanner) (Version, error) { + var version Version + var message sql.NullString + var parent sql.NullString + var branch sql.NullString + var workspaceVersionID sql.NullInt64 + if err := row.Scan(&version.ID, &version.RepositoryID, &version.CommitHash, &message, &parent, &branch, &version.RepresentationHash, &workspaceVersionID, &version.CreatedAt); err != nil { + return Version{}, err + } + if message.Valid { + version.CommitMessage = message.String + } + if parent.Valid { + version.ParentCommitHash = parent.String + } + if branch.Valid { + version.Branch = branch.String + } + if workspaceVersionID.Valid { + version.WorkspaceVersionID = &workspaceVersionID.Int64 + } + return version, nil +} + +func (s *Store) addElementTags(ctx context.Context, elementID int64, add []string) (int, error) { + var raw string + if err := s.db.QueryRowContext(ctx, `SELECT tags FROM elements WHERE id = ?`, elementID).Scan(&raw); err != nil { + return 0, err + } + var tags []string + _ = json.Unmarshal([]byte(raw), &tags) + seen := make(map[string]struct{}, len(tags)+len(add)) + next := make([]string, 0, len(tags)+len(add)) + added := 0 + for _, tag := range tags { + if _, ok := seen[tag]; ok { + continue + } + seen[tag] = struct{}{} + next = append(next, tag) + } + for _, tag := range add { + if _, ok := seen[tag]; ok { + continue + } + seen[tag] = struct{}{} + next = append(next, tag) + added++ + } + if added == 0 { + return 0, nil + } + data, _ := json.Marshal(next) + _, err := s.db.ExecContext(ctx, `UPDATE elements SET tags = ?, updated_at = ? WHERE id = ?`, string(data), nowString(), elementID) + return added, err +} + +func (s *Store) removeElementTags(ctx context.Context, elementID int64, remove []string) (int, error) { + var raw string + if err := s.db.QueryRowContext(ctx, `SELECT tags FROM elements WHERE id = ?`, elementID).Scan(&raw); err != nil { + return 0, err + } + var tags []string + _ = json.Unmarshal([]byte(raw), &tags) + removeSet := make(map[string]struct{}, len(remove)) + for _, tag := range remove { + removeSet[tag] = struct{}{} + } + next := make([]string, 0, len(tags)) + removed := 0 + for _, tag := range tags { + if _, ok := removeSet[tag]; ok { + removed++ + continue + } + next = append(next, tag) + } + if removed == 0 { + return 0, nil + } + data, _ := json.Marshal(next) + _, err := s.db.ExecContext(ctx, `UPDATE elements SET tags = ?, updated_at = ? WHERE id = ?`, string(data), nowString(), elementID) + return removed, err +} + +func managedGitTags() []string { + return []string{"git:staged", "git:unstaged", "git:untracked", "watch:deleted"} +} + +func filepathToSlash(path string) string { + return strings.ReplaceAll(path, "\\", "/") +} + +func absInt(value int) int { + if value < 0 { + return -value + } + return value +} + +func sameQualifierParent(left, right string) bool { + leftParent := qualifierParent(left) + rightParent := qualifierParent(right) + return leftParent != "" && leftParent == rightParent +} + +func qualifierParent(value string) string { + if idx := strings.LastIndex(value, "."); idx > 0 { + return value[:idx] + } + return "" +} + +func nameTokenSimilarity(left, right string) float64 { + leftTokens := splitIdentifierToken(pathBaseQualifier(left)) + rightTokens := splitIdentifierToken(pathBaseQualifier(right)) + if len(leftTokens) == 0 || len(rightTokens) == 0 { + return 0 + } + leftSet := make(map[string]struct{}, len(leftTokens)) + for _, token := range leftTokens { + leftSet[token] = struct{}{} + } + intersection := 0 + union := len(leftSet) + for _, token := range rightTokens { + if _, ok := leftSet[token]; ok { + intersection++ + continue + } + union++ + } + if union == 0 { + return 0 + } + return float64(intersection) / float64(union) +} + +func pathBaseQualifier(value string) string { + if idx := strings.LastIndex(value, "."); idx >= 0 && idx+1 < len(value) { + return value[idx+1:] + } + return value +} + +func embeddingDataset(modelID int64) string { + return fmt.Sprintf("model:%d", modelID) +} + +func bytesToVector(data []byte) Vector { + if len(data)%4 != 0 { + return nil + } + vector := make(Vector, len(data)/4) + for i := range vector { + vector[i] = math.Float32frombits(binary.LittleEndian.Uint32(data[i*4:])) + } + return vector +} + +func (s *Store) clusterByStableKey(ctx context.Context, repositoryID int64, stableKey string) (Cluster, error) { + row := s.db.QueryRowContext(ctx, ` + SELECT id, repository_id, stable_key, parent_cluster_id, name, kind, algorithm, settings_hash, member_count, created_at, updated_at + FROM watch_clusters + WHERE repository_id = ? AND stable_key = ?`, repositoryID, stableKey) + return scanCluster(row) +} + +type rowScanner interface { + Scan(dest ...any) error +} + +func scanCluster(row rowScanner) (Cluster, error) { + var cluster Cluster + var parent sql.NullInt64 + if err := row.Scan(&cluster.ID, &cluster.RepositoryID, &cluster.StableKey, &parent, &cluster.Name, &cluster.Kind, &cluster.Algorithm, &cluster.SettingsHash, &cluster.MemberCount, &cluster.CreatedAt, &cluster.UpdatedAt); err != nil { + return Cluster{}, err + } + if parent.Valid { + cluster.ParentClusterID = &parent.Int64 + } + return cluster, nil +} + +func (s *Store) latestFilterRunID(ctx context.Context, repositoryID int64) (int64, error) { + var id int64 + err := s.db.QueryRowContext(ctx, ` + SELECT id FROM watch_filter_runs + WHERE repository_id = ? + ORDER BY id DESC + LIMIT 1`, repositoryID).Scan(&id) + if errors.Is(err, sql.ErrNoRows) { + return 0, nil + } + return id, err +} + +func (s *Store) fileByPathMust(ctx context.Context, repositoryID int64, path string) (File, error) { + var file File + err := s.db.QueryRowContext(ctx, ` + SELECT id, repository_id, path, language, git_blob_hash, worktree_hash, size_bytes, mtime_unix, scan_status, scan_error, created_at, updated_at + FROM watch_files + WHERE repository_id = ? AND path = ?`, repositoryID, path).Scan(&file.ID, &file.RepositoryID, &file.Path, &file.Language, &file.GitBlobHash, &file.WorktreeHash, &file.SizeBytes, &file.MtimeUnix, &file.ScanStatus, &file.ScanError, &file.CreatedAt, &file.UpdatedAt) + return file, err +} + +func (s *Store) file(ctx context.Context, id int64) (File, error) { + var file File + err := s.db.QueryRowContext(ctx, ` + SELECT id, repository_id, path, language, git_blob_hash, worktree_hash, size_bytes, mtime_unix, scan_status, scan_error, created_at, updated_at + FROM watch_files + WHERE id = ?`, id).Scan(&file.ID, &file.RepositoryID, &file.Path, &file.Language, &file.GitBlobHash, &file.WorktreeHash, &file.SizeBytes, &file.MtimeUnix, &file.ScanStatus, &file.ScanError, &file.CreatedAt, &file.UpdatedAt) + return file, err +} + +func nullString(value string) any { + if strings.TrimSpace(value) == "" { + return nil + } + return value +} + +func nowString() string { + return time.Now().UTC().Format(time.RFC3339) +} diff --git a/internal/watch/watch_test.go b/internal/watch/watch_test.go new file mode 100644 index 0000000..db09b32 --- /dev/null +++ b/internal/watch/watch_test.go @@ -0,0 +1,4771 @@ +package watch + +import ( + "context" + "database/sql" + "encoding/json" + "errors" + "fmt" + "net/http" + "net/http/httptest" + "os" + "os/exec" + "path/filepath" + "reflect" + "slices" + "strings" + "testing" + "time" + + "github.com/mertcikla/tld/internal/analyzer" + tldgit "github.com/mertcikla/tld/internal/git" + "github.com/mertcikla/tld/internal/watch/enrich" + "github.com/mertcikla/tld/internal/watch/enrich/defaults" + sqlitevec "github.com/viant/sqlite-vec/vec" + _ "modernc.org/sqlite" +) + +func TestMigrationCreatesWatchTablesAndIndexes(t *testing.T) { + db := openTestDB(t) + defer func() { _ = db.Close() }() + + for _, table := range []string{"watch_repositories", "watch_files", "watch_symbols", "watch_references", "watch_facts", "watch_scan_runs", "watch_embedding_models", "watch_embeddings", "watch_filter_runs", "watch_filter_decisions", "watch_clusters", "watch_cluster_members", "watch_materialization", "watch_architecture_links", "watch_context_policies", "watch_context_expansions", "watch_representation_runs", "watch_locks", "watch_apply_locks", "watch_versions", "watch_representation_diffs", "watch_version_resources", "workspace_versions"} { + var name string + if err := db.QueryRow(`SELECT name FROM sqlite_master WHERE type = 'table' AND name = ?`, table).Scan(&name); err != nil { + t.Fatalf("missing table %s: %v", table, err) + } + } + for _, index := range []string{"idx_watch_repositories_remote_url", "idx_watch_repositories_repo_root", "idx_watch_facts_subject", "idx_watch_facts_object", "idx_watch_filter_decisions_owner_key", "idx_watch_context_expansions_scope", "idx_watch_context_expansions_owner"} { + var name string + if err := db.QueryRow(`SELECT name FROM sqlite_master WHERE type = 'index' AND name = ?`, index).Scan(&name); err != nil { + t.Fatalf("missing index %s: %v", index, err) + } + } +} + +func TestReplaceFactsForFileIsIdempotentAndAffectsRawGraphHash(t *testing.T) { + db := openTestDB(t) + defer func() { _ = db.Close() }() + store := NewStore(db) + repo, err := store.EnsureRepository(context.Background(), RepositoryInput{RepoRoot: t.TempDir(), DisplayName: "repo"}) + if err != nil { + t.Fatal(err) + } + file, _, err := store.UpsertFile(context.Background(), repo.ID, "main.go", "go", "", "hash", 10, 1, "parsed", nil) + if err != nil { + t.Fatal(err) + } + first := Fact{ + StableKey: "http.route:main", + Type: "http.route", + Enricher: "go.nethttp", + SubjectKind: "file", + SubjectStableKey: "file:main.go", + ObjectKind: "symbol", + ObjectStableKey: "go:main.go:function:Main", + ObjectFilePath: "main.go", + ObjectName: "Main", + Relationship: "declares", + FilePath: "main.go", + StartLine: 3, + Confidence: 1, + Name: "GET /users", + Tags: []string{"http:route", "framework:nethttp"}, + AttributesJSON: `{"path":"/users"}`, + VisibilityHintsJSON: `{"high_signal":1}`, + FactHash: "fact-hash-1", + RawJSON: `{}`, + } + if err := store.ReplaceFactsForFile(context.Background(), repo.ID, file.ID, []Fact{first}); err != nil { + t.Fatal(err) + } + hash1, err := store.RawGraphHash(context.Background(), repo.ID) + if err != nil { + t.Fatal(err) + } + if err := store.ReplaceFactsForFile(context.Background(), repo.ID, file.ID, []Fact{first}); err != nil { + t.Fatal(err) + } + hash2, err := store.RawGraphHash(context.Background(), repo.ID) + if err != nil { + t.Fatal(err) + } + if hash1 != hash2 { + t.Fatalf("idempotent fact replacement changed raw graph hash: %s != %s", hash1, hash2) + } + second := first + second.StableKey = "http.route:admin" + second.Name = "GET /admin" + second.FactHash = "fact-hash-2" + if err := store.ReplaceFactsForFile(context.Background(), repo.ID, file.ID, []Fact{second}); err != nil { + t.Fatal(err) + } + facts, err := store.FactsForRepository(context.Background(), repo.ID) + if err != nil { + t.Fatal(err) + } + if len(facts) != 1 || facts[0].StableKey != second.StableKey { + t.Fatalf("expected stale fact replacement, got %+v", facts) + } + if facts[0].ObjectKind != "symbol" || facts[0].ObjectStableKey != "go:main.go:function:Main" || facts[0].Relationship != "declares" || facts[0].VisibilityHintsJSON == "" { + t.Fatalf("expected fact object and visibility fields to round-trip, got %+v", facts[0]) + } + hash3, err := store.RawGraphHash(context.Background(), repo.ID) + if err != nil { + t.Fatal(err) + } + if hash3 == hash1 { + t.Fatalf("raw graph hash did not change after fact change: %s", hash3) + } +} + +func TestVisibilityConfigPreservesExplicitConfigValues(t *testing.T) { + defaults := defaultVisibilityConfig(VisibilityConfig{}) + if !defaults.CoreThresholdEnabled || defaults.Weights.HighSignalFact == 0 { + t.Fatalf("expected zero-value visibility config to receive defaults, got %+v", defaults) + } + cfg := defaultVisibilityConfig(VisibilityConfig{ + CoreThresholdEnabled: false, + CoreThresholdSet: true, + WeightsSet: true, + Weights: VisibilityWeights{ + HighSignalFact: 0, + UserHide: 0, + }, + }) + if cfg.CoreThresholdEnabled { + t.Fatalf("expected explicit disabled core threshold to be preserved, got %+v", cfg) + } + if cfg.Weights.HighSignalFact != 0 || cfg.Weights.UserHide != 0 { + t.Fatalf("expected explicit zero weights to be preserved, got %+v", cfg.Weights) + } +} + +func TestContextShowAndHideRoundTripGeneratedSymbol(t *testing.T) { + db := openTestDB(t) + defer func() { _ = db.Close() }() + repo := initGitRepoNoCommit(t) + writeFile(t, repo, "main.go", `package main + +func Main() {} + +func quietHelper() string { + return "quiet" +} + +`) + + store := NewStore(db) + scanResult, err := NewScanner(store).Scan(context.Background(), repo) + if err != nil { + t.Fatal(err) + } + req := RepresentRequest{Embedding: EmbeddingConfig{Provider: "none"}} + if _, err := NewRepresenter(store).Represent(context.Background(), scanResult.RepositoryID, req); err != nil { + t.Fatal(err) + } + if elementNameExists(t, db, "quietHelper") { + t.Fatal("quiet helper should start hidden") + } + fileElementID := elementIDByName(t, db, "main.go") + show, err := store.ApplyContextAction(context.Background(), scanResult.RepositoryID, contextActionShow, ContextResourceRequest{ResourceType: "element", ResourceID: fileElementID}, req) + if err != nil { + t.Fatal(err) + } + if show.PoliciesCreated != 0 || show.TierBefore != 0 || show.TierAfter != 1 || show.OwnersAffected == 0 { + t.Fatalf("expected show to create a tier-1 expansion without durable policies, got %+v", show) + } + if !elementNameExists(t, db, "quietHelper") { + t.Fatal("quiet helper should be materialized after show context") + } + manualID := insertManualElement(t, db, "Manual note") + clean, err := store.ApplyContextAction(context.Background(), scanResult.RepositoryID, contextActionClean, ContextResourceRequest{ResourceType: "element", ResourceID: fileElementID}, req) + if err != nil { + t.Fatal(err) + } + if clean.TierBefore != 1 || clean.TierAfter != 0 || clean.ElementsRemoved == 0 { + t.Fatalf("expected clean to collapse the expansion and remove generated detail, got %+v", clean) + } + if elementNameExists(t, db, "quietHelper") { + t.Fatal("quiet helper should be pruned after clean noise") + } + var manualName string + if err := db.QueryRow(`SELECT name FROM elements WHERE id = ?`, manualID).Scan(&manualName); err != nil { + t.Fatalf("manual element was removed: %v", err) + } +} + +func TestRunnerRunOnceScansAndRepresentsRepository(t *testing.T) { + db := openTestDB(t) + defer func() { _ = db.Close() }() + repo := initGitRepoNoCommit(t) + writeFile(t, repo, "main.go", `package main + +func Main() {} +`) + + store := NewStore(db) + result, err := NewRunner(store).RunOnce(context.Background(), OneShotOptions{ + Path: repo, + Embedding: EmbeddingConfig{Provider: "none"}, + }) + if err != nil { + t.Fatal(err) + } + if result.Repository.ID == 0 || result.Scan.RepositoryID == 0 { + t.Fatalf("missing repository/scan ids: %+v", result) + } + if result.Representation.RepresentationHash == "" { + t.Fatalf("missing representation hash: %+v", result.Representation) + } + if result.Scan.FilesParsed == 0 || result.Representation.ElementsCreated == 0 { + t.Fatalf("expected parsed files and materialized elements: scan=%+v representation=%+v", result.Scan, result.Representation) + } +} + +func TestScanAndRepresentMaterializesEnricherFactsWithoutNoisyTags(t *testing.T) { + db := openTestDB(t) + defer func() { _ = db.Close() }() + repo := initGitRepoNoCommit(t) + writeFile(t, repo, "go.mod", `module example.com/enriched + +go 1.23 + +require github.com/go-chi/chi/v5 v5.0.12 +`) + writeFile(t, repo, "main.go", `package main + +import "github.com/go-chi/chi/v5" + +func Routes(r chi.Router) { + r.Get("/users/{id}", GetUser) +} + +func GetUser() {} +`) + writeFile(t, repo, "package.json", `{ + "dependencies": { + "next": "14.0.0", + "@prisma/client": "5.0.0" + } +}`) + writeFile(t, repo, "src/app/users/[id]/page.tsx", `export default function Page() { + return null +}`) + writeFile(t, repo, "db.ts", `import { PrismaClient } from "@prisma/client" + +const prisma = new PrismaClient() + +export async function Users() { + return prisma.user.findMany() +} +`) + + store := NewStore(db) + scanResult, err := NewScanner(store).Scan(context.Background(), repo) + if err != nil { + t.Fatal(err) + } + facts, err := store.FactsForRepository(context.Background(), scanResult.RepositoryID) + if err != nil { + t.Fatal(err) + } + for _, want := range []struct { + factType string + tag string + }{ + {"dependency.module", "dependency:module"}, + {"dependency.import", "dependency:import"}, + {"http.route", "framework:chi"}, + {"frontend.route", "framework:nextjs"}, + {"orm.query", "orm:prisma"}, + } { + if !factsContain(facts, want.factType, want.tag) { + t.Fatalf("missing fact %s/%s in %+v", want.factType, want.tag, facts) + } + } + if _, err := NewRepresenter(store).Represent(context.Background(), scanResult.RepositoryID, RepresentRequest{Embedding: EmbeddingConfig{Provider: "none"}}); err != nil { + t.Fatal(err) + } + for _, tag := range []string{"http:route", "framework:chi", "frontend:route", "framework:nextjs", "orm:prisma"} { + if count := countElementTag(t, db, tag); count != 0 { + t.Fatalf("expected representation to omit noisy generated tag %q, found on %d elements", tag, count) + } + } + if routes := elementKindCount(t, db, "route"); routes == 0 { + t.Fatal("expected high-signal route facts to materialize as generated route nodes") + } + if deps := countElementTag(t, db, "dependency:import"); deps != 0 { + t.Fatalf("expected dependency/import facts not to surface as tags, found on %d elements", deps) + } + if deps := elementKindCount(t, db, "dependency"); deps != 0 { + t.Fatalf("dependency/import facts should not materialize as one dependency node per import, found %d dependency nodes", deps) + } +} + +func TestFactNodesUseSubjectAwarePlacementAndHandles(t *testing.T) { + db := openTestDB(t) + defer func() { _ = db.Close() }() + repo := initGitRepoNoCommit(t) + writeFile(t, repo, "main.go", `package main + +func Routes() { + GetUser() + CreateUser() +} + +func GetUser() {} +func CreateUser() {} +`) + + store := NewStore(db) + scanResult, err := NewScanner(store).Scan(context.Background(), repo) + if err != nil { + t.Fatal(err) + } + file, ok, err := store.CachedFileByPath(context.Background(), scanResult.RepositoryID, "main.go") + if err != nil { + t.Fatal(err) + } + if !ok { + t.Fatal("missing cached main.go") + } + symbols, err := store.SymbolsForRepository(context.Background(), scanResult.RepositoryID) + if err != nil { + t.Fatal(err) + } + var routesStableKey string + for _, sym := range symbols { + if sym.Name == "Routes" { + routesStableKey = sym.StableKey + break + } + } + if routesStableKey == "" { + t.Fatal("missing Routes symbol") + } + facts := []Fact{ + { + RepositoryID: scanResult.RepositoryID, + FileID: file.ID, + FilePath: "main.go", + StableKey: "test.route.users-id", + Type: "http.route", + Enricher: "test", + SubjectKind: "symbol", + SubjectStableKey: routesStableKey, + ObjectKind: "http.route", + ObjectStableKey: "http.route:/users/{id}", + ObjectName: "/users/{id}", + Relationship: "declares", + StartLine: 4, + Confidence: 1, + Name: "/users/{id}", + Tags: []string{"http:route"}, + AttributesJSON: `{"framework":"test"}`, + VisibilityHintsJSON: `{"high_signal":1}`, + RawJSON: `{}`, + }, + { + RepositoryID: scanResult.RepositoryID, + FileID: file.ID, + FilePath: "main.go", + StableKey: "test.route.users", + Type: "http.route", + Enricher: "test", + SubjectKind: "symbol", + SubjectStableKey: routesStableKey, + ObjectKind: "http.route", + ObjectStableKey: "http.route:/users", + ObjectName: "/users", + Relationship: "declares", + StartLine: 5, + Confidence: 1, + Name: "/users", + Tags: []string{"http:route"}, + AttributesJSON: `{"framework":"test"}`, + VisibilityHintsJSON: `{"high_signal":1}`, + RawJSON: `{}`, + }, + } + for i := range facts { + facts[i].FactHash = stableHash(facts[i]) + } + if err := store.ReplaceFactsForFile(context.Background(), scanResult.RepositoryID, file.ID, facts); err != nil { + t.Fatal(err) + } + if _, err := NewRepresenter(store).Represent(context.Background(), scanResult.RepositoryID, RepresentRequest{Embedding: EmbeddingConfig{Provider: "none"}}); err != nil { + t.Fatal(err) + } + routesPlacement := functionPlacement(t, db, "Routes") + userRoutePlacement := elementPlacementByName(t, db, "/users/{id}") + if routesPlacement.x == userRoutePlacement.x && routesPlacement.y == userRoutePlacement.y { + t.Fatalf("fact route overlaps subject placement: subject=%+v route=%+v", routesPlacement, userRoutePlacement) + } + var sourceHandle, targetHandle sql.NullString + if err := db.QueryRow(` + SELECT c.source_handle, c.target_handle + FROM connectors c + JOIN elements s ON s.id = c.source_element_id + JOIN elements target ON target.id = c.target_element_id + WHERE s.name = ? AND target.name = ? + ORDER BY c.id + LIMIT 1`, "Routes", "/users/{id}").Scan(&sourceHandle, &targetHandle); err != nil { + t.Fatalf("route fact connector: %v", err) + } + if !sourceHandle.Valid || sourceHandle.String == "" || !targetHandle.Valid || targetHandle.String == "" { + t.Fatalf("expected route fact connector to store handle sides, got source=%q target=%q", sourceHandle.String, targetHandle.String) + } +} + +func TestContextHTTPHandlers(t *testing.T) { + db := openTestDB(t) + defer func() { _ = db.Close() }() + repo := initGitRepoNoCommit(t) + writeFile(t, repo, "main.go", `package main + +func Main() {} + +func quietHelper() string { + return "quiet" +} +`) + + store := NewStore(db) + scanResult, err := NewScanner(store).Scan(context.Background(), repo) + if err != nil { + t.Fatal(err) + } + if _, err := NewRepresenter(store).Represent(context.Background(), scanResult.RepositoryID, RepresentRequest{Embedding: EmbeddingConfig{Provider: "none"}}); err != nil { + t.Fatal(err) + } + fileElementID := elementIDByName(t, db, "main.go") + mux := http.NewServeMux() + NewHandler(store).Register(mux) + + req := httptest.NewRequest(http.MethodPost, fmt.Sprintf("/api/watch/repositories/%d/context/show", scanResult.RepositoryID), strings.NewReader(fmt.Sprintf(`{"resource_type":"element","resource_id":%d}`, fileElementID))) + rec := httptest.NewRecorder() + mux.ServeHTTP(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("show context status = %d body %s", rec.Code, rec.Body.String()) + } + if !elementNameExists(t, db, "quietHelper") { + t.Fatal("quiet helper should be materialized by HTTP show context") + } + req = httptest.NewRequest(http.MethodPost, fmt.Sprintf("/api/watch/repositories/%d/context/clean", scanResult.RepositoryID), strings.NewReader(fmt.Sprintf(`{"resource_type":"element","resource_id":%d}`, fileElementID))) + rec = httptest.NewRecorder() + mux.ServeHTTP(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("clean context status = %d body %s", rec.Code, rec.Body.String()) + } + var cleanResponse ContextActionResult + if err := json.Unmarshal(rec.Body.Bytes(), &cleanResponse); err != nil { + t.Fatal(err) + } + if cleanResponse.TierBefore != 1 || cleanResponse.TierAfter != 0 { + t.Fatalf("expected HTTP clean to decrement context tier, got %+v", cleanResponse) + } + if elementNameExists(t, db, "quietHelper") { + t.Fatal("quiet helper should be collapsed by HTTP clean context") + } + + manualID := insertManualElement(t, db, "Manual only") + req = httptest.NewRequest(http.MethodPost, fmt.Sprintf("/api/watch/repositories/%d/context/hide", scanResult.RepositoryID), strings.NewReader(fmt.Sprintf(`{"resource_type":"element","resource_id":%d}`, manualID))) + rec = httptest.NewRecorder() + mux.ServeHTTP(rec, req) + if rec.Code != http.StatusBadRequest { + t.Fatalf("manual-only hide status = %d, want 400", rec.Code) + } +} + +func TestContextShowFocusedRescanRevealsNewSymbols(t *testing.T) { + db := openTestDB(t) + defer func() { _ = db.Close() }() + repo := initGitRepoNoCommit(t) + writeFile(t, repo, "main.go", `package main + +func Main() {} +`) + + store := NewStore(db) + scanResult, err := NewScanner(store).Scan(context.Background(), repo) + if err != nil { + t.Fatal(err) + } + req := RepresentRequest{Embedding: EmbeddingConfig{Provider: "none"}} + if _, err := NewRepresenter(store).Represent(context.Background(), scanResult.RepositoryID, req); err != nil { + t.Fatal(err) + } + fileElementID := elementIDByName(t, db, "main.go") + + writeFile(t, repo, "main.go", `package main + +func Main() {} + +func newPrivateContext() string { + return "new" +} +`) + show, err := store.ApplyContextAction(context.Background(), scanResult.RepositoryID, contextActionShow, ContextResourceRequest{ResourceType: "element", ResourceID: fileElementID}, req) + if err != nil { + t.Fatal(err) + } + if show.OwnersAffected == 0 { + t.Fatalf("expected focused show to affect owners, got %+v", show) + } + if !elementNameExists(t, db, "newPrivateContext") { + t.Fatal("focused show context should rescan the file and reveal newly added private symbols") + } + decisions, err := store.FilterDecisions(context.Background(), scanResult.RepositoryID, FilterDecisionQuery{Decision: "visible"}) + if err != nil { + t.Fatal(err) + } + sym, err := symbolsByName(context.Background(), store, scanResult.RepositoryID, "newPrivateContext") + if err != nil { + t.Fatal(err) + } + if !filterDecisionHasReason(decisions, sym.ID, "selected context expansion tier 1") { + t.Fatalf("expected new private symbol to be forced visible by context expansion, got %+v", decisions) + } +} + +func TestContextHideElementCleansImmediateGeneratedNeighborsOnly(t *testing.T) { + db := openTestDB(t) + defer func() { _ = db.Close() }() + repo := initGitRepoNoCommit(t) + writeFile(t, repo, "main.go", `package main + +func Main() { + helper() +} + +func helper() {} +`) + + store := NewStore(db) + scanResult, err := NewScanner(store).Scan(context.Background(), repo) + if err != nil { + t.Fatal(err) + } + req := RepresentRequest{Embedding: EmbeddingConfig{Provider: "none"}} + if _, err := NewRepresenter(store).Represent(context.Background(), scanResult.RepositoryID, req); err != nil { + t.Fatal(err) + } + mainID := symbolElementID(t, db, "Main") + if !elementNameExists(t, db, "helper") || !connectorExistsBetween(t, db, "Main", "helper") { + t.Fatal("expected generated helper neighbor and connector before hide context") + } + manualID := insertManualElement(t, db, "Manual neighbor note") + + hide, err := store.ApplyContextAction(context.Background(), scanResult.RepositoryID, contextActionHide, ContextResourceRequest{ResourceType: "element", ResourceID: mainID}, req) + if err != nil { + t.Fatal(err) + } + if hide.ConnectorsRemoved == 0 || hide.ElementsRemoved == 0 { + t.Fatalf("expected hide to remove generated neighbor and connector, got %+v", hide) + } + if !elementNameExists(t, db, "Main") { + t.Fatal("selected exported element should remain") + } + if elementNameExists(t, db, "helper") || connectorExistsBetween(t, db, "Main", "helper") { + t.Fatal("immediate generated neighbor context should be cleaned") + } + var manualName string + if err := db.QueryRow(`SELECT name FROM elements WHERE id = ?`, manualID).Scan(&manualName); err != nil { + t.Fatalf("manual element was removed: %v", err) + } +} + +func TestContextViewCleanupAndShowHidePrecedence(t *testing.T) { + db := openTestDB(t) + defer func() { _ = db.Close() }() + repo := initGitRepoNoCommit(t) + writeFile(t, repo, "main.go", `package main + +func Main() {} + +func quietHelper() string { + return "quiet" +} +`) + + store := NewStore(db) + scanResult, err := NewScanner(store).Scan(context.Background(), repo) + if err != nil { + t.Fatal(err) + } + req := RepresentRequest{Embedding: EmbeddingConfig{Provider: "none"}} + if _, err := NewRepresenter(store).Represent(context.Background(), scanResult.RepositoryID, req); err != nil { + t.Fatal(err) + } + fileViewID := materializedResourceID(t, db, scanResult.RepositoryID, "file", "file:main.go", "view") + show, err := store.ApplyContextAction(context.Background(), scanResult.RepositoryID, contextActionShow, ContextResourceRequest{ResourceType: "view", ResourceID: fileViewID}, req) + if err != nil { + t.Fatal(err) + } + if show.PoliciesCreated != 0 || show.TierAfter != 1 { + t.Fatalf("expected view show to create a tiered expansion without durable policies, got %+v", show) + } + if !elementNameExists(t, db, "quietHelper") { + t.Fatal("view show should reveal private generated detail") + } + clean, err := store.ApplyContextAction(context.Background(), scanResult.RepositoryID, contextActionClean, ContextResourceRequest{ResourceType: "view", ResourceID: fileViewID}, req) + if err != nil { + t.Fatal(err) + } + if clean.TierBefore != 1 || clean.TierAfter != 0 || clean.ElementsRemoved == 0 { + t.Fatalf("expected view clean to remove generated symbol noise one tier at a time, got %+v", clean) + } + if elementNameExists(t, db, "quietHelper") { + t.Fatal("view cleanup should remove private generated symbol noise") + } + if !elementNameExists(t, db, "Main") { + t.Fatal("view cleanup should preserve exported entrypoint") + } + if activePolicyCount(t, db, scanResult.RepositoryID, contextActionHide, "symbol") != 0 { + t.Fatal("clean noise should not create durable hide policies") + } + + if _, err := store.ApplyContextAction(context.Background(), scanResult.RepositoryID, contextActionShow, ContextResourceRequest{ResourceType: "view", ResourceID: fileViewID}, req); err != nil { + t.Fatal(err) + } + quietID := symbolElementID(t, db, "quietHelper") + hide, err := store.ApplyContextAction(context.Background(), scanResult.RepositoryID, contextActionHide, ContextResourceRequest{ResourceType: "element", ResourceID: quietID}, req) + if err != nil { + t.Fatal(err) + } + if hide.PoliciesCreated == 0 || activePolicyCount(t, db, scanResult.RepositoryID, contextActionHide, "symbol") == 0 { + t.Fatalf("expected explicit hide to create a durable policy, got %+v", hide) + } + if !elementNameExists(t, db, "quietHelper") { + t.Fatal("active expansion should keep selected context visible even when durable hide is recorded") + } + if _, err := store.ApplyContextAction(context.Background(), scanResult.RepositoryID, contextActionClean, ContextResourceRequest{ResourceType: "view", ResourceID: fileViewID}, req); err != nil { + t.Fatal(err) + } + if elementNameExists(t, db, "quietHelper") { + t.Fatal("durable hide should apply once the forcing expansion is cleaned") + } +} + +func TestRepresentMaterializesWorkspaceIdempotently(t *testing.T) { + db := openTestDB(t) + defer func() { _ = db.Close() }() + repo := initGitRepoNoCommit(t) + writeFile(t, repo, "cmd/app/main.go", `package main + +func Main() { + helper() +} + +func helper() {} +`) + + store := NewStore(db) + scanResult, err := NewScanner(store).Scan(context.Background(), repo) + if err != nil { + t.Fatal(err) + } + first, err := NewRepresenter(store).Represent(context.Background(), scanResult.RepositoryID, RepresentRequest{Embedding: EmbeddingConfig{Provider: "none"}}) + if err != nil { + t.Fatal(err) + } + if first.ElementsCreated == 0 || first.ViewsCreated == 0 { + t.Fatalf("expected materialized resources, got %+v", first) + } + if first.ConnectorsCreated == 0 { + t.Fatalf("expected symbol connector, got %+v", first) + } + countsAfterFirst := workspaceCounts(t, db) + + second, err := NewRepresenter(store).Represent(context.Background(), scanResult.RepositoryID, RepresentRequest{Embedding: EmbeddingConfig{Provider: "none"}}) + if err != nil { + t.Fatal(err) + } + if second.RepresentationHash != first.RepresentationHash { + t.Fatalf("representation hash changed: %s != %s", second.RepresentationHash, first.RepresentationHash) + } + if second.ElementsCreated != 0 || second.ViewsCreated != 0 || second.ConnectorsCreated != 0 { + t.Fatalf("rerun should reuse resources, got %+v", second) + } + if counts := workspaceCounts(t, db); counts != countsAfterFirst { + t.Fatalf("rerun duplicated resources: before %+v after %+v", countsAfterFirst, counts) + } + + summary, err := store.RepresentationSummary(context.Background(), scanResult.RepositoryID) + if err != nil { + t.Fatal(err) + } + if summary.VisibleSymbols != 2 || summary.VisibleReferences != 1 { + t.Fatalf("unexpected representation summary: %+v", summary) + } + decisions, err := store.FilterDecisions(context.Background(), scanResult.RepositoryID, FilterDecisionQuery{Decision: "visible"}) + if err != nil { + t.Fatal(err) + } + if len(decisions) < 3 { + t.Fatalf("expected symbol and reference decisions, got %+v", decisions) + } +} + +func TestRepresentPreservesDirtyElementButAddsNewConnector(t *testing.T) { + db := openTestDB(t) + defer func() { _ = db.Close() }() + repo := initGitRepoNoCommit(t) + writeFile(t, repo, "main.go", `package main + +func Main() { + helper() +} + +func helper() {} +`) + + store := NewStore(db) + scan, err := NewScanner(store).Scan(context.Background(), repo) + if err != nil { + t.Fatal(err) + } + if _, err := NewRepresenter(store).Represent(context.Background(), scan.RepositoryID, RepresentRequest{Embedding: EmbeddingConfig{Provider: "none"}}); err != nil { + t.Fatal(err) + } + helperID := elementIDByName(t, db, "helper") + if _, err := db.Exec(`UPDATE elements SET name = 'User Helper', description = 'manual edit', updated_at = 'user' WHERE id = ?`, helperID); err != nil { + t.Fatal(err) + } + writeFile(t, repo, "main.go", `package main + +func Main() { + helper() +} + +func extra() { + helper() +} + +func helper() {} +`) + if _, err := NewScanner(store).Scan(context.Background(), repo); err != nil { + t.Fatal(err) + } + next, err := NewRepresenter(store).Represent(context.Background(), scan.RepositoryID, RepresentRequest{Embedding: EmbeddingConfig{Provider: "none"}}) + if err != nil { + t.Fatal(err) + } + if next.ElementsPreserved == 0 { + t.Fatalf("expected dirty element to be preserved, got %+v", next) + } + var name, description string + if err := db.QueryRow(`SELECT name, description FROM elements WHERE id = ?`, helperID).Scan(&name, &description); err != nil { + t.Fatal(err) + } + if name != "User Helper" || description != "manual edit" { + t.Fatalf("dirty element was overwritten: name=%q description=%q", name, description) + } + if !connectorExistsBetween(t, db, "extra", "User Helper") { + t.Fatal("expected watch to add a new connector to the dirty endpoint") + } +} + +func TestRepresentDoesNotTreatUpdatedAtOnlyAsDirty(t *testing.T) { + db := openTestDB(t) + defer func() { _ = db.Close() }() + repo := initGitRepoNoCommit(t) + writeFile(t, repo, "main.go", `package main + +func Main() {} +`) + + store := NewStore(db) + scan, err := NewScanner(store).Scan(context.Background(), repo) + if err != nil { + t.Fatal(err) + } + if _, err := NewRepresenter(store).Represent(context.Background(), scan.RepositoryID, RepresentRequest{Embedding: EmbeddingConfig{Provider: "none"}}); err != nil { + t.Fatal(err) + } + mainID := elementIDByName(t, db, "Main") + if _, err := db.Exec(`UPDATE elements SET updated_at = 'user' WHERE id = ?`, mainID); err != nil { + t.Fatal(err) + } + next, err := NewRepresenter(store).Represent(context.Background(), scan.RepositoryID, RepresentRequest{Embedding: EmbeddingConfig{Provider: "none"}}) + if err != nil { + t.Fatal(err) + } + if next.ElementsPreserved != 0 { + t.Fatalf("updated_at-only change should not be dirty, got %+v", next) + } + var dirty int + if err := db.QueryRow(`SELECT dirty FROM watch_materialization WHERE resource_type = 'element' AND resource_id = ?`, mainID).Scan(&dirty); err != nil { + t.Fatal(err) + } + if dirty != 0 { + t.Fatalf("updated_at-only change marked dirty = %d", dirty) + } +} + +func TestPruneDeletedSourcePreservesDirtyElement(t *testing.T) { + db := openTestDB(t) + defer func() { _ = db.Close() }() + repo := initGitRepoNoCommit(t) + writeFile(t, repo, "main.go", `package main + +func Main() {} +`) + + store := NewStore(db) + scan, err := NewScanner(store).Scan(context.Background(), repo) + if err != nil { + t.Fatal(err) + } + if _, err := NewRepresenter(store).Represent(context.Background(), scan.RepositoryID, RepresentRequest{Embedding: EmbeddingConfig{Provider: "none"}}); err != nil { + t.Fatal(err) + } + mainID := elementIDByName(t, db, "Main") + if _, err := db.Exec(`UPDATE elements SET name = 'User Main', updated_at = 'user' WHERE id = ?`, mainID); err != nil { + t.Fatal(err) + } + if err := os.Remove(filepath.Join(repo, "main.go")); err != nil { + t.Fatal(err) + } + if _, err := NewScanner(store).Scan(context.Background(), repo); err != nil { + t.Fatal(err) + } + if _, err := NewRepresenter(store).Represent(context.Background(), scan.RepositoryID, RepresentRequest{Embedding: EmbeddingConfig{Provider: "none"}}); err != nil { + t.Fatal(err) + } + if !elementNameExists(t, db, "User Main") { + t.Fatal("dirty element should remain after backing source is deleted") + } + var mappingCount int + if err := db.QueryRow(`SELECT COUNT(*) FROM watch_materialization WHERE repository_id = ? AND resource_type = 'element' AND resource_id = ?`, scan.RepositoryID, mainID).Scan(&mappingCount); err != nil { + t.Fatal(err) + } + if mappingCount == 0 { + t.Fatal("dirty element mapping should remain after backing source is deleted") + } +} + +func TestRepresentPreservesDirtyConnector(t *testing.T) { + db := openTestDB(t) + defer func() { _ = db.Close() }() + repo := initGitRepoNoCommit(t) + writeFile(t, repo, "main.go", `package main + +func Main() { + helper() +} + +func helper() {} +`) + + store := NewStore(db) + scan, err := NewScanner(store).Scan(context.Background(), repo) + if err != nil { + t.Fatal(err) + } + if _, err := NewRepresenter(store).Represent(context.Background(), scan.RepositoryID, RepresentRequest{Embedding: EmbeddingConfig{Provider: "none"}}); err != nil { + t.Fatal(err) + } + connectorID := connectorIDBetween(t, db, "Main", "helper") + if _, err := db.Exec(`UPDATE connectors SET label = 'manual label', relationship = 'manual relationship', style = 'dashed', updated_at = 'user' WHERE id = ?`, connectorID); err != nil { + t.Fatal(err) + } + next, err := NewRepresenter(store).Represent(context.Background(), scan.RepositoryID, RepresentRequest{Embedding: EmbeddingConfig{Provider: "none"}}) + if err != nil { + t.Fatal(err) + } + if next.ConnectorsPreserved == 0 || next.ConnectorsUpdated != 0 { + t.Fatalf("expected dirty connector to be preserved without update, got %+v", next) + } + var label, relationship, style string + if err := db.QueryRow(`SELECT label, relationship, style FROM connectors WHERE id = ?`, connectorID).Scan(&label, &relationship, &style); err != nil { + t.Fatal(err) + } + if label != "manual label" || relationship != "manual relationship" || style != "dashed" { + t.Fatalf("dirty connector was overwritten: label=%q relationship=%q style=%q", label, relationship, style) + } +} + +func TestRepresentPreservesDirtyView(t *testing.T) { + db := openTestDB(t) + defer func() { _ = db.Close() }() + repo := initGitRepoNoCommit(t) + writeFile(t, repo, "main.go", `package main + +func Main() {} +`) + + store := NewStore(db) + scan, err := NewScanner(store).Scan(context.Background(), repo) + if err != nil { + t.Fatal(err) + } + if _, err := NewRepresenter(store).Represent(context.Background(), scan.RepositoryID, RepresentRequest{Embedding: EmbeddingConfig{Provider: "none"}}); err != nil { + t.Fatal(err) + } + viewID := materializedResourceID(t, db, scan.RepositoryID, "file", "file:main.go", "view") + if _, err := db.Exec(`UPDATE views SET name = 'User View', level_label = 'Manual', updated_at = 'user' WHERE id = ?`, viewID); err != nil { + t.Fatal(err) + } + next, err := NewRepresenter(store).Represent(context.Background(), scan.RepositoryID, RepresentRequest{Embedding: EmbeddingConfig{Provider: "none"}}) + if err != nil { + t.Fatal(err) + } + if next.ViewsPreserved == 0 { + t.Fatalf("expected dirty view to be preserved, got %+v", next) + } + var name, label string + if err := db.QueryRow(`SELECT name, level_label FROM views WHERE id = ?`, viewID).Scan(&name, &label); err != nil { + t.Fatal(err) + } + if name != "User View" || label != "Manual" { + t.Fatalf("dirty view was overwritten: name=%q label=%q", name, label) + } +} + +func TestRepresentMaterializesCatalogPrimaryIconFromLanguage(t *testing.T) { + db := openTestDB(t) + defer func() { _ = db.Close() }() + repo := initGitRepoNoCommit(t) + writeFile(t, repo, "cmd/app/main.go", `package main + +func Main() {} +`) + + store := NewStore(db) + scanResult, err := NewScanner(store).Scan(context.Background(), repo) + if err != nil { + t.Fatal(err) + } + if _, err := NewRepresenter(store).Represent(context.Background(), scanResult.RepositoryID, RepresentRequest{Embedding: EmbeddingConfig{Provider: "none"}}); err != nil { + t.Fatal(err) + } + + var raw string + if err := db.QueryRow(`SELECT technology_connectors FROM elements WHERE name = 'main.go'`).Scan(&raw); err != nil { + t.Fatal(err) + } + var links []materializedTechnologyLink + if err := json.Unmarshal([]byte(raw), &links); err != nil { + t.Fatal(err) + } + want := materializedTechnologyLink{Type: "catalog", Slug: "golang", Label: "Go", IsPrimaryIcon: true} + if len(links) != 1 || links[0] != want { + t.Fatalf("technology links for main.go = %+v, want %+v", links, want) + } + + for _, tt := range []struct { + name string + slug string + }{ + {name: "Architecture", slug: "architecture"}, + {name: "Structural", slug: "structural"}, + } { + t.Run(tt.name+" section icon", func(t *testing.T) { + var technology, raw string + if err := db.QueryRow(`SELECT technology, technology_connectors FROM elements WHERE name = ? AND kind = 'view'`, tt.name).Scan(&technology, &raw); err != nil { + t.Fatal(err) + } + var sectionLinks []materializedTechnologyLink + if err := json.Unmarshal([]byte(raw), §ionLinks); err != nil { + t.Fatal(err) + } + want := materializedTechnologyLink{Type: "catalog", Slug: tt.slug, Label: tt.name, IsPrimaryIcon: true} + if technology != tt.name || len(sectionLinks) != 1 || sectionLinks[0] != want { + t.Fatalf("%s section technology=%q links=%+v, want technology=%q links=%+v", tt.name, technology, sectionLinks, tt.name, want) + } + }) + } + + var fileElementID int64 + if err := db.QueryRow(`SELECT id FROM elements WHERE name = 'main.go'`).Scan(&fileElementID); err != nil { + t.Fatal(err) + } + if _, err := db.Exec(`UPDATE elements SET logo_url = '' WHERE id = ?`, fileElementID); err != nil { + t.Fatal(err) + } + if _, err := NewRepresenter(store).Represent(context.Background(), scanResult.RepositoryID, RepresentRequest{Embedding: EmbeddingConfig{Provider: "none"}}); err != nil { + t.Fatal(err) + } + var logoURL sql.NullString + if err := db.QueryRow(`SELECT logo_url FROM elements WHERE id = ?`, fileElementID).Scan(&logoURL); err != nil { + t.Fatal(err) + } + if !logoURL.Valid || logoURL.String != "" { + t.Fatalf("watch rerun should preserve explicit no-icon logo_url, got valid=%v value=%q", logoURL.Valid, logoURL.String) + } +} + +func TestTechnologyLinksForLanguage(t *testing.T) { + tests := []struct { + name string + language string + want []materializedTechnologyLink + }{ + { + name: "go", + language: "go", + want: []materializedTechnologyLink{{ + Type: "catalog", + Slug: "golang", + Label: "Go", + IsPrimaryIcon: true, + }}, + }, + { + name: "typescript", + language: "typescript", + want: []materializedTechnologyLink{{ + Type: "catalog", + Slug: "typescript", + Label: "TypeScript", + IsPrimaryIcon: true, + }}, + }, + { + name: "unknown returns empty", + language: "ruby", + want: []materializedTechnologyLink{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := technologyLinksForLanguage(tt.language) + if !reflect.DeepEqual(got, tt.want) { + t.Fatalf("technologyLinksForLanguage(%q) = %+v, want %+v", tt.language, got, tt.want) + } + }) + } +} + +func TestTechnologyLinksForElementUsesSectionCatalogIcon(t *testing.T) { + tests := []struct { + name string + technology string + language string + want []materializedTechnologyLink + }{ + { + name: "architecture", + technology: "Architecture", + language: "go", + want: []materializedTechnologyLink{{ + Type: "catalog", + Slug: "architecture", + Label: "Architecture", + IsPrimaryIcon: true, + }}, + }, + { + name: "structural", + technology: "Structural", + language: "go", + want: []materializedTechnologyLink{{ + Type: "catalog", + Slug: "structural", + Label: "Structural", + IsPrimaryIcon: true, + }}, + }, + { + name: "container maps to docker", + technology: "Container", + language: "go", + want: []materializedTechnologyLink{{ + Type: "catalog", + Slug: "docker", + Label: "Container", + IsPrimaryIcon: true, + }}, + }, + { + name: "falls back to language", + technology: "Go", + language: "go", + want: []materializedTechnologyLink{{ + Type: "catalog", + Slug: "golang", + Label: "Go", + IsPrimaryIcon: true, + }}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := technologyLinksForElement(tt.technology, tt.language) + if !reflect.DeepEqual(got, tt.want) { + t.Fatalf("technologyLinksForElement(%q, %q) = %+v, want %+v", tt.technology, tt.language, got, tt.want) + } + }) + } +} + +func TestRepresentCollapsesHighRawReferenceFolderPairs(t *testing.T) { + db := openTestDB(t) + defer func() { _ = db.Close() }() + repo := initGitRepoNoCommit(t) + writeFile(t, repo, "internal/pkg/lib.go", `package pkg + +func Target0() {} +func Target1() {} +func Target2() {} +func Target3() {} +func Target4() {} +func Target5() {} +func Target6() {} +func Target7() {} +func Target8() {} +func Target9() {} +`) + writeFile(t, repo, "cmd/app/main.go", `package main + +import "example.com/test/internal/pkg" + +func Main() { + pkg.Target0() + pkg.Target1() + pkg.Target2() + pkg.Target3() + pkg.Target4() + pkg.Target5() + pkg.Target6() + pkg.Target7() + pkg.Target8() + pkg.Target9() +} +`) + + store := NewStore(db) + scanResult, err := NewScanner(store).Scan(context.Background(), repo) + if err != nil { + t.Fatal(err) + } + req := RepresentRequest{ + Embedding: EmbeddingConfig{Provider: "none"}, + Thresholds: Thresholds{ + MaxExpandedConnectorsPerGroup: 4, + }, + } + if _, err := NewRepresenter(store).Represent(context.Background(), scanResult.RepositoryID, req); err != nil { + t.Fatal(err) + } + + var label string + err = db.QueryRow(` + SELECT c.label + FROM connectors c + JOIN elements s ON s.id = c.source_element_id + JOIN elements t ON t.id = c.target_element_id + WHERE s.name = 'cmd' AND t.name = 'internal'`).Scan(&label) + if err != nil { + t.Fatalf("expected collapsed cmd -> internal connector: %v", err) + } + if label != "10 references" { + t.Fatalf("expected raw reference count label, got %q", label) + } +} + +func TestRepresentPrioritizesCrossFolderAggregatesOverFilePairs(t *testing.T) { + groups := map[string][]filePairReference{ + "file:cmd/a.go->cmd/b.go": { + {Key: "cmd/a.go->cmd/b.go", Count: 500}, + }, + "folder:cmd->internal": { + {Key: "cmd/a.go->internal/b.go", Count: 20}, + }, + "file:assets.go->internal/a.go": { + {Key: "assets.go->internal/a.go", Count: 200}, + }, + } + + keys := sortedFileGroupKeys(groups) + if len(keys) < 3 { + t.Fatalf("expected sorted keys, got %+v", keys) + } + if keys[0] != "folder:cmd->internal" { + t.Fatalf("expected cross-folder aggregate to be materialized first, got %+v", keys) + } +} + +func TestScanCollectsConfiguredLanguages(t *testing.T) { + db := openTestDB(t) + defer func() { _ = db.Close() }() + repo := initGitRepoNoCommit(t) + writeFile(t, repo, "main.go", "package main\nfunc Main() {}\n") + writeFile(t, repo, "src/app.ts", "export function render() { return helper() }\nfunction helper() { return 1 }\n") + + store := NewStore(db) + scanner := NewScanner(store) + scanner.Settings = Settings{Languages: []string{"go", "typescript"}} + result, err := scanner.Scan(context.Background(), repo) + if err != nil { + t.Fatal(err) + } + if result.FilesSeen != 2 || result.FilesParsed != 2 { + t.Fatalf("expected two parsed source files, got %+v", result) + } + symbols, err := store.SymbolsForRepository(context.Background(), result.RepositoryID) + if err != nil { + t.Fatal(err) + } + seenLanguages := map[string]bool{} + for _, sym := range symbols { + seenLanguages[languageFromStableKey(sym.StableKey)] = true + } + if !seenLanguages["go"] || !seenLanguages["typescript"] { + t.Fatalf("expected go and typescript stable keys, got %#v", seenLanguages) + } +} + +func TestScanRespectsGitIgnore(t *testing.T) { + db := openTestDB(t) + defer func() { _ = db.Close() }() + repo := initGitRepoNoCommit(t) + writeFile(t, repo, ".gitignore", "ignored.go\nnested/\n") + writeFile(t, repo, "main.go", "package main\nfunc Main() {}\n") + writeFile(t, repo, "ignored.go", "package main\nfunc Ignored() {}\n") + writeFile(t, repo, "nested/ignored.go", "package nested\nfunc NestedIgnored() {}\n") + + store := NewStore(db) + result, err := NewScanner(store).Scan(context.Background(), repo) + if err != nil { + t.Fatal(err) + } + if result.FilesSeen != 1 || result.FilesParsed != 1 { + t.Fatalf("expected only non-ignored source file to be scanned, got %+v", result) + } + symbols, err := store.SymbolsForRepository(context.Background(), result.RepositoryID) + if err != nil { + t.Fatal(err) + } + if len(symbols) != 1 || symbols[0].Name != "Main" { + t.Fatalf("expected only Main symbol, got %+v", symbols) + } +} + +func TestScanBackfillsFactsForCachedFiles(t *testing.T) { + db := openTestDB(t) + defer func() { _ = db.Close() }() + repo := initGitRepoNoCommit(t) + writeFile(t, repo, "go.mod", `module example.com/enriched + +go 1.23 + +require github.com/go-chi/chi/v5 v5.0.12 +`) + writeFile(t, repo, "main.go", `package main + +import "github.com/go-chi/chi/v5" + +func Routes(r chi.Router) { + r.Get("/users/{id}", GetUser) +} + +func GetUser() {} +`) + + store := NewStore(db) + scanner := NewScanner(store) + scanner.Enrichers = enrich.NewRegistry() + first, err := scanner.Scan(context.Background(), repo) + if err != nil { + t.Fatal(err) + } + if first.FilesSeen != 2 || first.FilesParsed != 1 { + t.Fatalf("expected initial scan to see go.mod and parse main.go, got %+v", first) + } + if _, err := db.Exec(`DELETE FROM watch_facts WHERE repository_id = ?`, first.RepositoryID); err != nil { + t.Fatal(err) + } + + scanner.Enrichers = defaults.NewRegistry() + second, err := scanner.Scan(context.Background(), repo) + if err != nil { + t.Fatal(err) + } + if second.FilesSkipped != 2 || second.FilesParsed != 0 { + t.Fatalf("expected warm scan to skip while backfilling facts, got %+v", second) + } + facts, err := store.FactsForRepository(context.Background(), first.RepositoryID) + if err != nil { + t.Fatal(err) + } + if !factsContain(facts, "http.route", "framework:chi") { + t.Fatalf("expected cached-file backfill to persist chi route fact, got %+v", facts) + } + version, err := store.FactVersionForFile(context.Background(), first.RepositoryID, facts[0].FileID, enrichmentVersionEnricher, enrichmentVersionStableKey(facts[0].FilePath)) + if err != nil { + t.Fatal(err) + } + if version == "" { + t.Fatalf("expected cached-file backfill to persist enrichment version sentinel, got %+v", facts) + } +} + +func TestScanIgnoresPackageJSONSignalsFromIgnoredPaths(t *testing.T) { + db := openTestDB(t) + defer func() { _ = db.Close() }() + repo := initGitRepoNoCommit(t) + writeFile(t, repo, ".gitignore", "ignored/\n") + writeFile(t, repo, "ignored/package.json", `{ + "dependencies": { + "express": "4.18.0" + } +}`) + writeFile(t, repo, "src/server.ts", `router.get("/api/users", listUsers) + +function listUsers() { + return [] +} +`) + + store := NewStore(db) + scanner := NewScanner(store) + scanner.Settings = Settings{Languages: []string{"typescript"}} + result, err := scanner.Scan(context.Background(), repo) + if err != nil { + t.Fatal(err) + } + facts, err := store.FactsForRepository(context.Background(), result.RepositoryID) + if err != nil { + t.Fatal(err) + } + if factsContain(facts, "http.route", "framework:express") { + t.Fatalf("ignored package.json activated express enricher: %+v", facts) + } +} + +func TestScanForceRescanReparsesCachedFiles(t *testing.T) { + db := openTestDB(t) + defer func() { _ = db.Close() }() + repo := initGitRepoNoCommit(t) + writeFile(t, repo, "main.go", "package main\nfunc Main() {}\n") + + store := NewStore(db) + scanner := NewScanner(store) + first, err := scanner.Scan(context.Background(), repo) + if err != nil { + t.Fatal(err) + } + second, err := scanner.Scan(context.Background(), repo) + if err != nil { + t.Fatal(err) + } + forced, err := scanner.ScanWithOptions(context.Background(), repo, ScanOptions{Force: true}) + if err != nil { + t.Fatal(err) + } + if first.FilesParsed != 1 || second.FilesSkipped != 1 || forced.FilesParsed != 1 { + t.Fatalf("unexpected scan cache behavior: first=%+v second=%+v forced=%+v", first, second, forced) + } +} + +func TestNormalizeSettingsFiltersLanguagesAndDefaultsDurations(t *testing.T) { + settings := NormalizeSettings(Settings{ + Languages: []string{"TypeScript", "go", "rust", "bogus", "go", ""}, + Watcher: "unknown", + Thresholds: Thresholds{ + MaxElementsPerView: 4, + }, + }) + if strings.Join(settings.Languages, ",") != "go,rust,typescript" { + t.Fatalf("unexpected normalized languages: %#v", settings.Languages) + } + if settings.Watcher != WatcherAuto { + t.Fatalf("unknown watcher should normalize to auto, got %q", settings.Watcher) + } + if settings.PollInterval <= 0 || settings.Debounce <= 0 { + t.Fatalf("expected default durations, got poll=%s debounce=%s", settings.PollInterval, settings.Debounce) + } + if settings.Thresholds.MaxElementsPerView != 4 || settings.Thresholds.MaxConnectorsPerView <= 0 { + t.Fatalf("expected provided threshold plus defaults, got %+v", settings.Thresholds) + } + + fallback := NormalizeSettings(Settings{Languages: []string{"bogus"}}) + if len(fallback.Languages) == 0 || !languageAllowed("go", languageSet(fallback.Languages)) { + t.Fatalf("invalid-only language list should fall back to defaults, got %#v", fallback.Languages) + } +} + +func TestSourceSnapshotsRespectLanguagesAndReportChangeLanguage(t *testing.T) { + repo := initGitRepoNoCommit(t) + writeFile(t, repo, "main.go", "package main\nfunc Main() {}\n") + writeFile(t, repo, "web/app.ts", "export function render() { return 1 }\n") + writeFile(t, repo, "README.md", "# ignored\n") + + settings := Settings{Languages: []string{"typescript"}} + snapshot := sourceFileSnapshot(repo, settings, nil) + if len(snapshot) != 1 || snapshot["web/app.ts"] == "" { + t.Fatalf("expected only TypeScript source file, got %#v", snapshot) + } + + changes := diffSourceFileSnapshots( + map[string]string{"old.py": "python:1:1", "same.ts": "typescript:1:1", "changed.go": "go:1:1"}, + map[string]string{"same.ts": "typescript:1:1", "changed.go": "go:2:1", "new.cpp": "cpp:1:1"}, + ) + if got := changeSummary(changes); got != "changed.go:modified:go,new.cpp:added:cpp,old.py:deleted:python" { + t.Fatalf("unexpected source changes: %s (%+v)", got, changes) + } +} + +func TestSourceWatcherFiltersRelevantEvents(t *testing.T) { + repo := initGitRepoNoCommit(t) + writeFile(t, repo, "main.go", "package main\nfunc Main() {}\n") + writeFile(t, repo, "web/app.ts", "export function render() { return 1 }\n") + writeFile(t, repo, "README.md", "# ignored\n") + allowed := languageSet([]string{"typescript"}) + + if sourceEventRelevant(repo, filepath.Join(repo, "main.go"), allowed, nil) { + t.Fatal("Go event should be ignored when only TypeScript is allowed") + } + if !sourceEventRelevant(repo, filepath.Join(repo, "web", "app.ts"), allowed, nil) { + t.Fatal("TypeScript event should be relevant") + } + if sourceEventRelevant(repo, filepath.Join(repo, "README.md"), allowed, nil) { + t.Fatal("non-source event should be ignored") + } + + ctx := t.Context() + watcher := newSourceWatcher(ctx, repo, Settings{Watcher: WatcherPoll}, nil) + if watcher.Mode != WatcherPoll || watcher.Events != nil { + t.Fatalf("poll watcher should not create fs event channel, got %+v", watcher) + } +} + +func TestWatchDiffsCaptureWorkspaceResourceChanges(t *testing.T) { + db := openTestDB(t) + defer func() { _ = db.Close() }() + repo := initGitRepoNoCommit(t) + writeFile(t, repo, "main.go", `package main + +func Main() { + helper() +} + +func helper() {} +`) + store := NewStore(db) + scan, err := NewScanner(store).Scan(context.Background(), repo) + if err != nil { + t.Fatal(err) + } + rep, err := NewRepresenter(store).Represent(context.Background(), scan.RepositoryID, RepresentRequest{Embedding: EmbeddingConfig{Provider: "none"}}) + if err != nil { + t.Fatal(err) + } + firstDiffs, err := store.BuildWatchDiffs(context.Background(), scan.RepositoryID, rep.RepresentationHash) + if err != nil { + t.Fatal(err) + } + if connector := findDiff(firstDiffs, "connector", "added"); connector == nil || connector.Summary == nil || !strings.Contains(*connector.Summary, "->") { + t.Fatalf("expected connector diff summary to include endpoint arrow, got %+v", connector) + } + if _, err := store.CreateWatchVersion(context.Background(), scan.RepositoryID, "commit1", "first commit", "", "main", rep.RepresentationHash, nil, firstDiffs); err != nil { + t.Fatal(err) + } + + writeFile(t, repo, "main.go", `package main + +func Main() { + helper() + other() +} + +func helper() {} +func other() {} +`) + if _, err := NewScanner(store).Scan(context.Background(), repo); err != nil { + t.Fatal(err) + } + next, err := NewRepresenter(store).Represent(context.Background(), scan.RepositoryID, RepresentRequest{Embedding: EmbeddingConfig{Provider: "none"}}) + if err != nil { + t.Fatal(err) + } + diffs, err := store.BuildWatchDiffs(context.Background(), scan.RepositoryID, next.RepresentationHash) + if err != nil { + t.Fatal(err) + } + if !hasDiff(diffs, "symbol", "added") || !hasDiff(diffs, "file", "updated") || !hasDiff(diffs, "element", "added") { + t.Fatalf("expected symbol/file/element diffs, got %+v", diffs) + } +} + +func TestInitialWatchDiffsUseInitializedForCleanHeadResources(t *testing.T) { + db := openTestDB(t) + defer func() { _ = db.Close() }() + repo := initGitRepoNoCommit(t) + writeFile(t, repo, "main.go", `package main + +func Main() {} +`) + runGit(t, repo, "add", ".") + runGit(t, repo, "commit", "-m", "initial") + + store := NewStore(db) + scan, err := NewScanner(store).Scan(context.Background(), repo) + if err != nil { + t.Fatal(err) + } + rep, err := NewRepresenter(store).Represent(context.Background(), scan.RepositoryID, RepresentRequest{Embedding: EmbeddingConfig{Provider: "none"}}) + if err != nil { + t.Fatal(err) + } + diffs, err := store.BuildWatchDiffs(context.Background(), scan.RepositoryID, rep.RepresentationHash) + if err != nil { + t.Fatal(err) + } + if findDiffByOwner(diffs, "file", "main.go", "file", "initialized") == nil { + t.Fatalf("expected clean HEAD file to be initialized, got %+v", diffs) + } + if hasDiff(diffs, "file", "added") || hasDiff(diffs, "symbol", "added") || hasDiff(diffs, "element", "added") { + t.Fatalf("clean HEAD initial diff should not mark resources added, got %+v", diffs) + } +} + +func TestInitialWatchDiffsClassifyWorktreeChangesAgainstHead(t *testing.T) { + db := openTestDB(t) + defer func() { _ = db.Close() }() + repo := initGitRepoNoCommit(t) + writeFile(t, repo, "main.go", `package main + +func Main() {} +`) + writeFile(t, repo, "untouched.go", `package main + +func Untouched() {} +`) + runGit(t, repo, "add", ".") + runGit(t, repo, "commit", "-m", "initial") + + writeFile(t, repo, "main.go", `package main + +func Main() {} +func Dirty() {} +`) + writeFile(t, repo, "new.go", `package main + +func NewFile() {} +`) + + store := NewStore(db) + scan, err := NewScanner(store).Scan(context.Background(), repo) + if err != nil { + t.Fatal(err) + } + rep, err := NewRepresenter(store).Represent(context.Background(), scan.RepositoryID, RepresentRequest{Embedding: EmbeddingConfig{Provider: "none"}}) + if err != nil { + t.Fatal(err) + } + diffs, err := store.BuildWatchDiffs(context.Background(), scan.RepositoryID, rep.RepresentationHash) + if err != nil { + t.Fatal(err) + } + if findDiffByOwner(diffs, "file", "main.go", "file", "updated") == nil { + t.Fatalf("expected modified tracked file to be updated, got %+v", diffs) + } + if findDiffByOwner(diffs, "file", "new.go", "file", "added") == nil { + t.Fatalf("expected untracked file to be added, got %+v", diffs) + } + if findDiffByOwner(diffs, "file", "untouched.go", "file", "initialized") != nil { + t.Fatalf("expected untouched tracked file to be suppressed, got %+v", diffs) + } + for _, diff := range diffs { + if diff.ChangeType == "initialized" && diff.OwnerType != "repository" { + t.Fatalf("expected only repository initialized diff during dirty initial scan, got %+v", diffs) + } + } +} + +func TestInitialDirtyWatchDiffsOnlyEmitDirtyAttributedConnectors(t *testing.T) { + db := openTestDB(t) + defer func() { _ = db.Close() }() + repo := initGitRepoNoCommit(t) + writeFile(t, repo, "alpha.go", `package main + +func Alpha() { + alphaHelper() +} + +func alphaHelper() {} +`) + writeFile(t, repo, "beta.go", `package main + +func Beta() {} +func betaHelper() {} +`) + runGit(t, repo, "add", ".") + runGit(t, repo, "commit", "-m", "initial") + + writeFile(t, repo, "beta.go", `package main + +func Beta() { + betaHelper() +} + +func betaHelper() {} +`) + + store := NewStore(db) + scan, err := NewScanner(store).Scan(context.Background(), repo) + if err != nil { + t.Fatal(err) + } + rep, err := NewRepresenter(store).Represent(context.Background(), scan.RepositoryID, RepresentRequest{Embedding: EmbeddingConfig{Provider: "none"}}) + if err != nil { + t.Fatal(err) + } + diffs, err := store.BuildWatchDiffs(context.Background(), scan.RepositoryID, rep.RepresentationHash) + if err != nil { + t.Fatal(err) + } + var connectorDiffs []RepresentationDiff + for _, diff := range diffs { + if diff.ResourceType != nil && *diff.ResourceType == "connector" { + connectorDiffs = append(connectorDiffs, diff) + if strings.Contains(diff.OwnerKey, "alpha.go") { + t.Fatalf("expected clean alpha connector to be suppressed, got %+v in %+v", diff, diffs) + } + } + } + if len(connectorDiffs) == 0 { + t.Fatalf("expected dirty beta connector diff, got %+v", diffs) + } + if findDiffByOwner(diffs, "file", "alpha.go", "file", "initialized") != nil { + t.Fatalf("expected clean alpha file to be suppressed, got %+v", diffs) + } + if findDiffByOwner(diffs, "file", "beta.go", "file", "updated") == nil { + t.Fatalf("expected dirty beta file update, got %+v", diffs) + } +} + +func TestWatchDiffsIncludeElementLineDeltas(t *testing.T) { + db := openTestDB(t) + defer func() { _ = db.Close() }() + repo := initGitRepoNoCommit(t) + writeFile(t, repo, "main.go", `package main + +func Main() { + helper() +} + +func helper() {} +`) + store := NewStore(db) + scan, err := NewScanner(store).Scan(context.Background(), repo) + if err != nil { + t.Fatal(err) + } + rep, err := NewRepresenter(store).Represent(context.Background(), scan.RepositoryID, RepresentRequest{Embedding: EmbeddingConfig{Provider: "none"}}) + if err != nil { + t.Fatal(err) + } + firstDiffs, err := store.BuildWatchDiffs(context.Background(), scan.RepositoryID, rep.RepresentationHash) + if err != nil { + t.Fatal(err) + } + if _, err := store.CreateWatchVersion(context.Background(), scan.RepositoryID, "commit1", "first commit", "", "main", rep.RepresentationHash, nil, firstDiffs); err != nil { + t.Fatal(err) + } + + writeFile(t, repo, "main.go", `package main + +func Main() { + helper() + helper() +} + +func helper() {} +`) + if _, err := NewScanner(store).Scan(context.Background(), repo); err != nil { + t.Fatal(err) + } + next, err := NewRepresenter(store).Represent(context.Background(), scan.RepositoryID, RepresentRequest{Embedding: EmbeddingConfig{Provider: "none"}}) + if err != nil { + t.Fatal(err) + } + diffs, err := store.BuildWatchDiffs(context.Background(), scan.RepositoryID, next.RepresentationHash) + if err != nil { + t.Fatal(err) + } + for _, diff := range diffs { + if diff.ResourceType != nil && *diff.ResourceType == "element" && diff.ChangeType == "updated" && diff.AddedLines == 1 { + return + } + } + t.Fatalf("expected updated element diff with +1 line, got %+v", diffs) +} + +func TestWatchDiffsMaterializeChangedPackageManifests(t *testing.T) { + db := openTestDB(t) + defer func() { _ = db.Close() }() + repo := initGitRepoNoCommit(t) + writeFile(t, repo, "frontend/package.json", `{ + "name": "@tldiagram/core-ui", + "dependencies": { + "@buf/tldiagramcom_diagram.bufbuild_es": "^2.11.0" + } +} +`) + writeFile(t, repo, "frontend/package-lock.json", `{ + "name": "@tldiagram/core-ui", + "packages": { + "": { + "dependencies": { + "@buf/tldiagramcom_diagram.bufbuild_es": "^2.11.0" + } + } + } +} +`) + writeFile(t, repo, "frontend/src/App.tsx", `export function App() { + return null +} +`) + runGit(t, repo, "add", ".") + runGit(t, repo, "commit", "-m", "initial frontend") + + store := NewStore(db) + scan, err := NewScanner(store).Scan(context.Background(), repo) + if err != nil { + t.Fatal(err) + } + rep, err := NewRepresenter(store).Represent(context.Background(), scan.RepositoryID, RepresentRequest{Embedding: EmbeddingConfig{Provider: "none"}}) + if err != nil { + t.Fatal(err) + } + if _, err := store.CreateWatchVersion(context.Background(), scan.RepositoryID, "commit1", "initial frontend", "", "main", rep.RepresentationHash, nil, nil); err != nil { + t.Fatal(err) + } + + writeFile(t, repo, "frontend/package.json", `{ + "name": "@tldiagram/core-ui", + "dependencies": { + "@buf/tldiagramcom_diagram.bufbuild_es": "^2.12.0" + } +} +`) + writeFile(t, repo, "frontend/package-lock.json", `{ + "name": "@tldiagram/core-ui", + "packages": { + "": { + "dependencies": { + "@buf/tldiagramcom_diagram.bufbuild_es": "^2.12.0" + } + } + } +} +`) + if _, err := NewScanner(store).Scan(context.Background(), repo); err != nil { + t.Fatal(err) + } + next, err := NewRepresenter(store).Represent(context.Background(), scan.RepositoryID, RepresentRequest{Embedding: EmbeddingConfig{Provider: "none"}}) + if err != nil { + t.Fatal(err) + } + diffs, err := store.BuildWatchDiffs(context.Background(), scan.RepositoryID, next.RepresentationHash) + if err != nil { + t.Fatal(err) + } + for _, path := range []string{"frontend/package.json", "frontend/package-lock.json"} { + diff := findDiffByOwner(diffs, "file", "file:"+path, "element", "updated") + if diff == nil { + t.Fatalf("expected changed manifest %s to produce updated file element diff, got %+v", path, diffs) + } + if diff.AddedLines == 0 || diff.RemovedLines == 0 { + t.Fatalf("expected accurate line delta for %s, got %+v", path, diff) + } + } +} + +func TestWatchDiffsMaterializeChangedHiddenSymbolAsUpdated(t *testing.T) { + db := openTestDB(t) + defer func() { _ = db.Close() }() + repo := initGitRepoNoCommit(t) + writeFile(t, repo, "frontend/src/pages/ViewsGrid.tsx", `function viewGridInner() { + return 'grid' +} +`) + runGit(t, repo, "add", ".") + runGit(t, repo, "commit", "-m", "initial hidden symbol") + + store := NewStore(db) + scan, err := NewScanner(store).Scan(context.Background(), repo) + if err != nil { + t.Fatal(err) + } + rep, err := NewRepresenter(store).Represent(context.Background(), scan.RepositoryID, RepresentRequest{Embedding: EmbeddingConfig{Provider: "none"}}) + if err != nil { + t.Fatal(err) + } + if _, err := store.CreateWatchVersion(context.Background(), scan.RepositoryID, "commit1", "initial hidden symbol", "", "main", rep.RepresentationHash, nil, nil); err != nil { + t.Fatal(err) + } + if count := materializationOwnerTypeCount(t, db, "symbol"); count != 0 { + t.Fatalf("expected hidden symbol to be omitted from the baseline materialization, got %d symbol mappings", count) + } + + writeFile(t, repo, "frontend/src/pages/ViewsGrid.tsx", `function viewGridInner() { + const mode = 'grid' + return mode +} +`) + if _, err := NewScanner(store).Scan(context.Background(), repo); err != nil { + t.Fatal(err) + } + next, err := NewRepresenter(store).Represent(context.Background(), scan.RepositoryID, RepresentRequest{Embedding: EmbeddingConfig{Provider: "none"}}) + if err != nil { + t.Fatal(err) + } + diffs, err := store.BuildWatchDiffs(context.Background(), scan.RepositoryID, next.RepresentationHash) + if err != nil { + t.Fatal(err) + } + ownerKey := "typescript:frontend/src/pages/ViewsGrid.tsx:function:viewGridInner" + diff := findDiffByOwner(diffs, "symbol", ownerKey, "element", "updated") + if diff == nil { + t.Fatalf("expected changed hidden symbol to produce updated symbol element diff, got %+v", diffs) + } + if diff.AddedLines != 2 || diff.RemovedLines != 1 { + t.Fatalf("expected changed hidden symbol to report exact line diff, got %+v", diff) + } +} + +func TestRepresentForcesChangedSymbolReferenceEndpoints(t *testing.T) { + db := openTestDB(t) + defer func() { _ = db.Close() }() + repo := initGitRepoNoCommit(t) + writeFile(t, repo, "main.go", `package main + +func callerOne() { + sharedEndpoint() +} + +func callerTwo() { + sharedEndpoint() +} + +func changedHidden() string { + return "quiet" +} + +func sharedEndpoint() {} +`) + runGit(t, repo, "add", ".") + runGit(t, repo, "commit", "-m", "initial context") + + store := NewStore(db) + scan, err := NewScanner(store).Scan(context.Background(), repo) + if err != nil { + t.Fatal(err) + } + req := RepresentRequest{ + Embedding: EmbeddingConfig{Provider: "none"}, + Thresholds: Thresholds{ + MaxIncomingPerElement: 1, + }, + } + rep, err := NewRepresenter(store).Represent(context.Background(), scan.RepositoryID, req) + if err != nil { + t.Fatal(err) + } + if _, err := store.CreateWatchVersion(context.Background(), scan.RepositoryID, "commit1", "initial context", "", "main", rep.RepresentationHash, nil, nil); err != nil { + t.Fatal(err) + } + + writeFile(t, repo, "main.go", `package main + +func callerOne() { + sharedEndpoint() +} + +func callerTwo() { + sharedEndpoint() +} + +func changedHidden() string { + sharedEndpoint() + return "changed" +} + +func sharedEndpoint() {} +`) + if _, err := NewScanner(store).Scan(context.Background(), repo); err != nil { + t.Fatal(err) + } + if _, err := NewRepresenter(store).Represent(context.Background(), scan.RepositoryID, req); err != nil { + t.Fatal(err) + } + + shared, err := symbolsByName(context.Background(), store, scan.RepositoryID, "sharedEndpoint") + if err != nil { + t.Fatal(err) + } + decisions, err := store.FilterDecisions(context.Background(), scan.RepositoryID, FilterDecisionQuery{Decision: "visible"}) + if err != nil { + t.Fatal(err) + } + if !filterDecisionHasReason(decisions, shared.ID, "endpoint of changed symbol") { + t.Fatalf("expected shared endpoint to be forced visible, got decisions %+v", decisions) + } + var connectorID int64 + err = db.QueryRow(` + SELECT c.id + FROM connectors c + JOIN elements s ON s.id = c.source_element_id + JOIN elements t ON t.id = c.target_element_id + WHERE s.name = 'changedHidden' AND t.name = 'sharedEndpoint'`).Scan(&connectorID) + if err != nil { + t.Fatalf("expected connector from changed symbol to forced endpoint: %v", err) + } +} + +func TestAddedRawSymbolIsForcedVisibleSinceLatestVersion(t *testing.T) { + db := openTestDB(t) + defer func() { _ = db.Close() }() + repo := initGitRepoNoCommit(t) + writeFile(t, repo, "main.go", `package main + +func Main() {} +`) + + store := NewStore(db) + scan, err := NewScanner(store).Scan(context.Background(), repo) + if err != nil { + t.Fatal(err) + } + req := RepresentRequest{Embedding: EmbeddingConfig{Provider: "none"}} + rep, err := NewRepresenter(store).Represent(context.Background(), scan.RepositoryID, req) + if err != nil { + t.Fatal(err) + } + if _, err := store.CreateWatchVersion(context.Background(), scan.RepositoryID, "commit1", "initial", "", "main", rep.RepresentationHash, nil, nil); err != nil { + t.Fatal(err) + } + + writeFile(t, repo, "internal/quiet.go", `package internal + +func quietAdded() string { + return "new" +} +`) + if _, err := NewScanner(store).Scan(context.Background(), repo); err != nil { + t.Fatal(err) + } + if _, err := NewRepresenter(store).Represent(context.Background(), scan.RepositoryID, req); err != nil { + t.Fatal(err) + } + + added, err := symbolsByName(context.Background(), store, scan.RepositoryID, "quietAdded") + if err != nil { + t.Fatal(err) + } + decisions, err := store.FilterDecisions(context.Background(), scan.RepositoryID, FilterDecisionQuery{Decision: "visible"}) + if err != nil { + t.Fatal(err) + } + if !filterDecisionHasReason(decisions, added.ID, "added since latest watch version") { + t.Fatalf("expected added private symbol to be forced visible, got decisions %+v", decisions) + } +} + +func TestSourceChangeRepresentationChangedIsPerFile(t *testing.T) { + element := "element" + diffs := []RepresentationDiff{ + {OwnerType: "repository", OwnerKey: "1", ChangeType: "updated"}, + {OwnerType: "symbol", OwnerKey: "go:changed.go:function:Changed", ChangeType: "updated", ResourceType: &element}, + } + if !sourceChangeRepresentationChanged(SourceFileChange{Path: "changed.go", ChangeType: "updated"}, diffs) { + t.Fatalf("expected changed.go to be attributed to its symbol diff") + } + if sourceChangeRepresentationChanged(SourceFileChange{Path: "unchanged.go", ChangeType: "updated"}, diffs) { + t.Fatalf("unchanged.go should not inherit another file's representation diff") + } +} + +func TestWatchDiffsAttributeLineDiffsToSymbolRanges(t *testing.T) { + db := openTestDB(t) + defer func() { _ = db.Close() }() + repo := initGitRepoNoCommit(t) + writeFile(t, repo, "main.go", `package main + +func Alpha() string { + return "alpha" +} + +func Beta() string { + return "beta" +} +`) + runGit(t, repo, "add", ".") + runGit(t, repo, "commit", "-m", "initial symbols") + + store := NewStore(db) + scan, err := NewScanner(store).Scan(context.Background(), repo) + if err != nil { + t.Fatal(err) + } + rep, err := NewRepresenter(store).Represent(context.Background(), scan.RepositoryID, RepresentRequest{Embedding: EmbeddingConfig{Provider: "none"}}) + if err != nil { + t.Fatal(err) + } + if _, err := store.CreateWatchVersion(context.Background(), scan.RepositoryID, "commit1", "initial symbols", "", "main", rep.RepresentationHash, nil, nil); err != nil { + t.Fatal(err) + } + + writeFile(t, repo, "main.go", `package main + +func Alpha() string { + value := "alpha" + return value +} + +func Beta() string { + first := "be" + second := "ta" + return first + second +} +`) + if _, err := NewScanner(store).Scan(context.Background(), repo); err != nil { + t.Fatal(err) + } + next, err := NewRepresenter(store).Represent(context.Background(), scan.RepositoryID, RepresentRequest{Embedding: EmbeddingConfig{Provider: "none"}}) + if err != nil { + t.Fatal(err) + } + diffs, err := store.BuildWatchDiffs(context.Background(), scan.RepositoryID, next.RepresentationHash) + if err != nil { + t.Fatal(err) + } + + alpha := findDiffByOwner(diffs, "symbol", "go:main.go:function:Alpha", "symbol", "updated") + if alpha == nil { + t.Fatalf("expected Alpha symbol diff, got %+v", diffs) + } + if alpha.AddedLines != 2 || alpha.RemovedLines != 1 { + t.Fatalf("expected Alpha to receive only its hunk lines, got %+v", alpha) + } + beta := findDiffByOwner(diffs, "symbol", "go:main.go:function:Beta", "symbol", "updated") + if beta == nil { + t.Fatalf("expected Beta symbol diff, got %+v", diffs) + } + if beta.AddedLines != 3 || beta.RemovedLines != 1 { + t.Fatalf("expected Beta to receive only its hunk lines, got %+v", beta) + } +} + +func TestCreateVersionForHeadCanBaselineAlreadyRepresentedCommit(t *testing.T) { + db := openTestDB(t) + defer func() { _ = db.Close() }() + repo := initGitRepoNoCommit(t) + writeFile(t, repo, "main.go", `package main + +func Main() {} +`) + runGit(t, repo, "add", ".") + runGit(t, repo, "commit", "-m", "initial") + + store := NewStore(db) + scan, err := NewScanner(store).Scan(context.Background(), repo) + if err != nil { + t.Fatal(err) + } + rep, err := NewRepresenter(store).Represent(context.Background(), scan.RepositoryID, RepresentRequest{Embedding: EmbeddingConfig{Provider: "none"}}) + if err != nil { + t.Fatal(err) + } + status, err := gitStatusSnapshot(repo) + if err != nil { + t.Fatal(err) + } + runner := &Runner{Store: store} + if err := runner.createVersionForHead(context.Background(), scan.RepositoryID, status, rep.RepresentationHash, false); err != nil { + t.Fatal(err) + } + + writeFile(t, repo, "main.go", `package main + +func Main() {} +func Other() {} +`) + if _, err := NewScanner(store).Scan(context.Background(), repo); err != nil { + t.Fatal(err) + } + next, err := NewRepresenter(store).Represent(context.Background(), scan.RepositoryID, RepresentRequest{Embedding: EmbeddingConfig{Provider: "none"}}) + if err != nil { + t.Fatal(err) + } + pendingDiffs, err := store.BuildWatchDiffs(context.Background(), scan.RepositoryID, next.RepresentationHash) + if err != nil { + t.Fatal(err) + } + if !hasDiff(pendingDiffs, "element", "added") { + t.Fatalf("expected uncommitted representation to have pending element diff, got %+v", pendingDiffs) + } + + runGit(t, repo, "add", ".") + runGit(t, repo, "commit", "-m", "add other") + status, err = gitStatusSnapshot(repo) + if err != nil { + t.Fatal(err) + } + if err := runner.createVersionForHead(context.Background(), scan.RepositoryID, status, next.RepresentationHash, false); err != nil { + t.Fatal(err) + } + latest, found, err := store.LatestWatchVersion(context.Background(), scan.RepositoryID) + if err != nil { + t.Fatal(err) + } + if !found || latest.CommitHash != status.HeadCommit { + t.Fatalf("expected latest version for committed head, got found=%v version=%+v status=%+v", found, latest, status) + } + committedDiffs, err := store.WatchDiffs(context.Background(), latest.ID, "", "", "", "", 200) + if err != nil { + t.Fatal(err) + } + if len(committedDiffs) != 0 { + t.Fatalf("expected committed baseline version to have no pending diffs, got %+v", committedDiffs) + } + if latest.CommitMessage != "add other" { + t.Fatalf("expected commit message to be stored, got %q", latest.CommitMessage) + } +} + +func TestCreateVersionForHeadStoresDirtyHeadDiffsAndMetadata(t *testing.T) { + db := openTestDB(t) + defer func() { _ = db.Close() }() + repo := initGitRepoNoCommit(t) + writeFile(t, repo, "main.go", `package main + +func Main() {} +`) + runGit(t, repo, "add", ".") + runGit(t, repo, "commit", "-m", "initial") + + store := NewStore(db) + scan, err := NewScanner(store).Scan(context.Background(), repo) + if err != nil { + t.Fatal(err) + } + rep, err := NewRepresenter(store).Represent(context.Background(), scan.RepositoryID, RepresentRequest{Embedding: EmbeddingConfig{Provider: "none"}}) + if err != nil { + t.Fatal(err) + } + status, err := gitStatusSnapshot(repo) + if err != nil { + t.Fatal(err) + } + runner := &Runner{Store: store} + if err := runner.createVersionForHead(context.Background(), scan.RepositoryID, status, rep.RepresentationHash, false); err != nil { + t.Fatal(err) + } + firstHead := status.HeadCommit + + writeFile(t, repo, "main.go", `package main + +func Main() {} +func Committed() {} +`) + runGit(t, repo, "add", ".") + runGit(t, repo, "commit", "-m", "add committed") + intermediateHead, err := tldgit.DetectHeadCommit(repo) + if err != nil { + t.Fatal(err) + } + writeFile(t, repo, "main.go", `package main + +func Main() {} +func Committed() {} +func SecondCommitted() {} +`) + runGit(t, repo, "add", ".") + runGit(t, repo, "commit", "-m", "add second committed") + writeFile(t, repo, "main.go", `package main + +func Main() {} +func Committed() {} +func SecondCommitted() {} +func Dirty() {} +`) + if _, err := NewScanner(store).Scan(context.Background(), repo); err != nil { + t.Fatal(err) + } + dirtyRep, err := NewRepresenter(store).Represent(context.Background(), scan.RepositoryID, RepresentRequest{Embedding: EmbeddingConfig{Provider: "none"}}) + if err != nil { + t.Fatal(err) + } + status, err = gitStatusSnapshot(repo) + if err != nil { + t.Fatal(err) + } + if gitStatusClean(status) { + t.Fatalf("test setup should have a dirty worktree: %+v", status) + } + if err := runner.createVersionForHead(context.Background(), scan.RepositoryID, status, dirtyRep.RepresentationHash, false); err != nil { + t.Fatal(err) + } + + latest, found, err := store.LatestWatchVersion(context.Background(), scan.RepositoryID) + if err != nil { + t.Fatal(err) + } + if !found || latest.CommitHash != status.HeadCommit || latest.CommitMessage != "add second committed" || latest.ParentCommitHash != intermediateHead || latest.Branch == "" || latest.WorkspaceVersionID == nil { + t.Fatalf("dirty head version metadata was not stored correctly: found=%v latest=%+v status=%+v first=%s intermediate=%s", found, latest, status, firstHead, intermediateHead) + } + diffs, err := store.WatchDiffs(context.Background(), latest.ID, "", "", "", "", 200) + if err != nil { + t.Fatal(err) + } + if !hasDiff(diffs, "element", "added") { + t.Fatalf("expected dirty head snapshot to retain pending representation diffs, got %+v", diffs) + } +} + +func TestCreateWatchVersionRetainsOnlyFiveRecentSnapshots(t *testing.T) { + db := openTestDB(t) + defer func() { _ = db.Close() }() + repo := initGitRepoNoCommit(t) + writeFile(t, repo, "main.go", `package main + +func Main() {} +`) + + store := NewStore(db) + scan, err := NewScanner(store).Scan(context.Background(), repo) + if err != nil { + t.Fatal(err) + } + rep, err := NewRepresenter(store).Represent(context.Background(), scan.RepositoryID, RepresentRequest{Embedding: EmbeddingConfig{Provider: "none"}}) + if err != nil { + t.Fatal(err) + } + resourceType := "element" + resourceID := int64(1) + for i := 1; i <= 6; i++ { + after := fmt.Sprintf("after-%d", i) + summary := fmt.Sprintf("snapshot %d", i) + diffs := []RepresentationDiff{{ + OwnerType: "symbol", + OwnerKey: fmt.Sprintf("go:main.go:function:Main%d", i), + ChangeType: "added", + AfterHash: &after, + ResourceType: &resourceType, + ResourceID: &resourceID, + Summary: &summary, + AddedLines: 1, + }} + if _, err := store.CreateWatchVersion(context.Background(), scan.RepositoryID, fmt.Sprintf("commit-%d", i), fmt.Sprintf("commit %d", i), "", "main", fmt.Sprintf("%s-%d", rep.RepresentationHash, i), nil, diffs); err != nil { + t.Fatal(err) + } + } + + versions, err := store.WatchVersions(context.Background(), scan.RepositoryID, 100) + if err != nil { + t.Fatal(err) + } + if len(versions) != 5 { + t.Fatalf("expected only five retained watch versions, got %d: %+v", len(versions), versions) + } + for i, version := range versions { + expected := fmt.Sprintf("commit-%d", 6-i) + if version.CommitHash != expected { + t.Fatalf("expected retained version %d to be %s, got %+v", i, expected, version) + } + } + var oldestDiffs int + if err := db.QueryRow(` + SELECT COUNT(*) + FROM watch_representation_diffs d + JOIN watch_versions v ON v.id = d.version_id + WHERE v.repository_id = ? AND v.commit_hash = 'commit-1'`, scan.RepositoryID).Scan(&oldestDiffs); err != nil { + t.Fatal(err) + } + if oldestDiffs != 0 { + t.Fatalf("expected oldest snapshot diffs to be pruned, found %d", oldestDiffs) + } + var resources int + if err := db.QueryRow(`SELECT COUNT(*) FROM watch_version_resources`).Scan(&resources); err != nil { + t.Fatal(err) + } + if resources == 0 { + t.Fatal("expected retained snapshots to keep version resources") + } + var materializedElements int + if err := db.QueryRow(`SELECT COUNT(*) FROM elements`).Scan(&materializedElements); err != nil { + t.Fatal(err) + } + if materializedElements == 0 { + t.Fatal("snapshot pruning should not delete current materialized workspace resources") + } +} + +func TestDeletedFileTombstonesMaterializedResourcesAndDiffs(t *testing.T) { + db := openTestDB(t) + defer func() { _ = db.Close() }() + repo := initGitRepoNoCommit(t) + writeFile(t, repo, "main.go", `package main + +func Main() { + helper() +} + +func helper() {} +`) + runGit(t, repo, "add", ".") + runGit(t, repo, "commit", "-m", "initial") + + store := NewStore(db) + scan, err := NewScanner(store).Scan(context.Background(), repo) + if err != nil { + t.Fatal(err) + } + rep, err := NewRepresenter(store).Represent(context.Background(), scan.RepositoryID, RepresentRequest{Embedding: EmbeddingConfig{Provider: "none"}}) + if err != nil { + t.Fatal(err) + } + if _, err := store.CreateWatchVersion(context.Background(), scan.RepositoryID, "commit1", "initial", "", "main", rep.RepresentationHash, nil, nil); err != nil { + t.Fatal(err) + } + if err := os.Remove(filepath.Join(repo, "main.go")); err != nil { + t.Fatal(err) + } + if _, err := NewScanner(store).Scan(context.Background(), repo); err != nil { + t.Fatal(err) + } + if _, err := NewRepresenter(store).Represent(context.Background(), scan.RepositoryID, RepresentRequest{Embedding: EmbeddingConfig{Provider: "none"}}); err != nil { + t.Fatal(err) + } + status, err := gitStatusSnapshot(repo) + if err != nil { + t.Fatal(err) + } + if _, err := store.ApplyGitTags(context.Background(), scan.RepositoryID, status); err != nil { + t.Fatal(err) + } + + summary, err := store.Summary(context.Background(), scan.RepositoryID) + if err != nil { + t.Fatal(err) + } + if summary.Files != 0 || summary.Symbols != 0 { + t.Fatalf("raw graph should remove deleted file and symbols, got %+v", summary) + } + if tagged := countElementTag(t, db, "watch:deleted"); tagged == 0 { + t.Fatal("expected tombstoned materialized resources to receive watch:deleted") + } + if count := materializationOwnerTypeCount(t, db, "file"); count == 0 { + t.Fatal("expected deleted file materialization mapping to be retained as tombstone") + } + diffs, err := store.BuildWatchDiffs(context.Background(), scan.RepositoryID, rep.RepresentationHash) + if err != nil { + t.Fatal(err) + } + if !hasDiff(diffs, "file", "deleted") || !hasDiff(diffs, "symbol", "deleted") || !hasDiff(diffs, "element", "deleted") { + t.Fatalf("expected deleted raw and materialized diffs, got %+v", diffs) + } +} + +func TestRestoredDeletedFileRemovesTombstoneTagsAndReusesResources(t *testing.T) { + db := openTestDB(t) + defer func() { _ = db.Close() }() + repo := initGitRepoNoCommit(t) + source := `package main + +func Main() {} +` + writeFile(t, repo, "main.go", source) + runGit(t, repo, "add", ".") + runGit(t, repo, "commit", "-m", "initial") + + store := NewStore(db) + scan, err := NewScanner(store).Scan(context.Background(), repo) + if err != nil { + t.Fatal(err) + } + if _, err := NewRepresenter(store).Represent(context.Background(), scan.RepositoryID, RepresentRequest{Embedding: EmbeddingConfig{Provider: "none"}}); err != nil { + t.Fatal(err) + } + fileElementID, ok, err := store.MappingResourceID(context.Background(), scan.RepositoryID, "file", "file:main.go", "element") + if err != nil || !ok { + t.Fatalf("expected file element mapping, ok=%v err=%v", ok, err) + } + if err := os.Remove(filepath.Join(repo, "main.go")); err != nil { + t.Fatal(err) + } + if _, err := NewScanner(store).Scan(context.Background(), repo); err != nil { + t.Fatal(err) + } + if _, err := NewRepresenter(store).Represent(context.Background(), scan.RepositoryID, RepresentRequest{Embedding: EmbeddingConfig{Provider: "none"}}); err != nil { + t.Fatal(err) + } + status, err := gitStatusSnapshot(repo) + if err != nil { + t.Fatal(err) + } + if _, err := store.ApplyGitTags(context.Background(), scan.RepositoryID, status); err != nil { + t.Fatal(err) + } + if tagged := countElementTag(t, db, "watch:deleted"); tagged == 0 { + t.Fatal("expected deletion to create tombstone tag") + } + + writeFile(t, repo, "main.go", source) + if _, err := NewScanner(store).Scan(context.Background(), repo); err != nil { + t.Fatal(err) + } + if _, err := NewRepresenter(store).Represent(context.Background(), scan.RepositoryID, RepresentRequest{Embedding: EmbeddingConfig{Provider: "none"}}); err != nil { + t.Fatal(err) + } + status, err = gitStatusSnapshot(repo) + if err != nil { + t.Fatal(err) + } + if _, err := store.ApplyGitTags(context.Background(), scan.RepositoryID, status); err != nil { + t.Fatal(err) + } + nextFileElementID, ok, err := store.MappingResourceID(context.Background(), scan.RepositoryID, "file", "file:main.go", "element") + if err != nil || !ok { + t.Fatalf("expected restored file element mapping, ok=%v err=%v", ok, err) + } + if nextFileElementID != fileElementID { + t.Fatalf("expected restored file to reuse element %d, got %d", fileElementID, nextFileElementID) + } + if tagged := countElementTag(t, db, "watch:deleted"); tagged != 0 { + t.Fatalf("expected restore to remove watch:deleted, found %d tagged elements", tagged) + } +} + +func TestCleanHeadPrunesDeletedMaterializedTombstones(t *testing.T) { + db := openTestDB(t) + defer func() { _ = db.Close() }() + repo := initGitRepoNoCommit(t) + writeFile(t, repo, "main.go", `package main + +func Main() {} +`) + runGit(t, repo, "add", ".") + runGit(t, repo, "commit", "-m", "initial") + + store := NewStore(db) + scan, err := NewScanner(store).Scan(context.Background(), repo) + if err != nil { + t.Fatal(err) + } + rep, err := NewRepresenter(store).Represent(context.Background(), scan.RepositoryID, RepresentRequest{Embedding: EmbeddingConfig{Provider: "none"}}) + if err != nil { + t.Fatal(err) + } + status, err := gitStatusSnapshot(repo) + if err != nil { + t.Fatal(err) + } + runner := &Runner{Store: store} + if err := runner.createVersionForHead(context.Background(), scan.RepositoryID, status, rep.RepresentationHash, false); err != nil { + t.Fatal(err) + } + + if err := os.Remove(filepath.Join(repo, "main.go")); err != nil { + t.Fatal(err) + } + runGit(t, repo, "add", "-u") + runGit(t, repo, "commit", "-m", "delete main") + if _, err := NewScanner(store).Scan(context.Background(), repo); err != nil { + t.Fatal(err) + } + next, err := NewRepresenter(store).Represent(context.Background(), scan.RepositoryID, RepresentRequest{Embedding: EmbeddingConfig{Provider: "none"}}) + if err != nil { + t.Fatal(err) + } + status, err = gitStatusSnapshot(repo) + if err != nil { + t.Fatal(err) + } + if !gitStatusClean(status) { + t.Fatalf("test setup should have clean status after deletion commit: %+v", status) + } + if err := runner.createVersionForHead(context.Background(), scan.RepositoryID, status, next.RepresentationHash, false); err != nil { + t.Fatal(err) + } + + if count := materializationOwnerTypeCount(t, db, "file"); count != 0 { + t.Fatalf("expected clean baseline to prune deleted file mappings, got %d", count) + } + if count := materializationOwnerTypeCount(t, db, "symbol"); count != 0 { + t.Fatalf("expected clean baseline to prune deleted symbol mappings, got %d", count) + } + if tagged := countElementTag(t, db, "watch:deleted"); tagged != 0 { + t.Fatalf("expected clean baseline cleanup to remove tombstone tags with resources, found %d", tagged) + } + latest, found, err := store.LatestWatchVersion(context.Background(), scan.RepositoryID) + if err != nil { + t.Fatal(err) + } + if !found || latest.CommitHash != status.HeadCommit { + t.Fatalf("expected clean deletion baseline version, found=%v latest=%+v status=%+v", found, latest, status) + } + diffs, err := store.WatchDiffs(context.Background(), latest.ID, "", "", "", "", 200) + if err != nil { + t.Fatal(err) + } + if len(diffs) != 0 { + t.Fatalf("expected clean baseline to store no pending diffs, got %+v", diffs) + } +} + +func TestDirtyHeadRetainsDeletedMaterializedTombstonesAndDiffs(t *testing.T) { + db := openTestDB(t) + defer func() { _ = db.Close() }() + repo := initGitRepoNoCommit(t) + writeFile(t, repo, "gone.go", `package main + +func Gone() {} +`) + writeFile(t, repo, "keep.go", `package main + +func Keep() {} +`) + runGit(t, repo, "add", ".") + runGit(t, repo, "commit", "-m", "initial") + + store := NewStore(db) + scan, err := NewScanner(store).Scan(context.Background(), repo) + if err != nil { + t.Fatal(err) + } + rep, err := NewRepresenter(store).Represent(context.Background(), scan.RepositoryID, RepresentRequest{Embedding: EmbeddingConfig{Provider: "none"}}) + if err != nil { + t.Fatal(err) + } + status, err := gitStatusSnapshot(repo) + if err != nil { + t.Fatal(err) + } + runner := &Runner{Store: store} + if err := runner.createVersionForHead(context.Background(), scan.RepositoryID, status, rep.RepresentationHash, false); err != nil { + t.Fatal(err) + } + + if err := os.Remove(filepath.Join(repo, "gone.go")); err != nil { + t.Fatal(err) + } + writeFile(t, repo, "keep.go", `package main + +func Keep() {} +func Added() {} +`) + runGit(t, repo, "add", "keep.go") + runGit(t, repo, "commit", "-m", "add keep symbol") + if _, err := NewScanner(store).Scan(context.Background(), repo); err != nil { + t.Fatal(err) + } + next, err := NewRepresenter(store).Represent(context.Background(), scan.RepositoryID, RepresentRequest{Embedding: EmbeddingConfig{Provider: "none"}}) + if err != nil { + t.Fatal(err) + } + status, err = gitStatusSnapshot(repo) + if err != nil { + t.Fatal(err) + } + if gitStatusClean(status) || len(status.Deleted) == 0 { + t.Fatalf("test setup should retain dirty deleted file after HEAD change: %+v", status) + } + if _, err := store.ApplyGitTags(context.Background(), scan.RepositoryID, status); err != nil { + t.Fatal(err) + } + if err := runner.createVersionForHead(context.Background(), scan.RepositoryID, status, next.RepresentationHash, false); err != nil { + t.Fatal(err) + } + + if tagged := countElementTag(t, db, "watch:deleted"); tagged == 0 { + t.Fatal("expected dirty head to retain deleted tombstones") + } + if count := materializationOwnerTypeCount(t, db, "file"); count == 0 { + t.Fatal("expected dirty head to retain deleted file mapping") + } + latest, found, err := store.LatestWatchVersion(context.Background(), scan.RepositoryID) + if err != nil { + t.Fatal(err) + } + if !found || latest.CommitHash != status.HeadCommit { + t.Fatalf("expected dirty head snapshot, found=%v latest=%+v status=%+v", found, latest, status) + } + diffs, err := store.WatchDiffs(context.Background(), latest.ID, "", "", "", "", 200) + if err != nil { + t.Fatal(err) + } + if !hasDiff(diffs, "file", "deleted") || !hasDiff(diffs, "element", "deleted") { + t.Fatalf("expected dirty head snapshot to retain deleted diffs, got %+v", diffs) + } +} + +func TestWatchDiffsFilterByResourceTypeAndLanguage(t *testing.T) { + db := openTestDB(t) + defer func() { _ = db.Close() }() + repo := initGitRepoNoCommit(t) + writeFile(t, repo, "main.go", `package main + +func Main() {} +`) + store := NewStore(db) + scan, err := NewScanner(store).Scan(context.Background(), repo) + if err != nil { + t.Fatal(err) + } + rep, err := NewRepresenter(store).Represent(context.Background(), scan.RepositoryID, RepresentRequest{Embedding: EmbeddingConfig{Provider: "none"}}) + if err != nil { + t.Fatal(err) + } + diffs, err := store.BuildWatchDiffs(context.Background(), scan.RepositoryID, rep.RepresentationHash) + if err != nil { + t.Fatal(err) + } + version, err := store.CreateWatchVersion(context.Background(), scan.RepositoryID, "commit1", "first commit", "", "main", rep.RepresentationHash, nil, diffs) + if err != nil { + t.Fatal(err) + } + + symbolDiffs, err := store.WatchDiffs(context.Background(), version.ID, "", "added", "symbol", "go", 200) + if err != nil { + t.Fatal(err) + } + if len(symbolDiffs) == 0 { + t.Fatalf("expected Go symbol diffs, got none from %+v", diffs) + } + for _, diff := range symbolDiffs { + if diff.ResourceType == nil || *diff.ResourceType != "symbol" || diff.ChangeType != "added" || diff.Language == nil || *diff.Language != "go" { + t.Fatalf("diff did not satisfy filters: %+v", diff) + } + } + + none, err := store.WatchDiffs(context.Background(), version.ID, "", "", "symbol", "python", 200) + if err != nil { + t.Fatal(err) + } + if len(none) != 0 { + t.Fatalf("expected no Python symbol diffs, got %+v", none) + } +} + +func TestRepresentInitialLayoutFollowsConnectors(t *testing.T) { + db := openTestDB(t) + defer func() { _ = db.Close() }() + repo := initGitRepoNoCommit(t) + writeFile(t, repo, "main.go", `package main + +func A() { + B() + C() +} + +func B() {} +func C() {} +`) + + store := NewStore(db) + scanResult, err := NewScanner(store).Scan(context.Background(), repo) + if err != nil { + t.Fatal(err) + } + if _, err := NewRepresenter(store).Represent(context.Background(), scanResult.RepositoryID, RepresentRequest{Embedding: EmbeddingConfig{Provider: "none"}}); err != nil { + t.Fatal(err) + } + + a := functionPlacement(t, db, "A") + b := functionPlacement(t, db, "B") + c := functionPlacement(t, db, "C") + if b.x <= a.x || c.x <= a.x { + t.Fatalf("initial layout should place callees to the right of caller: A=%+v B=%+v C=%+v", a, b, c) + } + if b.x == c.x && b.y == c.y { + t.Fatalf("initial layout overlapped connected callees: B=%+v C=%+v", b, c) + } +} + +func TestOrganicWatchLayoutCapsRowsPerColumnWithinDirectedLevels(t *testing.T) { + targets := map[int64]struct{}{} + for id := int64(1); id <= int64(watchLayoutMaxRowsPerColumn+5); id++ { + targets[id] = struct{}{} + } + positions := organicWatchLayout(targets, []watchLayoutConnector{{Source: 1, Target: int64(watchLayoutMaxRowsPerColumn + 5)}}) + + rowsByColumn := map[int]int{} + for _, position := range positions { + column := int(position.X / watchLayoutGapX) + rowsByColumn[column]++ + } + for column, rows := range rowsByColumn { + if rows > watchLayoutMaxRowsPerColumn { + t.Fatalf("column %d has %d rows, want at most %d: %+v", column, rows, watchLayoutMaxRowsPerColumn, positions) + } + } + if positions[int64(watchLayoutMaxRowsPerColumn+5)].X <= positions[1].X { + t.Fatalf("directed target should remain to the right of source: source=%+v target=%+v", positions[1], positions[int64(watchLayoutMaxRowsPerColumn+5)]) + } +} + +func TestRepresentRelayoutsFreshPlacementsWithExistingMappings(t *testing.T) { + db := openTestDB(t) + defer func() { _ = db.Close() }() + repo := initGitRepoNoCommit(t) + writeFile(t, repo, "main.go", `package main + +func A() { + B() + C() +} + +func B() {} +func C() {} +`) + + store := NewStore(db) + scanResult, err := NewScanner(store).Scan(context.Background(), repo) + if err != nil { + t.Fatal(err) + } + if _, err := NewRepresenter(store).Represent(context.Background(), scanResult.RepositoryID, RepresentRequest{Embedding: EmbeddingConfig{Provider: "none"}}); err != nil { + t.Fatal(err) + } + if _, err := db.Exec(`DELETE FROM placements`); err != nil { + t.Fatal(err) + } + if _, err := NewRepresenter(store).Represent(context.Background(), scanResult.RepositoryID, RepresentRequest{Embedding: EmbeddingConfig{Provider: "none"}}); err != nil { + t.Fatal(err) + } + + a := functionPlacement(t, db, "A") + b := functionPlacement(t, db, "B") + c := functionPlacement(t, db, "C") + if b.x <= a.x || c.x <= a.x { + t.Fatalf("fresh placements with existing mappings should use full layout: A=%+v B=%+v C=%+v", a, b, c) + } + if b.x == c.x && b.y == c.y { + t.Fatalf("fresh placements with existing mappings overlapped connected callees: B=%+v C=%+v", b, c) + } +} + +func TestRepresentIncrementalLayoutPreservesExistingPlacements(t *testing.T) { + db := openTestDB(t) + defer func() { _ = db.Close() }() + repo := initGitRepoNoCommit(t) + writeFile(t, repo, "main.go", `package main + +func A() { + B() +} + +func B() {} +`) + + store := NewStore(db) + scanResult, err := NewScanner(store).Scan(context.Background(), repo) + if err != nil { + t.Fatal(err) + } + if _, err := NewRepresenter(store).Represent(context.Background(), scanResult.RepositoryID, RepresentRequest{Embedding: EmbeddingConfig{Provider: "none"}}); err != nil { + t.Fatal(err) + } + b := functionPlacement(t, db, "B") + if _, err := db.Exec(`UPDATE placements SET position_x = 780, position_y = 510 WHERE id = ?`, b.placementID); err != nil { + t.Fatal(err) + } + + writeFile(t, repo, "main.go", `package main + +func A() { + B() + C() +} + +func B() {} +func C() {} +`) + if _, err := NewScanner(store).Scan(context.Background(), repo); err != nil { + t.Fatal(err) + } + if _, err := NewRepresenter(store).Represent(context.Background(), scanResult.RepositoryID, RepresentRequest{Embedding: EmbeddingConfig{Provider: "none"}}); err != nil { + t.Fatal(err) + } + + b = functionPlacement(t, db, "B") + c := functionPlacement(t, db, "C") + if b.x != 780 || b.y != 510 { + t.Fatalf("incremental layout moved existing placement B: %+v", b) + } + if c.x == b.x && c.y == b.y { + t.Fatalf("incremental layout placed new function on occupied B cell: B=%+v C=%+v", b, c) + } +} + +func TestRepresentDoesNotTouchManualResources(t *testing.T) { + db := openTestDB(t) + defer func() { _ = db.Close() }() + repo := initGitRepoNoCommit(t) + writeFile(t, repo, "main.go", "package main\nfunc Main() {}\n") + res, err := db.Exec(`INSERT INTO elements(name, tags, technology_connectors, created_at, updated_at) VALUES ('Manual', '[]', '[]', 'now', 'now')`) + if err != nil { + t.Fatal(err) + } + manualID, _ := res.LastInsertId() + + store := NewStore(db) + scanResult, err := NewScanner(store).Scan(context.Background(), repo) + if err != nil { + t.Fatal(err) + } + if _, err := NewRepresenter(store).Represent(context.Background(), scanResult.RepositoryID, RepresentRequest{Embedding: EmbeddingConfig{Provider: "none"}}); err != nil { + t.Fatal(err) + } + + var name string + if err := db.QueryRow(`SELECT name FROM elements WHERE id = ?`, manualID).Scan(&name); err != nil { + t.Fatal(err) + } + if name != "Manual" { + t.Fatalf("manual element was changed to %q", name) + } +} + +func TestRepresentAssignsUsefulSemanticTags(t *testing.T) { + db := openTestDB(t) + defer func() { _ = db.Close() }() + repo := initGitRepoNoCommit(t) + writeFile(t, repo, "internal/watch/scan.go", `package watch + +func ScanRepository() { + RepresentRepository() +} + +func RepresentRepository() {} +`) + writeFile(t, repo, "internal/server/http.go", `package server + +func ServeAPI() {} +`) + writeFile(t, repo, "cmd/tld/main.go", `package main + +func ExecuteCLI() {} +`) + + store := NewStore(db) + scanResult, err := NewScanner(store).Scan(context.Background(), repo) + if err != nil { + t.Fatal(err) + } + if _, err := NewRepresenter(store).Represent(context.Background(), scanResult.RepositoryID, RepresentRequest{Embedding: EmbeddingConfig{Provider: "none"}}); err != nil { + t.Fatal(err) + } + + for _, tag := range []string{"tld:watch", "watch:generated", "watch:go", "lang:go"} { + if count := countElementTag(t, db, tag); count != 0 { + t.Fatalf("expected unhelpful tag %q to be omitted, found on %d elements", tag, count) + } + } + for _, tag := range []string{"role:watch", "area:internal", "kind:function", "graph:entrypoint"} { + count := countElementTag(t, db, tag) + if strings.HasPrefix(tag, "role:") { + if count < 2 { + t.Fatalf("expected useful role tag %q on multiple elements, found %d", tag, count) + } + continue + } + if count != 0 { + t.Fatalf("expected non-role generated tag %q to be omitted, found on %d elements", tag, count) + } + } + + tags := elementTagsByName(t, db, "ScanRepository") + for _, tag := range []string{"role:watch"} { + if !stringSliceContains(tags, tag) { + t.Fatalf("expected ScanRepository to include %q, got %v", tag, tags) + } + } + for _, tag := range []string{"area:internal", "kind:function", "graph:entrypoint"} { + if stringSliceContains(tags, tag) { + t.Fatalf("expected ScanRepository to omit non-role generated tag %q, got %v", tag, tags) + } + } +} + +func TestRepresentDoesNotOverwriteUserTagMetadata(t *testing.T) { + db := openTestDB(t) + defer func() { _ = db.Close() }() + repo := initGitRepoNoCommit(t) + writeFile(t, repo, "internal/watch/scan.go", `package watch + +func ScanRepository() {} + +func RepresentRepository() {} +`) + writeFile(t, repo, "internal/watch/runner.go", `package watch + +func RunWatch() {} +`) + writeFile(t, repo, "internal/server/http.go", `package server + +func ServeAPI() {} +`) + writeFile(t, repo, "cmd/tld/main.go", `package main + +func ExecuteCLI() {} +`) + + store := NewStore(db) + scanResult, err := NewScanner(store).Scan(context.Background(), repo) + if err != nil { + t.Fatal(err) + } + req := RepresentRequest{Embedding: EmbeddingConfig{Provider: "none"}} + if _, err := NewRepresenter(store).Represent(context.Background(), scanResult.RepositoryID, req); err != nil { + t.Fatal(err) + } + if count := countElementTag(t, db, "role:watch"); count == 0 { + t.Fatal("expected role:watch tag to be generated") + } + userDescription := "User picked this color" + if _, err := db.Exec(`UPDATE tags SET color = ?, description = ? WHERE name = ?`, "#123456", userDescription, "role:watch"); err != nil { + t.Fatal(err) + } + if _, err := NewRepresenter(store).Represent(context.Background(), scanResult.RepositoryID, req); err != nil { + t.Fatal(err) + } + + color, description := tagMetadataByName(t, db, "role:watch") + if color != "#123456" || description == nil || *description != userDescription { + t.Fatalf("role:watch metadata = color:%q description:%v, want user metadata preserved", color, description) + } +} + +func TestRepresentAssignsCodeownersTags(t *testing.T) { + db := openTestDB(t) + defer func() { _ = db.Close() }() + repo := initGitRepoNoCommit(t) + writeFile(t, repo, "CODEOWNERS", ` +/frontend/* @org/web-team:random(2) +/backend/* @backend @org/backend:least_busy(3) +`) + writeFile(t, repo, "frontend/app.go", `package frontend + +func Render() {} +`) + writeFile(t, repo, "backend/server.go", `package backend + +func Serve() {} +`) + + store := NewStore(db) + scanResult, err := NewScanner(store).Scan(context.Background(), repo) + if err != nil { + t.Fatal(err) + } + if _, err := NewRepresenter(store).Represent(context.Background(), scanResult.RepositoryID, RepresentRequest{Embedding: EmbeddingConfig{Provider: "none"}}); err != nil { + t.Fatal(err) + } + + for _, name := range []string{"frontend", "app.go", "Render"} { + tags := elementTagsByName(t, db, name) + if stringSliceContains(tags, "owner:@org/web-team") { + t.Fatalf("expected %s to omit non-role CODEOWNERS tag, got %v", name, tags) + } + if stringSliceContains(tags, "owner:@org/web-team:random(2)") { + t.Fatalf("expected %s extended assignment suffix to be stripped, got %v", name, tags) + } + } + backendTags := elementTagsByName(t, db, "Serve") + for _, tag := range []string{"owner:@backend", "owner:@org/backend"} { + if stringSliceContains(backendTags, tag) { + t.Fatalf("expected backend symbol to omit non-role CODEOWNERS tag %q, got %v", tag, backendTags) + } + } + if count := countElementTag(t, db, "owner:@org/web-team"); count != 0 { + t.Fatalf("expected non-role CODEOWNERS tag to be omitted, found on %d elements", count) + } +} + +func TestLargeRepresentationPrunesDetailedSymbolElements(t *testing.T) { + previousLimit := maxDetailedSymbolElements + maxDetailedSymbolElements = 100 + defer func() { maxDetailedSymbolElements = previousLimit }() + + db := openTestDB(t) + defer func() { _ = db.Close() }() + repo := initGitRepoNoCommit(t) + writeFile(t, repo, "pkg/busy.go", `package pkg + +func Func0() {} +func Func1() {} +func Func2() {} +func Func3() {} +func Func4() {} +`) + + store := NewStore(db) + scanResult, err := NewScanner(store).Scan(context.Background(), repo) + if err != nil { + t.Fatal(err) + } + req := RepresentRequest{ + Embedding: EmbeddingConfig{Provider: "none"}, + Thresholds: Thresholds{ + MaxElementsPerView: 2, + MaxConnectorsPerView: 2, + }, + } + if _, err := NewRepresenter(store).Represent(context.Background(), scanResult.RepositoryID, req); err != nil { + t.Fatal(err) + } + if count := elementKindCount(t, db, "function"); count != 5 { + t.Fatalf("expected detailed symbol elements before large-mode pruning, got %d", count) + } + + maxDetailedSymbolElements = 3 + if _, err := NewRepresenter(store).Represent(context.Background(), scanResult.RepositoryID, req); err != nil { + t.Fatal(err) + } + if count := elementKindCount(t, db, "function"); count != 0 { + t.Fatalf("expected large-mode rerun to prune detailed symbol elements, got %d", count) + } + if count := materializationOwnerTypeCount(t, db, "symbol"); count != 0 { + t.Fatalf("expected stale symbol materialization mappings to be pruned, got %d", count) + } + if count := elementKindCount(t, db, "cluster"); count == 0 { + t.Fatalf("expected cluster elements to summarize the large file") + } +} + +func TestEmbeddingCandidateSymbolsAreCappedDeterministically(t *testing.T) { + symbols := map[int64]Symbol{ + 3: {ID: 3, StableKey: "go:b.go:function:C", FilePath: "b.go", StartLine: 1}, + 1: {ID: 1, StableKey: "go:a.go:function:A", FilePath: "a.go", StartLine: 10}, + 2: {ID: 2, StableKey: "go:a.go:function:B", FilePath: "a.go", StartLine: 2}, + } + candidates := embeddingCandidateSymbols(symbols, 2) + if len(candidates) != 2 { + t.Fatalf("expected capped candidates, got %d", len(candidates)) + } + if candidates[0].ID != 2 || candidates[1].ID != 1 { + t.Fatalf("unexpected candidate order: %+v", candidates) + } +} + +func TestApplyGitTagsReportsAddedAndRemovedTags(t *testing.T) { + db := openTestDB(t) + defer func() { _ = db.Close() }() + repo := initGitRepoNoCommit(t) + writeFile(t, repo, "main.go", "package main\nfunc Main() {}\n") + + store := NewStore(db) + scanResult, err := NewScanner(store).Scan(context.Background(), repo) + if err != nil { + t.Fatal(err) + } + if _, err := NewRepresenter(store).Represent(context.Background(), scanResult.RepositoryID, RepresentRequest{Embedding: EmbeddingConfig{Provider: "none"}}); err != nil { + t.Fatal(err) + } + + first, err := store.ApplyGitTags(context.Background(), scanResult.RepositoryID, GitStatus{Untracked: []string{"main.go"}}) + if err != nil { + t.Fatal(err) + } + if first.TagsAdded == 0 || first.TagsRemoved != 0 { + t.Fatalf("expected untracked tags to be added only, got %+v", first) + } + if tagged := countElementTag(t, db, "git:untracked"); tagged == 0 { + t.Fatalf("expected git:untracked on generated elements") + } + + second, err := store.ApplyGitTags(context.Background(), scanResult.RepositoryID, GitStatus{}) + if err != nil { + t.Fatal(err) + } + if second.TagsAdded != 0 || second.TagsRemoved != first.TagsAdded { + t.Fatalf("expected stale git tags to be removed, first=%+v second=%+v", first, second) + } + if tagged := countElementTag(t, db, "git:untracked"); tagged != 0 { + t.Fatalf("expected git:untracked to be removed, found %d tagged elements", tagged) + } +} + +func TestEmbeddingCacheAvoidsProviderCalls(t *testing.T) { + db := openTestDB(t) + defer func() { _ = db.Close() }() + store := NewStore(db) + provider := &countingProvider{} + model := provider.ModelID() + modelID, err := store.EnsureEmbeddingModel(context.Background(), EmbeddingConfig{Provider: model.Provider, Model: model.Model, Dimension: model.Dimension}, model.ConfigHash) + if err != nil { + t.Fatal(err) + } + symbols := map[int64]Symbol{ + 1: {ID: 1, StableKey: "go:a.go:function:A", QualifiedName: "A", Kind: "function", FilePath: "a.go"}, + 2: {ID: 2, StableKey: "go:b.go:function:B", QualifiedName: "B", Kind: "function", FilePath: "b.go"}, + } + representer := NewRepresenter(store) + stats, _, err := representer.cacheEmbeddings(context.Background(), modelID, provider, "", []Symbol{ + symbols[1], + symbols[2], + }, nil, nil, 0) + if err != nil { + t.Fatal(err) + } + if stats.Created != 2 { + t.Fatalf("expected two embeddings created, got %+v", stats) + } + if provider.calls != 1 || provider.inputs != 2 { + t.Fatalf("expected one batched provider call for two inputs, got calls=%d inputs=%d", provider.calls, provider.inputs) + } + stats, _, err = representer.cacheEmbeddings(context.Background(), modelID, provider, "", []Symbol{ + symbols[1], + symbols[2], + }, nil, nil, 0) + if err != nil { + t.Fatal(err) + } + if stats.CacheHits != 2 { + t.Fatalf("expected two embedding cache hits, got %+v", stats) + } + if provider.calls != 1 { + t.Fatalf("cache miss recomputed embeddings, calls=%d", provider.calls) + } +} + +func TestEmbeddingCacheChunksProviderCallsAndReportsProgress(t *testing.T) { + db := openTestDB(t) + defer func() { _ = db.Close() }() + store := NewStore(db) + provider := &countingProvider{} + model := provider.ModelID() + modelID, err := store.EnsureEmbeddingModel(context.Background(), EmbeddingConfig{Provider: model.Provider, Model: model.Model, Dimension: model.Dimension}, model.ConfigHash) + if err != nil { + t.Fatal(err) + } + symbols := make([]Symbol, 0, defaultEmbeddingBatchSize*2+1) + for i := range defaultEmbeddingBatchSize*2 + 1 { + name := fmt.Sprintf("Symbol%d", i) + symbols = append(symbols, Symbol{ID: int64(i + 1), StableKey: "go:a.go:function:" + name, QualifiedName: name, Kind: "function", FilePath: "a.go"}) + } + progress := &recordingProgress{} + + stats, _, err := NewRepresenter(store).cacheEmbeddings(context.Background(), modelID, provider, "", symbols, nil, progress, 0) + if err != nil { + t.Fatal(err) + } + if stats.Created != len(symbols) { + t.Fatalf("expected %d embeddings created, got %+v", len(symbols), stats) + } + expectedBatchSizes := fmt.Sprintf("%d,%d,1", defaultEmbeddingBatchSize, defaultEmbeddingBatchSize) + if provider.calls != 3 || strings.Join(provider.batchSizes, ",") != expectedBatchSizes { + t.Fatalf("expected chunked provider calls %s, got calls=%d batchSizes=%v", expectedBatchSizes, provider.calls, provider.batchSizes) + } + expectedProgressTotal := fmt.Sprintf("%d", defaultEmbeddingBatchSize*2+1) + if len(progress.starts) != 2 || progress.starts[0] != "Preparing symbol embeddings:"+expectedProgressTotal || progress.starts[1] != "Embedding symbols:"+expectedProgressTotal { + t.Fatalf("unexpected progress starts: %v", progress.starts) + } + if progress.advances != len(symbols)*2 { + t.Fatalf("expected prepare and embed progress advances, got %d", progress.advances) + } +} + +func TestSymbolEmbeddingTextUsesOutdentedCodeBody(t *testing.T) { + repo := t.TempDir() + writeFile(t, repo, "a.go", `package main + +func Outer() { + if true { + fmt.Println("body") + } +} +`) + end := 6 + text := symbolEmbeddingText(repo, Symbol{ + QualifiedName: "Outer", + Kind: "function", + FilePath: "a.go", + StartLine: 3, + EndLine: &end, + }) + + if !strings.Contains(text, `fmt.Println("body")`) { + t.Fatalf("expected embedding text to include code body, got:\n%s", text) + } + if strings.Contains(text, "Outer\nfunction\na.go") { + t.Fatalf("embedding text fell back to metadata instead of source body:\n%s", text) + } +} + +func TestShrinkEmbeddingTextFitsApproximateTokenBudget(t *testing.T) { + text := shrinkEmbeddingText(strings.Repeat("// comment that should be removed\n", 600) + strings.Repeat("statement := value + otherValue\n", 700)) + if approximateTokenCount(text) > maxEmbeddingInputApproxTokens { + t.Fatalf("expected text within token budget, got %d", approximateTokenCount(text)) + } + if strings.Contains(text, "// comment") { + t.Fatalf("expected low-signal comment lines to be dropped") + } +} + +func TestLocalLexicalProviderKeepsRenamedCodeSimilar(t *testing.T) { + provider := LexicalProvider{} + vectors, err := provider.Embed(context.Background(), []EmbeddingInput{ + {Text: `func FetchUser(id string) (*User, error) { + cacheKey := "user:" + id + if cached, ok := cache.Get(cacheKey); ok { + return cached, nil + } + return client.Load(id) +}`}, + {Text: `func LoadAccount(accountID string) (*Account, error) { + cacheKey := "user:" + accountID + if cached, ok := cache.Get(cacheKey); ok { + return cached, nil + } + return client.Load(accountID) +}`}, + {Text: `func WriteAudit(event Event) error { + data, err := json.Marshal(event) + if err != nil { + return err + } + return os.WriteFile("audit.json", data, 0600) +}`}, + }) + if err != nil { + t.Fatal(err) + } + renamed := CosineSimilarity(vectors[0], vectors[1]) + unrelated := CosineSimilarity(vectors[0], vectors[2]) + if renamed < 0.70 { + t.Fatalf("expected renamed implementation to stay similar, got %.3f", renamed) + } + if unrelated >= renamed { + t.Fatalf("expected unrelated implementation below renamed similarity, renamed=%.3f unrelated=%.3f", renamed, unrelated) + } +} + +func TestDefaultEmbeddingConfigUsesLocalOpenAIEndpoint(t *testing.T) { + cfg := NormalizeEmbeddingConfig(EmbeddingConfig{}) + if cfg.Provider != "openai" || cfg.Endpoint != DefaultOpenAIEndpoint || cfg.Model != DefaultOpenAIModel { + t.Fatalf("unexpected default embedding config: %+v", cfg) + } +} + +func TestOpenAIHealthCheckUsesCompatibleEmbeddingsEndpoint(t *testing.T) { + var requestBody struct { + Model string `json:"model"` + Input []string `json:"input"` + } + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/embeddings" { + t.Fatalf("unexpected path %s", r.URL.Path) + } + if auth := r.Header.Get("Authorization"); auth == "" { + t.Fatalf("expected authorization header for OpenAI-compatible request") + } + if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil { + t.Fatal(err) + } + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"object":"list","model":"text-embedding-embeddinggemma-300m-qat","data":[{"object":"embedding","index":0,"embedding":[1,0,0]},{"object":"embedding","index":1,"embedding":[0.95,0.05,0]}],"usage":{"prompt_tokens":1,"total_tokens":1}}`)) + })) + defer server.Close() + + cfg, result, err := CheckEmbeddingHealth(context.Background(), EmbeddingConfig{ + Provider: "openai", + Endpoint: server.URL + "/v1/embeddings", + Model: "text-embedding-embeddinggemma-300m-qat", + }) + if err != nil { + t.Fatal(err) + } + if requestBody.Model != "text-embedding-embeddinggemma-300m-qat" || len(requestBody.Input) != 2 { + t.Fatalf("unexpected embeddings request body: %+v", requestBody) + } + if cfg.Dimension != 3 || result.Dimension != 3 || result.Similarity < DefaultEmbeddingHealthThreshold { + t.Fatalf("unexpected health result cfg=%+v result=%+v", cfg, result) + } +} + +func TestOllamaHealthCheckParsesEmbedResponse(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/api/embed" { + t.Fatalf("unexpected path %s", r.URL.Path) + } + _, _ = w.Write([]byte(`{"embeddings":[[1,0,0],[0.95,0.05,0]]}`)) + })) + defer server.Close() + + cfg, result, err := CheckEmbeddingHealth(context.Background(), EmbeddingConfig{ + Provider: "ollama", + Endpoint: server.URL, + Model: "jina/jina-embeddings-v2-base-en", + }) + if err != nil { + t.Fatal(err) + } + if cfg.Dimension != 3 || result.Dimension != 3 || result.Similarity < DefaultEmbeddingHealthThreshold { + t.Fatalf("unexpected health result cfg=%+v result=%+v", cfg, result) + } +} + +func TestSQLiteVecStoresAndQueriesEmbeddings(t *testing.T) { + db := openTestDB(t) + defer func() { _ = db.Close() }() + store := NewStore(db) + modelID, err := store.EnsureEmbeddingModel(context.Background(), EmbeddingConfig{Provider: "local-deterministic-test", Model: "vec", Dimension: 3}, "vec") + if err != nil { + t.Fatal(err) + } + if err := store.SaveEmbedding(context.Background(), modelID, "symbol", "a", "a", vectorBytes(Vector{1, 0, 0})); err != nil { + t.Fatal(err) + } + if err := store.SaveEmbedding(context.Background(), modelID, "symbol", "b", "b", vectorBytes(Vector{0, 1, 0})); err != nil { + t.Fatal(err) + } + var shadowRows int + if err := db.QueryRow(`SELECT COUNT(*) FROM _vec_watch_embedding_vec`).Scan(&shadowRows); err != nil { + t.Fatal(err) + } + if shadowRows != 2 { + t.Fatalf("expected sqlite-vec shadow rows, got %d", shadowRows) + } + ids, err := store.SimilarEmbeddings(context.Background(), modelID, Vector{1, 0, 0}, 1) + if err != nil { + t.Fatal(err) + } + if len(ids) != 1 { + t.Fatalf("expected one sqlite-vec match, got %v", ids) + } +} + +func TestRenamePreservesGeneratedSymbolElementAndConnector(t *testing.T) { + db := openTestDB(t) + defer func() { _ = db.Close() }() + repo := initGitRepoNoCommit(t) + writeFile(t, repo, "main.go", `package main + +func Main() { + FetchUser() +} + +func FetchUser() {} +`) + store := NewStore(db) + scanResult, err := NewScanner(store).Scan(context.Background(), repo) + if err != nil { + t.Fatal(err) + } + if _, err := NewRepresenter(store).Represent(context.Background(), scanResult.RepositoryID, RepresentRequest{Embedding: EmbeddingConfig{Provider: "none"}}); err != nil { + t.Fatal(err) + } + beforeElement := symbolElementID(t, db, "FetchUser") + beforeConnectors := connectorCount(t, db) + + writeFile(t, repo, "main.go", `package main + +func Main() { + LoadUser() +} + +func LoadUser() {} +`) + scanResult, err = NewScanner(store).Scan(context.Background(), repo) + if err != nil { + t.Fatal(err) + } + if _, err := NewRepresenter(store).Represent(context.Background(), scanResult.RepositoryID, RepresentRequest{Embedding: EmbeddingConfig{Provider: "none"}}); err != nil { + t.Fatal(err) + } + afterElement := symbolElementID(t, db, "LoadUser") + if afterElement != beforeElement { + t.Fatalf("rename created a new generated element: before=%d after=%d", beforeElement, afterElement) + } + if afterConnectors := connectorCount(t, db); afterConnectors != beforeConnectors { + t.Fatalf("rename changed connector count: before=%d after=%d", beforeConnectors, afterConnectors) + } +} + +func TestMoveRenamePreservesGeneratedSymbolElement(t *testing.T) { + db := openTestDB(t) + defer func() { _ = db.Close() }() + repo := initGitRepoNoCommit(t) + writeFile(t, repo, "main.go", `package main + +func Main() { + FetchUser() +} + +func FetchUser() int { + value := 41 + return value + 1 +} +`) + store := NewStore(db) + scanResult, err := NewScanner(store).Scan(context.Background(), repo) + if err != nil { + t.Fatal(err) + } + if _, err := NewRepresenter(store).Represent(context.Background(), scanResult.RepositoryID, RepresentRequest{Embedding: EmbeddingConfig{Provider: "none"}}); err != nil { + t.Fatal(err) + } + beforeElement := symbolElementID(t, db, "FetchUser") + + if err := os.Remove(filepath.Join(repo, "main.go")); err != nil { + t.Fatal(err) + } + writeFile(t, repo, "pkg/users.go", `package pkg + +func Main() { + LoadAccount() +} + +func LoadAccount() int { + value := 41 + return value + 1 +} +`) + scanResult, err = NewScanner(store).Scan(context.Background(), repo) + if err != nil { + t.Fatal(err) + } + if _, err := NewRepresenter(store).Represent(context.Background(), scanResult.RepositoryID, RepresentRequest{Embedding: EmbeddingConfig{Provider: "none"}}); err != nil { + t.Fatal(err) + } + afterElement := symbolElementID(t, db, "LoadAccount") + if afterElement != beforeElement { + t.Fatalf("move+rename created a new generated element: before=%d after=%d", beforeElement, afterElement) + } +} + +func TestClusterStableKeyIsDeterministic(t *testing.T) { + left := stableClusterKey(42, "pkg", "settings", []string{"c", "a", "b"}) + right := stableClusterKey(42, "pkg", "settings", []string{"b", "c", "a"}) + if left != right { + t.Fatalf("stable cluster key changed with member order: %s != %s", left, right) + } +} + +func TestWatchSymbolsFromAnalyzerKeepsSameNameMethodsDistinct(t *testing.T) { + source := []byte(`package main + +func (p *Page) Render() {} +func (c *Card) Render() {} +`) + symbols := watchSymbolsFromAnalyzer(1, 2, "view.go", "go", source, []analyzer.Symbol{ + {Name: "Render", Kind: "method", Parent: "Page", Line: 3, EndLine: 3}, + {Name: "Render", Kind: "method", Parent: "Card", Line: 4, EndLine: 4}, + }) + if len(symbols) != 2 { + t.Fatalf("symbols = %d, want 2", len(symbols)) + } + if symbols[0].StableKey == symbols[1].StableKey { + t.Fatalf("stable keys collided: %+v", symbols) + } + if symbols[0].QualifiedName != "Page.Render" || symbols[1].QualifiedName != "Card.Render" { + t.Fatalf("qualified names = %q, %q", symbols[0].QualifiedName, symbols[1].QualifiedName) + } +} + +func TestWatchSymbolsFromAnalyzerDisambiguatesDuplicateKeys(t *testing.T) { + source := []byte("void render();\nvoid render(int value);\n") + symbols := watchSymbolsFromAnalyzer(1, 2, "view.h", "cpp", source, []analyzer.Symbol{ + {Name: "render", Kind: "function", Line: 1, EndLine: 1}, + {Name: "render", Kind: "function", Line: 2, EndLine: 2}, + }) + if len(symbols) != 2 { + t.Fatalf("symbols = %d, want 2", len(symbols)) + } + if symbols[0].StableKey == symbols[1].StableKey { + t.Fatalf("stable keys collided: %+v", symbols) + } + for _, sym := range symbols { + if sym.QualifiedName != "render" { + t.Fatalf("qualified name = %q, want render", sym.QualifiedName) + } + } +} + +func TestScanLocalOnlyRepositoryIsIdempotent(t *testing.T) { + db := openTestDB(t) + defer func() { _ = db.Close() }() + repo := initGitRepoNoCommit(t) + writeFile(t, repo, "main.go", `package main + +func main() { + helper() +} + +func helper() {} +`) + + scanner := NewScanner(NewStore(db)) + first, err := scanner.Scan(context.Background(), repo) + if err != nil { + t.Fatalf("first scan: %v", err) + } + if first.FilesSeen != 1 || first.FilesParsed != 1 || first.FilesSkipped != 0 || first.SymbolsSeen != 2 || first.ReferencesSeen != 1 { + t.Fatalf("unexpected first scan counts: %+v", first) + } + + second, err := scanner.Scan(context.Background(), repo) + if err != nil { + t.Fatalf("second scan: %v", err) + } + if second.FilesSeen != 1 || second.FilesParsed != 0 || second.FilesSkipped != 1 { + t.Fatalf("unexpected second scan counts: %+v", second) + } + third, err := scanner.Scan(context.Background(), repo) + if err != nil { + t.Fatalf("third scan: %v", err) + } + if third.FilesSeen != 1 || third.FilesParsed != 0 || third.FilesSkipped != 1 { + t.Fatalf("unexpected third scan counts after prior skipped status: %+v", third) + } + + store := NewStore(db) + repos, err := store.Repositories(context.Background()) + if err != nil { + t.Fatal(err) + } + if len(repos) != 1 || repos[0].IdentityStatus != "local_only" { + t.Fatalf("expected one local_only repo, got %+v", repos) + } + summary, err := store.Summary(context.Background(), first.RepositoryID) + if err != nil { + t.Fatal(err) + } + if summary.Files != 1 || summary.Symbols != 2 || summary.References != 1 { + t.Fatalf("unexpected summary after idempotent scan: %+v", summary) + } +} + +func TestScanUsesRemoteURLIdentity(t *testing.T) { + db := openTestDB(t) + defer func() { _ = db.Close() }() + repo := initGitRepoNoCommit(t) + runGit(t, repo, "remote", "add", "origin", "git@github.com:owner/repo.git") + writeFile(t, repo, "main.go", "package main\nfunc main() {}\n") + + result, err := NewScanner(NewStore(db)).Scan(context.Background(), repo) + if err != nil { + t.Fatal(err) + } + stored, err := NewStore(db).Repository(context.Background(), result.RepositoryID) + if err != nil { + t.Fatal(err) + } + if !stored.RemoteURL.Valid || stored.RemoteURL.String != "git@github.com:owner/repo.git" || stored.IdentityStatus != "known" { + t.Fatalf("unexpected repository identity: %+v", stored) + } +} + +func TestScanRemovesDeletedFiles(t *testing.T) { + db := openTestDB(t) + defer func() { _ = db.Close() }() + repo := initGitRepoNoCommit(t) + writeFile(t, repo, "one.go", "package main\nfunc one() {}\n") + writeFile(t, repo, "two.go", "package main\nfunc two() {}\n") + + scanner := NewScanner(NewStore(db)) + result, err := scanner.Scan(context.Background(), repo) + if err != nil { + t.Fatal(err) + } + if err := os.Remove(filepath.Join(repo, "two.go")); err != nil { + t.Fatal(err) + } + if _, err := scanner.Scan(context.Background(), repo); err != nil { + t.Fatal(err) + } + summary, err := NewStore(db).Summary(context.Background(), result.RepositoryID) + if err != nil { + t.Fatal(err) + } + if summary.Files != 1 || summary.Symbols != 1 { + t.Fatalf("deleted file was not reconciled: %+v", summary) + } +} + +func TestScanFailsClearlyOutsideGitRepository(t *testing.T) { + db := openTestDB(t) + defer func() { _ = db.Close() }() + _, err := NewScanner(NewStore(db)).Scan(context.Background(), t.TempDir()) + if err == nil || !strings.Contains(err.Error(), "not inside a git repository") { + t.Fatalf("expected git repository error, got %v", err) + } +} + +func TestStatusEndpointReportsActiveWatch(t *testing.T) { + db := openTestDB(t) + defer func() { _ = db.Close() }() + repo := initGitRepoNoCommit(t) + writeFile(t, repo, "main.go", "package main\nfunc main() {}\n") + + store := NewStore(db) + scanResult, err := NewScanner(store).Scan(context.Background(), repo) + if err != nil { + t.Fatal(err) + } + if _, err := store.AcquireLock(context.Background(), scanResult.RepositoryID, os.Getpid(), "token", LockHeartbeatTimeout); err != nil { + t.Fatal(err) + } + + mux := http.NewServeMux() + NewHandler(store).Register(mux) + rec := httptest.NewRecorder() + mux.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/api/watch/status", nil)) + if rec.Code != http.StatusOK { + t.Fatalf("unexpected status code %d: %s", rec.Code, rec.Body.String()) + } + var body struct { + Active bool `json:"active"` + Repository RepositoryJSON `json:"repository"` + Lock Lock `json:"lock"` + } + if err := json.Unmarshal(rec.Body.Bytes(), &body); err != nil { + t.Fatal(err) + } + if !body.Active || body.Repository.ID != scanResult.RepositoryID || body.Lock.RepositoryID != scanResult.RepositoryID { + t.Fatalf("unexpected status body: %+v", body) + } +} + +func TestAcquireLockReplacesDeadProcessLock(t *testing.T) { + db := openTestDB(t) + defer func() { _ = db.Close() }() + repo := initGitRepoNoCommit(t) + writeFile(t, repo, "main.go", "package main\nfunc main() {}\n") + + store := NewStore(db) + scanResult, err := NewScanner(store).Scan(context.Background(), repo) + if err != nil { + t.Fatal(err) + } + originalProcessCheck := watchProcessIsRunning + t.Cleanup(func() { watchProcessIsRunning = originalProcessCheck }) + watchProcessIsRunning = func(pid int) bool { return pid == os.Getpid() } + + if _, err := store.AcquireLock(context.Background(), scanResult.RepositoryID, 999999, "dead-token", LockHeartbeatTimeout); err != nil { + t.Fatal(err) + } + lock, err := store.AcquireLock(context.Background(), scanResult.RepositoryID, os.Getpid(), "live-token", LockHeartbeatTimeout) + if err != nil { + t.Fatalf("expected dead process lock to be replaced: %v", err) + } + if lock.Token != "live-token" || lock.PID != os.Getpid() { + t.Fatalf("unexpected replacement lock: %+v", lock) + } +} + +func TestActiveLiveLockTreatsDeadProcessAsStale(t *testing.T) { + db := openTestDB(t) + defer func() { _ = db.Close() }() + repo := initGitRepoNoCommit(t) + writeFile(t, repo, "main.go", "package main\nfunc main() {}\n") + + store := NewStore(db) + scanResult, err := NewScanner(store).Scan(context.Background(), repo) + if err != nil { + t.Fatal(err) + } + originalProcessCheck := watchProcessIsRunning + t.Cleanup(func() { watchProcessIsRunning = originalProcessCheck }) + watchProcessIsRunning = func(pid int) bool { return false } + + if _, err := store.AcquireLock(context.Background(), scanResult.RepositoryID, 999999, "dead-token", LockHeartbeatTimeout); err != nil { + t.Fatal(err) + } + lock, live, err := store.ActiveLiveLock(context.Background(), LockHeartbeatTimeout) + if err != nil { + t.Fatal(err) + } + if live || lock.Token != "dead-token" { + t.Fatalf("expected dead process lock to be non-live: live=%v lock=%+v", live, lock) + } + status, err := store.LockStatus(context.Background(), scanResult.RepositoryID, "dead-token") + if err != nil { + t.Fatal(err) + } + if status != "stale" { + t.Fatalf("expected stale lock, got %q", status) + } +} + +func TestRequestStopActiveStopsCurrentLock(t *testing.T) { + db := openTestDB(t) + defer func() { _ = db.Close() }() + repo := initGitRepoNoCommit(t) + writeFile(t, repo, "main.go", "package main\nfunc main() {}\n") + + store := NewStore(db) + scanResult, err := NewScanner(store).Scan(context.Background(), repo) + if err != nil { + t.Fatal(err) + } + if _, err := store.AcquireLock(context.Background(), scanResult.RepositoryID, os.Getpid(), "token", LockHeartbeatTimeout); err != nil { + t.Fatal(err) + } + if err := store.RequestStopActive(context.Background()); err != nil { + t.Fatal(err) + } + status, err := store.LockStatus(context.Background(), scanResult.RepositoryID, "token") + if err != nil { + t.Fatal(err) + } + if status != "stopping" { + t.Fatalf("expected stopping lock, got %q", status) + } +} + +func TestPauseResumeActiveLock(t *testing.T) { + db := openTestDB(t) + defer func() { _ = db.Close() }() + repo := initGitRepoNoCommit(t) + writeFile(t, repo, "main.go", "package main\nfunc main() {}\n") + + store := NewStore(db) + scanResult, err := NewScanner(store).Scan(context.Background(), repo) + if err != nil { + t.Fatal(err) + } + if _, err := store.AcquireLock(context.Background(), scanResult.RepositoryID, os.Getpid(), "token", LockHeartbeatTimeout); err != nil { + t.Fatal(err) + } + if err := store.RequestPauseActive(context.Background()); err != nil { + t.Fatal(err) + } + status, err := store.LockStatus(context.Background(), scanResult.RepositoryID, "token") + if err != nil { + t.Fatal(err) + } + if status != "paused" { + t.Fatalf("expected paused lock, got %q", status) + } + if _, err := store.HeartbeatLock(context.Background(), scanResult.RepositoryID, "token"); err != nil { + t.Fatal(err) + } + if err := store.RequestResumeActive(context.Background()); err != nil { + t.Fatal(err) + } + status, err = store.LockStatus(context.Background(), scanResult.RepositoryID, "token") + if err != nil { + t.Fatal(err) + } + if status != "active" { + t.Fatalf("expected active lock, got %q", status) + } +} + +func TestHeartbeatLockReportsMissingOwnLock(t *testing.T) { + db := openTestDB(t) + defer func() { _ = db.Close() }() + repo := initGitRepoNoCommit(t) + writeFile(t, repo, "main.go", "package main\nfunc main() {}\n") + + store := NewStore(db) + scanResult, err := NewScanner(store).Scan(context.Background(), repo) + if err != nil { + t.Fatal(err) + } + if _, err := store.AcquireLock(context.Background(), scanResult.RepositoryID, os.Getpid(), "token", LockHeartbeatTimeout); err != nil { + t.Fatal(err) + } + if err := store.ReleaseLock(context.Background(), scanResult.RepositoryID, "token"); err != nil { + t.Fatal(err) + } + if _, err := store.HeartbeatLock(context.Background(), scanResult.RepositoryID, "token"); !errors.Is(err, sql.ErrNoRows) { + t.Fatalf("expected missing lock error, got %v", err) + } +} + +func TestRunnerStopsCleanlyWhenOwnLockIsReleased(t *testing.T) { + db := openTestDB(t) + defer func() { _ = db.Close() }() + repo := initGitRepoNoCommit(t) + writeFile(t, repo, "main.go", "package main\nfunc main() {}\n") + + ctx := t.Context() + events := make(chan Event, 32) + ready := make(chan RunnerResult, 1) + done := make(chan error, 1) + store := NewStore(db) + go func() { + _, err := NewRunner(store).Run(ctx, RunnerOptions{ + Path: repo, + PollInterval: time.Hour, + HeartbeatInterval: 10 * time.Millisecond, + SummaryInterval: time.Hour, + Embedding: EmbeddingConfig{Provider: "none"}, + Events: events, + Ready: ready, + }) + done <- err + close(events) + }() + + result := waitForRunnerReady(t, ready, done, "released-lock runner") + if err := store.ReleaseLock(context.Background(), result.Repository.ID, result.Token); err != nil { + t.Fatal(err) + } + waitForRunnerDone(t, done, "released-lock runner") +} + +func TestRunnerEmitsChangeCounter(t *testing.T) { + db := openTestDB(t) + defer func() { _ = db.Close() }() + repo := initGitRepoNoCommit(t) + writeFile(t, repo, "main.go", "package main\nfunc main() {}\n") + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + events := make(chan Event, 32) + ready := make(chan RunnerResult, 1) + done := make(chan error, 1) + go func() { + _, err := NewRunner(NewStore(db)).Run(ctx, RunnerOptions{ + Path: repo, + PollInterval: time.Hour, + HeartbeatInterval: time.Hour, + SummaryInterval: 10 * time.Millisecond, + Embedding: EmbeddingConfig{Provider: "none"}, + Events: events, + Ready: ready, + }) + done <- err + close(events) + }() + + waitForRunnerReady(t, ready, done, "change counter runner") + + event := waitForRunnerEvent(t, events, done, "watch.changeCounter", func(event Event) bool { + return event.Type == "watch.changeCounter" + }) + counter, ok := event.Data.(ChangeCounter) + if !ok { + t.Fatalf("unexpected counter payload: %#v", event.Data) + } + if counter.TotalChangesProcessed != 0 || counter.IntervalChangesProcessed != 0 { + t.Fatalf("unexpected idle counter: %+v", counter) + } + cancel() + waitForRunnerDone(t, done, "change counter runner") +} + +func TestRunnerResolvesSubdirectoryToRepositoryRootBeforeReady(t *testing.T) { + db := openTestDB(t) + defer func() { _ = db.Close() }() + repo := initGitRepoNoCommit(t) + writeFile(t, repo, "cmd/app/main.go", `package main + +func Main() {} +`) + subdir := filepath.Join(repo, "cmd", "app") + + ctx, cancel := context.WithCancel(context.Background()) + events := make(chan Event, 32) + ready := make(chan RunnerResult, 1) + done := make(chan error, 1) + go func() { + _, err := NewRunner(NewStore(db)).Run(ctx, RunnerOptions{ + Path: subdir, + PollInterval: time.Hour, + HeartbeatInterval: time.Hour, + SummaryInterval: time.Hour, + Embedding: EmbeddingConfig{Provider: "none"}, + Events: events, + Ready: ready, + }) + done <- err + close(events) + }() + + result := waitForRunnerReady(t, ready, done, "subdirectory runner") + cancel() + waitForRunnerDone(t, done, "subdirectory runner") + expectedRoot, err := filepath.EvalSymlinks(repo) + if err != nil { + t.Fatal(err) + } + actualRoot, err := filepath.EvalSymlinks(result.Repository.RepoRoot) + if err != nil { + t.Fatal(err) + } + if actualRoot != expectedRoot { + t.Fatalf("expected runner repository root %q, got %q", expectedRoot, actualRoot) + } + if result.InitialScan.RepositoryID == 0 || result.InitialRep.RepositoryID == 0 { + t.Fatalf("expected initial scan and representation before ready, got %+v", result) + } +} + +func openTestDB(t *testing.T) *sql.DB { + t.Helper() + db, err := sql.Open("sqlite", filepath.Join(t.TempDir(), "tld.db")) + if err != nil { + t.Fatal(err) + } + db.SetMaxOpenConns(1) + if err := sqlitevec.Register(db); err != nil { + t.Fatal(err) + } + if _, err := db.Exec(`PRAGMA foreign_keys = ON;`); err != nil { + t.Fatal(err) + } + for _, migration := range []string{"001_init.sql", "002_watch_raw_code_graph.sql"} { + data, err := os.ReadFile(filepath.Join("..", "..", "migrations", migration)) + if err != nil { + t.Fatal(err) + } + if _, err := db.Exec(string(data)); err != nil { + t.Fatalf("apply %s: %v", migration, err) + } + } + if _, err := db.Exec(`INSERT INTO views(owner_element_id, name, description, level_label, level, created_at, updated_at) VALUES (NULL, 'Workspace', 'Local offline workspace', 'Root', 1, 'now', 'now')`); err != nil { + t.Fatal(err) + } + return db +} + +func waitForRunnerReady(t *testing.T, ready <-chan RunnerResult, done <-chan error, label string) RunnerResult { + t.Helper() + select { + case result := <-ready: + return result + case err := <-done: + t.Fatalf("%s exited before ready: %v", label, err) + case <-time.After(2 * time.Second): + t.Fatalf("%s did not become ready", label) + } + return RunnerResult{} +} + +func waitForRunnerDone(t *testing.T, done <-chan error, label string) { + t.Helper() + select { + case err := <-done: + if err != nil { + t.Fatal(err) + } + case <-time.After(2 * time.Second): + t.Fatalf("%s did not stop", label) + } +} + +func waitForRunnerEvent(t *testing.T, events <-chan Event, done <-chan error, label string, matches func(Event) bool) Event { + t.Helper() + deadline := time.After(2 * time.Second) + for { + select { + case event, ok := <-events: + if !ok { + t.Fatalf("%s events channel closed before event", label) + } + if matches(event) { + return event + } + case err := <-done: + t.Fatalf("runner exited before %s: %v", label, err) + case <-deadline: + t.Fatalf("runner did not emit %s", label) + } + } +} + +func symbolElementID(t *testing.T, db *sql.DB, name string) int64 { + t.Helper() + var id int64 + if err := db.QueryRow(` + SELECT id FROM elements + WHERE name = ? AND kind = 'function'`, name).Scan(&id); err != nil { + t.Fatalf("find symbol element %s: %v", name, err) + } + return id +} + +func elementIDByName(t *testing.T, db *sql.DB, name string) int64 { + t.Helper() + var id int64 + if err := db.QueryRow(`SELECT id FROM elements WHERE name = ? ORDER BY id LIMIT 1`, name).Scan(&id); err != nil { + t.Fatalf("find element %s: %v", name, err) + } + return id +} + +func elementNameExists(t *testing.T, db *sql.DB, name string) bool { + t.Helper() + var id int64 + err := db.QueryRow(`SELECT id FROM elements WHERE name = ? ORDER BY id LIMIT 1`, name).Scan(&id) + if errors.Is(err, sql.ErrNoRows) { + return false + } + if err != nil { + t.Fatal(err) + } + return true +} + +func insertManualElement(t *testing.T, db *sql.DB, name string) int64 { + t.Helper() + res, err := db.Exec(` + INSERT INTO elements(name, kind, description, technology_connectors, tags, created_at, updated_at) + VALUES (?, 'note', '', '[]', '[]', 'now', 'now')`, name) + if err != nil { + t.Fatal(err) + } + id, err := res.LastInsertId() + if err != nil { + t.Fatal(err) + } + return id +} + +func connectorExistsBetween(t *testing.T, db *sql.DB, sourceName, targetName string) bool { + t.Helper() + var id int64 + err := db.QueryRow(` + SELECT c.id + FROM connectors c + JOIN elements s ON s.id = c.source_element_id + JOIN elements target ON target.id = c.target_element_id + WHERE s.name = ? AND target.name = ? + ORDER BY c.id + LIMIT 1`, sourceName, targetName).Scan(&id) + if errors.Is(err, sql.ErrNoRows) { + return false + } + if err != nil { + t.Fatal(err) + } + return true +} + +func connectorIDBetween(t *testing.T, db *sql.DB, sourceName, targetName string) int64 { + t.Helper() + var id int64 + if err := db.QueryRow(` + SELECT c.id + FROM connectors c + JOIN elements s ON s.id = c.source_element_id + JOIN elements target ON target.id = c.target_element_id + WHERE s.name = ? AND target.name = ? + ORDER BY c.id + LIMIT 1`, sourceName, targetName).Scan(&id); err != nil { + t.Fatalf("connector %s -> %s: %v", sourceName, targetName, err) + } + return id +} + +func activePolicyCount(t *testing.T, db *sql.DB, repositoryID int64, action, ownerType string) int { + t.Helper() + var count int + if err := db.QueryRow(` + SELECT COUNT(*) + FROM watch_context_policies + WHERE repository_id = ? AND action = ? AND owner_type = ? AND active = 1`, repositoryID, action, ownerType).Scan(&count); err != nil { + t.Fatal(err) + } + return count +} + +func materializedResourceID(t *testing.T, db *sql.DB, repositoryID int64, ownerType, ownerKey, resourceType string) int64 { + t.Helper() + var id int64 + if err := db.QueryRow(` + SELECT resource_id + FROM watch_materialization + WHERE repository_id = ? AND owner_type = ? AND owner_key = ? AND resource_type = ?`, repositoryID, ownerType, ownerKey, resourceType).Scan(&id); err != nil { + t.Fatalf("materialized resource %s/%s/%s: %v", ownerType, ownerKey, resourceType, err) + } + return id +} + +func symbolsByName(ctx context.Context, store *Store, repositoryID int64, name string) (Symbol, error) { + symbols, err := store.QuerySymbols(ctx, repositoryID, SymbolQuery{Search: name, Limit: -1}) + if err != nil { + return Symbol{}, err + } + for _, sym := range symbols { + if sym.Name == name || sym.QualifiedName == name { + return sym, nil + } + } + return Symbol{}, fmt.Errorf("symbol %q not found", name) +} + +func filterDecisionHasReason(decisions []FilterDecision, ownerID int64, reason string) bool { + for _, decision := range decisions { + if decision.OwnerType == "symbol" && decision.OwnerID == ownerID && strings.Contains(decision.Reason, reason) { + return true + } + } + return false +} + +func connectorCount(t *testing.T, db *sql.DB) int { + t.Helper() + var count int + if err := db.QueryRow(`SELECT COUNT(*) FROM connectors`).Scan(&count); err != nil { + t.Fatal(err) + } + if count == 0 { + t.Fatal("expected at least one generated connector") + } + return count +} + +func elementKindCount(t *testing.T, db *sql.DB, kind string) int { + t.Helper() + var count int + if err := db.QueryRow(`SELECT COUNT(*) FROM elements WHERE kind = ?`, kind).Scan(&count); err != nil { + t.Fatal(err) + } + return count +} + +func materializationOwnerTypeCount(t *testing.T, db *sql.DB, ownerType string) int { + t.Helper() + var count int + if err := db.QueryRow(`SELECT COUNT(*) FROM watch_materialization WHERE owner_type = ?`, ownerType).Scan(&count); err != nil { + t.Fatal(err) + } + return count +} + +type workspaceCount struct { + Views int + Elements int + Placements int + Connectors int +} + +func workspaceCounts(t *testing.T, db *sql.DB) workspaceCount { + t.Helper() + var count workspaceCount + for query, dest := range map[string]*int{ + `SELECT COUNT(*) FROM views`: &count.Views, + `SELECT COUNT(*) FROM elements`: &count.Elements, + `SELECT COUNT(*) FROM placements`: &count.Placements, + `SELECT COUNT(*) FROM connectors`: &count.Connectors, + } { + if err := db.QueryRow(query).Scan(dest); err != nil { + t.Fatal(err) + } + } + return count +} + +func countElementTag(t *testing.T, db *sql.DB, tag string) int { + t.Helper() + rows, err := db.Query(`SELECT tags FROM elements`) + if err != nil { + t.Fatal(err) + } + defer func() { _ = rows.Close() }() + count := 0 + for rows.Next() { + var raw string + if err := rows.Scan(&raw); err != nil { + t.Fatal(err) + } + var tags []string + _ = json.Unmarshal([]byte(raw), &tags) + for _, item := range tags { + if item == tag { + count++ + } + } + } + if err := rows.Err(); err != nil { + t.Fatal(err) + } + return count +} + +func elementTagsByName(t *testing.T, db *sql.DB, name string) []string { + t.Helper() + var raw string + if err := db.QueryRow(`SELECT tags FROM elements WHERE name = ? ORDER BY id LIMIT 1`, name).Scan(&raw); err != nil { + t.Fatal(err) + } + var tags []string + if err := json.Unmarshal([]byte(raw), &tags); err != nil { + t.Fatal(err) + } + return tags +} + +func tagMetadataByName(t *testing.T, db *sql.DB, name string) (string, *string) { + t.Helper() + var color string + var description sql.NullString + if err := db.QueryRow(`SELECT color, description FROM tags WHERE name = ?`, name).Scan(&color, &description); err != nil { + t.Fatal(err) + } + if description.Valid { + return color, &description.String + } + return color, nil +} + +func factsContain(facts []Fact, factType, tag string) bool { + for _, fact := range facts { + if fact.Type == factType && stringSliceContains(fact.Tags, tag) { + return true + } + } + return false +} + +func stringSliceContains(values []string, needle string) bool { + return slices.Contains(values, needle) +} + +func hasDiff(diffs []RepresentationDiff, resourceType, changeType string) bool { + return findDiff(diffs, resourceType, changeType) != nil +} + +func findDiff(diffs []RepresentationDiff, resourceType, changeType string) *RepresentationDiff { + for _, diff := range diffs { + if diff.ResourceType != nil && *diff.ResourceType == resourceType && diff.ChangeType == changeType { + return &diff + } + } + return nil +} + +func findDiffByOwner(diffs []RepresentationDiff, ownerType, ownerKey, resourceType, changeType string) *RepresentationDiff { + for _, diff := range diffs { + if diff.OwnerType == ownerType && diff.OwnerKey == ownerKey && diff.ResourceType != nil && *diff.ResourceType == resourceType && diff.ChangeType == changeType { + return &diff + } + } + return nil +} + +func languageSet(languages []string) map[string]struct{} { + out := make(map[string]struct{}, len(languages)) + for _, language := range languages { + out[language] = struct{}{} + } + return out +} + +func changeSummary(changes []SourceFileChange) string { + parts := make([]string, 0, len(changes)) + for _, change := range changes { + parts = append(parts, change.Path+":"+change.ChangeType+":"+change.Language) + } + return strings.Join(parts, ",") +} + +type testPlacement struct { + placementID int64 + elementID int64 + x float64 + y float64 +} + +func functionPlacement(t *testing.T, db *sql.DB, name string) testPlacement { + t.Helper() + row := db.QueryRow(` + SELECT p.id, p.element_id, p.position_x, p.position_y + FROM placements p + JOIN elements e ON e.id = p.element_id + WHERE e.kind = 'function' AND (e.name = ? OR e.name LIKE ?) + ORDER BY p.id + LIMIT 1`, name, "%."+name) + var p testPlacement + if err := row.Scan(&p.placementID, &p.elementID, &p.x, &p.y); err != nil { + t.Fatalf("function placement %q: %v", name, err) + } + return p +} + +func elementPlacementByName(t *testing.T, db *sql.DB, name string) testPlacement { + t.Helper() + row := db.QueryRow(` + SELECT p.id, p.element_id, p.position_x, p.position_y + FROM placements p + JOIN elements e ON e.id = p.element_id + WHERE e.name = ? + ORDER BY p.id + LIMIT 1`, name) + var p testPlacement + if err := row.Scan(&p.placementID, &p.elementID, &p.x, &p.y); err != nil { + t.Fatalf("element placement %q: %v", name, err) + } + return p +} + +type countingProvider struct { + calls int + inputs int + batchSizes []string + texts []string +} + +func (p *countingProvider) ModelID() ModelID { + return ModelID{Provider: "local-deterministic-test", Model: "counting", Dimension: 2, ConfigHash: "counting"} +} + +func (p *countingProvider) Embed(_ context.Context, inputs []EmbeddingInput) ([]Vector, error) { + p.calls++ + p.inputs += len(inputs) + p.batchSizes = append(p.batchSizes, fmt.Sprint(len(inputs))) + out := make([]Vector, 0, len(inputs)) + for _, input := range inputs { + p.texts = append(p.texts, input.Text) + out = append(out, Vector{1, 2}) + } + return out, nil +} + +type recordingProgress struct { + starts []string + advances int +} + +func (p *recordingProgress) Start(label string, total int) { + p.starts = append(p.starts, fmt.Sprintf("%s:%d", label, total)) +} + +func (p *recordingProgress) Advance(string) { + p.advances++ +} + +func (p *recordingProgress) Finish() {} + +func TestInferArchitectureSkipsMalformedRuntimeYAML(t *testing.T) { + dir := t.TempDir() + writeFile(t, dir, "chart/templates/workload.yaml", "{{ if .Values.enabled }}\nkind: Deployment\n{{ end }}\n") + writeFile(t, dir, "runtime/topology.yaml", ` +apiVersion: apps/v1 +kind: Deployment +metadata: + name: api +spec: + template: + spec: + containers: + - name: api + image: example/api:latest + ports: + - containerPort: 8080 +`) + + progress := &recordingProgress{} + model := inferArchitectureWithProgress(dir, progress) + if model.Components[architectureKey("component", "api")] == nil { + t.Fatalf("expected api deployment component, got %#v", model.Components) + } + if len(progress.starts) == 0 || progress.advances == 0 { + t.Fatalf("expected architecture progress, starts=%v advances=%d", progress.starts, progress.advances) + } +} + +func TestArchitectureFromFactsPromotesRuntimeAndGRPCGlue(t *testing.T) { + facts := []Fact{ + { + FilePath: "src/frontend/rpc.go", + Type: "grpc.client", + Name: "cartservice", + Relationship: "calls", + Confidence: 0.9, + AttributesJSON: `{"service":"cartservice"}`, + }, + { + FilePath: "src/cartservice/src/Startup.cs", + Type: "datastore.dependency", + Name: "redis-cart", + Relationship: "uses", + Confidence: 0.8, + AttributesJSON: `{"name":"redis-cart","technology":"Redis"}`, + }, + } + + model := architectureFromFacts(facts) + if model.Components[architectureKey("component", "frontend")] == nil { + t.Fatalf("expected frontend component, got %#v", model.Components) + } + if model.Components[architectureKey("component", "cartservice")] == nil { + t.Fatalf("expected cartservice component, got %#v", model.Components) + } + if model.Components[architectureKey("component", "redis-cart")] == nil { + t.Fatalf("expected redis-cart component, got %#v", model.Components) + } + if len(model.Connectors) < 2 { + t.Fatalf("expected grpc and datastore connectors, got %#v", model.Connectors) + } +} + +func TestCanonicalizeArchitectureFoldsGenericDependencyIntoConcreteAlias(t *testing.T) { + model := architectureModel{Components: map[string]*architectureComponent{}, Connectors: map[string]*architectureConnector{}} + frontend := architectureKey("component", "frontend") + redisGeneric := architectureKey("datastore", "redis") + redisConcrete := architectureKey("component", "redis-cart") + model.Components[frontend] = &architectureComponent{Key: frontend, Name: "frontend", Kind: "service", Technology: "Go"} + model.Components[redisGeneric] = &architectureComponent{Key: redisGeneric, Name: "redis", Kind: "datastore", Technology: "Redis", Tags: []string{"datastore:redis"}} + model.Components[redisConcrete] = &architectureComponent{Key: redisConcrete, Name: "redis-cart", Kind: "service", Technology: "Redis", Tags: []string{"runtime:kubernetes"}} + model.Connectors["generic"] = &architectureConnector{Key: "generic", SourceKey: frontend, TargetKey: redisGeneric, Label: "redis", Relationship: "runtime-dependency", Confidence: 0.72, Evidence: []architectureEvidence{{Kind: "datastore.dependency"}}} + model.Connectors["concrete"] = &architectureConnector{Key: "concrete", SourceKey: frontend, TargetKey: redisConcrete, Label: "redis", Relationship: "runtime-dependency", Confidence: 0.78, Evidence: []architectureEvidence{{Kind: "runtime-connection"}}} + + got := pruneDisconnectedArchitecture(canonicalizeArchitecture(model)) + if got.Components[redisGeneric] != nil { + t.Fatalf("generic redis component should fold into concrete alias: %#v", got.Components) + } + if got.Components[redisConcrete] == nil { + t.Fatalf("concrete redis component should remain: %#v", got.Components) + } + if len(got.Connectors) != 1 { + t.Fatalf("duplicate redis connectors should merge, got %#v", got.Connectors) + } + for _, connector := range got.Connectors { + if connector.TargetKey != redisConcrete { + t.Fatalf("connector should target concrete redis alias, got %#v", connector) + } + if len(connector.Evidence) != 2 { + t.Fatalf("merged connector should preserve evidence, got %#v", connector.Evidence) + } + } +} + +func TestCanonicalizeArchitectureMergesParallelConnectorsAfterAliasFold(t *testing.T) { + model := architectureModel{Components: map[string]*architectureComponent{}, Connectors: map[string]*architectureConnector{}} + loadgenerator := architectureKey("component", "loadgenerator") + frontend := architectureKey("component", "frontend") + frontendService := architectureKey("component", "frontendservice") + model.Components[loadgenerator] = &architectureComponent{Key: loadgenerator, Name: "loadgenerator", Kind: "service", Evidence: []architectureEvidence{{Kind: "deployable"}}} + model.Components[frontend] = &architectureComponent{Key: frontend, Name: "frontend", Kind: "service", Evidence: []architectureEvidence{{Kind: "grpc.server"}}} + model.Components[frontendService] = &architectureComponent{Key: frontendService, Name: "frontendservice", Kind: "service", Evidence: []architectureEvidence{{Kind: "deployable"}}} + model.Connectors["a"] = &architectureConnector{Key: "a", SourceKey: loadgenerator, TargetKey: frontend, Label: "grpc", Relationship: "runtime-dependency", Direction: "forward", Evidence: []architectureEvidence{{Kind: "grpc.client"}}} + model.Connectors["b"] = &architectureConnector{Key: "b", SourceKey: loadgenerator, TargetKey: frontendService, Label: "uses", Relationship: "runtime-dependency", Direction: "forward", Evidence: []architectureEvidence{{Kind: "runtime.connection"}}} + + got := pruneDisconnectedArchitecture(canonicalizeArchitecture(model)) + if len(got.Connectors) != 1 { + t.Fatalf("folded aliases should have one merged connector, got %#v", got.Connectors) + } + for _, connector := range got.Connectors { + if connector.Label != "" { + t.Fatalf("merged connector should use an empty label, got %#v", connector) + } + if connector.Direction != "forward" { + t.Fatalf("same-direction merged connector should stay forward, got %#v", connector) + } + if len(connector.Evidence) != 2 { + t.Fatalf("merged connector should preserve evidence, got %#v", connector.Evidence) + } + } +} + +func TestCanonicalizeArchitectureMergesOppositeConnectorDirections(t *testing.T) { + model := architectureModel{Components: map[string]*architectureComponent{}, Connectors: map[string]*architectureConnector{}} + client := architectureKey("component", "client") + server := architectureKey("component", "server") + model.Components[client] = &architectureComponent{Key: client, Name: "client", Kind: "service"} + model.Components[server] = &architectureComponent{Key: server, Name: "server", Kind: "service"} + model.Connectors["a"] = &architectureConnector{Key: "a", SourceKey: client, TargetKey: server, Label: "grpc", Relationship: "runtime-dependency", Direction: "forward"} + model.Connectors["b"] = &architectureConnector{Key: "b", SourceKey: server, TargetKey: client, Label: "uses", Relationship: "runtime-dependency", Direction: "forward"} + + got := pruneDisconnectedArchitecture(canonicalizeArchitecture(model)) + if len(got.Connectors) != 1 { + t.Fatalf("opposite connectors should merge, got %#v", got.Connectors) + } + for _, connector := range got.Connectors { + if connector.Label != "" { + t.Fatalf("merged connector should use an empty label, got %#v", connector) + } + if connector.Direction != "both" { + t.Fatalf("opposite directions should merge to both, got %#v", connector) + } + } +} + +func TestCanonicalizeArchitectureDoesNotMergeDistinctConcreteDependencies(t *testing.T) { + model := architectureModel{Components: map[string]*architectureComponent{}, Connectors: map[string]*architectureConnector{}} + cart := architectureKey("component", "cartservice") + catalog := architectureKey("component", "productcatalogservice") + redisCart := architectureKey("component", "redis-cart") + redisData := architectureKey("component", "redis-data") + model.Components[cart] = &architectureComponent{Key: cart, Name: "cartservice", Kind: "service", Technology: "Go"} + model.Components[catalog] = &architectureComponent{Key: catalog, Name: "productcatalogservice", Kind: "service", Technology: "Go"} + model.Components[redisCart] = &architectureComponent{Key: redisCart, Name: "redis-cart", Kind: "service", Technology: "Redis"} + model.Components[redisData] = &architectureComponent{Key: redisData, Name: "redis-data", Kind: "service", Technology: "Redis"} + model.Connectors["cart"] = &architectureConnector{Key: "cart", SourceKey: cart, TargetKey: redisCart, Label: "redis", Relationship: "runtime-dependency"} + model.Connectors["catalog"] = &architectureConnector{Key: "catalog", SourceKey: catalog, TargetKey: redisData, Label: "redis", Relationship: "runtime-dependency"} + + got := pruneDisconnectedArchitecture(canonicalizeArchitecture(model)) + if got.Components[redisCart] == nil || got.Components[redisData] == nil { + t.Fatalf("distinct concrete redis dependencies should remain separate: %#v", got.Components) + } + if len(got.Connectors) != 2 { + t.Fatalf("expected separate concrete dependency connectors, got %#v", got.Connectors) + } +} + +func TestCanonicalizeArchitectureFoldsServiceRoleNameVariants(t *testing.T) { + model := architectureModel{Components: map[string]*architectureComponent{}, Connectors: map[string]*architectureConnector{}} + frontend := architectureKey("component", "frontend") + payment := architectureKey("component", "payment") + paymentService := architectureKey("component", "paymentservice") + paymentContract := architectureKey("contract", "PaymentService") + paymentAPI := architectureKey("component", "paymentAPI") + paymentDB := architectureKey("component", "paymentDB") + model.Components[frontend] = &architectureComponent{Key: frontend, Name: "frontend", Kind: "service", Technology: "Go", Evidence: []architectureEvidence{{Kind: "deployable"}}} + model.Components[payment] = &architectureComponent{Key: payment, Name: "payment", Kind: "service", Technology: "gRPC", Evidence: []architectureEvidence{{Kind: "grpc.server"}}} + model.Components[paymentService] = &architectureComponent{Key: paymentService, Name: "paymentservice", Kind: "service", Technology: "Kubernetes", Evidence: []architectureEvidence{{Kind: "deployable"}}} + model.Components[paymentContract] = &architectureComponent{Key: paymentContract, Name: "PaymentService", Kind: "interface", Technology: "gRPC", Evidence: []architectureEvidence{{Kind: "service-contract"}}} + model.Components[paymentAPI] = &architectureComponent{Key: paymentAPI, Name: "paymentAPI", Kind: "service", Technology: "OpenAPI", Evidence: []architectureEvidence{{Kind: "runtime-component"}}} + model.Components[paymentDB] = &architectureComponent{Key: paymentDB, Name: "paymentDB", Kind: "service", Technology: "PostgreSQL", Evidence: []architectureEvidence{{Kind: "runtime-component"}}} + model.Connectors["frontend-payment"] = &architectureConnector{Key: "frontend-payment", SourceKey: frontend, TargetKey: payment, Label: "grpc", Relationship: "runtime-dependency"} + model.Connectors["frontend-paymentservice"] = &architectureConnector{Key: "frontend-paymentservice", SourceKey: frontend, TargetKey: paymentService, Label: "grpc", Relationship: "runtime-dependency"} + model.Connectors["frontend-paymentapi"] = &architectureConnector{Key: "frontend-paymentapi", SourceKey: frontend, TargetKey: paymentAPI, Label: "http", Relationship: "runtime-dependency"} + model.Connectors["payment-paymentdb"] = &architectureConnector{Key: "payment-paymentdb", SourceKey: payment, TargetKey: paymentDB, Label: "uses", Relationship: "runtime-dependency"} + + got := pruneDisconnectedArchitecture(canonicalizeArchitecture(model)) + if got.Components[paymentService] == nil { + t.Fatalf("paymentservice should be canonical service alias, got %#v", got.Components) + } + for _, folded := range []string{payment, paymentContract, paymentAPI, paymentDB} { + if got.Components[folded] != nil { + t.Fatalf("%s should fold into paymentservice, got %#v", folded, got.Components) + } + } + for _, connector := range got.Connectors { + if connector.SourceKey != frontend && connector.TargetKey != paymentService { + t.Fatalf("connectors should be rewritten to paymentservice alias or pruned, got %#v", connector) + } + } +} + +func TestCanonicalizeArchitectureDoesNotFoldEmbeddedServiceRootNames(t *testing.T) { + model := architectureModel{Components: map[string]*architectureComponent{}, Connectors: map[string]*architectureConnector{}} + frontend := architectureKey("component", "frontend") + payment := architectureKey("component", "paymentservice") + proxy := architectureKey("component", "shipping-payment-proxy") + model.Components[frontend] = &architectureComponent{Key: frontend, Name: "frontend", Kind: "service", Evidence: []architectureEvidence{{Kind: "deployable"}}} + model.Components[payment] = &architectureComponent{Key: payment, Name: "paymentservice", Kind: "service", Evidence: []architectureEvidence{{Kind: "deployable"}}} + model.Components[proxy] = &architectureComponent{Key: proxy, Name: "shipping-payment-proxy", Kind: "service", Evidence: []architectureEvidence{{Kind: "deployable"}}} + model.Connectors["frontend-payment"] = &architectureConnector{Key: "frontend-payment", SourceKey: frontend, TargetKey: payment, Label: "grpc", Relationship: "runtime-dependency"} + model.Connectors["frontend-proxy"] = &architectureConnector{Key: "frontend-proxy", SourceKey: frontend, TargetKey: proxy, Label: "http", Relationship: "runtime-dependency"} + + got := pruneDisconnectedArchitecture(canonicalizeArchitecture(model)) + if got.Components[payment] == nil || got.Components[proxy] == nil { + t.Fatalf("embedded root names should not fold unrelated services, got %#v", got.Components) + } + if len(got.Connectors) != 2 { + t.Fatalf("expected separate connectors, got %#v", got.Connectors) + } +} + +func TestCanonicalizeArchitectureFoldsMultiTokenCompactServiceVariants(t *testing.T) { + model := architectureModel{Components: map[string]*architectureComponent{}, Connectors: map[string]*architectureConnector{}} + frontend := architectureKey("component", "frontend") + productCatalog := architectureKey("component", "product-catalog") + productCatalogService := architectureKey("component", "productcatalogservice") + productCatalogContract := architectureKey("contract", "ProductCatalogService") + model.Components[frontend] = &architectureComponent{Key: frontend, Name: "frontend", Kind: "service", Evidence: []architectureEvidence{{Kind: "deployable"}}} + model.Components[productCatalog] = &architectureComponent{Key: productCatalog, Name: "product-catalog", Kind: "service", Technology: "gRPC", Evidence: []architectureEvidence{{Kind: "grpc.server"}}} + model.Components[productCatalogService] = &architectureComponent{Key: productCatalogService, Name: "productcatalogservice", Kind: "service", Technology: "Kubernetes", Evidence: []architectureEvidence{{Kind: "deployable"}}} + model.Components[productCatalogContract] = &architectureComponent{Key: productCatalogContract, Name: "ProductCatalogService", Kind: "interface", Technology: "gRPC", Evidence: []architectureEvidence{{Kind: "service-contract"}}} + model.Connectors["frontend-product"] = &architectureConnector{Key: "frontend-product", SourceKey: frontend, TargetKey: productCatalog, Label: "grpc", Relationship: "runtime-dependency"} + model.Connectors["frontend-productservice"] = &architectureConnector{Key: "frontend-productservice", SourceKey: frontend, TargetKey: productCatalogService, Label: "grpc", Relationship: "runtime-dependency"} + + got := pruneDisconnectedArchitecture(canonicalizeArchitecture(model)) + if got.Components[productCatalogService] == nil { + t.Fatalf("productcatalogservice should be canonical service alias, got %#v", got.Components) + } + if got.Components[productCatalog] != nil || got.Components[productCatalogContract] != nil { + t.Fatalf("product catalog variants should fold into productcatalogservice, got %#v", got.Components) + } +} + +func TestCanonicalizeArchitectureFoldsShortServiceRootsWithRoleEvidence(t *testing.T) { + model := architectureModel{Components: map[string]*architectureComponent{}, Connectors: map[string]*architectureConnector{}} + frontend := architectureKey("component", "frontend") + ad := architectureKey("component", "ad") + adService := architectureKey("component", "adservice") + adContract := architectureKey("contract", "AdService") + model.Components[frontend] = &architectureComponent{Key: frontend, Name: "frontend", Kind: "service", Evidence: []architectureEvidence{{Kind: "deployable"}}} + model.Components[ad] = &architectureComponent{Key: ad, Name: "ad", Kind: "service", Technology: "gRPC", Evidence: []architectureEvidence{{Kind: "grpc.server"}}} + model.Components[adService] = &architectureComponent{Key: adService, Name: "adservice", Kind: "service", Technology: "Kubernetes", Evidence: []architectureEvidence{{Kind: "deployable"}}} + model.Components[adContract] = &architectureComponent{Key: adContract, Name: "AdService", Kind: "interface", Technology: "gRPC", Evidence: []architectureEvidence{{Kind: "service-contract"}}} + model.Connectors["frontend-ad"] = &architectureConnector{Key: "frontend-ad", SourceKey: frontend, TargetKey: ad, Label: "grpc", Relationship: "runtime-dependency"} + model.Connectors["frontend-adservice"] = &architectureConnector{Key: "frontend-adservice", SourceKey: frontend, TargetKey: adService, Label: "grpc", Relationship: "runtime-dependency"} + + got := pruneDisconnectedArchitecture(canonicalizeArchitecture(model)) + if got.Components[adService] == nil { + t.Fatalf("adservice should be canonical service alias, got %#v", got.Components) + } + if got.Components[ad] != nil || got.Components[adContract] != nil { + t.Fatalf("ad variants should fold into adservice, got %#v", got.Components) + } +} + +func TestCanonicalizeArchitectureFoldsShortGRPCClientTargetIntoDeployableService(t *testing.T) { + model := mergeArchitectureModels( + architectureFromFacts([]Fact{ + { + FilePath: "src/frontend/rpc.go", + Type: "grpc.client", + Name: "ad", + Relationship: "calls", + Confidence: 0.9, + AttributesJSON: `{"service":"ad"}`, + }, + { + FilePath: "src/adservice/deploy.yaml", + Type: "runtime.component", + Name: "adservice", + Relationship: "deploys", + Confidence: 0.9, + AttributesJSON: `{"name":"adservice","kind":"service","technology":"Kubernetes"}`, + }, + }), + ) + + got := pruneDisconnectedArchitecture(canonicalizeArchitecture(model)) + ad := architectureKey("component", "ad") + adService := architectureKey("component", "adservice") + if got.Components[adService] == nil { + t.Fatalf("adservice should remain canonical, got %#v", got.Components) + } + if got.Components[ad] != nil { + t.Fatalf("grpc client target ad should fold into adservice, got %#v", got.Components) + } + for _, connector := range got.Connectors { + if connector.TargetKey != adService { + t.Fatalf("connector should target folded adservice alias, got %#v", connector) + } + } +} + +func TestResolveArchitectureBindingsUsesGenericSignals(t *testing.T) { + repo := Repository{ID: 1, DisplayName: "demo"} + tests := []struct { + name string + component *architectureComponent + targets []ArchitectureBindingTarget + wantTarget string + }{ + { + name: "service folder under services", + component: &architectureComponent{ + Key: architectureKey("component", "billingservice"), + Name: "billingservice", + FilePath: "deploy/billing.yaml", + Evidence: []architectureEvidence{ + {Kind: "deployable", Path: "deploy/billing.yaml", Note: "Deployment"}, + {Kind: "grpc.server", Path: "services/billing/main.go", Note: "billing"}, + }, + }, + targets: []ArchitectureBindingTarget{architectureBindingTestTarget(1, "folder", "folder:services/billing", "billing", "folder", "services/billing")}, + wantTarget: "folder:services/billing", + }, + { + name: "service folder under apps", + component: &architectureComponent{ + Key: architectureKey("component", "checkout"), + Name: "checkout", + FilePath: "ops/checkout.yaml", + Evidence: []architectureEvidence{ + {Kind: "runtime-component", Path: "apps/checkout/server.go", Note: "checkout"}, + }, + }, + targets: []ArchitectureBindingTarget{architectureBindingTestTarget(1, "folder", "folder:apps/checkout", "checkout", "folder", "apps/checkout")}, + wantTarget: "folder:apps/checkout", + }, + { + name: "language package layout", + component: &architectureComponent{ + Key: architectureKey("component", "catalog"), + Name: "catalog", + FilePath: "manifests/catalog.yaml", + Evidence: []architectureEvidence{ + {Kind: "grpc.server", Path: "cmd/catalog/main.go", Note: "catalog"}, + }, + }, + targets: []ArchitectureBindingTarget{architectureBindingTestTarget(1, "file", "file:cmd/catalog/main.go", "main.go", "file", "cmd/catalog/main.go")}, + wantTarget: "file:cmd/catalog/main.go", + }, + { + name: "external stays unbound", + component: &architectureComponent{ + Key: architectureKey("external", "stripe"), + Name: "stripe", + Kind: "external", + }, + targets: []ArchitectureBindingTarget{architectureBindingTestTarget(1, "folder", "folder:payments", "payments", "folder", "payments")}, + wantTarget: "", + }, + { + name: "ambiguous exact names stay unbound", + component: &architectureComponent{ + Key: architectureKey("component", "payment"), + Name: "payment", + Kind: "service", + }, + targets: []ArchitectureBindingTarget{ + architectureBindingTestTarget(1, "folder", "folder:apps/payment", "payment", "folder", "apps/payment"), + architectureBindingTestTarget(1, "folder", "folder:services/payment", "payment", "folder", "services/payment"), + }, + wantTarget: "", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + model := architectureModel{Components: map[string]*architectureComponent{tt.component.Key: tt.component}} + bindings := resolveArchitectureBindings(repo, model, tt.targets) + if tt.wantTarget == "" { + if len(bindings) != 0 { + t.Fatalf("expected no bindings, got %+v", bindings) + } + return + } + if len(bindings) == 0 || bindings[0].TargetOwnerKey != tt.wantTarget { + t.Fatalf("expected primary target %q, got %+v", tt.wantTarget, bindings) + } + }) + } +} + +func architectureBindingTestTarget(repoID int64, ownerType, ownerKey, name, kind, filePath string) ArchitectureBindingTarget { + return ArchitectureBindingTarget{ + RepositoryID: repoID, + OwnerType: ownerType, + OwnerKey: ownerKey, + ResourceType: "element", + ResourceID: int64(len(ownerKey)), + Name: name, + Kind: kind, + FilePath: filePath, + } +} + +func initGitRepoNoCommit(t *testing.T) string { + t.Helper() + dir := t.TempDir() + runGit(t, dir, "init") + runGit(t, dir, "config", "user.email", "test@example.com") + runGit(t, dir, "config", "user.name", "Test User") + return dir +} + +func runGit(t *testing.T, dir string, args ...string) { + t.Helper() + cmd := exec.Command("git", args...) + cmd.Dir = dir + if out, err := cmd.CombinedOutput(); err != nil { + t.Fatalf("git %v failed: %v\n%s", args, err, out) + } +} + +func writeFile(t *testing.T, root, name, content string) { + t.Helper() + path := filepath.Join(root, name) + if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(path, []byte(content), 0o644); err != nil { + t.Fatal(err) + } +} diff --git a/internal/watch/watcher.go b/internal/watch/watcher.go new file mode 100644 index 0000000..0ff5a25 --- /dev/null +++ b/internal/watch/watcher.go @@ -0,0 +1,108 @@ +package watch + +import ( + "context" + "io/fs" + "os" + "path/filepath" + "strings" + + "github.com/fsnotify/fsnotify" + "github.com/mertcikla/tld/internal/ignore" +) + +type sourceWatcher struct { + Mode string + Events <-chan struct{} + Warnings []string + Close func() error +} + +func newSourceWatcher(ctx context.Context, root string, settings Settings, rules *ignore.Rules) sourceWatcher { + settings = NormalizeSettings(settings) + if settings.Watcher == WatcherPoll { + return sourceWatcher{Mode: WatcherPoll} + } + watcher, err := fsnotify.NewWatcher() + if err != nil { + if settings.Watcher == WatcherFSNotify { + return sourceWatcher{Mode: WatcherPoll, Warnings: []string{"fsnotify unavailable: " + err.Error()}} + } + return sourceWatcher{Mode: WatcherPoll, Warnings: []string{"fsnotify unavailable; using poll fallback"}} + } + ch := make(chan struct{}, 1) + allowed := map[string]struct{}{} + for _, lang := range settings.Languages { + allowed[lang] = struct{}{} + } + if rules == nil { + rules = &ignore.Rules{} + } + walkErr := filepath.WalkDir(root, func(path string, d fs.DirEntry, err error) error { + if err != nil { + return nil + } + if !d.IsDir() { + return nil + } + rel, _ := filepath.Rel(root, path) + rel = filepath.ToSlash(rel) + if rel != "." && (rules.ShouldIgnorePath(rel) || isHiddenBuildOutput(d.Name())) { + return filepath.SkipDir + } + _ = watcher.Add(path) + return nil + }) + warnings := []string{} + if walkErr != nil { + warnings = append(warnings, "fsnotify setup warning: "+walkErr.Error()) + } + go func() { + defer close(ch) + defer func() { _ = watcher.Close() }() + for { + select { + case <-ctx.Done(): + return + case event, ok := <-watcher.Events: + if !ok { + return + } + if event.Has(fsnotify.Create) { + if info, err := filepathAbsStat(event.Name); err == nil && info.IsDir() { + _ = watcher.Add(event.Name) + } + } + if sourceEventRelevant(root, event.Name, allowed, rules) { + select { + case ch <- struct{}{}: + default: + } + } + case <-watcher.Errors: + select { + case ch <- struct{}{}: + default: + } + } + } + }() + return sourceWatcher{Mode: WatcherFSNotify, Events: ch, Warnings: warnings, Close: watcher.Close} +} + +func sourceEventRelevant(root, eventPath string, allowed map[string]struct{}, rules *ignore.Rules) bool { + rel, err := filepath.Rel(root, eventPath) + if err != nil || strings.HasPrefix(rel, "..") { + return false + } + rel = filepath.ToSlash(rel) + if rules != nil && rules.ShouldIgnorePath(rel) { + return false + } + language, _, ok := watchedFileLanguage(eventPath) + return ok && languageAllowed(language, allowed) +} + +func filepathAbsStat(path string) (fs.FileInfo, error) { + return os.Stat(path) +} diff --git a/internal/watch/websocket.go b/internal/watch/websocket.go new file mode 100644 index 0000000..0aa2658 --- /dev/null +++ b/internal/watch/websocket.go @@ -0,0 +1,293 @@ +package watch + +import ( + "bufio" + "context" + "crypto/sha1" + "encoding/base64" + "encoding/binary" + "encoding/json" + "errors" + "io" + "net" + "net/http" + "strings" + "sync/atomic" + "time" +) + +var watchWebSocketClients atomic.Int64 + +func WatchWebSocketClientCount() int { + return int(watchWebSocketClients.Load()) +} + +func (h *Handler) watchWebSocket(w http.ResponseWriter, r *http.Request) { + if !strings.EqualFold(r.Header.Get("Upgrade"), "websocket") { + writeError(w, http.StatusBadRequest, "websocket upgrade required") + return + } + conn, rw, err := upgradeWebSocket(w, r) + if err != nil { + return + } + clients := watchWebSocketClients.Add(1) + defer func() { _ = conn.Close() }() + defer watchWebSocketClients.Add(-1) + + ctx, cancel := context.WithCancel(r.Context()) + defer cancel() + controlEvents := make(chan Event, 4) + go h.watchWebSocketReads(ctx, rw, controlEvents, cancel) + + if err := writeWebSocketJSON(rw, Event{Type: "watch.connected", At: nowString(), Data: map[string]int64{"clients": clients}}); err != nil { + return + } + + ticker := time.NewTicker(time.Second) + defer ticker.Stop() + var lastRepresentationHash string + var lastVersionID int64 + for { + lock, live, err := h.Store.ActiveLiveLock(ctx, LockHeartbeatTimeout) + if err != nil { + _ = writeWebSocketJSON(rw, Event{Type: "watch.error", At: nowString(), Message: err.Error()}) + return + } + if live && lock.Status == "stopping" { + _ = writeWebSocketJSON(rw, Event{Type: "lock.disabled", RepositoryID: lock.RepositoryID, At: nowString(), Data: lock}) + _ = writeWebSocketJSON(rw, Event{Type: "watch.stopped", RepositoryID: lock.RepositoryID, At: nowString(), Data: lock}) + return + } + eventType := "watch.stopped" + if live { + eventType = "watch.heartbeat" + if lock.Status == "paused" { + eventType = "watch.paused" + } + } + if err := writeWebSocketJSON(rw, Event{Type: eventType, RepositoryID: lock.RepositoryID, At: nowString(), Data: lock}); err != nil { + return + } + if live { + summary, err := h.Store.RepresentationSummary(ctx, lock.RepositoryID) + if err == nil && summary.RepresentationHash != "" && summary.RepresentationHash != lastRepresentationHash { + if lastRepresentationHash != "" { + if diffs, diffErr := h.Store.BuildWatchDiffs(ctx, lock.RepositoryID, summary.RepresentationHash); diffErr == nil { + summary.Diffs = diffs + } + } + lastRepresentationHash = summary.RepresentationHash + if err := writeWebSocketJSON(rw, Event{Type: "representation.updated", RepositoryID: lock.RepositoryID, At: nowString(), Data: summary}); err != nil { + return + } + } + version, found, err := h.Store.LatestWatchVersion(ctx, lock.RepositoryID) + if err == nil && found && version.ID != lastVersionID { + lastVersionID = version.ID + if err := writeWebSocketJSON(rw, Event{Type: "version.created", RepositoryID: lock.RepositoryID, At: nowString(), Data: version}); err != nil { + return + } + } + } + select { + case <-ctx.Done(): + return + case event := <-controlEvents: + for { + if err := writeWebSocketJSON(rw, event); err != nil { + return + } + if event.Type == "watch.stopped" { + return + } + select { + case event = <-controlEvents: + default: + goto next + } + } + case <-ticker.C: + } + next: + } +} + +func (h *Handler) watchWebSocketReads(ctx context.Context, reader io.Reader, controlEvents chan<- Event, cancel context.CancelFunc) { + defer cancel() + for { + msg, err := readWebSocketMessage(reader) + if err != nil { + return + } + var body struct { + Type string `json:"type"` + RepositoryID int64 `json:"repository_id"` + RemoteURL string `json:"remote_url"` + } + if err := json.Unmarshal(msg, &body); err != nil { + continue + } + switch body.Type { + case "watch.pause": + if body.RepositoryID > 0 { + _ = h.Store.RequestPause(ctx, body.RepositoryID) + emitControlEvent(controlEvents, Event{Type: "watch.paused", RepositoryID: body.RepositoryID, At: nowString()}) + } else { + lock, live, _ := h.Store.ActiveLiveLock(ctx, LockHeartbeatTimeout) + _ = h.Store.RequestPauseActive(ctx) + if live { + emitControlEvent(controlEvents, Event{Type: "watch.paused", RepositoryID: lock.RepositoryID, At: nowString(), Data: lock}) + } + } + case "watch.resume": + if body.RepositoryID > 0 { + _ = h.Store.RequestResume(ctx, body.RepositoryID) + emitControlEvent(controlEvents, Event{Type: "watch.heartbeat", RepositoryID: body.RepositoryID, At: nowString()}) + } else { + lock, live, _ := h.Store.ActiveLiveLock(ctx, LockHeartbeatTimeout) + _ = h.Store.RequestResumeActive(ctx) + if live { + emitControlEvent(controlEvents, Event{Type: "watch.heartbeat", RepositoryID: lock.RepositoryID, At: nowString(), Data: lock}) + } + } + case "watch.stop": + if body.RepositoryID > 0 { + _ = h.Store.RequestStop(ctx, body.RepositoryID) + emitControlEvent(controlEvents, Event{Type: "lock.disabled", RepositoryID: body.RepositoryID, At: nowString()}) + emitControlEvent(controlEvents, Event{Type: "watch.stopped", RepositoryID: body.RepositoryID, At: nowString()}) + } else { + lock, live, _ := h.Store.ActiveLiveLock(ctx, LockHeartbeatTimeout) + _ = h.Store.RequestStopActive(ctx) + if live { + emitControlEvent(controlEvents, Event{Type: "lock.disabled", RepositoryID: lock.RepositoryID, At: nowString(), Data: lock}) + emitControlEvent(controlEvents, Event{Type: "watch.stopped", RepositoryID: lock.RepositoryID, At: nowString(), Data: lock}) + } else { + emitControlEvent(controlEvents, Event{Type: "watch.stopped", At: nowString()}) + } + } + case "watch.reassociateRepo": + if body.RepositoryID > 0 && strings.TrimSpace(body.RemoteURL) != "" { + _, _ = h.Store.ReassociateRepository(ctx, body.RepositoryID, body.RemoteURL) + } + case "watch.status", "watch.rescan": + } + } +} + +func emitControlEvent(ch chan<- Event, event Event) { + select { + case ch <- event: + default: + } +} + +func upgradeWebSocket(w http.ResponseWriter, r *http.Request) (net.Conn, *bufio.ReadWriter, error) { + hj, ok := w.(http.Hijacker) + if !ok { + http.Error(w, "websocket unsupported", http.StatusInternalServerError) + return nil, nil, errors.New("hijack unsupported") + } + key := strings.TrimSpace(r.Header.Get("Sec-WebSocket-Key")) + if key == "" { + http.Error(w, "missing Sec-WebSocket-Key", http.StatusBadRequest) + return nil, nil, errors.New("missing websocket key") + } + conn, rw, err := hj.Hijack() + if err != nil { + return nil, nil, err + } + accept := websocketAccept(key) + _, err = rw.WriteString("HTTP/1.1 101 Switching Protocols\r\n" + + "Upgrade: websocket\r\n" + + "Connection: Upgrade\r\n" + + "Sec-WebSocket-Accept: " + accept + "\r\n\r\n") + if err != nil { + _ = conn.Close() + return nil, nil, err + } + if err := rw.Flush(); err != nil { + _ = conn.Close() + return nil, nil, err + } + return conn, rw, nil +} + +func websocketAccept(key string) string { + sum := sha1.Sum([]byte(key + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11")) + return base64.StdEncoding.EncodeToString(sum[:]) +} + +func writeWebSocketJSON(rw *bufio.ReadWriter, value any) error { + data, err := json.Marshal(value) + if err != nil { + return err + } + if err := writeWebSocketFrame(rw, data); err != nil { + return err + } + return rw.Flush() +} + +func writeWebSocketFrame(w io.Writer, payload []byte) error { + header := []byte{0x81} + switch { + case len(payload) < 126: + header = append(header, byte(len(payload))) + case len(payload) <= 65535: + header = append(header, 126, byte(len(payload)>>8), byte(len(payload))) + default: + header = append(header, 127) + var size [8]byte + binary.BigEndian.PutUint64(size[:], uint64(len(payload))) + header = append(header, size[:]...) + } + if _, err := w.Write(header); err != nil { + return err + } + _, err := w.Write(payload) + return err +} + +func readWebSocketMessage(r io.Reader) ([]byte, error) { + var hdr [2]byte + if _, err := io.ReadFull(r, hdr[:]); err != nil { + return nil, err + } + opcode := hdr[0] & 0x0f + if opcode == 0x8 { + return nil, io.EOF + } + masked := hdr[1]&0x80 != 0 + length := int(hdr[1] & 0x7f) + switch length { + case 126: + var ext [2]byte + if _, err := io.ReadFull(r, ext[:]); err != nil { + return nil, err + } + length = int(binary.BigEndian.Uint16(ext[:])) + case 127: + var ext [8]byte + if _, err := io.ReadFull(r, ext[:]); err != nil { + return nil, err + } + length = int(binary.BigEndian.Uint64(ext[:])) + } + var mask [4]byte + if masked { + if _, err := io.ReadFull(r, mask[:]); err != nil { + return nil, err + } + } + payload := make([]byte, length) + if _, err := io.ReadFull(r, payload); err != nil { + return nil, err + } + if masked { + for i := range payload { + payload[i] ^= mask[i%4] + } + } + return payload, nil +} diff --git a/internal/workspace/config.go b/internal/workspace/config.go index a4b2633..32972d1 100644 --- a/internal/workspace/config.go +++ b/internal/workspace/config.go @@ -5,8 +5,6 @@ import ( "os" "path/filepath" "strings" - - "gopkg.in/yaml.v3" ) // ConfigDir returns the path to the global configuration directory. @@ -54,6 +52,26 @@ func WorkspaceConfigPath(dir string) string { return filepath.Join(dir, ".tld.yaml") } +// Config holds all global tld configuration, merging server settings, +// watch behaviors, and authentication. +type Config struct { + ServerURL string `yaml:"server_url"` + APIKey string `yaml:"api_key"` + WorkspaceID string `yaml:"org_id"` + Validation ValidationConfig `yaml:"validation"` + Serve ServeConfig `yaml:"serve"` + Watch WatchConfig `yaml:"watch"` + Completion CompletionConfig `yaml:"completion"` +} + +// ValidationConfig represents workspace validation settings. +type ValidationConfig struct { + Level int `yaml:"level"` + AllowLowInsight bool `yaml:"allow_low_insight"` + IncludeRules []string `yaml:"include_rules,omitempty"` + ExcludeRules []string `yaml:"exclude_rules,omitempty"` +} + // ServeConfig holds serve-specific settings from the global config file. type ServeConfig struct { Host string `yaml:"host"` @@ -61,55 +79,145 @@ type ServeConfig struct { DataDir string `yaml:"data_dir"` } -// GlobalConfig represents the global tld.yaml configuration file. -type GlobalConfig struct { - Serve ServeConfig `yaml:"serve"` +type WatchEmbeddingConfig struct { + Provider string `yaml:"provider"` + Endpoint string `yaml:"endpoint"` + Model string `yaml:"model"` + Dimension int `yaml:"dimension"` + HealthThreshold float64 `yaml:"health_threshold"` } -// LoadGlobalConfig reads the global config file. Missing file is not an error. -func LoadGlobalConfig() (*GlobalConfig, error) { - cfgPath, err := ConfigPath() - if err != nil { - return &GlobalConfig{}, nil - } - data, err := os.ReadFile(cfgPath) - if err != nil { - return &GlobalConfig{}, nil - } - var cfg GlobalConfig - if err := yaml.Unmarshal(data, &cfg); err != nil { - return &GlobalConfig{}, nil +type WatchThresholdConfig struct { + MaxElementsPerView int `yaml:"max_elements_per_view"` + MaxConnectorsPerView int `yaml:"max_connectors_per_view"` + MaxIncomingPerElement int `yaml:"max_incoming_per_element"` + MaxOutgoingPerElement int `yaml:"max_outgoing_per_element"` + MaxExpandedConnectorsPerGroup int `yaml:"max_expanded_connectors_per_group"` +} + +type WatchVisibilityWeightsConfig struct { + Changed float64 `yaml:"changed"` + Selected float64 `yaml:"selected"` + UserShow float64 `yaml:"user_show"` + UserHide float64 `yaml:"user_hide"` + HighSignalFact float64 `yaml:"high_signal_fact"` + RelationshipProximity float64 `yaml:"relationship_proximity"` + DependencyFact float64 `yaml:"dependency_fact"` + UtilityNoise float64 `yaml:"utility_noise"` + HighDegreeNoise float64 `yaml:"high_degree_noise"` +} + +type WatchVisibilityConfig struct { + CoreThresholdEnabled bool `yaml:"core_threshold_enabled"` + CoreThreshold float64 `yaml:"core_threshold"` + TierMultiplier float64 `yaml:"tier_multiplier"` + MaxExpansionMultiplier float64 `yaml:"max_expansion_multiplier"` + Weights WatchVisibilityWeightsConfig `yaml:"weights"` +} + +type WatchLayoutConfig struct { + LinkDistance float64 `yaml:"link_distance"` + ChargeStrength float64 `yaml:"charge_strength"` + CollideRadius float64 `yaml:"collide_radius"` + GravityStrength float64 `yaml:"gravity_strength"` +} + +type WatchConfig struct { + Languages []string `yaml:"languages"` + Watcher string `yaml:"watcher"` + PollInterval string `yaml:"poll_interval"` + Debounce string `yaml:"debounce"` + Thresholds WatchThresholdConfig `yaml:"thresholds"` + Visibility WatchVisibilityConfig `yaml:"visibility"` + Embedding WatchEmbeddingConfig `yaml:"embedding"` + Layout WatchLayoutConfig `yaml:"layout"` +} + +type CompletionConfig struct { + Remote bool `yaml:"remote"` +} + +const DefaultValidationLevel = 2 + +// DefaultConfig returns a Config struct populated with system defaults. +func DefaultConfig() *Config { + return &Config{ + ServerURL: "https://tldiagram.com", + Validation: ValidationConfig{ + Level: DefaultValidationLevel, + }, + Serve: ServeConfig{ + Host: "127.0.0.1", + Port: "8060", + }, + Watch: WatchConfig{ + Languages: []string{"go", "python", "typescript", "javascript", "java", "c", "cpp", "rust"}, + Watcher: "auto", + PollInterval: "1s", + Debounce: "500ms", + Thresholds: WatchThresholdConfig{ + MaxElementsPerView: 100, + MaxConnectorsPerView: 200, + MaxIncomingPerElement: 20, + MaxOutgoingPerElement: 20, + MaxExpandedConnectorsPerGroup: 24, + }, + Visibility: WatchVisibilityConfig{ + CoreThresholdEnabled: true, + CoreThreshold: 1, + TierMultiplier: 0.5, + MaxExpansionMultiplier: 2, + Weights: WatchVisibilityWeightsConfig{ + Changed: 100, + Selected: 100, + UserShow: 100, + UserHide: -100, + HighSignalFact: 1.5, + RelationshipProximity: 1, + DependencyFact: 0.2, + UtilityNoise: -0.8, + HighDegreeNoise: -1.5, + }, + }, + Embedding: WatchEmbeddingConfig{ + Provider: "local-lexical", + Endpoint: "http://127.0.0.1:8000/v1/embeddings", + Model: "embeddinggemma-300m-4bit", + HealthThreshold: 0.70, + }, + Layout: WatchLayoutConfig{ + LinkDistance: 100, + ChargeStrength: -400, + CollideRadius: 180, + GravityStrength: 0.05, + }, + }, } - return &cfg, nil } -// EnsureGlobalConfig ensures the global config file exists. -// If it doesn't, it writes a default one with commented instructions. -func EnsureGlobalConfig() error { - path, err := ConfigPath() +// LoadGlobalConfig reads the global config file, applies defaults to missing fields, +// handles environment variable overrides, and persists any added defaults back to YAML. +func LoadGlobalConfig() (*Config, error) { + state, err := LoadGlobalConfigState() if err != nil { - return err - } - if _, err := os.Stat(path); err == nil { - return nil // Already exists + return nil, err } + return state.Config, nil +} - if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { - return err - } +// SaveGlobalConfig writes the config back to the global configuration file. +func SaveGlobalConfig(cfg *Config) error { + return SaveGlobalConfigPreservingUnknown(cfg, nil) +} - defaultConfig := `# tlDiagram global configuration -serve: - host: 127.0.0.1 - port: 8060 - # data_dir: ~/.local/share/tldiagram -` - return os.WriteFile(path, []byte(defaultConfig), 0o644) +// EnsureGlobalConfig ensures the global config file exists with full defaults. +func EnsureGlobalConfig() error { + return SaveGlobalConfig(DefaultConfig()) } // ResolveDataDir returns the absolute path to the data directory, applying // resolution priority: flag > env (TLD_DATA_DIR) > config > default. -func ResolveDataDir(cfg *GlobalConfig, flagDir string) (string, error) { +func ResolveDataDir(cfg *Config, flagDir string) (string, error) { // 1. Flag if flagDir != "" { return filepath.Abs(flagDir) diff --git a/internal/workspace/config_registry.go b/internal/workspace/config_registry.go new file mode 100644 index 0000000..3a7571f --- /dev/null +++ b/internal/workspace/config_registry.go @@ -0,0 +1,1072 @@ +package workspace + +import ( + "encoding/json" + "fmt" + "net" + "net/url" + "os" + "path/filepath" + "sort" + "strconv" + "strings" + "time" + + "github.com/mertcikla/tld/internal/analyzer" + "gopkg.in/yaml.v3" +) + +type ConfigSource string + +const ( + ConfigSourceDefault ConfigSource = "default" + ConfigSourceFile ConfigSource = "file" + ConfigSourceEnv ConfigSource = "env" +) + +type ConfigDefinition struct { + Key string `json:"key"` + Env []string `json:"env,omitempty"` + Description string `json:"description"` + Secret bool `json:"secret,omitempty"` +} + +type ConfigValue struct { + Key string `json:"key"` + Value string `json:"value"` + Source ConfigSource `json:"source"` + Env string `json:"env,omitempty"` + Description string `json:"description"` + Secret bool `json:"secret,omitempty"` +} + +type ConfigValidationError struct { + Key string `json:"key"` + Message string `json:"message"` +} + +func (e ConfigValidationError) Error() string { + if e.Key == "" { + return e.Message + } + return e.Key + ": " + e.Message +} + +type ConfigValidationErrors []ConfigValidationError + +func (e ConfigValidationErrors) Error() string { + if len(e) == 0 { + return "" + } + if len(e) == 1 { + return e[0].Error() + } + return fmt.Sprintf("%s (+%d more)", e[0].Error(), len(e)-1) +} + +type GlobalConfigState struct { + Path string + Config *Config + File *Config + Values []ConfigValue + FileRoot *yaml.Node +} + +func ConfigDefinitions() []ConfigDefinition { + return append([]ConfigDefinition(nil), configDefinitions...) +} + +func ConfigDefinitionForKey(key string) (ConfigDefinition, bool) { + key = normalizeConfigKey(key) + for _, def := range configDefinitions { + if def.Key == key { + return def, true + } + } + return ConfigDefinition{}, false +} + +func LoadGlobalConfigState() (*GlobalConfigState, error) { + return loadGlobalConfigState(true) +} + +func LoadGlobalConfigStateNoRepair() (*GlobalConfigState, error) { + return loadGlobalConfigState(false) +} + +func SetGlobalConfigValue(key, value string) error { + key = normalizeConfigKey(key) + if _, ok := ConfigDefinitionForKey(key); !ok { + return fmt.Errorf("unknown global config key %q", key) + } + path, err := ConfigPath() + if err != nil { + return err + } + cfg := DefaultConfig() + existingRoot, data, err := readConfigNode(path) + if err != nil { + if os.IsNotExist(err) { + existingRoot = emptyConfigNode() + } else { + return err + } + } + if len(data) > 0 { + if err := yaml.Unmarshal(data, cfg); err != nil { + return fmt.Errorf("parse global config: %w", err) + } + } + if err := setConfigValue(cfg, key, value); err != nil { + return err + } + if errs := ValidateGlobalConfig(cfg); len(errs) > 0 { + return errs + } + return SaveGlobalConfigPreservingUnknown(cfg, existingRoot) +} + +func SaveGlobalConfigPreservingUnknown(cfg *Config, existingRoot *yaml.Node) error { + path, err := ConfigPath() + if err != nil { + return err + } + if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { + return err + } + if existingRoot == nil { + existingRoot, _, _ = readConfigNode(path) + } + root := configToYAMLNode(cfg, existingRoot) + data, err := yaml.Marshal(root) + if err != nil { + return err + } + return os.WriteFile(path, data, 0o644) +} + +func ValidateGlobalConfig(cfg *Config) ConfigValidationErrors { + var errs ConfigValidationErrors + add := func(key, msg string) { + errs = append(errs, ConfigValidationError{Key: key, Message: msg}) + } + + if strings.TrimSpace(cfg.ServerURL) != "" && !validHTTPURL(cfg.ServerURL) { + add("server_url", "must be a valid URL") + } + if cfg.Validation.Level < 1 || cfg.Validation.Level > 3 { + add("validation.level", "must be 1, 2, or 3") + } + if strings.TrimSpace(cfg.Serve.Host) == "" { + add("serve.host", "must be non-empty") + } + if !validPort(cfg.Serve.Port) { + add("serve.port", "must be an integer between 1 and 65535") + } + if strings.TrimSpace(cfg.Serve.DataDir) != "" { + if _, err := expandConfigPath(cfg.Serve.DataDir); err != nil { + add("serve.data_dir", err.Error()) + } + } + + for _, item := range []struct { + key string + value string + }{ + {"watch.poll_interval", cfg.Watch.PollInterval}, + {"watch.debounce", cfg.Watch.Debounce}, + } { + d, err := time.ParseDuration(strings.TrimSpace(item.value)) + if err != nil || d <= 0 { + add(item.key, "must be a positive duration such as 500ms or 1s") + } + } + switch strings.ToLower(strings.TrimSpace(cfg.Watch.Watcher)) { + case "auto", "fsnotify", "poll": + default: + add("watch.watcher", "must be auto, fsnotify, or poll") + } + if len(normalizeConfigLanguages(cfg.Watch.Languages)) == 0 { + add("watch.languages", "must include at least one supported language") + } + for _, item := range []struct { + key string + value int + }{ + {"watch.thresholds.max_elements_per_view", cfg.Watch.Thresholds.MaxElementsPerView}, + {"watch.thresholds.max_connectors_per_view", cfg.Watch.Thresholds.MaxConnectorsPerView}, + {"watch.thresholds.max_incoming_per_element", cfg.Watch.Thresholds.MaxIncomingPerElement}, + {"watch.thresholds.max_outgoing_per_element", cfg.Watch.Thresholds.MaxOutgoingPerElement}, + {"watch.thresholds.max_expanded_connectors_per_group", cfg.Watch.Thresholds.MaxExpandedConnectorsPerGroup}, + } { + if item.value <= 0 { + add(item.key, "must be positive") + } + } + for _, item := range []struct { + key string + value float64 + }{ + {"watch.visibility.core_threshold", cfg.Watch.Visibility.CoreThreshold}, + {"watch.visibility.tier_multiplier", cfg.Watch.Visibility.TierMultiplier}, + {"watch.visibility.max_expansion_multiplier", cfg.Watch.Visibility.MaxExpansionMultiplier}, + {"watch.layout.link_distance", cfg.Watch.Layout.LinkDistance}, + {"watch.layout.collide_radius", cfg.Watch.Layout.CollideRadius}, + {"watch.layout.gravity_strength", cfg.Watch.Layout.GravityStrength}, + } { + if item.value <= 0 { + add(item.key, "must be positive") + } + } + if cfg.Watch.Layout.ChargeStrength == 0 { + add("watch.layout.charge_strength", "must be non-zero") + } + + provider := strings.TrimSpace(cfg.Watch.Embedding.Provider) + switch provider { + case "none", "openai", "ollama", "local-lexical", "local-deterministic-test": + default: + add("watch.embedding.provider", "must be none, openai, ollama, local-lexical, or local-deterministic-test") + } + if cfg.Watch.Embedding.Dimension < 0 { + add("watch.embedding.dimension", "must be non-negative") + } + if provider == "openai" || provider == "ollama" { + if strings.TrimSpace(cfg.Watch.Embedding.Endpoint) == "" || !validHTTPURL(cfg.Watch.Embedding.Endpoint) { + add("watch.embedding.endpoint", "must be a valid URL for the selected provider") + } + if strings.TrimSpace(cfg.Watch.Embedding.Model) == "" { + add("watch.embedding.model", "must be non-empty for the selected provider") + } + if cfg.Watch.Embedding.HealthThreshold <= 0 || cfg.Watch.Embedding.HealthThreshold > 1 { + add("watch.embedding.health_threshold", "must be greater than 0 and at most 1") + } + } + return errs +} + +func ResolveServeOptions(cfg *Config, flagHost, flagPort string) ServeConfig { + if cfg == nil { + cfg = DefaultConfig() + } + out := cfg.Serve + if flagHost != "" { + out.Host = flagHost + } + if flagPort != "" { + out.Port = flagPort + } + return out +} + +func ResolveCompletionRemote() bool { + cfg, err := LoadGlobalConfig() + if err != nil { + return false + } + return cfg.Completion.Remote +} + +func ResolveWatchLayoutConfig() WatchLayoutConfig { + cfg, err := LoadGlobalConfig() + if err != nil { + return DefaultConfig().Watch.Layout + } + return cfg.Watch.Layout +} + +func FormatConfigValue(value any) string { + switch v := value.(type) { + case []string: + return strings.Join(v, ",") + case bool: + return strconv.FormatBool(v) + case int: + return strconv.Itoa(v) + case float64: + return strconv.FormatFloat(v, 'f', -1, 64) + case string: + return v + default: + data, _ := json.Marshal(v) + return string(data) + } +} + +var configDefinitions = []ConfigDefinition{ + {Key: "server_url", Env: []string{"TLD_SERVER_URL"}, Description: "tlDiagram cloud/server URL used by sync commands."}, + {Key: "api_key", Env: []string{"TLD_API_KEY"}, Description: "API key used to authenticate with tlDiagram.", Secret: true}, + {Key: "org_id", Env: []string{"TLD_ORG_ID"}, Description: "Default tlDiagram organization/workspace identifier."}, + {Key: "validation.level", Description: "Architectural warning strictness: 1 minimal, 2 standard, 3 strict."}, + {Key: "validation.allow_low_insight", Description: "Allow low-insight generated warning groups."}, + {Key: "validation.include_rules", Description: "Additional architectural warning rule codes to include."}, + {Key: "validation.exclude_rules", Description: "Architectural warning rule codes to suppress."}, + {Key: "serve.host", Env: []string{"TLD_HOST", "TLD_ADDR"}, Description: "Host address for the local web server."}, + {Key: "serve.port", Env: []string{"PORT", "TLD_ADDR"}, Description: "Port for the local web server."}, + {Key: "serve.data_dir", Env: []string{"TLD_DATA_DIR"}, Description: "Directory for local database, logs, and pid files."}, + {Key: "watch.languages", Env: []string{"TLD_WATCH_LANGUAGES"}, Description: "Comma-separated source languages watched by analyze/watch."}, + {Key: "watch.watcher", Env: []string{"TLD_WATCH_WATCHER"}, Description: "File watcher backend: auto, fsnotify, or poll."}, + {Key: "watch.poll_interval", Env: []string{"TLD_WATCH_POLL_INTERVAL"}, Description: "Polling interval used by the poll watcher."}, + {Key: "watch.debounce", Env: []string{"TLD_WATCH_DEBOUNCE"}, Description: "Delay used to batch file changes before rescanning."}, + {Key: "watch.thresholds.max_elements_per_view", Description: "Maximum generated elements in a watch-created view."}, + {Key: "watch.thresholds.max_connectors_per_view", Description: "Maximum generated connectors in a watch-created view."}, + {Key: "watch.thresholds.max_incoming_per_element", Description: "Incoming reference limit before collapsing context."}, + {Key: "watch.thresholds.max_outgoing_per_element", Description: "Outgoing reference limit before collapsing context."}, + {Key: "watch.thresholds.max_expanded_connectors_per_group", Description: "File-pair connector expansion limit before folder-level collapse."}, + {Key: "watch.visibility.core_threshold_enabled", Description: "Enable score thresholding for watch visibility decisions."}, + {Key: "watch.visibility.core_threshold", Description: "Minimum score for core watch visibility."}, + {Key: "watch.visibility.tier_multiplier", Description: "Density multiplier added by each Show Context tier."}, + {Key: "watch.visibility.max_expansion_multiplier", Description: "Maximum density multiplier allowed by Show Context."}, + {Key: "watch.visibility.weights.changed", Description: "Visibility score weight for changed resources."}, + {Key: "watch.visibility.weights.selected", Description: "Visibility score weight for selected context expansion resources."}, + {Key: "watch.visibility.weights.user_show", Description: "Visibility score weight for durable show policies."}, + {Key: "watch.visibility.weights.user_hide", Description: "Visibility score weight for durable hide policies."}, + {Key: "watch.visibility.weights.high_signal_fact", Description: "Visibility score weight for high-signal facts."}, + {Key: "watch.visibility.weights.relationship_proximity", Description: "Visibility score weight for graph/fact neighborhood proximity."}, + {Key: "watch.visibility.weights.dependency_fact", Description: "Visibility score weight for dependency facts."}, + {Key: "watch.visibility.weights.utility_noise", Description: "Visibility score penalty for utility-like noise."}, + {Key: "watch.visibility.weights.high_degree_noise", Description: "Visibility score penalty for high-degree noise."}, + {Key: "watch.embedding.provider", Env: []string{"TLD_EMBEDDING_PROVIDER"}, Description: "Embedding provider for watch identity and similarity."}, + {Key: "watch.embedding.endpoint", Env: []string{"TLD_EMBEDDING_ENDPOINT"}, Description: "Embedding provider endpoint when the provider uses HTTP."}, + {Key: "watch.embedding.model", Env: []string{"TLD_EMBEDDING_MODEL"}, Description: "Embedding model name."}, + {Key: "watch.embedding.dimension", Env: []string{"TLD_EMBEDDING_DIMENSION"}, Description: "Embedding vector dimension, or 0 to infer when supported."}, + {Key: "watch.embedding.health_threshold", Description: "Similarity threshold required by embedding health checks."}, + {Key: "watch.layout.link_distance", Env: []string{"LAYOUT_LINK_DISTANCE"}, Description: "Organic layout target link distance for generated watch views."}, + {Key: "watch.layout.charge_strength", Env: []string{"LAYOUT_CHARGE_STRENGTH"}, Description: "Organic layout node charge strength for generated watch views."}, + {Key: "watch.layout.collide_radius", Env: []string{"LAYOUT_COLLIDE_RADIUS"}, Description: "Organic layout collision radius for generated watch views."}, + {Key: "watch.layout.gravity_strength", Env: []string{"LAYOUT_GRAVITY_STRENGTH"}, Description: "Organic layout gravity strength for generated watch views."}, + {Key: "completion.remote", Env: []string{"TLD_COMPLETION_REMOTE"}, Description: "Allow shell completion to query remote resources."}, +} + +func loadGlobalConfigState(repair bool) (*GlobalConfigState, error) { + path, err := ConfigPath() + if err != nil { + return &GlobalConfigState{Config: DefaultConfig(), File: DefaultConfig(), Values: buildConfigValues(DefaultConfig(), nil, nil)}, nil + } + cfg := DefaultConfig() + fileCfg := DefaultConfig() + root, data, err := readConfigNode(path) + if err != nil { + if os.IsNotExist(err) { + if repair { + _ = SaveGlobalConfig(cfg) + } + values, applyErr := applyEnvOverridesDetailed(cfg, root) + return &GlobalConfigState{Path: path, Config: cfg, File: fileCfg, Values: values, FileRoot: root}, applyErr + } + return nil, err + } + if len(data) > 0 { + if err := yaml.Unmarshal(data, fileCfg); err != nil { + return nil, fmt.Errorf("parse global config: %w", err) + } + *cfg = *fileCfg + } + if repair && shouldSaveConfig(root) { + _ = SaveGlobalConfigPreservingUnknown(fileCfg, root) + } + values, err := applyEnvOverridesDetailed(cfg, root) + if err != nil { + return nil, err + } + return &GlobalConfigState{Path: path, Config: cfg, File: fileCfg, Values: values, FileRoot: root}, nil +} + +func readConfigNode(path string) (*yaml.Node, []byte, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, nil, err + } + var root yaml.Node + if err := yaml.Unmarshal(data, &root); err != nil { + return nil, data, fmt.Errorf("parse global config: %w", err) + } + if root.Kind == 0 { + root = yaml.Node{Kind: yaml.DocumentNode, Content: []*yaml.Node{{Kind: yaml.MappingNode, Tag: "!!map"}}} + } + if len(root.Content) == 0 || root.Content[0].Kind != yaml.MappingNode { + return nil, data, fmt.Errorf("parse global config: expected mapping document") + } + return &root, data, nil +} + +func emptyConfigNode() *yaml.Node { + return &yaml.Node{Kind: yaml.DocumentNode, Content: []*yaml.Node{{Kind: yaml.MappingNode, Tag: "!!map"}}} +} + +func applyEnvOverridesDetailed(cfg *Config, root *yaml.Node) ([]ConfigValue, error) { + sources := map[string]ConfigSource{} + envSources := map[string]string{} + for _, def := range configDefinitions { + if hasYAMLPath(root, def.Key) { + sources[def.Key] = ConfigSourceFile + } else { + sources[def.Key] = ConfigSourceDefault + } + } + apply := func(key, env, value string) error { + if value == "" { + return nil + } + if err := setConfigValue(cfg, key, value); err != nil { + return fmt.Errorf("%s from %s: %w", key, env, err) + } + sources[key] = ConfigSourceEnv + envSources[key] = env + return nil + } + for _, item := range []struct { + key string + env string + }{ + {"server_url", "TLD_SERVER_URL"}, + {"api_key", "TLD_API_KEY"}, + {"org_id", "TLD_ORG_ID"}, + {"serve.host", "TLD_HOST"}, + {"serve.port", "PORT"}, + {"serve.data_dir", "TLD_DATA_DIR"}, + {"watch.languages", "TLD_WATCH_LANGUAGES"}, + {"watch.watcher", "TLD_WATCH_WATCHER"}, + {"watch.poll_interval", "TLD_WATCH_POLL_INTERVAL"}, + {"watch.debounce", "TLD_WATCH_DEBOUNCE"}, + {"watch.embedding.provider", "TLD_EMBEDDING_PROVIDER"}, + {"watch.embedding.endpoint", "TLD_EMBEDDING_ENDPOINT"}, + {"watch.embedding.model", "TLD_EMBEDDING_MODEL"}, + {"watch.embedding.dimension", "TLD_EMBEDDING_DIMENSION"}, + {"watch.layout.link_distance", "LAYOUT_LINK_DISTANCE"}, + {"watch.layout.charge_strength", "LAYOUT_CHARGE_STRENGTH"}, + {"watch.layout.collide_radius", "LAYOUT_COLLIDE_RADIUS"}, + {"watch.layout.gravity_strength", "LAYOUT_GRAVITY_STRENGTH"}, + } { + if err := apply(item.key, item.env, os.Getenv(item.env)); err != nil { + return nil, err + } + } + if v := os.Getenv("TLD_COMPLETION_REMOTE"); v != "" { + if err := apply("completion.remote", "TLD_COMPLETION_REMOTE", v); err != nil { + return nil, err + } + } + if addr := strings.TrimSpace(os.Getenv("TLD_ADDR")); addr != "" { + host, port, err := splitAddrOverride(addr) + if err != nil { + return nil, fmt.Errorf("serve.host/serve.port from TLD_ADDR: %w", err) + } + if host != "" { + if err := setConfigValue(cfg, "serve.host", host); err != nil { + return nil, err + } + sources["serve.host"] = ConfigSourceEnv + envSources["serve.host"] = "TLD_ADDR" + } + if port != "" { + if err := setConfigValue(cfg, "serve.port", port); err != nil { + return nil, err + } + sources["serve.port"] = ConfigSourceEnv + envSources["serve.port"] = "TLD_ADDR" + } + } + if errs := ValidateGlobalConfig(cfg); len(errs) > 0 { + return nil, errs + } + return buildConfigValues(cfg, sources, envSources), nil +} + +func buildConfigValues(cfg *Config, sources map[string]ConfigSource, envSources map[string]string) []ConfigValue { + values := make([]ConfigValue, 0, len(configDefinitions)) + for _, def := range configDefinitions { + source := sources[def.Key] + if source == "" { + source = ConfigSourceDefault + } + values = append(values, ConfigValue{ + Key: def.Key, + Value: FormatConfigValue(getConfigValue(cfg, def.Key)), + Source: source, + Env: configValueEnv(def, envSources[def.Key]), + Description: def.Description, + Secret: def.Secret, + }) + } + return values +} + +func configValueEnv(def ConfigDefinition, active string) string { + if active != "" { + return active + } + return strings.Join(def.Env, ",") +} + +func setConfigValue(cfg *Config, key, value string) error { + key = normalizeConfigKey(key) + switch key { + case "server_url": + cfg.ServerURL = strings.TrimSpace(value) + case "api_key": + cfg.APIKey = value + case "org_id": + cfg.WorkspaceID = strings.TrimSpace(value) + case "validation.level": + v, err := parseInt(value) + if err != nil { + return err + } + cfg.Validation.Level = v + case "validation.allow_low_insight": + v, err := parseBool(value) + if err != nil { + return err + } + cfg.Validation.AllowLowInsight = v + case "validation.include_rules": + cfg.Validation.IncludeRules = parseStringList(value) + case "validation.exclude_rules": + cfg.Validation.ExcludeRules = parseStringList(value) + case "serve.host": + cfg.Serve.Host = strings.TrimSpace(value) + case "serve.port": + cfg.Serve.Port = strings.TrimSpace(value) + case "serve.data_dir": + cfg.Serve.DataDir = strings.TrimSpace(value) + case "watch.languages": + cfg.Watch.Languages = parseStringList(value) + case "watch.watcher": + cfg.Watch.Watcher = strings.ToLower(strings.TrimSpace(value)) + case "watch.poll_interval": + cfg.Watch.PollInterval = strings.TrimSpace(value) + case "watch.debounce": + cfg.Watch.Debounce = strings.TrimSpace(value) + case "watch.thresholds.max_elements_per_view": + v, err := parseInt(value) + if err != nil { + return err + } + cfg.Watch.Thresholds.MaxElementsPerView = v + case "watch.thresholds.max_connectors_per_view": + v, err := parseInt(value) + if err != nil { + return err + } + cfg.Watch.Thresholds.MaxConnectorsPerView = v + case "watch.thresholds.max_incoming_per_element": + v, err := parseInt(value) + if err != nil { + return err + } + cfg.Watch.Thresholds.MaxIncomingPerElement = v + case "watch.thresholds.max_outgoing_per_element": + v, err := parseInt(value) + if err != nil { + return err + } + cfg.Watch.Thresholds.MaxOutgoingPerElement = v + case "watch.thresholds.max_expanded_connectors_per_group": + v, err := parseInt(value) + if err != nil { + return err + } + cfg.Watch.Thresholds.MaxExpandedConnectorsPerGroup = v + case "watch.visibility.core_threshold_enabled": + v, err := parseBool(value) + if err != nil { + return err + } + cfg.Watch.Visibility.CoreThresholdEnabled = v + case "watch.visibility.core_threshold": + v, err := parseFloat(value) + if err != nil { + return err + } + cfg.Watch.Visibility.CoreThreshold = v + case "watch.visibility.tier_multiplier": + v, err := parseFloat(value) + if err != nil { + return err + } + cfg.Watch.Visibility.TierMultiplier = v + case "watch.visibility.max_expansion_multiplier": + v, err := parseFloat(value) + if err != nil { + return err + } + cfg.Watch.Visibility.MaxExpansionMultiplier = v + case "watch.visibility.weights.changed": + v, err := parseFloat(value) + if err != nil { + return err + } + cfg.Watch.Visibility.Weights.Changed = v + case "watch.visibility.weights.selected": + v, err := parseFloat(value) + if err != nil { + return err + } + cfg.Watch.Visibility.Weights.Selected = v + case "watch.visibility.weights.user_show": + v, err := parseFloat(value) + if err != nil { + return err + } + cfg.Watch.Visibility.Weights.UserShow = v + case "watch.visibility.weights.user_hide": + v, err := parseFloat(value) + if err != nil { + return err + } + cfg.Watch.Visibility.Weights.UserHide = v + case "watch.visibility.weights.high_signal_fact": + v, err := parseFloat(value) + if err != nil { + return err + } + cfg.Watch.Visibility.Weights.HighSignalFact = v + case "watch.visibility.weights.relationship_proximity": + v, err := parseFloat(value) + if err != nil { + return err + } + cfg.Watch.Visibility.Weights.RelationshipProximity = v + case "watch.visibility.weights.dependency_fact": + v, err := parseFloat(value) + if err != nil { + return err + } + cfg.Watch.Visibility.Weights.DependencyFact = v + case "watch.visibility.weights.utility_noise": + v, err := parseFloat(value) + if err != nil { + return err + } + cfg.Watch.Visibility.Weights.UtilityNoise = v + case "watch.visibility.weights.high_degree_noise": + v, err := parseFloat(value) + if err != nil { + return err + } + cfg.Watch.Visibility.Weights.HighDegreeNoise = v + case "watch.embedding.provider": + cfg.Watch.Embedding.Provider = strings.TrimSpace(value) + case "watch.embedding.endpoint": + cfg.Watch.Embedding.Endpoint = strings.TrimSpace(value) + case "watch.embedding.model": + cfg.Watch.Embedding.Model = strings.TrimSpace(value) + case "watch.embedding.dimension": + v, err := parseInt(value) + if err != nil { + return err + } + cfg.Watch.Embedding.Dimension = v + case "watch.embedding.health_threshold": + v, err := parseFloat(value) + if err != nil { + return err + } + cfg.Watch.Embedding.HealthThreshold = v + case "watch.layout.link_distance": + v, err := parseFloat(value) + if err != nil { + return err + } + cfg.Watch.Layout.LinkDistance = v + case "watch.layout.charge_strength": + v, err := parseFloat(value) + if err != nil { + return err + } + cfg.Watch.Layout.ChargeStrength = v + case "watch.layout.collide_radius": + v, err := parseFloat(value) + if err != nil { + return err + } + cfg.Watch.Layout.CollideRadius = v + case "watch.layout.gravity_strength": + v, err := parseFloat(value) + if err != nil { + return err + } + cfg.Watch.Layout.GravityStrength = v + case "completion.remote": + v, err := parseBool(value) + if err != nil { + return err + } + cfg.Completion.Remote = v + default: + return fmt.Errorf("unknown global config key %q", key) + } + return nil +} + +func getConfigValue(cfg *Config, key string) any { + switch normalizeConfigKey(key) { + case "server_url": + return cfg.ServerURL + case "api_key": + return cfg.APIKey + case "org_id": + return cfg.WorkspaceID + case "validation.level": + return cfg.Validation.Level + case "validation.allow_low_insight": + return cfg.Validation.AllowLowInsight + case "validation.include_rules": + return cfg.Validation.IncludeRules + case "validation.exclude_rules": + return cfg.Validation.ExcludeRules + case "serve.host": + return cfg.Serve.Host + case "serve.port": + return cfg.Serve.Port + case "serve.data_dir": + return cfg.Serve.DataDir + case "watch.languages": + return cfg.Watch.Languages + case "watch.watcher": + return cfg.Watch.Watcher + case "watch.poll_interval": + return cfg.Watch.PollInterval + case "watch.debounce": + return cfg.Watch.Debounce + case "watch.thresholds.max_elements_per_view": + return cfg.Watch.Thresholds.MaxElementsPerView + case "watch.thresholds.max_connectors_per_view": + return cfg.Watch.Thresholds.MaxConnectorsPerView + case "watch.thresholds.max_incoming_per_element": + return cfg.Watch.Thresholds.MaxIncomingPerElement + case "watch.thresholds.max_outgoing_per_element": + return cfg.Watch.Thresholds.MaxOutgoingPerElement + case "watch.thresholds.max_expanded_connectors_per_group": + return cfg.Watch.Thresholds.MaxExpandedConnectorsPerGroup + case "watch.visibility.core_threshold_enabled": + return cfg.Watch.Visibility.CoreThresholdEnabled + case "watch.visibility.core_threshold": + return cfg.Watch.Visibility.CoreThreshold + case "watch.visibility.tier_multiplier": + return cfg.Watch.Visibility.TierMultiplier + case "watch.visibility.max_expansion_multiplier": + return cfg.Watch.Visibility.MaxExpansionMultiplier + case "watch.visibility.weights.changed": + return cfg.Watch.Visibility.Weights.Changed + case "watch.visibility.weights.selected": + return cfg.Watch.Visibility.Weights.Selected + case "watch.visibility.weights.user_show": + return cfg.Watch.Visibility.Weights.UserShow + case "watch.visibility.weights.user_hide": + return cfg.Watch.Visibility.Weights.UserHide + case "watch.visibility.weights.high_signal_fact": + return cfg.Watch.Visibility.Weights.HighSignalFact + case "watch.visibility.weights.relationship_proximity": + return cfg.Watch.Visibility.Weights.RelationshipProximity + case "watch.visibility.weights.dependency_fact": + return cfg.Watch.Visibility.Weights.DependencyFact + case "watch.visibility.weights.utility_noise": + return cfg.Watch.Visibility.Weights.UtilityNoise + case "watch.visibility.weights.high_degree_noise": + return cfg.Watch.Visibility.Weights.HighDegreeNoise + case "watch.embedding.provider": + return cfg.Watch.Embedding.Provider + case "watch.embedding.endpoint": + return cfg.Watch.Embedding.Endpoint + case "watch.embedding.model": + return cfg.Watch.Embedding.Model + case "watch.embedding.dimension": + return cfg.Watch.Embedding.Dimension + case "watch.embedding.health_threshold": + return cfg.Watch.Embedding.HealthThreshold + case "watch.layout.link_distance": + return cfg.Watch.Layout.LinkDistance + case "watch.layout.charge_strength": + return cfg.Watch.Layout.ChargeStrength + case "watch.layout.collide_radius": + return cfg.Watch.Layout.CollideRadius + case "watch.layout.gravity_strength": + return cfg.Watch.Layout.GravityStrength + case "completion.remote": + return cfg.Completion.Remote + default: + return "" + } +} + +func configToYAMLNode(cfg *Config, existingRoot *yaml.Node) *yaml.Node { + var existing *yaml.Node + if existingRoot != nil && len(existingRoot.Content) > 0 { + existing = existingRoot.Content[0] + } + root := &yaml.Node{Kind: yaml.DocumentNode} + mapping := &yaml.Node{Kind: yaml.MappingNode, Tag: "!!map"} + root.Content = []*yaml.Node{mapping} + + addScalar(mapping, "server_url", cfg.ServerURL, desc("server_url")) + addScalar(mapping, "api_key", cfg.APIKey, desc("api_key")) + addScalar(mapping, "org_id", cfg.WorkspaceID, desc("org_id")) + + validation := &yaml.Node{Kind: yaml.MappingNode, Tag: "!!map"} + addScalar(validation, "level", cfg.Validation.Level, desc("validation.level")) + addScalar(validation, "allow_low_insight", cfg.Validation.AllowLowInsight, desc("validation.allow_low_insight")) + addStringSeq(validation, "include_rules", cfg.Validation.IncludeRules, desc("validation.include_rules")) + addStringSeq(validation, "exclude_rules", cfg.Validation.ExcludeRules, desc("validation.exclude_rules")) + appendUnknownEntries(validation, mappingValueNode(existing, "validation"), setOf("level", "allow_low_insight", "include_rules", "exclude_rules")) + addMap(mapping, "validation", validation, "Workspace validation and architectural warning settings.") + + serve := &yaml.Node{Kind: yaml.MappingNode, Tag: "!!map"} + addScalar(serve, "host", cfg.Serve.Host, desc("serve.host")) + addScalar(serve, "port", cfg.Serve.Port, desc("serve.port")) + addScalar(serve, "data_dir", cfg.Serve.DataDir, desc("serve.data_dir")) + appendUnknownEntries(serve, mappingValueNode(existing, "serve"), setOf("host", "port", "data_dir")) + addMap(mapping, "serve", serve, "Local web server settings.") + + watchNode := &yaml.Node{Kind: yaml.MappingNode, Tag: "!!map"} + addStringSeq(watchNode, "languages", cfg.Watch.Languages, desc("watch.languages")) + addScalar(watchNode, "watcher", cfg.Watch.Watcher, desc("watch.watcher")) + addScalar(watchNode, "poll_interval", cfg.Watch.PollInterval, desc("watch.poll_interval")) + addScalar(watchNode, "debounce", cfg.Watch.Debounce, desc("watch.debounce")) + + thresholds := &yaml.Node{Kind: yaml.MappingNode, Tag: "!!map"} + addScalar(thresholds, "max_elements_per_view", cfg.Watch.Thresholds.MaxElementsPerView, desc("watch.thresholds.max_elements_per_view")) + addScalar(thresholds, "max_connectors_per_view", cfg.Watch.Thresholds.MaxConnectorsPerView, desc("watch.thresholds.max_connectors_per_view")) + addScalar(thresholds, "max_incoming_per_element", cfg.Watch.Thresholds.MaxIncomingPerElement, desc("watch.thresholds.max_incoming_per_element")) + addScalar(thresholds, "max_outgoing_per_element", cfg.Watch.Thresholds.MaxOutgoingPerElement, desc("watch.thresholds.max_outgoing_per_element")) + addScalar(thresholds, "max_expanded_connectors_per_group", cfg.Watch.Thresholds.MaxExpandedConnectorsPerGroup, desc("watch.thresholds.max_expanded_connectors_per_group")) + appendUnknownEntries(thresholds, mappingValueNode(mappingValueNode(existing, "watch"), "thresholds"), setOf("max_elements_per_view", "max_connectors_per_view", "max_incoming_per_element", "max_outgoing_per_element", "max_expanded_connectors_per_group")) + addMap(watchNode, "thresholds", thresholds, "Limits used while materializing generated watch views.") + + visibility := &yaml.Node{Kind: yaml.MappingNode, Tag: "!!map"} + addScalar(visibility, "core_threshold_enabled", cfg.Watch.Visibility.CoreThresholdEnabled, desc("watch.visibility.core_threshold_enabled")) + addScalar(visibility, "core_threshold", cfg.Watch.Visibility.CoreThreshold, desc("watch.visibility.core_threshold")) + addScalar(visibility, "tier_multiplier", cfg.Watch.Visibility.TierMultiplier, desc("watch.visibility.tier_multiplier")) + addScalar(visibility, "max_expansion_multiplier", cfg.Watch.Visibility.MaxExpansionMultiplier, desc("watch.visibility.max_expansion_multiplier")) + weights := &yaml.Node{Kind: yaml.MappingNode, Tag: "!!map"} + addScalar(weights, "changed", cfg.Watch.Visibility.Weights.Changed, desc("watch.visibility.weights.changed")) + addScalar(weights, "selected", cfg.Watch.Visibility.Weights.Selected, desc("watch.visibility.weights.selected")) + addScalar(weights, "user_show", cfg.Watch.Visibility.Weights.UserShow, desc("watch.visibility.weights.user_show")) + addScalar(weights, "user_hide", cfg.Watch.Visibility.Weights.UserHide, desc("watch.visibility.weights.user_hide")) + addScalar(weights, "high_signal_fact", cfg.Watch.Visibility.Weights.HighSignalFact, desc("watch.visibility.weights.high_signal_fact")) + addScalar(weights, "relationship_proximity", cfg.Watch.Visibility.Weights.RelationshipProximity, desc("watch.visibility.weights.relationship_proximity")) + addScalar(weights, "dependency_fact", cfg.Watch.Visibility.Weights.DependencyFact, desc("watch.visibility.weights.dependency_fact")) + addScalar(weights, "utility_noise", cfg.Watch.Visibility.Weights.UtilityNoise, desc("watch.visibility.weights.utility_noise")) + addScalar(weights, "high_degree_noise", cfg.Watch.Visibility.Weights.HighDegreeNoise, desc("watch.visibility.weights.high_degree_noise")) + appendUnknownEntries(weights, mappingValueNode(mappingValueNode(mappingValueNode(existing, "watch"), "visibility"), "weights"), setOf("changed", "selected", "user_show", "user_hide", "high_signal_fact", "relationship_proximity", "dependency_fact", "utility_noise", "high_degree_noise")) + addMap(visibility, "weights", weights, "Visibility scoring weights.") + appendUnknownEntries(visibility, mappingValueNode(mappingValueNode(existing, "watch"), "visibility"), setOf("core_threshold_enabled", "core_threshold", "tier_multiplier", "max_expansion_multiplier", "weights")) + addMap(watchNode, "visibility", visibility, "Scoring and density-tier settings for watch context.") + + embedding := &yaml.Node{Kind: yaml.MappingNode, Tag: "!!map"} + addScalar(embedding, "provider", cfg.Watch.Embedding.Provider, desc("watch.embedding.provider")) + addScalar(embedding, "endpoint", cfg.Watch.Embedding.Endpoint, desc("watch.embedding.endpoint")) + addScalar(embedding, "model", cfg.Watch.Embedding.Model, desc("watch.embedding.model")) + addScalar(embedding, "dimension", cfg.Watch.Embedding.Dimension, desc("watch.embedding.dimension")) + addScalar(embedding, "health_threshold", cfg.Watch.Embedding.HealthThreshold, desc("watch.embedding.health_threshold")) + appendUnknownEntries(embedding, mappingValueNode(mappingValueNode(existing, "watch"), "embedding"), setOf("provider", "endpoint", "model", "dimension", "health_threshold")) + addMap(watchNode, "embedding", embedding, "Embedding settings used by watch/analyze identity matching.") + + layout := &yaml.Node{Kind: yaml.MappingNode, Tag: "!!map"} + addScalar(layout, "link_distance", cfg.Watch.Layout.LinkDistance, desc("watch.layout.link_distance")) + addScalar(layout, "charge_strength", cfg.Watch.Layout.ChargeStrength, desc("watch.layout.charge_strength")) + addScalar(layout, "collide_radius", cfg.Watch.Layout.CollideRadius, desc("watch.layout.collide_radius")) + addScalar(layout, "gravity_strength", cfg.Watch.Layout.GravityStrength, desc("watch.layout.gravity_strength")) + appendUnknownEntries(layout, mappingValueNode(mappingValueNode(existing, "watch"), "layout"), setOf("link_distance", "charge_strength", "collide_radius", "gravity_strength")) + addMap(watchNode, "layout", layout, "Organic layout tuning for generated watch views.") + + appendUnknownEntries(watchNode, mappingValueNode(existing, "watch"), setOf("languages", "watcher", "poll_interval", "debounce", "thresholds", "visibility", "embedding", "layout")) + addMap(mapping, "watch", watchNode, "Source watch/analyze pipeline settings.") + + completion := &yaml.Node{Kind: yaml.MappingNode, Tag: "!!map"} + addScalar(completion, "remote", cfg.Completion.Remote, desc("completion.remote")) + appendUnknownEntries(completion, mappingValueNode(existing, "completion"), setOf("remote")) + addMap(mapping, "completion", completion, "Shell completion settings.") + + appendUnknownEntries(mapping, existing, setOf("server_url", "api_key", "org_id", "validation", "serve", "watch", "completion")) + return root +} + +func addMap(mapping *yaml.Node, key string, value *yaml.Node, comment string) { + mapping.Content = append(mapping.Content, &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: key, HeadComment: comment}, value) +} + +func addStringSeq(mapping *yaml.Node, key string, values []string, comment string) { + seq := &yaml.Node{Kind: yaml.SequenceNode, Tag: "!!seq"} + for _, value := range values { + seq.Content = append(seq.Content, &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: value}) + } + mapping.Content = append(mapping.Content, &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: key, HeadComment: comment}, seq) +} + +func addScalar(mapping *yaml.Node, key string, value any, comment string) { + mapping.Content = append(mapping.Content, &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: key, HeadComment: comment}, scalarNode(value)) +} + +func scalarNode(value any) *yaml.Node { + switch v := value.(type) { + case bool: + return &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!bool", Value: strconv.FormatBool(v)} + case int: + return &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!int", Value: strconv.Itoa(v)} + case float64: + return &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!float", Value: strconv.FormatFloat(v, 'f', -1, 64)} + default: + return &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: fmt.Sprint(v)} + } +} + +func appendUnknownEntries(dst, src *yaml.Node, known map[string]struct{}) { + if src == nil || src.Kind != yaml.MappingNode { + return + } + for i := 0; i+1 < len(src.Content); i += 2 { + key := src.Content[i].Value + if _, ok := known[key]; ok { + continue + } + dst.Content = append(dst.Content, cloneYAMLNode(src.Content[i]), cloneYAMLNode(src.Content[i+1])) + } +} + +func shouldSaveConfig(root *yaml.Node) bool { + if root == nil { + return true + } + for _, def := range configDefinitions { + if !hasYAMLPath(root, def.Key) { + return true + } + } + return false +} + +func hasYAMLPath(root *yaml.Node, dotted string) bool { + if root == nil || len(root.Content) == 0 { + return false + } + node := root.Content[0] + for part := range strings.SplitSeq(dotted, ".") { + if node == nil || node.Kind != yaml.MappingNode { + return false + } + node = mappingValueNode(node, part) + if node == nil { + return false + } + } + return true +} + +func desc(key string) string { + if def, ok := ConfigDefinitionForKey(key); ok { + return def.Description + } + return "" +} + +func normalizeConfigKey(key string) string { + return strings.ToLower(strings.TrimSpace(key)) +} + +func parseStringList(value string) []string { + if strings.TrimSpace(value) == "" { + return nil + } + parts := strings.Split(value, ",") + out := make([]string, 0, len(parts)) + for _, part := range parts { + part = strings.TrimSpace(part) + if part != "" { + out = append(out, part) + } + } + return out +} + +func parseBool(value string) (bool, error) { + switch strings.ToLower(strings.TrimSpace(value)) { + case "1", "true", "yes", "on": + return true, nil + case "0", "false", "no", "off": + return false, nil + default: + return false, fmt.Errorf("must be a boolean") + } +} + +func parseInt(value string) (int, error) { + v, err := strconv.Atoi(strings.TrimSpace(value)) + if err != nil { + return 0, fmt.Errorf("must be an integer") + } + return v, nil +} + +func parseFloat(value string) (float64, error) { + v, err := strconv.ParseFloat(strings.TrimSpace(value), 64) + if err != nil { + return 0, fmt.Errorf("must be a number") + } + return v, nil +} + +func validHTTPURL(value string) bool { + parsed, err := url.Parse(strings.TrimSpace(value)) + return err == nil && parsed.Scheme != "" && parsed.Host != "" +} + +func validPort(value string) bool { + port, err := strconv.Atoi(strings.TrimSpace(value)) + return err == nil && port >= 1 && port <= 65535 +} + +func expandConfigPath(path string) (string, error) { + path = strings.TrimSpace(path) + if strings.HasPrefix(path, "~/") { + home, err := os.UserHomeDir() + if err != nil { + return "", fmt.Errorf("user home dir: %w", err) + } + path = filepath.Join(home, path[2:]) + } + return filepath.Abs(path) +} + +func splitAddrOverride(addr string) (host string, port string, err error) { + if strings.Count(addr, ":") == 1 { + parts := strings.Split(addr, ":") + return parts[0], parts[1], nil + } + if strings.Contains(addr, ":") { + h, p, err := net.SplitHostPort(addr) + if err != nil { + return "", "", err + } + return h, p, nil + } + return addr, "", nil +} + +func normalizeConfigLanguages(values []string) []string { + seen := map[string]struct{}{} + for _, value := range values { + lang := strings.ToLower(strings.TrimSpace(value)) + if lang == "" { + continue + } + if _, ok := analyzer.LanguageSpecFor(analyzer.Language(lang)); ok { + seen[lang] = struct{}{} + } + } + out := make([]string, 0, len(seen)) + for lang := range seen { + out = append(out, lang) + } + sort.Strings(out) + return out +} + +func setOf(values ...string) map[string]struct{} { + out := make(map[string]struct{}, len(values)) + for _, value := range values { + out[value] = struct{}{} + } + return out +} diff --git a/internal/workspace/config_registry_test.go b/internal/workspace/config_registry_test.go new file mode 100644 index 0000000..5d1e429 --- /dev/null +++ b/internal/workspace/config_registry_test.go @@ -0,0 +1,95 @@ +package workspace_test + +import ( + "os" + "path/filepath" + "strings" + "testing" + + "github.com/mertcikla/tld/internal/workspace" +) + +func TestLoadGlobalConfigStateReportsEnvSourcesAndDoesNotPersistOverrides(t *testing.T) { + configDir := t.TempDir() + t.Setenv("TLD_CONFIG_DIR", configDir) + t.Setenv("TLD_API_KEY", "env-secret") + + configPath := filepath.Join(configDir, "tld.yaml") + writeFile(t, configPath, "server_url: http://file.example\nunknown_root: keep-me\n") + + state, err := workspace.LoadGlobalConfigState() + if err != nil { + t.Fatalf("LoadGlobalConfigState: %v", err) + } + if state.Config.APIKey != "env-secret" { + t.Fatalf("APIKey = %q, want env-secret", state.Config.APIKey) + } + + var apiKey workspace.ConfigValue + for _, value := range state.Values { + if value.Key == "api_key" { + apiKey = value + break + } + } + if apiKey.Source != workspace.ConfigSourceEnv || apiKey.Env != "TLD_API_KEY" { + t.Fatalf("api_key source = %q env = %q, want env/TLD_API_KEY", apiKey.Source, apiKey.Env) + } + + data, err := os.ReadFile(configPath) + if err != nil { + t.Fatalf("read config: %v", err) + } + content := string(data) + if strings.Contains(content, "env-secret") { + t.Fatalf("env override was persisted:\n%s", content) + } + if !strings.Contains(content, "unknown_root: keep-me") { + t.Fatalf("unknown key was not preserved:\n%s", content) + } + if !strings.Contains(content, "tlDiagram cloud/server URL") || !strings.Contains(content, "Shell completion settings") { + t.Fatalf("generated comments missing from config:\n%s", content) + } +} + +func TestSetGlobalConfigValuePreservesUnknownAndValidates(t *testing.T) { + configDir := t.TempDir() + t.Setenv("TLD_CONFIG_DIR", configDir) + configPath := filepath.Join(configDir, "tld.yaml") + writeFile(t, configPath, "server_url: https://tldiagram.com\nunknown_root: keep-me\nwatch:\n unknown_watch: still-here\n") + + if err := workspace.SetGlobalConfigValue("serve.port", "9000"); err != nil { + t.Fatalf("SetGlobalConfigValue: %v", err) + } + cfg, err := workspace.LoadGlobalConfig() + if err != nil { + t.Fatalf("LoadGlobalConfig: %v", err) + } + if cfg.Serve.Port != "9000" { + t.Fatalf("Serve.Port = %q, want 9000", cfg.Serve.Port) + } + data, err := os.ReadFile(configPath) + if err != nil { + t.Fatalf("read config: %v", err) + } + content := string(data) + if !strings.Contains(content, "unknown_root: keep-me") || !strings.Contains(content, "unknown_watch: still-here") { + t.Fatalf("unknown keys were not preserved:\n%s", content) + } + + if err := workspace.SetGlobalConfigValue("watch.watcher", "bogus"); err == nil { + t.Fatal("expected invalid watcher to fail") + } +} + +func TestResolveWatchLayoutConfigUsesEnvOverride(t *testing.T) { + configDir := t.TempDir() + t.Setenv("TLD_CONFIG_DIR", configDir) + t.Setenv("LAYOUT_LINK_DISTANCE", "222") + writeFile(t, filepath.Join(configDir, "tld.yaml"), "watch:\n layout:\n link_distance: 111\n") + + got := workspace.ResolveWatchLayoutConfig() + if got.LinkDistance != 222 { + t.Fatalf("LinkDistance = %v, want env override 222", got.LinkDistance) + } +} diff --git a/internal/workspace/config_test.go b/internal/workspace/config_test.go index ec1d484..8f1675f 100644 --- a/internal/workspace/config_test.go +++ b/internal/workspace/config_test.go @@ -13,7 +13,7 @@ func TestResolveDataDirDefaultsToAppDataDir(t *testing.T) { xdgData := filepath.Join(t.TempDir(), "xdg-data") t.Setenv("XDG_DATA_HOME", xdgData) - got, err := workspace.ResolveDataDir(&workspace.GlobalConfig{}, "") + got, err := workspace.ResolveDataDir(&workspace.Config{}, "") if err != nil { t.Fatalf("ResolveDataDir: %v", err) } diff --git a/internal/workspace/loader.go b/internal/workspace/loader.go index 37e5621..fab127b 100644 --- a/internal/workspace/loader.go +++ b/internal/workspace/loader.go @@ -19,21 +19,11 @@ func Load(dir string) (*Workspace, error) { } // Load config - cfgPath, err := ConfigPath() + cfg, err := LoadGlobalConfig() if err != nil { - return nil, fmt.Errorf("get config path: %w", err) - } - cfgData, err := os.ReadFile(cfgPath) - if err != nil { - return nil, fmt.Errorf("read tld.yaml: %w", err) - } - if err := yaml.Unmarshal(cfgData, &ws.Config); err != nil { - return nil, fmt.Errorf("parse tld.yaml: %w", err) - } - // Fallback: TLD_API_KEY env var - if ws.Config.APIKey == "" { - ws.Config.APIKey = os.Getenv("TLD_API_KEY") + return nil, fmt.Errorf("load global config: %w", err) } + ws.Config = *cfg // Load workspace-local configuration from .tld.yaml if present. workspaceConfigPath := WorkspaceConfigPath(dir) diff --git a/internal/workspace/loader_test.go b/internal/workspace/loader_test.go index 75bf501..3428761 100644 --- a/internal/workspace/loader_test.go +++ b/internal/workspace/loader_test.go @@ -46,16 +46,16 @@ func TestLoad_MinimalWorkspace(t *testing.T) { } } -func TestLoad_MissingConfigFile(t *testing.T) { +func TestLoad_MissingConfigFileReturnsDefaults(t *testing.T) { dir := t.TempDir() - setupConfig(t) + setupConfig(t) // config file doesn't exist yet - _, err := workspace.Load(dir) - if err == nil { - t.Fatal("expected error, got nil") + ws, err := workspace.Load(dir) + if err != nil { + t.Fatalf("Load: %v", err) } - if !strings.Contains(err.Error(), "read tld.yaml") { - t.Fatalf("error %q does not contain 'read tld.yaml'", err.Error()) + if ws.Config.ServerURL != "https://tldiagram.com" { + t.Fatalf("expected default ServerURL, got %q", ws.Config.ServerURL) } } @@ -67,16 +67,15 @@ func TestLoad_MalformedConfigYAML(t *testing.T) { if err == nil { t.Fatal("expected error") } - if !strings.Contains(err.Error(), "parse tld.yaml") { - t.Fatalf("error %q does not contain 'parse tld.yaml'", err.Error()) + if !strings.Contains(err.Error(), "parse global config") { + t.Fatalf("error %q does not contain 'parse global config'", err.Error()) } } func TestLoad_APIKeyFromEnv(t *testing.T) { dir := t.TempDir() writeFile(t, setupConfig(t), minimalConfig()) - _ = os.Setenv("TLD_API_KEY", "env-test-key") - t.Cleanup(func() { _ = os.Unsetenv("TLD_API_KEY") }) + t.Setenv("TLD_API_KEY", "env-test-key") ws, err := workspace.Load(dir) if err != nil { @@ -87,18 +86,17 @@ func TestLoad_APIKeyFromEnv(t *testing.T) { } } -func TestLoad_APIKeyFileOverridesEnv(t *testing.T) { +func TestLoad_EnvOverridesAPIKeyFile(t *testing.T) { dir := t.TempDir() writeFile(t, setupConfig(t), "server_url: http://localhost\napi_key: file-key\norg_id: \"\"\n") - _ = os.Setenv("TLD_API_KEY", "env-key") - t.Cleanup(func() { _ = os.Unsetenv("TLD_API_KEY") }) + t.Setenv("TLD_API_KEY", "env-key") ws, err := workspace.Load(dir) if err != nil { t.Fatalf("Load: %v", err) } - if ws.Config.APIKey != "file-key" { - t.Fatalf("APIKey = %q, want file-key", ws.Config.APIKey) + if ws.Config.APIKey != "env-key" { + t.Fatalf("APIKey = %q, want env-key", ws.Config.APIKey) } } diff --git a/internal/workspace/types.go b/internal/workspace/types.go index c7b536c..0279de4 100644 --- a/internal/workspace/types.go +++ b/internal/workspace/types.go @@ -8,14 +8,6 @@ import ( "github.com/mertcikla/tld/internal/ignore" ) -// Config is parsed from the user's global tld.yaml. -type Config struct { - ServerURL string `yaml:"server_url"` - APIKey string `yaml:"api_key"` - WorkspaceID string `yaml:"org_id"` - Validation *ValidationConfig `yaml:"validation,omitempty"` -} - // WorkspaceConfig is parsed from the workspace-local .tld.yaml. type WorkspaceConfig struct { ProjectName string `yaml:"project_name,omitempty"` @@ -37,16 +29,6 @@ type Repository struct { Exclude []string `yaml:"exclude,omitempty"` } -// ValidationConfig represents workspace validation settings. -const DefaultValidationLevel = 2 - -type ValidationConfig struct { - Level int `yaml:"level"` - AllowLowInsight bool `yaml:"allow_low_insight"` - IncludeRules []string `yaml:"include_rules,omitempty"` - ExcludeRules []string `yaml:"exclude_rules,omitempty"` -} - // ViewPlacement is an element placement within another element's internal view. // Parent "root" means the synthetic workspace root. type ViewPlacement struct { @@ -70,6 +52,7 @@ type Element struct { Language string `yaml:"language,omitempty"` FilePath string `yaml:"file_path,omitempty"` Symbol string `yaml:"symbol,omitempty"` // Named code symbol within FilePath (e.g. "MyFunc") + Tags []string `yaml:"tags,omitempty"` HasView bool `yaml:"has_view,omitempty"` ViewLabel string `yaml:"view_label,omitempty"` Placements []ViewPlacement `yaml:"placements,omitempty"` diff --git a/migrations/002_watch_raw_code_graph.sql b/migrations/002_watch_raw_code_graph.sql new file mode 100644 index 0000000..87cacf5 --- /dev/null +++ b/migrations/002_watch_raw_code_graph.sql @@ -0,0 +1,515 @@ +PRAGMA foreign_keys = ON; + +CREATE TABLE IF NOT EXISTS watch_repositories ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + remote_url TEXT NULL, + repo_root TEXT NOT NULL, + display_name TEXT NOT NULL, + branch TEXT NULL, + head_commit TEXT NULL, + identity_status TEXT NOT NULL DEFAULT 'known', + settings_hash TEXT NOT NULL DEFAULT '', + created_at TEXT NOT NULL, + updated_at TEXT NOT NULL +); + +CREATE UNIQUE INDEX IF NOT EXISTS idx_watch_repositories_remote_url + ON watch_repositories(remote_url) + WHERE remote_url IS NOT NULL AND remote_url <> ''; + +CREATE INDEX IF NOT EXISTS idx_watch_repositories_repo_root + ON watch_repositories(repo_root); + +CREATE TABLE IF NOT EXISTS watch_files ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + repository_id INTEGER NOT NULL, + path TEXT NOT NULL, + language TEXT NOT NULL, + git_blob_hash TEXT NULL, + worktree_hash TEXT NOT NULL, + size_bytes INTEGER NOT NULL DEFAULT 0, + mtime_unix INTEGER NOT NULL DEFAULT 0, + scan_status TEXT NOT NULL, + scan_error TEXT NULL, + created_at TEXT NOT NULL, + updated_at TEXT NOT NULL, + UNIQUE(repository_id, path), + FOREIGN KEY (repository_id) REFERENCES watch_repositories(id) ON DELETE CASCADE +); + +CREATE INDEX IF NOT EXISTS idx_watch_files_repository_id + ON watch_files(repository_id); + +CREATE TABLE IF NOT EXISTS watch_symbols ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + repository_id INTEGER NOT NULL, + file_id INTEGER NOT NULL, + stable_key TEXT NOT NULL, + name TEXT NOT NULL, + qualified_name TEXT NOT NULL, + kind TEXT NOT NULL, + start_line INTEGER NOT NULL, + end_line INTEGER NULL, + signature_hash TEXT NOT NULL, + content_hash TEXT NOT NULL, + raw_json TEXT NOT NULL DEFAULT '{}', + created_at TEXT NOT NULL, + updated_at TEXT NOT NULL, + UNIQUE(repository_id, stable_key), + FOREIGN KEY (repository_id) REFERENCES watch_repositories(id) ON DELETE CASCADE, + FOREIGN KEY (file_id) REFERENCES watch_files(id) ON DELETE CASCADE +); + +CREATE INDEX IF NOT EXISTS idx_watch_symbols_repository_id + ON watch_symbols(repository_id); + +CREATE INDEX IF NOT EXISTS idx_watch_symbols_file_id + ON watch_symbols(file_id); + +CREATE INDEX IF NOT EXISTS idx_watch_symbols_search + ON watch_symbols(repository_id, name, qualified_name, kind); + +CREATE TABLE IF NOT EXISTS watch_references ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + repository_id INTEGER NOT NULL, + source_symbol_id INTEGER NOT NULL, + target_symbol_id INTEGER NOT NULL, + source_file_id INTEGER NOT NULL, + kind TEXT NOT NULL, + line INTEGER NOT NULL, + column INTEGER NOT NULL DEFAULT 0, + evidence_hash TEXT NOT NULL, + raw_json TEXT NOT NULL DEFAULT '{}', + created_at TEXT NOT NULL, + updated_at TEXT NOT NULL, + UNIQUE(repository_id, source_symbol_id, target_symbol_id, kind, evidence_hash), + FOREIGN KEY (repository_id) REFERENCES watch_repositories(id) ON DELETE CASCADE, + FOREIGN KEY (source_symbol_id) REFERENCES watch_symbols(id) ON DELETE CASCADE, + FOREIGN KEY (target_symbol_id) REFERENCES watch_symbols(id) ON DELETE CASCADE, + FOREIGN KEY (source_file_id) REFERENCES watch_files(id) ON DELETE CASCADE +); + +CREATE INDEX IF NOT EXISTS idx_watch_references_repository_id + ON watch_references(repository_id); + +CREATE INDEX IF NOT EXISTS idx_watch_references_source_symbol_id + ON watch_references(source_symbol_id); + +CREATE INDEX IF NOT EXISTS idx_watch_references_target_symbol_id + ON watch_references(target_symbol_id); + +CREATE TABLE IF NOT EXISTS watch_facts ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + repository_id INTEGER NOT NULL, + file_id INTEGER NOT NULL, + stable_key TEXT NOT NULL, + type TEXT NOT NULL, + enricher TEXT NOT NULL, + subject_kind TEXT NOT NULL, + subject_stable_key TEXT NOT NULL, + object_kind TEXT NOT NULL DEFAULT '', + object_stable_key TEXT NOT NULL DEFAULT '', + object_file_path TEXT NOT NULL DEFAULT '', + object_name TEXT NOT NULL DEFAULT '', + relationship TEXT NOT NULL DEFAULT '', + file_path TEXT NOT NULL, + start_line INTEGER NOT NULL DEFAULT 0, + end_line INTEGER NULL, + confidence REAL NOT NULL DEFAULT 1.0, + name TEXT NOT NULL DEFAULT '', + tags TEXT NOT NULL DEFAULT '[]', + attributes_json TEXT NOT NULL DEFAULT '{}', + visibility_hints_json TEXT NOT NULL DEFAULT '{}', + fact_hash TEXT NOT NULL, + raw_json TEXT NOT NULL DEFAULT '{}', + created_at TEXT NOT NULL, + updated_at TEXT NOT NULL, + UNIQUE(repository_id, enricher, stable_key), + FOREIGN KEY (repository_id) REFERENCES watch_repositories(id) ON DELETE CASCADE, + FOREIGN KEY (file_id) REFERENCES watch_files(id) ON DELETE CASCADE +); + +CREATE INDEX IF NOT EXISTS idx_watch_facts_repository_id + ON watch_facts(repository_id); + +CREATE INDEX IF NOT EXISTS idx_watch_facts_file_id + ON watch_facts(file_id); + +CREATE INDEX IF NOT EXISTS idx_watch_facts_subject + ON watch_facts(repository_id, subject_kind, subject_stable_key); + +CREATE INDEX IF NOT EXISTS idx_watch_facts_object + ON watch_facts(repository_id, object_kind, object_stable_key); + +CREATE INDEX IF NOT EXISTS idx_watch_facts_type + ON watch_facts(repository_id, type); + +CREATE TABLE IF NOT EXISTS watch_scan_runs ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + repository_id INTEGER NOT NULL, + mode TEXT NOT NULL, + started_at TEXT NOT NULL, + finished_at TEXT NULL, + status TEXT NOT NULL, + files_seen INTEGER NOT NULL DEFAULT 0, + files_parsed INTEGER NOT NULL DEFAULT 0, + files_skipped INTEGER NOT NULL DEFAULT 0, + symbols_seen INTEGER NOT NULL DEFAULT 0, + references_seen INTEGER NOT NULL DEFAULT 0, + error TEXT NULL, + FOREIGN KEY (repository_id) REFERENCES watch_repositories(id) ON DELETE CASCADE +); + +CREATE INDEX IF NOT EXISTS idx_watch_scan_runs_repository_id + ON watch_scan_runs(repository_id); + +CREATE TABLE IF NOT EXISTS watch_embedding_models ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + provider TEXT NOT NULL, + model TEXT NOT NULL, + dimension INTEGER NOT NULL, + config_hash TEXT NOT NULL, + created_at TEXT NOT NULL, + UNIQUE(provider, model, dimension, config_hash) +); + +CREATE TABLE IF NOT EXISTS watch_embeddings ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + model_id INTEGER NOT NULL, + owner_type TEXT NOT NULL, + owner_key TEXT NOT NULL, + input_hash TEXT NOT NULL, + vector BLOB NOT NULL, + created_at TEXT NOT NULL, + UNIQUE(model_id, owner_type, owner_key, input_hash), + FOREIGN KEY (model_id) REFERENCES watch_embedding_models(id) ON DELETE CASCADE +); + +CREATE TABLE IF NOT EXISTS watch_filter_runs ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + repository_id INTEGER NOT NULL, + settings_hash TEXT NOT NULL, + raw_graph_hash TEXT NOT NULL, + started_at TEXT NOT NULL, + finished_at TEXT NULL, + status TEXT NOT NULL, + visible_symbols INTEGER NOT NULL DEFAULT 0, + hidden_symbols INTEGER NOT NULL DEFAULT 0, + visible_references INTEGER NOT NULL DEFAULT 0, + hidden_references INTEGER NOT NULL DEFAULT 0, + FOREIGN KEY (repository_id) REFERENCES watch_repositories(id) ON DELETE CASCADE +); + +CREATE INDEX IF NOT EXISTS idx_watch_filter_runs_repository_id + ON watch_filter_runs(repository_id); + +CREATE TABLE IF NOT EXISTS watch_filter_decisions ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + filter_run_id INTEGER NOT NULL, + owner_type TEXT NOT NULL, + owner_id INTEGER NOT NULL, + owner_key TEXT NOT NULL DEFAULT '', + decision TEXT NOT NULL, + reason TEXT NOT NULL, + score REAL NULL, + tier INTEGER NOT NULL DEFAULT 0, + signals_json TEXT NOT NULL DEFAULT '[]', + FOREIGN KEY (filter_run_id) REFERENCES watch_filter_runs(id) ON DELETE CASCADE +); + +CREATE INDEX IF NOT EXISTS idx_watch_filter_decisions_filter_run_id + ON watch_filter_decisions(filter_run_id); + +CREATE INDEX IF NOT EXISTS idx_watch_filter_decisions_owner + ON watch_filter_decisions(owner_type, owner_id); + +CREATE INDEX IF NOT EXISTS idx_watch_filter_decisions_owner_key + ON watch_filter_decisions(owner_type, owner_key); + +CREATE TABLE IF NOT EXISTS watch_clusters ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + repository_id INTEGER NOT NULL, + stable_key TEXT NOT NULL, + parent_cluster_id INTEGER NULL, + name TEXT NOT NULL, + kind TEXT NOT NULL, + algorithm TEXT NOT NULL, + settings_hash TEXT NOT NULL, + member_count INTEGER NOT NULL, + created_at TEXT NOT NULL, + updated_at TEXT NOT NULL, + UNIQUE(repository_id, stable_key), + FOREIGN KEY (repository_id) REFERENCES watch_repositories(id) ON DELETE CASCADE, + FOREIGN KEY (parent_cluster_id) REFERENCES watch_clusters(id) ON DELETE CASCADE +); + +CREATE INDEX IF NOT EXISTS idx_watch_clusters_repository_id + ON watch_clusters(repository_id); + +CREATE TABLE IF NOT EXISTS watch_cluster_members ( + cluster_id INTEGER NOT NULL, + owner_type TEXT NOT NULL, + owner_id INTEGER NOT NULL, + PRIMARY KEY (cluster_id, owner_type, owner_id), + FOREIGN KEY (cluster_id) REFERENCES watch_clusters(id) ON DELETE CASCADE +); + +CREATE TABLE IF NOT EXISTS watch_materialization ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + repository_id INTEGER NOT NULL, + owner_type TEXT NOT NULL, + owner_key TEXT NOT NULL, + resource_type TEXT NOT NULL, + resource_id INTEGER NOT NULL, + created_at TEXT NOT NULL, + updated_at TEXT NOT NULL, + last_watch_hash TEXT NULL, + dirty INTEGER NOT NULL DEFAULT 0, + dirty_detected_at TEXT NULL, + UNIQUE(repository_id, owner_type, owner_key, resource_type), + FOREIGN KEY (repository_id) REFERENCES watch_repositories(id) ON DELETE CASCADE +); + +CREATE INDEX IF NOT EXISTS idx_watch_materialization_repository_id + ON watch_materialization(repository_id); + +CREATE TABLE IF NOT EXISTS watch_architecture_links ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + repository_id INTEGER NOT NULL, + component_key TEXT NOT NULL, + target_repository_id INTEGER NOT NULL, + target_owner_type TEXT NOT NULL, + target_owner_key TEXT NOT NULL, + target_resource_type TEXT NOT NULL, + target_resource_id INTEGER NOT NULL, + role TEXT NOT NULL, + confidence REAL NOT NULL, + evidence_json TEXT NOT NULL DEFAULT '[]', + created_at TEXT NOT NULL, + updated_at TEXT NOT NULL, + UNIQUE(repository_id, component_key, target_repository_id, target_owner_type, target_owner_key, target_resource_type, role), + FOREIGN KEY (repository_id) REFERENCES watch_repositories(id) ON DELETE CASCADE, + FOREIGN KEY (target_repository_id) REFERENCES watch_repositories(id) ON DELETE CASCADE +); + +CREATE INDEX IF NOT EXISTS idx_watch_architecture_links_repository_id + ON watch_architecture_links(repository_id); + +CREATE INDEX IF NOT EXISTS idx_watch_architecture_links_target + ON watch_architecture_links(target_repository_id, target_owner_type, target_owner_key); + +CREATE TABLE IF NOT EXISTS watch_apply_locks ( + id INTEGER PRIMARY KEY, + repository_id INTEGER NOT NULL, + pid INTEGER NOT NULL, + token TEXT NOT NULL, + started_at TEXT NOT NULL, + heartbeat_at TEXT NOT NULL, + status TEXT NOT NULL, + FOREIGN KEY (repository_id) REFERENCES watch_repositories(id) ON DELETE CASCADE +); + +CREATE TABLE IF NOT EXISTS watch_context_policies ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + repository_id INTEGER NOT NULL, + owner_type TEXT NOT NULL, + owner_key TEXT NOT NULL, + action TEXT NOT NULL, + scope TEXT NOT NULL, + active INTEGER NOT NULL DEFAULT 1, + reason TEXT NOT NULL DEFAULT '', + created_at TEXT NOT NULL, + updated_at TEXT NOT NULL, + FOREIGN KEY (repository_id) REFERENCES watch_repositories(id) ON DELETE CASCADE +); + +CREATE INDEX IF NOT EXISTS idx_watch_context_policies_repository_active + ON watch_context_policies(repository_id, active); + +CREATE INDEX IF NOT EXISTS idx_watch_context_policies_owner + ON watch_context_policies(repository_id, owner_type, owner_key); + +CREATE TABLE IF NOT EXISTS watch_context_expansions ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + repository_id INTEGER NOT NULL, + scope_resource_type TEXT NOT NULL, + scope_resource_id INTEGER NOT NULL, + scope_owner_type TEXT NOT NULL, + scope_owner_key TEXT NOT NULL, + tier INTEGER NOT NULL DEFAULT 0, + created_at TEXT NOT NULL, + updated_at TEXT NOT NULL, + UNIQUE(repository_id, scope_resource_type, scope_resource_id, scope_owner_type, scope_owner_key), + FOREIGN KEY (repository_id) REFERENCES watch_repositories(id) ON DELETE CASCADE +); + +CREATE INDEX IF NOT EXISTS idx_watch_context_expansions_repository + ON watch_context_expansions(repository_id); + +CREATE INDEX IF NOT EXISTS idx_watch_context_expansions_scope + ON watch_context_expansions(repository_id, scope_resource_type, scope_resource_id); + +CREATE INDEX IF NOT EXISTS idx_watch_context_expansions_owner + ON watch_context_expansions(repository_id, scope_owner_type, scope_owner_key); + +CREATE TABLE IF NOT EXISTS watch_representation_runs ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + repository_id INTEGER NOT NULL, + raw_graph_hash TEXT NOT NULL, + filter_settings_hash TEXT NOT NULL, + embedding_model_id INTEGER NULL, + representation_hash TEXT NOT NULL, + started_at TEXT NOT NULL, + finished_at TEXT NULL, + status TEXT NOT NULL, + elements_created INTEGER NOT NULL DEFAULT 0, + elements_updated INTEGER NOT NULL DEFAULT 0, + connectors_created INTEGER NOT NULL DEFAULT 0, + connectors_updated INTEGER NOT NULL DEFAULT 0, + views_created INTEGER NOT NULL DEFAULT 0, + error TEXT NULL, + FOREIGN KEY (repository_id) REFERENCES watch_repositories(id) ON DELETE CASCADE, + FOREIGN KEY (embedding_model_id) REFERENCES watch_embedding_models(id) ON DELETE SET NULL +); + +CREATE INDEX IF NOT EXISTS idx_watch_representation_runs_repository_id + ON watch_representation_runs(repository_id); + +CREATE TABLE IF NOT EXISTS watch_locks ( + id INTEGER PRIMARY KEY, + repository_id INTEGER NOT NULL, + pid INTEGER NOT NULL, + token TEXT NOT NULL, + started_at TEXT NOT NULL, + heartbeat_at TEXT NOT NULL, + status TEXT NOT NULL, + FOREIGN KEY (repository_id) REFERENCES watch_repositories(id) ON DELETE CASCADE +); + +CREATE UNIQUE INDEX IF NOT EXISTS idx_watch_locks_repository_active + ON watch_locks(repository_id) + WHERE status IN ('active', 'stopping'); + +CREATE TABLE IF NOT EXISTS watch_versions ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + repository_id INTEGER NOT NULL, + commit_hash TEXT NOT NULL, + parent_commit_hash TEXT NULL, + branch TEXT NULL, + representation_hash TEXT NOT NULL, + workspace_version_id INTEGER NULL, + created_at TEXT NOT NULL, + commit_message TEXT NULL, + UNIQUE(repository_id, commit_hash, representation_hash), + FOREIGN KEY (repository_id) REFERENCES watch_repositories(id) ON DELETE CASCADE +); + +CREATE INDEX IF NOT EXISTS idx_watch_versions_repository_id + ON watch_versions(repository_id); + +CREATE TABLE IF NOT EXISTS watch_representation_diffs ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + version_id INTEGER NOT NULL, + owner_type TEXT NOT NULL, + owner_key TEXT NOT NULL, + change_type TEXT NOT NULL, + before_hash TEXT NULL, + after_hash TEXT NULL, + resource_type TEXT NULL, + resource_id INTEGER NULL, + summary TEXT NULL, + added_lines INTEGER NOT NULL DEFAULT 0, + removed_lines INTEGER NOT NULL DEFAULT 0, + FOREIGN KEY (version_id) REFERENCES watch_versions(id) ON DELETE CASCADE +); + +CREATE INDEX IF NOT EXISTS idx_watch_representation_diffs_version_id + ON watch_representation_diffs(version_id); + +CREATE TABLE IF NOT EXISTS workspace_versions ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + version_id TEXT NOT NULL UNIQUE, + source TEXT NOT NULL, + parent_version_id INTEGER NULL, + view_count INTEGER NOT NULL DEFAULT 0, + element_count INTEGER NOT NULL DEFAULT 0, + connector_count INTEGER NOT NULL DEFAULT 0, + description TEXT NULL, + workspace_hash TEXT NULL, + created_at TEXT NOT NULL, + FOREIGN KEY (parent_version_id) REFERENCES workspace_versions(id) ON DELETE SET NULL +); + +CREATE TABLE IF NOT EXISTS workspace_version_settings ( + id INTEGER PRIMARY KEY CHECK (id = 1), + cli_versioning_enabled INTEGER NOT NULL DEFAULT 1 +); + +INSERT INTO workspace_version_settings(id, cli_versioning_enabled) +VALUES (1, 1) +ON CONFLICT(id) DO NOTHING; + +CREATE TABLE IF NOT EXISTS watch_symbol_identities ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + repository_id INTEGER NOT NULL, + identity_key TEXT NOT NULL, + current_stable_key TEXT NOT NULL, + file_path TEXT NOT NULL, + kind TEXT NOT NULL, + name TEXT NOT NULL, + qualified_name TEXT NOT NULL, + start_line INTEGER NOT NULL, + content_hash TEXT NOT NULL, + created_at TEXT NOT NULL, + updated_at TEXT NOT NULL, + UNIQUE(repository_id, identity_key), + UNIQUE(repository_id, current_stable_key), + FOREIGN KEY (repository_id) REFERENCES watch_repositories(id) ON DELETE CASCADE +); + +CREATE INDEX IF NOT EXISTS idx_watch_symbol_identities_current_key + ON watch_symbol_identities(repository_id, current_stable_key); + +CREATE TABLE IF NOT EXISTS _vec_watch_embedding_vec ( + dataset_id TEXT NOT NULL, + id TEXT NOT NULL, + content TEXT, + meta TEXT, + embedding BLOB, + PRIMARY KEY(dataset_id, id) +); + +CREATE INDEX IF NOT EXISTS idx_views_owner_element_id + ON views(owner_element_id); + +CREATE INDEX IF NOT EXISTS idx_placements_element_id_view_id + ON placements(element_id, view_id); + +CREATE INDEX IF NOT EXISTS idx_placements_view_id_id + ON placements(view_id, id); + +CREATE INDEX IF NOT EXISTS idx_connectors_view_id_id + ON connectors(view_id, id); + +CREATE INDEX IF NOT EXISTS idx_elements_updated_at_id + ON elements(updated_at DESC, id DESC); + +CREATE TABLE IF NOT EXISTS watch_version_resources ( + version_id INTEGER NOT NULL, + owner_type TEXT NOT NULL, + owner_key TEXT NOT NULL, + resource_type TEXT NOT NULL, + resource_id INTEGER NULL, + language TEXT NULL, + resource_hash TEXT NOT NULL, + summary TEXT NULL, + line_count INTEGER NOT NULL DEFAULT 0, + file_path TEXT NULL, + start_line INTEGER NOT NULL DEFAULT 0, + end_line INTEGER NOT NULL DEFAULT 0, + PRIMARY KEY(version_id, owner_type, owner_key, resource_type), + FOREIGN KEY (version_id) REFERENCES watch_versions(id) ON DELETE CASCADE +); + +CREATE INDEX IF NOT EXISTS idx_watch_version_resources_version_id + ON watch_version_resources(version_id); diff --git a/migrations/003_view_density_visibility_overrides.sql b/migrations/003_view_density_visibility_overrides.sql new file mode 100644 index 0000000..4e95078 --- /dev/null +++ b/migrations/003_view_density_visibility_overrides.sql @@ -0,0 +1,17 @@ +PRAGMA foreign_keys = ON; + +ALTER TABLE views ADD COLUMN density_level INTEGER NOT NULL DEFAULT 0; + +CREATE TABLE IF NOT EXISTS view_visibility_overrides ( + view_id INTEGER NOT NULL, + resource_type TEXT NOT NULL CHECK(resource_type IN ('element', 'connector')), + resource_id INTEGER NOT NULL, + level_delta INTEGER NOT NULL DEFAULT 0 CHECK(level_delta BETWEEN -4 AND 4), + created_at TEXT NOT NULL, + updated_at TEXT NOT NULL, + PRIMARY KEY(view_id, resource_type, resource_id), + FOREIGN KEY(view_id) REFERENCES views(id) ON DELETE CASCADE +); + +CREATE INDEX IF NOT EXISTS idx_view_visibility_overrides_resource + ON view_visibility_overrides(resource_type, resource_id); diff --git a/pkg/api/dependency_service.go b/pkg/api/dependency_service.go index 3a8949a..ee3bb65 100644 --- a/pkg/api/dependency_service.go +++ b/pkg/api/dependency_service.go @@ -33,7 +33,7 @@ func (s *DependencyService) ListDependencies(ctx context.Context, req *connect.R return nil, err } - elements, err := s.Store.ListElements(ctx, workspaceID, 0, 0, "") + elements, _, err := s.Store.ListElements(ctx, workspaceID, 0, 0, "") if err != nil { return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("list elements: %w", err)) } diff --git a/pkg/api/org_service.go b/pkg/api/org_service.go new file mode 100644 index 0000000..928463f --- /dev/null +++ b/pkg/api/org_service.go @@ -0,0 +1,58 @@ +package api + +import ( + "context" + "fmt" + "strings" + + "buf.build/gen/go/tldiagramcom/diagram/connectrpc/go/diag/v1/diagv1connect" + diagv1 "buf.build/gen/go/tldiagramcom/diagram/protocolbuffers/go/diag/v1" + "connectrpc.com/connect" +) + +type OrgService struct { + diagv1connect.UnimplementedOrgServiceHandler + Store Store + Hooks WorkspaceHooks +} + +func (s *OrgService) hooks() WorkspaceHooks { + if s.Hooks == nil { + return NopWorkspaceHooks{} + } + return s.Hooks +} + +func (s *OrgService) ListTagColors(ctx context.Context, _ *connect.Request[diagv1.ListTagColorsRequest]) (*connect.Response[diagv1.ListTagColorsResponse], error) { + workspaceID := WorkspaceIDFromCtx(ctx) + if err := s.hooks().CheckRead(ctx, workspaceID); err != nil { + return nil, err + } + tags, err := s.Store.Tags(ctx, workspaceID) + if err != nil { + return nil, err + } + return connect.NewResponse(&diagv1.ListTagColorsResponse{Tags: tags}), nil +} + +func (s *OrgService) UpdateTag(ctx context.Context, req *connect.Request[diagv1.UpdateTagRequest]) (*connect.Response[diagv1.UpdateTagResponse], error) { + m := req.Msg + workspaceID := WorkspaceIDFromCtx(ctx) + if err := s.hooks().CheckWrite(ctx, workspaceID, "tags"); err != nil { + return nil, err + } + name := strings.TrimSpace(m.GetTag()) + color := strings.TrimSpace(m.GetColor()) + if name == "" { + return nil, connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("tag is required")) + } + if color == "" { + return nil, connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("color is required")) + } + if err := s.Store.UpdateTag(ctx, workspaceID, name, color, m.Description); err != nil { + return nil, err + } + resp := &diagv1.UpdateTagResponse{} + s.hooks().AfterWrite(ctx, workspaceID, "update", "tag", name, map[string]any{"color": color}, resp) + return connect.NewResponse(resp), nil +} diff --git a/pkg/api/store.go b/pkg/api/store.go index 4cdae2d..7eea819 100644 --- a/pkg/api/store.go +++ b/pkg/api/store.go @@ -72,14 +72,14 @@ type ConnectorInput struct { type Store interface { // Views ListViews(ctx context.Context, workspaceID uuid.UUID) ([]*diagv1.View, error) - GetViews(ctx context.Context, workspaceID uuid.UUID, ownerElementID *int32, isRoot *bool, search string, limit, offset int) ([]*diagv1.View, int, error) + GetViews(ctx context.Context, workspaceID uuid.UUID, parentViewID *int32, isRoot *bool, search string, limit, offset int) ([]*diagv1.View, int, error) GetView(ctx context.Context, id int32, workspaceID uuid.UUID) (*diagv1.View, error) CreateView(ctx context.Context, workspaceID uuid.UUID, ownerElementID *int32, name string, label *string, isRoot bool) (*diagv1.View, error) UpdateView(ctx context.Context, id int32, workspaceID uuid.UUID, name string, label *string) (*diagv1.View, error) DeleteView(ctx context.Context, id int32, workspaceID uuid.UUID) error // Elements - ListElements(ctx context.Context, workspaceID uuid.UUID, limit, offset int32, search string) ([]*diagv1.Element, error) + ListElements(ctx context.Context, workspaceID uuid.UUID, limit, offset int32, search string) ([]*diagv1.Element, int, error) GetElement(ctx context.Context, id int32, workspaceID uuid.UUID) (*diagv1.Element, error) CreateElement(ctx context.Context, workspaceID uuid.UUID, input ElementInput) (*diagv1.Element, error) UpdateElement(ctx context.Context, id int32, workspaceID uuid.UUID, input ElementInput) (*diagv1.Element, error) @@ -113,6 +113,10 @@ type Store interface { UpdateViewLayer(ctx context.Context, id int32, name *string, tags []string, color *string) (*diagv1.ViewLayer, error) DeleteViewLayer(ctx context.Context, id int32) error + // Tags + Tags(ctx context.Context, workspaceID uuid.UUID) (map[string]*diagv1.Tag, error) + UpdateTag(ctx context.Context, workspaceID uuid.UUID, name, color string, description *string) error + // ApplyPlan atomically applies a CLI workspace plan (create/update elements, views, connectors). ApplyPlan(ctx context.Context, workspaceID uuid.UUID, req *diagv1.ApplyPlanRequest) (*diagv1.ApplyPlanResponse, error) diff --git a/pkg/api/workspace_service.go b/pkg/api/workspace_service.go index b93963f..b52d738 100644 --- a/pkg/api/workspace_service.go +++ b/pkg/api/workspace_service.go @@ -33,6 +33,17 @@ func (s *WorkspaceService) hooks() WorkspaceHooks { return s.Hooks } +func intToInt32(n int) int32 { + switch { + case n > math.MaxInt32: + return math.MaxInt32 + case n < math.MinInt32: + return math.MinInt32 + default: + return int32(n) //nolint:gosec // clamped above + } +} + // ─── CLI RPCs ───────────────────────────────────────────────────────────────── func (s *WorkspaceService) CreateView( @@ -268,7 +279,11 @@ func (s *WorkspaceService) ExportWorkspace( g, gctx := errgroup.WithContext(ctx) g.Go(func() error { var e error; views, e = s.Store.ListViews(gctx, workspaceID); return e }) - g.Go(func() error { var e error; elements, e = s.Store.ListElements(gctx, workspaceID, 0, 0, ""); return e }) + g.Go(func() error { + var e error + elements, _, e = s.Store.ListElements(gctx, workspaceID, 0, 0, "") + return e + }) g.Go(func() error { var e error; placements, e = s.Store.ListAllPlacements(gctx, workspaceID); return e }) g.Go(func() error { var e error; connectors, e = s.Store.ListAllConnectors(gctx, workspaceID); return e }) g.Go(func() error { var e error; layers, e = s.Store.ListAllViewLayers(gctx, workspaceID); return e }) @@ -329,31 +344,22 @@ func (s *WorkspaceService) GetWorkspace( } m := req.Msg - var ownerElementID *int32 + var parentViewID *int32 if m.GetParentId() != 0 { pid := m.GetParentId() - ownerElementID = &pid + parentViewID = &pid } var isRoot *bool if m.Level != nil { ir := m.GetLevel() == 0 isRoot = &ir } - - views, totalCount, err := s.Store.GetViews(ctx, workspaceID, ownerElementID, isRoot, m.GetSearch(), int(m.GetLimit()), int(m.GetOffset())) + views, totalCount, err := s.Store.GetViews(ctx, workspaceID, parentViewID, isRoot, m.GetSearch(), int(m.GetLimit()), int(m.GetOffset())) if err != nil { return nil, storeErr("get views", err) } - var tc int32 - switch { - case totalCount > math.MaxInt32: - tc = math.MaxInt32 - case totalCount < math.MinInt32: - tc = math.MinInt32 - default: - tc = int32(totalCount) //nolint:gosec // clamped above - } + tc := intToInt32(totalCount) viewMap := make(map[int32]*diagv1.View) for _, v := range views { @@ -397,7 +403,14 @@ func (s *WorkspaceService) GetWorkspace( } resp.Content = make(map[int32]*diagv1.ViewContent) + viewIDs := make(map[int32]struct{}, len(views)) + for _, v := range views { + viewIDs[v.Id] = struct{}{} + } for _, p := range allPlacements { + if _, ok := viewIDs[p.ViewId]; !ok { + continue + } if _, ok := resp.Content[p.ViewId]; !ok { resp.Content[p.ViewId] = &diagv1.ViewContent{} } @@ -411,6 +424,9 @@ func (s *WorkspaceService) GetWorkspace( }) } for _, c := range allConnectors { + if _, ok := viewIDs[c.ViewId]; !ok { + continue + } if _, ok := resp.Content[c.ViewId]; !ok { resp.Content[c.ViewId] = &diagv1.ViewContent{} } @@ -425,6 +441,9 @@ func (s *WorkspaceService) GetWorkspace( } } for _, p := range allPlacements { + if _, ok := viewIDs[p.ViewId]; !ok { + continue + } childView, ok := elementToChildView[p.ElementId] if !ok { continue @@ -536,14 +555,19 @@ func (s *WorkspaceService) ListElements( if err := s.hooks().CheckRead(ctx, workspaceID); err != nil { return nil, err } - elements, err := s.Store.ListElements(ctx, workspaceID, req.Msg.Limit, req.Msg.Offset, req.Msg.Search) + elements, totalCount, err := s.Store.ListElements(ctx, workspaceID, req.Msg.Limit, req.Msg.Offset, req.Msg.Search) if err != nil { return nil, storeErr("list elements", err) } if elements == nil { elements = []*diagv1.Element{} } - return connect.NewResponse(&diagv1.ListElementsResponse{Elements: elements}), nil + return connect.NewResponse(&diagv1.ListElementsResponse{ + Elements: elements, + Pagination: &diagv1.Pagination{ + TotalCount: intToInt32(totalCount), + }, + }), nil } func (s *WorkspaceService) GetElement( diff --git a/pkg/api/workspace_service_contract_test.go b/pkg/api/workspace_service_contract_test.go new file mode 100644 index 0000000..39c2ad6 --- /dev/null +++ b/pkg/api/workspace_service_contract_test.go @@ -0,0 +1,386 @@ +package api + +import ( + "context" + "errors" + "strings" + "testing" + + diagv1 "buf.build/gen/go/tldiagramcom/diagram/protocolbuffers/go/diag/v1" + "connectrpc.com/connect" + "github.com/google/uuid" +) + +func TestWorkspaceService_ListElementsReturnsPaginationAndChecksRead(t *testing.T) { + workspaceID := uuid.MustParse("aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee") + store := &contractStore{ + listElements: func(ctx context.Context, id uuid.UUID, limit, offset int32, search string) ([]*diagv1.Element, int, error) { + if id != workspaceID { + t.Fatalf("workspace id = %s, want %s", id, workspaceID) + } + if limit != 2 || offset != 4 || search != "api" { + t.Fatalf("query = limit:%d offset:%d search:%q, want 2/4/api", limit, offset, search) + } + return []*diagv1.Element{{Id: 10, Name: "API"}}, 7, nil + }, + } + hooks := &recordingHooks{} + service := &WorkspaceService{Store: store, Hooks: hooks} + + resp, err := service.ListElements(WithWorkspaceID(context.Background(), workspaceID), connect.NewRequest(&diagv1.ListElementsRequest{ + Limit: 2, + Offset: 4, + Search: "api", + })) + if err != nil { + t.Fatal(err) + } + if got := resp.Msg.GetPagination().GetTotalCount(); got != 7 { + t.Fatalf("total count = %d, want 7", got) + } + if len(resp.Msg.GetElements()) != 1 || resp.Msg.GetElements()[0].GetId() != 10 { + t.Fatalf("elements = %+v, want API element", resp.Msg.GetElements()) + } + if strings.Join(hooks.events, ",") != "read" { + t.Fatalf("hook events = %v, want read", hooks.events) + } +} + +func TestWorkspaceService_CreateConnectorDefaultsValidatesAndAudits(t *testing.T) { + store := &contractStore{ + createConnector: func(_ context.Context, _ uuid.UUID, input ConnectorInput) (*diagv1.Connector, error) { + if input.ViewID != 3 || input.SourceID != 4 || input.TargetID != 5 { + t.Fatalf("connector ids = %+v, want view/source/target 3/4/5", input) + } + if input.Direction != "forward" || input.Style != "bezier" { + t.Fatalf("connector defaults = direction:%q style:%q, want forward/bezier", input.Direction, input.Style) + } + if input.Label == nil || *input.Label != "uses" { + t.Fatalf("label = %v, want uses", input.Label) + } + return &diagv1.Connector{Id: 99, ViewId: 3, SourceElementId: 4, TargetElementId: 5, Direction: input.Direction, Style: input.Style, Label: input.Label}, nil + }, + } + hooks := &recordingHooks{} + service := &WorkspaceService{Store: store, Hooks: hooks} + + resp, err := service.CreateConnector(context.Background(), connect.NewRequest(&diagv1.CreateConnectorRequest{ + ViewId: 3, + SourceElementId: 4, + TargetElementId: 5, + Label: new("uses"), + })) + if err != nil { + t.Fatal(err) + } + if resp.Msg.GetConnector().GetId() != 99 { + t.Fatalf("connector id = %d, want 99", resp.Msg.GetConnector().GetId()) + } + if got := strings.Join(hooks.events, ","); got != "write:connectors,after:create:connector:99" { + t.Fatalf("hook events = %s", got) + } +} + +func TestWorkspaceService_CreateConnectorRejectsInvalidStyleBeforeStoreWrite(t *testing.T) { + store := &contractStore{ + createConnector: func(context.Context, uuid.UUID, ConnectorInput) (*diagv1.Connector, error) { + t.Fatal("store should not be called for invalid style") + return nil, nil + }, + } + service := &WorkspaceService{Store: store, Hooks: &recordingHooks{}} + + _, err := service.CreateConnector(context.Background(), connect.NewRequest(&diagv1.CreateConnectorRequest{ + ViewId: 3, + SourceElementId: 4, + TargetElementId: 5, + Style: "zigzag", + })) + if code := connect.CodeOf(err); code != connect.CodeInvalidArgument { + t.Fatalf("code = %s, want invalid_argument: %v", code, err) + } +} + +func TestWorkspaceService_UpdateElementClearsLogoWhenNoPrimaryIcon(t *testing.T) { + var update ElementInput + store := &contractStore{ + getElement: func(context.Context, int32, uuid.UUID) (*diagv1.Element, error) { + return &diagv1.Element{ + Id: 42, + Name: "API", + LogoUrl: new("https://example.com/logo.svg"), + TechnologyLinks: []*diagv1.TechnologyLink{{ + Type: "catalog", + Label: "Go", + Slug: new("go"), + IsPrimaryIcon: true, + }}, + }, nil + }, + updateElement: func(_ context.Context, id int32, _ uuid.UUID, input ElementInput) (*diagv1.Element, error) { + if id != 42 { + t.Fatalf("id = %d, want 42", id) + } + update = input + return &diagv1.Element{Id: id, Name: input.Name, LogoUrl: input.LogoURL, TechnologyLinks: input.TechLinks}, nil + }, + } + service := &WorkspaceService{Store: store, Hooks: &recordingHooks{}} + + resp, err := service.UpdateElement(context.Background(), connect.NewRequest(&diagv1.UpdateElementRequest{ + ElementId: 42, + Name: "API", + TechnologyLinks: []*diagv1.TechnologyLink{{ + Type: "catalog", + Label: "Kafka", + Slug: new("kafka"), + }}, + LogoUrl: new("https://example.com/kafka.svg"), + })) + if err != nil { + t.Fatal(err) + } + if update.LogoURL == nil || *update.LogoURL != "" { + t.Fatalf("update logo url = %v, want explicit empty string", update.LogoURL) + } + if got := resp.Msg.GetElement().GetLogoUrl(); got != "" { + t.Fatalf("response logo url = %q, want cleared", got) + } +} + +func TestWorkspaceService_UpdateElementPreservesExistingTechnologyLinksWhenOmitted(t *testing.T) { + existingLinks := []*diagv1.TechnologyLink{{ + Type: "catalog", + Label: "Go", + Slug: new("go"), + IsPrimaryIcon: true, + }} + var update ElementInput + store := &contractStore{ + getElement: func(context.Context, int32, uuid.UUID) (*diagv1.Element, error) { + return &diagv1.Element{Id: 42, Name: "API", LogoUrl: new("go.svg"), TechnologyLinks: existingLinks}, nil + }, + updateElement: func(_ context.Context, id int32, _ uuid.UUID, input ElementInput) (*diagv1.Element, error) { + update = input + return &diagv1.Element{Id: id, Name: input.Name, LogoUrl: input.LogoURL, TechnologyLinks: input.TechLinks}, nil + }, + } + service := &WorkspaceService{Store: store, Hooks: &recordingHooks{}} + + _, err := service.UpdateElement(context.Background(), connect.NewRequest(&diagv1.UpdateElementRequest{ + ElementId: 42, + Name: "API", + })) + if err != nil { + t.Fatal(err) + } + if update.TechLinks != nil { + t.Fatalf("tech link patch = %+v, want nil so store preserves existing links", update.TechLinks) + } + if update.LogoURL != nil { + t.Fatalf("logo patch = %v, want nil so store preserves existing logo", update.LogoURL) + } +} + +func TestWorkspaceService_CreateViewLayerValidatesViewAndName(t *testing.T) { + tests := []struct { + name string + req *diagv1.CreateViewLayerRequest + store *contractStore + wantErr connect.Code + }{ + { + name: "missing view id", + req: &diagv1.CreateViewLayerRequest{Name: "Runtime"}, + store: &contractStore{ + getView: func(context.Context, int32, uuid.UUID) (*diagv1.View, error) { + t.Fatal("store should not be called without a view id") + return nil, nil + }, + }, + wantErr: connect.CodeInvalidArgument, + }, + { + name: "unknown view", + req: &diagv1.CreateViewLayerRequest{ViewId: 7, Name: "Runtime"}, + store: &contractStore{ + getView: func(context.Context, int32, uuid.UUID) (*diagv1.View, error) { + return nil, errors.New("missing") + }, + }, + wantErr: connect.CodeNotFound, + }, + { + name: "blank name", + req: &diagv1.CreateViewLayerRequest{ViewId: 7, Name: " "}, + store: &contractStore{ + getView: func(context.Context, int32, uuid.UUID) (*diagv1.View, error) { + return &diagv1.View{Id: 7}, nil + }, + }, + wantErr: connect.CodeInvalidArgument, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + service := &WorkspaceService{Store: tt.store, Hooks: &recordingHooks{}} + _, err := service.CreateViewLayer(context.Background(), connect.NewRequest(tt.req)) + if code := connect.CodeOf(err); code != tt.wantErr { + t.Fatalf("code = %s, want %s: %v", code, tt.wantErr, err) + } + }) + } +} + +type recordingHooks struct { + NopWorkspaceHooks + events []string +} + +func (h *recordingHooks) CheckRead(context.Context, uuid.UUID) error { + h.events = append(h.events, "read") + return nil +} + +func (h *recordingHooks) CheckWrite(_ context.Context, _ uuid.UUID, resourceType string) error { + h.events = append(h.events, "write:"+resourceType) + return nil +} + +func (h *recordingHooks) AfterWrite(_ context.Context, _ uuid.UUID, action string, resourceType string, resourceID string, _ map[string]any, _ any) { + h.events = append(h.events, "after:"+action+":"+resourceType+":"+resourceID) +} + +type contractStore struct { + listElements func(context.Context, uuid.UUID, int32, int32, string) ([]*diagv1.Element, int, error) + getElement func(context.Context, int32, uuid.UUID) (*diagv1.Element, error) + updateElement func(context.Context, int32, uuid.UUID, ElementInput) (*diagv1.Element, error) + getView func(context.Context, int32, uuid.UUID) (*diagv1.View, error) + createConnector func(context.Context, uuid.UUID, ConnectorInput) (*diagv1.Connector, error) +} + +var _ Store = (*contractStore)(nil) + +func (s *contractStore) ListViews(context.Context, uuid.UUID) ([]*diagv1.View, error) { + return nil, nil +} +func (s *contractStore) GetViews(context.Context, uuid.UUID, *int32, *bool, string, int, int) ([]*diagv1.View, int, error) { + return nil, 0, nil +} +func (s *contractStore) GetView(ctx context.Context, id int32, workspaceID uuid.UUID) (*diagv1.View, error) { + if s.getView != nil { + return s.getView(ctx, id, workspaceID) + } + return &diagv1.View{Id: id}, nil +} +func (s *contractStore) CreateView(context.Context, uuid.UUID, *int32, string, *string, bool) (*diagv1.View, error) { + return nil, nil +} +func (s *contractStore) UpdateView(context.Context, int32, uuid.UUID, string, *string) (*diagv1.View, error) { + return nil, nil +} +func (s *contractStore) DeleteView(context.Context, int32, uuid.UUID) error { return nil } +func (s *contractStore) ListElements(ctx context.Context, workspaceID uuid.UUID, limit, offset int32, search string) ([]*diagv1.Element, int, error) { + if s.listElements != nil { + return s.listElements(ctx, workspaceID, limit, offset, search) + } + return nil, 0, nil +} +func (s *contractStore) GetElement(ctx context.Context, id int32, workspaceID uuid.UUID) (*diagv1.Element, error) { + if s.getElement != nil { + return s.getElement(ctx, id, workspaceID) + } + return nil, errors.New("element not found") +} +func (s *contractStore) CreateElement(context.Context, uuid.UUID, ElementInput) (*diagv1.Element, error) { + return nil, nil +} +func (s *contractStore) UpdateElement(ctx context.Context, id int32, workspaceID uuid.UUID, input ElementInput) (*diagv1.Element, error) { + if s.updateElement != nil { + return s.updateElement(ctx, id, workspaceID, input) + } + return nil, nil +} +func (s *contractStore) DeleteElement(context.Context, int32, uuid.UUID) error { return nil } +func (s *contractStore) ListPlacements(context.Context, int32) ([]*diagv1.PlacedElement, error) { + return nil, nil +} +func (s *contractStore) ListAllPlacements(context.Context, uuid.UUID) ([]*diagv1.PlacedElement, error) { + return nil, nil +} +func (s *contractStore) ListElementPlacements(context.Context, int32, uuid.UUID) ([]*diagv1.ViewPlacementInfo, error) { + return nil, nil +} +func (s *contractStore) AddPlacement(context.Context, int32, int32, float64, float64) (*diagv1.PlacedElement, error) { + return nil, nil +} +func (s *contractStore) UpdatePlacementPosition(context.Context, int32, int32, float64, float64) error { + return nil +} +func (s *contractStore) RemovePlacement(context.Context, int32, int32) error { return nil } +func (s *contractStore) ListConnectors(context.Context, int32, uuid.UUID) ([]*diagv1.Connector, error) { + return nil, nil +} +func (s *contractStore) ListAllConnectors(context.Context, uuid.UUID) ([]*diagv1.Connector, error) { + return nil, nil +} +func (s *contractStore) GetConnector(context.Context, int32, uuid.UUID) (*diagv1.Connector, error) { + return nil, nil +} +func (s *contractStore) CreateConnector(ctx context.Context, workspaceID uuid.UUID, input ConnectorInput) (*diagv1.Connector, error) { + if s.createConnector != nil { + return s.createConnector(ctx, workspaceID, input) + } + return nil, nil +} +func (s *contractStore) UpdateConnector(context.Context, int32, uuid.UUID, ConnectorInput) (*diagv1.Connector, error) { + return nil, nil +} +func (s *contractStore) DeleteConnector(context.Context, int32, uuid.UUID) error { return nil } +func (s *contractStore) ListElementNavigations(context.Context, uuid.UUID, int32) ([]*diagv1.ElementNavigationInfo, error) { + return nil, nil +} +func (s *contractStore) ListIncomingElementNavigations(context.Context, int32) ([]*diagv1.IncomingElementNavigationInfo, error) { + return nil, nil +} +func (s *contractStore) ListViewLayers(context.Context, int32) ([]*diagv1.ViewLayer, error) { + return nil, nil +} +func (s *contractStore) ListAllViewLayers(context.Context, uuid.UUID) ([]*diagv1.ViewLayer, error) { + return nil, nil +} +func (s *contractStore) GetViewLayer(context.Context, int32) (*diagv1.ViewLayer, error) { + return nil, nil +} +func (s *contractStore) CreateViewLayer(context.Context, int32, string, []string, string) (*diagv1.ViewLayer, error) { + return nil, nil +} +func (s *contractStore) UpdateViewLayer(context.Context, int32, *string, []string, *string) (*diagv1.ViewLayer, error) { + return nil, nil +} +func (s *contractStore) DeleteViewLayer(context.Context, int32) error { return nil } +func (s *contractStore) Tags(context.Context, uuid.UUID) (map[string]*diagv1.Tag, error) { + return nil, nil +} +func (s *contractStore) UpdateTag(context.Context, uuid.UUID, string, string, *string) error { + return nil +} +func (s *contractStore) ApplyPlan(context.Context, uuid.UUID, *diagv1.ApplyPlanRequest) (*diagv1.ApplyPlanResponse, error) { + return nil, nil +} +func (s *contractStore) ListVersions(context.Context, uuid.UUID, int) ([]*diagv1.WorkspaceVersionInfo, error) { + return nil, nil +} +func (s *contractStore) GetLatestVersion(context.Context, uuid.UUID) (*diagv1.WorkspaceVersionInfo, error) { + return nil, nil +} +func (s *contractStore) CreateVersion(context.Context, uuid.UUID, string, string, *int32, int, int, int, *string, *string) (*diagv1.WorkspaceVersionInfo, error) { + return nil, nil +} +func (s *contractStore) GetVersioningEnabled(context.Context, uuid.UUID) (bool, error) { + return false, nil +} +func (s *contractStore) SetVersioningEnabled(context.Context, uuid.UUID, bool) error { return nil } +func (s *contractStore) GetWorkspaceResourceCounts(context.Context, uuid.UUID) (int, int, int, error) { + return 0, 0, 0, nil +} diff --git a/scripts/benchmark_embeddings.go b/scripts/benchmark_embeddings.go new file mode 100644 index 0000000..8b70c17 --- /dev/null +++ b/scripts/benchmark_embeddings.go @@ -0,0 +1,143 @@ +//go:build ignore + +package main + +import ( + "bytes" + "context" + "encoding/json" + "flag" + "fmt" + "io" + "net/http" + "os" + "strings" + "time" +) + +type embeddingRequest struct { + Model string `json:"model"` + Input []string `json:"input"` +} + +type embeddingResponse struct { + Data []struct { + Embedding []float64 `json:"embedding"` + } `json:"data"` +} + +func main() { + endpoint := flag.String("endpoint", envDefault("TLD_EMBEDDING_ENDPOINT", "http://127.0.0.1:8000/v1/embeddings"), "OpenAI-compatible embeddings endpoint") + model := flag.String("model", envDefault("TLD_EMBEDDING_MODEL", "embeddinggemma-300m-4bit"), "embedding model") + repeats := flag.Int("repeats", intEnvDefault("TLD_EMBEDDING_REPEATS", 3), "measured requests per batch size") + warmup := flag.Int("warmup", intEnvDefault("TLD_EMBEDDING_WARMUP", 1), "warmup requests per batch size") + flag.Parse() + + client := &http.Client{Timeout: 5 * time.Minute} + ctx := context.Background() + fmt.Printf("endpoint=%s\nmodel=%s\nrepeats=%d warmup=%d\n\n", *endpoint, *model, *repeats, *warmup) + fmt.Printf("%6s %10s %10s %10s %10s %10s\n", "batch", "avg_ms", "min_ms", "max_ms", "items/s", "dim") + + for batch := 1; batch <= 512; batch *= 2 { + for i := 0; i < *warmup; i++ { + if _, _, err := runOnce(ctx, client, *endpoint, *model, batch); err != nil { + fail(batch, err) + } + } + + var total, min, max time.Duration + dim := 0 + for i := 0; i < *repeats; i++ { + elapsed, nextDim, err := runOnce(ctx, client, *endpoint, *model, batch) + if err != nil { + fail(batch, err) + } + if i == 0 || elapsed < min { + min = elapsed + } + if elapsed > max { + max = elapsed + } + total += elapsed + dim = nextDim + } + avg := total / time.Duration(*repeats) + itemsPerSecond := float64(batch) / avg.Seconds() + fmt.Printf("%6d %10.1f %10.1f %10.1f %10.1f %10d\n", batch, ms(avg), ms(min), ms(max), itemsPerSecond, dim) + } +} + +func runOnce(ctx context.Context, client *http.Client, endpoint, model string, batch int) (time.Duration, int, error) { + body, err := json.Marshal(embeddingRequest{Model: model, Input: inputs(batch)}) + if err != nil { + return 0, 0, err + } + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(body)) + if err != nil { + return 0, 0, err + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer tldcli") + + start := time.Now() + resp, err := client.Do(req) + elapsed := time.Since(start) + if err != nil { + return 0, 0, err + } + defer resp.Body.Close() + data, err := io.ReadAll(resp.Body) + if err != nil { + return 0, 0, err + } + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return 0, 0, fmt.Errorf("%s: %s", resp.Status, strings.TrimSpace(string(data))) + } + var parsed embeddingResponse + if err := json.Unmarshal(data, &parsed); err != nil { + return 0, 0, err + } + if len(parsed.Data) != batch { + return 0, 0, fmt.Errorf("expected %d embeddings, got %d", batch, len(parsed.Data)) + } + dim := 0 + if len(parsed.Data) > 0 { + dim = len(parsed.Data[0].Embedding) + } + return elapsed, dim, nil +} + +func inputs(batch int) []string { + out := make([]string, batch) + base := "package main\n\nfunc FetchUserProfile(ctx context.Context, userID string) (*User, error) {\n\treturn repository.LoadUser(ctx, userID)\n}\n" + for i := range out { + out[i] = fmt.Sprintf("%s\n// benchmark sample %d: code symbol embedding context with repository, service, handler, and tests.code symbol embedding context with repository, service, handler, and tests.code symbol embedding context with repository, service, handler, and tests.code symbol embedding context with repository, service, handler, and tests.", base, i+1) + } + return out +} + +func ms(d time.Duration) float64 { + return float64(d.Microseconds()) / 1000 +} + +func fail(batch int, err error) { + fmt.Fprintf(os.Stderr, "batch %d failed: %v\n", batch, err) + os.Exit(1) +} + +func envDefault(name, fallback string) string { + if value := os.Getenv(name); value != "" { + return value + } + return fallback +} + +func intEnvDefault(name string, fallback int) int { + if value := os.Getenv(name); value != "" { + var parsed int + if _, err := fmt.Sscanf(value, "%d", &parsed); err == nil && parsed > 0 { + return parsed + } + } + return fallback +} diff --git a/scripts/dev/add_technology_icon.go b/scripts/dev/add_technology_icon.go new file mode 100644 index 0000000..1f3cc26 --- /dev/null +++ b/scripts/dev/add_technology_icon.go @@ -0,0 +1,328 @@ +package main + +import ( + "archive/tar" + "bytes" + "compress/gzip" + "encoding/json" + "errors" + "flag" + "fmt" + "io" + "os" + "path/filepath" + "regexp" + "strings" + "time" +) + +type catalogItem struct { + IconURL string `json:"iconUrl,omitempty"` + Name string `json:"name"` + Provider string `json:"provider,omitempty"` + DocsURL string `json:"docsUrl,omitempty"` + Description string `json:"description,omitempty"` + WebsiteURL string `json:"websiteUrl,omitempty"` + NameShort string `json:"nameShort"` + DefaultSlug string `json:"defaultSlug"` +} + +type validationCatalogItem struct { + Name string `json:"name"` + NameShort string `json:"nameShort"` + DefaultSlug string `json:"defaultSlug"` +} + +type archiveEntry struct { + name string + body []byte + mode int64 + typeflag byte +} + +var slugPartRE = regexp.MustCompile(`[^a-z0-9]+`) + +func main() { + if err := run(); err != nil { + fmt.Fprintf(os.Stderr, "add technology icon: %v\n", err) + os.Exit(1) + } +} + +func run() error { + var ( + name = flag.String("name", "", "technology name") + nameShort = flag.String("short", "", "short display name; defaults to -name") + slug = flag.String("slug", "", "icon/catalog slug; defaults to a slugified -name") + iconPath = flag.String("icon", "", "path to a PNG icon") + provider = flag.String("provider", "", "optional provider, for example aws, azure, gcp") + docsURL = flag.String("docs-url", "", "optional documentation URL") + websiteURL = flag.String("website-url", "", "optional website URL") + description = flag.String("description", "", "optional catalog description") + archivePath = flag.String("archive", filepath.Join("build-assets", "icons.tar.gz"), "path to icons tar.gz") + techCatalog = flag.String("tech-catalog", filepath.Join("internal", "tech", "icons.json"), "path to backend validation catalog") + replace = flag.Bool("replace", false, "replace an existing catalog/icon entry with the same slug") + ) + flag.Parse() + + itemName := strings.TrimSpace(*name) + if itemName == "" { + return errors.New("-name is required") + } + if strings.TrimSpace(*iconPath) == "" { + return errors.New("-icon is required") + } + + itemSlug := strings.TrimSpace(*slug) + if itemSlug == "" { + itemSlug = slugify(itemName) + } else { + itemSlug = slugify(itemSlug) + } + if itemSlug == "" { + return fmt.Errorf("could not derive a slug from %q", itemName) + } + + iconBody, err := os.ReadFile(*iconPath) + if err != nil { + return err + } + if !isPNG(iconBody) { + return fmt.Errorf("%s is not a PNG; catalog icons must be PNG files", *iconPath) + } + + entries, err := readArchive(*archivePath) + if err != nil { + return err + } + + catalog, err := readCatalog(entries) + if err != nil { + return err + } + + short := strings.TrimSpace(*nameShort) + if short == "" { + short = itemName + } + item := catalogItem{ + IconURL: "/icons/" + itemSlug + ".png", + Name: itemName, + Provider: strings.TrimSpace(*provider), + DocsURL: strings.TrimSpace(*docsURL), + Description: strings.TrimSpace(*description), + WebsiteURL: strings.TrimSpace(*websiteURL), + NameShort: short, + DefaultSlug: itemSlug, + } + + updatedCatalog, err := upsertCatalogItem(catalog, item, *replace) + if err != nil { + return err + } + + catalogJSON, err := marshalJSON(updatedCatalog) + if err != nil { + return err + } + + iconName := "icons/" + itemSlug + ".png" + updatedEntries := upsertArchiveEntry(entries, archiveEntry{name: iconName, body: iconBody, mode: 0o644, typeflag: tar.TypeReg}, *replace) + updatedEntries = upsertArchiveEntry(updatedEntries, archiveEntry{name: "icons.json", body: catalogJSON, mode: 0o644, typeflag: tar.TypeReg}, true) + + if err := writeArchive(*archivePath, updatedEntries); err != nil { + return err + } + if err := writeValidationCatalog(*techCatalog, updatedCatalog); err != nil { + return err + } + + fmt.Printf("Added %s (%s) to %s\n", item.Name, item.DefaultSlug, *archivePath) + return nil +} + +func slugify(value string) string { + slug := strings.ToLower(strings.TrimSpace(value)) + slug = slugPartRE.ReplaceAllString(slug, "-") + return strings.Trim(slug, "-") +} + +func isPNG(body []byte) bool { + return len(body) >= 8 && + body[0] == 0x89 && + body[1] == 'P' && + body[2] == 'N' && + body[3] == 'G' && + body[4] == '\r' && + body[5] == '\n' && + body[6] == 0x1a && + body[7] == '\n' +} + +func readArchive(path string) ([]archiveEntry, error) { + f, err := os.Open(path) + if err != nil { + return nil, err + } + defer func() { _ = f.Close() }() + + gzr, err := gzip.NewReader(f) + if err != nil { + return nil, err + } + defer func() { _ = gzr.Close() }() + + var entries []archiveEntry + tr := tar.NewReader(gzr) + for { + hdr, err := tr.Next() + if err == io.EOF { + return entries, nil + } + if err != nil { + return nil, err + } + + entry := archiveEntry{name: filepath.ToSlash(filepath.Clean(hdr.Name)), mode: hdr.Mode, typeflag: hdr.Typeflag} + if hdr.Typeflag == tar.TypeReg { + entry.body, err = io.ReadAll(tr) + if err != nil { + return nil, err + } + } + entries = append(entries, entry) + } +} + +func readCatalog(entries []archiveEntry) ([]catalogItem, error) { + for _, entry := range entries { + if entry.name != "icons.json" { + continue + } + var items []catalogItem + if err := json.Unmarshal(entry.body, &items); err != nil { + return nil, err + } + return items, nil + } + return nil, errors.New("icons.json not found in archive") +} + +func upsertCatalogItem(items []catalogItem, item catalogItem, replace bool) ([]catalogItem, error) { + out := make([]catalogItem, 0, len(items)+1) + replaced := false + for _, existing := range items { + if existing.DefaultSlug != item.DefaultSlug { + out = append(out, existing) + continue + } + if !replace { + return nil, fmt.Errorf("catalog entry %q already exists; rerun with -replace to overwrite it", item.DefaultSlug) + } + out = append(out, item) + replaced = true + } + if !replaced { + out = append(out, item) + } + return out, nil +} + +func upsertArchiveEntry(entries []archiveEntry, item archiveEntry, replace bool) []archiveEntry { + out := make([]archiveEntry, 0, len(entries)+1) + replaced := false + for _, entry := range entries { + if entry.name != item.name { + out = append(out, entry) + continue + } + if replace { + out = append(out, item) + } else { + out = append(out, entry) + } + replaced = true + } + if !replaced { + out = append(out, item) + } + return out +} + +func writeArchive(path string, entries []archiveEntry) error { + var buf bytes.Buffer + gzw, err := gzip.NewWriterLevel(&buf, gzip.BestCompression) + if err != nil { + return err + } + gzw.Name = filepath.Base(path) + gzw.ModTime = time.Unix(0, 0) + + tw := tar.NewWriter(gzw) + for _, entry := range entries { + mode := entry.mode + if mode == 0 { + mode = 0o644 + } + typeflag := entry.typeflag + if typeflag == 0 { + typeflag = tar.TypeReg + } + hdr := &tar.Header{ + Name: entry.name, + Mode: mode, + Typeflag: typeflag, + ModTime: time.Unix(0, 0), + } + if typeflag == tar.TypeReg { + hdr.Size = int64(len(entry.body)) + } + if err := tw.WriteHeader(hdr); err != nil { + return err + } + if typeflag == tar.TypeReg { + if _, err := tw.Write(entry.body); err != nil { + return err + } + } + } + if err := tw.Close(); err != nil { + return err + } + if err := gzw.Close(); err != nil { + return err + } + + tmp := path + ".tmp" + if err := os.WriteFile(tmp, buf.Bytes(), 0o644); err != nil { + return err + } + return os.Rename(tmp, path) +} + +func writeValidationCatalog(path string, items []catalogItem) error { + validationItems := make([]validationCatalogItem, 0, len(items)) + for _, item := range items { + validationItems = append(validationItems, validationCatalogItem{ + Name: item.Name, + NameShort: item.NameShort, + DefaultSlug: item.DefaultSlug, + }) + } + body, err := marshalJSON(validationItems) + if err != nil { + return err + } + return os.WriteFile(path, body, 0o644) +} + +func marshalJSON(v any) ([]byte, error) { + var buf bytes.Buffer + enc := json.NewEncoder(&buf) + enc.SetEscapeHTML(false) + enc.SetIndent("", " ") + if err := enc.Encode(v); err != nil { + return nil, err + } + return buf.Bytes(), nil +}