Port legacy all(*) to ATen (#15540)
authorShen Li <shenli@fb.com>
Wed, 16 Jan 2019 17:02:44 +0000 (09:02 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 16 Jan 2019 17:06:26 +0000 (09:06 -0800)
Summary:
Questions:

1. ~This PR disables `common_dtype` computation [in `TensorIterator.cpp`](https://github.com/mrshenli/pytorch/blob/all/aten/src/ATen/native/TensorIterator.cpp#L489-L491) for `all*` operators. The reason is that, [this code](https://github.com/mrshenli/pytorch/blob/all/aten/src/ATen/native/TensorIterator.cpp#L120) otherwise complains type mismatch, where the `op.tensor` is `type Variable[CPUByteType]` while the `op` is `CPUByteType`. I am not sure if this is the right solution for this problem.~

2. Should I clean up all occurrences of `_th_all` and `_th_all_out` (and `logicalAnd`, `logicalAndAll`)?

3. Do I need to implement derivatives for `all`?

gchanan

Benchmark:

<img width="590" alt="screen shot 2018-12-26 at 3 24 31 pm" src="https://user-images.githubusercontent.com/16999635/50456505-e9596a00-0922-11e9-844e-00c4b4aad7ca.png">

<img width="587" alt="screen shot 2018-12-26 at 3 26 10 pm" src="https://user-images.githubusercontent.com/16999635/50456509-ef4f4b00-0922-11e9-96bf-0a30c8574fe7.png">

<img width="590" alt="screen shot 2018-12-26 at 3 26 54 pm" src="https://user-images.githubusercontent.com/16999635/50456510-ef4f4b00-0922-11e9-8a63-e47988843cc8.png">

<img width="589" alt="screen shot 2018-12-26 at 3 27 16 pm" src="https://user-images.githubusercontent.com/16999635/50456511-ef4f4b00-0922-11e9-9004-2518aebcdc6e.png">
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15540

Differential Revision: D13548938

Pulled By: mrshenli

fbshipit-source-id: 5a2e5eef1047decb4c79906cb9f3332034908c9c

14 files changed:
aten/src/ATen/Declarations.cwrap
aten/src/ATen/core/aten_interned_strings.h
aten/src/ATen/cpu/vec256/vec256_base.h
aten/src/ATen/native/LegacyDefinitions.cpp
aten/src/ATen/native/ReduceOps.cpp
aten/src/ATen/native/ReduceOps.h
aten/src/ATen/native/cpu/ReduceOpsKernel.cpp
aten/src/ATen/native/cuda/ReduceOpsKernel.cu
aten/src/TH/generic/THTensorMath.h
aten/src/TH/generic/THTensorMoreMath.cpp
aten/src/THC/THCTensorMath.h
aten/src/THC/THCTensorMathReduce.cu
test/test_torch.py
tools/autograd/derivatives.yaml

index 30a90c7..fc04feb 100644 (file)
       default: "true"
 ]]
 [[
-  name: _th_all
-  types:
-    - Byte
-  variants:
-    - function
-  backends:
-    - CPU
-    - CUDA
-  options:
-    - cname: logicalAndAll
-      return: real
-      arguments:
-        - THTensor* self
-]]
-[[
-  name: _th_all
-  types:
-    - Byte
-  variants: function
-  backends:
-    - CPU
-    - CUDA
-  options:
-    - cname: logicalAnd
-      return: argument 0
-      scalar_check: self_->dim() == 0 || (keepdim == false && self_->dim() == 1)
-      arguments:
-        - arg: THTensor* result
-          output: True
-        - THTensor* self
-        - arg: long dim
-          wrap_dim: self
-        - arg: bool keepdim
-          default: "false"
-]]
-[[
   name: _th_any
   types:
     - Byte
index 61399e7..8e2fbd7 100644 (file)
@@ -139,7 +139,6 @@ _(aten, _tan) \
 _(aten, _tanh) \
 _(aten, _tanh_backward) \
 _(aten, _tanh_forward) \
-_(aten, _th_all) \
 _(aten, _th_any) \
 _(aten, _th_baddbmm) \
 _(aten, _th_bmm) \
index 5217b90..f5afece 100644 (file)
@@ -257,7 +257,7 @@ public:
 #define DEFINE_COMP(binary_pred)                                              \
   Vec256<T> operator binary_pred(const Vec256<T> &other) const {              \
     Vec256<T> vec;                                                            \
-    for (int64_t i = 0; i != size(); i++) {                                     \
+    for (int64_t i = 0; i != size(); i++) {                                   \
       if (values[i] binary_pred other.values[i]) {                            \
         std::memset(static_cast<void*>(vec.values + i), 0xFF, sizeof(T));     \
       } else {                                                                \
@@ -273,6 +273,7 @@ public:
   DEFINE_COMP(>)
   DEFINE_COMP(<)
 #undef DEFINE_COMP
+
 };
 
 template <class T> Vec256<T> inline operator+(const Vec256<T> &a, const Vec256<T> &b) {
@@ -510,8 +511,8 @@ interleave2(const Vec256<T>& a, const Vec256<T>& b) {
 
 template <typename src_T, typename dst_T>
 void convert(const src_T *src, dst_T *dst, int64_t n) {
-#ifndef _MSC_VER  
-# pragma unroll  
+#ifndef _MSC_VER
+# pragma unroll
 #endif
   for (int64_t i = 0; i < n; i++) {
     *dst = static_cast<dst_T>(
index abd3378..e6c2787 100644 (file)
@@ -705,10 +705,6 @@ std::tuple<Tensor,Tensor> topk(const Tensor & self, int64_t k, int64_t dim, bool
   return at::legacy::th::_th_topk(self, k, dim, largest, sorted);
 }
 
-Tensor all(const Tensor & self) {
-  return at::legacy::th::_th_all(self);
-}
-
 Tensor any(const Tensor & self) {
   return at::legacy::th::_th_any(self);
 }
index c1c4a75..a2e2acf 100644 (file)
@@ -24,6 +24,7 @@ DEFINE_DISPATCH(sum_stub);
 DEFINE_DISPATCH(std_var_stub);
 DEFINE_DISPATCH(prod_stub);
 DEFINE_DISPATCH(mean_stub);
+DEFINE_DISPATCH(and_stub);
 
 static inline Tensor integer_upcast(const Tensor& self, optional<ScalarType> dtype) {
   ScalarType scalarType = self.type().scalarType();
@@ -428,6 +429,29 @@ Tensor norm(const Tensor& self, Scalar p) {
   return at::native::_norm(self, p);
 }
 
+inline Tensor & _all(Tensor & result, std::unique_ptr<TensorIterator> & iter) {
+  if (iter->numel() == 0) {
+    result.fill_(1);
+  } else {
+    and_stub(iter->device_type(), *iter);
+  }
+
+  return result;
+}
+
+Tensor all(const Tensor& self) {
+  AT_CHECK(self.type().backend() == Backend::CPU ||
+    self.type().backend() == Backend::CUDA, "all only supports CPU AND CUDA "
+    "backend, got: ", toString(self.type().backend()));
+  AT_CHECK(self.type().scalarType() == at::ScalarType::Byte,
+    "all only supports torch.uint8 dtype");
+
+  Tensor result = at::empty({0}, self.options());
+  auto iter = make_reduction(
+    "all", result, self, {}, false, at::ScalarType::Byte);
+  return _all(result, iter);
+}
+
 Tensor all(const Tensor& self, int64_t dim, bool keepdim) {
   Tensor result = at::empty({0}, self.options());
   return at::native::all_out(result, self, dim, keepdim);
@@ -441,7 +465,9 @@ Tensor &all_out(Tensor &result, const Tensor &self, int64_t dim, bool keepdim) {
   if (_dimreduce_return_trivial(result, self, 1, dim, keepdim)) {
     return result;
   } else {
-    return at::legacy::th::_th_all_out(result, self, dim, keepdim);
+    auto iter = make_reduction(
+      "all", result, self, dim, keepdim, at::ScalarType::Byte);
+    return _all(result, iter);
   }
 }
 
index 2d8cd4e..5fb6ac6 100644 (file)
@@ -15,6 +15,7 @@ using reduce_fn = void(*)(TensorIterator &);
 DECLARE_DISPATCH(reduce_fn, sum_stub);
 DECLARE_DISPATCH(reduce_fn, prod_stub);
 DECLARE_DISPATCH(reduce_fn, mean_stub);
+DECLARE_DISPATCH(reduce_fn, and_stub);
 
 using reduce_std_var_function =
   void (*)(TensorIterator&, bool unbiased, bool take_sqrt);
index 8b365fb..4fca2ce 100644 (file)
@@ -54,11 +54,36 @@ static void prod_kernel_impl(TensorIterator& iter) {
   });
 }
 
+static void and_kernel_impl(TensorIterator& iter) {
+  binary_kernel_reduce_vec(
+    iter,
+    [=](uint8_t a, uint8_t b) -> uint8_t { return a && b; },
+    [=](Vec256<uint8_t> a, Vec256<uint8_t> b) {
+      // Adding the implementation here instead of in vec256_base to avoid
+      // return value inconsistency. Other comparison operators in vec256_base
+      // return -1/0 (all bit 1 / all bit 0) as true/false to follow the AVX2
+      // convention. This would be convenient when combined with other
+      // vectorized operations. For example, one can use the logical operation
+      // results as a mask for a bit operation to retrieve/reset multiple
+      // elements in a vector.
+      //
+      // In this method, users would expect, e.g., all(), to return 1/0 as
+      // true/false.
+      Vec256<uint8_t> c = Vec256<uint8_t>();
+      for (int i = 0; i != Vec256<uint8_t>::size(); i++) {
+        c[i] = a[i] && b[i];
+      }
+      return c;
+    },
+    /*ident=*/true);
+}
+
 }  // anonymous namespace
 
 REGISTER_DISPATCH(sum_stub, &sum_kernel_impl);
 REGISTER_DISPATCH(std_var_stub, &std_var_kernel_impl);
 REGISTER_DISPATCH(prod_stub, &prod_kernel_impl);
 REGISTER_DISPATCH(mean_stub, &mean_kernel_impl);
+REGISTER_DISPATCH(and_stub, &and_kernel_impl);
 
 }}  // namespace at::native
index 3110458..5856d37 100644 (file)
@@ -119,9 +119,17 @@ static void mean_kernel_cuda(TensorIterator& iter) {
   });
 }
 
+void and_kernel_cuda(TensorIterator& iter) {
+  gpu_reduce_kernel<uint8_t, uint8_t>(
+    iter, func_wrapper<uint8_t> ([]GPU_LAMBDA(uint8_t a, uint8_t b) -> uint8_t {
+      return a && b;
+    }), true);
+}
+
 REGISTER_DISPATCH(std_var_stub, &std_var_kernel_cuda);
 REGISTER_DISPATCH(sum_stub, &sum_kernel_cuda);
 REGISTER_DISPATCH(prod_stub, &prod_kernel_cuda);
 REGISTER_DISPATCH(mean_stub, &mean_kernel_cuda);
+REGISTER_DISPATCH(and_stub, &and_kernel_cuda);
 
 }} // namespace at::native
index adabf5b..61520a3 100644 (file)
@@ -198,9 +198,7 @@ TH_API void THTensor_(dirichlet_grad)(THTensor *self, THTensor *x, THTensor *alp
 
 #if defined(TH_REAL_IS_BYTE)
 
-TH_API int THTensor_(logicalAndAll)(THTensor *self);
 TH_API int THTensor_(logicalAnyAll)(THTensor *self);
-TH_API void THTensor_(logicalAnd)(THTensor *r_, THTensor *t, int dimension, int keepdim);
 TH_API void THTensor_(logicalAny)(THTensor *r_, THTensor *t, int dimension, int keepdim);
 
 #endif /* TH_REAL_IS_BYTE */
index ac1b323..ab6242d 100644 (file)
@@ -1533,26 +1533,6 @@ LAB_IMPLEMENT_BASIC_FUNCTION(abs,abs)
 
 #if defined(TH_REAL_IS_BYTE) /* Byte only part */
 
-int THTensor_(logicalAndAll)(THTensor *tensor)
-{
-  scalar_t prod = 1;
-  int serial_path = 0;
-#ifdef _OPENMP
-  int inOMP = omp_in_parallel();
-  if(inOMP) {
-    serial_path = 1;
-  } else {
-    TH_TENSOR_APPLY_REDUCTION_OMP(scalar_t, tensor, &&:prod, prod = prod && *tensor_data;, UNCERTAIN_TH_OMP_OVERHEAD_THRESHOLD);
-  }
-#else
-    serial_path = 1;
-#endif
-  if (serial_path) {
-    TH_TENSOR_APPLY(scalar_t, tensor, prod = prod && *tensor_data;);
-  }
-  return prod;
-}
-
 int THTensor_(logicalAnyAll)(THTensor *tensor)
 {
   scalar_t sum = 0;
@@ -1573,83 +1553,6 @@ int THTensor_(logicalAnyAll)(THTensor *tensor)
   return (bool)sum;
 }
 
-void THTensor_(logicalAnd)(THTensor *r_, THTensor *t, int dimension, int keepdim)
-{
-  THArgCheck(dimension >= 0 && dimension < THTensor_(nDimensionLegacyAll)(t), 2, "dimension %d out of range",
-      dimension + TH_INDEX_BASE);
-
-  THTensor_(preserveReduceDimSemantics)(r_, THTensor_(nDimensionLegacyAll)(t), dimension, keepdim);
-  std::vector<int64_t> dim = THTensor_sizesLegacyNoScalars(t);
-  dim[dimension] = 1;
-  THTensor_(resize)(r_, dim, {});
-
-  int serial_path = 0;
-#ifdef _OPENMP
-  int inOMP = omp_in_parallel();
-  if (inOMP) {
-    serial_path = 1;
-  } else {
-    int r_Contig = THTensor_(isContiguous)(r_);
-    scalar_t *tp = t->data<scalar_t>();
-    scalar_t *rp = r_->data<scalar_t>();
-    if(r_Contig && (tp != rp)){
-      ptrdiff_t iter = 0;
-      ptrdiff_t r_Size = THTensor_(nElement)(r_);
-      int r_Dim = THTensor_nDimensionLegacyAll(r_);
-      #pragma omp parallel for if ( r_Size > TH_OMP_OVERHEAD_THRESHOLD)
-      for (iter = 0; iter < r_Size; iter++) {
-        int j;
-        int64_t quot;
-        int64_t rem = iter;
-        ptrdiff_t tBasicIndex = 0;
-
-        for(j = 0; j < r_Dim; ++j) {
-          if(j != dimension){
-            quot = rem/r_->stride(j);
-            rem = rem%r_->stride(j);
-            tBasicIndex += quot*t->stride(j);
-          }
-        }
-        scalar_t *t_data = tp+tBasicIndex;
-        scalar_t *r__data = rp+iter;
-        *r__data = 1;
-        for(j=0; j < THTensor_sizeLegacyNoScalars(t, dimension); ++j) {
-          *r__data = *r__data && *(t_data + j*THTensor_strideLegacyNoScalars(t, dimension));
-        }
-      }
-    } else {
-      serial_path = 1;
-    }
-  }
-#else
-  serial_path = 1;
-#endif
-
-  if(serial_path) {
-    // two implementations optimized for data locality
-    if (THTensor_strideLegacyNoScalars(t, dimension) == 1) {
-      TH_TENSOR_DIM_APPLY2(scalar_t, t, scalar_t, r_, dimension,
-                           accreal prod = 1;
-                           int64_t i;
-                           for(i = 0; i < t_size; i++)
-                             prod = prod && t_data[i*t_stride];
-                           *r__data = (scalar_t)prod;);
-    } else {
-      THTensor_(fill)(r_, 1);
-      THTensor *temp_ = THTensor_(newWithTensor)(r_);
-      // r_.expand_as(t)
-      temp_->set_size(dimension,THTensor_sizeLegacyNoScalars(t, dimension));
-      temp_->set_stride(dimension, 0);
-
-      TH_TENSOR_APPLY2(scalar_t, temp_, scalar_t, t, *temp__data = *temp__data && *t_data;);
-      c10::raw::intrusive_ptr::decref(temp_);
-    }
-  }
-  if (!keepdim) {
-    THTensor_(squeeze1d)(r_, r_, dimension);
-  }
-}
-
 void THTensor_(logicalAny)(THTensor *r_, THTensor *t, int dimension, int keepdim)
 {
   THArgCheck(dimension >= 0 && dimension < THTensor_(nDimensionLegacyAll)(t), 2, "dimension %d out of range",
index 361bacb..c3164f3 100644 (file)
 #include <THC/generic/THCTensorTopK.h>
 #include <THC/THCGenerateAllTypes.h>
 
-THC_API int THCudaByteTensor_logicalAndAll(THCState *state, THCudaByteTensor *self);
 THC_API int THCudaByteTensor_logicalAnyAll(THCState *state, THCudaByteTensor *self);
 
-THC_API void THCudaByteTensor_logicalAnd(THCState *state, THCudaByteTensor *self, THCudaByteTensor *src, int dimension, int keepdim);
 THC_API void THCudaByteTensor_logicalAny(THCState *state, THCudaByteTensor *self, THCudaByteTensor *src, int dimension, int keepdim);
 
 #endif
index d8e8b30..7d2342e 100644 (file)
@@ -2,20 +2,6 @@
 #include <THC/THCTensor.hpp>
 
 THC_API int
-THCudaByteTensor_logicalAndAll(THCState *state, THCudaByteTensor *self) {
-  THCAssertSameGPU(THCudaByteTensor_checkGPU(state, 1, self));
-  unsigned char result;
-  if (!THC_reduceAll<uint8_t>(state, self,
-                              thrust::identity<unsigned char>(),
-                              LogicalAll(),
-                              (unsigned char) 1, &result, 0)) {
-    THArgCheck(false, 1, CUTORCH_DIM_WARNING);
-  }
-
-  return (int) result;
-}
-
-THC_API int
 THCudaByteTensor_logicalAnyAll(THCState *state, THCudaByteTensor *self) {
   THCAssertSameGPU(THCudaByteTensor_checkGPU(state, 1, self));
   unsigned char result;
@@ -30,22 +16,6 @@ THCudaByteTensor_logicalAnyAll(THCState *state, THCudaByteTensor *self) {
 }
 
 THC_API void
-THCudaByteTensor_logicalAnd(THCState* state, THCudaByteTensor *self, THCudaByteTensor *src, int dimension, int keepdim) {
-  THCAssertSameGPU(THCudaByteTensor_checkGPU(state, 2, self, src));
-  if (!THC_reduceDim<uint8_t>(state, self, src,
-                              thrust::identity<unsigned char>(),
-                              LogicalAll(),
-                              thrust::identity<unsigned char>(),
-                              (unsigned char) 1,
-                              dimension,
-                              keepdim)) {
-    THArgCheck(false, 2, CUTORCH_DIM_WARNING);
-  }
-
-  THCudaCheck(cudaGetLastError());
-}
-
-THC_API void
 THCudaByteTensor_logicalAny(THCState* state, THCudaByteTensor *self, THCudaByteTensor *src, int dimension, int keepdim) {
   THCAssertSameGPU(THCudaByteTensor_checkGPU(state, 2, self, src));
   if (!THC_reduceDim<uint8_t>(state, self, src,
index fbc2fef..4719532 100644 (file)
@@ -378,6 +378,46 @@ class _TestTorchMixin(object):
                         res2[i, j] += m1[i, k] * m2[k, j]
             self.assertEqual(res1, res2)
 
+    def test_logical_all(self):
+        devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
+        for device in devices:
+            x = torch.ones([2, 3, 400], dtype=torch.uint8, device=device)
+
+            self.assertEqual(
+                torch.tensor(1, dtype=torch.uint8, device=device),
+                x.all())
+
+            self.assertEqual(
+                torch.ones([1, 3, 400], dtype=torch.uint8, device=device),
+                x.all(0, keepdim=True))
+
+            self.assertEqual(
+                torch.ones([2, 1, 400], dtype=torch.uint8, device=device),
+                x.all(1, keepdim=True))
+
+            self.assertEqual(
+                torch.ones([2, 3, 1], dtype=torch.uint8, device=device),
+                x.all(2, keepdim=True))
+
+            # set the last element to 0
+            x[-1][-1][-1] = 0
+
+            self.assertEqual(
+                torch.tensor(0, dtype=torch.uint8, device=device),
+                x.all())
+
+            y = torch.ones([1, 3, 400], dtype=torch.uint8, device=device)
+            y[-1][-1][-1] = 0
+            self.assertEqual(y, x.all(0, keepdim=True))
+
+            y = torch.ones([2, 1, 400], dtype=torch.uint8, device=device)
+            y[-1][-1][-1] = 0
+            self.assertEqual(y, x.all(1, keepdim=True))
+
+            y = torch.ones([2, 3, 1], dtype=torch.uint8, device=device)
+            y[-1][-1][-1] = 0
+            self.assertEqual(y, x.all(2, keepdim=True))
+
     def test_allclose(self):
         x = torch.tensor([1.0, 2.0, 3.0])
         y = torch.tensor([1.01, 2.01, 3.01])
index e652b2e..9ec8f9e 100644 (file)
 - name: alias(Tensor self)
   self: grad
 
+# The two items below are necessary because TensorIterator doesn't work on
+# Variables (codegen does not unwrap the input Tensor for all(*) without this
+# line).
+- name: all(Tensor self)
+  self: not_implemented("all")
+
+- name: all(Tensor self, int64_t dim, bool keepdim)
+  self: not_implemented("all")
+
 - name: as_strided(Tensor self, IntList size, IntList stride, int64_t? storage_offset)
   self: as_strided_backward(grad, TensorGeometry(self), size, stride, storage_offset)