reflect/protoreflect: remove nullability from repeated extension fields

Remove repeated extension fields from the set of nullable fields,
so that Has reports false and Range does not visit a a zero-length
repeated extension field.

This corrects a fuzzer-detected case where unmarshaling and remarshaling
a wire-format message could result in a semantic change. For a repeated
extension field in non-packed encoding, unmarshaling a packed
representation of the field would result in a message which Has the
extension. Remarshaling it would discard the the field.

Fixes golang.org/protobuf#975

Change-Id: Ie836559c93d218db5b5201742a3b8ebbaacf54ed
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/204897
Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
Reviewed-by: Joe Tsai <joetsai@google.com>
This commit is contained in:
Damien Neil 2019-11-01 15:18:36 -07:00
parent ef19a2a994
commit a0a54b8005
6 changed files with 81 additions and 19 deletions

View File

@ -145,18 +145,33 @@ type extensionMap map[int32]ExtensionField
func (m *extensionMap) Range(f func(pref.FieldDescriptor, pref.Value) bool) {
if m != nil {
for _, x := range *m {
xt := x.Type()
if !f(xt.TypeDescriptor(), x.Value()) {
xd := x.Type().TypeDescriptor()
v := x.Value()
if xd.IsList() && v.List().Len() == 0 {
continue
}
if !f(xd, v) {
return
}
}
}
}
func (m *extensionMap) Has(xt pref.ExtensionType) (ok bool) {
if m != nil {
_, ok = (*m)[int32(xt.TypeDescriptor().Number())]
if m == nil {
return false
}
return ok
xd := xt.TypeDescriptor()
x, ok := (*m)[int32(xd.Number())]
if !ok {
return false
}
switch {
case xd.IsList():
return x.Value().List().Len() > 0
case xd.IsMap():
return x.Value().Map().Len() > 0
}
return true
}
func (m *extensionMap) Clear(xt pref.ExtensionType) {
delete(*m, int32(xt.TypeDescriptor().Number()))

View File

@ -614,6 +614,30 @@ var testProtos = []testProto{
},
}.Marshal(),
},
{
desc: "basic repeated types (zero-length packed encoding)",
decodeTo: []proto.Message{
&testpb.TestAllTypes{},
&test3pb.TestAllTypes{},
&testpb.TestAllExtensions{},
},
wire: pack.Message{
pack.Tag{31, pack.BytesType}, pack.LengthPrefix{},
pack.Tag{32, pack.BytesType}, pack.LengthPrefix{},
pack.Tag{33, pack.BytesType}, pack.LengthPrefix{},
pack.Tag{34, pack.BytesType}, pack.LengthPrefix{},
pack.Tag{35, pack.BytesType}, pack.LengthPrefix{},
pack.Tag{36, pack.BytesType}, pack.LengthPrefix{},
pack.Tag{37, pack.BytesType}, pack.LengthPrefix{},
pack.Tag{38, pack.BytesType}, pack.LengthPrefix{},
pack.Tag{39, pack.BytesType}, pack.LengthPrefix{},
pack.Tag{40, pack.BytesType}, pack.LengthPrefix{},
pack.Tag{41, pack.BytesType}, pack.LengthPrefix{},
pack.Tag{42, pack.BytesType}, pack.LengthPrefix{},
pack.Tag{43, pack.BytesType}, pack.LengthPrefix{},
pack.Tag{51, pack.BytesType}, pack.LengthPrefix{},
}.Marshal(),
},
{
desc: "packed repeated types",
decodeTo: []proto.Message{&testpb.TestPackedTypes{
@ -700,6 +724,29 @@ var testProtos = []testProto{
},
}.Marshal(),
},
{
desc: "packed repeated types (zero length)",
decodeTo: []proto.Message{
&testpb.TestPackedTypes{},
&testpb.TestPackedExtensions{},
},
wire: pack.Message{
pack.Tag{90, pack.BytesType}, pack.LengthPrefix{},
pack.Tag{91, pack.BytesType}, pack.LengthPrefix{},
pack.Tag{92, pack.BytesType}, pack.LengthPrefix{},
pack.Tag{93, pack.BytesType}, pack.LengthPrefix{},
pack.Tag{94, pack.BytesType}, pack.LengthPrefix{},
pack.Tag{95, pack.BytesType}, pack.LengthPrefix{},
pack.Tag{96, pack.BytesType}, pack.LengthPrefix{},
pack.Tag{97, pack.BytesType}, pack.LengthPrefix{},
pack.Tag{98, pack.BytesType}, pack.LengthPrefix{},
pack.Tag{99, pack.BytesType}, pack.LengthPrefix{},
pack.Tag{100, pack.BytesType}, pack.LengthPrefix{},
pack.Tag{101, pack.BytesType}, pack.LengthPrefix{},
pack.Tag{102, pack.BytesType}, pack.LengthPrefix{},
pack.Tag{103, pack.BytesType}, pack.LengthPrefix{},
}.Marshal(),
},
{
desc: "repeated messages",
decodeTo: []proto.Message{&testpb.TestAllTypes{

View File

@ -66,8 +66,8 @@ type Message interface {
// Some fields have the property of nullability where it is possible to
// distinguish between the default value of a field and whether the field
// was explicitly populated with the default value. Singular message fields,
// member fields of a oneof, proto2 scalar fields, and extension fields
// are nullable. Such fields are populated only if explicitly set.
// member fields of a oneof, and proto2 scalar fields are nullable. Such
// fields are populated only if explicitly set.
//
// In other cases (aside from the nullable cases above),
// a proto3 scalar field is populated if it contains a non-zero value, and

View File

@ -118,12 +118,12 @@ func testField(t testing.TB, m pref.Message, fd pref.FieldDescriptor) {
if fd.Syntax() == pref.Proto3 && fd.Message() == nil {
wantHas = false
}
if fd.Cardinality() == pref.Repeated {
wantHas = false
}
if fd.IsExtension() {
wantHas = true
}
if fd.Cardinality() == pref.Repeated {
wantHas = false
}
if fd.ContainingOneof() != nil {
wantHas = true
}
@ -176,7 +176,7 @@ func testField(t testing.TB, m pref.Message, fd pref.FieldDescriptor) {
switch {
case fd.IsList() || fd.IsMap():
m.Set(fd, m.Get(fd))
if got, want := m.Has(fd), fd.IsExtension() || fd.ContainingOneof() != nil; got != want {
if got, want := m.Has(fd), (fd.IsExtension() && fd.Cardinality() != pref.Repeated) || fd.ContainingOneof() != nil; got != want {
t.Errorf("after setting %q to default:\nMessage.Has(%v) = %v, want %v", name, num, got, want)
}
case fd.Message() == nil:
@ -300,7 +300,7 @@ func testFieldList(t testing.TB, m pref.Message, fd pref.FieldDescriptor) {
// Append values.
var want pref.List = &testList{}
for i, n := range []seed{1, 0, minVal, maxVal} {
if got, want := m.Has(fd), i > 0 || fd.IsExtension(); got != want {
if got, want := m.Has(fd), i > 0; got != want {
t.Errorf("after appending %d elements to %q:\nMessage.Has(%v) = %v, want %v", i, name, num, got, want)
}
v := newListElement(fd, list, n, nil)
@ -327,7 +327,7 @@ func testFieldList(t testing.TB, m pref.Message, fd pref.FieldDescriptor) {
n := want.Len() - 1
want.Truncate(n)
list.Truncate(n)
if got, want := m.Has(fd), want.Len() > 0 || fd.IsExtension(); got != want {
if got, want := m.Has(fd), want.Len() > 0; got != want {
t.Errorf("after truncating %q to %d:\nMessage.Has(%v) = %v, want %v", name, n, num, got, want)
}
if got, want := m.Get(fd), pref.ValueOfList(want); !valueEqual(got, want) {

View File

@ -83,9 +83,9 @@ func (m *Message) Range(f func(pref.FieldDescriptor, pref.Value) bool) {
fd := m.ext[num]
if fd == nil {
fd = m.Descriptor().Fields().ByNumber(num)
if !isSet(fd, v) {
continue
}
}
if !isSet(fd, v) {
continue
}
if !f(fd, v) {
return
@ -97,8 +97,8 @@ func (m *Message) Range(f func(pref.FieldDescriptor, pref.Value) bool) {
// See protoreflect.Message for details.
func (m *Message) Has(fd pref.FieldDescriptor) bool {
m.checkField(fd)
if fd.IsExtension() {
return m.ext[fd.Number()] == fd
if fd.IsExtension() && m.ext[fd.Number()] != fd {
return false
}
v, ok := m.known[fd.Number()]
if !ok {
@ -371,7 +371,7 @@ func isSet(fd pref.FieldDescriptor, v pref.Value) bool {
return v.List().Len() > 0
case fd.ContainingOneof() != nil:
return true
case fd.Syntax() == pref.Proto3:
case fd.Syntax() == pref.Proto3 && !fd.IsExtension():
switch fd.Kind() {
case pref.BoolKind:
return v.Bool()