-
Notifications
You must be signed in to change notification settings - Fork 546
/
ModelRefitter.hpp
108 lines (88 loc) · 4.26 KB
/
ModelRefitter.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
/*
* SPDX-License-Identifier: Apache-2.0
*/
#pragma once
#include "NvInferRuntime.h"
#include "Status.hpp"
#include "WeightsContext.hpp"
#include "errorHelpers.hpp"
#include <onnx/onnx_pb.h>
#include <string>
#include <unordered_set>
#include <vector>
// Logging macros
#define LOG_REFITTER(msg, severity) \
do \
{ \
std::ostringstream ss{}; \
if (severity <= nvinfer1::ILogger::Severity::kWARNING) \
ss << __FILENAME__ << ":" << __LINE__ << ": "; \
ss << msg; \
mLogger->log(severity, ss.str().c_str()); \
} while (0)
#define LOG_REFITTER_WARNING(msg) LOG_REFITTER(msg, nvinfer1::ILogger::Severity::kWARNING)
namespace onnx2trt
{
class ModelRefitter : public nvonnxparser::IParserRefitter
{
private:
nvinfer1::IRefitter* mRefitter;
nvinfer1::ILogger* mLogger;
//! WeightsContext object to hold ownership of ONNX weights and any temporary weights created by the refitter.
WeightsContext mWeightsContext;
//! ONNX ModelProto object to hold ownership of ONNX weights whenever a data type conversion is not needed.
::ONNX_NAMESPACE::ModelProto onnx_model;
//! Counter to limit the recursion depth to a set amount for nodes containing subgraphs.
size_t nestedDepth{0};
//! Set to keep track of how many times a batch norm weight name shows up, to avoid duplicate naming in TRT.
std::set<std::string> mBatchNormWeightNames;
//! An increasing suffix counter used to uniquify batch norm weight names.
int64_t mBatchNormWeightSuffixCounter{0};
size_t successfullyRefittedWeights{};
std::unordered_set<std::string> refittableWeights;
std::unordered_set<std::string> refittedWeights;
mutable std::vector<Status> mErrors;
std::unordered_set<std::string> getRefittableWeights();
//! T is the working type.
//! TConvertFunc is a functor for converting ShapedWeights to an array of type T.
//! It should return a T*.
template <typename T, typename TConvertFunc>
size_t batchnormWeightRefitter(
::ONNX_NAMESPACE::NodeProto const& node, std::vector<ShapedWeights>& inputs, TConvertFunc&& f);
void refitOnnxWeights(::ONNX_NAMESPACE::ModelProto const& onnx_model);
void refitOnnxGraph(::ONNX_NAMESPACE::GraphProto const& graph);
void refitOnnxNode(::ONNX_NAMESPACE::NodeProto const& node, ::ONNX_NAMESPACE::GraphProto const& graph);
void refitOnnxConstantNode(::ONNX_NAMESPACE::NodeProto const& node, std::string const& graphName);
void refitOnnxBatchNormNode(::ONNX_NAMESPACE::NodeProto const& node, ::ONNX_NAMESPACE::GraphProto const& graph);
void refitOnnxIfNode(::ONNX_NAMESPACE::NodeProto const& node);
void refitOnnxLoopNode(::ONNX_NAMESPACE::NodeProto const& node);
void refitOnnxScanNode(::ONNX_NAMESPACE::NodeProto const& node);
public:
ModelRefitter(nvinfer1::IRefitter* refitter, nvinfer1::ILogger* logger)
: mRefitter{refitter}
, mLogger{logger}
, mWeightsContext{WeightsContext{logger}}
{
}
bool refitFromBytes(void const* serializedOnnxModel, size_t serializedOnnxModelSize,
char const* modelPath = nullptr) noexcept override;
bool refitFromFile(char const* onnxModelFile) noexcept override;
int32_t getNbErrors() const noexcept override
{
return mErrors.size();
}
nvonnxparser::IParserError const* getError(int32_t index) const noexcept override
{
ONNXTRT_TRY
{
return &mErrors.at(index);
}
ONNXTRT_CATCH_LOG(mLogger)
return nullptr;
}
void clearErrors() noexcept override
{
mErrors.clear();
}
};
} // namespace onnx2trt