Add better support for bools in the graph fuser (#15057)
authorRichard Zou <zou3519@gmail.com>
Wed, 12 Dec 2018 17:37:10 +0000 (09:37 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 12 Dec 2018 17:39:44 +0000 (09:39 -0800)
Summary:
Fixes #15038.

aten::_cast_Float(tensor, non_blocking) support was added in #14336.
Its second argument is a bool, but because we don't support generating values
of type bool in the fuser codegen, the codegen errored out.

aten::_cast_Float in the fuser never actually uses its non_blocking
argument, so another way to fix this would be to have a special op for a
fused cast but I thought that we might have fusible ops that do take
bool arguments in the future so this would be good to have.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15057

Differential Revision: D13432091

Pulled By: zou3519

fbshipit-source-id: 455fe574f5f080aca9a112e346b841a2534a8dc3

test/test_jit.py
torch/csrc/jit/fuser/codegen.cpp

index 30f319e..e6696cb 100644 (file)
@@ -9612,6 +9612,31 @@ class TestFuser(JitTestCase):
         ge = self.checkTrace(scaleshift, inputs)
         self.assertExpectedGraph(ge.graph_for(*inputs))
 
+    @staticmethod
+    def _test_cast_Float(self, device):
+        def f(x, y):
+            z = x.float()
+            return z + y
+
+        inputs = [
+            torch.randn(4, 4, dtype=torch.double, device=device),
+            torch.randn(4, 4, dtype=torch.float, device=device),
+        ]
+
+        ge = self.checkScript(f, inputs)
+        self.assertAllFused(ge.graph_for(*inputs))
+
+    @unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: fuser support for Windows or Sandcastle")
+    @enable_cpu_fuser
+    def test_cast_Float(self):
+        return self._test_cast_Float(self, 'cpu')
+
+    @unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: fuser support for Windows or Sandcastle")
+    @unittest.skipIf(not RUN_CUDA, "No CUDA")
+    @skipIfRocm
+    def test_cast_Float_cuda(self):
+        return self._test_cast_Float(self, 'cuda')
+
     @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
     @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
     @unittest.skipIf(not RUN_CUDA_HALF, "no half support")
index 8120f45..ebfee5c 100644 (file)
@@ -94,6 +94,8 @@ static std::string variableType(const std::shared_ptr<c10::Type> t) {
     return "int";
   } else if (t->kind() == TypeKind::FloatType) {
     return "float";
+  } else if (t->kind() == TypeKind::BoolType) {
+    return "bool";
   } else if (t->kind() == TypeKind::TensorType) {
     auto const tt = t->cast<TensorType>();
     return calcScalarTypeName(tt->scalarType());
@@ -103,7 +105,7 @@ static std::string variableType(const std::shared_ptr<c10::Type> t) {
 }
 
 static std::string typeCastedValueName(const std::shared_ptr<c10::Type> t, const at::ScalarType outtype, const std::string& vn) {
-  if (t->kind() == TypeKind::IntType) {
+  if (t->kind() == TypeKind::IntType || t->kind() == TypeKind::BoolType) {
     if (! isIntegralType(outtype)) {
       return std::string("((") + calcScalarTypeName(outtype) + ") " + vn + ")";
     }