Skip to content
This repository has been archived by the owner on Jan 3, 2023. It is now read-only.

Commit

Permalink
Fix serializer so it can serialize PP BERT large (#4127) (#4137)
Browse files Browse the repository at this point in the history
Co-authored-by: Scott Cyphers <[email protected]>

Co-authored-by: Scott Cyphers <[email protected]>
  • Loading branch information
rkimballn1 and diyessi committed Jan 7, 2020
1 parent 33ca594 commit 916ae6e
Showing 1 changed file with 12 additions and 34 deletions.
46 changes: 12 additions & 34 deletions src/ngraph/serializer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,6 @@ class JSONSerializer
json serialize_output(const Output<Node>& output);
json serialize_parameter_vector(const ParameterVector& parameters);
json serialize_output_vector(const OutputVector& output_vector);
json serialize_node_reference(const Node& node);
json serialize_node(const Node& node);
json serialize_axis_set(const AxisSet& axis_set);
json serialize_tensor_iterator_input_description(
Expand All @@ -127,8 +126,6 @@ class JSONSerializer
bool m_serialize_output_shapes{false};
bool m_binary_constant_data{false};
json m_json_nodes;
set<const Node*> m_nodes_serialized;
queue<const Node*> m_nodes_to_serialize;
};

class JSONDeserializer
Expand Down Expand Up @@ -444,7 +441,7 @@ json JSONSerializer::serialize_parameter_vector(const ParameterVector& parameter
json json_parameters = json::array();
for (auto param : parameters)
{
json_parameters.push_back(serialize_node_reference(*param));
json_parameters.push_back(param->get_name());
}
return json_parameters;
}
Expand All @@ -458,9 +455,16 @@ json JSONSerializer::serialize_function(const Function& f)
// TODO Functions can return multiple results
for (size_t i = 0; i < f.get_output_size(); ++i)
{
function["result"].push_back(serialize_node_reference(*f.get_output_op(i)));
function["result"].push_back(f.get_output_op(i)->get_name());
}
function["ops"] = m_json_nodes;

json nodes;
for (shared_ptr<Node> node : f.get_ordered_ops(true))
{
nodes.push_back(serialize_node(*node));
}

function["ops"] = nodes;
return function;
}

Expand Down Expand Up @@ -2996,36 +3000,11 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
return node;
}

json JSONSerializer::serialize_node_reference(const Node& n)
{
if (m_nodes_serialized.count(&n) != 1)
{
m_nodes_to_serialize.push(&n);
if (m_nodes_to_serialize.size() == 1)
{
// Nothing in the queue
stack<json> serialized_nodes;
while (!m_nodes_to_serialize.empty())
{
const Node* next_node = m_nodes_to_serialize.front();
m_nodes_to_serialize.pop();
serialized_nodes.push(serialize_node(*next_node));
}
while (serialized_nodes.size() > 0)
{
m_json_nodes.push_back(serialized_nodes.top());
serialized_nodes.pop();
}
}
}
return n.get_name();
}

json JSONSerializer::serialize_output(const Output<Node>& output)
{
json result;
auto index = output.get_index();
json json_node_reference = serialize_node_reference(*output.get_node());
json json_node_reference = output.get_node()->get_name();
if (index == 0)
{
result = json_node_reference;
Expand All @@ -3050,7 +3029,6 @@ json JSONSerializer::serialize_output_vector(const OutputVector& output_vector)

json JSONSerializer::serialize_node(const Node& n)
{
m_nodes_serialized.insert(&n);
const NodeTypeInfo& type_info = n.get_type_info();
json jtype_info;
jtype_info["name"] = type_info.name;
Expand All @@ -3077,7 +3055,7 @@ json JSONSerializer::serialize_node(const Node& n)
}
for (auto cdep : n.get_control_dependencies())
{
control_deps.push_back(serialize_node_reference(*cdep));
control_deps.push_back(cdep->get_name());
}
for (auto& output : n.outputs())
{
Expand Down

0 comments on commit 916ae6e

Please sign in to comment.