Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Funasr1.0 #1282

Merged
merged 4 commits into from
Jan 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,4 @@ samples
.ipynb_checkpoints
outputs*
emotion2vec*
GPT-SoVITS*
19 changes: 10 additions & 9 deletions funasr/auto/auto_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import string
import logging
import os.path
import numpy as np
from tqdm import tqdm
from omegaconf import DictConfig, OmegaConf, ListConfig

Expand Down Expand Up @@ -96,25 +97,25 @@ def __init__(self, **kwargs):
vad_kwargs = kwargs.get("vad_model_revision", None)
if vad_model is not None:
logging.info("Building VAD model.")
vad_kwargs = {"model": vad_model, "model_revision": vad_kwargs}
vad_kwargs = {"model": vad_model, "model_revision": vad_kwargs, "device": kwargs["device"]}
vad_model, vad_kwargs = self.build_model(**vad_kwargs)

# if punc_model is not None, build punc model else None
punc_model = kwargs.get("punc_model", None)
punc_kwargs = kwargs.get("punc_model_revision", None)
if punc_model is not None:
logging.info("Building punc model.")
punc_kwargs = {"model": punc_model, "model_revision": punc_kwargs}
punc_kwargs = {"model": punc_model, "model_revision": punc_kwargs, "device": kwargs["device"]}
punc_model, punc_kwargs = self.build_model(**punc_kwargs)

# if spk_model is not None, build spk model else None
spk_model = kwargs.get("spk_model", None)
spk_kwargs = kwargs.get("spk_model_revision", None)
if spk_model is not None:
logging.info("Building SPK model.")
spk_kwargs = {"model": spk_model, "model_revision": spk_kwargs}
spk_kwargs = {"model": spk_model, "model_revision": spk_kwargs, "device": kwargs["device"]}
spk_model, spk_kwargs = self.build_model(**spk_kwargs)
self.cb_model = ClusterBackend()
self.cb_model = ClusterBackend().to(kwargs["device"])
spk_mode = kwargs.get("spk_mode", 'punc_segment')
if spk_mode not in ["default", "vad_segment", "punc_segment"]:
logging.error("spk_mode should be one of default, vad_segment and punc_segment.")
Expand Down Expand Up @@ -334,7 +335,7 @@ def inference_with_vad(self, input, input_len=None, **cfg):
for _b in range(len(speech_j)):
vad_segments = [[sorted_data[beg_idx:end_idx][_b][0][0]/1000.0,
sorted_data[beg_idx:end_idx][_b][0][1]/1000.0,
speech_j[_b]]]
np.array(speech_j[_b])]]
segments = sv_chunk(vad_segments)
all_segments.extend(segments)
speech_b = [i[2] for i in segments]
Expand Down Expand Up @@ -376,7 +377,7 @@ def inference_with_vad(self, input, input_len=None, **cfg):
result[k] = restored_data[j][k]
else:
result[k] = torch.cat([result[k], restored_data[j][k]], dim=0)
elif k == 'text':
elif k == 'raw_text':
if k not in result:
result[k] = restored_data[j][k]
else:
Expand All @@ -397,20 +398,20 @@ def inference_with_vad(self, input, input_len=None, **cfg):
if self.spk_model is not None:
all_segments = sorted(all_segments, key=lambda x: x[0])
spk_embedding = result['spk_embedding']
labels = self.cb_model(spk_embedding, oracle_num=self.preset_spk_num)
labels = self.cb_model(spk_embedding.cpu(), oracle_num=self.preset_spk_num)
del result['spk_embedding']
sv_output = postprocess(all_segments, None, labels, spk_embedding.cpu())
if self.spk_mode == 'vad_segment':
sentence_list = []
for res, vadsegment in zip(restored_data, vadsegments):
sentence_list.append({"start": vadsegment[0],\
"end": vadsegment[1],
"sentence": res['text'],
"sentence": res['raw_text'],
"timestamp": res['timestamp']})
else: # punc_segment
sentence_list = timestamp_sentence(punc_res[0]['punc_array'], \
result['timestamp'], \
result['text'])
result['raw_text'])
distribute_spk(sentence_list, sv_output)
result['sentence_info'] = sentence_list

Expand Down
6 changes: 3 additions & 3 deletions funasr/models/paraformer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,9 +528,9 @@ def inference(self,
if tokenizer is not None:
# Change integer-ids to tokens
token = tokenizer.ids2tokens(token_int)
text = tokenizer.tokens2text(token)

text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
text_postprocessed = tokenizer.tokens2text(token)
if not hasattr(tokenizer, "bpemodel"):
text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)

result_i = {"key": key[i], "text": text_postprocessed}

Expand Down
4 changes: 2 additions & 2 deletions funasr/models/seaco_paraformer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,12 +415,12 @@ def inference(self,
token, timestamp)

result_i = {"key": key[i], "text": text_postprocessed,
"timestamp": time_stamp_postprocessed,
"timestamp": time_stamp_postprocessed, "raw_text": copy.copy(text_postprocessed)
}

if ibest_writer is not None:
ibest_writer["token"][key[i]] = " ".join(token)
# ibest_writer["text"][key[i]] = text
# ibest_writer["raw_text"][key[i]] = text
ibest_writer["timestamp"][key[i]] = time_stamp_postprocessed
ibest_writer["text"][key[i]] = text_postprocessed
else:
Expand Down
28 changes: 20 additions & 8 deletions funasr/tokenizer/sentencepiece_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,16 @@

import sentencepiece as spm

from funasr.tokenizer.abs_tokenizer import AbsTokenizer


class SentencepiecesTokenizer(AbsTokenizer):
def __init__(self, model: Union[Path, str]):
self.model = str(model)
from funasr.tokenizer.abs_tokenizer import BaseTokenizer
from funasr.register import tables

@tables.register("tokenizer_classes", "SentencepiecesTokenizer")
class SentencepiecesTokenizer(BaseTokenizer):
def __init__(self, bpemodel: Union[Path, str],
**kwargs
):
super().__init__(**kwargs)
self.bpemodel = str(bpemodel)
# NOTE(kamo):
# Don't build SentencePieceProcessor in __init__()
# because it's not picklable and it may cause following error,
Expand All @@ -19,13 +23,13 @@ def __init__(self, model: Union[Path, str]):
self.sp = None

def __repr__(self):
return f'{self.__class__.__name__}(model="{self.model}")'
return f'{self.__class__.__name__}(model="{self.bpemodel}")'

def _build_sentence_piece_processor(self):
# Build SentencePieceProcessor lazily.
if self.sp is None:
self.sp = spm.SentencePieceProcessor()
self.sp.load(self.model)
self.sp.load(self.bpemodel)

def text2tokens(self, line: str) -> List[str]:
self._build_sentence_piece_processor()
Expand All @@ -34,3 +38,11 @@ def text2tokens(self, line: str) -> List[str]:
def tokens2text(self, tokens: Iterable[str]) -> str:
self._build_sentence_piece_processor()
return self.sp.DecodePieces(list(tokens))

def encode(self, line: str) -> List[int]:
self._build_sentence_piece_processor()
return self.sp.EncodeAsIds(line)

def decode(self, line: List[int]):
self._build_sentence_piece_processor()
return self.sp.DecodeIds(line)
Loading