diff --git a/CHANGELOG.md b/CHANGELOG.md index 3332504..799c88b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [UNRELEASED] +### Added + +- Unit tests for the `ecs.py` module. + ## [0.8.0] - 2022-08-17 ### Changed diff --git a/covalent_ecs_plugin/ecs.py b/covalent_ecs_plugin/ecs.py index 0a9c6af..4ef5244 100644 --- a/covalent_ecs_plugin/ecs.py +++ b/covalent_ecs_plugin/ecs.py @@ -148,12 +148,12 @@ def __init__( def _is_valid_subnet_id(self, subnet_id: str) -> bool: """Check if the subnet is valid.""" - return False if re.fullmatch(r"subnet-[0-9a-z]{8}", subnet_id) is None else True + return re.fullmatch(r"subnet-[0-9a-z]{8}", subnet_id) is not None def _is_valid_security_group(self, security_group: str) -> bool: """Check if the security group is valid.""" - return False if re.fullmatch(r"sg-[0-9a-z]{8}", security_group) is None else True + return re.fullmatch(r"sg-[0-9a-z]{8}", security_group) is not None def run(self, function: Callable, args: List, kwargs: Dict, task_metadata: Dict): pass diff --git a/tests/test_ecs.py b/tests/test_ecs.py index e879be1..9141c2d 100644 --- a/tests/test_ecs.py +++ b/tests/test_ecs.py @@ -54,6 +54,10 @@ def ecs_executor(mocker): def test_executor_init_default_values(mocker): """Test that the init values of the executor are set properly.""" mocker.patch("covalent_ecs_plugin.ecs.get_config", return_value="mock") + mocker.patch("covalent_ecs_plugin.ecs.ECSExecutor._is_valid_subnet_id", return_value=False) + mocker.patch( + "covalent_ecs_plugin.ecs.ECSExecutor._is_valid_security_group", return_value=False + ) ecse = ECSExecutor() assert ecse.credentials == "mock" assert ecse.profile == "mock" @@ -72,6 +76,29 @@ def test_executor_init_default_values(mocker): assert ecse.cache_dir == "mock" +def test_executor_init_validation(mocker): + """Test that subnet and security group id is validated.""" + mocker.patch("covalent_ecs_plugin.ecs.get_config", return_value="mock") + mocker.patch("covalent_ecs_plugin.ecs.ECSExecutor._is_valid_subnet_id", return_value=True) + mocker.patch( + "covalent_ecs_plugin.ecs.ECSExecutor._is_valid_security_group", return_value=False + ) + + +def test_is_valid_subnet_id(ecs_executor): + """Test the valid subnet checking method.""" + assert ecs_executor._is_valid_subnet_id("subnet-871545e1") is True + assert ecs_executor._is_valid_subnet_id("subnet-871545e") is False + assert ecs_executor._is_valid_subnet_id("jlkjlkj871545e1") is False + + +def test_is_valid_security_group(ecs_executor): + """Test the valid security group checking method.""" + assert ecs_executor._is_valid_security_group("sg-0043541a") is True + assert ecs_executor._is_valid_security_group("sg-0043541") is False + assert ecs_executor._is_valid_security_group("80980043541") is False + + def test_get_aws_account(ecs_executor, mocker): """Test the method to retrieve the aws account.""" mm = MagicMock() @@ -81,6 +108,11 @@ def test_get_aws_account(ecs_executor, mocker): mm.client().get_caller_identity.get.called_once_with("Account") +def test_execute(mocker): + """Test the execute method.""" + pass + + def test_format_exec_script(ecs_executor): """Test method that constructs the executable tasks-execution Python script.""" kwargs = { @@ -168,3 +200,45 @@ def test_package_and_upload(ecs_executor, mocker): format_exec_script_mock.assert_called_once() format_dockerfile_mock.assert_called_once() get_ecr_info_mock.assert_called_once() + + +def test_get_status(mocker, ecs_executor): + """Test the status checking method.""" + ecs_mock = MagicMock() + ecs_mock.get_paginator().paginate.return_value = [] # Case 1: no tasks found + res = ecs_executor.get_status(ecs_mock, "") + assert res == ("TASK_NOT_FOUND", -1) + + ecs_mock.get_paginator().paginate.return_value = [ + {"taskArns": ["mock_task_arn"]} + ] # Case 2 valid task found + ecs_mock.describe_tasks.return_value = { + "tasks": [ + {"taskArn": "mock_task_arn", "lastStatus": "RUNNING", "containers": [{"exitCode": 1}]} + ] + } + res = ecs_executor.get_status(ecs_mock, "mock_task_arn") + assert res == ("RUNNING", 1) + + ecs_mock.get_paginator().paginate.return_value = [ + {"taskArns": ["mock_task_arn"]} + ] # Case 3 - task found without any status + ecs_mock.describe_tasks.return_value = { + "tasks": [{"taskArn": "mock_task_arn", "lastStatus": "FAILED"}] + } + res = ecs_executor.get_status(ecs_mock, "mock_task_arn") + assert res == ("FAILED", -1) + + +def test_poll_ecs_task(mocker, ecs_executor): + """Test the method to poll the ecs task.""" + + +def test_query_result(mocker): + """Test the method to query the result.""" + pass + + +def test_cancel(mocker): + """Test the execution cancellation method.""" + pass