diff --git a/tabpfn_client/client.py b/tabpfn_client/client.py index bb8a9eb..9730c84 100644 --- a/tabpfn_client/client.py +++ b/tabpfn_client/client.py @@ -397,6 +397,17 @@ def retrieve_greeting_messages(self) -> list[str]: greeting_messages = response.json()["messages"] return greeting_messages + + def get_user_email_verification_status(self, email: str) -> tuple[bool, str]: + """ + Check if the user's email is verified. + """ + response = self.httpx_client.post( + self.server_endpoints.get_user_verification_status_via_email.path, + params={"email": email}, + ) + + return response.json() or False def get_data_summary(self) -> {}: """ diff --git a/tabpfn_client/config.py b/tabpfn_client/config.py index 2f05f3a..6c1d128 100644 --- a/tabpfn_client/config.py +++ b/tabpfn_client/config.py @@ -8,6 +8,7 @@ class TabPFNConfig: is_initialized = None + user_email = None use_server = None user_auth_handler = None inference_handler = None @@ -44,7 +45,7 @@ def init(use_server=True): ) # prompt for login / register - PromptAgent.prompt_and_set_token(user_auth_handler) + g_tabpfn_config.user_email = PromptAgent.prompt_and_set_token(user_auth_handler) # Print new greeting messages. If there are no new messages, nothing will be printed. PromptAgent.prompt_retrieved_greeting_messages( diff --git a/tabpfn_client/estimator.py b/tabpfn_client/estimator.py index 2265552..799e09e 100644 --- a/tabpfn_client/estimator.py +++ b/tabpfn_client/estimator.py @@ -181,6 +181,12 @@ def __init__( self.add_fingerprint_features = add_fingerprint_features self.subsample_samples = subsample_samples + # check if user is verified + if config.g_tabpfn_config.user_email and not config.g_tabpfn_config.user_auth_handler.get_user_email_verification_status(config.g_tabpfn_config.user_email): + raise RuntimeError( + "Dear User, your email has not been verified. Please, check your mailbox, verify your email and try again!" + ) + def fit(self, X, y): # assert init() is called if not config.g_tabpfn_config.is_initialized: @@ -311,6 +317,12 @@ def __init__( self.super_bar_dist_averaging = super_bar_dist_averaging self.subsample_samples = subsample_samples + # check if user is verified + if config.g_tabpfn_config.user_email and not config.g_tabpfn_config.user_auth_handler.get_user_email_verification_status(config.g_tabpfn_config.user_email): + raise RuntimeError( + "Dear User, your email has not been verified. Please, check your mailbox, verify your email and try again!" + ) + def fit(self, X, y): # assert init() is called if not config.g_tabpfn_config.is_initialized: diff --git a/tabpfn_client/prompt_agent.py b/tabpfn_client/prompt_agent.py index 5b9a299..d877d73 100644 --- a/tabpfn_client/prompt_agent.py +++ b/tabpfn_client/prompt_agent.py @@ -55,6 +55,7 @@ def prompt_and_set_token(cls, user_auth_handler: "UserAuthenticationClient"): ] ) choice = cls._choice_with_retries(prompt, ["1", "2"]) + email = "" # Registration if choice == "1": @@ -161,6 +162,7 @@ def prompt_and_set_token(cls, user_auth_handler: "UserAuthenticationClient"): ) print(cls.indent("Login successful!") + "\n") + return email @classmethod def prompt_terms_and_cond(cls) -> bool: diff --git a/tabpfn_client/service_wrapper.py b/tabpfn_client/service_wrapper.py index f9f4ffc..d915f23 100644 --- a/tabpfn_client/service_wrapper.py +++ b/tabpfn_client/service_wrapper.py @@ -58,6 +58,9 @@ def set_token_by_login(self, email: str, password: str) -> tuple[bool, str]: self.set_token(access_token) return True, message + def get_user_email_verification_status(self, email: str) -> tuple[bool, str]: + return self.service_client.get_user_email_verification_status(email) + def try_reuse_existing_token(self) -> bool: if self.service_client.access_token is None: if not self.CACHED_TOKEN_FILE.exists(): @@ -175,7 +178,7 @@ def __init__(self, service_client=ServiceClient()): def fit(self, X, y) -> None: if not self.service_client.is_initialized: raise RuntimeError( - "Either email is not verified or Service client is not initialized. Please Verify your email and try again!" + "Service client is not initialized!" ) self.last_train_set_uid = self.service_client.upload_train_set(X, y)