Peephole Optimize Shape Ops (#18549)
authorElias Ellison <eellison@fb.com>
Mon, 1 Apr 2019 22:33:35 +0000 (15:33 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Mon, 1 Apr 2019 22:39:43 +0000 (15:39 -0700)
Summary:
Peephole optimize ops that just require Dimensioned Tensor Type, which is what we specialize graphs on.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18549

Differential Revision: D14690827

Pulled By: eellison

fbshipit-source-id: 9d7439eb584f0a5b877f5aa53cf80150f00e7e5f

test/test_jit.py
torch/csrc/jit/passes/peephole.cpp

index 7820a3c..89d3592 100644 (file)
@@ -756,6 +756,62 @@ class TestJit(JitTestCase):
         self.run_pass('peephole', trace.graph)
         self.assertTrue(len(list(trace.graph.nodes())) == 0)
 
+    def test_peephole_optimize_shape_ops(self):
+        def test_input(func, input, result):
+            self.assertEqual(func(input), result)
+            gre = func.graph_for(input)
+            FileCheck().check_not("prim::If").run(gre)
+
+        def test_dim():
+            @torch.jit.script
+            def func(x):
+                if x.dim() == 1:
+                    return 1
+                else:
+                    return 2
+
+            test_input(func, torch.tensor([0.5]), 1)
+            test_input(func, torch.tensor([[0.5]]), 2)
+        test_dim()
+
+        def test_dtype():
+            @torch.jit.script
+            def func(x):
+                if x.dtype == torch.float32:
+                    return 1
+                else:
+                    return 2
+
+            test_input(func, torch.tensor(0.5, dtype=torch.float32), 1)
+            test_input(func, torch.tensor(0.5, dtype=torch.int64), 2)
+        test_dtype()
+
+        def test_device():
+            @torch.jit.script
+            def func_1(x):
+                if x.device == torch.device('cuda:0'):
+                    a = 0
+                else:
+                    a = 1
+                return a
+
+            @torch.jit.script
+            def func_2(x):
+                if x.is_cuda:
+                    a = 0
+                else:
+                    a = 1
+                return a
+
+            test_input(func_1, torch.tensor(0.5), 1)
+            test_input(func_2, torch.tensor(0.5), 1)
+
+            if RUN_CUDA:
+                test_input(func_1, torch.tensor(0.5, device="cuda:0"), 0)
+                test_input(func_2, torch.tensor(0.5, device="cuda:0"), 0)
+
+        test_device()
+
     def test_index(self):
         x = torch.tensor([0.4], requires_grad=True)
         y = torch.tensor([0], dtype=torch.int64)
index 5a55b97..ff2bd6d 100644 (file)
@@ -208,6 +208,36 @@ void PeepholeOptimizeImpl(Block* block, bool addmm_fusion_enabled) {
       if (input->mustNotBeNone()) {
         node->output()->replaceAllUsesWith(node->input());
       }
+    } else if (node->matches("prim::dtype(Tensor a) -> int")) {
+      if (auto dim_tensor =
+              node->input()->type()->cast<DimensionedTensorType>()) {
+        WithInsertPoint guard(node);
+        auto output = node->owningGraph()->insertConstant(
+            static_cast<int64_t>(dim_tensor->scalarType()));
+        node->output()->replaceAllUsesWith(output);
+      }
+    } else if (node->matches("prim::device(Tensor a) -> Device")) {
+      if (auto dim_tensor =
+              node->input()->type()->cast<DimensionedTensorType>()) {
+        WithInsertPoint guard(node);
+        auto output = node->owningGraph()->insertConstant(dim_tensor->device());
+        node->output()->replaceAllUsesWith(output);
+      }
+    } else if (node->matches("aten::dim(Tensor self) -> int")) {
+      if (auto dim_tensor =
+              node->input()->type()->cast<DimensionedTensorType>()) {
+        WithInsertPoint guard(node);
+        auto output = node->owningGraph()->insertConstant(dim_tensor->dim());
+        node->output()->replaceAllUsesWith(output);
+      }
+    } else if (node->matches("prim::is_cuda(Tensor a) -> bool")) {
+      if (auto dim_tensor =
+              node->input()->type()->cast<DimensionedTensorType>()) {
+        WithInsertPoint guard(node);
+        auto output =
+            node->owningGraph()->insertConstant(dim_tensor->device().is_cuda());
+        node->output()->replaceAllUsesWith(output);
+      }
     }
   }
 }