Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Customizability (transform) #2

Merged
merged 4 commits into from
Apr 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions deepequal.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@ type visit struct {
typ reflect.Type
}

const maxDepth = 1_000

func (teq Teq) deepValueEqual(v1, v2 reflect.Value, visited map[visit]bool, depth int) bool {
if depth > maxDepth {
func (teq Teq) deepValueEqual(
v1, v2 reflect.Value,
visited map[visit]bool,
depth int,
) bool {
if depth > teq.MaxDepth {
panic("maximum depth exceeded")
}
if !v1.IsValid() || !v2.IsValid() {
Expand All @@ -32,6 +34,15 @@ func (teq Teq) deepValueEqual(v1, v2 reflect.Value, visited map[visit]bool, dept
return false
}

tr, ok := teq.transforms[v1.Type()]
if ok {
t1 := tr(v1)
t2 := tr(v2)
newTeq := New()
newTeq.MaxDepth = teq.MaxDepth
return newTeq.deepValueEqual(t1, t2, visited, depth)
}

if hard(v1.Kind()) {
if v1.CanAddr() && v2.CanAddr() {
addr1 := v1.Addr().UnsafePointer()
Expand Down Expand Up @@ -134,7 +145,7 @@ func pointerEq(teq Teq, v1, v2 reflect.Value, nx next) bool {

func structEq(teq Teq, v1, v2 reflect.Value, nx next) bool {
for i, n := 0, v1.NumField(); i < n; i++ {
if !nx(v1.Field(i), v2.Field(i)) {
if !nx(field(v1, i), field(v2, i)) {
return false
}
}
Expand Down
16 changes: 16 additions & 0 deletions misc.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package teq

import (
"reflect"
)

func field(v reflect.Value, idx int) reflect.Value {
f1 := v.Field(idx)
if f1.CanAddr() {
return f1
}
vc := reflect.New(v.Type()).Elem()
vc.Set(v)
rf := vc.Field(idx)
return reflect.NewAt(rf.Type(), rf.Addr().UnsafePointer()).Elem()
}
60 changes: 57 additions & 3 deletions teq.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,25 @@ import (
"reflect"
)

type Teq struct{}
type Teq struct {
MaxDepth int

transforms map[reflect.Type]func(reflect.Value) reflect.Value
}

func New() Teq {
return Teq{
MaxDepth: 1_000,

transforms: make(map[reflect.Type]func(reflect.Value) reflect.Value),
}
}

func (teq Teq) Equal(t TestingT, expected, actual any) bool {
t.Helper()
defer func() {
if r := recover(); r != nil {
t.Errorf("panic: %v", r)
t.Errorf("panic in github.com/seiyab/teq. please report issue. message: %v", r)
}
}()
ok := teq.equal(expected, actual)
Expand All @@ -20,6 +32,44 @@ func (teq Teq) Equal(t TestingT, expected, actual any) bool {
return ok
}

func (teq Teq) NotEqual(t TestingT, expected, actual any) bool {
t.Helper()
defer func() {
if r := recover(); r != nil {
t.Errorf("panic in github.com/seiyab/teq. please report issue. message: %v", r)
}
}()
ok := !teq.equal(expected, actual)
if !ok {
if reflect.DeepEqual(expected, actual) {
t.Error("reflect.DeepEqual(expected, actual) == true.")
} else {
t.Errorf("expected %v != %v", expected, actual)
t.Log("reflect.DeepEqual(expected, actual) == false. maybe transforms made them equal.")
}
}
return ok

}

func (teq *Teq) AddTransform(transform any) {
ty := reflect.TypeOf(transform)
if ty.Kind() != reflect.Func {
panic("transform must be a function")
}
if ty.NumIn() != 1 {
panic("transform must have only one argument")
}
if ty.NumOut() != 1 {
panic("transform must have only one return value")
}
trValue := reflect.ValueOf(transform)
reflectTransform := func(v reflect.Value) reflect.Value {
return trValue.Call([]reflect.Value{v})[0]
}
teq.transforms[ty.In(0)] = reflectTransform
}

func (teq Teq) equal(x, y any) bool {
if x == nil || y == nil {
return x == y
Expand All @@ -29,5 +79,9 @@ func (teq Teq) equal(x, y any) bool {
if v1.Type() != v2.Type() {
return false
}
return teq.deepValueEqual(v1, v2, make(map[visit]bool), 0)
return teq.deepValueEqual(
v1, v2,
make(map[visit]bool),
0,
)
}
55 changes: 55 additions & 0 deletions teq_customized_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package teq_test

import (
"reflect"
"testing"
"time"

"github.com/seiyab/teq"
)

func TestEqual_Customized(t *testing.T) {
t.Run("time.Time", func(t *testing.T) {
defaultTeq := teq.New()
customizedTeq := teq.New()
customizedTeq.AddTransform(utc)

secondsEastOfUTC := int((8 * time.Hour).Seconds())
beijing := time.FixedZone("Beijing Time", secondsEastOfUTC)
d1 := time.Date(2000, 2, 1, 12, 30, 0, 0, time.UTC)
d2 := time.Date(2000, 2, 1, 20, 30, 0, 0, beijing)

defaultTeq.NotEqual(t, d1, d2)
customizedTeq.Equal(t, d1, d2)
if reflect.DeepEqual(d1, d2) {
t.Error("expected d1 != d2, got d1 == d2 with reflect.DeepEqual")
}

type twoDates struct {
d1 time.Time
d2 time.Time
}
dt1 := twoDates{d1, d2}
dt2 := twoDates{d2, d1}

defaultTeq.NotEqual(t, dt1, dt2)
customizedTeq.Equal(t, dt1, dt2)

if reflect.DeepEqual(dt1, dt2) {
t.Error("expected dt1 != dt2, got dt1 == dt2 with reflect.DeepEqual")
}

t.Skip("slice is not supported yet")
ds1 := []time.Time{d1, d1, d2}
ds2 := []time.Time{d2, d1, d1}
defaultTeq.NotEqual(t, ds1, ds2)
customizedTeq.Equal(t, ds1, ds2)
if reflect.DeepEqual(ds1, ds2) {
t.Error("expected ds1 != ds2, got ds1 == ds2 with reflect.DeepEqual")
}
})
}

func utc(d time.Time) time.Time {
return d.UTC()
}
21 changes: 15 additions & 6 deletions teq_test.go → teq_default_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ type group struct {
}

func TestEqual(t *testing.T) {
assert := teq.Teq{}
assert := teq.New()

groups := []group{
{"primitives", primitives()},
Expand All @@ -43,14 +43,23 @@ func TestEqual(t *testing.T) {
}
t.Fatalf("expected %d errors, got %d", len(test.expected), len(mt.errors))
}
if test.pendingFormat {
return

if !test.pendingFormat {
for i, e := range test.expected {
if mt.errors[i] != e {
t.Errorf("expected %q, got %q at i = %d", e, mt.errors[i], i)
}
}
}
for i, e := range test.expected {
if mt.errors[i] != e {
t.Errorf("expected %q, got %q at i = %d", e, mt.errors[i], i)

{
mt := &mockT{}
assert.NotEqual(mt, test.a, test.b)
if (len(mt.errors) > 0) == (len(test.expected) > 0) {
t.Errorf("expected (len(mt.errors) > 0) = %t, got %t", len(test.expected) > 0, len(mt.errors) > 0)
}
}

})
}
})
Expand Down
2 changes: 2 additions & 0 deletions testingt.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ import "testing"

type TestingT interface {
Helper()
Error(args ...interface{})
Errorf(format string, args ...interface{})
Log(args ...interface{})
}

var _ TestingT = &testing.T{}
6 changes: 6 additions & 0 deletions testingt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@ var _ teq.TestingT = &mockT{}

func (t *mockT) Helper() {}

func (t *mockT) Error(args ...interface{}) {
t.errors = append(t.errors, fmt.Sprint(args...))
}

func (t *mockT) Errorf(format string, args ...interface{}) {
t.errors = append(t.errors, fmt.Sprintf(format, args...))
}

func (t *mockT) Log(args ...interface{}) {}
Loading