Speed-up "advanced" indexing operations (#13420)
authorSam Gross <sgross@fb.com>
Tue, 27 Nov 2018 23:18:39 +0000 (15:18 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Tue, 27 Nov 2018 23:23:59 +0000 (15:23 -0800)
Summary:
This speeds-up "advanced" indexing (indexing a tensor by a tensor)
on CPU and GPU. There's still a bunch of work to do, including
speeding up indexing by a byte (boolean) mask and speeding up the derivative
calculation for advanced indexing.

Here's some speed comparisons to indexing on master using a little [benchmark script](https://gist.github.com/colesbury/c369db72aad594e5e032c8fda557d909) with 16 OpenMP threads and on a P100. The test cases are listed as (input shape -> output shape).

| Test case             | CPU (old vs. new)   | CUDA (old vs. new)     |
|-----------------------|---------------------|------------------------|
| 1024x1024 -> 512x1024 | 225 us vs. **57 us**  | 297 us vs. **47 us** |
| 1024x1024 -> 1024x512 | 208 us vs. **153 us** | 335 us vs. **54 us** |
| 50x50 -> 20000x50     | 617 us vs. **77 us**  | 239 us vs. **54 us** |
| 50x50 -> 50x20000     | 575 us vs. **236 us** | 262 us vs. **58 us** |
| 2x5x10 -> 10          | 65 us  vs. **18 us**  | 612 us vs. **93 us** |

See #11647
Pull Request resolved: https://github.com/pytorch/pytorch/pull/13420

Reviewed By: soumith

Differential Revision: D13088936

Pulled By: colesbury

fbshipit-source-id: 0a5c2ee9aa54e15f96d06692d1694c3b24b924e2

21 files changed:
aten/src/ATen/core/Tensor.h
aten/src/ATen/core/TensorMethods.h
aten/src/ATen/core/Type.h
aten/src/ATen/cuda/detail/OffsetCalculator.cuh
aten/src/ATen/native/Indexing.cpp
aten/src/ATen/native/Indexing.h [new file with mode: 0644]
aten/src/ATen/native/TensorIterator.cpp
aten/src/ATen/native/TensorIterator.h
aten/src/ATen/native/cpu/IndexKernel.cpp [new file with mode: 0644]
aten/src/ATen/native/cuda/IndexKernel.cu [new file with mode: 0644]
aten/src/ATen/native/native_functions.yaml
test/expect/TestScript.test_index_put_trace_with_view.expect
test/expect/TestScript.test_index_put_trace_without_view.expect
test/run_test.py
test/test_cuda.py
test/test_indexing.py
test/test_indexing_cuda.py [new file with mode: 0644]
tools/autograd/derivatives.yaml
torch/csrc/autograd/VariableTypeManual.cpp
torch/csrc/autograd/VariableTypeUtils.h
torch/onnx/symbolic.py

index c48afa0..16c2699 100644 (file)
@@ -356,8 +356,8 @@ public:
   Tensor irfft(int64_t signal_ndim, bool normalized=false, bool onesided=true, IntList signal_sizes={}) const;
   Tensor index(TensorList indices) const;
   Tensor & index_copy_(int64_t dim, const Tensor & index, const Tensor & source);
-  Tensor index_put(TensorList indices, const Tensor & values) const;
-  Tensor & index_put_(TensorList indices, const Tensor & values);
+  Tensor index_put(TensorList indices, const Tensor & values, bool accumulate=false) const;
+  Tensor & index_put_(TensorList indices, const Tensor & values, bool accumulate=false);
   Tensor inverse() const;
   Tensor isclose(const Tensor & other, double rtol=1e-05, double atol=1e-08, bool equal_nan=false) const;
   bool is_distributed() const;
index 114f067..270b9c4 100644 (file)
@@ -318,11 +318,11 @@ inline Tensor Tensor::index(TensorList indices) const {
 inline Tensor & Tensor::index_copy_(int64_t dim, const Tensor & index, const Tensor & source) {
     return type().index_copy_(*this, dim, index, source);
 }
-inline Tensor Tensor::index_put(TensorList indices, const Tensor & values) const {
-    return type().index_put(*this, indices, values);
+inline Tensor Tensor::index_put(TensorList indices, const Tensor & values, bool accumulate) const {
+    return type().index_put(*this, indices, values, accumulate);
 }
-inline Tensor & Tensor::index_put_(TensorList indices, const Tensor & values) {
-    return type().index_put_(*this, indices, values);
+inline Tensor & Tensor::index_put_(TensorList indices, const Tensor & values, bool accumulate) {
+    return type().index_put_(*this, indices, values, accumulate);
 }
 inline Tensor Tensor::inverse() const {
     return type().inverse(*this);
index 2c7d22e..a937b05 100644 (file)
@@ -264,8 +264,8 @@ struct CAFFE2_API Type {
   virtual Tensor irfft(const Tensor & self, int64_t signal_ndim, bool normalized, bool onesided, IntList signal_sizes) const = 0;
   virtual Tensor index(const Tensor & self, TensorList indices) const = 0;
   virtual Tensor & index_copy_(Tensor & self, int64_t dim, const Tensor & index, const Tensor & source) const = 0;
-  virtual Tensor index_put(const Tensor & self, TensorList indices, const Tensor & values) const = 0;
-  virtual Tensor & index_put_(Tensor & self, TensorList indices, const Tensor & values) const = 0;
+  virtual Tensor index_put(const Tensor & self, TensorList indices, const Tensor & values, bool accumulate) const = 0;
+  virtual Tensor & index_put_(Tensor & self, TensorList indices, const Tensor & values, bool accumulate) const = 0;
   virtual Tensor inverse(const Tensor & self) const = 0;
   virtual Tensor isclose(const Tensor & self, const Tensor & other, double rtol, double atol, bool equal_nan) const = 0;
   virtual bool is_distributed(const Tensor & self) const = 0;
index 207bbb7..c92bc05 100644 (file)
@@ -17,6 +17,9 @@ struct OffsetCalculator {
   using offset_type = at::cuda::Array<uint32_t, NARGS>;
 
   OffsetCalculator(int dims, const int64_t* sizes, const int64_t* const* strides) : dims(dims) {
+    if (dims > MAX_DIMS) {
+      throw std::runtime_error("tensor has too many (>25) dims");
+    }
     for (int i = 0; i < MAX_DIMS; ++i) {
       if (i < dims) {
         sizes_[i] = IntDivider<uint32_t>(sizes[i]);
index bde283d..163db86 100644 (file)
@@ -3,7 +3,7 @@
 // This corresponds to "advanced indexing" in NumPy. The two operations are:
 //
 //  index(Tensor self, indices) -> Tensor
-//  index_put_(Tensor self, indices, value)
+//  index_put_(Tensor self, indices, value, accumulate=false)
 //
 // The index is a TensorList containg kLong or kByte tensors or nulls. Byte
 // tensors (boolean masks) are expanded to long tensors via nonzero(). Null
 // Note 2: The behavior is more complicated when the index tensors are not all
 // adjacent (e.g. x[[0, 1], :, [2, 3]]). In this case, self and the index
 // tensors are transposed to the front: x.transpose(1, 2)[[0, 1], [2, 3]]
+//
+// The code contains two implementations of indexing. The more efficient
+// implementation treats indexing like an elementwise operation over the
+// tensors `result`, `x`, `ind_1`, `ind_2`, etc. This implementation does
+// not work for index_put_ with accumulate=True. The other implementation
+// combines the indexed tensors into a single linear index that is used
+// with Tensor.put_. This is used for index_put_ with accumulate=True.
+//
+// The more efficient implementation takes the following steps for the
+// above operation:
+//
+// 1) Broadcast ind_1, ind_2, ind_3 together to a common shape
+// 2) Record x.stride(i) for each indexed dimension `i`
+// 3) Replace the indexed subspace of `x` with the shape of the corresponding
+//    subspace of `result` but with stride 0
+// 4) Add dimensions of size 1 to the index tensors (ind_1, ind_2, etc.) so
+//    that their shape is compatible with the result shape
+//
+// The CPU or CUDA kernel then computes element-wise over the broadcasted
+// and restrided result, x, ind_1,  ind_2, etc.:
+//
+//   result[...] = *(&x[...] +
+//                   ind_1[...] * x.stride(1) +
+//                   ind_2[...] * x.stride(2) +
+//                   ...)
+//
+// where & and * represent the C-style address-of and indirection operations.
 
+#include <ATen/native/Indexing.h>
 
-#include "ATen/ATen.h"
-#include "ATen/NativeFunctions.h"
-#include "ATen/ExpandUtils.h"
+#include <ATen/ATen.h>
+#include <ATen/NativeFunctions.h>
+#include <ATen/ExpandUtils.h>
+#include <ATen/native/TensorIterator.h>
 
 #include <algorithm>
 #include <functional>
@@ -33,6 +62,9 @@
 
 namespace at { namespace native {
 
+DEFINE_DISPATCH(index_stub);
+DEFINE_DISPATCH(index_put_stub);
+
 [[noreturn]]
 static void invalid_mask(const Tensor & self, int64_t idx, const Tensor & mask, int64_t maskIdx) {
   std::stringstream ss;
@@ -226,34 +258,188 @@ static std::tuple<Tensor, Tensor> makeLinearIndex(Tensor self, TensorList orig)
   return std::make_tuple(self, linearIndex);
 }
 
-Tensor index(const Tensor & self, TensorList indices) {
-  AT_CHECK(indices.size() <= (size_t)self.dim(),
-           "too many indices for tensor of dimension ", self.dim(), " (got ", indices.size(), ")");
+static bool all_strides_match(TensorList tensors) {
+  AT_ASSERT(tensors.size() >= 1);
+  auto strides = tensors[0].strides();
+  for (auto& tensor : tensors.slice(1)) {
+    if (!strides.equals(tensor.strides())) {
+      return false;
+    }
+  }
+  return true;
+}
+
+static std::string shapes_as_str(TensorList tensors) {
+  std::ostringstream os;
+  bool first = true;
+  for (auto& tensor : tensors) {
+    if (tensor.defined()) {
+      if (!first) {
+        os << ", ";
+      }
+      os << tensor.sizes();
+      first = false;
+    }
+  }
+  return os.str();
+}
+
+struct AdvancedIndex {
+  AdvancedIndex(const Tensor& src, TensorList indices);
+
+  Tensor src;
+  std::vector<Tensor> indices;
+  DimVector indexed_sizes;
+  DimVector indexed_strides;
+  int64_t dims_before;
+  int64_t dims_after;
+};
+
+// Replace indexed dimensions in src with stride 0 and the size of the result tensor.
+// The offset in these dimensions is computed by the kernel using the index tensor's
+// values and the stride of src. The new shape is not meaningful. It's used to make
+// the shape compatible with the result tensor.
+static Tensor restride_src(const Tensor& src, int64_t dims_before, int64_t dims_indexed,
+                           IntList replacement_shape) {
+  auto shape = DimVector(src.sizes());
+  auto strides = DimVector(src.strides());
+  int end = dims_before + dims_indexed;
+  shape.erase(shape.begin() + dims_before, shape.begin() + end);
+  strides.erase(strides.begin() + dims_before, strides.begin() + end);
+  shape.insert(shape.begin() + dims_before, replacement_shape.begin(), replacement_shape.end());
+  strides.insert(strides.begin() + dims_before, replacement_shape.size(), 0);
+  return src.as_strided(shape, strides);
+}
+
+// Add dimensions of size 1 to an index tensor so that it's can be broadcast to the result
+// shape and iterated over element-wise like the result tensor and the restrided src.
+static Tensor reshape_indexer(const Tensor& index, int64_t dims_before, int64_t dims_after) {
+  auto orig_shape = index.sizes();
+  auto shape = DimVector();
+  shape.append(dims_before, 1);
+  shape.append(orig_shape.begin(), orig_shape.end());
+  shape.append(dims_after, 1);
+  return index.reshape(shape);
+}
+
+AdvancedIndex::AdvancedIndex(const Tensor& src, TensorList indices_list)
+{
+  int64_t element_size_bytes = src.type().elementSizeInBytes();
+  int dims_before = 0, dims_after = 0, dims_indexed = 0;
+  IntList replacement_shape;
+  for (size_t dim = 0; dim < indices_list.size(); dim++) {
+    if (!indices_list[dim].defined()) {
+      if (dims_indexed == 0) {
+        dims_before++;
+      } else {
+        dims_after++;
+      }
+    } else {
+      dims_indexed++;
+      replacement_shape = indices_list[dim].sizes();
+      indexed_sizes.push_back(src.size(dim));
+      indexed_strides.push_back(src.stride(dim) * element_size_bytes);
+    }
+  }
+
+  this->dims_before = dims_before;
+  this->dims_after = dims_after;
+  this->src = restride_src(src, dims_before, dims_indexed, replacement_shape);
+
+  for (auto& index : indices_list) {
+    if (index.defined()) {
+      indices.push_back(reshape_indexer(index, dims_before, dims_after));
+    }
+  }
 
-  Tensor src, linearIndex;
-  std::tie(src, linearIndex) = makeLinearIndex(self, indices);
-  return src.take(linearIndex);
+  // For CUDA tensors, force all index tensors to have the same striding to
+  // simplify the CUDA kernel.
+  if (indices.size() >= 2 && this->src.type().device_type() == kCUDA) {
+    if (!all_strides_match(indices)) {
+      for (size_t i = 0; i < indices.size(); i++) {
+        indices[i] = indices[i].contiguous();
+      }
+    }
+  }
 }
 
-Tensor index_put(const Tensor & self, TensorList indices, const Tensor & value) {
+static AdvancedIndex make_info(Tensor self, TensorList orig) {
+  checkIndexTensorTypes(orig);
+  // first expand ByteTensor (boolean masks) into 1 or more LongTensors
+  auto indices = expandByteTensors(self, orig);
+  // next broadcast all index tensors together
+  try {
+    indices = expand_outplace(indices);
+  } catch (std::exception& e) {
+    AT_ERROR("shape mismatch: indexing tensors could not be broadcast together"
+             " with shapes ", shapes_as_str(indices));
+  }
+  // add missing null Tensors so that it matches self.dim()
+  while (indices.size() < (size_t)self.dim()) {
+    indices.emplace_back();
+  }
+  // if the non-null indices are not all adjacent, transpose self and indices
+  // together so that they're adjacent at the front
+  if (!hasContiguousSubspace(indices)) {
+    std::tie(self, indices) = transposeToFront(self, indices);
+  }
+  return AdvancedIndex(self, indices);
+}
+
+static std::unique_ptr<TensorIterator> make_index_iterator(const AdvancedIndex& info) {
+  auto builder = TensorIterator::Builder();
+  builder.dont_compute_common_dtype();
+  builder.add_output(Tensor(), &info.src.type());
+  builder.add_input(info.src);
+  for (auto& index : info.indices) {
+    builder.add_input(index);
+  }
+  return builder.build();
+}
+
+static std::unique_ptr<TensorIterator> make_index_put_iterator(const AdvancedIndex& info, const Tensor& value) {
+  if (!is_expandable_to(value.sizes(), info.src.sizes())) {
+    AT_ERROR("shape mismatch: value tensor of shape ", value.sizes(),
+             " cannot be broadcast to indexing result of shape ", info.src.sizes());
+  }
+  auto builder = TensorIterator::Builder();
+  builder.dont_compute_common_dtype();
+  builder.dont_resize_outputs();
+  builder.add_output(info.src);
+  builder.add_input(value, &info.src.type());
+  for (auto& index : info.indices) {
+    builder.add_input(index);
+  }
+  return builder.build();
+}
+
+Tensor index(const Tensor & self, TensorList indices) {
   AT_CHECK(indices.size() <= (size_t)self.dim(),
            "too many indices for tensor of dimension ", self.dim(), " (got ", indices.size(), ")");
 
-  Tensor src, linearIndex, expandedValue;
-  std::tie(src, linearIndex) = makeLinearIndex(self, indices);
-  std::tie(expandedValue) = expand_inplace(linearIndex, value);
-  Tensor dst = src.clone();
-  return dst.put_(linearIndex, expandedValue);
+  auto info = make_info(self, indices);
+  auto iter = make_index_iterator(info);
+  index_stub(iter->device_type(), *iter, info.indexed_sizes, info.indexed_strides);
+  return iter->output();
 }
 
-Tensor & index_put_(Tensor & self, TensorList indices, const Tensor & value) {
+Tensor index_put(const Tensor & self, TensorList indices, const Tensor & value, bool accumulate) {
+  return self.clone().index_put_(indices, value, accumulate);
+}
+
+Tensor & index_put_(Tensor & self, TensorList indices, const Tensor & value, bool accumulate) {
   AT_CHECK(indices.size() <= (size_t)self.dim(),
            "too many indices for tensor of dimension ", self.dim(), " (got ", indices.size(), ")");
-
-  Tensor src, linearIndex, expandedValue;
-  std::tie(src, linearIndex) = makeLinearIndex(self, indices);
-  std::tie(expandedValue) = expand_inplace(linearIndex, value);
-  return src.put_(linearIndex, expandedValue);
+  if (accumulate && self.type().device_type() == kCUDA) {
+    Tensor src, linearIndex, expandedValue;
+    std::tie(src, linearIndex) = makeLinearIndex(self, indices);
+    std::tie(expandedValue) = expand_inplace(linearIndex, value);
+    return src.put_(linearIndex, expandedValue, true);
+  }
+  auto info = make_info(self, indices);
+  auto iter = make_index_put_iterator(info, value);
+  index_put_stub(iter->device_type(), *iter, info.indexed_sizes, info.indexed_strides, accumulate);
+  return self;
 }
 
 Tensor & index_copy_(Tensor & self, int64_t dim, const Tensor & index, const Tensor & source) {
diff --git a/aten/src/ATen/native/Indexing.h b/aten/src/ATen/native/Indexing.h
new file mode 100644 (file)
index 0000000..bdd5ed1
--- /dev/null
@@ -0,0 +1,20 @@
+#pragma once
+
+// Indexing tensors by by tensors
+
+#include <ATen/ATen.h>
+#include <ATen/native/DispatchStub.h>
+
+namespace at {
+  struct TensorIterator;
+}
+
+namespace at { namespace native {
+
+using index_fn = void(*)(TensorIterator &, IntList indexed_sizes, IntList indexed_strides);
+using index_put_fn = void(*)(TensorIterator &, IntList indexed_sizes, IntList indexed_strides, bool accumulate);
+
+DECLARE_DISPATCH(index_fn, index_stub);
+DECLARE_DISPATCH(index_put_fn, index_put_stub);
+
+}} // namespace at::native
index 03d1072..a94d6df 100644 (file)
@@ -97,21 +97,47 @@ compute_result_type(at::ArrayRef<OperandInfo> operands, const F& predicate) {
   return std::make_tuple(result_type, backend);
 }
 
-static bool needs_cast(const Tensor& tensor, const Type& dst_type) {
-  if (!tensor.defined() || dst_type == tensor.type()) {
-    return false;
+void TensorIterator::compute_types() {
+  bool missing_dtypes = false;
+  for (auto& op : operands_) {
+    if (!op.tensor.defined() && !op.type) {
+      missing_dtypes = true;
+    }
   }
-  if (dst_type.device_type() == DeviceType::CUDA &&
-      tensor.type().device_type() == DeviceType::CPU &&
-      tensor.dim() == 0) {
-    // zero-dim CPU tensors used in CUDA operations can be used directly without
-    // casting
-    return false;
+
+  if (missing_dtypes || compute_common_dtype_) {
+    auto& type = compute_common_type();
+    for (auto& op : operands_) {
+      if (!op.type) {
+        op.type = &type;
+      } else if (compute_common_dtype_ && op.type != &type) {
+        if (allow_cpu_scalars_ && op.tensor.defined() && op.tensor.dim() == 0 &&
+            type.device_type() == kCUDA && op.tensor.type().device_type() == kCPU) {
+          // don't cast CPU scalars in CUDA ops that directly support them
+          op.type = &op.tensor.type();
+        } else {
+          op.type = &type;
+        }
+      }
+    }
+  }
+
+  for (auto& op : operands_) {
+    if (op.tensor.defined() && op.tensor.type() != *op.type) {
+      if (op.is_output) {
+        AT_ERROR("output with type ", op.tensor.type().toString(),
+                 " doesn't match the desired type ", type().toString());
+      } else if (op.tensor.dim() == 0) {
+        op.tensor = op.tensor.toType(*op.type);
+      } else {
+        AT_ERROR("expected type ", type().toString(), " but got ",
+            op.tensor.type().toString());
+      }
+    }
   }
-  return true;
 }
 
-void TensorIterator::compute_common_type() {
+Type& TensorIterator::compute_common_type() {
   // See [Result type computation] in TensorIterator.h
   auto result_type = ScalarType::Undefined;
   auto backend = Backend::Undefined;
@@ -132,18 +158,7 @@ void TensorIterator::compute_common_type() {
   AT_ASSERT(result_type != ScalarType::Undefined);
   AT_ASSERT(backend != Backend::Undefined);
 
-  auto& type = at::globalContext().getNonVariableType(backend, result_type);
-
-  for (auto& op : operands_) {
-    if (!op.type) {
-      op.type = &type;
-      op.needs_cast = needs_cast(op.tensor, type);
-      if (op.needs_cast && op.tensor.dim() == 0 && !op.is_output) {
-        op.tensor = op.tensor.toType(type);
-        op.needs_cast = false;
-      }
-    }
-  }
+  return at::globalContext().getNonVariableType(backend, result_type);
 }
 
 DimVector TensorIterator::compatible_stride(int element_size) const {
@@ -171,6 +186,7 @@ void TensorIterator::allocate_outputs() {
   for (int i = 0; i < num_outputs_; i++) {
     auto& op = operands_[i];
     if (!op.tensor.defined()) {
+      AT_ASSERTM(op.type, "no type for operand", i);
       int element_size = op.type->elementSizeInBytes();
       op.stride_bytes = compatible_stride(element_size);
 
@@ -405,7 +421,7 @@ bool TensorIterator::is_scalar(int arg) const {
 }
 
 bool TensorIterator::is_cpu_scalar(int arg) const {
-  return is_scalar(arg) && operands_[arg].tensor.type().backend() == at::Backend::CPU;
+  return is_scalar(arg) && operands_[arg].tensor.type().device_type() == kCPU;
 }
 
 void* TensorIterator::data_ptr(int arg) const {
@@ -450,6 +466,7 @@ std::unique_ptr<TensorIterator> TensorIterator::binary_op(Tensor& out, const Ten
   builder.add_output(out);
   builder.add_input(a);
   builder.add_input(b);
+  builder.iter_->allow_cpu_scalars_ = true;
   return builder.build();
 }
 
@@ -459,6 +476,7 @@ std::unique_ptr<TensorIterator> TensorIterator::reduce_op(Tensor& out, const Ten
   builder.add_output(out);
   builder.add_input(a);
   builder.iter_->resize_outputs_ = false;
+  builder.iter_->is_reduction_ = true;
   return builder.build();
 }
 
@@ -485,7 +503,7 @@ void TensorIterator::compute_shape() {
     // For now, don't include output tensors that are not also input tensors.
     // This preserves the legacy behavior where torch.add(..., out=dst) resizes
     // the destination tensor.
-    if (op.is_output && !op.is_read_write) continue;
+    if (resize_outputs_ && op.is_output && !op.is_read_write) continue;
 
     auto shape = op.tensor.sizes();
     if (shape_.empty()) {
@@ -501,15 +519,17 @@ void TensorIterator::compute_shape() {
   // outputs.
   for (int i = 0; i < num_outputs_; i++) {
     auto& tensor = operands_[i].tensor;
-    if (resize_outputs_ && tensor.defined() && !tensor.sizes().equals(shape_)) {
-      if (!operands_[i].is_read_write) {
+    if (tensor.defined() && !tensor.sizes().equals(shape_)) {
+      if (resize_outputs_ && !operands_[i].is_read_write) {
         // Preserve legacy resizing behavior of out=... arguments
         // TODO: issue warning
         tensor.resize_(shape_);
         continue;
       }
-      AT_ERROR("output with shape ", tensor.sizes(), " doesn't match the broadcast shape ",
-               shape_);
+      if (!is_reduction_) {
+        AT_ERROR("output with shape ", tensor.sizes(), " doesn't match the broadcast shape ",
+                 shape_);
+      }
     }
   }
 }
@@ -540,15 +560,6 @@ void TensorIterator::compute_strides() {
   }
 }
 
-void TensorIterator::check_type_conversions() {
-  for (auto& op : operands_) {
-    if (op.needs_cast) {
-      AT_ERROR("TensorIterator expected type ", type().toString(), " but got ",
-          op.tensor.type().toString());
-    }
-  }
-}
-
 bool TensorIterator::can_use_32bit_indexing() const {
   int64_t max_value = std::numeric_limits<int32_t>::max();
   if (numel() > max_value) {
@@ -612,7 +623,7 @@ std::unique_ptr<TensorIterator> TensorIterator::Builder::build() {
   // re-order dimensions to improve coalescing
   iter_->reorder_dimensions();
   // compute the result dtype and backend
-  iter_->compute_common_type();
+  iter_->compute_types();
   // allocate the output tensor if it's not provided
   iter_->allocate_outputs();
   // coalesce adjacent dimensions when possible
@@ -623,8 +634,6 @@ std::unique_ptr<TensorIterator> TensorIterator::Builder::build() {
     op.data = op.tensor.data_ptr();
   }
 
-  iter_->check_type_conversions();
-
   return std::move(iter_);
 }
 
index ad103b2..62d9a4b 100644 (file)
@@ -54,7 +54,12 @@ namespace at {
 
 struct CAFFE2_API OperandInfo {
   OperandInfo() {}
-  OperandInfo(const Tensor& t) : tensor(t) {}
+  OperandInfo(const Tensor& t, const Type* type=nullptr)
+    : tensor(t), type(const_cast<Type*>(type)) {
+      if (t.defined() && !type) {
+        this->type = &t.type();
+      }
+  }
 
   /// Stride after broadcasting. The stride is in bytes, not number of elements.
   DimVector stride_bytes;
@@ -74,9 +79,6 @@ struct CAFFE2_API OperandInfo {
   /// iterator is split.
   void* data = nullptr;
 
-  /// True if the kernel needs to handle a cast operation for this operand.
-  bool needs_cast = false;
-
   bool is_output = false;
 
   bool is_read_write = false;
@@ -210,10 +212,10 @@ protected:
   void compute_strides();
   void reorder_dimensions();
   void permute_dimensions(IntList perm);
-  void compute_common_type();
+  void compute_types();
+  Type& compute_common_type();
   void allocate_outputs();
   void coalesce_dimensions();
-  void check_type_conversions();
 
 protected:
   DimVector shape_;
@@ -223,6 +225,9 @@ protected:
   bool has_coalesced_dimensions_ = false;
   bool accumulate_ = false;
   bool resize_outputs_ = true;
+  bool is_reduction_ = false;
+  bool compute_common_dtype_ = true;
+  bool allow_cpu_scalars_ = false;
 };
 
 struct TensorIterator::Builder {
@@ -230,15 +235,21 @@ struct TensorIterator::Builder {
 
   Builder() : iter_(new TensorIterator()) {};
 
-  Builder& add_output(const Tensor& output) {
-    iter_->operands_.emplace_back(output);
+  void add_output(const Tensor& output, const Type* type=nullptr) {
+    iter_->operands_.emplace_back(output, type);
     iter_->num_outputs_++;
-    return *this;
   }
 
-  Builder& add_input(const Tensor& input) {
-    iter_->operands_.emplace_back(input);
-    return *this;
+  void add_input(const Tensor& input, const Type* type=nullptr) {
+    iter_->operands_.emplace_back(input, type);
+  }
+
+  void dont_compute_common_dtype() {
+    iter_->compute_common_dtype_ = false;
+  }
+
+  void dont_resize_outputs() {
+    iter_->resize_outputs_ = false;
   }
 
   std::unique_ptr<TensorIterator> build();
diff --git a/aten/src/ATen/native/cpu/IndexKernel.cpp b/aten/src/ATen/native/cpu/IndexKernel.cpp
new file mode 100644 (file)
index 0000000..525bb9c
--- /dev/null
@@ -0,0 +1,125 @@
+#include <ATen/native/Indexing.h>
+
+#include <cmath>
+#include <iostream>
+#include <ATen/Dispatch.h>
+#include <ATen/native/TensorIterator.h>
+#include <ATen/Parallel.h>
+#include <ATen/cpu/vec256/vec256.h>
+
+namespace at { namespace native {
+namespace {
+
+using namespace vec256;
+
+struct Indexer {
+  Indexer(int64_t num_indexers, char** indexers, const int64_t* indexer_strides,
+          IntList original_sizes, IntList original_strides)
+    : num_indexers(num_indexers)
+    , indexers(indexers)
+    , indexer_strides(indexer_strides)
+    , original_strides(original_strides.data())
+    , original_sizes(original_sizes.data()) {
+    AT_ASSERT(original_strides.size() == num_indexers);
+    AT_ASSERT(original_sizes.size() == num_indexers);
+  }
+
+  int64_t num_indexers;
+  char** indexers;
+  const int64_t* indexer_strides;
+  const int64_t* original_strides;
+  const int64_t* original_sizes;
+
+  int64_t get(int64_t idx) {
+    int64_t offset = 0;
+    for (int j = 0; j < num_indexers; j++) {
+      int64_t value = *(int64_t*)&indexers[j][idx * indexer_strides[j]];
+      int64_t size = original_sizes[j];
+      if (value < -size || value >= size) {
+        AT_ERROR("index ", value, " is out of bounds for dim with size ", size);
+      }
+      if (value < 0) {
+        value += size;
+      }
+      offset += value * original_strides[j];
+    }
+    return offset;
+  }
+};
+
+static bool is_constant_index(int ntensor, const int64_t* strides) {
+  AT_ASSERT(ntensor >= 3);
+  for (int arg = 2; arg < ntensor; arg++) {
+    if (strides[arg] != 0) {
+      return false;
+    }
+  }
+  return true;
+}
+
+template <typename scalar_t, typename func_t>
+void cpu_index_kernel(TensorIterator& iter, IntList index_size, IntList index_stride,
+                      const func_t& f, bool serial_execution=false)
+{
+  auto loop = [&](int ntensor, char** data, const int64_t* strides, int64_t n) {
+    auto indexer = Indexer(ntensor - 2, &data[2], &strides[2], index_size, index_stride);
+    char* dst = data[0];
+    char* src = data[1];
+    if (is_constant_index(ntensor, strides)) {
+      // specialization for when every element uses the same index
+      int64_t offset = indexer.get(0);
+      if (strides[0] == sizeof(scalar_t) && strides[1] == sizeof(scalar_t)) {
+        for (int64_t i = 0; i < n; i++) {
+          f(dst + strides[0] * i, src + strides[1] * i, offset);
+        }
+      } else {
+        for (int64_t i = 0; i < n; i++) {
+          f(dst + strides[0] * i, src + strides[1] * i, offset);
+        }
+      }
+    } else {
+      for (int64_t i = 0; i < n; i++) {
+        int64_t offset = indexer.get(i);
+        f(dst + strides[0] * i, src + strides[1] * i, offset);
+      }
+    }
+  };
+  if (serial_execution) {
+    iter.serial_for_each(loop, {0, iter.numel()});
+  } else {
+    iter.for_each(loop);
+  }
+}
+
+void index_kernel(TensorIterator& iter, IntList index_size, IntList index_stride) {
+  AT_DISPATCH_ALL_TYPES(iter.type(0), "index", [&] {
+    cpu_index_kernel<scalar_t>(iter, index_size, index_stride, [](char* dst, char* src, int64_t offset) {
+      *(scalar_t*)dst = *(scalar_t*)(src + offset);
+    });
+  });
+}
+
+void index_put_kernel(TensorIterator& iter, IntList index_size, IntList index_stride, bool accumulate) {
+  // NOTE: duplicate indices are only supported if accumulate is true.
+  AT_DISPATCH_ALL_TYPES(iter.type(0), "index_put", [&] {
+    if (accumulate) {
+      // TODO: investigate parallelization of the accumulate kernel. Unlike the non-accumulate case,
+      // this needs to be thread-safe.
+      cpu_index_kernel<scalar_t>(iter, index_size, index_stride, [](char* dst, char* src, int64_t offset) {
+        *(scalar_t*)(dst + offset) += *(scalar_t*)src;
+      }, /*serial_execution=*/true);
+    } else {
+      cpu_index_kernel<scalar_t>(iter, index_size, index_stride, [](char* dst, char* src, int64_t offset) {
+        *(scalar_t*)(dst + offset) = *(scalar_t*)src;
+      });
+    }
+  });
+}
+
+} // anonymous namespace
+
+
+REGISTER_DISPATCH(index_stub, &index_kernel);
+REGISTER_DISPATCH(index_put_stub, &index_put_kernel);
+
+}} // namespace at::native
diff --git a/aten/src/ATen/native/cuda/IndexKernel.cu b/aten/src/ATen/native/cuda/IndexKernel.cu
new file mode 100644 (file)
index 0000000..a393ee4
--- /dev/null
@@ -0,0 +1,102 @@
+#include <ATen/native/Indexing.h>
+
+#include <ATen/ATen.h>
+#include <ATen/Dispatch.h>
+#include <ATen/native/TensorIterator.h>
+#include <ATen/native/cuda/Loops.cuh>
+#include <ATen/cuda/Array.h>
+
+namespace at { namespace native {
+
+template <int N>
+static OffsetCalculator<N> index_make_offset_calculator(const TensorIterator& iter) {
+  AT_ASSERT(N <= iter.ntensors());
+  std::array<const int64_t*, N> strides;
+  for (int i = 0; i < N; i++) {
+    strides[i] = iter.strides(i).data();
+  }
+  return OffsetCalculator<N>(iter.ndim(), iter.shape().data(), strides.data());
+}
+
+template <typename func_t>
+void gpu_index_kernel(TensorIterator& iter, IntList index_size, IntList index_stride, const func_t& f) {
+  int num_indices = index_size.size();
+  AT_ASSERT(num_indices == index_stride.size());
+  AT_ASSERT(num_indices == iter.ntensors() - 2);
+
+  if (iter.numel() == 0) {
+    return;
+  }
+
+  auto sizes = cuda::Array<int64_t, 25>(0);
+  auto strides = cuda::Array<int64_t, 25>(0);
+  auto index_ptrs = cuda::Array<char*, 25>(nullptr);
+  for (int i = 0; i < num_indices; i++) {
+    sizes[i] = index_size[i];
+    strides[i] = index_stride[i];
+    index_ptrs[i] = (char*)iter.data_ptr(i + 2);
+  }
+
+  char* out_ptr = (char*)iter.data_ptr(0);
+  char* in_ptr = (char*)iter.data_ptr(1);
+
+  auto offset_calc = index_make_offset_calculator<3>(iter);
+  launch_kernel<128, 4>(iter.numel(), [=]__device__(int idx) {
+    auto offsets = offset_calc.get(idx);
+    char* out_data = out_ptr + offsets[0];
+    char* in_data = in_ptr + offsets[1];
+
+    int64_t offset = 0;
+    #pragma unroll
+    for (int i = 0; i < num_indices; i++) {
+      int64_t index = *(int64_t*)(index_ptrs[i] + offsets[2]);
+      assert(index >= -sizes[i] && index < sizes[i] && "index out of bounds");
+      if (index < 0) {
+        index += sizes[i];
+      }
+      offset += index * strides[i];
+    }
+
+    f(out_data, in_data, offset);
+  });
+}
+
+// The kernels are templated on an opaque, self-aligned type of the correct
+// size to avoid redundant kernels for different types of the same size.
+template <int N> struct alignas(N) OpaqueType { char data[N]; };
+
+
+template <typename scalar_t>
+void index_kernel_impl(TensorIterator& iter, IntList index_size, IntList index_stride) {
+  gpu_index_kernel(iter, index_size, index_stride, []C10_DEVICE(char* out_data, char* in_data, int64_t offset) {
+    *(scalar_t*)out_data = *(scalar_t*)(in_data + offset);
+  });
+}
+
+template <typename scalar_t>
+void index_put_kernel_impl(TensorIterator& iter, IntList index_size, IntList index_stride) {
+  gpu_index_kernel(iter, index_size, index_stride, []C10_DEVICE(char* out_data, char* in_data, int64_t offset) {
+    *(scalar_t*)(out_data + offset) = *(scalar_t*)in_data;
+  });
+}
+
+static void index_kernel(TensorIterator& iter, IntList index_size, IntList index_stride) {
+  AT_DISPATCH_ALL_TYPES_AND_HALF(iter.type(), "index", [&] {
+    using dtype = OpaqueType<sizeof(scalar_t)>;
+    index_kernel_impl<dtype>(iter, index_size, index_stride);
+  });
+}
+
+
+static void index_put_kernel(TensorIterator& iter, IntList index_size, IntList index_stride, bool accumulate) {
+  AT_ASSERTM(!accumulate, "index_put does not support accumulate=true");
+  AT_DISPATCH_ALL_TYPES_AND_HALF(iter.type(), "index_put", [&] {
+    using dtype = OpaqueType<sizeof(scalar_t)>;
+    index_put_kernel_impl<dtype>(iter, index_size, index_stride);
+  });
+}
+
+REGISTER_DISPATCH(index_stub, &index_kernel);
+REGISTER_DISPATCH(index_put_stub, &index_put_kernel);
+
+}} // namespace at::native
index a9a4b59..ed1a150 100644 (file)
 - func: index_copy_(Tensor self, int64_t dim, IndexTensor index, Tensor source) -> Tensor
   variants: method
 
-- func: index_put(Tensor self, TensorList indices, Tensor values) -> Tensor
+- func: index_put(Tensor self, TensorList indices, Tensor values, bool accumulate=false) -> Tensor
   variants: function, method
 
-- func: index_put_(Tensor self, TensorList indices, Tensor values) -> Tensor
+- func: index_put_(Tensor self, TensorList indices, Tensor values, bool accumulate=false) -> Tensor
   variants: function, method
 
 - func: instance_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool use_input_stats, double momentum, double eps, bool cudnn_enabled) -> Tensor
index cfbec5e..6a747aa 100644 (file)
@@ -7,6 +7,7 @@ graph(%target : Double(100)
   %6 : bool = prim::Constant[value=0]()
   %indices : Long(4) = aten::_cast_Long(%indices.1, %6)
   %8 : Dynamic[] = prim::ListConstruct(%indices)
-  %9 : Double(100) = aten::index_put_(%target, %8, %5)
-  return (%9);
+  %9 : bool = prim::Constant[value=0]()
+  %10 : Double(100) = aten::index_put_(%target, %8, %5, %9)
+  return (%10);
 }
index f168a3f..8671b2c 100644 (file)
@@ -4,6 +4,7 @@ graph(%target : Double(100)
   %3 : bool = prim::Constant[value=0]()
   %indices : Long(4) = aten::_cast_Long(%indices.1, %3)
   %5 : Dynamic[] = prim::ListConstruct(%indices)
-  %6 : Double(100) = aten::index_put_(%target, %5, %rhs)
-  return (%6);
+  %6 : bool = prim::Constant[value=0]()
+  %7 : Double(100) = aten::index_put_(%target, %5, %rhs, %6)
+  return (%7);
 }
index 4623834..9991411 100644 (file)
@@ -28,6 +28,7 @@ TESTS = [
     'distributions',
     'expecttest',
     'indexing',
+    'indexing_cuda',
     'jit',
     'multiprocessing',
     'multiprocessing_spawn',
index 7003459..c85080f 100644 (file)
@@ -926,7 +926,7 @@ class TestCuda(TestCase):
 
             self.assertEqual(x * y, 4.5)
             self.assertEqual(y * x, 4.5)
-            with self.assertRaisesRegex(RuntimeError, 'expected type'):
+            with self.assertRaisesRegex(RuntimeError, "doesn't match the desired type"):
                 y *= x
             x *= y
             self.assertEqual(x, 4.5)
index 68ca3ff..9f55d59 100644 (file)
@@ -2,6 +2,7 @@ from common_utils import TestCase, run_tests
 import torch
 import warnings
 from torch import tensor
+import unittest
 
 
 class TestIndexing(TestCase):
@@ -448,9 +449,9 @@ class NumpyTests(TestCase):
         def f(a, v):
             a[a > -1] = tensor(v)
 
-        self.assertRaisesRegex(Exception, "expand", f, a, [])
-        self.assertRaisesRegex(Exception, 'expand', f, a, [1, 2, 3])
-        self.assertRaisesRegex(Exception, 'expand', f, a[:1], [1, 2, 3])
+        self.assertRaisesRegex(Exception, 'shape mismatch', f, a, [])
+        self.assertRaisesRegex(Exception, 'shape mismatch', f, a, [1, 2, 3])
+        self.assertRaisesRegex(Exception, 'shape mismatch', f, a[:1], [1, 2, 3])
 
     def test_boolean_indexing_twodim(self):
         # Indexing a 2-dimensional array with
@@ -503,12 +504,14 @@ class NumpyTests(TestCase):
 
     def test_broaderrors_indexing(self):
         a = torch.zeros(5, 5)
-        self.assertRaisesRegex(RuntimeError, 'match the size', a.__getitem__, ([0, 1], [0, 1, 2]))
-        self.assertRaisesRegex(RuntimeError, 'match the size', a.__setitem__, ([0, 1], [0, 1, 2]), 0)
+        self.assertRaisesRegex(RuntimeError, 'shape mismatch', a.__getitem__, ([0, 1], [0, 1, 2]))
+        self.assertRaisesRegex(RuntimeError, 'shape mismatch', a.__setitem__, ([0, 1], [0, 1, 2]), 0)
 
     def test_trivial_fancy_out_of_bounds(self):
         a = torch.zeros(5)
         ind = torch.ones(20, dtype=torch.int64)
+        if a.is_cuda:
+            raise unittest.SkipTest('CUDA asserts instead of raising an exception')
         ind[-1] = 10
         self.assertRaises(RuntimeError, a.__getitem__, ind)
         self.assertRaises(RuntimeError, a.__setitem__, ind, 0)
diff --git a/test/test_indexing_cuda.py b/test/test_indexing_cuda.py
new file mode 100644 (file)
index 0000000..9fccc00
--- /dev/null
@@ -0,0 +1,10 @@
+import torch
+from test_indexing import *
+
+
+if __name__ == '__main__':
+    if torch.cuda.is_available():
+        torch.set_default_tensor_type(torch.cuda.FloatTensor)
+        run_tests()
+    else:
+        print("Skipping test_indexing_cuda.py")
index 5f899ef..05c5a51 100644 (file)
 - name: histc(Tensor self, int64_t bins, Scalar min, Scalar max)
   self: not_implemented("histc")
 
+- name: index(Tensor self, TensorList indices)
+  self: zeros_like(self).index_put_(indices, grad, true)
+  indices: TensorList()
+
 - name: index_add_(Tensor self, int64_t dim, Tensor index, Tensor source)
   self: grad
   source: grad.index_select(dim, index)
   self: grad.clone().index_fill_(dim, index, 0)
   value: grad.index_select(dim, index).sum()
 
+- name: index_put_(Tensor self, TensorList indices, Tensor values, bool accumulate)
+  self: grad.clone().index_put_(indices, zeros_like(values), accumulate)
+  values: grad.index(indices)
+
 - name: index_select(Tensor self, int64_t dim, Tensor index)
   self: at::zeros(self.sizes(), grad.options()).index_add_(dim, index, grad)
 
index b6a2c64..d4570dc 100644 (file)
@@ -222,8 +222,7 @@ std::vector<at::Tensor> VariableType::unpack(at::TensorList tl, const char *name
   for (size_t i = 0; i < tl.size(); ++i) {
     const auto &t = tl[i];
     if (!t.defined()) {
-      AT_ERROR("Expected a Tensor of type Variable but found an undefined Tensor at position #", i, " "
-                    "for iterable argument #", pos, " '", name, "'");
+      continue;
     }
     if (!isVariableType(t.type())) {
       AT_ERROR("Expected object of type Variable but found type ", t.type().toString(), " at position #", i, " "
index 4082c79..4a7092f 100644 (file)
@@ -135,6 +135,12 @@ inline void check_no_requires_grad(const Tensor& tensor, const char* name) {
   }
 }
 
+inline void check_no_requires_grad(TensorList tensors, const char* name) {
+  for (auto& tensor : tensors) {
+    check_no_requires_grad(tensor, name);
+  }
+}
+
 // Assumed that saved tensor lists are never inplace outputs
 inline std::vector<SavedVariable> make_saved_variable_list(TensorList tensors) {
   return fmap(tensors, [](const Tensor& tensor) -> SavedVariable {
index f3963b3..ba79a64 100644 (file)
@@ -811,9 +811,9 @@ def index_select(g, self, dim, index):
     return g.op("Gather", self, index, axis_i=dim)
 
 
-def index_put(g, self, indices_list_value, values):
+def index_put(g, self, indices_list_value, values, accumulate):
     indices_list = _unpack_list(indices_list_value)
-    args = [self] + indices_list + [values]
+    args = [self] + indices_list + [values, accumulate]
     return g.op("ATen", *args, operator_s='index_put')