Revert D13540278: [pytorch][PR] Unhide unique from C++, make unique partially scriptable
authorWanchao Liang <wanchaol@fb.com>
Tue, 22 Jan 2019 20:11:23 +0000 (12:11 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Tue, 22 Jan 2019 20:22:40 +0000 (12:22 -0800)
Differential Revision:
D13540278

Original commit changeset: 3768c76a90b0

fbshipit-source-id: 7a31c239f9dca6ff467344d99820095addcae9d7

13 files changed:
aten/src/ATen/native/Unique.cpp
aten/src/ATen/native/cuda/Unique.cu
aten/src/ATen/native/native_functions.yaml
aten/src/ATen/native/sparse/SparseTensor.cpp
test/onnx/expect/TestOperators.test_unique.expect [deleted file]
test/onnx/expect/TestOperators.test_unique_dim.expect [deleted file]
test/onnx/test_operators.py
test/test_jit.py
tools/autograd/derivatives.yaml
tools/autograd/gen_python_functions.py
torch/functional.py
torch/onnx/symbolic.py
torch/tensor.py

index 7f1b3af..8f6cfae 100644 (file)
@@ -126,28 +126,18 @@ std::tuple<Tensor, Tensor> _unique_dim_cpu_template(
 } // namespace
 
 std::tuple<Tensor, Tensor>
-_unique_cpu(const Tensor& self, const bool sorted, const bool return_inverse, optional<int64_t> dim) {
-  if (dim) {
-    return AT_DISPATCH_ALL_TYPES(self.type(), "unique", [&] {
-      // The current implementation using `dim` always sorts due to unhashable tensors
-      return _unique_dim_cpu_template<scalar_t>(self, dim.value(), return_inverse);
-    });
-  }
+_unique_cpu(const Tensor& self, const bool sorted, const bool return_inverse) {
   return AT_DISPATCH_ALL_TYPES(self.type(), "unique", [&] {
     return _unique_cpu_template<scalar_t>(self, sorted, return_inverse);
   });
 }
 
-std::tuple<Tensor, Tensor> unique_dim(const Tensor& self, int64_t dim, const bool sorted, const bool return_inverse) {
-  return at::unique(self, sorted, return_inverse, dim);
-}
-
-std::tuple<Tensor, Tensor> _unique(const Tensor& self, const bool sorted, const bool return_inverse) {
-  return at::unique(self, sorted, return_inverse);
-}
-
-std::tuple<Tensor, Tensor> _unique_dim(const Tensor& self, int64_t dim, const bool sorted, const bool return_inverse) {
-  return at::unique(self, sorted, return_inverse, dim);
+std::tuple<Tensor, Tensor>
+_unique_dim_cpu(const Tensor& self, const int64_t dim, const bool sorted, const bool return_inverse) {
+  return AT_DISPATCH_ALL_TYPES(self.type(), "unique_dim", [&] {
+    // The current implementation using `dim` always sorts due to unhashable tensors
+    return _unique_dim_cpu_template<scalar_t>(self, dim, return_inverse);
+  });
 }
 
 }  // namespace native
index 29efe7c..828fb48 100644 (file)
@@ -145,12 +145,7 @@ template <typename scalar_t>
 } // namespace
 
 std::tuple<Tensor, Tensor>
-_unique_cuda(const Tensor& self, const bool sorted, const bool return_inverse, optional<int64_t> dim) {
-  if (dim) {
-    return AT_DISPATCH_ALL_TYPES(self.type(), "unique", [&] {
-      return _unique_dim_cuda_template<scalar_t>(self, dim.value(), return_inverse);
-    });
-  }
+_unique_cuda(const Tensor& self, const bool sorted, const bool return_inverse) {
   return AT_DISPATCH_ALL_TYPES(self.type(), "unique", [&] {
     // The current CUDA implementation of unique always sort due to the
     // lack of hashtable implementation in thrust
@@ -158,5 +153,12 @@ _unique_cuda(const Tensor& self, const bool sorted, const bool return_inverse, o
   });
 }
 
+std::tuple<Tensor, Tensor>
+_unique_dim_cuda(const Tensor& self, const int64_t dim, const bool sorted, const bool return_inverse) {
+  return AT_DISPATCH_ALL_TYPES(self.type(), "unique_dim", [&] {
+    return _unique_dim_cuda_template<scalar_t>(self, dim, return_inverse);
+  });
+}
+
 }  // namespace native
 }  // namespace at
index f4e51c1..9608f9b 100644 (file)
   matches_jit_signature: True
   variants: method
 
-- func: unique(Tensor self, bool sorted=true, bool return_inverse=false, int64_t? dim=None) -> (Tensor, Tensor)
+- func: _unique(Tensor self, bool sorted=true, bool return_inverse=false) -> (Tensor, Tensor)
   variants: function
   dispatch:
     CPU: _unique_cpu
     CUDA: _unique_cuda
 
-# unique_dim is not exposed to python, special cased in gen_python_functions.py
-- func: unique_dim(Tensor self, int64_t dim, bool sorted=true, bool return_inverse=false) -> (Tensor, Tensor)
-  variants: function
-
-# FIXME: for back compatibility reason, _unique and _unique_dim is still there
-# it just calls unique. These two functions should be deleted in the future
-- func: _unique(Tensor self, bool sorted=true, bool return_inverse=false) -> (Tensor, Tensor)
-  variants: function
-
 - func: _unique_dim(Tensor self, int64_t dim, bool sorted=true, bool return_inverse=false) -> (Tensor, Tensor)
   variants: function
+  dispatch:
+    CPU: _unique_dim_cpu
+    CUDA: _unique_dim_cuda
 
 - func: _unsafe_view(Tensor self, IntList size) -> Tensor
 
index 6cb55e9..ce69e5c 100644 (file)
@@ -288,7 +288,7 @@ SparseTensor dense_to_sparse(const Tensor& self){
 SparseTensor dense_to_sparse(const Tensor& self, int64_t sparse_dim){
   int64_t dims = self.dim();
   AT_CHECK(sparse_dim > 0, "sparse_dim must be >0");
-  AT_CHECK(sparse_dim <= dims,
+  AT_CHECK(sparse_dim <= dims, 
     "sparse_dim must be less than or equal to self.dim()");
   at::TensorOptions sparse_options = self.options().layout(kSparse);
   std::vector<int64_t> sizes = self.sizes().vec();
@@ -302,7 +302,7 @@ SparseTensor dense_to_sparse(const Tensor& self, int64_t sparse_dim){
     indices = nz.clone();
   } else {
     Tensor i = nz.narrow(0, 0, sparse_dim);
-    std::tie(indices, std::ignore) = at::unique_dim(i, 1);
+    std::tie(indices, std::ignore) = _unique_dim(i, 1);
     indices = indices.contiguous();  // many sparse CUDA kernels require contiguity, see issue #12633
   }
 
diff --git a/test/onnx/expect/TestOperators.test_unique.expect b/test/onnx/expect/TestOperators.test_unique.expect
deleted file mode 100644 (file)
index caec02c..0000000
+++ /dev/null
@@ -1,56 +0,0 @@
-ir_version: 3
-producer_name: "pytorch"
-producer_version: "0.4"
-graph {
-  node {
-    input: "x"
-    output: "1"
-    output: "2"
-    op_type: "ATen"
-    attribute {
-      name: "operator"
-      s: "unique"
-      type: STRING
-    }
-    attribute {
-      name: "return_inverse"
-      i: 0
-      type: INT
-    }
-    attribute {
-      name: "sorted"
-      i: 1
-      type: INT
-    }
-  }
-  name: "torch-jit-export"
-  input {
-    name: "x"
-    type {
-      tensor_type {
-        elem_type: 1
-        shape {
-          dim {
-            dim_value: 10
-          }
-        }
-      }
-    }
-  }
-  output {
-    name: "1"
-    type {
-      tensor_type {
-        elem_type: 1
-        shape {
-          dim {
-            dim_value: 10
-          }
-        }
-      }
-    }
-  }
-}
-opset_import {
-  version: 9
-}
diff --git a/test/onnx/expect/TestOperators.test_unique_dim.expect b/test/onnx/expect/TestOperators.test_unique_dim.expect
deleted file mode 100644 (file)
index e781709..0000000
+++ /dev/null
@@ -1,67 +0,0 @@
-ir_version: 3
-producer_name: "pytorch"
-producer_version: "0.4"
-graph {
-  node {
-    input: "x"
-    output: "1"
-    output: "2"
-    op_type: "ATen"
-    attribute {
-      name: "dim"
-      i: 1
-      type: INT
-    }
-    attribute {
-      name: "operator"
-      s: "unique"
-      type: STRING
-    }
-    attribute {
-      name: "return_inverse"
-      i: 0
-      type: INT
-    }
-    attribute {
-      name: "sorted"
-      i: 1
-      type: INT
-    }
-  }
-  name: "torch-jit-export"
-  input {
-    name: "x"
-    type {
-      tensor_type {
-        elem_type: 1
-        shape {
-          dim {
-            dim_value: 10
-          }
-          dim {
-            dim_value: 10
-          }
-        }
-      }
-    }
-  }
-  output {
-    name: "1"
-    type {
-      tensor_type {
-        elem_type: 1
-        shape {
-          dim {
-            dim_value: 10
-          }
-          dim {
-            dim_value: 10
-          }
-        }
-      }
-    }
-  }
-}
-opset_import {
-  version: 9
-}
index 4afa19b..ddb98fa 100644 (file)
@@ -290,14 +290,6 @@ class TestOperators(TestCase):
         x = torch.randn(1, 2, 3, 4, requires_grad=True)
         self.assertONNX(lambda x: x.clamp(max=0.1), x)
 
-    def test_unique(self):
-        x = torch.randn(10, requires_grad=True)
-        self.assertONNX(lambda x: torch.unique(x), x)
-
-    def test_unique_dim(self):
-        x = torch.randn(10, 10, requires_grad=True)
-        self.assertONNX(lambda x: torch.unique(x, dim=1), x)
-
     def test_hardtanh(self):
         x = torch.randn(3, 4, requires_grad=True)
         self.assertONNX(lambda x: torch.nn.Hardtanh(-0.5, 0.5)(x), x)
index 36803b7..b26521f 100644 (file)
@@ -4410,35 +4410,6 @@ a")
         self.checkScript(test_script_clamp_min_none, input, optimize=True)
         self.checkScript(test_script_clamp_min, input, optimize=True)
 
-    def test_script_unique_none(self):
-        def test_unique_inverse(a):
-            b, c = torch.unique(a, return_inverse=True)
-            return b + 1
-
-        def test_unique_inverse_nonedim(a):
-            b, c = torch.unique(a, return_inverse=True, dim=None)
-            return b + 1
-
-        def test_unique_noinverse(a):
-            b = torch.unique(a)
-            return b + 1
-
-        def test_unique_noinverse_nonedim(a):
-            b = torch.unique(a, dim=None)
-            return b + 1
-
-        a = torch.rand(5, 6, 7)
-
-        self.checkTrace(test_unique_inverse, [a], inputs_require_grads=False)
-        self.checkTrace(test_unique_inverse_nonedim, [a], inputs_require_grads=False)
-        self.checkTrace(test_unique_noinverse, [a], inputs_require_grads=False)
-        self.checkTrace(test_unique_noinverse_nonedim, [a], inputs_require_grads=False)
-        self.checkScript(test_unique_inverse, [a])
-        self.checkScript(test_unique_inverse_nonedim, [a])
-        # TODO: scripting unique when return_inverse = False is not supported yet
-        # self.checkScript(test_unique_noinverse, [a])
-        # self.checkScript(test_unique_noinverse_nonedim, [a])
-
     def test_script_bool_constant(self):
         script = '''
         def test_script_bool_constant():
index 621496d..5d2bda0 100644 (file)
 - name: uniform_(Tensor self, double from, double to, Generator generator)
   self: zeros_like(grad)
 
-- name: unique(Tensor self, bool sorted, bool return_inverse, int64_t? dim)
-  self: not_implemented("unique")
+- name: _unique(Tensor self, bool sorted, bool return_inverse)
+  self: not_implemented("_unique")
 
 - name: _unsafe_view(Tensor self, IntList size)
   self: grad.reshape(self.sizes())
index d837dd7..ffacdf8 100644 (file)
@@ -28,7 +28,7 @@ SKIP_PYTHON_BINDINGS = [
     '_th_.*', '_thnn_.*',
     'arange.*', 'range.*', '_gesv.*', '_getri.*', '_inverse.*',
     '_potrs.*', '_cholesky.*',
-    'slice', 'randint(_out)?', 'unique_dim', '_unique', '_unique_dim',
+    'slice', 'randint(_out)?',
     'item', '_local_scalar_dense',
     'max_pool1d', 'max_pool2d', 'max_pool3d', 'linear', 'to',
     'copy_sparse_to_sparse_',
index 2e1f0bb..9847f34 100644 (file)
@@ -433,8 +433,19 @@ def unique(input, sorted=True, return_inverse=False, dim=None):
                 [ 1,  2]])
 
     """
-    output, inverse_indices = torch._C._VariableFunctions.unique(
-        input, sorted=sorted, return_inverse=return_inverse, dim=dim)
+    if dim is not None:
+        output, inverse_indices = torch._unique_dim(
+            input,
+            dim,
+            sorted=sorted,
+            return_inverse=return_inverse
+        )
+    else:
+        output, inverse_indices = torch._unique(
+            input,
+            sorted=sorted,
+            return_inverse=return_inverse,
+        )
     if return_inverse:
         return output, inverse_indices
     else:
index 88ba735..231cf56 100644 (file)
@@ -1048,16 +1048,10 @@ def conv_tbc(g, input, weight, bias, pad):
     return g.op("ATen", input, weight, bias, operator_s="conv_tbc", pad_i=pad)
 
 
-def unique(g, input, sorted, return_inverse, dim):
-    sorted = _parse_arg(sorted, 'i')
-    return_inverse = _parse_arg(return_inverse, 'i')
-    if dim.node().kind() == "prim::None":
-        return g.op("ATen", input, operator_s="unique", sorted_i=sorted,
-                    return_inverse_i=return_inverse, outputs=2)
-    else:
-        dim = _parse_arg(dim, 'i')
-        return g.op("ATen", input, operator_s="unique", sorted_i=sorted,
-                    return_inverse_i=return_inverse, dim_i=dim, outputs=2)
+@parse_args('v', 'i', 'i')
+def _unique(g, input, sorted, return_inverse):
+    return g.op("ATen", input, operator_s="_unique", sorted_i=sorted,
+                return_inverse_i=return_inverse, outputs=2)
 
 
 # Metaprogram symbolics for each ATen native specialized cast operator.
index a69db66..be936bf 100644 (file)
@@ -343,7 +343,23 @@ class Tensor(torch._C._TensorBase):
 
         See :func:`torch.unique`
         """
-        return torch.unique(self, sorted, return_inverse, dim)
+        if dim is not None:
+            output, inverse_indices = torch._unique_dim(
+                self,
+                sorted=sorted,
+                return_inverse=return_inverse,
+                dim=dim
+            )
+        else:
+            output, inverse_indices = torch._unique(
+                self,
+                sorted=sorted,
+                return_inverse=return_inverse
+            )
+        if return_inverse:
+            return output, inverse_indices
+        else:
+            return output
 
     def __rsub__(self, other):
         return _C._VariableFunctions.rsub(self, other)