Skip to content

Commit

Permalink
add get_raw_memory and get_all_memories_as_text functions on the memo…
Browse files Browse the repository at this point in the history
…ry component and use them in parochial_universalization_agent.

PiperOrigin-RevId: 693371638
Change-Id: I4ad81639effd7e6a18c6dc5f3b4cc2564f7b5ffc
  • Loading branch information
jzleibo authored and copybara-github committed Nov 5, 2024
1 parent 2d51735 commit 12f0e56
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 14 deletions.
12 changes: 12 additions & 0 deletions concordia/associative_memory/associative_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,3 +459,15 @@ def set_num_to_retrieve_to_contextualize_importance(
importance.
"""
self._num_to_retrieve_to_contextualize_importance = num_to_retrieve

def get_all_memories_as_text(
self,
add_time: bool = True,
sort_by_time: bool = True,
) -> Sequence[str]:
"""Returns all memories in the memory bank as a sequence of strings."""
memories_data_frame = self.get_data_frame()
texts = self._pd_to_text(memories_data_frame,
add_time=add_time,
sort_by_time=sort_by_time)
return texts
21 changes: 21 additions & 0 deletions concordia/components/agent/memory_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@

from concordia.typing import entity_component
from concordia.typing import memory as memory_lib
import pandas as pd


DEFAULT_MEMORY_COMPONENT_NAME = '__memory__'

Expand Down Expand Up @@ -110,3 +112,22 @@ def update(
for mem in self._buffer:
self._memory.add(mem['text'], mem['metadata'])
self._buffer = []

def get_raw_memory(self) -> pd.DataFrame:
"""Returns the raw memory as a pandas dataframe."""
self._check_phase()
with self._lock:
return self._memory.get_data_frame()

def get_all_memories_as_text(
self,
add_time: bool = True,
sort_by_time: bool = True,
) -> Sequence[str]:
"""Returns all memories in the memory bank as a sequence of strings."""
self._check_phase()
with self._lock:
texts = self._memory.get_all_memories_as_text(
add_time=add_time,
sort_by_time=sort_by_time)
return texts
25 changes: 11 additions & 14 deletions concordia/factory/agent/parochial_universalization_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import json
import types

from absl import logging as absl_logging
from concordia.agents import entity_agent_with_logging
from concordia.associative_memory import associative_memory
from concordia.associative_memory import formative_memories
Expand Down Expand Up @@ -59,13 +60,8 @@ def _get_all_memories(
sort_by_time: whether to sort by time
constant_score: assign this score value to each memory
"""
# pylint: disable=protected-access
raw_memory = memory_component_._memory._memory
memories_data_frame = raw_memory.get_data_frame()
texts = raw_memory._pd_to_text(memories_data_frame,
add_time=add_time,
sort_by_time=sort_by_time)
# pylint: enable=protected-access
texts = memory_component_.get_all_memories_as_text(add_time=add_time,
sort_by_time=sort_by_time)
return [memory_lib.MemoryResult(text=t, score=constant_score) for t in texts]


Expand All @@ -77,13 +73,14 @@ def _get_earliest_timepoint(
Args:
memory_component_: The memory component to retrieve memories from.
"""
# pylint: disable=protected-access
raw_memory = memory_component_._memory._memory
memories_data_frame = raw_memory.get_data_frame()
sorted_memories_data_frame = memories_data_frame.sort_values(
'time', ascending=True)
# pylint: enable=protected-access
return sorted_memories_data_frame.iloc[0].time
memories_data_frame = memory_component_.get_raw_memory()
if not memories_data_frame.empty:
sorted_memories_data_frame = memories_data_frame.sort_values(
'time', ascending=True)
return sorted_memories_data_frame['time'][0]
else:
absl_logging.warn('No memories found in memory bank.')
return datetime.datetime.now()


class AvailableOptionsPerception(
Expand Down
14 changes: 14 additions & 0 deletions concordia/memory_bank/legacy_associative_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from concordia.associative_memory import associative_memory
from concordia.typing import entity_component
from concordia.typing import memory as memory_lib
import pandas as pd
from typing_extensions import override


Expand Down Expand Up @@ -105,6 +106,19 @@ def __init__(self, memory: associative_memory.AssociativeMemory):
def add(self, text: str, metadata: Mapping[str, Any]) -> None:
self._memory.add(text, **metadata)

def get_data_frame(self) -> pd.DataFrame:
"""Returns the memory bank as a pandas dataframe."""
return self._memory.get_data_frame()

def get_all_memories_as_text(
self,
add_time: bool = True,
sort_by_time: bool = True,
) -> Sequence[str]:
"""Returns all memories in the memory bank as a sequence of strings."""
return self._memory.get_all_memories_as_text(add_time=add_time,
sort_by_time=sort_by_time)

@override
def get_state(self) -> entity_component.ComponentState:
"""Returns the state of the memory bank.
Expand Down
20 changes: 20 additions & 0 deletions concordia/typing/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
from collections.abc import Mapping, Sequence
import dataclasses
from typing import Any, Protocol

from concordia.typing import entity_component
import pandas as pd


class MemoryScorer(Protocol):
Expand Down Expand Up @@ -75,6 +77,24 @@ def extend(self, texts: Sequence[str], metadata: Mapping[str, Any]) -> None:
for text in texts:
self.add(text, metadata)

@abc.abstractmethod
def get_data_frame(self) -> pd.DataFrame:
"""Returns the memory bank as a pandas dataframe."""
raise NotImplementedError()

@abc.abstractmethod
def get_all_memories_as_text(
self,
add_time: bool,
sort_by_time: bool) -> Sequence[str]:
"""Returns the memory bank as a sequence of strings.
Args:
add_time: Whether to add the time stamp to the memory.
sort_by_time: Whether to sort the memories by time.
"""
raise NotImplementedError()

@abc.abstractmethod
def get_state(self) -> entity_component.ComponentState:
"""Returns the state of the memory bank.
Expand Down

0 comments on commit 12f0e56

Please sign in to comment.