From be5348c905867acf1dc5eb65c6571a9c3f8a8ca2 Mon Sep 17 00:00:00 2001 From: Joe Tsai Date: Tue, 23 Oct 2018 18:31:18 -0700 Subject: [PATCH] internal/impl: setup scaffolding for unknown and extension fields Setup scaffolding for implementing unknown and extension fields. Add functions to MessageType to produce a protoreflect.KnownFields or protoreflect.UnknownFields from a message pointer. Within the implementation of known fields, delegate the logic to the underlying extension fields (which also implements protoreflect.KnownFields) if the field number is not found in the set of defined fields. Change-Id: I2c35f4cdf1c7b58727ce6a582861ef18b8d69a61 Reviewed-on: https://go-review.googlesource.com/c/144280 Reviewed-by: Damien Neil --- internal/impl/message.go | 91 ++++++++++++++++++++++++++-------------- 1 file changed, 59 insertions(+), 32 deletions(-) diff --git a/internal/impl/message.go b/internal/impl/message.go index 1150449d..7552157b 100644 --- a/internal/impl/message.go +++ b/internal/impl/message.go @@ -34,6 +34,9 @@ type MessageType struct { // TODO: Split fields into dense and sparse maps similar to the current // table-driven implementation in v1? fields map[pref.FieldNumber]*fieldInfo + + unknownFields func(*messageDataType) pref.UnknownFields + extensionFields func(*messageDataType) pref.KnownFields } // init lazily initializes the MessageType upon first use and @@ -73,7 +76,9 @@ func (mi *MessageType) init(p interface{}) { }) } - mi.generateFieldFuncs(t.Elem(), md) + mi.generateKnownFieldFuncs(t.Elem(), md) + mi.generateUnknownFieldFuncs(t.Elem(), md) + mi.generateExtensionFieldFuncs(t.Elem(), md) }) // TODO: Remove this check? This API is primarily used by generated code, @@ -85,14 +90,14 @@ func (mi *MessageType) init(p interface{}) { } } -// generateFieldFuncs generates per-field functions for all common operations +// generateKnownFieldFuncs generates per-field functions for all operations // to be performed on each field. It takes in a reflect.Type representing the // Go struct, and a protoreflect.MessageDescriptor to match with the fields // in the struct. // // This code assumes that the struct is well-formed and panics if there are // any discrepancies. -func (mi *MessageType) generateFieldFuncs(t reflect.Type, md pref.MessageDescriptor) { +func (mi *MessageType) generateKnownFieldFuncs(t reflect.Type, md pref.MessageDescriptor) { // Generate a mapping of field numbers and names to Go struct field or type. fields := map[pref.FieldNumber]reflect.StructField{} oneofs := map[pref.Name]reflect.StructField{} @@ -157,6 +162,20 @@ fieldLoop: } } +func (mi *MessageType) generateUnknownFieldFuncs(t reflect.Type, md pref.MessageDescriptor) { + // TODO + mi.unknownFields = func(*messageDataType) pref.UnknownFields { + return emptyUnknownFields{} + } +} + +func (mi *MessageType) generateExtensionFieldFuncs(t reflect.Type, md pref.MessageDescriptor) { + // TODO + mi.extensionFields = func(*messageDataType) pref.KnownFields { + return emptyExtensionFields{} + } +} + func (mi *MessageType) MessageOf(p interface{}) pref.Message { mi.init(p) if m, ok := p.(pref.ProtoMessage); ok { @@ -174,7 +193,7 @@ func (mi *MessageType) KnownFieldsOf(p interface{}) pref.KnownFields { func (mi *MessageType) UnknownFieldsOf(p interface{}) pref.UnknownFields { mi.init(p) - return (*unknownFields)(mi.dataTypeOf(p)) + return mi.unknownFields(mi.dataTypeOf(p)) } func (mi *MessageType) dataTypeOf(p interface{}) *messageDataType { @@ -213,7 +232,7 @@ func (m *message) KnownFields() pref.KnownFields { return (*knownFields)(m) } func (m *message) UnknownFields() pref.UnknownFields { - return (*unknownFields)(m) + return m.mi.unknownFields((*messageDataType)(m)) } func (m *message) Unwrap() interface{} { // TODO: unexport? return m.p.asType(m.mi.goType.Elem()).Interface() @@ -234,45 +253,39 @@ func (fs *knownFields) Len() (cnt int) { cnt++ } } - // TODO: Handle extension fields. - return cnt + return cnt + fs.extensionFields().Len() } func (fs *knownFields) Has(n pref.FieldNumber) bool { if fi := fs.mi.fields[n]; fi != nil { return fi.has(fs.p) } - // TODO: Handle extension fields. - return false + return fs.extensionFields().Has(n) } func (fs *knownFields) Get(n pref.FieldNumber) pref.Value { if fi := fs.mi.fields[n]; fi != nil { return fi.get(fs.p) } - // TODO: Handle extension fields. - return pref.Value{} + return fs.extensionFields().Get(n) } func (fs *knownFields) Set(n pref.FieldNumber, v pref.Value) { if fi := fs.mi.fields[n]; fi != nil { fi.set(fs.p, v) return } - // TODO: Handle extension fields. - panic(fmt.Sprintf("invalid field: %d", n)) + fs.extensionFields().Set(n, v) } func (fs *knownFields) Clear(n pref.FieldNumber) { if fi := fs.mi.fields[n]; fi != nil { fi.clear(fs.p) return } - // TODO: Handle extension fields. - panic(fmt.Sprintf("invalid field: %d", n)) + fs.extensionFields().Clear(n) } func (fs *knownFields) Mutable(n pref.FieldNumber) pref.Mutable { if fi := fs.mi.fields[n]; fi != nil { return fi.mutable(fs.p) } - // TODO: Handle extension fields. - panic(fmt.Sprintf("invalid field: %d", n)) + return fs.extensionFields().Mutable(n) } func (fs *knownFields) Range(f func(pref.FieldNumber, pref.Value) bool) { for n, fi := range fs.mi.fields { @@ -282,25 +295,39 @@ func (fs *knownFields) Range(f func(pref.FieldNumber, pref.Value) bool) { } } } - // TODO: Handle extension fields. + fs.extensionFields().Range(f) } func (fs *knownFields) ExtensionTypes() pref.ExtensionFieldTypes { - return (*extensionFieldTypes)(fs) + return fs.extensionFields().ExtensionTypes() +} +func (fs *knownFields) extensionFields() pref.KnownFields { + return fs.mi.extensionFields((*messageDataType)(fs)) } -type extensionFieldTypes messageDataType // TODO +type emptyUnknownFields struct{} -func (fs *extensionFieldTypes) Len() int { return 0 } -func (fs *extensionFieldTypes) Register(pref.ExtensionType) { return } -func (fs *extensionFieldTypes) Remove(pref.ExtensionType) { return } -func (fs *extensionFieldTypes) ByNumber(pref.FieldNumber) pref.ExtensionType { return nil } -func (fs *extensionFieldTypes) ByName(pref.FullName) pref.ExtensionType { return nil } -func (fs *extensionFieldTypes) Range(f func(pref.ExtensionType) bool) { return } +func (emptyUnknownFields) Len() int { return 0 } +func (emptyUnknownFields) Get(pref.FieldNumber) pref.RawFields { return nil } +func (emptyUnknownFields) Set(pref.FieldNumber, pref.RawFields) { /* noop */ } +func (emptyUnknownFields) Range(func(pref.FieldNumber, pref.RawFields) bool) {} +func (emptyUnknownFields) IsSupported() bool { return false } -type unknownFields messageDataType // TODO +type emptyExtensionFields struct{} -func (fs *unknownFields) Len() int { return 0 } -func (fs *unknownFields) Get(n pref.FieldNumber) pref.RawFields { return nil } -func (fs *unknownFields) Set(n pref.FieldNumber, b pref.RawFields) { return } -func (fs *unknownFields) Range(f func(pref.FieldNumber, pref.RawFields) bool) { return } -func (fs *unknownFields) IsSupported() bool { return false } +func (emptyExtensionFields) Len() int { return 0 } +func (emptyExtensionFields) Has(pref.FieldNumber) bool { return false } +func (emptyExtensionFields) Get(pref.FieldNumber) pref.Value { return pref.Value{} } +func (emptyExtensionFields) Set(pref.FieldNumber, pref.Value) { panic("invalid field") } +func (emptyExtensionFields) Clear(pref.FieldNumber) { panic("invalid field") } +func (emptyExtensionFields) Mutable(pref.FieldNumber) pref.Mutable { panic("invalid field") } +func (emptyExtensionFields) Range(f func(pref.FieldNumber, pref.Value) bool) {} +func (emptyExtensionFields) ExtensionTypes() pref.ExtensionFieldTypes { return emptyExtensionTypes{} } + +type emptyExtensionTypes struct{} + +func (emptyExtensionTypes) Len() int { return 0 } +func (emptyExtensionTypes) Register(pref.ExtensionType) { panic("extensions not supported") } +func (emptyExtensionTypes) Remove(pref.ExtensionType) { panic("extensions not supported") } +func (emptyExtensionTypes) ByNumber(pref.FieldNumber) pref.ExtensionType { return nil } +func (emptyExtensionTypes) ByName(pref.FullName) pref.ExtensionType { return nil } +func (emptyExtensionTypes) Range(func(pref.ExtensionType) bool) {}