-
Notifications
You must be signed in to change notification settings - Fork 5
/
model.py
151 lines (119 loc) · 5.19 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
"""
Diff-WTS model. Main adapted from https://github.com/acids-ircam/ddsp_pytorch.
"""
from core import harmonic_synth
from wavetable_synth import WavetableSynth
import torch
import torch.nn as nn
from core import mlp, gru, scale_function, remove_above_nyquist, upsample
from core import amp_to_impulse_response, fft_convolve
import math
from torchvision.transforms import Resize
class Reverb(nn.Module):
def __init__(self, length, sampling_rate, initial_wet=0, initial_decay=5):
super().__init__()
self.length = length
self.sampling_rate = sampling_rate
self.noise = nn.Parameter((torch.rand(length) * 2 - 1).unsqueeze(-1))
self.decay = nn.Parameter(torch.tensor(float(initial_decay)))
self.wet = nn.Parameter(torch.tensor(float(initial_wet)))
t = torch.arange(self.length) / self.sampling_rate
t = t.reshape(1, -1, 1)
self.register_buffer("t", t)
def build_impulse(self):
t = torch.exp(-nn.functional.softplus(-self.decay) * self.t * 500)
noise = self.noise * t
impulse = noise * torch.sigmoid(self.wet)
impulse[:, 0] = 1
return impulse
def forward(self, x):
lenx = x.shape[1]
impulse = self.build_impulse()
impulse = nn.functional.pad(impulse, (0, 0, 0, lenx - self.length))
x = fft_convolve(x.squeeze(-1), impulse.squeeze(-1)).unsqueeze(-1)
return x
class WTS(nn.Module):
def __init__(self, hidden_size, n_harmonic, n_bands, sampling_rate,
block_size, n_wavetables, mode="wavetable", duration_secs=3):
super().__init__()
self.register_buffer("sampling_rate", torch.tensor(sampling_rate))
self.register_buffer("block_size", torch.tensor(block_size))
self.encoder = mlp(30, hidden_size, 3)
self.layer_norm = nn.LayerNorm(30)
self.gru_mfcc = nn.GRU(30, 512, batch_first=True)
self.mlp_mfcc = nn.Linear(512, 16)
self.in_mlps = nn.ModuleList([mlp(1, hidden_size, 3),
mlp(1, hidden_size, 3),
mlp(16, hidden_size, 3)])
self.gru = gru(3, hidden_size)
self.out_mlp = mlp(hidden_size * 4, hidden_size, 3)
self.loudness_mlp = nn.Sequential(
nn.Linear(1, 1),
nn.Sigmoid()
)
self.proj_matrices = nn.ModuleList([
nn.Linear(hidden_size, n_harmonic + 1),
nn.Linear(hidden_size, n_bands),
])
self.reverb = Reverb(sampling_rate, sampling_rate)
self.wts = WavetableSynth(n_wavetables=n_wavetables,
sr=sampling_rate,
duration_secs=duration_secs,
block_size=block_size)
self.register_buffer("cache_gru", torch.zeros(1, 1, hidden_size))
self.register_buffer("phase", torch.zeros(1))
self.mode = mode
self.duration_secs = duration_secs
def forward(self, mfcc, pitch, loudness):
# encode mfcc first
# use layer norm instead of trainable norm, not much difference found
mfcc = self.layer_norm(torch.transpose(mfcc, 1, 2))
mfcc = self.gru_mfcc(mfcc)[0]
mfcc = self.mlp_mfcc(mfcc)
# use image resize to align dimensions, ddsp also do this...
mfcc = Resize(size=(self.duration_secs * 100, 16))(mfcc)
hidden = torch.cat([
self.in_mlps[0](pitch),
self.in_mlps[1](loudness),
self.in_mlps[2](mfcc)
], -1)
hidden = torch.cat([self.gru(hidden)[0], hidden], -1)
hidden = self.out_mlp(hidden)
# harmonic part
param = self.proj_matrices[0](hidden)
if self.mode != "wavetable":
param = scale_function(self.proj_matrices[0](hidden))
total_amp = param[..., :1]
amplitudes = param[..., 1:]
amplitudes = remove_above_nyquist(
amplitudes,
pitch,
self.sampling_rate,
)
amplitudes /= amplitudes.sum(-1, keepdim=True)
amplitudes *= total_amp
total_amp_2 = self.loudness_mlp(loudness)
amplitudes = upsample(amplitudes, self.block_size)
pitch = upsample(pitch, self.block_size)
total_amp = upsample(total_amp, self.block_size) # TODO: wts can't backprop when using this total_amp, not sure why
total_amp_2 = upsample(total_amp_2, self.block_size) # use this instead for wavetable
if self.mode == "wavetable":
# diff-wave-synth synthesizer
harmonic = self.wts(pitch, total_amp_2)
else:
# ddsp synthesizer
harmonic = harmonic_synth(pitch, amplitudes, self.sampling_rate)
# noise part
param = scale_function(self.proj_matrices[1](hidden) - 5)
impulse = amp_to_impulse_response(param, self.block_size)
noise = torch.rand(
impulse.shape[0],
impulse.shape[1],
self.block_size,
).to(impulse) * 2 - 1
noise = fft_convolve(noise, impulse).contiguous()
noise = noise.reshape(noise.shape[0], -1, 1)
signal = harmonic + noise
# reverb part
# signal = self.reverb(signal)
return signal