diff --git a/src/operator/cudnn_deconvolution-inl.h b/src/operator/cudnn_deconvolution-inl.h index 691089fe56c5..f804419b9c4f 100644 --- a/src/operator/cudnn_deconvolution-inl.h +++ b/src/operator/cudnn_deconvolution-inl.h @@ -14,15 +14,16 @@ namespace mxnet { namespace op { #if defined(__CUDACC__) && MXNET_USE_CUDNN == 1 +template class CuDNNDeconvolutionOp : public Operator { public: explicit CuDNNDeconvolutionOp(DeconvolutionParam param) { this->param_ = param; // convert MB to words - param_.workspace = (param_.workspace << 20) / sizeof(real_t); + param_.workspace = (param_.workspace << 20) / sizeof(DType); init_cudnn_ = false; // TODO(xxx): fp16 - dtype_ = CUDNN_DATA_FLOAT; + dtype_ = mshadow::DataType::kCudnnFlag; } ~CuDNNDeconvolutionOp() { @@ -45,20 +46,21 @@ class CuDNNDeconvolutionOp : public Operator { CHECK_EQ(in_data.size(), expected); CHECK_EQ(out_data.size(), 1); Stream *s = ctx.get_stream(); - Tensor data = in_data[deconv::kData].get(s); - Tensor wmat = in_data[deconv::kWeight].get(s); - Tensor out = out_data[deconv::kOut].get(s); + Tensor data = in_data[deconv::kData].get(s); + Tensor wmat = in_data[deconv::kWeight].get(s); + Tensor out = out_data[deconv::kOut].get(s); CHECK_EQ(data.CheckContiguous(), true); CHECK_EQ(wmat.CheckContiguous(), true); CHECK_EQ(out.CheckContiguous(), true); if (!init_cudnn_) { Init(s, in_data, out_data); } - Tensor workspace = ctx.requested[deconv::kTempSpace].get_space( + Tensor workspace = + ctx.requested[deconv::kTempSpace].get_space_typed( mshadow::Shape1(forward_workspace_), s); for (uint32_t g = 0; g < param_.num_group; ++g) { - float alpha = 1.0f; - float beta = 0.0f; + typename DataType::ScaleType alpha = 1.0f; + typename DataType::ScaleType beta = 0.0f; #if CUDNN_MAJOR <= 4 CHECK_EQ(cudnnConvolutionBackwardData_v3(s->dnn_handle_, &alpha, @@ -90,7 +92,7 @@ class CuDNNDeconvolutionOp : public Operator { #endif if (!param_.no_bias) { beta = 1.0f; - Tensor bias = in_data[deconv::kBias].get(s); + Tensor bias = in_data[deconv::kBias].get(s); #if CUDNN_MAJOR >= 4 CHECK_EQ(cudnnAddTensor(s->dnn_handle_, &alpha, @@ -129,18 +131,19 @@ class CuDNNDeconvolutionOp : public Operator { // TODO(bing): think about how to support add to CHECK_EQ(req[deconv::kWeight], kWriteTo); Stream *s = ctx.get_stream(); - Tensor grad = out_grad[deconv::kOut].get(s); - Tensor wmat = in_data[deconv::kWeight].get(s); - Tensor gwmat = in_grad[deconv::kWeight].get(s); - Tensor data = in_data[deconv::kData].get(s); - Tensor gdata = in_grad[deconv::kData].get(s); - Tensor workspace = ctx.requested[deconv::kTempSpace].get_space( + Tensor grad = out_grad[deconv::kOut].get(s); + Tensor wmat = in_data[deconv::kWeight].get(s); + Tensor gwmat = in_grad[deconv::kWeight].get(s); + Tensor data = in_data[deconv::kData].get(s); + Tensor gdata = in_grad[deconv::kData].get(s); + Tensor workspace = + ctx.requested[deconv::kTempSpace].get_space_typed( mshadow::Shape1(backward_workspace_), s); for (uint32_t g = 0; g < param_.num_group; ++g) { - float alpha = 1.0f; - float beta = 0.0f; + typename DataType::ScaleType alpha = 1.0f; + typename DataType::ScaleType beta = 0.0f; if (!param_.no_bias) { - Tensor gbias = in_grad[deconv::kBias].get(s); + Tensor gbias = in_grad[deconv::kBias].get(s); CHECK_EQ(cudnnConvolutionBackwardBias(s->dnn_handle_, &alpha, out_desc_, @@ -208,11 +211,11 @@ class CuDNNDeconvolutionOp : public Operator { CHECK_EQ(out_data.size(), 1); if (!init_cudnn_) { init_cudnn_ = true; - size_t workspace_byte = static_cast(param_.workspace * sizeof(real_t)); + size_t workspace_byte = static_cast(param_.workspace * sizeof(DType)); size_t back_size = 0; size_t back_size_w = 0; - Tensor data = in_data[deconv::kData].get(s); - Tensor out = out_data[deconv::kOut].get(s); + Tensor data = in_data[deconv::kData].get(s); + Tensor out = out_data[deconv::kOut].get(s); data_offset_ = data.shape_[1] / param_.num_group * data.shape_[2] * data.shape_[3]; out_offset_ = out.shape_[1] /param_.num_group * out.shape_[2] * out.shape_[3]; weight_offset_ = data.shape_[1] / param_.num_group * param_.num_filter / param_.num_group @@ -267,7 +270,7 @@ class CuDNNDeconvolutionOp : public Operator { out.shape_[3], 1), CUDNN_STATUS_SUCCESS); if (!param_.no_bias) { - Tensor bias = in_data[deconv::kBias].get(s); + Tensor bias = in_data[deconv::kBias].get(s); bias_offset_ = bias.shape_[0] / param_.num_group; CHECK_EQ(cudnnSetTensor4dDescriptor(bias_desc_, CUDNN_TENSOR_NCHW, @@ -324,8 +327,8 @@ class CuDNNDeconvolutionOp : public Operator { in_desc_, algo_, &forward_workspace_byte_), CUDNN_STATUS_SUCCESS); - forward_workspace_ = forward_workspace_byte_ / sizeof(real_t) + 1; - backward_workspace_ = backward_workspace_byte_ / sizeof(real_t) + 1; + forward_workspace_ = forward_workspace_byte_ / sizeof(DType) + 1; + backward_workspace_ = backward_workspace_byte_ / sizeof(DType) + 1; } } diff --git a/src/operator/deconvolution-inl.h b/src/operator/deconvolution-inl.h index 9a8a9607fcf7..a1590956e8c7 100644 --- a/src/operator/deconvolution-inl.h +++ b/src/operator/deconvolution-inl.h @@ -54,7 +54,7 @@ struct DeconvolutionParam : public dmlc::Parameter { } }; -template +template class DeconvolutionOp : public Operator { public: explicit DeconvolutionOp(DeconvolutionParam p) { @@ -75,29 +75,33 @@ class DeconvolutionOp : public Operator { CHECK_EQ(in_data.size(), expected); CHECK_EQ(out_data.size(), 1); Stream *s = ctx.get_stream(); - Tensor data = in_data[deconv::kData].get(s); - Tensor out = out_data[deconv::kOut].get(s); + Tensor data = in_data[deconv::kData].get(s); + Tensor out = out_data[deconv::kOut].get(s); Shape<3> wmat_shape = Shape3(param_.num_group, data.shape_[1] / param_.num_group, param_.num_filter / param_.num_group * param_.kernel[0] * param_.kernel[1]); - Tensor wmat = in_data[deconv::kWeight].get_with_shape(wmat_shape, s); + Tensor wmat = + in_data[deconv::kWeight].get_with_shape(wmat_shape, s); #if defined(__CUDACC__) CHECK_EQ(s->blas_handle_ownership_, Stream::OwnHandle) << "Must init CuBLAS handle in stream"; #endif const index_t nbatch = data.size(0); - Tensor workspace = ctx.requested[deconv::kTempSpace].get_space( - Shape1(this->InitTemp(out.shape_, data.shape_)), s); + Tensor workspace = + ctx.requested[deconv::kTempSpace].get_space_typed( + Shape1(this->InitTemp(out.shape_, data.shape_)), s); for (index_t i = 0; i < nbatch; i += nstep_) { const index_t step = std::min(nstep_, nbatch - i); - Tensor temp_col = Tensor(workspace.dptr_, - Shape2(shape_colunit_[0], - shape_colunit_[1] * step), s); - Tensor temp_dst = Tensor(workspace.dptr_ + temp_col.shape_.Size(), - Shape3(shape_dstunit_[0], - shape_dstunit_[1], - shape_dstunit_[2] * step), s); + Tensor temp_col = Tensor( + workspace.dptr_, + Shape2(shape_colunit_[0], + shape_colunit_[1] * step), s); + Tensor temp_dst = Tensor( + workspace.dptr_ + temp_col.shape_.Size(), + Shape3(shape_dstunit_[0], + shape_dstunit_[1], + shape_dstunit_[2] * step), s); temp_dst = reshape(swapaxis<1, 0>(data.Slice(i, i + step)), temp_dst.shape_); if (param_.pad[0] == 0 && param_.pad[1] == 0) { temp_col = unpack_patch2col(out.Slice(i, i + step), @@ -117,8 +121,8 @@ class DeconvolutionOp : public Operator { } const index_t gstride = temp_col.size(0) / param_.num_group; for (uint32_t gid = 0; gid < param_.num_group; ++gid) { - mshadow::Tensor tmpc = temp_col.Slice(gstride * gid, - gstride * (gid + 1)); + mshadow::Tensor tmpc = temp_col.Slice(gstride * gid, + gstride * (gid + 1)); tmpc = dot(wmat[gid].T(), temp_dst[gid]); } if (param_.pad[0] == 0 && param_.pad[1] == 0) { @@ -143,7 +147,7 @@ class DeconvolutionOp : public Operator { } if (!param_.no_bias) { // add bias, broadcast bias to dim 1: channel - Tensor bias = in_data[deconv::kBias].get(s); + Tensor bias = in_data[deconv::kBias].get(s); out += broadcast<1>(bias, out.shape_); } } @@ -165,31 +169,36 @@ class DeconvolutionOp : public Operator { CHECK_EQ(in_data[deconv::kWeight].CheckContiguous(), true); // get data Stream *s = ctx.get_stream(); - Tensor data = in_data[deconv::kData].get(s); - Tensor grad = out_grad[deconv::kOut].get(s); - Tensor gdata = in_grad[deconv::kData].get(s); + Tensor data = in_data[deconv::kData].get(s); + Tensor grad = out_grad[deconv::kOut].get(s); + Tensor gdata = in_grad[deconv::kData].get(s); Shape<3> wmat_shape = Shape3(param_.num_group, data.shape_[1] / param_.num_group, param_.num_filter / param_.num_group * param_.kernel[0] * param_.kernel[1]); - Tensor wmat = in_data[deconv::kWeight].get_with_shape(wmat_shape, s); - Tensor gwmat = in_grad[deconv::kWeight].get_with_shape(wmat_shape, s); + Tensor wmat = + in_data[deconv::kWeight].get_with_shape(wmat_shape, s); + Tensor gwmat = + in_grad[deconv::kWeight].get_with_shape(wmat_shape, s); #if defined(__CUDACC__) CHECK_EQ(s->blas_handle_ownership_, Stream::OwnHandle) << "Must init CuBLAS handle in stream"; #endif const index_t nbatch = data.size(0); - Tensor workspace = ctx.requested[deconv::kTempSpace].get_space( - Shape1(this->InitTemp(grad.shape_, data.shape_)), s); + Tensor workspace = + ctx.requested[deconv::kTempSpace].get_space_typed( + Shape1(this->InitTemp(grad.shape_, data.shape_)), s); for (index_t i = 0; i < nbatch; i += nstep_) { const index_t step = std::min(nstep_, nbatch - i); - Tensor temp_col = Tensor(workspace.dptr_, - Shape2(shape_colunit_[0], - shape_colunit_[1] * step), s); - Tensor temp_dst = Tensor(workspace.dptr_ + temp_col.shape_.Size(), - Shape3(shape_dstunit_[0], - shape_dstunit_[1], - shape_dstunit_[2] * step), s); + Tensor temp_col = Tensor( + workspace.dptr_, + Shape2(shape_colunit_[0], + shape_colunit_[1] * step), s); + Tensor temp_dst = Tensor( + workspace.dptr_ + temp_col.shape_.Size(), + Shape3(shape_dstunit_[0], + shape_dstunit_[1], + shape_dstunit_[2] * step), s); temp_dst = reshape(swapaxis<1, 0>(data.Slice(i, i + step)), temp_dst.shape_); if (param_.pad[0] == 0 && param_.pad[1] == 0) { temp_col = unpack_patch2col(grad.Slice(i, i + step), @@ -208,9 +217,9 @@ class DeconvolutionOp : public Operator { } const index_t gstride = temp_col.size(0) / param_.num_group; for (uint32_t gid = 0; gid < param_.num_group; ++gid) { - Tensor tmpc = temp_col.Slice(gstride * gid, gstride * (gid + 1)); + Tensor tmpc = temp_col.Slice(gstride * gid, gstride * (gid + 1)); if (i == 0) { - Tensor tmp_gwmat = gwmat[gid]; + Tensor tmp_gwmat = gwmat[gid]; Assign(tmp_gwmat, req[deconv::kWeight], dot(temp_dst[gid], tmpc.T())); } else { gwmat[gid] += dot(temp_dst[gid], tmpc.T()); @@ -218,7 +227,7 @@ class DeconvolutionOp : public Operator { } if (req[deconv::kData] == kWriteTo || req[deconv::kData] == kWriteInplace) { for (uint32_t gid = 0; gid < param_.num_group; ++gid) { - Tensor tmpc = temp_col.Slice(gstride * gid, gstride * (gid + 1)); + Tensor tmpc = temp_col.Slice(gstride * gid, gstride * (gid + 1)); temp_dst[gid] = dot(wmat[gid], tmpc); } gdata.Slice(i, i + step) = swapaxis<1, 0>(reshape(temp_dst, @@ -229,7 +238,7 @@ class DeconvolutionOp : public Operator { } } if (!param_.no_bias) { - Tensor gbias = in_grad[deconv::kBias].get(s); + Tensor gbias = in_grad[deconv::kBias].get(s); Assign(gbias, req[deconv::kBias], sumall_except_dim<1>(grad)); } } @@ -259,8 +268,8 @@ class DeconvolutionOp : public Operator { shape_dstunit_[2] * nstep_); index_t required_size = scol.Size() + sdst.Size(); CHECK_GE(param_.workspace, required_size) - << "\nMinimum workspace size: " << required_size * sizeof(real_t) << " Bytes\n" - << "Given: " << param_.workspace * sizeof(real_t); + << "\nMinimum workspace size: " << required_size * sizeof(DType) << " Bytes\n" + << "Given: " << param_.workspace * sizeof(DType); return required_size; } @@ -271,7 +280,7 @@ class DeconvolutionOp : public Operator { }; // class DeconvolutionOp template -Operator* CreateOp(DeconvolutionParam param); +Operator* CreateOp(DeconvolutionParam param, int dtype); #if DMLC_USE_CXX11 class DeconvolutionProp : public OperatorProperty { @@ -332,6 +341,26 @@ class DeconvolutionProp : public OperatorProperty { return true; } + bool InferType(std::vector *in_type, + std::vector *out_type, + std::vector *aux_type) const override { + CHECK_GE(in_type->size(), 1); + int dtype = (*in_type)[0]; + CHECK_NE(dtype, -1) << "First input must have specified type"; + for (index_t i = 0; i < in_type->size(); ++i) { + if ((*in_type)[i] == -1) { + (*in_type)[i] = dtype; + } else { + CHECK_EQ((*in_type)[i], dtype) << "This layer requires uniform type. " + << "Expected " << dtype << " v.s. given " + << (*in_type)[i] << " at " << ListArguments()[i]; + } + } + out_type->clear(); + out_type->push_back(dtype); + return true; + } + OperatorProperty* Copy() const override { auto ptr = new DeconvolutionProp(); ptr->param_ = param_; @@ -359,7 +388,14 @@ class DeconvolutionProp : public OperatorProperty { return {ResourceRequest::kTempSpace}; } - Operator* CreateOperator(Context ctx) const override; + Operator* CreateOperator(Context ctx) const override { + LOG(FATAL) << "Not Implemented"; + return NULL; + } + + Operator* CreateOperatorEx(Context ctx, std::vector *in_shape, + std::vector *in_type) const override; + private: DeconvolutionParam param_; diff --git a/src/operator/deconvolution.cc b/src/operator/deconvolution.cc index fe5deeafc05b..61d839bae8d3 100644 --- a/src/operator/deconvolution.cc +++ b/src/operator/deconvolution.cc @@ -10,12 +10,21 @@ namespace mxnet { namespace op { template<> -Operator* CreateOp(DeconvolutionParam param) { - return new DeconvolutionOp(param); +Operator* CreateOp(DeconvolutionParam param, int dtype) { + Operator *op = NULL; + MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { + op = new DeconvolutionOp(param); + }); + return op; } -Operator* DeconvolutionProp::CreateOperator(Context ctx) const { - DO_BIND_DISPATCH(CreateOp, param_); +Operator* DeconvolutionProp::CreateOperatorEx(Context ctx, std::vector *in_shape, + std::vector *in_type) const { + std::vector out_shape, aux_shape; + std::vector out_type, aux_type; + CHECK(InferType(in_type, &out_type, &aux_type)); + CHECK(InferShape(in_shape, &out_shape, &aux_shape)); + DO_BIND_DISPATCH(CreateOp, param_, in_type->at(0)); } DMLC_REGISTER_PARAMETER(DeconvolutionParam); diff --git a/src/operator/deconvolution.cu b/src/operator/deconvolution.cu index f9df3fb1f8d7..5fea2862554b 100644 --- a/src/operator/deconvolution.cu +++ b/src/operator/deconvolution.cu @@ -13,12 +13,16 @@ namespace mxnet { namespace op { template<> -Operator* CreateOp(DeconvolutionParam param) { +Operator* CreateOp(DeconvolutionParam param, int dtype) { + Operator *op = NULL; + MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { #if MXNET_USE_CUDNN == 1 - return new CuDNNDeconvolutionOp(param); + op = new CuDNNDeconvolutionOp(param); #else - return new DeconvolutionOp(param); + op = new DeconvolutionOp(param); #endif // MXNET_USE_CUDNN + }); + return op; } } // namespace op diff --git a/src/operator/upsampling.cc b/src/operator/upsampling.cc index 6398189635c0..d69e7e99c040 100644 --- a/src/operator/upsampling.cc +++ b/src/operator/upsampling.cc @@ -30,7 +30,7 @@ Operator *CreateOp(UpSamplingParam param) { p.stride = TShape(shape, shape + 2); shape[0] = shape[1] = pad; p.pad = TShape(shape, shape + 2); - return new DeconvolutionOp(p); + return new DeconvolutionOp(p); } else { LOG(FATAL) << "Unknown sample type"; return NULL; diff --git a/src/operator/upsampling.cu b/src/operator/upsampling.cu index 99cf8b6aa807..526f3a91de84 100644 --- a/src/operator/upsampling.cu +++ b/src/operator/upsampling.cu @@ -30,7 +30,7 @@ Operator *CreateOp(UpSamplingParam param) { p.stride = TShape(shape, shape + 2); shape[0] = shape[1] = pad; p.pad = TShape(shape, shape + 2); - return new DeconvolutionOp(p); + return new DeconvolutionOp(p); } else { LOG(FATAL) << "Unknown sample type"; return NULL; diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py index b1ad35db5b0b..7bf532d8468c 100644 --- a/tests/python/gpu/test_operator_gpu.py +++ b/tests/python/gpu/test_operator_gpu.py @@ -76,6 +76,15 @@ def test_convolution_with_type(): {'ctx': mx.cpu(0), 'conv_data': (2, 2, 10, 10), 'type_dict': {'conv_data': np.float32}}] check_consistency(sym, ctx_list) +def test_deconvolution_with_type(): + sym = mx.sym.Deconvolution(num_filter=2, kernel=(3,3), name='conv') + ctx_list = [{'ctx': mx.gpu(0), 'conv_data': (2, 2, 10, 10), 'type_dict': {'conv_data': np.float64}}, + {'ctx': mx.gpu(0), 'conv_data': (2, 2, 10, 10), 'type_dict': {'conv_data': np.float32}}, + {'ctx': mx.gpu(0), 'conv_data': (2, 2, 10, 10), 'type_dict': {'conv_data': np.float16}}, + {'ctx': mx.cpu(0), 'conv_data': (2, 2, 10, 10), 'type_dict': {'conv_data': np.float64}}, + {'ctx': mx.cpu(0), 'conv_data': (2, 2, 10, 10), 'type_dict': {'conv_data': np.float32}}] + check_type_consistency(sym, ctx_list) + def test_fullyconnected_with_type(): sym = mx.sym.FullyConnected(num_hidden=3, name='inner') ctx_list = [{'ctx': mx.gpu(0), 'inner_data': (2, 10), 'type_dict': {'inner_data': np.float64}}, @@ -97,6 +106,7 @@ def test_activation_with_type(): if __name__ == '__main__': test_convolution_with_type() + test_deconvolution_with_type() test_fullyconnected_with_type() test_activation_with_type() #test_softmax_with_shape((3,4), mx.gpu())