+++ /dev/null
-graph(%input : Dynamic
- %opt.1 : Dynamic?) {
- %2 : int = prim::Constant[value=1]()
- %3 : int = prim::Constant[value=2]()
- %4 : int = prim::Constant[value=4]()
- %x.1 : Dynamic = aten::add(%input, %3, %2)
- %6 : None = prim::None()
- %7 : bool = aten::__isnot__(%opt.1, %6)
- %opt : Dynamic?, %x.3 : Dynamic = prim::If(%7)
- block0() {
- %opt.2 : Dynamic = aten::_unwrap_optional(%opt.1)
- %x.2 : Dynamic = aten::add(%opt.2, %x.1, %2)
- -> (%opt.2, %x.2)
- }
- block1() {
- -> (%opt.1, %x.1)
- }
- %12 : None = prim::None()
- %13 : bool = aten::__is__(%opt, %12)
- %x : Dynamic = prim::If(%13)
- block0() {
- %x.4 : Dynamic = aten::add(%x.3, %4, %2)
- -> (%x.4)
- }
- block1() {
- -> (%x.3)
- }
- return (%x);
-}
inputs = self._make_scalar_vars([-1, 1], torch.int64)
self.checkScript(func, inputs, optimize=True)
- def test_if_is_none_dispatch(self):
- class Test(torch.jit.ScriptModule):
- __constants__ = ['b']
-
- def __init__(self, b=None):
- super(Test, self).__init__()
- self.b = b
-
- @torch.jit.script_method
- def forward(self, input, opt=None):
- # type: (Tensor, Optional[Tensor]) -> Tensor
- x = input
- if self.b is not None:
- x = self.b(input)
-
- if self.b is None:
- x = input + 2
-
- if opt is not None:
- opt = torch.jit._unwrap_optional(opt)
- x = opt + x
-
- if opt is None:
- x = x + 4
-
- return x
-
- inputs = torch.zeros(1, 2)
- self.assertExpectedGraph(Test().graph)
- out = Test()(inputs)
- self.assertEqual(out, inputs + 6)
-
def test_explicit_bool_cast(self):
with self.assertRaisesRegex(RuntimeError, "expected a boolean"):
@torch.jit.script
TypePtr type;
};
-static Value* asSimple(SugaredValuePtr value) {
- if(SimpleValue* sv = dynamic_cast<SimpleValue*>(value.get())) {
- return sv->getValue();
- }
- return nullptr;
-}
// we consider _N where N is a number, to be a non-meaningful name
// and do not record it as a unique name. This allows python printing to
// be able to export and import more consistently named graphs
void setVar(const SourceRange& loc, const std::string& name, Value* value) {
setSugaredVar(loc, name, std::make_shared<SimpleValue>(value));
}
+ static Value* asSimple(SugaredValuePtr value) {
+ if(SimpleValue* sv = dynamic_cast<SimpleValue*>(value.get())) {
+ return sv->getValue();
+ }
+ return nullptr;
+ }
void setSugaredVar(const SourceRange& loc, const std::string& name, SugaredValuePtr value) {
Value* as_simple_value = asSimple(value);
return v;
}
- void emitIfElseBlocks(Value* cond_value, const If& stmt) {
+ void emitIf(const If& stmt) {
+ Value* cond_value = emitCond(stmt.cond());
+
Node* n = graph->insertNode(create(prim::If, stmt.range(), 0));
n->addInput(cond_value);
auto* true_block = n->addBlock();
}
}
- void emitIf(const If& stmt) {
- // NOTE: emitIf checks on If stmt condition to see if the cond AST kind == is/is not,
- // for such cases we do meta programming and disable emitting the corresponding branches
- Expr cond = stmt.cond();
-
- if (cond.kind() != TK_IS && cond.kind() != TK_ISNOT) {
- // emit normal IF stmt for cases except TK_IS and TK_ISNOT
- Value* cond_value = emitCond(cond);
- emitIfElseBlocks(cond_value, stmt);
- return;
- }
- // meta programming on AST for is/is not cases and emit branches base on the possible output of cond
- auto cond_op = BinOp(cond);
- SugaredValuePtr lhs_val = emitSugaredExpr(cond_op.lhs(), 1);
- SugaredValuePtr rhs_val = emitSugaredExpr(cond_op.rhs(), 1);
-
- List<Stmt> always_none_branch = cond.kind() == TK_IS? stmt.trueBranch(): stmt.falseBranch();
- List<Stmt> never_none_branch = cond.kind() == TK_IS? stmt.falseBranch(): stmt.trueBranch();
-
- auto lhs_none= lhs_val->isNone();
- auto rhs_none= rhs_val->isNone();
-
- // Dispatch logic (A: ALWAYS, N: NEVER, M: MAYBE):
- //
- // AA, -> emit always_none_branch
- // AN , NA-> emit never_none_branch
- // MA, MM, MN, NM, NN, AM -> emit both conditional branches
-
- if (lhs_none == ALWAYS && rhs_none == ALWAYS) {
- // None is/is not None: only emit the always_none_branch
- emitStatements(always_none_branch);
- } else if ((lhs_none == ALWAYS && rhs_none == NEVER) ||
- (lhs_none == NEVER && rhs_none == ALWAYS)){
- // lhs_val/rhs_val with A/M: only emit never_none_branch
- emitStatements(never_none_branch);
- }
- else {
- // all other cases for lhs_val and rhs_val
- // emit the whole If stmt as usual, finish emitCond first
- auto lhs_range = cond_op.lhs().get()->range();
- auto rhs_range = cond_op.rhs().get()->range();
- auto kind = getNodeKind(cond.kind(), cond.get()->trees().size());
- Value* cond_value = emitBuiltinCall(
- cond.get()->range(),
- *method.graph(),
- kind,
- c10::nullopt,
- {lhs_val->asValue(lhs_range, method), rhs_val->asValue(rhs_range, method)},
- {},
- /*required=*/true);
- emitIfElseBlocks(cond_value, stmt);
-
- }
-
- }
-
// *********************** Loop Operators ************************************
// Emits a loop operators conforming to the semantics specified at
// https://github.com/onnx/onnx/blob/master/docs/Operators.md#experimental-loop
// that separates their behavior from the AST -> IR converter itself.
// This allows us to keep dependencies on python minimal.
-enum NoneStatus {
- ALWAYS,
- MAYBE,
- NEVER
-};
-
struct SugaredValue : public std::enable_shared_from_this<SugaredValue> {
// what is this node? for error reporting (e.g. Module, python function)
virtual std::string kind() const = 0;
virtual std::shared_ptr<SugaredValue> attr(SourceRange loc, Method & m, const std::string& field) {
throw ErrorReport(loc) << "attribute lookup is not defined on " << kind();
}
- virtual NoneStatus isNone() {
- return NEVER;
- }
// use it as a vector of values, e.g. a tuple of values as return value from
// a method invocation
Value * asValue(SourceRange range, Method & m) override {
return value;
}
- NoneStatus isNone() override {
- if (value->mustBeNone())
- return ALWAYS;
- else if (value->type()->cast<OptionalType>())
- return MAYBE;
- else
- return NEVER;
- }
std::vector<std::shared_ptr<SugaredValue>> asTuple(
SourceRange loc,
Method& m,