-
Notifications
You must be signed in to change notification settings - Fork 23
/
dump.cc
96 lines (82 loc) · 2.84 KB
/
dump.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
// Dump an ONNX proto
#include <algorithm>
#include <cstdlib>
#include <fstream>
#include <iostream>
#include <limits>
#include <string>
#include <vector>
#include <compiler/onnx.h>
#include <common/log.h>
#include <common/protoutil.h>
#include <common/strutil.h>
#include <compiler/tensor.h>
#include <compiler/util.h>
#include <tools/cmdline.h>
#include <tools/util.h>
namespace chainer_compiler {
namespace runtime {
namespace {
void DumpONNX(const std::string& filename, const cmdline::parser& args) {
onnx::ModelProto model(LoadLargeProto<onnx::ModelProto>(filename));
onnx::GraphProto* graph = model.mutable_graph();
for (int i = 0; i < graph->initializer_size(); ++i) {
onnx::TensorProto* tensor = graph->mutable_initializer(i);
if (!args.exist("full")) {
StripLargeValue(tensor, 20);
}
MakeHumanReadableValue(tensor);
}
std::cout << model.DebugString();
}
void DumpTensor(const std::string& filename) {
onnx::TensorProto xtensor(LoadLargeProto<onnx::TensorProto>(filename));
chainer_compiler::Tensor tensor(xtensor);
onnx::TensorProto xtensor_normalized;
tensor.ToONNX(&xtensor_normalized);
std::cout << xtensor_normalized.DebugString();
}
void RunMain(int argc, char** argv) {
cmdline::parser args;
args.add("full", '\0', "Dump all tensor values.");
args.parse_check(argc, argv);
if (args.rest().empty()) {
QFAIL() << "Usage: " << argv[0] << " <onnx>";
}
chainerx::Context ctx;
chainerx::ContextScope ctx_scope(ctx);
for (const std::string& filename : args.rest()) {
std::cout << "=== " << filename << " ===\n";
if (HasSuffix(filename, ".onnx")) {
DumpONNX(filename, args);
} else if (HasSuffix(filename, ".pb")) {
DumpTensor(filename);
} else {
// TODO(hamaji): Check if this directory is a standard
// ONNX test directory.
DumpONNX(filename + "/model.onnx", args);
std::vector<std::string> filenames;
for (const std::string& test_dir_name : ListDir(filename)) {
if (!IsDir(test_dir_name)) {
continue;
}
for (const std::string& pb_name : ListDir(test_dir_name)) {
if (pb_name.size() > 3 && pb_name.substr(pb_name.size() - 3, 3) == ".pb") {
filenames.push_back(pb_name);
}
}
}
std::sort(filenames.begin(), filenames.end());
for (const std::string& filename : filenames) {
std::cout << "=== " << filename << " ===\n";
DumpTensor(filename);
}
}
}
}
} // namespace
} // namespace runtime
} // namespace chainer_compiler
int main(int argc, char** argv) {
chainer_compiler::runtime::RunMain(argc, argv);
}