-
Notifications
You must be signed in to change notification settings - Fork 88
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Optimized memory usage and speed for covar type "full" #23
base: master
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice work, and thank you for putting so much thought into this! Impressive speed-ups! 👍
Left a couple of comments (apologies for the delay!). In particular, curious to hear your ideas on whether we should move optimizations for covariance_type=full
, as this could benefit readability w.r.t. the underlying EM mechanism.
@@ -0,0 +1,39 @@ | |||
# Benchmark |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome results, thanks for sharing these!
Before merging with master, I would suggest removing benchmark.md
.
@@ -1,9 +1,11 @@ | |||
import torch | |||
import numpy as np | |||
import math |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
gmm.py:5
imports from math
, so it'd make sense to either replace all occurrences of pi
or import ceil
alongside it.
|
||
from math import pi | ||
from scipy.special import logsumexp | ||
from utils import calculate_matmul, calculate_matmul_n_times | ||
from utils import calculate_matmul, calculate_matmul_n_times, find_optimal_splits | ||
from tqdm import tqdm |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd recommend removing this to keep the repository light on dependencies — users that require this functionality can always add it.
return check_available_ram(device) >= size | ||
|
||
|
||
def find_optimal_splits(n, get_required_memory, device="cpu", safe_mode=True): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
safe_mode
doesn't seem to get passed on to will_it_fit
.
@@ -188,7 +203,8 @@ def predict(self, x, probs=False): | |||
""" | |||
x = self.check_size(x) | |||
|
|||
weighted_log_prob = self._estimate_log_prob(x) + torch.log(self.pi) | |||
weighted_log_prob = self._estimate_log_prob(x) | |||
weighted_log_prob.add_(torch.log(self.pi)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
While carrying this out in-place preserves memory, spreading this across two lines here and in 369 and 466 decreases readability somewhat. Alternatively, I reckon this could be moved into _estimate_log_prob
.
|
||
log_det = self._calculate_log_det(precision) #[K, 1] | ||
|
||
x_mu_T_precision_x_mu = torch.empty(N, K, 1, device=x.device, dtype=x.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unless there are reservations/concerns, I would consider moving this into its own utility function, in the interest of preserving readability of the code (happy to take care of this once it has been merged).
eps = (torch.eye(self.n_features) * self.eps).to(x.device) | ||
var = torch.sum((x - mu).unsqueeze(-1).matmul((x - mu).unsqueeze(-2)) * resp.unsqueeze(-1), dim=0, | ||
keepdim=True) / torch.sum(resp, dim=0, keepdim=True).unsqueeze(-1) + eps | ||
var = torch.empty(1, K, D, D, device=x.device, dtype=resp.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! 👍
Same thought as before however, given the additional complexity that's introduced here, it might make sense to define these optimizations in some other place.
covariance_type: str | ||
eps: float | ||
init_params: str | ||
covariance_data_type: str or torch.dtype |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since mu
is getting matched against this type, might as well go ahead and introduce this as dtype
altogether, right?
@@ -15,30 +17,31 @@ class GaussianMixture(torch.nn.Module): | |||
probabilities are shaped (n, k, 1) if they relate to an individual sample, | |||
or (1, k, 1) if they assign membership probabilities to one of the mixture components. | |||
""" | |||
def __init__(self, n_components, n_features, covariance_type="full", eps=1.e-6, init_params="kmeans", mu_init=None, var_init=None): | |||
def __init__(self, n_components, n_features, covariance_type="full", eps=1.e-6, init_params="kmeans", mu_init=None, var_init=None, covariance_data_type="double"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any reservations against going with default "float"
as default type (matches the torch.Tensor
default)?
log_2pi = d * np.log(2. * pi) | ||
|
||
log_det = self._calculate_log_det(precision) | ||
x = x.to(var.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since self.covariance_data_type
has been allocated, maybe use that instead?
Improved speed and memory usage with following optimizations:
the
(N, K, 1, D) * (1, K, D, D)
matmul at line275 is replaced with an equivalent matmul(K, N, D) * (K, D, D)
.(N, K, 1, D) * (1, K, D, D)
will be interpreted by cublas as batched matrix vector product, while(K, N, D) * (K, D, D)
is batched matrix matrix product, which is more efficient on GPUs.in 2 consecutive iterations of
fit
,_estimate_log_prob
was being called twice with the same input, in_e_step
and__score
. nowweighted_log_probs
is only computed once in__score
of previous iteration, then cached to be reused at_e_step
of next iteration.at line342 ,
mu
was originally obtained by element wise multiplication & summation, which is now simplified to a matmul.at line346, the batched vector outer product followed by summation is rewritten as a single batched matmul, which is more efficient on GPUs.
computations in
_m_step
and_estimate_log_prob
is splitted into smaller "chunks" of computations in order to prevent OOM as much as possible.added option to choose the dtype of the covariance matrix. Use
torch.linalg.eigvals
to computelog_det
ifcovariance_data_type = torch.float
, otherwise use cholesky decomp.replaced some of the tensor-scalar or tensor-tensor additions/multiplications with their inplace counterparts to reduce unnecessary memory allocation.
benchmark results
remaining issues:
covariance_data_type = "float"
, and bothn_components
andn_features
are large, covar contains NaN.