Skip to content

Commit

Permalink
Merge branch 'master' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
Darshan7575 authored Jul 21, 2024
2 parents 4e244bc + f2431f4 commit ab1d0a6
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 5 deletions.
4 changes: 2 additions & 2 deletions espnet2/gan_svs/vits/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""

import math
from typing import List, Optional, Tuple
from typing import List, Optional, Tuple, Union

import numpy as np
import torch
Expand Down Expand Up @@ -110,7 +110,7 @@ def __init__(
use_phoneme_predictor: bool = False,
expand_f0_method: str = "repeat",
# hubert
hubert_channels: int = 0,
hubert_channels: Union[int, None] = 0,
):
"""Initialize VITS generator module.
Expand Down
20 changes: 17 additions & 3 deletions espnet2/tasks/abs_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import logging
import os
import sys
import tempfile
from abc import ABC, abstractmethod
from dataclasses import dataclass
from pathlib import Path
Expand Down Expand Up @@ -1196,10 +1197,23 @@ def main(

# The following block is copied from:
# https://github.com/pytorch/pytorch/blob/master/torch/multiprocessing/spawn.py
error_queues = []
error_files = []
processes = []
mp = torch.multiprocessing.get_context("spawn")
for i in range(args.ngpu):

# Each process is assigned a file to write tracebacks to. We
# use the file being non-empty to indicate an exception
# occurred (vs an expected shutdown). Note: this previously
# used a multiprocessing.Queue but that can be prone to
# deadlocks, so we went with a simpler solution for a one-shot
# message between processes.
tf = tempfile.NamedTemporaryFile(
prefix="pytorch-errorfile-", suffix=".pickle", delete=False
)
tf.close()
os.unlink(tf.name)

# Copy args
local_args = argparse.Namespace(**vars(args))

Expand All @@ -1214,9 +1228,9 @@ def main(
)
process.start()
processes.append(process)
error_queues.append(mp.SimpleQueue())
error_files.append(tf.name)
# Loop on join until it returns True or raises an exception.
while not ProcessContext(processes, error_queues).join():
while not ProcessContext(processes, error_files).join():
pass

@classmethod
Expand Down

0 comments on commit ab1d0a6

Please sign in to comment.