From 5ace0943350e5d84211bbc5f17a0d188065427e8 Mon Sep 17 00:00:00 2001 From: Wangyou Zhang Date: Fri, 28 Jun 2024 14:16:53 +0800 Subject: [PATCH 1/8] Update BSRNN to support speech separation (num_spk > 1); Minor update of data preparation in egs2/urgent24/enh1; Add an arugment to specify the output audio format in espnet2/bin/enh_inference.py and espnet2/bin/enh_inference_streaming.py --- egs2/urgent24/enh1/local/data.sh | 6 ++++ espnet2/bin/enh_inference.py | 13 +++++++- espnet2/bin/enh_inference_streaming.py | 13 +++++++- espnet2/enh/layers/bsrnn.py | 42 +++++++++++++++--------- espnet2/enh/separator/bsrnn_separator.py | 4 +-- 5 files changed, 59 insertions(+), 19 deletions(-) diff --git a/egs2/urgent24/enh1/local/data.sh b/egs2/urgent24/enh1/local/data.sh index b61e9c551e9..79b19747680 100755 --- a/egs2/urgent24/enh1/local/data.sh +++ b/egs2/urgent24/enh1/local/data.sh @@ -101,6 +101,12 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then ./prepare_espnet_data.sh cd "${cwd}" cp -r "${odir}"/urgent2024_challenge/data "${PWD}/data" + if [ ! -d "${PWD}/simulation_validation" ]; then + ln -s "${odir}/urgent2024_challenge/simulation_validation" "${PWD}/" + fi + if [ ! -d "${PWD}/simulation_train" ]; then + ln -s "${odir}/urgent2024_challenge/simulation_train" "${PWD}/" + fi fi log "Successfully finished. [elapsed=${SECONDS}s]" diff --git a/espnet2/bin/enh_inference.py b/espnet2/bin/enh_inference.py index 5b2a140c320..f9b6d54cdae 100755 --- a/espnet2/bin/enh_inference.py +++ b/espnet2/bin/enh_inference.py @@ -479,6 +479,7 @@ def inference( normalize_segment_scale: bool, show_progressbar: bool, ref_channel: Optional[int], + output_format: str, normalize_output_wav: bool, enh_s2t_task: bool, ): @@ -542,7 +543,11 @@ def inference( writers = [] for i in range(separate_speech.num_spk): writers.append( - SoundScpWriter(f"{output_dir}/wavs/{i + 1}", f"{output_dir}/spk{i + 1}.scp") + SoundScpWriter( + f"{output_dir}/wavs/{i + 1}", + f"{output_dir}/spk{i + 1}.scp", + format=output_format, + ) ) import tqdm @@ -623,6 +628,12 @@ def get_parser(): group.add_argument("--allow_variable_data_keys", type=str2bool, default=False) group = parser.add_argument_group("Output data related") + group.add_argument( + "--output_format", + type=str, + default="wav", + help="Output format for the separated speech", + ) group.add_argument( "--normalize_output_wav", type=str2bool, diff --git a/espnet2/bin/enh_inference_streaming.py b/espnet2/bin/enh_inference_streaming.py index 85b6d75ba64..e902717c122 100755 --- a/espnet2/bin/enh_inference_streaming.py +++ b/espnet2/bin/enh_inference_streaming.py @@ -233,6 +233,7 @@ def inference( inference_config: Optional[str], allow_variable_data_keys: bool, ref_channel: Optional[int], + output_format: str, enh_s2t_task: bool, ): if batch_size > 1: @@ -290,7 +291,11 @@ def inference( writers = [] for i in range(separate_speech.num_spk): writers.append( - SoundScpWriter(f"{output_dir}/wavs/{i + 1}", f"{output_dir}/spk{i + 1}.scp") + SoundScpWriter( + f"{output_dir}/wavs/{i + 1}", + f"{output_dir}/spk{i + 1}.scp", + format=output_format, + ) ) import tqdm @@ -389,6 +394,12 @@ def get_parser(): group.add_argument("--allow_variable_data_keys", type=str2bool, default=False) group = parser.add_argument_group("Output data related") + group.add_argument( + "--output_format", + type=str, + default="wav", + help="Output format for the separated speech", + ) group = parser.add_argument_group("The model configuration related") group.add_argument( diff --git a/espnet2/enh/layers/bsrnn.py b/espnet2/enh/layers/bsrnn.py index c5f120da233..d4fec805920 100644 --- a/espnet2/enh/layers/bsrnn.py +++ b/espnet2/enh/layers/bsrnn.py @@ -7,7 +7,13 @@ class BSRNN(nn.Module): # ported from https://github.com/sungwon23/BSRNN def __init__( - self, input_dim=481, num_channel=16, num_layer=6, target_fs=48000, causal=True + self, + input_dim=481, + num_channel=16, + num_layer=6, + target_fs=48000, + causal=True, + num_spk=1, ): """Band-Split RNN (BSRNN). @@ -27,6 +33,7 @@ def __init__( target_fs (int): maximum sampling frequency supported by the model causal (bool): Whether or not to adopt causal processing if True, LSTM will be used instead of BLSTM for time modeling + num_spk (int): number of outputs to be generated """ super().__init__() self.num_layer = num_layer @@ -35,6 +42,7 @@ def __init__( ) self.target_fs = target_fs self.causal = causal + self.num_spk = num_spk self.norm_time = nn.ModuleList() self.rnn_time = nn.ModuleList() @@ -61,7 +69,7 @@ def __init__( self.fc_freq.append(nn.Linear(4 * num_channel, num_channel)) self.mask_decoder = MaskDecoder( - input_dim, self.band_split.subbands, channels=num_channel + input_dim, self.band_split.subbands, channels=num_channel, num_spk=num_spk ) def forward(self, x, fs=None): @@ -75,7 +83,7 @@ def forward(self, x, fs=None): if None, the input signal is assumed to be already truncated to only contain effective frequency subbands. Returns: - out (torch.Tensor): output tensor of shape (B, T, F, 2) + out (torch.Tensor): output tensor of shape (B, num_spk, T, F, 2) """ z = self.band_split(x, fs=fs) B, N, T, K = z.shape @@ -101,7 +109,8 @@ def forward(self, x, fs=None): x = torch.view_as_complex(x) m = m[..., : x.size(-1)] r = r[..., : x.size(-1)] - return torch.view_as_real(m * x + r) + ret = torch.view_as_real(m * x.unsqueeze(1) + r) + return ret class BandSplit(nn.Module): @@ -172,11 +181,12 @@ def forward(self, x, fs=None): class MaskDecoder(nn.Module): - def __init__(self, freq_dim, subbands, channels=128): + def __init__(self, freq_dim, subbands, channels=128, num_spk=1): super().__init__() assert freq_dim == sum(subbands), (freq_dim, subbands) self.subbands = subbands self.freq_dim = freq_dim + self.num_spk = num_spk self.mlp_mask = nn.ModuleList() self.mlp_residual = nn.ModuleList() for subband in self.subbands: @@ -185,7 +195,7 @@ def __init__(self, freq_dim, subbands, channels=128): nn.GroupNorm(1, channels), nn.Conv1d(channels, 4 * channels, 1), nn.Tanh(), - nn.Conv1d(4 * channels, int(subband * 4), 1), + nn.Conv1d(4 * channels, int(subband * 4 * num_spk), 1), nn.GLU(dim=1), ) ) @@ -194,7 +204,7 @@ def __init__(self, freq_dim, subbands, channels=128): nn.GroupNorm(1, channels), nn.Conv1d(channels, 4 * channels, 1), nn.Tanh(), - nn.Conv1d(4 * channels, int(subband * 4), 1), + nn.Conv1d(4 * channels, int(subband * 4 * num_spk), 1), nn.GLU(dim=1), ) ) @@ -205,27 +215,29 @@ def forward(self, x): Args: x (torch.Tensor): input tensor of shape (B, N, T, K) Returns: - m (torch.Tensor): output mask of shape (B, T, F, 2) - r (torch.Tensor): output residual of shape (B, T, F, 2) + m (torch.Tensor): output mask of shape (B, num_spk, T, F, 2) + r (torch.Tensor): output residual of shape (B, num_spk, T, F, 2) """ for i in range(len(self.subbands)): if i >= x.size(-1): break x_band = x[:, :, :, i] out = self.mlp_mask[i](x_band).transpose(1, 2).contiguous() - out = out.reshape(out.size(0), out.size(1), -1, 2) + # (B, T, num_spk, subband, 2) + out = out.reshape(out.size(0), out.size(1), self.num_spk, -1, 2) if i == 0: m = out else: - m = torch.cat((m, out), dim=2) + m = torch.cat((m, out), dim=3) res = self.mlp_residual[i](x_band).transpose(1, 2).contiguous() - res = res.reshape(res.size(0), res.size(1), -1, 2) + # (B, T, num_spk, subband, 2) + res = res.reshape(res.size(0), res.size(1), self.num_spk, -1, 2) if i == 0: r = res else: - r = torch.cat((r, res), dim=2) - # Pad zeros in addition to efffective subbands to cover the full frequency range + r = torch.cat((r, res), dim=3) + # Pad zeros in addition to effective subbands to cover the full frequency range m = nn.functional.pad(m, (0, 0, 0, int(self.freq_dim - m.size(-2)))) r = nn.functional.pad(r, (0, 0, 0, int(self.freq_dim - r.size(-2)))) - return m, r + return m.moveaxis(1, 2), r.moveaxis(1, 2) diff --git a/espnet2/enh/separator/bsrnn_separator.py b/espnet2/enh/separator/bsrnn_separator.py index e1b441ff903..f6c6abae788 100644 --- a/espnet2/enh/separator/bsrnn_separator.py +++ b/espnet2/enh/separator/bsrnn_separator.py @@ -46,13 +46,13 @@ def __init__( self._num_spk = num_spk self.ref_channel = ref_channel - assert num_spk == 1, num_spk self.bsrnn = BSRNN( input_dim=input_dim, num_channel=num_channels, num_layer=num_layers, target_fs=target_fs, causal=causal, + num_spk=num_spk, ) def forward( @@ -86,7 +86,7 @@ def forward( assert input.size(-1) == 2, input.shape feature = input - masked = self.bsrnn(feature).unsqueeze(1) + masked = self.bsrnn(feature) # B, num_spk, T, F if not is_complex(input): masked = list(ComplexTensor(masked[..., 0], masked[..., 1]).unbind(1)) From d3ed4aa7a69dd8a63b192b91441fd8cc3a62a48a Mon Sep 17 00:00:00 2001 From: Wangyou Zhang Date: Fri, 28 Jun 2024 14:24:08 +0800 Subject: [PATCH 2/8] Update the unit test for BSRNN --- egs2/urgent24/enh1/local/data.sh | 6 ------ test/espnet2/enh/separator/test_bsrnn_separator.py | 5 ++--- 2 files changed, 2 insertions(+), 9 deletions(-) diff --git a/egs2/urgent24/enh1/local/data.sh b/egs2/urgent24/enh1/local/data.sh index 79b19747680..b61e9c551e9 100755 --- a/egs2/urgent24/enh1/local/data.sh +++ b/egs2/urgent24/enh1/local/data.sh @@ -101,12 +101,6 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then ./prepare_espnet_data.sh cd "${cwd}" cp -r "${odir}"/urgent2024_challenge/data "${PWD}/data" - if [ ! -d "${PWD}/simulation_validation" ]; then - ln -s "${odir}/urgent2024_challenge/simulation_validation" "${PWD}/" - fi - if [ ! -d "${PWD}/simulation_train" ]; then - ln -s "${odir}/urgent2024_challenge/simulation_train" "${PWD}/" - fi fi log "Successfully finished. [elapsed=${SECONDS}s]" diff --git a/test/espnet2/enh/separator/test_bsrnn_separator.py b/test/espnet2/enh/separator/test_bsrnn_separator.py index 8ed11915192..66975b57d1a 100644 --- a/test/espnet2/enh/separator/test_bsrnn_separator.py +++ b/test/espnet2/enh/separator/test_bsrnn_separator.py @@ -1,13 +1,12 @@ import pytest import torch -from torch import Tensor from torch_complex import ComplexTensor from espnet2.enh.separator.bsrnn_separator import BSRNNSeparator @pytest.mark.parametrize("input_dim", [481]) -@pytest.mark.parametrize("num_spk", [1]) +@pytest.mark.parametrize("num_spk", [1, 2]) @pytest.mark.parametrize("num_channels", [16]) @pytest.mark.parametrize("num_layers", [3]) @pytest.mark.parametrize("target_fs", [48000]) @@ -44,7 +43,7 @@ def test_bsrnn_separator_forward_backward_complex( @pytest.mark.parametrize("input_dim", [481]) -@pytest.mark.parametrize("num_spk", [1]) +@pytest.mark.parametrize("num_spk", [1, 2]) @pytest.mark.parametrize("num_channels", [16]) @pytest.mark.parametrize("num_layers", [3]) @pytest.mark.parametrize("target_fs", [48000]) From a3bcf2bd285cceca11a67e39ee3b5f845ce51be7 Mon Sep 17 00:00:00 2001 From: Wangyou Zhang Date: Fri, 28 Jun 2024 14:29:16 +0800 Subject: [PATCH 3/8] Also update enh.sh, enh_asr.sh, and espnet2/bin/enh_tse_inference.py to support --output_format --- egs2/TEMPLATE/enh1/enh.sh | 4 ++-- egs2/TEMPLATE/enh_asr1/enh_asr.sh | 2 +- espnet2/bin/enh_tse_inference.py | 13 ++++++++++++- 3 files changed, 15 insertions(+), 4 deletions(-) diff --git a/egs2/TEMPLATE/enh1/enh.sh b/egs2/TEMPLATE/enh1/enh.sh index 39820a3328a..ea0faf65fd4 100755 --- a/egs2/TEMPLATE/enh1/enh.sh +++ b/egs2/TEMPLATE/enh1/enh.sh @@ -78,7 +78,7 @@ extra_wav_list= # Extra list of scp files for wav formatting init_param= # Enhancement related -inference_args="--normalize_output_wav true" +inference_args="--normalize_output_wav true --output_format wav" inference_model=valid.loss.ave.pth download_model= @@ -222,7 +222,7 @@ fi [ -z "${test_sets}" ] && { log "${help_message}"; log "Error: --test_sets is required"; exit 2; }; # Extra files for enhancement process -utt_extra_files="utt2category" +utt_extra_files="utt2category utt2fs" data_feats=${dumpdir}/raw diff --git a/egs2/TEMPLATE/enh_asr1/enh_asr.sh b/egs2/TEMPLATE/enh_asr1/enh_asr.sh index 6edeb263cb2..06bad9a06d5 100755 --- a/egs2/TEMPLATE/enh_asr1/enh_asr.sh +++ b/egs2/TEMPLATE/enh_asr1/enh_asr.sh @@ -105,7 +105,7 @@ inference_tag= # Suffix to the result dir for decoding. inference_config= # Config for decoding. asr_inference_args= # Arguments for decoding, e.g., "--lm_weight 0.1". # Note that it will overwrite args in inference config. -enh_inference_args="--normalize_output_wav true" +enh_inference_args="--normalize_output_wav true --output_format wav" inference_lm=valid.loss.ave.pth # Language model path for decoding. inference_ngram=${ngram_num}gram.bin inference_enh_asr_model=valid.acc.ave.pth # ASR model path for decoding. diff --git a/espnet2/bin/enh_tse_inference.py b/espnet2/bin/enh_tse_inference.py index 8964cf91e20..153102f4597 100755 --- a/espnet2/bin/enh_tse_inference.py +++ b/espnet2/bin/enh_tse_inference.py @@ -437,6 +437,7 @@ def inference( normalize_segment_scale: bool, show_progressbar: bool, ref_channel: Optional[int], + output_format: str, normalize_output_wav: bool, ): if batch_size > 1: @@ -513,7 +514,11 @@ def inference( writers = [] for i in range(separate_speech.num_spk): writers.append( - SoundScpWriter(f"{output_dir}/wavs/{i + 1}", f"{output_dir}/spk{i + 1}.scp") + SoundScpWriter( + f"{output_dir}/wavs/{i + 1}", + f"{output_dir}/spk{i + 1}.scp", + format=output_format, + ) ) for i, (keys, batch) in enumerate(loader): @@ -590,6 +595,12 @@ def get_parser(): default=False, help="Whether to normalize the predicted wav to [-1~1]", ) + group.add_argument( + "--output_format", + type=str, + default="wav", + help="Output format for the separated speech", + ) group = parser.add_argument_group("The model configuration related") group.add_argument( From 4084a02f475a2543d28fee85533b1a1ccfe3a236 Mon Sep 17 00:00:00 2001 From: Wangyou Zhang Date: Mon, 1 Jul 2024 14:05:47 +0800 Subject: [PATCH 4/8] Unify BSRNN interfaces with and without the argument --- espnet2/enh/layers/bsrnn.py | 12 +++++----- espnet2/enh/separator/bsrnn_separator.py | 5 +++- .../enh/separator/test_bsrnn_separator.py | 23 +++++++++++++++++++ 3 files changed, 33 insertions(+), 7 deletions(-) diff --git a/espnet2/enh/layers/bsrnn.py b/espnet2/enh/layers/bsrnn.py index d4fec805920..078fd723403 100644 --- a/espnet2/enh/layers/bsrnn.py +++ b/espnet2/enh/layers/bsrnn.py @@ -132,6 +132,7 @@ def __init__(self, input_dim, target_fs=48000, channels=128): f"Please define your own subbands for input_dim={input_dim} and " f"target_fs={target_fs}" ) + assert sum(self.subbands) == input_dim, (self.subbands, input_dim) self.subband_freqs = freqs[[idx - 1 for idx in accumulate(self.subbands)]] self.norm = nn.ModuleList() @@ -154,14 +155,8 @@ def forward(self, x, fs=None): z (torch.Tensor): output tensor of shape (B, N, T, K') K' might be smaller than len(self.subbands) if fs < self.target_fs. """ - if fs is not None: - assert x.size(2) == sum(self.subbands), (x.size(2), sum(self.subbands)) hz_band = 0 for i, subband in enumerate(self.subbands): - if fs is not None and self.subband_freqs[i] > fs / 2: - break - if fs is None and hz_band >= x.size(2): - break x_band = x[:, :, hz_band : hz_band + int(subband), :] if int(subband) > x_band.size(2): x_band = nn.functional.pad( @@ -177,6 +172,11 @@ def forward(self, x, fs=None): else: z = torch.cat((z, out.unsqueeze(-1)), dim=-1) hz_band = hz_band + int(subband) + print(i, subband, hz_band, self.subband_freqs[i]) + if hz_band >= x.size(2): + break + if fs is not None and self.subband_freqs[i] >= fs / 2: + break return z diff --git a/espnet2/enh/separator/bsrnn_separator.py b/espnet2/enh/separator/bsrnn_separator.py index f6c6abae788..3edbc78238f 100644 --- a/espnet2/enh/separator/bsrnn_separator.py +++ b/espnet2/enh/separator/bsrnn_separator.py @@ -86,7 +86,10 @@ def forward( assert input.size(-1) == 2, input.shape feature = input - masked = self.bsrnn(feature) + opt = {} + if additional is not None and "fs" in additional: + opt["fs"] = additional["fs"] + masked = self.bsrnn(feature, **opt) # B, num_spk, T, F if not is_complex(input): masked = list(ComplexTensor(masked[..., 0], masked[..., 1]).unbind(1)) diff --git a/test/espnet2/enh/separator/test_bsrnn_separator.py b/test/espnet2/enh/separator/test_bsrnn_separator.py index 66975b57d1a..e1ffe7856e5 100644 --- a/test/espnet2/enh/separator/test_bsrnn_separator.py +++ b/test/espnet2/enh/separator/test_bsrnn_separator.py @@ -93,3 +93,26 @@ def test_bsrnn_separator_with_different_sf(): f = int(sf * 0.01) + 1 x = torch.randn(2, 10, f, 2) model(x, ilens=x_lens) + + +@pytest.mark.parametrize("fs", [8000, 16000, 24000, 32000, 44100, 48000]) +def test_bsrnn_separator_with_fs_arg(fs): + x_lens = torch.tensor([10, 8], dtype=torch.long) + + model = BSRNNSeparator( + input_dim=481, + num_spk=1, + num_channels=10, + num_layers=3, + target_fs=48000, + causal=True, + ) + model.eval() + + f = int(fs * 0.01) + 1 + x = torch.randn(2, 10, f, 2) + y1 = model(x, ilens=x_lens)[0] + y2 = model(x, ilens=x_lens, additional={"fs": fs})[0] + for yy1, yy2 in zip(y1, y2): + torch.testing.assert_close(yy1.real, yy2.real) + torch.testing.assert_close(yy1.imag, yy2.imag) From b879ea2b52b312a32b0426cc4103fe4bb45e1683 Mon Sep 17 00:00:00 2001 From: Wangyou Zhang Date: Wed, 10 Jul 2024 20:10:03 +0800 Subject: [PATCH 5/8] Fix a bug related to --download_model in enh.sh --- egs2/TEMPLATE/enh1/enh.sh | 29 +++++++++++++++++++++++++---- 1 file changed, 25 insertions(+), 4 deletions(-) diff --git a/egs2/TEMPLATE/enh1/enh.sh b/egs2/TEMPLATE/enh1/enh.sh index ea0faf65fd4..cc942a8b76a 100755 --- a/egs2/TEMPLATE/enh1/enh.sh +++ b/egs2/TEMPLATE/enh1/enh.sh @@ -89,7 +89,7 @@ ref_channel=0 inference_tag= # Prefix to the result dir for ENH inference. inference_enh_config= # Config for enhancement. score_with_asr=false -asr_exp="" # asr model for scoring WER +enh_exp="" # asr model for scoring WER lm_exp="" # lm model for scoring WER inference_asr_model=valid.acc.best.pth # ASR model path for decoding. inference_lm=valid.loss.best.pth # Language model path for decoding. @@ -172,6 +172,7 @@ Options: --inference_args # Arguments for enhancement in the inference stage (default="${inference_args}") --inference_model # Enhancement model path for inference (default="${inference_model}"). --inference_enh_config # Configuration file for overwriting some model attributes during SE inference. (default="${inference_enh_config}") + --download_model # Download a model from Model Zoo and use it for inference (default="${download_model}"). # Evaluation related --scoring_protocol # Metrics to be used for scoring (default="${scoring_protocol}") @@ -182,7 +183,7 @@ Options: # ASR evaluation related --score_with_asr # Enable ASR evaluation (default="${score_with_asr}") - --asr_exp # asr model for scoring WER (default="${asr_exp}") + --enh_exp # asr model for scoring WER (default="${enh_exp}") --lm_exp # lm model for scoring WER (default="${lm_exp}") --nlsyms_txt # Non-linguistic symbol list if existing. (default="${nlsyms_txt}") --inference_asr_model # ASR model path for decoding. (default="${inference_asr_model}") @@ -800,6 +801,26 @@ else fi + +if [ -n "${download_model}" ]; then + log "Use ${download_model} for inference and scoring" + enh_exp="${expdir}/${download_model}" + mkdir -p "${enh_exp}" + + # If the model already exists, you can skip downloading + espnet_model_zoo_download --unpack true "${download_model}" > "${enh_exp}/config.txt" + + # Get the path of each file + _enh_model_file=$(<"${enh_exp}/config.txt" sed -e "s/.*'enh_model_file': '\([^']*\)'.*$/\1/") + _enh_train_config=$(<"${enh_exp}/config.txt" sed -e "s/.*'enh_train_config': '\([^']*\)'.*$/\1/") + + # Create symbolic links + ln -sf "${_enh_model_file}" "${enh_exp}" + ln -sf "${_enh_train_config}" "${enh_exp}" + inference_model=$(basename "${_enh_model_file}") +fi + + if ! "${skip_eval}"; then if [ ${stage} -le 7 ] && [ ${stop_stage} -ge 7 ]; then log "Stage 7: Enhance Speech: training_dir=${enh_exp}" @@ -1101,8 +1122,8 @@ if "${score_with_asr}"; then --ngpu "${_ngpu}" \ --data_path_and_name_and_type "${_ddir}/wav.scp,speech,${_type}" \ --key_file "${_logdir}"/keys.JOB.scp \ - --asr_train_config "${asr_exp}"/config.yaml \ - --asr_model_file "${asr_exp}"/"${inference_asr_model}" \ + --asr_train_config "${enh_exp}"/config.yaml \ + --asr_model_file "${enh_exp}"/"${inference_asr_model}" \ --output_dir "${_logdir}"/output.JOB \ ${_opts} ${inference_asr_args} From fdfe377e7a9e782e671bfbf185e4e5545d3f4923 Mon Sep 17 00:00:00 2001 From: Wangyou Zhang Date: Wed, 10 Jul 2024 20:12:26 +0800 Subject: [PATCH 6/8] Fix a bug related to --download_model in enh.sh --- egs2/TEMPLATE/enh1/enh.sh | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/egs2/TEMPLATE/enh1/enh.sh b/egs2/TEMPLATE/enh1/enh.sh index cc942a8b76a..39cca774c9b 100755 --- a/egs2/TEMPLATE/enh1/enh.sh +++ b/egs2/TEMPLATE/enh1/enh.sh @@ -89,7 +89,7 @@ ref_channel=0 inference_tag= # Prefix to the result dir for ENH inference. inference_enh_config= # Config for enhancement. score_with_asr=false -enh_exp="" # asr model for scoring WER +asr_exp="" # asr model for scoring WER lm_exp="" # lm model for scoring WER inference_asr_model=valid.acc.best.pth # ASR model path for decoding. inference_lm=valid.loss.best.pth # Language model path for decoding. @@ -183,7 +183,7 @@ Options: # ASR evaluation related --score_with_asr # Enable ASR evaluation (default="${score_with_asr}") - --enh_exp # asr model for scoring WER (default="${enh_exp}") + --asr_exp # asr model for scoring WER (default="${asr_exp}") --lm_exp # lm model for scoring WER (default="${lm_exp}") --nlsyms_txt # Non-linguistic symbol list if existing. (default="${nlsyms_txt}") --inference_asr_model # ASR model path for decoding. (default="${inference_asr_model}") @@ -1122,8 +1122,8 @@ if "${score_with_asr}"; then --ngpu "${_ngpu}" \ --data_path_and_name_and_type "${_ddir}/wav.scp,speech,${_type}" \ --key_file "${_logdir}"/keys.JOB.scp \ - --asr_train_config "${enh_exp}"/config.yaml \ - --asr_model_file "${enh_exp}"/"${inference_asr_model}" \ + --asr_train_config "${asr_exp}"/config.yaml \ + --asr_model_file "${asr_exp}"/"${inference_asr_model}" \ --output_dir "${_logdir}"/output.JOB \ ${_opts} ${inference_asr_args} From 22148c51f9fdd5157f2fed37697fab0ca5ce420c Mon Sep 17 00:00:00 2001 From: Wangyou Zhang Date: Thu, 11 Jul 2024 19:55:48 +0800 Subject: [PATCH 7/8] Update BSRNN to support different norm types --- espnet2/enh/layers/bsrnn.py | 128 ++++++++++++++++-- espnet2/enh/separator/bsrnn_separator.py | 3 + .../enh/separator/test_bsrnn_separator.py | 3 + 3 files changed, 124 insertions(+), 10 deletions(-) diff --git a/espnet2/enh/layers/bsrnn.py b/espnet2/enh/layers/bsrnn.py index 078fd723403..241281e4f65 100644 --- a/espnet2/enh/layers/bsrnn.py +++ b/espnet2/enh/layers/bsrnn.py @@ -3,6 +3,11 @@ import torch import torch.nn as nn +from espnet2.enh.layers.tcn import choose_norm as choose_norm1d + + +EPS = torch.finfo(torch.get_default_dtype()).eps + class BSRNN(nn.Module): # ported from https://github.com/sungwon23/BSRNN @@ -14,6 +19,7 @@ def __init__( target_fs=48000, causal=True, num_spk=1, + norm_type="GN", ): """Band-Split RNN (BSRNN). @@ -34,11 +40,13 @@ def __init__( causal (bool): Whether or not to adopt causal processing if True, LSTM will be used instead of BLSTM for time modeling num_spk (int): number of outputs to be generated + norm_type (str): type of normalization layer (cfLN / cLN / BN / GN) """ super().__init__() + norm1d_type = norm_type if norm_type != "cfLN" else "cLN" self.num_layer = num_layer self.band_split = BandSplit( - input_dim, target_fs=target_fs, channels=num_channel + input_dim, target_fs=target_fs, channels=num_channel, norm_type=norm1d_type ) self.target_fs = target_fs self.causal = causal @@ -52,7 +60,7 @@ def __init__( self.fc_freq = nn.ModuleList() hdim = 2 * num_channel for i in range(self.num_layer): - self.norm_time.append(nn.GroupNorm(1, num_channel)) + self.norm_time.append(choose_norm(norm_type, num_channel)) self.rnn_time.append( nn.LSTM( num_channel, @@ -62,14 +70,18 @@ def __init__( ) ) self.fc_time.append(nn.Linear(hdim if causal else hdim * 2, num_channel)) - self.norm_freq.append(nn.GroupNorm(1, num_channel)) + self.norm_freq.append(choose_norm(norm_type, num_channel)) self.rnn_freq.append( nn.LSTM(num_channel, hdim, batch_first=True, bidirectional=True) ) self.fc_freq.append(nn.Linear(4 * num_channel, num_channel)) self.mask_decoder = MaskDecoder( - input_dim, self.band_split.subbands, channels=num_channel, num_spk=num_spk + input_dim, + self.band_split.subbands, + channels=num_channel, + num_spk=num_spk, + norm_type=norm1d_type, ) def forward(self, x, fs=None): @@ -114,7 +126,7 @@ def forward(self, x, fs=None): class BandSplit(nn.Module): - def __init__(self, input_dim, target_fs=48000, channels=128): + def __init__(self, input_dim, target_fs=48000, channels=128, norm_type="GN"): super().__init__() assert input_dim % 2 == 1, input_dim n_fft = (input_dim - 1) * 2 @@ -138,7 +150,7 @@ def __init__(self, input_dim, target_fs=48000, channels=128): self.norm = nn.ModuleList() self.fc = nn.ModuleList() for i in range(len(self.subbands)): - self.norm.append(nn.GroupNorm(1, int(self.subbands[i] * 2))) + self.norm.append(choose_norm1d(norm_type, int(self.subbands[i] * 2))) self.fc.append(nn.Conv1d(int(self.subbands[i] * 2), channels, 1)) def forward(self, x, fs=None): @@ -172,7 +184,6 @@ def forward(self, x, fs=None): else: z = torch.cat((z, out.unsqueeze(-1)), dim=-1) hz_band = hz_band + int(subband) - print(i, subband, hz_band, self.subband_freqs[i]) if hz_band >= x.size(2): break if fs is not None and self.subband_freqs[i] >= fs / 2: @@ -181,7 +192,7 @@ def forward(self, x, fs=None): class MaskDecoder(nn.Module): - def __init__(self, freq_dim, subbands, channels=128, num_spk=1): + def __init__(self, freq_dim, subbands, channels=128, num_spk=1, norm_type="GN"): super().__init__() assert freq_dim == sum(subbands), (freq_dim, subbands) self.subbands = subbands @@ -192,7 +203,7 @@ def __init__(self, freq_dim, subbands, channels=128, num_spk=1): for subband in self.subbands: self.mlp_mask.append( nn.Sequential( - nn.GroupNorm(1, channels), + choose_norm1d(norm_type, channels), nn.Conv1d(channels, 4 * channels, 1), nn.Tanh(), nn.Conv1d(4 * channels, int(subband * 4 * num_spk), 1), @@ -201,7 +212,7 @@ def __init__(self, freq_dim, subbands, channels=128, num_spk=1): ) self.mlp_residual.append( nn.Sequential( - nn.GroupNorm(1, channels), + choose_norm1d(norm_type, channels), nn.Conv1d(channels, 4 * channels, 1), nn.Tanh(), nn.Conv1d(4 * channels, int(subband * 4 * num_spk), 1), @@ -241,3 +252,100 @@ def forward(self, x): m = nn.functional.pad(m, (0, 0, 0, int(self.freq_dim - m.size(-2)))) r = nn.functional.pad(r, (0, 0, 0, int(self.freq_dim - r.size(-2)))) return m.moveaxis(1, 2), r.moveaxis(1, 2) + + +def choose_norm(norm_type, channel_size, shape="BDTF"): + """The input of normalization will be (M, C, K), where M is batch size. + + C is channel size and K is sequence length. + """ + if norm_type == "cfLN": + return ChannelFreqwiseLayerNorm(channel_size, shape=shape) + elif norm_type == "cLN": + return ChannelwiseLayerNorm(channel_size, shape=shape) + elif norm_type == "BN": + # Given input (M, C, T, K), nn.BatchNorm2d(C) will accumulate statics + # along M, T, and K, so this BN usage is right. + return nn.BatchNorm2d(channel_size) + elif norm_type == "GN": + return nn.GroupNorm(1, channel_size) + else: + raise ValueError("Unsupported normalization type") + + +class ChannelwiseLayerNorm(nn.Module): + """Channel-wise Layer Normalization (cLN).""" + + def __init__(self, channel_size, shape="BDTF"): + super().__init__() + self.gamma = nn.Parameter(torch.Tensor(1, channel_size, 1, 1)) # [1, N, 1] + self.beta = nn.Parameter(torch.Tensor(1, channel_size, 1, 1)) # [1, N, 1] + self.reset_parameters() + assert shape in ["BDTF", "BTFD"], shape + self.shape = shape + + def reset_parameters(self): + self.gamma.data.fill_(1) + self.beta.data.zero_() + + @torch.cuda.amp.autocast(enabled=False) + def forward(self, y): + """Forward. + + Args: + y: [M, N, T, K], M is batch size, N is channel size, T and K are lengths + + Returns: + cLN_y: [M, N, T, K] + """ + + assert y.dim() == 4 + + if self.shape == "BTFD": + y = y.moveaxis(-1, 1) + + mean = torch.mean(y, dim=1, keepdim=True) # [M, 1, T, K] + var = torch.var(y, dim=1, keepdim=True, unbiased=False) # [M, 1, T, K] + cLN_y = self.gamma * (y - mean) / torch.pow(var + EPS, 0.5) + self.beta + + if self.shape == "BTFD": + cLN_y = cLN_y.moveaxis(1, -1) + + return cLN_y + + +class ChannelFreqwiseLayerNorm(nn.Module): + """Channel-and-Frequency-wise Layer Normalization (cfLN).""" + + def __init__(self, channel_size, shape="BDTF"): + super().__init__() + self.gamma = nn.Parameter(torch.Tensor(1, channel_size, 1, 1)) # [1, N, 1, 1] + self.beta = nn.Parameter(torch.Tensor(1, channel_size, 1, 1)) # [1, N, 1, 1] + self.reset_parameters() + assert shape in ["BDTF", "BTFD"], shape + self.shape = shape + + def reset_parameters(self): + self.gamma.data.fill_(1) + self.beta.data.zero_() + + @torch.cuda.amp.autocast(enabled=False) + def forward(self, y): + """Forward. + + Args: + y: [M, N, T, K], M is batch size, N is channel size, T and K are lengths + + Returns: + gLN_y: [M, N, T, K] + """ + if self.shape == "BTFD": + y = y.moveaxis(-1, 1) + + mean = y.mean(dim=(1, 3), keepdim=True) # [M, 1, T, 1] + var = (torch.pow(y - mean, 2)).mean(dim=(1, 3), keepdim=True) + gLN_y = self.gamma * (y - mean) / torch.pow(var + EPS, 0.5) + self.beta + + if self.shape == "BTFD": + gLN_y = gLN_y.moveaxis(1, -1) + return gLN_y diff --git a/espnet2/enh/separator/bsrnn_separator.py b/espnet2/enh/separator/bsrnn_separator.py index 3edbc78238f..cdca9f346d3 100644 --- a/espnet2/enh/separator/bsrnn_separator.py +++ b/espnet2/enh/separator/bsrnn_separator.py @@ -18,6 +18,7 @@ def __init__( num_layers: int = 6, target_fs: int = 48000, causal: bool = True, + norm_type: str = "GN", ref_channel: Optional[int] = None, ): """Band-split RNN (BSRNN) separator. @@ -39,6 +40,7 @@ def __init__( target_fs: (int) max sampling frequency that the model can handle. causal (bool): whether or not to apply causal modeling. if True, LSTM will be used instead of BLSTM for time modeling + norm_type (str): type of the normalization layer (cfLN / cLN / BN / GN). ref_channel: (int) reference channel. not used for now. """ super().__init__() @@ -53,6 +55,7 @@ def __init__( target_fs=target_fs, causal=causal, num_spk=num_spk, + norm_type=norm_type, ) def forward( diff --git a/test/espnet2/enh/separator/test_bsrnn_separator.py b/test/espnet2/enh/separator/test_bsrnn_separator.py index e1ffe7856e5..782a9729485 100644 --- a/test/espnet2/enh/separator/test_bsrnn_separator.py +++ b/test/espnet2/enh/separator/test_bsrnn_separator.py @@ -11,6 +11,7 @@ @pytest.mark.parametrize("num_layers", [3]) @pytest.mark.parametrize("target_fs", [48000]) @pytest.mark.parametrize("causal", [True, False]) +@pytest.mark.parametrize("norm_type", ["cfLN", "cLN", "BN", "GN"]) def test_bsrnn_separator_forward_backward_complex( input_dim, num_spk, @@ -18,6 +19,7 @@ def test_bsrnn_separator_forward_backward_complex( num_layers, target_fs, causal, + norm_type, ): model = BSRNNSeparator( input_dim=input_dim, @@ -26,6 +28,7 @@ def test_bsrnn_separator_forward_backward_complex( num_layers=num_layers, target_fs=target_fs, causal=causal, + norm_type=norm_type, ) model.train() From b1bb697c540109b4d0035a3ca1201b730514cd65 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 11 Jul 2024 11:58:07 +0000 Subject: [PATCH 8/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- espnet2/enh/layers/bsrnn.py | 1 - 1 file changed, 1 deletion(-) diff --git a/espnet2/enh/layers/bsrnn.py b/espnet2/enh/layers/bsrnn.py index 241281e4f65..bd4b736efa6 100644 --- a/espnet2/enh/layers/bsrnn.py +++ b/espnet2/enh/layers/bsrnn.py @@ -5,7 +5,6 @@ from espnet2.enh.layers.tcn import choose_norm as choose_norm1d - EPS = torch.finfo(torch.get_default_dtype()).eps