Skip to content

Commit

Permalink
Support termination condition combination. Closes #4325
Browse files Browse the repository at this point in the history
  • Loading branch information
victordibia committed Nov 23, 2024
1 parent 6b054cf commit c5345c8
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from autogen_agentchat.task import MaxMessageTermination, TextMentionTermination, StopMessageTermination
import yaml
import logging
from packaging import version
from ..utils.utils import Version

from ..datamodel import (
TeamConfig, AgentConfig, ModelConfig, ToolConfig,
Expand Down Expand Up @@ -174,18 +174,44 @@ def _dict_to_config(self, config_dict: dict) -> ComponentConfig:
async def load_termination(self, config: TerminationConfig) -> TerminationComponent:
"""Create termination condition instance from configuration."""
try:
if config.termination_type == TerminationTypes.MAX_MESSAGES:
if config.termination_type == TerminationTypes.COMBINATION:
if not config.conditions or len(config.conditions) < 2:
raise ValueError(
"Combination termination requires at least 2 conditions")
if not config.operator:
raise ValueError(
"Combination termination requires an operator (and/or)")

# Load first two conditions
conditions = [await self.load_termination(cond) for cond in config.conditions[:2]]
result = conditions[0] & conditions[1] if config.operator == "and" else conditions[0] | conditions[1]

# Process remaining conditions if any
for condition in config.conditions[2:]:
next_condition = await self.load_termination(condition)
result = result & next_condition if config.operator == "and" else result | next_condition

return result

elif config.termination_type == TerminationTypes.MAX_MESSAGES:
if config.max_messages is None:
raise ValueError(
"max_messages parameter required for MaxMessageTermination")
return MaxMessageTermination(max_messages=config.max_messages)

elif config.termination_type == TerminationTypes.STOP_MESSAGE:
return StopMessageTermination()

elif config.termination_type == TerminationTypes.TEXT_MENTION:
if not config.text:
raise ValueError(
"text parameter required for TextMentionTermination")
return TextMentionTermination(text=config.text)

else:
raise ValueError(
f"Unsupported termination type: {config.termination_type}")

except Exception as e:
logger.error(f"Failed to create termination condition: {str(e)}")
raise ValueError(
Expand Down Expand Up @@ -367,9 +393,11 @@ def _func_from_string(self, content: str) -> callable:
def _is_version_supported(self, component_type: ComponentType, ver: str) -> bool:
"""Check if version is supported for component type."""
try:
v = version.parse(ver)
return ver in self.SUPPORTED_VERSIONS[component_type]
except version.InvalidVersion:
version = Version(ver)
supported = [Version(v)
for v in self.SUPPORTED_VERSIONS[component_type]]
return any(version == v for v in supported)
except ValueError:
return False

async def cleanup(self) -> None:
Expand Down
95 changes: 95 additions & 0 deletions python/packages/autogen-studio/tests/test_component_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,101 @@ async def test_load_termination(component_factory: ComponentFactory):
assert isinstance(termination, TextMentionTermination)
assert termination._text == "DONE"

# Test AND combination
and_combo_config = TerminationConfig(
termination_type=TerminationTypes.COMBINATION,
operator="and",
conditions=[
TerminationConfig(
termination_type=TerminationTypes.MAX_MESSAGES,
max_messages=5,
component_type=ComponentType.TERMINATION,
version="1.0.0"
),
TerminationConfig(
termination_type=TerminationTypes.TEXT_MENTION,
text="DONE",
component_type=ComponentType.TERMINATION,
version="1.0.0"
)
],
component_type=ComponentType.TERMINATION,
version="1.0.0"
)
termination = await component_factory.load_termination(and_combo_config)
assert termination is not None

# Test OR combination
or_combo_config = TerminationConfig(
termination_type=TerminationTypes.COMBINATION,
operator="or",
conditions=[
TerminationConfig(
termination_type=TerminationTypes.MAX_MESSAGES,
max_messages=5,
component_type=ComponentType.TERMINATION,
version="1.0.0"
),
TerminationConfig(
termination_type=TerminationTypes.TEXT_MENTION,
text="DONE",
component_type=ComponentType.TERMINATION,
version="1.0.0"
)
],
component_type=ComponentType.TERMINATION,
version="1.0.0"
)
termination = await component_factory.load_termination(or_combo_config)
assert termination is not None

# Test invalid combinations
with pytest.raises(ValueError):
await component_factory.load_termination(TerminationConfig(
termination_type=TerminationTypes.COMBINATION,
conditions=[], # Empty conditions
component_type=ComponentType.TERMINATION,
version="1.0.0"
))

with pytest.raises(ValueError):
await component_factory.load_termination(TerminationConfig(
termination_type=TerminationTypes.COMBINATION,
operator="invalid", # type: ignore
conditions=[
TerminationConfig(
termination_type=TerminationTypes.MAX_MESSAGES,
max_messages=5,
component_type=ComponentType.TERMINATION,
version="1.0.0"
)
],
component_type=ComponentType.TERMINATION,
version="1.0.0"
))

# Test missing operator
with pytest.raises(ValueError):
await component_factory.load_termination(TerminationConfig(
termination_type=TerminationTypes.COMBINATION,
conditions=[
TerminationConfig(
termination_type=TerminationTypes.MAX_MESSAGES,
max_messages=5,
component_type=ComponentType.TERMINATION,
version="1.0.0"
),
TerminationConfig(
termination_type=TerminationTypes.TEXT_MENTION,
text="DONE",
component_type=ComponentType.TERMINATION,
version="1.0.0"
)
],
component_type=ComponentType.TERMINATION,
version="1.0.0"
))


@pytest.mark.asyncio
async def test_load_team(component_factory: ComponentFactory, sample_team_config: TeamConfig, sample_model_config: ModelConfig):
Expand Down

0 comments on commit c5345c8

Please sign in to comment.