Skip to content

Commit

Permalink
update documentation and lint
Browse files Browse the repository at this point in the history
  • Loading branch information
oliviaweng committed May 9, 2024
1 parent da8a0a4 commit e7baea1
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions fkeras/metrics/stat_fi.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,27 +48,27 @@ def get_supported_layer_indices(self):
if self.model.layers[i].__class__.__name__ in SUPPORTED_LAYERS:
supported_indices.append(running_idx)
running_idx += self.model.layers[i].trainable_variables.__len__()
return supported_indices
return supported_indices

def get_params_and_quantizers(self):
"""
Compute the Hessian vector product of Hv, where
H is the Hessian of the loss function with respect to the model parameters
v is a vector of the same size as the model parameters
Based on: https://github.com/tensorflow/tensorflow/blob/47f0e99c1918f68daa84bd4cac1b6011b2942dac/tensorflow/python/eager/benchmarks/resnet50/hvp_test.py#L62
Return tuple (list of parameters layer-wise, list of quantizers layer-wise),
e.g., ([param1, param2, ...], [quantizer1, quantizer2, ...]) for layers
1 and 2.
"""

# Compute the gradients of the loss function with respect to the model parameters
params = [
v.numpy()
for i in self.layer_indices if self.model.layers[i].__class__.__name__ in SUPPORTED_LAYERS
for i in self.layer_indices
if self.model.layers[i].__class__.__name__ in SUPPORTED_LAYERS
for v in self.model.layers[i].trainable_variables
]

quantizers = [
self.model.layers[i].kernel_quantizer_internal
for i in self.layer_indices if self.model.layers[i].__class__.__name__ in SUPPORTED_LAYERS
for i in self.layer_indices
if self.model.layers[i].__class__.__name__ in SUPPORTED_LAYERS
]

return np.array(params, dtype='object'), np.array(quantizers, dtype='object')
return np.array(params, dtype="object"), np.array(quantizers, dtype="object")

0 comments on commit e7baea1

Please sign in to comment.