Skip to content

Commit

Permalink
Support static graph code-gen for bincount (#54686)
Browse files Browse the repository at this point in the history
  • Loading branch information
sanbuphy authored Jun 26, 2023
1 parent 733eca8 commit b547c4a
Show file tree
Hide file tree
Showing 7 changed files with 34 additions and 108 deletions.
73 changes: 0 additions & 73 deletions paddle/fluid/operators/bincount_op.cc

This file was deleted.

9 changes: 9 additions & 0 deletions paddle/fluid/operators/generator/get_expected_kernel_func.cc
Original file line number Diff line number Diff line change
Expand Up @@ -410,5 +410,14 @@ phi::KernelKey GetConvExpectedKernelType(
return phi::KernelKey(input_data_type, ctx.GetPlace());
}

phi::KernelKey GetBincountExpectedKernelType(
const framework::ExecutionContext& ctx,
const framework::OperatorWithKernel* op_ptr) {
auto data_type = ctx.HasInput("Weights")
? op_ptr->IndicateVarDataType(ctx, "Weights")
: op_ptr->IndicateVarDataType(ctx, "X");
return phi::KernelKey(data_type, ctx.device_context().GetPlace());
}

} // namespace operators
} // namespace paddle
4 changes: 4 additions & 0 deletions paddle/fluid/operators/generator/get_expected_kernel_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,5 +100,9 @@ phi::KernelKey GetConvExpectedKernelType(
const framework::ExecutionContext& ctx,
const framework::OperatorWithKernel* op_ptr);

phi::KernelKey GetBincountExpectedKernelType(
const framework::ExecutionContext& ctx,
const framework::OperatorWithKernel* op_ptr);

} // namespace operators
} // namespace paddle
9 changes: 0 additions & 9 deletions paddle/phi/api/yaml/legacy_ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -121,15 +121,6 @@
view : (mean -> mean_out), (variance -> variance_out)
backward : batch_norm_grad

- op : bincount
args: (Tensor x, Tensor weights, Scalar(int) minlength = 0)
output: Tensor(out)
infer_meta:
func: BincountInferMeta
kernel:
func: bincount
optional: weights

- op : cast
args : (Tensor x, DataType dtype)
output : Tensor
Expand Down
12 changes: 12 additions & 0 deletions paddle/phi/api/yaml/op_compat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,18 @@
extra :
attrs : [bool use_mkldnn = false]

- op : bincount
inputs :
{x : X, weights : Weights}
outputs :
out : Out
scalar:
minlength:
data_type : int
support_tensor : true
get_expected_kernel_type :
bincount : GetBincountExpectedKernelType

- op : bitwise_and
inputs :
{x : X, y : Y}
Expand Down
9 changes: 9 additions & 0 deletions paddle/phi/api/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,15 @@
data_transform :
skip_transform : out_size, size_tensor, scale_tensor

- op : bincount
args: (Tensor x, Tensor weights, Scalar(int) minlength = 0)
output: Tensor(out)
infer_meta:
func: BincountInferMeta
kernel:
func: bincount
optional: weights

- op : bitwise_and
args : (Tensor x, Tensor y)
output : Tensor(out)
Expand Down
26 changes: 0 additions & 26 deletions paddle/phi/ops/compat/bincount_sig.cc

This file was deleted.

0 comments on commit b547c4a

Please sign in to comment.