Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

📝 update GraphMaxCutEnv to VecEnv #136

Open
Yonv1943 opened this issue Jun 4, 2023 · 1 comment
Open

📝 update GraphMaxCutEnv to VecEnv #136

Yonv1943 opened this issue Jun 4, 2023 · 1 comment
Labels
bug Something isn't working

Comments

@Yonv1943
Copy link
Collaborator

Yonv1943 commented Jun 4, 2023

import torch as th
import numpy as np
from torch import Tensor


class GraphMaxCutEnv:
    def __init__(self, num_envs=8, device=th.device('cpu')):
        txt_path = "./graph_set_G14.txt"

        with open(txt_path, 'r') as file:
            lines = file.readlines()
            lines = [[int(i1) for i1 in i0.split()] for i0 in lines]

        num_nodes, num_edges = lines[0]
        edge_to_n0_n1_dist = [(i[0] - 1, i[1] - 1, i[2]) for i in lines[1:]]

        '''
        n0: index of node0
        n1: index of node1
        dt: distance between node0 and node1
        p0: the probability of node0 is in set, (1-p0): node0 is in another set
        p1: the probability of node0 is in set, (1-p1): node0 is in another set
        '''

        n0_to_n1s = [[] for _ in range(num_nodes)]  # 将 node0_id 映射到 node1_id
        n0_to_dts = [[] for _ in range(num_nodes)]  # 将 mode0_id 映射到 node1_id 与 node0_id 的距离
        for n0, n1, dist in edge_to_n0_n1_dist:
            n0_to_n1s[n0].append(n1)
            n0_to_dts[n0].append(dist)
        n0_to_n1s = [th.tensor(node1s, dtype=th.long, device=device) for node1s in n0_to_n1s]
        n0_to_dts = [th.tensor(node1s, dtype=th.long, device=device) for node1s in n0_to_dts]  # dists == 1
        assert num_nodes == len(n0_to_n1s)
        assert num_nodes == len(n0_to_dts)
        assert num_edges == sum([len(n0_to_n1) for n0_to_n1 in n0_to_n1s])
        assert num_edges == sum([len(n0_to_dt) for n0_to_dt in n0_to_dts])

        self.num_envs = num_envs
        self.num_nodes = len(n0_to_n1s)
        self.num_edges = sum([len(n0_to_n1) for n0_to_n1 in n0_to_n1s])
        self.n0_to_n1s = n0_to_n1s
        self.device = device

        '''为了高性能计算,删掉了 n0_to_n1s 的空item'''
        v2_ids = [i for i, n1 in enumerate(n0_to_n1s) if n1.shape[0] > 0]
        self.v2_ids = v2_ids
        self.v2_n0_to_n1s = [n0_to_n1s[idx] for idx in v2_ids]
        self.v2_num_nodes = len(v2_ids)

    def get_objective(self, p0s):
        assert p0s.shape == (self.num_envs, self.num_nodes)

        sum_dts = []
        for env_i in range(self.num_envs):
            p0 = p0s[env_i]
            n0_to_p1 = []
            for n1 in self.n0_to_n1s:
                p1 = p0[n1]
                n0_to_p1.append(p1)

            sum_dt = []
            for _p0, _p1 in zip(p0, n0_to_p1):
                # dt = _p0 * (1-_p1) + _p1 * (1-_p0)  # 等价于以下一行代码
                dt = _p0 + _p1 - 2 * _p0 * _p1
                sum_dt.append(dt.sum(dim=0))
            sum_dt = th.stack(sum_dt).sum(dim=-1)
            sum_dts.append(sum_dt)
        sum_dts = th.hstack(sum_dts)
        return sum_dts

    def get_objectives_v1(self, p0s):  # version 1
        device = p0s.device
        env_is = th.arange(self.num_envs, device=device)
        num_envs = self.num_envs
        num_nodes = self.num_nodes

        n0s_to_p1 = []
        for n1 in self.n0_to_n1s:
            num_n1 = n1.shape[0]
            if num_n1 == 0:  # 为了高性能计算,可将 n0_to_n1s 的空item 删掉
                p1s = th.zeros((num_envs, 0), dtype=th.float32, device=device)
            else:
                env_js = env_is.repeat(num_n1, 1).T.reshape(num_envs * num_n1)
                n1s = n1.repeat(num_envs)
                p1s = p0s[env_js, n1s].reshape(num_envs, num_n1)
            n0s_to_p1.append(p1s)

        sum_dts = th.zeros((num_envs, num_nodes), dtype=th.float32, device=device)
        for node_i in range(num_nodes):
            _p0 = p0s[:, node_i].unsqueeze(1)
            _p1 = n0s_to_p1[node_i]

            dt = _p0 + _p1 - 2 * _p0 * _p1
            sum_dts[:, node_i] = dt.sum(dim=-1)
        return sum_dts.sum(dim=-1)

    def get_objectives(self, p0s):  # version 2
        device = p0s.device
        env_is = th.arange(self.num_envs, device=device)
        num_envs = self.num_envs
        # num_nodes = self.num_nodes
        v2_num_nodes = len(self.v2_ids)

        v2_p0s = p0s[:, self.v2_ids]

        n0s_to_p1 = []
        for n1 in self.v2_n0_to_n1s:
            num_n1 = n1.shape[0]
            env_js = env_is.repeat(num_n1, 1).T.reshape(num_envs * num_n1)
            n1s = n1.repeat(num_envs)
            p1s = p0s[env_js, n1s].reshape(num_envs, num_n1)
            n0s_to_p1.append(p1s)

        sum_dts = th.zeros((num_envs, v2_num_nodes), dtype=th.float32, device=device)
        for node_i in range(v2_num_nodes):
            _p0 = v2_p0s[:, node_i].unsqueeze(1)
            _p1 = n0s_to_p1[node_i]

            dt = _p0 + _p1 - 2 * _p0 * _p1
            sum_dts[:, node_i] = dt.sum(dim=-1)
        return sum_dts.sum(dim=-1)

    def get_rand_p0s(self):
        device = self.device
        return th.rand((self.num_envs, self.num_nodes), dtype=th.float32, device=device)


def check_env():
    th.manual_seed(0)
    env = GraphMaxCutEnv(num_envs=6)

    p0s = env.get_rand_p0s()
    print(env.get_objective(p0s))
    print(env.get_objectives_v1(p0s))
    print(env.get_objectives(p0s))


check_env()
@YangletLiu YangletLiu added the bug Something isn't working label Jun 6, 2023
@Yonv1943
Copy link
Collaborator Author

检查结果如下:

  • 第一行,无矢量并行的环境,在for循环里求出的8个结果
  • 第二行,矢量并行的环境,version1,求出的结果无误
  • 第三行,矢量并行的环境,version2,跳过部分节点数量为0的计算,结果相差在 1e-3 以内,可以接受
tensor([2345.0999, 2354.1797, 2337.8169, 2338.0452, 2332.6572, 2356.8047])
tensor([2345.0999, 2354.1797, 2337.8169, 2338.0452, 2332.6572, 2356.8047])
tensor([2345.0999, 2354.1794, 2337.8171, 2338.0452, 2332.6570, 2356.8047])

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants