forked from espnet/espnet
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request espnet#5933 from pyf98/owsm-ctc-pr
Add OWSM-CTC
- Loading branch information
Showing
31 changed files
with
4,778 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,164 @@ | ||
# OWSM-CTC v3.1 | ||
|
||
[OWSM-CTC](https://aclanthology.org/2024.acl-long.549/) is an encoder-only speech foundation model based on hierarchical multi-task self-conditioned CTC. | ||
This version is trained on 180k hours of public audio data for multilingual speech recognition, any-to-any speech translation, and language identification, which follows the design of the project, [Open Whisper-style Speech Model (OWSM)](https://arxiv.org/abs/2401.16658). | ||
|
||
## Data Preparation | ||
|
||
The training data follows the same format as the encoder-decoder OWSM v3.1, except that timestamps are removed from the `text` file. Please first follow the `egs2/owsm_v3.1/s2t1` recipe to prepare OWSM data, and then convert `text` into the new format by running `python local/convert_owsm_data.py` (the path to the BPE tokenizer needs to be modified to your path). | ||
|
||
### OWSM-CTC Data Format | ||
|
||
The prepared data directory contains the following files: | ||
|
||
``` | ||
dump/raw/train | ||
├── feats_type | ||
├── spk2utt | ||
├── text | ||
├── text.ctc | ||
├── text.prev | ||
├── utt2spk | ||
├── wav.scp | ||
``` | ||
|
||
`feats_type` has a single line of text, which should be automatically generated in the data preparation stage: | ||
``` | ||
raw | ||
``` | ||
|
||
`spk2utt` and `utt2spk` have the same meaning as the standard Kaldi recipes (see `asr1` recipe for example). Typically, the speaker information is not utilized. Hence, each utterance has a unique speaker ID which is simply its utterance ID. | ||
|
||
`wav.scp` also follows the standard Kaldi format. | ||
|
||
`text` contains the multitask reference (ASR or ST) with language and task tokens but without timestamps: | ||
|
||
``` | ||
AIDATATANG_200ZH_T0055G0013S0001_000000000_000003561_zho_asr <zho><asr> 今天什么日子 | ||
... | ||
GigaST_YOU0000009624_002208970_002218840_en_st_zh <eng><st_zho> 大会结束后,我们要求有兴趣进一步参与我们项目或进一步参与气候教育的学生站出来, | ||
... | ||
MLS_en_sikhreligion6_22_macauliffe_64kb_003555300_003571720_en_asr <eng><asr> it farid considered that faqiri or holiness consisted in four things namely to be blind to the faults of muhammadans to be deaf to slander to be dumb when evil speaking is suggested and to be lame when there is a desire to visit evil places | ||
... | ||
``` | ||
|
||
`text.ctc` contains the pure ASR reference: | ||
|
||
``` | ||
AIDATATANG_200ZH_T0055G0013S0001_000000000_000003561_zho_asr 今天什么日子 | ||
... | ||
CoVoST2_147d94ad8405722d5930a859295bfac7b925ccd40c587334d34f3ebd2668a70242240866e93907398f10b7f2265a4ddb82b5355eb21fe37993d04a69900df388-common_voice_en_19741894_000000000_000006270_en_st_ca He appointed military officers to most leading government positions. | ||
... | ||
``` | ||
|
||
`text.prev` contains the previous sentence that will be used as an additional prompt. If a sample does not have a prompt, then `<na>` is used. | ||
|
||
``` | ||
AIDATATANG_200ZH_T0055G0013S0001_000000000_000003561_zho_asr <na> | ||
... | ||
GigaST_YOU0000009624_002208970_002218840_en_st_zh 与员工和同事一起,这将有助于为事物创造空间,帮助为我们创造空间,一些掩护,尝试新事物, | ||
... | ||
``` | ||
|
||
## Pre-trained Model | ||
|
||
**IMPORTANT: Our model is trained on 16kHz audio with fixed duration 30s. When using the pre-trained model, please ensure the input speech is 16kHz and pad or truncate it to 30s.** | ||
|
||
The pre-trained model is available at: https://huggingface.co/espnet/owsm_ctc_v3.1_1B | ||
|
||
The model is trained with this config: [conf/train_s2t_multitask-ctc_ebf27_conv2d8_size1024.yaml](conf/train_s2t_multitask-ctc_ebf27_conv2d8_size1024.yaml) | ||
|
||
|
||
### Example script for short-form ASR/ST | ||
|
||
```python | ||
import librosa | ||
from espnet2.bin.s2t_inference_ctc import Speech2TextGreedySearch | ||
|
||
|
||
s2t = Speech2TextGreedySearch.from_pretrained( | ||
"espnet/owsm_ctc_v3.1_1B", | ||
device="cuda", | ||
generate_interctc_outputs=False, | ||
lang_sym='<eng>', | ||
task_sym='<asr>', | ||
) | ||
|
||
# NOTE: OWSM-CTC is trained on 16kHz audio with a fixed 30s duration. Please ensure your input has the correct sample rate; otherwise resample it to 16k before feeding it to the model | ||
speech, rate = librosa.load("xxx.wav", sr=16000) | ||
speech = librosa.util.fix_length(speech, size=(16000 * 30)) | ||
|
||
res = s2t(speech)[0] | ||
print(res) | ||
``` | ||
|
||
### Example script for long-form ASR/ST | ||
|
||
```python | ||
import soundfile as sf | ||
import torch | ||
from espnet2.bin.s2t_inference_ctc import Speech2TextGreedySearch | ||
|
||
|
||
context_len_in_secs = 4 # left and right context when doing buffered inference | ||
batch_size = 32 # depends on the GPU memory | ||
s2t = Speech2TextGreedySearch.from_pretrained( | ||
"espnet/owsm_ctc_v3.1_1B", | ||
device='cuda' if torch.cuda.is_available() else 'cpu', | ||
generate_interctc_outputs=False, | ||
lang_sym='<eng>', | ||
task_sym='<asr>', | ||
) | ||
|
||
speech, rate = sf.read( | ||
"xxx.wav" | ||
) | ||
|
||
text = s2t.decode_long_batched_buffered( | ||
speech, | ||
batch_size=batch_size, | ||
context_len_in_secs=context_len_in_secs, | ||
) | ||
print(text) | ||
``` | ||
|
||
### Example for CTC forced alignment using `ctc-segmentation` | ||
|
||
CTC segmentation can be efficiently applied to audio of an arbitrary length. | ||
|
||
```python | ||
import soundfile as sf | ||
from espnet2.bin.s2t_ctc_align import CTCSegmentation | ||
from espnet_model_zoo.downloader import ModelDownloader | ||
|
||
|
||
## Please download model first | ||
d = ModelDownloader() | ||
downloaded = d.download_and_unpack("espnet/owsm_ctc_v3.1_1B") | ||
|
||
aligner = CTCSegmentation( | ||
**downloaded, | ||
fs=16000, | ||
ngpu=1, | ||
batch_size=16, # batched parallel decoding; reduce it if your GPU memory is smaller | ||
kaldi_style_text=True, | ||
time_stamps="fixed", | ||
lang_sym="<eng>", | ||
task_sym="<asr>", | ||
context_len_in_secs=2, # left and right context in buffered decoding | ||
) | ||
|
||
speech, rate = sf.read( | ||
"./test_utils/ctc_align_test.wav" | ||
) | ||
print(f"speech duration: {len(speech) / rate : .2f} seconds") | ||
text = """ | ||
utt1 THE SALE OF THE HOTELS | ||
utt2 IS PART OF HOLIDAY'S STRATEGY | ||
utt3 TO SELL OFF ASSETS | ||
utt4 AND CONCENTRATE ON PROPERTY MANAGEMENT | ||
""" | ||
|
||
segments = aligner(speech, text) | ||
print(segments) | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
# ====== About run.pl, queue.pl, slurm.pl, and ssh.pl ====== | ||
# Usage: <cmd>.pl [options] JOB=1:<nj> <log> <command...> | ||
# e.g. | ||
# run.pl --mem 4G JOB=1:10 echo.JOB.log echo JOB | ||
# | ||
# Options: | ||
# --time <time>: Limit the maximum time to execute. | ||
# --mem <mem>: Limit the maximum memory usage. | ||
# -–max-jobs-run <njob>: Limit the number parallel jobs. This is ignored for non-array jobs. | ||
# --num-threads <ngpu>: Specify the number of CPU core. | ||
# --gpu <ngpu>: Specify the number of GPU devices. | ||
# --config: Change the configuration file from default. | ||
# | ||
# "JOB=1:10" is used for "array jobs" and it can control the number of parallel jobs. | ||
# The left string of "=", i.e. "JOB", is replaced by <N>(Nth job) in the command and the log file name, | ||
# e.g. "echo JOB" is changed to "echo 3" for the 3rd job and "echo 8" for 8th job respectively. | ||
# Note that the number must start with a positive number, so you can't use "JOB=0:10" for example. | ||
# | ||
# run.pl, queue.pl, slurm.pl, and ssh.pl have unified interface, not depending on its backend. | ||
# These options are mapping to specific options for each backend and | ||
# it is configured by "conf/queue.conf" and "conf/slurm.conf" by default. | ||
# If jobs failed, your configuration might be wrong for your environment. | ||
# | ||
# | ||
# The official documentation for run.pl, queue.pl, slurm.pl, and ssh.pl: | ||
# "Parallelization in Kaldi": http://kaldi-asr.org/doc/queue.html | ||
# =========================================================~ | ||
|
||
|
||
# Select the backend used by run.sh from "local", "stdout", "sge", "slurm", or "ssh" | ||
cmd_backend='local' | ||
|
||
# Local machine, without any Job scheduling system | ||
if [ "${cmd_backend}" = local ]; then | ||
|
||
# The other usage | ||
export train_cmd="run.pl" | ||
# Used for "*_train.py": "--gpu" is appended optionally by run.sh | ||
export cuda_cmd="run.pl" | ||
# Used for "*_recog.py" | ||
export decode_cmd="run.pl" | ||
|
||
# Local machine logging to stdout and log file, without any Job scheduling system | ||
elif [ "${cmd_backend}" = stdout ]; then | ||
|
||
# The other usage | ||
export train_cmd="stdout.pl" | ||
# Used for "*_train.py": "--gpu" is appended optionally by run.sh | ||
export cuda_cmd="stdout.pl" | ||
# Used for "*_recog.py" | ||
export decode_cmd="stdout.pl" | ||
|
||
|
||
# "qsub" (Sun Grid Engine, or derivation of it) | ||
elif [ "${cmd_backend}" = sge ]; then | ||
# The default setting is written in conf/queue.conf. | ||
# You must change "-q g.q" for the "queue" for your environment. | ||
# To know the "queue" names, type "qhost -q" | ||
# Note that to use "--gpu *", you have to setup "complex_value" for the system scheduler. | ||
|
||
export train_cmd="queue.pl" | ||
export cuda_cmd="queue.pl" | ||
export decode_cmd="queue.pl" | ||
|
||
|
||
# "qsub" (Torque/PBS.) | ||
elif [ "${cmd_backend}" = pbs ]; then | ||
# The default setting is written in conf/pbs.conf. | ||
|
||
export train_cmd="pbs.pl" | ||
export cuda_cmd="pbs.pl" | ||
export decode_cmd="pbs.pl" | ||
|
||
|
||
# "sbatch" (Slurm) | ||
elif [ "${cmd_backend}" = slurm ]; then | ||
# The default setting is written in conf/slurm.conf. | ||
# You must change "-p cpu" and "-p gpu" for the "partition" for your environment. | ||
# To know the "partion" names, type "sinfo". | ||
# You can use "--gpu * " by default for slurm and it is interpreted as "--gres gpu:*" | ||
# The devices are allocated exclusively using "${CUDA_VISIBLE_DEVICES}". | ||
|
||
export train_cmd="slurm.pl" | ||
export cuda_cmd="slurm.pl" | ||
export decode_cmd="slurm.pl" | ||
|
||
elif [ "${cmd_backend}" = ssh ]; then | ||
# You have to create ".queue/machines" to specify the host to execute jobs. | ||
# e.g. .queue/machines | ||
# host1 | ||
# host2 | ||
# host3 | ||
# Assuming you can login them without any password, i.e. You have to set ssh keys. | ||
|
||
export train_cmd="ssh.pl" | ||
export cuda_cmd="ssh.pl" | ||
export decode_cmd="ssh.pl" | ||
|
||
# This is an example of specifying several unique options in the JHU CLSP cluster setup. | ||
# Users can modify/add their own command options according to their cluster environments. | ||
elif [ "${cmd_backend}" = jhu ]; then | ||
|
||
export train_cmd="queue.pl --mem 2G" | ||
export cuda_cmd="queue-freegpu.pl --mem 2G --gpu 1 --config conf/queue.conf" | ||
export decode_cmd="queue.pl --mem 4G" | ||
|
||
else | ||
echo "$0: Error: Unknown cmd_backend=${cmd_backend}" 1>&2 | ||
return 1 | ||
fi |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
beam_size: 1 | ||
penalty: 0.0 | ||
maxlenratio: 0.0 | ||
minlenratio: 0.0 | ||
lm_weight: 0.0 | ||
lang_sym: <eng> | ||
task_sym: <asr> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
--sample-frequency=16000 | ||
--num-mel-bins=80 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
# Default configuration | ||
command qsub -V -v PATH -S /bin/bash | ||
option name=* -N $0 | ||
option mem=* -l mem=$0 | ||
option mem=0 # Do not add anything to qsub_opts | ||
option num_threads=* -l ncpus=$0 | ||
option num_threads=1 # Do not add anything to qsub_opts | ||
option num_nodes=* -l nodes=$0:ppn=1 | ||
default gpu=0 | ||
option gpu=0 | ||
option gpu=* -l ngpus=$0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
--sample-frequency=16000 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
# Default configuration | ||
command qsub -v PATH -cwd -S /bin/bash -j y -l arch=*64* | ||
option name=* -N $0 | ||
option mem=* -l mem_free=$0,ram_free=$0 | ||
option mem=0 # Do not add anything to qsub_opts | ||
option num_threads=* -pe smp $0 | ||
option num_threads=1 # Do not add anything to qsub_opts | ||
option max_jobs_run=* -tc $0 | ||
option num_nodes=* -pe mpi $0 # You must set this PE as allocation_rule=1 | ||
default gpu=0 | ||
option gpu=0 | ||
option gpu=* -l gpu=$0 -q g.q |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
# Default configuration | ||
command sbatch --export=PATH | ||
option name=* --job-name $0 | ||
option time=* --time $0 | ||
option mem=* --mem-per-cpu $0 | ||
option mem=0 | ||
option num_threads=* --cpus-per-task $0 | ||
option num_threads=1 --cpus-per-task 1 | ||
option num_nodes=* --nodes $0 | ||
default gpu=0 | ||
option gpu=0 -p cpu | ||
option gpu=* -p gpu --gres=gpu:$0 -c $0 # Recommend allocating more CPU than, or equal to the number of GPU | ||
# note: the --max-jobs-run option is supported as a special case | ||
# by slurm.pl and you don't have to handle it in the config file. |
Oops, something went wrong.