diff --git a/rlsolver/rlsolver_learn2opt/tensor_train/H2O_MISO.py b/rlsolver/rlsolver_learn2opt/tensor_train/H2O_MISO.py index e132397..e6a34f8 100644 --- a/rlsolver/rlsolver_learn2opt/tensor_train/H2O_MISO.py +++ b/rlsolver/rlsolver_learn2opt/tensor_train/H2O_MISO.py @@ -88,7 +88,8 @@ def __init__(self, inp_dim: int, hid_dim: int): self.activation = nn.Tanh() self.recurs1 = nn.GRUCell(inp_dim, hid_dim) self.recurs2 = nn.GRUCell(hid_dim, hid_dim) - self.output0 = nn.Linear(hid_dim * self.num_rnn, inp_dim) + self.output0 = nn.Linear(hid_dim * self.num_rnn, 1) + self.output1 = nn.Linear(hid_dim * self.num_rnn, inp_dim) layer_init_with_orthogonal(self.output0, std=0.1) def forward(self, inp0, hid_): @@ -96,7 +97,9 @@ def forward(self, inp0, hid_): hid2 = self.activation(self.recurs2(hid1, hid_[1])) hid = th.cat((hid1, hid2), dim=1) - out = self.output0(hid) + out_avg = self.output0(hid) + out_res = self.output1(hid) + out = out_avg + out_res return out, (hid1, hid2)