124 lines
3.4 KiB
Python
124 lines
3.4 KiB
Python
from __future__ import annotations
|
|
|
|
from pathlib import Path
|
|
|
|
import numpy as np
|
|
import pytest
|
|
import soundfile as sf
|
|
|
|
from level_1.level_1 import aac_coder_1, i_aac_coder_1
|
|
|
|
# Helper "fixtures" for aac_coder_1 / i_aac_coder_1
|
|
# -----------------------------------------------------------------------------
|
|
|
|
def _snr_db(x_ref: np.ndarray, x_hat: np.ndarray) -> float:
|
|
"""
|
|
Compute overall SNR (dB) over all samples and channels after aligning lengths.
|
|
"""
|
|
x_ref = np.asarray(x_ref, dtype=np.float64)
|
|
x_hat = np.asarray(x_hat, dtype=np.float64)
|
|
|
|
if x_ref.ndim == 1:
|
|
x_ref = x_ref.reshape(-1, 1)
|
|
if x_hat.ndim == 1:
|
|
x_hat = x_hat.reshape(-1, 1)
|
|
|
|
n = min(x_ref.shape[0], x_hat.shape[0])
|
|
c = min(x_ref.shape[1], x_hat.shape[1])
|
|
|
|
x_ref = x_ref[:n, :c]
|
|
x_hat = x_hat[:n, :c]
|
|
|
|
err = x_ref - x_hat
|
|
ps = float(np.sum(x_ref * x_ref))
|
|
pn = float(np.sum(err * err))
|
|
|
|
if pn <= 0.0:
|
|
return float("inf")
|
|
if ps <= 0.0:
|
|
return -float("inf")
|
|
|
|
return float(10.0 * np.log10(ps / pn))
|
|
|
|
|
|
@pytest.fixture()
|
|
def tmp_stereo_wav(tmp_path: Path) -> Path:
|
|
"""
|
|
Create a temporary 48 kHz stereo WAV with random samples.
|
|
"""
|
|
rng = np.random.default_rng(123)
|
|
fs = 48000
|
|
|
|
# ~1 second of audio, keep small for test speed
|
|
n = fs
|
|
x = rng.normal(size=(n, 2)).astype(np.float64)
|
|
|
|
wav_path = tmp_path / "in.wav"
|
|
sf.write(str(wav_path), x, fs)
|
|
return wav_path
|
|
|
|
|
|
def test_aac_coder_seq_schema_and_shapes(tmp_stereo_wav: Path) -> None:
|
|
"""
|
|
Module-level contract test:
|
|
Ensure aac_seq_1 follows the expected schema and per-frame shapes.
|
|
"""
|
|
aac_seq = aac_coder_1(tmp_stereo_wav)
|
|
|
|
assert isinstance(aac_seq, list)
|
|
assert len(aac_seq) > 0
|
|
|
|
for fr in aac_seq:
|
|
assert isinstance(fr, dict)
|
|
|
|
# Required keys
|
|
assert "frame_type" in fr
|
|
assert "win_type" in fr
|
|
assert "chl" in fr
|
|
assert "chr" in fr
|
|
|
|
frame_type = fr["frame_type"]
|
|
win_type = fr["win_type"]
|
|
|
|
assert frame_type in ("OLS", "LSS", "ESH", "LPS")
|
|
assert win_type in ("SIN", "KBD")
|
|
|
|
assert isinstance(fr["chl"], dict)
|
|
assert isinstance(fr["chr"], dict)
|
|
assert "frame_F" in fr["chl"]
|
|
assert "frame_F" in fr["chr"]
|
|
|
|
chl_f = np.asarray(fr["chl"]["frame_F"])
|
|
chr_f = np.asarray(fr["chr"]["frame_F"])
|
|
|
|
if frame_type == "ESH":
|
|
assert chl_f.shape == (128, 8)
|
|
assert chr_f.shape == (128, 8)
|
|
else:
|
|
assert chl_f.shape == (1024, 1)
|
|
assert chr_f.shape == (1024, 1)
|
|
|
|
|
|
def test_end_to_end_aac_coder_decoder_high_snr(tmp_stereo_wav: Path, tmp_path: Path) -> None:
|
|
"""
|
|
End-to-end module test:
|
|
Encode + decode and check SNR is very high (numerical-noise only).
|
|
Threshold is intentionally loose to avoid fragility.
|
|
"""
|
|
x_ref, fs = sf.read(str(tmp_stereo_wav), always_2d=True)
|
|
assert fs == 48000
|
|
|
|
out_wav = tmp_path / "out.wav"
|
|
|
|
aac_seq = aac_coder_1(tmp_stereo_wav)
|
|
x_hat = i_aac_coder_1(aac_seq, out_wav)
|
|
|
|
# Basic sanity: output file exists and is readable
|
|
assert out_wav.exists()
|
|
x_hat_file, fs_hat = sf.read(str(out_wav), always_2d=True)
|
|
assert fs_hat == 48000
|
|
|
|
# SNR computed against the array returned by i_aac_coder_1 (should match file, but not required)
|
|
snr = _snr_db(x_ref, x_hat)
|
|
assert snr > 80.0
|