Skip to content

Commit

Permalink
implemet batch aggregation on the client (#318)
Browse files Browse the repository at this point in the history
* implemet batch aggregation calls on workflow. With tests

* hound fixes

* fix failing test

* refactor logic to use PanoptesObject and add tests

* clean hound sniffs

* remove set_attr usage

* remove unanted mock result

* remove declared next variable
  • Loading branch information
Tooyosi authored Sep 20, 2024
1 parent b5439dc commit 1962b46
Show file tree
Hide file tree
Showing 3 changed files with 148 additions and 4 deletions.
14 changes: 14 additions & 0 deletions panoptes_client/aggregation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from panoptes_client.panoptes import PanoptesObject


class Aggregation(PanoptesObject):
_api_slug = 'aggregations'
_link_slug = 'aggregations'
_edit_attributes = (
{
'links': (
'workflow',
'user',
)
},
)
66 changes: 64 additions & 2 deletions panoptes_client/tests/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
from panoptes_client.panoptes import PanoptesAPIException
from panoptes_client.workflow import Workflow
from panoptes_client.caesar import Caesar
from panoptes_client.aggregation import Aggregation

if sys.version_info <= (3, 0):
from mock import patch
from mock import patch, MagicMock
else:
from unittest.mock import patch
from unittest.mock import patch, MagicMock


class TestWorkflow(unittest.TestCase):
Expand Down Expand Up @@ -208,3 +209,64 @@ def test_add_caesar_rule_effect_invalid_effect(self):

self.caesar_post_mock.assert_not_called()
self.assertEqual('Invalid action for rule type', str(invalid_effect_err.exception))


class TestAggregation(unittest.TestCase):
def setUp(self):
self.instance = Workflow(1)
self.mock_user_id = 1

def _mock_aggregation(self):
mock_aggregation = MagicMock()
mock_aggregation.object_count = 1
mock_aggregation.next = MagicMock(return_value=MagicMock(id=1))
return mock_aggregation

@patch.object(Aggregation, 'where')
@patch.object(Aggregation, 'find')
def test_run_aggregation_with_user_object(self, mock_find, mock_where):
mock_where.return_value = self._mock_aggregation()

mock_current_agg = MagicMock()
mock_find.return_value = mock_current_agg

result = self.instance.run_aggregation(self.mock_user_id, False)

self.assertEqual(result, mock_current_agg)

@patch.object(Aggregation, 'find')
@patch.object(Aggregation, 'where')
@patch.object(Aggregation, 'save')
def test_run_aggregation_with_delete_if_true(self, mock_save, mock_where, mock_find):
mock_where.return_value = self._mock_aggregation()

mock_current_agg = MagicMock()
mock_current_agg.delete = MagicMock()
mock_find.return_value = mock_current_agg

mock_save_func = MagicMock()

mock_save.return_value = mock_save_func()
self.instance.run_aggregation(self.mock_user_id, True)

mock_current_agg.delete.assert_called_once()

mock_save_func.assert_called_once()

@patch.object(Workflow, 'get_batch_aggregations')
def test_get_agg_property(self, mock_get_batch_aggregations):
mock_aggregation = MagicMock()
mock_aggregation.test_property = 'returned_test_value'

mock_get_batch_aggregations.return_value = iter([mock_aggregation])

result = self.instance._get_agg_property('test_property')

self.assertEqual(result, 'returned_test_value')

@patch.object(Workflow, 'get_batch_aggregations')
def test_get_agg_property_failed(self, mock_get_batch_aggregations):
mock_get_batch_aggregations.return_value = iter([])

with self.assertRaises(PanoptesAPIException):
self.instance._get_agg_property('test_property')
72 changes: 70 additions & 2 deletions panoptes_client/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@
from panoptes_client.subject_workflow_status import SubjectWorkflowStatus

from panoptes_client.exportable import Exportable
from panoptes_client.panoptes import PanoptesObject, LinkResolver
from panoptes_client.panoptes import PanoptesObject, LinkResolver, PanoptesAPIException
from panoptes_client.subject import Subject
from panoptes_client.subject_set import SubjectSet
from panoptes_client.utils import batchable

from panoptes_client.caesar import Caesar

from panoptes_client.user import User
from panoptes_client.aggregation import Aggregation
import six

class Workflow(PanoptesObject, Exportable):
_api_slug = 'workflows'
Expand Down Expand Up @@ -530,6 +532,72 @@ def configure_for_alice(self):
self.add_alice_reducers()
self.add_alice_rules_and_effects()

def run_aggregation(self, user=None, delete_if_exists=False):
"""
This method will start a new batch aggregation run, Will return a dict with the created aggregation if successful.
- **user** can be either a :py:class:`.User` or an ID.
- **delete_if_exists** parameter is optional if true, deletes any previous instance
-
Examples::
Workflow(1234).run_aggregation(1234)
Workflow(1234).run_aggregation(user=1234, delete_if_exists=True)
"""

if(isinstance(user, User)):
_user_id = user.id
elif (isinstance(user, (int, str,))):
_user_id = user
else:
raise TypeError('Invalid user parameter')

try:
workflow_aggs = self.get_batch_aggregations()
if workflow_aggs.object_count > 0:
agg_id = workflow_aggs.next().id
current_wf_agg = Aggregation.find(agg_id)
if delete_if_exists:
current_wf_agg.delete()
return self._create_agg(_user_id)
else:
return current_wf_agg
else:
return self._create_agg(_user_id)
except PanoptesAPIException as err:
raise err

def get_batch_aggregations(self):
return Aggregation.where(workflow_id=self.id)

def _create_agg(self, user_id):
new_agg = Aggregation()
new_agg.links.workflow = self.id
new_agg.links.user = user_id
new_agg.save()
return new_agg

def _get_agg_property(self, param):
try:
aggs = self.get_batch_aggregations()
return getattr(six.next(aggs), param, None)
except StopIteration:
raise PanoptesAPIException(
"Could not find Aggregations for Workflow with id='{}'".format(self.id)
)

def check_batch_aggregation_run_status(self):
"""
This method will fetch existing aggregation status if any.
"""
return self._get_agg_property('status')

def get_batch_aggregation_links(self):
"""
This method will fetch existing aggregation links if any.
"""
return self._get_agg_property('uuid')

@property
def versions(self):
"""
Expand Down

0 comments on commit 1962b46

Please sign in to comment.