Init commit with level 1 python source
This commit is contained in:
commit
dde11ddebe
2
.gitignore
vendored
Normal file
2
.gitignore
vendored
Normal file
@ -0,0 +1,2 @@
|
||||
# project wide excludes
|
||||
|
||||
21
LICENSE
Normal file
21
LICENSE
Normal file
@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2026 Christos Choutouridis <cchoutou@ece.auth.gr>
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
107
Readme.md
Normal file
107
Readme.md
Normal file
@ -0,0 +1,107 @@
|
||||
# AAC Encoder/Decoder Assignment (Multimedia)
|
||||
|
||||
## About
|
||||
|
||||
This repository contains a staged implementation of a simplified AAC-like audio coder/decoder pipeline, developed in the context of the Multimedia course at Aristotle University of Thessaloniki (AUTh).
|
||||
The project is organized into incremental levels, where each level introduces additional functionality and requirements (e.g., segmentation control, filterbanks, and progressively more complete encoding/decoding stages).
|
||||
The purpose of this work is to implement the specified processing chain faithfully to the assignment specification, validate correctness with structured tests, and maintain a clean, reproducible project structure throughout development.
|
||||
|
||||
## Repository Structure
|
||||
|
||||
The repository is organized under the `source/` directory and split into three incremental levels:
|
||||
|
||||
- `source/level_1/`
|
||||
Baseline implementation of the required processing chain for Level 1.
|
||||
|
||||
- `source/level_2/`
|
||||
Placeholder for Level 2 implementation (to be filled step-by-step).
|
||||
|
||||
- `source/level_3/`
|
||||
Placeholder for Level 3 implementation (to be filled step-by-step).
|
||||
|
||||
Each level contains:
|
||||
- a module file (e.g., `level_1/level_1.py`)
|
||||
- a dedicated `tests/` directory for module-level tests (pytest)
|
||||
|
||||
## Level Descriptions
|
||||
|
||||
### Level 1
|
||||
|
||||
**Goal:** Implement the core analysis/synthesis chain for Level 1 as defined in the assignment specification.
|
||||
|
||||
Implemented components (current status):
|
||||
- SSC (Sequence Segmentation Control)
|
||||
- Filterbank (MDCT analysis) and inverse filterbank (IMDCT synthesis)
|
||||
- End-to-end encoder/decoder functions:
|
||||
- `aac_coder_1()`
|
||||
- `i_aac_coder_1()`
|
||||
- Demo function:
|
||||
- `demo_aac_1()`
|
||||
|
||||
Tests (current status):
|
||||
- Module-level tests for SSC
|
||||
- Module-level tests for filterbank and inverse filterbank (including OLA-based reconstruction checks)
|
||||
- Internal consistency tests for MDCT/IMDCT
|
||||
- Module-level tests for `aac_coder_1` / `i_aac_coder_1`
|
||||
|
||||
### Level 2
|
||||
|
||||
**Goal:** ...
|
||||
|
||||
|
||||
### Level 3
|
||||
|
||||
**Goal:** ...
|
||||
|
||||
## How to Run
|
||||
|
||||
All commands below assume you are inside the `source/` directory.
|
||||
|
||||
### Run Level 1 Demo
|
||||
|
||||
Run the Level 1 demo by providing an input WAV file and an output WAV file:
|
||||
|
||||
```bash
|
||||
python -m level_1.level_1 <input.wav> <output.wav>
|
||||
```
|
||||
Example:
|
||||
```bash
|
||||
python -m level_1.level_1 ../material/LicorDeCalandraca.wav ../material/LicorDeCalandraca_out.wav
|
||||
```
|
||||
The demo prints the overall SNR (in dB) between the original and reconstructed audio.
|
||||
|
||||
### How to Run Tests
|
||||
|
||||
Tests are written using `pytest` and are organized per level.
|
||||
|
||||
From inside `source/`, run all tests:
|
||||
```bash
|
||||
pytest
|
||||
```
|
||||
|
||||
To run only level_1/tests
|
||||
```bash
|
||||
pytest level_1/tests
|
||||
```
|
||||
|
||||
To run a specific test file:
|
||||
```bash
|
||||
pytest level_1/tests/test_SSC.py
|
||||
```
|
||||
|
||||
|
||||
|
||||
## Notes on Development Workflow
|
||||
|
||||
- The project is developed incrementally, level-by-level.
|
||||
- Tests are primarily module-level and specification-driven.
|
||||
- Internal helper testing is avoided in general, except for MDCT/IMDCT which are treated as reusable “library-like” primitives.
|
||||
|
||||
## Disclaimer
|
||||
|
||||
This project was developed solely for educational purposes.
|
||||
It is provided "as is", without any express or implied warranties.
|
||||
The author assumes no responsibility for any misuse, data loss, security incidents, or damages resulting from the use of this software.
|
||||
This implementation should not be used in production environments.
|
||||
|
||||
All work, modifications, and results are the sole responsibility of the author.
|
||||
3
RequirementDocs/.gitignore
vendored
Normal file
3
RequirementDocs/.gitignore
vendored
Normal file
@ -0,0 +1,3 @@
|
||||
# Data files
|
||||
*zip
|
||||
|
||||
BIN
RequirementDocs/AAC description/26401-700.pdf
Normal file
BIN
RequirementDocs/AAC description/26401-700.pdf
Normal file
Binary file not shown.
BIN
RequirementDocs/AAC description/26403-700.pdf
Normal file
BIN
RequirementDocs/AAC description/26403-700.pdf
Normal file
Binary file not shown.
BIN
RequirementDocs/AAC description/MDCT etc.pdf
Normal file
BIN
RequirementDocs/AAC description/MDCT etc.pdf
Normal file
Binary file not shown.
BIN
RequirementDocs/AAC description/ieee AAC 00883454.pdf
Normal file
BIN
RequirementDocs/AAC description/ieee AAC 00883454.pdf
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
RequirementDocs/AAC description/w2203tfa_filterbanks.pdf
Normal file
BIN
RequirementDocs/AAC description/w2203tfa_filterbanks.pdf
Normal file
Binary file not shown.
BIN
mm-2025-hw-v0.1.pdf
Normal file
BIN
mm-2025-hw-v0.1.pdf
Normal file
Binary file not shown.
9
source/.gitignore
vendored
Normal file
9
source/.gitignore
vendored
Normal file
@ -0,0 +1,9 @@
|
||||
# Python excludes
|
||||
.venv/*
|
||||
.pytest_cache/*
|
||||
*__pycache__*
|
||||
|
||||
# IDEs
|
||||
.idea/*
|
||||
|
||||
|
||||
843
source/level_1/level_1.py
Normal file
843
source/level_1/level_1.py
Normal file
@ -0,0 +1,843 @@
|
||||
#! /usr/bin/env python
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Dict, Tuple, List, Literal, TypedDict, Union
|
||||
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
from scipy.signal.windows import kaiser
|
||||
|
||||
# --------------------------------
|
||||
# Public Type aliases (Level 1)
|
||||
# --------------------------------
|
||||
|
||||
FrameType = Literal["OLS", "LSS", "ESH", "LPS"]
|
||||
"""
|
||||
Frame type codes:
|
||||
- "OLS": ONLY_LONG_SEQUENCE
|
||||
- "LSS": LONG_START_SEQUENCE
|
||||
- "ESH": EIGHT_SHORT_SEQUENCE
|
||||
- "LPS": LONG_STOP_SEQUENCE
|
||||
"""
|
||||
|
||||
WinType = Literal["KBD", "SIN"]
|
||||
"""
|
||||
Window type codes:
|
||||
- "KBD": Kaiser-Bessel-Derived
|
||||
- "SIN": sinusoid
|
||||
"""
|
||||
|
||||
FrameT = np.ndarray
|
||||
"""
|
||||
Time-domain frame.
|
||||
Expected shape: (2048, 2) for stereo (two channels).
|
||||
dtype: float (e.g., float32/float64).
|
||||
"""
|
||||
|
||||
FrameChannelT = np.ndarray
|
||||
"""
|
||||
Time-domain single channel frame.
|
||||
Expected shape: (2048,).
|
||||
dtype: float (e.g., float32/float64).
|
||||
"""
|
||||
|
||||
|
||||
FrameF = np.ndarray
|
||||
"""
|
||||
Frequency-domain frame (MDCT coefficients).
|
||||
As per spec (Level 1):
|
||||
- If frame_type in {"OLS","LSS","LPS"}: shape (1024, 2)
|
||||
- If frame_type == "ESH": shape (128, 16) where 8 subframes x 2 channels
|
||||
are placed in columns according to the subframe order (i.e., each subframe is (128,2)).
|
||||
"""
|
||||
|
||||
ChannelKey = Literal["chl", "chr"]
|
||||
|
||||
|
||||
class AACChannelFrameF(TypedDict):
|
||||
"""Channel payload for aac_seq_1[i]["chl"] or ["chr"] (Level 1)."""
|
||||
frame_F: np.ndarray
|
||||
# frame_F for one channel:
|
||||
# - ESH: shape (128, 8)
|
||||
# - else: shape (1024, 1)
|
||||
|
||||
|
||||
class AACSeq1Frame(TypedDict):
|
||||
"""One frame dictionary of aac_seq_1 (Level 1)."""
|
||||
frame_type: FrameType
|
||||
win_type: WinType
|
||||
chl: AACChannelFrameF
|
||||
chr: AACChannelFrameF
|
||||
|
||||
|
||||
AACSeq1 = List[AACSeq1Frame]
|
||||
"""AAC sequence for Level 1:
|
||||
List of length K (K = number of frames).
|
||||
Each element is a dict with keys:
|
||||
- "frame_type", "win_type", "chl", "chr"
|
||||
"""
|
||||
|
||||
# Global Options
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
# Window type
|
||||
# Options: "SIN", "KBD"
|
||||
WIN_TYPE: WinType = "SIN"
|
||||
|
||||
|
||||
# Private helpers for SSC
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
# See Table 1 in mm-2025-hw-v0.1.pdf
|
||||
STEREO_MERGE_TABLE: Dict[Tuple[FrameType, FrameType], FrameType] = {
|
||||
("OLS", "OLS"): "OLS",
|
||||
("OLS", "LSS"): "LSS",
|
||||
("OLS", "ESH"): "ESH",
|
||||
("OLS", "LPS"): "LPS",
|
||||
("LSS", "OLS"): "LSS",
|
||||
("LSS", "LSS"): "LSS",
|
||||
("LSS", "ESH"): "ESH",
|
||||
("LSS", "LPS"): "ESH",
|
||||
("ESH", "OLS"): "ESH",
|
||||
("ESH", "LSS"): "ESH",
|
||||
("ESH", "ESH"): "ESH",
|
||||
("ESH", "LPS"): "ESH",
|
||||
("LPS", "OLS"): "LPS",
|
||||
("LPS", "LSS"): "ESH",
|
||||
("LPS", "ESH"): "ESH",
|
||||
("LPS", "LPS"): "LPS",
|
||||
}
|
||||
|
||||
def _detect_attack(next_frame_channel: FrameChannelT) -> bool:
|
||||
"""
|
||||
Detect if next frame (single channel) implies ESH according to the spec's attack criterion.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
next_frame_channel : FrameChannelT
|
||||
One channel of next_frame_T (shape: (2048,), dtype float).
|
||||
|
||||
Returns
|
||||
-------
|
||||
attack : bool
|
||||
True if an attack is detected (=> next frame predicted ESH), else False.
|
||||
|
||||
Notes
|
||||
-----
|
||||
The spec describes:
|
||||
|
||||
- High-pass filter applied to next_frame_channel
|
||||
- Split into 16 segments of length 128
|
||||
- Compute segment energies s(l)
|
||||
- Compute ds(l) = s(l) / s(l-1)
|
||||
- Attack exists if there exists l in {1..7} such that:
|
||||
s(l) > 1e-3 and ds(l) > 10
|
||||
"""
|
||||
x = next_frame_channel # local alias, x assumed to be a 1-D array of length 2048
|
||||
|
||||
# High-pass filter H(z) = (1 - z^-1) / (1 - 0.5 z^-1)
|
||||
# Implemented as: y[n] = x[n] - x[n-1] + 0.5*y[n-1]
|
||||
y = np.zeros_like(x)
|
||||
prev_x = 0.0
|
||||
prev_y = 0.0
|
||||
for n in range(x.shape[0]):
|
||||
xn = float(x[n])
|
||||
yn = (xn - prev_x) + 0.5 * prev_y
|
||||
y[n] = yn
|
||||
prev_x = xn
|
||||
prev_y = yn
|
||||
|
||||
# Segment energies over 16 blocks of 128 samples.
|
||||
s = np.empty(16, dtype=np.float64)
|
||||
for l in range(16):
|
||||
a = l * 128
|
||||
b = (l + 1) * 128
|
||||
seg = y[a:b]
|
||||
s[l] = float(np.sum(seg * seg))
|
||||
|
||||
# ds(l) for l>=1. For l=0 not defined, keep 0.
|
||||
ds = np.zeros(16, dtype=np.float64)
|
||||
eps = 1e-12 # avoid division by zero without changing logic materially
|
||||
for l in range(1, 16):
|
||||
ds[l] = s[l] / max(s[l - 1], eps)
|
||||
|
||||
# Spec: check l in {1..7}
|
||||
for l in range(1, 8):
|
||||
if (s[l] > 1e-3) and (ds[l] > 10.0):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _decide_frame_type(prev_frame_type: FrameType, attack: bool) -> FrameType:
|
||||
"""
|
||||
Decide current frame type for a single channel based on prev_frame_type and next-frame attack.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
prev_frame_type : FrameType
|
||||
Previous frame type (one of "OLS","LSS","ESH","LPS").
|
||||
attack : bool
|
||||
Whether next frame is predicted ESH for this channel.
|
||||
|
||||
Returns
|
||||
-------
|
||||
frame_type : FrameType
|
||||
The per-channel decision for the current frame.
|
||||
|
||||
Rules (spec)
|
||||
------------
|
||||
- If prev is "LSS" => current is "ESH" (fixed)
|
||||
- If prev is "LPS" => current is "OLS" (fixed)
|
||||
- If prev is "OLS" => current is "LSS" if attack else "OLS"
|
||||
- If prev is "ESH" => current is "ESH" if attack else "LPS"
|
||||
"""
|
||||
if prev_frame_type == "LSS":
|
||||
return "ESH"
|
||||
if prev_frame_type == "LPS":
|
||||
return "OLS"
|
||||
if prev_frame_type == "OLS":
|
||||
return "LSS" if attack else "OLS"
|
||||
if prev_frame_type == "ESH":
|
||||
return "ESH" if attack else "LPS"
|
||||
|
||||
raise ValueError(f"Invalid prev_frame_type: {prev_frame_type!r}")
|
||||
|
||||
|
||||
def _stereo_merge(ft_l: FrameType, ft_r: FrameType) -> FrameType:
|
||||
"""
|
||||
Merge per-channel frame types into one common frame type using the spec table.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
ft_l : FrameType
|
||||
Frame type decision for channel 0 (left).
|
||||
ft_r : FrameType
|
||||
Frame type decision for channel 1 (right).
|
||||
|
||||
Returns
|
||||
-------
|
||||
common : FrameType
|
||||
The common final frame type.
|
||||
"""
|
||||
try:
|
||||
return STEREO_MERGE_TABLE[(ft_l, ft_r)]
|
||||
except KeyError as e:
|
||||
raise ValueError(f"Invalid stereo merge pair: {(ft_l, ft_r)}") from e
|
||||
|
||||
|
||||
|
||||
# Private helpers for Filterbank
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
def _sin_window(N: int) -> np.ndarray:
|
||||
"""
|
||||
Sine window (full length N).
|
||||
w[n] = sin(pi/N * (n + 0.5)), 0 <= n < N
|
||||
"""
|
||||
n = np.arange(N, dtype=np.float64)
|
||||
return np.sin((np.pi / N) * (n + 0.5))
|
||||
|
||||
|
||||
def _kbd_window(N: int, alpha: float) -> np.ndarray:
|
||||
"""
|
||||
Kaiser-Bessel-Derived (KBD) window (full length N).
|
||||
|
||||
This follows the standard KBD construction:
|
||||
- Build Kaiser kernel of length N/2 + 1
|
||||
- Use cumulative sum and sqrt normalization to form left and right halves
|
||||
"""
|
||||
half = N // 2
|
||||
|
||||
# Kaiser kernel length: half + 1 samples (0 .. half)
|
||||
# beta = pi * alpha per the usual correspondence with the ISO definition
|
||||
kernel = kaiser(half + 1, beta=np.pi * alpha).astype(np.float64)
|
||||
|
||||
csum = np.cumsum(kernel)
|
||||
denom = csum[-1]
|
||||
|
||||
w_left = np.sqrt(csum[:-1] / denom) # length half, n = 0 .. half-1
|
||||
w_right = w_left[::-1] # mirror for second half
|
||||
|
||||
return np.concatenate([w_left, w_right])
|
||||
|
||||
|
||||
def _long_window(win_type: WinType) -> np.ndarray:
|
||||
"""
|
||||
Long window (length 2048) for the selected win_type.
|
||||
"""
|
||||
if win_type == "SIN":
|
||||
return _sin_window(2048)
|
||||
if win_type == "KBD":
|
||||
# Assignment-specific alpha values
|
||||
return _kbd_window(2048, alpha=6.0)
|
||||
raise ValueError(f"Invalid win_type: {win_type!r}")
|
||||
|
||||
|
||||
def _short_window(win_type: WinType) -> np.ndarray:
|
||||
"""
|
||||
Short window (length 256) for the selected win_type.
|
||||
"""
|
||||
if win_type == "SIN":
|
||||
return _sin_window(256)
|
||||
if win_type == "KBD":
|
||||
# Assignment-specific alpha values
|
||||
return _kbd_window(256, alpha=4.0)
|
||||
raise ValueError(f"Invalid win_type: {win_type!r}")
|
||||
|
||||
|
||||
def _window_sequence(frame_type: FrameType, win_type: WinType) -> np.ndarray:
|
||||
"""
|
||||
Build the 2048-sample window sequence for OLS/LSS/LPS.
|
||||
|
||||
We follow the simplified assumption:
|
||||
- The same window shape (KBD or SIN) is used globally (no mixed halves).
|
||||
- Therefore, the left and right halves are drawn from the same family.
|
||||
"""
|
||||
wL = _long_window(win_type) # length 2048
|
||||
wS = _short_window(win_type) # length 256
|
||||
|
||||
if frame_type == "OLS":
|
||||
return wL
|
||||
|
||||
if frame_type == "LSS":
|
||||
# 0..1023: left half of long window
|
||||
# 1024..1471: ones (448 samples)
|
||||
# 1472..1599: right half of short window (128 samples)
|
||||
# 1600..2047: zeros (448 samples)
|
||||
out = np.zeros(2048, dtype=np.float64)
|
||||
out[0:1024] = wL[0:1024]
|
||||
out[1024:1472] = 1.0
|
||||
out[1472:1600] = wS[128:256]
|
||||
out[1600:2048] = 0.0
|
||||
return out
|
||||
|
||||
if frame_type == "LPS":
|
||||
# 0..447: zeros (448)
|
||||
# 448..575: left half of short window (128)
|
||||
# 576..1023: ones (448)
|
||||
# 1024..2047: right half of long window (1024)
|
||||
out = np.zeros(2048, dtype=np.float64)
|
||||
out[0:448] = 0.0
|
||||
out[448:576] = wS[0:128]
|
||||
out[576:1024] = 1.0
|
||||
out[1024:2048] = wL[1024:2048]
|
||||
return out
|
||||
|
||||
raise ValueError(f"Invalid frame_type for long window sequence: {frame_type!r}")
|
||||
|
||||
|
||||
def _mdct(s: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
MDCT (direct form) as given in the assignment.
|
||||
|
||||
Input:
|
||||
s: windowed time samples of length N (N = 2048 or 256)
|
||||
|
||||
Output:
|
||||
X: MDCT coefficients of length N/2
|
||||
|
||||
Definition:
|
||||
X[k] = 2 * sum_{n=0 .. N-1} s[n] * cos(2*pi/N * (n + n0) * (k + 1/2))
|
||||
where n0 = (N/2 + 1)/2
|
||||
"""
|
||||
s = np.asarray(s, dtype=np.float64)
|
||||
N = int(s.shape[0])
|
||||
if N not in (2048, 256):
|
||||
raise ValueError("MDCT input length must be 2048 or 256.")
|
||||
|
||||
n0 = (N / 2.0 + 1.0) / 2.0
|
||||
|
||||
n = np.arange(N, dtype=np.float64) + n0
|
||||
k = np.arange(N // 2, dtype=np.float64) + 0.5
|
||||
|
||||
# Cosine matrix: shape (N, N/2)
|
||||
C = np.cos((2.0 * np.pi / N) * np.outer(n, k))
|
||||
X = 2.0 * (s @ C)
|
||||
|
||||
return X
|
||||
|
||||
def _imdct(X: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
IMDCT (direct form) as given in the assignment.
|
||||
|
||||
Input:
|
||||
X: MDCT coefficients of length N/2 (N = 2048 or 256)
|
||||
|
||||
Output:
|
||||
s: time samples of length N
|
||||
|
||||
Definition:
|
||||
s[n] = (2/N) * sum_{k=0 .. N/2-1} X[k] * cos(2*pi/N * (n + n0) * (k + 1/2))
|
||||
where n0 = (N/2 + 1)/2
|
||||
"""
|
||||
X = np.asarray(X, dtype=np.float64).reshape(-1)
|
||||
K = int(X.shape[0])
|
||||
if K not in (1024, 128):
|
||||
raise ValueError("IMDCT input length must be 1024 or 128.")
|
||||
|
||||
N = 2 * K
|
||||
n0 = (N / 2.0 + 1.0) / 2.0
|
||||
|
||||
n = np.arange(N, dtype=np.float64) + n0
|
||||
k = np.arange(K, dtype=np.float64) + 0.5
|
||||
|
||||
C = np.cos((2.0 * np.pi / N) * np.outer(n, k)) # (N, K)
|
||||
s = (2.0 / N) * (C @ X)
|
||||
|
||||
return s
|
||||
|
||||
|
||||
def _filter_bank_esh_channel(x_ch: np.ndarray, win_type: WinType) -> np.ndarray:
|
||||
"""
|
||||
ESH analysis for one channel.
|
||||
|
||||
Returns:
|
||||
X_esh: shape (128, 8), where each column is the 128 MDCT coeffs of one short window.
|
||||
"""
|
||||
wS = _short_window(win_type)
|
||||
X_esh = np.empty((128, 8), dtype=np.float64)
|
||||
|
||||
# ESH subwindows are taken from the central region:
|
||||
# start positions: 448 + 128*j, j = 0..7
|
||||
for j in range(8):
|
||||
start = 448 + 128 * j
|
||||
seg = x_ch[start:start + 256] * wS
|
||||
X_esh[:, j] = _mdct(seg)
|
||||
|
||||
return X_esh
|
||||
|
||||
|
||||
|
||||
|
||||
def _unpack_esh(frame_F: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Unpack ESH spectrum from shape (128, 16) into per-channel arrays (128, 8).
|
||||
|
||||
Mapping is the inverse of the packing used in filter_bank():
|
||||
out[:, 2*j] = left[:, j]
|
||||
out[:, 2*j+1] = right[:, j]
|
||||
"""
|
||||
if frame_F.shape != (128, 16):
|
||||
raise ValueError("ESH frame_F must have shape (128, 16).")
|
||||
|
||||
left = np.empty((128, 8), dtype=np.float64)
|
||||
right = np.empty((128, 8), dtype=np.float64)
|
||||
for j in range(8):
|
||||
left[:, j] = frame_F[:, 2 * j + 0]
|
||||
right[:, j] = frame_F[:, 2 * j + 1]
|
||||
return left, right
|
||||
|
||||
|
||||
def _i_filter_bank_esh_channel(X_esh: np.ndarray, win_type: WinType) -> np.ndarray:
|
||||
"""
|
||||
ESH synthesis for one channel.
|
||||
|
||||
Input:
|
||||
X_esh: (128, 8) MDCT coeffs for 8 short windows
|
||||
|
||||
Output:
|
||||
x_ch: (2048, ) time-domain frame contribution (windowed),
|
||||
ready for OLA at the caller level.
|
||||
"""
|
||||
if X_esh.shape != (128, 8):
|
||||
raise ValueError("X_esh must have shape (128, 8).")
|
||||
|
||||
wS = _short_window(win_type)
|
||||
out = np.zeros(2048, dtype=np.float64)
|
||||
|
||||
# Each short IMDCT returns 256 samples. Place them at:
|
||||
# start = 448 + 128*j, j=0..7 (50% overlap)
|
||||
for j in range(8):
|
||||
seg = _imdct(X_esh[:, j]) * wS # (256,)
|
||||
start = 448 + 128 * j
|
||||
out[start:start + 256] += seg
|
||||
|
||||
return out
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Public Function prototypes (Level 1)
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
def SSC(frame_T: FrameT, next_frame_T: FrameT, prev_frame_type: FrameType) -> FrameType:
|
||||
"""
|
||||
Sequence Segmentation Control (SSC).
|
||||
Selects and returns the frame type for the current frame (i) based on input parameters.
|
||||
|
||||
Parameters
|
||||
-------
|
||||
frame_T: FrameT
|
||||
current time-domain frame i, stereo, shape (2048, 2)
|
||||
next_frame_T: FrameT
|
||||
next time-domain frame (i+1), stereo, shape (2048, 2)
|
||||
(used to decide transitions to/from ESH)
|
||||
prev_frame_type: FrameType
|
||||
frame type chosen for the previous frame (i-1)
|
||||
|
||||
Returns
|
||||
-------
|
||||
frame_type : FrameType
|
||||
- "OLS" (ONLY_LONG_SEQUENCE)
|
||||
- "LSS" (LONG_START_SEQUENCE)
|
||||
- "ESH" (EIGHT_SHORT_SEQUENCE)
|
||||
- "LPS" (LONG_STOP_SEQUENCE)
|
||||
"""
|
||||
if frame_T.shape != (2048, 2):
|
||||
raise ValueError("frame_T must have shape (2048, 2).")
|
||||
if next_frame_T.shape != (2048, 2):
|
||||
raise ValueError("next_frame_T must have shape (2048, 2).")
|
||||
|
||||
# Detect attack independently per channel on next frame.
|
||||
attack_l = _detect_attack(next_frame_T[:, 0])
|
||||
attack_r = _detect_attack(next_frame_T[:, 1])
|
||||
|
||||
# Decide per-channel type based on shared prev_frame_type.
|
||||
ft_l = _decide_frame_type(prev_frame_type, attack_l)
|
||||
ft_r = _decide_frame_type(prev_frame_type, attack_r)
|
||||
|
||||
# Stereo merge as per Table 1.
|
||||
return _stereo_merge(ft_l, ft_r)
|
||||
|
||||
|
||||
def filter_bank(frame_T: FrameT, frame_type: FrameType, win_type: WinType) -> FrameF:
|
||||
"""
|
||||
Filterbank stage (MDCT analysis).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
frame_T : FrameT
|
||||
Time-domain frame, stereo, shape (2048, 2).
|
||||
frame_type : FrameType
|
||||
Type of the frame under encoding ("OLS"|"LSS"|"ESH"|"LPS").
|
||||
win_type : WinType
|
||||
Window type ("KBD" or "SIN") used for the current frame.
|
||||
|
||||
Returns
|
||||
-------
|
||||
frame_F : FrameF
|
||||
Frequency-domain MDCT coefficients:
|
||||
- If frame_type in {"OLS","LSS","LPS"}: array shape (1024, 2)
|
||||
containing MDCT coefficients for both channels.
|
||||
- If frame_type == "ESH": contains 8 subframes, each subframe has shape (128,2),
|
||||
placed in columns according to subframe order, i.e. overall shape (128, 16).
|
||||
"""
|
||||
if frame_T.shape != (2048, 2):
|
||||
raise ValueError("frame_T must have shape (2048, 2).")
|
||||
|
||||
xL = frame_T[:, 0].astype(np.float64, copy=False)
|
||||
xR = frame_T[:, 1].astype(np.float64, copy=False)
|
||||
|
||||
if frame_type in ("OLS", "LSS", "LPS"):
|
||||
w = _window_sequence(frame_type, win_type) # length 2048
|
||||
XL = _mdct(xL * w) # length 1024
|
||||
XR = _mdct(xR * w) # length 1024
|
||||
out = np.empty((1024, 2), dtype=np.float64)
|
||||
out[:, 0] = XL
|
||||
out[:, 1] = XR
|
||||
return out
|
||||
|
||||
if frame_type == "ESH":
|
||||
Xl = _filter_bank_esh_channel(xL, win_type) # (128, 8)
|
||||
Xr = _filter_bank_esh_channel(xR, win_type) # (128, 8)
|
||||
|
||||
# Pack into (128, 16): each subframe as (128,2) placed in columns
|
||||
out = np.empty((128, 16), dtype=np.float64)
|
||||
for j in range(8):
|
||||
out[:, 2 * j + 0] = Xl[:, j]
|
||||
out[:, 2 * j + 1] = Xr[:, j]
|
||||
return out
|
||||
|
||||
raise ValueError(f"Invalid frame_type: {frame_type!r}")
|
||||
|
||||
|
||||
def i_filter_bank(frame_F: FrameF, frame_type: FrameType, win_type: WinType) -> FrameT:
|
||||
"""
|
||||
Inverse filterbank (IMDCT synthesis).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
frame_F : FrameF
|
||||
Frequency-domain MDCT coefficients as produced by filter_bank().
|
||||
frame_type : FrameType
|
||||
Frame type ("OLS"|"LSS"|"ESH"|"LPS").
|
||||
win_type : WinType
|
||||
Window type ("KBD" or "SIN").
|
||||
|
||||
Returns
|
||||
-------
|
||||
frame_T : FrameT
|
||||
Reconstructed time-domain frame, stereo, shape (2048, 2).
|
||||
"""
|
||||
if frame_type in ("OLS", "LSS", "LPS"):
|
||||
if frame_F.shape != (1024, 2):
|
||||
raise ValueError("For OLS/LSS/LPS, frame_F must have shape (1024, 2).")
|
||||
|
||||
w = _window_sequence(frame_type, win_type)
|
||||
|
||||
xL = _imdct(frame_F[:, 0]) * w
|
||||
xR = _imdct(frame_F[:, 1]) * w
|
||||
|
||||
out = np.empty((2048, 2), dtype=np.float64)
|
||||
out[:, 0] = xL
|
||||
out[:, 1] = xR
|
||||
return out
|
||||
|
||||
if frame_type == "ESH":
|
||||
if frame_F.shape != (128, 16):
|
||||
raise ValueError("For ESH, frame_F must have shape (128, 16).")
|
||||
|
||||
Xl, Xr = _unpack_esh(frame_F)
|
||||
xL = _i_filter_bank_esh_channel(Xl, win_type)
|
||||
xR = _i_filter_bank_esh_channel(Xr, win_type)
|
||||
|
||||
out = np.empty((2048, 2), dtype=np.float64)
|
||||
out[:, 0] = xL
|
||||
out[:, 1] = xR
|
||||
return out
|
||||
|
||||
raise ValueError(f"Invalid frame_type: {frame_type!r}")
|
||||
|
||||
|
||||
def aac_coder_1(filename_in: Union[str, Path]) -> AACSeq1:
|
||||
"""
|
||||
Level-1 AAC encoder.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
filename_in : str | Path
|
||||
Input WAV filename.
|
||||
Assumption: stereo audio, sampling rate 48 kHz.
|
||||
|
||||
Returns
|
||||
-------
|
||||
aac_seq_1 : AACSeq1
|
||||
List of K encoded frames.
|
||||
For each i:
|
||||
|
||||
- aac_seq_1[i]["frame_type"]: FrameType
|
||||
- aac_seq_1[i]["win_type"]: WinType
|
||||
- aac_seq_1[i]["chl"]["frame_F"]:
|
||||
- ESH: shape (128, 8)
|
||||
- else: shape (1024, 1)
|
||||
- aac_seq_1[i]["chr"]["frame_F"]:
|
||||
- ESH: shape (128, 8)
|
||||
- else: shape (1024, 1)
|
||||
"""
|
||||
filename_in = Path(filename_in)
|
||||
|
||||
x, fs = sf.read(str(filename_in), always_2d=True)
|
||||
x = np.asarray(x, dtype=np.float64)
|
||||
|
||||
if x.shape[1] != 2:
|
||||
raise ValueError("Input must be stereo (2 channels).")
|
||||
if fs != 48000:
|
||||
raise ValueError("Input sampling rate must be 48 kHz.")
|
||||
|
||||
hop = 1024
|
||||
win = 2048
|
||||
|
||||
# Pad at the beginning to support the first overlap region.
|
||||
# Tail padding is kept minimal; next-frame is padded on-the-fly when needed.
|
||||
pad_pre = np.zeros((hop, 2), dtype=np.float64)
|
||||
pad_post = np.zeros((hop, 2), dtype=np.float64)
|
||||
x_pad = np.vstack([pad_pre, x, pad_post])
|
||||
|
||||
# Number of frames such that current frame fits; next frame will be padded if needed.
|
||||
K = int((x_pad.shape[0] - win) // hop + 1)
|
||||
if K <= 0:
|
||||
raise ValueError("Input too short for framing.")
|
||||
|
||||
aac_seq: AACSeq1 = []
|
||||
prev_frame_type: FrameType = "OLS"
|
||||
|
||||
for i in range(K):
|
||||
start = i * hop
|
||||
|
||||
frame_t: FrameT = x_pad[start:start + win, :]
|
||||
if frame_t.shape != (win, 2):
|
||||
# This should not happen due to K definition, but we keep it explicit.
|
||||
raise ValueError("Internal framing error: frame_t has wrong shape.")
|
||||
|
||||
next_t = x_pad[start + hop:start + hop + win, :]
|
||||
|
||||
# Ensure next_t is always (2048,2) by zero-padding at the tail.
|
||||
if next_t.shape[0] < win:
|
||||
tail = np.zeros((win - next_t.shape[0], 2), dtype=np.float64)
|
||||
next_t = np.vstack([next_t, tail])
|
||||
|
||||
frame_type = SSC(frame_t, next_t, prev_frame_type)
|
||||
frame_f = filter_bank(frame_t, frame_type, WIN_TYPE)
|
||||
|
||||
# Store per-channel as required by AACSeq1 schema
|
||||
if frame_type == "ESH":
|
||||
# frame_f: (128, 16) packed as [L0 R0 L1 R1 ... L7 R7]
|
||||
chl_f = np.empty((128, 8), dtype=np.float64)
|
||||
chr_f = np.empty((128, 8), dtype=np.float64)
|
||||
for j in range(8):
|
||||
chl_f[:, j] = frame_f[:, 2 * j + 0]
|
||||
chr_f[:, j] = frame_f[:, 2 * j + 1]
|
||||
else:
|
||||
# frame_f: (1024, 2)
|
||||
chl_f = frame_f[:, 0:1].astype(np.float64, copy=False)
|
||||
chr_f = frame_f[:, 1:2].astype(np.float64, copy=False)
|
||||
|
||||
aac_seq.append({
|
||||
"frame_type": frame_type,
|
||||
"win_type": WIN_TYPE,
|
||||
"chl": {"frame_F": chl_f},
|
||||
"chr": {"frame_F": chr_f},
|
||||
})
|
||||
prev_frame_type = frame_type
|
||||
return aac_seq
|
||||
|
||||
|
||||
def i_aac_coder_1(aac_seq_1: AACSeq1, filename_out: Union[str, Path]) -> np.ndarray:
|
||||
"""
|
||||
Level-1 AAC decoder (inverse of aac_coder_1()).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
aac_seq_1 : AACSeq1
|
||||
Encoded sequence as produced by aac_coder_1().
|
||||
filename_out : str | Path
|
||||
Output WAV filename.
|
||||
Assumption: stereo audio, sampling rate 48 kHz.
|
||||
|
||||
Returns
|
||||
-------
|
||||
x : np.ndarray
|
||||
Decoded audio samples (time-domain).
|
||||
Expected shape: (N, 2) for stereo (N depends on input length).
|
||||
"""
|
||||
filename_out = Path(filename_out)
|
||||
|
||||
hop = 1024
|
||||
win = 2048
|
||||
K = len(aac_seq_1)
|
||||
|
||||
# Output includes the encoder padding region, so we reconstruct
|
||||
# full padded stream. For K frames: last frame starts at (K-1)*hop and spans win,
|
||||
# so total length = (K-1)*hop + win
|
||||
n_pad = (K - 1) * hop + win
|
||||
y_pad = np.zeros((n_pad, 2), dtype=np.float64)
|
||||
|
||||
for i, fr in enumerate(aac_seq_1):
|
||||
frame_type = fr["frame_type"]
|
||||
win_type = fr["win_type"]
|
||||
|
||||
chl_f = np.asarray(fr["chl"]["frame_F"], dtype=np.float64)
|
||||
chr_f = np.asarray(fr["chr"]["frame_F"], dtype=np.float64)
|
||||
|
||||
# Re-pack into the format expected by i_filter_bank()
|
||||
if frame_type == "ESH":
|
||||
if chl_f.shape != (128, 8) or chr_f.shape != (128, 8):
|
||||
raise ValueError("ESH channel frame_F must have shape (128, 8).")
|
||||
|
||||
frame_f = np.empty((128, 16), dtype=np.float64)
|
||||
for j in range(8):
|
||||
frame_f[:, 2 * j + 0] = chl_f[:, j]
|
||||
frame_f[:, 2 * j + 1] = chr_f[:, j]
|
||||
else:
|
||||
if chl_f.shape != (1024, 1) or chr_f.shape != (1024, 1):
|
||||
raise ValueError("Non-ESH channel frame_F must have shape (1024, 1).")
|
||||
|
||||
frame_f = np.empty((1024, 2), dtype=np.float64)
|
||||
frame_f[:, 0] = chl_f[:, 0]
|
||||
frame_f[:, 1] = chr_f[:, 0]
|
||||
|
||||
frame_t_hat = i_filter_bank(frame_f, frame_type, win_type) # (2048, 2)
|
||||
|
||||
start = i * hop
|
||||
y_pad[start:start + win, :] += frame_t_hat
|
||||
|
||||
# Remove boundary padding that encoder adds: hop samples at start and hop at end.
|
||||
if y_pad.shape[0] < 2 * hop:
|
||||
raise ValueError("Decoded stream too short to unpad.")
|
||||
|
||||
y = y_pad[hop:-hop, :]
|
||||
|
||||
sf.write(str(filename_out), y, 48000)
|
||||
return y
|
||||
|
||||
|
||||
def demo_aac_1(filename_in: Union[str, Path], filename_out: Union[str, Path]) -> float:
|
||||
"""
|
||||
Demonstration for Level-1 codec.
|
||||
|
||||
Runs:
|
||||
- aac_coder_1(filename_in)
|
||||
- i_aac_coder_1(aac_seq_1, filename_out)
|
||||
and computes total SNR between original and decoded audio.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
filename_in : str | Path
|
||||
Input WAV filename (stereo, 48 kHz).
|
||||
filename_out : str | Path
|
||||
Output WAV filename (stereo, 48 kHz).
|
||||
|
||||
Returns
|
||||
-------
|
||||
SNR : float
|
||||
Overall Signal-to-Noise Ratio in dB.
|
||||
"""
|
||||
filename_in = Path(filename_in)
|
||||
filename_out = Path(filename_out)
|
||||
|
||||
# Read original audio (reference)
|
||||
x_ref, fs_ref = sf.read(str(filename_in), always_2d=True)
|
||||
x_ref = np.asarray(x_ref, dtype=np.float64)
|
||||
|
||||
# Encode / decode
|
||||
aac_seq_1 = aac_coder_1(filename_in)
|
||||
x_hat = i_aac_coder_1(aac_seq_1, filename_out)
|
||||
x_hat = np.asarray(x_hat, dtype=np.float64)
|
||||
|
||||
# Ensure 2D stereo shape (N, 2)
|
||||
if x_hat.ndim == 1:
|
||||
x_hat = x_hat.reshape(-1, 1)
|
||||
if x_ref.ndim == 1:
|
||||
x_ref = x_ref.reshape(-1, 1)
|
||||
|
||||
# Align lengths (use common overlap)
|
||||
n = min(x_ref.shape[0], x_hat.shape[0])
|
||||
x_ref = x_ref[:n, :]
|
||||
x_hat = x_hat[:n, :]
|
||||
|
||||
# Match channel count conservatively (common channels)
|
||||
c = min(x_ref.shape[1], x_hat.shape[1])
|
||||
x_ref = x_ref[:, :c]
|
||||
x_hat = x_hat[:, :c]
|
||||
|
||||
# Compute overall SNR over all samples and channels
|
||||
err = x_ref - x_hat
|
||||
p_signal = float(np.sum(x_ref * x_ref))
|
||||
p_noise = float(np.sum(err * err))
|
||||
|
||||
if p_noise <= 0.0:
|
||||
return float("inf")
|
||||
if p_signal <= 0.0:
|
||||
# Degenerate case: silent input
|
||||
return -float("inf")
|
||||
# else:
|
||||
snr_db = 10.0 * np.log10(p_signal / p_noise)
|
||||
return float(snr_db)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Example usage:
|
||||
# python -m level_1.level_1 input.wav output.wav
|
||||
import sys
|
||||
|
||||
if len(sys.argv) != 3:
|
||||
raise SystemExit("Usage: python -m level_1.level_1 <input.wav> <output.wav>")
|
||||
|
||||
in_wav = sys.argv[1]
|
||||
out_wav = sys.argv[2]
|
||||
|
||||
print(f"Encoding/Decoding {in_wav} to {out_wav}")
|
||||
snr = demo_aac_1(in_wav, out_wav)
|
||||
print(f"SNR = {snr:.3f} dB")
|
||||
|
||||
199
source/level_1/tests/test_SSC.py
Normal file
199
source/level_1/tests/test_SSC.py
Normal file
@ -0,0 +1,199 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
# Adjust the import based on package/module layout.
|
||||
from level_1.level_1 import SSC
|
||||
|
||||
# Helper "fixtures" for SSC
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
def _next_frame_no_attack() -> np.ndarray:
|
||||
"""
|
||||
Build a next_frame_T that should NOT trigger ESH detection.
|
||||
|
||||
Uses exact zeros so all s2l are zero and the ESH condition (s2l > 1e-3) cannot hold.
|
||||
"""
|
||||
return np.zeros((2048, 2), dtype=np.float64)
|
||||
|
||||
|
||||
def _next_frame_strong_attack(
|
||||
*,
|
||||
attack_left: bool,
|
||||
attack_right: bool,
|
||||
segment_l: int = 4,
|
||||
baseline: float = 1e-6,
|
||||
burst_amp: float = 1.0,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Build a next_frame_T (2048x2) that should trigger ESH detection on selected channels.
|
||||
|
||||
Spec: ESH if exists l in {1..7} with s2l > 1e-3 AND ds2l > 10.
|
||||
We create:
|
||||
- small baseline energy in all samples (avoids division by zero in ds2l),
|
||||
- a strong burst inside one 128-sample segment l in 1..7.
|
||||
"""
|
||||
assert 1 <= segment_l <= 7
|
||||
x = np.full((2048, 2), baseline, dtype=np.float64)
|
||||
|
||||
a = segment_l * 128
|
||||
b = (segment_l + 1) * 128
|
||||
|
||||
if attack_left:
|
||||
x[a:b, 0] += burst_amp
|
||||
if attack_right:
|
||||
x[a:b, 1] += burst_amp
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def _next_frame_below_s2l_threshold(
|
||||
*,
|
||||
left: bool,
|
||||
right: bool,
|
||||
segment_l: int = 4,
|
||||
impulse_amp: float = 0.01,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Construct a next_frame_T where s2l is below 1e-3, so ESH must NOT be triggered,
|
||||
even if ds2l could be large.
|
||||
|
||||
Put a single impulse of amplitude 'impulse_amp' inside a segment.
|
||||
Energy in the 128-sample segment: s2l ~= impulse_amp^2.
|
||||
With impulse_amp=0.01 => s2l ~= 1e-4 < 1e-3.
|
||||
"""
|
||||
assert 1 <= segment_l <= 7
|
||||
x = np.zeros((2048, 2), dtype=np.float64)
|
||||
|
||||
idx = segment_l * 128 + 10 # inside segment
|
||||
if left:
|
||||
x[idx, 0] = impulse_amp
|
||||
if right:
|
||||
x[idx, 1] = impulse_amp
|
||||
|
||||
return x
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------
|
||||
# 1) Fixed/mandatory cases (prev frame type forces current type)
|
||||
# ---------------------------------------------------------------------
|
||||
|
||||
def test_ssc_fixed_cases_prev_lss_and_lps() -> None:
|
||||
"""
|
||||
Spec: if prev was:
|
||||
- LSS => current MUST be ESH
|
||||
- LPS => current MUST be OLS
|
||||
independent of next frame check.
|
||||
"""
|
||||
frame_t = np.zeros((2048, 2), dtype=np.float64)
|
||||
|
||||
# Even if next frame has a strong attack, LSS must force ESH.
|
||||
next_attack = _next_frame_strong_attack(attack_left=True, attack_right=True)
|
||||
out1 = SSC(frame_t, next_attack, "LSS")
|
||||
assert out1 == "ESH"
|
||||
|
||||
# Even if next frame has a strong attack, LPS must force OLS.
|
||||
out2 = SSC(frame_t, next_attack, "LPS")
|
||||
assert out2 == "OLS"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------
|
||||
# 2) Cases requiring next-frame ESH prediction (energy/attack computation)
|
||||
# ---------------------------------------------------------------------
|
||||
|
||||
def test_prev_ols_next_not_esh_returns_ols() -> None:
|
||||
"""
|
||||
Spec: if prev=OLS, current is OLS or LSS.
|
||||
Choose LSS iff (i+1) predicted ESH, else OLS.
|
||||
"""
|
||||
frame_t = np.zeros((2048, 2), dtype=np.float64)
|
||||
next_t = _next_frame_no_attack()
|
||||
|
||||
out = SSC(frame_t, next_t, "OLS")
|
||||
assert out == "OLS"
|
||||
|
||||
|
||||
def test_prev_ols_next_esh_both_channels_returns_lss() -> None:
|
||||
"""
|
||||
prev=OLS, next predicted ESH (both channels) => per-channel decisions are LSS and LSS
|
||||
and merge table keeps LSS.
|
||||
"""
|
||||
frame_t = np.zeros((2048, 2), dtype=np.float64)
|
||||
next_t = _next_frame_strong_attack(attack_left=True, attack_right=True)
|
||||
|
||||
out = SSC(frame_t, next_t, "OLS")
|
||||
assert out == "LSS"
|
||||
|
||||
|
||||
def test_prev_ols_next_esh_one_channel_returns_lss() -> None:
|
||||
"""
|
||||
prev=OLS:
|
||||
- one channel predicts ESH => LSS
|
||||
- other channel predicts not ESH => OLS
|
||||
Merge table: OLS + LSS => LSS.
|
||||
"""
|
||||
frame_t = np.zeros((2048, 2), dtype=np.float64)
|
||||
|
||||
next1_t = _next_frame_strong_attack(attack_left=True, attack_right=False)
|
||||
out1 = SSC(frame_t, next1_t, "OLS")
|
||||
assert out1 == "LSS"
|
||||
|
||||
next2_t = _next_frame_strong_attack(attack_left=False, attack_right=True)
|
||||
out2 = SSC(frame_t, next2_t, "OLS")
|
||||
assert out2 == "LSS"
|
||||
|
||||
|
||||
def test_prev_esh_next_esh_both_channels_returns_esh() -> None:
|
||||
"""
|
||||
prev=ESH:
|
||||
- next predicted ESH => current ESH (per-channel)
|
||||
Merge table: ESH + ESH => ESH.
|
||||
"""
|
||||
frame_t = np.zeros((2048, 2), dtype=np.float64)
|
||||
next_t = _next_frame_strong_attack(attack_left=True, attack_right=True)
|
||||
|
||||
out = SSC(frame_t, next_t, "ESH")
|
||||
assert out == "ESH"
|
||||
|
||||
|
||||
def test_prev_esh_next_not_esh_both_channels_returns_lps() -> None:
|
||||
"""
|
||||
prev=ESH:
|
||||
- next not predicted ESH => current LPS (per-channel)
|
||||
Merge table: LPS + LPS => LPS.
|
||||
"""
|
||||
frame_t = np.zeros((2048, 2), dtype=np.float64)
|
||||
next_t = _next_frame_no_attack()
|
||||
|
||||
out = SSC(frame_t, next_t, "ESH")
|
||||
assert out == "LPS"
|
||||
|
||||
|
||||
def test_prev_esh_next_esh_one_channel_merged_is_esh() -> None:
|
||||
"""
|
||||
prev=ESH:
|
||||
- one channel predicts ESH => ESH
|
||||
- other channel predicts not ESH => LPS
|
||||
Merge table: ESH + LPS => ESH.
|
||||
"""
|
||||
frame_t = np.zeros((2048, 2), dtype=np.float64)
|
||||
|
||||
next1_t = _next_frame_strong_attack(attack_left=True, attack_right=False)
|
||||
out1 = SSC(frame_t, next1_t, "ESH")
|
||||
assert out1 == "ESH"
|
||||
|
||||
next2_t = _next_frame_strong_attack(attack_left=True, attack_right=False)
|
||||
out2 = SSC(frame_t, next2_t, "ESH")
|
||||
assert out2 == "ESH"
|
||||
|
||||
def test_threshold_s2l_must_exceed_1e_3() -> None:
|
||||
"""
|
||||
Spec: next frame is ESH only if s2l > 1e-3 AND ds2l > 10 for some l in 1..7.
|
||||
This test checks the necessity of the s2l threshold:
|
||||
- Create a frame with s2l ~= 1e-4 < 1e-3 (single impulse with amp 0.01).
|
||||
- Expect: not classified as ESH -> for prev=OLS return OLS.
|
||||
"""
|
||||
frame_t = np.zeros((2048, 2), dtype=np.float64)
|
||||
next_t = _next_frame_below_s2l_threshold(left=True, right=True, impulse_amp=0.01)
|
||||
|
||||
out = SSC(frame_t, next_t, "OLS")
|
||||
assert out == "OLS"
|
||||
123
source/level_1/tests/test_aac_level1.py
Normal file
123
source/level_1/tests/test_aac_level1.py
Normal file
@ -0,0 +1,123 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import soundfile as sf
|
||||
|
||||
from level_1.level_1 import aac_coder_1, i_aac_coder_1
|
||||
|
||||
# Helper "fixtures" for aac_coder_1 / i_aac_coder_1
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
def _snr_db(x_ref: np.ndarray, x_hat: np.ndarray) -> float:
|
||||
"""
|
||||
Compute overall SNR (dB) over all samples and channels after aligning lengths.
|
||||
"""
|
||||
x_ref = np.asarray(x_ref, dtype=np.float64)
|
||||
x_hat = np.asarray(x_hat, dtype=np.float64)
|
||||
|
||||
if x_ref.ndim == 1:
|
||||
x_ref = x_ref.reshape(-1, 1)
|
||||
if x_hat.ndim == 1:
|
||||
x_hat = x_hat.reshape(-1, 1)
|
||||
|
||||
n = min(x_ref.shape[0], x_hat.shape[0])
|
||||
c = min(x_ref.shape[1], x_hat.shape[1])
|
||||
|
||||
x_ref = x_ref[:n, :c]
|
||||
x_hat = x_hat[:n, :c]
|
||||
|
||||
err = x_ref - x_hat
|
||||
ps = float(np.sum(x_ref * x_ref))
|
||||
pn = float(np.sum(err * err))
|
||||
|
||||
if pn <= 0.0:
|
||||
return float("inf")
|
||||
if ps <= 0.0:
|
||||
return -float("inf")
|
||||
|
||||
return float(10.0 * np.log10(ps / pn))
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def tmp_stereo_wav(tmp_path: Path) -> Path:
|
||||
"""
|
||||
Create a temporary 48 kHz stereo WAV with random samples.
|
||||
"""
|
||||
rng = np.random.default_rng(123)
|
||||
fs = 48000
|
||||
|
||||
# ~1 second of audio, keep small for test speed
|
||||
n = fs
|
||||
x = rng.normal(size=(n, 2)).astype(np.float64)
|
||||
|
||||
wav_path = tmp_path / "in.wav"
|
||||
sf.write(str(wav_path), x, fs)
|
||||
return wav_path
|
||||
|
||||
|
||||
def test_aac_coder_seq_schema_and_shapes(tmp_stereo_wav: Path) -> None:
|
||||
"""
|
||||
Module-level contract test:
|
||||
Ensure aac_seq_1 follows the expected schema and per-frame shapes.
|
||||
"""
|
||||
aac_seq = aac_coder_1(tmp_stereo_wav)
|
||||
|
||||
assert isinstance(aac_seq, list)
|
||||
assert len(aac_seq) > 0
|
||||
|
||||
for fr in aac_seq:
|
||||
assert isinstance(fr, dict)
|
||||
|
||||
# Required keys
|
||||
assert "frame_type" in fr
|
||||
assert "win_type" in fr
|
||||
assert "chl" in fr
|
||||
assert "chr" in fr
|
||||
|
||||
frame_type = fr["frame_type"]
|
||||
win_type = fr["win_type"]
|
||||
|
||||
assert frame_type in ("OLS", "LSS", "ESH", "LPS")
|
||||
assert win_type in ("SIN", "KBD")
|
||||
|
||||
assert isinstance(fr["chl"], dict)
|
||||
assert isinstance(fr["chr"], dict)
|
||||
assert "frame_F" in fr["chl"]
|
||||
assert "frame_F" in fr["chr"]
|
||||
|
||||
chl_f = np.asarray(fr["chl"]["frame_F"])
|
||||
chr_f = np.asarray(fr["chr"]["frame_F"])
|
||||
|
||||
if frame_type == "ESH":
|
||||
assert chl_f.shape == (128, 8)
|
||||
assert chr_f.shape == (128, 8)
|
||||
else:
|
||||
assert chl_f.shape == (1024, 1)
|
||||
assert chr_f.shape == (1024, 1)
|
||||
|
||||
|
||||
def test_end_to_end_aac_coder_decoder_high_snr(tmp_stereo_wav: Path, tmp_path: Path) -> None:
|
||||
"""
|
||||
End-to-end module test:
|
||||
Encode + decode and check SNR is very high (numerical-noise only).
|
||||
Threshold is intentionally loose to avoid fragility.
|
||||
"""
|
||||
x_ref, fs = sf.read(str(tmp_stereo_wav), always_2d=True)
|
||||
assert fs == 48000
|
||||
|
||||
out_wav = tmp_path / "out.wav"
|
||||
|
||||
aac_seq = aac_coder_1(tmp_stereo_wav)
|
||||
x_hat = i_aac_coder_1(aac_seq, out_wav)
|
||||
|
||||
# Basic sanity: output file exists and is readable
|
||||
assert out_wav.exists()
|
||||
x_hat_file, fs_hat = sf.read(str(out_wav), always_2d=True)
|
||||
assert fs_hat == 48000
|
||||
|
||||
# SNR computed against the array returned by i_aac_coder_1 (should match file, but not required)
|
||||
snr = _snr_db(x_ref, x_hat)
|
||||
assert snr > 80.0
|
||||
235
source/level_1/tests/test_filterbank.py
Normal file
235
source/level_1/tests/test_filterbank.py
Normal file
@ -0,0 +1,235 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from level_1.level_1 import FrameType, WinType, filter_bank, i_filter_bank
|
||||
|
||||
# Helper "fixtures" for filterbank
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
def _ola_reconstruct(x: np.ndarray, frame_types: list[str], win_type: str) -> np.ndarray:
|
||||
"""
|
||||
Analyze-synthesize each frame and overlap-add with hop=1024.
|
||||
x: shape (N,2)
|
||||
frame_types: length K, for frames starting at i*1024
|
||||
"""
|
||||
hop = 1024
|
||||
win = 2048
|
||||
K = len(frame_types)
|
||||
|
||||
y = np.zeros_like(x, dtype=np.float64)
|
||||
|
||||
for i in range(K):
|
||||
start = i * hop
|
||||
frame_t = x[start:start + win, :]
|
||||
frame_f = filter_bank(frame_t, frame_types[i], win_type)
|
||||
frame_t_hat = i_filter_bank(frame_f, frame_types[i], win_type)
|
||||
y[start:start + win, :] += frame_t_hat
|
||||
|
||||
return y
|
||||
|
||||
|
||||
def _snr_db(x: np.ndarray, y: np.ndarray) -> float:
|
||||
err = x - y
|
||||
ps = float(np.sum(x * x))
|
||||
pn = float(np.sum(err * err))
|
||||
if pn <= 0.0:
|
||||
return float("inf")
|
||||
return 10.0 * np.log10(ps / pn)
|
||||
|
||||
# ---------------------------------------------------------------------
|
||||
# Forward filterbank tests
|
||||
# ---------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.parametrize("win_type", ["SIN", "KBD"])
|
||||
@pytest.mark.parametrize("frame_type", ["OLS", "LSS", "LPS"])
|
||||
def test_filterbank_shapes_long_sequences(frame_type: FrameType, win_type: WinType) -> None:
|
||||
"""
|
||||
Contract test:
|
||||
For OLS/LSS/LPS, filter_bank returns shape (1024, 2).
|
||||
"""
|
||||
frame_t = np.zeros((2048, 2), dtype=np.float64)
|
||||
frame_f = filter_bank(frame_t, frame_type, win_type)
|
||||
assert frame_f.shape == (1024, 2)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("win_type", ["SIN", "KBD"])
|
||||
def test_filterbank_shapes_esh(win_type: WinType) -> None:
|
||||
"""
|
||||
Contract test:
|
||||
For ESH, filter_bank returns shape (128, 16).
|
||||
"""
|
||||
frame_t = np.zeros((2048, 2), dtype=np.float64)
|
||||
frame_f = filter_bank(frame_t, "ESH", win_type)
|
||||
assert frame_f.shape == (128, 16)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("win_type", ["SIN", "KBD"])
|
||||
def test_filterbank_channel_isolation_long_sequences(win_type: WinType) -> None:
|
||||
"""
|
||||
Module behavior test:
|
||||
For OLS (representative long-sequence), channels are processed independently:
|
||||
- If right channel is zero and left is random, right spectrum should be near zero.
|
||||
"""
|
||||
rng = np.random.default_rng(0)
|
||||
frame_t = np.zeros((2048, 2), dtype=np.float64)
|
||||
frame_t[:, 0] = rng.normal(size=2048)
|
||||
|
||||
frame_f = filter_bank(frame_t, "OLS", win_type)
|
||||
|
||||
# Right channel output should be (close to) zero
|
||||
assert np.max(np.abs(frame_f[:, 1])) < 1e-9
|
||||
|
||||
|
||||
@pytest.mark.parametrize("win_type", ["SIN", "KBD"])
|
||||
def test_filterbank_channel_isolation_esh(win_type: WinType) -> None:
|
||||
"""
|
||||
Module behavior test:
|
||||
For ESH, channels are processed independently:
|
||||
- If right channel is zero and left is random, all odd columns (right) should be near zero.
|
||||
"""
|
||||
rng = np.random.default_rng(1)
|
||||
frame_t = np.zeros((2048, 2), dtype=np.float64)
|
||||
frame_t[:, 0] = rng.normal(size=2048)
|
||||
|
||||
frame_f = filter_bank(frame_t, "ESH", win_type)
|
||||
|
||||
# Right channel appears in columns 1,3,5,...,15
|
||||
right_cols = frame_f[:, 1::2]
|
||||
assert np.max(np.abs(right_cols)) < 1e-9
|
||||
|
||||
|
||||
@pytest.mark.parametrize("win_type", ["SIN", "KBD"])
|
||||
def test_filterbank_esh_ignores_outer_regions(win_type: WinType) -> None:
|
||||
"""
|
||||
Spec-driven behavior test:
|
||||
ESH uses only the central 1152 samples (from 448 to 1599), split into 8 overlapping
|
||||
windows of length 256 with 50% overlap.
|
||||
|
||||
Therefore, changing samples outside [448, 1600) must not affect the output.
|
||||
"""
|
||||
rng = np.random.default_rng(2)
|
||||
|
||||
frame_a = np.zeros((2048, 2), dtype=np.float64)
|
||||
frame_b = np.zeros((2048, 2), dtype=np.float64)
|
||||
|
||||
# Same central region for both frames
|
||||
center = rng.normal(size=(1152, 2))
|
||||
frame_a[448:1600, :] = center
|
||||
frame_b[448:1600, :] = center
|
||||
|
||||
# Modify only the outer regions of frame_b
|
||||
frame_b[0:448, :] = rng.normal(size=(448, 2))
|
||||
frame_b[1600:2048, :] = rng.normal(size=(448, 2))
|
||||
|
||||
fa = filter_bank(frame_a, "ESH", win_type)
|
||||
fb = filter_bank(frame_b, "ESH", win_type)
|
||||
|
||||
np.testing.assert_allclose(fa, fb, rtol=0.0, atol=0.0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("win_type", ["SIN", "KBD"])
|
||||
def test_filterbank_output_is_finite(win_type: WinType) -> None:
|
||||
"""
|
||||
Sanity test:
|
||||
Output must not contain NaN or inf for representative cases.
|
||||
"""
|
||||
rng = np.random.default_rng(3)
|
||||
frame_t = rng.normal(size=(2048, 2)).astype(np.float64)
|
||||
|
||||
for frame_type in ("OLS", "LSS", "ESH", "LPS"):
|
||||
frame_f = filter_bank(frame_t, frame_type, win_type)
|
||||
assert np.isfinite(frame_f).all()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------
|
||||
# Reverse i_filterbank tests
|
||||
# ---------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize("win_type", ["SIN", "KBD"])
|
||||
def test_ifilterbank_shapes_long_sequences(win_type: str) -> None:
|
||||
frame_f = np.zeros((1024, 2), dtype=np.float64)
|
||||
for frame_type in ("OLS", "LSS", "LPS"):
|
||||
frame_t = i_filter_bank(frame_f, frame_type, win_type)
|
||||
assert frame_t.shape == (2048, 2)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("win_type", ["SIN", "KBD"])
|
||||
def test_ifilterbank_shapes_esh(win_type: str) -> None:
|
||||
frame_f = np.zeros((128, 16), dtype=np.float64)
|
||||
frame_t = i_filter_bank(frame_f, "ESH", win_type)
|
||||
assert frame_t.shape == (2048, 2)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("win_type", ["SIN", "KBD"])
|
||||
def test_roundtrip_per_frame_is_finite(win_type: str) -> None:
|
||||
rng = np.random.default_rng(0)
|
||||
frame_t = rng.normal(size=(2048, 2)).astype(np.float64)
|
||||
|
||||
for frame_type in ("OLS", "LSS", "ESH", "LPS"):
|
||||
frame_f = filter_bank(frame_t, frame_type, win_type)
|
||||
frame_t_hat = i_filter_bank(frame_f, frame_type, win_type)
|
||||
assert np.isfinite(frame_t_hat).all()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("win_type", ["SIN", "KBD"])
|
||||
def test_ola_reconstruction_ols_high_snr(win_type: str) -> None:
|
||||
"""
|
||||
Core module-level test:
|
||||
OLS analysis+synthesis with hop=1024 must reconstruct with high SNR
|
||||
in the steady-state region.
|
||||
"""
|
||||
rng = np.random.default_rng(1)
|
||||
|
||||
K = 6
|
||||
N = 1024 * (K + 1)
|
||||
x = rng.normal(size=(N, 2)).astype(np.float64)
|
||||
|
||||
y = _ola_reconstruct(x, ["OLS"] * K, win_type)
|
||||
|
||||
# Exclude edges (first and last hop) where full overlap is not available
|
||||
a = 1024
|
||||
b = N - 1024
|
||||
snr = _snr_db(x[a:b, :], y[a:b, :])
|
||||
assert snr > 50.0
|
||||
|
||||
|
||||
@pytest.mark.parametrize("win_type", ["SIN", "KBD"])
|
||||
def test_ola_reconstruction_esh_high_snr(win_type: str) -> None:
|
||||
"""
|
||||
ESH analysis+synthesis with hop=1024 must reconstruct with high SNR
|
||||
in the steady-state region.
|
||||
"""
|
||||
rng = np.random.default_rng(2)
|
||||
|
||||
K = 6
|
||||
N = 1024 * (K + 1)
|
||||
x = rng.normal(size=(N, 2)).astype(np.float64)
|
||||
|
||||
y = _ola_reconstruct(x, ["ESH"] * K, win_type)
|
||||
|
||||
a = 1024
|
||||
b = N - 1024
|
||||
snr = _snr_db(x[a:b, :], y[a:b, :])
|
||||
assert snr > 45.0
|
||||
|
||||
|
||||
@pytest.mark.parametrize("win_type", ["SIN", "KBD"])
|
||||
def test_ola_reconstruction_transition_sequence(win_type: str) -> None:
|
||||
"""
|
||||
Transition sequence test matching the windowing logic:
|
||||
OLS -> LSS -> ESH -> LPS -> OLS -> OLS
|
||||
"""
|
||||
rng = np.random.default_rng(3)
|
||||
|
||||
frame_types = ["OLS", "LSS", "ESH", "LPS", "OLS", "OLS"]
|
||||
K = len(frame_types)
|
||||
N = 1024 * (K + 1)
|
||||
x = rng.normal(size=(N, 2)).astype(np.float64)
|
||||
|
||||
y = _ola_reconstruct(x, frame_types, win_type)
|
||||
|
||||
a = 1024
|
||||
b = N - 1024
|
||||
snr = _snr_db(x[a:b, :], y[a:b, :])
|
||||
assert snr > 40.0
|
||||
102
source/level_1/tests/test_filterbank_internal.py
Normal file
102
source/level_1/tests/test_filterbank_internal.py
Normal file
@ -0,0 +1,102 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from level_1.level_1 import _imdct, _mdct
|
||||
|
||||
# Helper "fixtures" for filterbank internals (MDCT/IMDCT)
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
def _assert_allclose(a: np.ndarray, b: np.ndarray, *, rtol: float, atol: float) -> None:
|
||||
# Helper for consistent tolerances across tests.
|
||||
np.testing.assert_allclose(a, b, rtol=rtol, atol=atol)
|
||||
|
||||
def _estimate_gain(y: np.ndarray, x: np.ndarray) -> float:
|
||||
"""
|
||||
Estimate scalar gain g such that y ~= g*x in least-squares sense.
|
||||
"""
|
||||
denom = float(np.dot(x, x))
|
||||
if denom == 0.0:
|
||||
return 0.0
|
||||
return float(np.dot(y, x) / denom)
|
||||
|
||||
|
||||
tolerance = 1e-10
|
||||
|
||||
@pytest.mark.parametrize("N", [256, 2048])
|
||||
def test_mdct_imdct_mdct_identity_up_to_gain(N: int) -> None:
|
||||
"""
|
||||
Consistency test in coefficient domain:
|
||||
mdct(imdct(X)) ~= g * X
|
||||
|
||||
For our chosen (non-orthonormal) scaling, g is expected to be close to 2.
|
||||
"""
|
||||
rng = np.random.default_rng(0)
|
||||
K = N // 2
|
||||
|
||||
X = rng.normal(size=K).astype(np.float64)
|
||||
x = _imdct(X)
|
||||
X_hat = _mdct(x)
|
||||
|
||||
g = _estimate_gain(X_hat, X)
|
||||
_assert_allclose(X_hat, g * X, rtol=tolerance, atol=tolerance)
|
||||
_assert_allclose(np.array([g]), np.array([2.0]), rtol=tolerance, atol=tolerance)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("N", [256, 2048])
|
||||
def test_mdct_linearity(N: int) -> None:
|
||||
"""
|
||||
Linearity test:
|
||||
mdct(a*x + b*y) == a*mdct(x) + b*mdct(y)
|
||||
|
||||
This should hold up to numerical error.
|
||||
"""
|
||||
rng = np.random.default_rng(1)
|
||||
x = rng.normal(size=N).astype(np.float64)
|
||||
y = rng.normal(size=N).astype(np.float64)
|
||||
|
||||
a = 0.37
|
||||
b = -1.12
|
||||
|
||||
left = _mdct(a * x + b * y)
|
||||
right = a * _mdct(x) + b * _mdct(y)
|
||||
|
||||
_assert_allclose(left, right, rtol=tolerance, atol=tolerance)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("N", [256, 2048])
|
||||
def test_imdct_linearity(N: int) -> None:
|
||||
"""
|
||||
Linearity test for IMDCT:
|
||||
imdct(a*X + b*Y) == a*imdct(X) + b*imdct(Y)
|
||||
"""
|
||||
rng = np.random.default_rng(2)
|
||||
K = N // 2
|
||||
|
||||
X = rng.normal(size=K).astype(np.float64)
|
||||
Y = rng.normal(size=K).astype(np.float64)
|
||||
|
||||
a = -0.5
|
||||
b = 2.0
|
||||
|
||||
left = _imdct(a * X + b * Y)
|
||||
right = a * _imdct(X) + b * _imdct(Y)
|
||||
|
||||
_assert_allclose(left, right, rtol=tolerance, atol=tolerance)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("N", [256, 2048])
|
||||
def test_mdct_imdct_outputs_are_finite(N: int) -> None:
|
||||
"""
|
||||
Sanity test: no NaN/inf on random inputs.
|
||||
"""
|
||||
rng = np.random.default_rng(3)
|
||||
K = N // 2
|
||||
|
||||
x = rng.normal(size=N).astype(np.float64)
|
||||
X = rng.normal(size=K).astype(np.float64)
|
||||
|
||||
X1 = _mdct(x)
|
||||
x1 = _imdct(X)
|
||||
|
||||
assert np.isfinite(X1).all()
|
||||
assert np.isfinite(x1).all()
|
||||
BIN
source/material/LicorDeCalandraca.wav
Normal file
BIN
source/material/LicorDeCalandraca.wav
Normal file
Binary file not shown.
BIN
source/material/LicorDeCalandraca_out.wav
Normal file
BIN
source/material/LicorDeCalandraca_out.wav
Normal file
Binary file not shown.
BIN
source/material/TableB219.mat
Normal file
BIN
source/material/TableB219.mat
Normal file
Binary file not shown.
BIN
source/material/huffCodebooks.mat
Normal file
BIN
source/material/huffCodebooks.mat
Normal file
Binary file not shown.
400
source/material/huff_utils.py
Normal file
400
source/material/huff_utils.py
Normal file
@ -0,0 +1,400 @@
|
||||
import numpy as np
|
||||
import scipy.io as sio
|
||||
import os
|
||||
|
||||
# ------------------ LOAD LUT ------------------
|
||||
|
||||
def load_LUT(mat_filename=None):
|
||||
"""
|
||||
Loads the list of Huffman Codebooks (LUTs)
|
||||
|
||||
Returns:
|
||||
huffLUT : list (index 1..11 used, index 0 unused)
|
||||
"""
|
||||
if mat_filename is None:
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
mat_filename = os.path.join(current_dir, "huffCodebooks.mat")
|
||||
|
||||
mat = sio.loadmat(mat_filename)
|
||||
|
||||
|
||||
huffCodebooks_raw = mat['huffCodebooks'].squeeze()
|
||||
|
||||
huffCodebooks = []
|
||||
for i in range(11):
|
||||
huffCodebooks.append(np.array(huffCodebooks_raw[i]))
|
||||
|
||||
# Build inverse VLC tables
|
||||
invTable = [None] * 11
|
||||
|
||||
for i in range(11):
|
||||
h = huffCodebooks[i][:, 2].astype(int) # column 3
|
||||
hlength = huffCodebooks[i][:, 1].astype(int) # column 2
|
||||
|
||||
hbin = []
|
||||
for j in range(len(h)):
|
||||
hbin.append(format(h[j], f'0{hlength[j]}b'))
|
||||
|
||||
invTable[i] = vlc_table(hbin)
|
||||
|
||||
# Build Huffman LUT dicts
|
||||
huffLUT = [None] * 12 # index 0 unused
|
||||
params = [
|
||||
(4, 1, True),
|
||||
(4, 1, True),
|
||||
(4, 2, False),
|
||||
(4, 2, False),
|
||||
(2, 4, True),
|
||||
(2, 4, True),
|
||||
(2, 7, False),
|
||||
(2, 7, False),
|
||||
(2, 12, False),
|
||||
(2, 12, False),
|
||||
(2, 16, False),
|
||||
]
|
||||
|
||||
for i, (nTupleSize, maxAbs, signed) in enumerate(params, start=1):
|
||||
huffLUT[i] = {
|
||||
'LUT': huffCodebooks[i-1],
|
||||
'invTable': invTable[i-1],
|
||||
'codebook': i,
|
||||
'nTupleSize': nTupleSize,
|
||||
'maxAbsCodeVal': maxAbs,
|
||||
'signedValues': signed
|
||||
}
|
||||
|
||||
return huffLUT
|
||||
|
||||
def vlc_table(code_array):
|
||||
"""
|
||||
codeArray: list of strings, each string is a Huffman codeword (e.g. '0101')
|
||||
returns:
|
||||
h : NumPy array of shape (num_nodes, 3)
|
||||
columns:
|
||||
[ next_if_0 , next_if_1 , symbol_index ]
|
||||
"""
|
||||
h = np.zeros((1, 3), dtype=int)
|
||||
|
||||
for code_index, code in enumerate(code_array, start=1):
|
||||
word = [int(bit) for bit in code]
|
||||
h_index = 0
|
||||
|
||||
for bit in word:
|
||||
k = bit
|
||||
next_node = h[h_index, k]
|
||||
if next_node == 0:
|
||||
h = np.vstack([h, [0, 0, 0]])
|
||||
new_index = h.shape[0] - 1
|
||||
h[h_index, k] = new_index
|
||||
h_index = new_index
|
||||
else:
|
||||
h_index = next_node
|
||||
|
||||
h[h_index, 2] = code_index
|
||||
|
||||
return h
|
||||
|
||||
# ------------------ ENCODE ------------------
|
||||
|
||||
def encode_huff(coeff_sec, huff_LUT_list, force_codebook = None):
|
||||
"""
|
||||
Huffman-encode a sequence of quantized coefficients.
|
||||
|
||||
This function selects the appropriate Huffman codebook based on the
|
||||
maximum absolute value of the input coefficients, encodes the coefficients
|
||||
into a binary Huffman bitstream, and returns both the bitstream and the
|
||||
selected codebook index.
|
||||
|
||||
This is the Python equivalent of the MATLAB `encodeHuff.m` function used
|
||||
in audio/image coding (e.g., scale factor band encoding). The input
|
||||
coefficient sequence is grouped into fixed-size tuples as defined by
|
||||
the chosen Huffman LUT. Zero-padding may be applied internally.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
coeff_sec : array_like of int
|
||||
1-D array of quantized integer coefficients to encode.
|
||||
Typically corresponds to a "section" or scale-factor band.
|
||||
|
||||
huff_LUT_list : list
|
||||
List of Huffman lookup-table dictionaries as returned by `loadLUT()`.
|
||||
Index 1..11 correspond to valid Huffman codebooks.
|
||||
Index 0 is unused.
|
||||
|
||||
Returns
|
||||
-------
|
||||
huffSec : str
|
||||
Huffman-encoded bitstream represented as a string of '0' and '1'
|
||||
characters.
|
||||
|
||||
huffCodebook : int
|
||||
Index (1..11) of the Huffman codebook used for encoding.
|
||||
A value of 0 indicates a special all-zero section.
|
||||
"""
|
||||
if force_codebook is not None:
|
||||
return huff_LUT_code_1(huff_LUT_list[force_codebook], coeff_sec)
|
||||
|
||||
maxAbsVal = np.max(np.abs(coeff_sec))
|
||||
|
||||
if maxAbsVal == 0:
|
||||
huffCodebook = 0
|
||||
huffSec = huff_LUT_code_0()
|
||||
|
||||
elif maxAbsVal == 1:
|
||||
candidates = [1, 2]
|
||||
huffSec1 = huff_LUT_code_1(huff_LUT_list[candidates[0]], coeff_sec)
|
||||
huffSec2 = huff_LUT_code_1(huff_LUT_list[candidates[1]], coeff_sec)
|
||||
if len(huffSec1) <= len(huffSec2):
|
||||
huffSec = huffSec1
|
||||
huffCodebook = candidates[0]
|
||||
else:
|
||||
huffSec = huffSec2
|
||||
huffCodebook = candidates[1]
|
||||
|
||||
elif maxAbsVal == 2:
|
||||
candidates = [3, 4]
|
||||
huffSec1 = huff_LUT_code_1(huff_LUT_list[candidates[0]], coeff_sec)
|
||||
huffSec2 = huff_LUT_code_1(huff_LUT_list[candidates[1]], coeff_sec)
|
||||
if len(huffSec1) <= len(huffSec2):
|
||||
huffSec = huffSec1
|
||||
huffCodebook = candidates[0]
|
||||
else:
|
||||
huffSec = huffSec2
|
||||
huffCodebook = candidates[1]
|
||||
|
||||
elif maxAbsVal in (3, 4):
|
||||
candidates = [5, 6]
|
||||
huffSec1 = huff_LUT_code_1(huff_LUT_list[candidates[0]], coeff_sec)
|
||||
huffSec2 = huff_LUT_code_1(huff_LUT_list[candidates[1]], coeff_sec)
|
||||
if len(huffSec1) <= len(huffSec2):
|
||||
huffSec = huffSec1
|
||||
huffCodebook = candidates[0]
|
||||
else:
|
||||
huffSec = huffSec2
|
||||
huffCodebook = candidates[1]
|
||||
|
||||
elif maxAbsVal in (5, 6, 7):
|
||||
candidates = [7, 8]
|
||||
huffSec1 = huff_LUT_code_1(huff_LUT_list[candidates[0]], coeff_sec)
|
||||
huffSec2 = huff_LUT_code_1(huff_LUT_list[candidates[1]], coeff_sec)
|
||||
if len(huffSec1) <= len(huffSec2):
|
||||
huffSec = huffSec1
|
||||
huffCodebook = candidates[0]
|
||||
else:
|
||||
huffSec = huffSec2
|
||||
huffCodebook = candidates[1]
|
||||
|
||||
elif maxAbsVal in (8, 9, 10, 11, 12):
|
||||
candidates = [9, 10]
|
||||
huffSec1 = huff_LUT_code_1(huff_LUT_list[candidates[0]], coeff_sec)
|
||||
huffSec2 = huff_LUT_code_1(huff_LUT_list[candidates[1]], coeff_sec)
|
||||
if len(huffSec1) <= len(huffSec2):
|
||||
huffSec = huffSec1
|
||||
huffCodebook = candidates[0]
|
||||
else:
|
||||
huffSec = huffSec2
|
||||
huffCodebook = candidates[1]
|
||||
|
||||
elif maxAbsVal in (13, 14, 15):
|
||||
huffCodebook = 11
|
||||
huffSec = huff_LUT_code_1(huff_LUT_list[huffCodebook], coeff_sec)
|
||||
|
||||
else:
|
||||
huffCodebook = 11
|
||||
huffSec = huff_LUT_code_ESC(huff_LUT_list[huffCodebook], coeff_sec)
|
||||
|
||||
return huffSec, huffCodebook
|
||||
|
||||
def huff_LUT_code_1(huff_LUT, coeff_sec):
|
||||
LUT = huff_LUT['LUT']
|
||||
nTupleSize = huff_LUT['nTupleSize']
|
||||
maxAbsCodeVal = huff_LUT['maxAbsCodeVal']
|
||||
signedValues = huff_LUT['signedValues']
|
||||
|
||||
numTuples = int(np.ceil(len(coeff_sec) / nTupleSize))
|
||||
|
||||
if signedValues:
|
||||
coeff = coeff_sec + maxAbsCodeVal
|
||||
base = 2 * maxAbsCodeVal + 1
|
||||
else:
|
||||
coeff = coeff_sec
|
||||
base = maxAbsCodeVal + 1
|
||||
|
||||
coeffPad = np.zeros(numTuples * nTupleSize, dtype=int)
|
||||
coeffPad[:len(coeff)] = coeff
|
||||
|
||||
huffSec = []
|
||||
|
||||
powers = base ** np.arange(nTupleSize - 1, -1, -1)
|
||||
|
||||
for i in range(numTuples):
|
||||
nTuple = coeffPad[i*nTupleSize:(i+1)*nTupleSize]
|
||||
huffIndex = int(np.abs(nTuple) @ powers)
|
||||
|
||||
hexVal = LUT[huffIndex, 2]
|
||||
huffLen = LUT[huffIndex, 1]
|
||||
|
||||
bits = format(int(hexVal), f'0{int(huffLen)}b')
|
||||
|
||||
if signedValues:
|
||||
huffSec.append(bits)
|
||||
else:
|
||||
signBits = ''.join('1' if v < 0 else '0' for v in nTuple)
|
||||
huffSec.append(bits + signBits)
|
||||
|
||||
return ''.join(huffSec)
|
||||
|
||||
def huff_LUT_code_0():
|
||||
return ''
|
||||
|
||||
def huff_LUT_code_ESC(huff_LUT, coeff_sec):
|
||||
LUT = huff_LUT['LUT']
|
||||
nTupleSize = huff_LUT['nTupleSize']
|
||||
maxAbsCodeVal = huff_LUT['maxAbsCodeVal']
|
||||
|
||||
numTuples = int(np.ceil(len(coeff_sec) / nTupleSize))
|
||||
base = maxAbsCodeVal + 1
|
||||
|
||||
coeffPad = np.zeros(numTuples * nTupleSize, dtype=int)
|
||||
coeffPad[:len(coeff_sec)] = coeff_sec
|
||||
|
||||
huffSec = []
|
||||
powers = base ** np.arange(nTupleSize - 1, -1, -1)
|
||||
|
||||
for i in range(numTuples):
|
||||
nTuple = coeffPad[i*nTupleSize:(i+1)*nTupleSize]
|
||||
|
||||
lnTuple = nTuple.astype(float)
|
||||
lnTuple[lnTuple == 0] = np.finfo(float).eps
|
||||
|
||||
N4 = np.maximum(0, np.floor(np.log2(np.abs(lnTuple))).astype(int))
|
||||
N = np.maximum(0, N4 - 4)
|
||||
esc = np.abs(nTuple) > 15
|
||||
|
||||
nTupleESC = nTuple.copy()
|
||||
nTupleESC[esc] = np.sign(nTupleESC[esc]) * 16
|
||||
|
||||
huffIndex = int(np.abs(nTupleESC) @ powers)
|
||||
|
||||
hexVal = LUT[huffIndex, 2]
|
||||
huffLen = LUT[huffIndex, 1]
|
||||
|
||||
bits = format(int(hexVal), f'0{int(huffLen)}b')
|
||||
|
||||
escSeq = ''
|
||||
for k in range(nTupleSize):
|
||||
if esc[k]:
|
||||
escSeq += '1' * N[k]
|
||||
escSeq += '0'
|
||||
escSeq += format(abs(nTuple[k]) - (1 << N4[k]), f'0{N4[k]}b')
|
||||
|
||||
signBits = ''.join('1' if v < 0 else '0' for v in nTuple)
|
||||
huffSec.append(bits + signBits + escSeq)
|
||||
|
||||
return ''.join(huffSec)
|
||||
|
||||
# ------------------ DECODE ------------------
|
||||
|
||||
def decode_huff(huff_sec, huff_LUT):
|
||||
"""
|
||||
Decode a Huffman-encoded stream.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
huff_sec : array-like of int or str
|
||||
Huffman encoded stream as a sequence of 0 and 1 (string or list/array).
|
||||
huff_LUT : dict
|
||||
Huffman lookup table with keys:
|
||||
- 'invTable': inverse table (numpy array)
|
||||
- 'codebook': codebook number
|
||||
- 'nTupleSize': tuple size
|
||||
- 'maxAbsCodeVal': maximum absolute code value
|
||||
- 'signedValues': True/False
|
||||
|
||||
Returns
|
||||
-------
|
||||
decCoeffs : list of int
|
||||
Decoded quantized coefficients.
|
||||
"""
|
||||
|
||||
h = huff_LUT['invTable']
|
||||
huffCodebook = huff_LUT['codebook']
|
||||
nTupleSize = huff_LUT['nTupleSize']
|
||||
maxAbsCodeVal = huff_LUT['maxAbsCodeVal']
|
||||
signedValues = huff_LUT['signedValues']
|
||||
|
||||
# Convert string to array of ints
|
||||
if isinstance(huff_sec, str):
|
||||
huff_sec = np.array([int(b) for b in huff_sec])
|
||||
|
||||
eos = False
|
||||
decCoeffs = []
|
||||
streamIndex = 0
|
||||
|
||||
while not eos:
|
||||
wordbit = 0
|
||||
r = 0 # start at root
|
||||
|
||||
# Decode Huffman word using inverse table
|
||||
while True:
|
||||
b = huff_sec[streamIndex + wordbit]
|
||||
wordbit += 1
|
||||
rOld = r
|
||||
r = h[rOld, b]
|
||||
if h[r, 0] == 0 and h[r, 1] == 0:
|
||||
symbolIndex = h[r, 2] - 1 # zero-based
|
||||
streamIndex += wordbit
|
||||
break
|
||||
|
||||
# Decode n-tuple magnitudes
|
||||
if signedValues:
|
||||
base = 2 * maxAbsCodeVal + 1
|
||||
nTupleDec = []
|
||||
tmp = symbolIndex
|
||||
for p in reversed(range(nTupleSize)):
|
||||
val = tmp // (base ** p)
|
||||
nTupleDec.append(val - maxAbsCodeVal)
|
||||
tmp = tmp % (base ** p)
|
||||
nTupleDec = np.array(nTupleDec)
|
||||
else:
|
||||
base = maxAbsCodeVal + 1
|
||||
nTupleDec = []
|
||||
tmp = symbolIndex
|
||||
for p in reversed(range(nTupleSize)):
|
||||
val = tmp // (base ** p)
|
||||
nTupleDec.append(val)
|
||||
tmp = tmp % (base ** p)
|
||||
nTupleDec = np.array(nTupleDec)
|
||||
|
||||
# Apply sign bits
|
||||
nTupleSignBits = huff_sec[streamIndex:streamIndex + nTupleSize]
|
||||
nTupleSign = -(np.sign(nTupleSignBits - 0.5))
|
||||
streamIndex += nTupleSize
|
||||
nTupleDec = nTupleDec * nTupleSign
|
||||
|
||||
# Handle escape sequences
|
||||
escIndex = np.where(np.abs(nTupleDec) == 16)[0]
|
||||
if huffCodebook == 11 and escIndex.size > 0:
|
||||
for idx in escIndex:
|
||||
N = 0
|
||||
b = huff_sec[streamIndex]
|
||||
while b:
|
||||
N += 1
|
||||
b = huff_sec[streamIndex + N]
|
||||
streamIndex += N
|
||||
N4 = N + 4
|
||||
escape_word = huff_sec[streamIndex:streamIndex + N4]
|
||||
escape_value = 2 ** N4 + int("".join(map(str, escape_word)), 2)
|
||||
nTupleDec[idx] = escape_value
|
||||
streamIndex += N4 + 1
|
||||
# Apply signs again
|
||||
nTupleDec[escIndex] *= nTupleSign[escIndex]
|
||||
|
||||
decCoeffs.extend(nTupleDec.tolist())
|
||||
|
||||
if streamIndex >= len(huff_sec):
|
||||
eos = True
|
||||
|
||||
return decCoeffs
|
||||
|
||||
|
||||
5
source/requirements.txt
Normal file
5
source/requirements.txt
Normal file
@ -0,0 +1,5 @@
|
||||
numpy
|
||||
scipy
|
||||
scipy-stubs
|
||||
soundfile
|
||||
pytest
|
||||
Loading…
x
Reference in New Issue
Block a user