Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update Jax version #76

Merged
merged 5 commits into from
Mar 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "regmod"
version = "0.2.1"
version = "0.2.2"
description = "General regression models"
readme = "README.rst"
requires-python = ">=3.10"
Expand All @@ -18,8 +18,8 @@ dependencies = [
"pandas",
"xspline==0.0.7",
"msca",
"jax[cpu]==0.4.5",
"jaxlib==0.4.4",
"jax",
"jaxlib",
]

[project.optional-dependencies]
Expand Down
7 changes: 7 additions & 0 deletions src/regmod/_typing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from collections.abc import Iterable
from typing import Any, Callable

from jax import Array as JaxArray
from msca.linalg.matrix import Matrix
from numpy.typing import ArrayLike, NDArray
from pandas import DataFrame
74 changes: 34 additions & 40 deletions src/regmod/composite_models/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Base Model
"""

import logging
from copy import deepcopy
from typing import Dict, List, Optional
Expand All @@ -19,15 +20,9 @@
logger = logging.getLogger(__name__)

link_funs = {
"gaussian": fun_dict[
GaussianModel.default_param_specs["mu"]["inv_link"]
].inv_fun,
"poisson": fun_dict[
PoissonModel.default_param_specs["lam"]["inv_link"]
].inv_fun,
"binomial": fun_dict[
BinomialModel.default_param_specs["p"]["inv_link"]
].inv_fun,
"gaussian": fun_dict[GaussianModel.default_param_specs["mu"]["inv_link"]].inv_fun,
"poisson": fun_dict[PoissonModel.default_param_specs["lam"]["inv_link"]].inv_fun,
"binomial": fun_dict[BinomialModel.default_param_specs["p"]["inv_link"]].inv_fun,
}

model_constructors = {
Expand Down Expand Up @@ -86,19 +81,23 @@ class BaseModel(NodeModel):
Overwrite the append function in NodeModel.
"""

def __init__(self,
name: str,
y: str,
variables: List[Variable],
df: Optional[pd.DataFrame] = None,
weights: str = "weights",
mtype: str = "gaussian",
prior_mask: Optional[Dict] = None,
**param_specs):
def __init__(
self,
name: str,
y: str,
variables: List[Variable],
df: Optional[DataFrame] = None,
weights: str = "weights",
mtype: str = "gaussian",
prior_mask: Optional[Dict] = None,
**param_specs,
):

