Skip to content

Commit

Permalink
Merge pull request #181 from X-LANCE/ygr_pr2
Browse files Browse the repository at this point in the history
upload ctc_file and remove irrelavant codes
  • Loading branch information
ddlBoJack authored Nov 30, 2024
2 parents 8d2dc88 + 12539e4 commit 87e1449
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 41 deletions.
1 change: 1 addition & 0 deletions examples/contextual_asr/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ They categorize the 5,000 most frequent words in the Librispeech training corpus
words, with the remainder classified as rare words. The biasing list generated for the test set consists of two segments: rare words in the transcriptions, and distractors sampled from the 209.2K rare words vocabulary. Biasing lists of varying lengths are generated by incorporating N = {100, 500, 1000, 2000} distractors into the lists.


The viterbi decode results of our CTC Fine-tuned WavLM-Large: [test-clean](https://drive.google.com/file/d/1kMzPx8oRK3aOsxNaMGski3zH8z5Otvek/view?usp=drive_link), [test-other](https://drive.google.com/file/d/12KHaatVg5O0MIBTcf8e_rNjV_i9WLBFR/view?usp=drive_link) (``ctc_file`` in contextual_asr_config.py)

## Decoding with checkpoints
LLM-based ASR Inference script.
Expand Down
2 changes: 0 additions & 2 deletions examples/contextual_asr/contextual_asr_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,6 @@ class DataConfig:
infer_type: str = "bias"
infer_file: str = "/nfs/yangguanrou.ygr/data/fbai-speech/is21_deep_bias/my_ref/test-clean.biasing_100.tsv"
ctc_file: str = "/nfs/yangguanrou.ygr/data/librispeech_my_infer/wavlm_large_libri_test_other_char.txt"
filter_type: str = "char"
phn_to_name_dict: str = "/nfs/yangguanrou.ygr/data/librispeech_my_infer/wavlm_ft_libri960_${ref_split}_phn.json"
common_words_5k_dir: str="/nfs/yangguanrou.ygr/data/fbai-speech/is21_deep_bias/words/common_words_5k.txt"
probability_threshold: float = 0.9
word_num: int = 15
Expand Down
42 changes: 3 additions & 39 deletions examples/contextual_asr/dataset/hotwordsinfer_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,25 +36,6 @@ def find_candidate_names(sentence, ngram_index, n=2):
candidates.update(ngram_index.get(ngram, []))
return candidates

def build_ngram_index_phn(names, n=2):
"""构建N-Gram倒排索引"""
index = {}
for name in names:
phonemes = name.split()
for i in range(len(phonemes) - n + 1):
ngram = ' '.join(phonemes[i:i+n])
index.setdefault(ngram, set()).add(name)
return index

def find_candidate_names_phn(phonemes, ngram_index, n=2):
"""通过N-Gram倒排索引找到候选人名"""
candidates = set()
phonemes = phonemes.split()
for i in range(len(phonemes) - n + 1):
ngram = ' '.join(phonemes[i:i+n])
candidates.update(ngram_index.get(ngram, []))
return candidates

@lru_cache(maxsize=100000)
def similarity(name, sentence):
return Levenshtein.ratio(name, sentence)
Expand Down Expand Up @@ -139,11 +120,6 @@ def __init__(
# analyze
self.hotwords_num=0
self.miss_words_num=0
self.filter_type=dataset_config.filter_type
if self.filter_type=="phn":
with open(dataset_config.phn_to_name_dict, 'r') as file:
self.phn_to_name_dict = json.load(file)

self.probability_threshold = dataset_config.get("probability_threshold", 0.95)
self.word_num = dataset_config.get("word_num", 15)
self.prompt_word_num = 0
Expand Down Expand Up @@ -202,22 +178,14 @@ def __getitem__(self, index):
ocr = ocr.upper()
elif self.infer_type=="filter":
gt = eval(self.hotwords_list[index])
if self.filter_type == "char":
infer_sentence = self.infer_list[index].lower()
else:
infer_sentence = self.infer_list[index]

infer_sentence = self.infer_list[index].lower()
words_list = infer_sentence.split()
filtered_words = [word for word in words_list if word not in self.common_words_5k]
infer_sentence = ' '.join(filtered_words)

biaswords=eval(self.biaswords_list[index])
if self.filter_type=="char":
ngram_index = build_ngram_index(biaswords)
candidates = find_candidate_names(infer_sentence, ngram_index)
elif self.filter_type=="phn":
ngram_index = build_ngram_index_phn(biaswords)
candidates = find_candidate_names_phn(infer_sentence, ngram_index)
ngram_index = build_ngram_index(biaswords)
candidates = find_candidate_names(infer_sentence, ngram_index)
if not self.filter_infer_sentence_few:
scores = score_candidates(candidates, infer_sentence)
sorted_dict = sorted(scores.items(), key=lambda item: item[1], reverse=True)
Expand Down Expand Up @@ -246,10 +214,6 @@ def __getitem__(self, index):
logger.info("infer sentence: %s",infer_sentence)
logger.info("target sentence: %s", target)
logger.info("gt: %s, keys_list: %s", gt, keys_list)
# ===============================
if self.filter_type=="phn":
keys_list = [self.phn_to_name_dict[phn] for phn in keys_list]
keys_list = [item for sublist in keys_list for item in sublist]

ocr = " ".join(keys_list).upper()

Expand Down

0 comments on commit 87e1449

Please sign in to comment.