Skip to content

Commit

Permalink
Fix position + velocity control in ros2_control (#415)
Browse files Browse the repository at this point in the history
* Fix position + velocity control in ros2_control

* Fix camera FoV

* fix

* Remove timer

* Add acceleration

* change camera info qos

* fix rate

* fix
  • Loading branch information
lukicdarkoo authored Mar 22, 2022
1 parent 3a71e88 commit 2ad205c
Show file tree
Hide file tree
Showing 9 changed files with 110 additions and 80 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ namespace webots_ros2_control
webots_ros2_driver::WebotsNode *mNode;
std::shared_ptr<pluginlib::ClassLoader<Ros2ControlSystemInterface>> mHardwareLoader;
std::shared_ptr<controller_manager::ControllerManager> mControllerManager;
double mControlPeriodMs;
double mLastControlUpdateMs;
int mControlPeriodMs;
int mLastControlUpdateMs;

std::thread mThreadExecutor;
rclcpp::executors::MultiThreadedExecutor::SharedPtr mExecutor;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ namespace webots_ros2_control
double velocityCommand;
double velocity;
double effortCommand;
double effort;
double acceleration;
bool controlPosition;
bool controlVelocity;
bool controlEffort;
Expand Down
14 changes: 10 additions & 4 deletions webots_ros2_control/src/Ros2Control.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,15 @@ namespace webots_ros2_control

void Ros2Control::step()
{
const double nowMs = mNode->get_clock()->now().seconds() * 1000.0;
if (nowMs >= mLastControlUpdateMs + mControlPeriodMs)
const int nowMs = mNode->robot()->getTime() * 1000.0;
const int periodMs = nowMs - mLastControlUpdateMs;
if (periodMs >= mControlPeriodMs)
{
mControllerManager->read();
#if FOXY
mControllerManager->update();
#else
rclcpp::Duration dt = rclcpp::Duration::from_nanoseconds(RCL_MS_TO_NS(mNode->robot()->getBasicTimeStep()));
const rclcpp::Duration dt = rclcpp::Duration::from_seconds(mControlPeriodMs / 1000.0);
mControllerManager->update(mNode->get_clock()->now(), dt);
#endif
mLastControlUpdateMs = nowMs;
Expand Down Expand Up @@ -100,7 +101,12 @@ namespace webots_ros2_control
// Update rate
const int updateRate = mControllerManager->get_parameter("update_rate").as_int();
mControlPeriodMs = (1.0 / updateRate) * 1000.0;
if (abs(mControlPeriodMs - mNode->robot()->getBasicTimeStep()) > CONTROLLER_MANAGER_ALLOWED_SAMPLE_ERROR_MS)

int controlPeriodProductMs = mNode->robot()->getBasicTimeStep();
while (controlPeriodProductMs < mControlPeriodMs)
controlPeriodProductMs += mNode->robot()->getBasicTimeStep();

if (abs(controlPeriodProductMs - mControlPeriodMs) > CONTROLLER_MANAGER_ALLOWED_SAMPLE_ERROR_MS)
RCLCPP_WARN_STREAM(node->get_logger(), "Desired controller update period (" << mControlPeriodMs << "ms / " << updateRate << "Hz) is different from the Webots timestep (" << mNode->robot()->getBasicTimeStep() << "ms). Please adjust the `update_rate` parameter in the `controller_manager` or the `basicTimeStep` parameter in the Webots `WorldInfo` node.");

// Spin
Expand Down
119 changes: 71 additions & 48 deletions webots_ros2_control/src/Ros2ControlSystem.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,62 +43,69 @@ namespace webots_ros2_control
if (!joint.sensor && !joint.motor)
throw std::runtime_error("Cannot find a Motor or PositionSensor with name " + joint.name);

// Initialize the state
joint.controlPosition = false;
joint.controlVelocity = false;
joint.controlEffort = false;
joint.positionCommand = NAN;
joint.velocityCommand = NAN;
joint.effortCommand = NAN;
joint.position = NAN;
joint.velocity = NAN;
joint.acceleration = NAN;

// Configure the command interface
for (hardware_interface::InterfaceInfo commandInterface : component.command_interfaces)
{
if (commandInterface.name == "position")
joint.controlPosition = true;
else if (commandInterface.name == "velocity")
{
joint.controlVelocity = true;
if (joint.motor)
{
joint.motor->setPosition(INFINITY);
joint.motor->setVelocity(0.0);
}
}
else if (commandInterface.name == "effort")
joint.controlEffort = true;
else
throw std::runtime_error("Invalid hardware info name `" + commandInterface.name + "`");
}
if (joint.motor && joint.controlVelocity && !joint.controlPosition)
{
joint.motor->setPosition(INFINITY);
joint.motor->setVelocity(0.0);
}

mJoints.push_back(joint);
}
}

#if FOXY
hardware_interface::return_type Ros2ControlSystem::configure(const hardware_interface::HardwareInfo &info)
#if FOXY
hardware_interface::return_type Ros2ControlSystem::configure(const hardware_interface::HardwareInfo &info)
{
if (configure_default(info) != hardware_interface::return_type::OK)
{
if (configure_default(info) != hardware_interface::return_type::OK)
{
return hardware_interface::return_type::ERROR;
}
status_ = hardware_interface::status::CONFIGURED;
return hardware_interface::return_type::OK;
return hardware_interface::return_type::ERROR;
}
#else
CallbackReturn Ros2ControlSystem::on_init(const hardware_interface::HardwareInfo &info)
status_ = hardware_interface::status::CONFIGURED;
return hardware_interface::return_type::OK;
}
#else
CallbackReturn Ros2ControlSystem::on_init(const hardware_interface::HardwareInfo &info)
{
if (hardware_interface::SystemInterface::on_init(info) != CallbackReturn::SUCCESS)
{
if (hardware_interface::SystemInterface::on_init(info) != CallbackReturn::SUCCESS)
{
return CallbackReturn::ERROR;
}
return CallbackReturn::SUCCESS;
return CallbackReturn::ERROR;
}
#endif
return CallbackReturn::SUCCESS;
}
#endif

std::vector<hardware_interface::StateInterface> Ros2ControlSystem::export_state_interfaces()
{
std::vector<hardware_interface::StateInterface> interfaces;
for (Joint &joint : mJoints)
if (joint.sensor)
if (joint.sensor) {
interfaces.emplace_back(hardware_interface::StateInterface(joint.name, hardware_interface::HW_IF_POSITION, &(joint.position)));
interfaces.emplace_back(hardware_interface::StateInterface(joint.name, hardware_interface::HW_IF_VELOCITY, &(joint.velocity)));
interfaces.emplace_back(hardware_interface::StateInterface(joint.name, hardware_interface::HW_IF_ACCELERATION, &(joint.acceleration)));
}

return interfaces;
}
Expand All @@ -119,36 +126,48 @@ namespace webots_ros2_control
return interfaces;
}

