Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ATO-2102] Instrument ValidationAction.run #1074

Merged
merged 6 commits into from
Feb 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog/1074.improvement.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Instrument `ValidationAction.run` method and extract attributes `class_name`, `sender_id`, `action_name` and `slots_to_validate`.
3 changes: 2 additions & 1 deletion rasa_sdk/tracing/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from rasa_sdk.tracing.endpoints import EndpointConfig, read_endpoint_config
from rasa_sdk.tracing.instrumentation import instrumentation
from rasa_sdk.executor import ActionExecutor

from rasa_sdk.forms import ValidationAction

TRACING_SERVICE_NAME = os.environ.get("RASA_SDK_TRACING_SERVICE_NAME", "rasa_sdk")

Expand All @@ -38,6 +38,7 @@ def configure_tracing(tracer_provider: Optional[TracerProvider]) -> None:
instrumentation.instrument(
tracer_provider=tracer_provider,
action_executor_class=ActionExecutor,
validation_action_class=ValidationAction,
)


Expand Down
32 changes: 30 additions & 2 deletions rasa_sdk/tracing/instrumentation/attribute_extractors.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import json

from typing import Any, Dict, Text
from rasa_sdk.executor import ActionExecutor
from rasa_sdk.types import ActionCall
from rasa_sdk.executor import ActionExecutor, CollectingDispatcher
from rasa_sdk.forms import ValidationAction
from rasa_sdk.types import ActionCall, DomainDict
from rasa_sdk import Tracker


# This file contains all attribute extractors for tracing instrumentation.
Expand Down Expand Up @@ -28,3 +32,27 @@ def extract_attrs_for_action_executor(
attributes["action_name"] = action_name

return attributes


def extract_attrs_for_validation_action(
self: ValidationAction,
dispatcher: "CollectingDispatcher",
tracker: "Tracker",
domain: "DomainDict",
) -> Dict[Text, Any]:
"""Extract the attributes for `ValidationAction.run`.

:param self: The `ValidationAction` on which `run` is called.
:param dispatcher: The `CollectingDispatcher` argument.
:param tracker: The `Tracker` argument.
:param domain: The `DomainDict` argument.
:return: A dictionary containing the attributes.
"""
slots_to_validate = tracker.slots_to_validate().keys()

return {
"class_name": self.__class__.__name__,
"sender_id": tracker.sender_id,
"slots_to_validate": json.dumps(list(slots_to_validate)),
"action_name": self.name(),
}
24 changes: 21 additions & 3 deletions rasa_sdk/tracing/instrumentation/instrumentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.trace import Tracer
from rasa_sdk.executor import ActionExecutor
from rasa_sdk.forms import ValidationAction
from rasa_sdk.tracing.instrumentation import attribute_extractors

# The `TypeVar` representing the return type for a function to be wrapped.
Expand Down Expand Up @@ -70,9 +71,11 @@ async def async_wrapper(self: T, *args: Any, **kwargs: Any) -> S:
if attr_extractor and should_extract_args
else {}
)
with tracer.start_as_current_span(
f"{self.__class__.__name__}.{fn.__name__}", attributes=attrs
):
if issubclass(self.__class__, ValidationAction):
span_name = f"ValidationAction.{self.__class__.__name__}.{fn.__name__}"
ancalita marked this conversation as resolved.
Show resolved Hide resolved
else:
span_name = f"{self.__class__.__name__}.{fn.__name__}"
with tracer.start_as_current_span(span_name, attributes=attrs):
return await fn(self, *args, **kwargs)

return async_wrapper
Expand Down Expand Up @@ -109,18 +112,22 @@ def wrapper(self: T, *args: Any, **kwargs: Any) -> S:


ActionExecutorType = TypeVar("ActionExecutorType", bound=ActionExecutor)
ValidationActionType = TypeVar("ValidationActionType", bound=ValidationAction)


