Skip to content

Commit

Permalink
fix(autoware_lidar_transfusion): set tensor names by matching with pr…
Browse files Browse the repository at this point in the history
…edefined values. (#9057)

* set tensor order using api

Signed-off-by: Samrat Thapa <[email protected]>

* style(pre-commit): autofix

Signed-off-by: Samrat Thapa <[email protected]>

* fix tensor order

Signed-off-by: Samrat Thapa <[email protected]>

* style(pre-commit): autofix

Signed-off-by: Samrat Thapa <[email protected]>

* style fix

Signed-off-by: Samrat Thapa <[email protected]>

* style(pre-commit): autofix

---------

Signed-off-by: Samrat Thapa <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
SamratThapa120 and pre-commit-ci[bot] authored Oct 10, 2024
1 parent 146be20 commit 23bb4a0
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ struct Box3D
float yaw;
};

enum NetworkIO { voxels = 0, num_points, coors, cls_score, dir_pred, bbox_pred, ENUM_SIZE };
enum NetworkIO { voxels = 0, num_points, coors, cls_score, bbox_pred, dir_pred, ENUM_SIZE };

// cspell: ignore divup
template <typename T1, typename T2>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,20 @@
namespace autoware::lidar_transfusion
{

inline NetworkIO nameToNetworkIO(const char * name)
{
static const std::unordered_map<std::string_view, NetworkIO> name_to_enum = {
{"voxels", NetworkIO::voxels}, {"num_points", NetworkIO::num_points},
{"coors", NetworkIO::coors}, {"cls_score0", NetworkIO::cls_score},
{"bbox_pred0", NetworkIO::bbox_pred}, {"dir_cls_pred0", NetworkIO::dir_pred}};

auto it = name_to_enum.find(name);
if (it != name_to_enum.end()) {
return it->second;
}
throw std::runtime_error("Invalid input name: " + std::string(name));
}

std::ostream & operator<<(std::ostream & os, const ProfileDimension & profile)
{
std::string delim = "";
Expand Down Expand Up @@ -253,8 +267,14 @@ bool NetworkTRT::validateNetworkIO()
<< ". Actual size: " << engine->getNbIOTensors() << "." << std::endl;
throw std::runtime_error("Failed to initialize TRT network.");
}

// Initialize tensors_names_ with null values
tensors_names_.resize(NetworkIO::ENUM_SIZE, nullptr);

// Loop over the tensor names and place them in the correct positions
for (int i = 0; i < NetworkIO::ENUM_SIZE; ++i) {
tensors_names_.push_back(engine->getIOTensorName(i));
const char * name = engine->getIOTensorName(i);
tensors_names_[nameToNetworkIO(name)] = name;
}

// Log the network IO
Expand Down

0 comments on commit 23bb4a0

Please sign in to comment.