From 0af4fe76a76efd142add52be9158d86f92b3b4b1 Mon Sep 17 00:00:00 2001 From: Xiao Wang Date: Fri, 1 Sep 2023 21:52:33 +0800 Subject: [PATCH] add tqdm download progress bar --- requirements.txt | 1 + visual_chatgpt.py | 24 +++++++++++++++++++----- 2 files changed, 20 insertions(+), 5 deletions(-) diff --git a/requirements.txt b/requirements.txt index b3666e86..8aaefcfb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,6 +2,7 @@ langchain==0.0.101 torch==1.13.1 torchvision==0.14.1 wget==3.2 +tdqm accelerate addict albumentations diff --git a/visual_chatgpt.py b/visual_chatgpt.py index 17dc4def..f1ad42ac 100644 --- a/visual_chatgpt.py +++ b/visual_chatgpt.py @@ -44,6 +44,7 @@ import numpy as np import matplotlib.pyplot as plt import wget +from tqdm import tqdm VISUAL_CHATGPT_PREFIX = """Visual ChatGPT is designed to be able to assist with a wide range of text and visual related tasks, from answering simple questions to providing in-depth explanations and discussions on a wide range of topics. Visual ChatGPT is able to generate human-like text based on the input it receives, allowing it to engage in natural-sounding conversations and provide responses that are coherent and relevant to the topic at hand. @@ -138,6 +139,18 @@ os.makedirs('image', exist_ok=True) +class ProgressBar: + + def __init__(self, url): + self.progress_bar = None + print(f"Downloading checkpoints file from {url}") + + def __call__(self, current_bytes, total_bytes, width): + if self.progress_bar is None: + self.progress_bar = tqdm(total=total_bytes, unit='B', unit_scale=True, unit_divisor=1024)#tqdm(total=total_mb, desc="MB") + self.progress_bar.update(current_bytes) + + def seed_everything(seed): random.seed(seed) np.random.seed(seed) @@ -817,7 +830,7 @@ def __init__(self, device): def download_parameters(self): url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth" if not os.path.exists(self.model_checkpoint_path): - wget.download(url,out=self.model_checkpoint_path) + wget.download(url, out=self.model_checkpoint_path, bar=ProgressBar(url)) def show_mask(self, mask: np.ndarray,image: np.ndarray, @@ -1038,12 +1051,13 @@ def __init__(self, device): def download_parameters(self): url = "https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth" if not os.path.exists(self.model_checkpoint_path): - wget.download(url,out=self.model_checkpoint_path) + wget.download(url, out=self.model_checkpoint_path, bar=ProgressBar(url)) config_url = "https://raw.githubusercontent.com/IDEA-Research/GroundingDINO/main/groundingdino/config/GroundingDINO_SwinT_OGC.py" if not os.path.exists(self.model_config_path): - wget.download(config_url,out=self.model_config_path) - def load_image(self,image_path): - # load image + wget.download(config_url, out=self.model_config_path, bar=ProgressBar(url)) + + def load_image(self, image_path): + # load image image_pil = Image.open(image_path).convert("RGB") # load image transform = T.Compose(