Skip to content

Commit

Permalink
instrument ActionExecutor.run method
Browse files Browse the repository at this point in the history
  • Loading branch information
Tawakalt committed Feb 7, 2024
1 parent 8c0dea8 commit 987ed6d
Show file tree
Hide file tree
Showing 4 changed files with 190 additions and 0 deletions.
20 changes: 20 additions & 0 deletions rasa_sdk/tracing/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor
from rasa_sdk.tracing.endpoints import EndpointConfig, read_endpoint_config
from rasa_sdk.tracing.instrumentation import instrumentation
from rasa_sdk.executor import ActionExecutor


TRACING_SERVICE_NAME = os.environ.get("RASA_SDK_TRACING_SERVICE_NAME", "rasa_sdk")
Expand All @@ -21,6 +23,24 @@
logger = logging.getLogger(__name__)


def configure_tracing(tracer_provider: Optional[TracerProvider]) -> None:
"""Configure tracing functionality.
When a tracing backend is defined, this function will
instrument all methods that shall be traced.
If no tracing backend is defined, no tracing is configured.
:param tracer_provider: The `TracingProvider` to be used for tracing
"""
if tracer_provider is None:
return None

instrumentation.instrument(
tracer_provider=tracer_provider,
action_executor_class=ActionExecutor,
)


def get_tracer_provider(endpoints_file: Text) -> Optional[TracerProvider]:
"""Configure tracing backend.
Expand Down
27 changes: 27 additions & 0 deletions rasa_sdk/tracing/instrumentation/attribute_extractors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from typing import Any, Dict, Text
from rasa_sdk.executor import ActionExecutor
from rasa_sdk.types import ActionCall


# This file contains all attribute extractors for tracing instrumentation.
# These are functions that are applied to the arguments of the wrapped function to be
# traced to extract the attributes that we want to forward to our tracing backend.
# Note that we always mirror the argument lists of the wrapped functions, as our
# wrapping mechanism always passes in the original arguments unchanged for further
# processing.


def extract_attrs_for_action_executor(
self: ActionExecutor,
action_call: ActionCall,
) -> Dict[Text, Any]:
"""Extract the attributes for `ActionExecutor.run`.
:param self: The `ActionExecutor` on which `run` is called.
:param action_call: The `ActionCall` argument.
:return: A dictionary containing the attributes.
"""
return {
"next_action": action_call["next_action"],
"sender_id": action_call["sender_id"],
}
142 changes: 142 additions & 0 deletions rasa_sdk/tracing/instrumentation/instrumentation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
import functools
import inspect
import logging
from typing import (
Any,
Awaitable,
Callable,
Dict,
Optional,
Text,
Type,
TypeVar,
)

from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.trace import Tracer
from rasa_sdk.executor import ActionExecutor
from rasa_sdk.tracing.instrumentation import attribute_extractors

# The `TypeVar` representing the return type for a function to be wrapped.
S = TypeVar("S")
# The `TypeVar` representing the type of the argument passed to the function to be
# wrapped.
T = TypeVar("T")

logger = logging.getLogger(__name__)
INSTRUMENTED_BOOLEAN_ATTRIBUTE_NAME = "class_has_been_instrumented"


def _check_extractor_argument_list(
fn: Callable[[T, Any, Any], S],
attr_extractor: Optional[Callable[[T, Any, Any], Dict[str, Any]]],
) -> bool:
if attr_extractor is None:
return False

fn_args = inspect.signature(fn)
attr_args = inspect.signature(attr_extractor)

are_arglists_congruent = fn_args.parameters.keys() == attr_args.parameters.keys()

if not are_arglists_congruent:
logger.warning(
f"Argument lists for {fn.__name__} and {attr_extractor.__name__}"
f" do not match up. {fn.__name__} will be traced without attributes."
)

return are_arglists_congruent


def traceable_async(
fn: Callable[[T, Any, Any], Awaitable[S]],
tracer: Tracer,
attr_extractor: Optional[Callable[[T, Any, Any], Dict[str, Any]]],
) -> Callable[[T, Any, Any], Awaitable[S]]:
"""Wrap an `async` function by tracing functionality.
:param fn: The function to be wrapped.
:param tracer: The `Tracer` that shall be used for tracing this function.
:param attr_extractor: A function that is applied to the function's instance and
the function's arguments.
:return: The wrapped function.
"""
should_extract_args = _check_extractor_argument_list(fn, attr_extractor)

@functools.wraps(fn)
async def async_wrapper(self: T, *args: Any, **kwargs: Any) -> S:
attrs = (
attr_extractor(self, *args, **kwargs)
if attr_extractor and should_extract_args
else {}
)
with tracer.start_as_current_span(
f"{self.__class__.__name__}.{fn.__name__}", attributes=attrs
):
return await fn(self, *args, **kwargs)

return async_wrapper


ActionExecutorType = TypeVar("ActionExecutorType", bound=ActionExecutor)


def instrument(
tracer_provider: TracerProvider,
action_executor_class: Optional[Type[ActionExecutorType]] = None,
) -> None:
"""Substitute methods to be traced by their traced counterparts.
:param tracer_provider: The `TracerProvider` to be used for configuring tracing
on the substituted methods.
:param action_executor_class: The `ActionExecutor` to be instrumented. If `None`
is given, no `ActionExecutor` will be instrumented.
"""
if action_executor_class is not None and not class_is_instrumented(
action_executor_class
):
_instrument_method(
tracer_provider.get_tracer(action_executor_class.__module__),
action_executor_class,
"run",
attribute_extractors.extract_attrs_for_action_executor,
)
mark_class_as_instrumented(action_executor_class)


def _instrument_method(
tracer: Tracer,
instrumented_class: Type,
method_name: Text,
attr_extractor: Optional[Callable],
) -> None:
method_to_trace = getattr(instrumented_class, method_name)
traced_method = traceable_async(method_to_trace, tracer, attr_extractor)
setattr(instrumented_class, method_name, traced_method)

logger.debug(f"Instrumented '{instrumented_class.__name__}.{method_name}'.")


def _mangled_instrumented_boolean_attribute_name(instrumented_class: Type) -> Text:
# see https://peps.python.org/pep-0008/#method-names-and-instance-variables
# and https://stackoverflow.com/a/50401073
return f"_{instrumented_class.__name__}__{INSTRUMENTED_BOOLEAN_ATTRIBUTE_NAME}"


def class_is_instrumented(instrumented_class: Type) -> bool:
"""Check if a class has already been instrumented."""
return getattr(
instrumented_class,
_mangled_instrumented_boolean_attribute_name(instrumented_class),
False,
)


def mark_class_as_instrumented(instrumented_class: Type) -> None:
"""Mark a class as instrumented if it isn't already marked."""
if not class_is_instrumented(instrumented_class):
setattr(
instrumented_class,
_mangled_instrumented_boolean_attribute_name(instrumented_class),
True,
)
1 change: 1 addition & 0 deletions rasa_sdk/tracing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def get_tracer_provider(

if endpoints_file is not None:
tracer_provider = config.get_tracer_provider(endpoints_file)
config.configure_tracing(tracer_provider)
return tracer_provider


Expand Down

0 comments on commit 987ed6d

Please sign in to comment.