Skip to content

Commit

Permalink
Merge pull request #176 from X-LANCE/yxdu
Browse files Browse the repository at this point in the history
Yxdu
  • Loading branch information
ddlBoJack authored Nov 24, 2024
2 parents f42716d + 33b84ed commit 6c26585
Show file tree
Hide file tree
Showing 20 changed files with 125,016 additions and 109,284 deletions.
70 changes: 50 additions & 20 deletions examples/st_covost2/README.md
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,49 @@


## Model Stracture
<img src="image/framework.jpg" alt="示例图片" style="width:75%;">
<img src="image/framework.jpg" alt="Photo" style="width:75%;">


## Multitask
<img src="image/prompt.png" alt="示例图片" style="width:50%;">
<img src="image/prompt.png" alt="Photo" style="width:50%;">


## Installation
```
conda create -n cotst python=3.10
conda activate cotst
git clone https://github.com/ddlBoJack/SLAM-LLM.git
cd SLAM-LLM
pip install -e .
sudo apt install ffmpeg
pip install -U openai-whisper
pip install wandb
pip install soundfile
pip install evaluate
pip install transformers
pip install datasets
pip install sacrebleu
pip install jiwer
pip install librosa
pip install torch==2.4.0
pip install torchaudio==2.4.0
pip install torchvision==0.19.0
```

## Infer Demo
It is recommended to run on a single GPU for the first execution. Later, remove CUDA_VISIBLE_DEVICES=0, and it will automatically utilize all GPUs.

This demo will automatically download the model and dataset from Hugging Face, totaling approximately 100GB. Each card requires 128GB of RAM and 24GB of GPU memory.

#supported translation languages are Chinese (zh), German (de), and Japanese (ja).


```
CUDA_VISIBLE_DEVICES=0 bash examples/st_covost2/scripts/infer_enzh.sh zh
```


## Download Model
We only train the q-former projector in this recipe.
Expand Down Expand Up @@ -46,31 +82,25 @@ You can find the test jsonl in "test_st.jsonl"
Here, we have designed a three-step training process, where each training session uses the checkpoint obtained from the previous training session.
```
#In this step, we perform ASR pretraining to acquire speech recognition capabilities.
bash asr_pretrain.sh
bash examples/st_covost2/scripts/asr_pretrain.sh
#In this phase, we conduct multimodal machine translation training to enhance the final performance.
bash mmt.sh
#monolingual SRT training and multitask training.
bash srt.sh
bash zsrt.sh
#monolingual MMT,SRT training and multitask training.
#You can change the task type by modifying the value of **source** in the script.
bash examples/st_covost2/scripts/all.sh
```


## Infer Stage
You can try our pre-trained model.

```
bash infer_enzh.sh
```

## Citation
You can refer to the paper for more results.
```
@article{du2024cot,
title={CoT-ST: Enhancing LLM-based Speech Translation with Multimodal Chain-of-Thought},
author={Yexing Du, Ziyang Ma, Yifan Yang, Keqi Deng, Xie Chen, Bo Yang, Yang Xiang, Ming Liu, Bing Qin},
journal={arXiv preprint arXiv:2409.19510},
year={2024}
@misc{du2024cotstenhancingllmbasedspeech,
title={CoT-ST: Enhancing LLM-based Speech Translation with Multimodal Chain-of-Thought},
author={Yexing Du and Ziyang Ma and Yifan Yang and Keqi Deng and Xie Chen and Bo Yang and Yang Xiang and Ming Liu and Bing Qin},
year={2024},
eprint={2409.19510},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/2409.19510},
}
```
8 changes: 5 additions & 3 deletions examples/st_covost2/asr_config.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,10 @@ class ModelConfig:
encoder_type: str = field(default="finetune", metadata={
"help": "whether model is only pretrained or finetuned, used for models such as hubert"
})
ckpt_path: Optional[str] = None
query_len: Optional[str] = None
qformer_layers: int = 8




@dataclass
Expand Down Expand Up @@ -93,7 +95,7 @@ class DataConfig:
train_data_path: Optional[str] = None
val_data_path: Optional[str] = None
train_split: str = "train"
test_split:str = "validation"
test_split:str = "test"
prompt: Optional[str] = None
data_path: Optional[str] = None
max_words: Optional[int] = None
Expand Down Expand Up @@ -127,7 +129,7 @@ class FSDPConfig:
class LogConfig:
use_wandb: bool = False
wandb_dir: str = "test_wandb"
wandb_entity_name: str = "SLAM"
wandb_entity_name: str = "sdinger"
wandb_project_name: str = "project_name"
wandb_exp_name: str = "exp_name"
log_file: str = "./test.log"
Expand Down
34 changes: 34 additions & 0 deletions examples/st_covost2/change_dir.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import os
import json

# 定义输入文件夹路径
folder_path = ""

# 定义关键词替换规则
old_keyword = "" # 需要替换的关键词
new_keyword = "/code_dir" # 替换成的关键词

# 遍历文件夹及其子文件夹
for root, _, files in os.walk(folder_path):
for file_name in files:
if file_name.endswith(".jsonl"):
file_path = os.path.join(root, file_name)

# 读取和处理 JSONL 文件
with open(file_path, "r", encoding="utf-8") as file:
lines = file.readlines()

updated_lines = []
for line in lines:
data = json.loads(line)
if "audio" in data and old_keyword in data["audio"]:
data["audio"] = data["audio"].replace(old_keyword, new_keyword)
updated_lines.append(json.dumps(data, ensure_ascii=False))

# 写入修改后的内容到原文件
with open(file_path, "w", encoding="utf-8") as file:
file.write("\n".join(updated_lines))

print(f"关键词替换完成,修改内容已写回文件: {file_path}")

print("所有文件处理完成。")
2 changes: 1 addition & 1 deletion examples/st_covost2/conf/prompt.yaml
100755 → 100644
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
dataset_config:
# we put prompt here, because the hydra override in shell script only support a small subset of chars
# prompt: "Transcribe speech to text. Output the transcription directly without redundant content. Ensure that the output is not duplicated. "
prompt: "<en>"
# prompt: "<en>"
Loading

0 comments on commit 6c26585

Please sign in to comment.