Skip to content

Commit

Permalink
Added checkDefaultOptions (#31)
Browse files Browse the repository at this point in the history
  • Loading branch information
ewu63 authored Jan 11, 2021
1 parent 4727e70 commit e36c1fe
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 35 deletions.
119 changes: 86 additions & 33 deletions baseclasses/BaseSolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,25 +15,67 @@ class BaseSolver(object):
"""

def __init__(
self, name, category, defaultOptions={}, options={}, immutableOptions=set(), deprecatedOptions={}, comm=None, informs={}
self,
name,
category,
defaultOptions={},
options={},
immutableOptions=set(),
deprecatedOptions={},
comm=None,
informs={},
checkDefaultOptions=True,
caseSensitiveOptions=False,
):
"""
Solver Class Initialization
Parameters
----------
name : str
The name of the solver
category : dict
The category of the solver
defaultOptions : dict, optional
The default options dictionary
options : dict, optional
The user-supplied options dictionary
immutableOptions : set of strings, optional
A set of immutable option names, which cannot be modified after solver creation.
deprecatedOptions : dict, optional
A dictionary containing deprecated option names, and a message to display if they were used.
comm : MPI Communicator, optional
The comm object to be used. If none, serial execution is assumed.
informs : dict, optional
A dictionary of exit code: exit message mappings.
checkDefaultOptions : bool, optional
A flag to specify whether the default options should be used for error checking.
This is used in cases where the default options are not the complete set, which is common for external solvers.
In such cases, no error checking is done when calling ``setOption``, but the default options are still set as options
upon solver creation.
caseSensitiveOptions : bool, optional
A flag to specify whether the option names are case sensitive or insensitive.
"""

self.name = name
self.category = category
self.options = CaseInsensitiveDict()
self.defaultOptions = CaseInsensitiveDict(defaultOptions)
self.immutableOptions = CaseInsensitiveSet(immutableOptions)
self.deprecatedOptions = CaseInsensitiveDict(deprecatedOptions)
if not caseSensitiveOptions:
self.options = CaseInsensitiveDict()
self.defaultOptions = CaseInsensitiveDict(defaultOptions)
self.immutableOptions = CaseInsensitiveSet(immutableOptions)
self.deprecatedOptions = CaseInsensitiveDict(deprecatedOptions)
else:
self.options = {}
self.defaultOptions = defaultOptions
self.immutableOptions = immutableOptions
self.deprecatedOptions = deprecatedOptions
self.comm = comm
self.informs = informs
self.solverCreated = False
self.checkDefaultOptions = checkDefaultOptions

# Initialize Options
for key, (optionType, optionValue) in self.defaultOptions.items():

# Check if the default is given in a list of possible values
if isinstance(optionValue, list) and optionType is not list:
# Default is the first element of the list
Expand Down Expand Up @@ -67,39 +109,44 @@ def setOption(self, name, value):
"""
# Check if the option exists
try:
defaultType, defaultValue = self.defaultOptions[name]
except KeyError:
if name in self.deprecatedOptions:
raise Error(f"Option {name} is deprecated. {self.deprecatedOptions[name]}")
else:
raise Error(f"Option {name} is not a valid {self.name} option.")
if self.checkDefaultOptions:
try:
defaultType, defaultValue = self.defaultOptions[name]
except KeyError:
if name in self.deprecatedOptions:
raise Error(f"Option {name} is deprecated. {self.deprecatedOptions[name]}")
else:
raise Error(f"Option {name} is not a valid {self.name} option.")

# Make sure we are not trying to change an immutable option if
# we are not allowed to.
if self.solverCreated and name in self.immutableOptions:
raise Error(f"Option {name} cannot be modified after the solver is created.")

# If the default provides a list of acceptable values, check whether the value is valid
if isinstance(defaultValue, list) and defaultType is not list:
if value in defaultValue:
self.options[name] = value
if self.checkDefaultOptions:
# If the default provides a list of acceptable values, check whether the value is valid
if isinstance(defaultValue, list) and defaultType is not list:
if value in defaultValue:
self.options[name] = value
else:
raise Error(
f"Value for option {name} is not valid. "
+ f"Value must be one of {defaultValue} with data type {defaultType}. "
+ f"Received value is {value} with data type {type(value)}."
)
else:
raise Error(
f"Value for option {name} is not valid. "
+ f"Value must be one of {defaultValue} with data type {defaultType}. "
+ f"Received value is {value} with data type {type(value)}."
)
# If a list is not provided, check just the type
if isinstance(value, defaultType):
self.options[name] = value
else:
raise Error(
f"Datatype for option {name} is not valid. "
+ f"Expected data type {defaultType}. "
+ f"Received data type is {type(value)}."
)
else:
# If a list is not provided, check just the type
if isinstance(value, defaultType):
self.options[name] = value
else:
raise Error(
f"Datatype for option {name} is not valid. "
+ f"Expected data type {defaultType}. "
+ f"Received data type is {type(value)}."
)
# no error checking
self.options[name] = value

def getOption(self, name):
"""
Expand All @@ -116,8 +163,14 @@ def getOption(self, name):
Return the current value of the option.
"""

if name in self.defaultOptions:
return self.options[name]
if name in self.defaultOptions or not self.checkDefaultOptions:
if name in self.options:
return self.options[name]
else:
raise Error(
f"Option {name} was not found. "
+ "Because options checking has been disabled, make sure the option has been set first."
)
else:
raise Error(f"{name} is not a valid option name.")

Expand Down
2 changes: 1 addition & 1 deletion baseclasses/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "1.3.0"
__version__ = "1.3.1"

from .pyAero_problem import AeroProblem
from .pyTransi_problem import TransiProblem
Expand Down
18 changes: 17 additions & 1 deletion tests/test_BaseSolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


class SOLVER(BaseSolver):
def __init__(self, name, options={}, comm=None):
def __init__(self, name, options={}, comm=None, checkDefaultOptions=True, caseSensitiveOptions=False):

"""Create an artificial class for testing"""

Expand Down Expand Up @@ -43,6 +43,8 @@ def __init__(self, name, options={}, comm=None):
deprecatedOptions=deprecatedOptions,
comm=comm,
informs=informs,
checkDefaultOptions=checkDefaultOptions,
caseSensitiveOptions=caseSensitiveOptions,
)


Expand Down Expand Up @@ -104,6 +106,20 @@ def test_options(self):
with self.assertRaises(Error):
solver.setOption("oldoption", 4) # test deprecatedOptions

def test_checkDefaultOptions(self):
# initialize solver
solver = SOLVER("test", checkDefaultOptions=False)
solver.setOption("newOption", 1)
self.assertEqual(solver.getOption("newOption"), 1)
with self.assertRaises(Error):
solver.getOption("nonexistant option") # test that this name should be rejected

def test_caseSensitive(self):
# initialize solver
solver = SOLVER("test", caseSensitiveOptions=True)
with self.assertRaises(Error):
solver.getOption("booloption") # test that this name should be rejected


class TestComm(unittest.TestCase):

Expand Down

0 comments on commit e36c1fe

Please sign in to comment.