diff --git a/SpatialSubtractiveNormalization.lua b/SpatialSubtractiveNormalization.lua index 7fa440267..dd77dde00 100644 --- a/SpatialSubtractiveNormalization.lua +++ b/SpatialSubtractiveNormalization.lua @@ -109,8 +109,10 @@ function SpatialSubtractiveNormalization:updateGradInput(input, gradOutput) end function SpatialSubtractiveNormalization:clearState() - if self.ones then self.ones:set() end - if self._coef then self._coef:set() end - self.meanestimator:clearState() - return parent.clearState(self) + nn.utils.clear(self, '_inpsz', 'ones', '_coef') + self.coef = torch.Tensor(1,1,1) + self.subtractor:clearState() + self.divider:clearState() + self.meanestimator:clearState() + return parent.clearState(self) end