diff --git a/.ci-scripts/create_user.py b/.ci-scripts/create_user.py deleted file mode 100755 index 1b263b8b..00000000 --- a/.ci-scripts/create_user.py +++ /dev/null @@ -1,11 +0,0 @@ -import django - -django.setup() - -from scram.users.models import User # noqa:E402 - -u, created = User.objects.get_or_create(username="admin") -u.set_password("password") -u.is_staff = True -u.is_superuser = True -u.save() diff --git a/.github/workflows/behave.yml b/.github/workflows/behave.yml new file mode 100644 index 00000000..34fb68f0 --- /dev/null +++ b/.github/workflows/behave.yml @@ -0,0 +1,91 @@ +--- +name: Run behave + +on: + push: + branches: + - "**" + pull_request: + branches: + - main + - develop + + # Allows you to run this workflow manually from the Actions tab + workflow_dispatch: + +jobs: + behave: + name: Run Behave + runs-on: ubuntu-latest + strategy: + max-parallel: 4 + matrix: + python-version: ["3.11", "3.12"] + + steps: + - name: Check out the code + uses: actions/checkout@v4 + + - name: Set up Docker + uses: docker/setup-buildx-action@v3 + + - name: Install Docker Compose + run: | + sudo apt-get update + sudo apt-get install -y docker-compose make + + - name: Check Docker state (pre-build) + run: docker ps + + - name: Build Docker images + run: make build + env: + PYTHON_IMAGE_VER: "${{ matrix.python-version }}" + + - name: Migrate Database + run: make migrate + + - name: Run Application + run: make run + + - name: Check Docker state (pre-test) + run: docker ps + + - name: Run pytest + behave with Coverage + env: + POSTGRES_USER: scram + POSTGRES_DB: test_scram + run: make coverage.xml + + - name: Dump docker logs on failure + if: failure() + uses: jwalton/gh-docker-logs@v2 + + - name: Upload Coverage to Coveralls + if: matrix.python-version == '3.12' + uses: coverallsapp/github-action@v2 + + - name: Upload Coverage to GitHub + if: matrix.python-version == '3.12' + uses: actions/upload-artifact@v4 + with: + name: coverage-report + path: coverage.xml + + - name: Display Coverage Metrics + if: matrix.python-version == '3.12' + uses: 5monkeys/cobertura-action@v14 + with: + minimum_coverage: "50" + + - name: Check Docker state (post-test) + if: always() + run: docker ps + + - name: Stop Services + if: always() + run: make stop + + - name: Clean Up + if: always() + run: make clean diff --git a/.github/workflows/behave_next_python.yml b/.github/workflows/behave_next_python.yml new file mode 100644 index 00000000..1c4b1a3c --- /dev/null +++ b/.github/workflows/behave_next_python.yml @@ -0,0 +1,77 @@ +--- +name: Run behave with unsupported Python versions + +on: + push: + branches: + - '**' + pull_request: + branches: + - main + - develop + + # Allows you to run this workflow manually from the Actions tab + workflow_dispatch: + +jobs: + behave_next_python: + name: Run Behave + runs-on: ubuntu-latest + strategy: + max-parallel: 4 + matrix: + python-version: ['3.13'] + + steps: + - name: Check out the code + uses: actions/checkout@v4 + + - name: Set up Docker + uses: docker/setup-buildx-action@v3 + + - name: Install Docker Compose + run: | + sudo apt-get update + sudo apt-get install -y docker-compose make + + - name: Check Docker state (pre-build) + run: docker ps + + - name: Build Docker images + run: make build + env: + PYTHON_IMAGE_VER: "${{ matrix.python-version }}" + + - name: Migrate Database + run: | + make migrate || echo "::warning:: migrate failed on future Python version ${{ matrix.python-version }}." + + - name: Run Application + run: | + make run || echo "::warning:: run failed on future Python version ${{ matrix.python-version }}." + + - name: Check Docker state (pre-test) + run: docker ps + + - name: Run pytest + behave with Coverage + env: + POSTGRES_USER: scram + POSTGRES_DB: test_scram + run: | + make coverage.xml || echo "::warning:: pytest + behave failed on future Python version ${{ matrix.python-version }}." + + - name: Dump docker logs on failure + if: failure() + uses: jwalton/gh-docker-logs@v2 + + - name: Check Docker state (post-test) + if: always() + run: docker ps + + - name: Stop Services + if: always() + run: make stop + + - name: Clean Up + if: always() + run: make clean diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index d3fbb051..223f58d5 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -45,7 +45,7 @@ jobs: uses: actions/upload-pages-artifact@v3 with: # Upload entire repository - path: 'docs/_build/html' + path: 'site/' - name: Deploy to GitHub Pages if: github.ref == 'refs/heads/main' diff --git a/.github/workflows/flake8.yml b/.github/workflows/flake8.yml deleted file mode 100644 index 887ebc50..00000000 --- a/.github/workflows/flake8.yml +++ /dev/null @@ -1,33 +0,0 @@ -name: Run flake8 - -on: - push: - branches: - - '**' - pull_request: - branches: - - main - - develop - - # Allows you to run this workflow manually from the Actions tab - workflow_dispatch: - -jobs: - flake8: - runs-on: ubuntu-latest - steps: - - name: Checkout code - uses: actions/checkout@v4 - - - name: Cache Docker images. - uses: ScribeMD/docker-cache@0.3.7 - with: - key: docker-${{ runner.os }}-${{ hashFiles('docker-compose.yaml') }} - - - name: Install dependencies - run: | - pip install flake8 flake8-docstrings - - - name: Run flake8 - run: | - flake8 scram translator diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index e389d697..5d4004d2 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -17,6 +17,10 @@ jobs: pytest: name: Run Pytest runs-on: ubuntu-latest + strategy: + max-parallel: 4 + matrix: + python-version: ['3.11', '3.12'] services: postgres: @@ -24,7 +28,7 @@ jobs: env: POSTGRES_USER: scram POSTGRES_PASSWORD: '' - POSTGRES_DB: test_scram + POSTGRES_DB: test_scram_${{ matrix.python-version }} POSTGRES_HOST_AUTH_METHOD: trust ports: - 5432:5432 @@ -34,63 +38,45 @@ jobs: --health-timeout 5s --health-retries 5 + redis: + image: redis:5.0 + ports: + - 6379:6379 + steps: - name: Check out the code uses: actions/checkout@v4 - - name: Cache Docker images. - uses: ScribeMD/docker-cache@0.3.7 + - uses: actions/setup-python@v5 with: - key: docker-${{ runner.os }}-${{ hashFiles('docker-compose.yaml') }} + python-version: ${{ matrix.python-version }} - - name: Cache Docker layers - uses: actions/cache@v4 - with: - path: /tmp/.buildx-cache - key: ${{ runner.os }}-single-buildx-${{ github.sha }} - restore-keys: | - ${{ runner.os }}-single-buildx - - - name: Set up Docker - uses: docker/setup-buildx-action@v3 - - name: Install Docker Compose + - name: Install dependencies run: | - sudo apt-get update - sudo apt-get install -y docker-compose make - - - name: Build Docker images - run: make build + python -m pip install --upgrade pip + pip install -r requirements/local.txt + # https://github.com/pytest-dev/pytest-github-actions-annotate-failures/pull/68 isn't yet in a release + pip install git+https://github.com/pytest-dev/pytest-github-actions-annotate-failures.git@6e66cd895fe05cd09be8bad58f5d79110a20385f - - name: Migrate Database - run: make migrate - - - name: Run Application - run: make run - - - name: Run Pytest with Coverage + - name: Apply migrations env: - POSTGRES_USER: scram - POSTGRES_DB: test_scram - run: make coverage.xml - - - name: Upload Coverage to Coveralls - uses: coverallsapp/github-action@v2 - - - name: Upload Coverage to GitHub - uses: actions/upload-artifact@v4 - with: - name: coverage-report - path: coverage.xml - - - name: Display Coverage Metrics - uses: 5monkeys/cobertura-action@v14 - with: - minimum_coverage: '50' + DATABASE_URL: "postgres://scram:@localhost:5432/test_scram_${{ matrix.python-version }}" + run: | + python manage.py makemigrations --noinput + python manage.py migrate --noinput - - name: Stop Services - if: always() - run: make stop + - name: Check for duplicate migrations + env: + DATABASE_URL: "postgres://scram:@localhost:5432/test_scram_${{ matrix.python-version }}" + run: | + if python manage.py makemigrations --dry-run | grep "No changes detected"; then + echo "No duplicate migrations detected." + else + echo "::warning:: Potential duplicate migrations detected. Please review." + fi - - name: Clean Up - if: always() - run: make clean + - name: Run Pytest + env: + DATABASE_URL: "postgres://scram:@localhost:5432/test_scram_${{ matrix.python-version }}" + REDIS_HOST: "localhost" + run: pytest diff --git a/.github/workflows/pytest_next_python.yml b/.github/workflows/pytest_next_python.yml new file mode 100644 index 00000000..dbc9d4e1 --- /dev/null +++ b/.github/workflows/pytest_next_python.yml @@ -0,0 +1,88 @@ +--- +name: Run pytest with unsupported Python versions + +on: + push: + branches: + - '**' + pull_request: + branches: + - main + - develop + + # Allows you to run this workflow manually from the Actions tab + workflow_dispatch: + +jobs: + pytest_next_python: + name: Run Pytest + runs-on: ubuntu-latest + strategy: + max-parallel: 4 + matrix: + python-version: ['3.13'] + + services: + postgres: + image: postgres:latest + env: + POSTGRES_USER: scram + POSTGRES_PASSWORD: '' + POSTGRES_DB: test_scram + POSTGRES_HOST_AUTH_METHOD: trust + ports: + - 5432:5432 + options: >- + --health-cmd "pg_isready -U scram" + --health-interval 10s + --health-timeout 5s + --health-retries 5 + + redis: + image: redis:5.0 + ports: + - 6379:6379 + + steps: + - name: Check out the code + uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements/local.txt + # https://github.com/pytest-dev/pytest-github-actions-annotate-failures/pull/68 isn't yet in a release + pip install git+https://github.com/pytest-dev/pytest-github-actions-annotate-failures.git@6e66cd895fe05cd09be8bad58f5d79110a20385f + + - name: Apply unapplied migrations + env: + DATABASE_URL: "postgres://scram:@localhost:5432/test_scram" + run: | + python manage.py makemigrations --noinput || true + UNAPPLIED_MIGRATIONS=$(python manage.py showmigrations --plan | grep '\[ \]' | awk '{print $2}') + if [ -n "$UNAPPLIED_MIGRATIONS" ]; then + for migration in $UNAPPLIED_MIGRATIONS; do + python manage.py migrate $migration --fake-initial --noinput + done + else + echo "No unapplied migrations." + fi + + - name: Check for duplicate migrations + run: | + if python manage.py makemigrations --dry-run | grep "No changes detected"; then + echo "No duplicate migrations detected." + else + echo "Warning: Potential duplicate migrations detected. Please review." + fi + + - name: Run Pytest + env: + DATABASE_URL: "postgres://scram:@localhost:5432/test_scram" + REDIS_HOST: "localhost" + run: | + pytest || echo "::warning:: Failed on future Python version ${{ matrix.python-version }}." diff --git a/.github/workflows/ruff.yml b/.github/workflows/ruff.yml new file mode 100644 index 00000000..31647732 --- /dev/null +++ b/.github/workflows/ruff.yml @@ -0,0 +1,29 @@ +name: Run ruff + +on: + push: + branches: + - '**' + pull_request: + branches: + - main + - develop + + # Allows you to run this workflow manually from the Actions tab + workflow_dispatch: + +jobs: + ruff: + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Install dependencies + run: pip install ruff + + - name: Fail if we have any style errors + run: ruff check --output-format=github + + - name: Fail if the code is not formatted correctly (like Black) + run: ruff format --diff diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 48d2a7b2..08d8da08 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,4 +1,4 @@ -exclude: 'docs|node_modules|migrations|.git|.tox' +exclude: 'docs|migrations|.git|.tox' default_stages: [commit] fail_fast: true @@ -9,17 +9,11 @@ repos: - id: trailing-whitespace - id: end-of-file-fixer - - repo: https://github.com/psf/black - rev: 23.9.1 + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.8.0 hooks: - - id: black - - - repo: https://github.com/timothycrosley/isort - rev: 5.12.0 - hooks: - - id: isort - - - repo: https://github.com/PyCQA/flake8 - rev: 6.1.0 - hooks: - - id: flake8 + # Run the linter. + - id: ruff + args: [ --fix ] + # Run the formatter. + - id: ruff-format diff --git a/Makefile b/Makefile index 26b6f3a0..347eb58c 100644 --- a/Makefile +++ b/Makefile @@ -1,3 +1,7 @@ +# It'd be nice to keep these in sync with the defaults of the Dockerfiles +PYTHON_IMAGE_VER ?= 3.12 +POSTGRES_IMAGE_VER ?= 12.3 + .DEFAULT_GOAL := help ## toggle-prod: configure make to use the production stack @@ -37,7 +41,8 @@ behave-translator: compose.override.yml ## build: rebuilds all your containers or a single one if CONTAINER is specified .Phony: build build: compose.override.yml - @docker compose up -d --no-deps --build $(CONTAINER) + @docker compose build --build-arg PYTHON_IMAGE_VER=$(PYTHON_IMAGE_VER) --build-arg POSTGRES_IMAGE_VER=$(POSTGRES_IMAGE_VER) $(CONTAINER) + @docker compose up -d --no-deps $(CONTAINER) @docker compose restart $(CONTAINER) ## coverage.xml: generate coverage from test runs @@ -138,3 +143,13 @@ tail-log: compose.override.yml .Phony: type-check type-check: compose.override.yml @docker compose run --rm django mypy scram + +## docs-build: build the documentation +.Phony: docs-build +docs-build: + @docker compose run --rm docs mkdocs build + +## docs-serve: build and run a server with the documentation +.Phony: docs-serve +docs-serve: + @docker compose run --rm docs mkdocs serve -a 0.0.0.0:8888 diff --git a/compose.override.local.yml b/compose.override.local.yml index c2075761..9332114a 100644 --- a/compose.override.local.yml +++ b/compose.override.local.yml @@ -44,12 +44,10 @@ services: networks: default: {} volumes: - - $CI_PROJECT_DIR/docs:/docs:z - - $CI_PROJECT_DIR/config:/app/config:z - - $CI_PROJECT_DIR/scram:/app/scram:z + - $CI_PROJECT_DIR:/app:z ports: - - "7000" - command: /start-docs + - "${DOCS_PORT:-8888}" + command: "mkdocs serve -a 0.0.0.0:${DOCS_PORT:-8888}" redis: ports: diff --git a/compose/local/django/Dockerfile b/compose/local/django/Dockerfile index 64bd4a6b..aaba2ef9 100644 --- a/compose/local/django/Dockerfile +++ b/compose/local/django/Dockerfile @@ -1,5 +1,8 @@ -FROM python:3.12-slim-bookworm +ARG PYTHON_IMAGE_VER=3.12 +FROM python:${PYTHON_IMAGE_VER}-slim-bookworm + +ENV PIP_ROOT_USER_ACTION ignore ENV PYTHONUNBUFFERED 1 ENV PYTHONDONTWRITEBYTECODE 1 diff --git a/compose/local/django/start b/compose/local/django/start index 2d910375..95f91dff 100644 --- a/compose/local/django/start +++ b/compose/local/django/start @@ -4,5 +4,6 @@ set -o errexit set -o pipefail set -o nounset +mkdir -p /app/staticfiles python manage.py migrate uvicorn config.asgi:application --host 0.0.0.0 --reload diff --git a/compose/local/docs/Dockerfile b/compose/local/docs/Dockerfile index 9ab25f44..49965c75 100644 --- a/compose/local/docs/Dockerfile +++ b/compose/local/docs/Dockerfile @@ -1,5 +1,8 @@ -FROM python:3.8-slim-buster +ARG PYTHON_IMAGE_VER=3.12 +FROM python:${PYTHON_IMAGE_VER}-slim-bookworm + +ENV PIP_ROOT_USER_ACTION ignore ENV PYTHONUNBUFFERED 1 ENV PYTHONDONTWRITEBYTECODE 1 @@ -19,13 +22,11 @@ RUN apt-get update \ && apt-get purge -y --auto-remove -o APT::AutoRemove::RecommendsImportant=false \ && rm -rf /var/lib/apt/lists/* -# Requirements are installed here to ensure they will be cached. -COPY ./requirements /requirements -# All imports needed for autodoc. -RUN pip install -r /requirements/local.txt -r /requirements/production.txt -COPY ./compose/local/docs/start /start-docs -RUN sed -i 's/\r$//g' /start-docs -RUN chmod +x /start-docs +# Only re-run the pip install if these files have changed +COPY requirements/base.txt requirements/local.txt requirements/production.txt /app/requirements/ +RUN pip install -r /app/requirements/local.txt -r /app/requirements/production.txt + +COPY . /app/ -WORKDIR /docs +WORKDIR /app diff --git a/compose/local/docs/start b/compose/local/docs/start deleted file mode 100644 index c562c13a..00000000 --- a/compose/local/docs/start +++ /dev/null @@ -1,8 +0,0 @@ -#!/bin/bash - -set -o errexit -set -o pipefail -set -o nounset - -make apidocs -make livehtml diff --git a/compose/production/django/Dockerfile b/compose/production/django/Dockerfile index cbb5aeb5..84748126 100644 --- a/compose/production/django/Dockerfile +++ b/compose/production/django/Dockerfile @@ -1,5 +1,6 @@ +ARG PYTHON_IMAGE_VER=3.12 -FROM python:3.12-slim-bookwork +FROM python:${PYTHON_IMAGE_VER}-slim-bookworm ENV PYTHONUNBUFFERED 1 diff --git a/compose/production/postgres/Dockerfile b/compose/production/postgres/Dockerfile index c4160f1e..88f86a97 100644 --- a/compose/production/postgres/Dockerfile +++ b/compose/production/postgres/Dockerfile @@ -1,4 +1,6 @@ -FROM postgres:12.3 +ARG POSTGRES_IMAGE_VER=12.3 + +FROM postgres:${POSTGRES_IMAGE_VER} COPY ./compose/production/postgres/maintenance /usr/local/bin/maintenance RUN chmod +x /usr/local/bin/maintenance/* diff --git a/config/__init__.py b/config/__init__.py index e69de29b..69052b7a 100644 --- a/config/__init__.py +++ b/config/__init__.py @@ -0,0 +1 @@ +"""Holds the setings and entrypoints.""" diff --git a/config/api_router.py b/config/api_router.py index a5b77f57..a57a12ba 100644 --- a/config/api_router.py +++ b/config/api_router.py @@ -1,3 +1,5 @@ +"""Map the API routes to the views.""" + from rest_framework.routers import DefaultRouter from scram.route_manager.api.views import ActionTypeViewSet, ClientViewSet, EntryViewSet, IgnoreEntryViewSet diff --git a/config/asgi.py b/config/asgi.py index 98a16ce2..5e11d1fa 100644 --- a/config/asgi.py +++ b/config/asgi.py @@ -1,5 +1,4 @@ -""" -ASGI config for SCRAM project. +"""ASGI config for SCRAM project. It exposes the ASGI callable as a module-level variable named ``application``. @@ -7,6 +6,7 @@ https://docs.djangoproject.com/en/dev/howto/deployment/asgi/ """ + import logging import os import sys @@ -18,28 +18,30 @@ # TODO: from channels.security.websocket import AllowedHostsOriginValidator from django.core.asgi import get_asgi_application +logger = logging.getLogger(__name__) + # Here we setup a debugger if this is desired. This obviously should not be run in production. debug = os.environ.get("DEBUG") if debug: - logging.info(f"Django is set to use a debugger. Provided debug mode: {debug}") + logger.info("Django is set to use a debugger. Provided debug mode: %s", debug) if debug == "pycharm-pydevd": - logging.info("Entering debug mode for pycharm, make sure the debug server is running in PyCharm!") + logger.info("Entering debug mode for pycharm, make sure the debug server is running in PyCharm!") import pydevd_pycharm pydevd_pycharm.settrace("host.docker.internal", port=56783, stdoutToServer=True, stderrToServer=True) - logging.info("Debugger started.") + logger.info("Debugger started.") elif debug == "debugpy": - logging.info("Entering debug mode for debugpy (VSCode)") + logger.info("Entering debug mode for debugpy (VSCode)") import debugpy - debugpy.listen(("0.0.0.0", 56780)) + debugpy.listen(("0.0.0.0", 56780)) # noqa S104 (doesn't like binding to all interfaces) - logging.info("Debugger listening on port 56780.") + logger.info("Debugger listening on port 56780.") else: - logging.warning(f"Invalid debug mode given: {debug}. Debugger not started") + logger.warning("Invalid debug mode given: %s. Debugger not started", debug) # This allows easy placement of apps within the interior # scram directory. @@ -62,5 +64,5 @@ { "http": django_application, "websocket": ws_application, - } + }, ) diff --git a/config/consumers.py b/config/consumers.py index 11cca2ef..cd163bb7 100644 --- a/config/consumers.py +++ b/config/consumers.py @@ -1,14 +1,22 @@ +"""Define logic for the WebSocket consumers.""" + import logging +from functools import partial from asgiref.sync import sync_to_async from channels.generic.websocket import AsyncJsonWebsocketConsumer from scram.route_manager.models import Entry, WebSocketSequenceElement +logger = logging.getLogger(__name__) + class TranslatorConsumer(AsyncJsonWebsocketConsumer): + """Handle messages from the Translator(s).""" + async def connect(self): - logging.info("Translator connected") + """Handle the initial connection with adding to the right group.""" + logger.info("Translator connected") self.actiontype = self.scope["url_route"]["kwargs"]["actiontype"] self.translator_group = f"translator_{self.actiontype}" @@ -17,10 +25,10 @@ async def connect(self): # Filter WebSocketSequenceElements by actiontype elements = await sync_to_async(list)( - WebSocketSequenceElement.objects.filter(action_type__name=self.actiontype).order_by("order_num") + WebSocketSequenceElement.objects.filter(action_type__name=self.actiontype).order_by("order_num"), ) if not elements: - logging.warning(f"No elements found for actiontype={self.actiontype}.") + logger.warning("No elements found for actiontype=%s.", self.actiontype) return # Avoid lazy evaluation @@ -28,16 +36,17 @@ async def connect(self): for route in routes: for element in elements: - msg = await sync_to_async(lambda: element.websocketmessage)() + msg = await sync_to_async(partial(element.websocketmessage)) msg.msg_data[msg.msg_data_route_field] = str(route) await self.send_json({"type": msg.msg_type, "message": msg.msg_data}) async def disconnect(self, close_code): - logging.info(f"Disconnect received: {close_code}") + """Discard any remaining messages on disconnect.""" + logger.info("Disconnect received: %s", close_code) await self.channel_layer.group_discard(self.translator_group, self.channel_name) async def receive_json(self, content): - """Received a WebSocket message""" + """Handle a WebSocket message.""" if content["type"] == "translator_check_resp": # We received a check response from a translator, forward to web UI. channel = content.pop("channel") @@ -58,14 +67,17 @@ async def _send_event(self, event): class WebUIConsumer(AsyncJsonWebsocketConsumer): + """Handle messages from the Web UI.""" + async def connect(self): + """Handle the initial connection with adding to the right group.""" self.actiontype = self.scope["url_route"]["kwargs"]["actiontype"] self.translator_group = f"translator_{self.actiontype}" await self.accept() - # Receive message from WebSocket async def receive_json(self, content): + """Receive message from WebSocket.""" if content["type"] == "wui_check_req": # Web UI asks us to check; forward to translator(s) await self.channel_layer.group_send( diff --git a/config/routing.py b/config/routing.py index bae0298a..10a63105 100644 --- a/config/routing.py +++ b/config/routing.py @@ -1,3 +1,5 @@ +"""Define URLs for the WebSocket consumers.""" + from django.urls import re_path from . import consumers diff --git a/config/settings/__init__.py b/config/settings/__init__.py index e69de29b..dd91545c 100644 --- a/config/settings/__init__.py +++ b/config/settings/__init__.py @@ -0,0 +1 @@ +"""Define Django settings. Everyone gets base, and then we can override in local or production.""" diff --git a/config/settings/base.py b/config/settings/base.py index 766d785e..be90d71e 100644 --- a/config/settings/base.py +++ b/config/settings/base.py @@ -1,6 +1,4 @@ -""" -Base settings to build other settings files upon. -""" +"""Base settings to build other settings files upon.""" import logging import os @@ -8,6 +6,8 @@ import environ +logger = logging.getLogger(__name__) + ROOT_DIR = Path(__file__).resolve(strict=True).parent.parent.parent # scram/ APPS_DIR = ROOT_DIR / "scram" @@ -188,7 +188,7 @@ "scram.route_manager.context_processors.login_logout", ], }, - } + }, ] # https://docs.djangoproject.com/en/dev/ref/settings/#form-renderer @@ -238,14 +238,14 @@ "version": 1, "disable_existing_loggers": False, "formatters": { - "verbose": {"format": "%(levelname)s %(asctime)s %(module)s " "%(process)d %(thread)d %(message)s"} + "verbose": {"format": "%(levelname)s %(asctime)s %(module)s %(process)d %(thread)d %(message)s"}, }, "handlers": { "console": { "level": "DEBUG", "class": "logging.StreamHandler", "formatter": "verbose", - } + }, }, "root": {"level": "INFO", "handlers": ["console"]}, } @@ -256,7 +256,7 @@ "default": { "BACKEND": "channels_redis.core.RedisChannelLayer", "CONFIG": { - "hosts": [("redis", 6379)], + "hosts": [(os.environ.get("REDIS_HOST", "redis"), 6379)], }, }, } @@ -264,6 +264,8 @@ # django-rest-framework # ------------------------------------------------------------------------------- # django-rest-framework - https://www.django-rest-framework.org/api-guide/settings/ +# Swagger related tooling +INSTALLED_APPS += ["drf_spectacular"] REST_FRAMEWORK = { "DEFAULT_AUTHENTICATION_CLASSES": ( "rest_framework.authentication.SessionAuthentication", @@ -271,6 +273,7 @@ ), "DEFAULT_PERMISSION_CLASSES": ("rest_framework.permissions.IsAuthenticated",), "TEST_REQUEST_DEFAULT_FORMAT": "json", + "DEFAULT_SCHEMA_CLASS": "drf_spectacular.openapi.AutoSchema", } # django-cors-headers - https://github.com/adamchainz/django-cors-headers#setup @@ -307,13 +310,13 @@ ) OIDC_RP_SIGN_ALGO = "RS256" -logging.info(f"Using AUTH METHOD = {AUTH_METHOD}") +logger.info("Using AUTH METHOD=%s", AUTH_METHOD) if AUTH_METHOD == "oidc": # Extend middleware to add OIDC middleware - MIDDLEWARE += ["mozilla_django_oidc.middleware.SessionRefresh"] # noqa F405 + MIDDLEWARE += ["mozilla_django_oidc.middleware.SessionRefresh"] # Extend middleware to add OIDC auth backend - AUTHENTICATION_BACKENDS += ["scram.route_manager.authentication_backends.ESnetAuthBackend"] # noqa F405 + AUTHENTICATION_BACKENDS += ["scram.route_manager.authentication_backends.ESnetAuthBackend"] # https://docs.djangoproject.com/en/dev/ref/settings/#login-url LOGIN_URL = "oidc_authentication_init" @@ -331,7 +334,8 @@ # https://docs.djangoproject.com/en/dev/ref/settings/#logout-url LOGOUT_URL = "local_auth:logout" else: - raise ValueError(f"Invalid authentication method: {AUTH_METHOD}. Please choose 'local' or 'oidc'") + msg = f"Invalid authentication method: {AUTH_METHOD}. Please choose 'local' or 'oidc'" + raise ValueError(msg) # Should we create an admin user for you diff --git a/config/settings/local.py b/config/settings/local.py index 83b31dd8..56029930 100644 --- a/config/settings/local.py +++ b/config/settings/local.py @@ -1,5 +1,5 @@ from .base import * # noqa -from .base import AUTH_METHOD, env +from .base import AUTH_METHOD, REST_FRAMEWORK, env # GENERAL # ------------------------------------------------------------------------------ @@ -11,7 +11,7 @@ default="BmZnn8FeNFdaeCod8ky6eBNpTiwO45NzlFyA6kk1xo0g4Mc263gAyscHFCMCeJAi", ) # https://docs.djangoproject.com/en/dev/ref/settings/#allowed-hosts -ALLOWED_HOSTS = ["localhost", "0.0.0.0", "127.0.0.1", "django"] +ALLOWED_HOSTS = ["*"] # CACHES # ------------------------------------------------------------------------------ @@ -20,7 +20,7 @@ "default": { "BACKEND": "django.core.cache.backends.locmem.LocMemCache", "LOCATION": "", - } + }, } # EMAIL @@ -36,12 +36,12 @@ # django-coverage-plugin # ------------------------------------------------------------------------------ # https://github.com/nedbat/django_coverage_plugin?tab=readme-ov-file#django-template-coveragepy-plugin -TEMPLATES[0]["OPTIONS"]['debug'] = True # noqa F405 +TEMPLATES[0]["OPTIONS"]["debug"] = True # noqa F405 # django-debug-toolbar # ------------------------------------------------------------------------------ # https://django-debug-toolbar.readthedocs.io/en/latest/installation.html#prerequisites -INSTALLED_APPS += ["debug_toolbar"] # noqa F405 +INSTALLED_APPS += ["debug_toolbar"] # https://django-debug-toolbar.readthedocs.io/en/latest/installation.html#middleware MIDDLEWARE += ["debug_toolbar.middleware.DebugToolbarMiddleware"] # noqa F405 # https://django-debug-toolbar.readthedocs.io/en/latest/configuration.html#debug-toolbar-config @@ -51,7 +51,7 @@ } # https://django-debug-toolbar.readthedocs.io/en/latest/installation.html#internal-ips INTERNAL_IPS = ["127.0.0.1", "10.0.2.2"] -if env("USE_DOCKER") == "yes": +if env("USE_DOCKER", default="no") == "yes": import socket hostname, _, ips = socket.gethostbyname_ex(socket.gethostname()) @@ -60,23 +60,21 @@ # django-extensions # ------------------------------------------------------------------------------ # https://django-extensions.readthedocs.io/en/latest/installation_instructions.html#configuration -INSTALLED_APPS += ["django_extensions"] # noqa F405 +INSTALLED_APPS += ["django_extensions"] # Your stuff... # ------------------------------------------------------------------------------ - -REST_FRAMEWORK = { - "DEFAULT_PERMISSION_CLASSES": ("rest_framework.permissions.IsAdminUser",), -} +REST_FRAMEWORK["DEFAULT_PERMISSION_CLASSES"] = ("rest_framework.permissions.IsAdminUser",) # Behave Django testing framework -INSTALLED_APPS += ["behave_django"] # noqa F405 +INSTALLED_APPS += ["behave_django"] # AUTHENTICATION # ------------------------------------------------------------------------------ # We shouldn't be using OIDC in local dev mode as of now, but might be worth pursuing later if AUTH_METHOD == "oidc": - raise NotImplementedError("oidc is not yet implemented") + msg = "oidc is not yet implemented" + raise NotImplementedError(msg) # https://docs.djangoproject.com/en/dev/ref/settings/#login-url LOGIN_URL = "admin:login" diff --git a/config/settings/production.py b/config/settings/production.py index 8f62b5d4..cdb213f0 100644 --- a/config/settings/production.py +++ b/config/settings/production.py @@ -28,7 +28,7 @@ # https://github.com/jazzband/django-redis#memcached-exceptions-behavior "IGNORE_EXCEPTIONS": True, }, - } + }, } # SECURITY @@ -68,7 +68,7 @@ "django.template.loaders.filesystem.Loader", "django.template.loaders.app_directories.Loader", ], - ) + ), ] # EMAIL @@ -108,7 +108,7 @@ "disable_existing_loggers": False, "filters": {"require_debug_false": {"()": "django.utils.log.RequireDebugFalse"}}, "formatters": { - "verbose": {"format": "%(levelname)s %(asctime)s %(module)s " "%(process)d %(thread)d %(message)s"} + "verbose": {"format": "%(levelname)s %(asctime)s %(module)s %(process)d %(thread)d %(message)s"}, }, "handlers": { "mail_admins": { diff --git a/config/settings/test.py b/config/settings/test.py index a4951b2c..606cee75 100644 --- a/config/settings/test.py +++ b/config/settings/test.py @@ -1,6 +1,4 @@ -""" -With these settings, tests run faster. -""" +"""With these settings, tests run faster.""" from .base import * # noqa from .base import env @@ -29,8 +27,9 @@ "django.template.loaders.filesystem.Loader", "django.template.loaders.app_directories.Loader", ], - ) + ), ] +TEMPLATES[0]["OPTIONS"]["debug"] = True # noqa F405 # EMAIL # ------------------------------------------------------------------------------ diff --git a/config/urls.py b/config/urls.py index ec9e5b7b..41a45d12 100644 --- a/config/urls.py +++ b/config/urls.py @@ -1,9 +1,12 @@ +"""Define the non-WebSocket URLs for Django.""" + from django.conf import settings from django.conf.urls.static import static from django.contrib import admin from django.contrib.staticfiles.urls import staticfiles_urlpatterns from django.urls import include, path from django.views import defaults as default_views +from drf_spectacular.views import SpectacularAPIView, SpectacularRedocView, SpectacularSwaggerView from rest_framework.authtoken.views import obtain_auth_token from .api_router import app_name @@ -16,7 +19,8 @@ # User management path("users/", include("scram.users.urls", namespace="users")), # Your stuff: custom urls includes go here -] + static(settings.MEDIA_URL, document_root=settings.MEDIA_ROOT) + *static(settings.MEDIA_URL, document_root=settings.MEDIA_ROOT), +] if settings.DEBUG: # Static file serving when using Gunicorn + Uvicorn for local web socket development urlpatterns += staticfiles_urlpatterns() @@ -44,6 +48,13 @@ path("auth-token/", obtain_auth_token), ] +# Swagger OpenAPI URLs +urlpatterns += [ + path("schema/", SpectacularAPIView.as_view(), name="schema"), + path("schema/swagger-ui/", SpectacularSwaggerView.as_view(url_name="schema"), name="swagger-ui"), + path("schema/redoc/", SpectacularRedocView.as_view(url_name="schema"), name="redoc"), +] + if settings.DEBUG: # This allows the error pages to be debugged during development, just visit # these url in browser to see how these error pages look like. @@ -68,4 +79,4 @@ if "debug_toolbar" in settings.INSTALLED_APPS: import debug_toolbar - urlpatterns = [path("__debug__/", include(debug_toolbar.urls))] + urlpatterns + urlpatterns = [path("__debug__/", include(debug_toolbar.urls)), *urlpatterns] diff --git a/config/websocket.py b/config/websocket.py index 81adfbc6..b48dfb8b 100644 --- a/config/websocket.py +++ b/config/websocket.py @@ -1,4 +1,8 @@ +"""TODO: Find out if this is used.""" + + async def websocket_application(scope, receive, send): + """Handle WebSocket messages. I guess.""" while True: event = await receive() diff --git a/config/wsgi.py b/config/wsgi.py index 98830366..73ed363b 100644 --- a/config/wsgi.py +++ b/config/wsgi.py @@ -1,5 +1,4 @@ -""" -WSGI config for SCRAM project. +"""WSGI config for SCRAM project. This module contains the WSGI application used by Django's development server and any production WSGI deployments. It should expose a module-level variable @@ -13,6 +12,7 @@ framework. """ + import os import sys from pathlib import Path @@ -26,13 +26,10 @@ # We defer to a DJANGO_SETTINGS_MODULE already in the environment. This breaks # if running multiple sites in the same mod_wsgi process. To fix this, use # mod_wsgi daemon mode with each site in its own daemon process, or use -# os.environ["DJANGO_SETTINGS_MODULE"] = "config.settings.production" +# os.environ["DJANGO_SETTINGS_MODULE"] = "config.settings.production" # noqa ERA001 os.environ.setdefault("DJANGO_SETTINGS_MODULE", "config.settings.production") # This application object is used by any WSGI server configured to use this # file. This includes Django's development server, if the WSGI_APPLICATION # setting points here. application = get_wsgi_application() -# Apply WSGI middleware here. -# from helloworld.wsgi import HelloWorldApplication -# application = HelloWorldApplication(application) diff --git a/docs/Makefile b/docs/Makefile deleted file mode 100644 index f7bac529..00000000 --- a/docs/Makefile +++ /dev/null @@ -1,29 +0,0 @@ -# Minimal makefile for Sphinx documentation -# - -# You can set these variables from the command line, and also -# from the environment for the first two. -SPHINXOPTS ?= -SPHINXBUILD ?= sphinx-build -c . -SOURCEDIR = . -BUILDDIR = ./_build -APP = /app - -.PHONY: help livehtml apidocs Makefile - -# Put it first so that "make" without argument is like "make help". -help: - @$(SPHINXBUILD) help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) - -# Build, watch and serve docs with live reload -livehtml: - sphinx-autobuild -b html --host 0.0.0.0 --port 7000 --watch $(APP) -c . $(SOURCEDIR) $(BUILDDIR)/html - -# Outputs rst files from django application code -apidocs: - sphinx-apidoc -o $(SOURCEDIR)/api /app - -# Catch-all target: route all unknown targets to Sphinx using the new -# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). -%: Makefile - @$(SPHINXBUILD) -b $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/README.rst b/docs/README.md similarity index 82% rename from README.rst rename to docs/README.md index 876815b3..53f356c5 100644 --- a/README.rst +++ b/docs/README.md @@ -1,30 +1,19 @@ -SCRAM -===== +# SCRAM Security Catch and Release Automation Manager -.. image:: https://coveralls.io/repos/github/esnet-security/SCRAM/badge.svg - :target: https://coveralls.io/github/esnet-security/SCRAM - :alt: Coveralls Code Coverage Stats -.. image:: https://img.shields.io/badge/built%20with-Cookiecutter%20Django-ff69b4.svg?logo=cookiecutter - :target: https://github.com/pydanny/cookiecutter-django/ - :alt: Built with Cookiecutter Django -.. image:: https://img.shields.io/badge/code%20style-black-000000.svg - :target: https://github.com/ambv/black - :alt: Black code style +[]() +[]() +[]() -:License: BSD - -==== - -Overview -==== +## Overview SCRAM is a web based service to assist in automation of security data. There is a web interface as well as a REST API available. + The idea is to create actiontypes which allow you to take actions on the IPs/cidr networks you provide. -Components -==== +## Components + SCRAM utilizes ``docker compose`` to run the following stack in production: - nginx (as a webserver and static asset server) @@ -37,8 +26,7 @@ SCRAM utilizes ``docker compose`` to run the following stack in production: A predefined actiontype of "block" exists which utilizes bgp nullrouting to effectivley block any traffic you want to apply. You can add any other actiontypes via the admin page of the web interface dynamically, but keep in mind translator support would need to be added as well. -Installation -==== +## Installation To get a basic implementation up and running locally: @@ -61,8 +49,7 @@ To get a basic implementation up and running locally: - ``make django-open`` -*** Copyright Notice *** -==== +## Copyright Notice Security Catch and Release Automation Manager (SCRAM) Copyright (c) 2022, The Regents of the University of California, through Lawrence Berkeley diff --git a/docs/__init__.py b/docs/__init__.py index 8772c827..d34efb05 100644 --- a/docs/__init__.py +++ b/docs/__init__.py @@ -1 +1 @@ -# Included so that Django's startproject comment runs against the docs directory +"""Included so that Django's startproject comment runs against the docs directory.""" diff --git a/docs/conf.py b/docs/conf.py deleted file mode 100644 index d4d4bad7..00000000 --- a/docs/conf.py +++ /dev/null @@ -1,63 +0,0 @@ -# Configuration file for the Sphinx documentation builder. -# -# This file only contains a selection of the most common options. For a full -# list see the documentation: -# https://www.sphinx-doc.org/en/master/usage/configuration.html - -# -- Path setup -------------------------------------------------------------- - -# If extensions (or modules to document with autodoc) are in another directory, -# add these directories to sys.path here. If the directory is relative to the -# documentation root, use os.path.abspath to make it absolute, like shown here. - -import os -import sys - -import django - -if os.getenv("READTHEDOCS", default=False) == "True": - sys.path.insert(0, os.path.abspath("..")) - os.environ["DJANGO_READ_DOT_ENV_FILE"] = "True" - os.environ["USE_DOCKER"] = "no" -else: - sys.path.insert(0, os.path.abspath("/app")) -os.environ["DATABASE_URL"] = "sqlite:///readthedocs.db" -os.environ.setdefault("DJANGO_SETTINGS_MODULE", "config.settings.local") -django.setup() - -# -- Project information ----------------------------------------------------- - -project = "SCRAM" -copyright = """2021, Sam Oehlert""" -author = "Sam Oehlert" - - -# -- General configuration --------------------------------------------------- - -# Add any Sphinx extension module names here, as strings. They can be -# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom -# ones. -extensions = [ - "sphinx.ext.autodoc", - "sphinx.ext.napoleon", -] - -# Add any paths that contain templates here, relative to this directory. -# templates_path = ["_templates"] - -# List of patterns, relative to source directory, that match files and -# directories to ignore when looking for source files. -# This pattern also affects html_static_path and html_extra_path. -exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] - -# -- Options for HTML output ------------------------------------------------- - -# The theme to use for HTML and HTML Help pages. See the documentation for -# a list of builtin themes. -# -html_theme = "alabaster" - -# Add any paths that contain custom static files (such as style sheets) here, -# relative to this directory. They are copied after the builtin static files, -# so a file named "default.css" will overwrite the builtin "default.css". -# html_static_path = ["_static"] diff --git a/docs/cookiecutter.md b/docs/cookiecutter.md deleted file mode 100644 index 17445af4..00000000 --- a/docs/cookiecutter.md +++ /dev/null @@ -1,71 +0,0 @@ -This documentation was initially provided in the README.rst from the cookiecutter used to generate the project - -Settings --------- - -Moved to settings_. - -.. _settings: http://cookiecutter-django.readthedocs.io/en/latest/settings.html - -Basic Commands --------------- - -Setting Up Your Users -^^^^^^^^^^^^^^^^^^^^^ - -* To create a **normal user account**, just go to Sign Up and fill out the form. Once you submit it, you'll see a "Verify Your E-mail Address" page. Go to your console to see a simulated email verification message. Copy the link into your browser. Now the user's email should be verified and ready to go. - -* To create an **superuser account**, use this command:: - - $ python manage.py createsuperuser - -For convenience, you can keep your normal user logged in on Chrome and your superuser logged in on Firefox (or similar), so that you can see how the site behaves for both kinds of users. - -Type checks -^^^^^^^^^^^ - -Running type checks with mypy: - -:: - - $ mypy scram - -Test coverage -^^^^^^^^^^^^^ - -To run the tests, check your test coverage, and generate an HTML coverage report:: - - $ coverage run -m pytest - $ coverage html - $ open htmlcov/index.html - -Running tests with py.test -~~~~~~~~~~~~~~~~~~~~~~~~~~ - -:: - - $ pytest - -Live reloading and Sass CSS compilation -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -Moved to `Live reloading and SASS compilation`_. - -.. _`Live reloading and SASS compilation`: http://cookiecutter-django.readthedocs.io/en/latest/live-reloading-and-sass-compilation.html - - - -Deployment ----------- - -The following details how to deploy this application. - - - -Docker -^^^^^^ - -See detailed `cookiecutter-django Docker documentation`_. - -.. _`cookiecutter-django Docker documentation`: http://cookiecutter-django.readthedocs.io/en/latest/deployment-with-docker.html - diff --git a/docs/make.bat b/docs/make.bat deleted file mode 100644 index a22d92b7..00000000 --- a/docs/make.bat +++ /dev/null @@ -1,46 +0,0 @@ -@ECHO OFF - -pushd %~dp0 - -REM Command file for Sphinx documentation - - -if "%SPHINXBUILD%" == "" ( - set SPHINXBUILD=sphinx-build -c . -) -set SOURCEDIR=_source -set BUILDDIR=_build -set APP=..\scram - -if "%1" == "" goto help - -%SPHINXBUILD% >NUL 2>NUL -if errorlevel 9009 ( - echo. - echo.The 'sphinx-build' command was not found. Make sure you have Sphinx - echo.installed, then set the SPHINXBUILD environment variable to point - echo.to the full path of the 'sphinx-build' executable. Alternatively you - echo.may add the Sphinx directory to PATH. - echo. - echo.Install sphinx-autobuild for live serving. - echo.If you don't have Sphinx installed, grab it from - echo.http://sphinx-doc.org/ - exit /b 1 -) - -%SPHINXBUILD% -b %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% -goto end - -:livehtml -sphinx-autobuild -b html --open-browser -p 7000 --watch %APP% -c . %SOURCEDIR% %BUILDDIR%/html -GOTO :EOF - -:apidocs -sphinx-apidoc -o %SOURCEDIR%/api %APP% -GOTO :EOF - -:help -%SPHINXBUILD% -b help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% - -:end -popd diff --git a/docs/reference/django.md b/docs/reference/django.md new file mode 100644 index 00000000..be734554 --- /dev/null +++ b/docs/reference/django.md @@ -0,0 +1,3 @@ +# Route Manager + +::: scram.route_manager diff --git a/docs/reference/index.md b/docs/reference/index.md new file mode 100644 index 00000000..fc771185 --- /dev/null +++ b/docs/reference/index.md @@ -0,0 +1,7 @@ +# Module Reference + +The SCRAM ecosystem consists of two parts: + +A Django app, [route_manager](/reference/django) + +A translator service, [translator](/reference/translator) diff --git a/docs/reference/translator.md b/docs/reference/translator.md new file mode 100644 index 00000000..897d84e2 --- /dev/null +++ b/docs/reference/translator.md @@ -0,0 +1,3 @@ +# Translator + +::: translator diff --git a/manage.py b/manage.py index 37181fe1..cbc80d8d 100755 --- a/manage.py +++ b/manage.py @@ -1,27 +1,25 @@ #!/usr/bin/env python +"""Django's command-line utility for administrative tasks.""" + import os import sys from pathlib import Path -if __name__ == "__main__": - os.environ.setdefault("DJANGO_SETTINGS_MODULE", "config.settings.local") +def main(): + """Run administrative tasks.""" + os.environ.setdefault("DJANGO_SETTINGS_MODULE", "config.settings.local") try: - from django.core.management import execute_from_command_line - except ImportError: - # The above import may fail for some other reason. Ensure that the - # issue is really that Django is missing to avoid masking other - # exceptions on Python 2. - try: - import django # noqa - except ImportError: - raise ImportError( - "Couldn't import Django. Are you sure it's installed and " - "available on your PYTHONPATH environment variable? Did you " - "forget to activate a virtual environment?" - ) - - raise + from django.core.management import execute_from_command_line # noqa: PLC0415 + except ImportError as exc: + msg = ( + "Couldn't import Django. Are you sure it's installed and " + "available on your PYTHONPATH environment variable? Did you " + "forget to activate a virtual environment?" + ) + raise ImportError( + msg, + ) from exc # This allows easy placement of apps within the interior # scram directory. @@ -29,3 +27,7 @@ sys.path.append(str(current_path / "scram")) execute_from_command_line(sys.argv) + + +if __name__ == "__main__": + main() diff --git a/merge_production_dotenvs_in_dotenv.py b/merge_production_dotenvs_in_dotenv.py deleted file mode 100644 index b66558c3..00000000 --- a/merge_production_dotenvs_in_dotenv.py +++ /dev/null @@ -1,65 +0,0 @@ -import os -from pathlib import Path -from typing import Sequence - -import pytest - -ROOT_DIR_PATH = Path(__file__).parent.resolve() -PRODUCTION_DOTENVS_DIR_PATH = ROOT_DIR_PATH / ".envs" / ".production" -PRODUCTION_DOTENV_FILE_PATHS = [ - PRODUCTION_DOTENVS_DIR_PATH / ".django", - PRODUCTION_DOTENVS_DIR_PATH / ".postgres", -] -DOTENV_FILE_PATH = ROOT_DIR_PATH / ".env" - - -def merge(output_file_path: str, merged_file_paths: Sequence[str], append_linesep: bool = True) -> None: - with open(output_file_path, "w") as output_file: - for merged_file_path in merged_file_paths: - with open(merged_file_path, "r") as merged_file: - merged_file_content = merged_file.read() - output_file.write(merged_file_content) - if append_linesep: - output_file.write(os.linesep) - - -def main(): - merge(DOTENV_FILE_PATH, PRODUCTION_DOTENV_FILE_PATHS) - - -@pytest.mark.parametrize("merged_file_count", range(3)) -@pytest.mark.parametrize("append_linesep", [True, False]) -def test_merge(tmpdir_factory, merged_file_count: int, append_linesep: bool): - tmp_dir_path = Path(str(tmpdir_factory.getbasetemp())) - - output_file_path = tmp_dir_path / ".env" - - expected_output_file_content = "" - merged_file_paths = [] - for i in range(merged_file_count): - merged_file_ord = i + 1 - - merged_filename = ".service{}".format(merged_file_ord) - merged_file_path = tmp_dir_path / merged_filename - - merged_file_content = merged_filename * merged_file_ord - - with open(merged_file_path, "w+") as file: - file.write(merged_file_content) - - expected_output_file_content += merged_file_content - if append_linesep: - expected_output_file_content += os.linesep - - merged_file_paths.append(merged_file_path) - - merge(output_file_path, merged_file_paths, append_linesep) - - with open(output_file_path, "r") as output_file: - actual_output_file_content = output_file.read() - - assert actual_output_file_content == expected_output_file_content - - -if __name__ == "__main__": - main() diff --git a/mkdocs.yml b/mkdocs.yml new file mode 100644 index 00000000..704db80c --- /dev/null +++ b/mkdocs.yml @@ -0,0 +1,16 @@ +site_name: Security Catch and Release Automation Manager + +theme: + name: "material" + +plugins: + - mkdocstrings: + handlers: + python: + import: + - url: https://docs.djangoproject.com/en/4.2/_objects/ + base_url: https://docs.djangoproject.com/en/4.2/ + options: + show_submodules: true + - search + - section-index diff --git a/pyproject.toml b/pyproject.toml index 910571d6..5c332104 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,23 +25,91 @@ exclude_also = [ "if __name__ == .__main__.:", ] -# ==== black ==== -[tool.black] +# ===== ruff ==== +[tool.ruff] +exclude = [ + "migrations", +] + line-length = 119 -target-version = ['py311'] +target-version = 'py312' +preview = true + +[tool.ruff.lint] +select = [ + "A", # builtins + "ASYNC", # async + "B", # bugbear + "BLE", # blind-except + "C4", # comprehensions + "C90", # complexity + "COM", # commas + "D", # pydocstyle + "DJ", # django + "DOC", # pydoclint + "DTZ", # datetimez + "E", # pycodestyle + "EM", # errmsg + "ERA", # eradicate + "F", # pyflakes + "FBT", # boolean-trap + "FLY", # flynt + "G", # logging-format + "I", # isort + "ICN", # import-conventions + "ISC", # implicit-str-concat + "LOG", # logging + "N", # pep8-naming + "PERF", # perflint + "PIE", # pie + "PL", # pylint + "PTH", # use-pathlib + "Q", # quotes + "RET", # return + "RSE", # raise + "RUF", # Ruff + "S", # bandit + "SIM", # simplify + "SLF", # self + "SLOT", # slots + "T20", # print + "TRY", # tryceratops + "UP", # pyupgrade +] +ignore = [ + "COM812", # handled by the formatter + "DOC501", # add possible exceptions to the docstring (TODO) + "ISC001", # handled by the formatter + "RUF012", # need more widespread typing + "SIM102", # Use a single `if` statement instead of nested `if` statements + "SIM108", # Use ternary operator instead of `if`-`else`-block +] +[tool.ruff.lint.mccabe] +max-complexity = 7 # our current code adheres to this without too much effort -# ==== isort ==== -[tool.isort] -profile = "black" -line_length = 119 -known_first_party = [ - "scram", - "config", +[tool.ruff.lint.per-file-ignores] +"**/{tests}/*" = [ + "DOC201", # documenting return values + "DOC402", # documenting yield values + "PLR6301", # could be a static method + "S101", # use of assert + "S106", # hardcoded password +] +"test.py" = [ + "S105", # hardcoded password as argument +] +"scram/users/**" = [ + "DOC201", # documenting return values + "FBT001", # minimal issue; don't need to mess with in the User app + "PLR2004", # magic values when checking HTTP status codes +] +"**/views.py" = [ + "DOC201", # documenting return values; it's fairly obvious in a View ] -skip = ["venv/"] -skip_glob = ["**/migrations/*.py"] +[tool.ruff.lint.pydocstyle] +convention = "google" # ==== mypy ==== [tool.mypy] @@ -63,53 +131,3 @@ ignore_errors = true [tool.django-stubs] django_settings_module = "config.settings.test" - - -# ==== PyLint ==== -[tool.pylint.MASTER] -load-plugins = [ - "pylint_django", -] -django-settings-module = "config.settings.local" - -[tool.pylint.FORMAT] -max-line-length = 119 - -[tool.pylint."MESSAGES CONTROL"] -disable = [ - "missing-docstring", - "invalid-name", -] - -[tool.pylint.DESIGN] -max-parents = 13 - -[tool.pylint.TYPECHECK] -generated-members = [ - "REQUEST", - "acl_users", - "aq_parent", - "[a-zA-Z]+_set{1,2}", - "save", - "delete", -] - - -# ==== djLint ==== -[tool.djlint] -blank_line_after_tag = "load,extends" -close_void_tags = true -format_css = true -format_js = true -# TODO: remove T002 when fixed https://github.com/Riverside-Healthcare/djLint/issues/687 -ignore = "H006,H021,H023,H025,H029,H030,H031,T002,T003" -include = "H017,H035" -indent = 2 -max_line_length = 119 -profile = "django" - -[tool.djlint.css] -indent_size = 2 - -[tool.djlint.js] -indent_size = 2 diff --git a/pytest.ini b/pytest.ini deleted file mode 100644 index c2b3a233..00000000 --- a/pytest.ini +++ /dev/null @@ -1,3 +0,0 @@ -[pytest] -addopts = --ds=config.settings.test --reuse-db -python_files = tests.py test_*.py diff --git a/requirements/base.txt b/requirements/base.txt index 7577c200..b43b582a 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -22,8 +22,7 @@ django-simple-history~=3.1.1 # Django REST Framework djangorestframework~=3.15 # https://github.com/encode/django-rest-framework django-cors-headers==3.13.0 # https://github.com/adamchainz/django-cors-headers -# DRF-spectacular for api documentation -drf-spectacular # https://github.com/tfranzel/drf-spectacular +drf-spectacular # https://github.com/tfranzel/drf-spectacular # OIDC # ------------------------------------------------------------------------------ diff --git a/requirements/local.txt b/requirements/local.txt index 2a6c95ae..134a2213 100644 --- a/requirements/local.txt +++ b/requirements/local.txt @@ -2,7 +2,7 @@ Werkzeug[watchdog]==2.0.3 # https://github.com/pallets/werkzeug ipdb==0.13.9 # https://github.com/gotcha/ipdb -psycopg2-binary==2.9.3 # https://github.com/psycopg/psycopg2 +psycopg2-binary==2.9.10 # https://github.com/psycopg/psycopg2 watchgod==0.8.2 # https://github.com/samuelcolvin/watchgod # Testing @@ -13,8 +13,16 @@ behave-django==1.4.0 # https://github.com/behave/behave-django # Documentation # ------------------------------------------------------------------------------ -sphinx==5.0.1 # https://github.com/sphinx-doc/sphinx -sphinx-autobuild==2021.3.14 # https://github.com/GaretJax/sphinx-autobuild +mkdocs==1.6.1 +mkdocs-autorefs==1.2.0 +mkdocs-gen-files==0.5.0 +mkdocs-get-deps==0.2.0 +mkdocs-literate-nav==0.6.1 +mkdocs-material==9.5.44 +mkdocs-material-extensions==1.3.1 +mkdocs-section-index==0.3.9 +mkdocstrings==0.27.0 +mkdocstrings-python==1.12.2 # Code quality # ------------------------------------------------------------------------------ diff --git a/scram/__init__.py b/scram/__init__.py index d5feb05a..8f5a0220 100644 --- a/scram/__init__.py +++ b/scram/__init__.py @@ -1,4 +1,4 @@ """The Django project for Security Catch and Release Automation Manager (SCRAM).""" __version__ = "1.1.1" -__version_info__ = tuple([int(num) if num.isdigit() else num for num in __version__.replace("-", ".", 1).split(".")]) +__version_info__ = tuple(int(num) if num.isdigit() else num for num in __version__.replace("-", ".", 1).split(".")) diff --git a/scram/contrib/__init__.py b/scram/contrib/__init__.py index 88d92bd3..adb76b31 100644 --- a/scram/contrib/__init__.py +++ b/scram/contrib/__init__.py @@ -1,5 +1,4 @@ -""" -To understand why this file is here, please read the CookieCutter documentation. +"""To understand why this file is here, please read the CookieCutter documentation. http://cookiecutter-django.readthedocs.io/en/latest/faq.html#why-is-there-a-django-contrib-sites-directory-in-cookiecutter-django """ diff --git a/scram/contrib/sites/__init__.py b/scram/contrib/sites/__init__.py index 88d92bd3..adb76b31 100644 --- a/scram/contrib/sites/__init__.py +++ b/scram/contrib/sites/__init__.py @@ -1,5 +1,4 @@ -""" -To understand why this file is here, please read the CookieCutter documentation. +"""To understand why this file is here, please read the CookieCutter documentation. http://cookiecutter-django.readthedocs.io/en/latest/faq.html#why-is-there-a-django-contrib-sites-directory-in-cookiecutter-django """ diff --git a/scram/route_manager/api/exceptions.py b/scram/route_manager/api/exceptions.py index 74d6f2e2..0445dbe4 100644 --- a/scram/route_manager/api/exceptions.py +++ b/scram/route_manager/api/exceptions.py @@ -10,7 +10,7 @@ class PrefixTooLarge(APIException): v4_min_prefix = getattr(settings, "V4_MINPREFIX", 0) v6_min_prefix = getattr(settings, "V6_MINPREFIX", 0) status_code = 400 - default_detail = f"You've supplied too large of a network. settings.V4_MINPREFIX = {v4_min_prefix} settings.V6_MINPREFIX = {v6_min_prefix}" # noqa: 501 + default_detail = f"You've supplied too large of a network. settings.V4_MINPREFIX = {v4_min_prefix} settings.V6_MINPREFIX = {v6_min_prefix}" # noqa: E501 default_code = "prefix_too_large" diff --git a/scram/route_manager/api/serializers.py b/scram/route_manager/api/serializers.py index 6baf7d95..f1bc1907 100644 --- a/scram/route_manager/api/serializers.py +++ b/scram/route_manager/api/serializers.py @@ -2,6 +2,7 @@ import logging +from drf_spectacular.utils import extend_schema_field from netfields import rest_framework from rest_framework import serializers from rest_framework.fields import CurrentUserDefault @@ -12,8 +13,13 @@ logger = logging.getLogger(__name__) +@extend_schema_field(field={"type": "string", "format": "cidr"}) +class CustomCidrAddressField(rest_framework.CidrAddressField): + """Define a wrapper field so swagger can properly handle the inherited field.""" + + class ActionTypeSerializer(serializers.ModelSerializer): - """This serializer defines no new fields.""" + """Map the serializer to the model via Meta.""" class Meta: """Maps to the ActionType model, and specifies the fields exposed by the API.""" @@ -25,7 +31,7 @@ class Meta: class RouteSerializer(serializers.ModelSerializer): """Exposes route as a CIDR field.""" - route = rest_framework.CidrAddressField() + route = CustomCidrAddressField() class Meta: """Maps to the Route model, and specifies the fields exposed by the API.""" @@ -37,7 +43,7 @@ class Meta: class ClientSerializer(serializers.ModelSerializer): - """This serializer defines no new fields.""" + """Map the serializer to the model via Meta.""" class Meta: """Maps to the Client model, and specifies the fields exposed by the API.""" @@ -50,9 +56,11 @@ class EntrySerializer(serializers.HyperlinkedModelSerializer): """Due to the use of ForeignKeys, this follows some relationships to make sense via the API.""" url = serializers.HyperlinkedIdentityField( - view_name="api:v1:entry-detail", lookup_url_kwarg="pk", lookup_field="route" + view_name="api:v1:entry-detail", + lookup_url_kwarg="pk", + lookup_field="route", ) - route = rest_framework.CidrAddressField() + route = CustomCidrAddressField() actiontype = serializers.CharField(default="block") if CurrentUserDefault(): # This is set if we are calling this serializer from WUI @@ -67,28 +75,36 @@ class Meta: model = Entry fields = ["route", "actiontype", "url", "comment", "who"] - def get_comment(self, obj): - """Provide a nicer name for change reason.""" + @staticmethod + def get_comment(obj): + """Provide a nicer name for change reason. + + Returns: + string: The change reason that modified the Entry. + """ return obj.get_change_reason() - def create(self, validated_data): - """Implement custom logic and validates creating a new route.""" + @staticmethod + def create(validated_data): + """Implement custom logic and validates creating a new route.""" # noqa: DOC201 valid_route = validated_data.pop("route") actiontype = validated_data.pop("actiontype") comment = validated_data.pop("comment") - route_instance, created = Route.objects.get_or_create(route=valid_route) + route_instance, _ = Route.objects.get_or_create(route=valid_route) actiontype_instance = ActionType.objects.get(name=actiontype) - entry_instance, created = Entry.objects.get_or_create(route=route_instance, actiontype=actiontype_instance) + entry_instance, _ = Entry.objects.get_or_create(route=route_instance, actiontype=actiontype_instance) - logger.debug(f"{comment}") + logger.debug("Created entry with comment: %s", comment) update_change_reason(entry_instance, comment) return entry_instance class IgnoreEntrySerializer(serializers.ModelSerializer): - """This serializer defines no new fields.""" + """Map the route to the right field type.""" + + route = CustomCidrAddressField() class Meta: """Maps to the IgnoreEntry model, and specifies the fields exposed by the API.""" diff --git a/scram/route_manager/api/views.py b/scram/route_manager/api/views.py index ad26127e..70d30b34 100644 --- a/scram/route_manager/api/views.py +++ b/scram/route_manager/api/views.py @@ -10,6 +10,7 @@ from django.db.models import Q from django.http import Http404 from django.utils.dateparse import parse_datetime +from drf_spectacular.utils import extend_schema from rest_framework import status, viewsets from rest_framework.permissions import AllowAny, IsAuthenticated from rest_framework.response import Response @@ -19,8 +20,13 @@ from .serializers import ActionTypeSerializer, ClientSerializer, EntrySerializer, IgnoreEntrySerializer channel_layer = get_channel_layer() +logger = logging.getLogger(__name__) +@extend_schema( + description="API endpoint for actiontypes", + responses={200: ActionTypeSerializer}, +) class ActionTypeViewSet(viewsets.ReadOnlyModelViewSet): """Lookup ActionTypes by name when authenticated, and bind to the serializer.""" @@ -30,6 +36,10 @@ class ActionTypeViewSet(viewsets.ReadOnlyModelViewSet): lookup_field = "name" +@extend_schema( + description="API endpoint for ignore entries", + responses={200: IgnoreEntrySerializer}, +) class IgnoreEntryViewSet(viewsets.ModelViewSet): """Lookup IgnoreEntries by route when authenticated, and bind to the serializer.""" @@ -39,6 +49,10 @@ class IgnoreEntryViewSet(viewsets.ModelViewSet): lookup_field = "route" +@extend_schema( + description="API endpoint for clients", + responses={200: ClientSerializer}, +) class ClientViewSet(viewsets.ModelViewSet): """Lookup Client by hostname on POSTs regardless of authentication, and bind to the serializer.""" @@ -50,6 +64,10 @@ class ClientViewSet(viewsets.ModelViewSet): http_method_names = ["post"] +@extend_schema( + description="API endpoint for entries", + responses={200: EntrySerializer}, +) class EntryViewSet(viewsets.ModelViewSet): """Lookup Entry when authenticated, and bind to the serializer.""" @@ -60,8 +78,7 @@ class EntryViewSet(viewsets.ModelViewSet): http_method_names = ["get", "post", "head", "delete"] def get_permissions(self): - """ - Override the permissions classes for POST method since we want to accept Entry creates from any client. + """Override the permissions classes for POST method since we want to accept Entry creates from any client. Note: We make authorization decisions on whether to actually create the object in the perform_create method later. @@ -70,6 +87,31 @@ def get_permissions(self): return [AllowAny()] return super().get_permissions() + def check_client_authorization(self, actiontype): + """Ensure that a given client is authorized to use a given actiontype.""" + uuid = self.request.data.get("uuid") + if uuid: + authorized_actiontypes = Client.objects.filter(uuid=uuid).values_list( + "authorized_actiontypes__name", + flat=True, + ) + authorized_client = Client.objects.filter(uuid=uuid).values("is_authorized") + if not authorized_client or actiontype not in authorized_actiontypes: + logger.debug("Client: %s, actiontypes: %s", uuid, authorized_actiontypes) + logger.info("%s is not allowed to add an entry to the %s list.", uuid, actiontype) + raise ActiontypeNotAllowed + elif not self.request.user.has_perm("route_manager.can_add_entry"): + raise PermissionDenied + + @staticmethod + def check_ignore_list(route): + """Ensure that we're not trying to block something from the ignore list.""" + overlapping_ignore = IgnoreEntry.objects.filter(route__net_overlaps=route) + if overlapping_ignore.count(): + ignore_entries = [str(ignore_entry["route"]) for ignore_entry in overlapping_ignore.values()] + logger.info("Cannot proceed adding %s. The ignore list contains %s.", route, ignore_entries) + raise IgnoredRoute + def perform_create(self, serializer): """Create a new Entry, causing that route to receive the actiontype (i.e. block).""" actiontype = serializer.validated_data["actiontype"] @@ -84,60 +126,41 @@ def perform_create(self, serializer): tmp_exp = self.request.data.get("expiration", "") try: - expiration = parse_datetime(tmp_exp) # noqa: F841 + expiration = parse_datetime(tmp_exp) except ValueError: - logging.info(f"Could not parse expiration DateTime string {tmp_exp!r}.") + logger.warning("Could not parse expiration DateTime string: %s", tmp_exp) # Make sure we put in an acceptable sized prefix min_prefix = getattr(settings, f"V{route.version}_MINPREFIX", 0) if route.prefixlen < min_prefix: - raise PrefixTooLarge() - - # Make sure this client is authorized to add this entry with this actiontype - if self.request.data.get("uuid"): - client_uuid = self.request.data["uuid"] - authorized_actiontypes = Client.objects.filter(uuid=client_uuid).values_list( - "authorized_actiontypes__name", flat=True + raise PrefixTooLarge + + self.check_client_authorization(actiontype) + self.check_ignore_list(route) + + elements = WebSocketSequenceElement.objects.filter(action_type__name=actiontype).order_by("order_num") + if not elements: + logger.warning("No elements found for actiontype: %s", actiontype) + + for element in elements: + msg = element.websocketmessage + msg.msg_data[msg.msg_data_route_field] = str(route) + # Must match a channel name defined in asgi.py + async_to_sync(channel_layer.group_send)( + f"translator_{actiontype}", + {"type": msg.msg_type, "message": msg.msg_data}, ) - authorized_client = Client.objects.filter(uuid=client_uuid).values("is_authorized") - if not authorized_client or actiontype not in authorized_actiontypes: - logging.debug(f"Client {client_uuid} actiontypes: {authorized_actiontypes}") - logging.info(f"{client_uuid} is not allowed to add an entry to the {actiontype} list") - raise ActiontypeNotAllowed() - elif not self.request.user.has_perm("route_manager.can_add_entry"): - raise PermissionDenied() - # Don't process if we have the entry in the ignorelist - overlapping_ignore = IgnoreEntry.objects.filter(route__net_overlaps=route) - if overlapping_ignore.count(): - ignore_entries = [] - for ignore_entry in overlapping_ignore.values(): - ignore_entries.append(str(ignore_entry["route"])) - logging.info(f"Cannot proceed adding {route}. The ignore list contains {ignore_entries}") - raise IgnoredRoute - else: - elements = WebSocketSequenceElement.objects.filter(action_type__name=actiontype).order_by("order_num") - if not elements: - logging.warning(f"No elements found for actiontype={actiontype}.") - - for element in elements: - msg = element.websocketmessage - msg.msg_data[msg.msg_data_route_field] = str(route) - # Must match a channel name defined in asgi.py - async_to_sync(channel_layer.group_send)( - f"translator_{actiontype}", {"type": msg.msg_type, "message": msg.msg_data} - ) - - serializer.save() - - entry = Entry.objects.get(route__route=route, actiontype__name=actiontype) - if expiration: - entry.expiration = expiration - entry.who = who - entry.is_active = True - entry.comment = comment - logging.info(f"{entry}") - entry.save() + serializer.save() + + entry = Entry.objects.get(route__route=route, actiontype__name=actiontype) + if expiration: + entry.expiration = expiration + entry.who = who + entry.is_active = True + entry.comment = comment + logger.info("Created entry: %s", entry) + entry.save() @staticmethod def find_entries(arg, active_filter=None): @@ -149,13 +172,13 @@ def find_entries(arg, active_filter=None): try: pk = int(arg) query = Q(pk=pk) - except ValueError: + except ValueError as exc: # Maybe a CIDR? We want the ValueError at this point, if not. cidr = ipaddress.ip_network(arg, strict=False) min_prefix = getattr(settings, f"V{cidr.version}_MINPREFIX", 0) if cidr.prefixlen < min_prefix: - raise PrefixTooLarge() + raise PrefixTooLarge from exc query = Q(route__route__net_overlaps=cidr) diff --git a/scram/route_manager/authentication_backends.py b/scram/route_manager/authentication_backends.py index 30288520..1831d788 100644 --- a/scram/route_manager/authentication_backends.py +++ b/scram/route_manager/authentication_backends.py @@ -5,48 +5,45 @@ from mozilla_django_oidc.auth import OIDCAuthenticationBackend +def groups_overlap(a, b): + """Helper function to see if a and b have any overlap. + + Returns: + bool: True if there's any overlap between a and b. + """ + return not set(a).isdisjoint(b) + + class ESnetAuthBackend(OIDCAuthenticationBackend): """Extend the OIDC backend with a custom permission model.""" - def update_groups(self, user, claims): + @staticmethod + def update_groups(user, claims): """Set the user's group(s) to whatever is in the claims.""" - claimed_groups = claims.get("groups", []) - effective_groups = [] - is_admin = False - - ro_group = Group.objects.get(name="readonly") - rw_group = Group.objects.get(name="readwrite") - - for g in claimed_groups: - # If any of the user's groups are in DENIED_GROUPS, deny them and stop processing immediately - if g in settings.SCRAM_DENIED_GROUPS: - effective_groups = [] - is_admin = False - break - - if g in settings.SCRAM_ADMIN_GROUPS: - is_admin = True - - if g in settings.SCRAM_READONLY_GROUPS: - if ro_group not in effective_groups: - effective_groups.append(ro_group) + claimed_groups = claims.get("groups", []) - if g in settings.SCRAM_READWRITE_GROUPS: - if rw_group not in effective_groups: - effective_groups.append(rw_group) + if groups_overlap(claimed_groups, settings.SCRAM_DENIED_GROUPS): + is_admin = False + # Don't even look at anything else if they're denied + else: + is_admin = groups_overlap(claimed_groups, settings.SCRAM_ADMIN_GROUPS) + if groups_overlap(claimed_groups, settings.SCRAM_READWRITE_GROUPS): + effective_groups.append(Group.objects.get(name="readwrite")) + if groups_overlap(claimed_groups, settings.SCRAM_READONLY_GROUPS): + effective_groups.append(Group.objects.get(name="readonly")) user.groups.set(effective_groups) user.is_staff = user.is_superuser = is_admin user.save() def create_user(self, claims): - """Wrap the superclass's user creation.""" - user = super(ESnetAuthBackend, self).create_user(claims) + """Wrap the superclass's user creation.""" # noqa: DOC201 + user = super().create_user(claims) return self.update_user(user, claims) def update_user(self, user, claims): - """Determine the user name from the claims and update said user's groups.""" + """Determine the user name from the claims and update said user's groups.""" # noqa: DOC201 user.name = claims.get("given_name", "") + " " + claims.get("family_name", "") user.username = claims.get("preferred_username", "") if claims.get("groups", False): diff --git a/scram/route_manager/context_processors.py b/scram/route_manager/context_processors.py index 61abde75..0daf06c9 100644 --- a/scram/route_manager/context_processors.py +++ b/scram/route_manager/context_processors.py @@ -5,7 +5,11 @@ def login_logout(request): - """Pass through the relevant URLs from the settings.""" + """Pass through the relevant URLs from the settings. + + Returns: + dict: login and logout URLs + """ login_url = reverse(settings.LOGIN_URL) logout_url = reverse(settings.LOGOUT_URL) return {"login": login_url, "logout": logout_url} diff --git a/scram/route_manager/models.py b/scram/route_manager/models.py index 3bbf67e3..aa72f7ba 100644 --- a/scram/route_manager/models.py +++ b/scram/route_manager/models.py @@ -10,6 +10,8 @@ from netfields import CidrAddressField from simple_history.models import HistoricalRecords +logger = logging.getLogger(__name__) + class Route(models.Model): """Define a route as a CIDR route and a UUID.""" @@ -17,14 +19,15 @@ class Route(models.Model): route = CidrAddressField(unique=True) uuid = models.UUIDField(db_index=True, default=uuid_lib.uuid4, editable=False) - def get_absolute_url(self): - """Ensure we use UUID on the API side instead.""" - return reverse("") - def __str__(self): - """Don't display the UUID, only the route.""" + """Don't display the UUID, only the route.""" # noqa: DOC201 return str(self.route) + @staticmethod + def get_absolute_url(): + """Ensure we use UUID on the API side instead.""" # noqa: DOC201 + return reverse("") + class ActionType(models.Model): """Define a type of action that can be done with a given route. e.g. Block, shunt, redirect, etc.""" @@ -34,7 +37,7 @@ class ActionType(models.Model): history = HistoricalRecords() def __str__(self): - """Display clearly whether the action is currently available.""" + """Display clearly whether the action is currently available.""" # noqa: DOC201 if not self.available: return f"{self.name} (Inactive)" return self.name @@ -52,7 +55,7 @@ class WebSocketMessage(models.Model): ) def __str__(self): - """Display clearly what the fields are used for.""" + """Display clearly what the fields are used for.""" # noqa: DOC201 return f"{self.msg_type}: {self.msg_data} with the route in key {self.msg_data_route_field}" @@ -62,7 +65,7 @@ class WebSocketSequenceElement(models.Model): websocketmessage = models.ForeignKey("WebSocketMessage", on_delete=models.CASCADE) order_num = models.SmallIntegerField( "Sequences are sent from the smallest order_num to the highest. " - + "Messages with the same order_num could be sent in any order", + "Messages with the same order_num could be sent in any order", default=0, ) @@ -76,10 +79,10 @@ class WebSocketSequenceElement(models.Model): action_type = models.ForeignKey("ActionType", on_delete=models.CASCADE) def __str__(self): - """Summarize the fields into something short and readable.""" + """Summarize the fields into something short and readable.""" # noqa: DOC201 return ( f"{self.websocketmessage} as order={self.order_num} for " - + f"{self.verb} actions on actiontype={self.action_type}" + f"{self.verb} actions on actiontype={self.action_type}" ) @@ -88,7 +91,7 @@ class Entry(models.Model): route = models.ForeignKey("Route", on_delete=models.PROTECT) actiontype = models.ForeignKey("ActionType", on_delete=models.PROTECT) - comment = models.TextField(blank=True, null=True) + comment = models.TextField(blank=True, default="") is_active = models.BooleanField(default=True) # TODO: fix name if this works history = HistoricalRecords() @@ -98,30 +101,10 @@ class Entry(models.Model): expiration_reason = models.CharField( help_text="Optional reason for the expiration", max_length=200, - null=True, blank=True, + default="", ) - def delete(self, *args, **kwargs): - """Set inactive instead of deleting, as we want to ensure a history of entries.""" - if not self.is_active: - # We've already expired this route, don't send another message - return - else: - # We don't actually delete records; we set them to inactive and then tell the translator to remove them - logging.info(f"Deactivating {self.route}") - self.is_active = False - self.save() - - # Unblock it - async_to_sync(channel_layer.group_send)( - f"translator_{self.actiontype}", - { - "type": "translator_remove", - "message": {"route": str(self.route)}, - }, - ) - class Meta: """Ensure that multiple routes can be added as long as they have different action types.""" @@ -129,14 +112,37 @@ class Meta: verbose_name_plural = "Entries" def __str__(self): - """Summarize the most important fields to something easily readable.""" + """Summarize the most important fields to something easily readable.""" # noqa: DOC201 desc = f"{self.route} ({self.actiontype})" if not self.is_active: desc += " (inactive)" return desc + def delete(self, *args, **kwargs): + """Set inactive instead of deleting, as we want to ensure a history of entries.""" + if not self.is_active: + # We've already expired this route, don't send another message + return + # We don't actually delete records; we set them to inactive and then tell the translator to remove them + logger.info("Deactivating %s", self.route) + self.is_active = False + self.save() + + # Unblock it + async_to_sync(channel_layer.group_send)( + f"translator_{self.actiontype}", + { + "type": "translator_remove", + "message": {"route": str(self.route)}, + }, + ) + def get_change_reason(self): - """Traverse come complex relationships to determine the most recent change reason.""" + """Traverse some complex relationships to determine the most recent change reason. + + Returns: + str: The most recent change reason + """ hist_mgr = getattr(self, self._meta.simple_history_manager_attribute) return hist_mgr.order_by("-history_date").first().history_change_reason @@ -154,7 +160,7 @@ class Meta: verbose_name_plural = "Ignored Entries" def __str__(self): - """Only display the route.""" + """Only display the route.""" # noqa: DOC201 return str(self.route) @@ -168,7 +174,7 @@ class Client(models.Model): authorized_actiontypes = models.ManyToManyField(ActionType) def __str__(self): - """Only display the hostname.""" + """Only display the hostname.""" # noqa: DOC201 return str(self.hostname) diff --git a/scram/route_manager/tests/acceptance/steps/common.py b/scram/route_manager/tests/acceptance/steps/common.py index c7c03f8c..f688bd64 100644 --- a/scram/route_manager/tests/acceptance/steps/common.py +++ b/scram/route_manager/tests/acceptance/steps/common.py @@ -3,34 +3,35 @@ import datetime import time -import django.conf as conf from asgiref.sync import async_to_sync from behave import given, step, then, when from channels.layers import get_channel_layer +from django import conf from django.urls import reverse from scram.route_manager.models import ActionType, Client, WebSocketMessage, WebSocketSequenceElement @given("a {name} actiontype is defined") -def step_impl(context, name): +def create_actiontype(context, name): """Create an actiontype of that name.""" context.channel_layer = get_channel_layer() async_to_sync(context.channel_layer.group_send)( - f"translator_{name}", {"type": "translator_remove_all", "message": {}} + f"translator_{name}", + {"type": "translator_remove_all", "message": {}}, ) - at, created = ActionType.objects.get_or_create(name=name) - wsm, created = WebSocketMessage.objects.get_or_create(msg_type="translator_add", msg_data_route_field="route") + at, _ = ActionType.objects.get_or_create(name=name) + wsm, _ = WebSocketMessage.objects.get_or_create(msg_type="translator_add", msg_data_route_field="route") wsm.save() - wsse, created = WebSocketSequenceElement.objects.get_or_create(websocketmessage=wsm, verb="A", action_type=at) + wsse, _ = WebSocketSequenceElement.objects.get_or_create(websocketmessage=wsm, verb="A", action_type=at) wsse.save() @given("a client with {name} authorization") -def step_impl(context, name): +def create_authed_client(context, name): """Create a client and authorize it for that action type.""" - at, created = ActionType.objects.get_or_create(name=name) + at, _ = ActionType.objects.get_or_create(name=name) authorized_client = Client.objects.create( hostname="authorized_client.es.net", uuid="0e7e1cbd-7d73-4968-bc4b-ce3265dc2fd3", @@ -40,7 +41,7 @@ def step_impl(context, name): @given("a client without {name} authorization") -def step_impl(context, name): +def create_unauthed_client(context, name): """Create a client that has no authorized action types.""" unauthorized_client = Client.objects.create( hostname="unauthorized_client.es.net", @@ -50,26 +51,26 @@ def step_impl(context, name): @when("we're logged in") -def step_impl(context): +def login(context): """Login.""" context.test.client.login(username="user", password="password") @when("the CIDR prefix limits are {v4_minprefix:d} and {v6_minprefix:d}") -def step_impl(context, v4_minprefix, v6_minprefix): +def set_cidr_limit(context, v4_minprefix, v6_minprefix): """Override our settings with the provided values.""" conf.settings.V4_MINPREFIX = v4_minprefix conf.settings.V6_MINPREFIX = v6_minprefix @then("we get a {status_code:d} status code") -def step_impl(context, status_code): +def check_status_code(context, status_code): """Ensure the status code response matches the expected value.""" context.test.assertEqual(context.response.status_code, status_code) @when("we add the entry {value:S}") -def step_impl(context, value): +def add_entry(context, value): """Block the provided route.""" context.response = context.test.client.post( reverse("api:v1:entry-list"), @@ -86,7 +87,7 @@ def step_impl(context, value): @when("we add the entry {value:S} with comment {comment}") -def step_impl(context, value, comment): +def add_entry_with_comment(context, value, comment): """Block the provided route and add a comment.""" context.response = context.test.client.post( reverse("api:v1:entry-list"), @@ -102,7 +103,7 @@ def step_impl(context, value, comment): @when("we add the entry {value:S} with expiration {exp:S}") -def step_impl(context, value, exp): +def add_entry_with_absolute_expiration(context, value, exp): """Block the provided route and add an absolute expiration datetime.""" context.response = context.test.client.post( reverse("api:v1:entry-list"), @@ -119,10 +120,10 @@ def step_impl(context, value, exp): @when("we add the entry {value:S} with expiration in {secs:d} seconds") -def step_impl(context, value, secs): +def add_entry_with_relative_expiration(context, value, secs): """Block the provided route and add a relative expiration.""" td = datetime.timedelta(seconds=secs) - expiration = datetime.datetime.now() + td + expiration = datetime.datetime.now(tz=datetime.UTC) + td context.response = context.test.client.post( reverse("api:v1:entry-list"), @@ -139,39 +140,40 @@ def step_impl(context, value, secs): @step("we wait {secs:d} seconds") -def step_impl(context, secs): +def wait(context, secs): """Wait to allow messages to propagate.""" time.sleep(secs) @then("we remove expired entries") -def step_impl(context): +def remove_expired(context): """Call the function that removes expired entries.""" context.response = context.test.client.get(reverse("route_manager:process-expired")) @when("we add the ignore entry {value:S}") -def step_impl(context, value): +def add_ignore_entry(context, value): """Add an IgnoreEntry with the specified route.""" context.response = context.test.client.post( - reverse("api:v1:ignoreentry-list"), {"route": value, "comment": "test api"} + reverse("api:v1:ignoreentry-list"), + {"route": value, "comment": "test api"}, ) @when("we remove the {model} {value}") -def step_impl(context, model, value): +def remove_an_object(context, model, value): """Remove any model object with the matching value.""" context.response = context.test.client.delete(reverse(f"api:v1:{model.lower()}-detail", args=[value])) @when("we list the {model}s") -def step_impl(context, model): +def list_objects(context, model): """List all objects of an arbitrary model.""" context.response = context.test.client.get(reverse(f"api:v1:{model.lower()}-list")) @when("we update the {model} {value_from} to {value_to}") -def step_impl(context, model, value_from, value_to): +def update_object(context, model, value_from, value_to): """Modify any model object with the matching value to the new value instead.""" context.response = context.test.client.patch( reverse(f"api:v1:{model.lower()}-detail", args=[value_from]), @@ -180,7 +182,7 @@ def step_impl(context, model, value_from, value_to): @then("the number of {model}s is {num:d}") -def step_impl(context, model, num): +def count_objects(context, model, num): """Count the number of objects of an arbitrary model.""" objs = context.test.client.get(reverse(f"api:v1:{model.lower()}-list")) context.test.assertEqual(len(objs.json()), num) @@ -190,7 +192,7 @@ def step_impl(context, model, num): @then("{value} is one of our list of {model}s") -def step_impl(context, value, model): +def check_object(context, value, model): """Ensure that the arbitrary model has an object with the specified value.""" objs = context.test.client.get(reverse(f"api:v1:{model.lower()}-list")) @@ -206,7 +208,7 @@ def step_impl(context, value, model): @when("we register a client named {hostname} with the uuid of {uuid}") -def step_impl(context, hostname, uuid): +def add_client(context, hostname, uuid): """Create a client with a specific UUID.""" context.response = context.test.client.post( reverse("api:v1:client-list"), diff --git a/scram/route_manager/tests/acceptance/steps/ip.py b/scram/route_manager/tests/acceptance/steps/ip.py index 8154f04b..34536a86 100644 --- a/scram/route_manager/tests/acceptance/steps/ip.py +++ b/scram/route_manager/tests/acceptance/steps/ip.py @@ -7,7 +7,7 @@ @then("{route} is contained in our list of {model}s") -def step_impl(context, route, model): +def check_route(context, route, model): """Perform a CIDR match on the matching object.""" objs = context.test.client.get(reverse(f"api:v1:{model.lower()}-list")) ip_target = ipaddress.ip_address(route) @@ -23,7 +23,7 @@ def step_impl(context, route, model): @when("we query for {ip}") -def step_impl(context, ip): +def check_ip(context, ip): """Find an Entry for the specified IP.""" try: context.response = context.test.client.get(reverse("api:v1:entry-detail", args=[ip])) @@ -34,13 +34,13 @@ def step_impl(context, ip): @then("we get a ValueError") -def step_impl(context): +def check_error(context): """Ensure we received a ValueError exception.""" assert isinstance(context.queryException, ValueError) @then("the change entry for {value:S} is {comment}") -def step_impl(context, value, comment): +def check_comment(context, value, comment): """Verify the comment for the Entry.""" try: objs = context.test.client.get(reverse("api:v1:entry-detail", args=[value])) diff --git a/scram/route_manager/tests/acceptance/steps/translator.py b/scram/route_manager/tests/acceptance/steps/translator.py index c6c2ff49..ff478a9a 100644 --- a/scram/route_manager/tests/acceptance/steps/translator.py +++ b/scram/route_manager/tests/acceptance/steps/translator.py @@ -7,10 +7,10 @@ from config.asgi import ws_application -async def query_translator(route, actiontype, is_announced=True): +async def query_translator(route, actiontype, is_announced): """Ensure the specified route is currently either blocked or unblocked.""" communicator = WebsocketCommunicator(ws_application, f"/ws/route_manager/webui_{actiontype}/") - connected, subprotocol = await communicator.connect() + connected, _ = await communicator.connect() assert connected await communicator.send_json_to({"type": "wui_check_req", "message": {"route": route}}) @@ -23,13 +23,13 @@ async def query_translator(route, actiontype, is_announced=True): @then("{route} is announced by {actiontype} translators") @async_run_until_complete -async def step_impl(context, route, actiontype): +async def check_blocked(context, route, actiontype): """Ensure the specified route is currently blocked.""" - await query_translator(route, actiontype) + await query_translator(route, actiontype, is_announced=True) @then("{route} is not announced by {actiontype} translators") @async_run_until_complete -async def step_impl(context, route, actiontype): +async def check_unblocked(context, route, actiontype): """Ensure the specified route is currently unblocked.""" await query_translator(route, actiontype, is_announced=False) diff --git a/scram/route_manager/tests/functional_tests.py b/scram/route_manager/tests/functional_tests.py index c4515b13..a87ce21e 100644 --- a/scram/route_manager/tests/functional_tests.py +++ b/scram/route_manager/tests/functional_tests.py @@ -6,8 +6,6 @@ class HomePageTest(unittest.TestCase): """Ensure the home page works.""" - pass - if __name__ == "__main__": unittest.main(warnings="ignore") diff --git a/scram/route_manager/tests/test_authorization.py b/scram/route_manager/tests/test_authorization.py index 17de2b6d..9a0d4bff 100644 --- a/scram/route_manager/tests/test_authorization.py +++ b/scram/route_manager/tests/test_authorization.py @@ -236,7 +236,7 @@ def test_authorized_removal(self): def test_disabled(self): """Pass all the groups, user should be disabled as it takes precedence.""" claims = dict(self.claims) - claims["groups"] = [settings.SCRAM_GROUPS] + claims["groups"] = settings.SCRAM_GROUPS user = ESnetAuthBackend().create_user(claims) self.assertFalse(user.is_staff) diff --git a/scram/route_manager/tests/test_autocreate_admin.py b/scram/route_manager/tests/test_autocreate_admin.py index b9d2e432..88270f4e 100644 --- a/scram/route_manager/tests/test_autocreate_admin.py +++ b/scram/route_manager/tests/test_autocreate_admin.py @@ -1,13 +1,15 @@ -"""This file contains tests for the auto-creation of an admin user.""" +"""Test the auto-creation of an admin user.""" import pytest -from django.contrib.auth.models import User from django.contrib.messages import get_messages from django.test import Client from django.urls import reverse from scram.users.models import User +LEVEL_SUCCESS = 25 +LEVEL_INFO = 20 + @pytest.mark.django_db def test_autocreate_admin(settings): @@ -15,15 +17,15 @@ def test_autocreate_admin(settings): settings.AUTOCREATE_ADMIN = True client = Client() response = client.get(reverse("route_manager:home")) - assert response.status_code == 200 + assert response.status_code == 200 # noqa: PLR2004 assert User.objects.count() == 1 user = User.objects.get(username="admin") assert user.is_superuser assert user.email == "admin@example.com" messages = list(get_messages(response.wsgi_request)) - assert len(messages) == 2 - assert messages[0].level == 25 # SUCCESS - assert messages[1].level == 20 # INFO + assert len(messages) == 2 # noqa: PLR2004 + assert messages[0].level == LEVEL_SUCCESS + assert messages[1].level == LEVEL_INFO @pytest.mark.django_db @@ -32,7 +34,7 @@ def test_autocreate_admin_disabled(settings): settings.AUTOCREATE_ADMIN = False client = Client() response = client.get(reverse("route_manager:home")) - assert response.status_code == 200 + assert response.status_code == 200 # noqa: PLR2004 assert User.objects.count() == 0 @@ -43,6 +45,6 @@ def test_autocreate_admin_existing_user(settings): User.objects.create_user("testuser", "test@example.com", "password") client = Client() response = client.get(reverse("route_manager:home")) - assert response.status_code == 200 + assert response.status_code == 200 # noqa: PLR2004 assert User.objects.count() == 1 assert not User.objects.filter(username="admin").exists() diff --git a/scram/route_manager/tests/test_history.py b/scram/route_manager/tests/test_history.py index b234e621..d2848e83 100644 --- a/scram/route_manager/tests/test_history.py +++ b/scram/route_manager/tests/test_history.py @@ -16,7 +16,7 @@ def setUp(self): def test_comments(self): """Ensure we can go back and set a reason.""" self.atype.name = "Nullroute" - self.atype._change_reason = "Use more descriptive name" + self.atype._change_reason = "Use more descriptive name" # noqa SLF001 self.atype.save() self.assertIsNotNone(get_change_reason_from_object(self.atype)) @@ -47,7 +47,7 @@ def test_comments(self): e.route = Route.objects.create(route=route_new) change_reason = "I meant 32, not 16." - e._change_reason = change_reason + e._change_reason = change_reason # noqa SLF001 e.save() self.assertEqual(len(e.history.all()), 2) diff --git a/scram/route_manager/tests/test_swagger.py b/scram/route_manager/tests/test_swagger.py new file mode 100644 index 00000000..ce955db9 --- /dev/null +++ b/scram/route_manager/tests/test_swagger.py @@ -0,0 +1,30 @@ +"""Test the swagger API endpoints.""" + +import pytest +from django.urls import reverse + + +@pytest.mark.django_db +def test_swagger_api(client): + """Test that the Swagger API endpoint returns a successful response.""" + url = reverse("swagger-ui") + response = client.get(url) + assert response.status_code == 200 # noqa: PLR2004 + + +@pytest.mark.django_db +def test_redoc_api(client): + """Test that the Redoc API endpoint returns a successful response.""" + url = reverse("redoc") + response = client.get(url) + assert response.status_code == 200 # noqa: PLR2004 + + +@pytest.mark.django_db +def test_schema_api(client): + """Test that the Schema API endpoint returns a successful response.""" + url = reverse("schema") + response = client.get(url) + assert response.status_code == 200 # noqa: PLR2004 + expected_strings = [b"/entries/", b"/actiontypes/", b"/ignore_entries/", b"/users/"] + assert all(string in response.content for string in expected_strings) diff --git a/scram/route_manager/tests/test_views.py b/scram/route_manager/tests/test_views.py index bcc3a4ad..1433135a 100644 --- a/scram/route_manager/tests/test_views.py +++ b/scram/route_manager/tests/test_views.py @@ -1,7 +1,8 @@ """Define simple tests for the template-based Views.""" +from django.conf import settings from django.test import TestCase -from django.urls import resolve +from django.urls import resolve, reverse from scram.route_manager.views import home_page @@ -13,3 +14,45 @@ def test_root_url_resolves_to_home_page_view(self): """Ensure we can find the home page.""" found = resolve("/") self.assertEqual(found.func, home_page) + + +class HomePageFirstVisitTest(TestCase): + """Test how the home page renders the first time we view it.""" + + def setUp(self): + """Get the home page.""" + self.response = self.client.get(reverse("route_manager:home")) + + def test_first_homepage_view_has_userinfo(self): + """The first time we view the home page, a user was created for us.""" + self.assertContains(self.response, b"An admin user was created for you.") + + def test_first_homepage_view_is_logged_in(self): + """The first time we view the home page, we're logged in.""" + self.assertContains(self.response, b'type="submit">Logout') + + +class HomePageLogoutTest(TestCase): + """Verify that once logged out, we can't view anything.""" + + def test_homepage_logout_links_missing(self): + """After logout, we can't see anything.""" + response = self.client.get(reverse("route_manager:home")) + response = self.client.post(reverse(settings.LOGOUT_URL), follow=True) + self.assertEqual(response.status_code, 200) + response = self.client.get(reverse("route_manager:home")) + + self.assertNotContains(response, b"An admin user was created for you.") + self.assertNotContains(response, b'type="submit">Logout') + self.assertNotContains(response, b">Admin") + + +class NotFoundTest(TestCase): + """Verify that our custom 404 page is being served up.""" + + def test_404(self): + """Grab a bad URL.""" + response = self.client.get("/foobarbaz") + self.assertContains( + response, b'
The page you are looking for was not found.
', status_code=404 + ) diff --git a/scram/route_manager/tests/test_websockets.py b/scram/route_manager/tests/test_websockets.py index 306e772b..ca9cd528 100644 --- a/scram/route_manager/tests/test_websockets.py +++ b/scram/route_manager/tests/test_websockets.py @@ -26,15 +26,13 @@ async def get_communicators(actiontypes, should_match, *args, **kwds): Returns a list of (communicator, should_match bool) pairs. """ - assert len(actiontypes) == len(should_match) - router = URLRouter(websocket_urlpatterns) communicators = [ WebsocketCommunicator(router, f"/ws/route_manager/translator_{actiontype}/") for actiontype in actiontypes ] - response = zip(communicators, should_match) + response = zip(communicators, should_match, strict=True) - for communicator, should_match in response: + for communicator, _ in response: connected, _ = await communicator.connect() assert connected @@ -42,7 +40,7 @@ async def get_communicators(actiontypes, should_match, *args, **kwds): yield response finally: - for communicator, should_match in response: + for communicator, _ in response: await communicator.disconnect() @@ -71,7 +69,9 @@ def setUp(self): wsm, _ = WebSocketMessage.objects.get_or_create(msg_type="translator_add", msg_data_route_field="route") _, _ = WebSocketSequenceElement.objects.get_or_create( - websocketmessage=wsm, verb="A", action_type=self.actiontype + websocketmessage=wsm, + verb="A", + action_type=self.actiontype, ) # Set some defaults; some child classes override this @@ -80,9 +80,9 @@ def setUp(self): self.generate_add_msgs = [lambda ip, mask: {"type": "translator_add", "message": {"route": f"{ip}/{mask}"}}] # Now we run any local setup actions by the child classes - self.local_setUp() + self.local_setup() - def local_setUp(self): + def local_setup(self): """Allow child classes to override this if desired.""" return @@ -151,9 +151,9 @@ async def test_add_v6(self): class TranslatorDontCrossTheStreamsTestCase(TestTranslatorBaseCase): - """Two translators in the same group, two in another group, single IP, ensure we get only the messages we expect.""" + """Two translators in one group, two in another group, single IP, ensure we get only the messages we expect.""" - def local_setUp(self): + def local_setup(self): """Define the actions and what we expect.""" self.actiontypes = ["block", "block", "noop", "noop"] self.should_match = [True, True, False, False] @@ -162,15 +162,21 @@ def local_setUp(self): class TranslatorSequenceTestCase(TestTranslatorBaseCase): """Test a sequence of WebSocket messages.""" - def local_setUp(self): + def local_setup(self): """Define the messages we want to send.""" wsm2 = WebSocketMessage.objects.create(msg_type="translator_add", msg_data_route_field="foo") _ = WebSocketSequenceElement.objects.create( - websocketmessage=wsm2, verb="A", action_type=self.actiontype, order_num=20 + websocketmessage=wsm2, + verb="A", + action_type=self.actiontype, + order_num=20, ) wsm3 = WebSocketMessage.objects.create(msg_type="translator_add", msg_data_route_field="bar") _ = WebSocketSequenceElement.objects.create( - websocketmessage=wsm3, verb="A", action_type=self.actiontype, order_num=2 + websocketmessage=wsm3, + verb="A", + action_type=self.actiontype, + order_num=2, ) self.generate_add_msgs = [ @@ -183,7 +189,7 @@ def local_setUp(self): class TranslatorParametersTestCase(TestTranslatorBaseCase): """Additional parameters in the JSONField.""" - def local_setUp(self): + def local_setup(self): """Define the message we want to send.""" wsm = WebSocketMessage.objects.get(msg_type="translator_add", msg_data_route_field="route") wsm.msg_data = {"asn": 65550, "community": 100, "route": "Ensure this gets overwritten."} diff --git a/scram/route_manager/views.py b/scram/route_manager/views.py index 99cc2c52..fa8b7116 100644 --- a/scram/route_manager/views.py +++ b/scram/route_manager/views.py @@ -24,8 +24,10 @@ channel_layer = get_channel_layer() -def home_page(request, prefilter=Entry.objects.all()): +def home_page(request, prefilter=None): """Return the home page, autocreating a user if none exists.""" + if not prefilter: + prefilter = Entry.objects.all() num_entries = settings.RECENT_LIMIT if request.user.has_perms(("route_manager.view_entry", "route_manager.add_entry")): readwrite = True @@ -103,29 +105,27 @@ def add_entry(request): with transaction.atomic(): res = add_entry_api(request) - if res.status_code == 201: + if res.status_code == 201: # noqa: PLR2004 messages.add_message( request, messages.SUCCESS, "Entry successfully added", ) - elif res.status_code == 400: + elif res.status_code == 400: # noqa: PLR2004 errors = [] if isinstance(res.data, rest_framework.utils.serializer_helpers.ReturnDict): for k, v in res.data.items(): - for error in v: - errors.append(f"'{k}' error: {str(error)}") + errors.extend(f"'{k}' error: {error!s}" for error in v) else: - for k, v in res.data.items(): - errors.append(f"error: {str(v)}") + errors.extend(f"error: {v!s}" for v in res.data.values()) messages.add_message(request, messages.ERROR, "
".join(errors)) - elif res.status_code == 403: + elif res.status_code == 403: # noqa: PLR2004 messages.add_message(request, messages.ERROR, "Permission Denied") else: messages.add_message(request, messages.WARNING, f"Something went wrong: {res.status_code}") with transaction.atomic(): home = home_page(request) - return home + return home # noqa RET504 def process_expired(request): @@ -142,7 +142,7 @@ def process_expired(request): { "entries_deleted": entries_start - entries_end, "active_entries": entries_end, - } + }, ), content_type="application/json", ) @@ -154,7 +154,8 @@ class EntryListView(ListView): model = Entry template_name = "route_manager/entry_list.html" - def get_context_data(self, **kwargs): + @staticmethod + def get_context_data(**kwargs): """Group entries by action type.""" context = {"entries": {}} for at in ActionType.objects.all(): diff --git a/scram/users/api/serializers.py b/scram/users/api/serializers.py index a49520d0..f8d22c88 100644 --- a/scram/users/api/serializers.py +++ b/scram/users/api/serializers.py @@ -7,10 +7,10 @@ class UserSerializer(serializers.ModelSerializer): - """This serializer defines no new fields.""" + """Map to the User model.""" class Meta: - """Maps to the User model, and specifies the fields exposed by the API.""" + """Specify the fields exposed by the API.""" model = User fields = ["username", "name", "url"] diff --git a/scram/users/api/views.py b/scram/users/api/views.py index ce70a713..0e17cb54 100644 --- a/scram/users/api/views.py +++ b/scram/users/api/views.py @@ -23,8 +23,9 @@ def get_queryset(self, *args, **kwargs): """Query on User ID.""" return self.queryset.filter(id=self.request.user.id) + @staticmethod @action(detail=False, methods=["GET"]) - def me(self, request): + def me(request): """Return the current user.""" serializer = UserSerializer(request.user, context={"request": request}) return Response(status=status.HTTP_200_OK, data=serializer.data) diff --git a/scram/users/apps.py b/scram/users/apps.py index 449ebe39..e6cc9c6b 100644 --- a/scram/users/apps.py +++ b/scram/users/apps.py @@ -14,7 +14,8 @@ class UsersConfig(AppConfig): name = "scram.users" verbose_name = _("Users") - def ready(self): + @staticmethod + def ready(): """Check if signals are registered for User events.""" try: import scram.users.signals # noqa F401 diff --git a/scram/users/tests/factories.py b/scram/users/tests/factories.py index 1eca95c0..10331878 100644 --- a/scram/users/tests/factories.py +++ b/scram/users/tests/factories.py @@ -1,6 +1,7 @@ """Define Factory tests for the Users application.""" -from typing import Any, Sequence +from collections.abc import Sequence +from typing import Any from django.contrib.auth import get_user_model from factory import Faker, post_generation diff --git a/scram/users/tests/test_forms.py b/scram/users/tests/test_forms.py index 4ff89f37..3fe65da4 100644 --- a/scram/users/tests/test_forms.py +++ b/scram/users/tests/test_forms.py @@ -13,12 +13,11 @@ class TestUserCreationForm: """Test class for all tests related to the UserCreationForm.""" def test_username_validation_error_msg(self, user: User): - """ - Tests UserCreation Form's unique validator functions correctly by testing 3 things. + """Tests UserCreation Form's unique validator functions correctly by testing 3 things. - 1) A new user with an existing username cannot be added. - 2) Only 1 error is raised by the UserCreation Form - 3) The desired error message is raised + 1) A new user with an existing username cannot be added. + 2) Only 1 error is raised by the UserCreation Form + 3) The desired error message is raised """ # The user already exists, # hence cannot be created. @@ -27,7 +26,7 @@ def test_username_validation_error_msg(self, user: User): "username": user.username, "password1": user.password, "password2": user.password, - } + }, ) assert not form.is_valid() diff --git a/scram/users/tests/test_views.py b/scram/users/tests/test_views.py index 63ad97bc..2df1929a 100644 --- a/scram/users/tests/test_views.py +++ b/scram/users/tests/test_views.py @@ -18,14 +18,14 @@ class TestUserUpdateView: - """ - Define tests related to the Update View. + """Define tests related to the Update View. - TODO: + Todo: extracting view initialization code as class-scoped fixture would be great if only pytest-django supported non-function-scoped fixture db access -- this is a work-in-progress for now: https://github.com/pytest-dev/pytest-django/pull/258 + """ def test_get_success_url(self, user: User, rf: RequestFactory): diff --git a/scram/utils/context_processors.py b/scram/utils/context_processors.py index 2a86dfa4..0bbf8786 100644 --- a/scram/utils/context_processors.py +++ b/scram/utils/context_processors.py @@ -4,7 +4,11 @@ def settings_context(_request): - """Define settings available by default to the templates context.""" + """Define settings available by default to the templates context. + + Returns: + dict: Whether or not we have DEBUG on + """ # Note: we intentionally do NOT expose the entire settings # to prevent accidental leaking of sensitive information return {"DEBUG": settings.DEBUG} diff --git a/translator/exceptions.py b/translator/exceptions.py index a999ec8c..b6d2cfd6 100644 --- a/translator/exceptions.py +++ b/translator/exceptions.py @@ -1,4 +1,4 @@ -"""This module holds all of the exceptions we want to raise in our translators.""" +"""Define all of the exceptions we want to raise in our translators.""" class ASNError(TypeError): diff --git a/translator/gobgp.py b/translator/gobgp.py index 6c10e159..f2b8dee9 100644 --- a/translator/gobgp.py +++ b/translator/gobgp.py @@ -15,11 +15,15 @@ DEFAULT_COMMUNITY = 666 DEFAULT_V4_NEXTHOP = "192.0.2.199" DEFAULT_V6_NEXTHOP = "100::1" +MAX_SMALL_ASN = 2**16 +MAX_SMALL_COMM = 2**16 +IPV6 = 6 logging.basicConfig(level=logging.DEBUG) +logger = logging.getLogger(__name__) -class GoBGP(object): +class GoBGP: """Represents a GoBGP instance.""" def __init__(self, url): @@ -27,14 +31,16 @@ def __init__(self, url): channel = grpc.insecure_channel(url) self.stub = gobgp_pb2_grpc.GobgpApiStub(channel) - def _get_family_AFI(self, ip_version): - if ip_version == 6: + @staticmethod + def _get_family_afi(ip_version): + if ip_version == IPV6: return gobgp_pb2.Family.AFI_IP6 - else: - return gobgp_pb2.Family.AFI_IP + return gobgp_pb2.Family.AFI_IP - def _build_path(self, ip, event_data={}): + def _build_path(self, ip, event_data=None): # noqa: PLR0914 # Grab ASN and Community from our event_data, or use the defaults + if not event_data: + event_data = {} asn = event_data.get("asn", DEFAULT_ASN) community = event_data.get("community", DEFAULT_COMMUNITY) ip_version = ip.ip.version @@ -47,7 +53,7 @@ def _build_path(self, ip, event_data={}): origin.Pack( attribute_pb2.OriginAttribute( origin=2, - ) + ), ) # IP prefix and its associated length @@ -56,27 +62,27 @@ def _build_path(self, ip, event_data={}): attribute_pb2.IPAddressPrefix( prefix_len=ip.network.prefixlen, prefix=str(ip.ip), - ) + ), ) # Set the next hop to the correct value depending on IP family next_hop = Any() - family_afi = self._get_family_AFI(ip_version) - if ip_version == 6: + family_afi = self._get_family_afi(ip_version) + if ip_version == IPV6: next_hops = event_data.get("next_hop", DEFAULT_V6_NEXTHOP) next_hop.Pack( attribute_pb2.MpReachNLRIAttribute( family=gobgp_pb2.Family(afi=family_afi, safi=gobgp_pb2.Family.SAFI_UNICAST), next_hops=[next_hops], nlris=[nlri], - ) + ), ) else: next_hops = event_data.get("next_hop", DEFAULT_V4_NEXTHOP) next_hop.Pack( attribute_pb2.NextHopAttribute( next_hop=next_hops, - ) + ), ) # Set our AS Path @@ -94,13 +100,13 @@ def _build_path(self, ip, event_data={}): communities = Any() # Standard community # Since we pack both into the community string we need to make sure they will both fit - if asn < 65536 and community < 65536: + if asn < MAX_SMALL_ASN and community < MAX_SMALL_COMM: # We bitshift ASN left by 16 so that there is room to add the community on the end of it. This is because # GoBGP wants the community sent as a single integer. comm_id = (asn << 16) + community communities.Pack(attribute_pb2.CommunitiesAttribute(communities=[comm_id])) else: - logging.info(f"LargeCommunity Used - ASN:{asn} Community: {community}") + logger.info("LargeCommunity Used - ASN: %s. Community: %s", asn, community) global_admin = asn local_data1 = community # set to 0 because there's no use case for it, but we need a local_data2 for gobgp to read any of it @@ -122,7 +128,7 @@ def _build_path(self, ip, event_data={}): def add_path(self, ip, event_data): """Announce a single route.""" - logging.info(f"Blocking {ip}") + logger.info("Blocking %s", ip) try: path = self._build_path(ip, event_data) @@ -131,17 +137,17 @@ def add_path(self, ip, event_data): _TIMEOUT_SECONDS, ) except ASNError as e: - logging.warning(f"ASN assertion failed with error: {e}") + logger.warning("ASN assertion failed with error: %s", e) def del_all_paths(self): """Remove all routes from being announced.""" - logging.warning("Withdrawing ALL routes") + logger.warning("Withdrawing ALL routes") self.stub.DeletePath(gobgp_pb2.DeletePathRequest(table_type=gobgp_pb2.GLOBAL), _TIMEOUT_SECONDS) def del_path(self, ip, event_data): """Remove a single route from being announced.""" - logging.info(f"Unblocking {ip}") + logger.info("Unblocking %s", ip) try: path = self._build_path(ip, event_data) self.stub.DeletePath( @@ -149,12 +155,16 @@ def del_path(self, ip, event_data): _TIMEOUT_SECONDS, ) except ASNError as e: - logging.warning(f"ASN assertion failed with error: {e}") + logger.warning("ASN assertion failed with error: %s", e) def get_prefixes(self, ip): - """Retrieve the routes that match a prefix and are announced.""" + """Retrieve the routes that match a prefix and are announced. + + Returns: + list: The routes that overlap with the prefix and are currently announced. + """ prefixes = [gobgp_pb2.TableLookupPrefix(prefix=str(ip.ip))] - family_afi = self._get_family_AFI(ip.ip.version) + family_afi = self._get_family_afi(ip.ip.version) result = self.stub.ListPath( gobgp_pb2.ListPathRequest( table_type=gobgp_pb2.GLOBAL, diff --git a/translator/shared.py b/translator/shared.py index 27e89f97..78555b2d 100644 --- a/translator/shared.py +++ b/translator/shared.py @@ -1,11 +1,12 @@ -"""This module provides a location for code that we want to share between all translators.""" +"""Provide a location for code that we want to share between all translators.""" from exceptions import ASNError +MAX_ASN_VAL = 2**32 - 1 + def asn_is_valid(asn: int) -> bool: - """ - asn_is_valid makes sure that an ASN passed in is a valid 2 or 4 Byte ASN. + """asn_is_valid makes sure that an ASN passed in is a valid 2 or 4 Byte ASN. Args: asn (int): The Autonomous System Number that we want to validate @@ -15,11 +16,14 @@ def asn_is_valid(asn: int) -> bool: Returns: bool: _description_ + """ if not isinstance(asn, int): - raise ASNError(f"ASN {asn} is not an Integer, has type {type(asn)}") - if not 0 < asn < 4294967295: + msg = f"ASN {asn} is not an Integer, has type {type(asn)}" + raise ASNError(msg) + if not 0 < asn < MAX_ASN_VAL: # This is the max as stated in rfc6996 - raise ASNError(f"ASN {asn} is out of range. Must be between 0 and 4294967295") + msg = f"ASN {asn} is out of range. Must be between 0 and 4294967295" + raise ASNError(msg) return True diff --git a/translator/tests/acceptance/steps/actions.py b/translator/tests/acceptance/steps/actions.py index 1e03e8e1..fdb2a95f 100644 --- a/translator/tests/acceptance/steps/actions.py +++ b/translator/tests/acceptance/steps/actions.py @@ -27,17 +27,17 @@ def del_block(context, route, asn, community): def get_block_status(context, ip): - """Check if the IP is currently blocked.""" + """Check if the IP is currently blocked. + + Returns: + bool: The return value. True if the IP is currently blocked, False otherwise. + """ # Allow our add/delete requests to settle time.sleep(1) ip_obj = ipaddress.ip_interface(ip) - for path in context.gobgp.get_prefixes(ip_obj): - if ip_obj in ipaddress.ip_network(path.destination.prefix): - return True - - return False + return any(ip_obj in ipaddress.ip_network(path.destination.prefix) for path in context.gobgp.get_prefixes(ip_obj)) @capture diff --git a/translator/translator.py b/translator/translator.py index b5a7dd02..d2c34018 100644 --- a/translator/translator.py +++ b/translator/translator.py @@ -11,33 +11,41 @@ import websockets from gobgp import GoBGP +logger = logging.getLogger(__name__) + +KNOWN_MESSAGES = { + "translator_add", + "translator_remove", + "translator_remove_all", + "translator_check", +} + # Here we setup a debugger if this is desired. This obviously should not be run in production. debug_mode = os.environ.get("DEBUG") if debug_mode: def install_deps(): - """ - Install necessary dependencies for debuggers. + """Install necessary dependencies for debuggers. Because of how we build translator currently, we don't have a great way to selectively install things at build, so we just do it here! Right now this also includes base.txt, which is unecessary, but in the future when we build a little better, it'll already be setup. """ - logging.info("Installing dependencies for debuggers") + logger.info("Installing dependencies for debuggers") - import subprocess - import sys + import subprocess # noqa: S404, PLC0415 + import sys # noqa: PLC0415 - subprocess.check_call([sys.executable, "-m", "pip", "install", "-r", "/requirements/local.txt"]) + subprocess.check_call([sys.executable, "-m", "pip", "install", "-r", "/requirements/local.txt"]) # noqa: S603 TODO: add this to the container build - logging.info("Done installing dependencies for debuggers") + logger.info("Done installing dependencies for debuggers") - logging.info(f"Translator is set to use a debugger. Provided debug mode: {debug_mode}") + logger.info("Translator is set to use a debugger. Provided debug mode: %s", debug_mode) # We have to setup the debugger appropriately for various IDEs. It'd be nice if they all used the same thing but # sadly, we live in a fallen world. if debug_mode == "pycharm-pydevd": - logging.info("Entering debug mode for pycharm, make sure the debug server is running in PyCharm!") + logger.info("Entering debug mode for pycharm, make sure the debug server is running in PyCharm!") install_deps() @@ -45,67 +53,67 @@ def install_deps(): pydevd_pycharm.settrace("host.docker.internal", port=56782, stdoutToServer=True, stderrToServer=True) - logging.info("Debugger started.") + logger.info("Debugger started.") elif debug_mode == "debugpy": - logging.info("Entering debug mode for debugpy (VSCode)") + logger.info("Entering debug mode for debugpy (VSCode)") install_deps() import debugpy - debugpy.listen(("0.0.0.0", 56781)) + debugpy.listen(("0.0.0.0", 56781)) # noqa S104 (doesn't like binding to all interfaces) - logging.info("Debugger listening on port 56781.") + logger.info("Debugger listening on port 56781.") else: - logging.warning(f"Invalid debug mode given: {debug_mode}. Debugger not started") + logger.warning("Invalid debug mode given: %s. Debugger not started", debug_mode) # Must match the URL in asgi.py, and needs a trailing slash hostname = os.environ.get("SCRAM_HOSTNAME", "scram_hostname_not_set") url = os.environ.get("SCRAM_EVENTS_URL", "ws://django:8000/ws/route_manager/translator_block/") +async def process(message, websocket, g): + """Take a single message form the websocket and hand it off to the appropriate function.""" + json_message = json.loads(message) + event_type = json_message.get("type") + event_data = json_message.get("message") + if event_type not in KNOWN_MESSAGES: + logger.error("Unknown event type received: %s", event_type) + # TODO: Maybe only allow this in testing? + elif event_type == "translator_remove_all": + g.del_all_paths() + else: + try: + ip = ipaddress.ip_interface(event_data["route"]) + except: # noqa E722 + logger.exception("Error parsing message: %s", message) + return + + if event_type == "translator_add": + g.add_path(ip, event_data) + elif event_type == "translator_remove": + g.del_path(ip, event_data) + elif event_type == "translator_check": + json_message["type"] = "translator_check_resp" + json_message["message"]["is_blocked"] = g.is_blocked(ip) + json_message["message"]["translator_name"] = hostname + await websocket.send(json.dumps(json_message)) + + async def main(): """Connect to the websocket and start listening for messages.""" g = GoBGP("gobgp:50051") async for websocket in websockets.connect(url): try: async for message in websocket: - json_message = json.loads(message) - event_type = json_message.get("type") - event_data = json_message.get("message") - if event_type not in [ - "translator_add", - "translator_remove", - "translator_remove_all", - "translator_check", - ]: - logging.error(f"Unknown event type received: {event_type!r}") - # TODO: Maybe only allow this in testing? - elif event_type == "translator_remove_all": - g.del_all_paths() - else: - try: - ip = ipaddress.ip_interface(event_data["route"]) - except: # noqa E722 - logging.error(f"Error parsing message: {message!r}") - continue - - if event_type == "translator_add": - g.add_path(ip, event_data) - elif event_type == "translator_remove": - g.del_path(ip, event_data) - elif event_type == "translator_check": - json_message["type"] = "translator_check_resp" - json_message["message"]["is_blocked"] = g.is_blocked(ip) - json_message["message"]["translator_name"] = hostname - await websocket.send(json.dumps(json_message)) + await process(message, websocket, g) except websockets.ConnectionClosed: continue if __name__ == "__main__": - logging.info("translator started") + logger.info("translator started") loop = asyncio.get_event_loop() loop.run_until_complete(main()) loop.close()