diff --git a/internal/ledger/payload.go b/internal/ledger/payload.go new file mode 100644 index 0000000..9f09156 --- /dev/null +++ b/internal/ledger/payload.go @@ -0,0 +1,26 @@ +package ledger + +import "github.com/kontext-security/kontext-cli/internal/guard/store/sqlite" + +const ( + SchemaVersion = "authorization-ledger-v1" + DefaultEndpoint = "/api/v1/authorization-ledger/batches" +) + +type Payload struct { + SchemaVersion string `json:"schema_version"` + OrganizationID string `json:"organization_id"` + InstallationID string `json:"installation_id"` + BatchID string `json:"batch_id"` + SentAt string `json:"sent_at"` + Device *Device `json:"device,omitempty"` + Sessions []sqlite.LedgerRecord `json:"agent_sessions"` + Actions []sqlite.LedgerRecord `json:"authorization_actions"` + Receipts []sqlite.LedgerRecord `json:"authorization_receipts"` + ReceiptChainAnchor *sqlite.LedgerReceiptChainAnchor `json:"receipt_chain_anchor,omitempty"` +} + +type Device struct { + Label string `json:"label,omitempty"` + DeploymentVersion string `json:"deployment_version,omitempty"` +} diff --git a/internal/ledger/payload_test.go b/internal/ledger/payload_test.go new file mode 100644 index 0000000..e284625 --- /dev/null +++ b/internal/ledger/payload_test.go @@ -0,0 +1,68 @@ +package ledger + +import ( + "encoding/json" + "testing" + + "github.com/kontext-security/kontext-cli/internal/guard/store/sqlite" +) + +func TestPayloadJSONShape(t *testing.T) { + payload := Payload{ + SchemaVersion: SchemaVersion, + OrganizationID: "org_123", + InstallationID: "ins_123", + BatchID: "batch_abc", + SentAt: "2026-05-31T10:00:00Z", + Device: &Device{Label: "test-mac"}, + Sessions: []sqlite.LedgerRecord{}, + Actions: []sqlite.LedgerRecord{{"session_id": "claude-session"}}, + Receipts: []sqlite.LedgerRecord{}, + } + + data, err := json.Marshal(payload) + if err != nil { + t.Fatalf("Marshal() error = %v", err) + } + + var got map[string]any + if err := json.Unmarshal(data, &got); err != nil { + t.Fatalf("Unmarshal() error = %v", err) + } + + for _, key := range []string{ + "schema_version", + "organization_id", + "installation_id", + "batch_id", + "sent_at", + "agent_sessions", + "authorization_actions", + "authorization_receipts", + "device", + } { + if _, ok := got[key]; !ok { + t.Fatalf("missing JSON key %q in %s", key, string(data)) + } + } + if _, ok := got["receipt_chain_anchor"]; ok { + t.Fatalf("receipt_chain_anchor was present for nil anchor: %s", string(data)) + } + + device, ok := got["device"].(map[string]any) + if !ok { + t.Fatalf("device = %#v, want object: %s", got["device"], string(data)) + } + if device["label"] != "test-mac" { + t.Fatalf("device.label = %#v, want %q", device["label"], "test-mac") + } + if _, ok := device["deployment_version"]; ok { + t.Fatalf("device.deployment_version was present for empty value: %s", string(data)) + } + + for _, key := range []string{"agent_sessions", "authorization_actions", "authorization_receipts"} { + if _, ok := got[key].([]any); !ok { + t.Fatalf("%s = %#v, want array: %s", key, got[key], string(data)) + } + } +} diff --git a/internal/managedobserve/daemon_test.go b/internal/managedobserve/daemon_test.go index 25a9a6f..5e2ed9f 100644 --- a/internal/managedobserve/daemon_test.go +++ b/internal/managedobserve/daemon_test.go @@ -14,6 +14,7 @@ import ( "github.com/kontext-security/kontext-cli/internal/guard/store/sqlite" "github.com/kontext-security/kontext-cli/internal/hook" + "github.com/kontext-security/kontext-cli/internal/ledger" "github.com/kontext-security/kontext-cli/internal/localruntime" ) @@ -91,26 +92,15 @@ func TestDaemonSessionEndClosesHookSessionID(t *testing.T) { } func TestDaemonStreamsLedgerBatches(t *testing.T) { - type ledgerBatchRequest struct { - OrganizationID string `json:"organization_id"` - InstallationID string `json:"installation_id"` - Device *struct { - Label string `json:"label"` - } `json:"device,omitempty"` - Actions []struct { - SessionID string `json:"session_id"` - } `json:"authorization_actions"` - } - - requests := make(chan ledgerBatchRequest, 1) + requests := make(chan ledger.Payload, 1) server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.URL.Path != "/api/v1/authorization-ledger/batches" { + if r.URL.Path != ledger.DefaultEndpoint { t.Fatalf("path = %q", r.URL.Path) } if got := r.Header.Get("Authorization"); got != "Bearer test-install-token" { t.Fatalf("Authorization = %q, want bearer install token", got) } - var body ledgerBatchRequest + var body ledger.Payload if err := json.NewDecoder(r.Body).Decode(&body); err != nil { t.Fatalf("Decode() error = %v", err) } @@ -180,7 +170,8 @@ func TestDaemonStreamsLedgerBatches(t *testing.T) { } found := false for _, action := range body.Actions { - if action.SessionID == "claude-stream-session" { + sessionID, _ := action["session_id"].(string) + if sessionID == "claude-stream-session" { found = true } } diff --git a/internal/managedobserve/lifecycle.go b/internal/managedobserve/lifecycle.go index d7c5bdc..ce92d5c 100644 --- a/internal/managedobserve/lifecycle.go +++ b/internal/managedobserve/lifecycle.go @@ -101,7 +101,9 @@ func (l Lifecycle) probe(ctx context.Context) bool { if err != nil { return false } - _ = conn.Close() + if err := conn.Close(); err != nil { + l.Diagnostic.Printf("managed observe probe close: %v\n", err) + } return true } diff --git a/internal/managedstream/stream.go b/internal/managedstream/stream.go index 940519b..cc4cc57 100644 --- a/internal/managedstream/stream.go +++ b/internal/managedstream/stream.go @@ -17,12 +17,10 @@ import ( "github.com/kontext-security/kontext-cli/internal/diagnostic" "github.com/kontext-security/kontext-cli/internal/guard/store/sqlite" + "github.com/kontext-security/kontext-cli/internal/ledger" ) const ( - SchemaVersion = "authorization-ledger-v1" - DefaultEndpoint = "/api/v1/authorization-ledger/batches" - DefaultBatchLimit = 500 DefaultInterval = 10 * time.Second @@ -45,27 +43,9 @@ type Options struct { Diagnostic diagnostic.Logger } -type Payload struct { - SchemaVersion string `json:"schema_version"` - OrganizationID string `json:"organization_id"` - InstallationID string `json:"installation_id"` - BatchID string `json:"batch_id"` - SentAt string `json:"sent_at"` - Device *Device `json:"device,omitempty"` - Sessions []sqlite.LedgerRecord `json:"agent_sessions"` - Actions []sqlite.LedgerRecord `json:"authorization_actions"` - Receipts []sqlite.LedgerRecord `json:"authorization_receipts"` - ReceiptChainAnchor *sqlite.LedgerReceiptChainAnchor `json:"receipt_chain_anchor,omitempty"` -} - -type Device struct { - Label string `json:"label,omitempty"` - DeploymentVersion string `json:"deployment_version,omitempty"` -} - type State struct { - UpdatedAfter string `json:"updated_after,omitempty"` - ActionID string `json:"action_id,omitempty"` + UpdatedAfter *time.Time + ActionID string } func Run(ctx context.Context, opts Options) error { @@ -114,21 +94,12 @@ func Flush(ctx context.Context, opts Options) error { return err } - var updatedAfter *time.Time - if state.UpdatedAfter != "" { - parsed, err := time.Parse(time.RFC3339Nano, state.UpdatedAfter) - if err != nil { - return fmt.Errorf("parse managed stream state: %w", err) - } - updatedAfter = &parsed - } - limit := opts.BatchLimit if limit <= 0 { limit = DefaultBatchLimit } batch, err := store.LedgerBatch(ctx, sqlite.LedgerExportOptions{ - UpdatedAfter: updatedAfter, + UpdatedAfter: state.UpdatedAfter, UpdatedAfterID: state.ActionID, Limit: limit, }) @@ -139,8 +110,8 @@ func Flush(ctx context.Context, opts Options) error { return nil } - payload := Payload{ - SchemaVersion: SchemaVersion, + payload := ledger.Payload{ + SchemaVersion: ledger.SchemaVersion, OrganizationID: opts.OrganizationID, InstallationID: opts.InstallationID, BatchID: "batch_" + uuid.NewString(), @@ -158,22 +129,23 @@ func Flush(ctx context.Context, opts Options) error { deploymentVersion = strings.TrimSpace(opts.DeploymentVersion()) } if label != "" || deploymentVersion != "" { - payload.Device = &Device{Label: label, DeploymentVersion: deploymentVersion} + payload.Device = &ledger.Device{Label: label, DeploymentVersion: deploymentVersion} } if err := post(ctx, opts, payload); err != nil { return err } if batch.Cursor != nil { + updatedAfter := batch.Cursor.UpdatedAt.UTC() return SaveState(statePath, State{ - UpdatedAfter: batch.Cursor.UpdatedAt.UTC().Format(time.RFC3339Nano), + UpdatedAfter: &updatedAfter, ActionID: batch.Cursor.ActionID, }) } return nil } -func post(ctx context.Context, opts Options, payload Payload) error { +func post(ctx context.Context, opts Options, payload ledger.Payload) error { body, err := json.Marshal(payload) if err != nil { return err @@ -209,7 +181,7 @@ func endpointURL(cloudURL string) (string, error) { if err != nil { return "", err } - parsed.Path = DefaultEndpoint + parsed.Path = ledger.DefaultEndpoint parsed.RawQuery = "" parsed.Fragment = "" return parsed.String(), nil @@ -252,20 +224,53 @@ func LoadState(path string) (State, error) { } return State{}, err } - var state State + + type diskState struct { + UpdatedAfter string `json:"updated_after,omitempty"` + ActionID string `json:"action_id,omitempty"` + } + + var state diskState if err := json.Unmarshal(data, &state); err != nil { return State{}, err } - state.UpdatedAfter = strings.TrimSpace(state.UpdatedAfter) - state.ActionID = strings.TrimSpace(state.ActionID) - return state, nil + + updatedAfter := strings.TrimSpace(state.UpdatedAfter) + actionID := strings.TrimSpace(state.ActionID) + + var parsedUpdatedAfter *time.Time + if updatedAfter != "" { + parsed, err := time.Parse(time.RFC3339Nano, updatedAfter) + if err != nil { + return State{}, fmt.Errorf("parse managed stream state updated_after: %w", err) + } + parsedUpdatedAfter = &parsed + } + + return State{ + UpdatedAfter: parsedUpdatedAfter, + ActionID: actionID, + }, nil } -func SaveState(path string, state State) error { +func SaveState(path string, state State) (err error) { if err := os.MkdirAll(filepath.Dir(path), 0o700); err != nil { return err } - data, err := json.MarshalIndent(state, "", " ") + + type diskState struct { + UpdatedAfter string `json:"updated_after,omitempty"` + ActionID string `json:"action_id,omitempty"` + } + + updatedAfter := "" + if state.UpdatedAfter != nil { + updatedAfter = state.UpdatedAfter.UTC().Format(time.RFC3339Nano) + } + data, err := json.MarshalIndent(diskState{ + UpdatedAfter: updatedAfter, + ActionID: strings.TrimSpace(state.ActionID), + }, "", " ") if err != nil { return err } @@ -275,22 +280,35 @@ func SaveState(path string, state State) error { return err } tempPath := temp.Name() + closed := false cleanup := true defer func() { if cleanup { - _ = os.Remove(tempPath) + var cleanupErr error + if !closed { + cleanupErr = errors.Join(cleanupErr, temp.Close()) + } + cleanupErr = errors.Join(cleanupErr, os.Remove(tempPath)) + if cleanupErr == nil { + return + } + if err == nil { + err = cleanupErr + return + } + err = errors.Join(err, cleanupErr) } }() if err := temp.Chmod(0o600); err != nil { - _ = temp.Close() return err } if _, err := temp.Write(data); err != nil { - _ = temp.Close() return err } - if err := temp.Close(); err != nil { - return err + closeErr := temp.Close() + closed = true + if closeErr != nil { + return closeErr } if err := os.Rename(tempPath, path); err != nil { return err diff --git a/internal/managedstream/stream_test.go b/internal/managedstream/stream_test.go index 77ee7ab..a8b2193 100644 --- a/internal/managedstream/stream_test.go +++ b/internal/managedstream/stream_test.go @@ -12,13 +12,14 @@ import ( "github.com/kontext-security/kontext-cli/internal/guard/risk" "github.com/kontext-security/kontext-cli/internal/guard/store/sqlite" + "github.com/kontext-security/kontext-cli/internal/ledger" ) func TestFlushPostsLedgerBatchWithInstallationIdentity(t *testing.T) { store, dbPath := testStore(t) saveTestDecision(t, store, "session-1", "toolu_1") - var got Payload + var got ledger.Payload server := capturePayloadServer(t, &got) t.Cleanup(server.Close) @@ -36,8 +37,8 @@ func TestFlushPostsLedgerBatchWithInstallationIdentity(t *testing.T) { t.Fatalf("Flush() error = %v", err) } - if got.SchemaVersion != SchemaVersion { - t.Fatalf("schema_version = %q, want %q", got.SchemaVersion, SchemaVersion) + if got.SchemaVersion != ledger.SchemaVersion { + t.Fatalf("schema_version = %q, want %q", got.SchemaVersion, ledger.SchemaVersion) } if got.OrganizationID != "org_123" { t.Fatalf("organization_id = %q", got.OrganizationID) @@ -61,7 +62,7 @@ func TestFlushPostsLedgerBatchWithInstallationIdentity(t *testing.T) { if err != nil { t.Fatalf("LoadState() error = %v", err) } - if state.UpdatedAfter == "" { + if state.UpdatedAfter == nil { t.Fatal("updated_after was not persisted") } } @@ -70,7 +71,7 @@ func TestFlushOmitsBlankDeviceLabel(t *testing.T) { store, dbPath := testStore(t) saveTestDecision(t, store, "session-1", "toolu_1") - var got Payload + var got ledger.Payload server := capturePayloadServer(t, &got) t.Cleanup(server.Close) @@ -96,7 +97,7 @@ func TestFlushResolvesDeploymentVersionPerFlush(t *testing.T) { store, dbPath := testStore(t) saveTestDecision(t, store, "session-1", "toolu_1") - var got Payload + var got ledger.Payload server := capturePayloadServer(t, &got) t.Cleanup(server.Close) @@ -185,7 +186,7 @@ func TestFlushDefaultsStatePathBesideLedgerDB(t *testing.T) { if err != nil { t.Fatalf("LoadState() error = %v", err) } - if state.UpdatedAfter == "" { + if state.UpdatedAfter == nil { t.Fatal("updated_after was not persisted") } } @@ -194,8 +195,9 @@ func TestFlushUsesUpdatedAfterCursor(t *testing.T) { store, dbPath := testStore(t) saveTestDecision(t, store, "session-1", "toolu_1") + updatedAfter := time.Now().Add(time.Hour).UTC() statePath := filepath.Join(t.TempDir(), "stream-state.json") - if err := SaveState(statePath, State{UpdatedAfter: time.Now().Add(time.Hour).UTC().Format(time.RFC3339Nano)}); err != nil { + if err := SaveState(statePath, State{UpdatedAfter: &updatedAfter}); err != nil { t.Fatalf("SaveState() error = %v", err) } @@ -222,6 +224,54 @@ func TestFlushUsesUpdatedAfterCursor(t *testing.T) { } } +func TestLoadStateParsesAndTrimsFields(t *testing.T) { + statePath := filepath.Join(t.TempDir(), "stream-state.json") + if err := os.WriteFile(statePath, []byte(`{ + "updated_after": " 2026-05-31T10:11:12.123456789Z ", + "action_id": " act_123 " +} +`), 0o600); err != nil { + t.Fatalf("WriteFile() error = %v", err) + } + state, err := LoadState(statePath) + if err != nil { + t.Fatalf("LoadState() error = %v", err) + } + if state.UpdatedAfter == nil || state.UpdatedAfter.UTC().Format(time.RFC3339Nano) != "2026-05-31T10:11:12.123456789Z" { + t.Fatalf("UpdatedAfter = %+v, want parsed timestamp", state.UpdatedAfter) + } + if state.ActionID != "act_123" { + t.Fatalf("ActionID = %q, want %q", state.ActionID, "act_123") + } +} + +func TestLoadStateTreatsBlankTimestampAsUnset(t *testing.T) { + statePath := filepath.Join(t.TempDir(), "stream-state.json") + if err := os.WriteFile(statePath, []byte(`{"updated_after":" \t\n ","action_id":" act_123 "}`+"\n"), 0o600); err != nil { + t.Fatalf("WriteFile() error = %v", err) + } + state, err := LoadState(statePath) + if err != nil { + t.Fatalf("LoadState() error = %v", err) + } + if state.UpdatedAfter != nil { + t.Fatalf("UpdatedAfter = %+v, want unset", state.UpdatedAfter) + } + if state.ActionID != "act_123" { + t.Fatalf("ActionID = %q, want %q", state.ActionID, "act_123") + } +} + +func TestLoadStateRejectsInvalidTimestamp(t *testing.T) { + statePath := filepath.Join(t.TempDir(), "stream-state.json") + if err := os.WriteFile(statePath, []byte(`{"updated_after":"not-a-time"}`+"\n"), 0o600); err != nil { + t.Fatalf("WriteFile() error = %v", err) + } + if _, err := LoadState(statePath); err == nil { + t.Fatal("LoadState() error = nil, want invalid timestamp failure") + } +} + func testStore(t *testing.T) (*sqlite.Store, string) { t.Helper() dbPath := filepath.Join(t.TempDir(), "guard.db") @@ -233,11 +283,11 @@ func testStore(t *testing.T) (*sqlite.Store, string) { return store, dbPath } -func capturePayloadServer(t *testing.T, got *Payload) *httptest.Server { +func capturePayloadServer(t *testing.T, got *ledger.Payload) *httptest.Server { t.Helper() return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.URL.Path != DefaultEndpoint { - t.Fatalf("path = %q, want %q", r.URL.Path, DefaultEndpoint) + if r.URL.Path != ledger.DefaultEndpoint { + t.Fatalf("path = %q, want %q", r.URL.Path, ledger.DefaultEndpoint) } if got := r.Header.Get("Authorization"); got != "Bearer test-install-token" { t.Fatalf("Authorization = %q, want bearer install token", got)