From: Zachary DeVito Date: Fri, 5 Apr 2019 20:33:14 +0000 (-0700) Subject: slots with explicit value/setValue make more sense in future patches (#18468) X-Git-Tag: accepted/tizen/6.5/unified/20211028.231830~369 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=f6f34b3f4cc3368e13e0f0d0144f9a73653a6986;p=platform%2Fupstream%2Fpytorch.git slots with explicit value/setValue make more sense in future patches (#18468) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18468 ghimport-source-id: d4b41c521f2269a695e03c8e7d05d5542731ee48 Stack from [ghstack](https://github.com/ezyang/ghstack): * #18469 Create Object that represents a Module * **#18468 slots with explicit value/setValue make more sense in future patches** * #18467 Make Object hold its ClassType * #18379 Enforce single parent for script submodules * #18378 Unify namespace of script::Module * #18314 Add ability to specialize class types to ArgumentSpec * #18226 Add Slot type to abstract the raw pointers being used for slots. Reviewed By: suo Differential Revision: D14613509 fbshipit-source-id: 9f2208d0efd01465c78cebdc3e8365a9e0adf9ff --- diff --git a/test/custom_operator/test_custom_ops.cpp b/test/custom_operator/test_custom_ops.cpp index 1b64171..a943a08 100644 --- a/test/custom_operator/test_custom_ops.cpp +++ b/test/custom_operator/test_custom_ops.cpp @@ -15,10 +15,10 @@ void check_all_parameters( const torch::jit::script::Module& module, Predicate predicate) { for (const auto& parameter : module.get_parameters()) { - AT_ASSERT(predicate(parameter.slot()->toTensor())); + AT_ASSERT(predicate(parameter.slot().value().toTensor())); } for (const auto& child : module.get_modules()) { - check_all_parameters(module, predicate); + check_all_parameters(*child, predicate); } } } // namespace helpers diff --git a/torch/csrc/api/src/serialize/input-archive.cpp b/torch/csrc/api/src/serialize/input-archive.cpp index e8243cd..0355dc5 100644 --- a/torch/csrc/api/src/serialize/input-archive.cpp +++ b/torch/csrc/api/src/serialize/input-archive.cpp @@ -32,7 +32,7 @@ void InputArchive::read( "'"); // clang-format off auto read_param = is_buffer ? buffer : param; - auto read_tensor = read_param->slot()->toTensor(); + auto read_tensor = read_param->slot().value().toTensor(); AT_CHECK( bool(buffer) == is_buffer, "Expected deserialized tensor for key '", key, diff --git a/torch/csrc/jit/export.cpp b/torch/csrc/jit/export.cpp index 4f27838..69399f9 100644 --- a/torch/csrc/jit/export.cpp +++ b/torch/csrc/jit/export.cpp @@ -718,7 +718,7 @@ void ScriptModuleSerializer::writeAttributeTable() { } pickler.finish(); writer_.writeRecord( - "attributes.pkl", pickler.stack().data(), pickler.stack().size()); + "attributes.pkl", pickler.stack().data(), pickler.stack().size()); } void ScriptModuleSerializer::convertModule( @@ -739,7 +739,7 @@ void ScriptModuleSerializer::convertModule( attribute_def->set_name(attribute.name()); attribute_def->set_type(attribute.type()->python_str()); - attribute_table_.push_back(*attribute.slot()); + attribute_table_.push_back(attribute.slot().value()); attribute_def->set_id(attribute_table_.size() - 1); } @@ -779,7 +779,7 @@ void ScriptModuleSerializer::convertParameter( bool is_parameter) { param_def->set_name(param.name()); param_def->set_is_buffer(is_parameter); - param_def->set_tensor_id(addTensor(param.slot()->toTensor())); + param_def->set_tensor_id(addTensor(param.slot().value().toTensor())); } // Pretty printing for ONNX diff --git a/torch/csrc/jit/script/init.cpp b/torch/csrc/jit/script/init.cpp index ea69b11..79f7537 100644 --- a/torch/csrc/jit/script/init.cpp +++ b/torch/csrc/jit/script/init.cpp @@ -784,7 +784,8 @@ void initJitScriptBindings(PyObject* module) { auto& p = parameters[i]; py::tuple r(2); result[i] = std::make_tuple( - p.name(), autograd::as_variable_ref(p.slot()->toTensor())); + p.name(), + autograd::as_variable_ref(p.slot().value().toTensor())); } return result; }) @@ -796,7 +797,7 @@ void initJitScriptBindings(PyObject* module) { for (size_t i = 0; i < attributes.size(); ++i) { auto& buffer = attributes[i]; py::tuple r(3); - IValue v = *buffer.slot(); + IValue v = buffer.slot().value(); result[i] = std::make_tuple( buffer.name(), buffer.type(), toPyObject(std::move(v))); } @@ -856,7 +857,7 @@ void initJitScriptBindings(PyObject* module) { gatherParametersAndBuffers(parameters, *self); Stack inputs = toStack(input_tuple); for (const Slot& param : parameters) { - inputs.emplace_back(*param); + inputs.emplace_back(param.value()); } auto graph = tracer::createGraphByTracing( func, @@ -980,7 +981,7 @@ void initJitScriptBindings(PyObject* module) { [](Method& m) { std::vector tensors; for (auto& t : m.initial_ivalues()) { - tensors.push_back(t->toTensor()); + tensors.push_back(t.value().toTensor()); } return tensors; }) diff --git a/torch/csrc/jit/script/module.cpp b/torch/csrc/jit/script/module.cpp index cb9cdc7..1e80a0b 100644 --- a/torch/csrc/jit/script/module.cpp +++ b/torch/csrc/jit/script/module.cpp @@ -56,7 +56,7 @@ Value* try_emit_call_to( " from a raw graph. File a bug report"; } // TODO: preserve the type information so we don't have to infer it here - auto type = incompleteInferTypeFrom(*member); + auto type = incompleteInferTypeFrom(member.value()); matched_schema->inputs.push_back( caller->get_or_add_attribute(type, member)); } @@ -128,7 +128,7 @@ void Module::to_impl( // Then convert every of our parameters. for (auto& parameter : get_parameters()) { // Need to access the `at::Tensor` as a `Variable` here. - autograd::Variable variable = parameter.slot()->toTensor(); + autograd::Variable variable = parameter.slot().value().toTensor(); at::Tensor data = variable.data(); // Use the data's original device or dtype if not supplied here. auto new_data = data.to( diff --git a/torch/csrc/jit/script/module.h b/torch/csrc/jit/script/module.h index e7211e0..e674b02 100644 --- a/torch/csrc/jit/script/module.h +++ b/torch/csrc/jit/script/module.h @@ -76,7 +76,7 @@ struct Method { void run(Stack& stack) { for (auto input : initial_ivalues_) { - push(stack, *input); + push(stack, input.value()); } get_executor().run(stack); } @@ -93,7 +93,7 @@ struct Method { std::shared_ptr graph_for(Stack inputs) { for (auto tp : initial_ivalues_) { - inputs.emplace_back(*tp); + inputs.emplace_back(tp.value()); } return get_executor().graphFor(inputs); } @@ -122,7 +122,7 @@ struct Method { return graph()->inputs().size() - initial_ivalues_.size(); } TORCH_API Value* get_or_add_parameter(Slot slot) { - AT_ASSERT(slot->isTensor()); + AT_ASSERT(slot.value().isTensor()); return get_or_add_attribute(TensorType::get(), slot); } @@ -154,7 +154,7 @@ struct Method { stack.emplace_back(std::move(i)); } for (const Slot& inp : initial_ivalues_) { - stack.push_back(*inp); + stack.push_back(inp.value()); } setInputTensorTypes(*retval, stack); PropagateInputShapes(retval); @@ -168,8 +168,8 @@ struct Method { bool propagate = true) { auto retval = graph_->copy(); for (auto inp : initial_ivalues_) { - if (inp->isTensor()) { - inputs.push_back(inp->toTensor()); + if (inp.value().isTensor()) { + inputs.push_back(inp.value().toTensor()); } } if (propagate) { @@ -406,7 +406,7 @@ struct Module { void register_buffer(const std::string& name, autograd::Variable v) { if (auto b = find_attribute(name)) { AT_ASSERT(b->type()->isSubtypeOf(TensorType::get())); - *b->slot() = v; + b->slot().setValue(v); return; } insert( @@ -424,7 +424,7 @@ struct Module { return; } if (auto p = find_parameter(name)) { - *p->slot() = v; + p->slot().setValue(v); return; } insert( @@ -497,15 +497,15 @@ struct Module { } void set_parameter(const std::string& name, at::Tensor v) { - *parameter_slot(name) = std::move(v); + parameter_slot(name).setValue(std::move(v)); } autograd::Variable get_parameter(const std::string& name) const { - return autograd::as_variable_ref(parameter_slot(name)->toTensor()); + return autograd::as_variable_ref(parameter_slot(name).value().toTensor()); } IValue get_attribute(const std::string& name) const { - return *attributes_[get_offset(name, EntityType::ATTRIBUTE)].slot(); + return attributes_[get_offset(name, EntityType::ATTRIBUTE)].slot().value(); } autograd::Variable get_buffer(const std::string& name) const { @@ -579,7 +579,7 @@ struct Module { /// True if the module is in training mode. bool is_training() { if (auto p = find_buffer("training")) { - return p->slot()->toTensor().item() == 1; + return p->slot().value().toTensor().item() == 1; } // We are in training mode by default return true; @@ -648,7 +648,7 @@ struct Module { for (auto& param : get_parameters()) { curr->register_parameter( param.name(), - param.slot()->toTensor(), + param.slot().value().toTensor(), /*is_buffer=*/false); parameter_remap[param.slot()] = curr->parameter_slot(param.name()); } @@ -656,7 +656,7 @@ struct Module { if (!attr.type()->isSubtypeOf(TensorType::get())) { continue; } - curr->register_buffer(attr.name(), attr.slot()->toTensor()); + curr->register_buffer(attr.name(), attr.slot().value().toTensor()); parameter_remap[attr.slot()] = curr->find_buffer(attr.name())->slot(); } for (auto& mod : get_modules()) { diff --git a/torch/csrc/jit/script/slot.h b/torch/csrc/jit/script/slot.h index dc70e7b..67fd170 100644 --- a/torch/csrc/jit/script/slot.h +++ b/torch/csrc/jit/script/slot.h @@ -7,36 +7,39 @@ namespace script { // a stable location that can hold an IValue. // Currently this is internally implemented as a pointer, but when -// modules become first-class this will be a pair of +// modules become first-class this will be a pair of struct Slot { friend struct NamedIValue; - Slot() - : slot_(nullptr) {} - Slot(at::IValue* slot) - : slot_(slot) {} - at::IValue& operator*() const { - return *slot_; - } - at::IValue* operator->() const { - return slot_; - } + Slot() : slot_(nullptr) {} + Slot(at::IValue* slot) : slot_(slot) {} + bool operator==(const Slot& rhs) const { return slot_ == rhs.slot_; } -private: + void setValue(at::IValue v) { + *slot_ = std::move(v); + } + const at::IValue& value() const { + return *slot_; + } + + private: at::IValue* slot_; friend struct std::hash; }; -}}} +} // namespace script +} // namespace jit +} // namespace torch // slots are hashable, because they are often used as keys in maps // for remapping uses of a slot from one model to another namespace std { - template <> - struct hash { - size_t operator()(const torch::jit::script::Slot& s) const noexcept { - return std::hash{}(s.slot_); - } - }; +template <> +struct hash { + size_t operator()(const torch::jit::script::Slot& s) const noexcept { + return std::hash{}(s.slot_); + } +}; } // namespace std