Skip to content

Commit

Permalink
Merge pull request #36 from Lakoc/word_based_diarization_batched
Browse files Browse the repository at this point in the history
Batched extraction of speaker embeddings for Nemo word-level baseline
  • Loading branch information
nidleo authored Apr 16, 2024
2 parents 3d80ac4 + f49d6d9 commit 0993574
Showing 1 changed file with 58 additions and 47 deletions.
105 changes: 58 additions & 47 deletions diarization/word_based_diarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from diarization.diarization import DiarizationCfg
from diarization.diarization_common import prepare_diarized_data_frame, DiarizationCfg
from utils.logging_def import get_logger

from torch.nn.utils.rnn import pad_sequence
_LOG = get_logger('word_based_diarization')


Expand Down Expand Up @@ -50,7 +50,12 @@ def run_clustering(raw_affinity_mat: np.array, max_num_speakers: int=8, max_rp_t
return cluster_label


def extract_speaker_embedding_for_words(segments_df, wavs, sr, spk_model, min_embedding_windows, max_allowed_word_duration=3):
def batch_generator(data, batch_size):
for i in range(0, len(data), batch_size):
yield data[i:i + batch_size]


def extract_speaker_embedding_for_words(segments_df, wavs, sr, spk_model, min_embedding_windows, max_allowed_word_duration=3, batch_size=32):
"""
For each word, use its word boundary information to extract multi-scale speaker embedding vectors.
"""
Expand All @@ -68,52 +73,58 @@ def extract_speaker_embedding_for_words(segments_df, wavs, sr, spk_model, min_em
# get the unmixed channel id for current segment
channel_id = seg.wav_file_name_ind

for word in seg["word_timing"]:
start_time = word[1]
end_time = word[2]
center_time = (start_time + end_time) / 2
word_duration = end_time - start_time

# extract multi-scale speaker embedding for the word
for words_batch in batch_generator(seg["word_timing"], batch_size):
word_embedding = []
for min_window_size in min_embedding_windows:
if word_duration < min_window_size:
# if the word duration is shorter than the window size, use a window centered at the word.
# The window may cover other neighboring words
start_time2 = np.maximum(0, center_time - min_window_size/2)
end_time2 = np.minimum(wav_duration, center_time + min_window_size/2)
start_sample = int(start_time2*sr)
end_sample = int(end_time2*sr)
else:
start_sample = int(start_time*sr)
end_sample = int(end_time*sr)

### TO DO
### Use batching to increase speed
word_wav = wavs[channel_id][start_sample:end_sample]
word_wav = torch.tensor(word_wav[np.newaxis], dtype=torch.float32).to(spk_model.device)
word_lens = torch.tensor([word_wav.shape[1]], dtype=torch.int).to(spk_model.device)
with autocast(), torch.no_grad():
_, tmp_embedding = spk_model.forward(input_signal=word_wav, input_signal_length=word_lens)
word_embedding.append(tmp_embedding.cpu().detach())

words_processed += 1

if words_processed > n_words:
# This is a dummy word added for DDP. Skip it.
continue

if word_duration > max_allowed_word_duration:
# Very long word duration is very suspicious and may harm diarization. Ignore them for now.
# Note that these words will disappear in the final result.
# To do: find a better way to deal with these words.
_LOG.info(f"word '{word[0]}' has unreasonablly long duration ({start_time}s, {end_time}s). Skip it in diarization")
too_long_words.append(word)
continue

# append only the real words (do not append dummy words)
all_words.append(word+[channel_id])
all_word_embeddings.append(torch.vstack(word_embedding))
word_wavs = []
word_lens = []
for word in words_batch:
start_time = word[1]
end_time = word[2]
center_time = (start_time + end_time) / 2
word_duration = end_time - start_time

# extract multi-scale speaker embedding for the word
for min_window_size in min_embedding_windows:
if word_duration < min_window_size:
# if the word duration is shorter than the window size, use a window centered at the word.
# The window may cover other neighboring words
start_time2 = np.maximum(0, center_time - min_window_size/2)
end_time2 = np.minimum(wav_duration, center_time + min_window_size/2)
start_sample = int(start_time2*sr)
end_sample = int(end_time2*sr)
else:
start_sample = int(start_time*sr)
end_sample = int(end_time*sr)

word_wav = wavs[channel_id][start_sample:end_sample]
word_wavs.append(torch.tensor(word_wav, dtype=torch.float32).to(spk_model.device))
word_lens.append(torch.tensor(word_wav.shape[0], dtype=torch.int).to(spk_model.device))
with autocast(), torch.no_grad():
word_wavs = pad_sequence(word_wavs, batch_first=True, padding_value=0)
word_lens = torch.stack(word_lens)
_, tmp_embedding = spk_model.forward(input_signal=word_wavs, input_signal_length=word_lens)
word_embedding.extend(tmp_embedding.cpu().detach())

for word, word_embedding_local in zip(words_batch, batch_generator(word_embedding, len(min_embedding_windows))):
words_processed += 1
start_time = word[1]
end_time = word[2]
word_duration = end_time - start_time
if words_processed > n_words:
# This is a dummy word added for DDP. Skip it.
continue

if word_duration > max_allowed_word_duration:
# Very long word duration is very suspicious and may harm diarization. Ignore them for now.
# Note that these words will disappear in the final result.
# To do: find a better way to deal with these words.
_LOG.info(f"word '{word[0]}' has unreasonablly long duration ({start_time}s, {end_time}s). Skip it in diarization")
too_long_words.append(word)
continue

# append only the real words (do not append dummy words)
all_words.append(word+[channel_id])
all_word_embeddings.append(torch.vstack(word_embedding_local))

print(f'Done extracting embeddings. {words_processed=}, {len(all_words)=}, {n_words=}', flush=True)
n_real_words = n_words - len(too_long_words)
Expand Down

0 comments on commit 0993574

Please sign in to comment.