From 20b63aa9776c7dbe8358daf1111464cb30ce08b6 Mon Sep 17 00:00:00 2001 From: Elias Ellison Date: Mon, 1 Apr 2019 15:33:35 -0700 Subject: [PATCH] Peephole Optimize Shape Ops (#18549) Summary: Peephole optimize ops that just require Dimensioned Tensor Type, which is what we specialize graphs on. Pull Request resolved: https://github.com/pytorch/pytorch/pull/18549 Differential Revision: D14690827 Pulled By: eellison fbshipit-source-id: 9d7439eb584f0a5b877f5aa53cf80150f00e7e5f --- test/test_jit.py | 56 ++++++++++++++++++++++++++++++++++++++ torch/csrc/jit/passes/peephole.cpp | 30 ++++++++++++++++++++ 2 files changed, 86 insertions(+) diff --git a/test/test_jit.py b/test/test_jit.py index 7820a3c..89d3592 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -756,6 +756,62 @@ class TestJit(JitTestCase): self.run_pass('peephole', trace.graph) self.assertTrue(len(list(trace.graph.nodes())) == 0) + def test_peephole_optimize_shape_ops(self): + def test_input(func, input, result): + self.assertEqual(func(input), result) + gre = func.graph_for(input) + FileCheck().check_not("prim::If").run(gre) + + def test_dim(): + @torch.jit.script + def func(x): + if x.dim() == 1: + return 1 + else: + return 2 + + test_input(func, torch.tensor([0.5]), 1) + test_input(func, torch.tensor([[0.5]]), 2) + test_dim() + + def test_dtype(): + @torch.jit.script + def func(x): + if x.dtype == torch.float32: + return 1 + else: + return 2 + + test_input(func, torch.tensor(0.5, dtype=torch.float32), 1) + test_input(func, torch.tensor(0.5, dtype=torch.int64), 2) + test_dtype() + + def test_device(): + @torch.jit.script + def func_1(x): + if x.device == torch.device('cuda:0'): + a = 0 + else: + a = 1 + return a + + @torch.jit.script + def func_2(x): + if x.is_cuda: + a = 0 + else: + a = 1 + return a + + test_input(func_1, torch.tensor(0.5), 1) + test_input(func_2, torch.tensor(0.5), 1) + + if RUN_CUDA: + test_input(func_1, torch.tensor(0.5, device="cuda:0"), 0) + test_input(func_2, torch.tensor(0.5, device="cuda:0"), 0) + + test_device() + def test_index(self): x = torch.tensor([0.4], requires_grad=True) y = torch.tensor([0], dtype=torch.int64) diff --git a/torch/csrc/jit/passes/peephole.cpp b/torch/csrc/jit/passes/peephole.cpp index 5a55b97..ff2bd6d 100644 --- a/torch/csrc/jit/passes/peephole.cpp +++ b/torch/csrc/jit/passes/peephole.cpp @@ -208,6 +208,36 @@ void PeepholeOptimizeImpl(Block* block, bool addmm_fusion_enabled) { if (input->mustNotBeNone()) { node->output()->replaceAllUsesWith(node->input()); } + } else if (node->matches("prim::dtype(Tensor a) -> int")) { + if (auto dim_tensor = + node->input()->type()->cast()) { + WithInsertPoint guard(node); + auto output = node->owningGraph()->insertConstant( + static_cast(dim_tensor->scalarType())); + node->output()->replaceAllUsesWith(output); + } + } else if (node->matches("prim::device(Tensor a) -> Device")) { + if (auto dim_tensor = + node->input()->type()->cast()) { + WithInsertPoint guard(node); + auto output = node->owningGraph()->insertConstant(dim_tensor->device()); + node->output()->replaceAllUsesWith(output); + } + } else if (node->matches("aten::dim(Tensor self) -> int")) { + if (auto dim_tensor = + node->input()->type()->cast()) { + WithInsertPoint guard(node); + auto output = node->owningGraph()->insertConstant(dim_tensor->dim()); + node->output()->replaceAllUsesWith(output); + } + } else if (node->matches("prim::is_cuda(Tensor a) -> bool")) { + if (auto dim_tensor = + node->input()->type()->cast()) { + WithInsertPoint guard(node); + auto output = + node->owningGraph()->insertConstant(dim_tensor->device().is_cuda()); + node->output()->replaceAllUsesWith(output); + } } } } -- 2.7.4