From ebc0aacf83a0446ed798a96059c05da815c73d3d Mon Sep 17 00:00:00 2001 From: Bert Maher Date: Mon, 30 Aug 2021 18:36:33 -0700 Subject: [PATCH] [nnc] Fix half2float conversion and re-enable float16 (#64199) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64199 Test Plan: Imported from OSS Reviewed By: ZolotukhinM Differential Revision: D30643865 Pulled By: bertmaher fbshipit-source-id: 9de6adca53bd08839328cbaf6364f7de9550264b --- test/test_jit_fuser_te.py | 44 ++++++++++++++++++++++-------- test/test_tensorexpr.py | 1 - torch/csrc/jit/passes/tensorexpr_fuser.cpp | 2 +- torch/csrc/jit/tensorexpr/half_support.h | 40 +++++++++++++++++++++++++++ torch/csrc/jit/tensorexpr/ir_verifier.cpp | 12 ++++++++ 5 files changed, 86 insertions(+), 13 deletions(-) diff --git a/test/test_jit_fuser_te.py b/test/test_jit_fuser_te.py index 6d2432a..918cc70 100644 --- a/test/test_jit_fuser_te.py +++ b/test/test_jit_fuser_te.py @@ -94,8 +94,7 @@ class TestTEFuser(JitTestCase): torch.bool, ] self.fp_dtypes = [ - # TODO: Add back when https://github.com/pytorch/pytorch/issues/55905 is closed - # torch.float16, + torch.float16, torch.float32, torch.float64, ] @@ -1130,8 +1129,7 @@ class TestTEFuser(JitTestCase): dtypes = [ torch.bool, torch.int, - # TODO: Add back when https://github.com/pytorch/pytorch/issues/55905 is closed - # torch.float16, + torch.float16, torch.float32, torch.float64, ] @@ -1146,6 +1144,9 @@ class TestTEFuser(JitTestCase): bad_dtypes = [] for dtype, output_dtype, device, size in product(dtypes, dtypes, self.devices, sizes): + # TODO: Add back when https://github.com/pytorch/pytorch/issues/55905 is closed + if dtype == torch.float16 and device == "cpu": + continue if dtype == output_dtype: continue @@ -1201,18 +1202,16 @@ class TestTEFuser(JitTestCase): torch.int16, torch.int32, torch.int64, - # TODO: Add back when https://github.com/pytorch/pytorch/issues/55905 is closed - # torch.float16, + torch.float16, torch.float32, torch.float64, torch.bool, ] for inp, device, dtype in product(inputs, self.devices, dtypes): - # TODO - if dtype == torch.float16 and not LLVM_ENABLED: + # TODO: Add back when https://github.com/pytorch/pytorch/issues/55905 is closed + if dtype == torch.float16 and device == "cpu": continue - inp = inp.to(device=device, dtype=dtype) try: f = torch.jit.trace(lambda x: x.isnan(), (inp,)) @@ -1272,6 +1271,9 @@ class TestTEFuser(JitTestCase): gpu_only = {torch.erf, torch.erfc} sizes = [(1,), (2,), (4, 4)] for dtype, op, device, size in product(self.dtypes, unary_ops, self.devices, sizes): + # TODO: Add back when https://github.com/pytorch/pytorch/issues/55905 is closed + if dtype == torch.float16 and device == "cpu": + continue if op in gpu_only and device == "cpu": continue try: @@ -1323,6 +1325,8 @@ class TestTEFuser(JitTestCase): ] devices = self.devices for dtype, op, device in product(self.dtypes, binary_ops, devices): + if dtype == torch.float16 and device == "cpu": + continue try: x = self.data_for(dtype, device) y = self.data_for(dtype, device) @@ -1373,6 +1377,8 @@ class TestTEFuser(JitTestCase): "[[10, 3, 4], [4, 5]]", ] for dtype, size, device in product(self.dtypes, sizes, devices): + if dtype == torch.float16 and device == "cpu": + continue try: size_x, size_y = size x = self.data_for(dtype, device, size=size_x) @@ -1417,6 +1423,8 @@ class TestTEFuser(JitTestCase): # only using scalar values relevant to particular ops scalars = [1.5, 3, 0, -2.0, -1] for dtype, op, device, scalar in product(self.dtypes, binary_ops, devices, scalars): + if dtype == torch.float16 and device == "cpu": + continue try: x = self.data_for(dtype, device) fn = apply_with_scalar(op, scalar) @@ -1449,6 +1457,8 @@ class TestTEFuser(JitTestCase): # only using scalar values relevant to particular ops scalars = [1.5, 3, -2.0, -1] # skip 0 for dtype, op, device, scalar in product(self.dtypes, binary_ops, devices, scalars): + if dtype == torch.float16 and device == "cpu": + continue try: x = self.data_for(dtype, device) fn = apply_with_scalar(op, scalar) @@ -1484,6 +1494,8 @@ class TestTEFuser(JitTestCase): # only using scalar values relevant to particular ops scalars = [1.5, 3, 0, -2.0, -1] for dtype, op, device, scalar in product(dtypes, binary_ops, self.devices, scalars): + if dtype == torch.float16 and device == "cpu": + continue try: x = self.data_for(dtype, device) fn = apply_with_scalar(op, scalar) @@ -1512,6 +1524,8 @@ class TestTEFuser(JitTestCase): ] devices = self.devices for dtype, op, device in product(self.dtypes, ternary_ops, devices): + if dtype == torch.float16 and device == "cpu": + continue try: x = self.data_for(dtype, device) y = self.data_for(dtype, device) @@ -1541,6 +1555,8 @@ class TestTEFuser(JitTestCase): ] devices = self.devices for dtype, op, device in product(self.dtypes, ternary_ops, devices): + if dtype == torch.float16 and device == "cpu": + continue try: x = self.data_for(dtype, device, size=[5, 3, 128, 128]) y = self.data_for(dtype, device, size=[3]) @@ -1572,6 +1588,8 @@ class TestTEFuser(JitTestCase): torch.cat, ] for dtype, op, device in product(self.dtypes, list_ops, devices): + if dtype == torch.float16 and device == "cpu": + continue try: x = self.data_for(dtype, device, size=[5, 4, 1, 7]) y = self.data_for(dtype, device, size=[5, 4, 1, 7]) @@ -1603,6 +1621,8 @@ class TestTEFuser(JitTestCase): ] devices = self.devices for dtype, op, device in product(self.dtypes, ops, devices): + if dtype == torch.float16 and device == "cpu": + continue try: cond = self.data_for(torch.bool, device) x = self.data_for(dtype, device) @@ -1768,7 +1788,10 @@ class TestTEFuser(JitTestCase): with inline_fusion_groups(): def eager(x, y): return torch.cat((x, y.type_as(x)), dim=1) - for dtype1, dtype2 in product(self.dtypes, self.dtypes): + dtypes = self.dtypes.copy() + # CPU fuser doesn't support float16. + dtypes.remove(torch.float16) + for dtype1, dtype2 in product(dtypes, dtypes): x = torch.randint(2, (1, 13,)).to(dtype1) zero = torch.tensor([[0]]).to(dtype2) one = torch.tensor([[1]]).to(dtype2) @@ -1936,7 +1959,6 @@ class TestTEFuser(JitTestCase): for fn in [bn, bn_no_weight, bn_no_bias, bn_neither]: test(fn, (i, x)) - works_list = [ '__radd__', '__rdiv__', diff --git a/test/test_tensorexpr.py b/test/test_tensorexpr.py index 47c7e68..366c262 100644 --- a/test/test_tensorexpr.py +++ b/test/test_tensorexpr.py @@ -1222,7 +1222,6 @@ class TestTensorExprFuser(BaseTestClass): x = warmup_and_run_forward(traced, a, b) self.assertLastGraphAllFused() - @unittest.skip("float16 is not supported yet.") def test_half_bn_relu(self): devices = ["cuda"] if torch.cuda.is_available() else [] diff --git a/torch/csrc/jit/passes/tensorexpr_fuser.cpp b/torch/csrc/jit/passes/tensorexpr_fuser.cpp index 1d5128c..a3e3707 100644 --- a/torch/csrc/jit/passes/tensorexpr_fuser.cpp +++ b/torch/csrc/jit/passes/tensorexpr_fuser.cpp @@ -966,7 +966,7 @@ class TensorExprFuser { // but on top of that Float16 has a few kinks on LLVM. Thus, on CPU we // additionally disable it until we either move to a more stable version // or find workarounds. - if (*st == c10::ScalarType::Half) { + if (*st == c10::ScalarType::Half && *device == c10::kCPU) { return false; } diff --git a/torch/csrc/jit/tensorexpr/half_support.h b/torch/csrc/jit/tensorexpr/half_support.h index eaf74d3..674af8a 100644 --- a/torch/csrc/jit/tensorexpr/half_support.h +++ b/torch/csrc/jit/tensorexpr/half_support.h @@ -128,6 +128,46 @@ class HalfRewriter : public IRMutator { return v; } + template + ExprPtr mutateArithmetic(T v) { + IRMutator::mutate(v); + if (v->dtype().scalar_type() == c10::kHalf) { + v->set_dtype(v->dtype().cloneWithScalarType(c10::kFloat)); + } + return v; + } + + ExprPtr mutate(AddPtr v) override { + return mutateArithmetic(v); + } + ExprPtr mutate(SubPtr v) override { + return mutateArithmetic(v); + } + ExprPtr mutate(MulPtr v) override { + return mutateArithmetic(v); + } + ExprPtr mutate(DivPtr v) override { + return mutateArithmetic(v); + } + ExprPtr mutate(MaxPtr v) override { + return mutateArithmetic(v); + } + ExprPtr mutate(MinPtr v) override { + return mutateArithmetic(v); + } + ExprPtr mutate(CompareSelectPtr v) override { + return mutateArithmetic(v); + } + ExprPtr mutate(BroadcastPtr v) override { + return mutateArithmetic(v); + } + ExprPtr mutate(IfThenElsePtr v) override { + return mutateArithmetic(v); + } + ExprPtr mutate(IntrinsicsPtr v) override { + return mutateArithmetic(v); + } + private: std::unordered_set inserted_half_casts_; std::unordered_map var_map; diff --git a/torch/csrc/jit/tensorexpr/ir_verifier.cpp b/torch/csrc/jit/tensorexpr/ir_verifier.cpp index f7adbde..f31a935 100644 --- a/torch/csrc/jit/tensorexpr/ir_verifier.cpp +++ b/torch/csrc/jit/tensorexpr/ir_verifier.cpp @@ -119,7 +119,19 @@ void IRVerifier::visit(IfThenElsePtr v) { } void IRVerifier::visit(IntrinsicsPtr v) { + if (v->op_type() == kIsNan) { + if (v->dtype().scalar_type() != c10::kInt) { + throw malformed_ir("bad dtype in intrinsic arg"); + } + IRVisitor::visit(v); + return; + } // TODO: add a check for OpArgCount and op_type + for (auto const& param : v->params()) { + if (param->dtype() != v->dtype()) { + throw malformed_ir("bad dtype in intrinsic arg"); + } + } IRVisitor::visit(v); } -- 2.7.4