diff --git a/tests/client/test_dataset_version_export.py b/tests/client/test_dataset_version_export.py index 7424362a..55c5e96f 100644 --- a/tests/client/test_dataset_version_export.py +++ b/tests/client/test_dataset_version_export.py @@ -1,4 +1,5 @@ import io +import json import requests import uuid import zipfile @@ -87,13 +88,22 @@ def test_export_dataset_version(channel): resources_pb2.DatasetVersionExport( format=resources_pb2.CLARIFAI_DATA_PROTOBUF, ), + resources_pb2.DatasetVersionExport( + format=resources_pb2.CLARIFAI_DATA_JSON, + ), ], ), metadata=metadata(), ) raise_on_failure(put_dataset_version_exports_response) - wait_for_dataset_version_export_success(stub, metadata(), dataset_id, dataset_version_id) + wait_for_dataset_version_export_success( + stub, + metadata(), + dataset_id, + dataset_version_id, + ["clarifai_data_protobuf", "clarifai_data_json"], + ) get_dataset_version_response = stub.GetDatasetVersion( service_pb2.GetDatasetVersionRequest( @@ -104,22 +114,29 @@ def test_export_dataset_version(channel): ) raise_on_failure(get_dataset_version_response) - export = get_dataset_version_response.dataset_version.export_info.clarifai_data_protobuf - assert export.format == resources_pb2.CLARIFAI_DATA_PROTOBUF - assert export.size > 0 + export_info = get_dataset_version_response.dataset_version.export_info - get_export_url_response = requests.get(export.url) - assert get_export_url_response.status_code == 200 + def check_protobuf(batch_str): + input_batch = resources_pb2.InputBatch().FromString(batch_str) + assert len(input_batch.inputs) == len(input_ids) - with zipfile.ZipFile(io.BytesIO(get_export_url_response.content)) as zip_file: - assert zip_file.read("mimetype") == b"application/x.clarifai-data+protobuf" + _check_export( + export_info.clarifai_data_protobuf, + resources_pb2.CLARIFAI_DATA_PROTOBUF, + "application/x.clarifai-data+protobuf", + check_protobuf, + ) - namelist = zip_file.namelist() - namelist.remove("mimetype") - assert len(namelist) == 1 # All inputs in a single batch. + def check_json(batch_str): + input_batch = json.loads(batch_str) + assert len(input_batch["inputs"]) == len(input_ids) - input_batch = resources_pb2.InputBatch().FromString(zip_file.read(namelist[0])) - assert len(input_batch.inputs) == len(input_ids) + _check_export( + export_info.clarifai_data_json, + resources_pb2.CLARIFAI_DATA_JSON, + "application/x.clarifai-data+json", + check_json, + ) finally: if dataset_version_id: delete_dataset_versions_response = stub.DeleteDatasetVersions( @@ -146,3 +163,20 @@ def test_export_dataset_version(channel): if input_ids: raise_on_failure(delete_inputs_response) raise_on_failure(delete_datasets_response) + + +def _check_export(export, expected_format, expected_mimetype, check_fn): + assert export.format == expected_format + assert export.size > 0 + + get_export_url_response = requests.get(export.url) + assert get_export_url_response.status_code == 200 + + with zipfile.ZipFile(io.BytesIO(get_export_url_response.content)) as zip_file: + assert zip_file.read("mimetype") == expected_mimetype.encode("ascii") + + namelist = zip_file.namelist() + namelist.remove("mimetype") + assert len(namelist) == 1 # All inputs in a single batch. + + check_fn(zip_file.read(namelist[0])) diff --git a/tests/common.py b/tests/common.py index 6e960d03..60bc4647 100644 --- a/tests/common.py +++ b/tests/common.py @@ -1,6 +1,6 @@ import os import time -from typing import Tuple +from typing import List, Tuple from grpc._channel import _Rendezvous @@ -160,7 +160,9 @@ def wait_for_dataset_version_ready(stub, metadata, dataset_id, dataset_version_i # At this point, the dataset version is ready. -def wait_for_dataset_version_export_success(stub, metadata, dataset_id, dataset_version_id): +def wait_for_dataset_version_export_success( + stub, metadata, dataset_id, dataset_version_id, export_info_fields: List[str] +): while True: response = stub.GetDatasetVersion( service_pb2.GetDatasetVersionRequest( @@ -170,19 +172,28 @@ def wait_for_dataset_version_export_success(stub, metadata, dataset_id, dataset_ metadata=metadata, ) raise_on_failure(response) - export = response.dataset_version.export_info.clarifai_data_protobuf - if export.status.code == status_code_pb2.DATASET_VERSION_EXPORT_SUCCESS: - break - elif export.status.code in ( - status_code_pb2.DATASET_VERSION_EXPORT_PENDING, - status_code_pb2.DATASET_VERSION_EXPORT_IN_PROGRESS, - ): - time.sleep(1) + + for field in export_info_fields: + if not response.dataset_version.export_info.HasField(field): + raise Exception( + f"Missing expected dataset version export info field '{field}'. Full response: {response}" + ) + export = getattr(response.dataset_version.export_info, field) + if export.status.code == status_code_pb2.DATASET_VERSION_EXPORT_SUCCESS: + continue + elif export.status.code in ( + status_code_pb2.DATASET_VERSION_EXPORT_PENDING, + status_code_pb2.DATASET_VERSION_EXPORT_IN_PROGRESS, + ): + time.sleep(1) + break + else: + error_message = get_status_message(export.status) + raise Exception( + f"Expected dataset version to export, but got {error_message}. Full response: {response}" + ) else: - error_message = get_status_message(export.status) - raise Exception( - f"Expected dataset version to export, but got {error_message}. Full response: {response}" - ) + break # break the while True # At this point, the dataset version has successfully finished exporting.