[mlir][bufferize] Insert memref.cast ops during finalizing pass
authorMatthias Springer <springerm@google.com>
Thu, 27 Jan 2022 10:04:44 +0000 (19:04 +0900)
committerMatthias Springer <springerm@google.com>
Thu, 27 Jan 2022 10:06:53 +0000 (19:06 +0900)
The pass can currently not handle to_memref(to_tensor(x)) folding where a cast is necessary. This is required with the new unified bufferization. There is already a canonicalization pattern that handles such foldings and it should be used during this pass.

Differential Revision: https://reviews.llvm.org/D117988

mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h
mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir

index 2cbfc90..2a1edf8 100644 (file)
 
 #include "mlir/Dialect/Bufferization/IR/BufferizationOpsDialect.h.inc"
 
+namespace mlir {
+class RewritePatternSet;
+class MLIRContext;
+
+namespace bufferization {
+/// Populate patterns for folding to_memref and to_tensor ops.
+/// Note: to_memref(to_tensor(x)) without type changes are handled by a folder.
+void populateBufferizationOpFoldingPatterns(RewritePatternSet &patterns,
+                                            MLIRContext *context);
+} // namespace bufferization
+} // namespace mlir
+
 //===----------------------------------------------------------------------===//
 // Bufferization Dialect Operations
 //===----------------------------------------------------------------------===//
index f1ec7bb..28a8e5d 100644 (file)
@@ -240,7 +240,8 @@ static LogicalResult foldToMemrefToTensorPair(RewriterBase &rewriter,
       if (resultType.getShape()[i] != ShapedType::kDynamicSize)
         continue;
       auto index = rewriter.createOrFold<arith::ConstantIndexOp>(loc, i);
-      Value size = rewriter.create<tensor::DimOp>(loc, memrefToTensor, index);
+      Value size =
+          rewriter.create<memref::DimOp>(loc, memrefToTensor.memref(), index);
       dynamicOperands.push_back(size);
     }
     // TODO: Use alloc/memcpy callback from BufferizationOptions if called via
@@ -309,6 +310,11 @@ void ToMemrefOp::getCanonicalizationPatterns(RewritePatternSet &results,
       context);
 }
 
+void bufferization::populateBufferizationOpFoldingPatterns(
+    RewritePatternSet &patterns, MLIRContext *context) {
+  patterns.add<TensorLoadToMemref>(context);
+}
+
 LogicalResult ToMemrefOp::bufferize(RewriterBase &rewriter,
                                     const BufferizationState &state) {
   // Fold to_memref(to_tensor(x)) to x. Insert a cast if necessary.
index eb7f0de..f202a7a 100644 (file)
@@ -92,6 +92,7 @@ void mlir::bufferization::populateEliminateBufferizeMaterializationsPatterns(
     BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns) {
   patterns.add<BufferizeToTensorOp, BufferizeToMemrefOp>(typeConverter,
                                                          patterns.getContext());
+  populateBufferizationOpFoldingPatterns(patterns, patterns.getContext());
 }
 
 namespace {
index fac685a..66e1ccf 100644 (file)
@@ -1,11 +1,11 @@
 // RUN: mlir-opt %s -finalizing-bufferize -split-input-file -verify-diagnostics | FileCheck %s
 
-// CHECK-LABEL:   func @eliminate_materializations(
-// CHECK-SAME:                                     %[[ARG:.*]]: memref<f32>) -> memref<f32> {
-// CHECK:           return %[[ARG]] : memref<f32>
+// CHECK-LABEL: func @eliminate_materializations(
+//  CHECK-SAME:     %[[ARG:.*]]: memref<f32>) -> memref<f32> {
 func @eliminate_materializations(%arg0: memref<f32>) -> memref<f32> {
   %0 = bufferization.to_tensor %arg0 : memref<f32>
   %1 = bufferization.to_memref %0 : memref<f32>
+  // CHECK: return %[[ARG]] : memref<f32>
   return %1 : memref<f32>
 }
 
@@ -26,3 +26,37 @@ func @unable_to_convert_lone_tensor_load(%arg0: memref<f32>) {
   "test.sink"(%0) : (tensor<f32>) -> ()
   return
 }
+
+// -----
+
+// CHECK: #[[$MAP1:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>
+#map1 = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>
+
+// CHECK-LABEL: func @insert_memref_cast(
+//  CHECK-SAME:     %[[arg0:.*]]: memref<?xf32>
+func @insert_memref_cast(%arg0: memref<?xf32>) -> memref<?xf32, #map1> {
+  %0 = bufferization.to_tensor %arg0 : memref<?xf32>
+  %1 = bufferization.to_memref %0 : memref<?xf32, #map1>
+  // CHECK: %[[r:.*]] = memref.cast %[[arg0]] : memref<?xf32> to memref<?xf32, #[[$MAP1]]>
+  // CHECK: return %[[r]]
+  return %1 : memref<?xf32, #map1>
+}
+
+// -----
+
+// CHECK: #[[$MAP2:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>
+#map2 = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>
+
+// CHECK-LABEL: func @insert_buffer_copy(
+//  CHECK-SAME:     %[[arg0:.*]]: memref<?xf32, #[[$MAP2]]>
+func @insert_buffer_copy(%arg0: memref<?xf32, #map2>) -> memref<?xf32> {
+  // CHECK: %[[c0:.*]] = arith.constant 0 : index
+  // CHECK: %[[dim0:.*]] = memref.dim %[[arg0]], %[[c0]]
+  // CHECK: %[[alloc:.*]] = memref.alloc(%[[dim0]]) : memref<?xf32>
+  // CHECK: memref.copy %[[arg0]], %[[alloc]]
+  %0 = bufferization.to_tensor %arg0 : memref<?xf32, #map2>
+  %1 = bufferization.to_memref %0 : memref<?xf32>
+
+  // CHECK: return %[[alloc]]
+  return %1 : memref<?xf32>
+}