diff --git a/utils/classification.py b/utils/classification.py index c2f29ca..8192619 100644 --- a/utils/classification.py +++ b/utils/classification.py @@ -93,7 +93,11 @@ def _fit( test_pos, test_layer ) train_positive_str_labels, train_negative_str_labels = train_data.get_positive_negative_labels() - test_positive_str_labels, test_negative_str_labels = test_data.get_positive_negative_labels() + if test_data is None: + test_positive_str_labels = train_positive_str_labels + test_negative_str_labels = train_negative_str_labels + else: + test_positive_str_labels, test_negative_str_labels = test_data.get_positive_negative_labels() kmeans = KMeans(n_clusters=n_clusters, n_init=n_init, random_state=random_state) if method == ClassificationMethod.KMEANS or method == ClassificationMethod.MEAN_DIFF: kmeans.fit(train_embeddings)