diff --git a/egs2/TEMPLATE/enh1/enh.sh b/egs2/TEMPLATE/enh1/enh.sh index 39820a3328a..39cca774c9b 100755 --- a/egs2/TEMPLATE/enh1/enh.sh +++ b/egs2/TEMPLATE/enh1/enh.sh @@ -78,7 +78,7 @@ extra_wav_list= # Extra list of scp files for wav formatting init_param= # Enhancement related -inference_args="--normalize_output_wav true" +inference_args="--normalize_output_wav true --output_format wav" inference_model=valid.loss.ave.pth download_model= @@ -172,6 +172,7 @@ Options: --inference_args # Arguments for enhancement in the inference stage (default="${inference_args}") --inference_model # Enhancement model path for inference (default="${inference_model}"). --inference_enh_config # Configuration file for overwriting some model attributes during SE inference. (default="${inference_enh_config}") + --download_model # Download a model from Model Zoo and use it for inference (default="${download_model}"). # Evaluation related --scoring_protocol # Metrics to be used for scoring (default="${scoring_protocol}") @@ -222,7 +223,7 @@ fi [ -z "${test_sets}" ] && { log "${help_message}"; log "Error: --test_sets is required"; exit 2; }; # Extra files for enhancement process -utt_extra_files="utt2category" +utt_extra_files="utt2category utt2fs" data_feats=${dumpdir}/raw @@ -800,6 +801,26 @@ else fi + +if [ -n "${download_model}" ]; then + log "Use ${download_model} for inference and scoring" + enh_exp="${expdir}/${download_model}" + mkdir -p "${enh_exp}" + + # If the model already exists, you can skip downloading + espnet_model_zoo_download --unpack true "${download_model}" > "${enh_exp}/config.txt" + + # Get the path of each file + _enh_model_file=$(<"${enh_exp}/config.txt" sed -e "s/.*'enh_model_file': '\([^']*\)'.*$/\1/") + _enh_train_config=$(<"${enh_exp}/config.txt" sed -e "s/.*'enh_train_config': '\([^']*\)'.*$/\1/") + + # Create symbolic links + ln -sf "${_enh_model_file}" "${enh_exp}" + ln -sf "${_enh_train_config}" "${enh_exp}" + inference_model=$(basename "${_enh_model_file}") +fi + + if ! "${skip_eval}"; then if [ ${stage} -le 7 ] && [ ${stop_stage} -ge 7 ]; then log "Stage 7: Enhance Speech: training_dir=${enh_exp}" diff --git a/egs2/TEMPLATE/enh_asr1/enh_asr.sh b/egs2/TEMPLATE/enh_asr1/enh_asr.sh index 6edeb263cb2..06bad9a06d5 100755 --- a/egs2/TEMPLATE/enh_asr1/enh_asr.sh +++ b/egs2/TEMPLATE/enh_asr1/enh_asr.sh @@ -105,7 +105,7 @@ inference_tag= # Suffix to the result dir for decoding. inference_config= # Config for decoding. asr_inference_args= # Arguments for decoding, e.g., "--lm_weight 0.1". # Note that it will overwrite args in inference config. -enh_inference_args="--normalize_output_wav true" +enh_inference_args="--normalize_output_wav true --output_format wav" inference_lm=valid.loss.ave.pth # Language model path for decoding. inference_ngram=${ngram_num}gram.bin inference_enh_asr_model=valid.acc.ave.pth # ASR model path for decoding. diff --git a/espnet2/bin/enh_inference.py b/espnet2/bin/enh_inference.py index 5b2a140c320..f9b6d54cdae 100755 --- a/espnet2/bin/enh_inference.py +++ b/espnet2/bin/enh_inference.py @@ -479,6 +479,7 @@ def inference( normalize_segment_scale: bool, show_progressbar: bool, ref_channel: Optional[int], + output_format: str, normalize_output_wav: bool, enh_s2t_task: bool, ): @@ -542,7 +543,11 @@ def inference( writers = [] for i in range(separate_speech.num_spk): writers.append( - SoundScpWriter(f"{output_dir}/wavs/{i + 1}", f"{output_dir}/spk{i + 1}.scp") + SoundScpWriter( + f"{output_dir}/wavs/{i + 1}", + f"{output_dir}/spk{i + 1}.scp", + format=output_format, + ) ) import tqdm @@ -623,6 +628,12 @@ def get_parser(): group.add_argument("--allow_variable_data_keys", type=str2bool, default=False) group = parser.add_argument_group("Output data related") + group.add_argument( + "--output_format", + type=str, + default="wav", + help="Output format for the separated speech", + ) group.add_argument( "--normalize_output_wav", type=str2bool, diff --git a/espnet2/bin/enh_inference_streaming.py b/espnet2/bin/enh_inference_streaming.py index 85b6d75ba64..e902717c122 100755 --- a/espnet2/bin/enh_inference_streaming.py +++ b/espnet2/bin/enh_inference_streaming.py @@ -233,6 +233,7 @@ def inference( inference_config: Optional[str], allow_variable_data_keys: bool, ref_channel: Optional[int], + output_format: str, enh_s2t_task: bool, ): if batch_size > 1: @@ -290,7 +291,11 @@ def inference( writers = [] for i in range(separate_speech.num_spk): writers.append( - SoundScpWriter(f"{output_dir}/wavs/{i + 1}", f"{output_dir}/spk{i + 1}.scp") + SoundScpWriter( + f"{output_dir}/wavs/{i + 1}", + f"{output_dir}/spk{i + 1}.scp", + format=output_format, + ) ) import tqdm @@ -389,6 +394,12 @@ def get_parser(): group.add_argument("--allow_variable_data_keys", type=str2bool, default=False) group = parser.add_argument_group("Output data related") + group.add_argument( + "--output_format", + type=str, + default="wav", + help="Output format for the separated speech", + ) group = parser.add_argument_group("The model configuration related") group.add_argument( diff --git a/espnet2/bin/enh_tse_inference.py b/espnet2/bin/enh_tse_inference.py index 8964cf91e20..153102f4597 100755 --- a/espnet2/bin/enh_tse_inference.py +++ b/espnet2/bin/enh_tse_inference.py @@ -437,6 +437,7 @@ def inference( normalize_segment_scale: bool, show_progressbar: bool, ref_channel: Optional[int], + output_format: str, normalize_output_wav: bool, ): if batch_size > 1: @@ -513,7 +514,11 @@ def inference( writers = [] for i in range(separate_speech.num_spk): writers.append( - SoundScpWriter(f"{output_dir}/wavs/{i + 1}", f"{output_dir}/spk{i + 1}.scp") + SoundScpWriter( + f"{output_dir}/wavs/{i + 1}", + f"{output_dir}/spk{i + 1}.scp", + format=output_format, + ) ) for i, (keys, batch) in enumerate(loader): @@ -590,6 +595,12 @@ def get_parser(): default=False, help="Whether to normalize the predicted wav to [-1~1]", ) + group.add_argument( + "--output_format", + type=str, + default="wav", + help="Output format for the separated speech", + ) group = parser.add_argument_group("The model configuration related") group.add_argument( diff --git a/espnet2/enh/layers/bsrnn.py b/espnet2/enh/layers/bsrnn.py index c5f120da233..bd4b736efa6 100644 --- a/espnet2/enh/layers/bsrnn.py +++ b/espnet2/enh/layers/bsrnn.py @@ -3,11 +3,22 @@ import torch import torch.nn as nn +from espnet2.enh.layers.tcn import choose_norm as choose_norm1d + +EPS = torch.finfo(torch.get_default_dtype()).eps + class BSRNN(nn.Module): # ported from https://github.com/sungwon23/BSRNN def __init__( - self, input_dim=481, num_channel=16, num_layer=6, target_fs=48000, causal=True + self, + input_dim=481, + num_channel=16, + num_layer=6, + target_fs=48000, + causal=True, + num_spk=1, + norm_type="GN", ): """Band-Split RNN (BSRNN). @@ -27,14 +38,18 @@ def __init__( target_fs (int): maximum sampling frequency supported by the model causal (bool): Whether or not to adopt causal processing if True, LSTM will be used instead of BLSTM for time modeling + num_spk (int): number of outputs to be generated + norm_type (str): type of normalization layer (cfLN / cLN / BN / GN) """ super().__init__() + norm1d_type = norm_type if norm_type != "cfLN" else "cLN" self.num_layer = num_layer self.band_split = BandSplit( - input_dim, target_fs=target_fs, channels=num_channel + input_dim, target_fs=target_fs, channels=num_channel, norm_type=norm1d_type ) self.target_fs = target_fs self.causal = causal + self.num_spk = num_spk self.norm_time = nn.ModuleList() self.rnn_time = nn.ModuleList() @@ -44,7 +59,7 @@ def __init__( self.fc_freq = nn.ModuleList() hdim = 2 * num_channel for i in range(self.num_layer): - self.norm_time.append(nn.GroupNorm(1, num_channel)) + self.norm_time.append(choose_norm(norm_type, num_channel)) self.rnn_time.append( nn.LSTM( num_channel, @@ -54,14 +69,18 @@ def __init__( ) ) self.fc_time.append(nn.Linear(hdim if causal else hdim * 2, num_channel)) - self.norm_freq.append(nn.GroupNorm(1, num_channel)) + self.norm_freq.append(choose_norm(norm_type, num_channel)) self.rnn_freq.append( nn.LSTM(num_channel, hdim, batch_first=True, bidirectional=True) ) self.fc_freq.append(nn.Linear(4 * num_channel, num_channel)) self.mask_decoder = MaskDecoder( - input_dim, self.band_split.subbands, channels=num_channel + input_dim, + self.band_split.subbands, + channels=num_channel, + num_spk=num_spk, + norm_type=norm1d_type, ) def forward(self, x, fs=None): @@ -75,7 +94,7 @@ def forward(self, x, fs=None): if None, the input signal is assumed to be already truncated to only contain effective frequency subbands. Returns: - out (torch.Tensor): output tensor of shape (B, T, F, 2) + out (torch.Tensor): output tensor of shape (B, num_spk, T, F, 2) """ z = self.band_split(x, fs=fs) B, N, T, K = z.shape @@ -101,11 +120,12 @@ def forward(self, x, fs=None): x = torch.view_as_complex(x) m = m[..., : x.size(-1)] r = r[..., : x.size(-1)] - return torch.view_as_real(m * x + r) + ret = torch.view_as_real(m * x.unsqueeze(1) + r) + return ret class BandSplit(nn.Module): - def __init__(self, input_dim, target_fs=48000, channels=128): + def __init__(self, input_dim, target_fs=48000, channels=128, norm_type="GN"): super().__init__() assert input_dim % 2 == 1, input_dim n_fft = (input_dim - 1) * 2 @@ -123,12 +143,13 @@ def __init__(self, input_dim, target_fs=48000, channels=128): f"Please define your own subbands for input_dim={input_dim} and " f"target_fs={target_fs}" ) + assert sum(self.subbands) == input_dim, (self.subbands, input_dim) self.subband_freqs = freqs[[idx - 1 for idx in accumulate(self.subbands)]] self.norm = nn.ModuleList() self.fc = nn.ModuleList() for i in range(len(self.subbands)): - self.norm.append(nn.GroupNorm(1, int(self.subbands[i] * 2))) + self.norm.append(choose_norm1d(norm_type, int(self.subbands[i] * 2))) self.fc.append(nn.Conv1d(int(self.subbands[i] * 2), channels, 1)) def forward(self, x, fs=None): @@ -145,14 +166,8 @@ def forward(self, x, fs=None): z (torch.Tensor): output tensor of shape (B, N, T, K') K' might be smaller than len(self.subbands) if fs < self.target_fs. """ - if fs is not None: - assert x.size(2) == sum(self.subbands), (x.size(2), sum(self.subbands)) hz_band = 0 for i, subband in enumerate(self.subbands): - if fs is not None and self.subband_freqs[i] > fs / 2: - break - if fs is None and hz_band >= x.size(2): - break x_band = x[:, :, hz_band : hz_band + int(subband), :] if int(subband) > x_band.size(2): x_band = nn.functional.pad( @@ -168,33 +183,38 @@ def forward(self, x, fs=None): else: z = torch.cat((z, out.unsqueeze(-1)), dim=-1) hz_band = hz_band + int(subband) + if hz_band >= x.size(2): + break + if fs is not None and self.subband_freqs[i] >= fs / 2: + break return z class MaskDecoder(nn.Module): - def __init__(self, freq_dim, subbands, channels=128): + def __init__(self, freq_dim, subbands, channels=128, num_spk=1, norm_type="GN"): super().__init__() assert freq_dim == sum(subbands), (freq_dim, subbands) self.subbands = subbands self.freq_dim = freq_dim + self.num_spk = num_spk self.mlp_mask = nn.ModuleList() self.mlp_residual = nn.ModuleList() for subband in self.subbands: self.mlp_mask.append( nn.Sequential( - nn.GroupNorm(1, channels), + choose_norm1d(norm_type, channels), nn.Conv1d(channels, 4 * channels, 1), nn.Tanh(), - nn.Conv1d(4 * channels, int(subband * 4), 1), + nn.Conv1d(4 * channels, int(subband * 4 * num_spk), 1), nn.GLU(dim=1), ) ) self.mlp_residual.append( nn.Sequential( - nn.GroupNorm(1, channels), + choose_norm1d(norm_type, channels), nn.Conv1d(channels, 4 * channels, 1), nn.Tanh(), - nn.Conv1d(4 * channels, int(subband * 4), 1), + nn.Conv1d(4 * channels, int(subband * 4 * num_spk), 1), nn.GLU(dim=1), ) ) @@ -205,27 +225,126 @@ def forward(self, x): Args: x (torch.Tensor): input tensor of shape (B, N, T, K) Returns: - m (torch.Tensor): output mask of shape (B, T, F, 2) - r (torch.Tensor): output residual of shape (B, T, F, 2) + m (torch.Tensor): output mask of shape (B, num_spk, T, F, 2) + r (torch.Tensor): output residual of shape (B, num_spk, T, F, 2) """ for i in range(len(self.subbands)): if i >= x.size(-1): break x_band = x[:, :, :, i] out = self.mlp_mask[i](x_band).transpose(1, 2).contiguous() - out = out.reshape(out.size(0), out.size(1), -1, 2) + # (B, T, num_spk, subband, 2) + out = out.reshape(out.size(0), out.size(1), self.num_spk, -1, 2) if i == 0: m = out else: - m = torch.cat((m, out), dim=2) + m = torch.cat((m, out), dim=3) res = self.mlp_residual[i](x_band).transpose(1, 2).contiguous() - res = res.reshape(res.size(0), res.size(1), -1, 2) + # (B, T, num_spk, subband, 2) + res = res.reshape(res.size(0), res.size(1), self.num_spk, -1, 2) if i == 0: r = res else: - r = torch.cat((r, res), dim=2) - # Pad zeros in addition to efffective subbands to cover the full frequency range + r = torch.cat((r, res), dim=3) + # Pad zeros in addition to effective subbands to cover the full frequency range m = nn.functional.pad(m, (0, 0, 0, int(self.freq_dim - m.size(-2)))) r = nn.functional.pad(r, (0, 0, 0, int(self.freq_dim - r.size(-2)))) - return m, r + return m.moveaxis(1, 2), r.moveaxis(1, 2) + + +def choose_norm(norm_type, channel_size, shape="BDTF"): + """The input of normalization will be (M, C, K), where M is batch size. + + C is channel size and K is sequence length. + """ + if norm_type == "cfLN": + return ChannelFreqwiseLayerNorm(channel_size, shape=shape) + elif norm_type == "cLN": + return ChannelwiseLayerNorm(channel_size, shape=shape) + elif norm_type == "BN": + # Given input (M, C, T, K), nn.BatchNorm2d(C) will accumulate statics + # along M, T, and K, so this BN usage is right. + return nn.BatchNorm2d(channel_size) + elif norm_type == "GN": + return nn.GroupNorm(1, channel_size) + else: + raise ValueError("Unsupported normalization type") + + +class ChannelwiseLayerNorm(nn.Module): + """Channel-wise Layer Normalization (cLN).""" + + def __init__(self, channel_size, shape="BDTF"): + super().__init__() + self.gamma = nn.Parameter(torch.Tensor(1, channel_size, 1, 1)) # [1, N, 1] + self.beta = nn.Parameter(torch.Tensor(1, channel_size, 1, 1)) # [1, N, 1] + self.reset_parameters() + assert shape in ["BDTF", "BTFD"], shape + self.shape = shape + + def reset_parameters(self): + self.gamma.data.fill_(1) + self.beta.data.zero_() + + @torch.cuda.amp.autocast(enabled=False) + def forward(self, y): + """Forward. + + Args: + y: [M, N, T, K], M is batch size, N is channel size, T and K are lengths + + Returns: + cLN_y: [M, N, T, K] + """ + + assert y.dim() == 4 + + if self.shape == "BTFD": + y = y.moveaxis(-1, 1) + + mean = torch.mean(y, dim=1, keepdim=True) # [M, 1, T, K] + var = torch.var(y, dim=1, keepdim=True, unbiased=False) # [M, 1, T, K] + cLN_y = self.gamma * (y - mean) / torch.pow(var + EPS, 0.5) + self.beta + + if self.shape == "BTFD": + cLN_y = cLN_y.moveaxis(1, -1) + + return cLN_y + + +class ChannelFreqwiseLayerNorm(nn.Module): + """Channel-and-Frequency-wise Layer Normalization (cfLN).""" + + def __init__(self, channel_size, shape="BDTF"): + super().__init__() + self.gamma = nn.Parameter(torch.Tensor(1, channel_size, 1, 1)) # [1, N, 1, 1] + self.beta = nn.Parameter(torch.Tensor(1, channel_size, 1, 1)) # [1, N, 1, 1] + self.reset_parameters() + assert shape in ["BDTF", "BTFD"], shape + self.shape = shape + + def reset_parameters(self): + self.gamma.data.fill_(1) + self.beta.data.zero_() + + @torch.cuda.amp.autocast(enabled=False) + def forward(self, y): + """Forward. + + Args: + y: [M, N, T, K], M is batch size, N is channel size, T and K are lengths + + Returns: + gLN_y: [M, N, T, K] + """ + if self.shape == "BTFD": + y = y.moveaxis(-1, 1) + + mean = y.mean(dim=(1, 3), keepdim=True) # [M, 1, T, 1] + var = (torch.pow(y - mean, 2)).mean(dim=(1, 3), keepdim=True) + gLN_y = self.gamma * (y - mean) / torch.pow(var + EPS, 0.5) + self.beta + + if self.shape == "BTFD": + gLN_y = gLN_y.moveaxis(1, -1) + return gLN_y diff --git a/espnet2/enh/separator/bsrnn_separator.py b/espnet2/enh/separator/bsrnn_separator.py index e1b441ff903..cdca9f346d3 100644 --- a/espnet2/enh/separator/bsrnn_separator.py +++ b/espnet2/enh/separator/bsrnn_separator.py @@ -18,6 +18,7 @@ def __init__( num_layers: int = 6, target_fs: int = 48000, causal: bool = True, + norm_type: str = "GN", ref_channel: Optional[int] = None, ): """Band-split RNN (BSRNN) separator. @@ -39,6 +40,7 @@ def __init__( target_fs: (int) max sampling frequency that the model can handle. causal (bool): whether or not to apply causal modeling. if True, LSTM will be used instead of BLSTM for time modeling + norm_type (str): type of the normalization layer (cfLN / cLN / BN / GN). ref_channel: (int) reference channel. not used for now. """ super().__init__() @@ -46,13 +48,14 @@ def __init__( self._num_spk = num_spk self.ref_channel = ref_channel - assert num_spk == 1, num_spk self.bsrnn = BSRNN( input_dim=input_dim, num_channel=num_channels, num_layer=num_layers, target_fs=target_fs, causal=causal, + num_spk=num_spk, + norm_type=norm_type, ) def forward( @@ -86,7 +89,10 @@ def forward( assert input.size(-1) == 2, input.shape feature = input - masked = self.bsrnn(feature).unsqueeze(1) + opt = {} + if additional is not None and "fs" in additional: + opt["fs"] = additional["fs"] + masked = self.bsrnn(feature, **opt) # B, num_spk, T, F if not is_complex(input): masked = list(ComplexTensor(masked[..., 0], masked[..., 1]).unbind(1)) diff --git a/test/espnet2/enh/separator/test_bsrnn_separator.py b/test/espnet2/enh/separator/test_bsrnn_separator.py index 8ed11915192..782a9729485 100644 --- a/test/espnet2/enh/separator/test_bsrnn_separator.py +++ b/test/espnet2/enh/separator/test_bsrnn_separator.py @@ -1,17 +1,17 @@ import pytest import torch -from torch import Tensor from torch_complex import ComplexTensor from espnet2.enh.separator.bsrnn_separator import BSRNNSeparator @pytest.mark.parametrize("input_dim", [481]) -@pytest.mark.parametrize("num_spk", [1]) +@pytest.mark.parametrize("num_spk", [1, 2]) @pytest.mark.parametrize("num_channels", [16]) @pytest.mark.parametrize("num_layers", [3]) @pytest.mark.parametrize("target_fs", [48000]) @pytest.mark.parametrize("causal", [True, False]) +@pytest.mark.parametrize("norm_type", ["cfLN", "cLN", "BN", "GN"]) def test_bsrnn_separator_forward_backward_complex( input_dim, num_spk, @@ -19,6 +19,7 @@ def test_bsrnn_separator_forward_backward_complex( num_layers, target_fs, causal, + norm_type, ): model = BSRNNSeparator( input_dim=input_dim, @@ -27,6 +28,7 @@ def test_bsrnn_separator_forward_backward_complex( num_layers=num_layers, target_fs=target_fs, causal=causal, + norm_type=norm_type, ) model.train() @@ -44,7 +46,7 @@ def test_bsrnn_separator_forward_backward_complex( @pytest.mark.parametrize("input_dim", [481]) -@pytest.mark.parametrize("num_spk", [1]) +@pytest.mark.parametrize("num_spk", [1, 2]) @pytest.mark.parametrize("num_channels", [16]) @pytest.mark.parametrize("num_layers", [3]) @pytest.mark.parametrize("target_fs", [48000]) @@ -94,3 +96,26 @@ def test_bsrnn_separator_with_different_sf(): f = int(sf * 0.01) + 1 x = torch.randn(2, 10, f, 2) model(x, ilens=x_lens) + + +@pytest.mark.parametrize("fs", [8000, 16000, 24000, 32000, 44100, 48000]) +def test_bsrnn_separator_with_fs_arg(fs): + x_lens = torch.tensor([10, 8], dtype=torch.long) + + model = BSRNNSeparator( + input_dim=481, + num_spk=1, + num_channels=10, + num_layers=3, + target_fs=48000, + causal=True, + ) + model.eval() + + f = int(fs * 0.01) + 1 + x = torch.randn(2, 10, f, 2) + y1 = model(x, ilens=x_lens)[0] + y2 = model(x, ilens=x_lens, additional={"fs": fs})[0] + for yy1, yy2 in zip(y1, y2): + torch.testing.assert_close(yy1.real, yy2.real) + torch.testing.assert_close(yy1.imag, yy2.imag)