diff --git a/models/src/anemoi/models/preprocessing/imputer.py b/models/src/anemoi/models/preprocessing/imputer.py index 4bd3c0ae..7b657f6a 100644 --- a/models/src/anemoi/models/preprocessing/imputer.py +++ b/models/src/anemoi/models/preprocessing/imputer.py @@ -231,6 +231,111 @@ def __init__( self._validate_indices() +class CopyImputer(BaseImputer): + """Imputes missing values copying them from another variable. + ``` + default: "none" + x: + - y + - q + ``` + """ + + def __init__( + self, + config=None, + data_indices: Optional[IndexCollection] = None, + statistics: Optional[dict] = None, + ) -> None: + super().__init__(config, data_indices, statistics) + + self._create_imputation_indices() + + self._validate_indices() + + def _create_imputation_indices( + self, + ): + """Create the indices for imputation.""" + name_to_index_training_input = self.data_indices.data.input.name_to_index + name_to_index_inference_input = self.data_indices.model.input.name_to_index + name_to_index_training_output = self.data_indices.data.output.name_to_index + name_to_index_inference_output = self.data_indices.model.output.name_to_index + + self.num_training_input_vars = len(name_to_index_training_input) + self.num_inference_input_vars = len(name_to_index_inference_input) + self.num_training_output_vars = len(name_to_index_training_output) + self.num_inference_output_vars = len(name_to_index_inference_output) + + ( + self.index_training_input, + self.index_inference_input, + self.index_training_output, + self.index_inference_output, + self.replacement, + ) = ([], [], [], [], []) + + # Create indices for imputation + for name in name_to_index_training_input: + key_to_copy = self.methods.get(name, self.default) + + if key_to_copy == "none": + LOGGER.debug(f"Imputer: skipping {name} as no imputation method is specified") + continue + + self.index_training_input.append(name_to_index_training_input[name]) + self.index_training_output.append(name_to_index_training_output.get(name, None)) + self.index_inference_input.append(name_to_index_inference_input.get(name, None)) + self.index_inference_output.append(name_to_index_inference_output.get(name, None)) + + self.replacement.append(key_to_copy) + + LOGGER.debug(f"Imputer: replacing NaNs in {name} with value coming from variable :{self.replacement[-1]}") + + def transform(self, x: torch.Tensor, in_place: bool = True) -> torch.Tensor: + """Impute missing values in the input tensor.""" + if not in_place: + x = x.clone() + + # Initialize nan mask once + if self.nan_locations is None: + + # Get NaN locations + self.nan_locations = self.get_nans(x) + + # Initialize training loss mask to weigh imputed values with zeroes once + self.loss_mask_training = torch.ones( + (x.shape[-2], len(self.data_indices.model.output.name_to_index)), device=x.device + ) # shape (grid, n_outputs) + # for all variables that are imputed and part of the model output, set the loss weight to zero + for idx_src, idx_dst in zip(self.index_training_input, self.index_inference_output): + if idx_dst is not None: + self.loss_mask_training[:, idx_dst] = (~self.nan_locations[:, idx_src]).int() + + # Choose correct index based on number of variables + if x.shape[-1] == self.num_training_input_vars: + index = self.index_training_input + elif x.shape[-1] == self.num_inference_input_vars: + index = self.index_inference_input + else: + raise ValueError( + f"Input tensor ({x.shape[-1]}) does not match the training " + f"({self.num_training_input_vars}) or inference shape ({self.num_inference_input_vars})", + ) + + # Replace values + for idx_src, (idx_dst, value) in zip(self.index_training_input, zip(index, self.replacement)): + if idx_dst is not None: + assert not torch.isnan( + x[..., self.data_indices.data.input.name_to_index[value]][self._expand_subset_mask(x, idx_src)] + ).any(), f"NaNs found in {value}." + x[..., idx_dst][self._expand_subset_mask(x, idx_src)] = x[ + ..., self.data_indices.data.input.name_to_index[value] + ][self._expand_subset_mask(x, idx_src)] + + return x + + class DynamicMixin: """Mixin to add dynamic imputation behavior.""" @@ -303,3 +408,50 @@ def __init__( "You are using a dynamic Imputer: NaN values will not be present in the model predictions. \ The model will be trained to predict imputed values. This might deteriorate performances." ) + + +class DynamicCopyImputer(CopyImputer): + """Dynamic Copy imputation behavior.""" + + def get_nans(self, x: torch.Tensor) -> torch.Tensor: + """Override to calculate NaN locations dynamically.""" + return torch.isnan(x) + + def transform(self, x: torch.Tensor, in_place: bool = True) -> torch.Tensor: + """Impute missing values in the input tensor.""" + if not in_place: + x = x.clone() + + # Initilialize mask every time + nan_locations = self.get_nans(x) + + self.loss_mask_training = torch.ones( + (x.shape[-2], len(self.data_indices.model.output.name_to_index)), device=x.device + ) + + # Choose correct index based on number of variables + if x.shape[-1] == self.num_training_input_vars: + index = self.index_training_input + elif x.shape[-1] == self.num_inference_input_vars: + index = self.index_inference_input + else: + raise ValueError( + f"Input tensor ({x.shape[-1]}) does not match the training " + f"({self.num_training_input_vars}) or inference shape ({self.num_inference_input_vars})", + ) + + # Replace values + for idx_src, (idx_dst, value) in zip(self.index_training_input, zip(index, self.replacement)): + if idx_dst is not None: + assert not torch.isnan( + x[..., self.data_indices.data.input.name_to_index[value]][nan_locations[..., idx_src]] + ).any(), f"NaNs found in {value}." + x[..., idx_dst][nan_locations[..., idx_src]] = x[ + ..., self.data_indices.data.input.name_to_index[value] + ][nan_locations[..., idx_src]] + + return x + + def inverse_transform(self, x: torch.Tensor, in_place: bool = True) -> torch.Tensor: + """Impute missing values in the input tensor.""" + return x