diff --git a/examples/contextual_asr/README.md b/examples/contextual_asr/README.md new file mode 100644 index 00000000..62173cd0 --- /dev/null +++ b/examples/contextual_asr/README.md @@ -0,0 +1,61 @@ +# CTC-Assisted LLM-Based Contextual ASR + +## Guides + +[CTC-Assisted LLM-Based Contextual ASR](https://arxiv.org/abs/2411.06437) is an LLM-based contextual ASR model that first uses CTC decoding results to filter potential relevant hotwords from pre-defined hotwords list and then incorporate them into LLM prompt input to improve recognition of hotwords. + +## Model Architecture + +We use WavLM-Large model pre-trained on 94, 000 hours of data, and fine-tuned on 960h hours of Librispeech data with CTC loss, as our speech encoder. We use the public Vicuna 7B as our large language model decoder, and a simple-structured linear projector, consisting of a 1-D convolution layer and two linear layers as our adapter. Refer to our [paper](https://arxiv.org/pdf/2411.06437) for more details. + +![](docs/model.png) + +## Checkpoints +We only train the linear projector in this recipe. +Encoder | Projector | LLM +|---|---|---| +[CTC Fine-tuned WavLM-Large](https://drive.google.com/file/d/12ZmSSbDvx73W0eK1wpUgajapCLhqh5DI/view?usp=drive_link)(~315.45M) | [Linear](https://drive.google.com/file/d/1Zlbsnz1YUWtYtt-yNyoPK5OhR30kwLfS/view?usp=drive_link)(~15.74M) | [vicuna-7b-v1.5](https://huggingface.co/lmsys/vicuna-7b-v1.5)(~6.7B) + +## Performance +![](docs/performance.png) + + +## Data preparation +The artificial biasing list constructed in [Contextualized streaming end-to-end speech recognition with trie-based deep biasing and shallow fusion](https://arxiv.org/pdf/2104.02194) is utilized for contextual ASR testing. Refer to official [Repo](https://github.com/facebookresearch/fbai-speech/tree/main/is21_deep_bias). +They categorize the 5,000 most frequent words in the Librispeech training corpus as common +words, with the remainder classified as rare words. The biasing list generated for the test set consists of two segments: rare words in the transcriptions, and distractors sampled from the 209.2K rare words vocabulary. Biasing lists of varying lengths are generated by incorporating N = {100, 500, 1000, 2000} distractors into the lists. + + + +## Decoding with checkpoints +LLM-based ASR Inference script. +``` +bash decode_wavlm_libri960_ft_char.sh +``` +LLM-based Contextual ASR Inference script, with different biaisng list sizes. +``` +bash decode_wavlm_libri960_ft_char_hotwords.sh +``` + + +## Training the model +LLM-based ASR Training script: using CTC fine-tuned Wavlm as encoder and “Transcribe speech to text.” as prompt. +``` +bash finetune_wavlm_libri960_ft_char.sh +``` +LLM-based Contextual ASR Training script: using CTC fine-tuned Wavlm as encoder and "Transcribe speech to text. Some hotwords might help. The hotwords are {}.” as prompt. +``` +bash finetune_wavlm_libri960_ft_char_hotwords.sh +``` + + +## Citation +You can refer to the paper for more results. +``` +@article{yang2024ctc, + title={CTC-Assisted LLM-Based Contextual ASR}, + author={Yang, Guanrou and Ma, Ziyang and Gao, Zhifu and Zhang, Shiliang and Chen, Xie}, + journal={Proc. SLT}, + year={2024} +} +``` \ No newline at end of file diff --git a/examples/contextual_asr/conf/ds_config.json b/examples/contextual_asr/conf/ds_config.json new file mode 100644 index 00000000..7ea70e4a --- /dev/null +++ b/examples/contextual_asr/conf/ds_config.json @@ -0,0 +1,19 @@ +{ + "train_micro_batch_size_per_gpu": 4, + "gradient_accumulation_steps": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-4 + } + }, + "fp16": { + "enabled": true + }, + "zero_optimization": { + "stage": 3, + "offload_optimizer": { + "device": "cpu" + } + } +} \ No newline at end of file diff --git a/examples/contextual_asr/conf/prompt.yaml b/examples/contextual_asr/conf/prompt.yaml new file mode 100644 index 00000000..0bc65175 --- /dev/null +++ b/examples/contextual_asr/conf/prompt.yaml @@ -0,0 +1,4 @@ +dataset_config: + # we put prompt here, because the hydra override in shell script only support a small subset of chars + # prompt: "Transcribe speech to text. Output the transcription directly without redundant content. Ensure that the output is not duplicated. " + prompt: "Transcribe speech to text. " diff --git a/examples/contextual_asr/contextual_asr_config.py b/examples/contextual_asr/contextual_asr_config.py new file mode 100644 index 00000000..6bfc9e4b --- /dev/null +++ b/examples/contextual_asr/contextual_asr_config.py @@ -0,0 +1,137 @@ +from dataclasses import dataclass, field +from typing import Optional, List +@dataclass +class ModelConfig: + file: str = "examples/contextual_asr/model/slam_model_contextual_asr.py:model_factory" + llm_name: str = "vicuna-13b-v1.5" + llm_path: str = "PATH/to/LLAMA/7B" + llm_type: str = "decoder_only" + llm_dim: int = 4096 + encoder_name: Optional[str] = None + encoder_ds_rate: int = 2 + encoder_path: Optional[str] = None + encoder_dim: int = 1280 + encoder_projector: str = "linear" + encoder_projector_ds_rate: int = 5 + modal: str = "audio" + normalize: Optional[bool] = field(default=False, metadata={ + "help": "whether input is normalized, used for models such as wavlm" + }) + encoder_type: str = field(default="finetune", metadata={ + "help": "whether model is only pretrained or finetuned, used for models such as hubert" + }) + +@dataclass +class PeftConfig: + peft_method: str = "lora" # None , llama_adapter, prefix + r: int = 8 + lora_alpha: int = 32 + # target_modules: List = field(default_factory=lambda: [ "q_proj", "v_proj" ]) + target_modules: List = field(default_factory=lambda: [ "q_proj", "v_proj","k_proj","o_proj" ]) + bias: str = "none" + task_type: str = "CAUSAL_LM" + lora_dropout: float = 0.05 + inference_mode: bool = False + +@dataclass +class TrainConfig: + model_name:str = "PATH/to/LLAMA/7B" + enable_ddp:bool = False + enable_deepspeed:bool = False + enable_fsdp:bool = False + low_cpu_fsdp:bool = False + run_validation:bool = True + batch_size_training:int = 4 + batching_strategy:str = field(default="packing", metadata={ + "help":"alternative: padding" + }) + context_length:int = 4096 + gradient_accumulation_steps:int = 1 + num_epochs:int = 3 + num_workers_dataloader:int = 1 + warmup_steps:int = 1000 + total_steps:int = 100000 + validation_interval:int = 1000 + lr:float = 1e-4 + weight_decay:float = 0.0 + gamma:float = 0.85 + seed:int = 42 + use_fp16:bool = False + mixed_precision:bool = True + val_batch_size:int = 1 + use_peft:bool = False + peft_config:PeftConfig = field(default_factory=PeftConfig) + output_dir:str = "PATH/to/save/PEFT/model" + freeze_layers:bool = False + num_freeze_layers:int = 1 + quantization:bool = False + one_gpu:bool = False + save_model:bool = True + dist_checkpoint_root_folder:str = "PATH/to/save/FSDP/model" # will be used if using FSDP + dist_checkpoint_folder:str = "fine-tuned" # will be used if using FSDP + save_optimizer:bool = False # will be used if using FSDP + use_fast_kernels:bool = False # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels + run_test_during_validation:bool = False + run_test_during_validation_file:str = "test.wav" + run_test_during_validation_prompt:str = "<|ASR|>" + freeze_llm:bool = field(default=False, metadata={ + "help": "whether to freeze llm when finetuning, should be true when use peft finetuning" + }) + freeze_encoder:bool = False + +@dataclass +class DataConfig: + dataset: str = "speech_dataset" + file: str = "src/slam_llm/datasets/speech_dataset.py:get_speech_dataset" + train_data_path: Optional[str] = None + val_data_path: Optional[str] = None + train_split: str = "train" + test_split:str = "validation" + prompt: Optional[str] = None + data_path: Optional[str] = None + max_words: Optional[int] = None + max_mel: Optional[float] = None + fix_length_audio: int = -1 + inference_mode:bool = False + input_type: str = field(default="raw", metadata={ + "help":"Use raw when input is wav, mel when for whisper" + }) + mel_size: int = field(default=80, metadata={ + "help": "80 for whisper large v1 and v2, 128 for v3" + }) + normalize: Optional[bool] = field(default=False, metadata={ + "help": "whether input is normalized, used for models such as wavlm" + }) + infer_type: str = "bias" + infer_file: str = "/nfs/yangguanrou.ygr/data/fbai-speech/is21_deep_bias/my_ref/test-clean.biasing_100.tsv" + ctc_file: str = "/nfs/yangguanrou.ygr/data/librispeech_my_infer/wavlm_large_libri_test_other_char.txt" + filter_type: str = "char" + phn_to_name_dict: str = "/nfs/yangguanrou.ygr/data/librispeech_my_infer/wavlm_ft_libri960_${ref_split}_phn.json" + common_words_5k_dir: str="/nfs/yangguanrou.ygr/data/fbai-speech/is21_deep_bias/words/common_words_5k.txt" + probability_threshold: float = 0.9 + word_num: int = 15 + filter_infer_sentence: bool = False + filter_infer_sentence_few: bool = False + first: int = 1 + +@dataclass +class FSDPConfig: + mixed_precision: bool = True + use_fp16: bool = False + # sharding_strategy = "FULL_SHARD" #ShardingStrategy = ShardingStrategy.FULL_SHARD + sharding_strategy: str = "NO_SHARD" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP + checkpoint_type: str = "SHARDED_STATE_DICT" # alternatively can use SHARDED_STATE_DICT save one file per rank, and can resize the world-size. + fsdp_activation_checkpointing: bool = True + fsdp_cpu_offload: bool = False + pure_bf16: bool = False + optimizer: str = "AdamW" + +@dataclass +class LogConfig: + use_wandb: bool = False + wandb_dir: str = "/root/test_wandb" + wandb_entity_name: str = "project_name" + wandb_project_name: str = "project_name" + wandb_exp_name: str = "exp_name" + log_file: str = "/root/test.log" + log_interval: int = 5 diff --git a/examples/contextual_asr/dataset/hotwords_dataset.py b/examples/contextual_asr/dataset/hotwords_dataset.py new file mode 100644 index 00000000..20660bb8 --- /dev/null +++ b/examples/contextual_asr/dataset/hotwords_dataset.py @@ -0,0 +1,230 @@ +import os.path as osp +import random +import json, yaml +import copy +import numpy as np +from scipy import signal +import soundfile as sf +import torch +import torchaudio +from torch.utils.data import Dataset +import whisper +from slam_llm.utils.compute_utils import calculate_output_length_1d + + +class HotwordsDataset(torch.utils.data.Dataset): + def __init__( + self, + dataset_config, + tokenizer=None, + split='train', + ): + super().__init__() + self.dataset_config = dataset_config + self.tokenizer = tokenizer + data_parallel_size = 1 + self.IGNORE_INDEX = -100 # The default setting in CrossEntropyLoss + self.prompt = dataset_config.get("prompt", None) + self.mel_size = dataset_config.get("mel_size", 80) # 80 for whisper large v1 and v2, 128 for large v3 + self.prompt_template = "USER: {}\n ASSISTANT:" + self.answer_template = "{}" + self.fix_length_audio = dataset_config.get("fix_length_audio", -1) + self.inference_mode = dataset_config.get("inference_mode", False) + self.normalize = dataset_config.get("normalize", False) + self.input_type = dataset_config.get("input_type", None) + assert self.input_type in ["raw", "mel"], "input_type must be one of [raw, mel]" + self.Pkeep = dataset_config.get("Pkeep", 0.5) + self.Norder = dataset_config.get("Norder", 4) + + self.data_list = [] + if split == "train": + with open(dataset_config.train_data_path, encoding='utf-8') as fin: + for line in fin: + data_dict = json.loads(line.strip()) + self.data_list.append(data_dict) + else: + with open(dataset_config.val_data_path, encoding='utf-8') as fin: + for line in fin: + data_dict = json.loads(line.strip()) + self.data_list.append(data_dict) + + def get_source_len(self, data_dict): + return data_dict["source_len"] + + def get_target_len(self, data_dict): + return data_dict["target_len"] if "target_len" in data_dict else 0 + + def __len__(self): + return len(self.data_list) + + def __getitem__(self, index): + data_dict = self.data_list[index] + audio_path = data_dict.get("source") + target = data_dict.get("target", None) + task = data_dict.get("prompt", "ASR") + key = data_dict.get("key", None) + + audio_raw = whisper.load_audio(audio_path) + if self.input_type == "raw": + audio_raw = torch.from_numpy(audio_raw) + if self.normalize: + audio_raw = torch.nn.functional.layer_norm(audio_raw, audio_raw.shape) + audio_length = len(audio_raw) // 320 # ad-hoc for fairseq 320x downsample + audio_length = audio_length // 5 # ad-hoc for 5x fc downsample + elif self.input_type == "mel": + audio_raw = whisper.pad_or_trim(audio_raw) + audio_mel = whisper.log_mel_spectrogram(audio_raw, n_mels=self.mel_size).permute(1, 0) + audio_length = (audio_mel.shape[0] + 1) // 2 # ad-hoc for whisper for 2x downsample from mel to feats + audio_length = audio_length // 5 # ad-hoc for 5x fc downsample + if self.fix_length_audio > 0: + audio_length = self.fix_length_audio + audio_pseudo = torch.full((audio_length,), -1) # placeholder + + if self.inference_mode: + return { + "audio": audio_raw if self.input_type == "raw" else None, + "audio_mel": audio_mel if self.input_type == "mel" else None, + "audio_length": audio_length, + "key": key, + "target": target, + } + else: + return { + "audio": audio_raw if self.input_type == "raw" else None, + "audio_mel": audio_mel if self.input_type == "mel" else None, + "audio_length": audio_length, + "target":target, + } + + def pad(self, sequence, max_length, padding_idx=0): + if isinstance(sequence, (int, list, tuple)): + if len(sequence) < max_length: + sequence = sequence + [padding_idx] * (max_length - len(sequence)) + else: + sequence = sequence[:max_length] + elif isinstance(sequence, torch.Tensor): + if len(sequence) < max_length: + sequence = torch.cat( + (sequence, torch.full(([max_length - len(sequence)] + list(sequence.size())[1:]), padding_idx))) + else: + sequence = sequence[:max_length] + elif isinstance(sequence, np.ndarray): + if len(sequence) < max_length: + sequence = np.concatenate( + (sequence, np.full((max_length - len(sequence),) + sequence.shape[1:], padding_idx))) + else: + sequence = sequence[:max_length] + else: + raise Exception("Type mismatch during padding!") + return sequence + + def collator(self, samples): + assert samples is not None + + if self.dataset_config.infer_type=="nobias": + selected_ngrams="" + else: + selected_ngrams_list = [] + for s in samples: + label = s['target'] + if random.random() < self.Pkeep: + words = label.split() + n = min(random.randint(1, self.Norder),len(words)) + if len(words) >= n: + start_index = random.randint(0,len(words)-n) + selected_ngrams = words[start_index:start_index + n] + selected_ngrams_list.append(" ".join(selected_ngrams)) + selected_ngrams = " ".join(selected_ngrams_list) + + prompt = "Transcribe speech to text. Some hotwords might help. The hotwords are \"{}\". " + prompt = prompt.format(selected_ngrams) + prompt = self.prompt_template.format(prompt) #'USER: Transcribe speech to text. Some hotwords might help. The hotwords are "ONLY OR FOUND DEMANDS YOU THREE RESPONSIVE AND COVER\'D BY TO". \n ASSISTANT:' + prompt_ids = self.tokenizer.encode(prompt) + prompt_length = len(prompt_ids) + + if self.inference_mode: + for i in range(len(samples)): + audio_pseudo = torch.full((samples[i]["audio_length"],), -1) + prompt_ids = torch.tensor(prompt_ids, dtype=torch.int64) + example_ids = torch.cat((audio_pseudo, prompt_ids)) # [audio,prompt] + example_mask = example_ids.ge(-1) # [True,True] + + samples[i]["input_ids"] = example_ids + samples[i]["attention_mask"] = example_mask + else: + for i in range(len(samples)): + audio_length = samples[i]["audio_length"] + audio_pseudo = torch.full((audio_length,), -1) + answer = self.answer_template.format(samples[i]["target"]) + example = prompt + answer # FIX(MZY): avoid putting a bos token before answer. + example_ids = self.tokenizer.encode(example) # [prompt,answer] + example_ids.append(self.tokenizer.eos_token_id) # [prompt,answer,eos] + example_ids = torch.tensor( + example_ids, dtype=torch.int64 + ) + example_ids = torch.cat((audio_pseudo, example_ids)) # [audio,prompt,answer,eos] + + labels_ids = copy.deepcopy(example_ids) # [audio,prompt,answer,eos] + labels_ids[:audio_length + prompt_length] = -1 # [-1,-1,answer,eos]; + example_mask = example_ids.ge(-1) # FIX(GZF): [True,True,True,True] + + label_mask = labels_ids.ge(0) # [False,False,True,True] + example_ids[~example_mask] = 0 # [audio,prompt,answer,eos] + labels_ids[~label_mask] = self.IGNORE_INDEX # [-100,-100,answer,eos] + + samples[i]["input_ids"] = example_ids + samples[i]["labels"] = labels_ids + samples[i]["attention_mask"] = example_mask + + input_ids_max_length = max([s['input_ids'].shape[0] for s in samples]) + input_ids = torch.stack([self.pad(s['input_ids'], input_ids_max_length, self.tokenizer.pad_token_id) for s in samples]) + attention_mask = torch.stack([self.pad(s['attention_mask'], input_ids_max_length, False) for s in samples]) + + if self.input_type == "raw": + audio_raw_max_length = max([s['audio'].shape[0] for s in samples]) + audio_raw = torch.stack([self.pad(s['audio'], audio_raw_max_length, 0) for s in samples]) + audio_mask = torch.zeros(len(samples), audio_raw_max_length) + for line, sample in enumerate(samples): + audio_mask[line, :sample['audio'].shape[0]] = 1 + elif self.input_type == "mel": + audio_mel_max_length = max([s['audio_mel'].shape[0] for s in samples]) + audio_mel = torch.stack([self.pad(s['audio_mel'], audio_mel_max_length, 0) for s in samples]) + audio_mel_post_mask = torch.zeros(len(samples), (audio_mel_max_length + 1) // 2) # ad-hoc for whisper for 2x downsample from mel to feats + for line, sample in enumerate(samples): + audio_mel_post_mask[line, :(sample['audio_mel'].shape[0] + 1) // 2] = 1 + + modality_mask = torch.zeros_like(attention_mask) + for line, sample in enumerate(samples): + modality_mask[line, :sample['audio_length']] = 1 + + if self.inference_mode: + keys = [s['key'] for s in samples] + targets = [s['target'] for s in samples] + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "audio": audio_raw if self.input_type == "raw" else None, + "audio_mask": audio_mask if self.input_type == "raw" else None, + "audio_mel": audio_mel if self.input_type == "mel" else None, + "audio_mel_post_mask": audio_mel_post_mask if self.input_type == "mel" else None, + "modality_mask": modality_mask, + "keys": keys, + "targets": targets + } + else: + labels = torch.stack([self.pad(s['labels'], input_ids_max_length, self.IGNORE_INDEX) for s in samples]) + return { + "input_ids": input_ids, + "labels": labels, + "attention_mask": attention_mask, + "audio": audio_raw if self.input_type == "raw" else None, + "audio_mask": audio_mask if self.input_type == "raw" else None, + "audio_mel": audio_mel if self.input_type == "mel" else None, + "audio_mel_post_mask": audio_mel_post_mask if self.input_type == "mel" else None, + "modality_mask": modality_mask + } + +def get_speech_dataset(dataset_config, tokenizer, split): + dataset = HotwordsDataset(dataset_config, tokenizer, split) + return dataset + diff --git a/examples/contextual_asr/dataset/hotwordsinfer_dataset.py b/examples/contextual_asr/dataset/hotwordsinfer_dataset.py new file mode 100644 index 00000000..b9a72f51 --- /dev/null +++ b/examples/contextual_asr/dataset/hotwordsinfer_dataset.py @@ -0,0 +1,394 @@ +import os.path as osp +import random +import json, yaml +import copy +import numpy as np +from scipy import signal +import soundfile as sf +import difflib +from functools import lru_cache +from tqdm import tqdm +import Levenshtein +import torch +import torchaudio +from torch.utils.data import Dataset +import whisper +from slam_llm.utils.compute_utils import calculate_output_length_1d + +import logging +logger = logging.getLogger(__name__) + + +def build_ngram_index(names, n=2): + """构建N-Gram倒排索引""" + index = {} + for name in names: + for i in range(len(name) - n + 1): + ngram = name[i:i+n].lower() + index.setdefault(ngram, set()).add(name) + return index + +def find_candidate_names(sentence, ngram_index, n=2): + """通过N-Gram倒排索引找到候选人名""" + candidates = set() + for i in range(len(sentence) - n + 1): + ngram = sentence[i:i+n].lower() + candidates.update(ngram_index.get(ngram, [])) + return candidates + +def build_ngram_index_phn(names, n=2): + """构建N-Gram倒排索引""" + index = {} + for name in names: + phonemes = name.split() + for i in range(len(phonemes) - n + 1): + ngram = ' '.join(phonemes[i:i+n]) + index.setdefault(ngram, set()).add(name) + return index + +def find_candidate_names_phn(phonemes, ngram_index, n=2): + """通过N-Gram倒排索引找到候选人名""" + candidates = set() + phonemes = phonemes.split() + for i in range(len(phonemes) - n + 1): + ngram = ' '.join(phonemes[i:i+n]) + candidates.update(ngram_index.get(ngram, [])) + return candidates + +@lru_cache(maxsize=100000) +def similarity(name, sentence): + return Levenshtein.ratio(name, sentence) + +def generate_ngrams(sentence, n): + """生成长度为n的n-grams""" + sentence = sentence.split() + return [' '.join(sentence[i:i+n]) for i in range(len(sentence)-n+1)] + +def calculate_similarity_score(name, sentence, length_tolerance=3): + max_similarity = 0 + name_sentence = name.split() + name_length = len(name_sentence) + sentence_ngrams = generate_ngrams(sentence, name_length) + + for ngram in sentence_ngrams: + if abs(len(ngram) - len(name)) <= length_tolerance: + sim = similarity(name.lower(), ngram.lower()) + max_similarity = max(max_similarity, sim) + return max_similarity + +def score_candidates(candidates, sentence): + """为候选人名计算得分""" + scores = {} + for candidate in candidates: + score = calculate_similarity_score(candidate, sentence) + scores[candidate] = score + return scores + + +class HotwordsInferDataset(torch.utils.data.Dataset): + def __init__( + self, + dataset_config, + tokenizer=None, + split='train', + ): + super().__init__() + self.dataset_config = dataset_config + self.tokenizer = tokenizer + data_parallel_size = 1 + self.IGNORE_INDEX = -100 # The default setting in CrossEntropyLoss + self.prompt = dataset_config.get("prompt", None) + self.mel_size = dataset_config.get("mel_size", 80) # 80 for whisper large v1 and v2, 128 for large v3 + self.prompt_template = "USER: {}\n ASSISTANT:" + self.answer_template = "{}" + self.fix_length_audio = dataset_config.get("fix_length_audio", -1) + self.inference_mode = dataset_config.get("inference_mode", False) + self.normalize = dataset_config.get("normalize", False) + self.input_type = dataset_config.get("input_type", None) + assert self.input_type in ["raw", "mel"], "input_type must be one of [raw, mel]" + + self.data_list = [] + if split == "train": + with open(dataset_config.train_data_path, encoding='utf-8') as fin: + for line in fin: + data_dict = json.loads(line.strip()) + self.data_list.append(data_dict) + else: + with open(dataset_config.val_data_path, encoding='utf-8') as fin: + for line in fin: + data_dict = json.loads(line.strip()) + self.data_list.append(data_dict) + + self.hotwords_list=[] + self.biaswords_list=[] + with open(dataset_config.infer_file,'r') as fref: + for line in fref: + line=line.strip().split('\t') + hotwords = line[2] + biaswords= line[3] + self.hotwords_list.append(hotwords) + self.biaswords_list.append(biaswords) + + self.infer_type=dataset_config.infer_type + if self.infer_type=="filter": + self.infer_list=[] + with open(dataset_config.ctc_file,'r') as finfer: + for line in finfer: + self.infer_list.append(line.strip()) + + # analyze + self.hotwords_num=0 + self.miss_words_num=0 + self.filter_type=dataset_config.filter_type + if self.filter_type=="phn": + with open(dataset_config.phn_to_name_dict, 'r') as file: + self.phn_to_name_dict = json.load(file) + + self.probability_threshold = dataset_config.get("probability_threshold", 0.95) + self.word_num = dataset_config.get("word_num", 15) + self.prompt_word_num = 0 + logger.info("word_num: %d", self.word_num) + logger.info("probability_threshold: %f", self.probability_threshold) + + self.filter_infer_sentence = dataset_config.get("filter_infer_sentence", False) + self.filter_infer_sentence_few = dataset_config.get("filter_infer_sentence_few", False) + if self.filter_infer_sentence: + self.common_words_5k=set() + with open(dataset_config.common_words_5k_dir) as f: + for line in f: + word = line.strip() + self.common_words_5k.add(word) + if self.filter_infer_sentence_few: + self.first = dataset_config.get("first",1) + logger.info("first: %d", self.first) + + def get_source_len(self, data_dict): + return data_dict["source_len"] + + def get_target_len(self, data_dict): + return data_dict["target_len"] if "target_len" in data_dict else 0 + + def __len__(self): + return len(self.data_list) + + def __getitem__(self, index): + data_dict = self.data_list[index] + audio_path = data_dict.get("source") + target = data_dict.get("target", None) + task = data_dict.get("prompt", "ASR") + key = data_dict.get("key", None) + + audio_raw = whisper.load_audio(audio_path) + if self.input_type == "raw": + audio_raw = torch.from_numpy(audio_raw) + if self.normalize: + audio_raw = torch.nn.functional.layer_norm(audio_raw, audio_raw.shape) + audio_length = len(audio_raw) // 320 # ad-hoc for fairseq 320x downsample + audio_length = audio_length // 5 # ad-hoc for 5x fc downsample + elif self.input_type == "mel": + audio_raw = whisper.pad_or_trim(audio_raw) + audio_mel = whisper.log_mel_spectrogram(audio_raw, n_mels=self.mel_size).permute(1, 0) + audio_length = (audio_mel.shape[0] + 1) // 2 # ad-hoc for whisper for 2x downsample from mel to feats + audio_length = audio_length // 5 # ad-hoc for 5x fc downsample + if self.fix_length_audio > 0: + audio_length = self.fix_length_audio + audio_pseudo = torch.full((audio_length,), -1) # placeholder + + if self.infer_type=="nobias": + ocr = "" + elif self.infer_type=="gt": + ocr = eval(self.hotwords_list[index]) + ocr = " ".join(ocr) + ocr = ocr.upper() + elif self.infer_type=="filter": + gt = eval(self.hotwords_list[index]) + if self.filter_type == "char": + infer_sentence = self.infer_list[index].lower() + else: + infer_sentence = self.infer_list[index] + + words_list = infer_sentence.split() + filtered_words = [word for word in words_list if word not in self.common_words_5k] + infer_sentence = ' '.join(filtered_words) + + biaswords=eval(self.biaswords_list[index]) + if self.filter_type=="char": + ngram_index = build_ngram_index(biaswords) + candidates = find_candidate_names(infer_sentence, ngram_index) + elif self.filter_type=="phn": + ngram_index = build_ngram_index_phn(biaswords) + candidates = find_candidate_names_phn(infer_sentence, ngram_index) + if not self.filter_infer_sentence_few: + scores = score_candidates(candidates, infer_sentence) + sorted_dict = sorted(scores.items(), key=lambda item: item[1], reverse=True) + high_score_items = [(k, value) for k, value in sorted_dict if value > self.probability_threshold] + if len(high_score_items) < self.word_num: + high_score_items = sorted_dict[:self.word_num] + self.prompt_word_num += len(high_score_items) + keys_list = [k for k, _ in high_score_items] + + if len(high_score_items)>self.word_num: + logger.info("longer than %d candidates, cand_num: %d", self.word_num,len(high_score_items)) + else: + keys_list = self.score_candidates_for_each_word(candidates, infer_sentence) + self.prompt_word_num += len(keys_list) + + # ======== count recall ======== + miss=False + for name in gt: + self.hotwords_num+=1 + if name not in keys_list: + logger.info("miss name: %s", name) + self.miss_words_num+=1 + miss=True + if miss: + logger.info("key: %s", key) + logger.info("infer sentence: %s",infer_sentence) + logger.info("target sentence: %s", target) + logger.info("gt: %s, keys_list: %s", gt, keys_list) + # =============================== + if self.filter_type=="phn": + keys_list = [self.phn_to_name_dict[phn] for phn in keys_list] + keys_list = [item for sublist in keys_list for item in sublist] + + ocr = " ".join(keys_list).upper() + + prompt = "Transcribe speech to text. Some hotwords might help. The hotwords are \"{}\". " + prompt = prompt.format(ocr) + prompt = self.prompt_template.format(prompt) + prompt_ids = self.tokenizer.encode(prompt) #'USER: Transcribe speech to text. Some hotwords might help. The hotwords are "anon harshly". \n ASSISTANT:' + prompt_length = len(prompt_ids) + + if self.inference_mode: + prompt_ids = torch.tensor(prompt_ids, dtype=torch.int64) + example_ids = torch.cat((audio_pseudo, prompt_ids)) # [audio,prompt] + example_mask = example_ids.ge(-1) # [True,True] + + return { + "input_ids": example_ids, + "attention_mask": example_mask, + "audio": audio_raw if self.input_type == "raw" else None, + "audio_mel": audio_mel if self.input_type == "mel" else None, + "audio_length": audio_length, + "key": key, + "target": target, + } + + answer = self.answer_template.format(target) + example = prompt + answer # FIX(MZY): avoid putting a bos token before answer. + example_ids = self.tokenizer.encode(example) # [prompt,answer] + example_ids.append(self.tokenizer.eos_token_id) # [prompt,answer,eos] + example_ids = torch.tensor( + example_ids, dtype=torch.int64 + ) + example_ids = torch.cat((audio_pseudo, example_ids)) # [audio,prompt,answer,eos] + + labels_ids = copy.deepcopy(example_ids) # [audio,prompt,answer,eos] + labels_ids[:audio_length + prompt_length] = -1 # [-1,-1,answer,eos]; + example_mask = example_ids.ge(-1) # FIX(GZF): [True,True,True,True] + + label_mask = labels_ids.ge(0) # [False,False,True,True] + example_ids[~example_mask] = 0 # [audio,prompt,answer,eos] + labels_ids[~label_mask] = self.IGNORE_INDEX # [-100,-100,answer,eos] + + return { + "input_ids": example_ids, + "labels": labels_ids, + "attention_mask": example_mask, + "audio": audio_raw if self.input_type == "raw" else None, + "audio_mel": audio_mel if self.input_type == "mel" else None, + "audio_length": audio_length, + } + + def pad(self, sequence, max_length, padding_idx=0): + if isinstance(sequence, (int, list, tuple)): + if len(sequence) < max_length: + sequence = sequence + [padding_idx] * (max_length - len(sequence)) + else: + sequence = sequence[:max_length] + elif isinstance(sequence, torch.Tensor): + if len(sequence) < max_length: + sequence = torch.cat( + (sequence, torch.full(([max_length - len(sequence)] + list(sequence.size())[1:]), padding_idx))) + else: + sequence = sequence[:max_length] + elif isinstance(sequence, np.ndarray): + if len(sequence) < max_length: + sequence = np.concatenate( + (sequence, np.full((max_length - len(sequence),) + sequence.shape[1:], padding_idx))) + else: + sequence = sequence[:max_length] + else: + raise Exception("Type mismatch during padding!") + return sequence + + def collator(self, samples): + assert samples is not None + input_ids_max_length = max([s['input_ids'].shape[0] for s in samples]) + input_ids = torch.stack([self.pad(s['input_ids'], input_ids_max_length, self.tokenizer.pad_token_id) + for s in samples]) + attention_mask = torch.stack([self.pad(s['attention_mask'], input_ids_max_length, False) + for s in samples]) + if self.input_type == "raw": + audio_raw_max_length = max([s['audio'].shape[0] for s in samples]) + audio_raw = torch.stack([self.pad(s['audio'], audio_raw_max_length, 0) + for s in samples]) + audio_mask = torch.zeros(len(samples), audio_raw_max_length) + for line, sample in enumerate(samples): + audio_mask[line, :sample['audio'].shape[0]] = 1 + elif self.input_type == "mel": + audio_mel_max_length = max([s['audio_mel'].shape[0] for s in samples]) + audio_mel = torch.stack([self.pad(s['audio_mel'], audio_mel_max_length, 0) + for s in samples]) + audio_mel_post_mask = torch.zeros(len(samples), (audio_mel_max_length + 1) // 2) # ad-hoc for whisper for 2x downsample from mel to feats + for line, sample in enumerate(samples): + audio_mel_post_mask[line, :(sample['audio_mel'].shape[0] + 1) // 2] = 1 + + modality_mask = torch.zeros_like(attention_mask) + for line, sample in enumerate(samples): + modality_mask[line, :sample['audio_length']] = 1 + + if self.inference_mode: + keys = [s['key'] for s in samples] + targets = [s['target'] for s in samples] + + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "audio": audio_raw if self.input_type == "raw" else None, + "audio_mask": audio_mask if self.input_type == "raw" else None, + "audio_mel": audio_mel if self.input_type == "mel" else None, + "audio_mel_post_mask": audio_mel_post_mask if self.input_type == "mel" else None, + "modality_mask": modality_mask, + "keys": keys, + "targets": targets + } + + labels = torch.stack([self.pad(s['labels'], input_ids_max_length, self.IGNORE_INDEX) for s in samples]) + return { + "input_ids": input_ids, + "labels": labels, + "attention_mask": attention_mask, + "audio": audio_raw if self.input_type == "raw" else None, + "audio_mask": audio_mask if self.input_type == "raw" else None, + "audio_mel": audio_mel if self.input_type == "mel" else None, + "audio_mel_post_mask": audio_mel_post_mask if self.input_type == "mel" else None, + "modality_mask": modality_mask + } + + def score_candidates_for_each_word(self,candidates, sentence): + keys_list = [] + for word in sentence.split(): + scores = {} + for candidate in candidates: + score = similarity(word,candidate) + scores[candidate] = score + sorted_items = sorted(scores.items(), key=lambda item: item[1], reverse=True) + first_two_items = sorted_items[:self.first] + keys_list.extend([item[0] for item in first_two_items]) + return keys_list + + +def get_speech_dataset(dataset_config, tokenizer, split): + dataset = HotwordsInferDataset(dataset_config, tokenizer, split) + return dataset diff --git a/examples/contextual_asr/docs/model.png b/examples/contextual_asr/docs/model.png new file mode 100644 index 00000000..78691443 Binary files /dev/null and b/examples/contextual_asr/docs/model.png differ diff --git a/examples/contextual_asr/docs/performance.png b/examples/contextual_asr/docs/performance.png new file mode 100644 index 00000000..86d203a8 Binary files /dev/null and b/examples/contextual_asr/docs/performance.png differ diff --git a/examples/contextual_asr/finetune_contextual_asr.py b/examples/contextual_asr/finetune_contextual_asr.py new file mode 100644 index 00000000..046e2301 --- /dev/null +++ b/examples/contextual_asr/finetune_contextual_asr.py @@ -0,0 +1,50 @@ +from slam_llm.pipeline.finetune import main as train + +import hydra +import logging +from dataclasses import dataclass, field +from omegaconf import DictConfig, ListConfig, OmegaConf +from contextual_asr_config import ModelConfig, TrainConfig, DataConfig, LogConfig, FSDPConfig #, PeftConfig +from typing import Optional, List + +@dataclass +class RunConfig: + dataset_config: DataConfig = field(default_factory=DataConfig) + model_config: ModelConfig = field(default_factory=ModelConfig) + train_config: TrainConfig = field(default_factory=TrainConfig) + log_config: LogConfig = field(default_factory=LogConfig) + fsdp_config: FSDPConfig = field(default_factory=FSDPConfig) + # peft_config: PeftConfig = field(default_factory=PeftConfig) + debug: bool = field(default=False, metadata={"help": "Use pdb when true"}) + metric: str = field(default="acc", metadata={"help": "The metric for evaluation"}) + ckpt_path: Optional[str] = field( + default=None, metadata={"help": "The path to projector checkpoint"} + ) + +@hydra.main(config_name=None, version_base=None) +def main_hydra(cfg: DictConfig): + run_config = RunConfig() + cfg = OmegaConf.merge(run_config, cfg) + def to_plain_list(cfg_item): + if isinstance(cfg_item, ListConfig): + return OmegaConf.to_container(cfg_item, resolve=True) + elif isinstance(cfg_item, DictConfig): + return {k: to_plain_list(v) for k, v in cfg_item.items()} + else: + return cfg_item + + # kwargs = to_plain_list(cfg) + kwargs = cfg + log_level = getattr(logging, kwargs.get("log_level", "INFO").upper()) + + logging.basicConfig(level=log_level) + + if kwargs.get("debug", False): + import pdb; + pdb.set_trace() + + train(kwargs) + + +if __name__ == "__main__": + main_hydra() \ No newline at end of file diff --git a/examples/contextual_asr/inference_contextual_asr_batch.py b/examples/contextual_asr/inference_contextual_asr_batch.py new file mode 100644 index 00000000..1ffc325e --- /dev/null +++ b/examples/contextual_asr/inference_contextual_asr_batch.py @@ -0,0 +1,53 @@ +from slam_llm.pipeline.inference_batch import main as inference + +import hydra +import logging +from dataclasses import dataclass, field +from omegaconf import DictConfig, ListConfig, OmegaConf +from typing import Optional +from contextual_asr_config import ModelConfig, TrainConfig, DataConfig, LogConfig, FSDPConfig + + +@dataclass +class RunConfig: + dataset_config: DataConfig = field(default_factory=DataConfig) + model_config: ModelConfig = field(default_factory=ModelConfig) + train_config: TrainConfig = field(default_factory=TrainConfig) + log_config: LogConfig = field(default_factory=LogConfig) + fsdp_config: FSDPConfig = field(default_factory=FSDPConfig) + debug: bool = field(default=False, metadata={"help": "Use pdb when true"}) + metric: str = field(default="acc", metadata={"help": "The metric for evaluation"}) + decode_log: str = field( + default="output/decode_log", + metadata={"help": "The prefix for the decode output"}, + ) + ckpt_path: str = field( + default="output/model.pt", metadata={"help": "The path to projector checkpoint"} + ) + peft_ckpt: Optional[str] = field( + default=None, + metadata={ + "help": "The path to peft checkpoint, should be a directory including adapter_config.json" + }, + ) + + +@hydra.main(config_name=None, version_base=None) +def main_hydra(cfg: DictConfig): + run_config = RunConfig() + cfg = OmegaConf.merge(run_config, cfg) + # kwargs = to_plain_list(cfg) + log_level = getattr(logging, cfg.get("log_level", "INFO").upper()) + + logging.basicConfig(level=log_level) + + if cfg.get("debug", False): + import pdb + + pdb.set_trace() + + inference(cfg) + + +if __name__ == "__main__": + main_hydra() diff --git a/examples/contextual_asr/model/slam_model_contextual_asr.py b/examples/contextual_asr/model/slam_model_contextual_asr.py new file mode 100644 index 00000000..0910d2ed --- /dev/null +++ b/examples/contextual_asr/model/slam_model_contextual_asr.py @@ -0,0 +1,155 @@ +import torch +import os +import logging +from slam_llm.models.slam_model import ( + slam_model, + setup_tokenizer, + setup_encoder, + setup_encoder_projector, + setup_llm, +) +from slam_llm.utils.train_utils import print_model_size + +logger = logging.getLogger(__name__) + +def model_factory(train_config, model_config, **kwargs): + # return necessary components for training + tokenizer = setup_tokenizer(train_config, model_config, **kwargs) + + encoder = setup_encoder(train_config, model_config, **kwargs) + + # llm + llm = setup_llm(train_config, model_config, **kwargs) + + # projector + encoder_projector = setup_encoder_projector( + train_config, model_config, **kwargs + ) + model = slam_model_asr( + encoder, + llm, + encoder_projector, + tokenizer, + train_config, + model_config, + **kwargs, + ) + + ckpt_path = kwargs.get( + "ckpt_path", None + ) # FIX(MZY): load model ckpt(mainly projector, related to model_checkpointing/checkpoint_handler.py: save_model_checkpoint_peft) + if ckpt_path is not None: + logger.info("loading other parts from: {}".format(ckpt_path)) + ckpt_dict = torch.load(ckpt_path, map_location="cpu") + model.load_state_dict(ckpt_dict, strict=False) + + print_model_size( + model, + train_config, + ( + int(os.environ["RANK"]) + if train_config.enable_fsdp or train_config.enable_ddp + else 0 + ), + ) + return model, tokenizer + + +class slam_model_asr(slam_model): + def __init__( + self, + encoder, + llm, + encoder_projector, + tokenizer, + train_config, + model_config, + **kwargs, + ): + super().__init__( + encoder, + llm, + encoder_projector, + tokenizer, + train_config, + model_config, + **kwargs, + ) + + + @torch.no_grad() + def inference( + self, + wav_path=None, + prompt=None, + generation_config=None, + logits_processor=None, + stopping_criteria=None, + prefix_allowed_tokens_fn=None, + synced_gpus=None, + assistant_model=None, + streamer=None, + negative_prompt_ids=None, + negative_prompt_attention_mask=None, + **kwargs, + ): + # inference for asr model + + device = kwargs.get("device", "cuda") + if os.path.exists(wav_path): # Audio-Text QA + import whisper + + audio_raw = whisper.load_audio(wav_path) + audio_raw = whisper.pad_or_trim(audio_raw) + + mel_size = getattr( + self.dataset_config, "mel_size", 80 + ) # 80 for large v1 and v2, 128 for large v3 + audio_mel = ( + whisper.log_mel_spectrogram(audio_raw, n_mels=mel_size) + .permute(1, 0)[None, :, :] + .to(device) + ) + + encoder_outs = self.encoder.extract_variable_length_features( + audio_mel.permute(0, 2, 1) + ) + + if self.model_config.encoder_projector == "q-former": + audio_mel_post_mask = torch.ones( + encoder_outs.size()[:-1], dtype=torch.long + ).to(encoder_outs.device) + encoder_outs = self.encoder_projector(encoder_outs, audio_mel_post_mask) + if self.model_config.encoder_projector == "linear": + encoder_outs = self.encoder_projector(encoder_outs) + else: # Text QA + encoder_outs = torch.empty( + 1, 0, self.llm.model.embed_tokens.embedding_dim + ).to(device) + + prompt = "USER: {}\n ASSISTANT:".format(prompt) + prompt_ids = self.tokenizer.encode(prompt) + prompt_length = len(prompt_ids) + prompt_ids = torch.tensor(prompt_ids, dtype=torch.int64).to(device) + + if hasattr(self.llm.model, "embed_tokens"): + inputs_embeds = self.llm.model.embed_tokens(prompt_ids) + elif hasattr(self.llm.model.model, "embed_tokens"): + inputs_embeds = self.llm.model.model.embed_tokens(prompt_ids) + else: + inputs_embeds = self.llm.model.model.model.embed_tokens(prompt_ids) + + inputs_embeds = torch.cat( + (encoder_outs, inputs_embeds[None, :, :]), dim=1 + ) # [audio,prompt] + + attention_mask = torch.ones(inputs_embeds.size()[:-1], dtype=torch.long).to( + inputs_embeds.device + ) + + # generate + model_outputs = self.generate( + inputs_embeds=inputs_embeds, attention_mask=attention_mask, **kwargs + ) + + return model_outputs diff --git a/examples/contextual_asr/scripts/finetune/finetune_wavlm_libri960_ft_char.sh b/examples/contextual_asr/scripts/finetune/finetune_wavlm_libri960_ft_char.sh new file mode 100644 index 00000000..d700a0f0 --- /dev/null +++ b/examples/contextual_asr/scripts/finetune/finetune_wavlm_libri960_ft_char.sh @@ -0,0 +1,75 @@ +#!/bin/bash +export PYTHONPATH=/root/fairseq:$PYTHONPATH +export CUDA_VISIBLE_DEVICES=0,1 +export TOKENIZERS_PARALLELISM=false +# export CUDA_LAUNCH_BLOCKING=1 +export OMP_NUM_THREADS=1 + +cd /nfs/yangguanrou.ygr/codes/SLAM-LLM +code_dir=examples/contextual_asr + +speech_encoder_path=/nfs/yangguanrou.ygr/ckpts/wavlm_large_ft_libri960_char/wavlm_large_ft_libri960_char.pt +llm_path=/nfs/maziyang.mzy/models/vicuna-7b-v1.5 +train_data_path=/nfs/maziyang.mzy/data/librispeech/librispeech_train_960h.jsonl +val_data_path=/nfs/maziyang.mzy/data/librispeech/librispeech_dev_other.jsonl + +output_dir=/nfs/yangguanrou.ygr/experiments_librispeech/vicuna-7b-v1.5-WavLM-Large-libri960-ft-char-$(date +"%Y%m%d")-debug + +hydra_args=" +hydra.run.dir=$output_dir \ +++model_config.llm_name=vicuna-7b-v1.5 \ +++model_config.llm_path=$llm_path \ +++model_config.llm_dim=4096 \ +++model_config.encoder_name=wavlm \ +++model_config.normalize=true \ +++dataset_config.normalize=true \ +++model_config.encoder_projector_ds_rate=5 \ +++model_config.encoder_path=$speech_encoder_path \ +++model_config.encoder_dim=1024 \ +++model_config.encoder_projector=cov1d-linear \ +++dataset_config.dataset=speech_dataset \ +++dataset_config.train_data_path=$train_data_path \ +++dataset_config.val_data_path=$val_data_path \ +++dataset_config.input_type=raw \ +++train_config.model_name=asr \ +++train_config.num_epochs=5 \ +++train_config.freeze_encoder=true \ +++train_config.freeze_llm=true \ +++train_config.batching_strategy=custom \ +++train_config.warmup_steps=1000 \ +++train_config.total_steps=100000 \ +++train_config.lr=1e-4 \ +++train_config.validation_interval=8000 \ +++train_config.val_batch_size=4 \ +++train_config.batch_size_training=4 \ +++train_config.num_workers_dataloader=2 \ +++train_config.output_dir=$output_dir \ +++metric=acc \ +++log_config.log_file=/$output_dir/train.log \ +++log_config.use_wandb=true \ +++log_config.wandb_dir=$output_dir \ +++log_config.wandb_entity_name=yanghaha \ +++log_config.wandb_project_name=slam-llm \ +++log_config.wandb_exp_name=vicuna-7b-v1.5-WavLM-Large-libri960-ft-char \ +++log_config.log_interval=5 \ +" + +# -m debugpy --listen 5678 --wait-for-client +if [[ $CUDA_VISIBLE_DEVICES != *","* ]]; then + python -m debugpy --listen 5678 --wait-for-client $code_dir/finetune_contextual_asr.py \ + --config-path "conf" \ + --config-name "prompt.yaml" \ + $hydra_args +else + torchrun \ + --nnodes 1 \ + --nproc_per_node 2 \ + --master_port=29503 \ + $code_dir/finetune_contextual_asr.py \ + --config-path "conf" \ + --config-name "prompt.yaml" \ + ++train_config.enable_fsdp=false \ + ++train_config.enable_ddp=true \ + ++train_config.use_fp16=true \ + $hydra_args +fi diff --git a/examples/contextual_asr/scripts/finetune/finetune_wavlm_libri960_ft_char_hotwords.sh b/examples/contextual_asr/scripts/finetune/finetune_wavlm_libri960_ft_char_hotwords.sh new file mode 100644 index 00000000..7326a614 --- /dev/null +++ b/examples/contextual_asr/scripts/finetune/finetune_wavlm_libri960_ft_char_hotwords.sh @@ -0,0 +1,77 @@ +#!/bin/bash +export PYTHONPATH=/root/fairseq:$PYTHONPATH +export CUDA_VISIBLE_DEVICES=2,3 +export TOKENIZERS_PARALLELISM=false +# export CUDA_LAUNCH_BLOCKING=1 +export OMP_NUM_THREADS=1 + +cd /nfs/yangguanrou.ygr/codes/SLAM-LLM +code_dir=examples/contextual_asr + +speech_encoder_path=/nfs/yangguanrou.ygr/ckpts/wavlm_large_ft_libri960_char/wavlm_large_ft_libri960_char.pt +llm_path=/nfs/maziyang.mzy/models/vicuna-7b-v1.5 +train_data_path=/nfs/maziyang.mzy/data/librispeech/librispeech_train_960h.jsonl +val_data_path=/nfs/maziyang.mzy/data/librispeech/librispeech_dev_other.jsonl + +output_dir=/nfs/yangguanrou.ygr/experiments_librispeech/vicuna-7b-v1.5-WavLM-Large-libri960-ft-char-hotwords-$(date +"%Y%m%d")-debug + +hydra_args=" +hydra.run.dir=$output_dir \ +++model_config.llm_name=vicuna-7b-v1.5 \ +++model_config.llm_path=$llm_path \ +++model_config.llm_dim=4096 \ +++model_config.encoder_name=wavlm \ +++model_config.normalize=true \ +++dataset_config.normalize=true \ +++model_config.encoder_projector_ds_rate=5 \ +++model_config.encoder_path=$speech_encoder_path \ +++model_config.encoder_dim=1024 \ +++model_config.encoder_projector=cov1d-linear \ +++dataset_config.dataset=speech_dataset \ +++dataset_config.train_data_path=$train_data_path \ +++dataset_config.val_data_path=$val_data_path \ +++dataset_config.input_type=raw \ +++dataset_config.dataset=hotwords_dataset \ +++dataset_config.file=examples/contextual_asr/dataset/hotwords_dataset.py:get_speech_dataset \ +++train_config.model_name=asr \ +++train_config.num_epochs=5 \ +++train_config.freeze_encoder=true \ +++train_config.freeze_llm=true \ +++train_config.batching_strategy=custom \ +++train_config.warmup_steps=1000 \ +++train_config.total_steps=100000 \ +++train_config.lr=1e-4 \ +++train_config.validation_interval=8000 \ +++train_config.val_batch_size=4 \ +++train_config.batch_size_training=4 \ +++train_config.num_workers_dataloader=2 \ +++train_config.output_dir=$output_dir \ +++metric=acc \ +++log_config.log_file=/$output_dir/train.log \ +++log_config.use_wandb=true \ +++log_config.wandb_dir=$output_dir \ +++log_config.wandb_entity_name=yanghaha \ +++log_config.wandb_project_name=slam-llm \ +++log_config.wandb_exp_name=vicuna-7b-v1.5-WavLM-Large-libri960-ft-char-hotwords \ +++log_config.log_interval=5 \ +" + +# -m debugpy --listen 5678 --wait-for-client +if [[ $CUDA_VISIBLE_DEVICES != *","* ]]; then + python -m debugpy --listen 5678 --wait-for-client $code_dir/finetune_contextual_asr.py \ + --config-path "conf" \ + --config-name "prompt.yaml" \ + $hydra_args +else + torchrun \ + --nnodes 1 \ + --nproc_per_node 2 \ + --master_port=29504 \ + $code_dir/finetune_contextual_asr.py \ + --config-path "conf" \ + --config-name "prompt.yaml" \ + ++train_config.enable_fsdp=false \ + ++train_config.enable_ddp=true \ + ++train_config.use_fp16=true \ + $hydra_args +fi diff --git a/examples/contextual_asr/scripts/infer/decode_wavlm_libri960_ft_char.sh b/examples/contextual_asr/scripts/infer/decode_wavlm_libri960_ft_char.sh new file mode 100644 index 00000000..7bdee2ef --- /dev/null +++ b/examples/contextual_asr/scripts/infer/decode_wavlm_libri960_ft_char.sh @@ -0,0 +1,56 @@ +#!/bin/bash +#export PYTHONPATH=/root/whisper:$PYTHONPATH +export CUDA_VISIBLE_DEVICES=2 +export TOKENIZERS_PARALLELISM=false +# export CUDA_LAUNCH_BLOCKING=1 + +run_dir=/nfs/yangguanrou.ygr/codes/SLAM-LLM +cd $run_dir +code_dir=examples/contextual_asr + +speech_encoder_path=/nfs/yangguanrou.ygr/ckpts/wavlm_large_ft_libri960_char/wavlm_large_ft_libri960_char.pt +llm_path=/nfs/maziyang.mzy/models/vicuna-7b-v1.5 + +output_dir=/nfs/yangguanrou.ygr/experiments_librispeech/vicuna-7b-v1.5-WavLM-Large-libri960-ft-char-20240521 +ckpt_path=$output_dir/asr_epoch_3_step_9780 +N=100 +for ref_split in test_clean test_other; do + split=librispeech_${ref_split} + val_data_path=/nfs/maziyang.mzy/data/librispeech/${split}.jsonl + decode_log=$ckpt_path/decode_${split}_beam4_debug + python $code_dir/inference_contextual_asr_batch.py \ + --config-path "conf" \ + --config-name "prompt.yaml" \ + hydra.run.dir=$ckpt_path \ + ++model_config.llm_name="vicuna-7b-v1.5" \ + ++model_config.llm_path=$llm_path \ + ++model_config.llm_dim=4096 \ + ++model_config.encoder_name=wavlm \ + ++model_config.normalize=true \ + ++dataset_config.normalize=true \ + ++model_config.encoder_projector_ds_rate=5 \ + ++model_config.encoder_path=$speech_encoder_path \ + ++model_config.encoder_dim=1024 \ + ++model_config.encoder_projector=cov1d-linear \ + ++dataset_config.dataset=speech_dataset \ + ++dataset_config.val_data_path=$val_data_path \ + ++dataset_config.input_type=raw \ + ++dataset_config.inference_mode=true \ + ++train_config.model_name=asr \ + ++train_config.freeze_encoder=true \ + ++train_config.freeze_llm=true \ + ++train_config.batching_strategy=custom \ + ++train_config.num_epochs=1 \ + ++train_config.val_batch_size=1 \ + ++train_config.num_workers_dataloader=0 \ + ++train_config.output_dir=$output_dir \ + ++decode_log=$decode_log \ + ++ckpt_path=$ckpt_path/model.pt && \ + python src/slam_llm/utils/whisper_tn.py ${decode_log}_gt ${decode_log}_gt.proc && \ + python src/slam_llm/utils/whisper_tn.py ${decode_log}_pred ${decode_log}_pred.proc && \ + python src/slam_llm/utils/compute_wer.py ${decode_log}_gt.proc ${decode_log}_pred.proc ${decode_log}.proc.wer && \ + python /nfs/yangguanrou.ygr/data/fbai-speech/is21_deep_bias/my_score.py \ + --refs /nfs/yangguanrou.ygr/data/fbai-speech/is21_deep_bias/ref_score/${ref_split}.biasing_${N}.tsv \ + --hyps ${decode_log}_pred.proc \ + --output_file ${decode_log}.proc.wer +done diff --git a/examples/contextual_asr/scripts/infer/decode_wavlm_libri960_ft_char_hotwords.sh b/examples/contextual_asr/scripts/infer/decode_wavlm_libri960_ft_char_hotwords.sh new file mode 100644 index 00000000..c58351ad --- /dev/null +++ b/examples/contextual_asr/scripts/infer/decode_wavlm_libri960_ft_char_hotwords.sh @@ -0,0 +1,74 @@ +#!/bin/bash +#export PYTHONPATH=/root/whisper:$PYTHONPATH +export CUDA_VISIBLE_DEVICES=0 +export TOKENIZERS_PARALLELISM=false +export CUDA_LAUNCH_BLOCKING=1 +export HYDRA_FULL_ERROR=1 + +run_dir=/nfs/yangguanrou.ygr/codes/SLAM-LLM +cd $run_dir +code_dir=examples/contextual_asr + + +speech_encoder_path=/nfs/yangguanrou.ygr/ckpts/wavlm_large_ft_libri960_char/wavlm_large_ft_libri960_char.pt +llm_path=/nfs/maziyang.mzy/models/vicuna-7b-v1.5 + +output_dir=/nfs/yangguanrou.ygr/experiments_librispeech/vicuna-7b-v1.5-WavLM-Large-libri960-ft-char-hotwords-20240521 +ckpt_path=$output_dir/asr_epoch_3_step_25780 + +first=1 +for N in 100 500 1000 2000; do + for ref_split in test_clean test_other; do + split=librispeech_${ref_split} + val_data_path=/nfs/maziyang.mzy/data/librispeech/${split}.jsonl + decode_log=$ckpt_path/decode_${split}_beam4_filter_N${N}_first${first}_debug + python $code_dir/inference_contextual_asr_batch.py \ + --config-path "conf" \ + --config-name "prompt.yaml" \ + hydra.run.dir=$ckpt_path \ + ++model_config.llm_name="vicuna-7b-v1.5" \ + ++model_config.llm_path=$llm_path \ + ++model_config.llm_dim=4096 \ + ++model_config.encoder_name=wavlm \ + ++model_config.normalize=true \ + ++dataset_config.normalize=true \ + ++model_config.encoder_projector_ds_rate=5 \ + ++model_config.encoder_path=$speech_encoder_path \ + ++model_config.encoder_dim=1024 \ + ++model_config.encoder_projector=cov1d-linear \ + ++dataset_config.dataset=speech_dataset \ + ++dataset_config.val_data_path=$val_data_path \ + ++dataset_config.input_type=raw \ + ++dataset_config.inference_mode=true \ + ++dataset_config.infer_type=filter \ + ++dataset_config.dataset=hotwordsinfer_dataset \ + ++dataset_config.file=examples/contextual_asr/dataset/hotwordsinfer_dataset.py:get_speech_dataset \ + ++dataset_config.infer_file=/nfs/yangguanrou.ygr/data/fbai-speech/is21_deep_bias/my_ref/${ref_split}.biasing_${N}.tsv \ + ++dataset_config.ctc_file=/nfs/yangguanrou.ygr/data/librispeech_my_infer/wavlm_ft_libri960_${ref_split}_char.txt \ + ++dataset_config.probability_threshold=0.9 \ + ++dataset_config.word_num=15 \ + ++dataset_config.filter_infer_sentence=true \ + ++dataset_config.filter_infer_sentence_few=true \ + ++train_config.model_name=asr \ + ++train_config.freeze_encoder=true \ + ++train_config.freeze_llm=true \ + ++train_config.batching_strategy=custom \ + ++train_config.num_epochs=1 \ + ++train_config.val_batch_size=1 \ + ++train_config.num_workers_dataloader=0 \ + ++train_config.output_dir=$output_dir \ + ++decode_log=$decode_log \ + ++ckpt_path=$ckpt_path/model.pt && \ + + python src/slam_llm/utils/whisper_tn.py ${decode_log}_gt ${decode_log}_gt.proc && \ + python src/slam_llm/utils/whisper_tn.py ${decode_log}_pred ${decode_log}_pred.proc && \ + python src/slam_llm/utils/compute_wer.py ${decode_log}_gt.proc ${decode_log}_pred.proc ${decode_log}.proc.wer && \ + python /nfs/yangguanrou.ygr/data/fbai-speech/is21_deep_bias/my_score.py \ + --refs /nfs/yangguanrou.ygr/data/fbai-speech/is21_deep_bias/ref_score/${ref_split}.biasing_${N}.tsv \ + --hyps ${decode_log}_pred.proc \ + --output_file ${decode_log}.proc.wer + done +done + + +# bash /root/SLAM-LLM/examples/hotwords_librispeech/scripts_libri960/infer/3/filter_infer_sen_for_each_word/bs1_remake/decode_wavlm_libri960_ft_char_hotwords_filter_N100_first1_remake_debug.sh > /root/SLAM-LLM/examples/hotwords_librispeech/scripts_libri960/infer/3/filter_infer_sen_for_each_word/bs1_remake/decode_wavlm_libri960_ft_char_hotwords_filter_N100_first1_remake_debug.log \ No newline at end of file