Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better Intelligibility metric for CAD2 Task1 into main #418

Draft
wants to merge 12 commits into
base: main
Choose a base branch
from
Draft
10 changes: 8 additions & 2 deletions clarity/evaluator/msbg/cochlea.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,11 @@ class Cochlea:
"""

def __init__(
self, audiogram: Audiogram, catch_up_level: float = 105.0, fs: float = 44100.0
self,
audiogram: Audiogram,
catch_up_level: float = 105.0,
fs: float = 44100.0,
verbose=True,
) -> None:
"""Cochlea constructor.

Expand All @@ -233,6 +237,7 @@ def __init__(
catch_up_level (float, optional): loudness catch-up level in dB
Default is 105 dB
fs (float, optional): sampling frequency
verbose (bool, optional): verbose mode. Default is True

"""
self.fs = fs
Expand All @@ -254,7 +259,8 @@ def __init__(
r_lower, r_upper = HL_PARAMS[severity_level]["smear_params"]
self.smearer = Smearer(r_lower, r_upper, fs)

logging.info("Severity level - %s", severity_level)
if verbose:
logging.info("Severity level - %s", severity_level)

def simulate(self, coch_sig: ndarray, equiv_0dB_file_SPL: float) -> ndarray:
"""Pass a signal through the cochlea.
Expand Down
35 changes: 24 additions & 11 deletions clarity/evaluator/msbg/msbg.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def __init__(
sample_rate: float = 44100.0,
equiv_0db_spl: float = 100.0,
ahr: float = 20.0,
verbose: bool = True,
) -> None:
"""
Constructor for the Ear class.
Expand All @@ -48,7 +49,9 @@ def __init__(
sample_rate (float): sample frequency.
equiv_0db_spl (): ???
ahr (): ???
verbose (): ???
"""
self.verbose = verbose
self.sample_rate = sample_rate
self.src_correction = self.get_src_correction(src_pos)
self.equiv_0db_spl = equiv_0db_spl
Expand All @@ -62,7 +65,7 @@ def set_audiogram(self, audiogram: Audiogram) -> None:
"Impairment too severe: Suggest you limit audiogram max to"
"80-90 dB HL, otherwise things go wrong/unrealistic."
)
self.cochlea = Cochlea(audiogram=audiogram)
self.cochlea = Cochlea(audiogram=audiogram, verbose=self.verbose)

@staticmethod
def get_src_correction(src_pos: str) -> ndarray:
Expand Down Expand Up @@ -92,6 +95,7 @@ def src_to_cochlea_filt(
src_correction: ndarray,
sample_rate: float,
backward: bool = False,
verbose: bool = True,
) -> ndarray:
"""Simulate middle and outer ear transfer functions.

Expand All @@ -109,12 +113,14 @@ def src_to_cochlea_filt(
or ITU
sample_rate (int): sampling frequency
backward (bool, optional): if true then cochlea to src (default: False)
verbose (bool, optional): print verbose output (default: True)

Returns:
np.ndarray: the processed signal

"""
logging.info("performing outer/middle ear corrections")
if verbose:
logging.info("performing outer/middle ear corrections")

# make sure that response goes only up to sample_frequency/2
nyquist = int(sample_rate / 2.0)
Expand Down Expand Up @@ -204,7 +210,8 @@ def process(self, signal: ndarray, add_calibration: bool = False) -> list[ndarra
)
raise ValueError("Invalid sampling frequency, valid value is 44100")

logging.info("Processing {len(chans)} samples")
if self.verbose:
logging.info("Processing {len(chans)} samples")

