diff --git a/examples/aishell/paraformer/conf/train_asr_paraformer_conformer_12e_6d_2048_256.yaml b/examples/aishell/paraformer/conf/train_asr_paraformer_conformer_12e_6d_2048_256.yaml index 3a2231f4d..7d41c640e 100644 --- a/examples/aishell/paraformer/conf/train_asr_paraformer_conformer_12e_6d_2048_256.yaml +++ b/examples/aishell/paraformer/conf/train_asr_paraformer_conformer_12e_6d_2048_256.yaml @@ -99,7 +99,10 @@ dataset_conf: max_token_length: 2048 # filter samples if source_token_len+target_token_len > max_token_length, buffer_size: 1024 shuffle: True - num_workers: 0 + num_workers: 4 + preprocessor_speech: SpeechPreprocessSpeedPerturb + preprocessor_speech_conf: + speed_perturb: [0.9, 1.0, 1.1] tokenizer: CharTokenizer tokenizer_conf: diff --git a/examples/aishell/paraformer/run.sh b/examples/aishell/paraformer/run.sh index fd51de2a8..994513218 100755 --- a/examples/aishell/paraformer/run.sh +++ b/examples/aishell/paraformer/run.sh @@ -1,13 +1,8 @@ #!/usr/bin/env bash -workspace=`pwd` -# machines configuration + CUDA_VISIBLE_DEVICES="0,1" -gpu_num=2 -gpu_inference=true # Whether to perform gpu decoding, set false for cpu decoding -# for gpu decoding, inference_nj=ngpu*njob; for cpu decoding, inference_nj=njob -njob=1 # general configuration feats_dir="../DATA" #feature output dictionary @@ -18,7 +13,11 @@ stage=0 stop_stage=5 # feature configuration -nj=64 +nj=32 + +inference_device="cuda" #"cpu" +inference_checkpoint="model.pt" +inference_scp="wav.scp" # data raw_data=../raw_data @@ -26,6 +25,7 @@ data_url=www.openslr.org/resources/33 # exp tag tag="exp1" +workspace=`pwd` . utils/parse_options.sh || exit 1; @@ -42,11 +42,6 @@ test_sets="dev test" config=train_asr_paraformer_conformer_12e_6d_2048_256.yaml model_dir="baseline_$(basename "${config}" .yaml)_${lang}_${token_type}_${tag}" -inference_device="cuda" #"cpu" -inference_checkpoint="model.pt" -inference_scp="wav.scp" - - if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then echo "stage -1: Data Download" @@ -112,6 +107,8 @@ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then mkdir -p ${exp_dir}/exp/${model_dir} log_file="${exp_dir}/exp/${model_dir}/train.log.txt" echo "log_file: ${log_file}" + + gpu_num=$(echo CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') torchrun \ --nnodes 1 \ --nproc_per_node ${gpu_num} \ diff --git a/funasr/datasets/audio_datasets/preprocessor.py b/funasr/datasets/audio_datasets/preprocessor.py index 6c21fbf0e..c2e27bfc4 100644 --- a/funasr/datasets/audio_datasets/preprocessor.py +++ b/funasr/datasets/audio_datasets/preprocessor.py @@ -41,43 +41,9 @@ def __init__(self, seg_dict: str = None, **kwargs): super().__init__() - self.seg_dict = None - if seg_dict is not None: - self.seg_dict = {} - with open(seg_dict, "r", encoding="utf8") as f: - lines = f.readlines() - for line in lines: - s = line.strip().split() - key = s[0] - value = s[1:] - self.seg_dict[key] = " ".join(value) self.text_cleaner = TextCleaner(text_cleaner) - self.split_with_space = split_with_space def forward(self, text, **kwargs): - if self.seg_dict is not None: - text = self.text_cleaner(text) - if self.split_with_space: - tokens = text.strip().split(" ") - if self.seg_dict is not None: - text = seg_tokenize(tokens, self.seg_dict) - + text = self.text_cleaner(text) + return text - -def seg_tokenize(txt, seg_dict): - pattern = re.compile(r'^[\u4E00-\u9FA50-9]+$') - out_txt = "" - for word in txt: - word = word.lower() - if word in seg_dict: - out_txt += seg_dict[word] + " " - else: - if pattern.match(word): - for char in word: - if char in seg_dict: - out_txt += seg_dict[char] + " " - else: - out_txt += "" + " " - else: - out_txt += "" + " " - return out_txt.strip().split() \ No newline at end of file diff --git a/funasr/frontends/wav_frontend.py b/funasr/frontends/wav_frontend.py index 71cf77a07..c6e03e86e 100644 --- a/funasr/frontends/wav_frontend.py +++ b/funasr/frontends/wav_frontend.py @@ -32,7 +32,6 @@ def load_cmvn(cmvn_file): rescale_line = line_item[3:(len(line_item) - 1)] vars_list = list(rescale_line) continue - import pdb;pdb.set_trace() means = np.array(means_list).astype(np.float32) vars = np.array(vars_list).astype(np.float32) cmvn = np.array([means, vars]) diff --git a/funasr/tokenizer/char_tokenizer.py b/funasr/tokenizer/char_tokenizer.py index 0635fd70c..0f40b5e63 100644 --- a/funasr/tokenizer/char_tokenizer.py +++ b/funasr/tokenizer/char_tokenizer.py @@ -3,60 +3,105 @@ from typing import List from typing import Union import warnings +import re from funasr.tokenizer.abs_tokenizer import BaseTokenizer from funasr.register import tables @tables.register("tokenizer_classes", "CharTokenizer") class CharTokenizer(BaseTokenizer): - def __init__( - self, - non_linguistic_symbols: Union[Path, str, Iterable[str]] = None, - space_symbol: str = "", - remove_non_linguistic_symbols: bool = False, - **kwargs, - ): - super().__init__(**kwargs) - self.space_symbol = space_symbol - if non_linguistic_symbols is None: - self.non_linguistic_symbols = set() - elif isinstance(non_linguistic_symbols, (Path, str)): - non_linguistic_symbols = Path(non_linguistic_symbols) - try: - with non_linguistic_symbols.open("r", encoding="utf-8") as f: - self.non_linguistic_symbols = set(line.rstrip() for line in f) - except FileNotFoundError: - warnings.warn(f"{non_linguistic_symbols} doesn't exist.") - self.non_linguistic_symbols = set() - else: - self.non_linguistic_symbols = set(non_linguistic_symbols) - self.remove_non_linguistic_symbols = remove_non_linguistic_symbols + def __init__( + self, + non_linguistic_symbols: Union[Path, str, Iterable[str]] = None, + space_symbol: str = "", + remove_non_linguistic_symbols: bool = False, + split_with_space: bool = False, + seg_dict: str = None, + **kwargs, + ): + super().__init__(**kwargs) + self.space_symbol = space_symbol + if non_linguistic_symbols is None: + self.non_linguistic_symbols = set() + elif isinstance(non_linguistic_symbols, (Path, str)): + non_linguistic_symbols = Path(non_linguistic_symbols) + try: + with non_linguistic_symbols.open("r", encoding="utf-8") as f: + self.non_linguistic_symbols = set(line.rstrip() for line in f) + except FileNotFoundError: + warnings.warn(f"{non_linguistic_symbols} doesn't exist.") + self.non_linguistic_symbols = set() + else: + self.non_linguistic_symbols = set(non_linguistic_symbols) + self.remove_non_linguistic_symbols = remove_non_linguistic_symbols + self.split_with_space = split_with_space + self.seg_dict = None + if seg_dict is not None: + self.seg_dict = load_seg_dict(seg_dict) + + + def __repr__(self): + return ( + f"{self.__class__.__name__}(" + f'space_symbol="{self.space_symbol}"' + f'non_linguistic_symbols="{self.non_linguistic_symbols}"' + f")" + ) + + def text2tokens(self, line: Union[str, list]) -> List[str]: + + if self.split_with_space: + tokens = line.strip().split(" ") + if self.seg_dict is not None: + tokens = seg_tokenize(tokens, self.seg_dict) + else: + tokens = [] + while len(line) != 0: + for w in self.non_linguistic_symbols: + if line.startswith(w): + if not self.remove_non_linguistic_symbols: + tokens.append(line[: len(w)]) + line = line[len(w) :] + break + else: + t = line[0] + if t == " ": + t = "" + tokens.append(t) + line = line[1:] + return tokens + + def tokens2text(self, tokens: Iterable[str]) -> str: + tokens = [t if t != self.space_symbol else " " for t in tokens] + return "".join(tokens) - def __repr__(self): - return ( - f"{self.__class__.__name__}(" - f'space_symbol="{self.space_symbol}"' - f'non_linguistic_symbols="{self.non_linguistic_symbols}"' - f")" - ) - def text2tokens(self, line: Union[str, list]) -> List[str]: - tokens = [] - while len(line) != 0: - for w in self.non_linguistic_symbols: - if line.startswith(w): - if not self.remove_non_linguistic_symbols: - tokens.append(line[: len(w)]) - line = line[len(w) :] - break - else: - t = line[0] - if t == " ": - t = "" - tokens.append(t) - line = line[1:] - return tokens +def load_seg_dict(seg_dict_file): + seg_dict = {} + assert isinstance(seg_dict_file, str) + with open(seg_dict_file, "r", encoding="utf8") as f: + lines = f.readlines() + for line in lines: + s = line.strip().split() + key = s[0] + value = s[1:] + seg_dict[key] = " ".join(value) + return seg_dict - def tokens2text(self, tokens: Iterable[str]) -> str: - tokens = [t if t != self.space_symbol else " " for t in tokens] - return "".join(tokens) \ No newline at end of file +def seg_tokenize(txt, seg_dict): + pattern = re.compile(r'^[\u4E00-\u9FA50-9]+$') + out_txt = "" + for word in txt: + word = word.lower() + if word in seg_dict: + out_txt += seg_dict[word] + " " + else: + if pattern.match(word): + for char in word: + if char in seg_dict: + out_txt += seg_dict[char] + " " + else: + out_txt += "" + " " + else: + out_txt += "" + " " + return out_txt.strip().split() \ No newline at end of file