From: Zachary DeVito Date: Sat, 13 Apr 2019 15:28:11 +0000 (-0700) Subject: get propagate_shape logic out of module.h (#19137) X-Git-Tag: accepted/tizen/6.5/unified/20211028.231830~241 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=dcb5fd3613dd86cb5ba794656b7ffadf00c70121;p=platform%2Fupstream%2Fpytorch.git get propagate_shape logic out of module.h (#19137) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/19137 ghimport-source-id: 2394765f2d401e68ffdfa4c985bfab4cca2517f8 Reviewed By: jamesr66a Differential Revision: D14885946 Pulled By: zdevito fbshipit-source-id: daa2894ed9761107e9d273bb172840dc23ace072 --- diff --git a/test/test_jit.py b/test/test_jit.py index 015c67c..56f0fc2 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -42,7 +42,7 @@ from common_methods_invocations import method_tests as autograd_method_tests from common_methods_invocations import create_input, unpack_variables, \ exclude_tensor_method, non_differentiable, EXCLUDE_GRADCHECK, EXCLUDE_FUNCTIONAL from torch.testing import FileCheck -from torch._C import TensorType, parse_ir +from torch._C import TensorType, parse_ir, _propagate_shapes from copy import deepcopy import random from typing import List, Dict, Optional, Tuple @@ -3733,8 +3733,8 @@ a") # test that shape analysis is written correctly for sum with IntArrayRef[1] dim argument self.run_pass('constant_propagation', func.graph) self.run_pass('constant_propagation', func2.graph) - g = func._get_method('forward').propagate_shapes((torch.zeros(1, 1, 1, 1, 4),), False) - g2 = func2._get_method('forward').propagate_shapes((torch.zeros(1, 1, 1, 1, 4),), False) + g = _propagate_shapes(func.graph, (torch.zeros(1, 1, 1, 1, 4),), False) + g2 = _propagate_shapes(func2.graph, (torch.zeros(1, 1, 1, 1, 4),), False) self.assertTrue(g.findNode("aten::sum").output().type().kind() == "DimensionedTensorType") self.assertTrue(g2.findNode("aten::sum").output().type().kind() @@ -4509,8 +4509,8 @@ a") torch.mul(x, y, out=z) return z - graph = test._get_method('forward').propagate_shapes( - (torch.zeros(2, 1), torch.zeros(1, 2), torch.zeros(1, 1, 1)), False) + graph = _propagate_shapes(test.graph, + (torch.zeros(2, 1), torch.zeros(1, 2), torch.zeros(1, 1, 1)), False) self.assertTrue(next(graph.outputs()).type() == TensorType.get()) out_op_graph_input() @@ -4529,7 +4529,7 @@ a") return after_resize_alias self.run_pass('constant_propagation', test.graph) - g = test._get_method('forward').propagate_shapes((torch.zeros(1, 1),), False) + g = _propagate_shapes(test.graph, (torch.zeros(1, 1),), False) resize_node = g.findNode("aten::resize_") # first input and output of b.resize_ is b self.assertTrue(next(resize_node.inputs()).type() == TensorType.get()) @@ -4553,7 +4553,7 @@ a") g = test.graph self.run_pass('constant_propagation', g) - g = test._get_method('forward').propagate_shapes((torch.zeros(1, 1),), False) + g = _propagate_shapes(test.graph, (torch.zeros(1, 1),), False) # x doesn't alias a resized op so it shouldn't be set to base Tensor type self.assertTrue(next(g.inputs()).type() != TensorType.get()) @@ -4608,7 +4608,7 @@ a") x = torch.randn(3, 1, 5, requires_grad=True) fn = torch.jit.script(fn) - graph = fn._get_method('forward').propagate_shapes((x,), False) + graph = _propagate_shapes(fn.graph, (x,), False) a = next(graph.outputs()).type().kind() self.assertTrue(next(graph.outputs()).type().kind() != 'TensorType') @@ -4618,7 +4618,7 @@ a") return x + y x, y = torch.rand(3, 4, dtype=torch.float), torch.rand(3, 4, dtype=torch.double) - graph = fn._get_method('forward').propagate_shapes((x, y), False) + graph = _propagate_shapes(fn.graph, (x, y), False) FileCheck().check('Double(*, *) = aten::add').run(graph) def test_shape_prop_promote_scalar_arg(self): @@ -4627,7 +4627,7 @@ a") return math.pi + x x = torch.zeros(3, 4, dtype=torch.long) - graph = fn._get_method('forward').propagate_shapes((x,), False) + graph = _propagate_shapes(fn.graph, (x,), False) FileCheck().check('Long(*, *) = aten::add').run(graph) def test_integral_shape_inference(self): @@ -7201,7 +7201,7 @@ a") return torch.cat(c) b = torch.zeros(2, 4) - test_list._get_method('forward').propagate_shapes((b,), False) + _propagate_shapes(test_list.graph, (b,), False) def test_if_supertype(self): @torch.jit.script @@ -7218,7 +7218,7 @@ a") b = torch.zeros(2, 4, dtype=torch.long) c = torch.zeros(2, 4, dtype=torch.float) - graph = tensor_unifying._get_method('forward').propagate_shapes((a, b, c), False) + graph = _propagate_shapes(tensor_unifying.graph, (a, b, c), False) if_outputs = list(graph.findNode("prim::If").outputs()) self.assertTrue(if_outputs[0].type().str() == "Float(*, *)") self.assertTrue(if_outputs[1].type().str() == "Tensor") diff --git a/torch/csrc/jit/script/init.cpp b/torch/csrc/jit/script/init.cpp index 668faa5..caeb618 100644 --- a/torch/csrc/jit/script/init.cpp +++ b/torch/csrc/jit/script/init.cpp @@ -387,8 +387,7 @@ struct ModuleValue : public SugaredValue { auto& param = *it; params.emplace_back(g.insertGetAttr(self_, param.name())); } - auto list = - g.insertNode(g.createTuple(params))->output(); + auto list = g.insertNode(g.createTuple(params))->output(); return std::make_shared(list); } if (py::isinstance(attr) || @@ -700,6 +699,61 @@ static Self moduleSelf(const std::shared_ptr& m) { }; } +static void setInputTensorTypes(Graph& g, const Stack& stack) { + AT_ASSERT(stack.size() == g.inputs().size()); + for (size_t i = 0; i < stack.size(); ++i) { + g.inputs().at(i)->setType( + DimensionedTensorType::create(stack.at(i).toTensor())); + } +} + +static std::shared_ptr _propagate_shapes( + Graph& graph, + std::vector inputs, + bool with_grad = false) { + Stack stack(inputs.begin(), inputs.end()); + auto retval = graph.copy(); + setInputTensorTypes(*retval, stack); + PropagateInputShapes(retval); + return retval; +} + +static std::shared_ptr _propagate_and_assign_input_and_output_shapes( + Graph& graph, + std::vector inputs, + std::vector outputs, + bool with_grad = false, + bool propagate = true) { + auto retval = graph.copy(); + if (propagate) { + setInputTensorTypes(*retval, fmap(inputs)); + PropagateInputShapes(retval); + } + AT_ASSERT(retval->inputs().size() == inputs.size()); + for (size_t i = 0; i < retval->inputs().size(); ++i) { + auto scalar_type = inputs[i].scalar_type(); + auto sizes = inputs[i].sizes(); + auto type = + torch::jit::CompleteTensorType::create(scalar_type, at::kCPU, sizes); + retval->inputs()[i]->setType(type); + } + at::ArrayRef output_values = retval->outputs(); + // patch this to still work if we are returning a tuple of multiple values + if (output_values.at(0)->type()->kind() == TupleType::Kind) { + AT_ASSERT(output_values.at(0)->node()->kind() == prim::TupleConstruct); + output_values = output_values.at(0)->node()->inputs(); + } + AT_ASSERT(output_values.size() == outputs.size()); + for (size_t i = 0; i < retval->outputs().size(); ++i) { + auto scalar_type = outputs[i].scalar_type(); + auto sizes = outputs[i].sizes(); + auto type = + torch::jit::CompleteTensorType::create(scalar_type, at::kCPU, sizes); + output_values[i]->setType(type); + } + return retval; +} + void initJitScriptBindings(PyObject* module) { auto m = py::handle(module).cast(); @@ -984,7 +1038,6 @@ void initJitScriptBindings(PyObject* module) { }); py::class_(m, "ScriptMethod", py::dynamic_attr()) - .def("graph", [&](Method& self) { return self.graph(); }) .def( "__call__", [](py::args args, py::kwargs kwargs) { @@ -993,28 +1046,7 @@ void initJitScriptBindings(PyObject* module) { return invokeScriptMethodFromPython( method, tuple_slice(std::move(args), 1), std::move(kwargs)); }) - .def_property_readonly("graph", [](Method& m) { return m.graph(); }) - .def( - "propagate_shapes", - [](Method& m, const std::vector& inputs, bool with_grad) { - return propagate_shapes( - *m.graph(), inputs, m.initial_ivalues(), with_grad); - }) - .def( - "propagate_and_assign_input_and_output_shapes", - [](Method& m, - const std::vector& inputs, - std::vector outputs, - bool with_grad, - bool propagate) { - return propagate_and_assign_input_and_output_shapes( - *m.graph(), - inputs, - m.initial_ivalues(), - outputs, - with_grad, - propagate); - }) + .def_property_readonly("graph", &Method::graph) .def( "initial_ivalues", [](Method& m) { @@ -1131,7 +1163,10 @@ void initJitScriptBindings(PyObject* module) { m.def("_jit_clear_class_registry", ClassType::clearRegistry); m.def( "_debug_set_autodiff_subgraph_inlining", debugSetAutodiffSubgraphInlining); - + m.def("_propagate_shapes", _propagate_shapes); + m.def( + "_propagate_and_assign_input_and_output_shapes", + _propagate_and_assign_input_and_output_shapes); py::class_(m, "FileCheck") .def(py::init<>()) .def("check", &testing::FileCheck::check) diff --git a/torch/csrc/jit/script/module.h b/torch/csrc/jit/script/module.h index 10923f0..e99436c 100644 --- a/torch/csrc/jit/script/module.h +++ b/torch/csrc/jit/script/module.h @@ -626,75 +626,6 @@ struct TORCH_API Module { mutable std::recursive_mutex find_method_guard_; }; -static void setInputTensorTypes(Graph& g, const Stack& stack) { - AT_ASSERT(stack.size() == g.inputs().size()); - for (size_t i = 0; i < stack.size(); ++i) { - g.inputs().at(i)->setType( - DimensionedTensorType::create(stack.at(i).toTensor())); - } -} - -inline std::shared_ptr propagate_shapes( - Graph& graph, - const std::vector& inputs, - const std::vector& initial_ivalues, - bool with_grad = false) { - auto retval = graph.copy(); - Stack stack; - stack.reserve(inputs.size() + initial_ivalues.size()); - for (const at::Tensor& i : inputs) { - stack.emplace_back(std::move(i)); - } - for (const Slot& inp : initial_ivalues) { - stack.push_back(inp.value()); - } - setInputTensorTypes(*retval, stack); - PropagateInputShapes(retval); - return retval; -} - -inline std::shared_ptr propagate_and_assign_input_and_output_shapes( - Graph& graph, - std::vector inputs, - const std::vector& initial_ivalues, - std::vector outputs, - bool with_grad = false, - bool propagate = true) { - auto retval = graph.copy(); - for (auto inp : initial_ivalues) { - if (inp.value().isTensor()) { - inputs.push_back(inp.value().toTensor()); - } - } - if (propagate) { - setInputTensorTypes(*retval, fmap(inputs)); - PropagateInputShapes(retval); - } - AT_ASSERT(retval->inputs().size() == inputs.size()); - for (size_t i = 0; i < retval->inputs().size(); ++i) { - auto scalar_type = inputs[i].scalar_type(); - auto sizes = inputs[i].sizes(); - auto type = - torch::jit::CompleteTensorType::create(scalar_type, at::kCPU, sizes); - retval->inputs()[i]->setType(type); - } - at::ArrayRef output_values = retval->outputs(); - // patch this to still work if we are returning a tuple of multiple values - if (output_values.at(0)->type()->kind() == TupleType::Kind) { - AT_ASSERT(output_values.at(0)->node()->kind() == prim::TupleConstruct); - output_values = output_values.at(0)->node()->inputs(); - } - AT_ASSERT(output_values.size() == outputs.size()); - for (size_t i = 0; i < retval->outputs().size(); ++i) { - auto scalar_type = outputs[i].scalar_type(); - auto sizes = outputs[i].sizes(); - auto type = - torch::jit::CompleteTensorType::create(scalar_type, at::kCPU, sizes); - output_values[i]->setType(type); - } - return retval; -} - } // namespace script } // namespace jit } // namespace torch diff --git a/torch/onnx/utils.py b/torch/onnx/utils.py index 1910a3c..0f4d1c2 100644 --- a/torch/onnx/utils.py +++ b/torch/onnx/utils.py @@ -16,7 +16,7 @@ import warnings from torch._six import string_classes from torch.jit import _unique_state_dict from torch.onnx import ONNX_ARCHIVE_MODEL_PROTO_NAME, ExportTypes, OperatorExportTypes -from torch._C import ListType +from torch._C import ListType, _propagate_and_assign_input_and_output_shapes # the flag to tell the user whether it's in the middle of ONNX export or not @@ -237,10 +237,11 @@ def _model_to_graph(model, args, f, verbose=False, training=False, example_outputs = [example_outputs] try: method = model.__getattr__('forward') - graph = method.propagate_and_assign_input_and_output_shapes( - args, example_outputs, False, propagate) - # Erase number types to bring the graph to a pre-NumberType state params = method.initial_ivalues() + graph = _propagate_and_assign_input_and_output_shapes( + method.graph, tuple(args) + tuple(params), example_outputs, False, propagate) + # Erase number types to bring the graph to a pre-NumberType state + except AttributeError: # TODO: just trace it raise RuntimeError('\'forward\' method must be a script method')