diff --git a/cape/cape_main.py b/cape/cape_main.py index de96558..9100064 100644 --- a/cape/cape_main.py +++ b/cape/cape_main.py @@ -234,8 +234,11 @@ def start(self) -> None: # noinspection PyTypeChecker def execute(self, request: ServiceRequest) -> None: if not len(self.hosts): - raise CapeHostsUnavailable( - "All hosts are unavailable at the moment, as determined by a previous execution.") + # If all hosts were removed from a previous execution, reset hosts + for host in self.config["remote_host_details"]["hosts"]: + host["auth_header"] = {'Authorization': f"{self.config.get('token_key', DEFAULT_TOKEN_KEY)} {host['token']}"} + del host["token"] + self.hosts = self.config["remote_host_details"]["hosts"] self.request = request self.session = requests.Session() @@ -500,7 +503,6 @@ def submit(self, file_content: bytes, cape_task: CapeTask, parent_section: Resul "This is usually due to an issue on CAPE's machinery end." " Contact the CAPE administrator for details.") parent_section.add_subsection(task_timeout_sec) - cape_task.id = None raise AnalysisTimeoutExceeded() elif status == TASK_MISSING: err_msg = f"Task {cape_task.id} went missing while waiting for CAPE to analyze file." diff --git a/tests/test_cape_main.py b/tests/test_cape_main.py index 9f54f8b..de14ed9 100644 --- a/tests/test_cape_main.py +++ b/tests/test_cape_main.py @@ -506,7 +506,7 @@ def test_execute(sample, cape_class_instance, mocker): from assemblyline_v4_service.common.task import Task from assemblyline.odm.messages.task import Task as ServiceTask from assemblyline_v4_service.common.request import ServiceRequest - from cape.cape_main import CAPE, CapeHostsUnavailable + from cape.cape_main import CAPE mocker.patch('cape.cape_main.generate_random_words', return_value="blah") mocker.patch.object(CAPE, "_decode_mime_encoded_file_name", return_value=None) @@ -525,8 +525,11 @@ def test_execute(sample, cape_class_instance, mocker): # Coverage test mocker.patch.object(CAPE, "_assign_file_extension", return_value=None) - with pytest.raises(CapeHostsUnavailable): - cape_class_instance.execute(service_request) + + cape_class_instance.config["remote_host_details"]["hosts"] = [{"token": "blah"}] + cape_class_instance.execute(service_request) + assert cape_class_instance.hosts == [{"auth_header": {"Authorization": "Token blah"}}] + cape_class_instance.hosts = [{"ip": "1.1.1.1"}] cape_class_instance.execute(service_request) assert True @@ -683,7 +686,7 @@ def test_submit(task_id, poll_started_status, poll_report_status, cape_class_ins f"an issue on CAPE's machinery end. Contact the CAPE " f"administrator for details.") check_section_equality(parent_section.subsections[0], correct_sec) - assert cape_task.id is None + assert cape_task.id == 1 elif (poll_started_status == TASK_MISSING and poll_report_status is None) or (poll_started_status == TASK_STARTED and poll_report_status == TASK_MISSING): with pytest.raises(RecoverableError): cape_class_instance.submit(file_content, cape_task, parent_section)