diff --git a/cpuMLP/Interface.hpp b/cpuMLP/Interface.hpp index a859cee..945ef14 100644 --- a/cpuMLP/Interface.hpp +++ b/cpuMLP/Interface.hpp @@ -2,7 +2,8 @@ /// /// \brief This is the header for Interface class /// -/// \details C++ interface between the control loop and the low-level neural network code +/// \details C++ interface between the control loop and the low-level neural +/// network code /// ////////////////////////////////////////////////////////////////////////////////////////////////// @@ -35,7 +36,8 @@ class Interface { /// \param[in] q_init Initial joint configuration of the robot /// //////////////////////////////////////////////////////////////////////////////////////////////// - void initialize(std::string polDirName, std::string estFileName, Vector12 q_init); + void initialize(std::string polDirName, std::string estFileName, + Vector12 q_init); //////////////////////////////////////////////////////////////////////////////////////////////// /// @@ -54,7 +56,8 @@ class Interface { /// \param[in] gyro Base angular velocities /// //////////////////////////////////////////////////////////////////////////////////////////////// - void update_observation(Vector12 pos, Vector12 vel, Vector3 ori, Vector3 gyro); + void update_observation(Vector12 pos, Vector12 vel, Vector3 ori, + Vector3 gyro); //////////////////////////////////////////////////////////////////////////////////////////////// /// @@ -68,7 +71,8 @@ class Interface { //////////////////////////////////////////////////////////////////////////////////////////////// /// - /// \brief Convert roll, pitch and yaw angles into the corresponding rotation matrix + /// \brief Convert roll, pitch and yaw angles into the corresponding rotation + /// matrix /// /// \param[in] r Roll angle /// \param[in] p Pitch angle @@ -86,10 +90,15 @@ class Interface { //////////////////////////////////////////////////////////////////////////////////////////////// /// - /// \brief Return the computation time to run the observation and control networks + /// \brief Return the computation time to run the observation and control + /// networks /// //////////////////////////////////////////////////////////////////////////////////////////////// - float get_computation_time() { return static_cast(std::chrono::duration_cast(t_end_ - t_start_).count()); } + float get_computation_time() { + return static_cast( + std::chrono::duration_cast(t_end_ - t_start_) + .count()); + } // Control policy MLP_3<132, 12> policy_; @@ -121,7 +130,6 @@ class Interface { Vector12 bound_pi_; std::chrono::time_point t_start_; std::chrono::time_point t_end_; - }; Interface::Interface() @@ -144,8 +152,8 @@ Interface::Interface() // Empty } -void Interface::initialize(std::string polDirName, std::string estFileName, Vector12 q_init) { - +void Interface::initialize(std::string polDirName, std::string estFileName, + Vector12 q_init) { // Control policy policy_.updateParamFromTxt(polDirName + "full_2000.txt"); @@ -157,22 +165,22 @@ void Interface::initialize(std::string polDirName, std::string estFileName, Vect std::ifstream obsMean_file, obsVariance_file; obsMean_file.open(polDirName + "mean2000.csv"); obsVariance_file.open(polDirName + "var2000.csv"); - if(obsMean_file.is_open()) { - for(int i = 0; i < obs_mean_.size(); i++){ + if (obsMean_file.is_open()) { + for (int i = 0; i < obs_mean_.size(); i++) { std::getline(obsMean_file, in_line); obs_mean_(i) = std::stof(in_line); - } + } obsMean_file.close(); } else { throw std::runtime_error("Failed to open obsMean file."); } - if(obsVariance_file.is_open()) { - for(int i = 0; i < obs_var_.size(); i++){ + if (obsVariance_file.is_open()) { + for (int i = 0; i < obs_var_.size(); i++) { std::getline(obsVariance_file, in_line); obs_var_(i) = std::stof(in_line); } - obsVariance_file.close(); + obsVariance_file.close(); } else { throw std::runtime_error("Failed to open obsVariance file."); } @@ -192,15 +200,18 @@ void Interface::initialize(std::string polDirName, std::string estFileName, Vect // Initial times t_start_ = std::chrono::steady_clock::now(); t_end_ = std::chrono::steady_clock::now(); - } Vector12 Interface::forward() { - - obs_normalized_ = ((obs_ - obs_mean_).array() / (obs_var_ + .1E-8f * Vector132::Ones()).cwiseSqrt().array()).matrix(); + obs_normalized_ = + ((obs_ - obs_mean_).array() / + (obs_var_ + .1E-8f * Vector132::Ones()).cwiseSqrt().array()) + .matrix(); obs_normalized_ = obs_normalized_.cwiseMax(-bound_).cwiseMin(bound_); - pTarget12_ = q_init_ + 0.3f * policy_.forward(obs_normalized_).cwiseMax(-bound_pi_).cwiseMin(bound_pi_); + pTarget12_ = q_init_ + 0.3f * policy_.forward(obs_normalized_) + .cwiseMax(-bound_pi_) + .cwiseMin(bound_pi_); // Log time t_end_ = std::chrono::steady_clock::now(); @@ -208,7 +219,8 @@ Vector12 Interface::forward() { return pTarget12_; } -void Interface::update_observation(Vector12 pos, Vector12 vel, Vector3 ori, Vector3 gyro) { +void Interface::update_observation(Vector12 pos, Vector12 vel, Vector3 ori, + Vector3 gyro) { // Log time t_start_ = std::chrono::steady_clock::now(); @@ -242,7 +254,6 @@ void Interface::update_observation(Vector12 pos, Vector12 vel, Vector3 ori, Vect obs_.segment<12>(12 + 12 * 7) = qd_hist_.row(2); obs_.segment<12>(12 + 12 * 8) = q_pos_error_hist_.row(4); obs_.segment<12>(12 + 12 * 9) = qd_hist_.row(4); - } void Interface::update_history(Vector12 pos, Vector12 vel) { @@ -265,13 +276,11 @@ void Interface::update_history(Vector12 pos, Vector12 vel) { // Remember previous actions preprevious_action_ = previous_action_; previous_action_ = pTarget12_; - } Matrix3 Interface::rpyToMatrix(float r, float p, float y) { typedef Eigen::AngleAxis AngleAxis; - return (AngleAxis(y, Vector3::UnitZ()) - * AngleAxis(p, Vector3::UnitY()) - * AngleAxis(r, Vector3::UnitX()) - ).toRotationMatrix(); + return (AngleAxis(y, Vector3::UnitZ()) * AngleAxis(p, Vector3::UnitY()) * + AngleAxis(r, Vector3::UnitX())) + .toRotationMatrix(); } \ No newline at end of file diff --git a/cpuMLP/Types.h b/cpuMLP/Types.h index 2d40f78..14b9b58 100644 --- a/cpuMLP/Types.h +++ b/cpuMLP/Types.h @@ -6,8 +6,8 @@ #include #include #include -#include #include +#include using Vector1 = Eigen::Matrix; using Vector2 = Eigen::Matrix; diff --git a/cpuMLP/main.cpp b/cpuMLP/main.cpp index 8f7b096..a389116 100644 --- a/cpuMLP/main.cpp +++ b/cpuMLP/main.cpp @@ -1,23 +1,23 @@ -#include #include +#include #include "Interface.hpp" int main() { - Eigen::VectorXf input; input.setRandom(5); MLP_3<132, 12> net = MLP_3<132, 12>(); - //net.load_checkpoint("/home/maractin/Workspace/quadruped-replay/checkpoints/vel_3d/flat/p2/jit_20000.pt"); + // net.load_checkpoint("/home/maractin/Workspace/quadruped-replay/checkpoints/vel_3d/flat/p2/jit_20000.pt"); Eigen::Matrix input1 = Eigen::Matrix::Zero(); - std::cout< net2 = MLP_2<123, 11>(); std::cout << "----" << std::endl; Interface test = Interface(); - std::string polDirName = "../../tmp_checkpoints/sym_pose/energy/6cm/policy-08-03-01-20-47/"; - std::string estDirName = "../../tmp_checkpoints/state_estimation/symmetric_state_estimator.txt"; + std::string polDirName = + "../../tmp_checkpoints/sym_pose/energy/6cm/policy-08-03-01-20-47/"; + std::string estDirName = + "../../tmp_checkpoints/state_estimation/symmetric_state_estimator.txt"; test.initialize(polDirName, estDirName, Vector12::Ones()); - }