You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Currently our loss functions are coded as straightforward functions working on torch inputs. Some loss functions have additional parameters that are set at initialization, for example,
def entropy(cube, prior_intensity): ...
where prior_intensity is a reference cube used in the evaluation of the actual target, cube.
This works fine, but can get cumbersome, especially when we are interfacing with a bunch of loss functions all at once (as in a cross-validation loop).
Can we take a page from the way PyTorch designs its loss functions and make most if not all loss functions classes that inherit from torch.nn? This would create objects that could be instantiated with default parameter values easily and generalize the calls to each parameter. For example, see MSE Loss.
This may have additional benefits (with reduce, say) if we think about batching and applications to multiple GPUs.
Does it additionally make sense to include the lambda terms as parameters of the loss object, too? @kadri-nizam do you have any thoughts from experience w/ your VAE architecture?
The text was updated successfully, but these errors were encountered:
I think making versions of the loss functions as torch modules is a great idea. I'd still keep the functional definition separate and import them when defining the module as it is more flexible (for developing and testing). I believe this is how PyTorch implements it; logic for the losses are in torch.nn.functional which gets used in the nn.Module version.
Does it additionally make sense to include the lambda terms as parameters of the loss object, too?
The lambda parameter doesn't change throughout the optimization run, right? If so then I'd include it in the argument during instantiation.
Yes, the 'lambda' parameter remains fixed during an optimization run. In a cross-validation loop you'd want to try several different lambda values, so in that situation I guess you would need to re-instantiate the loss functions.
The purpose for having the functional static method is to allow for easier testing -- just call TV.functional instead of the need to instantiate and all that.
I defined an abstract base class in my fork to specify requirements that a loss module in the repo must meet, but this is optional.
Currently our loss functions are coded as straightforward functions working on torch inputs. Some loss functions have additional parameters that are set at initialization, for example,
def entropy(cube, prior_intensity): ...
where
prior_intensity
is a reference cube used in the evaluation of the actual target,cube
.This works fine, but can get cumbersome, especially when we are interfacing with a bunch of loss functions all at once (as in a cross-validation loop).
Can we take a page from the way PyTorch designs its loss functions and make most if not all loss functions classes that inherit from
torch.nn
? This would create objects that could be instantiated with default parameter values easily and generalize the calls to each parameter. For example, see MSE Loss.This may have additional benefits (with
reduce
, say) if we think about batching and applications to multiple GPUs.Does it additionally make sense to include the
lambda
terms as parameters of the loss object, too? @kadri-nizam do you have any thoughts from experience w/ your VAE architecture?The text was updated successfully, but these errors were encountered: