diff --git a/builtin_functions.go b/builtin_functions.go index c3ca27bc..27a063f9 100644 --- a/builtin_functions.go +++ b/builtin_functions.go @@ -9,7 +9,7 @@ import ( func funcCeil(ctx *Context, this *VMValue, params []*VMValue) *VMValue { v, ok := params[0].ReadFloat() if ok { - return VMValueNewInt(int64(math.Ceil(v))) + return VMValueNewInt(IntType(math.Ceil(v))) } else { ctx.Error = errors.New("类型错误: 只能是float") } @@ -19,7 +19,7 @@ func funcCeil(ctx *Context, this *VMValue, params []*VMValue) *VMValue { func funcRound(ctx *Context, this *VMValue, params []*VMValue) *VMValue { v, ok := params[0].ReadFloat() if ok { - return VMValueNewInt(int64(math.Round(v))) + return VMValueNewInt(IntType(math.Round(v))) } else { ctx.Error = errors.New("类型错误: 只能是float") } @@ -29,7 +29,7 @@ func funcRound(ctx *Context, this *VMValue, params []*VMValue) *VMValue { func funcFloor(ctx *Context, this *VMValue, params []*VMValue) *VMValue { v, ok := params[0].ReadFloat() if ok { - return VMValueNewInt(int64(math.Floor(v))) + return VMValueNewInt(IntType(math.Floor(v))) } else { ctx.Error = errors.New("类型错误: 只能是float") } @@ -63,12 +63,12 @@ func funcInt(ctx *Context, this *VMValue, params []*VMValue) *VMValue { return params[0] case VMTypeFloat: v, _ := params[0].ReadFloat() - return VMValueNewInt(int64(v)) + return VMValueNewInt(IntType(v)) case VMTypeString: s, _ := params[0].ReadString() val, err := strconv.ParseInt(s, 10, 64) if err == nil { - return VMValueNewInt(val) + return VMValueNewInt(IntType(val)) } else { ctx.Error = errors.New("值错误: 无法进行 int() 转换: " + s) } diff --git a/bytecode.go b/bytecode.go index ef700abb..792a4630 100644 --- a/bytecode.go +++ b/bytecode.go @@ -114,7 +114,7 @@ const ( func (code *ByteCode) CodeString() string { switch code.T { case TypePushIntNumber: - return "push.int " + strconv.FormatInt(code.Value.(int64), 10) + return "push.int " + strconv.FormatInt(int64(code.Value.(IntType)), 10) case TypePushFloatNumber: return "push.flt " + strconv.FormatFloat(code.Value.(float64), 'f', 2, 64) case TypePushString: @@ -122,9 +122,9 @@ func (code *ByteCode) CodeString() string { case TypePushRange: return "push.range" case TypePushArray: - return "push.arr " + strconv.FormatInt(code.Value.(int64), 10) + return "push.arr " + strconv.FormatInt(int64(code.Value.(IntType)), 10) case TypePushDict: - return "push.dict " + strconv.FormatInt(code.Value.(int64), 10) + return "push.dict " + strconv.FormatInt(int64(code.Value.(IntType)), 10) case TypePushComputed: computed, _ := code.Value.(*VMValue).ReadComputed() return "push.computed " + computed.Expr @@ -141,7 +141,7 @@ func (code *ByteCode) CodeString() string { return "push.func " + computed.Name case TypeInvoke: - return "invoke " + strconv.FormatInt(code.Value.(int64), 10) + return "invoke " + strconv.FormatInt(int64(code.Value.(IntType)), 10) case TypeInvokeSelf: return "invoke.self " + code.Value.(string) diff --git a/cmd/main.go b/cmd/main.go index b98fc182..da6fc5a0 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -2,13 +2,14 @@ package main import ( "fmt" - "github.com/peterh/liner" - dice "github.com/sealdice/dicescript" "os" "path/filepath" "regexp" "strconv" "strings" + + "github.com/peterh/liner" + ds "github.com/sealdice/dicescript" ) var ( @@ -29,17 +30,17 @@ func main() { _ = f.Close() } - attrs := map[string]*dice.VMValue{} + attrs := map[string]*ds.VMValue{} fmt.Println("DiceScript Shell v0.0.1") ccTimes := 0 - vm := dice.NewVM() + vm := ds.NewVM() vm.Config.EnableDiceWoD = true vm.Config.EnableDiceCoC = true vm.Config.EnableDiceFate = true vm.Config.EnableDiceDoubleCross = true vm.Config.PrintBytecode = true - vm.Config.CallbackSt = func(_type string, name string, val *dice.VMValue, extra *dice.VMValue, op string, detail string) { + vm.Config.CallbackSt = func(_type string, name string, val *ds.VMValue, extra *ds.VMValue, op string, detail string) { fmt.Println("st:", _type, name, val.ToString(), extra.ToString(), op, detail) } @@ -47,7 +48,7 @@ func main() { vm.Config.DefaultDiceSideExpr = "面数 ?? 50" vm.Config.OpCountLimit = 30000 - vm.Config.CallbackLoadVar = func(name string) (string, *dice.VMValue) { + vm.Config.CallbackLoadVar = func(name string) (string, *ds.VMValue) { re := regexp.MustCompile(`^(困难|极难|大成功|常规|失败|困難|極難|常規|失敗)?([^\d]+)(\d+)?$`) m := re.FindStringSubmatch(name) var cocFlagVarPrefix string @@ -62,7 +63,7 @@ func main() { if m[3] != "" { v, _ := strconv.ParseInt(m[3], 10, 64) fmt.Println("COC值:", name, cocFlagVarPrefix) - return name, dice.VMValueNewInt(v) + return name, ds.VMValueNewInt(ds.IntType(v)) } } @@ -70,8 +71,8 @@ func main() { return name, nil } - _ = vm.RegCustomDice(`E(\d+)`, func(ctx *dice.Context, groups []string) *dice.VMValue { - return dice.VMValueNewInt(2) + _ = vm.RegCustomDice(`E(\d+)`, func(ctx *ds.Context, groups []string) *ds.VMValue { + return ds.VMValueNewInt(2) }) //vm.ValueStoreNameFunc = func(name string, v *dice.VMValue) { @@ -80,12 +81,12 @@ func main() { re := regexp.MustCompile(`^(\D+)(\d+)$`) - vm.GlobalValueLoadFunc = func(name string) *dice.VMValue { + vm.GlobalValueLoadFunc = func(name string) *ds.VMValue { m := re.FindStringSubmatch(name) if len(m) > 1 { //val, _ := strconv.ParseInt(m[2], 10, 64) //return dice.VMValueNewInt(val) - return dice.VMValueNewInt(0) + return ds.VMValueNewInt(0) } if val, ok := attrs[name]; ok { diff --git a/jsport/main.go b/jsport/main.go index ef8e5b74..bb045d7a 100644 --- a/jsport/main.go +++ b/jsport/main.go @@ -30,7 +30,7 @@ func newVM(name string) *js.Object { m := re.FindStringSubmatch(name) if len(m) > 1 { val, _ := strconv.ParseInt(m[2], 10, 64) - return ds.VMValueNewInt(val) + return ds.VMValueNewInt(ds.IntType(val)) } if v, exists := player.Load(name); exists { @@ -59,7 +59,7 @@ func main() { "newValueMap": func() *js.Object { return js.MakeFullWrapper(&ds.ValueMap{}) }, - "vmNewInt": func(i int64) *js.Object { + "vmNewInt": func(i ds.IntType) *js.Object { return js.MakeFullWrapper(ds.VMValueNewInt(i)) }, "vmNewFloat": func(i float64) *js.Object { diff --git a/parser.go b/parser.go index c6676786..ad85a26b 100644 --- a/parser.go +++ b/parser.go @@ -6,16 +6,16 @@ import ( ) type ParserData struct { - counterStack []int64 // f-string 嵌套计数,在解析时中起作用 - varnameStack []string // 另一个解析用栈 - jmpStack []int64 - breakStack []int64 // break,用时创建 - continueStack []int64 // continue用,用时创建 + counterStack []IntType // f-string 嵌套计数,在解析时中起作用 + varnameStack []string // 另一个解析用栈 + jmpStack []IntType + breakStack []IntType // break,用时创建 + continueStack []IntType // continue用,用时创建 loopInfo []struct { continueIndex int breakIndex int } - loopLayer int64 // 当前loop层数 + loopLayer int // 当前loop层数 codeStack []struct { code []ByteCode index int @@ -23,16 +23,16 @@ type ParserData struct { } type BufferSpan struct { - begin int64 - end int64 + begin IntType + end IntType ret *VMValue text string } func (pd *ParserData) init() { - pd.counterStack = []int64{} + pd.counterStack = []IntType{} pd.varnameStack = []string{} - pd.jmpStack = []int64{} // 不复用counterStack的原因是在 ?: 算符中两个都有用到 + pd.jmpStack = []IntType{} // 不复用counterStack的原因是在 ?: 算符中两个都有用到 pd.codeStack = []struct { code []ByteCode index int @@ -84,14 +84,14 @@ func (e *Parser) WriteCode(T CodeType, value interface{}) { e.codeIndex += 1 } -func (p *Parser) AddDiceDetail(begin int64, end int64) { +func (p *Parser) AddDiceDetail(begin IntType, end IntType) { p.WriteCode(TypeDetailMark, BufferSpan{begin: begin, end: end}) } func (e *Parser) AddOp(operator CodeType) { var val interface{} = nil if operator == TypeJne || operator == TypeJmp { - val = int64(0) + val = IntType(0) } e.WriteCode(operator, val) } @@ -102,18 +102,18 @@ func (e *Parser) AddLoadName(value string) { func (e *Parser) PushIntNumber(value string) { val, _ := strconv.ParseInt(value, 10, 64) - e.WriteCode(TypePushIntNumber, int64(val)) + e.WriteCode(TypePushIntNumber, IntType(val)) } func (e *Parser) PushStr(value string) { e.WriteCode(TypePushString, value) } -func (e *Parser) PushArray(value int64) { +func (e *Parser) PushArray(value IntType) { e.WriteCode(TypePushArray, value) } -func (e *Parser) PushDict(value int64) { +func (e *Parser) PushDict(value IntType) { e.WriteCode(TypePushDict, value) } @@ -129,7 +129,7 @@ func (e *Parser) PushGlobal() { e.WriteCode(TypePushGlobal, nil) } -func (e *Parser) AddFormatString(value string, num int64) { +func (e *Parser) AddFormatString(value string, num IntType) { //e.PushStr(value) e.WriteCode(TypeLoadFormatString, num) // num } @@ -176,16 +176,16 @@ func (e *Parser) NamePop() string { } func (e *Parser) OffsetPush() { - e.jmpStack = append(e.jmpStack, int64(e.codeIndex)-1) + e.jmpStack = append(e.jmpStack, IntType(e.codeIndex)-1) } func (p *Parser) ContinuePush() { if p.loopLayer > 0 { if p.continueStack == nil { - p.continueStack = []int64{} + p.continueStack = []IntType{} } p.AddOp(TypeJmp) - p.continueStack = append(p.continueStack, int64(p.codeIndex)-1) + p.continueStack = append(p.continueStack, IntType(p.codeIndex)-1) } else { p.Error = errors.New("循环外不能放置continue") } @@ -198,7 +198,7 @@ func (p *Parser) ContinueSet(offsetB int) { lastB := len(p.jmpStack) - 1 - offsetB jmpIndex := p.jmpStack[lastB] // 试出来的,这个是对的,那么也许while那个是错的??还是说因为while最后多push了一个jmp呢? - p.code[codeIndex].Value = -(int64(codeIndex) - jmpIndex) + p.code[codeIndex].Value = -(IntType(codeIndex) - jmpIndex) } } } @@ -207,7 +207,7 @@ func (p *Parser) BreakSet() { if p.breakStack != nil { info := p.loopInfo[len(p.loopInfo)-1] for _, codeIndex := range p.breakStack[info.breakIndex:] { - p.code[codeIndex].Value = int64(p.codeIndex) - codeIndex - 1 + p.code[codeIndex].Value = IntType(p.codeIndex) - codeIndex - 1 } } } @@ -215,10 +215,10 @@ func (p *Parser) BreakSet() { func (p *Parser) BreakPush() { if p.loopLayer > 0 { if p.breakStack == nil { - p.breakStack = []int64{} + p.breakStack = []IntType{} } p.AddOp(TypeJmp) - p.breakStack = append(p.breakStack, int64(p.codeIndex)-1) + p.breakStack = append(p.breakStack, IntType(p.codeIndex)-1) } else { p.Error = errors.New("循环外不能放置break") } @@ -228,7 +228,7 @@ func (e *Parser) OffsetPopAndSet() { last := len(e.jmpStack) - 1 codeIndex := e.jmpStack[last] e.jmpStack = e.jmpStack[:last] - e.code[codeIndex].Value = int64(int64(e.codeIndex) - codeIndex - 1) + e.code[codeIndex].Value = IntType(IntType(e.codeIndex) - codeIndex - 1) //fmt.Println("XXXX", e.Code[codeIndex], "|", e.Top, codeIndex) } @@ -245,9 +245,9 @@ func (e *Parser) OffsetJmpSetX(offsetA int, offsetB int, rev bool) { jmpIndex := e.jmpStack[lastB] if rev { - e.code[codeIndex].Value = -(int64(e.codeIndex) - jmpIndex - 1) + e.code[codeIndex].Value = -(IntType(e.codeIndex) - jmpIndex - 1) } else { - e.code[codeIndex].Value = int64(e.codeIndex) - jmpIndex - 1 + e.code[codeIndex].Value = IntType(e.codeIndex) - jmpIndex - 1 } } @@ -255,14 +255,14 @@ func (e *Parser) CounterPush() { e.counterStack = append(e.counterStack, 0) } -func (e *Parser) CounterAdd(offset int64) { +func (e *Parser) CounterAdd(offset IntType) { last := len(e.counterStack) - 1 if last != -1 { e.counterStack[last] += offset } } -func (e *Parser) CounterPop() int64 { +func (e *Parser) CounterPop() IntType { last := len(e.counterStack) - 1 num := e.counterStack[last] e.counterStack = e.counterStack[:last] @@ -279,12 +279,12 @@ func (e *Parser) FlagsPop() { e.flagsStack = e.flagsStack[:last] } -func (e *Parser) AddInvokeMethod(name string, paramsNum int64) { +func (e *Parser) AddInvokeMethod(name string, paramsNum IntType) { e.WriteCode(TypePushIntNumber, paramsNum) e.WriteCode(TypeInvokeSelf, name) } -func (e *Parser) AddInvoke(paramsNum int64) { +func (e *Parser) AddInvoke(paramsNum IntType) { //e.WriteCode(TypePushIntNumber, paramsNum) e.WriteCode(TypeInvoke, paramsNum) } diff --git a/roll.peg b/roll.peg index 911564d8..0e10632f 100644 --- a/roll.peg +++ b/roll.peg @@ -80,7 +80,7 @@ func_def_params <- '(' sp ')' sp { p.CounterPush() } / '(' sp { p.CounterPush(); p.CounterAdd(1) } identifier sp { p.NamePush(text) } (',' sp identifier sp { p.NamePush(text) } {p.CounterAdd(1)} )* ')' sp stmtFunc <- 'func' sp1 identifier sp { p.NamePush(text) } func_def_params '{' sp { p.CodePush() } < stmtRoot? > '}' sp - { num := p.CounterPop(); arr := []string{}; for i:=int64(0); i '}' sp - { num := p.CounterPop(); arr := []string{}; for i:=int64(0); i */ nil, nil, - /* 144 Action18 <- <{ num := p.CounterPop(); arr := []string{}; for i:=int64(0); i */ + /* 144 Action18 <- <{ num := p.CounterPop(); arr := []string{}; for i:=IntType(0); i */ nil, /* 145 Action19 <- <{ p.NamePush(text) }> */ nil, @@ -9508,7 +9508,7 @@ func (p *Parser) Init(options ...func(*Parser) error) error { nil, /* 171 Action45 <- <{p.CounterAdd(1)}> */ nil, - /* 172 Action46 <- <{ p.PushStr(""); limit:=p.CounterPop()+1; for i:=int64(0); i */ + /* 172 Action46 <- <{ p.PushStr(""); limit:=p.CounterPop()+1; for i:=IntType(0); i */ nil, /* 173 Action47 <- <{p.AddOp(TypeJeDup); p.OffsetPush()}> */ nil, @@ -9576,9 +9576,9 @@ func (p *Parser) Init(options ...func(*Parser) error) error { nil, /* 205 Action79 <- <{ p.PushIntNumber("2"); p.AddOp(TypeDiceSetTimes); p.PushIntNumber("1"); p.AddOp(TypeDiceSetKeepLowNum) }> */ nil, - /* 206 Action80 <- <{ p.CounterPush(); p.CounterAdd(int64(token.begin)) }> */ + /* 206 Action80 <- <{ p.CounterPush(); p.CounterAdd(IntType(token.begin)) }> */ nil, - /* 207 Action81 <- <{ p.AddDiceDetail(p.CounterPop(), int64(token.end)) }> */ + /* 207 Action81 <- <{ p.AddDiceDetail(p.CounterPop(), IntType(token.end)) }> */ nil, /* 208 Action82 <- <{ p.AddOp(TypeDiceInit); p.AddOp(TypeDiceSetTimes); }> */ nil, @@ -9688,7 +9688,7 @@ func (p *Parser) Init(options ...func(*Parser) error) error { nil, /* 261 Action135 <- <{ p.CodePush() }> */ nil, - /* 262 Action136 <- <{ num := p.CounterPop(); arr := []string{}; for i:=int64(0); i */ + /* 262 Action136 <- <{ num := p.CounterPop(); arr := []string{}; for i:=IntType(0); i */ nil, /* 263 Action137 <- <{ p.PushIntNumber(string(text)) }> */ nil, diff --git a/roll_func.go b/roll_func.go index c1c668da..1a7f2a63 100644 --- a/roll_func.go +++ b/roll_func.go @@ -9,15 +9,15 @@ import ( "strings" ) -func Roll(dicePoints int64) int64 { +func Roll(dicePoints IntType) IntType { if dicePoints == 0 { return 0 } - val := rand.Int63()%dicePoints + 1 + val := IntType(rand.Int63())%dicePoints + 1 return val } -func wodCheck(e *Context, addLine int64, pool int64, points int64, threshold int64) bool { +func wodCheck(e *Context, addLine IntType, pool IntType, points IntType, threshold IntType) bool { //makeE6 := func() { // e.Error = errors.New("E6: 类型错误") //} @@ -46,19 +46,19 @@ func wodCheck(e *Context, addLine int64, pool int64, points int64, threshold int } // RollWoD 返回: 成功数,总骰数,轮数,细节 -func RollWoD(addLine int64, pool int64, points int64, threshold int64, isGE bool) (int64, int64, int64, string) { +func RollWoD(addLine IntType, pool IntType, points IntType, threshold IntType, isGE bool) (IntType, IntType, IntType, string) { var details []string addTimes := 1 isShowDetails := pool < 15 allRollCount := pool - successCount := int64(0) + successCount := IntType(0) for times := 0; times < addTimes; times++ { - addCount := int64(0) + addCount := IntType(0) detailsOne := []string{} - for i := int64(0); i < pool; i++ { + for i := IntType(0); i < pool; i++ { var reachSuccess bool var reachAddRound bool one := Roll(points) @@ -81,7 +81,7 @@ func RollWoD(addLine int64, pool int64, points int64, threshold int64, isGE bool } if isShowDetails { - baseText := strconv.FormatInt(one, 10) + baseText := strconv.FormatInt(int64(one), 10) if reachSuccess { baseText += "*" } @@ -123,10 +123,10 @@ func RollWoD(addLine int64, pool int64, points int64, threshold int64, isGE bool detailText = fmt.Sprintf("成功%d/%d%s%s", successCount, allRollCount, roundsText, detailText) // 成功数,总骰数,轮数,细节 - return successCount, allRollCount, int64(addTimes), detailText + return successCount, allRollCount, IntType(addTimes), detailText } -func doubleCrossCheck(ctx *Context, addLine, pool, points int64) bool { +func doubleCrossCheck(ctx *Context, addLine, pool, points IntType) bool { if pool < 1 || pool > 20000 { ctx.Error = errors.New("E7: 非法数值, 骰池范围是1到20000") return false @@ -145,20 +145,20 @@ func doubleCrossCheck(ctx *Context, addLine, pool, points int64) bool { return true } -func RollDoubleCross(addLine int64, pool int64, points int64) (int64, int64, int64, string) { +func RollDoubleCross(addLine IntType, pool IntType, points IntType) (IntType, IntType, IntType, string) { var details []string addTimes := 1 isShowDetails := pool < 15 allRollCount := pool - resultDice := int64(0) + resultDice := IntType(0) for times := 0; times < addTimes; times++ { - addCount := int64(0) + addCount := IntType(0) detailsOne := []string{} - maxDice := int64(0) + maxDice := IntType(0) - for i := int64(0); i < pool; i++ { + for i := IntType(0); i < pool; i++ { one := Roll(points) if one > maxDice { maxDice = one @@ -171,7 +171,7 @@ func RollDoubleCross(addLine int64, pool int64, points int64) (int64, int64, int } if isShowDetails { - baseText := strconv.FormatInt(one, 10) + baseText := strconv.FormatInt(int64(one), 10) if reachAddRound { baseText = "<" + baseText + ">" } @@ -218,13 +218,13 @@ func RollDoubleCross(addLine int64, pool int64, points int64) (int64, int64, int } // 成功数,总骰数,轮数,细节 - return resultDice, allRollCount, int64(addTimes), lastDetail + return resultDice, allRollCount, IntType(addTimes), lastDetail } // RollCommon (times)d(dicePoints)kl(lowNum) 或 (times)d(dicePoints)kh(highNum) -func RollCommon(times, dicePoints int64, diceMin, diceMax *int64, isKeepLH, lowNum, highNum int64) (int64, string) { - var nums []int64 - for i := int64(0); i < times; i += 1 { +func RollCommon(times, dicePoints IntType, diceMin, diceMax *IntType, isKeepLH, lowNum, highNum IntType) (IntType, string) { + var nums []IntType + for i := IntType(0); i < times; i += 1 { die := Roll(dicePoints) if diceMax != nil { if die > *diceMax { @@ -270,10 +270,10 @@ func RollCommon(times, dicePoints int64, diceMin, diceMax *int64, isKeepLH, lowN } } - num := int64(0) - for i := int64(0); i < pickNum; i++ { + num := IntType(0) + for i := IntType(0); i < pickNum; i++ { // 当取数大于上限 跳过 - if i >= int64(len(nums)) { + if i >= IntType(len(nums)) { continue } num += nums[i] @@ -292,7 +292,7 @@ func RollCommon(times, dicePoints int64, diceMin, diceMax *int64, isKeepLH, lowN } } else { text = "{" - for i := int64(0); i < int64(len(nums)); i++ { + for i := IntType(0); i < IntType(len(nums)); i++ { if i == pickNum { text += "| " } @@ -307,17 +307,17 @@ func RollCommon(times, dicePoints int64, diceMin, diceMax *int64, isKeepLH, lowN return num, text } -func RollCoC(isBonus bool, diceNum int64) (int64, string) { +func RollCoC(isBonus bool, diceNum IntType) (IntType, string) { diceResult := Roll(100) diceTens := diceResult / 10 diceUnits := diceResult % 10 - nums := []string{} + var nums []string diceMin := diceTens diceMax := diceTens num10Exists := false - for i := int64(0); i < diceNum; i++ { + for i := IntType(0); i < diceNum; i++ { n := Roll(10) if n == 10 { @@ -325,7 +325,7 @@ func RollCoC(isBonus bool, diceNum int64) (int64, string) { nums = append(nums, "0") continue } else { - nums = append(nums, strconv.FormatInt(n, 10)) + nums = append(nums, strconv.FormatInt(int64(n), 10)) } if n < diceMin { @@ -357,11 +357,11 @@ func RollCoC(isBonus bool, diceNum int64) (int64, string) { } } -func RollFate() (int64, string) { +func RollFate() (IntType, string) { detail := "" - sum := int64(0) + sum := IntType(0) for i := 0; i < 4; i++ { - n := rand.Int63()%3 - 1 + n := IntType(rand.Int63())%3 - 1 sum += n switch n { case -1: diff --git a/roll_func_test.go b/roll_func_test.go index ef1d0fef..262adc5d 100644 --- a/roll_func_test.go +++ b/roll_func_test.go @@ -7,7 +7,7 @@ import ( func TestRollCommon(t *testing.T) { ret, _ := RollCommon(5, 1, nil, nil, 0, 0, 0) - assert.Equal(t, ret, int64(5)) + assert.Equal(t, ret, IntType(5)) } func TestRollDoubleCross(t *testing.T) { @@ -17,5 +17,5 @@ func TestRollDoubleCross(t *testing.T) { func TestRollWoD(t *testing.T) { ret, _, _, _ := RollWoD(11, 8, 10, 1, true) // 8a11m10k1 - assert.Equal(t, int64(8), ret) + assert.Equal(t, IntType(8), ret) } diff --git a/rollvm.go b/rollvm.go index 0717e427..0d62f42a 100644 --- a/rollvm.go +++ b/rollvm.go @@ -127,7 +127,7 @@ func (e *Parser) Evaluate() { ctx := &e.Context var details []BufferSpan - numOpCountAdd := func(count int64) bool { + numOpCountAdd := func(count IntType) bool { e.NumOpCount += count if ctx.Config.OpCountLimit > 0 && e.NumOpCount > ctx.Config.OpCountLimit { ctx.Error = errors.New("允许算力上限") @@ -138,23 +138,23 @@ func (e *Parser) Evaluate() { diceStateIndex := -1 var diceStates []struct { - times int64 // 次数,如 2d10,times为2 - isKeepLH int64 // 为1对应取低个数,为2对应取高个数,3为丢弃低个数,4为丢弃高个数 - lowNum int64 - highNum int64 - min *int64 - max *int64 + times IntType // 次数,如 2d10,times为2 + isKeepLH IntType // 为1对应取低个数,为2对应取高个数,3为丢弃低个数,4为丢弃高个数 + lowNum IntType + highNum IntType + min *IntType + max *IntType } diceInit := func() { diceStateIndex += 1 data := struct { - times int64 // 次数,如 2d10,times为2 - isKeepLH int64 // 为1对应取低个数,为2对应取高个数 - lowNum int64 - highNum int64 - min *int64 - max *int64 + times IntType // 次数,如 2d10,times为2 + isKeepLH IntType // 为1对应取低个数,为2对应取高个数 + lowNum IntType + highNum IntType + min *IntType + max *IntType }{ times: 1, } @@ -168,9 +168,9 @@ func (e *Parser) Evaluate() { } var wodState struct { - pool int64 - points int64 - threshold int64 + pool IntType + points IntType + threshold IntType isGE bool } @@ -182,8 +182,8 @@ func (e *Parser) Evaluate() { } var dcState struct { - pool int64 - points int64 + pool IntType + points IntType } dcInit := func() { @@ -196,12 +196,12 @@ func (e *Parser) Evaluate() { return } var m []struct { - begin int64 - end int64 + begin IntType + end IntType spans []BufferSpan } - curPoint := int64(-1) // nolint - lastEnd := int64(-1) // nolint + curPoint := IntType(-1) // nolint + lastEnd := IntType(-1) // nolint sort.Sort(spanByBegin(details)) for _, i := range details { @@ -209,8 +209,8 @@ func (e *Parser) Evaluate() { if i.begin > lastEnd { curPoint = i.begin m = append(m, struct { - begin int64 - end int64 + begin IntType + end IntType spans []BufferSpan }{begin: curPoint, end: i.end, spans: []BufferSpan{i}}) } else { @@ -279,9 +279,9 @@ func (e *Parser) Evaluate() { return v1, v2 } - stackPopN := func(num int64) []*VMValue { + stackPopN := func(num IntType) []*VMValue { var data []*VMValue - for i := int64(0); i < num; i++ { + for i := IntType(0); i < num; i++ { data = append(data, stackPop().Clone()) // 复制一遍规避栈问题 } for i, j := 0, len(data)-1; i < j; i, j = i+1, j-1 { @@ -339,10 +339,10 @@ func (e *Parser) Evaluate() { stack[e.top].Value = unquote e.top++ case TypePushArray: - num := code.Value.(int64) + num := code.Value.(IntType) stackPush(VMValueNewArray(stackPopN(num)...)) case TypePushDict: - num := code.Value.(int64) + num := code.Value.(IntType) items := stackPopN(num * 2) dict, err := VMValueNewDictWithArray(items...) if err != nil { @@ -369,7 +369,7 @@ func (e *Parser) Evaluate() { return } - step := int64(1) + step := IntType(1) length := _b - _a if length < 0 { step = -1 @@ -442,7 +442,7 @@ func (e *Parser) Evaluate() { } case TypeInvoke: - paramsNum := code.Value.(int64) + paramsNum := code.Value.(IntType) arr := stackPopN(paramsNum) funcObj := stackPop() @@ -477,7 +477,7 @@ func (e *Parser) Evaluate() { val := stackPop() // 右值 itemIndex := stackPop() // 下标 obj := stackPop() // 数组 / 对象 - obj.ItemSet(ctx, itemIndex, val) + obj.ItemSet(ctx, itemIndex, val.Clone()) if ctx.Error != nil { return } @@ -485,7 +485,7 @@ func (e *Parser) Evaluate() { attrVal, obj := stackPop2() attrName := code.Value.(string) - ret := obj.AttrSet(ctx, attrName, attrVal) + ret := obj.AttrSet(ctx, attrName, attrVal.Clone()) if ctx.Error == nil && ret == nil { ctx.Error = errors.New("不支持的类型:当前变量无法用.来设置属性") } @@ -543,7 +543,7 @@ func (e *Parser) Evaluate() { return case TypeLoadFormatString: - num := int(code.Value.(int64)) + num := int(code.Value.(IntType)) outStr := "" for index := 0; index < num; index++ { @@ -600,7 +600,7 @@ func (e *Parser) Evaluate() { case TypeJe, TypeJeDup: v := stackPop() if v.AsBool() { - opIndex += int(code.Value.(int64)) + opIndex += int(code.Value.(IntType)) if code.T == TypeJeDup { stackPush(v) } @@ -608,14 +608,14 @@ func (e *Parser) Evaluate() { case TypeJne: t := stackPop() if !t.AsBool() { - opIndex += int(code.Value.(int64)) + opIndex += int(code.Value.(IntType)) } case TypeJmp: - opIndex += int(code.Value.(int64)) + opIndex += int(code.Value.(IntType)) case TypePop: stackPop() case TypePopN: - stackPopN(code.Value.(int64)) + stackPopN(code.Value.(IntType)) case TypeAdd, TypeSubtract, TypeMultiply, TypeDivide, TypeModulus, TypeExponentiation, TypeNullCoalescing, TypeCompLT, TypeCompLE, TypeCompEQ, TypeCompNE, TypeCompGE, TypeCompGT, diff --git a/rollvm_test.go b/rollvm_test.go index e37eca99..75508e86 100644 --- a/rollvm_test.go +++ b/rollvm_test.go @@ -2,11 +2,12 @@ package dicescript import ( "fmt" - "github.com/stretchr/testify/assert" "regexp" "strconv" "strings" "testing" + + "github.com/stretchr/testify/assert" ) func vmValueEqual(vm *Context, aKey string, bValue *VMValue) bool { @@ -147,7 +148,7 @@ func TestUnsupportedOperandType(t *testing.T) { err := vm.Run("2 % 3.1") if assert.Error(t, err) { // VM Error: 这两种类型无法使用 mod 算符连接: int64, float64 - assert.Equal(t, err.Error(), "这两种类型无法使用 mod 算符连接: int64, float64") + assert.Equal(t, err.Error(), "这两种类型无法使用 mod 算符连接: int, float") } } @@ -727,7 +728,7 @@ fib(11) func TestBytecodeToString(t *testing.T) { ops := []ByteCode{ - {TypePushIntNumber, int64(1)}, + {TypePushIntNumber, IntType(1)}, {TypePushFloatNumber, float64(1.2)}, {TypePushString, "abc"}, @@ -763,9 +764,9 @@ func TestBytecodeToString(t *testing.T) { {TypeDiceSetMin, nil}, {TypeDiceSetMax, nil}, - {TypeJmp, int64(0)}, - {TypeJe, int64(0)}, - {TypeJne, int64(0)}, + {TypeJmp, IntType(0)}, + {TypeJe, IntType(0)}, + {TypeJne, IntType(0)}, } for _, i := range ops { @@ -1391,3 +1392,39 @@ func TestLogicOrBug(t *testing.T) { assert.True(t, valueEqual(vm.Ret, ni(2))) } } + +func TestAttrSetBug(t *testing.T) { + // 这个问题是最后一个函数调用的第一个参数成了他自己 + // 例如下面这个例子中,str(a.x)中,str拿到的参数是 &{2 nfunction str} + // 原因是 attr_set 在设置时未进行值复制,而是拿到了vm栈地址,栈被覆盖后问题就出现了 + // 2024/04/23 绑定时发现 + // ItemSet 也有同样问题 + vm := NewVM() + err := vm.Run(`a = {}; a.x = 10; str(a.x)`) + if assert.NoError(t, err) { + assert.True(t, valueEqual(vm.Ret, ns("10"))) + } +} + +func TestItemSetBug2(t *testing.T) { + vm := NewVM() + err := vm.Run(`a = {}; a[1] = 10; str(a[1])`) + if assert.NoError(t, err) { + assert.True(t, valueEqual(vm.Ret, ns("10"))) + } +} + +func TestStackTop(t *testing.T) { + vm := NewVM() + _ = vm.Run(`1;2;3`) + assert.Equal(t, vm.StackTop(), 3) // 暂时的设计是只在语句块弃栈 + + _ = vm.Run(`4`) + assert.Equal(t, vm.StackTop(), 1) // 二次运行清空栈 + + _ = vm.Run(`while (i<10) { i=i+1; 1;2;3 }`) + assert.Equal(t, vm.StackTop(), 0) // 语句块后空栈 + + _ = vm.Run(`1;2; while (i<10) { i=i+1; 1;2;3 }`) + assert.Equal(t, vm.StackTop(), 2) // 语句块弃栈不影响上级 +} diff --git a/types.go b/types.go index 705b7128..8f3cc741 100644 --- a/types.go +++ b/types.go @@ -27,6 +27,7 @@ import ( ) type VMValueType int +type IntType int // :IntType const ( VMTypeInt VMValueType = 0 @@ -79,7 +80,7 @@ type RollConfig struct { CallbackLoadVar func(name string) (string, *VMValue) // 加载变量回调,返回值会成为新变量名 CallbackSt func(_type string, name string, val *VMValue, extra *VMValue, op string, detail string) // st回调 - OpCountLimit int64 // 算力限制,超过这个值会报错,0为无限,建议值30000 + OpCountLimit IntType // 算力限制,超过这个值会报错,0为无限,建议值30000 DefaultDiceSideExpr string // 默认骰子面数 defaultDiceSideExprCacheFunc *VMValue // expr的缓存函数 @@ -109,7 +110,7 @@ type Context struct { stack []VMValue top int - NumOpCount int64 // 算力计数 + NumOpCount IntType // 算力计数 //CocFlagVarPrefix string // 解析过程中出现,当VarNumber开启时有效,可以是困难极难常规大成功 Config RollConfig // 标记 @@ -366,7 +367,7 @@ func (v *VMValue) Clone() *VMValue { func (v *VMValue) AsBool() bool { switch v.TypeId { case VMTypeInt: - return v.Value != int64(0) + return v.Value != IntType(0) case VMTypeString: return v.Value != "" case VMTypeNull, VMTypeUndefined: @@ -394,7 +395,7 @@ func (v *VMValue) toStringRaw(ri *recursionInfo) string { } switch v.TypeId { case VMTypeInt: - return strconv.FormatInt(v.Value.(int64), 10) + return strconv.FormatInt(int64(v.Value.(IntType)), 10) case VMTypeFloat: return strconv.FormatFloat(v.Value.(float64), 'f', -1, 64) case VMTypeString: @@ -481,9 +482,9 @@ func (v *VMValue) ToRepr() string { return v.toReprRaw(ri) } -func (v *VMValue) ReadInt() (int64, bool) { +func (v *VMValue) ReadInt() (IntType, bool) { if v.TypeId == VMTypeInt { - return v.Value.(int64), true + return v.Value.(IntType), true } return 0, false } @@ -537,7 +538,7 @@ func (v *VMValue) MustReadArray() *ArrayData { panic("错误: 不正确的类型") } -func (v *VMValue) MustReadInt() int64 { +func (v *VMValue) MustReadInt() IntType { val, ok := v.ReadInt() if ok { return val @@ -579,16 +580,16 @@ func (v *VMValue) OpAdd(ctx *Context, v2 *VMValue) *VMValue { case VMTypeInt: switch v2.TypeId { case VMTypeInt: - val := v.Value.(int64) + v2.Value.(int64) + val := v.Value.(IntType) + v2.Value.(IntType) return VMValueNewInt(val) case VMTypeFloat: - val := float64(v.Value.(int64)) + v2.Value.(float64) + val := float64(v.Value.(IntType)) + v2.Value.(float64) return VMValueNewFloat(val) } case VMTypeFloat: switch v2.TypeId { case VMTypeInt: - val := v.Value.(float64) + float64(v2.Value.(int64)) + val := v.Value.(float64) + float64(v2.Value.(IntType)) return VMValueNewFloat(val) case VMTypeFloat: val := v.Value.(float64) + v2.Value.(float64) @@ -629,16 +630,16 @@ func (v *VMValue) OpSub(ctx *Context, v2 *VMValue) *VMValue { case VMTypeInt: switch v2.TypeId { case VMTypeInt: - val := v.Value.(int64) - v2.Value.(int64) + val := v.Value.(IntType) - v2.Value.(IntType) return VMValueNewInt(val) case VMTypeFloat: - val := float64(v.Value.(int64)) - v2.Value.(float64) + val := float64(v.Value.(IntType)) - v2.Value.(float64) return VMValueNewFloat(val) } case VMTypeFloat: switch v2.TypeId { case VMTypeInt: - val := v.Value.(float64) - float64(v2.Value.(int64)) + val := v.Value.(float64) - float64(v2.Value.(IntType)) return VMValueNewFloat(val) case VMTypeFloat: val := v.Value.(float64) - v2.Value.(float64) @@ -655,10 +656,10 @@ func (v *VMValue) OpMultiply(ctx *Context, v2 *VMValue) *VMValue { switch v2.TypeId { case VMTypeInt: // TODO: 溢出,均未考虑溢出 - val := v.Value.(int64) * v2.Value.(int64) + val := v.Value.(IntType) * v2.Value.(IntType) return VMValueNewInt(val) case VMTypeFloat: - val := float64(v.Value.(int64)) * v2.Value.(float64) + val := float64(v.Value.(IntType)) * v2.Value.(float64) return VMValueNewFloat(val) case VMTypeArray: return v2.ArrayRepeatTimesEx(ctx, v) @@ -666,7 +667,7 @@ func (v *VMValue) OpMultiply(ctx *Context, v2 *VMValue) *VMValue { case VMTypeFloat: switch v2.TypeId { case VMTypeInt: - val := v.Value.(float64) * float64(v2.Value.(int64)) + val := v.Value.(float64) * float64(v2.Value.(IntType)) return VMValueNewFloat(val) case VMTypeFloat: val := v.Value.(float64) * v2.Value.(float64) @@ -692,25 +693,25 @@ func (v *VMValue) OpDivide(ctx *Context, v2 *VMValue) *VMValue { case VMTypeInt: switch v2.TypeId { case VMTypeInt: - if v2.Value.(int64) == 0 { + if v2.Value.(IntType) == 0 { return setDivideZero() } - val := v.Value.(int64) / v2.Value.(int64) + val := v.Value.(IntType) / v2.Value.(IntType) return VMValueNewInt(val) case VMTypeFloat: if v2.Value.(float64) == 0 { return setDivideZero() } - val := float64(v.Value.(int64)) / v2.Value.(float64) + val := float64(v.Value.(IntType)) / v2.Value.(float64) return VMValueNewFloat(val) } case VMTypeFloat: switch v2.TypeId { case VMTypeInt: - if v2.Value.(int64) == 0 { + if v2.Value.(IntType) == 0 { return setDivideZero() } - val := v.Value.(float64) / float64(v2.Value.(int64)) + val := v.Value.(float64) / float64(v2.Value.(IntType)) return VMValueNewFloat(val) case VMTypeFloat: if v2.Value.(float64) == 0 { @@ -733,11 +734,11 @@ func (v *VMValue) OpModulus(ctx *Context, v2 *VMValue) *VMValue { case VMTypeInt: switch v2.TypeId { case VMTypeInt: - if v2.Value.(int64) == 0 { + if v2.Value.(IntType) == 0 { setDivideZero() return nil } - val := v.Value.(int64) % v2.Value.(int64) + val := v.Value.(IntType) % v2.Value.(IntType) return VMValueNewInt(val) } } @@ -750,16 +751,16 @@ func (v *VMValue) OpPower(ctx *Context, v2 *VMValue) *VMValue { case VMTypeInt: switch v2.TypeId { case VMTypeInt: - val := int64(math.Pow(float64(v.Value.(int64)), float64(v2.Value.(int64)))) + val := IntType(math.Pow(float64(v.Value.(IntType)), float64(v2.Value.(IntType)))) return VMValueNewInt(val) case VMTypeFloat: - val := math.Pow(float64(v.Value.(int64)), v2.Value.(float64)) + val := math.Pow(float64(v.Value.(IntType)), v2.Value.(float64)) return VMValueNewFloat(val) } case VMTypeFloat: switch v2.TypeId { case VMTypeInt: - val := math.Pow(v.Value.(float64), float64(v2.Value.(int64))) + val := math.Pow(v.Value.(float64), float64(v2.Value.(IntType))) return VMValueNewFloat(val) case VMTypeFloat: val := math.Pow(v.Value.(float64), v2.Value.(float64)) @@ -779,7 +780,7 @@ func (v *VMValue) OpNullCoalescing(ctx *Context, v2 *VMValue) *VMValue { } func boolToVMValue(v bool) *VMValue { - var val int64 + var val IntType if v { val = 1 } @@ -791,14 +792,14 @@ func (v *VMValue) OpCompLT(ctx *Context, v2 *VMValue) *VMValue { case VMTypeInt: switch v2.TypeId { case VMTypeInt: - return boolToVMValue(v.Value.(int64) < v2.Value.(int64)) + return boolToVMValue(v.Value.(IntType) < v2.Value.(IntType)) case VMTypeFloat: - return boolToVMValue(float64(v.Value.(int64)) < v2.Value.(float64)) + return boolToVMValue(float64(v.Value.(IntType)) < v2.Value.(float64)) } case VMTypeFloat: switch v2.TypeId { case VMTypeInt: - return boolToVMValue(v.Value.(float64) < float64(v2.Value.(int64))) + return boolToVMValue(v.Value.(float64) < float64(v2.Value.(IntType))) case VMTypeFloat: return boolToVMValue(v.Value.(float64) < v2.Value.(float64)) } @@ -812,14 +813,14 @@ func (v *VMValue) OpCompLE(ctx *Context, v2 *VMValue) *VMValue { case VMTypeInt: switch v2.TypeId { case VMTypeInt: - return boolToVMValue(v.Value.(int64) <= v2.Value.(int64)) + return boolToVMValue(v.Value.(IntType) <= v2.Value.(IntType)) case VMTypeFloat: - return boolToVMValue(float64(v.Value.(int64)) <= v2.Value.(float64)) + return boolToVMValue(float64(v.Value.(IntType)) <= v2.Value.(float64)) } case VMTypeFloat: switch v2.TypeId { case VMTypeInt: - return boolToVMValue(v.Value.(float64) <= float64(v2.Value.(int64))) + return boolToVMValue(v.Value.(float64) <= float64(v2.Value.(IntType))) case VMTypeFloat: return boolToVMValue(v.Value.(float64) <= v2.Value.(float64)) } @@ -842,14 +843,14 @@ func (v *VMValue) OpCompGE(ctx *Context, v2 *VMValue) *VMValue { case VMTypeInt: switch v2.TypeId { case VMTypeInt: - return boolToVMValue(v.Value.(int64) >= v2.Value.(int64)) + return boolToVMValue(v.Value.(IntType) >= v2.Value.(IntType)) case VMTypeFloat: - return boolToVMValue(float64(v.Value.(int64)) >= v2.Value.(float64)) + return boolToVMValue(float64(v.Value.(IntType)) >= v2.Value.(float64)) } case VMTypeFloat: switch v2.TypeId { case VMTypeInt: - return boolToVMValue(v.Value.(float64) >= float64(v2.Value.(int64))) + return boolToVMValue(v.Value.(float64) >= float64(v2.Value.(IntType))) case VMTypeFloat: return boolToVMValue(v.Value.(float64) >= v2.Value.(float64)) } @@ -863,14 +864,14 @@ func (v *VMValue) OpCompGT(ctx *Context, v2 *VMValue) *VMValue { case VMTypeInt: switch v2.TypeId { case VMTypeInt: - return boolToVMValue(v.Value.(int64) > v2.Value.(int64)) + return boolToVMValue(v.Value.(IntType) > v2.Value.(IntType)) case VMTypeFloat: - return boolToVMValue(float64(v.Value.(int64)) > v2.Value.(float64)) + return boolToVMValue(float64(v.Value.(IntType)) > v2.Value.(float64)) } case VMTypeFloat: switch v2.TypeId { case VMTypeInt: - return boolToVMValue(v.Value.(float64) > float64(v2.Value.(int64))) + return boolToVMValue(v.Value.(float64) > float64(v2.Value.(IntType))) case VMTypeFloat: return boolToVMValue(v.Value.(float64) > v2.Value.(float64)) } @@ -884,7 +885,7 @@ func (v *VMValue) OpBitwiseAnd(ctx *Context, v2 *VMValue) *VMValue { case VMTypeInt: switch v2.TypeId { case VMTypeInt: - return VMValueNewInt(v.Value.(int64) & v2.Value.(int64)) + return VMValueNewInt(v.Value.(IntType) & v2.Value.(IntType)) } } return nil @@ -895,7 +896,7 @@ func (v *VMValue) OpBitwiseOr(ctx *Context, v2 *VMValue) *VMValue { case VMTypeInt: switch v2.TypeId { case VMTypeInt: - return VMValueNewInt(v.Value.(int64) | v2.Value.(int64)) + return VMValueNewInt(v.Value.(IntType) | v2.Value.(IntType)) } } return nil @@ -904,7 +905,7 @@ func (v *VMValue) OpBitwiseOr(ctx *Context, v2 *VMValue) *VMValue { func (v *VMValue) OpPositive() *VMValue { switch v.TypeId { case VMTypeInt: - return VMValueNewInt(v.Value.(int64)) + return VMValueNewInt(v.Value.(IntType)) case VMTypeFloat: return VMValueNewFloat(v.Value.(float64)) } @@ -914,7 +915,7 @@ func (v *VMValue) OpPositive() *VMValue { func (v *VMValue) OpNegation() *VMValue { switch v.TypeId { case VMTypeInt: - return VMValueNewInt(-v.Value.(int64)) + return VMValueNewInt(-v.Value.(IntType)) case VMTypeFloat: return VMValueNewFloat(-v.Value.(float64)) } @@ -1047,7 +1048,7 @@ func (v *VMValue) ItemGet(ctx *Context, index *VMValue) *VMValue { rstr := []rune(str) rIndex := index.MustReadInt() - _index := getClampRealIndex(ctx, rIndex, int64(len(rstr))) + _index := getClampRealIndex(ctx, rIndex, IntType(len(rstr))) newArr := string(rstr[_index : _index+1]) return VMValueNewStr(newArr) @@ -1093,7 +1094,7 @@ func (v *VMValue) ItemSet(ctx *Context, index *VMValue, val *VMValue) bool { return false } -func getRealIndex(ctx *Context, index int64, length int64) int64 { +func getRealIndex(ctx *Context, index IntType, length IntType) IntType { if index < 0 { // 负数下标支持 index = length + index @@ -1104,7 +1105,7 @@ func getRealIndex(ctx *Context, index int64, length int64) int64 { return index } -func getClampRealIndex(ctx *Context, index int64, length int64) int64 { +func getClampRealIndex(ctx *Context, index IntType, length IntType) IntType { if index < 0 { // 负数下标支持 index = length + index @@ -1119,7 +1120,7 @@ func getClampRealIndex(ctx *Context, index int64, length int64) int64 { return index } -func (v *VMValue) GetSlice(ctx *Context, a int64, b int64, step int64) *VMValue { +func (v *VMValue) GetSlice(ctx *Context, a IntType, b IntType, step IntType) *VMValue { length := v.Length(ctx) if ctx.Error != nil { return nil @@ -1147,16 +1148,16 @@ func (v *VMValue) GetSlice(ctx *Context, a int64, b int64, step int64) *VMValue } } -func (v *VMValue) Length(ctx *Context) int64 { - var length int64 +func (v *VMValue) Length(ctx *Context) IntType { + var length IntType switch v.TypeId { case VMTypeArray: arr, _ := v.ReadArray() - length = int64(len(arr.List)) + length = IntType(len(arr.List)) case VMTypeString: str, _ := v.ReadString() - length = int64(len([]rune(str))) + length = IntType(len([]rune(str))) default: ctx.Error = errors.New("这个类型无法取得分片") return 0 @@ -1194,7 +1195,7 @@ func (v *VMValue) GetSliceEx(ctx *Context, a *VMValue, b *VMValue) *VMValue { return v.GetSlice(ctx, valA, valB, 1) } -func (v *VMValue) SetSlice(ctx *Context, a int64, b int64, step int64, val *VMValue) bool { +func (v *VMValue) SetSlice(ctx *Context, a, b, step IntType, val *VMValue) bool { arr, ok := v.ReadArray() if !ok { ctx.Error = errors.New("这个类型无法赋值分片") @@ -1205,7 +1206,7 @@ func (v *VMValue) SetSlice(ctx *Context, a int64, b int64, step int64, val *VMVa ctx.Error = errors.New("val 的类型必须是一个列表") return false } - length := int64(len(arr.List)) + length := IntType(len(arr.List)) _a := getClampRealIndex(ctx, a, length) _b := getClampRealIndex(ctx, b, length) @@ -1216,7 +1217,7 @@ func (v *VMValue) SetSlice(ctx *Context, a int64, b int64, step int64, val *VMVa offset := len(arr2.List) - int(_b-_a) newArr := make([]*VMValue, len(arr.List)+offset) - for i := int64(0); i < _a; i++ { + for i := IntType(0); i < _a; i++ { newArr[i] = arr.List[i] } @@ -1244,7 +1245,7 @@ func (v *VMValue) SetSliceEx(ctx *Context, a *VMValue, b *VMValue, val *VMValue) } if b.TypeId == VMTypeUndefined { - b = VMValueNewInt(int64(len(arr.List))) + b = VMValueNewInt(IntType(len(arr.List))) } valA, ok := a.ReadInt() @@ -1267,7 +1268,7 @@ func (v *VMValue) ArrayRepeatTimesEx(ctx *Context, times *VMValue) *VMValue { case VMTypeInt: times, _ := times.ReadInt() ad, _ := v.ReadArray() - length := int64(len(ad.List)) * times + length := IntType(len(ad.List)) * times if length > 512 { ctx.Error = errors.New("不能一次性创建过长的数组") @@ -1276,7 +1277,7 @@ func (v *VMValue) ArrayRepeatTimesEx(ctx *Context, times *VMValue) *VMValue { arr := make([]*VMValue, length) - for i := int64(0); i < length; i++ { + for i := IntType(0); i < length; i++ { arr[i] = ad.List[int(i)%len(ad.List)].Clone() } return VMValueNewArray(arr...) @@ -1287,9 +1288,9 @@ func (v *VMValue) ArrayRepeatTimesEx(ctx *Context, times *VMValue) *VMValue { func (v *VMValue) GetTypeName() string { switch v.TypeId { case VMTypeInt: - return "int64" + return "int" case VMTypeFloat: - return "float64" + return "float" case VMTypeString: return "str" case VMTypeUndefined: @@ -1508,12 +1509,12 @@ func ValueEqual(a *VMValue, b *VMValue, autoConvert bool) bool { case VMTypeInt: switch b.TypeId { case VMTypeFloat: - return float64(a.Value.(int64)) == b.Value.(float64) + return float64(a.Value.(IntType)) == b.Value.(float64) } case VMTypeFloat: switch b.TypeId { case VMTypeInt: - return a.Value.(float64) == float64(b.Value.(int64)) + return a.Value.(float64) == float64(b.Value.(IntType)) } } } @@ -1521,7 +1522,7 @@ func ValueEqual(a *VMValue, b *VMValue, autoConvert bool) bool { return false } -func VMValueNewInt(i int64) *VMValue { +func VMValueNewInt(i IntType) *VMValue { // TODO: 小整数可以处理为不可变对象,且一直停留在内存中,就像python那样。这可以避免很多内存申请 return &VMValue{TypeId: VMTypeInt, Value: i} } diff --git a/types_functions.go b/types_functions.go index 2b74cfd3..64f32458 100644 --- a/types_functions.go +++ b/types_functions.go @@ -33,10 +33,10 @@ func (d *VMDictValue) ToString() string { return d.V().ToString() } -func (v *VMValue) ArrayItemGet(ctx *Context, index int64) *VMValue { +func (v *VMValue) ArrayItemGet(ctx *Context, index IntType) *VMValue { if v.TypeId == VMTypeArray { arr, _ := v.ReadArray() - index = getRealIndex(ctx, index, int64(len(arr.List))) + index = getRealIndex(ctx, index, IntType(len(arr.List))) if ctx.Error != nil { return nil } @@ -46,10 +46,10 @@ func (v *VMValue) ArrayItemGet(ctx *Context, index int64) *VMValue { return nil } -func (v *VMValue) ArrayItemSet(ctx *Context, index int64, val *VMValue) bool { +func (v *VMValue) ArrayItemSet(ctx *Context, index IntType, val *VMValue) bool { if v.TypeId == VMTypeArray { arr, _ := v.ReadArray() - index = getRealIndex(ctx, index, int64(len(arr.List))) + index = getRealIndex(ctx, index, IntType(len(arr.List))) if ctx.Error != nil { return false } @@ -60,7 +60,7 @@ func (v *VMValue) ArrayItemSet(ctx *Context, index int64, val *VMValue) bool { return false } -func (v *VMValue) ArrayFuncKeepBase(ctx *Context, pickNum int64, orderType int) (isAllInt bool, ret float64) { +func (v *VMValue) ArrayFuncKeepBase(ctx *Context, pickNum IntType, orderType int) (isAllInt bool, ret float64) { arr, _ := v.ReadArray() var nums []float64 @@ -82,9 +82,9 @@ func (v *VMValue) ArrayFuncKeepBase(ctx *Context, pickNum int64, orderType int) } num := float64(0) - for i := int64(0); i < pickNum; i++ { + for i := IntType(0); i < pickNum; i++ { // 当取数大于上限 跳过 - if i >= int64(len(nums)) { + if i >= IntType(len(nums)) { continue } num += nums[i] @@ -93,10 +93,10 @@ func (v *VMValue) ArrayFuncKeepBase(ctx *Context, pickNum int64, orderType int) return isAllInt, num } -func (v *VMValue) ArrayFuncKeepHigh(ctx *Context, pickNum int64) (isAllInt bool, ret float64) { +func (v *VMValue) ArrayFuncKeepHigh(ctx *Context, pickNum IntType) (isAllInt bool, ret float64) { return v.ArrayFuncKeepBase(ctx, pickNum, 0) } -func (v *VMValue) ArrayFuncKeepLow(ctx *Context, pickNum int64) (isAllInt bool, ret float64) { +func (v *VMValue) ArrayFuncKeepLow(ctx *Context, pickNum IntType) (isAllInt bool, ret float64) { return v.ArrayFuncKeepBase(ctx, pickNum, 1) } diff --git a/types_functions_test.go b/types_functions_test.go index d528e596..f85b76dc 100644 --- a/types_functions_test.go +++ b/types_functions_test.go @@ -40,7 +40,7 @@ func TestTypesFuncArray(t *testing.T) { vm = NewVM() arr = na(ni(1), ni(2), ni(3)) arr.ArrayItemSet(vm, 1, ni(4)) - assert.Equal(t, arr.MustReadArray().List[1].MustReadInt(), int64(4)) + assert.Equal(t, arr.MustReadArray().List[1].MustReadInt(), IntType(4)) vm = NewVM() arr = na(ni(1), ni(2), ni(3)) diff --git a/types_methods.go b/types_methods.go index 9bf5ae0b..6c3c7910 100644 --- a/types_methods.go +++ b/types_methods.go @@ -8,7 +8,7 @@ import ( func funcArrayKeepLow(ctx *Context, this *VMValue, params []*VMValue) *VMValue { isAllInt, ret := this.ArrayFuncKeepLow(ctx, params[0].MustReadInt()) if isAllInt { - return VMValueNewInt(int64(ret)) + return VMValueNewInt(IntType(ret)) } else { return VMValueNewFloat(ret) } @@ -17,7 +17,7 @@ func funcArrayKeepLow(ctx *Context, this *VMValue, params []*VMValue) *VMValue { func funcArrayKeepHigh(ctx *Context, this *VMValue, params []*VMValue) *VMValue { isAllInt, ret := this.ArrayFuncKeepHigh(ctx, params[0].MustReadInt()) if isAllInt { - return VMValueNewInt(int64(ret)) + return VMValueNewInt(IntType(ret)) } else { return VMValueNewFloat(ret) } @@ -39,7 +39,7 @@ func funcArraySum(ctx *Context, this *VMValue, params []*VMValue) *VMValue { } if isAllInt { - return VMValueNewInt(int64(sumNum)) + return VMValueNewInt(IntType(sumNum)) } else { return VMValueNewFloat(sumNum) } @@ -47,7 +47,7 @@ func funcArraySum(ctx *Context, this *VMValue, params []*VMValue) *VMValue { func funcArrayLen(ctx *Context, this *VMValue, params []*VMValue) *VMValue { arr, _ := this.ReadArray() - return VMValueNewInt(int64(len(arr.List))) + return VMValueNewInt(IntType(len(arr.List))) } func funcArrayShuttle(ctx *Context, this *VMValue, params []*VMValue) *VMValue { @@ -139,7 +139,7 @@ func funcDictItems(ctx *Context, this *VMValue, params []*VMValue) *VMValue { func funcDictLen(ctx *Context, this *VMValue, params []*VMValue) *VMValue { d := this.MustReadDictData() - var size int64 + var size IntType d.Dict.Range(func(key string, value *VMValue) bool { size++ return true diff --git a/types_serialization.go b/types_serialization.go index 08959f0b..765761c1 100644 --- a/types_serialization.go +++ b/types_serialization.go @@ -169,7 +169,7 @@ func (v *VMValue) UnmarshalJSON(input []byte) error { switch v0.TypeId { case VMTypeInt: var v1 struct { - Value int64 `json:"value"` + Value IntType `json:"value"` } err := json.Unmarshal(input, &v1) if err == nil { diff --git a/types_serialization_test.go b/types_serialization_test.go index e46f5cb7..2f3d99e1 100644 --- a/types_serialization_test.go +++ b/types_serialization_test.go @@ -75,7 +75,7 @@ func TestLoads(t *testing.T) { v, err = VMValueFromJSON([]byte(`{"typeId":0,"value":123}`)) if assert.NoError(t, err) { assert.Equal(t, v.TypeId, VMTypeInt) - assert.Equal(t, int64(123), v.Value) + assert.Equal(t, IntType(123), v.Value) } v, err = VMValueFromJSON([]byte(`{"typeId":1,"value":3.2}`))