Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for image search engines #317

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ accelerate
addict
albumentations
basicsr
clip-retrieval
controlnet-aux
diffusers
einops
Expand Down
61 changes: 61 additions & 0 deletions search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import requests

from clip_retrieval.clip_client import ClipClient, Modality
from PIL import Image

from utils import get_image_name, get_new_image_name, prompts


def download_image(img_url, img_path):
img_stream = requests.get(img_url, stream=True)
if img_stream.status_code == 200:
img = Image.open(img_stream.raw)
img.save(img_path, format="png")
return img_path


def download_best_available(search_result, result_img_path):
if search_result:
img_path = download_image(search_result[0]["url"], result_img_path)
return img_path if img_path else download_best_available(search_result[1:], result_img_path)


class SearchSupport:
def __init__(self):
self.client = ClipClient(
url="https://knn.laion.ai/knn-service",
indice_name="laion5B-L-14",
modality=Modality.IMAGE,
aesthetic_score=0,
aesthetic_weight=0.0,
num_images=10,
)


class ImageSearch(SearchSupport):
def __init__(self, *args, **kwargs):
print("Initializing Image Search")
super().__init__()

@prompts(name="Search Image That Matches User Input Text",
description="useful when you want to search an image that matches a given description. "
"like: find an image that contains certain objects with certain properties, "
"or refine a previous search with additional criteria. "
"The input to this tool should be a string, representing the description. ")
def inference(self, query_text):
search_result = self.client.query(text=query_text)
return download_best_available(search_result, get_image_name())


class VisualSearch(SearchSupport):
def __init__(self, *args, **kwargs):
print("Initializing Visual Search")
super().__init__()

@prompts(name="Search Image Visually Similar to an Input Image",
description="useful when you want to search an image that is visually similar to an input image. "
"like: find an image visually similar to a generated or modified image. "
"The input to this tool should be a string, representing the input image path. ")
def inference(self, query_img_path):
search_result = self.client.query(image=query_img_path)
return download_best_available(search_result, get_new_image_name(query_img_path, "visual-search"))
47 changes: 47 additions & 0 deletions utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import os
import uuid


def prompts(name, description):
def decorator(func):
func.name = name
func.description = description
return func

return decorator


def cut_dialogue_history(history_memory, keep_last_n_words=500):
if history_memory is None or len(history_memory) == 0:
return history_memory
tokens = history_memory.split()
n_tokens = len(tokens)
print(f"history_memory:{history_memory}, n_tokens: {n_tokens}")
if n_tokens < keep_last_n_words:
return history_memory
paragraphs = history_memory.split('\n')
last_n_tokens = n_tokens
while last_n_tokens >= keep_last_n_words:
last_n_tokens -= len(paragraphs[0].split(' '))
paragraphs = paragraphs[1:]
return '\n' + '\n'.join(paragraphs)


def get_new_image_name(org_img_name, func_name="update"):
head_tail = os.path.split(org_img_name)
head = head_tail[0]
tail = head_tail[1]
name_split = tail.split('.')[0].split('_')
this_new_uuid = str(uuid.uuid4())[:4]
if len(name_split) == 1:
most_org_file_name = name_split[0]
else:
assert len(name_split) == 4
most_org_file_name = name_split[3]
recent_prev_file_name = name_split[0]
new_file_name = f'{this_new_uuid}_{func_name}_{recent_prev_file_name}_{most_org_file_name}.png'
return os.path.join(head, new_file_name)


def get_image_name():
return os.path.join('image', f"{str(uuid.uuid4())[:8]}.png")
58 changes: 10 additions & 48 deletions visual_chatgpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
import torch
import cv2
import re
import uuid
from PIL import Image, ImageDraw, ImageOps
from PIL import Image, ImageOps
import math
import numpy as np
import argparse
Expand All @@ -25,6 +24,9 @@
from langchain.chains.conversation.memory import ConversationBufferMemory
from langchain.llms.openai import OpenAI

from search import ImageSearch, VisualSearch
from utils import cut_dialogue_history, get_image_name, get_new_image_name, prompts

