Program Listing for File async_function_handler.hpp
↰ Return to documentation for file (include/realtime_tools/async_function_handler.hpp)
// Copyright 2024 PAL Robotics S.L.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef REALTIME_TOOLS__ASYNC_FUNCTION_HANDLER_HPP_
#define REALTIME_TOOLS__ASYNC_FUNCTION_HANDLER_HPP_
#include <atomic>
#include <cmath>
#include <condition_variable>
#include <functional>
#include <limits>
#include <memory>
#include <mutex>
#include <stdexcept>
#include <string>
#include <thread>
#include <utility>
#include <vector>
#include "rclcpp/clock.hpp"
#include "rclcpp/duration.hpp"
#include "rclcpp/logging.hpp"
#include "rclcpp/time.hpp"
#include "realtime_tools/realtime_helpers.hpp"
namespace realtime_tools
{
class AsyncSchedulingPolicy
{
public:
enum Value : int8_t {
UNKNOWN = -1,
SYNCHRONIZED,
DETACHED,
};
AsyncSchedulingPolicy() = default;
constexpr AsyncSchedulingPolicy(Value value) : value_(value) {} // NOLINT(runtime/explicit)
explicit AsyncSchedulingPolicy(const std::string & data_type)
{
if (data_type == "synchronized") {
value_ = SYNCHRONIZED;
} else if (data_type == "detached") {
value_ = DETACHED;
} else {
value_ = UNKNOWN;
}
}
operator Value() const { return value_; }
explicit operator bool() const = delete;
constexpr bool operator==(AsyncSchedulingPolicy other) const { return value_ == other.value_; }
constexpr bool operator!=(AsyncSchedulingPolicy other) const { return value_ != other.value_; }
constexpr bool operator==(Value other) const { return value_ == other; }
constexpr bool operator!=(Value other) const { return value_ != other; }
std::string to_string() const
{
switch (value_) {
case SYNCHRONIZED:
return "synchronized";
case DETACHED:
return "detached";
default:
return "unknown";
}
}
AsyncSchedulingPolicy from_string(const std::string & data_type)
{
return AsyncSchedulingPolicy(data_type);
}
private:
Value value_ = UNKNOWN;
};
struct AsyncFunctionHandlerParams
{
bool validate() const
{
if (thread_priority < 0 || thread_priority > 99) {
RCLCPP_ERROR(
logger, "Invalid thread priority: %d. It should be between 0 and 99.", thread_priority);
return false;
}
if (scheduling_policy == AsyncSchedulingPolicy::DETACHED) {
if (!clock) {
RCLCPP_ERROR(logger, "Clock must be set when using DETACHED scheduling policy.");
return false;
}
if (exec_rate == 0u) {
RCLCPP_ERROR(logger, "Execution rate must be set when using DETACHED scheduling policy.");
return false;
}
}
if (scheduling_policy == AsyncSchedulingPolicy::UNKNOWN) {
throw std::runtime_error(
"AsyncFunctionHandlerParams: scheduling policy is unknown. "
"Please set it to either 'synchronized' or 'detached'.");
}
if (trigger_predicate == nullptr) {
RCLCPP_ERROR(logger, "The parsed trigger predicate is not valid!");
return false;
}
for (const int & core : cpu_affinity_cores) {
if (core < 0) {
RCLCPP_ERROR(logger, "Invalid CPU core id: %d. It should be a non-negative integer.", core);
return false;
}
}
return true;
}
template <typename NodeT>
void initialize(NodeT & node, const std::string & prefix)
{
if (node->has_parameter(prefix + "thread_priority")) {
thread_priority = static_cast<int>(node->get_parameter(prefix + "thread_priority").as_int());
}
if (node->has_parameter(prefix + "cpu_affinity")) {
const auto cpu_affinity_param =
node->get_parameter(prefix + "cpu_affinity").as_integer_array();
for (const auto & core : cpu_affinity_param) {
cpu_affinity_cores.push_back(static_cast<int>(core));
}
}
if (node->has_parameter(prefix + "scheduling_policy")) {
scheduling_policy =
AsyncSchedulingPolicy(node->get_parameter(prefix + "scheduling_policy").as_string());
}
if (
scheduling_policy == AsyncSchedulingPolicy::DETACHED &&
node->has_parameter(prefix + "execution_rate")) {
const int execution_rate =
static_cast<int>(node->get_parameter(prefix + "execution_rate").as_int());
if (execution_rate <= 0) {
throw std::runtime_error(
"AsyncFunctionHandler: execution_rate parameter must be positive.");
}
exec_rate = static_cast<unsigned int>(execution_rate);
}
if (node->has_parameter(prefix + "wait_until_initial_trigger")) {
wait_until_initial_trigger =
node->get_parameter(prefix + "wait_until_initial_trigger").as_bool();
}
if (node->has_parameter(prefix + "print_warnings")) {
print_warnings = node->get_parameter(prefix + "print_warnings").as_bool();
}
}
int thread_priority = 50;
std::vector<int> cpu_affinity_cores = {};
AsyncSchedulingPolicy scheduling_policy = AsyncSchedulingPolicy::SYNCHRONIZED;
unsigned int exec_rate = 0u;
rclcpp::Clock::SharedPtr clock = nullptr;
rclcpp::Logger logger = rclcpp::get_logger("AsyncFunctionHandler");
std::function<bool()> trigger_predicate = []() { return true; };
bool wait_until_initial_trigger = true;
bool print_warnings = true;
};
template <typename T>
class AsyncFunctionHandler
{
public:
AsyncFunctionHandler() = default;
~AsyncFunctionHandler() { stop_thread(); }
void init(
std::function<T(const rclcpp::Time &, const rclcpp::Duration &)> callback,
int thread_priority = 50)
{
if (callback == nullptr) {
throw std::runtime_error(
"AsyncFunctionHandler: parsed function to call asynchronously is not valid!");
}
if (thread_.joinable()) {
throw std::runtime_error(
"AsyncFunctionHandler: Cannot reinitialize while the thread is "
"running. Please stop the async callback first!");
}
async_function_ = callback;
thread_priority_ = thread_priority;
}
void init(
std::function<T(const rclcpp::Time &, const rclcpp::Duration &)> callback,
std::function<bool()> trigger_predicate, int thread_priority = 50)
{
if (trigger_predicate == nullptr) {
throw std::runtime_error("AsyncFunctionHandler: parsed trigger predicate is not valid!");
}
init(callback, thread_priority);
trigger_predicate_ = trigger_predicate;
}
void init(
std::function<T(const rclcpp::Time &, const rclcpp::Duration &)> callback,
const AsyncFunctionHandlerParams & params)
{
params.validate();
init(callback, params.trigger_predicate, params.thread_priority);
params_ = params;
pause_thread_ = params.wait_until_initial_trigger;
}
std::pair<bool, T> trigger_async_callback(
const rclcpp::Time & time, const rclcpp::Duration & period)
{
if (!is_initialized()) {
throw std::runtime_error("AsyncFunctionHandler: need to be initialized first!");
}
if (async_exception_ptr_) {
RCLCPP_ERROR(
params_.logger, "AsyncFunctionHandler: Exception caught in the async callback thread!");
std::rethrow_exception(async_exception_ptr_);
}
if (params_.scheduling_policy == AsyncSchedulingPolicy::DETACHED) {
RCLCPP_WARN_ONCE(
params_.logger,
"AsyncFunctionHandler is configured with DETACHED scheduling policy. "
"This means that the async callback may not be synchronized with the main thread. ");
if (pause_thread_.load(std::memory_order_relaxed)) {
{
std::unique_lock<std::mutex> lock(async_mtx_);
pause_thread_ = false;
RCLCPP_INFO(params_.logger, "AsyncFunctionHandler: Resuming the async callback thread.");
async_callback_return_ = T();
auto const sync_period = std::chrono::nanoseconds(1'000'000'000 / params_.exec_rate);
previous_time_ = params_.clock->now() - rclcpp::Duration(sync_period);
}
async_callback_condition_.notify_one();
}
return std::make_pair(true, async_callback_return_.load(std::memory_order_relaxed));
}
if (!is_running()) {
throw std::runtime_error(
"AsyncFunctionHandler: need to start the async callback thread first before triggering!");
}
std::unique_lock<std::mutex> lock(async_mtx_, std::try_to_lock);
bool trigger_status = false;
if (lock.owns_lock() && !trigger_in_progress_ && trigger_predicate_()) {
{
std::unique_lock<std::mutex> scoped_lock(std::move(lock));
trigger_in_progress_ = true;
current_callback_time_ = time;
current_callback_period_ = period;
}
async_callback_condition_.notify_one();
trigger_status = true;
}
const T return_value = async_callback_return_;
return std::make_pair(trigger_status, return_value);
}
T get_last_return_value() const { return async_callback_return_; }
const rclcpp::Time & get_current_callback_time() const { return current_callback_time_; }
const rclcpp::Duration & get_current_callback_period() const { return current_callback_period_; }
void reset_variables()
{
std::unique_lock<std::mutex> lock(async_mtx_);
stop_async_callback_ = false;
trigger_in_progress_ = false;
current_callback_time_ = rclcpp::Time(0, 0, RCL_CLOCK_UNINITIALIZED);
current_callback_period_ = rclcpp::Duration(0, 0);
last_execution_time_ = std::chrono::nanoseconds(0);
async_callback_return_ = T();
async_exception_ptr_ = nullptr;
}
bool wait_for_trigger_cycle_to_finish()
{
if (is_running()) {
std::unique_lock<std::mutex> lock(async_mtx_);
cycle_end_condition_.wait(lock, [this] { return !trigger_in_progress_; });
return true;
}
return false;
}
bool pause_execution()
{
RCLCPP_INFO_EXPRESSION(
params_.logger, !pause_thread_, "AsyncFunctionHandler: Pausing the async callback thread.");
if (params_.scheduling_policy == AsyncSchedulingPolicy::SYNCHRONIZED) {
pause_thread_ = true;
return wait_for_trigger_cycle_to_finish();
} else {
if (is_running()) {
pause_thread_.store(true, std::memory_order_relaxed);
std::unique_lock<std::mutex> lock(async_mtx_);
return true;
}
}
return pause_thread_.load(std::memory_order_relaxed);
}
bool is_initialized() const { return async_function_ && trigger_predicate_; }
void join_async_callback_thread()
{
if (is_running()) {
thread_.join();
}
}
bool is_running() const { return thread_.joinable(); }
bool is_stopped() const { return stop_async_callback_.load(std::memory_order_relaxed); }
bool is_paused() const { return pause_thread_.load(std::memory_order_relaxed); }
std::thread & get_thread() { return thread_; }
const std::thread & get_thread() const { return thread_; }
const AsyncFunctionHandlerParams & get_params() const { return params_; }
bool is_trigger_cycle_in_progress() const { return trigger_in_progress_; }
void stop_thread()
{
if (is_running()) {
{
stop_async_callback_.store(true, std::memory_order_relaxed);
std::unique_lock<std::mutex> lock(async_mtx_);
}
async_callback_condition_.notify_one();
thread_.join();
}
}
std::chrono::nanoseconds get_last_execution_time() const
{
return last_execution_time_.load(std::memory_order_relaxed);
}
void start_thread()
{
if (!is_initialized()) {
throw std::runtime_error("AsyncFunctionHandler: need to be initialized first!");
}
if (!thread_.joinable()) {
reset_variables();
thread_ = std::thread([this]() -> void {
if (!realtime_tools::configure_sched_fifo(thread_priority_)) {
RCLCPP_WARN(
params_.logger,
"Could not enable FIFO RT scheduling policy. Consider setting up your user to do FIFO "
"RT scheduling. See "
"[https://control.ros.org/master/doc/ros2_control/controller_manager/doc/userdoc.html] "
"for details.");
}
if (!params_.cpu_affinity_cores.empty()) {
const auto affinity_result =
realtime_tools::set_current_thread_affinity(params_.cpu_affinity_cores);
RCLCPP_WARN_EXPRESSION(
params_.logger, !affinity_result.first,
"Could not set CPU affinity for the async worker thread. Error: %s",
affinity_result.second.c_str());
RCLCPP_WARN_EXPRESSION(
params_.logger, affinity_result.first,
"Async worker thread is successfully pinned to the requested CPU cores!");
}
if (params_.scheduling_policy == AsyncSchedulingPolicy::SYNCHRONIZED) {
execute_synchronized_callback();
} else {
execute_detached_callback();
}
});
}
}
private:
void execute_synchronized_callback()
{
while (!stop_async_callback_.load(std::memory_order_relaxed)) {
{
std::unique_lock<std::mutex> lock(async_mtx_);
async_callback_condition_.wait(
lock, [this] { return trigger_in_progress_ || stop_async_callback_; });
if (!stop_async_callback_) {
const auto start_time = std::chrono::steady_clock::now();
try {
async_callback_return_ =
async_function_(current_callback_time_, current_callback_period_);
} catch (...) {
async_exception_ptr_ = std::current_exception();
}
const auto end_time = std::chrono::steady_clock::now();
last_execution_time_ =
std::chrono::duration_cast<std::chrono::nanoseconds>(end_time - start_time);
}
trigger_in_progress_ = false;
}
cycle_end_condition_.notify_all();
}
}
void execute_detached_callback()
{
if (!params_.clock) {
throw std::runtime_error(
"AsyncFunctionHandler: Clock must be set when using DETACHED scheduling policy.");
}
if (params_.exec_rate == 0u) {
throw std::runtime_error(
"AsyncFunctionHandler: Execution rate must be set when using DETACHED scheduling policy.");
}
auto const period = std::chrono::nanoseconds(1'000'000'000 / params_.exec_rate);
if (pause_thread_) {
std::unique_lock<std::mutex> lock(async_mtx_);
async_callback_condition_.wait(
lock, [this] { return !pause_thread_ || stop_async_callback_; });
}
// for calculating the measured period of the loop
previous_time_ = params_.clock->now();
std::this_thread::sleep_for(period);
std::chrono::steady_clock::time_point next_iteration_time{std::chrono::steady_clock::now()};
while (!stop_async_callback_.load(std::memory_order_relaxed)) {
{
std::unique_lock<std::mutex> lock(async_mtx_);
async_callback_condition_.wait(
lock, [this] { return !pause_thread_ || stop_async_callback_; });
if (!stop_async_callback_) {
// calculate measured period
auto const current_time = params_.clock->now();
auto const measured_period = current_time - previous_time_;
previous_time_ = current_time;
current_callback_time_ = current_time;
current_callback_period_ = measured_period;
const auto start_time = std::chrono::steady_clock::now();
try {
async_callback_return_ = async_function_(current_time, measured_period);
} catch (...) {
async_exception_ptr_ = std::current_exception();
}
last_execution_time_ = std::chrono::duration_cast<std::chrono::nanoseconds>(
std::chrono::steady_clock::now() - start_time);
next_iteration_time += period;
const auto time_now = std::chrono::steady_clock::now();
if (next_iteration_time < time_now) {
const double time_diff =
std::chrono::duration<double, std::milli>(time_now - next_iteration_time).count();
const double cm_period = 1.e3 / static_cast<double>(params_.exec_rate);
const int overrun_count = static_cast<int>(std::ceil(time_diff / cm_period));
if (params_.print_warnings) {
RCLCPP_WARN_THROTTLE(
params_.logger, *params_.clock, 1000,
"Overrun detected! The async callback missed its desired rate of %d Hz. The loop "
"took %f ms (missed cycles : %d).",
params_.exec_rate, time_diff + cm_period, overrun_count + 1);
}
next_iteration_time += (overrun_count * period);
}
std::this_thread::sleep_until(next_iteration_time);
}
trigger_in_progress_ = false;
}
cycle_end_condition_.notify_all();
}
}
rclcpp::Time current_callback_time_ = rclcpp::Time(0, 0, RCL_CLOCK_UNINITIALIZED);
rclcpp::Duration current_callback_period_{0, 0};
std::function<T(const rclcpp::Time &, const rclcpp::Duration &)> async_function_;
std::function<bool()> trigger_predicate_ = []() { return true; };
// Async related variables
std::thread thread_;
AsyncFunctionHandlerParams params_;
rclcpp::Time previous_time_{0, 0, RCL_CLOCK_UNINITIALIZED};
int thread_priority_ = std::numeric_limits<int>::quiet_NaN();
std::atomic_bool stop_async_callback_{false};
std::atomic_bool trigger_in_progress_{false};
std::atomic_bool pause_thread_{false};
std::atomic<T> async_callback_return_;
std::condition_variable async_callback_condition_;
std::condition_variable cycle_end_condition_;
std::mutex async_mtx_;
std::atomic<std::chrono::nanoseconds> last_execution_time_;
std::atomic<double> periodicity_;
std::exception_ptr async_exception_ptr_;
};
} // namespace realtime_tools
#endif // REALTIME_TOOLS__ASYNC_FUNCTION_HANDLER_HPP_