From d2e150948ad3d7a919cd5906b70844b508fab1f5 Mon Sep 17 00:00:00 2001 From: chenzhuoyu Date: Mon, 6 Sep 2021 16:25:38 +0800 Subject: [PATCH] fix: unmarshalers are always addressable --- decoder/assembler_amd64.go | 10 +-- decoder/assembler_test.go | 15 ++-- decoder/compiler.go | 156 +++++++++++++++++-------------------- decoder/compiler_test.go | 2 +- decoder/pools.go | 4 +- encoder/assembler_amd64.go | 10 +-- encoder/assembler_test.go | 20 ++--- encoder/compiler.go | 45 ++++++----- encoder/pools.go | 2 +- issue82_test.go | 41 ++++++++++ 10 files changed, 164 insertions(+), 141 deletions(-) create mode 100644 issue82_test.go diff --git a/decoder/assembler_amd64.go b/decoder/assembler_amd64.go index b40b0fd..301768e 100644 --- a/decoder/assembler_amd64.go +++ b/decoder/assembler_amd64.go @@ -182,10 +182,10 @@ var ( type _Assembler struct { jit.BaseAssembler - p *_Program + p _Program } -func newAssembler(p *_Program) *_Assembler { +func newAssembler(p _Program) *_Assembler { return new(_Assembler).Init(p) } @@ -195,7 +195,7 @@ func (self *_Assembler) Load() _Decoder { return ptodec(self.BaseAssembler.Load("json_decoder", _FP_size, _FP_args)) } -func (self *_Assembler) Init(p *_Program) *_Assembler { +func (self *_Assembler) Init(p _Program) *_Assembler { self.p = p self.BaseAssembler.Init(self.compile) return self @@ -286,14 +286,14 @@ func (self *_Assembler) instr(v *_Instr) { } func (self *_Assembler) instrs() { - for i, v := range self.p.ins { + for i, v := range self.p { self.Mark(i) self.instr(&v) } } func (self *_Assembler) epilogue() { - self.Mark(len(self.p.ins)) + self.Mark(len(self.p)) self.Emit("XORL", _ET, _ET) // XORL ET, ET self.Emit("XORL", _EP, _EP) // XORL EP, EP self.Link(_LB_error) // _error: diff --git a/decoder/assembler_test.go b/decoder/assembler_test.go index 1bf6e68..dc228c8 100644 --- a/decoder/assembler_test.go +++ b/decoder/assembler_test.go @@ -32,8 +32,7 @@ import ( ) func TestAssembler_PrologueAndEpilogue(t *testing.T) { - p := new(_Program) - a := newAssembler(p) + a := newAssembler(nil) _, e := a.Load()("", 0, nil, nil, 0) assert.Nil(t, e) } @@ -103,7 +102,7 @@ func init() { type testOps struct { key string - ins []_Instr + ins _Program src string pos int opt uint64 @@ -114,7 +113,7 @@ type testOps struct { } func testOpCode(t *testing.T, ops *testOps) { - p := &_Program{ins: ops.ins} + p := ops.ins k := new(_Stack) a := newAssembler(p) f := a.Load() @@ -671,7 +670,7 @@ type JsonStruct struct { func TestAssembler_DecodeStruct(t *testing.T) { var v JsonStruct s := `{"A": 123, "B": "asdf", "C": {"qwer": 4567}, "D": [1, 2, 3, 4, 5]}` - p, err := newCompiler().compile(reflect.TypeOf(v)) + p, err := make(_Compiler).compile(reflect.TypeOf(v)) require.NoError(t, err) k := new(_Stack) a := newAssembler(p) @@ -694,7 +693,7 @@ type Tx struct { func TestAssembler_DecodeStruct_SinglePrivateField(t *testing.T) { var v Tx s := `{"x": 1}` - p, err := newCompiler().compile(reflect.TypeOf(v)) + p, err := make(_Compiler).compile(reflect.TypeOf(v)) require.NoError(t, err) k := new(_Stack) a := newAssembler(p) @@ -708,7 +707,7 @@ func TestAssembler_DecodeStruct_SinglePrivateField(t *testing.T) { func TestAssembler_DecodeByteSlice_Bin(t *testing.T) { var v []byte s := `"aGVsbG8sIHdvcmxk"` - p, err := newCompiler().compile(reflect.TypeOf(v)) + p, err := make(_Compiler).compile(reflect.TypeOf(v)) require.NoError(t, err) k := new(_Stack) a := newAssembler(p) @@ -722,7 +721,7 @@ func TestAssembler_DecodeByteSlice_Bin(t *testing.T) { func TestAssembler_DecodeByteSlice_List(t *testing.T) { var v []byte s := `[104, 101, 108, 108, 111, 44, 32, 119, 111, 114, 108, 100]` - p, err := newCompiler().compile(reflect.TypeOf(v)) + p, err := make(_Compiler).compile(reflect.TypeOf(v)) require.NoError(t, err) k := new(_Stack) a := newAssembler(p) diff --git a/decoder/compiler.go b/decoder/compiler.go index 23b433e..263ec50 100644 --- a/decoder/compiler.go +++ b/decoder/compiler.go @@ -388,63 +388,63 @@ func (self _Instr) formatStructFields() string { return strings.Join(r, ", ") } -type _Program struct { - ins []_Instr +type ( + _Program []_Instr +) + +func (self _Program) pc() int { + return len(self) } -func (self *_Program) pc() int { - return len(self.ins) -} - -func (self *_Program) tag(n int) { +func (self _Program) tag(n int) { if n >= _MaxStack { panic("type nesting too deep") } } -func (self *_Program) pin(i int) { - v := &self.ins[i] +func (self _Program) pin(i int) { + v := &self[i] v.u &= 0xffff000000000000 v.u |= rt.PackInt(self.pc()) } -func (self *_Program) rel(v []int) { +func (self _Program) rel(v []int) { for _, i := range v { self.pin(i) } } func (self *_Program) add(op _Op) { - self.ins = append(self.ins, newInsOp(op)) + *self = append(*self, newInsOp(op)) } func (self *_Program) int(op _Op, vi int) { - self.ins = append(self.ins, newInsVi(op, vi)) + *self = append(*self, newInsVi(op, vi)) } func (self *_Program) chr(op _Op, vb byte) { - self.ins = append(self.ins, newInsVb(op, vb)) + *self = append(*self, newInsVb(op, vb)) } func (self *_Program) tab(op _Op, vs []int) { - self.ins = append(self.ins, newInsVs(op, vs)) + *self = append(*self, newInsVs(op, vs)) } func (self *_Program) rtt(op _Op, vt reflect.Type) { - self.ins = append(self.ins, newInsVt(op, vt)) + *self = append(*self, newInsVt(op, vt)) } func (self *_Program) fmv(op _Op, vf *caching.FieldMap) { - self.ins = append(self.ins, newInsVf(op, vf)) + *self = append(*self, newInsVf(op, vf)) } -func (self *_Program) disassemble() string { - nb := len(self.ins) +func (self _Program) disassemble() string { + nb := len(self) tab := make([]bool, nb + 1) ret := make([]string, 0, nb + 1) /* prescan to get all the labels */ - for _, ins := range self.ins { + for _, ins := range self { if ins.isBranch() { if ins.op() != _OP_switch { tab[ins.vi()] = true @@ -457,7 +457,7 @@ func (self *_Program) disassemble() string { } /* disassemble each instruction */ - for i, ins := range self.ins { + for i, ins := range self { if !tab[i] { ret = append(ret, "\t" + ins.disassemble()) } else { @@ -474,18 +474,11 @@ func (self *_Program) disassemble() string { return strings.Join(append(ret, "\tend"), "\n") } -type _Compiler struct { - pv bool - tab map[reflect.Type]bool -} +type ( + _Compiler map[reflect.Type]bool +) -func newCompiler() *_Compiler { - return &_Compiler { - tab: map[reflect.Type]bool{}, - } -} - -func (self *_Compiler) rescue(ep *error) { +func (self _Compiler) rescue(ep *error) { if val := recover(); val != nil { if err, ok := val.(error); ok { *ep = err @@ -495,27 +488,24 @@ func (self *_Compiler) rescue(ep *error) { } } -func (self *_Compiler) compile(vt reflect.Type) (ret *_Program, err error) { - ret = &_Program{} +func (self _Compiler) compile(vt reflect.Type) (ret _Program, err error) { defer self.rescue(&err) - self.compileOne(ret, 0, vt, true) + self.compileOne(&ret, 0, vt) return } -func (self *_Compiler) compileOne(p *_Program, sp int, vt reflect.Type, pv bool) { - if self.tab[vt] { - p.rtt(_OP_recurse, vt) - } else { - self.compileRec(p, sp, vt, pv) - } -} - -func (self *_Compiler) compileRec(p *_Program, sp int, vt reflect.Type, pv bool) { - pr := self.pv +func (self _Compiler) compileOne(p *_Program, sp int, vt reflect.Type) { + ok := self[vt] pt := reflect.PtrTo(vt) - /* check for addressable `json.Unmarshaler` with pointer receiver */ - if pv && pt.Implements(jsonUnmarshalerType) { + /* check for recursive nesting */ + if ok { + p.rtt(_OP_recurse, vt) + return + } + + /* check for `json.Unmarshaler` with pointer receiver */ + if pt.Implements(jsonUnmarshalerType) { p.rtt(_OP_unmarshal_p, pt) return } @@ -527,8 +517,8 @@ func (self *_Compiler) compileRec(p *_Program, sp int, vt reflect.Type, pv bool) return } - /* check for addressable `encoding.TextMarshaler` with pointer receiver */ - if pv && pt.Implements(encodingTextUnmarshalerType) { + /* check for `encoding.TextMarshaler` with pointer receiver */ + if pt.Implements(encodingTextUnmarshalerType) { p.add(_OP_lspace) self.compileUnmarshalTextPtr(p, pt) return @@ -542,19 +532,13 @@ func (self *_Compiler) compileRec(p *_Program, sp int, vt reflect.Type, pv bool) } /* enter the recursion */ - self.pv = pv - self.tab[vt] = true - - /* compile the type */ p.add(_OP_lspace) + self[vt] = true self.compileOps(p, sp, vt) - - /* exit the recursion */ - self.pv = pr - delete(self.tab, vt) + delete(self, vt) } -func (self *_Compiler) compileOps(p *_Program, sp int, vt reflect.Type) { +func (self _Compiler) compileOps(p *_Program, sp int, vt reflect.Type) { switch vt.Kind() { case reflect.Bool : self.compilePrimitive (p, _OP_bool) case reflect.Int : self.compilePrimitive (p, _OP_int()) @@ -581,7 +565,7 @@ func (self *_Compiler) compileOps(p *_Program, sp int, vt reflect.Type) { } } -func (self *_Compiler) compileMap(p *_Program, sp int, vt reflect.Type) { +func (self _Compiler) compileMap(p *_Program, sp int, vt reflect.Type) { if reflect.PtrTo(vt.Key()).Implements(encodingTextUnmarshalerType) { self.compileMapOp(p, sp, vt, _OP_map_key_utext_p) } else if vt.Key().Implements(encodingTextUnmarshalerType) { @@ -591,7 +575,7 @@ func (self *_Compiler) compileMap(p *_Program, sp int, vt reflect.Type) { } } -func (self *_Compiler) compileMapUt(p *_Program, sp int, vt reflect.Type) { +func (self _Compiler) compileMapUt(p *_Program, sp int, vt reflect.Type) { switch vt.Key().Kind() { case reflect.Int : self.compileMapOp(p, sp, vt, _OP_map_key_int()) case reflect.Int8 : self.compileMapOp(p, sp, vt, _OP_map_key_i8) @@ -611,7 +595,7 @@ func (self *_Compiler) compileMapUt(p *_Program, sp int, vt reflect.Type) { } } -func (self *_Compiler) compileMapOp(p *_Program, sp int, vt reflect.Type, op _Op) { +func (self _Compiler) compileMapOp(p *_Program, sp int, vt reflect.Type, op _Op) { i := p.pc() p.add(_OP_is_null) p.tag(sp + 1) @@ -633,7 +617,7 @@ func (self *_Compiler) compileMapOp(p *_Program, sp int, vt reflect.Type, op _Op /* match the value separator */ p.add(_OP_lspace) p.chr(_OP_match_char, ':') - self.compileOne(p, sp + 2, vt.Elem(), false) + self.compileOne(p, sp + 2, vt.Elem()) p.add(_OP_load) k0 := p.pc() p.add(_OP_lspace) @@ -652,7 +636,7 @@ func (self *_Compiler) compileMapOp(p *_Program, sp int, vt reflect.Type, op _Op /* match the value separator */ p.add(_OP_lspace) p.chr(_OP_match_char, ':') - self.compileOne(p, sp + 2, vt.Elem(), false) + self.compileOne(p, sp + 2, vt.Elem()) p.add(_OP_load) p.int(_OP_goto, k0) p.pin(j) @@ -665,7 +649,7 @@ func (self *_Compiler) compileMapOp(p *_Program, sp int, vt reflect.Type, op _Op p.pin(x) } -func (self *_Compiler) compilePtr(p *_Program, sp int, et reflect.Type) { +func (self _Compiler) compilePtr(p *_Program, sp int, et reflect.Type) { i := p.pc() p.add(_OP_is_null) @@ -676,7 +660,7 @@ func (self *_Compiler) compilePtr(p *_Program, sp int, et reflect.Type) { } /* compile the element type */ - self.compileOne(p, sp + 1, et, true) + self.compileOne(p, sp + 1, et) j := p.pc() p.add(_OP_goto) p.pin(i) @@ -684,7 +668,7 @@ func (self *_Compiler) compilePtr(p *_Program, sp int, et reflect.Type) { p.pin(j) } -func (self *_Compiler) compileArray(p *_Program, sp int, vt reflect.Type) { +func (self _Compiler) compileArray(p *_Program, sp int, vt reflect.Type) { x := p.pc() p.add(_OP_is_null) p.tag(sp) @@ -696,7 +680,7 @@ func (self *_Compiler) compileArray(p *_Program, sp int, vt reflect.Type) { /* decode every item */ for i := 1; i <= vt.Len(); i++ { - self.compileOne(p, sp + 1, vt.Elem(), self.pv) + self.compileOne(p, sp + 1, vt.Elem()) p.add(_OP_load) p.int(_OP_index, i * int(vt.Elem().Size())) p.add(_OP_lspace) @@ -724,7 +708,7 @@ func (self *_Compiler) compileArray(p *_Program, sp int, vt reflect.Type) { p.pin(x) } -func (self *_Compiler) compileSlice(p *_Program, sp int, et reflect.Type) { +func (self _Compiler) compileSlice(p *_Program, sp int, et reflect.Type) { if et.Kind() == byteType.Kind() { self.compileSliceBin(p, sp, et) } else { @@ -732,7 +716,7 @@ func (self *_Compiler) compileSlice(p *_Program, sp int, et reflect.Type) { } } -func (self *_Compiler) compileSliceBin(p *_Program, sp int, et reflect.Type) { +func (self _Compiler) compileSliceBin(p *_Program, sp int, et reflect.Type) { i := p.pc() p.add(_OP_is_null) j := p.pc() @@ -754,7 +738,7 @@ func (self *_Compiler) compileSliceBin(p *_Program, sp int, et reflect.Type) { p.pin(y) } -func (self *_Compiler) compileSliceList(p *_Program, sp int, et reflect.Type) { +func (self _Compiler) compileSliceList(p *_Program, sp int, et reflect.Type) { i := p.pc() p.add(_OP_is_null) p.tag(sp) @@ -767,14 +751,14 @@ func (self *_Compiler) compileSliceList(p *_Program, sp int, et reflect.Type) { p.pin(x) } -func (self *_Compiler) compileSliceBody(p *_Program, sp int, et reflect.Type) { +func (self _Compiler) compileSliceBody(p *_Program, sp int, et reflect.Type) { p.rtt(_OP_slice_init, et) p.add(_OP_save) p.add(_OP_lspace) j := p.pc() p.chr(_OP_check_char, ']') p.rtt(_OP_slice_append, et) - self.compileOne(p, sp + 1, et, true) + self.compileOne(p, sp + 1, et) p.add(_OP_load) k0 := p.pc() p.add(_OP_lspace) @@ -782,7 +766,7 @@ func (self *_Compiler) compileSliceBody(p *_Program, sp int, et reflect.Type) { p.chr(_OP_check_char, ']') p.chr(_OP_match_char, ',') p.rtt(_OP_slice_append, et) - self.compileOne(p, sp + 1, et, true) + self.compileOne(p, sp + 1, et) p.add(_OP_load) p.int(_OP_goto, k0) p.pin(j) @@ -790,7 +774,7 @@ func (self *_Compiler) compileSliceBody(p *_Program, sp int, et reflect.Type) { p.add(_OP_drop) } -func (self *_Compiler) compileString(p *_Program, vt reflect.Type) { +func (self _Compiler) compileString(p *_Program, vt reflect.Type) { if vt == jsonNumberType { self.compilePrimitive(p, _OP_num) } else { @@ -798,7 +782,7 @@ func (self *_Compiler) compileString(p *_Program, vt reflect.Type) { } } -func (self *_Compiler) compileStringBody(p *_Program) { +func (self _Compiler) compileStringBody(p *_Program) { i := p.pc() p.add(_OP_is_null) p.chr(_OP_match_char, '"') @@ -806,15 +790,15 @@ func (self *_Compiler) compileStringBody(p *_Program) { p.pin(i) } -func (self *_Compiler) compileStruct(p *_Program, sp int, vt reflect.Type) { - if sp >= _MAX_STACK || len(p.ins) >= _MAX_ILBUF { +func (self _Compiler) compileStruct(p *_Program, sp int, vt reflect.Type) { + if sp >= _MAX_STACK || p.pc() >= _MAX_ILBUF { p.rtt(_OP_recurse, vt) } else { self.compileStructBody(p, sp, vt) } } -func (self *_Compiler) compileStructBody(p *_Program, sp int, vt reflect.Type) { +func (self _Compiler) compileStructBody(p *_Program, sp int, vt reflect.Type) { fv := resolver.ResolveStruct(vt) fm, sw := caching.CreateFieldMap(len(fv)), make([]int, len(fv)) @@ -869,7 +853,7 @@ func (self *_Compiler) compileStructBody(p *_Program, sp int, vt reflect.Type) { /* check for "stringnize" option */ if (f.Opts & resolver.F_stringize) == 0 { - self.compileOne(p, sp + 1, f.Type, self.pv) + self.compileOne(p, sp + 1, f.Type) } else { self.compileStructFieldStr(p, sp + 1, f.Type) } @@ -886,7 +870,7 @@ end_of_object: p.pin(n) } -func (self *_Compiler) compileStructFieldStr(p *_Program, sp int, vt reflect.Type) { +func (self _Compiler) compileStructFieldStr(p *_Program, sp int, vt reflect.Type) { n1 := -1 ft := vt sv := false @@ -917,7 +901,7 @@ func (self *_Compiler) compileStructFieldStr(p *_Program, sp int, vt reflect.Typ /* if it's not, ignore the "string" and follow the regular path */ if !sv { - self.compileOne(p, sp, vt, self.pv) + self.compileOne(p, sp, vt) return } @@ -993,7 +977,7 @@ func (self *_Compiler) compileStructFieldStr(p *_Program, sp int, vt reflect.Typ p.pin(pc) } -func (self *_Compiler) compileInterface(p *_Program, vt reflect.Type) { +func (self _Compiler) compileInterface(p *_Program, vt reflect.Type) { i := p.pc() p.add(_OP_is_null) @@ -1012,14 +996,14 @@ func (self *_Compiler) compileInterface(p *_Program, vt reflect.Type) { p.pin(j) } -func (self *_Compiler) compilePrimitive(p *_Program, op _Op) { +func (self _Compiler) compilePrimitive(p *_Program, op _Op) { i := p.pc() p.add(_OP_is_null) p.add(op) p.pin(i) } -func (self *_Compiler) compileUnmarshalEnd(p *_Program, vt reflect.Type, i int) { +func (self _Compiler) compileUnmarshalEnd(p *_Program, vt reflect.Type, i int) { j := p.pc() k := vt.Kind() @@ -1036,7 +1020,7 @@ func (self *_Compiler) compileUnmarshalEnd(p *_Program, vt reflect.Type, i int) p.pin(j) } -func (self *_Compiler) compileUnmarshalJson(p *_Program, vt reflect.Type) { +func (self _Compiler) compileUnmarshalJson(p *_Program, vt reflect.Type) { i := p.pc() v := _OP_unmarshal p.add(_OP_is_null) @@ -1051,7 +1035,7 @@ func (self *_Compiler) compileUnmarshalJson(p *_Program, vt reflect.Type) { self.compileUnmarshalEnd(p, vt, i) } -func (self *_Compiler) compileUnmarshalText(p *_Program, vt reflect.Type) { +func (self _Compiler) compileUnmarshalText(p *_Program, vt reflect.Type) { i := p.pc() v := _OP_unmarshal_text p.add(_OP_is_null) @@ -1068,7 +1052,7 @@ func (self *_Compiler) compileUnmarshalText(p *_Program, vt reflect.Type) { self.compileUnmarshalEnd(p, vt, i) } -func (self *_Compiler) compileUnmarshalTextPtr(p *_Program, vt reflect.Type) { +func (self _Compiler) compileUnmarshalTextPtr(p *_Program, vt reflect.Type) { i := p.pc() p.add(_OP_is_null) p.chr(_OP_match_char, '"') diff --git a/decoder/compiler_test.go b/decoder/compiler_test.go index 1fb20c8..2c168cb 100644 --- a/decoder/compiler_test.go +++ b/decoder/compiler_test.go @@ -24,7 +24,7 @@ import ( ) func TestCompiler_Compile(t *testing.T) { - prg, err := newCompiler().compile(reflect.TypeOf(TwitterStruct{})) + prg, err := make(_Compiler).compile(reflect.TypeOf(TwitterStruct{})) assert.Nil(t, err) println(prg.disassemble()) } diff --git a/decoder/pools.go b/decoder/pools.go index c53ad40..e0c68fc 100644 --- a/decoder/pools.go +++ b/decoder/pools.go @@ -89,7 +89,7 @@ func referenceFields(v *caching.FieldMap) int64 { func findOrCompile(vt *rt.GoType) (_Decoder, error) { var ex error var fn _Decoder - var pp *_Program + var pp _Program var fv interface{} /* fast path: the program is in the cache */ @@ -98,7 +98,7 @@ func findOrCompile(vt *rt.GoType) (_Decoder, error) { } /* slow path: not found, compile the type on the fly */ - if pp, ex = newCompiler().compile(vt.Pack()); ex != nil { + if pp, ex = make(_Compiler).compile(vt.Pack()); ex != nil { return nil, ex } diff --git a/encoder/assembler_amd64.go b/encoder/assembler_amd64.go index e839800..d25c60b 100644 --- a/encoder/assembler_amd64.go +++ b/encoder/assembler_amd64.go @@ -168,11 +168,11 @@ var ( type _Assembler struct { jit.BaseAssembler - p *_Program + p _Program x int } -func newAssembler(p *_Program) *_Assembler { +func newAssembler(p _Program) *_Assembler { return new(_Assembler).Init(p) } @@ -182,7 +182,7 @@ func (self *_Assembler) Load() _Encoder { return ptoenc(self.BaseAssembler.Load("json_encoder", _FP_size, _FP_args)) } -func (self *_Assembler) Init(p *_Program) *_Assembler { +func (self *_Assembler) Init(p _Program) *_Assembler { self.p = p self.BaseAssembler.Init(self.compile) return self @@ -259,7 +259,7 @@ func (self *_Assembler) instr(v *_Instr) { } func (self *_Assembler) instrs() { - for i, v := range self.p.ins { + for i, v := range self.p { self.Mark(i) self.instr(&v) } @@ -273,7 +273,7 @@ func (self *_Assembler) builtins() { } func (self *_Assembler) epilogue() { - self.Mark(len(self.p.ins)) + self.Mark(len(self.p)) self.Emit("XORL", _ET, _ET) self.Emit("XORL", _EP, _EP) self.Link(_LB_error) diff --git a/encoder/assembler_test.go b/encoder/assembler_test.go index ad54c63..79a592b 100644 --- a/encoder/assembler_test.go +++ b/encoder/assembler_test.go @@ -66,14 +66,14 @@ func TestAssembler_CompileAndLoad(t *testing.T) { type testOps struct { key string - ins []_Instr + ins _Program exp string err error val interface{} } -func testOpCode(t *testing.T, v interface{}, ex string, err error, ins []_Instr) { - p := &_Program{ins: ins} +func testOpCode(t *testing.T, v interface{}, ex string, err error, ins _Program) { + p := ins m := []byte(nil) s := new(_Stack) a := newAssembler(p) @@ -105,7 +105,7 @@ type RecursiveValue struct { Z int `json:"z"` } -func mustCompile(t interface{}) *_Program { +func mustCompile(t interface{}) _Program { p, err := newCompiler().compile(reflect.TypeOf(t)) if err != nil { panic(err) @@ -297,12 +297,12 @@ func TestAssembler_OpCode(t *testing.T) { val: nil, }, { key: "_OP_map_[iter,next,value]", - ins: mustCompile(map[string]map[int64]int{}).ins, + ins: mustCompile(map[string]map[int64]int{}), exp: `{"asdf":{"-9223372036854775808":1234}}`, val: &map[string]map[int64]int{"asdf": {math.MinInt64: 1234}}, }, { key: "_OP_slice_[len,next]", - ins: mustCompile([][]int{}).ins, + ins: mustCompile([][]int{}), exp: `[[1,2,3],[4,5,6]]`, val: &[][]int{{1, 2, 3}, {4, 5, 6}}, }, { @@ -327,7 +327,7 @@ func TestAssembler_OpCode(t *testing.T) { val: &jifp, }, { key: "_OP_recurse", - ins: mustCompile(rec).ins, + ins: mustCompile(rec), exp: `{"a":123,"p":{"a":789,"p":{"a":777,"q":[{"a":999,"q":null,"r":{"` + `xxx":{"a":333,"q":null,"r":null,"z":0}},"z":222}],"r":null,"z":8` + `88},"q":null,"r":null,"z":666},"q":null,"r":null,"z":456}`, @@ -341,7 +341,7 @@ func TestAssembler_OpCode(t *testing.T) { } func TestAssembler_StringMoreSpace(t *testing.T) { - p := &_Program{ins: []_Instr{newInsOp(_OP_str)}} + p := _Program{newInsOp(_OP_str)} m := make([]byte, 0, 8) s := new(_Stack) a := newAssembler(p) @@ -353,7 +353,7 @@ func TestAssembler_StringMoreSpace(t *testing.T) { } func TestAssembler_TwitterJSON_Generic(t *testing.T) { - p := &_Program{ins: mustCompile(&_GenericValue).ins} + p := mustCompile(&_GenericValue) m := []byte(nil) s := new(_Stack) a := newAssembler(p) @@ -365,7 +365,7 @@ func TestAssembler_TwitterJSON_Generic(t *testing.T) { } func TestAssembler_TwitterJSON_Structure(t *testing.T) { - p := &_Program{ins: mustCompile(_BindingValue).ins} + p := mustCompile(_BindingValue) m := []byte(nil) s := new(_Stack) a := newAssembler(p) diff --git a/encoder/compiler.go b/encoder/compiler.go index 775e98e..89d9430 100644 --- a/encoder/compiler.go +++ b/encoder/compiler.go @@ -298,38 +298,38 @@ func (self _Instr) disassemble() string { } } -type _Program struct { - ins []_Instr +type ( + _Program []_Instr +) + +func (self _Program) pc() int { + return len(self) } -func (self *_Program) pc() int { - return len(self.ins) -} - -func (self *_Program) tag(n int) { +func (self _Program) tag(n int) { if n >= _MaxStack { panic("type nesting too deep") } } -func (self *_Program) pin(i int) { - v := &self.ins[i] +func (self _Program) pin(i int) { + v := &self[i] v.u &= 0xffff000000000000 v.u |= rt.PackInt(self.pc()) } -func (self *_Program) rel(v []int) { +func (self _Program) rel(v []int) { for _, i := range v { self.pin(i) } } func (self *_Program) add(op _Op) { - self.ins = append(self.ins, newInsOp(op)) + *self = append(*self, newInsOp(op)) } func (self *_Program) key(op _Op) { - self.ins = append(self.ins, + *self = append(*self, newInsVi(_OP_byte, '"'), newInsOp(op), newInsVi(_OP_byte, '"'), @@ -337,31 +337,31 @@ func (self *_Program) key(op _Op) { } func (self *_Program) int(op _Op, vi int) { - self.ins = append(self.ins, newInsVi(op, vi)) + *self = append(*self, newInsVi(op, vi)) } func (self *_Program) str(op _Op, vs string) { - self.ins = append(self.ins, newInsVs(op, vs)) + *self = append(*self, newInsVs(op, vs)) } func (self *_Program) rtt(op _Op, vt reflect.Type) { - self.ins = append(self.ins, newInsVt(op, vt)) + *self = append(*self, newInsVt(op, vt)) } -func (self *_Program) disassemble() string { - nb := len(self.ins) +func (self _Program) disassemble() string { + nb := len(self) tab := make([]bool, nb + 1) ret := make([]string, 0, nb + 1) /* prescan to get all the labels */ - for _, ins := range self.ins { + for _, ins := range self { if ins.isBranch() { tab[ins.vi()] = true } } /* disassemble each instruction */ - for i, ins := range self.ins { + for i, ins := range self { if !tab[i] { ret = append(ret, "\t" + ins.disassemble()) } else { @@ -399,10 +399,9 @@ func (self *_Compiler) rescue(ep *error) { } } -func (self *_Compiler) compile(vt reflect.Type) (ret *_Program, err error) { - ret = &_Program{} +func (self *_Compiler) compile(vt reflect.Type) (ret _Program, err error) { defer self.rescue(&err) - self.compileOne(ret, 0, vt, false) + self.compileOne(&ret, 0, vt, false) return } @@ -652,7 +651,7 @@ func (self *_Compiler) compileString(p *_Program, vt reflect.Type) { } func (self *_Compiler) compileStruct(p *_Program, sp int, vt reflect.Type) { - if sp >= _MAX_STACK || len(p.ins) >= _MAX_ILBUF { + if sp >= _MAX_STACK || p.pc() >= _MAX_ILBUF { p.rtt(_OP_recurse, vt) } else { self.compileStructBody(p, sp, vt) diff --git a/encoder/pools.go b/encoder/pools.go index f5b854a..fbaf751 100644 --- a/encoder/pools.go +++ b/encoder/pools.go @@ -98,7 +98,7 @@ func freeBuffer(p *bytes.Buffer) { func findOrCompile(vt *rt.GoType) (_Encoder, error) { var ex error var fn _Encoder - var pp *_Program + var pp _Program var fv interface{} /* fast path: the program is in the cache */ diff --git a/issue82_test.go b/issue82_test.go new file mode 100644 index 0000000..7920d71 --- /dev/null +++ b/issue82_test.go @@ -0,0 +1,41 @@ +/* + * Copyright 2021 ByteDance Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package sonic + +import ( + `testing` + + `github.com/bytedance/sonic/decoder` + `github.com/stretchr/testify/require` +) + +type Issue82String string + +func (s *Issue82String) UnmarshalJSON(b []byte) error { + *s = Issue82String(b) + return nil +} + +func TestIssue82_MapValueIsStringUnmarshaler(t *testing.T) { + var v map[string]Issue82String + err := Unmarshal([]byte(`{"a":123}`), &v) + if err != nil { + println(err.(decoder.SyntaxError).Description()) + require.NoError(t, err) + } + require.Equal(t, map[string]Issue82String{"a": "123"}, v) +}