254 lines
8.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# ------------------------------------------------------------
# AAC Coder/Decoder - Inverse AAC Coder (Core)
#
# Multimedia course at Aristotle University of
# Thessaloniki (AUTh)
#
# Author:
# Christos Choutouridis (ΑΕΜ 8997)
# cchoutou@ece.auth.gr
#
# Description:
# - Level 1 AAC decoder orchestration (inverse of aac_coder_1()).
# - Level 2 AAC decoder orchestration (inverse of aac_coder_1()).
#
# ------------------------------------------------------------
from __future__ import annotations
from pathlib import Path
from typing import Union
import soundfile as sf
from core.aac_filterbank import aac_i_filter_bank
from core.aac_tns import aac_i_tns
from core.aac_types import *
# -----------------------------------------------------------------------------
# Public helpers
# -----------------------------------------------------------------------------
def aac_unpack_seq_channels_to_frame_f(frame_type: FrameType, chl_f: FrameChannelF, chr_f: FrameChannelF) -> FrameF:
"""
Re-pack per-channel spectra from the Level-1 AACSeq1 schema into the stereo
FrameF container expected by aac_i_filter_bank().
Parameters
----------
frame_type : FrameType
"OLS" | "LSS" | "ESH" | "LPS".
chl_f : FrameChannelF
Left channel coefficients:
- ESH: (128, 8)
- else: (1024, 1)
chr_f : FrameChannelF
Right channel coefficients:
- ESH: (128, 8)
- else: (1024, 1)
Returns
-------
FrameF
Stereo coefficients:
- ESH: (128, 16) packed as [L0 R0 L1 R1 ... L7 R7]
- else: (1024, 2)
"""
if frame_type == "ESH":
if chl_f.shape != (128, 8) or chr_f.shape != (128, 8):
raise ValueError("ESH channel frame_F must have shape (128, 8).")
frame_f = np.empty((128, 16), dtype=np.float64)
for j in range(8):
frame_f[:, 2 * j + 0] = chl_f[:, j]
frame_f[:, 2 * j + 1] = chr_f[:, j]
return frame_f
# Non-ESH: expected (1024, 1) per channel in Level-1 schema.
if chl_f.shape != (1024, 1) or chr_f.shape != (1024, 1):
raise ValueError("Non-ESH channel frame_F must have shape (1024, 1).")
frame_f = np.empty((1024, 2), dtype=np.float64)
frame_f[:, 0] = chl_f[:, 0]
frame_f[:, 1] = chr_f[:, 0]
return frame_f
def aac_remove_padding(y_pad: StereoSignal, hop: int = 1024) -> StereoSignal:
"""
Remove the boundary padding that the Level-1 encoder adds:
hop samples at start and hop samples at end.
Parameters
----------
y_pad : StereoSignal (np.ndarray)
Reconstructed padded stream, shape (N_pad, 2).
hop : int
Hop size in samples (default 1024).
Returns
-------
StereoSignal (np.ndarray)
Unpadded reconstructed stream, shape (N_pad - 2*hop, 2).
Raises
------
ValueError
If y_pad is too short to unpad.
"""
if y_pad.shape[0] < 2 * hop:
raise ValueError("Decoded stream too short to unpad.")
return y_pad[hop:-hop, :]
# -----------------------------------------------------------------------------
# Level 1 decoder
# -----------------------------------------------------------------------------
def aac_decoder_1(aac_seq_1: AACSeq1, filename_out: Union[str, Path]) -> StereoSignal:
"""
Level-1 AAC decoder (inverse of aac_coder_1()).
This function preserves the behavior of the original level_1 implementation:
- Reconstruct the full padded stream by overlap-adding K synthesized frames
- Remove hop padding at the beginning and hop padding at the end
- Write the reconstructed stereo WAV file (48 kHz)
- Return reconstructed stereo samples as float64
Parameters
----------
aac_seq_1 : AACSeq1
Encoded sequence as produced by aac_coder_1().
filename_out : Union[str, Path]
Output WAV filename. Assumption: 48 kHz, stereo.
Returns
-------
StereoSignal
Decoded audio samples (time-domain), stereo, shape (N, 2), dtype float64.
"""
filename_out = Path(filename_out)
hop = 1024
win = 2048
K = len(aac_seq_1)
# Output includes the encoder padding region, so we reconstruct the full padded stream.
# For K frames: last frame starts at (K-1)*hop and spans win,
# so total length = (K-1)*hop + win.
n_pad = (K - 1) * hop + win
y_pad: StereoSignal = np.zeros((n_pad, 2), dtype=np.float64)
for i, fr in enumerate(aac_seq_1):
frame_type: FrameType = fr["frame_type"]
win_type: WinType = fr["win_type"]
chl_f = np.asarray(fr["chl"]["frame_F"], dtype=np.float64)
chr_f = np.asarray(fr["chr"]["frame_F"], dtype=np.float64)
frame_f: FrameF = aac_unpack_seq_channels_to_frame_f(frame_type, chl_f, chr_f)
frame_t_hat: FrameT = aac_i_filter_bank(frame_f, frame_type, win_type) # (2048, 2)
start = i * hop
y_pad[start:start + win, :] += frame_t_hat
y: StereoSignal = aac_remove_padding(y_pad, hop=hop)
# Level 1 assumption: 48 kHz output.
sf.write(str(filename_out), y, 48000)
return y
# -----------------------------------------------------------------------------
# Level 2 decoder
# -----------------------------------------------------------------------------
def aac_decoder_2(aac_seq_2: AACSeq2, filename_out: Union[str, Path]) -> StereoSignal:
"""
Level-2 AAC decoder (inverse of aac_coder_2).
Behavior matches Level 1 decoder pipeline, with additional iTNS stage:
- Per frame/channel: inverse TNS using stored coefficients
- Re-pack to stereo frame_F
- IMDCT + windowing
- Overlap-add over frames
- Remove Level-1 padding (hop samples start/end)
- Write output WAV (48 kHz)
Parameters
----------
aac_seq_2 : AACSeq2
Encoded sequence as produced by aac_coder_2().
filename_out : Union[str, Path]
Output WAV filename.
Returns
-------
StereoSignal
Decoded audio samples (time-domain), stereo, shape (N, 2), dtype float64.
"""
filename_out = Path(filename_out)
hop = 1024
win = 2048
K = len(aac_seq_2)
if K <= 0:
raise ValueError("aac_seq_2 must contain at least one frame.")
n_pad = (K - 1) * hop + win
y_pad = np.zeros((n_pad, 2), dtype=np.float64)
for i, fr in enumerate(aac_seq_2):
frame_type: FrameType = fr["frame_type"]
win_type: WinType = fr["win_type"]
chl_f_tns = np.asarray(fr["chl"]["frame_F"], dtype=np.float64)
chr_f_tns = np.asarray(fr["chr"]["frame_F"], dtype=np.float64)
chl_coeffs = np.asarray(fr["chl"]["tns_coeffs"], dtype=np.float64)
chr_coeffs = np.asarray(fr["chr"]["tns_coeffs"], dtype=np.float64)
# Inverse TNS per channel
chl_f = aac_i_tns(chl_f_tns, frame_type, chl_coeffs)
chr_f = aac_i_tns(chr_f_tns, frame_type, chr_coeffs)
# Re-pack to the stereo container expected by aac_i_filter_bank
if frame_type == "ESH":
if chl_f.shape != (128, 8) or chr_f.shape != (128, 8):
raise ValueError("ESH channel frame_F must have shape (128, 8).")
frame_f: FrameF = np.empty((128, 16), dtype=np.float64)
for j in range(8):
frame_f[:, 2 * j + 0] = chl_f[:, j]
frame_f[:, 2 * j + 1] = chr_f[:, j]
else:
# Accept either (1024,1) or (1024,) from your internal convention.
if chl_f.shape == (1024,):
chl_col = chl_f.reshape(1024, 1)
elif chl_f.shape == (1024, 1):
chl_col = chl_f
else:
raise ValueError("Non-ESH left channel frame_F must be shape (1024,) or (1024, 1).")
if chr_f.shape == (1024,):
chr_col = chr_f.reshape(1024, 1)
elif chr_f.shape == (1024, 1):
chr_col = chr_f
else:
raise ValueError("Non-ESH right channel frame_F must be shape (1024,) or (1024, 1).")
frame_f = np.empty((1024, 2), dtype=np.float64)
frame_f[:, 0] = chl_col[:, 0]
frame_f[:, 1] = chr_col[:, 0]
frame_t_hat: FrameT = aac_i_filter_bank(frame_f, frame_type, win_type)
start = i * hop
y_pad[start : start + win, :] += frame_t_hat
y = aac_remove_padding(y_pad, hop=hop)
sf.write(str(filename_out), y, 48000)
return y