Skip to content

Commit

Permalink
ENH: Enable loading and saving of central correction routines.
Browse files Browse the repository at this point in the history
  • Loading branch information
jwboth committed Apr 12, 2024
1 parent 513dc41 commit deed14e
Show file tree
Hide file tree
Showing 4 changed files with 160 additions and 18 deletions.
49 changes: 46 additions & 3 deletions src/darsia/corrections/color/colorcorrection.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,10 +234,23 @@ def __init__(
"""

# Define config
assert config is not None, "provide config at least with 'roi' key"
self.config: dict = copy.deepcopy(config)
"""Config dictionary for initialization of color correction."""
if config is not None:
self.config: dict = copy.deepcopy(config)
"""Config dictionary for initialization of color correction."""
self._init_from_config(base)
else:
self.config = {}

def _init_from_config(
self, base: Optional[Union[darsia.Image, ColorChecker]]
) -> None:
"""Auxiliary function for initialization from config.
Args:
base (Image or ColorChecker, optional): reference defining a color checker; if
None provided, use CustomColorChecker.
"""
self.active: bool = self.config.get("active", True)
"""Flag controlling whether correction is active"""

Expand Down Expand Up @@ -265,6 +278,8 @@ def __init__(
"""Flag controlling whether values outside the feasible range [0., 1.] are clipped"""

# Construct color checker
if base is None:
base = self.config.get("colorchecker", None)
self._setup_colorchecker(base)

def correct_array(
Expand Down Expand Up @@ -386,6 +401,34 @@ def write_config_to_file(self, path: Union[Path, str]) -> None:
with open(Path(path), "w") as outfile:
json.dump(self.config, outfile, indent=4)

def save(self, path: Path) -> None:
"""Save the color correction to a file.
Args:
path (Path): path to the file
"""
# Make sure that the path exists
path.parents[0].mkdir(parents=True, exist_ok=True)

# Save the color correction
np.savez(path, config=self.config)
print(f"Color correction saved to {path}.")

def load(self, path: Path) -> None:
"""Load the color correction from a file.
Args:
path (Path): path to the file
"""
# Make sure the file exists
assert path.exists(), f"File {path} does not exist."

# Load the color correction
self.config = np.load(path, allow_pickle=True)["config"].item()
self._init_from_config(base=None)

# ! ---- Auxiliary files

def _setup_colorchecker(
Expand Down
41 changes: 40 additions & 1 deletion src/darsia/corrections/color/illuminationcorrection.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Module containing illumination correction functionality."""

from typing import Literal, Union
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
Expand All @@ -14,7 +15,7 @@
class IlluminationCorrection(darsia.BaseCorrection):
"""Class for illumination correction."""

def __init__(
def setup(
self,
base: Union[darsia.Image, list[darsia.Image]],
samples: list[tuple[slice, ...]],
Expand Down Expand Up @@ -226,3 +227,41 @@ def correct_array(self, img: np.ndarray) -> np.ndarray:
self.local_scaling[i if self.colorspace == "rgb" else 0].img,
)
return img_wb

def save(self, path: Path) -> None:
"""Save the illumination correction to a file.
Args:
path (Path): path to the file
"""
# Make sure the parent directory exists
path.parent.mkdir(parents=True, exist_ok=True)

# Store color space and local scaling images as npz files
np.savez(
path,
config={
"colorspace": self.colorspace,
"local_scaling": self.local_scaling,
},
)
print(f"Illumination correction saved to {path}.")

def load(self, path: Path) -> None:
"""Load the illumination correction from a file.
Args:
path (Path): path to the file
"""
# Make sure the file exists
if not path.is_file():
raise FileNotFoundError(f"File {path} not found.")

# Load color space and local scaling images from npz file
data = np.load(path, allow_pickle=True)["config"].item()
if "colorspace" not in data or "local_scaling" not in data:
raise ValueError("Invalid file format.")
self.colorspace = data["colorspace"]
self.local_scaling = data["local_scaling"]
40 changes: 35 additions & 5 deletions src/darsia/corrections/shape/curvature.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,7 @@ def __init__(
self.height = kwargs.get("height", 1.0)

else:
if config is None:
raise Exception(
"Please provide either an image as 'image' \
or a config file as 'config'."
)
warn("No image provided. Please provide an image or a config file.")

# The internally stored config file is tailored to when resize_factor is equal to 1.
# For other values, it has to be adapted.
Expand Down Expand Up @@ -152,6 +148,40 @@ def read_config_from_file(self, path: Path) -> None:
with open(str(path), "r") as openfile:
self.config = json.load(openfile)

def save(self, path: Path) -> None:
"""Save the curvature correction to a file.
Arguments:
path (Path): path to the file
"""
# Make sure the parent directory exists
path.parent.mkdir(parents=True, exist_ok=True)

# Store color space and local scaling images as npz files
np.savez(
path,
config=self.config,
)
print(f"Curvature correction saved to {path}.")

def load(self, path: Path) -> None:
"""Load the curvature correction from a file.
Arguments:
path (Path): path to the file
"""
# Make sure the file exists
if not path.is_file():
raise FileNotFoundError(f"File {path} not found.")

# Load color space and local scaling images from npz file
data = np.load(path, allow_pickle=True)
if "config" not in data:
raise ValueError("Invalid file format.")
self.config = data["config"].item()

def return_image(self) -> darsia.Image:
"""
Returns the current image as a darsia image width provided width and height.
Expand Down
48 changes: 39 additions & 9 deletions src/darsia/corrections/shape/drift.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class DriftCorrection(darsia.BaseCorrection):

def __init__(
self,
base: Union[np.ndarray, darsia.Image],
base: Optional[Union[np.ndarray, darsia.Image]],
config: Optional[dict] = None,
) -> None:
"""
Expand All @@ -37,6 +37,10 @@ def __init__(
applied or not, default is True.
"""
# Read baseline from config if not provided
if base is None:
base = config.get("base")
assert base is not None, "Baseline image not provided."

# Read baseline image
if isinstance(base, darsia.Image):
Expand All @@ -55,27 +59,42 @@ def __init__(
# Establish config
if config is None:
config = {}
self._init_from_config(config)

self.translation_estimator = darsia.TranslationEstimator()
"""Detection of effective translation based on feature detection."""

def _init_from_config(self, config: dict) -> None:

self.active = config.get("active", True)
"""Flag controlling whether correction is active."""

relative_padding: float = config.get("padding", 0.0)
self.relative_padding: float = config.get("padding", 0.0)
"""Allow for extra padding around the provided roi (relative sense)."""

roi: Optional[Union[list, np.ndarray]] = config.get("roi")
roi: Optional[Union[list, np.ndarray, tuple[slice]]] = config.get("roi")
self.roi: Optional[tuple[slice, ...]] = (
None
if roi is None
else darsia.bounding_box(
np.array(roi),
padding=round(relative_padding * np.min(self.base.shape[:2])),
max_size=self.base.shape[:2],
else (
roi
if isinstance(roi, tuple)
else darsia.bounding_box(
np.array(roi),
padding=round(self.relative_padding * np.min(self.base.shape[:2])),
max_size=self.base.shape[:2],
)
)
)
"""ROI for feature detection."""

self.translation_estimator = darsia.TranslationEstimator()
"""Detection of effective translation based on feature detection."""
def return_config(self) -> dict:
"""Return config file for the drift correction."""
return {
"active": self.active,
"padding": self.relative_padding,
"roi": self.roi,
}

# ! ---- Main correction routines

Expand All @@ -102,3 +121,14 @@ def correct_array(
)
else:
return img

# ! ---- I/O ---- ! #
def load(self, path) -> None:
"""Load the drift correction from a file."""
config = np.load(path, allow_pickle=True)["config"].item()
self._init_from_config(config)

def save(self, path) -> None:
"""Save the drift correction to a file."""
np.savez(path, config=self.return_config())
print(f"Drift correction saved to {path}.")

0 comments on commit deed14e

Please sign in to comment.