protobuf-go/proto/methods_test.go
Damien Neil c600d6c086 all: do best-effort initialization check on fast path unmarshal
Add a fast check for required fields to the fast path unmarshal.
This is best-effort and will fail to detect some initialized
messages: Messages with more than 64 required fields, messages
split across multiple tags, possibly other cases.

In the cases where it works (which is most of them in practice),
this permits us to skip the IsInitialized check.

Change-Id: I6b70953a333033a5e64fb7ca37a59786cb0f75a0
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/215878
Reviewed-by: Joe Tsai <joetsai@google.com>
2020-01-22 20:57:14 +00:00

158 lines
4.7 KiB
Go

// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// The protoreflect tag disables fast-path methods, including legacy ones.
// +build !protoreflect
package proto_test
import (
"bytes"
"errors"
"fmt"
"testing"
"google.golang.org/protobuf/internal/impl"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/runtime/protoiface"
legacypb "google.golang.org/protobuf/internal/testprotos/legacy"
)
type selfMarshaler struct {
bytes []byte
err error
}
func (m selfMarshaler) Reset() {}
func (m selfMarshaler) ProtoMessage() {}
func (m selfMarshaler) String() string {
return fmt.Sprintf("selfMarshaler{bytes:%v, err:%v}", m.bytes, m.err)
}
func (m selfMarshaler) Marshal() ([]byte, error) {
return m.bytes, m.err
}
func (m *selfMarshaler) Unmarshal(b []byte) error {
m.bytes = b
return m.err
}
func TestLegacyMarshalMethod(t *testing.T) {
for _, test := range []selfMarshaler{
{bytes: []byte("marshal")},
{bytes: []byte("marshal"), err: errors.New("some error")},
} {
m := impl.Export{}.MessageOf(test).Interface()
b, err := proto.Marshal(m)
if err != test.err || !bytes.Equal(b, test.bytes) {
t.Errorf("proto.Marshal(%v) = %v, %v; want %v, %v", test, b, err, test.bytes, test.err)
}
if gotSize, wantSize := proto.Size(m), len(test.bytes); gotSize != wantSize {
t.Fatalf("proto.Size(%v) = %v, want %v", test, gotSize, wantSize)
}
prefix := []byte("prefix")
want := append(prefix, test.bytes...)
b, err = proto.MarshalOptions{}.MarshalAppend(prefix, m)
if err != test.err || !bytes.Equal(b, want) {
t.Errorf("MarshalAppend(%v, %v) = %v, %v; want %v, %v", prefix, test, b, err, test.bytes, test.err)
}
b, err = proto.MarshalOptions{
Deterministic: true,
}.MarshalAppend(nil, m)
if err != test.err || !bytes.Equal(b, test.bytes) {
t.Errorf("MarshalOptions{Deterministic:true}.MarshalAppend(nil, %v) = %v, %v; want %v, %v", test, b, err, test.bytes, test.err)
}
}
}
func TestLegacyUnmarshalMethod(t *testing.T) {
sm := &selfMarshaler{}
m := impl.Export{}.MessageOf(sm).Interface()
want := []byte("unmarshal")
if err := proto.Unmarshal(want, m); err != nil {
t.Fatalf("proto.Unmarshal(selfMarshaler{}) = %v, want nil", err)
}
if !bytes.Equal(sm.bytes, want) {
t.Fatalf("proto.Unmarshal(selfMarshaler{}): Marshal method not called")
}
}
type descPanicSelfMarshaler struct{}
const descPanicSelfMarshalerBytes = "bytes"
func (m *descPanicSelfMarshaler) Reset() {}
func (m *descPanicSelfMarshaler) ProtoMessage() {}
func (m *descPanicSelfMarshaler) Descriptor() ([]byte, []int) { panic("Descriptor method panics") }
func (m *descPanicSelfMarshaler) String() string { return "descPanicSelfMarshaler{}" }
func (m *descPanicSelfMarshaler) Marshal() ([]byte, error) {
return []byte(descPanicSelfMarshalerBytes), nil
}
func TestSelfMarshalerDescriptorPanics(t *testing.T) {
m := &descPanicSelfMarshaler{}
got, err := proto.Marshal(impl.Export{}.MessageOf(m).Interface())
want := []byte(descPanicSelfMarshalerBytes)
if err != nil || !bytes.Equal(got, want) {
t.Fatalf("proto.Marshal(%v) = %v, %v; want %v, nil", m, got, err, want)
}
}
type descSelfMarshaler struct {
someField int // some non-generated field
}
const descSelfMarshalerBytes = "bytes"
func (m *descSelfMarshaler) Reset() {}
func (m *descSelfMarshaler) ProtoMessage() {}
func (m *descSelfMarshaler) Descriptor() ([]byte, []int) {
return ((*legacypb.Legacy)(nil)).GetF1().Descriptor()
}
func (m *descSelfMarshaler) String() string {
return "descSelfMarshaler{}"
}
func (m *descSelfMarshaler) Marshal() ([]byte, error) {
return []byte(descSelfMarshalerBytes), nil
}
func TestSelfMarshalerWithDescriptor(t *testing.T) {
m := &descSelfMarshaler{}
got, err := proto.Marshal(impl.Export{}.MessageOf(m).Interface())
want := []byte(descSelfMarshalerBytes)
if err != nil || !bytes.Equal(got, want) {
t.Fatalf("proto.Marshal(%v) = %v, %v; want %v, nil", m, got, err, want)
}
}
func TestDecodeFastIsInitialized(t *testing.T) {
for _, test := range testValidMessages {
if !test.checkFastInit {
continue
}
for _, message := range test.decodeTo {
t.Run(fmt.Sprintf("%s (%T)", test.desc, message), func(t *testing.T) {
m := message.ProtoReflect().New()
opts := proto.UnmarshalOptions{
AllowPartial: true,
}
out, err := opts.UnmarshalState(m.Interface(), protoiface.UnmarshalInput{
Buf: test.wire,
})
if err != nil {
t.Fatalf("Unmarshal error: %v", err)
}
if got, want := out.Initialized, !test.partial; got != want {
t.Errorf("out.Initialized = %v, want %v", got, want)
}
})
}
}
}