Skip to content

Commit

Permalink
support transducer model inference
Browse files Browse the repository at this point in the history
  • Loading branch information
R1ckShi committed Feb 20, 2024
1 parent 94de39d commit 151c339
Show file tree
Hide file tree
Showing 9 changed files with 956 additions and 231 deletions.
666 changes: 666 additions & 0 deletions funasr/models/conformer/encoder.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion funasr/models/seaco_paraformer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ def inference(self,

speech = speech.to(device=kwargs["device"])
speech_lengths = speech_lengths.to(device=kwargs["device"])

# hotword
self.hotword_list = self.generate_hotwords_list(kwargs.get("hotword", None), tokenizer=tokenizer, frontend=frontend)

Expand Down
10 changes: 6 additions & 4 deletions funasr/models/transducer/beam_search_transducer.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
"""Search algorithms for Transducer models."""
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)

import torch
import numpy as np
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np
import torch

from funasr.models.transducer.joint_network import JointNetwork


Expand Down
7 changes: 6 additions & 1 deletion funasr/models/transducer/joint_network.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
"""Transducer joint network implementation."""
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)

import torch

from funasr.register import tables
from funasr.models.transformer.utils.nets_utils import get_activation


@tables.register("joint_network_classes", "joint_network")
class JointNetwork(torch.nn.Module):
"""Transducer joint network module.
Expand Down
153 changes: 47 additions & 106 deletions funasr/models/transducer/model.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,26 @@
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)

import time
import torch
import logging
from contextlib import contextmanager
from typing import Dict, Optional, Tuple
from distutils.version import LooseVersion
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union
import tempfile
import codecs
import requests
import re
import copy
import torch
import torch.nn as nn
import random
import numpy as np
import time
from funasr.losses.label_smoothing_loss import (
LabelSmoothingLoss, # noqa: H301
)
# from funasr.models.ctc import CTC
# from funasr.models.decoder.abs_decoder import AbsDecoder
# from funasr.models.e2e_asr_common import ErrorCalculator
# from funasr.models.encoder.abs_encoder import AbsEncoder
# from funasr.frontends.abs_frontend import AbsFrontend
# from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
from funasr.models.paraformer.cif_predictor import mae_loss
# from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
# from funasr.models.specaug.abs_specaug import AbsSpecAug
from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
from funasr.metrics.compute_acc import th_accuracy

from funasr.register import tables
from funasr.utils import postprocess_utils
from funasr.utils.datadir_writer import DatadirWriter
from funasr.train_utils.device_funcs import force_gatherable
# from funasr.models.base_model import FunASRModel
# from funasr.models.paraformer.cif_predictor import CifPredictorV3
from funasr.models.paraformer.search import Hypothesis
from funasr.models.transformer.scorers.ctc import CTCPrefixScorer
from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
from funasr.models.transformer.scorers.length_bonus import LengthBonus
from funasr.models.transformer.utils.nets_utils import get_transducer_task_io
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
from funasr.models.transducer.beam_search_transducer import BeamSearchTransducer

from funasr.models.model_class_factory import *

if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
from torch.cuda.amp import autocast
Expand All @@ -45,16 +29,10 @@
@contextmanager
def autocast(enabled=True):
yield
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
from funasr.utils import postprocess_utils
from funasr.utils.datadir_writer import DatadirWriter
from funasr.models.transformer.utils.nets_utils import get_transducer_task_io


class Transducer(nn.Module):
"""ESPnet2ASRTransducerModel module definition."""


@tables.register("model_classes", "Transducer")
class Transducer(torch.nn.Module):
def __init__(
self,
frontend: Optional[str] = None,
Expand Down Expand Up @@ -96,36 +74,31 @@ def __init__(

super().__init__()

if frontend is not None:
frontend_class = frontend_classes.get_class(frontend)
frontend = frontend_class(**frontend_conf)
if specaug is not None:
specaug_class = specaug_classes.get_class(specaug)
specaug_class = tables.specaug_classes.get(specaug)
specaug = specaug_class(**specaug_conf)
if normalize is not None:
normalize_class = normalize_classes.get_class(normalize)
normalize_class = tables.normalize_classes.get(normalize)
normalize = normalize_class(**normalize_conf)
encoder_class = encoder_classes.get_class(encoder)
encoder_class = tables.encoder_classes.get(encoder)
encoder = encoder_class(input_size=input_size, **encoder_conf)
encoder_output_size = encoder.output_size()

decoder_class = decoder_classes.get_class(decoder)
decoder_class = tables.decoder_classes.get(decoder)
decoder = decoder_class(
vocab_size=vocab_size,
encoder_output_size=encoder_output_size,
**decoder_conf,
)
decoder_output_size = decoder.output_size

joint_network_class = joint_network_classes.get_class(decoder)
joint_network_class = tables.joint_network_classes.get(joint_network)
joint_network = joint_network_class(
vocab_size,
encoder_output_size,
decoder_output_size,
**joint_network_conf,
)


self.criterion_transducer = None
self.error_calculator = None

Expand Down Expand Up @@ -157,23 +130,17 @@ def __init__(
self.decoder = decoder
self.joint_network = joint_network



self.criterion_att = LabelSmoothingLoss(
size=vocab_size,
padding_idx=ignore_id,
smoothing=lsm_weight,
normalize_length=length_normalized_loss,
)
#
# if report_cer or report_wer:
# self.error_calculator = ErrorCalculator(
# token_list, sym_space, sym_blank, report_cer, report_wer
# )
#

self.length_normalized_loss = length_normalized_loss
self.beam_search = None
self.ctc = None
self.ctc_weight = 0.0

def forward(
self,
Expand All @@ -190,8 +157,6 @@ def forward(
text: (Batch, Length)
text_lengths: (Batch,)
"""
# import pdb;
# pdb.set_trace()
if len(text_lengths.size()) > 1:
text_lengths = text_lengths[:, 0]
if len(speech_lengths.size()) > 1:
Expand Down Expand Up @@ -283,12 +248,7 @@ def encode(
# Forward encoder
# feats: (Batch, Length, Dim)
# -> encoder_out: (Batch, Length2, Dim2)
if self.encoder.interctc_use_conditioning:
encoder_out, encoder_out_lens, _ = self.encoder(
speech, speech_lengths, ctc=self.ctc
)
else:
encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
intermediate_outs = None
if isinstance(encoder_out, tuple):
intermediate_outs = encoder_out[1]
Expand Down Expand Up @@ -449,9 +409,6 @@ def _calc_lm_loss(
def init_beam_search(self,
**kwargs,
):
from funasr.models.transformer.search import BeamSearch
from funasr.models.transformer.scorers.ctc import CTCPrefixScorer
from funasr.models.transformer.scorers.length_bonus import LengthBonus

# 1. Build ASR model
scorers = {}
Expand All @@ -466,53 +423,41 @@ def init_beam_search(self,
length_bonus=LengthBonus(len(token_list)),
)


# 3. Build ngram model
# ngram is not supported now
ngram = None
scorers["ngram"] = ngram

weights = dict(
decoder=1.0 - kwargs.get("decoding_ctc_weight"),
ctc=kwargs.get("decoding_ctc_weight", 0.0),
lm=kwargs.get("lm_weight", 0.0),
ngram=kwargs.get("ngram_weight", 0.0),
length_bonus=kwargs.get("penalty", 0.0),
)
beam_search = BeamSearch(
beam_size=kwargs.get("beam_size", 2),
weights=weights,
scorers=scorers,
sos=self.sos,
eos=self.eos,
vocab_size=len(token_list),
token_list=token_list,
pre_beam_score_key=None if self.ctc_weight == 1.0 else "full",
beam_search = BeamSearchTransducer(
self.decoder,
self.joint_network,
kwargs.get("beam_size", 2),
nbest=1,
)
# beam_search.to(device=kwargs.get("device", "cpu"), dtype=getattr(torch, kwargs.get("dtype", "float32"))).eval()
# for scorer in scorers.values():
# if isinstance(scorer, torch.nn.Module):
# scorer.to(device=kwargs.get("device", "cpu"), dtype=getattr(torch, kwargs.get("dtype", "float32"))).eval()
self.beam_search = beam_search

def generate(self,
data_in: list,
data_lengths: list=None,
key: list=None,
tokenizer=None,
**kwargs,
):
def inference(self,
data_in: list,
data_lengths: list=None,
key: list=None,
tokenizer=None,
**kwargs,
):

if kwargs.get("batch_size", 1) > 1:
raise NotImplementedError("batch decoding is not implemented")

# init beamsearch
is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
if self.beam_search is None and (is_use_lm or is_use_ctc):
logging.info("enable beam_search")
self.init_beam_search(**kwargs)
self.nbest = kwargs.get("nbest", 1)
# if self.beam_search is None and (is_use_lm or is_use_ctc):
logging.info("enable beam_search")
self.init_beam_search(**kwargs)
self.nbest = kwargs.get("nbest", 1)

meta_data = {}
# extract fbank feats
Expand All @@ -534,13 +479,9 @@ def generate(self,
encoder_out = encoder_out[0]

# c. Passed the encoder result and the beam search
nbest_hyps = self.beam_search(
x=encoder_out[0], maxlenratio=kwargs.get("maxlenratio", 0.0), minlenratio=kwargs.get("minlenratio", 0.0)
)

nbest_hyps = self.beam_search(encoder_out[0], is_final=True)
nbest_hyps = nbest_hyps[: self.nbest]


results = []
b, n, d = encoder_out.size()
for i in range(b):
Expand All @@ -553,9 +494,9 @@ def generate(self,
# remove sos/eos and get results
last_pos = -1
if isinstance(hyp.yseq, list):
token_int = hyp.yseq[1:last_pos]
token_int = hyp.yseq#[1:last_pos]
else:
token_int = hyp.yseq[1:last_pos].tolist()
token_int = hyp.yseq#[1:last_pos].tolist()

# remove blank symbol id, which is assumed to be 0
token_int = list(filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int))
Expand Down
11 changes: 8 additions & 3 deletions funasr/models/transducer/rnn_decoder.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
import random
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)

import numpy as np
import torch
import random
import numpy as np
import torch.nn as nn
import torch.nn.functional as F

from funasr.register import tables
from funasr.models.transformer.utils.nets_utils import make_pad_mask
from funasr.models.transformer.utils.nets_utils import to_device
from funasr.models.language_model.rnn.attentions import initial_att
Expand Down Expand Up @@ -78,7 +83,7 @@ def build_attention_list(
)
return att_list


@tables.register("decoder_classes", "rnn_decoder")
class RNNDecoder(nn.Module):
def __init__(
self,
Expand Down
Loading

0 comments on commit 151c339

Please sign in to comment.