mirror of
https://github.com/bluekitchen/btstack.git
synced 2025-02-15 03:40:42 +00:00
241 lines
5.5 KiB
Python
241 lines
5.5 KiB
Python
|
#
|
||
|
# Copyright 2022 Google LLC
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
#
|
||
|
|
||
|
import math
|
||
|
|
||
|
class Bitstream:
|
||
|
|
||
|
def __init__(self, data):
|
||
|
|
||
|
self.bytes = data
|
||
|
|
||
|
self.bp_bw = len(data) - 1
|
||
|
self.mask_bw = 1
|
||
|
|
||
|
self.bp = 0
|
||
|
self.low = 0
|
||
|
self.range = 0xffffff
|
||
|
|
||
|
def dump(self):
|
||
|
|
||
|
b = self.bytes
|
||
|
|
||
|
for i in range(0, len(b), 20):
|
||
|
print(''.join('{:02x} '.format(x)
|
||
|
for x in b[i:min(i+20, len(b))] ))
|
||
|
|
||
|
class BitstreamReader(Bitstream):
|
||
|
|
||
|
def __init__(self, data):
|
||
|
|
||
|
super().__init__(data)
|
||
|
|
||
|
self.low = ( (self.bytes[0] << 16) |
|
||
|
(self.bytes[1] << 8) |
|
||
|
(self.bytes[2] ) )
|
||
|
self.bp = 3
|
||
|
|
||
|
def read_bit(self):
|
||
|
|
||
|
bit = bool(self.bytes[self.bp_bw] & self.mask_bw)
|
||
|
|
||
|
self.mask_bw <<= 1
|
||
|
if self.mask_bw == 0x100:
|
||
|
self.mask_bw = 1
|
||
|
self.bp_bw -= 1
|
||
|
|
||
|
return bit
|
||
|
|
||
|
def read_uint(self, nbits):
|
||
|
|
||
|
val = 0
|
||
|
for k in range(nbits):
|
||
|
val |= self.read_bit() << k
|
||
|
|
||
|
return val
|
||
|
|
||
|
def ac_decode(self, cum_freqs, sym_freqs):
|
||
|
|
||
|
r = self.range >> 10
|
||
|
if self.low >= r << 10:
|
||
|
raise ValueError('Invalid ac bitstream')
|
||
|
|
||
|
val = len(cum_freqs) - 1
|
||
|
while self.low < r * cum_freqs[val]:
|
||
|
val -= 1
|
||
|
|
||
|
self.low -= r * cum_freqs[val]
|
||
|
self.range = r * sym_freqs[val]
|
||
|
while self.range < 0x10000:
|
||
|
self.range <<= 8
|
||
|
|
||
|
self.low <<= 8
|
||
|
self.low &= 0xffffff
|
||
|
self.low += self.bytes[self.bp]
|
||
|
self.bp += 1
|
||
|
|
||
|
return val
|
||
|
|
||
|
def get_bits_left(self):
|
||
|
|
||
|
nbits = 8 * len(self.bytes)
|
||
|
|
||
|
nbits_bw = nbits - \
|
||
|
(8*self.bp_bw + 8 - int(math.log2(self.mask_bw)))
|
||
|
|
||
|
nbits_ac = 8 * (self.bp - 3) + \
|
||
|
(25 - int(math.floor(math.log2(self.range))))
|
||
|
|
||
|
return nbits - (nbits_bw + nbits_ac)
|
||
|
|
||
|
class BitstreamWriter(Bitstream):
|
||
|
|
||
|
def __init__(self, nbytes):
|
||
|
|
||
|
super().__init__(bytearray(nbytes))
|
||
|
|
||
|
self.cache = -1
|
||
|
self.carry = 0
|
||
|
self.carry_count = 0
|
||
|
|
||
|
def write_bit(self, bit):
|
||
|
|
||
|
mask = self.mask_bw
|
||
|
bp = self.bp_bw
|
||
|
|
||
|
if bit == 0:
|
||
|
self.bytes[bp] &= ~mask
|
||
|
else:
|
||
|
self.bytes[bp] |= mask
|
||
|
|
||
|
self.mask_bw <<= 1
|
||
|
if self.mask_bw == 0x100:
|
||
|
self.mask_bw = 1
|
||
|
self.bp_bw -= 1
|
||
|
|
||
|
def write_uint(self, val, nbits):
|
||
|
|
||
|
for k in range(nbits):
|
||
|
self.write_bit(val & 1)
|
||
|
val >>= 1
|
||
|
|
||
|
def ac_shift(self):
|
||
|
|
||
|
if self.low < 0xff0000 or self.carry == 1:
|
||
|
|
||
|
if self.cache >= 0:
|
||
|
self.bytes[self.bp] = self.cache + self.carry
|
||
|
self.bp += 1
|
||
|
|
||
|
while self.carry_count > 0:
|
||
|
self.bytes[self.bp] = (self.carry + 0xff) & 0xff
|
||
|
self.bp += 1
|
||
|
self.carry_count -= 1
|
||
|
|
||
|
self.cache = self.low >> 16
|
||
|
self.carry = 0
|
||
|
|
||
|
else:
|
||
|
self.carry_count += 1
|
||
|
|
||
|
self.low <<= 8
|
||
|
self.low &= 0xffffff
|
||
|
|
||
|
def ac_encode(self, cum_freq, sym_freq):
|
||
|
|
||
|
r = self.range >> 10
|
||
|
self.low += r * cum_freq
|
||
|
if (self.low >> 24) != 0:
|
||
|
self.carry = 1
|
||
|
|
||
|
self.low &= 0xffffff
|
||
|
self.range = r * sym_freq
|
||
|
while self.range < 0x10000:
|
||
|
self.range <<= 8;
|
||
|
self.ac_shift()
|
||
|
|
||
|
def get_bits_left(self):
|
||
|
|
||
|
nbits = 8 * len(self.bytes)
|
||
|
|
||
|
nbits_bw = nbits - \
|
||
|
(8*self.bp_bw + 8 - int(math.log2(self.mask_bw)))
|
||
|
|
||
|
nbits_ac = 8 * self.bp + (25 - int(math.floor(math.log2(self.range))))
|
||
|
if self.cache >= 0:
|
||
|
nbits_ac += 8
|
||
|
if self.carry_count > 0:
|
||
|
nbits_ac += 8 * self.carry_count
|
||
|
|
||
|
return nbits - (nbits_bw + nbits_ac)
|
||
|
|
||
|
def terminate(self):
|
||
|
|
||
|
bits = 1
|
||
|
while self.range >> (24 - bits) == 0:
|
||
|
bits += 1
|
||
|
|
||
|
mask = 0xffffff >> bits;
|
||
|
val = self.low + mask;
|
||
|
|
||
|
over1 = val >> 24
|
||
|
val &= 0x00ffffff
|
||
|
high = self.low + self.range
|
||
|
over2 = high >> 24
|
||
|
high &= 0x00ffffff
|
||
|
val = val & ~mask
|
||
|
|
||
|
if over1 == over2:
|
||
|
|
||
|
if val + mask >= high:
|
||
|
bits += 1
|
||
|
mask >>= 1
|
||
|
val = ((self.low + mask) & 0x00ffffff) & ~mask
|
||
|
|
||
|
if val < self.low:
|
||
|
self.carry = 1
|
||
|
|
||
|
self.low = val
|
||
|
while bits > 0:
|
||
|
self.ac_shift()
|
||
|
bits -= 8
|
||
|
bits += 8;
|
||
|
|
||
|
val = self.cache
|
||
|
|
||
|
if self.carry_count > 0:
|
||
|
self.bytes[self.bp] = self.cache
|
||
|
self.bp += 1
|
||
|
|
||
|
while self.carry_count > 1:
|
||
|
self.bytes[self.bp] = 0xff
|
||
|
self.bp += 1
|
||
|
self.carry_count -= 1
|
||
|
|
||
|
val = 0xff >> (8 - bits)
|
||
|
|
||
|
mask = 0x80;
|
||
|
for k in range(bits):
|
||
|
|
||
|
if val & mask == 0:
|
||
|
self.bytes[self.bp] &= ~mask
|
||
|
else:
|
||
|
self.bytes[self.bp] |= mask
|
||
|
|
||
|
mask >>= 1
|
||
|
|
||
|
return self.bytes
|