Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better Parameter Validation #80

Open
wants to merge 25 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
b2ada30
Initial version working
AstroPatty Oct 27, 2023
3db5362
add docstrings
AstroPatty Oct 27, 2023
f265226
Finish documentation
AstroPatty Oct 27, 2023
366de8c
Complete GaussianMixtureModel model
AstroPatty Oct 27, 2023
dd03ca2
Add decorator as top-level import in Util module
AstroPatty Oct 27, 2023
84ea259
Add to GaussianMixtureModel
AstroPatty Oct 27, 2023
405e19b
Move to using pydantic.PositiveFloat
AstroPatty Oct 27, 2023
dc2202d
GMM: Enforce sum(weights) == 1
AstroPatty Oct 27, 2023
4f4361d
Spelling
AstroPatty Oct 27, 2023
fa6d791
Remove unnecessary import
AstroPatty Oct 27, 2023
cb6983d
One more spelling mistake
AstroPatty Oct 27, 2023
c6d902f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 27, 2023
5e5f7cc
Add pydantic to requirements
AstroPatty Oct 27, 2023
8cb7521
Fix linting issue
AstroPatty Oct 27, 2023
8903241
Several changes
AstroPatty Nov 6, 2023
0905108
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 6, 2023
a730c7d
Working with all object methods
AstroPatty Nov 6, 2023
5fea8ec
Working with tests
AstroPatty Nov 6, 2023
eadf85e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 6, 2023
2e37905
Passing linters
AstroPatty Nov 6, 2023
3ab59d8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 6, 2023
5287a3d
Tests written
AstroPatty Nov 7, 2023
067e642
Bump minimum version to 3.9
AstroPatty Nov 7, 2023
b5d21dd
partial fix for typing
AstroPatty Nov 7, 2023
cfec7a6
Fix to run in 3.10
AstroPatty Nov 7, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.8]
python-version: ['3.9', '3.10', '3.11']

