Skip to content

Commit

Permalink
random state fixed for DDPM
Browse files Browse the repository at this point in the history
  • Loading branch information
Julien Roussel authored and Julien Roussel committed Jun 13, 2024
1 parent 210e2f4 commit a901097
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 27 deletions.
5 changes: 3 additions & 2 deletions qolmat/analysis/holes_characterization.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,9 @@ class LittleTest(McarTest):
imputer : Optional[ImputerEM]
Imputer based on the EM algorithm. The 'model' attribute must be equal to 'multinormal'.
If None, the default ImputerEM is taken.
random_state : Union[None, int, np.random.RandomState], optional
Controls the randomness of the fit_transform, by default None
random_state : int, RandomState instance or None, default=None
Controls the randomness.
Pass an int for reproducible output across multiple function calls.
"""

def __init__(
Expand Down
35 changes: 21 additions & 14 deletions qolmat/benchmark/missing_patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,9 @@ class _HoleGenerator:
Names of the columns for which holes must be created, by default None
ratio_masked : Optional[float]
Ratio of values ​​to mask, by default 0.05.
random_state : Optional[int]
The seed used by the random number generator, by default 42.
random_state : int, RandomState instance or None, default=None
Controls the randomness.
Pass an int for reproducible output across multiple function calls.
groups: Tuple[str, ...]
Column names used to group the data
"""
Expand Down Expand Up @@ -150,8 +151,9 @@ class UniformHoleGenerator(_HoleGenerator):
Names of the columns for which holes must be created, by default None
ratio_masked : Optional[float], optional
Ratio of masked values ​​to add, by default 0.05.
random_state : Optional[int], optional
The seed used by the random number generator, by default 42.
random_state : int, RandomState instance or None, default=None
Controls the randomness.
Pass an int for reproducible output across multiple function calls.
sample_proportional: bool, optional
If True, generates holes in target columns with same equal frequency.
If False, reproduces the empirical proportions between the variables.
Expand Down Expand Up @@ -215,8 +217,9 @@ class _SamplerHoleGenerator(_HoleGenerator):
Names of the columns for which holes must be created, by default None
ratio_masked : Optional[float], optional
Ratio of masked values ​​to add, by default 0.05.
random_state : Optional[int], optional
The seed used by the random number generator, by default 42.
random_state : int, RandomState instance or None, default=None
Controls the randomness.
Pass an int for reproducible output across multiple function calls.
groups: Tuple[str, ...]
Column names used to group the data
"""
Expand Down Expand Up @@ -321,8 +324,9 @@ class GeometricHoleGenerator(_SamplerHoleGenerator):
Names of the columns for which holes must be created, by default None
ratio_masked : Optional[float], optional
Ratio of masked values ​​to add, by default 0.05.
random_state : Union[None, int, np.random.RandomState], optional
The seed used by the random number generator, by default 42.
random_state : int, RandomState instance or None, default=None
Controls the randomness.
Pass an int for reproducible output across multiple function calls.
groups: Tuple[str, ...]
Column names used to group the data
"""
Expand Down Expand Up @@ -390,8 +394,9 @@ class EmpiricalHoleGenerator(_SamplerHoleGenerator):
Names of the columns for which holes must be created, by default None
ratio_masked : Optional[float], optional
Ratio of masked values ​​to add, by default 0.05.
random_state : Optional[int], optional
The seed used by the random number generator, by default 42.
random_state : int, RandomState instance or None, default=None
Controls the randomness.
Pass an int for reproducible output across multiple function calls.
groups: Tuple[str, ...]
Column names used to group the data
"""
Expand Down Expand Up @@ -485,8 +490,9 @@ class MultiMarkovHoleGenerator(_HoleGenerator):
Names of the columns for which holes must be created, by default None
ratio_masked : Optional[float], optional
Ratio of masked values to add, by default 0.05
random_state : Optional[int], optional
The seed used by the random number generator, by default 42.
random_state : int, RandomState instance or None, default=None
Controls the randomness.
Pass an int for reproducible output across multiple function calls.
groups: Tuple[str, ...]
Column names used to group the data
"""
Expand Down Expand Up @@ -634,8 +640,9 @@ class GroupedHoleGenerator(_HoleGenerator):
Names of the columns for which holes must be created, by default None
ratio_masked : Optional[float], optional
Ratio of masked to add, by default 0.05
random_state : Optional[int], optional
The seed used by the random number generator, by default 42.
random_state : int, RandomState instance or None, default=None
Controls the randomness.
Pass an int for reproducible output across multiple function calls.
groups : Tuple[str, ...]
Names of the columns forming the groups, by default []
"""
Expand Down
25 changes: 14 additions & 11 deletions qolmat/imputations/diffusions/ddpms.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
from typing import Dict, List, Callable, Tuple, Union
from typing_extensions import Self
import math
import sys
import numpy as np
import pandas as pd
import time
from datetime import timedelta
from tqdm import tqdm
import gc

import torch
from torch.utils.data import DataLoader, TensorDataset
from sklearn import preprocessing
from sklearn import utils as sku


from qolmat.imputations.diffusions.base import AutoEncoder, ResidualBlock, ResidualBlockTS
from qolmat.imputations.diffusions.utils import get_num_params
Expand Down Expand Up @@ -39,7 +40,7 @@ def __init__(
p_dropout: float = 0.0,
num_sampling: int = 1,
is_clip: bool = True,
random_state: Union[None, int] = None,
random_state: Union[None, int, np.random.RandomState] = None,
):
"""Diffusion model for tabular data based on
Denoising Diffusion Probabilistic Models (DDPM) of
Expand Down Expand Up @@ -69,8 +70,9 @@ def __init__(
Dropout probability, by default 0.0
num_sampling : int, optional
Number of samples generated for each cell, by default 1
random_state : int, optional
The seed of the pseudo random number generator to use, for reproductibility.
random_state : int, RandomState instance or None, default=None
Controls the randomness.
Pass an int for reproducible output across multiple function calls.
"""
self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

Expand Down Expand Up @@ -111,8 +113,9 @@ def __init__(
self.is_clip = is_clip

self.normalizer_x = preprocessing.StandardScaler()
if random_state is not None:
torch.manual_seed(random_state)
self.random_state = sku.check_random_state(random_state)
seed_torch = self.random_state.randint(sys.maxsize)
torch.manual_seed(seed_torch)

def _q_sample(self, x: torch.Tensor, t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Section 3.2, algorithm 1 formula implementation. Forward process, defined by `q`.
Expand Down Expand Up @@ -350,7 +353,6 @@ def fit(
round: int = 10,
cols_imputed: Tuple[str, ...] = (),
) -> Self:

"""Fit data
Parameters
Expand Down Expand Up @@ -542,7 +544,7 @@ def __init__(
p_dropout: float = 0.0,
num_sampling: int = 1,
is_rolling: bool = False,
random_state: Union[None, int] = None,
random_state: Union[None, int, np.random.RandomState] = None,
):
"""Diffusion model for time-series data based on the works of
Ho et al., 2020 (https://arxiv.org/abs/2006.11239),
Expand Down Expand Up @@ -581,8 +583,9 @@ def __init__(
Number of samples generated for each cell, by default 1
is_rolling : bool, optional
Use pandas.DataFrame.rolling for preprocessing data, by default False
random_state : int, optional
The seed of the pseudo random number generator to use, for reproductibility.
random_state : int, RandomState instance or None, default=None
Controls the randomness.
Pass an int for reproducible output across multiple function calls.
"""
super().__init__(
num_noise_steps,
Expand Down

0 comments on commit a901097

Please sign in to comment.