Skip to content

Commit

Permalink
correct system message handling and add image support
Browse files Browse the repository at this point in the history
  • Loading branch information
justinh-rahb authored Jun 26, 2024
1 parent 9cf0fc4 commit d4d1828
Showing 1 changed file with 78 additions and 74 deletions.
152 changes: 78 additions & 74 deletions examples/pipelines/providers/anthropic_manifold_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
"""
title: Anthropic Manifold Pipeline
author: justinh-rahb
author: justinh-rahb (updated)
date: 2024-06-20
version: 1.1
version: 1.2
license: MIT
description: A pipeline for generating text using the Anthropic API.
description: A pipeline for generating text and processing images using the Anthropic API.
requirements: requests, anthropic
environment_variables: ANTHROPIC_API_KEY
"""
Expand Down Expand Up @@ -35,13 +35,11 @@ def __init__(self):
self.client = Anthropic(api_key=self.valves.ANTHROPIC_API_KEY)

def get_anthropic_models(self):
# In the future, this could fetch models dynamically from Anthropic
return [
{"id": "claude-3-haiku-20240307", "name": "claude-3-haiku"},
{"id": "claude-3-opus-20240229", "name": "claude-3-opus"},
{"id": "claude-3-sonnet-20240229", "name": "claude-3-sonnet"},
{"id": "claude-3-5-sonnet-20240620", "name": "claude-3.5-sonnet"},
# Add other Anthropic models here as they become available
]

async def on_startup(self):
Expand All @@ -53,95 +51,101 @@ async def on_shutdown(self):
pass

async def on_valves_updated(self):
# This function is called when the valves are updated.
self.client = Anthropic(api_key=self.valves.ANTHROPIC_API_KEY)
pass

# Pipelines are the models that are available in the manifold.
# It can be a list or a function that returns a list.
def pipelines(self) -> List[dict]:
return self.get_anthropic_models()

def process_image(self, image_data):
if image_data["image_url"]["url"].startswith("data:image"):
mime_type, base64_data = image_data["image_url"]["url"].split(",", 1)
media_type = mime_type.split(":")[1].split(";")[0]
return {
"type": "image",
"source": {
"type": "base64",
"media_type": media_type,
"data": base64_data,
},
}
else:
return {
"type": "image",
"source": {"type": "url", "url": image_data["image_url"]["url"]},
}

def pipe(
self, user_message: str, model_id: str, messages: List[dict], body: dict
) -> Union[str, Generator, Iterator]:
try:
if "user" in body:
del body["user"]
if "chat_id" in body:
del body["chat_id"]
if "title" in body:
del body["title"]
# Remove unnecessary keys
for key in ['user', 'chat_id', 'title']:
body.pop(key, None)

system_message, messages = pop_system_message(messages)

processed_messages = []
image_count = 0
total_image_size = 0

for message in messages:
processed_content = []
if isinstance(message.get("content"), list):
for item in message["content"]:
if item["type"] == "text":
processed_content.append({"type": "text", "text": item["text"]})
elif item["type"] == "image_url":
if image_count >= 5:
raise ValueError("Maximum of 5 images per API call exceeded")

processed_image = self.process_image(item)
processed_content.append(processed_image)

if processed_image["source"]["type"] == "base64":
image_size = len(processed_image["source"]["data"]) * 3 / 4
else:
image_size = 0

total_image_size += image_size
if total_image_size > 100 * 1024 * 1024:
raise ValueError("Total size of images exceeds 100 MB limit")

image_count += 1
else:
processed_content = [{"type": "text", "text": message.get("content", "")}]

processed_messages.append({"role": message["role"], "content": processed_content})

# Prepare the payload
payload = {
"model": model_id,
"messages": processed_messages,
"max_tokens": body.get("max_tokens", 4096),
"temperature": body.get("temperature", 0.8),
"top_k": body.get("top_k", 40),
"top_p": body.get("top_p", 0.9),
"stop_sequences": body.get("stop", []),
**({"system": str(system_message)} if system_message else {}),
"stream": body.get("stream", False),
}

if body.get("stream", False):
return self.stream_response(model_id, messages, body)
return self.stream_response(model_id, payload)
else:
return self.get_completion(model_id, messages, body)
return self.get_completion(model_id, payload)
except (RateLimitError, APIStatusError, APIConnectionError) as e:
return f"Error: {e}"

def stream_response(
self, model_id: str, messages: List[dict], body: dict
) -> Generator:
system_message, messages = pop_system_message(messages)

max_tokens = (
body.get("max_tokens") if body.get("max_tokens") is not None else 4096
)
temperature = (
body.get("temperature") if body.get("temperature") is not None else 0.8
)
top_k = body.get("top_k") if body.get("top_k") is not None else 40
top_p = body.get("top_p") if body.get("top_p") is not None else 0.9
stop_sequences = body.get("stop") if body.get("stop") is not None else []

stream = self.client.messages.create(
**{
"model": model_id,
**(
{"system": system_message} if system_message else {}
), # Add system message if it exists (optional
"messages": messages,
"max_tokens": max_tokens,
"temperature": temperature,
"top_k": top_k,
"top_p": top_p,
"stop_sequences": stop_sequences,
"stream": True,
}
)
def stream_response(self, model_id: str, payload: dict) -> Generator:
stream = self.client.messages.create(**payload)

for chunk in stream:
if chunk.type == "content_block_start":
yield chunk.content_block.text
elif chunk.type == "content_block_delta":
yield chunk.delta.text

def get_completion(self, model_id: str, messages: List[dict], body: dict) -> str:
system_message, messages = pop_system_message(messages)

max_tokens = (
body.get("max_tokens") if body.get("max_tokens") is not None else 4096
)
temperature = (
body.get("temperature") if body.get("temperature") is not None else 0.8
)
top_k = body.get("top_k") if body.get("top_k") is not None else 40
top_p = body.get("top_p") if body.get("top_p") is not None else 0.9
stop_sequences = body.get("stop") if body.get("stop") is not None else []

response = self.client.messages.create(
**{
"model": model_id,
**(
{"system": system_message} if system_message else {}
), # Add system message if it exists (optional
"messages": messages,
"max_tokens": max_tokens,
"temperature": temperature,
"top_k": top_k,
"top_p": top_p,
"stop_sequences": stop_sequences,
}
)
return response.content[0].text
def get_completion(self, model_id: str, payload: dict) -> str:
response = self.client.messages.create(**payload)
return response.content[0].text

0 comments on commit d4d1828

Please sign in to comment.