Skip to content

Commit

Permalink
Merge pull request pfnet-research#159 from durswd/format
Browse files Browse the repository at this point in the history
apply format
  • Loading branch information
durswd authored Apr 13, 2019
2 parents 601f4d8 + 32185db commit dd2bc00
Show file tree
Hide file tree
Showing 15 changed files with 426 additions and 235 deletions.
12 changes: 8 additions & 4 deletions elichika/elichika/chainer2onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,14 @@
import elichika.layers_buikdin as lb
import elichika.functions_buildin as fb


class ONNXModel:
def __init__(self):
self.model = None
self.inputs = []
self.outputs = []


def compile_model(model, inputs) -> 'ONNXModel':

oc.chainer_f_converter.clear()
Expand All @@ -56,21 +58,23 @@ def compile_model(model, inputs) -> 'ONNXModel':
oc.preprocess(graph_, True)

generator = oc.ONNXGenerator()
model = generator.generate_model(graph_.input_values, graph_.output_values, graph_, model)
model = generator.generate_model(
graph_.input_values, graph_.output_values, graph_, model)

# check inputs


onnx_model = ONNXModel()
onnx_model.model = model
onnx_model.inputs = graph_.input_values
onnx_model.outputs = graph_.output_values
return onnx_model

def save_model(path : 'str', model : 'ModelProto'):

def save_model(path: 'str', model: 'ModelProto'):
with open(path, "wb") as f:
f.write(model.SerializeToString())

def save_model_as_text(path : 'str', model : 'ModelProto'):

def save_model_as_text(path: 'str', model: 'ModelProto'):
with open(path, "w") as f:
print(model, file=f)
14 changes: 9 additions & 5 deletions elichika/elichika/functions_buildin.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,22 @@

import elichika.onnx_converters as oc


def convert_relu(onnx_graph, node):
onnx_graph.add_node('Relu',
[node.inputs[0]],
[node.outputs[0]],
name = str(node.lineprop))
onnx_graph.add_node('Relu',
[node.inputs[0]],
[node.outputs[0]],
name=str(node.lineprop))


def convert_softmax(onnx_graph, node):
onnx_graph.add_node(
"Softmax",
[node.inputs[0]],
[node.outputs[0]],
str(node.lineprop),
axis = oc.try_get_attribute(node.inputs[1]))
axis=oc.try_get_attribute(node.inputs[1]))


def convert_pad_sequence(onnx_graph, node):
kwargs = {}
Expand All @@ -55,6 +58,7 @@ def convert_pad_sequence(onnx_graph, node):
str(node.lineprop),
**kwargs)


def convert_softmax_cross_entropy(onnx_graph, node):

onnx_graph.add_node(
Expand Down
16 changes: 10 additions & 6 deletions elichika/elichika/layers_buikdin.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@

import elichika.onnx_converters as oc

def convert_onnx_chainer_linear(onnx_graph : 'ONNXGraph', node : 'nodes.Node'):
chainer_inst = node.func.owner.inst # type: chainer.links.Linear

def convert_onnx_chainer_linear(onnx_graph: 'ONNXGraph', node: 'nodes.Node'):
chainer_inst = node.func.owner.inst # type: chainer.links.Linear
onnx_name = oc.node2onnx_parameter[node].onnx_name

x = oc.ONNXValue(onnx_graph, node.inputs[0])
Expand All @@ -42,7 +43,8 @@ def convert_onnx_chainer_linear(onnx_graph : 'ONNXGraph', node : 'nodes.Node'):

(batch_size_1,) = onnx_graph.add_node(
'Gather',
[x_shape, oc.ONNXValue(onnx_graph, np.array(0, dtype=np.int64), [onnx_name, '/Zero'])],
[x_shape, oc.ONNXValue(onnx_graph, np.array(
0, dtype=np.int64), [onnx_name, '/Zero'])],
[None],
str(node.lineprop))

Expand All @@ -55,7 +57,8 @@ def convert_onnx_chainer_linear(onnx_graph : 'ONNXGraph', node : 'nodes.Node'):

(mat_shape,) = onnx_graph.add_node(
'Concat',
[batch_size_2, oc.ONNXValue(onnx_graph, np.array([-1], dtype=np.int64), [onnx_name, '/Minus1'])],
[batch_size_2, oc.ONNXValue(onnx_graph, np.array(
[-1], dtype=np.int64), [onnx_name, '/Minus1'])],
[None],
str(node.lineprop),
axis=0)
Expand Down Expand Up @@ -92,8 +95,9 @@ def convert_onnx_chainer_linear(onnx_graph : 'ONNXGraph', node : 'nodes.Node'):
[o],
str(node.lineprop))

def convert_onnx_chainer_convolution2d(onnx_graph : 'ONNXGraph', node : 'nodes.Node'):
chainer_inst = node.func.owner.inst # type: chainer.links.Convolution2D

def convert_onnx_chainer_convolution2d(onnx_graph: 'ONNXGraph', node: 'nodes.Node'):
chainer_inst = node.func.owner.inst # type: chainer.links.Convolution2D
onnx_name = oc.node2onnx_parameter[node].onnx_name

ksize = oc.size2d(chainer_inst.ksize)
Expand Down
Loading

0 comments on commit dd2bc00

Please sign in to comment.