Skip to content

Commit

Permalink
[LT-3000] Add client tests for dataset version JSON exports (#136)
Browse files Browse the repository at this point in the history
  • Loading branch information
tiit-clarifai authored May 29, 2023
1 parent 7dc08bb commit 90aabe8
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 27 deletions.
60 changes: 47 additions & 13 deletions tests/client/test_dataset_version_export.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import io
import json
import requests
import uuid
import zipfile
Expand Down Expand Up @@ -87,13 +88,22 @@ def test_export_dataset_version(channel):
resources_pb2.DatasetVersionExport(
format=resources_pb2.CLARIFAI_DATA_PROTOBUF,
),
resources_pb2.DatasetVersionExport(
format=resources_pb2.CLARIFAI_DATA_JSON,
),
],
),
metadata=metadata(),
)
raise_on_failure(put_dataset_version_exports_response)

wait_for_dataset_version_export_success(stub, metadata(), dataset_id, dataset_version_id)
wait_for_dataset_version_export_success(
stub,
metadata(),
dataset_id,
dataset_version_id,
["clarifai_data_protobuf", "clarifai_data_json"],
)

get_dataset_version_response = stub.GetDatasetVersion(
service_pb2.GetDatasetVersionRequest(
Expand All @@ -104,22 +114,29 @@ def test_export_dataset_version(channel):
)
raise_on_failure(get_dataset_version_response)

export = get_dataset_version_response.dataset_version.export_info.clarifai_data_protobuf
assert export.format == resources_pb2.CLARIFAI_DATA_PROTOBUF
assert export.size > 0
export_info = get_dataset_version_response.dataset_version.export_info

get_export_url_response = requests.get(export.url)
assert get_export_url_response.status_code == 200
def check_protobuf(batch_str):
input_batch = resources_pb2.InputBatch().FromString(batch_str)
assert len(input_batch.inputs) == len(input_ids)

with zipfile.ZipFile(io.BytesIO(get_export_url_response.content)) as zip_file:
assert zip_file.read("mimetype") == b"application/x.clarifai-data+protobuf"
_check_export(
export_info.clarifai_data_protobuf,
resources_pb2.CLARIFAI_DATA_PROTOBUF,
"application/x.clarifai-data+protobuf",
check_protobuf,
)

namelist = zip_file.namelist()
namelist.remove("mimetype")
assert len(namelist) == 1 # All inputs in a single batch.
def check_json(batch_str):
input_batch = json.loads(batch_str)
assert len(input_batch["inputs"]) == len(input_ids)

input_batch = resources_pb2.InputBatch().FromString(zip_file.read(namelist[0]))
assert len(input_batch.inputs) == len(input_ids)
_check_export(
export_info.clarifai_data_json,
resources_pb2.CLARIFAI_DATA_JSON,
"application/x.clarifai-data+json",
check_json,
)
finally:
if dataset_version_id:
delete_dataset_versions_response = stub.DeleteDatasetVersions(
Expand All @@ -146,3 +163,20 @@ def test_export_dataset_version(channel):
if input_ids:
raise_on_failure(delete_inputs_response)
raise_on_failure(delete_datasets_response)


def _check_export(export, expected_format, expected_mimetype, check_fn):
assert export.format == expected_format
assert export.size > 0

get_export_url_response = requests.get(export.url)
assert get_export_url_response.status_code == 200

with zipfile.ZipFile(io.BytesIO(get_export_url_response.content)) as zip_file:
assert zip_file.read("mimetype") == expected_mimetype.encode("ascii")

namelist = zip_file.namelist()
namelist.remove("mimetype")
assert len(namelist) == 1 # All inputs in a single batch.

check_fn(zip_file.read(namelist[0]))
39 changes: 25 additions & 14 deletions tests/common.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
import time
from typing import Tuple
from typing import List, Tuple

from grpc._channel import _Rendezvous

Expand Down Expand Up @@ -160,7 +160,9 @@ def wait_for_dataset_version_ready(stub, metadata, dataset_id, dataset_version_i
# At this point, the dataset version is ready.


def wait_for_dataset_version_export_success(stub, metadata, dataset_id, dataset_version_id):
def wait_for_dataset_version_export_success(
stub, metadata, dataset_id, dataset_version_id, export_info_fields: List[str]
):
while True:
response = stub.GetDatasetVersion(
service_pb2.GetDatasetVersionRequest(
Expand All @@ -170,19 +172,28 @@ def wait_for_dataset_version_export_success(stub, metadata, dataset_id, dataset_
metadata=metadata,
)
raise_on_failure(response)
export = response.dataset_version.export_info.clarifai_data_protobuf
if export.status.code == status_code_pb2.DATASET_VERSION_EXPORT_SUCCESS:
break
elif export.status.code in (
status_code_pb2.DATASET_VERSION_EXPORT_PENDING,
status_code_pb2.DATASET_VERSION_EXPORT_IN_PROGRESS,
):
time.sleep(1)

for field in export_info_fields:
if not response.dataset_version.export_info.HasField(field):
raise Exception(
f"Missing expected dataset version export info field '{field}'. Full response: {response}"
)
export = getattr(response.dataset_version.export_info, field)
if export.status.code == status_code_pb2.DATASET_VERSION_EXPORT_SUCCESS:
continue
elif export.status.code in (
status_code_pb2.DATASET_VERSION_EXPORT_PENDING,
status_code_pb2.DATASET_VERSION_EXPORT_IN_PROGRESS,
):
time.sleep(1)
break
else:
error_message = get_status_message(export.status)
raise Exception(
f"Expected dataset version to export, but got {error_message}. Full response: {response}"
)
else:
error_message = get_status_message(export.status)
raise Exception(
f"Expected dataset version to export, but got {error_message}. Full response: {response}"
)
break # break the while True
# At this point, the dataset version has successfully finished exporting.


Expand Down

0 comments on commit 90aabe8

Please sign in to comment.