Optimize remove_trailing_zeros

This commit is contained in:
Junekey Jeon 2022-02-08 18:32:20 -08:00 committed by Victor Zverovich
parent 7b4323e1e0
commit 10642e6082

View File

@ -1835,115 +1835,86 @@ bool is_left_endpoint_integer_shorter_interval(int exponent) noexcept {
// Remove trailing zeros from n and return the number of zeros removed (float)
FMT_INLINE int remove_trailing_zeros(uint32_t& n) noexcept {
#ifdef FMT_BUILTIN_CTZ
int t = FMT_BUILTIN_CTZ(n);
#else
int t = ctz(n);
#endif
if (t > float_info<float>::max_trailing_zeros)
t = float_info<float>::max_trailing_zeros;
const uint32_t mod_inv1 = 0xcccccccd;
const uint32_t max_quotient1 = 0x33333333;
const uint32_t mod_inv2 = 0xc28f5c29;
const uint32_t max_quotient2 = 0x0a3d70a3;
FMT_ASSERT(n != 0, "");
const uint32_t mod_inv_5 = 0xcccccccd;
const uint32_t mod_inv_25 = mod_inv_5 * mod_inv_5;
int s = 0;
for (; s < t - 1; s += 2) {
if (n * mod_inv2 > max_quotient2) break;
n *= mod_inv2;
while (true) {
auto q = rotr(n * mod_inv_25, 2);
if (q <= std::numeric_limits<uint32_t>::max() / 100) {
n = q;
s += 2;
} else {
break;
}
}
if (s < t && n * mod_inv1 <= max_quotient1) {
n *= mod_inv1;
++s;
auto q = rotr(n * mod_inv_5, 1);
if (q <= std::numeric_limits<uint32_t>::max() / 10) {
n = q;
s |= 1;
}
n >>= s;
return s;
}
// Removes trailing zeros and returns the number of zeros removed (double)
FMT_INLINE int remove_trailing_zeros(uint64_t& n) noexcept {
#ifdef FMT_BUILTIN_CTZLL
int t = FMT_BUILTIN_CTZLL(n);
#else
int t = ctzll(n);
#endif
if (t > float_info<double>::max_trailing_zeros)
t = float_info<double>::max_trailing_zeros;
// Divide by 10^8 and reduce to 32-bits
// Since ret_value.significand <= (2^64 - 1) / 1000 < 10^17,
// both of the quotient and the r should fit in 32-bits
FMT_ASSERT(n != 0, "");
const uint32_t mod_inv1 = 0xcccccccd;
const uint32_t max_quotient1 = 0x33333333;
const uint64_t mod_inv8 = 0xc767074b22e90e21;
const uint64_t max_quotient8 = 0x00002af31dc46118;
// This magic number is ceil(2^90 / 10^8).
constexpr auto magic_number = uint64_t(12379400392853802749ull);
auto nm = umul128(n, magic_number);
// If the number is divisible by 1'0000'0000, work with the quotient
if (t >= 8) {
auto quotient_candidate = n * mod_inv8;
// Is n is divisible by 10^8?
if ((nm.high() & ((1ull << (90 - 64)) - 1)) == 0 && nm.low() < magic_number) {
// If yes, work with the quotient.
auto n32 = static_cast<uint32_t>(nm.high() >> (90 - 64));
if (quotient_candidate <= max_quotient8) {
auto quotient = static_cast<uint32_t>(quotient_candidate >> 8);
const uint32_t mod_inv_5 = 0xcccccccd;
const uint32_t mod_inv_25 = mod_inv_5 * mod_inv_5;
int s = 8;
for (; s < t; ++s) {
if (quotient * mod_inv1 > max_quotient1) break;
quotient *= mod_inv1;
int s = 8;
while (true) {
auto q = rotr(n32 * mod_inv_25, 2);
if (q <= std::numeric_limits<uint32_t>::max() / 100) {
n32 = q;
s += 2;
} else {
break;
}
quotient >>= (s - 8);
n = quotient;
return s;
}
auto q = rotr(n32 * mod_inv_5, 1);
if (q <= std::numeric_limits<uint32_t>::max() / 10) {
n32 = q;
s |= 1;
}
n = n32;
return s;
}
// If n is not divisible by 10^8, work with n itself.
const uint64_t mod_inv_5 = 0xcccccccc'cccccccd;
const uint64_t mod_inv_25 = mod_inv_5 * mod_inv_5;
int s = 0;
while (true) {
auto q = rotr(n * mod_inv_25, 2);
if (q <= std::numeric_limits<uint64_t>::max() / 100) {
n = q;
s += 2;
} else {
break;
}
}
// Otherwise, work with the remainder
auto quotient = static_cast<uint32_t>(n / 100000000);
auto remainder = static_cast<uint32_t>(n - 100000000 * quotient);
if (t == 0 || remainder * mod_inv1 > max_quotient1) {
return 0;
auto q = rotr(n * mod_inv_5, 1);
if (q <= std::numeric_limits<uint64_t>::max() / 10) {
n = q;
s |= 1;
}
remainder *= mod_inv1;
if (t == 1 || remainder * mod_inv1 > max_quotient1) {
n = (remainder >> 1) + quotient * 10000000ull;
return 1;
}
remainder *= mod_inv1;
if (t == 2 || remainder * mod_inv1 > max_quotient1) {
n = (remainder >> 2) + quotient * 1000000ull;
return 2;
}
remainder *= mod_inv1;
if (t == 3 || remainder * mod_inv1 > max_quotient1) {
n = (remainder >> 3) + quotient * 100000ull;
return 3;
}
remainder *= mod_inv1;
if (t == 4 || remainder * mod_inv1 > max_quotient1) {
n = (remainder >> 4) + quotient * 10000ull;
return 4;
}
remainder *= mod_inv1;
if (t == 5 || remainder * mod_inv1 > max_quotient1) {
n = (remainder >> 5) + quotient * 1000ull;
return 5;
}
remainder *= mod_inv1;
if (t == 6 || remainder * mod_inv1 > max_quotient1) {
n = (remainder >> 6) + quotient * 100ull;
return 6;
}
remainder *= mod_inv1;
n = (remainder >> 7) + quotient * 10ull;
return 7;
return s;
}
// The main algorithm for shorter interval case