void emitLoopCommon(
SourceRange range,
- c10::optional<Expr> max_trip_count,
- c10::optional<Expr> cond,
const List<Stmt>& body,
- c10::optional<Ident> itr_ident,
- bool in_list = false) {
+ const std::function<void(Value*, std::shared_ptr<Environment>)>&
+ current_element_assigner,
+ c10::optional<Expr> cond,
+ Value* max_trip_count_val = nullptr) {
+ Value* cond_val = nullptr;
Node* n = graph->insertNode(create(prim::Loop, range, 0));
- Value *max_trip_count_val, *cond_val;
+ WithInsertPoint guard(n);
+
+ if (!max_trip_count_val)
{
- WithInsertPoint guard(n);
- if (max_trip_count) {
- if (in_list) {
- auto listArg = emitExpr(max_trip_count.value());
-
- max_trip_count_val = emitBuiltinCall(
- max_trip_count->range(),
- *graph,
- aten::len,
- c10::nullopt,
- {listArg},
- {},
- /*required=*/true);
- } else {
- max_trip_count_val = ensureInt(
- max_trip_count->range(), emitExpr(max_trip_count.value()));
- }
- } else {
- max_trip_count_val = materializeConstant(
- std::numeric_limits<int64_t>::max(),
- *graph,
- range,
- integral_constants);
- }
- if (cond) {
- cond_val = emitCond(cond.value());
- } else {
- cond_val = graph->insertConstant(true, nullptr, range);
- }
+ max_trip_count_val = materializeConstant(
+ std::numeric_limits<int64_t>::max(),
+ *graph,
+ range,
+ integral_constants);
}
+
+ cond_val = (cond) ? emitCond(cond.value())
+ : graph->insertConstant(true, nullptr, range);
n->addInput(max_trip_count_val);
n->addInput(cond_val);
auto* body_block = n->addBlock();
{
pushFrame(body_block);
WithInsertPoint guard(body_block);
- if (itr_ident) {
- if (in_list) {
- // set user's iterator variable to the current element
- auto listArg = emitExpr(max_trip_count.value());
- trip_count = emitBuiltinCall(
- max_trip_count->range(),
- *graph,
- aten::select,
- c10::nullopt,
- {listArg, trip_count},
- {},
- /*required=*/true);
- }
- environment_stack->setVar(
- itr_ident->range(), itr_ident->name(), trip_count);
+
+ // current_element_assigner uses an induction variable
+ // to set a current element
+ if (current_element_assigner)
+ {
+ current_element_assigner(trip_count, environment_stack);
}
+
emitStatements(body);
// Also emit the conditional
- if (cond) {
- Value* body_cond_value = emitCond(cond.value());
- body_block->registerOutput(body_cond_value);
- } else {
- Value* cond_value_dummy = graph->insertConstant(true, nullptr, range);
- body_block->registerOutput(cond_value_dummy);
- }
-
+ cond_val = (cond) ? emitCond(cond.value())
+ : graph->insertConstant(true, nullptr, range);
+ body_block->registerOutput(cond_val);
auto body_frame = popFrame();
auto outer_frame = environment_stack;
throw ErrorReport(range)
<< "range() expects 1 argument but got " << args.size();
}
- emitLoopCommon(range, {args[0]}, {}, body, target);
+ auto max_trip_count_val = ensureInt(range, emitExpr(args[0]));
+ const auto& ident_name = target.name();
+ auto assigner = [ident_name, range](Value* index, std::shared_ptr<Environment> env) {
+ env->setVar(range, ident_name, index);
+ };
+ emitLoopCommon(range, body, assigner, {}, max_trip_count_val);
+ }
+
+ void emitForInListLoop(
+ const For& stmt,
+ const std::shared_ptr<torch::jit::script::SimpleValue>& siv) {
+ auto targets = stmt.targets();
+ auto itrs = stmt.itrs();
+ auto body = stmt.body();
+ auto& range = stmt.range();
+ auto target = Var(targets[0]).name();
+
+ auto listArg = siv->asValue(range, method);
+ auto max_trip_count_val = emitBuiltinCall(
+ range,
+ *graph,
+ aten::len,
+ c10::nullopt,
+ {listArg},
+ {},
+ /*required=*/true);
+ const auto& ident_name = target.name();
+ auto assigner = [ident_name, range, listArg, this](
+ Value* index, std::shared_ptr<Environment> env) {
+ auto cur_elm = emitBuiltinCall(
+ range,
+ *this->graph,
+ aten::select,
+ c10::nullopt,
+ {listArg, index},
+ {},
+ /*required=*/true);
+ env->setVar(range, ident_name, cur_elm);
+ };
+ emitLoopCommon(range, body, assigner, {}, max_trip_count_val);
}
void emitFor(const For& stmt) {
// check if a value is simple and list-like
if (auto siv = std::dynamic_pointer_cast<SimpleValue>(sv)) {
if (siv->getValue()->type()->kind() == TypeKind::ListType) {
- return emitLoopCommon(
- stmt.range(), {itrs[0]}, {}, body, {target}, true);
+ emitForInListLoop(stmt, siv);
+ return;
}
}
auto instances = sv->asTuple(stmt.range(), method);
void emitWhile(const While& stmt) {
auto cond = stmt.cond();
- emitLoopCommon(stmt.range(), {}, {cond}, stmt.body(), {});
+ emitLoopCommon(stmt.range(), stmt.body(), nullptr, cond, nullptr);
}
// Currently we do not support assigning exceptions to variables,