1/2 Add Tracing support for C2 Ops (#17899)
authorLu Fang <lufang@fb.com>
Fri, 15 Mar 2019 18:41:31 +0000 (11:41 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 15 Mar 2019 18:48:34 +0000 (11:48 -0700)
Summary:
The C10 ops are not registered as custom ops in PyTorch. So we have to add the explicit support for it, too.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17899

Reviewed By: dzhulgakov

Differential Revision: D14436999

Pulled By: houseroad

fbshipit-source-id: a31fdf13a5c84f9b156a7288e0ffa57deb23b83f

test/test_jit.py
torch/csrc/jit/register_c10_ops.cpp

index baab4d5..e470bd9 100644 (file)
@@ -1404,6 +1404,33 @@ class TestJit(JitTestCase):
         self.assertEqual(len(list(trace.graph().inputs())), 2)
         FileCheck().check("mul").check("add").run(str(trace))
 
+    def test_trace_c10_ops(self):
+        class MyModel(torch.nn.Module):
+            def __init__(self):
+                super(MyModel, self).__init__()
+
+            def forward(self, scores, bbox_deltas, im_info, anchors):
+                a, b = torch.ops._caffe2.GenerateProposals(
+                    (scores), (bbox_deltas), (im_info), (anchors),
+                    2.0, 6000, 300, 0.7, 16, True, -90, 90, 1.0,
+                )
+                return a, b
+        model = MyModel()
+        A = 4
+        H = 10
+        W = 8
+        img_count = 3
+        scores = torch.ones(img_count, A, H, W, dtype=torch.float32)
+        bbox_deltas = torch.linspace(0, 10, steps=img_count * 4 * A * H * W,
+                                     dtype=torch.float32)
+        bbox_deltas = bbox_deltas.view(img_count, 4 * A, H, W)
+        im_info = torch.ones(img_count, 3, dtype=torch.float32)
+        anchors = torch.ones(A, 4, dtype=torch.float32)
+        inputs = (scores, bbox_deltas, im_info, anchors)
+        traced_model = torch.jit.trace(model, inputs)
+        self.assertEqual(traced_model(*inputs), model(*inputs))
+        self.assertExportImport(traced_model.graph, (scores, bbox_deltas, im_info, anchors))
+
     def test_nested_inplace(self):
         x = torch.randn(2, 2)
         trace, outputs, inputs = torch.jit.get_trace_graph(
index 6849e4c..c5ff0bf 100644 (file)
@@ -1,5 +1,6 @@
 #include <ATen/core/dispatch/Dispatcher.h>
 #include <torch/csrc/jit/operator.h>
+#include <torch/csrc/jit/tracer.h>
 
 namespace torch {
 namespace jit {
@@ -19,6 +20,8 @@ Operator createOperatorFromC10(const c10::OperatorHandle& op) {
       const auto input_size = op.schema().arguments().size();
       const auto output_size = op.schema().returns().size();
 
+      Node* node = nullptr;
+
       // unwrap tensor inputs from variable
       for (auto iter = stack.end() - input_size; iter != stack.end(); ++iter) {
         // TODO Remove the .defined() check once we don't have undefined tensors on the stack anymore (@wanchaol is working on this)
@@ -31,6 +34,83 @@ Operator createOperatorFromC10(const c10::OperatorHandle& op) {
         }
       }
 
+      if (jit::tracer::isTracing()) {
+        auto symbol = Symbol::fromQualString(op.schema().name());
+        const auto& graph = tracer::getTracingState()->graph;
+        node = graph->create(symbol, 0);
+        const auto& args = op.schema().arguments();
+        int i = 0;
+        for (auto iter = stack.end() - input_size; iter != stack.end();
+             ++iter, ++i) {
+          // TODO we need to refactor graph APIs (e.g., addInputs)
+          // appropriately; after that, we can get rid of the giant if-else
+          // block we will clean this tech debt together in the following PRs
+          auto type = args[i].type();
+          if (type->kind() == TypeKind::OptionalType) {
+            if (iter->isNone()) {
+              Value* none =
+                  graph
+                      ->insertNode(graph->createNone(
+                          reinterpret_cast<OptionalType*>(args[i].type().get())
+                              ->getElementType()))
+                      ->output();
+              node->addInput(none);
+              continue;
+            } else {
+              type =
+                  reinterpret_cast<OptionalType*>(type.get())->getElementType();
+            }
+          }
+          if (type->isSubclass(TypeKind::TensorType)) {
+            AT_ASSERT(iter->isTensor());
+            tracer::addInputs(node, args[i].name().c_str(), iter->toTensor());
+          } else if (type->kind() == TypeKind::FloatType) {
+            AT_ASSERT(iter->isDouble());
+            tracer::addInputs(node, args[i].name().c_str(), iter->toDouble());
+          } else if (type->kind() == TypeKind::IntType) {
+            AT_ASSERT(iter->isInt());
+            tracer::addInputs(node, args[i].name().c_str(), iter->toInt());
+          } else if (type->kind() == TypeKind::BoolType) {
+            AT_ASSERT(iter->isBool());
+            tracer::addInputs(node, args[i].name().c_str(), iter->toBool());
+          } else if (type->kind() == TypeKind::StringType) {
+            AT_ASSERT(iter->isString());
+            tracer::addInputs(
+                node, args[i].name().c_str(), iter->toStringRef());
+          } else if (type->kind() == TypeKind::ListType) {
+            const auto& elem_type =
+                reinterpret_cast<ListType*>(type.get())->getElementType();
+            if (elem_type->isSubclass(TypeKind::TensorType)) {
+              AT_ASSERT(iter->isTensorList());
+              tracer::addInputs(
+                  node,
+                  args[i].name().c_str(),
+                  iter->toTensorList()->elements());
+            } else if (elem_type->kind() == TypeKind::FloatType) {
+              AT_ASSERT(iter->isDoubleList());
+              tracer::addInputs(
+                  node,
+                  args[i].name().c_str(),
+                  iter->toDoubleList()->elements());
+            } else if (elem_type->kind() == TypeKind::IntType) {
+              AT_ASSERT(iter->isIntList());
+              tracer::addInputs(
+                  node, args[i].name().c_str(), iter->toIntList()->elements());
+            } else if (elem_type->kind() == TypeKind::BoolType) {
+              AT_ASSERT(iter->isBoolList());
+              tracer::addInputs(
+                  node, args[i].name().c_str(), iter->toBoolList()->elements());
+            } else {
+              throw std::runtime_error(
+                  "unsupported input list type: " + elem_type->str());
+            }
+          } else {
+            throw std::runtime_error("unsupported input type: " + type->str());
+          }
+        }
+        graph->insertNode(node);
+      }
+
       c10::Dispatcher::singleton().lookup(op, &stack).call(&stack);
 
       // wrap tensor outputs as variable
@@ -40,6 +120,30 @@ Operator createOperatorFromC10(const c10::OperatorHandle& op) {
         }
       }
 
+      if (jit::tracer::isTracing()) {
+        int i = 0;
+        for (auto iter = stack.end() - output_size; iter != stack.end();
+             ++iter, ++i) {
+          const auto& type = op.schema().returns()[i].type();
+          if (type->isSubclass(TypeKind::TensorType)) {
+            AT_ASSERT(iter->isTensor());
+            tracer::addOutput(node, iter->toTensor());
+          } else if (type->kind() == TypeKind::ListType) {
+            const auto& elem_type =
+                reinterpret_cast<ListType*>(type.get())->getElementType();
+            if (elem_type->isSubclass(TypeKind::TensorType)) {
+              AT_ASSERT(iter->isTensorList());
+              tracer::addOutput(node, iter->toTensorList()->elements());
+            } else {
+              throw std::runtime_error(
+                  "unsupported ouptut list type: " + elem_type->str());
+            }
+          } else {
+            throw std::runtime_error("unsupported output type: " + type->str());
+          }
+        }
+      }
+
       return 0;
   });
 }