testing/protopack: add Message.UnmarshalAbductive

The protobuf wire format is insufficiently self-decribing such that
it is impossible to know whether for sure whether an unknown bytes value
is a sub-message or not. However, protopack is primarily used for debugging
where a best-effort guess is still very useful.

The Message.UnmarshalAbductive unmarshals an unknown bytes value as a message
if it is syntactically well-formed. Otherwise, it is left as is.

Change-Id: I5e2b4b995e2b5eb60942a242558bf4cea1da9891
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/309669
Trust: Joe Tsai <joetsai@digital-static.net>
Reviewed-by: Damien Neil <dneil@google.com>
This commit is contained in:
Joe Tsai 2021-04-12 23:08:47 -07:00
parent fc9592f7ac
commit fb30439f55
3 changed files with 113 additions and 19 deletions

View File

@ -128,7 +128,7 @@ func main() {
// Parse and print message structure.
defer log.Printf("fatal input: %q", buf) // debug printout if panic occurs
var m protopack.Message
m.UnmarshalDescriptor(buf, desc)
m.UnmarshalAbductive(buf, desc)
if *printSource {
fmt.Printf("%#v\n", m)
} else {

View File

@ -270,7 +270,7 @@ func (m Message) Marshal() []byte {
//
// Unmarshal is useful for debugging the protobuf wire format.
func (m *Message) Unmarshal(in []byte) {
m.UnmarshalDescriptor(in, nil)
m.unmarshal(in, nil, false)
}
// UnmarshalDescriptor parses the input protobuf wire data as a syntax tree
@ -289,22 +289,40 @@ func (m *Message) Unmarshal(in []byte) {
// Known sub-messages are parsed as a Message and packed repeated fields are
// parsed as a LengthPrefix.
func (m *Message) UnmarshalDescriptor(in []byte, desc protoreflect.MessageDescriptor) {
m.unmarshal(in, desc, false)
}
// UnmarshalAbductive is like UnmarshalDescriptor, but infers abductively
// whether any unknown bytes values is a message based on whether it is
// a syntactically well-formed message.
//
// Note that the protobuf wire format is not fully self-describing,
// so abductive inference may attempt to expand a bytes value as a message
// that is not actually a message. It is a best-effort guess.
func (m *Message) UnmarshalAbductive(in []byte, desc protoreflect.MessageDescriptor) {
m.unmarshal(in, desc, true)
}
func (m *Message) unmarshal(in []byte, desc protoreflect.MessageDescriptor, inferMessage bool) {
p := parser{in: in, out: *m}
p.parseMessage(desc, false)
p.parseMessage(desc, false, inferMessage)
*m = p.out
}
type parser struct {
in []byte
out []Token
invalid bool
}
func (p *parser) parseMessage(msgDesc protoreflect.MessageDescriptor, group bool) {
func (p *parser) parseMessage(msgDesc protoreflect.MessageDescriptor, group, inferMessage bool) {
for len(p.in) > 0 {
v, n := protowire.ConsumeVarint(p.in)
num, typ := protowire.DecodeTag(v)
if n < 0 || num < 0 || v > math.MaxUint32 {
if n < 0 || num <= 0 || v > math.MaxUint32 {
p.out, p.in = append(p.out, Raw(p.in)), nil
p.invalid = true
return
}
if typ == EndGroupType && group {
@ -341,13 +359,14 @@ func (p *parser) parseMessage(msgDesc protoreflect.MessageDescriptor, group bool
case Fixed64Type:
p.parseFixed64(kind)
case BytesType:
p.parseBytes(isPacked, kind, subDesc)
p.parseBytes(isPacked, kind, subDesc, inferMessage)
case StartGroupType:
p.parseGroup(subDesc)
p.parseGroup(num, subDesc, inferMessage)
case EndGroupType:
// Handled above.
// Handled by p.parseGroup.
default:
p.out, p.in = append(p.out, Raw(p.in)), nil
p.invalid = true
}
}
}
@ -356,6 +375,7 @@ func (p *parser) parseVarint(kind protoreflect.Kind) {
v, n := protowire.ConsumeVarint(p.in)
if n < 0 {
p.out, p.in = append(p.out, Raw(p.in)), nil
p.invalid = true
return
}
switch kind {
@ -384,6 +404,7 @@ func (p *parser) parseFixed32(kind protoreflect.Kind) {
v, n := protowire.ConsumeFixed32(p.in)
if n < 0 {
p.out, p.in = append(p.out, Raw(p.in)), nil
p.invalid = true
return
}
switch kind {
@ -400,6 +421,7 @@ func (p *parser) parseFixed64(kind protoreflect.Kind) {
v, n := protowire.ConsumeFixed64(p.in)
if n < 0 {
p.out, p.in = append(p.out, Raw(p.in)), nil
p.invalid = true
return
}
switch kind {
@ -412,10 +434,11 @@ func (p *parser) parseFixed64(kind protoreflect.Kind) {
}
}
func (p *parser) parseBytes(isPacked bool, kind protoreflect.Kind, desc protoreflect.MessageDescriptor) {
func (p *parser) parseBytes(isPacked bool, kind protoreflect.Kind, desc protoreflect.MessageDescriptor, inferMessage bool) {
v, n := protowire.ConsumeVarint(p.in)
if n < 0 {
p.out, p.in = append(p.out, Raw(p.in)), nil
p.invalid = true
return
}
p.out, p.in = append(p.out, Uvarint(v)), p.in[n:]
@ -424,6 +447,7 @@ func (p *parser) parseBytes(isPacked bool, kind protoreflect.Kind, desc protoref
}
if v > uint64(len(p.in)) {
p.out, p.in = append(p.out, Raw(p.in)), nil
p.invalid = true
return
}
p.out = p.out[:len(p.out)-1] // subsequent tokens contain prefix-length
@ -434,11 +458,22 @@ func (p *parser) parseBytes(isPacked bool, kind protoreflect.Kind, desc protoref
switch kind {
case protoreflect.MessageKind:
p2 := parser{in: p.in[:v]}
p2.parseMessage(desc, false)
p2.parseMessage(desc, false, inferMessage)
p.out, p.in = append(p.out, LengthPrefix(p2.out)), p.in[v:]
case protoreflect.StringKind:
p.out, p.in = append(p.out, String(p.in[:v])), p.in[v:]
case protoreflect.BytesKind:
p.out, p.in = append(p.out, Bytes(p.in[:v])), p.in[v:]
default:
if inferMessage {
// Check whether this is a syntactically valid message.
p2 := parser{in: p.in[:v]}
p2.parseMessage(nil, false, inferMessage)
if !p2.invalid {
p.out, p.in = append(p.out, LengthPrefix(p2.out)), p.in[v:]
break
}
}
p.out, p.in = append(p.out, Bytes(p.in[:v])), p.in[v:]
}
}
@ -466,9 +501,9 @@ func (p *parser) parsePacked(n int, kind protoreflect.Kind) {
p.out, p.in = append(p.out, LengthPrefix(p2.out)), p.in[n:]
}
func (p *parser) parseGroup(desc protoreflect.MessageDescriptor) {
func (p *parser) parseGroup(startNum protowire.Number, desc protoreflect.MessageDescriptor, inferMessage bool) {
p2 := parser{in: p.in}
p2.parseMessage(desc, true)
p2.parseMessage(desc, true, inferMessage)
if len(p2.out) > 0 {
p.out = append(p.out, Message(p2.out))
}
@ -476,8 +511,11 @@ func (p *parser) parseGroup(desc protoreflect.MessageDescriptor) {
// Append the trailing end group.
v, n := protowire.ConsumeVarint(p.in)
if num, typ := protowire.DecodeTag(v); typ == EndGroupType {
p.out, p.in = append(p.out, Tag{num, typ}), p.in[n:]
if endNum, typ := protowire.DecodeTag(v); typ == EndGroupType {
if startNum != endNum {
p.invalid = true
}
p.out, p.in = append(p.out, Tag{endNum, typ}), p.in[n:]
if m := n - protowire.SizeVarint(v); m > 0 {
p.out[len(p.out)-1] = Denormalized{uint(m), p.out[len(p.out)-1]}
}

View File

@ -15,6 +15,7 @@ import (
"google.golang.org/protobuf/encoding/prototext"
pdesc "google.golang.org/protobuf/reflect/protodesc"
"google.golang.org/protobuf/reflect/protoreflect"
pref "google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/types/descriptorpb"
@ -69,6 +70,8 @@ func TestPack(t *testing.T) {
tests := []struct {
raw []byte
msg Message
msgDesc protoreflect.MessageDescriptor
inferMsg bool
wantOutCompact string
wantOutMulti string
@ -81,12 +84,22 @@ func TestPack(t *testing.T) {
Tag{1, VarintType}, Denormalized{5, Uvarint(2)},
Tag{1, BytesType}, LengthPrefix{Bool(true), Bool(false), Uvarint(2), Denormalized{5, Uvarint(2)}},
},
msgDesc: msgDesc,
wantOutSource: `protopack.Message{
protopack.Tag{1, protopack.VarintType}, protopack.Bool(false),
protopack.Denormalized{+5, protopack.Tag{1, protopack.VarintType}}, protopack.Uvarint(2),
protopack.Tag{1, protopack.VarintType}, protopack.Denormalized{+5, protopack.Uvarint(2)},
protopack.Tag{1, protopack.BytesType}, protopack.LengthPrefix{protopack.Bool(true), protopack.Bool(false), protopack.Uvarint(2), protopack.Denormalized{+5, protopack.Uvarint(2)}},
}`,
}, {
raw: dhex("080088808080800002088280808080000a09010002828080808000"),
msg: Message{
Tag{1, VarintType}, Uvarint(0),
Denormalized{5, Tag{1, VarintType}}, Uvarint(2),
Tag{1, VarintType}, Denormalized{5, Uvarint(2)},
Tag{1, BytesType}, Bytes(Message{Bool(true), Bool(false), Uvarint(2), Denormalized{5, Uvarint(2)}}.Marshal()),
},
inferMsg: true,
}, {
raw: dhex("100010828080808000121980808080808080808001ffffffffffffffff7f828080808000"),
msg: Message{
@ -94,6 +107,7 @@ func TestPack(t *testing.T) {
Tag{2, VarintType}, Denormalized{5, Varint(2)},
Tag{2, BytesType}, LengthPrefix{Varint(math.MinInt64), Varint(math.MaxInt64), Denormalized{5, Varint(2)}},
},
msgDesc: msgDesc,
wantOutCompact: `Message{Tag{2, Varint}, Varint(0), Tag{2, Varint}, Denormalized{+5, Varint(2)}, Tag{2, Bytes}, LengthPrefix{Varint(-9223372036854775808), Varint(9223372036854775807), Denormalized{+5, Varint(2)}}}`,
}, {
raw: dhex("1801188180808080001a1affffffffffffffffff01feffffffffffffffff01818080808000"),
@ -102,6 +116,7 @@ func TestPack(t *testing.T) {
Tag{3, VarintType}, Denormalized{5, Svarint(-1)},
Tag{3, BytesType}, LengthPrefix{Svarint(math.MinInt64), Svarint(math.MaxInt64), Denormalized{5, Svarint(-1)}},
},
msgDesc: msgDesc,
wantOutMulti: `Message{
Tag{3, Varint}, Svarint(-1),
Tag{3, Varint}, Denormalized{+5, Svarint(-1)},
@ -114,6 +129,7 @@ func TestPack(t *testing.T) {
Tag{4, VarintType}, Denormalized{5, Uvarint(+1)},
Tag{4, BytesType}, LengthPrefix{Uvarint(0), Uvarint(math.MaxUint64), Denormalized{5, Uvarint(+1)}},
},
msgDesc: msgDesc,
wantOutSource: `protopack.Message{
protopack.Tag{4, protopack.VarintType}, protopack.Uvarint(1),
protopack.Tag{4, protopack.VarintType}, protopack.Denormalized{+5, protopack.Uvarint(1)},
@ -125,6 +141,7 @@ func TestPack(t *testing.T) {
Tag{5, Fixed32Type}, Uint32(+1),
Tag{5, BytesType}, LengthPrefix{Uint32(0), Uint32(math.MaxUint32)},
},
msgDesc: msgDesc,
wantOutCompact: `Message{Tag{5, Fixed32}, Uint32(1), Tag{5, Bytes}, LengthPrefix{Uint32(0), Uint32(4294967295)}}`,
}, {
raw: dhex("35ffffffff320800000080ffffff7f"),
@ -132,6 +149,7 @@ func TestPack(t *testing.T) {
Tag{6, Fixed32Type}, Int32(-1),
Tag{6, BytesType}, LengthPrefix{Int32(math.MinInt32), Int32(math.MaxInt32)},
},
msgDesc: msgDesc,
wantOutMulti: `Message{
Tag{6, Fixed32}, Int32(-1),
Tag{6, Bytes}, LengthPrefix{Int32(-2147483648), Int32(2147483647)},
@ -142,6 +160,7 @@ func TestPack(t *testing.T) {
Tag{7, Fixed32Type}, Float32(math.Pi),
Tag{7, BytesType}, LengthPrefix{Float32(math.SmallestNonzeroFloat32), Float32(math.MaxFloat32), Float32(math.Inf(+1)), Float32(math.Inf(-1))},
},
msgDesc: msgDesc,
wantOutSource: `protopack.Message{
protopack.Tag{7, protopack.Fixed32Type}, protopack.Float32(3.1415927),
protopack.Tag{7, protopack.BytesType}, protopack.LengthPrefix{protopack.Float32(1e-45), protopack.Float32(3.4028235e+38), protopack.Float32(math.Inf(+1)), protopack.Float32(math.Inf(-1))},
@ -152,6 +171,7 @@ func TestPack(t *testing.T) {
Tag{8, Fixed64Type}, Uint64(+1),
Tag{8, BytesType}, LengthPrefix{Uint64(0), Uint64(math.MaxUint64)},
},
msgDesc: msgDesc,
wantOutCompact: `Message{Tag{8, Fixed64}, Uint64(1), Tag{8, Bytes}, LengthPrefix{Uint64(0), Uint64(18446744073709551615)}}`,
}, {
raw: dhex("49ffffffffffffffff4a100000000000000080ffffffffffffff7f"),
@ -159,6 +179,7 @@ func TestPack(t *testing.T) {
Tag{9, Fixed64Type}, Int64(-1),
Tag{9, BytesType}, LengthPrefix{Int64(math.MinInt64), Int64(math.MaxInt64)},
},
msgDesc: msgDesc,
wantOutMulti: `Message{
Tag{9, Fixed64}, Int64(-1),
Tag{9, Bytes}, LengthPrefix{Int64(-9223372036854775808), Int64(9223372036854775807)},
@ -169,6 +190,7 @@ func TestPack(t *testing.T) {
Tag{10, Fixed64Type}, Float64(math.Pi),
Tag{10, BytesType}, LengthPrefix{Float64(math.SmallestNonzeroFloat64), Float64(math.MaxFloat64), Float64(math.Inf(+1)), Float64(math.Inf(-1))},
},
msgDesc: msgDesc,
wantOutMulti: `Message{
Tag{10, Fixed64}, Float64(3.141592653589793),
Tag{10, Bytes}, LengthPrefix{Float64(5e-324), Float64(1.7976931348623157e+308), Float64(+Inf), Float64(-Inf)},
@ -179,6 +201,7 @@ func TestPack(t *testing.T) {
Tag{11, BytesType}, String("string"),
Tag{11, BytesType}, Denormalized{+5, String("string")},
},
msgDesc: msgDesc,
wantOutCompact: `Message{Tag{11, Bytes}, String("string"), Tag{11, Bytes}, Denormalized{+5, String("string")}}`,
}, {
raw: dhex("62056279746573628580808080006279746573"),
@ -186,6 +209,7 @@ func TestPack(t *testing.T) {
Tag{12, BytesType}, Bytes("bytes"),
Tag{12, BytesType}, Denormalized{+5, Bytes("bytes")},
},
msgDesc: msgDesc,
wantOutMulti: `Message{
Tag{12, Bytes}, Bytes("bytes"),
Tag{12, Bytes}, Denormalized{+5, Bytes("bytes")},
@ -201,6 +225,7 @@ func TestPack(t *testing.T) {
Tag{100, StartGroupType}, Tag{100, EndGroupType},
}),
},
msgDesc: msgDesc,
wantOutSource: `protopack.Message{
protopack.Tag{13, protopack.BytesType}, protopack.LengthPrefix(protopack.Message{
protopack.Tag{100, protopack.VarintType}, protopack.Uvarint(18446744073709551615),
@ -211,6 +236,30 @@ func TestPack(t *testing.T) {
protopack.Tag{100, protopack.EndGroupType},
}),
}`,
}, {
raw: dhex("6a28a006ffffffffffffffffff01a506ffffffffa106ffffffffffffffffa206056279746573a306a406"),
msg: Message{
Tag{13, BytesType}, LengthPrefix(Message{
Tag{100, VarintType}, Uvarint(math.MaxUint64),
Tag{100, Fixed32Type}, Uint32(math.MaxUint32),
Tag{100, Fixed64Type}, Uint64(math.MaxUint64),
Tag{100, BytesType}, Bytes("bytes"),
Tag{100, StartGroupType}, Tag{100, EndGroupType},
}),
},
inferMsg: true,
}, {
raw: dhex("6a28a006ffffffffffffffffff01a506ffffffffa106ffffffffffffffffa206056279746573a306ac06"),
msg: Message{
Tag{13, BytesType}, Bytes(Message{
Tag{100, VarintType}, Uvarint(math.MaxUint64),
Tag{100, Fixed32Type}, Uint32(math.MaxUint32),
Tag{100, Fixed64Type}, Uint64(math.MaxUint64),
Tag{100, BytesType}, Bytes("bytes"),
Tag{100, StartGroupType}, Tag{101, EndGroupType},
}.Marshal()),
},
inferMsg: true,
}, {
raw: dhex("6aa88080808000a006ffffffffffffffffff01a506ffffffffa106ffffffffffffffffa206056279746573a306a406"),
msg: Message{
@ -222,6 +271,7 @@ func TestPack(t *testing.T) {
Tag{100, StartGroupType}, Tag{100, EndGroupType},
})},
},
msgDesc: msgDesc,
wantOutCompact: `Message{Tag{13, Bytes}, Denormalized{+5, LengthPrefix(Message{Tag{100, Varint}, Uvarint(18446744073709551615), Tag{100, Fixed32}, Uint32(4294967295), Tag{100, Fixed64}, Uint64(18446744073709551615), Tag{100, Bytes}, Bytes("bytes"), Tag{100, StartGroup}, Tag{100, EndGroup}})}}`,
}, {
raw: dhex("73a006ffffffffffffffffff01a506ffffffffa106ffffffffffffffffa206056279746573a306a40674"),
@ -235,6 +285,7 @@ func TestPack(t *testing.T) {
},
Tag{14, EndGroupType},
},
msgDesc: msgDesc,
wantOutMulti: `Message{
Tag{14, StartGroup},
Message{
@ -261,6 +312,7 @@ func TestPack(t *testing.T) {
Tag{1706, Type(7)},
Raw("\x1an\x98\x11\xc8Z*\xb3"),
},
msgDesc: msgDesc,
}, {
raw: dhex("3d08d0e57f"),
msg: Message{
@ -269,6 +321,7 @@ func TestPack(t *testing.T) {
func() uint32 { return 0x7fe5d008 }(),
)),
},
msgDesc: msgDesc,
wantOutSource: `protopack.Message{
protopack.Tag{7, protopack.Fixed32Type}, protopack.Float32(math.Float32frombits(0x7fe5d008)),
}`,
@ -277,6 +330,7 @@ func TestPack(t *testing.T) {
msg: Message{
Tag{10, Fixed64Type}, Float64(math.Float64frombits(0x7ff91b771051d6a8)),
},
msgDesc: msgDesc,
wantOutSource: `protopack.Message{
protopack.Tag{10, protopack.Fixed64Type}, protopack.Float64(math.Float64frombits(0x7ff91b771051d6a8)),
}`,
@ -302,6 +356,7 @@ func TestPack(t *testing.T) {
Tag{28856, BytesType},
Raw("\xbb"),
},
msgDesc: msgDesc,
}, {
raw: dhex("29baa4ac1c1e0a20183393bac434b8d3559337ec940050038770eaa9937f98e4"),
msg: Message{
@ -318,6 +373,7 @@ func TestPack(t *testing.T) {
Raw("꩓\u007f\x98\xe4"),
},
},
msgDesc: msgDesc,
}}
equateFloatBits := cmp.Options{
@ -332,13 +388,13 @@ func TestPack(t *testing.T) {
t.Run("", func(t *testing.T) {
var msg Message
raw := tt.msg.Marshal()
msg.UnmarshalDescriptor(tt.raw, msgDesc)
msg.unmarshal(tt.raw, tt.msgDesc, tt.inferMsg)
if !bytes.Equal(raw, tt.raw) {
t.Errorf("Marshal() mismatch:\ngot %x\nwant %x", raw, tt.raw)
}
if !cmp.Equal(msg, tt.msg, equateFloatBits) {
t.Errorf("Unmarshal() mismatch:\ngot %+v\nwant %+v", msg, tt.msg)
if diff := cmp.Diff(tt.msg, msg, equateFloatBits); diff != "" {
t.Errorf("Unmarshal() mismatch (-want +got):\n%s", diff)
}
if got, want := tt.msg.Size(), len(tt.raw); got != want {
t.Errorf("Size() = %v, want %v", got, want)