steps:
- uses: actions/checkout@v3
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ scipy
speclite
pyyaml
matplotlib
pydantic>=2
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
setup(
author="DESC/SLSC",
author_email="[email protected]",
python_requires=">=3.6",
python_requires=">=3.9",
classifiers=[
"Development Status :: 2 - Pre-Alpha",
"Intended Audience :: Developers",
Expand Down
5 changes: 5 additions & 0 deletions slsim/Deflectors/_params/elliptical_lens_galaxies.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from pydantic import BaseModel, PositiveFloat


class vel_disp_from_m_star(BaseModel):
m_star: PositiveFloat
11 changes: 11 additions & 0 deletions slsim/Deflectors/_params/velocity_dispersion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from pydantic import BaseModel, PositiveFloat
from astropy.cosmology import Cosmology


class vel_disp_composite_model(BaseModel, arbitrary_types_allowed=True):
r: PositiveFloat
m_star: PositiveFloat
rs_star: PositiveFloat
m_halo: PositiveFloat
c_halo: float
cosmo: Cosmology
2 changes: 2 additions & 0 deletions slsim/Deflectors/elliptical_lens_galaxies.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from slsim.Deflectors.velocity_dispersion import vel_disp_sdss
from slsim.Util import param_util
from slsim.Deflectors.deflector_base import DeflectorBase
from slsim.Util import check_params


class EllipticalLensGalaxies(DeflectorBase):
Expand Down Expand Up @@ -120,6 +121,7 @@ def elliptical_projected_eccentricity(ellipticity, **kwargs):
return e1_light, e2_light, e1_mass, e2_mass


@check_params
def vel_disp_from_m_star(m_star):
"""Function to calculate the velocity dispersion from the staller mass using
empirical relation for elliptical galaxies.
Expand Down
2 changes: 2 additions & 0 deletions slsim/Deflectors/velocity_dispersion.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import scipy
from skypy.galaxies.redshift import redshifts_from_comoving_density
from skypy.utils.random import schechter
from slsim.Util import check_params

"""
This module provides functions to compute velocity dispersion using schechter function.
Expand All @@ -11,6 +12,7 @@
# from skypy.galaxies.velocity_dispersion import schechter_vdf


@check_params
def vel_disp_composite_model(r, m_star, rs_star, m_halo, c_halo, cosmo):
"""Computes the luminosity weighted velocity dispersion for a deflector with a
stellar Hernquist profile and a NFW halo profile, assuming isotropic anisotropy.
Expand Down
Empty file.
35 changes: 35 additions & 0 deletions slsim/ParamDistributions/_params/gaussian_mixture_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from pydantic import (
BaseModel,
PositiveFloat,
PositiveInt,
model_validator,
field_validator,
)
import numpy as np


class GaussianMixtureModel(BaseModel):
means: list[float] = [0.00330796, -0.07635054, 0.11829008]
stds: list[PositiveFloat] = [
np.sqrt(0.00283885),
np.sqrt(0.01066668),
np.sqrt(0.0097978),
]
weights: list[PositiveFloat] = [0.62703102, 0.23732313, 0.13564585]

@field_validator("weights")
@classmethod
def check_weights(cls, weight_values):
if sum(weight_values) != 1:
raise ValueError("The sum of the weights must be 1")
return weight_values

@model_validator(mode="after")
def check_lengths(self):
if len(self.means) != len(self.stds) or len(self.means) != len(self.weights):
raise ValueError("The lengths of means, stds and weights must be equal")

Check warning on line 30 in slsim/ParamDistributions/_params/gaussian_mixture_model.py

View check run for this annotation

Codecov / codecov/patch

slsim/ParamDistributions/_params/gaussian_mixture_model.py#L30

Added line #L30 was not covered by tests
return self


class GaussianMixtureModel_rvs(BaseModel):
size: PositiveInt
19 changes: 7 additions & 12 deletions slsim/ParamDistributions/gaussian_mixture_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
from slsim.Util import check_params


class GaussianMixtureModel:
Expand All @@ -8,11 +9,13 @@ class GaussianMixtureModel:
is defined by its mean, standard deviation and weight.
"""

def __init__(self, means=None, stds=None, weights=None):
@check_params
def __init__(self, means: list[float], stds: list[float], weights: list[float]):
"""
The constructor for GaussianMixtureModel class. The default values are the
means, standard deviations, and weights of the fits to the data in the table
2 of https://doi.org/10.1093/mnras/stac2235 and others.
2 of https://doi.org/10.1093/mnras/stac2235 and others. See "_params" for
defaults and validation logic.

:param means: the mean values of the Gaussian components.
:type means: list of float
Expand All @@ -21,20 +24,12 @@ def __init__(self, means=None, stds=None, weights=None):
:param weights: The weights of the Gaussian components in the mixture.
:type weights: list of float
"""
if means is None:
means = [0.00330796, -0.07635054, 0.11829008]
if stds is None:
stds = [np.sqrt(0.00283885), np.sqrt(0.01066668), np.sqrt(0.0097978)]
if weights is None:
weights = [0.62703102, 0.23732313, 0.13564585]
assert (
len(means) == len(stds) == len(weights)
), "Lengths of means, standard deviations, and weights must be equal."
self.means = means
self.stds = stds
self.weights = weights

def rvs(self, size):
@check_params
def rvs(self, size: int):
"""Generate random variables from the GMM distribution.

:param size: The number of random variables to generate.
Expand Down
3 changes: 3 additions & 0 deletions slsim/Util/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .params import check_params

__all__ = ["check_params"]
208 changes: 208 additions & 0 deletions slsim/Util/params.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
"""Utilities for managing parameter defaults and validation in the slsim package.

Desgined to be unobtrusive to use.
"""
from functools import wraps
from importlib import import_module
from typing import Callable, Any, TypeVar
from enum import Enum
import inspect
import pydantic

"""
Set of routines for validating inputs to functions and classes. The elements of this
module should never be imported directly. Instead, @check_params can be imported
directly from the Util module.
"""


class SlSimParameterException(Exception):
pass


_defaults = {}


class _FnType(Enum):
"""Enum for the different types of functions we can have. This is used to determine
how to parse the arguments to the function.

There are three possible cases:
1. The function is a standard function, defined outside a class
2. The function is a standard object method,
taking "self" as the first parameter
3. The funtion is a class method (or staticmethod), not taking
"self" as the first parameter
"""

STANDARD = 0
METHOD = 1
CLASSMETHOD = 2


def determine_fn_type(fn: Callable) -> _FnType:
"""Determine which of the three possible cases a function falls into. Cases 0 and 2
are actually functionally identical. Things only get spicy when we have a "self"
argument.

However the tricky thing is that decorators operate on functions and methods when
they are imported, not when they are used. This means "inspect.ismethod" will always
return False, even if the function is a method.

We can get around this by checking if the parent of the function is a class. Then,
we check if the first argument of the function is "self". If both of these are true,
then the function is a method.
"""
if not inspect.isfunction(fn):
raise TypeError("decorator @check_params can only be used on functions!")

Check warning on line 57 in slsim/Util/params.py

View check run for this annotation

Codecov / codecov/patch

slsim/Util/params.py#L57

Added line #L57 was not covered by tests
qualified_obj_name = fn.__qualname__
qualified_obj_path = qualified_obj_name.split(".")
if len(qualified_obj_path) == 1:
# If the qualified name isn't split, this is a standard function not
# attached to a class
return _FnType.STANDARD

spec = inspect.getfullargspec(fn)
if spec.args[0] == "self":
return _FnType.METHOD
else:
return _FnType.CLASSMETHOD

Check warning on line 69 in slsim/Util/params.py

View check run for this annotation

Codecov / codecov/patch

slsim/Util/params.py#L69

Added line #L69 was not covered by tests


T_ = TypeVar("T_")


def check_params(fn: Callable[..., T_]) -> Callable[..., T_]:
"""A decorator for enforcing checking of params in __init__ methods. This decorator
will automatically load the default parameters for the class and check that the
passed parameters are valid. It expeects a "params.py" file in the same folder as
the class definition. Uses pydantic models to enforce types, sanity checks, and
defaults.

From and end user perspective, there is no difference between this and a normal
__init__ fn. Developers only need to add @check_params above their __init__ method
definition to enable this feature, then add their default parameters to the
"params.py" file.
"""
fn_type = determine_fn_type(fn)
if fn_type == _FnType.STANDARD:
new_fn = standard_fn_wrapper(fn)
elif fn_type == _FnType.METHOD:
new_fn = method_fn_wrapper(fn)
elif fn_type == _FnType.CLASSMETHOD:
new_fn = standard_fn_wrapper(fn)

Check warning on line 93 in slsim/Util/params.py

View check run for this annotation

Codecov / codecov/patch

slsim/Util/params.py#L92-L93

Added lines #L92 - L93 were not covered by tests

return new_fn


def standard_fn_wrapper(fn: Callable[..., T_]) -> Callable[..., T_]:
"""A wrapper for standard functions.

This is used to parse the arguments to the function and check that they are valid.
"""

@wraps(fn)
def new_fn(*args, **kwargs) -> T_:
# Get function argument names
pargs = {}
if args:
largs = list(inspect.signature(fn).parameters.keys())
for i in range(len(args)):
arg_value = args[i]
if arg_value is not None:
pargs[largs[i]] = args[i]
# Doing it this way ensures we still catch duplicate arguments
defaults = get_defaults(fn)
parsed_args = defaults(**pargs, **kwargs)
return fn(**dict(parsed_args))

return new_fn


def method_fn_wrapper(fn: Callable[..., T_]) -> Callable[..., T_]:
@wraps(fn)
def new_fn(obj: Any, *args, **kwargs) -> T_:
# Get function argument names
parsed_args = {}
if args:
largs = list(inspect.signature(fn).parameters.keys())

for i in range(len(args)):
arg_value = args[i]
if arg_value is not None:
parsed_args[largs[i + 1]] = arg_value
# Doing it this way ensures we still catch duplicate arguments
parsed_kwargs = {k: v for k, v in kwargs.items() if v is not None}
defaults = get_defaults(fn)
parsed_args = defaults(**parsed_args, **parsed_kwargs)
return fn(obj, **dict(parsed_args))

return new_fn


def get_defaults(fn: Callable) -> pydantic.BaseModel:
module_trace = inspect.getmodule(fn).__name__.split(".")
file_name = module_trace[-1]
parent_trace = module_trace[:-1]
parent_path = ".".join(parent_trace)
param_path = ".".join([parent_path, "_params"])
fn_qualname = fn.__qualname__
cache_name = parent_path + "." + fn_qualname
if cache_name in _defaults:
return _defaults[cache_name]

try:
_ = import_module(param_path)
except ModuleNotFoundError:
raise SlSimParameterException(
f'No default parameters found in module {".".join(parent_trace)},'
" but something in that module is trying to use the @check_params decorator"
)
try:
param_model_file = import_module(f"{param_path}.{file_name}")
except AttributeError:
raise SlSimParameterException(

Check warning on line 164 in slsim/Util/params.py

View check run for this annotation

Codecov / codecov/patch

slsim/Util/params.py#L163-L164

Added lines #L163 - L164 were not covered by tests
f'No default parameters found for file "{file_name}" in module '
f'{".".join(parent_trace)}, but something in that module is trying to use '
"the @check_params decorator"
)

if fn.__name__ == "__init__":
expected_model_name = "_".join(fn_qualname.split(".")[:-1])
else:
expected_model_name = "_".join(fn_qualname.split("."))

try:
parameter_model = getattr(param_model_file, expected_model_name)
except AttributeError:
raise SlSimParameterException(

Check warning on line 178 in slsim/Util/params.py

View check run for this annotation

Codecov / codecov/patch

slsim/Util/params.py#L177-L178

Added lines #L177 - L178 were not covered by tests
"No default parameters found for function " f'"{fn_qualname}"'
)
if not issubclass(parameter_model, pydantic.BaseModel):
raise SlSimParameterException(

Check warning on line 182 in slsim/Util/params.py

View check run for this annotation

Codecov / codecov/patch

slsim/Util/params.py#L182

Added line #L182 was not covered by tests
f'Defaults for "{fn_qualname}" are not in a pydantic model!'
)
_defaults[cache_name] = parameter_model
return _defaults[cache_name]


def load_parameters(modpath: str, obj_name: str) -> pydantic.BaseModel:
"""Loads parameters from the "params.py" file which should be in the same folder as
the class definition."""
try:
defaults = import_module(modpath)

Check warning on line 193 in slsim/Util/params.py

View check run for this annotation

Codecov / codecov/patch

slsim/Util/params.py#L192-L193

Added lines #L192 - L193 were not covered by tests
except ModuleNotFoundError:
raise SlSimParameterException(
"No default parameters found in module " f'"{modpath[-2]}"'
)
try:
obj_defaults = getattr(defaults, obj_name)
except AttributeError:
raise SlSimParameterException(

Check warning on line 201 in slsim/Util/params.py

View check run for this annotation

Codecov / codecov/patch

slsim/Util/params.py#L198-L201

Added lines #L198 - L201 were not covered by tests
f"No default parameters found for class " f'"{obj_name}"'
)
if not issubclass(obj_defaults, pydantic.BaseModel):
raise SlSimParameterException(

Check warning on line 205 in slsim/Util/params.py

View check run for this annotation

Codecov / codecov/patch

slsim/Util/params.py#L204-L205

Added lines #L204 - L205 were not covered by tests
f'Defaults for "{obj_name}" are not in a ' "pydantic model!"
)
return obj_defaults

Check warning on line 208 in slsim/Util/params.py

View check run for this annotation

Codecov / codecov/patch

slsim/Util/params.py#L208

Added line #L208 was not covered by tests
Empty file added tests/test_Params/__init__.py
Empty file.
Loading
Loading