_(aten, warn) \
_(aten, floordiv) \
_(aten, __round_to_zero_floordiv)\
+ _(aten, _unwrap_optional) \
_(prim, fork) \
_(prim, RaiseException) \
_(prim, Function) \
// Note: NoneType is NOT a subtype of any optional.
// instead NoneType is convertable in schema matching to any Optional[T]
// it is handled this way because it is not possible to match None to Optional[T]
-// and extract T. Intead, we always create an instance of the prim::None instruction
-// with a particular type: v: Optional[int] = prim::None()
+// and extract T. Intead, we always create a None constant instruction
+// with a particular type: v: Optional[int] = None()
struct CAFFE2_API OptionalType: public SingleElementType<TypeKind::OptionalType, OptionalType> {
static OptionalTypePtr create(TypePtr element) {
return OptionalTypePtr(new OptionalType(std::move(element))); // NOLINT(modernize-make-shared)
JIT_TEST(THNNConv)
JIT_TEST(ATenNativeBatchNorm)
+JIT_TEST(NoneSchemaMatch)
#define JIT_TEST_CUDA(name) \
TEST(JitTest, name##_CUDA) { \
namespace jit {
std::string runJITCPPTests() {
std::stringstream out;
+ testNoneSchemaMatch();
testAutogradProfiler();
testADFormulas();
testArgumentSpec();
ASSERT_EQ(op->schema().arguments()[0].name(), "_0");
ASSERT_EQ(op->schema().arguments()[0].type()->kind(), TypeKind::FloatType);
ASSERT_EQ(op->schema().arguments()[1].name(), "_1");
- ASSERT_EQ(
- op->schema().arguments()[1].type()->kind(), TypeKind::TensorType);
+ ASSERT_EQ(op->schema().arguments()[1].type()->kind(), TypeKind::TensorType);
ASSERT_EQ(op->schema().returns()[0].type()->kind(), TypeKind::TensorType);
ASSERT_EQ(op->schema().arguments()[0].name(), "a");
ASSERT_EQ(op->schema().arguments()[0].type()->kind(), TypeKind::FloatType);
ASSERT_EQ(op->schema().arguments()[1].name(), "b");
- ASSERT_EQ(
- op->schema().arguments()[1].type()->kind(), TypeKind::TensorType);
+ ASSERT_EQ(op->schema().arguments()[1].type()->kind(), TypeKind::TensorType);
ASSERT_EQ(op->schema().returns().size(), 1);
ASSERT_EQ(op->schema().returns()[0].type()->kind(), TypeKind::TensorType);
}
void testEvalModeForLoadedModule() {
- if (isSandcastle()) return; // The module file to load is not generated in Sandcastle
+ if (isSandcastle())
+ return; // The module file to load is not generated in Sandcastle
std::string module_path = "dropout_model.pt";
- std::shared_ptr<torch::jit::script::Module> module = torch::jit::load(module_path);
+ std::shared_ptr<torch::jit::script::Module> module =
+ torch::jit::load(module_path);
AT_ASSERT(module->get_module("dropout")->is_training());
module->eval();
AT_ASSERT(!module->get_module("dropout")->is_training());
AT_CHECK(count == 200);
}
+void testNoneSchemaMatch() {
+ RegisterOperators reg({
+ Operator(
+ "test::test_none() -> int?",
+ [](const Node* node) {
+ return [](Stack& stack) {
+ push(stack, IValue());
+ return 0;
+ };
+ }),
+ Operator(
+ "test::is_none(int? a) -> bool",
+ [](const Node* node) {
+ return [](Stack& stack) {
+ IValue a = pop(stack);
+ if (a.isNone()) {
+ push(stack, true);
+ } else {
+ push(stack, false);
+ }
+ return 0;
+ };
+ }),
+ });
+
+ // Constant propagation will run test_none and produce a None,
+ // testing that its type is set appropriately and schema matching doesn't
+ // fail when running is_none
+
+ auto r = std::make_shared<Graph>();
+ auto& g = *r;
+ auto opt_int = g.insert(Symbol::fromQualString("test::test_none"), {});
+ auto out_bool = g.insert(Symbol::fromQualString("test::is_none"), {opt_int});
+ g.registerOutput(out_bool);
+ ConstantPropagation(r);
+
+ auto nodes = r->block()->nodes();
+ // checking that constant propagation ran wo/failure
+ AT_ASSERT(std::distance(nodes.begin(), nodes.end()) == 1);
+}
} // namespace
} // namespace jit
graph(%0 : Double(20, 16, 50, 40),
%1 : Double(13, 16, 3, 3)):
- %2 : Tensor? = prim::None(), scope: Conv2d
+ %2 : Tensor? = prim::Constant(), scope: Conv2d
%3 : int = prim::Constant[value=1](), scope: Conv2d
%4 : int = prim::Constant[value=1](), scope: Conv2d
%5 : int[] = prim::ListConstruct(%3, %4), scope: Conv2d
--- /dev/null
+graph(%input : Tensor,
+ %opt.1 : Tensor?):
+ %2 : None = prim::Constant()
+ %3 : int = prim::Constant[value=1]()
+ %4 : int = prim::Constant[value=2]()
+ %5 : int = prim::Constant[value=4]()
+ %x.1 : Tensor = aten::add(%input, %4, %3)
+ %7 : bool = aten::__isnot__(%opt.1, %2)
+ %opt.4 : Tensor?, %x.3 : Tensor = prim::If(%7)
+ block0():
+ %opt.2 : Tensor = prim::unchecked_unwrap_optional(%opt.1)
+ %opt.3 : Tensor = aten::_unwrap_optional(%opt.2)
+ %x.2 : Tensor = aten::add(%opt.3, %x.1, %3)
+ -> (%opt.3, %x.2)
+ block1():
+ -> (%opt.1, %x.1)
+ %13 : bool = aten::__is__(%opt.4, %2)
+ %x : Tensor = prim::If(%13)
+ block0():
+ %x.4 : Tensor = aten::add(%x.3, %5, %3)
+ -> (%x.4)
+ block1():
+ -> (%x.3)
+ return (%x)
graph(%a.1 : Tensor):
- %1 : None = prim::None()
+ %1 : None = prim::Constant()
%2 : int = prim::Constant[value=1]()
%3 : Device = prim::Constant[value="cpu"]()
%4 : int = prim::Constant[value=0]()
self.run_pass('constant_propagation', constant_prop.graph)
graph_str = str(constant_prop.graph)
- self.assertTrue(graph_str.count("prim::None") == 0)
+ self.assertTrue(graph_str.count("prim::Constant") == 1)
def test_constant_prop_if_inline(self):
@torch.jit.script
print(typed_nones())
graph_str = str(test.graph)
- self.assertTrue(graph_str.count("bool? = prim::None") == 1)
- self.assertTrue(graph_str.count("int? = prim::None") == 1)
- self.assertTrue(graph_str.count("None = prim::None") == 1)
+ self.assertTrue(graph_str.count("bool? = prim::Constant") == 1)
+ self.assertTrue(graph_str.count("int? = prim::Constant") == 1)
+ self.assertTrue(graph_str.count("None = prim::Constant") == 1)
def test_literal(self):
def func1(a, b):
// Tensor", "aten::min(Tensor self) -> Tensor"
if (n->kind() == prim::Constant || n->kind() == prim::Undefined ||
- n->kind() == prim::AutogradAdd || n->kind() == prim::ConstantChunk ||
- n->kind() == prim::None)
+ n->kind() == prim::AutogradAdd || n->kind() == prim::ConstantChunk)
return true;
if (differentiable_ops.find(n))
return true;
if (n->matches(
"aten::nll_loss(Tensor self, Tensor target, Tensor? weight, int reduction, int ignore_index) -> Tensor")) {
// TODO(asuhan): support weight
- return n->namedInput(attr::weight)->node()->kind() == prim::None;
+ return n->namedInput(attr::weight)->node()->mustBeNone();
}
// linear blocks may appear as inputs to graph executors, but they are removed
if (node->matches(
"aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor")) {
- return {
- gradSumToSizeOf(grads.at(0), attr::self),
- gradSumToSizeOf(grads.at(0) * node->namedInput(attr::alpha), attr::other),
- nullptr};
+ return {gradSumToSizeOf(grads.at(0), attr::self),
+ gradSumToSizeOf(
+ grads.at(0) * node->namedInput(attr::alpha), attr::other),
+ nullptr};
} else if (
node->matches(
"aten::clamp(Tensor self, Scalar? min, Scalar? max) -> Tensor")) {
// handle the case that min/max is None
Value* min = inputs.at(1);
- bool min_must_be_none = min->node()->kind() == prim::None;
+ bool min_must_be_none = min->mustBeNone();
Value* max = inputs.at(2);
- bool max_must_be_none = max->node()->kind() == prim::None;
- // XXX - this formula is wrong when min or max are not stricly prim::None
- // but may be None dynamically. In this case an internal compiler error
- // will get thrown when trying to generate expressions involving the
+ bool max_must_be_none = max->mustBeNone();
+ // XXX - this formula is wrong when min or max are not stricly a constant
+ // None but may be None dynamically. In this case an internal compiler
+ // error will get thrown when trying to generate expressions involving the
// values of min/max
if (!min_must_be_none && !max_must_be_none) {
return {grads.at(0) *
} else if (
node->matches(
"aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta, Scalar alpha) -> Tensor")) {
- return {
- gradSumToSizeOf(grads.at(0) * node->namedInput(attr::beta), attr::self),
- grads.at(0).mm(inputs.at(2).t()) * node->namedInput(attr::alpha),
- inputs.at(1).t().mm(grads.at(0)) * node->namedInput(attr::alpha),
- nullptr,
- nullptr};
+ return {gradSumToSizeOf(
+ grads.at(0) * node->namedInput(attr::beta), attr::self),
+ grads.at(0).mm(inputs.at(2).t()) * node->namedInput(attr::alpha),
+ inputs.at(1).t().mm(grads.at(0)) * node->namedInput(attr::alpha),
+ nullptr,
+ nullptr};
} else if (node->matches("aten::mm(Tensor self, Tensor mat2) -> Tensor")) {
return {grads.at(0).mm(inputs.at(1).t()),
return {backward_value->node()->output(0), nullptr};
} else if (
- node->kind() == prim::Constant || node->kind() == prim::Undefined ||
- node->kind() == prim::None) {
+ node->kind() == prim::Constant || node->kind() == prim::Undefined) {
return {};
}
throw std::runtime_error(
Value* insertConstant(
Graph& g,
const IValue& val,
+ const c10::TypePtr& result_type,
c10::optional<SourceRange> loc,
c10::optional<ScopePtr> scope) {
Node* n = g.create(prim::Constant);
n->s_(attr::value, ss.str());
n->output()->setType(DeviceObjType::get());
} else if (val.isNone()) {
- n->destroy();
- n = g.create(prim::None);
n->output()->setType(NoneType::get());
} else {
throw constant_not_supported_error(
n->setSourceLocation(std::make_shared<SourceRange>(*loc));
if (scope)
n->setScope(*scope);
+ if (result_type) {
+ auto inferred_type = n->output()->type();
+ // Retain more type information in case of tensor constant
+ if (!(inferred_type->isSubtypeOf(TensorType::get()) && result_type->isSubtypeOf(inferred_type))) {
+ n->output()->setType(result_type);
+ }
+ }
return g.insertNode(n)->output();
}
push(stack, d);
return 0;
};
+ } else if (node->mustBeNone()) {
+ return [](Stack& stack) {
+ push(stack, IValue());
+ return 0;
+ };
} else {
std::stringstream ss;
ss << "constant literal not supported for: " << type->str();
#pragma once
#include <torch/csrc/WindowsTorchApiMacro.h>
#include <ATen/core/ivalue.h>
+#include <ATen/core/jit_type.h>
#include <torch/csrc/jit/scope.h>
#include <torch/csrc/jit/source_range.h>
TORCH_API Value* insertConstant(
Graph& g,
const IValue& val,
+ const c10::TypePtr& result_type = nullptr,
c10::optional<SourceRange> loc = c10::nullopt,
c10::optional<ScopePtr> scope = c10::nullopt);
bool is_aten_enabled = operator_export_type ==
onnx_torch::OperatorExportTypes::ONNX_ATEN_FALLBACK ||
operator_export_type == onnx_torch::OperatorExportTypes::ONNX_ATEN;
- if (!node->kind().is_onnx() && !is_aten_enabled &&
- node->kind() != prim::None) {
+ if (!node->kind().is_onnx() && !is_aten_enabled && !node->mustBeNone()) {
FAIL_EXPORT(
"Couldn't export operator " + node->kind().toDisplayString() +
"\n\nDefined at:\n" + getNodeStackTraceString(node));
for (auto node : block->nodes()) {
bool is_raw_export =
operator_export_type_ == onnx_torch::OperatorExportTypes::RAW;
- if (node->kind() == prim::None && !is_raw_export) {
+ if (node->mustBeNone() && !is_raw_export) {
// None nodes are used to implement optional inputs. One
// way to "not provide" an optional input is to create an
// Undefined node, and pass its output as that input.
p_n->set_doc_string(ss.str());
}
for (auto input : node->inputs()) {
- if (input->node()->kind() == prim::None && !is_raw_export) {
+ if (input->node()->mustBeNone() && !is_raw_export) {
p_n->add_input("");
} else {
p_n->add_input(input->uniqueName());
<< " is out of bounds (constant table has "
<< constants_.size() << " entries).";
}
- Value* value = m.graph()->insertConstant(constants_[offset], loc);
+ Value* value = m.graph()->insertConstant(constants_[offset], nullptr, loc);
return std::make_shared<script::SimpleValue>(value);
}
}
bool Value::mustBeNone() const {
- return node_->kind() == prim::None;
+ return node_->mustBeNone();
}
std::string Value::uniqueNameBase() const {
return true;
}
+bool Node::mustBeNone() const {
+ return kind_ == prim::Constant && !this->hasAttributes() &&
+ (output()->type()->cast<OptionalType>() ||
+ output()->type() == NoneType::get());
+}
+
void Node::dump() const {
std::cout << *this << "\n";
}
}
Node* Graph::createNone(TypePtr typ) {
- Node* n = create(prim::None);
+ Node* n = create(prim::Constant);
n->output()->setType(OptionalType::create(std::move(typ)));
return n;
}
Value* Graph::insertConstant(
IValue val,
+ const TypePtr& result_type,
c10::optional<SourceRange> loc,
c10::optional<ScopePtr> scope) {
return jit::insertConstant(
- *this, std::move(val), std::move(loc), std::move(scope));
+ *this, std::move(val), result_type, std::move(loc), std::move(scope));
}
std::string Graph::toString() const {
bool is_constant(Symbol name) const {
return static_cast<bool>(get(name));
}
+ TORCH_API bool mustBeNone() const;
TORCH_API bool isNondeterministic() const;
TORCH_API bool hasSideEffects() const;
const std::function<Value*(Value*)>& value_map,
bool copy_blocks = true);
+
+ // Insert constant IValue into the graph. If the type cannot be fully deduced
+ // from the ivalue, as with a None that is set to t?, use result_type
TORCH_API Value* insertConstant(
IValue val,
+ const TypePtr& result_type = nullptr,
c10::optional<SourceRange> loc = c10::nullopt,
c10::optional<ScopePtr> scope = c10::nullopt);
case prim::FusedConcat:
case prim::MMTreeReduce:
case prim::MMBatchSide:
- case prim::None:
case prim::BroadcastSizes:
case prim::ChunkSizes:
case prim::Function:
continue;
}
- if (node->kind() != prim::Constant && node->kind() != prim::None) {
+ if (node->kind() != prim::Constant) {
continue;
}
prim::Constant,
prim::Undefined,
prim::unchecked_unwrap_optional, // TODO remove
- prim::None, // it is already a constant and propagating it will lose
- // important type information about which Optional type it is
// TODO (zach): we should consider skipping tensor factories in the cases
// where the constant tensor would be large but cheap to create.
};
auto op = getOperation(n);
Stack stack;
for (auto input : n->inputs()) {
- if (input->node()->kind() == prim::None) {
- stack.emplace_back(IValue());
- } else {
- stack.push_back(*(toIValue(input)));
- }
+ stack.push_back(*(toIValue(input)));
}
op(stack);
auto var_outputs = fmap(stack, [&](IValue v) -> IValue {
for (size_t i = 0; i < outputs.size(); ++i) {
try {
auto new_output = graph->insertConstant(outputs[i]);
+ if (outputs[i].isNone()) {
+ new_output->setType(n->outputs()[i]->type());
+ }
n->outputs()[i]->replaceAllUsesWith(new_output);
} catch (constant_not_supported_error& err) {
// we cannot actually represent the IValue as a constant node,
void ConstantPropagation(Node* n, const AliasDb& aliasDb) {
bool constant_inputs =
std::all_of(n->inputs().begin(), n->inputs().end(), [&](Value* v) {
- return v->node()->kind() == prim::Constant ||
- v->node()->kind() == prim::None;
+ return v->node()->kind() == prim::Constant;
});
bool supported_node = !n->kind().is_onnx() &&
skip_list.count(n->kind()) == 0 && !n->isNondeterministic() &&
auto s = *constant_as<at::Scalar>(it->output());
WithInsertPoint guard(*it);
Value* r = block->owningGraph()->insertConstant(
- scalar_to_tensor(s), c10::nullopt, it->scope());
+ scalar_to_tensor(s), nullptr, c10::nullopt, it->scope());
it->output()->replaceAllUsesWith(r);
}
} break;
if (tensor->type()->isSubtypeOf(TensorType::get())) {
return true;
}
- if (tensor->node()->kind() == prim::None ||
+ if (tensor->node()->mustBeNone() ||
tensor->node()->kind() == prim::Undefined) {
return false;
}
switch (n->kind()) {
case prim::Constant:
case prim::Undefined:
- case prim::None:
return true;
default:
return false;
if (node->kind() == prim::PythonOp) {
auto value = static_cast<const PythonOp*>(node);
if (enforce_importable_ && value->ignore_on_export) {
- // Op has been marked as ignored, so insert an error in its place
- indent();
- out << "ops.prim.IgnoredPythonOp()\n";
- return;
+ // Op has been marked as ignored, so insert an error in its place
+ indent();
+ out << "ops.prim.IgnoredPythonOp()\n";
+ return;
}
}
switch (node->kind()) {
}
}
+ void printNone(std::ostream& stmt, const Node* node) {
+ if (node->output()->type()->isSubtypeOf(NoneType::get())) {
+ stmt << "None";
+ return;
+ }
+ // XXX - when None has an Optional[T] type, we must ensure that type
+ // can be recovered on parsing. It cannot be recovered if it will be
+ // matched to schema with free variables. If it is used only in places
+ // where there is schema and the scheme has no free variables, then we
+ // can recover it without annotation. Otherwise, we annotate None with
+ // the right optional type
+ const auto& uses = node->output()->uses();
+ bool all_usable_schema =
+ std::all_of(uses.begin(), uses.end(), [](const Use& u) {
+ if (auto schema = u.user->maybeSchema()) {
+ if (u.offset >= schema->arguments().size()) {
+ return false;
+ }
+ return !schema->arguments().at(u.offset).type()->hasFreeVariables();
+ }
+ return false;
+ });
+
+ if (all_usable_schema) {
+ stmt << "None";
+ } else {
+ stmt << "annotate(" << node->output()->type()->python_str() << ", None)";
+ }
+ }
+
// Prints the RHS value of a Node, e.g. `aten.add(x, y)`
void printRHS(std::ostream& stmt, Node* node) {
switch (node->kind()) {
value->writeScalars(stmt);
printValueList(stmt, node->inputs(), "(", ")");
} break;
- case prim::Constant: {
- IValue v = toIValue(node->output()).value();
- printConstant(stmt, v);
- } break;
- case prim::Undefined:
- case prim::None: {
- if (node->output()->type()->isSubtypeOf(NoneType::get())) {
- stmt << "None";
- break;
- }
- // XXX - when None has an Optional[T] type, we must ensure that type
- // can be recovered on parsing. It cannot be recovered if it will be
- // matched to schema with free variables. If it is used only in places
- // where there is schema and the scheme has no free variables, then we
- // can recover it without annotation. Otherwise, we annotate None with
- // the right optional type
- const auto& uses = node->output()->uses();
- bool all_usable_schema =
- std::all_of(uses.begin(), uses.end(), [](const Use& u) {
- if (auto schema = u.user->maybeSchema()) {
- if (u.offset >= schema->arguments().size()) {
- return false;
- }
- return !schema->arguments()
- .at(u.offset)
- .type()
- ->hasFreeVariables();
- }
- return false;
- });
-
- if (all_usable_schema) {
- stmt << "None";
+ case prim::Constant:
+ case prim::Undefined: {
+ if (node->kind() == prim::Constant && !node->mustBeNone()) {
+ IValue v = toIValue(node->output()).value();
+ printConstant(stmt, v);
} else {
- stmt << "annotate(" << node->output()->type()->python_str()
- << ", None)";
+ printNone(stmt, node);
}
} break;
case prim::ImplicitTensorToNum: {
if (node->inputs().size() == 0 &&
!dict_type->getKeyType()->isSubtypeOf(StringType::get()) &&
!dict_type->getValueType()->isSubtypeOf(TensorType::get())) {
- stmt << "annotate(" << node->output()->type()->python_str() << ", {})";
+ stmt << "annotate(" << node->output()->type()->python_str()
+ << ", {})";
} else {
printDict(stmt, node->inputs());
}
prim::ListConstruct,
prim::DictConstruct,
prim::ListUnpack,
- prim::None,
prim::Print,
prim::PythonOp,
prim::TupleConstruct,
TypePtr typ = node->input()->type();
if (typ->isSubtypeOf(IntType::get()) ||
typ->isSubtypeOf(BoolType::get())) {
- node->output()->setType(DimensionedTensorType::create(at::kLong, at::kCPU, 0));
+ node->output()->setType(
+ DimensionedTensorType::create(at::kLong, at::kCPU, 0));
} else if (node->input()->type()->isSubtypeOf(FloatType::get())) {
- node->output()->setType(DimensionedTensorType::create(at::kDouble, at::kCPU, 0));
+ node->output()->setType(
+ DimensionedTensorType::create(at::kDouble, at::kCPU, 0));
}
return;
}
setUnshapedType(node);
return;
}
+ case aten::_unwrap_optional: {
+ auto input_ivalue = toIValue(node->input());
+ if (input_ivalue && input_ivalue->isNone()) {
+ return;
+ }
+ }
default:
break; // fall-through
}
// primitive/tensor outputs.
bool PropagateTensorShapeOnNode(Node* node, bool insert_expands) {
- static const auto broadcast = [](std::vector<DimensionedTensorTypePtr>& tensor_types,
- size_t arg_for_type) -> DimensionedTensorTypePtr {
+ static const auto broadcast =
+ [](std::vector<DimensionedTensorTypePtr>& tensor_types,
+ size_t arg_for_type) -> DimensionedTensorTypePtr {
if (tensor_types.size() == 1) {
return tensor_types[0];
}
"aten::zeros_like(Tensor self) -> Tensor",
},
[](Node* node) -> type_vec_t {
- auto input_type = node->input(0)->type()->cast<DimensionedTensorType>();
+ auto input_type =
+ node->input(0)->type()->cast<DimensionedTensorType>();
return input_type ? type_vec_t{input_type} : type_vec_t{};
}};
"aten::addcmul(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value) -> Tensor",
},
[this](Node* node) -> type_vec_t {
- if (auto maybe_tensor_types = gatherTensorTypes<DimensionedTensorType>(node)) {
+ if (auto maybe_tensor_types =
+ gatherTensorTypes<DimensionedTensorType>(node)) {
return {broadcast(*maybe_tensor_types, 0)};
}
return {};
"aten::where(Tensor condition, Tensor self, Tensor other) -> Tensor",
},
[this](Node* node) -> type_vec_t {
- if (auto maybe_tensor_types = gatherTensorTypes<DimensionedTensorType>(node)) {
+ if (auto maybe_tensor_types =
+ gatherTensorTypes<DimensionedTensorType>(node)) {
return {broadcast(*maybe_tensor_types, 1)};
}
return {};
}};
- static const auto any_tensor_type = [](Node* node) -> DimensionedTensorTypePtr {
+ static const auto any_tensor_type =
+ [](Node* node) -> DimensionedTensorTypePtr {
for (Value* input : node->inputs()) {
if (auto type = input->type()->cast<DimensionedTensorType>()) {
return type;
"aten::ne(Tensor self, Scalar other) -> Tensor",
},
[this](Node* node) -> type_vec_t {
- if (auto maybe_tensor_types = gatherTensorTypes<DimensionedTensorType>(node)) {
+ if (auto maybe_tensor_types =
+ gatherTensorTypes<DimensionedTensorType>(node)) {
return {broadcast(*maybe_tensor_types, 0)->toScalarType(at::kByte)};
}
return {};
"aten::prelu(Tensor self, Tensor weight) -> Tensor",
},
[](Node* node) -> type_vec_t {
- if (auto type = node->input(0)->type()->cast<DimensionedTensorType>()) {
+ if (auto type =
+ node->input(0)->type()->cast<DimensionedTensorType>()) {
return {type};
}
return {};
"aten::any(Tensor self) -> Tensor",
},
[](Node* node) -> type_vec_t {
- if (auto type = node->input(0)->type()->cast<DimensionedTensorType>()) {
+ if (auto type =
+ node->input(0)->type()->cast<DimensionedTensorType>()) {
return {type->withDim(0)};
}
return {};
"aten::prod(Tensor self) -> Tensor",
},
[](Node* node) -> type_vec_t {
- if (auto type = node->input(0)->type()->cast<DimensionedTensorType>()) {
+ if (auto type =
+ node->input(0)->type()->cast<DimensionedTensorType>()) {
return {at::isFloatingType(type->scalarType())
? type->withDim(0)
: type->withDim(0)->toScalarType(at::kLong)};
auto maybe_scalar_type = node->get<at::ScalarType>(attr::dtype);
if (!maybe_scalar_type)
return {};
- return {DimensionedTensorType::create(*maybe_scalar_type, *maybe_device, dim)};
+ return {DimensionedTensorType::create(
+ *maybe_scalar_type, *maybe_device, dim)};
};
// Requirements:
"aten::zeros_like(Tensor self, *, int dtype, int layout, Device device) -> Tensor",
},
[](Node* node) -> type_vec_t {
- if (auto type =
- node->namedInput(attr::self)->type()->cast<DimensionedTensorType>()) {
+ if (auto type = node->namedInput(attr::self)
+ ->type()
+ ->cast<DimensionedTensorType>()) {
return factory_with_ndim(node, type->dim());
}
return {};
"aten::_cast_Short(Tensor self, bool non_blocking) -> Tensor",
},
[](Node* node) -> type_vec_t {
- if (auto type =
- node->namedInput(attr::self)->type()->cast<DimensionedTensorType>()) {
+ if (auto type = node->namedInput(attr::self)
+ ->type()
+ ->cast<DimensionedTensorType>()) {
return {type->toScalarType(get_cast_scalar_type(node))};
}
return {};
static const auto reshape_prop =
[](Node* node,
Symbol shape_input,
- const std::vector<DimensionedTensorTypePtr>& tensor_types) -> DimensionedTensorTypePtr {
+ const std::vector<DimensionedTensorTypePtr>& tensor_types)
+ -> DimensionedTensorTypePtr {
if (auto list_size = determineListSize(node->namedInput(shape_input))) {
return tensor_types.at(0)->withDim(*list_size);
}
}
return nullptr;
};
- if (auto maybe_tensor_types = gatherTensorTypes<DimensionedTensorType>(node)) {
+ if (auto maybe_tensor_types =
+ gatherTensorTypes<DimensionedTensorType>(node)) {
tensor_types = std::move(*maybe_tensor_types);
} else {
return false;
-#include <torch/csrc/jit/passes/dead_code_elimination.h>
#include <torch/csrc/jit/passes/to_batch.h>
+#include <torch/csrc/jit/passes/dead_code_elimination.h>
#include <torch/csrc/jit/script/compiler.h>
namespace torch {
} else if (n->kind().is_prim()) {
switch (n->kind()) {
case prim::Constant:
- case prim::None:
visitConstant(n, block, res_block);
break;
case prim::NumToTensor:
// enable more optimizations
if (auto value = toIValue(input)) {
auto nv = subgraph->insertConstant(*value);
+ nv->setType(input->type()); // Need to retain type information on Nones
inputsMap[input] = nv;
} else {
// The common case: this is a regular input, so just register it with
return ret;
}
-std::vector<Node*> findAllNodes(Block* block, Symbol kind, bool recurse = true) {
+std::vector<Node*> findAllNodes(
+ Block* block,
+ Symbol kind,
+ bool recurse = true) {
std::vector<Block*> blocks = {block};
return findAllNodes(blocks, kind, recurse);
}
"findNode",
[](Graph& g, const std::string& kind, bool recurse) {
return findNode(g.block(), Symbol::fromQualString(kind), recurse);
- }, "Find Node", py::arg("kind"), py::arg("recurse") = true)
+ },
+ "Find Node",
+ py::arg("kind"),
+ py::arg("recurse") = true)
.def(
"findAllNodes",
[](Graph& g, const std::string& kind, bool recurse) {
return findAllNodes(
g.block(), Symbol::fromQualString(kind), recurse);
- }, "Find all nodes", py::arg("kind"), py::arg("recurse") = true)
+ },
+ "Find all nodes",
+ py::arg("kind"),
+ py::arg("recurse") = true)
.def("addInput", [](Graph& g) { return g.addInput(); })
.def("copy", [](Graph& g) { return g.copy(); })
.GS(eraseInput)
#undef VS
py::class_<Block, std::unique_ptr<Block, py::nodelete>>(m, "Block")
- .def("nodes", [](Block& b) {
- return py::make_iterator(b.nodes().begin(), b.nodes().end());
- })
+ .def(
+ "nodes",
+ [](Block& b) {
+ return py::make_iterator(b.nodes().begin(), b.nodes().end());
+ })
.def(
"findNode",
[](Block& b, const std::string& kind, bool recurse) {
return findNode(&b, Symbol::fromQualString(kind), recurse);
- }, "Find Node", py::arg("kind"), py::arg("recurse") = true)
+ },
+ "Find Node",
+ py::arg("kind"),
+ py::arg("recurse") = true)
.def(
"findAllNodes",
[](Block& b, const std::string& kind, bool recurse) {
return findAllNodes(&b, Symbol::fromQualString(kind), recurse);
- }, "Find all nodes", py::arg("kind"), py::arg("recurse") = true);
-
+ },
+ "Find all nodes",
+ py::arg("kind"),
+ py::arg("recurse") = true);
#define NS(name) def(#name, &Node ::name)
py::class_<Node, std::unique_ptr<Node, py::nodelete>>(m, "Node")
"findNode",
[](Node& n, const std::string& kind, bool recurse) {
return findNode(n.blocks(), Symbol::fromQualString(kind), recurse);
- }, "Find Node", py::arg("kind"), py::arg("recurse") = true)
+ },
+ "Find Node",
+ py::arg("kind"),
+ py::arg("recurse") = true)
.def(
"findAllNodes",
[](Node& n, const std::string& kind, bool recurse) {
return findAllNodes(
n.blocks(), Symbol::fromQualString(kind), recurse);
- }, "Find all nodes", py::arg("kind"), py::arg("recurse") = true)
+ },
+ "Find all nodes",
+ py::arg("kind"),
+ py::arg("recurse") = true)
.def("input", [](Node& n) { return n.input(); })
.def("output", [](Node& n) { return n.output(); })
.NS(addInput)
return py::make_iterator(n.blocks().begin(), n.blocks().end());
})
.NS(addBlock)
+ .NS(mustBeNone)
#define AS(name) def(#name, &Node::name)
// methods from Attributes
})
.def(
"t",
- [](Node& n, const char* name) {
- return n.t(Symbol::attr(name));
- })
+ [](Node& n, const char* name) { return n.t(Symbol::attr(name)); })
// Tensors (ts_) -- manually written to unwrap variables into tensors.
.def(
"ts_",
"z_",
[](Node& n, const char* name, at::Tensor v) {
return n.t_(
- Symbol::attr(name), autograd::Variable(v.view({})).set_requires_grad(false));
+ Symbol::attr(name),
+ autograd::Variable(v.view({})).set_requires_grad(false));
})
.def(
"z",
return s.str();
})
.def("kind", [](const Type& t) { return typeKindToString(t.kind()); })
- .def("dim", [](const Type& t) { return t.expect<DimensionedTensorType>()->dim(); })
+ .def(
+ "dim",
+ [](const Type& t) {
+ return t.expect<DimensionedTensorType>()->dim();
+ })
.def(
"sizes",
[](Type& t) { return t.expect<CompleteTensorType>()->sizes(); })
.def(py::init([](TypePtr key, TypePtr value) {
return DictType::create(key, value);
}));
+ py::class_<OptionalType, Type, std::shared_ptr<OptionalType>>(
+ m, "OptionalType")
+ .def(py::init([](TypePtr a) { return OptionalType::create(a); }))
+ .def_static("ofTensor", &OptionalType::ofTensor)
+ .def("getElementType", &OptionalType::getElementType);
py::class_<Use>(m, "Use")
.def_readonly("user", &Use::user)
};
}),
Operator(
- prim::None,
- [](const Node* node) {
- return [](Stack& stack) {
- stack.emplace_back(IValue());
- return 0;
- };
- }),
- Operator(
prim::Print,
[](const Node* node) {
size_t num_inputs = node->inputs().size();
}
WithInsertPoint guard(graph.block()->nodes().front());
- auto new_constant = graph.insertConstant(val, r);
+ auto new_constant = graph.insertConstant(val, nullptr, r);
map[val] = new_constant;
return new_constant;
if (cond) {
cond_val = emitCond(cond.value());
} else {
- cond_val = graph->insertConstant(true, range);
+ cond_val = graph->insertConstant(true, nullptr, range);
}
}
n->addInput(max_trip_count_val);
Value* body_cond_value = emitCond(cond.value());
body_block->registerOutput(body_cond_value);
} else {
- Value* cond_value_dummy = graph->insertConstant(true, range);
+ Value* cond_value_dummy = graph->insertConstant(true, nullptr, range);
body_block->registerOutput(cond_value_dummy);
}
// print(a)
void emitRaise(const SourceRange& loc) {
const std::string exception = "Exception";
- auto string_input = insertConstant(*graph, exception, loc);
+ auto string_input = insertConstant(*graph, exception, nullptr, loc);
graph->insert(prim::RaiseException, {string_input}, {}, loc);
}
bool is_instance_val =
isInstanceCheck(apply.inputs()[0], apply.inputs()[1]);
return std::make_shared<SimpleValue>(
- graph->insertConstant(is_instance_val, loc));
+ graph->insertConstant(is_instance_val, nullptr, loc));
} else {
auto inputs = getNamedValues(apply.inputs(), true);
auto attributes = emitAttributes(apply.attributes());
stack.push_back(*maybe_constant_input);
op(stack);
AT_ASSERT(stack.size() == 1);
- return graph->insertConstant(stack[0], tree->range());
+ return graph->insertConstant(stack[0], nullptr, tree->range());
}
// This function extract a new graph from its original subgraph
return emitConst(Const(tree));
} break;
case TK_TRUE: {
- return graph->insertConstant(true, tree->range());
+ return graph->insertConstant(true, nullptr, tree->range());
} break;
case TK_FALSE: {
- return graph->insertConstant(false, tree->range());
+ return graph->insertConstant(false, nullptr, tree->range());
} break;
case TK_NONE: {
- return graph->insertConstant(IValue(), tree->range());
+ return graph->insertConstant(IValue(), nullptr, tree->range());
} break;
case TK_SUBSCRIPT: {
return emitSubscript(Subscript(tree));
}
Value* emitStringLiteral(const StringLiteral& c) {
- return insertConstant(*graph, c.text(), c.range());
+ return insertConstant(*graph, c.text(), nullptr, c.range());
}
// Desugars select indexing: tensor[i] -> tensor.select(dim, i)
*graph,
aten::select,
c10::nullopt,
- {input, graph->insertConstant(dim, loc), index},
+ {input, graph->insertConstant(dim, nullptr, loc), index},
{},
true);
}
// aten::slice, we should separate it from this function.
if (dim) {
AT_ASSERT(input->type()->isSubtypeOf(TensorType::get()));
- args.emplace_back(loc, "dim", graph->insertConstant(dim.value(), loc));
+ args.emplace_back(loc, "dim", graph->insertConstant(dim.value(), nullptr, loc));
} else {
AT_ASSERT(!input->type()->isSubtypeOf(TensorType::get()));
}
return emitTupleSlice(loc, args[0], args[1], c10::nullopt);
}
}
- NamedValue step = NamedValue(loc, "step", graph->insertConstant(1, loc));
+ NamedValue step = NamedValue(loc, "step", graph->insertConstant(1, nullptr, loc));
return emitBuiltinCall(
loc, *graph, aten::slice, c10::nullopt, args, {step}, true);
}
auto& g = *m.graph();
if (is_constant) {
if (py::isinstance<py::bool_>(obj)) {
- return toSimple(g.insertConstant(py::cast<bool>(obj), loc));
+ return toSimple(g.insertConstant(py::cast<bool>(obj), nullptr, loc));
} else if (py::isinstance<py::int_>(obj)) {
- return toSimple(g.insertConstant(py::cast<int64_t>(obj), loc));
+ return toSimple(g.insertConstant(py::cast<int64_t>(obj), nullptr, loc));
} else if (py::isinstance<py::float_>(obj)) {
- return toSimple(g.insertConstant(py::cast<double>(obj), loc));
+ return toSimple(g.insertConstant(py::cast<double>(obj), nullptr, loc));
} else if (py::isinstance<py::str>(obj)) {
- return toSimple(g.insertConstant(py::cast<std::string>(obj), loc));
+ return toSimple(g.insertConstant(py::cast<std::string>(obj), nullptr, loc));
} else if (obj.is(py::none())) {
- return toSimple(g.insertConstant(IValue(), loc));
+ return toSimple(g.insertConstant(IValue(), nullptr, loc));
} else if (THPDevice_Check(obj.ptr())) {
auto device = reinterpret_cast<THPDevice*>(obj.ptr());
return toSimple(g.insertConstant(device->device));
} else if (THPLayout_Check(obj.ptr())) {
auto layout = reinterpret_cast<THPLayout*>(obj.ptr());
const auto v = static_cast<int64_t>(layout->layout);
- return toSimple(g.insertConstant(v, loc));
+ return toSimple(g.insertConstant(v, nullptr, loc));
} else if (THPDtype_Check(obj.ptr())) {
auto dtype = reinterpret_cast<THPDtype*>(obj.ptr());
const auto v = static_cast<int64_t>(dtype->scalar_type);
- return toSimple(g.insertConstant(v, loc));
+ return toSimple(g.insertConstant(v, nullptr, loc));
} else if (py::isinstance<py::tuple>(obj)) {
return std::make_shared<ConstantPythonTupleValue>(obj);
}
else:
grad_self = torch._adaptive_avg_pool2d_backward(grad_output, self)
return grad_self, None
-
return torch.adaptive_avg_pool2d(self, output_size), backward
def batch_norm(input : Tensor,
return torch.embedding(weight, indices, padding_idx, scale_grad_by_freq, sparse), backward
+ def nll_loss(self, target, weight: Optional[Tensor], reduction: int, ignore_index: int):
+ result, total_weight = torch.nll_loss_forward(self, target, weight, reduction, ignore_index)
+ def backward(grad):
+ return torch.nll_loss_backward(grad, self, target, weight, reduction, ignore_index, total_weight), None, None, None, None
+ return result, backward
)"};
std::unordered_map<std::string, GradientPair> schema_to_graphs;
if n_mod.kind() != n_check.kind():
break # Graphs have already diverged
- if n_mod.kind() == n_check.kind() and n_mod.kind() == 'prim::Constant':
+ if n_mod.kind() == 'prim::Constant' and not (n_mod.mustBeNone() or n_check.mustBeNone()):
if n_mod.kindOf('value') != 't' or n_check.kindOf('value') != 't':
continue
import numbers
import torch
-from torch._C import DynamicType, ListType
+from torch._C import DynamicType, ListType, OptionalType
from torch.nn.modules.utils import _single, _pair, _triple
from torch.nn.utils.rnn import PackedSequence
import warnings
# used to represent "missing" optional inputs
def unused(g):
- return g.op("prim::None")
+ n = g.op("prim::Constant")
+ n.setType(OptionalType.ofTensor())
+ return n
def _shape_as_tensor(g, input):
args = [input, weight]
# ONNX only supports 1D bias
- if bias.node().kind() != "prim::None" and bias.type().dim() == 1:
+ if not bias.node().mustBeNone() and bias.type().dim() == 1:
args.append(bias)
kwargs = {"kernel_shape_i": weight_size[2:],
n = g.op("ConvTranspose" if transposed else "Conv", *args, **kwargs)
- if bias.node().kind() != "prim::None" and bias.type().dim() != 1:
+ if not bias.node().mustBeNone() and bias.type().dim() != 1:
return g.op("Add", n, bias)
else:
return n
# batchnorm1d accepts 2d and 3d array, but ONNX only accepts 3d
input = g.op("Unsqueeze", input, axes_i=[2])
- if weight is None or weight.node().kind() == "prim::None":
+ if weight is None or weight.node().mustBeNone():
assert len(input_sizes) > 1
weight_value = torch.tensor([1.] * input_sizes[1]).type(
'torch.' + input.type().scalarType() + 'Tensor')
weight = g.op("Constant", value_t=weight_value)
- if bias is None or bias.node().kind() == "prim::None":
+ if bias is None or bias.node().mustBeNone():
assert len(input_sizes) > 1
bias_value = torch.tensor([0.] * input_sizes[1]).type(
'torch.' + input.type().scalarType() + 'Tensor')
@parse_args('v', 'v', 'v', 'v', 'v', 'i', 'f', 'f', 'i')
def instance_norm(g, input, weight, bias, running_mean, running_var, use_input_stats, momentum, eps, cudnn_enabled):
input_sizes = input.type().sizes()
- if weight is None or weight.node().kind() == "prim::None":
+ if weight is None or weight.node().mustBeNone():
assert len(input_sizes) > 1
weight_value = torch.tensor([1.] * input_sizes[1]).type(
'torch.' + input.type().scalarType() + 'Tensor')
weight = g.op("Constant", value_t=weight_value)
- if bias is None or bias.node().kind() == "prim::None":
+ if bias is None or bias.node().mustBeNone():
assert len(input_sizes) > 1
bias_value = torch.tensor([0.] * input_sizes[1]).type(
'torch.' + input.type().scalarType() + 'Tensor')
def clamp(g, self, min, max):
- # min or max may be prim::None that we need to dispatch to
+ # min or max may be None that we need to dispatch to
# Clip separately, as ONNX does not have None syntax
- if min.node().kind() == "prim::None":
+ if min.node().mustBeNone():
return clamp_max(g, self, max)
- elif max.node().kind() == "prim::None":
+ elif max.node().mustBeNone():
return clamp_min(g, self, min)
else:
min = _parse_arg(min, 'f')
return fn(g, *inputs, **attrs)
elif ns == "prim":
- if op_name == "Constant":
+ if op_name == "Constant" and not n.mustBeNone():
if n.kindOf("value") == "t":
return g.op("Constant", value_t=n["value"])
elif n.kindOf("value") == "is":
else:
raise RuntimeError("Unsupported prim::Constant kind: `{}`. Send a bug report.".format(
n.kindOf("value")))
- elif op_name == "None" or op_name == "ListConstruct" or op_name == "ListUnpack":
- # None is not an ONNX operator; keep it as prim::None
+ elif n.mustBeNone() or op_name == "ListConstruct" or op_name == "ListUnpack":
+ # None is not an ONNX operator; keep it as None
# let the exporter handle finally eliminating these
# For ListConstruct/ListUnpack, it will be erased in the ONNX peephole pass