sbc: calculate cos in advance, prepare for DCT

This commit is contained in:
Milanka Ringwald 2016-05-20 15:30:04 +02:00
parent 1bd8157e8c
commit 205be8eae5
3 changed files with 118 additions and 38 deletions

View File

@ -4,10 +4,11 @@ import wave
import struct
import sys
from sbc import *
from sbc_synthesis_v1 import *
V = np.zeros(shape = (2, 10*2*8))
N = np.zeros(shape = (16,8))
total_time_ms = 0
implementation = "SIG"
def sbc_unpack_frame(fin, available_bytes, frame):
if available_bytes == 0:
@ -106,7 +107,7 @@ def sbc_reconstruct_subband_samples(frame):
def sbc_frame_synthesis_sig(frame, ch, blk, proto_table):
global V
global V, N
M = frame.nr_subbands
L = 10 * M
M2 = 2*M
@ -126,8 +127,7 @@ def sbc_frame_synthesis_sig(frame, ch, blk, proto_table):
for k in range(M2):
V[ch][k] = 0
for i in range(M):
N = np.cos((i+0.5)*(k+M/2)*np.pi/M)
V[ch][k] += N * S[i]
V[ch][k] += N[k][i] * S[i]
for i in range(5):
for j in range(M):
@ -146,9 +146,52 @@ def sbc_frame_synthesis_sig(frame, ch, blk, proto_table):
frame.pcm[ch][offset + j] = np.int16(frame.X[j])
def sbc_frame_synthesis(frame, ch, blk, proto_table):
global total_time_ms, implementation
def sbc_frame_synthesis_v1(frame, ch, blk, proto_table):
global V
N = matrix_N()
M = frame.nr_subbands
L = 10 * M
M2 = 2*M
L2 = 2*L
S = np.zeros(M)
U = np.zeros(L)
W = np.zeros(L)
frame.X = np.zeros(M)
for i in range(M):
S[i] = frame.sb_sample[blk][ch][i]
for i in range(L2-1, M2-1,-1):
V[ch][i] = V[ch][i-M2]
for k in range(M2):
V[ch][k] = 0
for i in range(M):
V[ch][k] += N[k][i] * S[i]
for i in range(5):
for j in range(M):
U[i*M2+j] = V[ch][i*2*M2+j]
U[(i*2+1)*M+j] = V[ch][(i*4+3)*M+j]
for i in range(L):
D = proto_table[i] * (-M)
W[i] = U[i]*D
offset = blk*M
for j in range(M):
for i in range(10):
frame.X[j] += W[j+M*i]
frame.pcm[ch][offset + j] = np.int16(frame.X[j])
def sbc_frame_synthesis(frame, ch, blk, proto_table, implementation = "SIG"):
global total_time_ms
t1 = time_ms()
if implementation == "SIG":
sbc_frame_synthesis_sig(frame, ch, blk, proto_table)
@ -160,9 +203,30 @@ def sbc_frame_synthesis(frame, ch, blk, proto_table):
t2 = time_ms()
total_time_ms += t2-t1
def sbc_synthesis(frame):
def sbc_init_synthesis_sig(M):
global N
M2 = M << 1
N = np.zeros(shape = (M2,M))
for k in range(M2):
for i in range(M):
N[k][i] = np.cos((i+0.5)*(k+M/2)*np.pi/M)
def sbc_init_sythesis(nr_subbands, implementation = "SIG"):
if implementation == "SIG":
sbc_init_synthesis_sig(nr_subbands)
elif implementation == "V1":
sbc_init_synthesis_v1(nr_subbands)
else:
print ("synthesis %s not implemented" % implementation)
exit(1)
def sbc_synthesis(frame, implementation = "SIG"):
if frame.nr_subbands == 4:
proto_table = Proto_4_40
elif frame.nr_subbands == 8:
@ -171,14 +235,14 @@ def sbc_synthesis(frame):
return -1
for ch in range(frame.nr_channels):
for blk in range(frame.nr_blocks):
sbc_frame_synthesis(frame, ch, blk, proto_table)
sbc_frame_synthesis(frame, ch, blk, proto_table, implementation)
return frame.nr_blocks * frame.nr_subbands
def sbc_decode(frame):
def sbc_decode(frame, implementation = "SIG"):
err = sbc_reconstruct_subband_samples(frame)
if err >= 0:
err = sbc_synthesis(frame)
err = sbc_synthesis(frame, implementation)
return err
@ -197,11 +261,12 @@ def write_wav_file(fout, frame):
value_str = ''.join(values)
fout.writeframes(value_str)
if __name__ == "__main__":
usage = '''
Usage: ./sbc_decoder.py input.sbc
Usage: ./sbc_decoder.py input.sbc implementation[default=SIG, V1]
'''
if (len(sys.argv) < 2):
@ -216,6 +281,15 @@ if __name__ == "__main__":
wavfile = infile.replace('.sbc', '-decoded.wav')
fout = False
implementation = "SIG"
if len(sys.argv) == 3:
implementation = sys.argv[2]
if implementation != "V1":
print ("synthesis %s not implemented" % implementation)
exit(1)
print ("\nSynthesis implementation: %s\n" % implementation)
with open (infile, 'rb') as fin:
try:
fin.seek(0, 2)
@ -224,34 +298,36 @@ if __name__ == "__main__":
frame_count = 0
while True:
sbc_decoder_frame = SBCFrame()
frame = SBCFrame()
if frame_count % 200 == 0:
print "== Frame %d == %d" % (frame_count, fin.tell())
err = sbc_unpack_frame(fin, file_size - fin.tell(), sbc_decoder_frame)
err = sbc_unpack_frame(fin, file_size - fin.tell(), frame)
if frame_count == 0:
print sbc_decoder_frame
sbc_init_sythesis(frame.nr_subbands, implementation)
print frame
if err:
print "error, frame_count: ", frame_count
break
sbc_decode(sbc_decoder_frame)
sbc_decode(frame, implementation)
if frame_count == 0:
fout = wave.open(wavfile, 'w')
fout.setnchannels(sbc_decoder_frame.nr_channels)
fout.setnchannels(frame.nr_channels)
fout.setsampwidth(2)
fout.setframerate(sampling_frequencies[sbc_decoder_frame.sampling_frequency])
fout.setframerate(sampling_frequencies[frame.sampling_frequency])
fout.setnframes(0)
fout.setcomptype = 'NONE'
write_wav_file(fout, sbc_decoder_frame)
write_wav_file(fout, frame)
frame_count += 1
if frame_count == 1:
break
# if frame_count == 1:
# break
except TypeError as err:
if not fout:
@ -260,7 +336,7 @@ if __name__ == "__main__":
fout.close()
if frame_count > 0:
print ("DONE, SBC file %s decoded into WAV file %s " % (infile, wavfile))
print ("Sythesis average %d ms/frame", total_time_ms/frame_count)
print ("Average sythesis time per frame: %d ms/frame" % (total_time_ms/frame_count))
else:
print ("No frame found")
exit(0)