#if FOXY
hardware_interface::return_type Ros2ControlSystem::start()
{
status_ = hardware_interface::status::STARTED;
return hardware_interface::return_type::OK;
}
#if FOXY
hardware_interface::return_type Ros2ControlSystem::start()
{
status_ = hardware_interface::status::STARTED;
return hardware_interface::return_type::OK;
}

hardware_interface::return_type Ros2ControlSystem::stop()
{
status_ = hardware_interface::status::STOPPED;
return hardware_interface::return_type::OK;
}
#else
CallbackReturn Ros2ControlSystem::on_activate(const rclcpp_lifecycle::State & /*previous_state*/)
{
return CallbackReturn::SUCCESS;
}
hardware_interface::return_type Ros2ControlSystem::stop()
{
status_ = hardware_interface::status::STOPPED;
return hardware_interface::return_type::OK;
}
#else
CallbackReturn Ros2ControlSystem::on_activate(const rclcpp_lifecycle::State & /*previous_state*/)
{
return CallbackReturn::SUCCESS;
}

CallbackReturn Ros2ControlSystem::on_deactivate(const rclcpp_lifecycle::State & /*previous_state*/)
{
return CallbackReturn::SUCCESS;
}
#endif
CallbackReturn Ros2ControlSystem::on_deactivate(const rclcpp_lifecycle::State & /*previous_state*/)
{
return CallbackReturn::SUCCESS;
}
#endif

