Skip to content

Commit

Permalink
fixed errors in rktl_autonomy
Browse files Browse the repository at this point in the history
  • Loading branch information
rtjord committed Feb 10, 2024
1 parent 865418f commit 242fdf4
Show file tree
Hide file tree
Showing 7 changed files with 17 additions and 17 deletions.
10 changes: 5 additions & 5 deletions src/rktl_autonomy/nodes/plotter
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class Plotter(rclpy.Node):
#rospy.Subscriber('~log', DiagnosticStatus, self.progress_cb)
self.create_subscription(DiagnosticStatus, '~log', self.progress_cb, qos_profile=10)

self.history = None
self.history = []
self.LOG_NAME = None
self.next_plot_episode = self.PLOT_FREQ
self.init_plot()
Expand Down Expand Up @@ -110,15 +110,15 @@ class Plotter(rclpy.Node):
data[item.key] = float(item.value)

if data["episode"] is not None:
if self.history is None:
if self.history is []:
self.history = [data]
else:
self.history.append(data)

if data["episode"] >= self.next_plot_episode:
self.plot()
self.next_plot_episode += self.PLOT_FREQ
self.history = None
self.history = []
else:
#rospy.logerr("Bad progress message.")
self.get_logger().warn("Bad progress message.")
Expand Down Expand Up @@ -175,8 +175,8 @@ class Plotter(rclpy.Node):

# update file
#rospy.loginfo(f"Saving training progress to {self.LOG_DIR}{self.LOG_NAME}")
self.get_logger().info("Saving training progress to {self.LOG_DIR}{self.LOG_NAME}")
plt.savefig(self.LOG_DIR + self.LOG_NAME)
self.get_logger().info(f"Saving training progress to {self.LOG_DIR}{self.LOG_NAME}")
plt.savefig(self.LOG_DIR + str(self.LOG_NAME))

if __name__ == "__main__":
Plotter()
2 changes: 1 addition & 1 deletion src/rktl_autonomy/nodes/rocket_league_agent
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ env = RocketLeagueInterface(eval=True)

# load the model
# weights = expanduser(rospy.get_param('~weights'))
weights = expanduser(env.node.get_parameter('~weights'))
weights = expanduser(env.node.get_parameter('~weights').get_parameter_value().string_value)
model = PPO.load(weights)

# evaluate in real-time
Expand Down
2 changes: 1 addition & 1 deletion src/rktl_autonomy/rktl_autonomy/_ros_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from threading import Condition
import time, uuid, socket, os

from gym import Env
from gymnasium import Env

import rclpy
from rclpy.duration import Duration
Expand Down
12 changes: 6 additions & 6 deletions src/rktl_autonomy/rktl_autonomy/rocket_league_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,9 @@

# package
from rktl_autonomy._ros_interface import ROSInterface
from gym.spaces import Box, Discrete
from gymnasium.spaces import Box, Discrete

# ROS
# import rospy
import rclpy
from rclpy import Node
from rclpy.parameter import Parameter
Expand All @@ -25,7 +24,7 @@

# System
import numpy as np
from tf.transformations import euler_from_quaternion
from transformations import euler_from_quaternion
from enum import IntEnum, unique, auto
from math import pi, tan

Expand Down Expand Up @@ -143,7 +142,7 @@ def __init__(self, eval=False, launch_file=('rktl_autonomy', 'rocket_league_trai
self.node = Node('rocket_league_interface')
# Publishers
# self._command_pub = rospy.Publisher('cars/car0/command', ControlCommand, queue_size=1)
self.node.create_publisher(ControlCommand, 'cars/car0/command', 1)
self._command_pub = self.node.create_publisher(ControlCommand, 'cars/car0/command', 1)
# self._reset_srv = rospy.ServiceProxy('sim_reset', Empty)
self._reset_srv = self.node.create_client(Empty, 'sim_reset')

Expand Down Expand Up @@ -256,14 +255,15 @@ def _get_state(self):
goal_dist_sq = np.sum(np.square(ball[0:2] - np.array([self._FIELD_LENGTH/2, 0])))
reward += self._GOAL_DISTANCE_REWARD * goal_dist_sq

if self._score != 0:
if self._score is not None and self._score != 0:
done = True
if self._score > 0:
reward += self._WIN_REWARD
else:
reward += self._LOSS_REWARD

x, y, __, v, __ = self._car_odom
if self._car_odom is not None:
x, y, __, v, __ = self._car_odom

if self._prev_vel is None:
self._prev_vel = v
Expand Down
2 changes: 1 addition & 1 deletion src/rktl_autonomy/scripts/eval_rocket_league.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
All rights reserved.
"""

from rktl_autonomy import RocketLeagueInterface
from rktl_autonomy.rocket_league_interface import RocketLeagueInterface
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import SubprocVecEnv
from stable_baselines3.common.env_util import make_vec_env
Expand Down
4 changes: 2 additions & 2 deletions src/rktl_autonomy/scripts/train_rocket_league.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
All rights reserved.
"""

from rktl_autonomy import RocketLeagueInterface
from rktl_autonomy.rocket_league_interface import RocketLeagueInterface
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import SubprocVecEnv
from stable_baselines3.common.env_util import make_vec_env
Expand Down Expand Up @@ -39,7 +39,7 @@ def train(n_envs=24, n_saves=100, n_steps=240000000, env_number=0):

# log model weights
freq = n_steps / (n_saves * n_envs)
callback = CheckpointCallback(save_freq=freq, save_path=log_dir)
callback = CheckpointCallback(save_freq=int(freq), save_path=log_dir)

# run training
steps = n_steps
Expand Down
2 changes: 1 addition & 1 deletion src/rktl_autonomy/scripts/tune_rocket_league.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
All rights reserved.
"""

from rktl_autonomy import RocketLeagueInterface
from rktl_autonomy.rocket_league_interface import RocketLeagueInterface
import numpy as np
from stable_baselines3 import PPO

Expand Down

0 comments on commit 242fdf4

Please sign in to comment.