Skip to content

Commit

Permalink
Merge pull request pfnet-research#155 from pfnet-research/bump-chainer
Browse files Browse the repository at this point in the history
Bump chainer
  • Loading branch information
shinh authored Apr 12, 2019
2 parents 75bc903 + ede4562 commit 453b398
Show file tree
Hide file tree
Showing 13 changed files with 134 additions and 51 deletions.
4 changes: 2 additions & 2 deletions compiler/gen_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,8 +227,8 @@ def __init__(self, op_type, num_inputs, num_outputs, **kwargs):
# For experimental ops.
NodeDef('ChainerDoSomething', None, None, function_name=Required(str))

NodeDef('ChainerMaxPoolGrad', 2, 1)
NodeDef('ChainerAveragePoolGrad', 2, 1)
NodeDef('ChainerMaxPoolGrad', 2, 1, chainer_cover_all=False, **pool_attrs)
NodeDef('ChainerAveragePoolGrad', 2, 1, count_include_pad=False, **pool_attrs)
NodeDef('ChainerMaxPoolGradNoCtx',
3, 1, chainer_cover_all=False, **pool_attrs)
NodeDef('ChainerAveragePoolGradNoCtx',
Expand Down
16 changes: 14 additions & 2 deletions compiler/gradient_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,13 @@ void MaxPoolGradFn(GradientOpContext* gc) {
if (node->outputs().size() == 1) gc->AddNullOutput();
CHECK_EQ(2, node->outputs().size());
Value* context = gc->AddOutput(Type(Type::Kind::kOpaque));
gc->GradOp(Node::kChainerMaxPoolGrad, 0, {gc->gy(0), context});
gc->GradOp(Node::kChainerMaxPoolGrad, 0, {gc->gy(0), context})
->producer()
->set_kernel_shape(node->kernel_shape())
->set_pads(node->pads())
->set_storage_order(node->storage_order())
->set_strides(node->strides())
->set_chainer_cover_all(node->chainer_cover_all());
}
}

Expand All @@ -477,7 +483,13 @@ void AveragePoolGradFn(GradientOpContext* gc) {
Node* node = gc->node();
CHECK_EQ(1, node->outputs().size());
Value* context = gc->AddOutput(Type(Type::Kind::kOpaque));
gc->GradOp(Node::kChainerAveragePoolGrad, 0, {gc->gy(0), context});
gc->GradOp(Node::kChainerAveragePoolGrad, 0, {gc->gy(0), context})
->producer()
->set_kernel_shape(node->kernel_shape())
->set_pads(node->pads())
->set_storage_order(node->storage_order())
->set_strides(node->strides())
->set_count_include_pad(node->count_include_pad());
}
}

