auto new_outputs = ConvertSequenceDependencies(node, opset_version);
// Copy type of block output to node output.
- for (size_t i = 0; i < node->outputs().size(); ++i) {
- node->output(i)->setType(node->blocks().at(0)->outputs().at(i + 1)->type());
- }
+ FixupONNXControlflowNodeOutputs(node);
TORCH_INTERNAL_ASSERT(output_size == new_outputs.size());
return new_outputs;
}
graph, else_block, else_block_output, then_block_output);
if_node->outputs()[i]->setType(else_block->outputs()[i]->type());
}
- auto then_tensor_type =
- then_block->outputs().at(i)->type()->castRaw<TensorType>();
- auto else_tensor_type =
- else_block->outputs().at(i)->type()->castRaw<TensorType>();
- if (then_tensor_type && else_tensor_type) {
- const auto& then_shape = then_tensor_type->symbolic_sizes();
- const auto& else_shape = else_tensor_type->symbolic_sizes();
- std::vector<::c10::ShapeSymbol> dims;
- if (then_shape.rank() && else_shape.rank() &&
- then_shape.rank() == else_shape.rank()) {
- for (const auto j : c10::irange(then_shape.rank().value())) {
- if (then_shape[j] == else_shape[j]) {
- dims.emplace_back(then_shape[j]);
- } else {
- dims.emplace_back(::c10::ShapeSymbol::newSymbol());
- }
+ }
+}
+
+void ONNXMergeIfBlockOutputShapes(Node* node) {
+ TORCH_INTERNAL_ASSERT(node->kind() == ::c10::onnx::If);
+ Block* then_block = node->blocks().at(0);
+ Block* else_block = node->blocks().at(1);
+
+ TORCH_INTERNAL_ASSERT(
+ then_block->outputs().size() == else_block->outputs().size())
+
+ auto findCommonShape =
+ [](const ::c10::SymbolicShape& a,
+ const ::c10::SymbolicShape& b) -> ::c10::SymbolicShape {
+ std::vector<::c10::ShapeSymbol> dims;
+ if (a.rank() && b.rank() && a.rank() == b.rank()) {
+ for (const auto j : c10::irange(a.rank().value())) {
+ if (a[j] == b[j]) {
+ dims.emplace_back(a[j]);
+ } else {
+ dims.emplace_back(::c10::ShapeSymbol::newSymbol());
}
- if_node->output(i)->setType(
- then_tensor_type->withSymbolicShapes(::c10::SymbolicShape(dims)));
+ }
+ return ::c10::SymbolicShape(dims);
+ }
+ if (a.rank() && a.rank().value() > 0) {
+ return a;
+ }
+ if (b.rank() && b.rank().value() > 0) {
+ return b;
+ }
+
+ return ::c10::SymbolicShape();
+ };
+
+ auto mergeTensorType =
+ [&findCommonShape](TensorTypePtr a, TensorTypePtr b) -> TensorTypePtr {
+ if (a && b) {
+ const auto& a_shape = a->symbolic_sizes();
+ const auto& b_shape = b->symbolic_sizes();
+ auto commonShape = findCommonShape(a_shape, b_shape);
+ return a->withSymbolicShapes(commonShape);
+ } else if (a) {
+ return a;
+ } else if (b) {
+ return b;
+ }
+ return nullptr;
+ };
+
+ auto mergeListType = [&mergeTensorType](
+ ListTypePtr a, ListTypePtr b) -> ListTypePtr {
+ if (a && b) {
+ auto a_tensor_type = a->getElementType()->cast<TensorType>();
+ auto b_tensor_type = b->getElementType()->cast<TensorType>();
+ auto tensor_type = mergeTensorType(a_tensor_type, b_tensor_type);
+ if (tensor_type) {
+ return a->withContained({tensor_type})->cast<ListType>();
+ }
+ // Both branches produce ListType without tensor shape.
+ return a;
+ } else if (a) {
+ return a;
+ } else if (b) {
+ return b;
+ }
+ return nullptr;
+ };
+
+ for (const auto i : c10::irange(else_block->outputs().size())) {
+ auto then_type = then_block->outputs().at(i)->type();
+ auto else_type = else_block->outputs().at(i)->type();
+ auto then_tensor_type = then_type->cast<TensorType>();
+ auto else_tensor_type = else_type->cast<TensorType>();
+ auto then_list_type = then_type->cast<ListType>();
+ auto else_list_type = else_type->cast<ListType>();
+ if (then_tensor_type || else_tensor_type) {
+ if (auto tensor_type =
+ mergeTensorType(then_tensor_type, else_tensor_type)) {
+ node->output(i)->setType(tensor_type);
+ }
+ } else if (then_list_type || else_list_type) {
+ if (auto list_type = mergeListType(then_list_type, else_list_type)) {
+ node->output(i)->setType(list_type);
}
}
}
return node->outputs().vec();
}
GRAPH_DUMP("Graph before fixing controlflow: ", node->owningGraph());
- auto* if_node = node;
FixupONNXSubblockOutputs(node);
- ONNXFixupUninitializedOutput(if_node);
+ ONNXFixupUninitializedOutput(node);
// Copy type of block output to node output.
- for (size_t i = 0; i < node->outputs().size(); ++i) {
- node->output(i)->setType(node->blocks().at(0)->outputs().at(i)->type());
- }
+ ONNXMergeIfBlockOutputShapes(node);
GRAPH_DUMP("Graph after fixing controlflow: ", node->owningGraph());
- return if_node->outputs().vec();
+ return node->outputs().vec();
}
std::vector<Value*> FixupONNXControlflowNode(Node* n, int opset_version) {
}
}
+void FixupONNXControlflowNodeOutputs(Node* n) {
+ switch (n->kind()) {
+ case ::c10::onnx::Loop: {
+ auto loop_carried_output_size = n->blocks().at(0)->inputs().size() - 2;
+ for (auto i : c10::irange(n->outputs().size())) {
+ auto type = n->blocks().at(0)->outputs().at(i + 1)->type();
+ if (i < loop_carried_output_size) {
+ n->output(i)->setType(type);
+ } else {
+ if (auto t_type = type->cast<TensorType>()) {
+ auto sizes = t_type->symbolic_sizes().sizes();
+ if (sizes.has_value()) {
+ sizes.value().emplace(
+ sizes.value().begin(), c10::ShapeSymbol::newSymbol());
+ type = t_type->withSymbolicShapes(sizes.value());
+ }
+ }
+ n->output(i)->setType(type);
+ }
+ }
+ break;
+ }
+ case ::c10::onnx::If: {
+ ONNXMergeIfBlockOutputShapes(n);
+ break;
+ }
+ default:
+ break;
+ }
+}
+
} // namespace jit
} // namespace torch
#include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/jit/passes/onnx/constant_fold.h>
#include <torch/csrc/jit/passes/onnx/constant_map.h>
+#include <torch/csrc/jit/passes/onnx/fixup_onnx_controlflow.h>
#include <torch/csrc/jit/passes/onnx/fold_if_node.h>
#include <torch/csrc/jit/passes/onnx/helper.h>
#include <torch/csrc/jit/passes/onnx/scalar_type_analysis.h>
}
}
-bool IsSupportedNode(const Node* n) {
+bool IsValidONNXControlflowNode(const Node* n) {
+ // Skip when block size is zero. This is when the node is being created,
+ // and doesn't have subblocks attached yet. Run shape inference for these
+ // nodes later, when the subgraph has already completed shape inferencing.
+ auto node_kind = n->kind();
+ if (node_kind == ::c10::onnx::Loop || node_kind == ::c10::onnx::If) {
+ if (n->blocks().size() == 0) {
+ return false;
+ }
+ }
+
+ return true;
+}
+
+bool IsValidONNXNode(const Node* n) {
auto node_kind = n->kind();
if (!node_kind.is_onnx()) {
return false;
}
- // Skip when block size is zero. This is when the node is first created,
- // doesn't have subblocks attached yet. Run shape inference for these nodes
- // when the subgraph has already completed shape inferencing.
- if (node_kind == ::c10::onnx::Loop || node_kind == ::c10::onnx::If) {
- if (n->blocks().size() == 0) {
- return false;
- }
- for (auto b : n->blocks()) {
- for (auto b_n : b->nodes()) {
- if (!IsSupportedNode(b_n)) {
- return false;
- }
+ if (!IsValidONNXControlflowNode(n)) {
+ return false;
+ }
+
+ for (auto b : n->blocks()) {
+ for (auto b_n : b->nodes()) {
+ if (!IsValidONNXNode(b_n)) {
+ return false;
}
}
}
}
break;
}
+ case ::c10::onnx::If: {
+ if (!IsValidONNXControlflowNode(n)) {
+ break;
+ }
+ FixupONNXControlflowNodeOutputs(n);
+ break;
+ }
+ case ::c10::onnx::Loop: {
+ if (!IsValidONNXControlflowNode(n)) {
+ break;
+ }
+ FixupONNXControlflowNodeOutputs(n);
+ break;
+ }
}
}
int opset_version) {
GRAPH_UPDATE(
"Running ONNX shape inference for node: ", n->kind().toDisplayString());
- if (!IsSupportedNode(n)) {
- return;
- }
- // Create a Graph containing only the single node n.
- // This graph is later converted to ONNX to run shape inference.
- auto n_graph = std::make_shared<Graph>();
- auto clone_node = CloneNodeToGraph(n, n_graph, params_dict, opset_version);
- n_graph->insertNode(clone_node);
+ if (IsValidONNXNode(n)) {
+ // Create a Graph containing only the single node n.
+ // This graph is later converted to ONNX to run shape inference.
+ auto n_graph = std::make_shared<Graph>();
+ auto clone_node = CloneNodeToGraph(n, n_graph, params_dict, opset_version);
+ n_graph->insertNode(clone_node);
- // Register all node outputs as graph outputs.
- for (auto output : clone_node->outputs()) {
- n_graph->registerOutput(output);
- }
+ // Register all node outputs as graph outputs.
+ for (auto output : clone_node->outputs()) {
+ n_graph->registerOutput(output);
+ }
- // Use scalar_type_analysis without low precision cast
- ScalarTypeAnalysisForONNX(n_graph, false, opset_version);
+ // Use scalar_type_analysis without low precision cast
+ ScalarTypeAnalysisForONNX(n_graph, false, opset_version);
- GRAPH_DEBUG("Original torch graph: ", n->owningGraph()->toString());
- GRAPH_DEBUG(
- "Cloned torch graph to run shape inference: ", n_graph->toString());
-
- if (IsGraphValidForInference(n_graph)) {
- // TODO: Some ops have conversion happen at Peephole pass.
- // The conversion here is incomplete for these ops.
- // e.g: ListConstruct, ListUnpack, etc.
- std::shared_ptr<onnx::ModelProto> model_proto;
- SymbolDimMap symbol_map;
- ConvertGraphToONNXProto(n_graph, model_proto, symbol_map, opset_version);
+ GRAPH_DEBUG("Original torch graph: ", n->owningGraph()->toString());
GRAPH_DEBUG(
- "ONNX graph to run shape inference: ", prettyPrint(*model_proto));
-
- // infer shape
- try {
- onnx::shape_inference::InferShapes(*model_proto);
- UpdateOutputTypeByONNXProto(n, clone_node, *model_proto, symbol_map);
- } catch (std::runtime_error& ex) {
- // TODO: include this as warning once we have a more consolidated warning
- // system.
+ "Cloned torch graph to run shape inference: ", n_graph->toString());
+
+ if (IsGraphValidForInference(n_graph)) {
+ // TODO: Some ops have conversion happen at Peephole pass.
+ // The conversion here is incomplete for these ops.
+ // e.g: ListConstruct, ListUnpack, etc.
+ std::shared_ptr<onnx::ModelProto> model_proto;
+ SymbolDimMap symbol_map;
+ ConvertGraphToONNXProto(n_graph, model_proto, symbol_map, opset_version);
GRAPH_DEBUG(
- "ONNX shape inference fails with: ",
- ex.what(),
- " on graph: ",
- n_graph->toString());
- // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
- const char shape_err[] = "ShapeInferenceError";
- // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
- const char type_err[] = "TypeInferenceError";
- if ((strstr(ex.what(), shape_err) == nullptr) &&
- (strstr(ex.what(), type_err) == nullptr)) {
- throw;
+ "ONNX graph to run shape inference: ", prettyPrint(*model_proto));
+
+ // infer shape
+ try {
+ onnx::shape_inference::InferShapes(*model_proto);
+ UpdateOutputTypeByONNXProto(n, clone_node, *model_proto, symbol_map);
+ } catch (std::runtime_error& ex) {
+ // TODO: include this as warning once we have a more consolidated
+ // warning system.
+ GRAPH_DEBUG(
+ "ONNX shape inference fails with: ",
+ ex.what(),
+ " on graph: ",
+ n_graph->toString());
+ // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
+ const char shape_err[] = "ShapeInferenceError";
+ // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
+ const char type_err[] = "TypeInferenceError";
+ // NOLINTNEXTLINE(modernize-use-nullptr)
+ if ((strstr(ex.what(), shape_err) == NULL) &&
+ // NOLINTNEXTLINE(modernize-use-nullptr)
+ (strstr(ex.what(), type_err) == NULL)) {
+ throw;
+ }
}
+ GRAPH_DEBUG(
+ "ONNX graph after shape inference: ", prettyPrint(*model_proto));
}
- GRAPH_DEBUG(
- "ONNX graph after shape inference: ", prettyPrint(*model_proto));
}
SpecialPostProcess(n);
- ProcessConstantValueMap(n, opset_version);
+ if (IsValidONNXNode(n)) {
+ ProcessConstantValueMap(n, opset_version);
+ }
GRAPH_DEBUG(
"Torch graph after shape inference:", n->owningGraph()->toString());
}