Skip to content

Commit

Permalink
good progress
Browse files Browse the repository at this point in the history
  • Loading branch information
pabloprf committed Nov 22, 2024
1 parent e8ee02e commit bbc6593
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 144 deletions.
85 changes: 24 additions & 61 deletions src/mitim_tools/opt_tools/BOTORCHtools.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,14 @@
# ----------------------------------------------------------------------------------------------------------------------------

from botorch.models.transforms.input import InputTransform
from botorch.models.transforms.outcome import OutcomeTransform, Standardize
from botorch.models.transforms.outcome import OutcomeTransform
from botorch.models.utils import validate_input_scaling
from botorch.utils.types import DEFAULT
from gpytorch.models.exact_gp import ExactGP
from torch import Tensor

from linear_operator.operators import CholLinearOperator, DiagLinearOperator

from typing import Iterable
from torch.nn import ModuleDict
from botorch.posteriors.gpytorch import GPyTorchPosterior
from botorch.posteriors.posterior import Posterior
from gpytorch.distributions import MultitaskMultivariateNormal
from linear_operator.operators import BlockDiagLinearOperator


Expand Down Expand Up @@ -59,16 +54,14 @@ def __init__(
f"\t\t\t- FixedNoise: {FixedNoise} (extra noise: {learn_additional_noise}), TypeMean: {TypeMean}, TypeKernel: {TypeKernel}, ConstrainNoise: {ConstrainNoise:.1e}"
)

self.store_training(
train_X,
train_X_added,
train_Y,
train_Y_added,
train_Yvar,
train_Yvar_added,
input_transform,
outcome_transform,
)
# ** Store training data

# x, y are raw untransformed, and I want raw transformed. xa, ya are raw transformed
#x_tr = input_transform["tf1"](train_X)if input_transform is not None else train_X
#y_tr, yv_tr = outcome_transform["tf1"](train_X, train_Y, train_Yvar) if outcome_transform is not None else train_Y, train_Yvar
#self.train_X_usedToTrain = torch.cat((train_X_added, x_tr), axis=-2)
#self.train_Y_usedToTrain = torch.cat((train_Y_added, y_tr), axis=-2)
#self.train_Yvar_usedToTrain = torch.cat((train_Yvar_added, yv_tr), axis=-2)

