diff --git a/.github/workflows/build_test.yml b/.github/workflows/build_test.yml index ff02e13eb..0b6832e20 100644 --- a/.github/workflows/build_test.yml +++ b/.github/workflows/build_test.yml @@ -68,6 +68,11 @@ jobs: ports: - 6379 + valkey: + image: valkey/valkey-bundle:unstable + ports: + - 6380 + rustfs: image: rustfs/rustfs:alpha ports: @@ -133,6 +138,8 @@ jobs: CLICKHOUSE_URI: clickhouse://localhost:${{ job.services.clickhouse.ports[8123] }}/ # Redis REDIS_URI: redis://localhost:${{ job.services.redis.ports[6379] }}/ + # Valkey + VALKEY_URI: valkey://localhost:${{ job.services.valkey.ports[6380] }}/ # S3 S3_ENDPOINT: localhost:${{ job.services.rustfs.ports[9000] }} S3_ACCESS_KEY_ID: rustfsadmin @@ -181,7 +188,7 @@ jobs: go-version-file: ./go.mod - name: Test - run: go test -timeout 20m -v ./... -skip "TestPostgres|TestMySQL|TestMongo|TestRedis|TestClickHouse|TestMilvus|TestQdrant|TestWeaviate" + run: go test -timeout 20m -v ./... -skip "TestPostgres|TestMySQL|TestMongo|TestRedis|TestClickHouse|TestMilvus|TestQdrant|TestWeaviate|TestValkey" unit_test_windows: strategy: @@ -215,7 +222,7 @@ jobs: go-version-file: ./go.mod - name: Test - run: go test -timeout 20m -v ./... -skip "TestPostgres|TestMySQL|TestMongo|TestRedis|TestClickHouse|TestMilvus|TestQdrant|TestWeaviate" + run: go test -timeout 20m -v ./... -skip "TestPostgres|TestMySQL|TestMongo|TestRedis|TestClickHouse|TestMilvus|TestQdrant|TestWeaviate|TestValkey" integrate_test: name: integrate tests diff --git a/config/config.go b/config/config.go index 8a2ea6512..e148398cb 100644 --- a/config/config.go +++ b/config/config.go @@ -831,6 +831,10 @@ func (config *Config) Validate() error { storage.RedissPrefix, storage.RedisClusterPrefix, storage.RedissClusterPrefix, + storage.ValkeyPrefix, + storage.ValkeysPrefix, + storage.ValkeyClusterPrefix, + storage.ValkeysClusterPrefix, storage.MongoPrefix, storage.MongoSrvPrefix, storage.MySQLPrefix, diff --git a/config/config.toml b/config/config.toml index 37bf2ca3d..4201b71a6 100644 --- a/config/config.toml +++ b/config/config.toml @@ -1,10 +1,14 @@ [database] -# The database for caching, support Redis, MySQL, Postgres and MongoDB: +# The database for caching, support Redis, Valkey, MySQL, Postgres and MongoDB: # redis://:@:/ # rediss://:@:/ # redis+cluster://:@:[?addr=:&addr=:] # rediss+cluster://:@:[?addr=:&addr=:] +# valkey://:@:/ +# valkeys://:@:/ +# valkey+cluster://:@:[?addr=:&addr=:] +# valkeys+cluster://:@:[?addr=:&addr=:] # mysql://[username[:password]@][protocol[(address)]]/dbname[?param1=value1&...¶mN=valueN] # postgres://bob:secret@1.2.3.4:5432/mydb?sslmode=verify-full # postgresql://bob:secret@1.2.3.4:5432/mydb?sslmode=verify-full diff --git a/config/config_test.go b/config/config_test.go index 6fa626298..1a8ea00f1 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -503,4 +503,20 @@ func (s *ValidateTestSuite) TestCacheStore() { // Test that rediss+cluster:// prefix is accepted for cache_store s.Database.CacheStore = "rediss+cluster://:password@192.168.1.11:6379?addr=192.168.0.5:6379" s.NoError(s.Validate()) + + // Test that valkey:// prefix is accepted for cache_store + s.Database.CacheStore = "valkey://localhost:6379/0" + s.NoError(s.Validate()) + + // Test that valkeys:// prefix is accepted for cache_store + s.Database.CacheStore = "valkeys://localhost:6379/0" + s.NoError(s.Validate()) + + // Test that valkey+cluster:// prefix is accepted for cache_store + s.Database.CacheStore = "valkey+cluster://:password@192.168.1.11:6379?addr=192.168.0.5:6379&addr=192.168.0.7:6379" + s.NoError(s.Validate()) + + // Test that valkeys+cluster:// prefix is accepted for cache_store + s.Database.CacheStore = "valkeys+cluster://:password@192.168.1.11:6379?addr=192.168.0.5:6379" + s.NoError(s.Validate()) } diff --git a/go.mod b/go.mod index e302eeefd..c9066db8e 100644 --- a/go.mod +++ b/go.mod @@ -63,6 +63,7 @@ require ( github.com/stretchr/testify v1.11.1 github.com/swaggest/swgui v1.8.5 github.com/tiktoken-go/tokenizer v0.7.0 + github.com/valkey-io/valkey-go v1.0.74 github.com/weaviate/weaviate v1.27.0 github.com/weaviate/weaviate-go-client/v4 v4.16.1 github.com/yuin/goldmark v1.7.16 diff --git a/go.sum b/go.sum index 47ea55925..0480e2fbc 100644 --- a/go.sum +++ b/go.sum @@ -740,8 +740,8 @@ github.com/onsi/ginkgo v1.10.3/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+ github.com/onsi/ginkgo/v2 v2.23.4 h1:ktYTpKJAVZnDT4VjxSbiBenUjmlL/5QkBEocaWXiQus= github.com/onsi/ginkgo/v2 v2.23.4/go.mod h1:Bt66ApGPBFzHyR+JO10Zbt0Gsp4uWxu5mIOTusL46e8= github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY= -github.com/onsi/gomega v1.37.0 h1:CdEG8g0S133B4OswTDC/5XPSzE1OeP29QOioj2PID2Y= -github.com/onsi/gomega v1.37.0/go.mod h1:8D9+Txp43QWKhM24yyOBEdpkzN8FvJyAwecBgsU4KU0= +github.com/onsi/gomega v1.38.3 h1:eTX+W6dobAYfFeGC2PV6RwXRu/MyT+cQguijutvkpSM= +github.com/onsi/gomega v1.38.3/go.mod h1:ZCU1pkQcXDO5Sl9/VVEGlDyp+zm0m1cmeG5TOzLgdh4= github.com/opentracing/opentracing-go v1.1.0/go.mod h1:UkNAQd3GIcIGf0SeVgPpRdFStlNbqXla1AfSYxPUl2o= github.com/openzipkin/zipkin-go v0.4.3 h1:9EGwpqkgnwdEIJ+Od7QVSEIH+ocmm5nPat0G7sjsSdg= github.com/openzipkin/zipkin-go v0.4.3/go.mod h1:M9wCJZFWCo2RiY+o1eBCEMe0Dp2S5LDHcMZmk3RmK7c= @@ -921,6 +921,8 @@ github.com/ugorji/go v1.1.7/go.mod h1:kZn38zHttfInRq0xu/PH0az30d+z6vm202qpg1oXVM github.com/ugorji/go/codec v0.0.0-20181204163529-d75b2dcb6bc8/go.mod h1:VFNgLljTbGfSG7qAOspJ7OScBnGdDN/yBr0sguwnwf0= github.com/ugorji/go/codec v1.1.7/go.mod h1:Ax+UKWsSmolVDwsd+7N3ZtXu+yMGCf907BLYF3GoBXY= github.com/urfave/negroni v1.0.0/go.mod h1:Meg73S6kFm/4PpbYdq35yYWoCZ9mS/YSx+lKnmiohz4= +github.com/valkey-io/valkey-go v1.0.74 h1:NqtBHzjybz+is+c71hsyZP7hoE5lwCHQX026me0Vb08= +github.com/valkey-io/valkey-go v1.0.74/go.mod h1:VGhZ6fs68Qrn2+OhH+6waZH27bjpgQOiLyUQyXuYK5k= github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= github.com/valyala/fasthttp v1.6.0/go.mod h1:FstJa9V+Pj9vQ7OJie2qMHdwemEDaDiSdBnvPM1Su9w= github.com/valyala/fasttemplate v1.0.1/go.mod h1:UQGH1tvbgY+Nz5t2n7tXsz52dQxojPUpymEIMZ47gx8= diff --git a/storage/cache/database_test.go b/storage/cache/database_test.go index 65c67c5d6..ded932bf9 100644 --- a/storage/cache/database_test.go +++ b/storage/cache/database_test.go @@ -601,6 +601,21 @@ func benchmark(b *testing.B, database Database) { b.Run("UpdateScores", func(b *testing.B) { benchmarkUpdateDocuments(b, database) }) + b.Run("TSIngestSingle", func(b *testing.B) { + benchmarkTimeSeriesIngestSingle(b, database) + }) + b.Run("TSIngestBatch", func(b *testing.B) { + benchmarkTimeSeriesIngestBatch(b, database) + }) + b.Run("TSQuerySmallRange", func(b *testing.B) { + benchmarkTimeSeriesQuerySmallRange(b, database) + }) + b.Run("TSQueryLargeRange", func(b *testing.B) { + benchmarkTimeSeriesQueryLargeRange(b, database) + }) + b.Run("TSQueryWideRange", func(b *testing.B) { + benchmarkTimeSeriesQueryWideRange(b, database) + }) } func benchmarkAddDocuments(b *testing.B, database Database) { @@ -658,3 +673,137 @@ func benchmarkUpdateDocuments(b *testing.B, database Database) { assert.NoError(b, err) } } + +// benchmarkTimeSeriesIngestSingle measures single-point ingestion throughput. +func benchmarkTimeSeriesIngestSingle(b *testing.B, database Database) { + ctx := b.Context() + base := time.Date(2023, 1, 1, 0, 0, 0, 0, time.UTC) + b.ResetTimer() + for i := 0; i < b.N; i++ { + err := database.AddTimeSeriesPoints(ctx, []TimeSeriesPoint{ + {Name: "bench_single", Value: float64(i), Timestamp: base.Add(time.Duration(i) * time.Second)}, + }) + assert.NoError(b, err) + } +} + +// benchmarkTimeSeriesIngestBatch measures batch ingestion throughput (100 points per call). +func benchmarkTimeSeriesIngestBatch(b *testing.B, database Database) { + ctx := b.Context() + base := time.Date(2023, 1, 1, 0, 0, 0, 0, time.UTC) + const batchSize = 100 + b.ResetTimer() + for i := 0; i < b.N; i++ { + points := make([]TimeSeriesPoint, batchSize) + for j := range batchSize { + points[j] = TimeSeriesPoint{ + Name: "bench_batch", + Value: float64(i*batchSize + j), + Timestamp: base.Add(time.Duration(i*batchSize+j) * time.Second), + } + } + err := database.AddTimeSeriesPoints(ctx, points) + assert.NoError(b, err) + } +} + +// benchmarkTimeSeriesQuerySmallRange measures query latency for a small time range (60s window, 1s buckets). +// Pre-loads 10,000 points at 1-second intervals. +func benchmarkTimeSeriesQuerySmallRange(b *testing.B, database Database) { + ctx := b.Context() + base := time.Date(2023, 1, 1, 0, 0, 0, 0, time.UTC) + const totalPoints = 10000 + // pre-load data in batches of 500 + for offset := 0; offset < totalPoints; offset += 500 { + end := offset + 500 + if end > totalPoints { + end = totalPoints + } + points := make([]TimeSeriesPoint, 0, end-offset) + for i := offset; i < end; i++ { + points = append(points, TimeSeriesPoint{ + Name: "bench_query_small", + Value: float64(i), + Timestamp: base.Add(time.Duration(i) * time.Second), + }) + } + err := database.AddTimeSeriesPoints(ctx, points) + assert.NoError(b, err) + } + // query a 60-second window in the middle, 1-second buckets (~60 points returned) + queryBegin := base.Add(5000 * time.Second) + queryEnd := base.Add(5060 * time.Second) + b.ResetTimer() + for i := 0; i < b.N; i++ { + points, err := database.GetTimeSeriesPoints(ctx, "bench_query_small", queryBegin, queryEnd, time.Second) + assert.NoError(b, err) + assert.NotEmpty(b, points) + } +} + +// benchmarkTimeSeriesQueryLargeRange measures query latency with heavy aggregation. +// Pre-loads 10,000 points at 1-second intervals, queries full range with 60-second buckets (~167 buckets). +func benchmarkTimeSeriesQueryLargeRange(b *testing.B, database Database) { + ctx := b.Context() + base := time.Date(2023, 1, 1, 0, 0, 0, 0, time.UTC) + const totalPoints = 10000 + // pre-load data in batches of 500 + for offset := 0; offset < totalPoints; offset += 500 { + end := offset + 500 + if end > totalPoints { + end = totalPoints + } + points := make([]TimeSeriesPoint, 0, end-offset) + for i := offset; i < end; i++ { + points = append(points, TimeSeriesPoint{ + Name: "bench_query_large", + Value: float64(i), + Timestamp: base.Add(time.Duration(i) * time.Second), + }) + } + err := database.AddTimeSeriesPoints(ctx, points) + assert.NoError(b, err) + } + // query full range with 60-second buckets + queryBegin := base + queryEnd := base.Add(time.Duration(totalPoints) * time.Second) + b.ResetTimer() + for i := 0; i < b.N; i++ { + points, err := database.GetTimeSeriesPoints(ctx, "bench_query_large", queryBegin, queryEnd, 60*time.Second) + assert.NoError(b, err) + assert.NotEmpty(b, points) + } +} + +// benchmarkTimeSeriesQueryWideRange measures worst-case query: 100K points aggregated into ~28 hourly buckets. +func benchmarkTimeSeriesQueryWideRange(b *testing.B, database Database) { + ctx := b.Context() + base := time.Date(2023, 1, 1, 0, 0, 0, 0, time.UTC) + const totalPoints = 100000 + // pre-load data in batches of 1000 + for offset := 0; offset < totalPoints; offset += 1000 { + end := offset + 1000 + if end > totalPoints { + end = totalPoints + } + points := make([]TimeSeriesPoint, 0, end-offset) + for i := offset; i < end; i++ { + points = append(points, TimeSeriesPoint{ + Name: "bench_query_wide", + Value: float64(i), + Timestamp: base.Add(time.Duration(i) * time.Second), + }) + } + err := database.AddTimeSeriesPoints(ctx, points) + assert.NoError(b, err) + } + // query full range with 1-hour buckets + queryBegin := base + queryEnd := base.Add(time.Duration(totalPoints) * time.Second) + b.ResetTimer() + for i := 0; i < b.N; i++ { + points, err := database.GetTimeSeriesPoints(ctx, "bench_query_wide", queryBegin, queryEnd, time.Hour) + assert.NoError(b, err) + assert.NotEmpty(b, points) + } +} diff --git a/storage/cache/valkey.go b/storage/cache/valkey.go new file mode 100644 index 000000000..16db1b469 --- /dev/null +++ b/storage/cache/valkey.go @@ -0,0 +1,839 @@ +// Copyright 2025 gorse Project Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cache + +import ( + "context" + "crypto/tls" + "fmt" + "io" + "net/url" + "sort" + "strconv" + "strings" + "time" + + "github.com/gorse-io/gorse/common/log" + "github.com/gorse-io/gorse/storage" + "github.com/juju/errors" + "github.com/samber/lo" + "github.com/valkey-io/valkey-go" + "go.uber.org/zap" +) + +func init() { + Register([]string{storage.ValkeyPrefix, storage.ValkeysPrefix}, func(path, tablePrefix string, opts ...storage.Option) (Database, error) { + host, port, username, password, db, useTLS, err := parseValkeyURL(path) + if err != nil { + return nil, errors.Trace(err) + } + option := valkey.ClientOption{ + InitAddress: []string{fmt.Sprintf("%s:%d", host, port)}, + SelectDB: db, + } + if username != "" || password != "" { + option.Username = username + option.Password = password + } + if useTLS { + option.TLSConfig = &tls.Config{} + } + client, err := valkey.NewClient(option) + if err != nil { + return nil, errors.Trace(err) + } + database := &Valkey{ + client: client, + isCluster: false, + maxSearchResults: storage.NewOptions(opts...).MaxSearchResults, + } + database.TablePrefix = storage.TablePrefix(tablePrefix) + return database, nil + }) + Register([]string{storage.ValkeyClusterPrefix, storage.ValkeysClusterPrefix}, func(path, tablePrefix string, opts ...storage.Option) (Database, error) { + addresses, username, password, useTLS, err := parseValkeyClusterURL(path) + if err != nil { + return nil, errors.Trace(err) + } + option := valkey.ClientOption{ + InitAddress: addresses, + } + if username != "" || password != "" { + option.Username = username + option.Password = password + } + if useTLS { + option.TLSConfig = &tls.Config{} + } + client, err := valkey.NewClient(option) + if err != nil { + return nil, errors.Trace(err) + } + database := &Valkey{ + client: client, + isCluster: true, + maxSearchResults: storage.NewOptions(opts...).MaxSearchResults, + } + database.TablePrefix = storage.TablePrefix(tablePrefix) + return database, nil + }) +} + +// parseValkeyURL parses a valkey:// or valkeys:// URL into connection parameters. +func parseValkeyURL(rawURL string) (host string, port int, username, password string, db int, useTLS bool, err error) { + parsed, err := url.Parse(rawURL) + if err != nil { + return "", 0, "", "", 0, false, errors.Trace(err) + } + host = parsed.Hostname() + if host == "" { + host = "localhost" + } + port = 6379 + if parsed.Port() != "" { + port, err = strconv.Atoi(parsed.Port()) + if err != nil { + return "", 0, "", "", 0, false, errors.Trace(err) + } + } + if parsed.User != nil { + username = parsed.User.Username() + password, _ = parsed.User.Password() + } + if parsed.Path != "" && parsed.Path != "/" { + dbStr := strings.TrimPrefix(parsed.Path, "/") + if dbStr != "" { + db, err = strconv.Atoi(dbStr) + if err != nil { + return "", 0, "", "", 0, false, errors.Errorf("invalid database number: %s", dbStr) + } + } + } + useTLS = parsed.Scheme == "valkeys" + return host, port, username, password, db, useTLS, nil +} + +// parseValkeyClusterURL parses a valkey+cluster:// or valkeys+cluster:// URL. +func parseValkeyClusterURL(rawURL string) (addresses []string, username, password string, useTLS bool, err error) { + // Replace the cluster prefix with a standard scheme for URL parsing. + var newURL string + if strings.HasPrefix(rawURL, storage.ValkeyClusterPrefix) { + newURL = strings.Replace(rawURL, storage.ValkeyClusterPrefix, storage.ValkeyPrefix, 1) + useTLS = false + } else if strings.HasPrefix(rawURL, storage.ValkeysClusterPrefix) { + newURL = strings.Replace(rawURL, storage.ValkeysClusterPrefix, storage.ValkeysPrefix, 1) + useTLS = true + } + parsed, err := url.Parse(newURL) + if err != nil { + return nil, "", "", false, errors.Trace(err) + } + host := parsed.Hostname() + if host == "" { + host = "localhost" + } + port := 6379 + if parsed.Port() != "" { + port, err = strconv.Atoi(parsed.Port()) + if err != nil { + return nil, "", "", false, errors.Trace(err) + } + } + addresses = append(addresses, fmt.Sprintf("%s:%d", host, port)) + if parsed.User != nil { + username = parsed.User.Username() + password, _ = parsed.User.Password() + } + // Parse additional addresses from query params (addr=host:port). + for _, addrStr := range parsed.Query()["addr"] { + if !strings.Contains(addrStr, ":") { + addrStr = addrStr + ":6379" + } + addresses = append(addresses, addrStr) + } + return addresses, username, password, useTLS, nil +} + +// Valkey cache storage using valkey-go client. +type Valkey struct { + storage.TablePrefix + client valkey.Client + isCluster bool + maxSearchResults int +} + +// Close the valkey connection. +func (v *Valkey) Close() error { + v.client.Close() + return nil +} + +// Ping the valkey server. +func (v *Valkey) Ping() error { + ctx := context.Background() + return v.client.Do(ctx, v.client.B().Ping().Build()).Error() +} + +// Init creates the valkey-search index for document storage. +func (v *Valkey) Init() error { + ctx := context.Background() + // List existing indices via FT._LIST. + result, err := v.client.Do(ctx, v.client.B().Arbitrary("FT._LIST").Build()).ToArray() + if err != nil { + return errors.Trace(err) + } + indices := make([]string, 0, len(result)) + for _, r := range result { + if s, err := r.ToString(); err == nil { + indices = append(indices, s) + } + } + if lo.Contains(indices, v.DocumentTable()) { + return nil + } + // Create the index. + err = v.client.Do(ctx, v.client.B().Arbitrary("FT.CREATE"). + Keys(v.DocumentTable()). + Args( + "ON", "HASH", + "PREFIX", "1", v.DocumentTable()+":", + "SCHEMA", + "collection", "TAG", + "subset", "TAG", + "id", "TAG", + "score", "NUMERIC", + "is_hidden", "NUMERIC", + "categories", "TAG", "SEPARATOR", ";", + "timestamp", "NUMERIC", + ).Build()).Error() + if err != nil { + return errors.Trace(err) + } + return nil +} + +// Scan iterates over all keys with the table prefix. +func (v *Valkey) Scan(work func(string) error) error { + ctx := context.Background() + var cursor uint64 + for { + entry, err := v.client.Do(ctx, v.client.B().Scan().Cursor(cursor).Match(string(v.TablePrefix)+"*").Count(100).Build()).AsScanEntry() + if err != nil { + return errors.Trace(err) + } + for _, key := range entry.Elements { + if err = work(key[len(v.TablePrefix):]); err != nil { + return errors.Trace(err) + } + } + cursor = entry.Cursor + if cursor == 0 { + return nil + } + } +} + +// Purge deletes all keys with the table prefix. +func (v *Valkey) Purge() error { + ctx := context.Background() + var cursor uint64 + for { + entry, err := v.client.Do(ctx, v.client.B().Scan().Cursor(cursor).Match(string(v.TablePrefix)+"*").Count(100).Build()).AsScanEntry() + if err != nil { + return errors.Trace(err) + } + if len(entry.Elements) > 0 { + if v.isCluster { + for _, key := range entry.Elements { + if err = v.client.Do(ctx, v.client.B().Del().Key(key).Build()).Error(); err != nil { + return errors.Trace(err) + } + } + } else { + if err = v.client.Do(ctx, v.client.B().Del().Key(entry.Elements...).Build()).Error(); err != nil { + return errors.Trace(err) + } + } + } + cursor = entry.Cursor + if cursor == 0 { + return nil + } + } +} + +// Set stores values in Valkey. +func (v *Valkey) Set(ctx context.Context, values ...Value) error { + if len(values) == 0 { + return nil + } + cmds := make(valkey.Commands, 0, len(values)) + for _, val := range values { + cmds = append(cmds, v.client.B().Set().Key(v.Key(val.name)).Value(val.value).Build()) + } + for _, resp := range v.client.DoMulti(ctx, cmds...) { + if err := resp.Error(); err != nil { + return errors.Trace(err) + } + } + return nil +} + +// Get returns a value from Valkey. +func (v *Valkey) Get(ctx context.Context, key string) *ReturnValue { + result, err := v.client.Do(ctx, v.client.B().Get().Key(v.Key(key)).Build()).ToString() + if err != nil { + if valkey.IsValkeyNil(err) { + return &ReturnValue{value: "", exists: false} + } + return &ReturnValue{err: err, exists: false} + } + return &ReturnValue{value: result, exists: true} +} + +// Delete removes a key from Valkey. +func (v *Valkey) Delete(ctx context.Context, key string) error { + return v.client.Do(ctx, v.client.B().Del().Key(v.Key(key)).Build()).Error() +} + +// Push adds a message to a sorted set queue with timestamp as score. +func (v *Valkey) Push(ctx context.Context, name, message string) error { + return v.client.Do(ctx, v.client.B().Zadd().Key(v.Key(name)).ScoreMember().ScoreMember(float64(time.Now().UnixNano()), message).Build()).Error() +} + +// Pop removes and returns the message with the lowest score from the queue. +func (v *Valkey) Pop(ctx context.Context, name string) (string, error) { + result, err := v.client.Do(ctx, v.client.B().Zpopmin().Key(v.Key(name)).Count(1).Build()).AsZScores() + if err != nil { + return "", errors.Trace(err) + } + if len(result) == 0 { + return "", io.EOF + } + return result[0].Member, nil +} + +// Remain returns the number of messages in the queue. +func (v *Valkey) Remain(ctx context.Context, name string) (int64, error) { + return v.client.Do(ctx, v.client.B().Zcard().Key(v.Key(name)).Build()).AsInt64() +} + +func (v *Valkey) documentKey(collection, subset, value string) string { + return v.DocumentTable() + ":" + collection + ":" + subset + ":" + value +} + +// AddScores adds score documents to Valkey using hash storage. +func (v *Valkey) AddScores(ctx context.Context, collection, subset string, documents []Score) error { + if len(documents) == 0 { + return nil + } + cmds := make(valkey.Commands, 0, len(documents)) + for _, document := range documents { + key := v.documentKey(collection, subset, document.Id) + cmds = append(cmds, v.client.B().Hset().Key(key).FieldValue(). + FieldValue("collection", collection). + FieldValue("subset", subset). + FieldValue("id", document.Id). + FieldValue("score", strconv.FormatFloat(document.Score, 'g', -1, 64)). + FieldValue("is_hidden", formatBool(document.IsHidden)). + FieldValue("categories", encodeCategories(document.Categories)). + FieldValue("timestamp", strconv.FormatInt(document.Timestamp.UnixMicro(), 10)). + Build()) + } + for _, resp := range v.client.DoMulti(ctx, cmds...) { + if err := resp.Error(); err != nil { + return errors.Trace(err) + } + } + return nil +} + +func formatBool(b bool) string { + if b { + return "1" + } + return "0" +} + +// SearchScores searches for score documents using FT.SEARCH. +func (v *Valkey) SearchScores(ctx context.Context, collection, subset string, query []string, begin, end int) ([]Score, error) { + var builder strings.Builder + fmt.Fprintf(&builder, "@collection:{ %s } @is_hidden:[0 0]", escapeTag(collection)) + if subset != "" { + fmt.Fprintf(&builder, " @subset:{ %s }", escapeTag(subset)) + } + for _, q := range query { + fmt.Fprintf(&builder, " @categories:{ %s }", escapeTag(encodeCategory(q))) + } + // Use server-side LIMIT offset/count to fetch only the needed slice. + fetchOffset := begin + fetchCount := 10000 + if end != -1 { + fetchCount = end - begin + } + cmd := v.client.B().Arbitrary("FT.SEARCH"). + Keys(v.DocumentTable()). + Args(builder.String(), "SORTBY", "score", "DESC", "LIMIT", strconv.Itoa(fetchOffset), strconv.Itoa(fetchCount)). + Build() + result, err := v.client.Do(ctx, cmd).ToArray() + if err != nil { + return nil, errors.Trace(err) + } + documents, err := parseFTSearchResult(result) + if err != nil { + return nil, errors.Trace(err) + } + return documents, nil +} + +// UpdateScores updates score documents matching the query. +func (v *Valkey) UpdateScores(ctx context.Context, collections []string, subset *string, id string, patch ScorePatch) error { + if len(collections) == 0 { + return nil + } + if patch.Score == nil && patch.IsHidden == nil && patch.Categories == nil { + return nil + } + var builder strings.Builder + escapedCollections := make([]string, len(collections)) + for i, c := range collections { + escapedCollections[i] = escapeTag(c) + } + fmt.Fprintf(&builder, "@collection:{ %s }", strings.Join(escapedCollections, " | ")) + fmt.Fprintf(&builder, " @id:{ %s }", escapeTag(id)) + if subset != nil { + fmt.Fprintf(&builder, " @subset:{ %s }", escapeTag(*subset)) + } + limit := v.maxSearchResults + if limit <= 0 { + limit = 10000 + } + + // Two-phase update: collect all matching keys first, then mutate. + // On the first fetch, if total results exceed the page limit, re-fetch all + // in a single request to avoid pagination drift when scores change. + keys := make([]string, 0) + keySet := make(map[string]struct{}) + offset := 0 + for { + cmd := v.client.B().Arbitrary("FT.SEARCH"). + Keys(v.DocumentTable()). + Args(builder.String(), "SORTBY", "score", "DESC", "LIMIT", strconv.Itoa(offset), strconv.Itoa(limit)). + Build() + result, err := v.client.Do(ctx, cmd).ToArray() + if err != nil { + return errors.Trace(err) + } + if offset == 0 { + total := parseFTSearchTotal(result) + if total > limit { + // Fetch all results in one shot to avoid pagination issues. + cmd = v.client.B().Arbitrary("FT.SEARCH"). + Keys(v.DocumentTable()). + Args(builder.String(), "SORTBY", "score", "DESC", "LIMIT", "0", strconv.Itoa(total)). + Build() + result, err = v.client.Do(ctx, cmd).ToArray() + if err != nil { + return errors.Trace(err) + } + } + } + docKeys := parseFTSearchKeys(result) + if len(docKeys) == 0 { + break + } + newKeys := 0 + for _, k := range docKeys { + if _, exists := keySet[k]; !exists { + keySet[k] = struct{}{} + keys = append(keys, k) + newKeys++ + } + } + offset += len(docKeys) + if len(docKeys) < limit || newKeys == 0 { + break + } + } + + for _, key := range keys { + cmd := v.client.B().Hset().Key(key).FieldValue() + if patch.Score != nil { + cmd = cmd.FieldValue("score", strconv.FormatFloat(*patch.Score, 'g', -1, 64)) + } + if patch.IsHidden != nil { + cmd = cmd.FieldValue("is_hidden", formatBool(*patch.IsHidden)) + } + if patch.Categories != nil { + cmd = cmd.FieldValue("categories", encodeCategories(patch.Categories)) + } + if err := v.client.Do(ctx, cmd.Build()).Error(); err != nil { + return errors.Trace(err) + } + } + return nil +} + +// DeleteScores deletes score documents matching the condition. +func (v *Valkey) DeleteScores(ctx context.Context, collections []string, condition ScoreCondition) error { + if err := condition.Check(); err != nil { + return errors.Trace(err) + } + var builder strings.Builder + escapedCollections := make([]string, len(collections)) + for i, c := range collections { + escapedCollections[i] = escapeTag(c) + } + fmt.Fprintf(&builder, "@collection:{ %s }", strings.Join(escapedCollections, " | ")) + if condition.Subset != nil { + fmt.Fprintf(&builder, " @subset:{ %s }", escapeTag(*condition.Subset)) + } + if condition.Id != nil { + fmt.Fprintf(&builder, " @id:{ %s }", escapeTag(*condition.Id)) + } + if condition.Before != nil { + fmt.Fprintf(&builder, " @timestamp:[-inf (%d]", condition.Before.UnixMicro()) + } + const maxDeleteIterations = 100 + for iteration := 0; iteration < maxDeleteIterations; iteration++ { + cmd := v.client.B().Arbitrary("FT.SEARCH"). + Keys(v.DocumentTable()). + Args(builder.String(), "SORTBY", "score", "DESC", "LIMIT", "0", "10000"). + Build() + result, err := v.client.Do(ctx, cmd).ToArray() + if err != nil { + return errors.Trace(err) + } + docKeys := parseFTSearchKeys(result) + total := parseFTSearchTotal(result) + if len(docKeys) == 0 { + break + } + if v.isCluster { + for _, key := range docKeys { + if err = v.client.Do(ctx, v.client.B().Del().Key(key).Build()).Error(); err != nil { + return errors.Trace(err) + } + } + } else { + if err = v.client.Do(ctx, v.client.B().Del().Key(docKeys...).Build()).Error(); err != nil { + return errors.Trace(err) + } + } + if total == len(docKeys) { + break + } + } + return nil +} + +// ScanScores iterates over all score documents. +func (v *Valkey) ScanScores(ctx context.Context, callback func(collection string, id string, subset string, timestamp time.Time) error) error { + var cursor uint64 + for { + entry, err := v.client.Do(ctx, v.client.B().Scan().Cursor(cursor).Match(v.DocumentTable()+":*").Count(100).Build()).AsScanEntry() + if err != nil { + return errors.Trace(err) + } + for _, key := range entry.Elements { + row, err := v.client.Do(ctx, v.client.B().Hgetall().Key(key).Build()).AsStrMap() + if err != nil { + return errors.Trace(err) + } + usec, err := strconv.ParseInt(row["timestamp"], 10, 64) + if err != nil { + return errors.Trace(err) + } + if err = callback(row["collection"], row["id"], row["subset"], time.UnixMicro(usec).In(time.UTC)); err != nil { + return errors.Trace(err) + } + } + cursor = entry.Cursor + if cursor == 0 { + return nil + } + } +} + +// AddTimeSeriesPoints adds time series points using sorted set + hash. +func (v *Valkey) AddTimeSeriesPoints(ctx context.Context, points []TimeSeriesPoint) error { + grouped := groupTimeSeriesPoints(points) + cmds := make(valkey.Commands, 0, len(grouped)*2) + for name, sd := range grouped { + indexKey := v.PointsTable() + ":ts_index:" + name + dataKey := v.PointsTable() + ":ts_data:" + name + // ZADD + zaddCmd := v.client.B().Zadd().Key(indexKey).ScoreMember() + for member, score := range sd.zaddMembers { + zaddCmd = zaddCmd.ScoreMember(score, member) + } + cmds = append(cmds, zaddCmd.Build()) + // HSET + hsetCmd := v.client.B().Hset().Key(dataKey).FieldValue() + for field, value := range sd.hsetFields { + hsetCmd = hsetCmd.FieldValue(field, value) + } + cmds = append(cmds, hsetCmd.Build()) + } + for _, resp := range v.client.DoMulti(ctx, cmds...) { + if err := resp.Error(); err != nil { + return errors.Trace(err) + } + } + return nil +} + +// timeSeriesGroup holds grouped ZADD members and HSET fields for a single time series. +type timeSeriesGroup struct { + zaddMembers map[string]float64 + hsetFields map[string]string +} + +// groupTimeSeriesPoints groups time series points by name for batched writes. +func groupTimeSeriesPoints(points []TimeSeriesPoint) map[string]*timeSeriesGroup { + grouped := make(map[string]*timeSeriesGroup) + for _, point := range points { + tsMs := point.Timestamp.UnixMilli() + tsMsStr := strconv.FormatInt(tsMs, 10) + sd, ok := grouped[point.Name] + if !ok { + sd = &timeSeriesGroup{ + zaddMembers: make(map[string]float64), + hsetFields: make(map[string]string), + } + grouped[point.Name] = sd + } + sd.zaddMembers[tsMsStr] = float64(tsMs) + sd.hsetFields[tsMsStr] = strconv.FormatFloat(point.Value, 'g', -1, 64) + } + return grouped +} + +// GetTimeSeriesPoints retrieves time series points with Go-side bucket aggregation. +func (v *Valkey) GetTimeSeriesPoints(ctx context.Context, name string, begin, end time.Time, duration time.Duration) ([]TimeSeriesPoint, error) { + indexKey := v.PointsTable() + ":ts_index:" + name + dataKey := v.PointsTable() + ":ts_data:" + name + beginMs := begin.UnixMilli() + endMs := end.UnixMilli() + + // Fetch all timestamps in range from sorted set. + members, err := v.client.Do(ctx, v.client.B().Zrangebyscore().Key(indexKey). + Min(strconv.FormatInt(beginMs, 10)). + Max(strconv.FormatInt(endMs, 10)). + Withscores().Build()).AsZScores() + if err != nil { + return nil, errors.Trace(err) + } + if len(members) == 0 { + return nil, nil + } + + // Fetch corresponding values from hash. + fields := make([]string, len(members)) + for i, m := range members { + fields[i] = m.Member + } + hmgetResults, err := v.client.Do(ctx, v.client.B().Hmget().Key(dataKey).Field(fields...).Build()).ToArray() + if err != nil { + return nil, errors.Trace(err) + } + + // Build timestamp-value pairs. + type tsValue struct { + timestampMs int64 + value float64 + } + tsValues := make([]tsValue, 0, len(members)) + for i, m := range members { + valStr, err := hmgetResults[i].ToString() + if err != nil { + continue // nil result + } + tsMs, err := strconv.ParseInt(m.Member, 10, 64) + if err != nil { + log.Logger().Warn("failed to parse timestamp", zap.String("member", m.Member), zap.Error(err)) + continue + } + val, err := strconv.ParseFloat(valStr, 64) + if err != nil { + log.Logger().Warn("failed to parse value", zap.String("value", valStr), zap.Error(err)) + continue + } + tsValues = append(tsValues, tsValue{timestampMs: tsMs, value: val}) + } + + // Go-side bucket aggregation. + durationMs := duration.Milliseconds() + if durationMs <= 0 { + durationMs = 1 + } + type bucket struct { + bucketMs int64 + lastTsMs int64 + lastValue float64 + } + buckets := make(map[int64]*bucket) + for _, tv := range tsValues { + bk := (tv.timestampMs / durationMs) * durationMs + existing, ok := buckets[bk] + if !ok || tv.timestampMs > existing.lastTsMs { + buckets[bk] = &bucket{bucketMs: bk, lastTsMs: tv.timestampMs, lastValue: tv.value} + } + } + + sortedBuckets := make([]*bucket, 0, len(buckets)) + for _, b := range buckets { + sortedBuckets = append(sortedBuckets, b) + } + sort.Slice(sortedBuckets, func(i, j int) bool { + return sortedBuckets[i].bucketMs < sortedBuckets[j].bucketMs + }) + + points := make([]TimeSeriesPoint, 0, len(sortedBuckets)) + for _, b := range sortedBuckets { + points = append(points, TimeSeriesPoint{ + Name: name, + Timestamp: time.UnixMilli(b.bucketMs).UTC(), + Value: b.lastValue, + }) + } + return points, nil +} + +// --- FT.SEARCH result parsing helpers --- + +// parseFTSearchTotal extracts the total count from an FT.SEARCH result array. +// valkey-go returns: [total_int64, key1, [field1, val1, ...], key2, [field2, val2, ...], ...] +func parseFTSearchTotal(result []valkey.ValkeyMessage) int { + if len(result) == 0 { + return 0 + } + total, err := result[0].AsInt64() + if err != nil { + return 0 + } + return int(total) +} + +// parseFTSearchKeys extracts document keys from an FT.SEARCH result array. +func parseFTSearchKeys(result []valkey.ValkeyMessage) []string { + if len(result) < 2 { + return nil + } + var keys []string + for i := 1; i < len(result); i += 2 { + key, err := result[i].ToString() + if err == nil { + keys = append(keys, key) + } + } + return keys +} + +// parseFTSearchResult parses an FT.SEARCH result into Score documents. +func parseFTSearchResult(result []valkey.ValkeyMessage) ([]Score, error) { + if len(result) < 2 { + return nil, nil + } + documents := make([]Score, 0) + for i := 1; i < len(result); i += 2 { + if i+1 >= len(result) { + break + } + fields, err := parseFieldArray(result[i+1]) + if err != nil { + continue + } + doc, err := scoreFromFieldMap(fields) + if err != nil { + return nil, err + } + documents = append(documents, doc) + } + // Sort by score descending to match the FT.SEARCH SORTBY score DESC. + sort.Slice(documents, func(i, j int) bool { + return documents[i].Score > documents[j].Score + }) + return documents, nil +} + +// parseFieldArray converts a ValkeyMessage field array [field1, val1, field2, val2, ...] into a map. +func parseFieldArray(msg valkey.ValkeyMessage) (map[string]string, error) { + arr, err := msg.ToArray() + if err != nil { + return nil, err + } + m := make(map[string]string) + for j := 0; j+1 < len(arr); j += 2 { + key, err1 := arr[j].ToString() + val, err2 := arr[j+1].ToString() + if err1 == nil && err2 == nil { + m[key] = val + } + } + return m, nil +} + +// scoreFromFieldMap converts a field map into a Score struct. +func scoreFromFieldMap(fields map[string]string) (Score, error) { + var doc Score + doc.Id = fields["id"] + score, err := strconv.ParseFloat(fields["score"], 64) + if err != nil { + return doc, errors.Trace(err) + } + doc.Score = score + isHidden, err := strconv.ParseInt(fields["is_hidden"], 10, 64) + if err != nil { + return doc, errors.Trace(err) + } + doc.IsHidden = isHidden != 0 + categories, err := decodeCategories(fields["categories"]) + if err != nil { + return doc, errors.Trace(err) + } + doc.Categories = categories + timestamp, err := strconv.ParseInt(fields["timestamp"], 10, 64) + if err != nil { + return doc, errors.Trace(err) + } + doc.Timestamp = time.UnixMicro(timestamp).In(time.UTC) + return doc, nil +} + +// valkeyTagEscaper escapes all TAG special characters for Valkey Search queries. +var valkeyTagEscaper = strings.NewReplacer( + `\`, `\\`, + `{`, `\{`, + `}`, `\}`, + `|`, `\|`, + `*`, `\*`, + `(`, `\(`, + `)`, `\)`, + `~`, `\~`, + `@`, `\@`, + `"`, `\"`, + `'`, `\'`, + `-`, `\-`, + `:`, `\:`, + `.`, `\.`, + `/`, `\/`, + `+`, `\+`, +) + +// escapeTag escapes a value for use in Valkey Search TAG queries. +func escapeTag(s string) string { + return valkeyTagEscaper.Replace(s) +} diff --git a/storage/cache/valkey_test.go b/storage/cache/valkey_test.go new file mode 100644 index 000000000..acfd71af4 --- /dev/null +++ b/storage/cache/valkey_test.go @@ -0,0 +1,286 @@ +// Copyright 2025 gorse Project Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cache + +import ( + "context" + "fmt" + "math" + "os" + "testing" + "time" + + "github.com/gorse-io/gorse/common/log" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" +) + +var ( + valkeyDSN string +) + +func init() { + env := func(key, defaultValue string) string { + if value := os.Getenv(key); value != "" { + return value + } + return defaultValue + } + valkeyDSN = env("VALKEY_URI", "valkey://127.0.0.1:6380/") +} + +type ValkeyTestSuite struct { + baseTestSuite +} + +func (suite *ValkeyTestSuite) SetupSuite() { + var err error + suite.Database, err = Open(valkeyDSN, "gorse_") + suite.Require().NoError(err) + // flush db + valkeyClient, ok := suite.Database.(*Valkey) + suite.Require().True(ok) + err = valkeyClient.client.Do(context.Background(), valkeyClient.client.B().Flushdb().Build()).Error() + suite.Require().NoError(err) + // create schema + err = suite.Database.Init() + suite.Require().NoError(err) +} + +func (suite *ValkeyTestSuite) TestEscapeCharacters() { + ts := time.Date(2023, 1, 1, 0, 0, 0, 0, time.UTC) + ctx := suite.T().Context() + for _, c := range []string{"-", ":", ".", "/"} { + suite.Run(c, func() { + collection := fmt.Sprintf("a%s1", c) + subset := fmt.Sprintf("b%s2", c) + id := fmt.Sprintf("c%s3", c) + err := suite.AddScores(ctx, collection, subset, []Score{{ + Id: id, + Score: math.MaxFloat64, + Categories: []string{"a", "b"}, + Timestamp: ts, + }}) + suite.NoError(err) + documents, err := suite.SearchScores(ctx, collection, subset, []string{"b"}, 0, -1) + suite.NoError(err) + suite.Equal([]Score{{Id: id, Score: math.MaxFloat64, Categories: []string{"a", "b"}, Timestamp: ts}}, documents) + + err = suite.UpdateScores(ctx, []string{collection}, nil, id, ScorePatch{Score: new(float64(1))}) + suite.NoError(err) + documents, err = suite.SearchScores(ctx, collection, subset, []string{"b"}, 0, -1) + suite.NoError(err) + suite.Equal([]Score{{Id: id, Score: 1, Categories: []string{"a", "b"}, Timestamp: ts}}, documents) + + err = suite.DeleteScores(ctx, []string{collection}, ScoreCondition{ + Subset: new(subset), + Id: new(id), + }) + suite.NoError(err) + documents, err = suite.SearchScores(ctx, collection, subset, []string{"b"}, 0, -1) + suite.NoError(err) + suite.Empty(documents) + }) + } +} + +func (suite *ValkeyTestSuite) TestUpdateScoresWithPagination() { + ctx := suite.T().Context() + db, ok := suite.Database.(*Valkey) + suite.True(ok) + limit := db.maxSearchResults + db.maxSearchResults = 2 + defer func() { + db.maxSearchResults = limit + }() + + for i := range 5 { + subset := fmt.Sprintf("subset-%d", i) + err := suite.AddScores(ctx, "collection-a", subset, []Score{{ + Id: "shared-item", + Score: float64(i), + Categories: []string{"old"}, + Timestamp: time.Now().UTC(), + }}) + suite.NoError(err) + } + + err := suite.UpdateScores(ctx, []string{"collection-a"}, nil, "shared-item", ScorePatch{ + Categories: []string{"new"}, + }) + suite.NoError(err) + + for i := range 5 { + subset := fmt.Sprintf("subset-%d", i) + docs, err := suite.SearchScores(ctx, "collection-a", subset, []string{"new"}, 0, -1) + suite.NoError(err) + suite.Require().Len(docs, 1) + suite.Equal("shared-item", docs[0].Id) + } +} + +func (suite *ValkeyTestSuite) TestUpdateScoresWithPaginationAndScorePatch() { + ctx := suite.T().Context() + db, ok := suite.Database.(*Valkey) + suite.True(ok) + limit := db.maxSearchResults + db.maxSearchResults = 1 + defer func() { + db.maxSearchResults = limit + }() + + initialScores := []float64{3, 2, 1} + for i, score := range initialScores { + subset := fmt.Sprintf("score-subset-%d", i) + err := suite.AddScores(ctx, "collection-b", subset, []Score{{ + Id: "shared-item", + Score: score, + Categories: []string{"score-old"}, + Timestamp: time.Now().UTC(), + }}) + suite.NoError(err) + } + + targetScore := float64(0) + err := suite.UpdateScores(ctx, []string{"collection-b"}, nil, "shared-item", ScorePatch{ + Score: &targetScore, + }) + suite.NoError(err) + + for i := range initialScores { + subset := fmt.Sprintf("score-subset-%d", i) + docs, err := suite.SearchScores(ctx, "collection-b", subset, nil, 0, -1) + suite.NoError(err) + suite.Require().Len(docs, 1) + suite.Equal(targetScore, docs[0].Score) + } +} + +func (suite *ValkeyTestSuite) TestUpdateScoresWithPaginationAndTiedScores() { + ctx := suite.T().Context() + db, ok := suite.Database.(*Valkey) + suite.True(ok) + limit := db.maxSearchResults + db.maxSearchResults = 2 + defer func() { + db.maxSearchResults = limit + }() + + for i := range 5 { + subset := fmt.Sprintf("tie-subset-%d", i) + err := suite.AddScores(ctx, "collection-c", subset, []Score{{ + Id: "shared-item", + Score: 1, + Categories: []string{"tie-old"}, + Timestamp: time.Now().UTC(), + }}) + suite.NoError(err) + } + + err := suite.UpdateScores(ctx, []string{"collection-c"}, nil, "shared-item", ScorePatch{ + Categories: []string{"tie-new"}, + }) + suite.NoError(err) + + for i := range 5 { + subset := fmt.Sprintf("tie-subset-%d", i) + docs, err := suite.SearchScores(ctx, "collection-c", subset, []string{"tie-new"}, 0, -1) + suite.NoError(err) + suite.Require().Len(docs, 1) + suite.Equal("shared-item", docs[0].Id) + } +} + +func TestValkey(t *testing.T) { + suite.Run(t, new(ValkeyTestSuite)) +} + +func TestParseValkeyURL(t *testing.T) { + // Basic URL + host, port, username, password, db, useTLS, err := parseValkeyURL("valkey://127.0.0.1:6380/") + assert.NoError(t, err) + assert.Equal(t, "127.0.0.1", host) + assert.Equal(t, 6380, port) + assert.Equal(t, "", username) + assert.Equal(t, "", password) + assert.Equal(t, 0, db) + assert.False(t, useTLS) + + // URL with password only + host, _, _, password, db, useTLS, err = parseValkeyURL("valkey://:secret@localhost:6379/2") + assert.NoError(t, err) + assert.Equal(t, "localhost", host) + assert.Equal(t, "secret", password) + assert.Equal(t, 2, db) + assert.False(t, useTLS) + + // URL with username and password + host, port, username, password, db, useTLS, err = parseValkeyURL("valkey://myuser:mypass@host.example.com:6380/3") + assert.NoError(t, err) + assert.Equal(t, "host.example.com", host) + assert.Equal(t, 6380, port) + assert.Equal(t, "myuser", username) + assert.Equal(t, "mypass", password) + assert.Equal(t, 3, db) + assert.False(t, useTLS) + + // TLS URL + _, _, _, _, _, useTLS, err = parseValkeyURL("valkeys://localhost:6379/0") + assert.NoError(t, err) + assert.True(t, useTLS) + + // Default port + _, port, _, _, _, _, err = parseValkeyURL("valkey://localhost/") + assert.NoError(t, err) + assert.Equal(t, 6379, port) +} + +func TestParseValkeyClusterURL(t *testing.T) { + // Basic cluster URL + addresses, username, password, useTLS, err := parseValkeyClusterURL("valkey+cluster://:password@192.168.1.11:6379?addr=192.168.0.5:6379&addr=192.168.0.7:6379") + assert.NoError(t, err) + assert.Len(t, addresses, 3) + assert.Equal(t, "192.168.1.11:6379", addresses[0]) + assert.Equal(t, "192.168.0.5:6379", addresses[1]) + assert.Equal(t, "192.168.0.7:6379", addresses[2]) + assert.Equal(t, "", username) + assert.Equal(t, "password", password) + assert.False(t, useTLS) + + // Cluster URL with username + addresses, username, password, useTLS, err = parseValkeyClusterURL("valkeys+cluster://admin:secret@node1:6380?addr=node2:6380") + assert.NoError(t, err) + assert.Len(t, addresses, 2) + assert.Equal(t, "node1:6380", addresses[0]) + assert.Equal(t, "node2:6380", addresses[1]) + assert.Equal(t, "admin", username) + assert.Equal(t, "secret", password) + assert.True(t, useTLS) +} + +func BenchmarkValkey(b *testing.B) { + log.CloseLogger() + database, err := Open(valkeyDSN, "gorse_") + assert.NoError(b, err) + // flush db + valkeyClient := database.(*Valkey) + err = valkeyClient.client.Do(context.Background(), valkeyClient.client.B().Flushdb().Build()).Error() + assert.NoError(b, err) + // create schema + err = database.Init() + assert.NoError(b, err) + // benchmark + benchmark(b, database) +} diff --git a/storage/docker-compose.yml b/storage/docker-compose.yml index 97b98331c..3382766f5 100644 --- a/storage/docker-compose.yml +++ b/storage/docker-compose.yml @@ -5,6 +5,14 @@ services: ports: - 6379:6379 + # Valkey with valkey-search 1.2.0+ (full-text search support). + # Bundle 'unstable' is the only published tag with search >= 1.2. + # Switch to a stable tag (e.g. 9.1) once it is released. + valkey: + image: valkey/valkey-bundle:unstable + ports: + - 6380:6379 + mysql: image: mysql:8.0 ports: diff --git a/storage/scheme.go b/storage/scheme.go index 138176127..c1f7e856d 100644 --- a/storage/scheme.go +++ b/storage/scheme.go @@ -70,23 +70,27 @@ func init() { } const ( - MySQLPrefix = "mysql://" - MongoPrefix = "mongodb://" - MongoSrvPrefix = "mongodb+srv://" - PostgresPrefix = "postgres://" - PostgreSQLPrefix = "postgresql://" - ClickhousePrefix = "clickhouse://" - CHHTTPPrefix = "chhttp://" - CHHTTPSPrefix = "chhttps://" - SQLitePrefix = "sqlite://" - RedisPrefix = "redis://" - RedissPrefix = "rediss://" - RedisClusterPrefix = "redis+cluster://" - RedissClusterPrefix = "rediss+cluster://" - QdrantPrefix = "qdrant://" - WeaviatePrefix = "weaviate://" - WeaviatesPrefix = "weaviates://" - MilvusPrefix = "milvus://" + MySQLPrefix = "mysql://" + MongoPrefix = "mongodb://" + MongoSrvPrefix = "mongodb+srv://" + PostgresPrefix = "postgres://" + PostgreSQLPrefix = "postgresql://" + ClickhousePrefix = "clickhouse://" + CHHTTPPrefix = "chhttp://" + CHHTTPSPrefix = "chhttps://" + SQLitePrefix = "sqlite://" + RedisPrefix = "redis://" + RedissPrefix = "rediss://" + RedisClusterPrefix = "redis+cluster://" + RedissClusterPrefix = "rediss+cluster://" + ValkeyPrefix = "valkey://" + ValkeysPrefix = "valkeys://" + ValkeyClusterPrefix = "valkey+cluster://" + ValkeysClusterPrefix = "valkeys+cluster://" + QdrantPrefix = "qdrant://" + WeaviatePrefix = "weaviate://" + WeaviatesPrefix = "weaviates://" + MilvusPrefix = "milvus://" ) func AppendURLParams(rawURL string, params []lo.Tuple2[string, string]) (string, error) {