Skip to content

Commit

Permalink
feat: adds pagination feature (#63)
Browse files Browse the repository at this point in the history
  • Loading branch information
jtyoung84 authored May 31, 2024
1 parent c489716 commit 5aab8b0
Show file tree
Hide file tree
Showing 2 changed files with 220 additions and 54 deletions.
136 changes: 109 additions & 27 deletions src/aind_codeocean_api/codeocean.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import logging
from enum import Enum
from inspect import signature
from time import sleep
from typing import Dict, List, Optional, Union

import requests
Expand Down Expand Up @@ -200,7 +201,7 @@ def search_data_assets(

return response

def search_all_data_assets(
def _paginate_data_assets(
self,
sort_order: Optional[str] = None,
sort_field: Optional[str] = None,
Expand All @@ -209,9 +210,9 @@ def search_all_data_assets(
favorite: Optional[bool] = None,
archived: Optional[bool] = None,
query: Optional[str] = None,
) -> requests.models.Response:
):
"""
Utility method to return all the search results that match a query
Utility method to paginate through search results returned by search
Parameters
----------
sort_order : Optional[str]
Expand All @@ -232,11 +233,9 @@ def search_all_data_assets(
Returns
-------
requests.models.Response
Iterator[List[dict]]
"""

# TODO: it'd be nice to re-use the search_data_assets function, but
# it'll require passing in a requests.Session object into that method.
frame_locals = locals()
query_params = dict(
[
Expand All @@ -248,34 +247,117 @@ def search_all_data_assets(
]
)

requests_session = requests.Session()
all_results = []
def get_page(
r: requests.Session,
qp: dict,
max_retries: int = 3,
) -> dict:
"""
Get a single list of results back from Code Ocean. It will retry
a request up to the max amount of retries. It will wait
min(retry_count**2, 15) seconds.
Parameters
----------
r : requests.Session
qp : dict
query parameters
max_retries : int
Max number of retries before raising an error
Returns
-------
dict
Response from Code Ocean
"""
rsp = r.get(self.asset_url, params=qp, auth=(self.token, ""))
if rsp.status_code == 200:
return rsp.json()
else:
retry = 1
while retry <= max_retries and rsp.status_code != 200:
logging.debug(
f"Backing off and retrying: {retry}. "
f"Reason: {rsp.status_code}"
)
sleep(min(retry**2, 15))
retry += 1
rsp = r.get(
self.asset_url, params=qp, auth=(self.token, "")
)
if rsp.status_code == 200:
return rsp.json()
else:
raise ConnectionError(
f"There was an error getting data from Code Ocean: "
f"{rsp.status_code}"
)

with requests.Session() as requests_session:
has_more = True
status_code = 200
start_index = 0
limit = self._MAX_SEARCH_BATCH_REQUEST
while has_more and status_code == 200:
while has_more:
query_params[self._Fields.START.value] = start_index
query_params[self._Fields.LIMIT.value] = limit
response = requests_session.get(
self.asset_url, params=query_params, auth=(self.token, "")
)

self.logger.info(response.url)

status_code = response.status_code
if status_code == 200:
has_more = response.json()[self._Fields.HAS_MORE.value]
response_results = response.json()[
self._Fields.RESULTS.value
]
num_of_results = len(response_results)
all_results.extend(response_results)
has_more = has_more if num_of_results > 0 else False
start_index += num_of_results
else:
return response
page = get_page(requests_session, query_params)
has_more = page.get(self._Fields.HAS_MORE.value)
results = page.get("results", [])
num_of_results = len(results)
has_more = has_more if num_of_results > 0 else False
start_index += num_of_results
yield results

def search_all_data_assets(
self,
sort_order: Optional[str] = None,
sort_field: Optional[str] = None,
type: Optional[str] = None,
ownership: Optional[str] = None,
favorite: Optional[bool] = None,
archived: Optional[bool] = None,
query: Optional[str] = None,
) -> requests.models.Response:
"""
Utility method to return all the search results that match a query
Parameters
----------
sort_order : Optional[str]
Determines the result sort order.
sort_field : Optional[str]
Determines the field to sort by.
type : Optional[str]
Type of data asset: dataset or result.
Returns both if omitted.
ownership : Optional[str]
Search data asset by ownership: owner or shared.
favorite : Optional[bool]
Search only favorite data assets.
archived : Optional[bool]
Search only archived data assets.
query : Optional[str]
Determines the search query.
Returns
-------
requests.models.Response
"""

# TODO: it'd be nice to re-use the search_data_assets function, but
# it'll require passing in a requests.Session object into that method.

all_results = []

for page in self._paginate_data_assets(
sort_order=sort_order,
sort_field=sort_field,
type=type,
ownership=ownership,
favorite=favorite,
archived=archived,
query=query,
):
all_results.extend(list(page))

all_response = requests.Response()
all_response.status_code = 200
Expand Down
138 changes: 111 additions & 27 deletions tests/test_codeocean_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def request_patch_response(url: str) -> MockResponse:
"""Mock a patch response"""
success_message = map_input_to_success_message(url)
return MockResponse(
status_code=202, content=success_message, url=url
status_code=204, content=success_message, url=url
)

def request_delete_response(url: str) -> MockResponse:
Expand Down Expand Up @@ -507,16 +507,6 @@ def test_search_all_data_assets(

actual_response = response.json()

bad_response = requests.Response()
bad_response.status_code = 500
bad_response._content = json.dumps(
{"message": "Internal Server Error"}
).encode("utf-8")
mock_api_get.side_effect = [bad_response]
expected_bad_response = {"message": "Internal Server Error"}
response_bad = self.co_client.search_all_data_assets(archived=False)
actual_bad_response = response_bad.json()

mock_api_get.assert_has_calls(
[
call(
Expand All @@ -529,17 +519,111 @@ def test_search_all_data_assets(
params={"start": 2, "limit": 1000},
auth=("CODEOCEAN_API_TOKEN", ""),
),
call(
"https://acmecorp.codeocean.com/api/v1/data_assets",
params={"archived": False, "start": 0, "limit": 1000},
auth=("CODEOCEAN_API_TOKEN", ""),
),
]
)
self.assertEqual(200, response.status_code)
self.assertEqual(expected_response, actual_response)
self.assertEqual(500, response_bad.status_code)
self.assertEqual(expected_bad_response, actual_bad_response)

@mock.patch("requests.Session.get")
@mock.patch("aind_codeocean_api.codeocean.sleep", return_value=None)
def test_search_all_data_assets_bad_response_max_retry_once(
self,
mock_sleep: unittest.mock.MagicMock,
mock_api_get: unittest.mock.MagicMock,
) -> None:
"""Tests search_all_data_assets method when a bad response is
returned once and then a good response is returned."""

mocked_response1 = requests.Response()
mocked_response1.status_code = 200
mocked_response1._content = json.dumps(
{
"has_more": True,
"results": [
{"id": "abc123", "type": "dataset"},
{"id": "def456", "type": "result"},
],
}
).encode("utf-8")

mocked_response2 = requests.Response()
mocked_response2.status_code = 200
mocked_response2._content = json.dumps(
{
"has_more": False,
"results": [
{"id": "ghi789", "type": "result"},
{"id": "jkl101", "type": "result"},
],
}
).encode("utf-8")

bad_response = requests.Response()
bad_response.status_code = 500
bad_response._content = json.dumps(
{"message": "Internal Server Error"}
).encode("utf-8")

mock_api_get.side_effect = [
mocked_response1,
bad_response,
mocked_response2,
]
expected_response = {
"results": [
{"id": "abc123", "type": "dataset"},
{"id": "def456", "type": "result"},
{"id": "ghi789", "type": "result"},
{"id": "jkl101", "type": "result"},
]
}
response = self.co_client.search_all_data_assets()
actual_response = response.json()
self.assertEqual(expected_response, actual_response)
mock_sleep.assert_has_calls([call(1)])

@mock.patch("requests.Session.get")
@mock.patch("aind_codeocean_api.codeocean.sleep", return_value=None)
def test_search_all_data_assets_bad_response_max_retries(
self,
mock_sleep: unittest.mock.MagicMock,
mock_api_get: unittest.mock.MagicMock,
) -> None:
"""Tests search_all_data_assets method when a bad response is
returned."""

mocked_response1 = requests.Response()
mocked_response1.status_code = 200
mocked_response1._content = json.dumps(
{
"has_more": True,
"results": [
{"id": "abc123", "type": "dataset"},
{"id": "def456", "type": "result"},
],
}
).encode("utf-8")

bad_response = requests.Response()
bad_response.status_code = 500
bad_response._content = json.dumps(
{"message": "Internal Server Error"}
).encode("utf-8")

mock_api_get.side_effect = [
mocked_response1,
bad_response,
bad_response,
bad_response,
bad_response,
]
with self.assertRaises(ConnectionError) as e:
self.co_client.search_all_data_assets()
self.assertEqual(
"There was an error getting data from Code Ocean: 500",
e.exception.args[0],
)
mock_sleep.assert_has_calls([call(1), call(4), call(9)])

@mock.patch("requests.put")
def test_update_data_asset(
Expand Down Expand Up @@ -949,12 +1033,12 @@ def mock_success_response() -> Callable[..., MockResponse]:

def request_post_response(json: dict) -> MockResponse:
"""Mock a post response"""
return MockResponse(status_code=204, content=None, url="")
return MockResponse(status_code=204, content={}, url="")

return request_post_response

users = ([{"email": "[email protected]", "role": "viewer"}],)
groups = ([{"group": "group4", "role": "viewer"}],)
users = [{"email": "[email protected]", "role": "viewer"}]
groups = [{"group": "group4", "role": "viewer"}]
everyone = "viewer"

example_data_asset_id = "648473aa-791e-4372-bd25-205cc587ec56"
Expand Down Expand Up @@ -987,12 +1071,12 @@ def mock_success_response() -> Callable[..., MockResponse]:

def request_post_response(json: dict) -> MockResponse:
"""Mock a post response"""
return MockResponse(status_code=204, content=None, url="")
return MockResponse(status_code=204, content={}, url="")

return request_post_response

users = ([{"email": "[email protected]", "role": "viewer"}],)
groups = ([{"group": "group4", "role": "viewer"}],)
users: List[dict] = [{"email": "[email protected]", "role": "viewer"}]
groups: List[dict] = [{"group": "group4", "role": "viewer"}]

example_data_asset_id = "648473aa-791e-4372-bd25-205cc587ec56"
input_json_data = {
Expand All @@ -1017,7 +1101,7 @@ def test_archive_data_asset(

def map_to_success_message(_) -> dict:
"""Map to a success message"""
return ""
return {}

mocked_success_patch = self.mock_success_response(
map_to_success_message, req_type="patch"
Expand All @@ -1036,7 +1120,7 @@ def map_to_success_message(_) -> dict:
)

self.assertEqual(response.url, expected_url)
self.assertEqual(response.status_code, 202)
self.assertEqual(response.status_code, 204)

@mock.patch("requests.delete")
def test_delete_data_asset(
Expand All @@ -1046,7 +1130,7 @@ def test_delete_data_asset(

def map_to_success_message(_) -> dict:
"""Map to a success message"""
return ""
return {}

mocked_success_delete = self.mock_success_response(
map_to_success_message, req_type="delete"
Expand Down

0 comments on commit 5aab8b0

Please sign in to comment.