Refactor EmitLoopCommon to make it more amenable to future extensions (#19341)
authorNikolay Korovaiko <korovaikon@gmail.com>
Thu, 18 Apr 2019 16:56:02 +0000 (09:56 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 18 Apr 2019 16:59:21 +0000 (09:59 -0700)
Summary:
This PR paves the way for support more iterator types in for-in loops.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19341

Differential Revision: D14992749

Pulled By: Krovatkin

fbshipit-source-id: e2d4c9465c8ec3fc74fbf23006dcb6783d91795f

torch/csrc/jit/script/compiler.cpp

index d17ce63..87152a2 100644 (file)
@@ -1301,44 +1301,26 @@ struct to_ir {
 
   void emitLoopCommon(
       SourceRange range,
-      c10::optional<Expr> max_trip_count,
-      c10::optional<Expr> cond,
       const List<Stmt>& body,
-      c10::optional<Ident> itr_ident,
-      bool in_list = false) {
+      const std::function<void(Value*, std::shared_ptr<Environment>)>&
+          current_element_assigner,
+      c10::optional<Expr> cond,
+      Value* max_trip_count_val = nullptr) {
+    Value* cond_val = nullptr;
     Node* n = graph->insertNode(create(prim::Loop, range, 0));
-    Value *max_trip_count_val, *cond_val;
+    WithInsertPoint guard(n);
+
+    if (!max_trip_count_val)
     {
-      WithInsertPoint guard(n);
-      if (max_trip_count) {
-        if (in_list) {
-          auto listArg = emitExpr(max_trip_count.value());
-
-          max_trip_count_val = emitBuiltinCall(
-              max_trip_count->range(),
-              *graph,
-              aten::len,
-              c10::nullopt,
-              {listArg},
-              {},
-              /*required=*/true);
-        } else {
-          max_trip_count_val = ensureInt(
-              max_trip_count->range(), emitExpr(max_trip_count.value()));
-        }
-      } else {
-        max_trip_count_val = materializeConstant(
-            std::numeric_limits<int64_t>::max(),
-            *graph,
-            range,
-            integral_constants);
-      }
-      if (cond) {
-        cond_val = emitCond(cond.value());
-      } else {
-        cond_val = graph->insertConstant(true, nullptr, range);
-      }
+      max_trip_count_val = materializeConstant(
+          std::numeric_limits<int64_t>::max(),
+          *graph,
+          range,
+          integral_constants);
     }
+
+    cond_val = (cond) ? emitCond(cond.value())
+                      : graph->insertConstant(true, nullptr, range);
     n->addInput(max_trip_count_val);
     n->addInput(cond_val);
     auto* body_block = n->addBlock();
@@ -1348,33 +1330,20 @@ struct to_ir {
     {
       pushFrame(body_block);
       WithInsertPoint guard(body_block);
-      if (itr_ident) {
-        if (in_list) {
-          // set user's iterator variable to the current element
-          auto listArg = emitExpr(max_trip_count.value());
-          trip_count = emitBuiltinCall(
-              max_trip_count->range(),
-              *graph,
-              aten::select,
-              c10::nullopt,
-              {listArg, trip_count},
-              {},
-              /*required=*/true);
-        }
-        environment_stack->setVar(
-            itr_ident->range(), itr_ident->name(), trip_count);
+
+      // current_element_assigner uses an induction variable
+      // to set a current element
+      if (current_element_assigner)
+      {
+        current_element_assigner(trip_count, environment_stack);
       }
+
       emitStatements(body);
 
       // Also emit the conditional
-      if (cond) {
-        Value* body_cond_value = emitCond(cond.value());
-        body_block->registerOutput(body_cond_value);
-      } else {
-        Value* cond_value_dummy = graph->insertConstant(true, nullptr, range);
-        body_block->registerOutput(cond_value_dummy);
-      }
-
+      cond_val = (cond) ? emitCond(cond.value())
+                        : graph->insertConstant(true, nullptr, range);
+      body_block->registerOutput(cond_val);
       auto body_frame = popFrame();
       auto outer_frame = environment_stack;
 
@@ -1411,7 +1380,46 @@ struct to_ir {
       throw ErrorReport(range)
           << "range() expects 1 argument but got " << args.size();
     }
-    emitLoopCommon(range, {args[0]}, {}, body, target);
+    auto max_trip_count_val = ensureInt(range, emitExpr(args[0]));
+    const auto& ident_name = target.name();
+    auto assigner = [ident_name, range](Value* index, std::shared_ptr<Environment> env) {
+      env->setVar(range, ident_name, index);
+    };
+    emitLoopCommon(range, body, assigner, {}, max_trip_count_val);
+  }
+
+  void emitForInListLoop(
+      const For& stmt,
+      const std::shared_ptr<torch::jit::script::SimpleValue>& siv) {
+    auto targets = stmt.targets();
+    auto itrs = stmt.itrs();
+    auto body = stmt.body();
+    auto& range = stmt.range();
+    auto target = Var(targets[0]).name();
+
+    auto listArg = siv->asValue(range, method);
+    auto max_trip_count_val = emitBuiltinCall(
+        range,
+        *graph,
+        aten::len,
+        c10::nullopt,
+        {listArg},
+        {},
+        /*required=*/true);
+    const auto& ident_name = target.name();
+    auto assigner = [ident_name, range, listArg, this](
+                        Value* index, std::shared_ptr<Environment> env) {
+      auto cur_elm = emitBuiltinCall(
+          range,
+          *this->graph,
+          aten::select,
+          c10::nullopt,
+          {listArg, index},
+          {},
+          /*required=*/true);
+      env->setVar(range, ident_name, cur_elm);
+    };
+    emitLoopCommon(range, body, assigner, {}, max_trip_count_val);
   }
 
   void emitFor(const For& stmt) {
@@ -1454,8 +1462,8 @@ struct to_ir {
     // check if a value is simple and list-like
     if (auto siv = std::dynamic_pointer_cast<SimpleValue>(sv)) {
       if (siv->getValue()->type()->kind() == TypeKind::ListType) {
-        return emitLoopCommon(
-            stmt.range(), {itrs[0]}, {}, body, {target}, true);
+        emitForInListLoop(stmt, siv);
+        return;
       }
     }
     auto instances = sv->asTuple(stmt.range(), method);
@@ -1477,7 +1485,7 @@ struct to_ir {
 
   void emitWhile(const While& stmt) {
     auto cond = stmt.cond();
-    emitLoopCommon(stmt.range(), {}, {cond}, stmt.body(), {});
+    emitLoopCommon(stmt.range(), stmt.body(), nullptr, cond, nullptr);
   }
 
   // Currently we do not support assigning exceptions to variables,