We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
import torch as th import numpy as np from torch import Tensor class GraphMaxCutEnv: def __init__(self, num_envs=8, device=th.device('cpu')): txt_path = "./graph_set_G14.txt" with open(txt_path, 'r') as file: lines = file.readlines() lines = [[int(i1) for i1 in i0.split()] for i0 in lines] num_nodes, num_edges = lines[0] edge_to_n0_n1_dist = [(i[0] - 1, i[1] - 1, i[2]) for i in lines[1:]] ''' n0: index of node0 n1: index of node1 dt: distance between node0 and node1 p0: the probability of node0 is in set, (1-p0): node0 is in another set p1: the probability of node0 is in set, (1-p1): node0 is in another set ''' n0_to_n1s = [[] for _ in range(num_nodes)] # 将 node0_id 映射到 node1_id n0_to_dts = [[] for _ in range(num_nodes)] # 将 mode0_id 映射到 node1_id 与 node0_id 的距离 for n0, n1, dist in edge_to_n0_n1_dist: n0_to_n1s[n0].append(n1) n0_to_dts[n0].append(dist) n0_to_n1s = [th.tensor(node1s, dtype=th.long, device=device) for node1s in n0_to_n1s] n0_to_dts = [th.tensor(node1s, dtype=th.long, device=device) for node1s in n0_to_dts] # dists == 1 assert num_nodes == len(n0_to_n1s) assert num_nodes == len(n0_to_dts) assert num_edges == sum([len(n0_to_n1) for n0_to_n1 in n0_to_n1s]) assert num_edges == sum([len(n0_to_dt) for n0_to_dt in n0_to_dts]) self.num_envs = num_envs self.num_nodes = len(n0_to_n1s) self.num_edges = sum([len(n0_to_n1) for n0_to_n1 in n0_to_n1s]) self.n0_to_n1s = n0_to_n1s self.device = device '''为了高性能计算,删掉了 n0_to_n1s 的空item''' v2_ids = [i for i, n1 in enumerate(n0_to_n1s) if n1.shape[0] > 0] self.v2_ids = v2_ids self.v2_n0_to_n1s = [n0_to_n1s[idx] for idx in v2_ids] self.v2_num_nodes = len(v2_ids) def get_objective(self, p0s): assert p0s.shape == (self.num_envs, self.num_nodes) sum_dts = [] for env_i in range(self.num_envs): p0 = p0s[env_i] n0_to_p1 = [] for n1 in self.n0_to_n1s: p1 = p0[n1] n0_to_p1.append(p1) sum_dt = [] for _p0, _p1 in zip(p0, n0_to_p1): # dt = _p0 * (1-_p1) + _p1 * (1-_p0) # 等价于以下一行代码 dt = _p0 + _p1 - 2 * _p0 * _p1 sum_dt.append(dt.sum(dim=0)) sum_dt = th.stack(sum_dt).sum(dim=-1) sum_dts.append(sum_dt) sum_dts = th.hstack(sum_dts) return sum_dts def get_objectives_v1(self, p0s): # version 1 device = p0s.device env_is = th.arange(self.num_envs, device=device) num_envs = self.num_envs num_nodes = self.num_nodes n0s_to_p1 = [] for n1 in self.n0_to_n1s: num_n1 = n1.shape[0] if num_n1 == 0: # 为了高性能计算,可将 n0_to_n1s 的空item 删掉 p1s = th.zeros((num_envs, 0), dtype=th.float32, device=device) else: env_js = env_is.repeat(num_n1, 1).T.reshape(num_envs * num_n1) n1s = n1.repeat(num_envs) p1s = p0s[env_js, n1s].reshape(num_envs, num_n1) n0s_to_p1.append(p1s) sum_dts = th.zeros((num_envs, num_nodes), dtype=th.float32, device=device) for node_i in range(num_nodes): _p0 = p0s[:, node_i].unsqueeze(1) _p1 = n0s_to_p1[node_i] dt = _p0 + _p1 - 2 * _p0 * _p1 sum_dts[:, node_i] = dt.sum(dim=-1) return sum_dts.sum(dim=-1) def get_objectives(self, p0s): # version 2 device = p0s.device env_is = th.arange(self.num_envs, device=device) num_envs = self.num_envs # num_nodes = self.num_nodes v2_num_nodes = len(self.v2_ids) v2_p0s = p0s[:, self.v2_ids] n0s_to_p1 = [] for n1 in self.v2_n0_to_n1s: num_n1 = n1.shape[0] env_js = env_is.repeat(num_n1, 1).T.reshape(num_envs * num_n1) n1s = n1.repeat(num_envs) p1s = p0s[env_js, n1s].reshape(num_envs, num_n1) n0s_to_p1.append(p1s) sum_dts = th.zeros((num_envs, v2_num_nodes), dtype=th.float32, device=device) for node_i in range(v2_num_nodes): _p0 = v2_p0s[:, node_i].unsqueeze(1) _p1 = n0s_to_p1[node_i] dt = _p0 + _p1 - 2 * _p0 * _p1 sum_dts[:, node_i] = dt.sum(dim=-1) return sum_dts.sum(dim=-1) def get_rand_p0s(self): device = self.device return th.rand((self.num_envs, self.num_nodes), dtype=th.float32, device=device) def check_env(): th.manual_seed(0) env = GraphMaxCutEnv(num_envs=6) p0s = env.get_rand_p0s() print(env.get_objective(p0s)) print(env.get_objectives_v1(p0s)) print(env.get_objectives(p0s)) check_env()
The text was updated successfully, but these errors were encountered:
检查结果如下:
tensor([2345.0999, 2354.1797, 2337.8169, 2338.0452, 2332.6572, 2356.8047]) tensor([2345.0999, 2354.1797, 2337.8169, 2338.0452, 2332.6572, 2356.8047]) tensor([2345.0999, 2354.1794, 2337.8171, 2338.0452, 2332.6570, 2356.8047])
Sorry, something went wrong.
No branches or pull requests
The text was updated successfully, but these errors were encountered: