diff --git a/apps/multi_roles_chat_room/app.py b/apps/multi_roles_chat_room/app.py
index dc4f497b5..2cdb96f24 100644
--- a/apps/multi_roles_chat_room/app.py
+++ b/apps/multi_roles_chat_room/app.py
@@ -1,13 +1,17 @@
+import os
import re
+from typing import Optional
import gradio as gr
import json
import modelscope_studio as mgr
+from modelscope_agent.multi_agents_utils.executors.ray import RayTaskExecutor
from role_core import (chat_progress, init_all_remote_actors,
start_chat_with_topic)
from story_holder import get_avatar_by_name, get_story_by_id, stories
chat_history = []
+RayTaskExecutor.init_ray()
# 发送消息的函数
@@ -54,6 +58,17 @@ def end_topic():
return '', 'topic ended。'
+def check_uuid(uid: Optional[str]) -> str:
+ """Checks whether a UUID is provided or generates a default one."""
+ if not uid or uid == '':
+ if os.getenv('MODELSCOPE_ENVIRONMENT') == 'studio':
+ import gradio as gr
+
+ raise gr.Error('Please login first')
+ uid = 'local_user'
+ return uid
+
+
def format_entry_html():
base_html = '''
@@ -89,6 +104,8 @@ def format_entry_html():
with demo:
state = gr.State({})
story_state = gr.State('1') # default value
+ uuid = gr.Textbox(label='modelscope_uuid', visible=False)
+
with gr.Column(visible=True) as entry:
gr.Markdown('## 选择一个场景进入聊天吧~')
entry_btn = gr.Button(
@@ -197,10 +214,13 @@ def back():
back_btn.click(fn=back, inputs=[], outputs=[entry, content])
def start_chat(username, from_user, topic, _state, _chatbot, _input,
- _story_state, select_model):
+ _story_state, select_model, _uuid):
+ # get uuid
+ _uuid = check_uuid(_uuid)
+ user_task_id = f'{_uuid[:6]}_{_story_state}'
roles = get_story_by_id(_story_state)['roles']
_state = init_all_remote_actors(roles, username, _state, _story_state,
- select_model)
+ select_model, user_task_id)
_state = start_chat_with_topic(from_user, topic, _state)
init_chat = [{
@@ -328,7 +348,7 @@ def send_message(_chatbot, _input, _state):
fn=start_chat,
inputs=[
user_select, start_topic_from, start_topic_input, state,
- user_chatbot, preview_chat_input, story_state, select_model
+ user_chatbot, preview_chat_input, story_state, select_model, uuid
],
outputs=[user_chatbot, preview_chat_input, state])
diff --git a/apps/multi_roles_chat_room/requirements.txt b/apps/multi_roles_chat_room/requirements.txt
index 143544549..51a17116e 100644
--- a/apps/multi_roles_chat_room/requirements.txt
+++ b/apps/multi_roles_chat_room/requirements.txt
@@ -1 +1,3 @@
gradio==4.8
+modelscope-agent==0.4.1
+modelscope_studio
diff --git a/apps/multi_roles_chat_room/role_core.py b/apps/multi_roles_chat_room/role_core.py
index 8c2b314a9..821d038d3 100644
--- a/apps/multi_roles_chat_room/role_core.py
+++ b/apps/multi_roles_chat_room/role_core.py
@@ -8,7 +8,6 @@
from modelscope_agent import create_component
from modelscope_agent.agent_env_util import AgentEnvMixin
from modelscope_agent.agents import MultiRolePlay
-from modelscope_agent.multi_agents_utils.executors.ray import RayTaskExecutor
from modelscope_agent.task_center import TaskCenter
from story_holder import get_story_by_id
@@ -16,22 +15,21 @@
# instruction prompt
ROLE_INSTRUCTION_PROMPT = """<|im_start|>system
-你是{role},请你根据对话情节设定,和你的对话角色设定,继续当前的对话,说话要符合你的角色设定
-
-【你的角色设定】
-{role_description}
-
-【对话情节设定】
-{story}
+你是{role},请你根据对话情节设定,当前的对话记录和你的对话角色设定,继续当前的对话,说话要符合你的角色设定,不要重复chat history中的内容
【注意事项】
1. 根据user的历史对话回复时,不要重复说历史对话里的句子,要保持生成内容的有趣,丰富度,多样性
2. 长话短说,不要说太多话,不要超过50字
+3. 符合当前对话情节设定,围绕对话情节,你是{role},不要模仿其他角色,也不要重复chat history中的内容
+
+【对话场景】
+{story}
+
+【chat history】
+chat_records
【你的角色设定】
-{role_description}
-<|im_end|>
-"""
+{role_description}<|im_end|>"""
CHATROOM_INSTRUCTION_PROMPT = """<|im_start|>system
你是一个小说作家,请你根据对话场景、人物介绍及最近的对话记录,选择继续对话的下一个角色。注意,对话历史是以群聊的形式展现,因此角色可能会@某个人表示对这个人说话。
@@ -42,17 +40,17 @@
【人物介绍】
{all_roles_info}
-【对话记录】
+【chat history】
chat_records
【最新对话】
recent_records
【注意事项】
-根据对话记录和最新对话
+根据chat history和最新对话
1. 当主角{user_role}说话中提到某个角色,你需要只让提到的角色接话。
2. 不要选【最新对话】里的角色发言
-3. 要让每个角色有平等的对话机会,多选一些【对话记录】没有出现的角色,要求情节充满戏剧性。
+3. 要让每个角色有平等的对话机会,多选一些chat history没有出现的角色,要求情节充满戏剧性。
4. 只写角色名字即可,每次最多选两个角色,尽量多的选择主角,当前对话的主角是{user_role}
【回复格式】
@@ -91,7 +89,8 @@ def generate_role_instruction(role, story_info):
return instruction
-def upsert_role(new_user, user_char, human_input_mode, story_info, llm_config):
+def upsert_role(new_user, user_char, human_input_mode, story_info, llm_config,
+ _uid):
role = create_component(
MultiRolePlay,
name=new_user,
@@ -101,7 +100,8 @@ def upsert_role(new_user, user_char, human_input_mode, story_info, llm_config):
llm=llm_config,
function_list=function_list,
instruction=generate_role_instruction(new_user, story_info),
- human_input_mode=human_input_mode)
+ human_input_mode=human_input_mode,
+ prefix_name=_uid)
return role
@@ -127,9 +127,7 @@ def change_user_role(user_role, state):
def init_all_remote_actors(_roles, user_role, _state, _story_state,
- select_model):
- RayTaskExecutor.init_ray()
-
+ select_model, _uid):
story_info = get_story_by_id(_story_state)
llm_config['model'] = select_model
@@ -139,7 +137,7 @@ def init_all_remote_actors(_roles, user_role, _state, _story_state,
return state
task_center = create_component(
- TaskCenter, name='Task_Center', remote=REMOTE_MODE)
+ TaskCenter, name='Task_Center', remote=REMOTE_MODE, prefix_name=_uid)
# init all agents and task center
role_agents = []
@@ -148,7 +146,7 @@ def init_all_remote_actors(_roles, user_role, _state, _story_state,
if role == user_role:
human_input_mode = 'ON'
role_agent = upsert_role(role, _roles[role], human_input_mode,
- story_info, llm_config)
+ story_info, llm_config, _uid)
role_agents.append(role_agent)
@@ -156,7 +154,6 @@ def init_all_remote_actors(_roles, user_role, _state, _story_state,
MultiRolePlay,
name='chat_room',
remote=True,
- role='chat_room',
llm=llm_config,
function_list=function_list,
instruction=CHATROOM_INSTRUCTION_PROMPT.format(
@@ -164,9 +161,11 @@ def init_all_remote_actors(_roles, user_role, _state, _story_state,
story=story_info['story'],
user_role=user_role),
is_watcher=True,
- use_history=False)
+ use_history=False,
+ prefix_name=_uid)
- logging.warning(msg=f'time:{time.time()} done create task center')
+ logging.warning(
+ msg=f'time:{time.time()} done create task center with uid: {_uid}')
ray.get(task_center.add_agents.remote(role_agents))
ray.get(task_center.add_agents.remote([chat_room]))
diff --git a/apps/multi_roles_chat_room/story_holder.py b/apps/multi_roles_chat_room/story_holder.py
index be90e4623..2129a21e2 100644
--- a/apps/multi_roles_chat_room/story_holder.py
+++ b/apps/multi_roles_chat_room/story_holder.py
@@ -31,8 +31,7 @@
stories = [
{
'id': '1',
- 'cover':
- '//img.alicdn.com/imgextra/i1/O1CN01UHwXNQ2780lrVHY6n_!!6000000007751-0-tps-1024-512.jpg',
+ 'cover': '//s21.ax1x.com/2024/04/16/pFxG1zj.jpg',
'title': '我被美女包围了',
'description': '用户是男主角顾易,与多位长相、性格都大相径庭的美女相识',
'roles': ROLES_1,
@@ -41,8 +40,7 @@
},
{
'id': '2',
- 'cover':
- '//img.alicdn.com/imgextra/i1/O1CN01UHwXNQ2780lrVHY6n_!!6000000007751-0-tps-1024-512.jpg',
+ 'cover': '//s21.ax1x.com/2024/04/16/pFxGgw6.png',
'title': '我是雷军,雷中有“电”,军下有“车”',
'description': '用户是男主角雷军,小米创始人,最近发布了小米SU7',
'roles': ROLES_2,
diff --git a/modelscope_agent/__init__.py b/modelscope_agent/__init__.py
index bc51dd488..ae04173cb 100644
--- a/modelscope_agent/__init__.py
+++ b/modelscope_agent/__init__.py
@@ -1,13 +1,19 @@
from .agent import Agent
-def _create_remote(cls, name, max_concurrency=1, *args, **kwargs):
+def _create_remote(cls,
+ name,
+ max_concurrency=1,
+ force_new=False,
+ *args,
+ **kwargs):
'''
Create a remote actor by ray
Args:
cls: the class to be created
name: the name of ray actor, also the role name
max_concurrency: max concurrency of the actor
+ focus_new: force to create a new actor
*args:
**kwargs:
@@ -15,7 +21,16 @@ def _create_remote(cls, name, max_concurrency=1, *args, **kwargs):
'''
import ray
-
+ try:
+ # try to get an existing actor
+ ray_actor = ray.get_actor(name)
+ if force_new:
+ ray.kill(ray_actor)
+ else:
+ return ray_actor
+ except ValueError:
+ pass
+ # if failed, create a new actor
return ray.remote(
name=name,
max_concurrency=max_concurrency)(cls).remote(*args, **kwargs)
@@ -39,11 +54,15 @@ def create_component(cls,
name,
remote=False,
max_concurrency=1,
+ prefix_name=None,
*args,
**kwargs):
kwargs['remote'] = remote
kwargs['role'] = name
+ kwargs['prefix_name'] = prefix_name
if remote:
+ if prefix_name is not None:
+ name = f'{prefix_name}_{name}'
return _create_remote(cls, name, max_concurrency, *args, **kwargs)
else:
return _create_local(cls, *args, **kwargs)
diff --git a/modelscope_agent/agent.py b/modelscope_agent/agent.py
index f07a12983..88ff1f58b 100644
--- a/modelscope_agent/agent.py
+++ b/modelscope_agent/agent.py
@@ -3,7 +3,7 @@
from modelscope_agent.llm import get_chat_model
from modelscope_agent.llm.base import BaseChatModel
-from modelscope_agent.tools import TOOL_REGISTRY
+from modelscope_agent.tools.base import TOOL_REGISTRY
from modelscope_agent.utils.utils import has_chinese_chars
diff --git a/modelscope_agent/agent_env_util.py b/modelscope_agent/agent_env_util.py
index e478b7452..ebce48517 100644
--- a/modelscope_agent/agent_env_util.py
+++ b/modelscope_agent/agent_env_util.py
@@ -156,9 +156,6 @@ def step(self,
history = []
if self.use_history:
history = self.memory.get_history()
- logger.info(
- f'reach here 2 {result}, self.human_input_mode {self.human_input_mode}'
- )
# run generation
for frame in self.run(
prompt,
@@ -266,10 +263,10 @@ def pull(self):
else:
return ''
- def convert_to_string(self, messages: List[Message], max_turn=5):
+ def convert_to_string(self, messages: List[Message], max_turn=15):
prompt_template = """{conversation_history}"""
conversation_history = ''
- for item in messages[:max_turn]:
+ for item in messages[-1 * max_turn:]:
conversation_history += f'{item.sent_from}: {item.content}\n'
return prompt_template.format(
conversation_history=conversation_history.strip())
diff --git a/modelscope_agent/constants.py b/modelscope_agent/constants.py
index 77921d2bd..ca54a0235 100644
--- a/modelscope_agent/constants.py
+++ b/modelscope_agent/constants.py
@@ -4,3 +4,6 @@
DEFAULT_LOG_STORAGE_PATH = DEFAULT_AGENT_ROOT / 'log'
DEFAULT_SEND_TO = 'all'
USER_REQUIREMENT = 'user_requirement'
+ENVIRONMENT_NAME = 'env'
+AGENT_REGISTRY_NAME = 'agent_center'
+TASK_CENTER_NAME = 'task_center'
diff --git a/modelscope_agent/llm/dashscope.py b/modelscope_agent/llm/dashscope.py
index 6bf4ec535..2d22a0708 100644
--- a/modelscope_agent/llm/dashscope.py
+++ b/modelscope_agent/llm/dashscope.py
@@ -209,14 +209,52 @@ def build_raw_prompt(self, messages: list):
prompt = prompt[:-len(f'{im_end}')]
return prompt
+ def build_multi_role_raw_prompt(self, messages: list):
+ prompt = ''
+ im_start = '<|im_start|>'
+ im_end = '<|im_end|>'
+ print('build_raw_prompt', messages)
+ if messages[0]['role'] == 'system':
+ system_prompt = messages[0]['content']
+ else:
+ system_prompt = f'{im_start}system\nYou are a helpful assistant.{im_end}'
+
+ # select user
+ if 'recent_records' in system_prompt and 'chat_records' in system_prompt:
+ chat_records = messages[-1]['content'].strip()
+ recent_records = chat_records.split('\n')[-1]
+ prompt = f'{system_prompt.replace("chat_records", chat_records).replace("recent_records", recent_records)}<|im_start|>assistant\n' # noqa E501
+ else:
+ try:
+ re_pattern_config = re.compile(pattern=r'你是([\s\S]+),请你根据对话')
+ res = re_pattern_config.search(system_prompt)
+ cur_role_name = res.group(1).strip()
+ except Exception:
+ cur_role_name = 'assistant'
+ print('cur_role_name: ', cur_role_name)
+ prompt = system_prompt
+ content = messages[-1]['content'].lstrip('\n').rstrip()
+ if 'chat_records' in prompt:
+ prompt = f'{prompt.replace("chat_records", content)}\n<|im_start|>{cur_role_name}\n'
+ else:
+ prompt = f'{prompt}user\n{content}<|im_end|>\n<|im_start|>assistant\n{cur_role_name}: '
-@register_llm('dashscope_qwen_spark')
-class QwenSparkAtDS(DashScopeLLM):
+ print('prompt: ', [prompt])
+ return prompt
def _chat_stream(self,
messages: List[Dict],
stop: Optional[List[str]] = None,
**kwargs) -> Iterator[str]:
+ if self.model == 'qwen-spark-plus':
+ return self._chat_stream_with_raw_prompt(messages, stop, **kwargs)
+ else:
+ return super()._chat_stream(messages, stop, **kwargs)
+
+ def _chat_stream_with_raw_prompt(self,
+ messages: List[Dict],
+ stop: Optional[List[str]] = None,
+ **kwargs) -> Iterator[str]:
stop = stop or []
stop.append('<|im_end|>')
generation_input = {
@@ -226,10 +264,12 @@ def _chat_stream(self,
'stop_str': word,
'mode': 'exclude'
} for word in stop],
- 'top_p': kwargs.get('top_p', 0.8),
+ 'top_p': kwargs.get('top_p', 0.95),
+ 'temperature': kwargs.get('temperature', 0.92),
'result_format': 'message',
'stream': True,
'use_raw_prompt': True,
+ 'max_length': 100
}
logger.query_info(
@@ -240,39 +280,5 @@ def _chat_stream(self,
generation_input['temperature'] = kwargs.get('temperature')
if kwargs.get('seed', None):
generation_input['seed'] = kwargs.get('seed')
- logger.info(f'######## input{generation_input}')
-
response = dashscope.Generation.call(**generation_input)
- logger.info(f'######## response{response}')
-
return stream_output(response, **kwargs)
-
- def build_multi_role_raw_prompt(self, messages: list):
- prompt = ''
- im_start = '<|im_start|>'
- im_end = '<|im_end|>'
- print('build_raw_prompt', messages)
- if messages[0]['role'] == 'system':
- system_prompt = messages[0]['content']
- else:
- system_prompt = f'{im_start}system\nYou are a helpful assistant.{im_end}'
-
- # select user
- if 'chat_records' in system_prompt:
- chat_records = messages[-1]['content'].strip()
- recent_records = chat_records.split('\n')[-1]
- prompt = f'{system_prompt.replace("chat_records", chat_records).replace("recent_records", recent_records)}<|im_start|>assistant\n' # noqa E501
- else:
- try:
- re_pattern_config = re.compile(pattern=r'你是([\s\S]+),请你根据对话')
- res = re_pattern_config.search(system_prompt)
- cur_role_name = res.group(1).strip()
- except Exception:
- cur_role_name = 'assistant'
- print('cur_role_name: ', cur_role_name)
- prompt = system_prompt
- content = messages[-1]['content'].lstrip('\n').rstrip()
- prompt = f'{prompt}user\n{content}<|im_end|>\n<|im_start|>assistant\n{cur_role_name}: '
-
- print('prompt: ', [prompt])
- return prompt
diff --git a/modelscope_agent/multi_agents_utils/README.md b/modelscope_agent/multi_agents_utils/README.md
index 415dbbcb5..08340c5cc 100644
--- a/modelscope_agent/multi_agents_utils/README.md
+++ b/modelscope_agent/multi_agents_utils/README.md
@@ -423,6 +423,17 @@ With the above inputs, the step could be used in different scenarios of multi-ag
For example, in a three-man debate scenario, the `step` method could be used like this:
```python
+# add new role
+role_template_hillary = 'you are the former secretary of state Hillary Clinton, and you are debating with former president Donald Trump and current president Joe Biden with couple of topics'
+
+hillary_clinton = create_component(
+ RolePlay,
+ name='hillary_clinton',
+ remote=REMOTE_MODE,
+ llm=llm_config,
+ function_list=function_list,
+ instruction=role_template_hillary)
+
# initialize the agents
task_center.add_agents([joe_biden, donald_trump, hillary_clinton])
@@ -434,9 +445,12 @@ for frame in task_center.step(send_to='hillary_clinton'):
print(frame)
# in 2nd step, allow only donald_trump to response the topic
-for frame in task_center.step(allower_roles='donald_trump'):
+for frame in task_center.step(allowed_roles='donald_trump'):
print(frame)
```
+Notice that the `frame` will only show the message from different agents with format `<[role_name]>: [message stream]`
+user should take care of the message format in the `step` method.
+
The above case show how to use `send_to` and `allowed_roles` in the `step` method to control the communication between agents in a multi-agent task.
There is another case, in a chatbot mode, the `user_response` could be used to let the user response in this step to replace the llm output, if user_agent is in this step.
@@ -456,17 +470,20 @@ user = create_component(
task_center.add_agents([joe_biden, donald_trump, hillary_clinton, user])
# let joe_biden start the topic
-task_center.send_task_request('what is the best solution to land on moon?', send_to='joe_biden')
+task_center.send_task_request('what is the best solution to land on moon, in one sentence?', send_to='joe_biden')
# in 1st step, let joe_biden send his opinion to all agents.
for frame in task_center.step():
print(frame)
# in 2nd step, allow only user to response the topic, with user_response
-for frame in task_center.step(allower_roles='user', user_response='I dont agree with you about the landing project'):
- print(frame)
-assert frame == 'I dont agree with you about the landing project'
+result = ''
+for frame in task_center.step(allowed_roles='user', user_response='I dont agree with you about the landing project'):
+ result += frame
+ print(frame)
+# the input from outside will not print out here as the user_response is set
+assert result == ''
```
The user response will be used in this step to replace the llm output, because `user` is a human agent.
diff --git a/modelscope_agent/multi_agents_utils/README_CN.md b/modelscope_agent/multi_agents_utils/README_CN.md
index 40b45c6c8..fa4baf826 100644
--- a/modelscope_agent/multi_agents_utils/README_CN.md
+++ b/modelscope_agent/multi_agents_utils/README_CN.md
@@ -382,6 +382,17 @@ step方法将使每个agent在这一步骤中作出响应,响应将是一个
有了以上的参数,step方法可以在不同场景的multi-agent中使用。
例如,在一个三人辩论场景中,step方法可以像这样使用:
```python
+# add new role
+role_template_hillary = 'you are the former secretary of state Hillary Clinton, and you are debating with former president Donald Trump and current president Joe Biden with couple of topics'
+
+hillary_clinton = create_component(
+ RolePlay,
+ name='hillary_clinton',
+ remote=REMOTE_MODE,
+ llm=llm_config,
+ function_list=function_list,
+ instruction=role_template_hillary)
+
# initialize the agents
task_center.add_agents([joe_biden, donald_trump, hillary_clinton])
@@ -393,9 +404,13 @@ for frame in task_center.step(send_to='hillary_clinton'):
print(frame)
# in 2nd step, allow only donald_trump to response the topic
-for frame in task_center.step(allower_roles='donald_trump'):
+for frame in task_center.step(allowed_roles='donald_trump'):
print(frame)
```
+*请注意*,在`frame`中只会显示来自不同agent的消息,并且格式为`<[role_name]>: [message stream]`。
+用户需要根据自己的业务需求在step方法中处理future输出的格式。
+
+
上述案例展示了如何在multi-agent任务中使用step方法中的send_to和allowed_roles来控制agent之间的通信。
在另一个情况下,在聊天机器人模式中,如果本步骤中包含user-agent,可以使用user_response让用户在这一步骤中进行输入,以取代LLM(大型语言模型)的输出。
@@ -423,9 +438,13 @@ for frame in task_center.step():
print(frame)
# in 2nd step, allow only user to response the topic, with user_response
-for frame in task_center.step(allower_roles='user', user_response='I dont agree with you about the landing project'):
- print(frame)
-assert frame == 'I dont agree with you about the landing project'
+result = ''
+for frame in task_center.step(allowed_roles='user', user_response='I dont agree with you about the landing project'):
+ result += frame
+ print(frame)
+
+# the input from outside will not print out here as the user_response is set
+assert result == ''
```
可以看到,用户的响应将在这个步骤中被使用,以取代大型语言模型(LLM)的输出,因为名为`user`的agent 是一个user-agent。
diff --git a/modelscope_agent/schemas.py b/modelscope_agent/schemas.py
index 10a89543d..4d40559bc 100644
--- a/modelscope_agent/schemas.py
+++ b/modelscope_agent/schemas.py
@@ -1,5 +1,6 @@
-from typing import List, Set, Union
+from typing import List, Union
+from modelscope_agent.constants import DEFAULT_SEND_TO
from pydantic import BaseModel
@@ -10,7 +11,7 @@ class Message(BaseModel):
role: str = 'user' # user, assistant, system, tool
content: str = ''
sent_from: str = ''
- send_to: Union[str, Set[str]] = {'all'}
+ send_to: Union[str, List[str]] = DEFAULT_SEND_TO
class Document(BaseModel):
diff --git a/modelscope_agent/task_center.py b/modelscope_agent/task_center.py
index 2196b0f37..614b37efe 100644
--- a/modelscope_agent/task_center.py
+++ b/modelscope_agent/task_center.py
@@ -3,7 +3,8 @@
from modelscope_agent import create_component
from modelscope_agent.agent import Agent
from modelscope_agent.agents_registry import AgentRegistry
-from modelscope_agent.constants import DEFAULT_SEND_TO, USER_REQUIREMENT
+from modelscope_agent.constants import (AGENT_REGISTRY_NAME, DEFAULT_SEND_TO,
+ ENVIRONMENT_NAME, USER_REQUIREMENT)
from modelscope_agent.environment import Environment
from modelscope_agent.schemas import Message
from modelscope_agent.utils.logger import agent_logger as logger
@@ -11,16 +12,24 @@
class TaskCenter:
- def __init__(self, remote=False, **kwargs):
+ def __init__(self, remote=False, prefix_name=None, **kwargs):
if remote:
from modelscope_agent.multi_agents_utils.executors.ray import RayTaskExecutor
self.task_executor = RayTaskExecutor
else:
from modelscope_agent.multi_agents_utils.executors.local import LocalTaskExecutor
self.task_executor = LocalTaskExecutor
- self.env = create_component(Environment, 'env', remote)
- self.agent_registry = create_component(AgentRegistry, 'agent_center',
- remote)
+ # used to create the environment and agent registry with specific prefix
+ self.env = create_component(
+ cls=Environment,
+ name=ENVIRONMENT_NAME,
+ remote=remote,
+ prefix_name=prefix_name)
+ self.agent_registry = create_component(
+ cls=AgentRegistry,
+ name=AGENT_REGISTRY_NAME,
+ remote=remote,
+ prefix_name=prefix_name)
self.remote = remote
def add_agents(self, agents: List[Agent]):
@@ -133,6 +142,8 @@ def step(self,
if len(allowed_roles) == 0:
roles = self.task_executor.get_notified_roles(self.env)
else:
+ if isinstance(allowed_roles, str):
+ allowed_roles = [allowed_roles]
roles = allowed_roles
if len(roles) == 0: