diff --git a/.github/workflows/daily.yaml b/.github/workflows/daily.yaml index 0b428a39b2..3c65b0f4fa 100644 --- a/.github/workflows/daily.yaml +++ b/.github/workflows/daily.yaml @@ -88,6 +88,8 @@ jobs: code-eval-apikey: ${{ secrets.CODE_EVAL_APIKEY }} gcs-key: ${{ secrets.GCS_KEY }} gcs-secret: ${{ secrets.GCS_SECRET }} + azure-account-name: ${{ secrets.AZURE_ACCOUNT_NAME }} + azure-account-access-key: ${{ secrets.AZURE_ACCOUNT_ACCESS_KEY }} coverage: uses: ./.github/workflows/coverage.yaml name: Coverage Results diff --git a/.github/workflows/pytest-cpu.yaml b/.github/workflows/pytest-cpu.yaml index 3f237424ba..af95f8918f 100644 --- a/.github/workflows/pytest-cpu.yaml +++ b/.github/workflows/pytest-cpu.yaml @@ -45,6 +45,10 @@ on: required: false gcs-secret: required: false + azure-account-name: + required: false + azure-account-access-key: + required: false jobs: pytest-cpu: timeout-minutes: 30 @@ -75,6 +79,8 @@ jobs: export CODE_EVAL_APIKEY='${{ secrets.code-eval-apikey }}' export GCS_KEY='${{ secrets.gcs-key }}' export GCS_SECRET='${{ secrets.gcs-secret }}' + export AZURE_ACCOUNT_NAME='${{ secrets.azure-account-name }}' + export AZURE_ACCOUNT_ACCESS_KEY='${{ secrets.azure-account-access-key }}' export S3_BUCKET='${{ inputs.pytest-s3-bucket }}' export COMMON_ARGS="-v --durations=20 -m '${{ inputs.pytest-markers }}' --s3_bucket '$S3_BUCKET' \ -o tmp_path_retention_policy=none" diff --git a/tests/utils/object_store/test_azure_object_store.py b/tests/utils/object_store/test_azure_object_store.py new file mode 100644 index 0000000000..949e2149ff --- /dev/null +++ b/tests/utils/object_store/test_azure_object_store.py @@ -0,0 +1,33 @@ +# Copyright 2022 MosaicML Composer authors +# SPDX-License-Identifier: Apache-2.0 + +import pytest +from torch.utils.data import DataLoader + +from composer.trainer import Trainer +from tests.common import RandomClassificationDataset, SimpleModel + + +@pytest.mark.remote +def test_azure_object_store_integration(): + model = SimpleModel() + train_dataloader = DataLoader(dataset=RandomClassificationDataset()) + trainer_save = Trainer( + model=model, + train_dataloader=train_dataloader, + save_folder='azure://mosaicml-composer-tests/checkpoints/{run_name}', + save_filename='test-model.pt', + max_duration='1ba', + ) + run_name = trainer_save.state.run_name + trainer_save.fit() + trainer_save.close() + + trainer_load = Trainer( + model=model, + train_dataloader=train_dataloader, + load_path=f'azure://mosaicml-composer-tests/checkpoints/{run_name}/test-model.pt', + max_duration='2ba', + ) + trainer_load.fit() + trainer_load.close() diff --git a/tests/utils/object_store/test_integration_gs_object_store.py b/tests/utils/object_store/test_integration_gs_object_store.py deleted file mode 100644 index 1a08bb73ce..0000000000 --- a/tests/utils/object_store/test_integration_gs_object_store.py +++ /dev/null @@ -1,107 +0,0 @@ -# Copyright 2022 MosaicML Composer authors -# SPDX-License-Identifier: Apache-2.0 - -import time -from pathlib import Path - -import pytest - -from composer.utils import GCSObjectStore - -__DUMMY_OBJ__ = '/tmp/dummy.ckpt' -__NUM_BYTES__ = 1000 -bucket_name = 'mosaicml-composer-tests' - - -@pytest.mark.remote -@pytest.fixture -def gs_object_store(): - pytest.skip('Run this test suite only after GCS service account is configured on CI node.') - remote_dir = 'gs://mosaicml-composer-tests/streaming/' - yield GCSObjectStore(remote_dir) - - -@pytest.mark.remote -def test_bucket_not_found(): - pytest.skip('Run this test suite only after GCS service account is configured on CI node.') - with pytest.raises(FileNotFoundError): - _ = GCSObjectStore('gs://not_a_bucket/streaming') - - -@pytest.mark.remote -def test_get_uri(gs_object_store): - pytest.skip('Run this test suite only after GCS service account is configured on CI node.') - object_name = 'test-object' - expected_uri = 'gs://mosaicml-composer-tests/streaming/test-object' - assert (gs_object_store.get_uri(object_name) == expected_uri) - - -@pytest.mark.remote -def test_get_key(gs_object_store): - pytest.skip('Run this test suite only after GCS service account is configured on CI node.') - object_name = 'test-object' - expected_key = 'streaming/test-object' - assert (gs_object_store.get_key(object_name) == expected_key) - - -@pytest.mark.remote -@pytest.mark.parametrize('result', ['success', 'not found']) -def test_get_object_size(gs_object_store, result: str): - pytest.skip('Run this test suite only after GCS service account is configured on CI node.') - fn = Path(__DUMMY_OBJ__) - with open(fn, 'wb') as fp: - fp.write(bytes('0' * __NUM_BYTES__, 'utf-8')) - gs_object_store.upload_object(fn) - - if result == 'success': - assert (gs_object_store.get_object_size(__DUMMY_OBJ__) == __NUM_BYTES__) - else: # not found - with pytest.raises(FileNotFoundError): - gs_object_store.get_object_size(__DUMMY_OBJ__ + f'time.ctime()') - - -@pytest.mark.remote -def test_upload_object(gs_object_store): - pytest.skip('Run this test suite only after GCS service account is configured on CI node.') - from google.cloud.storage import Blob - destination_blob_name = '/tmp/dummy.ckpt2' - key = gs_object_store.get_key(destination_blob_name) - stats = Blob(bucket=gs_object_store.bucket, name=key).exists(gs_object_store.client) - if not stats: - gs_object_store.upload_object(__DUMMY_OBJ__, destination_blob_name) - - -@pytest.mark.remote -def test_list_objects(gs_object_store): - pytest.skip('Run this test suite only after GCS service account is configured on CI node.') - from google.cloud.storage import Blob - destination_blob_name = '/tmp/dummy.ckpt2' - key = gs_object_store.get_key(destination_blob_name) - stats = Blob(bucket=gs_object_store.bucket, name=key).exists(gs_object_store.client) - if not stats: - gs_object_store.upload_object(__DUMMY_OBJ__, destination_blob_name) - objects = gs_object_store.list_objects() - assert (key in objects) - - -@pytest.mark.remote -@pytest.mark.parametrize('result', ['success', 'file_exists', 'obj_not_found']) -def test_download_object(gs_object_store, tmp_path, result: str): - pytest.skip('Run this test suite only after GCS service account is configured on CI node.') - fn = Path(__DUMMY_OBJ__) - with open(fn, 'wb') as fp: - fp.write(bytes('0' * __NUM_BYTES__, 'utf-8')) - gs_object_store.upload_object(fn) - - object_name = __DUMMY_OBJ__ - filename = './dummy.ckpt.download' - - if result == 'success': - gs_object_store.download_object(object_name, filename, overwrite=True) - - elif result == 'file_exists': - with pytest.raises(FileExistsError): - gs_object_store.download_object(object_name, __DUMMY_OBJ__) - else: # obj_not_found - with pytest.raises(FileNotFoundError): - gs_object_store.download_object(object_name + f'{time.ctime()}', filename, overwrite=True)