diff --git a/examples/runner/main.go b/examples/runner/main.go index 31390cbff..c76787830 100644 --- a/examples/runner/main.go +++ b/examples/runner/main.go @@ -29,6 +29,7 @@ import ( "trpc.group/trpc-go/trpc-agent-go/model/openai" "trpc.group/trpc-go/trpc-agent-go/runner" "trpc.group/trpc-go/trpc-agent-go/session" + sessiondb "trpc.group/trpc-go/trpc-agent-go/session/database" sessioninmemory "trpc.group/trpc-go/trpc-agent-go/session/inmemory" "trpc.group/trpc-go/trpc-agent-go/session/redis" "trpc.group/trpc-go/trpc-agent-go/tool" @@ -38,7 +39,8 @@ import ( var ( modelName = flag.String("model", "deepseek-chat", "Name of the model to use") redisAddr = flag.String("redis-addr", "localhost:6379", "Redis address") - sessServiceName = flag.String("session", "inmemory", "Name of the session service to use, inmemory / redis") + databaseDSN = flag.String("database-dsn", "", "Database DSN (MySQL: user:pass@tcp(host:port)/db?charset=utf8mb4&parseTime=True&loc=Local, PostgreSQL: postgres://user:pass@host/db)") + sessServiceName = flag.String("session", "inmemory", "Name of the session service to use: inmemory / redis / database") streaming = flag.Bool("streaming", true, "Enable streaming mode for responses") enableParallel = flag.Bool("enable-parallel", false, "Enable parallel tool execution (default: false, serial execution)") ) @@ -55,8 +57,15 @@ func main() { parallelStatus = "enabled (parallel execution)" } fmt.Printf("Parallel Tools: %s\n", parallelStatus) - if *sessServiceName == "redis" { + switch *sessServiceName { + case "redis": fmt.Printf("Redis: %s\n", *redisAddr) + case "database": + if *databaseDSN != "" { + fmt.Printf("Database: DSN configured\n") + } else { + fmt.Printf("Database: Using default configuration\n") + } } fmt.Printf("Type 'exit' to end the conversation\n") fmt.Printf("Available tools: calculator, current_time\n") @@ -113,10 +122,23 @@ func (c *multiTurnChat) setup(_ context.Context) error { redisURL := fmt.Sprintf("redis://%s", *redisAddr) sessionService, err = redis.NewService(redis.WithRedisClientURL(redisURL)) if err != nil { - return fmt.Errorf("failed to create session service: %w", err) + return fmt.Errorf("failed to create redis session service: %w", err) + } + + case "database": + if *databaseDSN == "" { + return fmt.Errorf("database-dsn is required when using database session service") + } + sessionService, err = sessiondb.NewService( + sessiondb.WithDatabaseDSN(*databaseDSN), + sessiondb.WithAutoCreateTable(true), + ) + if err != nil { + return fmt.Errorf("failed to create database session service: %w", err) } + default: - return fmt.Errorf("invalid session service name: %s", *sessServiceName) + return fmt.Errorf("invalid session service name: %s (valid options: inmemory, redis, database)", *sessServiceName) } // Create tools. diff --git a/session/database/go.mod b/session/database/go.mod new file mode 100644 index 000000000..fad6684c8 --- /dev/null +++ b/session/database/go.mod @@ -0,0 +1,27 @@ +module trpc.group/trpc-go/trpc-agent-go/session/database + +go 1.22 + +replace ( + trpc.group/trpc-go/trpc-agent-go => ../../ + trpc.group/trpc-go/trpc-agent-go/storage/database => ../../storage/database +) + +require ( + github.com/google/uuid v1.6.0 + github.com/spaolacci/murmur3 v1.1.0 + gorm.io/gorm v1.25.12 + trpc.group/trpc-go/trpc-agent-go v0.0.0 + trpc.group/trpc-go/trpc-agent-go/storage/database v0.0.0 +) + +require ( + github.com/go-sql-driver/mysql v1.7.0 // indirect + github.com/jinzhu/inflection v1.0.0 // indirect + github.com/jinzhu/now v1.1.5 // indirect + go.uber.org/multierr v1.10.0 // indirect + go.uber.org/zap v1.27.0 // indirect + golang.org/x/text v0.21.0 // indirect + gorm.io/driver/mysql v1.5.7 // indirect + trpc.group/trpc-go/trpc-a2a-go v0.2.5-0.20251020094851-6ab922c9dab1 // indirect +) diff --git a/session/database/go.sum b/session/database/go.sum new file mode 100644 index 000000000..848f6bb16 --- /dev/null +++ b/session/database/go.sum @@ -0,0 +1,33 @@ +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/go-sql-driver/mysql v1.7.0 h1:ueSltNNllEqE3qcWBTD0iQd3IpL/6U+mJxLkazJ7YPc= +github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= +github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= +github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= +github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/spaolacci/murmur3 v1.1.0 h1:7c1g84S4BPRrfL5Xrdp6fOJ206sU9y293DDHaoy0bLI= +github.com/spaolacci/murmur3 v1.1.0/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= +go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= +go.uber.org/multierr v1.10.0 h1:S0h4aNzvfcFsC3dRF1jLoaov7oRaKqRGC/pUEJ2yvPQ= +go.uber.org/multierr v1.10.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= +go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8= +go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E= +golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo= +golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gorm.io/driver/mysql v1.5.7 h1:MndhOPYOfEp2rHKgkZIhJ16eVUIRf2HmzgoPmh7FCWo= +gorm.io/driver/mysql v1.5.7/go.mod h1:sEtPWMiqiN1N1cMXoXmBbd8C6/l+TESwriotuRRpkDM= +gorm.io/gorm v1.25.7/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= +gorm.io/gorm v1.25.12 h1:I0u8i2hWQItBq1WfE0o2+WuL9+8L21K9e2HHSTE/0f8= +gorm.io/gorm v1.25.12/go.mod h1:xh7N7RHfYlNc5EmcI/El95gXusucDrQnHXe0+CgWcLQ= +trpc.group/trpc-go/trpc-a2a-go v0.2.5-0.20251020094851-6ab922c9dab1 h1:P+OyPh+QCNuO8u+M2UPTYZCGKnH9YAcijC8ULokAdTw= +trpc.group/trpc-go/trpc-a2a-go v0.2.5-0.20251020094851-6ab922c9dab1/go.mod h1:Gtytau9Uoc3oPo/dpHvKit+tQn9Qlk5XFG1RiZTGqfk= diff --git a/session/database/migration.go b/session/database/migration.go new file mode 100644 index 000000000..9ad8d85e4 --- /dev/null +++ b/session/database/migration.go @@ -0,0 +1,483 @@ +// +// Tencent is pleased to support the open source community by making trpc-agent-go available. +// +// Copyright (C) 2025 Tencent. All rights reserved. +// +// trpc-agent-go is licensed under the Apache License Version 2.0. +// +// + +package database + +import ( + "fmt" + "strings" + + "gorm.io/gorm" + "trpc.group/trpc-go/trpc-agent-go/log" +) + +// columnInfo represents database column information +type columnInfo struct { + Field string + Type string + Null string + Key string + Default *string + Extra string +} + +// indexInfo represents database index information +type indexInfo struct { + Table string + NonUnique int + KeyName string + SeqInIndex int + ColumnName string + Collation string + Cardinality int64 + SubPart *int + Packed *string + Null string + IndexType string + Comment string +} + +// expectedColumn represents expected column definition +type expectedColumn struct { + Name string + Type string + Nullable bool +} + +// expectedIndex represents expected index definition +type expectedIndex struct { + Name string + Columns []string + Unique bool +} + +// getTableColumns retrieves column information from a table +func getTableColumns(db *gorm.DB, tableName string) (map[string]*columnInfo, error) { + // Use GORM's Migrator to get column types (works across databases) + columnTypes, err := db.Migrator().ColumnTypes(tableName) + if err != nil { + return nil, err + } + + result := make(map[string]*columnInfo) + for _, col := range columnTypes { + nullable, _ := col.Nullable() + nullStr := "NO" + if nullable { + nullStr = "YES" + } + + result[col.Name()] = &columnInfo{ + Field: col.Name(), + Type: col.DatabaseTypeName(), + Null: nullStr, + } + } + return result, nil +} + +// getTableIndexes retrieves index information from a table +func getTableIndexes(db *gorm.DB, tableName string) (map[string]*expectedIndex, error) { + dialectName := db.Dialector.Name() + + switch dialectName { + case "mysql": + return getMySQLIndexes(db, tableName) + case "postgres": + return getPostgreSQLIndexes(db, tableName) + case "sqlite": + return getSQLiteIndexes(db, tableName) + default: + // For unsupported databases, return empty map (skip index check) + log.Warnf("Index checking not supported for database type: %s", dialectName) + return make(map[string]*expectedIndex), nil + } +} + +// getMySQLIndexes retrieves index information from MySQL +func getMySQLIndexes(db *gorm.DB, tableName string) (map[string]*expectedIndex, error) { + var indexes []indexInfo + if err := db.Raw("SHOW INDEX FROM " + tableName).Scan(&indexes).Error; err != nil { + return nil, err + } + + result := make(map[string]*expectedIndex) + for _, idx := range indexes { + if idx.KeyName == "PRIMARY" { + continue // Skip primary key + } + + if existing, ok := result[idx.KeyName]; ok { + existing.Columns = append(existing.Columns, idx.ColumnName) + } else { + result[idx.KeyName] = &expectedIndex{ + Name: idx.KeyName, + Columns: []string{idx.ColumnName}, + Unique: idx.NonUnique == 0, + } + } + } + return result, nil +} + +// getPostgreSQLIndexes retrieves index information from PostgreSQL +func getPostgreSQLIndexes(db *gorm.DB, tableName string) (map[string]*expectedIndex, error) { + type pgIndex struct { + IndexName string + ColumnName string + IsUnique bool + } + + var indexes []pgIndex + query := ` + SELECT + i.relname AS index_name, + a.attname AS column_name, + ix.indisunique AS is_unique + FROM pg_class t + JOIN pg_index ix ON t.oid = ix.indrelid + JOIN pg_class i ON i.oid = ix.indexrelid + JOIN pg_attribute a ON a.attrelid = t.oid AND a.attnum = ANY(ix.indkey) + WHERE t.relname = $1 AND i.relname NOT LIKE 'pg_%' + ORDER BY i.relname, a.attnum + ` + if err := db.Raw(query, tableName).Scan(&indexes).Error; err != nil { + return nil, err + } + + result := make(map[string]*expectedIndex) + for _, idx := range indexes { + if existing, ok := result[idx.IndexName]; ok { + existing.Columns = append(existing.Columns, idx.ColumnName) + } else { + result[idx.IndexName] = &expectedIndex{ + Name: idx.IndexName, + Columns: []string{idx.ColumnName}, + Unique: idx.IsUnique, + } + } + } + return result, nil +} + +// getSQLiteIndexes retrieves index information from SQLite +func getSQLiteIndexes(db *gorm.DB, tableName string) (map[string]*expectedIndex, error) { + type sqliteIndex struct { + Name string + Unique int + } + + var indexes []sqliteIndex + if err := db.Raw("SELECT name, `unique` FROM sqlite_master WHERE type='index' AND tbl_name=?", tableName). + Scan(&indexes).Error; err != nil { + return nil, err + } + + result := make(map[string]*expectedIndex) + for _, idx := range indexes { + // Skip auto-created indexes + if strings.HasPrefix(idx.Name, "sqlite_autoindex") { + continue + } + + // Get columns for this index + type indexColumn struct { + Name string + } + var columns []indexColumn + db.Raw("PRAGMA index_info(?)", idx.Name).Scan(&columns) + + cols := make([]string, len(columns)) + for i, col := range columns { + cols[i] = col.Name + } + + result[idx.Name] = &expectedIndex{ + Name: idx.Name, + Columns: cols, + Unique: idx.Unique == 1, + } + } + return result, nil +} + +// tableExists checks if a table exists +func tableExists(db *gorm.DB, tableName string) (bool, error) { + // Use GORM's Migrator which works across different databases + return db.Migrator().HasTable(tableName), nil +} + +// getExpectedColumns returns expected columns for each table +func getExpectedColumns() map[string][]expectedColumn { + return map[string][]expectedColumn{ + "session_states": { + {Name: "id", Type: "bigint", Nullable: false}, + {Name: "app_name", Type: "varchar", Nullable: false}, + {Name: "user_id", Type: "varchar", Nullable: false}, + {Name: "session_id", Type: "varchar", Nullable: false}, + {Name: "state", Type: "mediumblob", Nullable: true}, + {Name: "created_at", Type: "datetime", Nullable: false}, + {Name: "updated_at", Type: "datetime", Nullable: false}, + {Name: "expires_at", Type: "datetime", Nullable: true}, + }, + "session_events": { + {Name: "id", Type: "bigint", Nullable: false}, + {Name: "app_name", Type: "varchar", Nullable: false}, + {Name: "user_id", Type: "varchar", Nullable: false}, + {Name: "session_id", Type: "varchar", Nullable: false}, + {Name: "event_data", Type: "mediumblob", Nullable: false}, + {Name: "timestamp", Type: "datetime", Nullable: false}, + {Name: "created_at", Type: "datetime", Nullable: false}, + {Name: "expires_at", Type: "datetime", Nullable: true}, + }, + "session_summaries": { + {Name: "id", Type: "bigint", Nullable: false}, + {Name: "app_name", Type: "varchar", Nullable: false}, + {Name: "user_id", Type: "varchar", Nullable: false}, + {Name: "session_id", Type: "varchar", Nullable: false}, + {Name: "filter_key", Type: "varchar", Nullable: false}, + {Name: "summary", Type: "blob", Nullable: false}, + {Name: "updated_at", Type: "datetime", Nullable: false}, + {Name: "expires_at", Type: "datetime", Nullable: true}, + }, + "app_states": { + {Name: "id", Type: "bigint", Nullable: false}, + {Name: "app_name", Type: "varchar", Nullable: false}, + {Name: "state_key", Type: "varchar", Nullable: false}, + {Name: "value", Type: "mediumblob", Nullable: false}, + {Name: "updated_at", Type: "datetime", Nullable: false}, + {Name: "expires_at", Type: "datetime", Nullable: true}, + }, + "user_states": { + {Name: "id", Type: "bigint", Nullable: false}, + {Name: "app_name", Type: "varchar", Nullable: false}, + {Name: "user_id", Type: "varchar", Nullable: false}, + {Name: "state_key", Type: "varchar", Nullable: false}, + {Name: "value", Type: "mediumblob", Nullable: false}, + {Name: "updated_at", Type: "datetime", Nullable: false}, + {Name: "expires_at", Type: "datetime", Nullable: true}, + }, + } +} + +// getExpectedIndexes returns expected indexes for each table +func getExpectedIndexes() map[string][]expectedIndex { + return map[string][]expectedIndex{ + "session_states": { + {Name: "idx_app_user_session", Columns: []string{"app_name", "user_id", "session_id"}, Unique: true}, + {Name: "idx_expires_at", Columns: []string{"expires_at"}, Unique: false}, + }, + "session_events": { + {Name: "idx_app_user_session_event", Columns: []string{"app_name", "user_id", "session_id", "timestamp"}, Unique: false}, + {Name: "idx_expires_at", Columns: []string{"expires_at"}, Unique: false}, + }, + "session_summaries": { + {Name: "idx_app_user_session_filter", Columns: []string{"app_name", "user_id", "session_id", "filter_key"}, Unique: false}, + {Name: "idx_expires_at", Columns: []string{"expires_at"}, Unique: false}, + }, + "app_states": { + {Name: "idx_app_key", Columns: []string{"app_name", "state_key"}, Unique: true}, + {Name: "idx_expires_at", Columns: []string{"expires_at"}, Unique: false}, + }, + "user_states": { + {Name: "idx_app_user_key", Columns: []string{"app_name", "user_id", "state_key"}, Unique: true}, + {Name: "idx_expires_at", Columns: []string{"expires_at"}, Unique: false}, + }, + } +} + +// normalizeType normalizes database type string for comparison across different databases +func normalizeType(t string) string { + t = strings.ToLower(t) + t = strings.TrimSpace(t) + + // Remove size specifications and parentheses (e.g., "varchar(255)" -> "varchar") + if idx := strings.Index(t, "("); idx != -1 { + t = strings.TrimSpace(t[:idx]) + } + + // Integer types (MySQL: BIGINT/INT/TINYINT, PostgreSQL: BIGINT/INTEGER, SQLite: INTEGER) + if t == "bigint" || t == "bigint unsigned" || t == "integer" || + t == "int" || t == "int unsigned" || t == "tinyint" || t == "smallint" { + return "int" + } + + // String types (MySQL: VARCHAR, PostgreSQL: CHARACTER VARYING/VARCHAR, SQLite: TEXT) + if t == "varchar" || t == "character varying" || t == "char" { + return "varchar" + } + + // Text types + if t == "text" || t == "mediumtext" || t == "longtext" { + return "text" + } + + // Binary/Blob types (MySQL: BLOB/MEDIUMBLOB, PostgreSQL: BYTEA, SQLite: BLOB) + if t == "blob" || t == "mediumblob" || t == "longblob" || t == "tinyblob" || t == "bytea" { + return "blob" + } + + // Datetime types (MySQL: DATETIME, PostgreSQL: TIMESTAMP, SQLite: DATETIME) + // Handle "timestamp without time zone" and "timestamp with time zone" + if t == "datetime" || t == "timestamp" || + strings.HasPrefix(t, "timestamp without") || strings.HasPrefix(t, "timestamp with") { + return "datetime" + } + + return t +} + +// checkTableSchema checks if table schema matches expected definition +func checkTableSchema(db *gorm.DB, tableName string, expectedColumns []expectedColumn) error { + columns, err := getTableColumns(db, tableName) + if err != nil { + return fmt.Errorf("failed to get columns for table %s: %w", tableName, err) + } + + // Check if all expected columns exist with correct type + for _, expected := range expectedColumns { + actual, exists := columns[expected.Name] + if !exists { + return fmt.Errorf("table %s: missing column '%s'", tableName, expected.Name) + } + + // Normalize types for comparison + expectedType := normalizeType(expected.Type) + actualType := normalizeType(actual.Type) + + if expectedType != actualType { + return fmt.Errorf("table %s: column '%s' type mismatch (expected: %s, actual: %s)", + tableName, expected.Name, expected.Type, actual.Type) + } + + // Check nullable constraint + isNullable := actual.Null == "YES" + if expected.Nullable != isNullable { + return fmt.Errorf("table %s: column '%s' nullable mismatch (expected: %v, actual: %v)", + tableName, expected.Name, expected.Nullable, isNullable) + } + } + + return nil +} + +// checkTableIndexes checks if table indexes match expected definition +func checkTableIndexes(db *gorm.DB, tableName string, expectedIndexes []expectedIndex) { + indexes, err := getTableIndexes(db, tableName) + if err != nil { + log.Warnf("Failed to get indexes for table %s: %v", tableName, err) + return + } + + // Check for missing or mismatched indexes + for _, expected := range expectedIndexes { + actual, exists := indexes[expected.Name] + if !exists { + log.Infof("Table %s: index '%s' does not exist (expected columns: %v)", + tableName, expected.Name, expected.Columns) + continue + } + + // Check if columns match + if len(actual.Columns) != len(expected.Columns) { + log.Infof("Table %s: index '%s' column count mismatch (expected: %v, actual: %v)", + tableName, expected.Name, expected.Columns, actual.Columns) + continue + } + + for i, col := range expected.Columns { + if actual.Columns[i] != col { + log.Infof("Table %s: index '%s' column mismatch at position %d (expected: %s, actual: %s)", + tableName, expected.Name, i, col, actual.Columns[i]) + break + } + } + + // Check if unique constraint matches + if actual.Unique != expected.Unique { + log.Infof("Table %s: index '%s' unique constraint mismatch (expected: %v, actual: %v)", + tableName, expected.Name, expected.Unique, actual.Unique) + } + } + + // Check for extra indexes (informational only) + for indexName := range indexes { + found := false + for _, expected := range expectedIndexes { + if expected.Name == indexName { + found = true + break + } + } + if !found { + log.Infof("Table %s: found unexpected index '%s'", tableName, indexName) + } + } +} + +// initializeTables handles table initialization based on configuration +func initializeTables(db *gorm.DB, autoCreateTable, autoMigrate bool) error { + allModels := []interface{}{ + &sessionStateModel{}, + &sessionEventModel{}, + &sessionSummaryModel{}, + &appStateModel{}, + &userStateModel{}, + } + + tableNames := []string{ + "session_states", + "session_events", + "session_summaries", + "app_states", + "user_states", + } + + expectedColumns := getExpectedColumns() + expectedIndexes := getExpectedIndexes() + for i, tableName := range tableNames { + exists, err := tableExists(db, tableName) + if err != nil { + return fmt.Errorf("failed to check if table %s exists: %w", tableName, err) + } + + if !exists { + // Table doesn't exist + if !autoCreateTable { + return fmt.Errorf("table %s does not exist and auto-create is disabled", tableName) + } + + // Create table using GORM AutoMigrate + log.Infof("Creating table %s...", tableName) + if err := db.AutoMigrate(allModels[i]); err != nil { + return fmt.Errorf("failed to create table %s: %w", tableName, err) + } + log.Infof("Table %s created successfully", tableName) + } else { + // Table exists, check schema + log.Debugf("Checking schema for table %s...", tableName) + if err := checkTableSchema(db, tableName, expectedColumns[tableName]); err != nil { + return fmt.Errorf("table schema validation failed: %w", err) + } + + // Check indexes (informational only) + checkTableIndexes(db, tableName, expectedIndexes[tableName]) + + // If auto migrate is enabled, run GORM AutoMigrate + if autoMigrate { + log.Infof("Running auto-migration for table %s...", tableName) + if err := db.AutoMigrate(allModels[i]); err != nil { + return fmt.Errorf("failed to migrate table %s: %w", tableName, err) + } + } + } + } + + return nil +} diff --git a/session/database/models.go b/session/database/models.go new file mode 100644 index 000000000..bace2a5d4 --- /dev/null +++ b/session/database/models.go @@ -0,0 +1,99 @@ +// +// Tencent is pleased to support the open source community by making trpc-agent-go available. +// +// Copyright (C) 2025 Tencent. All rights reserved. +// +// trpc-agent-go is licensed under the Apache License Version 2.0. +// +// + +package database + +import ( + "time" +) + +// sessionStateModel represents the session state table structure. +// This table stores the session metadata and state. +type sessionStateModel struct { + ID uint64 `gorm:"primaryKey;autoIncrement"` + AppName string `gorm:"type:varchar(255);not null;index:idx_app_user_session,priority:1"` + UserID string `gorm:"type:varchar(255);not null;index:idx_app_user_session,priority:2"` + SessionID string `gorm:"type:varchar(255);not null;uniqueIndex:idx_app_user_session,priority:3"` + State []byte `gorm:"type:mediumblob"` // JSON encoded StateMap + CreatedAt time.Time `gorm:"not null"` + UpdatedAt time.Time `gorm:"not null"` + ExpiresAt time.Time `gorm:"index:idx_expires_at"` // For TTL support, nullable means no expiration +} + +// TableName specifies the table name for SessionStateModel. +func (sessionStateModel) TableName() string { + return "session_states" +} + +// sessionEventModel represents the session events table structure. +// This table stores individual events for each session. +type sessionEventModel struct { + ID uint64 `gorm:"primaryKey;autoIncrement"` + AppName string `gorm:"type:varchar(255);not null;index:idx_app_user_session_event,priority:1"` + UserID string `gorm:"type:varchar(255);not null;index:idx_app_user_session_event,priority:2"` + SessionID string `gorm:"type:varchar(255);not null;index:idx_app_user_session_event,priority:3"` + EventData []byte `gorm:"type:mediumblob;not null"` // JSON encoded Event + Timestamp time.Time `gorm:"not null;index:idx_app_user_session_event,priority:4"` + CreatedAt time.Time `gorm:"not null"` + ExpiresAt time.Time `gorm:"index:idx_expires_at"` // For TTL support +} + +// TableName specifies the table name for SessionEventModel. +func (sessionEventModel) TableName() string { + return "session_events" +} + +// sessionSummaryModel represents the session summaries table structure. +// This table stores summaries for sessions, keyed by filterKey. +type sessionSummaryModel struct { + ID uint64 `gorm:"primaryKey;autoIncrement"` + AppName string `gorm:"type:varchar(255);not null;index:idx_app_user_session_filter,priority:1"` + UserID string `gorm:"type:varchar(255);not null;index:idx_app_user_session_filter,priority:2"` + SessionID string `gorm:"type:varchar(255);not null;index:idx_app_user_session_filter,priority:3"` + FilterKey string `gorm:"type:varchar(255);not null;default:'';index:idx_app_user_session_filter,priority:4"` // Empty string for full summary + Summary []byte `gorm:"type:mediumblob;not null"` // JSON encoded Summary + UpdatedAt time.Time `gorm:"not null"` + ExpiresAt time.Time `gorm:"index:idx_expires_at"` // For TTL support +} + +// TableName specifies the table name for SessionSummaryModel. +func (sessionSummaryModel) TableName() string { + return "session_summaries" +} + +// appStateModel represents the application-level state table structure. +type appStateModel struct { + ID uint64 `gorm:"primaryKey;autoIncrement"` + AppName string `gorm:"type:varchar(255);not null;uniqueIndex:idx_app_key"` + StateKey string `gorm:"type:varchar(255);not null;uniqueIndex:idx_app_key"` + Value []byte `gorm:"type:mediumblob;not null"` + UpdatedAt time.Time `gorm:"not null"` + ExpiresAt time.Time `gorm:"index:idx_expires_at"` // For TTL support +} + +// TableName specifies the table name for AppStateModel. +func (appStateModel) TableName() string { + return "app_states" +} + +// userStateModel represents the user-level state table structure. +type userStateModel struct { + ID uint64 `gorm:"primaryKey;autoIncrement"` + AppName string `gorm:"type:varchar(255);not null;uniqueIndex:idx_app_user_key,priority:1"` + UserID string `gorm:"type:varchar(255);not null;uniqueIndex:idx_app_user_key,priority:2"` + StateKey string `gorm:"type:varchar(255);not null;uniqueIndex:idx_app_user_key,priority:3"` + Value []byte `gorm:"type:mediumblob;not null"` + UpdatedAt time.Time `gorm:"not null"` + ExpiresAt time.Time `gorm:"index:idx_expires_at"` // For TTL support +} + +// TableName specifies the table name for UserStateModel. +func (userStateModel) TableName() string { + return "user_states" +} diff --git a/session/database/options.go b/session/database/options.go new file mode 100644 index 000000000..baa0accd0 --- /dev/null +++ b/session/database/options.go @@ -0,0 +1,191 @@ +// +// Tencent is pleased to support the open source community by making trpc-agent-go available. +// +// Copyright (C) 2025 Tencent. All rights reserved. +// +// trpc-agent-go is licensed under the Apache License Version 2.0. +// +// + +package database + +import ( + "time" + + "trpc.group/trpc-go/trpc-agent-go/session/summary" + storage "trpc.group/trpc-go/trpc-agent-go/storage/database" +) + +// ServiceOpts is the options for the database session service. +type ServiceOpts struct { + sessionEventLimit int + dsn string + driverType storage.DriverType // Database driver type (mysql, postgres, sqlite) + instanceName string + extraOptions []any + sessionTTL time.Duration // TTL for session state and event list + appStateTTL time.Duration // TTL for app state + userStateTTL time.Duration // TTL for user state + autoCreateTable bool // Whether to auto create tables if not exist (default: true) + autoMigrate bool // Whether to auto migrate existing tables (default: false) + cleanupInterval time.Duration // Interval for cleanup of expired data + enableAsyncPersist bool + asyncPersisterNum int // number of worker goroutines for async persistence + // summarizer integrates LLM summarization. + summarizer summary.SessionSummarizer + // asyncSummaryNum is the number of worker goroutines for async summary. + asyncSummaryNum int + // summaryQueueSize is the size of summary job queue. + summaryQueueSize int + // summaryJobTimeout is the timeout for processing a single summary job. + summaryJobTimeout time.Duration +} + +// ServiceOpt is the option for the database session service. +type ServiceOpt func(*ServiceOpts) + +// WithSessionEventLimit sets the limit of events in a session. +func WithSessionEventLimit(limit int) ServiceOpt { + return func(opts *ServiceOpts) { + opts.sessionEventLimit = limit + } +} + +// WithDatabaseDSN creates a database client from DSN and sets it to the service. +// Supports MySQL, PostgreSQL, and other GORM-compatible databases. +// Use WithDriverType to specify the database type if not using MySQL. +func WithDatabaseDSN(dsn string) ServiceOpt { + return func(opts *ServiceOpts) { + opts.dsn = dsn + } +} + +// WithDriverType sets the database driver type. +// Supported types: storage.DriverMySQL (default), storage.DriverPostgreSQL, storage.DriverSQLite +func WithDriverType(driverType storage.DriverType) ServiceOpt { + return func(opts *ServiceOpts) { + opts.driverType = driverType + } +} + +// WithDatabaseInstance uses a database instance from storage. +// Note: WithDatabaseDSN has higher priority than WithDatabaseInstance. +// If both are specified, WithDatabaseDSN will be used. +func WithDatabaseInstance(instanceName string) ServiceOpt { + return func(opts *ServiceOpts) { + opts.instanceName = instanceName + } +} + +// WithExtraOptions sets the extra options for the database session service. +// this option mainly used for the customized database client builder, it will be passed to the builder. +func WithExtraOptions(extraOptions ...any) ServiceOpt { + return func(opts *ServiceOpts) { + opts.extraOptions = append(opts.extraOptions, extraOptions...) + } +} + +// WithSessionTTL sets the TTL for session state and event list. +// If not set, session will not expire automatically, set 0 will not expire. +func WithSessionTTL(ttl time.Duration) ServiceOpt { + return func(opts *ServiceOpts) { + opts.sessionTTL = ttl + } +} + +// WithAppStateTTL sets the TTL for app state. +// If not set, app state will not expire. +func WithAppStateTTL(ttl time.Duration) ServiceOpt { + return func(opts *ServiceOpts) { + opts.appStateTTL = ttl + } +} + +// WithUserStateTTL sets the TTL for user state. +// If not set, user state will not expire. +func WithUserStateTTL(ttl time.Duration) ServiceOpt { + return func(opts *ServiceOpts) { + opts.userStateTTL = ttl + } +} + +// WithAutoCreateTable enables automatic table creation if tables don't exist. +// Default is true. Set to false to require manual table creation. +func WithAutoCreateTable(enable bool) ServiceOpt { + return func(opts *ServiceOpts) { + opts.autoCreateTable = enable + } +} + +// WithAutoMigrate enables automatic table migration for existing tables. +// Default is false to prevent unintended schema changes in production. +// This includes adding missing columns and updating column definitions. +func WithAutoMigrate(enable bool) ServiceOpt { + return func(opts *ServiceOpts) { + opts.autoMigrate = enable + } +} + +// WithCleanupInterval sets the interval for automatic cleanup of expired data. +// Default is 5 minutes if any TTL is configured. +func WithCleanupInterval(interval time.Duration) ServiceOpt { + return func(opts *ServiceOpts) { + opts.cleanupInterval = interval + } +} + +// WithEnableAsyncPersist enables async persistence for session state and event list. +// if not set, default is false. +func WithEnableAsyncPersist(enable bool) ServiceOpt { + return func(opts *ServiceOpts) { + opts.enableAsyncPersist = enable + } +} + +// WithAsyncPersisterNum sets the number of workers for async persistence. +func WithAsyncPersisterNum(num int) ServiceOpt { + return func(opts *ServiceOpts) { + if num < 1 { + num = defaultAsyncPersisterNum + } + opts.asyncPersisterNum = num + } +} + +// WithSummarizer injects a summarizer for LLM-based summaries. +func WithSummarizer(s summary.SessionSummarizer) ServiceOpt { + return func(opts *ServiceOpts) { + opts.summarizer = s + } +} + +// WithAsyncSummaryNum sets the number of workers for async summary processing. +func WithAsyncSummaryNum(num int) ServiceOpt { + return func(opts *ServiceOpts) { + if num < 1 { + num = defaultAsyncSummaryNum + } + opts.asyncSummaryNum = num + } +} + +// WithSummaryQueueSize sets the size of the summary job queue. +func WithSummaryQueueSize(size int) ServiceOpt { + return func(opts *ServiceOpts) { + if size < 1 { + size = defaultSummaryQueueSize + } + opts.summaryQueueSize = size + } +} + +// WithSummaryJobTimeout sets the timeout for processing a single summary job. +// If not set, a sensible default will be applied. +func WithSummaryJobTimeout(timeout time.Duration) ServiceOpt { + return func(opts *ServiceOpts) { + if timeout <= 0 { + return + } + opts.summaryJobTimeout = timeout + } +} diff --git a/session/database/schema.sql b/session/database/schema.sql new file mode 100644 index 000000000..09742877f --- /dev/null +++ b/session/database/schema.sql @@ -0,0 +1,87 @@ +-- MySQL Session Service Schema +-- This file contains the schema for the MySQL session service. +-- Note: GORM can auto-migrate these tables, but this file is provided for reference and manual setup. + +-- Create database (optional) +-- CREATE DATABASE IF NOT EXISTS trpc_sessions CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci; +-- USE trpc_sessions; + +-- Session States Table +-- Stores session metadata and state +CREATE TABLE IF NOT EXISTS `session_states` ( + `id` BIGINT UNSIGNED NOT NULL AUTO_INCREMENT, + `app_name` VARCHAR(255) NOT NULL, + `user_id` VARCHAR(255) NOT NULL, + `session_id` VARCHAR(255) NOT NULL, + `state` MEDIUMBLOB, + `created_at` DATETIME NOT NULL, + `updated_at` DATETIME NOT NULL, + `expires_at` DATETIME DEFAULT NULL, + PRIMARY KEY (`id`), + UNIQUE INDEX `idx_app_user_session` (`app_name`, `user_id`, `session_id`), + INDEX `idx_expires_at` (`expires_at`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; + +-- Session Events Table +-- Stores session events +CREATE TABLE IF NOT EXISTS `session_events` ( + `id` BIGINT UNSIGNED NOT NULL AUTO_INCREMENT, + `app_name` VARCHAR(255) NOT NULL, + `user_id` VARCHAR(255) NOT NULL, + `session_id` VARCHAR(255) NOT NULL, + `event_data` MEDIUMBLOB NOT NULL, + `timestamp` DATETIME NOT NULL, + `created_at` DATETIME NOT NULL, + `expires_at` DATETIME DEFAULT NULL, + PRIMARY KEY (`id`), + INDEX `idx_app_user_session_event` (`app_name`, `user_id`, `session_id`, `timestamp`), + INDEX `idx_expires_at` (`expires_at`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; + +-- Session Summaries Table +-- Stores session summaries (supports branch summaries) +CREATE TABLE IF NOT EXISTS `session_summaries` ( + `id` BIGINT UNSIGNED NOT NULL AUTO_INCREMENT, + `app_name` VARCHAR(255) NOT NULL, + `user_id` VARCHAR(255) NOT NULL, + `session_id` VARCHAR(255) NOT NULL, + `filter_key` VARCHAR(255) NOT NULL DEFAULT '', + `summary` MEDIUMBLOB NOT NULL, + `updated_at` DATETIME NOT NULL, + `expires_at` DATETIME DEFAULT NULL, + PRIMARY KEY (`id`), + INDEX `idx_app_user_session_filter` (`app_name`, `user_id`, `session_id`, `filter_key`), + INDEX `idx_expires_at` (`expires_at`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; + +-- App States Table +-- Stores application-level state +CREATE TABLE IF NOT EXISTS `app_states` ( + `id` BIGINT UNSIGNED NOT NULL AUTO_INCREMENT, + `app_name` VARCHAR(255) NOT NULL, + `state_key` VARCHAR(255) NOT NULL, + `value` MEDIUMBLOB NOT NULL, + `updated_at` DATETIME NOT NULL, + `expires_at` DATETIME DEFAULT NULL, + PRIMARY KEY (`id`), + UNIQUE INDEX `idx_app_key` (`app_name`, `state_key`), + INDEX `idx_expires_at` (`expires_at`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; + +-- User States Table +-- Stores user-level state +CREATE TABLE IF NOT EXISTS `user_states` ( + `id` BIGINT UNSIGNED NOT NULL AUTO_INCREMENT, + `app_name` VARCHAR(255) NOT NULL, + `user_id` VARCHAR(255) NOT NULL, + `state_key` VARCHAR(255) NOT NULL, + `value` MEDIUMBLOB NOT NULL, + `updated_at` DATETIME NOT NULL, + `expires_at` DATETIME DEFAULT NULL, + PRIMARY KEY (`id`), + UNIQUE INDEX `idx_app_user_key` (`app_name`, `user_id`, `state_key`), + INDEX `idx_expires_at` (`expires_at`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; + + + diff --git a/session/database/service.go b/session/database/service.go new file mode 100644 index 000000000..e762d179a --- /dev/null +++ b/session/database/service.go @@ -0,0 +1,939 @@ +// +// Tencent is pleased to support the open source community by making trpc-agent-go available. +// +// Copyright (C) 2025 Tencent. All rights reserved. +// +// trpc-agent-go is licensed under the Apache License Version 2.0. +// +// + +// Package database provides the relational database session service. +// It supports MySQL, PostgreSQL, and other GORM-compatible databases. +package database + +import ( + "context" + "encoding/json" + "fmt" + "strings" + "sync" + "time" + + "github.com/google/uuid" + "github.com/spaolacci/murmur3" + "gorm.io/gorm" + "trpc.group/trpc-go/trpc-agent-go/event" + isession "trpc.group/trpc-go/trpc-agent-go/internal/session" + "trpc.group/trpc-go/trpc-agent-go/log" + "trpc.group/trpc-go/trpc-agent-go/session" + storage "trpc.group/trpc-go/trpc-agent-go/storage/database" +) + +var _ session.Service = (*Service)(nil) + +const ( + defaultSessionEventLimit = 1000 + defaultTimeout = 2 * time.Second + defaultChanBufferSize = 100 + defaultAsyncPersisterNum = 10 + defaultCleanupInterval = 5 * time.Minute + + defaultAsyncSummaryNum = 3 + defaultSummaryQueueSize = 256 + defaultSummaryJobTimeout = 30 * time.Second +) + +// Service is the database session service. +type Service struct { + opts ServiceOpts + db *gorm.DB + eventPairChans []chan *sessionEventPair // channel for session events to persistence + summaryJobChans []chan *summaryJob // channel for summary jobs to processing + cleanupTicker *time.Ticker + cleanupDone chan struct{} + cleanupOnce sync.Once + once sync.Once +} + +type sessionEventPair struct { + key session.Key + event *event.Event +} + +// summaryJob represents a summary job to be processed asynchronously. +type summaryJob struct { + sessionKey session.Key + filterKey string + force bool + session *session.Session +} + +// NewService creates a new database session service. +func NewService(options ...ServiceOpt) (*Service, error) { + opts := ServiceOpts{ + sessionEventLimit: defaultSessionEventLimit, + autoCreateTable: true, // Default: auto create tables if not exist + asyncPersisterNum: defaultAsyncPersisterNum, + asyncSummaryNum: defaultAsyncSummaryNum, + summaryQueueSize: defaultSummaryQueueSize, + summaryJobTimeout: defaultSummaryJobTimeout, + } + for _, option := range options { + option(&opts) + } + + var db *gorm.DB + var err error + builder := storage.GetClientBuilder() + + // if instance name set, and dsn not set, use instance name to create database client + if opts.dsn == "" && opts.instanceName != "" { + builderOpts, ok := storage.GetDatabaseInstance(opts.instanceName) + if !ok { + return nil, fmt.Errorf("database instance %s not found", opts.instanceName) + } + db, err = builder(builderOpts...) + if err != nil { + return nil, fmt.Errorf("create database client from instance name failed: %w", err) + } + } else { + builderOpts := []storage.ClientBuilderOpt{ + storage.WithClientBuilderDSN(opts.dsn), + } + // Add driver type if specified + if opts.driverType != "" { + builderOpts = append(builderOpts, storage.WithDriverType(opts.driverType)) + } + // Add extra options + if len(opts.extraOptions) > 0 { + builderOpts = append(builderOpts, storage.WithExtraOptions(opts.extraOptions...)) + } + db, err = builder(builderOpts...) + if err != nil { + return nil, fmt.Errorf("create database client from dsn failed: %w", err) + } + } + + // init table: check schema, create if needed, and optionally migrate + if err := initializeTables(db, opts.autoCreateTable, opts.autoMigrate); err != nil { + return nil, fmt.Errorf("initialize tables failed: %w", err) + } + + // Set default cleanup interval if any TTL is configured + if opts.cleanupInterval <= 0 { + if opts.sessionTTL > 0 || opts.appStateTTL > 0 || opts.userStateTTL > 0 { + opts.cleanupInterval = defaultCleanupInterval + } + } + + s := &Service{ + opts: opts, + db: db, + cleanupDone: make(chan struct{}), + } + if opts.enableAsyncPersist { + s.startAsyncPersistWorker() + } + if opts.cleanupInterval > 0 { + s.startCleanupRoutine() + } + // Always start async summary workers by default. + s.startAsyncSummaryWorker() + return s, nil +} + +// CreateSession creates a new session. +func (s *Service) CreateSession( + ctx context.Context, + key session.Key, + state session.StateMap, + opts ...session.Option, +) (*session.Session, error) { + if err := key.CheckUserKey(); err != nil { + return nil, err + } + if key.SessionID == "" { + key.SessionID = uuid.New().String() + } + + now := time.Now() + + // Prepare state map for storage + stateMap := make(session.StateMap) + for k, v := range state { + stateMap[k] = v + } + + // Marshal session state + stateBytes, err := json.Marshal(stateMap) + if err != nil { + return nil, fmt.Errorf("marshal session state failed: %w", err) + } + + // Calculate expiration time + var expiresAt time.Time + if s.opts.sessionTTL > 0 { + expiresAt = now.Add(s.opts.sessionTTL) + } + + // Store session state in transaction + sessionModel := &sessionStateModel{ + AppName: key.AppName, + UserID: key.UserID, + SessionID: key.SessionID, + State: stateBytes, + CreatedAt: now, + UpdatedAt: now, + ExpiresAt: expiresAt, + } + if err := s.db.WithContext(ctx).Create(sessionModel).Error; err != nil { + return nil, fmt.Errorf("create session state failed: %w", err) + } + + // Query app states (outside transaction for better concurrency) + var appStates []appStateModel + if err := s.db.WithContext(ctx). + Where("app_name = ? AND (expires_at IS NULL OR expires_at > ?)", key.AppName, now). + Find(&appStates).Error; err != nil { + return nil, fmt.Errorf("query app states failed: %w", err) + } + + // Query user states (outside transaction for better concurrency) + var userStates []userStateModel + if err := s.db.WithContext(ctx). + Where("app_name = ? AND user_id = ? AND (expires_at IS NULL OR expires_at > ?)", + key.AppName, key.UserID, now).Find(&userStates).Error; err != nil { + return nil, fmt.Errorf("query user states failed: %w", err) + } + + // Process app/user state + appState := processAppStates(appStates) + userState := processUserStates(userStates) + + sess := &session.Session{ + ID: key.SessionID, + AppName: key.AppName, + UserID: key.UserID, + State: stateMap, + Events: []event.Event{}, + UpdatedAt: now, + CreatedAt: now, + } + return mergeState(appState, userState, sess), nil +} + +// GetSession gets a session. +func (s *Service) GetSession( + ctx context.Context, + key session.Key, + opts ...session.Option, +) (*session.Session, error) { + if err := key.CheckSessionKey(); err != nil { + return nil, err + } + opt := applyOptions(opts...) + sess, err := s.getSession(ctx, key, opt.EventNum, opt.EventTime) + if err != nil { + return nil, fmt.Errorf("database session service get session state failed: %w", err) + } + return sess, nil +} + +// ListSessions lists all sessions by user scope of session key. +func (s *Service) ListSessions( + ctx context.Context, + userKey session.UserKey, + opts ...session.Option, +) ([]*session.Session, error) { + if err := userKey.CheckUserKey(); err != nil { + return nil, err + } + opt := applyOptions(opts...) + sessList, err := s.listSessions(ctx, userKey, opt.EventNum, opt.EventTime) + if err != nil { + return nil, fmt.Errorf("database session service get session list failed: %w", err) + } + return sessList, nil +} + +// DeleteSession deletes a session. +func (s *Service) DeleteSession( + ctx context.Context, + key session.Key, + opts ...session.Option, +) error { + if err := key.CheckSessionKey(); err != nil { + return err + } + if err := s.deleteSessionState(ctx, key); err != nil { + return fmt.Errorf("database session service delete session state failed: %w", err) + } + return nil +} + +// UpdateAppState updates the state by target scope and key. +func (s *Service) UpdateAppState(ctx context.Context, appName string, state session.StateMap) error { + if appName == "" { + return session.ErrAppNameRequired + } + + now := time.Now() + var expiresAt time.Time + if s.opts.appStateTTL > 0 { + expiresAt = now.Add(s.opts.appStateTTL) + } + + return s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + for k, v := range state { + k = strings.TrimPrefix(k, session.StateAppPrefix) + appState := &appStateModel{ + AppName: appName, + StateKey: k, + Value: v, + UpdatedAt: now, + ExpiresAt: expiresAt, + } + // Use ON DUPLICATE KEY UPDATE + if err := tx.Where("app_name = ? AND state_key = ?", appName, k). + Assign(map[string]interface{}{ + "value": v, + "updated_at": now, + "expires_at": expiresAt, + }). + FirstOrCreate(appState).Error; err != nil { + return fmt.Errorf("update app state failed: %w", err) + } + } + return nil + }) +} + +// ListAppStates gets the app states. +func (s *Service) ListAppStates(ctx context.Context, appName string) (session.StateMap, error) { + if appName == "" { + return nil, session.ErrAppNameRequired + } + + var appStates []appStateModel + now := time.Now() + if err := s.db.WithContext(ctx). + Where("app_name = ? AND (expires_at IS NULL OR expires_at > ?)", appName, now). + Find(&appStates).Error; err != nil { + return nil, fmt.Errorf("database session service list app states failed: %w", err) + } + + return processAppStates(appStates), nil +} + +// DeleteAppState deletes the state by target scope and key. +func (s *Service) DeleteAppState(ctx context.Context, appName string, key string) error { + if appName == "" { + return session.ErrAppNameRequired + } + if key == "" { + return fmt.Errorf("state key is required") + } + + if err := s.db.WithContext(ctx). + Where("app_name = ? AND state_key = ?", appName, key). + Delete(&appStateModel{}).Error; err != nil { + return fmt.Errorf("database session service delete app state failed: %w", err) + } + return nil +} + +// UpdateUserState updates the state by target scope and key. +func (s *Service) UpdateUserState(ctx context.Context, userKey session.UserKey, state session.StateMap) error { + if err := userKey.CheckUserKey(); err != nil { + return err + } + + now := time.Now() + var expiresAt time.Time + if s.opts.userStateTTL > 0 { + expiresAt = now.Add(s.opts.userStateTTL) + } + + // Validate state keys + for k := range state { + if strings.HasPrefix(k, session.StateAppPrefix) { + return fmt.Errorf("database session service update user state failed: %s is not allowed", k) + } + if strings.HasPrefix(k, session.StateTempPrefix) { + return fmt.Errorf("database session service update user state failed: %s is not allowed", k) + } + } + + return s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + for k, v := range state { + k = strings.TrimPrefix(k, session.StateUserPrefix) + userState := &userStateModel{ + AppName: userKey.AppName, + UserID: userKey.UserID, + StateKey: k, + Value: v, + UpdatedAt: now, + ExpiresAt: expiresAt, + } + // Use ON DUPLICATE KEY UPDATE + if err := tx.Where("app_name = ? AND user_id = ? AND state_key = ?", userKey.AppName, userKey.UserID, k). + Assign(map[string]interface{}{ + "value": v, + "updated_at": now, + "expires_at": expiresAt, + }). + FirstOrCreate(userState).Error; err != nil { + return fmt.Errorf("update user state failed: %w", err) + } + } + return nil + }) +} + +// ListUserStates lists the state by target scope and key. +func (s *Service) ListUserStates(ctx context.Context, userKey session.UserKey) (session.StateMap, error) { + if err := userKey.CheckUserKey(); err != nil { + return nil, err + } + + var userStates []userStateModel + now := time.Now() + if err := s.db.WithContext(ctx). + Where("app_name = ? AND user_id = ? AND (expires_at IS NULL OR expires_at > ?)", + userKey.AppName, userKey.UserID, now). + Find(&userStates).Error; err != nil { + return nil, fmt.Errorf("database session service list user states failed: %w", err) + } + + return processUserStates(userStates), nil +} + +// DeleteUserState deletes the state by target scope and key. +func (s *Service) DeleteUserState(ctx context.Context, userKey session.UserKey, key string) error { + if err := userKey.CheckUserKey(); err != nil { + return err + } + if key == "" { + return fmt.Errorf("state key is required") + } + + if err := s.db.WithContext(ctx). + Where("app_name = ? AND user_id = ? AND state_key = ?", userKey.AppName, userKey.UserID, key). + Delete(&userStateModel{}).Error; err != nil { + return fmt.Errorf("database session service delete user state failed: %w", err) + } + return nil +} + +// AppendEvent appends an event to a session. +func (s *Service) AppendEvent( + ctx context.Context, + sess *session.Session, + event *event.Event, + opts ...session.Option, +) error { + key := session.Key{ + AppName: sess.AppName, + UserID: sess.UserID, + SessionID: sess.ID, + } + if err := key.CheckSessionKey(); err != nil { + return err + } + // update user session with the given event + isession.UpdateUserSession(sess, event, opts...) + + // persist event to database asynchronously + if s.opts.enableAsyncPersist { + defer func() { + if r := recover(); r != nil { + if err, ok := r.(error); ok && err.Error() == "send on closed channel" { + log.Errorf("database session service append event failed: %v", r) + return + } + panic(r) + } + }() + + // Hash-based distribution + hKey := fmt.Sprintf("%s:%s:%s", key.AppName, key.UserID, key.SessionID) + n := len(s.eventPairChans) + index := int(murmur3.Sum32([]byte(hKey))) % n + select { + case s.eventPairChans[index] <- &sessionEventPair{key: key, event: event}: + case <-ctx.Done(): + return ctx.Err() + } + return nil + } + + if err := s.addEvent(ctx, key, event); err != nil { + return fmt.Errorf("database session service append event failed: %w", err) + } + + return nil +} + +// Close closes the service. +func (s *Service) Close() error { + s.once.Do(func() { + // Close database connection + if s.db != nil { + if sqlDB, err := s.db.DB(); err == nil { + sqlDB.Close() + } + } + + // Stop cleanup routine + s.stopCleanupRoutine() + + // Close async persist channels + for _, ch := range s.eventPairChans { + close(ch) + } + + // Close summary channels + for _, ch := range s.summaryJobChans { + close(ch) + } + }) + + return nil +} + +func (s *Service) getSession( + ctx context.Context, + key session.Key, + limit int, + afterTime time.Time, +) (*session.Session, error) { + now := time.Now() + + var sessionModel sessionStateModel + if err := s.db.WithContext(ctx). + Where("app_name = ? AND user_id = ? AND session_id = ? AND (expires_at IS NULL OR expires_at > ?)", + key.AppName, key.UserID, key.SessionID, now). + First(&sessionModel).Error; err != nil { + if err == gorm.ErrRecordNotFound { + return nil, nil + } + return nil, fmt.Errorf("get session state failed: %w", err) + } + + var stateMap session.StateMap + if err := json.Unmarshal(sessionModel.State, &stateMap); err != nil { + return nil, fmt.Errorf("unmarshal session state failed: %w", err) + } + var appStates []appStateModel + if err := s.db.WithContext(ctx). + Where("app_name = ? AND (expires_at IS NULL OR expires_at > ?)", key.AppName, now). + Find(&appStates).Error; err != nil { + return nil, fmt.Errorf("query app states failed: %w", err) + } + var userStates []userStateModel + if err := s.db.WithContext(ctx). + Where("app_name = ? AND user_id = ? AND (expires_at IS NULL OR expires_at > ?)", + key.AppName, key.UserID, now). + Find(&userStates).Error; err != nil { + return nil, fmt.Errorf("query user states failed: %w", err) + } + + // query events + events, err := s.getEventsList(ctx, []session.Key{key}, limit, afterTime) + if err != nil { + return nil, fmt.Errorf("get events failed: %w", err) + } + + if len(events) == 0 { + events = make([][]event.Event, 1) + } + + sess := &session.Session{ + ID: key.SessionID, + AppName: key.AppName, + UserID: key.UserID, + State: stateMap, + Events: events[0], + UpdatedAt: sessionModel.UpdatedAt, + CreatedAt: sessionModel.CreatedAt, + } + + // query summaries if there are events + if len(sess.Events) > 0 { + var summaryModels []sessionSummaryModel + if err := s.db.WithContext(ctx). + Where("app_name = ? AND user_id = ? AND session_id = ? AND (expires_at IS NULL OR expires_at > ?)", + key.AppName, key.UserID, key.SessionID, now). + Find(&summaryModels).Error; err == nil && len(summaryModels) > 0 { + summaries := make(map[string]*session.Summary) + for _, sm := range summaryModels { + var summary session.Summary + if err := json.Unmarshal(sm.Summary, &summary); err == nil { + summaries[sm.FilterKey] = &summary + } + } + if len(summaries) > 0 { + sess.Summaries = summaries + } + } + } + + // refresh TTL if configured + if s.opts.sessionTTL > 0 { + expiresAt := now.Add(s.opts.sessionTTL) + s.db.WithContext(ctx).Model(&sessionStateModel{}). + Where("app_name = ? AND user_id = ? AND session_id = ?", key.AppName, key.UserID, key.SessionID). + Update("expires_at", expiresAt) + } + + // filter events to ensure they start with RoleUser + isession.EnsureEventStartWithUser(sess) + appState := processAppStates(appStates) + userState := processUserStates(userStates) + return mergeState(appState, userState, sess), nil +} + +func (s *Service) listSessions( + ctx context.Context, + userKey session.UserKey, + limit int, + afterTime time.Time, +) ([]*session.Session, error) { + now := time.Now() + // query session states + var sessionModels []sessionStateModel + if err := s.db.WithContext(ctx). + Where("app_name = ? AND user_id = ? AND (expires_at IS NULL OR expires_at > ?)", + userKey.AppName, userKey.UserID, now). + Find(&sessionModels).Error; err != nil { + return nil, fmt.Errorf("get session states failed: %w", err) + } + + if len(sessionModels) == 0 { + return []*session.Session{}, nil + } + + // query app states + var appStates []appStateModel + if err := s.db.WithContext(ctx). + Where("app_name = ? AND (expires_at IS NULL OR expires_at > ?)", userKey.AppName, now). + Find(&appStates).Error; err != nil { + return nil, fmt.Errorf("query app states failed: %w", err) + } + + // query user states + var userStates []userStateModel + if err := s.db.WithContext(ctx). + Where("app_name = ? AND user_id = ? AND (expires_at IS NULL OR expires_at > ?)", + userKey.AppName, userKey.UserID, now). + Find(&userStates).Error; err != nil { + return nil, fmt.Errorf("query user states failed: %w", err) + } + + // process app and user states + appState := processAppStates(appStates) + userState := processUserStates(userStates) + + // query events list + sessionKeys := make([]session.Key, 0, len(sessionModels)) + for _, sm := range sessionModels { + sessionKeys = append(sessionKeys, session.Key{ + AppName: userKey.AppName, + UserID: userKey.UserID, + SessionID: sm.SessionID, + }) + } + events, err := s.getEventsList(ctx, sessionKeys, limit, afterTime) + if err != nil { + return nil, fmt.Errorf("get events failed: %w", err) + } + + sessList := make([]*session.Session, 0, len(sessionModels)) + for i, sm := range sessionModels { + var stateMap session.StateMap + if err := json.Unmarshal(sm.State, &stateMap); err != nil { + return nil, fmt.Errorf("unmarshal session state failed: %w", err) + } + + sess := &session.Session{ + ID: sm.SessionID, + AppName: userKey.AppName, + UserID: userKey.UserID, + State: stateMap, + Events: events[i], + UpdatedAt: sm.UpdatedAt, + CreatedAt: sm.CreatedAt, + } + + // filter events to ensure they start with role user + isession.EnsureEventStartWithUser(sess) + sessList = append(sessList, mergeState(appState, userState, sess)) + } + + return sessList, nil +} + +func (s *Service) getEventsList( + ctx context.Context, + sessionKeys []session.Key, + limit int, + afterTime time.Time, +) ([][]event.Event, error) { + sessEventsList := make([][]event.Event, len(sessionKeys)) + + for i, key := range sessionKeys { + query := s.db.WithContext(ctx). + Where("app_name = ? AND user_id = ? AND session_id = ? AND timestamp >= ?", + key.AppName, key.UserID, key.SessionID, afterTime). + Order("timestamp DESC") + + if limit > 0 { + query = query.Limit(limit) + } + + var eventModels []sessionEventModel + if err := query.Find(&eventModels).Error; err != nil { + return nil, fmt.Errorf("get events failed: %w", err) + } + + events := make([]event.Event, 0, len(eventModels)) + for _, em := range eventModels { + var evt event.Event + if err := json.Unmarshal(em.EventData, &evt); err != nil { + return nil, fmt.Errorf("unmarshal event failed: %w", err) + } + events = append(events, evt) + } + + // reverse events to get oldest first order + if len(events) > 1 { + for j, k := 0, len(events)-1; j < k; j, k = j+1, k-1 { + events[j], events[k] = events[k], events[j] + } + } + sessEventsList[i] = events + } + + return sessEventsList, nil +} + +func (s *Service) addEvent(ctx context.Context, key session.Key, event *event.Event) error { + now := time.Now() + return s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + // update session state + var sessionModel sessionStateModel + if err := tx.Where("app_name = ? AND user_id = ? AND session_id = ?", + key.AppName, key.UserID, key.SessionID). + First(&sessionModel).Error; err != nil { + return fmt.Errorf("get session state failed: %w", err) + } + var stateMap session.StateMap + if err := json.Unmarshal(sessionModel.State, &stateMap); err != nil { + return fmt.Errorf("unmarshal session state failed: %w", err) + } + + // apply event state delta + isession.ApplyEventStateDeltaMap(stateMap, event) + updatedStateBytes, err := json.Marshal(stateMap) + if err != nil { + return fmt.Errorf("marshal session state failed: %w", err) + } + sessionModel.State = updatedStateBytes + sessionModel.UpdatedAt = now + if s.opts.sessionTTL > 0 { + sessionModel.ExpiresAt = now.Add(s.opts.sessionTTL) + } + + if err := tx.Save(&sessionModel).Error; err != nil { + return fmt.Errorf("update session state failed: %w", err) + } + + // Add event if it has response and is not partial + if event.Response != nil && !event.IsPartial && event.IsValidContent() { + eventBytes, err := json.Marshal(event) + if err != nil { + return fmt.Errorf("marshal event failed: %w", err) + } + + var expiresAt time.Time + if s.opts.sessionTTL > 0 { + expiresAt = now.Add(s.opts.sessionTTL) + } + + eventModel := &sessionEventModel{ + AppName: key.AppName, + UserID: key.UserID, + SessionID: key.SessionID, + EventData: eventBytes, + Timestamp: event.Timestamp, + CreatedAt: now, + ExpiresAt: expiresAt, + } + + if err := tx.Create(eventModel).Error; err != nil { + return fmt.Errorf("create event failed: %w", err) + } + + // Enforce event limit + if s.opts.sessionEventLimit > 0 { + var count int64 + if err := tx.Model(&sessionEventModel{}). + Where("app_name = ? AND user_id = ? AND session_id = ?", + key.AppName, key.UserID, key.SessionID). + Count(&count).Error; err != nil { + return fmt.Errorf("count events failed: %w", err) + } + + if count > int64(s.opts.sessionEventLimit) { + // Delete oldest events + if err := tx.Where("app_name = ? AND user_id = ? AND session_id = ?", + key.AppName, key.UserID, key.SessionID). + Order("timestamp ASC"). + Limit(int(count - int64(s.opts.sessionEventLimit))). + Delete(&sessionEventModel{}).Error; err != nil { + return fmt.Errorf("delete old events failed: %w", err) + } + } + } + } + + return nil + }) +} + +func (s *Service) deleteSessionState(ctx context.Context, key session.Key) error { + return s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + // Delete session state + if err := tx.Where("app_name = ? AND user_id = ? AND session_id = ?", + key.AppName, key.UserID, key.SessionID). + Delete(&sessionStateModel{}).Error; err != nil { + return fmt.Errorf("delete session state failed: %w", err) + } + + // Delete session events + if err := tx.Where("app_name = ? AND user_id = ? AND session_id = ?", + key.AppName, key.UserID, key.SessionID). + Delete(&sessionEventModel{}).Error; err != nil { + return fmt.Errorf("delete session events failed: %w", err) + } + + // Delete session summaries + if err := tx.Where("app_name = ? AND user_id = ? AND session_id = ?", + key.AppName, key.UserID, key.SessionID). + Delete(&sessionSummaryModel{}).Error; err != nil { + return fmt.Errorf("delete session summaries failed: %w", err) + } + + return nil + }) +} + +func (s *Service) startAsyncPersistWorker() { + persisterNum := s.opts.asyncPersisterNum + // init event pair chan + s.eventPairChans = make([]chan *sessionEventPair, persisterNum) + for i := 0; i < persisterNum; i++ { + s.eventPairChans[i] = make(chan *sessionEventPair, defaultChanBufferSize) + } + + for _, eventPairChan := range s.eventPairChans { + go func(eventPairChan chan *sessionEventPair) { + for eventPair := range eventPairChan { + ctx, cancel := context.WithTimeout(context.Background(), defaultTimeout) + log.Debugf("Session persistence queue monitoring: channel capacity: %d, current length: %d, session key:%s:%s:%s", + cap(eventPairChan), len(eventPairChan), eventPair.key.AppName, eventPair.key.UserID, eventPair.key.SessionID) + if err := s.addEvent(ctx, eventPair.key, eventPair.event); err != nil { + log.Errorf("database session service persistence event failed: %w", err) + } + cancel() + } + }(eventPairChan) + } +} + +// startCleanupRoutine starts the background cleanup routine. +func (s *Service) startCleanupRoutine() { + s.cleanupTicker = time.NewTicker(s.opts.cleanupInterval) + ticker := s.cleanupTicker + go func() { + defer ticker.Stop() + for { + select { + case <-ticker.C: + s.cleanupExpired() + case <-s.cleanupDone: + return + } + } + }() +} + +// stopCleanupRoutine stops the background cleanup routine. +func (s *Service) stopCleanupRoutine() { + s.cleanupOnce.Do(func() { + if s.cleanupTicker != nil { + close(s.cleanupDone) + s.cleanupTicker = nil + } + }) +} + +// cleanupExpired removes all expired sessions and states. +func (s *Service) cleanupExpired() { + ctx := context.Background() + now := time.Now() + + // Clean expired session states + s.db.WithContext(ctx).Where("expires_at IS NOT NULL AND expires_at <= ?", now). + Delete(&sessionStateModel{}) + + // Clean expired session events + s.db.WithContext(ctx).Where("expires_at IS NOT NULL AND expires_at <= ?", now). + Delete(&sessionEventModel{}) + + // Clean expired session summaries + s.db.WithContext(ctx).Where("expires_at IS NOT NULL AND expires_at <= ?", now). + Delete(&sessionSummaryModel{}) + + // Clean expired app states + s.db.WithContext(ctx).Where("expires_at IS NOT NULL AND expires_at <= ?", now). + Delete(&appStateModel{}) + + // Clean expired user states + s.db.WithContext(ctx).Where("expires_at IS NOT NULL AND expires_at <= ?", now). + Delete(&userStateModel{}) +} + +func processAppStates(appStates []appStateModel) session.StateMap { + stateMap := make(session.StateMap) + for _, as := range appStates { + stateMap[as.StateKey] = as.Value + } + return stateMap +} + +func processUserStates(userStates []userStateModel) session.StateMap { + stateMap := make(session.StateMap) + for _, us := range userStates { + stateMap[us.StateKey] = us.Value + } + return stateMap +} + +func mergeState(appState, userState session.StateMap, sess *session.Session) *session.Session { + for k, v := range appState { + sess.State[session.StateAppPrefix+k] = v + } + for k, v := range userState { + sess.State[session.StateUserPrefix+k] = v + } + return sess +} + +func applyOptions(opts ...session.Option) *session.Options { + opt := &session.Options{} + for _, o := range opts { + o(opt) + } + return opt +} diff --git a/session/database/service_test.go b/session/database/service_test.go new file mode 100644 index 000000000..c633bc9e7 --- /dev/null +++ b/session/database/service_test.go @@ -0,0 +1,318 @@ +// +// Tencent is pleased to support the open source community by making trpc-agent-go available. +// +// Copyright (C) 2025 Tencent. All rights reserved. +// +// trpc-agent-go is licensed under the Apache License Version 2.0. +// +// + +package database + +import ( + "context" + "testing" + "time" + + "trpc.group/trpc-go/trpc-agent-go/event" + "trpc.group/trpc-go/trpc-agent-go/model" + "trpc.group/trpc-go/trpc-agent-go/session" +) + +// TestNewService tests the service creation. +func TestNewService(t *testing.T) { + // This test requires a running Database instance + // Skip if DSN is not provided + dsn := "root:password@tcp(127.0.0.1:3306)/test_session?charset=utf8mb4&parseTime=True&loc=Local" + + service, err := NewService( + WithDatabaseDSN(dsn), + WithAutoCreateTable(true), + WithSessionEventLimit(100), + ) + if err != nil { + t.Skipf("Skip test due to Database connection error: %v", err) + return + } + defer service.Close() + + if service == nil { + t.Fatal("Expected service to be created") + } +} + +// TestCreateAndGetSession tests creating and retrieving a session. +func TestCreateAndGetSession(t *testing.T) { + dsn := "root:password@tcp(127.0.0.1:3306)/test_session?charset=utf8mb4&parseTime=True&loc=Local" + + service, err := NewService( + WithDatabaseDSN(dsn), + WithAutoCreateTable(true), + ) + if err != nil { + t.Skipf("Skip test due to Database connection error: %v", err) + return + } + defer service.Close() + + ctx := context.Background() + key := session.Key{ + AppName: "testapp", + UserID: "user1", + } + + // Create session + state := session.StateMap{ + "key1": []byte("value1"), + } + sess, err := service.CreateSession(ctx, key, state) + if err != nil { + t.Fatalf("Failed to create session: %v", err) + } + + if sess.ID == "" { + t.Fatal("Expected session ID to be generated") + } + if sess.AppName != key.AppName { + t.Errorf("Expected AppName %s, got %s", key.AppName, sess.AppName) + } + if sess.UserID != key.UserID { + t.Errorf("Expected UserID %s, got %s", key.UserID, sess.UserID) + } + + // Get session + key.SessionID = sess.ID + retrievedSess, err := service.GetSession(ctx, key) + if err != nil { + t.Fatalf("Failed to get session: %v", err) + } + + if retrievedSess.ID != sess.ID { + t.Errorf("Expected session ID %s, got %s", sess.ID, retrievedSess.ID) + } + + // Clean up + service.DeleteSession(ctx, key) +} + +// TestAppendEvent tests appending an event to a session. +func TestAppendEvent(t *testing.T) { + dsn := "root:password@tcp(127.0.0.1:3306)/test_session?charset=utf8mb4&parseTime=True&loc=Local" + + service, err := NewService( + WithDatabaseDSN(dsn), + WithAutoCreateTable(true), + ) + if err != nil { + t.Skipf("Skip test due to Database connection error: %v", err) + return + } + defer service.Close() + + ctx := context.Background() + key := session.Key{ + AppName: "testapp", + UserID: "user1", + } + + // Create session + sess, err := service.CreateSession(ctx, key, nil) + if err != nil { + t.Fatalf("Failed to create session: %v", err) + } + + // Append event + evt := event.New("test-invocation", "test-author") + evt.Timestamp = time.Now() + evt.Response = &model.Response{ + Done: true, + Choices: []model.Choice{ + { + Message: model.Message{ + Role: model.RoleUser, + Content: "Hello, this is a test message", + }, + }, + }, + } + + key.SessionID = sess.ID + err = service.AppendEvent(ctx, sess, evt) + if err != nil { + t.Fatalf("Failed to append event: %v", err) + } + + // Verify event was stored + retrievedSess, err := service.GetSession(ctx, key) + if err != nil { + t.Fatalf("Failed to get session: %v", err) + } + + if len(retrievedSess.Events) != 1 { + t.Errorf("Expected 1 event, got %d", len(retrievedSess.Events)) + } + + // Clean up + service.DeleteSession(ctx, key) +} + +// TestAppState tests app state operations. +func TestAppState(t *testing.T) { + dsn := "root:password@tcp(127.0.0.1:3306)/test_session?charset=utf8mb4&parseTime=True&loc=Local" + + service, err := NewService( + WithDatabaseDSN(dsn), + WithAutoCreateTable(true), + ) + if err != nil { + t.Skipf("Skip test due to Database connection error: %v", err) + return + } + defer service.Close() + + ctx := context.Background() + appName := "testapp" + + // Update app state + state := session.StateMap{ + "config": []byte("value1"), + } + err = service.UpdateAppState(ctx, appName, state) + if err != nil { + t.Fatalf("Failed to update app state: %v", err) + } + + // List app states + retrievedState, err := service.ListAppStates(ctx, appName) + if err != nil { + t.Fatalf("Failed to list app states: %v", err) + } + + if string(retrievedState["config"]) != "value1" { + t.Errorf("Expected config value1, got %s", retrievedState["config"]) + } + + // Delete app state + err = service.DeleteAppState(ctx, appName, "config") + if err != nil { + t.Fatalf("Failed to delete app state: %v", err) + } + + // Verify deletion + retrievedState, err = service.ListAppStates(ctx, appName) + if err != nil { + t.Fatalf("Failed to list app states: %v", err) + } + + if _, exists := retrievedState["config"]; exists { + t.Error("Expected config to be deleted") + } +} + +// TestUserState tests user state operations. +func TestUserState(t *testing.T) { + dsn := "root:password@tcp(127.0.0.1:3306)/test_session?charset=utf8mb4&parseTime=True&loc=Local" + + service, err := NewService( + WithDatabaseDSN(dsn), + WithAutoCreateTable(true), + ) + if err != nil { + t.Skipf("Skip test due to Database connection error: %v", err) + return + } + defer service.Close() + + ctx := context.Background() + userKey := session.UserKey{ + AppName: "testapp", + UserID: "user1", + } + + // Update user state + state := session.StateMap{ + "preference": []byte("dark_mode"), + } + err = service.UpdateUserState(ctx, userKey, state) + if err != nil { + t.Fatalf("Failed to update user state: %v", err) + } + + // List user states + retrievedState, err := service.ListUserStates(ctx, userKey) + if err != nil { + t.Fatalf("Failed to list user states: %v", err) + } + + if string(retrievedState["preference"]) != "dark_mode" { + t.Errorf("Expected preference dark_mode, got %s", retrievedState["preference"]) + } + + // Delete user state + err = service.DeleteUserState(ctx, userKey, "preference") + if err != nil { + t.Fatalf("Failed to delete user state: %v", err) + } + + // Verify deletion + retrievedState, err = service.ListUserStates(ctx, userKey) + if err != nil { + t.Fatalf("Failed to list user states: %v", err) + } + + if _, exists := retrievedState["preference"]; exists { + t.Error("Expected preference to be deleted") + } +} + +// TestSessionTTL tests session TTL functionality. +func TestSessionTTL(t *testing.T) { + dsn := "root:password@tcp(127.0.0.1:3306)/test_session?charset=utf8mb4&parseTime=True&loc=Local" + + service, err := NewService( + WithDatabaseDSN(dsn), + WithAutoCreateTable(true), + WithSessionTTL(2*time.Second), + WithCleanupInterval(1*time.Second), + ) + if err != nil { + t.Skipf("Skip test due to Database connection error: %v", err) + return + } + defer service.Close() + + ctx := context.Background() + key := session.Key{ + AppName: "testapp", + UserID: "user1", + } + + // Create session + sess, err := service.CreateSession(ctx, key, nil) + if err != nil { + t.Fatalf("Failed to create session: %v", err) + } + + key.SessionID = sess.ID + + // Session should exist + retrievedSess, err := service.GetSession(ctx, key) + if err != nil { + t.Fatalf("Failed to get session: %v", err) + } + if retrievedSess == nil { + t.Fatal("Expected session to exist") + } + + // Wait for TTL to expire and cleanup to run + time.Sleep(4 * time.Second) + + // Session should be deleted + retrievedSess, err = service.GetSession(ctx, key) + if err != nil { + t.Fatalf("Failed to get session: %v", err) + } + if retrievedSess != nil { + t.Error("Expected session to be deleted after TTL expiration") + } +} diff --git a/session/database/summary.go b/session/database/summary.go new file mode 100644 index 000000000..641e4a4d7 --- /dev/null +++ b/session/database/summary.go @@ -0,0 +1,333 @@ +// +// Tencent is pleased to support the open source community by making trpc-agent-go available. +// +// Copyright (C) 2025 Tencent. All rights reserved. +// +// trpc-agent-go is licensed under the Apache License Version 2.0. +// +// + +package database + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "time" + + "github.com/spaolacci/murmur3" + "gorm.io/gorm" + "trpc.group/trpc-go/trpc-agent-go/log" + "trpc.group/trpc-go/trpc-agent-go/session" + isession "trpc.group/trpc-go/trpc-agent-go/session/internal/session" +) + +// CreateSessionSummary generates a summary for the session (async-ready). +// It performs per-filterKey delta summarization; when filterKey=="", it means full-session summary. +func (s *Service) CreateSessionSummary(ctx context.Context, sess *session.Session, filterKey string, force bool) error { + if s.opts.summarizer == nil { + return nil + } + + if sess == nil { + return errors.New("nil session") + } + key := session.Key{AppName: sess.AppName, UserID: sess.UserID, SessionID: sess.ID} + if err := key.CheckSessionKey(); err != nil { + return fmt.Errorf("check session key failed: %w", err) + } + + updated, err := isession.SummarizeSession(ctx, s.opts.summarizer, sess, filterKey, force) + if err != nil { + return fmt.Errorf("summarize and persist failed: %w", err) + } + if !updated { + return nil + } + + // Persist only the updated filterKey summary with atomic set-if-newer to avoid late-write override. + sess.SummariesMu.RLock() + sum := sess.Summaries[filterKey] + sess.SummariesMu.RUnlock() + + payload, err := json.Marshal(sum) + if err != nil { + return fmt.Errorf("marshal summary failed: %w", err) + } + + // Store summary with atomic set-if-newer logic + now := time.Now() + var expiresAt time.Time + if s.opts.sessionTTL > 0 { + expiresAt = now.Add(s.opts.sessionTTL) + } + + return s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + // Check if summary exists + var existingSummary sessionSummaryModel + err := tx.Where("app_name = ? AND user_id = ? AND session_id = ? AND filter_key = ?", + key.AppName, key.UserID, key.SessionID, filterKey). + First(&existingSummary).Error + + if err == gorm.ErrRecordNotFound { + // Create new summary + summaryModel := &sessionSummaryModel{ + AppName: key.AppName, + UserID: key.UserID, + SessionID: key.SessionID, + FilterKey: filterKey, + Summary: payload, + UpdatedAt: sum.UpdatedAt, + ExpiresAt: expiresAt, + } + return tx.Create(summaryModel).Error + } + + if err != nil { + return fmt.Errorf("query existing summary failed: %w", err) + } + + // Compare timestamps to decide whether to update + if !existingSummary.UpdatedAt.Before(sum.UpdatedAt) { + // Existing summary is newer or equal, skip update + return nil + } + + // Update existing summary + return tx.Model(&existingSummary).Updates(map[string]interface{}{ + "summary": payload, + "updated_at": sum.UpdatedAt, + "expires_at": expiresAt, + }).Error + }) +} + +// GetSessionSummaryText returns the latest summary text from the session state if present. +func (s *Service) GetSessionSummaryText(ctx context.Context, sess *session.Session) (string, bool) { + if sess == nil { + return "", false + } + key := session.Key{AppName: sess.AppName, UserID: sess.UserID, SessionID: sess.ID} + if err := key.CheckSessionKey(); err != nil { + return "", false + } + + // Prefer local in-memory session summaries when available. + if len(sess.Summaries) > 0 { + if text, ok := pickSummaryText(sess.Summaries); ok { + return text, true + } + } + + // Query from database + now := time.Now() + var summaryModels []sessionSummaryModel + if err := s.db.WithContext(ctx). + Where("app_name = ? AND user_id = ? AND session_id = ? AND (expires_at IS NULL OR expires_at > ?)", + key.AppName, key.UserID, key.SessionID, now). + Find(&summaryModels).Error; err == nil && len(summaryModels) > 0 { + summaries := make(map[string]*session.Summary) + for _, sm := range summaryModels { + var summary session.Summary + if err := json.Unmarshal(sm.Summary, &summary); err == nil { + summaries[sm.FilterKey] = &summary + } + } + if len(summaries) > 0 { + return pickSummaryText(summaries) + } + } + + return "", false +} + +// pickSummaryText picks a non-empty summary string with preference for the +// all-contents key "" (empty filterKey). No special handling for "root". +func pickSummaryText(summaries map[string]*session.Summary) (string, bool) { + if summaries == nil { + return "", false + } + // Prefer full-summary stored under empty filterKey. + if sum, ok := summaries[session.SummaryFilterKeyAllContents]; ok && sum != nil && sum.Summary != "" { + return sum.Summary, true + } + for _, s := range summaries { + if s != nil && s.Summary != "" { + return s.Summary, true + } + } + return "", false +} + +// EnqueueSummaryJob enqueues a summary job for asynchronous processing. +func (s *Service) EnqueueSummaryJob(ctx context.Context, sess *session.Session, filterKey string, force bool) error { + if s.opts.summarizer == nil { + return nil + } + + if sess == nil { + return errors.New("nil session") + } + key := session.Key{AppName: sess.AppName, UserID: sess.UserID, SessionID: sess.ID} + if err := key.CheckSessionKey(); err != nil { + return fmt.Errorf("check session key failed: %w", err) + } + + // If async workers are not initialized, fall back to synchronous processing. + if len(s.summaryJobChans) == 0 { + return s.CreateSessionSummary(ctx, sess, filterKey, force) + } + + // Create summary job. + job := &summaryJob{ + sessionKey: key, + filterKey: filterKey, + force: force, + session: sess, + } + + // Try to enqueue the job asynchronously. + if s.tryEnqueueJob(ctx, job) { + return nil // Successfully enqueued. + } + + // If async enqueue failed, fall back to synchronous processing. + return s.CreateSessionSummary(ctx, sess, filterKey, force) +} + +// tryEnqueueJob attempts to enqueue a summary job to the appropriate channel. +// Returns true if successful, false if the job should be processed synchronously. +func (s *Service) tryEnqueueJob(ctx context.Context, job *summaryJob) bool { + // Select a channel using hash distribution. + keyStr := fmt.Sprintf("%s:%s:%s", job.sessionKey.AppName, job.sessionKey.UserID, job.sessionKey.SessionID) + index := int(murmur3.Sum32([]byte(keyStr))) % len(s.summaryJobChans) + + // Use a defer-recover pattern to handle potential panic from sending to closed channel. + defer func() { + if r := recover(); r != nil { + log.Warnf("summary job channel may be closed, falling back to synchronous processing: %v", r) + } + }() + + select { + case s.summaryJobChans[index] <- job: + return true // Successfully enqueued. + case <-ctx.Done(): + log.Debugf("summary job channel context cancelled, falling back to synchronous processing, error: %v", ctx.Err()) + return false // Context cancelled. + default: + // Queue is full, fall back to synchronous processing. + log.Warnf("summary job queue is full, falling back to synchronous processing") + return false + } +} + +func (s *Service) startAsyncSummaryWorker() { + summaryNum := s.opts.asyncSummaryNum + // Init summary job chan. + s.summaryJobChans = make([]chan *summaryJob, summaryNum) + for i := 0; i < summaryNum; i++ { + s.summaryJobChans[i] = make(chan *summaryJob, s.opts.summaryQueueSize) + } + + for _, summaryJobChan := range s.summaryJobChans { + go func(summaryJobChan chan *summaryJob) { + for job := range summaryJobChan { + s.processSummaryJob(job) + // After branch summary, cascade a full-session summary by + // reusing the same processing path to keep logic unified. + if job.filterKey != session.SummaryFilterKeyAllContents { + job.filterKey = session.SummaryFilterKeyAllContents + s.processSummaryJob(job) + } + } + }(summaryJobChan) + } +} + +func (s *Service) processSummaryJob(job *summaryJob) { + defer func() { + if r := recover(); r != nil { + log.Errorf("panic in summary worker: %v", r) + } + }() + + // Create a fresh context with timeout for this job. + ctx := context.Background() + if s.opts.summaryJobTimeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, s.opts.summaryJobTimeout) + defer cancel() + } + + // Perform the actual summary generation for the requested filterKey. + updated, err := isession.SummarizeSession(ctx, s.opts.summarizer, job.session, job.filterKey, job.force) + if err != nil { + log.Errorf("summary worker failed to generate summary: %v", err) + return + } + if !updated { + return + } + + // Persist to database. + job.session.SummariesMu.RLock() + sum := job.session.Summaries[job.filterKey] + job.session.SummariesMu.RUnlock() + + payload, err := json.Marshal(sum) + if err != nil { + log.Errorf("summary worker failed to marshal summary: %v", err) + return + } + + now := time.Now() + var expiresAt time.Time + if s.opts.sessionTTL > 0 { + expiresAt = now.Add(s.opts.sessionTTL) + } + + err = s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + // Check if summary exists + var existingSummary sessionSummaryModel + err := tx.Where("app_name = ? AND user_id = ? AND session_id = ? AND filter_key = ?", + job.sessionKey.AppName, job.sessionKey.UserID, job.sessionKey.SessionID, job.filterKey). + First(&existingSummary).Error + + if err == gorm.ErrRecordNotFound { + // Create new summary + summaryModel := &sessionSummaryModel{ + AppName: job.sessionKey.AppName, + UserID: job.sessionKey.UserID, + SessionID: job.sessionKey.SessionID, + FilterKey: job.filterKey, + Summary: payload, + UpdatedAt: sum.UpdatedAt, + ExpiresAt: expiresAt, + } + return tx.Create(summaryModel).Error + } + + if err != nil { + return fmt.Errorf("query existing summary failed: %w", err) + } + + // Compare timestamps to decide whether to update + if !existingSummary.UpdatedAt.Before(sum.UpdatedAt) { + // Existing summary is newer or equal, skip update + return nil + } + + // Update existing summary + return tx.Model(&existingSummary).Updates(map[string]interface{}{ + "summary": payload, + "updated_at": sum.UpdatedAt, + "expires_at": expiresAt, + }).Error + }) + + if err != nil { + log.Errorf("summary worker failed to store summary: %v", err) + } +} diff --git a/storage/database/database.go b/storage/database/database.go new file mode 100644 index 000000000..896e0d6f5 --- /dev/null +++ b/storage/database/database.go @@ -0,0 +1,204 @@ +// +// Tencent is pleased to support the open source community by making trpc-agent-go available. +// +// Copyright (C) 2025 Tencent. All rights reserved. +// +// trpc-agent-go is licensed under the Apache License Version 2.0. +// +// + +// Package database provides the Database instance info management. +package database + +import ( + "errors" + "fmt" + "time" + + "gorm.io/driver/mysql" + "gorm.io/gorm" + "gorm.io/gorm/logger" +) + +func init() { + databaseRegistry = make(map[string][]ClientBuilderOpt) +} + +var databaseRegistry map[string][]ClientBuilderOpt + +type clientBuilder func(builderOpts ...ClientBuilderOpt) (*gorm.DB, error) + +var globalBuilder clientBuilder = defaultClientBuilder + +// SetClientBuilder sets the Database client builder. +func SetClientBuilder(builder clientBuilder) { + globalBuilder = builder +} + +// GetClientBuilder gets the Database client builder. +func GetClientBuilder() clientBuilder { + return globalBuilder +} + +// defaultClientBuilder is the default Database client builder. +func defaultClientBuilder(builderOpts ...ClientBuilderOpt) (*gorm.DB, error) { + o := &ClientBuilderOpts{} + for _, opt := range builderOpts { + opt(o) + } + + if o.DSN == "" { + return nil, errors.New("database: DSN is empty") + } + + // Default to MySQL if driver type not specified + if o.DriverType == "" { + o.DriverType = DriverMySQL + } + + // Set default GORM config if not provided + if o.Config == nil { + o.Config = &gorm.Config{ + Logger: logger.Default.LogMode(logger.Silent), + } + } + + // Select appropriate driver based on driver type + var dialector gorm.Dialector + switch o.DriverType { + case DriverMySQL: + dialector = mysql.Open(o.DSN) + case DriverPostgreSQL: + // Note: requires "gorm.io/driver/postgres" to be imported + return nil, fmt.Errorf("database: PostgreSQL driver not imported, please use custom builder with postgres.Open()") + case DriverSQLite: + // Note: requires "gorm.io/driver/sqlite" to be imported + return nil, fmt.Errorf("database: SQLite driver not imported, please use custom builder with sqlite.Open()") + default: + return nil, fmt.Errorf("database: unsupported driver type: %s", o.DriverType) + } + + db, err := gorm.Open(dialector, o.Config) + if err != nil { + return nil, fmt.Errorf("database: open connection: %w", err) + } + + // Get underlying sql.DB to configure connection pool + sqlDB, err := db.DB() + if err != nil { + return nil, fmt.Errorf("database: get underlying sql.DB: %w", err) + } + + // Set connection pool parameters + if o.MaxIdleConns > 0 { + sqlDB.SetMaxIdleConns(o.MaxIdleConns) + } + if o.MaxOpenConns > 0 { + sqlDB.SetMaxOpenConns(o.MaxOpenConns) + } + if o.ConnMaxLifetime > 0 { + sqlDB.SetConnMaxLifetime(o.ConnMaxLifetime) + } + if o.ConnMaxIdleTime > 0 { + sqlDB.SetConnMaxIdleTime(o.ConnMaxIdleTime) + } + + return db, nil +} + +// ClientBuilderOpt is the option for the Database client. +type ClientBuilderOpt func(*ClientBuilderOpts) + +// DriverType represents the database driver type +type DriverType string + +const ( + // DriverMySQL represents MySQL database + DriverMySQL DriverType = "mysql" + // DriverPostgreSQL represents PostgreSQL database + DriverPostgreSQL DriverType = "postgres" + // DriverSQLite represents SQLite database (for testing) + DriverSQLite DriverType = "sqlite" +) + +// ClientBuilderOpts is the options for the Database client. +type ClientBuilderOpts struct { + // DSN is the Database data source name. + // MySQL format: username:password@tcp(host:port)/dbname?charset=utf8mb4&parseTime=True&loc=Local + // PostgreSQL format: host=localhost user=postgres password=pass dbname=db port=5432 sslmode=disable + // SQLite format: file:test.db?cache=shared or :memory: + DSN string + + // DriverType specifies which database driver to use (mysql, postgres, sqlite) + // If not specified, defaults to mysql for backward compatibility + DriverType DriverType + + // Config is the GORM configuration. + Config *gorm.Config + + // Connection pool settings + MaxIdleConns int + MaxOpenConns int + ConnMaxLifetime time.Duration + ConnMaxIdleTime time.Duration + + // ExtraOptions is the extra options for the Database client. + ExtraOptions []any +} + +// WithClientBuilderDSN sets the Database DSN for clientBuilder. +func WithClientBuilderDSN(dsn string) ClientBuilderOpt { + return func(opts *ClientBuilderOpts) { + opts.DSN = dsn + } +} + +// WithDriverType sets the database driver type (mysql, postgres, sqlite). +// If not set, defaults to mysql for backward compatibility. +func WithDriverType(driverType DriverType) ClientBuilderOpt { + return func(opts *ClientBuilderOpts) { + opts.DriverType = driverType + } +} + +// WithGormConfig sets the GORM configuration. +func WithGormConfig(config *gorm.Config) ClientBuilderOpt { + return func(opts *ClientBuilderOpts) { + opts.Config = config + } +} + +// WithMaxIdleConns sets the maximum number of idle connections in the pool. +func WithMaxIdleConns(n int) ClientBuilderOpt { + return func(opts *ClientBuilderOpts) { + opts.MaxIdleConns = n + } +} + +// WithMaxOpenConns sets the maximum number of open connections to the database. +func WithMaxOpenConns(n int) ClientBuilderOpt { + return func(opts *ClientBuilderOpts) { + opts.MaxOpenConns = n + } +} + +// WithExtraOptions sets the Database client extra options for clientBuilder. +// this option mainly used for the customized Database client builder, it will be passed to the builder. +func WithExtraOptions(extraOptions ...any) ClientBuilderOpt { + return func(opts *ClientBuilderOpts) { + opts.ExtraOptions = append(opts.ExtraOptions, extraOptions...) + } +} + +// RegisterDatabaseInstance registers a database instance options. +func RegisterDatabaseInstance(name string, opts ...ClientBuilderOpt) { + databaseRegistry[name] = append(databaseRegistry[name], opts...) +} + +// GetDatabaseInstance gets the database instance options. +func GetDatabaseInstance(name string) ([]ClientBuilderOpt, bool) { + if _, ok := databaseRegistry[name]; !ok { + return nil, false + } + return databaseRegistry[name], true +} diff --git a/storage/database/go.mod b/storage/database/go.mod new file mode 100644 index 000000000..ae5595983 --- /dev/null +++ b/storage/database/go.mod @@ -0,0 +1,15 @@ +module trpc.group/trpc-go/trpc-agent-go/storage/database + +go 1.22 + +require ( + gorm.io/driver/mysql v1.5.7 + gorm.io/gorm v1.25.12 +) + +require ( + github.com/go-sql-driver/mysql v1.7.0 // indirect + github.com/jinzhu/inflection v1.0.0 // indirect + github.com/jinzhu/now v1.1.5 // indirect + golang.org/x/text v0.14.0 // indirect +) diff --git a/storage/database/go.sum b/storage/database/go.sum new file mode 100644 index 000000000..7559d885c --- /dev/null +++ b/storage/database/go.sum @@ -0,0 +1,13 @@ +github.com/go-sql-driver/mysql v1.7.0 h1:ueSltNNllEqE3qcWBTD0iQd3IpL/6U+mJxLkazJ7YPc= +github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= +github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= +github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= +github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= +github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= +golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= +golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +gorm.io/driver/mysql v1.5.7 h1:MndhOPYOfEp2rHKgkZIhJ16eVUIRf2HmzgoPmh7FCWo= +gorm.io/driver/mysql v1.5.7/go.mod h1:sEtPWMiqiN1N1cMXoXmBbd8C6/l+TESwriotuRRpkDM= +gorm.io/gorm v1.25.7/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= +gorm.io/gorm v1.25.12 h1:I0u8i2hWQItBq1WfE0o2+WuL9+8L21K9e2HHSTE/0f8= +gorm.io/gorm v1.25.12/go.mod h1:xh7N7RHfYlNc5EmcI/El95gXusucDrQnHXe0+CgWcLQ=