From deed14e6c5ab357f49c4c22e2359f8fee5c500bf Mon Sep 17 00:00:00 2001 From: Jakub Both Date: Thu, 11 Apr 2024 23:13:56 +0200 Subject: [PATCH] ENH: Enable loading and saving of central correction routines. --- .../corrections/color/colorcorrection.py | 49 +++++++++++++++++-- .../color/illuminationcorrection.py | 41 +++++++++++++++- src/darsia/corrections/shape/curvature.py | 40 +++++++++++++-- src/darsia/corrections/shape/drift.py | 48 ++++++++++++++---- 4 files changed, 160 insertions(+), 18 deletions(-) diff --git a/src/darsia/corrections/color/colorcorrection.py b/src/darsia/corrections/color/colorcorrection.py index 6da2bbd7..9671002b 100644 --- a/src/darsia/corrections/color/colorcorrection.py +++ b/src/darsia/corrections/color/colorcorrection.py @@ -234,10 +234,23 @@ def __init__( """ # Define config - assert config is not None, "provide config at least with 'roi' key" - self.config: dict = copy.deepcopy(config) - """Config dictionary for initialization of color correction.""" + if config is not None: + self.config: dict = copy.deepcopy(config) + """Config dictionary for initialization of color correction.""" + self._init_from_config(base) + else: + self.config = {} + + def _init_from_config( + self, base: Optional[Union[darsia.Image, ColorChecker]] + ) -> None: + """Auxiliary function for initialization from config. + + Args: + base (Image or ColorChecker, optional): reference defining a color checker; if + None provided, use CustomColorChecker. + """ self.active: bool = self.config.get("active", True) """Flag controlling whether correction is active""" @@ -265,6 +278,8 @@ def __init__( """Flag controlling whether values outside the feasible range [0., 1.] are clipped""" # Construct color checker + if base is None: + base = self.config.get("colorchecker", None) self._setup_colorchecker(base) def correct_array( @@ -386,6 +401,34 @@ def write_config_to_file(self, path: Union[Path, str]) -> None: with open(Path(path), "w") as outfile: json.dump(self.config, outfile, indent=4) + def save(self, path: Path) -> None: + """Save the color correction to a file. + + Args: + path (Path): path to the file + + """ + # Make sure that the path exists + path.parents[0].mkdir(parents=True, exist_ok=True) + + # Save the color correction + np.savez(path, config=self.config) + print(f"Color correction saved to {path}.") + + def load(self, path: Path) -> None: + """Load the color correction from a file. + + Args: + path (Path): path to the file + + """ + # Make sure the file exists + assert path.exists(), f"File {path} does not exist." + + # Load the color correction + self.config = np.load(path, allow_pickle=True)["config"].item() + self._init_from_config(base=None) + # ! ---- Auxiliary files def _setup_colorchecker( diff --git a/src/darsia/corrections/color/illuminationcorrection.py b/src/darsia/corrections/color/illuminationcorrection.py index 520909af..ed8f1fcc 100644 --- a/src/darsia/corrections/color/illuminationcorrection.py +++ b/src/darsia/corrections/color/illuminationcorrection.py @@ -1,6 +1,7 @@ """Module containing illumination correction functionality.""" from typing import Literal, Union +from pathlib import Path import matplotlib.pyplot as plt import numpy as np @@ -14,7 +15,7 @@ class IlluminationCorrection(darsia.BaseCorrection): """Class for illumination correction.""" - def __init__( + def setup( self, base: Union[darsia.Image, list[darsia.Image]], samples: list[tuple[slice, ...]], @@ -226,3 +227,41 @@ def correct_array(self, img: np.ndarray) -> np.ndarray: self.local_scaling[i if self.colorspace == "rgb" else 0].img, ) return img_wb + + def save(self, path: Path) -> None: + """Save the illumination correction to a file. + + Args: + path (Path): path to the file + + """ + # Make sure the parent directory exists + path.parent.mkdir(parents=True, exist_ok=True) + + # Store color space and local scaling images as npz files + np.savez( + path, + config={ + "colorspace": self.colorspace, + "local_scaling": self.local_scaling, + }, + ) + print(f"Illumination correction saved to {path}.") + + def load(self, path: Path) -> None: + """Load the illumination correction from a file. + + Args: + path (Path): path to the file + + """ + # Make sure the file exists + if not path.is_file(): + raise FileNotFoundError(f"File {path} not found.") + + # Load color space and local scaling images from npz file + data = np.load(path, allow_pickle=True)["config"].item() + if "colorspace" not in data or "local_scaling" not in data: + raise ValueError("Invalid file format.") + self.colorspace = data["colorspace"] + self.local_scaling = data["local_scaling"] diff --git a/src/darsia/corrections/shape/curvature.py b/src/darsia/corrections/shape/curvature.py index e5657523..5f249aba 100644 --- a/src/darsia/corrections/shape/curvature.py +++ b/src/darsia/corrections/shape/curvature.py @@ -101,11 +101,7 @@ def __init__( self.height = kwargs.get("height", 1.0) else: - if config is None: - raise Exception( - "Please provide either an image as 'image' \ - or a config file as 'config'." - ) + warn("No image provided. Please provide an image or a config file.") # The internally stored config file is tailored to when resize_factor is equal to 1. # For other values, it has to be adapted. @@ -152,6 +148,40 @@ def read_config_from_file(self, path: Path) -> None: with open(str(path), "r") as openfile: self.config = json.load(openfile) + def save(self, path: Path) -> None: + """Save the curvature correction to a file. + + Arguments: + path (Path): path to the file + + """ + # Make sure the parent directory exists + path.parent.mkdir(parents=True, exist_ok=True) + + # Store color space and local scaling images as npz files + np.savez( + path, + config=self.config, + ) + print(f"Curvature correction saved to {path}.") + + def load(self, path: Path) -> None: + """Load the curvature correction from a file. + + Arguments: + path (Path): path to the file + + """ + # Make sure the file exists + if not path.is_file(): + raise FileNotFoundError(f"File {path} not found.") + + # Load color space and local scaling images from npz file + data = np.load(path, allow_pickle=True) + if "config" not in data: + raise ValueError("Invalid file format.") + self.config = data["config"].item() + def return_image(self) -> darsia.Image: """ Returns the current image as a darsia image width provided width and height. diff --git a/src/darsia/corrections/shape/drift.py b/src/darsia/corrections/shape/drift.py index 59b7e692..53065404 100644 --- a/src/darsia/corrections/shape/drift.py +++ b/src/darsia/corrections/shape/drift.py @@ -17,7 +17,7 @@ class DriftCorrection(darsia.BaseCorrection): def __init__( self, - base: Union[np.ndarray, darsia.Image], + base: Optional[Union[np.ndarray, darsia.Image]], config: Optional[dict] = None, ) -> None: """ @@ -37,6 +37,10 @@ def __init__( applied or not, default is True. """ + # Read baseline from config if not provided + if base is None: + base = config.get("base") + assert base is not None, "Baseline image not provided." # Read baseline image if isinstance(base, darsia.Image): @@ -55,27 +59,42 @@ def __init__( # Establish config if config is None: config = {} + self._init_from_config(config) + + self.translation_estimator = darsia.TranslationEstimator() + """Detection of effective translation based on feature detection.""" + + def _init_from_config(self, config: dict) -> None: self.active = config.get("active", True) """Flag controlling whether correction is active.""" - relative_padding: float = config.get("padding", 0.0) + self.relative_padding: float = config.get("padding", 0.0) """Allow for extra padding around the provided roi (relative sense).""" - roi: Optional[Union[list, np.ndarray]] = config.get("roi") + roi: Optional[Union[list, np.ndarray, tuple[slice]]] = config.get("roi") self.roi: Optional[tuple[slice, ...]] = ( None if roi is None - else darsia.bounding_box( - np.array(roi), - padding=round(relative_padding * np.min(self.base.shape[:2])), - max_size=self.base.shape[:2], + else ( + roi + if isinstance(roi, tuple) + else darsia.bounding_box( + np.array(roi), + padding=round(self.relative_padding * np.min(self.base.shape[:2])), + max_size=self.base.shape[:2], + ) ) ) """ROI for feature detection.""" - self.translation_estimator = darsia.TranslationEstimator() - """Detection of effective translation based on feature detection.""" + def return_config(self) -> dict: + """Return config file for the drift correction.""" + return { + "active": self.active, + "padding": self.relative_padding, + "roi": self.roi, + } # ! ---- Main correction routines @@ -102,3 +121,14 @@ def correct_array( ) else: return img + + # ! ---- I/O ---- ! # + def load(self, path) -> None: + """Load the drift correction from a file.""" + config = np.load(path, allow_pickle=True)["config"].item() + self._init_from_config(config) + + def save(self, path) -> None: + """Save the drift correction to a file.""" + np.savez(path, config=self.return_config()) + print(f"Drift correction saved to {path}.")