[nnc] Fix half2float conversion and re-enable float16 (#64199)
authorBert Maher <bertrand@fb.com>
Tue, 31 Aug 2021 01:36:33 +0000 (18:36 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Tue, 31 Aug 2021 01:37:55 +0000 (18:37 -0700)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64199

Test Plan: Imported from OSS

Reviewed By: ZolotukhinM

Differential Revision: D30643865

Pulled By: bertmaher

fbshipit-source-id: 9de6adca53bd08839328cbaf6364f7de9550264b

test/test_jit_fuser_te.py
test/test_tensorexpr.py
torch/csrc/jit/passes/tensorexpr_fuser.cpp
torch/csrc/jit/tensorexpr/half_support.h
torch/csrc/jit/tensorexpr/ir_verifier.cpp

index 6d2432a..918cc70 100644 (file)
@@ -94,8 +94,7 @@ class TestTEFuser(JitTestCase):
             torch.bool,
         ]
         self.fp_dtypes = [
-            # TODO: Add back when https://github.com/pytorch/pytorch/issues/55905 is closed
-            # torch.float16,
+            torch.float16,
             torch.float32,
             torch.float64,
         ]
@@ -1130,8 +1129,7 @@ class TestTEFuser(JitTestCase):
         dtypes = [
             torch.bool,
             torch.int,
-            # TODO: Add back when https://github.com/pytorch/pytorch/issues/55905 is closed
-            # torch.float16,
+            torch.float16,
             torch.float32,
             torch.float64,
         ]
@@ -1146,6 +1144,9 @@ class TestTEFuser(JitTestCase):
 
         bad_dtypes = []
         for dtype, output_dtype, device, size in product(dtypes, dtypes, self.devices, sizes):
+            # TODO: Add back when https://github.com/pytorch/pytorch/issues/55905 is closed
+            if dtype == torch.float16 and device == "cpu":
+                continue
             if dtype == output_dtype:
                 continue
 
@@ -1201,18 +1202,16 @@ class TestTEFuser(JitTestCase):
             torch.int16,
             torch.int32,
             torch.int64,
-            # TODO: Add back when https://github.com/pytorch/pytorch/issues/55905 is closed
-            # torch.float16,
+            torch.float16,
             torch.float32,
             torch.float64,
             torch.bool,
         ]
 
         for inp, device, dtype in product(inputs, self.devices, dtypes):
-            # TODO
-            if dtype == torch.float16 and not LLVM_ENABLED:
+            # TODO: Add back when https://github.com/pytorch/pytorch/issues/55905 is closed
+            if dtype == torch.float16 and device == "cpu":
                 continue
-
             inp = inp.to(device=device, dtype=dtype)
             try:
                 f = torch.jit.trace(lambda x: x.isnan(), (inp,))
@@ -1272,6 +1271,9 @@ class TestTEFuser(JitTestCase):
         gpu_only = {torch.erf, torch.erfc}
         sizes = [(1,), (2,), (4, 4)]
         for dtype, op, device, size in product(self.dtypes, unary_ops, self.devices, sizes):
+            # TODO: Add back when https://github.com/pytorch/pytorch/issues/55905 is closed
+            if dtype == torch.float16 and device == "cpu":
+                continue
             if op in gpu_only and device == "cpu":
                 continue
             try:
@@ -1323,6 +1325,8 @@ class TestTEFuser(JitTestCase):
         ]
         devices = self.devices
         for dtype, op, device in product(self.dtypes, binary_ops, devices):
+            if dtype == torch.float16 and device == "cpu":
+                continue
             try:
                 x = self.data_for(dtype, device)
                 y = self.data_for(dtype, device)
@@ -1373,6 +1377,8 @@ class TestTEFuser(JitTestCase):
                                      "[[10, 3, 4], [4, 5]]",
                                      ]
         for dtype, size, device in product(self.dtypes, sizes, devices):
+            if dtype == torch.float16 and device == "cpu":
+                continue
             try:
                 size_x, size_y = size
                 x = self.data_for(dtype, device, size=size_x)
@@ -1417,6 +1423,8 @@ class TestTEFuser(JitTestCase):
         # only using  scalar values relevant to particular ops
         scalars = [1.5, 3, 0, -2.0, -1]
         for dtype, op, device, scalar in product(self.dtypes, binary_ops, devices, scalars):
+            if dtype == torch.float16 and device == "cpu":
+                continue
             try:
                 x = self.data_for(dtype, device)
                 fn = apply_with_scalar(op, scalar)
@@ -1449,6 +1457,8 @@ class TestTEFuser(JitTestCase):
         # only using  scalar values relevant to particular ops
         scalars = [1.5, 3, -2.0, -1]  # skip 0
         for dtype, op, device, scalar in product(self.dtypes, binary_ops, devices, scalars):
+            if dtype == torch.float16 and device == "cpu":
+                continue
             try:
                 x = self.data_for(dtype, device)
                 fn = apply_with_scalar(op, scalar)
@@ -1484,6 +1494,8 @@ class TestTEFuser(JitTestCase):
         # only using  scalar values relevant to particular ops
         scalars = [1.5, 3, 0, -2.0, -1]
         for dtype, op, device, scalar in product(dtypes, binary_ops, self.devices, scalars):
+            if dtype == torch.float16 and device == "cpu":
+                continue
             try:
                 x = self.data_for(dtype, device)
                 fn = apply_with_scalar(op, scalar)
@@ -1512,6 +1524,8 @@ class TestTEFuser(JitTestCase):
         ]
         devices = self.devices
         for dtype, op, device in product(self.dtypes, ternary_ops, devices):
