From 2d0d153288c3b186c7232a71021b4f1648c14def Mon Sep 17 00:00:00 2001 From: Nikolay Korovaiko Date: Thu, 18 Apr 2019 09:56:02 -0700 Subject: [PATCH] Refactor EmitLoopCommon to make it more amenable to future extensions (#19341) Summary: This PR paves the way for support more iterator types in for-in loops. Pull Request resolved: https://github.com/pytorch/pytorch/pull/19341 Differential Revision: D14992749 Pulled By: Krovatkin fbshipit-source-id: e2d4c9465c8ec3fc74fbf23006dcb6783d91795f --- torch/csrc/jit/script/compiler.cpp | 130 ++++++++++++++++++++----------------- 1 file changed, 69 insertions(+), 61 deletions(-) diff --git a/torch/csrc/jit/script/compiler.cpp b/torch/csrc/jit/script/compiler.cpp index d17ce63..87152a2 100644 --- a/torch/csrc/jit/script/compiler.cpp +++ b/torch/csrc/jit/script/compiler.cpp @@ -1301,44 +1301,26 @@ struct to_ir { void emitLoopCommon( SourceRange range, - c10::optional max_trip_count, - c10::optional cond, const List& body, - c10::optional itr_ident, - bool in_list = false) { + const std::function)>& + current_element_assigner, + c10::optional cond, + Value* max_trip_count_val = nullptr) { + Value* cond_val = nullptr; Node* n = graph->insertNode(create(prim::Loop, range, 0)); - Value *max_trip_count_val, *cond_val; + WithInsertPoint guard(n); + + if (!max_trip_count_val) { - WithInsertPoint guard(n); - if (max_trip_count) { - if (in_list) { - auto listArg = emitExpr(max_trip_count.value()); - - max_trip_count_val = emitBuiltinCall( - max_trip_count->range(), - *graph, - aten::len, - c10::nullopt, - {listArg}, - {}, - /*required=*/true); - } else { - max_trip_count_val = ensureInt( - max_trip_count->range(), emitExpr(max_trip_count.value())); - } - } else { - max_trip_count_val = materializeConstant( - std::numeric_limits::max(), - *graph, - range, - integral_constants); - } - if (cond) { - cond_val = emitCond(cond.value()); - } else { - cond_val = graph->insertConstant(true, nullptr, range); - } + max_trip_count_val = materializeConstant( + std::numeric_limits::max(), + *graph, + range, + integral_constants); } + + cond_val = (cond) ? emitCond(cond.value()) + : graph->insertConstant(true, nullptr, range); n->addInput(max_trip_count_val); n->addInput(cond_val); auto* body_block = n->addBlock(); @@ -1348,33 +1330,20 @@ struct to_ir { { pushFrame(body_block); WithInsertPoint guard(body_block); - if (itr_ident) { - if (in_list) { - // set user's iterator variable to the current element - auto listArg = emitExpr(max_trip_count.value()); - trip_count = emitBuiltinCall( - max_trip_count->range(), - *graph, - aten::select, - c10::nullopt, - {listArg, trip_count}, - {}, - /*required=*/true); - } - environment_stack->setVar( - itr_ident->range(), itr_ident->name(), trip_count); + + // current_element_assigner uses an induction variable + // to set a current element + if (current_element_assigner) + { + current_element_assigner(trip_count, environment_stack); } + emitStatements(body); // Also emit the conditional - if (cond) { - Value* body_cond_value = emitCond(cond.value()); - body_block->registerOutput(body_cond_value); - } else { - Value* cond_value_dummy = graph->insertConstant(true, nullptr, range); - body_block->registerOutput(cond_value_dummy); - } - + cond_val = (cond) ? emitCond(cond.value()) + : graph->insertConstant(true, nullptr, range); + body_block->registerOutput(cond_val); auto body_frame = popFrame(); auto outer_frame = environment_stack; @@ -1411,7 +1380,46 @@ struct to_ir { throw ErrorReport(range) << "range() expects 1 argument but got " << args.size(); } - emitLoopCommon(range, {args[0]}, {}, body, target); + auto max_trip_count_val = ensureInt(range, emitExpr(args[0])); + const auto& ident_name = target.name(); + auto assigner = [ident_name, range](Value* index, std::shared_ptr env) { + env->setVar(range, ident_name, index); + }; + emitLoopCommon(range, body, assigner, {}, max_trip_count_val); + } + + void emitForInListLoop( + const For& stmt, + const std::shared_ptr& siv) { + auto targets = stmt.targets(); + auto itrs = stmt.itrs(); + auto body = stmt.body(); + auto& range = stmt.range(); + auto target = Var(targets[0]).name(); + + auto listArg = siv->asValue(range, method); + auto max_trip_count_val = emitBuiltinCall( + range, + *graph, + aten::len, + c10::nullopt, + {listArg}, + {}, + /*required=*/true); + const auto& ident_name = target.name(); + auto assigner = [ident_name, range, listArg, this]( + Value* index, std::shared_ptr env) { + auto cur_elm = emitBuiltinCall( + range, + *this->graph, + aten::select, + c10::nullopt, + {listArg, index}, + {}, + /*required=*/true); + env->setVar(range, ident_name, cur_elm); + }; + emitLoopCommon(range, body, assigner, {}, max_trip_count_val); } void emitFor(const For& stmt) { @@ -1454,8 +1462,8 @@ struct to_ir { // check if a value is simple and list-like if (auto siv = std::dynamic_pointer_cast(sv)) { if (siv->getValue()->type()->kind() == TypeKind::ListType) { - return emitLoopCommon( - stmt.range(), {itrs[0]}, {}, body, {target}, true); + emitForInListLoop(stmt, siv); + return; } } auto instances = sv->asTuple(stmt.range(), method); @@ -1477,7 +1485,7 @@ struct to_ir { void emitWhile(const While& stmt) { auto cond = stmt.cond(); - emitLoopCommon(stmt.range(), {}, {cond}, stmt.body(), {}); + emitLoopCommon(stmt.range(), stmt.body(), nullptr, cond, nullptr); } // Currently we do not support assigning exceptions to variables, -- 2.7.4