# Need to know file RMS, and then call that a certain level in SPL:
# needs some form of pre-measuring.
Expand All @@ -219,7 +226,7 @@ def process(self, signal: ndarray, add_calibration: bool = False) -> list[ndarra

# Measure RMS where 3rd arg is dB_rel_rms (how far below)
calculated_rms, idx, _rel_db_thresh, _active = measure_rms(
signal[0], sample_rate, -12
signal[0], sample_rate, -12, verbose=self.verbose
)

# Rescale input data and check level after rescaling
Expand All @@ -229,11 +236,11 @@ def process(self, signal: ndarray, add_calibration: bool = False) -> list[ndarra
new_rms_db = equiv_0db_spl + 10 * np.log10(
np.mean(np.power(signal[0][idx], 2.0))
)
logging.info(
"Rescaling: "
f"leveldBSPL was {level_db_spl:3.1f} dB SPL, now {new_rms_db:3.1f} dB SPL. "
f" Target SPL is {target_spl:3.1f} dB SPL."
)
if self.verbose:
logging.info(
f"Rescaling: leveldBSPL was {level_db_spl:3.1f} dB SPL, now"
f" {new_rms_db:3.1f} dB SPL. Target SPL is {target_spl:3.1f} dB SPL."
)

# Add calibration signal at target SPL dB
if add_calibration is True:
Expand All @@ -247,11 +254,17 @@ def process(self, signal: ndarray, add_calibration: bool = False) -> list[ndarra
signal = np.concatenate((pre_calibration, signal, post_calibration), axis=1)

# Transform from src pos to cochlea, simulate cochlea, transform back to src pos
signal = Ear.src_to_cochlea_filt(signal, self.src_correction, sample_rate)
signal = Ear.src_to_cochlea_filt(
signal, self.src_correction, sample_rate, verbose=self.verbose
)
if self.cochlea is not None:
signal = np.array([self.cochlea.simulate(x, equiv_0db_spl) for x in signal])
signal = Ear.src_to_cochlea_filt(
signal, self.src_correction, sample_rate, backward=True
signal,
self.src_correction,
sample_rate,
backward=True,
verbose=self.verbose,
)

# Implement low-pass filter at top end of audio range: flat to Cutoff freq,
Expand Down
14 changes: 10 additions & 4 deletions clarity/evaluator/msbg/msbg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,7 @@ def generate_key_percent(
threshold_db: float,
window_length: int,
percent_to_track: float | None = None,
verbose: bool = False,
) -> tuple[ndarray, float]:
"""Generate key percent.
Locates frames above some energy threshold or tracks a certain percentage
Expand All @@ -370,6 +371,7 @@ def generate_key_percent(
window_length (int): length of window in samples.
percent_to_track (float, optional): Track a percentage of frames.
Default is None
verbose (bool, optional): Print verbose output. Default is False.

Raises:
ValueError: percent_to_track is set too high.
Expand All @@ -393,10 +395,11 @@ def generate_key_percent(

expected = threshold_db
# new Dec 2003. Possibly track percentage of frames rather than fixed threshold
if percent_to_track is not None:
logging.info("tracking %s percentage of frames", percent_to_track)
else:
logging.info("tracking fixed threshold")
if verbose:
if percent_to_track is not None:
logging.info("tracking %s percentage of frames", percent_to_track)
else:
logging.info("tracking fixed threshold")

# put floor into histogram distribution
non_zero = np.power(10, (expected - 30) / 10)
Expand Down Expand Up @@ -466,6 +469,7 @@ def measure_rms(
sample_rate: float,
db_rel_rms: float,
percent_to_track: float | None = None,
verbose=False,
) -> tuple[float, ndarray, float, float]:
"""Measure Root Mean Square.

Expand All @@ -481,6 +485,7 @@ def measure_rms(
db_rel_rms (float): threshold for frames to track.
percent_to_track (float, optional): track percentage of frames,
rather than threshold (default: {None})
verbose (bool, optional): Print verbose output. Default is False.
Returns:
(tuple): tuple containing
- rms (float): overall calculated rms (linear)
Expand All @@ -500,6 +505,7 @@ def measure_rms(
key_thr_db,
round(WIN_SECS * sample_rate),
percent_to_track=percent_to_track,
verbose=verbose,
)

idx = key.astype(int) # move into generate_key_percent
Expand Down
84 changes: 68 additions & 16 deletions recipes/cad2/task1/baseline/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import hashlib
import json
import logging
from pathlib import Path
Expand All @@ -11,7 +12,7 @@
import pyloudnorm as pyln
import torch.nn
import whisper
from jiwer import compute_measures
from alt_eval import compute_metrics
from omegaconf import DictConfig

from clarity.enhancer.multiband_compressor import MultibandCompressor
Expand All @@ -25,6 +26,17 @@
logger = logging.getLogger(__name__)


def set_song_seed(song: str) -> None:
"""Set a seed that is unique for the given song"""
song_encoded = hashlib.md5(song.encode("utf-8")).hexdigest()
song_md5 = int(song_encoded, 16) % (10**8)
np.random.seed(song_md5)

torch.manual_seed(song_md5)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(song_md5)


def make_scene_listener_list(scenes_listeners: dict, small_test: bool = False) -> list:
"""Make the list of scene-listener pairing to process

Expand Down Expand Up @@ -90,6 +102,7 @@
ear = Ear(
equiv_0db_spl=equiv_0db_spl,
sample_rate=sample_rate,
verbose=False,
)

reference = segment_metadata["text"]
Expand All @@ -100,30 +113,38 @@
enhanced_left = ear.process(enhanced_signal[:, 0])[0]
left_path = Path(f"{path_intermediate.as_posix()}_left.flac")
save_flac_signal(
enhanced_signal,
enhanced_left,
left_path,
44100,
sample_rate,
)
hypothesis = scorer.transcribe(left_path.as_posix(), fp16=False)["text"]
hypothesis = scorer.transcribe(left_path.as_posix(), fp16=False, temperature=0.0)[
"text"
]
lyrics["hypothesis_left"] = hypothesis

left_results = compute_measures(reference, hypothesis)
left_results = compute_metrics(
[reference], [hypothesis], languages="en", include_other=False
)

# Compute right ear
ear.set_audiogram(listener.audiogram_right)
enhanced_right = ear.process(enhanced_signal[:, 1])[0]
right_path = Path(f"{path_intermediate.as_posix()}_right.flac")
save_flac_signal(
enhanced_signal,
enhanced_right,
right_path,
44100,
sample_rate,
)
hypothesis = scorer.transcribe(right_path.as_posix(), fp16=False)["text"]
hypothesis = scorer.transcribe(right_path.as_posix(), fp16=False, temperature=0.0)[
"text"
]
lyrics["hypothesis_right"] = hypothesis

right_results = compute_measures(reference, hypothesis)
right_results = compute_metrics(
[reference], [hypothesis], languages="en", include_other=False
)

# Compute the average score for both ears
total_words = (
Expand Down Expand Up @@ -155,9 +176,23 @@
reference_signal: np.ndarray,
enhanced_signal: np.ndarray,
listener: Listener,
config: DictConfig,
reference_sample_rate: int,
enhanced_sample_rate: int,
HAAQI_sample_rate: int,

Check notice on line 181 in recipes/cad2/task1/baseline/evaluate.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

recipes/cad2/task1/baseline/evaluate.py#L181

Argument name "HAAQI_sample_rate" doesn't conform to snake_case naming style
) -> tuple[float, float]:
"""Compute the HAAQI score for the left and right channels"""
"""Compute the HAAQI score for the left and right channels

Args:
reference_signal: The reference signal
enhanced_signal: The enhanced signal
listener: The listener
reference_sample_rate: The sample rate of the reference signal
enhanced_sample_rate: The sample rate of the enhanced signal
HAAQI_sample_rate: The sample rate for the HAAQI computation

Returns:
The HAAQI score for the left and right channels
"""
scores = []

for channel in range(2):
Expand All @@ -167,16 +202,16 @@
s = compute_haaqi(
processed_signal=resample(
enhanced_signal[:, channel],
config.remix_sample_rate,
config.HAAQI_sample_rate,
enhanced_sample_rate,
HAAQI_sample_rate,
),
reference_signal=resample(
reference_signal[:, channel],
config.input_sample_rate,
config.HAAQI_sample_rate,
reference_sample_rate,
HAAQI_sample_rate,
),
processed_sample_rate=config.HAAQI_sample_rate,
reference_sample_rate=config.HAAQI_sample_rate,
processed_sample_rate=HAAQI_sample_rate,
reference_sample_rate=HAAQI_sample_rate,
audiogram=audiogram,
equalisation=2,
level1=65 - 20 * np.log10(compute_rms(reference_signal[:, channel])),
Expand Down Expand Up @@ -304,6 +339,11 @@
sample_rate=config.input_sample_rate,
)

# Configure backend for determinism
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

# Load the Whisper model
intelligibility_scorer = whisper.load_model(config.evaluate.whisper_version)

# Loop over the scene-listener pairs
Expand All @@ -317,6 +357,10 @@

scene_id, listener_id = scene_listener_ids

# Set the random seed for the scene
if config.evaluate.set_random_seed:
set_song_seed(scene_id)

# Load scene details
scene = scenes[scene_id]
listener = listener_dict[listener_id]
Expand Down Expand Up @@ -377,7 +421,15 @@
# COMPUTE SCORES

# Compute the HAAQI and Whisper scores
haaqi_scores = compute_quality(reference, enhanced_signal, listener, config)
haaqi_scores = compute_quality(
reference_signal=reference,
enhanced_signal=enhanced_signal,
listener=listener,
reference_sample_rate=config.input_sample_rate,
enhanced_sample_rate=config.remix_sample_rate,
HAAQI_sample_rate=config.HAAQI_sample_rate,
)

whisper_left, whisper_right, lyrics_text = compute_intelligibility(
enhanced_signal=enhanced_signal,
segment_metadata=songs[scene["segment_id"]],
Expand Down
2 changes: 1 addition & 1 deletion recipes/cad2/task1/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
alt-eval

Check notice on line 1 in recipes/cad2/task1/requirements.txt

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

recipes/cad2/task1/requirements.txt#L1

Missing module docstring

Check warning on line 1 in recipes/cad2/task1/requirements.txt

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

recipes/cad2/task1/requirements.txt#L1

Statement seems to have no effect

Check failure on line 1 in recipes/cad2/task1/requirements.txt

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

recipes/cad2/task1/requirements.txt#L1

Undefined variable 'alt'
huggingface-hub
jiwer
openai-whisper
safetensors
Loading