Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Basic auth middleware #103

Merged
merged 6 commits into from
Oct 14, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ jobs:
enum: [ development, staging, production ]
environment:
SAM_CONFIG_FILE: samconfig.toml
SAM_LAMBDA_CONFIG_ENV: <<parameters.dc-environment>>
DC_ENVIRONMENT: <<parameters.dc-environment>>
SAM_PUBLIC_CONFIG_ENV: <<parameters.dc-environment>>-public-access
awdem marked this conversation as resolved.
Show resolved Hide resolved
steps:
- checkout
Expand All @@ -98,7 +98,7 @@ jobs:
command: |
pipenv run sam deploy ${DASH_DASH_DEBUG} \
--config-file ~/repo/${SAM_CONFIG_FILE} \
--config-env $SAM_LAMBDA_CONFIG_ENV \
--config-env $DC_ENVIRONMENT \
--template-file ~/repo/.aws-sam/build/template.yaml \
--parameter-overrides "GitHash='$CIRCLE_SHA1'"
- aws-cli/setup
Expand Down Expand Up @@ -167,6 +167,7 @@ jobs:
name: "Post deploy tests"
command: |
export FQDN=`aws ssm get-parameter --query Parameter.Value --name 'FQDN' --output text`
export DC_ENVIRONMENT=`aws ssm get-parameter --query Parameter.Value --name 'DC_ENVIRONMENT' --output text`
pipenv run pytest .circleci/tests/


Expand Down
12 changes: 11 additions & 1 deletion .circleci/tests/test_against_fqdn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,17 @@


def make_request(url):
req = httpx.get(url, timeout=20)
dc_env = os.environ.get("DC_ENVIRONMENT")

if dc_env == "development" or dc_env == "staging":
req = httpx.get(
url,
timeout=20,
auth=("dc", "dc"),
)
else:
req = httpx.get(url, timeout=20)

req.raise_for_status()
return req

Expand Down
7 changes: 7 additions & 0 deletions postcode_lookup/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
section_tester,
)
from mangum import Mangum
from middleware import BasicAuthMiddleware
from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.requests import HTTPConnection
Expand Down Expand Up @@ -148,6 +149,11 @@ def current_language_selector(conn: HTTPConnection) -> str | None:
return conn.scope["current_language"]


# This function is used to enable basic auth in development and staging environments
def enable_auth():
return os.environ.get("DC_ENVIRONMENT") in ["development", "staging"]


app = Starlette(
debug=True,
routes=routes,
Expand All @@ -160,6 +166,7 @@ def current_language_selector(conn: HTTPConnection) -> str | None:
selectors=[current_language_selector],
),
Middleware(ForwardedForMiddleware),
Middleware(BasicAuthMiddleware, enable_auth=enable_auth),
],
)

Expand Down
39 changes: 39 additions & 0 deletions postcode_lookup/middleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from starlette.requests import HTTPConnection
from starlette.responses import Response
from starlette.types import ASGIApp, Receive, Scope, Send


class BasicAuthMiddleware:
def __init__(self, app: ASGIApp, enable_auth: callable) -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any particular reason for passing enable_auth() in as a function here?

Seems like we could either

Make enable_auth a boolean param and call it when we add the middleware i.e: Middleware(BasicAuthMiddleware, enable_auth=enable_auth())

or

Remove the enable_auth param and just inspect the env vars in BasicAuthMiddleware.__init__ i.e:

import os

class BasicAuthMiddleware:
    def __init__(self, app: ASGIApp) -> None:
        self.enable_auth = os.environ.get("DC_ENVIRONMENT") in ["development", "staging"]
        self.app = app

which is basically what we're doing in the equivalent middleware in dc_django_utils
https://github.com/DemocracyClub/dc_django_utils/blob/78b20cf5e955c2994f49fd3d5db3fdad8ee5beba/dc_utils/middleware.py#L8-L15

Personally I'd find one or other of those two options clearer. The callable seems like a layer of indirection we don't really need.

The first one (boolean param) is probably easier as it would require fewer changes to the tests in tests/test_middleware.py as you can just replace mock_auth_enabled() and mock_auth_disabled() with True and False.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had originally written in the first way you've proposed as a boolean param, but after discussion with Sym I changed it to callable. I'm not sure if I'll be able to articulate the reasoning clearly, but I'll try.

