From db38ddde7d12ec30fea31c1ffdd887539eb0e88f Mon Sep 17 00:00:00 2001 From: Joe Tsai Date: Tue, 7 May 2019 15:14:40 -0700 Subject: [PATCH] proto: eagerly unmarshal extensions CL/172399 switches the v1 code to eagerly unmarshal extensions. This CL does the equivalent for v2. For the test, we simply switch from protoV1.Equal to protoV2.Equal, since the v2 equal does not magically unmarshal raw extensions. Change-Id: I6f64455b0a75bbc9a9a82108558641a29bd2b982 Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/175838 Reviewed-by: Damien Neil --- proto/decode.go | 19 +++++++++++++++++++ proto/decode_test.go | 2 +- runtime/protoiface/methods.go | 2 ++ 3 files changed, 22 insertions(+), 1 deletion(-) diff --git a/proto/decode.go b/proto/decode.go index 5a867a21..fa6d4432 100644 --- a/proto/decode.go +++ b/proto/decode.go @@ -9,6 +9,7 @@ import ( "github.com/golang/protobuf/v2/internal/errors" "github.com/golang/protobuf/v2/internal/pragma" "github.com/golang/protobuf/v2/reflect/protoreflect" + "github.com/golang/protobuf/v2/reflect/protoregistry" "github.com/golang/protobuf/v2/runtime/protoiface" ) @@ -25,6 +26,10 @@ type UnmarshalOptions struct { // If DiscardUnknown is set, unknown fields are ignored. DiscardUnknown bool + // Resolver is used for looking up types when unmarshaling extension fields. + // If nil, this defaults to using protoregistry.GlobalTypes. + Resolver *protoregistry.Types + pragma.NoUnkeyedLiterals } @@ -37,6 +42,10 @@ func Unmarshal(b []byte, m Message) error { // Unmarshal parses the wire-format message in b and places the result in m. func (o UnmarshalOptions) Unmarshal(b []byte, m Message) error { + if o.Resolver == nil { + o.Resolver = protoregistry.GlobalTypes + } + // TODO: Reset m? err := o.unmarshalMessageFast(b, m) if err == errInternalNoFast { @@ -77,6 +86,16 @@ func (o UnmarshalOptions) unmarshalMessage(b []byte, m protoreflect.Message) err fieldType := fieldTypes.ByNumber(num) if fieldType == nil { fieldType = knownFields.ExtensionTypes().ByNumber(num) + if fieldType == nil && messageType.ExtensionRanges().Has(num) { + extType, err := o.Resolver.FindExtensionByNumber(messageType.FullName(), num) + if err != nil && err != protoregistry.NotFound { + return err + } + if extType != nil { + knownFields.ExtensionTypes().Register(extType) + fieldType = extType + } + } } var err error var valLen int diff --git a/proto/decode_test.go b/proto/decode_test.go index 084014f0..4eb2598c 100644 --- a/proto/decode_test.go +++ b/proto/decode_test.go @@ -54,7 +54,7 @@ func TestDecode(t *testing.T) { // Equal doesn't work on messages containing invalid extension data. return } - if !protoV1.Equal(got.(protoV1.Message), want.(protoV1.Message)) { + if !proto.Equal(got, want) { t.Errorf("Unmarshal returned unexpected result; got:\n%v\nwant:\n%v", marshalText(got), marshalText(want)) } }) diff --git a/runtime/protoiface/methods.go b/runtime/protoiface/methods.go index 42832de0..fe17ca78 100644 --- a/runtime/protoiface/methods.go +++ b/runtime/protoiface/methods.go @@ -7,6 +7,7 @@ package protoiface import ( "github.com/golang/protobuf/v2/internal/pragma" "github.com/golang/protobuf/v2/reflect/protoreflect" + "github.com/golang/protobuf/v2/reflect/protoregistry" ) // Methoder is an optional interface implemented by generated messages to @@ -62,6 +63,7 @@ type MarshalOptions struct { type UnmarshalOptions struct { AllowPartial bool DiscardUnknown bool + Resolver *protoregistry.Types pragma.NoUnkeyedLiterals }