238 lines
7.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# ------------------------------------------------------------
# AAC Coder/Decoder - Level 3 Wrappers + Demo
#
# Multimedia course at Aristotle University of
# Thessaloniki (AUTh)
#
# Author:
# Christos Choutouridis (ΑΕΜ 8997)
# cchoutou@ece.auth.gr
#
# Description:
# Level 3 wrapper module.
#
# This file provides:
# - Thin wrappers for Level 3 API functions (encode/decode) that delegate
# to the corresponding core implementations.
# - A demo function that runs end-to-end and computes:
# * SNR
# * bitrate (coded)
# * compression ratio
# - A small CLI entrypoint for convenience.
# ------------------------------------------------------------
from __future__ import annotations
from pathlib import Path
from typing import Optional, Tuple, Union
import os
import soundfile as sf
from core.aac_types import AACSeq3, StereoSignal
from core.aac_coder import aac_coder_3 as core_aac_coder_3
from core.aac_coder import aac_read_wav_stereo_48k
from core.aac_decoder import aac_decoder_3 as core_aac_decoder_3
from core.aac_utils import snr_db
# -----------------------------------------------------------------------------
# Helpers (Level 3 metrics)
# -----------------------------------------------------------------------------
def _wav_duration_seconds(wav_path: Path) -> float:
"""Return WAV duration in seconds using soundfile metadata."""
info = sf.info(str(wav_path))
if info.samplerate <= 0:
raise ValueError("Invalid samplerate in WAV header.")
if info.frames < 0:
raise ValueError("Invalid frame count in WAV header.")
return float(info.frames) / float(info.samplerate)
def _bitrate_before_from_file(wav_path: Path) -> float:
"""
Compute input bitrate (bits/s) from file size and duration.
Note:
This is a file-based bitrate estimate (includes WAV header), which is
acceptable for a simple compression ratio metric.
"""
duration = _wav_duration_seconds(wav_path)
if duration <= 0.0:
raise ValueError("Non-positive WAV duration.")
nbits = float(os.path.getsize(wav_path)) * 8.0
return nbits / duration
def _bitrate_after_from_aacseq(aac_seq_3: AACSeq3, duration_sec: float) -> float:
"""
Compute coded bitrate (bits/s) from Huffman streams stored in AACSeq3.
We count bits from:
- scalefactor Huffman bitstream ("sfc")
- MDCT symbols Huffman bitstream ("stream")
for both channels and all frames.
Note:
We intentionally ignore side-info overhead (frame_type, G, T, TNS coeffs,
codebook ids, etc.). This matches a common simplified metric in demos.
"""
if duration_sec <= 0.0:
raise ValueError("Non-positive duration for bitrate computation.")
total_bits = 0
for fr in aac_seq_3:
total_bits += len(fr["chl"]["sfc"])
total_bits += len(fr["chl"]["stream"])
total_bits += len(fr["chr"]["sfc"])
total_bits += len(fr["chr"]["stream"])
return float(total_bits) / float(duration_sec)
# -----------------------------------------------------------------------------
# Public Level 3 API (wrappers)
# -----------------------------------------------------------------------------
def aac_coder_3(
filename_in: Union[str, Path],
filename_aac_coded: Optional[Union[str, Path]] = None,
) -> AACSeq3:
"""
Level-3 AAC encoder (wrapper).
Delegates to core implementation.
Parameters
----------
filename_in : Union[str, Path]
Input WAV filename.
Assumption: stereo audio, sampling rate 48 kHz.
filename_aac_coded : Optional[Union[str, Path]]
Optional filename to store the encoded AAC sequence (e.g., .mat).
Returns
-------
AACSeq3
List of encoded frames (Level 3 schema).
"""
return core_aac_coder_3(filename_in, filename_aac_coded, verbose=True)
def i_aac_coder_3(
aac_seq_3: AACSeq3,
filename_out: Union[str, Path],
) -> StereoSignal:
"""
Level-3 AAC decoder (wrapper).
Delegates to core implementation.
Parameters
----------
aac_seq_3 : AACSeq3
Encoded sequence as produced by aac_coder_3().
filename_out : Union[str, Path]
Output WAV filename. Assumption: 48 kHz, stereo.
Returns
-------
StereoSignal
Decoded audio samples (time-domain), stereo, shape (N, 2), dtype float64.
"""
return core_aac_decoder_3(aac_seq_3, filename_out, verbose=True)
# -----------------------------------------------------------------------------
# Demo (Level 3)
# -----------------------------------------------------------------------------
def demo_aac_3(
filename_in: Union[str, Path],
filename_out: Union[str, Path],
filename_aac_coded: Optional[Union[str, Path]] = None,
) -> Tuple[float, float, float]:
"""
Demonstration for the Level-3 codec.
Runs:
- aac_coder_3(filename_in, filename_aac_coded)
- i_aac_coder_3(aac_seq_3, filename_out)
and computes:
- total SNR between original and decoded audio
- coded bitrate (bits/s) based on Huffman streams
- compression ratio (bitrate_before / bitrate_after)
Parameters
----------
filename_in : Union[str, Path]
Input WAV filename (stereo, 48 kHz).
filename_out : Union[str, Path]
Output WAV filename (stereo, 48 kHz).
filename_aac_coded : Optional[Union[str, Path]]
Optional filename to store the encoded AAC sequence (e.g., .mat).
Returns
-------
Tuple[float, float, float]
(SNR_dB, bitrate_after_bits_per_s, compression_ratio)
"""
filename_in = Path(filename_in)
filename_out = Path(filename_out)
filename_aac_coded = Path(filename_aac_coded) if filename_aac_coded else None
# Read original audio (reference) with the same validation as the codec.
x_ref, fs_ref = aac_read_wav_stereo_48k(filename_in)
if int(fs_ref) != 48000:
raise ValueError("Input sampling rate must be 48 kHz.")
# Encode / decode
aac_seq_3 = aac_coder_3(filename_in, filename_aac_coded)
x_hat = i_aac_coder_3(aac_seq_3, filename_out)
# Optional sanity: ensure output file exists and is readable
_, fs_hat = sf.read(str(filename_out), always_2d=True)
if int(fs_hat) != 48000:
raise ValueError("Decoded output sampling rate must be 48 kHz.")
# Metrics
s = snr_db(x_ref, x_hat)
duration = _wav_duration_seconds(filename_in)
bitrate_before = _bitrate_before_from_file(filename_in)
bitrate_after = _bitrate_after_from_aacseq(aac_seq_3, duration)
compression = float("inf") if bitrate_after <= 0.0 else (bitrate_before / bitrate_after)
return float(s), float(bitrate_after), float(compression)
# -----------------------------------------------------------------------------
# CLI
# -----------------------------------------------------------------------------
if __name__ == "__main__":
# Example:
# cd level_3
# python -m level_3 input.wav output.wav
# for example:
# python -m level_3 material/LicorDeCalandraca.wav LicorDeCalandraca_out_l3.wav
# or
# python -m level_3 material/LicorDeCalandraca.wav LicorDeCalandraca_out_l3.wav aac_seq_3.mat
import sys
if len(sys.argv) not in (3, 4):
raise SystemExit("Usage: python -m level_3 <input.wav> <output.wav> [aac_seq_3.mat]")
in_wav = Path(sys.argv[1])
out_wav = Path(sys.argv[2])
aac_mat = Path(sys.argv[3]) if len(sys.argv) == 4 else None
print(f"Encoding/Decoding {in_wav} to {out_wav}")
if aac_mat is not None:
print(f"Storing coded sequence to {aac_mat}")
snr, bitrate, compression = demo_aac_3(in_wav, out_wav, aac_mat)
print(f"SNR = {snr:.3f} dB")
print(f"Bitrate (coded) = {bitrate:.2f} bits/s")
print(f"Compression ratio = {compression:.4f}")