It had to do with when enable_auth() is resolved. When it's passed as a boolean, the function is resolving when settings.py is imported and passing its return value then. I think the issue with that is, when that import happens, the DC_ENVIRONMENT variable is not set yet so the function doesn't detect the environment properly.

Passing it as a callable means that it will be called each time the middleware is checking for auth which means it should have access to the env variable. I think I'm making a mistake, maybe @symroe can clarify?

I do think that 2nd proposal addresses the same issue and is maybe a better approach, though.

self.enable_auth = enable_auth()
self.app = app

async def __call__(
self, scope: Scope, receive: Receive, send: Send
) -> None:
if scope["type"] not in ["http", "websocket"]:
await self.app(scope, receive, send)
return

conn = HTTPConnection(scope)

if not self.enable_auth:
await self.app(scope, receive, send)
return

required_auth_header = "Basic ZGM6ZGM=" # "dc:dc" in base64

# Check for authorization header:
auth_header = conn.headers.get("Authorization")
if auth_header and auth_header == required_auth_header:
await self.app(scope, receive, send)
return

# If authorization fails, return 401 Unauthorized and prompt for credentials
response = Response(
"Unauthorized",
status_code=401,
headers={"WWW-Authenticate": 'Basic realm="Restricted"'},
)
await response(scope, receive, send)
return
7 changes: 7 additions & 0 deletions template.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,17 @@ Parameters:
Default: AppAPIKey
Description: "The DC aggregator API key"
Type: AWS::SSM::Parameter::Value<String>

AppSentryDSN:
Default: SENTRY_DSN
Description: "The Sentry DSN"
Type: AWS::SSM::Parameter::Value<String>

DCEnvironment:
Default: DC_ENVIRONMENT
Description: "The DC_ENVIRONMENT environment variable passed to the app."
Type: AWS::SSM::Parameter::Value<String>

Resources:
ECDeployerRole:
Type: AWS::IAM::Role
Expand Down Expand Up @@ -64,6 +70,7 @@ Resources:
FQDN: !Ref FQDN
API_KEY: !Ref AppAPIKey
SENTRY_DSN: !Ref AppSentryDSN
DC_ENVIRONMENT: !Ref DCEnvironment
Events:
HTTPRequests:
Type: Api
Expand Down
20 changes: 20 additions & 0 deletions tests/test_basic_auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import importlib

import app
import pytest
from starlette.testclient import TestClient


@pytest.mark.parametrize(
"environment,status_code",
[("development", 401), ("staging", 401), ("production", 200)],
)
def test_basic_auth(monkeypatch, environment, status_code):
monkeypatch.setenv("DC_ENVIRONMENT", environment)
# We reload the app here because the original import happens before the monkeypatched DC_ENVIRONMENT
# variable and therefore it wouldn't see the new env variable when its called in the TestClient's instantiation
importlib.reload(app)

with TestClient(app=app.app) as client:
resp = client.get("/")
assert resp.status_code == status_code
58 changes: 58 additions & 0 deletions tests/test_middleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from middleware import BasicAuthMiddleware
from starlette.applications import Starlette
from starlette.responses import PlainTextResponse
from starlette.routing import Route
from starlette.testclient import TestClient


def mock_auth_enabled():
return True


def mock_auth_disabled():
return False


def create_app(enable_auth):
async def homepage(request):
return PlainTextResponse("Hello, world!")

routes = [Route("/", endpoint=homepage)]
app = Starlette(routes=routes)

app.add_middleware(BasicAuthMiddleware, enable_auth=enable_auth)
return app


def test_no_auth_required():
app = create_app(mock_auth_disabled)
client = TestClient(app)
response = client.get("/")
assert response.status_code == 200
assert response.text == "Hello, world!"


def test_auth_required_success():
app = create_app(mock_auth_enabled)
client = TestClient(app)
response = client.get("/", headers={"Authorization": "Basic ZGM6ZGM="})
assert response.status_code == 200
assert response.text == "Hello, world!"


def test_auth_required_failure():
app = create_app(mock_auth_enabled)
client = TestClient(app)
response = client.get("/")
assert response.status_code == 401
assert response.headers["www-authenticate"] == 'Basic realm="Restricted"'
assert response.text == "Unauthorized"


def test_auth_required_invalid_credentials():
app = create_app(mock_auth_enabled)
client = TestClient(app)
response = client.get("/", headers={"Authorization": "Basic invalid"})
assert response.status_code == 401
assert response.headers["www-authenticate"] == 'Basic realm="Restricted"'
assert response.text == "Unauthorized"
Loading