-#include "import_source.h"
-
+#include <torch/csrc/jit/import_source.h>
#include <torch/csrc/jit/script/parser.h>
namespace torch {
if (NamedModule* v = module->find_module(field)) {
return std::make_shared<ModuleAccessorValue>(v->module);
} else if (NamedIValue* v = module->find_parameter(field)) {
- return std::make_shared<SimpleValue>(m.get_or_add_parameter(v->slot()));
+ return std::make_shared<SimpleValue>(m.get_or_add_initial_ivalue(v));
} else if (NamedIValue* v = module->find_buffer(field)) {
- return std::make_shared<SimpleValue>(m.get_or_add_parameter(v->slot()));
+ return std::make_shared<SimpleValue>(m.get_or_add_initial_ivalue(v));
} else if (script::NamedIValue* v = module->find_attribute(field)) {
return std::make_shared<script::SimpleValue>(
- m.get_or_add_attribute(v->type(), v->slot()));
+ m.get_or_add_initial_ivalue(v));
} else if (Method* m = module->find_method(field)) {
return std::make_shared<MethodValue>(shared_from_this(), *m);
} else {
const std::unordered_map<script::Slot, QualifiedNamePtr>&
extra_ivalue_names) {
std::vector<std::string> ivalue_names =
- fmap(method.initial_ivalues(), [&](const script::Slot& slot) {
- return extra_ivalue_names.at(slot)->str();
+ fmap(method.initial_ivalues(), [&](const script::NamedIValue* value) {
+ auto entry = extra_ivalue_names.find(value->slot());
+ AT_CHECK(
+ entry != extra_ivalue_names.end(),
+ "Could not find named IValue '",
+ value->name(),
+ "' while pretty printing");
+ return entry->second->str();
});
const std::string& name = method.name();
Graph& graph = *method.graph();
const auto& param_list = module_->get_parameters();
for (auto it = param_list.rbegin(); it != param_list.rend(); ++it) {
auto& param = *it;
- params.push_back(caller.get_or_add_parameter(param.slot()));
+ params.push_back(caller.get_or_add_initial_ivalue(¶m));
}
auto list = caller.graph()->createList(TensorType::get(), params);
caller.graph()->insertNode(list);
module->register_buffer("training", std::move(t));
v = module->find_buffer(field);
}
- Value* the_tensor = m.get_or_add_parameter(v->slot());
+ Value* the_tensor = m.get_or_add_initial_ivalue(v);
Value* the_bool = m.graph()->insert(prim::Bool, {the_tensor});
return std::make_shared<SimpleValue>(the_bool);
}
} else if (Method* v = module->find_method(field)) {
return std::make_shared<MethodValue>(shared_from_this(), *v);
} else if (NamedIValue* v = module->find_parameter(field)) {
- return std::make_shared<SimpleValue>(m.get_or_add_parameter(v->slot()));
+ return std::make_shared<SimpleValue>(m.get_or_add_initial_ivalue(v));
} else if (NamedIValue* v = module->find_attribute(field)) {
return std::make_shared<SimpleValue>(
- m.get_or_add_attribute(v->type(), v->slot()));
+ m.get_or_add_initial_ivalue(v));
}
// This can also be a call to a non-script module, or a plain
}
static void gatherParametersAndBuffers(
- std::vector<Slot>& values,
+ std::vector<const NamedIValue*>& values,
const Module& m) {
for (auto& param : m.get_parameters()) {
- values.push_back(param.slot());
+ values.push_back(¶m);
}
for (auto& param : m.get_attributes()) {
if (param.type()->isSubtypeOf(TensorType::get())) {
- values.push_back(param.slot());
+ values.push_back(¶m);
}
}
for (const auto& sub : m.get_modules()) {
bool force_outplace) {
// prereq: Module's buffers and parameters are unique
// this was ensured in python before calling this function
- std::vector<Slot> parameters;
+ std::vector<const NamedIValue*> parameters;
gatherParametersAndBuffers(parameters, *self);
Stack inputs = toStack(input_tuple);
- for (const Slot& param : parameters) {
- inputs.emplace_back(*param);
+ for (const NamedIValue* param : parameters) {
+ inputs.emplace_back(*param->slot());
}
auto graph = tracer::createGraphByTracing(
func,
std::vector<std::tuple<std::shared_ptr<Module>, std::string>>
params,
std::shared_ptr<Module> orig) {
- std::vector<Slot> member_inputs;
+ std::vector<const NamedIValue*> member_inputs;
for (auto& p : params) {
- NamedIValue* np = std::get<0>(p)->find_parameter(std::get<1>(p));
- if (np == nullptr) {
- np = std::get<0>(p)->find_buffer(std::get<1>(p));
+ auto named_param = std::get<0>(p)->find_parameter(std::get<1>(p));
+ if (named_param == nullptr) {
+ named_param = std::get<0>(p)->find_buffer(std::get<1>(p));
}
- AT_ASSERT(np != nullptr);
- member_inputs.push_back(np->slot());
+ AT_ASSERT(named_param != nullptr);
+ member_inputs.push_back(named_param);
}
Method* orig_method = orig->find_method(name);
.def(
"propagate_and_assign_input_and_output_shapes",
&Method::propagate_and_assign_input_and_output_shapes)
- .def(
- "initial_ivalues",
- [](Method& m) {
- std::vector<at::Tensor> tensors;
- for (auto& t : m.initial_ivalues()) {
- tensors.push_back(t->toTensor());
- }
- return tensors;
- })
+ .def("initial_ivalues", [](Method& m) {
+ std::vector<at::Tensor> result;
+ result.reserve(m.initial_ivalues().size());
+
+ for (auto named_ivalue : m.initial_ivalues()) {
+ AT_CHECK(
+ named_ivalue->slot()->isTensor(),
+ "Cannot get initial"
+ " IValues if any are not Tensors (found ",
+ named_ivalue->type()->python_str(),
+ ")");
+ result.push_back(named_ivalue->slot()->toTensor());
+ }
+ return result;
+ })
.def(
"graph_for",
[](py::args args, py::kwargs kwargs) {
<< " attempting to call a method with parameters/attributes"
" from a raw graph. File a bug report";
}
- // TODO: preserve the type information so we don't have to infer it here
- auto type = incompleteInferTypeFrom(*member);
- matched_schema->inputs.push_back(
- caller->get_or_add_attribute(type, member));
+ matched_schema->inputs.push_back(caller->get_or_add_initial_ivalue(member));
}
callee.check_single_output();
return inlineCallTo(graph, *callee.graph(), matched_schema->inputs).at(0);
using ModuleLookup =
std::function<std::shared_ptr<Module>(const std::vector<std::string>&)>;
+struct NamedIValue {
+ NamedIValue(std::string name, TypePtr type, IValue ivalue)
+ : name_(name),
+ type_(type),
+ ivalue_(torch::make_unique<IValue>(std::move(ivalue))) {}
+
+ Slot slot() const {
+ return Slot(ivalue_.get());
+ }
+ const std::string& name() const {
+ return name_;
+ }
+ const TypePtr& type() const {
+ return type_;
+ }
+
+ private:
+ const std::string name_;
+ const TypePtr type_;
+ std::unique_ptr<IValue> ivalue_;
+};
+
struct Method {
Method(
Module* owner,
std::string name,
bool optimize,
std::shared_ptr<Graph> graph,
- std::vector<Slot> initial_members,
+ std::vector<const NamedIValue*> initial_members,
std::function<void(Method&)> method_creator)
: owner_(owner),
name_(std::move(name)),
void run(Stack& stack) {
for (auto input : initial_ivalues_) {
- push(stack, *input);
+ push(stack, *input->slot());
}
get_executor().run(stack);
}
std::shared_ptr<Graph> graph_for(Stack inputs) {
for (auto tp : initial_ivalues_) {
- inputs.emplace_back(*tp);
+ inputs.emplace_back(*tp->slot());
}
return get_executor().graphFor(inputs);
}
size_t num_inputs() const {
return graph()->inputs().size() - initial_ivalues_.size();
}
- TORCH_API Value* get_or_add_parameter(Slot slot) {
- AT_ASSERT(slot->isTensor());
- return get_or_add_attribute(TensorType::get(), slot);
- }
- TORCH_API Value* get_or_add_attribute(TypePtr type, Slot slot) {
- auto it = initial_ivalue_index.find(slot);
+ TORCH_API Value* get_or_add_initial_ivalue(const NamedIValue* value) {
+ auto it = initial_ivalue_index.find(value);
if (it != initial_ivalue_index.end()) {
return graph()->inputs().at(it->second);
}
- initial_ivalues_.push_back(slot);
- initial_ivalue_index[slot] = graph()->inputs().size();
- return graph()->addInput()->setType(type);
+ initial_ivalues_.push_back(value);
+ initial_ivalue_index[value] = graph()->inputs().size();
+ return graph()->addInput()->setType(value->type());
}
static void setInputTensorTypes(Graph& g, const Stack& stack) {
for (at::Tensor& i : inputs) {
stack.emplace_back(std::move(i));
}
- for (const Slot& inp : initial_ivalues_) {
- stack.push_back(*inp);
+ for (const NamedIValue* inp : initial_ivalues_) {
+ stack.push_back(*inp->slot());
}
setInputTensorTypes(*retval, stack);
PropagateInputShapes(retval);
bool propagate = true) {
auto retval = graph_->copy();
for (auto inp : initial_ivalues_) {
- if (inp->isTensor()) {
- inputs.push_back(inp->toTensor());
+ if (inp->slot()->isTensor()) {
+ inputs.push_back(inp->slot()->toTensor());
}
}
if (propagate) {
return retval;
}
- const std::vector<Slot>& initial_ivalues() const {
+ const std::vector<const NamedIValue*>& initial_ivalues() const {
return initial_ivalues_;
}
// each is a pointer to a slot in the module that owns this parameter
// parameters and submodules can only be _added_ to script Modules to ensure
// these pointers always stay valid
- std::vector<Slot> initial_ivalues_;
+ std::vector<const NamedIValue*> initial_ivalues_;
- // map from a IValue* in initial_ivalues to the offset it appears at
+ // map from a const NamedIValue* in initial_ivalues to the offset it appears at
// in graph. used to accelerate get_or_add_parameter
- std::unordered_map<Slot, size_t> initial_ivalue_index;
+ std::unordered_map<const NamedIValue*, size_t> initial_ivalue_index;
// TODO: support that case where we allow _writes_ to parameters from
// compiled functions.
std::shared_ptr<Module> module;
};
-struct NamedIValue {
- NamedIValue(std::string name, TypePtr type, IValue ivalue)
- : name_(name),
- type_(type),
- ivalue_(torch::make_unique<IValue>(std::move(ivalue))) {}
-
- Slot slot() const {
- return Slot(ivalue_.get());
- }
- const std::string& name() const {
- return name_;
- }
- const TypePtr& type() const {
- return type_;
- }
-
- private:
- const std::string name_;
- const TypePtr type_;
- std::unique_ptr<IValue> ivalue_;
-};
-
struct Module {
TH_DISALLOW_COPY_AND_ASSIGN(Module);
Module() : optimize(true) {}
Method& create_method(
const std::string& name,
std::shared_ptr<Graph> graph,
- std::vector<Slot> member_inputs) {
+ std::vector<const NamedIValue*> member_inputs) {
AT_ASSERT(graph);
std::unique_ptr<Method> method(new Method(
this,
ModuleLookup module_lookup,
// parameter_remap is needed when a parent module uses a parameter of a
// submodule
- std::unordered_map<Slot, Slot>& parameter_remap,
+ std::unordered_map<const NamedIValue*, const NamedIValue*>&
+ parameter_remap,
std::vector<std::string> names = {}) const {
auto curr = module_lookup(names);
+ curr->parameters_.reserve(get_parameters().size() + get_attributes().size());
+
for (auto& param : get_parameters()) {
curr->register_parameter(
param.name(),
param.slot()->toTensor(),
/*is_buffer=*/false);
- parameter_remap[param.slot()] = curr->parameter_slot(param.name());
+ parameter_remap[¶m] = curr->find_parameter(param.name());
}
for (auto& attr : get_attributes()) {
if (!attr.type()->isSubtypeOf(TensorType::get())) {
continue;
}
curr->register_buffer(attr.name(), attr.slot()->toTensor());
- parameter_remap[attr.slot()] = curr->find_buffer(attr.name())->slot();
+ parameter_remap[&attr] = curr->find_buffer(attr.name());
}
for (auto& mod : get_modules()) {
names.push_back(mod.name);
mod.module->copy_into(module_lookup, parameter_remap, names);
names.pop_back();
}
+
for (auto& method : get_methods()) {
- std::vector<Slot> initial_ivalues;
+ std::vector<const NamedIValue*> initial_ivalues;
for (auto& p : method->initial_ivalues()) {
initial_ivalues.push_back(parameter_remap.at(p));
}