Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ update 'Learn to optimize' to batch size mode #118

Open
Yonv1943 opened this issue May 5, 2023 · 1 comment
Open

✨ update 'Learn to optimize' to batch size mode #118

Yonv1943 opened this issue May 5, 2023 · 1 comment
Labels
enhancement New feature or request

Comments

@Yonv1943
Copy link
Collaborator

Yonv1943 commented May 5, 2023

问题需要求出的解,我们定义为 theta。例如在miso问题里,theta.shape==(2, 8, 8),其中2是复数的实部和虚部,(8, 8) 分别是用户数量和基站天线数量。

这是之前的方案:

  • 使用LSTM,需要保存两个隐藏状态,分别是 hidden state 和 cell state
  • 需要求解theta的时候,对于模型LSTM,输入是某一个解theta inp.shape=(2*8*8, 1),输出是这个解对应的梯度 grad out.shape=inp.shape

里面对于解 theta 的不同特征,共用了一样的 LSTM模型参数

在 miso问题里面,这是恰当的,因为在 theta这个矩阵里,任意用户和任意天线,以及实部与虚部,都可以相互替换。所以它们被 flatten后放在 batch size 这个并行维度上,对于解 theta 的不同特征,共用了一样的 LSTM模型参数。

https://github.com/AI4Finance-Foundation/ElegantRL_Solver/blob/a7cd35b66a99600386efe1b642dc9b1453ed10f7/rlsolver/rlsolver_learn2opt/tensor_train/L2O_H_term.py#L52-L68

如果想要推广 'Learn to optimize' 到其他问题,那么就需要把 theta的特征维度从 batch size 维度移动到 inp_dim 或者 out_dim 上,这样修改后,训练将会变慢,但是训练后得到的最高分数不会改变。

这是为了让调整后的代码适用于 张量收缩任务(TNCO) @spicywei ,以及 图的最大割任务 (Graph max cut) @shixun404

调整代码如下:

  • 使用GRU,只需要保存一个隐藏状态 hidden state
  • 需要求解theta的时候,对于模型GRU,输入是某一批次的解theta inp.shape=(batch_size, 2*8*8),输出是这个解对应的梯度 grad out.shape=inp.shape
class OptimizerOpti(nn.Module):
    def __init__(self, opt_dim, hid_dim):
        super().__init__()
        self.opt_dim = opt_dim
        self.hid_dim = hid_dim
        self.num_rec = 4

        self.activation = nn.Tanh()
        self.recurs1 = nn.GRUCell(opt_dim, hid_dim)
        self.recurs2 = nn.GRUCell(hid_dim, hid_dim)
        self.output = nn.Linear(hid_dim * 2, opt_dim)

    def forward(self, inp0, hid0):
        hid1 = self.activation(self.recurs1(inp0, hid0[0]))
        hid2 = self.activation(self.recurs2(hid1, hid0[1]))
        hid = th.cat((hid1, hid2), dim=1)
        return self.output(hid), (hid1, hid2)

完整代码见: https://github.com/AI4Finance-Foundation/ElegantRL_Solver/pull/119/files#diff-e03802a5a83ef6f88ad30c077b0f4cec4b4f6cc21f3cbac771087ea2824618ba


Compare

在MISO问题上,(加入更多人类先验知识的)旧方法肯定比新方法更快(达到最好的结果的耗时 1: 3 ),但是他们能达到的最高分是一样高的

以下是旧方法:batch size 并行维度被用来作为 theta特征维度的 LSTM 的结果

start training
    MMSE     5.598    15.900    31.134

     L2O     6.078    11.540    12.724    TimeUsed         9
     L2O     6.491    18.491    35.106    TimeUsed       160
     L2O     6.108    18.211    35.091    TimeUsed       311
     L2O     6.459    18.308    33.260    TimeUsed       465
     L2O     6.454    18.421    34.214    TimeUsed       615
     L2O     6.459    18.347    33.878    TimeUsed       754
     L2O     6.430    18.030    34.527    TimeUsed       892
     L2O     6.440    18.366    34.233    TimeUsed      1033

以下是新方法:区分batch size并行维度 和 theta 特征维度的GRU的结果

    MMSE     5.598    15.900    31.134

training start
     L2O     1.088     2.968     4.889    TimeUsed         7
     L2O     5.991    16.585    21.792    TimeUsed       313
     L2O     6.250    17.582    26.524    TimeUsed       613
     L2O     6.063    16.946    28.046    TimeUsed       927
     L2O     6.313    18.009    30.573    TimeUsed      1234
     L2O     6.189    17.932    32.204    TimeUsed      1551
     L2O     6.098    17.880    33.794    TimeUsed      1860
     L2O     6.299    17.909    34.490    TimeUsed      2165
     L2O     6.233    17.952    34.699    TimeUsed      2471
     L2O     6.285    18.033    34.079    TimeUsed      2785
@YangletLiu YangletLiu added the enhancement New feature or request label May 6, 2023
@Yonv1943
Copy link
Collaborator Author

Yonv1943 commented May 6, 2023

更新:修改后的代码,在 sycamoreN12M14 的小例子上面,打破的原本旧代码的记录

  • 旧代码 5.5792907356870209 (花了一共2天,4张GPU)
  • 新代码 5.5765728354277142 (花了半天,2张GPU)
  • 别人的SOTA 5.83 (减去log10(2) 之后是 5.53),我们的结果距离最优更近了
EdgeSortStrH2OSycamoreN12M14 = """
[22 96 78 74 84 92 87 72 45 57 94 91 89 24 26  8  2 73  0 23 35 76 97  7 36 14 59 19 75 21  4 63 12 66 10  3 30 20 80 34
 90  6 52 11 77 79 53 31 16 49 54 25 70 62 56 33 83 95 15 18 47 43 93  9 55 71 32 29 64 46 67 17 44 38 86 61 13 40 51 58
 65 85 37 41 5 81 28 27 68 39 48 42  1 82 60 88 69 98 50]
"""  # 5.5792907356870209

EdgeSortStrH2OSycamoreN12M14 = """
[90 88  1 74 23 94 98 93 18 14  8 84 78 82 62 49  2  6  7 86 21 73 72 22 45  9 96 91 28 80 87 83 34 36 95 89  0 71 92
 35 32 75 77 81 97 40 76 33 31 24 16 19 20 63 13  4 15 79 70 85 60 27 56 66 51 26 47 54 44 50 58 38 25 11 10 69 64 42
 67  3 17 53 41 65 59 12 39 61 55 57 52 46 43  5 48 30 37 68 29]
"""  # 5.5765728354277142

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants