Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add tqdm download progress bar #442

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ langchain==0.0.101
torch==1.13.1
torchvision==0.14.1
wget==3.2
tdqm
accelerate
addict
albumentations
Expand Down
24 changes: 19 additions & 5 deletions visual_chatgpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import numpy as np
import matplotlib.pyplot as plt
import wget
from tqdm import tqdm

VISUAL_CHATGPT_PREFIX = """Visual ChatGPT is designed to be able to assist with a wide range of text and visual related tasks, from answering simple questions to providing in-depth explanations and discussions on a wide range of topics. Visual ChatGPT is able to generate human-like text based on the input it receives, allowing it to engage in natural-sounding conversations and provide responses that are coherent and relevant to the topic at hand.

Expand Down Expand Up @@ -138,6 +139,18 @@
os.makedirs('image', exist_ok=True)


class ProgressBar:

def __init__(self, url):
self.progress_bar = None
print(f"Downloading checkpoints file from {url}")

def __call__(self, current_bytes, total_bytes, width):
if self.progress_bar is None:
self.progress_bar = tqdm(total=total_bytes, unit='B', unit_scale=True, unit_divisor=1024)#tqdm(total=total_mb, desc="MB")
self.progress_bar.update(current_bytes)


def seed_everything(seed):
random.seed(seed)
np.random.seed(seed)
Expand Down Expand Up @@ -817,7 +830,7 @@ def __init__(self, device):
def download_parameters(self):
url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
if not os.path.exists(self.model_checkpoint_path):
wget.download(url,out=self.model_checkpoint_path)
wget.download(url, out=self.model_checkpoint_path, bar=ProgressBar(url))


def show_mask(self, mask: np.ndarray,image: np.ndarray,
Expand Down Expand Up @@ -1038,12 +1051,13 @@ def __init__(self, device):
def download_parameters(self):
url = "https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth"
if not os.path.exists(self.model_checkpoint_path):
wget.download(url,out=self.model_checkpoint_path)
wget.download(url, out=self.model_checkpoint_path, bar=ProgressBar(url))
config_url = "https://raw.githubusercontent.com/IDEA-Research/GroundingDINO/main/groundingdino/config/GroundingDINO_SwinT_OGC.py"
if not os.path.exists(self.model_config_path):
wget.download(config_url,out=self.model_config_path)
def load_image(self,image_path):
# load image
wget.download(config_url, out=self.model_config_path, bar=ProgressBar(url))

def load_image(self, image_path):
# load image
image_pil = Image.open(image_path).convert("RGB") # load image

transform = T.Compose(
Expand Down