From 0ad39ec5c17864fc7aec2605c82fe82fba71e3c2 Mon Sep 17 00:00:00 2001 From: Richard Zou Date: Wed, 12 Dec 2018 09:37:10 -0800 Subject: [PATCH] Add better support for bools in the graph fuser (#15057) Summary: Fixes #15038. aten::_cast_Float(tensor, non_blocking) support was added in #14336. Its second argument is a bool, but because we don't support generating values of type bool in the fuser codegen, the codegen errored out. aten::_cast_Float in the fuser never actually uses its non_blocking argument, so another way to fix this would be to have a special op for a fused cast but I thought that we might have fusible ops that do take bool arguments in the future so this would be good to have. Pull Request resolved: https://github.com/pytorch/pytorch/pull/15057 Differential Revision: D13432091 Pulled By: zou3519 fbshipit-source-id: 455fe574f5f080aca9a112e346b841a2534a8dc3 --- test/test_jit.py | 25 +++++++++++++++++++++++++ torch/csrc/jit/fuser/codegen.cpp | 4 +++- 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/test/test_jit.py b/test/test_jit.py index 30f319e..e6696cb 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -9612,6 +9612,31 @@ class TestFuser(JitTestCase): ge = self.checkTrace(scaleshift, inputs) self.assertExpectedGraph(ge.graph_for(*inputs)) + @staticmethod + def _test_cast_Float(self, device): + def f(x, y): + z = x.float() + return z + y + + inputs = [ + torch.randn(4, 4, dtype=torch.double, device=device), + torch.randn(4, 4, dtype=torch.float, device=device), + ] + + ge = self.checkScript(f, inputs) + self.assertAllFused(ge.graph_for(*inputs)) + + @unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: fuser support for Windows or Sandcastle") + @enable_cpu_fuser + def test_cast_Float(self): + return self._test_cast_Float(self, 'cpu') + + @unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: fuser support for Windows or Sandcastle") + @unittest.skipIf(not RUN_CUDA, "No CUDA") + @skipIfRocm + def test_cast_Float_cuda(self): + return self._test_cast_Float(self, 'cuda') + @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows") @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") @unittest.skipIf(not RUN_CUDA_HALF, "no half support") diff --git a/torch/csrc/jit/fuser/codegen.cpp b/torch/csrc/jit/fuser/codegen.cpp index 8120f45..ebfee5c 100644 --- a/torch/csrc/jit/fuser/codegen.cpp +++ b/torch/csrc/jit/fuser/codegen.cpp @@ -94,6 +94,8 @@ static std::string variableType(const std::shared_ptr t) { return "int"; } else if (t->kind() == TypeKind::FloatType) { return "float"; + } else if (t->kind() == TypeKind::BoolType) { + return "bool"; } else if (t->kind() == TypeKind::TensorType) { auto const tt = t->cast(); return calcScalarTypeName(tt->scalarType()); @@ -103,7 +105,7 @@ static std::string variableType(const std::shared_ptr t) { } static std::string typeCastedValueName(const std::shared_ptr t, const at::ScalarType outtype, const std::string& vn) { - if (t->kind() == TypeKind::IntType) { + if (t->kind() == TypeKind::IntType || t->kind() == TypeKind::BoolType) { if (! isIntegralType(outtype)) { return std::string("((") + calcScalarTypeName(outtype) + ") " + vn + ")"; } -- 2.7.4