Skip to content

Commit

Permalink
Merge branch 'master' into spdhg_sampler
Browse files Browse the repository at this point in the history
  • Loading branch information
MargaretDuff authored Nov 25, 2024
2 parents c8ee858 + 42e5595 commit a346c1e
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 76 deletions.
2 changes: 2 additions & 0 deletions Wrappers/Python/cil/utilities/dataexample.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ def download_data(cls, data_dir, prompt=True):
with ZipFile(os.path.join(data_dir, cls.ZIP_FILE), 'r') as zip_ref:
zip_ref.extractall(os.path.join(data_dir, cls.FOLDER))
os.remove(os.path.join(data_dir, cls.ZIP_FILE))
if os.path.exists(os.path.join(data_dir, 'md5sums.txt')):
os.remove(os.path.join(data_dir, 'md5sums.txt'))
return True

class BOAT(CILDATA):
Expand Down
166 changes: 90 additions & 76 deletions Wrappers/Python/test/test_out_in_place.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,36 +87,38 @@ def setUp(self):
b_ig = ig.allocate('random')
c = numpy.float64(0.3)
bg = BlockGeometry(ig, ig)

# [(function, geometry, test_proximal, test_proximal_conjugate, test_gradient), ...]
self.func_geom_test_list = [
(IndicatorBox(), ag),
(KullbackLeibler(b=b, backend='numba'), ag),
(KullbackLeibler(b=b, backend='numpy'), ag),
(L1Norm(), ag),
(L1Norm(), ig),
(L1Norm(b=b), ag),
(L1Norm(b=b, weight=b), ag),
(TranslateFunction(L1Norm(), b), ag),
(TranslateFunction(L2NormSquared(), b), ag),
(L2NormSquared(), ag),
(scalar * L2NormSquared(), ag),
(SumFunction(L2NormSquared(), scalar * L2NormSquared()), ag),
(SumScalarFunction(L2NormSquared(), 3), ag),
(ConstantFunction(3), ag),
(ZeroFunction(), ag),
(L2NormSquared(b=b), ag),
(L2NormSquared(), ag),
(LeastSquares(A, b_ig, c, weight_ls), ig),
(LeastSquares(A, b_ig, c), ig),
(WeightedL2NormSquared(weight=b_ig), ig),
(TotalVariation(backend='c', warm_start=False, max_iteration=100), ig),
(TotalVariation(backend='numpy', warm_start=False, max_iteration=100), ig),
(OperatorCompositionFunction(L2NormSquared(), A), ig),
(MixedL21Norm(), bg),
(SmoothMixedL21Norm(epsilon=0.3), bg),
(MixedL11Norm(), bg),
(BlockFunction(L2NormSquared(),L2NormSquared()), bg),
(L1Sparsity(WaveletOperator(ig)), ig)
(IndicatorBox(), ag, True, True, False),
(KullbackLeibler(b=b, backend='numba'), ag, True, True, True),
(KullbackLeibler(b=b, backend='numpy'), ag, True, True, True),
(L1Norm(), ag, True, True, False),
(L1Norm(), ig, True, True, False),
(L1Norm(b=b), ag, True, True, False),
(L1Norm(b=b, weight=b), ag, True, True, False),
(TranslateFunction(L1Norm(), b), ag, True, True, False),
(TranslateFunction(L2NormSquared(), b), ag, True, True, True),
(L2NormSquared(), ag, True, True, True),
(scalar * L2NormSquared(), ag, True, True, True),
(SumFunction(L2NormSquared(), scalar * L2NormSquared()), ag, False, False, True),
(SumScalarFunction(L2NormSquared(), 3), ag, True, True, True),
(ConstantFunction(3), ag, True, True, True),
(ZeroFunction(), ag, True, True, True),
(L2NormSquared(b=b), ag, True, True, True),
(L2NormSquared(), ag, True, True, True),
(LeastSquares(A, b_ig, c, weight_ls), ig, False, False, True),
(LeastSquares(A, b_ig, c), ig, False, False, True),
(WeightedL2NormSquared(weight=b_ig), ig, True, True, True),
(TotalVariation(backend='c', warm_start=False, max_iteration=100), ig, True, True, False),
(TotalVariation(backend='numpy', warm_start=False, max_iteration=100), ig, True, True, False),
(OperatorCompositionFunction(L2NormSquared(), A), ig, False, False, True),
(MixedL21Norm(), bg, True, True, False),
(SmoothMixedL21Norm(epsilon=0.3), bg, False, False, True),
(MixedL11Norm(), bg, True, True, False),
(BlockFunction(L1Norm(),L2NormSquared()), bg, True, True, False),
(BlockFunction(L2NormSquared(),L2NormSquared()), bg, True, True, True),
(L1Sparsity(WaveletOperator(ig)), ig, True, True, False)


]

Expand All @@ -135,62 +137,71 @@ def get_result(self, function, method, x, *args):
self.assertDataArraysInContainerAllClose(input, x, rtol=1e-5, msg= "In case func."+method+'(data, *args) where func is ' + function.__class__.__name__+ 'the input data has been incorrectly affected by the calculation. ')
return out
except NotImplementedError:
return None
raise NotImplementedError(function.__class__.__name__+" raises a NotImplementedError for "+method)


def in_place_test(self,desired_result, function, method, x, *args, ):
out3 = x.copy()
try:
if method == 'proximal':
function.proximal(out3, *args, out=out3)
elif method == 'proximal_conjugate':
function.proximal_conjugate(out3, *args, out=out3)
elif method == 'gradient':
function.gradient(out3, *args, out=out3)
self.assertDataArraysInContainerAllClose(desired_result, out3, rtol=1e-5, msg= "In place calculation failed for func."+method+'(data, *args, out=data) where func is ' + function.__class__.__name__+ '. ')

except (InPlaceError, NotImplementedError):
pass

try:
if method == 'proximal':
function.proximal(out3, *args, out=out3)
elif method == 'proximal_conjugate':
function.proximal_conjugate(out3, *args, out=out3)
elif method == 'gradient':
function.gradient(out3, *args, out=out3)
self.assertDataArraysInContainerAllClose(desired_result, out3, rtol=1e-5, msg= "In place calculation failed for func."+method+'(data, *args, out=data) where func is ' + function.__class__.__name__+ '. ')

except InPlaceError:
pass
except NotImplementedError:
raise NotImplementedError(function.__class__.__name__+" raises a NotImplementedError for "+method)

def out_test(self, desired_result, function, method, x, *args, ):
input = x.copy()
out2=0*(x.copy())
try:
if method == 'proximal':
ret = function.proximal(input, *args, out=out2)
elif method == 'proximal_conjugate':
ret = function.proximal_conjugate(input, *args, out=out2)
elif method == 'gradient':
ret = function.gradient(input, *args, out=out2)
self.assertDataArraysInContainerAllClose(desired_result, out2, rtol=1e-5, msg= "Calculation failed using `out` in func."+method+'(x, *args, out=data) where func is ' + function.__class__.__name__+ '. ')
self.assertDataArraysInContainerAllClose(input, x, rtol=1e-5, msg= "In case func."+method+'(data, *args, out=out) where func is ' + function.__class__.__name__+ 'the input data has been incorrectly affected by the calculation. ')
self.assertDataArraysInContainerAllClose(desired_result, ret, rtol=1e-5, msg= f"Calculation failed returning with `out` in ret = func.{method}(x, *args, out=data) where func is {function.__class__.__name__}")

except (InPlaceError, NotImplementedError):
pass
try:
if method == 'proximal':
ret = function.proximal(input, *args, out=out2)
elif method == 'proximal_conjugate':
ret = function.proximal_conjugate(input, *args, out=out2)
elif method == 'gradient':
ret = function.gradient(input, *args, out=out2)
self.assertDataArraysInContainerAllClose(desired_result, out2, rtol=1e-5, msg= "Calculation failed using `out` in func."+method+'(x, *args, out=data) where func is ' + function.__class__.__name__+ '. ')
self.assertDataArraysInContainerAllClose(input, x, rtol=1e-5, msg= "In case func."+method+'(data, *args, out=out) where func is ' + function.__class__.__name__+ 'the input data has been incorrectly affected by the calculation. ')
self.assertDataArraysInContainerAllClose(desired_result, ret, rtol=1e-5, msg= f"Calculation failed returning with `out` in ret = func.{method}(x, *args, out=data) where func is {function.__class__.__name__}")

