From 151c339ffeced822917e85255431fcfb74f24db9 Mon Sep 17 00:00:00 2001 From: "shixian.shi" Date: Tue, 20 Feb 2024 14:04:09 +0800 Subject: [PATCH] support transducer model inference --- funasr/models/conformer/encoder.py | 666 ++++++++++++++++++ funasr/models/seaco_paraformer/model.py | 2 +- .../transducer/beam_search_transducer.py | 10 +- funasr/models/transducer/joint_network.py | 7 +- funasr/models/transducer/model.py | 153 ++-- funasr/models/transducer/rnn_decoder.py | 11 +- funasr/models/transducer/rnn_encoder.py | 112 --- funasr/models/transducer/rnnt_decoder.py | 13 +- funasr/models/transformer/attention.py | 213 ++++++ 9 files changed, 956 insertions(+), 231 deletions(-) delete mode 100644 funasr/models/transducer/rnn_encoder.py diff --git a/funasr/models/conformer/encoder.py b/funasr/models/conformer/encoder.py index 1ca437da4..1d252c206 100644 --- a/funasr/models/conformer/encoder.py +++ b/funasr/models/conformer/encoder.py @@ -14,6 +14,7 @@ MultiHeadedAttention, # noqa: H301 RelPositionMultiHeadedAttention, # noqa: H301 LegacyRelPositionMultiHeadedAttention, # noqa: H301 + RelPositionMultiHeadedAttentionChunk, ) from funasr.models.transformer.embedding import ( PositionalEncoding, # noqa: H301 @@ -610,4 +611,669 @@ def forward( if len(intermediate_outs) > 0: return (xs_pad, intermediate_outs), olens, None return xs_pad, olens, None + +class CausalConvolution(torch.nn.Module): + """ConformerConvolution module definition. + Args: + channels: The number of channels. + kernel_size: Size of the convolving kernel. + activation: Type of activation function. + norm_args: Normalization module arguments. + causal: Whether to use causal convolution (set to True if streaming). + """ + + def __init__( + self, + channels: int, + kernel_size: int, + activation: torch.nn.Module = torch.nn.ReLU(), + norm_args: Dict = {}, + causal: bool = False, + ) -> None: + """Construct an ConformerConvolution object.""" + super().__init__() + + assert (kernel_size - 1) % 2 == 0 + + self.kernel_size = kernel_size + + self.pointwise_conv1 = torch.nn.Conv1d( + channels, + 2 * channels, + kernel_size=1, + stride=1, + padding=0, + ) + + if causal: + self.lorder = kernel_size - 1 + padding = 0 + else: + self.lorder = 0 + padding = (kernel_size - 1) // 2 + + self.depthwise_conv = torch.nn.Conv1d( + channels, + channels, + kernel_size, + stride=1, + padding=padding, + groups=channels, + ) + self.norm = torch.nn.BatchNorm1d(channels, **norm_args) + self.pointwise_conv2 = torch.nn.Conv1d( + channels, + channels, + kernel_size=1, + stride=1, + padding=0, + ) + + self.activation = activation + + def forward( + self, + x: torch.Tensor, + cache: Optional[torch.Tensor] = None, + right_context: int = 0, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute convolution module. + Args: + x: ConformerConvolution input sequences. (B, T, D_hidden) + cache: ConformerConvolution input cache. (1, conv_kernel, D_hidden) + right_context: Number of frames in right context. + Returns: + x: ConformerConvolution output sequences. (B, T, D_hidden) + cache: ConformerConvolution output cache. (1, conv_kernel, D_hidden) + """ + x = self.pointwise_conv1(x.transpose(1, 2)) + x = torch.nn.functional.glu(x, dim=1) + + if self.lorder > 0: + if cache is None: + x = torch.nn.functional.pad(x, (self.lorder, 0), "constant", 0.0) + else: + x = torch.cat([cache, x], dim=2) + + if right_context > 0: + cache = x[:, :, -(self.lorder + right_context) : -right_context] + else: + cache = x[:, :, -self.lorder :] + + x = self.depthwise_conv(x) + x = self.activation(self.norm(x)) + + x = self.pointwise_conv2(x).transpose(1, 2) + + return x, cache + +class ChunkEncoderLayer(torch.nn.Module): + """Chunk Conformer module definition. + Args: + block_size: Input/output size. + self_att: Self-attention module instance. + feed_forward: Feed-forward module instance. + feed_forward_macaron: Feed-forward module instance for macaron network. + conv_mod: Convolution module instance. + norm_class: Normalization module class. + norm_args: Normalization module arguments. + dropout_rate: Dropout rate. + """ + + def __init__( + self, + block_size: int, + self_att: torch.nn.Module, + feed_forward: torch.nn.Module, + feed_forward_macaron: torch.nn.Module, + conv_mod: torch.nn.Module, + norm_class: torch.nn.Module = LayerNorm, + norm_args: Dict = {}, + dropout_rate: float = 0.0, + ) -> None: + """Construct a Conformer object.""" + super().__init__() + + self.self_att = self_att + + self.feed_forward = feed_forward + self.feed_forward_macaron = feed_forward_macaron + self.feed_forward_scale = 0.5 + + self.conv_mod = conv_mod + + self.norm_feed_forward = norm_class(block_size, **norm_args) + self.norm_self_att = norm_class(block_size, **norm_args) + + self.norm_macaron = norm_class(block_size, **norm_args) + self.norm_conv = norm_class(block_size, **norm_args) + self.norm_final = norm_class(block_size, **norm_args) + + self.dropout = torch.nn.Dropout(dropout_rate) + + self.block_size = block_size + self.cache = None + + def reset_streaming_cache(self, left_context: int, device: torch.device) -> None: + """Initialize/Reset self-attention and convolution modules cache for streaming. + Args: + left_context: Number of left frames during chunk-by-chunk inference. + device: Device to use for cache tensor. + """ + self.cache = [ + torch.zeros( + (1, left_context, self.block_size), + device=device, + ), + torch.zeros( + ( + 1, + self.block_size, + self.conv_mod.kernel_size - 1, + ), + device=device, + ), + ] + + def forward( + self, + x: torch.Tensor, + pos_enc: torch.Tensor, + mask: torch.Tensor, + chunk_mask: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Encode input sequences. + Args: + x: Conformer input sequences. (B, T, D_block) + pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block) + mask: Source mask. (B, T) + chunk_mask: Chunk mask. (T_2, T_2) + Returns: + x: Conformer output sequences. (B, T, D_block) + mask: Source mask. (B, T) + pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block) + """ + residual = x + + x = self.norm_macaron(x) + x = residual + self.feed_forward_scale * self.dropout( + self.feed_forward_macaron(x) + ) + + residual = x + x = self.norm_self_att(x) + x_q = x + x = residual + self.dropout( + self.self_att( + x_q, + x, + x, + pos_enc, + mask, + chunk_mask=chunk_mask, + ) + ) + + residual = x + + x = self.norm_conv(x) + x, _ = self.conv_mod(x) + x = residual + self.dropout(x) + residual = x + + x = self.norm_feed_forward(x) + x = residual + self.feed_forward_scale * self.dropout(self.feed_forward(x)) + + x = self.norm_final(x) + return x, mask, pos_enc + + def chunk_forward( + self, + x: torch.Tensor, + pos_enc: torch.Tensor, + mask: torch.Tensor, + chunk_size: int = 16, + left_context: int = 0, + right_context: int = 0, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Encode chunk of input sequence. + Args: + x: Conformer input sequences. (B, T, D_block) + pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block) + mask: Source mask. (B, T_2) + left_context: Number of frames in left context. + right_context: Number of frames in right context. + Returns: + x: Conformer output sequences. (B, T, D_block) + pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block) + """ + residual = x + + x = self.norm_macaron(x) + x = residual + self.feed_forward_scale * self.feed_forward_macaron(x) + + residual = x + x = self.norm_self_att(x) + if left_context > 0: + key = torch.cat([self.cache[0], x], dim=1) + else: + key = x + val = key + + if right_context > 0: + att_cache = key[:, -(left_context + right_context) : -right_context, :] + else: + att_cache = key[:, -left_context:, :] + x = residual + self.self_att( + x, + key, + val, + pos_enc, + mask, + left_context=left_context, + ) + + residual = x + x = self.norm_conv(x) + x, conv_cache = self.conv_mod( + x, cache=self.cache[1], right_context=right_context + ) + x = residual + x + residual = x + + x = self.norm_feed_forward(x) + x = residual + self.feed_forward_scale * self.feed_forward(x) + + x = self.norm_final(x) + self.cache = [att_cache, conv_cache] + + return x, pos_enc + +@tables.register("encoder_classes", "ChunkConformerEncoder") +class ConformerChunkEncoder(torch.nn.Module): + """Encoder module definition. + Args: + input_size: Input size. + body_conf: Encoder body configuration. + input_conf: Encoder input configuration. + main_conf: Encoder main configuration. + """ + + def __init__( + self, + input_size: int, + output_size: int = 256, + attention_heads: int = 4, + linear_units: int = 2048, + num_blocks: int = 6, + dropout_rate: float = 0.1, + positional_dropout_rate: float = 0.1, + attention_dropout_rate: float = 0.0, + embed_vgg_like: bool = False, + normalize_before: bool = True, + concat_after: bool = False, + positionwise_layer_type: str = "linear", + positionwise_conv_kernel_size: int = 3, + macaron_style: bool = False, + rel_pos_type: str = "legacy", + pos_enc_layer_type: str = "rel_pos", + selfattention_layer_type: str = "rel_selfattn", + activation_type: str = "swish", + use_cnn_module: bool = True, + zero_triu: bool = False, + norm_type: str = "layer_norm", + cnn_module_kernel: int = 31, + conv_mod_norm_eps: float = 0.00001, + conv_mod_norm_momentum: float = 0.1, + simplified_att_score: bool = False, + dynamic_chunk_training: bool = False, + short_chunk_threshold: float = 0.75, + short_chunk_size: int = 25, + left_chunk_size: int = 0, + time_reduction_factor: int = 1, + unified_model_training: bool = False, + default_chunk_size: int = 16, + jitter_range: int = 4, + subsampling_factor: int = 1, + ) -> None: + """Construct an Encoder object.""" + super().__init__() + + + self.embed = StreamingConvInput( + input_size=input_size, + conv_size=output_size, + subsampling_factor=subsampling_factor, + vgg_like=embed_vgg_like, + output_size=output_size, + ) + + self.pos_enc = StreamingRelPositionalEncoding( + output_size, + positional_dropout_rate, + ) + + activation = get_activation( + activation_type + ) + + pos_wise_args = ( + output_size, + linear_units, + positional_dropout_rate, + activation, + ) + + conv_mod_norm_args = { + "eps": conv_mod_norm_eps, + "momentum": conv_mod_norm_momentum, + } + + conv_mod_args = ( + output_size, + cnn_module_kernel, + activation, + conv_mod_norm_args, + dynamic_chunk_training or unified_model_training, + ) + + mult_att_args = ( + attention_heads, + output_size, + attention_dropout_rate, + simplified_att_score, + ) + + + fn_modules = [] + for _ in range(num_blocks): + module = lambda: ChunkEncoderLayer( + output_size, + RelPositionMultiHeadedAttentionChunk(*mult_att_args), + PositionwiseFeedForward(*pos_wise_args), + PositionwiseFeedForward(*pos_wise_args), + CausalConvolution(*conv_mod_args), + dropout_rate=dropout_rate, + ) + fn_modules.append(module) + + self.encoders = MultiBlocks( + [fn() for fn in fn_modules], + output_size, + ) + + self._output_size = output_size + + self.dynamic_chunk_training = dynamic_chunk_training + self.short_chunk_threshold = short_chunk_threshold + self.short_chunk_size = short_chunk_size + self.left_chunk_size = left_chunk_size + + self.unified_model_training = unified_model_training + self.default_chunk_size = default_chunk_size + self.jitter_range = jitter_range + + self.time_reduction_factor = time_reduction_factor + + def output_size(self) -> int: + return self._output_size + + def get_encoder_input_raw_size(self, size: int, hop_length: int) -> int: + """Return the corresponding number of sample for a given chunk size, in frames. + Where size is the number of features frames after applying subsampling. + Args: + size: Number of frames after subsampling. + hop_length: Frontend's hop length + Returns: + : Number of raw samples + """ + return self.embed.get_size_before_subsampling(size) * hop_length + + def get_encoder_input_size(self, size: int) -> int: + """Return the corresponding number of sample for a given chunk size, in frames. + Where size is the number of features frames after applying subsampling. + Args: + size: Number of frames after subsampling. + Returns: + : Number of raw samples + """ + return self.embed.get_size_before_subsampling(size) + + + def reset_streaming_cache(self, left_context: int, device: torch.device) -> None: + """Initialize/Reset encoder streaming cache. + Args: + left_context: Number of frames in left context. + device: Device ID. + """ + return self.encoders.reset_streaming_cache(left_context, device) + + def forward( + self, + x: torch.Tensor, + x_len: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Encode input sequences. + Args: + x: Encoder input features. (B, T_in, F) + x_len: Encoder input features lengths. (B,) + Returns: + x: Encoder outputs. (B, T_out, D_enc) + x_len: Encoder outputs lenghts. (B,) + """ + short_status, limit_size = check_short_utt( + self.embed.subsampling_factor, x.size(1) + ) + + if short_status: + raise TooShortUttError( + f"has {x.size(1)} frames and is too short for subsampling " + + f"(it needs more than {limit_size} frames), return empty results", + x.size(1), + limit_size, + ) + + mask = make_source_mask(x_len).to(x.device) + + if self.unified_model_training: + if self.training: + chunk_size = self.default_chunk_size + torch.randint(-self.jitter_range, self.jitter_range+1, (1,)).item() + else: + chunk_size = self.default_chunk_size + x, mask = self.embed(x, mask, chunk_size) + pos_enc = self.pos_enc(x) + chunk_mask = make_chunk_mask( + x.size(1), + chunk_size, + left_chunk_size=self.left_chunk_size, + device=x.device, + ) + x_utt = self.encoders( + x, + pos_enc, + mask, + chunk_mask=None, + ) + x_chunk = self.encoders( + x, + pos_enc, + mask, + chunk_mask=chunk_mask, + ) + + olens = mask.eq(0).sum(1) + if self.time_reduction_factor > 1: + x_utt = x_utt[:,::self.time_reduction_factor,:] + x_chunk = x_chunk[:,::self.time_reduction_factor,:] + olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1 + + return x_utt, x_chunk, olens + + elif self.dynamic_chunk_training: + max_len = x.size(1) + if self.training: + chunk_size = torch.randint(1, max_len, (1,)).item() + + if chunk_size > (max_len * self.short_chunk_threshold): + chunk_size = max_len + else: + chunk_size = (chunk_size % self.short_chunk_size) + 1 + else: + chunk_size = self.default_chunk_size + + x, mask = self.embed(x, mask, chunk_size) + pos_enc = self.pos_enc(x) + + chunk_mask = make_chunk_mask( + x.size(1), + chunk_size, + left_chunk_size=self.left_chunk_size, + device=x.device, + ) + else: + x, mask = self.embed(x, mask, None) + pos_enc = self.pos_enc(x) + chunk_mask = None + x = self.encoders( + x, + pos_enc, + mask, + chunk_mask=chunk_mask, + ) + + olens = mask.eq(0).sum(1) + if self.time_reduction_factor > 1: + x = x[:,::self.time_reduction_factor,:] + olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1 + + return x, olens, None + + def full_utt_forward( + self, + x: torch.Tensor, + x_len: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Encode input sequences. + Args: + x: Encoder input features. (B, T_in, F) + x_len: Encoder input features lengths. (B,) + Returns: + x: Encoder outputs. (B, T_out, D_enc) + x_len: Encoder outputs lenghts. (B,) + """ + short_status, limit_size = check_short_utt( + self.embed.subsampling_factor, x.size(1) + ) + + if short_status: + raise TooShortUttError( + f"has {x.size(1)} frames and is too short for subsampling " + + f"(it needs more than {limit_size} frames), return empty results", + x.size(1), + limit_size, + ) + + mask = make_source_mask(x_len).to(x.device) + x, mask = self.embed(x, mask, None) + pos_enc = self.pos_enc(x) + x_utt = self.encoders( + x, + pos_enc, + mask, + chunk_mask=None, + ) + + if self.time_reduction_factor > 1: + x_utt = x_utt[:,::self.time_reduction_factor,:] + return x_utt + + def simu_chunk_forward( + self, + x: torch.Tensor, + x_len: torch.Tensor, + chunk_size: int = 16, + left_context: int = 32, + right_context: int = 0, + ) -> torch.Tensor: + short_status, limit_size = check_short_utt( + self.embed.subsampling_factor, x.size(1) + ) + + if short_status: + raise TooShortUttError( + f"has {x.size(1)} frames and is too short for subsampling " + + f"(it needs more than {limit_size} frames), return empty results", + x.size(1), + limit_size, + ) + + mask = make_source_mask(x_len) + + x, mask = self.embed(x, mask, chunk_size) + pos_enc = self.pos_enc(x) + chunk_mask = make_chunk_mask( + x.size(1), + chunk_size, + left_chunk_size=self.left_chunk_size, + device=x.device, + ) + + x = self.encoders( + x, + pos_enc, + mask, + chunk_mask=chunk_mask, + ) + olens = mask.eq(0).sum(1) + if self.time_reduction_factor > 1: + x = x[:,::self.time_reduction_factor,:] + + return x + + def chunk_forward( + self, + x: torch.Tensor, + x_len: torch.Tensor, + processed_frames: torch.tensor, + chunk_size: int = 16, + left_context: int = 32, + right_context: int = 0, + ) -> torch.Tensor: + """Encode input sequences as chunks. + Args: + x: Encoder input features. (1, T_in, F) + x_len: Encoder input features lengths. (1,) + processed_frames: Number of frames already seen. + left_context: Number of frames in left context. + right_context: Number of frames in right context. + Returns: + x: Encoder outputs. (B, T_out, D_enc) + """ + mask = make_source_mask(x_len) + x, mask = self.embed(x, mask, None) + + if left_context > 0: + processed_mask = ( + torch.arange(left_context, device=x.device) + .view(1, left_context) + .flip(1) + ) + processed_mask = processed_mask >= processed_frames + mask = torch.cat([processed_mask, mask], dim=1) + pos_enc = self.pos_enc(x, left_context=left_context) + x = self.encoders.chunk_forward( + x, + pos_enc, + mask, + chunk_size=chunk_size, + left_context=left_context, + right_context=right_context, + ) + + if right_context > 0: + x = x[:, 0:-right_context, :] + + if self.time_reduction_factor > 1: + x = x[:,::self.time_reduction_factor,:] + return x diff --git a/funasr/models/seaco_paraformer/model.py b/funasr/models/seaco_paraformer/model.py index 8b8e97e53..2f55e6e19 100644 --- a/funasr/models/seaco_paraformer/model.py +++ b/funasr/models/seaco_paraformer/model.py @@ -335,7 +335,7 @@ def inference(self, speech = speech.to(device=kwargs["device"]) speech_lengths = speech_lengths.to(device=kwargs["device"]) - + # hotword self.hotword_list = self.generate_hotwords_list(kwargs.get("hotword", None), tokenizer=tokenizer, frontend=frontend) diff --git a/funasr/models/transducer/beam_search_transducer.py b/funasr/models/transducer/beam_search_transducer.py index 04b26b3b4..f599615c0 100644 --- a/funasr/models/transducer/beam_search_transducer.py +++ b/funasr/models/transducer/beam_search_transducer.py @@ -1,11 +1,13 @@ -"""Search algorithms for Transducer models.""" +#!/usr/bin/env python3 +# -*- encoding: utf-8 -*- +# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. +# MIT License (https://opensource.org/licenses/MIT) +import torch +import numpy as np from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple, Union -import numpy as np -import torch - from funasr.models.transducer.joint_network import JointNetwork diff --git a/funasr/models/transducer/joint_network.py b/funasr/models/transducer/joint_network.py index 9fca632e2..7d424dbc9 100644 --- a/funasr/models/transducer/joint_network.py +++ b/funasr/models/transducer/joint_network.py @@ -1,10 +1,15 @@ -"""Transducer joint network implementation.""" +#!/usr/bin/env python3 +# -*- encoding: utf-8 -*- +# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. +# MIT License (https://opensource.org/licenses/MIT) import torch +from funasr.register import tables from funasr.models.transformer.utils.nets_utils import get_activation +@tables.register("joint_network_classes", "joint_network") class JointNetwork(torch.nn.Module): """Transducer joint network module. diff --git a/funasr/models/transducer/model.py b/funasr/models/transducer/model.py index 906aa605d..fd8ad71db 100644 --- a/funasr/models/transducer/model.py +++ b/funasr/models/transducer/model.py @@ -1,42 +1,26 @@ +#!/usr/bin/env python3 +# -*- encoding: utf-8 -*- +# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. +# MIT License (https://opensource.org/licenses/MIT) + +import time +import torch import logging from contextlib import contextmanager +from typing import Dict, Optional, Tuple from distutils.version import LooseVersion -from typing import Dict -from typing import List -from typing import Optional -from typing import Tuple -from typing import Union -import tempfile -import codecs -import requests -import re -import copy -import torch -import torch.nn as nn -import random -import numpy as np -import time -from funasr.losses.label_smoothing_loss import ( - LabelSmoothingLoss, # noqa: H301 -) -# from funasr.models.ctc import CTC -# from funasr.models.decoder.abs_decoder import AbsDecoder -# from funasr.models.e2e_asr_common import ErrorCalculator -# from funasr.models.encoder.abs_encoder import AbsEncoder -# from funasr.frontends.abs_frontend import AbsFrontend -# from funasr.models.postencoder.abs_postencoder import AbsPostEncoder -from funasr.models.paraformer.cif_predictor import mae_loss -# from funasr.models.preencoder.abs_preencoder import AbsPreEncoder -# from funasr.models.specaug.abs_specaug import AbsSpecAug -from funasr.models.transformer.utils.add_sos_eos import add_sos_eos -from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list -from funasr.metrics.compute_acc import th_accuracy + +from funasr.register import tables +from funasr.utils import postprocess_utils +from funasr.utils.datadir_writer import DatadirWriter from funasr.train_utils.device_funcs import force_gatherable -# from funasr.models.base_model import FunASRModel -# from funasr.models.paraformer.cif_predictor import CifPredictorV3 -from funasr.models.paraformer.search import Hypothesis +from funasr.models.transformer.scorers.ctc import CTCPrefixScorer +from funasr.losses.label_smoothing_loss import LabelSmoothingLoss +from funasr.models.transformer.scorers.length_bonus import LengthBonus +from funasr.models.transformer.utils.nets_utils import get_transducer_task_io +from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank +from funasr.models.transducer.beam_search_transducer import BeamSearchTransducer -from funasr.models.model_class_factory import * if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): from torch.cuda.amp import autocast @@ -45,16 +29,10 @@ @contextmanager def autocast(enabled=True): yield -from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank -from funasr.utils import postprocess_utils -from funasr.utils.datadir_writer import DatadirWriter -from funasr.models.transformer.utils.nets_utils import get_transducer_task_io - -class Transducer(nn.Module): - """ESPnet2ASRTransducerModel module definition.""" - +@tables.register("model_classes", "Transducer") +class Transducer(torch.nn.Module): def __init__( self, frontend: Optional[str] = None, @@ -96,28 +74,24 @@ def __init__( super().__init__() - if frontend is not None: - frontend_class = frontend_classes.get_class(frontend) - frontend = frontend_class(**frontend_conf) if specaug is not None: - specaug_class = specaug_classes.get_class(specaug) + specaug_class = tables.specaug_classes.get(specaug) specaug = specaug_class(**specaug_conf) if normalize is not None: - normalize_class = normalize_classes.get_class(normalize) + normalize_class = tables.normalize_classes.get(normalize) normalize = normalize_class(**normalize_conf) - encoder_class = encoder_classes.get_class(encoder) + encoder_class = tables.encoder_classes.get(encoder) encoder = encoder_class(input_size=input_size, **encoder_conf) encoder_output_size = encoder.output_size() - decoder_class = decoder_classes.get_class(decoder) + decoder_class = tables.decoder_classes.get(decoder) decoder = decoder_class( vocab_size=vocab_size, - encoder_output_size=encoder_output_size, **decoder_conf, ) decoder_output_size = decoder.output_size - joint_network_class = joint_network_classes.get_class(decoder) + joint_network_class = tables.joint_network_classes.get(joint_network) joint_network = joint_network_class( vocab_size, encoder_output_size, @@ -125,7 +99,6 @@ def __init__( **joint_network_conf, ) - self.criterion_transducer = None self.error_calculator = None @@ -157,23 +130,17 @@ def __init__( self.decoder = decoder self.joint_network = joint_network - - self.criterion_att = LabelSmoothingLoss( size=vocab_size, padding_idx=ignore_id, smoothing=lsm_weight, normalize_length=length_normalized_loss, ) - # - # if report_cer or report_wer: - # self.error_calculator = ErrorCalculator( - # token_list, sym_space, sym_blank, report_cer, report_wer - # ) - # self.length_normalized_loss = length_normalized_loss self.beam_search = None + self.ctc = None + self.ctc_weight = 0.0 def forward( self, @@ -190,8 +157,6 @@ def forward( text: (Batch, Length) text_lengths: (Batch,) """ - # import pdb; - # pdb.set_trace() if len(text_lengths.size()) > 1: text_lengths = text_lengths[:, 0] if len(speech_lengths.size()) > 1: @@ -283,12 +248,7 @@ def encode( # Forward encoder # feats: (Batch, Length, Dim) # -> encoder_out: (Batch, Length2, Dim2) - if self.encoder.interctc_use_conditioning: - encoder_out, encoder_out_lens, _ = self.encoder( - speech, speech_lengths, ctc=self.ctc - ) - else: - encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths) + encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths) intermediate_outs = None if isinstance(encoder_out, tuple): intermediate_outs = encoder_out[1] @@ -449,9 +409,6 @@ def _calc_lm_loss( def init_beam_search(self, **kwargs, ): - from funasr.models.transformer.search import BeamSearch - from funasr.models.transformer.scorers.ctc import CTCPrefixScorer - from funasr.models.transformer.scorers.length_bonus import LengthBonus # 1. Build ASR model scorers = {} @@ -466,28 +423,16 @@ def init_beam_search(self, length_bonus=LengthBonus(len(token_list)), ) - # 3. Build ngram model # ngram is not supported now ngram = None scorers["ngram"] = ngram - weights = dict( - decoder=1.0 - kwargs.get("decoding_ctc_weight"), - ctc=kwargs.get("decoding_ctc_weight", 0.0), - lm=kwargs.get("lm_weight", 0.0), - ngram=kwargs.get("ngram_weight", 0.0), - length_bonus=kwargs.get("penalty", 0.0), - ) - beam_search = BeamSearch( - beam_size=kwargs.get("beam_size", 2), - weights=weights, - scorers=scorers, - sos=self.sos, - eos=self.eos, - vocab_size=len(token_list), - token_list=token_list, - pre_beam_score_key=None if self.ctc_weight == 1.0 else "full", + beam_search = BeamSearchTransducer( + self.decoder, + self.joint_network, + kwargs.get("beam_size", 2), + nbest=1, ) # beam_search.to(device=kwargs.get("device", "cpu"), dtype=getattr(torch, kwargs.get("dtype", "float32"))).eval() # for scorer in scorers.values(): @@ -495,13 +440,13 @@ def init_beam_search(self, # scorer.to(device=kwargs.get("device", "cpu"), dtype=getattr(torch, kwargs.get("dtype", "float32"))).eval() self.beam_search = beam_search - def generate(self, - data_in: list, - data_lengths: list=None, - key: list=None, - tokenizer=None, - **kwargs, - ): + def inference(self, + data_in: list, + data_lengths: list=None, + key: list=None, + tokenizer=None, + **kwargs, + ): if kwargs.get("batch_size", 1) > 1: raise NotImplementedError("batch decoding is not implemented") @@ -509,10 +454,10 @@ def generate(self, # init beamsearch is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None - if self.beam_search is None and (is_use_lm or is_use_ctc): - logging.info("enable beam_search") - self.init_beam_search(**kwargs) - self.nbest = kwargs.get("nbest", 1) + # if self.beam_search is None and (is_use_lm or is_use_ctc): + logging.info("enable beam_search") + self.init_beam_search(**kwargs) + self.nbest = kwargs.get("nbest", 1) meta_data = {} # extract fbank feats @@ -534,13 +479,9 @@ def generate(self, encoder_out = encoder_out[0] # c. Passed the encoder result and the beam search - nbest_hyps = self.beam_search( - x=encoder_out[0], maxlenratio=kwargs.get("maxlenratio", 0.0), minlenratio=kwargs.get("minlenratio", 0.0) - ) - + nbest_hyps = self.beam_search(encoder_out[0], is_final=True) nbest_hyps = nbest_hyps[: self.nbest] - results = [] b, n, d = encoder_out.size() for i in range(b): @@ -553,9 +494,9 @@ def generate(self, # remove sos/eos and get results last_pos = -1 if isinstance(hyp.yseq, list): - token_int = hyp.yseq[1:last_pos] + token_int = hyp.yseq#[1:last_pos] else: - token_int = hyp.yseq[1:last_pos].tolist() + token_int = hyp.yseq#[1:last_pos].tolist() # remove blank symbol id, which is assumed to be 0 token_int = list(filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int)) diff --git a/funasr/models/transducer/rnn_decoder.py b/funasr/models/transducer/rnn_decoder.py index 204f0b1d7..b999d9c2d 100644 --- a/funasr/models/transducer/rnn_decoder.py +++ b/funasr/models/transducer/rnn_decoder.py @@ -1,10 +1,15 @@ -import random +#!/usr/bin/env python3 +# -*- encoding: utf-8 -*- +# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. +# MIT License (https://opensource.org/licenses/MIT) -import numpy as np import torch +import random +import numpy as np import torch.nn as nn import torch.nn.functional as F +from funasr.register import tables from funasr.models.transformer.utils.nets_utils import make_pad_mask from funasr.models.transformer.utils.nets_utils import to_device from funasr.models.language_model.rnn.attentions import initial_att @@ -78,7 +83,7 @@ def build_attention_list( ) return att_list - +@tables.register("decoder_classes", "rnn_decoder") class RNNDecoder(nn.Module): def __init__( self, diff --git a/funasr/models/transducer/rnn_encoder.py b/funasr/models/transducer/rnn_encoder.py deleted file mode 100644 index 95fb4a589..000000000 --- a/funasr/models/transducer/rnn_encoder.py +++ /dev/null @@ -1,112 +0,0 @@ - -from typing import Optional -from typing import Sequence -from typing import Tuple - -import numpy as np -import torch - -from funasr.models.transformer.utils.nets_utils import make_pad_mask -from funasr.models.language_model.rnn.encoders import RNN -from funasr.models.language_model.rnn.encoders import RNNP -from funasr.models.encoder.abs_encoder import AbsEncoder - - -class RNNEncoder(AbsEncoder): - """RNNEncoder class. - Args: - input_size: The number of expected features in the input - output_size: The number of output features - hidden_size: The number of hidden features - bidirectional: If ``True`` becomes a bidirectional LSTM - use_projection: Use projection layer or not - num_layers: Number of recurrent layers - dropout: dropout probability - """ - - def __init__( - self, - input_size: int, - rnn_type: str = "lstm", - bidirectional: bool = True, - use_projection: bool = True, - num_layers: int = 4, - hidden_size: int = 320, - output_size: int = 320, - dropout: float = 0.0, - subsample: Optional[Sequence[int]] = (2, 2, 1, 1), - ): - super().__init__() - self._output_size = output_size - self.rnn_type = rnn_type - self.bidirectional = bidirectional - self.use_projection = use_projection - - if rnn_type not in {"lstm", "gru"}: - raise ValueError(f"Not supported rnn_type={rnn_type}") - - if subsample is None: - subsample = np.ones(num_layers + 1, dtype=np.int32) - else: - subsample = subsample[:num_layers] - # Append 1 at the beginning because the second or later is used - subsample = np.pad( - np.array(subsample, dtype=np.int32), - [1, num_layers - len(subsample)], - mode="constant", - constant_values=1, - ) - - rnn_type = ("b" if bidirectional else "") + rnn_type - if use_projection: - self.enc = torch.nn.ModuleList( - [ - RNNP( - input_size, - num_layers, - hidden_size, - output_size, - subsample, - dropout, - typ=rnn_type, - ) - ] - ) - - else: - self.enc = torch.nn.ModuleList( - [ - RNN( - input_size, - num_layers, - hidden_size, - output_size, - dropout, - typ=rnn_type, - ) - ] - ) - - def output_size(self) -> int: - return self._output_size - - def forward( - self, - xs_pad: torch.Tensor, - ilens: torch.Tensor, - prev_states: torch.Tensor = None, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - if prev_states is None: - prev_states = [None] * len(self.enc) - assert len(prev_states) == len(self.enc) - - current_states = [] - for module, prev_state in zip(self.enc, prev_states): - xs_pad, ilens, states = module(xs_pad, ilens, prev_state=prev_state) - current_states.append(states) - - if self.use_projection: - xs_pad.masked_fill_(make_pad_mask(ilens, xs_pad, 1), 0.0) - else: - xs_pad = xs_pad.masked_fill(make_pad_mask(ilens, xs_pad, 1), 0.0) - return xs_pad, ilens, current_states diff --git a/funasr/models/transducer/rnnt_decoder.py b/funasr/models/transducer/rnnt_decoder.py index 6d35b71e4..26ca1f2e2 100644 --- a/funasr/models/transducer/rnnt_decoder.py +++ b/funasr/models/transducer/rnnt_decoder.py @@ -1,12 +1,17 @@ -"""RNN decoder definition for Transducer models.""" - -from typing import List, Optional, Tuple +#!/usr/bin/env python3 +# -*- encoding: utf-8 -*- +# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. +# MIT License (https://opensource.org/licenses/MIT) import torch +from typing import List, Optional, Tuple -from funasr.models.transducer.beam_search_transducer import Hypothesis +from funasr.register import tables from funasr.models.specaug.specaug import SpecAug +from funasr.models.transducer.beam_search_transducer import Hypothesis + +@tables.register("decoder_classes", "rnnt_decoder") class RNNTDecoder(torch.nn.Module): """RNN decoder module. diff --git a/funasr/models/transformer/attention.py b/funasr/models/transformer/attention.py index 32e1e478d..f09d6420d 100644 --- a/funasr/models/transformer/attention.py +++ b/funasr/models/transformer/attention.py @@ -312,8 +312,221 @@ def forward(self, query, key, value, pos_emb, mask): return self.forward_attention(v, scores, mask) +class RelPositionMultiHeadedAttentionChunk(torch.nn.Module): + """RelPositionMultiHeadedAttention definition. + Args: + num_heads: Number of attention heads. + embed_size: Embedding size. + dropout_rate: Dropout rate. + """ + + def __init__( + self, + num_heads: int, + embed_size: int, + dropout_rate: float = 0.0, + simplified_attention_score: bool = False, + ) -> None: + """Construct an MultiHeadedAttention object.""" + super().__init__() + + self.d_k = embed_size // num_heads + self.num_heads = num_heads + + assert self.d_k * num_heads == embed_size, ( + "embed_size (%d) must be divisible by num_heads (%d)", + (embed_size, num_heads), + ) + + self.linear_q = torch.nn.Linear(embed_size, embed_size) + self.linear_k = torch.nn.Linear(embed_size, embed_size) + self.linear_v = torch.nn.Linear(embed_size, embed_size) + + self.linear_out = torch.nn.Linear(embed_size, embed_size) + + if simplified_attention_score: + self.linear_pos = torch.nn.Linear(embed_size, num_heads) + + self.compute_att_score = self.compute_simplified_attention_score + else: + self.linear_pos = torch.nn.Linear(embed_size, embed_size, bias=False) + + self.pos_bias_u = torch.nn.Parameter(torch.Tensor(num_heads, self.d_k)) + self.pos_bias_v = torch.nn.Parameter(torch.Tensor(num_heads, self.d_k)) + torch.nn.init.xavier_uniform_(self.pos_bias_u) + torch.nn.init.xavier_uniform_(self.pos_bias_v) + + self.compute_att_score = self.compute_attention_score + + self.dropout = torch.nn.Dropout(p=dropout_rate) + self.attn = None + + def rel_shift(self, x: torch.Tensor, left_context: int = 0) -> torch.Tensor: + """Compute relative positional encoding. + Args: + x: Input sequence. (B, H, T_1, 2 * T_1 - 1) + left_context: Number of frames in left context. + Returns: + x: Output sequence. (B, H, T_1, T_2) + """ + batch_size, n_heads, time1, n = x.shape + time2 = time1 + left_context + + batch_stride, n_heads_stride, time1_stride, n_stride = x.stride() + + return x.as_strided( + (batch_size, n_heads, time1, time2), + (batch_stride, n_heads_stride, time1_stride - n_stride, n_stride), + storage_offset=(n_stride * (time1 - 1)), + ) + + def compute_simplified_attention_score( + self, + query: torch.Tensor, + key: torch.Tensor, + pos_enc: torch.Tensor, + left_context: int = 0, + ) -> torch.Tensor: + """Simplified attention score computation. + Reference: https://github.com/k2-fsa/icefall/pull/458 + Args: + query: Transformed query tensor. (B, H, T_1, d_k) + key: Transformed key tensor. (B, H, T_2, d_k) + pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size) + left_context: Number of frames in left context. + Returns: + : Attention score. (B, H, T_1, T_2) + """ + pos_enc = self.linear_pos(pos_enc) + + matrix_ac = torch.matmul(query, key.transpose(2, 3)) + matrix_bd = self.rel_shift( + pos_enc.transpose(1, 2).unsqueeze(2).repeat(1, 1, query.size(2), 1), + left_context=left_context, + ) + return (matrix_ac + matrix_bd) / math.sqrt(self.d_k) + def compute_attention_score( + self, + query: torch.Tensor, + key: torch.Tensor, + pos_enc: torch.Tensor, + left_context: int = 0, + ) -> torch.Tensor: + """Attention score computation. + Args: + query: Transformed query tensor. (B, H, T_1, d_k) + key: Transformed key tensor. (B, H, T_2, d_k) + pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size) + left_context: Number of frames in left context. + Returns: + : Attention score. (B, H, T_1, T_2) + """ + p = self.linear_pos(pos_enc).view(pos_enc.size(0), -1, self.num_heads, self.d_k) + + query = query.transpose(1, 2) + q_with_bias_u = (query + self.pos_bias_u).transpose(1, 2) + q_with_bias_v = (query + self.pos_bias_v).transpose(1, 2) + + matrix_ac = torch.matmul(q_with_bias_u, key.transpose(-2, -1)) + + matrix_bd = torch.matmul(q_with_bias_v, p.permute(0, 2, 3, 1)) + matrix_bd = self.rel_shift(matrix_bd, left_context=left_context) + + return (matrix_ac + matrix_bd) / math.sqrt(self.d_k) + + def forward_qkv( + self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Transform query, key and value. + Args: + query: Query tensor. (B, T_1, size) + key: Key tensor. (B, T_2, size) + v: Value tensor. (B, T_2, size) + Returns: + q: Transformed query tensor. (B, H, T_1, d_k) + k: Transformed key tensor. (B, H, T_2, d_k) + v: Transformed value tensor. (B, H, T_2, d_k) + """ + n_batch = query.size(0) + + q = ( + self.linear_q(query) + .view(n_batch, -1, self.num_heads, self.d_k) + .transpose(1, 2) + ) + k = ( + self.linear_k(key) + .view(n_batch, -1, self.num_heads, self.d_k) + .transpose(1, 2) + ) + v = ( + self.linear_v(value) + .view(n_batch, -1, self.num_heads, self.d_k) + .transpose(1, 2) + ) + + return q, k, v + def forward_attention( + self, + value: torch.Tensor, + scores: torch.Tensor, + mask: torch.Tensor, + chunk_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Compute attention context vector. + Args: + value: Transformed value. (B, H, T_2, d_k) + scores: Attention score. (B, H, T_1, T_2) + mask: Source mask. (B, T_2) + chunk_mask: Chunk mask. (T_1, T_1) + Returns: + attn_output: Transformed value weighted by attention score. (B, T_1, H * d_k) + """ + batch_size = scores.size(0) + mask = mask.unsqueeze(1).unsqueeze(2) + if chunk_mask is not None: + mask = chunk_mask.unsqueeze(0).unsqueeze(1) | mask + scores = scores.masked_fill(mask, float("-inf")) + self.attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0) + + attn_output = self.dropout(self.attn) + attn_output = torch.matmul(attn_output, value) + + attn_output = self.linear_out( + attn_output.transpose(1, 2) + .contiguous() + .view(batch_size, -1, self.num_heads * self.d_k) + ) + + return attn_output + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + pos_enc: torch.Tensor, + mask: torch.Tensor, + chunk_mask: Optional[torch.Tensor] = None, + left_context: int = 0, + ) -> torch.Tensor: + """Compute scaled dot product attention with rel. positional encoding. + Args: + query: Query tensor. (B, T_1, size) + key: Key tensor. (B, T_2, size) + value: Value tensor. (B, T_2, size) + pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size) + mask: Source mask. (B, T_2) + chunk_mask: Chunk mask. (T_1, T_1) + left_context: Number of frames in left context. + Returns: + : Output tensor. (B, T_1, H * d_k) + """ + q, k, v = self.forward_qkv(query, key, value) + scores = self.compute_att_score(q, k, pos_enc, left_context=left_context) + return self.forward_attention(v, scores, mask, chunk_mask=chunk_mask)