Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Working version of aiohttp upgrade #140

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,9 @@
"-p",
"*test*.py"
],
"python.testing.pytestEnabled": true,
"python.testing.unittestEnabled": false,
"python.testing.unittestEnabled": true,
"editor.defaultFormatter": "charliermarsh.ruff",
"[python]": {
"editor.formatOnSave": true,
},
"python.testing.pytestArgs": [
"tests"
]
}
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ requires-python = ">=3.7"
dependencies = [
"requests",
"tqdm",
"packaging"
"packaging",
"aiohttp[speedups]",

]

[project.urls]
Expand Down
22 changes: 13 additions & 9 deletions src/kagglehub/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from kagglehub.config import get_kaggle_credentials, set_kaggle_credentials
from kagglehub.exceptions import UnauthenticatedError

from aiohttp import ClientSession

_logger = logging.getLogger(__name__)

INVALID_CREDENTIALS_ERROR = 401
Expand Down Expand Up @@ -60,7 +62,7 @@ def _is_in_notebook() -> bool:
return False # Probably standard Python interpreter


def _notebook_login(validate_credentials: bool) -> None: # noqa: FBT001
async def _notebook_login(validate_credentials: bool) -> None: # noqa: FBT001
"""Prompt the user for their Kaggle token and save it in a widget (Jupyter or Colab)."""
library_error = "You need the `ipywidgets` module: `pip install ipywidgets`."
try:
Expand All @@ -87,7 +89,7 @@ def _notebook_login(validate_credentials: bool) -> None: # noqa: FBT001
)
display(login_token_widget)

def on_click_login_button(_: str) -> None:
async def on_click_login_button(_: str) -> None:
username = username_widget.value
token = token_widget.value
# Erase token and clear value to make sure it's not saved in the notebook.
Expand All @@ -102,7 +104,8 @@ def on_click_login_button(_: str) -> None:

# Validate credentials if necessary
if validate_credentials is True:
_validate_credentials_helper()
async with ClientSession() as session:
await _validate_credentials_helper(session)
message = captured.getvalue()
except Exception as error:
message = str(error)
Expand All @@ -112,9 +115,9 @@ def on_click_login_button(_: str) -> None:
login_button.on_click(on_click_login_button)


def _validate_credentials_helper() -> None:
api_client = KaggleApiV1Client()
response = api_client.get("/hello")
async def _validate_credentials_helper(session: ClientSession) -> None:
api_client = KaggleApiV1Client(session)
response = await api_client.get("/hello")
if "code" not in response:
_logger.info("Kaggle credentials successfully validated.")
elif response["code"] == INVALID_CREDENTIALS_ERROR:
Expand All @@ -125,11 +128,11 @@ def _validate_credentials_helper() -> None:
_logger.warning("Unable to validate Kaggle credentials at this time.")


def login(validate_credentials: bool = True) -> None: # noqa: FBT002, FBT001
async def login(validate_credentials: bool = True) -> None: # noqa: FBT002, FBT001
"""Prompt the user for their Kaggle username and API key and save them globally."""

if _is_in_notebook():
_notebook_login(validate_credentials)
await _notebook_login(validate_credentials)
return
else:
username = input("Enter your Kaggle username: ")
Expand All @@ -140,7 +143,8 @@ def login(validate_credentials: bool = True) -> None: # noqa: FBT002, FBT001
if not validate_credentials:
return

_validate_credentials_helper()
with ClientSession() as session:
await _validate_credentials_helper(session)


def whoami() -> dict:
Expand Down
59 changes: 30 additions & 29 deletions src/kagglehub/clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from urllib.parse import urljoin

import requests
from aiohttp import ClientResponse, ClientSession
from packaging.version import parse
from requests.auth import HTTPBasicAuth
from tqdm import tqdm
Expand Down Expand Up @@ -81,11 +82,12 @@ def get_user_agent() -> str:
class KaggleApiV1Client:
BASE_PATH = "api/v1"

def __init__(self) -> None:
def __init__(self, session: ClientSession) -> None:
self.credentials = get_kaggle_credentials()
self.endpoint = get_kaggle_api_endpoint()
self.session = session

def _check_for_version_update(self, response: requests.Response) -> None:
def _check_for_version_update(self, response: ClientResponse) -> None:
latest_version_str = response.headers.get("X-Kaggle-HubVersion")
if latest_version_str:
current_version = parse(kagglehub.__version__)
Expand All @@ -96,26 +98,26 @@ def _check_for_version_update(self, response: requests.Response) -> None:
f"version, please consider updating (latest version: {latest_version})"
)

