diff --git a/elichika/elichika/chainer2onnx.py b/elichika/elichika/chainer2onnx.py index 4989dd10..1c7974a6 100644 --- a/elichika/elichika/chainer2onnx.py +++ b/elichika/elichika/chainer2onnx.py @@ -24,12 +24,14 @@ import elichika.layers_buikdin as lb import elichika.functions_buildin as fb + class ONNXModel: def __init__(self): self.model = None self.inputs = [] self.outputs = [] + def compile_model(model, inputs) -> 'ONNXModel': oc.chainer_f_converter.clear() @@ -56,21 +58,23 @@ def compile_model(model, inputs) -> 'ONNXModel': oc.preprocess(graph_, True) generator = oc.ONNXGenerator() - model = generator.generate_model(graph_.input_values, graph_.output_values, graph_, model) + model = generator.generate_model( + graph_.input_values, graph_.output_values, graph_, model) # check inputs - onnx_model = ONNXModel() onnx_model.model = model onnx_model.inputs = graph_.input_values onnx_model.outputs = graph_.output_values return onnx_model -def save_model(path : 'str', model : 'ModelProto'): + +def save_model(path: 'str', model: 'ModelProto'): with open(path, "wb") as f: f.write(model.SerializeToString()) -def save_model_as_text(path : 'str', model : 'ModelProto'): + +def save_model_as_text(path: 'str', model: 'ModelProto'): with open(path, "w") as f: print(model, file=f) diff --git a/elichika/elichika/functions_buildin.py b/elichika/elichika/functions_buildin.py index fb6030ae..0ff4f626 100644 --- a/elichika/elichika/functions_buildin.py +++ b/elichika/elichika/functions_buildin.py @@ -22,11 +22,13 @@ import elichika.onnx_converters as oc + def convert_relu(onnx_graph, node): - onnx_graph.add_node('Relu', - [node.inputs[0]], - [node.outputs[0]], - name = str(node.lineprop)) + onnx_graph.add_node('Relu', + [node.inputs[0]], + [node.outputs[0]], + name=str(node.lineprop)) + def convert_softmax(onnx_graph, node): onnx_graph.add_node( @@ -34,7 +36,8 @@ def convert_softmax(onnx_graph, node): [node.inputs[0]], [node.outputs[0]], str(node.lineprop), - axis = oc.try_get_attribute(node.inputs[1])) + axis=oc.try_get_attribute(node.inputs[1])) + def convert_pad_sequence(onnx_graph, node): kwargs = {} @@ -55,6 +58,7 @@ def convert_pad_sequence(onnx_graph, node): str(node.lineprop), **kwargs) + def convert_softmax_cross_entropy(onnx_graph, node): onnx_graph.add_node( diff --git a/elichika/elichika/layers_buikdin.py b/elichika/elichika/layers_buikdin.py index 04a1c9ec..7a593e41 100644 --- a/elichika/elichika/layers_buikdin.py +++ b/elichika/elichika/layers_buikdin.py @@ -22,8 +22,9 @@ import elichika.onnx_converters as oc -def convert_onnx_chainer_linear(onnx_graph : 'ONNXGraph', node : 'nodes.Node'): - chainer_inst = node.func.owner.inst # type: chainer.links.Linear + +def convert_onnx_chainer_linear(onnx_graph: 'ONNXGraph', node: 'nodes.Node'): + chainer_inst = node.func.owner.inst # type: chainer.links.Linear onnx_name = oc.node2onnx_parameter[node].onnx_name x = oc.ONNXValue(onnx_graph, node.inputs[0]) @@ -42,7 +43,8 @@ def convert_onnx_chainer_linear(onnx_graph : 'ONNXGraph', node : 'nodes.Node'): (batch_size_1,) = onnx_graph.add_node( 'Gather', - [x_shape, oc.ONNXValue(onnx_graph, np.array(0, dtype=np.int64), [onnx_name, '/Zero'])], + [x_shape, oc.ONNXValue(onnx_graph, np.array( + 0, dtype=np.int64), [onnx_name, '/Zero'])], [None], str(node.lineprop)) @@ -55,7 +57,8 @@ def convert_onnx_chainer_linear(onnx_graph : 'ONNXGraph', node : 'nodes.Node'): (mat_shape,) = onnx_graph.add_node( 'Concat', - [batch_size_2, oc.ONNXValue(onnx_graph, np.array([-1], dtype=np.int64), [onnx_name, '/Minus1'])], + [batch_size_2, oc.ONNXValue(onnx_graph, np.array( + [-1], dtype=np.int64), [onnx_name, '/Minus1'])], [None], str(node.lineprop), axis=0) @@ -92,8 +95,9 @@ def convert_onnx_chainer_linear(onnx_graph : 'ONNXGraph', node : 'nodes.Node'): [o], str(node.lineprop)) -def convert_onnx_chainer_convolution2d(onnx_graph : 'ONNXGraph', node : 'nodes.Node'): - chainer_inst = node.func.owner.inst # type: chainer.links.Convolution2D + +def convert_onnx_chainer_convolution2d(onnx_graph: 'ONNXGraph', node: 'nodes.Node'): + chainer_inst = node.func.owner.inst # type: chainer.links.Convolution2D onnx_name = oc.node2onnx_parameter[node].onnx_name ksize = oc.size2d(chainer_inst.ksize) diff --git a/elichika/elichika/onnx_converters.py b/elichika/elichika/onnx_converters.py index 0eb0e3e0..614483f4 100644 --- a/elichika/elichika/onnx_converters.py +++ b/elichika/elichika/onnx_converters.py @@ -20,37 +20,44 @@ import numpy as np import collections + def size2d(x): if isinstance(x, collections.Iterable): return x return (x, x) + def get_onnx_dtype(dtype): a = np.zeros((), dtype=dtype) dt = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[a.dtype] return dt + assigned_names = [] node2onnx_parameter = {} value2onnx_parameter = {} + class NodeONNXParameter: def __init__(self, onnx_name, value): self.onnx_name = onnx_name self.original_value = value + class ValueONNXParameter: def __init__(self, onnx_name, value): self.onnx_name = onnx_name self.original_value = value + def onnx_name(value): if isinstance(value, values.Value): return value2onnx_parameter[value].onnx_name if isinstance(value, nodes.Node): return node2onnx_parameter[value].onnx_name -def generate_onnx_value_name(value : 'values.Value', none_name = ''): + +def generate_onnx_value_name(value: 'values.Value', none_name=''): base_name = '' base_name = value.name @@ -73,7 +80,8 @@ def generate_onnx_value_name(value : 'values.Value', none_name = ''): assigned_names.append(name) return name -def generate_onnx_node_name(node : 'nodes.Node'): + +def generate_onnx_node_name(node: 'nodes.Node'): base_name = str(node) ind = 0 @@ -86,7 +94,7 @@ def generate_onnx_node_name(node : 'nodes.Node'): return name -def generate_onnx_name(name : 'str'): +def generate_onnx_name(name: 'str'): base_name = str(name) ind = 0 @@ -99,11 +107,13 @@ def generate_onnx_name(name : 'str'): return name -def assign_onnx_name_to_value(value : 'values.Value', none_name = ''): +def assign_onnx_name_to_value(value: 'values.Value', none_name=''): if not value in value2onnx_parameter: - value2onnx_parameter[value] = ValueONNXParameter(generate_onnx_value_name(value, none_name), value) + value2onnx_parameter[value] = ValueONNXParameter( + generate_onnx_value_name(value, none_name), value) -def assign_onnx_name(graph : 'graphs.Graph'): + +def assign_onnx_name(graph: 'graphs.Graph'): for v in graph.input_values: assign_onnx_name_to_value(v) @@ -119,12 +129,14 @@ def assign_onnx_name(graph : 'graphs.Graph'): assign_onnx_name_to_value(output) if not node in node2onnx_parameter: - node2onnx_parameter[node] = NodeONNXParameter(generate_onnx_node_name(node), node) + node2onnx_parameter[node] = NodeONNXParameter( + generate_onnx_node_name(node), node) for subgraph in node.subgraphs: assign_onnx_name(subgraph) -def preprocess(graph : 'graphs.Graph', isMain : 'bool'): + +def preprocess(graph: 'graphs.Graph', isMain: 'bool'): replacing = {} for value in graph.output_values: @@ -151,7 +163,8 @@ def preprocess(graph : 'graphs.Graph', isMain : 'bool'): node.set_outputs([copied_value]) graph.add_node(node) - copied_value.name = output_values[i].name + '_cp_out_' + str(duplicates[output_values[i]]) + copied_value.name = output_values[i].name + \ + '_cp_out_' + str(duplicates[output_values[i]]) duplicates[output_values[i]] += 1 output_values[i] = copied_value else: @@ -163,10 +176,12 @@ def preprocess(graph : 'graphs.Graph', isMain : 'bool'): for subgraph in node.subgraphs: preprocess(subgraph, False) + chainer_l_converter = {} chainer_f_converter = {} -def convert_node_aug_assign(onnx_graph, node : 'nodes.NodeAugAssign'): + +def convert_node_aug_assign(onnx_graph, node: 'nodes.NodeAugAssign'): binops = {} binops[nodes.BinOpType.Add] = 'Add' binops[nodes.BinOpType.Sub] = 'Sub' @@ -175,25 +190,28 @@ def convert_node_aug_assign(onnx_graph, node : 'nodes.NodeAugAssign'): # TODO: fix for reference types if isinstance(node.target, values.ListValue) or isinstance(node.target, values.TupleValue): - assert(isinstance(node.value, values.ListValue) or isinstance(node.value, values.TupleValue)) + assert(isinstance(node.value, values.ListValue) + or isinstance(node.value, values.TupleValue)) binops[nodes.BinOpType.Add] = 'ChainerGenericAdd' target = ONNXValue(onnx_graph, node.target) value = ONNXValue(onnx_graph, node.value) seq_target = target.create_sequence() seq_value = value.create_sequence() - onnx_graph.add_node(binops[node.binop], [seq_target, seq_value], [value2onnx_parameter[node.outputs[0]].onnx_name], None) + onnx_graph.add_node(binops[node.binop], [seq_target, seq_value], [ + value2onnx_parameter[node.outputs[0]].onnx_name], None) else: onnx_node = oh.make_node( binops[node.binop], [value2onnx_parameter[node.target].onnx_name, - value2onnx_parameter[node.value].onnx_name], + value2onnx_parameter[node.value].onnx_name], [value2onnx_parameter[node.outputs[0]].onnx_name]) onnx_graph.nodes.append(onnx_node) -def convert_node_bin_op(onnx_graph, node : 'nodes.NodeBinOp'): + +def convert_node_bin_op(onnx_graph, node: 'nodes.NodeBinOp'): binops = {} binops[nodes.BinOpType.Add] = 'Add' binops[nodes.BinOpType.Sub] = 'Sub' @@ -201,20 +219,24 @@ def convert_node_bin_op(onnx_graph, node : 'nodes.NodeBinOp'): binops[nodes.BinOpType.Unknown] = 'Add' if isinstance(node.left, values.ListValue) or isinstance(node.left, values.TupleValue): - assert(isinstance(node.right, values.ListValue) or isinstance(node.right, values.TupleValue)) + assert(isinstance(node.right, values.ListValue) + or isinstance(node.right, values.TupleValue)) binops[nodes.BinOpType.Add] = 'ChainerGenericAdd' left = ONNXValue(onnx_graph, node.left) right = ONNXValue(onnx_graph, node.right) seq_left = left.create_sequence() seq_right = right.create_sequence() - onnx_graph.add_node(binops[node.binop], [seq_left, seq_right], [value2onnx_parameter[node.outputs[0]].onnx_name], None) + onnx_graph.add_node(binops[node.binop], [seq_left, seq_right], [ + value2onnx_parameter[node.outputs[0]].onnx_name], None) else: - onnx_node = oh.make_node(binops[node.binop], [value2onnx_parameter[node.left].onnx_name, value2onnx_parameter[node.right].onnx_name], [value2onnx_parameter[node.outputs[0]].onnx_name]) + onnx_node = oh.make_node(binops[node.binop], [value2onnx_parameter[node.left].onnx_name, + value2onnx_parameter[node.right].onnx_name], [value2onnx_parameter[node.outputs[0]].onnx_name]) onnx_graph.nodes.append(onnx_node) -def convert_node_call(onnx_graph, node : 'nodes.NodeCall'): + +def convert_node_call(onnx_graph, node: 'nodes.NodeCall'): if node.func.base_func is not None: chainer_f_converter[node.func.base_func](onnx_graph, node) @@ -230,7 +252,8 @@ def convert_node_call(onnx_graph, node : 'nodes.NodeCall'): if isinstance(node.func, functions_builtin.NDArrayShapeFunction): # shape - op_shape_temp = onnx_graph.new_empty_tensor(['TODO'], np.int32, value2onnx_parameter[node.outputs[0]].onnx_name + '/ShapeTemp') + op_shape_temp = onnx_graph.new_empty_tensor( + ['TODO'], np.int32, value2onnx_parameter[node.outputs[0]].onnx_name + '/ShapeTemp') onnx_node = oh.make_node( "Shape", @@ -251,11 +274,13 @@ def convert_node_call(onnx_graph, node : 'nodes.NodeCall'): if isinstance(node.func, values_builtin.ChainerLinkFunction): original_inst = node.func.owner.inst chainer_l_converter[type(original_inst)](onnx_graph, node) - -def convert_node_unary_op(onnx_graph, node : 'nodes.NodeUnaryOp'): + + +def convert_node_unary_op(onnx_graph, node: 'nodes.NodeUnaryOp'): if node.unaryop == nodes.UnaryOpType.UAdd: - zero_ = ONNXValue(onnx_graph, np.array(0, dtype=np.float), [node,'/Zero'], is_constant=True) + zero_ = ONNXValue(onnx_graph, np.array(0, dtype=np.float), [ + node, '/Zero'], is_constant=True) onnx_node = oh.make_node( 'Add', [zero_.name, value2onnx_parameter[node.operand].onnx_name], @@ -263,7 +288,8 @@ def convert_node_unary_op(onnx_graph, node : 'nodes.NodeUnaryOp'): onnx_graph.nodes.append(onnx_node) if node.unaryop == nodes.UnaryOpType.USub: - zero_ = ONNXValue(onnx_graph, np.array(0, dtype=np.float), [node,'/Zero'], is_constant=True) + zero_ = ONNXValue(onnx_graph, np.array(0, dtype=np.float), [ + node, '/Zero'], is_constant=True) onnx_node = oh.make_node( 'Sub', [zero_.name, value2onnx_parameter[node.operand].onnx_name], @@ -278,7 +304,6 @@ def convert_node_unary_op(onnx_graph, node : 'nodes.NodeUnaryOp'): onnx_graph.nodes.append(onnx_node) - class ONNXValue: """ A wrapper of ONNX value @@ -289,10 +314,11 @@ class ONNXValue: name : a value of name. string or array is_constant : if this value can be converted as constant, it makes constant values. """ - def __init__(self, onnx_graph : 'ONNXGraph', any_value = None, name = None, is_constant = True): - assert(isinstance(onnx_graph,ONNXGraph)) - self.value = None # values.Value - self.np_value = None # np.array + + def __init__(self, onnx_graph: 'ONNXGraph', any_value=None, name=None, is_constant=True): + assert(isinstance(onnx_graph, ONNXGraph)) + self.value = None # values.Value + self.np_value = None # np.array self.onnx_graph = onnx_graph self.is_constant = is_constant self.name = '' @@ -302,7 +328,7 @@ def generate_name(): if(isinstance(name, list)): for n in name: - if isinstance(n,values.Value): + if isinstance(n, values.Value): name_ += value2onnx_parameter[n].onnx_name if isinstance(n, nodes.Node): name_ += node2onnx_parameter[n].onnx_name @@ -328,7 +354,8 @@ def generate_name(): elif id(any_value) in onnx_graph.generator.param2name.keys(): self.np_value = any_value.data self.name = onnx_graph.generator.param2name[id(any_value)] - self.tensor = onnx_graph.new_tensor_with_np(self.np_value, self.name) + self.tensor = onnx_graph.new_tensor_with_np( + self.np_value, self.name) elif isinstance(any_value, np.ndarray): self.np_value = any_value @@ -336,17 +363,21 @@ def generate_name(): if self.is_constant: tensor = numpy_helper.from_array(any_value, name=self.name) - self.onnx_graph.add_node('Constant', [], [self.name], self.name, value=tensor) + self.onnx_graph.add_node( + 'Constant', [], [self.name], self.name, value=tensor) else: - self.tensor = onnx_graph.new_tensor_with_np(self.np_value, self.name) + self.tensor = onnx_graph.new_tensor_with_np( + self.np_value, self.name) elif(any_value == np.float32 or any_value == np.float64 or any_value == np.int32 or any_value == np.int64): self.name = generate_name() - self.tensor = self.onnx_graph.new_empty_tensor(['TODO'], any_value, self.name) + self.tensor = self.onnx_graph.new_empty_tensor( + ['TODO'], any_value, self.name) def create_sequence(self) -> 'ONNXValue': if(isinstance(self.value, values.ListValue)): - ret = ONNXValue(self.onnx_graph,values.ListValue(), [self.name, '/create_sequence']) + ret = ONNXValue(self.onnx_graph, values.ListValue(), [ + self.name, '/create_sequence']) self.onnx_graph.add_node( "Identity", [self.name], @@ -356,9 +387,10 @@ def create_sequence(self) -> 'ONNXValue': return ret if(isinstance(self.value, values.TupleValue)): - value = self.value # values.TupleValue + value = self.value # values.TupleValue if value.internal_value is None: - ret = ONNXValue(self.onnx_graph,values.ListValue(), [self.name, '/create_sequence']) + ret = ONNXValue(self.onnx_graph, values.ListValue(), [ + self.name, '/create_sequence']) self.onnx_graph.add_node( "Identity", [self.name], @@ -366,10 +398,12 @@ def create_sequence(self) -> 'ONNXValue': str('create_sequence')) else: # TODO adhoc code - ret = ONNXValue(self.onnx_graph,values.ListValue(), [self.name, '/create_sequence']) + ret = ONNXValue(self.onnx_graph, values.ListValue(), [ + self.name, '/create_sequence']) self.onnx_graph.add_node( "ChainerSequenceCreate", - [ONNXValue(self.onnx_graph, np.array(v.internal_value), [self.name, '/c'], is_constant=True) for v in value.internal_value], + [ONNXValue(self.onnx_graph, np.array(v.internal_value), [ + self.name, '/c'], is_constant=True) for v in value.internal_value], [ret], str('create_sequence')) @@ -378,8 +412,7 @@ def create_sequence(self) -> 'ONNXValue': assert(False) - -def try_get_attribute(value, calling_node : 'nodes.Node' = None): +def try_get_attribute(value, calling_node: 'nodes.Node' = None): if calling_node is None: lineinfo = 'unknown' @@ -412,6 +445,7 @@ def try_get_attribute(value, calling_node : 'nodes.Node' = None): print("Cannot convert a value into an attribute") return -1 + class ONNXInitrializer: def __init__(self): self.tensor_value = None @@ -420,8 +454,9 @@ def __init__(self): self.dt = 0 self.shape = () + class ONNXGraph: - def __init__(self, generator : 'ONNXGenerator', parent : 'ONNXGraph'): + def __init__(self, generator: 'ONNXGenerator', parent: 'ONNXGraph'): self.generator = generator self.parent = parent self.nodes = [] @@ -495,7 +530,7 @@ def new_tensor_with_np(self, ndarray_, name): ''' tensor = numpy_helper.from_array(ndarray_, name=name) dt = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(ndarray_.dtype)] - + tensor_value = oh.make_tensor_value_info(name, dt, ndarray_.shape) initializer = ONNXInitrializer() @@ -518,7 +553,7 @@ def new_tensor_with_value(self, value): it is for constant input ''' name = self.get_value_name(value) - + if isinstance(value, values.NumberValue): if value.internal_value is None: # any value @@ -539,8 +574,8 @@ def new_tensor_with_value(self, value): arr = np.array(False) return self.new_tensor_with_np(arr, name) - - print('Warning : Found uknown type {} in new_tensor_with_value. Float is stored.'.format(type(value))) + print('Warning : Found uknown type {} in new_tensor_with_value. Float is stored.'.format( + type(value))) arr = np.array(0.0, dtype=np.float32) return self.new_tensor_with_np(arr, name) @@ -593,12 +628,14 @@ def get_value_name(self, value): assert(False) def set_input(self, input): - self.input_tensor = [self.generator.onnx_tensors[value2onnx_parameter[x].onnx_name] for x in input] + self.input_tensor = [ + self.generator.onnx_tensors[value2onnx_parameter[x].onnx_name] for x in input] def set_output(self, output): - self.output_tensor = [self.generator.onnx_tensors[value2onnx_parameter[x].onnx_name] for x in output] + self.output_tensor = [ + self.generator.onnx_tensors[value2onnx_parameter[x].onnx_name] for x in output] - def generate_graph(self, name : 'str', isMain = False): + def generate_graph(self, name: 'str', isMain=False): input_tensor_and_initializer = self.input_tensor.copy() initializers = [] @@ -615,6 +652,7 @@ def generate_graph(self, name : 'str', isMain = False): return oh.make_graph(self.nodes, name, input_tensor_and_initializer, self.output_tensor, initializer=initializers) + class ONNXGenerator: def __init__(self): self.onnx_graphs = [] @@ -622,7 +660,7 @@ def __init__(self): self.onnx_tensors = {} self.param2name = {} - def generate_graph(self, inputs, outputs, graph : 'graphs.Graph', parent : 'ONNXGraph', isMain = False): + def generate_graph(self, inputs, outputs, graph: 'graphs.Graph', parent: 'ONNXGraph', isMain=False): onnx_graph = ONNXGraph(self, parent) def generate_input_tensors(inputs_): @@ -635,14 +673,16 @@ def generate_tensor_constant(input_): # TODO improve def generate_tensor_constant_constant(input_): t = onnx_graph.new_empty_tensor_with_value(input_) - tensor = numpy_helper.from_array(np.array(input_.internal_value), name=value2onnx_parameter[input_].onnx_name) - onnx_node = oh.make_node('Constant', [], [t.name], value=tensor) + tensor = numpy_helper.from_array( + np.array(input_.internal_value), name=value2onnx_parameter[input_].onnx_name) + onnx_node = oh.make_node( + 'Constant', [], [t.name], value=tensor) onnx_graph.nodes.append(onnx_node) def generate_tensor(input_): tensor = onnx_graph.new_empty_tensor_with_value(input_) - if not (value2onnx_parameter[input].onnx_name in self.onnx_tensors.keys()): + if not (value2onnx_parameter[input].onnx_name in self.onnx_tensors.keys()): if input.generator is None and not (input in inputs) and input.internal_value is not None and not isinstance(input, values.TupleValue)is not None and not isinstance(input, values.StrValue): generate_tensor_constant_constant(input) elif input.generator is None and not (input in inputs): @@ -662,8 +702,10 @@ def generate_tensor(output_): # TODO improve if output.generator is None: - tensor = numpy_helper.from_array(np.array(output.internal_value), name=value2onnx_parameter[output].onnx_name) - onnx_node = oh.make_node('Constant', [], [t.name], value=tensor) + tensor = numpy_helper.from_array( + np.array(output.internal_value), name=value2onnx_parameter[output].onnx_name) + onnx_node = oh.make_node( + 'Constant', [], [t.name], value=tensor) onnx_graph.nodes.append(onnx_node) generate_input_tensors(inputs) @@ -676,7 +718,7 @@ def generate_tensor(output_): for node in graph.nodes: if isinstance(node, nodes.NodeCopy): - node_ = node # type: nodes.Copy + node_ = node # type: nodes.Copy onnx_node = oh.make_node( 'Identity', [value2onnx_parameter[node_.value].onnx_name], @@ -706,7 +748,7 @@ def generate_tensor(output_): convert_node_unary_op(onnx_graph, node) if isinstance(node, nodes.NodeCompare): - node_ = node # type: nodes.NodeCompare + node_ = node # type: nodes.NodeCompare op_str = None op_not = False @@ -733,30 +775,36 @@ def generate_tensor(output_): op_not = True if op_not: - op_not_temp = onnx_graph.new_empty_tensor(['TODO'], np.bool, value2onnx_parameter[node.outputs[0]].onnx_name + '/NotTemp') - onnx_node1 = oh.make_node(op_str, [value2onnx_parameter[node_.left].onnx_name, value2onnx_parameter[node_.right].onnx_name], [op_not_temp.name]) - onnx_node2 = oh.make_node('Not', [op_not_temp.name], [value2onnx_parameter[node.outputs[0]].onnx_name]) + op_not_temp = onnx_graph.new_empty_tensor( + ['TODO'], np.bool, value2onnx_parameter[node.outputs[0]].onnx_name + '/NotTemp') + onnx_node1 = oh.make_node(op_str, [ + value2onnx_parameter[node_.left].onnx_name, value2onnx_parameter[node_.right].onnx_name], [op_not_temp.name]) + onnx_node2 = oh.make_node('Not', [op_not_temp.name], [ + value2onnx_parameter[node.outputs[0]].onnx_name]) onnx_graph.nodes.append(onnx_node1) onnx_graph.nodes.append(onnx_node2) else: - onnx_node = oh.make_node(op_str, [value2onnx_parameter[node_.left].onnx_name, value2onnx_parameter[node_.right].onnx_name], [value2onnx_parameter[node.outputs[0]].onnx_name]) + onnx_node = oh.make_node(op_str, [value2onnx_parameter[node_.left].onnx_name, value2onnx_parameter[node_.right].onnx_name], [ + value2onnx_parameter[node.outputs[0]].onnx_name]) onnx_graph.nodes.append(onnx_node) if isinstance(node, nodes.NodeGetItem): - node_ = node # type: nodes.NodeGetItem + node_ = node # type: nodes.NodeGetItem if len(node_.indexes) == 1: if isinstance(node_.target, values.ListValue) or isinstance(node_.target, values.RangeValue): onnx_node = oh.make_node( 'ChainerSequenceLookup', - [value2onnx_parameter[node_.target].onnx_name, value2onnx_parameter[node_.indexes[0]].onnx_name], + [value2onnx_parameter[node_.target].onnx_name, + value2onnx_parameter[node_.indexes[0]].onnx_name], [value2onnx_parameter[node.outputs[0]].onnx_name]) onnx_graph.nodes.append(onnx_node) else: onnx_node = oh.make_node( 'ChainerGetItem', - [value2onnx_parameter[node_.target].onnx_name, value2onnx_parameter[node_.indexes[0]].onnx_name], + [value2onnx_parameter[node_.target].onnx_name, + value2onnx_parameter[node_.indexes[0]].onnx_name], [value2onnx_parameter[node.outputs[0]].onnx_name], slice_specs=[1]) onnx_graph.nodes.append(onnx_node) @@ -776,7 +824,7 @@ def generate_tensor(output_): onnx_graph.nodes.append(onnx_node) if isinstance(node, nodes.NodeSlice): - node_ = node # type: nodes.NodeSlice + node_ = node # type: nodes.NodeSlice indices = [] @@ -797,19 +845,21 @@ def generate_tensor(output_): slice_specs=node_.slice_specs) onnx_graph.nodes.append(onnx_node) - if isinstance(node, nodes.NodeCall): convert_node_call(onnx_graph, node) if isinstance(node, nodes.NodeIf): - node_ = node # type: nodes.NodeIf + node_ = node # type: nodes.NodeIf - true_graph = self.generate_graph(node_.true_graph.input_values, node_.true_graph.output_values, node_.true_graph, onnx_graph) - false_graph = self.generate_graph(node_.false_graph.input_values, node_.false_graph.output_values, node_.false_graph, onnx_graph) + true_graph = self.generate_graph( + node_.true_graph.input_values, node_.true_graph.output_values, node_.true_graph, onnx_graph) + false_graph = self.generate_graph( + node_.false_graph.input_values, node_.false_graph.output_values, node_.false_graph, onnx_graph) onnx_node = oh.make_node( 'If', - [value2onnx_parameter[node_.cond].onnx_name] + [value2onnx_parameter[x].onnx_name for x in node.input_values], + [value2onnx_parameter[node_.cond].onnx_name] + + [value2onnx_parameter[x].onnx_name for x in node.input_values], [value2onnx_parameter[x].onnx_name for x in node.outputs], then_branch=true_graph, else_branch=false_graph) @@ -817,10 +867,11 @@ def generate_tensor(output_): onnx_graph.nodes.append(onnx_node) if isinstance(node, nodes.NodeFor): - node_ = node # type: nodes.NodeFor + node_ = node # type: nodes.NodeFor # get length of sequence - v_len = ONNXValue(onnx_graph, np.array(0).dtype, [value2onnx_parameter[node_.iter_value].onnx_name, '/Len']) + v_len = ONNXValue(onnx_graph, np.array(0).dtype, [ + value2onnx_parameter[node_.iter_value].onnx_name, '/Len']) onnx_node = onnx_graph.add_node( 'ChainerGenericLen', @@ -828,39 +879,44 @@ def generate_tensor(output_): [v_len], str(node.lineprop)) - body_graph = self.generate_graph(node_.body_graph.input_values, node_.body_graph.output_values, node_.body_graph, onnx_graph) + body_graph = self.generate_graph( + node_.body_graph.input_values, node_.body_graph.output_values, node_.body_graph, onnx_graph) # for onnx_node = onnx_graph.add_node( 'Loop', - [v_len] + [""] + [value2onnx_parameter[node_.iter_value].onnx_name] + [value2onnx_parameter[x].onnx_name for x in node.input_values], + [v_len] + [""] + [value2onnx_parameter[node_.iter_value].onnx_name] + + [value2onnx_parameter[x].onnx_name for x in node.input_values], [value2onnx_parameter[x].onnx_name for x in node.outputs], str(node.lineprop), body=body_graph) if isinstance(node, nodes.NodeForGenerator): - node_ = node # type: nodes.NodeForGenerator + node_ = node # type: nodes.NodeForGenerator # get value from sequence with index if isinstance(node_.iter_value, values.ListValue) or isinstance(node_.iter_value, values.RangeValue): onnx_node = oh.make_node( 'ChainerSequenceLookup', - [value2onnx_parameter[node_.iter_value].onnx_name, value2onnx_parameter[node_.counter_value].onnx_name], + [value2onnx_parameter[node_.iter_value].onnx_name, + value2onnx_parameter[node_.counter_value].onnx_name], [value2onnx_parameter[node_.outputs[0]].onnx_name]) onnx_graph.nodes.append(onnx_node) else: onnx_node = oh.make_node( 'ChainerGetItem', - [value2onnx_parameter[node_.iter_value].onnx_name, value2onnx_parameter[node_.counter_value].onnx_name], + [value2onnx_parameter[node_.iter_value].onnx_name, + value2onnx_parameter[node_.counter_value].onnx_name], [value2onnx_parameter[node_.outputs[0]].onnx_name], slice_specs=[1]) onnx_graph.nodes.append(onnx_node) if isinstance(node, nodes.NodeListcomp): - node_ = node # type: nodes.NodeListcomp + node_ = node # type: nodes.NodeListcomp # get length of sequence - tensor_len = ONNXValue(onnx_graph, np.array(0).dtype, [value2onnx_parameter[node_.iter_value].onnx_name, '/Len']) + tensor_len = ONNXValue(onnx_graph, np.array(0).dtype, [ + value2onnx_parameter[node_.iter_value].onnx_name, '/Len']) onnx_graph.add_node( 'ChainerGenericLen', @@ -868,21 +924,23 @@ def generate_tensor(output_): [tensor_len], str(node.lineprop)) - body_graph = self.generate_graph(node_.body_graph.input_values, node_.body_graph.output_values, node_.body_graph, onnx_graph) + body_graph = self.generate_graph( + node_.body_graph.input_values, node_.body_graph.output_values, node_.body_graph, onnx_graph) onnx_node = oh.make_node( 'Loop', - [tensor_len.name] + [""] + [value2onnx_parameter[node_.iter_value].onnx_name] + [value2onnx_parameter[x].onnx_name for x in node.input_values], + [tensor_len.name] + [""] + [value2onnx_parameter[node_.iter_value].onnx_name] + + [value2onnx_parameter[x].onnx_name for x in node.input_values], [value2onnx_parameter[x].onnx_name for x in node.outputs], body=body_graph) onnx_graph.nodes.append(onnx_node) if isinstance(node, nodes.NodeConvert): - node_ = node # type: nodes.NodeConvert + node_ = node # type: nodes.NodeConvert if node_.classtype == 'List': - if isinstance(node_.value, values.ListValue): + if isinstance(node_.value, values.ListValue): onnx_node = oh.make_node( "Identity", [value2onnx_parameter[node.inputs[0]].onnx_name], @@ -891,7 +949,7 @@ def generate_tensor(output_): onnx_graph.nodes.append(onnx_node) - else: + else: # not supported yet assert False @@ -900,7 +958,7 @@ def generate_tensor(output_): assert False if isinstance(node, nodes.NodeGenerate): - node_ = node # type: nodes.NodeGenerate + node_ = node # type: nodes.NodeGenerate if node_.classtype == 'range': onnx_node = oh.make_node( "ChainerSequenceRange", @@ -911,7 +969,8 @@ def generate_tensor(output_): onnx_graph.nodes.append(onnx_node) if node_.classtype == 'array': - dtype_value = try_get_attribute(node.fargs.get_value('dtype')) + dtype_value = try_get_attribute( + node.fargs.get_value('dtype')) if dtype_value is not None: dtype = utils.int_2_numpy_type(dtype_value) else: @@ -938,7 +997,8 @@ def generate_tensor(output_): [o], str(node.lineprop)) else: - casting_name = value2onnx_parameter[node.outputs[0]].onnx_name + '/Cast' + casting_name = value2onnx_parameter[node.outputs[0] + ].onnx_name + '/Cast' onnx_node = onnx_graph.add_node( "ChainerSequenceStack", [value], @@ -977,12 +1037,13 @@ def generate_model(self, inputs, outputs, graph, model) -> 'ModelProto': self.param2name = {id(p): 'param' + n.replace('/', '_') for n, p in model.namedparams()} - for p,n in self.param2name.items(): + for p, n in self.param2name.items(): assigned_names.append(n) # assign onnx name assign_onnx_name(graph) graph_ = self.generate_graph(inputs, outputs, graph, None, True) - onnx_model = oh.make_model(graph_, producer_name="elichika", producer_version="0.1") + onnx_model = oh.make_model( + graph_, producer_name="elichika", producer_version="0.1") return onnx_model diff --git a/elichika/elichika/parser/core.py b/elichika/elichika/parser/core.py index d63151e1..7a43bf4a 100644 --- a/elichika/elichika/parser/core.py +++ b/elichika/elichika/parser/core.py @@ -15,6 +15,7 @@ from elichika.parser.graphs import Graph import numpy as np + def get_module_name(target_module, parent_module): members = inspect.getmembers(parent_module) @@ -24,16 +25,17 @@ def get_module_name(target_module, parent_module): return '' -def convert_model(model : 'chainer.Chain', args = []): + +def convert_model(model: 'chainer.Chain', args=[]): # reset values values.reset_field_and_attributes() utils.reset_guid() values.instance_converters.clear() - def instance_converter(m,i): + def instance_converter(m, i): if values_builtin.is_builtin_chainer_link(i): - return values_builtin.ChainerLinkInstance(m,i) + return values_builtin.ChainerLinkInstance(m, i) return None values.instance_converters.append(instance_converter) @@ -42,18 +44,25 @@ def instance_converter(m,i): default_module = values.Module(sys.modules[model.__module__]) # chainer.functions - chainer_functions_module_name = get_module_name(F, default_module.internal_module) + chainer_functions_module_name = get_module_name( + F, default_module.internal_module) if chainer_functions_module_name != '': f_dict = values.ValueRef(values.ModuleValue()) - f_relu = values.FuncValue(functions_builtin.ChainerFunction(F.relu), None) + f_relu = values.FuncValue( + functions_builtin.ChainerFunction(F.relu), None) f_dict.get_field().get_attribute('relu').revise(values.ValueRef(f_relu)) - f_softmax = values.FuncValue(functions_builtin.ChainerFunction(F.softmax), None) + f_softmax = values.FuncValue( + functions_builtin.ChainerFunction(F.softmax), None) f_dict.get_field().get_attribute('softmax').revise(values.ValueRef(f_softmax)) - f_softmax_cross_entropy = values.FuncValue(functions_builtin.ChainerFunction(F.softmax_cross_entropy), None) - f_dict.get_field().get_attribute('softmax_cross_entropy').revise(values.ValueRef(f_softmax_cross_entropy)) - f_pad_sequence = values.FuncValue(functions_builtin.ChainerFunction(F.pad_sequence), None) - f_dict.get_field().get_attribute('pad_sequence').revise(values.ValueRef(f_pad_sequence)) + f_softmax_cross_entropy = values.FuncValue( + functions_builtin.ChainerFunction(F.softmax_cross_entropy), None) + f_dict.get_field().get_attribute('softmax_cross_entropy').revise( + values.ValueRef(f_softmax_cross_entropy)) + f_pad_sequence = values.FuncValue( + functions_builtin.ChainerFunction(F.pad_sequence), None) + f_dict.get_field().get_attribute('pad_sequence').revise( + values.ValueRef(f_pad_sequence)) default_module.set_default_value(chainer_functions_module_name, f_dict) # numpy @@ -64,8 +73,10 @@ def instance_converter(m,i): f_array = values.FuncValue(functions_builtin.NDArrayFunction(), None) f_dict.get_field().get_attribute('array').revise(values.ValueRef(f_array)) - f_dict.get_field().get_attribute('int32').revise(values.ValueRef(values.NumberValue(utils.numpy_type_2_int(np.int32)))) - f_dict.get_field().get_attribute('float32').revise(values.ValueRef(values.NumberValue(utils.numpy_type_2_int(np.float32)))) + f_dict.get_field().get_attribute('int32').revise( + values.ValueRef(values.NumberValue(utils.numpy_type_2_int(np.int32)))) + f_dict.get_field().get_attribute('float32').revise( + values.ValueRef(values.NumberValue(utils.numpy_type_2_int(np.float32)))) default_module.set_default_value(numpy_module_name, f_dict) @@ -89,10 +100,10 @@ def instance_converter(m,i): varg.get_value().name = 'in_' + str(ind) # make value unknown - #if isinstance(varg.get_value(), values.TupleValue): + # if isinstance(varg.get_value(), values.TupleValue): # for i in range(len(varg.get_value().internal_value)): # varg.get_value().internal_value[i] = None - #else: + # else: varg.get_value().internal_value = None finput.inputs.append(varg) @@ -101,10 +112,10 @@ def instance_converter(m,i): graph = Graph() forward_func_value = forward_func.get_value() - ret = forward_func_value.func.vcall(default_module, graph, forward_func_value.obj, finput) + ret = forward_func_value.func.vcall( + default_module, graph, forward_func_value.obj, finput) assert(ret is None or isinstance(ret, values.ValueRef)) - def try_get_value(value) -> 'values.Value': if isinstance(value, values.Value): return value diff --git a/elichika/elichika/parser/functions.py b/elichika/elichika/parser/functions.py index e3c4ba53..b1b0e64b 100644 --- a/elichika/elichika/parser/functions.py +++ b/elichika/elichika/parser/functions.py @@ -3,7 +3,8 @@ import chainer.functions as F import chainer.links as L import inspect -import ast, gast +import ast +import gast import weakref import numpy as np @@ -16,8 +17,9 @@ from elichika.parser import core from elichika.parser import config -def generate_copied_value(value : 'values.Value'): - assert(isinstance(value,values.Value)) + +def generate_copied_value(value: 'values.Value'): + assert(isinstance(value, values.Value)) if isinstance(value, values.NumberValue): copied = values.NumberValue(value.internal_value) @@ -64,15 +66,16 @@ def generate_copied_value(value : 'values.Value'): return values.Value() -def generate_tensor_value_with_undefined_shape_size(value : 'values.TensorValue'): + +def generate_tensor_value_with_undefined_shape_size(value: 'values.TensorValue'): assert(isinstance(value, values.TensorValue)) ret = values.TensorValue() ret.shape = tuple([-1 for v in value.shape]) return ret -def generate_value_with_same_type(value : 'values.Value'): - assert(isinstance(value,values.Value)) +def generate_value_with_same_type(value: 'values.Value'): + assert(isinstance(value, values.Value)) ret = None if isinstance(value, values.TensorValue): ret = values.TensorValue() @@ -113,22 +116,25 @@ def generate_value_with_same_type(value : 'values.Value'): return ret + class FunctionArgInput(): def __init__(self): self.inputs = [] self.keywords = {} + class FunctionArg(): - def __init__(self, name : 'str' = '', obj : 'values.ValueRef' = None): + def __init__(self, name: 'str' = '', obj: 'values.ValueRef' = None): self.name = name self.obj = obj + class FunctionArgCollection(): def __init__(self): - self.args = {} # Dict[str,FunctionArg] + self.args = {} # Dict[str,FunctionArg] self.args_list = [] - def add_arg(self, fa : 'FunctionArg'): + def add_arg(self, fa: 'FunctionArg'): self.args_list.append(fa) self.args[fa.name] = fa @@ -150,17 +156,17 @@ def analyze_args(self, func): fa.name = v.name fa.obj = values.parse_instance(None, v.name, v.default) self.add_arg(fa) - - def merge_inputs(self, inputs : 'FunctionArgInput') -> 'FunctionArgCollection': + + def merge_inputs(self, inputs: 'FunctionArgInput') -> 'FunctionArgCollection': ret = FunctionArgCollection() - + for fa in self.get_args(): ret.add_arg(fa) for i in range(len(inputs.inputs)): ret.args_list[i].obj = inputs.inputs[i] - - for k,v in inputs.keywords.items(): + + for k, v in inputs.keywords.items(): if k in ret.args.keys(): ret.args[k].obj = v @@ -183,6 +189,7 @@ def get_args(self) -> 'List[FunctionArg]': ret.append(FunctionArg(fa.name, fa.obj)) return ret + class FunctionBase(): def __init__(self): self.name = '' @@ -191,7 +198,7 @@ def __init__(self): self.args = FunctionArgCollection() self.base_func = None - + def parse_args(self, args): funcArgs = self.funcArgs.copy() @@ -230,13 +237,14 @@ def analyze_args(self, func): self.funcArgs.append(fa) def get_values(self, args): - assert(all([isinstance(arg.obj,values.ValueRef) for arg in args])) + assert(all([isinstance(arg.obj, values.ValueRef) for arg in args])) return [arg.obj.get_value() for arg in args] - def vcall(self, module : 'values.Field', graph : 'core.Graph', inst : 'values.Value', args = [], line = -1): + def vcall(self, module: 'values.Field', graph: 'core.Graph', inst: 'values.Value', args=[], line=-1): return None + class UserDefinedClassConstructorFunction(FunctionBase): def __init__(self, classinfo): super().__init__() @@ -257,8 +265,9 @@ def __init__(self, classinfo): self.ast = gast.ast_to_gast(ast.parse(code)).body[0] - def vcall(self, module : 'values.Field', graph : 'graphs.Graph', inst : 'values.ValueRef', args : 'FunctionArgInput', line = -1): - ret = values.ValueRef(values.UserDefinedInstance(module, None, self.classinfo)) + def vcall(self, module: 'values.Field', graph: 'graphs.Graph', inst: 'values.ValueRef', args: 'FunctionArgInput', line=-1): + ret = values.ValueRef(values.UserDefinedInstance( + module, None, self.classinfo)) inst = ret func_field = values.Field() @@ -279,6 +288,7 @@ def vcall(self, module : 'values.Field', graph : 'graphs.Graph', inst : 'values. return ret + class UserDefinedFunction(FunctionBase): def __init__(self, func): super().__init__() @@ -293,7 +303,7 @@ def __init__(self, func): self.ast = gast.ast_to_gast(ast.parse(code)).body[0] - def vcall(self, module : 'values.Field', graph : 'core.Graph', inst : 'values.ValueRef', args : 'FunctionArgInput', line = -1): + def vcall(self, module: 'values.Field', graph: 'core.Graph', inst: 'values.ValueRef', args: 'FunctionArgInput', line=-1): func_field = values.Field() func_field.set_module(module) diff --git a/elichika/elichika/parser/functions_builtin.py b/elichika/elichika/parser/functions_builtin.py index 04c0f78e..7510db4d 100644 --- a/elichika/elichika/parser/functions_builtin.py +++ b/elichika/elichika/parser/functions_builtin.py @@ -8,6 +8,7 @@ import chainer.functions as F import chainer.links as L + class ChainerFunction(functions.FunctionBase): def __init__(self, func): super().__init__() @@ -15,38 +16,42 @@ def __init__(self, func): self.analyze_args(func) self.base_func = func - def vcall(self, module : 'Field', graph : 'Graph', inst : 'values.ValueRef', args : 'functions.FunctionArgInput', line = -1): + def vcall(self, module: 'Field', graph: 'Graph', inst: 'values.ValueRef', args: 'functions.FunctionArgInput', line=-1): funcArgs = self.args.merge_inputs(args) vargs = funcArgs.get_values() node = nodes.NodeCall(self, vargs, line) graph.add_node(node) #value = functions.generate_value_with_same_type(vargs[0]) - value = values.TensorValue() + value = values.TensorValue() value.name = '@F.{}.{}'.format(line, self.name) node.set_outputs([value]) return values.ValueRef(value) + class RangeFunction(functions.FunctionBase): def __init__(self): super().__init__() self.name = 'range' - def vcall(self, module : 'Field', graph : 'Graph', inst : 'values.ValueRef', args : 'functions.FunctionArgInput', line = -1): - node = nodes.NodeGenerate('range', [v.get_value() for v in args.inputs], line) + def vcall(self, module: 'Field', graph: 'Graph', inst: 'values.ValueRef', args: 'functions.FunctionArgInput', line=-1): + node = nodes.NodeGenerate( + 'range', [v.get_value() for v in args.inputs], line) graph.add_node(node) value = values.RangeValue() value.name = '@F.{}.{}'.format(line, self.name) node.set_outputs([value]) return values.ValueRef(value) + class ListFunction(functions.FunctionBase): def __init__(self): super().__init__() self.name = 'list' - self.args.add_arg(functions.FunctionArg('value', values.ValueRef(values.NoneValue()))) + self.args.add_arg(functions.FunctionArg( + 'value', values.ValueRef(values.NoneValue()))) - def vcall(self, module : 'Field', graph : 'Graph', inst : 'values.ValueRef', args : 'functions.FunctionArgInput', line = -1): + def vcall(self, module: 'Field', graph: 'Graph', inst: 'values.ValueRef', args: 'functions.FunctionArgInput', line=-1): funcArgs = self.args.merge_inputs(args) vargs = funcArgs.get_values() value = values.ListValue() @@ -62,6 +67,7 @@ def vcall(self, module : 'Field', graph : 'Graph', inst : 'values.ValueRef', arg node.set_outputs([value]) return values.ValueRef(value) + class AppendFunction(functions.FunctionBase): def __init__(self, owner): super().__init__() @@ -69,7 +75,7 @@ def __init__(self, owner): self.owner = owner self.args.add_arg(functions.FunctionArg('elmnt')) - def vcall(self, module : 'Field', graph : 'Graph', inst : 'values.ValueRef', args : 'functions.FunctionArgInput', line = -1): + def vcall(self, module: 'Field', graph: 'Graph', inst: 'values.ValueRef', args: 'functions.FunctionArgInput', line=-1): funcArgs = self.args.merge_inputs(args) vargs = funcArgs.get_values() @@ -85,6 +91,7 @@ def vcall(self, module : 'Field', graph : 'Graph', inst : 'values.ValueRef', arg graph.add_node(node) return values.NoneValue() + class NDArrayFunction(functions.FunctionBase): def __init__(self): super().__init__() @@ -120,7 +127,7 @@ def __init__(self): fa.obj = values.ValueRef(values.NumberValue(0)) self.args.add_arg(fa) - def vcall(self, module : 'Field', graph : 'Graph', inst : 'values.ValueRef', args : 'functions.FunctionArgInput', line = -1): + def vcall(self, module: 'Field', graph: 'Graph', inst: 'values.ValueRef', args: 'functions.FunctionArgInput', line=-1): funcArgs = self.args.merge_inputs(args) vargs = funcArgs.get_values() @@ -140,6 +147,7 @@ def vcall(self, module : 'Field', graph : 'Graph', inst : 'values.ValueRef', arg node.set_outputs([value]) return values.ValueRef(value) + class NDArrayShapeFunction(functions.FunctionBase): def __init__(self, owner): super().__init__() @@ -147,7 +155,7 @@ def __init__(self, owner): self.owner = owner self.is_property = True - def vcall(self, module : 'Field', graph : 'Graph', inst : 'values.ValueRef', args : 'functions.FunctionArgInput', line = -1): + def vcall(self, module: 'Field', graph: 'Graph', inst: 'values.ValueRef', args: 'functions.FunctionArgInput', line=-1): node = nodes.NodeCall(self, [inst.get_value()], line) value = values.ListValue() @@ -156,4 +164,4 @@ def vcall(self, module : 'Field', graph : 'Graph', inst : 'values.ValueRef', arg # TODO should make tuple graph.add_node(node) - return values.ValueRef(value) \ No newline at end of file + return values.ValueRef(value) diff --git a/elichika/elichika/parser/graphs.py b/elichika/elichika/parser/graphs.py index 43b5629e..02fddc73 100644 --- a/elichika/elichika/parser/graphs.py +++ b/elichika/elichika/parser/graphs.py @@ -1,6 +1,7 @@ from elichika.parser import nodes from elichika.parser import utils + class Graph: def __init__(self): self.name = '' @@ -18,5 +19,5 @@ def add_output_value(self, value): value = nodes.filter_tuple(value) self.output_values.append(value) - def add_node(self, node : 'nodes.Node'): + def add_node(self, node: 'nodes.Node'): self.nodes.append(node) diff --git a/elichika/elichika/parser/nodes.py b/elichika/elichika/parser/nodes.py index baf93c61..eb621d8b 100644 --- a/elichika/elichika/parser/nodes.py +++ b/elichika/elichika/parser/nodes.py @@ -10,18 +10,21 @@ from elichika.parser import functions from elichika.parser import utils + class BinOpType(Enum): Add = 0, Sub = 1, Mul = 2, Unknown = 255, + class UnaryOpType(Enum): UAdd = 0, USub = 1, Not = 2, Unknown = 255, + class CompareType(Enum): Eq = 0, NotEq = 1, @@ -33,6 +36,7 @@ class CompareType(Enum): IsNot = 7, unknown = 255, + class Node: def __init__(self, line): self.inputs = [] @@ -65,6 +69,7 @@ def set_outputs(self, outputs): for output in self.outputs: output.generator = self + def filter_tuple(value): if isinstance(value, list): for i in range(len(value)): @@ -84,8 +89,9 @@ def filter_tuple(value): return values.TupleValue(vs) return value + class NodeCopy(Node): - def __init__(self, value : 'values.Value', line = -1): + def __init__(self, value: 'values.Value', line=-1): super().__init__(line) value = filter_tuple(value) @@ -95,8 +101,9 @@ def __init__(self, value : 'values.Value', line = -1): def __str__(self): return 'Copy({})'.format(self.lineprop) + class NodeNonVolatileAssign(Node): - def __init__(self, target_value : 'values.Value', value : 'values.Value', line = -1): + def __init__(self, target_value: 'values.Value', value: 'values.Value', line=-1): super().__init__(line) target_value = filter_tuple(target_value) value = filter_tuple(value) @@ -109,11 +116,12 @@ def __init__(self, target_value : 'values.Value', value : 'values.Value', line = def __str__(self): return 'NodeNonVolatileAssign({})'.format(self.lineprop) + class NodeAssign(Node): - def __init__(self, attr : 'values.Attribute', obj : 'values.ValueRef', line = -1): - assert(isinstance(obj,values.ValueRef)) + def __init__(self, attr: 'values.Attribute', obj: 'values.ValueRef', line=-1): + assert(isinstance(obj, values.ValueRef)) super().__init__(line) - + self.targets = [] self.objects = [] @@ -123,8 +131,9 @@ def __init__(self, attr : 'values.Attribute', obj : 'values.ValueRef', line = -1 def __str__(self): return 'Assign({})'.format(self.lineprop) + class NodeAugAssign(Node): - def __init__(self, target : 'values.Value', value : 'values.Value', binop : 'BinOp', line = -1): + def __init__(self, target: 'values.Value', value: 'values.Value', binop: 'BinOp', line=-1): super().__init__(line) target = filter_tuple(target) @@ -140,8 +149,9 @@ def __init__(self, target : 'values.Value', value : 'values.Value', binop : 'Bin def __str__(self): return 'AugAssign({})'.format(self.lineprop) + class NodeBinOp(Node): - def __init__(self, left : 'values.Value', right : 'values.Value', binop : 'BinOp', line = -1): + def __init__(self, left: 'values.Value', right: 'values.Value', binop: 'BinOp', line=-1): super().__init__(line) left = filter_tuple(left) @@ -157,8 +167,9 @@ def __init__(self, left : 'values.Value', right : 'values.Value', binop : 'BinOp def __str__(self): return 'BinOp({},{})'.format(self.lineprop, self.binop) + class NodeUnaryOp(Node): - def __init__(self, operand : 'values.Value', unaryop : 'UnaryOpType', line = -1): + def __init__(self, operand: 'values.Value', unaryop: 'UnaryOpType', line=-1): super().__init__(line) operand = filter_tuple(operand) @@ -170,8 +181,9 @@ def __init__(self, operand : 'values.Value', unaryop : 'UnaryOpType', line = -1) def __str__(self): return 'UnaryOp({},{})'.format(self.lineprop, self.unaryop) + class NodeCompare(Node): - def __init__(self, left : 'values.Value', right : 'values.Value', compare : 'CompareType', line = -1): + def __init__(self, left: 'values.Value', right: 'values.Value', compare: 'CompareType', line=-1): super().__init__(line) left = filter_tuple(left) right = filter_tuple(right) @@ -186,8 +198,9 @@ def __init__(self, left : 'values.Value', right : 'values.Value', compare : 'Com def __str__(self): return 'Compare({},{})'.format(self.lineprop, self.compare) + class NodeGetItem(Node): - def __init__(self, target : "values.Value", indexes, line = -1): + def __init__(self, target: "values.Value", indexes, line=-1): super().__init__(line) target = filter_tuple(target) @@ -200,8 +213,9 @@ def __init__(self, target : "values.Value", indexes, line = -1): def __str__(self): return 'GetItem({})'.format(self.lineprop) + class NodeSlice(Node): - def __init__(self, target : "values.Value", indices, slice_specs, line = -1): + def __init__(self, target: "values.Value", indices, slice_specs, line=-1): super().__init__(line) target = filter_tuple(target) @@ -215,15 +229,16 @@ def __init__(self, target : "values.Value", indices, slice_specs, line = -1): def __str__(self): return 'Slice({})'.format(self.lineprop) + class NodeCall(Node): - def __init__(self, func : 'Function', args, line = -1): + def __init__(self, func: 'Function', args, line=-1): super().__init__(line) args = filter_tuple(args) self.func = func self.args = args self.inputs.extend(self.args) - self.fargs = None # functions.FunctionArgCollection + self.fargs = None # functions.FunctionArgCollection def __str__(self): if self.func is not None and isinstance(self.func, values.FuncValue): @@ -233,8 +248,9 @@ def __str__(self): else: return 'Call({}, {})'.format(self.lineprop, 'Unknown') + class NodeReturn(Node): - def __init__(self, value, line = -1): + def __init__(self, value, line=-1): super().__init__(line) value = filter_tuple(value) @@ -244,8 +260,9 @@ def __init__(self, value, line = -1): def __str__(self): return 'Return({})'.format(self.lineprop) + class NodeIf(Node): - def __init__(self, cond, input_values, true_graph, false_graph, line = -1): + def __init__(self, cond, input_values, true_graph, false_graph, line=-1): super().__init__(line) cond = filter_tuple(cond) input_values = filter_tuple(input_values) @@ -265,8 +282,9 @@ def __init__(self, cond, input_values, true_graph, false_graph, line = -1): def __str__(self): return 'If({})'.format(self.lineprop) + class NodeFor(Node): - def __init__(self, iter_value, input_values, body_graph, line = -1): + def __init__(self, iter_value, input_values, body_graph, line=-1): super().__init__(line) iter_value = filter_tuple(iter_value) input_values = filter_tuple(input_values) @@ -282,8 +300,9 @@ def __init__(self, iter_value, input_values, body_graph, line = -1): def __str__(self): return 'For({})'.format(self.lineprop) + class NodeForGenerator(Node): - def __init__(self, counter_value, iter_value, line = -1): + def __init__(self, counter_value, iter_value, line=-1): super().__init__(line) counter_value = filter_tuple(counter_value) iter_value = filter_tuple(iter_value) @@ -296,8 +315,9 @@ def __init__(self, counter_value, iter_value, line = -1): def __str__(self): return 'ForGen({})'.format(self.lineprop) + class NodeListcomp(Node): - def __init__(self, iter_value, input_values, body_graph, line = -1): + def __init__(self, iter_value, input_values, body_graph, line=-1): super().__init__(line) input_values = filter_tuple(input_values) iter_value = filter_tuple(iter_value) @@ -313,21 +333,23 @@ def __init__(self, iter_value, input_values, body_graph, line = -1): def __str__(self): return 'Listcomp({})'.format(self.lineprop) + class NodeGenerate(Node): - def __init__(self, classtype, args, line = -1): + def __init__(self, classtype, args, line=-1): super().__init__(line) args = filter_tuple(args) self.classtype = classtype self.args = args self.extend_inputs(self.args) - self.fargs = None # functions.FunctionArgCollection + self.fargs = None # functions.FunctionArgCollection def __str__(self): return 'Generate({},{})'.format(self.classtype, self.lineprop) + class NodeConvert(Node): - def __init__(self, classtype, value, line = -1): + def __init__(self, classtype, value, line=-1): super().__init__(line) value = filter_tuple(value) @@ -337,4 +359,3 @@ def __init__(self, classtype, value, line = -1): def __str__(self): return 'Convert({},{})'.format(self.classtype, self.lineprop) - diff --git a/elichika/elichika/parser/utils.py b/elichika/elichika/parser/utils.py index 644a404a..77216c0b 100644 --- a/elichika/elichika/parser/utils.py +++ b/elichika/elichika/parser/utils.py @@ -5,16 +5,19 @@ slice_int_max = 2 ** 31 - 1 + def get_guid(): global current_id id = current_id current_id += 1 return id + def reset_guid(): global current_id current_id = 0 + def numpy_type_2_int(t): if t == np.int32: return 0 @@ -22,6 +25,7 @@ def numpy_type_2_int(t): return 1 assert(False) + def int_2_numpy_type(n): if n == 0: return np.int32 @@ -29,16 +33,19 @@ def int_2_numpy_type(n): return np.float32 assert(False) + def create_obj_value_name_with_constant(value): return '@C_' + str(value) -def create_obj_value_name_with_attribute(name : "str", pre_name : "str"): + +def create_obj_value_name_with_attribute(name: "str", pre_name: "str"): if len(pre_name) > 0 and pre_name[0] != '@': return pre_name else: return name -def clip_head(s : 'str'): + +def clip_head(s: 'str'): s = s.split('\n') # print(s) hs = os.path.commonprefix(list(filter(lambda x: x != '', s))) @@ -47,8 +54,9 @@ def clip_head(s : 'str'): s = map(lambda x: x[ls:], s) return '\n'.join(s) + class LineProperty(): - def __init__(self, lineno = -1, filename = ''): + def __init__(self, lineno=-1, filename=''): self.lineno = lineno self.filename = filename diff --git a/elichika/elichika/parser/values.py b/elichika/elichika/parser/values.py index f152a039..88d055b1 100644 --- a/elichika/elichika/parser/values.py +++ b/elichika/elichika/parser/values.py @@ -5,7 +5,8 @@ import numpy as np import inspect -import ast, gast +import ast +import gast import weakref from elichika.parser import vevaluator from elichika.parser import core @@ -22,21 +23,25 @@ instance_converters = [] + def reset_field_and_attributes(): global fields fields = [] histories.clear() -def register_field(field : 'Field'): + +def register_field(field: 'Field'): fields.append(weakref.ref(field)) -def push_history(history_id : 'str'): + +def push_history(history_id: 'str'): histories.append(history_id) for field in fields: o = field() if o is not None: o.push_history(history_id) + def pop_history(): histories.pop() for field in fields: @@ -44,23 +49,26 @@ def pop_history(): if o is not None: o.pop_history() + def get_inputs() -> 'List[FieldInput]': - ret = [] + ret = [] for field in fields: o = field() if o is not None: ret += o.get_inputs() return ret + def get_outputs() -> 'List[FieldOutput]': - ret = [] + ret = [] for field in fields: o = field() if o is not None: ret += o.get_outputs() return ret -def parse_instance(default_module, name, instance, self_instance = None, parse_shape = False) -> "ValueRef": + +def parse_instance(default_module, name, instance, self_instance=None, parse_shape=False) -> "ValueRef": for converter in instance_converters: ret = converter(default_module, instance) @@ -141,9 +149,11 @@ def parse_instance(default_module, name, instance, self_instance = None, parse_s if instance is None: return ValueRef(NoneValue()) - model_inst = UserDefinedInstance(default_module, instance, None, isinstance(instance, chainer.Link)) + model_inst = UserDefinedInstance( + default_module, instance, None, isinstance(instance, chainer.Link)) return ValueRef(model_inst) + class FieldInput: def __init__(self): self.input_value = None @@ -151,6 +161,7 @@ def __init__(self): self.name = None self.value = None + class FieldOutput: def __init__(self): self.field = None @@ -159,14 +170,15 @@ def __init__(self): self.old_value = None self.value = None + class FieldAttributeCollection(): - def __init__(self, id : 'str', parent: 'FieldAttributeCollection'): + def __init__(self, id: 'str', parent: 'FieldAttributeCollection'): self.id = id self.parent = parent self.attributes = {} self.inputs = {} - def try_get_attribute(self, key : 'str'): + def try_get_attribute(self, key: 'str'): if key in self.attributes.keys(): return self.attributes[key] @@ -188,25 +200,26 @@ def try_get_attribute(self, key : 'str'): return attribute # input - + #value = parent_attribute.get_ref().get_value() #copied_value = functions.generate_copied_value(value) - #attribute.revise(ValueRef(copied_value)) + # attribute.revise(ValueRef(copied_value)) #self.attributes[key] = attribute attribute.revise(parent_attribute.get_ref()) self.attributes[key] = attribute - self.inputs[attribute] = (attribute.get_ref(), attribute.get_ref().get_value(), attribute.get_ref().get_value(), attribute.get_ref().get_value()) + self.inputs[attribute] = (attribute.get_ref(), attribute.get_ref().get_value( + ), attribute.get_ref().get_value(), attribute.get_ref().get_value()) return attribute def pop_history(self): for att, input in self.inputs.items(): input[0].revise(input[1]) - + def get_inputs(self) -> 'List[FieldInput]': ''' return [(input value, copied input value)] @@ -245,6 +258,7 @@ def get_outputs(self) -> 'List[FieldOutput]': return ret + class Field(): def __init__(self): self.collection = FieldAttributeCollection('', None) @@ -273,10 +287,10 @@ def has_attribute(self, key) -> 'Boolean': if key in c.attributes.keys(): return True c = c.parent - + return False - def get_attribute(self, key : 'str', from_module = True) -> 'Attribute': + def get_attribute(self, key: 'str', from_module=True) -> 'Attribute': attribute = self.collection.try_get_attribute(key) if attribute is not None: @@ -294,7 +308,7 @@ def get_attribute(self, key : 'str', from_module = True) -> 'Attribute': self.collection.attributes[key] = attribute return attribute - def push_history(self, history_id : 'str'): + def push_history(self, history_id: 'str'): collection = FieldAttributeCollection(history_id, self.collection) self.collection = collection @@ -314,7 +328,7 @@ def get_outputs(self): def set_default_value(self, key, value): attribute = self.get_attribute(key) attribute.revise(value) - + def set_predefined_obj(self, key, obj): collections = [] c = self.collection @@ -339,15 +353,17 @@ def set_predefined_obj(self, key, obj): if isinstance(obj.get_value(), Instance) or isinstance(obj.get_value(), FuncValue) or isinstance(obj.get_value(), ModuleValue): continue - collection.inputs[attribute] = (attribute.get_ref(), attribute.get_ref().get_value(), attribute.get_ref().get_value(), attribute.get_ref().get_value()) + collection.inputs[attribute] = (attribute.get_ref(), attribute.get_ref( + ).get_value(), attribute.get_ref().get_value(), attribute.get_ref().get_value()) - # if old_value is not None: + # if old_value is not None: # collection.inputs[attribute] = (attribute.get_ref(), attribute.get_ref().get_value(), old_value, value) #old_value = obj.get_value() #value = functions.generate_copied_value(old_value) #obj = ValueRef(value) + class Module(Field): def __init__(self, module): super().__init__() @@ -389,12 +405,14 @@ def set_default_value(self, key, value): attribute = super().get_attribute(key) attribute.revise(value) + class AttributeHistory: - def __init__(self, obj : 'ValueRef'): + def __init__(self, obj: 'ValueRef'): self.obj = obj + class Attribute: - def __init__(self, name : 'str'): + def __init__(self, name: 'str'): self.name = name self.history = [] self.parent = None @@ -405,12 +423,14 @@ def __init__(self, name : 'str'): # if it is non-volatile, an object in this attribute is saved after running self.is_non_volatile = False - def revise(self, obj : 'ValueRef'): + def revise(self, obj: 'ValueRef'): assert(isinstance(obj, ValueRef)) # assgin name to the object - obj.name = utils.create_obj_value_name_with_attribute(self.name, obj.name) - obj.get_value().name = utils.create_obj_value_name_with_attribute(self.name, obj.get_value().name) + obj.name = utils.create_obj_value_name_with_attribute( + self.name, obj.name) + obj.get_value().name = utils.create_obj_value_name_with_attribute( + self.name, obj.get_value().name) if self.initial_obj is None: self.initial_obj = obj @@ -421,19 +441,21 @@ def revise(self, obj : 'ValueRef'): def has_obj(self): return len(self.history) > 0 - def get_ref(self, inc_access = True): + def get_ref(self, inc_access=True): assert len(self.history) > 0 return self.history[-1].obj def __str__(self): return self.name + class ValueRefHistory(): def __init__(self, value): self.value = value + class ValueRef(): - def __init__(self, value : 'Value'): + def __init__(self, value: 'Value'): self.name = "" self.value = value self.id = utils.get_guid() @@ -449,7 +471,7 @@ def get_value(self) -> 'Value': def revise(self, value): self.value = value - def try_get_and_store_obj(self, name : 'str') -> 'ValueRef': + def try_get_and_store_obj(self, name: 'str') -> 'ValueRef': attribute = self.attributes.get_attribute(name) if attribute.has_obj(): @@ -463,6 +485,7 @@ def try_get_and_store_obj(self, name : 'str') -> 'ValueRef': self.attributes.set_predefined_obj(name, obj) return obj + class Value(): def __init__(self): self.name = "" @@ -470,19 +493,20 @@ def __init__(self): self.internal_value = None self.id = utils.get_guid() - def apply_to_object(self, obj : 'ValueRef'): + def apply_to_object(self, obj: 'ValueRef'): ''' register functions to an object this function is only called when an object is generated ''' return None - def try_get_ref(self, name : 'str', inst : 'ValueRef') -> 'ValueRef': + def try_get_ref(self, name: 'str', inst: 'ValueRef') -> 'ValueRef': return None def __str__(self): return self.name + class NoneValue(Value): def __init__(self): super().__init__() @@ -490,6 +514,7 @@ def __init__(self): def __str__(self): return self.name + '({})'.format('None') + class NumberValue(Value): def __init__(self, number): super().__init__() @@ -501,6 +526,7 @@ def __str__(self): return self.name + '(N.{})'.format('Any') return self.name + '(N.{})'.format(self.internal_value) + class StrValue(Value): def __init__(self, string): super().__init__() @@ -511,6 +537,7 @@ def __str__(self): return self.name + '(S.{})'.format('Any') return self.name + '(S.{})'.format(self.internal_value) + class BoolValue(Value): def __init__(self, b): super().__init__() @@ -521,40 +548,49 @@ def __str__(self): return self.name + '(B.{})'.format('Any') return self.name + '(B.{})'.format(self.internal_value) + class RangeValue(Value): def __init__(self): super().__init__() + def __str__(self): return self.name + '(R)' + class TupleValue(Value): - def __init__(self, values = None): + def __init__(self, values=None): super().__init__() self.internal_value = values + def __str__(self): return self.name + '(Tp{})' + class FuncValue(Value): - def __init__(self, func : 'functions.FunctionBase', obj : 'ValueRef'): + def __init__(self, func: 'functions.FunctionBase', obj: 'ValueRef'): super().__init__() self.func = func self.obj = obj + def __str__(self): return self.name + '(F)' + class ListValue(Value): - def __init__(self, values = None): + def __init__(self, values=None): super().__init__() self.is_any = values is None self.values = [] - def apply_to_object(self, obj : 'ValueRef'): - append_func = ValueRef(FuncValue(functions_builtin.AppendFunction(self), obj)) + def apply_to_object(self, obj: 'ValueRef'): + append_func = ValueRef( + FuncValue(functions_builtin.AppendFunction(self), obj)) obj.attributes.get_attribute('append').revise(append_func) def __str__(self): return self.name + '(L)' + class ModuleValue(Value): def __init__(self): super().__init__() @@ -562,6 +598,7 @@ def __init__(self): def __str__(self): return self.name + '(M)' + class DictValue(Value): def __init__(self): super().__init__() @@ -569,6 +606,7 @@ def __init__(self): def __str__(self): return self.name + '(D)' + class TensorValue(Value): def __init__(self): super().__init__() @@ -576,20 +614,23 @@ def __init__(self): self.value = None self.dtype = None - def apply_to_object(self, obj : 'ValueRef'): - shape_func = ValueRef(FuncValue(functions_builtin.NDArrayShapeFunction(self), obj)) + def apply_to_object(self, obj: 'ValueRef'): + shape_func = ValueRef( + FuncValue(functions_builtin.NDArrayShapeFunction(self), obj)) obj.attributes.get_attribute('shape').revise(shape_func) def __str__(self): return self.name + '(T.{})'.format(self.shape) + class Type(Value): - def __init__(self, name : 'str'): + def __init__(self, name: 'str'): super().__init__() self.name = name + class Instance(Value): - def __init__(self, module : 'Field', inst, classinfo): + def __init__(self, module: 'Field', inst, classinfo): super().__init__() self.inst = inst self.callable = False @@ -597,18 +638,19 @@ def __init__(self, module : 'Field', inst, classinfo): self.module = module self.classinfo = classinfo + class UserDefinedInstance(Instance): - def __init__(self, module : 'Field', inst, classinfo, is_chainer_link = False): + def __init__(self, module: 'Field', inst, classinfo, is_chainer_link=False): super().__init__(module, inst, classinfo) self.is_chainer_link = is_chainer_link if self.is_chainer_link: self.callable = True - def apply_to_object(self, obj : 'ValueRef'): + def apply_to_object(self, obj: 'ValueRef'): if self.is_chainer_link: self.func = obj.try_get_and_store_obj('forward') - def try_get_ref(self, name : 'str', inst : 'ValueRef') -> 'ValueRef': + def try_get_ref(self, name: 'str', inst: 'ValueRef') -> 'ValueRef': obj = None if self.inst is not None: if not hasattr(self.inst, name): @@ -628,4 +670,4 @@ def try_get_ref(self, name : 'str', inst : 'ValueRef') -> 'ValueRef': obj = parse_instance(self.module, name, members_dict[name], inst) - return obj \ No newline at end of file + return obj diff --git a/elichika/elichika/parser/values_builtin.py b/elichika/elichika/parser/values_builtin.py index b4377369..78986eb9 100644 --- a/elichika/elichika/parser/values_builtin.py +++ b/elichika/elichika/parser/values_builtin.py @@ -7,32 +7,39 @@ chainer_links = {} + class ChainerLinkDefinition: - def __init__(self, estimate_shape = None): + def __init__(self, estimate_shape=None): self.estimate_shape = estimate_shape -def estimate_linear_shape(inst : 'chainer.links.Linear', args : 'functions.FunctionArgInput'): + +def estimate_linear_shape(inst: 'chainer.links.Linear', args: 'functions.FunctionArgInput'): if isinstance(args.inputs[0].get_value(), values.TensorValue) and len(args.inputs[0].get_value().shape) >= 2: return (args.inputs[0].get_value().shape[0], inst.out_size) return () -def estimate_convolution2D_shape(inst : 'chainer.links.Convolution2D', args : 'functions.FunctionArgInput'): + +def estimate_convolution2D_shape(inst: 'chainer.links.Convolution2D', args: 'functions.FunctionArgInput'): return functions.generate_tensor_value_with_undefined_shape_size(args.inputs[0].get_value()).shape -chainer_links[chainer.links.Linear] = ChainerLinkDefinition(estimate_linear_shape) -chainer_links[chainer.links.Convolution2D] = ChainerLinkDefinition(estimate_convolution2D_shape) +chainer_links[chainer.links.Linear] = ChainerLinkDefinition( + estimate_linear_shape) +chainer_links[chainer.links.Convolution2D] = ChainerLinkDefinition( + estimate_convolution2D_shape) + def is_builtin_chainer_link(value) -> 'bool': return type(value) in chainer_links.keys() + class ChainerLinkFunction(functions.FunctionBase): def __init__(self, owner): super().__init__() self.name = '__call__' self.owner = owner - def vcall(self, module : 'values.Field', graph : 'Graph', inst : 'Object', args : 'functions.FunctionArgInput', line = -1): + def vcall(self, module: 'values.Field', graph: 'Graph', inst: 'Object', args: 'functions.FunctionArgInput', line=-1): node = nodes.NodeCall(self, [v.get_value() for v in args.inputs], line) graph.add_node(node) value = values.TensorValue() @@ -44,11 +51,13 @@ def vcall(self, module : 'values.Field', graph : 'Graph', inst : 'Object', args node.set_outputs([value]) return values.ValueRef(value) + class ChainerLinkInstance(values.Instance): - def __init__(self, module : 'Field', inst): + def __init__(self, module: 'Field', inst): super().__init__(module, inst, None) self.callable = True - def apply_to_object(self, obj : 'values.ValueRef'): - self.func = values.ValueRef(values.FuncValue(ChainerLinkFunction(self), obj)) + def apply_to_object(self, obj: 'values.ValueRef'): + self.func = values.ValueRef( + values.FuncValue(ChainerLinkFunction(self), obj)) obj.get_field().get_attribute('forward').revise(self.func) diff --git a/elichika/elichika/parser/veval_bin.py b/elichika/elichika/parser/veval_bin.py index adfdc5de..fd4a99cd 100644 --- a/elichika/elichika/parser/veval_bin.py +++ b/elichika/elichika/parser/veval_bin.py @@ -2,7 +2,8 @@ from elichika.parser import values from elichika.parser import functions -def veval(op : 'nodes.BinOpType', left : 'values.Value', right : 'values.Value'): + +def veval(op: 'nodes.BinOpType', left: 'values.Value', right: 'values.Value'): if isinstance(left, values.NumberValue) and isinstance(right, values.NumberValue): return functions.generate_value_with_same_type(left) diff --git a/elichika/elichika/parser/veval_unary.py b/elichika/elichika/parser/veval_unary.py index 5ab75641..854cd4b0 100644 --- a/elichika/elichika/parser/veval_unary.py +++ b/elichika/elichika/parser/veval_unary.py @@ -2,7 +2,8 @@ from elichika.parser import values from elichika.parser import functions -def veval(op : 'nodes.UnaryOpType', value : 'values.Value'): + +def veval(op: 'nodes.UnaryOpType', value: 'values.Value'): if isinstance(value, values.NumberValue): if value.internal_value is not None: diff --git a/elichika/elichika/parser/visualizer.py b/elichika/elichika/parser/visualizer.py index 7d9526df..adff9d9c 100644 --- a/elichika/elichika/parser/visualizer.py +++ b/elichika/elichika/parser/visualizer.py @@ -3,9 +3,11 @@ from elichika.parser.core import convert_model, Graph from elichika.parser.nodes import Node + def get_valids(list_): return [l for l in list_ if l is not None] + node_id = 0 node2id = {} node_ref_count = {} @@ -15,6 +17,7 @@ def get_valids(list_): graph_id = 0 + def reset(): global node_id global node2id @@ -34,7 +37,8 @@ def reset(): graph_id = 0 -def assign_id(graph : 'Graph'): + +def assign_id(graph: 'Graph'): global node_id global node2id global value_id @@ -51,7 +55,8 @@ def assign_id(graph : 'Graph'): for subgraph in node.subgraphs: assign_id(subgraph) -def count_ref(graph : 'Graph'): + +def count_ref(graph: 'Graph'): global node_ref_count for node in graph.nodes: @@ -65,7 +70,8 @@ def count_ref(graph : 'Graph'): for subgraph in node.subgraphs: count_ref(subgraph) -def visit_edge(parent_dot, graph : 'Graph', is_unused_node_ignored): + +def visit_edge(parent_dot, graph: 'Graph', is_unused_node_ignored): global graph_id with parent_dot.subgraph(name='cluster_' + str(graph_id)) as dot: @@ -76,10 +82,10 @@ def visit_edge(parent_dot, graph : 'Graph', is_unused_node_ignored): # ignore if is_unused_node_ignored: - if len(node.inputs) == 0 and not (node in node_ref_count): + if len(node.inputs) == 0 and not (node in node_ref_count): continue - dot.node(node2id[node],str(node)) + dot.node(node2id[node], str(node)) for input in node.inputs: if str(input) != "": @@ -94,7 +100,8 @@ def visit_edge(parent_dot, graph : 'Graph', is_unused_node_ignored): for subgraph in node.subgraphs: visit_edge(parent_dot, subgraph, is_unused_node_ignored) -def visualize(path : 'str', graph : 'Graph', is_unused_node_ignored = True): + +def visualize(path: 'str', graph: 'Graph', is_unused_node_ignored=True): global node_id global node2id global value_id @@ -120,4 +127,3 @@ def visualize(path : 'str', graph : 'Graph', is_unused_node_ignored = True): visit_edge(dot, graph, is_unused_node_ignored) dot.render(path) -