Skip to content

Commit

Permalink
Refactored Account class to a dataclass, other minor cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
jjjake committed Jan 7, 2025
1 parent 20b0ff7 commit 809d30e
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 51 deletions.
89 changes: 39 additions & 50 deletions internetarchive/account.py
Original file line number Diff line number Diff line change
@@ -1,49 +1,40 @@
import json
from dataclasses import dataclass, field
from typing import ClassVar, Dict, List, Optional

import requests

from internetarchive import get_session
from internetarchive.exceptions import AccountAPIError
from internetarchive.session import ArchiveSession


@dataclass
class Account:
locked: bool
verified: bool
email: str
canonical_email: str
itemname: str
screenname: str
notifications: List[str]
has_disability_access: bool
session: ArchiveSession = field(default_factory=get_session)

API_BASE_URL: str = '/services/xauthn/'
API_INFO_PARAMS: ClassVar[Dict[str, str]] = {'op': 'info'}
API_LOCK_UNLOCK_PARAMS: ClassVar[Dict[str, str]] = {'op': 'lock_unlock'}

def __init__(
self,
locked: bool,
verified: bool,
email: str,
canonical_email: str,
itemname: str,
screenname: str,
notifications: List[str],
has_disability_access: bool,
session: Optional[requests.Session] = None
):
self.locked = locked
self.verified = verified
self.email = email
self.canonical_email = canonical_email
self.itemname = itemname
self.screenname = screenname
self.notifications = notifications
self.has_disability_access = has_disability_access
self.session = session or get_session()

def _get_api_base_url(self) -> str:
"""Dynamically construct the API base URL using the session's host."""
return f'https://{self.session.host}{self.API_BASE_URL}' # type: ignore[attr-defined]

def _make_api_request(
def _post_api_request(
self,
endpoint: str,
params: Dict[str, str],
data: Dict[str, str],
session: Optional[requests.Session] = None
session: Optional[ArchiveSession] = None
) -> requests.Response:
"""
Helper method to make API requests.
Expand Down Expand Up @@ -71,7 +62,7 @@ def from_account_lookup(
cls,
identifier_type: str,
identifier: str,
session: Optional[requests.Session] = None
session: Optional[ArchiveSession] = None
) -> "Account":
"""
Factory method to initialize an Account using an identifier type and value.
Expand All @@ -92,7 +83,7 @@ def _fetch_account_data_from_api(
cls,
identifier_type: str,
identifier: str,
session: Optional[requests.Session] = None
session: Optional[ArchiveSession] = None
) -> Dict:
"""
Fetches account data from the API using an identifier type and value.
Expand Down Expand Up @@ -123,13 +114,14 @@ def _fetch_account_data_from_api(
raise AccountAPIError(j.get("error", "Unknown error"), error_data=j)
return j["values"]
except requests.exceptions.RequestException as e:
raise ValueError(f"Failed to fetch account data: {e}")
raise AccountAPIError(f"Failed to fetch account data: {e}")


@classmethod
def from_json(
cls,
json_data: Dict,
session: Optional[requests.Session] = None
session: Optional[ArchiveSession] = None
) -> "Account":
"""
Factory method to initialize an Account using JSON data.
Expand All @@ -154,9 +146,15 @@ def from_json(
"screenname",
"verified"
]
for field in required_fields:
if field not in json_data:
raise ValueError(f"Missing required field in JSON data: {field}")
for requried_field in required_fields:
if requried_field not in json_data:
raise ValueError(f"Missing required requried_field in JSON data: {requried_field}")

# Ensure session is of type ArchiveSession
if session is None:
session = get_session() # Default to ArchiveSession
elif not isinstance(session, ArchiveSession):
raise TypeError(f"Expected session to be of type ArchiveSession, got {type(session)}")

return cls(
locked=json_data["locked"],
Expand All @@ -172,7 +170,7 @@ def from_json(

def lock(self,
comment: Optional[str] = None,
session: Optional[requests.Session] = None) -> requests.Response:
session: Optional[ArchiveSession] = None) -> requests.Response:
"""
Lock the account.
Expand All @@ -183,10 +181,10 @@ def lock(self,
Returns:
The response from the API.
"""
data = {'itemname': self.itemname, 'is_lock': "1"}
data = {'itemname': self.itemname, 'is_lock': '1'}
if comment:
data['comments'] = comment
return self._make_api_request(
return self._post_api_request(
self.API_BASE_URL,
params=self.API_LOCK_UNLOCK_PARAMS,
data=data,
Expand All @@ -195,7 +193,7 @@ def lock(self,

def unlock(self,
comment: Optional[str] = None,
session: Optional[requests.Session] = None) -> requests.Response:
session: Optional[ArchiveSession] = None) -> requests.Response:
"""
Unlock the account.
Expand All @@ -206,24 +204,24 @@ def unlock(self,
Returns:
The response from the API.
"""
data = {'itemname': self.itemname, 'is_lock': "0"}
data = {'itemname': self.itemname, 'is_lock': '0'}
if comment:
data['comments'] = comment
return self._make_api_request(
return self._post_api_request(
self.API_BASE_URL,
params=self.API_LOCK_UNLOCK_PARAMS,
data=data,
session=session
)

def __iter__(self):
def to_dict(self) -> Dict:
"""
Allows the Account instance to be converted to a dictionary using dict(Account).
Converts the Account instance to a dictionary.
Returns:
A dictionary representation of the Account instance.
"""
return iter({
return {
"locked": self.locked,
"verified": self.verified,
"email": self.email,
Expand All @@ -232,13 +230,4 @@ def __iter__(self):
"screenname": self.screenname,
"notifications": self.notifications,
"has_disability_access": self.has_disability_access,
}.items())

def __repr__(self) -> str:
return (
f"Account(locked={self.locked}, verified={self.verified}, "
f"email={self.email}, canonical_email={self.canonical_email}, "
f"itemname={self.itemname}, screenname={self.screenname}, "
f"notifications={self.notifications}, "
f"has_disability_access={self.has_disability_access})"
)
}
2 changes: 1 addition & 1 deletion internetarchive/cli/ia_account.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,5 +101,5 @@ def main(args: argparse.Namespace) -> None:
r = account.unlock("test unlock", session=args.session)
print(r.text)
else:
account_data = dict(account)
account_data = account.to_dict()
print(json.dumps(account_data))

0 comments on commit 809d30e

Please sign in to comment.