-
-
Notifications
You must be signed in to change notification settings - Fork 988
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Autonormal encoder #2849
base: dev
Are you sure you want to change the base?
[WIP] Autonormal encoder #2849
Conversation
self._cond_indep_stacks[name] = site["cond_indep_stack"] | ||
|
||
# add linear layer for locs and scales | ||
param_dim = (self.n_hidden, self.amortised_plate_sites["sites"][name]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One of the main assumptions at the moment is that encoded variables are 2D tensors of shape (plate subsample_size aka batch size, self.amortised_plate_sites["sites"][name]) - but I guess that the shape can be automatically guessed, I just did not think that through and do not have applications where variables are more than 2D.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's more important to get an initial version merged quickly than to make something fully general right away. So WDYT about simply adding assertions or NotImplementedError("not yet supported")
checks for your current assumptions?
Also feel free to start the class docstring with EXPERIMENTAL and add a
.. warning:: Interface may change` to give you/us room to slightly change the interface later in case that fully-general version needs slight changes.
Looks like accidentally included changes to |
No worries, #2837 should merge soon. We often add "Blocked by #xxx" in the PR description to denote merge order dependencies. |
init_param = torch.normal( | ||
torch.full(size=param_dim, fill_value=0.0, device=site["value"].device), | ||
torch.full( | ||
size=param_dim, | ||
fill_value=(1 * self.init_param_scale) / np.sqrt(self.n_hidden), | ||
device=site["value"].device, | ||
), | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I use torch.normal rather than numpy.random.normal, I get this warning:
/scvi-tools/scvi/external/cell2location/autoguide.py:218: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
init_param, device=site["value"].device, requires_grad=True
I also get different results after training the model.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Numpy alternative
init_param = np.random.normal(
np.zeros(param_dim),
(np.ones(param_dim) * self.init_param_scale) / np.sqrt(self.n_hidden),
).astype("float32")
What is missing at the moment is a simple encoder NN class. @fritzo @martinjankowiak is there anything already defined in pyro or a good example? |
self.hidden2locs, | ||
name, | ||
PyroParam( | ||
torch.tensor( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe that UserWarning: To copy construct ...
is actuall due to this line. I believe you can fix that by using as_tensor
:
PyroParam(
- torch.tensor(
+ torch.as_tensor(
init_param, ...
self.hidden2scales, | ||
name, | ||
PyroParam( | ||
torch.tensor( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto: torch.tensor
-> torch.as_tensor
Existing code includes: Also feel free to add something to a new file in pyro/nn/ |
I started thinking about tests and realised that this testing class also needs a model with local variables and some data. @fritzo do you have any good example in mind (which is ideally already implemented in tests)? One alternative would be to use scVI pyro test regression model, and write simple training and posterior sampling code to test this class. Actually, for posterior sampling and computing median/quantiles, concatenating encoded local variables in plate dimension is non-trivial and a subject of this scVI PR ( Could be good if the AutoNormalEncoder class provided a method to merge quantiles, medians and posterior samples along the plate dimension. WDYT? First guess plate dimension: def _guess_obs_plate_sites(self, args, kwargs):
"""
Automatically guess which model sites belong to observation/minibatch plate.
This function requires minibatch plate name specified in `self.amortised_plate_sites["name"]`.
Parameters
----------
args
Arguments to the model.
kwargs
Keyword arguments to the model.
Returns
-------
Dictionary with keys corresponding to site names and values to plate dimension.
"""
plate_name = self.amortised_plate_sites["name"]
# find plate dimension
trace = poutine.trace(self.model).get_trace(*args, **kwargs)
obs_plate = {
name: site["cond_indep_stack"][0].dim
for name, site in trace.nodes.items()
if site["type"] == "sample"
if any(f.name == plate_name for f in site["cond_indep_stack"])
}
return obs_plate Then concatenate samples in that dimension: i=0
for args, kwargs in dataloader:
if i==0:
samples = guide.quantiles(0.5, *args, **kwargs)
obs_plate_sites = guide._guess_obs_plate_sites(args, kwargs)
obs_plate_dim = list(obs_plate_sites.values())[0]
else:
samples_ = guide.quantiles(0.5, *args, **kwargs)
samples = {
k: np.array(
[
np.concatenate(
[samples[k][j], samples_[k][j]],
axis=obs_plate_dim,
)
for j in range(
len(samples[k])
) # for each sample (in 0 dimension
]
)
for k in samples.keys() # for each variable
}
i = i + 1 |
I extended this class further to enable more complex architectures (see below) and a different number of hidden nodes for each model site.
Code is here for now: https://github.com/vitkl/scvi-tools/blob/pyro-cell2location/scvi/external/cell2location/autoguide.py Still looking for good example data for tests. |
@martinjankowiak @fritzo @eb8680 following our conversation here scverse/scvi-tools#930 (review), creating this PR to discuss adding Autonormal encoder class.
This class need users to specify encoder network class, data transformation, and
amortised_plate_sites
dictionary which tells which variables are amortised, which model args/kwargs need to be passed to the encoder and which plate the variables belong to.One of the main assumptions at the moment is that encoded variables are 2D tensors - but I guess that the shape can be automatically guessed, I just did not think that through and do not have applications where variables are more than 2D.