diff --git a/neptune/experiments.py b/neptune/experiments.py index b265bf063..ebe05245f 100644 --- a/neptune/experiments.py +++ b/neptune/experiments.py @@ -168,11 +168,18 @@ def get_system_properties(self): def get_tags(self): return self._client.get_experiment(self._internal_id).tags - def append_tag(self, tag): + def append_tag(self, tag, *tags): + if isinstance(tag, list): + tags_list = tag + else: + tags_list = [tag] + list(tags) self._client.update_tags(experiment=self, - tags_to_add=[tag], + tags_to_add=tags_list, tags_to_delete=[]) + def append_tags(self, tag, *tags): + self.append_tag(tag, *tags) + def remove_tag(self, tag): self._client.update_tags(experiment=self, tags_to_add=[], diff --git a/tests/neptune/test_experiment.py b/tests/neptune/test_experiment.py index 7557a95d6..b1e37cf33 100644 --- a/tests/neptune/test_experiment.py +++ b/tests/neptune/test_experiment.py @@ -21,7 +21,7 @@ import mock import pandas as pd -from mock import MagicMock +from mock import call, MagicMock from munch import Munch from pandas.util.testing import assert_frame_equal @@ -101,6 +101,37 @@ def test_send_image(self, ChannelsValuesSender, content): # then channels_values_sender.send.assert_called_with('errors', ChannelType.IMAGE.value, channel_value) + def test_append_tags(self): + # given + client = mock.MagicMock() + experiment = Experiment(client, an_experiment_id(), a_uuid_string(), a_project_qualified_name()) + + #and + def build_call(tags_list): + return call( + experiment=experiment, + tags_to_add=tags_list, + tags_to_delete=[] + ) + + # when + experiment.append_tag('tag') + experiment.append_tag(['tag1', 'tag2', 'tag3']) + experiment.append_tag('tag1', 'tag2', 'tag3') + experiment.append_tags('tag') + experiment.append_tags(['tag1', 'tag2', 'tag3']) + experiment.append_tags('tag1', 'tag2', 'tag3') + + # then + client.update_tags.assert_has_calls([ + build_call(['tag']), + build_call(['tag1', 'tag2', 'tag3']), + build_call(['tag1', 'tag2', 'tag3']), + build_call(['tag']), + build_call(['tag1', 'tag2', 'tag3']), + build_call(['tag1', 'tag2', 'tag3']) + ]) + def test_get_numeric_channels_values(self): # when client = MagicMock()