241 lines
5.5 KiB
Python
Raw Normal View History

2022-06-05 18:06:03 +02:00
#
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import math
class Bitstream:
def __init__(self, data):
self.bytes = data
self.bp_bw = len(data) - 1
self.mask_bw = 1
self.bp = 0
self.low = 0
self.range = 0xffffff
def dump(self):
b = self.bytes
for i in range(0, len(b), 20):
print(''.join('{:02x} '.format(x)
for x in b[i:min(i+20, len(b))] ))
class BitstreamReader(Bitstream):
def __init__(self, data):
super().__init__(data)
self.low = ( (self.bytes[0] << 16) |
(self.bytes[1] << 8) |
(self.bytes[2] ) )
self.bp = 3
def read_bit(self):
bit = bool(self.bytes[self.bp_bw] & self.mask_bw)
self.mask_bw <<= 1
if self.mask_bw == 0x100:
self.mask_bw = 1
self.bp_bw -= 1
return bit
def read_uint(self, nbits):
val = 0
for k in range(nbits):
val |= self.read_bit() << k
return val
def ac_decode(self, cum_freqs, sym_freqs):
r = self.range >> 10
if self.low >= r << 10:
raise ValueError('Invalid ac bitstream')
val = len(cum_freqs) - 1
while self.low < r * cum_freqs[val]:
val -= 1
self.low -= r * cum_freqs[val]
self.range = r * sym_freqs[val]
while self.range < 0x10000:
self.range <<= 8
self.low <<= 8
self.low &= 0xffffff
self.low += self.bytes[self.bp]
self.bp += 1
return val
def get_bits_left(self):
nbits = 8 * len(self.bytes)
nbits_bw = nbits - \
(8*self.bp_bw + 8 - int(math.log2(self.mask_bw)))
nbits_ac = 8 * (self.bp - 3) + \
(25 - int(math.floor(math.log2(self.range))))
return nbits - (nbits_bw + nbits_ac)
class BitstreamWriter(Bitstream):
def __init__(self, nbytes):
super().__init__(bytearray(nbytes))
self.cache = -1
self.carry = 0
self.carry_count = 0
def write_bit(self, bit):
mask = self.mask_bw
bp = self.bp_bw
if bit == 0:
self.bytes[bp] &= ~mask
else:
self.bytes[bp] |= mask
self.mask_bw <<= 1
if self.mask_bw == 0x100:
self.mask_bw = 1
self.bp_bw -= 1
def write_uint(self, val, nbits):
for k in range(nbits):
self.write_bit(val & 1)
val >>= 1
def ac_shift(self):
if self.low < 0xff0000 or self.carry == 1:
if self.cache >= 0:
self.bytes[self.bp] = self.cache + self.carry
self.bp += 1
while self.carry_count > 0:
self.bytes[self.bp] = (self.carry + 0xff) & 0xff
self.bp += 1
self.carry_count -= 1
self.cache = self.low >> 16
self.carry = 0
else:
self.carry_count += 1
self.low <<= 8
self.low &= 0xffffff
def ac_encode(self, cum_freq, sym_freq):
r = self.range >> 10
self.low += r * cum_freq
if (self.low >> 24) != 0:
self.carry = 1
self.low &= 0xffffff
self.range = r * sym_freq
while self.range < 0x10000:
self.range <<= 8;
self.ac_shift()
def get_bits_left(self):
nbits = 8 * len(self.bytes)
nbits_bw = nbits - \
(8*self.bp_bw + 8 - int(math.log2(self.mask_bw)))
nbits_ac = 8 * self.bp + (25 - int(math.floor(math.log2(self.range))))
if self.cache >= 0:
nbits_ac += 8
if self.carry_count > 0:
nbits_ac += 8 * self.carry_count
return nbits - (nbits_bw + nbits_ac)
def terminate(self):
bits = 1
while self.range >> (24 - bits) == 0:
bits += 1
mask = 0xffffff >> bits;
val = self.low + mask;
over1 = val >> 24
val &= 0x00ffffff
high = self.low + self.range
over2 = high >> 24
high &= 0x00ffffff
val = val & ~mask
if over1 == over2:
if val + mask >= high:
bits += 1
mask >>= 1
val = ((self.low + mask) & 0x00ffffff) & ~mask
if val < self.low:
self.carry = 1
self.low = val
while bits > 0:
self.ac_shift()
bits -= 8
bits += 8;
val = self.cache
if self.carry_count > 0:
self.bytes[self.bp] = self.cache
self.bp += 1
while self.carry_count > 1:
self.bytes[self.bp] = 0xff
self.bp += 1
self.carry_count -= 1
val = 0xff >> (8 - bits)
mask = 0x80;
for k in range(bits):
if val & mask == 0:
self.bytes[self.bp] &= ~mask
else:
self.bytes[self.bp] |= mask
mask >>= 1
return self.bytes