diff --git a/netbox_config_diff/compliance/base.py b/netbox_config_diff/compliance/base.py index 9478764..0d60752 100644 --- a/netbox_config_diff/compliance/base.py +++ b/netbox_config_diff/compliance/base.py @@ -8,15 +8,12 @@ from dcim.choices import DeviceStatusChoices from dcim.models import Device, DeviceRole, Site from django.conf import settings -from django.core.exceptions import ObjectDoesNotExist from django.db.models import Q from extras.scripts import MultiObjectVar, ObjectVar from jinja2.exceptions import TemplateError from netutils.config.compliance import diff_network_config from utilities.exceptions import AbortScript -from netbox_config_diff.models import ConfigCompliance - from .models import DeviceDataClass from .secrets import SecretsMixin from .utils import PLATFORM_MAPPING, CustomChoiceVar, exclude_lines, get_unified_diff @@ -112,13 +109,7 @@ def validate_data(self, data: dict) -> Iterable[Device]: def update_in_db(self, devices: list[DeviceDataClass]) -> None: for device in devices: self.log_results(device) - try: - obj = ConfigCompliance.objects.get(device_id=device.pk) - obj.snapshot() - obj.update(**device.to_db()) - obj.save() - except ObjectDoesNotExist: - ConfigCompliance.objects.create(**device.to_db()) + device.send_to_db() def log_results(self, device: DeviceDataClass) -> None: if device.error: diff --git a/netbox_config_diff/compliance/models.py b/netbox_config_diff/compliance/models.py index f16077e..1ddaa26 100644 --- a/netbox_config_diff/compliance/models.py +++ b/netbox_config_diff/compliance/models.py @@ -4,6 +4,7 @@ from scrapli import AsyncScrapli from netbox_config_diff.choices import ConfigComplianceStatusChoices +from netbox_config_diff.models import ConfigCompliance @dataclass @@ -18,21 +19,21 @@ class DeviceDataClass: exclude_regex: str | None = None rendered_config: str | None = None actual_config: str | None = None - diff: str | None = None + diff: str = "" missing: str | None = None extra: str | None = None - error: str | None = None + error: str = "" config_error: str | None = None auth_strict_key: bool = False transport: str = "asyncssh" - def __str__(self): + def __str__(self) -> str: return self.name - def __hash__(self): + def __hash__(self) -> int: return hash(self.name) - def to_scrapli(self): + def to_scrapli(self) -> dict: return { "host": self.mgmt_ip, "auth_username": self.username, @@ -78,13 +79,16 @@ def to_scrapli(self): }, } - def to_db(self): + def get_status(self) -> str: if self.error: - status = ConfigComplianceStatusChoices.ERRORED + return ConfigComplianceStatusChoices.ERRORED elif self.diff: - status = ConfigComplianceStatusChoices.DIFF + return ConfigComplianceStatusChoices.DIFF else: - status = ConfigComplianceStatusChoices.COMPLIANT + return ConfigComplianceStatusChoices.COMPLIANT + + def to_db(self) -> dict: + status = self.get_status() return { "device_id": self.pk, @@ -97,7 +101,17 @@ def to_db(self): "extra": self.extra or "", } - async def get_actual_config(self): + def send_to_db(self) -> None: + try: + obj = ConfigCompliance.objects.get(device_id=self.pk) + if obj.status != self.get_status(): + obj.update(commit=True, **self.to_db()) + elif obj.diff != self.diff or obj.error != self.error: + obj.update(commit=True, **self.to_db()) + except ConfigCompliance.DoesNotExist: + ConfigCompliance.objects.create(**self.to_db()) + + async def get_actual_config(self) -> None: if self.error is not None: return try: diff --git a/netbox_config_diff/configurator/base.py b/netbox_config_diff/configurator/base.py index 1e546a7..3f3913e 100644 --- a/netbox_config_diff/configurator/base.py +++ b/netbox_config_diff/configurator/base.py @@ -19,7 +19,6 @@ from netbox_config_diff.configurator.exceptions import DeviceConfigurationError, DeviceValidationError from netbox_config_diff.configurator.utils import CustomLogger from netbox_config_diff.constants import ACCEPTABLE_DRIVERS -from netbox_config_diff.models import ConfigCompliance from .factory import AsyncScrapliCfg @@ -28,9 +27,9 @@ class Configurator(SecretsMixin): def __init__(self, devices: Iterable[Device], request: NetBoxFakeRequest) -> None: self.devices = devices self.request = request - self.unprocessed_devices = set() - self.processed_devices = set() - self.failed_devices = set() + self.unprocessed_devices: set[DeviceDataClass] = set() + self.processed_devices: set[DeviceDataClass] = set() + self.failed_devices: set[DeviceDataClass] = set() self.substitutes: dict[str, list] = {} self.logger = CustomLogger() self.connections: dict[str, AsyncScrapliCfgPlatform] = {} @@ -109,13 +108,7 @@ def collect_diffs(self) -> None: @sync_to_async def update_diffs(self) -> None: for device in self.unprocessed_devices: - try: - obj = ConfigCompliance.objects.get(device_id=device.pk) - obj.snapshot() - obj.update(**device.to_db()) - obj.save() - except ConfigCompliance.DoesNotExist: - ConfigCompliance.objects.create(**device.to_db()) + device.send_to_db() async def _collect_diffs(self) -> None: async with self.connection(): diff --git a/netbox_config_diff/models.py b/netbox_config_diff/models.py index d65dff1..0e4c4e9 100644 --- a/netbox_config_diff/models.py +++ b/netbox_config_diff/models.py @@ -54,16 +54,18 @@ class ConfigCompliance(ChangeLoggingMixin, models.Model): class Meta: ordering = ("device",) - def __str__(self): + def __str__(self) -> str: return self.device.name def get_absolute_url(self): return reverse("plugins:netbox_config_diff:configcompliance", args=[self.pk]) - def get_status_color(self): + def get_status_color(self) -> str: return ConfigComplianceStatusChoices.colors.get(self.status) - def update(self, commit=False, **kwargs): + def update(self, commit: bool = False, **kwargs) -> None: + if commit: + self.snapshot() for key, value in kwargs.items(): setattr(self, key, value) if commit: @@ -102,7 +104,7 @@ class PlatformSetting(NetBoxModel): class Meta: ordering = ("driver",) - def __str__(self): + def __str__(self) -> str: return f"{self.platform} {self.driver}" def get_absolute_url(self): @@ -156,20 +158,20 @@ class ConfigurationRequest(JobsMixin, PrimaryModel): class Meta: ordering = ("-created",) - def __str__(self): + def __str__(self) -> str: return f"CR #{self.pk}" def get_absolute_url(self): return reverse("plugins:netbox_config_diff:configurationrequest", args=[self.pk]) - def get_status_color(self): + def get_status_color(self) -> str: return ConfigurationRequestStatusChoices.colors.get(self.status) @property - def finished(self): + def finished(self) -> bool: return self.status in ConfigurationRequestStatusChoices.FINISHED_STATE_CHOICES - def delete(self, *args, **kwargs): + def delete(self, *args, **kwargs) -> None: super().delete(*args, **kwargs) queue = django_rq.get_queue(RQ_QUEUE_DEFAULT) @@ -180,7 +182,7 @@ def delete(self, *args, **kwargs): except InvalidJobOperation: pass - def enqueue_job(self, request, job_name, schedule_at=None): + def enqueue_job(self, request, job_name, schedule_at=None) -> Job: return Job.enqueue( import_string(f"netbox_config_diff.jobs.{job_name}"), name=f"{self} {job_name}", @@ -190,7 +192,7 @@ def enqueue_job(self, request, job_name, schedule_at=None): schedule_at=schedule_at, ) - def start(self, job: Job): + def start(self, job: Job) -> None: """ Record the job's start time and update its status to "running." """ @@ -201,7 +203,7 @@ def start(self, job: Job): self.status = ConfigurationRequestStatusChoices.RUNNING self.save() - def terminate(self, job: Job, status: str = ConfigurationRequestStatusChoices.COMPLETED): + def terminate(self, job: Job, status: str = ConfigurationRequestStatusChoices.COMPLETED) -> None: job.terminate(status=status) self.status = status self.completed = timezone.now() @@ -243,7 +245,7 @@ class Substitute(NetBoxModel): class Meta: ordering = ("name",) - def __str__(self): + def __str__(self) -> str: return self.name def get_absolute_url(self):