diff --git a/baseclasses/BaseSolver.py b/baseclasses/BaseSolver.py index 8c6a9e4..97cdab1 100644 --- a/baseclasses/BaseSolver.py +++ b/baseclasses/BaseSolver.py @@ -15,25 +15,67 @@ class BaseSolver(object): """ def __init__( - self, name, category, defaultOptions={}, options={}, immutableOptions=set(), deprecatedOptions={}, comm=None, informs={} + self, + name, + category, + defaultOptions={}, + options={}, + immutableOptions=set(), + deprecatedOptions={}, + comm=None, + informs={}, + checkDefaultOptions=True, + caseSensitiveOptions=False, ): """ Solver Class Initialization + + Parameters + ---------- + name : str + The name of the solver + category : dict + The category of the solver + defaultOptions : dict, optional + The default options dictionary + options : dict, optional + The user-supplied options dictionary + immutableOptions : set of strings, optional + A set of immutable option names, which cannot be modified after solver creation. + deprecatedOptions : dict, optional + A dictionary containing deprecated option names, and a message to display if they were used. + comm : MPI Communicator, optional + The comm object to be used. If none, serial execution is assumed. + informs : dict, optional + A dictionary of exit code: exit message mappings. + checkDefaultOptions : bool, optional + A flag to specify whether the default options should be used for error checking. + This is used in cases where the default options are not the complete set, which is common for external solvers. + In such cases, no error checking is done when calling ``setOption``, but the default options are still set as options + upon solver creation. + caseSensitiveOptions : bool, optional + A flag to specify whether the option names are case sensitive or insensitive. """ self.name = name self.category = category - self.options = CaseInsensitiveDict() - self.defaultOptions = CaseInsensitiveDict(defaultOptions) - self.immutableOptions = CaseInsensitiveSet(immutableOptions) - self.deprecatedOptions = CaseInsensitiveDict(deprecatedOptions) + if not caseSensitiveOptions: + self.options = CaseInsensitiveDict() + self.defaultOptions = CaseInsensitiveDict(defaultOptions) + self.immutableOptions = CaseInsensitiveSet(immutableOptions) + self.deprecatedOptions = CaseInsensitiveDict(deprecatedOptions) + else: + self.options = {} + self.defaultOptions = defaultOptions + self.immutableOptions = immutableOptions + self.deprecatedOptions = deprecatedOptions self.comm = comm self.informs = informs self.solverCreated = False + self.checkDefaultOptions = checkDefaultOptions # Initialize Options for key, (optionType, optionValue) in self.defaultOptions.items(): - # Check if the default is given in a list of possible values if isinstance(optionValue, list) and optionType is not list: # Default is the first element of the list @@ -67,39 +109,44 @@ def setOption(self, name, value): """ # Check if the option exists - try: - defaultType, defaultValue = self.defaultOptions[name] - except KeyError: - if name in self.deprecatedOptions: - raise Error(f"Option {name} is deprecated. {self.deprecatedOptions[name]}") - else: - raise Error(f"Option {name} is not a valid {self.name} option.") + if self.checkDefaultOptions: + try: + defaultType, defaultValue = self.defaultOptions[name] + except KeyError: + if name in self.deprecatedOptions: + raise Error(f"Option {name} is deprecated. {self.deprecatedOptions[name]}") + else: + raise Error(f"Option {name} is not a valid {self.name} option.") # Make sure we are not trying to change an immutable option if # we are not allowed to. if self.solverCreated and name in self.immutableOptions: raise Error(f"Option {name} cannot be modified after the solver is created.") - # If the default provides a list of acceptable values, check whether the value is valid - if isinstance(defaultValue, list) and defaultType is not list: - if value in defaultValue: - self.options[name] = value + if self.checkDefaultOptions: + # If the default provides a list of acceptable values, check whether the value is valid + if isinstance(defaultValue, list) and defaultType is not list: + if value in defaultValue: + self.options[name] = value + else: + raise Error( + f"Value for option {name} is not valid. " + + f"Value must be one of {defaultValue} with data type {defaultType}. " + + f"Received value is {value} with data type {type(value)}." + ) else: - raise Error( - f"Value for option {name} is not valid. " - + f"Value must be one of {defaultValue} with data type {defaultType}. " - + f"Received value is {value} with data type {type(value)}." - ) + # If a list is not provided, check just the type + if isinstance(value, defaultType): + self.options[name] = value + else: + raise Error( + f"Datatype for option {name} is not valid. " + + f"Expected data type {defaultType}. " + + f"Received data type is {type(value)}." + ) else: - # If a list is not provided, check just the type - if isinstance(value, defaultType): - self.options[name] = value - else: - raise Error( - f"Datatype for option {name} is not valid. " - + f"Expected data type {defaultType}. " - + f"Received data type is {type(value)}." - ) + # no error checking + self.options[name] = value def getOption(self, name): """ @@ -116,8 +163,14 @@ def getOption(self, name): Return the current value of the option. """ - if name in self.defaultOptions: - return self.options[name] + if name in self.defaultOptions or not self.checkDefaultOptions: + if name in self.options: + return self.options[name] + else: + raise Error( + f"Option {name} was not found. " + + "Because options checking has been disabled, make sure the option has been set first." + ) else: raise Error(f"{name} is not a valid option name.") diff --git a/baseclasses/__init__.py b/baseclasses/__init__.py index 9e0b32a..be64a46 100644 --- a/baseclasses/__init__.py +++ b/baseclasses/__init__.py @@ -1,4 +1,4 @@ -__version__ = "1.3.0" +__version__ = "1.3.1" from .pyAero_problem import AeroProblem from .pyTransi_problem import TransiProblem diff --git a/tests/test_BaseSolver.py b/tests/test_BaseSolver.py index 3a6e67c..d23b37f 100644 --- a/tests/test_BaseSolver.py +++ b/tests/test_BaseSolver.py @@ -9,7 +9,7 @@ class SOLVER(BaseSolver): - def __init__(self, name, options={}, comm=None): + def __init__(self, name, options={}, comm=None, checkDefaultOptions=True, caseSensitiveOptions=False): """Create an artificial class for testing""" @@ -43,6 +43,8 @@ def __init__(self, name, options={}, comm=None): deprecatedOptions=deprecatedOptions, comm=comm, informs=informs, + checkDefaultOptions=checkDefaultOptions, + caseSensitiveOptions=caseSensitiveOptions, ) @@ -104,6 +106,20 @@ def test_options(self): with self.assertRaises(Error): solver.setOption("oldoption", 4) # test deprecatedOptions + def test_checkDefaultOptions(self): + # initialize solver + solver = SOLVER("test", checkDefaultOptions=False) + solver.setOption("newOption", 1) + self.assertEqual(solver.getOption("newOption"), 1) + with self.assertRaises(Error): + solver.getOption("nonexistant option") # test that this name should be rejected + + def test_caseSensitive(self): + # initialize solver + solver = SOLVER("test", caseSensitiveOptions=True) + with self.assertRaises(Error): + solver.getOption("booloption") # test that this name should be rejected + class TestComm(unittest.TestCase):