super().__init__(name)
if any(mtype not in model_config
for model_config in (link_funs, model_constructors)):
if any(
mtype not in model_config
for model_config in (link_funs, model_constructors)
):
raise ValueError(f"Not supported model type {mtype}")
data = deepcopy(data)
variables = list(deepcopy(variables))
Expand All @@ -108,9 +107,7 @@ def __init__(self,
self.df = df
self.weights = weights
self.variables = {v.name: v for v in variables}
self.param_specs = {"variables": variables,
"use_offset": True,
**param_specs}
self.param_specs = {"variables": variables, "use_offset": True, **param_specs}
self.model = None
self.prior_mask = {} if prior_mask is None else prior_mask

Expand Down Expand Up @@ -149,10 +146,7 @@ def fit(self, **fit_options):
if self.model is None:
model_constructor = model_constructors[self.mtype]
self.model = model_constructor(
self.y,
df=self.df,
weights=self.weights,
param_specs=self.param_specs
self.y, df=self.df, weights=self.weights, param_specs=self.param_specs
)
self.model.fit(**fit_options)
message = f"fit_node;finish;{self.level};{self.name};"
Expand All @@ -179,17 +173,19 @@ def get_draws(self, df: DataFrame = None, size: int = 1000) -> DataFrame:
pred_data = self.model.df.copy()
pred_data.attach_df(df)

coefs_draws = np.random.multivariate_normal(self.model.opt_coefs,
self.model.opt_vcov,
size=size)
draws = np.vstack([
self.model.params[0].get_param(coefs_draw, pred_data)
for coefs_draw in coefs_draws
])
df_draws = pd.DataFrame(
coefs_draws = np.random.multivariate_normal(
self.model.opt_coefs, self.model.opt_vcov, size=size
)
draws = np.vstack(
[
self.model.params[0].get_param(coefs_draw, pred_data)
for coefs_draw in coefs_draws
]
)
df_draws = DataFrame(
draws.T,
columns=[f"{self.col_value}_{i}" for i in range(size)],
index=df.index
index=df.index,
)

return pd.concat([df, df_draws], axis=1)
Expand All @@ -216,8 +212,7 @@ def get_posterior(self) -> Dict:
# use minimum standard deviation of the posterior distribution
sd = np.maximum(0.1, np.sqrt(np.diag(self.model.opt_vcov)))
vnames = [v.name for v in self.param_specs["variables"]]
slices = sizes_to_slices([self.variables[name].size
for name in vnames])
slices = sizes_to_slices([self.variables[name].size for name in vnames])
return {
name: GaussianPrior(mean=mean[slices[i]], sd=sd[slices[i]])
for i, name in enumerate(vnames)
Expand All @@ -240,8 +235,7 @@ def append(self, node: NodeModel, rank: int = 0):
primary children.
"""
if rank >= 1:
raise ValueError(f"{type(self).__name__} can only have primary "
"link.")
raise ValueError(f"{type(self).__name__} can only have primary " "link.")
super().append(node, rank=rank)

def __repr__(self) -> str:
Expand Down
84 changes: 43 additions & 41 deletions src/regmod/function.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
"""
Function module
"""

from dataclasses import dataclass, field
from typing import Callable

import numpy as np

from regmod._typing import Callable


@dataclass
class SmoothFunction:
Expand All @@ -25,6 +27,7 @@ class SmoothFunction:
Derivative of the function.
d2fun : Callable
Second derivative of the function.

"""

name: str
Expand Down Expand Up @@ -60,8 +63,8 @@ def exp_d2fun(x):

def expit_fun(x):
neg_indices = x < 0
z = np.exp(-np.sqrt(x*x))
y = 1/(1 + z)
z = np.exp(-np.sqrt(x * x))
y = 1 / (1 + z)
if np.isscalar(x):
if neg_indices:
y = 1 - y
Expand All @@ -71,15 +74,15 @@ def expit_fun(x):


def expit_dfun(x):
z = np.exp(-np.sqrt(x*x))
y = z/(1 + z)**2
z = np.exp(-np.sqrt(x * x))
y = z / (1 + z) ** 2
return y


def expit_d2fun(x):
neg_indices = x < 0
z = np.exp(-np.sqrt(x*x))
y = z*(z - 1)/(z + 1)**3
z = np.exp(-np.sqrt(x * x))
y = z * (z - 1) / (z + 1) ** 3
if np.isscalar(x):
if neg_indices:
y = -y
Expand All @@ -93,55 +96,54 @@ def log_fun(x):


def log_dfun(x):
return 1/x
return 1 / x


def log_d2fun(x):
return -1/x**2
return -1 / x**2


def logit_fun(x):
return np.log(x/(1 - x))
return np.log(x / (1 - x))


def logit_dfun(x):
return 1/(x*(1 - x))
return 1 / (x * (1 - x))


def logit_d2fun(x):
return (2*x - 1)/(x*(1 - x))**2
return (2 * x - 1) / (x * (1 - x)) ** 2


fun_list = [
SmoothFunction(name="identity",
fun=identity_fun,
inv_fun=identity_fun,
dfun=identity_dfun,
d2fun=identity_d2fun),
SmoothFunction(name="exp",
fun=exp_fun,
inv_fun=log_fun,
dfun=exp_dfun,
d2fun=exp_d2fun),
SmoothFunction(name="expit",
fun=expit_fun,
inv_fun=logit_fun,
dfun=expit_dfun,
d2fun=expit_d2fun),
SmoothFunction(name="log",
fun=log_fun,
inv_fun=exp_fun,
dfun=log_dfun,
d2fun=log_d2fun),
SmoothFunction(name="logit",
fun=logit_fun,
inv_fun=expit_fun,
dfun=logit_dfun,
d2fun=logit_d2fun),
SmoothFunction(
name="identity",
fun=identity_fun,
inv_fun=identity_fun,
dfun=identity_dfun,
d2fun=identity_d2fun,
),
SmoothFunction(
name="exp", fun=exp_fun, inv_fun=log_fun, dfun=exp_dfun, d2fun=exp_d2fun
),
SmoothFunction(
name="expit",
fun=expit_fun,
inv_fun=logit_fun,
dfun=expit_dfun,
d2fun=expit_d2fun,
),
SmoothFunction(
name="log", fun=log_fun, inv_fun=exp_fun, dfun=log_dfun, d2fun=log_d2fun
),
SmoothFunction(
name="logit",
fun=logit_fun,
inv_fun=expit_fun,
dfun=logit_dfun,
d2fun=logit_d2fun,
),
]


fun_dict = {
fun.name: fun
for fun in fun_list
}
fun_dict = {fun.name: fun for fun in fun_list}
Loading
Loading