-
Notifications
You must be signed in to change notification settings - Fork 57
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
56aa511
commit 3df6401
Showing
13 changed files
with
983 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
# MALA-ASR_SLIDESPEECH | ||
|
||
## Performance and checkpoints | ||
We only train the linear projector in this recipe. | ||
Encoder | Projector | LLM | dev | test | ||
|---|---|---|---|---| | ||
[WavLM-large](https://drive.google.com/file/d/12-cB34qCTvByWT-QtOcZaqwwO21FLSqU/view) | [Linear](https://drive.google.com/file/d/1hYS5UI3W0WVOZRVbqWxDUWIFMO9VgzHk/view?usp=drive_link)(~15.74M) | [vicuna-7b-v1.5](https://huggingface.co/lmsys/vicuna-7b-v1.5) | 8.91 | 9.14 | ||
This comment has been minimized.
Sorry, something went wrong.
This comment has been minimized.
Sorry, something went wrong.
yanghaha0908
via email
Author
Collaborator
|
||
|
||
|
||
## Data preparation | ||
Refer to official [SLIDESPEECH CORPUS](https://slidespeech.github.io/) | ||
|
||
## Decode with checkpoints | ||
``` | ||
bash decode_MaLa-ASR_withkeywords_L95.sh | ||
``` | ||
Modify the path including `speech_encoder_path`, `llm_path`, `output_dir`, `ckpt_path` and `decode_log` in the script when you run the shell script. | ||
|
||
## Train a new model | ||
|
||
### Use self-supervised model(such as WavLM) as the encoder | ||
``` | ||
bash finetune_MaLa-ASR_withkeywords_L95.sh | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
{ | ||
"train_micro_batch_size_per_gpu": 4, | ||
"gradient_accumulation_steps": 1, | ||
"optimizer": { | ||
"type": "Adam", | ||
"params": { | ||
"lr": 1e-4 | ||
} | ||
}, | ||
"fp16": { | ||
"enabled": true | ||
}, | ||
"zero_optimization": { | ||
"stage": 3, | ||
"offload_optimizer": { | ||
"device": "cpu" | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
dataset_config: | ||
# we put prompt here, because the hydra override in shell script only support a small subset of chars | ||
# prompt: "Transcribe speech to text. Output the transcription directly without redundant content. Ensure that the output is not duplicated. " | ||
prompt: "Transcribe speech to text. " |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
from slam_llm.pipeline.finetune import main as train | ||
from typing import Optional | ||
|
||
import hydra | ||
import logging | ||
from dataclasses import dataclass, field | ||
from omegaconf import DictConfig, ListConfig, OmegaConf | ||
from mala_asr_config import ModelConfig, TrainConfig, DataConfig, LogConfig, FSDPConfig | ||
|
||
@dataclass | ||
class RunConfig: | ||
dataset_config: DataConfig = field(default_factory=DataConfig) | ||
model_config: ModelConfig = field(default_factory=ModelConfig) | ||
train_config: TrainConfig = field(default_factory=TrainConfig) | ||
log_config: LogConfig = field(default_factory=LogConfig) | ||
fsdp_config: FSDPConfig = field(default_factory=FSDPConfig) | ||
debug: bool = field(default=False, metadata={"help": "Use pdb when true"}) | ||
metric: str = field(default="acc", metadata={"help": "The metric for evaluation"}) | ||
ckpt_path: Optional[str] = field( | ||
default=None, metadata={"help": "The path to projector checkpoint"} | ||
) | ||
|
||
@hydra.main(config_name=None, version_base=None) | ||
def main_hydra(cfg: DictConfig): | ||
run_config = RunConfig() | ||
cfg = OmegaConf.merge(run_config, cfg) | ||
def to_plain_list(cfg_item): | ||
if isinstance(cfg_item, ListConfig): | ||
return OmegaConf.to_container(cfg_item, resolve=True) | ||
elif isinstance(cfg_item, DictConfig): | ||
return {k: to_plain_list(v) for k, v in cfg_item.items()} | ||
else: | ||
return cfg_item | ||
|
||
# kwargs = to_plain_list(cfg) | ||
kwargs = cfg | ||
log_level = getattr(logging, kwargs.get("log_level", "INFO").upper()) | ||
|
||
logging.basicConfig(level=log_level) | ||
|
||
if kwargs.get("debug", False): | ||
import pdb; | ||
pdb.set_trace() | ||
|
||
train(kwargs) | ||
|
||
|
||
if __name__ == "__main__": | ||
main_hydra() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
from slam_llm.pipeline.inference_batch import main as inference | ||
|
||
import hydra | ||
import logging | ||
from dataclasses import dataclass, field | ||
from omegaconf import DictConfig, ListConfig, OmegaConf | ||
from typing import Optional | ||
from mala_asr_config import ModelConfig, TrainConfig, DataConfig, LogConfig, FSDPConfig | ||
|
||
|
||
@dataclass | ||
class RunConfig: | ||
dataset_config: DataConfig = field(default_factory=DataConfig) | ||
model_config: ModelConfig = field(default_factory=ModelConfig) | ||
train_config: TrainConfig = field(default_factory=TrainConfig) | ||
log_config: LogConfig = field(default_factory=LogConfig) | ||
fsdp_config: FSDPConfig = field(default_factory=FSDPConfig) | ||
debug: bool = field(default=False, metadata={"help": "Use pdb when true"}) | ||
metric: str = field(default="acc", metadata={"help": "The metric for evaluation"}) | ||
decode_log: str = field( | ||
default="output/decode_log", | ||
metadata={"help": "The prefix for the decode output"}, | ||
) | ||
ckpt_path: str = field( | ||
default="output/model.pt", metadata={"help": "The path to projector checkpoint"} | ||
) | ||
peft_ckpt: Optional[str] = field( | ||
default=None, | ||
metadata={ | ||
"help": "The path to peft checkpoint, should be a directory including adapter_config.json" | ||
}, | ||
) | ||
|
||
|
||
@hydra.main(config_name=None, version_base=None) | ||
def main_hydra(cfg: DictConfig): | ||
run_config = RunConfig() | ||
cfg = OmegaConf.merge(run_config, cfg) | ||
# kwargs = to_plain_list(cfg) | ||
log_level = getattr(logging, cfg.get("log_level", "INFO").upper()) | ||
|
||
logging.basicConfig(level=log_level) | ||
|
||
if cfg.get("debug", False): | ||
import pdb | ||
|
||
pdb.set_trace() | ||
|
||
inference(cfg) | ||
|
||
|
||
if __name__ == "__main__": | ||
main_hydra() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
from dataclasses import dataclass, field | ||
from typing import Optional, List | ||
@dataclass | ||
class ModelConfig: | ||
file: str = "examples/mala_asr_slidespeech/model/slam_model_mala_asr.py:model_factory" | ||
llm_name: str = "vicuna-7b-v1.5" | ||
llm_path: str = "PATH/to/LLAMA/7B" | ||
llm_type: str = "decoder_only" | ||
llm_dim: int = 4096 | ||
encoder_name: Optional[str] = None | ||
encoder_ds_rate: int = 2 | ||
encoder_path: Optional[str] = None | ||
encoder_dim: int = 1280 | ||
encoder_projector: str = "linear" | ||
encoder_projector_ds_rate: int = 5 | ||
modal: str = "audio" | ||
normalize: Optional[bool] = field(default=False, metadata={ | ||
"help": "whether input is normalized, used for models such as wavlm" | ||
}) | ||
encoder_type: str = field(default="finetune", metadata={ | ||
"help": "whether model is only pretrained or finetuned, used for models such as hubert" | ||
}) | ||
|
||
@dataclass | ||
class PeftConfig: | ||
peft_method: str = "lora" # None , llama_adapter, prefix | ||
r: int = 8 | ||
lora_alpha: int = 32 | ||
target_modules: List = field(default_factory=lambda: [ "q_proj", "v_proj" ]) | ||
bias: str = "none" | ||
task_type: str = "CAUSAL_LM" | ||
lora_dropout: float = 0.05 | ||
inference_mode: bool = False | ||
|
||
@dataclass | ||
class TrainConfig: | ||
model_name:str = "PATH/to/LLAMA/7B" | ||
enable_ddp:bool = False | ||
enable_deepspeed:bool = False | ||
enable_fsdp:bool = False | ||
low_cpu_fsdp:bool = False | ||
run_validation:bool = True | ||
batch_size_training:int = 4 | ||
batching_strategy:str = field(default="packing", metadata={ | ||
"help":"alternative: padding" | ||
}) # | ||
context_length:int = 4096 | ||
gradient_accumulation_steps:int = 1 | ||
num_epochs:int = 3 | ||
num_workers_dataloader:int = 1 | ||
warmup_steps:int = 1000 | ||
total_steps:int = 100000 | ||
validation_interval:int = 1000 | ||
lr:float = 1e-4 | ||
weight_decay:float = 0.0 | ||
gamma:float = 0.85 | ||
seed:int = 42 | ||
use_fp16:bool = False | ||
mixed_precision:bool = True | ||
val_batch_size:int = 1 | ||
|
||
use_peft:bool = False | ||
peft_config:PeftConfig = field(default_factory=PeftConfig) | ||
output_dir:str = "PATH/to/save/PEFT/model" | ||
freeze_layers:bool = False | ||
num_freeze_layers:int = 1 | ||
quantization:bool = False | ||
one_gpu:bool = False | ||
save_model:bool = True | ||
dist_checkpoint_root_folder:str = "PATH/to/save/FSDP/model" # will be used if using FSDP | ||
dist_checkpoint_folder:str = "fine-tuned" # will be used if using FSDP | ||
save_optimizer:bool = False # will be used if using FSDP | ||
use_fast_kernels:bool = False # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels | ||
run_test_during_validation:bool = False | ||
run_test_during_validation_file:str = "test.wav" | ||
run_test_during_validation_prompt:str = "<|ASR|>" | ||
freeze_llm:bool = field(default=False, metadata={ | ||
"help": "whether to freeze llm when finetuning, should be true when use peft finetuning" | ||
}) | ||
freeze_encoder:bool = False | ||
|
||
@dataclass | ||
class DataConfig: | ||
dataset: str = "slidespeech_dataset" | ||
file: str = "src/slam_llm/datasets/slidespeech_dataset.py:get_speech_dataset" | ||
train_scp_file_path: str = "/nfs/yangguanrou.ygr/slidespeech/train_L_95/" | ||
dev_scp_file_path: str = "/nfs/yangguanrou.ygr/slidespeech/dev_oracle_v1/" | ||
test_scp_file_path: str = "/nfs/yangguanrou.ygr/slidespeech/test_oracle_v1/" | ||
train_split: str = "train" | ||
test_split:str = "val" | ||
prompt: Optional[str] = None | ||
use_ocr: bool = True | ||
inference_mode: bool = False | ||
lower: bool = False | ||
fix_length_audio: int = -1 | ||
inference_mode:bool = False | ||
input_type: str = field(default="raw", metadata={ | ||
"help":"Use raw when input is wav, mel when for whisper" | ||
}) | ||
mel_size: int = field(default=80, metadata={ | ||
"help": "80 for whisper large v1 and v2, 128 for v3" | ||
}) | ||
normalize: Optional[bool] = field(default=False, metadata={ | ||
"help": "whether input is normalized, used for models such as wavlm" | ||
}) | ||
|
||
@dataclass | ||
class FSDPConfig: | ||
mixed_precision: bool = True | ||
use_fp16: bool = False | ||
# sharding_strategy = "FULL_SHARD" #ShardingStrategy = ShardingStrategy.FULL_SHARD | ||
sharding_strategy: str = "NO_SHARD" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP | ||
checkpoint_type: str = "SHARDED_STATE_DICT" # alternatively can use SHARDED_STATE_DICT save one file per rank, and can resize the world-size. | ||
fsdp_activation_checkpointing: bool = True | ||
fsdp_cpu_offload: bool = False | ||
pure_bf16: bool = False | ||
optimizer: str = "AdamW" | ||
|
||
@dataclass | ||
class LogConfig: | ||
use_wandb: bool = False | ||
wandb_dir: str = "/root/test_wandb" | ||
wandb_entity_name: str = "project_name" | ||
wandb_project_name: str = "project_name" | ||
wandb_exp_name: str = "exp_name" | ||
log_file: str = "/root/test.log" | ||
log_interval: int = 5 |
Oops, something went wrong.
Hi, @yanghaha0908
Thank you for your insightful work on MaLa-asr. I am particularly interested in your observations regarding the incorporation of historical long-term context into the prompt, which you mentioned was not effective.
I am currently working on a similar task and would like to know if you have conducted any in-depth analysis on the specific manifestations of the poor performance of history prompts. Additionally, have you observed any hallucination issues when inserting erroneous parts?
Looking forward to your response.
Best regards,
Alan