diff --git a/clarity/evaluator/msbg/cochlea.py b/clarity/evaluator/msbg/cochlea.py index 5d47bfbb..10ff5cf6 100644 --- a/clarity/evaluator/msbg/cochlea.py +++ b/clarity/evaluator/msbg/cochlea.py @@ -224,7 +224,11 @@ class Cochlea: """ def __init__( - self, audiogram: Audiogram, catch_up_level: float = 105.0, fs: float = 44100.0 + self, + audiogram: Audiogram, + catch_up_level: float = 105.0, + fs: float = 44100.0, + verbose=True, ) -> None: """Cochlea constructor. @@ -233,6 +237,7 @@ def __init__( catch_up_level (float, optional): loudness catch-up level in dB Default is 105 dB fs (float, optional): sampling frequency + verbose (bool, optional): verbose mode. Default is True """ self.fs = fs @@ -254,7 +259,8 @@ def __init__( r_lower, r_upper = HL_PARAMS[severity_level]["smear_params"] self.smearer = Smearer(r_lower, r_upper, fs) - logging.info("Severity level - %s", severity_level) + if verbose: + logging.info("Severity level - %s", severity_level) def simulate(self, coch_sig: ndarray, equiv_0dB_file_SPL: float) -> ndarray: """Pass a signal through the cochlea. diff --git a/clarity/evaluator/msbg/msbg.py b/clarity/evaluator/msbg/msbg.py index 204144f0..fabd943f 100644 --- a/clarity/evaluator/msbg/msbg.py +++ b/clarity/evaluator/msbg/msbg.py @@ -40,6 +40,7 @@ def __init__( sample_rate: float = 44100.0, equiv_0db_spl: float = 100.0, ahr: float = 20.0, + verbose: bool = True, ) -> None: """ Constructor for the Ear class. @@ -48,7 +49,9 @@ def __init__( sample_rate (float): sample frequency. equiv_0db_spl (): ??? ahr (): ??? + verbose (): ??? """ + self.verbose = verbose self.sample_rate = sample_rate self.src_correction = self.get_src_correction(src_pos) self.equiv_0db_spl = equiv_0db_spl @@ -62,7 +65,7 @@ def set_audiogram(self, audiogram: Audiogram) -> None: "Impairment too severe: Suggest you limit audiogram max to" "80-90 dB HL, otherwise things go wrong/unrealistic." ) - self.cochlea = Cochlea(audiogram=audiogram) + self.cochlea = Cochlea(audiogram=audiogram, verbose=self.verbose) @staticmethod def get_src_correction(src_pos: str) -> ndarray: @@ -92,6 +95,7 @@ def src_to_cochlea_filt( src_correction: ndarray, sample_rate: float, backward: bool = False, + verbose: bool = True, ) -> ndarray: """Simulate middle and outer ear transfer functions. @@ -109,12 +113,14 @@ def src_to_cochlea_filt( or ITU sample_rate (int): sampling frequency backward (bool, optional): if true then cochlea to src (default: False) + verbose (bool, optional): print verbose output (default: True) Returns: np.ndarray: the processed signal """ - logging.info("performing outer/middle ear corrections") + if verbose: + logging.info("performing outer/middle ear corrections") # make sure that response goes only up to sample_frequency/2 nyquist = int(sample_rate / 2.0) @@ -204,7 +210,8 @@ def process(self, signal: ndarray, add_calibration: bool = False) -> list[ndarra ) raise ValueError("Invalid sampling frequency, valid value is 44100") - logging.info("Processing {len(chans)} samples") + if self.verbose: + logging.info("Processing {len(chans)} samples") # Need to know file RMS, and then call that a certain level in SPL: # needs some form of pre-measuring. @@ -219,7 +226,7 @@ def process(self, signal: ndarray, add_calibration: bool = False) -> list[ndarra # Measure RMS where 3rd arg is dB_rel_rms (how far below) calculated_rms, idx, _rel_db_thresh, _active = measure_rms( - signal[0], sample_rate, -12 + signal[0], sample_rate, -12, verbose=self.verbose ) # Rescale input data and check level after rescaling @@ -229,11 +236,11 @@ def process(self, signal: ndarray, add_calibration: bool = False) -> list[ndarra new_rms_db = equiv_0db_spl + 10 * np.log10( np.mean(np.power(signal[0][idx], 2.0)) ) - logging.info( - "Rescaling: " - f"leveldBSPL was {level_db_spl:3.1f} dB SPL, now {new_rms_db:3.1f} dB SPL. " - f" Target SPL is {target_spl:3.1f} dB SPL." - ) + if self.verbose: + logging.info( + f"Rescaling: leveldBSPL was {level_db_spl:3.1f} dB SPL, now" + f" {new_rms_db:3.1f} dB SPL. Target SPL is {target_spl:3.1f} dB SPL." + ) # Add calibration signal at target SPL dB if add_calibration is True: @@ -247,11 +254,17 @@ def process(self, signal: ndarray, add_calibration: bool = False) -> list[ndarra signal = np.concatenate((pre_calibration, signal, post_calibration), axis=1) # Transform from src pos to cochlea, simulate cochlea, transform back to src pos - signal = Ear.src_to_cochlea_filt(signal, self.src_correction, sample_rate) + signal = Ear.src_to_cochlea_filt( + signal, self.src_correction, sample_rate, verbose=self.verbose + ) if self.cochlea is not None: signal = np.array([self.cochlea.simulate(x, equiv_0db_spl) for x in signal]) signal = Ear.src_to_cochlea_filt( - signal, self.src_correction, sample_rate, backward=True + signal, + self.src_correction, + sample_rate, + backward=True, + verbose=self.verbose, ) # Implement low-pass filter at top end of audio range: flat to Cutoff freq, diff --git a/clarity/evaluator/msbg/msbg_utils.py b/clarity/evaluator/msbg/msbg_utils.py index 09704d79..f5c387b3 100644 --- a/clarity/evaluator/msbg/msbg_utils.py +++ b/clarity/evaluator/msbg/msbg_utils.py @@ -358,6 +358,7 @@ def generate_key_percent( threshold_db: float, window_length: int, percent_to_track: float | None = None, + verbose: bool = False, ) -> tuple[ndarray, float]: """Generate key percent. Locates frames above some energy threshold or tracks a certain percentage @@ -370,6 +371,7 @@ def generate_key_percent( window_length (int): length of window in samples. percent_to_track (float, optional): Track a percentage of frames. Default is None + verbose (bool, optional): Print verbose output. Default is False. Raises: ValueError: percent_to_track is set too high. @@ -393,10 +395,11 @@ def generate_key_percent( expected = threshold_db # new Dec 2003. Possibly track percentage of frames rather than fixed threshold - if percent_to_track is not None: - logging.info("tracking %s percentage of frames", percent_to_track) - else: - logging.info("tracking fixed threshold") + if verbose: + if percent_to_track is not None: + logging.info("tracking %s percentage of frames", percent_to_track) + else: + logging.info("tracking fixed threshold") # put floor into histogram distribution non_zero = np.power(10, (expected - 30) / 10) @@ -466,6 +469,7 @@ def measure_rms( sample_rate: float, db_rel_rms: float, percent_to_track: float | None = None, + verbose=False, ) -> tuple[float, ndarray, float, float]: """Measure Root Mean Square. @@ -481,6 +485,7 @@ def measure_rms( db_rel_rms (float): threshold for frames to track. percent_to_track (float, optional): track percentage of frames, rather than threshold (default: {None}) + verbose (bool, optional): Print verbose output. Default is False. Returns: (tuple): tuple containing - rms (float): overall calculated rms (linear) @@ -500,6 +505,7 @@ def measure_rms( key_thr_db, round(WIN_SECS * sample_rate), percent_to_track=percent_to_track, + verbose=verbose, ) idx = key.astype(int) # move into generate_key_percent diff --git a/recipes/cad2/task1/baseline/evaluate.py b/recipes/cad2/task1/baseline/evaluate.py index 4bee312a..3749e181 100644 --- a/recipes/cad2/task1/baseline/evaluate.py +++ b/recipes/cad2/task1/baseline/evaluate.py @@ -2,6 +2,7 @@ from __future__ import annotations +import hashlib import json import logging from pathlib import Path @@ -11,7 +12,7 @@ import pyloudnorm as pyln import torch.nn import whisper -from jiwer import compute_measures +from alt_eval import compute_metrics from omegaconf import DictConfig from clarity.enhancer.multiband_compressor import MultibandCompressor @@ -25,6 +26,17 @@ logger = logging.getLogger(__name__) +def set_song_seed(song: str) -> None: + """Set a seed that is unique for the given song""" + song_encoded = hashlib.md5(song.encode("utf-8")).hexdigest() + song_md5 = int(song_encoded, 16) % (10**8) + np.random.seed(song_md5) + + torch.manual_seed(song_md5) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(song_md5) + + def make_scene_listener_list(scenes_listeners: dict, small_test: bool = False) -> list: """Make the list of scene-listener pairing to process @@ -90,6 +102,7 @@ def compute_intelligibility( ear = Ear( equiv_0db_spl=equiv_0db_spl, sample_rate=sample_rate, + verbose=False, ) reference = segment_metadata["text"] @@ -100,30 +113,38 @@ def compute_intelligibility( enhanced_left = ear.process(enhanced_signal[:, 0])[0] left_path = Path(f"{path_intermediate.as_posix()}_left.flac") save_flac_signal( - enhanced_signal, + enhanced_left, left_path, 44100, sample_rate, ) - hypothesis = scorer.transcribe(left_path.as_posix(), fp16=False)["text"] + hypothesis = scorer.transcribe(left_path.as_posix(), fp16=False, temperature=0.0)[ + "text" + ] lyrics["hypothesis_left"] = hypothesis - left_results = compute_measures(reference, hypothesis) + left_results = compute_metrics( + [reference], [hypothesis], languages="en", include_other=False + ) # Compute right ear ear.set_audiogram(listener.audiogram_right) enhanced_right = ear.process(enhanced_signal[:, 1])[0] right_path = Path(f"{path_intermediate.as_posix()}_right.flac") save_flac_signal( - enhanced_signal, + enhanced_right, right_path, 44100, sample_rate, ) - hypothesis = scorer.transcribe(right_path.as_posix(), fp16=False)["text"] + hypothesis = scorer.transcribe(right_path.as_posix(), fp16=False, temperature=0.0)[ + "text" + ] lyrics["hypothesis_right"] = hypothesis - right_results = compute_measures(reference, hypothesis) + right_results = compute_metrics( + [reference], [hypothesis], languages="en", include_other=False + ) # Compute the average score for both ears total_words = ( @@ -155,9 +176,23 @@ def compute_quality( reference_signal: np.ndarray, enhanced_signal: np.ndarray, listener: Listener, - config: DictConfig, + reference_sample_rate: int, + enhanced_sample_rate: int, + HAAQI_sample_rate: int, ) -> tuple[float, float]: - """Compute the HAAQI score for the left and right channels""" + """Compute the HAAQI score for the left and right channels + + Args: + reference_signal: The reference signal + enhanced_signal: The enhanced signal + listener: The listener + reference_sample_rate: The sample rate of the reference signal + enhanced_sample_rate: The sample rate of the enhanced signal + HAAQI_sample_rate: The sample rate for the HAAQI computation + + Returns: + The HAAQI score for the left and right channels + """ scores = [] for channel in range(2): @@ -167,16 +202,16 @@ def compute_quality( s = compute_haaqi( processed_signal=resample( enhanced_signal[:, channel], - config.remix_sample_rate, - config.HAAQI_sample_rate, + enhanced_sample_rate, + HAAQI_sample_rate, ), reference_signal=resample( reference_signal[:, channel], - config.input_sample_rate, - config.HAAQI_sample_rate, + reference_sample_rate, + HAAQI_sample_rate, ), - processed_sample_rate=config.HAAQI_sample_rate, - reference_sample_rate=config.HAAQI_sample_rate, + processed_sample_rate=HAAQI_sample_rate, + reference_sample_rate=HAAQI_sample_rate, audiogram=audiogram, equalisation=2, level1=65 - 20 * np.log10(compute_rms(reference_signal[:, channel])), @@ -304,6 +339,11 @@ def run_compute_scores(config: DictConfig) -> None: sample_rate=config.input_sample_rate, ) + # Configure backend for determinism + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + + # Load the Whisper model intelligibility_scorer = whisper.load_model(config.evaluate.whisper_version) # Loop over the scene-listener pairs @@ -317,6 +357,10 @@ def run_compute_scores(config: DictConfig) -> None: scene_id, listener_id = scene_listener_ids + # Set the random seed for the scene + if config.evaluate.set_random_seed: + set_song_seed(scene_id) + # Load scene details scene = scenes[scene_id] listener = listener_dict[listener_id] @@ -377,7 +421,15 @@ def run_compute_scores(config: DictConfig) -> None: # COMPUTE SCORES # Compute the HAAQI and Whisper scores - haaqi_scores = compute_quality(reference, enhanced_signal, listener, config) + haaqi_scores = compute_quality( + reference_signal=reference, + enhanced_signal=enhanced_signal, + listener=listener, + reference_sample_rate=config.input_sample_rate, + enhanced_sample_rate=config.remix_sample_rate, + HAAQI_sample_rate=config.HAAQI_sample_rate, + ) + whisper_left, whisper_right, lyrics_text = compute_intelligibility( enhanced_signal=enhanced_signal, segment_metadata=songs[scene["segment_id"]], diff --git a/recipes/cad2/task1/requirements.txt b/recipes/cad2/task1/requirements.txt index 12ce2cb5..14359e1f 100644 --- a/recipes/cad2/task1/requirements.txt +++ b/recipes/cad2/task1/requirements.txt @@ -1,4 +1,4 @@ +alt-eval huggingface-hub -jiwer openai-whisper safetensors