Skip to content

Commit

Permalink
Merge pull request espnet#5825 from Emrys365/urgent_recipe
Browse files Browse the repository at this point in the history
Update of SE functions
  • Loading branch information
sw005320 authored Jul 23, 2024
2 parents b8b31c0 + a2aa167 commit e0dd1cf
Show file tree
Hide file tree
Showing 8 changed files with 243 additions and 39 deletions.
25 changes: 23 additions & 2 deletions egs2/TEMPLATE/enh1/enh.sh
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ extra_wav_list= # Extra list of scp files for wav formatting
init_param=

# Enhancement related
inference_args="--normalize_output_wav true"
inference_args="--normalize_output_wav true --output_format wav"
inference_model=valid.loss.ave.pth
download_model=

Expand Down Expand Up @@ -172,6 +172,7 @@ Options:
--inference_args # Arguments for enhancement in the inference stage (default="${inference_args}")
--inference_model # Enhancement model path for inference (default="${inference_model}").
--inference_enh_config # Configuration file for overwriting some model attributes during SE inference. (default="${inference_enh_config}")
--download_model # Download a model from Model Zoo and use it for inference (default="${download_model}").
# Evaluation related
--scoring_protocol # Metrics to be used for scoring (default="${scoring_protocol}")
Expand Down Expand Up @@ -222,7 +223,7 @@ fi
[ -z "${test_sets}" ] && { log "${help_message}"; log "Error: --test_sets is required"; exit 2; };

# Extra files for enhancement process
utt_extra_files="utt2category"
utt_extra_files="utt2category utt2fs"

data_feats=${dumpdir}/raw

Expand Down Expand Up @@ -800,6 +801,26 @@ else
fi



if [ -n "${download_model}" ]; then
log "Use ${download_model} for inference and scoring"
enh_exp="${expdir}/${download_model}"
mkdir -p "${enh_exp}"

# If the model already exists, you can skip downloading
espnet_model_zoo_download --unpack true "${download_model}" > "${enh_exp}/config.txt"

# Get the path of each file
_enh_model_file=$(<"${enh_exp}/config.txt" sed -e "s/.*'enh_model_file': '\([^']*\)'.*$/\1/")
_enh_train_config=$(<"${enh_exp}/config.txt" sed -e "s/.*'enh_train_config': '\([^']*\)'.*$/\1/")

# Create symbolic links
ln -sf "${_enh_model_file}" "${enh_exp}"
ln -sf "${_enh_train_config}" "${enh_exp}"
inference_model=$(basename "${_enh_model_file}")
fi


if ! "${skip_eval}"; then
if [ ${stage} -le 7 ] && [ ${stop_stage} -ge 7 ]; then
log "Stage 7: Enhance Speech: training_dir=${enh_exp}"
Expand Down
2 changes: 1 addition & 1 deletion egs2/TEMPLATE/enh_asr1/enh_asr.sh
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ inference_tag= # Suffix to the result dir for decoding.
inference_config= # Config for decoding.
asr_inference_args= # Arguments for decoding, e.g., "--lm_weight 0.1".
# Note that it will overwrite args in inference config.
enh_inference_args="--normalize_output_wav true"
enh_inference_args="--normalize_output_wav true --output_format wav"
inference_lm=valid.loss.ave.pth # Language model path for decoding.
inference_ngram=${ngram_num}gram.bin
inference_enh_asr_model=valid.acc.ave.pth # ASR model path for decoding.
Expand Down
13 changes: 12 additions & 1 deletion espnet2/bin/enh_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,7 @@ def inference(
normalize_segment_scale: bool,
show_progressbar: bool,
ref_channel: Optional[int],
output_format: str,
normalize_output_wav: bool,
enh_s2t_task: bool,
):
Expand Down Expand Up @@ -542,7 +543,11 @@ def inference(
writers = []
for i in range(separate_speech.num_spk):
writers.append(
SoundScpWriter(f"{output_dir}/wavs/{i + 1}", f"{output_dir}/spk{i + 1}.scp")
SoundScpWriter(
f"{output_dir}/wavs/{i + 1}",
f"{output_dir}/spk{i + 1}.scp",
format=output_format,
)
)

import tqdm
Expand Down Expand Up @@ -623,6 +628,12 @@ def get_parser():
group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)

group = parser.add_argument_group("Output data related")
group.add_argument(
"--output_format",
type=str,
default="wav",
help="Output format for the separated speech",
)
group.add_argument(
"--normalize_output_wav",
type=str2bool,
Expand Down
13 changes: 12 additions & 1 deletion espnet2/bin/enh_inference_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ def inference(
inference_config: Optional[str],
allow_variable_data_keys: bool,
ref_channel: Optional[int],
output_format: str,
enh_s2t_task: bool,
):
if batch_size > 1:
Expand Down Expand Up @@ -290,7 +291,11 @@ def inference(
writers = []
for i in range(separate_speech.num_spk):
writers.append(
SoundScpWriter(f"{output_dir}/wavs/{i + 1}", f"{output_dir}/spk{i + 1}.scp")
SoundScpWriter(
f"{output_dir}/wavs/{i + 1}",
f"{output_dir}/spk{i + 1}.scp",
format=output_format,
)
)

import tqdm
Expand Down Expand Up @@ -389,6 +394,12 @@ def get_parser():
group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)

group = parser.add_argument_group("Output data related")
group.add_argument(
"--output_format",
type=str,
default="wav",
help="Output format for the separated speech",
)

group = parser.add_argument_group("The model configuration related")
group.add_argument(
Expand Down
13 changes: 12 additions & 1 deletion espnet2/bin/enh_tse_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,7 @@ def inference(
normalize_segment_scale: bool,
show_progressbar: bool,
ref_channel: Optional[int],
output_format: str,
normalize_output_wav: bool,
):
if batch_size > 1:
Expand Down Expand Up @@ -513,7 +514,11 @@ def inference(
writers = []
for i in range(separate_speech.num_spk):
writers.append(
SoundScpWriter(f"{output_dir}/wavs/{i + 1}", f"{output_dir}/spk{i + 1}.scp")
SoundScpWriter(
f"{output_dir}/wavs/{i + 1}",
f"{output_dir}/spk{i + 1}.scp",
format=output_format,
)
)

for i, (keys, batch) in enumerate(loader):
Expand Down Expand Up @@ -590,6 +595,12 @@ def get_parser():
default=False,
help="Whether to normalize the predicted wav to [-1~1]",
)
group.add_argument(
"--output_format",
type=str,
default="wav",
help="Output format for the separated speech",
)

group = parser.add_argument_group("The model configuration related")
group.add_argument(
Expand Down
Loading

0 comments on commit e0dd1cf

Please sign in to comment.