From 159f0464f20252715cb4ed8d260bf73fb8e6e2d3 Mon Sep 17 00:00:00 2001 From: alvinn Date: Wed, 13 Mar 2024 01:53:23 +0200 Subject: [PATCH] March 12 changeset --- README.md | 61 +++++- configs/inference/debug_inference.yaml | 4 + configs/inference/inference_v1.yaml | 4 + css/css.py | 71 ++++--- css/css_with_conformer/README.md | 3 + css/css_with_conformer/separate.py | 2 +- css/css_with_conformer/utils/mvdr_util.py | 88 ++++---- css/training/train.py | 3 + diarization/diarization.py | 46 +++- diarization/diarization_common.py | 3 +- diarization/time_based_diarization.py | 2 +- diarization/word_based_diarization.py | 116 +++++++--- inference_pipeline/inference.py | 134 ++++++------ inference_pipeline/load_meeting_data.py | 124 +++++++++-- run_inference.py | 2 +- utils/audio_utils.py | 2 +- utils/azure_storage.py | 18 +- utils/plot_utils.py | 101 ++++++++- utils/scoring.py | 244 ++++++++++++---------- utils/text_norm_whisper_like/__init__.py | 8 +- utils/text_norm_whisper_like/english.py | 15 ++ utils/torch_utils.py | 98 ++++++++- 22 files changed, 837 insertions(+), 312 deletions(-) diff --git a/README.md b/README.md index bdfb1eb..1312c51 100644 --- a/README.md +++ b/README.md @@ -1,11 +1,45 @@ +[![Slack][slack-badge]][slack-invite] + + +[slack-badge]: https://img.shields.io/badge/slack-chat-green.svg?logo=slack +[slack-invite]: https://join.slack.com/t/chime-fey5388/shared_invite/zt-1oha0gedv-JEUr1mSztR7~iK9AxM4HOA + # Introduction Welcome to the "NOTSOFAR-1: Distant Meeting Transcription with a Single Device" Challenge. This repo contains the baseline system code for the NOTSOFAR-1 Challenge. -For more details see: -1. CHiME website: https://www.chimechallenge.org/current/task2/index -2. Preprint: https://arxiv.org/abs/2401.08887 +- For more information about NOTSOFAR, visit [CHiME's official challenge website](https://www.chimechallenge.org/current/task2/index) +- [Register](https://www.chimechallenge.org/current/task2/submission) to participate. +- [Baseline system description](https://www.chimechallenge.org/current/task2/baseline). +- Contact us: join the `chime-8-notsofar` channel on the [CHiME Slack](https://join.slack.com/t/chime-fey5388/shared_invite/zt-1oha0gedv-JEUr1mSztR7~iK9AxM4HOA), or open a [GitHub issue](https://github.com/microsoft/NOTSOFAR1-Challenge/issues). + +### 📊 Baseline Results on NOTSOFAR dev-set-1 + +Values are presented in `tcpWER / tcORC-WER (session count)` format. +
+As mentioned in the [official website](https://www.chimechallenge.org/current/task2/index#tracks), +systems are ranked based on the speaker-attributed +[tcpWER](https://github.com/fgnt/meeteval/blob/main/doc/tcpwer.md) +, while the speaker-agnostic [tcORC-WER](https://github.com/fgnt/meeteval) serves as a supplementary metric for analysis. +
+We include analysis based on a selection of hashtags from our [metadata](https://www.chimechallenge.org/current/task2/data#metadata), providing insights into how different conditions affect system performance. + + + +| | Single-Channel | Multi-Channel | +|----------------------|-----------------------|-----------------------| +| All Sessions | **46.8** / 38.5 (177) | **32.4** / 26.7 (106) | +| #NaturalMeeting | 47.6 / 40.2 (30) | 32.3 / 26.2 (18) | +| #DebateOverlaps | 54.9 / 44.7 (39) | 38.0 / 31.4 (24) | +| #TurnsNoOverlap | 32.4 / 29.7 (10) | 21.2 / 18.8 (6) | +| #TransientNoise=high | 51.0 / 43.7 (10) | 33.6 / 29.1 (5) | +| #TalkNearWhiteboard | 55.4 / 43.9 (40) | 39.9 / 31.2 (22) | + + + + + # Project Setup @@ -143,7 +177,8 @@ Begin by exploring the following components: ### Training datasets For training and fine-tuning your models, NOTSOFAR offers the **simulated training set** and the training portion of the -**recorded meeting dataset**. Refer to the `download_simulated_subset` and `download_meeting_subset` functions in `utils/azure_storage.py`, +**recorded meeting dataset**. Refer to the `download_simulated_subset` and `download_meeting_subset` functions in +[utils/azure_storage.py](https://github.com/microsoft/NOTSOFAR1-Challenge/blob/main/utils/azure_storage.py#L109), or the [NOTSOFAR-1 Datasets](#notsofar-1-datasets---download-instructions) section. @@ -159,10 +194,12 @@ python run_training_css_local.py ## 2. Training on the full simulated training dataset ### Step 1: Download the simulated training dataset -You can use the `download_simulated_subset` function in `utils/azure_storage.py` to download the training dataset from blob storage. +You can use the `download_simulated_subset` function in +[utils/azure_storage.py](https://github.com/microsoft/NOTSOFAR1-Challenge/blob/main/utils/azure_storage.py) +to download the training dataset from blob storage. You have the option to download either the complete dataset, comprising almost 1000 hours, or a smaller, 200-hour subset. -For example, to download the entire 1000-hour dataset, make the following calls to download both the training and validation subsets: +Examples: ```python ver='v1.5' # this should point to the lateset and greatest version of the dataset. @@ -223,8 +260,11 @@ Alternatively, using AzCopy CLI, set these arguments and run the following comma - `version`: version to download (`240103g` / etc.). Use the latest version. - `datasets_path` - path to the directory where you want to download the benchmarking dataset (destination directory must exist).
-Currently only **dev_set** (no GT) and **train_set** are available. See timeline on the [NOTSOFAR page](https://www.chimechallenge.org/current/task2/index) for when the other sets will be released. -See doc in `download_meeting_subset` function in `utils/azure_storage.py` for latest available versions. +Train, dev, and eval sets are released for the NOTSOFAR challenge are released in stages. +See release timeline on the [NOTSOFAR page](https://www.chimechallenge.org/current/task2/index#dates). +See doc in `download_meeting_subset` function in +[utils/azure_storage.py](https://github.com/microsoft/NOTSOFAR1-Challenge/blob/main/utils/azure_storage.py#L109) +for latest available versions. ```bash azcopy copy https://notsofarsa.blob.core.windows.net/benchmark-datasets///MTG /benchmark --recursive @@ -259,7 +299,7 @@ azcopy copy https://notsofarsa.blob.core.windows.net/css-datasets// 1 and cfg.mc_mvdr: + mvdr_responses = make_mvdr(masks['spk_masks'].squeeze(0).moveaxis(2, 0).cpu().numpy(), + masks['noise_masks'].squeeze(0).moveaxis(2, 0).cpu().numpy(), + mix_stft=stft_seg_device.squeeze(0).moveaxis(2, 0).cpu().numpy(), + return_stft=True) + mvdr_responses = torch.from_numpy(np.stack(mvdr_responses, axis=-1)).unsqueeze(0).to(device) + # [B, F, T, num_spks] + seg_for_masking = mvdr_responses + else: + seg_for_masking = stft_seg_device_chref.unsqueeze(-1) # [B, F, T, 1] + + # floored mask multiplication. if mask_floor_db == 0, mask is all-ones (assuming mask in [0, 1] range) + mask_floor_db = cfg.mc_mask_floor_db if num_channels > 1 else cfg.sc_mask_floor_db + assert mask_floor_db <= 0 + mask_floor = 10. ** (mask_floor_db / 20.) # dB to amplitude + mask_clipped = torch.clip(masks['spk_masks'], min=mask_floor) + separated_seg = seg_for_masking * mask_clipped # [B, F, T, num_spks] + + # Plot for debugging + # plot_separation_methods(stft_seg_device_chref, masks, mvdr_responses, separator, cfg, + # plots=['mvdr', 'masked_mvdr', 'spk_masks', 'masked_ref_ch', 'mixture']) if cfg.normalize_segment_power: # normalize to match the input mixture power @@ -264,15 +287,15 @@ def separate_and_stitch(speech_mix: np.ndarray, separator: ConformerCssWrapper, dim=(1, 2), keepdim=True) ) - assert torch.is_complex(masked_seg) + assert torch.is_complex(separated_seg) sep_energy = torch.sqrt( - torch.mean(masked_seg[:, :, :t].sum(-1).abs().pow(2), # sum over spks, squared mag + torch.mean(separated_seg[:, :, :t].sum(-1).abs().pow(2), # sum over spks, squared mag dim=(1, 2), keepdim=True ) ) - masked_seg = (mix_energy / sep_energy)[..., None] * masked_seg + separated_seg = (mix_energy / sep_energy)[..., None] * separated_seg - masked_seg_list.append(masked_seg.cpu()) # [B, F, T, num_spks] + separated_seg_list.append(separated_seg.cpu()) # [B, F, T, num_spks] spk_masks_list.append(masks['spk_masks'].cpu()) # [B, F, T, num_spks] @@ -283,7 +306,7 @@ def separate_and_stitch(speech_mix: np.ndarray, separator: ConformerCssWrapper, # add first segment wg_seg = calc_segment_weight(segment_frames, m0_frames, m1_frames, is_first_seg=True) wg_stitched[:segment_frames] += wg_seg - stft_stitched[:, :, :segment_frames] += wg_seg.view(1, 1, -1, 1) * masked_seg_list[0] + stft_stitched[:, :, :segment_frames] += wg_seg.view(1, 1, -1, 1) * separated_seg_list[0] mask_stitched[:, :, :segment_frames] += wg_seg.view(1, 1, -1, 1) * spk_masks_list[0] pit = PitWrapper({'mse': mse_loss, 'l1': l1_loss}[cfg.stitching_loss]) @@ -292,9 +315,9 @@ def separate_and_stitch(speech_mix: np.ndarray, separator: ConformerCssWrapper, for i in range(1, num_segments): if cfg.stitching_input == 'mask': left_input, right_input = spk_masks_list[i-1], spk_masks_list[i] - elif cfg.stitching_input == 'masked_mag': + elif cfg.stitching_input == 'separation_result': # masked magnitudes - left_input, right_input = masked_seg_list[i - 1].abs(), masked_seg_list[i].abs() + left_input, right_input = separated_seg_list[i - 1].abs(), separated_seg_list[i].abs() else: assert False, f'unexpected stitching_input: {cfg.stitching_input}' @@ -303,12 +326,12 @@ def separate_and_stitch(speech_mix: np.ndarray, separator: ConformerCssWrapper, # Plot for debugging: # plot_left_right_stitch(separator, left_input, right_input, right_perm, - # overlap_frames, cfg, stft_seg_to_play=masked_seg_list[i][..., 0], fs=fs) + # overlap_frames, cfg, stft_seg_to_play=separated_seg_list[i][..., 0], fs=fs) # permute current segment to match with the previous one for ib in range(batch_size): spk_masks_list[i][ib] = spk_masks_list[i][ib, ..., right_perm[ib]] - masked_seg_list[i][ib] = masked_seg_list[i][ib, ..., right_perm[ib]] + separated_seg_list[i][ib] = separated_seg_list[i][ib, ..., right_perm[ib]] st = i * hop_frames en = min(st + segment_frames, mix_frames) @@ -316,8 +339,8 @@ def separate_and_stitch(speech_mix: np.ndarray, separator: ConformerCssWrapper, wg_seg = calc_segment_weight(segment_frames, m0_frames, m1_frames, is_last_seg=(i==num_segments-1)) wg_seg = wg_seg[:en-st] # last segment may be shorter wg_stitched[st:en] += wg_seg - assert torch.is_complex(masked_seg_list[i]), 'summation assumes complex representation' - stft_stitched[:, :, st:en] += wg_seg.view(1, 1, -1, 1) * masked_seg_list[i][:, :, :en-st] + assert torch.is_complex(separated_seg_list[i]), 'summation assumes complex representation' + stft_stitched[:, :, st:en] += wg_seg.view(1, 1, -1, 1) * separated_seg_list[i][:, :, :en-st] mask_stitched[:, :, st:en] += wg_seg.view(1, 1, -1, 1) * spk_masks_list[i][:, :, :en-st] assert (wg_stitched > 1e-5).all(), 'zero weights found. check hop_size, segment_size or m0, m1' diff --git a/css/css_with_conformer/README.md b/css/css_with_conformer/README.md index dec2c53..f6b50f9 100644 --- a/css/css_with_conformer/README.md +++ b/css/css_with_conformer/README.md @@ -1,3 +1,6 @@ +The code under this directory is mostly a copy of "CSS with Conformer" from the original repo at the URL below. +Some extentions were made when adopting to NOTSOFAR. + We didn't copy the README.md file from the original repo because it contains Shared Access Signatures (SAS) that are considered secrets by some version control systems. One can find the original README.md file at the URL below. https://github.com/Sanyuan-Chen/CSS_with_Conformer/blob/master/README.md diff --git a/css/css_with_conformer/separate.py b/css/css_with_conformer/separate.py index 35f78c8..4cc1469 100644 --- a/css/css_with_conformer/separate.py +++ b/css/css_with_conformer/separate.py @@ -100,7 +100,7 @@ def run(args): print('spks',len(spks),spks[0].shape) if args.mvdr: - res1, res2 = make_mvdr(np.asfortranarray(mixed.T), spks) + res1, res2 = make_mvdr(spks[:2], spks[2:], np.asfortranarray(mixed.T)) spks = [res1, res2] sf.write(dump_dir / f"{duration_sec}_mix.wav", egs['mix'][0].cpu().numpy(), sr) diff --git a/css/css_with_conformer/utils/mvdr_util.py b/css/css_with_conformer/utils/mvdr_util.py index 2e3ea44..0fefc8c 100644 --- a/css/css_with_conformer/utils/mvdr_util.py +++ b/css/css_with_conformer/utils/mvdr_util.py @@ -2,50 +2,68 @@ import numpy as np -def make_wta(result_mask): - noise_mask = result_mask[2] - if len(result_mask) == 4: - noise_mask += result_mask[3] - mask = np.stack((result_mask[0], result_mask[1],noise_mask)) +def make_mvdr(spk_masks, noise_masks, mix_wav = None, mix_stft = None, return_stft=False): + """ + + Args: + mix_wav: mixture waveform, [Nsamples, Mics] tensor + spk_masks: [num_spks, F, T] tensor + noise_masks: [num_noise, F, T] tensor + mix_stft: mixture STFT, [Mics, F, T] complex tensor + return_stft: if True, return the STFT of the separated signals. + Otherwise, return the separated signals in the time domain. + + Returns: + + """ + all_masks = make_wta(spk_masks, noise_masks) # [num_spks + 1_noise, F, T] + if mix_stft is None: + mix_stft=[] + for i in range(7): + st=librosa.core.stft(mix_wav[:, i], n_fft=512, hop_length=256) + mix_stft.append(st) + mix_stft=np.asarray(mix_stft) # [Mics, F, T] + + L = np.min([all_masks.shape[-1],mix_stft.shape[-1]]) + mix_stft = mix_stft[:,:,:L] + all_masks = all_masks[:,:,:L] + + scms = [get_mask_scm(mix_stft, mask) for mask in all_masks] + spk_scms = np.stack(scms[:-1]) # [num_spks, F, 7, 7] + noise_scm = scms[-1] # [F, 7, 7] + + res_per_spk = [] + for i in range(spk_scms.shape[0]): + # sum SCMs of all other speakers + other_spks_scm = spk_scms[np.arange(spk_scms.shape[0]) != i].sum(axis=0) + # add noise and compute beamforming coefficients for the current speaker + coef = calc_bfcoeffs(noise_scm + other_spks_scm, spk_scms[i]) + res = get_bf(mix_stft, coef) + res_per_spk.append(res) + + if not return_stft: + res_per_spk = [librosa.istft(res, hop_length=256) for res in res_per_spk] + + return res_per_spk + + +def make_wta(spk_masks, noise_masks): + noise_mask = noise_masks.sum(axis=0, keepdims=True) + mask = np.vstack([spk_masks, noise_mask]) mask_max = np.amax(mask, axis=0, keepdims=True) mask = np.where(mask==mask_max, mask, 1e-10) return mask -def make_mvdr(s,result): - mask=make_wta(result) - M=[] - for i in range(7): - st=librosa.core.stft(s[:,i],n_fft=512,hop_length=256) - M.append(st) - M=np.asarray(M) - - L=np.min([mask.shape[-1],M.shape[-1]]) - M=M[:,:,:L] - - mask=mask[:,:,:L] - - tgt_scm,_=get_mask_scm(M,mask[0]) - itf_scm,_=get_mask_scm(M,mask[1]) - noi_scm,_=get_mask_scm(M,mask[2]) - - coef=calc_bfcoeffs(noi_scm+itf_scm,tgt_scm) - res=get_bf(M,coef) - res1=librosa.istft(res,hop_length=256) - - coef=calc_bfcoeffs(noi_scm+tgt_scm,itf_scm) - res=get_bf(M,coef) - res2=librosa.istft(res,hop_length=256) - - return res1, res2 - - def get_mask_scm(mix,mask): - Ri = np.einsum('FT,FTM,FTm->FMm', mask, mix.transpose(1,2,0), mix.transpose(1,2,0).conj()) + """Return spatial covariance matrix of the masked signal.""" + + Ri = np.einsum('FT,FTM,FTm->FMm', + mask, mix.transpose(1,2,0), mix.transpose(1,2,0).conj()) t1=np.eye(7) t2=t1[np.newaxis,:,:] Ri+=1e-15*t2 - return Ri,np.sum(mask) + return Ri # ,np.sum(mask) def calc_bfcoeffs(noi_scm,tgt_scm): diff --git a/css/training/train.py b/css/training/train.py index c79c304..a7868f4 100644 --- a/css/training/train.py +++ b/css/training/train.py @@ -86,6 +86,9 @@ class TrainCfg: save_every: Optional[Tuple] = None scheduler_step_every: Optional[Tuple] = (1, 'epochs') stop_after: Optional[Tuple] = (120, 'epochs') + calc_side_info: bool = False + loss_name: Optional[str] = None + base_loss_name: Optional[str] = None def get_model(cfg: TrainCfg): diff --git a/diarization/diarization.py b/diarization/diarization.py index bc2a7e0..da988c9 100644 --- a/diarization/diarization.py +++ b/diarization/diarization.py @@ -1,4 +1,5 @@ import os +from typing import Optional import pandas as pd from pathlib import Path @@ -6,12 +7,13 @@ from diarization.time_based_diarization import time_based_diarization from diarization.word_based_diarization import word_based_clustering from utils.logging_def import get_logger +from utils.torch_utils import get_world_size _LOG = get_logger('diarization') def diarization_inference(out_dir: str, segments_df: pd.DataFrame, cfg: DiarizationCfg, - fetch_from_cache: bool) -> pd.DataFrame: + fetch_from_cache: bool, device: Optional[str] = None) -> pd.DataFrame: """ Run diarization to assign a speaker label to each ASR word. @@ -46,8 +48,10 @@ def diarization_inference(out_dir: str, segments_df: pd.DataFrame, cfg: Diarizat 'meeting_id': the meeting id. 'session_id': the session id. 'wav_file_name': the name of the wav file that the segment was transcribed from. + this is typically points to the speech separated wav file (see CSS module). cfg: diarization configuration. fetch_from_cache: If True, returns the cached results if they exist. Otherwise, runs the inference. + device: the device to use for loading the model and running inference. Returns: attributed_segments_df: a new set of segments with 'speaker_id' column added. """ @@ -55,31 +59,51 @@ def diarization_inference(out_dir: str, segments_df: pd.DataFrame, cfg: Diarizat _LOG.info("Running Speaker Diarization") assert segments_df.session_id.nunique() <= 1, 'no cross-session information is permitted' - assert segments_df.wav_file_name.nunique() <= 3, 'expecting at most three separated channels' + # these two modes are for debugging and analysis if cfg.method == "skip": _LOG.info("Skipping Diarization") attributed_segments_df = segments_df.copy() attributed_segments_df['speaker_id'] = 'spk0' return attributed_segments_df + elif cfg.method == "by_wav_file_name": + attributed_segments_df = segments_df.copy() + # map each unique wav_file_name to an index + wav_file_name_ind, uniques = pd.factorize(attributed_segments_df['wav_file_name'], sort=True) + attributed_segments_df['speaker_id'] = wav_file_name_ind + attributed_segments_df['speaker_id'] = 'wav_' + attributed_segments_df['speaker_id'].astype(str) + _LOG.info(f"Diarization by wav file names: {uniques}") + return attributed_segments_df session_name = segments_df.session_id[0] + is_ct = session_name.startswith('close_talk') + assert segments_df.wav_file_name.nunique() <= 3 or is_ct, 'expecting at most three separated channels' output_dir = Path(out_dir) / "diarization" / session_name / cfg.method out_file = output_dir / "all_segments_df.pkl" - if fetch_from_cache and out_file.exists(): - attributed_segments_df = pd.read_pickle(out_file) - return attributed_segments_df + # Skip cache and writing ops if running in DDP mode, it is necessary to continue evaluate the model on each device + skip_cache_and_write = get_world_size() > 1 + + if not skip_cache_and_write: + if fetch_from_cache and out_file.exists(): + attributed_segments_df = pd.read_pickle(out_file) + return attributed_segments_df + os.makedirs(output_dir, exist_ok=True) - wav_files_sorted = sorted(segments_df.wav_file_name.unique()) - os.makedirs(output_dir, exist_ok=True) + segments_df = segments_df.copy() + # wav_file_name as category to convert to indices + segments_df['wav_file_name'] = segments_df['wav_file_name'].astype('category') + assert 'wav_file_name_ind' not in segments_df + segments_df['wav_file_name_ind'] = segments_df['wav_file_name'].cat.codes + wav_files = segments_df['wav_file_name'].cat.categories.to_list() if cfg.method == "word_nmesc": - attributed_segments_df = word_based_clustering(wav_files_sorted, segments_df, cfg) + attributed_segments_df = word_based_clustering(wav_files, segments_df, cfg, device) else: - attributed_segments_df = time_based_diarization(wav_files_sorted, segments_df, str(output_dir), cfg) + attributed_segments_df = time_based_diarization(wav_files, segments_df, str(output_dir), cfg) - attributed_segments_df.to_pickle(out_file) - _LOG.info(f'Speaker Diarization saved to {out_file}') + if not skip_cache_and_write: + attributed_segments_df.to_pickle(out_file) + _LOG.info(f'Speaker Diarization saved to {out_file}') return attributed_segments_df diff --git a/diarization/diarization_common.py b/diarization/diarization_common.py index 9ed11f7..63c10cf 100644 --- a/diarization/diarization_common.py +++ b/diarization/diarization_common.py @@ -94,9 +94,8 @@ def prepare_diarized_data_frame(all_words, segments_df, apply_deduplication): diarized_segments_df['session_id'] = segments_df['session_id'][0] # assign correct CSS file name to each diarized segment - wav_file_name_prefix = os.path.splitext(segments_df['wav_file_name'][0])[0][:-1] stream_id = [seg[0][-1] for seg in diarized_segments_df.word_timing.to_list()] - diarized_segments_df['wav_file_name'] = [f"{wav_file_name_prefix}{i}.wav" for i in stream_id] + diarized_segments_df['wav_file_name'] = segments_df['wav_file_name'].cat.categories[stream_id] diarized_segments_df['speaker_id'] = segments["speaker_id"] diff --git a/diarization/time_based_diarization.py b/diarization/time_based_diarization.py index 9c6c86b..7766b18 100644 --- a/diarization/time_based_diarization.py +++ b/diarization/time_based_diarization.py @@ -126,7 +126,7 @@ def assign_words_to_speakers(segments_df: pd.DataFrame, spk_vad: np.array, apply all_words = [] for _, seg in segments_df.iterrows(): # get the unmixed channel id for current segment - channel_id = int(os.path.splitext(os.path.basename(seg.wav_file_name))[0][-1]) + channel_id = seg.wav_file_name_ind for i, word in enumerate(seg["word_timing"]): start_frame = int(np.round(word[1]/vad_time_resolution)) diff --git a/diarization/word_based_diarization.py b/diarization/word_based_diarization.py index 291773e..689c09f 100644 --- a/diarization/word_based_diarization.py +++ b/diarization/word_based_diarization.py @@ -1,14 +1,16 @@ import os +from typing import Optional import pandas as pd import numpy as np -from omegaconf import OmegaConf import torch from torch.cuda.amp import autocast +from tqdm import tqdm from nemo.collections.asr.models.label_models import EncDecSpeakerLabelModel from nemo.collections.asr.parts.utils.offline_clustering import NMESC, SpectralClustering, cos_similarity, getCosAffinityMatrix, getAffinityGraphMat from utils.audio_utils import read_wav +from utils.torch_utils import is_dist_initialized from diarization.diarization import DiarizationCfg from diarization.diarization_common import prepare_diarized_data_frame, DiarizationCfg from utils.logging_def import get_logger @@ -23,7 +25,7 @@ def load_speaker_model(model_name: str, device: str): _LOG.info("Loading pretrained {} model from NGC".format(model_name)) spk_model = EncDecSpeakerLabelModel.from_pretrained(model_name=model_name, map_location=device) spk_model.eval() - + return spk_model @@ -44,9 +46,9 @@ def run_clustering(raw_affinity_mat: np.array, max_num_speakers: int=8, max_rp_t spectral_model = SpectralClustering(n_clusters=n_clusters) cluster_label = spectral_model.forward(affinity_mat) - + return cluster_label - + def extract_speaker_embedding_for_words(segments_df, wavs, sr, spk_model, min_embedding_windows, max_allowed_word_duration=3): """ @@ -56,23 +58,22 @@ def extract_speaker_embedding_for_words(segments_df, wavs, sr, spk_model, min_em all_words = [] all_word_embeddings = [] - for _, seg in segments_df.iterrows(): + too_long_words = [] + + n_words = sum(len(seg['word_timing']) for _, seg in segments_df.iterrows()) + _segments_df, _ = _fill_dummy_words_for_ddp(segments_df) + words_processed = 0 + + for _, seg in tqdm(_segments_df.iterrows(), desc='extracting speaker embedding for segments', total=len(_segments_df)): # get the unmixed channel id for current segment - channel_id = int(os.path.splitext(os.path.basename(seg.wav_file_name))[0][-1]) + channel_id = seg.wav_file_name_ind for word in seg["word_timing"]: start_time = word[1] end_time = word[2] center_time = (start_time + end_time) / 2 word_duration = end_time - start_time - - if word_duration > max_allowed_word_duration: - # Very long word duration is very suspicious and may harm diarization. Ignore them for now. - # Note that these words will disappear in the final result. - # To do: find a better way to deal with these words. - _LOG.info(f"word '{word[0]}' has unreasonablly long duration ({start_time}s, {end_time}s). Skip it in diarization") - continue - + # extract multi-scale speaker embedding for the word word_embedding = [] for min_window_size in min_embedding_windows: @@ -92,17 +93,36 @@ def extract_speaker_embedding_for_words(segments_df, wavs, sr, spk_model, min_em word_wav = wavs[channel_id][start_sample:end_sample] word_wav = torch.tensor(word_wav[np.newaxis], dtype=torch.float32).to(spk_model.device) word_lens = torch.tensor([word_wav.shape[1]], dtype=torch.int).to(spk_model.device) - with autocast(): + with autocast(), torch.no_grad(): _, tmp_embedding = spk_model.forward(input_signal=word_wav, input_signal_length=word_lens) word_embedding.append(tmp_embedding.cpu().detach()) - + + words_processed += 1 + + if words_processed > n_words: + # This is a dummy word added for DDP. Skip it. + continue + + if word_duration > max_allowed_word_duration: + # Very long word duration is very suspicious and may harm diarization. Ignore them for now. + # Note that these words will disappear in the final result. + # To do: find a better way to deal with these words. + _LOG.info(f"word '{word[0]}' has unreasonablly long duration ({start_time}s, {end_time}s). Skip it in diarization") + too_long_words.append(word) + continue + + # append only the real words (do not append dummy words) all_words.append(word+[channel_id]) all_word_embeddings.append(torch.vstack(word_embedding)) - + + print(f'Done extracting embeddings. {words_processed=}, {len(all_words)=}, {n_words=}', flush=True) + n_real_words = n_words - len(too_long_words) + assert len(all_words) == n_real_words, f"Number of words {len(all_words)} != n_real_words {n_real_words}" return all_words, all_word_embeddings - -def word_based_clustering(audio_files: list, segments_df: pd.DataFrame, cfg: DiarizationCfg): + +def word_based_clustering(audio_files: list, segments_df: pd.DataFrame, cfg: DiarizationCfg, + device: Optional[str] = None): """ Treat each ASR word as a segment and run NMESC for clustering. @@ -122,16 +142,21 @@ def word_based_clustering(audio_files: list, segments_df: pd.DataFrame, cfg: Dia have the same size, i.e. NxN, where N is the number of words. So no resampling is needed. """ # load unmixed waveforms - wavs = [read_wav(audio_file, normalize=True, return_rate=True) for audio_file in audio_files] - sr = wavs[0][0] - wavs = np.vstack([wav[1] for wav in wavs]) - + srs, wavs = zip(*[read_wav(audio_file, normalize=True, return_rate=True) for audio_file in audio_files]) + sr = srs[0] + max_length = max([wav.size for wav in wavs]) + # pad to the maximum length and stack. padding is only relevant to segmented close-talk. + # CSS always returns equal-length channels. + wavs = np.vstack( + [np.pad(wav, (0, max_length - wav.size), 'constant', constant_values=(0, 0)) for wav in + wavs]) + # load speaker embedding model - spk_model = load_speaker_model(cfg.embedding_model_name, device=None) - + spk_model = load_speaker_model(cfg.embedding_model_name, device=device) + # extract word-based multi-scale speaker embedding vectors - all_words, all_word_embeddings = extract_speaker_embedding_for_words(segments_df, wavs, sr, spk_model, - cfg.min_embedding_windows, + all_words, all_word_embeddings = extract_speaker_embedding_for_words(segments_df, wavs, sr, spk_model, + cfg.min_embedding_windows, cfg.max_allowed_word_duration) # compute affinity matrix for clustering @@ -149,5 +174,40 @@ def word_based_clustering(audio_files: list, segments_df: pd.DataFrame, cfg: Dia # prepare segment data frame all_words = [word+[f"spk{spk_idx}"] for word, spk_idx in zip(all_words, cluster_label)] diarized_segments_df = prepare_diarized_data_frame(all_words, segments_df, cfg.apply_deduplication) - + return diarized_segments_df + + +def _fill_dummy_words_for_ddp(segments_df: pd.DataFrame) -> tuple[pd.DataFrame, int]: + """ + Fill the last segment with dummy words to make the number of words the same across all processes in DDP. + + Returns: + (a COPY of segments_df with dummy words added to the last segment, number of real words, number of dummies) + """ + + if not is_dist_initialized(): + return segments_df, 0 + + n_words = sum(len(seg['word_timing']) for _, seg in segments_df.iterrows()) + max_words = get_max_value(n_words) + print(f"Number of segments: {len(segments_df)}, Number of words: {n_words}, max_words(in DDP): {max_words}") + + # find first segment with non-empty word_timing + for i in range(len(segments_df)): + if len(segments_df.iloc[i]['word_timing']) > 0: + dummy_word = segments_df.iloc[i]['word_timing'][-1].copy() + break + + # fill last segment with dummy data + _segments_df = segments_df.copy() + n_dummies = max_words - n_words + for _ in range(n_dummies): + _segments_df.iloc[-1]['word_timing'].append(dummy_word) + + n_words_with_dummies = sum([len(seg['word_timing']) for _, seg in _segments_df.iterrows()]) + assert n_words_with_dummies == max_words, \ + f"Number of words with dummies {n_words_with_dummies} != max_words {max_words}" + print(f"Number of words to process (with dummies): {n_words_with_dummies}") + + return _segments_df, n_dummies diff --git a/inference_pipeline/inference.py b/inference_pipeline/inference.py index cdde41a..0e70c7a 100644 --- a/inference_pipeline/inference.py +++ b/inference_pipeline/inference.py @@ -1,6 +1,7 @@ from dataclasses import field, dataclass +from functools import partial +from pathlib import Path from typing import Optional -import os import tqdm import pandas as pd @@ -11,7 +12,7 @@ from diarization.diarization_common import DiarizationCfg from inference_pipeline.load_meeting_data import load_data from utils.logging_def import get_logger -from utils.scoring import ScoringCfg, calc_wer, write_transcript_to_stm +from utils.scoring import ScoringCfg, calc_wer, df_to_seglst, normalize_segment, write_submission_jsons _LOG = get_logger('inference') @@ -25,6 +26,7 @@ class InferenceCfg: # Optional: Query to filter all_session_df. Useful for debugging. Must be None during full evaluation. session_query: Optional[str] = None + @dataclass class FetchFromCacheCfg: css: bool = False @@ -51,10 +53,10 @@ def inference_pipeline(meetings_dir: str, models_dir: str, out_dir: str, cfg: In # Load all meetings from the meetings dir _LOG.info(f'loading meetings from: {meetings_dir}') all_session_df, all_gt_utt_df, all_gt_metadata_df = load_data(meetings_dir, cfg.session_query) - + + wer_dfs, hyp_jsons = [], [] # Process each session independently. (Cross-session information is not permitted) - wer_series_list = [] - for session_name, session in tqdm.tqdm(all_session_df.iterrows(), desc='processing sessions'): + for _, session in tqdm.tqdm(all_session_df.iterrows(), desc='processing sessions'): _LOG.info(f'Processing session: {session.session_id}') # Front-end: split session into enhanced streams without overlap speech @@ -67,80 +69,88 @@ def inference_pipeline(meetings_dir: str, models_dir: str, out_dir: str, cfg: In attributed_segments_df: pd.DataFrame = ( diarization_inference(out_dir, segments_df, cfg.diarization, cache.diarization)) - # Write hypothesis transcription to: outdir / wer / {multi|single}channel /.../ *.stm - # To submit your system for evaluation, send us the contents of: outdir / wer / {multi|single}channel - tcp_wer_hyp_stm, tcorc_wer_hyp_stm = ( - write_hyp_transcripts(out_dir, session.session_id, attributed_segments_df, attributed_segments_df, - cfg.asr.text_normalizer())) + # Write hypothesis transcription to: outdir / wer / {multi|single}channel / session_id / *.json + # These will be merged into one json per track (mc/sc) for submission below. + hyp_paths: pd.Series = write_hypothesis_jsons( + out_dir, session, attributed_segments_df, cfg.asr.text_normalizer()) + hyp_jsons.append(hyp_paths) - # Calculate WER if GT is available + # Calculate session WER if GT is available if all_gt_utt_df is not None: # Rules: WER metric, arguments (collar), and text normalizer must remain unchanged - session_wer: pd.Series = calc_wer(out_dir, - tcp_wer_hyp_stm, - tcorc_wer_hyp_stm, - session.session_id, - get_session_gt(session, all_gt_utt_df), - cfg.asr.text_normalizer(), - collar=5, - save_visualizations=cfg.scoring.save_visualizations) - wer_series_list.append(session_wer) - - if wer_series_list: - # if GT is available, aggregate WER. - all_session_wer_df = pd.DataFrame(wer_series_list) + calc_wer_out = Path(out_dir) / 'wer' / session.session_id + session_wer: pd.DataFrame = calc_wer( + calc_wer_out, + hyp_paths.tcp_wer_hyp_json, + hyp_paths.tcorc_wer_hyp_json, + all_gt_utt_df, + cfg.asr.text_normalizer(), + collar=5, save_visualizations=cfg.scoring.save_visualizations) + wer_dfs.append(session_wer) + + # To submit results to one of the tracks, upload the tcp_wer_hyp.json and tc_orc_wer_hyp.json located in: + # outdir/wer/{singlechannel | multichannel}/ + hyp_jsons_df = pd.DataFrame(hyp_jsons) + write_submission_jsons(out_dir, hyp_jsons_df) + + if wer_dfs: # GT available + all_session_wer_df = pd.concat(wer_dfs, ignore_index=True) _LOG.info(f'Results:\n{all_session_wer_df}') _LOG.info(f'mean tcp_wer = {all_session_wer_df["tcp_wer"].mean()}') _LOG.info(f'mean tcorc_wer = {all_session_wer_df["tcorc_wer"].mean()}') # write session level results into a file - exp_id = "_".join([os.path.basename(cfg.css.checkpoint_sc), - cfg.asr.model_name, - cfg.diarization.method]) - os.makedirs(os.path.join(out_dir, "results"), exist_ok=True) - result_file = os.path.join(out_dir, "results", exp_id+".tsv") - _LOG.info(f"Results can be found in: {result_file}") + exp_id = "_".join(['css', cfg.asr.model_name, cfg.diarization.method]) + result_file = Path(out_dir) / "wer" / f"{exp_id}_results.csv" + result_file.parent.mkdir(parents=True, exist_ok=True) all_session_wer_df.to_csv(result_file, sep="\t") + _LOG.info(f"Wrote full results to: {result_file}") # TODO confidence intervals, WER per meta-data -def get_session_gt(session: pd.Series, all_gt_utt_df: pd.DataFrame): - return all_gt_utt_df[all_gt_utt_df['meeting_id'] == session.meeting_id] - - -def write_hyp_transcripts(out_dir, session_id, +def write_hypothesis_jsons(out_dir, session: pd.Series, attributed_segments_df: pd.DataFrame, - segments_df: pd.DataFrame, text_normalizer): - _LOG.info(f'Writing hypothesis transcripts for session {session_id}') - # hyp file for tcpWER, the metric used for ranking. - # MeetEval requires stream _id, which for tcpWER is the same as speaker_id. - df = attributed_segments_df.copy() - df['stream_id'] = df['speaker_id'] - tcp_wer_hyp_stm = write_transcript_to_stm(out_dir, df, text_normalizer, - session_id, 'tcp_wer_hyp.stm') - _LOG.info(f'tcpwer STM: {tcp_wer_hyp_stm}') - - # hyp file for tcORC-WER, a supplementary metric for analysis. - # MeetEval requires stream _id, which for tcORC-WER depends on the system. + """ + Write hypothesis transcripts for session, to be used for tcpwer and tcorwer metrics. + """ + + _LOG.info(f'Writing hypothesis transcripts for session {session.session_id}') + + def write_json(df, filename): + filepath = Path(out_dir) / 'wer' / session.session_id / filename + filepath.parent.mkdir(parents=True, exist_ok=True) + seglst = df_to_seglst(df) + seglst = seglst.map(partial(normalize_segment, tn=text_normalizer)) + seglst.dump(filepath) + _LOG.info(f'Wrote {filepath}') + return filepath + + # I. hyp file for tcpWER + tcp_wer_hyp_json = write_json(attributed_segments_df, 'tcp_wer_hyp.json') + + # II. hyp file for tcORC-WER, a supplementary metric for analysis. + # meeteval.wer.tcorcwer requires a stream ID, which depends on the system. + # Overlapped words should go into different streams, or appear in one stream while respecting the order + # in reference. See https://github.com/fgnt/meeteval. # In NOTSOFAR we define the streams as the outputs of CSS (continuous speech separation). # If your system does not have CSS you need to define the streams differently. # For example: for end-to-end multi-talker ASR you might use a single stream. - # Overlap speech should go into different streams, - # or appear in one stream but respecting the order in reference. See https://github.com/fgnt/meeteval. + # Alternatively, you could use the predicted speaker ID as the stream ID. - # Take wav_file_name from segments_df, rather than attributed_segments_df, since the latter is a result of - # diarizations, where the segments are built of words potentially coming from different channels. - # So, in the general case there is no meaningful "channel" that can be associated with a segment. - - # Use segment_df for tcorc_wer - # df = segments_df.copy() - df = attributed_segments_df.copy() + # The wav_file_name column of attributed_segments_df indicates the source CSS stream. + # Note that the diarization module ensures the words within each segment have a consistent channel. + df_tcorc = attributed_segments_df.copy() # Use factorize to map each unique wav_file_name to an index. - df['stream_id'], uniques = pd.factorize(df['wav_file_name'], sort=True) - + # meeteval.wer.tcorcwer treats speaker_id field as stream id. + df_tcorc['speaker_id'], uniques = pd.factorize(df_tcorc['wav_file_name'], sort=True) _LOG.debug(f'Found {len(uniques)} streams for tc_orc_wer_hyp.stm') - tcorc_wer_hyp_stm= write_transcript_to_stm(out_dir, df, text_normalizer, - session_id, 'tc_orc_wer_hyp.stm') - _LOG.info(f'tcorc_wer STM: {tcorc_wer_hyp_stm}') - return tcp_wer_hyp_stm, tcorc_wer_hyp_stm + tcorc_wer_hyp_json = write_json(df_tcorc, 'tc_orc_wer_hyp.json') + + return pd.Series({ + 'session_id': session.session_id, + 'tcp_wer_hyp_json': tcp_wer_hyp_json, + 'tcorc_wer_hyp_json': tcorc_wer_hyp_json, + 'is_mc': session.is_mc, + 'is_close_talk': session.is_close_talk, + }) diff --git a/inference_pipeline/load_meeting_data.py b/inference_pipeline/load_meeting_data.py index 2d94095..f205002 100644 --- a/inference_pipeline/load_meeting_data.py +++ b/inference_pipeline/load_meeting_data.py @@ -2,12 +2,17 @@ from pathlib import Path from typing import Tuple, Optional +import numpy as np import pandas as pd +import soundfile from tqdm import tqdm +from utils.audio_utils import write_wav +from utils.torch_utils import is_zero_rank, barrier + def load_data(meetings_dir: str, session_query: Optional[str] = None, - drop_close_talk: bool = True + return_close_talk: bool = False, out_dir: Optional[str] = None ) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: """ Load all meetings from the meetings dir @@ -17,9 +22,11 @@ def load_data(meetings_dir: str, session_query: Optional[str] = None, Example: project_root/artifacts/meeting_data/dev_set/240121_dev/MTG/ session_query: a query string to filter the sessions (optional) When submitting results, this should be None so no filtering occurs. - drop_close_talk: whether to drop close-talk devices (optional) + return_close_talk: if True, return each meeting as a session with all close-talk devices as its + wav_file_names. Close-talk must not be used during inference. However, this can be used as supervision - signal during training. + signal during training or for analysis. + out_dir: directory to save outputs to. only used when return_close_talk is True. Returns: all_session_df (per device): Each line corresponds to a recording of a meeting captured with a single device @@ -51,6 +58,7 @@ def load_data(meetings_dir: str, session_query: Optional[str] = None, devices_file = meeting_subdir / 'devices.json' metadata_file = meeting_subdir / 'gt_meeting_metadata.json' + gt_utt_df = None if transcription_file.exists(): # we have GT transcription gt_utt_df = pd.read_json(transcription_file) @@ -65,16 +73,32 @@ def load_data(meetings_dir: str, session_query: Optional[str] = None, metadata_dfs.append(metadata_df) devices_df = pd.read_json(devices_file) - if drop_close_talk: - # drop close-talk devices - devices_df = devices_df[~devices_df['device_name'].str.startswith('CT')].copy() devices_df['meeting_id'] = meeting_subdir.name - prefix = devices_df.is_mc.map({True: 'multichannel', False: 'singlechannel'}) - devices_df['session_id'] = prefix + '/' + meeting_subdir.name + '_' + devices_df['device_name'] - # convert to a list of full paths by appending meeting_subdir to each file in wav_file_name - devices_df['wav_file_names'] = devices_df['wav_file_names'].apply( - lambda x: [str(meeting_subdir / file_name.strip()) for file_name in x.split(',')] - ) + if return_close_talk: + devices_df = devices_df[devices_df.is_close_talk].copy() + assert len(devices_df) > 0, 'no close-talk devices found' + assert gt_utt_df is not None, 'expecting GT transcription' + + new_wav_file_names = concat_speech_segments(devices_df, gt_utt_df, meeting_subdir, out_dir) + + # original close-talk: + # orig_wav_file_names = devices_df.wav_file_names.apply(lambda x: str(meeting_subdir / x)).to_list() + + devices_df = devices_df.iloc[0:1].copy() + devices_df['device_name'] = 'close_talk' + devices_df['wav_file_names'] = [new_wav_file_names] # orig_wav_file_names + devices_df['session_id'] = 'close_talk/' + meeting_subdir.name + else: + # drop close-talk devices + devices_df = devices_df[~devices_df.is_close_talk].copy() + + prefix = devices_df.is_mc.map({True: 'multichannel', False: 'singlechannel'}) + devices_df['session_id'] = prefix + '/' + meeting_subdir.name + '_' + devices_df['device_name'] + # convert to a list of full paths by appending meeting_subdir to each file in wav_file_name + devices_df['wav_file_names'] = devices_df['wav_file_names'].apply( + lambda x: [str(meeting_subdir / file_name.strip()) for file_name in x.split(',')] + ) + session_dfs.append(devices_df) @@ -96,6 +120,80 @@ def load_data(meetings_dir: str, session_query: Optional[str] = None, all_session_df.drop('MtgType', axis=1, inplace=True) if session_query: - all_session_df.query(session_query, inplace=True) + query, process_first_n = _process_query(session_query) + all_session_df.query(query, inplace=True) + if process_first_n: + all_session_df = all_session_df.head(process_first_n) return all_session_df, all_gt_utt_df, all_metadata_df + + +def _process_query(query): + """ Split query into a few parts + Query can have the following format: + 1. "query_string" + 2. "query_string ##and index Optional[str]: """ - Download a subset of the meeting dataset to the destination directory. - The subsets and versions available will be updated in: - https://www.chimechallenge.org/current/task2/index + Downloads a subset of the NOTSOFAR recorded meeting dataset. + + The subsets will be released according to the timeline in: + https://www.chimechallenge.org/current/task2/index#dates Args: subset_name: name of split to download (dev_set / eval_set / train_set) @@ -122,12 +123,19 @@ def download_meeting_subset(subset_name: Literal['train_set', 'dev_set', 'eval_s (warning!: if true, will delete the entire destination_dir if it exists) - Latest available datsets: + Latest available versions: # dev_set, no GT available. submit your systems to leaderboard to measure WER. res_dir = download_meeting_subset(subset_name='dev_set', version='240208.2_dev', destination_dir=...) - # train_set, with GT for training models. + # first and second train-set batches combined, with GT for training models. + res_dir = download_meeting_subset(subset_name='train_set', version='240229.1_train', destination_dir=...) + + + + Previous versions: + + # first train-set batch, with GT for training models. res_dir = download_meeting_subset(subset_name='train_set', version='240208.2_train', destination_dir=...) diff --git a/utils/plot_utils.py b/utils/plot_utils.py index c6c28ef..d20082e 100644 --- a/utils/plot_utils.py +++ b/utils/plot_utils.py @@ -1,10 +1,11 @@ """Plot CSS inference intermediate results for debug. See usage in css.py""" - +from pathlib import Path from typing import Optional +import numpy as np import torch -from utils.audio_utils import play_wav +from utils.audio_utils import play_wav, write_wav def plot_stitched_masks(mask_stitched, activity_b, activity_final, cfg): @@ -71,4 +72,98 @@ def plot_left_right_stitch(separator, left_input, right_input, right_perm, overl plt.xlabel("Time Frames") plt.ylabel("Frequency Bins") plt.suptitle('right') - plt.show() \ No newline at end of file + plt.show() + + +def plot_separation_methods(stft_seg_device_chref, masks, mvdr_responses, separator, cfg, plots): + """Plot various masking methods for multi-channel segment, and writes them as wav files. + plots arg controls what to plot. + + For full plot: + plots = ['mvdr', 'masked_mvdr', 'spk_masks', 'masked_ref_ch', 'mixture'] + """ + import matplotlib.pyplot as plt + import librosa + plots_ordered = [] + num_spks = cfg.num_spks + fig, axs = plt.subplots(num_spks, len(plots), figsize=(30, 5 * num_spks)) + masked_ref_ch = stft_seg_device_chref.unsqueeze(-1) * masks['spk_masks'] + masked_mvdr = mvdr_responses * masks['spk_masks'] # note, no floor + col_ind = -1 + if 'mvdr' in plots: + plots_ordered.append('mvdr') + col_ind += 1 + for j in range(num_spks): + ax = axs[j, col_ind] + img = librosa.display.specshow( + librosa.amplitude_to_db(mvdr_responses[0, :, :, j].abs().cpu(), ref=np.max), + y_axis='linear', x_axis='time', ax=ax, sr=16000) + ax.set_title(f'Speaker {j + 1} Spectrogram') + plt.colorbar(img, ax=ax, format="%+2.0f dB") + ax.set_xlabel("Time Frames") + ax.set_ylabel("Frequency Bins") + if 'masked_mvdr' in plots: + plots_ordered.append('masked_mvdr') + col_ind += 1 + for j in range(num_spks): + ax = axs[j, col_ind] + img = librosa.display.specshow( + librosa.amplitude_to_db(masked_mvdr[0, :, :, j].abs().cpu(), ref=np.max), + y_axis='linear', x_axis='time', ax=ax, sr=16000) + ax.set_title(f'Speaker {j + 1} Spectrogram') + plt.colorbar(img, ax=ax, format="%+2.0f dB") + ax.set_xlabel("Time Frames") + ax.set_ylabel("Frequency Bins") + if 'masked_ref_ch' in plots: + plots_ordered.append('masked_ref_ch') + col_ind += 1 + for j in range(num_spks): + ax = axs[j, col_ind] + img = librosa.display.specshow( + librosa.amplitude_to_db(masked_ref_ch[0, :, :, j].abs().cpu(), ref=np.max), + y_axis='linear', x_axis='time', ax=ax, sr=16000) + ax.set_title(f'Speaker {j + 1} Spectrogram') + plt.colorbar(img, ax=ax, format="%+2.0f dB") + ax.set_xlabel("Time Frames") + ax.set_ylabel("Frequency Bins") + if 'spk_masks' in plots: + plots_ordered.append('spk_masks') + col_ind += 1 + for j in range(num_spks): + ax = axs[j, col_ind] + img = ax.imshow(masks['spk_masks'][0, :, :, j].cpu(), aspect='auto', origin='lower', vmin=0, + vmax=1) + plt.colorbar(img, ax=ax) + ax.set_xlabel("Time Frames") + ax.set_ylabel("Frequency Bins") + if 'mixture' in plots: + plots_ordered.append('mixture') + col_ind += 1 + # plot mixture ch0 + ax = axs[0, col_ind] + img_right = librosa.display.specshow( + librosa.amplitude_to_db(stft_seg_device_chref[0, :, :].abs().cpu(), ref=np.max), + y_axis='linear', x_axis='time', ax=ax) + plt.colorbar(img_right, ax=ax, format="%+2.0f dB") + ax.set_xlabel("Time Frames") + ax.set_ylabel("Frequency Bins") + + # plot noisemask + ax = axs[1, col_ind] + img = ax.imshow(masks['noise_masks'][0, :, :, 0].cpu(), aspect='auto', origin='lower', vmin=0, vmax=1) + plt.colorbar(img, ax=ax) + ax.set_xlabel("Time Frames") + ax.set_ylabel("Frequency Bins") + + plt.suptitle(' | '.join(plots_ordered)) + plt.tight_layout() + plt.show() + + istft = lambda x: separator.istft(x).cpu().numpy()[0] + # x: [B, num_spks, Nsamples] + out_dir = Path('artifacts/analysis/separated_seg') + write_wav(out_dir / 'input_ref_ch.wav', samps=istft(stft_seg_device_chref), sr=16000) + for j in range(num_spks): + write_wav(out_dir / f'masked_ref_ch{j}.wav', samps=istft(masked_ref_ch[..., j]), sr=16000) + write_wav(out_dir / f'mvdr_{j}.wav', samps=istft(mvdr_responses[..., j]), sr=16000) + write_wav(out_dir / f'masked_mvdr_{j}.wav', samps=istft(masked_mvdr[..., j]), sr=16000) \ No newline at end of file diff --git a/utils/scoring.py b/utils/scoring.py index 216967c..ff3a75d 100644 --- a/utils/scoring.py +++ b/utils/scoring.py @@ -1,15 +1,18 @@ +import decimal +from functools import partial from pathlib import Path from dataclasses import dataclass - -import pandas as pd -import json +from typing import List, Dict, Callable import os -import sys +import pandas as pd import meeteval +import meeteval.io.chime7 +from meeteval.io.seglst import SegLstSegment from meeteval.viz.visualize import AlignmentVisualization from utils.logging_def import get_logger +from utils.text_norm_whisper_like import get_txt_norm _LOG = get_logger('wer') @@ -20,126 +23,153 @@ class ScoringCfg: save_visualizations: bool = False -def write_transcript_to_stm(out_dir, attributed_segments_df: pd.DataFrame, tn, session_id: str, - filename): - """ - Save a session's speaker attributed transcription into stm files. +def df_to_seglst(df): + return meeteval.io.SegLST([ + SegLstSegment( + session_id=row.session_id, + start_time=decimal.Decimal(row.start_time), + end_time=decimal.Decimal(row.end_time), + words=row.text, + speaker=row.speaker_id, + ) + for row in df.itertuples() + ]) - Args: - out_dir: the outputs per module are saved to out_dir/{module_name}/{session_id}. - attributed_segments_df: dataframe of speaker attributed transcribed segments for the given session. - tn: text normalizer. - session_id: session name - filename: the file name to save. Should be, e.g., tcpwer_hyp.stm for hypothesis - and ref.stm for reference. - Returns: - path to saved stm. - """ - if 'session_id' in attributed_segments_df: - assert attributed_segments_df.session_id.nunique() <= 1, 'no cross-session information is permitted' - - filepath = Path(out_dir) / 'wer' / session_id / filename - filepath.parent.mkdir(parents=True, exist_ok=True) - channel = 1 # ignored by MeetEval - - with filepath.open('w', encoding='utf-8') as f: - # utf-8 encoding to handle non-ascii characters that may be output by some ASRs - for entry in range(len(attributed_segments_df)): - stream_id = attributed_segments_df.iloc[entry]['stream_id'] - start_time = attributed_segments_df.iloc[entry]['start_time'] - end_time = attributed_segments_df.iloc[entry]['end_time'] - text = tn(attributed_segments_df.iloc[entry]['text']) - f.write(f'{session_id} {channel} {stream_id} {start_time} {end_time} {text}\n') +def normalize_segment(segment: SegLstSegment, tn): + words = segment["words"] + words = tn(words) + segment["words"] = words + return segment - return str(filepath) - -def calc_wer(out_dir: str, tcp_wer_hyp_stm: str, tcorc_wer_hyp_stm: str,session_id: str, - gt_utt_df: pd.DataFrame, tn, collar: float, save_visualizations: bool) -> pd.Series: +def calc_wer(out_dir: str, + tcp_wer_hyp_json: str | List[Dict], + tcorc_wer_hyp_json: str | List[Dict], + gt_utt_df: pd.DataFrame, tn: str | Callable = 'chime8', + collar: float = 5, save_visualizations: bool = False) -> pd.DataFrame: """ - Calculates tcpWER for the given session using meeteval dedicated API and saves the error + Calculates tcpWER and tcorcWER for each session in hypothesis files using meeteval, and saves the error information to .json. Text normalization is applied to both hypothesis and reference. Args: - out_dir: the directory to save intermediate files to. - tcp_wer_hyp_stm: path to hypothesis .stm file for tcpWER. - tcorc_wer_hyp_stm: path to hypothesis .stm file for tcorcWER. - session_id: session name - gt_utt_df: dataframe of ground truth utterances for the given session. + out_dir: the directory to save the ref.json reference transcript to (extracted from gt_utt_df). + tcp_wer_hyp_json: path to hypothesis .json file for tcpWER, or json structure. + tcorc_wer_hyp_json: path to hypothesis .json file for tcorcWER, or json structure. + gt_utt_df: dataframe of ground truth utterances. must include the sessions in the hypothesis files. + see load_data() function. tn: text normalizer collar: tolerance of tcpWER to temporal misalignment between hypothesis and reference. save_visualizations: if True, save html visualizations of alignment between hyp and ref. Returns: - session_res: pd.Series with keys - - 'session_id' - 'hyp_file_name': absolute path to .stm files that contain hypothesis of pipeline per session. - 'ref_file_name': absolute path to .stm files that contain ground truth per session. - 'tcp_wer': tcpWER. - ... other useful tcp_wer keys (see keys below) + wer_df: pd.DataFrame with columns - + 'session_id' - same as in hypothesis files + 'tcp_wer': tcpWER + 'tcorc_wer': tcorcWER + ... intermediate tcpWER/tcorcWER fields such as insertions/deletions. see in code. """ - assert gt_utt_df.meeting_id.nunique() <= 1, 'GT should come from a single session' - - df = gt_utt_df.copy() - df['stream_id'] = df['speaker_id'] - ref_stm_path = write_transcript_to_stm(out_dir, df, tn, session_id, filename='ref.stm') - - stm_res = pd.Series( - {'session_id': session_id, 'tcp_wer_hyp_stm': tcp_wer_hyp_stm, - 'tcorc_wer_hyp_stm': tcorc_wer_hyp_stm, 'ref_stm': ref_stm_path}) - - - def save_wer_visualization(session: pd.Series): - ref = meeteval.io.load(session.ref_stm).groupby('filename') - hyp = meeteval.io.load(session.tcp_wer_hyp_stm).groupby('filename') - assert len(ref) == 1, 'Multiple meetings in a ref file?' - assert len(hyp) == 1, 'Multiple meetings in a hyp file?' + # json to SegLST structure (Segment-wise Long-form Speech Transcription annotation) + to_seglst = lambda x: meeteval.io.chime7.json_to_stm(x, None).to_seglst() if isinstance(x, list) \ + else meeteval.io.load(Path(x)) + tcp_hyp_seglst = to_seglst(tcp_wer_hyp_json) + tcorc_hyp_seglst = to_seglst(tcorc_wer_hyp_json) + + # map session_id to meetind_id and join with gt_utt_df to include GT utterances for each session. + # since every meeting contributes several sessions, a meeting's GT will be repeated for every session. + sess2meet_id = tcp_hyp_seglst.groupby('session_id').keys() + sess2meet_id = pd.DataFrame(sess2meet_id, columns=['session_id']) + sess2meet_id['meeting_id'] = sess2meet_id['session_id'].str.extract(r'(MTG_\d+)') + joined_df = pd.merge(sess2meet_id, gt_utt_df, on='meeting_id', how='left') + ref_seglst = df_to_seglst(joined_df) + + if isinstance(tn, str): + tn = get_txt_norm(tn) + # normalization should be idempotent so a second normalization will not change the result + tcp_hyp_seglst = tcp_hyp_seglst.map(partial(normalize_segment, tn=tn)) + tcorc_hyp_seglst = tcorc_hyp_seglst.map(partial(normalize_segment, tn=tn)) + ref_seglst = ref_seglst.map(partial(normalize_segment, tn=tn)) + + ref_file_path = Path(out_dir) / 'ref.json' + ref_file_path.parent.mkdir(parents=True, exist_ok=True) + ref_seglst.dump(ref_file_path) + + def save_wer_visualization(ref, hyp): + ref = ref.groupby('session_id') + hyp = hyp.groupby('session_id') + assert len(ref) == 1 and len(hyp) == 1, 'expecting one session for visualization' assert list(ref.keys())[0] == list(hyp.keys())[0] - - meeting_name = list(ref.keys())[0] - av = AlignmentVisualization(ref[meeting_name], hyp[meeting_name], alignment='tcp') + + meeting_name = list(ref.keys())[0] + av = AlignmentVisualization(ref[meeting_name], hyp[meeting_name], alignment='tcp') # Create standalone HTML file - av.dump(os.path.join(os.path.split(session.tcp_wer_hyp_stm)[0], 'viz.html')) - - - def calc_session_tcp_wer(session: pd.Series): - os.system(f"{sys.executable} -m meeteval.wer tcpwer " - f"-h {session.tcp_wer_hyp_stm} " - f"-r {session.ref_stm} " - f"--collar {collar}") - - with (open(os.path.splitext(session.tcp_wer_hyp_stm)[0] + '_tcpwer.json', "r") as read_file): - data = json.load(read_file) - keys = ['error_rate', 'errors', 'length', 'insertions', 'deletions', 'substitutions', - 'missed_speaker', 'falarm_speaker', 'scored_speaker', 'assignment'] - - return pd.Series({('tcp_' + key): data[key] for key in keys} - ).rename({'tcp_error_rate': 'tcp_wer'}) - - - def calc_session_tcorc_wer(session: pd.Series): - os.system(f"{sys.executable} -m meeteval.wer tcorcwer " - f"-h {session.tcorc_wer_hyp_stm} " - f"-r {session.ref_stm} " - f"--collar {collar}") - - with (open(os.path.splitext(session.tcorc_wer_hyp_stm)[0] + '_tcorcwer.json', "r") as read_file): - data = json.load(read_file) - keys = ['error_rate', 'errors', 'length', 'insertions', 'deletions', 'substitutions', - 'assignment'] - - return pd.Series({('tcorc_'+key): data[key] for key in keys} - ).rename({'tcorc_error_rate': 'tcorc_wer'}) - - tcp_wer_res = calc_session_tcp_wer(stm_res) - tcorc_wer_res = calc_session_tcorc_wer(stm_res) + av.dump(os.path.join(out_dir, 'viz.html')) + + def calc_session_tcp_wer(ref, hyp): + res = meeteval.wer.tcpwer(reference=ref, hypothesis=hyp, collar=collar) + + res_df = pd.DataFrame.from_dict(res, orient='index').reset_index(names='session_id') + keys = ['error_rate', 'errors', 'length', 'insertions', 'deletions', 'substitutions', + 'missed_speaker', 'falarm_speaker', 'scored_speaker', 'assignment'] + return (res_df[['session_id'] + keys] + .rename(columns={k: 'tcp_' + k for k in keys}) + .rename(columns={'tcp_error_rate': 'tcp_wer'})) + + def calc_session_tcorc_wer(ref, hyp): + res = meeteval.wer.tcorcwer(reference=ref, hypothesis=hyp, collar=collar) + + res_df = pd.DataFrame.from_dict(res, orient='index').reset_index(names='session_id') + keys = ['error_rate', 'errors', 'length', 'insertions', 'deletions', 'substitutions', 'assignment'] + return (res_df[['session_id'] + keys] + .rename(columns={k: 'tcorc_' + k for k in keys}) + .rename(columns={'tcorc_error_rate': 'tcorc_wer'})) + + tcp_wer_res = calc_session_tcp_wer(ref_seglst, tcp_hyp_seglst) + tcorc_wer_res = calc_session_tcorc_wer(ref_seglst, tcorc_hyp_seglst) if save_visualizations: - save_wer_visualization(stm_res) + save_wer_visualization(ref_seglst, tcp_hyp_seglst) + + wer_df = pd.concat([tcp_wer_res, tcorc_wer_res.drop(columns='session_id')], axis=1) - session_res = pd.concat([stm_res, tcp_wer_res, tcorc_wer_res], axis=0) + if isinstance(tcp_wer_hyp_json, str | Path): + wer_df['tcp_wer_hyp_json'] = tcp_wer_hyp_json + if isinstance(tcorc_wer_hyp_json, str | Path): + wer_df['tcorc_wer_hyp_json'] = tcorc_wer_hyp_json - _LOG.info(f"tcp_wer = {session_res.tcp_wer:.4f}, tcorc_wer = {session_res.tcorc_wer:.4f} " - f"for session {session_res.session_id}") + _LOG.info('Done calculating WER') + _LOG.info(f"\n{wer_df[['session_id', 'tcp_wer', 'tcorc_wer']]}") + + return wer_df + + +def write_submission_jsons(out_dir: str, hyp_jsons_df: pd.DataFrame): + """ + Merges the per-session jsons in hyp_jsons_df and writes them under the appropriate track folder + in out_dir. + The resulting jsons can be used for submission. + """ + # close-talk is not supposed to be used for scoring + hyp_jsons_df = hyp_jsons_df[~hyp_jsons_df.is_close_talk] + + def write_json(files, file_name, is_mc): + seglst = [] + for f in files: + data = meeteval.io.load(f) + seglst.extend(data) + seglst = meeteval.io.SegLST(seglst) + track = 'multichannel' if is_mc else 'singlechannel' + filepath = Path(out_dir) / 'wer' / track / file_name + seglst.dump(filepath) + _LOG.info(f'Wrote hypothesis transcript for submission: {filepath}') + + mc_hyps = hyp_jsons_df[hyp_jsons_df.is_mc] + sc_hyps = hyp_jsons_df[~hyp_jsons_df.is_mc] + + if len(mc_hyps) > 0: + write_json(mc_hyps.tcp_wer_hyp_json, 'tcp_wer_hyp.json', is_mc=True) + write_json(mc_hyps.tcorc_wer_hyp_json, 'tc_orc_wer_hyp.json', is_mc=True) + + if len(sc_hyps) > 0: + write_json(sc_hyps.tcp_wer_hyp_json, 'tcp_wer_hyp.json', is_mc=False) + write_json(sc_hyps.tcorc_wer_hyp_json, 'tc_orc_wer_hyp.json', is_mc=False) - return session_res diff --git a/utils/text_norm_whisper_like/__init__.py b/utils/text_norm_whisper_like/__init__.py index b130d05..26b50c3 100644 --- a/utils/text_norm_whisper_like/__init__.py +++ b/utils/text_norm_whisper_like/__init__.py @@ -1,20 +1,18 @@ """ NOTSOFAR adopts the same text normalizer as the CHiME-8 DASR track. -This code is copied from the CHiME-8 repo: +This code is aligned with the CHiME-8 repo: https://github.com/chimechallenge/chime-utils/tree/main/chime_utils/text_norm """ from .basic import BasicTextNormalizer as BasicTextNormalizer from .english import EnglishTextNormalizer as EnglishTextNormalizer -from whisper.normalizers import EnglishTextNormalizer as OriginalEnglishTextNormalizer def get_txt_norm(txt_norm): + assert txt_norm in ["chime8", None] if txt_norm is None: return None elif txt_norm == "chime8": return EnglishTextNormalizer() - elif txt_norm == "whisper": - return OriginalEnglishTextNormalizer() else: - raise NotImplementedError() + raise NotImplementedError diff --git a/utils/text_norm_whisper_like/english.py b/utils/text_norm_whisper_like/english.py index b9bf2e1..76e2df5 100644 --- a/utils/text_norm_whisper_like/english.py +++ b/utils/text_norm_whisper_like/english.py @@ -540,6 +540,21 @@ def __call__(self, s: str): class EnglishTextNormalizer: + """ + This is a modified version of the Whisper text normalizer designed to enhance compatibility + across various ASRs. + + Key features: + + 1. Idempotency: output is unchanged with repeated application. + 2. The original Whisper-tailored number normalization is replaced with one that is compatible with + other ASR systems, mapping numerals into spelled-out numbers. + See EnglishReverseNumberNormalizer for details and limitations. + 3. Filler words are removed by default, similar to the original normalizer: ['hmm', 'uh', 'ah', 'eh']. + This is for compatibility with ASRs trained to ignore these. + 4. Added normalization for some common words: okay -> ok, everyday -> every day etc. + + """ def __init__(self, standardize_numbers=False, standardize_numbers_rev=True, remove_fillers=True): self.replacers = { # common non verbal sounds are mapped to the similar ones diff --git a/utils/torch_utils.py b/utils/torch_utils.py index a19b229..69a2ff4 100644 --- a/utils/torch_utils.py +++ b/utils/torch_utils.py @@ -1,10 +1,9 @@ import os -import pickle -from typing import Optional, Tuple, List, Any, Dict +from typing import Any, Dict +import pandas as pd import torch import torch.nn as nn -import torch.nn.functional as F import torch.distributed as dist @@ -31,6 +30,99 @@ def is_zero_rank(): return get_rank() == 0 +def barrier(): + if is_dist_initialized(): + dist.barrier() + + +def get_device_name(): + if is_dist_initialized(): + # when the number of nodes is 1, we can use only get_rank() to get the device_id + # but when the number of nodes is greater than 1, the device_id can be calculated by: + device_id = get_rank() % torch.cuda.device_count() + return f'cuda:{device_id}' + + return "cuda" if torch.cuda.is_available() else "cpu" + + +class DDPRowIterator: + """ A class that wraps a DataFrame, such that the returned DataFrame row number is divided by the world size + (i.e. the number of processes created by the DDP). The padded rows are filled with the row at the given + dummy_row_idx field. + This is useful for distributed inference, where we want to distribute the data across all processes, such that + all processes are working on different rows at the same time, while no process is idle (DDP assumption). + The next() method returns a tuple of (row, row_idx, is_dummy) where is_dummy is True if the row is a padded row. + Each process will iterate over the rows that are assigned to it, and then stop when the rows are exhausted. + + Args: + df (pd.DataFrame): the DataFrame to iterate over + """ + + def __init__(self, df: pd.DataFrame): + self.df = df + self.world_size = get_world_size() + self.current_process_idx = get_rank() + self.rows_per_chunk = len(df) // self.world_size + self.remainder = len(df) % self.world_size + self.current_row_idx = 0 + self.dummy_row_idx = self.current_process_idx + assert self.dummy_row_idx < len(self.df), f'{self.dummy_row_idx=} must be less than {len(self.df)=}' + + @property + def _padded_df_len(self): + return len(self.df) + ((self.world_size - self.remainder) if self.remainder > 0 else 0) + + def __len__(self): + return int(self._padded_df_len / self.world_size) + + def __iter__(self): + self.current_row_idx = self.current_process_idx + return self + + def __next__(self): + if self.current_row_idx >= self._padded_df_len: + # Wait for all processes to finish processing + barrier() + raise StopIteration + + row_idx = self.current_row_idx + + if row_idx < len(self.df): + is_dummy = False + row = self.df.iloc[row_idx] + else: + # if we are here, we are padding the DataFrame (self.current_row_idx >= len(self.df)) + is_dummy = True + row = self.df.iloc[self.dummy_row_idx] + + self.current_row_idx += self.world_size + return row, row_idx, is_dummy + + +def initialize_ddp(logger): + """ Process group initialization for distributed inference """ + if is_dist_env_available(): + rank = int(os.environ['RANK']) + world_size = int(os.environ['WORLD_SIZE']) + dist.init_process_group('nccl', rank=rank, world_size=world_size) + logger.info(f'Distributed: {get_rank()=}, {get_world_size()=}') + # NOTE! must call set_device or allocations go to GPU 0 disproportionally, causing CUDA OOM. + torch.cuda.set_device(torch.device(get_device_name())) + dist.barrier() + + return get_device_name() + + +def get_max_value(value: int) -> int: + """ Returns the maximum value from all processes """ + if not is_dist_initialized(): + return value + + tensor = torch.tensor(value).cuda() + dist.all_reduce(tensor, op=dist.ReduceOp.MAX) + return int(tensor.item()) + + def move_to(obj: Any, device: torch.device, numpy: bool=False) -> Any: """recursively visit a tuple/list/dict structure (can extend to more types if required)""" # pylint: disable=unidiomatic-typecheck # explicitly differentiate tuple from NamedTuple