return {};
}
+namespace {
+/// Replace tensor_cast + tensor_to_memref by tensor_to_memref + memref_cast.
+struct TensorCastToMemref : public OpRewritePattern<TensorToMemrefOp> {
+ using OpRewritePattern<TensorToMemrefOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(TensorToMemrefOp tensorToMemRef,
+ PatternRewriter &rewriter) const final {
+ auto tensorCastOperand =
+ tensorToMemRef.getOperand().getDefiningOp<tensor::CastOp>();
+ if (!tensorCastOperand)
+ return failure();
+ auto srcTensorType =
+ tensorCastOperand.getOperand().getType().dyn_cast<RankedTensorType>();
+ if (!srcTensorType)
+ return failure();
+ auto memrefType = MemRefType::get(srcTensorType.getShape(),
+ srcTensorType.getElementType());
+ Value memref = rewriter.create<TensorToMemrefOp>(
+ tensorToMemRef.getLoc(), memrefType, tensorCastOperand.getOperand());
+ rewriter.replaceOpWithNewOp<MemRefCastOp>(tensorToMemRef,
+ tensorToMemRef.getType(), memref);
+ return success();
+ }
+};
+} // namespace
+
+void TensorToMemrefOp::getCanonicalizationPatterns(
+ OwningRewritePatternList &results, MLIRContext *context) {
+ results.insert<TensorCastToMemref>(context);
+}
+
//===----------------------------------------------------------------------===//
// TransposeOp
//===----------------------------------------------------------------------===//
%2 = dim %0, %c1 : tensor<?x?xf32>
return %1, %2: index, index
}
+
+// CHECK-LABEL: func @tensor_cast_to_memref
+// CHECK-SAME: %[[ARG0:.+]]: tensor<4x6x16x32xi8>
+// CHECK: %[[M:.+]] = tensor_to_memref %[[ARG0]] : memref<4x6x16x32xi8>
+// CHECK: %[[M1:.+]] = memref_cast %[[M]] : memref<4x6x16x32xi8> to memref<?x?x16x32xi8>
+// CHECK: return %[[M1]] : memref<?x?x16x32xi8>
+func @tensor_cast_to_memref(%arg0 : tensor<4x6x16x32xi8>) ->
+ memref<?x?x16x32xi8> {
+ %0 = tensor.cast %arg0 : tensor<4x6x16x32xi8> to tensor<?x?x16x32xi8>
+ %1 = tensor_to_memref %0 : memref<?x?x16x32xi8>
+ return %1 : memref<?x?x16x32xi8>
+}