Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions internal/ledger/payload.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package ledger

import "github.com/kontext-security/kontext-cli/internal/guard/store/sqlite"

const (
SchemaVersion = "authorization-ledger-v1"
DefaultEndpoint = "/api/v1/authorization-ledger/batches"
)

type Payload struct {
SchemaVersion string `json:"schema_version"`
OrganizationID string `json:"organization_id"`
InstallationID string `json:"installation_id"`
BatchID string `json:"batch_id"`
SentAt string `json:"sent_at"`
Device *Device `json:"device,omitempty"`
Sessions []sqlite.LedgerRecord `json:"agent_sessions"`
Actions []sqlite.LedgerRecord `json:"authorization_actions"`
Receipts []sqlite.LedgerRecord `json:"authorization_receipts"`
ReceiptChainAnchor *sqlite.LedgerReceiptChainAnchor `json:"receipt_chain_anchor,omitempty"`
}

type Device struct {
Label string `json:"label,omitempty"`
DeploymentVersion string `json:"deployment_version,omitempty"`
}
68 changes: 68 additions & 0 deletions internal/ledger/payload_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package ledger

import (
"encoding/json"
"testing"

"github.com/kontext-security/kontext-cli/internal/guard/store/sqlite"
)

func TestPayloadJSONShape(t *testing.T) {
payload := Payload{
SchemaVersion: SchemaVersion,
OrganizationID: "org_123",
InstallationID: "ins_123",
BatchID: "batch_abc",
SentAt: "2026-05-31T10:00:00Z",
Device: &Device{Label: "test-mac"},
Sessions: []sqlite.LedgerRecord{},
Actions: []sqlite.LedgerRecord{{"session_id": "claude-session"}},
Receipts: []sqlite.LedgerRecord{},
}

data, err := json.Marshal(payload)
if err != nil {
t.Fatalf("Marshal() error = %v", err)
}

var got map[string]any
if err := json.Unmarshal(data, &got); err != nil {
t.Fatalf("Unmarshal() error = %v", err)
}

for _, key := range []string{
"schema_version",
"organization_id",
"installation_id",
"batch_id",
"sent_at",
"agent_sessions",
"authorization_actions",
"authorization_receipts",
"device",
} {
if _, ok := got[key]; !ok {
t.Fatalf("missing JSON key %q in %s", key, string(data))
}
}
if _, ok := got["receipt_chain_anchor"]; ok {
t.Fatalf("receipt_chain_anchor was present for nil anchor: %s", string(data))
}

device, ok := got["device"].(map[string]any)
if !ok {
t.Fatalf("device = %#v, want object: %s", got["device"], string(data))
}
if device["label"] != "test-mac" {
t.Fatalf("device.label = %#v, want %q", device["label"], "test-mac")
}
if _, ok := device["deployment_version"]; ok {
t.Fatalf("device.deployment_version was present for empty value: %s", string(data))
}

for _, key := range []string{"agent_sessions", "authorization_actions", "authorization_receipts"} {
if _, ok := got[key].([]any); !ok {
t.Fatalf("%s = %#v, want array: %s", key, got[key], string(data))
}
}
}
21 changes: 6 additions & 15 deletions internal/managedobserve/daemon_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (

"github.com/kontext-security/kontext-cli/internal/guard/store/sqlite"
"github.com/kontext-security/kontext-cli/internal/hook"
"github.com/kontext-security/kontext-cli/internal/ledger"
"github.com/kontext-security/kontext-cli/internal/localruntime"
)

Expand Down Expand Up @@ -91,26 +92,15 @@ func TestDaemonSessionEndClosesHookSessionID(t *testing.T) {
}

