Skip to content
This repository has been archived by the owner on May 5, 2023. It is now read-only.

Commit

Permalink
update test inference
Browse files Browse the repository at this point in the history
  • Loading branch information
Antoine Hoorelbeke committed Jun 11, 2020
1 parent 3339517 commit efeaa12
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 185 deletions.
29 changes: 20 additions & 9 deletions src/mot/serving/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,16 @@ def process_video():
return images_folder

def process_zip():
images_folder = "{}_split".format(filename)
zipfile.ZipFile(full_filepath).extractall(images_folder)
images_folder = os.path.join(upload_folder, "{}_split".format(filename))
with zipfile.ZipFile(full_filepath, 'r') as zip_obj:
file_names = zip_obj.namelist()
for file_name in file_names:
zip_obj.extract(file_name, images_folder)
if os.path.basename(file_name):
shutil.move(
os.path.join(images_folder, file_name),
os.path.join(images_folder, os.path.basename(file_name))
)
return images_folder

if file.mimetype == "":
Expand All @@ -103,13 +111,15 @@ def process_zip():
logger.info("{} images to analyze on {} CPUs.".format(len(image_paths), CPU_COUNT))
try:
inference_outputs = []
with multiprocessing.Pool(CPU_COUNT) as p:
inference_outputs = list(
tqdm(
p.imap(_process_image, image_paths),
total=len(image_paths),
)
)
for image_path in image_paths:
inference_outputs.append(_process_image(image_path))
# with multiprocessing.Pool(CPU_COUNT) as p:
# inference_outputs = list(
# tqdm(
# p.imap(_process_image, image_paths),
# total=len(image_paths),
# )
# )
except ValueError as e:
return {"error": str(e)}
logger.info("Object detection on video {} finished.".format(full_filepath))
Expand Down Expand Up @@ -223,4 +233,5 @@ def predict_and_format_image(
"score": score,
}
detected_trash.append(trash_json)

return detected_trash
20 changes: 6 additions & 14 deletions tests/tests_mot/serving/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def test_app_post_tracking(mock_server_result, tmpdir):
zip_path += ".zip"

app_folder = os.path.join(tmpdir, "app_folder")
os.makedirs(os.path.join(tmpdir, "app_folder"))
os.makedirs(app_folder)
app.config["UPLOAD_FOLDER"] = app_folder
with app.test_client() as c:
response = c.post("/tracking", data={"file": (open(zip_path, "rb"), "toto.zip")})
Expand Down Expand Up @@ -111,19 +111,11 @@ def test_app_post_image(mock_server_result, tmpdir):
output = response.get_json()
assert response.status_code == 200
expected_output = {
"detected_trash":
[
{
"box": [0.0, 0.0, 0.1, 0.05],
"label": "bottles",
"score": 0.7
}, {
"box": [0.0, 0.0, 0.1, 0.1],
"label": "fragments",
"score": 0.6
}
],
"image_path": os.path.join(tmpdir, TMP_IMAGE_NAME)
"detected_trash": [{
"box": [0.0, 0.0, 0.1, 0.05],
"label": "bottles",
"score": 0.7
}]
}
assert output == expected_output

Expand Down
Loading

0 comments on commit efeaa12

Please sign in to comment.