+            if dtype == torch.float16 and device == "cpu":
+                continue
             try:
                 x = self.data_for(dtype, device)
                 y = self.data_for(dtype, device)
@@ -1541,6 +1555,8 @@ class TestTEFuser(JitTestCase):
         ]
         devices = self.devices
         for dtype, op, device in product(self.dtypes, ternary_ops, devices):
+            if dtype == torch.float16 and device == "cpu":
+                continue
             try:
                 x = self.data_for(dtype, device, size=[5, 3, 128, 128])
                 y = self.data_for(dtype, device, size=[3])
@@ -1572,6 +1588,8 @@ class TestTEFuser(JitTestCase):
             torch.cat,
         ]
         for dtype, op, device in product(self.dtypes, list_ops, devices):
+            if dtype == torch.float16 and device == "cpu":
+                continue
             try:
                 x = self.data_for(dtype, device, size=[5, 4, 1, 7])
                 y = self.data_for(dtype, device, size=[5, 4, 1, 7])
@@ -1603,6 +1621,8 @@ class TestTEFuser(JitTestCase):
         ]
         devices = self.devices
         for dtype, op, device in product(self.dtypes, ops, devices):
+            if dtype == torch.float16 and device == "cpu":
+                continue
             try:
                 cond = self.data_for(torch.bool, device)
                 x = self.data_for(dtype, device)
@@ -1768,7 +1788,10 @@ class TestTEFuser(JitTestCase):
         with inline_fusion_groups():
             def eager(x, y):
                 return torch.cat((x, y.type_as(x)), dim=1)
-            for dtype1, dtype2 in product(self.dtypes, self.dtypes):
+            dtypes = self.dtypes.copy()
+            # CPU fuser doesn't support float16.
+            dtypes.remove(torch.float16)
+            for dtype1, dtype2 in product(dtypes, dtypes):
                 x = torch.randint(2, (1, 13,)).to(dtype1)
                 zero = torch.tensor([[0]]).to(dtype2)
                 one = torch.tensor([[1]]).to(dtype2)
@@ -1936,7 +1959,6 @@ class TestTEFuser(JitTestCase):
             for fn in [bn, bn_no_weight, bn_no_bias, bn_neither]:
                 test(fn, (i, x))
 
-
 works_list = [
     '__radd__',
     '__rdiv__',
index 47c7e68..366c262 100644 (file)
@@ -1222,7 +1222,6 @@ class TestTensorExprFuser(BaseTestClass):
             x = warmup_and_run_forward(traced, a, b)
             self.assertLastGraphAllFused()
 
-    @unittest.skip("float16 is not supported yet.")
     def test_half_bn_relu(self):
         devices = ["cuda"] if torch.cuda.is_available() else []
 
index 1d5128c..a3e3707 100644 (file)
@@ -966,7 +966,7 @@ class TensorExprFuser {
         // but on top of that Float16 has a few kinks on LLVM.  Thus, on CPU we
         // additionally disable it until we either move to a more stable version
         // or find workarounds.
-        if (*st == c10::ScalarType::Half) {
+        if (*st == c10::ScalarType::Half && *device == c10::kCPU) {
           return false;
         }
 
index eaf74d3..674af8a 100644 (file)
@@ -128,6 +128,46 @@ class HalfRewriter : public IRMutator {
     return v;
   }
 
+  template <typename T>
+  ExprPtr mutateArithmetic(T v) {
+    IRMutator::mutate(v);
+    if (v->dtype().scalar_type() == c10::kHalf) {
+      v->set_dtype(v->dtype().cloneWithScalarType(c10::kFloat));
+    }
+    return v;
+  }
+
+  ExprPtr mutate(AddPtr v) override {
+    return mutateArithmetic(v);
+  }
+  ExprPtr mutate(SubPtr v) override {
+    return mutateArithmetic(v);
+  }
+  ExprPtr mutate(MulPtr v) override {
+    return mutateArithmetic(v);
+  }
+  ExprPtr mutate(DivPtr v) override {
+    return mutateArithmetic(v);
+  }
+  ExprPtr mutate(MaxPtr v) override {
+    return mutateArithmetic(v);
+  }
+  ExprPtr mutate(MinPtr v) override {
+    return mutateArithmetic(v);
+  }
+  ExprPtr mutate(CompareSelectPtr v) override {
+    return mutateArithmetic(v);
+  }
+  ExprPtr mutate(BroadcastPtr v) override {
+    return mutateArithmetic(v);
+  }
+  ExprPtr mutate(IfThenElsePtr v) override {
+    return mutateArithmetic(v);
+  }
+  ExprPtr mutate(IntrinsicsPtr v) override {
+    return mutateArithmetic(v);
+  }
+
  private:
   std::unordered_set<ExprPtr> inserted_half_casts_;
   std::unordered_map<VarPtr, VarPtr> var_map;
index f7adbde..f31a935 100644 (file)
@@ -119,7 +119,19 @@ void IRVerifier::visit(IfThenElsePtr v) {
 }
 
 void IRVerifier::visit(IntrinsicsPtr v) {
+  if (v->op_type() == kIsNan) {
+    if (v->dtype().scalar_type() != c10::kInt) {
+      throw malformed_ir("bad dtype in intrinsic arg");
+    }
+    IRVisitor::visit(v);
+    return;
+  }
   // TODO: add a check for OpArgCount and op_type
+  for (auto const& param : v->params()) {
+    if (param->dtype() != v->dtype()) {
+      throw malformed_ir("bad dtype in intrinsic arg");
+    }
+  }
   IRVisitor::visit(v);
 }