self.assertEqual(len(list(trace.graph().inputs())), 2)
FileCheck().check("mul").check("add").run(str(trace))
+ def test_trace_c10_ops(self):
+ class MyModel(torch.nn.Module):
+ def __init__(self):
+ super(MyModel, self).__init__()
+
+ def forward(self, scores, bbox_deltas, im_info, anchors):
+ a, b = torch.ops._caffe2.GenerateProposals(
+ (scores), (bbox_deltas), (im_info), (anchors),
+ 2.0, 6000, 300, 0.7, 16, True, -90, 90, 1.0,
+ )
+ return a, b
+ model = MyModel()
+ A = 4
+ H = 10
+ W = 8
+ img_count = 3
+ scores = torch.ones(img_count, A, H, W, dtype=torch.float32)
+ bbox_deltas = torch.linspace(0, 10, steps=img_count * 4 * A * H * W,
+ dtype=torch.float32)
+ bbox_deltas = bbox_deltas.view(img_count, 4 * A, H, W)
+ im_info = torch.ones(img_count, 3, dtype=torch.float32)
+ anchors = torch.ones(A, 4, dtype=torch.float32)
+ inputs = (scores, bbox_deltas, im_info, anchors)
+ traced_model = torch.jit.trace(model, inputs)
+ self.assertEqual(traced_model(*inputs), model(*inputs))
+ self.assertExportImport(traced_model.graph, (scores, bbox_deltas, im_info, anchors))
+
def test_nested_inplace(self):
x = torch.randn(2, 2)
trace, outputs, inputs = torch.jit.get_trace_graph(
#include <ATen/core/dispatch/Dispatcher.h>
#include <torch/csrc/jit/operator.h>
+#include <torch/csrc/jit/tracer.h>
namespace torch {
namespace jit {
const auto input_size = op.schema().arguments().size();
const auto output_size = op.schema().returns().size();
+ Node* node = nullptr;
+
// unwrap tensor inputs from variable
for (auto iter = stack.end() - input_size; iter != stack.end(); ++iter) {
// TODO Remove the .defined() check once we don't have undefined tensors on the stack anymore (@wanchaol is working on this)
}
}
+ if (jit::tracer::isTracing()) {
+ auto symbol = Symbol::fromQualString(op.schema().name());
+ const auto& graph = tracer::getTracingState()->graph;
+ node = graph->create(symbol, 0);
+ const auto& args = op.schema().arguments();
+ int i = 0;
+ for (auto iter = stack.end() - input_size; iter != stack.end();
+ ++iter, ++i) {
+ // TODO we need to refactor graph APIs (e.g., addInputs)
+ // appropriately; after that, we can get rid of the giant if-else
+ // block we will clean this tech debt together in the following PRs
+ auto type = args[i].type();
+ if (type->kind() == TypeKind::OptionalType) {
+ if (iter->isNone()) {
+ Value* none =
+ graph
+ ->insertNode(graph->createNone(
+ reinterpret_cast<OptionalType*>(args[i].type().get())
+ ->getElementType()))
+ ->output();
+ node->addInput(none);
+ continue;
+ } else {
+ type =
+ reinterpret_cast<OptionalType*>(type.get())->getElementType();
+ }
+ }
+ if (type->isSubclass(TypeKind::TensorType)) {
+ AT_ASSERT(iter->isTensor());
+ tracer::addInputs(node, args[i].name().c_str(), iter->toTensor());
+ } else if (type->kind() == TypeKind::FloatType) {
+ AT_ASSERT(iter->isDouble());
+ tracer::addInputs(node, args[i].name().c_str(), iter->toDouble());
+ } else if (type->kind() == TypeKind::IntType) {
+ AT_ASSERT(iter->isInt());
+ tracer::addInputs(node, args[i].name().c_str(), iter->toInt());
+ } else if (type->kind() == TypeKind::BoolType) {
+ AT_ASSERT(iter->isBool());
+ tracer::addInputs(node, args[i].name().c_str(), iter->toBool());
+ } else if (type->kind() == TypeKind::StringType) {
+ AT_ASSERT(iter->isString());
+ tracer::addInputs(
+ node, args[i].name().c_str(), iter->toStringRef());
+ } else if (type->kind() == TypeKind::ListType) {
+ const auto& elem_type =
+ reinterpret_cast<ListType*>(type.get())->getElementType();
+ if (elem_type->isSubclass(TypeKind::TensorType)) {
+ AT_ASSERT(iter->isTensorList());
+ tracer::addInputs(
+ node,
+ args[i].name().c_str(),
+ iter->toTensorList()->elements());
+ } else if (elem_type->kind() == TypeKind::FloatType) {
+ AT_ASSERT(iter->isDoubleList());
+ tracer::addInputs(
+ node,
+ args[i].name().c_str(),
+ iter->toDoubleList()->elements());
+ } else if (elem_type->kind() == TypeKind::IntType) {
+ AT_ASSERT(iter->isIntList());
+ tracer::addInputs(
+ node, args[i].name().c_str(), iter->toIntList()->elements());
+ } else if (elem_type->kind() == TypeKind::BoolType) {
+ AT_ASSERT(iter->isBoolList());
+ tracer::addInputs(
+ node, args[i].name().c_str(), iter->toBoolList()->elements());
+ } else {
+ throw std::runtime_error(
+ "unsupported input list type: " + elem_type->str());
+ }
+ } else {
+ throw std::runtime_error("unsupported input type: " + type->str());
+ }
+ }
+ graph->insertNode(node);
+ }
+
c10::Dispatcher::singleton().lookup(op, &stack).call(&stack);
// wrap tensor outputs as variable
}
}
+ if (jit::tracer::isTracing()) {
+ int i = 0;
+ for (auto iter = stack.end() - output_size; iter != stack.end();
+ ++iter, ++i) {
+ const auto& type = op.schema().returns()[i].type();
+ if (type->isSubclass(TypeKind::TensorType)) {
+ AT_ASSERT(iter->isTensor());
+ tracer::addOutput(node, iter->toTensor());
+ } else if (type->kind() == TypeKind::ListType) {
+ const auto& elem_type =
+ reinterpret_cast<ListType*>(type.get())->getElementType();
+ if (elem_type->isSubclass(TypeKind::TensorType)) {
+ AT_ASSERT(iter->isTensorList());
+ tracer::addOutput(node, iter->toTensorList()->elements());
+ } else {
+ throw std::runtime_error(
+ "unsupported ouptut list type: " + elem_type->str());
+ }
+ } else {
+ throw std::runtime_error("unsupported output type: " + type->str());
+ }
+ }
+ }
+
return 0;
});
}