diff --git a/format.cc b/format.cc index 88801fc0..08064ac1 100644 --- a/format.cc +++ b/format.cc @@ -354,6 +354,11 @@ class ArgConverter : public fmt::internal::ArgVisitor, void> { ArgConverter(fmt::internal::Arg &arg, wchar_t type) : arg_(arg), type_(type) {} + void visit_bool(bool value) { + if (type_ != 's') + visit_any_int(value); + } + template void visit_any_int(U value) { bool is_signed = type_ == 'd' || type_ == 'i'; @@ -418,7 +423,13 @@ class BasicArgFormatter : public ArgVisitor { protected: BasicWriter &writer() { return writer_; } - const FormatSpec &spec() const { return spec_; } + FormatSpec &spec() { return spec_; } + + void write_bool(bool value) { + const char *str_value = value ? "true" : "false"; + Arg::StringValue str = { str_value, strlen(str_value) }; + writer_.write_str(str, spec_); + } public: BasicArgFormatter(BasicWriter &w, FormatSpec &s) @@ -431,13 +442,9 @@ class BasicArgFormatter : public ArgVisitor { void visit_any_double(T value) { writer_.write_double(value, spec_); } void visit_bool(bool value) { - if (spec_.type_) { - writer_.write_int(value, spec_); - return; - } - const char *str_value = value ? "true" : "false"; - Arg::StringValue str = { str_value, strlen(str_value) }; - writer_.write_str(str, spec_); + if (spec_.type_) + return visit_any_int(value); + write_bool(value); } void visit_char(int value) { @@ -470,10 +477,8 @@ class BasicArgFormatter : public ArgVisitor { } void visit_cstring(const char *value) { - if (spec_.type_ == 'p') { - write_pointer(value); - return; - } + if (spec_.type_ == 'p') + return write_pointer(value); Arg::StringValue str = {value, 0}; writer_.write_str(str, spec_); } @@ -524,6 +529,14 @@ class PrintfArgFormatter : PrintfArgFormatter(BasicWriter &w, FormatSpec &s) : BasicArgFormatter, Char>(w, s) {} + void visit_bool(bool value) { + FormatSpec &fmt_spec = this->spec(); + if (fmt_spec.type_ != 's') + return this->visit_any_int(value); + fmt_spec.type_ = 0; + this->write_bool(value); + } + void visit_char(int value) { const FormatSpec &fmt_spec = this->spec(); BasicWriter &w = this->writer(); diff --git a/test/printf-test.cc b/test/printf-test.cc index c07384be..c985d026 100644 --- a/test/printf-test.cc +++ b/test/printf-test.cc @@ -373,6 +373,11 @@ TEST(PrintfTest, Length) { EXPECT_PRINTF(fmt::format("{}", max), "%Lg", max); } +TEST(PrintfTest, Bool) { + EXPECT_PRINTF("1", "%d", true); + EXPECT_PRINTF("true", "%s", true); +} + TEST(PrintfTest, Int) { EXPECT_PRINTF("-42", "%d", -42); EXPECT_PRINTF("-42", "%i", -42);