Skip to content

Commit

Permalink
Comment algorithm reset in the unittest pipeline.
Browse files Browse the repository at this point in the history
  • Loading branch information
xaviliz committed Jun 19, 2024
1 parent 15eb2c1 commit cb8edc4
Showing 1 changed file with 94 additions and 54 deletions.
148 changes: 94 additions & 54 deletions test/src/unittests/all_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# You should have received a copy of the Affero GNU General Public License
# version 3 along with this program. If not, see http://www.gnu.org/licenses/

from __future__ import absolute_import # For Python 2 compatibility
from __future__ import absolute_import # For Python 2 compatibility

from os.path import join, sep
import os
Expand All @@ -28,14 +28,14 @@
import essentia.streaming

try:
from importlib import reload # Python3
from importlib import reload # Python3
except:
pass

# we don't want to get too chatty when running all the tests
essentia.log.info = False
#essentia.log.debug += essentia.EAll
#essentia.log.debug -= essentia.EConnectors
# essentia.log.debug += essentia.EAll
# essentia.log.debug -= essentia.EConnectors

tests_dir = os.path.dirname(__file__)
if tests_dir:
Expand All @@ -48,69 +48,105 @@


# import the test from the subdirectories which filename match the pattern 'test_*.py'
listAllTests = [ filename.split(sep+'test_') for filename in glob.glob(join('*', 'test_*.py')) ]
listAllTests = [
filename.split(sep + "test_") for filename in glob.glob(join("*", "test_*.py"))
]
for testfile in listAllTests:
testfile[1] = testfile[1][:-3]



def importTest(fullname, strategy = 'import'):
'''Imports or reloads test given its fullname.'''
def importTest(fullname, strategy="import"):
"""Imports or reloads test given its fullname."""
folder, name = fullname
if strategy == 'import':
cmd = 'import unittests.%s.test_%s; setattr(sys.modules[__name__], \'%s\', unittests.%s.test_%s.suite)' % (folder, name, name, folder, name)
elif strategy == 'reload':
cmd1 = 'reload(sys.modules[\'unittests.%s.test_%s\']); ' % (folder, name)
cmd2 = 'setattr(sys.modules[__name__], \'%s\', sys.modules[\'unittests.%s.test_%s\'].suite)' % (name, folder, name)
if strategy == "import":
cmd = (
"import unittests.%s.test_%s; setattr(sys.modules[__name__], '%s', unittests.%s.test_%s.suite)"
% (folder, name, name, folder, name)
)
elif strategy == "reload":
cmd1 = "reload(sys.modules['unittests.%s.test_%s']); " % (folder, name)
cmd2 = (
"setattr(sys.modules[__name__], '%s', sys.modules['unittests.%s.test_%s'].suite)"
% (name, folder, name)
)
cmd = cmd1 + cmd2
else:
raise ValueError('When importing a test, the only strategies allowed are \'import\' and \'reload\'')
raise ValueError(
"When importing a test, the only strategies allowed are 'import' and 'reload'"
)

exec(cmd)


def getTests(names=None, exclude=None, strategy='import'):
allNames = [ name for _, name in listAllTests ]
def getTests(names=None, exclude=None, strategy="import"):
allNames = [name for _, name in listAllTests]
names = names or allNames
tests = [ (folder, name) for folder, name in listAllTests
if name in names and name not in exclude ]
tests = [
(folder, name)
for folder, name in listAllTests
if name in names and name not in exclude
]

for name in names:
if name not in allNames:
print('WARNING: did not find test %s' % name)
for name in (exclude or []):
print("WARNING: did not find test %s" % name)
for name in exclude or []:
if name not in allNames:
print('WARNING: did not find test to exclude %s' % name)
print("WARNING: did not find test to exclude %s" % name)

print('Running tests:')
print("Running tests:")
print(sorted(name for _, name in tests))

if not tests:
raise RuntimeError('No test to execute!')
raise RuntimeError("No test to execute!")

for test in tests:
importTest(test, strategy)

testObjectsList = [ getattr(sys.modules[__name__], testName) for folder, testName in tests ]
testObjectsList = [
getattr(sys.modules[__name__], testName) for folder, testName in tests
]

return unittest.TestSuite(testObjectsList)



def traceCompute(algo, *args, **kwargs):
print('computing algo %s' % algo.name())
print("computing algo %s" % algo.name())
return algo.normalCompute(*args, **kwargs)


