diff --git a/.circleci/config.yml b/.circleci/config.yml index 1cef8e9..4ff942a 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -81,8 +81,7 @@ jobs: enum: [ development, staging, production ] environment: SAM_CONFIG_FILE: samconfig.toml - SAM_LAMBDA_CONFIG_ENV: <> - SAM_PUBLIC_CONFIG_ENV: <>-public-access + DC_ENVIRONMENT: <> steps: - checkout - attach_workspace: @@ -98,7 +97,7 @@ jobs: command: | pipenv run sam deploy ${DASH_DASH_DEBUG} \ --config-file ~/repo/${SAM_CONFIG_FILE} \ - --config-env $SAM_LAMBDA_CONFIG_ENV \ + --config-env $DC_ENVIRONMENT \ --template-file ~/repo/.aws-sam/build/template.yaml \ --parameter-overrides "GitHash='$CIRCLE_SHA1'" - aws-cli/setup @@ -167,6 +166,7 @@ jobs: name: "Post deploy tests" command: | export FQDN=`aws ssm get-parameter --query Parameter.Value --name 'FQDN' --output text` + export DC_ENVIRONMENT=`aws ssm get-parameter --query Parameter.Value --name 'DC_ENVIRONMENT' --output text` pipenv run pytest .circleci/tests/ diff --git a/.circleci/tests/test_against_fqdn.py b/.circleci/tests/test_against_fqdn.py index e567bd0..92a4b42 100644 --- a/.circleci/tests/test_against_fqdn.py +++ b/.circleci/tests/test_against_fqdn.py @@ -9,7 +9,17 @@ def make_request(url): - req = httpx.get(url, timeout=20) + dc_env = os.environ.get("DC_ENVIRONMENT") + + if dc_env == "development" or dc_env == "staging": + req = httpx.get( + url, + timeout=20, + auth=("dc", "dc"), + ) + else: + req = httpx.get(url, timeout=20) + req.raise_for_status() return req diff --git a/postcode_lookup/app.py b/postcode_lookup/app.py index e63a4e5..1d97019 100644 --- a/postcode_lookup/app.py +++ b/postcode_lookup/app.py @@ -15,6 +15,7 @@ section_tester, ) from mangum import Mangum +from middleware import BasicAuthMiddleware from starlette.applications import Starlette from starlette.middleware import Middleware from starlette.requests import HTTPConnection @@ -148,6 +149,11 @@ def current_language_selector(conn: HTTPConnection) -> str | None: return conn.scope["current_language"] +# This function is used to enable basic auth in development and staging environments +def enable_auth(): + return os.environ.get("DC_ENVIRONMENT") in ["development", "staging"] + + app = Starlette( debug=True, routes=routes, @@ -160,6 +166,7 @@ def current_language_selector(conn: HTTPConnection) -> str | None: selectors=[current_language_selector], ), Middleware(ForwardedForMiddleware), + Middleware(BasicAuthMiddleware, enable_auth=enable_auth()), ], ) diff --git a/postcode_lookup/middleware.py b/postcode_lookup/middleware.py new file mode 100644 index 0000000..b729e4b --- /dev/null +++ b/postcode_lookup/middleware.py @@ -0,0 +1,39 @@ +from starlette.requests import HTTPConnection +from starlette.responses import Response +from starlette.types import ASGIApp, Receive, Scope, Send + + +class BasicAuthMiddleware: + def __init__(self, app: ASGIApp, enable_auth: bool) -> None: + self.enable_auth = enable_auth + self.app = app + + async def __call__( + self, scope: Scope, receive: Receive, send: Send + ) -> None: + if scope["type"] not in ["http", "websocket"]: + await self.app(scope, receive, send) + return + + conn = HTTPConnection(scope) + + if not self.enable_auth: + await self.app(scope, receive, send) + return + + required_auth_header = "Basic ZGM6ZGM=" # "dc:dc" in base64 + + # Check for authorization header: + auth_header = conn.headers.get("Authorization") + if auth_header and auth_header == required_auth_header: + await self.app(scope, receive, send) + return + + # If authorization fails, return 401 Unauthorized and prompt for credentials + response = Response( + "Unauthorized", + status_code=401, + headers={"WWW-Authenticate": 'Basic realm="Restricted"'}, + ) + await response(scope, receive, send) + return diff --git a/template.yaml b/template.yaml index 738cb36..2121234 100644 --- a/template.yaml +++ b/template.yaml @@ -29,11 +29,17 @@ Parameters: Default: AppAPIKey Description: "The DC aggregator API key" Type: AWS::SSM::Parameter::Value + AppSentryDSN: Default: SENTRY_DSN Description: "The Sentry DSN" Type: AWS::SSM::Parameter::Value + DCEnvironment: + Default: DC_ENVIRONMENT + Description: "The DC_ENVIRONMENT environment variable passed to the app." + Type: AWS::SSM::Parameter::Value + Resources: ECDeployerRole: Type: AWS::IAM::Role @@ -64,6 +70,7 @@ Resources: FQDN: !Ref FQDN API_KEY: !Ref AppAPIKey SENTRY_DSN: !Ref AppSentryDSN + DC_ENVIRONMENT: !Ref DCEnvironment Events: HTTPRequests: Type: Api diff --git a/tests/test_basic_auth.py b/tests/test_basic_auth.py new file mode 100644 index 0000000..bf4aba8 --- /dev/null +++ b/tests/test_basic_auth.py @@ -0,0 +1,20 @@ +import importlib + +import app +import pytest +from starlette.testclient import TestClient + + +@pytest.mark.parametrize( + "environment,status_code", + [("development", 401), ("staging", 401), ("production", 200)], +) +def test_basic_auth(monkeypatch, environment, status_code): + monkeypatch.setenv("DC_ENVIRONMENT", environment) + # We reload the app here because the original import happens before the monkeypatched DC_ENVIRONMENT + # variable and therefore it wouldn't see the new env variable when its called in the TestClient's instantiation + importlib.reload(app) + + with TestClient(app=app.app) as client: + resp = client.get("/") + assert resp.status_code == status_code diff --git a/tests/test_middleware.py b/tests/test_middleware.py new file mode 100644 index 0000000..ecffae7 --- /dev/null +++ b/tests/test_middleware.py @@ -0,0 +1,50 @@ +from middleware import BasicAuthMiddleware +from starlette.applications import Starlette +from starlette.responses import PlainTextResponse +from starlette.routing import Route +from starlette.testclient import TestClient + + +def create_app(enable_auth): + async def homepage(request): + return PlainTextResponse("Hello, world!") + + routes = [Route("/", endpoint=homepage)] + app = Starlette(routes=routes) + + app.add_middleware(BasicAuthMiddleware, enable_auth=enable_auth) + return app + + +def test_no_auth_required(): + app = create_app(enable_auth=False) + client = TestClient(app) + response = client.get("/") + assert response.status_code == 200 + assert response.text == "Hello, world!" + + +def test_auth_required_success(): + app = create_app(enable_auth=True) + client = TestClient(app) + response = client.get("/", headers={"Authorization": "Basic ZGM6ZGM="}) + assert response.status_code == 200 + assert response.text == "Hello, world!" + + +def test_auth_required_failure(): + app = create_app(enable_auth=True) + client = TestClient(app) + response = client.get("/") + assert response.status_code == 401 + assert response.headers["www-authenticate"] == 'Basic realm="Restricted"' + assert response.text == "Unauthorized" + + +def test_auth_required_invalid_credentials(): + app = create_app(enable_auth=True) + client = TestClient(app) + response = client.get("/", headers={"Authorization": "Basic invalid"}) + assert response.status_code == 401 + assert response.headers["www-authenticate"] == 'Basic realm="Restricted"' + assert response.text == "Unauthorized"