Skip to content

Commit

Permalink
add tqdm download progress bar
Browse files Browse the repository at this point in the history
  • Loading branch information
truebit committed Sep 1, 2023
1 parent 4b7664f commit 0af4fe7
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 5 deletions.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ langchain==0.0.101
torch==1.13.1
torchvision==0.14.1
wget==3.2
tdqm
accelerate
addict
albumentations
Expand Down
24 changes: 19 additions & 5 deletions visual_chatgpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import numpy as np
import matplotlib.pyplot as plt
import wget
from tqdm import tqdm

VISUAL_CHATGPT_PREFIX = """Visual ChatGPT is designed to be able to assist with a wide range of text and visual related tasks, from answering simple questions to providing in-depth explanations and discussions on a wide range of topics. Visual ChatGPT is able to generate human-like text based on the input it receives, allowing it to engage in natural-sounding conversations and provide responses that are coherent and relevant to the topic at hand.
Expand Down Expand Up @@ -138,6 +139,18 @@
os.makedirs('image', exist_ok=True)


class ProgressBar:

def __init__(self, url):
self.progress_bar = None
print(f"Downloading checkpoints file from {url}")

def __call__(self, current_bytes, total_bytes, width):
if self.progress_bar is None:
self.progress_bar = tqdm(total=total_bytes, unit='B', unit_scale=True, unit_divisor=1024)#tqdm(total=total_mb, desc="MB")
self.progress_bar.update(current_bytes)


def seed_everything(seed):
random.seed(seed)
np.random.seed(seed)
Expand Down Expand Up @@ -817,7 +830,7 @@ def __init__(self, device):
def download_parameters(self):
url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
if not os.path.exists(self.model_checkpoint_path):
wget.download(url,out=self.model_checkpoint_path)
wget.download(url, out=self.model_checkpoint_path, bar=ProgressBar(url))


def show_mask(self, mask: np.ndarray,image: np.ndarray,
Expand Down Expand Up @@ -1038,12 +1051,13 @@ def __init__(self, device):
def download_parameters(self):
url = "https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth"
if not os.path.exists(self.model_checkpoint_path):
wget.download(url,out=self.model_checkpoint_path)
wget.download(url, out=self.model_checkpoint_path, bar=ProgressBar(url))
config_url = "https://raw.githubusercontent.com/IDEA-Research/GroundingDINO/main/groundingdino/config/GroundingDINO_SwinT_OGC.py"
if not os.path.exists(self.model_config_path):
wget.download(config_url,out=self.model_config_path)
def load_image(self,image_path):
# load image
wget.download(config_url, out=self.model_config_path, bar=ProgressBar(url))

def load_image(self, image_path):
# load image
image_pil = Image.open(image_path).convert("RGB") # load image

transform = T.Compose(
Expand Down

0 comments on commit 0af4fe7

Please sign in to comment.