def computeResetCompute(algo, *args, **kwargs):
# do skip certain algos, otherwise we'd enter in an infinite loop!!!
audioLoaders = [ 'MonoLoader', 'EqloudLoader', 'EasyLoader', 'AudioLoader' ]
filters = [ 'IIR', 'DCRemoval', 'HighPass', 'LowPass', 'BandPass', 'AllPass',
'BandReject', 'EqualLoudness', 'MovingAverage' ]
special = [ 'FrameCutter', 'OverlapAdd', 'TempoScaleBands', 'TempoTap', 'TempoTapTicks',
'Panning','OnsetDetection', 'MonoWriter', 'Flux', 'StartStopSilence',
'LogSpectrum', 'ClickDetector', 'SNR', 'SaturationDetector', 'Welch',
'FrameBuffer']
audioLoaders = ["MonoLoader", "EqloudLoader", "EasyLoader", "AudioLoader"]
filters = [
"IIR",
"DCRemoval",
"HighPass",
"LowPass",
"BandPass",
"AllPass",
"BandReject",
"EqualLoudness",
"MovingAverage",
]
special = [
"FrameCutter",
"OverlapAdd",
"TempoScaleBands",
"TempoTap",
"TempoTapTicks",
"Panning",
"OnsetDetection",
"MonoWriter",
"Flux",
"StartStopSilence",
"LogSpectrum",
"ClickDetector",
"SNR",
"SaturationDetector",
"Welch",
"FrameBuffer",
]

if algo.name() in audioLoaders + filters + special:
return algo.normalCompute(*args, **kwargs)
Expand All @@ -130,6 +166,7 @@ def algodecorator(algo):

return algodecorator


# recursive helper function that finds outputs connected to pools and calls func
def mapPools(algo, func):
# make a copy first, because func might modify the connections in the for
Expand All @@ -147,8 +184,7 @@ def mapPools(algo, func):
elif isinstance(input, essentia.streaming._StreamConnector):
mapPools(input.input_algo, func)

#else ignore nowhere connections

# else ignore nowhere connections


# For this to work for networks that are connected to a pool, we need to conduct
Expand All @@ -161,28 +197,34 @@ def runResetRun(gen, *args, **kwargs):
# little trick. In particular, we have a test for multiplexer that runs
# multiple generators...
def isValid(algo):
if isinstance(algo, essentia.streaming.VectorInput) and not list(algo.connections.values())[0]:
if (
isinstance(algo, essentia.streaming.VectorInput)
and not list(algo.connections.values())[0]
):
# non-connected VectorInput, we don't want to get too fancy here...
return False
if algo.name() == 'Multiplexer':
if algo.name() == "Multiplexer":
return False
for output, inputs in algo.connections.items():
for inp in inputs:
if isinstance(inp, essentia.streaming._StreamConnector) and not isValid(inp.input_algo):
if isinstance(inp, essentia.streaming._StreamConnector) and not isValid(
inp.input_algo
):
return False
return True

if not isValid(gen):
print('Network is not capable of doing the run/reset/run trick, doing it the normal way...')
print(
"Network is not capable of doing the run/reset/run trick, doing it the normal way..."
)
essentia.run(gen)
return


# 1. Find all the outputs in the network that are connected to pools--aka
# pool feeders and for each pool feeder, disconnect the given pool,
# store it, and connect a dummy pool in its place
def useDummy(algo, output, input):
if not hasattr(output, 'originalPools'):
if not hasattr(output, "originalPools"):
output.originalPools = []
output.dummyPools = []

Expand Down Expand Up @@ -227,33 +269,31 @@ def useOriginal(algo, output, input):
return essentia.run(gen)



def runTests(tests):
result = unittest.TextTestRunner(verbosity=2).run(tests)

# return the number of failures and errors
return len(result.errors) + len(result.failures)


if __name__ == '__main__':
testList = [ t for t in sys.argv[1:] if t[0] != '-' ]
testExclude = [ t[1:] for t in sys.argv[1:] if t[0] == '-' ]
if __name__ == "__main__":
testList = [t for t in sys.argv[1:] if t[0] != "-"]
testExclude = [t[1:] for t in sys.argv[1:] if t[0] == "-"]

print('Running tests normally')
print('-'*70)
print("Running tests normally")
print("-" * 70)
result1 = runTests(getTests(testList, exclude=testExclude))

print('\n\nRunning tests with compute/reset/compute')
print('-'*70)
print("\n\nRunning tests with compute/reset/compute")
print("-" * 70)

setattr(sys.modules['essentia.common'], 'algoDecorator', computeDecorator(computeResetCompute))
# setattr(sys.modules['essentia.common'], 'algoDecorator', computeDecorator(computeResetCompute))
essentia.standard._reloadAlgorithms()
essentia.standard._reloadAlgorithms('essentia_test')
essentia.standard._reloadAlgorithms("essentia_test")

# modify runGenerator behavior
setattr(sys.modules['essentia_test'], 'run', runResetRun)

setattr(sys.modules["essentia_test"], "run", runResetRun)

result2 = runTests(getTests(testList, exclude=testExclude, strategy='reload'))
result2 = runTests(getTests(testList, exclude=testExclude, strategy="reload"))

sys.exit(result1 + result2)

0 comments on commit cb8edc4

Please sign in to comment.