diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index a28302bc3..ea6658ebf 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,6 +1,6 @@ # How to contribute -The ARMI framework project strongly encourages developers to help contribute to and build the code. +The ARMI framework project strongly encourages developers to help contribute to the codebase. The ARMI framework code is open source, and your contributions will become open source. Although fewer laws apply to open source materials because they are publicly-available, you still @@ -13,16 +13,16 @@ There are a lot of things we need help with right off the bat, to get your feet * Many more type annotations are desired. Type issues cause lots of bugs. * Fewer Pylint warnings * Better documentation -* Additional relevance to thermal reactors * Better test coverage * Targeted speedups (e.g. informed by a profiler) +* Additional relevance to thermal reactors Naturally, we encourage other kinds of contributions as well. ## Testing Any contribution must pass all included unit tests. The tests are built and run with the -pytest system. Please add new tests if you add new functionality. You can generally just run +`pytest` system. Please add new tests if you add new functionality. You can generally just run `tox` to build the testing environment and execute all the tests and pylint checks. ## Submitting changes diff --git a/armi/apps.py b/armi/apps.py index 7fb89e07d..10fa63973 100644 --- a/armi/apps.py +++ b/armi/apps.py @@ -32,9 +32,9 @@ import collections from armi import plugins, pluginManager, meta, settings +from armi.reactor import parameters from armi.settings import Setting from armi.settings import fwSettings -from armi.reactor import parameters class App: diff --git a/armi/bookkeeping/report/reportingUtils.py b/armi/bookkeeping/report/reportingUtils.py index 90bec0714..aa4af6e53 100644 --- a/armi/bookkeeping/report/reportingUtils.py +++ b/armi/bookkeeping/report/reportingUtils.py @@ -30,11 +30,12 @@ import armi from armi import runLog -from armi import utils +from armi.utils import getFileSHA1Hash from armi.utils import iterables from armi.utils import units from armi.utils import textProcessors from armi.utils import plotting +from armi.utils.mathematics import findClosest from armi import interfaces from armi.bookkeeping import report from armi.reactor.flags import Flags @@ -131,7 +132,7 @@ def _listInputFiles(cs): shaHash = ( "MISSING" if (not fName or not os.path.exists(fName)) - else utils.getFileSHA1Hash(fName, digits=10) + else getFileSHA1Hash(fName, digits=10) ) inputInfo.append((label, fName, shaHash)) @@ -649,9 +650,8 @@ def summarizeZones(core, cs): peakAssem = highPow[peakIndex] avgPFrac = sum(pFracList) / len(pFracList) # true mean power fraction - _avgAssemPFrac, avgIndex = utils.findClosest( - pFracList, avgPFrac, indx=True - ) # the closest-to-average pfrac in the list + # the closest-to-average pfrac in the list + _avgAssemPFrac, avgIndex = findClosest(pFracList, avgPFrac, indx=True) avgAssem = highPow[avgIndex] # the actual average assembly # ok, now need counts, and peak and avg. flow and power in high power region. diff --git a/armi/conftest.py b/armi/conftest.py index 5e06a5ece..b3de51a33 100644 --- a/armi/conftest.py +++ b/armi/conftest.py @@ -34,7 +34,6 @@ def pytest_sessionstart(session): - print("Initializing generic ARMI Framework application") configure(apps.App()) bootstrapArmiTestEnv() diff --git a/armi/materials/lithium.py b/armi/materials/lithium.py index 66ec33a77..b8d3c3326 100644 --- a/armi/materials/lithium.py +++ b/armi/materials/lithium.py @@ -20,9 +20,9 @@ """ from armi import runLog -from armi import utils from armi.materials import material from armi.nucDirectory import nuclideBases as nb +from armi.utils.mathematics import getFloat class Lithium(material.Fluid): @@ -46,7 +46,7 @@ def applyInputParams(self, LI_wt_frac=None, LI6_wt_frac=None, *args, **kwargs): LI6_wt_frac = LI6_wt_frac or LI_wt_frac - enrich = utils.getFloat(LI6_wt_frac) + enrich = getFloat(LI6_wt_frac) # allow 0.0 to pass in! if enrich is not None: self.adjustMassEnrichment(LI6_wt_frac) diff --git a/armi/materials/sulfur.py b/armi/materials/sulfur.py index 568b34081..9784e4841 100644 --- a/armi/materials/sulfur.py +++ b/armi/materials/sulfur.py @@ -17,7 +17,7 @@ """ from armi import runLog -from armi import utils +from armi.utils.mathematics import linearInterpolation from armi.utils.units import getTk from armi.materials import material @@ -67,6 +67,4 @@ def volumetricExpansion(self, Tk=None, Tc=None): Tk = getTk(Tc, Tk) self.checkTempRange(334, 430, Tk, "volumetric expansion") - return utils.linearInterpolation( - x0=334, y0=5.28e-4, x1=430, y1=5.56e-4, targetX=Tk - ) + return linearInterpolation(x0=334, y0=5.28e-4, x1=430, y1=5.56e-4, targetX=Tk) diff --git a/armi/nucDirectory/nuclideBases.py b/armi/nucDirectory/nuclideBases.py index 7f628ac24..201ba7372 100644 --- a/armi/nucDirectory/nuclideBases.py +++ b/armi/nucDirectory/nuclideBases.py @@ -309,6 +309,7 @@ def __readMc2Nuclides(): """ with open(os.path.join(armi.context.RES, "mc2Nuclides.yaml"), "r") as mc2Nucs: mc2Nuclides = yaml.load(mc2Nucs, Loader=yaml.FullLoader) + # now add the mc2 specific nuclideBases, and correct the mc2Ids when a > 0 and state = 0 for name, data in mc2Nuclides.items(): z = data["z"] diff --git a/armi/operators/operator.py b/armi/operators/operator.py index 9ea61a878..181434af7 100644 --- a/armi/operators/operator.py +++ b/armi/operators/operator.py @@ -31,7 +31,7 @@ from armi import context from armi import runLog from armi.bookkeeping import memoryProfiler -from armi import utils +from armi.utils.mathematics import expandRepeatedFloats from armi.utils import codeTiming from armi.utils import pathTools from armi import settings @@ -142,19 +142,19 @@ def _initFastPath(): def _getCycleLengths(self): """Return the cycle length for each cycle of the system as a list.""" - return utils.expandRepeatedFloats(self.cs["cycleLengths"]) or ( + return expandRepeatedFloats(self.cs["cycleLengths"]) or ( [self.cs["cycleLength"]] * self.cs["nCycles"] ) def _getAvailabilityFactors(self): """Return the availability factors (capacity factor) for each cycle of the system as a list.""" - return utils.expandRepeatedFloats(self.cs["availabilityFactors"]) or ( + return expandRepeatedFloats(self.cs["availabilityFactors"]) or ( [self.cs["availabilityFactor"]] * self.cs["nCycles"] ) def _getPowerFractions(self): """Return the power fractions for each cycle of the system as a list.""" - return utils.expandRepeatedFloats(self.cs["powerFractions"]) or ( + return expandRepeatedFloats(self.cs["powerFractions"]) or ( [1.0 for _cl in self.cycleLengths] ) @@ -919,7 +919,7 @@ def snapshotRequest(self, cycle, node): newFolder = "snapShot{0}_{1}".format(cycle, node) if os.path.exists(newFolder): runLog.important("Deleting existing snapshot data in {0}".format(newFolder)) - utils.pathTools.cleanPath(newFolder) # careful with cleanPath! + pathTools.cleanPath(newFolder) # careful with cleanPath! # give it a minute. time.sleep(1) diff --git a/armi/operators/settingsValidation.py b/armi/operators/settingsValidation.py index 83e50cb66..3d4b32ec8 100644 --- a/armi/operators/settingsValidation.py +++ b/armi/operators/settingsValidation.py @@ -26,6 +26,7 @@ import armi from armi import runLog, settings, utils from armi.utils import pathTools +from armi.utils.mathematics import expandRepeatedFloats from armi.reactor import geometry from armi.reactor import systemLayoutInput from armi.physics import neutronics @@ -463,7 +464,7 @@ def _willBeCopiedFrom(fName): def _factorsAreValid(factors, maxVal=1.0): try: - expandedList = utils.expandRepeatedFloats(factors) + expandedList = expandRepeatedFloats(factors) except (ValueError, IndexError): return False return ( @@ -527,8 +528,8 @@ def _correctCycles(): def decayCyclesHaveInputThatWillBeIgnored(): """Check if there is any decay-related input that will be ignored.""" try: - powerFracs = utils.expandRepeatedFloats(self.cs["powerFractions"]) - availabilities = utils.expandRepeatedFloats( + powerFracs = expandRepeatedFloats(self.cs["powerFractions"]) + availabilities = expandRepeatedFloats( self.cs["availabilityFactors"] ) or ([self.cs["availabilityFactor"]] * self.cs["nCycles"]) except: # pylint: disable=bare-except diff --git a/armi/physics/fuelCycle/fuelHandlers.py b/armi/physics/fuelCycle/fuelHandlers.py index dd952eaf6..9a730b49d 100644 --- a/armi/physics/fuelCycle/fuelHandlers.py +++ b/armi/physics/fuelCycle/fuelHandlers.py @@ -38,9 +38,8 @@ from armi.reactor.flags import Flags from armi.operators import RunTypes from armi.utils import directoryChangers, pathTools -from armi import utils from armi.utils import plotting -from armi.utils.mathematics import resampleStepwise +from armi.utils.mathematics import findClosest, resampleStepwise runLog = logging.getLogger(__name__) @@ -1249,17 +1248,16 @@ def buildRingSchedule( # build widths widths = [] for i, ring in enumerate(baseRings[:-1]): - widths.append( - abs(baseRings[i + 1] - ring) - 1 - ) # 0 is the most restrictive, meaning don't even look in other rings. + # 0 is the most restrictive, meaning don't even look in other rings. + widths.append(abs(baseRings[i + 1] - ring) - 1) widths.append(0) # add the last ring with width 0. # step 2: locate which rings should be reversed to give the jump-ring effect. if jumpRingFrom is not None: - _closestRingFrom, jumpRingFromIndex = utils.findClosest( + _closestRingFrom, jumpRingFromIndex = findClosest( baseRings, jumpRingFrom, indx=True ) - _closestRingTo, jumpRingToIndex = utils.findClosest( + _closestRingTo, jumpRingToIndex = findClosest( baseRings, jumpRingTo, indx=True ) else: diff --git a/armi/physics/neutronics/energyGroups.py b/armi/physics/neutronics/energyGroups.py index 6a3d55d79..d6bd4b451 100644 --- a/armi/physics/neutronics/energyGroups.py +++ b/armi/physics/neutronics/energyGroups.py @@ -15,14 +15,15 @@ Energy group structures for multigroup neutronics calculations. """ -import itertools import copy +import itertools import math import numpy from armi import utils from armi import runLog +from armi.utils.mathematics import findNearestValue from .const import ( FAST_FLUX_THRESHOLD_EV, MAXIMUM_XS_LIBRARY_ENERGY, @@ -32,10 +33,7 @@ def getFastFluxGroupCutoff(eGrpStruc): - """ - Given a constant "fast" energy threshold, return which ARMI energy group index contains this threshold. - """ - + """Given a constant "fast" energy threshold, return which ARMI energy group index contains this threshold.""" gThres = -1 for g, eV in enumerate(eGrpStruc): if eV < FAST_FLUX_THRESHOLD_EV: @@ -302,9 +300,7 @@ def _create_multigroup_structures_on_finegroup_energies( modifiedEnergyBounds = set() modifiedEnergyBounds.add(max(finegroup_energy_bounds)) for energyBound in multigroup_energy_bounds[1:]: - modifiedEnergyBounds.add( - utils.findNearestValue(finegroup_energy_bounds, energyBound) - ) + modifiedEnergyBounds.add(findNearestValue(finegroup_energy_bounds, energyBound)) return sorted(modifiedEnergyBounds, reverse=True) diff --git a/armi/reactor/converters/uniformMesh.py b/armi/reactor/converters/uniformMesh.py index f6f63df7c..3e8468947 100644 --- a/armi/reactor/converters/uniformMesh.py +++ b/armi/reactor/converters/uniformMesh.py @@ -58,7 +58,7 @@ import numpy from armi import runLog -from armi import utils +from armi.utils.mathematics import average1DWithinTolerance from armi.utils import iterables from armi.utils import plotting from armi.reactor import grids @@ -124,7 +124,7 @@ def _computeAverageAxialMesh(self): aMesh = src.core.findAllAxialMeshPoints([a], applySubMesh=True)[1:] if len(aMesh) == refNumPoints: allMeshes.append(aMesh) - self._uniformMesh = utils.average1DWithinTolerance(numpy.array(allMeshes)) + self._uniformMesh = average1DWithinTolerance(numpy.array(allMeshes)) @staticmethod def _createNewAssembly(sourceAssembly): diff --git a/armi/reactor/reactors.py b/armi/reactor/reactors.py index eea1cfb29..e8a7972d8 100644 --- a/armi/reactor/reactors.py +++ b/armi/reactor/reactors.py @@ -32,7 +32,7 @@ import numpy -from armi import getPluginManagerOrFail, materials, nuclearDataIO, settings, utils +from armi import getPluginManagerOrFail, materials, nuclearDataIO, settings from armi.reactor import assemblies from armi.reactor import assemblyLists from armi.reactor import composites @@ -42,9 +42,10 @@ from armi.reactor import parameters from armi.reactor import zones from armi.reactor import reactorParameters -from armi.utils import units +from armi.utils import createFormattedStrWithDelimiter, units from armi.utils.iterables import Sequence from armi.utils import directoryChangers +from armi.utils.mathematics import average1DWithinTolerance from armi.reactor.flags import Flags from armi.settings.fwSettings.globalSettings import CONF_MATERIAL_NAMESPACE_ORDER from armi.nuclearDataIO import xsLibraries @@ -1305,19 +1306,19 @@ def summarizeNuclideCategories(self): [ ( "Fuel", - utils.createFormattedStrWithDelimiter( + createFormattedStrWithDelimiter( self._nuclideCategories["fuel"] ), ), ( "Coolant", - utils.createFormattedStrWithDelimiter( + createFormattedStrWithDelimiter( self._nuclideCategories["coolant"] ), ), ( "Structure", - utils.createFormattedStrWithDelimiter( + createFormattedStrWithDelimiter( self._nuclideCategories["structure"] ), ), @@ -1874,7 +1875,7 @@ def updateAxialMesh(self): # depending on what makes the most sense refAssem = self.refAssem refMesh = self.findAllAxialMeshPoints([refAssem]) - avgHeight = utils.average1DWithinTolerance( + avgHeight = average1DWithinTolerance( numpy.array( [ [ diff --git a/armi/utils/__init__.py b/armi/utils/__init__.py index 065154413..df0b728e9 100644 --- a/armi/utils/__init__.py +++ b/armi/utils/__init__.py @@ -13,85 +13,33 @@ # limitations under the License. """Generic ARMI utilities""" +import collections +import datetime +import getpass +import hashlib +import importlib +import math import os -import sys -import time import pickle -import re import pkgutil -import importlib -import traceback -import getpass -import math -import datetime -import tempfile +import re import shutil -import threading import subprocess -import collections - -import hashlib - -import numpy -import scipy.optimize as sciopt +import sys +import tempfile +import threading +import time +import traceback import armi from armi import runLog from armi.utils import iterables from armi.utils.flags import Flag +from armi.utils.mathematics import * # for backwards compatibility # Read in file 1 MB at a time to reduce memory burden of reading entire file at once _HASH_BUFFER_SIZE = 1024 * 1024 -# special pattern to deal with FORTRAN-produced scipats without E, like 3.2234-234 -SCIPAT_SPECIAL = re.compile(r"([+-]?\d*\.\d+)[eEdD]?([+-]\d+)") - - -def coverageReportHelper(config, dataPaths): - """ - Small utility function to generate coverage reports. - - This was created to side-step the difficulties in submitting multi-line python - commands on-the-fly. - - This combines data paths and then makes html and xml reports for the - fully-combined result. - """ - from coverage import Coverage - import coverage - - try: - cov = Coverage(config_file=config) - if dataPaths: - # fun fact: if you combine when there's only one file, it gets deleted. - cov.combine(data_paths=dataPaths) - cov.save() - else: - cov.load() - cov.html_report() - cov.xml_report() - except PermissionError as e: - # Some file systems have some issues with filenames that start with a '.', such as the - # .coverage files. If a permissions error is raised, it likely has something to - # do with that. We changed the COVERAGE_RESULTS_FILE in cases.py for this reason. - runLog.error( - f"There was an issue in generating coverage reports due " - f"to the following permissions error: {e}" - ) - # disabled until we figure out the problem. - # raise - except coverage.misc.CoverageException as e: - # This is happening when forming the unit test coverage report. This may be - # caused by the TestFixture coverage report gobbling up all of the coverage - # files before the UnitTests.cov_report task gets a chance to see them. It may - # simply be that we dont want a coverage report generated for the TestFixture. - # Something to think about. Either way, we do not want to fail the job just - # because of this - runLog.error( - "There was an issue generating coverage reports " - "({}):\n{}".format(type(e), e.args) - ) - def getFileSHA1Hash(filePath, digits=40): """ @@ -114,80 +62,6 @@ def getFileSHA1Hash(filePath, digits=40): return sha1.hexdigest()[:digits] -def efmt(a: str) -> str: - r"""Converts string exponential number to another string with just 2 digits in the exponent.""" - # this assumes that none of our numbers will be more than 1e100 or less than 1e-100... - if len(a.split("E")) != 2: - two = a.split("e") - else: - two = a.split("E") - # print two - exp = two[1] # this is '+002' or '+02' or something - - if len(exp) == 4: # it has 3 digits of exponent - exp = exp[0] + exp[2:] # gets rid of the hundred's place digit - - return two[0] + "E" + exp - - -def fixThreeDigitExp(strToFloat: str) -> float: - """ - Convert FORTRAN numbers that cannot be converted into floats. - - Notes - ----- - Converts a number line "9.03231714805651-101" (no e or E) to "9.03231714805651e-101". - Some external depletion kernels currently need this fix. From contact with developer: - The notation like 1.0-101 is a FORTRAN thing, with history going back to the 60's. - They will only put E before an exponent 99 and below. Fortran will also read these guys - just fine, and they are valid floating point numbers. It would not be a useful effort, - in terms of time, trying to get FORTRAN to behave differently. - The approach has been to write a routine in the reading code which will interpret these. - - This helps when the scientific number exponent does not fit. - """ - match = SCIPAT_SPECIAL.match(strToFloat) - return float("{}E{}".format(*match.groups())) - - -def findClosest(listToSearch, val, indx=False): - r""" - find closest item in a list. - - Parameters - ---------- - listToSearch : list - The list to search through - - val : float - The target value that is being searched for in the list - - indx : bool, optional - If true, returns minVal and minIndex, otherwise, just the value - - Returns - ------- - minVal : float - The item in the listToSearch that is closest to val - minI : int - The index of the item in listToSearch that is closest to val. Returned if indx=True. - - """ - d = float("inf") - minVal = None - minI = None - for i, item in enumerate(listToSearch): - if abs(item - val) < d: - d = abs(item - val) - minVal = item - minI = i - if indx: - return minVal, minI - else: - # backwards compatibility - return minVal - - def copyWithoutBlocking(src, dest): """ Copy a file in a separate thread to avoid blocking while IO completes. @@ -201,158 +75,6 @@ def copyWithoutBlocking(src, dest): return t -def linearInterpolation(x0, y0, x1, y1, targetX=None, targetY=None): - r""" - does a linear interpolation (or extrapolation) for y=f(x) - - Parameters - ---------- - x0,y0,x1,y1 : float - Coordinates of two points to interpolate between - - targetX : float, optional - X value to evaluate the line at - - targetY : float, optional - Y value we want to find the x value for (inverse interpolation) - - Returns - ------- - interpY : float - The value of y(targetX), if targetX is not None - - interpX : float - The value of x where y(x) = targetY (if targetY is not None) - - y = m(x-x0) + b - - x = (y-b)/m - - """ - if x1 == x0: - raise ZeroDivisionError("The x-values are identical. Cannot interpolate.") - - m = (y1 - y0) / (x1 - x0) - b = -m * x0 + y0 - - if targetX is not None: - return m * targetX + b - else: - return (targetY - b) / m - - -def parabolaFromPoints(p1, p2, p3): - r""" - find the parabola that passes through three points - - We solve a simultaneous equation with three points. - - A = x1**2 x1 1 - x2**2 x2 1 - x3**2 x3 1 - - b = y1 - y2 - y3 - - find coefficients Ax=b - - Parameters - ---------- - p1 : tuple - first point (x,y) coordinates - p2,p3: tuple, second and third points. - - Returns - ------- - a,b,c coefficients of y=ax^2+bx+c - - """ - - A = numpy.array( - [[p1[0] ** 2, p1[0], 1], [p2[0] ** 2, p2[0], 1], [p3[0] ** 2, p3[0], 1]] - ) - - b = numpy.array([[p1[1]], [p2[1]], [p3[1]]]) - try: - x = numpy.linalg.solve(A, b) - except: - print("Error in parabola {} {}".format(A, b)) - raise - - return float(x[0]), float(x[1]), float(x[2]) - - -def parabolicInterpolation(ap, bp, cp, targetY): - r""" - Given parabola coefficients, this interpolates the time - that would give k=targetK. - - keff = at^2+bt+c - We want to solve a*t^2+bt+c-targetK = 0.0 for time. - if there are real roots, we should probably take the smallest one - because the larger one might be at very high burnup. - If there are no real roots, just take the point where the deriv ==0, or - 2at+b=0, so t = -b/2a - The slope of the curve is the solution to 2at+b at whatever t has been determined - - Parameters - ---------- - ap, bp,cp : floats - coefficients of a parabola y = ap*x^2 + bp*x + cp - - targetK : float - The keff to find the cycle length of - - Returns - ------- - realRoots : list of tuples - (root, slope) - The best guess of the cycle length that will give k=targetK - If no positive root was found, this is the maximum of the curve. In that case, - it will be a negative number. If there are two positive roots, there will be two entries. - - slope : float - The slope of the keff vs. time curve at t=newTime - - """ - roots = numpy.roots([ap, bp, cp - targetY]) - realRoots = [] - for r in roots: - if r.imag == 0 and r.real > 0: - realRoots.append((r.real, 2.0 * ap * r.real + bp)) - - if not realRoots: - # no positive real roots. Take maximum and give up for this cyclic. - newTime = -bp / (2 * ap) - if newTime < 0: - raise RuntimeError("No positive roots or maxima.") - slope = 2.0 * ap * newTime + bp - newTime = ( - -newTime - ) # return a negative newTime to signal that it is not expected to be critical. - realRoots = [(newTime, slope)] - - return realRoots - - -def getFloat(val): - r"""returns float version of val, or None if it's impossible. Useful for converting - user-input into floats when '' might be possible.""" - try: - newVal = float(val) - return newVal - except: - return None - - -def relErr(v1: float, v2: float) -> float: - if v1: - return (v2 - v1) / v1 - else: - return -1e99 - - def getTimeStepNum(cycleNumber, subcycleNumber, cs): """Return the timestep associated with cycle and tn. @@ -433,12 +155,6 @@ def tryPickleOnAllContents(obj, ignore=None, path=None, verbose=False): print( "{0} in {1} cannot be pickled. It is: {2}. ".format(name, obj, ob) ) - # traceback.print_exc(limit=0,file=sys.stdout) - - -def tryPickleOnAllContents2(*args, **kwargs): - # helper - print(doTestPickleOnAllContents2(*args, **kwargs)) def doTestPickleOnAllContents2(obj, ignore=None, path=None, verbose=False): @@ -475,7 +191,7 @@ def doTestPickleOnAllContents2(obj, ignore=None, path=None, verbose=False): class MyPickler(pickle.Pickler): r""" - The big guns. This will find your pickle errors if all else fails. + This will find your pickle errors if all else fails. Use with tryPickleOnAllContents3. """ @@ -499,7 +215,6 @@ def tryPickleOnAllContents3(obj, ignore=None, path=None, verbose=False): to make it work like the other testPickle functions and handle errors, you could. But usually you just have to find one unpickleable SOB. """ - with tempfile.TemporaryFile() as output: try: MyPickler(output).dump(obj) @@ -508,9 +223,7 @@ def tryPickleOnAllContents3(obj, ignore=None, path=None, verbose=False): def classesInHierarchy(obj, classCounts, visited=None): - """ - Count the number of instances of each class contained in an objects heirarchy. - """ + """Count the number of instances of each class contained in an objects heirarchy.""" if not isinstance(classCounts, collections.defaultdict): raise TypeError( "Need to pass in a default dict for classCounts (it's an out param)" @@ -550,121 +263,6 @@ def slantSplit(val, ratio, nodes, order="low first"): return X -def newtonsMethod( - func, goal, guess, maxIterations=None, cs=None, positiveGuesses=False -): - r""" - Solves a Newton's method with the given function, goal value, and first guess. - - Parameters - ---------- - func : function - The function that guess will be changed to try to make it return the goal value. - - goal : float - The function will be changed until it's return equals this value. - - guess : float - The first guess value to do Newton's method on the func. - - maxIterations : int - The maximum number of iterations that the Newton's method will be allowed to perform. - - - Returns - ------- - ans : float - The guess that when input to the func returns the goal. - - """ - - def goalFunc(guess, func, positiveGuesses): - if positiveGuesses is True: - guess = abs(guess) - funcVal = func(guess) - val = abs(goal - funcVal) - return val - - if (maxIterations is None) and (cs is not None): - maxIterations = cs["maxNewtonsIterations"] - - # try: - ans = float( - sciopt.newton( - goalFunc, - guess, - args=(func, positiveGuesses), - tol=1.0e-3, - maxiter=maxIterations, - ) - ) - - if positiveGuesses is True: - ans = abs(ans) - - return ans - - -def minimizeScalarFunc( - func, - goal, - guess, - maxIterations=None, - cs=None, - positiveGuesses=False, - method=None, - tol=1.0e-3, -): - r""" - Use scipy minimize with the given function, goal value, and first guess. - - Parameters - ---------- - func : function - The function that guess will be changed to try to make it return the goal value. - - goal : float - The function will be changed until it's return equals this value. - - guess : float - The first guess value to do Newton's method on the func. - - maxIterations : int - The maximum number of iterations that the Newton's method will be allowed to perform. - - - Returns - ------- - ans : float - The guess that when input to the func returns the goal. - - """ - - def goalFunc(guess, func, positiveGuesses): - if positiveGuesses is True: - guess = abs(guess) - funcVal = func(guess) - val = abs(goal - funcVal) - return val - - if (maxIterations is None) and (cs is not None): - maxIterations = cs["maxNewtonsIterations"] - - X = sciopt.minimize( - goalFunc, - guess, - args=(func, positiveGuesses), - method=method, - tol=tol, - options={"maxiter": maxIterations}, - ) - ans = float(X["x"]) - if positiveGuesses is True: - ans = abs(ans) - - return ans - - def runFunctionFromAllModules(funcName, *args, **kwargs): r""" Runs funcName on all modules of ARMI, if it exists. @@ -701,34 +299,6 @@ def runFunctionFromAllModules(funcName, *args, **kwargs): traceback.print_exc() -# TODO: move to pathTools (and reference it here for convenience) -def mkdir(dirname): - r""" - Keeps trying to make a directory, outputting whatever errors it encounters, - until it is successful. - - Parameters - ---------- - dirname : str - Path to the directory to create. - What you would normally pass to os.mkdir. - """ - numTimesTried = 0 - while numTimesTried < 1000: - try: - os.mkdir(dirname) - break - except FileExistsError: - break - except Exception as err: - numTimesTried += 1 - # Only ouput err every 10 times. - if numTimesTried % 10 == 0: - print(err) - # Wait 0.5 seconds, try again. - time.sleep(0.5) - - def prependToList(originalList, listToPrepend): """ Add a new list to the beginnning of an original list. @@ -866,142 +436,6 @@ def createFormattedStrWithDelimiter( return formattedString -def rotateXY(x, y, degreesCounterclockwise=None, radiansCounterclockwise=None): - """ - Rotates x, y coordinates - - Parameters - ---------- - x, y : array_like - coordinates - - degreesCounterclockwise : float - Degrees to rotate in the CCW direction - - radiansCounterclockwise : float - Radians to rotate in the CCW direction - - Returns - ------- - xr, yr : array_like - the rotated coordinates. - """ - - if radiansCounterclockwise is None: - radiansCounterclockwise = degreesCounterclockwise * math.pi / 180.0 - - sinT = math.sin(radiansCounterclockwise) - cosT = math.cos(radiansCounterclockwise) - rotationMatrix = numpy.array([[cosT, -sinT], [sinT, cosT]]) - xr, yr = rotationMatrix.dot(numpy.vstack((x, y))) - if len(xr) > 1: - ## Convert to lists because everyone prefers lists for some reason - return xr.tolist(), yr.tolist() - else: - ## Convert to scalar for consistency with old implementation - return xr[0], yr[0] - - -def convertToSlice(x, increment=False): - """ - Convert a int, float, list of ints or floats, None, or slice - to a slice. Also optionally increments that slice to make it easy to line - up lists that don't start with 0. - - Use this with numpy.array (numpy.ndarray) types to easily get selections of it's elements. - - Parameters - ---------- - x : multiple types allowed. - int: select one index. - list of int: select these index numbers. - None: select all indices. - slice: select this slice - - Returns - ------- - slice : slice - Returns a slice object that can be used in an array - like a[x] to select from its members. - Also, the slice has its index numbers decremented by 1. - It can also return a numpy array, which can be used - to slice other numpy arrays in the same way as a slice. - - Examples - -------- - a = numpy.array([10, 11, 12, 13]) - - >>> convertToSlice(2) - slice(2, 3, None) - >>> a[convertToSlice(2)] - array([12]) - - >>> convertToSlice(2, increment=-1) - slice(1, 2, None) - >>> a[convertToSlice(2, increment=-1)] - array([11]) - - >>> a[convertToSlice(None)] - array([10, 11, 12, 13]) - - - >>> a[utils.convertToSlice([1, 3])] - array([11, 13]) - - - >>> a[utils.convertToSlice([1, 3], increment=-1)] - array([10, 12]) - - >>> a[utils.convertToSlice(slice(2, 3, None), increment=-1)] - array([11]) - - """ - if increment is False: - increment = 0 - - if not isinstance(increment, int): - raise Exception("increment must be False or an integer in utils.convertToSlice") - - if x is None: - x = numpy.s_[:] - - if isinstance(x, list): - x = numpy.array(x) - - if isinstance(x, (int, numpy.integer)) or isinstance(x, (float, numpy.floating)): - x = slice(int(x), int(x) + 1, None) - - # Correct the slice indices to be group instead of index based. - # The energy groups are 1..x and the indices are 0..x-1. - if isinstance(x, slice): - if x.start is not None: - jstart = x.start + increment - else: - jstart = None - - if x.stop is not None: - if isinstance(x.stop, list): - jstop = [x + increment for x in x.stop] - else: - jstop = x.stop + increment - else: - jstop = None - - jstep = x.step - - return numpy.s_[jstart:jstop:jstep] - - elif isinstance(x, numpy.ndarray): - return numpy.array([i + increment for i in x]) - - else: - raise Exception( - ( - "It is not known how to handle x type: " "{0} in utils.convertToSlice" - ).format(type(x)) - ) - - def plotMatrix( matrix, fName, @@ -1016,9 +450,7 @@ def plotMatrix( cmap=None, figsize=None, ): - """ - Plots a matrix - """ + """Plots a matrix""" import matplotlib import matplotlib.pyplot as plt @@ -1026,8 +458,10 @@ def plotMatrix( plt.figure(figsize=figsize) # dpi=300) else: plt.figure() + if cmap is None: cmap = plt.cm.jet # @UndefinedVariable #pylint: disable=no-member + cmap.set_bad("w") try: matrix = matrix.todense() @@ -1041,9 +475,9 @@ def plotMatrix( if title is None: title = fName - plt.imshow( - matrix, cmap=cmap, norm=norm, interpolation="nearest" - ) # or bicubic or nearest#,vmin=0, vmax=300) + + # or bicubic or nearest#,vmin=0, vmax=300) + plt.imshow(matrix, cmap=cmap, norm=norm, interpolation="nearest") plt.colorbar() plt.title(title) plt.xlabel(xlabel) @@ -1074,90 +508,6 @@ def userName() -> str: return re.sub("^[a-zA-Z]-", "", getpass.getuser()) -def expandRepeatedFloats(repeatedList): - """ - Return an expanded repeat list. - - Notes - ----- - R char is valid for showing the number of repeats in MCNP. For examples the list: - [150, 200, '9R'] - indicates a 150 day cycle followed by 10 200 day cycles. - """ - nonRepeatList = [] - for val in repeatedList: - isRepeat = False - if isinstance(val, str): - val = val.upper() - if val.count("R") > 1: - raise ValueError("List had strings that were not repeats") - elif "R" in val: - val = val.replace("R", "") - isRepeat = True - if isRepeat: - nonRepeatList += [nonRepeatList[-1]] * int(val) - else: - nonRepeatList.append(float(val)) - return nonRepeatList - - -def getStepsFromValues(values, prevValue=0.0): - """Convert list of floats to list of steps between each float.""" - steps = [] - for val in values: - currentVal = float(val) - steps.append(currentVal - prevValue) - prevValue = currentVal - return steps - - -def average1DWithinTolerance(vals, tolerance=0.2): - """ - Compute the average of a series of arrays with a tolerance. - - Tuned for averaging assembly meshes or block heights. - - Parameters - ---------- - vals : 2D numpy.array - could be assembly x axial mesh tops or heights - """ - vals = numpy.array(vals) - - filterOut = numpy.array([False]) # this gets discarded - while not filterOut.all(): # 20% difference is the default tolerance - avg = vals.mean(axis=0) # average over all columns - diff = abs(vals - avg) / avg # no nans, because all vals are non-zero - filterOut = (diff > tolerance).sum( - axis=1 - ) == 0 # True = 1, sum across axis means any height in assem is off - vals = vals[filterOut] # filter anything that is skewing - - if vals.size == 0: - raise ValueError("Nothing was near the mean, there are no acceptable values!") - - if (avg <= 0.0).any(): - raise ValueError( - "A non-physical value (<=0) was computed, but this is not possible.\n" - "Values: {}\navg: {}".format(vals, avg) - ) - - return avg - - -def findNearestValue(searchList, searchValue): - """Search a given list for the value that is closest to the given search value.""" - return findNearestValueAndIndex(searchList, searchValue)[0] - - -def findNearestValueAndIndex(searchList, searchValue): - """Search a given list for the value that is closest to the given search value. Return a tuple - containing the value and its index in the list.""" - searchArray = numpy.array(searchList) - closestValueIndex = (numpy.abs(searchArray - searchValue)).argmin() - return searchArray[closestValueIndex], closestValueIndex - - class MergeableDict(dict): """ Overrides python dictionary and implements a merge method. @@ -1172,9 +522,6 @@ def merge(self, *otherDictionaries) -> None: self.update(dictionary) -shutil_copy = shutil.copy - - def safeCopy(src: str, dst: str) -> None: """This copy overwrites ``shutil.copy`` and checks that copy operation is truly completed before continuing.""" waitTime = 0.01 # 10 ms @@ -1191,4 +538,6 @@ def safeCopy(src: str, dst: str) -> None: runLog.extra("Copied {} -> {}".format(src, dst)) +# Allow us to check the copy operation is complete before continuing +shutil_copy = shutil.copy shutil.copy = safeCopy diff --git a/armi/utils/mathematics.py b/armi/utils/mathematics.py index 746af9603..b71cfdc73 100644 --- a/armi/utils/mathematics.py +++ b/armi/utils/mathematics.py @@ -13,7 +13,539 @@ # limitations under the License. """Various math utilities""" +import math +import re + import numpy as np +import scipy.optimize as sciopt + +# special pattern to deal with FORTRAN-produced scipats without E, like 3.2234-234 +SCIPAT_SPECIAL = re.compile(r"([+-]?\d*\.\d+)[eEdD]?([+-]\d+)") + + +def average1DWithinTolerance(vals, tolerance=0.2): + """ + Compute the average of a series of arrays with a tolerance. + + Tuned for averaging assembly meshes or block heights. + + Parameters + ---------- + vals : 2D np.array + could be assembly x axial mesh tops or heights + """ + vals = np.array(vals) + + filterOut = np.array([False]) # this gets discarded + while not filterOut.all(): # 20% difference is the default tolerance + avg = vals.mean(axis=0) # average over all columns + diff = abs(vals - avg) / avg # no nans, because all vals are non-zero + # True = 1, sum across axis means any height in assem is off + filterOut = (diff > tolerance).sum(axis=1) == 0 + vals = vals[filterOut] # filter anything that is skewing + + if vals.size == 0: + raise ValueError("Nothing was near the mean, there are no acceptable values!") + + if (avg <= 0.0).any(): + raise ValueError( + "A non-physical value (<=0) was computed, but this is not possible.\n" + "Values: {}\navg: {}".format(vals, avg) + ) + + return avg + + +def convertToSlice(x, increment=False): + """ + Convert a int, float, list of ints or floats, None, or slice + to a slice. Also optionally increments that slice to make it easy to line + up lists that don't start with 0. + + Use this with np.array (np.ndarray) types to easily get selections of it's elements. + + Parameters + ---------- + x : multiple types allowed. + int: select one index. + list of int: select these index numbers. + None: select all indices. + slice: select this slice + + Returns + ------- + slice : slice + Returns a slice object that can be used in an array + like a[x] to select from its members. + Also, the slice has its index numbers decremented by 1. + It can also return a numpy array, which can be used + to slice other numpy arrays in the same way as a slice. + + Examples + -------- + a = np.array([10, 11, 12, 13]) + + >>> convertToSlice(2) + slice(2, 3, None) + >>> a[convertToSlice(2)] + array([12]) + + >>> convertToSlice(2, increment=-1) + slice(1, 2, None) + >>> a[convertToSlice(2, increment=-1)] + array([11]) + + >>> a[convertToSlice(None)] + array([10, 11, 12, 13]) + + + >>> a[utils.convertToSlice([1, 3])] + array([11, 13]) + + + >>> a[utils.convertToSlice([1, 3], increment=-1)] + array([10, 12]) + + >>> a[utils.convertToSlice(slice(2, 3, None), increment=-1)] + array([11]) + + """ + if increment is False: + increment = 0 + + if not isinstance(increment, int): + raise Exception("increment must be False or an integer in utils.convertToSlice") + + if x is None: + x = np.s_[:] + + if isinstance(x, list): + x = np.array(x) + + if isinstance(x, (int, np.integer)) or isinstance(x, (float, np.floating)): + x = slice(int(x), int(x) + 1, None) + + # Correct the slice indices to be group instead of index based. + # The energy groups are 1..x and the indices are 0..x-1. + if isinstance(x, slice): + if x.start is not None: + jstart = x.start + increment + else: + jstart = None + + if x.stop is not None: + if isinstance(x.stop, list): + jstop = [x + increment for x in x.stop] + else: + jstop = x.stop + increment + else: + jstop = None + + jstep = x.step + + return np.s_[jstart:jstop:jstep] + + elif isinstance(x, np.ndarray): + return np.array([i + increment for i in x]) + + else: + raise Exception( + ( + "It is not known how to handle x type: " "{0} in utils.convertToSlice" + ).format(type(x)) + ) + + +def efmt(a: str) -> str: + r"""Converts string exponential number to another string with just 2 digits in the exponent.""" + # this assumes that none of our numbers will be more than 1e100 or less than 1e-100... + if len(a.split("E")) != 2: + two = a.split("e") + else: + two = a.split("E") + # print two + exp = two[1] # this is '+002' or '+02' or something + + if len(exp) == 4: # it has 3 digits of exponent + exp = exp[0] + exp[2:] # gets rid of the hundred's place digit + + return two[0] + "E" + exp + + +def expandRepeatedFloats(repeatedList): + """ + Return an expanded repeat list. + + Notes + ----- + R char is valid for showing the number of repeats in MCNP. For examples the list: + [150, 200, '9R'] + indicates a 150 day cycle followed by 10 200 day cycles. + """ + nonRepeatList = [] + for val in repeatedList: + isRepeat = False + if isinstance(val, str): + val = val.upper() + if val.count("R") > 1: + raise ValueError("List had strings that were not repeats") + elif "R" in val: + val = val.replace("R", "") + isRepeat = True + if isRepeat: + nonRepeatList += [nonRepeatList[-1]] * int(val) + else: + nonRepeatList.append(float(val)) + return nonRepeatList + + +def findClosest(listToSearch, val, indx=False): + r""" + find closest item in a list. + + Parameters + ---------- + listToSearch : list + The list to search through + + val : float + The target value that is being searched for in the list + + indx : bool, optional + If true, returns minVal and minIndex, otherwise, just the value + + Returns + ------- + minVal : float + The item in the listToSearch that is closest to val + minI : int + The index of the item in listToSearch that is closest to val. Returned if indx=True. + + """ + d = float("inf") + minVal = None + minI = None + for i, item in enumerate(listToSearch): + if abs(item - val) < d: + d = abs(item - val) + minVal = item + minI = i + if indx: + return minVal, minI + else: + # backwards compatibility + return minVal + + +def findNearestValue(searchList, searchValue): + """Search a given list for the value that is closest to the given search value.""" + return findNearestValueAndIndex(searchList, searchValue)[0] + + +def findNearestValueAndIndex(searchList, searchValue): + """Search a given list for the value that is closest to the given search value. Return a tuple + containing the value and its index in the list.""" + searchArray = np.array(searchList) + closestValueIndex = (np.abs(searchArray - searchValue)).argmin() + return searchArray[closestValueIndex], closestValueIndex + + +def fixThreeDigitExp(strToFloat: str) -> float: + """ + Convert FORTRAN numbers that cannot be converted into floats. + + Notes + ----- + Converts a number line "9.03231714805651-101" (no e or E) to "9.03231714805651e-101". + Some external depletion kernels currently need this fix. From contact with developer: + The notation like 1.0-101 is a FORTRAN thing, with history going back to the 60's. + They will only put E before an exponent 99 and below. Fortran will also read these guys + just fine, and they are valid floating point numbers. It would not be a useful effort, + in terms of time, trying to get FORTRAN to behave differently. + The approach has been to write a routine in the reading code which will interpret these. + + This helps when the scientific number exponent does not fit. + """ + match = SCIPAT_SPECIAL.match(strToFloat) + return float("{}E{}".format(*match.groups())) + + +def getFloat(val): + r"""returns float version of val, or None if it's impossible. Useful for converting + user-input into floats when '' might be possible.""" + try: + newVal = float(val) + return newVal + except: + return None + + +def getStepsFromValues(values, prevValue=0.0): + """Convert list of floats to list of steps between each float.""" + steps = [] + for val in values: + currentVal = float(val) + steps.append(currentVal - prevValue) + prevValue = currentVal + return steps + + +def linearInterpolation(x0, y0, x1, y1, targetX=None, targetY=None): + r""" + does a linear interpolation (or extrapolation) for y=f(x) + + Parameters + ---------- + x0,y0,x1,y1 : float + Coordinates of two points to interpolate between + + targetX : float, optional + X value to evaluate the line at + + targetY : float, optional + Y value we want to find the x value for (inverse interpolation) + + Returns + ------- + interpY : float + The value of y(targetX), if targetX is not None + + interpX : float + The value of x where y(x) = targetY (if targetY is not None) + + y = m(x-x0) + b + + x = (y-b)/m + + """ + if x1 == x0: + raise ZeroDivisionError("The x-values are identical. Cannot interpolate.") + + m = (y1 - y0) / (x1 - x0) + b = -m * x0 + y0 + + if targetX is not None: + return m * targetX + b + else: + return (targetY - b) / m + + +def minimizeScalarFunc( + func, + goal, + guess, + maxIterations=None, + cs=None, + positiveGuesses=False, + method=None, + tol=1.0e-3, +): + r""" + Use scipy minimize with the given function, goal value, and first guess. + + Parameters + ---------- + func : function + The function that guess will be changed to try to make it return the goal value. + + goal : float + The function will be changed until it's return equals this value. + + guess : float + The first guess value to do Newton's method on the func. + + maxIterations : int + The maximum number of iterations that the Newton's method will be allowed to perform. + + + Returns + ------- + ans : float + The guess that when input to the func returns the goal. + + """ + + def goalFunc(guess, func, positiveGuesses): + if positiveGuesses is True: + guess = abs(guess) + funcVal = func(guess) + val = abs(goal - funcVal) + return val + + if (maxIterations is None) and (cs is not None): + maxIterations = cs["maxNewtonsIterations"] + + X = sciopt.minimize( + goalFunc, + guess, + args=(func, positiveGuesses), + method=method, + tol=tol, + options={"maxiter": maxIterations}, + ) + ans = float(X["x"]) + if positiveGuesses is True: + ans = abs(ans) + + return ans + + +def newtonsMethod( + func, goal, guess, maxIterations=None, cs=None, positiveGuesses=False +): + r""" + Solves a Newton's method with the given function, goal value, and first guess. + + Parameters + ---------- + func : function + The function that guess will be changed to try to make it return the goal value. + + goal : float + The function will be changed until it's return equals this value. + + guess : float + The first guess value to do Newton's method on the func. + + maxIterations : int + The maximum number of iterations that the Newton's method will be allowed to perform. + + + Returns + ------- + ans : float + The guess that when input to the func returns the goal. + + """ + + def goalFunc(guess, func, positiveGuesses): + if positiveGuesses is True: + guess = abs(guess) + funcVal = func(guess) + val = abs(goal - funcVal) + return val + + if (maxIterations is None) and (cs is not None): + maxIterations = cs["maxNewtonsIterations"] + + # try: + ans = float( + sciopt.newton( + goalFunc, + guess, + args=(func, positiveGuesses), + tol=1.0e-3, + maxiter=maxIterations, + ) + ) + + if positiveGuesses is True: + ans = abs(ans) + + return ans + + +def parabolaFromPoints(p1, p2, p3): + r""" + find the parabola that passes through three points + + We solve a simultaneous equation with three points. + + A = x1**2 x1 1 + x2**2 x2 1 + x3**2 x3 1 + + b = y1 + y2 + y3 + + find coefficients Ax=b + + Parameters + ---------- + p1 : tuple + first point (x,y) coordinates + p2,p3: tuple, second and third points. + + Returns + ------- + a,b,c coefficients of y=ax^2+bx+c + + """ + A = np.array( + [[p1[0] ** 2, p1[0], 1], [p2[0] ** 2, p2[0], 1], [p3[0] ** 2, p3[0], 1]] + ) + + b = np.array([[p1[1]], [p2[1]], [p3[1]]]) + + try: + x = np.linalg.solve(A, b) + except: + print("Error in parabola {} {}".format(A, b)) + raise + + return float(x[0]), float(x[1]), float(x[2]) + + +def parabolicInterpolation(ap, bp, cp, targetY): + r""" + Given parabola coefficients, this interpolates the time + that would give k=targetK. + + keff = at^2+bt+c + We want to solve a*t^2+bt+c-targetK = 0.0 for time. + if there are real roots, we should probably take the smallest one + because the larger one might be at very high burnup. + If there are no real roots, just take the point where the deriv ==0, or + 2at+b=0, so t = -b/2a + The slope of the curve is the solution to 2at+b at whatever t has been determined + + Parameters + ---------- + ap, bp,cp : floats + coefficients of a parabola y = ap*x^2 + bp*x + cp + + targetK : float + The keff to find the cycle length of + + Returns + ------- + realRoots : list of tuples + (root, slope) + The best guess of the cycle length that will give k=targetK + If no positive root was found, this is the maximum of the curve. In that case, + it will be a negative number. If there are two positive roots, there will be two entries. + + slope : float + The slope of the keff vs. time curve at t=newTime + + """ + roots = np.roots([ap, bp, cp - targetY]) + realRoots = [] + for r in roots: + if r.imag == 0 and r.real > 0: + realRoots.append((r.real, 2.0 * ap * r.real + bp)) + + if not realRoots: + # no positive real roots. Take maximum and give up for this cyclic. + newTime = -bp / (2 * ap) + if newTime < 0: + raise RuntimeError("No positive roots or maxima.") + slope = 2.0 * ap * newTime + bp + newTime = ( + -newTime + ) # return a negative newTime to signal that it is not expected to be critical. + realRoots = [(newTime, slope)] + + return realRoots + + +def relErr(v1: float, v2: float) -> float: + """find the relative error between to numbers""" + if v1: + return (v2 - v1) / v1 + else: + return -1e99 def resampleStepwise(xin, yin, xout, avg=True): @@ -90,3 +622,38 @@ def resampleStepwise(xin, yin, xout, avg=True): yout.append(sum(chunk)) return yout + + +def rotateXY(x, y, degreesCounterclockwise=None, radiansCounterclockwise=None): + """ + Rotates x, y coordinates + + Parameters + ---------- + x, y : array_like + coordinates + + degreesCounterclockwise : float + Degrees to rotate in the CCW direction + + radiansCounterclockwise : float + Radians to rotate in the CCW direction + + Returns + ------- + xr, yr : array_like + the rotated coordinates. + """ + if radiansCounterclockwise is None: + radiansCounterclockwise = degreesCounterclockwise * math.pi / 180.0 + + sinT = math.sin(radiansCounterclockwise) + cosT = math.cos(radiansCounterclockwise) + rotationMatrix = np.array([[cosT, -sinT], [sinT, cosT]]) + xr, yr = rotationMatrix.dot(np.vstack((x, y))) + if len(xr) > 1: + # Convert to lists because everyone prefers lists for some reason + return xr.tolist(), yr.tolist() + else: + # Convert to scalar for consistency with old implementation + return xr[0], yr[0] diff --git a/armi/utils/tests/test_mathematics.py b/armi/utils/tests/test_mathematics.py index dd9022559..e5bfec5ee 100644 --- a/armi/utils/tests/test_mathematics.py +++ b/armi/utils/tests/test_mathematics.py @@ -14,14 +14,170 @@ r""" Testing mathematics utilities """ # pylint: disable=missing-function-docstring,missing-class-docstring,abstract-method,protected-access,no-member,disallowed-name,invalid-name +from math import sqrt import unittest -from armi.utils.mathematics import resampleStepwise +import numpy as np + +from armi.utils.mathematics import ( + average1DWithinTolerance, + convertToSlice, + efmt, + expandRepeatedFloats, + findClosest, + findNearestValue, + fixThreeDigitExp, + getFloat, + getStepsFromValues, + linearInterpolation, + minimizeScalarFunc, + newtonsMethod, + parabolaFromPoints, + parabolicInterpolation, + relErr, + resampleStepwise, + rotateXY, +) class TestMath(unittest.TestCase): """Tests for various math utilities""" + def test_average1DWithinTolerance(self): + vals = np.array([np.array([1, 2, 3]), np.array([4, 5, 6]), np.array([7, 8, 9])]) + result = average1DWithinTolerance(vals, 0.1) + self.assertEqual(len(result), 3) + self.assertEqual(result[0], 4.0) + self.assertEqual(result[1], 5.0) + self.assertEqual(result[2], 6.0) + + def test_average1DWithinToleranceInvalid(self): + vals = np.array( + [np.array([1, -2, 3]), np.array([4, -5, 6]), np.array([7, -8, 9])] + ) + with self.assertRaises(ValueError): + average1DWithinTolerance(vals, 0.1) + + def test_convertToSlice(self): + slice1 = convertToSlice(2) + self.assertEqual(slice1, slice(2, 3, None)) + slice1 = convertToSlice(2.0, increment=-1) + self.assertEqual(slice1, slice(1, 2, None)) + slice1 = convertToSlice(None) + self.assertEqual(slice1, slice(None, None, None)) + slice1 = convertToSlice([1, 2, 3]) + self.assertTrue(np.allclose(slice1, np.array([1, 2, 3]))) + slice1 = convertToSlice(slice(2, 3, None)) + self.assertEqual(slice1, slice(2, 3, None)) + slice1 = convertToSlice(np.array([1, 2, 3])) + self.assertTrue(np.allclose(slice1, np.array([1, 2, 3]))) + with self.assertRaises(Exception): + slice1 = convertToSlice("slice") + + def test_efmt(self): + self.assertAlmostEqual(efmt("1.0e+001"), "1.0E+01") + self.assertAlmostEqual(efmt("1.0E+01"), "1.0E+01") + + def test_expandRepeatedFloats(self): + repeatedFloats = ["150", "2R", 200.0, 175, "4r", 180.0, "0R"] + expectedFloats = [150] * 3 + [200] + [175] * 5 + [180] + self.assertEqual(expandRepeatedFloats(repeatedFloats), expectedFloats) + + def test_findClosest(self): + l1 = range(10) + self.assertEqual(findClosest(l1, 5.6), 6) + self.assertEqual(findClosest(l1, 10.1), 9) + self.assertEqual(findClosest(l1, -200), 0) + + # with index + self.assertEqual(findClosest(l1, 5.6, indx=True), (6, 6)) + + def test_findNearestValue(self): + searchList = [0.1, 0.2, 0.25, 0.35, 0.4] + searchValue = 0.225 + self.assertEqual(findNearestValue(searchList, searchValue), 0.2) + searchValue = 0.226 + self.assertEqual(findNearestValue(searchList, searchValue), 0.25) + searchValue = 0.0 + self.assertEqual(findNearestValue(searchList, searchValue), 0.1) + searchValue = 10 + self.assertEqual(findNearestValue(searchList, searchValue), 0.4) + + def test_fixThreeDigitExp(self): + fixed = fixThreeDigitExp("-9.03231714805651E+101") + self.assertEqual(-9.03231714805651e101, fixed) + fixed = fixThreeDigitExp("9.03231714805651-101") + self.assertEqual(9.03231714805651e-101, fixed) + fixed = fixThreeDigitExp("-2.4594981981654+101") + self.assertEqual(-2.4594981981654e101, fixed) + fixed = fixThreeDigitExp("-2.4594981981654-101") + self.assertEqual(-2.4594981981654e-101, fixed) + + def test_getFloat(self): + self.assertEqual(getFloat(1.0), 1.0) + self.assertEqual(getFloat("1.0"), 1.0) + self.assertIsNone(getFloat("word")) + + def test_getStepsFromValues(self): + steps = getStepsFromValues([1.0, 3.0, 6.0, 10.0], prevValue=0.0) + self.assertListEqual(steps, [1.0, 2.0, 3.0, 4.0]) + + def test_linearInterpolation(self): + y = linearInterpolation(1.0, 2.0, 3.0, 4.0, targetX=20.0) + x = linearInterpolation(1.0, 2.0, 3.0, 4.0, targetY=y) + + x2 = linearInterpolation(1.0, 1.0, 2.0, 2.0, targetY=50) + + self.assertEqual(x, 20.0) + self.assertEqual(x2, 50.0) + + with self.assertRaises(ZeroDivisionError): + _ = linearInterpolation(1.0, 1.0, 1.0, 2.0) + + def test_minimizeScalarFunc(self): + f = lambda x: (x + 1) ** 2 + minimum = minimizeScalarFunc(f, -3.0, 10.0, maxIterations=10) + self.assertAlmostEqual(minimum, -1.0, places=3) + minimum = minimizeScalarFunc( + f, -3.0, 10.0, maxIterations=10, positiveGuesses=True + ) + self.assertAlmostEqual(minimum, 0.0, places=3) + + def test_newtonsMethod(self): + f = lambda x: (x + 2) * (x - 1) + root = newtonsMethod(f, 0.0, 5.0, maxIterations=10, positiveGuesses=True) + self.assertAlmostEqual(root, 1.0, places=3) + root = newtonsMethod(f, 0.0, -10.0, maxIterations=10) + self.assertAlmostEqual(root, -2.0, places=3) + + def test_parabola(self): + # test the parabola function + a, b, c = parabolaFromPoints((0, 1), (1, 2), (-1, 2)) + self.assertEqual(a, 1.0) + self.assertEqual(b, 0.0) + self.assertEqual(c, 1.0) + + with self.assertRaises(Exception): + a, b, c = parabolaFromPoints((0, 1), (0, 1), (-1, 2)) + + def test_parabolicInterpolation(self): + realRoots = parabolicInterpolation(2.0e-6, -5.0e-4, 1.02, 1.0) + self.assertAlmostEqual(realRoots[0][0], 200.0) + self.assertAlmostEqual(realRoots[0][1], 3.0e-4) + self.assertAlmostEqual(realRoots[1][0], 50.0) + self.assertAlmostEqual(realRoots[1][1], -3.0e-4) + noRoots = parabolicInterpolation(2.0e-6, -4.0e-4, 1.03, 1.0) + self.assertAlmostEqual(noRoots[0][0], -100.0) + self.assertAlmostEqual(noRoots[0][1], 0.0) + # 3. run time error + with self.assertRaises(RuntimeError): + _ = parabolicInterpolation(2.0e-6, 4.0e-4, 1.02, 1.0) + + def test_relErr(self): + self.assertAlmostEqual(relErr(1.00, 1.01), 0.01) + self.assertAlmostEqual(relErr(100.0, 97.0), -0.03) + self.assertAlmostEqual(relErr(0.00, 1.00), -1e99) + def test_resampleStepwiseAvg0(self): """Test resampleStepwise() averaging when in and out bins match""" xin = [0, 1, 2, 13.3] @@ -310,6 +466,26 @@ def test_resampleStepwiseAvgComplicatedNone(self): self.assertIsNone(yout[4]) self.assertEqual(yout[5], 38.5) + def test_rotateXY(self): + x = [1.0, -1.0] + y = [1.0, 1.0] + + # test operation on scalar + xr, yr = rotateXY(x[0], y[0], 45.0) + self.assertAlmostEqual(xr, 0.0) + self.assertAlmostEqual(yr, sqrt(2)) + + xr, yr = rotateXY(x[1], y[1], 45.0) + self.assertAlmostEqual(xr, -sqrt(2)) + self.assertAlmostEqual(yr, 0.0) + + # test operation on list + xr, yr = rotateXY(x, y, 45.0) + self.assertAlmostEqual(xr[0], 0.0) + self.assertAlmostEqual(yr[0], sqrt(2)) + self.assertAlmostEqual(xr[1], -sqrt(2)) + self.assertAlmostEqual(yr[1], 0.0) + if __name__ == "__main__": unittest.main() diff --git a/armi/utils/tests/test_utils.py b/armi/utils/tests/test_utils.py index f109d7396..77fc46d7f 100644 --- a/armi/utils/tests/test_utils.py +++ b/armi/utils/tests/test_utils.py @@ -24,87 +24,7 @@ from armi.utils import directoryChangers -class Utils_TestCase(unittest.TestCase): - def test_parabola(self): - # test the parabola function - a, b, c = utils.parabolaFromPoints((0, 1), (1, 2), (-1, 2)) - self.assertEqual(a, 1.0) - self.assertEqual(b, 0.0) - self.assertEqual(c, 1.0) - - with self.assertRaises(Exception): - a, b, c = utils.parabolaFromPoints((0, 1), (0, 1), (-1, 2)) - - def test_findClosest(self): - l1 = range(10) - self.assertEqual(utils.findClosest(l1, 5.6), 6) - self.assertEqual(utils.findClosest(l1, 10.1), 9) - self.assertEqual(utils.findClosest(l1, -200), 0) - - # with index - self.assertEqual(utils.findClosest(l1, 5.6, indx=True), (6, 6)) - - def test_linearInterpolation(self): - y = utils.linearInterpolation(1.0, 2.0, 3.0, 4.0, targetX=20.0) - x = utils.linearInterpolation(1.0, 2.0, 3.0, 4.0, targetY=y) - - x2 = utils.linearInterpolation(1.0, 1.0, 2.0, 2.0, targetY=50) - - self.assertEqual(x, 20.0) - self.assertEqual(x2, 50.0) - - with self.assertRaises(ZeroDivisionError): - _ = utils.linearInterpolation(1.0, 1.0, 1.0, 2.0) - - def test_parabolicInterpolation(self): - realRoots = utils.parabolicInterpolation(2.0e-6, -5.0e-4, 1.02, 1.0) - self.assertAlmostEqual(realRoots[0][0], 200.0) - self.assertAlmostEqual(realRoots[0][1], 3.0e-4) - self.assertAlmostEqual(realRoots[1][0], 50.0) - self.assertAlmostEqual(realRoots[1][1], -3.0e-4) - noRoots = utils.parabolicInterpolation(2.0e-6, -4.0e-4, 1.03, 1.0) - self.assertAlmostEqual(noRoots[0][0], -100.0) - self.assertAlmostEqual(noRoots[0][1], 0.0) - # 3. run time error - with self.assertRaises(RuntimeError): - _ = utils.parabolicInterpolation(2.0e-6, 4.0e-4, 1.02, 1.0) - - def test_rotateXY(self): - x = [1.0, -1.0] - y = [1.0, 1.0] - - # test operation on scalar - xr, yr = utils.rotateXY(x[0], y[0], 45.0) - self.assertAlmostEqual(xr, 0.0) - self.assertAlmostEqual(yr, math.sqrt(2)) - - xr, yr = utils.rotateXY(x[1], y[1], 45.0) - self.assertAlmostEqual(xr, -math.sqrt(2)) - self.assertAlmostEqual(yr, 0.0) - - # test operation on list - xr, yr = utils.rotateXY(x, y, 45.0) - self.assertAlmostEqual(xr[0], 0.0) - self.assertAlmostEqual(yr[0], math.sqrt(2)) - self.assertAlmostEqual(xr[1], -math.sqrt(2)) - self.assertAlmostEqual(yr[1], 0.0) - - def test_findNearestValue(self): - searchList = [0.1, 0.2, 0.25, 0.35, 0.4] - searchValue = 0.225 - self.assertEqual(utils.findNearestValue(searchList, searchValue), 0.2) - searchValue = 0.226 - self.assertEqual(utils.findNearestValue(searchList, searchValue), 0.25) - searchValue = 0.0 - self.assertEqual(utils.findNearestValue(searchList, searchValue), 0.1) - searchValue = 10 - self.assertEqual(utils.findNearestValue(searchList, searchValue), 0.4) - - def test_expandRepeatedFloats(self): - repeatedFloats = ["150", "2R", 200.0, 175, "4r", 180.0, "0R"] - expectedFloats = [150] * 3 + [200] + [175] * 5 + [180] - self.assertEqual(utils.expandRepeatedFloats(repeatedFloats), expectedFloats) - +class TestGeneralUtils(unittest.TestCase): def test_mergeableDictionary(self): mergeableDict = utils.MergeableDict() normalDict = {"luna": "thehusky", "isbegging": "fortreats", "right": "now"} @@ -141,16 +61,6 @@ def test_createFormattedStrWithDelimiter(self): ) self.assertEqual(outputStr, "") - def test_fixThreeDigitExp(self): - fixed = utils.fixThreeDigitExp("-9.03231714805651E+101") - self.assertEqual(-9.03231714805651e101, fixed) - fixed = utils.fixThreeDigitExp("9.03231714805651-101") - self.assertEqual(9.03231714805651e-101, fixed) - fixed = utils.fixThreeDigitExp("-2.4594981981654+101") - self.assertEqual(-2.4594981981654e101, fixed) - fixed = utils.fixThreeDigitExp("-2.4594981981654-101") - self.assertEqual(-2.4594981981654e-101, fixed) - def test_capStrLen(self): # Test with strings str1 = utils.capStrLen("sodium", 5) @@ -185,64 +95,18 @@ def test_list2str(self): str2 = utils.list2str(list2, 4, list1, 5) self.assertEqual(str1, str2) - def test_getFloat(self): - self.assertEqual(utils.getFloat(1.0), 1.0) - self.assertEqual(utils.getFloat("1.0"), 1.0) - self.assertIsNone(utils.getFloat("word")) - - def test_relErr(self): - self.assertAlmostEqual(utils.relErr(1.00, 1.01), 0.01) - self.assertAlmostEqual(utils.relErr(100.0, 97.0), -0.03) - self.assertAlmostEqual(utils.relErr(0.00, 1.00), -1e99) - - def test_efmt(self): - self.assertAlmostEqual(utils.efmt("1.0e+001"), "1.0E+01") - self.assertAlmostEqual(utils.efmt("1.0E+01"), "1.0E+01") - def test_slantSplit(self): x1 = utils.slantSplit(10.0, 4.0, 4) x2 = utils.slantSplit(10.0, 4.0, 4, order="high first") self.assertListEqual(x1, [1.0, 2.0, 3.0, 4.0]) self.assertListEqual(x2, [4.0, 3.0, 2.0, 1.0]) - def test_newtonsMethod(self): - f = lambda x: (x + 2) * (x - 1) - root = utils.newtonsMethod(f, 0.0, 5.0, maxIterations=10, positiveGuesses=True) - self.assertAlmostEqual(root, 1.0, places=3) - root = utils.newtonsMethod(f, 0.0, -10.0, maxIterations=10) - self.assertAlmostEqual(root, -2.0, places=3) - - def test_minimizeScalarFunc(self): - f = lambda x: (x + 1) ** 2 - minimum = utils.minimizeScalarFunc(f, -3.0, 10.0, maxIterations=10) - self.assertAlmostEqual(minimum, -1.0, places=3) - minimum = utils.minimizeScalarFunc( - f, -3.0, 10.0, maxIterations=10, positiveGuesses=True - ) - self.assertAlmostEqual(minimum, 0.0, places=3) - def test_prependToList(self): a = ["hello", "world"] b = [1, 2, 3] utils.prependToList(a, b) self.assertListEqual(a, [1, 2, 3, "hello", "world"]) - def test_convertToSlice(self): - slice1 = utils.convertToSlice(2) - self.assertEqual(slice1, slice(2, 3, None)) - slice1 = utils.convertToSlice(2.0, increment=-1) - self.assertEqual(slice1, slice(1, 2, None)) - slice1 = utils.convertToSlice(None) - self.assertEqual(slice1, slice(None, None, None)) - slice1 = utils.convertToSlice([1, 2, 3]) - self.assertTrue(np.allclose(slice1, np.array([1, 2, 3]))) - slice1 = utils.convertToSlice(slice(2, 3, None)) - self.assertEqual(slice1, slice(2, 3, None)) - slice1 = utils.convertToSlice(np.array([1, 2, 3])) - self.assertTrue(np.allclose(slice1, np.array([1, 2, 3]))) - with self.assertRaises(Exception): - slice1 = utils.convertToSlice("slice") - def test_plotMatrix(self): matrix = np.zeros([2, 2], dtype=float) matrix[0, 0] = 1 @@ -257,10 +121,6 @@ def test_plotMatrix(self): utils.plotMatrix(matrix, fname, minV=0, maxV=5, figsize=[3, 4]) utils.plotMatrix(matrix, fname, xticks=xtick, yticks=ytick) - def test_getStepsFromValues(self): - steps = utils.getStepsFromValues([1.0, 3.0, 6.0, 10.0], prevValue=0.0) - self.assertListEqual(steps, [1.0, 2.0, 3.0, 4.0]) - if __name__ == "__main__": unittest.main() diff --git a/doc/conf.py b/doc/conf.py index f2eff6c59..578aabf3c 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -26,11 +26,11 @@ # All configuration values have a default; values that are commented out # serve to show the default. +import os import pathlib import re -import warnings import sys -import os +import warnings import sphinx_rtd_theme from sphinx.domains.python import PythonDomain @@ -41,9 +41,9 @@ # Also add to os.environ which will be used by the nbsphinx extension environment os.environ["PYTHONPATH"] = PYTHONPATH import armi -from armi.context import RES from armi import apps from armi.bookkeeping import tests as bookkeepingTests +from armi.context import RES from armi.utils.dochelpers import * # Configure the baseline framework "App" for framework doc building @@ -61,7 +61,6 @@ APIDOC_REL = ".apidocs" SOURCE_DIR = os.path.join("..", "armi") -APIDOC_DIR = APIDOC_REL _TUTORIAL_FILES = [ pathlib.Path(SOURCE_DIR) / "tests" / "tutorials" / fName for fName in bookkeepingTests.TUTORIAL_FILES @@ -405,7 +404,7 @@ def setup(app): "default_thumb_file": os.path.join(RES, "images", "TerraPowerLogo.png"), } -suppress_warnings: ["autoapi.python_import_resolution"] +suppress_warnings = ["autoapi.python_import_resolution"] # filter out this warning which shows up in sphinx-gallery builds. # this is suggested in the sphinx-gallery example but doesn't actually work? diff --git a/doc/release/0.2.rst b/doc/release/0.2.rst index 0b404907a..831c82f36 100644 --- a/doc/release/0.2.rst +++ b/doc/release/0.2.rst @@ -11,6 +11,7 @@ What's new in ARMI ------------------ #. Added neutronics settings: ``inners`` and ``outers`` for downstream support. #. Removed unused Thermal Hydraulics settings. +#. Minor code re-org, moving math utilities into their own module. #. TBD Bug fixes