diff --git a/json.go b/json.go index 5f22956..ba20a7e 100644 --- a/json.go +++ b/json.go @@ -457,6 +457,10 @@ type JSONArrayExpression struct { column string keys []string equalsValue interface{} + // Add a new field to indicate whether to perform a length comparison + length bool + // Add a new field to store the expected length value + lengthValue int } // Contains checks if column[keys] contains the value given. The keys parameter is only supported for MySQL and SQLite. @@ -475,6 +479,13 @@ func (json *JSONArrayExpression) In(value interface{}, keys ...string) *JSONArra return json } +// Length checks if the length of the JSON array matches the given value. +func (json *JSONArrayExpression) Length(value int) *JSONArrayExpression { + json.length = true + json.lengthValue = value + return json +} + // Build implements clause.Expression func (json *JSONArrayExpression) Build(builder clause.Builder) { if stmt, ok := builder.(*gorm.Statement); ok { @@ -504,6 +515,10 @@ func (json *JSONArrayExpression) Build(builder clause.Builder) { builder.WriteByte(')') } builder.WriteByte(')') + // Add new logic to handle length comparison + case json.length: + builder.WriteString("JSON_LENGTH(" + stmt.Quote(json.column) + ") = ") + builder.AddVar(stmt, json.lengthValue) } case "sqlite": switch { @@ -551,6 +566,9 @@ func (json *JSONArrayExpression) Build(builder clause.Builder) { builder.WriteString(" IN ") builder.AddVar(stmt, json.equalsValue) builder.WriteString(" END") + case json.length: + builder.WriteString("json_array_length(" + stmt.Quote(json.column) + ") = ") + builder.AddVar(stmt, json.lengthValue) } case "postgres": switch { @@ -558,6 +576,9 @@ func (json *JSONArrayExpression) Build(builder clause.Builder) { builder.WriteString(stmt.Quote(json.column)) builder.WriteString(" ? ") builder.AddVar(stmt, json.equalsValue) + case json.length: + builder.WriteString("array_length(" + stmt.Quote(json.column) + "::jsonb::text[], 1) = ") + builder.AddVar(stmt, json.lengthValue) } } } diff --git a/json_test.go b/json_test.go index a5badb1..1a190e4 100644 --- a/json_test.go +++ b/json_test.go @@ -523,5 +523,20 @@ func TestJSONArrayQuery(t *testing.T) { t.Fatalf("failed to find params with json value and keys, got error %v", err) } AssertEqual(t, len(retMultiple), 1) + + // 新增的长度测试用例 + var retLength []Param + if err := DB.Where(datatypes.JSONArrayQuery("config").Length(2)).Find(&retLength).Error; err != nil { + t.Fatalf("failed to find params with json array length, got error %v", err) + } + AssertEqual(t, len(retLength), 1) + AssertEqual(t, retLength[0].DisplayName, cmp1.DisplayName) + + // 测试嵌套数组的长度 + if err := DB.Where(datatypes.JSONArrayQuery("config").Length(2).Contains("a", "test")).Find(&retLength).Error; err != nil { + t.Fatalf("failed to find params with json array length and contains, got error %v", err) + } + AssertEqual(t, len(retLength), 1) + AssertEqual(t, retLength[0].DisplayName, cmp3.DisplayName) } }