Skip to content

Commit

Permalink
Assess global mses for onsets, offsets and midi notes
Browse files Browse the repository at this point in the history
  • Loading branch information
xaviliz committed Nov 8, 2024
1 parent c6be702 commit dd2f3c5
Showing 1 changed file with 24 additions and 6 deletions.
30 changes: 24 additions & 6 deletions test/src/unittests/tonal/test_audio2midi.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,26 +59,44 @@ def assessNoteList(
n_expected_notes = len(expected_notes)

# estimate the onset error for each note and estimate the mean
onset_mse = mean(
local_onset_mses = array(
[square(note[1] - estimated[int(note[0])][0]) for note in expected_notes]
)
global_onset_mse = mean(local_onset_mses)

# estimate the onset error for each note and estimate the mean
offset_mse = mean(
local_offset_mses = array(
[square(note[2] - estimated[int(note[0])][1]) for note in expected_notes]
)
global_offset_mse = mean(local_offset_mses)

# estimate the midi note error for each note and estimate the mean
midi_note_mse = mean(
local_midi_note_mses = array(
[square(note[-1] - estimated[int(note[0])][-1]) for note in expected_notes]
)
midi_note_mse = mean(local_midi_note_mses)

# assert outputs
# assert global outputs
self.assertAlmostEqual(n_detected_notes, n_expected_notes, n_notes_tolerance)
self.assertAlmostEqual(onset_mse, 0, onset_tolerance)
self.assertAlmostEqual(offset_mse, 0, offset_tolerance)
self.assertAlmostEqual(global_onset_mse, 0, onset_tolerance)
self.assertAlmostEqual(global_offset_mse, 0, offset_tolerance)
self.assertAlmostEqual(midi_note_mse, midi_note_mse, midi_note_tolerance)

# assert local outputs
message = "Some onset is larger than onset tolerance"
onsets_condition = all([mse < onset_tolerance for mse in local_onset_mses])
self.assertTrue(onsets_condition, message)

message = "Some offset is larger than offset tolerance"
offsets_condition = all([mse < offset_tolerance for mse in local_offset_mses])
self.assertTrue(offsets_condition, message)

message = "Some midi note is larger than midi note tolerance"
midi_note_condition = all(
[mse <= midi_note_tolerance for mse in local_midi_note_mses]
)
self.assertTrue(midi_note_condition, message)

def testARealCaseWithEMajorScale(self):
frame_size = 8192
sample_rate = 48000
Expand Down

0 comments on commit dd2f3c5

Please sign in to comment.