diff --git a/reflect/protodesc/desc.go b/reflect/protodesc/desc.go index 658cad48..37f254d4 100644 --- a/reflect/protodesc/desc.go +++ b/reflect/protodesc/desc.go @@ -67,6 +67,12 @@ func NewFile(fd *descriptorpb.FileDescriptorProto, r Resolver) (protoreflect.Fil return FileOptions{}.New(fd, r) } +// NewFiles creates a new protoregistry.Files from the provided +// FileDescriptorSet message. See FileOptions.NewFiles for more information. +func NewFiles(fd *descriptorpb.FileDescriptorSet) (*protoregistry.Files, error) { + return FileOptions{}.NewFiles(fd) +} + // New creates a new protoreflect.FileDescriptor from the provided // file descriptor message. The file must represent a valid proto file according // to protobuf semantics. The returned descriptor is a deep copy of the input. @@ -223,3 +229,47 @@ func (is importSet) importPublic(imps protoreflect.FileImports) { } } } + +// NewFiles creates a new protoregistry.Files from the provided +// FileDescriptorSet message. The descriptor set must include only +// valid files according to protobuf semantics. The returned descriptors +// are a deep copy of the input. +func (o FileOptions) NewFiles(fds *descriptorpb.FileDescriptorSet) (*protoregistry.Files, error) { + files := make(map[string]*descriptorpb.FileDescriptorProto) + for _, fd := range fds.File { + if _, ok := files[fd.GetName()]; ok { + return nil, errors.New("file appears multiple times: %q", fd.GetName()) + } + files[fd.GetName()] = fd + } + r := &protoregistry.Files{} + for _, fd := range files { + if err := o.addFileDeps(r, fd, files); err != nil { + return nil, err + } + } + return r, nil +} +func (o FileOptions) addFileDeps(r *protoregistry.Files, fd *descriptorpb.FileDescriptorProto, files map[string]*descriptorpb.FileDescriptorProto) error { + // Set the entry to nil while descending into a file's dependencies to detect cycles. + files[fd.GetName()] = nil + for _, dep := range fd.Dependency { + depfd, ok := files[dep] + if depfd == nil { + if ok { + return errors.New("import cycle in file: %q", dep) + } + continue + } + if err := o.addFileDeps(r, depfd, files); err != nil { + return err + } + } + // Delete the entry once dependencies are processed. + delete(files, fd.GetName()) + f, err := o.New(fd, r) + if err != nil { + return err + } + return r.RegisterFile(f) +} diff --git a/reflect/protodesc/file_test.go b/reflect/protodesc/file_test.go index 744c02ff..9b7e8dd4 100644 --- a/reflect/protodesc/file_test.go +++ b/reflect/protodesc/file_test.go @@ -12,6 +12,7 @@ import ( "google.golang.org/protobuf/encoding/prototext" "google.golang.org/protobuf/internal/flags" "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protoreflect" "google.golang.org/protobuf/reflect/protoregistry" "google.golang.org/protobuf/types/descriptorpb" @@ -935,3 +936,61 @@ func TestNewFile(t *testing.T) { }) } } + +func TestNewFiles(t *testing.T) { + fdset := &descriptorpb.FileDescriptorSet{ + File: []*descriptorpb.FileDescriptorProto{ + mustParseFile(` + name: "test.proto" + package: "fizz" + dependency: "dep.proto" + message_type: [{ + name: "M2" + field: [{name:"F" number:1 label:LABEL_OPTIONAL type:TYPE_MESSAGE type_name:"M1"}] + }] + `), + // Inputs deliberately out of order. + mustParseFile(` + name: "dep.proto" + package: "fizz" + message_type: [{name:"M1"}] + `), + }, + } + f, err := NewFiles(fdset) + if err != nil { + t.Fatal(err) + } + m1, err := f.FindDescriptorByName("fizz.M1") + if err != nil { + t.Fatalf(`f.FindDescriptorByName("fizz.M1") = %v`, err) + } + m2, err := f.FindDescriptorByName("fizz.M2") + if err != nil { + t.Fatalf(`f.FindDescriptorByName("fizz.M2") = %v`, err) + } + if m2.(protoreflect.MessageDescriptor).Fields().ByName("F").Message() != m1 { + t.Fatalf(`m1.Fields().ByName("F").Message() != m2`) + } +} + +func TestNewFilesImportCycle(t *testing.T) { + fdset := &descriptorpb.FileDescriptorSet{ + File: []*descriptorpb.FileDescriptorProto{ + mustParseFile(` + name: "test.proto" + package: "fizz" + dependency: "dep.proto" + `), + mustParseFile(` + name: "dep.proto" + package: "fizz" + dependency: "test.proto" + `), + }, + } + _, err := NewFiles(fdset) + if err == nil { + t.Fatal("NewFiles with import cycle: success, want error") + } +}