From 83234d36c280ad09423634f84a1e3309ef48947e Mon Sep 17 00:00:00 2001 From: icedoom888 Date: Thu, 9 Jan 2025 18:25:33 +0100 Subject: [PATCH 1/4] Implemented copy imputer --- .../anemoi/models/preprocessing/imputer.py | 146 ++++++++++++++++++ 1 file changed, 146 insertions(+) diff --git a/models/src/anemoi/models/preprocessing/imputer.py b/models/src/anemoi/models/preprocessing/imputer.py index 4bd3c0ae..6b73eaa6 100644 --- a/models/src/anemoi/models/preprocessing/imputer.py +++ b/models/src/anemoi/models/preprocessing/imputer.py @@ -231,6 +231,108 @@ def __init__( self._validate_indices() +class CopyImputer(BaseImputer): + """Imputes missing values copying them from another variable. + ``` + default: "none" + x: + - y + - q + ``` + """ + + def __init__( + self, + config=None, + data_indices: Optional[IndexCollection] = None, + statistics: Optional[dict] = None, + ) -> None: + super().__init__(config, data_indices, statistics) + + self._create_imputation_indices() + + self._validate_indices() + + def _create_imputation_indices( + self, + ): + """Create the indices for imputation.""" + name_to_index_training_input = self.data_indices.data.input.name_to_index + name_to_index_inference_input = self.data_indices.model.input.name_to_index + name_to_index_training_output = self.data_indices.data.output.name_to_index + name_to_index_inference_output = self.data_indices.model.output.name_to_index + + self.num_training_input_vars = len(name_to_index_training_input) + self.num_inference_input_vars = len(name_to_index_inference_input) + self.num_training_output_vars = len(name_to_index_training_output) + self.num_inference_output_vars = len(name_to_index_inference_output) + + ( + self.index_training_input, + self.index_inference_input, + self.index_training_output, + self.index_inference_output, + self.replacement, + ) = ([], [], [], [], []) + + # Create indices for imputation + for name in name_to_index_training_input: + key_to_copy = self.methods.get(name, self.default) + + if key_to_copy == "none": + LOGGER.debug(f"Imputer: skipping {name} as no imputation method is specified") + continue + + self.index_training_input.append(name_to_index_training_input[name]) + self.index_training_output.append(name_to_index_training_output.get(name, None)) + self.index_inference_input.append(name_to_index_inference_input.get(name, None)) + self.index_inference_output.append(name_to_index_inference_output.get(name, None)) + + self.replacement.append(key_to_copy) + + LOGGER.debug(f"Imputer: replacing NaNs in {name} with value coming from variable :{self.replacement[-1]}") + + def transform(self, x: torch.Tensor, in_place: bool = True) -> torch.Tensor: + """Impute missing values in the input tensor.""" + if not in_place: + x = x.clone() + + # Initialize nan mask once + if self.nan_locations is None: + + # Get NaN locations + self.nan_locations = self.get_nans(x) + + # Initialize training loss mask to weigh imputed values with zeroes once + self.loss_mask_training = torch.ones( + (x.shape[-2], len(self.data_indices.model.output.name_to_index)), device=x.device + ) # shape (grid, n_outputs) + # for all variables that are imputed and part of the model output, set the loss weight to zero + for idx_src, idx_dst in zip(self.index_training_input, self.index_inference_output): + if idx_dst is not None: + self.loss_mask_training[:, idx_dst] = (~self.nan_locations[:, idx_src]).int() + + # Choose correct index based on number of variables + if x.shape[-1] == self.num_training_input_vars: + index = self.index_training_input + elif x.shape[-1] == self.num_inference_input_vars: + index = self.index_inference_input + else: + raise ValueError( + f"Input tensor ({x.shape[-1]}) does not match the training " + f"({self.num_training_input_vars}) or inference shape ({self.num_inference_input_vars})", + ) + + # Replace values + for idx_src, (idx_dst, value) in zip(self.index_training_input, zip(index, self.replacement)): + if idx_dst is not None: + x[..., idx_dst][self._expand_subset_mask(x, idx_src)] = x[..., self.data_indices[value]][ + self._expand_subset_mask(x, idx_src) + ] + + return x + + class DynamicMixin: """Mixin to add dynamic imputation behavior.""" @@ -303,3 +405,47 @@ def __init__( "You are using a dynamic Imputer: NaN values will not be present in the model predictions. \ The model will be trained to predict imputed values. This might deteriorate performances." ) + + +class DynamicCopyImputer(CopyImputer): + """Dynamic Copy imputation behavior.""" + + def get_nans(self, x: torch.Tensor) -> torch.Tensor: + """Override to calculate NaN locations dynamically.""" + return torch.isnan(x) + + def transform(self, x: torch.Tensor, in_place: bool = True) -> torch.Tensor: + """Impute missing values in the input tensor.""" + if not in_place: + x = x.clone() + + # Initilialize mask every time + nan_locations = self.get_nans(x) + + self.loss_mask_training = torch.ones( + (x.shape[-2], len(self.data_indices.model.output.name_to_index)), device=x.device + ) + + # Choose correct index based on number of variables + if x.shape[-1] == self.num_training_input_vars: + index = self.index_training_input + elif x.shape[-1] == self.num_inference_input_vars: + index = self.index_inference_input + else: + raise ValueError( + f"Input tensor ({x.shape[-1]}) does not match the training " + f"({self.num_training_input_vars}) or inference shape ({self.num_inference_input_vars})", + ) + + # Replace values + for idx_src, (idx_dst, value) in zip(self.index_training_input, zip(index, self.replacement)): + if idx_dst is not None: + x[..., idx_dst][nan_locations[..., idx_src]] = x[..., self.data_indices[value]][ + nan_locations[..., idx_src] + ] + + return x + + def inverse_transform(self, x: torch.Tensor, in_place: bool = True) -> torch.Tensor: + """Impute missing values in the input tensor.""" + return x From 1d482a17ad4f53f56c17d08c7118359b897ad8ec Mon Sep 17 00:00:00 2001 From: icedoom888 Date: Thu, 9 Jan 2025 19:59:29 +0100 Subject: [PATCH 2/4] Fixed implementation and tested --- models/src/anemoi/models/preprocessing/imputer.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/models/src/anemoi/models/preprocessing/imputer.py b/models/src/anemoi/models/preprocessing/imputer.py index 6b73eaa6..411287bd 100644 --- a/models/src/anemoi/models/preprocessing/imputer.py +++ b/models/src/anemoi/models/preprocessing/imputer.py @@ -326,7 +326,9 @@ def transform(self, x: torch.Tensor, in_place: bool = True) -> torch.Tensor: # Replace values for idx_src, (idx_dst, value) in zip(self.index_training_input, zip(index, self.replacement)): if idx_dst is not None: - x[..., idx_dst][self._expand_subset_mask(x, idx_src)] = x[..., self.data_indices[value]][ + assert not torch.isnan(x[..., self.data_indices.data.input.name_to_index[value]][self._expand_subset_mask(x, idx_src)]).any(), \ + f"NaNs found in {value}." + x[..., idx_dst][self._expand_subset_mask(x, idx_src)] = x[..., self.data_indices.data.input.name_to_index[value]][ self._expand_subset_mask(x, idx_src) ] @@ -440,7 +442,10 @@ def transform(self, x: torch.Tensor, in_place: bool = True) -> torch.Tensor: # Replace values for idx_src, (idx_dst, value) in zip(self.index_training_input, zip(index, self.replacement)): if idx_dst is not None: - x[..., idx_dst][nan_locations[..., idx_src]] = x[..., self.data_indices[value]][ + print(value) + assert not torch.isnan(x[..., self.data_indices.data.input.name_to_index[value]][nan_locations[..., idx_src]]).any(), \ + f"NaNs found in {value}." + x[..., idx_dst][nan_locations[..., idx_src]] = x[..., self.data_indices.data.input.name_to_index[value]][ nan_locations[..., idx_src] ] From 00523690a480f712552a2e7da17f965ef62d6677 Mon Sep 17 00:00:00 2001 From: icedoom888 Date: Thu, 9 Jan 2025 20:00:03 +0100 Subject: [PATCH 3/4] gpc --- .../anemoi/models/preprocessing/imputer.py | 22 ++++++++++--------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/models/src/anemoi/models/preprocessing/imputer.py b/models/src/anemoi/models/preprocessing/imputer.py index 411287bd..2bfcd46c 100644 --- a/models/src/anemoi/models/preprocessing/imputer.py +++ b/models/src/anemoi/models/preprocessing/imputer.py @@ -326,11 +326,12 @@ def transform(self, x: torch.Tensor, in_place: bool = True) -> torch.Tensor: # Replace values for idx_src, (idx_dst, value) in zip(self.index_training_input, zip(index, self.replacement)): if idx_dst is not None: - assert not torch.isnan(x[..., self.data_indices.data.input.name_to_index[value]][self._expand_subset_mask(x, idx_src)]).any(), \ - f"NaNs found in {value}." - x[..., idx_dst][self._expand_subset_mask(x, idx_src)] = x[..., self.data_indices.data.input.name_to_index[value]][ - self._expand_subset_mask(x, idx_src) - ] + assert not torch.isnan( + x[..., self.data_indices.data.input.name_to_index[value]][self._expand_subset_mask(x, idx_src)] + ).any(), f"NaNs found in {value}." + x[..., idx_dst][self._expand_subset_mask(x, idx_src)] = x[ + ..., self.data_indices.data.input.name_to_index[value] + ][self._expand_subset_mask(x, idx_src)] return x @@ -443,11 +444,12 @@ def transform(self, x: torch.Tensor, in_place: bool = True) -> torch.Tensor: for idx_src, (idx_dst, value) in zip(self.index_training_input, zip(index, self.replacement)): if idx_dst is not None: print(value) - assert not torch.isnan(x[..., self.data_indices.data.input.name_to_index[value]][nan_locations[..., idx_src]]).any(), \ - f"NaNs found in {value}." - x[..., idx_dst][nan_locations[..., idx_src]] = x[..., self.data_indices.data.input.name_to_index[value]][ - nan_locations[..., idx_src] - ] + assert not torch.isnan( + x[..., self.data_indices.data.input.name_to_index[value]][nan_locations[..., idx_src]] + ).any(), f"NaNs found in {value}." + x[..., idx_dst][nan_locations[..., idx_src]] = x[ + ..., self.data_indices.data.input.name_to_index[value] + ][nan_locations[..., idx_src]] return x From f2916aee1989d65ac39af5b0297166b37577feb3 Mon Sep 17 00:00:00 2001 From: icedoom888 Date: Fri, 10 Jan 2025 11:15:16 +0100 Subject: [PATCH 4/4] Removed print leftover --- models/src/anemoi/models/preprocessing/imputer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/models/src/anemoi/models/preprocessing/imputer.py b/models/src/anemoi/models/preprocessing/imputer.py index 2bfcd46c..7b657f6a 100644 --- a/models/src/anemoi/models/preprocessing/imputer.py +++ b/models/src/anemoi/models/preprocessing/imputer.py @@ -443,7 +443,6 @@ def transform(self, x: torch.Tensor, in_place: bool = True) -> torch.Tensor: # Replace values for idx_src, (idx_dst, value) in zip(self.index_training_input, zip(index, self.replacement)): if idx_dst is not None: - print(value) assert not torch.isnan( x[..., self.data_indices.data.input.name_to_index[value]][nan_locations[..., idx_src]] ).any(), f"NaNs found in {value}."