diff --git a/simpunch/tests/test_starfield.py b/simpunch/tests/test_starfield.py index 9ef307b..0a89c4f 100644 --- a/simpunch/tests/test_starfield.py +++ b/simpunch/tests/test_starfield.py @@ -34,21 +34,26 @@ def _sample_ndcube(shape: tuple, code: str = "PM1", level: str = "0") -> NDCube: @pytest.fixture -def sample_ndcollection(shape: tuple) -> NDCollection: - wcs = WCS(naxis=2) - wcs.wcs.ctype = "HPLN-ARC", "HPLT-ARC" - wcs.wcs.cunit = "deg", "deg" - wcs.wcs.cdelt = 0.1, 0.1 - wcs.wcs.crpix = 0, 0 - wcs.wcs.crval = 1, 1 - wcs.wcs.cname = "HPC lon", "HPC lat" - - input_data = NDCube(np.random.random(shape).astype(np.float32), wcs=wcs) - return NDCollection( - [("-60.0 deg", input_data), - ("0.0 deg", input_data), - ("60.0 deg", input_data)], - aligned_axes="all") +def sample_ndcollection(): + def _create_sample_ndcollection(shape: tuple) -> NDCollection: + wcs = WCS(naxis=2) + wcs.wcs.ctype = ["HPLN-ARC", "HPLT-ARC"] + wcs.wcs.cunit = ["deg", "deg"] + wcs.wcs.cdelt = [0.1, 0.1] + wcs.wcs.crpix = [shape[1] // 2, shape[0] // 2] # Center pixel for the WCS + wcs.wcs.crval = [1, 1] + wcs.wcs.cname = ["HPC lon", "HPC lat"] + + input_data = NDCube(np.random.random(shape).astype(np.float32), wcs=wcs) + + return NDCollection( + [("-60.0 deg", input_data), + ("0.0 deg", input_data), + ("60.0 deg", input_data)], + aligned_axes="all" + ) + return _create_sample_ndcollection + def test_starfield(sample_ndcube: NDCube) -> None: @@ -62,8 +67,10 @@ def test_starfield(sample_ndcube: NDCube) -> None: def test_polarized_starfield(sample_ndcollection: NDCollection) -> None: """Test polarized starfield generation.""" - input_data = sample_ndcollection((2048, 2048)) + shape = (2048, 2048) + input_data = sample_ndcollection(shape) output_data = add_starfield_polarized(input_data) assert isinstance(output_data, NDCollection) + assert len(output_data) == 3