Skip to content

Commit

Permalink
add client.predict and documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
abrichr committed Oct 30, 2024
1 parent 9cce7d7 commit 54b8b47
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 33 deletions.
4 changes: 2 additions & 2 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
# - Git LFS for large file support.
# - Required libraries: OpenCV, Hugging Face, Gradio, OpenGL.
# - Gradio server on port 7861.

#
# 1. Build the image with CUDA support.
# Example:
# ```bash
# sudo nvidia-docker build -t omniparser .
# ```

#
# 2. Run the Docker container with GPU access and port mapping for Gradio.
# Example:
# ```bash
Expand Down
76 changes: 45 additions & 31 deletions client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
This module provides a command-line interface to interact with the OmniParser Gradio server.
This module provides a command-line interface and programmatic API to interact with the OmniParser Gradio server.
Usage:
Command-line usage:
python client.py "http://<server_ip>:7861" "path/to/image.jpg"
View results:
Expand All @@ -11,6 +11,10 @@
Windows: start output_image_<timestamp>.png
Linux: xdg-open output_image_<timestamp>.png
Programmatic usage:
from omniparse.client import predict
result = predict("http://<server_ip>:7861", "path/to/image.jpg")
Result data format:
{
"label_coordinates": {
Expand All @@ -33,30 +37,31 @@
import fire
from gradio_client import Client
from loguru import logger
from PIL import Image
import base64
from io import BytesIO
import os
import shutil
import json
from datetime import datetime

def predict(server_url: str, image_path: str, box_threshold: float = 0.05, iou_threshold: float = 0.1):
# Define constants for default thresholds
DEFAULT_BOX_THRESHOLD = 0.05
DEFAULT_IOU_THRESHOLD = 0.1

def predict(server_url: str, image_path: str, box_threshold: float = DEFAULT_BOX_THRESHOLD, iou_threshold: float = DEFAULT_IOU_THRESHOLD):
"""
Makes a prediction using the OmniParser Gradio client with the provided server URL and image.
Args:
server_url (str): The URL of the OmniParser Gradio server.
image_path (str): Path to the image file to be processed.
box_threshold (float): Box threshold value (default: 0.05).
iou_threshold (float): IOU threshold value (default: 0.1).
Returns:
dict: Parsed result data containing label coordinates and parsed content list.
"""
client = Client(server_url)

# Generate a timestamp for unique file naming
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

# Load and encode the image
image_path = os.path.expanduser(image_path)
with open(image_path, "rb") as image_file:
encoded_image = base64.b64encode(image_file.read()).decode("utf-8")

Expand All @@ -72,47 +77,56 @@ def predict(server_url: str, image_path: str, box_threshold: float = 0.05, iou_t
}

# Make the prediction
try:
result = client.predict(
image_input, # image input as dictionary
box_threshold, # box_threshold
iou_threshold, # iou_threshold
api_name="/process"
)
result = client.predict(
image_input,
box_threshold,
iou_threshold,
api_name="/process"
)

# Process and log the results
output_image, result_json = result

logger.info("Prediction completed successfully")
# Process and return the result
output_image, result_json = result
result_data = json.loads(result_json)

# Parse the JSON string into a Python object
result_data = json.loads(result_json)
return {"output_image": output_image, "result_data": result_data}

# Extract label_coordinates and parsed_content_list
label_coordinates = result_data['label_coordinates']
parsed_content_list = result_data['parsed_content_list']

logger.info(f"{label_coordinates=}")
logger.info(f"{parsed_content_list=}")
def predict_and_save(server_url: str, image_path: str, box_threshold: float = DEFAULT_BOX_THRESHOLD, iou_threshold: float = DEFAULT_IOU_THRESHOLD):
"""
Makes a prediction and saves the results to files, including logs and image outputs.
Args:
server_url (str): The URL of the OmniParser Gradio server.
image_path (str): Path to the image file to be processed.
box_threshold (float): Box threshold value (default: 0.05).
iou_threshold (float): IOU threshold value (default: 0.1).
"""
# Generate a timestamp for unique file naming
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

# Call the predict function to get prediction data
try:
result = predict(server_url, image_path, box_threshold, iou_threshold)
output_image = result["output_image"]
result_data = result["result_data"]

# Save result data to JSON file
result_data_path = f"result_data_{timestamp}.json"
with open(result_data_path, "w") as json_file:
json.dump(result_data, json_file, indent=4)
logger.info(f"Parsed content saved to: {result_data_path}")

# Save the output image
output_image_path = f"output_image_{timestamp}.png"
if isinstance(output_image, str) and os.path.exists(output_image):
shutil.copy(output_image, output_image_path)
logger.info(f"Output image saved to: {output_image_path}")
else:
logger.warning(f"Unexpected output_image format or file not found: {output_image}")

except Exception as e:
logger.error(f"An error occurred: {str(e)}")
logger.exception("Traceback:")

if __name__ == "__main__":
fire.Fire(predict)

if __name__ == "__main__":
fire.Fire(predict_and_save)

0 comments on commit 54b8b47

Please sign in to comment.