-
Notifications
You must be signed in to change notification settings - Fork 10
/
util_MDN.py
44 lines (38 loc) · 1.3 KB
/
util_MDN.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
"""
Created on Thu Jun 9 12:12:57 2016
@author: rob
"""
import numpy as np
import tensorflow as tf
from scipy.stats import multivariate_normal
import matplotlib.pyplot as plt
# Extracts form the implementation by https://github.com/hardmaru/write-rnn-tensorflow
def tf_2d_normal(x1, x2, mu1, mu2, s1, s2, rho):
""" 2D normal distribution
input
- x,mu: input vectors
- s1,s2: standard deviances over x1 and x2
- rho: correlation coefficient in x1-x2 plane
"""
# eq # 24 and 25 of http://arxiv.org/abs/1308.0850
norm1 = tf.subtract(x1, mu1)
norm2 = tf.subtract(x2, mu2)
s1s2 = tf.multiply(s1, s2)
z = tf.square(tf.div(norm1, s1))+tf.square(tf.div(norm2, s2))-2.0*tf.div(tf.multiply(rho, tf.multiply(norm1, norm2)), s1s2)
negRho = 1-tf.square(rho)
result = tf.exp(tf.div(-1.0*z,2.0*negRho))
denom = 2*np.pi*tf.multiply(s1s2, tf.sqrt(negRho))
px1x2 = tf.div(result, denom)
return px1x2
def tf_1d_normal(x3,mu3,s3):
""" 3D normal distribution Under assumption that x3 is uncorrelated with x1 and x2
input
- x3: [batch_size, 1, seqlen]
- mu3,s3: [batch_size, mixtures, seqlen]
"""
norm3 = tf.subtract(x3, mu3)
z = tf.square(tf.div(norm3, s3))
result = tf.exp(tf.div(-z,2))
denom = np.sqrt(2.0*np.pi)*s3
px3 = tf.div(result, denom) #probability in x3 dimension
return px3