diff --git a/test/src/unittests/tonal/test_audio2midi.py b/test/src/unittests/tonal/test_audio2midi.py index 748f51753..ad3f4566d 100644 --- a/test/src/unittests/tonal/test_audio2midi.py +++ b/test/src/unittests/tonal/test_audio2midi.py @@ -59,26 +59,44 @@ def assessNoteList( n_expected_notes = len(expected_notes) # estimate the onset error for each note and estimate the mean - onset_mse = mean( + local_onset_mses = array( [square(note[1] - estimated[int(note[0])][0]) for note in expected_notes] ) + global_onset_mse = mean(local_onset_mses) # estimate the onset error for each note and estimate the mean - offset_mse = mean( + local_offset_mses = array( [square(note[2] - estimated[int(note[0])][1]) for note in expected_notes] ) + global_offset_mse = mean(local_offset_mses) # estimate the midi note error for each note and estimate the mean - midi_note_mse = mean( + local_midi_note_mses = array( [square(note[-1] - estimated[int(note[0])][-1]) for note in expected_notes] ) + midi_note_mse = mean(local_midi_note_mses) - # assert outputs + # assert global outputs self.assertAlmostEqual(n_detected_notes, n_expected_notes, n_notes_tolerance) - self.assertAlmostEqual(onset_mse, 0, onset_tolerance) - self.assertAlmostEqual(offset_mse, 0, offset_tolerance) + self.assertAlmostEqual(global_onset_mse, 0, onset_tolerance) + self.assertAlmostEqual(global_offset_mse, 0, offset_tolerance) self.assertAlmostEqual(midi_note_mse, midi_note_mse, midi_note_tolerance) + # assert local outputs + message = "Some onset is larger than onset tolerance" + onsets_condition = all([mse < onset_tolerance for mse in local_onset_mses]) + self.assertTrue(onsets_condition, message) + + message = "Some offset is larger than offset tolerance" + offsets_condition = all([mse < offset_tolerance for mse in local_offset_mses]) + self.assertTrue(offsets_condition, message) + + message = "Some midi note is larger than midi note tolerance" + midi_note_condition = all( + [mse <= midi_note_tolerance for mse in local_midi_note_mses] + ) + self.assertTrue(midi_note_condition, message) + def testARealCaseWithEMajorScale(self): frame_size = 8192 sample_rate = 48000