diff --git a/diarization/word_based_diarization.py b/diarization/word_based_diarization.py index 689c09f..62dd4d6 100644 --- a/diarization/word_based_diarization.py +++ b/diarization/word_based_diarization.py @@ -14,7 +14,7 @@ from diarization.diarization import DiarizationCfg from diarization.diarization_common import prepare_diarized_data_frame, DiarizationCfg from utils.logging_def import get_logger - +from torch.nn.utils.rnn import pad_sequence _LOG = get_logger('word_based_diarization') @@ -50,7 +50,12 @@ def run_clustering(raw_affinity_mat: np.array, max_num_speakers: int=8, max_rp_t return cluster_label -def extract_speaker_embedding_for_words(segments_df, wavs, sr, spk_model, min_embedding_windows, max_allowed_word_duration=3): +def batch_generator(data, batch_size): + for i in range(0, len(data), batch_size): + yield data[i:i + batch_size] + + +def extract_speaker_embedding_for_words(segments_df, wavs, sr, spk_model, min_embedding_windows, max_allowed_word_duration=3, batch_size=32): """ For each word, use its word boundary information to extract multi-scale speaker embedding vectors. """ @@ -68,52 +73,58 @@ def extract_speaker_embedding_for_words(segments_df, wavs, sr, spk_model, min_em # get the unmixed channel id for current segment channel_id = seg.wav_file_name_ind - for word in seg["word_timing"]: - start_time = word[1] - end_time = word[2] - center_time = (start_time + end_time) / 2 - word_duration = end_time - start_time - - # extract multi-scale speaker embedding for the word + for words_batch in batch_generator(seg["word_timing"], batch_size): word_embedding = [] - for min_window_size in min_embedding_windows: - if word_duration < min_window_size: - # if the word duration is shorter than the window size, use a window centered at the word. - # The window may cover other neighboring words - start_time2 = np.maximum(0, center_time - min_window_size/2) - end_time2 = np.minimum(wav_duration, center_time + min_window_size/2) - start_sample = int(start_time2*sr) - end_sample = int(end_time2*sr) - else: - start_sample = int(start_time*sr) - end_sample = int(end_time*sr) - - ### TO DO - ### Use batching to increase speed - word_wav = wavs[channel_id][start_sample:end_sample] - word_wav = torch.tensor(word_wav[np.newaxis], dtype=torch.float32).to(spk_model.device) - word_lens = torch.tensor([word_wav.shape[1]], dtype=torch.int).to(spk_model.device) - with autocast(), torch.no_grad(): - _, tmp_embedding = spk_model.forward(input_signal=word_wav, input_signal_length=word_lens) - word_embedding.append(tmp_embedding.cpu().detach()) - - words_processed += 1 - - if words_processed > n_words: - # This is a dummy word added for DDP. Skip it. - continue - - if word_duration > max_allowed_word_duration: - # Very long word duration is very suspicious and may harm diarization. Ignore them for now. - # Note that these words will disappear in the final result. - # To do: find a better way to deal with these words. - _LOG.info(f"word '{word[0]}' has unreasonablly long duration ({start_time}s, {end_time}s). Skip it in diarization") - too_long_words.append(word) - continue - - # append only the real words (do not append dummy words) - all_words.append(word+[channel_id]) - all_word_embeddings.append(torch.vstack(word_embedding)) + word_wavs = [] + word_lens = [] + for word in words_batch: + start_time = word[1] + end_time = word[2] + center_time = (start_time + end_time) / 2 + word_duration = end_time - start_time + + # extract multi-scale speaker embedding for the word + for min_window_size in min_embedding_windows: + if word_duration < min_window_size: + # if the word duration is shorter than the window size, use a window centered at the word. + # The window may cover other neighboring words + start_time2 = np.maximum(0, center_time - min_window_size/2) + end_time2 = np.minimum(wav_duration, center_time + min_window_size/2) + start_sample = int(start_time2*sr) + end_sample = int(end_time2*sr) + else: + start_sample = int(start_time*sr) + end_sample = int(end_time*sr) + + word_wav = wavs[channel_id][start_sample:end_sample] + word_wavs.append(torch.tensor(word_wav, dtype=torch.float32).to(spk_model.device)) + word_lens.append(torch.tensor(word_wav.shape[0], dtype=torch.int).to(spk_model.device)) + with autocast(), torch.no_grad(): + word_wavs = pad_sequence(word_wavs, batch_first=True, padding_value=0) + word_lens = torch.stack(word_lens) + _, tmp_embedding = spk_model.forward(input_signal=word_wavs, input_signal_length=word_lens) + word_embedding.extend(tmp_embedding.cpu().detach()) + + for word, word_embedding_local in zip(words_batch, batch_generator(word_embedding, len(min_embedding_windows))): + words_processed += 1 + start_time = word[1] + end_time = word[2] + word_duration = end_time - start_time + if words_processed > n_words: + # This is a dummy word added for DDP. Skip it. + continue + + if word_duration > max_allowed_word_duration: + # Very long word duration is very suspicious and may harm diarization. Ignore them for now. + # Note that these words will disappear in the final result. + # To do: find a better way to deal with these words. + _LOG.info(f"word '{word[0]}' has unreasonablly long duration ({start_time}s, {end_time}s). Skip it in diarization") + too_long_words.append(word) + continue + + # append only the real words (do not append dummy words) + all_words.append(word+[channel_id]) + all_word_embeddings.append(torch.vstack(word_embedding_local)) print(f'Done extracting embeddings. {words_processed=}, {len(all_words)=}, {n_words=}', flush=True) n_real_words = n_words - len(too_long_words)