-
Notifications
You must be signed in to change notification settings - Fork 23
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Mylstm elichika (WIP) #210
Changes from 6 commits
056ee7e
7cb622d
51bc783
2651d49
7078750
2e0765c
e844e1e
a6c5ac3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -46,6 +46,23 @@ def vcall(self, module: 'Field', graph: 'Graph', inst: 'values.ValueRef', args: | |
return values.ValueRef(value) | ||
|
||
|
||
class LenFunction(functions.FunctionBase): | ||
def __init__(self): | ||
super().__init__() | ||
self.name = 'len' | ||
|
||
def vcall(self, module: 'Field', graph: 'Graph', inst: 'values.ValueRef', args: 'functions.FunctionArgInput', line=-1): | ||
node = nodes.NodeLen( | ||
args.inputs[0].get_value(), # TODO: Check this. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you elaborate, please? It'll be great if you can give info about what should be done to remove this TODO comment. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It was a non-issue. I was concerned if a new node was necessary here. As @durswd pointed out, this can be done through NodeCall itself. |
||
line | ||
) | ||
graph.add_node(node) | ||
value = values.NumberValue(None) | ||
value.name = '@F.{}.{}'.format(line, self.name) | ||
node.set_outputs([value]) | ||
return values.ValueRef(value) | ||
|
||
|
||
class ListFunction(functions.FunctionBase): | ||
def __init__(self): | ||
super().__init__() | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -48,7 +48,7 @@ def remove_ref(value): | |
converted = {} | ||
|
||
ret = functions.FunctionArgValueInput() | ||
|
||
for v in value.inputs: | ||
converted_v = remove_ref(v) | ||
ret.inputs.append(converted_v) | ||
|
@@ -58,7 +58,7 @@ def remove_ref(value): | |
for k,v in value.keywords.items(): | ||
if v in converted.keys(): | ||
keywords_[k] = converted[v] | ||
else: | ||
else: | ||
keywords_[k] = remove_ref(v) | ||
ret.keywords = keywords_ | ||
return ret | ||
|
@@ -401,3 +401,26 @@ def __init__(self, classtype, value, line=-1): | |
|
||
def __str__(self): | ||
return 'Convert({},{})'.format(self.classtype, self.lineprop) | ||
|
||
class NodeLen(Node): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think that it can be replaced into NodeCall. |
||
def __init__(self, iter_value, line=-1): | ||
super().__init__(line) | ||
iter_value = remove_ref(iter_value) | ||
|
||
self.iter_value = iter_value | ||
self.append_inputs(self.iter_value) | ||
|
||
def __str__(self): | ||
return 'Len({})'.format(self.lineprop) | ||
|
||
class NodeTensorAttribute(Node): | ||
def __init__(self, type, value, line=-1): | ||
super().__init__(line) | ||
value = remove_ref(value) | ||
|
||
self.value = value | ||
self.type = type | ||
self.append_inputs(self.value) | ||
|
||
def __str__(self): | ||
return 'TensorAttribute({},{})'.format(self.type, self.lineprop) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -258,7 +258,8 @@ def get_outputs(self) -> 'List[FieldOutput]': | |
ret = [] | ||
|
||
for key, att in self.attributes.items(): | ||
|
||
if att.name in ['shape', 'size']: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sawada-san, could you double-check this is the right approach to handle these attribtues here? If this is right, I guess we need at least a comment to explain these will be handled in other code (vevaluator.py:114?) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this implementation causes error as follows (it is very rare case). class A: size, shape is regarded as bultin property. it is redundant. But it can support many cases. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wasn't aware that there was an implementation to handle NDArrayFunctions. In that case, there is a bug in the implementation then. It doesn't handle the following case. xs = np.random.rand(3,4,5)
[x.shape for x in xs] I'm sure your implementation can be extended to fix this though. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thank you! |
||
continue | ||
# instance or func | ||
if isinstance(att.get_ref().get_value(), Instance) or isinstance(att.get_ref().get_value(), FuncValue) or isinstance(att.get_ref().get_value(), ModuleValue): | ||
continue | ||
|
@@ -512,7 +513,7 @@ def __init__(self): | |
|
||
def has_constant_value(self) -> 'bool': | ||
return self.internal_value is not None | ||
|
||
def is_all_constant_values(self, is_ref_enabled = False) -> 'bool': | ||
return self.internal_value is not None | ||
|
||
|
@@ -539,7 +540,7 @@ def __init__(self): | |
|
||
def has_constant_value(self) -> 'bool': | ||
return True | ||
|
||
def is_all_constant_values(self, is_ref_enabled = False) -> 'bool': | ||
return True | ||
|
||
|
@@ -623,7 +624,7 @@ def is_all_constant_values(self, is_ref_enabled = False) -> 'bool': | |
else: | ||
if not v.is_all_constant_values(is_ref_enabled): | ||
return False | ||
return True | ||
return True | ||
|
||
def __str__(self): | ||
return self.name + '(Tp{})' | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -44,6 +44,17 @@ def forward(self, xs, h, c, mask): | |
#h = self.initial_h | ||
#c = self.initial_c | ||
inputs = F.pad_sequence(xs) | ||
x = None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you add a TODO comment to explain this is a workaround for the bug and these assignments should be removed? |
||
input = None | ||
gate = None | ||
i = None | ||
o = None | ||
f = None | ||
nc = None | ||
nh = None | ||
m = None | ||
pmask = None | ||
nmask = None | ||
for time in range(max_len): | ||
x = inputs[:, time] | ||
input = F.concat((x, h), axis=1) | ||
|
@@ -83,7 +94,7 @@ def main(): | |
num_vocabs = 10 | ||
num_hidden = 5 | ||
|
||
model_fn = lambda: MyLSTM(num_hidden, batch_size, sequence_length) | ||
model_fn = MyLSTM(num_hidden, batch_size, sequence_length) | ||
|
||
labels, lengths = sequence_utils.gen_random_sequence( | ||
batch_size, sequence_length, num_vocabs) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add assert
len(args.inputs) == 1
or something just in case?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not necessary I think.