diff --git a/.github/workflows/BasicTerm_ME_PyTorch.yml b/.github/workflows/BasicTerm_ME_python.yml similarity index 77% rename from .github/workflows/BasicTerm_ME_PyTorch.yml rename to .github/workflows/BasicTerm_ME_python.yml index 9ae2bfe..3a2cf49 100644 --- a/.github/workflows/BasicTerm_ME_PyTorch.yml +++ b/.github/workflows/BasicTerm_ME_python.yml @@ -1,10 +1,10 @@ -name: basicterm_me_pytorch +name: basicterm_me_python on: workflow_dispatch: push: paths: - - 'containers/BasicTerm_ME_heavylight/**' + - 'containers/BasicTerm_ME_python/**' jobs: docker: @@ -28,6 +28,6 @@ jobs: name: Build and push uses: docker/build-push-action@v5 with: - context: "{{defaultContext}}:containers/BasicTerm_ME_heavylight" + context: "{{defaultContext}}:containers/BasicTerm_ME_python" push: true - tags: actuarial/basicterm_me_pytorch:latest \ No newline at end of file + tags: actuarial/basicterm_me_python:latest \ No newline at end of file diff --git a/README.md b/README.md index b59d40d..b89fbfd 100644 --- a/README.md +++ b/README.md @@ -8,9 +8,10 @@ We currently only have 1 benchmark, we are working to expand the benchmarks. Ope Time measurement is currently the best of 3 runs. -| benchmark | recursive | container |A100-SXM4-40GB | H100-SXM5-80GB | +| benchmark | classification | container |A100-SXM4-40GB | H100-SXM5-80GB | |---------------|-|-|----------------|----------------| -| BasicTerm_ME 100 Million | Yes | [link](https://hub.docker.com/repository/docker/actuarial/basicterm_me_pytorch/general) | 15.8194s | 7.1820s | +| BasicTerm_ME 100 Million | recursive PyTorch | [link](https://hub.docker.com/repository/docker/actuarial/basicterm_me_python/general) | 15.8284s | 7.205s | +| BasicTerm_ME 100 Million | compiled iterative JAX | [link](https://hub.docker.com/repository/docker/actuarial/basicterm_me_python/general) | 3.448s | 1.551s | ### Notes diff --git a/containers/BasicTerm_ME_heavylight/.devcontainer/devcontainer.json b/containers/BasicTerm_ME_python/.devcontainer/devcontainer.json similarity index 100% rename from containers/BasicTerm_ME_heavylight/.devcontainer/devcontainer.json rename to containers/BasicTerm_ME_python/.devcontainer/devcontainer.json diff --git a/containers/BasicTerm_ME_heavylight/BasicTerm_ME/Projection/__init__.py b/containers/BasicTerm_ME_python/BasicTerm_ME/Projection/__init__.py similarity index 100% rename from containers/BasicTerm_ME_heavylight/BasicTerm_ME/Projection/__init__.py rename to containers/BasicTerm_ME_python/BasicTerm_ME/Projection/__init__.py diff --git a/containers/BasicTerm_ME_heavylight/BasicTerm_ME/__init__.py b/containers/BasicTerm_ME_python/BasicTerm_ME/__init__.py similarity index 100% rename from containers/BasicTerm_ME_heavylight/BasicTerm_ME/__init__.py rename to containers/BasicTerm_ME_python/BasicTerm_ME/__init__.py diff --git a/containers/BasicTerm_ME_heavylight/BasicTerm_ME/_data/data.pickle b/containers/BasicTerm_ME_python/BasicTerm_ME/_data/data.pickle similarity index 100% rename from containers/BasicTerm_ME_heavylight/BasicTerm_ME/_data/data.pickle rename to containers/BasicTerm_ME_python/BasicTerm_ME/_data/data.pickle diff --git a/containers/BasicTerm_ME_heavylight/BasicTerm_ME/_system.json b/containers/BasicTerm_ME_python/BasicTerm_ME/_system.json similarity index 100% rename from containers/BasicTerm_ME_heavylight/BasicTerm_ME/_system.json rename to containers/BasicTerm_ME_python/BasicTerm_ME/_system.json diff --git a/containers/BasicTerm_ME_heavylight/BasicTerm_ME/disc_rate_ann.xlsx b/containers/BasicTerm_ME_python/BasicTerm_ME/disc_rate_ann.xlsx similarity index 100% rename from containers/BasicTerm_ME_heavylight/BasicTerm_ME/disc_rate_ann.xlsx rename to containers/BasicTerm_ME_python/BasicTerm_ME/disc_rate_ann.xlsx diff --git a/containers/BasicTerm_ME_heavylight/BasicTerm_ME/model_point_table.xlsx b/containers/BasicTerm_ME_python/BasicTerm_ME/model_point_table.xlsx similarity index 100% rename from containers/BasicTerm_ME_heavylight/BasicTerm_ME/model_point_table.xlsx rename to containers/BasicTerm_ME_python/BasicTerm_ME/model_point_table.xlsx diff --git a/containers/BasicTerm_ME_heavylight/BasicTerm_ME/mort_table.xlsx b/containers/BasicTerm_ME_python/BasicTerm_ME/mort_table.xlsx similarity index 100% rename from containers/BasicTerm_ME_heavylight/BasicTerm_ME/mort_table.xlsx rename to containers/BasicTerm_ME_python/BasicTerm_ME/mort_table.xlsx diff --git a/containers/BasicTerm_ME_heavylight/BasicTerm_ME/premium_table.xlsx b/containers/BasicTerm_ME_python/BasicTerm_ME/premium_table.xlsx similarity index 100% rename from containers/BasicTerm_ME_heavylight/BasicTerm_ME/premium_table.xlsx rename to containers/BasicTerm_ME_python/BasicTerm_ME/premium_table.xlsx diff --git a/containers/BasicTerm_ME_heavylight/Dockerfile b/containers/BasicTerm_ME_python/Dockerfile similarity index 83% rename from containers/BasicTerm_ME_heavylight/Dockerfile rename to containers/BasicTerm_ME_python/Dockerfile index 3f7836d..682e966 100644 --- a/containers/BasicTerm_ME_heavylight/Dockerfile +++ b/containers/BasicTerm_ME_python/Dockerfile @@ -4,9 +4,12 @@ FROM pytorch/pytorch:2.2.2-cuda11.8-cudnn8-runtime # Set the working directory in the container WORKDIR /app -RUN python -m pip install \ +RUN pip install --upgrade pip +RUN pip install --upgrade "jax[cuda12]" +RUN pip install \ pandas \ openpyxl \ + equinox \ heavylight==1.0.6 # Copy the rest of the application diff --git a/containers/BasicTerm_ME_python/main.py b/containers/BasicTerm_ME_python/main.py new file mode 100644 index 0000000..ba13c04 --- /dev/null +++ b/containers/BasicTerm_ME_python/main.py @@ -0,0 +1,22 @@ +import argparse + +def main(): + parser = argparse.ArgumentParser(description="Term ME model runner") + parser.add_argument("--multiplier", type=int, default=100, help="Multiplier for model points") + # add an argument that must be either "torch" or "jax" + parser.add_argument("--model", type=str, default="jax_iterative", choices=["torch_recursive", "jax_iterative"], help="Model to run") + args = parser.parse_args() + + multiplier = args.multiplier + + if args.model == "torch_recursive": + from term_me_recursive_pytorch import time_recursive_PyTorch # having both imports at top level gave a jax error? + time_recursive_PyTorch(multiplier) + elif args.model == "jax_iterative": + from term_me_iterative_jax import time_iterative_jax + time_iterative_jax(multiplier) + else: + raise ValueError("Invalid model") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/containers/BasicTerm_ME_heavylight/notes.md b/containers/BasicTerm_ME_python/notes.md similarity index 100% rename from containers/BasicTerm_ME_heavylight/notes.md rename to containers/BasicTerm_ME_python/notes.md diff --git a/containers/BasicTerm_ME_heavylight/requirements.txt b/containers/BasicTerm_ME_python/requirements.txt similarity index 100% rename from containers/BasicTerm_ME_heavylight/requirements.txt rename to containers/BasicTerm_ME_python/requirements.txt diff --git a/containers/BasicTerm_ME_python/term_me_iterative_jax.py b/containers/BasicTerm_ME_python/term_me_iterative_jax.py new file mode 100644 index 0000000..20b9dbc --- /dev/null +++ b/containers/BasicTerm_ME_python/term_me_iterative_jax.py @@ -0,0 +1,127 @@ +import jax +import pandas as pd +import numpy as np +import timeit +import jax.numpy as jnp +import equinox as eqx +jax.config.update("jax_enable_x64", True) + +disc_rate_ann = pd.read_excel("BasicTerm_ME/disc_rate_ann.xlsx", index_col=0) +mort_table = pd.read_excel("BasicTerm_ME/mort_table.xlsx", index_col=0) +model_point_table = pd.read_excel("BasicTerm_ME/model_point_table.xlsx", index_col=0) +premium_table = pd.read_excel("BasicTerm_ME/premium_table.xlsx", index_col=[0,1]) + +class ModelPointsEqx(eqx.Module): + premium_pp: jnp.ndarray + duration_mth: jnp.ndarray + age_at_entry: jnp.ndarray + sum_assured: jnp.ndarray + policy_count: jnp.ndarray + policy_term: jnp.ndarray + max_proj_len: jnp.ndarray + + def __init__(self, model_point_table: pd.DataFrame, premium_table: pd.DataFrame, size_multiplier: int = 1): + table = model_point_table.merge(premium_table, left_on=["age_at_entry", "policy_term"], right_index=True) + table.sort_values(by="policy_id", inplace=True) + self.premium_pp = jnp.round(jnp.array(np.tile(table["sum_assured"].to_numpy() * table["premium_rate"].to_numpy(), size_multiplier)),decimals=2) + self.duration_mth = jnp.array(jnp.tile(table["duration_mth"].to_numpy(), size_multiplier)) + self.age_at_entry = jnp.array(jnp.tile(table["age_at_entry"].to_numpy(), size_multiplier)) + self.sum_assured = jnp.array(jnp.tile(table["sum_assured"].to_numpy(), size_multiplier)) + self.policy_count = jnp.array(jnp.tile(table["policy_count"].to_numpy(), size_multiplier)) + self.policy_term = jnp.array(jnp.tile(table["policy_term"].to_numpy(), size_multiplier)) + self.max_proj_len = jnp.max(12 * self.policy_term - self.duration_mth) + 1 + +class AssumptionsEqx(eqx.Module): + disc_rate_ann: jnp.ndarray + mort_table: jnp.ndarray + expense_acq: jnp.ndarray + expense_maint: jnp.ndarray + + def __init__(self, disc_rate_ann: pd.DataFrame, mort_table: pd.DataFrame): + self.disc_rate_ann = jnp.array(disc_rate_ann["zero_spot"].to_numpy()) + self.mort_table = jnp.array(mort_table.to_numpy()) + self.expense_acq = jnp.array(300) + self.expense_maint = jnp.array(60) + +class LoopState(eqx.Module): + t: jnp.ndarray + tot: jnp.ndarray + pols_lapse_prev: jnp.ndarray + pols_death_prev: jnp.ndarray + pols_if_at_BEF_DECR_prev: jnp.ndarray + +class TermME(eqx.Module): + mp: ModelPointsEqx + assume: AssumptionsEqx + init_ls: LoopState + + def __init__(self, mp: ModelPointsEqx, assume: AssumptionsEqx): + self.mp = mp + self.assume = assume + self.init_ls = LoopState( + t=jnp.array(0), + tot = jnp.array(0), + pols_lapse_prev=jnp.zeros_like(self.mp.duration_mth, dtype=jnp.float64), + pols_death_prev=jnp.zeros_like(self.mp.duration_mth, dtype=jnp.float64), + pols_if_at_BEF_DECR_prev=jnp.where(self.mp.duration_mth > 0, self.mp.policy_count, 0.) + ) + + def __call__(self): + def iterative_core(ls: LoopState, _): + duration_month_t = self.mp.duration_mth + ls.t + duration_t = duration_month_t // 12 + age_t = self.mp.age_at_entry + duration_t + pols_if_init = ls.pols_if_at_BEF_DECR_prev - ls.pols_lapse_prev - ls.pols_death_prev + pols_if_at_BEF_MAT = pols_if_init + pols_maturity = (duration_month_t == self.mp.policy_term * 12) * pols_if_at_BEF_MAT + pols_if_at_BEF_NB = pols_if_at_BEF_MAT - pols_maturity + pols_new_biz = jnp.where(duration_month_t == 0, self.mp.policy_count, 0) + pols_if_at_BEF_DECR = pols_if_at_BEF_NB + pols_new_biz + mort_rate = self.assume.mort_table[age_t-18, jnp.clip(duration_t, a_max=5)] + mort_rate_mth = 1 - (1 - mort_rate) ** (1/12) + pols_death = pols_if_at_BEF_DECR * mort_rate_mth + claims = self.mp.sum_assured * pols_death + premiums = self.mp.premium_pp * pols_if_at_BEF_DECR + commissions = (duration_t == 0) * premiums + discount = (1 + self.assume.disc_rate_ann[ls.t//12]) ** (-ls.t/12) + inflation_factor = (1 + 0.01) ** (ls.t/12) + expenses = self.assume.expense_acq * pols_new_biz + pols_if_at_BEF_DECR * self.assume.expense_maint/12 * inflation_factor + lapse_rate = jnp.clip(0.1 - 0.02 * duration_t, a_min=0.02) + net_cf = premiums - claims - expenses - commissions + discounted_net_cf = jnp.sum(net_cf) * discount + nxt_ls = LoopState( + t=ls.t+1, + tot = ls.tot + discounted_net_cf, + pols_lapse_prev=(pols_if_at_BEF_DECR - pols_death) * (1 - (1 - lapse_rate) ** (1/12)), + pols_death_prev=pols_death, + pols_if_at_BEF_DECR_prev=pols_if_at_BEF_DECR + ) + return nxt_ls, None + return jax.lax.scan(iterative_core, self.init_ls, xs=None, length=277)[0].tot + + +def run_jax_term_ME(term_me: TermME): + return term_me() + +run_jax_term_ME_opt = jax.jit(run_jax_term_ME) + +def time_jax_func(mp, assume, func): + term_me = TermME(mp, assume) + result = func(term_me).block_until_ready() + start = timeit.default_timer() + result = func(term_me).block_until_ready() + end = timeit.default_timer() + elapsed_time = end - start # Time in seconds + return float(result), elapsed_time + +def time_iterative_jax(multiplier: int): + mp = ModelPointsEqx(model_point_table, premium_table, size_multiplier=multiplier) + assume = AssumptionsEqx(disc_rate_ann, mort_table) + result, time_in_seconds = time_jax_func(mp, assume, run_jax_term_ME_opt) + print("JAX iterative model") + print(f"number modelpoints={len(mp.duration_mth):,}") + print(f"{result=:,}") + print(f"{time_in_seconds=}") + +if __name__ == "__main__": + time_iterative_jax(100) \ No newline at end of file diff --git a/containers/BasicTerm_ME_heavylight/main.py b/containers/BasicTerm_ME_python/term_me_recursive_pytorch.py similarity index 91% rename from containers/BasicTerm_ME_heavylight/main.py rename to containers/BasicTerm_ME_python/term_me_recursive_pytorch.py index cc15963..e6f9861 100644 --- a/containers/BasicTerm_ME_heavylight/main.py +++ b/containers/BasicTerm_ME_python/term_me_recursive_pytorch.py @@ -5,9 +5,7 @@ import torch from heavylight import LightModel, agg import timeit -import argparse -# ensure CUDA available print(f"{torch.cuda.is_available()=}") # set 64 bit precision torch.set_default_dtype(torch.float64) @@ -167,16 +165,8 @@ def time_recursive_CPU(model: TermME): end = timeit.default_timer() return result, end - start -def main(): - parser = argparse.ArgumentParser(description="Term ME model runner") - parser.add_argument("--disable_cuda", action="store_true", help="Disable CUDA usage") - parser.add_argument("--multiplier", type=int, default=100, help="Multiplier for model points") - args = parser.parse_args() - - disable_cuda = args.disable_cuda - multiplier = args.multiplier - - if not disable_cuda and torch.cuda.is_available(): +def time_recursive_PyTorch(multiplier: int): + if torch.cuda.is_available(): device = torch.device('cuda') else: device = torch.device('cpu') @@ -196,9 +186,7 @@ def main(): model.mp = mp_multiplied result, time_in_seconds = time_recursive(model) # report results - print(f"number modelpoints={len(model_point_table) * multiplier:,}") + print("PyTorch recursive model") + print(f"number modelpoints={len(mp_multiplied.duration_mth):,}") print(f"{result=:,}") - print(f"{time_in_seconds=}") - -if __name__ == "__main__": - main() \ No newline at end of file + print(f"{time_in_seconds=}") \ No newline at end of file