[TensorExpr] Move 2 graph passes from kernel.cpp to graph_opt.cpp (#64828)
authorMikhail Zolotukhin <mvz@fb.com>
Sat, 11 Sep 2021 17:21:42 +0000 (10:21 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Sat, 11 Sep 2021 17:23:15 +0000 (10:23 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64828

Also, make `removeUnusedSelfArgument` more consistent with other passes
by mutating the graph in-place rather than returning a copy.

Test Plan: Imported from OSS

Reviewed By: ngimel

Differential Revision: D30870776

Pulled By: ZolotukhinM

fbshipit-source-id: 4873f01b013921143a5aa43746d655a2d8d620c9

test/test_tensorexpr_pybind.py
torch/csrc/jit/tensorexpr/graph_opt.cpp
torch/csrc/jit/tensorexpr/graph_opt.h
torch/csrc/jit/tensorexpr/kernel.cpp
torch/csrc/jit/tensorexpr/kernel.h
torch/csrc/jit/tensorexpr/tensorexpr_init.cpp

index 6a34805..9a70838 100644 (file)
@@ -166,7 +166,7 @@ graph(%a.1 : Float(requires_grad=0, device=cpu),
 
         graph_str = """
 graph(%a : Tensor, %b : Tensor):
-  %c : Float(4, 4, strides=[4, 1], requires_grad=0, device=cpu) = aten::mul(%a, %b)
+  %c : Tensor = aten::mul(%a, %b)
   return (%c)
         """
         graph = torch._C.parse_ir(graph_str)
@@ -184,10 +184,7 @@ graph(%a : Tensor, %b : Tensor):
         # Inject shape info and try compiling again
         example_inputs = [torch.rand(4, 4), torch.rand(4, 4)]
         torch._C._te.annotate_input_shapes(graph, example_inputs)
-
-        # TODO: once we have shape propagation as well we should erase type
-        # info for %c from the input IR and run shape propagation here - it
-        # should be able to reconstruct that info
+        torch._C._jit_pass_propagate_shapes_on_graph(graph)
 
         # Now compilation should pass
         kernel = torch._C._te.TensorExprKernel(graph)
@@ -197,7 +194,6 @@ graph(%a : Tensor, %b : Tensor):
         np.testing.assert_allclose(res.numpy(), correct.numpy(), atol=1e-5)
 
     @unittest.skipIf(not LLVM_ENABLED, "LLVM backend not enabled")
-    @unittest.skip("Does not work until shape propagation is implemented")
     def test_kernel_shape_prop_module(self):
         class TestModule(torch.nn.Module):
             def forward(self, x, y):
@@ -228,14 +224,11 @@ graph(%a : Tensor, %b : Tensor):
         assert exception_thrown
 
         # Remove 'self' argument and try annotating shapes one more time
-        graph = torch._C._te.remove_unused_self_argument(graph)
+        torch._C._te.remove_unused_self_argument(graph)
 
         # Inject shape info and try compiling again
         torch._C._te.annotate_input_shapes(graph, example_inputs)
-
-        # TODO: once we have shape propagation as well we should erase type
-        # info for %c from the input IR and run shape propagation here - it
-        # should be able to reconstruct that info
+        torch._C._jit_pass_propagate_shapes_on_graph(graph)
 
         # Now compilation should pass
         kernel = torch._C._te.TensorExprKernel(graph)
index d55ea05..b264a9e 100644 (file)
@@ -178,6 +178,34 @@ bool OptimizeCat(const std::shared_ptr<Graph>& graph) {
   return false;
 }
 
+void annotateInputShapes(
+    const std::shared_ptr<Graph>& graph,
+    const std::vector<c10::optional<at::Tensor>>& example_inputs) {
+  TORCH_INTERNAL_ASSERT(
+      graph->inputs().size() == example_inputs.size(),
+      buildErrorMessage("Given inputs do not match the fuser graph inputs."));
+  for (size_t idx = 0; idx < example_inputs.size(); idx++) {
+    if (auto t = example_inputs[idx]) {
+      auto concrete_tensor_type = tensorTypeInCurrentExecutionContext(*t);
+      graph->inputs().at(idx)->setType(concrete_tensor_type);
+    }
+  }
+}
+
+std::shared_ptr<Graph> removeUnusedSelfArgument(
+    const std::shared_ptr<Graph>& graph) {
+  if (graph->inputs().size() == 0) {
+    return graph;
+  }
+  jit::Value* self_argument = graph->inputs().at(0);
+  if (self_argument->uses().size() != 0 ||
+      !self_argument->type()->is_module()) {
+    return graph;
+  }
+  graph->eraseInput(0);
+  return graph;
+}
+
 } // namespace tensorexpr
 } // namespace jit
 } // namespace torch
index 5a81553..ab1cbf0 100644 (file)
@@ -58,6 +58,12 @@ namespace tensorexpr {
 
 bool OptimizeCat(const std::shared_ptr<Graph>& graph);
 
+TORCH_API void annotateInputShapes(
+    const std::shared_ptr<Graph>& graph,
+    const std::vector<c10::optional<at::Tensor>>& example_inputs);
+TORCH_API std::shared_ptr<Graph> removeUnusedSelfArgument(
+    const std::shared_ptr<Graph>& graph);
+
 } // namespace tensorexpr
 } // namespace jit
 } // namespace torch
