-
Notifications
You must be signed in to change notification settings - Fork 181
/
utils.py
28 lines (24 loc) · 965 Bytes
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import math
import torch
def create_log_gaussian(mean, log_std, t):
quadratic = -((0.5 * (t - mean) / (log_std.exp())).pow(2))
l = mean.shape
log_z = log_std
z = l[-1] * math.log(2 * math.pi)
log_p = quadratic.sum(dim=-1) - log_z.sum(dim=-1) - 0.5 * z
return log_p
def logsumexp(inputs, dim=None, keepdim=False):
if dim is None:
inputs = inputs.view(-1)
dim = 0
s, _ = torch.max(inputs, dim=dim, keepdim=True)
outputs = s + (inputs - s).exp().sum(dim=dim, keepdim=True).log()
if not keepdim:
outputs = outputs.squeeze(dim)
return outputs
def soft_update(target, source, tau):
for target_param, param in zip(target.parameters(), source.parameters()):
target_param.data.copy_(target_param.data * (1.0 - tau) + param.data * tau)
def hard_update(target, source):
for target_param, param in zip(target.parameters(), source.parameters()):
target_param.data.copy_(param.data)