From a0a54b800581be257b89521074f86b40a3496564 Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Fri, 1 Nov 2019 15:18:36 -0700 Subject: [PATCH] reflect/protoreflect: remove nullability from repeated extension fields Remove repeated extension fields from the set of nullable fields, so that Has reports false and Range does not visit a a zero-length repeated extension field. This corrects a fuzzer-detected case where unmarshaling and remarshaling a wire-format message could result in a semantic change. For a repeated extension field in non-packed encoding, unmarshaling a packed representation of the field would result in a message which Has the extension. Remarshaling it would discard the the field. Fixes golang.org/protobuf#975 Change-Id: Ie836559c93d218db5b5201742a3b8ebbaacf54ed Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/204897 Reviewed-by: Joe Tsai Reviewed-by: Joe Tsai --- .../26ff72cf93341dd2ec7f4f3e4138f1158d554916 | Bin 0 -> 3 bytes internal/impl/message_reflect.go | 25 ++++++++-- proto/decode_test.go | 47 ++++++++++++++++++ reflect/protoreflect/value.go | 4 +- testing/prototest/prototest.go | 12 ++--- types/dynamicpb/dynamic.go | 12 ++--- 6 files changed, 81 insertions(+), 19 deletions(-) create mode 100644 internal/fuzz/wire/corpus/26ff72cf93341dd2ec7f4f3e4138f1158d554916 diff --git a/internal/fuzz/wire/corpus/26ff72cf93341dd2ec7f4f3e4138f1158d554916 b/internal/fuzz/wire/corpus/26ff72cf93341dd2ec7f4f3e4138f1158d554916 new file mode 100644 index 0000000000000000000000000000000000000000..560240fc3553850bc0be9f5c51d4de1754ed44a4 GIT binary patch literal 3 KcmZ3*!~g&S1puu8 literal 0 HcmV?d00001 diff --git a/internal/impl/message_reflect.go b/internal/impl/message_reflect.go index bc503030..f5f7f2bf 100644 --- a/internal/impl/message_reflect.go +++ b/internal/impl/message_reflect.go @@ -145,18 +145,33 @@ type extensionMap map[int32]ExtensionField func (m *extensionMap) Range(f func(pref.FieldDescriptor, pref.Value) bool) { if m != nil { for _, x := range *m { - xt := x.Type() - if !f(xt.TypeDescriptor(), x.Value()) { + xd := x.Type().TypeDescriptor() + v := x.Value() + if xd.IsList() && v.List().Len() == 0 { + continue + } + if !f(xd, v) { return } } } } func (m *extensionMap) Has(xt pref.ExtensionType) (ok bool) { - if m != nil { - _, ok = (*m)[int32(xt.TypeDescriptor().Number())] + if m == nil { + return false } - return ok + xd := xt.TypeDescriptor() + x, ok := (*m)[int32(xd.Number())] + if !ok { + return false + } + switch { + case xd.IsList(): + return x.Value().List().Len() > 0 + case xd.IsMap(): + return x.Value().Map().Len() > 0 + } + return true } func (m *extensionMap) Clear(xt pref.ExtensionType) { delete(*m, int32(xt.TypeDescriptor().Number())) diff --git a/proto/decode_test.go b/proto/decode_test.go index 10b3f951..a01da16f 100644 --- a/proto/decode_test.go +++ b/proto/decode_test.go @@ -614,6 +614,30 @@ var testProtos = []testProto{ }, }.Marshal(), }, + { + desc: "basic repeated types (zero-length packed encoding)", + decodeTo: []proto.Message{ + &testpb.TestAllTypes{}, + &test3pb.TestAllTypes{}, + &testpb.TestAllExtensions{}, + }, + wire: pack.Message{ + pack.Tag{31, pack.BytesType}, pack.LengthPrefix{}, + pack.Tag{32, pack.BytesType}, pack.LengthPrefix{}, + pack.Tag{33, pack.BytesType}, pack.LengthPrefix{}, + pack.Tag{34, pack.BytesType}, pack.LengthPrefix{}, + pack.Tag{35, pack.BytesType}, pack.LengthPrefix{}, + pack.Tag{36, pack.BytesType}, pack.LengthPrefix{}, + pack.Tag{37, pack.BytesType}, pack.LengthPrefix{}, + pack.Tag{38, pack.BytesType}, pack.LengthPrefix{}, + pack.Tag{39, pack.BytesType}, pack.LengthPrefix{}, + pack.Tag{40, pack.BytesType}, pack.LengthPrefix{}, + pack.Tag{41, pack.BytesType}, pack.LengthPrefix{}, + pack.Tag{42, pack.BytesType}, pack.LengthPrefix{}, + pack.Tag{43, pack.BytesType}, pack.LengthPrefix{}, + pack.Tag{51, pack.BytesType}, pack.LengthPrefix{}, + }.Marshal(), + }, { desc: "packed repeated types", decodeTo: []proto.Message{&testpb.TestPackedTypes{ @@ -700,6 +724,29 @@ var testProtos = []testProto{ }, }.Marshal(), }, + { + desc: "packed repeated types (zero length)", + decodeTo: []proto.Message{ + &testpb.TestPackedTypes{}, + &testpb.TestPackedExtensions{}, + }, + wire: pack.Message{ + pack.Tag{90, pack.BytesType}, pack.LengthPrefix{}, + pack.Tag{91, pack.BytesType}, pack.LengthPrefix{}, + pack.Tag{92, pack.BytesType}, pack.LengthPrefix{}, + pack.Tag{93, pack.BytesType}, pack.LengthPrefix{}, + pack.Tag{94, pack.BytesType}, pack.LengthPrefix{}, + pack.Tag{95, pack.BytesType}, pack.LengthPrefix{}, + pack.Tag{96, pack.BytesType}, pack.LengthPrefix{}, + pack.Tag{97, pack.BytesType}, pack.LengthPrefix{}, + pack.Tag{98, pack.BytesType}, pack.LengthPrefix{}, + pack.Tag{99, pack.BytesType}, pack.LengthPrefix{}, + pack.Tag{100, pack.BytesType}, pack.LengthPrefix{}, + pack.Tag{101, pack.BytesType}, pack.LengthPrefix{}, + pack.Tag{102, pack.BytesType}, pack.LengthPrefix{}, + pack.Tag{103, pack.BytesType}, pack.LengthPrefix{}, + }.Marshal(), + }, { desc: "repeated messages", decodeTo: []proto.Message{&testpb.TestAllTypes{ diff --git a/reflect/protoreflect/value.go b/reflect/protoreflect/value.go index 1f100acd..e1d6487c 100644 --- a/reflect/protoreflect/value.go +++ b/reflect/protoreflect/value.go @@ -66,8 +66,8 @@ type Message interface { // Some fields have the property of nullability where it is possible to // distinguish between the default value of a field and whether the field // was explicitly populated with the default value. Singular message fields, - // member fields of a oneof, proto2 scalar fields, and extension fields - // are nullable. Such fields are populated only if explicitly set. + // member fields of a oneof, and proto2 scalar fields are nullable. Such + // fields are populated only if explicitly set. // // In other cases (aside from the nullable cases above), // a proto3 scalar field is populated if it contains a non-zero value, and diff --git a/testing/prototest/prototest.go b/testing/prototest/prototest.go index 437c216d..79c792d0 100644 --- a/testing/prototest/prototest.go +++ b/testing/prototest/prototest.go @@ -118,12 +118,12 @@ func testField(t testing.TB, m pref.Message, fd pref.FieldDescriptor) { if fd.Syntax() == pref.Proto3 && fd.Message() == nil { wantHas = false } - if fd.Cardinality() == pref.Repeated { - wantHas = false - } if fd.IsExtension() { wantHas = true } + if fd.Cardinality() == pref.Repeated { + wantHas = false + } if fd.ContainingOneof() != nil { wantHas = true } @@ -176,7 +176,7 @@ func testField(t testing.TB, m pref.Message, fd pref.FieldDescriptor) { switch { case fd.IsList() || fd.IsMap(): m.Set(fd, m.Get(fd)) - if got, want := m.Has(fd), fd.IsExtension() || fd.ContainingOneof() != nil; got != want { + if got, want := m.Has(fd), (fd.IsExtension() && fd.Cardinality() != pref.Repeated) || fd.ContainingOneof() != nil; got != want { t.Errorf("after setting %q to default:\nMessage.Has(%v) = %v, want %v", name, num, got, want) } case fd.Message() == nil: @@ -300,7 +300,7 @@ func testFieldList(t testing.TB, m pref.Message, fd pref.FieldDescriptor) { // Append values. var want pref.List = &testList{} for i, n := range []seed{1, 0, minVal, maxVal} { - if got, want := m.Has(fd), i > 0 || fd.IsExtension(); got != want { + if got, want := m.Has(fd), i > 0; got != want { t.Errorf("after appending %d elements to %q:\nMessage.Has(%v) = %v, want %v", i, name, num, got, want) } v := newListElement(fd, list, n, nil) @@ -327,7 +327,7 @@ func testFieldList(t testing.TB, m pref.Message, fd pref.FieldDescriptor) { n := want.Len() - 1 want.Truncate(n) list.Truncate(n) - if got, want := m.Has(fd), want.Len() > 0 || fd.IsExtension(); got != want { + if got, want := m.Has(fd), want.Len() > 0; got != want { t.Errorf("after truncating %q to %d:\nMessage.Has(%v) = %v, want %v", name, n, num, got, want) } if got, want := m.Get(fd), pref.ValueOfList(want); !valueEqual(got, want) { diff --git a/types/dynamicpb/dynamic.go b/types/dynamicpb/dynamic.go index 036a92cc..6c0bcbea 100644 --- a/types/dynamicpb/dynamic.go +++ b/types/dynamicpb/dynamic.go @@ -83,9 +83,9 @@ func (m *Message) Range(f func(pref.FieldDescriptor, pref.Value) bool) { fd := m.ext[num] if fd == nil { fd = m.Descriptor().Fields().ByNumber(num) - if !isSet(fd, v) { - continue - } + } + if !isSet(fd, v) { + continue } if !f(fd, v) { return @@ -97,8 +97,8 @@ func (m *Message) Range(f func(pref.FieldDescriptor, pref.Value) bool) { // See protoreflect.Message for details. func (m *Message) Has(fd pref.FieldDescriptor) bool { m.checkField(fd) - if fd.IsExtension() { - return m.ext[fd.Number()] == fd + if fd.IsExtension() && m.ext[fd.Number()] != fd { + return false } v, ok := m.known[fd.Number()] if !ok { @@ -371,7 +371,7 @@ func isSet(fd pref.FieldDescriptor, v pref.Value) bool { return v.List().Len() > 0 case fd.ContainingOneof() != nil: return true - case fd.Syntax() == pref.Proto3: + case fd.Syntax() == pref.Proto3 && !fd.IsExtension(): switch fd.Kind() { case pref.BoolKind: return v.Bool()