diff --git a/examples/contextual_asr/README.md b/examples/contextual_asr/README.md index 62173cd0..b1dabe83 100644 --- a/examples/contextual_asr/README.md +++ b/examples/contextual_asr/README.md @@ -26,6 +26,7 @@ They categorize the 5,000 most frequent words in the Librispeech training corpus words, with the remainder classified as rare words. The biasing list generated for the test set consists of two segments: rare words in the transcriptions, and distractors sampled from the 209.2K rare words vocabulary. Biasing lists of varying lengths are generated by incorporating N = {100, 500, 1000, 2000} distractors into the lists. +The viterbi decode results of our CTC Fine-tuned WavLM-Large: [test-clean](https://drive.google.com/file/d/1kMzPx8oRK3aOsxNaMGski3zH8z5Otvek/view?usp=drive_link), [test-other](https://drive.google.com/file/d/12KHaatVg5O0MIBTcf8e_rNjV_i9WLBFR/view?usp=drive_link) (``ctc_file`` in contextual_asr_config.py) ## Decoding with checkpoints LLM-based ASR Inference script. diff --git a/examples/contextual_asr/contextual_asr_config.py b/examples/contextual_asr/contextual_asr_config.py index 6bfc9e4b..392ce8b4 100644 --- a/examples/contextual_asr/contextual_asr_config.py +++ b/examples/contextual_asr/contextual_asr_config.py @@ -105,8 +105,6 @@ class DataConfig: infer_type: str = "bias" infer_file: str = "/nfs/yangguanrou.ygr/data/fbai-speech/is21_deep_bias/my_ref/test-clean.biasing_100.tsv" ctc_file: str = "/nfs/yangguanrou.ygr/data/librispeech_my_infer/wavlm_large_libri_test_other_char.txt" - filter_type: str = "char" - phn_to_name_dict: str = "/nfs/yangguanrou.ygr/data/librispeech_my_infer/wavlm_ft_libri960_${ref_split}_phn.json" common_words_5k_dir: str="/nfs/yangguanrou.ygr/data/fbai-speech/is21_deep_bias/words/common_words_5k.txt" probability_threshold: float = 0.9 word_num: int = 15 diff --git a/examples/contextual_asr/dataset/hotwordsinfer_dataset.py b/examples/contextual_asr/dataset/hotwordsinfer_dataset.py index b9a72f51..6932e58c 100644 --- a/examples/contextual_asr/dataset/hotwordsinfer_dataset.py +++ b/examples/contextual_asr/dataset/hotwordsinfer_dataset.py @@ -36,25 +36,6 @@ def find_candidate_names(sentence, ngram_index, n=2): candidates.update(ngram_index.get(ngram, [])) return candidates -def build_ngram_index_phn(names, n=2): - """构建N-Gram倒排索引""" - index = {} - for name in names: - phonemes = name.split() - for i in range(len(phonemes) - n + 1): - ngram = ' '.join(phonemes[i:i+n]) - index.setdefault(ngram, set()).add(name) - return index - -def find_candidate_names_phn(phonemes, ngram_index, n=2): - """通过N-Gram倒排索引找到候选人名""" - candidates = set() - phonemes = phonemes.split() - for i in range(len(phonemes) - n + 1): - ngram = ' '.join(phonemes[i:i+n]) - candidates.update(ngram_index.get(ngram, [])) - return candidates - @lru_cache(maxsize=100000) def similarity(name, sentence): return Levenshtein.ratio(name, sentence) @@ -139,11 +120,6 @@ def __init__( # analyze self.hotwords_num=0 self.miss_words_num=0 - self.filter_type=dataset_config.filter_type - if self.filter_type=="phn": - with open(dataset_config.phn_to_name_dict, 'r') as file: - self.phn_to_name_dict = json.load(file) - self.probability_threshold = dataset_config.get("probability_threshold", 0.95) self.word_num = dataset_config.get("word_num", 15) self.prompt_word_num = 0 @@ -202,22 +178,14 @@ def __getitem__(self, index): ocr = ocr.upper() elif self.infer_type=="filter": gt = eval(self.hotwords_list[index]) - if self.filter_type == "char": - infer_sentence = self.infer_list[index].lower() - else: - infer_sentence = self.infer_list[index] - + infer_sentence = self.infer_list[index].lower() words_list = infer_sentence.split() filtered_words = [word for word in words_list if word not in self.common_words_5k] infer_sentence = ' '.join(filtered_words) biaswords=eval(self.biaswords_list[index]) - if self.filter_type=="char": - ngram_index = build_ngram_index(biaswords) - candidates = find_candidate_names(infer_sentence, ngram_index) - elif self.filter_type=="phn": - ngram_index = build_ngram_index_phn(biaswords) - candidates = find_candidate_names_phn(infer_sentence, ngram_index) + ngram_index = build_ngram_index(biaswords) + candidates = find_candidate_names(infer_sentence, ngram_index) if not self.filter_infer_sentence_few: scores = score_candidates(candidates, infer_sentence) sorted_dict = sorted(scores.items(), key=lambda item: item[1], reverse=True) @@ -246,10 +214,6 @@ def __getitem__(self, index): logger.info("infer sentence: %s",infer_sentence) logger.info("target sentence: %s", target) logger.info("gt: %s, keys_list: %s", gt, keys_list) - # =============================== - if self.filter_type=="phn": - keys_list = [self.phn_to_name_dict[phn] for phn in keys_list] - keys_list = [item for sublist in keys_list for item in sublist] ocr = " ".join(keys_list).upper()