Skip to content

Commit

Permalink
Merge pull request #63 from punch-mission/develop
Browse files Browse the repository at this point in the history
Merge develop
  • Loading branch information
jmbhughes authored Sep 22, 2023
2 parents 93ab305 + 0d79050 commit 841c086
Show file tree
Hide file tree
Showing 8 changed files with 70 additions and 34 deletions.
7 changes: 4 additions & 3 deletions CODE_OF_CONDUCT.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@ We welcome and encourage contributions to this project in the form of pull reque

Contributors are expected

to be respectful and constructive; and
to enforce 1.
1. to be respectful and constructive; and
2. to enforce 1.

This code of conduct applies to all project-related communication that takes place on or at mailing lists, forums, social media, conferences, meetings, and social events.

This code of conduct is from [the HAPI-Server project](https://github.com/hapi-server/client-python/blob/master/CODE_OF_CONDUCT.md).
This code of conduct is adapted from [the HAPI-Server project](https://github.com/hapi-server/client-python/blob/master/CODE_OF_CONDUCT.md).
16 changes: 13 additions & 3 deletions regularizepsf/corrector.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from pathlib import Path
from typing import Any, Tuple

import deepdish as dd
import dill
import h5py
import numpy as np
from numpy.fft import fft2, ifft2, ifftshift

Expand Down Expand Up @@ -264,11 +264,21 @@ def __getitem__(self, xy: Tuple[int, int]) -> np.ndarray:
raise UnevaluatedPointError(f"Model not evaluated at {xy}.")

def save(self, path: str) -> None:
dd.io.save(path, (self._evaluations, self._target_evaluation))
with h5py.File(path, 'w') as f:
eval_grp = f.create_group('evaluations')
for key, val in self._evaluations.items():
eval_grp.create_dataset(f'{key}', data=val)
f.create_dataset('target', data=self._target_evaluation)

@classmethod
def load(cls, path: str) -> ArrayCorrector:
evaluations, target_evaluation = dd.io.load(path)
with h5py.File(path, 'r') as f:
target_evaluation = f['target'][:].copy()

evaluations = dict()
for key, val in f['evaluations'].items():
parsed_key = tuple(int(val) for val in key.replace("(", "").replace(")", "").split(","))
evaluations[parsed_key] = val[:].copy()
return cls(evaluations, target_evaluation)

def simulate_observation(self, image: np.ndarray) -> np.ndarray:
Expand Down
47 changes: 43 additions & 4 deletions regularizepsf/fitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from numbers import Real
from typing import Any, Dict, Generator, List, Optional, Tuple

import deepdish as dd
import h5py
import numpy as np
import sep
from astropy.io import fits
Expand Down Expand Up @@ -172,6 +172,7 @@ def fit(self,
"""

@abc.abstractmethod
def save(self, path: str) -> None:
"""Save the PatchCollection to a file
Expand All @@ -184,9 +185,10 @@ def save(self, path: str) -> None:
-------
None
"""
dd.io.save(path, self.patches)


@classmethod
@abc.abstractmethod
def load(cls, path: str) -> PatchCollectionABC:
"""Load a PatchCollection from a file
Expand All @@ -200,7 +202,6 @@ def load(cls, path: str) -> PatchCollectionABC:
PatchCollectionABC
the new patch collection
"""
return cls(dd.io.load(path))

def keys(self) -> List:
"""Gets identifiers for all patches"""
Expand Down Expand Up @@ -476,7 +477,7 @@ def average(self, corners: np.ndarray, patch_size: int, psf_size: int, # noqa:

for identifier, patch in self.patches.items():
# Normalize the patch
patch = patch / np.max(patch)
patch = patch / patch[psf_size//2, psf_size//2]

# Determine which average region it belongs to
center_x = identifier.x + self.size // 2
Expand Down Expand Up @@ -569,3 +570,41 @@ def to_array_corrector(self, target_evaluation: np.array) -> ArrayCorrector:

return ArrayCorrector(evaluation_dictionary, target_evaluation)

def save(self, path: str) -> None:
"""Save the CoordinatePatchCollection to a file
Parameters
----------
path : str
where to save the patch collection
Returns
-------
None
"""
with h5py.File(path, 'w') as f:
patch_grp = f.create_group('patches')
for key, val in self.patches.items():
patch_grp.create_dataset(f"({key.image_index, key.x, key.y})", data=val)

@classmethod
def load(cls, path: str) -> PatchCollectionABC:
"""Load a PatchCollection from a file
Parameters
----------
path : str
file path to load from
Returns
-------
PatchCollectionABC
the new patch collection
"""
patches = dict()
with h5py.File(path, "r") as f:
for key, val in f['patches'].items():
parsed_key = tuple(int(val) for val in key.replace("(", "").replace(")", "").split(","))
coord_id = CoordinateIdentifier(image_index=parsed_key[0], x=parsed_key[1], y=parsed_key[2])
patches[coord_id] = val[:].copy()
return cls(patches)
6 changes: 3 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
numpy==1.23.4
numpy>=1.25.2
dill==0.3.6
deepdish==0.3.7
h5py>=3.9.0
lmfit==1.2.2
cython==3.0.0
astropy=5.3.1
astropy==5.3.1
scipy>=1.10.0
scikit-image==0.19.3
sep==1.2.1
Expand Down
4 changes: 2 additions & 2 deletions requirements_dev.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
numpy==1.23.4
numpy>=1.25.2
dill==0.3.6
deepdish==0.3.7
h5py>=3.9.0
lmfit==1.2.2
cython==3.0.0
astropy==5.3.1
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

setup(
name='regularizepsf',
version='0.2.1',
version='0.2.2',
description='Point spread function modeling and regularization',
long_description=long_description,
long_description_content_type='text/markdown',
Expand All @@ -21,7 +21,7 @@
author='J. Marcus Hughes',
author_email='[email protected]',
ext_modules=cythonize(ext_modules, annotate=True, compiler_directives={'language_level': 3}),
install_requires=["numpy", "dill", "deepdish", "lmfit", "sep", "cython", "astropy", "scipy", "scikit-image", "matplotlib"],
install_requires=["numpy", "dill", "h5py", "lmfit", "sep", "cython", "astropy", "scipy", "scikit-image", "matplotlib"],
package_data={"regularizepsf": ["helper.pyx"]},
setup_requires=["cython"],
extras_require={"test": ['pytest', 'coverage', 'pytest-runner', 'pytest-mpl']}
Expand Down
18 changes: 2 additions & 16 deletions tests/test_corrector.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,22 +55,6 @@ def padded_100by100_image_psf_10_with_pattern():
img_padded = np.pad(img, padding_shape, mode='constant')
return img_padded

#
# @pytest.mark.parametrize("coord, value",
# [((0, 0), 2),
# ((10, 10), 1),
# ((-10, -10), 0)])
# def test_get_padded_img_section(coord, value, padded_100by100_image_psf_10_with_pattern):
# img_i = get_padded_img_section(padded_100by100_image_psf_10_with_pattern, coord[0], coord[1], 10)
# assert np.all(img_i == np.zeros((10, 10)) + value)
#
#
# def test_set_padded_img_section(padded_100by100_image_psf_10_with_pattern):
# test_img = np.pad(np.ones((100, 100)), ((20, 20), (20, 20)), mode='constant')
# for coord, value in [((0, 0), 2), ((10, 10), 1), ((-10, -10), 0)]:
# set_padded_img_section(test_img, coord[0], coord[1], 10, np.zeros((10, 10))+value)
# assert np.all(test_img == padded_100by100_image_psf_10_with_pattern)


def test_create_array_corrector():
example = ArrayCorrector({(0, 0): np.zeros((10, 10))},
Expand Down Expand Up @@ -210,6 +194,8 @@ def test_save_load_array_corrector(tmp_path):
assert os.path.isfile(fname)
loaded = example.load(fname)
assert isinstance(loaded, ArrayCorrector)
assert np.all(loaded._target_evaluation == np.ones((100, 100)))
assert np.all(loaded._evaluations[(0,0)] == np.ones((100, 100)))


def test_array_corrector_simulate_observation_with_zero_stars():
Expand Down
2 changes: 1 addition & 1 deletion tests/test_fitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def test_coordinate_patch_average():
})
for patch in collection.values():
# Make the normalization of each patch a no-op
patch[-1, -1] = 1
patch[5, 5] = 1

averaged_collection = collection.average(
np.array([[0, 0]]), 10, 10, mode='median')
Expand Down

0 comments on commit 841c086

Please sign in to comment.