func TestDaemonStreamsLedgerBatches(t *testing.T) {
type ledgerBatchRequest struct {
OrganizationID string `json:"organization_id"`
InstallationID string `json:"installation_id"`
Device *struct {
Label string `json:"label"`
} `json:"device,omitempty"`
Actions []struct {
SessionID string `json:"session_id"`
} `json:"authorization_actions"`
}

requests := make(chan ledgerBatchRequest, 1)
requests := make(chan ledger.Payload, 1)
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/api/v1/authorization-ledger/batches" {
if r.URL.Path != ledger.DefaultEndpoint {
t.Fatalf("path = %q", r.URL.Path)
}
if got := r.Header.Get("Authorization"); got != "Bearer test-install-token" {
t.Fatalf("Authorization = %q, want bearer install token", got)
}
var body ledgerBatchRequest
var body ledger.Payload
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
t.Fatalf("Decode() error = %v", err)
}
Expand Down Expand Up @@ -180,7 +170,8 @@ func TestDaemonStreamsLedgerBatches(t *testing.T) {
}
found := false
for _, action := range body.Actions {
if action.SessionID == "claude-stream-session" {
sessionID, _ := action["session_id"].(string)
if sessionID == "claude-stream-session" {
found = true
}
}
Expand Down
4 changes: 3 additions & 1 deletion internal/managedobserve/lifecycle.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,9 @@ func (l Lifecycle) probe(ctx context.Context) bool {
if err != nil {
return false
}
_ = conn.Close()
if err := conn.Close(); err != nil {
l.Diagnostic.Printf("managed observe probe close: %v\n", err)
}
return true
}

Expand Down
118 changes: 68 additions & 50 deletions internal/managedstream/stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,10 @@ import (

"github.com/kontext-security/kontext-cli/internal/diagnostic"
"github.com/kontext-security/kontext-cli/internal/guard/store/sqlite"
"github.com/kontext-security/kontext-cli/internal/ledger"
)

const (
SchemaVersion = "authorization-ledger-v1"
DefaultEndpoint = "/api/v1/authorization-ledger/batches"

DefaultBatchLimit = 500
DefaultInterval = 10 * time.Second

Expand All @@ -45,27 +43,9 @@ type Options struct {
Diagnostic diagnostic.Logger
}

type Payload struct {
SchemaVersion string `json:"schema_version"`
OrganizationID string `json:"organization_id"`
InstallationID string `json:"installation_id"`
BatchID string `json:"batch_id"`
SentAt string `json:"sent_at"`
Device *Device `json:"device,omitempty"`
Sessions []sqlite.LedgerRecord `json:"agent_sessions"`
Actions []sqlite.LedgerRecord `json:"authorization_actions"`
Receipts []sqlite.LedgerRecord `json:"authorization_receipts"`
ReceiptChainAnchor *sqlite.LedgerReceiptChainAnchor `json:"receipt_chain_anchor,omitempty"`
}

type Device struct {
Label string `json:"label,omitempty"`
DeploymentVersion string `json:"deployment_version,omitempty"`
}

type State struct {
UpdatedAfter string `json:"updated_after,omitempty"`
ActionID string `json:"action_id,omitempty"`
UpdatedAfter *time.Time
ActionID string
}

func Run(ctx context.Context, opts Options) error {
Expand Down Expand Up @@ -114,21 +94,12 @@ func Flush(ctx context.Context, opts Options) error {
return err
}

var updatedAfter *time.Time
if state.UpdatedAfter != "" {
parsed, err := time.Parse(time.RFC3339Nano, state.UpdatedAfter)
if err != nil {
return fmt.Errorf("parse managed stream state: %w", err)
}
updatedAfter = &parsed
}

limit := opts.BatchLimit
if limit <= 0 {
limit = DefaultBatchLimit
}
batch, err := store.LedgerBatch(ctx, sqlite.LedgerExportOptions{
UpdatedAfter: updatedAfter,
UpdatedAfter: state.UpdatedAfter,
UpdatedAfterID: state.ActionID,
Limit: limit,
})
Expand All @@ -139,8 +110,8 @@ func Flush(ctx context.Context, opts Options) error {
return nil
}

payload := Payload{
SchemaVersion: SchemaVersion,
payload := ledger.Payload{
SchemaVersion: ledger.SchemaVersion,
OrganizationID: opts.OrganizationID,
InstallationID: opts.InstallationID,
BatchID: "batch_" + uuid.NewString(),
Expand All @@ -158,22 +129,23 @@ func Flush(ctx context.Context, opts Options) error {
deploymentVersion = strings.TrimSpace(opts.DeploymentVersion())
}
if label != "" || deploymentVersion != "" {
payload.Device = &Device{Label: label, DeploymentVersion: deploymentVersion}
payload.Device = &ledger.Device{Label: label, DeploymentVersion: deploymentVersion}
}
if err := post(ctx, opts, payload); err != nil {
return err
}

if batch.Cursor != nil {
updatedAfter := batch.Cursor.UpdatedAt.UTC()
return SaveState(statePath, State{
UpdatedAfter: batch.Cursor.UpdatedAt.UTC().Format(time.RFC3339Nano),
UpdatedAfter: &updatedAfter,
ActionID: batch.Cursor.ActionID,
})
}
return nil
}

func post(ctx context.Context, opts Options, payload Payload) error {
func post(ctx context.Context, opts Options, payload ledger.Payload) error {
body, err := json.Marshal(payload)
if err != nil {
return err
Expand Down Expand Up @@ -209,7 +181,7 @@ func endpointURL(cloudURL string) (string, error) {
if err != nil {
return "", err
}
parsed.Path = DefaultEndpoint
parsed.Path = ledger.DefaultEndpoint
parsed.RawQuery = ""
parsed.Fragment = ""
return parsed.String(), nil
Expand Down Expand Up @@ -252,20 +224,53 @@ func LoadState(path string) (State, error) {
}
return State{}, err
}
var state State

type diskState struct {
UpdatedAfter string `json:"updated_after,omitempty"`
ActionID string `json:"action_id,omitempty"`
}

var state diskState
if err := json.Unmarshal(data, &state); err != nil {
return State{}, err
}
state.UpdatedAfter = strings.TrimSpace(state.UpdatedAfter)
state.ActionID = strings.TrimSpace(state.ActionID)
return state, nil

updatedAfter := strings.TrimSpace(state.UpdatedAfter)
actionID := strings.TrimSpace(state.ActionID)

var parsedUpdatedAfter *time.Time
if updatedAfter != "" {
parsed, err := time.Parse(time.RFC3339Nano, updatedAfter)
if err != nil {
return State{}, fmt.Errorf("parse managed stream state updated_after: %w", err)
}
parsedUpdatedAfter = &parsed
}

return State{
UpdatedAfter: parsedUpdatedAfter,
ActionID: actionID,
}, nil
}

func SaveState(path string, state State) error {
func SaveState(path string, state State) (err error) {
if err := os.MkdirAll(filepath.Dir(path), 0o700); err != nil {
return err
}
data, err := json.MarshalIndent(state, "", " ")

type diskState struct {
UpdatedAfter string `json:"updated_after,omitempty"`
ActionID string `json:"action_id,omitempty"`
}

updatedAfter := ""
if state.UpdatedAfter != nil {
updatedAfter = state.UpdatedAfter.UTC().Format(time.RFC3339Nano)
}
data, err := json.MarshalIndent(diskState{
UpdatedAfter: updatedAfter,
ActionID: strings.TrimSpace(state.ActionID),
}, "", " ")
if err != nil {
return err
}
Expand All @@ -275,22 +280,35 @@ func SaveState(path string, state State) error {
return err
}
tempPath := temp.Name()
closed := false
cleanup := true
defer func() {
if cleanup {
_ = os.Remove(tempPath)
var cleanupErr error
if !closed {
cleanupErr = errors.Join(cleanupErr, temp.Close())
}
cleanupErr = errors.Join(cleanupErr, os.Remove(tempPath))
if cleanupErr == nil {
return
}
if err == nil {
err = cleanupErr
return
}
err = errors.Join(err, cleanupErr)
}
}()
if err := temp.Chmod(0o600); err != nil {
_ = temp.Close()
return err
}
if _, err := temp.Write(data); err != nil {
_ = temp.Close()
return err
}
if err := temp.Close(); err != nil {
return err
closeErr := temp.Close()
closed = true
if closeErr != nil {
return closeErr
}
if err := os.Rename(tempPath, path); err != nil {
return err
Expand Down
Loading
Loading