/// Detensorize linalg ops involved in control-flow within a function.
///
- /// This model starts from CondBranchOps within a function. For each cond_br,
- /// the model then walks the use-def chain for the branch's condition
- /// backwards in order to understand where the condition's value comes from.
- /// If the condition value is (indirectly) computed by a linalg op that can be
- /// detensored, the model then continues walking the use-def chain in order to
- /// understand where the linalg op's operands come from. This leads to
- /// discovering a "detensoring component". A detensoring component is the set
- /// of operations + block arguments that are involved in control-flow AND can
- /// be detensored.
- ///
- /// For examples where this model succeeds to discover a detensoring
- /// component, see:
- /// - test/Dialect/Linalg/detensorize_while.mlir
- /// - test/Dialect/Linalg/detesorize_while_pure_cf.mlir.
- ///
- /// For an example where this model marks control-flow as "non-detensorable",
- /// see:
- /// - test/Dialect/Linalg/detensorize_while_failure.mlir
- class PureControlFlowDetectionModel : public CostModel {
+ /// This model starts from BranchOps and CondBranchOps within a function. For
+ /// each such branch, the model then walks the use-def chain for the branch's
+ /// condition backwards in order to understand where the condition's value
+ /// comes from. If the condition value is (indirectly) computed by a linalg op
+ /// that can be detensored, the model then continues walking the use-def chain
+ /// in order to understand where the linalg op's operands come from. This
+ /// leads to discovering a "detensoring component". A detensoring component is
+ /// the set of operations + block arguments that are involved in control-flow
+ /// AND can be detensored.
+ class ControlFlowDetectionModel : public CostModel {
public:
void compute(FuncOp func, DetensorizeTypeConverter typeConverter,
DenseSet<Operation *> &opsToDetensor,
for (PredecessorIterator pred = ownerBlock->pred_begin();
pred != ownerBlock->pred_end(); ++pred) {
- BranchOpInterface terminator =
+ BranchOpInterface predTerminator =
dyn_cast<BranchOpInterface>((*pred)->getTerminator());
// TODO: For now, we give up if any of the control-flow components
// in a function is not detensorable. Fix that.
- if (!terminator) {
+ if (!predTerminator) {
opsToDetensor.clear();
blockArgsToDetensor.clear();
return;
}
auto ownerBlockOperands =
- terminator.getSuccessorOperands(pred.getSuccessorIndex());
+ predTerminator.getSuccessorOperands(pred.getSuccessorIndex());
if (!ownerBlockOperands || ownerBlockOperands->empty())
continue;
if (opsToDetensor.count(genericOp))
continue;
- // TODO: For now, we give up if any of the control-flow components
- // in a function is not detensorable. Fix that.
+ // The op should not be detensored, give up on it but continue with
+ // discovering the rest of the control-flow component.
if (!shouldBeDetensored(genericOp, typeConverter)) {
- opsToDetensor.clear();
- blockArgsToDetensor.clear();
- return;
+ continue;
}
opsToDetensor.insert(genericOp);
for (Value scalarOpOperand : currentItemDefiningOp->getOperands())
workList.push_back(scalarOpOperand);
}
+
+ // Since the cost model gives up on some ops (see the details of step 2.2
+ // above), block arguments that correspond to the values produced by those
+ // ops should not be detensored as well.
+
+ DenseSet<BlockArgument> blockArgsToRemove;
+
+ for (auto &blockArg : blockArgsToDetensor) {
+ Block *block = blockArg.getParentBlock();
+
+ // For the potentially detensorable block argument, find the
+ // correpsonding operands in predecessor blocks.
+ for (PredecessorIterator pred = block->pred_begin();
+ pred != block->pred_end(); ++pred) {
+ BranchOpInterface terminator =
+ dyn_cast<BranchOpInterface>((*pred)->getTerminator());
+ auto blockOperands =
+ terminator.getSuccessorOperands(pred.getSuccessorIndex());
+
+ if (!blockOperands || blockOperands->empty())
+ continue;
+
+ Operation *definingOp =
+ terminator
+ ->getOperand(blockOperands->getBeginOperandIndex() +
+ blockArg.getArgNumber())
+ .getDefiningOp();
+
+ // If the operand is defined by a GenericOp that will not be
+ // detensored, then do not detensor the corresponding block argument.
+ if (dyn_cast_or_null<GenericOp>(definingOp) &&
+ opsToDetensor.count(definingOp) == 0) {
+ blockArgsToRemove.insert(blockArg);
+ break;
+ }
+ }
+ }
+
+ for (auto &blockArg : blockArgsToRemove) {
+ blockArgsToDetensor.erase(blockArg);
+ }
}
};
blockArgsToDetensor);
} else {
- PureControlFlowDetectionModel costModel;
+ ControlFlowDetectionModel costModel;
costModel.compute(getFunction(), typeConverter, opsToDetensor,
blockArgsToDetensor);
}
// DET-ALL: return %{{.*}} : tensor<i32>
// DET-ALL: }
-// Try to detensor pure control-flow. However, that fails since the potential
-// detensorable component contains some ops that cannot be detensored.
-//
// DET-CF-LABEL: func @main
// DET-CF-SAME: (%{{.*}}: tensor<10xi32>, %{{.*}}: tensor<i32>)
// DET-CF: br ^[[bb1:.*]](%{{.*}} : tensor<10xi32>)
// DET-CF: ^bb1(%{{.*}}: tensor<10xi32>)
// DET-CF: %{{.*}} = linalg.generic {{{.*}}} ins(%{{.*}} : tensor<10xi32>) outs(%{{.*}} : tensor<i32>) {
-// DET-CF: %{{.*}} = linalg.generic {{{.*}}} ins(%{{.*}}, %{{.*}} : tensor<i32>, tensor<i32>) outs(%{{.*}} : tensor<i1>) {
+// DET-CF: tensor.extract %{{.*}}[] : tensor<i32>
+// DET-CF: tensor.extract %{{.*}}[] : tensor<i32>
+// DET-CF: cmpi slt, %{{.*}}, %{{.*}} : i32
// DET-CF: cond_br %{{.*}}, ^bb2(%{{.*}} : tensor<i32>), ^bb3(%{{.*}} : tensor<i32>)
// DET-CF: ^bb2(%{{.*}}: tensor<i32>)
// DET-CF: %{{.*}} = linalg.generic {{{.*}}} ins(%{{.*}} : tensor<i32>) outs(%{{.*}} : tensor<10xi32>) {