diff --git a/.gitignore b/.gitignore index bb58b8d..8ec0c33 100644 --- a/.gitignore +++ b/.gitignore @@ -4,5 +4,10 @@ weights/icon_detect/ weights/icon_detect_v1_5/ weights/icon_detect_v1_5_2/ .gradio +*.swp +.env +.env.* +venv/ +*.pem __pycache__/ -debug.ipynb +debug.ipynb \ No newline at end of file diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..6aa8a3f --- /dev/null +++ b/Dockerfile @@ -0,0 +1,85 @@ +# Dockerfile for OmniParser with GPU and OpenGL support. +# +# Base: nvidia/cuda:12.3.1-devel-ubuntu22.04 +# Features: +# - Python 3.12 with Miniconda environment. +# - Git LFS for large file support. +# - Required libraries: OpenCV, Hugging Face, Gradio, OpenGL. +# - Gradio server on port 7861. +# +# 1. Build the image with CUDA support. +# ``` +# sudo docker build -t omniparser . +# ``` +# +# 2. Run the Docker container with GPU access and port mapping for Gradio. +# ```bash +# sudo docker run -d -p 7861:7861 --gpus all --name omniparser-container omniparser +# ``` +# +# Author: Richard Abrich (richard@openadapt.ai) + +FROM nvidia/cuda:12.3.1-devel-ubuntu22.04 + +# Install system dependencies with explicit OpenGL libraries +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y \ + git-lfs \ + wget \ + libgl1 \ + libglib2.0-0 \ + && apt-get clean \ + && rm -rf /var/lib/apt/lists/* \ + && git lfs install + +# Install Miniconda for Python 3.12 +RUN wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O miniconda.sh && \ + bash miniconda.sh -b -p /opt/conda && \ + rm miniconda.sh +ENV PATH="/opt/conda/bin:$PATH" + +# Create and activate Conda environment with Python 3.12, and set it as the default +RUN conda create -n omni python=3.12 && \ + echo "source activate omni" > ~/.bashrc +ENV CONDA_DEFAULT_ENV=omni +ENV PATH="/opt/conda/envs/omni/bin:$PATH" + +# Set the working directory in the container +WORKDIR /usr/src/app + +# Copy project files and requirements +COPY . . +COPY requirements.txt /usr/src/app/requirements.txt + +# Initialize Git LFS and pull LFS files +RUN git lfs install && \ + git lfs pull + +# Install dependencies from requirements.txt with specific opencv-python-headless version +RUN . /opt/conda/etc/profile.d/conda.sh && conda activate omni && \ + pip uninstall -y opencv-python opencv-python-headless && \ + pip install --no-cache-dir opencv-python-headless==4.8.1.78 && \ + pip install -r requirements.txt && \ + pip install huggingface_hub + +# Run download.py to fetch model weights and convert safetensors to .pt format +RUN . /opt/conda/etc/profile.d/conda.sh && conda activate omni && \ + python download.py && \ + echo "Contents of weights directory:" && \ + ls -lR weights && \ + python weights/convert_safetensor_to_pt.py + +# Expose the default Gradio port +EXPOSE 7861 + +# Configure Gradio to be accessible externally +ENV GRADIO_SERVER_NAME="0.0.0.0" + +# Copy and set permissions for entrypoint script +COPY entrypoint.sh /usr/src/app/entrypoint.sh +RUN chmod +x /usr/src/app/entrypoint.sh + +# To debug, keep the container running +# CMD ["tail", "-f", "/dev/null"] + +# Set the entrypoint +ENTRYPOINT ["/usr/src/app/entrypoint.sh"] diff --git a/README.md b/README.md index 4c4c7be..2819440 100644 --- a/README.md +++ b/README.md @@ -18,6 +18,19 @@ - [2024/10] Both Interactive Region Detection Model and Icon functional description model are released! [Hugginface models](https://huggingface.co/microsoft/OmniParser) - [2024/09] OmniParser achieves the best performance on [Windows Agent Arena](https://microsoft.github.io/WindowsAgentArena/)! +### :rocket: Docker Quick Start + +Prerequisites: +- CUDA-enabled GPU +- NVIDIA Container Toolkit installed (https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html) +``` +# Build the image (requires CUDA) +sudo docker build -t omniparser . + +# Run the image +sudo docker run -d -p 7861:7861 --gpus all --name omniparser-container omniparser +``` + ## Install Install environment: ```python @@ -26,10 +39,17 @@ conda activate omni pip install -r requirements.txt ``` -Then download the model ckpts files in: https://huggingface.co/microsoft/OmniParser, and put them under weights/, default folder structure is: weights/icon_detect, weights/icon_caption_florence, weights/icon_caption_blip2. +Download and convert the model ckpt files from https://huggingface.co/microsoft/OmniParser: +```python +python download.py +``` + +Or, download the model ckpts files in: https://huggingface.co/microsoft/OmniParser, and put them under weights/, default folder structure is: weights/icon_detect, weights/icon_caption_florence, weights/icon_caption_blip2. +Finally, convert the safetensor to .pt file. For v1: convert the safetensor to .pt file. + ```python python weights/convert_safetensor_to_pt.py @@ -49,6 +69,13 @@ python gradio_demo.py --icon_detect_model weights/icon_detect/best.pt --icon_cap python gradio_demo.py --icon_detect_model weights/icon_detect_v1_5/model_v1_5.pt --icon_caption_model florence2 ``` +## Deploy to AWS + +To deploy OmniParser to EC2 on AWS via Github Actions: + +1. Fork this repository and clone your fork to your local machine. +2. Follow the instructions at the top of [`deploy.py`](https://github.com/microsoft/OmniParser/blob/main/deploy.py). + ## Model Weights License For the model checkpoints on huggingface model hub, please note that icon_detect model is under AGPL license since it is a license inherited from the original yolo model. And icon_caption_blip2 & icon_caption_florence is under MIT license. Please refer to the LICENSE file in the folder of each model: https://huggingface.co/microsoft/OmniParser. diff --git a/client.py b/client.py new file mode 100644 index 0000000..234dcb9 --- /dev/null +++ b/client.py @@ -0,0 +1,132 @@ +""" +This module provides a command-line interface and programmatic API to interact with the OmniParser Gradio server. + +Command-line usage: + python client.py "http://:7861" "path/to/image.jpg" + +View results: + JSON: cat result_data_.json + Image: + macOS: open output_image_.png + Windows: start output_image_.png + Linux: xdg-open output_image_.png + +Programmatic usage: + from omniparse.client import predict + result = predict("http://:7861", "path/to/image.jpg") + +Result data format: + { + "label_coordinates": { + "0": [x1, y1, width, height], // Normalized coordinates for each bounding box + "1": [x1, y1, width, height], + ... + }, + "parsed_content_list": [ + "Text Box ID 0: [content]", + "Text Box ID 1: [content]", + ..., + "Icon Box ID X: [description]", + ... + ] + } + +Note: The parsed_content_list includes both text box contents and icon descriptions. +""" + +import fire +from gradio_client import Client +from loguru import logger +import base64 +import os +import shutil +import json +from datetime import datetime + +# Define constants for default thresholds +DEFAULT_BOX_THRESHOLD = 0.05 +DEFAULT_IOU_THRESHOLD = 0.1 + +def predict(server_url: str, image_path: str, box_threshold: float = DEFAULT_BOX_THRESHOLD, iou_threshold: float = DEFAULT_IOU_THRESHOLD): + """ + Makes a prediction using the OmniParser Gradio client with the provided server URL and image. + Args: + server_url (str): The URL of the OmniParser Gradio server. + image_path (str): Path to the image file to be processed. + box_threshold (float): Box threshold value (default: 0.05). + iou_threshold (float): IOU threshold value (default: 0.1). + Returns: + dict: Parsed result data containing label coordinates and parsed content list. + """ + client = Client(server_url) + + # Load and encode the image + image_path = os.path.expanduser(image_path) + with open(image_path, "rb") as image_file: + encoded_image = base64.b64encode(image_file.read()).decode("utf-8") + + # Prepare the image input in the format expected by the server + image_input = { + "path": None, + "url": f"data:image/png;base64,{encoded_image}", + "size": None, + "orig_name": image_path, + "mime_type": "image/png", + "is_stream": False, + "meta": {} + } + + # Make the prediction + result = client.predict( + image_input, + box_threshold, + iou_threshold, + api_name="/process" + ) + + # Process and return the result + output_image, result_json = result + result_data = json.loads(result_json) + + return {"output_image": output_image, "result_data": result_data} + + +def predict_and_save(server_url: str, image_path: str, box_threshold: float = DEFAULT_BOX_THRESHOLD, iou_threshold: float = DEFAULT_IOU_THRESHOLD): + """ + Makes a prediction and saves the results to files, including logs and image outputs. + Args: + server_url (str): The URL of the OmniParser Gradio server. + image_path (str): Path to the image file to be processed. + box_threshold (float): Box threshold value (default: 0.05). + iou_threshold (float): IOU threshold value (default: 0.1). + """ + # Generate a timestamp for unique file naming + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + + # Call the predict function to get prediction data + try: + result = predict(server_url, image_path, box_threshold, iou_threshold) + output_image = result["output_image"] + result_data = result["result_data"] + + # Save result data to JSON file + result_data_path = f"result_data_{timestamp}.json" + with open(result_data_path, "w") as json_file: + json.dump(result_data, json_file, indent=4) + logger.info(f"Parsed content saved to: {result_data_path}") + + # Save the output image + output_image_path = f"output_image_{timestamp}.png" + if isinstance(output_image, str) and os.path.exists(output_image): + shutil.copy(output_image, output_image_path) + logger.info(f"Output image saved to: {output_image_path}") + else: + logger.warning(f"Unexpected output_image format or file not found: {output_image}") + + except Exception as e: + logger.error(f"An error occurred: {str(e)}") + logger.exception("Traceback:") + + +if __name__ == "__main__": + fire.Fire(predict_and_save) diff --git a/deploy.py b/deploy.py new file mode 100644 index 0000000..92d7cc7 --- /dev/null +++ b/deploy.py @@ -0,0 +1,818 @@ +"""Deploy OmniParse to AWS EC2 via Github action. + +Usage: + + 1. Create and populate the .env file: + + cat > .env < +AWS_SECRET_ACCESS_KEY= +AWS_REGION= +GITHUB_OWNER= # e.g. microsoft +GITHUB_REPO= # e.g. OmniParse +GITHUB_TOKEN= +PROJECT_NAME= # for tagging AWS resources +EOF + + 2. Create a virtual environment for deployment: + + python3.10 -m venv venv + source venv/bin/activate + pip install -r deploy_requirements.txt + + 3. Run the deployment script: + + python deploy.py start + + You may see the following error: + + botocore.exceptions.ClientError: An error occurred (OptInRequired) when + calling the RunInstances operation: In order to use this AWS + Marketplace product you need to accept terms and subscribe. To do so + please visit https://aws.amazon.com/marketplace/pp?sku=64g24n0wem7a8nuhfum3097vb + + Open the specified URL in the browser and accept the terms, then try again. + + 4. Wait for the build to succeed in Github actions (see console output for URL) + + 5. Open the gradio interface (see console output for URL) and test it out. + Note that it may take a minute for the interface to become available. + You can also interact with the server programmatically: + + python client.py "http://:7861" + + 6. Terminate the EC2 instance and stop incurring charges: + + python deploy.py stop + + Or, to shut it down without removing it: + + python deploy.py pause + + (This can later be re-started with the `start` command.) + + 7. (optional) List all tagged instances with their respective statuses: + + python deploy.py status + + 8. (optional) SSH into server: + + python deploy.py ssh + + Example commands: + - View containers: `docker ps -a` + - Tail logs: `docker logs -f ` + - Enter container shell: `docker exec -it /bin/bash` + +Troubleshooting Token Scope Error: + + If you encounter an error similar to the following when pushing changes to + GitHub Actions workflow files: + + ! [remote rejected] feat/docker -> feat/docker (refusing to allow a + Personal Access Token to create or update workflow + `.github/workflows/docker-build-ec2.yml` without `workflow` scope) + + This indicates that the Personal Access Token (PAT) being used does not + have the necessary permissions ('workflow' scope) to create or update GitHub + Actions workflows. To resolve this issue, you will need to create or update + your PAT with the appropriate scope. + + Creating or Updating a Classic PAT with 'workflow' Scope: + + 1. Go to GitHub and sign in to your account. + 2. Click on your profile picture in the top right corner, and then click 'Settings'. + 3. In the sidebar, click 'Developer settings'. + 4. Click 'Personal access tokens', then 'Classic tokens'. + 5. To update an existing token: + a. Find the token you wish to update in the list and click on it. + b. Scroll down to the 'Select scopes' section. + c. Make sure the 'workflow' scope is checked. This scope allows for + managing GitHub Actions workflows. + d. Click 'Update token' at the bottom of the page. + 6. To create a new token: + a. Click 'Generate new token'. + b. Give your token a descriptive name under 'Note'. + c. Scroll down to the 'Select scopes' section. + d. Check the 'workflow' scope to allow managing GitHub Actions workflows. + e. Optionally, select any other scopes needed for your project. + f. Click 'Generate token' at the bottom of the page. + 7. Copy the generated token. Make sure to save it securely, as you will not + be able to see it again. + + After creating or updating your PAT with the 'workflow' scope, update the + Git remote configuration to use the new token, and try pushing your changes + again. + + Note: Always keep your tokens secure and never share them publicly. + +""" + +import base64 +import json +import os +import subprocess +import time + +from botocore.exceptions import ClientError +from jinja2 import Environment, FileSystemLoader +from loguru import logger +from nacl import encoding, public +from pydantic_settings import BaseSettings +import boto3 +import fire +import git +import paramiko +import requests + +class Config(BaseSettings): + AWS_ACCESS_KEY_ID: str + AWS_SECRET_ACCESS_KEY: str + AWS_REGION: str + GITHUB_OWNER: str + GITHUB_REPO: str + GITHUB_TOKEN: str + PROJECT_NAME: str + + AWS_EC2_AMI: str = "ami-06835d15c4de57810" + AWS_EC2_DISK_SIZE: int = 128 # GB + #AWS_EC2_INSTANCE_TYPE: str = "p3.2xlarge" # (V100 16GB $3.06/hr x86_64) + AWS_EC2_INSTANCE_TYPE: str = "g4dn.xlarge" # (T4 16GB $0.526/hr x86_64) + AWS_EC2_USER: str = "ubuntu" + + # Note: changing this requires changing the hard-coded value in other files + PORT: int = 7861 + + class Config: + env_file = ".env" + env_file_encoding = 'utf-8' + + @property + def AWS_EC2_KEY_NAME(self) -> str: + return f"{self.PROJECT_NAME}-key" + + @property + def AWS_EC2_KEY_PATH(self) -> str: + return f"./{self.AWS_EC2_KEY_NAME}.pem" + + @property + def AWS_EC2_SECURITY_GROUP(self) -> str: + return f"{self.PROJECT_NAME}-SecurityGroup" + + @property + def AWS_SSM_ROLE_NAME(self) -> str: + return f"{self.PROJECT_NAME}-SSMRole" + + @property + def AWS_SSM_PROFILE_NAME(self) -> str: + return f"{self.PROJECT_NAME}-SSMInstanceProfile" + + @property + def GITHUB_PATH(self) -> str: + return f"{self.GITHUB_OWNER}/{self.GITHUB_REPO}" + +config = Config() + +def encrypt(public_key: str, secret_value: str) -> str: + """ + Encrypts a Unicode string using the provided public key. + + Args: + public_key (str): The public key for encryption, encoded in Base64. + secret_value (str): The Unicode string to be encrypted. + + Returns: + str: The encrypted value, encoded in Base64. + """ + public_key = public.PublicKey(public_key.encode("utf-8"), encoding.Base64Encoder()) + sealed_box = public.SealedBox(public_key) + encrypted = sealed_box.encrypt(secret_value.encode("utf-8")) + return base64.b64encode(encrypted).decode("utf-8") + +def set_github_secret(token: str, repo: str, secret_name: str, secret_value: str) -> None: + """ + Sets a secret in the specified GitHub repository. + + Args: + token (str): GitHub token with permissions to set secrets. + repo (str): Repository path in the format "owner/repo". + secret_name (str): The name of the secret to set. + secret_value (str): The value of the secret. + + Returns: + None + """ + secret_value = secret_value or "" + headers = { + "Authorization": f"token {token}", + "Accept": "application/vnd.github.v3+json" + } + response = requests.get(f"https://api.github.com/repos/{repo}/actions/secrets/public-key", headers=headers) + response.raise_for_status() + key = response.json()['key'] + key_id = response.json()['key_id'] + encrypted_value = encrypt(key, secret_value) + secret_url = f"https://api.github.com/repos/{repo}/actions/secrets/{secret_name}" + data = {"encrypted_value": encrypted_value, "key_id": key_id} + response = requests.put(secret_url, headers=headers, json=data) + response.raise_for_status() + logger.info(f"set {secret_name=}") + +def set_github_secrets() -> None: + """ + Sets required AWS credentials and SSH private key as GitHub Secrets. + + Returns: + None + """ + # Set AWS secrets + set_github_secret(config.GITHUB_TOKEN, config.GITHUB_PATH, 'AWS_ACCESS_KEY_ID', config.AWS_ACCESS_KEY_ID) + set_github_secret(config.GITHUB_TOKEN, config.GITHUB_PATH, 'AWS_SECRET_ACCESS_KEY', config.AWS_SECRET_ACCESS_KEY) + + # Read the SSH private key from the file + try: + with open(config.AWS_EC2_KEY_PATH, 'r') as key_file: + ssh_private_key = key_file.read() + set_github_secret(config.GITHUB_TOKEN, config.GITHUB_PATH, 'SSH_PRIVATE_KEY', ssh_private_key) + except IOError as e: + logger.error(f"Error reading SSH private key file: {e}") + +def create_key_pair(key_name: str = config.AWS_EC2_KEY_NAME, key_path: str = config.AWS_EC2_KEY_PATH) -> str | None: + """ + Creates a new EC2 key pair and saves it to a file. + + Args: + key_name (str): The name of the key pair to create. Defaults to config.AWS_EC2_KEY_NAME. + key_path (str): The path where the key file should be saved. Defaults to config.AWS_EC2_KEY_PATH. + + Returns: + str | None: The name of the created key pair or None if an error occurred. + """ + ec2_client = boto3.client('ec2', region_name=config.AWS_REGION) + try: + key_pair = ec2_client.create_key_pair(KeyName=key_name) + private_key = key_pair['KeyMaterial'] + + # Save the private key to a file + with open(key_path, "w") as key_file: + key_file.write(private_key) + os.chmod(key_path, 0o400) # Set read-only permissions + + logger.info(f"Key pair {key_name} created and saved to {key_path}") + return key_name + except ClientError as e: + logger.error(f"Error creating key pair: {e}") + return None + +def get_or_create_security_group_id(ports: list[int] = [22, config.PORT]) -> str | None: + """ + Retrieves or creates a security group with the specified ports opened. + + Args: + ports (list[int]): A list of ports to open in the security group. Defaults to [22, 7861]. + + Returns: + str | None: The ID of the security group, or None if an error occurred. + """ + ec2 = boto3.client('ec2', region_name=config.AWS_REGION) + + # Construct ip_permissions list + ip_permissions = [{ + 'IpProtocol': 'tcp', + 'FromPort': port, + 'ToPort': port, + 'IpRanges': [{'CidrIp': '0.0.0.0/0'}] + } for port in ports] + + try: + response = ec2.describe_security_groups(GroupNames=[config.AWS_EC2_SECURITY_GROUP]) + security_group_id = response['SecurityGroups'][0]['GroupId'] + logger.info(f"Security group '{config.AWS_EC2_SECURITY_GROUP}' already exists: {security_group_id}") + + for ip_permission in ip_permissions: + try: + ec2.authorize_security_group_ingress( + GroupId=security_group_id, + IpPermissions=[ip_permission] + ) + logger.info(f"Added inbound rule to allow TCP traffic on port {ip_permission['FromPort']} from any IP") + except ClientError as e: + if e.response['Error']['Code'] == 'InvalidPermission.Duplicate': + logger.info(f"Rule for port {ip_permission['FromPort']} already exists") + else: + logger.error(f"Error adding rule for port {ip_permission['FromPort']}: {e}") + + return security_group_id + except ClientError as e: + if e.response['Error']['Code'] == 'InvalidGroup.NotFound': + try: + # Create the security group + response = ec2.create_security_group( + GroupName=config.AWS_EC2_SECURITY_GROUP, + Description='Security group for specified port access', + TagSpecifications=[ + { + 'ResourceType': 'security-group', + 'Tags': [{'Key': 'Name', 'Value': config.PROJECT_NAME}] + } + ] + ) + security_group_id = response['GroupId'] + logger.info(f"Created security group '{config.AWS_EC2_SECURITY_GROUP}' with ID: {security_group_id}") + + # Add rules for the given ports + ec2.authorize_security_group_ingress(GroupId=security_group_id, IpPermissions=ip_permissions) + logger.info(f"Added inbound rules to allow access on {ports=}") + + return security_group_id + except ClientError as e: + logger.error(f"Error creating security group: {e}") + return None + else: + logger.error(f"Error describing security groups: {e}") + return None + +def deploy_ec2_instance( + ami: str = config.AWS_EC2_AMI, + instance_type: str = config.AWS_EC2_INSTANCE_TYPE, + project_name: str = config.PROJECT_NAME, + key_name: str = config.AWS_EC2_KEY_NAME, + disk_size: int = config.AWS_EC2_DISK_SIZE, +) -> tuple[str | None, str | None]: + """ + Deploys an EC2 instance with the specified parameters. + + Args: + ami (str): The Amazon Machine Image ID to use for the instance. Defaults to config.AWS_EC2_AMI. + instance_type (str): The type of instance to deploy. Defaults to config.AWS_EC2_INSTANCE_TYPE. + project_name (str): The project name, used for tagging the instance. Defaults to config.PROJECT_NAME. + key_name (str): The name of the key pair to use for the instance. Defaults to config.AWS_EC2_KEY_NAME. + disk_size (int): The size of the disk in GB. Defaults to config.AWS_EC2_DISK_SIZE. + + Returns: + tuple[str | None, str | None]: A tuple containing the instance ID and IP address, or None, None if deployment fails. + """ + ec2 = boto3.resource('ec2') + ec2_client = boto3.client('ec2') + + # Check if key pair exists, if not create one + try: + ec2_client.describe_key_pairs(KeyNames=[key_name]) + except ClientError as e: + create_key_pair(key_name) + + # Fetch the security group ID + security_group_id = get_or_create_security_group_id() + if not security_group_id: + logger.error("Unable to retrieve security group ID. Instance deployment aborted.") + return None, None + + # Check for existing instances + instances = ec2.instances.filter( + Filters=[ + {'Name': 'tag:Name', 'Values': [config.PROJECT_NAME]}, + {'Name': 'instance-state-name', 'Values': ['running', 'pending', 'stopped']} + ] + ) + + for instance in instances: + if instance.state['Name'] == 'running': + logger.info(f"Instance already running: ID - {instance.id}, IP - {instance.public_ip_address}") + return instance.id, instance.public_ip_address + elif instance.state['Name'] == 'stopped': + logger.info(f"Starting existing stopped instance: ID - {instance.id}") + ec2_client.start_instances(InstanceIds=[instance.id]) + instance.wait_until_running() + instance.reload() + logger.info(f"Instance started: ID - {instance.id}, IP - {instance.public_ip_address}") + return instance.id, instance.public_ip_address + elif state == 'pending': + logger.info(f"Instance is pending: ID - {instance.id}. Waiting for 'running' state.") + try: + instance.wait_until_running() # Wait for the instance to be in 'running' state + instance.reload() # Reload the instance attributes + logger.info(f"Instance is now running: ID - {instance.id}, IP - {instance.public_ip_address}") + return instance.id, instance.public_ip_address + except botocore.exceptions.WaiterError as e: + logger.error(f"Error waiting for instance to run: {e}") + return None, None + # Define EBS volume configuration + ebs_config = { + 'DeviceName': '/dev/sda1', # You may need to change this depending on the instance type and AMI + 'Ebs': { + 'VolumeSize': disk_size, + 'VolumeType': 'gp3', # Or other volume types like gp2, io1, etc. + 'DeleteOnTermination': True # Set to False if you want to keep the volume after instance termination + }, + } + + # Create a new instance if none exist + new_instance = ec2.create_instances( + ImageId=ami, + MinCount=1, + MaxCount=1, + InstanceType=instance_type, + KeyName=key_name, + SecurityGroupIds=[security_group_id], + BlockDeviceMappings=[ebs_config], + TagSpecifications=[ + { + 'ResourceType': 'instance', + 'Tags': [{'Key': 'Name', 'Value': project_name}] + }, + ] + )[0] + + new_instance.wait_until_running() + new_instance.reload() + logger.info(f"New instance created: ID - {new_instance.id}, IP - {new_instance.public_ip_address}") + return new_instance.id, new_instance.public_ip_address + +def configure_ec2_instance( + instance_id: str | None = None, + instance_ip: str | None = None, + max_ssh_retries: int = 20, + ssh_retry_delay: int = 20, + max_cmd_retries: int = 20, + cmd_retry_delay: int = 30, +) -> tuple[str | None, str | None]: + """ + Configures the specified EC2 instance for Docker builds. + + Args: + instance_id (str | None): The ID of the instance to configure. If None, a new instance will be deployed. Defaults to None. + instance_ip (str | None): The IP address of the instance. Must be provided if instance_id is manually passed. Defaults to None. + max_ssh_retries (int): Maximum number of SSH connection retries. Defaults to 20. + ssh_retry_delay (int): Delay between SSH connection retries in seconds. Defaults to 20. + max_cmd_retries (int): Maximum number of command execution retries. Defaults to 20. + cmd_retry_delay (int): Delay between command execution retries in seconds. Defaults to 30. + + Returns: + tuple[str | None, str | None]: A tuple containing the instance ID and IP address, or None, None if configuration fails. + """ + if not instance_id: + ec2_instance_id, ec2_instance_ip = deploy_ec2_instance() + else: + ec2_instance_id = instance_id + ec2_instance_ip = instance_ip # Ensure instance IP is provided if instance_id is manually passed + + key = paramiko.RSAKey.from_private_key_file(config.AWS_EC2_KEY_PATH) + ssh_client = paramiko.SSHClient() + ssh_client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + + ssh_retries = 0 + while ssh_retries < max_ssh_retries: + try: + ssh_client.connect(hostname=ec2_instance_ip, username=config.AWS_EC2_USER, pkey=key) + break # Successful SSH connection, break out of the loop + except Exception as e: + ssh_retries += 1 + logger.error(f"SSH connection attempt {ssh_retries} failed: {e}") + if ssh_retries < max_ssh_retries: + logger.info(f"Retrying SSH connection in {ssh_retry_delay} seconds...") + time.sleep(ssh_retry_delay) + else: + logger.error("Maximum SSH connection attempts reached. Aborting.") + return + + # Setup commands to configure Docker, clean up unused resources, and handle name conflicts + commands = [ + "sudo apt-get update", + "sudo apt-get install -y docker.io", + "sudo systemctl start docker", + "sudo systemctl enable docker", + "sudo usermod -a -G docker ${USER}", + # Install Docker Compose + "sudo curl -L \"https://github.com/docker/compose/releases/download/1.29.2/docker-compose-$(uname -s)-$(uname -m)\" -o /usr/local/bin/docker-compose", + "sudo chmod +x /usr/local/bin/docker-compose", + "sudo ln -s /usr/local/bin/docker-compose /usr/bin/docker-compose", + # Docker cleanup and conflict resolution commands + "sudo docker system prune -af --volumes", # Clean up unused Docker resources + "sudo docker rm -f omniparser-container || true", # Remove conflicting container if it exists + ] + + + for command in commands: + logger.info(f"Executing command: {command}") + cmd_retries = 0 + while cmd_retries < max_cmd_retries: + stdin, stdout, stderr = ssh_client.exec_command(command) + exit_status = stdout.channel.recv_exit_status() # Blocking call + + if exit_status == 0: + logger.info(f"Command executed successfully") + break + else: + error_message = stderr.read() + if "Could not get lock" in str(error_message): + cmd_retries += 1 + logger.warning(f"dpkg is locked, retrying command in {cmd_retry_delay} seconds... Attempt {cmd_retries}/{max_cmd_retries}") + time.sleep(cmd_retry_delay) + else: + logger.error(f"Error in command: {command}, Exit Status: {exit_status}, Error: {error_message}") + break # Non-dpkg lock error, break out of the loop + + ssh_client.close() + return ec2_instance_id, ec2_instance_ip + +def generate_github_actions_workflow() -> None: + """ + Generates and writes the GitHub Actions workflow file for Docker build on EC2. + + Returns: + None + """ + current_branch = get_current_git_branch() + + _, host = deploy_ec2_instance() + + # Set up Jinja2 environment + env = Environment(loader=FileSystemLoader('.')) + template = env.get_template('docker-build-ec2.yml.j2') + + # Render the template with the current branch + rendered_workflow = template.render( + branch_name=current_branch, + host=host, + username=config.AWS_EC2_USER, + project_name=config.PROJECT_NAME, + github_path=config.GITHUB_PATH, + github_repo=config.GITHUB_REPO, + ) + + # Write the rendered workflow to a file + workflows_dir = '.github/workflows' + os.makedirs(workflows_dir, exist_ok=True) + with open(os.path.join(workflows_dir, 'docker-build-ec2.yml'), 'w') as file: + file.write("# Autogenerated via deploy.py, do not edit!\n\n") + file.write(rendered_workflow) + logger.info("GitHub Actions EC2 workflow file generated successfully.") + +def get_current_git_branch() -> str: + """ + Retrieves the current active git branch name. + + Returns: + str: The name of the current git branch. + """ + repo = git.Repo(search_parent_directories=True) + branch = repo.active_branch.name + return branch + +def get_github_actions_url() -> str: + """ + Get the GitHub Actions URL for the user's repository. + + Returns: + str: The Github Actions URL + """ + url = f"https://github.com/{config.GITHUB_OWNER}/{config.GITHUB_REPO}/actions" + return url + +def get_gradio_server_url(ip_address: str) -> str: + """ + Get the Gradio server URL using the provided IP address. + + Args: + ip_address (str): The IP address of the EC2 instance running the Gradio server. + + Returns: + str: The Gradio server URL + """ + url = f"http://{ip_address}:{config.PORT}" + return url + +def git_push_set_upstream(branch_name: str) -> None: + """ + Pushes the current branch to the remote 'origin' and sets it to track the upstream branch. + If the push fails due to the branch being behind, pulls changes and retries. + + Args: + branch_name (str): The name of the current branch to push. + """ + try: + # Push the current branch and set the remote 'origin' as upstream + subprocess.run(["git", "push", "--set-upstream", "origin", branch_name], check=True) + logger.info(f"Branch '{branch_name}' pushed and set up to track 'origin/{branch_name}'.") + except subprocess.CalledProcessError as e: + if "non-fast-forward" in str(e): + logger.info("Branch is behind 'origin'. Attempting to pull and re-push...") + try: + subprocess.run(["git", "pull", "--rebase", "origin", branch_name], check=True) + subprocess.run(["git", "push", "--set-upstream", "origin", branch_name], check=True) + logger.info(f"Branch '{branch_name}' pushed after rebase.") + except subprocess.CalledProcessError as pull_push_error: + logger.error(f"Failed to push branch '{branch_name}' to 'origin' after pull: {pull_push_error}") + else: + logger.error(f"Failed to push branch '{branch_name}' to 'origin': {e}") + +def update_git_remote_with_pat(github_owner: str, repo_name: str, pat: str): + """ + Updates the git remote 'origin' to include the Personal Access Token in the URL. + + Args: + github_owner (str): GitHub repository owner. + repo_name (str): GitHub repository name. + pat (str): Personal Access Token with the necessary scopes. + + """ + new_origin_url = f"https://{github_owner}:{pat}@github.com/{github_owner}/{repo_name}.git" + try: + # Remove the existing 'origin' remote + subprocess.run(["git", "remote", "remove", "origin"], check=True) + # Add the new 'origin' with the PAT in the URL + subprocess.run(["git", "remote", "add", "origin", new_origin_url], check=True) + logger.info("Git remote 'origin' updated successfully.") + except subprocess.CalledProcessError as e: + logger.error(f"Failed to update git remote 'origin': {e}") + +class Deploy: + + @staticmethod + def start() -> None: + """ + Main method to execute the deployment process. + + Returns: + None + """ + set_github_secrets() + instance_id, instance_ip = configure_ec2_instance() + assert instance_ip, f"invalid {instance_ip=}" + generate_github_actions_workflow() + + # Update the Git remote configuration to include the PAT + update_git_remote_with_pat( + config.GITHUB_OWNER, config.GITHUB_REPO, config.GITHUB_TOKEN, + ) + + # Use the `ssh` method to connect and execute instance setup commands + Deploy.ssh(non_interactive=True) + + # Add, commit, and push the workflow file changes, setting the upstream branch + try: + # Stage the workflow file + subprocess.run(["git", "add", ".github/workflows/docker-build-ec2.yml"], check=True) + + # Check if there are any staged changes to commit + result = subprocess.run(["git", "diff", "--cached", "--exit-code"]) + if result.returncode != 0: + # Proceed with the commit only if there are changes + subprocess.run(["git", "commit", "-m", "add workflow file"], check=True) + logger.info("Changes committed successfully.") + else: + logger.info("No changes to commit.") + + current_branch = get_current_git_branch() + git_push_set_upstream(current_branch) + + except subprocess.CalledProcessError as e: + logger.error(f"Failed to commit or push changes: {e}") + + github_actions_url = get_github_actions_url() + gradio_server_url = get_gradio_server_url(instance_ip) + logger.info("Deployment process completed.") + logger.info(f"Check the GitHub Actions at {github_actions_url}.") + logger.info("Once the action is complete, run:") + logger.info(f" python client.py {gradio_server_url} //") + + @staticmethod + def pause(project_name: str = config.PROJECT_NAME) -> None: + """ + Shuts down the EC2 instance associated with the specified project name. + + Args: + project_name (str): The project name used to tag the instance. Defaults to config.PROJECT_NAME. + + Returns: + None + """ + ec2 = boto3.resource('ec2') + + instances = ec2.instances.filter( + Filters=[ + {'Name': 'tag:Name', 'Values': [project_name]}, + {'Name': 'instance-state-name', 'Values': ['running']} + ] + ) + + for instance in instances: + logger.info(f"Shutting down instance: ID - {instance.id}") + instance.stop() + + @staticmethod + def stop( + project_name: str = config.PROJECT_NAME, + security_group_name: str = config.AWS_EC2_SECURITY_GROUP, + ) -> None: + """ + Terminates the EC2 instance and deletes the associated security group. + + Args: + project_name (str): The project name used to tag the instance. Defaults to config.PROJECT_NAME. + security_group_name (str): The name of the security group to delete. Defaults to config.AWS_EC2_SECURITY_GROUP. + + Returns: + None + """ + ec2_resource = boto3.resource('ec2') + ec2_client = boto3.client('ec2') + + # Terminate EC2 instances + instances = ec2_resource.instances.filter( + Filters=[ + {'Name': 'tag:Name', 'Values': [project_name]}, + {'Name': 'instance-state-name', 'Values': ['pending', 'running', 'shutting-down', 'stopped', 'stopping']} + ] + ) + + for instance in instances: + logger.info(f"Terminating instance: ID - {instance.id}") + instance.terminate() + instance.wait_until_terminated() + logger.info(f"Instance {instance.id} terminated successfully.") + + # Delete security group + try: + ec2_client.delete_security_group(GroupName=security_group_name) + logger.info(f"Deleted security group: {security_group_name}") + except ClientError as e: + if e.response['Error']['Code'] == 'InvalidGroup.NotFound': + logger.info(f"Security group {security_group_name} does not exist or already deleted.") + else: + logger.error(f"Error deleting security group: {e}") + + @staticmethod + def status() -> None: + """ + Lists all EC2 instances tagged with the project name, along with their HTTP URLs. + + Returns: + None + """ + ec2 = boto3.resource('ec2') + instances = ec2.instances.filter( + Filters=[{'Name': 'tag:Name', 'Values': [config.PROJECT_NAME]}] + ) + + for instance in instances: + public_ip = instance.public_ip_address + if public_ip: + http_url = f"http://{public_ip}:{config.PORT}" + logger.info(f"Instance ID: {instance.id}, State: {instance.state['Name']}, HTTP URL: {http_url}") + else: + logger.info(f"Instance ID: {instance.id}, State: {instance.state['Name']}, HTTP URL: Not available (no public IP)") + + @staticmethod + def ssh(project_name: str = config.PROJECT_NAME, non_interactive: bool = False) -> None: + """ + Establishes an SSH connection to the EC2 instance associated with the specified project name. + + Args: + project_name (str): The project name used to tag the instance. Defaults to config.PROJECT_NAME. + non_interactive (bool): If True, ensures a full interactive login simulation. Defaults to False. + + Returns: + None + """ + ec2 = boto3.resource('ec2') + instances = ec2.instances.filter( + Filters=[ + {'Name': 'tag:Name', 'Values': [project_name]}, + {'Name': 'instance-state-name', 'Values': ['running']} + ] + ) + + for instance in instances: + logger.info(f"Attempting to SSH into instance: ID - {instance.id}, IP - {instance.public_ip_address}") + + if non_interactive: + # Simulate full login by forcing all initialization scripts + ssh_command = [ + "ssh", + "-o", "StrictHostKeyChecking=no", # Automatically accept new host keys + "-o", "UserKnownHostsFile=/dev/null", # Prevent writing to known_hosts + "-i", config.AWS_EC2_KEY_PATH, + f"{config.AWS_EC2_USER}@{instance.public_ip_address}", + "-t", # Allocate a pseudo-terminal + "-tt", # Force pseudo-terminal allocation + "bash --login -c 'exit'" # Force a full login shell and exit immediately + ] + else: + # Standard interactive SSH session + ssh_command = [ + "ssh", + "-o", "StrictHostKeyChecking=no", + "-o", "UserKnownHostsFile=/dev/null", + "-i", config.AWS_EC2_KEY_PATH, + f"{config.AWS_EC2_USER}@{instance.public_ip_address}" + ] + + # Execute the SSH command + try: + subprocess.run(ssh_command, check=True) + except subprocess.CalledProcessError as e: + logger.error(f"SSH connection failed: {e}") + +if __name__ == "__main__": + fire.Fire(Deploy) diff --git a/deploy_requirements.txt b/deploy_requirements.txt new file mode 100644 index 0000000..82ad3b2 --- /dev/null +++ b/deploy_requirements.txt @@ -0,0 +1,12 @@ +boto3==1.34.18 +botocore==1.34.18 +fire==0.5.0 +gitpython==3.1.41 +gradio_client==1.4.2 +jinja2==3.1.3 +loguru==0.7.2 +paramiko==3.5.0 +Pillow==11.0.0 +pydantic_settings==2.1.0 +pynacl==1.5.0 +requests==2.31.0 diff --git a/docker-build-ec2.yml.j2 b/docker-build-ec2.yml.j2 new file mode 100644 index 0000000..2b289ca --- /dev/null +++ b/docker-build-ec2.yml.j2 @@ -0,0 +1,40 @@ +name: Docker Build on EC2 Instance for OmniParser + +on: + push: + branches: + - {{ branch_name }} + +jobs: + build: + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v2 + + - name: SSH and Execute Build on EC2 + uses: appleboy/ssh-action@master + with: + command_timeout: "60m" + host: {{ host }} + username: {{ username }} + {% raw %} + key: ${{ secrets.SSH_PRIVATE_KEY }} + {% endraw %} + script: | + rm -rf {{ github_repo }} || true + git clone https://github.com/{{ github_path }} + cd {{ github_repo }} + git checkout {{ branch_name }} + git pull + + # Stop and remove any existing containers + sudo docker stop {{ project_name }}-container || true + sudo docker rm {{ project_name }}-container || true + + # Build the Docker image + sudo docker build -t {{ project_name }} . + + # Run the Docker container on the specified port + sudo docker run -d -p 7861:7861 --gpus all --name {{ project_name }}-container {{ project_name }} + diff --git a/download.py b/download.py new file mode 100644 index 0000000..03d967f --- /dev/null +++ b/download.py @@ -0,0 +1,16 @@ +import os +from huggingface_hub import snapshot_download + +# Set the repository name +repo_id = "microsoft/OmniParser" + +# Set the local directory where you want to save the files +local_dir = "weights" + +# Create the local directory if it doesn't exist +os.makedirs(local_dir, exist_ok=True) + +# Download the entire repository +snapshot_download(repo_id, local_dir=local_dir, ignore_patterns=["*.md"]) + +print(f"All files and folders have been downloaded to {local_dir}") diff --git a/entrypoint.sh b/entrypoint.sh new file mode 100644 index 0000000..c3d87db --- /dev/null +++ b/entrypoint.sh @@ -0,0 +1,8 @@ +#!/bin/bash + +# Entry point for starting OmniParser's Gradio demo. +# Note: You can read or set environment variables here if needed. + +# Start the Gradio demo script +echo "Starting OmniParser Gradio demo..." +python ./gradio_demo.py diff --git a/gradio_demo.py b/gradio_demo.py index 0557680..92e0515 100644 --- a/gradio_demo.py +++ b/gradio_demo.py @@ -7,7 +7,7 @@ import io -import base64, os +import base64, json, os from utils import check_ocr_box, get_yolo_model, get_caption_model_processor, get_som_labeled_img import torch from PIL import Image @@ -57,9 +57,11 @@ def process( dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(image_save_path, yolo_model, BOX_TRESHOLD = box_threshold, output_coord_in_ratio=True, ocr_bbox=ocr_bbox,draw_bbox_config=draw_bbox_config, caption_model_processor=caption_model_processor, ocr_text=text,iou_threshold=iou_threshold, imgsz=imgsz, batch_size=icon_process_batch_size) image = Image.open(io.BytesIO(base64.b64decode(dino_labled_img))) print('finish processing') - # parsed_content_list = '\n'.join(parsed_content_list) - parsed_content_list = '\n'.join([f'type: {x['type']}, content: {x["content"]}, interactivity: {x["interactivity"]}' for x in parsed_content_list]) - return image, str(parsed_content_list) + parsed_content = json.dumps({ + "parsed_content_list": parsed_content_list, + "label_coordinates": label_coordinates, + }, indent=2) + return image, parsed_content parser = argparse.ArgumentParser(description='Process model paths and names.')