# Grab num_outputs
self._num_outputs = train_Y.shape[-1]
Expand All @@ -91,40 +84,29 @@ def __init__(
# Added points are raw transformed, so I need to normalize them
if train_X_added.shape[0] > 0:
train_X_added = input_transform["tf2"](train_X_added)
train_Y_added, train_Yvar_added = outcome_transform["tf2"](
train_Y_added, train_Yvar_added
)
# -----

train_X_usedToTrain = torch.cat((transformed_X, train_X_added), axis=0)
train_Y_usedToTrain = torch.cat((train_Y, train_Y_added), axis=0)
train_Yvar_usedToTrain = torch.cat((train_Yvar, train_Yvar_added), axis=0)

self._input_batch_shape, self._aug_batch_shape = self.get_batch_dimensions(
train_X=train_X_usedToTrain, train_Y=train_Y_usedToTrain
)

train_Y_usedToTrain = train_Y_usedToTrain.squeeze(-1)
train_Yvar_usedToTrain = train_Yvar_usedToTrain.squeeze(-1)

#self._aug_batch_shape = train_Y.shape[:-2] #<----- New
train_Y_added = outcome_transform["tf3"].untransform(train_Y_added)[0]
train_Yvar_added = outcome_transform["tf3"].untransform(train_Yvar_added)[0]
train_Y_added, train_Yvar_added = outcome_transform["tf3"](*outcome_transform["tf2"](train_Y_added, train_Yvar_added))

# -----

train_X_usedToTrain = torch.cat((transformed_X, train_X_added), axis=-2)
train_Y_usedToTrain = torch.cat((train_Y, train_Y_added), axis=-2)
train_Yvar_usedToTrain = torch.cat((train_Yvar, train_Yvar_added), axis=-2)

# Validate again after applying the transforms
self._validate_tensor_args(X=transformed_X, Y=train_Y, Yvar=train_Yvar)
self._validate_tensor_args(X=train_X_usedToTrain, Y=train_Y_usedToTrain, Yvar=train_Yvar_usedToTrain)
ignore_X_dims = getattr(self, "_ignore_X_dims_scaling_check", None)
validate_input_scaling(
train_X=transformed_X,
train_Y=train_Y,
train_Yvar=train_Yvar,
train_X=train_X_usedToTrain,
train_Y=train_Y_usedToTrain,
train_Yvar=train_Yvar_usedToTrain,
ignore_X_dims=ignore_X_dims,
)
self._set_dimensions(train_X=train_X, train_Y=train_Y)
self._set_dimensions(train_X=train_X_usedToTrain, train_Y=train_Y_usedToTrain)


train_X, train_Y, train_Yvar = self._transform_tensor_args(
X=train_X, Y=train_Y, Yvar=train_Yvar
train_X_usedToTrain, train_Y_usedToTrain, train_Yvar_usedToTrain = self._transform_tensor_args(
X=train_X_usedToTrain, Y=train_Y_usedToTrain, Yvar=train_Yvar_usedToTrain
)

"""
Expand Down Expand Up @@ -271,26 +253,6 @@ def __init__(
self.input_transform = input_transform
self.to(train_X)

def store_training(self, x, xa, y, ya, yv, yva, input_transform, outcome_transform):

# x, y are raw untransformed, and I want raw transformed
if input_transform is not None:
x_tr = input_transform["tf1"](x)
else:
x_tr = x
if outcome_transform is not None:
y_tr, yv_tr = outcome_transform["tf1"](x, y, yv)
else:
y_tr, yv_tr = y, yv

# xa, ya are raw transformed
xa_tr = xa
ya_tr, yva_tr = ya, yva

self.train_X_usedToTrain = torch.cat((xa_tr, x_tr), axis=0)
self.train_Y_usedToTrain = torch.cat((ya_tr, y_tr), axis=0)
self.train_Yvar_usedToTrain = torch.cat((yva_tr, yv_tr), axis=0)

# Modify posterior call from BatchedMultiOutputGPyTorchModel to call posterior untransform with "X"
def posterior(
self,
Expand Down Expand Up @@ -559,6 +521,7 @@ def untransform_posterior(self, X, posterior):
if i == len(self.values())-1
else tf.untransform_posterior(posterior)
) # Only physics transformation (tf1) takes X


return posterior

Expand Down
6 changes: 3 additions & 3 deletions src/mitim_tools/opt_tools/STEPtools.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,9 @@ def fit_step(self, avoidPoints=None, fit_output_contains=None):

time1 = datetime.datetime.now()

#self._fit_multioutput_model()
self._fit_individual_models(fit_output_contains=fit_output_contains)
self._fit_multioutput_model(); self.GP["combined_model"] = self.GP["mo_model"]
#self._fit_individual_models(fit_output_contains=fit_output_contains)

txt_time = IOtools.getTimeDifference(time1)
print(f"--> Fitting of all models took {txt_time}")
if self.fileOutputs is not None:
Expand Down
118 changes: 51 additions & 67 deletions src/mitim_tools/opt_tools/SURROGATEtools.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,10 +163,17 @@ def __init__(
dimTransformedDV_x, dimTransformedDV_y = self._define_MITIM_transformations()
# ------------------------------------------------------------------------------------------------------------

self.train_X_added_full = torch.empty((0, dimTransformedDV_x_full)).to(self.dfT)
self.train_X_added = torch.empty((0, dimTransformedDV_x)).to(self.dfT)
self.train_Y_added = torch.empty((0, dimTransformedDV_y)).to(self.dfT)
self.train_Yvar_added = torch.empty((0, dimTransformedDV_y)).to(self.dfT)
x_transformed = input_transform_physics(self.train_X)
shape = list(x_transformed.shape)
shape[-2] = 0
shape[-1] = dimTransformedDV_x_full

self.train_X_added_full = torch.empty(*shape).to(self.dfT)
shape[-1] = dimTransformedDV_x
self.train_X_added = torch.empty(*shape).to(self.dfT)
shape[-1] = 1
self.train_Y_added = torch.empty(*shape).to(self.dfT)
self.train_Yvar_added = torch.empty(*shape).to(self.dfT)

# --------------------------------------------------------------------------------------
# Make sure that very small variations are not captured
Expand Down Expand Up @@ -202,17 +209,7 @@ def __init__(
if (self.fileTraining is not None) and (self.train_X.shape[0] + self.train_X_added.shape[0] > 0):
self.write_datafile(input_transform_physics, outcome_transform_physics)

# -------------------------------------------------------------------------------------
# Obtain normalization constants now (although during training this is messed up, so needed later too)
# -------------------------------------------------------------------------------------

self.normalization_pass(
input_transform_physics,
input_transform_normalization,
outcome_transform_physics,
output_transformed_standardization,
)

# ------------------------------------------------------------------------------------
# Combine transformations in chain of PHYSICS + NORMALIZATION
# ------------------------------------------------------------------------------------
Expand All @@ -222,9 +219,15 @@ def __init__(
).to(self.dfT)

outcome_transform = BOTORCHtools.ChainedOutcomeTransform(
tf1=outcome_transform_physics, tf2=output_transformed_standardization #, tf3=BOTORCHtools.OutcomeToBatchDimension()
tf1=outcome_transform_physics, tf2=output_transformed_standardization, tf3=BOTORCHtools.OutcomeToBatchDimension()
).to(self.dfT)

# -------------------------------------------------------------------------------------
# Obtain normalization constants now (although during training this is messed up, so needed later too)
# -------------------------------------------------------------------------------------

self.normalization_pass(input_transform, outcome_transform)

self.variables = (
self.surrogate_transformation_variables[self.outputs[0]]
if (
Expand Down Expand Up @@ -305,7 +308,7 @@ def _define_MITIM_transformations(self):
# Broadcast the input transformation to all outputs
# ------------------------------------------------------------------------------------

input_transformation_physics = input_transformations_physics[0] #BOTORCHtools.BatchBroadcastedInputTransform(input_transformations_physics)
input_transformation_physics = BOTORCHtools.BatchBroadcastedInputTransform(input_transformations_physics)

transformed_X = input_transformation_physics(self.train_X)

Expand All @@ -331,41 +334,41 @@ def _define_MITIM_transformations(self):
output_transformed_standardization, \
dimTransformedDV_x, dimTransformedDV_y

def normalization_pass(
self,
input_transform_physics,
input_transform_normalization,
outcome_transform_physics,
outcome_transform_normalization,
):
input_transform_normalization.training = True
outcome_transform_normalization.training = True
outcome_transform_normalization._is_trained = torch.tensor(False)
def normalization_pass(self,input_transform, outcome_transform):
'''
The goal of this is to capture NOW the normalization and standardization constants,
by account for both the actual data and the added data from file
'''

train_X_transformed = input_transform_physics(self.train_X)
train_Y_transformed, train_Yvar_transformed = outcome_transform_physics(self.train_X, self.train_Y, self.train_Yvar)
# Get input normalization and outcome standardization in training mode
input_transform['tf2'].training = True
outcome_transform['tf2'].training = True
outcome_transform['tf2']._is_trained = torch.tensor(False)

train_X_transformed = torch.cat(
(input_transform_physics(self.train_X), self.train_X_added), axis=0
)
y, yvar = outcome_transform_physics(self.train_X, self.train_Y, self.train_Yvar)
train_Y_transformed = torch.cat((y, self.train_Y_added), axis=0)
train_Yvar_transformed = torch.cat((yvar, self.train_Yvar_added), axis=0)
# Get the input normalization constants by physics-transforming the train_x and adding the data from file
train_X_transformed = input_transform['tf1'](self.train_X)
train_X_transformed = torch.cat((train_X_transformed, self.train_X_added), axis=-2)
_ = input_transform['tf2'](train_X_transformed)

train_X_transformed_norm = input_transform_normalization(train_X_transformed)
(
train_Y_transformed_norm,
train_Yvar_transformed_norm,
) = outcome_transform_normalization(train_Y_transformed, train_Yvar_transformed)
# Get the outcome standardization constants by physics-transforming the train_y and adding the data from file
# With the caveat that the added points have to not be batched
train_Y_transformed, train_Yvar_transformed = outcome_transform['tf1'](self.train_X, self.train_Y, self.train_Yvar)
y, yvar = outcome_transform['tf1'](self.train_X, self.train_Y, self.train_Yvar)

train_Y_transformed = torch.cat((y, outcome_transform['tf3'].untransform(self.train_Y_added)[0]), axis=-2)
train_Yvar_transformed = torch.cat((yvar, outcome_transform['tf3'].untransform(self.train_Yvar_added)[0]), axis=0)

train_Y_transformed_norm, train_Yvar_transformed_norm = outcome_transform['tf2'](train_Y_transformed, train_Yvar_transformed)

# Make sure they are not on training mode
input_transform_normalization.training = False
outcome_transform_normalization.training = False
outcome_transform_normalization._is_trained = torch.tensor(True)
input_transform['tf2'].training = False
outcome_transform['tf2'].training = False
outcome_transform['tf2']._is_trained = torch.tensor(True)


def fit(self):
print(
f"\t- Fitting model to {self.train_X.shape[0]+self.train_X_added.shape[0]} points"
f"\t- Fitting model to {self.train_X.shape[-2]+self.train_X_added.shape[-2]} points"
)

# ---------------------------------------------------------------------------------------------------
Expand Down Expand Up @@ -398,8 +401,6 @@ def fit(self):
with fundamental_model_context(self):
track_fval = self.perform_model_fit(mll)

embed()

# ---------------------------------------------------------------------------------------------------
# Asses optimization
# ---------------------------------------------------------------------------------------------------
Expand All @@ -409,12 +410,7 @@ def fit(self):
# Go back to definining the right normalizations, because the optimizer has to work on training mode...
# ---------------------------------------------------------------------------------------------------

self.normalization_pass(
self.gpmodel.input_transform["tf1"],
self.gpmodel.input_transform["tf2"],
self.gpmodel.outcome_transform["tf1"],
self.gpmodel.outcome_transform["tf2"],
)
self.normalization_pass(self.gpmodel.input_transform, self.gpmodel.outcome_transform)

def perform_model_fit(self, mll):
self.gpmodel.train()
Expand Down Expand Up @@ -903,29 +899,17 @@ def __init__(self, surrogate_model):

def __enter__(self):
# Works for individual models, not ModelList
self.surrogate_model.gpmodel.input_transform.tf1.flag_to_evaluate = False
for i in range(len(self.surrogate_model.gpmodel.input_transform.tf1.transforms)):
self.surrogate_model.gpmodel.input_transform.tf1.transforms[i].flag_to_evaluate = False
self.surrogate_model.gpmodel.outcome_transform.tf1.flag_to_evaluate = False

return self.surrogate_model

def __exit__(self, *args):
self.surrogate_model.gpmodel.input_transform.tf1.flag_to_evaluate = True
for i in range(len(self.surrogate_model.gpmodel.input_transform.tf1.transforms)):
self.surrogate_model.gpmodel.input_transform.tf1.transforms[i].flag_to_evaluate = True
self.surrogate_model.gpmodel.outcome_transform.tf1.flag_to_evaluate = True

# def __enter__(self):
# # Works for individual models, not ModelList
# embed()
# for i in range(len(self.surrogate_model.gpmodel.input_transform.tf1.transforms)):
# self.surrogate_model.gpmodel.input_transform.tf1.transforms[i].flag_to_evaluate = False
# self.surrogate_model.gpmodel.outcome_transform.tf1.flag_to_evaluate = False

# return self.surrogate_model

# def __exit__(self, *args):
# for i in range(len(self.surrogate_model.gpmodel.input_transform.tf1.transforms)):
# self.surrogate_model.gpmodel.input_transform.tf1.transforms[i].flag_to_evaluate = True
# self.surrogate_model.gpmodel.outcome_transform.tf1.flag_to_evaluate = True

def create_df_portals(x, y, yvar, x_names, output, max_x = 20):

new_data = []
Expand Down
23 changes: 10 additions & 13 deletions src/mitim_tools/opt_tools/optimizers/BOTORCHoptim.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ def findOptima(fun, optimization_params = {}, writeTrajectory=False):
"sample_around_best": True,
"disp": 50 if read_verbose_level() == 5 else False,
"seed": fun.seed,
"maxiter": 100,
}

"""
Expand Down Expand Up @@ -64,18 +63,16 @@ def __call__(self, x, *args, **kwargs):
seq_message = f'({"sequential" if sequential_q else "joint"}) ' if q>1 else ''
print(f"\t\t- Optimizing using optimize_acqf: {q = } {seq_message}, {num_restarts = }, {raw_samples = }")


#with IOtools.timer(name = "\n\t- Optimization", name_timer = '\t\t- Time: '):
#with IOtools.speeder("/Users/pablorf/PROJECTS/project_2024_PORTALSdevelopment/speed/profiler_opt.prof") as s:
x_opt, _ = botorch.optim.optimize_acqf(
acq_function=fun_opt,
bounds=fun.bounds_mod,
raw_samples=raw_samples,
q=q,
sequential=sequential_q,
num_restarts=num_restarts,
options=options,
)
with IOtools.timer(name = "\n\t- Optimization", name_timer = '\t\t- Time: '):
x_opt, _ = botorch.optim.optimize_acqf(
acq_function=fun_opt,
bounds=fun.bounds_mod,
raw_samples=raw_samples,
q=q,
sequential=sequential_q,
num_restarts=num_restarts,
options=options,
)
embed()

acq_evaluated = torch.Tensor(acq_evaluated)
Expand Down

0 comments on commit bbc6593

Please sign in to comment.