Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CloudWatch EMF Route RequestLifecycleHooks #105

Merged
merged 9 commits into from
Oct 16, 2023
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@

Features
--------

- `[sc-26163] <https://app.shortcut.com/globus/story/26163>`
Added support for middleware registration to the flask ActionProviderBlueprint class.
Classes may be provided at Blueprint instantiation time to register before, after, or
teardown functionality to wrap all view invocation.

- `[sc-26163] <https://app.shortcut.com/globus/story/26163>`
Added a CloudWatchEMFLogger middleware class.
When attached to an ActionProviderBlueprint, it will emit request count, latency, and
response category (2xxs, 4xxs, 5xxs) count metrics through CloudWatch EMF. Metrics
are emitted both at the aggregate AP and at the individual route level.
27 changes: 27 additions & 0 deletions globus_action_provider_tools/flask/apt_blueprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def __init__(
globus_auth_client_name: t.Optional[str] = None,
additional_scopes: t.Iterable[str] = (),
action_repository: t.Optional[AbstractActionRepository] = None,
middleware: t.Optional[t.List[t.Any]] = None,
**kwarg,
):
"""Create a new ActionProviderBlueprint. All arguments not listed here are the
Expand All @@ -71,6 +72,11 @@ def __init__(
``globus_auth_scope`` value of the input provider description. Only
needed if more than one scope has been allocated for the Action
Provider's Globus Auth client_id.

:param middleware: A list of classes defining a before_request, after_request,
and/or teardown_request method. If these functions exist they will be registered
with the blueprint. Middleware classes are registered in the order they are
provided.
"""

