Skip to content

Commit

Permalink
removed commented out ROS 1 code
Browse files Browse the repository at this point in the history
  • Loading branch information
rtjord committed Feb 11, 2024
1 parent 06b2ce4 commit 0bd7607
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 55 deletions.
4 changes: 0 additions & 4 deletions src/rktl_autonomy/nodes/rocket_league_agent
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,13 @@ License:
from rktl_autonomy.rocket_league_interface import RocketLeagueInterface
from stable_baselines3 import PPO
from os.path import expanduser
# import rospy
import rclpy
from rclpy.exceptions import ROSInterruptException



# create interface (and init ROS)
env = RocketLeagueInterface(eval=True)

# load the model
# weights = expanduser(rospy.get_param('~weights'))
weights = expanduser(env.node.get_parameter('~weights').get_parameter_value().string_value)
model = PPO.load(weights)

Expand Down
4 changes: 0 additions & 4 deletions src/rktl_autonomy/rktl_autonomy/_ros_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,11 +119,9 @@ def __init__(self, node_name='gym_interface', eval=False, launch_file=None, laun
else:
port = None
self.__LOG_ID = f'{run_id}:{port}'
# self.__log_pub = rospy.Publisher('~log', DiagnosticStatus, queue_size=1)
self.__log_pub = self.node.create_publisher(DiagnosticStatus, '~log', qos_profile=1)
self.__episode = 0
self.__net_reward = 0
# self.__start_time = rospy.Time.now()
self.__start_time = self.node.get_clock().now()

def step(self, action):
Expand Down Expand Up @@ -166,7 +164,6 @@ def reset(self):
info = {
'episode' : self.__episode,
'net_reward' : self.__net_reward,
# 'duration' : (rospy.Time.now() - self.__start_time).to_sec()
'duration' : self.node.get_clock().now() - self.__start_time
}
info.update(self._get_state()[3])
Expand Down Expand Up @@ -198,7 +195,6 @@ def __step_time_and_wait_for_state(self, max_retries=1):
retries = 0
while not self.__wait_once_for_state():
if retries >= max_retries:
# rospy.logerr("Failed to get new state.")
self.node.get_logger().warn('Failed to get new state.')
raise SimTimeException
else:
Expand Down
51 changes: 4 additions & 47 deletions src/rktl_autonomy/rktl_autonomy/rocket_league_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,101 +49,66 @@ def __init__(self, eval=False, launch_file=('rktl_autonomy', 'rocket_league_trai
# Constants
self.env_number = env_number
# Actions
# self._MIN_VELOCITY = -rospy.get_param('/cars/throttle/max_speed')
self._MIN_VELOCITY = -self.node.get_parameter('/cars/throttle/max_speed').get_parameter_value().double_value
# self._MAX_VELOCITY = rospy.get_param('/cars/throttle/max_speed')
self._MAX_VELOCITY = self.node.get_parameter('/cars/throttle/max_speed').get_parameter_value().double_value
# self._MIN_CURVATURE = -tan(rospy.get_param('/cars/steering/max_throw')) / rospy.get_param('cars/length')
self._MIN_CURVATURE = -tan(self.node.get_parameter('/cars/steering/max_throw').get_parameter_value().double_value) / self.node.get_parameter('cars/length').get_parameter_value().double_value
# self._MAX_CURVATURE = tan(rospy.get_param('/cars/steering/max_throw')) / rospy.get_param('cars/length')
self._MAX_CURVATURE = tan(self.node.get_parameter('/cars/steering/max_throw').get_parameter_value().double_value) / self.node.get_parameter('cars/length').get_parameter_value().double_value


# Action space overrides
# if rospy.has_param('~action_space/velocity/min'):
if self.node.has_parameter('~action_space/velocity/min'):
# min_velocity = rospy.get_param('~action_space/velocity/min')
min_velocity = self.node.get_parameter('~action_space/velocity/min').get_parameter_value().double_value
assert min_velocity > self._MIN_VELOCITY
self._MIN_VELOCITY = min_velocity
# if rospy.has_param('~action_space/velocity/max'):

if self.node.has_parameter('~action_space/velocity/max'):
# max_velocity = rospy.get_param('~action_space/velocity/max')
max_velocity = self.node.get_parameter('~action_space/velocity/max').get_parameter_value().double_value
assert max_velocity < self._MAX_VELOCITY
self._MAX_VELOCITY = max_velocity
# if rospy.has_param('~action_space/curvature/min'):

if self.node.has_parameter('~action_space/curvature/min'):
# min_curvature = rospy.get_param('~action_space/curvature/min')
min_curvature = self.node.get_parameter('~action_space/curvature/min').get_parameter_value().double_value
assert min_curvature > self._MIN_CURVATURE
self._MIN_CURVATURE = min_curvature
# if rospy.has_param('~action_space/curvature/max'):

if self.node.has_parameter('~action_space/curvature/max'):
# max_curvature = rospy.get_param('~action_space/curvature/max')
max_curvature = self.node.get_parameter('~action_space/curvature/max').get_parameter_value().double_value
assert max_curvature < self._MAX_CURVATURE
self._MAX_CURVATURE = max_curvature

# Observations
# self._FIELD_WIDTH = rospy.get_param('/field/width')
self._FIELD_WIDTH = self.node.get_parameter('/field/width').get_parameter_value().double_value
# self._FIELD_LENGTH = rospy.get_param('/field/length')
self._FIELD_LENGTH = self.node.get_parameter('/field/length').get_parameter_value().double_value
# self._GOAL_DEPTH = rospy.get_param('~observation/goal_depth', 0.075)
self._GOAL_DEPTH = self.node.get_parameter_or('~observation/goal_depth', Parameter(0.075)).get_parameter_value().double_value
# self._MAX_OBS_VEL = rospy.get_param('~observation/velocity/max_abs', 3.0)
self._MAX_OBS_VEL = self.node.get_parameter_or('~observation/velocity/max_abs', Parameter(3.0)).get_parameter_value().double_value
# self._MAX_OBS_ANG_VEL = rospy.get_param('~observation/angular_velocity/max_abs', 2*pi)
self._MAX_OBS_ANG_VEL = self.node.get_parameter_or('~observation/angular_velocity/max_abs', Parameter(2*pi)).get_parameter_value().double_value

# Learning
# self._MAX_TIME = rospy.get_param('~max_episode_time', 30.0)
self._MAX_TIME = self.node.get_parameter_or('~max_episode_time', Parameter(30.0)).get_parameter_value().double_value
# self._CONSTANT_REWARD = rospy.get_param('~reward/constant', 0.0)
self._CONSTANT_REWARD = self.node.get_parameter_or('~reward/constant', Parameter(0.0)).get_parameter_value().double_value
# self._BALL_DISTANCE_REWARD = rospy.get_param('~reward/ball_dist_sq', 0.0)
self._BALL_DISTANCE_REWARD = self.node.get_parameter_or('~reward/ball_dist_sq', Parameter(0.0)).get_parameter_value().double_value
# self._GOAL_DISTANCE_REWARD = rospy.get_param('~reward/goal_dist_sq', 0.0)
self._GOAL_DISTANCE_REWARD = self.node.get_parameter_or('~reward/goal_dist_sq', Parameter(0.0)).get_parameter_value().double_value
# self._DIRECTION_CHANGE_REWARD = rospy.get_param('~reward/direction_change', 0.0)
self._DIRECTION_CHANGE_REWARD = self.node.get_parameter_or('~reward/direction_change', Parameter(0.0)).get_parameter_value().double_value
# if isinstance(rospy.get_param('~reward/win', [100.0]), int):

if isinstance(self.node.get_parameter_or('~reward/win', Parameter([100.0])).get_parameter_value().double_array_value, int):
# self._WIN_REWARD = rospy.get_param('~reward/win', [100.0])
self._WIN_REWARD = self.node.get_parameter_or('~reward/win', Parameter([100.0])).get_parameter_value().double_array_value
else:
# if len(rospy.get_param('~reward/win', [100.0])) >= self.env_number:
if len(self.node.get_parameter_or('~reward/win', Parameter([100.0])).get_parameter_value().double_array_value) >= self.env_number:
# self._WIN_REWARD = rospy.get_param('~reward/win', [100.0])[0]
self._WIN_REWARD = self.node.get_parameter_or('~reward/win', Parameter([100.0])).get_parameter_value().double_array_value[0]
else:
# self._WIN_REWARD = rospy.get_param('~reward/win', [100.0])[self.env_number]
self._WIN_REWARD = self.node.get_parameter_or('~reward/win', Parameter([100.0])).get_parameter_value().double_array_value[self.env_number]
# if isinstance(rospy.get_param('~reward/loss', [100.0]), int):
if isinstance(self.node.get_parameter_or('~reward/loss', Parameter([100.0])).get_parameter_value().double_array_value, int):
# self._LOSS_REWARD = rospy.get_param('~reward/loss', [100.0])
self._LOSS_REWARD = self.node.get_parameter_or('~reward/loss', Parameter([100.0])).get_parameter_value().double_array_value
else:
# if len(rospy.get_param('~reward/loss', [100.0])) >= self.env_number:
if len(self.node.get_parameter_or('~reward/loss', Parameter([100.0])).get_parameter_value().double_array_value) >= self.env_number:
# self._LOSS_REWARD = rospy.get_param('~reward/loss', [100.0])[0]
self._LOSS_REWARD = self.node.get_parameter_or('~reward/loss', Parameter([100.0], type_=8)).get_parameter_value().double_array_value[0]
else:
# self._LOSS_REWARD = rospy.get_param('~reward/loss', [100.0])[self.env_number]
self._LOSS_REWARD = self.node.get_parameter_or('~reward/loss', Parameter([100.0])).get_parameter_value().double_array_value[self.env_number]
# self._REVERSE_REWARD = rospy.get_param('~reward/reverse', 0.0)
self._REVERSE_REWARD = self.node.get_parameter_or('~reward/reverse', Parameter(0.0)).get_parameter_value().double_value
# self._WALL_REWARD = rospy.get_param('~reward/walls/value', 0.0)
self._WALL_REWARD = self.node.get_parameter_or('~reward/walls/value', Parameter(0.0)).get_parameter_value().double_value
# self._WALL_THRESHOLD = rospy.get_param('~reward/walls/threshold', 0.0)
self._WALL_THRESHOLD = self.node.get_parameter_or('~reward/walls/threshold', Parameter(0.0)).get_parameter_value().double_value

self.node = Node('rocket_league_interface')
# Publishers
# self._command_pub = rospy.Publisher('cars/car0/command', ControlCommand, queue_size=1)
self._command_pub = self.node.create_publisher(ControlCommand, 'cars/car0/command', 1)
# self._reset_srv = rospy.ServiceProxy('sim_reset', Empty)
self._reset_srv = self.node.create_client(Empty, 'sim_reset')

# State variables
Expand All @@ -154,16 +119,11 @@ def __init__(self, eval=False, launch_file=('rktl_autonomy', 'rocket_league_trai
self._prev_vel = None

# Subscribers
# rospy.Subscriber('cars/car0/odom', Odometry, self._car_odom_cb)
self.node.create_subscription(Odometry, 'cars/car0/odom', self._car_odom_cb, qos_profile=1)
# rospy.Subscriber('ball/odom', Odometry, self._ball_odom_cb)
self.node.create_subscription(Odometry, 'ball/odom', self._ball_odom_cb, qos_profile=1)
# rospy.Subscriber('match_status', MatchStatus, self._score_cb)
self.node.create_subscription(MatchStatus, 'match_status', self._score_cb, qos_profile=1)

# block until environment is ready
# if not eval:
# rospy.wait_for_service('sim_reset')
if not eval:
while not self._reset_srv.wait_for_service(timeout_sec=1.0):
self.node.get_logger().info('service not available, waiting again...')
Expand Down Expand Up @@ -240,9 +200,7 @@ def _get_state(self):

# check if time exceeded
if self._start_time is None:
# self._start_time = rospy.Time.now()
self._start_time = self.node.get_clock().now()
# done = (rospy.Time.now() - self._start_time).to_sec() >= self._MAX_TIME
done = (self.node.get_clock().now() - self._start_time) >= self._MAX_TIME


Expand Down Expand Up @@ -287,7 +245,6 @@ def _publish_action(self, action):
assert self.action_space.contains(action)

msg = ControlCommand()
# msg.header.stamp = rospy.Time.now()
msg.header.stamp = self.node.get_clock().now()

if ( action == CarActions.FWD or
Expand Down

0 comments on commit 0bd7607

Please sign in to comment.