[mlir] Fix memref_cast + subview folder when reducing rank
authorThomas Raoux <thomasraoux@google.com>
Tue, 16 Feb 2021 19:03:58 +0000 (11:03 -0800)
committerThomas Raoux <thomasraoux@google.com>
Tue, 16 Feb 2021 20:00:59 +0000 (12:00 -0800)
When the destination of the subview has a lower rank than its source we need to
fix the result type of the new subview op.

Differential Revision: https://reviews.llvm.org/D96804

mlir/lib/Dialect/StandardOps/IR/Ops.cpp
mlir/test/Dialect/Standard/canonicalize.mlir

index 4908291..5582c0b 100644 (file)
@@ -3058,7 +3058,23 @@ isRankReducedType(Type originalType, Type candidateReducedType,
     candidateLayout = getStridedLinearLayoutMap(candidateReduced);
   else
     candidateLayout = candidateReduced.getAffineMaps().front();
-  if (inferredType != candidateLayout) {
+  assert(inferredType.getNumResults() == 1 &&
+         candidateLayout.getNumResults() == 1);
+  if (inferredType.getNumSymbols() != candidateLayout.getNumSymbols() ||
+      inferredType.getNumDims() != candidateLayout.getNumDims()) {
+    if (errMsg) {
+      llvm::raw_string_ostream os(*errMsg);
+      os << "inferred type: " << inferredType;
+    }
+    return SubViewVerificationResult::AffineMapMismatch;
+  }
+  // Check that the difference of the affine maps simplifies to 0.
+  AffineExpr diffExpr =
+      inferredType.getResult(0) - candidateLayout.getResult(0);
+  diffExpr = simplifyAffineExpr(diffExpr, inferredType.getNumDims(),
+                                inferredType.getNumSymbols());
+  auto cst = diffExpr.dyn_cast<AffineConstantExpr>();
+  if (!(cst && cst.getValue() == 0)) {
     if (errMsg) {
       llvm::raw_string_ostream os(*errMsg);
       os << "inferred type: " << inferredType;
@@ -3344,11 +3360,29 @@ public:
     /// Deduce the resultType of the SubViewOp using `inferSubViewResultType` on
     /// the cast source operand type and the SubViewOp static information. This
     /// is the resulting type if the MemRefCastOp were folded.
-    Type resultType = SubViewOp::inferResultType(
-        castOp.source().getType().cast<MemRefType>(),
-        extractFromI64ArrayAttr(subViewOp.static_offsets()),
-        extractFromI64ArrayAttr(subViewOp.static_sizes()),
-        extractFromI64ArrayAttr(subViewOp.static_strides()));
+    auto resultType = SubViewOp::inferResultType(
+                          castOp.source().getType().cast<MemRefType>(),
+                          extractFromI64ArrayAttr(subViewOp.static_offsets()),
+                          extractFromI64ArrayAttr(subViewOp.static_sizes()),
+                          extractFromI64ArrayAttr(subViewOp.static_strides()))
+                          .cast<MemRefType>();
+    uint32_t rankDiff =
+        subViewOp.getSourceType().getRank() - subViewOp.getType().getRank();
+    if (rankDiff > 0) {
+      auto shape = resultType.getShape();
+      auto projectedShape = shape.drop_front(rankDiff);
+      AffineMap map;
+      auto maps = resultType.getAffineMaps();
+      if (!maps.empty() && maps.front()) {
+        auto optionalUnusedDimsMask =
+            computeRankReductionMask(shape, projectedShape);
+        llvm::SmallDenseSet<unsigned> dimsToProject =
+            optionalUnusedDimsMask.getValue();
+        map = getProjectedMap(maps.front(), dimsToProject);
+      }
+      resultType = MemRefType::get(projectedShape, resultType.getElementType(),
+                                   map, resultType.getMemorySpace());
+    }
     Value newSubView = rewriter.create<SubViewOp>(
         subViewOp.getLoc(), resultType, castOp.source(), subViewOp.offsets(),
         subViewOp.sizes(), subViewOp.strides(), subViewOp.static_offsets(),
index 7b54938..c864af8 100644 (file)
@@ -143,3 +143,17 @@ func @tensor_cast_to_memref(%arg0 : tensor<4x6x16x32xi8>) ->
   %1 = tensor_to_memref %0 : memref<?x?x16x32xi8>
   return %1 : memref<?x?x16x32xi8>
 }
+
+// CHECK-LABEL: func @subview_of_memcast
+//  CHECK-SAME:   %[[ARG0:.[a-z0-9A-Z_]+]]: memref<4x6x16x32xi8>
+//       CHECK:   %[[S:.+]] = subview %arg0[0, 1, 0, 0] [1, 1, 16, 32] [1, 1, 1, 1] : memref<4x6x16x32xi8> to memref<16x32xi8, #{{.*}}>
+//       CHECK:   %[[M:.+]] = memref_cast %[[S]] : memref<16x32xi8, #{{.*}}> to memref<16x32xi8, #{{.*}}>
+//       CHECK:   return %[[M]] : memref<16x32xi8, #{{.*}}>
+func @subview_of_memcast(%arg : memref<4x6x16x32xi8>) ->
+  memref<16x32xi8, affine_map<(d0, d1)[s0] -> (d0 * 32 + d1 + s0)>>{
+  %0 = memref_cast %arg : memref<4x6x16x32xi8> to memref<?x?x16x32xi8>
+  %1 = subview %0[0, 1, 0, 0] [1, 1, 16, 32] [1, 1, 1, 1] :
+    memref<?x?x16x32xi8> to
+    memref<16x32xi8, affine_map<(d0, d1)[s0] -> (d0 * 32 + d1 + s0)>>
+  return %1 : memref<16x32xi8, affine_map<(d0, d1)[s0] -> (d0 * 32 + d1 + s0)>>
+}