diff --git a/src/olmo_core/nn/lm_head.py b/src/olmo_core/nn/lm_head.py index 66cc4c3a..bc233a5a 100644 --- a/src/olmo_core/nn/lm_head.py +++ b/src/olmo_core/nn/lm_head.py @@ -80,6 +80,7 @@ def __init__( bias: bool = True, init_device: str = "cpu", ): + super().__init__() self.norm = ( None if layer_norm is None else layer_norm.build(d_model, init_device=init_device) )