forked from songgc/TF-recomm
-
Notifications
You must be signed in to change notification settings - Fork 0
/
ops.py
34 lines (30 loc) · 1.88 KB
/
ops.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import tensorflow as tf
def inference_svd(user_batch, item_batch, user_num, item_num, dim=5, device="/cpu:0"):
with tf.device("/cpu:0"):
bias_global = tf.get_variable("bias_global", shape=[])
w_bias_user = tf.get_variable("embd_bias_user", shape=[user_num])
w_bias_item = tf.get_variable("embd_bias_item", shape=[item_num])
bias_user = tf.nn.embedding_lookup(w_bias_user, user_batch, name="bias_user")
bias_item = tf.nn.embedding_lookup(w_bias_item, item_batch, name="bias_item")
w_user = tf.get_variable("embd_user", shape=[user_num, dim],
initializer=tf.truncated_normal_initializer(stddev=0.02))
w_item = tf.get_variable("embd_item", shape=[item_num, dim],
initializer=tf.truncated_normal_initializer(stddev=0.02))
embd_user = tf.nn.embedding_lookup(w_user, user_batch, name="embedding_user")
embd_item = tf.nn.embedding_lookup(w_item, item_batch, name="embedding_item")
with tf.device(device):
infer = tf.reduce_sum(tf.multiply(embd_user, embd_item), 1)
infer = tf.add(infer, bias_global)
infer = tf.add(infer, bias_user)
infer = tf.add(infer, bias_item, name="svd_inference")
regularizer = tf.add(tf.nn.l2_loss(embd_user), tf.nn.l2_loss(embd_item), name="svd_regularizer")
return infer, regularizer
def optimization(infer, regularizer, rate_batch, learning_rate=0.001, reg=0.1, device="/cpu:0"):
global_step = tf.train.get_global_step()
assert global_step is not None
with tf.device(device):
cost_l2 = tf.nn.l2_loss(tf.subtract(infer, rate_batch))
penalty = tf.constant(reg, dtype=tf.float32, shape=[], name="l2")
cost = tf.add(cost_l2, tf.multiply(regularizer, penalty))
train_op = tf.train.AdamOptimizer(learning_rate).minimize(cost, global_step=global_step)
return cost, train_op