From 62e39704ded24c6e7e520b182195af7cbc038a10 Mon Sep 17 00:00:00 2001 From: Tushar Kusnur Date: Thu, 2 May 2024 13:48:10 -0400 Subject: [PATCH] Expose op name to sampler dict --- .../ground_truth_models/spot_env/nsrts.py | 29 ++++++++++--------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/predicators/ground_truth_models/spot_env/nsrts.py b/predicators/ground_truth_models/spot_env/nsrts.py index 1cfb9b70e0..ee44f26a59 100644 --- a/predicators/ground_truth_models/spot_env/nsrts.py +++ b/predicators/ground_truth_models/spot_env/nsrts.py @@ -167,6 +167,20 @@ def _prepare_sweeping_sampler(state: State, goal: Set[GroundAtom], return np.array([-0.8, -0.4, home_pose.angle]) +_OPERATOR_NAME_TO_SAMPLER: Dict[str, NSRTSampler] = { + "MoveToHandViewObject": _move_to_hand_view_object_sampler, + "MoveToBodyViewObject": _move_to_body_view_object_sampler, + "MoveToReachObject": _move_to_reach_object_sampler, + "PickObjectFromTop": _pick_object_from_top_sampler, + "PlaceObjectOnTop": _place_object_on_top_sampler, + "DropObjectInside": _drop_object_inside_sampler, + "DropObjectInsideContainerOnTop": _drop_object_inside_sampler, + "DragToUnblockObject": _drag_to_unblock_object_sampler, + "SweepIntoContainer": _sweep_into_container_sampler, + "PrepareContainerForSweeping": _prepare_sweeping_sampler, +} + + class SpotCubeEnvGroundTruthNSRTFactory(GroundTruthNSRTFactory): """Ground-truth NSRTs for the Spot Env.""" @@ -192,21 +206,8 @@ def get_nsrts(env_name: str, types: Dict[str, Type], nsrts = set() - operator_name_to_sampler: Dict[str, NSRTSampler] = { - "MoveToHandViewObject": _move_to_hand_view_object_sampler, - "MoveToBodyViewObject": _move_to_body_view_object_sampler, - "MoveToReachObject": _move_to_reach_object_sampler, - "PickObjectFromTop": _pick_object_from_top_sampler, - "PlaceObjectOnTop": _place_object_on_top_sampler, - "DropObjectInside": _drop_object_inside_sampler, - "DropObjectInsideContainerOnTop": _drop_object_inside_sampler, - "DragToUnblockObject": _drag_to_unblock_object_sampler, - "SweepIntoContainer": _sweep_into_container_sampler, - "PrepareContainerForSweeping": _prepare_sweeping_sampler, - } - for strips_op in env.strips_operators: - sampler = operator_name_to_sampler[strips_op.name] + sampler = _OPERATOR_NAME_TO_SAMPLER[strips_op.name] option = options[strips_op.name] nsrt = strips_op.make_nsrt( option=option,