442 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# ------------------------------------------------------------
# AAC Coder/Decoder - Psychoacoustic Model
#
# Multimedia course at Aristotle University of
# Thessaloniki (AUTh)
#
# Author:
# Christos Choutouridis (ΑΕΜ 8997)
# cchoutou@ece.auth.gr
#
# Description:
# Psychoacoustic model for ONE channel, based on the assignment notes (Section 2.4).
#
# Public API:
# SMR = aac_psycho(frame_T, frame_type, frame_T_prev_1, frame_T_prev_2)
#
# Output:
# - For long frames ("OLS", "LSS", "LPS"): SMR has shape (69,)
# - For short frames ("ESH"): SMR has shape (42, 8) (one column per subframe)
#
# Notes:
# - Uses Bark band tables from material/TableB219.mat:
# * B219a for long windows (69 bands, N=2048 FFT, N/2=1024 bins)
# * B219b for short windows (42 bands, N=256 FFT, N/2=128 bins)
# - Applies a Hann window in time domain before FFT magnitude/phase extraction.
# - Implements:
# spreading function -> band spreading -> tonality index -> masking thresholds -> SMR.
# ------------------------------------------------------------
from __future__ import annotations
import numpy as np
from core.aac_utils import band_limits, get_table
from core.aac_configuration import NMT_DB, TMN_DB
from core.aac_types import *
# -----------------------------------------------------------------------------
# Spreading function
# -----------------------------------------------------------------------------
def _spreading_matrix(bval: BandValueArray) -> FloatArray:
"""
Compute the spreading function matrix between psychoacoustic bands.
The spreading function describes how energy in one critical band masks
nearby bands. The formula follows the assignment pseudo-code.
Parameters
----------
bval : BandValueArray
Bark value per band, shape (B,).
Returns
-------
FloatArray
Spreading matrix S of shape (B, B), where:
S[bb, b] quantifies the contribution of band bb masking band b.
"""
bval = np.asarray(bval, dtype=np.float64).reshape(-1)
B = int(bval.shape[0])
spread = np.zeros((B, B), dtype=np.float64)
for b in range(B):
for bb in range(B):
# tmpx depends on direction (asymmetric spreading)
if bb >= b:
tmpx = 3.0 * (bval[bb] - bval[b])
else:
tmpx = 1.5 * (bval[bb] - bval[b])
# tmpz uses the "min(..., 0)" nonlinearity exactly as in the notes
tmpz = 8.0 * min((tmpx - 0.5) ** 2 - 2.0 * (tmpx - 0.5), 0.0)
tmpy = 15.811389 + 7.5 * (tmpx + 0.474) - 17.5 * np.sqrt(1.0 + (tmpx + 0.474) ** 2)
# Clamp very small values (below -100 dB) to 0 contribution
if tmpy < -100.0:
spread[bb, b] = 0.0
else:
spread[bb, b] = 10.0 ** ((tmpz + tmpy) / 10.0)
return spread
# -----------------------------------------------------------------------------
# Windowing + FFT feature extraction
# -----------------------------------------------------------------------------
def _hann_window(N: int) -> FloatArray:
"""
Hann window as specified in the notes:
w[n] = 0.5 - 0.5*cos(2*pi*(n + 0.5)/N)
Parameters
----------
N : int
Window length.
Returns
-------
FloatArray
1-D array of shape (N,), dtype float64.
"""
n = np.arange(N, dtype=np.float64)
return 0.5 - 0.5 * np.cos((2.0 * np.pi / N) * (n + 0.5))
def _r_phi_from_time(x: FrameChannelT, N: int) -> tuple[FloatArray, FloatArray]:
"""
Compute FFT magnitude r(w) and phase phi(w) for bins w = 0 .. N/2-1.
Processing:
1) Apply Hann window in time domain.
2) Compute N-point FFT.
3) Keep only the positive-frequency bins [0 .. N/2-1].
Parameters
----------
x : FrameChannelT
Time-domain samples, shape (N,).
N : int
FFT size (2048 or 256).
Returns
-------
r : FloatArray
Magnitude spectrum for bins 0 .. N/2-1, shape (N/2,).
phi : FloatArray
Phase spectrum for bins 0 .. N/2-1, shape (N/2,).
"""
x = np.asarray(x, dtype=np.float64).reshape(-1)
if x.shape[0] != N:
raise ValueError(f"Expected time vector of length {N}, got {x.shape[0]}.")
w = _hann_window(N)
X = np.fft.fft(x * w, n=N)
Xp = X[: N // 2]
r = np.abs(Xp).astype(np.float64, copy=False)
phi = np.angle(Xp).astype(np.float64, copy=False)
return r, phi
def _predictability(
r: FloatArray,
phi: FloatArray,
r_m1: FloatArray,
phi_m1: FloatArray,
r_m2: FloatArray,
phi_m2: FloatArray,
) -> FloatArray:
"""
Compute predictability c(w) per spectral bin.
The notes define:
r_pred(w) = 2*r_{-1}(w) - r_{-2}(w)
phi_pred(w) = 2*phi_{-1}(w) - phi_{-2}(w)
c(w) = |X(w) - X_pred(w)| / (r(w) + |r_pred(w)|)
where X(w) is represented in polar form using r(w), phi(w).
Parameters
----------
r, phi : FloatArray
Current magnitude and phase, shape (N/2,).
r_m1, phi_m1 : FloatArray
Previous magnitude and phase, shape (N/2,).
r_m2, phi_m2 : FloatArray
Pre-previous magnitude and phase, shape (N/2,).
Returns
-------
FloatArray
Predictability c(w), shape (N/2,).
"""
r_pred = 2.0 * r_m1 - r_m2
phi_pred = 2.0 * phi_m1 - phi_m2
num = np.sqrt(
(r * np.cos(phi) - r_pred * np.cos(phi_pred)) ** 2
+ (r * np.sin(phi) - r_pred * np.sin(phi_pred)) ** 2
)
den = r + np.abs(r_pred) + 1e-12 # avoid division-by-zero without altering behavior
return (num / den).astype(np.float64, copy=False)
# -----------------------------------------------------------------------------
# Band-domain aggregation
# -----------------------------------------------------------------------------
def _band_energy_and_pred(
r: FloatArray,
c: FloatArray,
wlow: BandIndexArray,
whigh: BandIndexArray,
) -> tuple[FloatArray, FloatArray]:
"""
Aggregate spectral bin quantities into psychoacoustic bands.
Definitions (notes):
e(b) = sum_{w=wlow(b)..whigh(b)} r(w)^2
c_num(b) = sum_{w=wlow(b)..whigh(b)} c(w) * r(w)^2
The band predictability c(b) is later computed after spreading as:
cb(b) = ct(b) / ecb(b)
Parameters
----------
r : FloatArray
Magnitude spectrum, shape (N/2,).
c : FloatArray
Predictability per bin, shape (N/2,).
wlow, whigh : BandIndexArray
Band limits (inclusive indices), shape (B,).
Returns
-------
e_b : FloatArray
Band energies e(b), shape (B,).
c_num_b : FloatArray
Weighted predictability numerators c_num(b), shape (B,).
"""
r2 = (r * r).astype(np.float64, copy=False)
B = int(wlow.shape[0])
e_b = np.zeros(B, dtype=np.float64)
c_num_b = np.zeros(B, dtype=np.float64)
for b in range(B):
a = int(wlow[b])
z = int(whigh[b])
seg_r2 = r2[a : z + 1]
e_b[b] = float(np.sum(seg_r2))
c_num_b[b] = float(np.sum(c[a : z + 1] * seg_r2))
return e_b, c_num_b
def _psycho_window(
time_x: FrameChannelT,
prev1_x: FrameChannelT,
prev2_x: FrameChannelT,
*,
N: int,
table: BarkTable,
) -> FloatArray:
"""
Compute SMR for one FFT analysis window (N=2048 for long, N=256 for short).
This implements the pipeline described in the notes:
- FFT magnitude/phase
- predictability per bin
- band energies and predictability
- band spreading
- tonality index tb(b)
- masking threshold (noise + threshold in quiet)
- SMR(b) = e(b) / np(b)
Parameters
----------
time_x : FrameChannelT
Current time-domain samples, shape (N,).
prev1_x : FrameChannelT
Previous time-domain samples, shape (N,).
prev2_x : FrameChannelT
Pre-previous time-domain samples, shape (N,).
N : int
FFT size.
table : BarkTable
Psychoacoustic band table (B219a or B219b).
Returns
-------
FloatArray
SMR per band, shape (B,).
"""
wlow, whigh, bval, qthr_db = band_limits(table)
spread = _spreading_matrix(bval)
# FFT features for current and history windows
r, phi = _r_phi_from_time(time_x, N)
r_m1, phi_m1 = _r_phi_from_time(prev1_x, N)
r_m2, phi_m2 = _r_phi_from_time(prev2_x, N)
# Predictability per bin
c_w = _predictability(r, phi, r_m1, phi_m1, r_m2, phi_m2)
# Aggregate into psycho bands
e_b, c_num_b = _band_energy_and_pred(r, c_w, wlow, whigh)
# Spread energies and predictability across bands:
# ecb(b) = sum_bb e(bb) * S(bb, b)
# ct(b) = sum_bb c_num(bb) * S(bb, b)
ecb = spread.T @ e_b
ct = spread.T @ c_num_b
# Band predictability after spreading: cb(b) = ct(b) / ecb(b)
cb = ct / (ecb + 1e-12)
# Normalized energy term:
# en(b) = ecb(b) / sum_bb S(bb, b)
spread_colsum = np.sum(spread, axis=0)
en = ecb / (spread_colsum + 1e-12)
# Tonality index (clamped to [0, 1])
tb = -0.299 - 0.43 * np.log(np.maximum(cb, 1e-12))
tb = np.clip(tb, 0.0, 1.0)
# Required SNR per band (dB): interpolate between TMN and NMT
snr_b = tb * TMN_DB + (1.0 - tb) * NMT_DB
bc = 10.0 ** (-snr_b / 10.0)
# Noise masking threshold estimate (power domain)
nb = en * bc
# Threshold in quiet (convert from dB to power domain):
# qthr_power = eps * (N/2) * 10^(qthr_db/10)
qthr_power = np.finfo('float').eps * (N / 2.0) * (10.0 ** (qthr_db / 10.0))
# Final masking threshold per band:
# np(b) = max(nb(b), qthr(b))
npart = np.maximum(nb, qthr_power)
# Signal-to-mask ratio:
# SMR(b) = e(b) / np(b)
smr = e_b / (npart + 1e-12)
return smr.astype(np.float64, copy=False)
# -----------------------------------------------------------------------------
# ESH window slicing (match filterbank conventions)
# -----------------------------------------------------------------------------
def _esh_subframes(x_2048: FrameChannelT) -> list[FrameChannelT]:
"""
Extract the 8 overlapping 256-sample short windows used by AAC ESH.
The project convention (matching the filterbank) is:
start_j = 448 + 128*j, for j = 0..7
subframe_j = x[start_j : start_j + 256]
This selects the central 1152-sample region [448, 1600) and produces
8 windows with 50% overlap.
Parameters
----------
x_2048 : FrameChannelT
Time-domain channel frame, shape (2048,).
Returns
-------
list[FrameChannelT]
List of 8 subframes, each of shape (256,).
"""
x_2048 = np.asarray(x_2048, dtype=np.float64).reshape(-1)
if x_2048.shape[0] != 2048:
raise ValueError("ESH requires 2048-sample input frames.")
subs: list[FrameChannelT] = []
for j in range(8):
start = 448 + 128 * j
subs.append(x_2048[start : start + 256])
return subs
# -----------------------------------------------------------------------------
# Public API
# -----------------------------------------------------------------------------
def aac_psycho(
frame_T: FrameChannelT,
frame_type: FrameType,
frame_T_prev_1: FrameChannelT,
frame_T_prev_2: FrameChannelT,
) -> FloatArray:
"""
Psychoacoustic model for ONE channel.
Parameters
----------
frame_T : FrameChannelT
Current time-domain channel frame, shape (2048,).
For "ESH", the 8 short windows are derived internally.
frame_type : FrameType
AAC frame type ("OLS", "LSS", "ESH", "LPS").
frame_T_prev_1 : FrameChannelT
Previous time-domain channel frame, shape (2048,).
frame_T_prev_2 : FrameChannelT
Pre-previous time-domain channel frame, shape (2048,).
Returns
-------
FloatArray
Signal-to-Mask Ratio (SMR), per psychoacoustic band.
- If frame_type == "ESH": shape (42, 8)
- Else: shape (69,)
"""
frame_T = np.asarray(frame_T, dtype=np.float64).reshape(-1)
frame_T_prev_1 = np.asarray(frame_T_prev_1, dtype=np.float64).reshape(-1)
frame_T_prev_2 = np.asarray(frame_T_prev_2, dtype=np.float64).reshape(-1)
if frame_T.shape[0] != 2048 or frame_T_prev_1.shape[0] != 2048 or frame_T_prev_2.shape[0] != 2048:
raise ValueError("aac_psycho expects 2048-sample frames for current/prev1/prev2.")
table, N = get_table(frame_type)
# Long frame types: compute one SMR vector (69 bands)
if frame_type != "ESH":
return _psycho_window(frame_T, frame_T_prev_1, frame_T_prev_2, N=N, table=table)
# ESH: compute 8 SMR vectors (42 bands each), one per short subframe.
#
# The notes use short-window history for predictability:
# - For j=0: use previous frame's subframes (7, 6)
# - For j=1: use current subframe 0 and previous frame's subframe 7
# - For j>=2: use current subframes (j-1, j-2)
#
# This matches the "within-frame history" convention commonly used in
# simplified psycho models for ESH.
cur_subs = _esh_subframes(frame_T)
prev1_subs = _esh_subframes(frame_T_prev_1)
B = int(table.shape[0]) # expected 42
smr_out = np.zeros((B, 8), dtype=np.float64)
for j in range(8):
if j == 0:
x_m1 = prev1_subs[7]
x_m2 = prev1_subs[6]
elif j == 1:
x_m1 = cur_subs[0]
x_m2 = prev1_subs[7]
else:
x_m1 = cur_subs[j - 1]
x_m2 = cur_subs[j - 2]
smr_out[:, j] = _psycho_window(cur_subs[j], x_m1, x_m2, N=256, table=table)
return smr_out