Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Weight example #165

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 7 additions & 8 deletions scripts/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from src.preprocess.admin_boundaries import KenyaAdminPreprocessor

from scripts.utils import get_data_path
from typing import Optional


def process_vci_2018():
Expand Down Expand Up @@ -55,19 +56,16 @@ def process_era5POS_2018():
)


def process_era5_land(variable: str):
if Path(".").absolute().as_posix().split("/")[-1] == "ml_drought":
data_path = Path("data")
else:
data_path = Path("../data")
regrid_path = data_path / "interim/chirps_preprocessed/chirps_kenya.nc"
def process_era5_land(variable: Optional[str] = None):
data_path = get_data_path()
regrid_path = data_path / "interim/VCI_preprocessed/data_kenya.nc"
assert regrid_path.exists(), f"{regrid_path} not available"

processor = ERA5LandPreprocessor(data_path)

processor.preprocess(
subset_str="kenya",
regrid=None,
regrid=regrid_path,
resample_time="M",
upsampling=False,
variable=variable,
Expand Down Expand Up @@ -194,5 +192,6 @@ def preprocess_boku_ndvi():
# preprocess_kenya_boundaries(selection="level_2")
# preprocess_kenya_boundaries(selection="level_3")
# preprocess_era5_hourly()
preprocess_boku_ndvi()
# preprocess_boku_ndvi()
# preprocess_asal_mask()
process_era5_land()
16 changes: 15 additions & 1 deletion src/models/neural_networks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,14 @@ def _make_analysis_folder(self) -> Path:
def _initialize_model(self, x_ref: Tuple[torch.Tensor, ...]) -> torch.nn.Module:
raise NotImplementedError

def _make_weights(self, target: torch.Tensor) -> torch.Tensor:

lower_threshold = -1 if self.normalize_y else 15
weight = 10
weights = torch.ones_like(target)
weights[target <= lower_threshold] = weight
return weights

def train(
self,
num_epochs: int = 1,
Expand Down Expand Up @@ -147,11 +155,17 @@ def train(
"somewhere in the data"
)

loss = F.smooth_l1_loss(pred, y_batch)
weights = self._make_weights(y_batch)

loss = F.smooth_l1_loss(pred, y_batch, reduce=False)
loss = (weights * loss).mean()

loss.backward()
optimizer.step()

with torch.no_grad():
# could optionally weight the rmse, but it makes things
# less interpretable and doesn't help with model learning
rmse = F.mse_loss(pred, y_batch)
train_rmse.append(math.sqrt(rmse.cpu().item()))

Expand Down