diff --git a/runtime/ops/roi.cc b/runtime/ops/roi.cc index 7c76315a..a73bbc2a 100644 --- a/runtime/ops/roi.cc +++ b/runtime/ops/roi.cc @@ -429,26 +429,22 @@ class ReduceByAverage { chainerx::Array ROIMaxPool2DOp::RunImpl( ChxVMState* st, const chainerx::Array& x, const chainerx::Array& rois, const chainerx::Array& roi_indices) { - CHECK(!IsCudaDevice(&x.device())) << "Not implemented"; - return ROIPool2D(x, rois, roi_indices, output_shape, spatial_scale, chainerx::AMax); + return ROIPool2D(x.ToNative(), rois, roi_indices, output_shape, spatial_scale, chainerx::AMax).ToDevice(x.device()); } chainerx::Array ROIAveragePool2DOp::RunImpl( ChxVMState* st, const chainerx::Array& x, const chainerx::Array& rois, const chainerx::Array& roi_indices) { - CHECK(!IsCudaDevice(&x.device())) << "Not implemented"; - return ROIPool2D(x, rois, roi_indices, output_shape, spatial_scale, chainerx::Mean); + return ROIPool2D(x.ToNative(), rois, roi_indices, output_shape, spatial_scale, chainerx::Mean).ToDevice(x.device()); } chainerx::Array ROIMaxAlign2DOp::RunImpl( ChxVMState* st, const chainerx::Array& x, const chainerx::Array& rois, const chainerx::Array& roi_indices) { - CHECK(!IsCudaDevice(&x.device())) << "Not implemented"; - return ROIAlign2D(x, rois, roi_indices, output_shape, spatial_scale, sampling_ratio); + return ROIAlign2D(x.ToNative(), rois, roi_indices, output_shape, spatial_scale, sampling_ratio).ToDevice(x.device()); } chainerx::Array ROIAverageAlign2DOp::RunImpl( ChxVMState* st, const chainerx::Array& x, const chainerx::Array& rois, const chainerx::Array& roi_indices) { - CHECK(!IsCudaDevice(&x.device())) << "Not implemented"; - return ROIAlign2D(x, rois, roi_indices, output_shape, spatial_scale, sampling_ratio); + return ROIAlign2D(x.ToNative(), rois, roi_indices, output_shape, spatial_scale, sampling_ratio).ToDevice(x.device()); } } // namespace runtime