Skip to content

Commit

Permalink
Merge pull request #603 from pfnet-research/add_check_for_menoh_api
Browse files Browse the repository at this point in the history
Add check for menoh api
  • Loading branch information
shinh authored Aug 27, 2019
2 parents b8ba8b5 + bacce56 commit 67dcbd4
Showing 1 changed file with 21 additions and 7 deletions.
28 changes: 21 additions & 7 deletions menoh/menoh_chainer_compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@
#include <compiler/onnx.h>
#include <compiler/passes.h>
#include <compiler/util.h>
#include <menoh/menoh_chainer_compiler_util.hpp>
#include <runtime/chainerx_util.h>
#include <runtime/chxvm.h>
#include <runtime/chxvm.pb.h>
#include <runtime/chxvm_var.h>
#include <tools/util.h>
#include <menoh/menoh_chainer_compiler_util.hpp>

namespace menoh_impl {
using fixed_array = std::array<char, MENOH_ERROR_MESSAGE_MAX_LENGTH>;
Expand Down Expand Up @@ -178,7 +178,7 @@ onnx::TensorProto::DataType menoh_dtype_to_xtensor_dtype(menoh_dtype mdtype) {
} else if (mdtype == menoh_dtype_bool) {
return onnx::TensorProto::BOOL;
} else {
assert(!"Not Implemeneted");
CHECK(false) << "Not Implemeneted menoh_dtype: " << mdtype;
}
return onnx::TensorProto::UNDEFINED;
}
Expand Down Expand Up @@ -291,6 +291,16 @@ bool has_dynamic_shape(array_profile const& a) {
size_t total_size(std::vector<int64_t> const& dims) {
return std::accumulate(dims.begin(), dims.end(), 1, std::multiplies<int64_t>());
}

size_t total_size_in_bytes(menoh_dtype dtype, std::vector<int64_t> const& dims) {
int64_t dtype_size;
menoh_dtype_size(dtype, &dtype_size);
return dtype_size * total_size(dims);
}

size_t total_size_in_bytes(array_profile const& p) {
return total_size_in_bytes(p.dtype(), p.dims());
}
} // namespace menoh_impl
struct menoh_variable_profile_table_builder {
std::unordered_map<std::string, menoh_impl::array_profile> input_profiles;
Expand Down Expand Up @@ -592,7 +602,7 @@ menoh_error_code menoh_build_model(
auto xgraph = *(builder->xgraph);

// Set initializer
assert(xgraph.initializer().empty());
CHECK(xgraph.initializer().empty());
for (onnx::TensorProto const& xtensor : model_data->xgraph.initializer()) {
*(xgraph.add_initializer()) = xtensor;
}
Expand All @@ -616,7 +626,7 @@ menoh_error_code menoh_build_model(
for (const chainer_compiler::Value* input : graph.input_values()) {
if (!input->initializer()) { // user input is input which doesn't have initializer
auto p = builder->input_profile_table.find(input->name());
assert(p != builder->input_profile_table.end());
CHECK(p != builder->input_profile_table.end()) << input->name() << " is not found in input_profile_table";
void* datap = nullptr;
auto found = builder->external_buffer_handle_table.find(input->name());
if (found != builder->external_buffer_handle_table.end()) {
Expand All @@ -642,7 +652,7 @@ menoh_error_code menoh_build_model(
datap = found->second;
} else {
auto p = builder->output_profile_table.find(output->name());
assert(p != builder->output_profile_table.end());
CHECK(p != builder->output_profile_table.end()) << output->name() << " is not found in output_profile_table";
auto data = allocate_buffer(p->second);
buffer_holder.push_back(data);
datap = data.get();
Expand Down Expand Up @@ -676,7 +686,7 @@ menoh_error_code menoh_model_get_variable_buffer_handle(const menoh_model_handle
auto found = model->outputs.find(variable_name);
if (found == model->outputs.end()) {
auto found = model->inputs.find(variable_name);
if(found == model->inputs.end()) {
if (found == model->inputs.end()) {
auto message = std::string("menoh variable not found: ") + variable_name;
menoh_impl::set_last_error_message(message.c_str());
return menoh_error_code_variable_not_found;
Expand Down Expand Up @@ -745,10 +755,14 @@ menoh_error_code menoh_model_run(menoh_model_handle model) {
auto outputs = model->chxvm->Run(model->inputs, model->chxvm_options);
for (auto output : outputs) {
auto found = model->outputs.find(output.first);
assert(found != model->outputs.end() && "output buffer not found");
CHECK(found != model->outputs.end()) << "output buffer for " << output.first << " is not found";
auto const& array = chainerx::AsContiguous(output.second->GetArray());
auto const& shape = array.shape();
auto bytesize = shape.GetTotalSize() * chainerx::GetItemSize(array.dtype());
CHECK(model->variable_profiles.find(output.first) != model->variable_profiles.end())
<< output.first << " is not found in variable_profiles";
CHECK_EQ(bytesize, menoh_impl::total_size_in_bytes(model->variable_profiles.find(output.first)->second))
<< "allocated output buffer size is not equal to cc's output buffer size";
std::copy(
static_cast<uint8_t*>(array.raw_data()),
static_cast<uint8_t*>(array.raw_data()) + bytesize,
Expand Down

0 comments on commit 67dcbd4

Please sign in to comment.