add isinstance static type checking for jit (#15076)
authorWanchao Liang <wanchaol@users.noreply.github.com>
Mon, 17 Dec 2018 23:18:51 +0000 (15:18 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Mon, 17 Dec 2018 23:21:49 +0000 (15:21 -0800)
Summary:
This PR add isinstance to do static type checking in JIT.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15076

Differential Revision: D13471067

Pulled By: wanchaol

fbshipit-source-id: d39b7ed5db9fcca4b503659d02cf7795950ea8ea

test/test_jit.py
torch/csrc/jit/script/compiler.cpp
torch/csrc/jit/script/compiler.h

index ac3a8c3..72bdf53 100644 (file)
@@ -4450,6 +4450,43 @@ a")
         for op, lhs, rhs in product(ops, type_literals, type_literals):
             test(op, [lhs, rhs])
 
+    def test_isinstance(self):
+        # test isinstance operator for static type checking
+        template = dedent('''
+        def func(x):
+            # type: ({type_hint}) -> bool
+            return isinstance(x, {typ})
+        ''')
+
+        def test(inp, typ, type_hint):
+            code = template.format(typ=typ, type_hint=type_hint)
+            scope = {}
+            exec(code, globals(), scope)
+            cu = torch.jit.CompilationUnit(code)
+            self.assertEqual(
+                cu.func(inp),
+                scope['func'](inp),
+                "Failed with typ: {}"
+                .format(typ)
+            )
+
+        inputs = [True, 1, 1.0, torch.tensor(1), [1, 2], (1.0,), [1, 2], 1]
+        type_literals = ['bool', 'int', 'float', 'torch.Tensor', 'list', 'tuple',
+                         '(list, tuple)', '(int, float, bool)']
+        type_annotations = ['bool', 'int', 'float', 'Tensor', 'List[int]', 'Tuple[float]',
+                            'List[int]', 'int']
+
+        # do zipping to try different types
+        for inp, typ, type_hint in zip(inputs, type_literals, type_annotations):
+            test(inp, typ, type_hint)
+
+        # test optional isintance check
+        with self.assertRaisesRegex(RuntimeError, "Optional isinstance check is not supported"):
+            @torch.jit.script
+            def opt_func(x):
+                # type: (Optional[int]) -> bool
+                return isinstance(x, int)
+
     def test_python_call(self):
         def pyfunc(a):
             return a * 3.0
index 7466e53..88d0844 100644 (file)
@@ -35,13 +35,6 @@ struct NoneValue : SugaredValue {
   }
 };
 
-// matched against for special handling of getattr expressions
-struct GetAttrValue : SugaredValue {
-  std::string kind() const override {
-    return "getattr";
-  }
-};
-
 struct PrintValue : public SugaredValue {
   std::string kind() const override {
     return "print";
@@ -328,6 +321,7 @@ struct Environment {
         {"int", std::make_shared<CastValue>(IntType::get(), prim::Int)},
         {"bool", std::make_shared<CastValue>(BoolType::get(), prim::Bool)},
         {"getattr", std::make_shared<GetAttrValue>()},
+        {"isinstance", std::make_shared<IsInstanceValue>()},
         // todo(zach): remove when we can correctly export torch.full via ONNX
         // or we have implicit conversion that can convert numbers to tensors
         {"_to_tensor", std::make_shared<CastValue>(DynamicType::get(), prim::NumToTensor)},
@@ -815,6 +809,7 @@ inline bool isSupportedListElementType(const TypePtr& type) {
       type->isSubtypeOf(NumberType::get());
 }
 
+c10::optional<std::string> parseBaseTypeName(const Expr& expr);
 TypePtr parseTypeFromExpr(const Expr& expr);
 c10::optional<std::pair<TypePtr, int32_t>> handleBroadcastList(const Expr& expr);
 
@@ -2020,6 +2015,21 @@ private:
       return NamedValue(attr.range(), attr.name().name(), emitExpr(attr.value()));
     });
   }
+
+  void checkApplyExpr(Apply& apply, SourceRange& loc) {
+    if (apply.inputs().size() != 2) {
+      throw ErrorReport(loc)
+          << Var(apply.callee()).name().name()
+          << " expected exactly two arguments but found "
+          << apply.inputs().size();
+    }
+    if (apply.attributes().size() > 0) {
+      throw ErrorReport(loc)
+          << Var(apply.callee()).name().name()
+          << " takes no keyword arguments";
+    }
+  }
+
   std::shared_ptr<SugaredValue> emitApplyExpr(Apply &apply, size_t n_binders) {
     auto sv = emitSugaredExpr(apply.callee(), 1);
     auto loc = apply.callee().range();
@@ -2035,14 +2045,7 @@ private:
       auto attributes = emitAttributes(apply.attributes());
       return emitForkExpr(loc, forked, inputs, attributes);
     } else if (auto annotate_value = dynamic_cast<AnnotateValue*>(sv.get())) {
-      if (apply.inputs().size() != 2) {
-        throw ErrorReport(loc)
-            << "expected exactly two arguments to attribute but found "
-            << apply.inputs().size();
-      }
-      if (apply.attributes().size() > 0) {
-        throw ErrorReport(loc) << "attribute takes no keyword arguments";
-      }
+      checkApplyExpr(apply, loc);
       TypePtr type = parseTypeFromExpr(apply.inputs()[0]);
       Value* expr = tryConvertToType(
           apply.range(),
@@ -2057,12 +2060,7 @@ private:
       }
       return std::make_shared<SimpleValue>(expr);
     } else if(auto getattr = dynamic_cast<GetAttrValue*>(sv.get())) {
-      if (apply.attributes().size() > 0) {
-        throw ErrorReport(loc) << "getattr takes no keyword arguments";
-      }
-      if (apply.inputs().size() != 2) {
-        throw ErrorReport(loc) << "getattr expects 2 inputs";
-      }
+      checkApplyExpr(apply, loc);
       auto obj = emitSugaredExpr(apply.inputs()[0], 1);
       auto selector = apply.inputs()[1];
       if (selector.kind() != TK_STRINGLITERAL) {
@@ -2070,6 +2068,45 @@ private:
       }
       const std::string& name = StringLiteral(selector).text();
       return obj->attr(apply.range(), method, name);
+    } else if (auto isinstance = dynamic_cast<IsInstanceValue*>(sv.get())) {
+      // NOTE: for `isinstance` builtin call in JIT, we only check the static types
+      // on the inputs to evaluate, and insert the corresponding constant node
+      std::function<bool(Expr, Expr)> isInstanceCheck = [&](Expr obj, Expr classinfo) {
+        if (classinfo.kind() == TK_TUPLE_LITERAL) {
+          // handle the case for recursive tuple classinfo
+          // return true if obj is an instance of any of the types
+          for (Expr e: TupleLiteral(classinfo).inputs()) {
+            if (isInstanceCheck(obj, e)) {
+              return true;
+            }
+          }
+          return false;
+        }
+        auto type_name = parseBaseTypeName(classinfo);
+        if (!type_name) {
+          throw ErrorReport(classinfo.range()) << "type must be a type identifier";
+        }
+        auto val = emitExpr(obj);
+        // Special casing for list and tuple since isintance(x, list) and isinstance(x, tuple)
+        // does not accept List[int] / Tuple[int] like subscript type annotation in python
+        if (*type_name == "list" && val->type()->cast<ListType>()) {
+          return true;
+        } else if (*type_name == "tuple" && val->type()->cast<TupleType>()) {
+          return true;
+        } else if (val->type()->cast<OptionalType>()) {
+          throw ErrorReport(loc)
+                << "Optional isinstance check is not supported, consider use is/isnot None instead";
+        } else {
+          TypePtr type = parseTypeFromExpr(classinfo);
+          if (val->type()->isSubtypeOf(type)) {
+            return true;
+          }
+        }
+        return false;
+      };
+      checkApplyExpr(apply, loc);
+      bool is_instance_val = isInstanceCheck(apply.inputs()[0], apply.inputs()[1]);
+      return std::make_shared<SimpleValue>(graph->insertConstant(is_instance_val, loc));
     } else {
       auto inputs = getNamedValues(apply.inputs(), true);
       auto attributes = emitAttributes(apply.attributes());
index e3e28aa..7abe842 100644 (file)
@@ -177,6 +177,22 @@ struct TORCH_API AnnotateValue : public SugaredValue {
   }
 };
 
+// matched against for special handling of getattr expressions
+struct TORCH_API GetAttrValue : SugaredValue {
+  GetAttrValue() = default;
+  std::string kind() const override {
+    return "getattr";
+  }
+};
+
+// matched against for special handling of isinstance expressions
+struct TORCH_API IsInstanceValue : SugaredValue {
+  IsInstanceValue() = default;
+  std::string kind() const override {
+    return "isinstance";
+  }
+};
+
 using Resolver = std::function<std::shared_ptr<SugaredValue>(const std::string& name, Method& m, const SourceRange& loc)>;
 
 inline std::shared_ptr<SugaredValue> nativeResolver(const std::string& name, Method& m, const SourceRange& loc){