From 3df640160ea3f87fcd8d4dadc4abcf31ab91a725 Mon Sep 17 00:00:00 2001 From: yanghaha0908 Date: Fri, 7 Jun 2024 21:02:36 +0800 Subject: [PATCH] for pr mala-asr --- examples/mala_asr_slidespeech/README.md | 24 ++ .../mala_asr_slidespeech/conf/ds_config.json | 19 + .../mala_asr_slidespeech/conf/prompt.yaml | 4 + .../mala_asr_slidespeech/finetune_mala_asr.py | 49 +++ .../inference_mala_asr_batch.py | 53 +++ .../mala_asr_slidespeech/mala_asr_config.py | 127 +++++++ .../model/slam_model_mala_asr.py | 155 +++++++++ .../decode_MaLa-ASR_withkeywords_L95.sh | 49 +++ .../scripts/decode_old.sh | 50 +++ .../finetune_MaLa-ASR_withkeywords_L95.sh | 70 ++++ examples/mala_asr_slidespeech/scripts/old.sh | 58 ++++ src/slam_llm/datasets/slidespeech_dataset.py | 325 ++++++++++++++++++ src/slam_llm/pipeline/finetune.py | 13 - 13 files changed, 983 insertions(+), 13 deletions(-) create mode 100644 examples/mala_asr_slidespeech/README.md create mode 100644 examples/mala_asr_slidespeech/conf/ds_config.json create mode 100644 examples/mala_asr_slidespeech/conf/prompt.yaml create mode 100644 examples/mala_asr_slidespeech/finetune_mala_asr.py create mode 100644 examples/mala_asr_slidespeech/inference_mala_asr_batch.py create mode 100644 examples/mala_asr_slidespeech/mala_asr_config.py create mode 100644 examples/mala_asr_slidespeech/model/slam_model_mala_asr.py create mode 100644 examples/mala_asr_slidespeech/scripts/decode_MaLa-ASR_withkeywords_L95.sh create mode 100644 examples/mala_asr_slidespeech/scripts/decode_old.sh create mode 100644 examples/mala_asr_slidespeech/scripts/finetune_MaLa-ASR_withkeywords_L95.sh create mode 100644 examples/mala_asr_slidespeech/scripts/old.sh create mode 100644 src/slam_llm/datasets/slidespeech_dataset.py diff --git a/examples/mala_asr_slidespeech/README.md b/examples/mala_asr_slidespeech/README.md new file mode 100644 index 00000000..0656a4c1 --- /dev/null +++ b/examples/mala_asr_slidespeech/README.md @@ -0,0 +1,24 @@ +# MALA-ASR_SLIDESPEECH + +## Performance and checkpoints +We only train the linear projector in this recipe. +Encoder | Projector | LLM | dev | test +|---|---|---|---|---| +[WavLM-large](https://drive.google.com/file/d/12-cB34qCTvByWT-QtOcZaqwwO21FLSqU/view) | [Linear](https://drive.google.com/file/d/1hYS5UI3W0WVOZRVbqWxDUWIFMO9VgzHk/view?usp=drive_link)(~15.74M) | [vicuna-7b-v1.5](https://huggingface.co/lmsys/vicuna-7b-v1.5) | 8.91 | 9.14 + + +## Data preparation +Refer to official [SLIDESPEECH CORPUS](https://slidespeech.github.io/) + +## Decode with checkpoints +``` +bash decode_MaLa-ASR_withkeywords_L95.sh +``` +Modify the path including `speech_encoder_path`, `llm_path`, `output_dir`, `ckpt_path` and `decode_log` in the script when you run the shell script. + +## Train a new model + +### Use self-supervised model(such as WavLM) as the encoder +``` +bash finetune_MaLa-ASR_withkeywords_L95.sh +``` \ No newline at end of file diff --git a/examples/mala_asr_slidespeech/conf/ds_config.json b/examples/mala_asr_slidespeech/conf/ds_config.json new file mode 100644 index 00000000..7ea70e4a --- /dev/null +++ b/examples/mala_asr_slidespeech/conf/ds_config.json @@ -0,0 +1,19 @@ +{ + "train_micro_batch_size_per_gpu": 4, + "gradient_accumulation_steps": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-4 + } + }, + "fp16": { + "enabled": true + }, + "zero_optimization": { + "stage": 3, + "offload_optimizer": { + "device": "cpu" + } + } +} \ No newline at end of file diff --git a/examples/mala_asr_slidespeech/conf/prompt.yaml b/examples/mala_asr_slidespeech/conf/prompt.yaml new file mode 100644 index 00000000..0bc65175 --- /dev/null +++ b/examples/mala_asr_slidespeech/conf/prompt.yaml @@ -0,0 +1,4 @@ +dataset_config: + # we put prompt here, because the hydra override in shell script only support a small subset of chars + # prompt: "Transcribe speech to text. Output the transcription directly without redundant content. Ensure that the output is not duplicated. " + prompt: "Transcribe speech to text. " diff --git a/examples/mala_asr_slidespeech/finetune_mala_asr.py b/examples/mala_asr_slidespeech/finetune_mala_asr.py new file mode 100644 index 00000000..3f0d46a0 --- /dev/null +++ b/examples/mala_asr_slidespeech/finetune_mala_asr.py @@ -0,0 +1,49 @@ +from slam_llm.pipeline.finetune import main as train +from typing import Optional + +import hydra +import logging +from dataclasses import dataclass, field +from omegaconf import DictConfig, ListConfig, OmegaConf +from mala_asr_config import ModelConfig, TrainConfig, DataConfig, LogConfig, FSDPConfig + +@dataclass +class RunConfig: + dataset_config: DataConfig = field(default_factory=DataConfig) + model_config: ModelConfig = field(default_factory=ModelConfig) + train_config: TrainConfig = field(default_factory=TrainConfig) + log_config: LogConfig = field(default_factory=LogConfig) + fsdp_config: FSDPConfig = field(default_factory=FSDPConfig) + debug: bool = field(default=False, metadata={"help": "Use pdb when true"}) + metric: str = field(default="acc", metadata={"help": "The metric for evaluation"}) + ckpt_path: Optional[str] = field( + default=None, metadata={"help": "The path to projector checkpoint"} + ) + +@hydra.main(config_name=None, version_base=None) +def main_hydra(cfg: DictConfig): + run_config = RunConfig() + cfg = OmegaConf.merge(run_config, cfg) + def to_plain_list(cfg_item): + if isinstance(cfg_item, ListConfig): + return OmegaConf.to_container(cfg_item, resolve=True) + elif isinstance(cfg_item, DictConfig): + return {k: to_plain_list(v) for k, v in cfg_item.items()} + else: + return cfg_item + + # kwargs = to_plain_list(cfg) + kwargs = cfg + log_level = getattr(logging, kwargs.get("log_level", "INFO").upper()) + + logging.basicConfig(level=log_level) + + if kwargs.get("debug", False): + import pdb; + pdb.set_trace() + + train(kwargs) + + +if __name__ == "__main__": + main_hydra() \ No newline at end of file diff --git a/examples/mala_asr_slidespeech/inference_mala_asr_batch.py b/examples/mala_asr_slidespeech/inference_mala_asr_batch.py new file mode 100644 index 00000000..e733f7b3 --- /dev/null +++ b/examples/mala_asr_slidespeech/inference_mala_asr_batch.py @@ -0,0 +1,53 @@ +from slam_llm.pipeline.inference_batch import main as inference + +import hydra +import logging +from dataclasses import dataclass, field +from omegaconf import DictConfig, ListConfig, OmegaConf +from typing import Optional +from mala_asr_config import ModelConfig, TrainConfig, DataConfig, LogConfig, FSDPConfig + + +@dataclass +class RunConfig: + dataset_config: DataConfig = field(default_factory=DataConfig) + model_config: ModelConfig = field(default_factory=ModelConfig) + train_config: TrainConfig = field(default_factory=TrainConfig) + log_config: LogConfig = field(default_factory=LogConfig) + fsdp_config: FSDPConfig = field(default_factory=FSDPConfig) + debug: bool = field(default=False, metadata={"help": "Use pdb when true"}) + metric: str = field(default="acc", metadata={"help": "The metric for evaluation"}) + decode_log: str = field( + default="output/decode_log", + metadata={"help": "The prefix for the decode output"}, + ) + ckpt_path: str = field( + default="output/model.pt", metadata={"help": "The path to projector checkpoint"} + ) + peft_ckpt: Optional[str] = field( + default=None, + metadata={ + "help": "The path to peft checkpoint, should be a directory including adapter_config.json" + }, + ) + + +@hydra.main(config_name=None, version_base=None) +def main_hydra(cfg: DictConfig): + run_config = RunConfig() + cfg = OmegaConf.merge(run_config, cfg) + # kwargs = to_plain_list(cfg) + log_level = getattr(logging, cfg.get("log_level", "INFO").upper()) + + logging.basicConfig(level=log_level) + + if cfg.get("debug", False): + import pdb + + pdb.set_trace() + + inference(cfg) + + +if __name__ == "__main__": + main_hydra() diff --git a/examples/mala_asr_slidespeech/mala_asr_config.py b/examples/mala_asr_slidespeech/mala_asr_config.py new file mode 100644 index 00000000..0e9b7300 --- /dev/null +++ b/examples/mala_asr_slidespeech/mala_asr_config.py @@ -0,0 +1,127 @@ +from dataclasses import dataclass, field +from typing import Optional, List +@dataclass +class ModelConfig: + file: str = "examples/mala_asr_slidespeech/model/slam_model_mala_asr.py:model_factory" + llm_name: str = "vicuna-7b-v1.5" + llm_path: str = "PATH/to/LLAMA/7B" + llm_type: str = "decoder_only" + llm_dim: int = 4096 + encoder_name: Optional[str] = None + encoder_ds_rate: int = 2 + encoder_path: Optional[str] = None + encoder_dim: int = 1280 + encoder_projector: str = "linear" + encoder_projector_ds_rate: int = 5 + modal: str = "audio" + normalize: Optional[bool] = field(default=False, metadata={ + "help": "whether input is normalized, used for models such as wavlm" + }) + encoder_type: str = field(default="finetune", metadata={ + "help": "whether model is only pretrained or finetuned, used for models such as hubert" + }) + +@dataclass +class PeftConfig: + peft_method: str = "lora" # None , llama_adapter, prefix + r: int = 8 + lora_alpha: int = 32 + target_modules: List = field(default_factory=lambda: [ "q_proj", "v_proj" ]) + bias: str = "none" + task_type: str = "CAUSAL_LM" + lora_dropout: float = 0.05 + inference_mode: bool = False + +@dataclass +class TrainConfig: + model_name:str = "PATH/to/LLAMA/7B" + enable_ddp:bool = False + enable_deepspeed:bool = False + enable_fsdp:bool = False + low_cpu_fsdp:bool = False + run_validation:bool = True + batch_size_training:int = 4 + batching_strategy:str = field(default="packing", metadata={ + "help":"alternative: padding" + }) # + context_length:int = 4096 + gradient_accumulation_steps:int = 1 + num_epochs:int = 3 + num_workers_dataloader:int = 1 + warmup_steps:int = 1000 + total_steps:int = 100000 + validation_interval:int = 1000 + lr:float = 1e-4 + weight_decay:float = 0.0 + gamma:float = 0.85 + seed:int = 42 + use_fp16:bool = False + mixed_precision:bool = True + val_batch_size:int = 1 + + use_peft:bool = False + peft_config:PeftConfig = field(default_factory=PeftConfig) + output_dir:str = "PATH/to/save/PEFT/model" + freeze_layers:bool = False + num_freeze_layers:int = 1 + quantization:bool = False + one_gpu:bool = False + save_model:bool = True + dist_checkpoint_root_folder:str = "PATH/to/save/FSDP/model" # will be used if using FSDP + dist_checkpoint_folder:str = "fine-tuned" # will be used if using FSDP + save_optimizer:bool = False # will be used if using FSDP + use_fast_kernels:bool = False # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels + run_test_during_validation:bool = False + run_test_during_validation_file:str = "test.wav" + run_test_during_validation_prompt:str = "<|ASR|>" + freeze_llm:bool = field(default=False, metadata={ + "help": "whether to freeze llm when finetuning, should be true when use peft finetuning" + }) + freeze_encoder:bool = False + +@dataclass +class DataConfig: + dataset: str = "slidespeech_dataset" + file: str = "src/slam_llm/datasets/slidespeech_dataset.py:get_speech_dataset" + train_scp_file_path: str = "/nfs/yangguanrou.ygr/slidespeech/train_L_95/" + dev_scp_file_path: str = "/nfs/yangguanrou.ygr/slidespeech/dev_oracle_v1/" + test_scp_file_path: str = "/nfs/yangguanrou.ygr/slidespeech/test_oracle_v1/" + train_split: str = "train" + test_split:str = "val" + prompt: Optional[str] = None + use_ocr: bool = True + inference_mode: bool = False + lower: bool = False + fix_length_audio: int = -1 + inference_mode:bool = False + input_type: str = field(default="raw", metadata={ + "help":"Use raw when input is wav, mel when for whisper" + }) + mel_size: int = field(default=80, metadata={ + "help": "80 for whisper large v1 and v2, 128 for v3" + }) + normalize: Optional[bool] = field(default=False, metadata={ + "help": "whether input is normalized, used for models such as wavlm" + }) + +@dataclass +class FSDPConfig: + mixed_precision: bool = True + use_fp16: bool = False + # sharding_strategy = "FULL_SHARD" #ShardingStrategy = ShardingStrategy.FULL_SHARD + sharding_strategy: str = "NO_SHARD" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP + checkpoint_type: str = "SHARDED_STATE_DICT" # alternatively can use SHARDED_STATE_DICT save one file per rank, and can resize the world-size. + fsdp_activation_checkpointing: bool = True + fsdp_cpu_offload: bool = False + pure_bf16: bool = False + optimizer: str = "AdamW" + +@dataclass +class LogConfig: + use_wandb: bool = False + wandb_dir: str = "/root/test_wandb" + wandb_entity_name: str = "project_name" + wandb_project_name: str = "project_name" + wandb_exp_name: str = "exp_name" + log_file: str = "/root/test.log" + log_interval: int = 5 diff --git a/examples/mala_asr_slidespeech/model/slam_model_mala_asr.py b/examples/mala_asr_slidespeech/model/slam_model_mala_asr.py new file mode 100644 index 00000000..0910d2ed --- /dev/null +++ b/examples/mala_asr_slidespeech/model/slam_model_mala_asr.py @@ -0,0 +1,155 @@ +import torch +import os +import logging +from slam_llm.models.slam_model import ( + slam_model, + setup_tokenizer, + setup_encoder, + setup_encoder_projector, + setup_llm, +) +from slam_llm.utils.train_utils import print_model_size + +logger = logging.getLogger(__name__) + +def model_factory(train_config, model_config, **kwargs): + # return necessary components for training + tokenizer = setup_tokenizer(train_config, model_config, **kwargs) + + encoder = setup_encoder(train_config, model_config, **kwargs) + + # llm + llm = setup_llm(train_config, model_config, **kwargs) + + # projector + encoder_projector = setup_encoder_projector( + train_config, model_config, **kwargs + ) + model = slam_model_asr( + encoder, + llm, + encoder_projector, + tokenizer, + train_config, + model_config, + **kwargs, + ) + + ckpt_path = kwargs.get( + "ckpt_path", None + ) # FIX(MZY): load model ckpt(mainly projector, related to model_checkpointing/checkpoint_handler.py: save_model_checkpoint_peft) + if ckpt_path is not None: + logger.info("loading other parts from: {}".format(ckpt_path)) + ckpt_dict = torch.load(ckpt_path, map_location="cpu") + model.load_state_dict(ckpt_dict, strict=False) + + print_model_size( + model, + train_config, + ( + int(os.environ["RANK"]) + if train_config.enable_fsdp or train_config.enable_ddp + else 0 + ), + ) + return model, tokenizer + + +class slam_model_asr(slam_model): + def __init__( + self, + encoder, + llm, + encoder_projector, + tokenizer, + train_config, + model_config, + **kwargs, + ): + super().__init__( + encoder, + llm, + encoder_projector, + tokenizer, + train_config, + model_config, + **kwargs, + ) + + + @torch.no_grad() + def inference( + self, + wav_path=None, + prompt=None, + generation_config=None, + logits_processor=None, + stopping_criteria=None, + prefix_allowed_tokens_fn=None, + synced_gpus=None, + assistant_model=None, + streamer=None, + negative_prompt_ids=None, + negative_prompt_attention_mask=None, + **kwargs, + ): + # inference for asr model + + device = kwargs.get("device", "cuda") + if os.path.exists(wav_path): # Audio-Text QA + import whisper + + audio_raw = whisper.load_audio(wav_path) + audio_raw = whisper.pad_or_trim(audio_raw) + + mel_size = getattr( + self.dataset_config, "mel_size", 80 + ) # 80 for large v1 and v2, 128 for large v3 + audio_mel = ( + whisper.log_mel_spectrogram(audio_raw, n_mels=mel_size) + .permute(1, 0)[None, :, :] + .to(device) + ) + + encoder_outs = self.encoder.extract_variable_length_features( + audio_mel.permute(0, 2, 1) + ) + + if self.model_config.encoder_projector == "q-former": + audio_mel_post_mask = torch.ones( + encoder_outs.size()[:-1], dtype=torch.long + ).to(encoder_outs.device) + encoder_outs = self.encoder_projector(encoder_outs, audio_mel_post_mask) + if self.model_config.encoder_projector == "linear": + encoder_outs = self.encoder_projector(encoder_outs) + else: # Text QA + encoder_outs = torch.empty( + 1, 0, self.llm.model.embed_tokens.embedding_dim + ).to(device) + + prompt = "USER: {}\n ASSISTANT:".format(prompt) + prompt_ids = self.tokenizer.encode(prompt) + prompt_length = len(prompt_ids) + prompt_ids = torch.tensor(prompt_ids, dtype=torch.int64).to(device) + + if hasattr(self.llm.model, "embed_tokens"): + inputs_embeds = self.llm.model.embed_tokens(prompt_ids) + elif hasattr(self.llm.model.model, "embed_tokens"): + inputs_embeds = self.llm.model.model.embed_tokens(prompt_ids) + else: + inputs_embeds = self.llm.model.model.model.embed_tokens(prompt_ids) + + inputs_embeds = torch.cat( + (encoder_outs, inputs_embeds[None, :, :]), dim=1 + ) # [audio,prompt] + + attention_mask = torch.ones(inputs_embeds.size()[:-1], dtype=torch.long).to( + inputs_embeds.device + ) + + # generate + model_outputs = self.generate( + inputs_embeds=inputs_embeds, attention_mask=attention_mask, **kwargs + ) + + return model_outputs diff --git a/examples/mala_asr_slidespeech/scripts/decode_MaLa-ASR_withkeywords_L95.sh b/examples/mala_asr_slidespeech/scripts/decode_MaLa-ASR_withkeywords_L95.sh new file mode 100644 index 00000000..c922691b --- /dev/null +++ b/examples/mala_asr_slidespeech/scripts/decode_MaLa-ASR_withkeywords_L95.sh @@ -0,0 +1,49 @@ +#!/bin/bash +#export PYTHONPATH=/root/whisper:$PYTHONPATH +export CUDA_VISIBLE_DEVICES=0 +export TOKENIZERS_PARALLELISM=false +# export CUDA_LAUNCH_BLOCKING=1 + +run_dir=/root/SLAM-LLM +cd $run_dir +code_dir=examples/mala_asr_slidespeech + +speech_encoder_path=/nfs/maziyang.mzy/models/wavlm/WavLM-Large.pt +llm_path=/nfs/maziyang.mzy/models/vicuna-7b-v1.5 + +output_dir=/nfs/yangguanrou.ygr/experiments_slides_wavlm/slides-finetune-wavlm +ckpt_path=$output_dir/asr/3840 +split=test #dev +val_data_path=/nfs/yangguanrou.ygr/slidespeech/${split}_oracle_v1/ +decode_log=$ckpt_path/decode_${split}_beam4 + +# -m debugpy --listen 5678 --wait-for-client +python $code_dir/inference_mala_asr_batch.py \ + --config-path "conf" \ + --config-name "prompt.yaml" \ + hydra.run.dir=$ckpt_path \ + ++model_config.llm_name="vicuna-7b-v1.5" \ + ++model_config.llm_path=$llm_path \ + ++model_config.llm_dim=4096 \ + ++model_config.encoder_name=wavlm \ + ++model_config.normalize=true \ + ++dataset_config.normalize=true \ + ++model_config.encoder_projector_ds_rate=5 \ + ++model_config.encoder_path=$speech_encoder_path \ + ++model_config.encoder_dim=1024 \ + ++model_config.encoder_projector=cov1d-linear \ + ++dataset_config.dataset=slidespeech_dataset \ + ++dataset_config.use_ocr=true \ + ++dataset_config.dev_scp_file_path=$val_data_path \ + ++dataset_config.input_type=raw \ + ++dataset_config.inference_mode=true \ + ++train_config.model_name=mala_asr \ + ++train_config.freeze_encoder=true \ + ++train_config.freeze_llm=true \ + ++train_config.batching_strategy=custom \ + ++train_config.num_epochs=1 \ + ++train_config.val_batch_size=1 \ + ++train_config.num_workers_dataloader=2 \ + ++train_config.output_dir=$output_dir \ + ++decode_log=$decode_log \ + ++ckpt_path=$ckpt_path/model.pt \ \ No newline at end of file diff --git a/examples/mala_asr_slidespeech/scripts/decode_old.sh b/examples/mala_asr_slidespeech/scripts/decode_old.sh new file mode 100644 index 00000000..b4799112 --- /dev/null +++ b/examples/mala_asr_slidespeech/scripts/decode_old.sh @@ -0,0 +1,50 @@ +#!/bin/bash + +export CUDA_VISIBLE_DEVICES=1 +export TOKENIZERS_PARALLELISM=false +# export CUDA_LAUNCH_BLOCKING=1 + + +cd /root/SLAM-LLM + +speech_encoder_path=/nfs/maziyang.mzy/models/wavlm/WavLM-Large.pt + +llm_path=/nfs/maziyang.mzy/models/vicuna-7b-v1.5 + +output_dir=/nfs/yangguanrou.ygr/slides-finetune-wavlm +ckpt_path=$output_dir/asr/3840 +# peft_ckpt=/nfs/maziyang.mzy/exps/llama-2-hf-finetune-asr-ds5-proj2048-lr1e-4-whisper-lora-prompt-paddinglr-20240102/asr/4 +val_data_path=/nfs/yangguanrou.ygr/slidespeech/dev_oracle_v1/ +decode_log=$ckpt_path/decode_log_dev_clean_beam4_repetition_penalty1 + +# -m debugpy --listen 5678 --wait-for-client +python src/llama_recipes/pipeline/inference_batch.py \ +--config-path "/root/SLAM-LLM/scripts/slides_conf" \ +--config-name "slides.yaml" \ +hydra.run.dir=$ckpt_path \ +++model_config.llm_name="vicuna-7b-v1.5" \ +++model_config.llm_path=$llm_path \ +++model_config.llm_dim=4096 \ +++model_config.encoder_name=wavlm \ +++model_config.encoder_path=$speech_encoder_path \ +++model_config.encoder_dim=1024 \ +++model_config.encoder_projector=cov1d-linear \ +++encoder_projector_ds_rate=5 \ +++dataset_config.dataset=slides_dataset \ +++dataset_config.use_ocr=true \ +++dataset_config.dev_scp_file_path=$val_data_path \ +++dataset_config.inference_mode=true \ +++train_config.model_name=asr \ +++train_config.batching_strategy=custom \ +++train_config.num_epochs=1 \ +++train_config.val_batch_size=1 \ +++train_config.num_workers_dataloader=1 \ +++train_config.output_dir=$output_dir \ +++ckpt_path=$ckpt_path/model.pt \ +++decode_log=$decode_log \ +++train_config.freeze_encoder=true \ +++train_config.freeze_llm=true \ +# ++model_config.encoder_projector=q-former \ +# ++dataset_config.fix_length_audio=64 \ +# --peft_ckpt $peft_ckpt \ +# --use_peft --peft_method lora \ \ No newline at end of file diff --git a/examples/mala_asr_slidespeech/scripts/finetune_MaLa-ASR_withkeywords_L95.sh b/examples/mala_asr_slidespeech/scripts/finetune_MaLa-ASR_withkeywords_L95.sh new file mode 100644 index 00000000..60c871c4 --- /dev/null +++ b/examples/mala_asr_slidespeech/scripts/finetune_MaLa-ASR_withkeywords_L95.sh @@ -0,0 +1,70 @@ +#!/bin/bash +export PYTHONPATH=/root/fairseq:$PYTHONPATH +export CUDA_VISIBLE_DEVICES=0 +export TOKENIZERS_PARALLELISM=false +# export CUDA_LAUNCH_BLOCKING=1 +export OMP_NUM_THREADS=1 + +# debug setting for multiple gpus +# export NCCL_DEBUG=INFO +# export NCCL_DEBUG_SUBSYS=ALL +# export TORCH_DISTRIBUTED_DEBUG=INFO + +run_dir=/root/SLAM-LLM +cd $run_dir +code_dir=examples/mala_asr_slidespeech + +speech_encoder_path=/nfs/maziyang.mzy/models/wavlm/WavLM-Large.pt +llm_path=/nfs/maziyang.mzy/models/vicuna-7b-v1.5 +output_dir=/root/tmp/finetune_MaLa-ASR_withkeywords_L95-$(date +"%Y%m%d") + +hydra_args=" +hydra.run.dir=$output_dir \ +++model_config.llm_name=vicuna-7b-v1.5 \ +++model_config.llm_path=$llm_path \ +++model_config.llm_dim=4096 \ +++model_config.encoder_name=wavlm \ +++model_config.normalize=true \ +++dataset_config.normalize=true \ +++model_config.encoder_projector_ds_rate=5 \ +++model_config.encoder_path=$speech_encoder_path \ +++model_config.encoder_dim=1024 \ +++model_config.encoder_projector=cov1d-linear \ +++dataset_config.dataset=slidespeech_dataset \ +++dataset_config.input_type=raw \ +++dataset_config.use_ocr=true \ +++train_config.model_name=mala_asr \ +++train_config.num_epochs=5 \ +++train_config.freeze_encoder=true \ +++train_config.freeze_llm=true \ +++train_config.batching_strategy=custom \ +++train_config.warmup_steps=1000 \ +++train_config.total_steps=110000 \ +++train_config.lr=5e-5 \ +++train_config.validation_interval=2000 \ +++train_config.batch_size_training=6 \ +++train_config.val_batch_size=6 \ +++train_config.num_workers_dataloader=2 \ +++train_config.output_dir=$output_dir \ +++metric=acc \ +" + +# -m debugpy --listen 5678 --wait-for-client +if [[ $CUDA_VISIBLE_DEVICES != *","* ]]; then + python -m debugpy --listen 5678 --wait-for-client $code_dir/finetune_mala_asr.py \ + --config-path "conf" \ + --config-name "prompt.yaml" \ + $hydra_args +else + torchrun \ + --nnodes 1 \ + --nproc_per_node 4 \ + --master_port=29503 \ + $code_dir/finetune_mala_asr.py \ + --config-path "conf" \ + --config-name "prompt.yaml" \ + ++train_config.enable_fsdp=false \ + ++train_config.enable_ddp=true \ + ++train_config.use_fp16=true \ + $hydra_args +fi diff --git a/examples/mala_asr_slidespeech/scripts/old.sh b/examples/mala_asr_slidespeech/scripts/old.sh new file mode 100644 index 00000000..d9204f78 --- /dev/null +++ b/examples/mala_asr_slidespeech/scripts/old.sh @@ -0,0 +1,58 @@ +#!/bin/bash +export PYTHONPATH=/root/fairseq:$PYTHONPATH +export CUDA_VISIBLE_DEVICES=0,1,2,3 +export TOKENIZERS_PARALLELISM=false +# export CUDA_LAUNCH_BLOCKING=1 +export OMP_NUM_THREADS=1 +export WANDB_API_KEY='cc633438a8688eaf73ea5889022da2c394a92df2' + +cd /root/SLAM-LLM + +speech_encoder_path=/nfs/maziyang.mzy/models/wavlm/WavLM-Large.pt + +llm_path=/nfs/maziyang.mzy/models/vicuna-7b-v1.5 + +output_dir=/nfs/yangguanrou.ygr/slides-finetune-wavlm + + +torchrun \ +--nnodes 1 \ +--nproc_per_node 4 \ +src/llama_recipes/pipeline/finetune.py \ +--config-path "/root/SLAM-LLM/scripts/slides_conf" \ +--config-name "slides.yaml" \ +hydra.run.dir=$output_dir \ +++model_config.llm_name="vicuna-7b-v1.5" \ +++model_config.llm_path=$llm_path \ +++model_config.llm_dim=4096 \ +++model_config.encoder_name=wavlm \ +++model_config.encoder_path=$speech_encoder_path \ +++model_config.encoder_dim=1024 \ +++model_config.encoder_projector=cov1d-linear \ +++encoder_projector_ds_rate=5 \ +++dataset_config.dataset=slides_dataset \ +++dataset_config.use_ocr=true \ +++train_config.model_name=asr \ +++train_config.freeze_encoder=true \ +++train_config.freeze_llm=true \ +++train_config.batching_strategy=custom \ +++train_config.warmup_steps=1000 \ +++train_config.total_steps=110000 \ +++train_config.batch_size_training=6 \ +++train_config.val_batch_size=6 \ +++train_config.num_workers_dataloader=0 \ +++train_config.lr=5e-5 \ +++train_config.scheduler=tri \ +++train_config.validation_interval=2000 \ +++train_config.output_dir=$output_dir \ +++train_config.enable_fsdp=false \ +++train_config.enable_ddp=true \ +++train_config.use_fp16=true \ +++metric=acc \ +++log_config.log_file=/$output_dir/train.log \ +++log_config.use_wandb=true \ +++log_config.wandb_dir=$output_dir \ +++log_config.wandb_entity_name=yanghaha \ +++log_config.wandb_project_name=slam-llm-slides \ +++log_config.wandb_exp_name=slides-finetune-wavlm \ +++log_config.log_interval=10 \ diff --git a/src/slam_llm/datasets/slidespeech_dataset.py b/src/slam_llm/datasets/slidespeech_dataset.py new file mode 100644 index 00000000..b57b22b1 --- /dev/null +++ b/src/slam_llm/datasets/slidespeech_dataset.py @@ -0,0 +1,325 @@ +import torch +from torch.utils.data import Dataset +import whisper +import kaldiio +import copy +import numpy as np +from tqdm import tqdm + + +class SlidespeechDataset(Dataset): + def __init__(self, dataset_config, tokenizer=None, split='train',): + super().__init__() + self.data_list = [] + self.num_samples_list = [] + self.label_list = [] + self.ocr_list = [] + self.key_list=[] # for debug + self.asr_list=[] # not gt + + if split == "train": + with open(dataset_config.train_scp_file_path + "my_wav.scp",'r') as f: + for line in f: + line = line.strip().split() + self.data_list.append(line[1]) + self.key_list.append(line[0]) + + with open(dataset_config.train_scp_file_path + "utt2num_samples",'r') as f: + for line in f: + line = line.strip().split() + self.num_samples_list.append(int(line[1])) + + with open(dataset_config.train_scp_file_path + "text",'r') as f: + for line in f: + line = line.strip().split(' ',1) + if len(line) == 1: + self.label_list.append(None) + else: + if dataset_config.lower: + self.label_list.append(line[1].lower()) + else: + self.label_list.append(line[1]) + + with open(dataset_config.train_scp_file_path + "hot_related/ocr_1gram_top50_mmr070_hotwords_list",'r') as f: + for line in f: + line = line.strip().split() + if len(line) == 1: + self.ocr_list.append(None) + else: + line = line[1] + line = line.split('$') + line = " ".join(line) + + if dataset_config.lower: + self.ocr_list.append(line.lower()) + else: + self.ocr_list.append(line) + + + elif split == "val": + with open(dataset_config.dev_scp_file_path + "my_wav.scp",'r') as f: + for line in f: + line = line.strip().split() + self.data_list.append(line[1]) + self.key_list.append(line[0]) + + with open(dataset_config.dev_scp_file_path + "utt2num_samples",'r') as f: + for line in f: + line = line.strip().split() + self.num_samples_list.append(int(line[1])) + + with open(dataset_config.dev_scp_file_path + "text",'r') as f: + for line in f: + line = line.strip().split(' ',1) + if len(line) == 1: + self.label_list.append(None) + else: + if dataset_config.lower: + self.label_list.append(line[1].lower()) + else: + self.label_list.append(line[1]) + + with open(dataset_config.dev_scp_file_path + "hot_related/ocr_1gram_top50_mmr070_hotwords_list",'r') as f: + for line in f: + line = line.strip().split() + if len(line) == 1: + self.ocr_list.append(None) + else: + line = line[1] + line = line.split('$') + line = " ".join(line) + + if dataset_config.lower: + self.ocr_list.append(line.lower()) + else: + self.ocr_list.append(line) + + elif split == "test": # 3188 只有prev用这个 不用ground truth 用解码 可以考虑要不要删了 + with open(dataset_config.test_scp_file_path + "my_wav.scp",'r') as f: + for line in f: + line = line.strip().split() + self.data_list.append(line[1]) + self.key_list.append(line[0]) + + with open(dataset_config.test_scp_file_path + "utt2num_samples",'r') as f: + for line in f: + line = line.strip().split() + self.num_samples_list.append(int(line[1])) + + with open(dataset_config.test_scp_file_path + "text",'r') as f: + for line in f: + line = line.strip().split(' ',1) + if len(line) == 1: + self.label_list.append(None) + else: + if dataset_config.lower: + self.label_list.append(line[1].lower()) + else: + self.label_list.append(line[1]) + + with open(dataset_config.test_scp_file_path + "hot_related/ocr_1gram_top50_mmr070_hotwords_list",'r') as f: + for line in f: + line = line.strip().split() + if len(line) == 1: + self.ocr_list.append(None) + else: + line = line[1] + line = line.split('$') + line = " ".join(line) + + if dataset_config.lower: + self.ocr_list.append(line.lower()) + else: + self.ocr_list.append(line) + + + + self.dataset_config = dataset_config + self.tokenizer = tokenizer + self.IGNORE_INDEX = -100 # The default setting in CrossEntropyLoss + self.mel_size = dataset_config.get("mel_size", 80) # 80 for whisper large v1 and v2, 128 for large v3 + self.prompt = dataset_config.get("prompt", None) + self.prompt_template1 = "USER: {}\n ASSISTANT:" + self.prompt_template2 = "USER: Transcribe speech to text. Use hotwords in ppt to improve speech recognition accuracy. But if the hotwords are irrelevant, just ignore them. The hotwords are \"{}\". \n ASSISTANT:" + self.answer_template = "{}" + self.fix_length_audio = dataset_config.get("fix_length_audio", -1) + self.inference_mode = dataset_config.get("inference_mode", False) + self.normalize = dataset_config.get("normalize", False) + self.input_type = dataset_config.get("input_type", None) + assert self.input_type in ["raw", "mel"], "input_type must be one of [raw, mel]" + + def get_source_len(self, data_dict): + return data_dict["source_len"] + + def get_target_len(self, data_dict): + + return data_dict["target_len"] if "target_len" in data_dict else 0 + + def __len__(self): + return len(self.data_list) + + def __getitem__(self, index): + ark_path = self.data_list[index] + numpy_array = kaldiio.load_mat(ark_path) + audio_raw = numpy_array[1].astype(np.float32) + num_samples = self.num_samples_list[index] + assert(audio_raw.shape[0] == num_samples) + ocr = self.ocr_list[index] + target = self.label_list[index] + key = self.key_list[index] + + + if self.input_type == "raw": + audio_raw = torch.from_numpy(audio_raw).float() + if self.normalize: + audio_raw = torch.nn.functional.layer_norm(audio_raw, audio_raw.shape) + audio_length = len(audio_raw) // 320 # ad-hoc for fairseq 320x downsample + audio_length = audio_length // 5 # ad-hoc for 5x fc downsample + elif self.input_type == "mel": + audio_mel = whisper.log_mel_spectrogram(audio_raw, n_mels=self.mel_size).permute(1, 0) + audio_length = (audio_mel.shape[0] + 1) // 2 # ad-hoc for whisper for 2x downsample from mel to feats + audio_length = audio_length // 5 # ad-hoc for 5x fc downsample + # audio_length = calculate_output_length_1d(audio_length, 5, 5, 0) # ad-hoc for 5x cov1d downsample + if self.fix_length_audio > 0: + audio_length = self.fix_length_audio + audio_pseudo = torch.full((audio_length,), -1) # placeholder + + + if self.dataset_config.use_ocr == True and ocr != None: + prompt = self.prompt_template2.format(ocr) + else: + prompt = self.prompt_template1.format(self.prompt) + # if self.dataset_config.task=="keyword_yizhi": + # if self.dataset_config.use_ocr == False or ocr == None: + # ocr="" + # prompt = self.prompt_template2.format(ocr) + prompt_ids = self.tokenizer.encode(prompt) + prompt_length = len(prompt_ids) + + + if self.inference_mode: + prompt_ids = torch.tensor(prompt_ids, dtype=torch.int64) + example_ids = torch.cat((audio_pseudo, prompt_ids)) # [audio,prompt] + example_mask = example_ids.ge(-1) # [True,True] + + return { + "input_ids": example_ids, + "attention_mask": example_mask, + "audio": audio_raw if self.input_type == "raw" else None, + "audio_mel": audio_mel if self.input_type == "mel" else None, + 'audio_length': audio_length, + 'key': key, + 'target': target, + } + + answer = self.answer_template.format(target) + example = prompt + answer # FIX(MZY): avoid putting a bos token before answer. + example_ids = self.tokenizer.encode(example) # [prompt,answer] + example_ids.append(self.tokenizer.eos_token_id) # [prompt,answer,eos] + example_ids = torch.tensor( + example_ids, dtype=torch.int64 + ) + example_ids = torch.cat((audio_pseudo, example_ids)) # [audio,prompt,answer,eos] + + labels_ids = copy.deepcopy(example_ids) # [audio,prompt,answer,eos] + labels_ids[:audio_length + prompt_length] = -1 # [-1,-1,answer,eos]; + example_mask = example_ids.ge(-1) # FIX(GZF): [True,True,True,True] + + label_mask = labels_ids.ge(0) # [False,False,True,True] + example_ids[~example_mask] = 0 # [audio,prompt,answer,eos] + labels_ids[~label_mask] = self.IGNORE_INDEX # [-100,-100,answer,eos] + + return { + "input_ids": example_ids, + "labels": labels_ids, + "attention_mask": example_mask, + "audio": audio_raw if self.input_type == "raw" else None, + "audio_mel": audio_mel if self.input_type == "mel" else None, + 'audio_length': audio_length, + } + + def pad(self, sequence, max_length, padding_idx=0): + if isinstance(sequence, (int, list, tuple)): + if len(sequence) < max_length: + sequence = sequence + [padding_idx] * (max_length - len(sequence)) + else: + sequence = sequence[:max_length] + elif isinstance(sequence, torch.Tensor): + if len(sequence) < max_length: + sequence = torch.cat( + (sequence, torch.full(([max_length - len(sequence)] + list(sequence.size())[1:]), padding_idx))) + else: + sequence = sequence[:max_length] + elif isinstance(sequence, np.ndarray): + if len(sequence) < max_length: + sequence = np.concatenate( + (sequence, np.full((max_length - len(sequence),) + sequence.shape[1:], padding_idx))) + else: + sequence = sequence[:max_length] + else: + raise Exception("Type mismatch during padding!") + return sequence + + def collator(self, samples): + assert samples is not None + input_ids_max_length = max([s['input_ids'].shape[0] for s in samples]) + input_ids = torch.stack([self.pad(s['input_ids'], input_ids_max_length, self.tokenizer.pad_token_id) + for s in samples]) + attention_mask = torch.stack([self.pad(s['attention_mask'], input_ids_max_length, False) + for s in samples]) + if self.input_type == "raw": + audio_raw_max_length = max([s['audio'].shape[0] for s in samples]) + audio_raw = torch.stack([self.pad(s['audio'], audio_raw_max_length, 0) + for s in samples]) + audio_mask = torch.zeros(len(samples), audio_raw_max_length) + for line, sample in enumerate(samples): + audio_mask[line, :sample['audio'].shape[0]] = 1 + elif self.input_type == "mel": + audio_mel_max_length = max([s['audio_mel'].shape[0] for s in samples]) + audio_mel = torch.stack([self.pad(s['audio_mel'], audio_mel_max_length, 0) + for s in samples]) + audio_mel_post_mask = torch.zeros(len(samples), (audio_mel_max_length + 1) // 2) # ad-hoc for whisper for 2x downsample from mel to feats + for line, sample in enumerate(samples): + audio_mel_post_mask[line, :(sample['audio_mel'].shape[0] + 1) // 2] = 1 + + modality_mask = torch.zeros_like(attention_mask) + for line, sample in enumerate(samples): + modality_mask[line, :sample['audio_length']] = 1 + + if self.inference_mode: + keys = [s['key'] for s in samples] + targets = [s['target'] for s in samples] + + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "audio": audio_raw if self.input_type == "raw" else None, + "audio_mask": audio_mask if self.input_type == "raw" else None, + "audio_mel": audio_mel if self.input_type == "mel" else None, + "audio_mel_post_mask": audio_mel_post_mask if self.input_type == "mel" else None, + "modality_mask": modality_mask, + "keys": keys, + "targets": targets + } + + labels = torch.stack([self.pad(s['labels'], input_ids_max_length, self.IGNORE_INDEX) + for s in samples]) + return { + "input_ids": input_ids, + "labels": labels, + "attention_mask": attention_mask, + "audio": audio_raw if self.input_type == "raw" else None, + "audio_mask": audio_mask if self.input_type == "raw" else None, + "audio_mel": audio_mel if self.input_type == "mel" else None, + "audio_mel_post_mask": audio_mel_post_mask if self.input_type == "mel" else None, + "modality_mask": modality_mask + } + + +def get_speech_dataset(dataset_config, tokenizer, split): + dataset = SlidespeechDataset(dataset_config, tokenizer, split) + return dataset + + + + \ No newline at end of file diff --git a/src/slam_llm/pipeline/finetune.py b/src/slam_llm/pipeline/finetune.py index 8f9dbbdc..f9132263 100644 --- a/src/slam_llm/pipeline/finetune.py +++ b/src/slam_llm/pipeline/finetune.py @@ -79,10 +79,6 @@ def main(kwargs: DictConfig): kwargs.dataset_config fsdp_config.use_fp16 = train_config.use_fp16 -<<<<<<< HEAD - # if model_config.encoder_name=="av_hubert": -======= ->>>>>>> origin/main OmegaConf.set_struct(kwargs,False) del kwargs["train_config"] del kwargs["fsdp_config"] @@ -90,15 +86,6 @@ def main(kwargs: DictConfig): del kwargs["log_config"] del kwargs["dataset_config"] OmegaConf.set_struct(kwargs,True) -<<<<<<< HEAD - # else: - # del kwargs.train_config - # del kwargs.fsdp_config - # del kwargs.model_config - # del kwargs.log_config - # del kwargs.dataset_config -======= ->>>>>>> origin/main # Set log if not os.path.exists(os.path.dirname(log_config.log_file)):