View File

@ -55,31 +55,36 @@ def sbc_compare_headers(frame_count, actual_frame, expected_frame):
return 0
file_size = 0
def get_actual_frame(fin):
def get_actual_frame(fin, implementation, frame_count):
global file_size
actual_frame = SBCFrame()
sbc_unpack_frame(fin, file_size - fin.tell(), actual_frame)
sbc_reconstruct_subband_samples(actual_frame)
sbc_synthesis(actual_frame)
if subband_frame_count == 0:
sbc_init_sythesis(actual_frame.nr_subbands, implementation)
print actual_frame
sbc_synthesis(actual_frame, implementation)
return actual_frame
def get_expected_frame(fin_expected, nr_blocks, nr_subbands, nr_channels, sampling_frequency, bitpool, allocation_method):
expected_frame = SBCFrame(nr_blocks, nr_subbands, nr_channels, sampling_frequency, bitpool, allocation_method)
fetch_samples_for_next_sbc_frame(fin_expected, expected_frame)
calculate_channel_mode_and_scale_factors(expected_frame)
calculate_channel_mode_and_scale_factors(expected_frame, 0)
return expected_frame
usage = '''
Usage: ./sbc_decoder_test.py decoder_input.sbc decoder_expected_output.wav
Example: ./sbc_decoder_test.py fanfare-4sb.sbc fanfare-4sb-decoded.wav
Usage: ./sbc_decoder_test.py decoder_input.sbc force_channel_mode[No=0, Stereo=2, Joint Stereo=3] implementation[SIG, V1] decoder_expected_output.wav
Example: ./sbc_decoder_test.py fanfare-4sb.sbc 0 fanfare-4sb-decoded.wav
'''
if (len(sys.argv) < 3):
if (len(sys.argv) < 5):
print(usage)
sys.exit(1)
try:
decoder_input_sbc = sys.argv[1]
decoder_expected_wav = sys.argv[2]
force_channel_mode = int(sys.argv[2])
implementation = sys.argv[3]
decoder_expected_wav = sys.argv[4]
if not decoder_input_sbc.endswith('.sbc'):
print(usage)
@ -89,6 +94,8 @@ try:
print(usage)
sys.exit(1)
fin_expected = wave.open(decoder_expected_wav, 'rb')
nr_channels, sampwidth, sampling_frequency, nr_audio_frames, comptype, compname = fin_expected.getparams()
@ -104,9 +111,8 @@ try:
print ("== Frame %d ==" % subband_frame_count)
actual_frame = get_actual_frame(fin)
actual_frame = get_actual_frame(fin, implementation, subband_frame_count)
expected_frame = get_expected_frame(fin_expected, actual_frame.nr_blocks,
actual_frame.nr_subbands, nr_channels,
actual_frame.bitpool, sampling_frequency,
@ -115,17 +121,15 @@ try:
err = sbc_compare_headers(subband_frame_count, actual_frame, expected_frame)
if err < 0:
print ("Headers differ \n%s\n%s" % (actual_frame, expected_frame))
print ("Frame %d: Headers differ \n%s\n%s" % (subband_frame_count, actual_frame, expected_frame))
sys.exit(1)
err = sbc_compare_pcm(subband_frame_count, actual_frame, expected_frame)
if err < 0:
print ("PCMs differ \n%s\n%s" % (actual_frame.pcm, expected_frame.pcm))
print ("Frame %d: PCMs differ %f \n%s\n%s" % (subband_frame_count, max_error, actual_frame.pcm, expected_frame.pcm))
sys.exit(1)
if subband_frame_count == 0:
print actual_frame, expected_frame
subband_frame_count += 1
except TypeError:

View File

@ -129,7 +129,7 @@ try:
bitpool = int(sys.argv[4])
allocation_method = int(sys.argv[5])
encoder_expected_sbc = sys.argv[6]
force_channel_mode = sys.argv[7]
force_channel_mode = int(sys.argv[7])
sampling_frequency = 44100
if not encoder_input_wav.endswith('.wav'):