-
Notifications
You must be signed in to change notification settings - Fork 33
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
342499d
commit 06fa8a6
Showing
16 changed files
with
2,511 additions
and
39 deletions.
There are no files selected for viewing
101 changes: 101 additions & 0 deletions
101
PythonExample/hmp_minimal_modules/ALGORITHM/hete_league_onenet_fix/ccategorical.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
from torch.distributions.categorical import Categorical | ||
import torch | ||
from .foundation import AlgorithmConfig | ||
from UTIL.tensor_ops import repeat_at, _2tensor | ||
from torch.distributions import kl_divergence | ||
EPS = 1e-9 | ||
# yita = p_hit = 0.14 | ||
|
||
def random_process(probs, rsn_flag): | ||
yita = AlgorithmConfig.yita | ||
with torch.no_grad(): | ||
max_place = probs.argmax(-1, keepdims=True) | ||
mask_max = torch.zeros_like(probs).scatter_(-1, max_place, 1).bool() | ||
pmax = probs[mask_max] | ||
if rsn_flag: | ||
assert max_place.shape[-1] == 1 | ||
return max_place.squeeze(-1) | ||
else: | ||
# forbit max prob being chosen, pmax = probs.max(axis=-1) | ||
p_hat = pmax + (pmax-1)/(1/yita-1) | ||
k = 1/(1-yita) | ||
#!!! write | ||
probs *= k | ||
#!!! write | ||
probs[mask_max] = p_hat | ||
# print(probs) | ||
dist = Categorical(probs=probs) | ||
samp = dist.sample() | ||
assert samp.shape[-1] != 1 | ||
return samp | ||
|
||
def random_process_allow_big_yita(probs, rsn_flag): | ||
yita = AlgorithmConfig.yita | ||
with torch.no_grad(): | ||
max_place = probs.argmax(-1, keepdims=True) | ||
mask_max = torch.zeros_like(probs).scatter_(-1, max_place, 1).bool() | ||
pmax = probs[mask_max].reshape(max_place.shape) #probs[max_place].clone() | ||
if rsn_flag: | ||
assert max_place.shape[-1] == 1 | ||
return max_place.squeeze(-1) | ||
else: | ||
# forbit max prob being chosen | ||
# pmax = probs.max(axis=-1) #probs[max_place].clone() | ||
yita_arr = torch.ones_like(pmax)*yita | ||
yita_arr_clip = torch.minimum(pmax, yita_arr) | ||
# p_hat = pmax + (pmax-1) / (1/yita_arr_clip-1) + 1e-10 | ||
p_hat = (pmax-yita_arr_clip)/(1-yita_arr_clip) | ||
k = 1/(1-yita_arr_clip) | ||
probs *= k | ||
probs[mask_max] = p_hat.reshape(-1) | ||
|
||
# print(probs) | ||
dist = Categorical(probs=probs) | ||
samp = dist.sample() | ||
assert samp.shape[-1] != 1 | ||
return samp #.squeeze(-1) | ||
|
||
|
||
|
||
def random_process_with_clamp3(probs, yita, yita_min_prob, rsn_flag): | ||
|
||
with torch.no_grad(): | ||
max_place = probs.argmax(-1, keepdims=True) | ||
mask_max = torch.zeros_like(probs).scatter_(dim=-1, index=max_place, value=1).bool() | ||
pmax = probs[mask_max].reshape(max_place.shape) | ||
# act max | ||
assert max_place.shape[-1] == 1 | ||
act_max = max_place.squeeze(-1) | ||
# act samp | ||
yita_arr = torch.ones_like(pmax)*yita | ||
# p_hat = pmax + (pmax-1) / (1/yita_arr_clip-1) + 1e-10 | ||
p_hat = (pmax-yita_arr)/((1-yita_arr)+EPS) | ||
p_hat = p_hat.clamp(min=yita_min_prob) | ||
k = (1-p_hat)/((1-pmax)+EPS) | ||
probs *= k | ||
probs[mask_max] = p_hat.reshape(-1) | ||
dist = Categorical(probs=probs) | ||
act_samp = dist.sample() | ||
# assert act_samp.shape[-1] != 1 | ||
hit_e = _2tensor(rsn_flag) | ||
return torch.where(hit_e, act_max, act_samp) | ||
|
||
|
||
class CCategorical(): | ||
def __init__(self, planner): | ||
self.planner = planner | ||
|
||
pass | ||
|
||
def sample(self, dist, eprsn): | ||
probs = dist.probs.clone() | ||
return random_process_with_clamp3(probs, self.planner.yita, self.planner.yita_min_prob, eprsn) | ||
|
||
def register_rsn(self, rsn_flag): | ||
self.rsn_flag = rsn_flag | ||
|
||
def feed_logits(self, logits): | ||
try: | ||
return Categorical(logits=logits) | ||
except: | ||
print('error') |
47 changes: 47 additions & 0 deletions
47
PythonExample/hmp_minimal_modules/ALGORITHM/hete_league_onenet_fix/cython_func.pyx
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
import numpy as np | ||
cimport numpy as np | ||
cimport cython | ||
from cython.parallel import prange | ||
np.import_array() | ||
ctypedef fused DTYPE_t: | ||
np.float32_t | ||
np.float64_t | ||
|
||
ctypedef fused DTYPE_intlong_t: | ||
np.int64_t | ||
np.int32_t # to compat Windows | ||
|
||
ctypedef np.uint8_t DTYPE_bool_t | ||
|
||
|
||
@cython.boundscheck(False) | ||
@cython.wraparound(False) | ||
@cython.nonecheck(False) | ||
def roll_hisory( DTYPE_t[:,:,:,:] obs_feed_new, | ||
DTYPE_t[:,:,:,:] prev_obs_feed, | ||
DTYPE_bool_t[:,:,:] valid_mask, | ||
DTYPE_intlong_t[:,:] N_valid, | ||
DTYPE_t[:,:,:,:] next_his_pool): | ||
# how many threads | ||
cdef Py_ssize_t vmax = N_valid.shape[0] | ||
# how many agents | ||
cdef Py_ssize_t wmax = N_valid.shape[1] | ||
# how many entity subjects (including self @0) | ||
cdef Py_ssize_t max_obs_entity = obs_feed_new.shape[2] | ||
cdef int n_v, th, a, t, k, pointer | ||
for th in prange(vmax, nogil=True): | ||
# for each thread range -> prange | ||
for a in prange(wmax): | ||
# for each agent | ||
pointer = 0 | ||
# step 1 fill next_his_pool[0 ~ (nv-1)] with obs_feed_new[0 ~ max_obs_entity-1] | ||
for k in range(max_obs_entity): | ||
if valid_mask[th,a,k]: | ||
next_his_pool[th, a, pointer] = obs_feed_new[th,a,k] | ||
pointer = pointer + 1 | ||
|
||
# step 2 fill next_his_pool[nv ~ (max_obs_entity-1)] with prev_obs_feed[0 ~ (max_obs_entity-1-nv)] | ||
n_v = N_valid[th,a] | ||
for k in range(n_v, max_obs_entity): | ||
next_his_pool[th,a,k] = prev_obs_feed[th,a,k-n_v] | ||
return np.asarray(next_his_pool) |
155 changes: 155 additions & 0 deletions
155
PythonExample/hmp_minimal_modules/ALGORITHM/hete_league_onenet_fix/div_tree.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,155 @@ | ||
import torch | ||
import torch.nn as nn | ||
import numpy as np | ||
from ALGORITHM.common.mlp import LinearFinal | ||
from UTIL.tensor_ops import add_onehot_id_at_last_dim, add_onehot_id_at_last_dim_fixlen, repeat_at, _2tensor, gather_righthand, scatter_righthand | ||
|
||
|
||
|
||
class DivTree(nn.Module): # merge by MLP version | ||
def __init__(self, input_dim, h_dim, n_action): | ||
super().__init__() | ||
|
||
# to design a division tree, I need to get the total number of agents | ||
from .foundation import AlgorithmConfig | ||
self.n_agent = AlgorithmConfig.n_agent | ||
self.div_tree = get_division_tree(self.n_agent) | ||
self.n_level = len(self.div_tree) | ||
self.max_level = len(self.div_tree) - 1 | ||
self.current_level = 0 | ||
self.init_level = AlgorithmConfig.div_tree_init_level | ||
if self.init_level < 0: | ||
self.init_level = self.max_level | ||
self.current_level_floating = 0.0 | ||
|
||
get_net = lambda: nn.Sequential( | ||
nn.Linear(h_dim+self.n_agent, h_dim), | ||
nn.ReLU(inplace=True), | ||
LinearFinal(h_dim, n_action) | ||
) | ||
# Note: this is NOT net defining for each agent | ||
# Instead, all agents starts from self.nets[0] | ||
self.nets = torch.nn.ModuleList(modules=[ | ||
get_net() for i in range(self.n_agent) | ||
]) | ||
|
||
def set_to_init_level(self, auto_transfer=True): | ||
if self.init_level!=self.current_level: | ||
for i in range(self.current_level, self.init_level): | ||
self.change_div_tree_level(i+1, auto_transfer) | ||
|
||
|
||
|
||
def change_div_tree_level(self, level, auto_transfer=True): | ||
print('performing div tree level change (%d -> %d/%d) \n'%(self.current_level, level, self.max_level)) | ||
self.current_level = level | ||
self.current_level_floating = level | ||
assert len(self.div_tree) > self.current_level, ('Reach max level already!') | ||
if not auto_transfer: return | ||
transfer_list = [] | ||
for i in range(self.n_agent): | ||
previous_net_index = self.div_tree[self.current_level-1, i] | ||
post_net_index = self.div_tree[self.current_level, i] | ||
if post_net_index!=previous_net_index: | ||
transfer = (previous_net_index, post_net_index) | ||
if transfer not in transfer_list: | ||
transfer_list.append(transfer) | ||
for transfer in transfer_list: | ||
from_which_net = transfer[0] | ||
to_which_net = transfer[1] | ||
self.nets[to_which_net].load_state_dict(self.nets[from_which_net].state_dict()) | ||
print('transfering model parameters from %d-th net to %d-th net'%(from_which_net, to_which_net)) | ||
return | ||
|
||
def forward(self, x_in, agent_ids): # x0: shape = (?,...,?, n_agent, core_dim) | ||
if self.current_level == 0: | ||
x0 = add_onehot_id_at_last_dim_fixlen(x_in, fixlen=self.n_agent, agent_ids=agent_ids) | ||
x2 = self.nets[0](x0) | ||
return x2, None | ||
else: | ||
x0 = add_onehot_id_at_last_dim_fixlen(x_in, fixlen=self.n_agent, agent_ids=agent_ids) | ||
res = [] | ||
for i in range(self.n_agent): | ||
use_which_net = self.div_tree[self.current_level, i] | ||
res.append(self.nets[use_which_net](x0[..., i, :])) | ||
x2 = torch.stack(res, -2) | ||
# x22 = self.nets[0](x1) | ||
|
||
return x2, None | ||
|
||
# def forward_try_parallel(self, x0): # x0: shape = (?,...,?, n_agent, core_dim) | ||
# x1 = self.shared_net(x0) | ||
# stream = [] | ||
# res = [] | ||
# for i in range(self.n_agent): | ||
# stream.append(torch.cuda.Stream()) | ||
|
||
# torch.cuda.synchronize() | ||
# for i in range(self.n_agent): | ||
# use_which_net = self.div_tree[self.current_level, i] | ||
# with torch.cuda.stream(stream[i]): | ||
# res.append(self.nets[use_which_net](x1[..., i, :])) | ||
# print(res[i]) | ||
|
||
# # s1 = torch.cuda.Stream() | ||
# # s2 = torch.cuda.Stream() | ||
# # # Wait for the above tensors to initialise. | ||
# # torch.cuda.synchronize() | ||
# # with torch.cuda.stream(s1): | ||
# # C = torch.mm(A, A) | ||
# # with torch.cuda.stream(s2): | ||
# # D = torch.mm(B, B) | ||
# # Wait for C and D to be computed. | ||
# torch.cuda.synchronize() | ||
# # Do stuff with C and D. | ||
|
||
# x2 = torch.stack(res, -2) | ||
|
||
# return x2 | ||
|
||
|
||
|
||
def _2div(arr): | ||
arr_res = arr.copy() | ||
arr_pieces = [] | ||
pa = 0 | ||
st = 0 | ||
needdivcnt = 0 | ||
for i, a in enumerate(arr): | ||
if a!=pa: | ||
arr_pieces.append([st, i]) | ||
if (i-st)!=1: needdivcnt+=1 | ||
pa = a | ||
st = i | ||
|
||
arr_pieces.append([st, len(arr)]) | ||
if (len(arr)-st)!=1: needdivcnt+=1 | ||
|
||
offset = range(len(arr_pieces), len(arr_pieces)+needdivcnt) | ||
p=0 | ||
for arr_p in arr_pieces: | ||
length = arr_p[1] - arr_p[0] | ||
if length == 1: continue | ||
half_len = int(np.ceil(length / 2)) | ||
for j in range(arr_p[0]+half_len, arr_p[1]): | ||
try: | ||
arr_res[j] = offset[p] | ||
except: | ||
print('wtf') | ||
p+=1 | ||
return arr_res | ||
|
||
def get_division_tree(n_agents): | ||
agent2divitreeindex = np.arange(n_agents) | ||
np.random.shuffle(agent2divitreeindex) | ||
max_div = np.ceil(np.log2(n_agents)).astype(int) | ||
levels = np.zeros(shape=(max_div+1, n_agents), dtype=int) | ||
tree_of_agent = []*(max_div+1) | ||
for ith, level in enumerate(levels): | ||
if ith == 0: continue | ||
res = _2div(levels[ith-1,:]) | ||
levels[ith,:] = res | ||
res_levels = levels.copy() | ||
for i, div_tree_index in enumerate(agent2divitreeindex): | ||
res_levels[:, i] = levels[:, div_tree_index] | ||
return res_levels |
Oops, something went wrong.