Skip to content

Commit

Permalink
update TrackHeedSimTool
Browse files Browse the repository at this point in the history
  • Loading branch information
wenxingfang committed Nov 8, 2023
1 parent 553b3aa commit f50cd1a
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 2 deletions.
2 changes: 2 additions & 0 deletions Detector/DetDriftChamber/compact/det.xml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@


<define>
<constant name="tracker_region_rmax" value="1723*mm" />
<constant name="tracker_region_zmax" value="3050*mm" />
<constant name="world_size" value="2226*mm"/>
<constant name="world_x" value="world_size"/>
<constant name="world_y" value="world_size"/>
Expand Down
5 changes: 4 additions & 1 deletion Simulation/DetSimDedx/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
find_package(Geant4 REQUIRED ui_all vis_all)
include(${Geant4_USE_FILE})
find_package(Garfield REQUIRED)
message(Garfield::Garfield)
message("libonnxruntime ${OnnxRuntime_LIBRARY}")
message("libonnxruntime include ${OnnxRuntime_INCLUDE_DIR}")
find_package(OnnxRuntime REQUIRED)

message("libonnxruntime ${OnnxRuntime_LIBRARY}")
Expand All @@ -20,7 +22,8 @@ gaudi_add_module(DetSimDedx
EDM4HEP::edm4hep EDM4HEP::edm4hepDict
k4FWCore::k4FWCore
Garfield::Garfield
${OnnxRuntime_LIBRARY}
OnnxRuntime
#${OnnxRuntime_LIBRARY}
#/cvmfs/sft.cern.ch/lcg/views/LCG_103/x86_64-centos7-gcc11-opt/lib/libonnxruntime.so
${CLHEP_LIBRARIES}

Expand Down
20 changes: 20 additions & 0 deletions Simulation/DetSimDedx/src/TrackHeedSimTool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -428,9 +428,15 @@ StatusCode TrackHeedSimTool::initialize()
auto num_input_nodes = m_session->GetInputCount();
if(m_debug) std::cout << "num_input_nodes: " << num_input_nodes << std::endl;
for (size_t i = 0; i < num_input_nodes; ++i) {
#if (ORT_API_VERSION >=13)
auto name = m_session->GetInputNameAllocated(i, m_allocator);
m_inputNodeNameAllocatedStrings.push_back(std::move(name));
m_input_node_names.push_back(m_inputNodeNameAllocatedStrings.back().get());
#else
auto name = m_session->GetInputName(i, m_allocator);
m_inputNodeNameAllocatedStrings.push_back(name);
m_input_node_names.push_back(m_inputNodeNameAllocatedStrings.back());
#endif

Ort::TypeInfo type_info = m_session->GetInputTypeInfo(i);
auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
Expand All @@ -441,23 +447,37 @@ StatusCode TrackHeedSimTool::initialize()


if(m_debug) std::cout<< "[" << i << "]"
#if (ORT_API_VERSION >=13)
<< " input_name: " << m_inputNodeNameAllocatedStrings.back().get()
#else
<< " input_name: " << m_inputNodeNameAllocatedStrings.back()
#endif
<< " ndims: " << dims.size()
<< " dims: " << dims_str(dims)
<< std::endl;
}
// prepare the output
size_t num_output_nodes = m_session->GetOutputCount();
for(std::size_t i = 0; i < num_output_nodes; i++) {
#if (ORT_API_VERSION >=13)
auto output_name = m_session->GetOutputNameAllocated(i, m_allocator);
m_outputNodeNameAllocatedStrings.push_back(std::move(output_name));
m_output_node_names.push_back(m_outputNodeNameAllocatedStrings.back().get());
#else
auto output_name = m_session->GetOutputName(i, m_allocator);
m_outputNodeNameAllocatedStrings.push_back(output_name);
m_output_node_names.push_back(m_outputNodeNameAllocatedStrings.back());
#endif
Ort::TypeInfo type_info = m_session->GetOutputTypeInfo(i);
auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
ONNXTensorElementDataType type = tensor_info.GetElementType();
m_output_node_dims = tensor_info.GetShape();
if(m_debug) std::cout << "[" << i << "]"
#if (ORT_API_VERSION >=13)
<< " output_name: " << m_outputNodeNameAllocatedStrings.back().get()
#else
<< " output_name: " << m_outputNodeNameAllocatedStrings.back()
#endif
<< " ndims: " << m_output_node_dims.size()
<< " dims: " << dims_str(m_output_node_dims)
<< std::endl;
Expand Down
7 changes: 6 additions & 1 deletion Simulation/DetSimDedx/src/TrackHeedSimTool.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
#include <string>

#include "core/session/onnxruntime_cxx_api.h"
#include "core/session/onnxruntime_c_api.h"
using namespace Garfield;

class TrackHeedSimTool: public extends<AlgTool, IDedxSimTool> {
Expand Down Expand Up @@ -135,9 +136,13 @@ class TrackHeedSimTool: public extends<AlgTool, IDedxSimTool> {
std::vector<std::vector<int64_t>> m_input_node_dims;
std::vector<const char*> m_output_node_names;
std::vector<int64_t> m_output_node_dims;
#if (ORT_API_VERSION >=13)
std::vector<Ort::AllocatedStringPtr> m_inputNodeNameAllocatedStrings;
std::vector<Ort::AllocatedStringPtr> m_outputNodeNameAllocatedStrings;

#else
std::vector<const char*> m_inputNodeNameAllocatedStrings;
std::vector<const char*> m_outputNodeNameAllocatedStrings;
#endif

Gaudi::Property<bool> m_sim_pulse { this, "sim_pulse" , true };
Gaudi::Property<std::string> m_model_file{ this, "model", "model_test.onnx"};
Expand Down

0 comments on commit f50cd1a

Please sign in to comment.