200 lines
6.0 KiB
Python
200 lines
6.0 KiB
Python
import numpy as np
|
|
import pytest
|
|
|
|
# Adjust the import based on package/module layout.
|
|
from level_1.level_1 import SSC
|
|
|
|
# Helper "fixtures" for SSC
|
|
# -----------------------------------------------------------------------------
|
|
|
|
def _next_frame_no_attack() -> np.ndarray:
|
|
"""
|
|
Build a next_frame_T that should NOT trigger ESH detection.
|
|
|
|
Uses exact zeros so all s2l are zero and the ESH condition (s2l > 1e-3) cannot hold.
|
|
"""
|
|
return np.zeros((2048, 2), dtype=np.float64)
|
|
|
|
|
|
def _next_frame_strong_attack(
|
|
*,
|
|
attack_left: bool,
|
|
attack_right: bool,
|
|
segment_l: int = 4,
|
|
baseline: float = 1e-6,
|
|
burst_amp: float = 1.0,
|
|
) -> np.ndarray:
|
|
"""
|
|
Build a next_frame_T (2048x2) that should trigger ESH detection on selected channels.
|
|
|
|
Spec: ESH if exists l in {1..7} with s2l > 1e-3 AND ds2l > 10.
|
|
We create:
|
|
- small baseline energy in all samples (avoids division by zero in ds2l),
|
|
- a strong burst inside one 128-sample segment l in 1..7.
|
|
"""
|
|
assert 1 <= segment_l <= 7
|
|
x = np.full((2048, 2), baseline, dtype=np.float64)
|
|
|
|
a = segment_l * 128
|
|
b = (segment_l + 1) * 128
|
|
|
|
if attack_left:
|
|
x[a:b, 0] += burst_amp
|
|
if attack_right:
|
|
x[a:b, 1] += burst_amp
|
|
|
|
return x
|
|
|
|
|
|
def _next_frame_below_s2l_threshold(
|
|
*,
|
|
left: bool,
|
|
right: bool,
|
|
segment_l: int = 4,
|
|
impulse_amp: float = 0.01,
|
|
) -> np.ndarray:
|
|
"""
|
|
Construct a next_frame_T where s2l is below 1e-3, so ESH must NOT be triggered,
|
|
even if ds2l could be large.
|
|
|
|
Put a single impulse of amplitude 'impulse_amp' inside a segment.
|
|
Energy in the 128-sample segment: s2l ~= impulse_amp^2.
|
|
With impulse_amp=0.01 => s2l ~= 1e-4 < 1e-3.
|
|
"""
|
|
assert 1 <= segment_l <= 7
|
|
x = np.zeros((2048, 2), dtype=np.float64)
|
|
|
|
idx = segment_l * 128 + 10 # inside segment
|
|
if left:
|
|
x[idx, 0] = impulse_amp
|
|
if right:
|
|
x[idx, 1] = impulse_amp
|
|
|
|
return x
|
|
|
|
|
|
# ---------------------------------------------------------------------
|
|
# 1) Fixed/mandatory cases (prev frame type forces current type)
|
|
# ---------------------------------------------------------------------
|
|
|
|
def test_ssc_fixed_cases_prev_lss_and_lps() -> None:
|
|
"""
|
|
Spec: if prev was:
|
|
- LSS => current MUST be ESH
|
|
- LPS => current MUST be OLS
|
|
independent of next frame check.
|
|
"""
|
|
frame_t = np.zeros((2048, 2), dtype=np.float64)
|
|
|
|
# Even if next frame has a strong attack, LSS must force ESH.
|
|
next_attack = _next_frame_strong_attack(attack_left=True, attack_right=True)
|
|
out1 = SSC(frame_t, next_attack, "LSS")
|
|
assert out1 == "ESH"
|
|
|
|
# Even if next frame has a strong attack, LPS must force OLS.
|
|
out2 = SSC(frame_t, next_attack, "LPS")
|
|
assert out2 == "OLS"
|
|
|
|
|
|
# ---------------------------------------------------------------------
|
|
# 2) Cases requiring next-frame ESH prediction (energy/attack computation)
|
|
# ---------------------------------------------------------------------
|
|
|
|
def test_prev_ols_next_not_esh_returns_ols() -> None:
|
|
"""
|
|
Spec: if prev=OLS, current is OLS or LSS.
|
|
Choose LSS iff (i+1) predicted ESH, else OLS.
|
|
"""
|
|
frame_t = np.zeros((2048, 2), dtype=np.float64)
|
|
next_t = _next_frame_no_attack()
|
|
|
|
out = SSC(frame_t, next_t, "OLS")
|
|
assert out == "OLS"
|
|
|
|
|
|
def test_prev_ols_next_esh_both_channels_returns_lss() -> None:
|
|
"""
|
|
prev=OLS, next predicted ESH (both channels) => per-channel decisions are LSS and LSS
|
|
and merge table keeps LSS.
|
|
"""
|
|
frame_t = np.zeros((2048, 2), dtype=np.float64)
|
|
next_t = _next_frame_strong_attack(attack_left=True, attack_right=True)
|
|
|
|
out = SSC(frame_t, next_t, "OLS")
|
|
assert out == "LSS"
|
|
|
|
|
|
def test_prev_ols_next_esh_one_channel_returns_lss() -> None:
|
|
"""
|
|
prev=OLS:
|
|
- one channel predicts ESH => LSS
|
|
- other channel predicts not ESH => OLS
|
|
Merge table: OLS + LSS => LSS.
|
|
"""
|
|
frame_t = np.zeros((2048, 2), dtype=np.float64)
|
|
|
|
next1_t = _next_frame_strong_attack(attack_left=True, attack_right=False)
|
|
out1 = SSC(frame_t, next1_t, "OLS")
|
|
assert out1 == "LSS"
|
|
|
|
next2_t = _next_frame_strong_attack(attack_left=False, attack_right=True)
|
|
out2 = SSC(frame_t, next2_t, "OLS")
|
|
assert out2 == "LSS"
|
|
|
|
|
|
def test_prev_esh_next_esh_both_channels_returns_esh() -> None:
|
|
"""
|
|
prev=ESH:
|
|
- next predicted ESH => current ESH (per-channel)
|
|
Merge table: ESH + ESH => ESH.
|
|
"""
|
|
frame_t = np.zeros((2048, 2), dtype=np.float64)
|
|
next_t = _next_frame_strong_attack(attack_left=True, attack_right=True)
|
|
|
|
out = SSC(frame_t, next_t, "ESH")
|
|
assert out == "ESH"
|
|
|
|
|
|
def test_prev_esh_next_not_esh_both_channels_returns_lps() -> None:
|
|
"""
|
|
prev=ESH:
|
|
- next not predicted ESH => current LPS (per-channel)
|
|
Merge table: LPS + LPS => LPS.
|
|
"""
|
|
frame_t = np.zeros((2048, 2), dtype=np.float64)
|
|
next_t = _next_frame_no_attack()
|
|
|
|
out = SSC(frame_t, next_t, "ESH")
|
|
assert out == "LPS"
|
|
|
|
|
|
def test_prev_esh_next_esh_one_channel_merged_is_esh() -> None:
|
|
"""
|
|
prev=ESH:
|
|
- one channel predicts ESH => ESH
|
|
- other channel predicts not ESH => LPS
|
|
Merge table: ESH + LPS => ESH.
|
|
"""
|
|
frame_t = np.zeros((2048, 2), dtype=np.float64)
|
|
|
|
next1_t = _next_frame_strong_attack(attack_left=True, attack_right=False)
|
|
out1 = SSC(frame_t, next1_t, "ESH")
|
|
assert out1 == "ESH"
|
|
|
|
next2_t = _next_frame_strong_attack(attack_left=True, attack_right=False)
|
|
out2 = SSC(frame_t, next2_t, "ESH")
|
|
assert out2 == "ESH"
|
|
|
|
def test_threshold_s2l_must_exceed_1e_3() -> None:
|
|
"""
|
|
Spec: next frame is ESH only if s2l > 1e-3 AND ds2l > 10 for some l in 1..7.
|
|
This test checks the necessity of the s2l threshold:
|
|
- Create a frame with s2l ~= 1e-4 < 1e-3 (single impulse with amp 0.01).
|
|
- Expect: not classified as ESH -> for prev=OLS return OLS.
|
|
"""
|
|
frame_t = np.zeros((2048, 2), dtype=np.float64)
|
|
next_t = _next_frame_below_s2l_threshold(left=True, right=True, impulse_amp=0.01)
|
|
|
|
out = SSC(frame_t, next_t, "OLS")
|
|
assert out == "OLS"
|