Skip to content

Commit

Permalink
Refactor Login/Email Verification Functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
anshulg954 committed Jun 22, 2024
1 parent 26087de commit 7d5f8ab
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 18 deletions.
16 changes: 10 additions & 6 deletions tabpfn_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import logging
from importlib.metadata import version, PackageNotFoundError
import numpy as np
from enum import Enum
from omegaconf import OmegaConf
import json
from typing import Literal
Expand Down Expand Up @@ -34,6 +35,9 @@ class ServiceClient:
Singleton class for handling communication with the server.
It encapsulates all the API calls to the server.
"""
class Status(Enum):
OKAY = 0
USER_NOT_VERIFIED = 1

def __init__(self):
self.server_config = SERVER_CONFIG
Expand Down Expand Up @@ -231,10 +235,10 @@ def try_authenticate(self, access_token) -> bool:
)

self._validate_response(response, "try_authenticate", only_version_check=True)

if response.status_code == 200:
is_authenticated = True

elif response.status_code == 403:
is_authenticated = (False, self.Status.USER_NOT_VERIFIED)
return is_authenticated

def validate_email(self, email: str) -> tuple[bool, str]:
Expand Down Expand Up @@ -398,16 +402,16 @@ def retrieve_greeting_messages(self) -> list[str]:
greeting_messages = response.json()["messages"]
return greeting_messages

def get_user_email_verification_status(self, email: str) -> tuple[bool, str]:
# bool optional parameter is accesstoken required
def get_user_email_verification_status(self, email: str, access_token_required: bool) -> tuple[bool, str]:
"""
Check if the user's email is verified.
"""
response = self.httpx_client.post(
self.server_endpoints.get_user_verification_status_via_email.path,
params={"email": email},
params={"email": email, "access_token_required": access_token_required},
)

return response.json() or False
return response.json()

def get_data_summary(self) -> {}:
"""
Expand Down
6 changes: 5 additions & 1 deletion tabpfn_client/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,12 @@ def init(use_server=True):

is_valid_token_set = user_auth_handler.try_reuse_existing_token()

if is_valid_token_set:
if isinstance(is_valid_token_set, bool) and is_valid_token_set:
PromptAgent.prompt_reusing_existing_token()
elif isinstance(is_valid_token_set, tuple) and is_valid_token_set[1] is not None:
print("Access token is valid but email is not verified...")
PromptAgent.reverify_email(user_auth_handler)
return init(use_server)
else:
if not PromptAgent.prompt_terms_and_cond():
raise RuntimeError(
Expand Down
13 changes: 5 additions & 8 deletions tabpfn_client/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from dataclasses import dataclass, asdict

import numpy as np
from tabpfn_client import init
from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin
from sklearn.utils.validation import check_is_fitted

Expand Down Expand Up @@ -182,17 +183,15 @@ def __init__(
self.subsample_samples = subsample_samples

# check if user is verified
if config.g_tabpfn_config.user_email and not config.g_tabpfn_config.user_auth_handler.get_user_email_verification_status(config.g_tabpfn_config.user_email):
if config.g_tabpfn_config.user_email and not config.g_tabpfn_config.user_auth_handler.get_user_email_verification_status(config.g_tabpfn_config.user_email)[0]:
raise RuntimeError(
"Dear User, your email has not been verified. Please, check your mailbox, verify your email and try again!"
)

def fit(self, X, y):
# assert init() is called
if not config.g_tabpfn_config.is_initialized:
raise RuntimeError(
"tabpfn_client.init() must be called before using TabPFNClassifier"
)
init()

if config.g_tabpfn_config.use_server:
try:
Expand Down Expand Up @@ -318,17 +317,15 @@ def __init__(
self.subsample_samples = subsample_samples

# check if user is verified
if config.g_tabpfn_config.user_email and not config.g_tabpfn_config.user_auth_handler.get_user_email_verification_status(config.g_tabpfn_config.user_email):
if config.g_tabpfn_config.user_email and not config.g_tabpfn_config.user_auth_handler.get_user_email_verification_status(config.g_tabpfn_config.user_email)[0]:
raise RuntimeError(
"Dear User, your email has not been verified. Please, check your mailbox, verify your email and try again!"
)

def fit(self, X, y):
# assert init() is called
if not config.g_tabpfn_config.is_initialized:
raise RuntimeError(
"tabpfn_client.init() must be called before using TabPFNRegressor"
)
init()

if config.g_tabpfn_config.use_server:
try:
Expand Down
23 changes: 23 additions & 0 deletions tabpfn_client/prompt_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,29 @@ def prompt_reusing_existing_token(cls):

print(cls.indent(prompt))

@classmethod
def reverify_email(cls, user_auth_handler: "UserAuthenticationClient"):
prompt = "\n".join(
[
"Please check your inbox for the verification email.",
"Note: The email might be in your spam folder or could have expired.",
]
)
print(cls.indent(prompt))
retry_verification = "\n".join(
[
"Do you want to resend email verification link? (y/n): ",
]
)
choice = cls._choice_with_retries(retry_verification, ["y", "n"])
if choice == "y":
email = input(cls.indent("Please enter your email: "))
password = getpass.getpass(cls.indent("Please enter your password: "))

user_auth_handler.set_token_by_login(email, password)
print(cls.indent("A verification email has been sent, provided the details are correct!") + "\n")
return

@classmethod
def prompt_retrieved_greeting_messages(cls, greeting_messages: list[str]):
for message in greeting_messages:
Expand Down
5 changes: 5 additions & 0 deletions tabpfn_client/server_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@ endpoints:
methods: [ "POST" ]
description: "Send reset password email"

get_user_verification_status_via_email:
path: "/auth/get_user_verification_status_via_email/"
methods: ["POST"]
description: "Get user verification status via email"

retrieve_greeting_messages:
path: "/retrieve_greeting_messages/"
methods: [ "GET" ]
Expand Down
11 changes: 8 additions & 3 deletions tabpfn_client/service_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ def set_token_by_registration(
is_created, message = self.service_client.register(
email, password, password_confirm, validation_link, additional_info
)
is_verified, access_token = self.get_user_email_verification_status(email, access_token_required=True)
if not is_verified:
self.set_token(access_token)
return is_created, message

def set_token_by_login(self, email: str, password: str) -> tuple[bool, str]:
Expand All @@ -58,10 +61,10 @@ def set_token_by_login(self, email: str, password: str) -> tuple[bool, str]:
self.set_token(access_token)
return True, message

def get_user_email_verification_status(self, email: str) -> tuple[bool, str]:
return self.service_client.get_user_email_verification_status(email)
def get_user_email_verification_status(self, email: str, access_token_required: bool = False) -> tuple[bool, str]:
return self.service_client.get_user_email_verification_status(email, access_token_required)

def try_reuse_existing_token(self) -> bool:
def try_reuse_existing_token(self) -> bool or (bool, str):
if self.service_client.access_token is None:
if not self.CACHED_TOKEN_FILE.exists():
return False
Expand All @@ -75,6 +78,8 @@ def try_reuse_existing_token(self) -> bool:
if not is_valid:
self._reset_token()
return False
elif isinstance(is_valid, tuple) and not is_valid[0] and is_valid[1] == self.service_client.Status.USER_NOT_VERIFIED:
return False, access_token

logger.debug(f"Reusing existing access token? {is_valid}")
self.set_token(access_token)
Expand Down

0 comments on commit 7d5f8ab

Please sign in to comment.