Expand Down
8 changes: 6 additions & 2 deletions compiler/xcvm/emitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -257,8 +257,6 @@ class XCVMEmitter {
EMIT_SIMPLE_BINARY_OP(Node::kXor, Xor);

EMIT_SIMPLE_BINARY_OP(Node::kChainerReluGrad, ReluGrad);
EMIT_SIMPLE_BINARY_OP(Node::kChainerMaxPoolGrad, MaxPoolGrad);
EMIT_SIMPLE_BINARY_OP(Node::kChainerAveragePoolGrad, AveragePoolGrad);
EMIT_SIMPLE_BINARY_OP(Node::kChainerSelectItem, SelectItem);

if (node.op_type() == Node::kDropout) {
Expand Down Expand Up @@ -415,6 +413,9 @@ class XCVMEmitter {
CHECK(node.output(1)->IsNull());
}
EMIT(MaxPool, out(0), oout(2), in(0), node.kernel_shape(), strides(), pads(), node.chainer_cover_all());
} else if (node.op_type() == Node::kChainerMaxPoolGrad) {
CHECK_EQ("NOTSET", node.auto_pad()) << "auto_pad is not supported for MaxPool";
EMIT(MaxPoolGrad, out(0), in(0), in(1), node.kernel_shape(), node.chainer_cover_all());
} else if (node.op_type() == Node::kChainerROIMaxPool2D) {
EMIT(ROIMaxPool2D, out(0), in(0), in(1), in(2), node.output_shape(), node.spatial_scale());
} else if (node.op_type() == Node::kChainerROIAveragePool2D) {
Expand All @@ -432,6 +433,9 @@ class XCVMEmitter {
CHECK_EQ("NOTSET", node.auto_pad()) << "auto_pad is not supported for AveragePool";
CHECK_EQ(1UL, node.inputs().size());
EMIT(AveragePool, out(0), oout(1), in(0), node.kernel_shape(), strides(), pads(), node.count_include_pad());
} else if (node.op_type() == Node::kChainerAveragePoolGrad) {
CHECK_EQ("NOTSET", node.auto_pad()) << "auto_pad is not supported for AveragePool";
EMIT(AveragePoolGrad, out(0), in(0), in(1), node.kernel_shape(), node.count_include_pad());
} else if (node.op_type() == Node::kChainerAveragePoolGradNoCtx) {
CHECK_EQ("NOTSET", node.auto_pad()) << "auto_pad is not supported for MaxPool";
EMIT(AveragePoolGradNoCtx, out(0), in(0), in(1), in(2), node.kernel_shape(), strides(), pads(), node.chainer_cover_all());
Expand Down
2 changes: 1 addition & 1 deletion runtime/chainerx_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ chainerx::Array PadSequence(const std::vector<chainerx::Array>& inputs, int64_t
const chainerx::Array& input = inputs[i];
indices[0] = chainerx::ArrayIndex(i);
indices[1] = chainerx::Slice(0, input.shape()[0]);
input.device().Copy(input, result.At(indices));
input.device().backend().CallOp<chainerx::CopyOp>(input, result.At(indices));
}
return result;
}
Expand Down
2 changes: 1 addition & 1 deletion runtime/ops/activation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ chainerx::Array ReluGradOp::RunImpl(XCVMState* st, const chainerx::Array& x, con
} else {
CHECK(false) << "TODO(hamaji): Unsupported dtype: " << x.dtype();
}
x.device().IfLessElseASSA(x, eps, chainerx::Scalar(0.0), gy, out);
x.device().backend().CallOp<chainerx::IfLessElseASSAOp>(x, eps, chainerx::Scalar(0.0), gy, out);
return out;
}

Expand Down
4 changes: 2 additions & 2 deletions runtime/ops/connection.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ chainerx::Array ConvTransposeWithDynamicShapeOp::RunImpl(
}

chainerx::Array ConvGradWeightOp::RunImpl(XCVMState* st, const chainerx::Array& w, const chainerx::Array& x, const chainerx::Array& gy) {
return x.device().ConvGradWeight(
w.dtype(), w.shape(), x, gy, ComplementStride(strides, x), ComplementPad(pads, x), false /* cover_all */);
return x.device().backend().CallOp<chainerx::ConvGradWeightOp>(
w.dtype(), w.shape(), x, gy, ComplementStride(strides, x), ComplementPad(pads, x), false /* cover_all */, nonstd::nullopt);
}

} // namespace runtime
Expand Down
13 changes: 7 additions & 6 deletions runtime/ops/cudnn_rnn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ chainerx::Array PackSequence(const chainerx::Array& x, int64_t num_inputs, const
chainerx::Array src = x.At({time, chainerx::Slice(0, num_batch)});
chainerx::Array dest = packed.At({chainerx::Slice(offset, offset + num_batch)});
CHECK_EQ(src.GetTotalSize(), dest.GetTotalSize());
x.device().Copy(src, dest);
x.device().backend().CallOp<chainerx::CopyOp>(src, dest);
offset += num_batch;
}
CHECK_EQ(offset, packed.shape()[0]);
Expand All @@ -363,7 +363,7 @@ chainerx::Array UnpackSequence(
chainerx::Array src = packed.At({chainerx::Slice(offset, offset + num_batch)});
chainerx::Array dest = x.At({time, chainerx::Slice(0, num_batch)});
CHECK_EQ(src.GetTotalSize(), dest.GetTotalSize());
x.device().Copy(src, dest);
x.device().backend().CallOp<chainerx::CopyOp>(src, dest);
offset += num_batch;
}
return x;
Expand Down Expand Up @@ -428,7 +428,7 @@ bool CudnnLSTM(
chainerx::Array packed = PackSequence(x, num_inputs, num_batches);

auto& device = dynamic_cast<chainerx::cuda::CudaDevice&>(x.device());
CudnnHandle& cudnn_handle = device.cudnn_handle();
CudnnHandle& cudnn_handle = chainerx::cuda::cuda_internal::GetDeviceInternals(device).cudnn_handle();

// TODO(hamaji): Avoid unnecessary memory allocation.
CudnnTensorDescriptor x_desc(chainerx::Empty({batch_size, input_size, 1}, x.dtype(), chainerx::GetNativeBackend().GetDevice(0)));
Expand Down Expand Up @@ -472,7 +472,8 @@ bool CudnnLSTM(
int64_t param_size = src_w.GetTotalSize();
int offset = GetRNNWeightOffset(
cudnn_handle, *rnn_desc, pseudo_layer, x_desc, *w_concat_desc, w_concat, lin_layer_id, is_bias, src_w);
w_concat.device().Copy(chainerx::Reshape(src_w, {param_size}), w_concat.At({chainerx::Slice(offset, offset + param_size)}));
w_concat.device().backend().CallOp<chainerx::CopyOp>(
chainerx::Reshape(src_w, {param_size}), w_concat.At({chainerx::Slice(offset, offset + param_size)}));
offsets.push_back(offset);
}
}
Expand Down Expand Up @@ -567,7 +568,7 @@ bool CudnnLSTMGrad(
if (!dynamic_cast<const LSTMBackwardContext*>(&ctx)) return false;
auto& context = dynamic_cast<const LSTMBackwardContext&>(ctx);
auto& device = dynamic_cast<chainerx::cuda::CudaDevice&>(ogy.device());
CudnnHandle& cudnn_handle = device.cudnn_handle();
CudnnHandle& cudnn_handle = chainerx::cuda::cuda_internal::GetDeviceInternals(device).cudnn_handle();

const chainerx::Array& x = context.x();
const chainerx::Array& w_concat = context.w();
Expand Down Expand Up @@ -652,7 +653,7 @@ bool CudnnLSTMGrad(
for (int is_bias = 0; is_bias < 2; ++is_bias) {
int64_t offset = context.offsets()[offset_index++];
chainerx::Array dest = slices[is_bias][lin_layer_id];
gw_concat.device().Copy(
gw_concat.device().backend().CallOp<chainerx::CopyOp>(
chainerx::Reshape(gw_concat.At({chainerx::Slice(offset, offset + dest.GetTotalSize())}), dest.shape()), dest);
}
}
Expand Down
4 changes: 2 additions & 2 deletions runtime/ops/indexing.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ chainerx::Array DynamicSliceGradOp::RunImpl(
const nonstd::optional<chainerx::Array>& steps) {
chainerx::Array out = chainerx::Zeros(ArrayToShape(shape), gy.dtype());
std::vector<chainerx::ArrayIndex> indices = GetIndicesForDynamicSlice(out, starts, ends, axes, steps);
out.device().Copy(gy, out.At(indices));
out.device().backend().CallOp<chainerx::CopyOp>(gy, out.At(indices));
return out;
}

Expand Down Expand Up @@ -117,7 +117,7 @@ chainerx::Array GetItemGradOp::RunImpl(
XCVMState* st, const chainerx::Array& gy, const chainerx::Array& shape, const std::vector<chainerx::Array>& index_arrays) {
chainerx::Array out = chainerx::Zeros(ArrayToShape(shape), gy.dtype());
std::vector<chainerx::ArrayIndex> indices = GetIndicesForGetItem(index_arrays, slice_specs);
out.device().Copy(gy, out.At(indices));
out.device().backend().CallOp<chainerx::CopyOp>(gy, out.At(indices));
return out;
}

Expand Down
2 changes: 1 addition & 1 deletion runtime/ops/manipulation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ chainerx::Array PadOp::RunImpl(XCVMState* st, const chainerx::Array& data) {
indices2.push_back(chainerx::Slice(start2, end2));
}
chainerx::Array result = chainerx::Full(new_shape, value, data.dtype(), data.device());
result.device().Copy(data.At(indices1), result.At(indices2));
result.device().backend().CallOp<chainerx::CopyOp>(data.At(indices1), result.At(indices2));
return result;
}

Expand Down
64 changes: 51 additions & 13 deletions runtime/ops/normalization.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,28 @@ namespace {

class BatchNormBackwardContext : public XCVMOpaque {
public:
BatchNormBackwardContext(std::unique_ptr<chainerx::BatchNormForwardBackward>&& fb, chainerx::Shape x1_shape, chainerx::Shape x2_shape)
: fb_(std::move(fb)), x1_shape_(x1_shape), x2_shape_(x2_shape) {
BatchNormBackwardContext(
std::shared_ptr<chainerx::BatchNormGradState> state,
const chainerx::Array& x,
const chainerx::Array& gamma,
chainerx::Shape x1_shape,
chainerx::Shape x2_shape,
double epsilon,
const chainerx::Axes& sorted_axis)
: state_(state), x_(x), gamma_(gamma), x1_shape_(x1_shape), x2_shape_(x2_shape), epsilon_(epsilon), sorted_axis_(sorted_axis) {
}
virtual ~BatchNormBackwardContext() = default;

chainerx::BatchNormForwardBackward* fb() const {
return fb_.get();
const std::shared_ptr<chainerx::BatchNormGradState>& state() const {
return state_;
}

const chainerx::Array& x() const {
return x_;
}

const chainerx::Array& gamma() const {
return gamma_;
}

const chainerx::Shape& x1_shape() const {
Expand All @@ -31,10 +46,22 @@ class BatchNormBackwardContext : public XCVMOpaque {
return x2_shape_;
}

double epsilon() const {
return epsilon_;
}

const chainerx::Axes& sorted_axis() const {
return sorted_axis_;
}

private:
std::unique_ptr<chainerx::BatchNormForwardBackward> fb_;
std::shared_ptr<chainerx::BatchNormGradState> state_;
chainerx::Array x_;
chainerx::Array gamma_;
chainerx::Shape x1_shape_;
chainerx::Shape x2_shape_;
double epsilon_;
chainerx::Axes sorted_axis_;
};

// TODO(hamaji): Copied from ChainerX's code.
Expand Down Expand Up @@ -129,12 +156,13 @@ std::tuple<chainerx::Array, XCVMOpaque*, chainerx::Array, chainerx::Array, chain
} else {
result = PreprocessBatchNorm(x, s, bias, mean, var, axes);
}
std::unique_ptr<chainerx::BatchNormForwardBackward> fb =
x.device().GetBatchNormForwardBackward(result.mean, result.var, epsilon, decay, result.sorted_axis);
const Array& gamma_reshaped = result.gamma;
const Array& beta_reshaped = result.beta;
chainerx::Array out = fb->Forward(x, gamma_reshaped, beta_reshaped);
XCVMOpaque* ctx = new BatchNormBackwardContext(std::move(fb), s.shape(), bias.shape());
std::shared_ptr<chainerx::BatchNormGradState> state;
chainerx::Array out;
std::tie(out, state) = x.device().backend().CallOp<chainerx::BatchNormOp>(
x, gamma_reshaped, beta_reshaped, result.mean, result.var, epsilon, decay, result.sorted_axis, true, nonstd::nullopt);
XCVMOpaque* ctx = new BatchNormBackwardContext(state, x, gamma_reshaped, s.shape(), bias.shape(), epsilon, result.sorted_axis);
if (st->options().dump_memory_usage) {
ctx->SetRetainedArrays({x, gamma_reshaped, beta_reshaped, result.mean, result.var});
}
Expand Down Expand Up @@ -173,10 +201,20 @@ chainerx::Array FixedBatchNormalizationOp::RunImpl(
std::tuple<chainerx::Array, chainerx::Array, chainerx::Array> BatchNormalizationGradOp::RunImpl(
XCVMState* st, const chainerx::Array& gy, const XCVMOpaque& ctx) {
auto& context = dynamic_cast<const BatchNormBackwardContext&>(ctx);
std::array<chainerx::Array, 3> gxs = context.fb()->Backward(gy);
chainerx::Array gx1 = chainerx::Reshape(gxs[1], context.x1_shape());
chainerx::Array gx2 = chainerx::Reshape(gxs[2], context.x2_shape());
return std::forward_as_tuple(gxs[0], gx1, gx2);
chainerx::Array gx, ggamma, gbeta;
std::tie(gx, ggamma, gbeta) = gy.device().backend().CallOp<chainerx::BatchNormGradOp>(
context.x(),
context.gamma(),
gy,
context.epsilon(),
context.sorted_axis(),
context.state(),
nonstd::nullopt,
nonstd::nullopt,
nonstd::nullopt);
chainerx::Array gx1 = chainerx::Reshape(ggamma, context.x1_shape());
chainerx::Array gx2 = chainerx::Reshape(gbeta, context.x2_shape());
return std::forward_as_tuple(gx, gx1, gx2);
}

std::tuple<chainerx::Array, chainerx::Array> LRNOp::RunImpl(XCVMState* st, const chainerx::Array& x) {
Expand Down
56 changes: 40 additions & 16 deletions runtime/ops/pooling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,26 +22,40 @@ namespace {
template <class T>
class BackwardContext : public XCVMOpaque {
public:
explicit BackwardContext(std::unique_ptr<T>&& fb) : fb_(std::move(fb)) {
BackwardContext(std::shared_ptr<T> state, const Int64StackVector& strides, const Int64StackVector& pads)
: state_(state), strides_(strides), pads_(pads) {
}
virtual ~BackwardContext() = default;

T* fb() const {
return fb_.get();
std::shared_ptr<T> state() const {
return state_;
}

const Int64StackVector& strides() const {
return strides_;
}

const Int64StackVector& pads() const {
return pads_;
}

private:
std::unique_ptr<T> fb_;
std::shared_ptr<T> state_;
const Int64StackVector strides_;
const Int64StackVector pads_;
};

} // namespace

std::tuple<chainerx::Array, XCVMOpaque*> MaxPoolOp::RunImpl(XCVMState* st, const chainerx::Array& x) {
// TODO(hamaji): Revive CheckPoolInputs.
std::unique_ptr<chainerx::MaxPoolForwardBackward> fb(
x.device().GetMaxPoolForwardBackward(kernel_shape, ComplementStride(strides, x), ComplementPad(pads, x), cover_all));
chainerx::Array out = fb->Forward(x);
XCVMOpaque* ctx = new BackwardContext<chainerx::MaxPoolForwardBackward>(std::move(fb));
std::shared_ptr<chainerx::MaxPoolGradState> state;
chainerx::Array out;
const Int64StackVector& strides = ComplementStride(this->strides, x);
const Int64StackVector& pads = ComplementPad(this->pads, x);
std::tie(out, state) =
x.device().backend().CallOp<chainerx::MaxPoolOp>(x, kernel_shape, strides, pads, cover_all, true, nonstd::nullopt);
XCVMOpaque* ctx = new BackwardContext<chainerx::MaxPoolGradState>(std::move(state), strides, pads);
if (st->options().dump_memory_usage) {
ctx->SetRetainedArrays({x, out});
}
Expand All @@ -51,40 +65,50 @@ std::tuple<chainerx::Array, XCVMOpaque*> MaxPoolOp::RunImpl(XCVMState* st, const
std::tuple<chainerx::Array, XCVMOpaque*> AveragePoolOp::RunImpl(XCVMState* st, const chainerx::Array& x) {
// TODO(hamaji): Revive CheckPoolInputs.
chainerx::AveragePoolPadMode pad_mode = count_include_pad ? chainerx::AveragePoolPadMode::kZero : chainerx::AveragePoolPadMode::kIgnore;
std::unique_ptr<chainerx::AveragePoolForwardBackward> fb(
x.device().GetAveragePoolForwardBackward(kernel_shape, ComplementStride(strides, x), ComplementPad(pads, x), pad_mode));
chainerx::Array out = fb->Forward(x);
XCVMOpaque* ctx = new BackwardContext<chainerx::AveragePoolForwardBackward>(std::move(fb));
std::shared_ptr<chainerx::AveragePoolGradState> state;
chainerx::Array out;
std::tie(out, state) =
x.device().backend().CallOp<chainerx::AveragePoolOp>(x, kernel_shape, strides, pads, pad_mode, true, nonstd::nullopt);
XCVMOpaque* ctx = new BackwardContext<chainerx::AveragePoolGradState>(std::move(state), strides, pads);
if (st->options().dump_memory_usage) {
ctx->SetRetainedArrays({x, out});
}
return std::tie(out, ctx);
}

chainerx::Array MaxPoolGradOp::RunImpl(XCVMState* st, const chainerx::Array& gy, const XCVMOpaque& ctx) {
auto& context = dynamic_cast<const BackwardContext<chainerx::MaxPoolForwardBackward>&>(ctx);
return context.fb()->Backward(gy);
auto& context = dynamic_cast<const BackwardContext<chainerx::MaxPoolGradState>&>(ctx);
return std::get<0>(gy.device().backend().CallOp<chainerx::MaxPoolGradOp>(
gy, kernel_shape, context.strides(), context.pads(), context.state(), true, nonstd::nullopt));
}

chainerx::Array AveragePoolGradOp::RunImpl(XCVMState* st, const chainerx::Array& gy, const XCVMOpaque& ctx) {
auto& context = dynamic_cast<const BackwardContext<chainerx::AveragePoolForwardBackward>&>(ctx);
return context.fb()->Backward(gy);
chainerx::AveragePoolPadMode pad_mode = count_include_pad ? chainerx::AveragePoolPadMode::kZero : chainerx::AveragePoolPadMode::kIgnore;
auto& context = dynamic_cast<const BackwardContext<chainerx::AveragePoolGradState>&>(ctx);
return gy.device().backend().CallOp<chainerx::AveragePoolGradOp>(
gy, kernel_shape, context.strides(), context.pads(), pad_mode, context.state(), nonstd::nullopt);
}

chainerx::Array MaxPoolGradNoCtxOp::RunImpl(XCVMState* st, const chainerx::Array& x, const chainerx::Array& y, const chainerx::Array& gy) {
CHECK(false);
#if 0
std::unique_ptr<chainerx::MaxPoolForwardBackward> fb(
x.device().GetMaxPoolForwardBackward(kernel_shape, ComplementStride(strides, x), ComplementPad(pads, x), cover_all));
fb->Forward(x);
return fb->Backward(gy);
#endif
}

chainerx::Array AveragePoolGradNoCtxOp::RunImpl(
XCVMState* st, const chainerx::Array& x, const chainerx::Array& y, const chainerx::Array& gy) {
CHECK(false);
#if 0
chainerx::AveragePoolPadMode pad_mode = count_include_pad ? chainerx::AveragePoolPadMode::kZero : chainerx::AveragePoolPadMode::kIgnore;
std::unique_ptr<chainerx::AveragePoolForwardBackward> fb(
x.device().GetAveragePoolForwardBackward(kernel_shape, ComplementStride(strides, x), ComplementPad(pads, x), pad_mode));
fb->Forward(x);
return fb->Backward(gy);
#endif
}

// A faithful re-implementation of Chainer's ROI ops.
Expand Down
Loading

0 comments on commit 453b398

Please sign in to comment.