-
Notifications
You must be signed in to change notification settings - Fork 1
/
converse_agent.py
104 lines (88 loc) · 3.8 KB
/
converse_agent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import boto3, json, re
class ConverseAgent:
def __init__(self, model_id, region='us-west-2', system_prompt='You are a helpful assistant.'):
self.model_id = model_id
self.region = region
self.client = boto3.client('bedrock-runtime', region_name=self.region)
self.system_prompt = system_prompt
self.messages = []
self.tools = None
self.response_output_tags = [] # ['<response>', '</response>']
def invoke_with_prompt(self, prompt):
content = [
{
'text': prompt
}
]
return self.invoke(content)
def invoke(self, content):
self.messages.append(
{
"role": "user",
"content": content
}
)
response = self._get_converse_response()
return self._handle_response(response)
def _get_converse_response(self):
"""
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/bedrock-runtime/client/converse.html
"""
# print(f"Invoking with messages: {json.dumps(self.messages, indent=2)}")
response = self.client.converse(
modelId=self.model_id,
messages=self.messages,
system=[
{
"text": self.system_prompt
}
],
inferenceConfig={
"maxTokens": 1024,
"temperature": 0.7,
},
toolConfig=self.tools.get_tools()
)
return(response)
def _handle_response(self, response):
# Add the response to the conversation history
self.messages.append(response['output']['message'])
# Do we need to do anything else?
stop_reason = response['stopReason']
if stop_reason in ['end_turn', 'stop_sequence']:
# Safely extract the text from the nested response structure
try:
message = response.get('output', {}).get('message', {})
content = message.get('content', [])
text = content[0].get('text', '')
if hasattr(self, 'response_output_tags') and len(self.response_output_tags) == 2:
pattern = f"(?s).*{re.escape(self.response_output_tags[0])}(.*?){re.escape(self.response_output_tags[1])}"
match = re.search(pattern, text)
if match:
return match.group(1)
return text
except (KeyError, IndexError):
return ''
elif stop_reason == 'tool_use':
try:
# Extract tool use details from response
tool_response = []
for content_item in response['output']['message']['content']:
if 'toolUse' in content_item:
tool_request = {
"toolUseId": content_item['toolUse']['toolUseId'],
"name": content_item['toolUse']['name'],
"input": content_item['toolUse']['input']
}
tool_result = self.tools.execute_tool(tool_request)
tool_response.append({'toolResult': tool_result})
return self.invoke(tool_response)
except KeyError as e:
raise ValueError(f"Missing required tool use field: {e}")
except Exception as e:
raise ValueError(f"Failed to execute tool: {e}")
elif stop_reason == 'max_tokens':
# Hit token limit (this is one way to handle it.)
self.invoke_with_prompt('Please continue.')
else:
raise ValueError(f"Unknown stop reason: {stop_reason}")