Skip to content

Commit

Permalink
fix class remappings
Browse files Browse the repository at this point in the history
  • Loading branch information
SamuelGabriel committed Jul 15, 2024
1 parent b14cb11 commit 985c024
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 7 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "hatchling.build"

[project]
name = "tabpfn-client"
version = "0.0.15"
version = "0.0.18"
requires-python = ">=3.10"
dependencies = [
"httpx>=0.24.1",
Expand Down
11 changes: 6 additions & 5 deletions tabpfn_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,11 +170,11 @@ def _validate_response(response, method_name, only_version_check=False):
try:
load = response.json()
except json.JSONDecodeError as e:
logging.error(f"Failed to parse JSON from response in {method_name}: {e}")
logging.info(f"Failed to parse JSON from response in {method_name}: {e}")

# Check if the server requires a newer client version.
if response.status_code == 426:
logger.error(
logger.info(
f"Fail to call {method_name}, response status: {response.status_code}"
)
raise RuntimeError(load.get("detail"))
Expand All @@ -194,9 +194,10 @@ def _validate_response(response, method_name, only_version_check=False):
)
> 1
):
raise RuntimeError(
f"Fail to call {method_name} with error: {reponse_split_up[1]}"
)
relevant_reponse_test = reponse_split_up[1].split("debug_error_string")[
0
]
raise RuntimeError(relevant_reponse_test)
raise RuntimeError(
f"Fail to call {method_name} with error: {response.status_code} and reason: "
f"{response.reason_phrase}"
Expand Down
17 changes: 16 additions & 1 deletion tabpfn_client/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,10 +182,23 @@ def __init__(
self.add_fingerprint_features = add_fingerprint_features
self.subsample_samples = subsample_samples

def _validate_targets_and_classes(self, y) -> np.ndarray:
from sklearn.utils import column_or_1d
from sklearn.utils.multiclass import check_classification_targets

y_ = column_or_1d(y, warn=True)
check_classification_targets(y)

# Get classes and encode before type conversion to guarantee correct class labels.
not_nan_mask = ~np.isnan(y)
self.classes_ = np.unique(y_[not_nan_mask])

def fit(self, X, y):
# assert init() is called
init()

self._validate_targets_and_classes(y)

if config.g_tabpfn_config.use_server:
try:
assert (
Expand All @@ -203,7 +216,9 @@ def fit(self, X, y):

def predict(self, X):
probas = self.predict_proba(X)
return np.argmax(probas, axis=1)
y = np.argmax(probas, axis=1)
y = self.classes_.take(np.asarray(y, dtype=int))
return y

def predict_proba(self, X):
check_is_fitted(self)
Expand Down

0 comments on commit 985c024

Please sign in to comment.