[mlir][MemRef] Fix canonicalization of BufferCast(TensorLoad).
authorAdrian Kuegel <akuegel@google.com>
Thu, 5 Aug 2021 09:30:09 +0000 (11:30 +0200)
committerAdrian Kuegel <akuegel@google.com>
Fri, 6 Aug 2021 06:32:35 +0000 (08:32 +0200)
CastOp::areCastCompatible does not check whether casts are definitely compatible.
When going from dynamic to static offset or stride, the canonicalization cannot
know whether it is really cast compatible. In that case, it can only canonicalize
to an alloc plus copy.

Differential Revision: https://reviews.llvm.org/D107545

mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
mlir/test/Dialect/MemRef/canonicalize.mlir

index 2d2d9ae..f3e44f2 100644 (file)
@@ -319,12 +319,54 @@ struct TensorLoadToMemRef : public OpRewritePattern<BufferCastOp> {
     // types. `BufferCastOp::fold` handles the same type case.
     if (!tensorLoad || tensorLoad.memref().getType() == bufferCast.getType())
       return failure();
-    // If types are not cast-compatible, bail.
+    // If types are definitely not cast-compatible, bail.
     if (!CastOp::areCastCompatible(tensorLoad.memref().getType(),
                                    bufferCast.getType()))
       return failure();
-    rewriter.replaceOpWithNewOp<CastOp>(bufferCast, bufferCast.getType(),
-                                        tensorLoad.memref());
+
+    // We already know that the types are potentially cast-compatible. However
+    // in case the affine maps are different, we may need to use a copy if we go
+    // from dynamic to static offset or stride (the canonicalization cannot know
+    // at this point that it is really cast compatible).
+    auto isGuaranteedCastCompatible = [](MemRefType source, MemRefType target) {
+      int64_t sourceOffset, targetOffset;
+      SmallVector<int64_t, 4> sourceStrides, targetStrides;
+      if (failed(getStridesAndOffset(source, sourceStrides, sourceOffset)) ||
+          failed(getStridesAndOffset(target, targetStrides, targetOffset)))
+        return false;
+      auto dynamicToStatic = [](int64_t a, int64_t b) {
+        return a == MemRefType::getDynamicStrideOrOffset() &&
+               b != MemRefType::getDynamicStrideOrOffset();
+      };
+      if (dynamicToStatic(sourceOffset, targetOffset))
+        return false;
+      for (auto it : zip(sourceStrides, targetStrides))
+        if (dynamicToStatic(std::get<0>(it), std::get<1>(it)))
+          return false;
+      return true;
+    };
+
+    auto tensorLoadType = tensorLoad.memref().getType().dyn_cast<MemRefType>();
+    auto bufferCastType = bufferCast.getType().dyn_cast<MemRefType>();
+    if (tensorLoadType && bufferCastType &&
+        !isGuaranteedCastCompatible(tensorLoadType, bufferCastType)) {
+      MemRefType resultType = bufferCastType;
+      auto loc = bufferCast.getLoc();
+      SmallVector<Value, 4> dynamicOperands;
+      for (int i = 0; i < resultType.getRank(); ++i) {
+        if (resultType.getShape()[i] != ShapedType::kDynamicSize)
+          continue;
+        auto index = rewriter.createOrFold<ConstantIndexOp>(loc, i);
+        Value size = rewriter.create<tensor::DimOp>(loc, tensorLoad, index);
+        dynamicOperands.push_back(size);
+      }
+      auto copy =
+          rewriter.create<memref::AllocOp>(loc, resultType, dynamicOperands);
+      rewriter.create<CopyOp>(loc, tensorLoad.memref(), copy);
+      rewriter.replaceOp(bufferCast, {copy});
+    } else
+      rewriter.replaceOpWithNewOp<CastOp>(bufferCast, bufferCast.getType(),
+                                          tensorLoad.memref());
     return success();
   }
 };
index 886f21d..d73bb13 100644 (file)
@@ -46,16 +46,18 @@ func @no_fold_buffer_cast_of_tensor_load(%arg0: memref<?xf32, 2>) -> memref<?xf3
 // CHECK-DAG: #[[$OFF_3:[a-z0-9]+]] = affine_map<(d0) -> (d0 + 3)>
 // CHECK-DAG: #[[$OFF_UNK:[a-z0-9]+]] = affine_map<(d0)[s0] -> (d0 + s0)>
 
