get propagate_shape logic out of module.h (#19137)
authorZachary DeVito <zdevito@fb.com>
Sat, 13 Apr 2019 15:28:11 +0000 (08:28 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Sat, 13 Apr 2019 15:42:17 +0000 (08:42 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19137
ghimport-source-id: 2394765f2d401e68ffdfa4c985bfab4cca2517f8

Reviewed By: jamesr66a

Differential Revision: D14885946

Pulled By: zdevito

fbshipit-source-id: daa2894ed9761107e9d273bb172840dc23ace072

test/test_jit.py
torch/csrc/jit/script/init.cpp
torch/csrc/jit/script/module.h
torch/onnx/utils.py

index 015c67c..56f0fc2 100644 (file)
@@ -42,7 +42,7 @@ from common_methods_invocations import method_tests as autograd_method_tests
 from common_methods_invocations import create_input, unpack_variables, \
     exclude_tensor_method, non_differentiable, EXCLUDE_GRADCHECK, EXCLUDE_FUNCTIONAL
 from torch.testing import FileCheck
-from torch._C import TensorType, parse_ir
+from torch._C import TensorType, parse_ir, _propagate_shapes
 from copy import deepcopy
 import random
 from typing import List, Dict, Optional, Tuple
@@ -3733,8 +3733,8 @@ a")
         # test that shape analysis is written correctly for sum with IntArrayRef[1] dim argument
         self.run_pass('constant_propagation', func.graph)
         self.run_pass('constant_propagation', func2.graph)
-        g = func._get_method('forward').propagate_shapes((torch.zeros(1, 1, 1, 1, 4),), False)
-        g2 = func2._get_method('forward').propagate_shapes((torch.zeros(1, 1, 1, 1, 4),), False)
+        g = _propagate_shapes(func.graph, (torch.zeros(1, 1, 1, 1, 4),), False)
+        g2 = _propagate_shapes(func2.graph, (torch.zeros(1, 1, 1, 1, 4),), False)
         self.assertTrue(g.findNode("aten::sum").output().type().kind()
                         == "DimensionedTensorType")
         self.assertTrue(g2.findNode("aten::sum").output().type().kind()
@@ -4509,8 +4509,8 @@ a")
                 torch.mul(x, y, out=z)
                 return z
 
-            graph = test._get_method('forward').propagate_shapes(
-                (torch.zeros(2, 1), torch.zeros(1, 2), torch.zeros(1, 1, 1)), False)
+            graph = _propagate_shapes(test.graph,
+                                     (torch.zeros(2, 1), torch.zeros(1, 2), torch.zeros(1, 1, 1)), False)
             self.assertTrue(next(graph.outputs()).type() == TensorType.get())
         out_op_graph_input()
 
@@ -4529,7 +4529,7 @@ a")
                 return after_resize_alias
 
             self.run_pass('constant_propagation', test.graph)
-            g = test._get_method('forward').propagate_shapes((torch.zeros(1, 1),), False)
+            g = _propagate_shapes(test.graph, (torch.zeros(1, 1),), False)
             resize_node = g.findNode("aten::resize_")
             # first input and output of b.resize_ is b
             self.assertTrue(next(resize_node.inputs()).type() == TensorType.get())
@@ -4553,7 +4553,7 @@ a")
 
             g = test.graph
             self.run_pass('constant_propagation', g)
-            g = test._get_method('forward').propagate_shapes((torch.zeros(1, 1),), False)
+            g = _propagate_shapes(test.graph, (torch.zeros(1, 1),), False)
 
             # x doesn't alias a resized op so it shouldn't be set to base Tensor type
             self.assertTrue(next(g.inputs()).type() != TensorType.get())
@@ -4608,7 +4608,7 @@ a")
 
         x = torch.randn(3, 1, 5, requires_grad=True)
         fn = torch.jit.script(fn)
-        graph = fn._get_method('forward').propagate_shapes((x,), False)
+        graph = _propagate_shapes(fn.graph, (x,), False)
         a = next(graph.outputs()).type().kind()
         self.assertTrue(next(graph.outputs()).type().kind() != 'TensorType')
 
@@ -4618,7 +4618,7 @@ a")
             return x + y
 
         x, y = torch.rand(3, 4, dtype=torch.float), torch.rand(3, 4, dtype=torch.double)
-        graph = fn._get_method('forward').propagate_shapes((x, y), False)
+        graph = _propagate_shapes(fn.graph, (x, y), False)
         FileCheck().check('Double(*, *) = aten::add').run(graph)
 
     def test_shape_prop_promote_scalar_arg(self):
@@ -4627,7 +4627,7 @@ a")
             return math.pi + x
 
         x = torch.zeros(3, 4, dtype=torch.long)
-        graph = fn._get_method('forward').propagate_shapes((x,), False)
+        graph = _propagate_shapes(fn.graph, (x,), False)
         FileCheck().check('Long(*, *) = aten::add').run(graph)
 
     def test_integral_shape_inference(self):
@@ -7201,7 +7201,7 @@ a")
             return torch.cat(c)
 
         b = torch.zeros(2, 4)
-        test_list._get_method('forward').propagate_shapes((b,), False)
+        _propagate_shapes(test_list.graph, (b,), False)
 
     def test_if_supertype(self):
         @torch.jit.script
@@ -7218,7 +7218,7 @@ a")
         b = torch.zeros(2, 4, dtype=torch.long)
         c = torch.zeros(2, 4, dtype=torch.float)
 
-        graph = tensor_unifying._get_method('forward').propagate_shapes((a, b, c), False)
+        graph = _propagate_shapes(tensor_unifying.graph, (a, b, c), False)
         if_outputs = list(graph.findNode("prim::If").outputs())
         self.assertTrue(if_outputs[0].type().str() == "Float(*, *)")
         self.assertTrue(if_outputs[1].type().str() == "Tensor")
index 668faa5..caeb618 100644 (file)
@@ -387,8 +387,7 @@ struct ModuleValue : public SugaredValue {
           auto& param = *it;
           params.emplace_back(g.insertGetAttr(self_, param.name()));
         }
-        auto list =
-            g.insertNode(g.createTuple(params))->output();
+        auto list = g.insertNode(g.createTuple(params))->output();
         return std::make_shared<ConstantParameterList>(list);
       }
       if (py::isinstance<py::function>(attr) ||
@@ -700,6 +699,61 @@ static Self moduleSelf(const std::shared_ptr<Module>& m) {
   };
 }
 
+static void setInputTensorTypes(Graph& g, const Stack& stack) {
+  AT_ASSERT(stack.size() == g.inputs().size());
+  for (size_t i = 0; i < stack.size(); ++i) {
+    g.inputs().at(i)->setType(
+        DimensionedTensorType::create(stack.at(i).toTensor()));
+  }
+}
+
+static std::shared_ptr<Graph> _propagate_shapes(
+    Graph& graph,
+    std::vector<at::Tensor> inputs,
+    bool with_grad = false) {
+  Stack stack(inputs.begin(), inputs.end());
+  auto retval = graph.copy();
+  setInputTensorTypes(*retval, stack);
+  PropagateInputShapes(retval);
+  return retval;
+}
+
+static std::shared_ptr<Graph> _propagate_and_assign_input_and_output_shapes(
+    Graph& graph,
+    std::vector<at::Tensor> inputs,
+    std::vector<at::Tensor> outputs,
+    bool with_grad = false,
+    bool propagate = true) {
+  auto retval = graph.copy();
+  if (propagate) {
+    setInputTensorTypes(*retval, fmap<IValue>(inputs));
+    PropagateInputShapes(retval);
+  }
+  AT_ASSERT(retval->inputs().size() == inputs.size());
+  for (size_t i = 0; i < retval->inputs().size(); ++i) {
+    auto scalar_type = inputs[i].scalar_type();
+    auto sizes = inputs[i].sizes();
+    auto type =
+        torch::jit::CompleteTensorType::create(scalar_type, at::kCPU, sizes);
+    retval->inputs()[i]->setType(type);
+  }
+  at::ArrayRef<Value*> output_values = retval->outputs();
+  // patch this to still work if we are returning a tuple of multiple values
+  if (output_values.at(0)->type()->kind() == TupleType::Kind) {
+    AT_ASSERT(output_values.at(0)->node()->kind() == prim::TupleConstruct);
+    output_values = output_values.at(0)->node()->inputs();
+  }
+  AT_ASSERT(output_values.size() == outputs.size());
+  for (size_t i = 0; i < retval->outputs().size(); ++i) {
+    auto scalar_type = outputs[i].scalar_type();
+    auto sizes = outputs[i].sizes();
+    auto type =
+        torch::jit::CompleteTensorType::create(scalar_type, at::kCPU, sizes);
+    output_values[i]->setType(type);
+  }
+  return retval;
+}
+
 void initJitScriptBindings(PyObject* module) {
   auto m = py::handle(module).cast<py::module>();
 
@@ -984,7 +1038,6 @@ void initJitScriptBindings(PyObject* module) {
           });
 
   py::class_<Method>(m, "ScriptMethod", py::dynamic_attr())
-      .def("graph", [&](Method& self) { return self.graph(); })
       .def(
           "__call__",
           [](py::args args, py::kwargs kwargs) {
@@ -993,28 +1046,7 @@ void initJitScriptBindings(PyObject* module) {
             return invokeScriptMethodFromPython(
                 method, tuple_slice(std::move(args), 1), std::move(kwargs));
           })
-      .def_property_readonly("graph", [](Method& m) { return m.graph(); })
-      .def(
-          "propagate_shapes",
-          [](Method& m, const std::vector<at::Tensor>& inputs, bool with_grad) {
-            return propagate_shapes(
-                *m.graph(), inputs, m.initial_ivalues(), with_grad);
-          })
-      .def(
-          "propagate_and_assign_input_and_output_shapes",
-          [](Method& m,
-             const std::vector<at::Tensor>& inputs,
-             std::vector<at::Tensor> outputs,
-             bool with_grad,
-             bool propagate) {
-            return propagate_and_assign_input_and_output_shapes(
-                *m.graph(),
-                inputs,
-                m.initial_ivalues(),
-                outputs,
-                with_grad,
-                propagate);
-          })
+      .def_property_readonly("graph", &Method::graph)
       .def(
           "initial_ivalues",
           [](Method& m) {
@@ -1131,7 +1163,10 @@ void initJitScriptBindings(PyObject* module) {
   m.def("_jit_clear_class_registry", ClassType::clearRegistry);
   m.def(
       "_debug_set_autodiff_subgraph_inlining", debugSetAutodiffSubgraphInlining);
-
+  m.def("_propagate_shapes", _propagate_shapes);
+  m.def(
+      "_propagate_and_assign_input_and_output_shapes",
+      _propagate_and_assign_input_and_output_shapes);
   py::class_<testing::FileCheck>(m, "FileCheck")
       .def(py::init<>())
       .def("check", &testing::FileCheck::check)
index 10923f0..e99436c 100644 (file)
@@ -626,75 +626,6 @@ struct TORCH_API Module {
   mutable std::recursive_mutex find_method_guard_;
 };
 
-static void setInputTensorTypes(Graph& g, const Stack& stack) {
-  AT_ASSERT(stack.size() == g.inputs().size());
-  for (size_t i = 0; i < stack.size(); ++i) {
-    g.inputs().at(i)->setType(
-        DimensionedTensorType::create(stack.at(i).toTensor()));
-  }
-}
-
-inline std::shared_ptr<Graph> propagate_shapes(
-    Graph& graph,
-    const std::vector<at::Tensor>& inputs,
-    const std::vector<Slot>& initial_ivalues,
-    bool with_grad = false) {
-  auto retval = graph.copy();
-  Stack stack;
-  stack.reserve(inputs.size() + initial_ivalues.size());
-  for (const at::Tensor& i : inputs) {
-    stack.emplace_back(std::move(i));
-  }
-  for (const Slot& inp : initial_ivalues) {
-    stack.push_back(inp.value());
-  }
-  setInputTensorTypes(*retval, stack);
-  PropagateInputShapes(retval);
-  return retval;
-}
-
-inline std::shared_ptr<Graph> propagate_and_assign_input_and_output_shapes(
-    Graph& graph,
-    std::vector<at::Tensor> inputs,
-    const std::vector<Slot>& initial_ivalues,
-    std::vector<at::Tensor> outputs,
-    bool with_grad = false,
-    bool propagate = true) {
-  auto retval = graph.copy();
-  for (auto inp : initial_ivalues) {
-    if (inp.value().isTensor()) {
-      inputs.push_back(inp.value().toTensor());
-    }
-  }
-  if (propagate) {
-    setInputTensorTypes(*retval, fmap<IValue>(inputs));
-    PropagateInputShapes(retval);
-  }
-  AT_ASSERT(retval->inputs().size() == inputs.size());
-  for (size_t i = 0; i < retval->inputs().size(); ++i) {
-    auto scalar_type = inputs[i].scalar_type();
-    auto sizes = inputs[i].sizes();
-    auto type =
-        torch::jit::CompleteTensorType::create(scalar_type, at::kCPU, sizes);
-    retval->inputs()[i]->setType(type);
-  }
-  at::ArrayRef<Value*> output_values = retval->outputs();
-  // patch this to still work if we are returning a tuple of multiple values
-  if (output_values.at(0)->type()->kind() == TupleType::Kind) {
-    AT_ASSERT(output_values.at(0)->node()->kind() == prim::TupleConstruct);
-    output_values = output_values.at(0)->node()->inputs();
-  }
-  AT_ASSERT(output_values.size() == outputs.size());
-  for (size_t i = 0; i < retval->outputs().size(); ++i) {
-    auto scalar_type = outputs[i].scalar_type();
-    auto sizes = outputs[i].sizes();
-    auto type =
-        torch::jit::CompleteTensorType::create(scalar_type, at::kCPU, sizes);
-    output_values[i]->setType(type);
-  }
-  return retval;
-}
-
 } // namespace script
 } // namespace jit
 } // namespace torch
index 1910a3c..0f4d1c2 100644 (file)
@@ -16,7 +16,7 @@ import warnings
 from torch._six import string_classes
 from torch.jit import _unique_state_dict
 from torch.onnx import ONNX_ARCHIVE_MODEL_PROTO_NAME, ExportTypes, OperatorExportTypes
-from torch._C import ListType
+from torch._C import ListType, _propagate_and_assign_input_and_output_shapes
 
 
 # the flag to tell the user whether it's in the middle of ONNX export or not
@@ -237,10 +237,11 @@ def _model_to_graph(model, args, f, verbose=False, training=False,
             example_outputs = [example_outputs]
         try:
             method = model.__getattr__('forward')
-            graph = method.propagate_and_assign_input_and_output_shapes(
-                args, example_outputs, False, propagate)
-            # Erase number types to bring the graph to a pre-NumberType state
             params = method.initial_ivalues()
+            graph = _propagate_and_assign_input_and_output_shapes(
+                method.graph, tuple(args) + tuple(params), example_outputs, False, propagate)
+            # Erase number types to bring the graph to a pre-NumberType state
+
         except AttributeError:
             # TODO: just trace it
             raise RuntimeError('\'forward\' method must be a script method')