Skip to content

Commit

Permalink
no need to use device in VAE
Browse files Browse the repository at this point in the history
  • Loading branch information
hkchengrex committed Dec 22, 2024
1 parent f451794 commit 5f0e921
Showing 1 changed file with 4 additions and 10 deletions.
14 changes: 4 additions & 10 deletions mmaudio/ext/autoencoder/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,18 +74,12 @@ def __init__(
):
super().__init__()

device = "cpu"
if torch.backends.mps.is_available():
device = "mps"
elif torch.cuda.is_available():
device = "cuda"

if data_dim == 80:
self.data_mean = nn.Buffer(torch.tensor(DATA_MEAN_80D, dtype=torch.float32).to(device))
self.data_std = nn.Buffer(torch.tensor(DATA_STD_80D, dtype=torch.float32).to(device))
self.data_mean = nn.Buffer(torch.tensor(DATA_MEAN_80D, dtype=torch.float32))
self.data_std = nn.Buffer(torch.tensor(DATA_STD_80D, dtype=torch.float32))
elif data_dim == 128:
self.data_mean = nn.Buffer(torch.tensor(DATA_MEAN_128D, dtype=torch.float32).to(device))
self.data_std = nn.Buffer(torch.tensor(DATA_STD_128D, dtype=torch.float32).to(device))
self.data_mean = nn.Buffer(torch.tensor(DATA_MEAN_128D, dtype=torch.float32))
self.data_std = nn.Buffer(torch.tensor(DATA_STD_128D, dtype=torch.float32))

self.data_mean = self.data_mean.view(1, -1, 1)
self.data_std = self.data_std.view(1, -1, 1)
Expand Down

0 comments on commit 5f0e921

Please sign in to comment.