VISUAL_CHATGPT_PREFIX = """Visual ChatGPT is designed to be able to assist with a wide range of text and visual related tasks, from answering simple questions to providing in-depth explanations and discussions on a wide range of topics. Visual ChatGPT is able to generate human-like text based on the input it receives, allowing it to engage in natural-sounding conversations and provide responses that are coherent and relevant to the topic at hand.

Visual ChatGPT is able to process and understand large amounts of text and images. As a language model, Visual ChatGPT can not directly read images, but it has a list of tools to finish different visual tasks. Each image will have a file name formed as "image/xxx.png", and Visual ChatGPT can invoke different tools to indirectly understand pictures. When talking about images, Visual ChatGPT is very strict to the file name and will never fabricate nonexistent files. When using tools to generate new image files, Visual ChatGPT is also known that the image may not be the same as the user's demand, and will use other visual question answering tools or description tools to observe the real image. Visual ChatGPT is able to use tools in a sequence, and is loyal to the tool observation outputs rather than faking the image content and image file name. It will remember to provide the file name from the last tool observation, if a new image is generated.
Expand Down Expand Up @@ -81,15 +83,6 @@ def seed_everything(seed):
return seed


def prompts(name, description):
def decorator(func):
func.name = name
func.description = description
return func

return decorator


def blend_gt2pt(old_image, new_image, sigma=0.15, steps=100):
new_size = new_image.size
old_size = old_image.size
Expand Down Expand Up @@ -147,39 +140,6 @@ def blend_gt2pt(old_image, new_image, sigma=0.15, steps=100):
return gaussian_img


def cut_dialogue_history(history_memory, keep_last_n_words=500):
if history_memory is None or len(history_memory) == 0:
return history_memory
tokens = history_memory.split()
n_tokens = len(tokens)
print(f"history_memory:{history_memory}, n_tokens: {n_tokens}")
if n_tokens < keep_last_n_words:
return history_memory
paragraphs = history_memory.split('\n')
last_n_tokens = n_tokens
while last_n_tokens >= keep_last_n_words:
last_n_tokens -= len(paragraphs[0].split(' '))
paragraphs = paragraphs[1:]
return '\n' + '\n'.join(paragraphs)


def get_new_image_name(org_img_name, func_name="update"):
head_tail = os.path.split(org_img_name)
head = head_tail[0]
tail = head_tail[1]
name_split = tail.split('.')[0].split('_')
this_new_uuid = str(uuid.uuid4())[:4]
if len(name_split) == 1:
most_org_file_name = name_split[0]
else:
assert len(name_split) == 4
most_org_file_name = name_split[3]
recent_prev_file_name = name_split[0]
new_file_name = f'{this_new_uuid}_{func_name}_{recent_prev_file_name}_{most_org_file_name}.png'
return os.path.join(head, new_file_name)



class MaskFormer:
def __init__(self, device):
print(f"Initializing MaskFormer to {device}")
Expand Down Expand Up @@ -295,7 +255,7 @@ def __init__(self, device):
"like: generate an image of an object or something, or generate an image that includes some objects. "
"The input to this tool should be a string, representing the text used to generate image. ")
def inference(self, text):
image_filename = os.path.join('image', f"{str(uuid.uuid4())[:8]}.png")
image_filename = get_image_name()
prompt = text + ', ' + self.a_prompt
image = self.pipe(prompt, negative_prompt=self.n_prompt).images[0]
image.save(image_filename)
Expand Down Expand Up @@ -1021,7 +981,7 @@ def run_text(self, text, state):
return state, state

def run_image(self, image, state, txt):
image_filename = os.path.join('image', f"{str(uuid.uuid4())[:8]}.png")
image_filename = get_image_name()
print("======>Auto Resize Image...")
img = Image.open(image.name)
width, height = img.size
Expand All @@ -1046,11 +1006,13 @@ def run_image(self, image, state, txt):
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--load', type=str, default="ImageCaptioning_cuda:0,Text2Image_cuda:0")
parser.add_argument('--host', type=str, default="0.0.0.0")
parser.add_argument('--port', type=int, default=1015)
args = parser.parse_args()
load_dict = {e.split('_')[0].strip(): e.split('_')[1].strip() for e in args.load.split(',')}
bot = ConversationBot(load_dict=load_dict)
with gr.Blocks(css="#chatbot .overflow-y-auto{height:500px}") as demo:
chatbot = gr.Chatbot(elem_id="chatbot", label="Visual ChatGPT")
chatbot = gr.Chatbot(elem_id="chatbot", label="Visual ChatGPT").style(height=800)
state = gr.State([])
with gr.Row():
with gr.Column(scale=0.7):
Expand All @@ -1067,4 +1029,4 @@ def run_image(self, image, state, txt):
clear.click(bot.memory.clear)
clear.click(lambda: [], None, chatbot)
clear.click(lambda: [], None, state)
demo.launch(server_name="0.0.0.0", server_port=1015)
demo.launch(server_name=args.host, server_port=args.port)