Skip to content

Commit

Permalink
Merge branch 'dev-mzy' of github.com:ddlBoJack/SLAM-LLM into dev-mzy
Browse files Browse the repository at this point in the history
  • Loading branch information
ddlBoJack committed Jan 10, 2024
2 parents 29279d4 + 0309a32 commit 9623db5
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 10 deletions.
12 changes: 6 additions & 6 deletions scripts/inference_asr_batch.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/bin/bash
#export PYTHONPATH=/root/whisper:$PYTHONPATH
export CUDA_VISIBLE_DEVICES=1
export CUDA_VISIBLE_DEVICES=0
# export CUDA_LAUNCH_BLOCKING=1

cd /root/SLAM-LLM
Expand All @@ -11,11 +11,11 @@ speech_encoder_path=/nfs/zhifu.gzf/ckpt/Whisper/large-v2.pt
# llm_path=/nfs/zhifu.gzf/ckpt/Llama-2-7b-hf
llm_path=/nfs/maziyang.mzy/models/vicuna-7b-v1.5

output_dir=nfs/maziyang.mzy/exps/vicuna-7b-v1.5-finetune-asr-ds5-proj2048-lr1e-4-whisper-prompt-padding30-20240106
ckpt_path=/nfs/maziyang.mzy/exps/vicuna-7b-v1.5-finetune-asr-ds5-proj2048-lr1e-4-whisper-prompt-padding30-20240106/asr/2
output_dir=/nfs/maziyang.mzy/exps/vicuna-7b-v1.5-finetune-asr-ds5-proj2048-lr1e-4-whisper-prompt-padding0-20240107
ckpt_path=$output_dir/asr/2
# peft_ckpt=/nfs/maziyang.mzy/exps/llama-2-hf-finetune-asr-ds5-proj2048-lr1e-4-whisper-lora-prompt-paddinglr-20240102/asr/4
val_data_path=/nfs/maziyang.mzy/data/librispeech/librispeech_test_other_filtered.jsonl
decode_log=$ckpt_path/decode_log_test_other_bs4_beam4_repetition_penalty1
val_data_path=/nfs/maziyang.mzy/data/librispeech/librispeech_test_clean_filtered.jsonl
decode_log=$ckpt_path/decode_log_test_clean_bs8_beam4_repetition_penalty1

# -m debugpy --listen 5678 --wait-for-client
python src/llama_recipes/pipeline/inference_batch.py \
Expand All @@ -35,7 +35,7 @@ python src/llama_recipes/pipeline/inference_batch.py \
--speech_dataset.val_data_path $val_data_path \
--batching_strategy custom \
--num_epochs 1 \
--val_batch_size 4 \
--val_batch_size 8 \
--num_workers_dataloader 4 \
--output_dir $output_dir \
--ckpt_path $ckpt_path/model.pt \
Expand Down
2 changes: 1 addition & 1 deletion src/llama_recipes/datasets/speech_dataset_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def __getitem__(self, index):

speech_raw = whisper.load_audio(speech_path)
# speech_raw = whisper.pad_or_trim(speech_raw)
speech_raw = np.concatenate((np.zeros(80000), speech_raw, np.zeros(80000))).astype(speech_raw.dtype)[:16000*30]
# speech_raw = np.concatenate((np.zeros(80000), speech_raw, np.zeros(80000))).astype(speech_raw.dtype)[:16000*30]
speech_mel = whisper.log_mel_spectrogram(speech_raw).permute(1, 0)

prompt = "Transcribe speech to text. Output the transcription directly without redundant content. Ensure that the output is not duplicated. "
Expand Down
4 changes: 2 additions & 2 deletions src/llama_recipes/pipeline/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def main(**kwargs):

# Set log
if not os.path.exists(os.path.dirname(log_config.log_file)):
os.makedirs(os.path.dirname(log_config.log_file))
os.makedirs(os.path.dirname(log_config.log_file), exist_ok=True)
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
Expand Down Expand Up @@ -103,7 +103,7 @@ def main(**kwargs):
if not train_config.enable_fsdp or rank == 0:
if log_config.use_wandb:
if not os.path.exists(log_config.wandb_dir):
os.makedirs(log_config.wandb_dir)
os.makedirs(log_config.wandb_dir, exist_ok=True)
wandb_config={"train_config":vars(train_config), "fsdp_config":vars(fsdp_config), "model_config":vars(model_config), "log_config":vars(log_config)}
wandb.init(dir=log_config.wandb_dir, entity=log_config.wandb_entity_name, project=log_config.wandb_project_name,name=log_config.wandb_exp_name ,config=wandb_config)

Expand Down
2 changes: 1 addition & 1 deletion src/llama_recipes/pipeline/inference_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def main(**kwargs):

# Set log
if not os.path.exists(os.path.dirname(log_config.log_file)):
os.makedirs(os.path.dirname(log_config.log_file))
os.makedirs(os.path.dirname(log_config.log_file), exist_ok=True)
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
Expand Down

0 comments on commit 9623db5

Please sign in to comment.