Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batched extraction of speaker embeddings for Nemo word-level baseline #36

Merged
merged 1 commit into from
Apr 16, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 58 additions & 47 deletions diarization/word_based_diarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from diarization.diarization import DiarizationCfg
from diarization.diarization_common import prepare_diarized_data_frame, DiarizationCfg
from utils.logging_def import get_logger

from torch.nn.utils.rnn import pad_sequence
_LOG = get_logger('word_based_diarization')


Expand Down Expand Up @@ -50,7 +50,12 @@ def run_clustering(raw_affinity_mat: np.array, max_num_speakers: int=8, max_rp_t
return cluster_label


def extract_speaker_embedding_for_words(segments_df, wavs, sr, spk_model, min_embedding_windows, max_allowed_word_duration=3):
def batch_generator(data, batch_size):
for i in range(0, len(data), batch_size):
yield data[i:i + batch_size]


def extract_speaker_embedding_for_words(segments_df, wavs, sr, spk_model, min_embedding_windows, max_allowed_word_duration=3, batch_size=32):
"""
For each word, use its word boundary information to extract multi-scale speaker embedding vectors.
"""
Expand All @@ -68,52 +73,58 @@ def extract_speaker_embedding_for_words(segments_df, wavs, sr, spk_model, min_em
# get the unmixed channel id for current segment
channel_id = seg.wav_file_name_ind

for word in seg["word_timing"]:
start_time = word[1]
end_time = word[2]
center_time = (start_time + end_time) / 2
word_duration = end_time - start_time

# extract multi-scale speaker embedding for the word
for words_batch in batch_generator(seg["word_timing"], batch_size):
word_embedding = []
for min_window_size in min_embedding_windows:
if word_duration < min_window_size:
# if the word duration is shorter than the window size, use a window centered at the word.
# The window may cover other neighboring words
start_time2 = np.maximum(0, center_time - min_window_size/2)
end_time2 = np.minimum(wav_duration, center_time + min_window_size/2)
start_sample = int(start_time2*sr)
end_sample = int(end_time2*sr)
else:
start_sample = int(start_time*sr)
end_sample = int(end_time*sr)

### TO DO
### Use batching to increase speed
word_wav = wavs[channel_id][start_sample:end_sample]
word_wav = torch.tensor(word_wav[np.newaxis], dtype=torch.float32).to(spk_model.device)
word_lens = torch.tensor([word_wav.shape[1]], dtype=torch.int).to(spk_model.device)
with autocast(), torch.no_grad():
_, tmp_embedding = spk_model.forward(input_signal=word_wav, input_signal_length=word_lens)
word_embedding.append(tmp_embedding.cpu().detach())

words_processed += 1

if words_processed > n_words:
# This is a dummy word added for DDP. Skip it.
continue

if word_duration > max_allowed_word_duration:
# Very long word duration is very suspicious and may harm diarization. Ignore them for now.
# Note that these words will disappear in the final result.
# To do: find a better way to deal with these words.
_LOG.info(f"word '{word[0]}' has unreasonablly long duration ({start_time}s, {end_time}s). Skip it in diarization")
too_long_words.append(word)
continue

# append only the real words (do not append dummy words)
all_words.append(word+[channel_id])
all_word_embeddings.append(torch.vstack(word_embedding))
word_wavs = []
word_lens = []
for word in words_batch:
start_time = word[1]
end_time = word[2]
center_time = (start_time + end_time) / 2
word_duration = end_time - start_time

# extract multi-scale speaker embedding for the word
for min_window_size in min_embedding_windows:
if word_duration < min_window_size:
# if the word duration is shorter than the window size, use a window centered at the word.
# The window may cover other neighboring words
start_time2 = np.maximum(0, center_time - min_window_size/2)
end_time2 = np.minimum(wav_duration, center_time + min_window_size/2)
start_sample = int(start_time2*sr)
end_sample = int(end_time2*sr)
else:
start_sample = int(start_time*sr)
end_sample = int(end_time*sr)

word_wav = wavs[channel_id][start_sample:end_sample]
word_wavs.append(torch.tensor(word_wav, dtype=torch.float32).to(spk_model.device))
word_lens.append(torch.tensor(word_wav.shape[0], dtype=torch.int).to(spk_model.device))
with autocast(), torch.no_grad():
word_wavs = pad_sequence(word_wavs, batch_first=True, padding_value=0)
word_lens = torch.stack(word_lens)
_, tmp_embedding = spk_model.forward(input_signal=word_wavs, input_signal_length=word_lens)
word_embedding.extend(tmp_embedding.cpu().detach())

for word, word_embedding_local in zip(words_batch, batch_generator(word_embedding, len(min_embedding_windows))):
words_processed += 1
start_time = word[1]
end_time = word[2]
word_duration = end_time - start_time
if words_processed > n_words:
# This is a dummy word added for DDP. Skip it.
continue

if word_duration > max_allowed_word_duration:
# Very long word duration is very suspicious and may harm diarization. Ignore them for now.
# Note that these words will disappear in the final result.
# To do: find a better way to deal with these words.
_LOG.info(f"word '{word[0]}' has unreasonablly long duration ({start_time}s, {end_time}s). Skip it in diarization")
too_long_words.append(word)
continue

# append only the real words (do not append dummy words)
all_words.append(word+[channel_id])
all_word_embeddings.append(torch.vstack(word_embedding_local))

print(f'Done extracting embeddings. {words_processed=}, {len(all_words)=}, {n_words=}', flush=True)
n_real_words = n_words - len(too_long_words)
Expand Down