307 lines
8.9 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# ------------------------------------------------------------
# AAC Coder/Decoder - AAC Utilities
#
# Multimedia course at Aristotle University of
# Thessaloniki (AUTh)
#
# Author:
# Christos Choutouridis (ΑΕΜ 8997)
# cchoutou@ece.auth.gr
#
# Description:
# Shared utility functions used across AAC encoder/decoder levels.
#
# This module currently provides:
# - MDCT / IMDCT conversions
# - Signal-to-Noise Ratio (SNR) computation in dB
# - Loading and access helpers for psychoacoustic band tables
# (TableB219.mat, Tables B.2.1.9a / B.2.1.9b of the AAC specification)
# ------------------------------------------------------------
from __future__ import annotations
import numpy as np
from pathlib import Path
from scipy.io import loadmat
from core.aac_types import *
# -----------------------------------------------------------------------------
# Global cached data
# -----------------------------------------------------------------------------
# Cached contents of TableB219.mat to avoid repeated disk I/O.
# Keys:
# - "B219a": long-window psychoacoustic bands (69 bands, FFT size 2048)
# - "B219b": short-window psychoacoustic bands (42 bands, FFT size 256)
B219_CACHE: dict[str, BarkTable] | None = None
# -----------------------------------------------------------------------------
# MDCT / IMDCT
# -----------------------------------------------------------------------------
def mdct(s: TimeSignal) -> MdctCoeffs:
"""
MDCT (direct form) as specified in the assignment.
Parameters
----------
s : TimeSignal
Windowed time samples, 1-D array of length N (N = 2048 or 256).
Returns
-------
MdctCoeffs
MDCT coefficients, 1-D array of length N/2.
Definition
----------
X[k] = 2 * sum_{n=0..N-1} s[n] * cos((2*pi/N) * (n + n0) * (k + 1/2)),
where n0 = (N/2 + 1)/2.
"""
s = np.asarray(s, dtype=np.float64).reshape(-1)
N = int(s.shape[0])
if N not in (2048, 256):
raise ValueError("MDCT input length must be 2048 or 256.")
n0 = (N / 2.0 + 1.0) / 2.0
n = np.arange(N, dtype=np.float64) + n0
k = np.arange(N // 2, dtype=np.float64) + 0.5
C = np.cos((2.0 * np.pi / N) * np.outer(n, k)) # (N, N/2)
X = 2.0 * (s @ C) # (N/2,)
return X
def imdct(X: MdctCoeffs) -> TimeSignal:
"""
IMDCT (direct form) as specified in the assignment.
Parameters
----------
X : MdctCoeffs
MDCT coefficients, 1-D array of length K (K = 1024 or 128).
Returns
-------
TimeSignal
Reconstructed time samples, 1-D array of length N = 2K.
Definition
----------
s[n] = (2/N) * sum_{k=0..N/2-1} X[k] * cos((2*pi/N) * (n + n0) * (k + 1/2)),
where n0 = (N/2 + 1)/2.
"""
X = np.asarray(X, dtype=np.float64).reshape(-1)
K = int(X.shape[0])
if K not in (1024, 128):
raise ValueError("IMDCT input length must be 1024 or 128.")
N = 2 * K
n0 = (N / 2.0 + 1.0) / 2.0
n = np.arange(N, dtype=np.float64) + n0
k = np.arange(K, dtype=np.float64) + 0.5
C = np.cos((2.0 * np.pi / N) * np.outer(n, k)) # (N, K)
s = (2.0 / N) * (C @ X) # (N,)
return s
# -----------------------------------------------------------------------------
# Signal quality metrics
# -----------------------------------------------------------------------------
def snr_db(x_ref: StereoSignal, x_hat: StereoSignal) -> float:
"""
Compute the overall Signal-to-Noise Ratio (SNR) in dB.
The SNR is computed over all available samples and channels,
after conservatively aligning the two signals to their common
length and channel count.
Parameters
----------
x_ref : StereoSignal
Reference (original) signal.
Typical shape: (N, 2) for stereo.
x_hat : StereoSignal
Reconstructed or processed signal.
Typical shape: (M, 2) for stereo.
Returns
-------
float
SNR in dB.
- +inf if the noise power is zero (perfect reconstruction).
- -inf if the reference signal power is zero.
"""
x_ref = np.asarray(x_ref, dtype=np.float64)
x_hat = np.asarray(x_hat, dtype=np.float64)
# Ensure 2-D shape: (samples, channels)
if x_ref.ndim == 1:
x_ref = x_ref.reshape(-1, 1)
if x_hat.ndim == 1:
x_hat = x_hat.reshape(-1, 1)
# Align lengths and channel count conservatively
n = min(x_ref.shape[0], x_hat.shape[0])
c = min(x_ref.shape[1], x_hat.shape[1])
x_ref = x_ref[:n, :c]
x_hat = x_hat[:n, :c]
err = x_ref - x_hat
ps = float(np.sum(x_ref * x_ref)) # signal power
pn = float(np.sum(err * err)) # noise power
if pn <= 0.0:
return float("inf")
if ps <= 0.0:
return float("-inf")
return float(10.0 * np.log10(ps / pn))
def estimate_lag_mono(x_ref: TimeSignal, x_hat: TimeSignal, max_lag=4096):
"""
Estimate time lag between two mono signals.
Returns lag (positive means x_hat delayed).
"""
n = min(len(x_ref), len(x_hat))
x_ref = x_ref[:n]
x_hat = x_hat[:n]
corr = np.correlate(x_ref, x_hat, mode='full')
lags = np.arange(-n + 1, n)
center = n - 1
lo = max(0, center - max_lag)
hi = min(len(corr), center + max_lag + 1)
best = lo + int(np.argmax(corr[lo:hi]))
return int(lags[best])
def match_gain(x_ref: StereoSignal, x_hat: StereoSignal) -> float:
"""
Least-squares gain g that best maps x_hat -> x_ref.
"""
n = min(x_ref.shape[0], x_hat.shape[0])
c = min(x_ref.shape[1], x_hat.shape[1])
r = x_ref[:n, :c].reshape(-1).astype(np.float64)
h = x_hat[:n, :c].reshape(-1).astype(np.float64)
denom = float(np.dot(h, h))
if denom <= 0.0:
return 1.0
return float(np.dot(r, h) / denom)
# -----------------------------------------------------------------------------
# Psychoacoustic band tables (TableB219.mat)
# -----------------------------------------------------------------------------
def load_b219_tables() -> dict[str, BarkTable]:
"""
Load and cache psychoacoustic band tables from TableB219.mat.
The assignment/project layout assumes that a 'material' directory
is available in the current working directory when running:
- tests
- level_1 / level_2 / level_3 entrypoints
This function loads the tables once and caches them for subsequent calls.
Returns
-------
dict[str, BarkTable]
Dictionary with the following entries:
- "B219a": long-window psychoacoustic table
(69 bands, FFT size 2048 / 1024 spectral lines)
- "B219b": short-window psychoacoustic table
(42 bands, FFT size 256 / 128 spectral lines)
"""
global B219_CACHE
if B219_CACHE is not None:
return B219_CACHE
mat_path = Path("material") / "TableB219.mat"
if not mat_path.exists():
raise FileNotFoundError(
"Could not locate material/TableB219.mat in the current working directory."
)
data = loadmat(str(mat_path))
if "B219a" not in data or "B219b" not in data:
raise ValueError(
"TableB219.mat missing required variables 'B219a' and/or 'B219b'."
)
B219_CACHE = {
"B219a": np.asarray(data["B219a"], dtype=np.float64),
"B219b": np.asarray(data["B219b"], dtype=np.float64),
}
return B219_CACHE
def get_table(frame_type: FrameType) -> tuple[BarkTable, int]:
"""
Select the appropriate psychoacoustic band table and FFT size
based on the AAC frame type.
Parameters
----------
frame_type : FrameType
AAC frame type ("OLS", "LSS", "ESH", "LPS").
Returns
-------
table : BarkTable
Psychoacoustic band table:
- B219a for long frames
- B219b for ESH short subframes
N : int
FFT size corresponding to the table:
- 2048 for long frames
- 256 for short frames (ESH)
"""
tables = load_b219_tables()
if frame_type == "ESH":
return tables["B219b"], 256
return tables["B219a"], 2048
def band_limits(
table: BarkTable,
) -> tuple[BandIndexArray, BandIndexArray, BandValueArray, BandValueArray]:
"""
Extract per-band metadata from a TableB2.1.9 psychoacoustic table.
The column layout follows the provided TableB219.mat file and the
AAC specification tables B.2.1.9a / B.2.1.9b.
Parameters
----------
table : BarkTable
Psychoacoustic band table (B219a or B219b).
Returns
-------
wlow : BandIndexArray
Lower FFT bin index (inclusive) for each band.
whigh : BandIndexArray
Upper FFT bin index (inclusive) for each band.
bval : BandValueArray
Bark-scale (or equivalent) band position values.
Used in the spreading function.
qthr_db : BandValueArray
Threshold in quiet for each band, in dB.
"""
wlow = table[:, 1].astype(int)
whigh = table[:, 2].astype(int)
bval = table[:, 4].astype(np.float64)
qthr_db = table[:, 5].astype(np.float64)
return wlow, whigh, bval, qthr_db