Skip to content

Commit

Permalink
Update retrieval agent
Browse files Browse the repository at this point in the history
  • Loading branch information
w5688414 committed Jan 10, 2024
1 parent 3735fce commit e463c18
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 14 deletions.
12 changes: 8 additions & 4 deletions erniebot-agent/src/erniebot_agent/agents/retrieval_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,15 +115,17 @@ async def _run(self, prompt: str, files: Optional[Sequence[File]] = None) -> Age
steps_input = HumanMessage(
content=self.query_transform.format(query=prompt, documents=few_shots)
)
steps_taken.append(AgentStep(info={'query':prompt, 'name': "few shot retriever"}, result=few_shots))
steps_taken.append(
AgentStep(info={"query": prompt, "name": "few shot retriever"}, result=few_shots)
)
elif self.context_retriever:
res = self.context_retriever.search(prompt, 3)

context = [item["content"] for item in res]
steps_input = HumanMessage(
content=self.context_planning.format(query=prompt, context="\n".join(context))
)
steps_taken.append(AgentStep(info={'query':prompt, 'name': "context retriever"}, result=res))
steps_taken.append(AgentStep(info={"query": prompt, "name": "context retriever"}, result=res))
else:
steps_input = HumanMessage(content=self.query_transform.format(query=prompt))
# Query planning
Expand Down Expand Up @@ -167,15 +169,17 @@ async def execute(self, sub_queries, steps_taken: List[AgentStep]):
compressed_data["content"] = output_message.content
retrieval_results.append(compressed_data)
steps_taken.append(
AgentStep(info={'query':query, 'name': f"sub query compressor {idx}"}, result=compressed_data)
AgentStep(
info={"query": query, "name": f"sub query compressor {idx}"}, result=compressed_data
)
)
else:
duplicates = set()
for idx, query in enumerate(sub_queries):
documents = await self.knowledge_base(query, top_k=self.top_k, filters=None)
docs = [item for item in documents["documents"]]
steps_taken.append(
AgentStep(info={'query':query, 'name': f"sub query results {idx}"}, result=documents)
AgentStep(info={"query": query, "name": f"sub query results {idx}"}, result=documents)
)
for doc in docs:
if doc["content"] not in duplicates:
Expand Down
1 change: 0 additions & 1 deletion erniebot-agent/src/erniebot_agent/agents/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,6 @@ class AgentStep(Generic[_IT, _RT]):
result: _RT



@dataclass
class AgentStepWithFiles(AgentStep[_IT, _RT]):
"""A step taken by an agent involving file input and output."""
Expand Down
18 changes: 9 additions & 9 deletions erniebot-agent/tests/unit_tests/agents/test_retrieval_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,9 @@ async def test_retrieval_agent_run_few_shot():
"检索语句: Hello, world!\n请根据以上检索结果回答检索语句的问题"
)
assert response.chat_history[1].content == "Text response"
assert response.steps[0].info == {'query': 'Hello, world!', 'name': 'few shot retriever'}
assert response.steps[1].info == {'query': '具体子问题1', 'name': 'sub query results 0'}
assert response.steps[2].info== {'query': '具体子问题2', 'name': 'sub query results 1'}
assert response.steps[0].info == {"query": "Hello, world!", "name": "few shot retriever"}
assert response.steps[1].info == {"query": "具体子问题1", "name": "sub query results 0"}
assert response.steps[2].info == {"query": "具体子问题2", "name": "sub query results 1"}


@pytest.mark.asyncio
Expand Down Expand Up @@ -126,10 +126,10 @@ async def test_retrieval_agent_run_context_planning():
"检索语句: Hello, world!\n请根据以上检索结果回答检索语句的问题"
)
assert response.chat_history[1].content == "Text response"
assert response.steps[0].info == {'query': 'Hello, world!', 'name': 'context retriever'}
assert response.steps[1].info == {'query': '具体子问题1', 'name': 'sub query results 0'}
assert response.steps[2].info == {'query': '具体子问题2', 'name': 'sub query results 1'}

assert response.steps[0].info == {"query": "Hello, world!", "name": "context retriever"}
assert response.steps[1].info == {"query": "具体子问题1", "name": "sub query results 0"}
assert response.steps[2].info == {"query": "具体子问题2", "name": "sub query results 1"}


@pytest.mark.asyncio
Expand Down Expand Up @@ -173,5 +173,5 @@ async def test_retrieval_agent_run_compressor():
"第2个段落: Sub query compress 2\n\n检索语句: Hello, world!\n请根据以上检索结果回答检索语句的问题"
)
assert response.chat_history[1].content == "Text response"
assert response.steps[0].info == {'query': '具体子问题1', 'name': 'sub query compressor 0'}
assert response.steps[1].info == {'query': '具体子问题2', 'name': 'sub query compressor 1'}
assert response.steps[0].info == {"query": "具体子问题1", "name": "sub query compressor 0"}
assert response.steps[1].info == {"query": "具体子问题2", "name": "sub query compressor 1"}

0 comments on commit e463c18

Please sign in to comment.