Skip to content

Commit

Permalink
refactor logic to use PanoptesObject and add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Tooyosi committed Sep 10, 2024
1 parent 5f596f0 commit 51ac389
Show file tree
Hide file tree
Showing 6 changed files with 114 additions and 181 deletions.
1 change: 0 additions & 1 deletion panoptes_client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,3 @@
from panoptes_client.subject_workflow_status import SubjectWorkflowStatus
from panoptes_client.caesar import Caesar
from panoptes_client.inaturalist import Inaturalist
from panoptes_client.batch_aggregation import BatchAggregation
13 changes: 13 additions & 0 deletions panoptes_client/aggregation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from panoptes_client.panoptes import PanoptesObject

class Aggregation(PanoptesObject):
_api_slug = 'aggregations'
_link_slug = 'aggregations'
_edit_attributes = (
{
'links': (
'workflow',
'user',
)
},
)
44 changes: 0 additions & 44 deletions panoptes_client/batch_aggregation.py

This file was deleted.

72 changes: 0 additions & 72 deletions panoptes_client/tests/test_batch_aggregation.py

This file was deleted.

100 changes: 60 additions & 40 deletions panoptes_client/tests/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
from panoptes_client.panoptes import PanoptesAPIException
from panoptes_client.workflow import Workflow
from panoptes_client.caesar import Caesar
from panoptes_client.batch_aggregation import BatchAggregation
from panoptes_client.aggregation import Aggregation

if sys.version_info <= (3, 0):
from mock import patch
from mock import patch, MagicMock
else:
from unittest.mock import patch
from unittest.mock import patch, MagicMock


class TestWorkflow(unittest.TestCase):
Expand All @@ -25,14 +25,6 @@ def setUp(self):
self.addCleanup(caesar_get_patch.stop)
self.addCleanup(caesar_put_patch.stop)

batch_agg_run_aggregation_patch = patch.object(BatchAggregation, 'run_aggregation')
batch_agg_get_aggregations_patch = patch.object(BatchAggregation, 'get_aggregations')
self.batch_agg_run_aggregation_mock = batch_agg_run_aggregation_patch.start()
self.batch_agg_get_aggregations_mock = batch_agg_get_aggregations_patch.start()

self.addCleanup(batch_agg_run_aggregation_patch.stop)
self.addCleanup(batch_agg_get_aggregations_patch.stop)

