Skip to content

Commit

Permalink
implement grad refer lego
Browse files Browse the repository at this point in the history
test=develop
  • Loading branch information
zhaify committed Sep 11, 2019
1 parent 00a38f8 commit 1a69c90
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 89 deletions.
1 change: 0 additions & 1 deletion paddle/fluid/API.spec
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,6 @@ paddle.fluid.layers.filter_by_instag (ArgSpec(args=['ins', 'ins_tag', 'filter_ta
paddle.fluid.layers.var_conv_2d (ArgSpec(args=['input', 'row', 'col', 'input_channel', 'output_channel', 'filter_size', 'stride', 'param_attr', 'act', 'dtype', 'name'], varargs=None, keywords=None, defaults=(1, None, None, 'float32', None)), ('document', '7a8b8ade5512c95f9ea30261d33ded6c'))
paddle.fluid.layers.shard_index (ArgSpec(args=['input', 'index_num', 'nshards', 'shard_id', 'ignore_value'], varargs=None, keywords=None, defaults=(-1,)), ('document', '5786fdbba6753ecd6cbce5e6b0889924'))
paddle.fluid.layers.hard_swish (ArgSpec(args=['x', 'threshold', 'scale', 'offset', 'name'], varargs=None, keywords=None, defaults=(6.0, 6.0, 3.0, None)), ('document', '6a5152a7015c62cb8278fc24cb456459'))
paddle.fluid.layers.fused_emb_seq (ArgSpec(args=['input', 'size', 'is_sparse', 'padding_idx', 'combiner', 'param_attr', 'dtype'], varargs=None, keywords=None, defaults=(False, None, 'sum', None, 'float32')), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.layers.data (ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True)), ('document', '9d7806e31bdf727c1a23b8782a09b545'))
paddle.fluid.layers.read_file (ArgSpec(args=['reader'], varargs=None, keywords=None, defaults=None), ('document', '88367daf9a30c9ab83adc5d7221e23ef'))
paddle.fluid.layers.double_buffer (ArgSpec(args=['reader', 'place', 'name'], varargs=None, keywords=None, defaults=(None, None)), ('document', '44724c493f41a124abc7531c2740e2e3'))
Expand Down
30 changes: 14 additions & 16 deletions paddle/fluid/operators/fused/fused_embedding_seq_pool_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ using DDim = framework::DDim;
constexpr int64_t kNoPadding = -1;

#if defined(PADDLE_WITH_MKLML) && !defined(_WIN32) && !defined(__APPLE__) && \
!defined(__OSX__) && !defined(PADDLE_WITH_CUDA)
!defined(__OSX__)
template <typename T>
void prepare_csr_data(const std::vector<uint64_t> &offset,
const int64_t *ids_data, const size_t idx_width,
Expand Down Expand Up @@ -138,7 +138,7 @@ class FusedEmbeddingSeqPoolKernel : public framework::OpKernel<T> {

if (combiner_type == "sum") {
#if defined(PADDLE_WITH_MKLML) && !defined(_WIN32) && !defined(__APPLE__) && \
!defined(__OSX__) && !defined(PADDLE_WITH_CUDA)
!defined(__OSX__)
int64_t padding_idx = context.Attr<int64_t>("padding_idx");
auto output = output_t->mutable_data<T>(context.GetPlace());
int64_t table_height = table_var->dims()[0];
Expand Down Expand Up @@ -232,7 +232,7 @@ class FusedEmbeddingSeqPoolGradKernel : public framework::OpKernel<T> {
}
} else {
#if defined(PADDLE_WITH_MKLML) && !defined(_WIN32) && !defined(__APPLE__) && \
!defined(__OSX__) && !defined(PADDLE_WITH_CUDA)
!defined(__OSX__)
auto *ids = context.Input<LoDTensor>("Ids");
auto *d_output = context.Input<LoDTensor>(framework::GradVarName("Out"));
auto *d_table = context.Output<LoDTensor>(framework::GradVarName("W"));
Expand Down Expand Up @@ -261,20 +261,18 @@ class FusedEmbeddingSeqPoolGradKernel : public framework::OpKernel<T> {
csr_colmuns, csr_row_idx, padding_idx);

auto *d_output_data = d_output->data<T>();
const char transa = 'T';
const T alpha = 1.0;
const T beta = 0.0;
const char matdescra[] = {'G', 'L', 'N', 'C'};

const int m = batch_size * idx_width;
const int n = table_dim[1];
const int k = table_dim[1];

auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
blas.CSRMM(&transa, &m, &n, &k, &alpha, matdescra, (const T *)csr_vals,
(const int *)csr_colmuns, (const int *)csr_row_idx,
(const int *)csr_row_idx + 1, d_output_data, &n, &beta,
d_table_data, &n);
int width = static_cast<int>(table_dim[1]);
int num_seq = batch_size * idx_width;
LOG(INFO) << "num seq = " << num_seq << " width = " << width;
for (int i = 0; i < num_seq; ++i) {
for (int j = csr_row_idx[i]; j < csr_row_idx[i + 1]; ++j) {
unsigned int word_idx = csr_colmuns[j];
T val = csr_vals[j];
blas.AXPY(width, val, d_output_data + i * width,
d_table_data + word_idx * width);
}
}
#else
LOG(ERROR) << "Dense is not supported in fused_embedding_seq_pool_op now";
#endif
Expand Down
28 changes: 0 additions & 28 deletions python/paddle/fluid/layers/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,6 @@
'var_conv_2d',
'shard_index',
'hard_swish',
'fused_emb_seq',
]

kIgnoreIndex = -100
Expand Down Expand Up @@ -13290,30 +13289,3 @@ def hard_swish(x, threshold=6.0, scale=6.0, offset=3.0, name=None):
'scale': scale,
'offset': offset})
return out


def fused_emb_seq(input,
size,
is_sparse=False,
padding_idx=None,
combiner='sum',
param_attr=None,
dtype='float32'):

helper = LayerHelper('fused_emb_seq', **locals())
w = helper.create_parameter(
attr=helper.param_attr, shape=size, dtype=dtype, is_bias=False)
out = helper.create_variable_for_type_inference(dtype)
padding_idx = -1 if padding_idx is None else padding_idx if padding_idx >= 0 else (
size[0] + padding_idx)
helper.append_op(
type='fused_embedding_seq_pool',
inputs={'Ids': input,
'W': w},
outputs={'Out': out},
attrs={
'is_sparse': is_sparse,
'combiner': combiner,
'padding_idx': padding_idx
})
return out
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ def setUp(self):
self.table = np.random.random((17, self.emb_size)).astype("float32")
self.ids = np.array([[[4], [3]], [[4], [3]], [[2], [1]],
[[16], [1]]]).astype("int64")
merged_ids = np.array([4, 2, 16]).astype("int64")
ids_expand = np.expand_dims(self.ids, axis=1)
self.lod = [[3, 1]]
self.attrs = {'is_sparse': True}
Expand All @@ -49,16 +48,14 @@ def test_check_output(self):
self.check_output()

def test_check_grad(self):
if ver.mkl() == "ON" and not fluid.core.is_compiled_with_cuda(
) and 'Linux' in platform.platform():
if ver.mkl() == "ON" and 'Linux' in platform.platform():
self.attrs = {'is_sparse': False}
self.check_grad(['W'], 'Out', no_grad_set=('Ids'))


class TestLookupTableOpWithPadding(TestFusedEmbeddingSeqPoolOp):
def test_check_output(self):
if ver.mkl() == "ON" and not fluid.core.is_compiled_with_cuda(
) and 'Linux' in platform.platform():
if ver.mkl() == "ON" and 'Linux' in platform.platform():
ids = np.squeeze(self.ids, axis=2)
padding_idx = np.random.choice(ids.flatten(), 1)[0]
output = list()
Expand All @@ -80,8 +77,7 @@ def test_check_output(self):
self.check_output()

def test_check_grad(self):
if ver.mkl() == "ON" and not fluid.core.is_compiled_with_cuda(
) and 'Linux' in platform.platform():
if ver.mkl() == "ON" and 'Linux' in platform.platform():
ids = np.squeeze(self.ids, axis=2)
padding_idx = np.random.choice(ids.flatten(), 1)[0]
self.attrs = {'padding_idx': int(padding_idx), 'is_sparse': False}
Expand Down
37 changes: 0 additions & 37 deletions python/paddle/fluid/tests/unittests/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

from __future__ import print_function
import unittest
import platform

import contextlib
import numpy as np
Expand All @@ -34,7 +33,6 @@
from test_imperative_base import new_program_scope
from paddle.fluid.dygraph import nn
from paddle.fluid.dygraph import base
import paddle.version as ver


class LayerTest(unittest.TestCase):
Expand Down Expand Up @@ -2253,41 +2251,6 @@ def test_retinanet_detection_output(self):
nms_eta=1.)
return (nmsed_outs)

def test_fused_emb_seq(self):
if ver.mkl() == "ON" and not fluid.core.is_compiled_with_cuda(
) and 'Linux' in platform.platform():
dict_size = 20
with self.static_graph():
tensor = fluid.core.LoDTensor()
place = fluid.core.CPUPlace()
tensor.set(np.array([1, 2, 3, 1, 2]).astype("int64"), place)
tensor.set_recursive_sequence_lengths([[4, 1]])
data_t = layers.data(
name='word', shape=[1], dtype='int64', lod_level=1)
emb = layers.fused_emb_seq(
input=data_t,
size=[dict_size, 32],
param_attr='w',
is_sparse=False)
self.get_static_graph_result(
feed={'word': tensor}, fetch_list=[emb])

with self.static_graph():
tensor = fluid.core.LoDTensor()
place = fluid.core.CPUPlace()
tensor.set(np.array([1, 2, 3, 1, 2]).astype("int64"), place)
tensor.set_recursive_sequence_lengths([[4, 1]])
data_t = layers.data(
name='word', shape=[1], dtype='int64', lod_level=1)
emb = layers.fused_emb_seq(
input=data_t,
size=[dict_size, 32],
param_attr='w',
padding_idx=1,
is_sparse=False)
self.get_static_graph_result(
feed={'word': tensor}, fetch_list=[emb])


if __name__ == '__main__':
unittest.main()

0 comments on commit 1a69c90

Please sign in to comment.