diff --git a/mel2wav/dataset.py b/mel2wav/dataset.py index 0982af8..2447034 100644 --- a/mel2wav/dataset.py +++ b/mel2wav/dataset.py @@ -36,10 +36,16 @@ def __init__(self, training_files, segment_length, sampling_rate, augment=True): random.shuffle(self.audio_files) self.augment = augment + # Load all audio files into memory + self.audio_data = [] + for audio_file in self.audio_files: + audio, _ = self.load_wav_to_torch(audio_file) + self.audio_data.append(audio) + def __getitem__(self, index): - # Read audio - filename = self.audio_files[index] - audio, sampling_rate = self.load_wav_to_torch(filename) + # Get audio from memory + audio = self.audio_data[index] + # Take segment if audio.size(0) >= self.segment_length: max_audio_start = audio.size(0) - self.segment_length