self.agg_mock_value = [{
'aggregations': [{
'id': '1',
Expand Down Expand Up @@ -231,40 +223,68 @@ def test_add_caesar_rule_effect_invalid_effect(self):
self.caesar_post_mock.assert_not_called()
self.assertEqual('Invalid action for rule type', str(invalid_effect_err.exception))

def test_get_batch_aggregations_without_delete_param(self):
workflow = Workflow(1)
workflow.run_batch_aggregation(user=1)
payload = {
"aggregations": {
"links": {
"user": 1,
"workflow": workflow.id,
}
}
}
self.batch_agg_run_aggregation_mock.assert_called_with(payload, False)
class TestAggregation(unittest.TestCase):
def setUp(self):
self.instance = Workflow(1)
self.mock_user_id = 1

def test_get_batch_aggregations_with_invalid_params(self):
with self.assertRaises(TypeError):
workflow = Workflow(1)
workflow.run_batch_aggregation(user=None)
def _mock_aggregation(self):
mock_aggregation = MagicMock()
mock_aggregation.object_count = 1
mock_aggregation.next = MagicMock(return_value=MagicMock(id=1))
return mock_aggregation

self.batch_agg_run_aggregation_mock.assert_not_called()
@patch.object(Aggregation, 'where')
@patch.object(Aggregation, 'find')
def test_run_aggregation_with_user_object(self, mock_find, mock_where):
mock_where.return_value = self._mock_aggregation()

def test_check_batch_aggregation_run_status(self):
self.batch_agg_get_aggregations_mock.return_value = self.agg_mock_value
mock_current_agg = MagicMock()
mock_find.return_value = mock_current_agg

workflow = Workflow(1)
status = workflow.check_batch_aggregation_run_status()
result = self.instance.run_aggregation(self.mock_user_id, False)

self.batch_agg_get_aggregations_mock.assert_called_once()
self.assertEqual(status, 'pending')
self.assertEqual(result, mock_current_agg)

def test_get_batch_aggregation_links(self):
self.batch_agg_get_aggregations_mock.return_value = self.agg_mock_value
@patch.object(Aggregation, 'find')
@patch.object(Aggregation, 'where')
@patch.object(Aggregation, 'save')
def test_run_aggregation_with_delete_if_true(self, mock_save, mock_where, mock_find):
mock_where.return_value = self._mock_aggregation()

workflow = Workflow(1)
links = workflow.get_batch_aggregation_links()
mock_current_agg = MagicMock()
mock_current_agg.delete = MagicMock()
mock_find.return_value = mock_current_agg


mock_save_func = MagicMock()

self.batch_agg_get_aggregations_mock.assert_called_once()
self.assertEqual(links, None)
mock_save.return_value = mock_save_func()
self.instance.run_aggregation(self.mock_user_id,True)

mock_current_agg.delete.assert_called_once()

mock_save_func.assert_called_once()

@patch.object(Workflow, 'get_batch_aggregations')
def test_get_agg_property(self, mock_get_batch_aggregations):
mock_aggregation = self._mock_aggregation()
mock_aggregation.test_property = 'returned_test_value'

mock_aggregations = MagicMock()
mock_aggregations.next.return_value = mock_aggregation
mock_get_batch_aggregations.return_value = mock_aggregations


result = self.instance._get_agg_property('test_property')

self.assertEqual(result, 'returned_test_value')

@patch.object(Workflow, 'get_batch_aggregations')
def test_get_agg_property_failed(self, mock_get_batch_aggregations):
mock_aggregations = MagicMock()
mock_aggregations.next.side_effect = StopIteration
mock_get_batch_aggregations.return_value = mock_aggregations

with self.assertRaises(PanoptesAPIException):
self.instance._get_agg_property('some_property')
65 changes: 41 additions & 24 deletions panoptes_client/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@
from panoptes_client.subject_workflow_status import SubjectWorkflowStatus

from panoptes_client.exportable import Exportable
from panoptes_client.panoptes import PanoptesObject, LinkResolver
from panoptes_client.panoptes import PanoptesObject, LinkResolver, PanoptesAPIException
from panoptes_client.subject import Subject
from panoptes_client.subject_set import SubjectSet
from panoptes_client.utils import batchable

from panoptes_client.caesar import Caesar
from panoptes_client.batch_aggregation import BatchAggregation
from panoptes_client.user import User
from panoptes_client.aggregation import Aggregation


class Workflow(PanoptesObject, Exportable):
Expand Down Expand Up @@ -532,14 +532,7 @@ def configure_for_alice(self):
self.add_alice_reducers()
self.add_alice_rules_and_effects()

def get_batch_aggregations(self):
"""
This method will fetch existing aggregations if any. It will return a dict with the existing aggregations.
"""

return BatchAggregation().get_aggregations(self.id)[0]

def run_batch_aggregation(self, user=None, delete_if_exists=False):
def run_aggregation(self, user=None, delete_if_exists=False):
"""
This method will start a new batch aggregation run, Will return a dict with the created aggregation if successful.
Expand All @@ -548,38 +541,62 @@ def run_batch_aggregation(self, user=None, delete_if_exists=False):
-
Examples::
Workflow(1234).run_batch_aggregation(1234)
Workflow(1234).run_batch_aggregation(user=1234, delete_if_exists=True)
Workflow(1234).run_aggregation(1234)
Workflow(1234).run_aggregation(user=1234, delete_if_exists=True)
"""

if(isinstance(user, User)):
_user_id = user.id
elif (isinstance(user, (int, str,))):
_user_id = user
else:
raise TypeError
raise TypeError('Invalid user parameter')

try:
workflow_aggs = self.get_batch_aggregations()
if workflow_aggs.object_count > 0:
agg_id = workflow_aggs.next().id
current_wf_agg = Aggregation.find(agg_id)
if delete_if_exists:
current_wf_agg.delete()
return self._create_agg(_user_id)
else:
return current_wf_agg
else:
return self._create_agg(_user_id)
except PanoptesAPIException as err:
raise err

payload = {
"aggregations": {
"links": {
"user": _user_id,
"workflow": self.id,
}
}
}
return BatchAggregation().run_aggregation(payload, delete_if_exists)
def get_batch_aggregations(self):
return Aggregation.where(workflow_id=self.id)

def _create_agg(self, user_id):
new_agg = Aggregation()
new_agg.links.workflow = self.id
new_agg.links.user = user_id
new_agg.save()
return new_agg

def _get_agg_property(self, param):
try:
aggs = self.get_batch_aggregations()
return getattr(aggs.next(), param, None)
except StopIteration:
raise PanoptesAPIException(
"Could not find Aggregations for Workflow with id='{}'".format(self.id)
)

def check_batch_aggregation_run_status(self):
"""
This method will fetch existing aggregation status if any.
"""
return self.get_batch_aggregations()['aggregations'][0]['status']
return self._get_agg_property('status')

def get_batch_aggregation_links(self):
"""
This method will fetch existing aggregation links if any.
"""
return self.get_batch_aggregations()['aggregations'][0]['uuid']
return self._get_agg_property('uuid')

@property
def versions(self):
Expand Down

0 comments on commit 51ac389

Please sign in to comment.