From 9afe9bb78b4f5e9015dd310c03875443c14125ec Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Fri, 7 Feb 2020 10:06:53 -0800 Subject: [PATCH] internal/impl: validate messagesets Change-Id: Id90bb386e7481bb9dee5a07889f308f1e1810825 Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/218438 Reviewed-by: Joe Tsai --- internal/encoding/messageset/messageset.go | 6 ++-- internal/impl/validate.go | 41 +++++++++++++++++----- proto/messageset_test.go | 13 ------- 3 files changed, 35 insertions(+), 25 deletions(-) diff --git a/internal/encoding/messageset/messageset.go b/internal/encoding/messageset/messageset.go index 77522de2..837e5c49 100644 --- a/internal/encoding/messageset/messageset.go +++ b/internal/encoding/messageset/messageset.go @@ -99,7 +99,7 @@ func Unmarshal(b []byte, wantLen bool, fn func(typeID wire.Number, value []byte) b = b[n:] continue } - typeID, value, n, err := consumeFieldValue(b, wantLen) + typeID, value, n, err := ConsumeFieldValue(b, wantLen) if err != nil { return err } @@ -114,13 +114,13 @@ func Unmarshal(b []byte, wantLen bool, fn func(typeID wire.Number, value []byte) return nil } -// consumeFieldValue parses b as a MessageSet item field value until and including +// ConsumeFieldValue parses b as a MessageSet item field value until and including // the trailing end group marker. It assumes the start group tag has already been parsed. // It returns the contents of the type_id and message subfields and the total // item length. // // If wantLen is true, the returned message value includes the length prefix. -func consumeFieldValue(b []byte, wantLen bool) (typeid wire.Number, message []byte, n int, err error) { +func ConsumeFieldValue(b []byte, wantLen bool) (typeid wire.Number, message []byte, n int, err error) { ilen := len(b) for { num, wtyp, n := wire.ConsumeTag(b) diff --git a/internal/impl/validate.go b/internal/impl/validate.go index 06acc788..bb00cd0d 100644 --- a/internal/impl/validate.go +++ b/internal/impl/validate.go @@ -11,6 +11,7 @@ import ( "reflect" "unicode/utf8" + "google.golang.org/protobuf/internal/encoding/messageset" "google.golang.org/protobuf/internal/encoding/wire" "google.golang.org/protobuf/internal/flags" "google.golang.org/protobuf/internal/strs" @@ -93,6 +94,7 @@ const ( validationTypeFixed64 validationTypeBytes validationTypeUTF8String + validationTypeMessageSetItem ) func newFieldValidationInfo(mi *MessageInfo, si structInfo, fd pref.FieldDescriptor, ft reflect.Type) validationInfo { @@ -237,11 +239,6 @@ func (mi *MessageInfo) validate(b []byte, groupTag wire.Number, opts unmarshalOp State: for len(states) > 0 { st := &states[len(states)-1] - if st.mi != nil { - if flags.ProtoLegacy && st.mi.isMessageSet { - return out, ValidationUnknown - } - } for len(b) > 0 { // Parse the tag (field number and wire type). var tag uint64 @@ -274,8 +271,8 @@ State: return out, ValidationInvalid } var vi validationInfo - switch st.typ { - case validationTypeMap: + switch { + case st.typ == validationTypeMap: switch num { case 1: vi.typ = st.keyType @@ -284,6 +281,11 @@ State: vi.mi = st.mi vi.requiredBit = 1 } + case flags.ProtoLegacy && st.mi.isMessageSet: + switch num { + case messageset.FieldItem: + vi.typ = validationTypeMessageSetItem + } default: var f *coderFieldInfo if int(num) < len(st.mi.denseCoderFields) { @@ -483,8 +485,8 @@ State: } b = b[8:] case wire.StartGroupType: - switch vi.typ { - case validationTypeGroup: + switch { + case vi.typ == validationTypeGroup: if vi.mi == nil { return out, ValidationUnknown } @@ -495,6 +497,27 @@ State: endGroup: num, }) continue State + case flags.ProtoLegacy && vi.typ == validationTypeMessageSetItem: + typeid, v, n, err := messageset.ConsumeFieldValue(b, false) + if err != nil { + return out, ValidationInvalid + } + xt, err := opts.Resolver.FindExtensionByNumber(st.mi.Desc.FullName(), typeid) + switch { + case err == preg.NotFound: + b = b[n:] + case err != nil: + return out, ValidationUnknown + default: + xvi := getExtensionFieldInfo(xt).validation + states = append(states, validationState{ + typ: xvi.typ, + mi: xvi.mi, + tail: b[n:], + }) + b = v + continue State + } default: n := wire.ConsumeFieldValue(num, wtyp, b) if n < 0 { diff --git a/proto/messageset_test.go b/proto/messageset_test.go index 9e70e59a..c9018000 100644 --- a/proto/messageset_test.go +++ b/proto/messageset_test.go @@ -8,7 +8,6 @@ import ( "google.golang.org/protobuf/internal/encoding/pack" "google.golang.org/protobuf/internal/encoding/wire" "google.golang.org/protobuf/internal/flags" - "google.golang.org/protobuf/internal/impl" "google.golang.org/protobuf/proto" messagesetpb "google.golang.org/protobuf/internal/testprotos/messageset/messagesetpb" @@ -41,7 +40,6 @@ var messageSetTestProtos = []testProto{ pack.Tag{1, pack.EndGroupType}, }), }.Marshal(), - validationStatus: impl.ValidationUnknown, }, { desc: "MessageSet type_id after message content", @@ -62,7 +60,6 @@ var messageSetTestProtos = []testProto{ pack.Tag{1, pack.EndGroupType}, }), }.Marshal(), - validationStatus: impl.ValidationUnknown, }, { desc: "MessageSet does not preserve unknown field", @@ -82,7 +79,6 @@ var messageSetTestProtos = []testProto{ // Unknown field pack.Tag{4, pack.VarintType}, pack.Varint(30), }.Marshal(), - validationStatus: impl.ValidationUnknown, }, { desc: "MessageSet with unknown type_id", @@ -102,7 +98,6 @@ var messageSetTestProtos = []testProto{ }), pack.Tag{1, pack.EndGroupType}, }.Marshal(), - validationStatus: impl.ValidationUnknown, }, { desc: "MessageSet merges repeated message fields in item", @@ -124,7 +119,6 @@ var messageSetTestProtos = []testProto{ }), pack.Tag{1, pack.EndGroupType}, }.Marshal(), - validationStatus: impl.ValidationUnknown, }, { desc: "MessageSet merges message fields in repeated items", @@ -161,7 +155,6 @@ var messageSetTestProtos = []testProto{ }), pack.Tag{1, pack.EndGroupType}, }.Marshal(), - validationStatus: impl.ValidationUnknown, }, { desc: "MessageSet with missing type_id", @@ -175,7 +168,6 @@ var messageSetTestProtos = []testProto{ }), pack.Tag{1, pack.EndGroupType}, }.Marshal(), - validationStatus: impl.ValidationUnknown, }, { desc: "MessageSet with missing message", @@ -188,7 +180,6 @@ var messageSetTestProtos = []testProto{ pack.Tag{2, pack.VarintType}, pack.Varint(1000), pack.Tag{1, pack.EndGroupType}, }.Marshal(), - validationStatus: impl.ValidationUnknown, }, { desc: "MessageSet with type id out of valid field number range", @@ -205,7 +196,6 @@ var messageSetTestProtos = []testProto{ pack.Tag{1, pack.EndGroupType}, }), }.Marshal(), - validationStatus: impl.ValidationUnknown, }, { desc: "MessageSet with unknown type id out of valid field number range", @@ -226,7 +216,6 @@ var messageSetTestProtos = []testProto{ pack.Tag{1, pack.EndGroupType}, }), }.Marshal(), - validationStatus: impl.ValidationUnknown, }, { desc: "MessageSet with required field set", @@ -248,7 +237,6 @@ var messageSetTestProtos = []testProto{ pack.Tag{1, pack.EndGroupType}, }), }.Marshal(), - validationStatus: impl.ValidationUnknown, }, { desc: "MessageSet with required field unset", @@ -267,6 +255,5 @@ var messageSetTestProtos = []testProto{ pack.Tag{1, pack.EndGroupType}, }), }.Marshal(), - validationStatus: impl.ValidationUnknown, }, }