From 51ac38928d6279c6479c7ac4cb32ab026907230d Mon Sep 17 00:00:00 2001 From: tooyosi Date: Tue, 10 Sep 2024 06:31:14 +0100 Subject: [PATCH] refactor logic to use PanoptesObject and add tests --- panoptes_client/__init__.py | 1 - panoptes_client/aggregation.py | 13 +++ panoptes_client/batch_aggregation.py | 44 -------- .../tests/test_batch_aggregation.py | 72 ------------- panoptes_client/tests/test_workflow.py | 100 +++++++++++------- panoptes_client/workflow.py | 65 +++++++----- 6 files changed, 114 insertions(+), 181 deletions(-) create mode 100644 panoptes_client/aggregation.py delete mode 100644 panoptes_client/batch_aggregation.py delete mode 100644 panoptes_client/tests/test_batch_aggregation.py diff --git a/panoptes_client/__init__.py b/panoptes_client/__init__.py index da797507..c0fe78c0 100644 --- a/panoptes_client/__init__.py +++ b/panoptes_client/__init__.py @@ -13,4 +13,3 @@ from panoptes_client.subject_workflow_status import SubjectWorkflowStatus from panoptes_client.caesar import Caesar from panoptes_client.inaturalist import Inaturalist -from panoptes_client.batch_aggregation import BatchAggregation diff --git a/panoptes_client/aggregation.py b/panoptes_client/aggregation.py new file mode 100644 index 00000000..28c1c404 --- /dev/null +++ b/panoptes_client/aggregation.py @@ -0,0 +1,13 @@ +from panoptes_client.panoptes import PanoptesObject + +class Aggregation(PanoptesObject): + _api_slug = 'aggregations' + _link_slug = 'aggregations' + _edit_attributes = ( + { + 'links': ( + 'workflow', + 'user', + ) + }, + ) \ No newline at end of file diff --git a/panoptes_client/batch_aggregation.py b/panoptes_client/batch_aggregation.py deleted file mode 100644 index 8aa4f95a..00000000 --- a/panoptes_client/batch_aggregation.py +++ /dev/null @@ -1,44 +0,0 @@ -from panoptes_client.panoptes import Panoptes, PanoptesAPIException - - -class BatchAggregation(object): - - def get_aggregations(self, workflow_id): - return Panoptes.client().get(f'/aggregations?workflow_id={workflow_id}') - - def get_aggregation(self, agg_id): - return Panoptes.client().get(f'/aggregations/{agg_id}') - - def create_aggregation(self, payload): - return Panoptes.client().post(f'/aggregations', json=payload) - - def delete_aggregation(self, agg_id, etag): - return Panoptes.client().delete(f'/aggregations/{agg_id}', etag=etag) - - def fetch_and_delete_aggregation(self, agg_id): - try: - single_agg = self.get_aggregation(agg_id) - self.delete_aggregation(agg_id, single_agg[1]) - except PanoptesAPIException as err: - raise err - - def run_aggregation(self, payload, delete_if_exists): - try: - # if exists and delete_if_exists, delete the first - response = self.get_aggregations(payload["aggregations"]["links"]["workflow"]) - if response and response[0]: - has_agg = len(response[0]['aggregations']) > 0 - if delete_if_exists and has_agg: - agg_id = response[0]['aggregations'][0]['id'] - # delete - self.fetch_and_delete_aggregation(agg_id) - # create - return self.create_aggregation(payload)[0] - elif not has_agg: - return self.create_aggregation(payload)[0] - else: - return response[0] - else: - self.create_aggregation(payload)[0] - except PanoptesAPIException as err: - raise err diff --git a/panoptes_client/tests/test_batch_aggregation.py b/panoptes_client/tests/test_batch_aggregation.py deleted file mode 100644 index 6548380f..00000000 --- a/panoptes_client/tests/test_batch_aggregation.py +++ /dev/null @@ -1,72 +0,0 @@ -import unittest -import sys -from panoptes_client.batch_aggregation import BatchAggregation - -if sys.version_info <= (3, 0): - from mock import patch -else: - from unittest.mock import patch - -class TestBatchAggregation(unittest.TestCase): - - def setUp(self): - super().setUp() - self.mock_client_patch = patch('panoptes_client.panoptes.Panoptes.client') - self.mock_client = self.mock_client_patch.start() - - self.addCleanup(self.mock_client_patch.stop) - - batch_agg_get_aggregations_patch = patch.object(BatchAggregation, 'get_aggregations') - fetch_and_delete_aggregation_patch = patch.object(BatchAggregation, 'fetch_and_delete_aggregation') - create_aggregation_patch = patch.object(BatchAggregation, 'create_aggregation') - - self.batch_agg_get_aggregations_mock = batch_agg_get_aggregations_patch.start() - self.fetch_and_delete_aggregation_mock = fetch_and_delete_aggregation_patch.start() - self.create_aggregation_mock = create_aggregation_patch.start() - - self.addCleanup(batch_agg_get_aggregations_patch.stop) - self.addCleanup(fetch_and_delete_aggregation_patch.stop) - self.addCleanup(create_aggregation_patch.stop) - - self.payload = { - "aggregations": { - "links": { - "user": 1, - "workflow": 1, - } - } - } - - self.agg_mock_value = [{ - 'aggregations': [{ - 'id': '1', - 'href': '/aggregations/1', - 'created_at': '2024-08-13T10:26:32.560Z', - 'updated_at': '2024-08-13T10:26:32.576Z', - 'uuid': None, - 'task_id': 'task_id', - 'status': 'pending', - 'links': {'project': '1', 'workflow': '1', 'user': '1'} - }] - }, 'etag'] - - def test_run_aggregation_with_delete_if_exists(self): - self.batch_agg_get_aggregations_mock.return_value = self.agg_mock_value - - BatchAggregation().run_aggregation(payload=self.payload, delete_if_exists=True) - self.fetch_and_delete_aggregation_mock.assert_called_with(self.agg_mock_value[0]['aggregations'][0]['id']) - self.create_aggregation_mock.assert_called_with(self.payload) - - def test_run_aggregation_with_previously_created_case(self): - self.batch_agg_get_aggregations_mock.return_value = self.agg_mock_value - - BatchAggregation().run_aggregation(payload=self.payload, delete_if_exists=False) - self.fetch_and_delete_aggregation_mock.assert_not_called() - self.create_aggregation_mock.assert_not_called() - - def test_run_aggregation_with_no_previously_created_case(self): - self.batch_agg_get_aggregations_mock.return_value = [{'aggregations': []}, 'etag'] - - BatchAggregation().run_aggregation(payload=self.payload, delete_if_exists=False) - self.fetch_and_delete_aggregation_mock.assert_not_called() - self.create_aggregation_mock.assert_called_with(self.payload) diff --git a/panoptes_client/tests/test_workflow.py b/panoptes_client/tests/test_workflow.py index 0ebc56e0..c8f49dcd 100644 --- a/panoptes_client/tests/test_workflow.py +++ b/panoptes_client/tests/test_workflow.py @@ -3,12 +3,12 @@ from panoptes_client.panoptes import PanoptesAPIException from panoptes_client.workflow import Workflow from panoptes_client.caesar import Caesar -from panoptes_client.batch_aggregation import BatchAggregation +from panoptes_client.aggregation import Aggregation if sys.version_info <= (3, 0): - from mock import patch + from mock import patch, MagicMock else: - from unittest.mock import patch + from unittest.mock import patch, MagicMock class TestWorkflow(unittest.TestCase): @@ -25,14 +25,6 @@ def setUp(self): self.addCleanup(caesar_get_patch.stop) self.addCleanup(caesar_put_patch.stop) - batch_agg_run_aggregation_patch = patch.object(BatchAggregation, 'run_aggregation') - batch_agg_get_aggregations_patch = patch.object(BatchAggregation, 'get_aggregations') - self.batch_agg_run_aggregation_mock = batch_agg_run_aggregation_patch.start() - self.batch_agg_get_aggregations_mock = batch_agg_get_aggregations_patch.start() - - self.addCleanup(batch_agg_run_aggregation_patch.stop) - self.addCleanup(batch_agg_get_aggregations_patch.stop) - self.agg_mock_value = [{ 'aggregations': [{ 'id': '1', @@ -231,40 +223,68 @@ def test_add_caesar_rule_effect_invalid_effect(self): self.caesar_post_mock.assert_not_called() self.assertEqual('Invalid action for rule type', str(invalid_effect_err.exception)) - def test_get_batch_aggregations_without_delete_param(self): - workflow = Workflow(1) - workflow.run_batch_aggregation(user=1) - payload = { - "aggregations": { - "links": { - "user": 1, - "workflow": workflow.id, - } - } - } - self.batch_agg_run_aggregation_mock.assert_called_with(payload, False) +class TestAggregation(unittest.TestCase): + def setUp(self): + self.instance = Workflow(1) + self.mock_user_id = 1 - def test_get_batch_aggregations_with_invalid_params(self): - with self.assertRaises(TypeError): - workflow = Workflow(1) - workflow.run_batch_aggregation(user=None) + def _mock_aggregation(self): + mock_aggregation = MagicMock() + mock_aggregation.object_count = 1 + mock_aggregation.next = MagicMock(return_value=MagicMock(id=1)) + return mock_aggregation - self.batch_agg_run_aggregation_mock.assert_not_called() + @patch.object(Aggregation, 'where') + @patch.object(Aggregation, 'find') + def test_run_aggregation_with_user_object(self, mock_find, mock_where): + mock_where.return_value = self._mock_aggregation() - def test_check_batch_aggregation_run_status(self): - self.batch_agg_get_aggregations_mock.return_value = self.agg_mock_value + mock_current_agg = MagicMock() + mock_find.return_value = mock_current_agg - workflow = Workflow(1) - status = workflow.check_batch_aggregation_run_status() + result = self.instance.run_aggregation(self.mock_user_id, False) - self.batch_agg_get_aggregations_mock.assert_called_once() - self.assertEqual(status, 'pending') + self.assertEqual(result, mock_current_agg) - def test_get_batch_aggregation_links(self): - self.batch_agg_get_aggregations_mock.return_value = self.agg_mock_value + @patch.object(Aggregation, 'find') + @patch.object(Aggregation, 'where') + @patch.object(Aggregation, 'save') + def test_run_aggregation_with_delete_if_true(self, mock_save, mock_where, mock_find): + mock_where.return_value = self._mock_aggregation() - workflow = Workflow(1) - links = workflow.get_batch_aggregation_links() + mock_current_agg = MagicMock() + mock_current_agg.delete = MagicMock() + mock_find.return_value = mock_current_agg + + + mock_save_func = MagicMock() - self.batch_agg_get_aggregations_mock.assert_called_once() - self.assertEqual(links, None) + mock_save.return_value = mock_save_func() + self.instance.run_aggregation(self.mock_user_id,True) + + mock_current_agg.delete.assert_called_once() + + mock_save_func.assert_called_once() + + @patch.object(Workflow, 'get_batch_aggregations') + def test_get_agg_property(self, mock_get_batch_aggregations): + mock_aggregation = self._mock_aggregation() + mock_aggregation.test_property = 'returned_test_value' + + mock_aggregations = MagicMock() + mock_aggregations.next.return_value = mock_aggregation + mock_get_batch_aggregations.return_value = mock_aggregations + + + result = self.instance._get_agg_property('test_property') + + self.assertEqual(result, 'returned_test_value') + + @patch.object(Workflow, 'get_batch_aggregations') + def test_get_agg_property_failed(self, mock_get_batch_aggregations): + mock_aggregations = MagicMock() + mock_aggregations.next.side_effect = StopIteration + mock_get_batch_aggregations.return_value = mock_aggregations + + with self.assertRaises(PanoptesAPIException): + self.instance._get_agg_property('some_property') diff --git a/panoptes_client/workflow.py b/panoptes_client/workflow.py index bbccb4df..50a91b90 100644 --- a/panoptes_client/workflow.py +++ b/panoptes_client/workflow.py @@ -5,14 +5,14 @@ from panoptes_client.subject_workflow_status import SubjectWorkflowStatus from panoptes_client.exportable import Exportable -from panoptes_client.panoptes import PanoptesObject, LinkResolver +from panoptes_client.panoptes import PanoptesObject, LinkResolver, PanoptesAPIException from panoptes_client.subject import Subject from panoptes_client.subject_set import SubjectSet from panoptes_client.utils import batchable from panoptes_client.caesar import Caesar -from panoptes_client.batch_aggregation import BatchAggregation from panoptes_client.user import User +from panoptes_client.aggregation import Aggregation class Workflow(PanoptesObject, Exportable): @@ -532,14 +532,7 @@ def configure_for_alice(self): self.add_alice_reducers() self.add_alice_rules_and_effects() - def get_batch_aggregations(self): - """ - This method will fetch existing aggregations if any. It will return a dict with the existing aggregations. - """ - - return BatchAggregation().get_aggregations(self.id)[0] - - def run_batch_aggregation(self, user=None, delete_if_exists=False): + def run_aggregation(self, user=None, delete_if_exists=False): """ This method will start a new batch aggregation run, Will return a dict with the created aggregation if successful. @@ -548,8 +541,8 @@ def run_batch_aggregation(self, user=None, delete_if_exists=False): - Examples:: - Workflow(1234).run_batch_aggregation(1234) - Workflow(1234).run_batch_aggregation(user=1234, delete_if_exists=True) + Workflow(1234).run_aggregation(1234) + Workflow(1234).run_aggregation(user=1234, delete_if_exists=True) """ if(isinstance(user, User)): @@ -557,29 +550,53 @@ def run_batch_aggregation(self, user=None, delete_if_exists=False): elif (isinstance(user, (int, str,))): _user_id = user else: - raise TypeError + raise TypeError('Invalid user parameter') + + try: + workflow_aggs = self.get_batch_aggregations() + if workflow_aggs.object_count > 0: + agg_id = workflow_aggs.next().id + current_wf_agg = Aggregation.find(agg_id) + if delete_if_exists: + current_wf_agg.delete() + return self._create_agg(_user_id) + else: + return current_wf_agg + else: + return self._create_agg(_user_id) + except PanoptesAPIException as err: + raise err - payload = { - "aggregations": { - "links": { - "user": _user_id, - "workflow": self.id, - } - } - } - return BatchAggregation().run_aggregation(payload, delete_if_exists) + def get_batch_aggregations(self): + return Aggregation.where(workflow_id=self.id) + + def _create_agg(self, user_id): + new_agg = Aggregation() + new_agg.links.workflow = self.id + new_agg.links.user = user_id + new_agg.save() + return new_agg + + def _get_agg_property(self, param): + try: + aggs = self.get_batch_aggregations() + return getattr(aggs.next(), param, None) + except StopIteration: + raise PanoptesAPIException( + "Could not find Aggregations for Workflow with id='{}'".format(self.id) + ) def check_batch_aggregation_run_status(self): """ This method will fetch existing aggregation status if any. """ - return self.get_batch_aggregations()['aggregations'][0]['status'] + return self._get_agg_property('status') def get_batch_aggregation_links(self): """ This method will fetch existing aggregation links if any. """ - return self.get_batch_aggregations()['aggregations'][0]['uuid'] + return self._get_agg_property('uuid') @property def versions(self):