Skip to content

Commit

Permalink
Merge pull request #30 from Project-Resilience/jacob/lstm
Browse files Browse the repository at this point in the history
Add test for no change to ELUC
  • Loading branch information
ofrancon authored Sep 21, 2023
2 parents 828b567 + cd69c4f commit 77846a9
Showing 1 changed file with 73 additions and 7 deletions.
80 changes: 73 additions & 7 deletions model/lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@

# Simple LSTM model
class LSTMModel(torch.nn.Module, PyTorchModelHubMixin):
def __init__(self, layers=2, hidden_dim=128, input_dim=26, output_dim=1):
def __init__(self, layers=2, hidden_dim=128, input_dim=26, output_dim=1, **kwargs):
super().__init__()
self.lstm = torch.nn.LSTM(input_dim, hidden_dim, num_layers=layers, batch_first=True)
self.average = torch.nn.AdaptiveAvgPool1d(1)
Expand Down Expand Up @@ -149,14 +149,80 @@ def train_country_lstm(country_code):
num += 1
print(f"Test {country_code} MSE: {test_mse/num}")
print(f"Test {country_code} MAE: {test_mae/num}")
names = [f'{c}_' for c in country_code]
names = [f'_{c}' for c in country_code]
# Join the names together
name = ''.join(names)
torch.save(model, f"lstm_model_{name}.pt")
torch.save(model, f"lstm_model{name}.pt")
# Save the weights to huggingface
model.save_pretrained(f"project_resilience_lstm_model_{name}", push_to_hub=True, config=model_config)
model.save_pretrained(f"project_resilience_lstm_model{name}", push_to_hub=True, config=model_config)

train_country_lstm([143, 29]) # UK and Brazil
train_country_lstm([143]) # UK
train_country_lstm([29]) # Brazil
def test_no_change(country_code):
if len(country_code) == 1:
test_da = dataset.where(country_mask == country_code[0], drop=True).where(dataset.time > 2007, drop=True).load()
#train_da = dataset.where(country_mask == country_code[0], drop=True).where(dataset.time <= 2007, drop=True).load() # 143 is the code for the UK
else:
c_mask = xr.DataArray(np.in1d(country_mask, country_code).reshape(country_mask.shape),
dims=country_mask.dims, coords=country_mask.coords)
test_da = dataset.where(c_mask, drop=True).where(dataset.time > 2007, drop=True).load()
#train_da = dataset.where(c_mask, drop=True).where(dataset.time <= 2007, drop=True).load() # 143 is the code for the UK
#model = LSTMModel(**model_config).to(device)
model = LSTMModel.from_pretrained("jacobbieker/project_resilience_lstm_model_143_29_").to(device)
crit = torch.nn.MSELoss()
stat = torch.nn.L1Loss()

with torch.no_grad():
model.eval()
test_mse = 0.0
test_mae = 0.0
num = 0
test_times = test_da.time.values[11:]
for time in test_times:
for lon in range(0, len(test_da.lon.values), 10):
for lat in range(0, len(test_da.lat.values), 10):
example = test_da.isel(lat=slice(lat,lat+10), lon=slice(lon,lon+10)).sel(time=slice(time-10,time))
data_names = list(example.data_vars.keys())
example = xr.concat([example[var] for var in example.data_vars], dim='variable')
example = example.assign_coords(variable=data_names)
example = example.transpose('time', 'lat', 'lon', 'variable')
target = example.sel(variable='ELUC').isel(time=8) # Set it to ELUC of step 8, which is same as ELUC for the train sample, so no change
train = example.isel(time=slice(8,9)) # Select only 8, but be 4D still
#assert np.isfinite(example['ELUC'].values).all()
# Convert infinite values and NaN to 0.0
train = train.fillna(0.0)
target = target.fillna(0.0)
#exit()
# Set the target to be the same as train and set all values of train other than current to the same
#target = train.isel(time=8)
#train = train.isel(time=8)
train = np.repeat(train.values, 9, axis=0) # Copy so no changes across the 9 years of looking
torch_example = torch.from_numpy(np.nan_to_num(train, posinf=0.0, neginf=0.0))
# Flatten with einops to get the shape (batch_size, time, features)
torch_example = einops.rearrange(torch_example, 'time lat lon features -> (lat lon) time features')
torch_target = torch.unsqueeze(torch.from_numpy(np.nan_to_num(target.values, posinf=0.0, neginf=0.0)), dim=-1)
# Flatten with einops to get the shape (batch_size, time, features)
torch_target = einops.rearrange(torch_target, 'lat lon features -> (lat lon) features')
num += torch_target.shape[0]
outputs = model(torch_example.to(device))
loss = crit(outputs.cpu(), torch_target)
mae_loss = stat(outputs.cpu(), torch_target)
#print(loss.item())
test_mse += loss.item()
test_mae += mae_loss.item()
# Already averaged, so just need to track num for time, lat, lon
num += 1
print(f"Test {country_code} MSE: {test_mse/num}")
print(f"Test {country_code} MAE: {test_mae/num}")
#names = [f'_{c}' for c in country_code]
# Join the names together
#name = ''.join(names)
#torch.save(model, f"lstm_model{name}.pt")
# Save the weights to huggingface
#model.save_pretrained(f"project_resilience_lstm_model{name}", push_to_hub=True, config=model_config)


#train_country_lstm([143, 29]) # UK and India
#train_country_lstm([143]) # UK
#train_country_lstm([29]) # Brazil
#train_country_lstm(list(range(80))) # First 80 countries
# Check ELUC when no change
test_no_change([143, 29])

0 comments on commit 77846a9

Please sign in to comment.