diff --git a/slsim/lens.py b/slsim/lens.py index cd4cf1d77..2ebbd3ee1 100644 --- a/slsim/lens.py +++ b/slsim/lens.py @@ -798,10 +798,16 @@ def lenstronomy_kwargs(self, band=None): lens_mass_model_list ) kwargs_model["z_lens"] = self.deflector_redshift - kwargs_model["source_redshift_list"] = self.source_redshift_list + if self.max_redshift_source_class.light_profile == "single_sersic": + kwargs_model["source_redshift_list"] = self.source_redshift_list + elif self.max_redshift_source_class.light_profile == "double_sersic": + kwargs_model["source_redshift_list"] = [ + z for z in self.source_redshift_list for _ in range(2) + ] kwargs_model["z_source_convention"] = ( self.max_redshift_source_class.redshift ) + kwargs_model["z_source"] = self.max_redshift_source_class.redshift kwargs_model["cosmo"] = self.cosmo sources, sources_kwargs = self.source_light_model_lenstronomy(band=band) @@ -893,19 +899,15 @@ def source_light_model_lenstronomy(self, band=None): ) ) # lets transform list in to required structure - if ( + """if ( self.max_redshift_source_class.light_profile == "double_sersic" and self.source_number > 1 ): source_models_list_restructure = source_models_list kwargs_source_list_restructure = kwargs_source_list - else: - source_models_list_restructure = list( - np.concatenate(source_models_list) - ) - kwargs_source_list_restructure = list( - np.concatenate(kwargs_source_list) - ) + else:""" + source_models_list_restructure = list(np.concatenate(source_models_list)) + kwargs_source_list_restructure = list(np.concatenate(kwargs_source_list)) source_models["source_light_model_list"] = source_models_list_restructure kwargs_source = kwargs_source_list_restructure else: diff --git a/tests/test_image_simulation.py b/tests/test_image_simulation.py index 5226b0adf..a33722d3a 100644 --- a/tests/test_image_simulation.py +++ b/tests/test_image_simulation.py @@ -465,42 +465,38 @@ def setup_method(self): source_class=[self.source1, self.source2], cosmo=self.cosmo, ) - path = os.path.dirname(__file__) - psf_kernel_single = np.load( - os.path.join(path, "TestData/psf_kernels_for_image_1.npy") - ) - image1 = lens_image( - lens_class=lens_class1, + + self.image1 = sharp_image( + lens_class1, band="i", mag_zero_point=27, + delta_pix=0.2, num_pix=64, - psf_kernel=psf_kernel_single, - transform_pix2angle=np.array([[0.2, 0], [0, 0.2]]), - exposure_time=30, - t_obs=10, + with_source=True, + with_deflector=True, ) - image2 = lens_image( - lens_class=lens_class2, + self.image2 = sharp_image( + lens_class2, band="i", mag_zero_point=27, + delta_pix=0.2, num_pix=64, - psf_kernel=psf_kernel_single, - transform_pix2angle=np.array([[0.2, 0], [0, 0.2]]), - exposure_time=30, - t_obs=10, + with_source=True, + with_deflector=False, ) - image3 = lens_image( - lens_class=lens_class3, + self.image3 = sharp_image( + lens_class3, band="i", mag_zero_point=27, + delta_pix=0.2, num_pix=64, - psf_kernel=psf_kernel_single, - transform_pix2angle=np.array([[0.2, 0], [0, 0.2]]), - exposure_time=30, - t_obs=10, + with_source=True, + with_deflector=True, ) - combined_image = image1 + image2 - assert image3 == combined_image + self.combined_image = self.image1 + self.image2 + + def test_image_multiple_source(self): + npt.assert_almost_equal(self.image3, self.combined_image, decimal=8) if __name__ == "__main__":