Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SNS support. #197

Open
wants to merge 19 commits into
base: main
Choose a base branch
from
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@
AWS_QUEUE_URL: str = "aws.sqs.queue_url"
AWS_QUEUE_NAME: str = "aws.sqs.queue_name"
AWS_STREAM_NAME: str = "aws.kinesis.stream_name"
AWS_TOPIC_ARN: str = "aws.sns.topic_arn"
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
AWS_REMOTE_SERVICE,
AWS_SPAN_KIND,
AWS_STREAM_NAME,
AWS_TOPIC_ARN,
)
from amazon.opentelemetry.distro._aws_span_processing_util import (
LOCAL_ROOT,
Expand Down Expand Up @@ -78,6 +79,7 @@
_NORMALIZED_KINESIS_SERVICE_NAME: str = "AWS::Kinesis"
_NORMALIZED_S3_SERVICE_NAME: str = "AWS::S3"
_NORMALIZED_SQS_SERVICE_NAME: str = "AWS::SQS"
_NORMALIZED_SNS_SERVICE_NAME: str = "AWS::SNS"
thpierce marked this conversation as resolved.
Show resolved Hide resolved
_DB_CONNECTION_STRING_TYPE: str = "DB::Connection"

# Special DEPENDENCY attribute value if GRAPHQL_OPERATION_TYPE attribute key is present.
Expand Down Expand Up @@ -372,6 +374,9 @@ def _set_remote_type_and_identifier(span: ReadableSpan, attributes: BoundedAttri
remote_resource_identifier = _escape_delimiters(
SqsUrlParser.get_queue_name(span.attributes.get(AWS_QUEUE_URL))
)
elif is_key_present(span, AWS_TOPIC_ARN):
remote_resource_type = _NORMALIZED_SNS_SERVICE_NAME + "::TopicArn"
thpierce marked this conversation as resolved.
Show resolved Hide resolved
remote_resource_identifier = _escape_delimiters(span.attributes.get(AWS_TOPIC_ARN))
elif is_db_span(span):
remote_resource_type = _DB_CONNECTION_STRING_TYPE
remote_resource_identifier = _get_db_connection(span)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,22 @@
import importlib

from opentelemetry.instrumentation.botocore.extensions import _KNOWN_EXTENSIONS
from opentelemetry.instrumentation.botocore.extensions.sns import _SnsExtension
from opentelemetry.instrumentation.botocore.extensions.sqs import _SqsExtension
from opentelemetry.instrumentation.botocore.extensions.types import _AttributeMapT, _AwsSdkExtension
from opentelemetry.instrumentation.botocore.extensions.types import _AttributeMapT, _AwsSdkExtension, _BotoResultT
from opentelemetry.semconv.trace import SpanAttributes
from opentelemetry.trace.span import Span


def _apply_botocore_instrumentation_patches() -> None:
"""Botocore instrumentation patches

Adds patches to provide additional support and Java parity for Kinesis, S3, and SQS.
Adds patches to provide additional support for Kinesis, S3, SQS and SNS.
"""
_apply_botocore_kinesis_patch()
_apply_botocore_s3_patch()
_apply_botocore_sqs_patch()
_apply_botocore_sns_patch()


def _apply_botocore_kinesis_patch() -> None:
Expand Down Expand Up @@ -65,6 +68,36 @@ def patch_extract_attributes(self, attributes: _AttributeMapT):
_SqsExtension.extract_attributes = patch_extract_attributes


def _apply_botocore_sns_patch() -> None:
"""Botocore instrumentation patch for SNS

This patch extends the existing upstream extension for SNS. Extensions allow for custom logic for adding
service-specific information to spans, such as attributes. Specifically, we are adding logic to add
"aws.sns.topic_arn" attributes to be used to generate AWS_REMOTE_RESOURCE_TYPE and AWS_REMOTE_RESOURCE_IDENTIFIER.
"""
old_extract_attributes = _SnsExtension.extract_attributes

def patch_extract_attributes(self, attributes: _AttributeMapT):
old_extract_attributes(self, attributes)
topic_arn = self._call_context.params.get("TopicArn")
if topic_arn:
attributes["aws.sns.topic_arn"] = topic_arn

old_on_success = _SnsExtension.on_success

def patch_on_success(self, span: Span, result: _BotoResultT):
thpierce marked this conversation as resolved.
Show resolved Hide resolved
old_on_success(self, span, result)
topic_arn = result.get("TopicArn")
if topic_arn:
span.set_attribute(
"aws.sns.topic_arn",
topic_arn,
)

_SnsExtension.extract_attributes = patch_extract_attributes
_SnsExtension.on_success = patch_on_success


# The OpenTelemetry Authors code
def _lazy_load(module, cls):
"""Clone of upstream opentelemetry.instrumentation.botocore.extensions.lazy_load
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
AWS_REMOTE_SERVICE,
AWS_SPAN_KIND,
AWS_STREAM_NAME,
AWS_TOPIC_ARN,
)
from amazon.opentelemetry.distro._aws_metric_attribute_generator import _AwsMetricAttributeGenerator
from amazon.opentelemetry.distro.metric_attribute_generator import DEPENDENCY_METRIC, SERVICE_METRIC
Expand Down Expand Up @@ -821,6 +822,7 @@ def test_normalize_remote_service_name_aws_sdk(self):
self.validate_aws_sdk_service_normalization("Kinesis", "AWS::Kinesis")
self.validate_aws_sdk_service_normalization("S3", "AWS::S3")
self.validate_aws_sdk_service_normalization("SQS", "AWS::SQS")
self.validate_aws_sdk_service_normalization("SNS", "AWS::SNS")

def validate_aws_sdk_service_normalization(self, service_name: str, expected_remote_service: str):
self._mock_attribute([SpanAttributes.RPC_SYSTEM, SpanAttributes.RPC_SERVICE], ["aws-api", service_name])
Expand Down Expand Up @@ -977,6 +979,11 @@ def test_sdk_client_span_with_remote_resource_attributes(self):
self._validate_remote_resource_attributes("AWS::DynamoDB::Table", "aws_table^^name")
self._mock_attribute([SpanAttributes.AWS_DYNAMODB_TABLE_NAMES], [None])

# Validate behaviour of AWS_TOPIC_ARN attribute, then remove it
self._mock_attribute([AWS_TOPIC_ARN], ["arn:aws:sns:us-west-2:012345678901:test_topic"], keys, values)
self._validate_remote_resource_attributes("AWS::SNS::TopicArn", "arn:aws:sns:us-west-2:012345678901:test_topic")
self._mock_attribute([AWS_TOPIC_ARN], [None])

self._mock_attribute([SpanAttributes.RPC_SYSTEM], [None])

def test_client_db_span_with_remote_resource_attributes(self):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
from typing import Dict
from typing import Any, Dict
from unittest import TestCase
from unittest.mock import MagicMock, patch

Expand All @@ -9,11 +9,13 @@
from amazon.opentelemetry.distro.patches._instrumentation_patch import apply_instrumentation_patches
from opentelemetry.instrumentation.botocore.extensions import _KNOWN_EXTENSIONS
from opentelemetry.semconv.trace import SpanAttributes
from opentelemetry.trace.span import Span

_STREAM_NAME: str = "streamName"
_BUCKET_NAME: str = "bucketName"
_QUEUE_NAME: str = "queueName"
_QUEUE_URL: str = "queueUrl"
_TOPIC_ARN: str = "topicArn"


class TestInstrumentationPatch(TestCase):
Expand Down Expand Up @@ -69,10 +71,15 @@ def _validate_unpatched_botocore_instrumentation(self):

# SQS
self.assertTrue("sqs" in _KNOWN_EXTENSIONS, "Upstream has removed the SQS extension")
attributes: Dict[str, str] = _do_extract_sqs_attributes()
self.assertTrue("aws.queue_url" in attributes)
self.assertFalse("aws.sqs.queue_url" in attributes)
self.assertFalse("aws.sqs.queue_name" in attributes)
sqs_attributes: Dict[str, str] = _do_extract_sqs_attributes()
self.assertTrue("aws.queue_url" in sqs_attributes)
self.assertFalse("aws.sqs.queue_url" in sqs_attributes)
self.assertFalse("aws.sqs.queue_name" in sqs_attributes)

# SNS
self.assertTrue("sns" in _KNOWN_EXTENSIONS, "Upstream has removed the SNS extension")
sns_attributes: Dict[str, str] = _do_extract_sns_attributes()
self.assertFalse("aws.sns.topic_arn" in sns_attributes)

def _validate_patched_botocore_instrumentation(self):
# Kinesis
Expand All @@ -96,6 +103,15 @@ def _validate_patched_botocore_instrumentation(self):
self.assertTrue("aws.sqs.queue_name" in sqs_attributes)
self.assertEqual(sqs_attributes["aws.sqs.queue_name"], _QUEUE_NAME)

# SNS
self.assertTrue("sns" in _KNOWN_EXTENSIONS)
sns_attributes: Dict[str, str] = _do_extract_sns_attributes()
self.assertTrue("aws.sns.topic_arn" in sns_attributes)
self.assertEqual(sns_attributes["aws.sns.topic_arn"], _TOPIC_ARN)
sns_success_attributes: Dict[str, str] = _do_sns_on_success()
self.assertTrue("aws.sns.topic_arn" in sns_success_attributes)
self.assertEqual(sns_success_attributes["aws.sns.topic_arn"], _TOPIC_ARN)
thpierce marked this conversation as resolved.
Show resolved Hide resolved


def _do_extract_kinesis_attributes() -> Dict[str, str]:
service_name: str = "kinesis"
Expand All @@ -115,10 +131,36 @@ def _do_extract_sqs_attributes() -> Dict[str, str]:
return _do_extract_attributes(service_name, params)


def _do_extract_sns_attributes() -> Dict[str, str]:
service_name: str = "sns"
params: Dict[str, str] = {"TopicArn": _TOPIC_ARN}
return _do_extract_attributes(service_name, params)


def _do_sns_on_success() -> Dict[str, str]:
service_name: str = "sns"
result: Dict[str, Any] = {"TopicArn": _TOPIC_ARN}
return _do_on_success(service_name, result)


def _do_extract_attributes(service_name: str, params: Dict[str, str]) -> Dict[str, str]:
mock_call_context: MagicMock = MagicMock()
mock_call_context.params = params
attributes: Dict[str, str] = {}
sqs_extension = _KNOWN_EXTENSIONS[service_name]()(mock_call_context)
sqs_extension.extract_attributes(attributes)
extension = _KNOWN_EXTENSIONS[service_name]()(mock_call_context)
extension.extract_attributes(attributes)
return attributes


def _do_on_success(service_name: str, result: Dict[str, Any]) -> Dict[str, str]:
span_mock: Span = MagicMock()
span_attributes: Dict[str, str] = {}

def set_side_effect(set_key, set_value):
span_attributes[set_key] = set_value

span_mock.set_attribute.side_effect = set_side_effect
extension = _KNOWN_EXTENSIONS[service_name]()(span_mock)
extension.on_success(span_mock, result)

return span_attributes
49 changes: 48 additions & 1 deletion contract-tests/images/applications/botocore/botocore_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,14 @@ def do_GET(self):
self._handle_sqs_request()
if self.in_path("kinesis"):
self._handle_kinesis_request()
if self.in_path("sns"):
self._handle_sns_request()

self._end_request(self.main_status)

# pylint: disable=invalid-name
def do_POST(self):
if self.in_path("sqserror"):
if self.in_path("sqserror") or self.in_path("snserror"):
self.send_response(self.main_status)
self.send_header("Content-type", "text/xml")
self.end_headers()
Expand Down Expand Up @@ -203,6 +205,47 @@ def _handle_kinesis_request(self) -> None:
else:
set_main_status(404)

def _handle_sns_request(self) -> None:
sns_client: BaseClient = boto3.client("sns", endpoint_url=_AWS_SDK_ENDPOINT, region_name=_AWS_REGION)
if self.in_path(_ERROR):
set_main_status(400)
try:
error_client: BaseClient = boto3.client(
"sns", endpoint_url=_ERROR_ENDPOINT + "/snserror", region_name=_AWS_REGION
)
topic_arn = "arn:aws:sns:us-west-2:000000000000:test_topic/snserror"
message = "Hello from Amazon SNS!"
subject = "Test Message"
message_attributes = {"Attribute1": {"DataType": "String", "StringValue": "Value1"}}
error_client.publish(
TopicArn=topic_arn, Message=message, Subject=subject, MessageAttributes=message_attributes
)
except Exception as exception:
print("Expected exception occurred", exception)
elif self.in_path(_FAULT):
set_main_status(500)
try:
fault_client: BaseClient = boto3.client(
"sns", endpoint_url=_FAULT_ENDPOINT, region_name=_AWS_REGION, config=_NO_RETRY_CONFIG
)
fault_client.get_topic_attributes(TopicArn="invalid_topic_arn")
except Exception as exception:
print("Expected exception occurred", exception)
elif self.in_path("gettopattributes/get-topic-attributes"):
set_main_status(200)
sns_client.get_topic_attributes(TopicArn="arn:aws:sns:us-west-2:000000000000:test_topic")
elif self.in_path("publishmessage/publish-message/some-message"):
set_main_status(200)
topic_arn = "arn:aws:sns:us-west-2:000000000000:test_topic"
message = "Hello from Amazon SNS!"
subject = "Test Message"
message_attributes = {"Attribute1": {"DataType": "String", "StringValue": "Value1"}}
sns_client.publish(
TopicArn=topic_arn, Message=message, Subject=subject, MessageAttributes=message_attributes
)
else:
set_main_status(404)

def _end_request(self, status_code: int):
self.send_response_only(status_code)
self.end_headers()
Expand Down Expand Up @@ -247,6 +290,10 @@ def prepare_aws_server() -> None:
# Set up Kinesis so tests can access a stream.
kinesis_client: BaseClient = boto3.client("kinesis", endpoint_url=_AWS_SDK_ENDPOINT, region_name=_AWS_REGION)
kinesis_client.create_stream(StreamName="test_stream", ShardCount=1)

# Set up SNS so tests can access a topic.
sns_client: BaseClient = boto3.client("sns", endpoint_url=_AWS_SDK_ENDPOINT, region_name=_AWS_REGION)
sns_client.create_topic(Name="test_topic")
except Exception as exception:
print("Unexpected exception occurred", exception)

Expand Down
Loading
Loading