self.run_pass('peephole', trace.graph)
self.assertTrue(len(list(trace.graph.nodes())) == 0)
+ def test_peephole_optimize_shape_ops(self):
+ def test_input(func, input, result):
+ self.assertEqual(func(input), result)
+ gre = func.graph_for(input)
+ FileCheck().check_not("prim::If").run(gre)
+
+ def test_dim():
+ @torch.jit.script
+ def func(x):
+ if x.dim() == 1:
+ return 1
+ else:
+ return 2
+
+ test_input(func, torch.tensor([0.5]), 1)
+ test_input(func, torch.tensor([[0.5]]), 2)
+ test_dim()
+
+ def test_dtype():
+ @torch.jit.script
+ def func(x):
+ if x.dtype == torch.float32:
+ return 1
+ else:
+ return 2
+
+ test_input(func, torch.tensor(0.5, dtype=torch.float32), 1)
+ test_input(func, torch.tensor(0.5, dtype=torch.int64), 2)
+ test_dtype()
+
+ def test_device():
+ @torch.jit.script
+ def func_1(x):
+ if x.device == torch.device('cuda:0'):
+ a = 0
+ else:
+ a = 1
+ return a
+
+ @torch.jit.script
+ def func_2(x):
+ if x.is_cuda:
+ a = 0
+ else:
+ a = 1
+ return a
+
+ test_input(func_1, torch.tensor(0.5), 1)
+ test_input(func_2, torch.tensor(0.5), 1)
+
+ if RUN_CUDA:
+ test_input(func_1, torch.tensor(0.5, device="cuda:0"), 0)
+ test_input(func_2, torch.tensor(0.5, device="cuda:0"), 0)
+
+ test_device()
+
def test_index(self):
x = torch.tensor([0.4], requires_grad=True)
y = torch.tensor([0], dtype=torch.int64)
if (input->mustNotBeNone()) {
node->output()->replaceAllUsesWith(node->input());
}
+ } else if (node->matches("prim::dtype(Tensor a) -> int")) {
+ if (auto dim_tensor =
+ node->input()->type()->cast<DimensionedTensorType>()) {
+ WithInsertPoint guard(node);
+ auto output = node->owningGraph()->insertConstant(
+ static_cast<int64_t>(dim_tensor->scalarType()));
+ node->output()->replaceAllUsesWith(output);
+ }
+ } else if (node->matches("prim::device(Tensor a) -> Device")) {
+ if (auto dim_tensor =
+ node->input()->type()->cast<DimensionedTensorType>()) {
+ WithInsertPoint guard(node);
+ auto output = node->owningGraph()->insertConstant(dim_tensor->device());
+ node->output()->replaceAllUsesWith(output);
+ }
+ } else if (node->matches("aten::dim(Tensor self) -> int")) {
+ if (auto dim_tensor =
+ node->input()->type()->cast<DimensionedTensorType>()) {
+ WithInsertPoint guard(node);
+ auto output = node->owningGraph()->insertConstant(dim_tensor->dim());
+ node->output()->replaceAllUsesWith(output);
+ }
+ } else if (node->matches("prim::is_cuda(Tensor a) -> bool")) {
+ if (auto dim_tensor =
+ node->input()->type()->cast<DimensionedTensorType>()) {
+ WithInsertPoint guard(node);
+ auto output =
+ node->owningGraph()->insertConstant(dim_tensor->device().is_cuda());
+ node->output()->replaceAllUsesWith(output);
+ }
}
}
}