except InPlaceError:
pass
except NotImplementedError:
raise NotImplementedError(function.__class__.__name__+" raises a NotImplementedError for "+method)



def test_proximal_conjugate_out(self):
for func, geom in self.func_geom_test_list:
for data_array in self.data_arrays:
data=geom.allocate(None)
data.fill(data_array)
result=self.get_result(func, 'proximal_conjugate', data, 0.5)
self.out_test(result, func, 'proximal_conjugate', data, 0.5)
self.in_place_test(result, func, 'proximal_conjugate', data, 0.5)
for func, geom, _, test_proximal_conj, _ in self.func_geom_test_list:
if test_proximal_conj:
for data_array in self.data_arrays:
data=geom.allocate(None)
data.fill(data_array)
result=self.get_result(func, 'proximal_conjugate', data, 0.5)
self.out_test(result, func, 'proximal_conjugate', data, 0.5)
self.in_place_test(result, func, 'proximal_conjugate', data, 0.5)

def test_proximal_out(self):
for func, geom in self.func_geom_test_list:
for data_array in self.data_arrays:
data=geom.allocate(None)
data.fill(data_array)
result=self.get_result(func, 'proximal', data, 0.5)
self.out_test(result, func, 'proximal', data, 0.5)
self.in_place_test(result,func, 'proximal', data, 0.5)
for func, geom, test_proximal, _, _ in self.func_geom_test_list:
if test_proximal:
for data_array in self.data_arrays:
data=geom.allocate(None)
data.fill(data_array)
result=self.get_result(func, 'proximal', data, 0.5)
self.out_test(result, func, 'proximal', data, 0.5)
self.in_place_test(result,func, 'proximal', data, 0.5)

def test_gradient_out(self):
for func, geom in self.func_geom_test_list:
for func, geom, _, _, test_gradient in self.func_geom_test_list:
if test_gradient:
for data_array in self.data_arrays:
print(func.__class__.__name__)
data=geom.allocate(None)
Expand Down Expand Up @@ -263,20 +274,23 @@ def get_result(self, operator, method, x, *args):
self.assertDataArraysInContainerAllClose(input, x, rtol=1e-5, msg= "In case operator."+method+'(data, *args) where operator is ' + operator.__class__.__name__+ 'the input data has been incorrectly affected by the calculation. ')
return out
except NotImplementedError:
return None
raise NotImplementedError(operator.__class__.__name__+" raises a NotImplementedError for "+method)

def in_place_test(self,desired_result, operator, method, x, *args, ):
out3 = x.copy()
try:
if method == 'direct':
operator.direct(out3, *args, out=out3)
elif method == 'adjoint':
operator.adjoint(out3, *args, out=out3)
try:
if method == 'direct':
operator.direct(out3, *args, out=out3)
elif method == 'adjoint':
operator.adjoint(out3, *args, out=out3)

self.assertDataArraysInContainerAllClose(desired_result, out3, rtol=1e-5, msg= "In place calculation failed for operator."+method+'(data, *args, out=data) where operator is ' + operator.__class__.__name__+ '. ')
self.assertDataArraysInContainerAllClose(desired_result, out3, rtol=1e-5, msg= "In place calculation failed for operator."+method+'(data, *args, out=data) where operator is ' + operator.__class__.__name__+ '. ')

except (InPlaceError, NotImplementedError):
pass
except InPlaceError:
pass
except NotImplementedError:
raise NotImplementedError(operator.__class__.__name__+" raises a NotImplementedError for "+method)


def out_test(self, desired_result, operator, method, x, *args):
Expand Down

0 comments on commit a346c1e

Please sign in to comment.