Skip to content

Commit

Permalink
Fix inplace operation in DCRNN fully-connected gates (#59)
Browse files Browse the repository at this point in the history
* Remove redundant sigmoid call
* Rename internal variable inputs_and_state
* Remove inplace operation
* Remove NotImplementedError
* Simplify reset and update gate size
  • Loading branch information
klane committed Jul 16, 2021
1 parent 3e6f282 commit 6b04b6d
Showing 1 changed file with 7 additions and 17 deletions.
24 changes: 7 additions & 17 deletions torchts/nn/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,9 @@ def __init__(
self.register_buffer("_supports", supports)

num_matrices = len(supports) * self._max_diffusion_step + 1
input_size_gconv = (self._num_units + input_dim) * num_matrices

if self._use_gc_for_ru:
input_size_ru = input_size_gconv
else:
input_size_ru = self._num_units + input_dim
raise NotImplementedError(
"Fully-connected reset and update gates not yet implemented"
)
input_size_fc = self._num_units + input_dim
input_size_gconv = input_size_fc * num_matrices
input_size_ru = input_size_gconv if self._use_gc_for_ru else input_size_fc

output_size = 2 * self._num_units
self._ru_weights = nn.Parameter(torch.empty(input_size_ru, output_size))
Expand Down Expand Up @@ -85,22 +79,18 @@ def _fc(self, inputs, state, output_size, bias_start=0.0, reset=True):
shape = (batch_size * self._num_nodes, -1)
inputs = torch.reshape(inputs, shape)
state = torch.reshape(state, shape)
inputs_and_state = torch.cat([inputs, state], dim=-1)

value = torch.sigmoid(torch.matmul(inputs_and_state, self._ru_weights))
value += self._ru_biases
x = torch.cat([inputs, state], dim=-1)

return value
return torch.matmul(x, self._ru_weights) + self._ru_biases

def _gconv(self, inputs, state, output_size, bias_start=0.0, reset=False):
batch_size = inputs.shape[0]
shape = (batch_size, self._num_nodes, -1)
inputs = torch.reshape(inputs, shape)
state = torch.reshape(state, shape)
inputs_and_state = torch.cat([inputs, state], dim=2)
input_size = inputs_and_state.size(2)
x = torch.cat([inputs, state], dim=2)
input_size = x.size(2)

x = inputs_and_state
x0 = x.permute(1, 2, 0)
x0 = torch.reshape(x0, shape=[self._num_nodes, input_size * batch_size])
x = torch.unsqueeze(x0, 0)
Expand Down

0 comments on commit 6b04b6d

Please sign in to comment.