From 773f4d8081634884ec421236165bd2ae19417503 Mon Sep 17 00:00:00 2001 From: Michael Antonov Date: Tue, 4 Dec 2018 11:42:43 -0800 Subject: [PATCH] Implements Gather operator for arbitrary axis, sharing the code with BatchGather. (#13756) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/13756 This implements general Gather operator for arbitrary axis, sharing the code with BatchGather. - CPU gather & batch gather logic is now shared through caffe2::gather_helper, for any axis. - Shared CUDA kernel moved to gather_op.cuh, for any axis. - Gradients of axis > 0 delegate to BatchGatherGradientOp which now has axis argument. - BatchGatherOp doc strings updated to have correct rank (q + (r -1)) and output. - Added tests for axis == 2. GatherOp supports index wrapping for axis == 0 by default, which was earlier for ONNX. This diff also extends it to work in Cuda kernel. Added "wrap_indices" argument which specifies wheather this wrapping should be done; set it to true if you'd like wrapping for any axis. TBD: Update gradients to support negative indices (separate diff). TBD: Once we have operator versioning, we'd like to update GatherOp to NOT support axis 0 wrapping by default, but rather do it only if wrap_indices is set. Reviewed By: dzhulgakov Differential Revision: D12983815 fbshipit-source-id: 8add9d67b47fe8c5ba7a335f581ca0530b205cd7 --- caffe2/core/operator.h | 3 + caffe2/operators/batch_gather_ops.cc | 21 +-- caffe2/operators/batch_gather_ops.cu | 110 +++++------- caffe2/operators/batch_gather_ops.h | 132 ++++++--------- caffe2/operators/gather_op.cc | 79 ++++++--- caffe2/operators/gather_op.cu | 53 +----- caffe2/operators/gather_op.cuh | 112 ++++++++++++ caffe2/operators/gather_op.h | 206 ++++++++++++++++++----- caffe2/python/operator_test/gather_ops_test.py | 72 +++++--- tools/amd_build/pyHIPIFY/cuda_to_hip_mappings.py | 1 + 10 files changed, 484 insertions(+), 305 deletions(-) create mode 100644 caffe2/operators/gather_op.cuh diff --git a/caffe2/core/operator.h b/caffe2/core/operator.h index 40d58b1..3db643a 100644 --- a/caffe2/core/operator.h +++ b/caffe2/core/operator.h @@ -664,6 +664,9 @@ class Operator : public OperatorBase { const Context* getContext() const { return &context_; } + Context* getContext() { + return &context_; + } protected: void RecordEvent(const char* err_msg = nullptr) final { diff --git a/caffe2/operators/batch_gather_ops.cc b/caffe2/operators/batch_gather_ops.cc index 1d57534..2bd941a 100644 --- a/caffe2/operators/batch_gather_ops.cc +++ b/caffe2/operators/batch_gather_ops.cc @@ -12,24 +12,20 @@ OPERATOR_SCHEMA(BatchGather) const vector& in) { vector out(1); ArgumentHelper helper(def); - - vector output_dims; const auto& data_dims = GetDimsVector(in[0]); const auto& indices_dims = GetDimsVector(in[1]); - output_dims.push_back(data_dims[0]); - output_dims.insert( - output_dims.end(), indices_dims.begin(), indices_dims.end()); - output_dims.insert( - output_dims.end(), data_dims.begin() + 2, data_dims.end()); + vector output_dims = + caffe2::gather_helper::calc_output_shape_vector( + data_dims, indices_dims, 1); out[0] = CreateTensorShape(output_dims, TensorProto::FLOAT); return out; }) .SetDoc(R"DOC( Batch gather operation, first dimension in DATA is the batch size. Given DATA tensor of rank r >= 2, and INDICES tensor of rank q >= 1, gather -entries of the outer-most dimension of DATA indexed by INDICES, and concatenate -them in an output tensor of rank (q - 1) + (r - 1). +entries of the second outer dimension (axis == 1) of DATA indexed by INDICES, +and concatenate them in an output tensor of rank q + (r - 1). Example: DATA = [ @@ -37,9 +33,8 @@ Example: [2.3, 3.4, 3.6, 2.3], [4.5, 5.7, 1.2, 4.5], ] - INDICES = [ - [0, 2], - ] + INDICES = [0, 2] + OUTPUT = [ [1.0, 2.4], [2.3, 3.6], @@ -48,7 +43,7 @@ Example: )DOC") .Input(0, "DATA", "Tensor of rank r >= 2.") .Input(1, "INDICES", "Tensor of int32/int64 indices, of any rank q.") - .Output(0, "OUTPUT", "Tensor of rank (q - 1) + (r - 1).") + .Output(0, "OUTPUT", "Tensor of rank q + (r - 1).") .InheritOnnxSchema(); OPERATOR_SCHEMA(BatchGatherGradient).NumInputs(3).NumOutputs(1); diff --git a/caffe2/operators/batch_gather_ops.cu b/caffe2/operators/batch_gather_ops.cu index d1559dc..4347a60 100644 --- a/caffe2/operators/batch_gather_ops.cu +++ b/caffe2/operators/batch_gather_ops.cu @@ -2,32 +2,11 @@ #include "caffe2/core/common_gpu.h" #include "caffe2/core/context_gpu.h" #include "caffe2/operators/batch_gather_ops.h" +// Shared batch kernel +#include "caffe2/operators/gather_op.cuh" namespace caffe2 { -template -__global__ void BatchGatherKernel( - const TData* src_base, - TData* out, - const T_INDEX* indices, - const int M, - const int N, - const int data_batch_size, - const int gathered_batch_size, - const int block_size) { - const int begin_idx = blockIdx.x * blockDim.x + threadIdx.x; - const int num_items = M * N * block_size; - for (int s = begin_idx; s < num_items; s += blockDim.x * gridDim.x) { - const int k = s % block_size; - const int j = s / block_size % N; - const int i = s / block_size / N; - const T_INDEX idx = indices[j]; - const float* src_offset = src_base + i * data_batch_size + idx * block_size; - float* dst_offset = out + i * gathered_batch_size + j * block_size; - dst_offset[k] = src_offset[k]; - } -} - template <> bool BatchGatherOp::RunOnDevice() { return DispatchHelper>::call( @@ -37,39 +16,9 @@ bool BatchGatherOp::RunOnDevice() { template <> template bool BatchGatherOp::DoRunWithType() { - auto& data = Input(DATA); - auto& indices = Input(INDICES); - auto* output = Output(0); - - vector shape; - shape.push_back(data.dim(0)); - shape.insert(shape.end(), indices.dims().begin(), indices.dims().end()); - shape.insert(shape.end(), data.dims().begin() + 2, data.dims().end()); - output->Resize(shape); - - const int block_size = data.size_from_dim(2); - const int N = indices.size(); - const auto data_batch_size = data.size_from_dim(1); - const auto gathered_batch_size = N * data.size_from_dim(2); - const TInd* idxs = indices.template data(); - auto src_base = static_cast(data.raw_data()); - auto out = static_cast(output->raw_mutable_data(data.meta())); - const int M = data.dim32(0); - - BatchGatherKernel<<< - std::min(M, CAFFE_MAXIMUM_NUM_BLOCKS), - std::min(N * block_size, CAFFE_CUDA_NUM_THREADS), - 0, - context_.cuda_stream()>>>( - src_base, - out, - idxs, - M, - N, - data_batch_size, - gathered_batch_size, - block_size); - return true; + // BatchGather is a special-case of Gather with Axis = 1, wrap = false. + return gather_helper::gather_impl_cuda( + this, DATA, INDICES, 0, 1, false); } template @@ -77,18 +26,24 @@ __global__ void BatchGatherGradientKernel( const TData* grad_data, TData* out, const T_INDEX* indices, - const int M, + const int outer_dims_product, const int N, const int data_batch_size, const int gathered_batch_size, - const int block_size) { + const int block_size, + const int src_indexing_axis_dim, + const bool wrap_indices) { int begin_idx = blockIdx.x * blockDim.x + threadIdx.x; - int num_items = M * N * block_size; + int num_items = outer_dims_product * N * block_size; + for (int s = begin_idx; s < num_items; s += blockDim.x * gridDim.x) { const int k = s % block_size; const int j = s / block_size % N; const int i = s / block_size / N; - const T_INDEX idx = indices[j]; + T_INDEX idx = indices[j]; + if (wrap_indices && idx < 0) { + idx = idx + src_indexing_axis_dim; + } const float* src_offset = grad_data + i * gathered_batch_size + j * block_size; float* dst_offset = out + i * data_batch_size + idx * block_size; @@ -118,34 +73,51 @@ bool BatchGatherGradientOp::DoRunWithType2() { auto& grad = Input(GRAD); auto* output = Output(0); - CAFFE_ENFORCE_EQ(data.dim(0), grad.dim(0), "batch sizes should be the same"); + // ONNX allows negative axis to index from the back, valid range: [-r, r]. + int axis = axis_; + if (axis < 0) { + axis = data.dim() + axis; + } + // Outer dimensions of input data and gradient should be the same + // because they are preserved for gathers with axis > 0. + for (int acheck = 0; acheck < axis; acheck++) { + CAFFE_ENFORCE_EQ( + data.size(acheck), grad.size(acheck), "batch sizes should be the same"); + } output->ResizeLike(data); auto* out_data = output->template mutable_data(); math::Set(output->size(), 0, out_data, &context_); const auto* grad_data = grad.template data(); + const TInd* idxs = indices.template data(); + + // Treat all outer dimensions as a unit as they contribute to larger batch. + const int outer_dims_product = grad.size_to_dim(axis); + const int block_size = data.size_from_dim(axis + 1); - const int M = grad.dim32(0); - const int block_size = data.size_from_dim(2); const int N = indices.size(); - const auto data_batch_size = data.size_from_dim(1); - const auto gathered_batch_size = N * data.size_from_dim(2); - const TInd* idxs = indices.template data(); + const auto data_batch_size = data.size_from_dim(axis); + const auto gathered_batch_size = N * block_size; + const int src_indexing_axis_dim = data.dim(axis); + // Assign each thread index its own 'float' in block_size * N (kernel will + // loop if there is more data than fits NUM_BLOCKS * NUM_THREADS limit). BatchGatherGradientKernel<<< - std::min(M, CAFFE_MAXIMUM_NUM_BLOCKS), + std::min(outer_dims_product, CAFFE_MAXIMUM_NUM_BLOCKS), std::min(N * block_size, CAFFE_CUDA_NUM_THREADS), 0, context_.cuda_stream()>>>( grad_data, out_data, idxs, - M, + outer_dims_product, N, data_batch_size, gathered_batch_size, - block_size); + block_size, + src_indexing_axis_dim, + false); // TBD: Add proper index wrapping support to Gather gradients. return true; } diff --git a/caffe2/operators/batch_gather_ops.h b/caffe2/operators/batch_gather_ops.h index 325abce..46ab990 100644 --- a/caffe2/operators/batch_gather_ops.h +++ b/caffe2/operators/batch_gather_ops.h @@ -4,6 +4,8 @@ #include "caffe2/core/context.h" #include "caffe2/core/operator.h" #include "caffe2/utils/math.h" +// Reuse helper logic from GatherOp since BatchGather is the same with axis=1. +#include "caffe2/operators/gather_op.h" namespace caffe2 { @@ -20,66 +22,9 @@ class BatchGatherOp final : public Operator { template bool DoRunWithType() { - auto& data = Input(DATA); - auto& indices = Input(INDICES); - - CAFFE_ENFORCE_GE(data.dim(), 2, "DATA should be at least 2-D"); - - vector shape; - shape.push_back(data.size(0)); - shape.insert(shape.end(), indices.sizes().begin(), indices.sizes().end()); - shape.insert(shape.end(), data.sizes().begin() + 2, data.sizes().end()); - auto* output = Output(0, shape, at::dtype(data.dtype())); - - auto block_size = data.size_from_dim(2); - auto block_bytesize = block_size * data.dtype().itemsize(); - auto N = indices.numel(); - auto data_batch_size = data.size_from_dim(1); - auto gathered_batch_size = N * data.size_from_dim(2); - auto data_batch_bytesize = data_batch_size * data.dtype().itemsize(); - auto gathered_batch_bytesize = - gathered_batch_size * data.dtype().itemsize(); - const TInd* idxs = indices.template data(); - auto src_base = static_cast(data.raw_data()); - auto out = static_cast(output->raw_mutable_data(data.dtype())); - - for (auto i = 0; i < N; ++i) { - auto idx = idxs[i]; - CAFFE_ENFORCE( - 0 <= idx && idx < data.size(1), - "INDICES element is out of DATA bounds, id=", - idx, - " data_dim=", - data.size(1)); - } - - if (data.template IsType() && block_size == 1) { - auto src = data.template data(); - auto dst = output->template mutable_data(); - - for (auto batch = 0; batch < data.size(0); ++batch) { - auto src_batch_base = src + batch * data_batch_size; - auto out_batch_base = dst + batch * gathered_batch_size; - - for (auto i = 0; i < N; ++i) { - auto idx = idxs[i]; - out_batch_base[i] = src_batch_base[idx]; - } - } - } else { - for (auto batch = 0; batch < data.size(0); ++batch) { - auto src_batch_base = src_base + batch * data_batch_bytesize; - auto out_batch_base = out + batch * gathered_batch_bytesize; - - for (auto i = 0; i < N; ++i) { - auto idx = idxs[i]; - auto src = src_batch_base + idx * block_bytesize; - auto dst = out_batch_base + i * block_bytesize; - context_.CopyItemsSameDevice(data.dtype(), block_size, src, dst); - } - } - } - return true; + // BatchGather is a special-case of Gather with Axis = 1. + return gather_helper::gather_impl( + this, DATA, INDICES, 0, 1, false); } INPUT_TAGS(DATA, INDICES); }; @@ -88,7 +33,13 @@ template class BatchGatherGradientOp final : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; - USE_SIMPLE_CTOR_DTOR(BatchGatherGradientOp); + + // Constructor to recieve axis in case it was passed for GatherOp gradient, + // use default of 1 for batch gather otherwise. + BatchGatherGradientOp(const OperatorDef& operator_def, Workspace* ws) + : Operator(operator_def, ws), + OP_SINGLE_ARG(int, "axis", axis_, 1) { } + virtual ~BatchGatherGradientOp() noexcept {} bool RunOnDevice() override { return DispatchHelper>::call( @@ -109,49 +60,62 @@ class BatchGatherGradientOp final : public Operator { auto& grad = Input(GRAD); auto* output = Output(0); + // ONNX allows negative axis to index from the back, valid range: [-r, r]. + int axis = axis_; + if (axis < 0) { + axis = data.dim() + axis; + } + CAFFE_ENFORCE_GE(data.dim(), 2, "DATA should be at least 2-D"); - CAFFE_ENFORCE_EQ( - data.size(0), grad.size(0), "batch sizes should be the same"); + // Outer dimensions of input data and gradient should be the same + // because they are preserved for gathers with axis > 0. + for (int acheck = 0; acheck < axis; acheck++) { + CAFFE_ENFORCE_EQ( + data.size(acheck), + grad.size(acheck), + "batch gather outer dimensions should match"); + } output->ResizeLike(data); TData* out_data = output->template mutable_data(); if (data.numel() <= 0) { return true; } - memset(out_data, 0, output->nbytes()); const TData* grad_data = grad.template data(); + const TInd* idxs = indices.template data(); - auto block_size = data.size_from_dim(2); + auto outer_dims_product = data.size_to_dim(axis); + auto batch_size = data.size_from_dim(axis); + auto block_size = data.size_from_dim(axis + 1); auto N = indices.numel(); - auto data_batch_size = data.size_from_dim(1); - auto gathered_batch_size = N * data.size_from_dim(2); - const TInd* idxs = indices.template data(); + auto gathered_grad_batch_size = N * block_size; - for (auto i = 0; i < N; ++i) { - auto idx = idxs[i]; - CAFFE_ENFORCE( - 0 <= idx && idx < data.size(1), - "INDICES element is out of DATA bounds, id=", - idx, - " data_dim=", - data.size(1)); - } + // Check indexing bounds. + auto src_indexing_axis_dim = data.dim(axis); + gather_helper::check_indexarray_range( + idxs, + N, + src_indexing_axis_dim, + false); - for (auto batch = 0; batch < grad.size(0); ++batch) { - auto src_batch_base = grad_data + batch * gathered_batch_size; - auto out_batch_base = out_data + batch * data_batch_size; + for (auto batch = 0; batch < outer_dims_product; ++batch) { + auto grad_batch_base = grad_data + batch * gathered_grad_batch_size; + auto out_batch_base = out_data + batch * batch_size; for (auto i = 0; i < N; ++i) { auto idx = idxs[i]; + if (idx < 0) { + idx = idx + src_indexing_axis_dim; + } if (block_size == 1) { - out_batch_base[idx * block_size] += src_batch_base[i * block_size]; + out_batch_base[idx] += grad_batch_base[i]; } else { math::Add( block_size, out_batch_base + idx * block_size, - src_batch_base + i * block_size, + grad_batch_base + i * block_size, out_batch_base + idx * block_size, &context_); } @@ -164,12 +128,14 @@ class BatchGatherGradientOp final : public Operator { bool DoRunWithOtherType2() { CAFFE_THROW( "BatchGatherGradient is not implemented on tensor of type ", - Input(DATA).dtype().name(), + Input(DATA).meta().name(), "Consider adding it a type in the list DispatchHelper or implementing " "a generic version (which won't work for duplicated indices though)"); } INPUT_TAGS(DATA, INDICES, GRAD); +protected: + int axis_; }; } // namespace caffe2 diff --git a/caffe2/operators/gather_op.cc b/caffe2/operators/gather_op.cc index 8eec9e4..805bd24 100644 --- a/caffe2/operators/gather_op.cc +++ b/caffe2/operators/gather_op.cc @@ -75,50 +75,73 @@ OUTPUT: "INDICES", "Input indices tensor of rank $q$. This tensor must contain integers.") .Output(0, "OUTPUT", "Output tensor of rank $q+(r-1)$") - .TensorInferenceFunction([](const OperatorDef& /* unused */, + .TensorInferenceFunction([](const OperatorDef& def, const vector& in) { + ArgumentHelper helper(def); + const int axis = helper.GetSingleArgument("axis", 0); + const auto& data_dims = GetDimsVector(in[0]); + const auto& indices_dims = GetDimsVector(in[1]); + + vector output_dims = + caffe2::gather_helper::calc_output_shape_vector( + data_dims, indices_dims, axis); vector out(1); - if (in[0].dims(0) == 0) { - for (int i = 0; i < in[0].dims_size(); ++i) { - out[0].add_dims(in[0].dims(i)); - } - } else { - for (auto d : in[1].dims()) { - out[0].add_dims(d); - } - for (int i = 1; i < in[0].dims_size(); ++i) { - out[0].add_dims(in[0].dims(i)); - } - } - out[0].set_data_type(in[0].data_type()); + out[0] = CreateTensorShape(output_dims, in[0].data_type()); return out; }) .InheritOnnxSchema(); class GetGatherGradient : public GradientMakerBase { using GradientMakerBase::GradientMakerBase; + vector GetGradientDefs() override { ArgumentHelper argsHelper(def_); const bool dense_gradient = argsHelper.GetSingleArgument("dense_gradient", false); + const int axis = argsHelper.GetSingleArgument("axis", 0); + + // TBD: While it hasn't been used yet, we need to add wrap_indices support + // to gradients next. + // if (argsHelper.HasArgument("wrap_indices_")) { + // } using Op = GatherOp; - if (dense_gradient) { - return vector{CreateOperatorDef( - "SparseToDense", - "", - vector{I(Op::INDICES), GO(0), I(Op::DATA)}, - vector{GI(Op::DATA)})}; - } else { - // For now we don't do any reshaping as the consumer of this op would - // probably be ScatterUpdate which is intenionally ignores shapes. We - // might need to revisit it in the future for correctness purposes. The - // right shape for the output woild be to flatten INDICES and collapse - // first X dims of GRAD - SetSparse(Op::DATA, I(Op::INDICES), GO(0)); - return vector(); + if (axis == 0) { + if (dense_gradient) { + return vector{CreateOperatorDef( + "SparseToDense", + "", + vector{I(Op::INDICES), GO(0), I(Op::DATA)}, + vector{GI(Op::DATA)})}; + } else { + // For now we don't do any reshaping as the consumer of this op would + // probably be ScatterUpdate which is intenionally ignores shapes. We + // might need to revisit it in the future for correctness purposes. The + // right shape for the output woild be to flatten INDICES and collapse + // first X dims of GRAD + SetSparse(Op::DATA, I(Op::INDICES), GO(0)); + return vector(); + } + } + + // TBD: This is misleading to use dense_gradient by default for axis 0 + // and not othewise.... + if (argsHelper.HasArgument("dense_gradient")) { + CAFFE_ENFORCE( + dense_gradient == true, + "Gather with axis > 0 must use dense_gradient"); } + + Argument axisArg = MakeArgument("axis", axis); + return SingleGradientDef( + "BatchGatherGradient", + "", + // This is the order as expected by BatchGatherGradient indices, + // different from SpartseToDense above. + vector{I(Op::DATA), I(Op::INDICES), GO(0)}, + vector{GI(0)}, + std::vector{axisArg}); } }; REGISTER_GRADIENT(Gather, GetGatherGradient); diff --git a/caffe2/operators/gather_op.cu b/caffe2/operators/gather_op.cu index ba704f2..395b7d5 100644 --- a/caffe2/operators/gather_op.cu +++ b/caffe2/operators/gather_op.cu @@ -1,25 +1,9 @@ #include "caffe2/core/context_gpu.h" #include "caffe2/operators/gather_op.h" +#include "caffe2/operators/gather_op.cuh" namespace caffe2 { -template -__global__ void GatherKernel( - const float* X, - float* Y, - const T_INDEX* indices, - const int N, - const int block_size) { - for (int i = blockIdx.x; i < N; i += gridDim.x) { - T_INDEX idx = indices[i]; - const float* src_offset = X + idx * block_size; - float* dst_offset = Y + i * block_size; - for (int j = threadIdx.x; j < block_size; j += blockDim.x) { - dst_offset[j] = src_offset[j]; - } - } -} - template <> bool GatherOp::RunOnDevice() { return DispatchHelper>::call( @@ -29,38 +13,9 @@ bool GatherOp::RunOnDevice() { template <> template bool GatherOp::DoRunWithType() { - auto& data = Input(DATA); - auto& indices = Input(INDICES); - auto* output = Output(0); - - CAFFE_ENFORCE_GE(data.ndim(), 1, "DATA should be at least 1-D"); - auto shape = indices.dims().vec(); - shape.insert(shape.end(), data.dims().begin() + 1, data.dims().end()); - output->Resize(shape); - - int block_size = data.size() / data.dim(0); - auto block_bytesize = data.size_from_dim(1) * data.meta().itemsize(); - CAFFE_ENFORCE( - block_bytesize == data.nbytes() / data.dim(0), - "block_bytesize should be consistent with data dim"); - int N = indices.size(); - - auto src_base = static_cast(data.raw_data()); - const Index* idxs = indices.template data(); - auto out = static_cast(output->raw_mutable_data(data.meta())); - - // return early when the input is empty, since CUDA kernel will fail for - // empty input. - if (N <= 0) { - return true; - } - - GatherKernel<<< - std::min(N, CAFFE_MAXIMUM_NUM_BLOCKS), - CAFFE_CUDA_NUM_THREADS, - 0, - context_.cuda_stream()>>>(src_base, out, idxs, N, block_size); - return true; + // Use shared implementation with BatchGather + return gather_helper::gather_impl_cuda( + this, DATA, INDICES, 0, axis_, wrap_indices_); } REGISTER_CUDA_OPERATOR(Gather, GatherOp); diff --git a/caffe2/operators/gather_op.cuh b/caffe2/operators/gather_op.cuh new file mode 100644 index 0000000..ed7c67c --- /dev/null +++ b/caffe2/operators/gather_op.cuh @@ -0,0 +1,112 @@ +#include "caffe2/core/common_gpu.h" +#include "caffe2/core/context_gpu.h" +#include "caffe2/operators/gather_op.h" + +namespace caffe2 { + +// This maintains kernels and index-mapping functions shared +// by Gather and BatchGather ops. +namespace gather_helper { + +template +__global__ void BatchGatherKernel( + const TData* src_base, + TData* out, + const T_INDEX* indices, + const int M, + const int N, + const int data_batch_size, + const int gathered_batch_size, + const int block_size, + const int indexing_axis_dim, + const bool wrap_indices) { + const int begin_idx = blockIdx.x * blockDim.x + threadIdx.x; + const int num_items = M * N * block_size; + for (int s = begin_idx; s < num_items; s += blockDim.x * gridDim.x) { + const int k = s % block_size; + const int j = s / block_size % N; + const int i = s / block_size / N; + T_INDEX idx = indices[j]; + if (wrap_indices && (idx < 0)) { + idx = idx + (T_INDEX) indexing_axis_dim; + } + const float* src_offset = src_base + i * data_batch_size + idx * block_size; + float* dst_offset = out + i * gathered_batch_size + j * block_size; + dst_offset[k] = src_offset[k]; + } +} + +// Actual gather implementation - resizes output and copies indexed data. +template +static bool gather_impl_cuda( + Operator* op, + int dataIdx, + int indicesIdx, + int outputIdx, + int axis, + bool wrap_indices) { + const Tensor& data = op->Input(dataIdx); + const Tensor& indices = op->Input(indicesIdx); + const TypeMeta dataType = data.dtype(); + size_t item_bytesize = dataType.itemsize(); + + // ONNX allows negative axis to index from the back, valid range: [-r, r]. + if (axis < 0) { + axis = data.dim() + axis; + } + CAFFE_ENFORCE_GE( + data.ndim(), axis + 1, "DATA should be at least [axis+1]-D"); + CAFFE_ENFORCE_GE(axis, 0, "Axis should be non-negative"); + CAFFE_ENFORCE_LT(axis, data.ndim(), "Axis out of range"); + + // New shape: + // [data dims before axis] + [indices dims] + [data dims after axis] + vector shape = + calc_output_shape_vector(data.dims(), indices.dims(), axis); + Tensor* output = op->Output(outputIdx, shape, at::dtype(dataType)); + float* out = static_cast(output->raw_mutable_data(dataType)); + + // Succeed if size of output is zero, which can happen for empty batch which + // would have data dimension size of 0. + // This *must* be done AFTER output->raw_mutable_data() above as that has + // important allocation side effect that we must see. + if (output->size() == 0) { + return true; + } + + const Index* idxs = indices.template data(); + const float* src_base = static_cast(data.raw_data()); + + const int outer_dims_product = data.size_to_dim(axis); + const int block_size = data.size_from_dim(axis + 1); + + const int src_indexing_axis_dim = data.size(axis); + // Treat indices as a single block even if they have multiple dimensions. + // The "gathered batch" is a cumulative result combining indexed blocks. + const int N = indices.size(); + auto gathered_batch_size = N * block_size; + const auto src_batch_size = data.size_from_dim(axis); + + // Only run kernel if input is not empty. + if (N > 0) { + BatchGatherKernel<<< + std::min(outer_dims_product, CAFFE_MAXIMUM_NUM_BLOCKS), + std::min(N * block_size, CAFFE_CUDA_NUM_THREADS), + 0, + op->getContext()->cuda_stream()>>>( + src_base, + out, + idxs, + outer_dims_product, + N, + src_batch_size, + gathered_batch_size, + block_size, + src_indexing_axis_dim, + wrap_indices); + } + return true; +} + +} // namespace gather_helper +} // namespace caffe2 diff --git a/caffe2/operators/gather_op.h b/caffe2/operators/gather_op.h index cc25505..4afc5c5 100644 --- a/caffe2/operators/gather_op.h +++ b/caffe2/operators/gather_op.h @@ -6,11 +6,167 @@ namespace caffe2 { +// This maintains index-mapping functions shared by Gather and BatchGather ops. +namespace gather_helper { + +// New shape is concatenation: +// [data dims before axis] + [indices dims] + [data dims after axis] +template +static vector calc_output_shape_vector( + const DataDimsVec& data_dims, + const IndexDimsVec& indices_dims, + int axis) { + vector shape; + // If the dimension we are indexing is empty, just use data_dims as shape. + // This replicates behavior in (https://github.com/pytorch/pytorch/pull/13781) + // needed to allow workflows with empty batch to succeed. + if (data_dims[axis] == 0) { + shape.insert(shape.end(), data_dims.begin(), data_dims.end()); + } else { + shape.insert(shape.end(), data_dims.begin(), data_dims.begin() + axis); + shape.insert(shape.end(), indices_dims.begin(), indices_dims.end()); + shape.insert(shape.end(), data_dims.begin() + axis + 1, data_dims.end()); + } + return shape; +} + +// Check that indices fall within dimension array size with CAFFE_ENFORCE. +template +static void check_indexarray_range( + const IndexType* indices, + int64_t n, + IndexType indexing_axis_dim, + bool wrap_indices) { + // + for (auto i = 0; i < n; ++i) { + auto idx = indices[i]; + if (wrap_indices && idx < 0) { + idx = idx + indexing_axis_dim; + } + CAFFE_ENFORCE( + 0 <= idx && idx < indexing_axis_dim, + "INDICES element is out of DATA bounds, id=", + idx, + " axis_dim=", + indexing_axis_dim); + } +} + +// Actual gather implementation - resizes output and copies indexed data. +template +static bool gather_impl( + Operator* op, + int dataIdx, + int indicesIdx, + int outputIdx, + int axis, + bool wrap_indices) { + // If we endup using it on GPU doing O(N) memcpy is probably not best :) + // TODO: implement prefetching if it starts mattering (TF does it) + + const Tensor& data = op->Input(dataIdx); + const Tensor& indices = op->Input(indicesIdx); + const TypeMeta dataType = data.dtype(); + size_t item_bytesize = dataType.itemsize(); + + // ONNX allows negative axis to index from the back, valid range: [-r, r]. + if (axis < 0) { + axis = data.dim() + axis; + } + CAFFE_ENFORCE_GE(data.dim(), axis + 1, "DATA should be at least [axis+1]-D"); + CAFFE_ENFORCE_GE(axis, 0, "Axis should be non-negative"); + CAFFE_ENFORCE_LT(axis, data.dim(), "Axis out of range"); + + // New shape: + // [data dims before axis] + [indices dims] + [data dims after axis] + vector shape = + calc_output_shape_vector(data.dims(), indices.dims(), axis); + Tensor* output = op->Output(outputIdx, shape, at::dtype(dataType)); + auto out = static_cast(output->raw_mutable_data(dataType)); + + // Succeed if size of output is zero, which can happen for empty batch which + // would have data dimension size of 0. + // This *must* be done AFTER output->raw_mutable_data() above as that has + // important allocation side effect that we must see. + if (output->size() == 0) { + return true; + } + + const Index* idxs = indices.template data(); + auto src_base = static_cast(data.raw_data()); + + auto outer_dims_product = data.size_to_dim(axis); + auto block_size = data.size_from_dim(axis + 1); + auto block_bytesize = block_size * item_bytesize; + + auto src_indexing_axis_dim = data.size(axis); + auto src_batch_bytesize = data.size_from_dim(axis) * item_bytesize; + // Treat indices as a single block even if they have multiple dimensions. + // The "gathered batch" is a cumulative result combining indexed blocks. + auto N = indices.size(); + auto gathered_batch_bytesize = N * block_size * item_bytesize; + + check_indexarray_range(idxs, N, src_indexing_axis_dim, wrap_indices); + + // Special-case single-float copy for efficiency + if (data.template IsType() && block_size == 1) { + for (auto batch = 0; batch < outer_dims_product; ++batch) { + const float* src_floats = + (const float*)(src_base + batch * src_batch_bytesize); + float* dst_floats = (float*)(out + batch * gathered_batch_bytesize); + + for (auto i = 0; i < N; ++i) { + auto idx = idxs[i]; + if (wrap_indices && idx < 0) { + idx = idx + src_indexing_axis_dim; + } + dst_floats[i] = src_floats[idx]; + } + } + } else { + // outer_dims_product specifies how many times we repeat inner dimensions, + // so we just iterate over it to cover all outer dimensions. + for (auto batch = 0; batch < outer_dims_product; ++batch) { + for (auto i = 0; i < N; ++i) { + auto idx = idxs[i]; + if (wrap_indices && idx < 0) { + idx = idx + src_indexing_axis_dim; + } + + auto src = src_base + batch * src_batch_bytesize + idx * block_bytesize; + auto dst = out + batch * gathered_batch_bytesize + i * block_bytesize; + op->getContext()->CopyItemsSameDevice(dataType, block_size, src, dst); + } + } + } + return true; +} + +} // namespace gather_helper + template class GatherOp : public Operator { public: USE_OPERATOR_CONTEXT_FUNCTIONS; - USE_SIMPLE_CTOR_DTOR(GatherOp); + + GatherOp(const OperatorDef& operator_def, Workspace* ws) + : Operator(operator_def, ws), + OP_SINGLE_ARG(int, "axis", axis_, 0) { + // TBD: We may want to fix the old index wrap behaviour once we have + // operator versioning, to only apply it when needed as otherwise its likely + // an error. + // Right now, we apply index wrapping by default only to axis == 0, + // since we have ONNX conversion code that uses it. For other ops it + // needs to be speified explicitly with argument or you don't get it. + if (OperatorBase::HasArgument("wrap_indices")) { + wrap_indices_ = Operator::template GetSingleArgument( + "wrap_indices", (false)); + } else { + wrap_indices_ = (axis_ == 0) ? true : false; + } + } + + virtual ~GatherOp() noexcept {} bool RunOnDevice() override { return DispatchHelper>::call( @@ -19,50 +175,16 @@ class GatherOp : public Operator { template bool DoRunWithType() { - // If we endup using it on GPU doing O(N) memcpy is probably not best :) - // TODO: implement prefetching if it starts mattering (TF does it) - auto& data = Input(DATA); - auto& indices = Input(INDICES); - auto* output = Output(0); - - CAFFE_ENFORCE_GE(data.dim(), 1, "DATA should be at least 1-D"); - auto shape = data.sizes().vec(); - if (data.size(0) > 0) { - shape = indices.sizes().vec(); - shape.insert(shape.end(), data.sizes().begin() + 1, data.sizes().end()); - } - output->Resize(shape); - - int block_size = data.size_from_dim(1); - auto block_bytesize = data.size_from_dim(1) * data.dtype().itemsize(); - int N = indices.numel(); - - auto src_base = static_cast(data.raw_data()); - const Index* idxs = indices.template data(); - auto out = static_cast(output->raw_mutable_data(data.dtype())); - - if (output->numel() == 0) { - return true; - } - for (int i = 0; i < N; ++i) { - auto idx = idxs[i]; - if (idx < 0) { - idx = idx + data.size(0); - } - CAFFE_ENFORCE( - 0 <= idx && idx < data.size(0), - "INDICES element is out of DATA bounds, id=", - idx, - " data_dim=", - data.size(0)); - auto src = src_base + idx * block_bytesize; - context_.template CopyItems( - data.dtype(), block_size, src, out + block_bytesize * i); - } - return true; + return gather_helper::gather_impl( + this, DATA, INDICES, 0, axis_, wrap_indices_); } INPUT_TAGS(DATA, INDICES); + + protected: + int axis_; + bool wrap_indices_; }; + } // namespace caffe2 #endif // GATHER_OP_H_ diff --git a/caffe2/python/operator_test/gather_ops_test.py b/caffe2/python/operator_test/gather_ops_test.py index 4da2186..6c9ca91 100644 --- a/caffe2/python/operator_test/gather_ops_test.py +++ b/caffe2/python/operator_test/gather_ops_test.py @@ -11,9 +11,32 @@ import caffe2.python.serialized_test.serialized_test_util as serial import hypothesis.strategies as st import hypothesis.extra.numpy as hnp +# Basic implementation of gather for axis == 0, shich is lookup of indices +# in the outer dimention. Keeping it for reference here, although is similar +# to more general funciton below. +def ref_gather_axis0(): + def inner(data, ind): + if ind.size == 0 or data.shape[0] == 0: + return [np.zeros((0, 10, 20)).astype(np.float32)] + output = [data[i] for i in ind] + return [output] + return inner + +# Returns axis-based lookup. We just use numpy take() which handles different +# axis values as we want. +def ref_gather(axis): + def inner(data, ind): + if ind.size == 0 or data.shape[axis] == 0: + shape = list(data.shape) + shape[0] = 0 + return [np.zeros(tuple(shape)).astype(np.float32)] + # np.take() does axis lookup same as gather + output = data.take(ind, axis).astype(np.float32) + return [output] + return inner class TestGatherOps(serial.SerializedTestCase): - @serial.given(rows_num=st.integers(0, 10000), + @given(rows_num=st.integers(0, 10000), index_num=st.integers(0, 5000), **hu.gcs) def test_gather_ops(self, rows_num, index_num, gc, dc): @@ -28,22 +51,37 @@ class TestGatherOps(serial.SerializedTestCase): ['data', 'ind'], ['output']) - def ref_gather(data, ind): - if ind.size == 0 or rows_num == 0: - return [np.zeros((0, 10, 20)).astype(np.float32)] - - output = [data[i] for i in ind] - return [output] + self.assertReferenceChecks(gc, op, [data, ind], ref_gather_axis0()) + self.assertDeviceChecks(dc, op, [data, ind], [0]) + return + + # Test axis == 2, this keeps outer dimension but will replace data + # within axis by lookup of index array (repeated for each outer entry) + @given(batch_num=st.integers(1, 4000), + rows_num=st.integers(1, 6), + index_num=st.integers(1, 20), + **hu.gcs) + def test_gather_ops_axis2(self, batch_num, rows_num, index_num, gc, dc): + data = np.random.random((batch_num, rows_num, 5)).astype(np.float32) + ind = np.random.randint(5, size=(index_num, )).astype('int32') + op = core.CreateOperator( + 'Gather', + ['data', 'ind'], + ['output'], + axis=2) - self.assertReferenceChecks(gc, op, [data, ind], ref_gather) + self.assertReferenceChecks(gc, op, [data, ind], ref_gather(axis=2)) + self.assertDeviceChecks(dc, op, [data, ind], [0]) + return +# Generates data arrays of max dims 10x100x2 and indexing array up to rows_num @st.composite def _inputs(draw): - rows_num = draw(st.integers(1, 100)) - index_num = draw(st.integers(1, 10)) batch_size = draw(st.integers(2, 10)) + rows_num = draw(st.integers(1, 100)) block_size = draw(st.integers(1, 2)) + index_num = draw(st.integers(1, 10)) return ( draw(hnp.arrays( np.float32, @@ -57,9 +95,8 @@ def _inputs(draw): )), ) - -class TestBatchGatherOps(serial.SerializedTestCase): - @serial.given(inputs=_inputs(), +class TestBatchGatherOps(hu.HypothesisTestCase): + @given(inputs=_inputs(), **hu.gcs) def test_batch_gather_ops(self, inputs, gc, dc): data, ind = inputs @@ -67,14 +104,7 @@ class TestBatchGatherOps(serial.SerializedTestCase): 'BatchGather', ['data', 'ind'], ['output']) - - def ref_batch_gather(data, ind): - output = [] - for b in range(data.shape[0]): - output.append([data[b][i] for i in ind]) - return [output] - - self.assertReferenceChecks(gc, op, [data, ind], ref_batch_gather) + self.assertReferenceChecks(gc, op, [data, ind], ref_gather(axis=1)) self.assertGradientChecks(gc, op, [data, ind], 0, [0]) diff --git a/tools/amd_build/pyHIPIFY/cuda_to_hip_mappings.py b/tools/amd_build/pyHIPIFY/cuda_to_hip_mappings.py index 5096974..730b53b 100644 --- a/tools/amd_build/pyHIPIFY/cuda_to_hip_mappings.py +++ b/tools/amd_build/pyHIPIFY/cuda_to_hip_mappings.py @@ -2229,6 +2229,7 @@ CAFFE2_SPECIFIC_MAPPINGS = collections.OrderedDict([ ("/GpuDefs", ("/hip/GpuDefs", API_CAFFE2)), ("/GpuScanUtils", ("/hip/GpuScanUtils", API_CAFFE2)), ("/GpuBitonicSort", ("/hip/GpuBitonicSort", API_CAFFE2)), + ("/gather_op.cuh", ("/hip/gather_op.cuh", API_CAFFE2)), ("caffe2/core/common_cudnn.h", ("caffe2/core/hip/common_miopen.h", API_CAFFE2)), ("REGISTER_CUDA_OPERATOR" , ("REGISTER_HIP_OPERATOR", API_CAFFE2)), ("CUDA_1D_KERNEL_LOOP" , ("HIP_1D_KERNEL_LOOP", API_CAFFE2)), -- 2.7.4