218 lines
6.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# ------------------------------------------------------------
# AAC Coder/Decoder - Sequence Segmentation Control module
#
# Multimedia course at Aristotle University of
# Thessaloniki (AUTh)
#
# Author:
# Christos Choutouridis (ΑΕΜ 8997)
# cchoutou@ece.auth.gr
#
# Description:
# Sequence Segmentation Control module (SSC).
# Selects and returns the frame type based on input parameters.
# ------------------------------------------------------------
from __future__ import annotations
from typing import Dict, Tuple
from core.aac_types import FrameType, FrameT, FrameChannelT
import numpy as np
# -----------------------------------------------------------------------------
# Private helpers for SSC
# -----------------------------------------------------------------------------
# See Table 1 in mm-2025-hw-v0.1.pdf
STEREO_MERGE_TABLE: Dict[Tuple[FrameType, FrameType], FrameType] = {
("OLS", "OLS"): "OLS",
("OLS", "LSS"): "LSS",
("OLS", "ESH"): "ESH",
("OLS", "LPS"): "LPS",
("LSS", "OLS"): "LSS",
("LSS", "LSS"): "LSS",
("LSS", "ESH"): "ESH",
("LSS", "LPS"): "ESH",
("ESH", "OLS"): "ESH",
("ESH", "LSS"): "ESH",
("ESH", "ESH"): "ESH",
("ESH", "LPS"): "ESH",
("LPS", "OLS"): "LPS",
("LPS", "LSS"): "ESH",
("LPS", "ESH"): "ESH",
("LPS", "LPS"): "LPS",
}
def _detect_attack(next_frame_channel: FrameChannelT) -> bool:
"""
Detect whether the *next* frame (single channel) implies an attack, i.e. ESH
according to the assignment's criterion.
Parameters
----------
next_frame_channel : FrameChannelT
One channel of next_frame_T (expected shape: (2048,)).
Returns
-------
bool
True if an attack is detected (=> next frame predicted ESH), else False.
Notes
-----
The criterion is implemented as described in the spec:
1) Apply the high-pass filter:
H(z) = (1 - z^-1) / (1 - 0.5 z^-1)
implemented in the time domain as:
y[n] = x[n] - x[n-1] + 0.5*y[n-1]
2) Split y into 16 segments of length 128 and compute segment energies s[l].
3) Compute the ratio:
ds[l] = s[l] / s[l-1]
4) An attack exists if there exists l in {1..7} such that:
s[l] > 1e-3 and ds[l] > 10
"""
# Local alias; expected to be a 1-D array of length 2048.
x = next_frame_channel
# High-pass filter reference implementation (scalar recurrence).
y = np.zeros_like(x)
prev_x = 0.0
prev_y = 0.0
for n in range(x.shape[0]):
xn = float(x[n])
yn = (xn - prev_x) + 0.5 * prev_y
y[n] = yn
prev_x = xn
prev_y = yn
# Segment energies over 16 blocks of 128 samples.
s = np.empty(16, dtype=np.float64)
for l in range(16):
a = l * 128
b = (l + 1) * 128
seg = y[a:b]
s[l] = float(np.sum(seg * seg))
# ds[l] for l>=1. For l=0 not defined, keep 0.
ds = np.zeros(16, dtype=np.float64)
eps = 1e-12 # Avoid division by zero without materially changing the logic.
for l in range(1, 16):
ds[l] = s[l] / max(s[l - 1], eps)
# Spec: check l in {1..7}.
for l in range(1, 8):
if (s[l] > 1e-3) and (ds[l] > 10.0):
return True
return False
def _decide_frame_type(prev_frame_type: FrameType, attack: bool) -> FrameType:
"""
Decide the current frame type for a single channel based on the previous
frame type and whether the next frame is predicted to be ESH.
Rules (spec):
- If prev is "LSS" => current is "ESH"
- If prev is "LPS" => current is "OLS"
- If prev is "OLS" => current is "LSS" if attack else "OLS"
- If prev is "ESH" => current is "ESH" if attack else "LPS"
Parameters
----------
prev_frame_type : FrameType
Previous frame type (one of "OLS", "LSS", "ESH", "LPS").
attack : bool
True if the next frame is predicted ESH for this channel.
Returns
-------
FrameType
The per-channel decision for the current frame.
"""
if prev_frame_type == "LSS":
return "ESH"
if prev_frame_type == "LPS":
return "OLS"
if prev_frame_type == "OLS":
return "LSS" if attack else "OLS"
if prev_frame_type == "ESH":
return "ESH" if attack else "LPS"
raise ValueError(f"Invalid prev_frame_type: {prev_frame_type!r}")
def _stereo_merge(ft_l: FrameType, ft_r: FrameType) -> FrameType:
"""
Merge per-channel frame type decisions into one common frame type using
the stereo merge table from the spec.
Parameters
----------
ft_l : FrameType
Frame type decision for the left channel.
ft_r : FrameType
Frame type decision for the right channel.
Returns
-------
FrameType
The merged common frame type.
"""
try:
return STEREO_MERGE_TABLE[(ft_l, ft_r)]
except KeyError as e:
raise ValueError(f"Invalid stereo merge pair: {(ft_l, ft_r)}") from e
# -----------------------------------------------------------------------------
# Public Function prototypes
# -----------------------------------------------------------------------------
def aac_SSC(frame_T: FrameT, next_frame_T: FrameT, prev_frame_type: FrameType) -> FrameType:
"""
Sequence Segmentation Control (SSC).
Select and return the frame type for the current frame (i) based on:
- the current time-domain frame (stereo),
- the next time-domain frame (stereo), used for attack detection,
- the previous frame type.
Parameters
----------
frame_T : FrameT
Current time-domain frame i (expected shape: (2048, 2)).
next_frame_T : FrameT
Next time-domain frame (i+1), used to decide transitions to/from ESH
(expected shape: (2048, 2)).
prev_frame_type : FrameType
Frame type chosen for the previous frame (i-1).
Returns
-------
FrameType
One of: "OLS", "LSS", "ESH", "LPS".
"""
if frame_T.shape != (2048, 2):
raise ValueError("frame_T must have shape (2048, 2).")
if next_frame_T.shape != (2048, 2):
raise ValueError("next_frame_T must have shape (2048, 2).")
# Detect attack independently per channel on the next frame.
attack_l = _detect_attack(next_frame_T[:, 0])
attack_r = _detect_attack(next_frame_T[:, 1])
# Decide per-channel type based on shared prev_frame_type.
ft_l = _decide_frame_type(prev_frame_type, attack_l)
ft_r = _decide_frame_type(prev_frame_type, attack_r)
# Stereo merge as per the spec table.
return _stereo_merge(ft_l, ft_r)