diff --git a/changelog/1075.improvement.md b/changelog/1075.improvement.md new file mode 100644 index 000000000..c254fe74c --- /dev/null +++ b/changelog/1075.improvement.md @@ -0,0 +1 @@ +Implement functionality that enables creating additional spans within custom actions. \ No newline at end of file diff --git a/rasa_sdk/tracing/instrumentation/instrumentation.py b/rasa_sdk/tracing/instrumentation/instrumentation.py index 9e84a4c5d..f28fd2b40 100644 --- a/rasa_sdk/tracing/instrumentation/instrumentation.py +++ b/rasa_sdk/tracing/instrumentation/instrumentation.py @@ -17,6 +17,7 @@ from rasa_sdk.executor import ActionExecutor from rasa_sdk.forms import ValidationAction from rasa_sdk.tracing.instrumentation import attribute_extractors +from rasa_sdk.tracing.tracer_register import ActionExecutorTracerRegister # The `TypeVar` representing the return type for a function to be wrapped. S = TypeVar("S") @@ -132,13 +133,15 @@ def instrument( if action_executor_class is not None and not class_is_instrumented( action_executor_class ): + tracer = tracer_provider.get_tracer(action_executor_class.__module__) _instrument_method( - tracer_provider.get_tracer(action_executor_class.__module__), + tracer, action_executor_class, "run", attribute_extractors.extract_attrs_for_action_executor, ) mark_class_as_instrumented(action_executor_class) + ActionExecutorTracerRegister().register_tracer(tracer) if validation_action_class is not None and not class_is_instrumented( validation_action_class diff --git a/rasa_sdk/tracing/tracer_register.py b/rasa_sdk/tracing/tracer_register.py new file mode 100644 index 000000000..2985d5e8e --- /dev/null +++ b/rasa_sdk/tracing/tracer_register.py @@ -0,0 +1,23 @@ +from typing import Optional +from rasa_sdk.utils import Singleton +from opentelemetry.trace import Tracer + + +class ActionExecutorTracerRegister(metaclass=Singleton): + """Represents a provider for ActionExecutor tracer.""" + + tracer: Optional[Tracer] = None + + def register_tracer(self, tracer: Tracer) -> None: + """Register an ActionExecutor tracer. + Args: + trace: The tracer to register. + """ + self.tracer = tracer + + def get_tracer(self) -> Optional[Tracer]: + """Get the ActionExecutor tracer. + Returns: + The tracer. + """ + return self.tracer diff --git a/rasa_sdk/utils.py b/rasa_sdk/utils.py index 37a66ae27..a9fb35cc9 100644 --- a/rasa_sdk/utils.py +++ b/rasa_sdk/utils.py @@ -46,6 +46,28 @@ class Button(dict): pass +class Singleton(type): + """Singleton metaclass.""" + + _instances: Dict[Any, Any] = {} + + def __call__(cls, *args: Any, **kwargs: Any) -> Any: + """Call the class. + + Args: + *args: Arguments. + **kwargs: Keyword arguments. + """ + if cls not in cls._instances: + cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) + + return cls._instances[cls] + + def clear(cls) -> None: + """Clear the class.""" + cls._instances = {} + + def all_subclasses(cls: Any) -> List[Any]: """Returns all known (imported) subclasses of a class.""" return cls.__subclasses__() + [ diff --git a/tests/tracing/instrumentation/test_action_executor.py b/tests/tracing/instrumentation/test_action_executor.py index f3777f200..05ab9f44c 100644 --- a/tests/tracing/instrumentation/test_action_executor.py +++ b/tests/tracing/instrumentation/test_action_executor.py @@ -1,13 +1,17 @@ -from typing import Any, Dict, Sequence, Text, Optional - import pytest + +from typing import Any, Dict, Sequence, Text, Optional +from unittest.mock import Mock +from pytest import MonkeyPatch from opentelemetry.sdk.trace import ReadableSpan, TracerProvider from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter +from opentelemetry import trace from rasa_sdk.tracing.instrumentation import instrumentation from tests.tracing.instrumentation.conftest import MockActionExecutor from rasa_sdk.types import ActionCall from rasa_sdk import Tracker +from rasa_sdk.tracing.tracer_register import ActionExecutorTracerRegister @pytest.mark.parametrize( @@ -56,3 +60,31 @@ async def test_tracing_action_executor_run( assert captured_span.name == "MockActionExecutor.run" assert captured_span.attributes == expected + + +def test_instrument_action_executor_run_registers_tracer( + tracer_provider: TracerProvider, monkeypatch: MonkeyPatch +) -> None: + component_class = MockActionExecutor + + mock_tracer = trace.get_tracer(__name__) + + register_tracer_mock = Mock() + get_tracer_mock = Mock(return_value=mock_tracer) + + monkeypatch.setattr( + ActionExecutorTracerRegister, "register_tracer", register_tracer_mock() + ) + monkeypatch.setattr(ActionExecutorTracerRegister, "get_tracer", get_tracer_mock) + + instrumentation.instrument( + tracer_provider, + action_executor_class=component_class, + ) + register_tracer_mock.assert_called_once() + + provider = ActionExecutorTracerRegister() + tracer = provider.get_tracer() + + assert tracer is not None + assert tracer == mock_tracer diff --git a/tests/tracing/test_tracer_register.py b/tests/tracing/test_tracer_register.py new file mode 100644 index 000000000..b547e144c --- /dev/null +++ b/tests/tracing/test_tracer_register.py @@ -0,0 +1,21 @@ +from rasa_sdk.tracing.tracer_register import ActionExecutorTracerRegister +from opentelemetry import trace + + +def test_tracer_register_is_singleton() -> None: + tracer_register_1 = ActionExecutorTracerRegister() + tracer_register_2 = ActionExecutorTracerRegister() + + assert tracer_register_1 is tracer_register_2 + assert tracer_register_1.tracer is tracer_register_2.tracer + + +def test_trace_register() -> None: + tracer_register = ActionExecutorTracerRegister() + assert tracer_register.get_tracer() is None + + tracer = trace.get_tracer(__name__) + tracer_register.register_tracer(tracer) + + assert tracer_register.tracer == tracer + assert tracer_register.get_tracer() == tracer