From: Lu Fang Date: Fri, 15 Mar 2019 18:41:31 +0000 (-0700) Subject: 1/2 Add Tracing support for C2 Ops (#17899) X-Git-Tag: accepted/tizen/6.5/unified/20211028.231830~787 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=b420f8ff70a5892edf99420e84b649954392798e;p=platform%2Fupstream%2Fpytorch.git 1/2 Add Tracing support for C2 Ops (#17899) Summary: The C10 ops are not registered as custom ops in PyTorch. So we have to add the explicit support for it, too. Pull Request resolved: https://github.com/pytorch/pytorch/pull/17899 Reviewed By: dzhulgakov Differential Revision: D14436999 Pulled By: houseroad fbshipit-source-id: a31fdf13a5c84f9b156a7288e0ffa57deb23b83f --- diff --git a/test/test_jit.py b/test/test_jit.py index baab4d5..e470bd9 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -1404,6 +1404,33 @@ class TestJit(JitTestCase): self.assertEqual(len(list(trace.graph().inputs())), 2) FileCheck().check("mul").check("add").run(str(trace)) + def test_trace_c10_ops(self): + class MyModel(torch.nn.Module): + def __init__(self): + super(MyModel, self).__init__() + + def forward(self, scores, bbox_deltas, im_info, anchors): + a, b = torch.ops._caffe2.GenerateProposals( + (scores), (bbox_deltas), (im_info), (anchors), + 2.0, 6000, 300, 0.7, 16, True, -90, 90, 1.0, + ) + return a, b + model = MyModel() + A = 4 + H = 10 + W = 8 + img_count = 3 + scores = torch.ones(img_count, A, H, W, dtype=torch.float32) + bbox_deltas = torch.linspace(0, 10, steps=img_count * 4 * A * H * W, + dtype=torch.float32) + bbox_deltas = bbox_deltas.view(img_count, 4 * A, H, W) + im_info = torch.ones(img_count, 3, dtype=torch.float32) + anchors = torch.ones(A, 4, dtype=torch.float32) + inputs = (scores, bbox_deltas, im_info, anchors) + traced_model = torch.jit.trace(model, inputs) + self.assertEqual(traced_model(*inputs), model(*inputs)) + self.assertExportImport(traced_model.graph, (scores, bbox_deltas, im_info, anchors)) + def test_nested_inplace(self): x = torch.randn(2, 2) trace, outputs, inputs = torch.jit.get_trace_graph( diff --git a/torch/csrc/jit/register_c10_ops.cpp b/torch/csrc/jit/register_c10_ops.cpp index 6849e4c..c5ff0bf 100644 --- a/torch/csrc/jit/register_c10_ops.cpp +++ b/torch/csrc/jit/register_c10_ops.cpp @@ -1,5 +1,6 @@ #include #include +#include namespace torch { namespace jit { @@ -19,6 +20,8 @@ Operator createOperatorFromC10(const c10::OperatorHandle& op) { const auto input_size = op.schema().arguments().size(); const auto output_size = op.schema().returns().size(); + Node* node = nullptr; + // unwrap tensor inputs from variable for (auto iter = stack.end() - input_size; iter != stack.end(); ++iter) { // TODO Remove the .defined() check once we don't have undefined tensors on the stack anymore (@wanchaol is working on this) @@ -31,6 +34,83 @@ Operator createOperatorFromC10(const c10::OperatorHandle& op) { } } + if (jit::tracer::isTracing()) { + auto symbol = Symbol::fromQualString(op.schema().name()); + const auto& graph = tracer::getTracingState()->graph; + node = graph->create(symbol, 0); + const auto& args = op.schema().arguments(); + int i = 0; + for (auto iter = stack.end() - input_size; iter != stack.end(); + ++iter, ++i) { + // TODO we need to refactor graph APIs (e.g., addInputs) + // appropriately; after that, we can get rid of the giant if-else + // block we will clean this tech debt together in the following PRs + auto type = args[i].type(); + if (type->kind() == TypeKind::OptionalType) { + if (iter->isNone()) { + Value* none = + graph + ->insertNode(graph->createNone( + reinterpret_cast(args[i].type().get()) + ->getElementType())) + ->output(); + node->addInput(none); + continue; + } else { + type = + reinterpret_cast(type.get())->getElementType(); + } + } + if (type->isSubclass(TypeKind::TensorType)) { + AT_ASSERT(iter->isTensor()); + tracer::addInputs(node, args[i].name().c_str(), iter->toTensor()); + } else if (type->kind() == TypeKind::FloatType) { + AT_ASSERT(iter->isDouble()); + tracer::addInputs(node, args[i].name().c_str(), iter->toDouble()); + } else if (type->kind() == TypeKind::IntType) { + AT_ASSERT(iter->isInt()); + tracer::addInputs(node, args[i].name().c_str(), iter->toInt()); + } else if (type->kind() == TypeKind::BoolType) { + AT_ASSERT(iter->isBool()); + tracer::addInputs(node, args[i].name().c_str(), iter->toBool()); + } else if (type->kind() == TypeKind::StringType) { + AT_ASSERT(iter->isString()); + tracer::addInputs( + node, args[i].name().c_str(), iter->toStringRef()); + } else if (type->kind() == TypeKind::ListType) { + const auto& elem_type = + reinterpret_cast(type.get())->getElementType(); + if (elem_type->isSubclass(TypeKind::TensorType)) { + AT_ASSERT(iter->isTensorList()); + tracer::addInputs( + node, + args[i].name().c_str(), + iter->toTensorList()->elements()); + } else if (elem_type->kind() == TypeKind::FloatType) { + AT_ASSERT(iter->isDoubleList()); + tracer::addInputs( + node, + args[i].name().c_str(), + iter->toDoubleList()->elements()); + } else if (elem_type->kind() == TypeKind::IntType) { + AT_ASSERT(iter->isIntList()); + tracer::addInputs( + node, args[i].name().c_str(), iter->toIntList()->elements()); + } else if (elem_type->kind() == TypeKind::BoolType) { + AT_ASSERT(iter->isBoolList()); + tracer::addInputs( + node, args[i].name().c_str(), iter->toBoolList()->elements()); + } else { + throw std::runtime_error( + "unsupported input list type: " + elem_type->str()); + } + } else { + throw std::runtime_error("unsupported input type: " + type->str()); + } + } + graph->insertNode(node); + } + c10::Dispatcher::singleton().lookup(op, &stack).call(&stack); // wrap tensor outputs as variable @@ -40,6 +120,30 @@ Operator createOperatorFromC10(const c10::OperatorHandle& op) { } } + if (jit::tracer::isTracing()) { + int i = 0; + for (auto iter = stack.end() - output_size; iter != stack.end(); + ++iter, ++i) { + const auto& type = op.schema().returns()[i].type(); + if (type->isSubclass(TypeKind::TensorType)) { + AT_ASSERT(iter->isTensor()); + tracer::addOutput(node, iter->toTensor()); + } else if (type->kind() == TypeKind::ListType) { + const auto& elem_type = + reinterpret_cast(type.get())->getElementType(); + if (elem_type->isSubclass(TypeKind::TensorType)) { + AT_ASSERT(iter->isTensorList()); + tracer::addOutput(node, iter->toTensorList()->elements()); + } else { + throw std::runtime_error( + "unsupported ouptut list type: " + elem_type->str()); + } + } else { + throw std::runtime_error("unsupported output type: " + type->str()); + } + } + } + return 0; }); }