From 2138136685af26369b9c336858dae2ee73189132 Mon Sep 17 00:00:00 2001 From: Yi Duan Date: Thu, 8 Sep 2022 16:31:38 +0800 Subject: [PATCH] feat: support more loose type-casting (#294) * feat: support more losing type cast * test: add loose casting tests * format * fmt: add license * fmt: add comments --- ast/compat_test.go | 63 +++++++++++++ ast/node.go | 229 +++++++++++++++++++++++++++++++++++---------- ast/node_test.go | 125 +++++++++++++++++++++---- 3 files changed, 351 insertions(+), 66 deletions(-) create mode 100644 ast/compat_test.go diff --git a/ast/compat_test.go b/ast/compat_test.go new file mode 100644 index 0000000..991f73a --- /dev/null +++ b/ast/compat_test.go @@ -0,0 +1,63 @@ +/* + * Copyright 2022 ByteDance Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ast + +import ( + `testing` + + jsoniter `github.com/json-iterator/go` + `github.com/stretchr/testify/require` + `github.com/tidwall/gjson` +) + +func TestNotFoud(t *testing.T) { + data := `{}` + + ia := jsoniter.Get([]byte(data), "b") + require.Error(t, ia.LastError()) + require.Equal(t, false, ia.ToBool()) + + ga := gjson.GetBytes([]byte(data), "b") + require.True(t, ga.Type == gjson.Null) + require.Equal(t, false, ga.Bool()) + + sa, err := NewSearcher(data).GetByPath("b") + require.True(t, sa.Type() == V_NONE) + require.Error(t, err) + sv, err := sa.Bool() + require.Error(t, err) + require.Equal(t, false, sv) +} + +func TestNull(t *testing.T) { + data := `{"b": null}` + + ia := jsoniter.Get([]byte(data), "b") + require.NoError(t, ia.LastError()) + require.Equal(t, false, ia.ToBool()) + + ga := gjson.GetBytes([]byte(data), "b") + require.True(t, ga.Type == gjson.Null) + require.Equal(t, false, ga.Bool()) + + sa, err := NewSearcher(data).GetByPath("b") + require.True(t, sa.Type() == V_NULL) + require.NoError(t, err) + sv, err := sa.Bool() + require.NoError(t, err) + require.Equal(t, false, sv) +} \ No newline at end of file diff --git a/ast/node.go b/ast/node.go index 652ef24..5cc91b1 100644 --- a/ast/node.go +++ b/ast/node.go @@ -19,8 +19,9 @@ package ast import ( `encoding/json` `fmt` + `strconv` `unsafe` - + `github.com/bytedance/sonic/decoder` `github.com/bytedance/sonic/internal/native/types` `github.com/bytedance/sonic/internal/rt` @@ -37,7 +38,7 @@ const ( const ( _V_NONE types.ValueType = 0 - _V_NODE_BASE types.ValueType = 1<<5 + _V_NODE_BASE types.ValueType = 1 << 5 _V_LAZY types.ValueType = 1 << 7 _V_RAW types.ValueType = 1 << 8 _V_NUMBER = _V_NODE_BASE + 1 @@ -165,11 +166,9 @@ func (self *Node) checkRaw() error { return nil } -// Bool returns bool value represented by this node -// -// If node type is not types.V_TRUE or types.V_FALSE, -// V_RAW (must be a bool json value), or V_ANY (must be a bool type) -// it will return error +// Bool returns bool value represented by this node, +// including types.V_TRUE|V_FALSE|V_NUMBER|V_STRING|V_ANY|V_NULL, +// V_NONE will return error func (self *Node) Bool() (bool, error) { if err := self.checkRaw(); err != nil { return false, err @@ -178,40 +177,97 @@ func (self *Node) Bool() (bool, error) { case types.V_TRUE : return true , nil case types.V_FALSE : return false, nil case types.V_NULL : return false, nil - case _V_ANY : - if v, ok := self.packAny().(bool); ok { - return v, nil + case _V_NUMBER : + if i, err := numberToInt64(self); err == nil { + return i != 0, nil + } else if f, err := numberToFloat64(self); err == nil { + return f != 0, nil } else { - return false, ErrUnsupportType + return false, err + } + case types.V_STRING: return strconv.ParseBool(addr2str(self.p, self.v)) + case _V_ANY : + any := self.packAny() + switch v := any.(type) { + case bool : return v, nil + case int : return v != 0, nil + case int8 : return v != 0, nil + case int16 : return v != 0, nil + case int32 : return v != 0, nil + case int64 : return v != 0, nil + case uint : return v != 0, nil + case uint8 : return v != 0, nil + case uint16 : return v != 0, nil + case uint32 : return v != 0, nil + case uint64 : return v != 0, nil + case float32: return v != 0, nil + case float64: return v != 0, nil + case string : return strconv.ParseBool(v) + case json.Number: + if i, err := v.Int64(); err == nil { + return i != 0, nil + } else if f, err := v.Float64(); err == nil { + return f != 0, nil + } else { + return false, err + } + default: return false, ErrUnsupportType } default : return false, ErrUnsupportType } } -// Int64 casts the node to int64 value, including V_NUMBER, V_TRUE, V_FALSE, V_ANY, -// V_STRING of invalid digits +// Int64 casts the node to int64 value, +// including V_NUMBER|V_TRUE|V_FALSE|V_ANY|V_STRING +// V_NONE it will return error func (self *Node) Int64() (int64, error) { if err := self.checkRaw(); err != nil { return 0, err } switch self.t { - case _V_NUMBER, types.V_STRING : return numberToInt64(self) + case _V_NUMBER, types.V_STRING : + if i, err := numberToInt64(self); err == nil { + return i, nil + } else if f, err := numberToFloat64(self); err == nil { + return int64(f), nil + } else { + return 0, err + } case types.V_TRUE : return 1, nil case types.V_FALSE : return 0, nil case types.V_NULL : return 0, nil case _V_ANY : any := self.packAny() switch v := any.(type) { - case int : return int64(v), nil - case int8 : return int64(v), nil - case int16 : return int64(v), nil - case int32 : return int64(v), nil - case int64 : return int64(v), nil - case uint : return int64(v), nil - case uint8 : return int64(v), nil - case uint16: return int64(v), nil - case uint32: return int64(v), nil - case uint64: return int64(v), nil + case bool : if v { return 1, nil } else { return 0, nil } + case int : return int64(v), nil + case int8 : return int64(v), nil + case int16 : return int64(v), nil + case int32 : return int64(v), nil + case int64 : return int64(v), nil + case uint : return int64(v), nil + case uint8 : return int64(v), nil + case uint16 : return int64(v), nil + case uint32 : return int64(v), nil + case uint64 : return int64(v), nil + case float32: return int64(v), nil + case float64: return int64(v), nil + case string : + if i, err := strconv.ParseInt(v, 10, 64); err == nil { + return i, nil + } else if f, err := strconv.ParseFloat(v, 64); err == nil { + return int64(f), nil + } else { + return 0, err + } + case json.Number: + if i, err := v.Int64(); err == nil { + return i, nil + } else if f, err := v.Float64(); err == nil { + return int64(f), nil + } else { + return 0, err + } default: return 0, ErrUnsupportType } default : return 0, ErrUnsupportType @@ -238,14 +294,29 @@ func (self *Node) StrictInt64() (int64, error) { case uint16: return int64(v), nil case uint32: return int64(v), nil case uint64: return int64(v), nil + case json.Number: + if i, err := v.Int64(); err == nil { + return i, nil + } else { + return 0, err + } default: return 0, ErrUnsupportType } default : return 0, ErrUnsupportType } } -// Number casts node to float64, including V_NUMBER, V_TRUE, V_FALSE, V_ANY of json.Number, -// V_STRING of invalid digits +func castNumber(v bool) json.Number { + if v { + return json.Number("1") + } else { + return json.Number("0") + } +} + +// Number casts node to float64, +// including V_NUMBER|V_TRUE|V_FALSE|V_ANY|V_STRING|V_NULL, +// V_NONE it will return error func (self *Node) Number() (json.Number, error) { if err := self.checkRaw(); err != nil { return json.Number(""), err @@ -264,10 +335,29 @@ func (self *Node) Number() (json.Number, error) { case types.V_FALSE : return json.Number("0"), nil case types.V_NULL : return json.Number("0"), nil case _V_ANY : - if v, ok := self.packAny().(json.Number); ok { - return v, nil - } else { - return json.Number(""), ErrUnsupportType + any := self.packAny() + switch v := any.(type) { + case bool : return castNumber(v), nil + case int : return castNumber(v != 0), nil + case int8 : return castNumber(v != 0), nil + case int16 : return castNumber(v != 0), nil + case int32 : return castNumber(v != 0), nil + case int64 : return castNumber(v != 0), nil + case uint : return castNumber(v != 0), nil + case uint8 : return castNumber(v != 0), nil + case uint16 : return castNumber(v != 0), nil + case uint32 : return castNumber(v != 0), nil + case uint64 : return castNumber(v != 0), nil + case float32: return castNumber(v != 0), nil + case float64: return castNumber(v != 0), nil + case string : + if _, err := strconv.ParseFloat(v, 64); err == nil { + return json.Number(v), nil + } else { + return json.Number(""), err + } + case json.Number: return v, nil + default: return json.Number(""), ErrUnsupportType } default : return json.Number(""), ErrUnsupportType } @@ -290,29 +380,38 @@ func (self *Node) StrictNumber() (json.Number, error) { } } -// String returns raw string value if node type is V_STRING. -// Or return the string representation of other types: -// V_NULL => "", -// V_TRUE => "true", -// V_FALSE => "false", -// V_NUMBER => "[0-9\.]*" -// V_ANY => interface{}.(string) +// String cast node to string, +// including V_NUMBER|V_TRUE|V_FALSE|V_ANY|V_STRING|V_NULL, +// V_NONE it will return error func (self *Node) String() (string, error) { if err := self.checkRaw(); err != nil { return "", err } switch self.t { - case _V_NUMBER : return toNumber(self).String(), nil case types.V_NULL : return "" , nil case types.V_TRUE : return "true" , nil case types.V_FALSE : return "false", nil - case types.V_STRING : return addr2str(self.p, self.v), nil - case _V_ANY : - if v, ok := self.packAny().(string); ok { - return v, nil - } else { - return "", ErrUnsupportType - } + case types.V_STRING, _V_NUMBER : return addr2str(self.p, self.v), nil + case _V_ANY : + any := self.packAny() + switch v := any.(type) { + case bool : return strconv.FormatBool(v), nil + case int : return strconv.Itoa(v), nil + case int8 : return strconv.Itoa(int(v)), nil + case int16 : return strconv.Itoa(int(v)), nil + case int32 : return strconv.Itoa(int(v)), nil + case int64 : return strconv.Itoa(int(v)), nil + case uint : return strconv.Itoa(int(v)), nil + case uint8 : return strconv.Itoa(int(v)), nil + case uint16 : return strconv.Itoa(int(v)), nil + case uint32 : return strconv.Itoa(int(v)), nil + case uint64 : return strconv.Itoa(int(v)), nil + case float32: return strconv.FormatFloat(float64(v), 'g', -1, 64), nil + case float64: return strconv.FormatFloat(float64(v), 'g', -1, 64), nil + case string : return v, nil + case json.Number: return v.String(), nil + default: return "", ErrUnsupportType + } default : return "" , ErrUnsupportType } } @@ -335,7 +434,9 @@ func (self *Node) StrictString() (string, error) { } } -// Float64 casts node to float64, includeing V_NUMBER, V_TRUE, V_FALSE, V_ANY +// Float64 cast node to float64, +// including V_NUMBER|V_TRUE|V_FALSE|V_ANY|V_STRING|V_NULL, +// V_NONE it will return error func (self *Node) Float64() (float64, error) { if err := self.checkRaw(); err != nil { return 0.0, err @@ -348,11 +449,39 @@ func (self *Node) Float64() (float64, error) { case _V_ANY : any := self.packAny() switch v := any.(type) { - case float32 : return float64(v), nil - case float64 : return float64(v), nil - default : return 0, ErrUnsupportType + case bool : + if v { + return 1.0, nil + } else { + return 0.0, nil + } + case int : return float64(v), nil + case int8 : return float64(v), nil + case int16 : return float64(v), nil + case int32 : return float64(v), nil + case int64 : return float64(v), nil + case uint : return float64(v), nil + case uint8 : return float64(v), nil + case uint16 : return float64(v), nil + case uint32 : return float64(v), nil + case uint64 : return float64(v), nil + case float32: return float64(v), nil + case float64: return float64(v), nil + case string : + if f, err := strconv.ParseFloat(v, 64); err == nil { + return float64(f), nil + } else { + return 0, err + } + case json.Number: + if f, err := v.Float64(); err == nil { + return float64(f), nil + } else { + return 0, err + } + default : return 0, ErrUnsupportType } - default : return 0.0, ErrUnsupportType + default : return 0.0, ErrUnsupportType } } diff --git a/ast/node_test.go b/ast/node_test.go index 1173e5e..5a46a38 100644 --- a/ast/node_test.go +++ b/ast/node_test.go @@ -280,22 +280,60 @@ func TestTypeCast(t *testing.T) { {"Bool", Node{}, false, ErrUnsupportType}, {"Bool", NewAny(true), true, nil}, {"Bool", NewAny(false), false, nil}, + {"Bool", NewAny(int(0)), false, nil}, + {"Bool", NewAny(int8(1)), true, nil}, + {"Bool", NewAny(int16(1)), true, nil}, + {"Bool", NewAny(int32(1)), true, nil}, + {"Bool", NewAny(int64(1)), true, nil}, + {"Bool", NewAny(uint(1)), true, nil}, + {"Bool", NewAny(uint16(1)), true, nil}, + {"Bool", NewAny(uint32(1)), true, nil}, + {"Bool", NewAny(uint64(1)), true, nil}, + {"Bool", NewAny(float64(0)), false, nil}, + {"Bool", NewAny(float32(1)), true, nil}, + {"Bool", NewAny(float64(1)), true, nil}, + {"Bool", NewAny(json.Number("0")), false, nil}, + {"Bool", NewAny(json.Number("1")), true, nil}, + {"Bool", NewAny(json.Number("1.1")), true, nil}, + {"Bool", NewAny(json.Number("+x1.1")), false, nonEmptyErr}, + {"Bool", NewAny(string("0")), false, nil}, + {"Bool", NewAny(string("t")), true, nil}, + {"Bool", NewAny([]byte{0}), false, nonEmptyErr}, {"Bool", NewRaw("true"), true, nil}, {"Bool", NewRaw("false"), false, nil}, {"Bool", NewRaw("null"), false, nil}, + {"Bool", NewString(`true`), true, nil}, + {"Bool", NewString(`false`), false, nil}, + {"Bool", NewString(``), false, nonEmptyErr}, + {"Bool", NewNumber("2"), true, nil}, + {"Bool", NewNumber("-2.1"), true, nil}, + {"Bool", NewNumber("-x-2.1"), false, nonEmptyErr}, {"Int64", NewRaw("true"), int64(1), nil}, {"Int64", NewRaw("false"), int64(0), nil}, {"Int64", NewRaw("\"1\""), int64(1), nil}, - {"Int64", NewRaw("\"1.0\""), int64(0), nonEmptyErr}, - {"Int64", NewAny(int(0)), int64(0), nil}, - {"Int64", NewAny(int8(0)), int64(0), nil}, - {"Int64", NewAny(int16(0)), int64(0), nil}, - {"Int64", NewAny(int32(0)), int64(0), nil}, - {"Int64", NewAny(int64(0)), int64(0), nil}, - {"Int64", NewAny(uint(0)), int64(0), nil}, - {"Int64", NewAny(uint8(0)), int64(0), nil}, - {"Int64", NewAny(uint32(0)), int64(0), nil}, - {"Int64", NewAny(uint64(0)), int64(0), nil}, + {"Int64", NewRaw("\"1.1\""), int64(1), nil}, + {"Int64", NewRaw("\"1.0\""), int64(1), nil}, + {"Int64", NewNumber("+x.0"), int64(0), nonEmptyErr}, + {"Int64", NewAny(false), int64(0), nil}, + {"Int64", NewAny(true), int64(1), nil}, + {"Int64", NewAny(int(1)), int64(1), nil}, + {"Int64", NewAny(int8(1)), int64(1), nil}, + {"Int64", NewAny(int16(1)), int64(1), nil}, + {"Int64", NewAny(int32(1)), int64(1), nil}, + {"Int64", NewAny(int64(1)), int64(1), nil}, + {"Int64", NewAny(uint(1)), int64(1), nil}, + {"Int64", NewAny(uint8(1)), int64(1), nil}, + {"Int64", NewAny(uint32(1)), int64(1), nil}, + {"Int64", NewAny(uint64(1)), int64(1), nil}, + {"Int64", NewAny(float32(1)), int64(1), nil}, + {"Int64", NewAny(float64(1)), int64(1), nil}, + {"Int64", NewAny("1"), int64(1), nil}, + {"Int64", NewAny("1.1"), int64(1), nil}, + {"Int64", NewAny("+1x.1"), int64(0), nonEmptyErr}, + {"Int64", NewAny(json.Number("1")), int64(1), nil}, + {"Int64", NewAny(json.Number("1.1")), int64(1), nil}, + {"Int64", NewAny(json.Number("+1x.1")), int64(0), nonEmptyErr}, + {"Int64", NewAny([]byte{0}), int64(0), ErrUnsupportType}, {"Int64", Node{}, int64(0), ErrUnsupportType}, {"Int64", NewRaw("0"), int64(0), nil}, {"Int64", NewRaw("null"), int64(0), nil}, @@ -318,9 +356,26 @@ func TestTypeCast(t *testing.T) { {"Float64", NewRaw("\"1.0\""), float64(1.0), nil}, {"Float64", NewRaw("\"xx\""), float64(0), nonEmptyErr}, {"Float64", Node{}, float64(0), ErrUnsupportType}, - {"Float64", NewAny(float32(0)), float64(0), nil}, - {"Float64", NewAny(float64(0)), float64(0), nil}, + {"Float64", NewAny(false), float64(0), nil}, + {"Float64", NewAny(true), float64(1), nil}, + {"Float64", NewAny(int(1)), float64(1), nil}, + {"Float64", NewAny(int8(1)), float64(1), nil}, + {"Float64", NewAny(int16(1)), float64(1), nil}, + {"Float64", NewAny(int32(1)), float64(1), nil}, + {"Float64", NewAny(int64(1)), float64(1), nil}, + {"Float64", NewAny(uint(1)), float64(1), nil}, + {"Float64", NewAny(uint8(1)), float64(1), nil}, + {"Float64", NewAny(uint32(1)), float64(1), nil}, + {"Float64", NewAny(uint64(1)), float64(1), nil}, + {"Float64", NewAny(float32(1)), float64(1), nil}, + {"Float64", NewAny(float64(1)), float64(1), nil}, + {"Float64", NewAny("1.1"), float64(1.1), nil}, + {"Float64", NewAny("+1x.1"), float64(0), nonEmptyErr}, + {"Float64", NewAny(json.Number("0")), float64(0), nil}, + {"Float64", NewAny(json.Number("x")), float64(0), nonEmptyErr}, + {"Float64", NewAny([]byte{0}), float64(0), ErrUnsupportType}, {"Float64", NewRaw("0.0"), float64(0.0), nil}, + {"Float64", NewRaw("1"), float64(1.0), nil}, {"Float64", NewRaw("null"), float64(0.0), nil}, {"StrictFloat64", NewRaw("true"), float64(0), ErrUnsupportType}, {"StrictFloat64", NewRaw("false"), float64(0), ErrUnsupportType}, @@ -330,11 +385,31 @@ func TestTypeCast(t *testing.T) { {"StrictFloat64", NewRaw("0.0"), float64(0.0), nil}, {"StrictFloat64", NewRaw("null"), float64(0.0), ErrUnsupportType}, {"Number", Node{}, json.Number(""), ErrUnsupportType}, + {"Number", NewAny(false), json.Number("0"), nil}, + {"Number", NewAny(true), json.Number("1"), nil}, + {"Number", NewAny(int(1)), json.Number("1"), nil}, + {"Number", NewAny(int8(1)), json.Number("1"), nil}, + {"Number", NewAny(int16(1)), json.Number("1"), nil}, + {"Number", NewAny(int32(1)), json.Number("1"), nil}, + {"Number", NewAny(int64(1)), json.Number("1"), nil}, + {"Number", NewAny(uint(1)), json.Number("1"), nil}, + {"Number", NewAny(uint8(1)), json.Number("1"), nil}, + {"Number", NewAny(uint32(1)), json.Number("1"), nil}, + {"Number", NewAny(uint64(1)), json.Number("1"), nil}, + {"Number", NewAny(float32(1)), json.Number("1"), nil}, + {"Number", NewAny(float64(1)), json.Number("1"), nil}, + {"Number", NewAny("1.1"), json.Number("1.1"), nil}, + {"Number", NewAny("+1x.1"), json.Number(""), nonEmptyErr}, {"Number", NewAny(json.Number("0")), json.Number("0"), nil}, + {"Number", NewAny(json.Number("x")), json.Number("x"), nil}, + {"Number", NewAny(json.Number("+1x.1")), json.Number("+1x.1"), nil}, + {"Number", NewAny([]byte{0}), json.Number(""), ErrUnsupportType}, + {"Number", NewRaw("x"), json.Number(""), nonEmptyErr}, {"Number", NewRaw("0.0"), json.Number("0.0"), nil}, {"Number", NewRaw("\"1\""), json.Number("1"), nil}, {"Number", NewRaw("\"1.1\""), json.Number("1.1"), nil}, {"Number", NewRaw("\"0.x0\""), json.Number(""), nonEmptyErr}, + {"Number", NewRaw("{]"), json.Number(""), nonEmptyErr}, {"Number", NewRaw("true"), json.Number("1"), nil}, {"Number", NewRaw("false"), json.Number("0"), nil}, {"Number", NewRaw("null"), json.Number("0"), nil}, @@ -352,6 +427,24 @@ func TestTypeCast(t *testing.T) { {"String", NewRaw(`true`), "true", nil}, {"String", NewRaw(`false`), "false", nil}, {"String", NewRaw(`null`), "", nil}, + {"String", NewAny(false), "false", nil}, + {"String", NewAny(true), "true", nil}, + {"String", NewAny(int(1)), "1", nil}, + {"String", NewAny(int8(1)), "1", nil}, + {"String", NewAny(int16(1)), "1", nil}, + {"String", NewAny(int32(1)), "1", nil}, + {"String", NewAny(int64(1)), "1", nil}, + {"String", NewAny(uint(1)), "1", nil}, + {"String", NewAny(uint8(1)), "1", nil}, + {"String", NewAny(uint32(1)), "1", nil}, + {"String", NewAny(uint64(1)), "1", nil}, + {"String", NewAny(float32(1)), "1", nil}, + {"String", NewAny(float64(1)), "1", nil}, + {"String", NewAny("1.1"), "1.1", nil}, + {"String", NewAny("+1x.1"), "+1x.1", nil}, + {"String", NewAny(json.Number("0")), ("0"), nil}, + {"String", NewAny(json.Number("x")), ("x"), nil}, + {"String", NewAny([]byte{0}), (""), ErrUnsupportType}, {"StrictString", Node{}, "", ErrUnsupportType}, {"StrictString", NewAny(`\u263a`), `\u263a`, nil}, {"StrictString", NewRaw(`"\u263a"`), `☺`, nil}, @@ -391,18 +484,18 @@ func TestTypeCast(t *testing.T) { m := rt.MethodByName(c.method) rets := m.Call([]reflect.Value{}) if len(rets) != 2 { - t.Fatal(i, rets) + t.Error(i, rets) } if !reflect.DeepEqual(rets[0].Interface(), c.exp) { - t.Fatal(i, rets[0].Interface(), c.exp) + t.Error(i, rets[0].Interface(), c.exp) } v := rets[1].Interface(); if c.err == nonEmptyErr { if reflect.ValueOf(v).IsNil() { - t.Fatal(i, v) + t.Error(i, v) } } else if v != c.err { - t.Fatal(i, v) + t.Error(i, v) } } }