From 37ef691e6bb8269921374c007809d0a349dc6c14 Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Wed, 25 Sep 2019 16:51:15 -0700 Subject: [PATCH] internal/impl: call Marshal/Unmarshal methods on legacy types Call the Marshal or Unmarshal method on legacy messages implementing protoV1.Marshaler or protoV2.Unmarshaler. We do this in the impl package by creating an appropriate function in the protoiface.Methods struct for legacy messages. In proto.MarshalAppend, return the bytes provided by the fast-path marshal function even when the returned error is non-nil. Fixes golang/protobuf#955 Change-Id: I36924af9ff959a946c43f2295ef3202216e81b32 Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/197357 Reviewed-by: Joe Tsai --- internal/impl/codec_message.go | 17 +++++--- internal/impl/legacy_message.go | 29 +++++++++++++ proto/encode.go | 2 +- proto/methods_test.go | 74 +++++++++++++++++++++++++++++++++ 4 files changed, 115 insertions(+), 7 deletions(-) create mode 100644 proto/methods_test.go diff --git a/internal/impl/codec_message.go b/internal/impl/codec_message.go index d7235a25..4694718f 100644 --- a/internal/impl/codec_message.go +++ b/internal/impl/codec_message.go @@ -128,11 +128,16 @@ func (mi *MessageInfo) makeCoderMethods(t reflect.Type, si structInfo) { } mi.needsInitCheck = needsInitCheck(mi.Desc) - mi.methods = piface.Methods{ - Flags: piface.SupportMarshalDeterministic | piface.SupportUnmarshalDiscardUnknown, - MarshalAppend: mi.marshalAppend, - Unmarshal: mi.unmarshal, - Size: mi.size, - IsInitialized: mi.isInitialized, + if mi.methods.MarshalAppend == nil && mi.methods.Size == nil { + mi.methods.Flags |= piface.SupportMarshalDeterministic + mi.methods.MarshalAppend = mi.marshalAppend + mi.methods.Size = mi.size + } + if mi.methods.Unmarshal == nil { + mi.methods.Flags |= piface.SupportUnmarshalDiscardUnknown + mi.methods.Unmarshal = mi.unmarshal + } + if mi.methods.IsInitialized == nil { + mi.methods.IsInitialized = mi.isInitialized } } diff --git a/internal/impl/legacy_message.go b/internal/impl/legacy_message.go index de03ee61..703e8654 100644 --- a/internal/impl/legacy_message.go +++ b/internal/impl/legacy_message.go @@ -16,6 +16,7 @@ import ( "google.golang.org/protobuf/internal/strs" "google.golang.org/protobuf/reflect/protoreflect" pref "google.golang.org/protobuf/reflect/protoreflect" + piface "google.golang.org/protobuf/runtime/protoiface" ) // legacyWrapMessage wraps v as a protoreflect.ProtoMessage, @@ -41,6 +42,34 @@ func legacyLoadMessageInfo(t reflect.Type, name pref.FullName) *MessageInfo { Desc: legacyLoadMessageDesc(t, name), GoReflectType: t, } + + v := reflect.Zero(t).Interface() + type marshaler interface { + Marshal() ([]byte, error) + } + if _, ok := v.(marshaler); ok { + mi.methods.MarshalAppend = func(b []byte, m pref.Message, _ piface.MarshalOptions) ([]byte, error) { + out, err := m.Interface().(unwrapper).protoUnwrap().(marshaler).Marshal() + if b != nil { + out = append(b, out...) + } + return out, err + } + mi.methods.Size = func(m pref.Message, _ piface.MarshalOptions) int { + // This is not at all efficient. + b, _ := m.Interface().(unwrapper).protoUnwrap().(marshaler).Marshal() + return len(b) + } + } + type unmarshaler interface { + Unmarshal([]byte) error + } + if _, ok := v.(unmarshaler); ok { + mi.methods.Unmarshal = func(b []byte, m pref.Message, _ piface.UnmarshalOptions) error { + return m.Interface().(unwrapper).protoUnwrap().(unmarshaler).Unmarshal(b) + } + } + if mi, ok := legacyMessageTypeCache.LoadOrStore(t, mi); ok { return mi.(*MessageInfo) } diff --git a/proto/encode.go b/proto/encode.go index 5511d162..a57133c7 100644 --- a/proto/encode.go +++ b/proto/encode.go @@ -88,7 +88,7 @@ func (o MarshalOptions) Marshal(m Message) ([]byte, error) { func (o MarshalOptions) MarshalAppend(b []byte, m Message) ([]byte, error) { out, err := o.marshalMessage(b, m.ProtoReflect()) if err != nil { - return nil, err + return out, err } if o.AllowPartial { return out, nil diff --git a/proto/methods_test.go b/proto/methods_test.go new file mode 100644 index 00000000..436df6df --- /dev/null +++ b/proto/methods_test.go @@ -0,0 +1,74 @@ +// Copyright 2019 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// The protoreflect tag disables fast-path methods, including legacy ones. +// +build !protoreflect + +package proto_test + +import ( + "bytes" + "errors" + "fmt" + "testing" + + "google.golang.org/protobuf/internal/impl" + "google.golang.org/protobuf/proto" +) + +type selfMarshaler struct { + bytes []byte + err error +} + +func (m selfMarshaler) Reset() {} +func (m selfMarshaler) ProtoMessage() {} + +func (m selfMarshaler) String() string { + return fmt.Sprintf("selfMarshaler{bytes:%v, err:%v}", m.bytes, m.err) +} + +func (m selfMarshaler) Marshal() ([]byte, error) { + return m.bytes, m.err +} + +func (m *selfMarshaler) Unmarshal(b []byte) error { + m.bytes = b + return m.err +} + +func TestLegacyMarshalMethod(t *testing.T) { + for _, test := range []*selfMarshaler{ + {bytes: []byte("marshal")}, + {bytes: []byte("marshal"), err: errors.New("some error")}, + } { + m := impl.Export{}.MessageOf(test).Interface() + b, err := proto.Marshal(m) + if err != test.err || !bytes.Equal(b, test.bytes) { + t.Errorf("proto.Marshal(%v) = %v, %v; want %v, %v", test, b, err, test.bytes, test.err) + } + if gotSize, wantSize := proto.Size(m), len(test.bytes); gotSize != wantSize { + t.Fatalf("proto.Size(%v) = %v, want %v", test, gotSize, wantSize) + } + + prefix := []byte("prefix") + want := append(prefix, test.bytes...) + b, err = proto.MarshalOptions{}.MarshalAppend(prefix, m) + if err != test.err || !bytes.Equal(b, want) { + t.Errorf("MarshalAppend(%v, %v) = %v, %v; want %v, %v", prefix, test, b, err, test.bytes, test.err) + } + } +} + +func TestLegacyUnmarshalMethod(t *testing.T) { + sm := &selfMarshaler{} + m := impl.Export{}.MessageOf(sm).Interface() + want := []byte("unmarshal") + if err := proto.Unmarshal(want, m); err != nil { + t.Fatalf("proto.Unmarshal(selfMarshaler{}) = %v, want nil", err) + } + if !bytes.Equal(sm.bytes, want) { + t.Fatalf("proto.Unmarshal(selfMarshaler{}): Marshal method not called") + } +}