_(namespaces, onnx) \
_(namespaces, attr) \
_(namespaces, scope) \
+ _(namespaces, user) \
_(namespaces, namespaces) \
_(prim, Assign) \
_(prim, BroadcastingChunk) \
_(prim, fork) \
_(prim, RaiseException) \
_(prim, Function) \
+ _(prim, CreateUserObject) \
+ _(prim, SetAttr) \
+ _(prim, GetAttr) \
_(aten, append) \
_(aten, format) \
_(aten, __not__) \
_(attr, b) \
_(attr, beg) \
_(attr, idx) \
- _(attr, split)
+ _(attr, split) \
+ _(attr, slot)
#else
#define FORALL_NS_SYMBOLS(_) \
_(namespaces, prim) \
_(namespaces, onnx) \
_(namespaces, attr) \
_(namespaces, scope) \
+ _(namespaces, user) \
_(namespaces, namespaces)
#endif
static Symbol aten(const std::string & s);
static Symbol onnx(const std::string & s);
static Symbol prim(const std::string & s);
+ static Symbol user(const std::string & s);
// TODO: eliminate me
static Symbol scope(const std::string & s);
bool is_aten() const;
bool is_prim() const;
bool is_onnx() const;
+ bool is_user() const;
// So we can switch on this
constexpr operator unique_t() const {
inline Symbol Symbol::onnx(const std::string & s) { return Symbol::fromQualString("onnx::" + s); }
inline Symbol Symbol::prim(const std::string & s) { return Symbol::fromQualString("prim::" + s); }
inline Symbol Symbol::scope(const std::string & s) { return Symbol::fromQualString("scope::" + s); }
+inline Symbol Symbol::user(const std::string & s) { return Symbol::fromQualString("user::" + s); }
inline bool Symbol::is_attr() const { return ns() == namespaces::attr; }
inline bool Symbol::is_aten() const { return ns() == namespaces::aten; }
inline bool Symbol::is_prim() const { return ns() == namespaces::prim; }
inline bool Symbol::is_onnx() const { return ns() == namespaces::onnx; }
+inline bool Symbol::is_user() const { return ns() == namespaces::user; }
} // namespace c10
return out << v.toDevice();
case IValue::Tag::GenericDict:
return printDict(out, v.toGenericDict());
+ case IValue::Tag::UserObject:
+ // TODO we should print the object contents
+ return out << "UserObject<" << v.toUserObject()->name().toUnqualString()
+ << ">";
}
AT_ERROR("Tag not found\n");
}
using BoolList = List<bool>;
using GenericList = List<IValue>;
-
+struct UserObject;
}
// IValue is the generic tagged union used by the interpreter to hold
_(GenericList) \
_(GenericDict) \
_(Future) \
- _(Device)
+ _(Device) \
+ _(UserObject)
struct CAFFE2_API IValue final {
IValue()
return toIntrusivePtr<ivalue::GenericDict>();
}
+ // UserType
+ IValue(c10::intrusive_ptr<ivalue::UserObject> v);
+ bool isUserObject() const { return tag == Tag::UserObject; }
+ c10::intrusive_ptr<ivalue::UserObject> toUserObject() && {
+ AT_ASSERT(isUserObject());
+ return toIntrusivePtr<ivalue::UserObject>();
+ }
+ c10::intrusive_ptr<ivalue::UserObject> toUserObject() const & {
+ AT_ASSERT(isUserObject());
+ return toIntrusivePtr<ivalue::UserObject>();
+ }
+
// None
bool isNone() const {
return Tag::None == tag;
FutureError error;
};
+// User-defined object.
+struct C10_EXPORT ivalue::UserObject final : c10::intrusive_ptr_target {
+ public:
+ UserObject(Symbol name, size_t numSlots) : typename_(std::move(name)) {
+ slots_.resize(numSlots);
+ }
+
+ static c10::intrusive_ptr<UserObject> create(
+ Symbol name,
+ size_t numSlots) {
+ return c10::make_intrusive<UserObject>(std::move(name), numSlots);
+ }
+
+ void setSlot(size_t slot, IValue v) {
+ slots_[slot] = v;
+ }
+
+ IValue getSlot(size_t slot) const {
+ return slots_.at(slot);
+ }
+
+ Symbol name() const {
+ return typename_;
+ }
+
+ private:
+ const Symbol typename_;
+ std::vector<IValue> slots_;
+};
+
struct C10_EXPORT ivalue::GenericDict : c10::intrusive_ptr_target {
private:
UnorderedMap elements_;
DEFINE_TO(c10::intrusive_ptr<ivalue::GenericList>, toGenericList)
DEFINE_TO(c10::intrusive_ptr<ivalue::GenericDict>, toGenericDict)
DEFINE_TO(c10::intrusive_ptr<ivalue::ConstantString>, toString)
+DEFINE_TO(c10::intrusive_ptr<ivalue::UserObject>, toUserObject)
DEFINE_TO(at::Scalar, toScalar)
DEFINE_TO(std::vector<int64_t>, toIntListRef)
DEFINE_TO(std::vector<double>, toDoubleListRef)
inline IValue::IValue(ivalue::UnorderedMap v)
: IValue(ivalue::GenericDict::create(std::move(v))) {}
+inline IValue::IValue(c10::intrusive_ptr<ivalue::UserObject> v)
+: tag(Tag::UserObject), is_intrusive_ptr(true) {
+ payload.as_intrusive_ptr = v.release();
+}
inline IValue::IValue(c10::intrusive_ptr<ivalue::Future> v)
: tag(Tag::Future), is_intrusive_ptr(true) {
payload.as_intrusive_ptr = v.release();
#include <iostream>
#include <type_traits>
+namespace torch {
+namespace jit {
+namespace script {
+struct Module;
+struct Method;
+}
+} // namespace jit
+} // namespace torch
+
namespace c10 {
#define C10_FORALL_TYPES(_) \
_(OptionalType) \
_(VarType) \
_(DeviceObjType) \
+_(UserType) \
enum class TypeKind {
#define DEFINE_TYPE(T) T,
CAFFE2_API TypePtr evalTypeVariables(TypePtr type, TypeEnv & type_env);
+/**
+ * User Defined Types
+ */
+
+struct UserType;
+using UserTypePtr = std::shared_ptr<UserType>;
+using ::torch::jit::script::Module;
+using ::torch::jit::script::Method;
+
+// This represents a user-defined type in TorchScript.
+struct CAFFE2_API UserType : public Type {
+ // Create a user type and register it globally.
+ static UserTypePtr create(const std::string& name, std::shared_ptr<Module> module);
+ // returns nullptr if there is no type with that name
+ static UserTypePtr get(const std::string& name);
+
+ DEFINE_IS_SUBCLASS(UserType);
+ bool operator==(const Type& rhs) const override {
+ if (auto user_rhs = rhs.cast<UserType>()) {
+ return typename_ == user_rhs->typename_;
+ }
+ return false;
+ }
+
+ bool isSubtypeOf(const TypePtr rhs) const override {
+ // XXX: We do not have inheritance implemented, only types that are the
+ // same can subtype from each other.
+ return *this == *rhs;
+ }
+ std::string str() const override {
+ return std::string("UserType<") + typename_ + ">";
+ }
+
+ TypePtr getAttribute(const std::string& name) const {
+ const auto it = std::find_if(
+ attributes_.cbegin(), attributes_.cend(), [&](const Attribute& attr) {
+ return attr.name == name;
+ });
+ if (it == attributes_.cend()) {
+ return nullptr;
+ }
+
+ return it->type;
+ }
+
+ Method* getMethod(const std::string& name) const;
+
+ std::string name() const {
+ return typename_;
+ }
+
+ size_t numAttributes() const {
+ return attributes_.size();
+ }
+
+ // Attributes are stored in a specific slot at runtime for effiency.
+ // When emitting instructions we specify the slot so that attribute access is
+ // a constant lookup
+ size_t getAttributeSlot(const std::string& name) const {
+ size_t slot = 0;
+ for (const auto& attr : attributes_) {
+ if (name == attr.name) {
+ return slot;
+ }
+ slot++;
+ }
+ throw std::runtime_error("Couldn't find attribute: " + name);
+ }
+
+ bool hasAttribute(const std::string& name) const {
+ return std::find_if(
+ attributes_.cbegin(),
+ attributes_.cend(),
+ [&](const Attribute& attr) { return attr.name == name; }) !=
+ attributes_.cend();
+ }
+
+ void addAttribute(const std::string& name, TypePtr type) {
+ attributes_.emplace_back(name, type);
+ }
+
+ static const TypeKind Kind = TypeKind::UserType;
+
+ private:
+ UserType(std::string name, std::shared_ptr<Module> module)
+ : Type(TypeKind::UserType),
+ typename_(std::move(name)),
+ module_(std::move(module)) {}
+
+ // Name of type (note that this has to be globally unique).
+ std::string typename_;
+
+ // Mapping of attribute names -> their type.
+ // NOTE: this does not contain methods, which are stored in the module
+ // TODO: once modules support arbitrary ivalue attributes, we don't need this
+ // anymore.
+ struct Attribute {
+ Attribute(std::string n, TypePtr t)
+ : name(std::move(n)), type(std::move(t)) {}
+ std::string name;
+ TypePtr type;
+ };
+ std::vector<Attribute> attributes_;
+ // Holds method attributes
+ std::shared_ptr<Module> module_;
+
+};
} // namespace c10
return *this == *rhs;
}
+namespace {
+class UserTypeRegistry {
+ public:
+ void registerType(std::string name, UserTypePtr type) {
+ std::lock_guard<std::mutex> g(mutex_);
+ // TODO: new type registrations will override the old ones. Is this safe?
+ reg_[name] = type;
+ }
+
+ UserTypePtr getType(const std::string& name) {
+ std::lock_guard<std::mutex> g(mutex_);
+ if (reg_.count(name)) {
+ return reg_.at(name);
+ }
+ return nullptr;
+ }
+
+ private:
+ std::mutex mutex_;
+ std::unordered_map<std::string, UserTypePtr> reg_;
+};
+
+UserTypeRegistry& getRegistry() {
+ static UserTypeRegistry r;
+ return r;
+}
+} // namespace
+
+UserTypePtr UserType::create(
+ const std::string& name,
+ std::shared_ptr<Module> module) {
+ auto ptr = UserTypePtr(new UserType(name, std::move(module)));
+ getRegistry().registerType(name, ptr);
+ return ptr;
+}
+
+UserTypePtr UserType::get(const std::string& name) {
+ return getRegistry().getType(name);
+}
} // namespace c10
assert 1 == 1, "hello"
return x
- ast = torch.jit.frontend.get_jit_ast(fn, is_method=False)
+ ast = torch.jit.frontend.get_jit_def(fn)
self.assertExpected(str(ast))
@unittest.skipIf(not PY2, "Requires python 2")
def test_python_frontend_py2(self):
def fn():
raise Exception("hello")
- ast = torch.jit.frontend.get_jit_ast(fn, is_method=False)
+ ast = torch.jit.frontend.get_jit_def(fn)
self.assertExpected(str(ast))
@unittest.skipIf(PY2, "Requires python 3")
def test_python_frontend_py3(self):
def fn():
raise Exception("hello")
- ast = torch.jit.frontend.get_jit_ast(fn, is_method=False)
+ ast = torch.jit.frontend.get_jit_def(fn)
self.assertExpected(str(ast))
def _make_scalar_vars(self, arr, dtype):
self.assertEqual(first_forward, r1_forward)
+class TestUserType(JitTestCase):
+ def test_get_with_method(self):
+ # Remove this when import/export is implemented for classes
+ with self.disableModuleHook():
+ @torch.jit.script
+ class Foo:
+ def __init__(self, x):
+ self.foo = x
+
+ def getFoo(self):
+ return self.foo
+
+ @torch.jit.script
+ def fn(x):
+ foo = Foo(x)
+ return foo.getFoo()
+
+ input = torch.ones(2, 3)
+ self.assertEqual(fn(input), input)
+
+ def test_get_attr(self):
+ # Remove this when import/export is implemented for classes
+ with self.disableModuleHook():
+ @torch.jit.script
+ class Foo:
+ def __init__(self, x):
+ self.foo = x
+
+ @torch.jit.script
+ def fn(x):
+ foo = Foo(x)
+ return foo.foo
+
+ input = torch.ones(2, 3)
+ self.assertEqual(fn(input), input)
+
+ def test_set_attr_in_method(self):
+ # Remove this when import/export is implemented for classes
+ with self.disableModuleHook():
+ @torch.jit.script
+ class Foo:
+ def __init__(self, x):
+ # type: (int)
+ self.foo = x
+
+ def incFoo(self, y):
+ # type: (int)
+ self.foo = self.foo + y
+
+ @torch.jit.script
+ def fn(x):
+ # type: (int)
+ foo = Foo(x)
+ foo.incFoo(2)
+ return foo.foo
+
+ self.assertEqual(fn(1), 3)
+
+ def test_set_attr_type_mismatch(self):
+ # Remove this when import/export is implemented for classes
+ with self.disableModuleHook():
+ with self.assertRaisesRegex(RuntimeError, "Wrong type for attribute assignment"):
+ @torch.jit.script
+ class Foo:
+ def __init__(self, x):
+ self.foo = x
+ self.foo = 10 # should error since int != Tensor
+
+ def test_get_attr_not_initialized(self):
+ # Remove this when import/export is implemented for classes
+ with self.disableModuleHook():
+ with self.assertRaisesRegex(RuntimeError, "Tried to access to nonexistent attribute"):
+ @torch.jit.script
+ class Foo:
+ def __init__(self, x):
+ self.foo = x
+
+ def get_non_initialized(self):
+ return self.asdf # asdf isn't an attr
+
+ def test_set_attr_non_initialized(self):
+ # Remove this when import/export is implemented for classes
+ with self.disableModuleHook():
+ with self.assertRaisesRegex(RuntimeError, "Tried to set nonexistent attribute"):
+ @torch.jit.script
+ class Foo:
+ def __init__(self, x):
+ self.foo = x
+
+ def set_non_initialized(self, y):
+ self.bar = y # can't assign to non-initialized attr
+
+ def test_type_annotations(self):
+ # Remove this when import/export is implemented for classes
+ with self.disableModuleHook():
+ with self.assertRaisesRegex(RuntimeError, "expected a value of type bool"):
+ @torch.jit.script
+ class Foo:
+ def __init__(self, x):
+ # type: (bool)
+ self.foo = x
+
+ @torch.jit.script
+ def fn(x):
+ Foo(x)
+
+ fn(2)
+
+ def test_conditional_set_attr(self):
+ # Remove this when import/export is implemented for classes
+ with self.disableModuleHook():
+ with self.assertRaisesRegex(RuntimeError, "assignment cannot be in a control-flow block"):
+ @torch.jit.script
+ class Foo:
+ def __init__(self, x):
+ if True:
+ self.attr = x
+
+ def test_user_type_as_param(self):
+ # Remove this when import/export is implemented for classes
+ with self.disableModuleHook():
+ @torch.jit.script
+ class Foo:
+ def __init__(self, x):
+ self.attr = x
+
+ @torch.jit.script
+ def fn(foo):
+ # type: (Foo)
+ return foo.attr
+
+ @torch.jit.script
+ def fn2(x):
+ foo = Foo(x)
+ return fn(foo)
+
+ input = torch.ones(1)
+ self.assertEqual(fn2(input), input)
+
+
for test in autograd_method_tests():
add_autograd_test(*test)
"torch/csrc/jit/script/script_type_parser.cpp",
"torch/csrc/jit/script/sugared_value.cpp",
"torch/csrc/jit/script/schema_matching.cpp",
+ "torch/csrc/jit/script/user_type.cpp",
"torch/csrc/jit/script/parser.cpp",
"torch/csrc/jit/testing/file_check.cpp",
"torch/csrc/jit/import_method.cpp",
${TORCH_SRC_DIR}/csrc/jit/script/schema_type_parser.cpp
${TORCH_SRC_DIR}/csrc/jit/script/script_type_parser.cpp
${TORCH_SRC_DIR}/csrc/jit/script/sugared_value.cpp
+ ${TORCH_SRC_DIR}/csrc/jit/script/user_type.cpp
${TORCH_SRC_DIR}/csrc/jit/script/parser.cpp
${TORCH_SRC_DIR}/csrc/jit/script/builtin_functions.cpp
${TORCH_SRC_DIR}/csrc/jit/script/edit_distance.cpp
return std::make_shared<script::SimpleValue>(
m.get_or_add_parameter(v->slot()));
} else if (script::Method* m = module->find_method(field)) {
- return std::make_shared<script::MethodValue>(module, *m);
+ return std::make_shared<script::MethodValue>(shared_from_this(), *m);
} else {
throw script::ErrorReport(loc) << "unknown attr: " << field;
}
case prim::IgnoredPythonOp:
case prim::Print:
case prim::RaiseException:
+ case prim::SetAttr:
case aten::warn:
return true;
}
return result;
}
+Node* Graph::createUserObject(const UserTypePtr& type) {
+ auto result = create(prim::CreateUserObject);
+ result->output()->setType(type);
+ return result;
+}
+
+Node* Graph::createSetAttr(
+ Value* obj,
+ const std::string& field,
+ Value* newValue) {
+ const auto userType = obj->type()->expect<UserType>();
+
+ auto n = create(prim::SetAttr, {obj, newValue}, /*num_outputs=*/0);
+ n->s_(attr::name, field);
+ return n;
+}
+
+Node* Graph::createGetAttr(Value* obj, const std::string& field) {
+ const auto userType = obj->type()->expect<UserType>();
+
+ auto n = create(prim::GetAttr, {obj}, /*num_outputs=*/1);
+ n->s_(attr::name, field);
+
+ const auto outputType = userType->getAttribute(field);
+ n->output()->setType(outputType);
+ return n;
+}
+
Node* Graph::createClone(
Node* n,
const std::function<Value*(Value*)>& value_map,
TORCH_API Node* createDictIndex(Value* dict, Value* index);
TORCH_API Node* createNumToTensor(Value* value);
TORCH_API Node* createImplicitTensorToNum(const TypePtr& type, Value* value);
+ TORCH_API Node* createUserObject(const UserTypePtr& type);
+ TORCH_API Node* createSetAttr(
+ Value* obj,
+ const std::string& field,
+ Value* newValue);
+ TORCH_API Node* createGetAttr(Value* obj, const std::string& field);
Node* createPythonOp(
THPObjectPtr&& pyobj,
const std::string& cconv,
type->kind() == TypeKind::TupleType ||
type->kind() == TypeKind::DictType || type->kind() == TypeKind::VarType ||
type->kind() == TypeKind::FutureType ||
+ type->kind() == TypeKind::UserType ||
(type->kind() == TypeKind::OptionalType &&
shouldAnnotate(type->cast<OptionalType>()->getElementType()));
}
std::map<TypeKind, std::vector<Value*>> listTypes;
std::unordered_map<TupleTypePtr, std::vector<Value*>> tupleTypes;
std::unordered_map<DictTypePtr, std::vector<Value*>> dictTypes;
+ std::unordered_map<UserTypePtr, std::vector<Value*>> userTypes;
std::vector<Value*> tensors;
for (auto input : graph->inputs()) {
} else if (inputType->kind() == TypeKind::DictType) {
auto dictType = inputType->cast<DictType>();
dictTypes[dictType].push_back(input);
+ } else if (inputType->kind() == TypeKind::UserType) {
+ auto userType = inputType->cast<UserType>();
+ userTypes[userType].push_back(input);
} else {
AT_ASSERT(!shouldAnnotate(input));
}
for (const auto& pr : dictTypes) {
makeAllAlias(pr.second, *aliasTracker_);
}
+ for (const auto& pr : userTypes) {
+ makeAllAlias(pr.second, *aliasTracker_);
+ }
makeAllAlias(tensors, *aliasTracker_);
analyze(graph->block());
case prim::BroadcastSizes:
case prim::ChunkSizes:
case prim::Function:
+ case prim::CreateUserObject:
return analyzeCreator(node);
case prim::TupleUnpack:
case prim::TupleIndex:
case prim::TupleSlice:
case prim::ListUnpack:
case prim::PythonOp:
+ case prim::GetAttr:
return analyzeExtractor(node);
case prim::ConstantChunk:
return analyzeChunk(node);
case prim::BroadcastingChunk:
return analyzeBroadcastingChunk(node);
+ case prim::SetAttr:
+ return analyzeSetAttr(node);
case aten::add:
case aten::sub:
case aten::mul:
// gives up and creates wildcards for everything.
void AliasDb::analyzeExtractor(Node* node) {
for (const auto output : node->outputs()) {
- aliasTracker_->setWildcard(output);
+ if (shouldAnnotate(output)) {
+ aliasTracker_->setWildcard(output);
+ }
}
}
}
}
+// SetAttr: writes to the `self` field
+void AliasDb::analyzeSetAttr(Node* node) {
+ const auto self = node->inputs().at(0);
+ AT_ASSERT(self->type()->kind() == TypeKind::UserType);
+ aliasTracker_->registerWrite(self, node);
+}
+
// BroadcastingChunk: all inputs are broadcasted, and then individually chunked.
// This is an intermediate node used only in the graph fuser.
void AliasDb::analyzeBroadcastingChunk(Node* node) {
prim::ConstantChunk,
prim::BroadcastingChunk,
prim::fork,
+ prim::CreateUserObject,
+ prim::GetAttr,
+ prim::SetAttr,
aten::wait,
aten::add,
aten::sub,
void analyzeBroadcastingChunk(Node* node);
void analyzeFork(Node* node);
void analyzeWait(Node* node);
+ void analyzeSetAttr(Node* node);
void makeAliasOf(const Value* value, const Value* to);
void mapAliases(at::ArrayRef<Value*> to, at::ArrayRef<Value*> from);
if (!aliasDb_) {
// If we don't have alias information, all mutable ops have unknown
// effects and can't be considered for elimination.
- if (!node->kind().is_aten()) {
+ if (!node->kind().is_aten() && !node->kind().is_prim()) {
return false;
}
// onnx export calls EliminateDeadCode but sometimes passes invalid
[graph, name, this] { printFunctionDefinition(*graph, name); });
stmt << "self." << name;
} break;
+ case prim::CreateUserObject:
+ case prim::SetAttr:
+ case prim::GetAttr:
+ throw std::runtime_error("NYI");
default: {
Symbol kind = node->kind();
if (kind.is_aten()) {
prim::TupleSlice,
prim::TupleUnpack,
prim::Undefined,
+ prim::CreateUserObject,
+ prim::GetAttr,
+ prim::SetAttr,
};
// WARNING: by adding a value to this set, you are asserting that your
case TypeKind::GeneratorType:
case TypeKind::VarType:
case TypeKind::FutureType:
+ case TypeKind::UserType:
break;
}
AT_ERROR(
}
return 0;
}),
-});
+ Operator(
+ prim::CreateUserObject,
+ [](const Node* node) {
+ const auto type = node->output()->type()->expect<UserType>();
+ const auto name = Symbol::user(type->name());
+ const size_t numAttrs = type->numAttributes();
+ return [name, numAttrs](Stack& stack) {
+ auto userObj =
+ c10::ivalue::UserObject::create(name, numAttrs);
+ push(stack, std::move(userObj));
+ return 0;
+ };
+ }),
+ Operator(
+ prim::GetAttr,
+ [](const Node* node) {
+ const auto type = node->input()->type()->expect<UserType>();
+ const auto& field = node->s(attr::name);
+ const auto slot = type->getAttributeSlot(field);
+ return [slot](Stack& stack) {
+ auto userObj = pop(stack).toUserObject();
+ auto value = userObj->getSlot(slot);
+ push(stack, std::move(value));
+ return 0;
+ };
+ }),
+ Operator(prim::SetAttr, [](const Node* node) {
+ const auto type = node->inputs().at(0)->type()->expect<UserType>();
+ const auto& field = node->s(attr::name);
+ const auto slot = type->getAttributeSlot(field);
+ return [slot](Stack& stack) {
+ auto v = pop(stack);
+ auto userObj = pop(stack).toUserObject();
+ userObj->setSlot(slot, std::move(v));
+ return 0;
+ };
+ })});
// define implementations for primitive number ops
#define DEFINE_GENERIC_OP(aten_op, int_op, float_op, int_result, float_result) \
#include <torch/csrc/jit/script/compiler.h>
+
#include <c10/util/Exception.h>
#include <torch/csrc/jit/hooks_for_testing.h>
#include <torch/csrc/jit/interpreter.h>
const SugaredValuePtr& self,
Block* block) {
auto schema = extractSchemaFromDef(def, self);
+ // TODO need guards on init returning none
if (schema.returns().size() == 1) {
def_stack_.back().declared_return_type_ = schema.returns().at(0).type();
}
const SugaredValuePtr& self) {
auto params_begin = decl.params().begin();
auto params_end = decl.params().end();
- if (self)
+ if (self) {
++params_begin;
+ }
std::vector<Argument> retval;
std::vector<Expr> default_types;
FunctionSchema extractSchemaFromDef(
const Def& def,
const SugaredValuePtr& self) {
- auto name = def.name().name();
+ const auto name = def.name().name();
std::vector<Argument> args = parseArgsFromDecl(def.decl(), self);
std::vector<Argument> returns = parseReturnFromDecl(def.decl());
return FunctionSchema(
// inputs
auto it = def.decl().params().begin();
auto end = def.decl().params().end();
- auto expected_annotation_size =
- self ? def.decl().params().size() - 1 : def.decl().params().size();
+ auto expected_annotation_size = def.decl().params().size();
+ if (self) {
+ expected_annotation_size--;
+ }
if (schema.arguments().size() != expected_annotation_size) {
throw ErrorReport(def.decl().params().range())
<< "Number of type annotations for"
<< " does not match the number of parameters on the function ("
<< expected_annotation_size << ")!";
}
+
if (self) {
AT_ASSERT(it != end);
- environment_stack->setSugaredVar(def.range(), (*it).ident().name(), self);
+ const auto& name = (*it).ident().name();
+ if (auto userType = dynamic_cast<UserTypeValue*>(self.get())) {
+ const auto type = userType->type_;
+ Value* new_input =
+ block->addInput()->setUniqueName(name)->setType(type);
+ environment_stack->setVar((*it).ident().range(), name, new_input);
+ arguments.emplace_back(name, type);
+ } else {
+ environment_stack->setSugaredVar(def.range(), name, self);
+ }
++it;
}
size_t arg_annotation_idx = 0;
case TK_TUPLE_LITERAL:
emitTupleAssign(TupleLiteral(stmt.lhs()), stmt.rhs());
break;
+ case '.':
+ emitSelectAssign(stmt);
+ break;
case TK_SUBSCRIPT:
emitSubscriptAssign(stmt.range(), Subscript(stmt.lhs()), stmt.rhs());
break;
}
}
+ void emitSelectAssign(const Assign& stmt) {
+ const auto lhs = Select(stmt.lhs());
+ const auto basename = Var(lhs.value()).name();
+ const auto rhsValue =
+ emitSugaredExpr(stmt.rhs(), 1)->asValue(stmt.rhs().range(), method);
+ auto userObject = environment_stack->getSugaredVar(basename);
+ const bool shouldDefine =
+ method.name() == "__init__" && basename.name() == "self";
+ userObject->setAttr(
+ stmt.range(), method, lhs.selector().name(), rhsValue, shouldDefine);
+ }
+
NodeKind getNodeKind(int kind, int ninputs) {
switch (kind) {
case '+':
if (NamedModule* v = module->find_module(field)) {
return std::make_shared<ModuleValue>(v->module);
} else if (Method* v = module->find_method(field)) {
- return std::make_shared<MethodValue>(module, *v);
+ return std::make_shared<MethodValue>(shared_from_this(), *v);
} else if (NamedParameter* v = module->find_parameter(field)) {
return std::make_shared<SimpleValue>(m.get_or_add_parameter(v->slot()));
}
<< "Attempted to inline a Module with parameters. "
"Stateful modules to be inlined must be submodules of the callee.";
}
+ const auto script_class_type =
+ py::module::import("torch.jit").attr("ScriptClass");
+ const bool is_user_type = py::isinstance(obj, script_class_type);
+ if (is_user_type) {
+ const auto classname = py::cast<std::string>(py::getattr(obj, "_name"));
+ auto userType = UserType::get(classname);
+ AT_ASSERT(userType);
+ return std::make_shared<UserTypeValue>(std::move(userType));
+ }
return std::make_shared<ModuleValue>(mod);
} else if (py::isinstance<py::module>(obj)) {
return std::make_shared<PythonModuleValue>(obj);
return mod;
});
+ m.def(
+ "_jit_script_class_compile",
+ [](std::shared_ptr<Module> module,
+ const ClassDef& classDef,
+ ResolutionCallback rcb) {
+ auto userType = UserType::create(classDef.name().name(), module);
+ std::vector<Resolver> rcbs;
+ std::vector<Def> methodDefs;
+ for (const auto& def : classDef.defs()) {
+ methodDefs.push_back(def);
+ rcbs.push_back(pythonResolver(rcb));
+ }
+ defineMethodsInModule(
+ module,
+ methodDefs,
+ rcbs,
+ std::make_shared<UserTypeValue>(userType));
+ return module;
+ });
+
m.def("parse_type_comment", [](const std::string& comment) {
Parser p(comment);
return Decl(p.parseTypeComment());
_(TK_RAISE, "raise", "raise") \
_(TK_ASSERT, "assert", "assert") \
_(TK_DOTS, "dots", "...") \
- _(TK_PASS, "pass", "pass")
+ _(TK_PASS, "pass", "pass") \
+ _(TK_CLASS_DEF, "class", "class")
static const char* valid_single_char_tokens = "+-*/%@()[]:,={}><.?!&^|";
const auto& r = name.range();
return Def::create(r, name, decl, wrap_list(r, std::move(body)));
}));
+ py::class_<ClassDef, TreeView>(m, "ClassDef")
+ .def(py::init([](const Ident& name, std::vector<Def> body) {
+ const auto& r = name.range();
+ return ClassDef::create(r, name, wrap_list(r, std::move(body)));
+ }));
py::class_<Decl, TreeView>(m, "Decl").def(py::init(
[](const SourceRange& r, std::vector<Param> params, Expr* return_type) {
return Decl::create(
if (itr != ident_to_type_lut().end()) {
return itr->second;
}
+ if (auto typePtr = UserType::get(*name)) {
+ return typePtr;
+ }
throw ErrorReport(expr) << "Unknown type name " << *name;
}
throw ErrorReport(expr.range())
Method& m,
const std::string& field) {
// Allow method-style casts on Tensor types. e.g. x.int()
- if (value->type()->isSubtypeOf(TensorType::get())) {
+ if (value_->type()->isSubtypeOf(TensorType::get())) {
if (builtin_cast_methods().count(field)) {
return std::make_shared<BuiltinFunction>(
Symbol::aten(builtin_cast_methods().at(field)),
- NamedValue(loc, "self", value));
+ NamedValue(loc, "self", value_));
}
// functions that are just direct property lookups on tensor
// must be registered as prim::<name>(Tensor t) -> <return_type>
};
if (fields.count(field)) {
auto r =
- m.graph()->insert(Symbol::fromQualString("prim::" + field), {value});
+ m.graph()->insert(Symbol::fromQualString("prim::" + field), {value_});
return std::make_shared<SimpleValue>(r);
}
}
- if (getValue()->type()->isSubtypeOf(NumberType::get())) {
+ if (value_->type()->isSubtypeOf(NumberType::get())) {
throw ErrorReport(loc) << "Cannot call methods on numbers";
}
- if (getValue()->type()->kind() == TypeKind::TupleType) {
- auto tuple_type = getValue()->type()->expect<TupleType>();
+ if (auto tuple_type = value_->type()->cast<TupleType>()) {
if (!tuple_type->hasNames()) {
throw ErrorReport(loc) << "Getting attributes of tuples is not supported";
}
auto names = tuple_type->names();
- for (int i = 0; i < names.size(); i++) {
+ for (size_t i = 0; i < names.size(); i++) {
if (names[i] == field) {
auto r = m.graph()
- ->insertNode(m.graph()->createTupleIndex(getValue(), i))
+ ->insertNode(m.graph()->createTupleIndex(value_, i))
->output();
return std::make_shared<SimpleValue>(r);
}
}
throw ErrorReport(loc) << "Unknown attribute to named tuple";
}
+
+ if (auto userType = value_->type()->cast<UserType>()) {
+ // This is a user-defined type, emit the proper attribute lookup
+ if (auto method = userType->getMethod(field)) {
+ return std::make_shared<MethodValue>(shared_from_this(), *method);
+ }
+
+ if (!userType->hasAttribute(field)) {
+ throw ErrorReport(loc)
+ << "Tried to access to nonexistent attribute " << field
+ << ". Did you forget to initialize it in __init__()?";
+ }
+ auto& g = *m.graph();
+ auto n = g.insertNode(g.createGetAttr(value_, field));
+ return std::make_shared<SimpleValue>(n->output());
+ }
+
return std::make_shared<BuiltinFunction>(
- Symbol::aten(field), NamedValue(loc, "self", value));
+ Symbol::aten(field), NamedValue(loc, "self", value_));
}
std::vector<std::shared_ptr<SugaredValue>> SimpleValue::asTuple(
[](Value* v) -> std::shared_ptr<SugaredValue> {
return std::make_shared<SimpleValue>(v);
};
- if (value->type()->kind() == TypeKind::TupleType) {
- auto outputs = createTupleUnpack(value);
+ if (value_->type()->kind() == TypeKind::TupleType) {
+ auto outputs = createTupleUnpack(value_);
return fmap(outputs, make_simple_value);
- } else if (value->type()->kind() == TypeKind::ListType) {
+ } else if (value_->type()->kind() == TypeKind::ListType) {
if (!size_hint) {
throw ErrorReport(loc)
<< "cannot statically infer the expected size of a list in this context";
}
- auto graph = value->owningGraph();
+ auto graph = value_->owningGraph();
Node* unpack =
- graph->insertNode(graph->createListUnpack(value, *size_hint));
+ graph->insertNode(graph->createListUnpack(value_, *size_hint));
return fmap(unpack->outputs(), make_simple_value);
}
- throw ErrorReport(loc) << value->type()->str()
+ throw ErrorReport(loc) << value_->type()->str()
<< " cannot be used as a tuple";
}
+
+void SimpleValue::setAttr(
+ const SourceRange& loc,
+ Method& m,
+ const std::string& field,
+ Value* newValue,
+ bool shouldDefine) {
+ const auto userType = value_->type()->cast<UserType>();
+ if (!userType) {
+ throw ErrorReport(loc) << "Tried to set an attribute: " << field
+ << " on a non-user-defined type: "
+ << value_->type()->str();
+ }
+
+ auto expectedType = userType->getAttribute(field);
+ if (!expectedType) {
+ // We don't have an attribute with this name, either add it to the type
+ // definition or throw an error
+ if (shouldDefine) {
+ userType->addAttribute(field, newValue->type());
+ expectedType = newValue->type();
+ const auto insertPoint = m.graph()->insertPoint();
+ const auto topLevelBlock = m.graph()->block();
+ if (insertPoint->owningBlock() != topLevelBlock) {
+ throw ErrorReport(loc)
+ << "First assignment cannot be in a control-flow block. "
+ << "Initialize the field at the top level first.";
+ }
+ } else {
+ throw ErrorReport(loc)
+ << "Tried to set nonexistent attribute: " << field
+ << ". Did you forget to initialize it in __init__()?";
+ }
+ }
+
+ // Check type correctness
+ const auto newType = newValue->type();
+ if (!newType->isSubtypeOf(expectedType)) {
+ throw ErrorReport(loc) << "Wrong type for attribute assignment. Expected "
+ << expectedType->str() << " but got "
+ << newType->str();
+ }
+
+ auto& g = *m.graph();
+ g.insertNode(g.createSetAttr(value_, field, newValue));
+}
+
+std::shared_ptr<SugaredValue> UserTypeValue::call(
+ const SourceRange& loc,
+ Method& m,
+ // note: names for args will be 'argument 0', 'argument 1', etc..
+ at::ArrayRef<NamedValue> inputs,
+ at::ArrayRef<NamedValue> attributes,
+ size_t n_binders) {
+ AT_ASSERT(n_binders <= 1);
+
+ // Generate a new object of the right type, then call `__init__` on it
+ auto& g = *m.graph();
+ auto createNode = g.insertNode(g.createUserObject(type_));
+ auto self = std::make_shared<SimpleValue>(createNode->output());
+
+ auto initMethod = type_->getMethod("__init__");
+ AT_ASSERT(initMethod);
+
+ // Call the init function
+ MethodValue(self, *initMethod).call(loc, m, inputs, attributes, n_binders);
+
+ return self;
+}
} // namespace script
} // namespace jit
} // namespace torch
const std::string& field) {
throw ErrorReport(loc) << "attribute lookup is not defined on " << kind();
}
+
+ // assign an attribute on it, e.g. `this.field = newValue`
+ virtual void setAttr(
+ const SourceRange& loc,
+ Method& m,
+ const std::string& field,
+ Value* newValue,
+ bool shouldDefine) {
+ throw ErrorReport(loc) << "attribute assignment is not defined on "
+ << kind();
+ }
virtual NoneStatus isNone() {
return NEVER;
}
// most things in the environment are just simple value types
// and not special python syntax sugar types
struct TORCH_API SimpleValue : public SugaredValue {
- SimpleValue(Value* value) : value(value) {}
+ SimpleValue(Value* value) : value_(value) {}
std::string kind() const override {
return "value";
}
Value* asValue(const SourceRange& range, Method& m) override {
- return value;
+ return value_;
}
NoneStatus isNone() override {
- if (value->mustBeNone())
+ if (value_->mustBeNone())
return ALWAYS;
- else if (value->type()->cast<OptionalType>())
+ else if (value_->type()->cast<OptionalType>())
return MAYBE;
else
return NEVER;
const SourceRange& loc,
Method& m,
const std::string& field) override;
+
+ void setAttr(
+ const SourceRange& loc,
+ Method& m,
+ const std::string& field,
+ Value* newValue,
+ bool shouldDefine) override;
+
Value* getValue() const {
- return value;
+ return value_;
}
private:
- Value* value;
+ Value* value_;
};
struct TORCH_API BuiltinFunction : public SugaredValue {
c10::optional<int64_t> version;
};
+// Represents a user type, analagous to `int` or `dict`
+struct TORCH_API UserTypeValue : public SugaredValue {
+ UserTypeValue(UserTypePtr type) : type_(std::move(type)) {}
+
+ // Call the type's constructor, as in:
+ // n = Foo(constructor_arg)
+ std::shared_ptr<SugaredValue> call(
+ const SourceRange& loc,
+ Method& m,
+ at::ArrayRef<NamedValue> inputs,
+ at::ArrayRef<NamedValue> attributes,
+ size_t n_binders) override;
+
+ std::string kind() const override {
+ return type_->str();
+ }
+
+ UserTypePtr type_;
+};
+
// defines how a method obtained from a module behaves in script
struct MethodValue : public SugaredValue {
- MethodValue(std::shared_ptr<Module> module, Method& method)
- : module(std::move(module)) // insurance that method stays alive
- ,
- method(method) {}
+ MethodValue(std::shared_ptr<SugaredValue> self, Method& method)
+ : self_(std::move(self)), method(method) {}
std::string kind() const override {
return "method";
}
at::ArrayRef<NamedValue> inputs,
at::ArrayRef<NamedValue> attributes,
size_t n_binders) override {
+ if (auto userType = dynamic_cast<SimpleValue*>(self_.get())) {
+ // If self_ is a user-defined type, then it will be expected as part of
+ // the schema. Add it to the front of the inputs.
+ std::vector<NamedValue> inputsWithSelf;
+ inputsWithSelf.emplace_back(loc, userType->getValue());
+ inputsWithSelf.insert(inputsWithSelf.end(), inputs.begin(), inputs.end());
+ return std::make_shared<SimpleValue>(
+ caller.emit_call_to(loc, method, inputsWithSelf, attributes));
+ }
+
return std::make_shared<SimpleValue>(
caller.emit_call_to(loc, method, inputs, attributes));
}
private:
- std::shared_ptr<Module> module;
+ std::shared_ptr<SugaredValue> self_;
Method& method;
};
//
// Decl = Decl(List<Param> params, Maybe<Expr> return_type) TK_DECL
// Def = Def(Ident name, Decl decl, List<Stmt> body) TK_DEF
+// ClassDef = ClassDef(Ident name, List<Def> body) TK_CLASS_DEF
//
// Stmt = If(Expr cond, List<Stmt> true_body, List<Stmt> false_body) TK_IF
// | For(List<Expr> targets, List<Expr> iters, List<Stmt> body) TK_FOR
}
};
+struct ClassDef : public TreeView {
+ explicit ClassDef(const TreeRef& tree) : TreeView(tree) {
+ tree->match(TK_CLASS_DEF);
+ }
+ ClassDef withName(std::string new_name) const {
+ auto new_ident = Ident::create(name().range(), std::move(new_name));
+ return create(range(), new_ident, defs());
+ }
+ Ident name() const {
+ return Ident(subtree(0));
+ }
+ List<Def> defs() const {
+ return List<Def>(subtree(1));
+ }
+ static ClassDef create(
+ const SourceRange& range,
+ const Ident& name,
+ const List<Def>& defs) {
+ return ClassDef(Compound::create(TK_CLASS_DEF, range, {name, defs}));
+ }
+};
+
////////////////////////////////////////////////////////////////////////////////
// Statements
////////////////////////////////////////////////////////////////////////////////
--- /dev/null
+#include <ATen/core/jit_type.h>
+#include <torch/csrc/jit/script/module.h>
+
+namespace c10 {
+
+// This file exists because we need to reference module.h, which we can't from
+// c10. Sigh...
+Method* UserType::getMethod(const std::string& name) const {
+ return module_->find_method(name);
+}
+
+} // namespace c10
from torch.autograd import Variable, function
from torch.serialization import validate_cuda_device
from torch.nn import Module, ModuleList, ParameterList, Parameter, Sequential
-from torch.jit.frontend import get_jit_ast, get_default_args
+from torch.jit.frontend import get_jit_class_def, get_jit_def, get_default_args
import torch.backends.cudnn as cudnn
import torch.jit.annotations
import torch._jit_internal as _jit_internal
_flatten = torch._C._jit_flatten
_unflatten = torch._C._jit_unflatten
_jit_script_compile = torch._C._jit_script_compile
+_jit_script_class_compile = torch._C._jit_script_class_compile
BatchTensor = torch._C._jit.BatchTensor
Future = torch._C.Future
return entry["compiled_fn"]
-def script(fn, optimize=True, _frames_up=0, _rcb=None):
+def script(obj, optimize=True, _frames_up=0, _rcb=None):
if not _enabled:
- return fn
+ return obj
if _rcb is None:
_rcb = _jit_internal.createResolutionCallback(_frames_up + 1)
- ast = get_jit_ast(fn, is_method=False)
- mod = ScriptModule()
- _jit_script_compile(mod, ast, _rcb, get_default_args(fn))
+ if inspect.isclass(obj):
+ mod = ScriptClass(obj.__name__)
+ ast = get_jit_class_def(obj)
+ _jit_script_class_compile(mod, ast, _rcb)
+ else:
+ mod = ScriptModule()
+ ast = get_jit_def(obj)
+ _jit_script_compile(mod, ast, _rcb, get_default_args(obj))
# Forward docstrings
- mod.__doc__ = fn.__doc__
+ mod.__doc__ = obj.__doc__
return mod
+
ScriptMethodStub = namedtuple('ScriptMethodStub', ('resolution_callback', 'def_', 'original_method'))
# function (the calling function). Adding 2 gets us to the proper surrounding scope.
if _rcb is None:
_rcb = _jit_internal.createResolutionCallback(frames_up=2)
- ast = get_jit_ast(fn, is_method=True)
+ ast = get_jit_def(fn, self_name="ScriptModule")
return ScriptMethodStub(_rcb, ast, fn)
"weak script module once it has been "
"created".format(attr))
+ class ScriptClass(ScriptModule):
+ def __init__(self, name):
+ super(ScriptClass, self).__init__()
+ self._name = name
+
else:
class ScriptModule(torch.nn.Module):
def __init__(self, optimize=True):
super(ScriptModule, self).__init__()
+ class ScriptClass(ScriptModule):
+ def __init__(self, name):
+ super(ScriptClass, self).__init__()
+
def _get_weak_stubs(cls):
"""
'_uses_true_division: expected function or method, got {}'.format(type(fn)))
-def get_jit_ast(fn, is_method):
+def get_jit_class_def(cls, self_name=None):
+ # Get defs for each method independently
+ methods = inspect.getmembers(
+ cls, predicate=lambda m: inspect.ismethod(m) or inspect.isfunction(m))
+ method_defs = [get_jit_def(method[1],
+ self_name=cls.__name__) for method in methods]
+
+ source = dedent(inspect.getsource(cls))
+ py_ast = ast.parse(source)
+ ctx = SourceContext(source, False)
+ return build_class_def(ctx, py_ast.body[0], method_defs)
+
+
+def get_jit_def(fn, self_name=None):
source = dedent(inspect.getsource(fn))
py_ast = ast.parse(source)
if len(py_ast.body) != 1 or not isinstance(py_ast.body[0], ast.FunctionDef):
raise RuntimeError("expected a single top-level function")
type_line = torch.jit.annotations.get_type_line(source)
ctx = SourceContext(source, _uses_true_division(fn))
- return build_def(ctx, py_ast.body[0], type_line, is_method)
+ return build_def(ctx, py_ast.body[0], type_line, self_name)
# Thin wrapper around SourceRangeFactory to store extra metadata
return method(ctx, node)
-def build_def(ctx, py_def, type_line, is_method):
- returns = []
- ret_body = []
+def build_class_def(ctx, py_def, methods):
+ r = ctx.make_range(py_def.lineno, py_def.col_offset,
+ py_def.col_offset + len("class"))
+ return ClassDef(Ident(r, py_def.name), methods)
+
+
+def build_def(ctx, py_def, type_line, self_name=None):
body = py_def.body
r = ctx.make_range(py_def.lineno, py_def.col_offset,
py_def.col_offset + len("def"))
- param_list = build_param_list(ctx, py_def.args)
+ param_list = build_param_list(ctx, py_def.args, self_name)
return_type = None
if getattr(py_def, 'returns', None) is not None:
return_type = build_expr(ctx, py_def.returns)
decl = Decl(r, param_list, return_type)
+ is_method = self_name is not None
if type_line is not None:
type_comment_decl = torch._C.parse_type_comment(type_line)
decl = torch._C.merge_type_from_type_comment(decl, type_comment_decl, is_method)
"or use keyword-only arguments with defaults")
-def build_param_list(ctx, py_args):
+def build_param_list(ctx, py_args, self_name):
if py_args.vararg is not None or py_args.kwarg is not None:
raise ValueError(_vararg_kwarg_err)
if not PY2 and py_args.kw_defaults:
raise ValueError(_vararg_kwarg_err)
- result = [build_param(ctx, arg, False) for arg in py_args.args]
+ result = [build_param(ctx, arg, self_name, False) for arg in py_args.args]
if not PY2:
- result += [build_params(ctx, arg, True) for arg in py_args.kwonlyargs]
+ result += [build_params(ctx, arg, self_name, True) for arg in py_args.kwonlyargs]
return result
-def build_param(ctx, py_arg, kwarg_only):
+def build_param(ctx, py_arg, self_name, kwarg_only):
# NB: In Python3 py_arg is a pair of (str arg, expr? annotation)
# In Python2 py_arg is a Name (Expr subclass)
name = py_arg.id if PY2 else py_arg.arg
r = ctx.make_range(py_arg.lineno, py_arg.col_offset, py_arg.col_offset + len(name))
if getattr(py_arg, 'annotation', None) is not None:
annotation_expr = build_expr(ctx, py_arg.annotation)
+ elif self_name is not None and name == 'self':
+ annotation_expr = Var(Ident(r, self_name))
else:
annotation_expr = Var(Ident(r, 'Tensor'))
return Param(annotation_expr, Ident(r, name), kwarg_only)