Skip to content

Commit

Permalink
Refactor ContextProcessorComponent to have one hook per phase.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 651805326
Change-Id: I0ddf3e5a7545554d03a6bc9c5626a29d376ddb91
  • Loading branch information
jagapiou authored and copybara-github committed Jul 12, 2024
1 parent daf5309 commit f1cd7b3
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 24 deletions.
7 changes: 4 additions & 3 deletions concordia/agents/entity_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,13 +125,14 @@ def act(
) -> str:
self._phase = component_v2.Phase.PRE_ACT
contexts = self._parallel_call_('pre_act', action_spec)
self._context_processor.pre_act(types.MappingProxyType(contexts))
action_attempt = self._act_component.get_action_attempt(
contexts, action_spec
)

self._phase = component_v2.Phase.POST_ACT
contexts = self._parallel_call_('post_act', action_attempt)
self._context_processor.process(contexts)
self._context_processor.post_act(contexts)

self._phase = component_v2.Phase.UPDATE
self._parallel_call_('update')
Expand All @@ -142,11 +143,11 @@ def act(
def observe(self, observation: str) -> None:
self._phase = component_v2.Phase.PRE_OBSERVE
contexts = self._parallel_call_('pre_observe', observation)
self._context_processor.process(contexts)
self._context_processor.pre_observe(contexts)

self._phase = component_v2.Phase.POST_OBSERVE
contexts = self._parallel_call_('post_observe')
self._context_processor.process(contexts)
self._context_processor.post_observe(contexts)

self._phase = component_v2.Phase.UPDATE
self._parallel_call_('update')
Expand Down
9 changes: 0 additions & 9 deletions concordia/components/agent/v2/no_op_context_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,6 @@

from concordia.typing import component_v2

from typing_extensions import override


class NoOpContextProcessor(component_v2.ContextProcessorComponent):
"""A context processor component that does nothing."""

@override
def process(
self,
contexts: component_v2.ComponentContextMapping,
) -> None:
del contexts
40 changes: 28 additions & 12 deletions concordia/typing/component_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,20 +219,36 @@ def get_action_attempt(


class ContextProcessorComponent(BaseComponent, metaclass=abc.ABCMeta):
"""A component that processes context from components."""
"""A component that processes context from ContextComponents."""

@abc.abstractmethod
def process(
self,
contexts: ComponentContextMapping,
) -> None:
"""Processes the context from ContextComponents.
def pre_act(self, contexts: ComponentContextMapping) -> None:
"""Processes the pre_act contexts returned by the ContextComponents.
Args:
contexts: A mapping from ComponentName to ComponentContext.
"""
del contexts

This function will be called by the entity with the context from other
components. The component should process the context and possibly update its
internal state or access other components.
def post_act(self, contexts: ComponentContextMapping) -> None:
"""Processes the post_act contexts returned by the ContextComponents.
Args:
contexts: The context from ContextComponents.
contexts: A mapping from ComponentName to ComponentContext.
"""
raise NotImplementedError()
del contexts

def pre_observe(self, contexts: ComponentContextMapping) -> None:
"""Processes the pre_observe contexts returned by the ContextComponents.
Args:
contexts: A mapping from ComponentName to ComponentContext.
"""
del contexts

def post_observe(self, contexts: ComponentContextMapping) -> None:
"""Processes the post_observe contexts returned by the ContextComponents.
Args:
contexts: The context from other components.
"""
del contexts

0 comments on commit f1cd7b3

Please sign in to comment.