Skip to content

Commit

Permalink
Update main.py
Browse files Browse the repository at this point in the history
  • Loading branch information
rishiraj authored Aug 15, 2024
1 parent 50d2ac0 commit fe8ae8c
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions spanking/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
import pandas as pd

class VectorDB:
def __init__(self, model_name='BAAI/bge-base-en-v1.5'):
self.model = SentenceTransformer(model_name)
def __init__(self, model_name='dunzhang/stella_en_400M_v5'):
self.model = SentenceTransformer(model_name, trust_remote_code=True)
self.image_classifier = pipeline(task="zero-shot-image-classification", model="google/siglip-so400m-patch14-384")
self.texts = []
self.embeddings = None
Expand Down Expand Up @@ -45,7 +45,7 @@ def search(self, query, top_k=5, type='text'):
if isinstance(query, str):
query = Image.open(requests.get(query, stream=True).raw)
outputs = self.image_classifier(query, candidate_labels=self.texts)
similarities = jnp.array([output['score'] for output in outputs])
similarities = jnp.array([round(output["score"], 4) for output in outputs])
else:
raise ValueError("Invalid search type. Supported types are 'text' and 'image'.")
top_indices = jnp.argsort(similarities)[-top_k:][::-1]
Expand Down

0 comments on commit fe8ae8c

Please sign in to comment.