diff --git a/rapyuta_io_sdk_v2/__init__.py b/rapyuta_io_sdk_v2/__init__.py index 9daf4c6..52dc80e 100644 --- a/rapyuta_io_sdk_v2/__init__.py +++ b/rapyuta_io_sdk_v2/__init__.py @@ -1,5 +1,7 @@ # ruff: noqa from rapyuta_io_sdk_v2.config import Configuration from rapyuta_io_sdk_v2.client import Client +from rapyuta_io_sdk_v2.utils import walk_pages +from rapyuta_io_sdk_v2.async_client import AsyncClient __version__ = "0.0.1" diff --git a/rapyuta_io_sdk_v2/async_client.py b/rapyuta_io_sdk_v2/async_client.py index ef8f529..4e31bb9 100644 --- a/rapyuta_io_sdk_v2/async_client.py +++ b/rapyuta_io_sdk_v2/async_client.py @@ -18,9 +18,8 @@ import platform from rapyuta_io_sdk_v2.config import Configuration from rapyuta_io_sdk_v2.utils import ( - handle_auth_token, handle_and_munchify_response, - walk_pages, + handle_server_errors, ) @@ -48,25 +47,8 @@ def __init__(self, config=None, **kwargs): ) }, ) - self.rip_host = self.config.hosts.get("rip_host") - self.v2api_host = self.config.hosts.get("v2api_host") - - @handle_auth_token - def login( - self, email: str = None, password: str = None, environment: str = "ga" - ) -> str: - """Get the authentication token for the user. - - Args: - email (str, optional): async defaults to None. - password (str, optional): async defaults to None. - environment (str, optional): async defaults to "ga". - - - Returns: - str: auth token - """ - sync_client = httpx.Client( + self.sync_client = httpx.Client( + timeout=timeout, limits=httpx.Limits( max_keepalive_connections=5, max_connections=5, @@ -83,35 +65,79 @@ def login( ) }, ) - if email is None and password is None and self.config is None: - raise ValueError("email and password are required") + self.rip_host = self.config.hosts.get("rip_host") + self.v2api_host = self.config.hosts.get("v2api_host") + + def get_auth_token(self, email: str, password: str) -> str: + """Get the authentication token for the user. - if self.config is None: - self.config = Configuration( - email=email, password=password, environment=environment - ) + Args: + email (str) + password (str) - return sync_client.post( + Returns: + str: authentication token + """ + response = self.sync_client.post( url=f"{self.rip_host}/user/login", headers={"Content-Type": "application/json"}, json={ - "email": email or self.config.email, - "password": password or self.config.password, + "email": email, + "password": password, }, ) + handle_server_errors(response) + return response.json()["data"].get("token") + + def login( + self, + email: str, + password: str, + environment: str = "ga", + ) -> None: + """Get the authentication token for the user. + + Args: + email (str) + password (str) + environment (str) - def logout(self, token=None): - pass + Returns: + str: authentication token + """ + + token = self.get_auth_token(email, password) + self.config.auth_token = token + + @handle_and_munchify_response + def logout(self, token: str = None) -> None: + """Expire the authentication token. + + Args: + token (str): The token to expire. + """ + + if token is None: + token = self.config.auth_token + + return self.sync_client.post( + url=f"{self.rip_host}/user/logout", + headers={ + "Content-Type": "application/json", + "Authorization": f"Bearer {token}", + }, + ) async def refresh_token(self, token: str = None) -> str: if token is None: token = self.config.auth_token - return await self.c.post( + response = await self.c.post( url=f"{self.rip_host}/refreshtoken", headers={"Content-Type": "application/json"}, json={"token": token}, ) + return response.json()["data"].get("token") def set_organization(self, organization_guid: str) -> None: """Set the organization GUID. @@ -130,7 +156,7 @@ def set_project(self, project_guid: str) -> None: self.config.set_project(project_guid) # ----------------- Projects ----------------- - @walk_pages + @handle_and_munchify_response async def list_projects(self, cont: int = 0, limit: int = 50, **kwargs) -> Munch: """List all projects. @@ -243,7 +269,7 @@ async def update_project_owner( ) # -------------------Package------------------- - @walk_pages + @handle_and_munchify_response async def list_packages(self, cont: int = 0, limit: int = 10, **kwargs) -> Munch: """List all packages in a project. @@ -304,7 +330,7 @@ async def delete_package(self, name: str, **kwargs) -> Munch: ) # -------------------Deployment------------------- - @walk_pages + @handle_and_munchify_response async def list_deployments(self, cont: int = 0, limit: int = 50, **kwargs) -> Munch: """List all deployments in a project. @@ -380,7 +406,7 @@ async def delete_deployment(self, name: str, **kwargs) -> Munch: ) # -------------------Disks------------------- - @walk_pages + @handle_and_munchify_response async def list_disks(self, cont: int = 0, limit: int = 50, **kwargs) -> Munch: """List all disks in a project. @@ -441,7 +467,7 @@ async def delete_disk(self, name: str, **kwargs) -> Munch: ) # -------------------Static Routes------------------- - @walk_pages + @handle_and_munchify_response async def list_staticroutes(self, cont: int = 0, limit: int = 0, **kwargs) -> Munch: """List all static routes in a project. @@ -520,7 +546,7 @@ async def delete_staticroute(self, name: str, **kwargs) -> Munch: ) # -------------------Networks------------------- - @walk_pages + @handle_and_munchify_response async def list_networks(self, cont: int = 0, limit: int = 0, **kwargs) -> Munch: """List all networks in a project. @@ -581,7 +607,7 @@ async def delete_network(self, name: str, **kwargs) -> Munch: ) # -------------------Secrets------------------- - @walk_pages + @handle_and_munchify_response async def list_secrets(self, cont: int = 0, limit: int = 50, **kwargs) -> Munch: """List all secrets in a project. @@ -660,7 +686,7 @@ async def delete_secret(self, name: str, **kwargs) -> Munch: ) # -------------------Config Trees------------------- - @walk_pages + @handle_and_munchify_response async def list_configtrees(self, cont: int = 0, limit: int = 50, **kwargs) -> Munch: """List all config trees in a project. @@ -762,7 +788,7 @@ async def delete_configtree(self, name: str, **kwargs) -> Munch: headers=self.config.get_headers(**kwargs), ) - @walk_pages + @handle_and_munchify_response async def list_revisions( self, name: str, diff --git a/rapyuta_io_sdk_v2/client.py b/rapyuta_io_sdk_v2/client.py index 40cdad0..6acb1b6 100644 --- a/rapyuta_io_sdk_v2/client.py +++ b/rapyuta_io_sdk_v2/client.py @@ -18,9 +18,8 @@ import platform from rapyuta_io_sdk_v2.config import Configuration from rapyuta_io_sdk_v2.utils import ( + handle_server_errors, handle_and_munchify_response, - handle_auth_token, - walk_pages, ) @@ -56,39 +55,46 @@ def __init__(self, config: Configuration = None, **kwargs): self.v2api_host = self.config.hosts.get("v2api_host") self.rip_host = self.config.hosts.get("rip_host") - @handle_auth_token - def login( - self, - email: str = None, - password: str = None, - environment: str = "ga", - ) -> str: + def get_auth_token(self, email: str, password: str) -> str: """Get the authentication token for the user. Args: email (str) password (str) - environment (str) Returns: str: authentication token """ - if email is None and password is None and self.config is None: - raise ValueError("email and password are required") - - if self.config is None: - self.config = Configuration( - email=email, password=password, environment=environment - ) - - return self.c.post( + response = self.c.post( url=f"{self.rip_host}/user/login", headers={"Content-Type": "application/json"}, json={ - "email": email or self.config.email, - "password": password or self.config.password, + "email": email, + "password": password, }, ) + handle_server_errors(response) + return response.json()["data"].get("token") + + def login( + self, + email: str, + password: str, + environment: str = "ga", + ) -> None: + """Get the authentication token for the user. + + Args: + email (str) + password (str) + environment (str) + + Returns: + str: authentication token + """ + + token = self.get_auth_token(email, password) + self.config.auth_token = token @handle_and_munchify_response def logout(self, token: str = None) -> None: @@ -103,11 +109,12 @@ def logout(self, token: str = None) -> None: return self.c.post( url=f"{self.rip_host}/user/logout", - headers={"Content-Type": "application/json"}, - json={"token": token}, + headers={ + "Content-Type": "application/json", + "Authorization": f"Bearer {token}", + }, ) - @handle_auth_token def refresh_token(self, token: str = None) -> str: """Refresh the authentication token. @@ -121,11 +128,13 @@ def refresh_token(self, token: str = None) -> str: if token is None: token = self.config.auth_token - return self.c.post( + response = self.c.post( url=f"{self.rip_host}/refreshtoken", headers={"Content-Type": "application/json"}, json={"token": token}, ) + handle_server_errors(response) + return response.json()["data"].get("token") def set_organization(self, organization_guid: str) -> None: """Set the organization GUID. @@ -173,8 +182,10 @@ def get_project(self, project_guid: str = None, **kwargs) -> Munch: ) # @handle_and_munchify_response - @walk_pages - def list_projects(self, cont: int = 0, limit: int = 50, **kwargs) -> Munch: + @handle_and_munchify_response + def list_projects( + self, cont: int = 0, limit: int = 50, status: list[str] = None, **kwargs + ) -> Munch: """List all projects. Returns: @@ -184,7 +195,7 @@ def list_projects(self, cont: int = 0, limit: int = 50, **kwargs) -> Munch: return self.c.get( url=f"{self.v2api_host}/v2/projects/", headers=self.config.get_headers(with_project=False, **kwargs), - params={"continue": cont, "limit": limit}, + params={"continue": cont, "limit": limit, "status": status}, ) @handle_and_munchify_response @@ -252,7 +263,7 @@ def update_project_owner( ) # -------------------Package------------------- - @walk_pages + @handle_and_munchify_response def list_packages(self, cont: int = 0, limit: int = 10, **kwargs) -> Munch: """List all packages in a project. @@ -313,7 +324,7 @@ def delete_package(self, name: str, **kwargs) -> Munch: ) # -------------------Deployment------------------- - @walk_pages + @handle_and_munchify_response def list_deployments(self, cont: int = 0, limit: int = 50, **kwargs) -> Munch: """List all deployments in a project. @@ -389,7 +400,7 @@ def delete_deployment(self, name: str, **kwargs) -> Munch: ) # -------------------Disks------------------- - @walk_pages + @handle_and_munchify_response def list_disks(self, cont: int = 0, limit: int = 50, **kwargs) -> Munch: """List all disks in a project. @@ -450,7 +461,7 @@ def delete_disk(self, name: str, **kwargs) -> Munch: ) # -------------------Static Routes------------------- - @walk_pages + @handle_and_munchify_response def list_staticroutes(self, cont: int = 0, limit: int = 0, **kwargs) -> Munch: """List all static routes in a project. @@ -529,7 +540,7 @@ def delete_staticroute(self, name: str, **kwargs) -> Munch: ) # -------------------Networks------------------- - @walk_pages + @handle_and_munchify_response def list_networks(self, cont: int = 0, limit: int = 0, **kwargs) -> Munch: """List all networks in a project. @@ -590,7 +601,7 @@ def delete_network(self, name: str, **kwargs) -> Munch: ) # -------------------Secrets------------------- - @walk_pages + @handle_and_munchify_response def list_secrets(self, cont: int = 0, limit: int = 50, **kwargs) -> Munch: """List all secrets in a project. @@ -669,7 +680,7 @@ def delete_secret(self, name: str, **kwargs) -> Munch: ) # -------------------Config Trees------------------- - @walk_pages + @handle_and_munchify_response def list_configtrees(self, cont: int = 0, limit: int = 50, **kwargs) -> Munch: """List all config trees in a project. @@ -771,7 +782,7 @@ def delete_configtree(self, name: str, **kwargs) -> Munch: headers=self.config.get_headers(**kwargs), ) - @walk_pages + @handle_and_munchify_response def list_revisions( self, name: str, diff --git a/rapyuta_io_sdk_v2/utils.py b/rapyuta_io_sdk_v2/utils.py index 7c9f522..c832bd8 100644 --- a/rapyuta_io_sdk_v2/utils.py +++ b/rapyuta_io_sdk_v2/utils.py @@ -14,6 +14,7 @@ # limitations under the License. # from rapyuta_io_sdk_v2.config import Configuration import asyncio +import functools import json import os import sys @@ -21,7 +22,7 @@ import httpx import rapyuta_io_sdk_v2.exceptions as exceptions -from munch import munchify +from munch import munchify, Munch def handle_server_errors(response: httpx.Response): @@ -90,12 +91,14 @@ def get_default_app_dir(app_name: str) -> str: # Decorator to handle server errors and munchify response def handle_and_munchify_response(func): - async def async_wrapper(*args, **kwargs): + @functools.wraps(func) + async def async_wrapper(*args, **kwargs) -> Munch: response = await func(*args, **kwargs) handle_server_errors(response) return munchify(response.json()) - def sync_wrapper(*args, **kwargs): + @functools.wraps(func) + def sync_wrapper(*args, **kwargs) -> Munch: response = func(*args, **kwargs) handle_server_errors(response) return munchify(response.json()) @@ -106,55 +109,32 @@ def sync_wrapper(*args, **kwargs): return sync_wrapper -def handle_auth_token(func): - async def async_wrapper(self, *args, **kwargs): - response = await func(self, *args, **kwargs) - handle_server_errors(response) - self.config.auth_token = response.json()["data"].get("token") - return self.config.auth_token - - def sync_wrapper(self, *args, **kwargs): - response = func(self, *args, **kwargs) - handle_server_errors(response) - self.config.auth_token = response.json()["data"].get("token") - return self.config.auth_token - - if asyncio.iscoroutinefunction(func): - return async_wrapper - else: - return sync_wrapper - - -def walk_pages(func): - def wrapper(self, *args, **kwargs): - result = {"items": []} - - limit = kwargs.pop("limit", 50) - limit = int(limit) if limit else 50 - - cont = kwargs.pop("cont", 0) - cont = int(cont) if cont else 0 - - while True: - response = func(self, cont, limit, **kwargs) - handle_server_errors(response) - - data = response.json() - items = data.get("items", []) - if not items: - break +def walk_pages(func, *args, limit=50, cont=0, **kwargs): + """ + A generator function to paginate through API results. - cont = data.get("metadata", {}).get("continue", None) - if cont is None: - break + Args: + func (callable): The API function to call, must accept `cont` and `limit` as arguments. + *args: Positional arguments to pass to the API function. + limit (int, optional): Maximum number of items to return. Defaults to 50. + cont (int, optional): Initial continuation token. Defaults to 0. + **kwargs: Additional keyword arguments to pass to the API function. - result["items"].extend(items) + Yields: + Munch: Each item from the API response. + """ + while True: + response = func(cont, limit, *args, **kwargs) - # Stop if we reach the limit - if limit is not None and len(result["items"]) >= limit: - result["items"] = result["items"][:limit] - return munchify(result) + data = response + items = data.get("items", []) + if not items: + break - return munchify(result) + for item in items: + yield munchify(item) - return wrapper + # Update `cont` for the next page + cont = data.get("metadata", {}).get("continue", None) + if cont is None: + break