diff --git a/flepimop/gempyor_pkg/src/gempyor/batch.py b/flepimop/gempyor_pkg/src/gempyor/batch.py index b2d56c99f..bacebf26f 100644 --- a/flepimop/gempyor_pkg/src/gempyor/batch.py +++ b/flepimop/gempyor_pkg/src/gempyor/batch.py @@ -27,7 +27,7 @@ from ._jinja import _render_template_to_file, _render_template_to_temp_file from .file_paths import run_id -from .info import Cluster, get_cluster_info +from .info import get_cluster_info from .logging import get_script_logger from .utils import _format_cli_options, _git_checkout, _git_head, _shutil_which, config from .shared_cli import ( @@ -764,6 +764,7 @@ def _submit_scenario_job( "job_name": job_name, "jobs": job_size.jobs, "outcome_modifiers_scenario": outcome_modifiers_scenario, + "nsamples": kwargs.get("samples", job_size.simulations), "nslots": job_size.simulations, # aka nwalkers "prefix": prefix, "project_path": kwargs["project_path"].absolute(), @@ -844,6 +845,12 @@ def _submit_scenario_job( type=click.IntRange(min=1), help="The number of sequential blocks to run per a job.", ), + click.Option( + param_decls=["--samples", "samples"], + default=None, + type=click.IntRange(min=1), + help="The number of samples to produce. Defaults to all samples.", + ), click.Option( param_decls=["--batch-system", "batch_system"], default=None, diff --git a/flepimop/gempyor_pkg/src/gempyor/templates/inference_command.bash.j2 b/flepimop/gempyor_pkg/src/gempyor/templates/inference_command.bash.j2 index 9a348f259..8763df704 100644 --- a/flepimop/gempyor_pkg/src/gempyor/templates/inference_command.bash.j2 +++ b/flepimop/gempyor_pkg/src/gempyor/templates/inference_command.bash.j2 @@ -4,6 +4,7 @@ flepimop-calibrate --config $CONFIG_PATH \ --project_path $PROJECT_PATH \ --nslots {{ nslots }} \ --niterations {{ simulations }} \ + --nsamples {{ nsamples }} \ --jobs {{ jobs }} \ --id {{ run_id }} \ --prefix {{ prefix}} \