218 lines
6.5 KiB
Python
218 lines
6.5 KiB
Python
# ------------------------------------------------------------
|
||
# AAC Coder/Decoder - Sequence Segmentation Control module
|
||
#
|
||
# Multimedia course at Aristotle University of
|
||
# Thessaloniki (AUTh)
|
||
#
|
||
# Author:
|
||
# Christos Choutouridis (ΑΕΜ 8997)
|
||
# cchoutou@ece.auth.gr
|
||
#
|
||
# Description:
|
||
# Sequence Segmentation Control module (SSC).
|
||
# Selects and returns the frame type based on input parameters.
|
||
# ------------------------------------------------------------
|
||
from __future__ import annotations
|
||
|
||
from typing import Dict, Tuple
|
||
from core.aac_types import FrameType, FrameT, FrameChannelT
|
||
|
||
import numpy as np
|
||
|
||
# -----------------------------------------------------------------------------
|
||
# Private helpers for SSC
|
||
# -----------------------------------------------------------------------------
|
||
|
||
# See Table 1 in mm-2025-hw-v0.1.pdf
|
||
STEREO_MERGE_TABLE: Dict[Tuple[FrameType, FrameType], FrameType] = {
|
||
("OLS", "OLS"): "OLS",
|
||
("OLS", "LSS"): "LSS",
|
||
("OLS", "ESH"): "ESH",
|
||
("OLS", "LPS"): "LPS",
|
||
("LSS", "OLS"): "LSS",
|
||
("LSS", "LSS"): "LSS",
|
||
("LSS", "ESH"): "ESH",
|
||
("LSS", "LPS"): "ESH",
|
||
("ESH", "OLS"): "ESH",
|
||
("ESH", "LSS"): "ESH",
|
||
("ESH", "ESH"): "ESH",
|
||
("ESH", "LPS"): "ESH",
|
||
("LPS", "OLS"): "LPS",
|
||
("LPS", "LSS"): "ESH",
|
||
("LPS", "ESH"): "ESH",
|
||
("LPS", "LPS"): "LPS",
|
||
}
|
||
|
||
|
||
def _detect_attack(next_frame_channel: FrameChannelT) -> bool:
|
||
"""
|
||
Detect whether the *next* frame (single channel) implies an attack, i.e. ESH
|
||
according to the assignment's criterion.
|
||
|
||
Parameters
|
||
----------
|
||
next_frame_channel : FrameChannelT
|
||
One channel of next_frame_T (expected shape: (2048,)).
|
||
|
||
Returns
|
||
-------
|
||
bool
|
||
True if an attack is detected (=> next frame predicted ESH), else False.
|
||
|
||
Notes
|
||
-----
|
||
The criterion is implemented as described in the spec:
|
||
|
||
1) Apply the high-pass filter:
|
||
H(z) = (1 - z^-1) / (1 - 0.5 z^-1)
|
||
implemented in the time domain as:
|
||
y[n] = x[n] - x[n-1] + 0.5*y[n-1]
|
||
|
||
2) Split y into 16 segments of length 128 and compute segment energies s[l].
|
||
|
||
3) Compute the ratio:
|
||
ds[l] = s[l] / s[l-1]
|
||
|
||
4) An attack exists if there exists l in {1..7} such that:
|
||
s[l] > 1e-3 and ds[l] > 10
|
||
"""
|
||
# Local alias; expected to be a 1-D array of length 2048.
|
||
x = next_frame_channel
|
||
|
||
# High-pass filter reference implementation (scalar recurrence).
|
||
y = np.zeros_like(x)
|
||
prev_x = 0.0
|
||
prev_y = 0.0
|
||
for n in range(x.shape[0]):
|
||
xn = float(x[n])
|
||
yn = (xn - prev_x) + 0.5 * prev_y
|
||
y[n] = yn
|
||
prev_x = xn
|
||
prev_y = yn
|
||
|
||
# Segment energies over 16 blocks of 128 samples.
|
||
s = np.empty(16, dtype=np.float64)
|
||
for l in range(16):
|
||
a = l * 128
|
||
b = (l + 1) * 128
|
||
seg = y[a:b]
|
||
s[l] = float(np.sum(seg * seg))
|
||
|
||
# ds[l] for l>=1. For l=0 not defined, keep 0.
|
||
ds = np.zeros(16, dtype=np.float64)
|
||
eps = 1e-12 # Avoid division by zero without materially changing the logic.
|
||
for l in range(1, 16):
|
||
ds[l] = s[l] / max(s[l - 1], eps)
|
||
|
||
# Spec: check l in {1..7}.
|
||
for l in range(1, 8):
|
||
if (s[l] > 1e-3) and (ds[l] > 10.0):
|
||
return True
|
||
|
||
return False
|
||
|
||
|
||
def _decide_frame_type(prev_frame_type: FrameType, attack: bool) -> FrameType:
|
||
"""
|
||
Decide the current frame type for a single channel based on the previous
|
||
frame type and whether the next frame is predicted to be ESH.
|
||
|
||
Rules (spec):
|
||
|
||
- If prev is "LSS" => current is "ESH"
|
||
- If prev is "LPS" => current is "OLS"
|
||
- If prev is "OLS" => current is "LSS" if attack else "OLS"
|
||
- If prev is "ESH" => current is "ESH" if attack else "LPS"
|
||
|
||
Parameters
|
||
----------
|
||
prev_frame_type : FrameType
|
||
Previous frame type (one of "OLS", "LSS", "ESH", "LPS").
|
||
attack : bool
|
||
True if the next frame is predicted ESH for this channel.
|
||
|
||
Returns
|
||
-------
|
||
FrameType
|
||
The per-channel decision for the current frame.
|
||
|
||
"""
|
||
if prev_frame_type == "LSS":
|
||
return "ESH"
|
||
if prev_frame_type == "LPS":
|
||
return "OLS"
|
||
if prev_frame_type == "OLS":
|
||
return "LSS" if attack else "OLS"
|
||
if prev_frame_type == "ESH":
|
||
return "ESH" if attack else "LPS"
|
||
|
||
raise ValueError(f"Invalid prev_frame_type: {prev_frame_type!r}")
|
||
|
||
|
||
def _stereo_merge(ft_l: FrameType, ft_r: FrameType) -> FrameType:
|
||
"""
|
||
Merge per-channel frame type decisions into one common frame type using
|
||
the stereo merge table from the spec.
|
||
|
||
Parameters
|
||
----------
|
||
ft_l : FrameType
|
||
Frame type decision for the left channel.
|
||
ft_r : FrameType
|
||
Frame type decision for the right channel.
|
||
|
||
Returns
|
||
-------
|
||
FrameType
|
||
The merged common frame type.
|
||
"""
|
||
try:
|
||
return STEREO_MERGE_TABLE[(ft_l, ft_r)]
|
||
except KeyError as e:
|
||
raise ValueError(f"Invalid stereo merge pair: {(ft_l, ft_r)}") from e
|
||
|
||
|
||
# -----------------------------------------------------------------------------
|
||
# Public Function prototypes
|
||
# -----------------------------------------------------------------------------
|
||
|
||
def aac_SSC(frame_T: FrameT, next_frame_T: FrameT, prev_frame_type: FrameType) -> FrameType:
|
||
"""
|
||
Sequence Segmentation Control (SSC).
|
||
|
||
Select and return the frame type for the current frame (i) based on:
|
||
- the current time-domain frame (stereo),
|
||
- the next time-domain frame (stereo), used for attack detection,
|
||
- the previous frame type.
|
||
|
||
Parameters
|
||
----------
|
||
frame_T : FrameT
|
||
Current time-domain frame i (expected shape: (2048, 2)).
|
||
next_frame_T : FrameT
|
||
Next time-domain frame (i+1), used to decide transitions to/from ESH
|
||
(expected shape: (2048, 2)).
|
||
prev_frame_type : FrameType
|
||
Frame type chosen for the previous frame (i-1).
|
||
|
||
Returns
|
||
-------
|
||
FrameType
|
||
One of: "OLS", "LSS", "ESH", "LPS".
|
||
"""
|
||
if frame_T.shape != (2048, 2):
|
||
raise ValueError("frame_T must have shape (2048, 2).")
|
||
if next_frame_T.shape != (2048, 2):
|
||
raise ValueError("next_frame_T must have shape (2048, 2).")
|
||
|
||
# Detect attack independently per channel on the next frame.
|
||
attack_l = _detect_attack(next_frame_T[:, 0])
|
||
attack_r = _detect_attack(next_frame_T[:, 1])
|
||
|
||
# Decide per-channel type based on shared prev_frame_type.
|
||
ft_l = _decide_frame_type(prev_frame_type, attack_l)
|
||
ft_r = _decide_frame_type(prev_frame_type, attack_r)
|
||
|
||
# Stereo merge as per the spec table.
|
||
return _stereo_merge(ft_l, ft_r)
|