Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Copy Imputer #72

Open
wants to merge 4 commits into
base: develop
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 152 additions & 0 deletions models/src/anemoi/models/preprocessing/imputer.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,111 @@ def __init__(
self._validate_indices()


class CopyImputer(BaseImputer):
"""Imputes missing values copying them from another variable.
```
default: "none"
x:
- y
- q
```
"""

def __init__(
self,
config=None,
data_indices: Optional[IndexCollection] = None,
statistics: Optional[dict] = None,
) -> None:
super().__init__(config, data_indices, statistics)

self._create_imputation_indices()

self._validate_indices()

def _create_imputation_indices(
self,
):
"""Create the indices for imputation."""
name_to_index_training_input = self.data_indices.data.input.name_to_index
name_to_index_inference_input = self.data_indices.model.input.name_to_index
name_to_index_training_output = self.data_indices.data.output.name_to_index
name_to_index_inference_output = self.data_indices.model.output.name_to_index

self.num_training_input_vars = len(name_to_index_training_input)
self.num_inference_input_vars = len(name_to_index_inference_input)
self.num_training_output_vars = len(name_to_index_training_output)
self.num_inference_output_vars = len(name_to_index_inference_output)

(
self.index_training_input,
self.index_inference_input,
self.index_training_output,
self.index_inference_output,
self.replacement,
) = ([], [], [], [], [])

# Create indices for imputation
for name in name_to_index_training_input:
key_to_copy = self.methods.get(name, self.default)

if key_to_copy == "none":
LOGGER.debug(f"Imputer: skipping {name} as no imputation method is specified")
continue

self.index_training_input.append(name_to_index_training_input[name])
self.index_training_output.append(name_to_index_training_output.get(name, None))
self.index_inference_input.append(name_to_index_inference_input.get(name, None))
self.index_inference_output.append(name_to_index_inference_output.get(name, None))

self.replacement.append(key_to_copy)

LOGGER.debug(f"Imputer: replacing NaNs in {name} with value coming from variable :{self.replacement[-1]}")

def transform(self, x: torch.Tensor, in_place: bool = True) -> torch.Tensor:
"""Impute missing values in the input tensor."""
if not in_place:
x = x.clone()

# Initialize nan mask once
if self.nan_locations is None:

# Get NaN locations
self.nan_locations = self.get_nans(x)

# Initialize training loss mask to weigh imputed values with zeroes once
self.loss_mask_training = torch.ones(
(x.shape[-2], len(self.data_indices.model.output.name_to_index)), device=x.device
) # shape (grid, n_outputs)
# for all variables that are imputed and part of the model output, set the loss weight to zero
for idx_src, idx_dst in zip(self.index_training_input, self.index_inference_output):
if idx_dst is not None:
self.loss_mask_training[:, idx_dst] = (~self.nan_locations[:, idx_src]).int()

# Choose correct index based on number of variables
if x.shape[-1] == self.num_training_input_vars:
index = self.index_training_input
elif x.shape[-1] == self.num_inference_input_vars:
index = self.index_inference_input
else:
raise ValueError(
f"Input tensor ({x.shape[-1]}) does not match the training "
f"({self.num_training_input_vars}) or inference shape ({self.num_inference_input_vars})",
)

# Replace values
for idx_src, (idx_dst, value) in zip(self.index_training_input, zip(index, self.replacement)):
if idx_dst is not None:
assert not torch.isnan(
x[..., self.data_indices.data.input.name_to_index[value]][self._expand_subset_mask(x, idx_src)]
).any(), f"NaNs found in {value}."
x[..., idx_dst][self._expand_subset_mask(x, idx_src)] = x[
..., self.data_indices.data.input.name_to_index[value]
][self._expand_subset_mask(x, idx_src)]

return x


class DynamicMixin:
"""Mixin to add dynamic imputation behavior."""

Expand Down Expand Up @@ -303,3 +408,50 @@ def __init__(
"You are using a dynamic Imputer: NaN values will not be present in the model predictions. \
The model will be trained to predict imputed values. This might deteriorate performances."
)


class DynamicCopyImputer(CopyImputer):
"""Dynamic Copy imputation behavior."""

def get_nans(self, x: torch.Tensor) -> torch.Tensor:
"""Override to calculate NaN locations dynamically."""
return torch.isnan(x)

def transform(self, x: torch.Tensor, in_place: bool = True) -> torch.Tensor:
"""Impute missing values in the input tensor."""
if not in_place:
x = x.clone()

# Initilialize mask every time
nan_locations = self.get_nans(x)

self.loss_mask_training = torch.ones(
(x.shape[-2], len(self.data_indices.model.output.name_to_index)), device=x.device
)

# Choose correct index based on number of variables
if x.shape[-1] == self.num_training_input_vars:
index = self.index_training_input
elif x.shape[-1] == self.num_inference_input_vars:
index = self.index_inference_input
else:
raise ValueError(
f"Input tensor ({x.shape[-1]}) does not match the training "
f"({self.num_training_input_vars}) or inference shape ({self.num_inference_input_vars})",
)

# Replace values
for idx_src, (idx_dst, value) in zip(self.index_training_input, zip(index, self.replacement)):
if idx_dst is not None:
assert not torch.isnan(
x[..., self.data_indices.data.input.name_to_index[value]][nan_locations[..., idx_src]]
).any(), f"NaNs found in {value}."
x[..., idx_dst][nan_locations[..., idx_src]] = x[
..., self.data_indices.data.input.name_to_index[value]
][nan_locations[..., idx_src]]

return x

def inverse_transform(self, x: torch.Tensor, in_place: bool = True) -> torch.Tensor:
"""Impute missing values in the input tensor."""
return x
Loading