hardware_interface::return_type Ros2ControlSystem::read()
{
static double lastReadTime = 0;

const double deltaTime = mNode->robot()->getTime() - lastReadTime;
lastReadTime = mNode->robot()->getTime();

for (Joint &joint : mJoints)
{
if (joint.sensor)
joint.position = joint.sensor->getValue();
const double position = joint.sensor->getValue();
const double velocity = std::isnan(joint.position) ? NAN : (position - joint.position) / deltaTime;

if (joint.sensor) {
if (!std::isnan(joint.velocity))
joint.acceleration = (joint.velocity - velocity) / deltaTime;
joint.velocity = velocity;
joint.position = position;
}
}

return hardware_interface::return_type::OK;
Expand All @@ -163,7 +182,11 @@ namespace webots_ros2_control
if (joint.controlPosition && !std::isnan(joint.positionCommand))
joint.motor->setPosition(joint.positionCommand);
if (joint.controlVelocity && !std::isnan(joint.velocityCommand))
joint.motor->setVelocity(joint.velocityCommand);
{
// In the position control mode the velocity cannot be negative.
const double velocityCommand = joint.controlPosition ? abs(joint.velocityCommand) : joint.velocityCommand;
joint.motor->setVelocity(velocityCommand);
}
if (joint.controlEffort && !std::isnan(joint.effortCommand))
joint.motor->setTorque(joint.effortCommand);
}
Expand Down
3 changes: 1 addition & 2 deletions webots_ros2_driver/include/webots_ros2_driver/WebotsNode.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,11 @@ namespace webots_ros2_driver
public:
WebotsNode(std::string name, webots::Supervisor *robot);
void init();
int step();
webots::Supervisor *robot() { return mRobot; }
std::string urdf() const { return mRobotDescription; };

private:
void timerCallback();
std::unordered_map<std::string, std::string> getDeviceRosProperties(const std::string &name) const;
std::unordered_map<std::string, std::string> getPluginProperties(tinyxml2::XMLElement *pluginElement) const;
void setAnotherNodeParameter(std::string anotherNodeName, std::string parameterName, std::string parameterValue);
Expand All @@ -54,7 +54,6 @@ namespace webots_ros2_driver
std::string mRobotDescription;
bool mSetRobotStatePublisher;

rclcpp::TimerBase::SharedPtr mTimer;
int mStep;
webots::Supervisor *mRobot;
std::vector<std::shared_ptr<PluginInterface>> mPlugins;
Expand Down
7 changes: 5 additions & 2 deletions webots_ros2_driver/src/Driver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,11 @@ int main(int argc, char **argv)

std::shared_ptr<webots_ros2_driver::WebotsNode> node = std::make_shared<webots_ros2_driver::WebotsNode>(robotName, robot);
node->init();

rclcpp::spin(node);
while (true) {
if (node->step())
break;
rclcpp::spin_some(node);
}
delete robot;
rclcpp::shutdown();
return 0;
Expand Down
15 changes: 7 additions & 8 deletions webots_ros2_driver/src/WebotsNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,6 @@ namespace webots_ros2_driver
setAnotherNodeParameter("robot_state_publisher", "robot_description", mRobot->getUrdf());

mStep = mRobot->getBasicTimeStep();
mTimer = this->create_wall_timer(std::chrono::milliseconds(1), std::bind(&WebotsNode::timerCallback, this));

// Load static plugins
// Static plugins are automatically configured based on the robot model.
Expand Down Expand Up @@ -214,19 +213,19 @@ namespace webots_ros2_driver
return plugin;
}

void WebotsNode::timerCallback()
int WebotsNode::step()
{
if (mRobot->step(mStep) == -1) {
mTimer->cancel();
exit(0);
return;
}
const int result = mRobot->step(mStep);
if (result == -1)
return result;
for (std::shared_ptr<PluginInterface> plugin : mPlugins)
plugin->step();

mClockMessage.clock = rclcpp::Time(mRobot->getTime() * 1e9);
mClockPublisher->publish(mClockMessage);
}

return result;
}

