-
Notifications
You must be signed in to change notification settings - Fork 144
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add Component base class for code refactoring (#983)
* Add Component base class Signed-off-by: lvliang-intel <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add controller class Signed-off-by: lvliang-intel <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add ut Signed-off-by: lvliang-intel <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: lvliang-intel <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
- Loading branch information
1 parent
ddd372d
commit c409ef9
Showing
4 changed files
with
308 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
# Copyright (C) 2024 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,155 @@ | ||
# Copyright (C) 2024 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from abc import ABC, abstractmethod | ||
|
||
|
||
class OpeaComponent(ABC): | ||
"""The OpeaComponent class serves as the base class for all components in the GenAIComps. | ||
It provides a unified interface and foundational attributes that every derived component inherits and extends. | ||
Attributes: | ||
name (str): The name of the component. | ||
type (str): The type of the component (e.g., 'retriever', 'embedding', 'reranking', 'llm', etc.). | ||
description (str): A brief description of the component's functionality. | ||
config (dict): A dictionary containing configuration parameters for the component. | ||
""" | ||
|
||
def __init__(self, name: str, type: str, description: str, config: dict = None): | ||
"""Initializes an OpeaComponent instance with the provided attributes. | ||
Args: | ||
name (str): The name of the component. | ||
type (str): The type of the component. | ||
description (str): A brief description of the component. | ||
config (dict, optional): Configuration parameters for the component. Defaults to an empty dictionary. | ||
""" | ||
self.name = name | ||
self.type = type | ||
self.description = description | ||
self.config = config if config is not None else {} | ||
|
||
def get_meta(self) -> dict: | ||
"""Retrieves metadata about the component, including its name, type, description, and configuration. | ||
Returns: | ||
dict: A dictionary containing the component's metadata. | ||
""" | ||
return { | ||
"name": self.name, | ||
"type": self.type, | ||
"description": self.description, | ||
"config": self.config, | ||
} | ||
|
||
def update_config(self, key: str, value): | ||
"""Updates a configuration parameter for the component. | ||
Args: | ||
key (str): The configuration parameter's key. | ||
value: The new value for the configuration parameter. | ||
""" | ||
self.config[key] = value | ||
|
||
@abstractmethod | ||
def check_health(self) -> bool: | ||
"""Checks the health of the component. | ||
Returns: | ||
bool: True if the component is healthy, False otherwise. | ||
""" | ||
pass | ||
|
||
@abstractmethod | ||
def invoke(self, *args, **kwargs): | ||
"""Invoke service accessing using the component. | ||
Args: | ||
*args: Positional arguments. | ||
**kwargs: Keyword arguments. | ||
Returns: | ||
Any: The result of the service accessing. | ||
""" | ||
pass | ||
|
||
def __repr__(self): | ||
"""Provides a string representation of the component for debugging and logging purposes. | ||
Returns: | ||
str: A string representation of the OpeaComponent instance. | ||
""" | ||
return f"OpeaComponent(name={self.name}, type={self.type}, description={self.description})" | ||
|
||
|
||
class OpeaComponentController(ABC): | ||
"""The OpeaComponentController class serves as the base class for managing and orchestrating multiple | ||
instances of components of the same type. It provides a unified interface for routing tasks, | ||
registering components, and dynamically discovering available components. | ||
Attributes: | ||
components (dict): A dictionary to store registered components by their unique identifiers. | ||
""" | ||
|
||
def __init__(self): | ||
"""Initializes the OpeaComponentController instance with an empty component registry.""" | ||
self.components = {} | ||
self.active_component = None | ||
|
||
def register(self, component): | ||
"""Registers an OpeaComponent instance to the controller. | ||
Args: | ||
component (OpeaComponent): An instance of a subclass of OpeaComponent to be managed. | ||
Raises: | ||
ValueError: If the component is already registered. | ||
""" | ||
if component.name in self.components: | ||
raise ValueError(f"Component '{component.name}' is already registered.") | ||
self.components[component.name] = component | ||
|
||
def discover_and_activate(self): | ||
"""Discovers healthy components and activates one. | ||
If multiple components are healthy, it prioritizes the first registered component. | ||
""" | ||
for component in self.components.values(): | ||
if component.check_health(): | ||
self.active_component = component | ||
print(f"Activated component: {component.name}") | ||
return | ||
raise RuntimeError("No healthy components available.") | ||
|
||
def invoke(self, *args, **kwargs): | ||
"""Invokes service accessing using the active component. | ||
Args: | ||
*args: Positional arguments. | ||
**kwargs: Keyword arguments. | ||
Returns: | ||
Any: The result of the service accessing. | ||
Raises: | ||
RuntimeError: If no active component is set. | ||
""" | ||
if not self.active_component: | ||
raise RuntimeError("No active component. Call 'discover_and_activate' first.") | ||
return self.active_component.invoke(*args, **kwargs) | ||
|
||
def list_components(self): | ||
"""Lists all registered components. | ||
Returns: | ||
list: A list of component names that are currently registered. | ||
""" | ||
return self.components.keys() | ||
|
||
def __repr__(self): | ||
"""Provides a string representation of the controller and its registered components. | ||
Returns: | ||
str: A string representation of the OpeaComponentController instance. | ||
""" | ||
return f"OpeaComponentController(registered_components={self.list_components()})" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,148 @@ | ||
# Copyright (C) 2024 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import unittest | ||
from unittest.mock import MagicMock | ||
|
||
from comps import OpeaComponent, OpeaComponentController | ||
|
||
|
||
class TestOpeaComponent(unittest.TestCase): | ||
class MockOpeaComponent(OpeaComponent): | ||
def check_health(self) -> bool: | ||
return True | ||
|
||
def invoke(self, *args, **kwargs): | ||
return "Service accessed" | ||
|
||
def test_initialization(self): | ||
component = self.MockOpeaComponent("TestComponent", "embedding", "Test description") | ||
|
||
self.assertEqual(component.name, "TestComponent") | ||
self.assertEqual(component.type, "embedding") | ||
self.assertEqual(component.description, "Test description") | ||
self.assertEqual(component.config, {}) | ||
|
||
def test_get_meta(self): | ||
component = self.MockOpeaComponent("TestComponent", "embedding", "Test description", {"key": "value"}) | ||
meta = component.get_meta() | ||
|
||
self.assertEqual(meta["name"], "TestComponent") | ||
self.assertEqual(meta["type"], "embedding") | ||
self.assertEqual(meta["description"], "Test description") | ||
self.assertEqual(meta["config"], {"key": "value"}) | ||
|
||
def test_update_config(self): | ||
component = self.MockOpeaComponent("TestComponent", "embedding", "Test description") | ||
component.update_config("key", "new_value") | ||
|
||
self.assertEqual(component.config["key"], "new_value") | ||
|
||
|
||
class TestOpeaComponentController(unittest.TestCase): | ||
def test_register_component(self): | ||
controller = OpeaComponentController() | ||
component = MagicMock() | ||
component.name = "TestComponent" | ||
controller.register(component) | ||
|
||
self.assertIn("TestComponent", controller.components) | ||
|
||
with self.assertRaises(ValueError): | ||
controller.register(component) | ||
|
||
def test_discover_and_activate(self): | ||
controller = OpeaComponentController() | ||
|
||
# Mock a healthy component | ||
component1 = MagicMock() | ||
component1.name = "Component1" | ||
component1.check_health.return_value = True | ||
|
||
# Register and activate the healthy component | ||
controller.register(component1) | ||
controller.discover_and_activate() | ||
|
||
# Ensure the component is activated | ||
self.assertEqual(controller.active_component, component1) | ||
|
||
# Add another component that is unhealthy | ||
component2 = MagicMock() | ||
component2.name = "Component2" | ||
component2.check_health.return_value = False | ||
controller.register(component2) | ||
|
||
# Call discover_and_activate again; the active component should remain the same | ||
controller.discover_and_activate() | ||
self.assertEqual(controller.active_component, component1) | ||
|
||
def test_invoke_no_active_component(self): | ||
controller = OpeaComponentController() | ||
with self.assertRaises(RuntimeError): | ||
controller.invoke("arg1", key="value") | ||
|
||
def test_invoke_with_active_component(self): | ||
controller = OpeaComponentController() | ||
|
||
# Mock a component | ||
component = MagicMock() | ||
component.name = "TestComponent" | ||
component.check_health.return_value = True | ||
component.invoke = MagicMock(return_value="Service accessed") | ||
|
||
# Register and activate the component | ||
controller.register(component) | ||
controller.discover_and_activate() | ||
|
||
# Invoke using the active component | ||
result = controller.invoke("arg1", key="value") | ||
|
||
# Assert the result and method call | ||
self.assertEqual(result, "Service accessed") | ||
component.invoke.assert_called_with("arg1", key="value") | ||
|
||
def test_discover_then_invoke(self): | ||
"""Ensures that `discover_and_activate` and `invoke` work correctly when called sequentially.""" | ||
controller = OpeaComponentController() | ||
|
||
# Mock a healthy component | ||
component1 = MagicMock() | ||
component1.name = "Component1" | ||
component1.check_health.return_value = True | ||
component1.invoke = MagicMock(return_value="Result from Component1") | ||
|
||
# Register the component | ||
controller.register(component1) | ||
|
||
# Discover and activate | ||
controller.discover_and_activate() | ||
|
||
# Ensure the component is activated | ||
self.assertEqual(controller.active_component, component1) | ||
|
||
# Call invoke separately | ||
result = controller.invoke("test_input") | ||
self.assertEqual(result, "Result from Component1") | ||
component1.invoke.assert_called_once_with("test_input") | ||
|
||
def test_list_components(self): | ||
controller = OpeaComponentController() | ||
|
||
# Mock components | ||
component1 = MagicMock() | ||
component1.name = "Component1" | ||
component2 = MagicMock() | ||
component2.name = "Component2" | ||
|
||
# Register components | ||
controller.register(component1) | ||
controller.register(component2) | ||
|
||
# Assert the components list | ||
components_list = controller.list_components() | ||
self.assertIn("Component1", components_list) | ||
self.assertIn("Component2", components_list) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |