diff --git a/slsim/lens_pop.py b/slsim/lens_pop.py index 846379039..b7cc2a4ab 100644 --- a/slsim/lens_pop.py +++ b/slsim/lens_pop.py @@ -119,14 +119,17 @@ def get_num_sources_tested(self, testarea=None, num_sources_tested_mean=None): num_sources_range = np.random.poisson(lam=num_sources_tested_mean) return num_sources_range - def draw_population(self, kwargs_lens_cuts, speed_factor=1): - """Return full population list of all lenses within the area # TODO: need to - implement a version of it. (improve the algorithm) + def draw_population(self, kwargs_lens_cuts, multi_source=False, speed_factor=1): + """Return full population list of all lenses within the area + + # TODO: need to implement a version of it. (improve the algorithm) :param kwargs_lens_cuts: validity test keywords + :type kwargs_lens_cuts: dict + :param multi_source: A boolean value. If True, considers multi source lensing. + If False, considers single source lensing. The default value is True. :param speed_factor: factor by which the number of deflectors is decreased to speed up the calculations. - :type kwargs_lens_cuts: dict :return: List of Lens instances with parameters of the deflectors and lens and source light. :rtype: list @@ -157,9 +160,12 @@ def draw_population(self, kwargs_lens_cuts, speed_factor=1): while n < num_sources_tested: _source = self._sources.draw_source() if n == 0: - # TODO: this is only consistent for a single source. If there are multiple sources at different redshift, this is not fully acurate - los_class = self.los_pop.draw_los(source_redshift=_source.redshift, - deflector_redshift=_deflector.redshift) + # TODO: this is only consistent for a single source. If there + # are multiple sources at different redshift, this is not fully + # acurate + los_class = self.los_pop.draw_los( + source_redshift=_source.redshift, + deflector_redshift=_deflector.redshift) lens_class = Lens( deflector_class=_deflector, source_class=_source, @@ -170,6 +176,9 @@ def draw_population(self, kwargs_lens_cuts, speed_factor=1): # Check the validity of the lens system if lens_class.validity_test(**kwargs_lens_cuts): valid_sources.append(_source) + # If multi_source is False, stop after finding the first valid source + if not multi_source: + break n += 1 if len(valid_sources) > 0: # Use a single source if only one source is valid, else use diff --git a/tests/test_lens_pop.py b/tests/test_lens_pop.py index 71f1658ce..75e45c5dd 100644 --- a/tests/test_lens_pop.py +++ b/tests/test_lens_pop.py @@ -62,8 +62,10 @@ def gg_lens_pop_instance(): def test_draw_population(gg_lens_pop_instance): lens_pop=gg_lens_pop_instance kwargs_lens_cuts = {} - lens_population = lens_pop.draw_population(kwargs_lens_cuts) + lens_population = lens_pop.draw_population(kwargs_lens_cuts, multi_source=True) + lens_population2 = lens_pop.draw_population(kwargs_lens_cuts, multi_source=False) assert len(lens_population) <= 40 + assert len(lens_population2) <= 40 def test_pes_lens_pop_instance(): cosmo = FlatLambdaCDM(H0=70, Om0=0.3)