-
Notifications
You must be signed in to change notification settings - Fork 0
/
data_processing.py
37 lines (27 loc) · 1.1 KB
/
data_processing.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
# data_processing.py
import csv
import pretty_midi
import numpy as np
def extract_midi_paths_from_metadata():
csv_path = "./maestro-v3.0.0/maestro-v3.0.0.csv"
midi_paths = []
with open(csv_path, newline='', encoding="utf-8") as csvfile:
csvreader = csv.DictReader(csvfile)
for row in csvreader:
midi_paths.append(row['midi_filename'])
return midi_paths
def extract_features(midi_file_path, n=32):
"""Extracts sequences of notes and their corresponding next note's attributes from a given midi file."""
midi_data = pretty_midi.PrettyMIDI(midi_file_path)
if not midi_data.instruments:
return []
notes = midi_data.instruments[0].notes
features = []
labels = []
for i in range(len(notes) - n):
input_sequence = [(note.pitch, note.velocity, note.end - note.start) for note in notes[i:i+n]]
next_note = notes[i+n]
output_sequence = (next_note.end - next_note.start, next_note.velocity)
features.append(input_sequence)
labels.append(output_sequence)
return list(zip(features, labels))