diff --git a/waymax/visualization/viz.py b/waymax/visualization/viz.py index edc6871..c567aab 100644 --- a/waymax/visualization/viz.py +++ b/waymax/visualization/viz.py @@ -282,9 +282,10 @@ def plot_simulator_state( current_xy = traj.xy[:, state.timestep, :] if viz_config.center_agent_idx == -1: xy = current_xy[state.object_metadata.is_sdc] + origin_x, origin_y = xy[0, :2] else: xy = current_xy[viz_config.center_agent_idx] - origin_x, origin_y = xy[0, :2] + origin_x, origin_y = xy[:2] ax.axis(( origin_x - viz_config.back_x, origin_x + viz_config.front_x,