diff --git a/utils/classification.py b/utils/classification.py index e386e38..6b9c41f 100644 --- a/utils/classification.py +++ b/utils/classification.py @@ -234,7 +234,7 @@ def _fit( def _fit_logistic_regression( - train_data: ResidualStreamDataset, train_pos: str, train_layer: int, + train_data: ResidualStreamDataset, train_pos: Union[str, None], train_layer: int, test_data: Optional[ResidualStreamDataset], test_pos: Optional[str], test_layer: Optional[int], @@ -283,6 +283,12 @@ def train_classifying_direction( """ Main entrypoint for training a direction using classification methods """ + if test_data is None: + test_data = train_data + if test_pos is None: + test_pos = train_pos + if test_layer is None: + test_layer = train_layer if method in (ClassificationMethod.PCA, ClassificationMethod.SVD): assert 'n_components' in kwargs, "Must specify n_components for PCA/SVD" model = train_data.model @@ -298,14 +304,11 @@ def train_classifying_direction( test_data, test_pos, test_layer, **kwargs, ) - if test_data is None: - test_line = train_line - else: - test_line, _, _, _ = fitting_method( - test_data, test_pos, test_layer, - test_data, test_pos, test_layer, - **kwargs, - ) + test_line, _, _, _ = fitting_method( + test_data, test_pos, test_layer, + test_data, test_pos, test_layer, + **kwargs, + ) except ConvergenceWarning: print( f"Convergence warning for {method.value}; "