Skip to content

Commit

Permalink
Check if none before test_data.get_positive_negative_labels()
Browse files Browse the repository at this point in the history
  • Loading branch information
ojh31 committed Oct 8, 2023
1 parent 1bf7e99 commit bcc1b98
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion utils/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,11 @@ def _fit(
test_pos, test_layer
)
train_positive_str_labels, train_negative_str_labels = train_data.get_positive_negative_labels()
test_positive_str_labels, test_negative_str_labels = test_data.get_positive_negative_labels()
if test_data is None:
test_positive_str_labels = train_positive_str_labels
test_negative_str_labels = train_negative_str_labels
else:
test_positive_str_labels, test_negative_str_labels = test_data.get_positive_negative_labels()
kmeans = KMeans(n_clusters=n_clusters, n_init=n_init, random_state=random_state)
if method == ClassificationMethod.KMEANS or method == ClassificationMethod.MEAN_DIFF:
kmeans.fit(train_embeddings)
Expand Down

0 comments on commit bcc1b98

Please sign in to comment.