From 0cf31136c79661a1d8421594363a2b2701e719ba Mon Sep 17 00:00:00 2001 From: John Wright Date: Wed, 15 May 2019 11:39:13 -0600 Subject: [PATCH] internal/prototype: support dynamic enum and message types in extension GoExtension now supports extensions that have enum or message type that is implemented by a Go type that can take on multiple enum or message types (i.e. the actual enum or message type cannot be determined simply from the zero value of its Go type). This is necessary to support dynamic types generated at runtime from descriptors rather than at compile-time. Change-Id: Ia0b3b4b02332fc83c0c85e992b37ded467070472 Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/177338 Reviewed-by: Joe Tsai --- internal/prototype/go_type.go | 5 +- internal/prototype/go_type_test.go | 143 +++++++++++++++++++++++++++++ internal/value/convert.go | 62 +++++++++---- 3 files changed, 193 insertions(+), 17 deletions(-) create mode 100644 internal/prototype/go_type_test.go diff --git a/internal/prototype/go_type.go b/internal/prototype/go_type.go index 255558c5..bf7977fb 100644 --- a/internal/prototype/go_type.go +++ b/internal/prototype/go_type.go @@ -257,15 +257,18 @@ func (t *goExtension) lazyInit() { } case protoreflect.Repeated: var typ reflect.Type + var c value.Converter switch t.Kind() { case protoreflect.EnumKind: typ = t.enumType.GoType() + c = value.NewEnumConverter(t.enumType) case protoreflect.MessageKind, protoreflect.GroupKind: typ = t.messageType.GoType() + c = value.NewMessageConverter(t.messageType) default: typ = goTypeForPBKind[t.Kind()] + c = value.NewConverter(typ, t.Kind()) } - c := value.NewConverter(typ, t.Kind()) t.typ = reflect.PtrTo(reflect.SliceOf(typ)) t.new = func() protoreflect.Value { v := reflect.New(t.typ.Elem()).Interface() diff --git a/internal/prototype/go_type_test.go b/internal/prototype/go_type_test.go new file mode 100644 index 00000000..7200afe7 --- /dev/null +++ b/internal/prototype/go_type_test.go @@ -0,0 +1,143 @@ +package prototype_test + +import ( + "fmt" + "reflect" + "testing" + + "google.golang.org/protobuf/internal/prototype" + "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/reflect/protoregistry" + + testpb "google.golang.org/protobuf/internal/testprotos/test" +) + +func TestGoEnum(t *testing.T) { + enumDescs := []protoreflect.EnumDescriptor{ + testpb.ForeignEnum(0).Descriptor(), + testpb.TestAllTypes_NestedEnum(0).Descriptor(), + } + for _, ed := range enumDescs { + et := prototype.GoEnum(ed, newEnum) + if gotED := et.Descriptor(); gotED != ed { + fmt.Errorf("GoEnum(ed (%v), newEnum).Descriptor() != ed", ed.FullName()) + } + e := et.New(0) + if gotED := e.Descriptor(); gotED != ed { + fmt.Errorf("GoEnum(ed (%v), newEnum).New(0).Descriptor() != ed", ed.FullName()) + } + if n := e.Number(); n != 0 { + fmt.Errorf("GoEnum(ed (%v), newEnum).New(0).Number() = %v; want 0", ed.FullName(), n) + } + if _, ok := e.(fakeEnum); !ok { + fmt.Errorf("GoEnum(ed (%v), newEnum).New(0) type is %T; want fakeEnum", ed.FullName(), e) + } + } +} + +func TestGoMessage(t *testing.T) { + msgDescs := []protoreflect.MessageDescriptor{ + ((*testpb.TestAllTypes)(nil)).ProtoReflect().Descriptor(), + ((*testpb.TestAllTypes_NestedMessage)(nil)).ProtoReflect().Descriptor(), + } + for _, md := range msgDescs { + mt := prototype.GoMessage(md, newMessage) + if gotMD := mt.Descriptor(); gotMD != md { + fmt.Errorf("GoMessage(md (%v), newMessage).Descriptor() != md", md.FullName()) + } + m := mt.New() + if gotMD := m.Descriptor(); gotMD != md { + fmt.Errorf("GoMessage(md (%v), newMessage).New().Descriptor() != md", md.FullName()) + } + if _, ok := m.(*fakeMessage); !ok { + fmt.Errorf("GoMessage(md (%v), newMessage).New() type is %T; want *fakeMessage", md.FullName(), m) + } + } +} + +func TestGoExtension(t *testing.T) { + testCases := []struct { + extName protoreflect.FullName + wantNewType reflect.Type + }{{ + extName: "goproto.proto.test.optional_int32_extension", + wantNewType: reflect.TypeOf(int32(0)), + }, { + extName: "goproto.proto.test.optional_string_extension", + wantNewType: reflect.TypeOf(""), + }, { + extName: "goproto.proto.test.repeated_int32_extension", + wantNewType: reflect.TypeOf((*[]int32)(nil)), + }, { + extName: "goproto.proto.test.repeated_string_extension", + wantNewType: reflect.TypeOf((*[]string)(nil)), + }, { + extName: "goproto.proto.test.repeated_string_extension", + wantNewType: reflect.TypeOf((*[]string)(nil)), + }, { + extName: "goproto.proto.test.optional_nested_enum_extension", + wantNewType: reflect.TypeOf((*fakeEnum)(nil)).Elem(), + }, { + extName: "goproto.proto.test.optional_nested_message_extension", + wantNewType: reflect.TypeOf((*fakeMessageImpl)(nil)), + }, { + extName: "goproto.proto.test.repeated_nested_enum_extension", + wantNewType: reflect.TypeOf((*[]fakeEnum)(nil)), + }, { + extName: "goproto.proto.test.repeated_nested_message_extension", + wantNewType: reflect.TypeOf((*[]*fakeMessageImpl)(nil)), + }} + for _, tc := range testCases { + xd, err := protoregistry.GlobalFiles.FindExtensionByName(tc.extName) + if err != nil { + t.Errorf("GlobalFiles.FindExtensionByName(%q) = _, %v; want _, ", tc.extName, err) + continue + } + var et protoreflect.EnumType + if ed := xd.Enum(); ed != nil { + et = prototype.GoEnum(ed, newEnum) + } + var mt protoreflect.MessageType + if md := xd.Message(); md != nil { + mt = prototype.GoMessage(md, newMessage) + } + xt := prototype.GoExtension(xd, et, mt) + v := xt.InterfaceOf(xt.New()) + if typ := reflect.TypeOf(v); typ != tc.wantNewType { + t.Errorf("GoExtension(xd (%v), et, mt).New() type unwraps to %v; want %v", tc.extName, typ, tc.wantNewType) + } + } +} + +type fakeMessage struct { + imp *fakeMessageImpl + protoreflect.Message +} + +func (m *fakeMessage) Type() protoreflect.MessageType { return m.imp.typ } +func (m *fakeMessage) Descriptor() protoreflect.MessageDescriptor { return m.imp.typ.Descriptor() } +func (m *fakeMessage) Interface() protoreflect.ProtoMessage { return m.imp } + +type fakeMessageImpl struct{ typ protoreflect.MessageType } + +func (m *fakeMessageImpl) ProtoReflect() protoreflect.Message { return &fakeMessage{imp: m} } + +func newMessage(typ protoreflect.MessageType) protoreflect.Message { + return (&fakeMessageImpl{typ: typ}).ProtoReflect() +} + +type fakeEnum struct { + typ protoreflect.EnumType + num protoreflect.EnumNumber +} + +func (e fakeEnum) Descriptor() protoreflect.EnumDescriptor { return e.typ.Descriptor() } +func (e fakeEnum) Type() protoreflect.EnumType { return e.typ } +func (e fakeEnum) Number() protoreflect.EnumNumber { return e.num } + +func newEnum(typ protoreflect.EnumType, num protoreflect.EnumNumber) protoreflect.Enum { + return fakeEnum{ + typ: typ, + num: num, + } +} diff --git a/internal/value/convert.go b/internal/value/convert.go index 793af226..35de822b 100644 --- a/internal/value/convert.go +++ b/internal/value/convert.go @@ -148,22 +148,7 @@ func NewLegacyConverter(t reflect.Type, k pref.Kind, w LegacyWrapper) Converter if t.Kind() == reflect.Ptr && t.Implements(messageIfaceV2) { md := reflect.Zero(t).Interface().(pref.ProtoMessage).ProtoReflect().Descriptor() mt := &messageType{md, t} - return Converter{ - PBValueOf: func(v reflect.Value) pref.Value { - if v.Type() != t { - panic(fmt.Sprintf("invalid type: got %v, want %v", v.Type(), t)) - } - return pref.ValueOf(v.Interface().(pref.ProtoMessage).ProtoReflect()) - }, - GoValueOf: func(v pref.Value) reflect.Value { - rv := reflect.ValueOf(v.Message().Interface()) - if rv.Type() != t { - panic(fmt.Sprintf("invalid type: got %v, want %v", rv.Type(), t)) - } - return rv - }, - MessageType: mt, - } + return NewMessageConverter(mt) } // Handle v1 messages, which we need to wrap as a v2 message. @@ -215,6 +200,51 @@ func makeScalarConverter(goType, pbType reflect.Type) Converter { } } +// NewEnumConverter returns a converter for an EnumType, whose GoType must implement protoreflect.Enum. +func NewEnumConverter(et pref.EnumType) Converter { + t := et.GoType() + if !t.Implements(enumIfaceV2) { + panic(fmt.Sprintf("invalid type: %v does not implement %v", t, enumIfaceV2)) + } + return Converter{ + PBValueOf: func(v reflect.Value) pref.Value { + if v.Type() != t { + panic(fmt.Sprintf("invalid type: got %v, want %v", v.Type(), t)) + } + e := v.Interface().(pref.Enum) + return pref.ValueOf(e.Number()) + }, + GoValueOf: func(v pref.Value) reflect.Value { + return reflect.ValueOf(et.New(v.Enum())) + }, + EnumType: et, + } +} + +// NewMessageConverter returns a converter for a MessageType, whose GoType must implement protoreflect.ProtoMessage. +func NewMessageConverter(mt pref.MessageType) Converter { + t := mt.GoType() + if !t.Implements(messageIfaceV2) { + panic(fmt.Sprintf("invalid type: %v does not implement %v", t, messageIfaceV2)) + } + return Converter{ + PBValueOf: func(v reflect.Value) pref.Value { + if v.Type() != t { + panic(fmt.Sprintf("invalid type: got %v, want %v", v.Type(), t)) + } + return pref.ValueOf(v.Interface().(pref.ProtoMessage).ProtoReflect()) + }, + GoValueOf: func(v pref.Value) reflect.Value { + rv := reflect.ValueOf(v.Message().Interface()) + if rv.Type() != t { + panic(fmt.Sprintf("invalid type: got %v, want %v", rv.Type(), t)) + } + return rv + }, + MessageType: mt, + } +} + // Converter provides functions for converting to/from Go reflect.Value types // and protobuf protoreflect.Value types. type Converter struct {