index 15ef427..b38de4f 100644 (file)
@@ -374,35 +374,6 @@ bool matmulIsSupported(const torch::jit::Node* node) {
   return true;
 }
 
-void annotateInputShapes(
-    const std::shared_ptr<Graph>& graph,
-    const std::vector<c10::optional<at::Tensor>>& example_inputs) {
-  TORCH_INTERNAL_ASSERT(
-      graph->inputs().size() == example_inputs.size(),
-      buildErrorMessage("Given inputs do not match the fuser graph inputs."));
-  for (size_t idx = 0; idx < example_inputs.size(); idx++) {
-    if (auto t = example_inputs[idx]) {
-      auto concrete_tensor_type = tensorTypeInCurrentExecutionContext(*t);
-      graph->inputs().at(idx)->setType(concrete_tensor_type);
-    }
-  }
-}
-
-std::shared_ptr<Graph> removeUnusedSelfArgument(
-    const std::shared_ptr<Graph>& graph) {
-  if (graph->inputs().size() == 0) {
-    return graph;
-  }
-  jit::Value* self_argument = graph->inputs().at(0);
-  if (self_argument->uses().size() != 0 ||
-      !self_argument->type()->is_module()) {
-    return graph;
-  }
-  std::shared_ptr<Graph> res = graph->copy();
-  res->eraseInput(0);
-  return res;
-}
-
 std::vector<ExprHandle> valueShape(const ArgValue& v) {
   if (auto b = c10::get_if<tensorexpr::BufHandle>(&v)) {
     return b->dims();
index 803dc4a..075ae6c 100644 (file)
@@ -306,12 +306,6 @@ TORCH_API bool& getOptConditionals();
 TORCH_API c10::optional<at::Device> pickDeviceType(
     const at::ArrayRef<torch::jit::Value*>& inputs);
 
-TORCH_API void annotateInputShapes(
-    const std::shared_ptr<Graph>& graph,
-    const std::vector<c10::optional<at::Tensor>>& example_inputs);
-TORCH_API std::shared_ptr<Graph> removeUnusedSelfArgument(
-    const std::shared_ptr<Graph>& graph);
-
 } // namespace tensorexpr
 } // namespace jit
 } // namespace torch
index 7e93092..27364cf 100644 (file)
@@ -6,6 +6,7 @@
 #ifdef USE_CUDA
 #include <torch/csrc/jit/tensorexpr/cuda_codegen.h>
 #endif
+#include <torch/csrc/jit/tensorexpr/graph_opt.h>
 #include <torch/csrc/jit/tensorexpr/ir_printer.h>
 #include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
 #include <torch/csrc/jit/tensorexpr/kernel.h>