Skip to content

Commit

Permalink
Refator Init for new functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
anshulg954 committed Jun 17, 2024
1 parent c678739 commit 26087de
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 2 deletions.
11 changes: 11 additions & 0 deletions tabpfn_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,17 @@ def retrieve_greeting_messages(self) -> list[str]:

greeting_messages = response.json()["messages"]
return greeting_messages

def get_user_email_verification_status(self, email: str) -> tuple[bool, str]:
"""
Check if the user's email is verified.
"""
response = self.httpx_client.post(
self.server_endpoints.get_user_verification_status_via_email.path,
params={"email": email},
)

return response.json() or False

def get_data_summary(self) -> {}:
"""
Expand Down
3 changes: 2 additions & 1 deletion tabpfn_client/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

class TabPFNConfig:
is_initialized = None
user_email = None
use_server = None
user_auth_handler = None
inference_handler = None
Expand Down Expand Up @@ -44,7 +45,7 @@ def init(use_server=True):
)

# prompt for login / register
PromptAgent.prompt_and_set_token(user_auth_handler)
g_tabpfn_config.user_email = PromptAgent.prompt_and_set_token(user_auth_handler)

# Print new greeting messages. If there are no new messages, nothing will be printed.
PromptAgent.prompt_retrieved_greeting_messages(
Expand Down
12 changes: 12 additions & 0 deletions tabpfn_client/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,12 @@ def __init__(
self.add_fingerprint_features = add_fingerprint_features
self.subsample_samples = subsample_samples

# check if user is verified
if config.g_tabpfn_config.user_email and not config.g_tabpfn_config.user_auth_handler.get_user_email_verification_status(config.g_tabpfn_config.user_email):
raise RuntimeError(
"Dear User, your email has not been verified. Please, check your mailbox, verify your email and try again!"
)

def fit(self, X, y):
# assert init() is called
if not config.g_tabpfn_config.is_initialized:
Expand Down Expand Up @@ -311,6 +317,12 @@ def __init__(
self.super_bar_dist_averaging = super_bar_dist_averaging
self.subsample_samples = subsample_samples

# check if user is verified
if config.g_tabpfn_config.user_email and not config.g_tabpfn_config.user_auth_handler.get_user_email_verification_status(config.g_tabpfn_config.user_email):
raise RuntimeError(
"Dear User, your email has not been verified. Please, check your mailbox, verify your email and try again!"
)

def fit(self, X, y):
# assert init() is called
if not config.g_tabpfn_config.is_initialized:
Expand Down
2 changes: 2 additions & 0 deletions tabpfn_client/prompt_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def prompt_and_set_token(cls, user_auth_handler: "UserAuthenticationClient"):
]
)
choice = cls._choice_with_retries(prompt, ["1", "2"])
email = ""

# Registration
if choice == "1":
Expand Down Expand Up @@ -161,6 +162,7 @@ def prompt_and_set_token(cls, user_auth_handler: "UserAuthenticationClient"):
)

print(cls.indent("Login successful!") + "\n")
return email

@classmethod
def prompt_terms_and_cond(cls) -> bool:
Expand Down
5 changes: 4 additions & 1 deletion tabpfn_client/service_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ def set_token_by_login(self, email: str, password: str) -> tuple[bool, str]:
self.set_token(access_token)
return True, message

def get_user_email_verification_status(self, email: str) -> tuple[bool, str]:
return self.service_client.get_user_email_verification_status(email)

def try_reuse_existing_token(self) -> bool:
if self.service_client.access_token is None:
if not self.CACHED_TOKEN_FILE.exists():
Expand Down Expand Up @@ -175,7 +178,7 @@ def __init__(self, service_client=ServiceClient()):
def fit(self, X, y) -> None:
if not self.service_client.is_initialized:
raise RuntimeError(
"Either email is not verified or Service client is not initialized. Please Verify your email and try again!"
"Service client is not initialized!"
)

self.last_train_set_uid = self.service_client.upload_train_set(X, y)
Expand Down

0 comments on commit 26087de

Please sign in to comment.