for op, lhs, rhs in product(ops, type_literals, type_literals):
test(op, [lhs, rhs])
+ def test_isinstance(self):
+ # test isinstance operator for static type checking
+ template = dedent('''
+ def func(x):
+ # type: ({type_hint}) -> bool
+ return isinstance(x, {typ})
+ ''')
+
+ def test(inp, typ, type_hint):
+ code = template.format(typ=typ, type_hint=type_hint)
+ scope = {}
+ exec(code, globals(), scope)
+ cu = torch.jit.CompilationUnit(code)
+ self.assertEqual(
+ cu.func(inp),
+ scope['func'](inp),
+ "Failed with typ: {}"
+ .format(typ)
+ )
+
+ inputs = [True, 1, 1.0, torch.tensor(1), [1, 2], (1.0,), [1, 2], 1]
+ type_literals = ['bool', 'int', 'float', 'torch.Tensor', 'list', 'tuple',
+ '(list, tuple)', '(int, float, bool)']
+ type_annotations = ['bool', 'int', 'float', 'Tensor', 'List[int]', 'Tuple[float]',
+ 'List[int]', 'int']
+
+ # do zipping to try different types
+ for inp, typ, type_hint in zip(inputs, type_literals, type_annotations):
+ test(inp, typ, type_hint)
+
+ # test optional isintance check
+ with self.assertRaisesRegex(RuntimeError, "Optional isinstance check is not supported"):
+ @torch.jit.script
+ def opt_func(x):
+ # type: (Optional[int]) -> bool
+ return isinstance(x, int)
+
def test_python_call(self):
def pyfunc(a):
return a * 3.0
}
};
-// matched against for special handling of getattr expressions
-struct GetAttrValue : SugaredValue {
- std::string kind() const override {
- return "getattr";
- }
-};
-
struct PrintValue : public SugaredValue {
std::string kind() const override {
return "print";
{"int", std::make_shared<CastValue>(IntType::get(), prim::Int)},
{"bool", std::make_shared<CastValue>(BoolType::get(), prim::Bool)},
{"getattr", std::make_shared<GetAttrValue>()},
+ {"isinstance", std::make_shared<IsInstanceValue>()},
// todo(zach): remove when we can correctly export torch.full via ONNX
// or we have implicit conversion that can convert numbers to tensors
{"_to_tensor", std::make_shared<CastValue>(DynamicType::get(), prim::NumToTensor)},
type->isSubtypeOf(NumberType::get());
}
+c10::optional<std::string> parseBaseTypeName(const Expr& expr);
TypePtr parseTypeFromExpr(const Expr& expr);
c10::optional<std::pair<TypePtr, int32_t>> handleBroadcastList(const Expr& expr);
return NamedValue(attr.range(), attr.name().name(), emitExpr(attr.value()));
});
}
+
+ void checkApplyExpr(Apply& apply, SourceRange& loc) {
+ if (apply.inputs().size() != 2) {
+ throw ErrorReport(loc)
+ << Var(apply.callee()).name().name()
+ << " expected exactly two arguments but found "
+ << apply.inputs().size();
+ }
+ if (apply.attributes().size() > 0) {
+ throw ErrorReport(loc)
+ << Var(apply.callee()).name().name()
+ << " takes no keyword arguments";
+ }
+ }
+
std::shared_ptr<SugaredValue> emitApplyExpr(Apply &apply, size_t n_binders) {
auto sv = emitSugaredExpr(apply.callee(), 1);
auto loc = apply.callee().range();
auto attributes = emitAttributes(apply.attributes());
return emitForkExpr(loc, forked, inputs, attributes);
} else if (auto annotate_value = dynamic_cast<AnnotateValue*>(sv.get())) {
- if (apply.inputs().size() != 2) {
- throw ErrorReport(loc)
- << "expected exactly two arguments to attribute but found "
- << apply.inputs().size();
- }
- if (apply.attributes().size() > 0) {
- throw ErrorReport(loc) << "attribute takes no keyword arguments";
- }
+ checkApplyExpr(apply, loc);
TypePtr type = parseTypeFromExpr(apply.inputs()[0]);
Value* expr = tryConvertToType(
apply.range(),
}
return std::make_shared<SimpleValue>(expr);
} else if(auto getattr = dynamic_cast<GetAttrValue*>(sv.get())) {
- if (apply.attributes().size() > 0) {
- throw ErrorReport(loc) << "getattr takes no keyword arguments";
- }
- if (apply.inputs().size() != 2) {
- throw ErrorReport(loc) << "getattr expects 2 inputs";
- }
+ checkApplyExpr(apply, loc);
auto obj = emitSugaredExpr(apply.inputs()[0], 1);
auto selector = apply.inputs()[1];
if (selector.kind() != TK_STRINGLITERAL) {
}
const std::string& name = StringLiteral(selector).text();
return obj->attr(apply.range(), method, name);
+ } else if (auto isinstance = dynamic_cast<IsInstanceValue*>(sv.get())) {
+ // NOTE: for `isinstance` builtin call in JIT, we only check the static types
+ // on the inputs to evaluate, and insert the corresponding constant node
+ std::function<bool(Expr, Expr)> isInstanceCheck = [&](Expr obj, Expr classinfo) {
+ if (classinfo.kind() == TK_TUPLE_LITERAL) {
+ // handle the case for recursive tuple classinfo
+ // return true if obj is an instance of any of the types
+ for (Expr e: TupleLiteral(classinfo).inputs()) {
+ if (isInstanceCheck(obj, e)) {
+ return true;
+ }
+ }
+ return false;
+ }
+ auto type_name = parseBaseTypeName(classinfo);
+ if (!type_name) {
+ throw ErrorReport(classinfo.range()) << "type must be a type identifier";
+ }
+ auto val = emitExpr(obj);
+ // Special casing for list and tuple since isintance(x, list) and isinstance(x, tuple)
+ // does not accept List[int] / Tuple[int] like subscript type annotation in python
+ if (*type_name == "list" && val->type()->cast<ListType>()) {
+ return true;
+ } else if (*type_name == "tuple" && val->type()->cast<TupleType>()) {
+ return true;
+ } else if (val->type()->cast<OptionalType>()) {
+ throw ErrorReport(loc)
+ << "Optional isinstance check is not supported, consider use is/isnot None instead";
+ } else {
+ TypePtr type = parseTypeFromExpr(classinfo);
+ if (val->type()->isSubtypeOf(type)) {
+ return true;
+ }
+ }
+ return false;
+ };
+ checkApplyExpr(apply, loc);
+ bool is_instance_val = isInstanceCheck(apply.inputs()[0], apply.inputs()[1]);
+ return std::make_shared<SimpleValue>(graph->insertConstant(is_instance_val, loc));
} else {
auto inputs = getNamedValues(apply.inputs(), true);
auto attributes = emitAttributes(apply.attributes());