-// Test case: If the memrefs are cast-compatible, canonicalize.
+// Test case: If the memrefs are definitely cast-compatible, canonicalize to
+//            cast.
 // CHECK-LABEL: func @canonicalize_buffer_cast_of_tensor_load(
 //  CHECK-SAME:   %[[M:.*]]: memref<?xf32, #[[$OFF_3]]>)
-//  CHEKC-SAME:     -> memref<?xf32, #[[$OFF_UNK]]> {
+//  CHECK-SAME:     -> memref<?xf32, #[[$OFF_UNK]]> {
 //   CHECK-NOT: memref.tensor_load
 //   CHECK-NOT: memref.buffer_cast
 //       CHECK: %[[R:.*]] = memref.cast %[[M]]
 //  CHECK-SAME:   memref<?xf32, #[[$OFF_3]]> to memref<?xf32, #[[$OFF_UNK]]>
 //       CHECK: return %[[R]]
-func @canonicalize_buffer_cast_of_tensor_load(%arg0: memref<?xf32, offset: 3, strides: [1]>)
+func @canonicalize_buffer_cast_of_tensor_load(
+  %arg0: memref<?xf32, offset: 3, strides: [1]>)
   -> memref<?xf32, offset: ?, strides: [1]>
 {
   %0 = memref.tensor_load %arg0 : memref<?xf32, offset: 3, strides: [1]>
@@ -65,6 +67,33 @@ func @canonicalize_buffer_cast_of_tensor_load(%arg0: memref<?xf32, offset: 3, st
 
 // -----
 
+// CHECK-DAG: #[[$OFF_UNK:[a-z0-9]+]] = affine_map<(d0)[s0] -> (d0 + s0)>
+// CHECK-DAG: #[[$OFF_3:[a-z0-9]+]] = affine_map<(d0) -> (d0 + 3)>
+
+// Test case: If the memrefs are potentially cast-compatible, canonicalize to
+//            copy.
+// CHECK-LABEL: func @canonicalize_buffer_cast_of_tensor_load_to_copy(
+//  CHECK-SAME:   %[[M:.*]]: memref<?xf32, #[[$OFF_UNK]]>)
+//  CHECK-SAME:     -> memref<?xf32, #[[$OFF_3]]> {
+//   CHECK-NOT: memref.tensor_load
+//   CHECK-NOT: memref.buffer_cast
+//       CHECK: %[[C0:.*]] = constant 0 : index
+//       CHECK: %[[DIM:.*]] = memref.dim %[[M]], %[[C0]] : memref<?xf32, #[[$OFF_UNK]]>
+//       CHECK: %[[ALLOC:.*]] = memref.alloc(%[[DIM]]) : memref<?xf32, #[[$OFF_3]]>
+//       CHECK: memref.copy %[[M]], %[[ALLOC]]
+//  CHECK-SAME:   memref<?xf32, #[[$OFF_UNK]]> to memref<?xf32, #[[$OFF_3]]>
+//       CHECK: return %[[ALLOC]]
+func @canonicalize_buffer_cast_of_tensor_load_to_copy(
+  %arg0: memref<?xf32, offset: ?, strides: [1]>)
+  -> memref<?xf32, offset: 3, strides: [1]>
+{
+  %0 = memref.tensor_load %arg0 : memref<?xf32, offset: ?, strides: [1]>
+  %1 = memref.buffer_cast %0 : memref<?xf32, offset: 3, strides: [1]>
+  return %1 : memref<?xf32, offset: 3, strides: [1]>
+}
+
+// -----
+
 // CHECK-LABEL: func @subview_of_memcast
 //  CHECK-SAME:   %[[ARG0:.[a-z0-9A-Z_]+]]: memref<4x6x16x32xi8>
 //       CHECK:   %[[S:.+]] = memref.subview %arg0[0, 1, 0] [1, 1, 16] [1, 1, 1] : memref<4x6x16x32xi8> to memref<16x32xi8, #{{.*}}>