Revert D13272203: [pytorch][PR] [jit] Meta programming on If Stmt cond to enable...
authorMichael Suo <suo@fb.com>
Mon, 3 Dec 2018 21:24:46 +0000 (13:24 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Mon, 3 Dec 2018 21:28:52 +0000 (13:28 -0800)
Differential Revision:
D13272203

Original commit changeset: 44a545abb766

fbshipit-source-id: 8861eb4810a6c9ea4aba8427b3a07d2fa0d69a15

test/expect/TestScript.test_if_is_none_dispatch.expect [deleted file]
test/test_jit.py
torch/csrc/jit/ir.cpp
torch/csrc/jit/ir.h
torch/csrc/jit/script/compiler.cpp
torch/csrc/jit/script/compiler.h

diff --git a/test/expect/TestScript.test_if_is_none_dispatch.expect b/test/expect/TestScript.test_if_is_none_dispatch.expect
deleted file mode 100644 (file)
index 118aec9..0000000
+++ /dev/null
@@ -1,29 +0,0 @@
-graph(%input : Dynamic
-      %opt.1 : Dynamic?) {
-  %2 : int = prim::Constant[value=1]()
-  %3 : int = prim::Constant[value=2]()
-  %4 : int = prim::Constant[value=4]()
-  %x.1 : Dynamic = aten::add(%input, %3, %2)
-  %6 : None = prim::None()
-  %7 : bool = aten::__isnot__(%opt.1, %6)
-  %opt : Dynamic?, %x.3 : Dynamic = prim::If(%7)
-    block0() {
-      %opt.2 : Dynamic = aten::_unwrap_optional(%opt.1)
-      %x.2 : Dynamic = aten::add(%opt.2, %x.1, %2)
-      -> (%opt.2, %x.2)
-    }
-    block1() {
-      -> (%opt.1, %x.1)
-    }
-  %12 : None = prim::None()
-  %13 : bool = aten::__is__(%opt, %12)
-  %x : Dynamic = prim::If(%13)
-    block0() {
-      %x.4 : Dynamic = aten::add(%x.3, %4, %2)
-      -> (%x.4)
-    }
-    block1() {
-      -> (%x.3)
-    }
-  return (%x);
-}
index 438bc45..208a649 100644 (file)
@@ -4379,38 +4379,6 @@ a")
         inputs = self._make_scalar_vars([-1, 1], torch.int64)
         self.checkScript(func, inputs, optimize=True)
 
-    def test_if_is_none_dispatch(self):
-        class Test(torch.jit.ScriptModule):
-            __constants__ = ['b']
-
-            def __init__(self, b=None):
-                super(Test, self).__init__()
-                self.b = b
-
-            @torch.jit.script_method
-            def forward(self, input, opt=None):
-                # type: (Tensor, Optional[Tensor]) -> Tensor
-                x = input
-                if self.b is not None:
-                    x = self.b(input)
-
-                if self.b is None:
-                    x = input + 2
-
-                if opt is not None:
-                    opt = torch.jit._unwrap_optional(opt)
-                    x = opt + x
-
-                if opt is None:
-                    x = x + 4
-
-                return x
-
-        inputs = torch.zeros(1, 2)
-        self.assertExpectedGraph(Test().graph)
-        out = Test()(inputs)
-        self.assertEqual(out, inputs + 6)
-
     def test_explicit_bool_cast(self):
         with self.assertRaisesRegex(RuntimeError, "expected a boolean"):
             @torch.jit.script
index 9cc75f5..3a674c0 100644 (file)
@@ -522,10 +522,6 @@ std::shared_ptr<Graph> Graph::copy() {
   return new_g;
 }
 
-bool Value::mustBeNone() const {
-  return node_->kind() == prim::None;
-}
-
 std::string Value::uniqueNameBase() const {
   std::string name = uniqueName();
   std::string name_base = name;
index fab58f7..f4cf0d4 100644 (file)
@@ -138,7 +138,9 @@ public:
   bool isTensor() const {
     return type()->kind() == TypeKind::CompleteTensorType;
   }
-  TORCH_API bool mustBeNone() const;
+  bool isNone() const {
+    return type()->kind() == TypeKind::NoneType;
+  }
   size_t unique() const {
     return unique_;
   }
index c491ea3..d22f886 100644 (file)
@@ -131,12 +131,6 @@ private:
   TypePtr type;
 };
 
-static Value* asSimple(SugaredValuePtr value) {
-  if(SimpleValue* sv = dynamic_cast<SimpleValue*>(value.get())) {
-    return sv->getValue();
-  }
-  return nullptr;
-}
 // we consider _N where N is a number, to be a non-meaningful name
 // and do not record it as a unique name. This allows python printing to
 // be able to export and import more consistently named graphs
@@ -293,6 +287,12 @@ struct Environment {
   void setVar(const SourceRange& loc, const std::string& name, Value* value) {
     setSugaredVar(loc, name, std::make_shared<SimpleValue>(value));
   }
+  static Value* asSimple(SugaredValuePtr value) {
+    if(SimpleValue* sv = dynamic_cast<SimpleValue*>(value.get())) {
+      return sv->getValue();
+    }
+    return nullptr;
+  }
 
   void setSugaredVar(const SourceRange& loc, const std::string& name, SugaredValuePtr value) {
     Value* as_simple_value = asSimple(value);
@@ -1256,7 +1256,9 @@ private:
     return v;
   }
 
-  void emitIfElseBlocks(Value* cond_value, const If& stmt) {
+  void emitIf(const If& stmt) {
+    Value* cond_value = emitCond(stmt.cond());
+
     Node* n = graph->insertNode(create(prim::If, stmt.range(), 0));
     n->addInput(cond_value);
     auto* true_block = n->addBlock();
@@ -1341,62 +1343,6 @@ private:
     }
   }
 
-  void emitIf(const If& stmt) {
-    // NOTE: emitIf checks on If stmt condition to see if the cond AST kind == is/is not,
-    // for such cases we do meta programming and disable emitting the corresponding branches
-    Expr cond = stmt.cond();
-
-    if (cond.kind() != TK_IS && cond.kind() != TK_ISNOT) {
-      // emit normal IF stmt for cases except TK_IS and TK_ISNOT
-      Value* cond_value = emitCond(cond);
-      emitIfElseBlocks(cond_value, stmt);
-      return;
-    }
-    // meta programming on AST for is/is not cases and emit branches base on the possible output of cond
-    auto cond_op = BinOp(cond);
-    SugaredValuePtr lhs_val = emitSugaredExpr(cond_op.lhs(), 1);
-    SugaredValuePtr rhs_val = emitSugaredExpr(cond_op.rhs(), 1);
-
-    List<Stmt> always_none_branch = cond.kind() == TK_IS? stmt.trueBranch(): stmt.falseBranch();
-    List<Stmt> never_none_branch = cond.kind() == TK_IS? stmt.falseBranch(): stmt.trueBranch();
-
-    auto lhs_none= lhs_val->isNone();
-    auto rhs_none= rhs_val->isNone();
-
-    // Dispatch logic (A: ALWAYS, N: NEVER, M: MAYBE):
-    //
-    // AA, -> emit always_none_branch
-    // AN , NA-> emit never_none_branch
-    // MA, MM, MN, NM, NN, AM -> emit both conditional branches
-
-    if (lhs_none == ALWAYS && rhs_none == ALWAYS) {
-      // None is/is not None: only emit the always_none_branch
-      emitStatements(always_none_branch);
-    } else if ((lhs_none == ALWAYS && rhs_none == NEVER) ||
-        (lhs_none == NEVER && rhs_none == ALWAYS)){
-      // lhs_val/rhs_val with A/M: only emit never_none_branch
-      emitStatements(never_none_branch);
-    }
-    else {
-      // all other cases for lhs_val and rhs_val
-      // emit the whole If stmt as usual, finish emitCond first
-      auto lhs_range = cond_op.lhs().get()->range();
-      auto rhs_range = cond_op.rhs().get()->range();
-      auto kind = getNodeKind(cond.kind(), cond.get()->trees().size());
-      Value* cond_value = emitBuiltinCall(
-          cond.get()->range(),
-          *method.graph(),
-          kind,
-          c10::nullopt,
-          {lhs_val->asValue(lhs_range, method), rhs_val->asValue(rhs_range, method)},
-          {},
-          /*required=*/true);
-      emitIfElseBlocks(cond_value, stmt);
-
-    }
-
-  }
-
   // *********************** Loop Operators ************************************
   // Emits a loop operators conforming to the semantics specified at
   // https://github.com/onnx/onnx/blob/master/docs/Operators.md#experimental-loop
index 364a0d9..e6f1a70 100644 (file)
@@ -26,12 +26,6 @@ static inline std::vector<Value*> toValues(Graph& g, at::ArrayRef<NamedValue> nv
 // that separates their behavior from the AST -> IR converter itself.
 // This allows us to keep dependencies on python minimal.
 
-enum NoneStatus {
- ALWAYS,
- MAYBE,
- NEVER
-};
-
 struct SugaredValue : public std::enable_shared_from_this<SugaredValue> {
   // what is this node? for error reporting (e.g. Module, python function)
   virtual std::string kind() const = 0;
@@ -46,9 +40,6 @@ struct SugaredValue : public std::enable_shared_from_this<SugaredValue> {
   virtual std::shared_ptr<SugaredValue> attr(SourceRange loc, Method & m, const std::string& field) {
     throw ErrorReport(loc) << "attribute lookup is not defined on " << kind();
   }
-  virtual NoneStatus isNone() {
-    return NEVER;
-  }
 
   // use it as a vector of values, e.g. a tuple of values as return value from
   // a method invocation
@@ -99,14 +90,6 @@ struct TORCH_API SimpleValue : public SugaredValue {
   Value * asValue(SourceRange range, Method & m) override {
     return value;
   }
-  NoneStatus isNone() override {
-    if (value->mustBeNone())
-      return ALWAYS;
-    else if (value->type()->cast<OptionalType>())
-      return MAYBE;
-    else
-      return NEVER;
-  }
   std::vector<std::shared_ptr<SugaredValue>> asTuple(
       SourceRange loc,
       Method& m,