diff --git a/encoding/protojson/decode.go b/encoding/protojson/decode.go index 99e1a8b3..a40d1e22 100644 --- a/encoding/protojson/decode.go +++ b/encoding/protojson/decode.go @@ -36,10 +36,13 @@ type UnmarshalOptions struct { // If DiscardUnknown is set, unknown fields are ignored. DiscardUnknown bool - // Resolver is the registry used for type lookups when unmarshaling extensions - // and processing Any. If Resolver is not set, unmarshaling will default to - // using protoregistry.GlobalTypes. - Resolver *protoregistry.Types + // Resolver is used for looking up types when unmarshaling + // google.protobuf.Any messages or extension fields. + // If nil, this defaults to using protoregistry.GlobalTypes. + Resolver interface { + protoregistry.MessageTypeResolver + protoregistry.ExtensionTypeResolver + } decoder *json.Decoder } diff --git a/encoding/protojson/encode.go b/encoding/protojson/encode.go index bd0ee3da..d789986f 100644 --- a/encoding/protojson/encode.go +++ b/encoding/protojson/encode.go @@ -36,10 +36,11 @@ type MarshalOptions struct { // composed of space or tab characters. Indent string - // Resolver is the registry used for type lookups when marshaling - // google.protobuf.Any messages. If Resolver is not set, marshaling will - // default to using protoregistry.GlobalTypes. - Resolver *protoregistry.Types + // Resolver is used for looking up types when expanding google.protobuf.Any + // messages. If nil, this defaults to using protoregistry.GlobalTypes. + Resolver interface { + protoregistry.MessageTypeResolver + } encoder *json.Encoder } diff --git a/encoding/prototext/decode.go b/encoding/prototext/decode.go index efc4c7af..20bdfe6b 100644 --- a/encoding/prototext/decode.go +++ b/encoding/prototext/decode.go @@ -33,10 +33,13 @@ type UnmarshalOptions struct { // return error if there are any missing required fields. AllowPartial bool - // Resolver is the registry used for type lookups when unmarshaling extensions - // and processing Any. If Resolver is not set, unmarshaling will default to - // using protoregistry.GlobalTypes. - Resolver *protoregistry.Types + // Resolver is used for looking up types when unmarshaling + // google.protobuf.Any messages or extension fields. + // If nil, this defaults to using protoregistry.GlobalTypes. + Resolver interface { + protoregistry.MessageTypeResolver + protoregistry.ExtensionTypeResolver + } } // Unmarshal reads the given []byte and populates the given proto.Message using options in diff --git a/encoding/prototext/encode.go b/encoding/prototext/encode.go index 8f063708..d86492e6 100644 --- a/encoding/prototext/encode.go +++ b/encoding/prototext/encode.go @@ -39,11 +39,11 @@ type MarshalOptions struct { // composed of space or tab characters. Indent string - // Resolver is the registry used for type lookups when marshaling out - // google.protobuf.Any messages in expanded form. If Resolver is not set, - // marshaling will default to using protoregistry.GlobalTypes. If a type is - // not found, an Any message will be marshaled as a regular message. - Resolver *protoregistry.Types + // Resolver is used for looking up types when expanding google.protobuf.Any + // messages. If nil, this defaults to using protoregistry.GlobalTypes. + Resolver interface { + protoregistry.MessageTypeResolver + } } // Marshal writes the given proto.Message in textproto format using options in MarshalOptions object. diff --git a/proto/decode.go b/proto/decode.go index 3d3d0dad..b3766856 100644 --- a/proto/decode.go +++ b/proto/decode.go @@ -28,7 +28,9 @@ type UnmarshalOptions struct { // Resolver is used for looking up types when unmarshaling extension fields. // If nil, this defaults to using protoregistry.GlobalTypes. - Resolver *protoregistry.Types + Resolver interface { + protoregistry.ExtensionTypeResolver + } pragma.NoUnkeyedLiterals } diff --git a/reflect/protoregistry/registry.go b/reflect/protoregistry/registry.go index d3ea3fc5..b778cd2e 100644 --- a/reflect/protoregistry/registry.go +++ b/reflect/protoregistry/registry.go @@ -276,38 +276,56 @@ var ( _ Type = protoreflect.ExtensionType(nil) ) +// MessageTypeResolver is an interface for looking up messages. +// +// A compliant implementation must deterministically return the same type +// if no error is encountered. +// +// The Types type implements this interface. +type MessageTypeResolver interface { + // FindMessageByName looks up a message by its full name. + // E.g., "google.protobuf.Any" + // + // This return (nil, NotFound) if not found. + FindMessageByName(message protoreflect.FullName) (protoreflect.MessageType, error) + + // FindMessageByURL looks up a message by a URL identifier. + // See documentation on google.protobuf.Any.type_url for the URL format. + // + // This returns (nil, NotFound) if not found. + FindMessageByURL(url string) (protoreflect.MessageType, error) +} + +// ExtensionTypeResolver is an interface for looking up extensions. +// +// A compliant implementation must deterministically return the same type +// if no error is encountered. +// +// The Types type implements this interface. +type ExtensionTypeResolver interface { + // FindExtensionByName looks up a extension field by the field's full name. + // Note that this is the full name of the field as determined by + // where the extension is declared and is unrelated to the full name of the + // message being extended. + // + // This returns (nil, NotFound) if not found. + FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error) + + // FindExtensionByNumber looks up a extension field by the field number + // within some parent message, identified by full name. + // + // This returns (nil, NotFound) if not found. + FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error) +} + +var ( + _ MessageTypeResolver = (*Types)(nil) + _ ExtensionTypeResolver = (*Types)(nil) +) + // Types is a registry for looking up or iterating over descriptor types. // The Find and Range methods are safe for concurrent use. type Types struct { - // Parent sets the parent registry to consult if a find operation - // could not locate the appropriate entry. - // - // Setting a parent results in each Range operation also iterating over the - // entries contained within the parent. In such a case, it is possible for - // Range to emit duplicates (since they may exist in both child and parent). - // Range iteration is guaranteed to iterate over local entries before - // iterating over parent entries. - Parent *Types - - // Resolver sets the local resolver to consult if the local registry does - // not contain an entry. The resolver takes precedence over the parent. - // - // The url is a URL where the full name of the type is the last segment - // of the path (i.e. string following the last '/' character). - // When missing a '/' character, the URL is the full name of the type. - // See documentation on the google.protobuf.Any.type_url field for details. - // - // If the resolver returns a result, it is not automatically registered - // into the local registry. Thus, a resolver function should cache results - // such that it deterministically returns the same result given the - // same URL assuming the error returned is nil or NotFound. - // - // If the resolver returns the NotFound error, the registry will consult the - // parent registry if it is set. - // - // Setting a resolver has no effect on the result of each Range operation. - Resolver func(url string) (Type, error) - // TODO: The syntax of the URL is ill-defined and the protobuf team recently // changed the documented semantics in a way that breaks prior usages. // I do not believe they can do this and need to sync up with the @@ -342,7 +360,6 @@ type ( // NewTypes returns a registry initialized with the provided set of types. // If there are conflicts, the first one takes precedence. func NewTypes(typs ...Type) *Types { - // TODO: Allow setting resolver and parent via constructor? r := new(Types) r.Register(typs...) // ignore errors; first takes precedence return r @@ -418,25 +435,17 @@ typeLoop: // // This returns (nil, NotFound) if not found. func (r *Types) FindEnumByName(enum protoreflect.FullName) (protoreflect.EnumType, error) { - r.globalCheck() if r == nil { return nil, NotFound } v, _ := r.typesByName[enum] - if v == nil && r.Resolver != nil { - var err error - v, err = r.Resolver(string(enum)) - if err != nil && err != NotFound { - return nil, err - } - } if v != nil { if et, _ := v.(protoreflect.EnumType); et != nil { return et, nil } return nil, errors.New("found wrong type: got %v, want enum", typeName(v)) } - return r.Parent.FindEnumByName(enum) + return nil, NotFound } // FindMessageByName looks up a message by its full name. @@ -449,11 +458,10 @@ func (r *Types) FindMessageByName(message protoreflect.FullName) (protoreflect.M } // FindMessageByURL looks up a message by a URL identifier. -// See Resolver for the format of the URL. +// See documentation on google.protobuf.Any.type_url for the URL format. // // This returns (nil, NotFound) if not found. func (r *Types) FindMessageByURL(url string) (protoreflect.MessageType, error) { - r.globalCheck() if r == nil { return nil, NotFound } @@ -463,20 +471,13 @@ func (r *Types) FindMessageByURL(url string) (protoreflect.MessageType, error) { } v, _ := r.typesByName[message] - if v == nil && r.Resolver != nil { - var err error - v, err = r.Resolver(url) - if err != nil && err != NotFound { - return nil, err - } - } if v != nil { if mt, _ := v.(protoreflect.MessageType); mt != nil { return mt, nil } return nil, errors.New("found wrong type: got %v, want message", typeName(v)) } - return r.Parent.FindMessageByURL(url) + return nil, NotFound } // FindExtensionByName looks up a extension field by the field's full name. @@ -486,25 +487,17 @@ func (r *Types) FindMessageByURL(url string) (protoreflect.MessageType, error) { // // This returns (nil, NotFound) if not found. func (r *Types) FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error) { - r.globalCheck() if r == nil { return nil, NotFound } v, _ := r.typesByName[field] - if v == nil && r.Resolver != nil { - var err error - v, err = r.Resolver(string(field)) - if err != nil && err != NotFound { - return nil, err - } - } if v != nil { if xt, _ := v.(protoreflect.ExtensionType); xt != nil { return xt, nil } return nil, errors.New("found wrong type: got %v, want extension", typeName(v)) } - return r.Parent.FindExtensionByName(field) + return nil, NotFound } // FindExtensionByNumber looks up a extension field by the field number @@ -512,20 +505,18 @@ func (r *Types) FindExtensionByName(field protoreflect.FullName) (protoreflect.E // // This returns (nil, NotFound) if not found. func (r *Types) FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error) { - r.globalCheck() if r == nil { return nil, NotFound } if xt, ok := r.extensionsByMessage[message][field]; ok { return xt, nil } - return r.Parent.FindExtensionByNumber(message, field) + return nil, NotFound } // RangeEnums iterates over all registered enums. // Iteration order is undefined. func (r *Types) RangeEnums(f func(protoreflect.EnumType) bool) { - r.globalCheck() if r == nil { return } @@ -536,13 +527,11 @@ func (r *Types) RangeEnums(f func(protoreflect.EnumType) bool) { } } } - r.Parent.RangeEnums(f) } // RangeMessages iterates over all registered messages. // Iteration order is undefined. func (r *Types) RangeMessages(f func(protoreflect.MessageType) bool) { - r.globalCheck() if r == nil { return } @@ -553,13 +542,11 @@ func (r *Types) RangeMessages(f func(protoreflect.MessageType) bool) { } } } - r.Parent.RangeMessages(f) } // RangeExtensions iterates over all registered extensions. // Iteration order is undefined. func (r *Types) RangeExtensions(f func(protoreflect.ExtensionType) bool) { - r.globalCheck() if r == nil { return } @@ -570,13 +557,11 @@ func (r *Types) RangeExtensions(f func(protoreflect.ExtensionType) bool) { } } } - r.Parent.RangeExtensions(f) } // RangeExtensionsByMessage iterates over all registered extensions filtered // by a given message type. Iteration order is undefined. func (r *Types) RangeExtensionsByMessage(message protoreflect.FullName, f func(protoreflect.ExtensionType) bool) { - r.globalCheck() if r == nil { return } @@ -585,13 +570,6 @@ func (r *Types) RangeExtensionsByMessage(message protoreflect.FullName, f func(p return } } - r.Parent.RangeExtensionsByMessage(message, f) -} - -func (r *Types) globalCheck() { - if r == GlobalTypes && (r.Parent != nil || r.Resolver != nil) { - panic("GlobalTypes.Parent and GlobalTypes.Resolver cannot be set") - } } func typeName(t Type) string { diff --git a/reflect/protoregistry/registry_test.go b/reflect/protoregistry/registry_test.go index 5da8f3d3..055879a0 100644 --- a/reflect/protoregistry/registry_test.go +++ b/reflect/protoregistry/registry_test.go @@ -302,41 +302,12 @@ func TestFilesLookup(t *testing.T) { } func TestTypes(t *testing.T) { - // Suffix 1 in registry, 2 in parent, 3 in resolver. mt1 := pimpl.Export{}.MessageTypeOf(&testpb.Message1{}) - mt2 := pimpl.Export{}.MessageTypeOf(&testpb.Message2{}) - mt3 := pimpl.Export{}.MessageTypeOf(&testpb.Message3{}) et1 := pimpl.Export{}.EnumTypeOf(testpb.Enum1_ONE) - et2 := pimpl.Export{}.EnumTypeOf(testpb.Enum2_UNO) - et3 := pimpl.Export{}.EnumTypeOf(testpb.Enum3_YI) - // Suffix indicates field number. - xt11 := testpb.E_StringField.Type - xt12 := testpb.E_EnumField.Type - xt13 := testpb.E_MessageField.Type - xt21 := testpb.E_Message4_MessageField.Type - xt22 := testpb.E_Message4_EnumField.Type - xt23 := testpb.E_Message4_StringField.Type - parent := &preg.Types{} - if err := parent.Register(mt2, et2, xt12, xt22); err != nil { - t.Fatalf("parent.Register() returns unexpected error: %v", err) - } - registry := &preg.Types{ - Parent: parent, - Resolver: func(url string) (preg.Type, error) { - switch { - case strings.HasSuffix(url, "testprotos.Message3"): - return mt3, nil - case strings.HasSuffix(url, "testprotos.Enum3"): - return et3, nil - case strings.HasSuffix(url, "testprotos.message_field"): - return xt13, nil - case strings.HasSuffix(url, "testprotos.Message4.string_field"): - return xt23, nil - } - return nil, preg.NotFound - }, - } - if err := registry.Register(mt1, et1, xt11, xt21); err != nil { + xt1 := testpb.E_StringField.Type + xt2 := testpb.E_Message4_MessageField.Type + registry := new(preg.Types) + if err := registry.Register(mt1, et1, xt1, xt2); err != nil { t.Fatalf("registry.Register() returns unexpected error: %v", err) } @@ -349,12 +320,6 @@ func TestTypes(t *testing.T) { }{{ name: "testprotos.Message1", messageType: mt1, - }, { - name: "testprotos.Message2", - messageType: mt2, - }, { - name: "testprotos.Message3", - messageType: mt3, }, { name: "testprotos.NoSuchMessage", wantErr: true, @@ -395,12 +360,6 @@ func TestTypes(t *testing.T) { }{{ name: "testprotos.Message1", messageType: mt1, - }, { - name: "foo.com/testprotos.Message2", - messageType: mt2, - }, { - name: "/testprotos.Message3", - messageType: mt3, }, { name: "type.googleapis.com/testprotos.Nada", wantErr: true, @@ -435,12 +394,6 @@ func TestTypes(t *testing.T) { }{{ name: "testprotos.Enum1", enumType: et1, - }, { - name: "testprotos.Enum2", - enumType: et2, - }, { - name: "testprotos.Enum3", - enumType: et3, }, { name: "testprotos.None", wantErr: true, @@ -474,22 +427,10 @@ func TestTypes(t *testing.T) { wantNotFound bool }{{ name: "testprotos.string_field", - extensionType: xt11, - }, { - name: "testprotos.enum_field", - extensionType: xt12, - }, { - name: "testprotos.message_field", - extensionType: xt13, + extensionType: xt1, }, { name: "testprotos.Message4.message_field", - extensionType: xt21, - }, { - name: "testprotos.Message4.enum_field", - extensionType: xt22, - }, { - name: "testprotos.Message4.string_field", - extensionType: xt23, + extensionType: xt2, }, { name: "testprotos.None", wantErr: true, @@ -525,13 +466,8 @@ func TestTypes(t *testing.T) { }{{ parent: "testprotos.Message1", number: 11, - extensionType: xt11, + extensionType: xt1, }, { - parent: "testprotos.Message1", - number: 12, - extensionType: xt12, - }, { - // FindExtensionByNumber does not use Resolver. parent: "testprotos.Message1", number: 13, wantErr: true, @@ -539,13 +475,8 @@ func TestTypes(t *testing.T) { }, { parent: "testprotos.Message1", number: 21, - extensionType: xt21, + extensionType: xt2, }, { - parent: "testprotos.Message1", - number: 22, - extensionType: xt22, - }, { - // FindExtensionByNumber does not use Resolver. parent: "testprotos.Message1", number: 23, wantErr: true, @@ -603,8 +534,7 @@ func TestTypes(t *testing.T) { }) t.Run("RangeMessages", func(t *testing.T) { - // RangeMessages do not include messages from Resolver. - want := []preg.Type{mt1, mt2} + want := []preg.Type{mt1} var got []preg.Type registry.RangeMessages(func(mt pref.MessageType) bool { got = append(got, mt) @@ -618,8 +548,7 @@ func TestTypes(t *testing.T) { }) t.Run("RangeEnums", func(t *testing.T) { - // RangeEnums do not include enums from Resolver. - want := []preg.Type{et1, et2} + want := []preg.Type{et1} var got []preg.Type registry.RangeEnums(func(et pref.EnumType) bool { got = append(got, et) @@ -633,8 +562,7 @@ func TestTypes(t *testing.T) { }) t.Run("RangeExtensions", func(t *testing.T) { - // RangeExtensions do not include messages from Resolver. - want := []preg.Type{xt11, xt12, xt21, xt22} + want := []preg.Type{xt1, xt2} var got []preg.Type registry.RangeExtensions(func(xt pref.ExtensionType) bool { got = append(got, xt) @@ -648,8 +576,7 @@ func TestTypes(t *testing.T) { }) t.Run("RangeExtensionsByMessage", func(t *testing.T) { - // RangeExtensions do not include messages from Resolver. - want := []preg.Type{xt11, xt12, xt21, xt22} + want := []preg.Type{xt1, xt2} var got []preg.Type registry.RangeExtensionsByMessage(pref.FullName("testprotos.Message1"), func(xt pref.ExtensionType) bool { got = append(got, xt) diff --git a/runtime/protoiface/methods.go b/runtime/protoiface/methods.go index 0ba9d5c9..af86d62a 100644 --- a/runtime/protoiface/methods.go +++ b/runtime/protoiface/methods.go @@ -63,7 +63,9 @@ type MarshalOptions struct { type UnmarshalOptions struct { AllowPartial bool DiscardUnknown bool - Resolver *protoregistry.Types + Resolver interface { + protoregistry.ExtensionTypeResolver + } pragma.NoUnkeyedLiterals }