Skip to content

Commit

Permalink
Added normalized l2 loss.
Browse files Browse the repository at this point in the history
  • Loading branch information
HuFY-dev committed Apr 1, 2024
1 parent 2660f57 commit e1943c3
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 5 deletions.
4 changes: 3 additions & 1 deletion sparse_autoencoder/autoencoder/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ class LitSparseAutoencoderConfig(SparseAutoencoderConfig):
resample_loss_dataset_size: PositiveInt = 819200

resample_threshold_is_dead_portion_fires: NonNegativeFloat = 0.0

normalize_by_input_norm: bool = False

def model_post_init(self, __context: Any) -> None: # noqa: ANN401
"""Model post init validation.
Expand Down Expand Up @@ -91,7 +93,7 @@ def __init__(

# Create the loss & metrics
self.loss_fn = SparseAutoencoderLoss(
num_components, config.l1_coefficient, keep_batch_dim=True
num_components, config.l1_coefficient, keep_batch_dim=True, config.normalize_by_input_norm
)

self.train_metrics = MetricCollection(
Expand Down
26 changes: 24 additions & 2 deletions sparse_autoencoder/metrics/loss/l2_reconstruction_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,13 @@ class L2ReconstructionLoss(Metric):
and it's corresponding decoded vector. The original paper found that models trained with some
loss functions such as cross-entropy loss generally prefer to represent features
polysemantically, whereas models trained with L2 may achieve the same loss for both
polysemantic and monosemantic representations of true features.
polysemantic and monosemantic representations of true features.
You have the option to set L2 reconstruction loss to normalize the input activations before
calculating the loss. This can be useful because the input vectors can vary in magnitude and
normalizing them can help to ensure that the loss is not dominated by the magnitude of the
activations.
Example:
>>> import torch
>>> loss = L2ReconstructionLoss(num_components=1)
Expand Down Expand Up @@ -53,6 +58,7 @@ class L2ReconstructionLoss(Metric):
# Settings
_num_components: int
_keep_batch_dim: bool
_normalize_by_input_norm: bool

@property
def keep_batch_dim(self) -> bool:
Expand Down Expand Up @@ -97,11 +103,13 @@ def __init__(
num_components: PositiveInt = 1,
*,
keep_batch_dim: bool = False,
normalize_by_input_norm: bool = False,
) -> None:
"""Initialise the L2 reconstruction loss."""
super().__init__()
self._num_components = num_components
self.keep_batch_dim = keep_batch_dim
self._normalize_by_input_norm = normalize_by_input_norm
self.add_state(
"num_activation_vectors",
default=torch.tensor(0, dtype=torch.int64),
Expand All @@ -119,6 +127,16 @@ def calculate_mse(
) -> Float[Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL)]:
"""Calculate the MSE."""
return (decoded_activations - source_activations).pow(2).mean(dim=-1)

@staticmethod
def normalize_input(
activations: Float[
Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.INPUT_OUTPUT_FEATURE)
],
) -> Float[Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.INPUT_OUTPUT_FEATURE)]:
"""Normalize the input activations."""
activation_norm = activations.norm(dim=-1, keepdim=True)
return activations / activation_norm

def update(
self,
Expand Down Expand Up @@ -146,6 +164,10 @@ def update(
source_activations: The source activations from the autoencoder.
**kwargs: Ignored keyword arguments (to allow use with other metrics in a collection).
"""
if self._normalize_by_input_norm:
decoded_activations = self.normalize_input(decoded_activations)
source_activations = self.normalize_input(source_activations)

mse = self.calculate_mse(decoded_activations, source_activations)

if self.keep_batch_dim:
Expand Down
6 changes: 6 additions & 0 deletions sparse_autoencoder/metrics/loss/sae_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class SparseAutoencoderLoss(Metric):
# Settings
_num_components: int
_keep_batch_dim: bool
_normalize_by_input_norm: bool
_l1_coefficient: float

@property
Expand Down Expand Up @@ -88,11 +89,13 @@ def __init__(
l1_coefficient: PositiveFloat = 0.001,
*,
keep_batch_dim: bool = False,
normalize_by_input_norm: bool = False,
):
"""Initialise the metric."""
super().__init__()
self._num_components = num_components
self.keep_batch_dim = keep_batch_dim
self._normalize_by_input_norm = normalize_by_input_norm
self._l1_coefficient = l1_coefficient

# Add the state
Expand All @@ -117,6 +120,9 @@ def update(
) -> None:
"""Update the metric."""
absolute_loss = L1AbsoluteLoss.calculate_abs_sum(learned_activations)
if self._normalize_by_input_norm:
source_activations = L2ReconstructionLoss.normalize_input(source_activations)
decoded_activations = L2ReconstructionLoss.normalize_input(decoded_activations)
mse = L2ReconstructionLoss.calculate_mse(decoded_activations, source_activations)

if self.keep_batch_dim:
Expand Down
4 changes: 2 additions & 2 deletions sparse_autoencoder/train/sweep.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,14 @@ def setup_autoencoder(
"""
autoencoder_input_dim: int = hyperparameters["source_model"]["hook_dimension"]
expansion_factor = hyperparameters["autoencoder"]["expansion_factor"]
type = hyperparameters["autoencoder"]["type"]

config = LitSparseAutoencoderConfig(
n_input_features=autoencoder_input_dim,
n_learned_features=autoencoder_input_dim * expansion_factor,
n_components=len(hyperparameters["source_model"]["cache_names"]),
component_names=hyperparameters["source_model"]["cache_names"],
l1_coefficient=hyperparameters["loss"]["l1_coefficient"],
normalize_by_input_norm=hyperparameters["loss"]["normalize_by_input_norm"],
resample_interval=hyperparameters["activation_resampler"]["resample_interval"],
max_n_resamples=hyperparameters["activation_resampler"]["max_n_resamples"],
resample_dead_neurons_dataset_size=hyperparameters["activation_resampler"][
Expand All @@ -73,7 +73,7 @@ def setup_autoencoder(
resample_threshold_is_dead_portion_fires=hyperparameters["activation_resampler"][
"threshold_is_dead_portion_fires"
],
type=type,
type=hyperparameters["autoencoder"]["type"],
)

return LitSparseAutoencoder(config)
Expand Down
10 changes: 10 additions & 0 deletions sparse_autoencoder/train/sweep_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,11 +111,21 @@ class LossHyperparameters(NestedParameter):
starting point for the L1 coefficient is 1e-3.
"""

normalize_by_input_norm: Parameter[bool] = field(default=Parameter(value=False))
"""Normalize by input norm.
Whether to normalize the input and source activations before calculating the L2 loss. This can
be useful because the input vectors can vary in magnitude and normalizing them can help to
ensure that the loss is not dominated by activations of high magnitudes (often uninterpretable
activations from the <|endoftext|> token).
"""

class LossRuntimeHyperparameters(TypedDict):
"""Loss runtime hyperparameters."""

l1_coefficient: float

normalize_by_input_norm: bool


@dataclass(frozen=True)
Expand Down

0 comments on commit e1943c3

Please sign in to comment.