diff --git a/internal/cmd/pbdump/pbdump.go b/internal/cmd/pbdump/pbdump.go index 87d60307..befb8310 100644 --- a/internal/cmd/pbdump/pbdump.go +++ b/internal/cmd/pbdump/pbdump.go @@ -128,7 +128,7 @@ func main() { // Parse and print message structure. defer log.Printf("fatal input: %q", buf) // debug printout if panic occurs var m protopack.Message - m.UnmarshalDescriptor(buf, desc) + m.UnmarshalAbductive(buf, desc) if *printSource { fmt.Printf("%#v\n", m) } else { diff --git a/testing/protopack/pack.go b/testing/protopack/pack.go index d39593ac..683ce0bd 100644 --- a/testing/protopack/pack.go +++ b/testing/protopack/pack.go @@ -270,7 +270,7 @@ func (m Message) Marshal() []byte { // // Unmarshal is useful for debugging the protobuf wire format. func (m *Message) Unmarshal(in []byte) { - m.UnmarshalDescriptor(in, nil) + m.unmarshal(in, nil, false) } // UnmarshalDescriptor parses the input protobuf wire data as a syntax tree @@ -289,22 +289,40 @@ func (m *Message) Unmarshal(in []byte) { // Known sub-messages are parsed as a Message and packed repeated fields are // parsed as a LengthPrefix. func (m *Message) UnmarshalDescriptor(in []byte, desc protoreflect.MessageDescriptor) { + m.unmarshal(in, desc, false) +} + +// UnmarshalAbductive is like UnmarshalDescriptor, but infers abductively +// whether any unknown bytes values is a message based on whether it is +// a syntactically well-formed message. +// +// Note that the protobuf wire format is not fully self-describing, +// so abductive inference may attempt to expand a bytes value as a message +// that is not actually a message. It is a best-effort guess. +func (m *Message) UnmarshalAbductive(in []byte, desc protoreflect.MessageDescriptor) { + m.unmarshal(in, desc, true) +} + +func (m *Message) unmarshal(in []byte, desc protoreflect.MessageDescriptor, inferMessage bool) { p := parser{in: in, out: *m} - p.parseMessage(desc, false) + p.parseMessage(desc, false, inferMessage) *m = p.out } type parser struct { in []byte out []Token + + invalid bool } -func (p *parser) parseMessage(msgDesc protoreflect.MessageDescriptor, group bool) { +func (p *parser) parseMessage(msgDesc protoreflect.MessageDescriptor, group, inferMessage bool) { for len(p.in) > 0 { v, n := protowire.ConsumeVarint(p.in) num, typ := protowire.DecodeTag(v) - if n < 0 || num < 0 || v > math.MaxUint32 { + if n < 0 || num <= 0 || v > math.MaxUint32 { p.out, p.in = append(p.out, Raw(p.in)), nil + p.invalid = true return } if typ == EndGroupType && group { @@ -341,13 +359,14 @@ func (p *parser) parseMessage(msgDesc protoreflect.MessageDescriptor, group bool case Fixed64Type: p.parseFixed64(kind) case BytesType: - p.parseBytes(isPacked, kind, subDesc) + p.parseBytes(isPacked, kind, subDesc, inferMessage) case StartGroupType: - p.parseGroup(subDesc) + p.parseGroup(num, subDesc, inferMessage) case EndGroupType: - // Handled above. + // Handled by p.parseGroup. default: p.out, p.in = append(p.out, Raw(p.in)), nil + p.invalid = true } } } @@ -356,6 +375,7 @@ func (p *parser) parseVarint(kind protoreflect.Kind) { v, n := protowire.ConsumeVarint(p.in) if n < 0 { p.out, p.in = append(p.out, Raw(p.in)), nil + p.invalid = true return } switch kind { @@ -384,6 +404,7 @@ func (p *parser) parseFixed32(kind protoreflect.Kind) { v, n := protowire.ConsumeFixed32(p.in) if n < 0 { p.out, p.in = append(p.out, Raw(p.in)), nil + p.invalid = true return } switch kind { @@ -400,6 +421,7 @@ func (p *parser) parseFixed64(kind protoreflect.Kind) { v, n := protowire.ConsumeFixed64(p.in) if n < 0 { p.out, p.in = append(p.out, Raw(p.in)), nil + p.invalid = true return } switch kind { @@ -412,10 +434,11 @@ func (p *parser) parseFixed64(kind protoreflect.Kind) { } } -func (p *parser) parseBytes(isPacked bool, kind protoreflect.Kind, desc protoreflect.MessageDescriptor) { +func (p *parser) parseBytes(isPacked bool, kind protoreflect.Kind, desc protoreflect.MessageDescriptor, inferMessage bool) { v, n := protowire.ConsumeVarint(p.in) if n < 0 { p.out, p.in = append(p.out, Raw(p.in)), nil + p.invalid = true return } p.out, p.in = append(p.out, Uvarint(v)), p.in[n:] @@ -424,6 +447,7 @@ func (p *parser) parseBytes(isPacked bool, kind protoreflect.Kind, desc protoref } if v > uint64(len(p.in)) { p.out, p.in = append(p.out, Raw(p.in)), nil + p.invalid = true return } p.out = p.out[:len(p.out)-1] // subsequent tokens contain prefix-length @@ -434,11 +458,22 @@ func (p *parser) parseBytes(isPacked bool, kind protoreflect.Kind, desc protoref switch kind { case protoreflect.MessageKind: p2 := parser{in: p.in[:v]} - p2.parseMessage(desc, false) + p2.parseMessage(desc, false, inferMessage) p.out, p.in = append(p.out, LengthPrefix(p2.out)), p.in[v:] case protoreflect.StringKind: p.out, p.in = append(p.out, String(p.in[:v])), p.in[v:] + case protoreflect.BytesKind: + p.out, p.in = append(p.out, Bytes(p.in[:v])), p.in[v:] default: + if inferMessage { + // Check whether this is a syntactically valid message. + p2 := parser{in: p.in[:v]} + p2.parseMessage(nil, false, inferMessage) + if !p2.invalid { + p.out, p.in = append(p.out, LengthPrefix(p2.out)), p.in[v:] + break + } + } p.out, p.in = append(p.out, Bytes(p.in[:v])), p.in[v:] } } @@ -466,9 +501,9 @@ func (p *parser) parsePacked(n int, kind protoreflect.Kind) { p.out, p.in = append(p.out, LengthPrefix(p2.out)), p.in[n:] } -func (p *parser) parseGroup(desc protoreflect.MessageDescriptor) { +func (p *parser) parseGroup(startNum protowire.Number, desc protoreflect.MessageDescriptor, inferMessage bool) { p2 := parser{in: p.in} - p2.parseMessage(desc, true) + p2.parseMessage(desc, true, inferMessage) if len(p2.out) > 0 { p.out = append(p.out, Message(p2.out)) } @@ -476,8 +511,11 @@ func (p *parser) parseGroup(desc protoreflect.MessageDescriptor) { // Append the trailing end group. v, n := protowire.ConsumeVarint(p.in) - if num, typ := protowire.DecodeTag(v); typ == EndGroupType { - p.out, p.in = append(p.out, Tag{num, typ}), p.in[n:] + if endNum, typ := protowire.DecodeTag(v); typ == EndGroupType { + if startNum != endNum { + p.invalid = true + } + p.out, p.in = append(p.out, Tag{endNum, typ}), p.in[n:] if m := n - protowire.SizeVarint(v); m > 0 { p.out[len(p.out)-1] = Denormalized{uint(m), p.out[len(p.out)-1]} } diff --git a/testing/protopack/pack_test.go b/testing/protopack/pack_test.go index 97525497..61ea336c 100644 --- a/testing/protopack/pack_test.go +++ b/testing/protopack/pack_test.go @@ -15,6 +15,7 @@ import ( "google.golang.org/protobuf/encoding/prototext" pdesc "google.golang.org/protobuf/reflect/protodesc" + "google.golang.org/protobuf/reflect/protoreflect" pref "google.golang.org/protobuf/reflect/protoreflect" "google.golang.org/protobuf/types/descriptorpb" @@ -67,8 +68,10 @@ func dhex(s string) []byte { func TestPack(t *testing.T) { tests := []struct { - raw []byte - msg Message + raw []byte + msg Message + msgDesc protoreflect.MessageDescriptor + inferMsg bool wantOutCompact string wantOutMulti string @@ -81,12 +84,22 @@ func TestPack(t *testing.T) { Tag{1, VarintType}, Denormalized{5, Uvarint(2)}, Tag{1, BytesType}, LengthPrefix{Bool(true), Bool(false), Uvarint(2), Denormalized{5, Uvarint(2)}}, }, + msgDesc: msgDesc, wantOutSource: `protopack.Message{ protopack.Tag{1, protopack.VarintType}, protopack.Bool(false), protopack.Denormalized{+5, protopack.Tag{1, protopack.VarintType}}, protopack.Uvarint(2), protopack.Tag{1, protopack.VarintType}, protopack.Denormalized{+5, protopack.Uvarint(2)}, protopack.Tag{1, protopack.BytesType}, protopack.LengthPrefix{protopack.Bool(true), protopack.Bool(false), protopack.Uvarint(2), protopack.Denormalized{+5, protopack.Uvarint(2)}}, }`, + }, { + raw: dhex("080088808080800002088280808080000a09010002828080808000"), + msg: Message{ + Tag{1, VarintType}, Uvarint(0), + Denormalized{5, Tag{1, VarintType}}, Uvarint(2), + Tag{1, VarintType}, Denormalized{5, Uvarint(2)}, + Tag{1, BytesType}, Bytes(Message{Bool(true), Bool(false), Uvarint(2), Denormalized{5, Uvarint(2)}}.Marshal()), + }, + inferMsg: true, }, { raw: dhex("100010828080808000121980808080808080808001ffffffffffffffff7f828080808000"), msg: Message{ @@ -94,6 +107,7 @@ func TestPack(t *testing.T) { Tag{2, VarintType}, Denormalized{5, Varint(2)}, Tag{2, BytesType}, LengthPrefix{Varint(math.MinInt64), Varint(math.MaxInt64), Denormalized{5, Varint(2)}}, }, + msgDesc: msgDesc, wantOutCompact: `Message{Tag{2, Varint}, Varint(0), Tag{2, Varint}, Denormalized{+5, Varint(2)}, Tag{2, Bytes}, LengthPrefix{Varint(-9223372036854775808), Varint(9223372036854775807), Denormalized{+5, Varint(2)}}}`, }, { raw: dhex("1801188180808080001a1affffffffffffffffff01feffffffffffffffff01818080808000"), @@ -102,6 +116,7 @@ func TestPack(t *testing.T) { Tag{3, VarintType}, Denormalized{5, Svarint(-1)}, Tag{3, BytesType}, LengthPrefix{Svarint(math.MinInt64), Svarint(math.MaxInt64), Denormalized{5, Svarint(-1)}}, }, + msgDesc: msgDesc, wantOutMulti: `Message{ Tag{3, Varint}, Svarint(-1), Tag{3, Varint}, Denormalized{+5, Svarint(-1)}, @@ -114,6 +129,7 @@ func TestPack(t *testing.T) { Tag{4, VarintType}, Denormalized{5, Uvarint(+1)}, Tag{4, BytesType}, LengthPrefix{Uvarint(0), Uvarint(math.MaxUint64), Denormalized{5, Uvarint(+1)}}, }, + msgDesc: msgDesc, wantOutSource: `protopack.Message{ protopack.Tag{4, protopack.VarintType}, protopack.Uvarint(1), protopack.Tag{4, protopack.VarintType}, protopack.Denormalized{+5, protopack.Uvarint(1)}, @@ -125,6 +141,7 @@ func TestPack(t *testing.T) { Tag{5, Fixed32Type}, Uint32(+1), Tag{5, BytesType}, LengthPrefix{Uint32(0), Uint32(math.MaxUint32)}, }, + msgDesc: msgDesc, wantOutCompact: `Message{Tag{5, Fixed32}, Uint32(1), Tag{5, Bytes}, LengthPrefix{Uint32(0), Uint32(4294967295)}}`, }, { raw: dhex("35ffffffff320800000080ffffff7f"), @@ -132,6 +149,7 @@ func TestPack(t *testing.T) { Tag{6, Fixed32Type}, Int32(-1), Tag{6, BytesType}, LengthPrefix{Int32(math.MinInt32), Int32(math.MaxInt32)}, }, + msgDesc: msgDesc, wantOutMulti: `Message{ Tag{6, Fixed32}, Int32(-1), Tag{6, Bytes}, LengthPrefix{Int32(-2147483648), Int32(2147483647)}, @@ -142,6 +160,7 @@ func TestPack(t *testing.T) { Tag{7, Fixed32Type}, Float32(math.Pi), Tag{7, BytesType}, LengthPrefix{Float32(math.SmallestNonzeroFloat32), Float32(math.MaxFloat32), Float32(math.Inf(+1)), Float32(math.Inf(-1))}, }, + msgDesc: msgDesc, wantOutSource: `protopack.Message{ protopack.Tag{7, protopack.Fixed32Type}, protopack.Float32(3.1415927), protopack.Tag{7, protopack.BytesType}, protopack.LengthPrefix{protopack.Float32(1e-45), protopack.Float32(3.4028235e+38), protopack.Float32(math.Inf(+1)), protopack.Float32(math.Inf(-1))}, @@ -152,6 +171,7 @@ func TestPack(t *testing.T) { Tag{8, Fixed64Type}, Uint64(+1), Tag{8, BytesType}, LengthPrefix{Uint64(0), Uint64(math.MaxUint64)}, }, + msgDesc: msgDesc, wantOutCompact: `Message{Tag{8, Fixed64}, Uint64(1), Tag{8, Bytes}, LengthPrefix{Uint64(0), Uint64(18446744073709551615)}}`, }, { raw: dhex("49ffffffffffffffff4a100000000000000080ffffffffffffff7f"), @@ -159,6 +179,7 @@ func TestPack(t *testing.T) { Tag{9, Fixed64Type}, Int64(-1), Tag{9, BytesType}, LengthPrefix{Int64(math.MinInt64), Int64(math.MaxInt64)}, }, + msgDesc: msgDesc, wantOutMulti: `Message{ Tag{9, Fixed64}, Int64(-1), Tag{9, Bytes}, LengthPrefix{Int64(-9223372036854775808), Int64(9223372036854775807)}, @@ -169,6 +190,7 @@ func TestPack(t *testing.T) { Tag{10, Fixed64Type}, Float64(math.Pi), Tag{10, BytesType}, LengthPrefix{Float64(math.SmallestNonzeroFloat64), Float64(math.MaxFloat64), Float64(math.Inf(+1)), Float64(math.Inf(-1))}, }, + msgDesc: msgDesc, wantOutMulti: `Message{ Tag{10, Fixed64}, Float64(3.141592653589793), Tag{10, Bytes}, LengthPrefix{Float64(5e-324), Float64(1.7976931348623157e+308), Float64(+Inf), Float64(-Inf)}, @@ -179,6 +201,7 @@ func TestPack(t *testing.T) { Tag{11, BytesType}, String("string"), Tag{11, BytesType}, Denormalized{+5, String("string")}, }, + msgDesc: msgDesc, wantOutCompact: `Message{Tag{11, Bytes}, String("string"), Tag{11, Bytes}, Denormalized{+5, String("string")}}`, }, { raw: dhex("62056279746573628580808080006279746573"), @@ -186,6 +209,7 @@ func TestPack(t *testing.T) { Tag{12, BytesType}, Bytes("bytes"), Tag{12, BytesType}, Denormalized{+5, Bytes("bytes")}, }, + msgDesc: msgDesc, wantOutMulti: `Message{ Tag{12, Bytes}, Bytes("bytes"), Tag{12, Bytes}, Denormalized{+5, Bytes("bytes")}, @@ -201,6 +225,7 @@ func TestPack(t *testing.T) { Tag{100, StartGroupType}, Tag{100, EndGroupType}, }), }, + msgDesc: msgDesc, wantOutSource: `protopack.Message{ protopack.Tag{13, protopack.BytesType}, protopack.LengthPrefix(protopack.Message{ protopack.Tag{100, protopack.VarintType}, protopack.Uvarint(18446744073709551615), @@ -211,6 +236,30 @@ func TestPack(t *testing.T) { protopack.Tag{100, protopack.EndGroupType}, }), }`, + }, { + raw: dhex("6a28a006ffffffffffffffffff01a506ffffffffa106ffffffffffffffffa206056279746573a306a406"), + msg: Message{ + Tag{13, BytesType}, LengthPrefix(Message{ + Tag{100, VarintType}, Uvarint(math.MaxUint64), + Tag{100, Fixed32Type}, Uint32(math.MaxUint32), + Tag{100, Fixed64Type}, Uint64(math.MaxUint64), + Tag{100, BytesType}, Bytes("bytes"), + Tag{100, StartGroupType}, Tag{100, EndGroupType}, + }), + }, + inferMsg: true, + }, { + raw: dhex("6a28a006ffffffffffffffffff01a506ffffffffa106ffffffffffffffffa206056279746573a306ac06"), + msg: Message{ + Tag{13, BytesType}, Bytes(Message{ + Tag{100, VarintType}, Uvarint(math.MaxUint64), + Tag{100, Fixed32Type}, Uint32(math.MaxUint32), + Tag{100, Fixed64Type}, Uint64(math.MaxUint64), + Tag{100, BytesType}, Bytes("bytes"), + Tag{100, StartGroupType}, Tag{101, EndGroupType}, + }.Marshal()), + }, + inferMsg: true, }, { raw: dhex("6aa88080808000a006ffffffffffffffffff01a506ffffffffa106ffffffffffffffffa206056279746573a306a406"), msg: Message{ @@ -222,6 +271,7 @@ func TestPack(t *testing.T) { Tag{100, StartGroupType}, Tag{100, EndGroupType}, })}, }, + msgDesc: msgDesc, wantOutCompact: `Message{Tag{13, Bytes}, Denormalized{+5, LengthPrefix(Message{Tag{100, Varint}, Uvarint(18446744073709551615), Tag{100, Fixed32}, Uint32(4294967295), Tag{100, Fixed64}, Uint64(18446744073709551615), Tag{100, Bytes}, Bytes("bytes"), Tag{100, StartGroup}, Tag{100, EndGroup}})}}`, }, { raw: dhex("73a006ffffffffffffffffff01a506ffffffffa106ffffffffffffffffa206056279746573a306a40674"), @@ -235,6 +285,7 @@ func TestPack(t *testing.T) { }, Tag{14, EndGroupType}, }, + msgDesc: msgDesc, wantOutMulti: `Message{ Tag{14, StartGroup}, Message{ @@ -261,6 +312,7 @@ func TestPack(t *testing.T) { Tag{1706, Type(7)}, Raw("\x1an\x98\x11\xc8Z*\xb3"), }, + msgDesc: msgDesc, }, { raw: dhex("3d08d0e57f"), msg: Message{ @@ -269,6 +321,7 @@ func TestPack(t *testing.T) { func() uint32 { return 0x7fe5d008 }(), )), }, + msgDesc: msgDesc, wantOutSource: `protopack.Message{ protopack.Tag{7, protopack.Fixed32Type}, protopack.Float32(math.Float32frombits(0x7fe5d008)), }`, @@ -277,6 +330,7 @@ func TestPack(t *testing.T) { msg: Message{ Tag{10, Fixed64Type}, Float64(math.Float64frombits(0x7ff91b771051d6a8)), }, + msgDesc: msgDesc, wantOutSource: `protopack.Message{ protopack.Tag{10, protopack.Fixed64Type}, protopack.Float64(math.Float64frombits(0x7ff91b771051d6a8)), }`, @@ -302,6 +356,7 @@ func TestPack(t *testing.T) { Tag{28856, BytesType}, Raw("\xbb"), }, + msgDesc: msgDesc, }, { raw: dhex("29baa4ac1c1e0a20183393bac434b8d3559337ec940050038770eaa9937f98e4"), msg: Message{ @@ -318,6 +373,7 @@ func TestPack(t *testing.T) { Raw("꩓\u007f\x98\xe4"), }, }, + msgDesc: msgDesc, }} equateFloatBits := cmp.Options{ @@ -332,13 +388,13 @@ func TestPack(t *testing.T) { t.Run("", func(t *testing.T) { var msg Message raw := tt.msg.Marshal() - msg.UnmarshalDescriptor(tt.raw, msgDesc) + msg.unmarshal(tt.raw, tt.msgDesc, tt.inferMsg) if !bytes.Equal(raw, tt.raw) { t.Errorf("Marshal() mismatch:\ngot %x\nwant %x", raw, tt.raw) } - if !cmp.Equal(msg, tt.msg, equateFloatBits) { - t.Errorf("Unmarshal() mismatch:\ngot %+v\nwant %+v", msg, tt.msg) + if diff := cmp.Diff(tt.msg, msg, equateFloatBits); diff != "" { + t.Errorf("Unmarshal() mismatch (-want +got):\n%s", diff) } if got, want := tt.msg.Size(), len(tt.raw); got != want { t.Errorf("Size() = %v, want %v", got, want)