[mlir] Fix subview verifier.
authorNicolas Vasilache <nicolas.vasilache@gmail.com>
Thu, 28 Jan 2021 13:54:51 +0000 (13:54 +0000)
committerNicolas Vasilache <nicolas.vasilache@gmail.com>
Thu, 28 Jan 2021 13:55:39 +0000 (13:55 +0000)
The subview verifier in the rank-reduced case is plainly skipping verification
when the resulting type is a memref with empty affine map. This is generally incorrect.

Instead, form the actual expected rank-reduced MemRefType that takes into account the projections of 1's dimensions. Then, check the canonicalized expected rank-reduced type against the canonicalized candidate type.

Differential Revision: https://reviews.llvm.org/D95316

mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
mlir/lib/Dialect/StandardOps/IR/Ops.cpp
mlir/lib/IR/BuiltinTypes.cpp
mlir/test/IR/invalid-ops.mlir

index 5571b58..c59cbda 100644 (file)
@@ -212,7 +212,6 @@ Type LLVMTypeConverter::convertFunctionType(FunctionType type) {
   return LLVM::LLVMPointerType::get(converted);
 }
 
-
 // Function types are converted to LLVM Function types by recursively converting
 // argument and result types.  If MLIR Function has zero results, the LLVM
 // Function has one VoidType result.  If MLIR Function has more than one result,
