from common_methods_invocations import create_input, unpack_variables, \
exclude_tensor_method, non_differentiable, EXCLUDE_GRADCHECK, EXCLUDE_FUNCTIONAL
from torch.testing import FileCheck
-from torch._C import TensorType, parse_ir
+from torch._C import TensorType, parse_ir, _propagate_shapes
from copy import deepcopy
import random
from typing import List, Dict, Optional, Tuple
# test that shape analysis is written correctly for sum with IntArrayRef[1] dim argument
self.run_pass('constant_propagation', func.graph)
self.run_pass('constant_propagation', func2.graph)
- g = func._get_method('forward').propagate_shapes((torch.zeros(1, 1, 1, 1, 4),), False)
- g2 = func2._get_method('forward').propagate_shapes((torch.zeros(1, 1, 1, 1, 4),), False)
+ g = _propagate_shapes(func.graph, (torch.zeros(1, 1, 1, 1, 4),), False)
+ g2 = _propagate_shapes(func2.graph, (torch.zeros(1, 1, 1, 1, 4),), False)
self.assertTrue(g.findNode("aten::sum").output().type().kind()
== "DimensionedTensorType")
self.assertTrue(g2.findNode("aten::sum").output().type().kind()
torch.mul(x, y, out=z)
return z
- graph = test._get_method('forward').propagate_shapes(
- (torch.zeros(2, 1), torch.zeros(1, 2), torch.zeros(1, 1, 1)), False)
+ graph = _propagate_shapes(test.graph,
+ (torch.zeros(2, 1), torch.zeros(1, 2), torch.zeros(1, 1, 1)), False)
self.assertTrue(next(graph.outputs()).type() == TensorType.get())
out_op_graph_input()
return after_resize_alias
self.run_pass('constant_propagation', test.graph)
- g = test._get_method('forward').propagate_shapes((torch.zeros(1, 1),), False)
+ g = _propagate_shapes(test.graph, (torch.zeros(1, 1),), False)
resize_node = g.findNode("aten::resize_")
# first input and output of b.resize_ is b
self.assertTrue(next(resize_node.inputs()).type() == TensorType.get())
g = test.graph
self.run_pass('constant_propagation', g)
- g = test._get_method('forward').propagate_shapes((torch.zeros(1, 1),), False)
+ g = _propagate_shapes(test.graph, (torch.zeros(1, 1),), False)
# x doesn't alias a resized op so it shouldn't be set to base Tensor type
self.assertTrue(next(g.inputs()).type() != TensorType.get())
x = torch.randn(3, 1, 5, requires_grad=True)
fn = torch.jit.script(fn)
- graph = fn._get_method('forward').propagate_shapes((x,), False)
+ graph = _propagate_shapes(fn.graph, (x,), False)
a = next(graph.outputs()).type().kind()
self.assertTrue(next(graph.outputs()).type().kind() != 'TensorType')
return x + y
x, y = torch.rand(3, 4, dtype=torch.float), torch.rand(3, 4, dtype=torch.double)
- graph = fn._get_method('forward').propagate_shapes((x, y), False)
+ graph = _propagate_shapes(fn.graph, (x, y), False)
FileCheck().check('Double(*, *) = aten::add').run(graph)
def test_shape_prop_promote_scalar_arg(self):
return math.pi + x
x = torch.zeros(3, 4, dtype=torch.long)
- graph = fn._get_method('forward').propagate_shapes((x,), False)
+ graph = _propagate_shapes(fn.graph, (x,), False)
FileCheck().check('Long(*, *) = aten::add').run(graph)
def test_integral_shape_inference(self):
return torch.cat(c)
b = torch.zeros(2, 4)
- test_list._get_method('forward').propagate_shapes((b,), False)
+ _propagate_shapes(test_list.graph, (b,), False)
def test_if_supertype(self):
@torch.jit.script
b = torch.zeros(2, 4, dtype=torch.long)
c = torch.zeros(2, 4, dtype=torch.float)
- graph = tensor_unifying._get_method('forward').propagate_shapes((a, b, c), False)
+ graph = _propagate_shapes(tensor_unifying.graph, (a, b, c), False)
if_outputs = list(graph.findNode("prim::If").outputs())
self.assertTrue(if_outputs[0].type().str() == "Float(*, *)")
self.assertTrue(if_outputs[1].type().str() == "Tensor")
auto& param = *it;
params.emplace_back(g.insertGetAttr(self_, param.name()));
}
- auto list =
- g.insertNode(g.createTuple(params))->output();
+ auto list = g.insertNode(g.createTuple(params))->output();
return std::make_shared<ConstantParameterList>(list);
}
if (py::isinstance<py::function>(attr) ||
};
}
+static void setInputTensorTypes(Graph& g, const Stack& stack) {
+ AT_ASSERT(stack.size() == g.inputs().size());
+ for (size_t i = 0; i < stack.size(); ++i) {
+ g.inputs().at(i)->setType(
+ DimensionedTensorType::create(stack.at(i).toTensor()));
+ }
+}
+
+static std::shared_ptr<Graph> _propagate_shapes(
+ Graph& graph,
+ std::vector<at::Tensor> inputs,
+ bool with_grad = false) {
+ Stack stack(inputs.begin(), inputs.end());
+ auto retval = graph.copy();
+ setInputTensorTypes(*retval, stack);
+ PropagateInputShapes(retval);
+ return retval;
+}
+
+static std::shared_ptr<Graph> _propagate_and_assign_input_and_output_shapes(
+ Graph& graph,
+ std::vector<at::Tensor> inputs,
+ std::vector<at::Tensor> outputs,
+ bool with_grad = false,
+ bool propagate = true) {
+ auto retval = graph.copy();
+ if (propagate) {
+ setInputTensorTypes(*retval, fmap<IValue>(inputs));
+ PropagateInputShapes(retval);
+ }
+ AT_ASSERT(retval->inputs().size() == inputs.size());
+ for (size_t i = 0; i < retval->inputs().size(); ++i) {
+ auto scalar_type = inputs[i].scalar_type();
+ auto sizes = inputs[i].sizes();
+ auto type =
+ torch::jit::CompleteTensorType::create(scalar_type, at::kCPU, sizes);
+ retval->inputs()[i]->setType(type);
+ }
+ at::ArrayRef<Value*> output_values = retval->outputs();
+ // patch this to still work if we are returning a tuple of multiple values
+ if (output_values.at(0)->type()->kind() == TupleType::Kind) {
+ AT_ASSERT(output_values.at(0)->node()->kind() == prim::TupleConstruct);
+ output_values = output_values.at(0)->node()->inputs();
+ }
+ AT_ASSERT(output_values.size() == outputs.size());
+ for (size_t i = 0; i < retval->outputs().size(); ++i) {
+ auto scalar_type = outputs[i].scalar_type();
+ auto sizes = outputs[i].sizes();
+ auto type =
+ torch::jit::CompleteTensorType::create(scalar_type, at::kCPU, sizes);
+ output_values[i]->setType(type);
+ }
+ return retval;
+}
+
void initJitScriptBindings(PyObject* module) {
auto m = py::handle(module).cast<py::module>();
});
py::class_<Method>(m, "ScriptMethod", py::dynamic_attr())
- .def("graph", [&](Method& self) { return self.graph(); })
.def(
"__call__",
[](py::args args, py::kwargs kwargs) {
return invokeScriptMethodFromPython(
method, tuple_slice(std::move(args), 1), std::move(kwargs));
})
- .def_property_readonly("graph", [](Method& m) { return m.graph(); })
- .def(
- "propagate_shapes",
- [](Method& m, const std::vector<at::Tensor>& inputs, bool with_grad) {
- return propagate_shapes(
- *m.graph(), inputs, m.initial_ivalues(), with_grad);
- })
- .def(
- "propagate_and_assign_input_and_output_shapes",
- [](Method& m,
- const std::vector<at::Tensor>& inputs,
- std::vector<at::Tensor> outputs,
- bool with_grad,
- bool propagate) {
- return propagate_and_assign_input_and_output_shapes(
- *m.graph(),
- inputs,
- m.initial_ivalues(),
- outputs,
- with_grad,
- propagate);
- })
+ .def_property_readonly("graph", &Method::graph)
.def(
"initial_ivalues",
[](Method& m) {
m.def("_jit_clear_class_registry", ClassType::clearRegistry);
m.def(
"_debug_set_autodiff_subgraph_inlining", debugSetAutodiffSubgraphInlining);
-
+ m.def("_propagate_shapes", _propagate_shapes);
+ m.def(
+ "_propagate_and_assign_input_and_output_shapes",
+ _propagate_and_assign_input_and_output_shapes);
py::class_<testing::FileCheck>(m, "FileCheck")
.def(py::init<>())
.def("check", &testing::FileCheck::check)
mutable std::recursive_mutex find_method_guard_;
};
-static void setInputTensorTypes(Graph& g, const Stack& stack) {
- AT_ASSERT(stack.size() == g.inputs().size());
- for (size_t i = 0; i < stack.size(); ++i) {
- g.inputs().at(i)->setType(
- DimensionedTensorType::create(stack.at(i).toTensor()));
- }
-}
-
-inline std::shared_ptr<Graph> propagate_shapes(
- Graph& graph,
- const std::vector<at::Tensor>& inputs,
- const std::vector<Slot>& initial_ivalues,
- bool with_grad = false) {
- auto retval = graph.copy();
- Stack stack;
- stack.reserve(inputs.size() + initial_ivalues.size());
- for (const at::Tensor& i : inputs) {
- stack.emplace_back(std::move(i));
- }
- for (const Slot& inp : initial_ivalues) {
- stack.push_back(inp.value());
- }
- setInputTensorTypes(*retval, stack);
- PropagateInputShapes(retval);
- return retval;
-}
-
-inline std::shared_ptr<Graph> propagate_and_assign_input_and_output_shapes(
- Graph& graph,
- std::vector<at::Tensor> inputs,
- const std::vector<Slot>& initial_ivalues,
- std::vector<at::Tensor> outputs,
- bool with_grad = false,
- bool propagate = true) {
- auto retval = graph.copy();
- for (auto inp : initial_ivalues) {
- if (inp.value().isTensor()) {
- inputs.push_back(inp.value().toTensor());
- }
- }
- if (propagate) {
- setInputTensorTypes(*retval, fmap<IValue>(inputs));
- PropagateInputShapes(retval);
- }
- AT_ASSERT(retval->inputs().size() == inputs.size());
- for (size_t i = 0; i < retval->inputs().size(); ++i) {
- auto scalar_type = inputs[i].scalar_type();
- auto sizes = inputs[i].sizes();
- auto type =
- torch::jit::CompleteTensorType::create(scalar_type, at::kCPU, sizes);
- retval->inputs()[i]->setType(type);
- }
- at::ArrayRef<Value*> output_values = retval->outputs();
- // patch this to still work if we are returning a tuple of multiple values
- if (output_values.at(0)->type()->kind() == TupleType::Kind) {
- AT_ASSERT(output_values.at(0)->node()->kind() == prim::TupleConstruct);
- output_values = output_values.at(0)->node()->inputs();
- }
- AT_ASSERT(output_values.size() == outputs.size());
- for (size_t i = 0; i < retval->outputs().size(); ++i) {
- auto scalar_type = outputs[i].scalar_type();
- auto sizes = outputs[i].sizes();
- auto type =
- torch::jit::CompleteTensorType::create(scalar_type, at::kCPU, sizes);
- output_values[i]->setType(type);
- }
- return retval;
-}
-
} // namespace script
} // namespace jit
} // namespace torch
from torch._six import string_classes
from torch.jit import _unique_state_dict
from torch.onnx import ONNX_ARCHIVE_MODEL_PROTO_NAME, ExportTypes, OperatorExportTypes
-from torch._C import ListType
+from torch._C import ListType, _propagate_and_assign_input_and_output_shapes
# the flag to tell the user whether it's in the middle of ONNX export or not
example_outputs = [example_outputs]
try:
method = model.__getattr__('forward')
- graph = method.propagate_and_assign_input_and_output_shapes(
- args, example_outputs, False, propagate)
- # Erase number types to bring the graph to a pre-NumberType state
params = method.initial_ivalues()
+ graph = _propagate_and_assign_input_and_output_shapes(
+ method.graph, tuple(args) + tuple(params), example_outputs, False, propagate)
+ # Erase number types to bring the graph to a pre-NumberType state
+
except AttributeError:
# TODO: just trace it
raise RuntimeError('\'forward\' method must be a script method')