encoding/textpb: add string fields UTF-8 validation

Change-Id: I15aec2b90efae9366eb496dc221b9e8cacd9d8e6
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/171122
Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
This commit is contained in:
Herbie Ong 2019-04-08 17:32:44 -07:00
parent b132ae09f0
commit 21a3974ed6
4 changed files with 206 additions and 7 deletions

View File

@ -7,6 +7,7 @@ package textpb
import ( import (
"fmt" "fmt"
"strings" "strings"
"unicode/utf8"
"github.com/golang/protobuf/v2/internal/encoding/text" "github.com/golang/protobuf/v2/internal/encoding/text"
"github.com/golang/protobuf/v2/internal/errors" "github.com/golang/protobuf/v2/internal/errors"
@ -293,7 +294,13 @@ func unmarshalScalar(input text.Value, fd pref.FieldDescriptor) (pref.Value, err
} }
case pref.StringKind: case pref.StringKind:
if input.Type() == text.String { if input.Type() == text.String {
return pref.ValueOf(string(input.String())), nil s := input.String()
if utf8.ValidString(s) {
return pref.ValueOf(s), nil
}
var nerr errors.NonFatal
nerr.AppendInvalidUTF8(string(fd.FullName()))
return pref.ValueOf(s), nerr.E
} }
case pref.BytesKind: case pref.BytesKind:
if input.Type() == text.String { if input.Type() == text.String {
@ -421,11 +428,12 @@ func unmarshalMapKey(input text.Value, fd pref.FieldDescriptor) (pref.MapKey, er
return fd.Default().MapKey(), nil return fd.Default().MapKey(), nil
} }
var nerr errors.NonFatal
val, err := unmarshalScalar(input, fd) val, err := unmarshalScalar(input, fd)
if err != nil { if !nerr.Merge(err) {
return pref.MapKey{}, errors.New("%v contains invalid key: %v", fd.FullName(), input) return pref.MapKey{}, errors.New("%v contains invalid key: %v", fd.FullName(), input)
} }
return val.MapKey(), nil return val.MapKey(), nerr.E
} }
// unmarshalMapMessageValue unmarshals given message-type text.Value into a protoreflect.Map for // unmarshalMapMessageValue unmarshals given message-type text.Value into a protoreflect.Map for
@ -447,18 +455,19 @@ func (o UnmarshalOptions) unmarshalMapMessageValue(input text.Value, pkey pref.M
// unmarshalMapScalarValue unmarshals given scalar-type text.Value into a protoreflect.Map // unmarshalMapScalarValue unmarshals given scalar-type text.Value into a protoreflect.Map
// for the given MapKey. // for the given MapKey.
func unmarshalMapScalarValue(input text.Value, pkey pref.MapKey, fd pref.FieldDescriptor, mmap pref.Map) error { func unmarshalMapScalarValue(input text.Value, pkey pref.MapKey, fd pref.FieldDescriptor, mmap pref.Map) error {
var nerr errors.NonFatal
var val pref.Value var val pref.Value
if input.Type() == 0 { if input.Type() == 0 {
val = fd.Default() val = fd.Default()
} else { } else {
var err error var err error
val, err = unmarshalScalar(input, fd) val, err = unmarshalScalar(input, fd)
if err != nil { if !nerr.Merge(err) {
return err return err
} }
} }
mmap.Set(pkey, val) mmap.Set(pkey, val)
return nil return nerr.E
} }
// isExpandedAny returns true if given [][2]text.Value may be an expanded Any that contains only one // isExpandedAny returns true if given [][2]text.Value may be an expanded Any that contains only one

View File

@ -10,6 +10,7 @@ import (
protoV1 "github.com/golang/protobuf/proto" protoV1 "github.com/golang/protobuf/proto"
"github.com/golang/protobuf/v2/encoding/textpb" "github.com/golang/protobuf/v2/encoding/textpb"
"github.com/golang/protobuf/v2/internal/errors"
"github.com/golang/protobuf/v2/internal/legacy" "github.com/golang/protobuf/v2/internal/legacy"
"github.com/golang/protobuf/v2/internal/scalar" "github.com/golang/protobuf/v2/internal/scalar"
"github.com/golang/protobuf/v2/proto" "github.com/golang/protobuf/v2/proto"
@ -182,6 +183,14 @@ s_string: "谷歌"
SBytes: []byte("\xe8\xb0\xb7\xe6\xad\x8c"), SBytes: []byte("\xe8\xb0\xb7\xe6\xad\x8c"),
SString: "谷歌", SString: "谷歌",
}, },
}, {
desc: "string with invalid UTF-8",
inputMessage: &pb3.Scalars{},
inputText: `s_string: "abc\xff"`,
wantMessage: &pb3.Scalars{
SString: "abc\xff",
},
wantErr: true,
}, { }, {
desc: "proto2 message contains unknown field", desc: "proto2 message contains unknown field",
inputMessage: &pb2.Scalars{}, inputMessage: &pb2.Scalars{},
@ -473,6 +482,19 @@ s_nested: {
}, },
}, },
}, },
}, {
desc: "proto3 nested message contains invalid UTF-8",
inputMessage: &pb3.Nests{},
inputText: `s_nested: {
s_string: "abc\xff"
}
`,
wantMessage: &pb3.Nests{
SNested: &pb3.Nested{
SString: "abc\xff",
},
},
wantErr: true,
}, { }, {
desc: "oneof set to empty string", desc: "oneof set to empty string",
inputMessage: &pb3.Oneofs{}, inputMessage: &pb3.Oneofs{},
@ -560,6 +582,14 @@ rpt_string: "b"
RptString: []string{"a", "x", "y", "b"}, RptString: []string{"a", "x", "y", "b"},
RptBool: []bool{true, false, true}, RptBool: []bool{true, false, true},
}, },
}, {
desc: "repeated contains invalid UTF-8",
inputMessage: &pb2.Repeats{},
inputText: `rpt_string: "abc\xff"`,
wantMessage: &pb2.Repeats{
RptString: []string{"abc\xff"},
},
wantErr: true,
}, { }, {
desc: "repeated enums", desc: "repeated enums",
inputMessage: &pb2.Enums{}, inputMessage: &pb2.Enums{},
@ -870,6 +900,34 @@ int32_to_str: {}
0: "", 0: "",
}, },
}, },
}, {
desc: "map field value contains invalid UTF-8",
inputMessage: &pb3.Maps{},
inputText: `int32_to_str: {
key: 101
value: "abc\xff"
}
`,
wantMessage: &pb3.Maps{
Int32ToStr: map[int32]string{
101: "abc\xff",
},
},
wantErr: true,
}, {
desc: "map field key contains invalid UTF-8",
inputMessage: &pb3.Maps{},
inputText: `str_to_nested: {
key: "abc\xff"
value: {}
}
`,
wantMessage: &pb3.Maps{
StrToNested: map[string]*pb3.Nested{
"abc\xff": {},
},
},
wantErr: true,
}, { }, {
desc: "map contains unknown field", desc: "map contains unknown field",
inputMessage: &pb3.Maps{}, inputMessage: &pb3.Maps{},
@ -1164,6 +1222,16 @@ opt_int32: 42
}) })
return m return m
}(), }(),
}, {
desc: "extension field contains invalid UTF-8",
inputMessage: &pb2.Extensions{},
inputText: `[pb2.opt_ext_string]: "abc\xff"`,
wantMessage: func() proto.Message {
m := &pb2.Extensions{}
setExtension(m, pb2.E_OptExtString, "abc\xff")
return m
}(),
wantErr: true,
}, { }, {
desc: "extensions of repeated fields", desc: "extensions of repeated fields",
inputMessage: &pb2.Extensions{}, inputMessage: &pb2.Extensions{},
@ -1418,6 +1486,32 @@ value: "some bytes"
} }
}(), }(),
wantErr: true, wantErr: true,
}, {
desc: "Any with invalid UTF-8",
umo: textpb.UnmarshalOptions{
Resolver: preg.NewTypes((&pb3.Nested{}).ProtoReflect().Type()),
},
inputMessage: &knownpb.Any{},
inputText: `
[pb3.Nested]: {
s_string: "abc\xff"
}
`,
wantMessage: func() proto.Message {
m := &pb3.Nested{
SString: "abc\xff",
}
var nerr errors.NonFatal
b, err := proto.MarshalOptions{Deterministic: true}.Marshal(m)
if !nerr.Merge(err) {
t.Fatalf("error in binary marshaling message for Any.value: %v", err)
}
return &knownpb.Any{
TypeUrl: string(m.ProtoReflect().Type().FullName()),
Value: b,
}
}(),
wantErr: true,
}, { }, {
desc: "Any expanded with unregistered type", desc: "Any expanded with unregistered type",
umo: textpb.UnmarshalOptions{Resolver: preg.NewTypes()}, umo: textpb.UnmarshalOptions{Resolver: preg.NewTypes()},
@ -1459,7 +1553,6 @@ type_url: "pb2.Nested"
for _, tt := range tests { for _, tt := range tests {
tt := tt tt := tt
t.Run(tt.desc, func(t *testing.T) { t.Run(tt.desc, func(t *testing.T) {
t.Parallel()
err := tt.umo.Unmarshal(tt.inputMessage, []byte(tt.inputText)) err := tt.umo.Unmarshal(tt.inputMessage, []byte(tt.inputText))
if err != nil && !tt.wantErr { if err != nil && !tt.wantErr {
t.Errorf("Unmarshal() returned error: %v\n\n", err) t.Errorf("Unmarshal() returned error: %v\n\n", err)

View File

@ -7,6 +7,7 @@ package textpb
import ( import (
"fmt" "fmt"
"sort" "sort"
"unicode/utf8"
"github.com/golang/protobuf/v2/internal/encoding/text" "github.com/golang/protobuf/v2/internal/encoding/text"
"github.com/golang/protobuf/v2/internal/encoding/wire" "github.com/golang/protobuf/v2/internal/encoding/wire"
@ -174,9 +175,18 @@ func (o MarshalOptions) marshalSingular(val pref.Value, fd pref.FieldDescriptor)
pref.Sfixed32Kind, pref.Fixed32Kind, pref.Sfixed32Kind, pref.Fixed32Kind,
pref.Sfixed64Kind, pref.Fixed64Kind, pref.Sfixed64Kind, pref.Fixed64Kind,
pref.FloatKind, pref.DoubleKind, pref.FloatKind, pref.DoubleKind,
pref.StringKind, pref.BytesKind: pref.BytesKind:
return text.ValueOf(val.Interface()), nil return text.ValueOf(val.Interface()), nil
case pref.StringKind:
s := val.String()
if utf8.ValidString(s) {
return text.ValueOf(s), nil
}
var nerr errors.NonFatal
nerr.AppendInvalidUTF8(string(fd.FullName()))
return text.ValueOf(s), nerr.E
case pref.EnumKind: case pref.EnumKind:
num := val.Enum() num := val.Enum()
if desc := fd.EnumType().Values().ByNumber(num); desc != nil { if desc := fd.EnumType().Values().ByNumber(num); desc != nil {

View File

@ -169,6 +169,14 @@ opt_double: 1.0199999809265137
opt_bytes: "谷歌" opt_bytes: "谷歌"
opt_string: "谷歌" opt_string: "谷歌"
`, `,
}, {
desc: "string with invalid UTF-8",
input: &pb3.Scalars{
SString: "abc\xff",
},
want: `s_string: "abc\xff"
`,
wantErr: true,
}, { }, {
desc: "float nan", desc: "float nan",
input: &pb3.Scalars{ input: &pb3.Scalars{
@ -363,6 +371,18 @@ OptGroup: {}
} }
} }
`, `,
}, {
desc: "proto3 nested message contains invalid UTF-8",
input: &pb3.Nests{
SNested: &pb3.Nested{
SString: "abc\xff",
},
},
want: `s_nested: {
s_string: "abc\xff"
}
`,
wantErr: true,
}, { }, {
desc: "oneof not set", desc: "oneof not set",
input: &pb3.Oneofs{}, input: &pb3.Oneofs{},
@ -472,6 +492,14 @@ rpt_string: "世界"
rpt_bytes: "hello" rpt_bytes: "hello"
rpt_bytes: "世界" rpt_bytes: "世界"
`, `,
}, {
desc: "repeated contains invalid UTF-8",
input: &pb2.Repeats{
RptString: []string{"abc\xff"},
},
want: `rpt_string: "abc\xff"
`,
wantErr: true,
}, { }, {
desc: "repeated enums", desc: "repeated enums",
input: &pb2.Enums{ input: &pb2.Enums{
@ -670,6 +698,32 @@ str_to_oneofs: {
} }
} }
`, `,
}, {
desc: "map field value contains invalid UTF-8",
input: &pb3.Maps{
Int32ToStr: map[int32]string{
101: "abc\xff",
},
},
want: `int32_to_str: {
key: 101
value: "abc\xff"
}
`,
wantErr: true,
}, {
desc: "map field key contains invalid UTF-8",
input: &pb3.Maps{
StrToNested: map[string]*pb3.Nested{
"abc\xff": {},
},
},
want: `str_to_nested: {
key: "abc\xff"
value: {}
}
`,
wantErr: true,
}, { }, {
desc: "map field contains nil value", desc: "map field contains nil value",
input: &pb3.Maps{ input: &pb3.Maps{
@ -918,6 +972,16 @@ opt_int32: 42
} }
[pb2.opt_ext_string]: "extension field" [pb2.opt_ext_string]: "extension field"
`, `,
}, {
desc: "extension field contains invalid UTF-8",
input: func() proto.Message {
m := &pb2.Extensions{}
setExtension(m, pb2.E_OptExtString, "abc\xff")
return m
}(),
want: `[pb2.opt_ext_string]: "abc\xff"
`,
wantErr: true,
}, { }, {
desc: "extension partial returns error", desc: "extension partial returns error",
input: func() proto.Message { input: func() proto.Message {
@ -1175,6 +1239,29 @@ value: "\n\x13embedded inside Any\x12\x0b\n\tinception"
want: `[pb2.PartialRequired]: { want: `[pb2.PartialRequired]: {
opt_string: "embedded inside Any" opt_string: "embedded inside Any"
} }
`,
wantErr: true,
}, {
desc: "Any with invalid UTF-8",
mo: textpb.MarshalOptions{
Resolver: preg.NewTypes((&pb3.Nested{}).ProtoReflect().Type()),
},
input: func() proto.Message {
m := &pb3.Nested{
SString: "abc\xff",
}
b, err := proto.MarshalOptions{Deterministic: true}.Marshal(m)
if err != nil {
t.Fatalf("error in binary marshaling message for Any.value: %v", err)
}
return &knownpb.Any{
TypeUrl: string(m.ProtoReflect().Type().FullName()),
Value: b,
}
}(),
want: `[pb3.Nested]: {
s_string: "abc\xff"
}
`, `,
wantErr: true, wantErr: true,
}, { }, {