Skip to content

Commit

Permalink
Apply code formatter
Browse files Browse the repository at this point in the history
  • Loading branch information
odri committed Aug 5, 2022
1 parent 87769c2 commit d1898b4
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 34 deletions.
61 changes: 35 additions & 26 deletions cpuMLP/Interface.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
///
/// \brief This is the header for Interface class
///
/// \details C++ interface between the control loop and the low-level neural network code
/// \details C++ interface between the control loop and the low-level neural
/// network code
///
//////////////////////////////////////////////////////////////////////////////////////////////////

Expand Down Expand Up @@ -35,7 +36,8 @@ class Interface {
/// \param[in] q_init Initial joint configuration of the robot
///
////////////////////////////////////////////////////////////////////////////////////////////////
void initialize(std::string polDirName, std::string estFileName, Vector12 q_init);
void initialize(std::string polDirName, std::string estFileName,
Vector12 q_init);

////////////////////////////////////////////////////////////////////////////////////////////////
///
Expand All @@ -54,7 +56,8 @@ class Interface {
/// \param[in] gyro Base angular velocities
///
////////////////////////////////////////////////////////////////////////////////////////////////
void update_observation(Vector12 pos, Vector12 vel, Vector3 ori, Vector3 gyro);
void update_observation(Vector12 pos, Vector12 vel, Vector3 ori,
Vector3 gyro);

////////////////////////////////////////////////////////////////////////////////////////////////
///
Expand All @@ -68,7 +71,8 @@ class Interface {

////////////////////////////////////////////////////////////////////////////////////////////////
///
/// \brief Convert roll, pitch and yaw angles into the corresponding rotation matrix
/// \brief Convert roll, pitch and yaw angles into the corresponding rotation
/// matrix
///
/// \param[in] r Roll angle
/// \param[in] p Pitch angle
Expand All @@ -86,10 +90,15 @@ class Interface {

////////////////////////////////////////////////////////////////////////////////////////////////
///
/// \brief Return the computation time to run the observation and control networks
/// \brief Return the computation time to run the observation and control
/// networks
///
////////////////////////////////////////////////////////////////////////////////////////////////
float get_computation_time() { return static_cast<float>(std::chrono::duration_cast<std::chrono::microseconds>(t_end_ - t_start_).count()); }
float get_computation_time() {
return static_cast<float>(
std::chrono::duration_cast<std::chrono::microseconds>(t_end_ - t_start_)
.count());
}

// Control policy
MLP_3<132, 12> policy_;
Expand Down Expand Up @@ -121,7 +130,6 @@ class Interface {
Vector12 bound_pi_;
std::chrono::time_point<std::chrono::steady_clock> t_start_;
std::chrono::time_point<std::chrono::steady_clock> t_end_;

};

Interface::Interface()
Expand All @@ -144,8 +152,8 @@ Interface::Interface()
// Empty
}

void Interface::initialize(std::string polDirName, std::string estFileName, Vector12 q_init) {

void Interface::initialize(std::string polDirName, std::string estFileName,
Vector12 q_init) {
// Control policy
policy_.updateParamFromTxt(polDirName + "full_2000.txt");

Expand All @@ -157,22 +165,22 @@ void Interface::initialize(std::string polDirName, std::string estFileName, Vect
std::ifstream obsMean_file, obsVariance_file;
obsMean_file.open(polDirName + "mean2000.csv");
obsVariance_file.open(polDirName + "var2000.csv");
if(obsMean_file.is_open()) {
for(int i = 0; i < obs_mean_.size(); i++){
if (obsMean_file.is_open()) {
for (int i = 0; i < obs_mean_.size(); i++) {
std::getline(obsMean_file, in_line);
obs_mean_(i) = std::stof(in_line);
}
}
obsMean_file.close();
} else {
throw std::runtime_error("Failed to open obsMean file.");
}

if(obsVariance_file.is_open()) {
for(int i = 0; i < obs_var_.size(); i++){
if (obsVariance_file.is_open()) {
for (int i = 0; i < obs_var_.size(); i++) {
std::getline(obsVariance_file, in_line);
obs_var_(i) = std::stof(in_line);
}
obsVariance_file.close();
obsVariance_file.close();
} else {
throw std::runtime_error("Failed to open obsVariance file.");
}
Expand All @@ -192,23 +200,27 @@ void Interface::initialize(std::string polDirName, std::string estFileName, Vect
// Initial times
t_start_ = std::chrono::steady_clock::now();
t_end_ = std::chrono::steady_clock::now();

}

Vector12 Interface::forward() {

obs_normalized_ = ((obs_ - obs_mean_).array() / (obs_var_ + .1E-8f * Vector132::Ones()).cwiseSqrt().array()).matrix();
obs_normalized_ =
((obs_ - obs_mean_).array() /
(obs_var_ + .1E-8f * Vector132::Ones()).cwiseSqrt().array())
.matrix();
obs_normalized_ = obs_normalized_.cwiseMax(-bound_).cwiseMin(bound_);

pTarget12_ = q_init_ + 0.3f * policy_.forward(obs_normalized_).cwiseMax(-bound_pi_).cwiseMin(bound_pi_);
pTarget12_ = q_init_ + 0.3f * policy_.forward(obs_normalized_)
.cwiseMax(-bound_pi_)
.cwiseMin(bound_pi_);

// Log time
t_end_ = std::chrono::steady_clock::now();

return pTarget12_;
}

void Interface::update_observation(Vector12 pos, Vector12 vel, Vector3 ori, Vector3 gyro) {
void Interface::update_observation(Vector12 pos, Vector12 vel, Vector3 ori,
Vector3 gyro) {
// Log time
t_start_ = std::chrono::steady_clock::now();

Expand Down Expand Up @@ -242,7 +254,6 @@ void Interface::update_observation(Vector12 pos, Vector12 vel, Vector3 ori, Vect
obs_.segment<12>(12 + 12 * 7) = qd_hist_.row(2);
obs_.segment<12>(12 + 12 * 8) = q_pos_error_hist_.row(4);
obs_.segment<12>(12 + 12 * 9) = qd_hist_.row(4);

}

void Interface::update_history(Vector12 pos, Vector12 vel) {
Expand All @@ -265,13 +276,11 @@ void Interface::update_history(Vector12 pos, Vector12 vel) {
// Remember previous actions
preprevious_action_ = previous_action_;
previous_action_ = pTarget12_;

}

Matrix3 Interface::rpyToMatrix(float r, float p, float y) {
typedef Eigen::AngleAxis<float> AngleAxis;
return (AngleAxis(y, Vector3::UnitZ())
* AngleAxis(p, Vector3::UnitY())
* AngleAxis(r, Vector3::UnitX())
).toRotationMatrix();
return (AngleAxis(y, Vector3::UnitZ()) * AngleAxis(p, Vector3::UnitY()) *
AngleAxis(r, Vector3::UnitX()))
.toRotationMatrix();
}
2 changes: 1 addition & 1 deletion cpuMLP/Types.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
#include <cmath>
#include <fstream>
#include <iostream>
#include <string>
#include <memory>
#include <string>

using Vector1 = Eigen::Matrix<float, 1, 1>;
using Vector2 = Eigen::Matrix<float, 2, 1>;
Expand Down
14 changes: 7 additions & 7 deletions cpuMLP/main.cpp
Original file line number Diff line number Diff line change
@@ -1,23 +1,23 @@
#include <iostream>
#include <Eigen/Core>
#include <iostream>
#include "Interface.hpp"

int main() {

Eigen::VectorXf input;
input.setRandom(5);

MLP_3<132, 12> net = MLP_3<132, 12>();
//net.load_checkpoint("/home/maractin/Workspace/quadruped-replay/checkpoints/vel_3d/flat/p2/jit_20000.pt");
// net.load_checkpoint("/home/maractin/Workspace/quadruped-replay/checkpoints/vel_3d/flat/p2/jit_20000.pt");
Eigen::Matrix<float, 132, 1> input1 = Eigen::Matrix<float, 132, 1>::Zero();
std::cout<<net.forward(input1)<<std::endl;
std::cout << net.forward(input1) << std::endl;

MLP_2<123, 11> net2 = MLP_2<123, 11>();

std::cout << "----" << std::endl;
Interface test = Interface();
std::string polDirName = "../../tmp_checkpoints/sym_pose/energy/6cm/policy-08-03-01-20-47/";
std::string estDirName = "../../tmp_checkpoints/state_estimation/symmetric_state_estimator.txt";
std::string polDirName =
"../../tmp_checkpoints/sym_pose/energy/6cm/policy-08-03-01-20-47/";
std::string estDirName =
"../../tmp_checkpoints/state_estimation/symmetric_state_estimator.txt";
test.initialize(polDirName, estDirName, Vector12::Ones());

}

0 comments on commit d1898b4

Please sign in to comment.