diff --git a/tabpfn_client/client.py b/tabpfn_client/client.py index 9730c84..406ba80 100644 --- a/tabpfn_client/client.py +++ b/tabpfn_client/client.py @@ -6,6 +6,7 @@ import logging from importlib.metadata import version, PackageNotFoundError import numpy as np +from enum import Enum from omegaconf import OmegaConf import json from typing import Literal @@ -34,6 +35,9 @@ class ServiceClient: Singleton class for handling communication with the server. It encapsulates all the API calls to the server. """ + class Status(Enum): + OKAY = 0 + USER_NOT_VERIFIED = 1 def __init__(self): self.server_config = SERVER_CONFIG @@ -231,10 +235,10 @@ def try_authenticate(self, access_token) -> bool: ) self._validate_response(response, "try_authenticate", only_version_check=True) - if response.status_code == 200: is_authenticated = True - + elif response.status_code == 403: + is_authenticated = (False, self.Status.USER_NOT_VERIFIED) return is_authenticated def validate_email(self, email: str) -> tuple[bool, str]: @@ -398,16 +402,16 @@ def retrieve_greeting_messages(self) -> list[str]: greeting_messages = response.json()["messages"] return greeting_messages - def get_user_email_verification_status(self, email: str) -> tuple[bool, str]: + # bool optional parameter is accesstoken required + def get_user_email_verification_status(self, email: str, access_token_required: bool) -> tuple[bool, str]: """ Check if the user's email is verified. """ response = self.httpx_client.post( self.server_endpoints.get_user_verification_status_via_email.path, - params={"email": email}, + params={"email": email, "access_token_required": access_token_required}, ) - - return response.json() or False + return response.json() def get_data_summary(self) -> {}: """ diff --git a/tabpfn_client/config.py b/tabpfn_client/config.py index 6c1d128..46f7111 100644 --- a/tabpfn_client/config.py +++ b/tabpfn_client/config.py @@ -36,8 +36,12 @@ def init(use_server=True): is_valid_token_set = user_auth_handler.try_reuse_existing_token() - if is_valid_token_set: + if isinstance(is_valid_token_set, bool) and is_valid_token_set: PromptAgent.prompt_reusing_existing_token() + elif isinstance(is_valid_token_set, tuple) and is_valid_token_set[1] is not None: + print("Access token is valid but email is not verified...") + PromptAgent.reverify_email(user_auth_handler) + return init(use_server) else: if not PromptAgent.prompt_terms_and_cond(): raise RuntimeError( diff --git a/tabpfn_client/estimator.py b/tabpfn_client/estimator.py index 799e09e..6f873e3 100644 --- a/tabpfn_client/estimator.py +++ b/tabpfn_client/estimator.py @@ -3,6 +3,7 @@ from dataclasses import dataclass, asdict import numpy as np +from tabpfn_client import init from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin from sklearn.utils.validation import check_is_fitted @@ -182,7 +183,7 @@ def __init__( self.subsample_samples = subsample_samples # check if user is verified - if config.g_tabpfn_config.user_email and not config.g_tabpfn_config.user_auth_handler.get_user_email_verification_status(config.g_tabpfn_config.user_email): + if config.g_tabpfn_config.user_email and not config.g_tabpfn_config.user_auth_handler.get_user_email_verification_status(config.g_tabpfn_config.user_email)[0]: raise RuntimeError( "Dear User, your email has not been verified. Please, check your mailbox, verify your email and try again!" ) @@ -190,9 +191,7 @@ def __init__( def fit(self, X, y): # assert init() is called if not config.g_tabpfn_config.is_initialized: - raise RuntimeError( - "tabpfn_client.init() must be called before using TabPFNClassifier" - ) + init() if config.g_tabpfn_config.use_server: try: @@ -318,7 +317,7 @@ def __init__( self.subsample_samples = subsample_samples # check if user is verified - if config.g_tabpfn_config.user_email and not config.g_tabpfn_config.user_auth_handler.get_user_email_verification_status(config.g_tabpfn_config.user_email): + if config.g_tabpfn_config.user_email and not config.g_tabpfn_config.user_auth_handler.get_user_email_verification_status(config.g_tabpfn_config.user_email)[0]: raise RuntimeError( "Dear User, your email has not been verified. Please, check your mailbox, verify your email and try again!" ) @@ -326,9 +325,7 @@ def __init__( def fit(self, X, y): # assert init() is called if not config.g_tabpfn_config.is_initialized: - raise RuntimeError( - "tabpfn_client.init() must be called before using TabPFNRegressor" - ) + init() if config.g_tabpfn_config.use_server: try: diff --git a/tabpfn_client/prompt_agent.py b/tabpfn_client/prompt_agent.py index d877d73..c8254e5 100644 --- a/tabpfn_client/prompt_agent.py +++ b/tabpfn_client/prompt_agent.py @@ -209,6 +209,29 @@ def prompt_reusing_existing_token(cls): print(cls.indent(prompt)) + @classmethod + def reverify_email(cls, user_auth_handler: "UserAuthenticationClient"): + prompt = "\n".join( + [ + "Please check your inbox for the verification email.", + "Note: The email might be in your spam folder or could have expired.", + ] + ) + print(cls.indent(prompt)) + retry_verification = "\n".join( + [ + "Do you want to resend email verification link? (y/n): ", + ] + ) + choice = cls._choice_with_retries(retry_verification, ["y", "n"]) + if choice == "y": + email = input(cls.indent("Please enter your email: ")) + password = getpass.getpass(cls.indent("Please enter your password: ")) + + user_auth_handler.set_token_by_login(email, password) + print(cls.indent("A verification email has been sent, provided the details are correct!") + "\n") + return + @classmethod def prompt_retrieved_greeting_messages(cls, greeting_messages: list[str]): for message in greeting_messages: diff --git a/tabpfn_client/server_config.yaml b/tabpfn_client/server_config.yaml index e5d2ef2..7443be9 100644 --- a/tabpfn_client/server_config.yaml +++ b/tabpfn_client/server_config.yaml @@ -39,6 +39,11 @@ endpoints: methods: [ "POST" ] description: "Send reset password email" + get_user_verification_status_via_email: + path: "/auth/get_user_verification_status_via_email/" + methods: ["POST"] + description: "Get user verification status via email" + retrieve_greeting_messages: path: "/retrieve_greeting_messages/" methods: [ "GET" ] diff --git a/tabpfn_client/service_wrapper.py b/tabpfn_client/service_wrapper.py index d915f23..bb446ef 100644 --- a/tabpfn_client/service_wrapper.py +++ b/tabpfn_client/service_wrapper.py @@ -47,6 +47,9 @@ def set_token_by_registration( is_created, message = self.service_client.register( email, password, password_confirm, validation_link, additional_info ) + is_verified, access_token = self.get_user_email_verification_status(email, access_token_required=True) + if not is_verified: + self.set_token(access_token) return is_created, message def set_token_by_login(self, email: str, password: str) -> tuple[bool, str]: @@ -58,10 +61,10 @@ def set_token_by_login(self, email: str, password: str) -> tuple[bool, str]: self.set_token(access_token) return True, message - def get_user_email_verification_status(self, email: str) -> tuple[bool, str]: - return self.service_client.get_user_email_verification_status(email) + def get_user_email_verification_status(self, email: str, access_token_required: bool = False) -> tuple[bool, str]: + return self.service_client.get_user_email_verification_status(email, access_token_required) - def try_reuse_existing_token(self) -> bool: + def try_reuse_existing_token(self) -> bool or (bool, str): if self.service_client.access_token is None: if not self.CACHED_TOKEN_FILE.exists(): return False @@ -75,6 +78,8 @@ def try_reuse_existing_token(self) -> bool: if not is_valid: self._reset_token() return False + elif isinstance(is_valid, tuple) and not is_valid[0] and is_valid[1] == self.service_client.Status.USER_NOT_VERIFIED: + return False, access_token logger.debug(f"Reusing existing access token? {is_valid}") self.set_token(access_token)