diff --git a/modules/vad/silero_vad.py b/modules/vad/silero_vad.py index 4041ada..f0ab758 100644 --- a/modules/vad/silero_vad.py +++ b/modules/vad/silero_vad.py @@ -218,13 +218,6 @@ def collect_chunks(audio: np.ndarray, chunks: List[dict]) -> np.ndarray: return np.concatenate([audio[chunk["start"]: chunk["end"]] for chunk in chunks]) - def get_chunk_index(self, time: float) -> int: - sample = int(time * self.sampling_rate) - return min( - bisect.bisect(self.chunk_end_sample, sample), - len(self.chunk_end_sample) - 1, - ) - @staticmethod def format_timestamp( seconds: float, @@ -260,8 +253,23 @@ def restore_speech_timestamps( ts_map = SpeechTimestampsMap(speech_chunks, sampling_rate) for segment in segments: - segment.start = ts_map.get_original_time(segment.start) - segment.end = ts_map.get_original_time(segment.end) + if segment.words: + words = [] + for word in segment.words: + # Ensure the word start and end times are resolved to the same chunk. + middle = (word.start + word.end) / 2 + chunk_index = ts_map.get_chunk_index(middle) + word.start = ts_map.get_original_time(word.start, chunk_index) + word.end = ts_map.get_original_time(word.end, chunk_index) + words.append(word) + + segment.start = words[0].start + segment.end = words[-1].end + segment.words = words + + else: + segment.start = ts_map.get_original_time(segment.start) + segment.end = ts_map.get_original_time(segment.end) return segments