Skip to content

Commit

Permalink
Merge pull request #77 from KIT-MRT/fix_cost_arbitrator_called_twice
Browse files Browse the repository at this point in the history
Fix cost arbitrator called twice #patch
  • Loading branch information
ll-nick authored Nov 19, 2024
2 parents 091d4f5 + dd6a017 commit 53d764b
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 11 deletions.
10 changes: 9 additions & 1 deletion include/arbitration_graphs/arbitrator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,16 @@ class Arbitrator : public Behavior<CommandT> {

typename Behavior<SubCommandT>::Ptr behavior_;
FlagsT flags_;
mutable util_caching::Cache<Time, SubCommandT> command_;
mutable util_caching::Cache<Time, VerificationResultT> verificationResult_;

SubCommandT getCommand(const Time& time) const {
if (!command_.cached(time)) {
command_.cache(time, behavior_->getCommand(time));
}
return command_.cached(time).value();
}

bool hasFlag(const FlagsT& flag_to_check) const {
return flags_ & flag_to_check;
}
Expand Down Expand Up @@ -250,4 +258,4 @@ class Arbitrator : public Behavior<CommandT> {
} // namespace arbitration_graphs

#include "internal/arbitrator_impl.hpp"
#include "internal/arbitrator_io.hpp"
#include "internal/arbitrator_io.hpp"
8 changes: 4 additions & 4 deletions include/arbitration_graphs/cost_arbitrator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ class CostArbitrator : public Arbitrator<CommandT, SubCommandT, VerifierT, Verif


CostArbitrator(const std::string& name = "CostArbitrator", const VerifierT& verifier = VerifierT())
: ArbitratorBase(name, verifier){};
: ArbitratorBase(name, verifier) {};


void addOption(const typename Behavior<SubCommandT>::Ptr& behavior,
Expand Down Expand Up @@ -118,10 +118,10 @@ class CostArbitrator : public Arbitrator<CommandT, SubCommandT, VerifierT, Verif

double cost;
if (isActive) {
cost = option->costEstimator_->estimateCost(option->behavior_->getCommand(time), isActive);
cost = option->costEstimator_->estimateCost(option->getCommand(time), isActive);
} else {
option->behavior_->gainControl(time);
cost = option->costEstimator_->estimateCost(option->behavior_->getCommand(time), isActive);
cost = option->costEstimator_->estimateCost(option->getCommand(time), isActive);
option->behavior_->loseControl(time);
}
option->last_estimated_cost_ = cost;
Expand All @@ -139,4 +139,4 @@ class CostArbitrator : public Arbitrator<CommandT, SubCommandT, VerifierT, Verif
};
} // namespace arbitration_graphs

#include "internal/cost_arbitrator_io.hpp"
#include "internal/cost_arbitrator_io.hpp"
4 changes: 2 additions & 2 deletions include/arbitration_graphs/internal/arbitrator_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ template <typename CommandT, typename SubCommandT, typename VerifierT, typename
std::optional<SubCommandT> Arbitrator<CommandT, SubCommandT, VerifierT, VerificationResultT>::getAndVerifyCommand(
const typename Option::Ptr& option, const Time& time) const {
try {
const SubCommandT command = option->behavior_->getCommand(time);
const SubCommandT command = option->getCommand(time);

const VerificationResultT verificationResult = verifier_.analyze(time, command);
option->verificationResult_.cache(time, verificationResult);
Expand Down Expand Up @@ -127,4 +127,4 @@ SubCommandT Arbitrator<CommandT, SubCommandT, VerifierT, VerificationResultT>::g
" applicable options passed the verification step!");
}

} // namespace arbitration_graphs
} // namespace arbitration_graphs
30 changes: 30 additions & 0 deletions test/cost_arbitrator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,36 @@ TEST_F(CostArbitratorTest, BasicFunctionality) {
EXPECT_EQ("high_cost", testCostArbitrator.getCommand(time));
}

TEST_F(CostArbitratorTest, CommandCaching) {
testCostArbitrator.addOption(testBehaviorLowCost, OptionFlags::NO_FLAGS, cost_estimator);
testCostArbitrator.addOption(testBehaviorLowCost, OptionFlags::NO_FLAGS, cost_estimator);
testCostArbitrator.addOption(testBehaviorHighCost, OptionFlags::NO_FLAGS, cost_estimator);
testCostArbitrator.addOption(testBehaviorMidCost, OptionFlags::NO_FLAGS, cost_estimator);

EXPECT_TRUE(testCostArbitrator.checkInvocationCondition(time));
EXPECT_FALSE(testCostArbitrator.checkCommitmentCondition(time));
EXPECT_EQ(0, testBehaviorMidCost->getCommandCounter_);

testCostArbitrator.gainControl(time);

// Even though the cost arbitrator needs to compute the command to estimate the costs, the behaviors getCommand
// should only be called once since the result is cached
EXPECT_EQ("mid_cost", testCostArbitrator.getCommand(time));
EXPECT_EQ(1, testBehaviorMidCost->getCommandCounter_);
EXPECT_EQ("mid_cost", testCostArbitrator.getCommand(time));
// For a second call to getCommand, we can still use the cached command
EXPECT_EQ(1, testBehaviorMidCost->getCommandCounter_);

time = time + Duration(1);

// The cached command should be invalidated after the time has passed
// Therefore the behavior should be called again once for the new time
EXPECT_EQ("mid_cost", testCostArbitrator.getCommand(time));
EXPECT_EQ(2, testBehaviorMidCost->getCommandCounter_);
EXPECT_EQ("mid_cost", testCostArbitrator.getCommand(time));
EXPECT_EQ(2, testBehaviorMidCost->getCommandCounter_);
}

TEST_F(CostArbitratorTest, Printout) {
testCostArbitrator.addOption(testBehaviorLowCost, OptionFlags::NO_FLAGS, cost_estimator);
testCostArbitrator.addOption(testBehaviorLowCost, OptionFlags::NO_FLAGS, cost_estimator);
Expand Down
9 changes: 5 additions & 4 deletions test/dummy_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,10 @@ class DummyBehavior : public Behavior<DummyCommand> {
using Ptr = std::shared_ptr<DummyBehavior>;

DummyBehavior(const bool invocation, const bool commitment, const std::string& name = "DummyBehavior")
: Behavior(name), invocationCondition_{invocation}, commitmentCondition_{commitment}, loseControlCounter_{
0} {};
: Behavior(name), invocationCondition_{invocation}, commitmentCondition_{commitment} {};

DummyCommand getCommand(const Time& time) override {
getCommandCounter_++;
return name_;
}
bool checkInvocationCondition(const Time& time) const override {
Expand All @@ -59,7 +59,8 @@ class DummyBehavior : public Behavior<DummyCommand> {

bool invocationCondition_;
bool commitmentCondition_;
int loseControlCounter_;
int getCommandCounter_{0};
int loseControlCounter_{0};
};

struct DummyResult {
Expand All @@ -76,4 +77,4 @@ struct DummyResult {
inline std::ostream& operator<<(std::ostream& out, const arbitration_graphs_tests::DummyResult& result) {
out << (result.isOk() ? "is okay" : "is not okay");
return out;
}
}

0 comments on commit 53d764b

Please sign in to comment.