def get(self, path: str, resource_handle: Optional[ResourceHandle] = None) -> dict:
async def get(self, path: str, resource_handle: Optional[ResourceHandle] = None) -> dict:
url = self._build_url(path)
with requests.get(
async with self.session.get(
url,
headers={"User-Agent": get_user_agent()},
auth=self._get_http_basic_auth(),
timeout=(DEFAULT_CONNECT_TIMEOUT, DEFAULT_READ_TIMEOUT),
) as response:
kaggle_api_raise_for_status(response, resource_handle)
self._check_for_version_update(response)
return response.json()

def post(self, path: str, data: dict) -> dict:
read_timeout=DEFAULT_READ_TIMEOUT,
conn_timeout=DEFAULT_CONNECT_TIMEOUT,
) as resp:
kaggle_api_raise_for_status(resp, resource_handle)
self._check_for_version_update(resp)
return resp.json()

async def post(self, path: str, data: dict) -> dict:
url = self._build_url(path)
with requests.post(
async with self.session.post(
url,
headers={"User-Agent": get_user_agent()},
read_timeout=DEFAULT_READ_TIMEOUT,
conn_timeout=DEFAULT_CONNECT_TIMEOUT,
json=data,
auth=self._get_http_basic_auth(),
timeout=(DEFAULT_CONNECT_TIMEOUT, DEFAULT_READ_TIMEOUT),
) as response:
response.raise_for_status()
response_dict = response.json()
Expand Down Expand Up @@ -210,7 +212,7 @@ def _download_file(
class KaggleJwtClient:
BASE_PATH = "/kaggle-jwt-handler/"

def __init__(self) -> None:
def __init__(self, session: ClientSession) -> None:
self.endpoint = os.getenv(KAGGLE_DATA_PROXY_URL_ENV_VAR_NAME)
if self.endpoint is None:
msg = f"The {KAGGLE_DATA_PROXY_URL_ENV_VAR_NAME} should be set."
Expand All @@ -236,20 +238,16 @@ def __init__(self) -> None:
"X-Kaggle-Authorization": f"Bearer {jwt_token}",
"X-KAGGLE-PROXY-DATA": data_proxy_token,
}
self.session = session

def post(
async def post(
self,
request_name: str,
data: dict,
timeout: Tuple[float, float] = (DEFAULT_CONNECT_TIMEOUT, DEFAULT_READ_TIMEOUT),
) -> dict:
url = f"{self.endpoint}{KaggleJwtClient.BASE_PATH}{request_name}"
with requests.post(
url,
headers=self.headers,
data=bytes(json.dumps(data), "utf-8"),
timeout=timeout,
) as response:
async with self.session.post(url, headers=self.headers, json=data, timeout=timeout) as response:
response.raise_for_status()
json_response = response.json()
if "wasSuccessful" not in json_response:
Expand All @@ -271,25 +269,28 @@ class ColabClient:
# of ModelColabCacheResolver.
TBE_RUNTIME_ADDR_ENV_VAR_NAME = "TBE_RUNTIME_ADDR"

def __init__(self) -> None:
def __init__(self, session: ClientSession) -> None:
self.endpoint = os.getenv(self.TBE_RUNTIME_ADDR_ENV_VAR_NAME)
if self.endpoint is None:
msg = f"The {self.TBE_RUNTIME_ADDR_ENV_VAR_NAME} should be set."
raise ColabEnvironmentError(msg)

self.credentials = get_kaggle_credentials()
self.headers = {"Content-type": "application/json"}
self.session = session

def post(self, data: dict, handle_path: str, resource_handle: Optional[ResourceHandle] = None) -> Optional[dict]:
async def post(
self, data: dict, handle_path: str, resource_handle: Optional[ResourceHandle] = None
) -> Optional[dict]:
url = f"http://{self.endpoint}{handle_path}"
with requests.post(
with self.session.post(
url,
data=json.dumps(data),
auth=self._get_http_basic_auth(),
headers=self.headers,
timeout=(DEFAULT_CONNECT_TIMEOUT, DEFAULT_READ_TIMEOUT),
read_timeout=DEFAULT_READ_TIMEOUT,
connect_timeout=DEFAULT_CONNECT_TIMEOUT,
) as response:
if response.status_code == HTTP_STATUS_404:
if response.staus == HTTP_STATUS_404:
raise NotFoundError()
colab_raise_for_status(response, resource_handle)
if response.text:
Expand Down
112 changes: 59 additions & 53 deletions src/kagglehub/colab_cache_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,36 +8,41 @@
from kagglehub.handle import ModelHandle
from kagglehub.resolver import Resolver

from aiohttp import ClientSession

COLAB_CACHE_MOUNT_FOLDER_ENV_VAR_NAME = "COLAB_CACHE_MOUNT_FOLDER"
DEFAULT_COLAB_CACHE_MOUNT_FOLDER = "/kaggle/input"

logger = logging.getLogger(__name__)


class ModelColabCacheResolver(Resolver[ModelHandle]):
def is_supported(self, handle: ModelHandle, *_, **__) -> bool: # noqa: ANN002, ANN003
async def is_supported(self, handle: ModelHandle, *_, **__) -> bool: # noqa: ANN002, ANN003
if ColabClient.TBE_RUNTIME_ADDR_ENV_VAR_NAME not in os.environ or is_colab_cache_disabled():
return False

api_client = ColabClient()
data = {
"owner": handle.owner,
"model": handle.model,
"framework": handle.framework,
"variation": handle.variation,
}

if handle.is_versioned():
# Colab treats version as int in the request
data["version"] = handle.version # type: ignore

try:
api_client.post(data, ColabClient.IS_SUPPORTED_PATH, handle)
except NotFoundError:
return False
return True
with ClientSession() as session:
api_client = ColabClient(session)
data = {
"owner": handle.owner,
"model": handle.model,
"framework": handle.framework,
"variation": handle.variation,
}

if handle.is_versioned():
# Colab treats version as int in the request
data["version"] = handle.version # type: ignore

def __call__(self, h: ModelHandle, path: Optional[str] = None, *, force_download: Optional[bool] = False) -> str:
try:
await api_client.post(data, ColabClient.IS_SUPPORTED_PATH, handle)
except NotFoundError:
return False
return True

async def __call__(
self, h: ModelHandle, path: Optional[str] = None, *, force_download: Optional[bool] = False
) -> str:
if force_download:
logger.warning("Ignoring invalid input: force_download flag cannot be used in a Colab notebook")

Expand All @@ -46,37 +51,38 @@ def __call__(self, h: ModelHandle, path: Optional[str] = None, *, force_download
else:
logger.info(f"Attaching model '{h}' to your Colab notebook...")

api_client = ColabClient()
data = {
"owner": h.owner,
"model": h.model,
"framework": h.framework,
"variation": h.variation,
}
if h.is_versioned():
# Colab treats version as int in the request
data["version"] = h.version # type: ignore

response = api_client.post(data, ColabClient.MOUNT_PATH, h)

if response is not None:
if "slug" not in response:
msg = "'slug' field missing from response"
raise BackendError(msg)

base_mount_path = os.getenv(COLAB_CACHE_MOUNT_FOLDER_ENV_VAR_NAME, DEFAULT_COLAB_CACHE_MOUNT_FOLDER)
cached_path = f"{base_mount_path}/{response['slug']}"

if path:
cached_filepath = f"{cached_path}/{path}"
if not os.path.exists(cached_filepath):
msg = (
f"'{path}' is not present in the model files. "
f"You can access the other files of the attached model at '{cached_path}'"
)
raise ValueError(msg)
return cached_filepath
return cached_path
else:
no_response = "No response received or response was empty."
raise ValueError(no_response)
with ClientSession() as session:
api_client = ColabClient(session)
data = {
"owner": h.owner,
"model": h.model,
"framework": h.framework,
"variation": h.variation,
}
if h.is_versioned():
# Colab treats version as int in the request
data["version"] = h.version # type: ignore

response = await api_client.post(data, ColabClient.MOUNT_PATH, h)

if response is not None:
if "slug" not in response:
msg = "'slug' field missing from response"
raise BackendError(msg)

base_mount_path = os.getenv(COLAB_CACHE_MOUNT_FOLDER_ENV_VAR_NAME, DEFAULT_COLAB_CACHE_MOUNT_FOLDER)
cached_path = f"{base_mount_path}/{response['slug']}"

if path:
cached_filepath = f"{cached_path}/{path}"
if not os.path.exists(cached_filepath):
msg = (
f"'{path}' is not present in the model files. "
f"You can access the other files of the attached model at '{cached_path}'"
)
raise ValueError(msg)
return cached_filepath
return cached_path
else:
no_response = "No response received or response was empty."
raise ValueError(no_response)
4 changes: 2 additions & 2 deletions src/kagglehub/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
logger = logging.getLogger(__name__)


def dataset_download(handle: str, path: Optional[str] = None, *, force_download: Optional[bool] = False) -> str:
async def dataset_download(handle: str, path: Optional[str] = None, *, force_download: Optional[bool] = False) -> str:
"""Download dataset files
Args:
handle: (string) the dataset handle
Expand All @@ -20,4 +20,4 @@ def dataset_download(handle: str, path: Optional[str] = None, *, force_download:

h = parse_dataset_handle(handle)
logger.info(f"Downloading Dataset: {h.to_url()} ...", extra={**EXTRA_CONSOLE_BLOCK})
return registry.dataset_resolver(h, path, force_download=force_download)
return await registry.dataset_resolver(h, path, force_download=force_download)
Loading