/// Verifier for ExtractSliceOp.
LogicalResult ExtractSliceOp::verify() {
// Verify result type against inferred type.
- auto expectedType = ExtractSliceOp::inferResultType(
+ RankedTensorType expectedType = ExtractSliceOp::inferResultType(
getSourceType(), getMixedOffsets(), getMixedSizes(), getMixedStrides());
- auto result = isRankReducedType(expectedType.cast<ShapedType>(), getType());
+ SliceVerificationResult result = isRankReducedType(expectedType, getType());
return produceSliceErrorMsg(result, *this, expectedType);
}
build(b, result, source, dest, offsetValues, sizeValues, strideValues);
}
+/// Rank-reducing type verification for both InsertSliceOp and
+/// ParallelInsertSliceOp.
static SliceVerificationResult
verifyInsertSliceOp(ShapedType srcType, ShapedType dstType,
ArrayAttr staticOffsets, ArrayAttr staticSizes,
ArrayAttr staticStrides,
ShapedType *expectedType = nullptr) {
// insert_slice is the inverse of extract_slice, use the same type inference.
- auto expected = ExtractSliceOp::inferResultType(
- dstType, extractFromI64ArrayAttr(staticOffsets),
- extractFromI64ArrayAttr(staticSizes),
- extractFromI64ArrayAttr(staticStrides))
- .cast<ShapedType>();
+ RankedTensorType expected = ExtractSliceOp::inferResultType(
+ dstType, extractFromI64ArrayAttr(staticOffsets),
+ extractFromI64ArrayAttr(staticSizes),
+ extractFromI64ArrayAttr(staticStrides));
if (expectedType)
*expectedType = expected;
return isRankReducedType(expected, srcType);
/// Verifier for InsertSliceOp.
LogicalResult InsertSliceOp::verify() {
ShapedType expectedType;
- auto result =
+ SliceVerificationResult result =
verifyInsertSliceOp(getSourceType(), getType(), getStaticOffsets(),
getStaticSizes(), getStaticStrides(), &expectedType);
return produceSliceErrorMsg(result, *this, expectedType);
/// If we have two consecutive InsertSliceOp writing to the same slice, we
/// can mutate the second InsertSliceOp's destination to the first one's.
+/// This works similarly when the second op is a ParallelInsertSliceOp.
///
/// Example:
///
/// ```mlir
/// %1 = tensor.insert_slice %slice1 into %input[0, 0] [64, 64] [1, 1]
/// ```
-static LogicalResult foldInsertAfterInsertSlice(InsertSliceOp insertOp) {
- auto prevInsertOp = insertOp.getDest().getDefiningOp<InsertSliceOp>();
+///
+/// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
+template <typename InsertOpTy>
+static LogicalResult foldInsertAfterInsertSlice(InsertOpTy insertOp) {
+ auto prevInsertOp = insertOp.getDest().template getDefiningOp<InsertOpTy>();
auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
if (!prevInsertOp ||
return success();
}
-OpFoldResult InsertSliceOp::fold(ArrayRef<Attribute>) {
- if (getSourceType().hasStaticShape() && getType().hasStaticShape() &&
- getSourceType() == getType() &&
- succeeded(foldIdentityOffsetSizeAndStrideOpInterface(*this, getType())))
- return this->getSource();
- if (succeeded(foldInsertAfterInsertSlice(*this)))
- return getResult();
- return OpFoldResult();
+/// Same logic for folding InsertSliceOp and ParallelInsertSliceOp, the return
+/// type varies though so we wrap it in a FailureOr.
+///
+/// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
+template <typename InsertOpTy>
+FailureOr<OpFoldResult> foldInsertOp(InsertOpTy insertOp, ArrayRef<Attribute>) {
+ if (insertOp.getSourceType().hasStaticShape() &&
+ insertOp.getDestType().hasStaticShape() &&
+ insertOp.getSourceType() == insertOp.getDestType() &&
+ succeeded(foldIdentityOffsetSizeAndStrideOpInterface(
+ insertOp, insertOp.getDestType())))
+ return static_cast<OpFoldResult>(insertOp.getSource());
+ if (succeeded(foldInsertAfterInsertSlice(insertOp))) {
+ // InsertSliceOp has 1 result but ParallelInsertSliceOp has none and should
+ // return OpFoldResult().
+ if (std::is_same<InsertOpTy, InsertSliceOp>::value)
+ return static_cast<OpFoldResult>(insertOp->getResult(0));
+ else
+ return OpFoldResult();
+ }
+ return failure();
+}
+
+OpFoldResult InsertSliceOp::fold(ArrayRef<Attribute> operands) {
+ auto maybeOpFoldResult = foldInsertOp(*this, operands);
+ return failed(maybeOpFoldResult) ? OpFoldResult() : *maybeOpFoldResult;
}
LogicalResult InsertSliceOp::reifyResultShapes(
namespace {
/// Pattern to rewrite a insert_slice op with constant arguments.
+///
+/// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
+template <typename InsertOpTy>
class InsertSliceOpConstantArgumentFolder final
- : public OpRewritePattern<InsertSliceOp> {
+ : public OpRewritePattern<InsertOpTy> {
public:
- using OpRewritePattern<InsertSliceOp>::OpRewritePattern;
+ using OpRewritePattern<InsertOpTy>::OpRewritePattern;
- LogicalResult matchAndRewrite(InsertSliceOp insertSliceOp,
+ LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
PatternRewriter &rewriter) const override {
// No constant operand, just return.
if (llvm::none_of(insertSliceOp.getOperands(), [](Value operand) {
// Create the new op in canonical form.
auto sourceType = ExtractSliceOp::inferCanonicalRankReducedResultType(
- insertSliceOp.getSourceType().getRank(), insertSliceOp.getType(),
+ insertSliceOp.getSourceType().getRank(), insertSliceOp.getDestType(),
mixedOffsets, mixedSizes, mixedStrides);
Value toInsert = insertSliceOp.getSource();
- if (sourceType != insertSliceOp.getSourceType())
+ if (sourceType != insertSliceOp.getSourceType()) {
+ OpBuilder::InsertionGuard g(rewriter);
+ // The only difference between InsertSliceOp and ParallelInsertSliceOp is
+ // the the insertion point is just before the ParallelCombiningOp in the
+ // parallel case.
+ if (std::is_same<InsertOpTy, ParallelInsertSliceOp>::value)
+ rewriter.setInsertionPoint(insertSliceOp->getParentOp());
toInsert = rewriter.create<tensor::CastOp>(insertSliceOp.getLoc(),
sourceType, toInsert);
- rewriter.replaceOpWithNewOp<InsertSliceOp>(
+ }
+ rewriter.replaceOpWithNewOp<InsertOpTy>(
insertSliceOp, toInsert, insertSliceOp.getDest(), mixedOffsets,
mixedSizes, mixedStrides);
return success();
/// Note: When folding a cast on the destination tensor, the result of the
/// insert_slice operation is casted to ensure that the type of the result did
/// not change.
-struct InsertSliceOpCastFolder final : public OpRewritePattern<InsertSliceOp> {
- using OpRewritePattern<InsertSliceOp>::OpRewritePattern;
+///
+/// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
+template <typename InsertOpTy>
+struct InsertSliceOpCastFolder final : public OpRewritePattern<InsertOpTy> {
+ using OpRewritePattern<InsertOpTy>::OpRewritePattern;
- LogicalResult matchAndRewrite(InsertSliceOp insertSliceOp,
+ LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
PatternRewriter &rewriter) const override {
if (llvm::any_of(insertSliceOp.getOperands(), [](Value operand) {
return matchPattern(operand, matchConstantIndex());
auto src =
(sourceCastSource ? *sourceCastSource : insertSliceOp.getSource());
auto dst = (destCastSource ? *destCastSource : insertSliceOp.getDest());
-
- auto srcType = src.getType().cast<ShapedType>();
- auto dstType = dst.getType().cast<ShapedType>();
+ auto srcType = src.getType().template cast<ShapedType>();
+ auto dstType = dst.getType().template cast<ShapedType>();
if (verifyInsertSliceOp(srcType, dstType, insertSliceOp.getStaticOffsets(),
insertSliceOp.getStaticSizes(),
insertSliceOp.getStaticStrides()) !=
SliceVerificationResult::Success)
return failure();
- Value replacement = rewriter.create<InsertSliceOp>(
+ Operation *replacement = rewriter.create<InsertOpTy>(
insertSliceOp.getLoc(), src, dst, insertSliceOp.getMixedOffsets(),
insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides());
- if (replacement.getType() != insertSliceOp.getType()) {
- replacement = rewriter.create<tensor::CastOp>(
- insertSliceOp.getLoc(), insertSliceOp.getType(), replacement);
+ // In the parallel case there is no result and so nothing to cast.
+ bool isParallelInsert =
+ std::is_same<InsertOpTy, ParallelInsertSliceOp>::value;
+ if (!isParallelInsert && dst.getType() != insertSliceOp.getDestType()) {
+ replacement = rewriter.create<tensor::CastOp>(insertSliceOp.getLoc(),
+ insertSliceOp.getDestType(),
+ replacement->getResult(0));
}
- rewriter.replaceOp(insertSliceOp, replacement);
+ rewriter.replaceOp(insertSliceOp, replacement->getResults());
return success();
}
};
/// %r = tensor.insert_slice %tmp into %1[...] [64, 64] [1, 1]
/// : tensor<64x64xf32> into ...
/// ```
+///
+/// This patterns works with both InsertSliceOp and ParallelInsertSliceOp.
+template <typename InsertOpTy>
struct InsertSliceOpSourceCastInserter final
- : public OpRewritePattern<InsertSliceOp> {
- using OpRewritePattern<InsertSliceOp>::OpRewritePattern;
+ : public OpRewritePattern<InsertOpTy> {
+ using OpRewritePattern<InsertOpTy>::OpRewritePattern;
- LogicalResult matchAndRewrite(InsertSliceOp insertSliceOp,
+ LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
PatternRewriter &rewriter) const override {
RankedTensorType srcType = insertSliceOp.getSourceType();
- if (srcType.getRank() != insertSliceOp.getType().getRank())
+ if (srcType.getRank() != insertSliceOp.getDestType().getRank())
return failure();
SmallVector<int64_t> newSrcShape(srcType.getShape().begin(),
srcType.getShape().end());
// 2) "More static" than srcType.
// 3) Cast-compatible with srcType.
// Insert the cast.
+ OpBuilder::InsertionGuard g(rewriter);
+ // The only difference between InsertSliceOp and ParallelInsertSliceOp is
+ // the the insertion point is just before the ParallelCombiningOp in the
+ // parallel case.
+ if (std::is_same<InsertOpTy, ParallelInsertSliceOp>::value)
+ rewriter.setInsertionPoint(insertSliceOp->getParentOp());
Value cast = rewriter.create<tensor::CastOp>(
insertSliceOp.getLoc(), newSrcType, insertSliceOp.getSource());
- rewriter.replaceOpWithNewOp<InsertSliceOp>(
+ rewriter.replaceOpWithNewOp<InsertOpTy>(
insertSliceOp, cast, insertSliceOp.getDest(),
insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
insertSliceOp.getMixedStrides());
+ cast.getDefiningOp()->getParentOfType<ModuleOp>().dump();
return success();
}
};
void InsertSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<InsertSliceOpConstantArgumentFolder, InsertSliceOpCastFolder,
- InsertSliceOpSourceCastInserter>(context);
+ results.add<InsertSliceOpConstantArgumentFolder<InsertSliceOp>,
+ InsertSliceOpCastFolder<InsertSliceOp>,
+ InsertSliceOpSourceCastInserter<InsertSliceOp>>(context);
}
Value mlir::tensor::createCanonicalRankReducingInsertSliceOp(OpBuilder &b,
if (!isa<ParallelCombiningOpInterface>(getOperation()->getParentOp()))
return this->emitError("expected ParallelCombiningOpInterface parent, got:")
<< *(getOperation()->getParentOp());
- return success();
+
+ ShapedType expectedType;
+ SliceVerificationResult result =
+ verifyInsertSliceOp(getSourceType(), getDestType(), getStaticOffsets(),
+ getStaticSizes(), getStaticStrides(), &expectedType);
+ return produceSliceErrorMsg(result, *this, expectedType);
}
namespace {
canonicalizeSubViewPart(mixedStrides, ShapedType::isDynamicStrideOrOffset);
// Create the new op in canonical form.
+ auto sourceType =
+ tensor::ExtractSliceOp::inferCanonicalRankReducedResultType(
+ insertSliceOp.getSourceType().getRank(),
+ insertSliceOp.getDestType(), mixedOffsets, mixedSizes,
+ mixedStrides);
+ Value toInsert = insertSliceOp.getSource();
+ if (sourceType != insertSliceOp.getSourceType()) {
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(insertSliceOp->getParentOp());
+ toInsert = rewriter.create<tensor::CastOp>(insertSliceOp.getLoc(),
+ sourceType, toInsert);
+ }
rewriter.replaceOpWithNewOp<ParallelInsertSliceOp>(
- insertSliceOp, insertSliceOp.getSource(), insertSliceOp.getDest(),
- mixedOffsets, mixedSizes, mixedStrides);
+ insertSliceOp, toInsert, insertSliceOp.getDest(), mixedOffsets,
+ mixedSizes, mixedStrides);
return success();
}
};
} // namespace
-/// Fold a parallel_insert_slice source coming from a tensor.cast op.
-///
-/// Example:
-/// ```
-/// %0 = scf.foreach_thread (%arg0) in (%c2) -> (tensor<128xf32>) {
-/// %1 = compute_some_tensor() : tensor<64xf32>
-/// %2 = tensor.cast %1 : tensor<64xf32> to tensor<?xf32>
-/// scf.foreach_thread.perform_concurrently {
-/// scf.foreach_thread.parallel_insert_slice %2 into %out[...] [64] [1] :
-/// tensor<?xf32> into tensor<128xf32>
-/// }
-/// }
-/// ```
-///
-/// is folded into:
-/// ```
-/// %0 = scf.foreach_thread (%arg0) in (%c2) -> (tensor<128xf32>) {
-/// %1 = compute_some_tensor() : tensor<64xf32>
-/// scf.foreach_thread.perform_concurrently {
-/// scf.foreach_thread.parallel_insert_slice %1 into %out[...] [64] [1] :
-/// tensor<64xf32> into tensor<128xf32>
-/// }
-/// }
-/// ```
LogicalResult
ParallelInsertSliceOp::fold(ArrayRef<Attribute> operands,
SmallVectorImpl<OpFoldResult> &results) {
- auto sourceCast = getSource().getDefiningOp<tensor::CastOp>();
- if (!sourceCast)
- return failure();
- getSourceMutable().assign(sourceCast.getSource());
- return success();
+ return foldInsertOp(*this, operands);
}
void ParallelInsertSliceOp::getCanonicalizationPatterns(
RewritePatternSet &results, MLIRContext *context) {
- results.add<ParallelInsertSliceOpConstantArgumentFolder>(context);
+ results.add<InsertSliceOpConstantArgumentFolder<ParallelInsertSliceOp>,
+ InsertSliceOpCastFolder<ParallelInsertSliceOp>,
+ InsertSliceOpSourceCastInserter<ParallelInsertSliceOp>>(context);
}
//===----------------------------------------------------------------------===//