diff --git a/mmaudio/ext/autoencoder/vae.py b/mmaudio/ext/autoencoder/vae.py index cac769e..c5c9ba3 100644 --- a/mmaudio/ext/autoencoder/vae.py +++ b/mmaudio/ext/autoencoder/vae.py @@ -74,18 +74,12 @@ def __init__( ): super().__init__() - device = "cpu" - if torch.backends.mps.is_available(): - device = "mps" - elif torch.cuda.is_available(): - device = "cuda" - if data_dim == 80: - self.data_mean = nn.Buffer(torch.tensor(DATA_MEAN_80D, dtype=torch.float32).to(device)) - self.data_std = nn.Buffer(torch.tensor(DATA_STD_80D, dtype=torch.float32).to(device)) + self.data_mean = nn.Buffer(torch.tensor(DATA_MEAN_80D, dtype=torch.float32)) + self.data_std = nn.Buffer(torch.tensor(DATA_STD_80D, dtype=torch.float32)) elif data_dim == 128: - self.data_mean = nn.Buffer(torch.tensor(DATA_MEAN_128D, dtype=torch.float32).to(device)) - self.data_std = nn.Buffer(torch.tensor(DATA_STD_128D, dtype=torch.float32).to(device)) + self.data_mean = nn.Buffer(torch.tensor(DATA_MEAN_128D, dtype=torch.float32)) + self.data_std = nn.Buffer(torch.tensor(DATA_STD_128D, dtype=torch.float32)) self.data_mean = self.data_mean.view(1, -1, 1) self.data_std = self.data_std.view(1, -1, 1)