properly device_guard IndexTensor and BoolTensor. (#18072)
authorGregory Chanan <gchanan@fb.com>
Sun, 17 Mar 2019 22:37:41 +0000 (15:37 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Sun, 17 Mar 2019 22:40:39 +0000 (15:40 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18072
ghimport-source-id: 9653731602c72f299e095dd50e3afe6bcc8b01d6

Stack:
* **#18072 properly device_guard IndexTensor and BoolTensor.**
* #18073 Change one_hot from IndexTensor to Tensor.

Currently IndexTensor and BoolTensors do not have device_guards applied to them.
This is bad in the case where the only tensor(s) are IndexTensors or BoolTensors, because no device guard is present.

The only case this currently happens is with one_hot which ends up not mattering because of the way the implementation is written.  But I wanted to make sure we are covered here.

Reviewed By: ezyang

Differential Revision: D14485249

fbshipit-source-id: e57b28086fa1ad2fdd248bb1220e8a2e42da03e1

aten/src/ATen/function_wrapper.py

index 0fc117d..c3182ee 100644 (file)
@@ -731,14 +731,19 @@ def create_generic(top_env, declarations):
     def find_dispatch_tensor(formals):
         # type: (List[AtFormal]) -> Optional[str]
         # dispatch to self if it's a parameter
+        def is_any_tensor_type(formal):
+            return (formal['dynamic_type'] == 'Tensor' or formal['dynamic_type'] == 'BoolTensor'
+                    or formal['dynamic_type'] == 'IndexTensor')
+
         for formal in formals:
-            if formal['name'] == 'self' and formal['dynamic_type'] == 'Tensor' and not formal.get('is_nullable', False):
+            if formal['name'] == 'self' and is_any_tensor_type(formal) and not formal.get('is_nullable', False):
                 return formal['name']
         # otherwise dispatch to the first Tensor or TensorList
         for formal in formals:
-            if 'TensorList' == formal['dynamic_type'] or formal['dynamic_type'] == 'Tensor' and \
+            if 'TensorList' == formal['dynamic_type'] or is_any_tensor_type(formal) and \
                not formal.get('is_nullable', False):
                 return formal['name']
+
         return None
 
     def format_formal(f):