Revert D15003387: Remove 'BoolTensor', 'IndexTensor' from frontend specifications.
authorGregory Chanan <gchanan@fb.com>
Fri, 19 Apr 2019 18:23:52 +0000 (11:23 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 19 Apr 2019 18:27:10 +0000 (11:27 -0700)
Differential Revision:
D15003387

Original commit changeset: e518e8ce3228

fbshipit-source-id: af5b107239446ea8d6f229a427d5b157fcafd224

aten/src/ATen/function_wrapper.py
aten/src/ATen/native/README.md
tools/autograd/gen_variable_type.py

index 21943f0..bb6003a 100644 (file)
@@ -969,6 +969,8 @@ def create_generic(top_env, declarations):
                 # type: (bool) -> Dict[str, str]
                 return {
                     'Tensor': 'const Tensor &' if const else 'Tensor &',
+                    'BoolTensor': 'const Tensor &' if const else 'Tensor &',
+                    'IndexTensor': 'const Tensor &' if const else 'Tensor &',
                     'Type': 'const Type &' if const else 'Type &',
                     'TensorOptions': 'const TensorOptions &' if const else 'TensorOptions &',
                     'TensorList': 'TensorList',
index 3e5b736..43e48ef 100644 (file)
@@ -57,6 +57,12 @@ signature.
   `Tensor!` - shorthand for Tensor(fresh\_identifier!)
   `Tensor(a! -> a|b)` - Tensor is in set `a`, written to, and after the write is in set `a` AND `b`.
   For more details on when and why this needs to happen, please see the section on annotations.
+- Tensors of specific types.  At the moment, valid type names are:
+    - `IntegerTensor` (a.k.a. `LongTensor`)
+    - `BoolTensor` (a.k.a. `ByteTensor`)
+    - `IndexTensor` (a.k.a. `IntTensor`)
+  These type names were inherited from TH, and may be renamed soon, so
+  don't commit them to memory.
 - `Tensor[]`.  A `Tensor[]` argument translates into a C++ argument of type `ArrayRef<Tensor>`
   (a.k.a. `TensorList`)
 - `int[]`.  `int[]` accepts an optional length specifier, e.g., `int[2]`, which
index 818cffa..c1ebb46 100644 (file)
@@ -478,16 +478,18 @@ def emit_body(declaration):
     base_name = name[:-1] if inplace else name[:-4] if is_out_fn else name
     view_info = VIEW_FUNCTIONS.get(base_name, None)
 
+    # These exclude things like BoolTensor, int64_t, and Scalar
     def is_differentiable(arg):
         if 'TensorOptions' in arg['type']:
             return False
         if 'Tensor' not in arg['type']:
             return False
         if arg['dynamic_type'] in {'IndexTensor', 'BoolTensor'}:
+            # TODO: Enable this after native_functions.yaml schema unification.
             # These are necessary for legacy code and should be
             # used by legacy code only!
-            assert declaration['mode'] == 'TH' or declaration['mode'] == 'NN', \
-                "IndexTensor and BoolTensor are restricted to legacy TH/THNN functions only."
+            # assert name.startswith('_th_'), \
+            # "IndexTensor and BoolTensor are restricted to legacy _th_ functions only.
             return False
         return True