Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 71 additions & 35 deletions defaults.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"encoding"
"encoding/json"
"errors"
"fmt"
"reflect"
"strconv"
"time"
Expand Down Expand Up @@ -34,7 +35,7 @@ func Set(ptr interface{}) error {

for i := 0; i < t.NumField(); i++ {
if defaultVal := t.Field(i).Tag.Get(fieldName); defaultVal != "-" {
if err := setField(v.Field(i), defaultVal); err != nil {
if err := setField(v.Field(i), t.Field(i).Name, defaultVal); err != nil {
return err
}
}
Expand All @@ -51,7 +52,11 @@ func MustSet(ptr interface{}) {
}
}

func setField(field reflect.Value, defaultVal string) error {
func setField(field reflect.Value, currFieldName, defaultVal string) error {
wrapErr := func(err error) error {
return fmt.Errorf("error on set Field:[%s], DefaultValue:[%s] Error:[%v]", currFieldName, defaultVal, err)
}

if !field.CanSet() {
return nil
}
Expand All @@ -68,63 +73,91 @@ func setField(field reflect.Value, defaultVal string) error {

switch field.Kind() {
case reflect.Bool:
if val, err := strconv.ParseBool(defaultVal); err == nil {
field.Set(reflect.ValueOf(val).Convert(field.Type()))
val, err := strconv.ParseBool(defaultVal)
if err != nil {
return wrapErr(err)
}
field.Set(reflect.ValueOf(val).Convert(field.Type()))
case reflect.Int:
if val, err := strconv.ParseInt(defaultVal, 0, strconv.IntSize); err == nil {
field.Set(reflect.ValueOf(int(val)).Convert(field.Type()))
val, err := strconv.ParseInt(defaultVal, 0, strconv.IntSize)
if err != nil {
return wrapErr(err)
}
field.Set(reflect.ValueOf(int(val)).Convert(field.Type()))
case reflect.Int8:
if val, err := strconv.ParseInt(defaultVal, 0, 8); err == nil {
field.Set(reflect.ValueOf(int8(val)).Convert(field.Type()))
val, err := strconv.ParseInt(defaultVal, 0, 8)
if err != nil {
return wrapErr(err)
}
field.Set(reflect.ValueOf(int8(val)).Convert(field.Type()))
case reflect.Int16:
if val, err := strconv.ParseInt(defaultVal, 0, 16); err == nil {
field.Set(reflect.ValueOf(int16(val)).Convert(field.Type()))
val, err := strconv.ParseInt(defaultVal, 0, 16)
if err != nil {
return wrapErr(err)
}
field.Set(reflect.ValueOf(int16(val)).Convert(field.Type()))
case reflect.Int32:
if val, err := strconv.ParseInt(defaultVal, 0, 32); err == nil {
field.Set(reflect.ValueOf(int32(val)).Convert(field.Type()))
val, err := strconv.ParseInt(defaultVal, 0, 32)
if err != nil {
return wrapErr(err)
}
field.Set(reflect.ValueOf(int32(val)).Convert(field.Type()))
case reflect.Int64:
if val, err := time.ParseDuration(defaultVal); err == nil {
field.Set(reflect.ValueOf(val).Convert(field.Type()))
} else if val, err := strconv.ParseInt(defaultVal, 0, 64); err == nil {
field.Set(reflect.ValueOf(val).Convert(field.Type()))
} else {
return wrapErr(err)
}
case reflect.Uint:
if val, err := strconv.ParseUint(defaultVal, 0, strconv.IntSize); err == nil {
field.Set(reflect.ValueOf(uint(val)).Convert(field.Type()))
val, err := strconv.ParseUint(defaultVal, 0, strconv.IntSize)
if err != nil {
return wrapErr(err)
}
field.Set(reflect.ValueOf(uint(val)).Convert(field.Type()))
case reflect.Uint8:
if val, err := strconv.ParseUint(defaultVal, 0, 8); err == nil {
field.Set(reflect.ValueOf(uint8(val)).Convert(field.Type()))
val, err := strconv.ParseUint(defaultVal, 0, 8)
if err != nil {
return wrapErr(err)
}
field.Set(reflect.ValueOf(uint8(val)).Convert(field.Type()))
case reflect.Uint16:
if val, err := strconv.ParseUint(defaultVal, 0, 16); err == nil {
field.Set(reflect.ValueOf(uint16(val)).Convert(field.Type()))
val, err := strconv.ParseUint(defaultVal, 0, 16)
if err != nil {
return wrapErr(err)
}
field.Set(reflect.ValueOf(uint16(val)).Convert(field.Type()))
case reflect.Uint32:
if val, err := strconv.ParseUint(defaultVal, 0, 32); err == nil {
field.Set(reflect.ValueOf(uint32(val)).Convert(field.Type()))
val, err := strconv.ParseUint(defaultVal, 0, 32)
if err != nil {
return wrapErr(err)
}
field.Set(reflect.ValueOf(uint32(val)).Convert(field.Type()))
case reflect.Uint64:
if val, err := strconv.ParseUint(defaultVal, 0, 64); err == nil {
field.Set(reflect.ValueOf(val).Convert(field.Type()))
val, err := strconv.ParseUint(defaultVal, 0, 64)
if err != nil {
return wrapErr(err)
}
field.Set(reflect.ValueOf(val).Convert(field.Type()))
case reflect.Uintptr:
if val, err := strconv.ParseUint(defaultVal, 0, strconv.IntSize); err == nil {
field.Set(reflect.ValueOf(uintptr(val)).Convert(field.Type()))
val, err := strconv.ParseUint(defaultVal, 0, strconv.IntSize)
if err != nil {
return wrapErr(err)
}
field.Set(reflect.ValueOf(uintptr(val)).Convert(field.Type()))
case reflect.Float32:
if val, err := strconv.ParseFloat(defaultVal, 32); err == nil {
field.Set(reflect.ValueOf(float32(val)).Convert(field.Type()))
val, err := strconv.ParseFloat(defaultVal, 32)
if err != nil {
return wrapErr(err)
}
field.Set(reflect.ValueOf(float32(val)).Convert(field.Type()))
case reflect.Float64:
if val, err := strconv.ParseFloat(defaultVal, 64); err == nil {
field.Set(reflect.ValueOf(val).Convert(field.Type()))
val, err := strconv.ParseFloat(defaultVal, 64)
if err != nil {
return wrapErr(err)
}
field.Set(reflect.ValueOf(val).Convert(field.Type()))
case reflect.String:
field.Set(reflect.ValueOf(defaultVal).Convert(field.Type()))

Expand All @@ -133,7 +166,7 @@ func setField(field reflect.Value, defaultVal string) error {
ref.Elem().Set(reflect.MakeSlice(field.Type(), 0, 0))
if defaultVal != "" && defaultVal != "[]" {
if err := json.Unmarshal([]byte(defaultVal), ref.Interface()); err != nil {
return err
return wrapErr(err)
}
}
field.Set(ref.Elem().Convert(field.Type()))
Expand All @@ -142,14 +175,14 @@ func setField(field reflect.Value, defaultVal string) error {
ref.Elem().Set(reflect.MakeMap(field.Type()))
if defaultVal != "" && defaultVal != "{}" {
if err := json.Unmarshal([]byte(defaultVal), ref.Interface()); err != nil {
return err
return wrapErr(err)
}
}
field.Set(ref.Elem().Convert(field.Type()))
case reflect.Struct:
if defaultVal != "" && defaultVal != "{}" {
if err := json.Unmarshal([]byte(defaultVal), field.Addr().Interface()); err != nil {
return err
return wrapErr(err)
}
}
case reflect.Ptr:
Expand All @@ -160,7 +193,10 @@ func setField(field reflect.Value, defaultVal string) error {
switch field.Kind() {
case reflect.Ptr:
if isInitial || field.Elem().Kind() == reflect.Struct {
setField(field.Elem(), defaultVal)
err := setField(field.Elem(), currFieldName, defaultVal)
if err != nil {
return err
}
callSetter(field.Interface())
}
case reflect.Struct:
Expand All @@ -169,7 +205,7 @@ func setField(field reflect.Value, defaultVal string) error {
}
case reflect.Slice:
for j := 0; j < field.Len(); j++ {
if err := setField(field.Index(j), ""); err != nil {
if err := setField(field.Index(j), currFieldName, ""); err != nil {
return err
}
}
Expand All @@ -181,14 +217,14 @@ func setField(field reflect.Value, defaultVal string) error {
case reflect.Ptr:
switch v.Elem().Kind() {
case reflect.Struct, reflect.Slice, reflect.Map:
if err := setField(v.Elem(), ""); err != nil {
if err := setField(v.Elem(), currFieldName, ""); err != nil {
return err
}
}
case reflect.Struct, reflect.Slice, reflect.Map:
ref := reflect.New(v.Type())
ref.Elem().Set(v)
if err := setField(ref.Elem(), ""); err != nil {
if err := setField(ref.Elem(), currFieldName, ""); err != nil {
return err
}
field.SetMapIndex(e, ref.Elem().Convert(v.Type()))
Expand Down
109 changes: 109 additions & 0 deletions defaults_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -769,3 +769,112 @@ func TestDefaultsSetter(t *testing.T) {
t.Errorf("expected 1 for MainInt, got %d", main.MainInt)
}
}

func expectErrorForDefaultSet(t *testing.T, cfg interface{}) {
if err := Set(cfg); err == nil {
t.Errorf("expected error, got nil")
}
}

type ErrBoolDefault struct {
ErrBool bool `default:"ture"`
}

type ErrIntDefault struct {
ErrInt int `default:"abc"`
}

type ErrInt8Default struct {
ErrInt8 int8 `default:"abc"`
}

type ErrInt16Default struct {
ErrInt16 int16 `default:"abc"`
}

type ErrInt32Default struct {
ErrInt32 int32 `default:"abc"`
}

type ErrInt64Default struct {
ErrInt64 int64 `default:"abc"`
}

type ErrUintDefault struct {
ErrUint uint `default:"abc"`
}

type ErrUint8Default struct {
ErrUint8 uint8 `default:"abc"`
}

type ErrUint16Default struct {
ErrUint16 uint16 `default:"abc"`
}

type ErrUint32Default struct {
ErrUint32 uint32 `default:"abc"`
}

type ErrUint64Default struct {
ErrUint64 uint64 `default:"abc"`
}

type ErrUintptrDefault struct {
ErrUintptr uintptr `default:"abc"`
}

type ErrFloat32Default struct {
ErrFloat32 float32 `default:"abc"`
}

type ErrFloat64Default struct {
ErrFloat64 float64 `default:"abc"`
}

func TestBasicTypeErrorConfig(t *testing.T) {
errBool := &ErrBoolDefault{}
expectErrorForDefaultSet(t, errBool)
errInt := &ErrIntDefault{}
expectErrorForDefaultSet(t, errInt)
errInt8 := &ErrInt8Default{}
expectErrorForDefaultSet(t, errInt8)
errInt16 := &ErrInt16Default{}
expectErrorForDefaultSet(t, errInt16)
errInt32 := &ErrInt32Default{}
expectErrorForDefaultSet(t, errInt32)
errInt64 := &ErrInt64Default{}
expectErrorForDefaultSet(t, errInt64)
errUint := &ErrUintDefault{}
expectErrorForDefaultSet(t, errUint)
errUint8 := &ErrUint8Default{}
expectErrorForDefaultSet(t, errUint8)
errUint16 := &ErrUint16Default{}
expectErrorForDefaultSet(t, errUint16)
errUint32 := &ErrUint32Default{}
expectErrorForDefaultSet(t, errUint32)
errUint64 := &ErrUint64Default{}
expectErrorForDefaultSet(t, errUint64)
errUintptr := &ErrUintptrDefault{}
expectErrorForDefaultSet(t, errUintptr)
errFloat32 := &ErrFloat32Default{}
expectErrorForDefaultSet(t, errFloat32)
errFloat64 := &ErrFloat64Default{}
expectErrorForDefaultSet(t, errFloat64)
}

type SubErrorDefault struct {
ErrSlice []string `default:"[1,2,3]"`
NormalInt int `default:"1"`
}

type ParentErrorDefault struct {
NormalInt int `default:"1"`
ErrChild *SubErrorDefault `default:"{}"`
NormalBool bool `default:"true"`
}

func TestSubFieldError(t *testing.T) {
cfg := &ParentErrorDefault{}
expectErrorForDefaultSet(t, cfg)
}