forked from malikwang/FFTViolin
-
Notifications
You must be signed in to change notification settings - Fork 0
/
FFT.py
177 lines (163 loc) · 5.29 KB
/
FFT.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
#!/usr/bin/env python
# -*-coding:utf8-*-#
import os
import librosa
import copy
import math
import numpy as np
from numpy.fft import fft,ifft
pitch_list = []
hz_list = []
pitch_sample_dic = {}
window_size = 4096
target_sr = 16000
total_count = 0
correct_count = 0
window = np.hamming(window_size)
def init_list():
# 识别音符范围是G3~E7,MIDI Number从53到100,但左右各多取一个音符以便确定频率的截止带宽
for midi in range(54,102):
pitch = librosa.midi_to_note(midi)
pitch_hz = librosa.note_to_hz(pitch)
pitch_list.append(pitch)
hz_list.append(pitch_hz)
def init_dic():
# 这里的index是G3~E7在hz_list中的
for index in range(1,47):
low_band = (float(hz_list[index-1]) + float(hz_list[index])) / 2
up_band = (float(hz_list[index]) + float(hz_list[index+1])) / 2
low_index = int(low_band*window_size/target_sr)
up_index = int(up_band*window_size/target_sr)
pitch_sample_dic[pitch_list[index]] = range(low_index+1,up_index+1)
#滤除非音频文件
def filter(file_list):
wav_list = []
for file in file_list:
if '.wav' in file:
wav_list.append(file)
return wav_list
def sort_note(note):
pitch_list = note.split('_')
pitch_list.sort()
return '_'.join(pitch_list)
def fft_transform(file_path,true_pitch):
global correct_count
x, sr = librosa.load(file_path,sr=None)
x = librosa.resample(x, sr, target_sr)
x_sample = x[10000:14096]
y = fft(x_sample*window)
abs_y = abs(y)
# 获取每一个基频的能量
energy_list = []
for index in range(1,47):
pitch = pitch_list[index]
sample_list = pitch_sample_dic[pitch]
pitch_energy = 0
for sample_index in sample_list:
pitch_energy += abs_y[sample_index]**2
energy_list.append(pitch_energy)
# 找第一个音
tmp_energy_list = copy.deepcopy(energy_list)
# 加上谐波的能量
for index in range(len(energy_list)):
try:
energy_list[index] += min(5*tmp_energy_list[index],tmp_energy_list[index+12])
energy_list[index] += min(5*tmp_energy_list[index],tmp_energy_list[index+19])
energy_list[index] += min(5*tmp_energy_list[index],tmp_energy_list[index+24])
energy_list[index] += min(5*tmp_energy_list[index],tmp_energy_list[index+28])
except:
pass
pitch_index = np.argmax(energy_list)+1
pred_pitch1 = pitch_list[pitch_index]
try:
octave_energy = energy_list[pitch_index-1-12]
if 4*octave_energy > energy_list[pitch_index-1]:
pred_pitch1 = pitch_list[pitch_index-12]
except:
pass
# 找第二个音
pitch_index = pitch_list.index(pred_pitch1)
index = pitch_index - 1
energy_list[index] *= 0.01
try:
energy_list[index+12] *= 0.01
energy_list[index+19] *= 0.01
energy_list[index+24] *= 0.01
energy_list[index+28] *= 0.01
except:
pass
tmp_energy_list = copy.deepcopy(energy_list)
# 加上谐波的能量
for index in range(len(energy_list)):
try:
energy_list[index] += min(5*tmp_energy_list[index],tmp_energy_list[index+12])
energy_list[index] += min(5*tmp_energy_list[index],tmp_energy_list[index+19])
energy_list[index] += min(5*tmp_energy_list[index],tmp_energy_list[index+24])
energy_list[index] += min(5*tmp_energy_list[index],tmp_energy_list[index+28])
except:
pass
pitch_index = np.argmax(energy_list)+1
pred_pitch2 = pitch_list[pitch_index]
try:
octave_energy = energy_list[pitch_index-1-12]
if 4*octave_energy > energy_list[pitch_index-1]:
pred_pitch2 = pitch_list[pitch_index-12]
except:
pass
if pred_pitch1 == pred_pitch2:
if pred_pitch1 == true_pitch:
correct_count += 1
else:
print(pred_pitch2,true_pitch)
else:
pred_pitch = '%s_%s'%(pred_pitch1,pred_pitch2)
if sort_note(pred_pitch) == sort_note(true_pitch):
correct_count += 1
else:
print(pred_pitch,true_pitch)
# index = pitch_list.index(true_pitch)-1
# try:
# print(tmp_energy_list[index],tmp_energy_list[index+12],tmp_energy_list[index+19],tmp_energy_list[index+24],tmp_energy_list[index+28])
# index += 12
# print(tmp_energy_list[index],tmp_energy_list[index+12],tmp_energy_list[index+19],tmp_energy_list[index+24],tmp_energy_list[index+28])
# except:
# pass
def cepstrum(file_path,true_pitch):
global correct_count
x, sr = librosa.load(file_path,sr=None)
x = librosa.resample(x, sr, target_sr)
ms_a = target_sr/2100
ms_b = target_sr/100
x_sample = x[10000:14096]
window = np.hamming(window_size)
y = fft(x_sample*window)
C = fft(np.log(abs(y)));
abs_C_sample = abs(C[ms_a:ms_b])
max_index = np.argmax(abs_C_sample)
fx = target_sr/(ms_a+max_index)
pred_pitch = librosa.hz_to_note(fx)
if pred_pitch == true_pitch:
correct_count += 1
else:
print(pred_pitch,true_pitch)
init_list()
# print(pitch_list)
# print(hz_list)
init_dic()
# print(pitch_sample_dic)
dataset_path = '/Users/lisimin/Desktop/Violin/Dataset/BUPT'
for pitch_dir in os.listdir(dataset_path):
if '.' not in pitch_dir:
pitch_dir_path = os.path.join(dataset_path, pitch_dir)
for split_dir in os.listdir(pitch_dir_path):
if '.' not in split_dir:
split_dir_path = os.path.join(pitch_dir_path, split_dir)
for wav_file in filter(os.listdir(split_dir_path)):
wav_file_path = os.path.join(split_dir_path, wav_file)
# if librosa.note_to_midi(pitch_dir) < librosa.note_to_midi('C4'):
# cepstrum(wav_file_path,pitch_dir)
# else:
# fft_transform(wav_file_path,pitch_dir)
fft_transform(wav_file_path,pitch_dir)
total_count += 1
print(correct_count,total_count)