Add ability to specialize class types to ArgumentSpec (#18314)
authorZachary DeVito <zdevito@fb.com>
Wed, 3 Apr 2019 00:33:06 +0000 (17:33 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 3 Apr 2019 00:35:57 +0000 (17:35 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18314
ghimport-source-id: 8cecb768d476ab19c9460f39c8f94a764e4cb052

Stack from [ghstack](https://github.com/ezyang/ghstack):
* **#18314 Add ability to specialize class types to ArgumentSpec**
* #18226 Add Slot type to abstract the raw pointers being used for slots.

Differential Revision: D14574395

fbshipit-source-id: cc3af6e56e9ae52990f4a1ad56ecceaa2d493577

14 files changed:
aten/src/ATen/core/ivalue.h
aten/src/ATen/core/jit_type.h
aten/src/ATen/core/type.cpp
test/cpp/jit/test_autodiff.h
test/test_jit.py
tools/build_variables.py
torch/CMakeLists.txt
torch/csrc/jit/argument_spec.cpp [new file with mode: 0644]
torch/csrc/jit/argument_spec.h
torch/csrc/jit/graph_executor.cpp
torch/csrc/jit/init.cpp
torch/csrc/jit/passes/shape_analysis.cpp
torch/csrc/jit/python_ir.cpp
torch/csrc/jit/script/module.h

index 487107e..dd4ed54 100644 (file)
@@ -704,7 +704,9 @@ struct C10_EXPORT ivalue::Object final : c10::intrusive_ptr_target {
   Symbol name() const {
     return typename_;
   }
-
+  const std::vector<IValue>& slots() const {
+    return slots_;
+  }
  private:
   const Symbol typename_;
   std::vector<IValue> slots_;
index 2acba07..be99f93 100644 (file)
@@ -1203,10 +1203,21 @@ struct CAFFE2_API ClassType : public Type {
     attributeTypes_.push_back(type);
   }
 
+  at::ArrayRef<std::string> attributeNames() const {
+    return attributeNames_;
+  }
+
   at::ArrayRef<TypePtr> containedTypes() const override {
     return attributeTypes_;
   }
 
+  // generate a refined version of this class.
+  // It has the same name but the slot Types are subtypes of
+  // the original slots. It is only valid to refine a class type in a context
+  // where it is know that there are not assignments to the objects slots
+  // that would invalidate the refinement.
+  // These variants are not registered in the global class table.
+  ClassTypePtr refine(at::ArrayRef<TypePtr> refined_slots) const;
   static const TypeKind Kind = TypeKind::ClassType;
 
  private:
index fc741e1..fe93a2c 100644 (file)
@@ -478,6 +478,16 @@ ClassTypePtr ClassType::create(
   return ptr;
 }
 
+ClassTypePtr ClassType::refine(at::ArrayRef<TypePtr> refined_slots) const {
+  auto ptr = ClassTypePtr(new ClassType(typename_, module_));
+  AT_ASSERT(numAttributes() == refined_slots.size());
+  for(size_t i = 0; i < attributeNames_.size(); ++i) {
+    AT_ASSERT(refined_slots[i]->isSubtypeOf(attributeTypes_[i]));
+    ptr->addAttribute(attributeNames_[i], refined_slots[i]);
+  }
+  return ptr;
+}
+
 ClassTypePtr ClassType::get(const std::string& name) {
   return getRegistry().getType(name);
 }
index 3ec4929..00c7624 100644 (file)
@@ -208,7 +208,10 @@ void testDifferentiateWithRequiresGrad() {
       at::empty_strided(2, 2, at::CPU(at::kFloat).options()), true);
   auto b_var = autograd::make_variable(
       at::empty_strided(2, 2, at::CPU(at::kFloat).options()), false);
-  setInputTypes(*graph, ArgumentSpec(true, {a_var, b_var}, 2));
+
+  ArgumentSpecCreator asc(*graph);
+  asc.setInputTypes(*graph, asc.create(true, {a_var, b_var}));
+
   PropagateInputShapes(graph);
   PropagateRequiresGrad(graph);
 
index 89d3592..22ce30a 100644 (file)
@@ -1893,17 +1893,18 @@ class TestJit(JitTestCase):
 
     def test_tuple_specialization(self):
         @torch.jit.script
-        def f(t):
-            # type: (Tuple[Tensor, Tensor]) -> Tensor
-            x, y = t
+        def f(t, s):
+            # type: (Tuple[Tensor, Tuple[int, Tensor]], str) -> Tensor
+            x, t2 = t
+            _, y = t2
             return x + y
 
-        t = torch.randn(2, 2), torch.randn(2, 2)
-        f(t)
-        graph = f.graph_for(t)
+        t = torch.randn(2, 2), (1, torch.randn(2, 2)),
+        f(t, "hi")
+        graph = f.graph_for(t, "hi")
         input_types = list(next(graph.inputs()).type().elements())
-        for t in input_types:
-            self.assertEqual(t.kind(), 'DimensionedTensorType')
+        self.assertEqual(input_types[0].kind(), 'DimensionedTensorType')
+        self.assertEqual(input_types[1].elements()[1].kind(), 'DimensionedTensorType')
 
     def test_constant_prop_simple(self):
         @torch.jit.script
@@ -3450,13 +3451,11 @@ a")
         # test that shape analysis is written correctly for sum with IntArrayRef[1] dim argument
         self.run_pass('constant_propagation', func.graph)
         self.run_pass('constant_propagation', func2.graph)
-        torch._C._jit_pass_shape_analysis(
-            func.graph, (torch.zeros(1, 1, 1, 1, 4),), False)
-        torch._C._jit_pass_shape_analysis(
-            func2.graph, (torch.zeros(1, 1, 1, 1, 4),), False)
-        self.assertTrue(func.graph.findNode("aten::sum").output().type().kind()
+        g = func._get_method('forward').propagate_shapes((torch.zeros(1, 1, 1, 1, 4),), False)
+        g2 = func2._get_method('forward').propagate_shapes((torch.zeros(1, 1, 1, 1, 4),), False)
+        self.assertTrue(g.findNode("aten::sum").output().type().kind()
                         == "DimensionedTensorType")
-        self.assertTrue(func2.graph.findNode("aten::sum").output().type().kind()
+        self.assertTrue(g2.findNode("aten::sum").output().type().kind()
                         == "DimensionedTensorType")
 
     def test_cat(self):
@@ -4154,9 +4153,9 @@ a")
                 torch.mul(x, y, out=z)
                 return z
 
-            torch._C._jit_pass_shape_analysis(
-                test.graph, (torch.zeros(2, 1), torch.zeros(1, 2), torch.zeros(1, 1, 1)), False)
-            self.assertTrue(next(test.graph.outputs()).type() == TensorType.get())
+            graph = test._get_method('forward').propagate_shapes(
+                (torch.zeros(2, 1), torch.zeros(1, 2), torch.zeros(1, 1, 1)), False)
+            self.assertTrue(next(graph.outputs()).type() == TensorType.get())
         out_op_graph_input()
 
         def test_resize():
@@ -4173,10 +4172,8 @@ a")
                     after_resize_alias = b.add_(1)
                 return after_resize_alias
 
-            g = test.graph
-            self.run_pass('constant_propagation', g)
-            torch._C._jit_pass_shape_analysis(
-                g, (torch.zeros(1, 1),), False)
+            self.run_pass('constant_propagation', test.graph)
+            g = test._get_method('forward').propagate_shapes((torch.zeros(1, 1),), False)
             resize_node = g.findNode("aten::resize_")
             # first input and output of b.resize_ is b
             self.assertTrue(next(resize_node.inputs()).type() == TensorType.get())
@@ -4200,8 +4197,7 @@ a")
 
             g = test.graph
             self.run_pass('constant_propagation', g)
-            torch._C._jit_pass_shape_analysis(
-                g, (torch.zeros(1, 1),), False)
+            g = test._get_method('forward').propagate_shapes((torch.zeros(1, 1),), False)
 
             # x doesn't alias a resized op so it shouldn't be set to base Tensor type
             self.assertTrue(next(g.inputs()).type() != TensorType.get())
@@ -4255,8 +4251,8 @@ a")
             return x.view(T, B, C)
 
         x = torch.randn(3, 1, 5, requires_grad=True)
-        graph = torch.jit.script(fn).graph
-        torch._C._jit_pass_shape_analysis(graph, (x,), False)
+        fn = torch.jit.script(fn)
+        graph = fn._get_method('forward').propagate_shapes((x,), False)
         a = next(graph.outputs()).type().kind()
         self.assertTrue(next(graph.outputs()).type().kind() != 'TensorType')
 
@@ -6677,7 +6673,7 @@ a")
             return torch.cat(c)
 
         b = torch.zeros(2, 4)
-        test_list.graph.propagate_shapes((b,), False)
+        test_list._get_method('forward').propagate_shapes((b,), False)
 
     def test_if_supertype(self):
         @torch.jit.script
@@ -6694,8 +6690,8 @@ a")
         b = torch.zeros(2, 4, dtype=torch.long)
         c = torch.zeros(2, 4, dtype=torch.float)
 
-        tensor_unifying.graph.propagate_shapes((a, b, c), False)
-        if_outputs = list(tensor_unifying.graph.findNode("prim::If").outputs())
+        graph = tensor_unifying._get_method('forward').propagate_shapes((a, b, c), False)
+        if_outputs = list(graph.findNode("prim::If").outputs())
         self.assertTrue(if_outputs[0].type().str() == "Float(*, *)")
         self.assertTrue(if_outputs[1].type().str() == "Tensor")
         self.assertTrue(if_outputs[2].type().str() == "Tensor")
@@ -13303,6 +13299,30 @@ class TestClassType(JitTestCase):
         self.assertEqual(x, f2.x)
         self.assertEqual(y, f2.y)
 
+    def test_class_specialization(self):
+        @torch.jit.script  # noqa: B903
+        class Foo(object):
+            def __init__(self, x, y):
+                self.x = x
+                self.y = y
+
+        def use_foo(foo, foo2, tup):
+            # type: (Foo, Foo, Tuple[Foo, Foo]) -> Tensor
+            a, b = tup
+            return foo.x + foo2.y + a.x + b.y
+
+        # create from python
+        x = torch.ones(2, 3)
+        y = torch.zeros(2, 3)
+        f = Foo(x, y)
+        f2 = Foo(x * 2, y * 3)
+        f3 = Foo(x * 4, y * 4)
+
+        input = (f, f2, (f, f3))
+        sfoo = self.checkScript(use_foo, input)
+        graphstr = str(sfoo.graph_for(*input))
+        FileCheck().check_count("Double(*, *) = prim::GetAttr", 4).run(graphstr)
+
 
 class TestLogging(JitTestCase):
     def test_bump_numeric_counter(self):
index 2236172..89a5ed8 100644 (file)
@@ -51,6 +51,7 @@ libtorch_sources = [
     "torch/csrc/Exceptions.cpp",
     "torch/csrc/jit/autodiff.cpp",
     "torch/csrc/jit/attributes.cpp",
+    "torch/csrc/jit/argument_spec.cpp",
     "torch/csrc/jit/constants.cpp",
     "torch/csrc/jit/node_hashing.cpp",
     "torch/csrc/jit/export.cpp",
index 6cd237f..60f883a 100644 (file)
@@ -123,6 +123,7 @@ set(TORCH_SRCS
   ${TORCH_SRC_DIR}/csrc/autograd/VariableTypeManual.cpp
   ${TORCH_SRC_DIR}/csrc/jit/autodiff.cpp
   ${TORCH_SRC_DIR}/csrc/jit/attributes.cpp
+  ${TORCH_SRC_DIR}/csrc/jit/argument_spec.cpp
   ${TORCH_SRC_DIR}/csrc/jit/export.cpp
   ${TORCH_SRC_DIR}/csrc/jit/pickler.cpp
   ${TORCH_SRC_DIR}/csrc/jit/generated/register_aten_ops_0.cpp
diff --git a/torch/csrc/jit/argument_spec.cpp b/torch/csrc/jit/argument_spec.cpp
new file mode 100644 (file)
index 0000000..a31a91e
--- /dev/null
@@ -0,0 +1,229 @@
+
+#include <torch/csrc/jit/argument_spec.h>
+
+namespace torch {
+namespace jit {
+
+void ArgumentSpecCreator::scan(
+    const TypePtr& typ,
+    size_t depth,
+    const WrittenSlots& written_slots) {
+  auto finishAggregate = [&](size_t pos) {
+    // it is possible after all the work we did to scan this aggregate,
+    // we found no tensors to specialize. In this case, just generate
+    // a skip for the whole aggregate.
+    bool any_spec = std::any_of(
+        instructions_.begin() + pos, instructions_.end(), [](Inst i) {
+          return i == SPECIALIZE_TENSOR;
+        });
+    if (!any_spec) {
+      instructions_[pos] = SKIP;
+      instructions_.resize(pos + 1);
+    } else {
+      instructions_.emplace_back(LEAVE);
+    }
+  };
+  // the simple vm that scans instructions_ has a limited stack depth,
+  // this prevents going deeper than that.
+  if (depth >= DEPTH_LIMIT) {
+    instructions_.emplace_back(SKIP);
+  }
+  if (typ->isSubtypeOf(TensorType::get())) {
+    num_tensors_++;
+    instructions_.emplace_back(SPECIALIZE_TENSOR);
+  } else if (auto tup = typ->cast<TupleType>()) {
+    size_t pos = instructions_.size();
+    instructions_.emplace_back(ENTER_TUPLE);
+    for (const auto& elem : tup->containedTypes()) {
+      scan(elem, depth + 1, written_slots);
+    }
+    finishAggregate(pos);
+  } else if (auto cls = typ->cast<ClassType>()) {
+    size_t pos = instructions_.size();
+    instructions_.emplace_back(ENTER_OBJECT);
+    for (size_t i = 0; i < cls->numAttributes(); ++i) {
+      auto key = cls->name() + cls->attributeNames().at(i);
+      // it is only safe to specialize because someone might have written to it
+      if (!written_slots.count(key)) {
+        scan(cls->containedTypes().at(i), depth + 1, written_slots);
+      } else {
+        instructions_.emplace_back(SKIP);
+      }
+    }
+    finishAggregate(pos);
+  } else {
+    instructions_.emplace_back(SKIP);
+  }
+};
+
+// this is a coarse-grained guarentee that the slots of a class will not be
+// modified by the function. It works fine for things that used be read-only
+// modules, but will be overly conservative when some classes are written to.
+// Doing alias analysis and looking for writes to the class would be more
+// accurate.
+static void scanWrittenSlots(
+    Block* block,
+    ArgumentSpecCreator::WrittenSlots& written_slots) {
+  for (Node* n : block->nodes()) {
+    if (n->kind() == prim::SetAttr) {
+      if (auto cls = n->inputs().at(0)->type()->cast<ClassType>()) {
+        written_slots.insert(cls->name() + n->s(attr::name));
+      }
+    }
+    for (Block* subblock : n->blocks()) {
+      scanWrittenSlots(subblock, written_slots);
+    }
+    if (n->hasAttribute(attr::Subgraph)) {
+      scanWrittenSlots(n->g(attr::Subgraph)->block(), written_slots);
+    }
+  }
+}
+
+ArgumentSpecCreator::ArgumentSpecCreator(Graph& graph)
+    : num_inputs_(graph.inputs().size()) {
+  WrittenSlots written_slots;
+  scanWrittenSlots(graph.block(), written_slots);
+  for (Value* input : graph.inputs()) {
+    scan(input->type(), 0, written_slots);
+  }
+}
+
+void ArgumentSpecCreator::dump() const {
+  for (Inst inst : instructions_) {
+    switch (inst) {
+      case LEAVE:
+        std::cout << "] ";
+        break;
+      case ENTER_TUPLE:
+        std::cout << "Tuple[";
+        break;
+      case ENTER_OBJECT:
+        std::cout << "Object[";
+        break;
+      case SKIP:
+        std::cout << "Skip ";
+        break;
+      case SPECIALIZE_TENSOR:
+        std::cout << "SpecializeTensor ";
+        break;
+    }
+  }
+  std::cout << "\n";
+}
+
+ArgumentSpec ArgumentSpecCreator::create(bool with_grad, const Stack& input)
+    const {
+  ArgumentSpec spec(num_tensors_);
+  const IValue* stack[DEPTH_LIMIT]; // The stack of IValue lists
+  // The stack gets initialized with the input list
+  stack[0] = last(input, num_inputs_).begin();
+  size_t stack_top = 0; // offset to the top of the stack
+  for (Inst inst : instructions_) {
+    switch (inst) {
+      case SPECIALIZE_TENSOR:
+        // consume a tensor and add to the argspec
+        spec.addTensor(*stack[stack_top]++, with_grad);
+        break;
+      case ENTER_TUPLE: {
+        // consume tuple
+        const IValue* iv = stack[stack_top]++;
+        AT_ASSERT(iv->isTuple());
+        // see [argspec refcounting]
+        auto p = *reinterpret_cast<const at::ivalue::Tuple* const*>(iv);
+        auto tup_ptr = &p->elements()[0];
+        // push list of tuple elements to the stack
+        stack[++stack_top] = tup_ptr;
+      } break;
+      case ENTER_OBJECT: {
+        // consume object
+        const IValue* iv = stack[stack_top]++;
+        AT_ASSERT(iv->isObject());
+        iv->toObject();
+        // see [argspec refcounting]
+        auto p = *reinterpret_cast<const at::ivalue::Object* const*>(iv);
+        auto obj_ptr = &p->slots()[0];
+        // push list of object elements to the stack
+        stack[++stack_top] = obj_ptr;
+      } break;
+      case SKIP:
+        // consume and skip an element
+        stack[stack_top]++;
+        break;
+      case LEAVE:
+        --stack_top;
+        break;
+    }
+  }
+  return spec;
+}
+
+// For every input of a given graph, returns a most detailed type that can be
+// inferred for it based on this ArgumentSpec.
+std::vector<TypePtr> ArgumentSpecCreator::getSpecializedTypes(
+    Graph& graph,
+    const ArgumentSpec& spec) const {
+  auto input_types =
+      fmap(graph.inputs(), [](Value* input) { return input->type(); });
+  std::vector<std::vector<TypePtr>> result_stack;
+  result_stack.emplace_back();
+  std::vector<const TypePtr*> input_stack = {input_types.data()};
+  std::vector<std::function<TypePtr()>> aggregate_creators;
+
+  size_t arg_spec_offset = 0; // number of specialized tensors seen so far
+
+  for (Inst inst : instructions_) {
+    switch (inst) {
+      case SPECIALIZE_TENSOR: {
+        input_stack.back()++;
+        auto& arg = spec.at(arg_spec_offset++);
+        if (!arg.defined()) {
+          result_stack.back().emplace_back(AutogradZeroTensorType::get());
+        } else {
+          result_stack.back().emplace_back(DimensionedTensorType::create(
+              arg.type(),
+              ConvertIntToCPUOrCUDA(arg.device()),
+              arg.dim(),
+              arg.requires_grad()));
+        }
+      } break;
+      case ENTER_TUPLE: {
+        auto tup = (*input_stack.back()++)->expect<TupleType>();
+        input_stack.emplace_back(tup->elements().data());
+        result_stack.emplace_back();
+        aggregate_creators.emplace_back(
+            [&] { return TupleType::create(result_stack.back()); });
+      } break;
+      case ENTER_OBJECT: {
+        auto cls = (*input_stack.back()++)->expect<ClassType>();
+        input_stack.emplace_back(cls->containedTypes().data());
+        result_stack.emplace_back();
+        aggregate_creators.emplace_back(
+            [&result_stack, cls] { return cls->refine(result_stack.back()); });
+      } break;
+      case SKIP:
+        result_stack.back().emplace_back(*input_stack.back()++);
+        break;
+      case LEAVE:
+        TypePtr result = aggregate_creators.back()();
+        result_stack.pop_back();
+        aggregate_creators.pop_back();
+        input_stack.pop_back();
+        result_stack.back().emplace_back(std::move(result));
+        break;
+    }
+  }
+  AT_ASSERT(result_stack.size() == 1);
+  return result_stack.back();
+}
+
+void ArgumentSpecCreator::setInputTypes(Graph& g, const ArgumentSpec& spec)
+    const {
+  auto input_types = getSpecializedTypes(g, spec);
+  auto inputs = g.inputs();
+  for (size_t i = 0; i < inputs.size(); ++i) {
+    inputs[i]->setType(input_types[i]);
+  }
+}
+
+} // namespace jit
+} // namespace torch
index a345c6b..ca88836 100644 (file)
@@ -1,9 +1,9 @@
 #pragma once
 
+#include <ATen/core/jit_type.h>
+#include <ATen/core/stack.h>
 #include <torch/csrc/autograd/variable.h>
 #include <torch/csrc/jit/ir.h>
-#include <ATen/core/stack.h>
-#include <ATen/core/jit_type.h>
 #include <torch/csrc/jit/variable_tensor_list.h>
 #include <torch/csrc/utils/hash.h>
 #include <iostream>
@@ -22,9 +22,6 @@ struct ArgumentInfo {
   friend struct ArgumentSpec;
   using plain_data_type = uint32_t;
 
-  bool isTensor() const {
-    return is_tensor_;
-  }
   bool defined() const {
     return defined_;
   }
@@ -45,11 +42,11 @@ struct ArgumentInfo {
   operator TypePtr() const {
     if (!defined())
       return TensorType::get();
-    return DimensionedTensorType::create(type(), ConvertIntToCPUOrCUDA(device()), dim());
+    return DimensionedTensorType::create(
+        type(), ConvertIntToCPUOrCUDA(device()), dim());
   }
 
  private:
-  unsigned is_tensor_ : 1;
   unsigned defined_ : 1;
   unsigned requires_grad_ : 1;
   unsigned : 5;
@@ -67,48 +64,32 @@ static_assert(
     "ArgumentInfo is expected to be a 32-bit struct");
 
 struct ArgumentSpec {
-  ArgumentSpec(
-      bool with_grad,
-      at::ArrayRef<IValue> inputs,
-      size_t num_flat_inputs) {
+  ArgumentSpec(size_t num_flat_inputs) {
     hash_code = num_flat_inputs;
-    args.resize(num_flat_inputs);
-    size_t offset = 0;
-    for (const auto& i : inputs) {
-      addInput(i, offset, with_grad);
-    }
-    AT_ASSERT(offset <= num_flat_inputs);
+    args.reserve(num_flat_inputs);
   }
 
-  void addInput(const IValue& input, size_t& offset, bool with_grad) {
-    auto& arg = args.at(offset);
+  void addTensor(const IValue& input, bool with_grad) {
+    AT_ASSERT(input.isTensor());
+    args.emplace_back();
+    auto& arg = args.back();
     // Initialize all fields to 0. This is convenient, because e.g.
     // requires_grad() can be checked even on tensors AND will make
     // padding bits all 0s.
     std::memset(&arg, 0, sizeof(ArgumentInfo));
 
-    if (input.isTensor()) {
-      at::Tensor t = input.toTensor();
-      if ((arg.defined_ = t.defined())) {
-        arg.requires_grad_ = with_grad && autograd::Variable(t).requires_grad();
-        arg.dim_ = t.dim();
-        arg.device_ = t.is_cuda() ? t.get_device() : -1;
-        arg.type_ = static_cast<unsigned>(t.scalar_type());
-      }
-
-      arg.is_tensor_ = true;
-      combineHash(arg);
-      offset++;
-    } else if (input.isTuple()) {
-      for (const IValue& elem : input.toTuple()->elements()) {
-        addInput(elem, offset, with_grad);
-      }
-    } else {
-      // NB: no need to set is_tensor to false, because we memset the struct to
-      // 0 above
-      combineHash(arg);
-      offset++;
+    // [argspec refcounting] reinterpret the IValue to avoid having to refcount
+    // the Tensor microbenchmarks
+    // https://github.com/zdevito/pytorch/commit/21e7200a0a0fc456bea2f10e95b1781f83933d10
+    // show overhead in extra refcounting along this path
+    const at::Tensor* t = reinterpret_cast<const at::Tensor*>(&input);
+    if ((arg.defined_ = t->defined())) {
+      arg.requires_grad_ = with_grad && autograd::Variable(*t).requires_grad();
+      arg.dim_ = t->dim();
+      arg.device_ = t->is_cuda() ? t->get_device() : -1;
+      arg.type_ = static_cast<unsigned>(t->scalar_type());
     }
+    combineHash(arg);
   }
 
   void combineHash(const ArgumentInfo& arg) {
@@ -143,38 +124,49 @@ struct ArgumentSpec {
   size_t hashCode() const {
     return hash_code;
   }
-  // For every input of a given graph, returns a most detailed type that can be
-  // inferred for it based on this ArgumentSpec.
-  std::vector<TypePtr> getTypes(Graph& graph) const {
-    size_t offset = 0;
-    return fmap(
-        graph.inputs(), [&](Value* v) { return fillType(v->type(), offset); });
-  }
 
  private:
-  TypePtr fillType(TypePtr original, size_t& offset) const {
-    if (original->isSubtypeOf(TensorType::get())) {
-      auto& arg = args.at(offset++);
-      if (!arg.defined())
-        return AutogradZeroTensorType::get();
-      return DimensionedTensorType::create(
-          arg.type(),
-          ConvertIntToCPUOrCUDA(arg.device()),
-          arg.dim(),
-          arg.requires_grad());
-    } else if (auto tuple_type = original->cast<TupleType>()) {
-      return TupleType::create(fmap(
-          tuple_type->elements(),
-          [&](const TypePtr& subtype) { return fillType(subtype, offset); }));
-    } else {
-      offset++;
-      return original;
-    }
-  }
   size_t hash_code; // precomputed on construction
   std::vector<ArgumentInfo> args;
 };
 
+// ArgumentSpecCreator takes an initial graph and comes up with a set
+// of simple instructions to compute the ArgumentSpec given a set of
+// input tensors.
+struct ArgumentSpecCreator {
+  // instructs acts on a stack of a list of input IValues
+  // at the beginning the stack contains a single list of the inputs to the
+  // function the ENTER_ instructs descend into subobjects and push new lists
+  // onto the stack
+  enum Inst : char {
+    ENTER_TUPLE, // consume a tuple ivalue from the top-most list, and push the
+                 // list of its elements onto the stack as a new list
+    ENTER_OBJECT, // same as ENTER_TUPLE, but the input is a class
+    LEAVE, // pop the top-most list from the stack
+    SKIP, // consume an element from the top-most list, and discard
+    SPECIALIZE_TENSOR, // consume a tensor for the top-most list, and
+                       // add it to the ArgSpec key being created
+  };
+  ArgumentSpecCreator(Graph& graph);
+  ArgumentSpec create(bool with_grad, const Stack& stack) const;
+  void setInputTypes(Graph& g, const ArgumentSpec& spec) const;
+  std::vector<TypePtr> getSpecializedTypes(
+      Graph& graph,
+      const ArgumentSpec& spec) const;
+  void dump() const;
+  using WrittenSlots = std::unordered_set<std::string>;
+
+ private:
+  static constexpr size_t DEPTH_LIMIT = 128;
+  void scan(
+      const TypePtr& typ,
+      size_t depth,
+      const WrittenSlots& written_slots);
+  size_t num_inputs_;
+  size_t num_tensors_ = 0;
+  std::vector<Inst> instructions_;
+};
+
 // CompleteArgumentSpec represents one particular specialization.
 // It is designed so that it can be created, hashed, and compared quickly
 // since it is used along the hot-path of the JIT to check if the code
@@ -398,14 +390,6 @@ inline CompleteArgumentInfo CompleteArgumentSpec::at(size_t i) const {
   return CompleteArgumentInfo(*this, i);
 }
 
-inline void setInputTypes(Graph& g, const ArgumentSpec& spec) {
-  auto input_types = spec.getTypes(g);
-  auto inputs = g.inputs();
-  for (size_t i = 0; i < inputs.size(); ++i) {
-    inputs[i]->setType(input_types[i]);
-  }
-}
-
 } // namespace jit
 } // namespace torch
 
index 1e993da..4d3f848 100644 (file)
@@ -322,28 +322,6 @@ struct GraphExecutorImpl {
     return copy;
   }
 
-  static size_t countFlatInputs(const TypePtr& ptr) {
-    if (auto optional_type = ptr->cast<OptionalType>()) {
-      return countFlatInputs(optional_type->getElementType());
-    }
-    if (auto tuple_type = ptr->cast<TupleType>()) {
-      size_t total = 0;
-      for (auto& elem : tuple_type->elements()) {
-        total += countFlatInputs(elem);
-      }
-      return total;
-    }
-    return 1;
-  }
-
-  static size_t countFlatInputs(const std::shared_ptr<Graph>& graph) {
-    size_t total = 0;
-    for (Value* input : graph->inputs()) {
-      total += countFlatInputs(input->type());
-    }
-    return total;
-  }
-
   inline bool hasMutableOperators(Block* block) {
     for (auto n : block->nodes()) {
       if (n->kind().is_aten() && n->schema().is_mutable())
@@ -362,11 +340,11 @@ struct GraphExecutorImpl {
         // disables all optimization
         optimize(optimize),
         num_inputs(this->graph->inputs().size()),
-        num_flat_inputs(countFlatInputs(graph)),
+        arg_spec_creator_(*graph),
         num_outputs(this->graph->outputs().size()) {
-    logging::getLogger()->addStatValue(
-        logging::runtime_counters::GRAPH_EXECUTORS_CONSTRUCTED, 1.0);
-  }
+          logging::getLogger()->addStatValue(
+              logging::runtime_counters::GRAPH_EXECUTORS_CONSTRUCTED, 1.0);
+        }
 
   // entry point where execution begins
   void run(Stack& stack) {
@@ -391,9 +369,9 @@ struct GraphExecutorImpl {
 
   std::shared_ptr<Graph> graphFor(const Stack& stack) const {
     AT_ASSERT(stack.size() >= num_inputs);
-    auto inputs = last(stack, num_inputs);
-    ArgumentSpec spec(
-        autograd::GradMode::is_enabled(), inputs, num_flat_inputs);
+
+    ArgumentSpec spec =
+        arg_spec_creator_.create(autograd::GradMode::is_enabled(), stack);
 
     if (!optimize) {
       AT_CHECK(fallback, "No graph found for given inputs");
@@ -441,10 +419,8 @@ struct GraphExecutorImpl {
   const ExecutionPlan& getOrCompile(const Stack& stack) {
     // outside lock guard, to minimize the time holding the lock on the fast
     // path ArgumentSpec even computes its hashCode here.
-    ArgumentSpec spec(
-        autograd::GradMode::is_enabled(),
-        last(stack, num_inputs),
-        num_flat_inputs);
+    ArgumentSpec spec =
+        arg_spec_creator_.create(autograd::GradMode::is_enabled(), stack);
     {
       std::lock_guard<std::mutex> lock(compile_mutex);
       auto it = plan_cache.find(spec);
@@ -463,7 +439,7 @@ struct GraphExecutorImpl {
 
   ExecutionPlan compileSpec(const ArgumentSpec& spec) {
     auto opt_graph = graph->copy();
-    setInputTypes(*opt_graph, spec);
+    arg_spec_creator_.setInputTypes(*opt_graph, spec);
 
     // Phase 1. Specialize to input definedness (this is very important for
     //          gradient graphs), and run required passes to bring the graph
@@ -562,8 +538,8 @@ struct GraphExecutorImpl {
     auto input_values = fmap(
         inputs, [](const IValue& v) { return tracer::getNestedValueTrace(v); });
 
-    ArgumentSpec spec(
-        autograd::GradMode::is_enabled(), inputs, num_flat_inputs);
+    ArgumentSpec spec =
+        arg_spec_creator_.create(autograd::GradMode::is_enabled(), stack);
     // NB: we could just run the fallback in here and call it a day, but that
     // would loose all the control flow information we have in the graph. Thus,
     // we run the fallback to get the correct output values, but we will
@@ -580,7 +556,7 @@ struct GraphExecutorImpl {
     // tracing and so we only do the type propgation if no concrete types have
     // been set.
     auto local_graph = this->graph->copy();
-    setInputTypes(*local_graph, spec);
+    arg_spec_creator_.setInputTypes(*local_graph, spec);
     PropagateInputShapes(local_graph);
     auto output_values =
         inlineCallTo(*state->graph, *local_graph, input_values);
@@ -600,8 +576,7 @@ struct GraphExecutorImpl {
   // Useful for debugging.
   const bool optimize;
   const size_t num_inputs;
-  const size_t num_flat_inputs; // Number of inputs, assuming all tuples would
-                                // be flattened.
+  ArgumentSpecCreator arg_spec_creator_;
   const size_t num_outputs;
 
   // Populated only when optimize is false (and in that case plan_cache will be
index 6ee8d9a..e693a15 100644 (file)
@@ -156,16 +156,6 @@ void initJITBindings(PyObject* module) {
           [](const std::shared_ptr<Graph>& g) { return Canonicalize(g); })
       .def("_jit_pass_lint", LintGraph)
       .def(
-          "_jit_pass_shape_analysis",
-          [](std::shared_ptr<Graph> graph,
-             std::vector<at::Tensor> inputs,
-             bool with_grad) {
-            setInputTypes(
-                *graph,
-                ArgumentSpec(with_grad, fmap<IValue>(inputs), inputs.size()));
-            PropagateInputShapes(graph);
-          })
-      .def(
           "_jit_pass_complete_shape_analysis",
           [](std::shared_ptr<Graph> graph, py::tuple inputs, bool with_grad) {
             CompleteArgumentSpec spec(
index bd162cf..5f63188 100644 (file)
@@ -528,6 +528,12 @@ class ShapePropagator {
         setUnshapedType(node);
         return;
       }
+      case prim::GetAttr: {
+        auto cls = node->input()->type()->expect<ClassType>();
+        // propagate any type specializations encoded in the type of the class
+        node->output()->setType(cls->getAttribute(node->s(attr::name)));
+        return;
+      }
       case aten::_unwrap_optional: {
         auto input_ivalue = toIValue(node->input());
         if (input_ivalue && input_ivalue->isNone()) {
@@ -997,11 +1003,9 @@ class ShapePropagator {
     };
 
     // Requirements:
-    //   dims           : 0 if dim is None, otherwise preserved if keepdim == false or 1 smaller otherwise
-    //   scalar type    : preserved
-    //   device         : preserved
-    //   tensor inputs  : 1
-    //   tensor outputs : 1
+    //   dims           : 0 if dim is None, otherwise preserved if keepdim ==
+    //   false or 1 smaller otherwise scalar type    : preserved device :
+    //   preserved tensor inputs  : 1 tensor outputs : 1
     // Additionally:
     //   - First input should be the only tensor input
     //   - Has a bool keepdim argument
@@ -1094,7 +1098,9 @@ class ShapePropagator {
         [](Node* node) -> type_vec_t {
           if (auto dim = node->get<std::vector<int64_t>>(attr::dim)) {
             return multidim_reduce_with_postprocess(
-                node, /*num_reduced_dim=*/dim->size(), /*upcast_integer=*/false);
+                node,
+                /*num_reduced_dim=*/dim->size(),
+                /*upcast_integer=*/false);
           }
           return {};
         }};
index a6a900b..4ffaa94 100644 (file)
@@ -209,16 +209,6 @@ void initPythonIRBindings(PyObject* module_) {
             db.dump();
           })
       .def(
-          "propagate_shapes",
-          [](std::shared_ptr<Graph> g,
-             std::vector<at::Tensor> inputs,
-             bool with_grad) {
-            setInputTypes(
-                *g,
-                ArgumentSpec(with_grad, fmap<IValue>(inputs), inputs.size()));
-            PropagateInputShapes(g);
-          })
-      .def(
           "_export_onnx",
           [](const std::shared_ptr<Graph> g,
              const std::map<std::string, at::Tensor>& initializers,
index d4417a7..ba99b2c 100644 (file)
@@ -1,15 +1,14 @@
 #pragma once
-#include <torch/csrc/autograd/variable.h>
+#include <c10/util/Exception.h>
 #include <torch/csrc/autograd/generated/variable_factories.h>
+#include <torch/csrc/autograd/variable.h>
 #include <torch/csrc/jit/argument_spec.h>
-#include <c10/util/Exception.h>
 #include <torch/csrc/jit/graph_executor.h>
 #include <torch/csrc/jit/ir.h>
 #include <torch/csrc/jit/named_value.h>
 #include <torch/csrc/jit/passes/shape_analysis.h>
-#include <torch/csrc/jit/source_range.h>
 #include <torch/csrc/jit/script/slot.h>
-
+#include <torch/csrc/jit/source_range.h>
 
 #include <torch/csrc/WindowsTorchApiMacro.h>
 #include <torch/csrc/api/include/torch/ordered_dict.h>
@@ -51,8 +50,8 @@ using ExtraFilesMap = std::unordered_map<std::string, std::string>;
 
 struct Module;
 
-using ModuleLookup = std::function<std::shared_ptr<Module>(
-    const std::vector<std::string>&)>;
+using ModuleLookup =
+    std::function<std::shared_ptr<Module>(const std::vector<std::string>&)>;
 
 struct Method {
   Method(
@@ -137,6 +136,14 @@ struct Method {
     return graph()->addInput()->setType(type);
   }
 
+  static void setInputTensorTypes(Graph& g, const Stack& stack) {
+    AT_ASSERT(stack.size() == g.inputs().size());
+    for (size_t i = 0; i < stack.size(); ++i) {
+      g.inputs().at(i)->setType(
+          DimensionedTensorType::create(stack.at(i).toTensor()));
+    }
+  }
+
   std::shared_ptr<Graph> propagate_shapes(
       std::vector<at::Tensor> inputs,
       bool with_grad = false) {
@@ -149,8 +156,7 @@ struct Method {
     for (const Slot& inp : initial_ivalues_) {
       stack.push_back(*inp);
     }
-    const auto size = stack.size();
-    setInputTypes(*retval, ArgumentSpec(with_grad, stack, size));
+    setInputTensorTypes(*retval, stack);
     PropagateInputShapes(retval);
     return retval;
   }
@@ -167,9 +173,7 @@ struct Method {
       }
     }
     if (propagate) {
-      setInputTypes(
-          *retval,
-          ArgumentSpec(with_grad, fmap<IValue>(inputs), inputs.size()));
+      setInputTensorTypes(*retval, fmap<IValue>(inputs));
       PropagateInputShapes(retval);
     }
     AT_ASSERT(retval->inputs().size() == inputs.size());
@@ -288,16 +292,16 @@ struct Method {
       if (pos < inputs.size()) {
         if (!isSubvalueOf(inputs[pos], argument.type())) {
           AT_ERROR(
-            "Expected value of type ",
-            *argument.type(),
-            " for argument '",
-            argument.name(),
-            "' in position ",
-            pos,
-            ", but instead got value of type ",
-            attemptToRecoverType(inputs[pos])->str(),
-            ". Declaration: ",
-            schema);
+              "Expected value of type ",
+              *argument.type(),
+              " for argument '",
+              argument.name(),
+              "' in position ",
+              pos,
+              ", but instead got value of type ",
+              attemptToRecoverType(inputs[pos])->str(),
+              ". Declaration: ",
+              schema);
         }
       } else if (argument.default_value()) {
         inputs.push_back(*argument.default_value());
@@ -375,7 +379,8 @@ struct NamedIValue {
   const TypePtr& type() const {
     return type_;
   }
-private:
+
+ private:
   const std::string name_;
   const TypePtr type_;
   std::unique_ptr<IValue> ivalue_;
@@ -497,12 +502,10 @@ struct Module {
   const torch::OrderedDict<std::string, NamedModule>& get_modules() const {
     return modules;
   }
-  const torch::OrderedDict<std::string, NamedIValue>& get_parameters()
-      const {
+  const torch::OrderedDict<std::string, NamedIValue>& get_parameters() const {
     return parameters;
   }
-  const torch::OrderedDict<std::string, NamedIValue>& get_attributes()
-      const {
+  const torch::OrderedDict<std::string, NamedIValue>& get_attributes() const {
     return attributes;
   }
   const torch::OrderedDict<std::string, std::unique_ptr<Method>>& get_methods()
@@ -630,9 +633,7 @@ struct Module {
       if (!kv.value().type()->isSubtypeOf(TensorType::get())) {
         continue;
       }
-      curr->register_buffer(
-          kv.key(),
-          kv.value().slot()->toTensor());
+      curr->register_buffer(kv.key(), kv.value().slot()->toTensor());
       parameter_remap[kv.value().slot()] = curr->find_buffer(kv.key())->slot();
     }
     for (auto& kv : modules) {