diff --git a/src/gallia/command/base.py b/src/gallia/command/base.py index 12ddef776..641d3a04a 100644 --- a/src/gallia/command/base.py +++ b/src/gallia/command/base.py @@ -15,13 +15,14 @@ from logging import Handler from pathlib import Path from subprocess import CalledProcessError, run +from tempfile import gettempdir from typing import cast import exitcode import msgspec -from pydantic import ConfigDict +from pydantic import ConfigDict, field_serializer -from gallia.command.config import Field, GalliaBaseModel +from gallia.command.config import Field, GalliaBaseModel, idempotent from gallia.db.handler import DBHandler from gallia.dumpcap import Dumpcap from gallia.log import add_zst_log_handler, get_logger, tz @@ -79,6 +80,15 @@ class BaseCommandConfig(GalliaBaseModel, argument_group="generic", config_sectio None, description="path to file used for a posix lock", metavar="PATH" ) db: Path | None = Field(None, description="Path to sqlite3 database") + artifacts_dir: Path | None = Field( + None, description="Folder for artifacts", metavar="DIR", config_section="gallia.scanner" + ) + artifacts_base: Path = Field( + Path(gettempdir()).joinpath("gallia"), + description="Base directory for artifacts", + metavar="DIR", + config_section="gallia.scanner", + ) class BaseCommand(ABC): @@ -388,8 +398,10 @@ def run(self) -> int: class ScannerConfig(AsyncScriptConfig, argument_group="scanner", config_section="gallia.scanner"): dumpcap: bool = Field(True, description="Enable/Disable creating a pcap file") - target: TargetURI | None = Field(description="URI that describes the target", metavar="TARGET") - power_supply: PowerSupplyURI | None = Field( + target: idempotent(TargetURI) = Field( + description="URI that describes the target", metavar="TARGET" + ) + power_supply: idempotent(PowerSupplyURI) | None = Field( None, description="URI specifying the location of the relevant opennetzteil server", metavar="URI", @@ -402,6 +414,13 @@ class ScannerConfig(AsyncScriptConfig, argument_group="scanner", config_section= 5.0, description="time to sleep after the power-cycle", metavar="SECs" ) + @field_serializer("target", "power_supply") + def serialize_target_uri(self, target_uri: TargetURI | None, _info): + if target_uri is None: + return None + + return target_uri.raw + class Scanner(AsyncScript, ABC): """Scanner is a base class for all scanning related commands. diff --git a/src/gallia/db/handler.py b/src/gallia/db/handler.py index 39ab1b365..dbfc15689 100644 --- a/src/gallia/db/handler.py +++ b/src/gallia/db/handler.py @@ -50,7 +50,7 @@ def bytes_repr(data: bytes) -> str: CREATE TABLE IF NOT EXISTS run_meta ( id integer primary key, script text not null, - config json not null check(json_valid(arguments)), + config json not null check(json_valid(config)), start_time real not null, start_timezone text not null, end_time real, @@ -205,7 +205,7 @@ async def insert_run_meta( # noqa: PLR0913 query = ( "INSERT INTO " - "run_meta(script, config, settings, start_time, start_timezone, path, exclude) " + "run_meta(script, config, start_time, start_timezone, path, exclude) " "VALUES (?, ?, ?, ?, ?, FALSE)" ) cursor = await self.connection.execute( diff --git a/vendor/pydantic-argparse/pydantic_argparse/argparse/parser.py b/vendor/pydantic-argparse/pydantic_argparse/argparse/parser.py index c8b9232f1..ee96cb3ab 100644 --- a/vendor/pydantic-argparse/pydantic_argparse/argparse/parser.py +++ b/vendor/pydantic-argparse/pydantic_argparse/argparse/parser.py @@ -159,6 +159,8 @@ def parse_typed_args( # Call Super Class Method namespace = self.parse_args(args) + print(namespace) + try: nested_parser = _NestedArgumentParser(model=self.model, namespace=namespace) return nested_parser.validate() diff --git a/vendor/pydantic-argparse/pydantic_argparse/utils/pydantic.py b/vendor/pydantic-argparse/pydantic_argparse/utils/pydantic.py index 10adfa03f..c9e06d0bb 100644 --- a/vendor/pydantic-argparse/pydantic_argparse/utils/pydantic.py +++ b/vendor/pydantic-argparse/pydantic_argparse/utils/pydantic.py @@ -229,7 +229,7 @@ def arg_names(self, invert: bool = False) -> tuple[str, str] | tuple[str]: name = self.info.title or self.name if isinstance(self.info, ArgFieldInfo) and self.info.positional: - return name.upper(), + return name, prefix = "--no-" if invert else "--" long_name = f"{prefix}{name.replace('_', '-')}" @@ -279,7 +279,7 @@ def metavar(self) -> Optional[str]: return self.info.metavar if self.info.positional: - return self.arg_names()[0] + return self.arg_names()[0].upper() # otherwise default to the type field_type = self.get_type()