Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pkg/generic/descriptor/type.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ const (
UTF16 Type = 17
// BINARY Type = 18 wrong and unused
JSON Type = 19

MAX_TYPE // indicates max type, so it has the value of previous one
TYPE_UPPERBOUND = MAX_TYPE + 1
)

var typeNames = map[Type]string{
Expand Down
122 changes: 106 additions & 16 deletions pkg/generic/thrift/write.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,22 +62,27 @@ func typeOf(sample interface{}, t *descriptor.TypeDescriptor, opt *writerOption)
case int8, byte:
switch tt {
case descriptor.I08, descriptor.I16, descriptor.I32, descriptor.I64:
return tt, writeInt8, nil
return tt, writeIntSeries[int8], nil
}
case int16:
switch tt {
case descriptor.I08, descriptor.I16, descriptor.I32, descriptor.I64:
return tt, writeInt16, nil
return tt, writeIntSeries[int16], nil
}
case int32:
switch tt {
case descriptor.I08, descriptor.I16, descriptor.I32, descriptor.I64:
return tt, writeInt32, nil
return tt, writeIntSeries[int32], nil
}
case int64:
switch tt {
case descriptor.I08, descriptor.I16, descriptor.I32, descriptor.I64:
return tt, writeInt64, nil
return tt, writeIntSeries[int64], nil
}
case int:
switch tt {
case descriptor.I08, descriptor.I16, descriptor.I32, descriptor.I64:
return tt, writeIntSeries[int], nil
}
case float64:
// maybe come from json decode
Expand Down Expand Up @@ -123,6 +128,16 @@ func typeOf(sample interface{}, t *descriptor.TypeDescriptor, opt *writerOption)
return descriptor.STRUCT, writeJSON, nil
case nil, descriptor.Void: // nil and Void
return descriptor.VOID, writeVoid, nil
case map[int8]interface{}: // these branches are uncommon, so placed in the rear of the switch clause, avoiding testing them early
return descriptor.MAP, writeVariousComparableMap[int8], nil
case map[int16]interface{}:
return descriptor.MAP, writeVariousComparableMap[int16], nil
case map[int32]interface{}:
return descriptor.MAP, writeVariousComparableMap[int32], nil
case map[int64]interface{}:
return descriptor.MAP, writeVariousComparableMap[int64], nil
case map[int]interface{}:
return descriptor.MAP, writeVariousComparableMap[int], nil
}
return 0, nil, fmt.Errorf("unsupported type:%T, expected type:%s", sample, tt)
}
Expand All @@ -141,19 +156,19 @@ func typeJSONOf(data *gjson.Result, t *descriptor.TypeDescriptor, opt *writerOpt
return
case descriptor.I08:
v = int8(data.Int())
w = writeInt8
w = writeIntSeries[int8]
return
case descriptor.I16:
v = int16(data.Int())
w = writeInt16
w = writeIntSeries[int16]
return
case descriptor.I32:
v = int32(data.Int())
w = writeInt32
w = writeIntSeries[int32]
return
case descriptor.I64:
v = data.Int()
w = writeInt64
w = writeIntSeries[int64]
return
case descriptor.DOUBLE:
v = data.Float()
Expand Down Expand Up @@ -311,6 +326,37 @@ func writeBool(ctx context.Context, val interface{}, out *thrift.BufferWriter, t
return out.WriteBool(val.(bool))
}

func writeIntSeries[T ~int | ~int8 | ~uint8 | ~int16 | ~int32 | ~int64](ctx context.Context, val interface{}, out *thrift.BufferWriter, t *descriptor.TypeDescriptor, opt *writerOption) error {
// compatible with lossless conversion
i, isT := val.(T)
if !isT {
return fmt.Errorf("unsupported type: %T", val)
}
switch t.Type {
case descriptor.I08:
// uint8 will never exceed the range
if _, ok := val.(uint8); !ok && (int64(i) < math.MinInt8 || int64(i) > math.MaxInt8) {
return fmt.Errorf("value is beyond range of i8: %v", i)
}
return out.WriteByte(int8(i))
case descriptor.I16:
if int64(i) < math.MinInt16 || int64(i) > math.MaxInt16 {
return fmt.Errorf("value is beyond range of i16: %v", i)
}
return out.WriteI16(int16(i))
case descriptor.I32:
// for int on 32-bit architectures: this branch is considered dead hence removed by go compiler
// for int on 64-bit architectures: this branch functions as writeInt64() does
if int64(i) < math.MinInt32 || int64(i) > math.MaxInt32 {
return fmt.Errorf("value is beyond range of i32: %v", i)
}
return out.WriteI32(int32(i))
case descriptor.I64:
return out.WriteI64(int64(i))
}
return fmt.Errorf("need int type, but got: %s", t.Type)
}

func writeInt8(ctx context.Context, val interface{}, out *thrift.BufferWriter, t *descriptor.TypeDescriptor, opt *writerOption) error {
var i int8
switch val := val.(type) {
Expand Down Expand Up @@ -409,25 +455,25 @@ func writeJSONNumber(ctx context.Context, val interface{}, out *thrift.BufferWri
if err != nil {
return err
}
return writeInt8(ctx, int8(i), out, t, opt)
return writeIntSeries[int8](ctx, int8(i), out, t, opt)
case descriptor.I16:
i, err := jn.Int64()
if err != nil {
return err
}
return writeInt16(ctx, int16(i), out, t, opt)
return writeIntSeries[int16](ctx, int16(i), out, t, opt)
case descriptor.I32:
i, err := jn.Int64()
if err != nil {
return err
}
return writeInt32(ctx, int32(i), out, t, opt)
return writeIntSeries[int32](ctx, int32(i), out, t, opt)
case descriptor.I64:
i, err := jn.Int64()
if err != nil {
return err
}
return writeInt64(ctx, i, out, t, opt)
return writeIntSeries[int64](ctx, i, out, t, opt)
case descriptor.DOUBLE:
i, err := jn.Float64()
if err != nil {
Expand All @@ -442,13 +488,13 @@ func writeJSONFloat64(ctx context.Context, val interface{}, out *thrift.BufferWr
i := val.(float64)
switch t.Type {
case descriptor.I08:
return writeInt8(ctx, int8(i), out, t, opt)
return writeIntSeries[int8](ctx, int8(i), out, t, opt)
case descriptor.I16:
return writeInt16(ctx, int16(i), out, t, opt)
return writeIntSeries[int16](ctx, int16(i), out, t, opt)
case descriptor.I32:
return writeInt32(ctx, int32(i), out, t, opt)
return writeIntSeries[int32](ctx, int32(i), out, t, opt)
case descriptor.I64:
return writeInt64(ctx, int64(i), out, t, opt)
return writeIntSeries[int64](ctx, int64(i), out, t, opt)
case descriptor.DOUBLE:
return writeFloat64(ctx, i, out, t, opt)
}
Expand Down Expand Up @@ -583,6 +629,50 @@ func writeInterfaceMap(ctx context.Context, val interface{}, out *thrift.BufferW
return nil
}

// writeVariousComparableMap moved most code from writeInterfaceMap.
// writeInterfaceMap can be simplified to one line: return writeVariousComparableMap[interface{}](ctx, val, out, t, opt)
// However this causes incompatibility with go 1.18 since it does not consider interface{} to be comparable. So code was copied
func writeVariousComparableMap[K comparable](ctx context.Context, val interface{}, out *thrift.BufferWriter, t *descriptor.TypeDescriptor, opt *writerOption) error {
m := val.(map[K]interface{})
length := len(m)
if err := out.WriteMapBegin(thrift.TType(t.Key.Type), thrift.TType(t.Elem.Type), length); err != nil {
return err
}
if length == 0 {
return nil
}
var (
keyWriter writer
elemWriter writer
err error
)
for key, elem := range m {
if keyWriter == nil {
if keyWriter, err = nextWriter(key, t.Key, opt); err != nil {
return err
}
}
if err := keyWriter(ctx, key, out, t.Key, opt); err != nil {
return err
}
if elem == nil {
if err = writeEmptyValue(out, t.Elem, opt); err != nil {
return err
}
} else {
if elemWriter == nil {
if elemWriter, err = nextWriter(elem, t.Elem, opt); err != nil {
return err
}
}
if err := elemWriter(ctx, elem, out, t.Elem, opt); err != nil {
return err
}
}
}
return nil
}

func writeStringMap(ctx context.Context, val interface{}, out *thrift.BufferWriter, t *descriptor.TypeDescriptor, opt *writerOption) error {
m := val.(map[string]interface{})
length := len(m)
Expand Down
117 changes: 117 additions & 0 deletions pkg/generic/thrift/write_new.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
package thrift

import (
"context"
"fmt"
"math"

"github.com/cloudwego/gopkg/protocol/thrift"
"github.com/cloudwego/kitex/pkg/generic/descriptor"
)

type thriftWriter func(ctx context.Context, out *thrift.BufferWriter, val interface{}) error

func nopWriter(ctx context.Context, out *thrift.BufferWriter, val interface{}) error {
return nil
}

var (
_ thriftWriter = nopWriter
_ thriftWriter = voidWriter
)

var writeFunctions = [descriptor.TYPE_UPPERBOUND]thriftWriter{
nopWriter, // STOP = 0
voidWriter, // VOID = 1
boolWriter, // BOOL = 2
byteWriter, // BYTE/I08 = 3
doubleWriter, // DOUBLE = 4
nopWriter, // nothing, index = 5
int16Writer, // I16 = 6
nopWriter, // nothing, index = 7
int32Writer, // I32 = 8
nopWriter, // nothing, index = 9
int64Writer, // I64 = 10
stringWriter, // STRING = 11
}

func voidWriter(ctx context.Context, out *thrift.BufferWriter, val interface{}) error {
panic("todo")
}

func boolWriter(ctx context.Context, out *thrift.BufferWriter, val interface{}) error {
b, ok := val.(bool)
if !ok {
return fmt.Errorf("val must be bool, got %T", val)
}
return out.WriteBool(b)
}

func byteWriter(ctx context.Context, out *thrift.BufferWriter, val interface{}) error {
switch b := val.(type) {
case int8:
return out.WriteByte(b)
case uint8:
return out.WriteByte(int8(b))
}

// slow path
if b, ok := asInt64(val); ok {
if b <= math.MaxInt8 && b >= math.MinInt8 {
return out.WriteByte(int8(b))
}
return fmt.Errorf("val[%d] is out of range (int8)", b)
}
return fmt.Errorf("val must be int, got %T", val)
}

func doubleWriter(ctx context.Context, out *thrift.BufferWriter, val interface{}) error {
b, ok := asFloat64(val)
if ok {
return out.WriteDouble(b)
}
return fmt.Errorf("val must be float64, got %T", val)
}

func int16Writer(ctx context.Context, out *thrift.BufferWriter, val interface{}) error {
if b, ok := val.(int16); ok {
return out.WriteI16(b)
}
if b, ok := asInt64(val); ok {
if b <= math.MaxInt16 && b >= math.MinInt16 {
return out.WriteI16(int16(b))
}
return fmt.Errorf("val[%d] is out of range (int16)", b)
}
return fmt.Errorf("val must be int16, got %T", val)
}

func int32Writer(ctx context.Context, out *thrift.BufferWriter, val interface{}) error {
if b, ok := val.(int32); ok {
return out.WriteI32(b)
}
if b, ok := asInt64(val); ok {
if b <= math.MaxInt32 && b >= math.MinInt32 {
return out.WriteI32(int32(b))
}
return fmt.Errorf("val[%d] is out of range (int32)", b)
}
return fmt.Errorf("val must be int32, got %T", val)
}

func int64Writer(ctx context.Context, out *thrift.BufferWriter, val interface{}) error {
if b, ok := val.(int64); ok {
return out.WriteI64(b)
}
if b, ok := asInt64(val); ok {
return out.WriteI64(b)
}
return fmt.Errorf("val must be int64, got %T", val)
}

func stringWriter(ctx context.Context, out *thrift.BufferWriter, val interface{}) error {
if b, ok := val.(string); ok {
return out.WriteString(b)
}
return fmt.Errorf("val must be string, got %T", val)
}
57 changes: 57 additions & 0 deletions pkg/generic/thrift/write_utils.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
package thrift

import "encoding/json"

func asInt64(v interface{}) (int64, bool) {
switch v := v.(type) {
case int64:
return v, true
case int:
return int64(v), true
}
return asInt64Slow(v)
}

func asInt64Slow(v interface{}) (int64, bool) {
switch v := v.(type) {
case int8:
return int64(v), true
case uint8:
return int64(v), true
case int32:
return int64(v), true
case uint32:
return int64(v), true
case json.Number:
r, e := v.Int64()
return r, e == nil
case int16:
return int64(v), true
case uint16:
return int64(v), true
case uint:
return int64(v), (v >> 63) > 0
case uint64:
return int64(v), (v >> 63) > 0
}
return 0, false
}

func asFloat64(v interface{}) (float64, bool) {
switch v := v.(type) {
case float64:
return v, true
case json.Number:
r, e := v.Float64()
return r, e == nil
}
return asFloat64Slow(v)
}

func asFloat64Slow(v interface{}) (float64, bool) {
if b, ok := v.(float32); ok {
return float64(b), true
}
i64, ok := asInt64(v)
return float64(i64), ok
}