844 lines
25 KiB
Python
844 lines
25 KiB
Python
#! /usr/bin/env python
|
|
|
|
from __future__ import annotations
|
|
|
|
from pathlib import Path
|
|
from typing import Dict, Tuple, List, Literal, TypedDict, Union
|
|
|
|
import numpy as np
|
|
import soundfile as sf
|
|
from scipy.signal.windows import kaiser
|
|
|
|
# --------------------------------
|
|
# Public Type aliases (Level 1)
|
|
# --------------------------------
|
|
|
|
FrameType = Literal["OLS", "LSS", "ESH", "LPS"]
|
|
"""
|
|
Frame type codes:
|
|
- "OLS": ONLY_LONG_SEQUENCE
|
|
- "LSS": LONG_START_SEQUENCE
|
|
- "ESH": EIGHT_SHORT_SEQUENCE
|
|
- "LPS": LONG_STOP_SEQUENCE
|
|
"""
|
|
|
|
WinType = Literal["KBD", "SIN"]
|
|
"""
|
|
Window type codes:
|
|
- "KBD": Kaiser-Bessel-Derived
|
|
- "SIN": sinusoid
|
|
"""
|
|
|
|
FrameT = np.ndarray
|
|
"""
|
|
Time-domain frame.
|
|
Expected shape: (2048, 2) for stereo (two channels).
|
|
dtype: float (e.g., float32/float64).
|
|
"""
|
|
|
|
FrameChannelT = np.ndarray
|
|
"""
|
|
Time-domain single channel frame.
|
|
Expected shape: (2048,).
|
|
dtype: float (e.g., float32/float64).
|
|
"""
|
|
|
|
|
|
FrameF = np.ndarray
|
|
"""
|
|
Frequency-domain frame (MDCT coefficients).
|
|
As per spec (Level 1):
|
|
- If frame_type in {"OLS","LSS","LPS"}: shape (1024, 2)
|
|
- If frame_type == "ESH": shape (128, 16) where 8 subframes x 2 channels
|
|
are placed in columns according to the subframe order (i.e., each subframe is (128,2)).
|
|
"""
|
|
|
|
ChannelKey = Literal["chl", "chr"]
|
|
|
|
|
|
class AACChannelFrameF(TypedDict):
|
|
"""Channel payload for aac_seq_1[i]["chl"] or ["chr"] (Level 1)."""
|
|
frame_F: np.ndarray
|
|
# frame_F for one channel:
|
|
# - ESH: shape (128, 8)
|
|
# - else: shape (1024, 1)
|
|
|
|
|
|
class AACSeq1Frame(TypedDict):
|
|
"""One frame dictionary of aac_seq_1 (Level 1)."""
|
|
frame_type: FrameType
|
|
win_type: WinType
|
|
chl: AACChannelFrameF
|
|
chr: AACChannelFrameF
|
|
|
|
|
|
AACSeq1 = List[AACSeq1Frame]
|
|
"""AAC sequence for Level 1:
|
|
List of length K (K = number of frames).
|
|
Each element is a dict with keys:
|
|
- "frame_type", "win_type", "chl", "chr"
|
|
"""
|
|
|
|
# Global Options
|
|
# -----------------------------------------------------------------------------
|
|
|
|
# Window type
|
|
# Options: "SIN", "KBD"
|
|
WIN_TYPE: WinType = "SIN"
|
|
|
|
|
|
# Private helpers for SSC
|
|
# -----------------------------------------------------------------------------
|
|
|
|
# See Table 1 in mm-2025-hw-v0.1.pdf
|
|
STEREO_MERGE_TABLE: Dict[Tuple[FrameType, FrameType], FrameType] = {
|
|
("OLS", "OLS"): "OLS",
|
|
("OLS", "LSS"): "LSS",
|
|
("OLS", "ESH"): "ESH",
|
|
("OLS", "LPS"): "LPS",
|
|
("LSS", "OLS"): "LSS",
|
|
("LSS", "LSS"): "LSS",
|
|
("LSS", "ESH"): "ESH",
|
|
("LSS", "LPS"): "ESH",
|
|
("ESH", "OLS"): "ESH",
|
|
("ESH", "LSS"): "ESH",
|
|
("ESH", "ESH"): "ESH",
|
|
("ESH", "LPS"): "ESH",
|
|
("LPS", "OLS"): "LPS",
|
|
("LPS", "LSS"): "ESH",
|
|
("LPS", "ESH"): "ESH",
|
|
("LPS", "LPS"): "LPS",
|
|
}
|
|
|
|
def _detect_attack(next_frame_channel: FrameChannelT) -> bool:
|
|
"""
|
|
Detect if next frame (single channel) implies ESH according to the spec's attack criterion.
|
|
|
|
Parameters
|
|
----------
|
|
next_frame_channel : FrameChannelT
|
|
One channel of next_frame_T (shape: (2048,), dtype float).
|
|
|
|
Returns
|
|
-------
|
|
attack : bool
|
|
True if an attack is detected (=> next frame predicted ESH), else False.
|
|
|
|
Notes
|
|
-----
|
|
The spec describes:
|
|
|
|
- High-pass filter applied to next_frame_channel
|
|
- Split into 16 segments of length 128
|
|
- Compute segment energies s(l)
|
|
- Compute ds(l) = s(l) / s(l-1)
|
|
- Attack exists if there exists l in {1..7} such that:
|
|
s(l) > 1e-3 and ds(l) > 10
|
|
"""
|
|
x = next_frame_channel # local alias, x assumed to be a 1-D array of length 2048
|
|
|
|
# High-pass filter H(z) = (1 - z^-1) / (1 - 0.5 z^-1)
|
|
# Implemented as: y[n] = x[n] - x[n-1] + 0.5*y[n-1]
|
|
y = np.zeros_like(x)
|
|
prev_x = 0.0
|
|
prev_y = 0.0
|
|
for n in range(x.shape[0]):
|
|
xn = float(x[n])
|
|
yn = (xn - prev_x) + 0.5 * prev_y
|
|
y[n] = yn
|
|
prev_x = xn
|
|
prev_y = yn
|
|
|
|
# Segment energies over 16 blocks of 128 samples.
|
|
s = np.empty(16, dtype=np.float64)
|
|
for l in range(16):
|
|
a = l * 128
|
|
b = (l + 1) * 128
|
|
seg = y[a:b]
|
|
s[l] = float(np.sum(seg * seg))
|
|
|
|
# ds(l) for l>=1. For l=0 not defined, keep 0.
|
|
ds = np.zeros(16, dtype=np.float64)
|
|
eps = 1e-12 # avoid division by zero without changing logic materially
|
|
for l in range(1, 16):
|
|
ds[l] = s[l] / max(s[l - 1], eps)
|
|
|
|
# Spec: check l in {1..7}
|
|
for l in range(1, 8):
|
|
if (s[l] > 1e-3) and (ds[l] > 10.0):
|
|
return True
|
|
|
|
return False
|
|
|
|
|
|
def _decide_frame_type(prev_frame_type: FrameType, attack: bool) -> FrameType:
|
|
"""
|
|
Decide current frame type for a single channel based on prev_frame_type and next-frame attack.
|
|
|
|
Parameters
|
|
----------
|
|
prev_frame_type : FrameType
|
|
Previous frame type (one of "OLS","LSS","ESH","LPS").
|
|
attack : bool
|
|
Whether next frame is predicted ESH for this channel.
|
|
|
|
Returns
|
|
-------
|
|
frame_type : FrameType
|
|
The per-channel decision for the current frame.
|
|
|
|
Rules (spec)
|
|
------------
|
|
- If prev is "LSS" => current is "ESH" (fixed)
|
|
- If prev is "LPS" => current is "OLS" (fixed)
|
|
- If prev is "OLS" => current is "LSS" if attack else "OLS"
|
|
- If prev is "ESH" => current is "ESH" if attack else "LPS"
|
|
"""
|
|
if prev_frame_type == "LSS":
|
|
return "ESH"
|
|
if prev_frame_type == "LPS":
|
|
return "OLS"
|
|
if prev_frame_type == "OLS":
|
|
return "LSS" if attack else "OLS"
|
|
if prev_frame_type == "ESH":
|
|
return "ESH" if attack else "LPS"
|
|
|
|
raise ValueError(f"Invalid prev_frame_type: {prev_frame_type!r}")
|
|
|
|
|
|
def _stereo_merge(ft_l: FrameType, ft_r: FrameType) -> FrameType:
|
|
"""
|
|
Merge per-channel frame types into one common frame type using the spec table.
|
|
|
|
Parameters
|
|
----------
|
|
ft_l : FrameType
|
|
Frame type decision for channel 0 (left).
|
|
ft_r : FrameType
|
|
Frame type decision for channel 1 (right).
|
|
|
|
Returns
|
|
-------
|
|
common : FrameType
|
|
The common final frame type.
|
|
"""
|
|
try:
|
|
return STEREO_MERGE_TABLE[(ft_l, ft_r)]
|
|
except KeyError as e:
|
|
raise ValueError(f"Invalid stereo merge pair: {(ft_l, ft_r)}") from e
|
|
|
|
|
|
|
|
# Private helpers for Filterbank
|
|
# -----------------------------------------------------------------------------
|
|
|
|
def _sin_window(N: int) -> np.ndarray:
|
|
"""
|
|
Sine window (full length N).
|
|
w[n] = sin(pi/N * (n + 0.5)), 0 <= n < N
|
|
"""
|
|
n = np.arange(N, dtype=np.float64)
|
|
return np.sin((np.pi / N) * (n + 0.5))
|
|
|
|
|
|
def _kbd_window(N: int, alpha: float) -> np.ndarray:
|
|
"""
|
|
Kaiser-Bessel-Derived (KBD) window (full length N).
|
|
|
|
This follows the standard KBD construction:
|
|
- Build Kaiser kernel of length N/2 + 1
|
|
- Use cumulative sum and sqrt normalization to form left and right halves
|
|
"""
|
|
half = N // 2
|
|
|
|
# Kaiser kernel length: half + 1 samples (0 .. half)
|
|
# beta = pi * alpha per the usual correspondence with the ISO definition
|
|
kernel = kaiser(half + 1, beta=np.pi * alpha).astype(np.float64)
|
|
|
|
csum = np.cumsum(kernel)
|
|
denom = csum[-1]
|
|
|
|
w_left = np.sqrt(csum[:-1] / denom) # length half, n = 0 .. half-1
|
|
w_right = w_left[::-1] # mirror for second half
|
|
|
|
return np.concatenate([w_left, w_right])
|
|
|
|
|
|
def _long_window(win_type: WinType) -> np.ndarray:
|
|
"""
|
|
Long window (length 2048) for the selected win_type.
|
|
"""
|
|
if win_type == "SIN":
|
|
return _sin_window(2048)
|
|
if win_type == "KBD":
|
|
# Assignment-specific alpha values
|
|
return _kbd_window(2048, alpha=6.0)
|
|
raise ValueError(f"Invalid win_type: {win_type!r}")
|
|
|
|
|
|
def _short_window(win_type: WinType) -> np.ndarray:
|
|
"""
|
|
Short window (length 256) for the selected win_type.
|
|
"""
|
|
if win_type == "SIN":
|
|
return _sin_window(256)
|
|
if win_type == "KBD":
|
|
# Assignment-specific alpha values
|
|
return _kbd_window(256, alpha=4.0)
|
|
raise ValueError(f"Invalid win_type: {win_type!r}")
|
|
|
|
|
|
def _window_sequence(frame_type: FrameType, win_type: WinType) -> np.ndarray:
|
|
"""
|
|
Build the 2048-sample window sequence for OLS/LSS/LPS.
|
|
|
|
We follow the simplified assumption:
|
|
- The same window shape (KBD or SIN) is used globally (no mixed halves).
|
|
- Therefore, the left and right halves are drawn from the same family.
|
|
"""
|
|
wL = _long_window(win_type) # length 2048
|
|
wS = _short_window(win_type) # length 256
|
|
|
|
if frame_type == "OLS":
|
|
return wL
|
|
|
|
if frame_type == "LSS":
|
|
# 0..1023: left half of long window
|
|
# 1024..1471: ones (448 samples)
|
|
# 1472..1599: right half of short window (128 samples)
|
|
# 1600..2047: zeros (448 samples)
|
|
out = np.zeros(2048, dtype=np.float64)
|
|
out[0:1024] = wL[0:1024]
|
|
out[1024:1472] = 1.0
|
|
out[1472:1600] = wS[128:256]
|
|
out[1600:2048] = 0.0
|
|
return out
|
|
|
|
if frame_type == "LPS":
|
|
# 0..447: zeros (448)
|
|
# 448..575: left half of short window (128)
|
|
# 576..1023: ones (448)
|
|
# 1024..2047: right half of long window (1024)
|
|
out = np.zeros(2048, dtype=np.float64)
|
|
out[0:448] = 0.0
|
|
out[448:576] = wS[0:128]
|
|
out[576:1024] = 1.0
|
|
out[1024:2048] = wL[1024:2048]
|
|
return out
|
|
|
|
raise ValueError(f"Invalid frame_type for long window sequence: {frame_type!r}")
|
|
|
|
|
|
def _mdct(s: np.ndarray) -> np.ndarray:
|
|
"""
|
|
MDCT (direct form) as given in the assignment.
|
|
|
|
Input:
|
|
s: windowed time samples of length N (N = 2048 or 256)
|
|
|
|
Output:
|
|
X: MDCT coefficients of length N/2
|
|
|
|
Definition:
|
|
X[k] = 2 * sum_{n=0 .. N-1} s[n] * cos(2*pi/N * (n + n0) * (k + 1/2))
|
|
where n0 = (N/2 + 1)/2
|
|
"""
|
|
s = np.asarray(s, dtype=np.float64)
|
|
N = int(s.shape[0])
|
|
if N not in (2048, 256):
|
|
raise ValueError("MDCT input length must be 2048 or 256.")
|
|
|
|
n0 = (N / 2.0 + 1.0) / 2.0
|
|
|
|
n = np.arange(N, dtype=np.float64) + n0
|
|
k = np.arange(N // 2, dtype=np.float64) + 0.5
|
|
|
|
# Cosine matrix: shape (N, N/2)
|
|
C = np.cos((2.0 * np.pi / N) * np.outer(n, k))
|
|
X = 2.0 * (s @ C)
|
|
|
|
return X
|
|
|
|
def _imdct(X: np.ndarray) -> np.ndarray:
|
|
"""
|
|
IMDCT (direct form) as given in the assignment.
|
|
|
|
Input:
|
|
X: MDCT coefficients of length N/2 (N = 2048 or 256)
|
|
|
|
Output:
|
|
s: time samples of length N
|
|
|
|
Definition:
|
|
s[n] = (2/N) * sum_{k=0 .. N/2-1} X[k] * cos(2*pi/N * (n + n0) * (k + 1/2))
|
|
where n0 = (N/2 + 1)/2
|
|
"""
|
|
X = np.asarray(X, dtype=np.float64).reshape(-1)
|
|
K = int(X.shape[0])
|
|
if K not in (1024, 128):
|
|
raise ValueError("IMDCT input length must be 1024 or 128.")
|
|
|
|
N = 2 * K
|
|
n0 = (N / 2.0 + 1.0) / 2.0
|
|
|
|
n = np.arange(N, dtype=np.float64) + n0
|
|
k = np.arange(K, dtype=np.float64) + 0.5
|
|
|
|
C = np.cos((2.0 * np.pi / N) * np.outer(n, k)) # (N, K)
|
|
s = (2.0 / N) * (C @ X)
|
|
|
|
return s
|
|
|
|
|
|
def _filter_bank_esh_channel(x_ch: np.ndarray, win_type: WinType) -> np.ndarray:
|
|
"""
|
|
ESH analysis for one channel.
|
|
|
|
Returns:
|
|
X_esh: shape (128, 8), where each column is the 128 MDCT coeffs of one short window.
|
|
"""
|
|
wS = _short_window(win_type)
|
|
X_esh = np.empty((128, 8), dtype=np.float64)
|
|
|
|
# ESH subwindows are taken from the central region:
|
|
# start positions: 448 + 128*j, j = 0..7
|
|
for j in range(8):
|
|
start = 448 + 128 * j
|
|
seg = x_ch[start:start + 256] * wS
|
|
X_esh[:, j] = _mdct(seg)
|
|
|
|
return X_esh
|
|
|
|
|
|
|
|
|
|
def _unpack_esh(frame_F: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
|
|
"""
|
|
Unpack ESH spectrum from shape (128, 16) into per-channel arrays (128, 8).
|
|
|
|
Mapping is the inverse of the packing used in filter_bank():
|
|
out[:, 2*j] = left[:, j]
|
|
out[:, 2*j+1] = right[:, j]
|
|
"""
|
|
if frame_F.shape != (128, 16):
|
|
raise ValueError("ESH frame_F must have shape (128, 16).")
|
|
|
|
left = np.empty((128, 8), dtype=np.float64)
|
|
right = np.empty((128, 8), dtype=np.float64)
|
|
for j in range(8):
|
|
left[:, j] = frame_F[:, 2 * j + 0]
|
|
right[:, j] = frame_F[:, 2 * j + 1]
|
|
return left, right
|
|
|
|
|
|
def _i_filter_bank_esh_channel(X_esh: np.ndarray, win_type: WinType) -> np.ndarray:
|
|
"""
|
|
ESH synthesis for one channel.
|
|
|
|
Input:
|
|
X_esh: (128, 8) MDCT coeffs for 8 short windows
|
|
|
|
Output:
|
|
x_ch: (2048, ) time-domain frame contribution (windowed),
|
|
ready for OLA at the caller level.
|
|
"""
|
|
if X_esh.shape != (128, 8):
|
|
raise ValueError("X_esh must have shape (128, 8).")
|
|
|
|
wS = _short_window(win_type)
|
|
out = np.zeros(2048, dtype=np.float64)
|
|
|
|
# Each short IMDCT returns 256 samples. Place them at:
|
|
# start = 448 + 128*j, j=0..7 (50% overlap)
|
|
for j in range(8):
|
|
seg = _imdct(X_esh[:, j]) * wS # (256,)
|
|
start = 448 + 128 * j
|
|
out[start:start + 256] += seg
|
|
|
|
return out
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# Public Function prototypes (Level 1)
|
|
# -----------------------------------------------------------------------------
|
|
|
|
def SSC(frame_T: FrameT, next_frame_T: FrameT, prev_frame_type: FrameType) -> FrameType:
|
|
"""
|
|
Sequence Segmentation Control (SSC).
|
|
Selects and returns the frame type for the current frame (i) based on input parameters.
|
|
|
|
Parameters
|
|
-------
|
|
frame_T: FrameT
|
|
current time-domain frame i, stereo, shape (2048, 2)
|
|
next_frame_T: FrameT
|
|
next time-domain frame (i+1), stereo, shape (2048, 2)
|
|
(used to decide transitions to/from ESH)
|
|
prev_frame_type: FrameType
|
|
frame type chosen for the previous frame (i-1)
|
|
|
|
Returns
|
|
-------
|
|
frame_type : FrameType
|
|
- "OLS" (ONLY_LONG_SEQUENCE)
|
|
- "LSS" (LONG_START_SEQUENCE)
|
|
- "ESH" (EIGHT_SHORT_SEQUENCE)
|
|
- "LPS" (LONG_STOP_SEQUENCE)
|
|
"""
|
|
if frame_T.shape != (2048, 2):
|
|
raise ValueError("frame_T must have shape (2048, 2).")
|
|
if next_frame_T.shape != (2048, 2):
|
|
raise ValueError("next_frame_T must have shape (2048, 2).")
|
|
|
|
# Detect attack independently per channel on next frame.
|
|
attack_l = _detect_attack(next_frame_T[:, 0])
|
|
attack_r = _detect_attack(next_frame_T[:, 1])
|
|
|
|
# Decide per-channel type based on shared prev_frame_type.
|
|
ft_l = _decide_frame_type(prev_frame_type, attack_l)
|
|
ft_r = _decide_frame_type(prev_frame_type, attack_r)
|
|
|
|
# Stereo merge as per Table 1.
|
|
return _stereo_merge(ft_l, ft_r)
|
|
|
|
|
|
def filter_bank(frame_T: FrameT, frame_type: FrameType, win_type: WinType) -> FrameF:
|
|
"""
|
|
Filterbank stage (MDCT analysis).
|
|
|
|
Parameters
|
|
----------
|
|
frame_T : FrameT
|
|
Time-domain frame, stereo, shape (2048, 2).
|
|
frame_type : FrameType
|
|
Type of the frame under encoding ("OLS"|"LSS"|"ESH"|"LPS").
|
|
win_type : WinType
|
|
Window type ("KBD" or "SIN") used for the current frame.
|
|
|
|
Returns
|
|
-------
|
|
frame_F : FrameF
|
|
Frequency-domain MDCT coefficients:
|
|
- If frame_type in {"OLS","LSS","LPS"}: array shape (1024, 2)
|
|
containing MDCT coefficients for both channels.
|
|
- If frame_type == "ESH": contains 8 subframes, each subframe has shape (128,2),
|
|
placed in columns according to subframe order, i.e. overall shape (128, 16).
|
|
"""
|
|
if frame_T.shape != (2048, 2):
|
|
raise ValueError("frame_T must have shape (2048, 2).")
|
|
|
|
xL = frame_T[:, 0].astype(np.float64, copy=False)
|
|
xR = frame_T[:, 1].astype(np.float64, copy=False)
|
|
|
|
if frame_type in ("OLS", "LSS", "LPS"):
|
|
w = _window_sequence(frame_type, win_type) # length 2048
|
|
XL = _mdct(xL * w) # length 1024
|
|
XR = _mdct(xR * w) # length 1024
|
|
out = np.empty((1024, 2), dtype=np.float64)
|
|
out[:, 0] = XL
|
|
out[:, 1] = XR
|
|
return out
|
|
|
|
if frame_type == "ESH":
|
|
Xl = _filter_bank_esh_channel(xL, win_type) # (128, 8)
|
|
Xr = _filter_bank_esh_channel(xR, win_type) # (128, 8)
|
|
|
|
# Pack into (128, 16): each subframe as (128,2) placed in columns
|
|
out = np.empty((128, 16), dtype=np.float64)
|
|
for j in range(8):
|
|
out[:, 2 * j + 0] = Xl[:, j]
|
|
out[:, 2 * j + 1] = Xr[:, j]
|
|
return out
|
|
|
|
raise ValueError(f"Invalid frame_type: {frame_type!r}")
|
|
|
|
|
|
def i_filter_bank(frame_F: FrameF, frame_type: FrameType, win_type: WinType) -> FrameT:
|
|
"""
|
|
Inverse filterbank (IMDCT synthesis).
|
|
|
|
Parameters
|
|
----------
|
|
frame_F : FrameF
|
|
Frequency-domain MDCT coefficients as produced by filter_bank().
|
|
frame_type : FrameType
|
|
Frame type ("OLS"|"LSS"|"ESH"|"LPS").
|
|
win_type : WinType
|
|
Window type ("KBD" or "SIN").
|
|
|
|
Returns
|
|
-------
|
|
frame_T : FrameT
|
|
Reconstructed time-domain frame, stereo, shape (2048, 2).
|
|
"""
|
|
if frame_type in ("OLS", "LSS", "LPS"):
|
|
if frame_F.shape != (1024, 2):
|
|
raise ValueError("For OLS/LSS/LPS, frame_F must have shape (1024, 2).")
|
|
|
|
w = _window_sequence(frame_type, win_type)
|
|
|
|
xL = _imdct(frame_F[:, 0]) * w
|
|
xR = _imdct(frame_F[:, 1]) * w
|
|
|
|
out = np.empty((2048, 2), dtype=np.float64)
|
|
out[:, 0] = xL
|
|
out[:, 1] = xR
|
|
return out
|
|
|
|
if frame_type == "ESH":
|
|
if frame_F.shape != (128, 16):
|
|
raise ValueError("For ESH, frame_F must have shape (128, 16).")
|
|
|
|
Xl, Xr = _unpack_esh(frame_F)
|
|
xL = _i_filter_bank_esh_channel(Xl, win_type)
|
|
xR = _i_filter_bank_esh_channel(Xr, win_type)
|
|
|
|
out = np.empty((2048, 2), dtype=np.float64)
|
|
out[:, 0] = xL
|
|
out[:, 1] = xR
|
|
return out
|
|
|
|
raise ValueError(f"Invalid frame_type: {frame_type!r}")
|
|
|
|
|
|
def aac_coder_1(filename_in: Union[str, Path]) -> AACSeq1:
|
|
"""
|
|
Level-1 AAC encoder.
|
|
|
|
Parameters
|
|
----------
|
|
filename_in : str | Path
|
|
Input WAV filename.
|
|
Assumption: stereo audio, sampling rate 48 kHz.
|
|
|
|
Returns
|
|
-------
|
|
aac_seq_1 : AACSeq1
|
|
List of K encoded frames.
|
|
For each i:
|
|
|
|
- aac_seq_1[i]["frame_type"]: FrameType
|
|
- aac_seq_1[i]["win_type"]: WinType
|
|
- aac_seq_1[i]["chl"]["frame_F"]:
|
|
- ESH: shape (128, 8)
|
|
- else: shape (1024, 1)
|
|
- aac_seq_1[i]["chr"]["frame_F"]:
|
|
- ESH: shape (128, 8)
|
|
- else: shape (1024, 1)
|
|
"""
|
|
filename_in = Path(filename_in)
|
|
|
|
x, fs = sf.read(str(filename_in), always_2d=True)
|
|
x = np.asarray(x, dtype=np.float64)
|
|
|
|
if x.shape[1] != 2:
|
|
raise ValueError("Input must be stereo (2 channels).")
|
|
if fs != 48000:
|
|
raise ValueError("Input sampling rate must be 48 kHz.")
|
|
|
|
hop = 1024
|
|
win = 2048
|
|
|
|
# Pad at the beginning to support the first overlap region.
|
|
# Tail padding is kept minimal; next-frame is padded on-the-fly when needed.
|
|
pad_pre = np.zeros((hop, 2), dtype=np.float64)
|
|
pad_post = np.zeros((hop, 2), dtype=np.float64)
|
|
x_pad = np.vstack([pad_pre, x, pad_post])
|
|
|
|
# Number of frames such that current frame fits; next frame will be padded if needed.
|
|
K = int((x_pad.shape[0] - win) // hop + 1)
|
|
if K <= 0:
|
|
raise ValueError("Input too short for framing.")
|
|
|
|
aac_seq: AACSeq1 = []
|
|
prev_frame_type: FrameType = "OLS"
|
|
|
|
for i in range(K):
|
|
start = i * hop
|
|
|
|
frame_t: FrameT = x_pad[start:start + win, :]
|
|
if frame_t.shape != (win, 2):
|
|
# This should not happen due to K definition, but we keep it explicit.
|
|
raise ValueError("Internal framing error: frame_t has wrong shape.")
|
|
|
|
next_t = x_pad[start + hop:start + hop + win, :]
|
|
|
|
# Ensure next_t is always (2048,2) by zero-padding at the tail.
|
|
if next_t.shape[0] < win:
|
|
tail = np.zeros((win - next_t.shape[0], 2), dtype=np.float64)
|
|
next_t = np.vstack([next_t, tail])
|
|
|
|
frame_type = SSC(frame_t, next_t, prev_frame_type)
|
|
frame_f = filter_bank(frame_t, frame_type, WIN_TYPE)
|
|
|
|
# Store per-channel as required by AACSeq1 schema
|
|
if frame_type == "ESH":
|
|
# frame_f: (128, 16) packed as [L0 R0 L1 R1 ... L7 R7]
|
|
chl_f = np.empty((128, 8), dtype=np.float64)
|
|
chr_f = np.empty((128, 8), dtype=np.float64)
|
|
for j in range(8):
|
|
chl_f[:, j] = frame_f[:, 2 * j + 0]
|
|
chr_f[:, j] = frame_f[:, 2 * j + 1]
|
|
else:
|
|
# frame_f: (1024, 2)
|
|
chl_f = frame_f[:, 0:1].astype(np.float64, copy=False)
|
|
chr_f = frame_f[:, 1:2].astype(np.float64, copy=False)
|
|
|
|
aac_seq.append({
|
|
"frame_type": frame_type,
|
|
"win_type": WIN_TYPE,
|
|
"chl": {"frame_F": chl_f},
|
|
"chr": {"frame_F": chr_f},
|
|
})
|
|
prev_frame_type = frame_type
|
|
return aac_seq
|
|
|
|
|
|
def i_aac_coder_1(aac_seq_1: AACSeq1, filename_out: Union[str, Path]) -> np.ndarray:
|
|
"""
|
|
Level-1 AAC decoder (inverse of aac_coder_1()).
|
|
|
|
Parameters
|
|
----------
|
|
aac_seq_1 : AACSeq1
|
|
Encoded sequence as produced by aac_coder_1().
|
|
filename_out : str | Path
|
|
Output WAV filename.
|
|
Assumption: stereo audio, sampling rate 48 kHz.
|
|
|
|
Returns
|
|
-------
|
|
x : np.ndarray
|
|
Decoded audio samples (time-domain).
|
|
Expected shape: (N, 2) for stereo (N depends on input length).
|
|
"""
|
|
filename_out = Path(filename_out)
|
|
|
|
hop = 1024
|
|
win = 2048
|
|
K = len(aac_seq_1)
|
|
|
|
# Output includes the encoder padding region, so we reconstruct
|
|
# full padded stream. For K frames: last frame starts at (K-1)*hop and spans win,
|
|
# so total length = (K-1)*hop + win
|
|
n_pad = (K - 1) * hop + win
|
|
y_pad = np.zeros((n_pad, 2), dtype=np.float64)
|
|
|
|
for i, fr in enumerate(aac_seq_1):
|
|
frame_type = fr["frame_type"]
|
|
win_type = fr["win_type"]
|
|
|
|
chl_f = np.asarray(fr["chl"]["frame_F"], dtype=np.float64)
|
|
chr_f = np.asarray(fr["chr"]["frame_F"], dtype=np.float64)
|
|
|
|
# Re-pack into the format expected by i_filter_bank()
|
|
if frame_type == "ESH":
|
|
if chl_f.shape != (128, 8) or chr_f.shape != (128, 8):
|
|
raise ValueError("ESH channel frame_F must have shape (128, 8).")
|
|
|
|
frame_f = np.empty((128, 16), dtype=np.float64)
|
|
for j in range(8):
|
|
frame_f[:, 2 * j + 0] = chl_f[:, j]
|
|
frame_f[:, 2 * j + 1] = chr_f[:, j]
|
|
else:
|
|
if chl_f.shape != (1024, 1) or chr_f.shape != (1024, 1):
|
|
raise ValueError("Non-ESH channel frame_F must have shape (1024, 1).")
|
|
|
|
frame_f = np.empty((1024, 2), dtype=np.float64)
|
|
frame_f[:, 0] = chl_f[:, 0]
|
|
frame_f[:, 1] = chr_f[:, 0]
|
|
|
|
frame_t_hat = i_filter_bank(frame_f, frame_type, win_type) # (2048, 2)
|
|
|
|
start = i * hop
|
|
y_pad[start:start + win, :] += frame_t_hat
|
|
|
|
# Remove boundary padding that encoder adds: hop samples at start and hop at end.
|
|
if y_pad.shape[0] < 2 * hop:
|
|
raise ValueError("Decoded stream too short to unpad.")
|
|
|
|
y = y_pad[hop:-hop, :]
|
|
|
|
sf.write(str(filename_out), y, 48000)
|
|
return y
|
|
|
|
|
|
def demo_aac_1(filename_in: Union[str, Path], filename_out: Union[str, Path]) -> float:
|
|
"""
|
|
Demonstration for Level-1 codec.
|
|
|
|
Runs:
|
|
- aac_coder_1(filename_in)
|
|
- i_aac_coder_1(aac_seq_1, filename_out)
|
|
and computes total SNR between original and decoded audio.
|
|
|
|
Parameters
|
|
----------
|
|
filename_in : str | Path
|
|
Input WAV filename (stereo, 48 kHz).
|
|
filename_out : str | Path
|
|
Output WAV filename (stereo, 48 kHz).
|
|
|
|
Returns
|
|
-------
|
|
SNR : float
|
|
Overall Signal-to-Noise Ratio in dB.
|
|
"""
|
|
filename_in = Path(filename_in)
|
|
filename_out = Path(filename_out)
|
|
|
|
# Read original audio (reference)
|
|
x_ref, fs_ref = sf.read(str(filename_in), always_2d=True)
|
|
x_ref = np.asarray(x_ref, dtype=np.float64)
|
|
|
|
# Encode / decode
|
|
aac_seq_1 = aac_coder_1(filename_in)
|
|
x_hat = i_aac_coder_1(aac_seq_1, filename_out)
|
|
x_hat = np.asarray(x_hat, dtype=np.float64)
|
|
|
|
# Ensure 2D stereo shape (N, 2)
|
|
if x_hat.ndim == 1:
|
|
x_hat = x_hat.reshape(-1, 1)
|
|
if x_ref.ndim == 1:
|
|
x_ref = x_ref.reshape(-1, 1)
|
|
|
|
# Align lengths (use common overlap)
|
|
n = min(x_ref.shape[0], x_hat.shape[0])
|
|
x_ref = x_ref[:n, :]
|
|
x_hat = x_hat[:n, :]
|
|
|
|
# Match channel count conservatively (common channels)
|
|
c = min(x_ref.shape[1], x_hat.shape[1])
|
|
x_ref = x_ref[:, :c]
|
|
x_hat = x_hat[:, :c]
|
|
|
|
# Compute overall SNR over all samples and channels
|
|
err = x_ref - x_hat
|
|
p_signal = float(np.sum(x_ref * x_ref))
|
|
p_noise = float(np.sum(err * err))
|
|
|
|
if p_noise <= 0.0:
|
|
return float("inf")
|
|
if p_signal <= 0.0:
|
|
# Degenerate case: silent input
|
|
return -float("inf")
|
|
# else:
|
|
snr_db = 10.0 * np.log10(p_signal / p_noise)
|
|
return float(snr_db)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# Example usage:
|
|
# python -m level_1.level_1 input.wav output.wav
|
|
import sys
|
|
|
|
if len(sys.argv) != 3:
|
|
raise SystemExit("Usage: python -m level_1.level_1 <input.wav> <output.wav>")
|
|
|
|
in_wav = sys.argv[1]
|
|
out_wav = sys.argv[2]
|
|
|
|
print(f"Encoding/Decoding {in_wav} to {out_wav}")
|
|
snr = demo_aac_1(in_wav, out_wav)
|
|
print(f"SNR = {snr:.3f} dB")
|
|
|