@@ -525,10 +524,11 @@ MemRefDescriptor::fromStaticShape(OpBuilder &builder, Location loc,
   auto result = getStridesAndOffset(type, strides, offset);
   (void)result;
   assert(succeeded(result) && "unexpected failure in stride computation");
-  assert(offset != MemRefType::getDynamicStrideOrOffset() &&
+  assert(!MemRefType::isDynamicStrideOrOffset(offset) &&
          "expected static offset");
-  assert(!llvm::is_contained(strides, MemRefType::getDynamicStrideOrOffset()) &&
-         "expected static strides");
+  assert(!llvm::any_of(strides, [](int64_t stride) {
+    return MemRefType::isDynamicStrideOrOffset(stride);
+  }) && "expected static strides");
 
   auto convertedType = typeConverter.convertType(type);
   assert(convertedType && "unexpected failure in memref type conversion");
@@ -1044,14 +1044,14 @@ Value ConvertToLLVMPattern::getStridedElementPtr(
 
   Value index;
   if (offset != 0) // Skip if offset is zero.
-    index = offset == MemRefType::getDynamicStrideOrOffset()
+    index = MemRefType::isDynamicStrideOrOffset(offset)
                 ? memRefDescriptor.offset(rewriter, loc)
                 : createIndexConstant(rewriter, loc, offset);
 
   for (int i = 0, e = indices.size(); i < e; ++i) {
     Value increment = indices[i];
     if (strides[i] != 1) { // Skip if stride is 1.
-      Value stride = strides[i] == MemRefType::getDynamicStrideOrOffset()
+      Value stride = MemRefType::isDynamicStrideOrOffset(strides[i])
                          ? memRefDescriptor.stride(rewriter, loc, i)
                          : createIndexConstant(rewriter, loc, strides[i]);
       increment = rewriter.create<LLVM::MulOp>(loc, increment, stride);
@@ -3308,7 +3308,7 @@ struct SubViewOpLowering : public ConvertOpToLLVMPattern<SubViewOp> {
         extracted);
     targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr);
 
-    // Copy the buffer pointer from the old descriptor to the new one.
+    // Copy the aligned pointer from the old descriptor to the new one.
     extracted = sourceMemRef.alignedPtr(rewriter, loc);
     bitcastPtr = rewriter.create<LLVM::BitcastOp>(
         loc,
@@ -3487,7 +3487,7 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern<ViewOp> {
                   ArrayRef<int64_t> strides, Value nextSize,
                   Value runningStride, unsigned idx) const {
     assert(idx < strides.size());
-    if (strides[idx] != MemRefType::getDynamicStrideOrOffset())
+    if (!MemRefType::isDynamicStrideOrOffset(strides[idx]))
       return createIndexConstant(rewriter, loc, strides[idx]);
     if (nextSize)
       return runningStride
index d5dac48..7226d89 100644 (file)
@@ -3078,29 +3078,32 @@ enum SubViewVerificationResult {
 /// This function is slight variant of `is subsequence` algorithm where
 /// not matching dimension must be 1.
 static SubViewVerificationResult isRankReducedType(Type originalType,
-                                                   Type reducedType) {
-  if (originalType == reducedType)
+                                                   Type candidateReducedType) {
+  if (originalType == candidateReducedType)
     return SubViewVerificationResult::Success;
   if (!originalType.isa<RankedTensorType>() && !originalType.isa<MemRefType>())
     return SubViewVerificationResult::Success;
   if (originalType.isa<RankedTensorType>() &&
-      !reducedType.isa<RankedTensorType>())
+      !candidateReducedType.isa<RankedTensorType>())
     return SubViewVerificationResult::Success;
-  if (originalType.isa<MemRefType>() && !reducedType.isa<MemRefType>())
+  if (originalType.isa<MemRefType>() && !candidateReducedType.isa<MemRefType>())
     return SubViewVerificationResult::Success;
 
   ShapedType originalShapedType = originalType.cast<ShapedType>();
-  ShapedType reducedShapedType = reducedType.cast<ShapedType>();
+  ShapedType candidateReducedShapedType =
+      candidateReducedType.cast<ShapedType>();
 
   // Rank and size logic is valid for all ShapedTypes.
   ArrayRef<int64_t> originalShape = originalShapedType.getShape();
-  ArrayRef<int64_t> reducedShape = reducedShapedType.getShape();
+  ArrayRef<int64_t> candidateReducedShape =
+      candidateReducedShapedType.getShape();
   unsigned originalRank = originalShape.size(),
-           reducedRank = reducedShape.size();
-  if (reducedRank > originalRank)
+           candidateReducedRank = candidateReducedShape.size();
+  if (candidateReducedRank > originalRank)
     return SubViewVerificationResult::RankTooLarge;
 
-  auto optionalMask = computeRankReductionMask(originalShape, reducedShape);
+  auto optionalMask =
+      computeRankReductionMask(originalShape, candidateReducedShape);
 
   // Sizes cannot be matched in case empty vector is returned.
   if (!optionalMask.hasValue())
@@ -3112,34 +3115,43 @@ static SubViewVerificationResult isRankReducedType(Type originalType,
 
   // Strided layout logic is relevant for MemRefType only.
   MemRefType original = originalType.cast<MemRefType>();
-  MemRefType reduced = reducedType.cast<MemRefType>();
+  MemRefType candidateReduced = candidateReducedType.cast<MemRefType>();
   MLIRContext *c = original.getContext();
-  int64_t originalOffset, reducedOffset;
-  SmallVector<int64_t, 4> originalStrides, reducedStrides, keepStrides;
+  int64_t originalOffset, candidateReducedOffset;
+  SmallVector<int64_t, 4> originalStrides, candidateReducedStrides, keepStrides;
   SmallVector<bool, 4> keepMask = optionalMask.getValue();
   getStridesAndOffset(original, originalStrides, originalOffset);
-  getStridesAndOffset(reduced, reducedStrides, reducedOffset);
+  getStridesAndOffset(candidateReduced, candidateReducedStrides,
+                      candidateReducedOffset);
 
   // Filter strides based on the mask and check that they are the same
-  // as reduced ones.
-  unsigned reducedIdx = 0;
+  // as candidateReduced ones.
+  unsigned candidateReducedIdx = 0;
   for (unsigned originalIdx = 0; originalIdx < originalRank; ++originalIdx) {
     if (keepMask[originalIdx]) {
-      if (originalStrides[originalIdx] != reducedStrides[reducedIdx++])
+      if (originalStrides[originalIdx] !=
+          candidateReducedStrides[candidateReducedIdx++])
         return SubViewVerificationResult::StrideMismatch;
       keepStrides.push_back(originalStrides[originalIdx]);
     }
   }
 
-  if (original.getElementType() != reduced.getElementType())
+  if (original.getElementType() != candidateReduced.getElementType())
     return SubViewVerificationResult::ElemTypeMismatch;
 
-  if (original.getMemorySpace() != reduced.getMemorySpace())
+  if (original.getMemorySpace() != candidateReduced.getMemorySpace())
     return SubViewVerificationResult::MemSpaceMismatch;
 
+  // reducedMap is obtained by projecting away the dimensions inferred from
+  // matching the 1's positions in candidateReducedType.
   auto reducedMap = makeStridedLinearLayoutMap(keepStrides, originalOffset, c);
-  if (!reduced.getAffineMaps().empty() &&
-      reducedMap != reduced.getAffineMaps().front())
+
+  MemRefType expectedReducedType = MemRefType::get(
+      candidateReduced.getShape(), candidateReduced.getElementType(),
+      reducedMap, candidateReduced.getMemorySpace());
+  expectedReducedType = canonicalizeStridedLayout(expectedReducedType);
+
+  if (expectedReducedType != canonicalizeStridedLayout(candidateReduced))
     return SubViewVerificationResult::AffineMapMismatch;
 
   return SubViewVerificationResult::Success;
index 23553a8..1414407 100644 (file)
@@ -745,12 +745,20 @@ MemRefType mlir::canonicalizeStridedLayout(MemRefType t) {
   if (affineMaps.size() > 1 || affineMaps[0].getNumResults() > 1)
     return t;
 
+  // Corner-case for 0-D affine maps.
+  auto m = affineMaps[0];
+  if (m.getNumDims() == 0 && m.getNumSymbols() == 0) {
+    if (auto cst = m.getResult(0).dyn_cast<AffineConstantExpr>())
+      if (cst.getValue() == 0)
+        return MemRefType::Builder(t).setAffineMaps({});
+    return t;
+  }
+
   // If the canonical strided layout for the sizes of `t` is equal to the
   // simplified layout of `t` we can just return an empty layout. Otherwise,
   // just simplify the existing layout.
   AffineExpr expr =
       makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext());
-  auto m = affineMaps[0];
   auto simplifiedLayoutExpr =
       simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols());
   if (expr != simplifiedLayoutExpr)
index a126107..595a950 100644 (file)
@@ -1011,6 +1011,16 @@ func @invalid_rank_reducing_subview(%arg0 : index, %arg1 : index, %arg2 : index)
 
 // -----
 
+func @invalid_rank_reducing_subview(%arg0 : index, %arg1 : index, %arg2 : index) {
+  %0 = alloc() : memref<8x16x4xf32>
+  // expected-error@+1 {{expected result type to be 'memref<8x16x4xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2 + 8)>>' or a rank-reduced version. (mismatch of result sizes)}}
+  %1 = subview %0[0, 2, 0][8, 16, 4][1, 1, 1]
+    : memref<8x16x4xf32> to memref<16x4xf32>
+  return
+}
+
+// -----
+
 func @invalid_rank_reducing_subview(%arg0 : memref<?x?xf32>, %arg1 : index, %arg2 : index) {
   // expected-error@+1 {{expected result type to be 'memref<?x1xf32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>' or a rank-reduced version. (mismatch of result strides)}}
   %0 = subview %arg0[0, %arg1][%arg2, 1][1, 1] : memref<?x?xf32> to memref<?xf32>