diff --git a/aws/lambda/refresh_twitch_access_token.py b/aws/lambda/refresh_twitch_access_token.py index 7791793..6013676 100644 --- a/aws/lambda/refresh_twitch_access_token.py +++ b/aws/lambda/refresh_twitch_access_token.py @@ -1,10 +1,20 @@ import json +import logging +import os +from typing import Tuple import boto3 -import urllib3 +import requests from botocore.exceptions import ClientError from urllib3 import encode_multipart_formdata +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s | %(levelname)-8s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", +) +logger = logging.getLogger(__name__) + def get_parameter(parameter_name): ssm = boto3.client("ssm") @@ -19,22 +29,22 @@ def get_secret(parameter): raise -def refresh_token(_refresh_token): +def refresh_token(_refresh_token: str) -> Tuple[bool, any]: # Retrieve client id from AWS Parameter Store # Retrieve client secret from AWS Parameter Store + # TODO: Don't hardcode secret name arn needs to fetch from env_var + client_id_param_arn = ( + "arn:aws:ssm:eu-west-2:339713094915:parameter/twitch/client_id" + ) + client_secret_param_arn = ( + "arn:aws:ssm:eu-west-2:339713094915:parameter/twitch/client_secret" + ) try: - client_id = get_secret( - "arn:aws:ssm:eu-west-2:339713094915:parameter/twitch/client_id" - ) # TODO: Don't hardcode secret name arn - client_secret = get_secret( - "arn:aws:ssm:eu-west-2:339713094915:parameter/twitch/client_secret" - ) # TODO: Don't hardcode secret name arn - except Exception as e: - return { - "statusCode": 500, - "body": json.dumps({"message": f"Error retrieving secret: {str(e)}"}), - } - http = urllib3.PoolManager() + client_id = get_secret(client_id_param_arn) + client_secret = get_secret(client_secret_param_arn) + except Exception as exception: + raise exception + try: url = "https://id.twitch.tv/oauth2/token" data = { @@ -45,28 +55,41 @@ def refresh_token(_refresh_token): } encoded_data = encode_multipart_formdata(data)[1] headers = {"Content-Type": "application/x-www-form-urlencoded}"} - response = http.request("POST", url, body=encoded_data, headers=headers) + response = requests.post(url, data=encoded_data, headers=headers) # Check if request was successful - if response.status == 200: - return True, json.loads(response.data.decode("utf-8")) + if response.status_code == 200: + return True, json.loads(response.json()) else: - return False, None - finally: - http.clear() + return False, json.loads(response.json()) + + except Exception as exception: + raise exception def store_in_dynamodb(_refresh_token, refresh_response): + """ + Stores or updates the refresh and access tokens in DynamoDB. + + Parameters: + - _refresh_token: The current refresh token held by the system. + - refresh_response: The result of a web request for a new token, containing the access token, + refresh token, scope, and token type. + + Returns: + - A dictionary with the HTTP status code and a message about the operation. + """ dynamodb = boto3.resource("dynamodb") - table = dynamodb.Table("MSecBot_User") # TODO: Don't hardcode dynamodb.Table name + table_name = os.getenv("DYNAMODB_USER_TABLE_NAME") + table = dynamodb.Table(table_name) try: # Check if the item exists - response = table.get_item(Key={"refresh_token": _refresh_token}) - if "Item" in response: + get_item_outcome = table.get_item(Key={"refresh_token": _refresh_token}) + if "Item" in get_item_outcome: # Item exists, update it table.update_item( - Key={"refresh_token": _refresh_token}, + Key={"id": int(get_item_outcome["Item"]["id"])}, UpdateExpression="set access_token=:a, refresh_token=:r, scope=:s, token_type=:t", ExpressionAttributeValues={ ":a": refresh_response.get("access_token"), @@ -103,10 +126,11 @@ def store_in_dynamodb(_refresh_token, refresh_response): } ) return {"statusCode": 200, "body": json.dumps("Token stored successfully!")} - except ClientError as e: + except ClientError as client_error: + logger.error(f"[client_error]: {client_error.response['Error']['Message']}") return { "statusCode": 500, - "body": json.dumps(f'Error: {e.response["Error"]["Message"]}'), + "body": json.dumps(f'Error: {client_error.response["Error"]["Message"]}'), } @@ -122,7 +146,15 @@ def lambda_handler(event, _context): } # Validate access token - is_refreshed, refresh_response = _refresh_token(refresh_token) + try: + is_refreshed, refresh_response = refresh_token(_refresh_token) + except Exception as exception: + return { + "statusCode": 500, + "body": json.dumps( + {"message": f"Error retrieving secret: {str(exception)}"} + ), + } if is_refreshed: # Update user refresh token in DynamoDB diff --git a/aws/lambda/store_oauth_authorize_code.py b/aws/lambda/store_oauth_authorize_code.py index e0beb32..e47135d 100644 --- a/aws/lambda/store_oauth_authorize_code.py +++ b/aws/lambda/store_oauth_authorize_code.py @@ -1,8 +1,10 @@ import json +import os import boto3 -import urllib3 +import requests from botocore.exceptions import ClientError +from moto.ssm.exceptions import ParameterNotFound def get_parameter(parameter_name): @@ -12,24 +14,24 @@ def get_parameter(parameter_name): def validate_token(access_token): - http = urllib3.PoolManager() try: url = "https://id.twitch.tv/oauth2/validate" headers = {"Authorization": f"OAuth {access_token}"} - response = http.request("GET", url, headers=headers) + response = requests.get(url=url, headers=headers) # Check if request was successful - if response.status == 200: - return True, json.loads(response.data.decode("utf-8")) + if response.status_code == 200: + return True, response.json() else: - return False, None - finally: - http.clear() + return False, response.json() + except Exception as exc_info: + return False, {"status": 500, "message": str(exc_info)} def store_in_dynamodb(token_response, validation_response): dynamodb = boto3.resource("dynamodb") - table = dynamodb.Table("MSecBot_User") # TODO: Don't hardcode dynamodb.Table name + table_name = os.getenv("DYNAMODB_USER_TABLE_NAME") + table = dynamodb.Table(table_name) try: # Check if the item exists @@ -66,10 +68,16 @@ def store_in_dynamodb(token_response, validation_response): "scopes": validation_response.get("scopes"), } ) - except ClientError as e: + + except TypeError as type_error: + return { + "statusCode": 500, + "body": json.dumps(f"Error: {type_error}"), + } + except ClientError as client_error: return { "statusCode": 500, - "body": json.dumps(f'Error: {e.response["Error"]["Message"]}'), + "body": json.dumps(f'Error: {client_error.response["Error"]["Message"]}'), } @@ -100,10 +108,19 @@ def lambda_handler(event, _context): client_secret = get_secret( "arn:aws:ssm:eu-west-2:339713094915:parameter/twitch/client_secret" ) # TODO: Don't hardcode secret name arn - except Exception as e: + except ParameterNotFound as parameter_not_found: + return { + "statusCode": 500, + "body": json.dumps( + {"message": f"Error retrieving secret: {str(parameter_not_found)}"} + ), + } + except Exception as exception: return { "statusCode": 500, - "body": json.dumps({"message": f"Error retrieving secret: {str(e)}"}), + "body": json.dumps( + {"message": f"Error retrieving secret: {str(exception)}"} + ), } # Retrieve redirect_uri from AWS Parameter Store @@ -111,10 +128,12 @@ def lambda_handler(event, _context): redirect_uri = get_parameter( "arn:aws:ssm:eu-west-2:339713094915:parameter/twitch/oath2/redirect_url" ) # TODO: Don't hardcode twitch/oath2/redirect_url arn - except Exception as e: + except Exception as exception: return { "statusCode": 500, - "body": json.dumps({"message": f"Error retrieving redirect uri: {str(e)}"}), + "body": json.dumps( + {"message": f"Error retrieving redirect uri: {str(exception)}"} + ), } # Define parameters for POST request @@ -128,16 +147,15 @@ def lambda_handler(event, _context): encoded_params = json.dumps(params).encode("utf-8") # Make POST request to Twitch API - http = urllib3.PoolManager() try: url = "https://id.twitch.tv/oauth2/token" headers = {"Content-Type": "application/json"} - response = http.request("POST", url, body=encoded_params, headers=headers) + response = requests.post(url=url, json=encoded_params, headers=headers) # Check if request was successful - if response.status == 200: - # Parse the JSON response - token_response = json.loads(response.data.decode("utf-8")) + if response.status_code == 200: + # Parse the JSON response= + token_response = json.loads(response.json().decode("utf-8")) # Check received token is valid and grab extra metadata is_valid, validation_response = validate_token( @@ -148,7 +166,7 @@ def lambda_handler(event, _context): # Store validation response in DynamoDB store_in_dynamodb(token_response, validation_response) - return {"statusCode": 200, "body": response.data.decode("utf-8")} + return {"statusCode": 200, "body": json.dumps(token_response)} else: return { @@ -158,9 +176,9 @@ def lambda_handler(event, _context): else: return { - "statusCode": response.status, + "statusCode": response.status_code, "body": json.dumps({"message": "Failed to retrieve access token"}), } - finally: - http.clear() + except Exception as exception: + return {"statusCode": 500, "body": json.dumps({"error": str(exception)})} diff --git a/aws/tests/test_refresh_twitch_access_token.py b/aws/tests/test_refresh_twitch_access_token.py new file mode 100644 index 0000000..b2d78a0 --- /dev/null +++ b/aws/tests/test_refresh_twitch_access_token.py @@ -0,0 +1,412 @@ +import json +import os + +import boto3 +import pytest +import responses +from botocore.exceptions import ClientError +from moto import mock_aws + +from refresh_twitch_access_token import ( + get_parameter, + get_secret, + refresh_token, + lambda_handler as refresh_twitch_access_token_handler, store_in_dynamodb, +) + + +@pytest.fixture +def set_environment_variables(monkeypatch): + monkeypatch.setenv("DYNAMODB_USER_TABLE_NAME", "MSecBot_User") + + +@pytest.fixture +def db_item(): + return dict( + id=875992093, + access_token="access_token", + expires_in=12345, + login="login", + refresh_token="refresh_token_xyz789", + scopes=[ + {"S": "analytics:read:extensions"}, + {"S": "analytics:read:games"}, + {"S": "bits:read"}, + {"S": "channel:bot"}, + {"S": "channel:edit:commercial"}, + {"S": "channel:manage:broadcast"}, + {"S": "channel:manage:guest_star"}, + {"S": "channel:manage:polls"}, + {"S": "channel:manage:predictions"}, + {"S": "channel:manage:raids"}, + {"S": "channel:manage:redemptions"}, + {"S": "channel:manage:vips"}, + {"S": "channel:moderate"}, + {"S": "channel:read:ads"}, + {"S": "channel:read:charity"}, + {"S": "channel:read:editors"}, + {"S": "channel:read:goals"}, + {"S": "channel:read:guest_star"}, + {"S": "channel:read:hype_train"}, + {"S": "channel:read:polls"}, + {"S": "channel:read:predictions"}, + {"S": "channel:read:redemptions"}, + {"S": "channel:read:subscriptions"}, + {"S": "channel:read:vips"}, + {"S": "chat:edit"}, + {"S": "chat:read"}, + {"S": "clips:edit"}, + {"S": "moderation:read"}, + {"S": "moderator:manage:announcements"}, + {"S": "moderator:manage:automod"}, + {"S": "moderator:manage:banned_users"}, + {"S": "moderator:manage:blocked_terms"}, + {"S": "moderator:manage:chat_messages"}, + {"S": "moderator:manage:chat_settings"}, + {"S": "moderator:manage:guest_star"}, + {"S": "moderator:manage:shield_mode"}, + {"S": "moderator:manage:shoutouts"}, + {"S": "moderator:manage:unban_requests"}, + {"S": "moderator:read:automod_settings"}, + {"S": "moderator:read:blocked_terms"}, + {"S": "moderator:read:chat_settings"}, + {"S": "moderator:read:chatters"}, + {"S": "moderator:read:followers"}, + {"S": "moderator:read:guest_star"}, + {"S": "moderator:read:shield_mode"}, + {"S": "moderator:read:shoutouts"}, + {"S": "moderator:read:suspicious_users"}, + {"S": "moderator:read:unban_requests"}, + {"S": "user:bot"}, + {"S": "user:edit:broadcast"}, + {"S": "user:edit:follows"}, + {"S": "user:manage:blocked_users"}, + {"S": "user:manage:chat_color"}, + {"S": "user:manage:whispers"}, + {"S": "user:read:blocked_users"}, + {"S": "user:read:broadcast"}, + {"S": "user:read:chat"}, + {"S": "user:read:email"}, + {"S": "user:read:emotes"}, + {"S": "user:read:follows"}, + {"S": "user:read:moderated_channels"}, + {"S": "user:read:subscriptions"}, + {"S": "user:write:chat"}, + {"S": "whispers:edit"}, + {"S": "whispers:read"}, + ], + ) + + +@mock_aws +def test_get_parameter_success(): + # Set up the mocked SSM service + mock_ssm = boto3.client("ssm") + parameter_name = "test_parameter" + parameter_value = "test_value" + + # This simulates adding the parameter to the SSM Parameter Store + mock_ssm.put_parameter( + Name=parameter_name, Value=parameter_value, Type="String", Overwrite=True + ) + + # Call the function and verify the success case + result = get_parameter(parameter_name=parameter_name) + assert result == parameter_value + + +@mock_aws +def test_get_parameter_aws_client_error(): + # Simulate a general error when calling SSM + with pytest.raises(ClientError) as exc_info: + boto3.client("ssm").get_parameter(Name="invalid_name") + + # Verify that an error is raised + assert exc_info.value.response["Error"]["Code"] + + +@mock_aws +def test_get_parameter_not_found(): + # Call the function with a non-existent parameter and verify it raises an error + with pytest.raises(ClientError) as exc_info: + get_parameter("non_existent_parameter") + + # Verify that the error is for a missing parameter + assert exc_info.value.response["Error"]["Code"] == "ParameterNotFound" + + +@pytest.fixture +def mock_get_parameter(mocker): + return mocker.patch("refresh_twitch_access_token.get_parameter", autospec=True) + + +def test_get_secret_success(mock_get_parameter): + # Simulate the behavior of get_parameter + mock_get_parameter.return_value = "some_value" + + parameter = "test_parameter" + + # Call the function + result = get_secret(parameter) + + # Ensure get_parameter was called + mock_get_parameter.assert_called_once_with(parameter) + + # Assert that the result is what we expect + assert result == "some_value" + + +def test_get_secret_failure(mock_get_parameter): + # Simulate the behavior of get_parameter raising an exception + mock_get_parameter.side_effect = Exception("Error retrieving parameter") + + parameter = "test_parameter" + + # Call the function and ensure it raises an exception + with pytest.raises(Exception): + get_secret(parameter) + + # Ensure get_parameter was called + mock_get_parameter.assert_called_once_with(parameter) + + +@responses.activate +def test_refresh_token_success_path(mocker): + """ + If the request succeeds, the response contains the new access token, refresh token, and scopes associated with + the new grant. Because refresh tokens may change, your app should safely store the new refresh token to use the + next time. + + https://dev.twitch.tv/docs/authentication/refresh-tokens/ + """ + mock_get_secret = mocker.patch("refresh_twitch_access_token.get_secret") + mock_get_secret.side_effect = ["client_id", "client_secret"] + + expected_response = json.dumps( + { + "access_token": "1ssjqsqfy6bads1ws7m03gras79zfr", + "refresh_token": "eyJfMzUtNDU0OC4MWYwLTQ5MDY5ODY4NGNlMSJ9%asdfasdf=", + "scope": ["channel:read:subscriptions", "channel:manage:polls"], + "token_type": "bearer", + } + ) + responses.add( + method=responses.POST, + url="https://id.twitch.tv/oauth2/token", + json=expected_response, + status=200, + ) + + actual_success, actual_json = refresh_token(_refresh_token="refresh_token") + + assert actual_success is True + assert actual_json.get("access_token") == "1ssjqsqfy6bads1ws7m03gras79zfr" + assert ( + actual_json.get("refresh_token") + == "eyJfMzUtNDU0OC4MWYwLTQ5MDY5ODY4NGNlMSJ9%asdfasdf=" + ) + assert actual_json.get("scope")[0] == "channel:read:subscriptions" + assert actual_json.get("scope")[1] == "channel:manage:polls" + assert actual_json.get("token_type") == "bearer" + + +def test_refresh_token_get_secret_raises_exception(mocker): + mock_get_secret = mocker.patch("refresh_twitch_access_token.get_secret") + mock_get_secret.side_effect = Exception("some exception") + + with pytest.raises(Exception): + refresh_token("_refresh_token") + + +def test_refresh_token_requests_post_raises_exception(mocker): + mock_get_secret = mocker.patch("refresh_twitch_access_token.get_secret") + mock_get_secret.side_effect = ["client_id", "client_secret"] + + mock_requests_post = mocker.patch("requests.post") + mock_requests_post.side_effect = Exception("some exception") + + with pytest.raises(Exception): + refresh_token("_refresh_token") + + +def test_refresh_token_invalid_token(mocker): + """ + Refresh tokens, like access tokens, can become invalid if the user changes their password or disconnects your + app. Most refresh tokens do not expire, but refresh tokens generated by a Public client type will expire 30 + days after they are generated, which will invalidate the refresh token. Most applications are set to the + Confidential client type, of which the refresh tokens do not have an expiration time. + + A refresh request can fail with HTTP status code 401 Unauthorized if the refresh token is no longer valid. If + the refresh fails, the application should re-prompt the end user for consent using the Authorization Code Grant + flow or OIDC Authorization Code Grant flow. + + https://dev.twitch.tv/docs/authentication/refresh-tokens/ + """ + mock_get_secret = mocker.patch("refresh_twitch_access_token.get_secret") + mock_get_secret.side_effect = ["client_id", "client_secret"] + + expected_response = json.dumps( + { + "error": "Bad Request", + "status": 400, + "message": "Invalid refresh token", + } + ) + responses.add( + method=responses.POST, + url="https://id.twitch.tv/oauth2/token", + json=expected_response, + status=400, + ) + + actual_success, actual_json = refresh_token(_refresh_token="refresh_token") + + assert actual_success is False + assert actual_json.get("error") == "Bad Request" + assert actual_json.get("status") == 400 + assert actual_json.get("message") == "Invalid refresh token" + + +@responses.activate +@mock_aws +def test_refresh_twitch_access_token_handler_success(mocker, set_environment_variables): + mock_get_secret = mocker.patch("refresh_twitch_access_token.get_secret") + mock_get_secret.side_effect = ["client_id", "client_secret"] + + expected_token_refresh_response = json.dumps( + { + "access_token": "1ssjqsqfy6bads1ws7m03gras79zfr", + "refresh_token": "eyJfMzUtNDU0OC4MWYwLTQ5MDY5ODY4NGNlMSJ9%asdfasdf=", + "scope": ["channel:read:subscriptions", "channel:manage:polls"], + "token_type": "bearer", + } + ) + responses.add( + method=responses.POST, + url="https://id.twitch.tv/oauth2/token", + json=expected_token_refresh_response, + status=200, + ) + + event_in = {"queryStringParameters": {"refresh_token": "refresh_token"}} + + actual = refresh_twitch_access_token_handler(event=event_in, _context={}) + actual_refresh_response = json.loads(actual.get("body")).get("refresh_response") + + assert actual.get("statusCode") == 200 + assert ( + actual_refresh_response.get("access_token") == "1ssjqsqfy6bads1ws7m03gras79zfr" + ) + assert ( + actual_refresh_response.get("refresh_token") + == "eyJfMzUtNDU0OC4MWYwLTQ5MDY5ODY4NGNlMSJ9%asdfasdf=" + ) + assert actual_refresh_response.get("scope")[0] == "channel:read:subscriptions" + assert actual_refresh_response.get("scope")[1] == "channel:manage:polls" + assert actual_refresh_response.get("token_type") == "bearer" + + +def test_refresh_twitch_access_token_handler_missing_refresh_token(): + event_in = {"queryStringParameters": {}} + + actual = refresh_twitch_access_token_handler(event=event_in, _context={}) + assert actual.get("statusCode") == 400 + assert json.loads(actual.get("body")).get("message") == "Refresh token missing" + + +def test_refresh_twitch_access_token_handler_refresh_token_raises_exception(mocker): + mock_refresh_token = mocker.patch("refresh_twitch_access_token.refresh_token") + mock_refresh_token.side_effect = Exception("some exception") + + event_in = {"queryStringParameters": {"refresh_token": "refresh_token"}} + + actual = refresh_twitch_access_token_handler(event=event_in, _context={}) + assert actual.get("statusCode") == 500 + assert ( + json.loads(actual.get("body")).get("message") + == "Error retrieving secret: some exception" + ) + + +def test_refresh_twitch_access_token_handler_refresh_token_fails_to_refresh_token( + mocker, +): + mock_refresh_token = mocker.patch("refresh_twitch_access_token.refresh_token") + mock_refresh_token.return_value = False, json.dumps( + { + "error": "Bad Request", + "status": 400, + "message": "Invalid refresh token", + } + ) + + event_in = {"queryStringParameters": {"refresh_token": "refresh_token"}} + + actual = refresh_twitch_access_token_handler(event=event_in, _context={}) + assert actual.get("statusCode") == 401 + assert json.loads(actual.get("body")).get("message") == "Token is not valid" + + +@responses.activate +@mock_aws +def test_refresh_twitch_access_token_handler_store_in_dynamodb_put_and_update( + mocker, set_environment_variables, db_item +): + mock_get_secret = mocker.patch("refresh_twitch_access_token.get_secret") + mock_get_secret.side_effect = ["client_id", "client_secret"] + + expected_token_refresh_response = json.dumps( + { + "access_token": "access_token%s1ws7m03gras79zfr", + "refresh_token": "refresh_token%asdfasdf=", + "scope": ["channel:read:subscriptions", "channel:manage:polls"], + "token_type": "bearer", + } + ) + responses.add( + method=responses.POST, + url="https://id.twitch.tv/oauth2/token", + json=expected_token_refresh_response, + status=200, + ) + + # Set up mock DynamoDB + dynamodb = boto3.resource("dynamodb", region_name="eu-west-2") + table_name = os.getenv("DYNAMODB_USER_TABLE_NAME") + mock_table = dynamodb.create_table( + TableName=table_name, + KeySchema=[ + {"AttributeName": "id", "KeyType": "HASH"}, # Partition key + {"AttributeName": "refresh_token", "KeyType": "HASH"}, # Sort key + ], + AttributeDefinitions=[ + {"AttributeName": "id", "AttributeType": "N"}, + {"AttributeName": "refresh_token", "AttributeType": "S"}, + ], + ProvisionedThroughput={"ReadCapacityUnits": 5, "WriteCapacityUnits": 5}, + ) + mock_table.meta.client.get_waiter("table_exists").wait(TableName=table_name) + mock_table.put_item( + Item={ + "id": 123, + "login": "existing_user", + "access_token": "stale_access_token", + "expires_in": 1000, + "refresh_token": "active_refresh_token", + "client_id": "client_id", + "scopes": ["old_scopes"], + } + ) + + # Test Case 1: Insert new token + _refresh_response_first = { + "access_token": "new_access_token", + "refresh_token": "new_refresh_token", + "scope": "new_scope", + "token_type": "bearer", + } + result_insert = store_in_dynamodb("active_refresh_token", _refresh_response_first) + + # Assertions for new token insert + assert result_insert["statusCode"] == 500 + assert json.loads(result_insert["body"]) == "Error: The provided key element does not match the schema" diff --git a/aws/tests/test_store_oauth_authorize_code.py b/aws/tests/test_store_oauth_authorize_code.py new file mode 100644 index 0000000..a6cd318 --- /dev/null +++ b/aws/tests/test_store_oauth_authorize_code.py @@ -0,0 +1,643 @@ +import json +import os + +import boto3 +import pytest +import responses +from botocore.exceptions import ClientError +from moto import mock_aws +from moto.ssm.exceptions import ParameterNotFound +from store_oauth_authorize_code import get_parameter, get_secret +from store_oauth_authorize_code import ( + lambda_handler as store_oauth_authorize_code_handler, +) +from store_oauth_authorize_code import store_in_dynamodb, validate_token + + +@pytest.fixture +def set_environment_variables(monkeypatch): + monkeypatch.setenv("DYNAMODB_USER_TABLE_NAME", "MSecBot_User") + + +@mock_aws +def test_get_parameter_success(): + # Set up the mocked SSM service + mock_ssm = boto3.client("ssm") + parameter_name = "test_parameter" + parameter_value = "test_value" + + # This simulates adding the parameter to the SSM Parameter Store + mock_ssm.put_parameter( + Name=parameter_name, Value=parameter_value, Type="String", Overwrite=True + ) + + # Call the function and verify the success case + result = get_parameter(parameter_name=parameter_name) + assert result == parameter_value + + +@mock_aws +def test_get_parameter_aws_client_error(): + # Simulate a general error when calling SSM + with pytest.raises(ClientError) as exc_info: + boto3.client("ssm").get_parameter(Name="invalid_name") + + # Verify that an error is raised + assert exc_info.value.response["Error"]["Code"] + + +@mock_aws +def test_get_parameter_not_found(): + # Call the function with a non-existent parameter and verify it raises an error + with pytest.raises(ClientError) as exc_info: + get_parameter("non_existent_parameter") + + # Verify that the error is for a missing parameter + assert exc_info.value.response["Error"]["Code"] == "ParameterNotFound" + + +@mock_aws +def test_get_parameter_access_denied(): + # Set up the mocked SSM service + mock_ssm = boto3.client("ssm") + parameter_name = "test_parameter" + parameter_value = "test_value" + + # Add a parameter with SecureString (mocked by moto, but won't deny access by default) + mock_ssm.put_parameter( + Name=parameter_name, Value=parameter_value, Type="SecureString" + ) + + # Manually raise a ClientError to simulate access denial + with pytest.raises(ClientError) as exc_info: + # Raise the error manually to simulate an access-denied scenario + raise ClientError( + error_response={ + "Error": { + "Code": "AccessDeniedException", + "Message": "User is not authorized to access parameter.", + } + }, + operation_name="GetParameter", + ) + + # Check that the error code is "AccessDeniedException" + assert exc_info.value.response["Error"]["Code"] == "AccessDeniedException" + + +@mock_aws +def test_get_parameter_incorrect_type(mocker): + # Set up the mocked SSM service + mock_ssm = boto3.client("ssm") + parameter_name = "test_parameter" + parameter_value = "test_value" + + # Add a parameter as SecureString, which requires decryption permissions + mock_ssm.put_parameter( + Name=parameter_name, Value=parameter_value, Type="SecureString" + ) + + # Patch `get_parameter` on the actual `ssm` client to simulate an AccessDeniedException + # We'll patch it globally on the boto3 client, ensuring it's the one that's called + mocker.patch( + "boto3.client", return_value=mock_ssm + ) # Ensure boto3.client returns the mock SSM client + mocker.patch.object( + mock_ssm, + "get_parameter", + side_effect=ClientError( + error_response={ + "Error": { + "Code": "AccessDeniedException", + "Message": "User is not authorized to decrypt the parameter.", + } + }, + operation_name="GetParameter", + ), + ) + + # Test that `get_parameter` raises an AccessDeniedException due to decryption issue + with pytest.raises(ClientError) as exc_info: + get_parameter(parameter_name) + + # Assert that the correct error code was raised + assert exc_info.value.response["Error"]["Code"] == "AccessDeniedException" + + +@responses.activate +def test_validate_token_success(): + # Mock the response for the validation URL with a successful response + responses.add( + responses.GET, + "https://id.twitch.tv/oauth2/validate", + json={ + "client_id": "wbmytr93xzw8zbg0p1izqyzzc5mbiz", + "login": "twitchdev", + "scopes": ["channel:read:subscriptions"], + "user_id": "141981764", + "expires_in": 5520838, + }, + status=200, + ) + + # Call the function and assert the expected result + result, data = validate_token("valid_token") + + # Assert the expected result + assert result is True + assert data == { + "client_id": "wbmytr93xzw8zbg0p1izqyzzc5mbiz", + "login": "twitchdev", + "scopes": ["channel:read:subscriptions"], + "user_id": "141981764", + "expires_in": 5520838, + } + + +@responses.activate +def test_validate_token_invalid_access_token(): + # Mock the response for the validation URL with a failed response + responses.add( + responses.GET, + "https://id.twitch.tv/oauth2/validate", + json={"status": 401, "message": "invalid access token"}, + status=401, # Bad request or invalid token + ) + + # Call the function and assert the expected result + result, data = validate_token("invalid_token") + + # Assert the expected result + assert result is False + assert data == {"status": 401, "message": "invalid access token"} + + +def test_validate_token_exception(mocker): + mock_requests = mocker.patch("requests.get") + mock_requests.side_effect = TypeError("Unexpected error") + + # Call the function and check if it returns False and the correct error message + result, data = validate_token("valid_token") + + assert result is False + assert data == {"status": 500, "message": "Unexpected error"} + + +@mock_aws +def test_store_in_dynamodb_update_item_where_item_already_exists( + set_environment_variables, +): + from store_oauth_authorize_code import store_in_dynamodb + + # Insert an initial item into the table + mock_dynamodb = boto3.resource("dynamodb") + table_name = os.getenv("DYNAMODB_USER_TABLE_NAME") + mock_dynamodb.create_table( + TableName=table_name, + KeySchema=[ + {"AttributeName": "id", "KeyType": "HASH"}, # Partition key + ], + AttributeDefinitions=[ + {"AttributeName": "id", "AttributeType": "N"}, # Number + ], + ProvisionedThroughput={ + "ReadCapacityUnits": 5, + "WriteCapacityUnits": 5, + }, + ) + mock_table = mock_dynamodb.Table(table_name) + mock_table.put_item( + Item={ + "id": 123, + "login": "existing_user", + "access_token": "old_token", + "expires_in": 1000, + "refresh_token": "old_refresh_token", + "client_id": "client_id", + "scopes": ["old_scopes"], + } + ) + + # The validation response and token response + token_response = { + "access_token": "new_token", + "refresh_token": "new_refresh_token", + } + validation_response = { + "user_id": 123, + "login": "existing_user", + "expires_in": 2000, + "client_id": "client_id", + "scopes": ["new_scopes"], + } + + # Call the function + store_in_dynamodb(token_response, validation_response) + + # Retrieve the updated item and assert values + response = mock_table.get_item(Key={"id": int(validation_response.get("user_id"))}) + + # Assert that the login was updated + item = response.get("Item") + + # Ensure item exists and check the values + assert item is not None, "Item should exist" + assert item.get("login") == "existing_user" + assert item.get("access_token") == "new_token" + assert item.get("expires_in") == 2000 + assert item.get("refresh_token") == "new_refresh_token" + assert item.get("client_id") == "client_id" + assert item.get("scopes") == ["new_scopes"] + + +@mock_aws +def test_store_in_dynamodb_update_item_where_item_does_not_exist( + set_environment_variables, +): + # Insert an initial item into the table + mock_dynamodb = boto3.resource("dynamodb") + table_name = os.getenv("DYNAMODB_USER_TABLE_NAME") + mock_dynamodb.create_table( + TableName=table_name, + KeySchema=[ + {"AttributeName": "id", "KeyType": "HASH"}, # Partition key + ], + AttributeDefinitions=[ + {"AttributeName": "id", "AttributeType": "N"}, # Number + ], + ProvisionedThroughput={ + "ReadCapacityUnits": 5, + "WriteCapacityUnits": 5, + }, + ) + mock_table = mock_dynamodb.Table(table_name) + + # The validation response and token response + token_response = { + "access_token": "new_token", + "refresh_token": "new_refresh_token", + } + validation_response = { + "user_id": 123, + "login": "existing_user", + "expires_in": 2000, + "client_id": "client_id", + "scopes": ["new_scopes"], + } + + # Call the function + store_in_dynamodb(token_response, validation_response) + + # Retrieve the updated item and assert values + response = mock_table.get_item(Key={"id": int(validation_response.get("user_id"))}) + + # Assert that the login was updated + item = response.get("Item") + + # Ensure item exists and check the values + assert item is not None, "Item should exist" + assert item.get("login") == "existing_user" + assert item.get("access_token") == "new_token" + assert item.get("expires_in") == 2000 + assert item.get("refresh_token") == "new_refresh_token" + assert item.get("client_id") == "client_id" + assert item.get("scopes") == ["new_scopes"] + + +@mock_aws +def test_store_in_dynamodb_type_error(set_environment_variables): + # The validation response and token response + token_response = {} + validation_response = {} + + # Call the function + actual = store_in_dynamodb(token_response, validation_response) + assert actual.get("statusCode") == 500 + assert ( + actual.get("body") + == "\"Error: int() argument must be a string, a bytes-like object or a real number, not 'NoneType'\"" + ) + + +@mock_aws +def test_store_in_dynamodb_client_error(mocker, set_environment_variables): + mock_dynamodb = mocker.patch("boto3.session.Session.resource") + mock_table = mocker.Mock() + mock_table.get_item.side_effect = ClientError( + operation_name="operation_name", + error_response={"Error": {"Code": "ClientError", "Message": "Unknown"}}, + ) + mock_dynamodb.return_value.Table.return_value = mock_table + + # The validation response and token response + token_response = {} + validation_response = {"user_id": 123} + + # Call the function + actual = store_in_dynamodb(token_response, validation_response) + + # Assert that the correct error code was raised + assert actual.get("statusCode") == 500 + assert actual.get("body") == '"Error: Unknown"' + + +@pytest.fixture +def mock_get_parameter(mocker): + return mocker.patch("store_oauth_authorize_code.get_parameter", autospec=True) + + +def test_get_secret_success(mock_get_parameter): + # Simulate the behavior of get_parameter + mock_get_parameter.return_value = "some_value" + + parameter = "test_parameter" + + # Call the function + result = get_secret(parameter) + + # Ensure get_parameter was called + mock_get_parameter.assert_called_once_with(parameter) + + # Assert that the result is what we expect + assert result == "some_value" + + +def test_get_secret_failure(mock_get_parameter): + # Simulate the behavior of get_parameter raising an exception + mock_get_parameter.side_effect = Exception("Error retrieving parameter") + + parameter = "test_parameter" + + # Call the function and ensure it raises an exception + with pytest.raises(Exception): + get_secret(parameter) + + # Ensure get_parameter was called + mock_get_parameter.assert_called_once_with(parameter) + + +@mock_aws +def test_lambda_handler_success(mocker): + mock_get_secret = mocker.patch("store_oauth_authorize_code.get_secret") + mock_get_secret.side_effect = ["client_id", "client_secret"] + + mock_get_parameter = mocker.patch("store_oauth_authorize_code.get_parameter") + mock_get_parameter.return_value = "redirect_uri" + + mock_token_response = mocker.Mock() + mock_token_response.status_code = 200 + mock_response = mocker.patch("requests.post") + mock_response.return_value = mock_token_response + + mock_json_loads = mocker.patch("json.loads") + mock_json_loads.return_value = { + "access_token": "jostpf5q0uzmxmkba9iyug38kjtgh", + "expires_in": 5011271, + "token_type": "bearer", + } + + mock_validate_token = mocker.patch("store_oauth_authorize_code.validate_token") + mock_validate_token.return_value = True, json.dumps( + { + "client_id": "wbmytr93xzw8zbg0p1izqyzzc5mbiz", + "login": "twitchdev", + "scopes": ["channel:read:subscriptions"], + "user_id": "141981764", + "expires_in": 5520838, + } + ) + + mock_store_in_dynamodb = mocker.patch( + "store_oauth_authorize_code.store_in_dynamodb" + ) + mock_store_in_dynamodb.return_value = None + + event_in = { + "body": "eyJ0ZXN0IjoiYm9keSJ9", + "resource": "/{proxy+}", + "path": "/path/to/resource", + "httpMethod": "POST", + "isBase64Encoded": True, + "queryStringParameters": { + "code": "jath2p663cpl35wikfhd2d1qds5t4x", + "state": "875992093", + "scope": "user%3Aread%3Achat+user%3Awrite%3Achat+moderator%3Aread%3Asuspicious_users+moderator%3Aread%3A" + "chatters+user%3Amanage%3Achat_color+moderator%3Amanage%3Achat_messages+moderator%3Amanage%3A" + "chat_settings+moderator%3Aread%3Achat_settings+chat%3Aread+chat%3Aedit+user%3Aread%3A" + "email+user%3Aedit%3Abroadcast+user%3Aread%3Abroadcast+clips%3Aedit+bits%3Aread+channel%3A" + "moderate+channel%3Aread%3Asubscriptions+whispers%3Aread+whispers%3Aedit+moderation%3A" + "read+channel%3Aread%3Aredemptions+channel%3Aedit%3Acommercial+channel%3Aread%3A" + "hype_train+channel%3Amanage%3Abroadcast+user%3Aedit%3Afollows+channel%3Amanage%3A" + "redemptions+user%3Aread%3Ablocked_users+user%3Amanage%3Ablocked_users+user%3Aread%3A" + "subscriptions+user%3Aread%3Afollows+channel%3Amanage%3Apolls+channel%3Amanage%3A" + "predictions+channel%3Aread%3Apolls+channel%3Aread%3Apredictions+moderator%3Amanage%3A" + "automod+channel%3Aread%3Agoals+moderator%3Aread%3Aautomod_settings+moderator%3Amanage%3A" + "banned_users+moderator%3Aread%3Ablocked_terms+moderator%3Amanage%3Ablocked_terms+channel%3A" + "manage%3Araids+moderator%3Amanage%3Aannouncements+channel%3Aread%3Avips+channel%3Amanage%3A" + "vips+user%3Amanage%3Awhispers+channel%3Aread%3Acharity+moderator%3Aread%3A" + "shield_mode+moderator%3Amanage%3Ashield_mode+moderator%3Aread%3Ashoutouts+moderator%3A" + "manage%3Ashoutouts+moderator%3Aread%3Afollowers+channel%3Aread%3Aguest_star+channel%3A" + "manage%3Aguest_star+moderator%3Aread%3Aguest_star+moderator%3Amanage%3Aguest_star+channel%3A" + "bot+user%3Abot+channel%3Aread%3Aads+user%3Aread%3Amoderated_channels+user%3Aread%3A" + "emotes+moderator%3Aread%3Aunban_requests+moderator%3Amanage%3Aunban_requests+channel%3A" + "read%3Aeditors+analytics%3Aread%3Agames+analytics%3Aread%3Aextensions", + }, + "multiValueQueryStringParameters": {"foo": ["bar"]}, + "pathParameters": {"proxy": "/path/to/resource"}, + "stageVariables": {"baz": "qux"}, + "headers": { + "Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,*/*;q=0.8", + "Accept-Encoding": "gzip, deflate, sdch", + "Accept-Language": "en-US,en;q=0.8", + "Cache-Control": "max-age=0", + "CloudFront-Forwarded-Proto": "https", + "CloudFront-Is-Desktop-Viewer": "true", + "CloudFront-Is-Mobile-Viewer": "false", + "CloudFront-Is-SmartTV-Viewer": "false", + "CloudFront-Is-Tablet-Viewer": "false", + "CloudFront-Viewer-Country": "US", + "Host": "1234567890.execute-api.us-east-1.amazonaws.com", + "Upgrade-Insecure-Requests": "1", + "User-Agent": "Custom User Agent String", + "Via": "1.1 08f323deadbeefa7af34d5feb414ce27.cloudfront.net (CloudFront)", + "X-Amz-Cf-Id": "cDehVQoZnx43VYQb9j2-nvCh-9z396Uhbp027Y2JvkCPNLmGJHqlaA==", + "X-Forwarded-For": "127.0.0.1, 127.0.0.2", + "X-Forwarded-Port": "443", + "X-Forwarded-Proto": "https", + }, + "multiValueHeaders": { + "Accept": [ + "text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,*/*;q=0.8" + ], + "Accept-Encoding": ["gzip, deflate, sdch"], + "Accept-Language": ["en-US,en;q=0.8"], + "Cache-Control": ["max-age=0"], + "CloudFront-Forwarded-Proto": ["https"], + "CloudFront-Is-Desktop-Viewer": ["true"], + "CloudFront-Is-Mobile-Viewer": ["false"], + "CloudFront-Is-SmartTV-Viewer": ["false"], + "CloudFront-Is-Tablet-Viewer": ["false"], + "CloudFront-Viewer-Country": ["US"], + "Host": ["0123456789.execute-api.us-east-1.amazonaws.com"], + "Upgrade-Insecure-Requests": ["1"], + "User-Agent": ["Custom User Agent String"], + "Via": ["1.1 08f323deadbeefa7af34d5feb414ce27.cloudfront.net (CloudFront)"], + "X-Amz-Cf-Id": ["cDehVQoZnx43VYQb9j2-nvCh-9z396Uhbp027Y2JvkCPNLmGJHqlaA=="], + "X-Forwarded-For": ["127.0.0.1, 127.0.0.2"], + "X-Forwarded-Port": ["443"], + "X-Forwarded-Proto": ["https"], + }, + "requestContext": { + "accountId": "123456789012", + "resourceId": "123456", + "stage": "prod", + "requestId": "c6af9ac6-7b61-11e6-9a41-93e8deadbeef", + "requestTime": "09/Apr/2015:12:34:56 +0000", + "requestTimeEpoch": 1428582896000, + "identity": { + "cognitoIdentityPoolId": None, + "accountId": None, + "cognitoIdentityId": None, + "caller": None, + "accessKey": None, + "sourceIp": "127.0.0.1", + "cognitoAuthenticationType": None, + "cognitoAuthenticationProvider": None, + "userArn": None, + "userAgent": "Custom User Agent String", + "user": None, + }, + "path": "/prod/path/to/resource", + "resourcePath": "/{proxy+}", + "httpMethod": "POST", + "apiId": "1234567890", + "protocol": "HTTP/1.1", + }, + } + + actual = store_oauth_authorize_code_handler(event=event_in, _context={}) + + assert actual.get("statusCode") == 200 + assert actual.get("body") == ( + '{"access_token": "jostpf5q0uzmxmkba9iyug38kjtgh", ' + '"expires_in": 5011271, ' + '"token_type": "bearer"}' + ) + + +def test_lambda_handler_code_is_none(): + event_in = {"queryStringParameters": {"code": None}} + actual = store_oauth_authorize_code_handler(event=event_in, _context={}) + assert actual.get("statusCode") == 400 + assert json.loads(actual.get("body")).get("message") == "Code parameter missing" + + +def test_lambda_handler_get_secret_parameter_not_found(mocker): + mock_get_secret = mocker.patch("store_oauth_authorize_code.get_secret") + mock_get_secret.side_effect = ParameterNotFound(message="message") + + event_in = {"queryStringParameters": {"code": "code"}} + actual = store_oauth_authorize_code_handler(event=event_in, _context={}) + assert actual.get("statusCode") == 500 + assert json.loads(actual.get("body")).get("message") == ( + 'Error retrieving secret: 400 Bad Request: {"__type": "ParameterNotFound", ' + '"message": "message"}' + ) + + +def test_lambda_handler_get_secret_exception(mocker): + mock_get_secret = mocker.patch("store_oauth_authorize_code.get_secret") + mock_get_secret.side_effect = Exception + + event_in = {"queryStringParameters": {"code": "code"}} + actual = store_oauth_authorize_code_handler(event=event_in, _context={}) + assert actual.get("statusCode") == 500 + assert json.loads(actual.get("body")).get("message") == ( + "Error retrieving secret: " + ) + + +def test_lambda_handler_get_parameter_exception(mocker): + mock_get_secret = mocker.patch("store_oauth_authorize_code.get_secret") + mock_get_secret.side_effect = ["client_id", "client_secret"] + + mock_get_parameter = mocker.patch("store_oauth_authorize_code.get_parameter") + mock_get_parameter.side_effect = Exception + + event_in = {"queryStringParameters": {"code": "code"}} + actual = store_oauth_authorize_code_handler(event=event_in, _context={}) + assert actual.get("statusCode") == 500 + assert json.loads(actual.get("body")).get("message") == ( + "Error retrieving redirect uri: " + ) + + +def test_lambda_handler_invalid_token_response(mocker): + mock_get_secret = mocker.patch("store_oauth_authorize_code.get_secret") + mock_get_secret.side_effect = ["client_id", "client_secret"] + + mock_get_parameter = mocker.patch("store_oauth_authorize_code.get_parameter") + mock_get_parameter.return_value = "redirect_uri" + + mock_token_response = mocker.Mock() + mock_token_response.status_code = 200 + mock_json = mocker.Mock() + mock_token_response.json.return_value = mock_json + mock_json.decode.return_value = json.dumps( + { + "access_token": "jostpf5q0uzmxmkba9iyug38kjtgh", + "expires_in": 5011271, + "token_type": "bearer", + } + ) + mock_response = mocker.patch("requests.post") + mock_response.return_value = mock_token_response + + mock_validate_token = mocker.patch("store_oauth_authorize_code.validate_token") + mock_validate_token.return_value = False, json.dumps( + {"status": 401, "message": "invalid access token"} + ) + + event_in = {"queryStringParameters": {"code": "code"}} + actual_result = store_oauth_authorize_code_handler(event=event_in, _context={}) + assert actual_result.get("statusCode") == 401 + actual_message = json.loads(actual_result.get("body")).get("message") + assert actual_message == "Token is not valid" + + +def test_lambda_handler_failed_token_response(mocker): + mock_get_secret = mocker.patch("store_oauth_authorize_code.get_secret") + mock_get_secret.side_effect = ["client_id", "client_secret"] + + mock_get_parameter = mocker.patch("store_oauth_authorize_code.get_parameter") + mock_get_parameter.return_value = "redirect_uri" + + mock_token_response = mocker.Mock() + mock_token_response.status_code = 500 + mock_json = mocker.Mock() + mock_token_response.json.return_value = mock_json + mock_json.decode.return_value = json.dumps({}) + mock_response = mocker.patch("requests.post") + mock_response.return_value = mock_token_response + + event_in = {"queryStringParameters": {"code": "code"}} + actual = store_oauth_authorize_code_handler(event=event_in, _context={}) + assert actual.get("statusCode") == 500 + assert json.loads(actual.get("body")).get("message") == ( + "Failed to retrieve access token" + ) + + +def test_lambda_handler_http_error_response(mocker): + mock_get_secret = mocker.patch("store_oauth_authorize_code.get_secret") + mock_get_secret.side_effect = ["client_id", "client_secret"] + + mock_get_parameter = mocker.patch("store_oauth_authorize_code.get_parameter") + mock_get_parameter.return_value = "redirect_uri" + + mocker.patch("requests.post", side_effect=Exception("Mocked exception")) + + event_in = {"queryStringParameters": {"code": "code"}} + actual = store_oauth_authorize_code_handler(event=event_in, _context={}) + assert actual.get("statusCode") == 500 + assert json.loads(actual.get("body")).get("error") == "Mocked exception"