void WebotsNode::setAnotherNodeParameter(std::string anotherNodeName, std::string parameterName, std::string parameterValue)
{
Expand Down
16 changes: 9 additions & 7 deletions webots_ros2_driver/src/plugins/static/Ros2Camera.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 +36,18 @@ namespace webots_ros2_driver
mImageMessage.encoding = sensor_msgs::image_encodings::BGRA8;

// CameraInfo publisher
rclcpp::QoS cameraInfoQos(1);
cameraInfoQos.reliable();
cameraInfoQos.transient_local();
cameraInfoQos.keep_last(1);
mCameraInfoPublisher = mNode->create_publisher<sensor_msgs::msg::CameraInfo>(mTopicName + "/camera_info", cameraInfoQos);
mCameraInfoPublisher = mNode->create_publisher<sensor_msgs::msg::CameraInfo>(mTopicName + "/camera_info", rclcpp::SensorDataQoS().reliable());
mCameraInfoMessage.header.stamp = mNode->get_clock()->now();
mCameraInfoMessage.header.frame_id = mFrameName;
mCameraInfoMessage.height = mCamera->getHeight();
mCameraInfoMessage.width = mCamera->getWidth();
mCameraInfoMessage.distortion_model = "plumb_bob";
const double focalLength = (mCamera->getFocalLength() == 0) ? 570.34 : mCamera->getFocalLength();

// Convert FoV to focal length.
// Reference: https://en.wikipedia.org/wiki/Focal_length#In_photography
const double diagonal = sqrt(pow(mCamera->getWidth(), 2) + pow(mCamera->getHeight(), 2));
const double focalLength = 0.5 * diagonal * (cos(0.5 * mCamera->getFov()) / sin(0.5 * mCamera->getFov()));

mCameraInfoMessage.d = {0.0, 0.0, 0.0, 0.0, 0.0};
mCameraInfoMessage.r = {1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0};
mCameraInfoMessage.k = {
Expand All @@ -57,7 +58,6 @@ namespace webots_ros2_driver
focalLength, 0.0, (double)mCamera->getWidth() / 2, 0.0,
0.0, focalLength, (double)mCamera->getHeight() / 2, 0.0,
0.0, 0.0, 1.0, 0.0};
mCameraInfoPublisher->publish(mCameraInfoMessage);

// Recognition publisher
if (mCamera->hasRecognition())
Expand Down Expand Up @@ -109,6 +109,8 @@ namespace webots_ros2_driver
publishImage();
if (recognitionSubscriptionsExist)
publishRecognition();
if (mCameraInfoPublisher->get_subscription_count() > 0)
mCameraInfoPublisher->publish(mCameraInfoMessage);
}

void Ros2Camera::publishImage()
Expand Down
10 changes: 4 additions & 6 deletions webots_ros2_driver/src/plugins/static/Ros2RangeFinder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,7 @@ namespace webots_ros2_driver
mImageMessage.encoding = sensor_msgs::image_encodings::TYPE_32FC1;

// CameraInfo publisher
rclcpp::QoS cameraInfoQos(1);
cameraInfoQos.reliable();
cameraInfoQos.transient_local();
cameraInfoQos.keep_last(1);
mCameraInfoPublisher = mNode->create_publisher<sensor_msgs::msg::CameraInfo>(mTopicName + "/camera_info", cameraInfoQos);
mCameraInfoPublisher = mNode->create_publisher<sensor_msgs::msg::CameraInfo>(mTopicName + "/camera_info", rclcpp::SensorDataQoS().reliable());
mCameraInfoMessage.header.stamp = mNode->get_clock()->now();
mCameraInfoMessage.header.frame_id = mFrameName;
mCameraInfoMessage.height = height;
Expand All @@ -63,7 +59,6 @@ namespace webots_ros2_driver
focalLengthX, 0.0, (double)width / 2, 0.0,
0.0, focalLengthY, (double)height / 2, 0.0,
0.0, 0.0, 1.0, 0.0};
mCameraInfoPublisher->publish(mCameraInfoMessage);

// Point cloud publisher
mPointCloudPublisher = mNode->create_publisher<sensor_msgs::msg::PointCloud2>(mTopicName + "/point_cloud", rclcpp::SensorDataQoS().reliable());
Expand Down Expand Up @@ -104,6 +99,9 @@ namespace webots_ros2_driver
publishPointCloud();
}

if (mCameraInfoPublisher->get_subscription_count() > 0)
mCameraInfoPublisher->publish(mCameraInfoMessage);

if (mAlwaysOn)
return;

Expand Down

0 comments on commit 2ad205c

Please sign in to comment.