def instrument(
tracer_provider: TracerProvider,
action_executor_class: Optional[Type[ActionExecutorType]] = None,
validation_action_class: Optional[Type[ValidationActionType]] = None,
) -> None:
"""Substitute methods to be traced by their traced counterparts.

:param tracer_provider: The `TracerProvider` to be used for configuring tracing
on the substituted methods.
:param action_executor_class: The `ActionExecutor` to be instrumented. If `None`
is given, no `ActionExecutor` will be instrumented.
:param validation_action_class: The `ValidationAction` to be instrumented. If `None`
is given, no `ValidationAction` will be instrumented.
"""
if action_executor_class is not None and not class_is_instrumented(
action_executor_class
Expand All @@ -133,6 +140,17 @@ def instrument(
)
mark_class_as_instrumented(action_executor_class)

if validation_action_class is not None and not class_is_instrumented(
validation_action_class
):
_instrument_method(
tracer_provider.get_tracer(validation_action_class.__module__),
validation_action_class,
"run",
attribute_extractors.extract_attrs_for_validation_action,
)
mark_class_as_instrumented(validation_action_class)


def _instrument_method(
tracer: Tracer,
Expand Down
13 changes: 12 additions & 1 deletion tests/test_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,18 @@ def test_server_health_returns_200():
def test_server_list_actions_returns_200():
request, response = app.test_client.get("/actions")
assert response.status == 200
assert len(response.json) == 3
assert len(response.json) == 4
ancalita marked this conversation as resolved.
Show resolved Hide resolved

# ENSURE TO UPDATE AS MORE ACTIONS ARE ADDED IN OTHER TESTS
expected = [
# defined in tests/test_actions.py
{"name": "custom_async_action"},
{"name": "custom_action"},
{"name": "custom_action_exception"},
ancalita marked this conversation as resolved.
Show resolved Hide resolved
# defined in tests/tracing/instrumentation/conftest.py
{"name": "mock_validation_action"},
]
assert response.json == expected


def test_server_webhook_unknown_action_returns_404():
Expand Down
33 changes: 31 additions & 2 deletions tests/tracing/instrumentation/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter

from rasa_sdk.executor import ActionExecutor
from rasa_sdk.types import ActionCall
from rasa_sdk.executor import ActionExecutor, CollectingDispatcher
from rasa_sdk.forms import ValidationAction
from rasa_sdk.types import ActionCall, DomainDict
from rasa_sdk import Tracker


@pytest.fixture(scope="session")
Expand Down Expand Up @@ -44,3 +46,30 @@ def fail_if_undefined(self, method_name: Text) -> None:

async def run(self, action_call: ActionCall) -> None:
pass


class MockValidationAction(ValidationAction):
def __init__(self) -> None:
self.fail_if_undefined("run")

def fail_if_undefined(self, method_name: Text) -> None:
if not (
hasattr(self.__class__.__base__, method_name)
and callable(getattr(self.__class__.__base__, method_name))
):
pytest.fail(
f"method '{method_name}' not found in {self.__class__.__base__}. "
f"This likely means the method was renamed, which means the "
f"instrumentation needs to be adapted!"
)

async def run(
self,
dispatcher: "CollectingDispatcher",
tracker: "Tracker",
domain: "DomainDict",
) -> None:
pass

def name(self) -> Text:
return "mock_validation_action"
63 changes: 63 additions & 0 deletions tests/tracing/instrumentation/test_validation_action.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from typing import List, Sequence

import pytest
from opentelemetry.sdk.trace import ReadableSpan, TracerProvider
from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter

from rasa_sdk.tracing.instrumentation import instrumentation
from tests.tracing.instrumentation.conftest import MockValidationAction
from rasa_sdk import Tracker
from rasa_sdk.executor import CollectingDispatcher
from rasa_sdk.events import SlotSet, EventType


@pytest.mark.parametrize(
"events, expected_slots_to_validate",
[
([], "[]"),
(
[SlotSet("name", "Tom"), SlotSet("address", "Berlin")],
'["name", "address"]',
),
],
)
@pytest.mark.asyncio
async def test_tracing_action_executor_run(
tracer_provider: TracerProvider,
span_exporter: InMemorySpanExporter,
previous_num_captured_spans: int,
events: List[EventType],
expected_slots_to_validate: str,
) -> None:
component_class = MockValidationAction

instrumentation.instrument(
tracer_provider,
validation_action_class=component_class,
)

mock_validation_action = component_class()
dispatcher = CollectingDispatcher()
tracker = Tracker.from_dict({"sender_id": "test", "events": events})

await mock_validation_action.run(dispatcher, tracker, {})

captured_spans: Sequence[
ReadableSpan
] = span_exporter.get_finished_spans() # type: ignore

num_captured_spans = len(captured_spans) - previous_num_captured_spans
assert num_captured_spans == 1

captured_span = captured_spans[-1]

assert captured_span.name == "ValidationAction.MockValidationAction.run"

expected_attributes = {
"class_name": component_class.__name__,
"sender_id": "test",
"slots_to_validate": expected_slots_to_validate,
"action_name": "mock_validation_action",
}

assert captured_span.attributes == expected_attributes
Loading