328 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# ------------------------------------------------------------
# AAC Coder/Decoder - Level 3 Wrappers + Demo
#
# Multimedia course at Aristotle University of
# Thessaloniki (AUTh)
#
# Author:
# Christos Choutouridis (ΑΕΜ 8997)
# cchoutou@ece.auth.gr
#
# Description:
# Level 3 wrapper module.
#
# This file provides:
# - Thin wrappers for Level 3 API functions (encode/decode) that delegate
# to the corresponding core implementations.
# - A demo function that runs end-to-end and computes:
# * SNR
# * bitrate (coded)
# * compression ratio
# - A small CLI entrypoint for convenience.
# ------------------------------------------------------------
from __future__ import annotations
from pathlib import Path
from typing import Optional, Tuple, Union
import os
import soundfile as sf
import numpy as np
import matplotlib.pyplot as plt
from core.aac_types import AACSeq3, StereoSignal
from core.aac_coder import aac_coder_3 as core_aac_coder_3
from core.aac_coder import aac_read_wav_stereo_48k
from core.aac_decoder import aac_decoder_3 as core_aac_decoder_3
from core.aac_utils import snr_db, estimate_lag_mono, match_gain
# Global variable to "pass" AACSeq3 without changing the demo_aac_e interface.
AAC_Seq_3: AACSeq3
# -----------------------------------------------------------------------------
# Helpers (Level 3 metrics)
# -----------------------------------------------------------------------------
def _wav_duration_seconds(wav_path: Path) -> float:
"""Return WAV duration in seconds using soundfile metadata."""
info = sf.info(str(wav_path))
if info.samplerate <= 0:
raise ValueError("Invalid samplerate in WAV header.")
if info.frames < 0:
raise ValueError("Invalid frame count in WAV header.")
return float(info.frames) / float(info.samplerate)
def _bitrate_before_from_file(wav_path: Path) -> float:
"""
Compute input bitrate (bits/s) from file size and duration.
Note:
This is a file-based bitrate estimate (includes WAV header), which is
acceptable for a simple compression ratio metric.
"""
duration = _wav_duration_seconds(wav_path)
if duration <= 0.0:
raise ValueError("Non-positive WAV duration.")
nbits = float(os.path.getsize(wav_path)) * 8.0
return nbits / duration
def _bitrate_after_from_aacseq(aac_seq_3: AACSeq3, duration_sec: float) -> float:
"""
Compute coded bitrate (bits/s) from Huffman streams stored in AACSeq3.
We count bits from:
- scalefactor Huffman bitstream ("sfc")
- MDCT symbols Huffman bitstream ("stream")
for both channels and all frames.
Note:
We intentionally ignore side-info overhead (frame_type, G, T, TNS coeffs,
codebook ids, etc.). This matches a common simplified metric in demos.
"""
if duration_sec <= 0.0:
raise ValueError("Non-positive duration for bitrate computation.")
total_bits = 0
for fr in aac_seq_3:
total_bits += len(fr["chl"]["sfc"])
total_bits += len(fr["chl"]["stream"])
total_bits += len(fr["chr"]["sfc"])
total_bits += len(fr["chr"]["stream"])
return float(total_bits) / float(duration_sec)
def _plot_frame_bitrate_and_compression(
aac_seq_3: AACSeq3,
wav_path: Union[str, Path],
fname_bitrate: Union[str, Path],
fname_comp: Union[str, Path],
) -> None:
"""
Compute and plot per-frame bitrate and compression ratio
for a Level 3 AAC sequence.
Parameters
----------
aac_seq_3 : list
Output of aac_coder_3 (list of frame dictionaries).
wav_path : str or Path
Path to original WAV file (PCM 48 kHz stereo).
fname_bitrate : str or Path
Path to original bitrate per frame plot output file.
fname_comp : str or Path
Path to original compression per frame plot output file.
"""
# Read WAV metadata
info = sf.info(str(wav_path))
samplerate = info.samplerate
total_samples = info.frames
total_duration = total_samples / samplerate
n_frames = len(aac_seq_3)
# AAC long-frame hop size is 1024 new samples per frame
samples_per_frame = 1024
duration_per_frame = samples_per_frame / samplerate
# Original bitrate (file-based estimate)
original_bits = os.path.getsize(wav_path) * 8.0
original_bitrate = original_bits / total_duration
frame_bitrates = []
frame_compression = []
for fr in aac_seq_3:
bits = 0
bits += len(fr["chl"]["sfc"])
bits += len(fr["chl"]["stream"])
bits += len(fr["chr"]["sfc"])
bits += len(fr["chr"]["stream"])
bitrate = bits / duration_per_frame
compression = original_bitrate / bitrate if bitrate > 0 else np.inf
frame_bitrates.append(bitrate)
frame_compression.append(compression)
frame_indices = np.arange(n_frames)
# Plot bitrate per frame and save to file
plt.figure(figsize=(6, 3), dpi=300)
plt.plot(frame_indices, frame_bitrates)
plt.xlabel("Frame index")
plt.ylabel("Bitrate (bits/s)")
plt.title("Bitrate (per-frame)")
plt.tight_layout()
plt.savefig(str(fname_bitrate))
plt.close()
# Plot compression ratio per frame and save to file
plt.figure(figsize=(6, 3), dpi=300)
plt.plot(frame_indices, frame_compression)
plt.xlabel("Frame index")
plt.ylabel("Compression Ratio")
plt.title("Compression Ratio (per-frame)")
plt.tight_layout()
plt.savefig(str(fname_comp))
plt.close()
# -----------------------------------------------------------------------------
# Public Level 3 API (wrappers)
# -----------------------------------------------------------------------------
def aac_coder_3(
filename_in: Union[str, Path],
filename_aac_coded: Optional[Union[str, Path]] = None,
) -> AACSeq3:
"""
Level-3 AAC encoder (wrapper).
Delegates to core implementation.
Parameters
----------
filename_in : Union[str, Path]
Input WAV filename.
Assumption: stereo audio, sampling rate 48 kHz.
filename_aac_coded : Optional[Union[str, Path]]
Optional filename to store the encoded AAC sequence (e.g., .mat).
Returns
-------
AACSeq3
List of encoded frames (Level 3 schema).
"""
return core_aac_coder_3(filename_in, filename_aac_coded, verbose=True)
def i_aac_coder_3(
aac_seq_3: AACSeq3,
filename_out: Union[str, Path],
) -> StereoSignal:
"""
Level-3 AAC decoder (wrapper).
Delegates to core implementation.
Parameters
----------
aac_seq_3 : AACSeq3
Encoded sequence as produced by aac_coder_3().
filename_out : Union[str, Path]
Output WAV filename. Assumption: 48 kHz, stereo.
Returns
-------
StereoSignal
Decoded audio samples (time-domain), stereo, shape (N, 2), dtype float64.
"""
return core_aac_decoder_3(aac_seq_3, filename_out, verbose=True)
# -----------------------------------------------------------------------------
# Demo (Level 3)
# -----------------------------------------------------------------------------
def demo_aac_3(
filename_in: Union[str, Path],
filename_out: Union[str, Path],
filename_aac_coded: Optional[Union[str, Path]] = None,
) -> Tuple[float, float, float]:
"""
Demonstration for the Level-3 codec.
Runs:
- aac_coder_3(filename_in, filename_aac_coded)
- i_aac_coder_3(aac_seq_3, filename_out)
and computes:
- total SNR between original and decoded audio
- coded bitrate (bits/s) based on Huffman streams
- compression ratio (bitrate_before / bitrate_after)
Parameters
----------
filename_in : Union[str, Path]
Input WAV filename (stereo, 48 kHz).
filename_out : Union[str, Path]
Output WAV filename (stereo, 48 kHz).
filename_aac_coded : Optional[Union[str, Path]]
Optional filename to store the encoded AAC sequence (e.g., .mat).
Returns
-------
Tuple[float, float, float]
(SNR_dB, bitrate_after_bits_per_s, compression_ratio)
"""
filename_in = Path(filename_in)
filename_out = Path(filename_out)
filename_aac_coded = Path(filename_aac_coded) if filename_aac_coded else None
# Read original audio (reference) with the same validation as the codec.
x_ref, fs_ref = aac_read_wav_stereo_48k(filename_in)
if int(fs_ref) != 48000:
raise ValueError("Input sampling rate must be 48 kHz.")
# Encode / decode
global AAC_Seq_3 # pick coder output
AAC_Seq_3 = aac_coder_3(filename_in, filename_aac_coded)
x_hat = i_aac_coder_3(AAC_Seq_3, filename_out)
# Optional sanity: ensure output file exists and is readable
_, fs_hat = sf.read(str(filename_out), always_2d=True)
if int(fs_hat) != 48000:
raise ValueError("Decoded output sampling rate must be 48 kHz.")
# Quality metrics
s = snr_db(x_ref, x_hat)
duration = _wav_duration_seconds(filename_in)
bitrate_before = _bitrate_before_from_file(filename_in)
bitrate_after = _bitrate_after_from_aacseq(AAC_Seq_3, duration)
compression = float("inf") if bitrate_after <= 0.0 else (bitrate_before / bitrate_after)
return float(s), float(bitrate_after), float(compression)
# -----------------------------------------------------------------------------
# CLI
# -----------------------------------------------------------------------------
if __name__ == "__main__":
# Example:
# cd level_3
# python -m level_3 input.wav output.wav
# for example:
# python -m level_3 material/LicorDeCalandraca.wav LicorDeCalandraca_out_l3.wav
# or
# python -m level_3 material/LicorDeCalandraca.wav LicorDeCalandraca_out_l3.wav aac_seq_3.mat
# or
# python -m level_3 material/LicorDeCalandraca.wav LicorDeCalandraca_out_l3.wav aac_seq_3.mat bitrate.png compression.png
import sys
if len(sys.argv) not in (3, 4, 5, 6):
raise SystemExit(
"Usage: python -m level_3 <input.wav> <output.wav> [aac_seq_3.mat] [bitrate_fname] [compression_fname]"
)
in_wav = Path(sys.argv[1])
out_wav = Path(sys.argv[2])
aac_mat = Path(sys.argv[3]) if len(sys.argv) == 4 else None
fname_bitrate = Path(sys.argv[4]) if len(sys.argv) == 5 else "bitrate_per_frame.png"
fname_comp = Path(sys.argv[5]) if len(sys.argv) == 6 else "compression_per_frame.png"
print(f"Encoding/Decoding {in_wav} to {out_wav}")
if aac_mat is not None:
print(f"Storing coded sequence to {aac_mat}")
snr, bitrate, compression = demo_aac_3(in_wav, out_wav, aac_mat)
# plot compresion / bitrate
_plot_frame_bitrate_and_compression(AAC_Seq_3, in_wav, fname_bitrate, fname_comp)
print(f"SNR = {snr:.3f} dB")
print(f"Bitrate (coded) = {bitrate:.2f} bits/s")
print(f"Compression ratio = {compression:.4f}")