Skip to content

Commit

Permalink
Added test data none check at start of train_classifying_direction()
Browse files Browse the repository at this point in the history
  • Loading branch information
ojh31 committed Oct 8, 2023
1 parent 5b88da8 commit 9ff7674
Showing 1 changed file with 12 additions and 9 deletions.
21 changes: 12 additions & 9 deletions utils/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def _fit(


def _fit_logistic_regression(
train_data: ResidualStreamDataset, train_pos: str, train_layer: int,
train_data: ResidualStreamDataset, train_pos: Union[str, None], train_layer: int,
test_data: Optional[ResidualStreamDataset],
test_pos: Optional[str],
test_layer: Optional[int],
Expand Down Expand Up @@ -283,6 +283,12 @@ def train_classifying_direction(
"""
Main entrypoint for training a direction using classification methods
"""
if test_data is None:
test_data = train_data
if test_pos is None:
test_pos = train_pos
if test_layer is None:
test_layer = train_layer
if method in (ClassificationMethod.PCA, ClassificationMethod.SVD):
assert 'n_components' in kwargs, "Must specify n_components for PCA/SVD"
model = train_data.model
Expand All @@ -298,14 +304,11 @@ def train_classifying_direction(
test_data, test_pos, test_layer,
**kwargs,
)
if test_data is None:
test_line = train_line
else:
test_line, _, _, _ = fitting_method(
test_data, test_pos, test_layer,
test_data, test_pos, test_layer,
**kwargs,
)
test_line, _, _, _ = fitting_method(
test_data, test_pos, test_layer,
test_data, test_pos, test_layer,
**kwargs,
)
except ConvergenceWarning:
print(
f"Convergence warning for {method.value}; "
Expand Down

0 comments on commit 9ff7674

Please sign in to comment.