# # Copyright 2022 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import math class Bitstream: def __init__(self, data): self.bytes = data self.bp_bw = len(data) - 1 self.mask_bw = 1 self.bp = 0 self.low = 0 self.range = 0xffffff def dump(self): b = self.bytes for i in range(0, len(b), 20): print(''.join('{:02x} '.format(x) for x in b[i:min(i+20, len(b))] )) class BitstreamReader(Bitstream): def __init__(self, data): super().__init__(data) self.low = ( (self.bytes[0] << 16) | (self.bytes[1] << 8) | (self.bytes[2] ) ) self.bp = 3 def read_bit(self): bit = bool(self.bytes[self.bp_bw] & self.mask_bw) self.mask_bw <<= 1 if self.mask_bw == 0x100: self.mask_bw = 1 self.bp_bw -= 1 return bit def read_uint(self, nbits): val = 0 for k in range(nbits): val |= self.read_bit() << k return val def ac_decode(self, cum_freqs, sym_freqs): r = self.range >> 10 if self.low >= r << 10: raise ValueError('Invalid ac bitstream') val = len(cum_freqs) - 1 while self.low < r * cum_freqs[val]: val -= 1 self.low -= r * cum_freqs[val] self.range = r * sym_freqs[val] while self.range < 0x10000: self.range <<= 8 self.low <<= 8 self.low &= 0xffffff self.low += self.bytes[self.bp] self.bp += 1 return val def get_bits_left(self): nbits = 8 * len(self.bytes) nbits_bw = nbits - \ (8*self.bp_bw + 8 - int(math.log2(self.mask_bw))) nbits_ac = 8 * (self.bp - 3) + \ (25 - int(math.floor(math.log2(self.range)))) return nbits - (nbits_bw + nbits_ac) class BitstreamWriter(Bitstream): def __init__(self, nbytes): super().__init__(bytearray(nbytes)) self.cache = -1 self.carry = 0 self.carry_count = 0 def write_bit(self, bit): mask = self.mask_bw bp = self.bp_bw if bit == 0: self.bytes[bp] &= ~mask else: self.bytes[bp] |= mask self.mask_bw <<= 1 if self.mask_bw == 0x100: self.mask_bw = 1 self.bp_bw -= 1 def write_uint(self, val, nbits): for k in range(nbits): self.write_bit(val & 1) val >>= 1 def ac_shift(self): if self.low < 0xff0000 or self.carry == 1: if self.cache >= 0: self.bytes[self.bp] = self.cache + self.carry self.bp += 1 while self.carry_count > 0: self.bytes[self.bp] = (self.carry + 0xff) & 0xff self.bp += 1 self.carry_count -= 1 self.cache = self.low >> 16 self.carry = 0 else: self.carry_count += 1 self.low <<= 8 self.low &= 0xffffff def ac_encode(self, cum_freq, sym_freq): r = self.range >> 10 self.low += r * cum_freq if (self.low >> 24) != 0: self.carry = 1 self.low &= 0xffffff self.range = r * sym_freq while self.range < 0x10000: self.range <<= 8; self.ac_shift() def get_bits_left(self): nbits = 8 * len(self.bytes) nbits_bw = nbits - \ (8*self.bp_bw + 8 - int(math.log2(self.mask_bw))) nbits_ac = 8 * self.bp + (25 - int(math.floor(math.log2(self.range)))) if self.cache >= 0: nbits_ac += 8 if self.carry_count > 0: nbits_ac += 8 * self.carry_count return nbits - (nbits_bw + nbits_ac) def terminate(self): bits = 1 while self.range >> (24 - bits) == 0: bits += 1 mask = 0xffffff >> bits; val = self.low + mask; over1 = val >> 24 val &= 0x00ffffff high = self.low + self.range over2 = high >> 24 high &= 0x00ffffff val = val & ~mask if over1 == over2: if val + mask >= high: bits += 1 mask >>= 1 val = ((self.low + mask) & 0x00ffffff) & ~mask if val < self.low: self.carry = 1 self.low = val while bits > 0: self.ac_shift() bits -= 8 bits += 8; val = self.cache if self.carry_count > 0: self.bytes[self.bp] = self.cache self.bp += 1 while self.carry_count > 1: self.bytes[self.bp] = 0xff self.bp += 1 self.carry_count -= 1 val = 0xff >> (8 - bits) mask = 0x80; for k in range(bits): if val & mask == 0: self.bytes[self.bp] &= ~mask else: self.bytes[self.bp] |= mask mask >>= 1 return self.bytes