Skip to content

Commit

Permalink
add test
Browse files Browse the repository at this point in the history
  • Loading branch information
Tawakalt committed Feb 7, 2024
1 parent 987ed6d commit 8c397b0
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 0 deletions.
23 changes: 23 additions & 0 deletions tests/tracing/instrumentation/conftest.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import pytest
from typing import Text

from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter

from rasa_sdk.executor import ActionExecutor
from rasa_sdk.types import ActionCall


@pytest.fixture(scope="session")
def tracer_provider() -> TracerProvider:
Expand All @@ -21,3 +25,22 @@ def span_exporter(tracer_provider: TracerProvider) -> InMemorySpanExporter:
def previous_num_captured_spans(span_exporter: InMemorySpanExporter) -> int:
captured_spans = span_exporter.get_finished_spans() # type: ignore
return len(captured_spans)


class MockActionExecutor(ActionExecutor):
def __init__(self) -> None:
self.fail_if_undefined("run")

def fail_if_undefined(self, method_name: Text) -> None:
if not (
hasattr(self.__class__.__base__, method_name)
and callable(getattr(self.__class__.__base__, method_name))
):
pytest.fail(
f"method '{method_name}' not found in {self.__class__.__base__}. "
f"This likely means the method was renamed, which means the "
f"instrumentation needs to be adapted!"
)

async def run(self, action_call: ActionCall) -> None:
pass
53 changes: 53 additions & 0 deletions tests/tracing/instrumentation/test_action_executor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from typing import Sequence

import pytest
from opentelemetry.sdk.trace import ReadableSpan, TracerProvider
from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter

from rasa_sdk.tracing.instrumentation import instrumentation
from tests.tracing.instrumentation.conftest import MockActionExecutor
from rasa_sdk.types import ActionCall
from rasa_sdk import Tracker


@pytest.mark.asyncio
async def test_tracing_action_executor_run(
tracer_provider: TracerProvider,
span_exporter: InMemorySpanExporter,
previous_num_captured_spans: int,
) -> None:
component_class = MockActionExecutor

instrumentation.instrument(
tracer_provider,
action_executor_class=component_class,
)

mock_action_executor = component_class()
action_call = ActionCall(
{
"next_action": "check_balance",
"sender_id": "test",
"tracker": Tracker("test", {}, {}, [], False, None, {}, ""),
"version": "1.0.0",
"domain": {},
}
)
await mock_action_executor.run(action_call)

captured_spans: Sequence[
ReadableSpan
] = span_exporter.get_finished_spans() # type: ignore

num_captured_spans = len(captured_spans) - previous_num_captured_spans
assert num_captured_spans == 1

captured_span = captured_spans[-1]

assert captured_span.name == "MockActionExecutor.run"

expected_attributes = {
"next_action": "check_balance",
"sender_id": "test",
}
assert captured_span.attributes == expected_attributes

0 comments on commit 8c397b0

Please sign in to comment.