diff --git a/internal/impl/legacy_extension.go b/internal/impl/legacy_extension.go new file mode 100644 index 00000000..b2295814 --- /dev/null +++ b/internal/impl/legacy_extension.go @@ -0,0 +1,126 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package impl + +import ( + "reflect" + "sync" + "unsafe" + + protoV1 "github.com/golang/protobuf/proto" + pref "github.com/golang/protobuf/v2/reflect/protoreflect" +) + +// TODO: The logic below this is a hack since v1 currently exposes no +// exported functionality for interacting with these data structures. +// Eventually make changes to v1 such that v2 can access the necessary +// fields without relying on unsafe. + +var ( + extTypeA = reflect.TypeOf(map[int32]protoV1.Extension(nil)) + extTypeB = reflect.TypeOf(protoV1.XXX_InternalExtensions{}) +) + +type legacyExtensionIface interface { + Len() int + Get(pref.FieldNumber) legacyExtensionEntry + Set(pref.FieldNumber, legacyExtensionEntry) + Range(f func(pref.FieldNumber, legacyExtensionEntry) bool) +} + +func makeLegacyExtensionMapFunc(t reflect.Type) func(*messageDataType) legacyExtensionIface { + fx1, _ := t.FieldByName("XXX_extensions") + fx2, _ := t.FieldByName("XXX_InternalExtensions") + switch { + case fx1.Type == extTypeA: + return func(p *messageDataType) legacyExtensionIface { + rv := p.p.asType(t).Elem() + return (*legacyExtensionMap)(unsafe.Pointer(rv.UnsafeAddr() + fx1.Offset)) + } + case fx2.Type == extTypeB: + return func(p *messageDataType) legacyExtensionIface { + rv := p.p.asType(t).Elem() + return (*legacyExtensionSyncMap)(unsafe.Pointer(rv.UnsafeAddr() + fx2.Offset)) + } + default: + return nil + } +} + +// legacyExtensionSyncMap is identical to protoV1.XXX_InternalExtensions. +// It implements legacyExtensionIface. +type legacyExtensionSyncMap struct { + p *struct { + mu sync.Mutex + m legacyExtensionMap + } +} + +func (m legacyExtensionSyncMap) Len() int { + if m.p == nil { + return 0 + } + m.p.mu.Lock() + defer m.p.mu.Unlock() + return m.p.m.Len() +} +func (m legacyExtensionSyncMap) Get(n pref.FieldNumber) legacyExtensionEntry { + if m.p == nil { + return legacyExtensionEntry{} + } + m.p.mu.Lock() + defer m.p.mu.Unlock() + return m.p.m.Get(n) +} +func (m *legacyExtensionSyncMap) Set(n pref.FieldNumber, x legacyExtensionEntry) { + if m.p == nil { + m.p = new(struct { + mu sync.Mutex + m legacyExtensionMap + }) + } + m.p.mu.Lock() + defer m.p.mu.Unlock() + m.p.m.Set(n, x) +} +func (m legacyExtensionSyncMap) Range(f func(pref.FieldNumber, legacyExtensionEntry) bool) { + if m.p == nil { + return + } + m.p.mu.Lock() + defer m.p.mu.Unlock() + m.p.m.Range(f) +} + +// legacyExtensionMap is identical to map[int32]protoV1.Extension. +// It implements legacyExtensionIface. +type legacyExtensionMap map[pref.FieldNumber]legacyExtensionEntry + +func (m legacyExtensionMap) Len() int { + return len(m) +} +func (m legacyExtensionMap) Get(n pref.FieldNumber) legacyExtensionEntry { + return m[n] +} +func (m *legacyExtensionMap) Set(n pref.FieldNumber, x legacyExtensionEntry) { + if *m == nil { + *m = make(map[pref.FieldNumber]legacyExtensionEntry) + } + (*m)[n] = x +} +func (m legacyExtensionMap) Range(f func(pref.FieldNumber, legacyExtensionEntry) bool) { + for n, x := range m { + if !f(n, x) { + return + } + } +} + +// legacyExtensionEntry is identical to protoV1.Extension. +type legacyExtensionEntry struct { + desc *protoV1.ExtensionDesc + val interface{} + raw []byte +} diff --git a/internal/impl/legacy_test.go b/internal/impl/legacy_test.go index d2e71d40..caa00434 100644 --- a/internal/impl/legacy_test.go +++ b/internal/impl/legacy_test.go @@ -10,6 +10,7 @@ import ( "reflect" "testing" + protoV1 "github.com/golang/protobuf/proto" "github.com/golang/protobuf/v2/internal/encoding/pack" "github.com/golang/protobuf/v2/internal/pragma" pref "github.com/golang/protobuf/v2/reflect/protoreflect" @@ -137,6 +138,15 @@ func TestLegacyDescriptor(t *testing.T) { } } +type legacyUnknownMessage struct { + XXX_unrecognized []byte + protoV1.XXX_InternalExtensions +} + +func (*legacyUnknownMessage) ExtensionRangeArray() []protoV1.ExtensionRange { + return []protoV1.ExtensionRange{{Start: 10, End: 20}, {Start: 40, End: 80}} +} + func TestLegacyUnknown(t *testing.T) { rawOf := func(toks ...pack.Token) pref.RawFields { return pref.RawFields(pack.Message(toks).Marshal()) @@ -149,6 +159,17 @@ func TestLegacyUnknown(t *testing.T) { raw3a := rawOf(pack.Tag{3, pack.StartGroupType}, pack.Tag{3, pack.EndGroupType}) // 1b1c raw3b := rawOf(pack.Tag{3, pack.BytesType}, pack.Bytes("\xde\xad\xbe\xef")) // 1a04deadbeef + raw1 := rawOf(pack.Tag{1, pack.BytesType}, pack.Bytes("1")) // 0a0131 + raw3 := rawOf(pack.Tag{3, pack.BytesType}, pack.Bytes("3")) // 1a0133 + raw10 := rawOf(pack.Tag{10, pack.BytesType}, pack.Bytes("10")) // 52023130 - extension + raw15 := rawOf(pack.Tag{15, pack.BytesType}, pack.Bytes("15")) // 7a023135 - extension + raw26 := rawOf(pack.Tag{26, pack.BytesType}, pack.Bytes("26")) // d201023236 + raw32 := rawOf(pack.Tag{32, pack.BytesType}, pack.Bytes("32")) // 8202023332 + raw45 := rawOf(pack.Tag{45, pack.BytesType}, pack.Bytes("45")) // ea02023435 - extension + raw46 := rawOf(pack.Tag{45, pack.BytesType}, pack.Bytes("46")) // ea02023436 - extension + raw47 := rawOf(pack.Tag{45, pack.BytesType}, pack.Bytes("47")) // ea02023437 - extension + raw99 := rawOf(pack.Tag{99, pack.BytesType}, pack.Bytes("99")) // 9a06023939 + joinRaw := func(bs ...pref.RawFields) (out []byte) { for _, b := range bs { out = append(out, b...) @@ -156,11 +177,13 @@ func TestLegacyUnknown(t *testing.T) { return out } - var fs legacyUnknownBytes + m := new(legacyUnknownMessage) + fs := new(MessageType).MessageOf(m).UnknownFields() + if got, want := fs.Len(), 0; got != want { t.Errorf("Len() = %d, want %d", got, want) } - if got, want := []byte(fs), joinRaw(); !bytes.Equal(got, want) { + if got, want := m.XXX_unrecognized, joinRaw(); !bytes.Equal(got, want) { t.Errorf("data mismatch:\ngot: %x\nwant: %x", got, want) } @@ -170,7 +193,7 @@ func TestLegacyUnknown(t *testing.T) { if got, want := fs.Len(), 1; got != want { t.Errorf("Len() = %d, want %d", got, want) } - if got, want := []byte(fs), joinRaw(raw1a, raw1b, raw1c); !bytes.Equal(got, want) { + if got, want := m.XXX_unrecognized, joinRaw(raw1a, raw1b, raw1c); !bytes.Equal(got, want) { t.Errorf("data mismatch:\ngot: %x\nwant: %x", got, want) } @@ -178,7 +201,7 @@ func TestLegacyUnknown(t *testing.T) { if got, want := fs.Len(), 2; got != want { t.Errorf("Len() = %d, want %d", got, want) } - if got, want := []byte(fs), joinRaw(raw1a, raw1b, raw1c, raw2a); !bytes.Equal(got, want) { + if got, want := m.XXX_unrecognized, joinRaw(raw1a, raw1b, raw1c, raw2a); !bytes.Equal(got, want) { t.Errorf("data mismatch:\ngot: %x\nwant: %x", got, want) } @@ -196,12 +219,12 @@ func TestLegacyUnknown(t *testing.T) { if got, want := fs.Len(), 1; got != want { t.Errorf("Len() = %d, want %d", got, want) } - if got, want := []byte(fs), joinRaw(raw2a); !bytes.Equal(got, want) { + if got, want := m.XXX_unrecognized, joinRaw(raw2a); !bytes.Equal(got, want) { t.Errorf("data mismatch:\ngot: %x\nwant: %x", got, want) } // Simulate manual appending of raw field data. - fs = append(fs, joinRaw(raw3a, raw1a, raw1b, raw2b, raw3b, raw1c)...) + m.XXX_unrecognized = append(m.XXX_unrecognized, joinRaw(raw3a, raw1a, raw1b, raw2b, raw3b, raw1c)...) if got, want := fs.Len(), 3; got != want { t.Errorf("Len() = %d, want %d", got, want) } @@ -232,14 +255,14 @@ func TestLegacyUnknown(t *testing.T) { if got, want := fs.Len(), 3; got != want { t.Errorf("Len() = %d, want %d", got, want) } - if got, want := []byte(fs), joinRaw(raw3a, raw1a, raw1b, raw3b, raw1c, raw2a, raw2b); !bytes.Equal(got, want) { + if got, want := m.XXX_unrecognized, joinRaw(raw3a, raw1a, raw1b, raw3b, raw1c, raw2a, raw2b); !bytes.Equal(got, want) { t.Errorf("data mismatch:\ngot: %x\nwant: %x", got, want) } fs.Set(1, nil) // remove field 1 if got, want := fs.Len(), 2; got != want { t.Errorf("Len() = %d, want %d", got, want) } - if got, want := []byte(fs), joinRaw(raw3a, raw3b, raw2a, raw2b); !bytes.Equal(got, want) { + if got, want := m.XXX_unrecognized, joinRaw(raw3a, raw3b, raw2a, raw2b); !bytes.Equal(got, want) { t.Errorf("data mismatch:\ngot: %x\nwant: %x", got, want) } @@ -251,7 +274,102 @@ func TestLegacyUnknown(t *testing.T) { if got, want := fs.Len(), 0; got != want { t.Errorf("Len() = %d, want %d", got, want) } - if got, want := []byte(fs), joinRaw(); !bytes.Equal(got, want) { + if got, want := m.XXX_unrecognized, joinRaw(); !bytes.Equal(got, want) { t.Errorf("data mismatch:\ngot: %x\nwant: %x", got, want) } + + fs.Set(1, raw1) + if got, want := fs.Len(), 1; got != want { + t.Errorf("Len() = %d, want %d", got, want) + } + if got, want := m.XXX_unrecognized, joinRaw(raw1); !bytes.Equal(got, want) { + t.Errorf("data mismatch:\ngot: %x\nwant: %x", got, want) + } + + fs.Set(45, raw45) + fs.Set(10, raw10) // extension + fs.Set(32, raw32) + fs.Set(1, nil) // deletion + fs.Set(26, raw26) + fs.Set(47, raw47) // extension + fs.Set(46, raw46) // extension + if got, want := fs.Len(), 6; got != want { + t.Errorf("Len() = %d, want %d", got, want) + } + if got, want := m.XXX_unrecognized, joinRaw(raw32, raw26); !bytes.Equal(got, want) { + t.Errorf("data mismatch:\ngot: %x\nwant: %x", got, want) + } + + // Verify iteration order. + i = 0 + want = []struct { + num pref.FieldNumber + raw pref.RawFields + }{ + {32, raw32}, + {26, raw26}, + {10, raw10}, // extension + {45, raw45}, // extension + {46, raw46}, // extension + {47, raw47}, // extension + } + fs.Range(func(num pref.FieldNumber, raw pref.RawFields) bool { + if i < len(want) { + if num != want[i].num || !bytes.Equal(raw, want[i].raw) { + t.Errorf("Range(%d) = (%d, %x), want (%d, %x)", i, num, raw, want[i].num, want[i].raw) + } + } else { + t.Errorf("unexpected Range iteration: %d", i) + } + i++ + return true + }) + + // Perform partial deletion while iterating. + i = 0 + fs.Range(func(num pref.FieldNumber, raw pref.RawFields) bool { + if i%2 == 0 { + fs.Set(num, nil) + } + i++ + return true + }) + + if got, want := fs.Len(), 3; got != want { + t.Errorf("Len() = %d, want %d", got, want) + } + if got, want := m.XXX_unrecognized, joinRaw(raw26); !bytes.Equal(got, want) { + t.Errorf("data mismatch:\ngot: %x\nwant: %x", got, want) + } + + fs.Set(15, raw15) // extension + fs.Set(3, raw3) + fs.Set(99, raw99) + if got, want := fs.Len(), 6; got != want { + t.Errorf("Len() = %d, want %d", got, want) + } + if got, want := m.XXX_unrecognized, joinRaw(raw26, raw3, raw99); !bytes.Equal(got, want) { + t.Errorf("data mismatch:\ngot: %x\nwant: %x", got, want) + } + + // Perform partial iteration. + i = 0 + want = []struct { + num pref.FieldNumber + raw pref.RawFields + }{ + {26, raw26}, + {3, raw3}, + } + fs.Range(func(num pref.FieldNumber, raw pref.RawFields) bool { + if i < len(want) { + if num != want[i].num || !bytes.Equal(raw, want[i].raw) { + t.Errorf("Range(%d) = (%d, %x), want (%d, %x)", i, num, raw, want[i].num, want[i].raw) + } + } else { + t.Errorf("unexpected Range iteration: %d", i) + } + i++ + return i < 2 + }) } diff --git a/internal/impl/legacy_unknown.go b/internal/impl/legacy_unknown.go index a319f056..9ab617bd 100644 --- a/internal/impl/legacy_unknown.go +++ b/internal/impl/legacy_unknown.go @@ -7,36 +7,104 @@ package impl import ( "container/list" "reflect" + "sort" - protoV1 "github.com/golang/protobuf/proto" "github.com/golang/protobuf/v2/internal/encoding/wire" pref "github.com/golang/protobuf/v2/reflect/protoreflect" ) -var ( - extTypeA = reflect.TypeOf(map[int32]protoV1.Extension(nil)) - extTypeB = reflect.TypeOf(protoV1.XXX_InternalExtensions{}) -) - -func generateLegacyUnknownFieldFuncs(t reflect.Type, md pref.MessageDescriptor) func(p *messageDataType) pref.UnknownFields { +func makeLegacyUnknownFieldsFunc(t reflect.Type) func(p *messageDataType) pref.UnknownFields { fu, ok := t.FieldByName("XXX_unrecognized") if !ok || fu.Type != bytesType { return nil } - fx1, _ := t.FieldByName("XXX_extensions") - fx2, _ := t.FieldByName("XXX_InternalExtensions") - if fx1.Type == extTypeA || fx2.Type == extTypeB { - // TODO: In proto v1, the unknown fields are split between both - // XXX_unrecognized and XXX_InternalExtensions. If the message supports - // extensions, then we will need to create a wrapper data structure - // that presents unknown fields in both lists as a single ordered list. - panic("not implemented") - } fieldOffset := offsetOf(fu) - return func(p *messageDataType) pref.UnknownFields { + unkFunc := func(p *messageDataType) pref.UnknownFields { rv := p.p.apply(fieldOffset).asType(bytesType) return (*legacyUnknownBytes)(rv.Interface().(*[]byte)) } + extFunc := makeLegacyExtensionMapFunc(t) + if extFunc != nil { + return func(p *messageDataType) pref.UnknownFields { + return &legacyUnknownBytesAndExtensionMap{ + unkFunc(p), extFunc(p), p.mi.Desc.ExtensionRanges(), + } + } + } + return unkFunc +} + +// legacyUnknownBytesAndExtensionMap is a wrapper around both XXX_unrecognized +// and also the extension field map. +type legacyUnknownBytesAndExtensionMap struct { + u pref.UnknownFields + x legacyExtensionIface + r pref.FieldRanges +} + +func (fs *legacyUnknownBytesAndExtensionMap) Len() int { + n := fs.u.Len() + fs.x.Range(func(_ pref.FieldNumber, x legacyExtensionEntry) bool { + if len(x.raw) > 0 { + n++ + } + return true + }) + return n +} + +func (fs *legacyUnknownBytesAndExtensionMap) Get(num pref.FieldNumber) (raw pref.RawFields) { + if fs.r.Has(num) { + return fs.x.Get(num).raw + } + return fs.u.Get(num) +} + +func (fs *legacyUnknownBytesAndExtensionMap) Set(num pref.FieldNumber, raw pref.RawFields) { + if fs.r.Has(num) { + x := fs.x.Get(num) + x.raw = raw + fs.x.Set(num, x) + return + } + fs.u.Set(num, raw) +} + +func (fs *legacyUnknownBytesAndExtensionMap) Range(f func(pref.FieldNumber, pref.RawFields) bool) { + // Range over unknown fields not in the extension range. + // Create a closure around f to capture whether iteration terminated early. + var stop bool + fs.u.Range(func(n pref.FieldNumber, b pref.RawFields) bool { + stop = stop || !f(n, b) + return !stop + }) + if stop { + return + } + + // Range over unknown fields in the extension range in ascending order + // to ensure protoreflect.UnknownFields.Range remains deterministic. + type entry struct { + num pref.FieldNumber + raw pref.RawFields + } + var xs []entry + fs.x.Range(func(n pref.FieldNumber, x legacyExtensionEntry) bool { + if len(x.raw) > 0 { + xs = append(xs, entry{n, x.raw}) + } + return true + }) + sort.Slice(xs, func(i, j int) bool { return xs[i].num < xs[j].num }) + for _, x := range xs { + if !f(x.num, x.raw) { + return + } + } +} + +func (fs *legacyUnknownBytesAndExtensionMap) IsSupported() bool { + return true } // legacyUnknownBytes is a wrapper around XXX_unrecognized that implements diff --git a/internal/impl/message.go b/internal/impl/message.go index 2b0a9ca3..a3e4b9d7 100644 --- a/internal/impl/message.go +++ b/internal/impl/message.go @@ -53,9 +53,8 @@ func (mi *MessageType) init(p interface{}) { mi.goType = t // Derive the message descriptor if unspecified. - md := mi.Desc - if md == nil { - // TODO: derive the message type from the Go struct type + if mi.Desc == nil { + mi.Desc = loadMessageDesc(t) } // Initialize the Go message type wrapper if the Go type does not @@ -68,7 +67,7 @@ func (mi *MessageType) init(p interface{}) { // Generated code ensures that this property holds. if _, ok := p.(pref.ProtoMessage); !ok { mi.pbType = ptype.NewGoMessage(&ptype.GoMessage{ - MessageDescriptor: md, + MessageDescriptor: mi.Desc, New: func(pref.MessageType) pref.ProtoMessage { p := reflect.New(t.Elem()).Interface() return (*message)(mi.dataTypeOf(p)) @@ -76,9 +75,9 @@ func (mi *MessageType) init(p interface{}) { }) } - mi.generateKnownFieldFuncs(t.Elem(), md) - mi.generateUnknownFieldFuncs(t.Elem(), md) - mi.generateExtensionFieldFuncs(t.Elem(), md) + mi.makeKnownFieldsFunc(t.Elem()) + mi.makeUnknownFieldsFunc(t.Elem()) + mi.makeExtensionFieldsFunc(t.Elem()) }) // TODO: Remove this check? This API is primarily used by generated code, @@ -90,14 +89,14 @@ func (mi *MessageType) init(p interface{}) { } } -// generateKnownFieldFuncs generates per-field functions for all operations +// makeKnownFieldsFunc generates per-field functions for all operations // to be performed on each field. It takes in a reflect.Type representing the // Go struct, and a protoreflect.MessageDescriptor to match with the fields // in the struct. // // This code assumes that the struct is well-formed and panics if there are // any discrepancies. -func (mi *MessageType) generateKnownFieldFuncs(t reflect.Type, md pref.MessageDescriptor) { +func (mi *MessageType) makeKnownFieldsFunc(t reflect.Type) { // Generate a mapping of field numbers and names to Go struct field or type. fields := map[pref.FieldNumber]reflect.StructField{} oneofs := map[pref.Name]reflect.StructField{} @@ -140,8 +139,8 @@ fieldLoop: } mi.fields = map[pref.FieldNumber]*fieldInfo{} - for i := 0; i < md.Fields().Len(); i++ { - fd := md.Fields().Get(i) + for i := 0; i < mi.Desc.Fields().Len(); i++ { + fd := mi.Desc.Fields().Get(i) fs := fields[fd.Number()] var fi fieldInfo switch { @@ -162,8 +161,8 @@ fieldLoop: } } -func (mi *MessageType) generateUnknownFieldFuncs(t reflect.Type, md pref.MessageDescriptor) { - if f := generateLegacyUnknownFieldFuncs(t, md); f != nil { +func (mi *MessageType) makeUnknownFieldsFunc(t reflect.Type) { + if f := makeLegacyUnknownFieldsFunc(t); f != nil { mi.unknownFields = f return } @@ -172,7 +171,7 @@ func (mi *MessageType) generateUnknownFieldFuncs(t reflect.Type, md pref.Message } } -func (mi *MessageType) generateExtensionFieldFuncs(t reflect.Type, md pref.MessageDescriptor) { +func (mi *MessageType) makeExtensionFieldsFunc(t reflect.Type) { // TODO mi.extensionFields = func(*messageDataType) pref.KnownFields { return emptyExtensionFields{}