torch.cross' dim default changed to c10::optional instead of int=-1 (#17582)
authorIgor Fedan <ifedan@fb.com>
Tue, 2 Apr 2019 20:18:20 +0000 (13:18 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Tue, 2 Apr 2019 20:27:00 +0000 (13:27 -0700)
Summary:
Argument dim=-1 doesn't work for torch.cross. The signature of the torch.cross has been changed to c10::optional<int64_t> dim instead of int64_t. So based on document "If dim is not given, it defaults to the first dimension found with the size 3." and if dim is specified (even negative) it will use the correspondent dim.

Fixes #17229
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17582

Differential Revision: D14483063

Pulled By: ifedan

fbshipit-source-id: f9699093ec401cb185fd33ca4563c8a46cdcd746

16 files changed:
aten/src/ATen/Declarations.cwrap
aten/src/ATen/core/Tensor.h
aten/src/ATen/core/TensorMethods.h
aten/src/ATen/core/Type.h
aten/src/ATen/native/Cross.cpp [new file with mode: 0644]
aten/src/ATen/native/Cross.h [new file with mode: 0644]
aten/src/ATen/native/LegacyDefinitions.cpp
aten/src/ATen/native/cpu/CrossKernel.cpp [new file with mode: 0644]
aten/src/ATen/native/cuda/CrossKernel.cu [new file with mode: 0644]
aten/src/ATen/native/native_functions.yaml
aten/src/TH/generic/THTensorMath.h
aten/src/TH/generic/THTensorMoreMath.cpp
aten/src/THC/generic/THCTensorMathPointwise.cu
aten/src/THC/generic/THCTensorMathPointwise.h
test/test_torch.py
tools/autograd/derivatives.yaml

index 8b8ffd3..b59bacf 100644 (file)
     - arg: THTensor* tensor
 ]]
 [[
-  name: _th_cross
-  cname: cross
+  name: _th_cross_kernel
+  cname: crossKernel
   variants:
     - function
+  backends:
+    - CUDA
   return: argument 0
   arguments:
     - arg: THTensor* result
       output: True
     - THTensor* self
     - THTensor* other
-    - arg: long dim
-      default: -1
+    - arg: int64_t dim
 ]]
 [[
   name: _th_diag
index ea3f2b5..add127c 100644 (file)
@@ -663,7 +663,7 @@ class CAFFE2_API Tensor {
   Tensor & exponential_(double lambd=1, Generator * generator=nullptr);
   Tensor & geometric_(double p, Generator * generator=nullptr);
   Tensor diag(int64_t diagonal=0) const;
-  Tensor cross(const Tensor & other, int64_t dim=-1) const;
+  Tensor cross(const Tensor & other, c10::optional<int64_t> dim=c10::nullopt) const;
   Tensor triu(int64_t diagonal=0) const;
   Tensor tril(int64_t diagonal=0) const;
   Tensor trace() const;
index 2a05ce7..1605aff 100644 (file)
@@ -1060,7 +1060,7 @@ inline Tensor & Tensor::geometric_(double p, Generator * generator) {
 inline Tensor Tensor::diag(int64_t diagonal) const {
     return type().diag(*this, diagonal);
 }
-inline Tensor Tensor::cross(const Tensor & other, int64_t dim) const {
+inline Tensor Tensor::cross(const Tensor & other, c10::optional<int64_t> dim) const {
     return type().cross(*this, other, dim);
 }
 inline Tensor Tensor::triu(int64_t diagonal) const {
index dcdd533..e50af83 100644 (file)
@@ -541,7 +541,7 @@ struct CAFFE2_API Type {
   virtual Tensor & exponential_(Tensor & self, double lambd, Generator * generator) const = 0;
   virtual Tensor & geometric_(Tensor & self, double p, Generator * generator) const = 0;
   virtual Tensor diag(const Tensor & self, int64_t diagonal) const = 0;
-  virtual Tensor cross(const Tensor & self, const Tensor & other, int64_t dim) const = 0;
+  virtual Tensor cross(const Tensor & self, const Tensor & other, c10::optional<int64_t> dim) const = 0;
   virtual Tensor triu(const Tensor & self, int64_t diagonal) const = 0;
   virtual Tensor tril(const Tensor & self, int64_t diagonal) const = 0;
   virtual Tensor trace(const Tensor & self) const = 0;
diff --git a/aten/src/ATen/native/Cross.cpp b/aten/src/ATen/native/Cross.cpp
new file mode 100644 (file)
index 0000000..8788969
--- /dev/null
@@ -0,0 +1,54 @@
+#include <ATen/ATen.h>
+#include <ATen/Dispatch.h>
+#include <ATen/NativeFunctions.h>
+
+#include <ATen/native/Cross.h>
+
+namespace at { namespace native {
+
+DEFINE_DISPATCH(cross_stub);
+
+Tensor cross(const Tensor & input, const Tensor & other, const c10::optional<int64_t> dimension) {
+  Tensor out = at::empty_like(input);
+  native::cross_out(out, input, other, dimension);
+  return out;
+}
+
+Tensor & cross_out(Tensor & out, const Tensor & input, const Tensor & other, const c10::optional<int64_t> dimension) {
+  auto device_res = input.type().device_type();
+  AT_CHECK(device_res == kCPU || device_res == kCUDA, "cross only supports CPU and CUDA devices, out got: ", device_res);
+  auto device1 = input.type().device_type();
+  AT_CHECK(device1 == kCPU || device1 == kCUDA, "cross only supports CPU and CUDA devices, input got: ", device1);
+  auto device2 = other.type().device_type();
+  AT_CHECK(device2 == kCPU || device2 == kCUDA, "cross only supports CPU and CUDA devices, other got: ", device2);
+  AT_CHECK(device_res == device1, "out and input must have the same device type. out: ", device_res, " input: ", device1);
+  AT_CHECK(device1 == device2, "input and other must have the same device type. input: ", device1, " other: ", device2);
+  AT_CHECK(!out.is_cuda() || out.get_device() == input.get_device(), "device of out (", input.get_device(), ") must match device of input (", other.get_device(), ")");
+  AT_CHECK(!input.is_cuda() || input.get_device() == other.get_device(), "device of input (", input.get_device(), ") must match device of other (", other.get_device(), ")");
+  AT_CHECK(input.dim() == other.dim(), "inconsistent tensors dimensions input: ", input.dim(), " other: ", other.dim());
+  AT_CHECK(input.sizes() == other.sizes(), "inconsistent tensors sizes input: ", input.sizes(), " other: ", other.sizes());
+
+  int64_t dim = -1;
+  if(!dimension.has_value()) {
+    for(int64_t i = 0; i < input.dim(); i++) {
+      if(input.size(i) == 3) {
+        dim = i;
+        break;
+      }
+    }
+    AT_CHECK(dim >= 0, "no dimension of size 3 in input");
+  } else {
+    dim = maybe_wrap_dim(dimension.value(), input.dim());
+    AT_CHECK(input.size(dim) == 3, "dimension ", dimension.value(), " does not have size 3");
+  }
+
+  if (out.sizes() != input.sizes()) {
+    out.resize_as_(input);
+  }
+
+  cross_stub(device1, out, input, other, dim);
+  return out;
+}
+
+}} // namespace at::native
+
diff --git a/aten/src/ATen/native/Cross.h b/aten/src/ATen/native/Cross.h
new file mode 100644 (file)
index 0000000..35f9886
--- /dev/null
@@ -0,0 +1,13 @@
+#pragma once
+
+#include <ATen/ATen.h>
+#include <ATen/native/DispatchStub.h>
+
+namespace at { namespace native {
+
+using cross_fn = void(*)(Tensor&, const Tensor&, const Tensor&, const int64_t d);
+
+DECLARE_DISPATCH(cross_fn, cross_stub);
+
+}} // namespace at::native
+
index 79ee5aa..6315d39 100644 (file)
@@ -272,14 +272,6 @@ Tensor diag(const Tensor & self, int64_t diagonal) {
   return at::legacy::th::_th_diag(self, diagonal);
 }
 
-Tensor & cross_out(Tensor & result, const Tensor & self, const Tensor & other, int64_t dim) {
-  return at::legacy::th::_th_cross_out(result, self, other, dim);
-}
-
-Tensor cross(const Tensor & self, const Tensor & other, int64_t dim) {
-  return at::legacy::th::_th_cross(self, other, dim);
-}
-
 Tensor trace(const Tensor & self) {
   return at::legacy::th::_th_trace(self);
 }
diff --git a/aten/src/ATen/native/cpu/CrossKernel.cpp b/aten/src/ATen/native/cpu/CrossKernel.cpp
new file mode 100644 (file)
index 0000000..9d51fc6
--- /dev/null
@@ -0,0 +1,78 @@
+#include <ATen/native/Cross.h>
+
+#include <numeric>
+#include <iterator>
+#include <algorithm>
+#include <vector>
+
+#include <ATen/Dispatch.h>
+#include <ATen/Parallel.h>
+#include <ATen/cpu/vml.h>
+namespace at { namespace native { namespace {
+
+template<typename scalar_t>
+static void apply_cross(Tensor& result, const Tensor& a, const Tensor& b, const int64_t dim) {
+  int64_t total = a.numel() / 3;
+  int64_t a_stride = a.stride(dim);
+  int64_t b_stride = b.stride(dim);
+  int64_t r_stride = result.stride(dim);
+
+  scalar_t *a_ptr = a.data<scalar_t>();
+  scalar_t *b_ptr = b.data<scalar_t>();
+  scalar_t *r_ptr = result.data<scalar_t>();
+
+  parallel_for(0, total, internal::GRAIN_SIZE, [&](int64_t s, int64_t e) {
+    const int64_t a_dim = a.dim();
+    std::vector<int64_t> position_in_dims(a_dim);
+    int64_t index_in_curr_dim = s;
+    int64_t a_start = 0;
+    int64_t b_start = 0;
+    int64_t r_start = 0;
+    for (int64_t i = 0; i < a.dim(); i++) {
+      if (i == dim) continue;
+      position_in_dims[i] = index_in_curr_dim % a.size(i);
+      a_start += (index_in_curr_dim % a.size(i)) * a.stride(i);
+      b_start += (index_in_curr_dim % b.size(i)) * b.stride(i);
+      r_start += (index_in_curr_dim % result.size(i)) * result.stride(i);
+      index_in_curr_dim = index_in_curr_dim / a.size(i);
+    }
+
+    while (s < e) {
+      r_ptr[r_start+0*r_stride] = a_ptr[a_start+1*a_stride]*b_ptr[b_start+2*b_stride] - a_ptr[a_start+2*a_stride]*b_ptr[b_start+1*b_stride];
+      r_ptr[r_start+1*r_stride] = a_ptr[a_start+2*a_stride]*b_ptr[b_start+0*b_stride] - a_ptr[a_start+0*a_stride]*b_ptr[b_start+2*b_stride];
+      r_ptr[r_start+2*r_stride] = a_ptr[a_start+0*a_stride]*b_ptr[b_start+1*b_stride] - a_ptr[a_start+1*a_stride]*b_ptr[b_start+0*b_stride];
+      s++;
+
+      for (int i = 0; i < a.dim(); i++) {
+        if (i == dim) {
+          continue;
+        }
+        position_in_dims[i]++;
+        a_start += a.stride(i);
+        b_start += b.stride(i);
+        r_start += result.stride(i);
+        if (position_in_dims[i] == a.size(i) && i != a.dim()-1) {
+            a_start -= position_in_dims[i] * a.stride(i);
+            b_start -= position_in_dims[i] * b.stride(i);
+            r_start -= position_in_dims[i] * result.stride(i);
+            position_in_dims[i] = 0;
+        } else {
+          break;
+        }
+      }
+    }
+  });
+}
+
+static void cross_kernel_impl(Tensor& result, const Tensor& a, const Tensor& b, const int64_t dim) {
+  AT_DISPATCH_ALL_TYPES(result.scalar_type(), "cross", [&]() {
+    apply_cross<scalar_t>(result, a, b, dim);
+  });
+}
+
+} // anonymous namespace
+
+REGISTER_DISPATCH(cross_stub, &cross_kernel_impl);
+
+}} // namespace at::native
+
diff --git a/aten/src/ATen/native/cuda/CrossKernel.cu b/aten/src/ATen/native/cuda/CrossKernel.cu
new file mode 100644 (file)
index 0000000..abac086
--- /dev/null
@@ -0,0 +1,15 @@
+#include <ATen/ATen.h>
+#include <ATen/NativeFunctions.h>
+#include <ATen/LegacyTHFunctions.h>
+#include <ATen/native/Cross.h>
+
+namespace at { namespace native {
+
+void cross_kernel_impl(Tensor& result, const Tensor& x1, const Tensor& x2, const int64_t dim) {
+  at::legacy::th::_th_cross_kernel_out(result, x1, x2, dim);
+}
+
+REGISTER_DISPATCH(cross_stub, &cross_kernel_impl);
+
+}}
+
index dc6153b..3821597 100644 (file)
   matches_jit_signature: True
   variants: method, function
 
-- func: cross(Tensor self, Tensor other, int dim=-1, *, Tensor(a!) out) -> Tensor(a!)
+- func: cross(Tensor self, Tensor other, int? dim=None, *, Tensor(a!) out) -> Tensor(a!)
   matches_jit_signature: True
 
-- func: cross(Tensor self, Tensor other, int dim=-1) -> Tensor
+- func: cross(Tensor self, Tensor other, int? dim=None) -> Tensor
   matches_jit_signature: True
   variants: method, function
 
index 27083a0..3ab999b 100644 (file)
@@ -80,7 +80,6 @@ TH_API void THTensor_(cumsum)(THTensor *r_, THTensor *t, int dimension);
 TH_API void THTensor_(cumprod)(THTensor *r_, THTensor *t, int dimension);
 TH_API void THTensor_(sign)(THTensor *r_, THTensor *t);
 TH_API accreal THTensor_(trace)(THTensor *t);
-TH_API void THTensor_(cross)(THTensor *r_, THTensor *a, THTensor *b, int dimension);
 
 TH_API void THTensor_(cmax)(THTensor *r, THTensor *t, THTensor *src);
 TH_API void THTensor_(cmin)(THTensor *r, THTensor *t, THTensor *src);
index 4661ef2..8461192 100644 (file)
@@ -390,53 +390,6 @@ accreal THTensor_(trace)(THTensor *t)
   return sum;
 }
 
-void THTensor_(cross)(THTensor *r_, THTensor *a, THTensor *b, int dimension)
-{
-  int i;
-
-  if(THTensor_(nDimensionLegacyNoScalars)(a) != THTensor_(nDimensionLegacyNoScalars)(b))
-    THError("inconsistent tensor dimension %dD, %dD",
-        THTensor_(nDimensionLegacyNoScalars)(a), THTensor_(nDimensionLegacyNoScalars)(b));
-
-  for(i = 0; i < a->dim(); i++)
-  {
-    if(THTensor_(size)(a, i) != THTensor_(size)(b, i)) {
-        THDescBuff ba = THTensor_(sizeDesc)(a);
-        THDescBuff bb = THTensor_(sizeDesc)(b);
-        THError("inconsistent tensor sizes %s, %s", ba.str, bb.str);
-    }
-  }
-
-  if(dimension < 0)
-  {
-    for(i = 0; i < THTensor_(nDimensionLegacyNoScalars)(a); i++)
-    {
-      if(THTensor_sizeLegacyNoScalars(a, i) == 3)
-      {
-        dimension = i;
-        break;
-      }
-    }
-    if(dimension < 0) {
-      THDescBuff ba = THTensor_(sizeDesc)(a);
-      THError("no dimension of size 3 in a: %s", ba.str);
-    }
-  }
-
-  THArgCheck(dimension >= 0 && dimension < THTensor_(nDimensionLegacyNoScalars)(a), 3, "dimension %d out of range",
-      dimension);
-  THArgCheck(THTensor_sizeLegacyNoScalars(a, dimension) == 3, 3, "dimension %d does not have size 3",
-      dimension);
-
-  THTensor_(resizeAs)(r_, a);
-
-  TH_TENSOR_DIM_APPLY3(scalar_t, a, scalar_t, b, scalar_t, r_, dimension,
-                       TH_TENSOR_DIM_APPLY3_SIZE_EQ_EXCEPT_DIM,
-                       r__data[0*r__stride] = a_data[1*a_stride]*b_data[2*b_stride] - a_data[2*a_stride]*b_data[1*b_stride];
-                       r__data[1*r__stride] = a_data[2*a_stride]*b_data[0*b_stride] - a_data[0*a_stride]*b_data[2*b_stride];
-                       r__data[2*r__stride] = a_data[0*a_stride]*b_data[1*b_stride] - a_data[1*a_stride]*b_data[0*b_stride];);
-}
-
 void THTensor_(cmax)(THTensor *r, THTensor *t, THTensor *src) {
   THTensor_(resizeAs)(r, t);
   TH_TENSOR_APPLY3(scalar_t, r, scalar_t, t, scalar_t, src,
@@ -1047,7 +1000,7 @@ int THTensor_(equal)(THTensor *ta, THTensor* tb)
 }
 
 #define TENSOR_IMPLEMENT_LOGICAL(NAME,OP)                              \
-  void THTensor_(NAME##Value)(THByteTensor *r_, THTensor* t, scalar_t value)   \
+  void THTensor_(NAME##Value)(THByteTensor *r_, THTensor* t, scalar_t value) \
   {                                                                    \
     THByteTensor_resizeNd(r_, t->dim(), THTensor_getSizePtr(t), NULL);         \
     TH_TENSOR_APPLY2(unsigned char, r_, scalar_t, t,                   \
index 0c128a5..6a2ee33 100644 (file)
@@ -111,26 +111,10 @@ void THCTensor_(clamp)(THCState *state, THCTensor *self_, THCTensor *src, scalar
   THCudaCheck(cudaGetLastError());
 }
 
-void THCTensor_(cross)(THCState *state, THCTensor *self, THCTensor *x, THCTensor *y, int dimension)
+void THCTensor_(crossKernel)(THCState *state, THCTensor *self, THCTensor *x, THCTensor *y, int dimension)
 {
   THCAssertSameGPU(THCTensor_(checkGPU)(state, 3, self, x, y));
 
-  int i;
-  int nd = x->dim();
-  ptrdiff_t nelem = THCTensor_(nElement)(state, x);
-  THArgCheck(nd == y->dim(), 1, "tensors must have same number of dimensions");
-  for (i = 0; i < nd; i++) {
-    THArgCheck(THCTensor_(size)(state, x, i) == THCTensor_(size)(state, y, i), 1, "dimension %i of x and y does not match", i);
-    if (dimension < 0 && THCTensor_(size)(state, x, i) == 3) {
-      dimension = i;
-    }
-  }
-
-  THArgCheck(dimension >= 0 && dimension < nd, 3, "dimension %d out of range", dimension+1);
-  THArgCheck(THCTensor_(size)(state, x, dimension) == 3, 3,
-      "dimension %d does not have size 3", dimension+1);
-  THCTensor_(resizeAs)(state, self, x);
-
   int64_t sx = THCTensor_(stride)(state, x, dimension);
   int64_t sy = THCTensor_(stride)(state, y, dimension);
   int64_t so = THCTensor_(stride)(state, self, dimension);
index 78559f5..5539e8e 100644 (file)
@@ -47,7 +47,7 @@ THC_API void THCTensor_(neg)(THCState *state, THCTensor *self, THCTensor *src);
 THC_API void THCTensor_(abs)(THCState *state, THCTensor *self, THCTensor *src);
 THC_API void THCTensor_(sign)(THCState *state, THCTensor *self, THCTensor *src);
 THC_API void THCTensor_(clamp)(THCState *state, THCTensor *self, THCTensor *src, scalar_t min_value, scalar_t max_value);
-THC_API void THCTensor_(cross)(THCState *state, THCTensor *self, THCTensor *src1, THCTensor *src2, int dimension);
+THC_API void THCTensor_(crossKernel)(THCState *state, THCTensor *self, THCTensor *src1, THCTensor *src2, int dimension);
 
 THC_API void THCTensor_(cadd)(THCState *state, THCTensor *self, THCTensor *src1, scalar_t value, THCTensor *src2);
 THC_API void THCTensor_(csub)(THCState *state, THCTensor *self, THCTensor *src1, scalar_t value, THCTensor *src2);
index 5de9045..7d44307 100644 (file)
@@ -2367,6 +2367,35 @@ class _TestTorchMixin(object):
         torch.cross(x, y, out=res2)
         self.assertEqual(res1, res2)
 
+    def test_cross_with_and_without_dim(self):
+        x = torch.rand(100, 3)
+        y = torch.rand(100, 3)
+        res1 = torch.cross(x, y, dim=1)
+        res2 = torch.cross(x, y, dim=-1)
+        res3 = torch.cross(x, y)
+        self.assertEqual(res1, res2)
+        self.assertEqual(res1, res3)
+
+    def test_cross_validation(self):
+        self.assertRaisesRegex(
+            RuntimeError, "inconsistent tensors dimensions",
+            lambda: torch.cross(torch.rand(100, 3), torch.rand(100, 3, 10)))
+        self.assertRaisesRegex(
+            RuntimeError, "inconsistent tensors sizes",
+            lambda: torch.cross(torch.rand(5, 3), torch.rand(3, 5)))
+        self.assertRaisesRegex(
+            RuntimeError, "no dimension of size 3 in input",
+            lambda: torch.cross(torch.rand(5, 4), torch.rand(5, 4)))
+        self.assertRaisesRegex(
+            RuntimeError, "dimension 0 does not have size 3",
+            lambda: torch.cross(torch.rand(5, 4, 3), torch.rand(5, 4, 3), dim=0))
+        self.assertRaisesRegex(
+            RuntimeError, "dimension -1 does not have size 3",
+            lambda: torch.cross(torch.rand(5, 3, 4), torch.rand(5, 3, 4), dim=-1))
+        self.assertRaisesRegex(
+            IndexError, "Dimension out of range",
+            lambda: torch.cross(torch.rand(5, 3, 4), torch.rand(5, 3, 4), dim=-5))
+
     def test_zeros(self):
         res1 = torch.zeros(100, 100)
         res2 = torch.Tensor()
index 425a5b5..bea70b5 100644 (file)
 - name: cosh(Tensor self)
   self: grad * self.sinh()
 
-- name: cross(Tensor self, Tensor other, int64_t dim)
+- name: cross(Tensor self, Tensor other, int64_t? dim)
   self: other.cross(grad, dim)
   other: grad.cross(self, dim)