// Copyright 2019 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package protocmp import ( "bytes" "fmt" "reflect" "sort" "strconv" "strings" "google.golang.org/protobuf/internal/detrand" "google.golang.org/protobuf/internal/encoding/wire" "google.golang.org/protobuf/reflect/protoreflect" ) // This implements a custom text marshaler similar to the prototext format. // We don't use the prototext marshaler so that we can: // • have finer grain control over the ordering of fields // • marshal maps with a more aesthetically pleasant output // // TODO: If the prototext format gains a map-specific syntax, consider just // using the prototext marshaler instead. func appendValue(b []byte, v interface{}) []byte { switch v := v.(type) { case bool, int32, int64, uint32, uint64, float32, float64: return append(b, fmt.Sprint(v)...) case string: return append(b, strconv.Quote(string(v))...) case []byte: return append(b, strconv.Quote(string(v))...) case Enum: return append(b, v.String()...) case Message: return appendMessage(b, v) case protoreflect.RawFields: return appendValue(b, transformRawFields(v)) default: switch v := reflect.ValueOf(v); v.Kind() { case reflect.Slice: return appendList(b, v) case reflect.Map: return appendMap(b, v) default: panic(fmt.Sprintf("invalid type: %v", v.Type())) } } } func appendMessage(b []byte, m Message) []byte { var knownKeys, extensionKeys, unknownKeys []string for k := range m { switch { case protoreflect.Name(k).IsValid(): knownKeys = append(knownKeys, k) case strings.HasPrefix(k, "[") && strings.HasSuffix(k, "]"): extensionKeys = append(extensionKeys, k) case len(strings.Trim(k, "0123456789")) == 0: unknownKeys = append(unknownKeys, k) } } sort.Slice(knownKeys, func(i, j int) bool { fdi := m.Descriptor().Fields().ByName(protoreflect.Name(knownKeys[i])) fdj := m.Descriptor().Fields().ByName(protoreflect.Name(knownKeys[j])) return fdi.Index() < fdj.Index() }) sort.Slice(extensionKeys, func(i, j int) bool { return extensionKeys[i] < extensionKeys[j] }) sort.Slice(unknownKeys, func(i, j int) bool { ni, _ := strconv.Atoi(unknownKeys[i]) nj, _ := strconv.Atoi(unknownKeys[j]) return ni < nj }) ks := append(append(append([]string(nil), knownKeys...), extensionKeys...), unknownKeys...) b = append(b, '{') for _, k := range ks { b = append(b, k...) b = append(b, ':') b = appendValue(b, m[k]) b = append(b, delim()...) } b = bytes.TrimRight(b, delim()) b = append(b, '}') return b } func appendList(b []byte, v reflect.Value) []byte { b = append(b, '[') for i := 0; i < v.Len(); i++ { b = appendValue(b, v.Index(i).Interface()) b = append(b, delim()...) } b = bytes.TrimRight(b, delim()) b = append(b, ']') return b } func appendMap(b []byte, v reflect.Value) []byte { ks := v.MapKeys() sort.Slice(ks, func(i, j int) bool { ki, kj := ks[i], ks[j] switch ki.Kind() { case reflect.Bool: return !ki.Bool() && kj.Bool() case reflect.Int32, reflect.Int64: return ki.Int() < kj.Int() case reflect.Uint32, reflect.Uint64: return ki.Uint() < kj.Uint() case reflect.String: return ki.String() < kj.String() default: panic(fmt.Sprintf("invalid kind: %v", ki.Kind())) } }) b = append(b, '{') for _, k := range ks { b = appendValue(b, k.Interface()) b = append(b, ':') b = appendValue(b, v.MapIndex(k).Interface()) b = append(b, delim()...) } b = bytes.TrimRight(b, delim()) b = append(b, '}') return b } func transformRawFields(b protoreflect.RawFields) interface{} { var vs []interface{} for len(b) > 0 { num, typ, n := wire.ConsumeTag(b) m := wire.ConsumeFieldValue(num, typ, b[n:]) vs = append(vs, transformRawField(typ, b[n:][:m])) b = b[n+m:] } if len(vs) == 1 { return vs[0] } return vs } func transformRawField(typ wire.Type, b protoreflect.RawFields) interface{} { switch typ { case wire.VarintType: v, _ := wire.ConsumeVarint(b) return v case wire.Fixed32Type: v, _ := wire.ConsumeFixed32(b) return v case wire.Fixed64Type: v, _ := wire.ConsumeFixed64(b) return v case wire.BytesType: v, _ := wire.ConsumeBytes(b) return v case wire.StartGroupType: v := Message{} for { num2, typ2, n := wire.ConsumeTag(b) if typ2 == wire.EndGroupType { return v } m := wire.ConsumeFieldValue(num2, typ2, b[n:]) s := strconv.Itoa(int(num2)) b2, _ := v[s].(protoreflect.RawFields) v[s] = append(b2, b[:n+m]...) b = b[n+m:] } default: panic(fmt.Sprintf("invalid type: %v", typ)) } } func delim() string { // Deliberately introduce instability into the message string to // discourage users from depending on it. if detrand.Bool() { return " " } return ", " }