844 lines
25 KiB
Python

#! /usr/bin/env python
from __future__ import annotations
from pathlib import Path
from typing import Dict, Tuple, List, Literal, TypedDict, Union
import numpy as np
import soundfile as sf
from scipy.signal.windows import kaiser
# --------------------------------
# Public Type aliases (Level 1)
# --------------------------------
FrameType = Literal["OLS", "LSS", "ESH", "LPS"]
"""
Frame type codes:
- "OLS": ONLY_LONG_SEQUENCE
- "LSS": LONG_START_SEQUENCE
- "ESH": EIGHT_SHORT_SEQUENCE
- "LPS": LONG_STOP_SEQUENCE
"""
WinType = Literal["KBD", "SIN"]
"""
Window type codes:
- "KBD": Kaiser-Bessel-Derived
- "SIN": sinusoid
"""
FrameT = np.ndarray
"""
Time-domain frame.
Expected shape: (2048, 2) for stereo (two channels).
dtype: float (e.g., float32/float64).
"""
FrameChannelT = np.ndarray
"""
Time-domain single channel frame.
Expected shape: (2048,).
dtype: float (e.g., float32/float64).
"""
FrameF = np.ndarray
"""
Frequency-domain frame (MDCT coefficients).
As per spec (Level 1):
- If frame_type in {"OLS","LSS","LPS"}: shape (1024, 2)
- If frame_type == "ESH": shape (128, 16) where 8 subframes x 2 channels
are placed in columns according to the subframe order (i.e., each subframe is (128,2)).
"""
ChannelKey = Literal["chl", "chr"]
class AACChannelFrameF(TypedDict):
"""Channel payload for aac_seq_1[i]["chl"] or ["chr"] (Level 1)."""
frame_F: np.ndarray
# frame_F for one channel:
# - ESH: shape (128, 8)
# - else: shape (1024, 1)
class AACSeq1Frame(TypedDict):
"""One frame dictionary of aac_seq_1 (Level 1)."""
frame_type: FrameType
win_type: WinType
chl: AACChannelFrameF
chr: AACChannelFrameF
AACSeq1 = List[AACSeq1Frame]
"""AAC sequence for Level 1:
List of length K (K = number of frames).
Each element is a dict with keys:
- "frame_type", "win_type", "chl", "chr"
"""
# Global Options
# -----------------------------------------------------------------------------
# Window type
# Options: "SIN", "KBD"
WIN_TYPE: WinType = "SIN"
# Private helpers for SSC
# -----------------------------------------------------------------------------
# See Table 1 in mm-2025-hw-v0.1.pdf
STEREO_MERGE_TABLE: Dict[Tuple[FrameType, FrameType], FrameType] = {
("OLS", "OLS"): "OLS",
("OLS", "LSS"): "LSS",
("OLS", "ESH"): "ESH",
("OLS", "LPS"): "LPS",
("LSS", "OLS"): "LSS",
("LSS", "LSS"): "LSS",
("LSS", "ESH"): "ESH",
("LSS", "LPS"): "ESH",
("ESH", "OLS"): "ESH",
("ESH", "LSS"): "ESH",
("ESH", "ESH"): "ESH",
("ESH", "LPS"): "ESH",
("LPS", "OLS"): "LPS",
("LPS", "LSS"): "ESH",
("LPS", "ESH"): "ESH",
("LPS", "LPS"): "LPS",
}
def _detect_attack(next_frame_channel: FrameChannelT) -> bool:
"""
Detect if next frame (single channel) implies ESH according to the spec's attack criterion.
Parameters
----------
next_frame_channel : FrameChannelT
One channel of next_frame_T (shape: (2048,), dtype float).
Returns
-------
attack : bool
True if an attack is detected (=> next frame predicted ESH), else False.
Notes
-----
The spec describes:
- High-pass filter applied to next_frame_channel
- Split into 16 segments of length 128
- Compute segment energies s(l)
- Compute ds(l) = s(l) / s(l-1)
- Attack exists if there exists l in {1..7} such that:
s(l) > 1e-3 and ds(l) > 10
"""
x = next_frame_channel # local alias, x assumed to be a 1-D array of length 2048
# High-pass filter H(z) = (1 - z^-1) / (1 - 0.5 z^-1)
# Implemented as: y[n] = x[n] - x[n-1] + 0.5*y[n-1]
y = np.zeros_like(x)
prev_x = 0.0
prev_y = 0.0
for n in range(x.shape[0]):
xn = float(x[n])
yn = (xn - prev_x) + 0.5 * prev_y
y[n] = yn
prev_x = xn
prev_y = yn
# Segment energies over 16 blocks of 128 samples.
s = np.empty(16, dtype=np.float64)
for l in range(16):
a = l * 128
b = (l + 1) * 128
seg = y[a:b]
s[l] = float(np.sum(seg * seg))
# ds(l) for l>=1. For l=0 not defined, keep 0.
ds = np.zeros(16, dtype=np.float64)
eps = 1e-12 # avoid division by zero without changing logic materially
for l in range(1, 16):
ds[l] = s[l] / max(s[l - 1], eps)
# Spec: check l in {1..7}
for l in range(1, 8):
if (s[l] > 1e-3) and (ds[l] > 10.0):
return True
return False
def _decide_frame_type(prev_frame_type: FrameType, attack: bool) -> FrameType:
"""
Decide current frame type for a single channel based on prev_frame_type and next-frame attack.
Parameters
----------
prev_frame_type : FrameType
Previous frame type (one of "OLS","LSS","ESH","LPS").
attack : bool
Whether next frame is predicted ESH for this channel.
Returns
-------
frame_type : FrameType
The per-channel decision for the current frame.
Rules (spec)
------------
- If prev is "LSS" => current is "ESH" (fixed)
- If prev is "LPS" => current is "OLS" (fixed)
- If prev is "OLS" => current is "LSS" if attack else "OLS"
- If prev is "ESH" => current is "ESH" if attack else "LPS"
"""
if prev_frame_type == "LSS":
return "ESH"
if prev_frame_type == "LPS":
return "OLS"
if prev_frame_type == "OLS":
return "LSS" if attack else "OLS"
if prev_frame_type == "ESH":
return "ESH" if attack else "LPS"
raise ValueError(f"Invalid prev_frame_type: {prev_frame_type!r}")
def _stereo_merge(ft_l: FrameType, ft_r: FrameType) -> FrameType:
"""
Merge per-channel frame types into one common frame type using the spec table.
Parameters
----------
ft_l : FrameType
Frame type decision for channel 0 (left).
ft_r : FrameType
Frame type decision for channel 1 (right).
Returns
-------
common : FrameType
The common final frame type.
"""
try:
return STEREO_MERGE_TABLE[(ft_l, ft_r)]
except KeyError as e:
raise ValueError(f"Invalid stereo merge pair: {(ft_l, ft_r)}") from e
# Private helpers for Filterbank
# -----------------------------------------------------------------------------
def _sin_window(N: int) -> np.ndarray:
"""
Sine window (full length N).
w[n] = sin(pi/N * (n + 0.5)), 0 <= n < N
"""
n = np.arange(N, dtype=np.float64)
return np.sin((np.pi / N) * (n + 0.5))
def _kbd_window(N: int, alpha: float) -> np.ndarray:
"""
Kaiser-Bessel-Derived (KBD) window (full length N).
This follows the standard KBD construction:
- Build Kaiser kernel of length N/2 + 1
- Use cumulative sum and sqrt normalization to form left and right halves
"""
half = N // 2
# Kaiser kernel length: half + 1 samples (0 .. half)
# beta = pi * alpha per the usual correspondence with the ISO definition
kernel = kaiser(half + 1, beta=np.pi * alpha).astype(np.float64)
csum = np.cumsum(kernel)
denom = csum[-1]
w_left = np.sqrt(csum[:-1] / denom) # length half, n = 0 .. half-1
w_right = w_left[::-1] # mirror for second half
return np.concatenate([w_left, w_right])
def _long_window(win_type: WinType) -> np.ndarray:
"""
Long window (length 2048) for the selected win_type.
"""
if win_type == "SIN":
return _sin_window(2048)
if win_type == "KBD":
# Assignment-specific alpha values
return _kbd_window(2048, alpha=6.0)
raise ValueError(f"Invalid win_type: {win_type!r}")
def _short_window(win_type: WinType) -> np.ndarray:
"""
Short window (length 256) for the selected win_type.
"""
if win_type == "SIN":
return _sin_window(256)
if win_type == "KBD":
# Assignment-specific alpha values
return _kbd_window(256, alpha=4.0)
raise ValueError(f"Invalid win_type: {win_type!r}")
def _window_sequence(frame_type: FrameType, win_type: WinType) -> np.ndarray:
"""
Build the 2048-sample window sequence for OLS/LSS/LPS.
We follow the simplified assumption:
- The same window shape (KBD or SIN) is used globally (no mixed halves).
- Therefore, the left and right halves are drawn from the same family.
"""
wL = _long_window(win_type) # length 2048
wS = _short_window(win_type) # length 256
if frame_type == "OLS":
return wL
if frame_type == "LSS":
# 0..1023: left half of long window
# 1024..1471: ones (448 samples)
# 1472..1599: right half of short window (128 samples)
# 1600..2047: zeros (448 samples)
out = np.zeros(2048, dtype=np.float64)
out[0:1024] = wL[0:1024]
out[1024:1472] = 1.0
out[1472:1600] = wS[128:256]
out[1600:2048] = 0.0
return out
if frame_type == "LPS":
# 0..447: zeros (448)
# 448..575: left half of short window (128)
# 576..1023: ones (448)
# 1024..2047: right half of long window (1024)
out = np.zeros(2048, dtype=np.float64)
out[0:448] = 0.0
out[448:576] = wS[0:128]
out[576:1024] = 1.0
out[1024:2048] = wL[1024:2048]
return out
raise ValueError(f"Invalid frame_type for long window sequence: {frame_type!r}")
def _mdct(s: np.ndarray) -> np.ndarray:
"""
MDCT (direct form) as given in the assignment.
Input:
s: windowed time samples of length N (N = 2048 or 256)
Output:
X: MDCT coefficients of length N/2
Definition:
X[k] = 2 * sum_{n=0 .. N-1} s[n] * cos(2*pi/N * (n + n0) * (k + 1/2))
where n0 = (N/2 + 1)/2
"""
s = np.asarray(s, dtype=np.float64)
N = int(s.shape[0])
if N not in (2048, 256):
raise ValueError("MDCT input length must be 2048 or 256.")
n0 = (N / 2.0 + 1.0) / 2.0
n = np.arange(N, dtype=np.float64) + n0
k = np.arange(N // 2, dtype=np.float64) + 0.5
# Cosine matrix: shape (N, N/2)
C = np.cos((2.0 * np.pi / N) * np.outer(n, k))
X = 2.0 * (s @ C)
return X
def _imdct(X: np.ndarray) -> np.ndarray:
"""
IMDCT (direct form) as given in the assignment.
Input:
X: MDCT coefficients of length N/2 (N = 2048 or 256)
Output:
s: time samples of length N
Definition:
s[n] = (2/N) * sum_{k=0 .. N/2-1} X[k] * cos(2*pi/N * (n + n0) * (k + 1/2))
where n0 = (N/2 + 1)/2
"""
X = np.asarray(X, dtype=np.float64).reshape(-1)
K = int(X.shape[0])
if K not in (1024, 128):
raise ValueError("IMDCT input length must be 1024 or 128.")
N = 2 * K
n0 = (N / 2.0 + 1.0) / 2.0
n = np.arange(N, dtype=np.float64) + n0
k = np.arange(K, dtype=np.float64) + 0.5
C = np.cos((2.0 * np.pi / N) * np.outer(n, k)) # (N, K)
s = (2.0 / N) * (C @ X)
return s
def _filter_bank_esh_channel(x_ch: np.ndarray, win_type: WinType) -> np.ndarray:
"""
ESH analysis for one channel.
Returns:
X_esh: shape (128, 8), where each column is the 128 MDCT coeffs of one short window.
"""
wS = _short_window(win_type)
X_esh = np.empty((128, 8), dtype=np.float64)
# ESH subwindows are taken from the central region:
# start positions: 448 + 128*j, j = 0..7
for j in range(8):
start = 448 + 128 * j
seg = x_ch[start:start + 256] * wS
X_esh[:, j] = _mdct(seg)
return X_esh
def _unpack_esh(frame_F: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
"""
Unpack ESH spectrum from shape (128, 16) into per-channel arrays (128, 8).
Mapping is the inverse of the packing used in filter_bank():
out[:, 2*j] = left[:, j]
out[:, 2*j+1] = right[:, j]
"""
if frame_F.shape != (128, 16):
raise ValueError("ESH frame_F must have shape (128, 16).")
left = np.empty((128, 8), dtype=np.float64)
right = np.empty((128, 8), dtype=np.float64)
for j in range(8):
left[:, j] = frame_F[:, 2 * j + 0]
right[:, j] = frame_F[:, 2 * j + 1]
return left, right
def _i_filter_bank_esh_channel(X_esh: np.ndarray, win_type: WinType) -> np.ndarray:
"""
ESH synthesis for one channel.
Input:
X_esh: (128, 8) MDCT coeffs for 8 short windows
Output:
x_ch: (2048, ) time-domain frame contribution (windowed),
ready for OLA at the caller level.
"""
if X_esh.shape != (128, 8):
raise ValueError("X_esh must have shape (128, 8).")
wS = _short_window(win_type)
out = np.zeros(2048, dtype=np.float64)
# Each short IMDCT returns 256 samples. Place them at:
# start = 448 + 128*j, j=0..7 (50% overlap)
for j in range(8):
seg = _imdct(X_esh[:, j]) * wS # (256,)
start = 448 + 128 * j
out[start:start + 256] += seg
return out
# -----------------------------------------------------------------------------
# Public Function prototypes (Level 1)
# -----------------------------------------------------------------------------
def SSC(frame_T: FrameT, next_frame_T: FrameT, prev_frame_type: FrameType) -> FrameType:
"""
Sequence Segmentation Control (SSC).
Selects and returns the frame type for the current frame (i) based on input parameters.
Parameters
-------
frame_T: FrameT
current time-domain frame i, stereo, shape (2048, 2)
next_frame_T: FrameT
next time-domain frame (i+1), stereo, shape (2048, 2)
(used to decide transitions to/from ESH)
prev_frame_type: FrameType
frame type chosen for the previous frame (i-1)
Returns
-------
frame_type : FrameType
- "OLS" (ONLY_LONG_SEQUENCE)
- "LSS" (LONG_START_SEQUENCE)
- "ESH" (EIGHT_SHORT_SEQUENCE)
- "LPS" (LONG_STOP_SEQUENCE)
"""
if frame_T.shape != (2048, 2):
raise ValueError("frame_T must have shape (2048, 2).")
if next_frame_T.shape != (2048, 2):
raise ValueError("next_frame_T must have shape (2048, 2).")
# Detect attack independently per channel on next frame.
attack_l = _detect_attack(next_frame_T[:, 0])
attack_r = _detect_attack(next_frame_T[:, 1])
# Decide per-channel type based on shared prev_frame_type.
ft_l = _decide_frame_type(prev_frame_type, attack_l)
ft_r = _decide_frame_type(prev_frame_type, attack_r)
# Stereo merge as per Table 1.
return _stereo_merge(ft_l, ft_r)
def filter_bank(frame_T: FrameT, frame_type: FrameType, win_type: WinType) -> FrameF:
"""
Filterbank stage (MDCT analysis).
Parameters
----------
frame_T : FrameT
Time-domain frame, stereo, shape (2048, 2).
frame_type : FrameType
Type of the frame under encoding ("OLS"|"LSS"|"ESH"|"LPS").
win_type : WinType
Window type ("KBD" or "SIN") used for the current frame.
Returns
-------
frame_F : FrameF
Frequency-domain MDCT coefficients:
- If frame_type in {"OLS","LSS","LPS"}: array shape (1024, 2)
containing MDCT coefficients for both channels.
- If frame_type == "ESH": contains 8 subframes, each subframe has shape (128,2),
placed in columns according to subframe order, i.e. overall shape (128, 16).
"""
if frame_T.shape != (2048, 2):
raise ValueError("frame_T must have shape (2048, 2).")
xL = frame_T[:, 0].astype(np.float64, copy=False)
xR = frame_T[:, 1].astype(np.float64, copy=False)
if frame_type in ("OLS", "LSS", "LPS"):
w = _window_sequence(frame_type, win_type) # length 2048
XL = _mdct(xL * w) # length 1024
XR = _mdct(xR * w) # length 1024
out = np.empty((1024, 2), dtype=np.float64)
out[:, 0] = XL
out[:, 1] = XR
return out
if frame_type == "ESH":
Xl = _filter_bank_esh_channel(xL, win_type) # (128, 8)
Xr = _filter_bank_esh_channel(xR, win_type) # (128, 8)
# Pack into (128, 16): each subframe as (128,2) placed in columns
out = np.empty((128, 16), dtype=np.float64)
for j in range(8):
out[:, 2 * j + 0] = Xl[:, j]
out[:, 2 * j + 1] = Xr[:, j]
return out
raise ValueError(f"Invalid frame_type: {frame_type!r}")
def i_filter_bank(frame_F: FrameF, frame_type: FrameType, win_type: WinType) -> FrameT:
"""
Inverse filterbank (IMDCT synthesis).
Parameters
----------
frame_F : FrameF
Frequency-domain MDCT coefficients as produced by filter_bank().
frame_type : FrameType
Frame type ("OLS"|"LSS"|"ESH"|"LPS").
win_type : WinType
Window type ("KBD" or "SIN").
Returns
-------
frame_T : FrameT
Reconstructed time-domain frame, stereo, shape (2048, 2).
"""
if frame_type in ("OLS", "LSS", "LPS"):
if frame_F.shape != (1024, 2):
raise ValueError("For OLS/LSS/LPS, frame_F must have shape (1024, 2).")
w = _window_sequence(frame_type, win_type)
xL = _imdct(frame_F[:, 0]) * w
xR = _imdct(frame_F[:, 1]) * w
out = np.empty((2048, 2), dtype=np.float64)
out[:, 0] = xL
out[:, 1] = xR
return out
if frame_type == "ESH":
if frame_F.shape != (128, 16):
raise ValueError("For ESH, frame_F must have shape (128, 16).")
Xl, Xr = _unpack_esh(frame_F)
xL = _i_filter_bank_esh_channel(Xl, win_type)
xR = _i_filter_bank_esh_channel(Xr, win_type)
out = np.empty((2048, 2), dtype=np.float64)
out[:, 0] = xL
out[:, 1] = xR
return out
raise ValueError(f"Invalid frame_type: {frame_type!r}")
def aac_coder_1(filename_in: Union[str, Path]) -> AACSeq1:
"""
Level-1 AAC encoder.
Parameters
----------
filename_in : str | Path
Input WAV filename.
Assumption: stereo audio, sampling rate 48 kHz.
Returns
-------
aac_seq_1 : AACSeq1
List of K encoded frames.
For each i:
- aac_seq_1[i]["frame_type"]: FrameType
- aac_seq_1[i]["win_type"]: WinType
- aac_seq_1[i]["chl"]["frame_F"]:
- ESH: shape (128, 8)
- else: shape (1024, 1)
- aac_seq_1[i]["chr"]["frame_F"]:
- ESH: shape (128, 8)
- else: shape (1024, 1)
"""
filename_in = Path(filename_in)
x, fs = sf.read(str(filename_in), always_2d=True)
x = np.asarray(x, dtype=np.float64)
if x.shape[1] != 2:
raise ValueError("Input must be stereo (2 channels).")
if fs != 48000:
raise ValueError("Input sampling rate must be 48 kHz.")
hop = 1024
win = 2048
# Pad at the beginning to support the first overlap region.
# Tail padding is kept minimal; next-frame is padded on-the-fly when needed.
pad_pre = np.zeros((hop, 2), dtype=np.float64)
pad_post = np.zeros((hop, 2), dtype=np.float64)
x_pad = np.vstack([pad_pre, x, pad_post])
# Number of frames such that current frame fits; next frame will be padded if needed.
K = int((x_pad.shape[0] - win) // hop + 1)
if K <= 0:
raise ValueError("Input too short for framing.")
aac_seq: AACSeq1 = []
prev_frame_type: FrameType = "OLS"
for i in range(K):
start = i * hop
frame_t: FrameT = x_pad[start:start + win, :]
if frame_t.shape != (win, 2):
# This should not happen due to K definition, but we keep it explicit.
raise ValueError("Internal framing error: frame_t has wrong shape.")
next_t = x_pad[start + hop:start + hop + win, :]
# Ensure next_t is always (2048,2) by zero-padding at the tail.
if next_t.shape[0] < win:
tail = np.zeros((win - next_t.shape[0], 2), dtype=np.float64)
next_t = np.vstack([next_t, tail])
frame_type = SSC(frame_t, next_t, prev_frame_type)
frame_f = filter_bank(frame_t, frame_type, WIN_TYPE)
# Store per-channel as required by AACSeq1 schema
if frame_type == "ESH":
# frame_f: (128, 16) packed as [L0 R0 L1 R1 ... L7 R7]
chl_f = np.empty((128, 8), dtype=np.float64)
chr_f = np.empty((128, 8), dtype=np.float64)
for j in range(8):
chl_f[:, j] = frame_f[:, 2 * j + 0]
chr_f[:, j] = frame_f[:, 2 * j + 1]
else:
# frame_f: (1024, 2)
chl_f = frame_f[:, 0:1].astype(np.float64, copy=False)
chr_f = frame_f[:, 1:2].astype(np.float64, copy=False)
aac_seq.append({
"frame_type": frame_type,
"win_type": WIN_TYPE,
"chl": {"frame_F": chl_f},
"chr": {"frame_F": chr_f},
})
prev_frame_type = frame_type
return aac_seq
def i_aac_coder_1(aac_seq_1: AACSeq1, filename_out: Union[str, Path]) -> np.ndarray:
"""
Level-1 AAC decoder (inverse of aac_coder_1()).
Parameters
----------
aac_seq_1 : AACSeq1
Encoded sequence as produced by aac_coder_1().
filename_out : str | Path
Output WAV filename.
Assumption: stereo audio, sampling rate 48 kHz.
Returns
-------
x : np.ndarray
Decoded audio samples (time-domain).
Expected shape: (N, 2) for stereo (N depends on input length).
"""
filename_out = Path(filename_out)
hop = 1024
win = 2048
K = len(aac_seq_1)
# Output includes the encoder padding region, so we reconstruct
# full padded stream. For K frames: last frame starts at (K-1)*hop and spans win,
# so total length = (K-1)*hop + win
n_pad = (K - 1) * hop + win
y_pad = np.zeros((n_pad, 2), dtype=np.float64)
for i, fr in enumerate(aac_seq_1):
frame_type = fr["frame_type"]
win_type = fr["win_type"]
chl_f = np.asarray(fr["chl"]["frame_F"], dtype=np.float64)
chr_f = np.asarray(fr["chr"]["frame_F"], dtype=np.float64)
# Re-pack into the format expected by i_filter_bank()
if frame_type == "ESH":
if chl_f.shape != (128, 8) or chr_f.shape != (128, 8):
raise ValueError("ESH channel frame_F must have shape (128, 8).")
frame_f = np.empty((128, 16), dtype=np.float64)
for j in range(8):
frame_f[:, 2 * j + 0] = chl_f[:, j]
frame_f[:, 2 * j + 1] = chr_f[:, j]
else:
if chl_f.shape != (1024, 1) or chr_f.shape != (1024, 1):
raise ValueError("Non-ESH channel frame_F must have shape (1024, 1).")
frame_f = np.empty((1024, 2), dtype=np.float64)
frame_f[:, 0] = chl_f[:, 0]
frame_f[:, 1] = chr_f[:, 0]
frame_t_hat = i_filter_bank(frame_f, frame_type, win_type) # (2048, 2)
start = i * hop
y_pad[start:start + win, :] += frame_t_hat
# Remove boundary padding that encoder adds: hop samples at start and hop at end.
if y_pad.shape[0] < 2 * hop:
raise ValueError("Decoded stream too short to unpad.")
y = y_pad[hop:-hop, :]
sf.write(str(filename_out), y, 48000)
return y
def demo_aac_1(filename_in: Union[str, Path], filename_out: Union[str, Path]) -> float:
"""
Demonstration for Level-1 codec.
Runs:
- aac_coder_1(filename_in)
- i_aac_coder_1(aac_seq_1, filename_out)
and computes total SNR between original and decoded audio.
Parameters
----------
filename_in : str | Path
Input WAV filename (stereo, 48 kHz).
filename_out : str | Path
Output WAV filename (stereo, 48 kHz).
Returns
-------
SNR : float
Overall Signal-to-Noise Ratio in dB.
"""
filename_in = Path(filename_in)
filename_out = Path(filename_out)
# Read original audio (reference)
x_ref, fs_ref = sf.read(str(filename_in), always_2d=True)
x_ref = np.asarray(x_ref, dtype=np.float64)
# Encode / decode
aac_seq_1 = aac_coder_1(filename_in)
x_hat = i_aac_coder_1(aac_seq_1, filename_out)
x_hat = np.asarray(x_hat, dtype=np.float64)
# Ensure 2D stereo shape (N, 2)
if x_hat.ndim == 1:
x_hat = x_hat.reshape(-1, 1)
if x_ref.ndim == 1:
x_ref = x_ref.reshape(-1, 1)
# Align lengths (use common overlap)
n = min(x_ref.shape[0], x_hat.shape[0])
x_ref = x_ref[:n, :]
x_hat = x_hat[:n, :]
# Match channel count conservatively (common channels)
c = min(x_ref.shape[1], x_hat.shape[1])
x_ref = x_ref[:, :c]
x_hat = x_hat[:, :c]
# Compute overall SNR over all samples and channels
err = x_ref - x_hat
p_signal = float(np.sum(x_ref * x_ref))
p_noise = float(np.sum(err * err))
if p_noise <= 0.0:
return float("inf")
if p_signal <= 0.0:
# Degenerate case: silent input
return -float("inf")
# else:
snr_db = 10.0 * np.log10(p_signal / p_noise)
return float(snr_db)
if __name__ == "__main__":
# Example usage:
# python -m level_1.level_1 input.wav output.wav
import sys
if len(sys.argv) != 3:
raise SystemExit("Usage: python -m level_1.level_1 <input.wav> <output.wav>")
in_wav = sys.argv[1]
out_wav = sys.argv[2]
print(f"Encoding/Decoding {in_wav} to {out_wav}")
snr = demo_aac_1(in_wav, out_wav)
print(f"SNR = {snr:.3f} dB")