From 809d30e0e46b558f2f21788cefe0761ecacdde6f Mon Sep 17 00:00:00 2001 From: jake Date: Tue, 7 Jan 2025 10:51:27 -0800 Subject: [PATCH] Refactored Account class to a dataclass, other minor cleanup --- internetarchive/account.py | 89 ++++++++++++++----------------- internetarchive/cli/ia_account.py | 2 +- 2 files changed, 40 insertions(+), 51 deletions(-) diff --git a/internetarchive/account.py b/internetarchive/account.py index 308e3ecb..bc643be2 100644 --- a/internetarchive/account.py +++ b/internetarchive/account.py @@ -1,49 +1,40 @@ import json +from dataclasses import dataclass, field from typing import ClassVar, Dict, List, Optional import requests from internetarchive import get_session from internetarchive.exceptions import AccountAPIError +from internetarchive.session import ArchiveSession +@dataclass class Account: + locked: bool + verified: bool + email: str + canonical_email: str + itemname: str + screenname: str + notifications: List[str] + has_disability_access: bool + session: ArchiveSession = field(default_factory=get_session) + API_BASE_URL: str = '/services/xauthn/' API_INFO_PARAMS: ClassVar[Dict[str, str]] = {'op': 'info'} API_LOCK_UNLOCK_PARAMS: ClassVar[Dict[str, str]] = {'op': 'lock_unlock'} - def __init__( - self, - locked: bool, - verified: bool, - email: str, - canonical_email: str, - itemname: str, - screenname: str, - notifications: List[str], - has_disability_access: bool, - session: Optional[requests.Session] = None - ): - self.locked = locked - self.verified = verified - self.email = email - self.canonical_email = canonical_email - self.itemname = itemname - self.screenname = screenname - self.notifications = notifications - self.has_disability_access = has_disability_access - self.session = session or get_session() - def _get_api_base_url(self) -> str: """Dynamically construct the API base URL using the session's host.""" return f'https://{self.session.host}{self.API_BASE_URL}' # type: ignore[attr-defined] - def _make_api_request( + def _post_api_request( self, endpoint: str, params: Dict[str, str], data: Dict[str, str], - session: Optional[requests.Session] = None + session: Optional[ArchiveSession] = None ) -> requests.Response: """ Helper method to make API requests. @@ -71,7 +62,7 @@ def from_account_lookup( cls, identifier_type: str, identifier: str, - session: Optional[requests.Session] = None + session: Optional[ArchiveSession] = None ) -> "Account": """ Factory method to initialize an Account using an identifier type and value. @@ -92,7 +83,7 @@ def _fetch_account_data_from_api( cls, identifier_type: str, identifier: str, - session: Optional[requests.Session] = None + session: Optional[ArchiveSession] = None ) -> Dict: """ Fetches account data from the API using an identifier type and value. @@ -123,13 +114,14 @@ def _fetch_account_data_from_api( raise AccountAPIError(j.get("error", "Unknown error"), error_data=j) return j["values"] except requests.exceptions.RequestException as e: - raise ValueError(f"Failed to fetch account data: {e}") + raise AccountAPIError(f"Failed to fetch account data: {e}") + @classmethod def from_json( cls, json_data: Dict, - session: Optional[requests.Session] = None + session: Optional[ArchiveSession] = None ) -> "Account": """ Factory method to initialize an Account using JSON data. @@ -154,9 +146,15 @@ def from_json( "screenname", "verified" ] - for field in required_fields: - if field not in json_data: - raise ValueError(f"Missing required field in JSON data: {field}") + for requried_field in required_fields: + if requried_field not in json_data: + raise ValueError(f"Missing required requried_field in JSON data: {requried_field}") + + # Ensure session is of type ArchiveSession + if session is None: + session = get_session() # Default to ArchiveSession + elif not isinstance(session, ArchiveSession): + raise TypeError(f"Expected session to be of type ArchiveSession, got {type(session)}") return cls( locked=json_data["locked"], @@ -172,7 +170,7 @@ def from_json( def lock(self, comment: Optional[str] = None, - session: Optional[requests.Session] = None) -> requests.Response: + session: Optional[ArchiveSession] = None) -> requests.Response: """ Lock the account. @@ -183,10 +181,10 @@ def lock(self, Returns: The response from the API. """ - data = {'itemname': self.itemname, 'is_lock': "1"} + data = {'itemname': self.itemname, 'is_lock': '1'} if comment: data['comments'] = comment - return self._make_api_request( + return self._post_api_request( self.API_BASE_URL, params=self.API_LOCK_UNLOCK_PARAMS, data=data, @@ -195,7 +193,7 @@ def lock(self, def unlock(self, comment: Optional[str] = None, - session: Optional[requests.Session] = None) -> requests.Response: + session: Optional[ArchiveSession] = None) -> requests.Response: """ Unlock the account. @@ -206,24 +204,24 @@ def unlock(self, Returns: The response from the API. """ - data = {'itemname': self.itemname, 'is_lock': "0"} + data = {'itemname': self.itemname, 'is_lock': '0'} if comment: data['comments'] = comment - return self._make_api_request( + return self._post_api_request( self.API_BASE_URL, params=self.API_LOCK_UNLOCK_PARAMS, data=data, session=session ) - def __iter__(self): + def to_dict(self) -> Dict: """ - Allows the Account instance to be converted to a dictionary using dict(Account). + Converts the Account instance to a dictionary. Returns: A dictionary representation of the Account instance. """ - return iter({ + return { "locked": self.locked, "verified": self.verified, "email": self.email, @@ -232,13 +230,4 @@ def __iter__(self): "screenname": self.screenname, "notifications": self.notifications, "has_disability_access": self.has_disability_access, - }.items()) - - def __repr__(self) -> str: - return ( - f"Account(locked={self.locked}, verified={self.verified}, " - f"email={self.email}, canonical_email={self.canonical_email}, " - f"itemname={self.itemname}, screenname={self.screenname}, " - f"notifications={self.notifications}, " - f"has_disability_access={self.has_disability_access})" - ) + } diff --git a/internetarchive/cli/ia_account.py b/internetarchive/cli/ia_account.py index e2a3fcac..fa40d64d 100644 --- a/internetarchive/cli/ia_account.py +++ b/internetarchive/cli/ia_account.py @@ -101,5 +101,5 @@ def main(args: argparse.Namespace) -> None: r = account.unlock("test unlock", session=args.session) print(r.text) else: - account_data = dict(account) + account_data = account.to_dict() print(json.dumps(account_data))