From: Zachary DeVito Date: Wed, 3 Apr 2019 00:33:06 +0000 (-0700) Subject: Add ability to specialize class types to ArgumentSpec (#18314) X-Git-Tag: accepted/tizen/6.5/unified/20211028.231830~466 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=2d07993bcbed8aec20d09c9192d30c4eefeb1507;p=platform%2Fupstream%2Fpytorch.git Add ability to specialize class types to ArgumentSpec (#18314) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18314 ghimport-source-id: 8cecb768d476ab19c9460f39c8f94a764e4cb052 Stack from [ghstack](https://github.com/ezyang/ghstack): * **#18314 Add ability to specialize class types to ArgumentSpec** * #18226 Add Slot type to abstract the raw pointers being used for slots. Differential Revision: D14574395 fbshipit-source-id: cc3af6e56e9ae52990f4a1ad56ecceaa2d493577 --- diff --git a/aten/src/ATen/core/ivalue.h b/aten/src/ATen/core/ivalue.h index 487107e..dd4ed54 100644 --- a/aten/src/ATen/core/ivalue.h +++ b/aten/src/ATen/core/ivalue.h @@ -704,7 +704,9 @@ struct C10_EXPORT ivalue::Object final : c10::intrusive_ptr_target { Symbol name() const { return typename_; } - + const std::vector& slots() const { + return slots_; + } private: const Symbol typename_; std::vector slots_; diff --git a/aten/src/ATen/core/jit_type.h b/aten/src/ATen/core/jit_type.h index 2acba07..be99f93 100644 --- a/aten/src/ATen/core/jit_type.h +++ b/aten/src/ATen/core/jit_type.h @@ -1203,10 +1203,21 @@ struct CAFFE2_API ClassType : public Type { attributeTypes_.push_back(type); } + at::ArrayRef attributeNames() const { + return attributeNames_; + } + at::ArrayRef containedTypes() const override { return attributeTypes_; } + // generate a refined version of this class. + // It has the same name but the slot Types are subtypes of + // the original slots. It is only valid to refine a class type in a context + // where it is know that there are not assignments to the objects slots + // that would invalidate the refinement. + // These variants are not registered in the global class table. + ClassTypePtr refine(at::ArrayRef refined_slots) const; static const TypeKind Kind = TypeKind::ClassType; private: diff --git a/aten/src/ATen/core/type.cpp b/aten/src/ATen/core/type.cpp index fc741e1..fe93a2c 100644 --- a/aten/src/ATen/core/type.cpp +++ b/aten/src/ATen/core/type.cpp @@ -478,6 +478,16 @@ ClassTypePtr ClassType::create( return ptr; } +ClassTypePtr ClassType::refine(at::ArrayRef refined_slots) const { + auto ptr = ClassTypePtr(new ClassType(typename_, module_)); + AT_ASSERT(numAttributes() == refined_slots.size()); + for(size_t i = 0; i < attributeNames_.size(); ++i) { + AT_ASSERT(refined_slots[i]->isSubtypeOf(attributeTypes_[i])); + ptr->addAttribute(attributeNames_[i], refined_slots[i]); + } + return ptr; +} + ClassTypePtr ClassType::get(const std::string& name) { return getRegistry().getType(name); } diff --git a/test/cpp/jit/test_autodiff.h b/test/cpp/jit/test_autodiff.h index 3ec4929..00c7624 100644 --- a/test/cpp/jit/test_autodiff.h +++ b/test/cpp/jit/test_autodiff.h @@ -208,7 +208,10 @@ void testDifferentiateWithRequiresGrad() { at::empty_strided(2, 2, at::CPU(at::kFloat).options()), true); auto b_var = autograd::make_variable( at::empty_strided(2, 2, at::CPU(at::kFloat).options()), false); - setInputTypes(*graph, ArgumentSpec(true, {a_var, b_var}, 2)); + + ArgumentSpecCreator asc(*graph); + asc.setInputTypes(*graph, asc.create(true, {a_var, b_var})); + PropagateInputShapes(graph); PropagateRequiresGrad(graph); diff --git a/test/test_jit.py b/test/test_jit.py index 89d3592..22ce30a 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -1893,17 +1893,18 @@ class TestJit(JitTestCase): def test_tuple_specialization(self): @torch.jit.script - def f(t): - # type: (Tuple[Tensor, Tensor]) -> Tensor - x, y = t + def f(t, s): + # type: (Tuple[Tensor, Tuple[int, Tensor]], str) -> Tensor + x, t2 = t + _, y = t2 return x + y - t = torch.randn(2, 2), torch.randn(2, 2) - f(t) - graph = f.graph_for(t) + t = torch.randn(2, 2), (1, torch.randn(2, 2)), + f(t, "hi") + graph = f.graph_for(t, "hi") input_types = list(next(graph.inputs()).type().elements()) - for t in input_types: - self.assertEqual(t.kind(), 'DimensionedTensorType') + self.assertEqual(input_types[0].kind(), 'DimensionedTensorType') + self.assertEqual(input_types[1].elements()[1].kind(), 'DimensionedTensorType') def test_constant_prop_simple(self): @torch.jit.script @@ -3450,13 +3451,11 @@ a") # test that shape analysis is written correctly for sum with IntArrayRef[1] dim argument self.run_pass('constant_propagation', func.graph) self.run_pass('constant_propagation', func2.graph) - torch._C._jit_pass_shape_analysis( - func.graph, (torch.zeros(1, 1, 1, 1, 4),), False) - torch._C._jit_pass_shape_analysis( - func2.graph, (torch.zeros(1, 1, 1, 1, 4),), False) - self.assertTrue(func.graph.findNode("aten::sum").output().type().kind() + g = func._get_method('forward').propagate_shapes((torch.zeros(1, 1, 1, 1, 4),), False) + g2 = func2._get_method('forward').propagate_shapes((torch.zeros(1, 1, 1, 1, 4),), False) + self.assertTrue(g.findNode("aten::sum").output().type().kind() == "DimensionedTensorType") - self.assertTrue(func2.graph.findNode("aten::sum").output().type().kind() + self.assertTrue(g2.findNode("aten::sum").output().type().kind() == "DimensionedTensorType") def test_cat(self): @@ -4154,9 +4153,9 @@ a") torch.mul(x, y, out=z) return z - torch._C._jit_pass_shape_analysis( - test.graph, (torch.zeros(2, 1), torch.zeros(1, 2), torch.zeros(1, 1, 1)), False) - self.assertTrue(next(test.graph.outputs()).type() == TensorType.get()) + graph = test._get_method('forward').propagate_shapes( + (torch.zeros(2, 1), torch.zeros(1, 2), torch.zeros(1, 1, 1)), False) + self.assertTrue(next(graph.outputs()).type() == TensorType.get()) out_op_graph_input() def test_resize(): @@ -4173,10 +4172,8 @@ a") after_resize_alias = b.add_(1) return after_resize_alias - g = test.graph - self.run_pass('constant_propagation', g) - torch._C._jit_pass_shape_analysis( - g, (torch.zeros(1, 1),), False) + self.run_pass('constant_propagation', test.graph) + g = test._get_method('forward').propagate_shapes((torch.zeros(1, 1),), False) resize_node = g.findNode("aten::resize_") # first input and output of b.resize_ is b self.assertTrue(next(resize_node.inputs()).type() == TensorType.get()) @@ -4200,8 +4197,7 @@ a") g = test.graph self.run_pass('constant_propagation', g) - torch._C._jit_pass_shape_analysis( - g, (torch.zeros(1, 1),), False) + g = test._get_method('forward').propagate_shapes((torch.zeros(1, 1),), False) # x doesn't alias a resized op so it shouldn't be set to base Tensor type self.assertTrue(next(g.inputs()).type() != TensorType.get()) @@ -4255,8 +4251,8 @@ a") return x.view(T, B, C) x = torch.randn(3, 1, 5, requires_grad=True) - graph = torch.jit.script(fn).graph - torch._C._jit_pass_shape_analysis(graph, (x,), False) + fn = torch.jit.script(fn) + graph = fn._get_method('forward').propagate_shapes((x,), False) a = next(graph.outputs()).type().kind() self.assertTrue(next(graph.outputs()).type().kind() != 'TensorType') @@ -6677,7 +6673,7 @@ a") return torch.cat(c) b = torch.zeros(2, 4) - test_list.graph.propagate_shapes((b,), False) + test_list._get_method('forward').propagate_shapes((b,), False) def test_if_supertype(self): @torch.jit.script @@ -6694,8 +6690,8 @@ a") b = torch.zeros(2, 4, dtype=torch.long) c = torch.zeros(2, 4, dtype=torch.float) - tensor_unifying.graph.propagate_shapes((a, b, c), False) - if_outputs = list(tensor_unifying.graph.findNode("prim::If").outputs()) + graph = tensor_unifying._get_method('forward').propagate_shapes((a, b, c), False) + if_outputs = list(graph.findNode("prim::If").outputs()) self.assertTrue(if_outputs[0].type().str() == "Float(*, *)") self.assertTrue(if_outputs[1].type().str() == "Tensor") self.assertTrue(if_outputs[2].type().str() == "Tensor") @@ -13303,6 +13299,30 @@ class TestClassType(JitTestCase): self.assertEqual(x, f2.x) self.assertEqual(y, f2.y) + def test_class_specialization(self): + @torch.jit.script # noqa: B903 + class Foo(object): + def __init__(self, x, y): + self.x = x + self.y = y + + def use_foo(foo, foo2, tup): + # type: (Foo, Foo, Tuple[Foo, Foo]) -> Tensor + a, b = tup + return foo.x + foo2.y + a.x + b.y + + # create from python + x = torch.ones(2, 3) + y = torch.zeros(2, 3) + f = Foo(x, y) + f2 = Foo(x * 2, y * 3) + f3 = Foo(x * 4, y * 4) + + input = (f, f2, (f, f3)) + sfoo = self.checkScript(use_foo, input) + graphstr = str(sfoo.graph_for(*input)) + FileCheck().check_count("Double(*, *) = prim::GetAttr", 4).run(graphstr) + class TestLogging(JitTestCase): def test_bump_numeric_counter(self): diff --git a/tools/build_variables.py b/tools/build_variables.py index 2236172..89a5ed8 100644 --- a/tools/build_variables.py +++ b/tools/build_variables.py @@ -51,6 +51,7 @@ libtorch_sources = [ "torch/csrc/Exceptions.cpp", "torch/csrc/jit/autodiff.cpp", "torch/csrc/jit/attributes.cpp", + "torch/csrc/jit/argument_spec.cpp", "torch/csrc/jit/constants.cpp", "torch/csrc/jit/node_hashing.cpp", "torch/csrc/jit/export.cpp", diff --git a/torch/CMakeLists.txt b/torch/CMakeLists.txt index 6cd237f..60f883a 100644 --- a/torch/CMakeLists.txt +++ b/torch/CMakeLists.txt @@ -123,6 +123,7 @@ set(TORCH_SRCS ${TORCH_SRC_DIR}/csrc/autograd/VariableTypeManual.cpp ${TORCH_SRC_DIR}/csrc/jit/autodiff.cpp ${TORCH_SRC_DIR}/csrc/jit/attributes.cpp + ${TORCH_SRC_DIR}/csrc/jit/argument_spec.cpp ${TORCH_SRC_DIR}/csrc/jit/export.cpp ${TORCH_SRC_DIR}/csrc/jit/pickler.cpp ${TORCH_SRC_DIR}/csrc/jit/generated/register_aten_ops_0.cpp diff --git a/torch/csrc/jit/argument_spec.cpp b/torch/csrc/jit/argument_spec.cpp new file mode 100644 index 0000000..a31a91e --- /dev/null +++ b/torch/csrc/jit/argument_spec.cpp @@ -0,0 +1,229 @@ + +#include + +namespace torch { +namespace jit { + +void ArgumentSpecCreator::scan( + const TypePtr& typ, + size_t depth, + const WrittenSlots& written_slots) { + auto finishAggregate = [&](size_t pos) { + // it is possible after all the work we did to scan this aggregate, + // we found no tensors to specialize. In this case, just generate + // a skip for the whole aggregate. + bool any_spec = std::any_of( + instructions_.begin() + pos, instructions_.end(), [](Inst i) { + return i == SPECIALIZE_TENSOR; + }); + if (!any_spec) { + instructions_[pos] = SKIP; + instructions_.resize(pos + 1); + } else { + instructions_.emplace_back(LEAVE); + } + }; + // the simple vm that scans instructions_ has a limited stack depth, + // this prevents going deeper than that. + if (depth >= DEPTH_LIMIT) { + instructions_.emplace_back(SKIP); + } + if (typ->isSubtypeOf(TensorType::get())) { + num_tensors_++; + instructions_.emplace_back(SPECIALIZE_TENSOR); + } else if (auto tup = typ->cast()) { + size_t pos = instructions_.size(); + instructions_.emplace_back(ENTER_TUPLE); + for (const auto& elem : tup->containedTypes()) { + scan(elem, depth + 1, written_slots); + } + finishAggregate(pos); + } else if (auto cls = typ->cast()) { + size_t pos = instructions_.size(); + instructions_.emplace_back(ENTER_OBJECT); + for (size_t i = 0; i < cls->numAttributes(); ++i) { + auto key = cls->name() + cls->attributeNames().at(i); + // it is only safe to specialize because someone might have written to it + if (!written_slots.count(key)) { + scan(cls->containedTypes().at(i), depth + 1, written_slots); + } else { + instructions_.emplace_back(SKIP); + } + } + finishAggregate(pos); + } else { + instructions_.emplace_back(SKIP); + } +}; + +// this is a coarse-grained guarentee that the slots of a class will not be +// modified by the function. It works fine for things that used be read-only +// modules, but will be overly conservative when some classes are written to. +// Doing alias analysis and looking for writes to the class would be more +// accurate. +static void scanWrittenSlots( + Block* block, + ArgumentSpecCreator::WrittenSlots& written_slots) { + for (Node* n : block->nodes()) { + if (n->kind() == prim::SetAttr) { + if (auto cls = n->inputs().at(0)->type()->cast()) { + written_slots.insert(cls->name() + n->s(attr::name)); + } + } + for (Block* subblock : n->blocks()) { + scanWrittenSlots(subblock, written_slots); + } + if (n->hasAttribute(attr::Subgraph)) { + scanWrittenSlots(n->g(attr::Subgraph)->block(), written_slots); + } + } +} + +ArgumentSpecCreator::ArgumentSpecCreator(Graph& graph) + : num_inputs_(graph.inputs().size()) { + WrittenSlots written_slots; + scanWrittenSlots(graph.block(), written_slots); + for (Value* input : graph.inputs()) { + scan(input->type(), 0, written_slots); + } +} + +void ArgumentSpecCreator::dump() const { + for (Inst inst : instructions_) { + switch (inst) { + case LEAVE: + std::cout << "] "; + break; + case ENTER_TUPLE: + std::cout << "Tuple["; + break; + case ENTER_OBJECT: + std::cout << "Object["; + break; + case SKIP: + std::cout << "Skip "; + break; + case SPECIALIZE_TENSOR: + std::cout << "SpecializeTensor "; + break; + } + } + std::cout << "\n"; +} + +ArgumentSpec ArgumentSpecCreator::create(bool with_grad, const Stack& input) + const { + ArgumentSpec spec(num_tensors_); + const IValue* stack[DEPTH_LIMIT]; // The stack of IValue lists + // The stack gets initialized with the input list + stack[0] = last(input, num_inputs_).begin(); + size_t stack_top = 0; // offset to the top of the stack + for (Inst inst : instructions_) { + switch (inst) { + case SPECIALIZE_TENSOR: + // consume a tensor and add to the argspec + spec.addTensor(*stack[stack_top]++, with_grad); + break; + case ENTER_TUPLE: { + // consume tuple + const IValue* iv = stack[stack_top]++; + AT_ASSERT(iv->isTuple()); + // see [argspec refcounting] + auto p = *reinterpret_cast(iv); + auto tup_ptr = &p->elements()[0]; + // push list of tuple elements to the stack + stack[++stack_top] = tup_ptr; + } break; + case ENTER_OBJECT: { + // consume object + const IValue* iv = stack[stack_top]++; + AT_ASSERT(iv->isObject()); + iv->toObject(); + // see [argspec refcounting] + auto p = *reinterpret_cast(iv); + auto obj_ptr = &p->slots()[0]; + // push list of object elements to the stack + stack[++stack_top] = obj_ptr; + } break; + case SKIP: + // consume and skip an element + stack[stack_top]++; + break; + case LEAVE: + --stack_top; + break; + } + } + return spec; +} + +// For every input of a given graph, returns a most detailed type that can be +// inferred for it based on this ArgumentSpec. +std::vector ArgumentSpecCreator::getSpecializedTypes( + Graph& graph, + const ArgumentSpec& spec) const { + auto input_types = + fmap(graph.inputs(), [](Value* input) { return input->type(); }); + std::vector> result_stack; + result_stack.emplace_back(); + std::vector input_stack = {input_types.data()}; + std::vector> aggregate_creators; + + size_t arg_spec_offset = 0; // number of specialized tensors seen so far + + for (Inst inst : instructions_) { + switch (inst) { + case SPECIALIZE_TENSOR: { + input_stack.back()++; + auto& arg = spec.at(arg_spec_offset++); + if (!arg.defined()) { + result_stack.back().emplace_back(AutogradZeroTensorType::get()); + } else { + result_stack.back().emplace_back(DimensionedTensorType::create( + arg.type(), + ConvertIntToCPUOrCUDA(arg.device()), + arg.dim(), + arg.requires_grad())); + } + } break; + case ENTER_TUPLE: { + auto tup = (*input_stack.back()++)->expect(); + input_stack.emplace_back(tup->elements().data()); + result_stack.emplace_back(); + aggregate_creators.emplace_back( + [&] { return TupleType::create(result_stack.back()); }); + } break; + case ENTER_OBJECT: { + auto cls = (*input_stack.back()++)->expect(); + input_stack.emplace_back(cls->containedTypes().data()); + result_stack.emplace_back(); + aggregate_creators.emplace_back( + [&result_stack, cls] { return cls->refine(result_stack.back()); }); + } break; + case SKIP: + result_stack.back().emplace_back(*input_stack.back()++); + break; + case LEAVE: + TypePtr result = aggregate_creators.back()(); + result_stack.pop_back(); + aggregate_creators.pop_back(); + input_stack.pop_back(); + result_stack.back().emplace_back(std::move(result)); + break; + } + } + AT_ASSERT(result_stack.size() == 1); + return result_stack.back(); +} + +void ArgumentSpecCreator::setInputTypes(Graph& g, const ArgumentSpec& spec) + const { + auto input_types = getSpecializedTypes(g, spec); + auto inputs = g.inputs(); + for (size_t i = 0; i < inputs.size(); ++i) { + inputs[i]->setType(input_types[i]); + } +} + +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/argument_spec.h b/torch/csrc/jit/argument_spec.h index a345c6b..ca88836 100644 --- a/torch/csrc/jit/argument_spec.h +++ b/torch/csrc/jit/argument_spec.h @@ -1,9 +1,9 @@ #pragma once +#include +#include #include #include -#include -#include #include #include #include @@ -22,9 +22,6 @@ struct ArgumentInfo { friend struct ArgumentSpec; using plain_data_type = uint32_t; - bool isTensor() const { - return is_tensor_; - } bool defined() const { return defined_; } @@ -45,11 +42,11 @@ struct ArgumentInfo { operator TypePtr() const { if (!defined()) return TensorType::get(); - return DimensionedTensorType::create(type(), ConvertIntToCPUOrCUDA(device()), dim()); + return DimensionedTensorType::create( + type(), ConvertIntToCPUOrCUDA(device()), dim()); } private: - unsigned is_tensor_ : 1; unsigned defined_ : 1; unsigned requires_grad_ : 1; unsigned : 5; @@ -67,48 +64,32 @@ static_assert( "ArgumentInfo is expected to be a 32-bit struct"); struct ArgumentSpec { - ArgumentSpec( - bool with_grad, - at::ArrayRef inputs, - size_t num_flat_inputs) { + ArgumentSpec(size_t num_flat_inputs) { hash_code = num_flat_inputs; - args.resize(num_flat_inputs); - size_t offset = 0; - for (const auto& i : inputs) { - addInput(i, offset, with_grad); - } - AT_ASSERT(offset <= num_flat_inputs); + args.reserve(num_flat_inputs); } - void addInput(const IValue& input, size_t& offset, bool with_grad) { - auto& arg = args.at(offset); + void addTensor(const IValue& input, bool with_grad) { + AT_ASSERT(input.isTensor()); + args.emplace_back(); + auto& arg = args.back(); // Initialize all fields to 0. This is convenient, because e.g. // requires_grad() can be checked even on tensors AND will make // padding bits all 0s. std::memset(&arg, 0, sizeof(ArgumentInfo)); - if (input.isTensor()) { - at::Tensor t = input.toTensor(); - if ((arg.defined_ = t.defined())) { - arg.requires_grad_ = with_grad && autograd::Variable(t).requires_grad(); - arg.dim_ = t.dim(); - arg.device_ = t.is_cuda() ? t.get_device() : -1; - arg.type_ = static_cast(t.scalar_type()); - } - - arg.is_tensor_ = true; - combineHash(arg); - offset++; - } else if (input.isTuple()) { - for (const IValue& elem : input.toTuple()->elements()) { - addInput(elem, offset, with_grad); - } - } else { - // NB: no need to set is_tensor to false, because we memset the struct to - // 0 above - combineHash(arg); - offset++; + // [argspec refcounting] reinterpret the IValue to avoid having to refcount + // the Tensor microbenchmarks + // https://github.com/zdevito/pytorch/commit/21e7200a0a0fc456bea2f10e95b1781f83933d10 + // show overhead in extra refcounting along this path + const at::Tensor* t = reinterpret_cast(&input); + if ((arg.defined_ = t->defined())) { + arg.requires_grad_ = with_grad && autograd::Variable(*t).requires_grad(); + arg.dim_ = t->dim(); + arg.device_ = t->is_cuda() ? t->get_device() : -1; + arg.type_ = static_cast(t->scalar_type()); } + combineHash(arg); } void combineHash(const ArgumentInfo& arg) { @@ -143,38 +124,49 @@ struct ArgumentSpec { size_t hashCode() const { return hash_code; } - // For every input of a given graph, returns a most detailed type that can be - // inferred for it based on this ArgumentSpec. - std::vector getTypes(Graph& graph) const { - size_t offset = 0; - return fmap( - graph.inputs(), [&](Value* v) { return fillType(v->type(), offset); }); - } private: - TypePtr fillType(TypePtr original, size_t& offset) const { - if (original->isSubtypeOf(TensorType::get())) { - auto& arg = args.at(offset++); - if (!arg.defined()) - return AutogradZeroTensorType::get(); - return DimensionedTensorType::create( - arg.type(), - ConvertIntToCPUOrCUDA(arg.device()), - arg.dim(), - arg.requires_grad()); - } else if (auto tuple_type = original->cast()) { - return TupleType::create(fmap( - tuple_type->elements(), - [&](const TypePtr& subtype) { return fillType(subtype, offset); })); - } else { - offset++; - return original; - } - } size_t hash_code; // precomputed on construction std::vector args; }; +// ArgumentSpecCreator takes an initial graph and comes up with a set +// of simple instructions to compute the ArgumentSpec given a set of +// input tensors. +struct ArgumentSpecCreator { + // instructs acts on a stack of a list of input IValues + // at the beginning the stack contains a single list of the inputs to the + // function the ENTER_ instructs descend into subobjects and push new lists + // onto the stack + enum Inst : char { + ENTER_TUPLE, // consume a tuple ivalue from the top-most list, and push the + // list of its elements onto the stack as a new list + ENTER_OBJECT, // same as ENTER_TUPLE, but the input is a class + LEAVE, // pop the top-most list from the stack + SKIP, // consume an element from the top-most list, and discard + SPECIALIZE_TENSOR, // consume a tensor for the top-most list, and + // add it to the ArgSpec key being created + }; + ArgumentSpecCreator(Graph& graph); + ArgumentSpec create(bool with_grad, const Stack& stack) const; + void setInputTypes(Graph& g, const ArgumentSpec& spec) const; + std::vector getSpecializedTypes( + Graph& graph, + const ArgumentSpec& spec) const; + void dump() const; + using WrittenSlots = std::unordered_set; + + private: + static constexpr size_t DEPTH_LIMIT = 128; + void scan( + const TypePtr& typ, + size_t depth, + const WrittenSlots& written_slots); + size_t num_inputs_; + size_t num_tensors_ = 0; + std::vector instructions_; +}; + // CompleteArgumentSpec represents one particular specialization. // It is designed so that it can be created, hashed, and compared quickly // since it is used along the hot-path of the JIT to check if the code @@ -398,14 +390,6 @@ inline CompleteArgumentInfo CompleteArgumentSpec::at(size_t i) const { return CompleteArgumentInfo(*this, i); } -inline void setInputTypes(Graph& g, const ArgumentSpec& spec) { - auto input_types = spec.getTypes(g); - auto inputs = g.inputs(); - for (size_t i = 0; i < inputs.size(); ++i) { - inputs[i]->setType(input_types[i]); - } -} - } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/graph_executor.cpp b/torch/csrc/jit/graph_executor.cpp index 1e993da..4d3f848 100644 --- a/torch/csrc/jit/graph_executor.cpp +++ b/torch/csrc/jit/graph_executor.cpp @@ -322,28 +322,6 @@ struct GraphExecutorImpl { return copy; } - static size_t countFlatInputs(const TypePtr& ptr) { - if (auto optional_type = ptr->cast()) { - return countFlatInputs(optional_type->getElementType()); - } - if (auto tuple_type = ptr->cast()) { - size_t total = 0; - for (auto& elem : tuple_type->elements()) { - total += countFlatInputs(elem); - } - return total; - } - return 1; - } - - static size_t countFlatInputs(const std::shared_ptr& graph) { - size_t total = 0; - for (Value* input : graph->inputs()) { - total += countFlatInputs(input->type()); - } - return total; - } - inline bool hasMutableOperators(Block* block) { for (auto n : block->nodes()) { if (n->kind().is_aten() && n->schema().is_mutable()) @@ -362,11 +340,11 @@ struct GraphExecutorImpl { // disables all optimization optimize(optimize), num_inputs(this->graph->inputs().size()), - num_flat_inputs(countFlatInputs(graph)), + arg_spec_creator_(*graph), num_outputs(this->graph->outputs().size()) { - logging::getLogger()->addStatValue( - logging::runtime_counters::GRAPH_EXECUTORS_CONSTRUCTED, 1.0); - } + logging::getLogger()->addStatValue( + logging::runtime_counters::GRAPH_EXECUTORS_CONSTRUCTED, 1.0); + } // entry point where execution begins void run(Stack& stack) { @@ -391,9 +369,9 @@ struct GraphExecutorImpl { std::shared_ptr graphFor(const Stack& stack) const { AT_ASSERT(stack.size() >= num_inputs); - auto inputs = last(stack, num_inputs); - ArgumentSpec spec( - autograd::GradMode::is_enabled(), inputs, num_flat_inputs); + + ArgumentSpec spec = + arg_spec_creator_.create(autograd::GradMode::is_enabled(), stack); if (!optimize) { AT_CHECK(fallback, "No graph found for given inputs"); @@ -441,10 +419,8 @@ struct GraphExecutorImpl { const ExecutionPlan& getOrCompile(const Stack& stack) { // outside lock guard, to minimize the time holding the lock on the fast // path ArgumentSpec even computes its hashCode here. - ArgumentSpec spec( - autograd::GradMode::is_enabled(), - last(stack, num_inputs), - num_flat_inputs); + ArgumentSpec spec = + arg_spec_creator_.create(autograd::GradMode::is_enabled(), stack); { std::lock_guard lock(compile_mutex); auto it = plan_cache.find(spec); @@ -463,7 +439,7 @@ struct GraphExecutorImpl { ExecutionPlan compileSpec(const ArgumentSpec& spec) { auto opt_graph = graph->copy(); - setInputTypes(*opt_graph, spec); + arg_spec_creator_.setInputTypes(*opt_graph, spec); // Phase 1. Specialize to input definedness (this is very important for // gradient graphs), and run required passes to bring the graph @@ -562,8 +538,8 @@ struct GraphExecutorImpl { auto input_values = fmap( inputs, [](const IValue& v) { return tracer::getNestedValueTrace(v); }); - ArgumentSpec spec( - autograd::GradMode::is_enabled(), inputs, num_flat_inputs); + ArgumentSpec spec = + arg_spec_creator_.create(autograd::GradMode::is_enabled(), stack); // NB: we could just run the fallback in here and call it a day, but that // would loose all the control flow information we have in the graph. Thus, // we run the fallback to get the correct output values, but we will @@ -580,7 +556,7 @@ struct GraphExecutorImpl { // tracing and so we only do the type propgation if no concrete types have // been set. auto local_graph = this->graph->copy(); - setInputTypes(*local_graph, spec); + arg_spec_creator_.setInputTypes(*local_graph, spec); PropagateInputShapes(local_graph); auto output_values = inlineCallTo(*state->graph, *local_graph, input_values); @@ -600,8 +576,7 @@ struct GraphExecutorImpl { // Useful for debugging. const bool optimize; const size_t num_inputs; - const size_t num_flat_inputs; // Number of inputs, assuming all tuples would - // be flattened. + ArgumentSpecCreator arg_spec_creator_; const size_t num_outputs; // Populated only when optimize is false (and in that case plan_cache will be diff --git a/torch/csrc/jit/init.cpp b/torch/csrc/jit/init.cpp index 6ee8d9a..e693a15 100644 --- a/torch/csrc/jit/init.cpp +++ b/torch/csrc/jit/init.cpp @@ -156,16 +156,6 @@ void initJITBindings(PyObject* module) { [](const std::shared_ptr& g) { return Canonicalize(g); }) .def("_jit_pass_lint", LintGraph) .def( - "_jit_pass_shape_analysis", - [](std::shared_ptr graph, - std::vector inputs, - bool with_grad) { - setInputTypes( - *graph, - ArgumentSpec(with_grad, fmap(inputs), inputs.size())); - PropagateInputShapes(graph); - }) - .def( "_jit_pass_complete_shape_analysis", [](std::shared_ptr graph, py::tuple inputs, bool with_grad) { CompleteArgumentSpec spec( diff --git a/torch/csrc/jit/passes/shape_analysis.cpp b/torch/csrc/jit/passes/shape_analysis.cpp index bd162cf..5f63188 100644 --- a/torch/csrc/jit/passes/shape_analysis.cpp +++ b/torch/csrc/jit/passes/shape_analysis.cpp @@ -528,6 +528,12 @@ class ShapePropagator { setUnshapedType(node); return; } + case prim::GetAttr: { + auto cls = node->input()->type()->expect(); + // propagate any type specializations encoded in the type of the class + node->output()->setType(cls->getAttribute(node->s(attr::name))); + return; + } case aten::_unwrap_optional: { auto input_ivalue = toIValue(node->input()); if (input_ivalue && input_ivalue->isNone()) { @@ -997,11 +1003,9 @@ class ShapePropagator { }; // Requirements: - // dims : 0 if dim is None, otherwise preserved if keepdim == false or 1 smaller otherwise - // scalar type : preserved - // device : preserved - // tensor inputs : 1 - // tensor outputs : 1 + // dims : 0 if dim is None, otherwise preserved if keepdim == + // false or 1 smaller otherwise scalar type : preserved device : + // preserved tensor inputs : 1 tensor outputs : 1 // Additionally: // - First input should be the only tensor input // - Has a bool keepdim argument @@ -1094,7 +1098,9 @@ class ShapePropagator { [](Node* node) -> type_vec_t { if (auto dim = node->get>(attr::dim)) { return multidim_reduce_with_postprocess( - node, /*num_reduced_dim=*/dim->size(), /*upcast_integer=*/false); + node, + /*num_reduced_dim=*/dim->size(), + /*upcast_integer=*/false); } return {}; }}; diff --git a/torch/csrc/jit/python_ir.cpp b/torch/csrc/jit/python_ir.cpp index a6a900b..4ffaa94 100644 --- a/torch/csrc/jit/python_ir.cpp +++ b/torch/csrc/jit/python_ir.cpp @@ -209,16 +209,6 @@ void initPythonIRBindings(PyObject* module_) { db.dump(); }) .def( - "propagate_shapes", - [](std::shared_ptr g, - std::vector inputs, - bool with_grad) { - setInputTypes( - *g, - ArgumentSpec(with_grad, fmap(inputs), inputs.size())); - PropagateInputShapes(g); - }) - .def( "_export_onnx", [](const std::shared_ptr g, const std::map& initializers, diff --git a/torch/csrc/jit/script/module.h b/torch/csrc/jit/script/module.h index d4417a7..ba99b2c 100644 --- a/torch/csrc/jit/script/module.h +++ b/torch/csrc/jit/script/module.h @@ -1,15 +1,14 @@ #pragma once -#include +#include #include +#include #include -#include #include #include #include #include -#include #include - +#include #include #include @@ -51,8 +50,8 @@ using ExtraFilesMap = std::unordered_map; struct Module; -using ModuleLookup = std::function( - const std::vector&)>; +using ModuleLookup = + std::function(const std::vector&)>; struct Method { Method( @@ -137,6 +136,14 @@ struct Method { return graph()->addInput()->setType(type); } + static void setInputTensorTypes(Graph& g, const Stack& stack) { + AT_ASSERT(stack.size() == g.inputs().size()); + for (size_t i = 0; i < stack.size(); ++i) { + g.inputs().at(i)->setType( + DimensionedTensorType::create(stack.at(i).toTensor())); + } + } + std::shared_ptr propagate_shapes( std::vector inputs, bool with_grad = false) { @@ -149,8 +156,7 @@ struct Method { for (const Slot& inp : initial_ivalues_) { stack.push_back(*inp); } - const auto size = stack.size(); - setInputTypes(*retval, ArgumentSpec(with_grad, stack, size)); + setInputTensorTypes(*retval, stack); PropagateInputShapes(retval); return retval; } @@ -167,9 +173,7 @@ struct Method { } } if (propagate) { - setInputTypes( - *retval, - ArgumentSpec(with_grad, fmap(inputs), inputs.size())); + setInputTensorTypes(*retval, fmap(inputs)); PropagateInputShapes(retval); } AT_ASSERT(retval->inputs().size() == inputs.size()); @@ -288,16 +292,16 @@ struct Method { if (pos < inputs.size()) { if (!isSubvalueOf(inputs[pos], argument.type())) { AT_ERROR( - "Expected value of type ", - *argument.type(), - " for argument '", - argument.name(), - "' in position ", - pos, - ", but instead got value of type ", - attemptToRecoverType(inputs[pos])->str(), - ". Declaration: ", - schema); + "Expected value of type ", + *argument.type(), + " for argument '", + argument.name(), + "' in position ", + pos, + ", but instead got value of type ", + attemptToRecoverType(inputs[pos])->str(), + ". Declaration: ", + schema); } } else if (argument.default_value()) { inputs.push_back(*argument.default_value()); @@ -375,7 +379,8 @@ struct NamedIValue { const TypePtr& type() const { return type_; } -private: + + private: const std::string name_; const TypePtr type_; std::unique_ptr ivalue_; @@ -497,12 +502,10 @@ struct Module { const torch::OrderedDict& get_modules() const { return modules; } - const torch::OrderedDict& get_parameters() - const { + const torch::OrderedDict& get_parameters() const { return parameters; } - const torch::OrderedDict& get_attributes() - const { + const torch::OrderedDict& get_attributes() const { return attributes; } const torch::OrderedDict>& get_methods() @@ -630,9 +633,7 @@ struct Module { if (!kv.value().type()->isSubtypeOf(TensorType::get())) { continue; } - curr->register_buffer( - kv.key(), - kv.value().slot()->toTensor()); + curr->register_buffer(kv.key(), kv.value().slot()->toTensor()); parameter_remap[kv.value().slot()] = curr->find_buffer(kv.key())->slot(); } for (auto& kv : modules) {