Actually model scalar type promotion in shape analysis (#18811)
authorJames Reed <jamesreed@fb.com>
Thu, 4 Apr 2019 19:53:44 +0000 (12:53 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 4 Apr 2019 19:56:40 +0000 (12:56 -0700)
Summary:
This was causing some numerical issues in the fuser
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18811

Differential Revision: D14767390

Pulled By: jamesr66a

fbshipit-source-id: f1123d1aab5501abad850d2edc996f8aa8dafe04

test/test_jit.py
torch/csrc/jit/passes/shape_analysis.cpp

index 931f6e8..4c38dff 100644 (file)
@@ -4314,6 +4314,24 @@ a")
         a = next(graph.outputs()).type().kind()
         self.assertTrue(next(graph.outputs()).type().kind() != 'TensorType')
 
+    def test_shape_prop_promotion(self):
+        @torch.jit.script
+        def fn(x, y):
+            return x + y
+
+        x, y = torch.rand(3, 4, dtype=torch.float), torch.rand(3, 4, dtype=torch.double)
+        graph = fn._get_method('forward').propagate_shapes((x, y), False)
+        FileCheck().check('Double(*, *) = aten::add').run(graph)
+
+    def test_shape_prop_promote_scalar_arg(self):
+        @torch.jit.script
+        def fn(x):
+            return math.pi + x
+
+        x = torch.zeros(3, 4, dtype=torch.long)
+        graph = fn._get_method('forward').propagate_shapes((x,), False)
+        FileCheck().check('Long(*, *) = aten::add').run(graph)
+
     def test_integral_shape_inference(self):
         cu = torch.jit.CompilationUnit('''
         def test_integral_shape_inference(a):
index 5f63188..fb334c2 100644 (file)
@@ -767,6 +767,42 @@ class ShapePropagator {
             "aten::__ilshift__(Tensor self, Tensor other) -> Tensor",
             "aten::__irshift__(Tensor self, Tensor other) -> Tensor",
 
+            // Ops with Tensor-Tensor overloads only
+            "aten::atan2(Tensor self, Tensor other) -> Tensor",
+        },
+        [this](Node* node) -> type_vec_t {
+          if (auto maybe_tensor_types =
+                  gatherTensorTypes<DimensionedTensorType>(node)) {
+            AT_ASSERT(maybe_tensor_types->size() == 2);
+            auto first_scalar_type = (*maybe_tensor_types)[0]->scalarType();
+            auto second_scalar_type = (*maybe_tensor_types)[1]->scalarType();
+            size_t arg_for_type = 0;
+            if (c10::promoteTypes(first_scalar_type, second_scalar_type) != first_scalar_type) {
+              arg_for_type = 1;
+            }
+            return {broadcast(*maybe_tensor_types, arg_for_type)};
+          }
+          return {};
+        }};
+
+    static const register_formula_for fused_accum_binary_ops{
+        {
+            // Non-binary ops
+            "aten::addcdiv(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value) -> Tensor",
+            "aten::addcmul(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value) -> Tensor",
+        },
+        [this](Node* node) -> type_vec_t {
+          if (auto maybe_tensor_types =
+                  gatherTensorTypes<DimensionedTensorType>(node)) {
+            return {broadcast(*maybe_tensor_types, 0)};
+          }
+          return {};
+        }};
+
+    // NB: we always take the scalar type of the Tensor
+    static const register_formula_for broadcasting_tensor_scalar_ops{
+        {
+
             // Tensor-Scalar operators
             "aten::add(Tensor self, Scalar other, Scalar alpha) -> Tensor",
             "aten::sub(Tensor self, Scalar other, Scalar alpha) -> Tensor",
@@ -786,13 +822,6 @@ class ShapePropagator {
             "aten::__ixor__(Tensor self, Scalar other) -> Tensor",
             "aten::__ilshift__(Tensor self, Scalar other) -> Tensor",
             "aten::__irshift__(Tensor self, Scalar other) -> Tensor",
-
-            // Ops with Tensor-Tensor overloads only
-            "aten::atan2(Tensor self, Tensor other) -> Tensor",
-
-            // Non-binary ops
-            "aten::addcdiv(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value) -> Tensor",
-            "aten::addcmul(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value) -> Tensor",
         },
         [this](Node* node) -> type_vec_t {
           if (auto maybe_tensor_types =