if (op.elseBlock())
llvm::append_range(elseYield, op.elseYield().getOperands());
+ // A list of indices for which we should upgrade the value yielded
+ // in the else to a select.
+ SmallVector<unsigned> elseYieldsToUpgradeToSelect;
+
// If the outer scf.if yields a value produced by the inner scf.if,
// only permit combining if the value yielded when the condition
// is false in the outer scf.if is the same value yielded when the
// If the correctness test passes, we will yield
// corresponding value from the inner scf.if
thenYield[tup.index()] = nestedIf.thenYield().getOperand(nestedIdx);
+ continue;
+ }
+
+ // Otherwise, we need to ensure the else block of the combined
+ // condition still returns the same value when the outer condition is
+ // true and the inner condition is false. This can be accomplished if
+ // the then value is defined outside the outer scf.if and we replace the
+ // value with a select that considers just the outer condition. Since
+ // the else region contains just the yield, its yielded value is
+ // defined outside the scf.if, by definition.
+
+ // If the then value is defined within the scf.if, bail.
+ if (tup.value().getParentRegion() == &op.getThenRegion()) {
+ return failure();
+ } else {
+ elseYieldsToUpgradeToSelect.push_back(tup.index());
}
}
Value newCondition = rewriter.create<arith::AndIOp>(
loc, op.getCondition(), nestedIf.getCondition());
auto newIf = rewriter.create<IfOp>(loc, op.getResultTypes(), newCondition);
+
+ SmallVector<Value> results;
+ llvm::append_range(results, newIf.getResults());
+ rewriter.setInsertionPoint(newIf);
+
+ for (auto idx : elseYieldsToUpgradeToSelect)
+ results[idx] = rewriter.create<arith::SelectOp>(
+ op.getLoc(), op.getCondition(), thenYield[idx], elseYield[idx]);
+
Block *newIfBlock = newIf.thenBlock();
if (newIfBlock)
rewriter.eraseOp(newIfBlock->getTerminator());
rewriter.setInsertionPointToEnd(newIf.elseBlock());
rewriter.create<YieldOp>(loc, elseYield);
}
- rewriter.replaceOp(op, newIf.getResults());
+ rewriter.replaceOp(op, results);
return success();
}
};
// CHECK: %[[PRE0:.*]] = "test.op"() : () -> i32
// CHECK: %[[PRE1:.*]] = "test.op1"() : () -> i32
// CHECK: %[[COND:.*]] = arith.andi %[[ARG0]], %[[ARG1]]
-// CHECK: %[[RES:.*]] = arith.select %[[COND]], %[[PRE0]], %[[PRE1]]
+// CHECK: %[[RES:.*]] = arith.select %[[ARG0]], %[[PRE0]], %[[PRE1]]
// CHECK: scf.if %[[COND]]
// CHECK: "test.run"() : () -> ()
// CHECK: }
}
return %r#0, %r#1, %r#2, %r#3 : i32, f32, i32, i8
}
+
// -----
// CHECK-LABEL: func @if_condition_swap