protobuf-go/reflect/protoregistry/registry_test.go
Joe Tsai 15076350e8 reflect/protodesc: enforce strict validation
Hyrum's Law dictates that if we do not prevent naughty behavior,
people will rely on it. If we do not validate that the provided
file descriptor is correct today, it will be near impossible
to add proper validation checks later on.

The logic added validates that the provided file descriptor is
correct according to the same semantics as protoc,
which was reversed engineered to derive the set of rules implemented here.
The rules are unfortunately complicated because protobuf is a language
full of many non-orthogonal features. While our logic is complicated,
it is still 1/7th the size of the equivalent C++ code!

Change-Id: I6acc5dc3bd2e4c6bea6cd9e81214f8104402602a
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/184837
Reviewed-by: Damien Neil <dneil@google.com>
2019-07-03 20:46:51 +00:00

611 lines
17 KiB
Go

// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package protoregistry_test
import (
"fmt"
"strings"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"google.golang.org/protobuf/encoding/prototext"
pimpl "google.golang.org/protobuf/internal/impl"
pdesc "google.golang.org/protobuf/reflect/protodesc"
pref "google.golang.org/protobuf/reflect/protoreflect"
preg "google.golang.org/protobuf/reflect/protoregistry"
testpb "google.golang.org/protobuf/reflect/protoregistry/testprotos"
"google.golang.org/protobuf/types/descriptorpb"
)
func mustMakeFile(s string) pref.FileDescriptor {
pb := new(descriptorpb.FileDescriptorProto)
if err := prototext.Unmarshal([]byte(s), pb); err != nil {
panic(err)
}
fd, err := pdesc.NewFile(pb, nil)
if err != nil {
panic(err)
}
return fd
}
func TestFiles(t *testing.T) {
type (
file struct {
Path string
Pkg pref.FullName
}
testFile struct {
inFile pref.FileDescriptor
wantErr string
}
testFindDesc struct {
inName pref.FullName
wantFound bool
}
testRangePkg struct {
inPkg pref.FullName
wantFiles []file
}
testFindPath struct {
inPath string
wantFiles []file
}
)
tests := []struct {
files []testFile
findDescs []testFindDesc
rangePkgs []testRangePkg
findPaths []testFindPath
}{{
// Test that overlapping packages and files are permitted.
files: []testFile{
{inFile: mustMakeFile(`syntax:"proto2" name:"test1.proto" package:"foo.bar"`)},
{inFile: mustMakeFile(`syntax:"proto2" name:"foo/bar/test.proto" package:"my.test"`)},
{inFile: mustMakeFile(`syntax:"proto2" name:"foo/bar/test.proto" package:"foo.bar.baz"`), wantErr: "already registered"},
{inFile: mustMakeFile(`syntax:"proto2" name:"test2.proto" package:"my.test.package"`)},
{inFile: mustMakeFile(`syntax:"proto2" name:"weird" package:"foo.bar"`)},
{inFile: mustMakeFile(`syntax:"proto2" name:"foo/bar/baz/../test.proto" package:"my.test"`)},
},
rangePkgs: []testRangePkg{{
inPkg: "nothing",
}, {
inPkg: "",
}, {
inPkg: ".",
}, {
inPkg: "foo",
}, {
inPkg: "foo.",
}, {
inPkg: "foo..",
}, {
inPkg: "foo.bar",
wantFiles: []file{
{"test1.proto", "foo.bar"},
{"weird", "foo.bar"},
},
}, {
inPkg: "my.test",
wantFiles: []file{
{"foo/bar/baz/../test.proto", "my.test"},
{"foo/bar/test.proto", "my.test"},
},
}, {
inPkg: "fo",
}},
findPaths: []testFindPath{{
inPath: "nothing",
}, {
inPath: "weird",
wantFiles: []file{
{"weird", "foo.bar"},
},
}, {
inPath: "foo/bar/test.proto",
wantFiles: []file{
{"foo/bar/test.proto", "my.test"},
},
}},
}, {
// Test when new enum conflicts with existing package.
files: []testFile{{
inFile: mustMakeFile(`syntax:"proto2" name:"test1a.proto" package:"foo.bar.baz"`),
}, {
inFile: mustMakeFile(`syntax:"proto2" name:"test1b.proto" enum_type:[{name:"foo" value:[{name:"VALUE" number:0}]}]`),
wantErr: `file "test1b.proto" has a name conflict over foo`,
}},
}, {
// Test when new package conflicts with existing enum.
files: []testFile{{
inFile: mustMakeFile(`syntax:"proto2" name:"test2a.proto" enum_type:[{name:"foo" value:[{name:"VALUE" number:0}]}]`),
}, {
inFile: mustMakeFile(`syntax:"proto2" name:"test2b.proto" package:"foo.bar.baz"`),
wantErr: `file "test2b.proto" has a name conflict over foo`,
}},
}, {
// Test when new enum conflicts with existing enum in same package.
files: []testFile{{
inFile: mustMakeFile(`syntax:"proto2" name:"test3a.proto" package:"foo" enum_type:[{name:"BAR" value:[{name:"VALUE" number:0}]}]`),
}, {
inFile: mustMakeFile(`syntax:"proto2" name:"test3b.proto" package:"foo" enum_type:[{name:"BAR" value:[{name:"VALUE2" number:0}]}]`),
wantErr: `file "test3b.proto" has a name conflict over foo.BAR`,
}},
}, {
files: []testFile{{
inFile: mustMakeFile(`
syntax: "proto2"
name: "test1.proto"
package: "fizz.buzz"
message_type: [{
name: "Message"
field: [
{name:"Field" number:1 label:LABEL_OPTIONAL type:TYPE_STRING oneof_index:0}
]
oneof_decl: [{name:"Oneof"}]
extension_range: [{start:1000 end:2000}]
enum_type: [
{name:"Enum" value:[{name:"EnumValue" number:0}]}
]
nested_type: [
{name:"Message" field:[{name:"Field" number:1 label:LABEL_OPTIONAL type:TYPE_STRING}]}
]
extension: [
{name:"Extension" number:1001 label:LABEL_OPTIONAL type:TYPE_STRING extendee:".fizz.buzz.Message"}
]
}]
enum_type: [{
name: "Enum"
value: [{name:"EnumValue" number:0}]
}]
extension: [
{name:"Extension" number:1000 label:LABEL_OPTIONAL type:TYPE_STRING extendee:".fizz.buzz.Message"}
]
service: [{
name: "Service"
method: [{
name: "Method"
input_type: ".fizz.buzz.Message"
output_type: ".fizz.buzz.Message"
client_streaming: true
server_streaming: true
}]
}]
`),
}, {
inFile: mustMakeFile(`
syntax: "proto2"
name: "test2.proto"
package: "fizz.buzz.gazz"
enum_type: [{
name: "Enum"
value: [{name:"EnumValue" number:0}]
}]
`),
}, {
inFile: mustMakeFile(`
syntax: "proto2"
name: "test3.proto"
package: "fizz.buzz"
enum_type: [{
name: "Enum1"
value: [{name:"EnumValue1" number:0}]
}, {
name: "Enum2"
value: [{name:"EnumValue2" number:0}]
}]
`),
}, {
// Make sure we can register without package name.
inFile: mustMakeFile(`
name: "weird"
syntax: "proto2"
message_type: [{
name: "Message"
nested_type: [{
name: "Message"
nested_type: [{
name: "Message"
}]
}]
}]
`),
}},
findDescs: []testFindDesc{
{inName: "fizz.buzz.message", wantFound: false},
{inName: "fizz.buzz.Message", wantFound: true},
{inName: "fizz.buzz.Message.X", wantFound: false},
{inName: "fizz.buzz.Field", wantFound: false},
{inName: "fizz.buzz.Oneof", wantFound: false},
{inName: "fizz.buzz.Message.Field", wantFound: true},
{inName: "fizz.buzz.Message.Field.X", wantFound: false},
{inName: "fizz.buzz.Message.Oneof", wantFound: true},
{inName: "fizz.buzz.Message.Oneof.X", wantFound: false},
{inName: "fizz.buzz.Message.Message", wantFound: true},
{inName: "fizz.buzz.Message.Message.X", wantFound: false},
{inName: "fizz.buzz.Message.Enum", wantFound: true},
{inName: "fizz.buzz.Message.Enum.X", wantFound: false},
{inName: "fizz.buzz.Message.EnumValue", wantFound: true},
{inName: "fizz.buzz.Message.EnumValue.X", wantFound: false},
{inName: "fizz.buzz.Message.Extension", wantFound: true},
{inName: "fizz.buzz.Message.Extension.X", wantFound: false},
{inName: "fizz.buzz.enum", wantFound: false},
{inName: "fizz.buzz.Enum", wantFound: true},
{inName: "fizz.buzz.Enum.X", wantFound: false},
{inName: "fizz.buzz.EnumValue", wantFound: true},
{inName: "fizz.buzz.EnumValue.X", wantFound: false},
{inName: "fizz.buzz.Enum.EnumValue", wantFound: false},
{inName: "fizz.buzz.Extension", wantFound: true},
{inName: "fizz.buzz.Extension.X", wantFound: false},
{inName: "fizz.buzz.service", wantFound: false},
{inName: "fizz.buzz.Service", wantFound: true},
{inName: "fizz.buzz.Service.X", wantFound: false},
{inName: "fizz.buzz.Method", wantFound: false},
{inName: "fizz.buzz.Service.Method", wantFound: true},
{inName: "fizz.buzz.Service.Method.X", wantFound: false},
{inName: "fizz.buzz.gazz", wantFound: false},
{inName: "fizz.buzz.gazz.Enum", wantFound: true},
{inName: "fizz.buzz.gazz.EnumValue", wantFound: true},
{inName: "fizz.buzz.gazz.Enum.EnumValue", wantFound: false},
{inName: "fizz.buzz", wantFound: false},
{inName: "fizz.buzz.Enum1", wantFound: true},
{inName: "fizz.buzz.EnumValue1", wantFound: true},
{inName: "fizz.buzz.Enum1.EnumValue1", wantFound: false},
{inName: "fizz.buzz.Enum2", wantFound: true},
{inName: "fizz.buzz.EnumValue2", wantFound: true},
{inName: "fizz.buzz.Enum2.EnumValue2", wantFound: false},
{inName: "fizz.buzz.Enum3", wantFound: false},
{inName: "", wantFound: false},
{inName: "Message", wantFound: true},
{inName: "Message.Message", wantFound: true},
{inName: "Message.Message.Message", wantFound: true},
{inName: "Message.Message.Message.Message", wantFound: false},
},
}}
sortFiles := cmpopts.SortSlices(func(x, y file) bool {
return x.Path < y.Path || (x.Path == y.Path && x.Pkg < y.Pkg)
})
for _, tt := range tests {
t.Run("", func(t *testing.T) {
var files preg.Files
for i, tc := range tt.files {
gotErr := files.Register(tc.inFile)
if ((gotErr == nil) != (tc.wantErr == "")) || !strings.Contains(fmt.Sprint(gotErr), tc.wantErr) {
t.Errorf("file %d, Register() = %v, want %v", i, gotErr, tc.wantErr)
}
}
for _, tc := range tt.findDescs {
d, _ := files.FindDescriptorByName(tc.inName)
gotFound := d != nil
if gotFound != tc.wantFound {
t.Errorf("FindDescriptorByName(%v) find mismatch: got %v, want %v", tc.inName, gotFound, tc.wantFound)
}
}
for _, tc := range tt.rangePkgs {
var gotFiles []file
files.RangeFilesByPackage(tc.inPkg, func(fd pref.FileDescriptor) bool {
gotFiles = append(gotFiles, file{fd.Path(), fd.Package()})
return true
})
if diff := cmp.Diff(tc.wantFiles, gotFiles, sortFiles); diff != "" {
t.Errorf("RangeFilesByPackage(%v) mismatch (-want +got):\n%v", tc.inPkg, diff)
}
}
for _, tc := range tt.findPaths {
var gotFiles []file
if fd, err := files.FindFileByPath(tc.inPath); err == nil {
gotFiles = append(gotFiles, file{fd.Path(), fd.Package()})
}
if diff := cmp.Diff(tc.wantFiles, gotFiles, sortFiles); diff != "" {
t.Errorf("FindFileByPath(%v) mismatch (-want +got):\n%v", tc.inPath, diff)
}
}
})
}
}
func TestTypes(t *testing.T) {
mt1 := pimpl.Export{}.MessageTypeOf(&testpb.Message1{})
et1 := pimpl.Export{}.EnumTypeOf(testpb.Enum1_ONE)
xt1 := testpb.E_StringField.Type
xt2 := testpb.E_Message4_MessageField.Type
registry := new(preg.Types)
if err := registry.Register(mt1, et1, xt1, xt2); err != nil {
t.Fatalf("registry.Register() returns unexpected error: %v", err)
}
t.Run("FindMessageByName", func(t *testing.T) {
tests := []struct {
name string
messageType pref.MessageType
wantErr bool
wantNotFound bool
}{{
name: "testprotos.Message1",
messageType: mt1,
}, {
name: "testprotos.NoSuchMessage",
wantErr: true,
wantNotFound: true,
}, {
name: "testprotos.Enum1",
wantErr: true,
}, {
name: "testprotos.Enum2",
wantErr: true,
}, {
name: "testprotos.Enum3",
wantErr: true,
}}
for _, tc := range tests {
got, err := registry.FindMessageByName(pref.FullName(tc.name))
gotErr := err != nil
if gotErr != tc.wantErr {
t.Errorf("FindMessageByName(%v) = (_, %v), want error? %t", tc.name, err, tc.wantErr)
continue
}
if tc.wantNotFound && err != preg.NotFound {
t.Errorf("FindMessageByName(%v) got error: %v, want NotFound error", tc.name, err)
continue
}
if got != tc.messageType {
t.Errorf("FindMessageByName(%v) got wrong value: %v", tc.name, got)
}
}
})
t.Run("FindMessageByURL", func(t *testing.T) {
tests := []struct {
name string
messageType pref.MessageType
wantErr bool
wantNotFound bool
}{{
name: "testprotos.Message1",
messageType: mt1,
}, {
name: "type.googleapis.com/testprotos.Nada",
wantErr: true,
wantNotFound: true,
}, {
name: "testprotos.Enum1",
wantErr: true,
}}
for _, tc := range tests {
got, err := registry.FindMessageByURL(tc.name)
gotErr := err != nil
if gotErr != tc.wantErr {
t.Errorf("FindMessageByURL(%v) = (_, %v), want error? %t", tc.name, err, tc.wantErr)
continue
}
if tc.wantNotFound && err != preg.NotFound {
t.Errorf("FindMessageByURL(%v) got error: %v, want NotFound error", tc.name, err)
continue
}
if got != tc.messageType {
t.Errorf("FindMessageByURL(%v) got wrong value: %v", tc.name, got)
}
}
})
t.Run("FindEnumByName", func(t *testing.T) {
tests := []struct {
name string
enumType pref.EnumType
wantErr bool
wantNotFound bool
}{{
name: "testprotos.Enum1",
enumType: et1,
}, {
name: "testprotos.None",
wantErr: true,
wantNotFound: true,
}, {
name: "testprotos.Message1",
wantErr: true,
}}
for _, tc := range tests {
got, err := registry.FindEnumByName(pref.FullName(tc.name))
gotErr := err != nil
if gotErr != tc.wantErr {
t.Errorf("FindEnumByName(%v) = (_, %v), want error? %t", tc.name, err, tc.wantErr)
continue
}
if tc.wantNotFound && err != preg.NotFound {
t.Errorf("FindEnumByName(%v) got error: %v, want NotFound error", tc.name, err)
continue
}
if got != tc.enumType {
t.Errorf("FindEnumByName(%v) got wrong value: %v", tc.name, got)
}
}
})
t.Run("FindExtensionByName", func(t *testing.T) {
tests := []struct {
name string
extensionType pref.ExtensionType
wantErr bool
wantNotFound bool
}{{
name: "testprotos.string_field",
extensionType: xt1,
}, {
name: "testprotos.Message4.message_field",
extensionType: xt2,
}, {
name: "testprotos.None",
wantErr: true,
wantNotFound: true,
}, {
name: "testprotos.Message1",
wantErr: true,
}}
for _, tc := range tests {
got, err := registry.FindExtensionByName(pref.FullName(tc.name))
gotErr := err != nil
if gotErr != tc.wantErr {
t.Errorf("FindExtensionByName(%v) = (_, %v), want error? %t", tc.name, err, tc.wantErr)
continue
}
if tc.wantNotFound && err != preg.NotFound {
t.Errorf("FindExtensionByName(%v) got error: %v, want NotFound error", tc.name, err)
continue
}
if got != tc.extensionType {
t.Errorf("FindExtensionByName(%v) got wrong value: %v", tc.name, got)
}
}
})
t.Run("FindExtensionByNumber", func(t *testing.T) {
tests := []struct {
parent string
number int32
extensionType pref.ExtensionType
wantErr bool
wantNotFound bool
}{{
parent: "testprotos.Message1",
number: 11,
extensionType: xt1,
}, {
parent: "testprotos.Message1",
number: 13,
wantErr: true,
wantNotFound: true,
}, {
parent: "testprotos.Message1",
number: 21,
extensionType: xt2,
}, {
parent: "testprotos.Message1",
number: 23,
wantErr: true,
wantNotFound: true,
}, {
parent: "testprotos.NoSuchMessage",
number: 11,
wantErr: true,
wantNotFound: true,
}, {
parent: "testprotos.Message1",
number: 30,
wantErr: true,
wantNotFound: true,
}, {
parent: "testprotos.Message1",
number: 99,
wantErr: true,
wantNotFound: true,
}}
for _, tc := range tests {
got, err := registry.FindExtensionByNumber(pref.FullName(tc.parent), pref.FieldNumber(tc.number))
gotErr := err != nil
if gotErr != tc.wantErr {
t.Errorf("FindExtensionByNumber(%v, %d) = (_, %v), want error? %t", tc.parent, tc.number, err, tc.wantErr)
continue
}
if tc.wantNotFound && err != preg.NotFound {
t.Errorf("FindExtensionByNumber(%v, %d) got error %v, want NotFound error", tc.parent, tc.number, err)
continue
}
if got != tc.extensionType {
t.Errorf("FindExtensionByNumber(%v, %d) got wrong value: %v", tc.parent, tc.number, got)
}
}
})
fullName := func(t preg.Type) pref.FullName {
switch t := t.(type) {
case pref.EnumType:
return t.Descriptor().FullName()
case pref.MessageType:
return t.Descriptor().FullName()
case pref.ExtensionType:
return t.Descriptor().FullName()
default:
panic("invalid type")
}
}
sortTypes := cmpopts.SortSlices(func(x, y preg.Type) bool {
return fullName(x) < fullName(y)
})
compare := cmp.Comparer(func(x, y preg.Type) bool {
return x == y
})
t.Run("RangeMessages", func(t *testing.T) {
want := []preg.Type{mt1}
var got []preg.Type
registry.RangeMessages(func(mt pref.MessageType) bool {
got = append(got, mt)
return true
})
diff := cmp.Diff(want, got, sortTypes, compare)
if diff != "" {
t.Errorf("RangeMessages() mismatch (-want +got):\n%v", diff)
}
})
t.Run("RangeEnums", func(t *testing.T) {
want := []preg.Type{et1}
var got []preg.Type
registry.RangeEnums(func(et pref.EnumType) bool {
got = append(got, et)
return true
})
diff := cmp.Diff(want, got, sortTypes, compare)
if diff != "" {
t.Errorf("RangeEnums() mismatch (-want +got):\n%v", diff)
}
})
t.Run("RangeExtensions", func(t *testing.T) {
want := []preg.Type{xt1, xt2}
var got []preg.Type
registry.RangeExtensions(func(xt pref.ExtensionType) bool {
got = append(got, xt)
return true
})
diff := cmp.Diff(want, got, sortTypes, compare)
if diff != "" {
t.Errorf("RangeExtensions() mismatch (-want +got):\n%v", diff)
}
})
t.Run("RangeExtensionsByMessage", func(t *testing.T) {
want := []preg.Type{xt1, xt2}
var got []preg.Type
registry.RangeExtensionsByMessage(pref.FullName("testprotos.Message1"), func(xt pref.ExtensionType) bool {
got = append(got, xt)
return true
})
diff := cmp.Diff(want, got, sortTypes, compare)
if diff != "" {
t.Errorf("RangeExtensionsByMessage() mismatch (-want +got):\n%v", diff)
}
})
}