super().__init__(*args, **kwarg)
Expand All @@ -86,6 +92,15 @@ def __init__(
self.register_error_handler(Exception, blueprint_error_handler)
self.record_once(self._create_token_checker)

if middleware:
for m in middleware:
if hasattr(m, "before_request"):
self.before_request(m.before_request)
if hasattr(m, "after_request"):
self.after_request(m.after_request)
if hasattr(m, "teardown_request"):
self.teardown_request(m.teardown_request)

self.add_url_rule(
"/",
"action_introspect",
Expand Down Expand Up @@ -152,6 +167,7 @@ def _action_introspect(self):
"""
Runs as an Action Provider's introspection endpoint.
"""
self._register_route_type("introspect")
if not g.auth_state.check_authorization(
self.provider_description.visible_to,
allow_public=True,
Expand All @@ -166,6 +182,7 @@ def _action_introspect(self):
return jsonify(self.provider_description), 200

def _action_enumerate(self):
self._register_route_type("enumerate")
if not g.auth_state.check_authorization(
self.provider_description.runnable_by,
allow_public=True,
Expand Down Expand Up @@ -206,6 +223,7 @@ def action_enumerate(self, func: ActionEnumerationCallback):
return func

def _action_run(self):
self._register_route_type("run")
if not g.auth_state.check_authorization(
self.provider_description.runnable_by,
allow_all_authenticated_users=True,
Expand Down Expand Up @@ -257,6 +275,7 @@ def action_run(self, func: ActionRunCallback):
return func

def _action_resume(self, action_id: str):
self._register_route_type("resume")
# Attempt to lookup the Action based on its action_id if there was an
# Action Repo defined. If an action is found, verify access to it.
action = None
Expand Down Expand Up @@ -307,6 +326,7 @@ def action_status(self, func: ActionStatusCallback):
return func

def _action_status(self, action_id: str):
self._register_route_type("status")
"""
Attempts to load an action_status via its action_id using an
action_loader. If an action is successfully loaded, view access by the
Expand Down Expand Up @@ -351,6 +371,7 @@ def action_cancel(self, func: ActionCancelCallback):
return func

def _action_cancel(self, action_id: str):
self._register_route_type("cancel")
"""
Executes a user-defined function for cancelling an Action.
"""
Expand Down Expand Up @@ -405,6 +426,7 @@ def action_release(self, func: ActionReleaseCallback):
return func

def _action_release(self, action_id: str):
self._register_route_type("release")
"""
Decorates a function to be run as an Action Provider's release endpoint.
"""
Expand Down Expand Up @@ -446,6 +468,7 @@ def action_log(self, func: ActionLogCallback):
return func

def _action_log(self, action_id: str):
self._register_route_type("log")
# Attempt to use a user-defined function to lookup the Action based
# on its action_id. If an action is found, authorize access to it
action = None
Expand All @@ -456,6 +479,10 @@ def _action_log(self, action_id: str):
status = self.action_log_callback(action_id, g.auth_state)
return jsonify(status), 200

def _register_route_type(self, route_type: str):
if not hasattr(g, "route_type"):
g.route_type = route_type

def _load_action_by_id(
self, repo: AbstractActionRepository, action_id: str
) -> ActionStatus:
Expand Down
Empty file.
237 changes: 237 additions & 0 deletions globus_action_provider_tools/flask/middleware/cloudwatch_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,237 @@
import json
import logging
import typing as t
from datetime import datetime, timedelta

from flask import Response, g

log = logging.getLogger("cloudwatch-emf-middleware")


class CloudWatchMetricEMFLogger:
"""
Middleware to emit CloudWatch Metrics detailing action provider usage via
the CloudWatch EMF Format.
https://docs.aws.amazon.com/AmazonCloudWatch/latest/monitoring/CloudWatch_Embedded_Metric_Format_Specification.html

Metric Structure
================

Aggregate
---------
Namespace: {supplied_namespace}
Dimensions:
ActionProvider: {supplied_action_provider_name}

Route-Specific
--------------
Namespace: {supplied_namespace}
Dimensions:
ActionProvider: {supplied_action_provider_name}
Route: "run" | "resume" | "status" | ...

Included Metrics:
* Count - The total number API requests in a given period.
* 2XXs - The number of successful responses returned in a given period.
* 4XXs - The number of client-side errors captured in a given period.
* 5XXs - The number of server-side errors captured in a given period.
* RequestLatency - The number of milliseconds between the request being received
and the response being sent.
"""

def __init__(
self, namespace: str, action_provider_name: str, log_level: int | None = None
):
"""
:param namespace: Custom CloudWatch Namespace target
:param action_provider_name: Action Provider Name to be used in metric dimension
sets
:param log_level: Optional log level to use when emitting metrics. If None,
metrics will be printed to stdout instead of logged.
"""
self._namespace = namespace
self._action_provider_name = action_provider_name
self._log_level = log_level

def before_request(self):
g.request_start_time = datetime.now()

def after_request(self, response: Response):
if hasattr(g, "route_type") and hasattr(g, "request_start_time"):
self.emit_route_metrics(
route_name=g.route_type,
request_latency=datetime.now() - g.request_start_time,
response_status=response.status_code,
)
return response

def teardown_request(self, error: BaseException | None):
# If a request errors mid-handling, after_request handlers will not be called,
# so we need to emit metrics for errors separately here
if error:
if hasattr(g, "route_type") and hasattr(g, "request_start_time"):
status_code = 500
if hasattr(error, "code"):
status_code = error.code
self.emit_route_metrics(
route_name=g.route_type,
request_latency=datetime.now() - g.request_start_time,
derek-globus marked this conversation as resolved.
Show resolved Hide resolved
response_status=status_code,
)
raise error

def emit_route_metrics(
self,
route_name: str,
request_latency: timedelta,
response_status: int,
):
request_latency_ms = request_latency.total_seconds() * 1000
emf_log = _serialize_to_emf(
namespace=self._namespace,
dimension_sets=[
{"ActionProvider": self._action_provider_name},
{"ActionProvider": self._action_provider_name, "Route": route_name},
],
metrics=[
("Count", 1, "Count"),
("2XXs", 1 if 200 <= response_status < 300 else 0, "Count"),
("4XXs", 1 if 400 <= response_status < 500 else 0, "Count"),
("5XXs", 1 if 500 <= response_status < 600 else 0, "Count"),
("RequestLatency", request_latency_ms, "Milliseconds"),
],
)
emf_log = json.dumps(emf_log)

if not self._log_level:
print(emf_log)
else:
log.log(self._log_level, emf_log)


# fmt: off
# https://docs.aws.amazon.com/AmazonCloudWatch/latest/APIReference/API_MetricDatum.html
CloudWatchUnit = t.Literal[
"Seconds", "Microseconds", "Milliseconds", "Bytes", "Kilobytes", "Megabytes",
"Gigabytes", "Terabytes", "Bits", "Kilobits", "Megabits", "Gigabits", "Terabits",
"Percent", "Count", "Bytes/Second", "Kilobytes/Second", "Megabytes/Second",
"Gigabytes/Second", "Terabytes/Second", "Bits/Second", "Kilobits/Second",
"Megabits/Second", "Gigabits/Second", "Terabits/Second", "Count/Second", "None"
]
# fmt: on


def _serialize_to_emf(
namespace: str,
dimension_sets: list[dict[str, str]],
metrics: list[tuple[str, str | int | float, CloudWatchUnit | None]],
timestamp: datetime | None = None,
) -> dict[str, t.Any]:
"""
Serializes a list of metrics into CloudWatch Embedded Metric Format
https://docs.aws.amazon.com/AmazonCloudWatch/latest/monitoring/CloudWatch_Embedded_Metric_Format_Specification.html

This results in an object like
```json
{
"_aws": {
"Timestamp": 1680634571444,
"CloudWatchMetrics": [{
"Namespace": "MyCoolNamespace",
"Dimensions": [["Foo"], ["Foo", "Bar"]],
"Metrics": [{ "Name": "MyCoolMetric", "Unit": "Milliseconds" }]
}]
},
"Foo": "a",
"Bar": "b",
"MyCoolMetric": 37,
}
```
Note how there are two additional top-level keys besides "_aws".
This is because Dimension Values & Metric Values must be referenced not passed
explicitly

:namespace str: Namespace
:metric_name str: Metric Name
:dimension_sets list[dict[str, str]]: A collection of Dimension Sets (each metric
will be emitted with each dimension set)
:metrics list[tuple[str, str | int | float, str | None]]: Metric Tuple in the format
(metric_name, value, Optional[unit])
:timestamp datetime | None: Timestamp to use for the metric. If None, the current
time will be used.
:returns: An emf formatted dict
"""
timestamp = timestamp or datetime.now()
epoch_ms = int(timestamp.timestamp() * 1000)
_verify_no_emf_root_collisions(
{metric_name for metric_name, _, _ in metrics}, dimension_sets
)

emf_obj = {}

emf_metrics = []
for metric_name, value, unit in metrics:
emf_obj[metric_name] = value
emf_metric = {"Name": metric_name}
if unit is not None:
emf_metric["Unit"] = unit
emf_metrics.append(emf_metric)

emf_dimension_sets = []
for dimension_map in dimension_sets:
for dimension_name, dimension_value in dimension_map.items():
emf_obj[dimension_name] = dimension_value
emf_dimension_sets.append(list(dimension_map.keys()))

emf_obj["_aws"] = {
"Timestamp": epoch_ms,
"CloudWatchMetrics": [
{
"Namespace": namespace,
"Dimensions": emf_dimension_sets,
"Metrics": emf_metrics,
}
],
}

return emf_obj


def _verify_no_emf_root_collisions(
metric_names: set[str], dimension_sets: list[dict[str, str]]
):
"""
Verify that there are no disallowed collisions between the root keys of the emf
object

:raises: RuntimeError if names/values collide in ways that preclude them from
being emitted via EMF
"""
# Verify that no dimension names match any metric names
dimension_names = {
dimension_name
for dimension_map in dimension_sets
for dimension_name in dimension_map.keys()
}

namespace_collisions = metric_names.intersection(dimension_names)
if namespace_collisions:
raise RuntimeError(
f"Cannot overlap dimension names and metric names ({namespace_collisions})"
)

# Verify that no dimension names in different dimension sets conflict
dimension_values = {}
for dimension_map in dimension_sets:
for dimension_name, dimension_value in dimension_map.items():
dimension_values.setdefault(dimension_name, set()).add(dimension_value)
dimension_collisions = {
dimension_name
for dimension_name, dimension_value in dimension_values.items()
if len(dimension_value) > 1
}
if dimension_collisions:
raise RuntimeError(
f"Dimension sets with the same name must have the same value "
f"({dimension_collisions})"
)
Empty file.
Loading
Loading