diff --git a/reflect/protoregistry/registry.go b/reflect/protoregistry/registry.go index 66dcbcd0..3c7ec706 100644 --- a/reflect/protoregistry/registry.go +++ b/reflect/protoregistry/registry.go @@ -94,7 +94,8 @@ type Files struct { // Note that enum values are in the top-level since that are in the same // scope as the parent enum. descsByName map[protoreflect.FullName]interface{} - filesByPath map[string]protoreflect.FileDescriptor + filesByPath map[string][]protoreflect.FileDescriptor + numFiles int } type packageDescriptor struct { @@ -117,17 +118,11 @@ func (r *Files) RegisterFile(file protoreflect.FileDescriptor) error { r.descsByName = map[protoreflect.FullName]interface{}{ "": &packageDescriptor{}, } - r.filesByPath = make(map[string]protoreflect.FileDescriptor) + r.filesByPath = make(map[string][]protoreflect.FileDescriptor) } path := file.Path() - if prev := r.filesByPath[path]; prev != nil { + if len(r.filesByPath[path]) > 0 { r.checkGenProtoConflict(path) - err := errors.New("file %q is already registered", file.Path()) - err = amendErrorWithCaller(err, prev, file) - if r == GlobalFiles && ignoreConflict(file, err) { - err = nil - } - return err } for name := file.Package(); name != ""; name = name.Parent() { @@ -168,7 +163,8 @@ func (r *Files) RegisterFile(file protoreflect.FileDescriptor) error { rangeTopLevelDescriptors(file, func(d protoreflect.Descriptor) { r.descsByName[d.FullName()] = d }) - r.filesByPath[path] = file + r.filesByPath[path] = append(r.filesByPath[path], file) + r.numFiles++ return nil } @@ -308,6 +304,7 @@ func (s *nameSuffix) Pop() (name protoreflect.Name) { // FindFileByPath looks up a file by the path. // // This returns (nil, NotFound) if not found. +// This returns an error if multiple files have the same path. func (r *Files) FindFileByPath(path string) (protoreflect.FileDescriptor, error) { if r == nil { return nil, NotFound @@ -316,13 +313,19 @@ func (r *Files) FindFileByPath(path string) (protoreflect.FileDescriptor, error) globalMutex.RLock() defer globalMutex.RUnlock() } - if fd, ok := r.filesByPath[path]; ok { - return fd, nil + fds := r.filesByPath[path] + switch len(fds) { + case 0: + return nil, NotFound + case 1: + return fds[0], nil + default: + return nil, errors.New("multiple files named %q", path) } - return nil, NotFound } -// NumFiles reports the number of registered files. +// NumFiles reports the number of registered files, +// including duplicate files with the same name. func (r *Files) NumFiles() int { if r == nil { return 0 @@ -331,10 +334,11 @@ func (r *Files) NumFiles() int { globalMutex.RLock() defer globalMutex.RUnlock() } - return len(r.filesByPath) + return r.numFiles } // RangeFiles iterates over all registered files while f returns true. +// If multiple files have the same name, RangeFiles iterates over all of them. // The iteration order is undefined. func (r *Files) RangeFiles(f func(protoreflect.FileDescriptor) bool) { if r == nil { @@ -344,9 +348,11 @@ func (r *Files) RangeFiles(f func(protoreflect.FileDescriptor) bool) { globalMutex.RLock() defer globalMutex.RUnlock() } - for _, file := range r.filesByPath { - if !f(file) { - return + for _, files := range r.filesByPath { + for _, file := range files { + if !f(file) { + return + } } } } diff --git a/reflect/protoregistry/registry_test.go b/reflect/protoregistry/registry_test.go index 446e0365..bc8019c8 100644 --- a/reflect/protoregistry/registry_test.go +++ b/reflect/protoregistry/registry_test.go @@ -55,6 +55,7 @@ func TestFiles(t *testing.T) { testFindPath struct { inPath string wantFiles []file + wantErr string } ) @@ -68,7 +69,7 @@ func TestFiles(t *testing.T) { files: []testFile{ {inFile: mustMakeFile(`syntax:"proto2" name:"test1.proto" package:"foo.bar"`)}, {inFile: mustMakeFile(`syntax:"proto2" name:"foo/bar/test.proto" package:"my.test"`)}, - {inFile: mustMakeFile(`syntax:"proto2" name:"foo/bar/test.proto" package:"foo.bar.baz"`), wantErr: "already registered"}, + {inFile: mustMakeFile(`syntax:"proto2" name:"foo/bar/test.proto" package:"foo.bar.baz"`)}, {inFile: mustMakeFile(`syntax:"proto2" name:"test2.proto" package:"my.test.package"`)}, {inFile: mustMakeFile(`syntax:"proto2" name:"weird" package:"foo.bar"`)}, {inFile: mustMakeFile(`syntax:"proto2" name:"foo/bar/baz/../test.proto" package:"my.test"`)}, @@ -103,17 +104,16 @@ func TestFiles(t *testing.T) { }}, findPaths: []testFindPath{{ - inPath: "nothing", + inPath: "nothing", + wantErr: "not found", }, { inPath: "weird", wantFiles: []file{ {"weird", "foo.bar"}, }, }, { - inPath: "foo/bar/test.proto", - wantFiles: []file{ - {"foo/bar/test.proto", "my.test"}, - }, + inPath: "foo/bar/test.proto", + wantErr: `multiple files named "foo/bar/test.proto"`, }}, }, { // Test when new enum conflicts with existing package. @@ -315,9 +315,13 @@ func TestFiles(t *testing.T) { for _, tc := range tt.findPaths { var gotFiles []file - if fd, err := files.FindFileByPath(tc.inPath); err == nil { + fd, gotErr := files.FindFileByPath(tc.inPath) + if gotErr == nil { gotFiles = append(gotFiles, file{fd.Path(), fd.Package()}) } + if ((gotErr == nil) != (tc.wantErr == "")) || !strings.Contains(fmt.Sprint(gotErr), tc.wantErr) { + t.Errorf("FindFileByPath(%v) = %v, want %v", tc.inPath, gotErr, tc.wantErr) + } if diff := cmp.Diff(tc.wantFiles, gotFiles, sortFiles); diff != "" { t.Errorf("FindFileByPath(%v) mismatch (-want +got):\n%v", tc.inPath, diff) }