diff --git a/pyproject.toml b/pyproject.toml index 47701fb..22112b1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "tabpfn-client" -version = "0.0.15" +version = "0.0.18" requires-python = ">=3.10" dependencies = [ "httpx>=0.24.1", diff --git a/tabpfn_client/client.py b/tabpfn_client/client.py index 4fe1c4c..d6a6a63 100644 --- a/tabpfn_client/client.py +++ b/tabpfn_client/client.py @@ -170,11 +170,11 @@ def _validate_response(response, method_name, only_version_check=False): try: load = response.json() except json.JSONDecodeError as e: - logging.error(f"Failed to parse JSON from response in {method_name}: {e}") + logging.info(f"Failed to parse JSON from response in {method_name}: {e}") # Check if the server requires a newer client version. if response.status_code == 426: - logger.error( + logger.info( f"Fail to call {method_name}, response status: {response.status_code}" ) raise RuntimeError(load.get("detail")) @@ -194,9 +194,10 @@ def _validate_response(response, method_name, only_version_check=False): ) > 1 ): - raise RuntimeError( - f"Fail to call {method_name} with error: {reponse_split_up[1]}" - ) + relevant_reponse_test = reponse_split_up[1].split("debug_error_string")[ + 0 + ] + raise RuntimeError(relevant_reponse_test) raise RuntimeError( f"Fail to call {method_name} with error: {response.status_code} and reason: " f"{response.reason_phrase}" diff --git a/tabpfn_client/estimator.py b/tabpfn_client/estimator.py index 7d043b7..65f061f 100644 --- a/tabpfn_client/estimator.py +++ b/tabpfn_client/estimator.py @@ -182,10 +182,23 @@ def __init__( self.add_fingerprint_features = add_fingerprint_features self.subsample_samples = subsample_samples + def _validate_targets_and_classes(self, y) -> np.ndarray: + from sklearn.utils import column_or_1d + from sklearn.utils.multiclass import check_classification_targets + + y_ = column_or_1d(y, warn=True) + check_classification_targets(y) + + # Get classes and encode before type conversion to guarantee correct class labels. + not_nan_mask = ~np.isnan(y) + self.classes_ = np.unique(y_[not_nan_mask]) + def fit(self, X, y): # assert init() is called init() + self._validate_targets_and_classes(y) + if config.g_tabpfn_config.use_server: try: assert ( @@ -203,7 +216,9 @@ def fit(self, X, y): def predict(self, X): probas = self.predict_proba(X) - return np.argmax(probas, axis=1) + y = np.argmax(probas, axis=1) + y = self.classes_.take(np.asarray(y, dtype=int)) + return y def predict_proba(self, X): check_is_fitted(self)