)
+[]()
+[]()
-:License: BSD
-
-====
-
-Overview
-====
+## Overview
SCRAM is a web based service to assist in automation of security data. There is a web interface as well as a REST API available.
+
The idea is to create actiontypes which allow you to take actions on the IPs/cidr networks you provide.
-Components
-====
+## Components
+
SCRAM utilizes ``docker compose`` to run the following stack in production:
- nginx (as a webserver and static asset server)
@@ -37,8 +26,7 @@ SCRAM utilizes ``docker compose`` to run the following stack in production:
A predefined actiontype of "block" exists which utilizes bgp nullrouting to effectivley block any traffic you want to apply.
You can add any other actiontypes via the admin page of the web interface dynamically, but keep in mind translator support would need to be added as well.
-Installation
-====
+## Installation
To get a basic implementation up and running locally:
@@ -61,8 +49,7 @@ To get a basic implementation up and running locally:
- ``make django-open``
-*** Copyright Notice ***
-====
+## Copyright Notice
Security Catch and Release Automation Manager (SCRAM) Copyright (c) 2022,
The Regents of the University of California, through Lawrence Berkeley
diff --git a/docs/__init__.py b/docs/__init__.py
index 8772c827..d34efb05 100644
--- a/docs/__init__.py
+++ b/docs/__init__.py
@@ -1 +1 @@
-# Included so that Django's startproject comment runs against the docs directory
+"""Included so that Django's startproject comment runs against the docs directory."""
diff --git a/docs/conf.py b/docs/conf.py
deleted file mode 100644
index d4d4bad7..00000000
--- a/docs/conf.py
+++ /dev/null
@@ -1,63 +0,0 @@
-# Configuration file for the Sphinx documentation builder.
-#
-# This file only contains a selection of the most common options. For a full
-# list see the documentation:
-# https://www.sphinx-doc.org/en/master/usage/configuration.html
-
-# -- Path setup --------------------------------------------------------------
-
-# If extensions (or modules to document with autodoc) are in another directory,
-# add these directories to sys.path here. If the directory is relative to the
-# documentation root, use os.path.abspath to make it absolute, like shown here.
-
-import os
-import sys
-
-import django
-
-if os.getenv("READTHEDOCS", default=False) == "True":
- sys.path.insert(0, os.path.abspath(".."))
- os.environ["DJANGO_READ_DOT_ENV_FILE"] = "True"
- os.environ["USE_DOCKER"] = "no"
-else:
- sys.path.insert(0, os.path.abspath("/app"))
-os.environ["DATABASE_URL"] = "sqlite:///readthedocs.db"
-os.environ.setdefault("DJANGO_SETTINGS_MODULE", "config.settings.local")
-django.setup()
-
-# -- Project information -----------------------------------------------------
-
-project = "SCRAM"
-copyright = """2021, Sam Oehlert"""
-author = "Sam Oehlert"
-
-
-# -- General configuration ---------------------------------------------------
-
-# Add any Sphinx extension module names here, as strings. They can be
-# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
-# ones.
-extensions = [
- "sphinx.ext.autodoc",
- "sphinx.ext.napoleon",
-]
-
-# Add any paths that contain templates here, relative to this directory.
-# templates_path = ["_templates"]
-
-# List of patterns, relative to source directory, that match files and
-# directories to ignore when looking for source files.
-# This pattern also affects html_static_path and html_extra_path.
-exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
-
-# -- Options for HTML output -------------------------------------------------
-
-# The theme to use for HTML and HTML Help pages. See the documentation for
-# a list of builtin themes.
-#
-html_theme = "alabaster"
-
-# Add any paths that contain custom static files (such as style sheets) here,
-# relative to this directory. They are copied after the builtin static files,
-# so a file named "default.css" will overwrite the builtin "default.css".
-# html_static_path = ["_static"]
diff --git a/docs/cookiecutter.md b/docs/cookiecutter.md
deleted file mode 100644
index 17445af4..00000000
--- a/docs/cookiecutter.md
+++ /dev/null
@@ -1,71 +0,0 @@
-This documentation was initially provided in the README.rst from the cookiecutter used to generate the project
-
-Settings
---------
-
-Moved to settings_.
-
-.. _settings: http://cookiecutter-django.readthedocs.io/en/latest/settings.html
-
-Basic Commands
---------------
-
-Setting Up Your Users
-^^^^^^^^^^^^^^^^^^^^^
-
-* To create a **normal user account**, just go to Sign Up and fill out the form. Once you submit it, you'll see a "Verify Your E-mail Address" page. Go to your console to see a simulated email verification message. Copy the link into your browser. Now the user's email should be verified and ready to go.
-
-* To create an **superuser account**, use this command::
-
- $ python manage.py createsuperuser
-
-For convenience, you can keep your normal user logged in on Chrome and your superuser logged in on Firefox (or similar), so that you can see how the site behaves for both kinds of users.
-
-Type checks
-^^^^^^^^^^^
-
-Running type checks with mypy:
-
-::
-
- $ mypy scram
-
-Test coverage
-^^^^^^^^^^^^^
-
-To run the tests, check your test coverage, and generate an HTML coverage report::
-
- $ coverage run -m pytest
- $ coverage html
- $ open htmlcov/index.html
-
-Running tests with py.test
-~~~~~~~~~~~~~~~~~~~~~~~~~~
-
-::
-
- $ pytest
-
-Live reloading and Sass CSS compilation
-^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
-
-Moved to `Live reloading and SASS compilation`_.
-
-.. _`Live reloading and SASS compilation`: http://cookiecutter-django.readthedocs.io/en/latest/live-reloading-and-sass-compilation.html
-
-
-
-Deployment
-----------
-
-The following details how to deploy this application.
-
-
-
-Docker
-^^^^^^
-
-See detailed `cookiecutter-django Docker documentation`_.
-
-.. _`cookiecutter-django Docker documentation`: http://cookiecutter-django.readthedocs.io/en/latest/deployment-with-docker.html
-
diff --git a/docs/make.bat b/docs/make.bat
deleted file mode 100644
index a22d92b7..00000000
--- a/docs/make.bat
+++ /dev/null
@@ -1,46 +0,0 @@
-@ECHO OFF
-
-pushd %~dp0
-
-REM Command file for Sphinx documentation
-
-
-if "%SPHINXBUILD%" == "" (
- set SPHINXBUILD=sphinx-build -c .
-)
-set SOURCEDIR=_source
-set BUILDDIR=_build
-set APP=..\scram
-
-if "%1" == "" goto help
-
-%SPHINXBUILD% >NUL 2>NUL
-if errorlevel 9009 (
- echo.
- echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
- echo.installed, then set the SPHINXBUILD environment variable to point
- echo.to the full path of the 'sphinx-build' executable. Alternatively you
- echo.may add the Sphinx directory to PATH.
- echo.
- echo.Install sphinx-autobuild for live serving.
- echo.If you don't have Sphinx installed, grab it from
- echo.http://sphinx-doc.org/
- exit /b 1
-)
-
-%SPHINXBUILD% -b %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
-goto end
-
-:livehtml
-sphinx-autobuild -b html --open-browser -p 7000 --watch %APP% -c . %SOURCEDIR% %BUILDDIR%/html
-GOTO :EOF
-
-:apidocs
-sphinx-apidoc -o %SOURCEDIR%/api %APP%
-GOTO :EOF
-
-:help
-%SPHINXBUILD% -b help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
-
-:end
-popd
diff --git a/docs/reference/django.md b/docs/reference/django.md
new file mode 100644
index 00000000..be734554
--- /dev/null
+++ b/docs/reference/django.md
@@ -0,0 +1,3 @@
+# Route Manager
+
+::: scram.route_manager
diff --git a/docs/reference/index.md b/docs/reference/index.md
new file mode 100644
index 00000000..fc771185
--- /dev/null
+++ b/docs/reference/index.md
@@ -0,0 +1,7 @@
+# Module Reference
+
+The SCRAM ecosystem consists of two parts:
+
+A Django app, [route_manager](/reference/django)
+
+A translator service, [translator](/reference/translator)
diff --git a/docs/reference/translator.md b/docs/reference/translator.md
new file mode 100644
index 00000000..897d84e2
--- /dev/null
+++ b/docs/reference/translator.md
@@ -0,0 +1,3 @@
+# Translator
+
+::: translator
diff --git a/manage.py b/manage.py
index 37181fe1..cbc80d8d 100755
--- a/manage.py
+++ b/manage.py
@@ -1,27 +1,25 @@
#!/usr/bin/env python
+"""Django's command-line utility for administrative tasks."""
+
import os
import sys
from pathlib import Path
-if __name__ == "__main__":
- os.environ.setdefault("DJANGO_SETTINGS_MODULE", "config.settings.local")
+def main():
+ """Run administrative tasks."""
+ os.environ.setdefault("DJANGO_SETTINGS_MODULE", "config.settings.local")
try:
- from django.core.management import execute_from_command_line
- except ImportError:
- # The above import may fail for some other reason. Ensure that the
- # issue is really that Django is missing to avoid masking other
- # exceptions on Python 2.
- try:
- import django # noqa
- except ImportError:
- raise ImportError(
- "Couldn't import Django. Are you sure it's installed and "
- "available on your PYTHONPATH environment variable? Did you "
- "forget to activate a virtual environment?"
- )
-
- raise
+ from django.core.management import execute_from_command_line # noqa: PLC0415
+ except ImportError as exc:
+ msg = (
+ "Couldn't import Django. Are you sure it's installed and "
+ "available on your PYTHONPATH environment variable? Did you "
+ "forget to activate a virtual environment?"
+ )
+ raise ImportError(
+ msg,
+ ) from exc
# This allows easy placement of apps within the interior
# scram directory.
@@ -29,3 +27,7 @@
sys.path.append(str(current_path / "scram"))
execute_from_command_line(sys.argv)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/merge_production_dotenvs_in_dotenv.py b/merge_production_dotenvs_in_dotenv.py
deleted file mode 100644
index b66558c3..00000000
--- a/merge_production_dotenvs_in_dotenv.py
+++ /dev/null
@@ -1,65 +0,0 @@
-import os
-from pathlib import Path
-from typing import Sequence
-
-import pytest
-
-ROOT_DIR_PATH = Path(__file__).parent.resolve()
-PRODUCTION_DOTENVS_DIR_PATH = ROOT_DIR_PATH / ".envs" / ".production"
-PRODUCTION_DOTENV_FILE_PATHS = [
- PRODUCTION_DOTENVS_DIR_PATH / ".django",
- PRODUCTION_DOTENVS_DIR_PATH / ".postgres",
-]
-DOTENV_FILE_PATH = ROOT_DIR_PATH / ".env"
-
-
-def merge(output_file_path: str, merged_file_paths: Sequence[str], append_linesep: bool = True) -> None:
- with open(output_file_path, "w") as output_file:
- for merged_file_path in merged_file_paths:
- with open(merged_file_path, "r") as merged_file:
- merged_file_content = merged_file.read()
- output_file.write(merged_file_content)
- if append_linesep:
- output_file.write(os.linesep)
-
-
-def main():
- merge(DOTENV_FILE_PATH, PRODUCTION_DOTENV_FILE_PATHS)
-
-
-@pytest.mark.parametrize("merged_file_count", range(3))
-@pytest.mark.parametrize("append_linesep", [True, False])
-def test_merge(tmpdir_factory, merged_file_count: int, append_linesep: bool):
- tmp_dir_path = Path(str(tmpdir_factory.getbasetemp()))
-
- output_file_path = tmp_dir_path / ".env"
-
- expected_output_file_content = ""
- merged_file_paths = []
- for i in range(merged_file_count):
- merged_file_ord = i + 1
-
- merged_filename = ".service{}".format(merged_file_ord)
- merged_file_path = tmp_dir_path / merged_filename
-
- merged_file_content = merged_filename * merged_file_ord
-
- with open(merged_file_path, "w+") as file:
- file.write(merged_file_content)
-
- expected_output_file_content += merged_file_content
- if append_linesep:
- expected_output_file_content += os.linesep
-
- merged_file_paths.append(merged_file_path)
-
- merge(output_file_path, merged_file_paths, append_linesep)
-
- with open(output_file_path, "r") as output_file:
- actual_output_file_content = output_file.read()
-
- assert actual_output_file_content == expected_output_file_content
-
-
-if __name__ == "__main__":
- main()
diff --git a/mkdocs.yml b/mkdocs.yml
new file mode 100644
index 00000000..704db80c
--- /dev/null
+++ b/mkdocs.yml
@@ -0,0 +1,16 @@
+site_name: Security Catch and Release Automation Manager
+
+theme:
+ name: "material"
+
+plugins:
+ - mkdocstrings:
+ handlers:
+ python:
+ import:
+ - url: https://docs.djangoproject.com/en/4.2/_objects/
+ base_url: https://docs.djangoproject.com/en/4.2/
+ options:
+ show_submodules: true
+ - search
+ - section-index
diff --git a/pyproject.toml b/pyproject.toml
index 910571d6..5c332104 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -25,23 +25,91 @@ exclude_also = [
"if __name__ == .__main__.:",
]
-# ==== black ====
-[tool.black]
+# ===== ruff ====
+[tool.ruff]
+exclude = [
+ "migrations",
+]
+
line-length = 119
-target-version = ['py311']
+target-version = 'py312'
+preview = true
+
+[tool.ruff.lint]
+select = [
+ "A", # builtins
+ "ASYNC", # async
+ "B", # bugbear
+ "BLE", # blind-except
+ "C4", # comprehensions
+ "C90", # complexity
+ "COM", # commas
+ "D", # pydocstyle
+ "DJ", # django
+ "DOC", # pydoclint
+ "DTZ", # datetimez
+ "E", # pycodestyle
+ "EM", # errmsg
+ "ERA", # eradicate
+ "F", # pyflakes
+ "FBT", # boolean-trap
+ "FLY", # flynt
+ "G", # logging-format
+ "I", # isort
+ "ICN", # import-conventions
+ "ISC", # implicit-str-concat
+ "LOG", # logging
+ "N", # pep8-naming
+ "PERF", # perflint
+ "PIE", # pie
+ "PL", # pylint
+ "PTH", # use-pathlib
+ "Q", # quotes
+ "RET", # return
+ "RSE", # raise
+ "RUF", # Ruff
+ "S", # bandit
+ "SIM", # simplify
+ "SLF", # self
+ "SLOT", # slots
+ "T20", # print
+ "TRY", # tryceratops
+ "UP", # pyupgrade
+]
+ignore = [
+ "COM812", # handled by the formatter
+ "DOC501", # add possible exceptions to the docstring (TODO)
+ "ISC001", # handled by the formatter
+ "RUF012", # need more widespread typing
+ "SIM102", # Use a single `if` statement instead of nested `if` statements
+ "SIM108", # Use ternary operator instead of `if`-`else`-block
+]
+[tool.ruff.lint.mccabe]
+max-complexity = 7 # our current code adheres to this without too much effort
-# ==== isort ====
-[tool.isort]
-profile = "black"
-line_length = 119
-known_first_party = [
- "scram",
- "config",
+[tool.ruff.lint.per-file-ignores]
+"**/{tests}/*" = [
+ "DOC201", # documenting return values
+ "DOC402", # documenting yield values
+ "PLR6301", # could be a static method
+ "S101", # use of assert
+ "S106", # hardcoded password
+]
+"test.py" = [
+ "S105", # hardcoded password as argument
+]
+"scram/users/**" = [
+ "DOC201", # documenting return values
+ "FBT001", # minimal issue; don't need to mess with in the User app
+ "PLR2004", # magic values when checking HTTP status codes
+]
+"**/views.py" = [
+ "DOC201", # documenting return values; it's fairly obvious in a View
]
-skip = ["venv/"]
-skip_glob = ["**/migrations/*.py"]
+[tool.ruff.lint.pydocstyle]
+convention = "google"
# ==== mypy ====
[tool.mypy]
@@ -63,53 +131,3 @@ ignore_errors = true
[tool.django-stubs]
django_settings_module = "config.settings.test"
-
-
-# ==== PyLint ====
-[tool.pylint.MASTER]
-load-plugins = [
- "pylint_django",
-]
-django-settings-module = "config.settings.local"
-
-[tool.pylint.FORMAT]
-max-line-length = 119
-
-[tool.pylint."MESSAGES CONTROL"]
-disable = [
- "missing-docstring",
- "invalid-name",
-]
-
-[tool.pylint.DESIGN]
-max-parents = 13
-
-[tool.pylint.TYPECHECK]
-generated-members = [
- "REQUEST",
- "acl_users",
- "aq_parent",
- "[a-zA-Z]+_set{1,2}",
- "save",
- "delete",
-]
-
-
-# ==== djLint ====
-[tool.djlint]
-blank_line_after_tag = "load,extends"
-close_void_tags = true
-format_css = true
-format_js = true
-# TODO: remove T002 when fixed https://github.com/Riverside-Healthcare/djLint/issues/687
-ignore = "H006,H021,H023,H025,H029,H030,H031,T002,T003"
-include = "H017,H035"
-indent = 2
-max_line_length = 119
-profile = "django"
-
-[tool.djlint.css]
-indent_size = 2
-
-[tool.djlint.js]
-indent_size = 2
diff --git a/pytest.ini b/pytest.ini
deleted file mode 100644
index c2b3a233..00000000
--- a/pytest.ini
+++ /dev/null
@@ -1,3 +0,0 @@
-[pytest]
-addopts = --ds=config.settings.test --reuse-db
-python_files = tests.py test_*.py
diff --git a/requirements/base.txt b/requirements/base.txt
index 7577c200..b43b582a 100644
--- a/requirements/base.txt
+++ b/requirements/base.txt
@@ -22,8 +22,7 @@ django-simple-history~=3.1.1
# Django REST Framework
djangorestframework~=3.15 # https://github.com/encode/django-rest-framework
django-cors-headers==3.13.0 # https://github.com/adamchainz/django-cors-headers
-# DRF-spectacular for api documentation
-drf-spectacular # https://github.com/tfranzel/drf-spectacular
+drf-spectacular # https://github.com/tfranzel/drf-spectacular
# OIDC
# ------------------------------------------------------------------------------
diff --git a/requirements/local.txt b/requirements/local.txt
index 2a6c95ae..134a2213 100644
--- a/requirements/local.txt
+++ b/requirements/local.txt
@@ -2,7 +2,7 @@
Werkzeug[watchdog]==2.0.3 # https://github.com/pallets/werkzeug
ipdb==0.13.9 # https://github.com/gotcha/ipdb
-psycopg2-binary==2.9.3 # https://github.com/psycopg/psycopg2
+psycopg2-binary==2.9.10 # https://github.com/psycopg/psycopg2
watchgod==0.8.2 # https://github.com/samuelcolvin/watchgod
# Testing
@@ -13,8 +13,16 @@ behave-django==1.4.0 # https://github.com/behave/behave-django
# Documentation
# ------------------------------------------------------------------------------
-sphinx==5.0.1 # https://github.com/sphinx-doc/sphinx
-sphinx-autobuild==2021.3.14 # https://github.com/GaretJax/sphinx-autobuild
+mkdocs==1.6.1
+mkdocs-autorefs==1.2.0
+mkdocs-gen-files==0.5.0
+mkdocs-get-deps==0.2.0
+mkdocs-literate-nav==0.6.1
+mkdocs-material==9.5.44
+mkdocs-material-extensions==1.3.1
+mkdocs-section-index==0.3.9
+mkdocstrings==0.27.0
+mkdocstrings-python==1.12.2
# Code quality
# ------------------------------------------------------------------------------
diff --git a/scram/__init__.py b/scram/__init__.py
index d5feb05a..8f5a0220 100644
--- a/scram/__init__.py
+++ b/scram/__init__.py
@@ -1,4 +1,4 @@
"""The Django project for Security Catch and Release Automation Manager (SCRAM)."""
__version__ = "1.1.1"
-__version_info__ = tuple([int(num) if num.isdigit() else num for num in __version__.replace("-", ".", 1).split(".")])
+__version_info__ = tuple(int(num) if num.isdigit() else num for num in __version__.replace("-", ".", 1).split("."))
diff --git a/scram/contrib/__init__.py b/scram/contrib/__init__.py
index 88d92bd3..adb76b31 100644
--- a/scram/contrib/__init__.py
+++ b/scram/contrib/__init__.py
@@ -1,5 +1,4 @@
-"""
-To understand why this file is here, please read the CookieCutter documentation.
+"""To understand why this file is here, please read the CookieCutter documentation.
http://cookiecutter-django.readthedocs.io/en/latest/faq.html#why-is-there-a-django-contrib-sites-directory-in-cookiecutter-django
"""
diff --git a/scram/contrib/sites/__init__.py b/scram/contrib/sites/__init__.py
index 88d92bd3..adb76b31 100644
--- a/scram/contrib/sites/__init__.py
+++ b/scram/contrib/sites/__init__.py
@@ -1,5 +1,4 @@
-"""
-To understand why this file is here, please read the CookieCutter documentation.
+"""To understand why this file is here, please read the CookieCutter documentation.
http://cookiecutter-django.readthedocs.io/en/latest/faq.html#why-is-there-a-django-contrib-sites-directory-in-cookiecutter-django
"""
diff --git a/scram/route_manager/api/exceptions.py b/scram/route_manager/api/exceptions.py
index 74d6f2e2..0445dbe4 100644
--- a/scram/route_manager/api/exceptions.py
+++ b/scram/route_manager/api/exceptions.py
@@ -10,7 +10,7 @@ class PrefixTooLarge(APIException):
v4_min_prefix = getattr(settings, "V4_MINPREFIX", 0)
v6_min_prefix = getattr(settings, "V6_MINPREFIX", 0)
status_code = 400
- default_detail = f"You've supplied too large of a network. settings.V4_MINPREFIX = {v4_min_prefix} settings.V6_MINPREFIX = {v6_min_prefix}" # noqa: 501
+ default_detail = f"You've supplied too large of a network. settings.V4_MINPREFIX = {v4_min_prefix} settings.V6_MINPREFIX = {v6_min_prefix}" # noqa: E501
default_code = "prefix_too_large"
diff --git a/scram/route_manager/api/serializers.py b/scram/route_manager/api/serializers.py
index 6baf7d95..f1bc1907 100644
--- a/scram/route_manager/api/serializers.py
+++ b/scram/route_manager/api/serializers.py
@@ -2,6 +2,7 @@
import logging
+from drf_spectacular.utils import extend_schema_field
from netfields import rest_framework
from rest_framework import serializers
from rest_framework.fields import CurrentUserDefault
@@ -12,8 +13,13 @@
logger = logging.getLogger(__name__)
+@extend_schema_field(field={"type": "string", "format": "cidr"})
+class CustomCidrAddressField(rest_framework.CidrAddressField):
+ """Define a wrapper field so swagger can properly handle the inherited field."""
+
+
class ActionTypeSerializer(serializers.ModelSerializer):
- """This serializer defines no new fields."""
+ """Map the serializer to the model via Meta."""
class Meta:
"""Maps to the ActionType model, and specifies the fields exposed by the API."""
@@ -25,7 +31,7 @@ class Meta:
class RouteSerializer(serializers.ModelSerializer):
"""Exposes route as a CIDR field."""
- route = rest_framework.CidrAddressField()
+ route = CustomCidrAddressField()
class Meta:
"""Maps to the Route model, and specifies the fields exposed by the API."""
@@ -37,7 +43,7 @@ class Meta:
class ClientSerializer(serializers.ModelSerializer):
- """This serializer defines no new fields."""
+ """Map the serializer to the model via Meta."""
class Meta:
"""Maps to the Client model, and specifies the fields exposed by the API."""
@@ -50,9 +56,11 @@ class EntrySerializer(serializers.HyperlinkedModelSerializer):
"""Due to the use of ForeignKeys, this follows some relationships to make sense via the API."""
url = serializers.HyperlinkedIdentityField(
- view_name="api:v1:entry-detail", lookup_url_kwarg="pk", lookup_field="route"
+ view_name="api:v1:entry-detail",
+ lookup_url_kwarg="pk",
+ lookup_field="route",
)
- route = rest_framework.CidrAddressField()
+ route = CustomCidrAddressField()
actiontype = serializers.CharField(default="block")
if CurrentUserDefault():
# This is set if we are calling this serializer from WUI
@@ -67,28 +75,36 @@ class Meta:
model = Entry
fields = ["route", "actiontype", "url", "comment", "who"]
- def get_comment(self, obj):
- """Provide a nicer name for change reason."""
+ @staticmethod
+ def get_comment(obj):
+ """Provide a nicer name for change reason.
+
+ Returns:
+ string: The change reason that modified the Entry.
+ """
return obj.get_change_reason()
- def create(self, validated_data):
- """Implement custom logic and validates creating a new route."""
+ @staticmethod
+ def create(validated_data):
+ """Implement custom logic and validates creating a new route.""" # noqa: DOC201
valid_route = validated_data.pop("route")
actiontype = validated_data.pop("actiontype")
comment = validated_data.pop("comment")
- route_instance, created = Route.objects.get_or_create(route=valid_route)
+ route_instance, _ = Route.objects.get_or_create(route=valid_route)
actiontype_instance = ActionType.objects.get(name=actiontype)
- entry_instance, created = Entry.objects.get_or_create(route=route_instance, actiontype=actiontype_instance)
+ entry_instance, _ = Entry.objects.get_or_create(route=route_instance, actiontype=actiontype_instance)
- logger.debug(f"{comment}")
+ logger.debug("Created entry with comment: %s", comment)
update_change_reason(entry_instance, comment)
return entry_instance
class IgnoreEntrySerializer(serializers.ModelSerializer):
- """This serializer defines no new fields."""
+ """Map the route to the right field type."""
+
+ route = CustomCidrAddressField()
class Meta:
"""Maps to the IgnoreEntry model, and specifies the fields exposed by the API."""
diff --git a/scram/route_manager/api/views.py b/scram/route_manager/api/views.py
index ad26127e..70d30b34 100644
--- a/scram/route_manager/api/views.py
+++ b/scram/route_manager/api/views.py
@@ -10,6 +10,7 @@
from django.db.models import Q
from django.http import Http404
from django.utils.dateparse import parse_datetime
+from drf_spectacular.utils import extend_schema
from rest_framework import status, viewsets
from rest_framework.permissions import AllowAny, IsAuthenticated
from rest_framework.response import Response
@@ -19,8 +20,13 @@
from .serializers import ActionTypeSerializer, ClientSerializer, EntrySerializer, IgnoreEntrySerializer
channel_layer = get_channel_layer()
+logger = logging.getLogger(__name__)
+@extend_schema(
+ description="API endpoint for actiontypes",
+ responses={200: ActionTypeSerializer},
+)
class ActionTypeViewSet(viewsets.ReadOnlyModelViewSet):
"""Lookup ActionTypes by name when authenticated, and bind to the serializer."""
@@ -30,6 +36,10 @@ class ActionTypeViewSet(viewsets.ReadOnlyModelViewSet):
lookup_field = "name"
+@extend_schema(
+ description="API endpoint for ignore entries",
+ responses={200: IgnoreEntrySerializer},
+)
class IgnoreEntryViewSet(viewsets.ModelViewSet):
"""Lookup IgnoreEntries by route when authenticated, and bind to the serializer."""
@@ -39,6 +49,10 @@ class IgnoreEntryViewSet(viewsets.ModelViewSet):
lookup_field = "route"
+@extend_schema(
+ description="API endpoint for clients",
+ responses={200: ClientSerializer},
+)
class ClientViewSet(viewsets.ModelViewSet):
"""Lookup Client by hostname on POSTs regardless of authentication, and bind to the serializer."""
@@ -50,6 +64,10 @@ class ClientViewSet(viewsets.ModelViewSet):
http_method_names = ["post"]
+@extend_schema(
+ description="API endpoint for entries",
+ responses={200: EntrySerializer},
+)
class EntryViewSet(viewsets.ModelViewSet):
"""Lookup Entry when authenticated, and bind to the serializer."""
@@ -60,8 +78,7 @@ class EntryViewSet(viewsets.ModelViewSet):
http_method_names = ["get", "post", "head", "delete"]
def get_permissions(self):
- """
- Override the permissions classes for POST method since we want to accept Entry creates from any client.
+ """Override the permissions classes for POST method since we want to accept Entry creates from any client.
Note: We make authorization decisions on whether to actually create the object in the perform_create method
later.
@@ -70,6 +87,31 @@ def get_permissions(self):
return [AllowAny()]
return super().get_permissions()
+ def check_client_authorization(self, actiontype):
+ """Ensure that a given client is authorized to use a given actiontype."""
+ uuid = self.request.data.get("uuid")
+ if uuid:
+ authorized_actiontypes = Client.objects.filter(uuid=uuid).values_list(
+ "authorized_actiontypes__name",
+ flat=True,
+ )
+ authorized_client = Client.objects.filter(uuid=uuid).values("is_authorized")
+ if not authorized_client or actiontype not in authorized_actiontypes:
+ logger.debug("Client: %s, actiontypes: %s", uuid, authorized_actiontypes)
+ logger.info("%s is not allowed to add an entry to the %s list.", uuid, actiontype)
+ raise ActiontypeNotAllowed
+ elif not self.request.user.has_perm("route_manager.can_add_entry"):
+ raise PermissionDenied
+
+ @staticmethod
+ def check_ignore_list(route):
+ """Ensure that we're not trying to block something from the ignore list."""
+ overlapping_ignore = IgnoreEntry.objects.filter(route__net_overlaps=route)
+ if overlapping_ignore.count():
+ ignore_entries = [str(ignore_entry["route"]) for ignore_entry in overlapping_ignore.values()]
+ logger.info("Cannot proceed adding %s. The ignore list contains %s.", route, ignore_entries)
+ raise IgnoredRoute
+
def perform_create(self, serializer):
"""Create a new Entry, causing that route to receive the actiontype (i.e. block)."""
actiontype = serializer.validated_data["actiontype"]
@@ -84,60 +126,41 @@ def perform_create(self, serializer):
tmp_exp = self.request.data.get("expiration", "")
try:
- expiration = parse_datetime(tmp_exp) # noqa: F841
+ expiration = parse_datetime(tmp_exp)
except ValueError:
- logging.info(f"Could not parse expiration DateTime string {tmp_exp!r}.")
+ logger.warning("Could not parse expiration DateTime string: %s", tmp_exp)
# Make sure we put in an acceptable sized prefix
min_prefix = getattr(settings, f"V{route.version}_MINPREFIX", 0)
if route.prefixlen < min_prefix:
- raise PrefixTooLarge()
-
- # Make sure this client is authorized to add this entry with this actiontype
- if self.request.data.get("uuid"):
- client_uuid = self.request.data["uuid"]
- authorized_actiontypes = Client.objects.filter(uuid=client_uuid).values_list(
- "authorized_actiontypes__name", flat=True
+ raise PrefixTooLarge
+
+ self.check_client_authorization(actiontype)
+ self.check_ignore_list(route)
+
+ elements = WebSocketSequenceElement.objects.filter(action_type__name=actiontype).order_by("order_num")
+ if not elements:
+ logger.warning("No elements found for actiontype: %s", actiontype)
+
+ for element in elements:
+ msg = element.websocketmessage
+ msg.msg_data[msg.msg_data_route_field] = str(route)
+ # Must match a channel name defined in asgi.py
+ async_to_sync(channel_layer.group_send)(
+ f"translator_{actiontype}",
+ {"type": msg.msg_type, "message": msg.msg_data},
)
- authorized_client = Client.objects.filter(uuid=client_uuid).values("is_authorized")
- if not authorized_client or actiontype not in authorized_actiontypes:
- logging.debug(f"Client {client_uuid} actiontypes: {authorized_actiontypes}")
- logging.info(f"{client_uuid} is not allowed to add an entry to the {actiontype} list")
- raise ActiontypeNotAllowed()
- elif not self.request.user.has_perm("route_manager.can_add_entry"):
- raise PermissionDenied()
- # Don't process if we have the entry in the ignorelist
- overlapping_ignore = IgnoreEntry.objects.filter(route__net_overlaps=route)
- if overlapping_ignore.count():
- ignore_entries = []
- for ignore_entry in overlapping_ignore.values():
- ignore_entries.append(str(ignore_entry["route"]))
- logging.info(f"Cannot proceed adding {route}. The ignore list contains {ignore_entries}")
- raise IgnoredRoute
- else:
- elements = WebSocketSequenceElement.objects.filter(action_type__name=actiontype).order_by("order_num")
- if not elements:
- logging.warning(f"No elements found for actiontype={actiontype}.")
-
- for element in elements:
- msg = element.websocketmessage
- msg.msg_data[msg.msg_data_route_field] = str(route)
- # Must match a channel name defined in asgi.py
- async_to_sync(channel_layer.group_send)(
- f"translator_{actiontype}", {"type": msg.msg_type, "message": msg.msg_data}
- )
-
- serializer.save()
-
- entry = Entry.objects.get(route__route=route, actiontype__name=actiontype)
- if expiration:
- entry.expiration = expiration
- entry.who = who
- entry.is_active = True
- entry.comment = comment
- logging.info(f"{entry}")
- entry.save()
+ serializer.save()
+
+ entry = Entry.objects.get(route__route=route, actiontype__name=actiontype)
+ if expiration:
+ entry.expiration = expiration
+ entry.who = who
+ entry.is_active = True
+ entry.comment = comment
+ logger.info("Created entry: %s", entry)
+ entry.save()
@staticmethod
def find_entries(arg, active_filter=None):
@@ -149,13 +172,13 @@ def find_entries(arg, active_filter=None):
try:
pk = int(arg)
query = Q(pk=pk)
- except ValueError:
+ except ValueError as exc:
# Maybe a CIDR? We want the ValueError at this point, if not.
cidr = ipaddress.ip_network(arg, strict=False)
min_prefix = getattr(settings, f"V{cidr.version}_MINPREFIX", 0)
if cidr.prefixlen < min_prefix:
- raise PrefixTooLarge()
+ raise PrefixTooLarge from exc
query = Q(route__route__net_overlaps=cidr)
diff --git a/scram/route_manager/authentication_backends.py b/scram/route_manager/authentication_backends.py
index 30288520..1831d788 100644
--- a/scram/route_manager/authentication_backends.py
+++ b/scram/route_manager/authentication_backends.py
@@ -5,48 +5,45 @@
from mozilla_django_oidc.auth import OIDCAuthenticationBackend
+def groups_overlap(a, b):
+ """Helper function to see if a and b have any overlap.
+
+ Returns:
+ bool: True if there's any overlap between a and b.
+ """
+ return not set(a).isdisjoint(b)
+
+
class ESnetAuthBackend(OIDCAuthenticationBackend):
"""Extend the OIDC backend with a custom permission model."""
- def update_groups(self, user, claims):
+ @staticmethod
+ def update_groups(user, claims):
"""Set the user's group(s) to whatever is in the claims."""
- claimed_groups = claims.get("groups", [])
-
effective_groups = []
- is_admin = False
-
- ro_group = Group.objects.get(name="readonly")
- rw_group = Group.objects.get(name="readwrite")
-
- for g in claimed_groups:
- # If any of the user's groups are in DENIED_GROUPS, deny them and stop processing immediately
- if g in settings.SCRAM_DENIED_GROUPS:
- effective_groups = []
- is_admin = False
- break
-
- if g in settings.SCRAM_ADMIN_GROUPS:
- is_admin = True
-
- if g in settings.SCRAM_READONLY_GROUPS:
- if ro_group not in effective_groups:
- effective_groups.append(ro_group)
+ claimed_groups = claims.get("groups", [])
- if g in settings.SCRAM_READWRITE_GROUPS:
- if rw_group not in effective_groups:
- effective_groups.append(rw_group)
+ if groups_overlap(claimed_groups, settings.SCRAM_DENIED_GROUPS):
+ is_admin = False
+ # Don't even look at anything else if they're denied
+ else:
+ is_admin = groups_overlap(claimed_groups, settings.SCRAM_ADMIN_GROUPS)
+ if groups_overlap(claimed_groups, settings.SCRAM_READWRITE_GROUPS):
+ effective_groups.append(Group.objects.get(name="readwrite"))
+ if groups_overlap(claimed_groups, settings.SCRAM_READONLY_GROUPS):
+ effective_groups.append(Group.objects.get(name="readonly"))
user.groups.set(effective_groups)
user.is_staff = user.is_superuser = is_admin
user.save()
def create_user(self, claims):
- """Wrap the superclass's user creation."""
- user = super(ESnetAuthBackend, self).create_user(claims)
+ """Wrap the superclass's user creation.""" # noqa: DOC201
+ user = super().create_user(claims)
return self.update_user(user, claims)
def update_user(self, user, claims):
- """Determine the user name from the claims and update said user's groups."""
+ """Determine the user name from the claims and update said user's groups.""" # noqa: DOC201
user.name = claims.get("given_name", "") + " " + claims.get("family_name", "")
user.username = claims.get("preferred_username", "")
if claims.get("groups", False):
diff --git a/scram/route_manager/context_processors.py b/scram/route_manager/context_processors.py
index 61abde75..0daf06c9 100644
--- a/scram/route_manager/context_processors.py
+++ b/scram/route_manager/context_processors.py
@@ -5,7 +5,11 @@
def login_logout(request):
- """Pass through the relevant URLs from the settings."""
+ """Pass through the relevant URLs from the settings.
+
+ Returns:
+ dict: login and logout URLs
+ """
login_url = reverse(settings.LOGIN_URL)
logout_url = reverse(settings.LOGOUT_URL)
return {"login": login_url, "logout": logout_url}
diff --git a/scram/route_manager/models.py b/scram/route_manager/models.py
index 3bbf67e3..aa72f7ba 100644
--- a/scram/route_manager/models.py
+++ b/scram/route_manager/models.py
@@ -10,6 +10,8 @@
from netfields import CidrAddressField
from simple_history.models import HistoricalRecords
+logger = logging.getLogger(__name__)
+
class Route(models.Model):
"""Define a route as a CIDR route and a UUID."""
@@ -17,14 +19,15 @@ class Route(models.Model):
route = CidrAddressField(unique=True)
uuid = models.UUIDField(db_index=True, default=uuid_lib.uuid4, editable=False)
- def get_absolute_url(self):
- """Ensure we use UUID on the API side instead."""
- return reverse("")
-
def __str__(self):
- """Don't display the UUID, only the route."""
+ """Don't display the UUID, only the route.""" # noqa: DOC201
return str(self.route)
+ @staticmethod
+ def get_absolute_url():
+ """Ensure we use UUID on the API side instead.""" # noqa: DOC201
+ return reverse("")
+
class ActionType(models.Model):
"""Define a type of action that can be done with a given route. e.g. Block, shunt, redirect, etc."""
@@ -34,7 +37,7 @@ class ActionType(models.Model):
history = HistoricalRecords()
def __str__(self):
- """Display clearly whether the action is currently available."""
+ """Display clearly whether the action is currently available.""" # noqa: DOC201
if not self.available:
return f"{self.name} (Inactive)"
return self.name
@@ -52,7 +55,7 @@ class WebSocketMessage(models.Model):
)
def __str__(self):
- """Display clearly what the fields are used for."""
+ """Display clearly what the fields are used for.""" # noqa: DOC201
return f"{self.msg_type}: {self.msg_data} with the route in key {self.msg_data_route_field}"
@@ -62,7 +65,7 @@ class WebSocketSequenceElement(models.Model):
websocketmessage = models.ForeignKey("WebSocketMessage", on_delete=models.CASCADE)
order_num = models.SmallIntegerField(
"Sequences are sent from the smallest order_num to the highest. "
- + "Messages with the same order_num could be sent in any order",
+ "Messages with the same order_num could be sent in any order",
default=0,
)
@@ -76,10 +79,10 @@ class WebSocketSequenceElement(models.Model):
action_type = models.ForeignKey("ActionType", on_delete=models.CASCADE)
def __str__(self):
- """Summarize the fields into something short and readable."""
+ """Summarize the fields into something short and readable.""" # noqa: DOC201
return (
f"{self.websocketmessage} as order={self.order_num} for "
- + f"{self.verb} actions on actiontype={self.action_type}"
+ f"{self.verb} actions on actiontype={self.action_type}"
)
@@ -88,7 +91,7 @@ class Entry(models.Model):
route = models.ForeignKey("Route", on_delete=models.PROTECT)
actiontype = models.ForeignKey("ActionType", on_delete=models.PROTECT)
- comment = models.TextField(blank=True, null=True)
+ comment = models.TextField(blank=True, default="")
is_active = models.BooleanField(default=True)
# TODO: fix name if this works
history = HistoricalRecords()
@@ -98,30 +101,10 @@ class Entry(models.Model):
expiration_reason = models.CharField(
help_text="Optional reason for the expiration",
max_length=200,
- null=True,
blank=True,
+ default="",
)
- def delete(self, *args, **kwargs):
- """Set inactive instead of deleting, as we want to ensure a history of entries."""
- if not self.is_active:
- # We've already expired this route, don't send another message
- return
- else:
- # We don't actually delete records; we set them to inactive and then tell the translator to remove them
- logging.info(f"Deactivating {self.route}")
- self.is_active = False
- self.save()
-
- # Unblock it
- async_to_sync(channel_layer.group_send)(
- f"translator_{self.actiontype}",
- {
- "type": "translator_remove",
- "message": {"route": str(self.route)},
- },
- )
-
class Meta:
"""Ensure that multiple routes can be added as long as they have different action types."""
@@ -129,14 +112,37 @@ class Meta:
verbose_name_plural = "Entries"
def __str__(self):
- """Summarize the most important fields to something easily readable."""
+ """Summarize the most important fields to something easily readable.""" # noqa: DOC201
desc = f"{self.route} ({self.actiontype})"
if not self.is_active:
desc += " (inactive)"
return desc
+ def delete(self, *args, **kwargs):
+ """Set inactive instead of deleting, as we want to ensure a history of entries."""
+ if not self.is_active:
+ # We've already expired this route, don't send another message
+ return
+ # We don't actually delete records; we set them to inactive and then tell the translator to remove them
+ logger.info("Deactivating %s", self.route)
+ self.is_active = False
+ self.save()
+
+ # Unblock it
+ async_to_sync(channel_layer.group_send)(
+ f"translator_{self.actiontype}",
+ {
+ "type": "translator_remove",
+ "message": {"route": str(self.route)},
+ },
+ )
+
def get_change_reason(self):
- """Traverse come complex relationships to determine the most recent change reason."""
+ """Traverse some complex relationships to determine the most recent change reason.
+
+ Returns:
+ str: The most recent change reason
+ """
hist_mgr = getattr(self, self._meta.simple_history_manager_attribute)
return hist_mgr.order_by("-history_date").first().history_change_reason
@@ -154,7 +160,7 @@ class Meta:
verbose_name_plural = "Ignored Entries"
def __str__(self):
- """Only display the route."""
+ """Only display the route.""" # noqa: DOC201
return str(self.route)
@@ -168,7 +174,7 @@ class Client(models.Model):
authorized_actiontypes = models.ManyToManyField(ActionType)
def __str__(self):
- """Only display the hostname."""
+ """Only display the hostname.""" # noqa: DOC201
return str(self.hostname)
diff --git a/scram/route_manager/tests/acceptance/steps/common.py b/scram/route_manager/tests/acceptance/steps/common.py
index c7c03f8c..f688bd64 100644
--- a/scram/route_manager/tests/acceptance/steps/common.py
+++ b/scram/route_manager/tests/acceptance/steps/common.py
@@ -3,34 +3,35 @@
import datetime
import time
-import django.conf as conf
from asgiref.sync import async_to_sync
from behave import given, step, then, when
from channels.layers import get_channel_layer
+from django import conf
from django.urls import reverse
from scram.route_manager.models import ActionType, Client, WebSocketMessage, WebSocketSequenceElement
@given("a {name} actiontype is defined")
-def step_impl(context, name):
+def create_actiontype(context, name):
"""Create an actiontype of that name."""
context.channel_layer = get_channel_layer()
async_to_sync(context.channel_layer.group_send)(
- f"translator_{name}", {"type": "translator_remove_all", "message": {}}
+ f"translator_{name}",
+ {"type": "translator_remove_all", "message": {}},
)
- at, created = ActionType.objects.get_or_create(name=name)
- wsm, created = WebSocketMessage.objects.get_or_create(msg_type="translator_add", msg_data_route_field="route")
+ at, _ = ActionType.objects.get_or_create(name=name)
+ wsm, _ = WebSocketMessage.objects.get_or_create(msg_type="translator_add", msg_data_route_field="route")
wsm.save()
- wsse, created = WebSocketSequenceElement.objects.get_or_create(websocketmessage=wsm, verb="A", action_type=at)
+ wsse, _ = WebSocketSequenceElement.objects.get_or_create(websocketmessage=wsm, verb="A", action_type=at)
wsse.save()
@given("a client with {name} authorization")
-def step_impl(context, name):
+def create_authed_client(context, name):
"""Create a client and authorize it for that action type."""
- at, created = ActionType.objects.get_or_create(name=name)
+ at, _ = ActionType.objects.get_or_create(name=name)
authorized_client = Client.objects.create(
hostname="authorized_client.es.net",
uuid="0e7e1cbd-7d73-4968-bc4b-ce3265dc2fd3",
@@ -40,7 +41,7 @@ def step_impl(context, name):
@given("a client without {name} authorization")
-def step_impl(context, name):
+def create_unauthed_client(context, name):
"""Create a client that has no authorized action types."""
unauthorized_client = Client.objects.create(
hostname="unauthorized_client.es.net",
@@ -50,26 +51,26 @@ def step_impl(context, name):
@when("we're logged in")
-def step_impl(context):
+def login(context):
"""Login."""
context.test.client.login(username="user", password="password")
@when("the CIDR prefix limits are {v4_minprefix:d} and {v6_minprefix:d}")
-def step_impl(context, v4_minprefix, v6_minprefix):
+def set_cidr_limit(context, v4_minprefix, v6_minprefix):
"""Override our settings with the provided values."""
conf.settings.V4_MINPREFIX = v4_minprefix
conf.settings.V6_MINPREFIX = v6_minprefix
@then("we get a {status_code:d} status code")
-def step_impl(context, status_code):
+def check_status_code(context, status_code):
"""Ensure the status code response matches the expected value."""
context.test.assertEqual(context.response.status_code, status_code)
@when("we add the entry {value:S}")
-def step_impl(context, value):
+def add_entry(context, value):
"""Block the provided route."""
context.response = context.test.client.post(
reverse("api:v1:entry-list"),
@@ -86,7 +87,7 @@ def step_impl(context, value):
@when("we add the entry {value:S} with comment {comment}")
-def step_impl(context, value, comment):
+def add_entry_with_comment(context, value, comment):
"""Block the provided route and add a comment."""
context.response = context.test.client.post(
reverse("api:v1:entry-list"),
@@ -102,7 +103,7 @@ def step_impl(context, value, comment):
@when("we add the entry {value:S} with expiration {exp:S}")
-def step_impl(context, value, exp):
+def add_entry_with_absolute_expiration(context, value, exp):
"""Block the provided route and add an absolute expiration datetime."""
context.response = context.test.client.post(
reverse("api:v1:entry-list"),
@@ -119,10 +120,10 @@ def step_impl(context, value, exp):
@when("we add the entry {value:S} with expiration in {secs:d} seconds")
-def step_impl(context, value, secs):
+def add_entry_with_relative_expiration(context, value, secs):
"""Block the provided route and add a relative expiration."""
td = datetime.timedelta(seconds=secs)
- expiration = datetime.datetime.now() + td
+ expiration = datetime.datetime.now(tz=datetime.UTC) + td
context.response = context.test.client.post(
reverse("api:v1:entry-list"),
@@ -139,39 +140,40 @@ def step_impl(context, value, secs):
@step("we wait {secs:d} seconds")
-def step_impl(context, secs):
+def wait(context, secs):
"""Wait to allow messages to propagate."""
time.sleep(secs)
@then("we remove expired entries")
-def step_impl(context):
+def remove_expired(context):
"""Call the function that removes expired entries."""
context.response = context.test.client.get(reverse("route_manager:process-expired"))
@when("we add the ignore entry {value:S}")
-def step_impl(context, value):
+def add_ignore_entry(context, value):
"""Add an IgnoreEntry with the specified route."""
context.response = context.test.client.post(
- reverse("api:v1:ignoreentry-list"), {"route": value, "comment": "test api"}
+ reverse("api:v1:ignoreentry-list"),
+ {"route": value, "comment": "test api"},
)
@when("we remove the {model} {value}")
-def step_impl(context, model, value):
+def remove_an_object(context, model, value):
"""Remove any model object with the matching value."""
context.response = context.test.client.delete(reverse(f"api:v1:{model.lower()}-detail", args=[value]))
@when("we list the {model}s")
-def step_impl(context, model):
+def list_objects(context, model):
"""List all objects of an arbitrary model."""
context.response = context.test.client.get(reverse(f"api:v1:{model.lower()}-list"))
@when("we update the {model} {value_from} to {value_to}")
-def step_impl(context, model, value_from, value_to):
+def update_object(context, model, value_from, value_to):
"""Modify any model object with the matching value to the new value instead."""
context.response = context.test.client.patch(
reverse(f"api:v1:{model.lower()}-detail", args=[value_from]),
@@ -180,7 +182,7 @@ def step_impl(context, model, value_from, value_to):
@then("the number of {model}s is {num:d}")
-def step_impl(context, model, num):
+def count_objects(context, model, num):
"""Count the number of objects of an arbitrary model."""
objs = context.test.client.get(reverse(f"api:v1:{model.lower()}-list"))
context.test.assertEqual(len(objs.json()), num)
@@ -190,7 +192,7 @@ def step_impl(context, model, num):
@then("{value} is one of our list of {model}s")
-def step_impl(context, value, model):
+def check_object(context, value, model):
"""Ensure that the arbitrary model has an object with the specified value."""
objs = context.test.client.get(reverse(f"api:v1:{model.lower()}-list"))
@@ -206,7 +208,7 @@ def step_impl(context, value, model):
@when("we register a client named {hostname} with the uuid of {uuid}")
-def step_impl(context, hostname, uuid):
+def add_client(context, hostname, uuid):
"""Create a client with a specific UUID."""
context.response = context.test.client.post(
reverse("api:v1:client-list"),
diff --git a/scram/route_manager/tests/acceptance/steps/ip.py b/scram/route_manager/tests/acceptance/steps/ip.py
index 8154f04b..34536a86 100644
--- a/scram/route_manager/tests/acceptance/steps/ip.py
+++ b/scram/route_manager/tests/acceptance/steps/ip.py
@@ -7,7 +7,7 @@
@then("{route} is contained in our list of {model}s")
-def step_impl(context, route, model):
+def check_route(context, route, model):
"""Perform a CIDR match on the matching object."""
objs = context.test.client.get(reverse(f"api:v1:{model.lower()}-list"))
ip_target = ipaddress.ip_address(route)
@@ -23,7 +23,7 @@ def step_impl(context, route, model):
@when("we query for {ip}")
-def step_impl(context, ip):
+def check_ip(context, ip):
"""Find an Entry for the specified IP."""
try:
context.response = context.test.client.get(reverse("api:v1:entry-detail", args=[ip]))
@@ -34,13 +34,13 @@ def step_impl(context, ip):
@then("we get a ValueError")
-def step_impl(context):
+def check_error(context):
"""Ensure we received a ValueError exception."""
assert isinstance(context.queryException, ValueError)
@then("the change entry for {value:S} is {comment}")
-def step_impl(context, value, comment):
+def check_comment(context, value, comment):
"""Verify the comment for the Entry."""
try:
objs = context.test.client.get(reverse("api:v1:entry-detail", args=[value]))
diff --git a/scram/route_manager/tests/acceptance/steps/translator.py b/scram/route_manager/tests/acceptance/steps/translator.py
index c6c2ff49..ff478a9a 100644
--- a/scram/route_manager/tests/acceptance/steps/translator.py
+++ b/scram/route_manager/tests/acceptance/steps/translator.py
@@ -7,10 +7,10 @@
from config.asgi import ws_application
-async def query_translator(route, actiontype, is_announced=True):
+async def query_translator(route, actiontype, is_announced):
"""Ensure the specified route is currently either blocked or unblocked."""
communicator = WebsocketCommunicator(ws_application, f"/ws/route_manager/webui_{actiontype}/")
- connected, subprotocol = await communicator.connect()
+ connected, _ = await communicator.connect()
assert connected
await communicator.send_json_to({"type": "wui_check_req", "message": {"route": route}})
@@ -23,13 +23,13 @@ async def query_translator(route, actiontype, is_announced=True):
@then("{route} is announced by {actiontype} translators")
@async_run_until_complete
-async def step_impl(context, route, actiontype):
+async def check_blocked(context, route, actiontype):
"""Ensure the specified route is currently blocked."""
- await query_translator(route, actiontype)
+ await query_translator(route, actiontype, is_announced=True)
@then("{route} is not announced by {actiontype} translators")
@async_run_until_complete
-async def step_impl(context, route, actiontype):
+async def check_unblocked(context, route, actiontype):
"""Ensure the specified route is currently unblocked."""
await query_translator(route, actiontype, is_announced=False)
diff --git a/scram/route_manager/tests/functional_tests.py b/scram/route_manager/tests/functional_tests.py
index c4515b13..a87ce21e 100644
--- a/scram/route_manager/tests/functional_tests.py
+++ b/scram/route_manager/tests/functional_tests.py
@@ -6,8 +6,6 @@
class HomePageTest(unittest.TestCase):
"""Ensure the home page works."""
- pass
-
if __name__ == "__main__":
unittest.main(warnings="ignore")
diff --git a/scram/route_manager/tests/test_authorization.py b/scram/route_manager/tests/test_authorization.py
index 17de2b6d..9a0d4bff 100644
--- a/scram/route_manager/tests/test_authorization.py
+++ b/scram/route_manager/tests/test_authorization.py
@@ -236,7 +236,7 @@ def test_authorized_removal(self):
def test_disabled(self):
"""Pass all the groups, user should be disabled as it takes precedence."""
claims = dict(self.claims)
- claims["groups"] = [settings.SCRAM_GROUPS]
+ claims["groups"] = settings.SCRAM_GROUPS
user = ESnetAuthBackend().create_user(claims)
self.assertFalse(user.is_staff)
diff --git a/scram/route_manager/tests/test_autocreate_admin.py b/scram/route_manager/tests/test_autocreate_admin.py
index b9d2e432..88270f4e 100644
--- a/scram/route_manager/tests/test_autocreate_admin.py
+++ b/scram/route_manager/tests/test_autocreate_admin.py
@@ -1,13 +1,15 @@
-"""This file contains tests for the auto-creation of an admin user."""
+"""Test the auto-creation of an admin user."""
import pytest
-from django.contrib.auth.models import User
from django.contrib.messages import get_messages
from django.test import Client
from django.urls import reverse
from scram.users.models import User
+LEVEL_SUCCESS = 25
+LEVEL_INFO = 20
+
@pytest.mark.django_db
def test_autocreate_admin(settings):
@@ -15,15 +17,15 @@ def test_autocreate_admin(settings):
settings.AUTOCREATE_ADMIN = True
client = Client()
response = client.get(reverse("route_manager:home"))
- assert response.status_code == 200
+ assert response.status_code == 200 # noqa: PLR2004
assert User.objects.count() == 1
user = User.objects.get(username="admin")
assert user.is_superuser
assert user.email == "admin@example.com"
messages = list(get_messages(response.wsgi_request))
- assert len(messages) == 2
- assert messages[0].level == 25 # SUCCESS
- assert messages[1].level == 20 # INFO
+ assert len(messages) == 2 # noqa: PLR2004
+ assert messages[0].level == LEVEL_SUCCESS
+ assert messages[1].level == LEVEL_INFO
@pytest.mark.django_db
@@ -32,7 +34,7 @@ def test_autocreate_admin_disabled(settings):
settings.AUTOCREATE_ADMIN = False
client = Client()
response = client.get(reverse("route_manager:home"))
- assert response.status_code == 200
+ assert response.status_code == 200 # noqa: PLR2004
assert User.objects.count() == 0
@@ -43,6 +45,6 @@ def test_autocreate_admin_existing_user(settings):
User.objects.create_user("testuser", "test@example.com", "password")
client = Client()
response = client.get(reverse("route_manager:home"))
- assert response.status_code == 200
+ assert response.status_code == 200 # noqa: PLR2004
assert User.objects.count() == 1
assert not User.objects.filter(username="admin").exists()
diff --git a/scram/route_manager/tests/test_history.py b/scram/route_manager/tests/test_history.py
index b234e621..d2848e83 100644
--- a/scram/route_manager/tests/test_history.py
+++ b/scram/route_manager/tests/test_history.py
@@ -16,7 +16,7 @@ def setUp(self):
def test_comments(self):
"""Ensure we can go back and set a reason."""
self.atype.name = "Nullroute"
- self.atype._change_reason = "Use more descriptive name"
+ self.atype._change_reason = "Use more descriptive name" # noqa SLF001
self.atype.save()
self.assertIsNotNone(get_change_reason_from_object(self.atype))
@@ -47,7 +47,7 @@ def test_comments(self):
e.route = Route.objects.create(route=route_new)
change_reason = "I meant 32, not 16."
- e._change_reason = change_reason
+ e._change_reason = change_reason # noqa SLF001
e.save()
self.assertEqual(len(e.history.all()), 2)
diff --git a/scram/route_manager/tests/test_swagger.py b/scram/route_manager/tests/test_swagger.py
new file mode 100644
index 00000000..ce955db9
--- /dev/null
+++ b/scram/route_manager/tests/test_swagger.py
@@ -0,0 +1,30 @@
+"""Test the swagger API endpoints."""
+
+import pytest
+from django.urls import reverse
+
+
+@pytest.mark.django_db
+def test_swagger_api(client):
+ """Test that the Swagger API endpoint returns a successful response."""
+ url = reverse("swagger-ui")
+ response = client.get(url)
+ assert response.status_code == 200 # noqa: PLR2004
+
+
+@pytest.mark.django_db
+def test_redoc_api(client):
+ """Test that the Redoc API endpoint returns a successful response."""
+ url = reverse("redoc")
+ response = client.get(url)
+ assert response.status_code == 200 # noqa: PLR2004
+
+
+@pytest.mark.django_db
+def test_schema_api(client):
+ """Test that the Schema API endpoint returns a successful response."""
+ url = reverse("schema")
+ response = client.get(url)
+ assert response.status_code == 200 # noqa: PLR2004
+ expected_strings = [b"/entries/", b"/actiontypes/", b"/ignore_entries/", b"/users/"]
+ assert all(string in response.content for string in expected_strings)
diff --git a/scram/route_manager/tests/test_views.py b/scram/route_manager/tests/test_views.py
index bcc3a4ad..1433135a 100644
--- a/scram/route_manager/tests/test_views.py
+++ b/scram/route_manager/tests/test_views.py
@@ -1,7 +1,8 @@
"""Define simple tests for the template-based Views."""
+from django.conf import settings
from django.test import TestCase
-from django.urls import resolve
+from django.urls import resolve, reverse
from scram.route_manager.views import home_page
@@ -13,3 +14,45 @@ def test_root_url_resolves_to_home_page_view(self):
"""Ensure we can find the home page."""
found = resolve("/")
self.assertEqual(found.func, home_page)
+
+
+class HomePageFirstVisitTest(TestCase):
+ """Test how the home page renders the first time we view it."""
+
+ def setUp(self):
+ """Get the home page."""
+ self.response = self.client.get(reverse("route_manager:home"))
+
+ def test_first_homepage_view_has_userinfo(self):
+ """The first time we view the home page, a user was created for us."""
+ self.assertContains(self.response, b"An admin user was created for you.")
+
+ def test_first_homepage_view_is_logged_in(self):
+ """The first time we view the home page, we're logged in."""
+ self.assertContains(self.response, b'type="submit">Logout')
+
+
+class HomePageLogoutTest(TestCase):
+ """Verify that once logged out, we can't view anything."""
+
+ def test_homepage_logout_links_missing(self):
+ """After logout, we can't see anything."""
+ response = self.client.get(reverse("route_manager:home"))
+ response = self.client.post(reverse(settings.LOGOUT_URL), follow=True)
+ self.assertEqual(response.status_code, 200)
+ response = self.client.get(reverse("route_manager:home"))
+
+ self.assertNotContains(response, b"An admin user was created for you.")
+ self.assertNotContains(response, b'type="submit">Logout')
+ self.assertNotContains(response, b">Admin")
+
+
+class NotFoundTest(TestCase):
+ """Verify that our custom 404 page is being served up."""
+
+ def test_404(self):
+ """Grab a bad URL."""
+ response = self.client.get("/foobarbaz")
+ self.assertContains(
+ response, b'The page you are looking for was not found.
', status_code=404
+ )
diff --git a/scram/route_manager/tests/test_websockets.py b/scram/route_manager/tests/test_websockets.py
index 306e772b..ca9cd528 100644
--- a/scram/route_manager/tests/test_websockets.py
+++ b/scram/route_manager/tests/test_websockets.py
@@ -26,15 +26,13 @@ async def get_communicators(actiontypes, should_match, *args, **kwds):
Returns a list of (communicator, should_match bool) pairs.
"""
- assert len(actiontypes) == len(should_match)
-
router = URLRouter(websocket_urlpatterns)
communicators = [
WebsocketCommunicator(router, f"/ws/route_manager/translator_{actiontype}/") for actiontype in actiontypes
]
- response = zip(communicators, should_match)
+ response = zip(communicators, should_match, strict=True)
- for communicator, should_match in response:
+ for communicator, _ in response:
connected, _ = await communicator.connect()
assert connected
@@ -42,7 +40,7 @@ async def get_communicators(actiontypes, should_match, *args, **kwds):
yield response
finally:
- for communicator, should_match in response:
+ for communicator, _ in response:
await communicator.disconnect()
@@ -71,7 +69,9 @@ def setUp(self):
wsm, _ = WebSocketMessage.objects.get_or_create(msg_type="translator_add", msg_data_route_field="route")
_, _ = WebSocketSequenceElement.objects.get_or_create(
- websocketmessage=wsm, verb="A", action_type=self.actiontype
+ websocketmessage=wsm,
+ verb="A",
+ action_type=self.actiontype,
)
# Set some defaults; some child classes override this
@@ -80,9 +80,9 @@ def setUp(self):
self.generate_add_msgs = [lambda ip, mask: {"type": "translator_add", "message": {"route": f"{ip}/{mask}"}}]
# Now we run any local setup actions by the child classes
- self.local_setUp()
+ self.local_setup()
- def local_setUp(self):
+ def local_setup(self):
"""Allow child classes to override this if desired."""
return
@@ -151,9 +151,9 @@ async def test_add_v6(self):
class TranslatorDontCrossTheStreamsTestCase(TestTranslatorBaseCase):
- """Two translators in the same group, two in another group, single IP, ensure we get only the messages we expect."""
+ """Two translators in one group, two in another group, single IP, ensure we get only the messages we expect."""
- def local_setUp(self):
+ def local_setup(self):
"""Define the actions and what we expect."""
self.actiontypes = ["block", "block", "noop", "noop"]
self.should_match = [True, True, False, False]
@@ -162,15 +162,21 @@ def local_setUp(self):
class TranslatorSequenceTestCase(TestTranslatorBaseCase):
"""Test a sequence of WebSocket messages."""
- def local_setUp(self):
+ def local_setup(self):
"""Define the messages we want to send."""
wsm2 = WebSocketMessage.objects.create(msg_type="translator_add", msg_data_route_field="foo")
_ = WebSocketSequenceElement.objects.create(
- websocketmessage=wsm2, verb="A", action_type=self.actiontype, order_num=20
+ websocketmessage=wsm2,
+ verb="A",
+ action_type=self.actiontype,
+ order_num=20,
)
wsm3 = WebSocketMessage.objects.create(msg_type="translator_add", msg_data_route_field="bar")
_ = WebSocketSequenceElement.objects.create(
- websocketmessage=wsm3, verb="A", action_type=self.actiontype, order_num=2
+ websocketmessage=wsm3,
+ verb="A",
+ action_type=self.actiontype,
+ order_num=2,
)
self.generate_add_msgs = [
@@ -183,7 +189,7 @@ def local_setUp(self):
class TranslatorParametersTestCase(TestTranslatorBaseCase):
"""Additional parameters in the JSONField."""
- def local_setUp(self):
+ def local_setup(self):
"""Define the message we want to send."""
wsm = WebSocketMessage.objects.get(msg_type="translator_add", msg_data_route_field="route")
wsm.msg_data = {"asn": 65550, "community": 100, "route": "Ensure this gets overwritten."}
diff --git a/scram/route_manager/views.py b/scram/route_manager/views.py
index 99cc2c52..fa8b7116 100644
--- a/scram/route_manager/views.py
+++ b/scram/route_manager/views.py
@@ -24,8 +24,10 @@
channel_layer = get_channel_layer()
-def home_page(request, prefilter=Entry.objects.all()):
+def home_page(request, prefilter=None):
"""Return the home page, autocreating a user if none exists."""
+ if not prefilter:
+ prefilter = Entry.objects.all()
num_entries = settings.RECENT_LIMIT
if request.user.has_perms(("route_manager.view_entry", "route_manager.add_entry")):
readwrite = True
@@ -103,29 +105,27 @@ def add_entry(request):
with transaction.atomic():
res = add_entry_api(request)
- if res.status_code == 201:
+ if res.status_code == 201: # noqa: PLR2004
messages.add_message(
request,
messages.SUCCESS,
"Entry successfully added",
)
- elif res.status_code == 400:
+ elif res.status_code == 400: # noqa: PLR2004
errors = []
if isinstance(res.data, rest_framework.utils.serializer_helpers.ReturnDict):
for k, v in res.data.items():
- for error in v:
- errors.append(f"'{k}' error: {str(error)}")
+ errors.extend(f"'{k}' error: {error!s}" for error in v)
else:
- for k, v in res.data.items():
- errors.append(f"error: {str(v)}")
+ errors.extend(f"error: {v!s}" for v in res.data.values())
messages.add_message(request, messages.ERROR, "
".join(errors))
- elif res.status_code == 403:
+ elif res.status_code == 403: # noqa: PLR2004
messages.add_message(request, messages.ERROR, "Permission Denied")
else:
messages.add_message(request, messages.WARNING, f"Something went wrong: {res.status_code}")
with transaction.atomic():
home = home_page(request)
- return home
+ return home # noqa RET504
def process_expired(request):
@@ -142,7 +142,7 @@ def process_expired(request):
{
"entries_deleted": entries_start - entries_end,
"active_entries": entries_end,
- }
+ },
),
content_type="application/json",
)
@@ -154,7 +154,8 @@ class EntryListView(ListView):
model = Entry
template_name = "route_manager/entry_list.html"
- def get_context_data(self, **kwargs):
+ @staticmethod
+ def get_context_data(**kwargs):
"""Group entries by action type."""
context = {"entries": {}}
for at in ActionType.objects.all():
diff --git a/scram/users/api/serializers.py b/scram/users/api/serializers.py
index a49520d0..f8d22c88 100644
--- a/scram/users/api/serializers.py
+++ b/scram/users/api/serializers.py
@@ -7,10 +7,10 @@
class UserSerializer(serializers.ModelSerializer):
- """This serializer defines no new fields."""
+ """Map to the User model."""
class Meta:
- """Maps to the User model, and specifies the fields exposed by the API."""
+ """Specify the fields exposed by the API."""
model = User
fields = ["username", "name", "url"]
diff --git a/scram/users/api/views.py b/scram/users/api/views.py
index ce70a713..0e17cb54 100644
--- a/scram/users/api/views.py
+++ b/scram/users/api/views.py
@@ -23,8 +23,9 @@ def get_queryset(self, *args, **kwargs):
"""Query on User ID."""
return self.queryset.filter(id=self.request.user.id)
+ @staticmethod
@action(detail=False, methods=["GET"])
- def me(self, request):
+ def me(request):
"""Return the current user."""
serializer = UserSerializer(request.user, context={"request": request})
return Response(status=status.HTTP_200_OK, data=serializer.data)
diff --git a/scram/users/apps.py b/scram/users/apps.py
index 449ebe39..e6cc9c6b 100644
--- a/scram/users/apps.py
+++ b/scram/users/apps.py
@@ -14,7 +14,8 @@ class UsersConfig(AppConfig):
name = "scram.users"
verbose_name = _("Users")
- def ready(self):
+ @staticmethod
+ def ready():
"""Check if signals are registered for User events."""
try:
import scram.users.signals # noqa F401
diff --git a/scram/users/tests/factories.py b/scram/users/tests/factories.py
index 1eca95c0..10331878 100644
--- a/scram/users/tests/factories.py
+++ b/scram/users/tests/factories.py
@@ -1,6 +1,7 @@
"""Define Factory tests for the Users application."""
-from typing import Any, Sequence
+from collections.abc import Sequence
+from typing import Any
from django.contrib.auth import get_user_model
from factory import Faker, post_generation
diff --git a/scram/users/tests/test_forms.py b/scram/users/tests/test_forms.py
index 4ff89f37..3fe65da4 100644
--- a/scram/users/tests/test_forms.py
+++ b/scram/users/tests/test_forms.py
@@ -13,12 +13,11 @@ class TestUserCreationForm:
"""Test class for all tests related to the UserCreationForm."""
def test_username_validation_error_msg(self, user: User):
- """
- Tests UserCreation Form's unique validator functions correctly by testing 3 things.
+ """Tests UserCreation Form's unique validator functions correctly by testing 3 things.
- 1) A new user with an existing username cannot be added.
- 2) Only 1 error is raised by the UserCreation Form
- 3) The desired error message is raised
+ 1) A new user with an existing username cannot be added.
+ 2) Only 1 error is raised by the UserCreation Form
+ 3) The desired error message is raised
"""
# The user already exists,
# hence cannot be created.
@@ -27,7 +26,7 @@ def test_username_validation_error_msg(self, user: User):
"username": user.username,
"password1": user.password,
"password2": user.password,
- }
+ },
)
assert not form.is_valid()
diff --git a/scram/users/tests/test_views.py b/scram/users/tests/test_views.py
index 63ad97bc..2df1929a 100644
--- a/scram/users/tests/test_views.py
+++ b/scram/users/tests/test_views.py
@@ -18,14 +18,14 @@
class TestUserUpdateView:
- """
- Define tests related to the Update View.
+ """Define tests related to the Update View.
- TODO:
+ Todo:
extracting view initialization code as class-scoped fixture
would be great if only pytest-django supported non-function-scoped
fixture db access -- this is a work-in-progress for now:
https://github.com/pytest-dev/pytest-django/pull/258
+
"""
def test_get_success_url(self, user: User, rf: RequestFactory):
diff --git a/scram/utils/context_processors.py b/scram/utils/context_processors.py
index 2a86dfa4..0bbf8786 100644
--- a/scram/utils/context_processors.py
+++ b/scram/utils/context_processors.py
@@ -4,7 +4,11 @@
def settings_context(_request):
- """Define settings available by default to the templates context."""
+ """Define settings available by default to the templates context.
+
+ Returns:
+ dict: Whether or not we have DEBUG on
+ """
# Note: we intentionally do NOT expose the entire settings
# to prevent accidental leaking of sensitive information
return {"DEBUG": settings.DEBUG}
diff --git a/translator/exceptions.py b/translator/exceptions.py
index a999ec8c..b6d2cfd6 100644
--- a/translator/exceptions.py
+++ b/translator/exceptions.py
@@ -1,4 +1,4 @@
-"""This module holds all of the exceptions we want to raise in our translators."""
+"""Define all of the exceptions we want to raise in our translators."""
class ASNError(TypeError):
diff --git a/translator/gobgp.py b/translator/gobgp.py
index 6c10e159..f2b8dee9 100644
--- a/translator/gobgp.py
+++ b/translator/gobgp.py
@@ -15,11 +15,15 @@
DEFAULT_COMMUNITY = 666
DEFAULT_V4_NEXTHOP = "192.0.2.199"
DEFAULT_V6_NEXTHOP = "100::1"
+MAX_SMALL_ASN = 2**16
+MAX_SMALL_COMM = 2**16
+IPV6 = 6
logging.basicConfig(level=logging.DEBUG)
+logger = logging.getLogger(__name__)
-class GoBGP(object):
+class GoBGP:
"""Represents a GoBGP instance."""
def __init__(self, url):
@@ -27,14 +31,16 @@ def __init__(self, url):
channel = grpc.insecure_channel(url)
self.stub = gobgp_pb2_grpc.GobgpApiStub(channel)
- def _get_family_AFI(self, ip_version):
- if ip_version == 6:
+ @staticmethod
+ def _get_family_afi(ip_version):
+ if ip_version == IPV6:
return gobgp_pb2.Family.AFI_IP6
- else:
- return gobgp_pb2.Family.AFI_IP
+ return gobgp_pb2.Family.AFI_IP
- def _build_path(self, ip, event_data={}):
+ def _build_path(self, ip, event_data=None): # noqa: PLR0914
# Grab ASN and Community from our event_data, or use the defaults
+ if not event_data:
+ event_data = {}
asn = event_data.get("asn", DEFAULT_ASN)
community = event_data.get("community", DEFAULT_COMMUNITY)
ip_version = ip.ip.version
@@ -47,7 +53,7 @@ def _build_path(self, ip, event_data={}):
origin.Pack(
attribute_pb2.OriginAttribute(
origin=2,
- )
+ ),
)
# IP prefix and its associated length
@@ -56,27 +62,27 @@ def _build_path(self, ip, event_data={}):
attribute_pb2.IPAddressPrefix(
prefix_len=ip.network.prefixlen,
prefix=str(ip.ip),
- )
+ ),
)
# Set the next hop to the correct value depending on IP family
next_hop = Any()
- family_afi = self._get_family_AFI(ip_version)
- if ip_version == 6:
+ family_afi = self._get_family_afi(ip_version)
+ if ip_version == IPV6:
next_hops = event_data.get("next_hop", DEFAULT_V6_NEXTHOP)
next_hop.Pack(
attribute_pb2.MpReachNLRIAttribute(
family=gobgp_pb2.Family(afi=family_afi, safi=gobgp_pb2.Family.SAFI_UNICAST),
next_hops=[next_hops],
nlris=[nlri],
- )
+ ),
)
else:
next_hops = event_data.get("next_hop", DEFAULT_V4_NEXTHOP)
next_hop.Pack(
attribute_pb2.NextHopAttribute(
next_hop=next_hops,
- )
+ ),
)
# Set our AS Path
@@ -94,13 +100,13 @@ def _build_path(self, ip, event_data={}):
communities = Any()
# Standard community
# Since we pack both into the community string we need to make sure they will both fit
- if asn < 65536 and community < 65536:
+ if asn < MAX_SMALL_ASN and community < MAX_SMALL_COMM:
# We bitshift ASN left by 16 so that there is room to add the community on the end of it. This is because
# GoBGP wants the community sent as a single integer.
comm_id = (asn << 16) + community
communities.Pack(attribute_pb2.CommunitiesAttribute(communities=[comm_id]))
else:
- logging.info(f"LargeCommunity Used - ASN:{asn} Community: {community}")
+ logger.info("LargeCommunity Used - ASN: %s. Community: %s", asn, community)
global_admin = asn
local_data1 = community
# set to 0 because there's no use case for it, but we need a local_data2 for gobgp to read any of it
@@ -122,7 +128,7 @@ def _build_path(self, ip, event_data={}):
def add_path(self, ip, event_data):
"""Announce a single route."""
- logging.info(f"Blocking {ip}")
+ logger.info("Blocking %s", ip)
try:
path = self._build_path(ip, event_data)
@@ -131,17 +137,17 @@ def add_path(self, ip, event_data):
_TIMEOUT_SECONDS,
)
except ASNError as e:
- logging.warning(f"ASN assertion failed with error: {e}")
+ logger.warning("ASN assertion failed with error: %s", e)
def del_all_paths(self):
"""Remove all routes from being announced."""
- logging.warning("Withdrawing ALL routes")
+ logger.warning("Withdrawing ALL routes")
self.stub.DeletePath(gobgp_pb2.DeletePathRequest(table_type=gobgp_pb2.GLOBAL), _TIMEOUT_SECONDS)
def del_path(self, ip, event_data):
"""Remove a single route from being announced."""
- logging.info(f"Unblocking {ip}")
+ logger.info("Unblocking %s", ip)
try:
path = self._build_path(ip, event_data)
self.stub.DeletePath(
@@ -149,12 +155,16 @@ def del_path(self, ip, event_data):
_TIMEOUT_SECONDS,
)
except ASNError as e:
- logging.warning(f"ASN assertion failed with error: {e}")
+ logger.warning("ASN assertion failed with error: %s", e)
def get_prefixes(self, ip):
- """Retrieve the routes that match a prefix and are announced."""
+ """Retrieve the routes that match a prefix and are announced.
+
+ Returns:
+ list: The routes that overlap with the prefix and are currently announced.
+ """
prefixes = [gobgp_pb2.TableLookupPrefix(prefix=str(ip.ip))]
- family_afi = self._get_family_AFI(ip.ip.version)
+ family_afi = self._get_family_afi(ip.ip.version)
result = self.stub.ListPath(
gobgp_pb2.ListPathRequest(
table_type=gobgp_pb2.GLOBAL,
diff --git a/translator/shared.py b/translator/shared.py
index 27e89f97..78555b2d 100644
--- a/translator/shared.py
+++ b/translator/shared.py
@@ -1,11 +1,12 @@
-"""This module provides a location for code that we want to share between all translators."""
+"""Provide a location for code that we want to share between all translators."""
from exceptions import ASNError
+MAX_ASN_VAL = 2**32 - 1
+
def asn_is_valid(asn: int) -> bool:
- """
- asn_is_valid makes sure that an ASN passed in is a valid 2 or 4 Byte ASN.
+ """asn_is_valid makes sure that an ASN passed in is a valid 2 or 4 Byte ASN.
Args:
asn (int): The Autonomous System Number that we want to validate
@@ -15,11 +16,14 @@ def asn_is_valid(asn: int) -> bool:
Returns:
bool: _description_
+
"""
if not isinstance(asn, int):
- raise ASNError(f"ASN {asn} is not an Integer, has type {type(asn)}")
- if not 0 < asn < 4294967295:
+ msg = f"ASN {asn} is not an Integer, has type {type(asn)}"
+ raise ASNError(msg)
+ if not 0 < asn < MAX_ASN_VAL:
# This is the max as stated in rfc6996
- raise ASNError(f"ASN {asn} is out of range. Must be between 0 and 4294967295")
+ msg = f"ASN {asn} is out of range. Must be between 0 and 4294967295"
+ raise ASNError(msg)
return True
diff --git a/translator/tests/acceptance/steps/actions.py b/translator/tests/acceptance/steps/actions.py
index 1e03e8e1..fdb2a95f 100644
--- a/translator/tests/acceptance/steps/actions.py
+++ b/translator/tests/acceptance/steps/actions.py
@@ -27,17 +27,17 @@ def del_block(context, route, asn, community):
def get_block_status(context, ip):
- """Check if the IP is currently blocked."""
+ """Check if the IP is currently blocked.
+
+ Returns:
+ bool: The return value. True if the IP is currently blocked, False otherwise.
+ """
# Allow our add/delete requests to settle
time.sleep(1)
ip_obj = ipaddress.ip_interface(ip)
- for path in context.gobgp.get_prefixes(ip_obj):
- if ip_obj in ipaddress.ip_network(path.destination.prefix):
- return True
-
- return False
+ return any(ip_obj in ipaddress.ip_network(path.destination.prefix) for path in context.gobgp.get_prefixes(ip_obj))
@capture
diff --git a/translator/translator.py b/translator/translator.py
index b5a7dd02..d2c34018 100644
--- a/translator/translator.py
+++ b/translator/translator.py
@@ -11,33 +11,41 @@
import websockets
from gobgp import GoBGP
+logger = logging.getLogger(__name__)
+
+KNOWN_MESSAGES = {
+ "translator_add",
+ "translator_remove",
+ "translator_remove_all",
+ "translator_check",
+}
+
# Here we setup a debugger if this is desired. This obviously should not be run in production.
debug_mode = os.environ.get("DEBUG")
if debug_mode:
def install_deps():
- """
- Install necessary dependencies for debuggers.
+ """Install necessary dependencies for debuggers.
Because of how we build translator currently, we don't have a great way to selectively
install things at build, so we just do it here! Right now this also includes base.txt,
which is unecessary, but in the future when we build a little better, it'll already be
setup.
"""
- logging.info("Installing dependencies for debuggers")
+ logger.info("Installing dependencies for debuggers")
- import subprocess
- import sys
+ import subprocess # noqa: S404, PLC0415
+ import sys # noqa: PLC0415
- subprocess.check_call([sys.executable, "-m", "pip", "install", "-r", "/requirements/local.txt"])
+ subprocess.check_call([sys.executable, "-m", "pip", "install", "-r", "/requirements/local.txt"]) # noqa: S603 TODO: add this to the container build
- logging.info("Done installing dependencies for debuggers")
+ logger.info("Done installing dependencies for debuggers")
- logging.info(f"Translator is set to use a debugger. Provided debug mode: {debug_mode}")
+ logger.info("Translator is set to use a debugger. Provided debug mode: %s", debug_mode)
# We have to setup the debugger appropriately for various IDEs. It'd be nice if they all used the same thing but
# sadly, we live in a fallen world.
if debug_mode == "pycharm-pydevd":
- logging.info("Entering debug mode for pycharm, make sure the debug server is running in PyCharm!")
+ logger.info("Entering debug mode for pycharm, make sure the debug server is running in PyCharm!")
install_deps()
@@ -45,67 +53,67 @@ def install_deps():
pydevd_pycharm.settrace("host.docker.internal", port=56782, stdoutToServer=True, stderrToServer=True)
- logging.info("Debugger started.")
+ logger.info("Debugger started.")
elif debug_mode == "debugpy":
- logging.info("Entering debug mode for debugpy (VSCode)")
+ logger.info("Entering debug mode for debugpy (VSCode)")
install_deps()
import debugpy
- debugpy.listen(("0.0.0.0", 56781))
+ debugpy.listen(("0.0.0.0", 56781)) # noqa S104 (doesn't like binding to all interfaces)
- logging.info("Debugger listening on port 56781.")
+ logger.info("Debugger listening on port 56781.")
else:
- logging.warning(f"Invalid debug mode given: {debug_mode}. Debugger not started")
+ logger.warning("Invalid debug mode given: %s. Debugger not started", debug_mode)
# Must match the URL in asgi.py, and needs a trailing slash
hostname = os.environ.get("SCRAM_HOSTNAME", "scram_hostname_not_set")
url = os.environ.get("SCRAM_EVENTS_URL", "ws://django:8000/ws/route_manager/translator_block/")
+async def process(message, websocket, g):
+ """Take a single message form the websocket and hand it off to the appropriate function."""
+ json_message = json.loads(message)
+ event_type = json_message.get("type")
+ event_data = json_message.get("message")
+ if event_type not in KNOWN_MESSAGES:
+ logger.error("Unknown event type received: %s", event_type)
+ # TODO: Maybe only allow this in testing?
+ elif event_type == "translator_remove_all":
+ g.del_all_paths()
+ else:
+ try:
+ ip = ipaddress.ip_interface(event_data["route"])
+ except: # noqa E722
+ logger.exception("Error parsing message: %s", message)
+ return
+
+ if event_type == "translator_add":
+ g.add_path(ip, event_data)
+ elif event_type == "translator_remove":
+ g.del_path(ip, event_data)
+ elif event_type == "translator_check":
+ json_message["type"] = "translator_check_resp"
+ json_message["message"]["is_blocked"] = g.is_blocked(ip)
+ json_message["message"]["translator_name"] = hostname
+ await websocket.send(json.dumps(json_message))
+
+
async def main():
"""Connect to the websocket and start listening for messages."""
g = GoBGP("gobgp:50051")
async for websocket in websockets.connect(url):
try:
async for message in websocket:
- json_message = json.loads(message)
- event_type = json_message.get("type")
- event_data = json_message.get("message")
- if event_type not in [
- "translator_add",
- "translator_remove",
- "translator_remove_all",
- "translator_check",
- ]:
- logging.error(f"Unknown event type received: {event_type!r}")
- # TODO: Maybe only allow this in testing?
- elif event_type == "translator_remove_all":
- g.del_all_paths()
- else:
- try:
- ip = ipaddress.ip_interface(event_data["route"])
- except: # noqa E722
- logging.error(f"Error parsing message: {message!r}")
- continue
-
- if event_type == "translator_add":
- g.add_path(ip, event_data)
- elif event_type == "translator_remove":
- g.del_path(ip, event_data)
- elif event_type == "translator_check":
- json_message["type"] = "translator_check_resp"
- json_message["message"]["is_blocked"] = g.is_blocked(ip)
- json_message["message"]["translator_name"] = hostname
- await websocket.send(json.dumps(json_message))
+ await process(message, websocket, g)
except websockets.ConnectionClosed:
continue
if __name__ == "__main__":
- logging.info("translator started")
+ logger.info("translator started")
loop = asyncio.get_event_loop()
loop.run_until_complete(main())
loop.close()