From cc398def9922f18325b5a6d521c5a44418b070f9 Mon Sep 17 00:00:00 2001 From: Tawakalt Date: Tue, 13 Feb 2024 14:59:34 +0100 Subject: [PATCH 1/4] implement TraceProvider component --- .../instrumentation/instrumentation.py | 5 +++- rasa_sdk/tracing/trace_provider.py | 23 +++++++++++++++++++ rasa_sdk/utils.py | 22 ++++++++++++++++++ 3 files changed, 49 insertions(+), 1 deletion(-) create mode 100644 rasa_sdk/tracing/trace_provider.py diff --git a/rasa_sdk/tracing/instrumentation/instrumentation.py b/rasa_sdk/tracing/instrumentation/instrumentation.py index 9e84a4c5d..5adcaa18a 100644 --- a/rasa_sdk/tracing/instrumentation/instrumentation.py +++ b/rasa_sdk/tracing/instrumentation/instrumentation.py @@ -17,6 +17,7 @@ from rasa_sdk.executor import ActionExecutor from rasa_sdk.forms import ValidationAction from rasa_sdk.tracing.instrumentation import attribute_extractors +from rasa_sdk.tracing.trace_provider import TraceProvider # The `TypeVar` representing the return type for a function to be wrapped. S = TypeVar("S") @@ -132,13 +133,15 @@ def instrument( if action_executor_class is not None and not class_is_instrumented( action_executor_class ): + tracer = tracer_provider.get_tracer(action_executor_class.__module__) _instrument_method( - tracer_provider.get_tracer(action_executor_class.__module__), + tracer, action_executor_class, "run", attribute_extractors.extract_attrs_for_action_executor, ) mark_class_as_instrumented(action_executor_class) + TraceProvider().register_tracer(tracer) if validation_action_class is not None and not class_is_instrumented( validation_action_class diff --git a/rasa_sdk/tracing/trace_provider.py b/rasa_sdk/tracing/trace_provider.py new file mode 100644 index 000000000..c5a599bab --- /dev/null +++ b/rasa_sdk/tracing/trace_provider.py @@ -0,0 +1,23 @@ +from typing import Optional +from rasa_sdk.utils import Singleton +from opentelemetry.trace import Tracer + + +class TraceProvider(metaclass=Singleton): + """Represents a provider for tracer.""" + + tracer: Optional[Tracer] = None + + def register_tracer(self, tracer: Tracer) -> None: + """Register a tracer. + Args: + trace: The tracer to register. + """ + self.tracer = tracer + + def get_tracer(self) -> Optional[Tracer]: + """Get the tracer. + Returns: + The tracer. + """ + return self.tracer diff --git a/rasa_sdk/utils.py b/rasa_sdk/utils.py index 37a66ae27..a9fb35cc9 100644 --- a/rasa_sdk/utils.py +++ b/rasa_sdk/utils.py @@ -46,6 +46,28 @@ class Button(dict): pass +class Singleton(type): + """Singleton metaclass.""" + + _instances: Dict[Any, Any] = {} + + def __call__(cls, *args: Any, **kwargs: Any) -> Any: + """Call the class. + + Args: + *args: Arguments. + **kwargs: Keyword arguments. + """ + if cls not in cls._instances: + cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) + + return cls._instances[cls] + + def clear(cls) -> None: + """Clear the class.""" + cls._instances = {} + + def all_subclasses(cls: Any) -> List[Any]: """Returns all known (imported) subclasses of a class.""" return cls.__subclasses__() + [ From fc7c4764eb00309398f6444911f64ae337d8ea17 Mon Sep 17 00:00:00 2001 From: Tawakalt Date: Tue, 13 Feb 2024 15:12:36 +0100 Subject: [PATCH 2/4] add tests --- .../instrumentation/test_action_executor.py | 35 +++++++++++++++++-- tests/tracing/test_trace_provider.py | 19 ++++++++++ 2 files changed, 52 insertions(+), 2 deletions(-) create mode 100644 tests/tracing/test_trace_provider.py diff --git a/tests/tracing/instrumentation/test_action_executor.py b/tests/tracing/instrumentation/test_action_executor.py index f3777f200..04006596b 100644 --- a/tests/tracing/instrumentation/test_action_executor.py +++ b/tests/tracing/instrumentation/test_action_executor.py @@ -1,13 +1,17 @@ -from typing import Any, Dict, Sequence, Text, Optional - import pytest + +from typing import Any, Dict, Sequence, Text, Optional +from unittest.mock import Mock +from pytest import MonkeyPatch from opentelemetry.sdk.trace import ReadableSpan, TracerProvider from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter +from opentelemetry import trace from rasa_sdk.tracing.instrumentation import instrumentation from tests.tracing.instrumentation.conftest import MockActionExecutor from rasa_sdk.types import ActionCall from rasa_sdk import Tracker +from rasa_sdk.tracing.trace_provider import TraceProvider @pytest.mark.parametrize( @@ -56,3 +60,30 @@ async def test_tracing_action_executor_run( assert captured_span.name == "MockActionExecutor.run" assert captured_span.attributes == expected + + +@pytest.mark.asyncio +async def test_instrument_action_executor_run_registers_tracer( + tracer_provider: TracerProvider, monkeypatch: MonkeyPatch +) -> None: + component_class = MockActionExecutor + + mock_tracer = trace.get_tracer(__name__) + + register_tracer_mock = Mock() + get_tracer_mock = Mock(return_value=mock_tracer) + + monkeypatch.setattr(TraceProvider, "register_tracer", register_tracer_mock()) + monkeypatch.setattr(TraceProvider, "get_tracer", get_tracer_mock) + + instrumentation.instrument( + tracer_provider, + action_executor_class=component_class, + ) + register_tracer_mock.assert_called_once() + + provider = TraceProvider() + tracer = provider.get_tracer() + + assert tracer is not None + assert tracer == mock_tracer diff --git a/tests/tracing/test_trace_provider.py b/tests/tracing/test_trace_provider.py new file mode 100644 index 000000000..8400f0a51 --- /dev/null +++ b/tests/tracing/test_trace_provider.py @@ -0,0 +1,19 @@ +from rasa_sdk.tracing.trace_provider import TraceProvider +from opentelemetry import trace + + +def test_anonymization_pipeline_provider_is_singleton() -> None: + trace_provider_1 = TraceProvider() + trace_provider_2 = TraceProvider() + + assert trace_provider_1 is trace_provider_2 + assert trace_provider_1.tracer is trace_provider_2.tracer + + +def test_trace_provider() -> None: + trace_provider = TraceProvider() + tracer = trace.get_tracer(__name__) + trace_provider.register_tracer(tracer) + + assert trace_provider.tracer == tracer + assert trace_provider.get_tracer() == tracer From 40275d15764c89ba7d5af638317d16ba770ee531 Mon Sep 17 00:00:00 2001 From: Tawakalt Date: Tue, 13 Feb 2024 15:15:00 +0100 Subject: [PATCH 3/4] add changelog entry --- changelog/1075.improvement.md | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog/1075.improvement.md diff --git a/changelog/1075.improvement.md b/changelog/1075.improvement.md new file mode 100644 index 000000000..c254fe74c --- /dev/null +++ b/changelog/1075.improvement.md @@ -0,0 +1 @@ +Implement functionality that enables creating additional spans within custom actions. \ No newline at end of file From c28a929462f14b7d1c47cf444fd20fb385bcbee2 Mon Sep 17 00:00:00 2001 From: Tawakalt Date: Tue, 13 Feb 2024 17:37:29 +0100 Subject: [PATCH 4/4] implement PR feedback --- .../instrumentation/instrumentation.py | 4 ++-- .../{trace_provider.py => tracer_register.py} | 8 +++---- .../instrumentation/test_action_executor.py | 13 ++++++------ tests/tracing/test_trace_provider.py | 19 ----------------- tests/tracing/test_tracer_register.py | 21 +++++++++++++++++++ 5 files changed, 34 insertions(+), 31 deletions(-) rename rasa_sdk/tracing/{trace_provider.py => tracer_register.py} (67%) delete mode 100644 tests/tracing/test_trace_provider.py create mode 100644 tests/tracing/test_tracer_register.py diff --git a/rasa_sdk/tracing/instrumentation/instrumentation.py b/rasa_sdk/tracing/instrumentation/instrumentation.py index 5adcaa18a..f28fd2b40 100644 --- a/rasa_sdk/tracing/instrumentation/instrumentation.py +++ b/rasa_sdk/tracing/instrumentation/instrumentation.py @@ -17,7 +17,7 @@ from rasa_sdk.executor import ActionExecutor from rasa_sdk.forms import ValidationAction from rasa_sdk.tracing.instrumentation import attribute_extractors -from rasa_sdk.tracing.trace_provider import TraceProvider +from rasa_sdk.tracing.tracer_register import ActionExecutorTracerRegister # The `TypeVar` representing the return type for a function to be wrapped. S = TypeVar("S") @@ -141,7 +141,7 @@ def instrument( attribute_extractors.extract_attrs_for_action_executor, ) mark_class_as_instrumented(action_executor_class) - TraceProvider().register_tracer(tracer) + ActionExecutorTracerRegister().register_tracer(tracer) if validation_action_class is not None and not class_is_instrumented( validation_action_class diff --git a/rasa_sdk/tracing/trace_provider.py b/rasa_sdk/tracing/tracer_register.py similarity index 67% rename from rasa_sdk/tracing/trace_provider.py rename to rasa_sdk/tracing/tracer_register.py index c5a599bab..2985d5e8e 100644 --- a/rasa_sdk/tracing/trace_provider.py +++ b/rasa_sdk/tracing/tracer_register.py @@ -3,20 +3,20 @@ from opentelemetry.trace import Tracer -class TraceProvider(metaclass=Singleton): - """Represents a provider for tracer.""" +class ActionExecutorTracerRegister(metaclass=Singleton): + """Represents a provider for ActionExecutor tracer.""" tracer: Optional[Tracer] = None def register_tracer(self, tracer: Tracer) -> None: - """Register a tracer. + """Register an ActionExecutor tracer. Args: trace: The tracer to register. """ self.tracer = tracer def get_tracer(self) -> Optional[Tracer]: - """Get the tracer. + """Get the ActionExecutor tracer. Returns: The tracer. """ diff --git a/tests/tracing/instrumentation/test_action_executor.py b/tests/tracing/instrumentation/test_action_executor.py index 04006596b..05ab9f44c 100644 --- a/tests/tracing/instrumentation/test_action_executor.py +++ b/tests/tracing/instrumentation/test_action_executor.py @@ -11,7 +11,7 @@ from tests.tracing.instrumentation.conftest import MockActionExecutor from rasa_sdk.types import ActionCall from rasa_sdk import Tracker -from rasa_sdk.tracing.trace_provider import TraceProvider +from rasa_sdk.tracing.tracer_register import ActionExecutorTracerRegister @pytest.mark.parametrize( @@ -62,8 +62,7 @@ async def test_tracing_action_executor_run( assert captured_span.attributes == expected -@pytest.mark.asyncio -async def test_instrument_action_executor_run_registers_tracer( +def test_instrument_action_executor_run_registers_tracer( tracer_provider: TracerProvider, monkeypatch: MonkeyPatch ) -> None: component_class = MockActionExecutor @@ -73,8 +72,10 @@ async def test_instrument_action_executor_run_registers_tracer( register_tracer_mock = Mock() get_tracer_mock = Mock(return_value=mock_tracer) - monkeypatch.setattr(TraceProvider, "register_tracer", register_tracer_mock()) - monkeypatch.setattr(TraceProvider, "get_tracer", get_tracer_mock) + monkeypatch.setattr( + ActionExecutorTracerRegister, "register_tracer", register_tracer_mock() + ) + monkeypatch.setattr(ActionExecutorTracerRegister, "get_tracer", get_tracer_mock) instrumentation.instrument( tracer_provider, @@ -82,7 +83,7 @@ async def test_instrument_action_executor_run_registers_tracer( ) register_tracer_mock.assert_called_once() - provider = TraceProvider() + provider = ActionExecutorTracerRegister() tracer = provider.get_tracer() assert tracer is not None diff --git a/tests/tracing/test_trace_provider.py b/tests/tracing/test_trace_provider.py deleted file mode 100644 index 8400f0a51..000000000 --- a/tests/tracing/test_trace_provider.py +++ /dev/null @@ -1,19 +0,0 @@ -from rasa_sdk.tracing.trace_provider import TraceProvider -from opentelemetry import trace - - -def test_anonymization_pipeline_provider_is_singleton() -> None: - trace_provider_1 = TraceProvider() - trace_provider_2 = TraceProvider() - - assert trace_provider_1 is trace_provider_2 - assert trace_provider_1.tracer is trace_provider_2.tracer - - -def test_trace_provider() -> None: - trace_provider = TraceProvider() - tracer = trace.get_tracer(__name__) - trace_provider.register_tracer(tracer) - - assert trace_provider.tracer == tracer - assert trace_provider.get_tracer() == tracer diff --git a/tests/tracing/test_tracer_register.py b/tests/tracing/test_tracer_register.py new file mode 100644 index 000000000..b547e144c --- /dev/null +++ b/tests/tracing/test_tracer_register.py @@ -0,0 +1,21 @@ +from rasa_sdk.tracing.tracer_register import ActionExecutorTracerRegister +from opentelemetry import trace + + +def test_tracer_register_is_singleton() -> None: + tracer_register_1 = ActionExecutorTracerRegister() + tracer_register_2 = ActionExecutorTracerRegister() + + assert tracer_register_1 is tracer_register_2 + assert tracer_register_1.tracer is tracer_register_2.tracer + + +def test_trace_register() -> None: + tracer_register = ActionExecutorTracerRegister() + assert tracer_register.get_tracer() is None + + tracer = trace.get_tracer(__name__) + tracer_register.register_tracer(tracer) + + assert tracer_register.tracer == tracer + assert tracer_register.get_tracer() == tracer