Skip to content

Commit

Permalink
good progress
Browse files Browse the repository at this point in the history
  • Loading branch information
pabloprf committed Dec 11, 2024
1 parent 042e654 commit 7156b7f
Show file tree
Hide file tree
Showing 6 changed files with 160 additions and 68 deletions.
4 changes: 2 additions & 2 deletions src/mitim_modules/portals/PORTALSmain.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ def default_namelist(optimization_options, CGYROrun=False):
}

# Surrogate
optimization_options["surrogateOptions"]["selectSurrogate"] = partial(
PORTALStools.selectSurrogate, CGYROrun=CGYROrun
optimization_options["surrogateOptions"]["selectSurrogates"] = partial(
PORTALStools.selectSurrogates, CGYROrun=CGYROrun
)

optimization_options["surrogateOptions"]["ensure_within_bounds"] = True
Expand Down
54 changes: 34 additions & 20 deletions src/mitim_modules/portals/PORTALStools.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,42 @@
from mitim_tools.misc_tools.LOGtools import printMsg as print
from IPython import embed

def selectSurrogate(output, surrogateOptions, CGYROrun=False):
def selectSurrogates(outputs, surrogateOptions, CGYROrun=False):
'''
This divides potentially different outputs into different
surrogates to be joined with ModelList
'''

# Find transition to Targets
for iTar, output in enumerate(outputs):
if output[2:5] == "Tar":
break

print(f'\t- Selecting surrogate options for "{output}" to be run')
# Find transition to last location of transport
for iTra, output in enumerate(outputs):
if output.split('_')[-1] == outputs[iTar-1].split('_')[-1]:
break

if output is not None:
# If it's a target, just linear
if output[2:5] == "Tar":
surrogateOptions["TypeMean"] = 1
surrogateOptions["TypeKernel"] = 2 # Constant kernel
# If it's not, stndard
else:
surrogateOptions["TypeMean"] = 2 # Linear in gradients, constant in rest
surrogateOptions["TypeKernel"] = 1 # RBF
# surrogateOptions['ExtraNoise'] = True
surrogateOptions_dict = {}

# Turbulent and Neoclassical

surrogateOptions_dict[iTra] = copy.deepcopy(surrogateOptions)
surrogateOptions_dict[iTra]["TypeMean"] = 2 # Linear in gradients, constant in rest
surrogateOptions_dict[iTra]["TypeKernel"] = 1 # RBF
# surrogateOptions_dict[len(output)]['ExtraNoise'] = True

return surrogateOptions
surrogateOptions_dict[iTar] = copy.deepcopy(surrogateOptions)
surrogateOptions_dict[iTar]["TypeMean"] = 2 # Linear in gradients, constant in rest
surrogateOptions_dict[iTar]["TypeKernel"] = 1 # RBF
# surrogateOptions_dict[len(output)]['ExtraNoise'] = True

# Targets (If it's a target, just linear)
surrogateOptions_dict[len(outputs)] = copy.deepcopy(surrogateOptions)
surrogateOptions_dict[len(outputs)]["TypeMean"] = 1
surrogateOptions_dict[len(outputs)]["TypeKernel"] = 2 # Constant kernel

return surrogateOptions_dict

def default_portals_transformation_variables(additional_params = []):
"""
Expand Down Expand Up @@ -121,18 +141,12 @@ def input_transformation_portals(Xorig, output, surrogate_parameters, surrogate_
"""

_, num = output.split("_")
index = powerstate.indexes_simulation[
int(num)
] # num=1 -> pos=1, so that it takes the second value in vectors
index = powerstate.indexes_simulation[int(num)] # num=1 -> pos=1, so that it takes the second value in vectors

xFit = torch.Tensor().to(X)
for ikey in surrogate_transformation_variables[output]:
xx = powerstate.plasma[ikey][: X.shape[0], index]
xFit = torch.cat((xFit, xx.unsqueeze(1)), dim=1).to(X)

#TO FIX
import torch.nn.functional as F
xFit = F.pad(xFit, (0, 3-xFit.shape[-1]))

parameters_combined = {"powerstate": powerstate}

Expand Down
61 changes: 37 additions & 24 deletions src/mitim_modules/portals/utils/PORTALSinit.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,33 +277,46 @@ def initializeProblem(
dictDVs[name] = [dvs_fixed[name][0], base_gradient, dvs_fixed[name][1]]

# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Define output dictionaries
# Define output dictionaries (order is important, consistent with selectSurrogates)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

ofs, name_objectives = [], []
for ikey in dictCPs_base:
if ikey == "te":
var = "Qe"
elif ikey == "ti":
var = "Qi"
elif ikey == "ne":
var = "Ge"
elif ikey == "nZ":
var = "GZ"
elif ikey == "w0":
var = "Mt"

for i in range(len(portals_fun.MODELparameters["RhoLocations"])):
ofs.append(f"{var}Turb_{i+1}")
ofs.append(f"{var}Neo_{i+1}")

ofs.append(f"{var}Tar_{i+1}")

name_objectives.append(f"{var}Res_{i+1}")

if portals_fun.PORTALSparameters["surrogateForTurbExch"]:
for i in range(len(portals_fun.MODELparameters["RhoLocations"])):
ofs.append(f"PexchTurb_{i+1}")

# Turb and neo, ending with last location

for i in range(len(portals_fun.MODELparameters["RhoLocations"])):
for model in ["Turb", "Neo"]:
for ikey in dictCPs_base:
if ikey == "te": var = "Qe"
elif ikey == "ti": var = "Qi"
elif ikey == "ne": var = "Ge"
elif ikey == "nZ": var = "GZ"
elif ikey == "w0": var = "Mt"

ofs.append(f"{var}{model}_{i+1}")


if f"{var}Res_{i+1}" not in name_objectives:
name_objectives.append(f"{var}Res_{i+1}")

if portals_fun.PORTALSparameters["surrogateForTurbExch"]:
for i in range(len(portals_fun.MODELparameters["RhoLocations"])):
ofs.append(f"PexchTurb_{i+1}")

# Tar

for i in range(len(portals_fun.MODELparameters["RhoLocations"])):
model = "Tar"
for ikey in dictCPs_base:
if ikey == "te": var = "Qe"
elif ikey == "ti": var = "Qi"
elif ikey == "ne": var = "Ge"
elif ikey == "nZ": var = "GZ"
elif ikey == "w0": var = "Mt"

ofs.append(f"{var}{model}_{i+1}")



name_transformed_ofs = []
for of in ofs:
Expand Down
39 changes: 32 additions & 7 deletions src/mitim_tools/opt_tools/STEPtools.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,8 @@ def fit_step(self, avoidPoints=None, fit_output_contains=None):

def _fit_multioutput_model(self):

surrogateOptions = self.surrogateOptions["selectSurrogate"]('AllMITIM', self.surrogateOptions)

self.GP["mo_model"] = SURROGATEtools.surrogate_model(
# Base model
self.GP["mo_model"] = SURROGATEtools.surrogate_model.simple(
self.x,
self.y,
self.yvar,
Expand All @@ -113,11 +112,37 @@ def _fit_multioutput_model(self):
outputs_transformed=self.outputs_transformed,
bounds=self.bounds,
dfT=self.dfT,
surrogateOptions=surrogateOptions,
surrogateOptions=self.surrogateOptions,
)

# -
surrogateOptions_dict = self.surrogateOptions["selectSurrogates"](self.outputs, self.surrogateOptions)

gps, j = [], 0
for i in surrogateOptions_dict:
gps.append(
SURROGATEtools.surrogate_model(
self.x,
self.y[:, j:i],
self.yvar[:, j:i],
self.surrogate_parameters,
outputs=self.outputs[j:i],
outputs_transformed=self.outputs_transformed[j:i],
bounds=self.bounds,
dfT=self.dfT,
surrogateOptions=surrogateOptions_dict[i],
)
)
j = i

# Fitting
self.GP["mo_model"].fit()
gpmodels = []
for gp in gps:
gp.fit()
gpmodels.append(gp.gpmodel)

# Joining
self.GP["mo_model"].gpmodel = BOTORCHtools.ModelListGP_MITIM(*gpmodels)

def _fit_individual_models(self, fit_output_contains=None):

Expand Down Expand Up @@ -146,8 +171,8 @@ def _fit_individual_models(self, fit_output_contains=None):
surrogateOptions = copy.deepcopy(self.surrogateOptions)

# Then, depending on application (e.g. targets in mitim are fitted differently)
if ("selectSurrogate" in surrogateOptions) and (surrogateOptions["selectSurrogate"] is not None):
surrogateOptions = surrogateOptions["selectSurrogate"](output_this, surrogateOptions)
if ("selectSurrogates" in surrogateOptions) and (surrogateOptions["selectSurrogates"] is not None):
surrogateOptions = surrogateOptions["selectSurrogates"](output_this, surrogateOptions)

# ---------------------------------------------------------------------------------------------------
# To avoid problems with fixed values (e.g. calibration terms that are fixed)
Expand Down
68 changes: 54 additions & 14 deletions src/mitim_tools/opt_tools/SURROGATEtools.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,15 @@ class surrogate_model:
"""

def __init__(
self,
@classmethod
def simple(cls, *args, **kwargs):
# Create an instance of the class
instance = cls.__new__(cls)
# Initialize the parameters manually
instance._init_parameters(*args, **kwargs)
return instance

def _init_parameters(self,
Xor,
Yor,
Yvaror,
Expand All @@ -42,17 +49,6 @@ def __init__(
fileTraining=None,
seed = 0
):
"""
Note:
- noise is variance (square of standard deviation).
"""

torch.manual_seed(seed)

# --------------------------------------------------------------------
# Input parameters
# --------------------------------------------------------------------

self.avoidPoints = avoidPoints if avoidPoints is not None else []
self.outputs = outputs
self.outputs_transformed = outputs_transformed
Expand All @@ -75,6 +71,50 @@ def __init__(

self.losses = None


def __init__(
self,
Xor,
Yor,
Yvaror,
surrogate_parameters,
outputs=None,
outputs_transformed=None,
bounds=None,
avoidPoints=None,
dfT=None,
surrogateOptions={},
FixedValue=False,
fileTraining=None,
seed = 0
):
"""
Note:
- noise is variance (square of standard deviation).
"""

torch.manual_seed(seed)

# --------------------------------------------------------------------
# Input parameters
# --------------------------------------------------------------------

self._init_parameters(
Xor,
Yor,
Yvaror,
surrogate_parameters,
outputs=outputs,
outputs_transformed=outputs_transformed,
bounds=bounds,
avoidPoints=avoidPoints,
dfT=dfT,
surrogateOptions=surrogateOptions,
FixedValue=FixedValue,
fileTraining=fileTraining,
seed=seed
)

# Print options
print("\t- Surrogate options:")
for i in self.surrogateOptions:
Expand Down Expand Up @@ -367,7 +407,7 @@ def normalization_pass(self,input_transform, outcome_transform):
train_X_transformed = input_transform['tf1'](self.train_X)

# Concatenate the training data and the data from file
train_X_transformed = torch.cat((train_X_transformed, self.train_X_added), axis=-2)
#train_X_transformed = torch.cat((train_X_transformed, self.train_X_added), axis=-2)

# Get the normalization constants
_ = input_transform['tf2'](train_X_transformed)
Expand Down
2 changes: 1 addition & 1 deletion templates/main.namelist.json
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
"surrogateOptions": {
"TypeKernel": 0,
"TypeMean": 0,
"selectSurrogate": null,
"selectSurrogates": null,
"FixedNoise": true,
"ExtraNoise": false,
"ConstrainNoise": -1e-3,
Expand Down

0 comments on commit 7156b7f

Please sign in to comment.