diff --git a/egs2/librispeech/asr1/README.md b/egs2/librispeech/asr1/README.md index 0ce590752d8..32fd2e86022 100644 --- a/egs2/librispeech/asr1/README.md +++ b/egs2/librispeech/asr1/README.md @@ -206,6 +206,61 @@ |decode_asr_lm_lm_train_lm_transformer2_en_bpe5000_valid.loss.ave_asr_model_valid.acc.ave/test_clean|2620|65818|97.7|1.6|0.7|0.4|2.7|25.7| |decode_asr_lm_lm_train_lm_transformer2_en_bpe5000_valid.loss.ave_asr_model_valid.acc.ave/test_other|2939|65101|94.5|3.9|1.5|1.0|6.4|45.1| +# Multiconvformer +- Params: 147.41 M +- ASR config: [conf/tuning/train_asr_multiconvformer_conv_fusion.yaml](conf/tuning/train_asr_multiconvformer_conv_fusion.yaml) +- LM config: [conf/tuning/train_lm_transformer2.yaml](conf/tuning/train_lm_transformer2.yaml) +- Model link: [https://huggingface.co/Darshan7575/librispeech_960_multiconvformer_ctcatt_conv_fusion](https://huggingface.co/Darshan7575/librispeech_960_multiconvformer_ctcatt_conv_fusion) + +# RESULTS +## Environments +- date: `Fri Mar 1 15:40:42 UTC 2024` +- python version: `3.9.16 (main, May 15 2023, 23:46:34) [GCC 11.2.0]` +- espnet version: `espnet 202402` +- pytorch version: `pytorch 2.1.2+cu118` +- Git hash: `a50d6a0c8c31b4ef775473a657de031a40be30c1` + - Commit date: `Mon Feb 19 07:37:52 2024 -0500` + +## exp/asr_train_asr_multiconvformer_conv_fusion_raw_en_bpe5000_sp +### WER + +|dataset|Snt|Wrd|Corr|Sub|Del|Ins|Err|S.Err| +|---|---|---|---|---|---|---|---|---| +|decode_asr_asr_model_valid.acc.ave/dev_clean|2703|54402|98.2|1.6|0.2|0.2|2.0|25.8| +|decode_asr_asr_model_valid.acc.ave/dev_other|2864|50948|95.7|3.9|0.3|0.5|4.7|41.4| +|decode_asr_asr_model_valid.acc.ave/test_clean|2620|52576|98.1|1.7|0.2|0.3|2.2|26.9| +|decode_asr_asr_model_valid.acc.ave/test_other|2939|52343|95.9|3.8|0.3|0.6|4.7|42.6| +|decode_asr_lm_lm_train_lm_transformer2_en_bpe5000_valid.loss.ave_asr_model_valid.acc.ave/dev_clean|2703|54402|98.4|1.4|0.2|0.2|1.7|23.2| +|decode_asr_lm_lm_train_lm_transformer2_en_bpe5000_valid.loss.ave_asr_model_valid.acc.ave/dev_other|2864|50948|96.7|2.9|0.3|0.3|3.6|34.3| +|decode_asr_lm_lm_train_lm_transformer2_en_bpe5000_valid.loss.ave_asr_model_valid.acc.ave/test_clean|2620|52576|98.3|1.5|0.2|0.2|1.9|23.4| +|decode_asr_lm_lm_train_lm_transformer2_en_bpe5000_valid.loss.ave_asr_model_valid.acc.ave/test_other|2939|52343|96.5|3.1|0.5|0.4|3.9|38.0| + +### CER + +|dataset|Snt|Wrd|Corr|Sub|Del|Ins|Err|S.Err| +|---|---|---|---|---|---|---|---|---| +|decode_asr_asr_model_valid.acc.ave/dev_clean|2703|288456|99.5|0.3|0.2|0.2|0.7|25.8| +|decode_asr_asr_model_valid.acc.ave/dev_other|2864|265951|98.5|0.9|0.6|0.5|2.0|41.4| +|decode_asr_asr_model_valid.acc.ave/test_clean|2620|281530|99.5|0.2|0.2|0.2|0.7|26.9| +|decode_asr_asr_model_valid.acc.ave/test_other|2939|272758|98.7|0.8|0.5|0.6|1.9|42.6| +|decode_asr_lm_lm_train_lm_transformer2_en_bpe5000_valid.loss.ave_asr_model_valid.acc.ave/dev_clean|2703|288456|99.5|0.2|0.2|0.2|0.6|23.2| +|decode_asr_lm_lm_train_lm_transformer2_en_bpe5000_valid.loss.ave_asr_model_valid.acc.ave/dev_other|2864|265951|98.7|0.7|0.6|0.4|1.7|34.3| +|decode_asr_lm_lm_train_lm_transformer2_en_bpe5000_valid.loss.ave_asr_model_valid.acc.ave/test_clean|2620|281530|99.5|0.2|0.3|0.2|0.7|23.4| +|decode_asr_lm_lm_train_lm_transformer2_en_bpe5000_valid.loss.ave_asr_model_valid.acc.ave/test_other|2939|272758|98.7|0.7|0.6|0.4|1.7|38.0| + +### TER + +|dataset|Snt|Wrd|Corr|Sub|Del|Ins|Err|S.Err| +|---|---|---|---|---|---|---|---|---| +|decode_asr_asr_model_valid.acc.ave/dev_clean|2703|68010|97.7|1.7|0.6|0.3|2.6|25.8| +|decode_asr_asr_model_valid.acc.ave/dev_other|2864|63110|94.7|4.0|1.3|0.8|6.1|41.4| +|decode_asr_asr_model_valid.acc.ave/test_clean|2620|65818|97.6|1.7|0.7|0.3|2.7|26.9| +|decode_asr_asr_model_valid.acc.ave/test_other|2939|65101|95.0|3.6|1.4|0.7|5.7|42.6| +|decode_asr_lm_lm_train_lm_transformer2_en_bpe5000_valid.loss.ave_asr_model_valid.acc.ave/dev_clean|2703|68010|97.9|1.4|0.7|0.3|2.4|23.2| +|decode_asr_lm_lm_train_lm_transformer2_en_bpe5000_valid.loss.ave_asr_model_valid.acc.ave/dev_other|2864|63110|95.5|3.1|1.4|0.6|5.1|34.3| +|decode_asr_lm_lm_train_lm_transformer2_en_bpe5000_valid.loss.ave_asr_model_valid.acc.ave/test_clean|2620|65818|97.8|1.4|0.8|0.3|2.4|23.4| +|decode_asr_lm_lm_train_lm_transformer2_en_bpe5000_valid.loss.ave_asr_model_valid.acc.ave/test_other|2939|65101|95.5|2.9|1.6|0.5|5.0|38.0| + # E-Branchformer - Params: 148.92 M - ASR config: [conf/tuning/train_asr_e_branchformer.yaml](conf/tuning/train_asr_e_branchformer.yaml) diff --git a/egs2/librispeech/asr1/conf/tuning/train_asr_multiconvformer_conv_fusion.yaml b/egs2/librispeech/asr1/conf/tuning/train_asr_multiconvformer_conv_fusion.yaml new file mode 100644 index 00000000000..d63cd7a7056 --- /dev/null +++ b/egs2/librispeech/asr1/conf/tuning/train_asr_multiconvformer_conv_fusion.yaml @@ -0,0 +1,82 @@ +# Trained with A100 (80 GB) x 2 GPUs. It takes 110 minutes per epoch. +encoder: multiconv_conformer +encoder_conf: + output_size: 512 + attention_heads: 8 + selfattention_layer_type: rel_selfattn + pos_enc_layer_type: rel_pos + rel_pos_type: latest + cgmlp_linear_units: 3072 + multicgmlp_type: concat_fusion + multicgmlp_kernel_sizes: 7,15,23,31 + use_linear_after_conv: false + gate_activation: identity + num_blocks: 18 + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + attention_dropout_rate: 0.1 + input_layer: conv2d + layer_drop_rate: 0.1 + linear_units: 1024 + positionwise_layer_type: linear + macaron_style: true + use_cnn_module: true + +decoder: transformer +decoder_conf: + attention_heads: 8 + linear_units: 2048 + num_blocks: 6 + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + self_attention_dropout_rate: 0.1 + src_attention_dropout_rate: 0.1 + +model_conf: + ctc_weight: 0.3 + lsm_weight: 0.1 + length_normalized_loss: false + +frontend_conf: + n_fft: 512 + hop_length: 160 + +use_amp: true +unused_parameters: true +num_workers: 4 +batch_type: numel +batch_bins: 70000000 +accum_grad: 2 +max_epoch: 80 +patience: none +init: none +best_model_criterion: +- - valid + - acc + - max +keep_nbest_models: 10 +nbest_averaging_interval: 10 + +optim: adam +optim_conf: + lr: 0.002 + weight_decay: 0.000001 +scheduler: warmuplr +scheduler_conf: + warmup_steps: 40000 + +specaug: specaug +specaug_conf: + apply_time_warp: true + time_warp_window: 5 + time_warp_mode: bicubic + apply_freq_mask: true + freq_mask_width_range: + - 0 + - 27 + num_freq_mask: 2 + apply_time_mask: true + time_mask_width_ratio_range: + - 0. + - 0.05 + num_time_mask: 10 diff --git a/egs2/librispeech_100/asr1/README.md b/egs2/librispeech_100/asr1/README.md index a9f559272e3..19005f398d1 100644 --- a/egs2/librispeech_100/asr1/README.md +++ b/egs2/librispeech_100/asr1/README.md @@ -1,3 +1,45 @@ +# Multiconvformer +- Params: 37.21 M +- ASR config: [conf/tuning/train_asr_multiconvformer_conv_fusion_linear1024.yaml](conf/tuning/train_asr_multiconvformer_conv_fusion_linear1024.yaml) +- Model link: [https://huggingface.co/Darshan7575/librispeech_100_multiconvformer_ctcatt_conv_fusion](https://huggingface.co/Darshan7575/librispeech_100_multiconvformer_ctcatt_conv_fusion) + +# RESULTS +## Environments +- date: `Sun Jan 28 23:50:53 UTC 2024` +- python version: `3.9.16 (main, Mar 8 2023, 14:00:05) [GCC 11.2.0]` +- espnet version: `espnet 202304` +- pytorch version: `pytorch 2.1.2+cu118` +- Git hash: `3651c2e67126c4544820cf148407be7f2679866c` + - Commit date: `Sat Jul 1 14:46:46 2023 +0000` + +## exp/librispeech_100_multiconvformer_conv_fusion +### WER + +|dataset|Snt|Wrd|Corr|Sub|Del|Ins|Err|S.Err| +|---|---|---|---|---|---|---|---|---| +|decode_asr_lm_lm_train_en_bpe5000_valid.loss.ave_asr_model_valid.acc.ave/dev_clean|2703|54402|94.8|4.8|0.3|0.7|5.9|53.8| +|decode_asr_lm_lm_train_en_bpe5000_valid.loss.ave_asr_model_valid.acc.ave/dev_other|2864|50948|85.4|13.2|1.4|2.0|16.6|78.8| +|decode_asr_lm_lm_train_en_bpe5000_valid.loss.ave_asr_model_valid.acc.ave/test_clean|2620|52576|94.5|5.0|0.4|0.7|6.2|55.5| +|decode_asr_lm_lm_train_en_bpe5000_valid.loss.ave_asr_model_valid.acc.ave/test_other|2939|52343|85.0|13.6|1.5|2.0|17.0|80.5| + +### CER + +|dataset|Snt|Wrd|Corr|Sub|Del|Ins|Err|S.Err| +|---|---|---|---|---|---|---|---|---| +|decode_asr_lm_lm_train_en_bpe5000_valid.loss.ave_asr_model_valid.acc.ave/dev_clean|2703|288456|98.3|1.0|0.7|0.6|2.3|53.8| +|decode_asr_lm_lm_train_en_bpe5000_valid.loss.ave_asr_model_valid.acc.ave/dev_other|2864|265951|93.6|4.0|2.4|2.0|8.4|78.8| +|decode_asr_lm_lm_train_en_bpe5000_valid.loss.ave_asr_model_valid.acc.ave/test_clean|2620|281530|98.3|1.0|0.7|0.6|2.4|55.5| +|decode_asr_lm_lm_train_en_bpe5000_valid.loss.ave_asr_model_valid.acc.ave/test_other|2939|272758|93.6|3.8|2.6|1.9|8.2|80.5| + +### TER + +|dataset|Snt|Wrd|Corr|Sub|Del|Ins|Err|S.Err| +|---|---|---|---|---|---|---|---|---| +|decode_asr_lm_lm_train_en_bpe5000_valid.loss.ave_asr_model_valid.acc.ave/dev_clean|2703|69558|92.5|4.7|2.8|0.6|8.1|53.8| +|decode_asr_lm_lm_train_en_bpe5000_valid.loss.ave_asr_model_valid.acc.ave/dev_other|2864|64524|82.0|12.9|5.0|2.4|20.4|78.8| +|decode_asr_lm_lm_train_en_bpe5000_valid.loss.ave_asr_model_valid.acc.ave/test_clean|2620|66983|92.4|4.8|2.8|0.6|8.2|55.5| +|decode_asr_lm_lm_train_en_bpe5000_valid.loss.ave_asr_model_valid.acc.ave/test_other|2939|66650|81.6|12.9|5.5|2.2|20.6|80.5| + # E-Branchformer ## Environments - date: `Mon Dec 12 06:50:58 CST 2022` @@ -40,7 +82,6 @@ |decode_asr_asr_model_valid.acc.ave/test_clean|2620|66983|92.2|4.9|2.9|0.6|8.4|56.1| |decode_asr_asr_model_valid.acc.ave/test_other|2939|66650|81.5|13.0|5.5|2.2|20.7|80.3| - # E-Branchformer with CTC ## Environments - date: `Sun Jan 1 15:05:07 CST 2023` diff --git a/egs2/librispeech_100/asr1/conf/tuning/train_asr_multiconvformer_conv_fusion_linear1024.yaml b/egs2/librispeech_100/asr1/conf/tuning/train_asr_multiconvformer_conv_fusion_linear1024.yaml new file mode 100644 index 00000000000..6228c7cd66f --- /dev/null +++ b/egs2/librispeech_100/asr1/conf/tuning/train_asr_multiconvformer_conv_fusion_linear1024.yaml @@ -0,0 +1,84 @@ +# Trained with A100 (80 GB) x 1 GPUs. It takes 15 minutes per epoch. +encoder: multiconv_conformer +encoder_conf: + output_size: 256 + attention_heads: 4 + selfattention_layer_type: rel_selfattn + pos_enc_layer_type: rel_pos + rel_pos_type: latest + cgmlp_linear_units: 1024 + multicgmlp_type: concat_fusion + multicgmlp_kernel_sizes: 7,15,23,31 + multicgmlp_merge_conv_kernel: 31 + use_linear_after_conv: false + gate_activation: identity + num_blocks: 12 + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + attention_dropout_rate: 0.1 + input_layer: conv2d + layer_drop_rate: 0.0 + linear_units: 1024 + positionwise_layer_type: linear + macaron_style: true + use_cnn_module: true + +decoder: transformer +decoder_conf: + attention_heads: 4 + linear_units: 2048 + num_blocks: 6 + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + self_attention_dropout_rate: 0.1 + src_attention_dropout_rate: 0.1 + layer_drop_rate: 0.0 + +model_conf: + ctc_weight: 0.3 + lsm_weight: 0.1 + length_normalized_loss: false + +frontend_conf: + n_fft: 512 + win_length: 400 + hop_length: 160 + +seed: 2022 +num_workers: 4 +batch_type: numel +batch_bins: 16000000 +accum_grad: 4 +max_epoch: 70 +patience: none +init: none +best_model_criterion: +- - valid + - acc + - max +keep_nbest_models: 10 +use_amp: true + +optim: adam +optim_conf: + lr: 0.002 + weight_decay: 0.000001 +scheduler: warmuplr +scheduler_conf: + warmup_steps: 15000 + +specaug: specaug +specaug_conf: + apply_time_warp: true + time_warp_window: 5 + time_warp_mode: bicubic + apply_freq_mask: true + freq_mask_width_range: + - 0 + - 27 + num_freq_mask: 2 + apply_time_mask: true + time_mask_width_ratio_range: + - 0. + - 0.05 + num_time_mask: 5 diff --git a/egs2/slurp_entity/asr1/README.md b/egs2/slurp_entity/asr1/README.md index 5834e26db7c..6d28bd32a36 100644 --- a/egs2/slurp_entity/asr1/README.md +++ b/egs2/slurp_entity/asr1/README.md @@ -1,3 +1,29 @@ +# Multiconvformer +- Params: 108.09 M +- ASR config: [conf/tuning/train_asr_multiconv_e12_mlp3072_linear2048_layerdrop.yaml](conf/tuning/train_asr_multiconv_e12_mlp3072_linear2048_layerdrop.yaml) +- Model link: [https://huggingface.co/Darshan7575/slurp_multiconvformer_conv_fusion](https://huggingface.co/Darshan7575/slurp_multiconvformer_conv_fusion) + +# RESULTS +## Environments +- date: `Wed Feb 21 01:04:03 EST 2024` +- python version: `3.9.18 (main, Sep 11 2023, 13:41:44) [GCC 11.2.0]` +- espnet version: `espnet 202310` +- pytorch version: `pytorch 2.1.2+cu118` +- Git hash: `edb6ec64bb5d4f2c68a3b81674f0c2822e2e5b58` + - Commit date: `Fri Feb 9 21:26:35 2024 +0530` + +### Intent Classification + +- Valid Intent Classification Result: 0.8882623705408516 +- Test Intent Classification Result: 0.8737574552683897 + +### Entity + +|Slu f1|Precision|Recall|F-Measure| +|:---:|:---:|:---:|:---:| +| test | 0.8076 | 0.7710 | 0.7889 | + + # E-Branchformer - ASR config: [conf/tuning/train_asr_e_branchformer_e12_mlp3072_linear1024_layerdrop.yaml](conf/tuning/train_asr_e_branchformer_e12_mlp3072_linear1024_layerdrop.yaml) diff --git a/egs2/slurp_entity/asr1/conf/tuning/train_asr_multiconv_e12_mlp3072_linear2048_layerdrop.yaml b/egs2/slurp_entity/asr1/conf/tuning/train_asr_multiconv_e12_mlp3072_linear2048_layerdrop.yaml new file mode 100644 index 00000000000..1d392cecb65 --- /dev/null +++ b/egs2/slurp_entity/asr1/conf/tuning/train_asr_multiconv_e12_mlp3072_linear2048_layerdrop.yaml @@ -0,0 +1,78 @@ +# network architecture +# encoder related +encoder: multiconv_conformer +encoder_conf: + output_size: 512 + attention_heads: 8 + selfattention_layer_type: rel_selfattn + pos_enc_layer_type: rel_pos + rel_pos_type: latest + cgmlp_linear_units: 3072 + multicgmlp_type: concat_fusion + multicgmlp_kernel_sizes: 7,15,23,31 + multicgmlp_merge_conv_kernel: 31 + use_linear_after_conv: false + gate_activation: identity + num_blocks: 12 # Maybe we can increase the size by 1 to match e-branchformer + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + attention_dropout_rate: 0.1 + input_layer: conv2d + layer_drop_rate: 0.1 + linear_units: 1152 + positionwise_layer_type: linear + macaron_style: true + use_cnn_module: true + + +decoder: transformer +decoder_conf: + attention_heads: 8 + linear_units: 2048 + num_blocks: 6 + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + self_attention_dropout_rate: 0.1 + src_attention_dropout_rate: 0.1 + layer_drop_rate: 0.2 + +optim: adam +optim_conf: + lr: 0.001 + weight_decay: 0.000001 +scheduler: warmuplr +scheduler_conf: + warmup_steps: 35000 + +unused_parameters: true +batch_type: folded +batch_size: 64 +accum_grad: 1 +max_epoch: 60 +best_model_criterion: +- - valid + - acc + - max +keep_nbest_models: 10 + +model_conf: + ctc_weight: 0.3 + lsm_weight: 0.1 + length_normalized_loss: false + extract_feats_in_collect_stats: false # Note: "False" means during collect stats (stage 10), generating dummy stats files rather than extract_feats by forward frontend. + +specaug: specaug +specaug_conf: + apply_time_warp: true + time_warp_window: 5 + time_warp_mode: bicubic + apply_freq_mask: true + freq_mask_width_range: + - 0 + - 30 + num_freq_mask: 2 + apply_time_mask: true + time_mask_width_range: + - 0 + - 40 + num_time_mask: 2 diff --git a/espnet2/asr/encoder/multiconvformer_encoder.py b/espnet2/asr/encoder/multiconvformer_encoder.py new file mode 100644 index 00000000000..cc755e79140 --- /dev/null +++ b/espnet2/asr/encoder/multiconvformer_encoder.py @@ -0,0 +1,404 @@ +# Copyright 2024 Darshan Prabhu (IIT Bombay) +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Multiconvformer encoder definition.""" + +import logging +from typing import List, Optional, Tuple, Union + +import torch +from typeguard import typechecked + +from espnet2.asr.ctc import CTC +from espnet2.asr.encoder.abs_encoder import AbsEncoder +from espnet2.asr.layers.multiconv_cgmlp import MultiConvolutionalGatingMLP +from espnet.nets.pytorch_backend.conformer.encoder_layer import EncoderLayer +from espnet.nets.pytorch_backend.nets_utils import get_activation, make_pad_mask +from espnet.nets.pytorch_backend.transformer.attention import ( + LegacyRelPositionMultiHeadedAttention, + MultiHeadedAttention, + RelPositionMultiHeadedAttention, +) +from espnet.nets.pytorch_backend.transformer.embedding import ( + LegacyRelPositionalEncoding, + PositionalEncoding, + RelPositionalEncoding, + ScaledPositionalEncoding, +) +from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm +from espnet.nets.pytorch_backend.transformer.multi_layer_conv import ( + Conv1dLinear, + MultiLayeredConv1d, +) +from espnet.nets.pytorch_backend.transformer.positionwise_feed_forward import ( + PositionwiseFeedForward, +) +from espnet.nets.pytorch_backend.transformer.repeat import repeat +from espnet.nets.pytorch_backend.transformer.subsampling import ( + Conv2dSubsampling, + Conv2dSubsampling1, + Conv2dSubsampling2, + Conv2dSubsampling6, + Conv2dSubsampling8, + TooShortUttError, + check_short_utt, +) + + +class MultiConvConformerEncoder(AbsEncoder): + """Multiconvformer encoder module. + Link to the paper: https://arxiv.org/abs/2407.03718 + + Args: + input_size (int): Input dimension. + output_size (int): Dimension of attention. + attention_heads (int): The number of heads of multi head attention. + linear_units (int): The number of units of position-wise feed forward. + num_blocks (int): The number of decoder blocks. + dropout_rate (float): Dropout rate. + positional_dropout_rate (float): Dropout rate after adding positional encoding. + attention_dropout_rate (float): Dropout rate in attention. + cgmlp_linear_units (int): The number of units used in CGMLP block. + multicgmlp_type (str): "sum", "weighted_sum", "concat" or "concat_fusion". + multicgmlp_kernel_sizes (str): Comma seperated list of kernel sizes. + multicgmlp_merge_conv_kernel (int): The number of kernels used in depthwise + convolution fusion in MultiCGMLP. + use_linear_after_conv (bool): Whether to use a linear layer after MultiCGMLP. + gate_activation (str): The activation function used in CGMLP gating. + input_layer (Union[str, torch.nn.Module]): Input layer type. + normalize_before (bool): Whether to use layer_norm before the first block. + concat_after (bool): Whether to concat attention layer's input and output. + If True, additional linear will be applied. + i.e. x -> x + linear(concat(x, att(x))) + If False, no additional linear will be applied. i.e. x -> x + att(x) + positionwise_layer_type (str): "linear", "conv1d", or "conv1d-linear". + positionwise_conv_kernel_size (int): Kernel size of positionwise conv1d layer. + rel_pos_type (str): Whether to use the latest relative positional encoding or + the legacy one. The legacy relative positional encoding will be deprecated + in the future. More Details can be found in + https://github.com/espnet/espnet/pull/2816. + encoder_pos_enc_layer_type (str): Encoder positional encoding layer type. + encoder_attn_layer_type (str): Encoder attention layer type. + activation_type (str): Encoder activation function type. + macaron_style (bool): Whether to use macaron style for positionwise layer. + use_cnn_module (bool): Whether to use convolution module. + zero_triu (bool): Whether to zero the upper triangular part of attention matrix. + cnn_module_kernel (int): Kernerl size of convolution module. + padding_idx (int): Padding idx for input_layer=embed. + + """ + + @typechecked + def __init__( + self, + input_size: int, + output_size: int = 256, + attention_heads: int = 4, + linear_units: int = 2048, + num_blocks: int = 6, + dropout_rate: float = 0.1, + positional_dropout_rate: float = 0.1, + attention_dropout_rate: float = 0.0, + cgmlp_linear_units: int = 2048, + multicgmlp_type: str = "concat_fusion", + multicgmlp_kernel_sizes: Union[int, str] = "7,15,23,31", + multicgmlp_merge_conv_kernel: int = 31, + multicgmlp_use_non_linear: int = True, + use_linear_after_conv: bool = False, + gate_activation: str = "identity", + input_layer: str = "conv2d", + normalize_before: bool = True, + concat_after: bool = False, + positionwise_layer_type: str = "linear", + positionwise_conv_kernel_size: int = 3, + macaron_style: bool = False, + rel_pos_type: str = "legacy", + pos_enc_layer_type: str = "rel_pos", + selfattention_layer_type: str = "rel_selfattn", + activation_type: str = "swish", + use_cnn_module: bool = True, + zero_triu: bool = False, + padding_idx: int = -1, + interctc_layer_idx: List[int] = [], + interctc_use_conditioning: bool = False, + stochastic_depth_rate: Union[float, List[float]] = 0.0, + layer_drop_rate: float = 0.0, + max_pos_emb_len: int = 5000, + ): + super().__init__() + self._output_size = output_size + + if rel_pos_type == "legacy": + if pos_enc_layer_type == "rel_pos": + pos_enc_layer_type = "legacy_rel_pos" + if selfattention_layer_type == "rel_selfattn": + selfattention_layer_type = "legacy_rel_selfattn" + elif rel_pos_type == "latest": + assert selfattention_layer_type != "legacy_rel_selfattn" + assert pos_enc_layer_type != "legacy_rel_pos" + else: + raise ValueError("unknown rel_pos_type: " + rel_pos_type) + + activation = get_activation(activation_type) + if pos_enc_layer_type == "abs_pos": + pos_enc_class = PositionalEncoding + elif pos_enc_layer_type == "scaled_abs_pos": + pos_enc_class = ScaledPositionalEncoding + elif pos_enc_layer_type == "rel_pos": + assert selfattention_layer_type == "rel_selfattn" + pos_enc_class = RelPositionalEncoding + elif pos_enc_layer_type == "legacy_rel_pos": + assert selfattention_layer_type == "legacy_rel_selfattn" + pos_enc_class = LegacyRelPositionalEncoding + logging.warning( + "Using legacy_rel_pos and it will be deprecated in the future." + ) + else: + raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type) + + if input_layer == "linear": + self.embed = torch.nn.Sequential( + torch.nn.Linear(input_size, output_size), + torch.nn.LayerNorm(output_size), + torch.nn.Dropout(dropout_rate), + pos_enc_class(output_size, positional_dropout_rate, max_pos_emb_len), + ) + elif input_layer == "conv2d": + self.embed = Conv2dSubsampling( + input_size, + output_size, + dropout_rate, + pos_enc_class(output_size, positional_dropout_rate, max_pos_emb_len), + ) + elif input_layer == "conv2d1": + self.embed = Conv2dSubsampling1( + input_size, + output_size, + dropout_rate, + pos_enc_class(output_size, positional_dropout_rate, max_pos_emb_len), + ) + elif input_layer == "conv2d2": + self.embed = Conv2dSubsampling2( + input_size, + output_size, + dropout_rate, + pos_enc_class(output_size, positional_dropout_rate, max_pos_emb_len), + ) + elif input_layer == "conv2d6": + self.embed = Conv2dSubsampling6( + input_size, + output_size, + dropout_rate, + pos_enc_class(output_size, positional_dropout_rate, max_pos_emb_len), + ) + elif input_layer == "conv2d8": + self.embed = Conv2dSubsampling8( + input_size, + output_size, + dropout_rate, + pos_enc_class(output_size, positional_dropout_rate, max_pos_emb_len), + ) + elif input_layer == "embed": + self.embed = torch.nn.Sequential( + torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx), + pos_enc_class(output_size, positional_dropout_rate, max_pos_emb_len), + ) + elif isinstance(input_layer, torch.nn.Module): + self.embed = torch.nn.Sequential( + input_layer, + pos_enc_class(output_size, positional_dropout_rate, max_pos_emb_len), + ) + elif input_layer is None: + self.embed = torch.nn.Sequential( + pos_enc_class(output_size, positional_dropout_rate, max_pos_emb_len) + ) + else: + raise ValueError("unknown input_layer: " + input_layer) + self.normalize_before = normalize_before + if positionwise_layer_type == "linear": + positionwise_layer = PositionwiseFeedForward + positionwise_layer_args = ( + output_size, + linear_units, + dropout_rate, + activation, + ) + elif positionwise_layer_type == "conv1d": + positionwise_layer = MultiLayeredConv1d + positionwise_layer_args = ( + output_size, + linear_units, + positionwise_conv_kernel_size, + dropout_rate, + ) + elif positionwise_layer_type == "conv1d-linear": + positionwise_layer = Conv1dLinear + positionwise_layer_args = ( + output_size, + linear_units, + positionwise_conv_kernel_size, + dropout_rate, + ) + else: + raise NotImplementedError("Support only linear or conv1d.") + + if selfattention_layer_type == "selfattn": + encoder_selfattn_layer = MultiHeadedAttention + encoder_selfattn_layer_args = ( + attention_heads, + output_size, + attention_dropout_rate, + ) + elif selfattention_layer_type == "legacy_rel_selfattn": + assert pos_enc_layer_type == "legacy_rel_pos" + encoder_selfattn_layer = LegacyRelPositionMultiHeadedAttention + encoder_selfattn_layer_args = ( + attention_heads, + output_size, + attention_dropout_rate, + ) + logging.warning( + "Using legacy_rel_selfattn and it will be deprecated in the future." + ) + elif selfattention_layer_type == "rel_selfattn": + assert pos_enc_layer_type == "rel_pos" + encoder_selfattn_layer = RelPositionMultiHeadedAttention + encoder_selfattn_layer_args = ( + attention_heads, + output_size, + attention_dropout_rate, + zero_triu, + ) + else: + raise ValueError("unknown encoder_attn_layer: " + selfattention_layer_type) + + if isinstance(multicgmlp_kernel_sizes, int): + multicgmlp_kernel_sizes = str(multicgmlp_kernel_sizes) + + convolution_layer = MultiConvolutionalGatingMLP + convolution_layer_args = ( + output_size, + cgmlp_linear_units, + multicgmlp_type, + multicgmlp_kernel_sizes, + multicgmlp_merge_conv_kernel, + multicgmlp_use_non_linear, + dropout_rate, + use_linear_after_conv, + activation, + gate_activation, + ) + + if isinstance(stochastic_depth_rate, float): + stochastic_depth_rate = [stochastic_depth_rate] * num_blocks + + if len(stochastic_depth_rate) != num_blocks: + raise ValueError( + f"Length of stochastic_depth_rate ({len(stochastic_depth_rate)}) " + f"should be equal to num_blocks ({num_blocks})" + ) + + self.encoders = repeat( + num_blocks, + lambda lnum: EncoderLayer( + output_size, + encoder_selfattn_layer(*encoder_selfattn_layer_args), + positionwise_layer(*positionwise_layer_args), + positionwise_layer(*positionwise_layer_args) if macaron_style else None, + convolution_layer(*convolution_layer_args) if use_cnn_module else None, + dropout_rate, + normalize_before, + concat_after, + stochastic_depth_rate[lnum], + ), + layer_drop_rate, + ) + if self.normalize_before: + self.after_norm = LayerNorm(output_size) + + self.interctc_layer_idx = interctc_layer_idx + if len(interctc_layer_idx) > 0: + assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks + self.interctc_use_conditioning = interctc_use_conditioning + self.conditioning_layer = None + + def output_size(self) -> int: + return self._output_size + + def forward( + self, + xs_pad: torch.Tensor, + ilens: torch.Tensor, + prev_states: torch.Tensor = None, + ctc: CTC = None, + ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + """Calculate forward propagation. + + Args: + xs_pad (torch.Tensor): Input tensor (#batch, L, input_size). + ilens (torch.Tensor): Input length (#batch). + prev_states (torch.Tensor): Not to be used now. + + Returns: + torch.Tensor: Output tensor (#batch, L, output_size). + torch.Tensor: Output length (#batch). + torch.Tensor: Not to be used now. + + """ + masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device) + + if ( + isinstance(self.embed, Conv2dSubsampling) + or isinstance(self.embed, Conv2dSubsampling1) + or isinstance(self.embed, Conv2dSubsampling2) + or isinstance(self.embed, Conv2dSubsampling6) + or isinstance(self.embed, Conv2dSubsampling8) + ): + short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1)) + if short_status: + raise TooShortUttError( + f"has {xs_pad.size(1)} frames and is too short for subsampling " + + f"(it needs more than {limit_size} frames), return empty results", + xs_pad.size(1), + limit_size, + ) + xs_pad, masks = self.embed(xs_pad, masks) + else: + xs_pad = self.embed(xs_pad) + + intermediate_outs = [] + if len(self.interctc_layer_idx) == 0: + xs_pad, masks = self.encoders(xs_pad, masks) + else: + for layer_idx, encoder_layer in enumerate(self.encoders): + xs_pad, masks = encoder_layer(xs_pad, masks) + + if layer_idx + 1 in self.interctc_layer_idx: + encoder_out = xs_pad + if isinstance(encoder_out, tuple): + encoder_out = encoder_out[0] + + # intermediate outputs are also normalized + if self.normalize_before: + encoder_out = self.after_norm(encoder_out) + + intermediate_outs.append((layer_idx + 1, encoder_out)) + + if self.interctc_use_conditioning: + ctc_out = ctc.softmax(encoder_out) + + if isinstance(xs_pad, tuple): + x, pos_emb = xs_pad + x = x + self.conditioning_layer(ctc_out) + xs_pad = (x, pos_emb) + else: + xs_pad = xs_pad + self.conditioning_layer(ctc_out) + + if isinstance(xs_pad, tuple): + xs_pad = xs_pad[0] + if self.normalize_before: + xs_pad = self.after_norm(xs_pad) + olens = masks.squeeze(1).sum(1) + if len(intermediate_outs) > 0: + return (xs_pad, intermediate_outs), olens, None + return xs_pad, olens, None diff --git a/espnet2/asr/layers/multiconv_cgmlp.py b/espnet2/asr/layers/multiconv_cgmlp.py new file mode 100644 index 00000000000..91cb0e70cfc --- /dev/null +++ b/espnet2/asr/layers/multiconv_cgmlp.py @@ -0,0 +1,223 @@ +"""Extension of convolutional gating (cgMLP) definition with multiple convolutions. + +References: + https://openreview.net/forum?id=RA-zVvZLYIy + https://arxiv.org/abs/2105.08050 + https://arxiv.org/abs/2407.03718 +""" + +import torch + +from espnet.nets.pytorch_backend.nets_utils import get_activation +from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm + + +class MultiConvolutionalSpatialGatingUnit(torch.nn.Module): + """Multi Convolutional Spatial Gating Unit (M-CSGU).""" + + def __init__( + self, + size: int, + arch_type: str, + kernel_sizes: str, + merge_conv_kernel: int, + use_non_linear: bool, + dropout_rate: float, + use_linear_after_conv: bool, + activation, + gate_activation: str, + ): + super().__init__() + + n_channels = size // 2 # split input channels + self.norm = LayerNorm(n_channels) + + kernel_sizes = list(map(int, kernel_sizes.split(","))) + no_kernels = len(kernel_sizes) + + assert ( + n_channels % no_kernels == 0 + ), f"{n_channels} input channels cannot be divided between {no_kernels} kernels" + + self.arch_type = arch_type + if arch_type in ["sum", "weighted_sum"]: + self.convs = torch.nn.ModuleList( + [ + torch.nn.Conv1d( + n_channels, + n_channels, + kernel_size, + 1, + (kernel_size - 1) // 2, + groups=n_channels, + ) + for kernel_size in kernel_sizes + ] + ) + elif arch_type in ["concat", "concat_fusion"]: + self.convs = torch.nn.ModuleList( + [ + torch.nn.Conv1d( + n_channels, + n_channels // no_kernels, + kernel_size, + 1, + (kernel_size - 1) // 2, + groups=n_channels // no_kernels, + ) + for kernel_size in kernel_sizes + ] + ) + else: + raise NotImplementedError( + f"Unknown architecture type for MultiConvCGMLP: {arch_type}" + ) + self.use_non_linear = use_non_linear + if arch_type == "weighted_sum": + self.kernel_prob_gen = torch.nn.Sequential( + torch.nn.Linear(n_channels * no_kernels, no_kernels), + torch.nn.Softmax(dim=-1), + ) + self.depthwise_conv_fusion = None + elif arch_type == "concat_fusion": + self.kernel_prob_gen = None + self.depthwise_conv_fusion = torch.nn.Conv1d( + n_channels, + n_channels, + kernel_size=merge_conv_kernel, + stride=1, + padding=(merge_conv_kernel - 1) // 2, + groups=n_channels, + bias=True, + ) + else: + self.kernel_prob_gen = None + self.depthwise_conv_fusion = None + + if use_linear_after_conv: + self.linear = torch.nn.Linear(n_channels, n_channels) + else: + self.linear = None + + self.model_act = activation + if gate_activation == "identity": + self.act = torch.nn.Identity() + else: + self.act = get_activation(gate_activation) + + self.dropout = torch.nn.Dropout(dropout_rate) + + def espnet_initialization_fn(self): + for conv in self.convs: + torch.nn.init.normal_(conv.weight, std=1e-6) + torch.nn.init.ones_(conv.bias) + if self.depthwise_conv_fusion is not None: + torch.nn.init.normal_(self.depthwise_conv_fusion.weight, std=1e-6) + torch.nn.init.ones_(self.depthwise_conv_fusion.bias) + if self.linear is not None: + torch.nn.init.normal_(self.linear.weight, std=1e-6) + torch.nn.init.ones_(self.linear.bias) + + def forward(self, x, gate_add=None): + """Forward method + + Args: + x (torch.Tensor): (N, T, D) + gate_add (torch.Tensor): (N, T, D/2) + + Returns: + out (torch.Tensor): (N, T, D/2) + """ + x_r, x_i = x.chunk(2, dim=-1) + + x_i = self.norm(x_i).transpose(1, 2) # (N, D/2, T) + + # TODO: Parallelize this convolution computation + xs = [] + for conv in self.convs: + xi = conv(x_i).transpose(1, 2) # (N, T, D/2) + if self.arch_type == "sum" and self.use_non_linear: + xi = self.model_act(xi) + xs.append(xi) + + if self.arch_type in ["sum", "weighted_sum"]: + x = torch.stack(xs, dim=-2) + if self.arch_type == "weighted_sum": + prob = self.kernel_prob_gen(torch.cat(xs, dim=-1)) + x = prob.unsqueeze(-1) * x + + x_g = x.sum(dim=-2) + else: + x_concat = torch.cat(xs, dim=-1) # (N, T, D) + + if self.arch_type == "concat_fusion": + x_tmp = x_concat.transpose(1, 2) + x_tmp = self.depthwise_conv_fusion(x_tmp) + x_concat = x_concat + x_tmp.transpose(1, 2) + + x_g = x_concat + + if self.linear is not None: + x_g = self.linear(x_g) + + if gate_add is not None: + x_g = x_g + gate_add + + x_g = self.act(x_g) + out = x_r * x_g # (N, T, D/2) + out = self.dropout(out) + return out + + +class MultiConvolutionalGatingMLP(torch.nn.Module): + """Convolutional Gating MLP (cgMLP).""" + + def __init__( + self, + size: int, + linear_units: int, + arch_type: str, + kernel_sizes: str, + merge_conv_kernel: int, + use_non_linear: bool, + dropout_rate: float, + use_linear_after_conv: bool, + activation, + gate_activation: str, + ): + super().__init__() + + if arch_type not in ["sum", "weighted_sum", "concat", "concat_fusion"]: + raise NotImplementedError(f"Unknown MultiConvCGMLP type: {type}") + + self.channel_proj1 = torch.nn.Sequential( + torch.nn.Linear(size, linear_units), torch.nn.GELU() + ) + self.csgu = MultiConvolutionalSpatialGatingUnit( + size=linear_units, + arch_type=arch_type, + kernel_sizes=kernel_sizes, + merge_conv_kernel=merge_conv_kernel, + use_non_linear=use_non_linear, + dropout_rate=dropout_rate, + use_linear_after_conv=use_linear_after_conv, + activation=activation, + gate_activation=gate_activation, + ) + self.channel_proj2 = torch.nn.Linear(linear_units // 2, size) + + def forward(self, x, mask=None): + if isinstance(x, tuple): + xs_pad, pos_emb = x + else: + xs_pad, pos_emb = x, None + + xs_pad = self.channel_proj1(xs_pad) # size -> linear_units + xs_pad = self.csgu(xs_pad) # linear_units -> linear_units/2 + xs_pad = self.channel_proj2(xs_pad) # linear_units/2 -> size + + if pos_emb is not None: + out = (xs_pad, pos_emb) + else: + out = xs_pad + return out diff --git a/espnet2/tasks/asr.py b/espnet2/tasks/asr.py index ab617ad7105..5c737da0dfe 100644 --- a/espnet2/tasks/asr.py +++ b/espnet2/tasks/asr.py @@ -40,6 +40,7 @@ TorchAudioHuBERTPretrainEncoder, ) from espnet2.asr.encoder.longformer_encoder import LongformerEncoder +from espnet2.asr.encoder.multiconvformer_encoder import MultiConvConformerEncoder from espnet2.asr.encoder.rnn_encoder import RNNEncoder from espnet2.asr.encoder.transformer_encoder import TransformerEncoder from espnet2.asr.encoder.transformer_encoder_multispkr import ( @@ -157,6 +158,7 @@ whisper=OpenAIWhisperEncoder, e_branchformer=EBranchformerEncoder, avhubert=FairseqAVHubertEncoder, + multiconv_conformer=MultiConvConformerEncoder, ), type_check=AbsEncoder, default="rnn", diff --git a/test/espnet2/asr/encoder/test_multiconvformer_encoder.py b/test/espnet2/asr/encoder/test_multiconvformer_encoder.py new file mode 100644 index 00000000000..158ba6c7d32 --- /dev/null +++ b/test/espnet2/asr/encoder/test_multiconvformer_encoder.py @@ -0,0 +1,179 @@ +import pytest +import torch + +from espnet2.asr.ctc import CTC +from espnet2.asr.encoder.multiconvformer_encoder import MultiConvConformerEncoder + + +@pytest.mark.parametrize( + "input_layer", + ["linear", "conv2d", "conv2d1", "conv2d2", "conv2d6", "conv2d8", "embed"], +) +@pytest.mark.parametrize("use_linear_after_conv", [True, False]) +@pytest.mark.parametrize("positionwise_layer_type", ["conv1d", "conv1d-linear"]) +@pytest.mark.parametrize( + "rel_pos_type, pos_enc_layer_type, selfattention_layer_type", + [ + ("legacy", "abs_pos", "selfattn"), + ("latest", "rel_pos", "rel_selfattn"), + ("legacy", "rel_pos", "rel_selfattn"), + ("legacy", "legacy_rel_pos", "legacy_rel_selfattn"), + ], +) +@pytest.mark.parametrize( + "interctc_layer_idx, interctc_use_conditioning", + [ + ([], False), + ([1], False), + ([1], True), + ], +) +@pytest.mark.parametrize( + "multicgmlp_type, multicgmlp_kernel_sizes", + [ + ("sum", "3"), + ("sum", "3,5"), + ("sum", "3,5,7"), + ("weighted_sum", "3"), + ("weighted_sum", "3,5"), + ("weighted_sum", "3,5,7"), + ("concat", "3"), + ("concat", "3,5"), + ("concat", "3,5,7"), + ("concat_fusion", "3"), + ("concat_fusion", "3,5"), + ("concat_fusion", "3,5,7"), + ], +) +@pytest.mark.parametrize("multicgmlp_merge_conv_kernel", [3, 31]) +@pytest.mark.parametrize("stochastic_depth_rate", [0.0, 0.1, [0.1, 0.1]]) +def test_encoder_forward_backward( + input_layer, + use_linear_after_conv, + positionwise_layer_type, + rel_pos_type, + pos_enc_layer_type, + selfattention_layer_type, + interctc_layer_idx, + interctc_use_conditioning, + multicgmlp_type, + multicgmlp_kernel_sizes, + multicgmlp_merge_conv_kernel, + stochastic_depth_rate, +): + encoder = MultiConvConformerEncoder( + 20, + output_size=2, + attention_heads=2, + linear_units=4, + num_blocks=2, + input_layer=input_layer, + selfattention_layer_type=selfattention_layer_type, + pos_enc_layer_type=pos_enc_layer_type, + positionwise_layer_type=positionwise_layer_type, + rel_pos_type=rel_pos_type, + cgmlp_linear_units=36, + use_cnn_module=True, + use_linear_after_conv=use_linear_after_conv, + gate_activation="identity", + multicgmlp_type=multicgmlp_type, + multicgmlp_kernel_sizes=multicgmlp_kernel_sizes, + multicgmlp_merge_conv_kernel=multicgmlp_merge_conv_kernel, + interctc_layer_idx=interctc_layer_idx, + interctc_use_conditioning=interctc_use_conditioning, + stochastic_depth_rate=stochastic_depth_rate, + ) + if input_layer == "embed": + x = torch.randint(0, 10, [2, 32]) + else: + x = torch.randn(2, 32, 20, requires_grad=True) + x_lens = torch.LongTensor([32, 28]) + + if len(interctc_layer_idx) > 0: # intermediate CTC + ctc = None + if interctc_use_conditioning: + vocab_size = 5 + output_size = encoder.output_size() + ctc = CTC(odim=vocab_size, encoder_output_size=output_size) + encoder.conditioning_layer = torch.nn.Linear(vocab_size, output_size) + y, _, _ = encoder(x, x_lens, ctc=ctc) + y = y[0] + else: + y, _, _ = encoder(x, x_lens) + + y.sum().backward() + + +def test_encoder_invalid_layer_type(): + with pytest.raises(ValueError): + MultiConvConformerEncoder(20, input_layer="dummy") + with pytest.raises(ValueError): + MultiConvConformerEncoder(20, rel_pos_type="dummy") + with pytest.raises(ValueError): + MultiConvConformerEncoder(20, pos_enc_layer_type="dummy") + with pytest.raises(ValueError): + MultiConvConformerEncoder( + 20, pos_enc_layer_type="abc_pos", selfattention_layer_type="dummy" + ) + + +def test_encoder_invalid_rel_pos_combination(): + with pytest.raises(AssertionError): + MultiConvConformerEncoder( + 20, + rel_pos_type="latest", + pos_enc_layer_type="legacy_rel_pos", + selfattention_layer_type="legacy_rel_sselfattn", + ) + with pytest.raises(AssertionError): + MultiConvConformerEncoder( + 20, + pos_enc_layer_type="rel_pos", + selfattention_layer_type="legacy_rel_sselfattn", + ) + with pytest.raises(AssertionError): + MultiConvConformerEncoder( + 20, + pos_enc_layer_type="legacy_rel_pos", + selfattention_layer_type="rel_sselfattn", + ) + + +def test_encoder_invalid_interctc_layer_idx(): + with pytest.raises(AssertionError): + MultiConvConformerEncoder( + 20, + num_blocks=2, + interctc_layer_idx=[0, 1], + ) + with pytest.raises(AssertionError): + MultiConvConformerEncoder( + 20, + num_blocks=2, + interctc_layer_idx=[1, 2], + ) + + +def test_encoder_output_size(): + encoder = MultiConvConformerEncoder(20, output_size=256) + assert encoder.output_size() == 256 + + +def test_encoder_invalid_type(): + with pytest.raises(ValueError): + MultiConvConformerEncoder(20, input_layer="fff") + + +def test_encoder_invalid_stochastic_depth_rate(): + with pytest.raises(ValueError): + MultiConvConformerEncoder( + 20, + num_blocks=2, + stochastic_depth_rate=[0.1], + ) + with pytest.raises(ValueError): + MultiConvConformerEncoder( + 20, + num_blocks=2, + stochastic_depth_rate=[0.1, 0.1, 0.1], + )