605 lines
17 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# ------------------------------------------------------------
# AAC Coder/Decoder - Quantizer / iQuantizer (Level 3)
#
# Multimedia course at Aristotle University of
# Thessaloniki (AUTh)
#
# Author:
# Christos Choutouridis (ΑΕΜ 8997)
# cchoutou@ece.auth.gr
#
# Description:
# Implements AAC quantizer and inverse quantizer for one channel.
# Based on assignment section 2.6 (Eq. 12-15).
#
# Notes:
# - Bit reservoir is not implemented (assignment simplification).
# - Scalefactor bands are assumed equal to psychoacoustic bands
# (Table B.2.1.9a / B.2.1.9b from TableB219.mat).
# ------------------------------------------------------------
from __future__ import annotations
import numpy as np
from core.aac_utils import get_table, band_limits
from core.aac_types import *
# -----------------------------------------------------------------------------
# Constants (assignment)
# -----------------------------------------------------------------------------
MAGIC_NUMBER: float = 0.4054
MQ: int = 8191
EPS: float = 1e-12
MAX_SFC_DIFF: int = 60
# Safeguard: prevents infinite loops in pathological cases
MAX_ALPHA_ITERS: int = 2000
# -----------------------------------------------------------------------------
# Helpers: ESH packing/unpacking (128x8 <-> 1024x1)
# -----------------------------------------------------------------------------
def _esh_pack_to_1024(x_128x8: FloatArray) -> FloatArray:
"""
Pack ESH coefficients (128 x 8) into a single long vector (1024 x 1).
Packing order:
Columns are concatenated in subframe order (0..7), column-major.
Parameters
----------
x_128x8 : FloatArray
ESH coefficients, shape (128, 8).
Returns
-------
FloatArray
Packed coefficients, shape (1024, 1).
"""
x_128x8 = np.asarray(x_128x8, dtype=np.float64)
if x_128x8.shape != (128, 8):
raise ValueError("ESH pack expects shape (128, 8).")
return x_128x8.reshape(1024, 1, order="F")
def _esh_unpack_from_1024(x_1024x1: FloatArray) -> FloatArray:
"""
Unpack a packed ESH vector (1024 elements) back to shape (128, 8).
Parameters
----------
x_1024x1 : FloatArray
Packed ESH vector, shape (1024,) or (1024, 1) after flattening.
Returns
-------
FloatArray
Unpacked ESH coefficients, shape (128, 8).
"""
x_1024x1 = np.asarray(x_1024x1, dtype=np.float64).reshape(-1)
if x_1024x1.shape[0] != 1024:
raise ValueError("ESH unpack expects 1024 elements.")
return x_1024x1.reshape(128, 8, order="F")
# -----------------------------------------------------------------------------
# Core quantizer formulas (Eq. 12, Eq. 13)
# -----------------------------------------------------------------------------
def _quantize_symbol(x: FloatArray, alpha: float) -> QuantizedSymbols:
"""
Quantize MDCT coefficients to integer symbols S(k).
Implements Eq. (12):
S(k) = sgn(X(k)) * int( (|X(k)| * 2^(-alpha/4))^(3/4) + MAGIC_NUMBER )
Parameters
----------
x : FloatArray
MDCT coefficients for a contiguous set of spectral lines.
Shape: (N,)
alpha : float
Scalefactor gain for the corresponding scalefactor band.
Returns
-------
QuantizedSymbols
Quantized symbols S(k) as int64, shape (N,).
"""
x = np.asarray(x, dtype=np.float64)
scale = 2.0 ** (-0.25 * float(alpha))
ax = np.abs(x) * scale
y = np.power(ax, 0.75, dtype=np.float64)
# "int" in the assignment corresponds to truncation.
q = np.floor(y + MAGIC_NUMBER).astype(np.int64)
return (np.sign(x).astype(np.int64) * q).astype(np.int64)
def _dequantize_symbol(S: QuantizedSymbols, alpha: float) -> FloatArray:
"""
Inverse quantizer (dequantization of symbols).
Implements Eq. (13):
Xhat(k) = sgn(S(k)) * |S(k)|^(4/3) * 2^(alpha/4)
Parameters
----------
S : QuantizedSymbols
Quantized symbols S(k), int64, shape (N,).
alpha : float
Scalefactor gain for the corresponding scalefactor band.
Returns
-------
FloatArray
Reconstructed MDCT coefficients Xhat(k), float64, shape (N,).
"""
S = np.asarray(S, dtype=np.int64)
scale = 2.0 ** (0.25 * float(alpha))
aS = np.abs(S).astype(np.float64)
y = np.power(aS, 4.0 / 3.0, dtype=np.float64)
return (np.sign(S).astype(np.float64) * y * scale).astype(np.float64)
# -----------------------------------------------------------------------------
# Alpha initialization (Eq. 14)
# -----------------------------------------------------------------------------
def _alpha_initial_from_frame_max(x_all: FloatArray) -> int:
"""
Compute the initial scalefactor gain alpha_hat for a frame.
Implements Eq. (14):
alpha_hat = (16/3) * log2( max_k(|X(k)|^(3/4)) / MQ )
The same initial value is used for all bands before the per-band refinement.
Parameters
----------
x_all : FloatArray
All MDCT coefficients of a frame (one channel), flattened.
Returns
-------
int
Initial alpha value (integer).
"""
x_all = np.asarray(x_all, dtype=np.float64).reshape(-1)
if x_all.size == 0:
return 0
max_abs = float(np.max(np.abs(x_all)))
if max_abs <= 0.0:
return 0
val = (max_abs ** 0.75) / float(MQ)
if val <= 0.0:
return 0
alpha_hat = (16.0 / 3.0) * np.log2(val)
return int(np.floor(alpha_hat))
# -----------------------------------------------------------------------------
# Band utilities
# -----------------------------------------------------------------------------
def _band_slices(frame_type: FrameType) -> list[tuple[int, int]]:
"""
Return scalefactor band ranges [wlow, whigh] (inclusive) for the given frame type.
These are derived from the psychoacoustic tables (TableB219),
and map directly to MDCT indices:
- long: 0..1023
- short (ESH subframe): 0..127
Parameters
----------
frame_type : FrameType
Frame type ("OLS", "LSS", "ESH", "LPS").
Returns
-------
list[tuple[int, int]]
List of (lo, hi) inclusive index pairs for each band.
"""
table, _Nfft = get_table(frame_type)
wlow, whigh, _bval, _qthr_db = band_limits(table)
bands: list[tuple[int, int]] = []
for lo, hi in zip(wlow, whigh):
bands.append((int(lo), int(hi)))
return bands
def _band_energy(x: FloatArray, lo: int, hi: int) -> float:
"""
Compute energy of a spectral segment x[lo:hi+1].
Parameters
----------
x : FloatArray
MDCT coefficient vector.
lo, hi : int
Inclusive index range.
Returns
-------
float
Sum of squares (energy) within the band.
"""
sec = x[lo : hi + 1]
return float(np.sum(sec * sec))
def _threshold_T_from_SMR(
X: FloatArray,
SMR_col: FloatArray,
bands: list[tuple[int, int]],
) -> FloatArray:
"""
Compute psychoacoustic thresholds T(b) per band.
Uses:
P(b) = sum_{k in band} X(k)^2
T(b) = P(b) / SMR(b)
Parameters
----------
X : FloatArray
MDCT coefficients for a frame (long) or one ESH subframe (short).
SMR_col : FloatArray
SMR values for this frame/subframe, shape (NB,).
bands : list[tuple[int, int]]
Band index ranges.
Returns
-------
FloatArray
Threshold vector T(b), shape (NB,).
"""
nb = len(bands)
T = np.zeros((nb,), dtype=np.float64)
for b, (lo, hi) in enumerate(bands):
P = _band_energy(X, lo, hi)
smr = float(SMR_col[b])
if smr <= EPS:
T[b] = 0.0
else:
T[b] = P / smr
return T
# -----------------------------------------------------------------------------
# Alpha selection per band + neighbor-difference constraint
# -----------------------------------------------------------------------------
def _best_alpha_for_band(
X: FloatArray,
lo: int,
hi: int,
T_b: float,
alpha0: int,
alpha_min: int,
alpha_max: int,
) -> int:
"""
Determine the band-wise scalefactor alpha(b) by iteratively increasing alpha.
The algorithm increases alpha until the quantization error power exceeds
the band threshold:
P_e(b) = sum_{k in band} (X(k) - Xhat(k))^2
select the largest alpha such that P_e(b) <= T(b)
Parameters
----------
X : FloatArray
Full MDCT vector (one frame or one subframe), shape (N,).
lo, hi : int
Inclusive MDCT index range defining the band.
T_b : float
Psychoacoustic threshold for this band.
alpha0 : int
Initial alpha value for the band.
alpha_min, alpha_max : int
Bounds for alpha, used as a safeguard.
Returns
-------
int
Selected integer alpha(b).
"""
if T_b <= 0.0:
return int(alpha0)
Xsec = X[lo : hi + 1]
alpha = int(alpha0)
alpha = max(alpha_min, min(alpha, alpha_max))
# Evaluate at current alpha
Ssec = _quantize_symbol(Xsec, alpha)
Xhat = _dequantize_symbol(Ssec, alpha)
Pe = float(np.sum((Xsec - Xhat) ** 2))
if Pe > T_b:
return alpha
iters = 0
while iters < MAX_ALPHA_ITERS:
iters += 1
alpha_next = alpha + 1
if alpha_next > alpha_max:
break
Ssec = _quantize_symbol(Xsec, alpha_next)
Xhat = _dequantize_symbol(Ssec, alpha_next)
Pe_next = float(np.sum((Xsec - Xhat) ** 2))
if Pe_next > T_b:
break
alpha = alpha_next
Pe = Pe_next
return alpha
def _enforce_max_diff(alpha: np.ndarray, max_diff: int = MAX_SFC_DIFF) -> np.ndarray:
"""
Enforce neighbor constraint |alpha[b] - alpha[b-1]| <= max_diff by clamping.
Uses a forward pass and a backward pass to reduce drift.
Parameters
----------
alpha : np.ndarray
Alpha vector, shape (NB,).
max_diff : int
Maximum allowed absolute difference between adjacent bands.
Returns
-------
np.ndarray
Clamped alpha vector, int64, shape (NB,).
"""
a = np.asarray(alpha, dtype=np.int64).copy()
nb = a.shape[0]
for b in range(1, nb):
lo = int(a[b - 1] - max_diff)
hi = int(a[b - 1] + max_diff)
if a[b] < lo:
a[b] = lo
elif a[b] > hi:
a[b] = hi
for b in range(nb - 2, -1, -1):
lo = int(a[b + 1] - max_diff)
hi = int(a[b + 1] + max_diff)
if a[b] < lo:
a[b] = lo
elif a[b] > hi:
a[b] = hi
return a
# -----------------------------------------------------------------------------
# Public API
# -----------------------------------------------------------------------------
def aac_quantizer(
frame_F: FrameChannelF,
frame_type: FrameType,
SMR: FloatArray,
) -> tuple[QuantizedSymbols, ScaleFactors, GlobalGain]:
"""
AAC quantizer for one channel (Level 3).
Quantizes MDCT coefficients using band-wise scalefactors derived from SMR.
Parameters
----------
frame_F : FrameChannelF
MDCT coefficients after TNS, one channel.
Shapes:
- Long frames: (1024,) or (1024, 1)
- ESH: (128, 8)
frame_type : FrameType
AAC frame type ("OLS", "LSS", "ESH", "LPS").
SMR : FloatArray
Signal-to-Mask Ratio per band.
Shapes:
- Long: (NB,) or (NB, 1)
- ESH: (NB, 8)
Returns
-------
S : QuantizedSymbols
Quantized symbols S(k), shape (1024, 1) for all frame types.
sfc : ScaleFactors
DPCM-coded scalefactor differences sfc(b) = alpha(b) - alpha(b-1).
Shapes:
- Long: (NB, 1)
- ESH: (NB, 8)
G : GlobalGain
Global gain G = alpha(0).
- Long: scalar float
- ESH: array shape (1, 8)
"""
bands = _band_slices(frame_type)
NB = len(bands)
X = np.asarray(frame_F, dtype=np.float64)
SMR = np.asarray(SMR, dtype=np.float64)
if frame_type == "ESH":
if X.shape != (128, 8):
raise ValueError("For ESH, frame_F must have shape (128, 8).")
if SMR.shape != (NB, 8):
raise ValueError(f"For ESH, SMR must have shape ({NB}, 8).")
S_out: QuantizedSymbols = np.zeros((1024, 1), dtype=np.int64)
sfc: ScaleFactors = np.zeros((NB, 8), dtype=np.int64)
G_arr = np.zeros((1, 8), dtype=np.int64)
for j in range(8):
Xj = X[:, j].reshape(128)
SMRj = SMR[:, j].reshape(NB)
T = _threshold_T_from_SMR(Xj, SMRj, bands)
alpha0 = _alpha_initial_from_frame_max(Xj)
alpha = np.full((NB,), alpha0, dtype=np.int64)
for b, (lo, hi) in enumerate(bands):
alpha[b] = _best_alpha_for_band(
X=Xj, lo=lo, hi=hi, T_b=float(T[b]),
alpha0=int(alpha[b]),
alpha_min=-4096,
alpha_max=4096,
)
alpha = _enforce_max_diff(alpha, MAX_SFC_DIFF)
G_arr[0, j] = int(alpha[0])
sfc[0, j] = int(alpha[0])
for b in range(1, NB):
sfc[b, j] = int(alpha[b] - alpha[b - 1])
Sj = np.zeros((128,), dtype=np.int64)
for b, (lo, hi) in enumerate(bands):
Sj[lo : hi + 1] = _quantize_symbol(Xj[lo : hi + 1], float(alpha[b]))
# Place this subframe in the packed output (column-major subframe layout)
S_out[:, 0].reshape(128, 8, order="F")[:, j] = Sj
return S_out, sfc, G_arr.astype(np.float64)
# Long frames
if X.shape == (1024,):
Xv = X
elif X.shape == (1024, 1):
Xv = X[:, 0]
else:
raise ValueError("For non-ESH, frame_F must have shape (1024,) or (1024, 1).")
if SMR.shape == (NB,):
SMRv = SMR
elif SMR.shape == (NB, 1):
SMRv = SMR[:, 0]
else:
raise ValueError(f"For non-ESH, SMR must have shape ({NB},) or ({NB}, 1).")
T = _threshold_T_from_SMR(Xv, SMRv, bands)
alpha0 = _alpha_initial_from_frame_max(Xv)
alpha = np.full((NB,), alpha0, dtype=np.int64)
for b, (lo, hi) in enumerate(bands):
alpha[b] = _best_alpha_for_band(
X=Xv, lo=lo, hi=hi, T_b=float(T[b]),
alpha0=int(alpha[b]),
alpha_min=-4096,
alpha_max=4096,
)
alpha = _enforce_max_diff(alpha, MAX_SFC_DIFF)
sfc: ScaleFactors = np.zeros((NB, 1), dtype=np.int64)
sfc[0, 0] = int(alpha[0])
for b in range(1, NB):
sfc[b, 0] = int(alpha[b] - alpha[b - 1])
G: float = float(alpha[0])
S_vec = np.zeros((1024,), dtype=np.int64)
for b, (lo, hi) in enumerate(bands):
S_vec[lo : hi + 1] = _quantize_symbol(Xv[lo : hi + 1], float(alpha[b]))
return S_vec.reshape(1024, 1), sfc, G
def aac_i_quantizer(
S: QuantizedSymbols,
sfc: ScaleFactors,
G: GlobalGain,
frame_type: FrameType,
) -> FrameChannelF:
"""
Inverse quantizer (iQuantizer) for one channel.
Reconstructs MDCT coefficients from quantized symbols and DPCM scalefactors.
Parameters
----------
S : QuantizedSymbols
Quantized symbols, shape (1024, 1) (or any array with 1024 elements).
sfc : ScaleFactors
DPCM-coded scalefactors.
Shapes:
- Long: (NB, 1)
- ESH: (NB, 8)
G : GlobalGain
Global gain (not strictly required if sfc includes sfc(0)=alpha(0)).
Present for API compatibility with the assignment.
frame_type : FrameType
AAC frame type.
Returns
-------
FrameChannelF
Reconstructed MDCT coefficients:
- ESH: (128, 8)
- Long: (1024, 1)
"""
bands = _band_slices(frame_type)
NB = len(bands)
S_flat = np.asarray(S, dtype=np.int64).reshape(-1)
if S_flat.shape[0] != 1024:
raise ValueError("S must contain 1024 symbols.")
if frame_type == "ESH":
sfc = np.asarray(sfc, dtype=np.int64)
if sfc.shape != (NB, 8):
raise ValueError(f"For ESH, sfc must have shape ({NB}, 8).")
S_128x8 = _esh_unpack_from_1024(S_flat)
Xrec = np.zeros((128, 8), dtype=np.float64)
for j in range(8):
alpha = np.zeros((NB,), dtype=np.int64)
alpha[0] = int(sfc[0, j])
for b in range(1, NB):
alpha[b] = int(alpha[b - 1] + sfc[b, j])
Xj = np.zeros((128,), dtype=np.float64)
for b, (lo, hi) in enumerate(bands):
Xj[lo : hi + 1] = _dequantize_symbol(S_128x8[lo : hi + 1, j].astype(np.int64), float(alpha[b]))
Xrec[:, j] = Xj
return Xrec
sfc = np.asarray(sfc, dtype=np.int64)
if sfc.shape != (NB, 1):
raise ValueError(f"For non-ESH, sfc must have shape ({NB}, 1).")
alpha = np.zeros((NB,), dtype=np.int64)
alpha[0] = int(sfc[0, 0])
for b in range(1, NB):
alpha[b] = int(alpha[b - 1] + sfc[b, 0])
Xrec = np.zeros((1024,), dtype=np.float64)
for b, (lo, hi) in enumerate(bands):
Xrec[lo : hi + 1] = _dequantize_symbol(S_flat[lo : hi + 1], float(alpha[b]))
return Xrec.reshape(1024, 1)