Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update H2O_MISO.py #137

Merged
merged 1 commit into from
Jun 10, 2023
Merged

Update H2O_MISO.py #137

merged 1 commit into from
Jun 10, 2023

Conversation

Yonv1943
Copy link
Collaborator

@Yonv1943 Yonv1943 commented Jun 8, 2023

#118

我们曾经在上面的帖子讨论过:

对于解 theta 的不同特征,共用了一样的 LSTM模型参数 ... ... 如果想要推广 'Learn to optimize' 到其他问题,那么就需要把 theta的特征维度从 batch size 维度移动到 inp_dim 或者 out_dim

现在发现,这两种选择并不是冲突的,我们可以仿照 DuelingDQN的思路:既让 network 同时学习:

  • 不同离散动作对应的Q值的平均
  • Q值对应不同离散动作的残差

https://github.com/AI4Finance-Foundation/ElegantRL/blob/0c019eec035391dbe7aca1464ed6a0067e5a130f/elegantrl/agents/net.py#L51-L67

class QNetDuel(QNetBase):  # Dueling DQN
    def __init__(self, dims: [int], state_dim: int, action_dim: int):
        super().__init__(state_dim=state_dim, action_dim=action_dim)
        self.net_state = build_mlp(dims=[state_dim, *dims])
        self.net_adv = build_mlp(dims=[dims[-1], 1])  # advantage value
        self.net_val = build_mlp(dims=[dims[-1], action_dim])  # Q value

    def forward(self, state):
        ...
        q_val = self.net_val(s_enc)  # q value
        q_adv = self.net_adv(s_enc)  # advantage value
        value = q_val - q_val.mean(dim=1, keepdim=True) + q_adv  # dueling Q value
        ...

如果用到“NP-hard 的 最优化问题上”,我们也可以让 network 学习,然后让网络同时学习:

  • 解theta的梯度 对应不同特征的平均值
  • 解theta的梯度 对应不同特征的残差
class OptimizerOpti(nn.Module):
    def __init__(self, inp_dim: int, hid_dim: int):
        ...
        self.output0 = nn.Linear(hid_dim * self.num_rnn, 1)
        self.output1 = nn.Linear(hid_dim * self.num_rnn, inp_dim)

    def forward(self, inp0, hid_):
        ...
        hid = th.cat((hid1, hid2), dim=1)
        out_avg = self.output0(hid)
        out_res = self.output1(hid)
        out = out_avg + out_res
        return out, (hid1, hid2)

只修改了3行,就提速 我们 Graph MaxCut 的任务了,改了这几行代码,其他地方不需要改动

直接就能去测 TNCO问题。 @ZhangAIPI @spicywei 有空就测测吧。

@Yonv1943 Yonv1943 requested review from ZhangAIPI and spicywei June 8, 2023 10:13
@zhumingpassional zhumingpassional merged commit 3d1df16 into main Jun 10, 2023
@zhumingpassional zhumingpassional deleted the Yonv1943-duelingH2O branch June 10, 2023 07:07
@zhumingpassional
Copy link
Collaborator

thanks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants