move prim::None to prim::Constant (again) (#17186)
authoreellison <elias_ellison@brown.edu>
Tue, 19 Feb 2019 19:34:46 +0000 (11:34 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Tue, 19 Feb 2019 19:45:50 +0000 (11:45 -0800)
Summary:
Trying to land again, make prim::None into a case of prim::Constant. Reverted the previous landing because it broke an important onnx export test.

https://github.com/pytorch/pytorch/pull/16160
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17186

Differential Revision: D14115304

Pulled By: eellison

fbshipit-source-id: 161435fc30460b4e116cdd62c7b2e5b94581dcb7

34 files changed:
aten/src/ATen/core/interned_strings.h
aten/src/ATen/core/jit_type.h
test/cpp/jit/gtest.cpp
test/cpp/jit/no-gtest.cpp
test/cpp/jit/test_misc.h
test/expect/TestJit.test_conv.expect
test/expect/TestScript.test_if_is_none_dispatch.expect [new file with mode: 0644]
test/expect/TestScript.test_mutable_dce_graph_input.expect
test/test_jit.py
torch/csrc/jit/autodiff.cpp
torch/csrc/jit/constants.cpp
torch/csrc/jit/constants.h
torch/csrc/jit/export.cpp
torch/csrc/jit/import_method.cpp
torch/csrc/jit/ir.cpp
torch/csrc/jit/ir.h
torch/csrc/jit/passes/alias_analysis.cpp
torch/csrc/jit/passes/constant_pooling.cpp
torch/csrc/jit/passes/constant_propagation.cpp
torch/csrc/jit/passes/erase_number_types.cpp
torch/csrc/jit/passes/graph_fuser.cpp
torch/csrc/jit/passes/python_print.cpp
torch/csrc/jit/passes/shape_analysis.cpp
torch/csrc/jit/passes/to_batch.cpp
torch/csrc/jit/passes/utils/subgraph_utils.cpp
torch/csrc/jit/python_ir.cpp
torch/csrc/jit/register_prim_ops.cpp
torch/csrc/jit/script/compiler.cpp
torch/csrc/jit/script/file_check.cpp [new file with mode: 0644]
torch/csrc/jit/script/init.cpp
torch/csrc/jit/symbolic_script.cpp
torch/jit/__init__.py
torch/onnx/symbolic.py
torch/onnx/utils.py

index 50a2084..527a604 100644 (file)
@@ -78,6 +78,7 @@ namespace c10 {
   _(aten, warn)                    \
   _(aten, floordiv)                \
   _(aten, __round_to_zero_floordiv)\
+  _(aten, _unwrap_optional)        \
   _(prim, fork)                    \
   _(prim, RaiseException)          \
   _(prim, Function)                \
index f23bae4..f115b35 100644 (file)
@@ -211,8 +211,8 @@ using OptionalTypePtr = std::shared_ptr<OptionalType>;
 // Note: NoneType is NOT a subtype of any optional.
 // instead NoneType is convertable in schema matching to any Optional[T]
 // it is handled this way because it is not possible to match None to Optional[T]
-// and extract T. Intead, we always create an instance of the prim::None instruction
-// with a particular type: v: Optional[int] = prim::None()
+// and extract T. Intead, we always create a None constant instruction
+// with a particular type: v: Optional[int] = None()
 struct CAFFE2_API OptionalType: public SingleElementType<TypeKind::OptionalType, OptionalType> {
   static OptionalTypePtr create(TypePtr element) {
     return OptionalTypePtr(new OptionalType(std::move(element))); // NOLINT(modernize-make-shared)
index a3aaa4a..65bc780 100644 (file)
@@ -41,6 +41,7 @@ JIT_TEST(NetDefConverter)
 
 JIT_TEST(THNNConv)
 JIT_TEST(ATenNativeBatchNorm)
+JIT_TEST(NoneSchemaMatch)
 
 #define JIT_TEST_CUDA(name)    \
   TEST(JitTest, name##_CUDA) { \
index 654dd30..8072ef8 100644 (file)
@@ -10,6 +10,7 @@ namespace torch {
 namespace jit {
 std::string runJITCPPTests() {
   std::stringstream out;
+  testNoneSchemaMatch();
   testAutogradProfiler();
   testADFormulas();
   testArgumentSpec();
index 6f50469..5dd7a65 100644 (file)
@@ -1266,8 +1266,7 @@ void testCustomOperators() {
     ASSERT_EQ(op->schema().arguments()[0].name(), "_0");
     ASSERT_EQ(op->schema().arguments()[0].type()->kind(), TypeKind::FloatType);
     ASSERT_EQ(op->schema().arguments()[1].name(), "_1");
-    ASSERT_EQ(
-        op->schema().arguments()[1].type()->kind(), TypeKind::TensorType);
+    ASSERT_EQ(op->schema().arguments()[1].type()->kind(), TypeKind::TensorType);
 
     ASSERT_EQ(op->schema().returns()[0].type()->kind(), TypeKind::TensorType);
 
@@ -1295,8 +1294,7 @@ void testCustomOperators() {
     ASSERT_EQ(op->schema().arguments()[0].name(), "a");
     ASSERT_EQ(op->schema().arguments()[0].type()->kind(), TypeKind::FloatType);
     ASSERT_EQ(op->schema().arguments()[1].name(), "b");
-    ASSERT_EQ(
-        op->schema().arguments()[1].type()->kind(), TypeKind::TensorType);
+    ASSERT_EQ(op->schema().arguments()[1].type()->kind(), TypeKind::TensorType);
 
     ASSERT_EQ(op->schema().returns().size(), 1);
     ASSERT_EQ(op->schema().returns()[0].type()->kind(), TypeKind::TensorType);
@@ -1454,9 +1452,11 @@ void testCustomOperators() {
 }
 
 void testEvalModeForLoadedModule() {
-  if (isSandcastle()) return;  // The module file to load is not generated in Sandcastle
+  if (isSandcastle())
+    return; // The module file to load is not generated in Sandcastle
   std::string module_path = "dropout_model.pt";
-  std::shared_ptr<torch::jit::script::Module> module = torch::jit::load(module_path);
+  std::shared_ptr<torch::jit::script::Module> module =
+      torch::jit::load(module_path);
   AT_ASSERT(module->get_module("dropout")->is_training());
   module->eval();
   AT_ASSERT(!module->get_module("dropout")->is_training());
@@ -1816,6 +1816,46 @@ void testAutogradProfiler() {
   AT_CHECK(count == 200);
 }
 
+void testNoneSchemaMatch() {
+  RegisterOperators reg({
+      Operator(
+          "test::test_none() -> int?",
+          [](const Node* node) {
+            return [](Stack& stack) {
+              push(stack, IValue());
+              return 0;
+            };
+          }),
+      Operator(
+          "test::is_none(int? a) -> bool",
+          [](const Node* node) {
+            return [](Stack& stack) {
+              IValue a = pop(stack);
+              if (a.isNone()) {
+                push(stack, true);
+              } else {
+                push(stack, false);
+              }
+              return 0;
+            };
+          }),
+  });
+
+  // Constant propagation will run test_none and produce a None,
+  // testing that its type is set appropriately and schema matching  doesn't
+  // fail when running is_none
+
+  auto r = std::make_shared<Graph>();
+  auto& g = *r;
+  auto opt_int = g.insert(Symbol::fromQualString("test::test_none"), {});
+  auto out_bool = g.insert(Symbol::fromQualString("test::is_none"), {opt_int});
+  g.registerOutput(out_bool);
+  ConstantPropagation(r);
+
+  auto nodes = r->block()->nodes();
+  // checking that constant propagation ran wo/failure
+  AT_ASSERT(std::distance(nodes.begin(), nodes.end()) == 1);
+}
 
 } // namespace
 } // namespace jit
index 4293354..3bf9312 100644 (file)
@@ -1,6 +1,6 @@
 graph(%0 : Double(20, 16, 50, 40),
       %1 : Double(13, 16, 3, 3)):
-  %2 : Tensor? = prim::None(), scope: Conv2d
+  %2 : Tensor? = prim::Constant(), scope: Conv2d
   %3 : int = prim::Constant[value=1](), scope: Conv2d
   %4 : int = prim::Constant[value=1](), scope: Conv2d
   %5 : int[] = prim::ListConstruct(%3, %4), scope: Conv2d
diff --git a/test/expect/TestScript.test_if_is_none_dispatch.expect b/test/expect/TestScript.test_if_is_none_dispatch.expect
new file mode 100644 (file)
index 0000000..1ba11bc
--- /dev/null
@@ -0,0 +1,24 @@
+graph(%input : Tensor,
+      %opt.1 : Tensor?):
+  %2 : None = prim::Constant()
+  %3 : int = prim::Constant[value=1]()
+  %4 : int = prim::Constant[value=2]()
+  %5 : int = prim::Constant[value=4]()
+  %x.1 : Tensor = aten::add(%input, %4, %3)
+  %7 : bool = aten::__isnot__(%opt.1, %2)
+  %opt.4 : Tensor?, %x.3 : Tensor = prim::If(%7)
+    block0():
+      %opt.2 : Tensor = prim::unchecked_unwrap_optional(%opt.1)
+      %opt.3 : Tensor = aten::_unwrap_optional(%opt.2)
+      %x.2 : Tensor = aten::add(%opt.3, %x.1, %3)
+      -> (%opt.3, %x.2)
+    block1():
+      -> (%opt.1, %x.1)
+  %13 : bool = aten::__is__(%opt.4, %2)
+  %x : Tensor = prim::If(%13)
+    block0():
+      %x.4 : Tensor = aten::add(%x.3, %5, %3)
+      -> (%x.4)
+    block1():
+      -> (%x.3)
+  return (%x)
index 488ba65..0bdea96 100644 (file)
@@ -1,5 +1,5 @@
 graph(%a.1 : Tensor):
-  %1 : None = prim::None()
+  %1 : None = prim::Constant()
   %2 : int = prim::Constant[value=1]()
   %3 : Device = prim::Constant[value="cpu"]()
   %4 : int = prim::Constant[value=0]()
index 6e6b4f5..143ba0c 100644 (file)
@@ -1743,7 +1743,7 @@ class TestJit(JitTestCase):
 
         self.run_pass('constant_propagation', constant_prop.graph)
         graph_str = str(constant_prop.graph)
-        self.assertTrue(graph_str.count("prim::None") == 0)
+        self.assertTrue(graph_str.count("prim::Constant") == 1)
 
     def test_constant_prop_if_inline(self):
         @torch.jit.script
@@ -3627,9 +3627,9 @@ a")
                 print(typed_nones())
 
         graph_str = str(test.graph)
-        self.assertTrue(graph_str.count("bool? = prim::None") == 1)
-        self.assertTrue(graph_str.count("int? = prim::None") == 1)
-        self.assertTrue(graph_str.count("None = prim::None") == 1)
+        self.assertTrue(graph_str.count("bool? = prim::Constant") == 1)
+        self.assertTrue(graph_str.count("int? = prim::Constant") == 1)
+        self.assertTrue(graph_str.count("None = prim::Constant") == 1)
 
     def test_literal(self):
         def func1(a, b):
index 5e248f8..5a3e8fa 100644 (file)
@@ -119,8 +119,7 @@ bool isDifferentiable(Node* n) {
   // Tensor", "aten::min(Tensor self) -> Tensor"
 
   if (n->kind() == prim::Constant || n->kind() == prim::Undefined ||
-      n->kind() == prim::AutogradAdd || n->kind() == prim::ConstantChunk ||
-      n->kind() == prim::None)
+      n->kind() == prim::AutogradAdd || n->kind() == prim::ConstantChunk)
     return true;
   if (differentiable_ops.find(n))
     return true;
@@ -143,7 +142,7 @@ bool isDifferentiable(Node* n) {
   if (n->matches(
           "aten::nll_loss(Tensor self, Tensor target, Tensor? weight, int reduction, int ignore_index) -> Tensor")) {
     // TODO(asuhan): support weight
-    return n->namedInput(attr::weight)->node()->kind() == prim::None;
+    return n->namedInput(attr::weight)->node()->mustBeNone();
   }
 
   // linear blocks may appear as inputs to graph executors, but they are removed
@@ -284,10 +283,10 @@ class GradientHelper {
 
     if (node->matches(
             "aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor")) {
-      return {
-          gradSumToSizeOf(grads.at(0), attr::self),
-          gradSumToSizeOf(grads.at(0) * node->namedInput(attr::alpha), attr::other),
-          nullptr};
+      return {gradSumToSizeOf(grads.at(0), attr::self),
+              gradSumToSizeOf(
+                  grads.at(0) * node->namedInput(attr::alpha), attr::other),
+              nullptr};
 
     } else if (
         node->matches(
@@ -379,12 +378,12 @@ class GradientHelper {
             "aten::clamp(Tensor self, Scalar? min, Scalar? max) -> Tensor")) {
       // handle the case that min/max is None
       Value* min = inputs.at(1);
-      bool min_must_be_none = min->node()->kind() == prim::None;
+      bool min_must_be_none = min->mustBeNone();
       Value* max = inputs.at(2);
-      bool max_must_be_none = max->node()->kind() == prim::None;
-      // XXX - this formula is wrong when min or max are not stricly prim::None
-      // but may be None dynamically. In this case an internal compiler error
-      // will get thrown when trying to generate expressions involving the
+      bool max_must_be_none = max->mustBeNone();
+      // XXX - this formula is wrong when min or max are not stricly a constant
+      // None but may be None dynamically. In this case an internal compiler
+      // error will get thrown when trying to generate expressions involving the
       // values of min/max
       if (!min_must_be_none && !max_must_be_none) {
         return {grads.at(0) *
@@ -544,12 +543,12 @@ class GradientHelper {
     } else if (
         node->matches(
             "aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta, Scalar alpha) -> Tensor")) {
-      return {
-          gradSumToSizeOf(grads.at(0) * node->namedInput(attr::beta), attr::self),
-          grads.at(0).mm(inputs.at(2).t()) * node->namedInput(attr::alpha),
-          inputs.at(1).t().mm(grads.at(0)) * node->namedInput(attr::alpha),
-          nullptr,
-          nullptr};
+      return {gradSumToSizeOf(
+                  grads.at(0) * node->namedInput(attr::beta), attr::self),
+              grads.at(0).mm(inputs.at(2).t()) * node->namedInput(attr::alpha),
+              inputs.at(1).t().mm(grads.at(0)) * node->namedInput(attr::alpha),
+              nullptr,
+              nullptr};
 
     } else if (node->matches("aten::mm(Tensor self, Tensor mat2) -> Tensor")) {
       return {grads.at(0).mm(inputs.at(1).t()),
@@ -739,8 +738,7 @@ class GradientHelper {
       return {backward_value->node()->output(0), nullptr};
 
     } else if (
-        node->kind() == prim::Constant || node->kind() == prim::Undefined ||
-        node->kind() == prim::None) {
+        node->kind() == prim::Constant || node->kind() == prim::Undefined) {
       return {};
     }
     throw std::runtime_error(
index ffd8f40..67400e8 100644 (file)
@@ -11,6 +11,7 @@ namespace jit {
 Value* insertConstant(
     Graph& g,
     const IValue& val,
+    const c10::TypePtr& result_type,
     c10::optional<SourceRange> loc,
     c10::optional<ScopePtr> scope) {
   Node* n = g.create(prim::Constant);
@@ -64,8 +65,6 @@ Value* insertConstant(
     n->s_(attr::value, ss.str());
     n->output()->setType(DeviceObjType::get());
   } else if (val.isNone()) {
-    n->destroy();
-    n = g.create(prim::None);
     n->output()->setType(NoneType::get());
   } else {
     throw constant_not_supported_error(
@@ -75,6 +74,13 @@ Value* insertConstant(
     n->setSourceLocation(std::make_shared<SourceRange>(*loc));
   if (scope)
     n->setScope(*scope);
+  if (result_type) {
+    auto inferred_type = n->output()->type();
+    // Retain more type information in case of tensor constant
+    if (!(inferred_type->isSubtypeOf(TensorType::get()) && result_type->isSubtypeOf(inferred_type))) {
+      n->output()->setType(result_type);
+    }
+  }
   return g.insertNode(n)->output();
 }
 
@@ -146,6 +152,11 @@ RegisterOperators reg({
               push(stack, d);
               return 0;
             };
+          } else if (node->mustBeNone()) {
+            return [](Stack& stack) {
+              push(stack, IValue());
+              return 0;
+            };
           } else {
             std::stringstream ss;
             ss << "constant literal not supported for: " << type->str();
index badfefd..5f61178 100644 (file)
@@ -1,6 +1,7 @@
 #pragma once
 #include <torch/csrc/WindowsTorchApiMacro.h>
 #include <ATen/core/ivalue.h>
+#include <ATen/core/jit_type.h>
 #include <torch/csrc/jit/scope.h>
 #include <torch/csrc/jit/source_range.h>
 
@@ -27,6 +28,7 @@ struct TORCH_API constant_not_supported_error : public std::runtime_error {
 TORCH_API Value* insertConstant(
     Graph& g,
     const IValue& val,
+    const c10::TypePtr& result_type = nullptr,
     c10::optional<SourceRange> loc = c10::nullopt,
     c10::optional<ScopePtr> scope = c10::nullopt);
 
index 004a64d..bda3e6e 100644 (file)
@@ -87,8 +87,7 @@ void validateBlock(
       bool is_aten_enabled = operator_export_type ==
               onnx_torch::OperatorExportTypes::ONNX_ATEN_FALLBACK ||
           operator_export_type == onnx_torch::OperatorExportTypes::ONNX_ATEN;
-      if (!node->kind().is_onnx() && !is_aten_enabled &&
-          node->kind() != prim::None) {
+      if (!node->kind().is_onnx() && !is_aten_enabled && !node->mustBeNone()) {
         FAIL_EXPORT(
             "Couldn't export operator " + node->kind().toDisplayString() +
             "\n\nDefined at:\n" + getNodeStackTraceString(node));
@@ -236,7 +235,7 @@ void EncoderBase::EncodeBlock(
   for (auto node : block->nodes()) {
     bool is_raw_export =
         operator_export_type_ == onnx_torch::OperatorExportTypes::RAW;
-    if (node->kind() == prim::None && !is_raw_export) {
+    if (node->mustBeNone() && !is_raw_export) {
       // None nodes are used to implement optional inputs. One
       // way to "not provide" an optional input is to create an
       // Undefined node, and pass its output as that input.
@@ -249,7 +248,7 @@ void EncoderBase::EncodeBlock(
       p_n->set_doc_string(ss.str());
     }
     for (auto input : node->inputs()) {
-      if (input->node()->kind() == prim::None && !is_raw_export) {
+      if (input->node()->mustBeNone() && !is_raw_export) {
         p_n->add_input("");
       } else {
         p_n->add_input(input->uniqueName());
index a51e32b..a5a3bc6 100644 (file)
@@ -81,7 +81,7 @@ struct ConstantTableValue : public script::SugaredValue {
                                      << " is out of bounds (constant table has "
                                      << constants_.size() << " entries).";
     }
-    Value* value = m.graph()->insertConstant(constants_[offset], loc);
+    Value* value = m.graph()->insertConstant(constants_[offset], nullptr, loc);
     return std::make_shared<script::SimpleValue>(value);
   }
 
index 3e75221..069d2d8 100644 (file)
@@ -637,7 +637,7 @@ std::shared_ptr<Graph> Graph::copy() {
 }
 
 bool Value::mustBeNone() const {
-  return node_->kind() == prim::None;
+  return node_->mustBeNone();
 }
 
 std::string Value::uniqueNameBase() const {
@@ -755,6 +755,12 @@ bool Node::matches(
   return true;
 }
 
+bool Node::mustBeNone() const {
+  return kind_ == prim::Constant && !this->hasAttributes() &&
+      (output()->type()->cast<OptionalType>() ||
+       output()->type() == NoneType::get());
+}
+
 void Node::dump() const {
   std::cout << *this << "\n";
 }
@@ -1179,7 +1185,7 @@ Node* Graph::createUndefined() {
 }
 
 Node* Graph::createNone(TypePtr typ) {
-  Node* n = create(prim::None);
+  Node* n = create(prim::Constant);
   n->output()->setType(OptionalType::create(std::move(typ)));
   return n;
 }
@@ -1312,10 +1318,11 @@ Node* Graph::createClone(
 
 Value* Graph::insertConstant(
     IValue val,
+    const TypePtr& result_type,
     c10::optional<SourceRange> loc,
     c10::optional<ScopePtr> scope) {
   return jit::insertConstant(
-      *this, std::move(val), std::move(loc), std::move(scope));
+      *this, std::move(val), result_type, std::move(loc), std::move(scope));
 }
 
 std::string Graph::toString() const {
index 430d385..f2ed570 100644 (file)
@@ -392,6 +392,7 @@ struct Node {
   bool is_constant(Symbol name) const {
     return static_cast<bool>(get(name));
   }
+  TORCH_API bool mustBeNone() const;
 
   TORCH_API bool isNondeterministic() const;
   TORCH_API bool hasSideEffects() const;
@@ -1071,8 +1072,12 @@ struct Graph {
       const std::function<Value*(Value*)>& value_map,
       bool copy_blocks = true);
 
+
+  // Insert constant IValue into the graph. If the type cannot be fully deduced
+  // from the ivalue, as with a None that is set to t?, use result_type
   TORCH_API Value* insertConstant(
       IValue val,
+      const TypePtr& result_type = nullptr,
       c10::optional<SourceRange> loc = c10::nullopt,
       c10::optional<ScopePtr> scope = c10::nullopt);
 
index 0abd12c..e4b44e3 100644 (file)
@@ -286,7 +286,6 @@ void AliasDb::analyzeImpl(Node* node) {
     case prim::FusedConcat:
     case prim::MMTreeReduce:
     case prim::MMBatchSide:
-    case prim::None:
     case prim::BroadcastSizes:
     case prim::ChunkSizes:
     case prim::Function:
index ab70489..5421c8c 100644 (file)
@@ -26,7 +26,7 @@ void ConstantPooling(
       continue;
     }
 
-    if (node->kind() != prim::Constant && node->kind() != prim::None) {
+    if (node->kind() != prim::Constant) {
       continue;
     }
 
index 7cdb1e5..56a4db7 100644 (file)
@@ -20,8 +20,6 @@ std::unordered_set<Symbol> skip_list = {
     prim::Constant,
     prim::Undefined,
     prim::unchecked_unwrap_optional, // TODO remove
-    prim::None, // it is already a constant and propagating it will lose
-                // important type information about which Optional type it is
     // TODO (zach): we should consider skipping tensor factories in the cases
     // where the constant tensor would be large but cheap to create.
 };
@@ -30,11 +28,7 @@ std::vector<IValue> runNode(Node* n) {
   auto op = getOperation(n);
   Stack stack;
   for (auto input : n->inputs()) {
-    if (input->node()->kind() == prim::None) {
-      stack.emplace_back(IValue());
-    } else {
-      stack.push_back(*(toIValue(input)));
-    }
+    stack.push_back(*(toIValue(input)));
   }
   op(stack);
   auto var_outputs = fmap(stack, [&](IValue v) -> IValue {
@@ -66,6 +60,9 @@ void propagateNode(Node* n) {
   for (size_t i = 0; i < outputs.size(); ++i) {
     try {
       auto new_output = graph->insertConstant(outputs[i]);
+      if (outputs[i].isNone()) {
+        new_output->setType(n->outputs()[i]->type());
+      }
       n->outputs()[i]->replaceAllUsesWith(new_output);
     } catch (constant_not_supported_error& err) {
       // we cannot actually represent the IValue as a constant node,
@@ -170,8 +167,7 @@ void removeExtraLoopOutputs(Node* node) {
 void ConstantPropagation(Node* n, const AliasDb& aliasDb) {
   bool constant_inputs =
       std::all_of(n->inputs().begin(), n->inputs().end(), [&](Value* v) {
-        return v->node()->kind() == prim::Constant ||
-            v->node()->kind() == prim::None;
+        return v->node()->kind() == prim::Constant;
       });
   bool supported_node = !n->kind().is_onnx() &&
       skip_list.count(n->kind()) == 0 && !n->isNondeterministic() &&
index 98dbf25..838b9ba 100644 (file)
@@ -24,7 +24,7 @@ static void EraseNumberTypesOnBlock(Block* block) {
           auto s = *constant_as<at::Scalar>(it->output());
           WithInsertPoint guard(*it);
           Value* r = block->owningGraph()->insertConstant(
-              scalar_to_tensor(s), c10::nullopt, it->scope());
+              scalar_to_tensor(s), nullptr, c10::nullopt, it->scope());
           it->output()->replaceAllUsesWith(r);
         }
       } break;
index 01b7eff..b217dce 100644 (file)
@@ -136,7 +136,7 @@ c10::optional<bool> isDefined(Value* tensor) {
   if (tensor->type()->isSubtypeOf(TensorType::get())) {
     return true;
   }
-  if (tensor->node()->kind() == prim::None ||
+  if (tensor->node()->mustBeNone() ||
       tensor->node()->kind() == prim::Undefined) {
     return false;
   }
index 7925917..7d249cd 100644 (file)
@@ -243,7 +243,6 @@ struct PythonPrintPass {
     switch (n->kind()) {
       case prim::Constant:
       case prim::Undefined:
-      case prim::None:
         return true;
       default:
         return false;
@@ -600,10 +599,10 @@ struct PythonPrintPass {
     if (node->kind() == prim::PythonOp) {
       auto value = static_cast<const PythonOp*>(node);
       if (enforce_importable_ && value->ignore_on_export) {
-          // Op has been marked as ignored, so insert an error in its place
-          indent();
-          out << "ops.prim.IgnoredPythonOp()\n";
-          return;
+        // Op has been marked as ignored, so insert an error in its place
+        indent();
+        out << "ops.prim.IgnoredPythonOp()\n";
+        return;
       }
     }
     switch (node->kind()) {
@@ -705,6 +704,36 @@ struct PythonPrintPass {
     }
   }
 
+  void printNone(std::ostream& stmt, const Node* node) {
+    if (node->output()->type()->isSubtypeOf(NoneType::get())) {
+      stmt << "None";
+      return;
+    }
+    // XXX - when None has an Optional[T] type, we must ensure that type
+    // can be recovered on parsing. It cannot be recovered if it will be
+    // matched to schema with free variables. If it is used only in places
+    // where there is schema and the scheme has no free variables, then we
+    // can recover it without annotation. Otherwise, we annotate None with
+    // the right optional type
+    const auto& uses = node->output()->uses();
+    bool all_usable_schema =
+        std::all_of(uses.begin(), uses.end(), [](const Use& u) {
+          if (auto schema = u.user->maybeSchema()) {
+            if (u.offset >= schema->arguments().size()) {
+              return false;
+            }
+            return !schema->arguments().at(u.offset).type()->hasFreeVariables();
+          }
+          return false;
+        });
+
+    if (all_usable_schema) {
+      stmt << "None";
+    } else {
+      stmt << "annotate(" << node->output()->type()->python_str() << ", None)";
+    }
+  }
+
   // Prints the RHS value of a Node, e.g. `aten.add(x, y)`
   void printRHS(std::ostream& stmt, Node* node) {
     switch (node->kind()) {
@@ -721,42 +750,13 @@ struct PythonPrintPass {
         value->writeScalars(stmt);
         printValueList(stmt, node->inputs(), "(", ")");
       } break;
-      case prim::Constant: {
-        IValue v = toIValue(node->output()).value();
-        printConstant(stmt, v);
-      } break;
-      case prim::Undefined:
-      case prim::None: {
-        if (node->output()->type()->isSubtypeOf(NoneType::get())) {
-          stmt << "None";
-          break;
-        }
-        // XXX - when None has an Optional[T] type, we must ensure that type
-        // can be recovered on parsing. It cannot be recovered if it will be
-        // matched to schema with free variables. If it is used only in places
-        // where there is schema and the scheme has no free variables, then we
-        // can recover it without annotation. Otherwise, we annotate None with
-        // the right optional type
-        const auto& uses = node->output()->uses();
-        bool all_usable_schema =
-            std::all_of(uses.begin(), uses.end(), [](const Use& u) {
-              if (auto schema = u.user->maybeSchema()) {
-                if (u.offset >= schema->arguments().size()) {
-                  return false;
-                }
-                return !schema->arguments()
-                            .at(u.offset)
-                            .type()
-                            ->hasFreeVariables();
-              }
-              return false;
-            });
-
-        if (all_usable_schema) {
-          stmt << "None";
+      case prim::Constant:
+      case prim::Undefined: {
+        if (node->kind() == prim::Constant && !node->mustBeNone()) {
+          IValue v = toIValue(node->output()).value();
+          printConstant(stmt, v);
         } else {
-          stmt << "annotate(" << node->output()->type()->python_str()
-               << ", None)";
+          printNone(stmt, node);
         }
       } break;
       case prim::ImplicitTensorToNum: {
@@ -804,7 +804,8 @@ struct PythonPrintPass {
         if (node->inputs().size() == 0 &&
             !dict_type->getKeyType()->isSubtypeOf(StringType::get()) &&
             !dict_type->getValueType()->isSubtypeOf(TensorType::get())) {
-          stmt << "annotate(" << node->output()->type()->python_str() << ", {})";
+          stmt << "annotate(" << node->output()->type()->python_str()
+               << ", {})";
         } else {
           printDict(stmt, node->inputs());
         }
@@ -1072,7 +1073,6 @@ TORCH_API bool printerHasSpecialCaseFor(Symbol sym) {
       prim::ListConstruct,
       prim::DictConstruct,
       prim::ListUnpack,
-      prim::None,
       prim::Print,
       prim::PythonOp,
       prim::TupleConstruct,
index 41bbd18..e4f6e59 100644 (file)
@@ -397,9 +397,11 @@ class ShapePropagator {
         TypePtr typ = node->input()->type();
         if (typ->isSubtypeOf(IntType::get()) ||
             typ->isSubtypeOf(BoolType::get())) {
-          node->output()->setType(DimensionedTensorType::create(at::kLong, at::kCPU, 0));
+          node->output()->setType(
+              DimensionedTensorType::create(at::kLong, at::kCPU, 0));
         } else if (node->input()->type()->isSubtypeOf(FloatType::get())) {
-          node->output()->setType(DimensionedTensorType::create(at::kDouble, at::kCPU, 0));
+          node->output()->setType(
+              DimensionedTensorType::create(at::kDouble, at::kCPU, 0));
         }
         return;
       }
@@ -442,6 +444,12 @@ class ShapePropagator {
         setUnshapedType(node);
         return;
       }
+      case aten::_unwrap_optional: {
+        auto input_ivalue = toIValue(node->input());
+        if (input_ivalue && input_ivalue->isNone()) {
+          return;
+        }
+      }
       default:
         break; // fall-through
     }
@@ -499,8 +507,9 @@ class ShapePropagator {
   // primitive/tensor outputs.
 
   bool PropagateTensorShapeOnNode(Node* node, bool insert_expands) {
-    static const auto broadcast = [](std::vector<DimensionedTensorTypePtr>& tensor_types,
-                                     size_t arg_for_type) -> DimensionedTensorTypePtr {
+    static const auto broadcast =
+        [](std::vector<DimensionedTensorTypePtr>& tensor_types,
+           size_t arg_for_type) -> DimensionedTensorTypePtr {
       if (tensor_types.size() == 1) {
         return tensor_types[0];
       }
@@ -629,7 +638,8 @@ class ShapePropagator {
             "aten::zeros_like(Tensor self) -> Tensor",
         },
         [](Node* node) -> type_vec_t {
-          auto input_type = node->input(0)->type()->cast<DimensionedTensorType>();
+          auto input_type =
+              node->input(0)->type()->cast<DimensionedTensorType>();
           return input_type ? type_vec_t{input_type} : type_vec_t{};
         }};
 
@@ -691,7 +701,8 @@ class ShapePropagator {
             "aten::addcmul(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value) -> Tensor",
         },
         [this](Node* node) -> type_vec_t {
-          if (auto maybe_tensor_types = gatherTensorTypes<DimensionedTensorType>(node)) {
+          if (auto maybe_tensor_types =
+                  gatherTensorTypes<DimensionedTensorType>(node)) {
             return {broadcast(*maybe_tensor_types, 0)};
           }
           return {};
@@ -704,13 +715,15 @@ class ShapePropagator {
             "aten::where(Tensor condition, Tensor self, Tensor other) -> Tensor",
         },
         [this](Node* node) -> type_vec_t {
-          if (auto maybe_tensor_types = gatherTensorTypes<DimensionedTensorType>(node)) {
+          if (auto maybe_tensor_types =
+                  gatherTensorTypes<DimensionedTensorType>(node)) {
             return {broadcast(*maybe_tensor_types, 1)};
           }
           return {};
         }};
 
-    static const auto any_tensor_type = [](Node* node) -> DimensionedTensorTypePtr {
+    static const auto any_tensor_type =
+        [](Node* node) -> DimensionedTensorTypePtr {
       for (Value* input : node->inputs()) {
         if (auto type = input->type()->cast<DimensionedTensorType>()) {
           return type;
@@ -760,7 +773,8 @@ class ShapePropagator {
             "aten::ne(Tensor self, Scalar other) -> Tensor",
         },
         [this](Node* node) -> type_vec_t {
-          if (auto maybe_tensor_types = gatherTensorTypes<DimensionedTensorType>(node)) {
+          if (auto maybe_tensor_types =
+                  gatherTensorTypes<DimensionedTensorType>(node)) {
             return {broadcast(*maybe_tensor_types, 0)->toScalarType(at::kByte)};
           }
           return {};
@@ -812,7 +826,8 @@ class ShapePropagator {
             "aten::prelu(Tensor self, Tensor weight) -> Tensor",
         },
         [](Node* node) -> type_vec_t {
-          if (auto type = node->input(0)->type()->cast<DimensionedTensorType>()) {
+          if (auto type =
+                  node->input(0)->type()->cast<DimensionedTensorType>()) {
             return {type};
           }
           return {};
@@ -845,7 +860,8 @@ class ShapePropagator {
             "aten::any(Tensor self) -> Tensor",
         },
         [](Node* node) -> type_vec_t {
-          if (auto type = node->input(0)->type()->cast<DimensionedTensorType>()) {
+          if (auto type =
+                  node->input(0)->type()->cast<DimensionedTensorType>()) {
             return {type->withDim(0)};
           }
           return {};
@@ -865,7 +881,8 @@ class ShapePropagator {
             "aten::prod(Tensor self) -> Tensor",
         },
         [](Node* node) -> type_vec_t {
-          if (auto type = node->input(0)->type()->cast<DimensionedTensorType>()) {
+          if (auto type =
+                  node->input(0)->type()->cast<DimensionedTensorType>()) {
             return {at::isFloatingType(type->scalarType())
                         ? type->withDim(0)
                         : type->withDim(0)->toScalarType(at::kLong)};
@@ -1003,7 +1020,8 @@ class ShapePropagator {
       auto maybe_scalar_type = node->get<at::ScalarType>(attr::dtype);
       if (!maybe_scalar_type)
         return {};
-      return {DimensionedTensorType::create(*maybe_scalar_type, *maybe_device, dim)};
+      return {DimensionedTensorType::create(
+          *maybe_scalar_type, *maybe_device, dim)};
     };
 
     // Requirements:
@@ -1026,8 +1044,9 @@ class ShapePropagator {
             "aten::zeros_like(Tensor self, *, int dtype, int layout, Device device) -> Tensor",
         },
         [](Node* node) -> type_vec_t {
-          if (auto type =
-                  node->namedInput(attr::self)->type()->cast<DimensionedTensorType>()) {
+          if (auto type = node->namedInput(attr::self)
+                              ->type()
+                              ->cast<DimensionedTensorType>()) {
             return factory_with_ndim(node, type->dim());
           }
           return {};
@@ -1097,8 +1116,9 @@ class ShapePropagator {
             "aten::_cast_Short(Tensor self, bool non_blocking) -> Tensor",
         },
         [](Node* node) -> type_vec_t {
-          if (auto type =
-                  node->namedInput(attr::self)->type()->cast<DimensionedTensorType>()) {
+          if (auto type = node->namedInput(attr::self)
+                              ->type()
+                              ->cast<DimensionedTensorType>()) {
             return {type->toScalarType(get_cast_scalar_type(node))};
           }
           return {};
@@ -1243,7 +1263,8 @@ class ShapePropagator {
     static const auto reshape_prop =
         [](Node* node,
            Symbol shape_input,
-           const std::vector<DimensionedTensorTypePtr>& tensor_types) -> DimensionedTensorTypePtr {
+           const std::vector<DimensionedTensorTypePtr>& tensor_types)
+        -> DimensionedTensorTypePtr {
       if (auto list_size = determineListSize(node->namedInput(shape_input))) {
         return tensor_types.at(0)->withDim(*list_size);
       }
@@ -1340,7 +1361,8 @@ class ShapePropagator {
       }
       return nullptr;
     };
-    if (auto maybe_tensor_types = gatherTensorTypes<DimensionedTensorType>(node)) {
+    if (auto maybe_tensor_types =
+            gatherTensorTypes<DimensionedTensorType>(node)) {
       tensor_types = std::move(*maybe_tensor_types);
     } else {
       return false;
index dda796a..445e02b 100644 (file)
@@ -1,5 +1,5 @@
-#include <torch/csrc/jit/passes/dead_code_elimination.h>
 #include <torch/csrc/jit/passes/to_batch.h>
+#include <torch/csrc/jit/passes/dead_code_elimination.h>
 #include <torch/csrc/jit/script/compiler.h>
 
 namespace torch {
@@ -524,7 +524,6 @@ void ToBatch::toBatch(Block* block, Block* res_block) {
     } else if (n->kind().is_prim()) {
       switch (n->kind()) {
         case prim::Constant:
-        case prim::None:
           visitConstant(n, block, res_block);
           break;
         case prim::NumToTensor:
index 5e21f02..10db432 100644 (file)
@@ -71,6 +71,7 @@ void mergeNodeIntoSubgraph(Node* toMerge, Node* subgraphNode) {
       // enable more optimizations
       if (auto value = toIValue(input)) {
         auto nv = subgraph->insertConstant(*value);
+        nv->setType(input->type()); // Need to retain type information on Nones
         inputsMap[input] = nv;
       } else {
         // The common case: this is a regular input, so just register it with
index c2c350d..87ce018 100644 (file)
@@ -88,7 +88,10 @@ std::vector<Node*> findAllNodes(
   return ret;
 }
 
-std::vector<Node*> findAllNodes(Block* block, Symbol kind, bool recurse = true) {
+std::vector<Node*> findAllNodes(
+    Block* block,
+    Symbol kind,
+    bool recurse = true) {
   std::vector<Block*> blocks = {block};
   return findAllNodes(blocks, kind, recurse);
 }
@@ -291,13 +294,19 @@ void initPythonIRBindings(PyObject* module_) {
           "findNode",
           [](Graph& g, const std::string& kind, bool recurse) {
             return findNode(g.block(), Symbol::fromQualString(kind), recurse);
-          }, "Find Node", py::arg("kind"), py::arg("recurse") = true)
+          },
+          "Find Node",
+          py::arg("kind"),
+          py::arg("recurse") = true)
       .def(
           "findAllNodes",
           [](Graph& g, const std::string& kind, bool recurse) {
             return findAllNodes(
                 g.block(), Symbol::fromQualString(kind), recurse);
-          }, "Find all nodes",  py::arg("kind"), py::arg("recurse") = true)
+          },
+          "Find all nodes",
+          py::arg("kind"),
+          py::arg("recurse") = true)
       .def("addInput", [](Graph& g) { return g.addInput(); })
       .def("copy", [](Graph& g) { return g.copy(); })
       .GS(eraseInput)
@@ -380,20 +389,27 @@ void initPythonIRBindings(PyObject* module_) {
 #undef VS
 
   py::class_<Block, std::unique_ptr<Block, py::nodelete>>(m, "Block")
-      .def("nodes", [](Block& b) {
-        return py::make_iterator(b.nodes().begin(), b.nodes().end());
-      })
+      .def(
+          "nodes",
+          [](Block& b) {
+            return py::make_iterator(b.nodes().begin(), b.nodes().end());
+          })
       .def(
           "findNode",
           [](Block& b, const std::string& kind, bool recurse) {
             return findNode(&b, Symbol::fromQualString(kind), recurse);
-          }, "Find Node", py::arg("kind"), py::arg("recurse") = true)
+          },
+          "Find Node",
+          py::arg("kind"),
+          py::arg("recurse") = true)
       .def(
           "findAllNodes",
           [](Block& b, const std::string& kind, bool recurse) {
             return findAllNodes(&b, Symbol::fromQualString(kind), recurse);
-          }, "Find all nodes",  py::arg("kind"), py::arg("recurse") = true);
-
+          },
+          "Find all nodes",
+          py::arg("kind"),
+          py::arg("recurse") = true);
 
 #define NS(name) def(#name, &Node ::name)
   py::class_<Node, std::unique_ptr<Node, py::nodelete>>(m, "Node")
@@ -434,13 +450,19 @@ void initPythonIRBindings(PyObject* module_) {
           "findNode",
           [](Node& n, const std::string& kind, bool recurse) {
             return findNode(n.blocks(), Symbol::fromQualString(kind), recurse);
-          }, "Find Node", py::arg("kind"), py::arg("recurse") = true)
+          },
+          "Find Node",
+          py::arg("kind"),
+          py::arg("recurse") = true)
       .def(
           "findAllNodes",
           [](Node& n, const std::string& kind, bool recurse) {
             return findAllNodes(
                 n.blocks(), Symbol::fromQualString(kind), recurse);
-          }, "Find all nodes",  py::arg("kind"), py::arg("recurse") = true)
+          },
+          "Find all nodes",
+          py::arg("kind"),
+          py::arg("recurse") = true)
       .def("input", [](Node& n) { return n.input(); })
       .def("output", [](Node& n) { return n.output(); })
       .NS(addInput)
@@ -465,6 +487,7 @@ void initPythonIRBindings(PyObject* module_) {
             return py::make_iterator(n.blocks().begin(), n.blocks().end());
           })
       .NS(addBlock)
+      .NS(mustBeNone)
 
 #define AS(name) def(#name, &Node::name)
       // methods from Attributes
@@ -506,9 +529,7 @@ void initPythonIRBindings(PyObject* module_) {
           })
       .def(
           "t",
-          [](Node& n, const char* name) {
-            return n.t(Symbol::attr(name));
-          })
+          [](Node& n, const char* name) { return n.t(Symbol::attr(name)); })
       // Tensors (ts_) -- manually written to unwrap variables into tensors.
       .def(
           "ts_",
@@ -538,7 +559,8 @@ void initPythonIRBindings(PyObject* module_) {
           "z_",
           [](Node& n, const char* name, at::Tensor v) {
             return n.t_(
-                Symbol::attr(name), autograd::Variable(v.view({})).set_requires_grad(false));
+                Symbol::attr(name),
+                autograd::Variable(v.view({})).set_requires_grad(false));
           })
       .def(
           "z",
@@ -583,7 +605,11 @@ void initPythonIRBindings(PyObject* module_) {
             return s.str();
           })
       .def("kind", [](const Type& t) { return typeKindToString(t.kind()); })
-      .def("dim", [](const Type& t) { return t.expect<DimensionedTensorType>()->dim(); })
+      .def(
+          "dim",
+          [](const Type& t) {
+            return t.expect<DimensionedTensorType>()->dim();
+          })
       .def(
           "sizes",
           [](Type& t) { return t.expect<CompleteTensorType>()->sizes(); })
@@ -644,6 +670,11 @@ void initPythonIRBindings(PyObject* module_) {
       .def(py::init([](TypePtr key, TypePtr value) {
         return DictType::create(key, value);
       }));
+  py::class_<OptionalType, Type, std::shared_ptr<OptionalType>>(
+      m, "OptionalType")
+      .def(py::init([](TypePtr a) { return OptionalType::create(a); }))
+      .def_static("ofTensor", &OptionalType::ofTensor)
+      .def("getElementType", &OptionalType::getElementType);
 
   py::class_<Use>(m, "Use")
       .def_readonly("user", &Use::user)
index 969c2ab..222156c 100644 (file)
@@ -417,14 +417,6 @@ RegisterOperators reg({
           };
         }),
     Operator(
-        prim::None,
-        [](const Node* node) {
-          return [](Stack& stack) {
-            stack.emplace_back(IValue());
-            return 0;
-          };
-        }),
-    Operator(
         prim::Print,
         [](const Node* node) {
           size_t num_inputs = node->inputs().size();
index 9d3fcdd..25f9607 100644 (file)
@@ -474,7 +474,7 @@ static Value* materializeConstant(
   }
 
   WithInsertPoint guard(graph.block()->nodes().front());
-  auto new_constant = graph.insertConstant(val, r);
+  auto new_constant = graph.insertConstant(val, nullptr, r);
   map[val] = new_constant;
 
   return new_constant;
@@ -1264,7 +1264,7 @@ struct to_ir {
       if (cond) {
         cond_val = emitCond(cond.value());
       } else {
-        cond_val = graph->insertConstant(true, range);
+        cond_val = graph->insertConstant(true, nullptr, range);
       }
     }
     n->addInput(max_trip_count_val);
@@ -1299,7 +1299,7 @@ struct to_ir {
         Value* body_cond_value = emitCond(cond.value());
         body_block->registerOutput(body_cond_value);
       } else {
-        Value* cond_value_dummy = graph->insertConstant(true, range);
+        Value* cond_value_dummy = graph->insertConstant(true, nullptr, range);
         body_block->registerOutput(cond_value_dummy);
       }
 
@@ -1422,7 +1422,7 @@ struct to_ir {
   // print(a)
   void emitRaise(const SourceRange& loc) {
     const std::string exception = "Exception";
-    auto string_input = insertConstant(*graph, exception, loc);
+    auto string_input = insertConstant(*graph, exception, nullptr, loc);
     graph->insert(prim::RaiseException, {string_input}, {}, loc);
   }
 
@@ -2013,7 +2013,7 @@ struct to_ir {
       bool is_instance_val =
           isInstanceCheck(apply.inputs()[0], apply.inputs()[1]);
       return std::make_shared<SimpleValue>(
-          graph->insertConstant(is_instance_val, loc));
+          graph->insertConstant(is_instance_val, nullptr, loc));
     } else {
       auto inputs = getNamedValues(apply.inputs(), true);
       auto attributes = emitAttributes(apply.attributes());
@@ -2135,7 +2135,7 @@ struct to_ir {
     stack.push_back(*maybe_constant_input);
     op(stack);
     AT_ASSERT(stack.size() == 1);
-    return graph->insertConstant(stack[0], tree->range());
+    return graph->insertConstant(stack[0], nullptr, tree->range());
   }
 
   // This function extract a new graph from its original subgraph
@@ -2221,13 +2221,13 @@ struct to_ir {
         return emitConst(Const(tree));
       } break;
       case TK_TRUE: {
-        return graph->insertConstant(true, tree->range());
+        return graph->insertConstant(true, nullptr, tree->range());
       } break;
       case TK_FALSE: {
-        return graph->insertConstant(false, tree->range());
+        return graph->insertConstant(false, nullptr, tree->range());
       } break;
       case TK_NONE: {
-        return graph->insertConstant(IValue(), tree->range());
+        return graph->insertConstant(IValue(), nullptr, tree->range());
       } break;
       case TK_SUBSCRIPT: {
         return emitSubscript(Subscript(tree));
@@ -2315,7 +2315,7 @@ struct to_ir {
   }
 
   Value* emitStringLiteral(const StringLiteral& c) {
-    return insertConstant(*graph, c.text(), c.range());
+    return insertConstant(*graph, c.text(), nullptr, c.range());
   }
 
   // Desugars select indexing: tensor[i] -> tensor.select(dim, i)
@@ -2329,7 +2329,7 @@ struct to_ir {
         *graph,
         aten::select,
         c10::nullopt,
-        {input, graph->insertConstant(dim, loc), index},
+        {input, graph->insertConstant(dim, nullptr, loc), index},
         {},
         true);
   }
@@ -2349,7 +2349,7 @@ struct to_ir {
     // aten::slice, we should separate it from this function.
     if (dim) {
       AT_ASSERT(input->type()->isSubtypeOf(TensorType::get()));
-      args.emplace_back(loc, "dim", graph->insertConstant(dim.value(), loc));
+      args.emplace_back(loc, "dim", graph->insertConstant(dim.value(), nullptr, loc));
     } else {
       AT_ASSERT(!input->type()->isSubtypeOf(TensorType::get()));
     }
@@ -2366,7 +2366,7 @@ struct to_ir {
         return emitTupleSlice(loc, args[0], args[1], c10::nullopt);
       }
     }
-    NamedValue step = NamedValue(loc, "step", graph->insertConstant(1, loc));
+    NamedValue step = NamedValue(loc, "step", graph->insertConstant(1, nullptr, loc));
     return emitBuiltinCall(
         loc, *graph, aten::slice, c10::nullopt, args, {step}, true);
   }
diff --git a/torch/csrc/jit/script/file_check.cpp b/torch/csrc/jit/script/file_check.cpp
new file mode 100644 (file)
index 0000000..e69de29
index 1a2b89b..ac44724 100644 (file)
@@ -444,26 +444,26 @@ std::shared_ptr<SugaredValue> toSugaredValue(
   auto& g = *m.graph();
   if (is_constant) {
     if (py::isinstance<py::bool_>(obj)) {
-      return toSimple(g.insertConstant(py::cast<bool>(obj), loc));
+      return toSimple(g.insertConstant(py::cast<bool>(obj), nullptr, loc));
     } else if (py::isinstance<py::int_>(obj)) {
-      return toSimple(g.insertConstant(py::cast<int64_t>(obj), loc));
+      return toSimple(g.insertConstant(py::cast<int64_t>(obj), nullptr, loc));
     } else if (py::isinstance<py::float_>(obj)) {
-      return toSimple(g.insertConstant(py::cast<double>(obj), loc));
+      return toSimple(g.insertConstant(py::cast<double>(obj), nullptr, loc));
     } else if (py::isinstance<py::str>(obj)) {
-      return toSimple(g.insertConstant(py::cast<std::string>(obj), loc));
+      return toSimple(g.insertConstant(py::cast<std::string>(obj), nullptr, loc));
     } else if (obj.is(py::none())) {
-      return toSimple(g.insertConstant(IValue(), loc));
+      return toSimple(g.insertConstant(IValue(), nullptr, loc));
     } else if (THPDevice_Check(obj.ptr())) {
       auto device = reinterpret_cast<THPDevice*>(obj.ptr());
       return toSimple(g.insertConstant(device->device));
     } else if (THPLayout_Check(obj.ptr())) {
       auto layout = reinterpret_cast<THPLayout*>(obj.ptr());
       const auto v = static_cast<int64_t>(layout->layout);
-      return toSimple(g.insertConstant(v, loc));
+      return toSimple(g.insertConstant(v, nullptr, loc));
     } else if (THPDtype_Check(obj.ptr())) {
       auto dtype = reinterpret_cast<THPDtype*>(obj.ptr());
       const auto v = static_cast<int64_t>(dtype->scalar_type);
-      return toSimple(g.insertConstant(v, loc));
+      return toSimple(g.insertConstant(v, nullptr, loc));
     } else if (py::isinstance<py::tuple>(obj)) {
       return std::make_shared<ConstantPythonTupleValue>(obj);
     }
index cfc1c2c..a9e7b09 100644 (file)
@@ -332,7 +332,6 @@ const std::vector<std::string> functions = {
                 else:
                     grad_self = torch._adaptive_avg_pool2d_backward(grad_output, self)
                 return grad_self, None
-
             return torch.adaptive_avg_pool2d(self, output_size), backward
 
         def batch_norm(input : Tensor,
@@ -371,6 +370,11 @@ const std::vector<std::string> functions = {
 
             return torch.embedding(weight, indices, padding_idx, scale_grad_by_freq, sparse), backward
 
+        def nll_loss(self, target, weight: Optional[Tensor], reduction: int, ignore_index: int):
+            result, total_weight = torch.nll_loss_forward(self, target, weight, reduction, ignore_index)
+            def backward(grad):
+                return torch.nll_loss_backward(grad, self, target, weight, reduction, ignore_index, total_weight), None, None, None, None
+            return result, backward
       )"};
 std::unordered_map<std::string, GradientPair> schema_to_graphs;
 
index 6bd08bd..7498f8e 100644 (file)
@@ -490,7 +490,7 @@ def _check_trace(check_inputs, func, executor_options, module, check_tolerance,
                 if n_mod.kind() != n_check.kind():
                     break  # Graphs have already diverged
 
-                if n_mod.kind() == n_check.kind() and n_mod.kind() == 'prim::Constant':
+                if n_mod.kind() == 'prim::Constant' and not (n_mod.mustBeNone() or n_check.mustBeNone()):
                     if n_mod.kindOf('value') != 't' or n_check.kindOf('value') != 't':
                         continue
 
index 8dc72c8..e594b1a 100644 (file)
@@ -1,7 +1,7 @@
 import numbers
 
 import torch
-from torch._C import DynamicType, ListType
+from torch._C import DynamicType, ListType, OptionalType
 from torch.nn.modules.utils import _single, _pair, _triple
 from torch.nn.utils.rnn import PackedSequence
 import warnings
@@ -230,7 +230,9 @@ _onnx_opset_version = 9
 
 # used to represent "missing" optional inputs
 def unused(g):
-    return g.op("prim::None")
+    n = g.op("prim::Constant")
+    n.setType(OptionalType.ofTensor())
+    return n
 
 
 def _shape_as_tensor(g, input):
@@ -800,7 +802,7 @@ def _convolution(g, input, weight, bias, stride, padding, dilation,
 
     args = [input, weight]
     # ONNX only supports 1D bias
-    if bias.node().kind() != "prim::None" and bias.type().dim() == 1:
+    if not bias.node().mustBeNone() and bias.type().dim() == 1:
         args.append(bias)
 
     kwargs = {"kernel_shape_i": weight_size[2:],
@@ -821,7 +823,7 @@ def _convolution(g, input, weight, bias, stride, padding, dilation,
 
     n = g.op("ConvTranspose" if transposed else "Conv", *args, **kwargs)
 
-    if bias.node().kind() != "prim::None" and bias.type().dim() != 1:
+    if not bias.node().mustBeNone() and bias.type().dim() != 1:
         return g.op("Add", n, bias)
     else:
         return n
@@ -834,12 +836,12 @@ def batch_norm(g, input, weight, bias, running_mean, running_var, training, mome
         # batchnorm1d accepts 2d and 3d array, but ONNX only accepts 3d
         input = g.op("Unsqueeze", input, axes_i=[2])
 
-    if weight is None or weight.node().kind() == "prim::None":
+    if weight is None or weight.node().mustBeNone():
         assert len(input_sizes) > 1
         weight_value = torch.tensor([1.] * input_sizes[1]).type(
             'torch.' + input.type().scalarType() + 'Tensor')
         weight = g.op("Constant", value_t=weight_value)
-    if bias is None or bias.node().kind() == "prim::None":
+    if bias is None or bias.node().mustBeNone():
         assert len(input_sizes) > 1
         bias_value = torch.tensor([0.] * input_sizes[1]).type(
             'torch.' + input.type().scalarType() + 'Tensor')
@@ -866,12 +868,12 @@ def batch_norm(g, input, weight, bias, running_mean, running_var, training, mome
 @parse_args('v', 'v', 'v', 'v', 'v', 'i', 'f', 'f', 'i')
 def instance_norm(g, input, weight, bias, running_mean, running_var, use_input_stats, momentum, eps, cudnn_enabled):
     input_sizes = input.type().sizes()
-    if weight is None or weight.node().kind() == "prim::None":
+    if weight is None or weight.node().mustBeNone():
         assert len(input_sizes) > 1
         weight_value = torch.tensor([1.] * input_sizes[1]).type(
             'torch.' + input.type().scalarType() + 'Tensor')
         weight = g.op("Constant", value_t=weight_value)
-    if bias is None or bias.node().kind() == "prim::None":
+    if bias is None or bias.node().mustBeNone():
         assert len(input_sizes) > 1
         bias_value = torch.tensor([0.] * input_sizes[1]).type(
             'torch.' + input.type().scalarType() + 'Tensor')
@@ -951,11 +953,11 @@ def pow(g, self, exponent):
 
 
 def clamp(g, self, min, max):
-    # min or max may be prim::None that we need to dispatch to
+    # min or max may be None that we need to dispatch to
     # Clip separately, as ONNX does not have None syntax
-    if min.node().kind() == "prim::None":
+    if min.node().mustBeNone():
         return clamp_max(g, self, max)
-    elif max.node().kind() == "prim::None":
+    elif max.node().mustBeNone():
         return clamp_min(g, self, min)
     else:
         min = _parse_arg(min, 'f')
index 0568e60..9eb953b 100644 (file)
@@ -509,7 +509,7 @@ def _run_symbolic_function(g, n, inputs, env, operator_export_type=OperatorExpor
                 return fn(g, *inputs, **attrs)
 
         elif ns == "prim":
-            if op_name == "Constant":
+            if op_name == "Constant" and not n.mustBeNone():
                 if n.kindOf("value") == "t":
                     return g.op("Constant", value_t=n["value"])
                 elif n.kindOf("value") == "is":
@@ -520,8 +520,8 @@ def _run_symbolic_function(g, n, inputs, env, operator_export_type=OperatorExpor
                 else:
                     raise RuntimeError("Unsupported prim::Constant kind: `{}`. Send a bug report.".format(
                         n.kindOf("value")))
-            elif op_name == "None" or op_name == "ListConstruct" or op_name == "ListUnpack":
-                # None is not an ONNX operator; keep it as prim::None
+            elif n.mustBeNone() or op_name == "ListConstruct" or op_name == "ListUnpack":
+                # None is not an ONNX operator; keep it as None
                 # let the exporter handle finally eliminating these
 
                 # For ListConstruct/ListUnpack, it will be erased in the ONNX peephole pass