[mlir][MemRef] Add pattern to resolve strided metadata of `memref.get_global` operation.
authorMahesh Ravishankar <ravishankarm@google.com>
Mon, 3 Apr 2023 17:31:36 +0000 (17:31 +0000)
committerMahesh Ravishankar <ravishankarm@google.com>
Mon, 3 Apr 2023 17:41:35 +0000 (17:41 +0000)
This changes adds patterns to resolve the base pointer, offset, sizes
and strides of the result of a `memref.get_global` operation. Since
the operation can only result in static shaped memrefs, current
resolution kicks in only for non-zero offsets, and identity strides.

Also

- Add a separate `populateResolveExtractStridedMetadata` method that
  adds just the pattern to resolve `<memref op>` ->
  `memref.extract_strided_metadata` operations.
- Refactor the `SubviewFolder` pattern to allow resolving
  `memref.subview` -> `memref.extract_strided_metadata`.

This allows using these patterns for cases where there are already
existing `memref.extract_strided_metadata` operations.

Reviewed By: qcolombet

Differential Revision: https://reviews.llvm.org/D147393

mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
mlir/test/Dialect/MemRef/expand-strided-metadata.mlir

index 3c6c12c..8d679cc 100644 (file)
@@ -53,6 +53,10 @@ void populateResolveShapedTypeResultDimsPatterns(RewritePatternSet &patterns);
 /// (sizes, offset, strides) of a memref into easier to analyze constructs.
 void populateExpandStridedMetadataPatterns(RewritePatternSet &patterns);
 
+/// Appends patterns for resolving `memref.extract_strided_metadata` into
+/// `memref.extract_strided_metadata` of its source.
+void populateResolveExtractStridedMetadataPatterns(RewritePatternSet &patterns);
+
 /// Appends patterns for emulating wide integer memref operations with ops over
 /// narrower integer types.
 void populateMemRefWideIntEmulationPatterns(
index 0b01a1c..a60cd80 100644 (file)
@@ -23,6 +23,7 @@ add_mlir_dialect_library(MLIRMemRefTransforms
   MLIRArithDialect
   MLIRArithTransforms
   MLIRBufferizationDialect
+  MLIRDialectUtils
   MLIRFuncDialect
   MLIRGPUOps
   MLIRInferTypeOpInterface
index 918055d..64b9b04 100644 (file)
@@ -18,6 +18,7 @@
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/MemRef/Transforms/Passes.h"
 #include "mlir/Dialect/MemRef/Transforms/Transforms.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -34,6 +35,115 @@ namespace memref {
 using namespace mlir;
 
 namespace {
+
+struct StridedMetadata {
+  Value basePtr;
+  OpFoldResult offset;
+  SmallVector<OpFoldResult> sizes;
+  SmallVector<OpFoldResult> strides;
+};
+
+/// From `subview(memref, subOffset, subSizes, subStrides))` compute
+///
+/// \verbatim
+/// baseBuffer, baseOffset, baseSizes, baseStrides =
+///     extract_strided_metadata(memref)
+/// strides#i = baseStrides#i * subSizes#i
+/// offset = baseOffset + sum(subOffset#i * baseStrides#i)
+/// sizes = subSizes
+/// \endverbatim
+///
+/// and return {baseBuffer, offset, sizes, strides}
+static FailureOr<StridedMetadata>
+resolveSubviewStridedMetadata(RewriterBase &rewriter,
+                              memref::SubViewOp subview) {
+  // Build a plain extract_strided_metadata(memref) from subview(memref).
+  Location origLoc = subview.getLoc();
+  Value source = subview.getSource();
+  auto sourceType = source.getType().cast<MemRefType>();
+  unsigned sourceRank = sourceType.getRank();
+
+  auto newExtractStridedMetadata =
+      rewriter.create<memref::ExtractStridedMetadataOp>(origLoc, source);
+
+  auto [sourceStrides, sourceOffset] = getStridesAndOffset(sourceType);
+
+  // Compute the new strides and offset from the base strides and offset:
+  // newStride#i = baseStride#i * subStride#i
+  // offset = baseOffset + sum(subOffsets#i * newStrides#i)
+  SmallVector<OpFoldResult> strides;
+  SmallVector<OpFoldResult> subStrides = subview.getMixedStrides();
+  auto origStrides = newExtractStridedMetadata.getStrides();
+
+  // Hold the affine symbols and values for the computation of the offset.
+  SmallVector<OpFoldResult> values(2 * sourceRank + 1);
+  SmallVector<AffineExpr> symbols(2 * sourceRank + 1);
+
+  bindSymbolsList(rewriter.getContext(), MutableArrayRef{symbols});
+  AffineExpr expr = symbols.front();
+  values[0] = ShapedType::isDynamic(sourceOffset)
+                  ? getAsOpFoldResult(newExtractStridedMetadata.getOffset())
+                  : rewriter.getIndexAttr(sourceOffset);
+  SmallVector<OpFoldResult> subOffsets = subview.getMixedOffsets();
+
+  AffineExpr s0 = rewriter.getAffineSymbolExpr(0);
+  AffineExpr s1 = rewriter.getAffineSymbolExpr(1);
+  for (unsigned i = 0; i < sourceRank; ++i) {
+    // Compute the stride.
+    OpFoldResult origStride =
+        ShapedType::isDynamic(sourceStrides[i])
+            ? origStrides[i]
+            : OpFoldResult(rewriter.getIndexAttr(sourceStrides[i]));
+    strides.push_back(makeComposedFoldedAffineApply(
+        rewriter, origLoc, s0 * s1, {subStrides[i], origStride}));
+
+    // Build up the computation of the offset.
+    unsigned baseIdxForDim = 1 + 2 * i;
+    unsigned subOffsetForDim = baseIdxForDim;
+    unsigned origStrideForDim = baseIdxForDim + 1;
+    expr = expr + symbols[subOffsetForDim] * symbols[origStrideForDim];
+    values[subOffsetForDim] = subOffsets[i];
+    values[origStrideForDim] = origStride;
+  }
+
+  // Compute the offset.
+  OpFoldResult finalOffset =
+      makeComposedFoldedAffineApply(rewriter, origLoc, expr, values);
+
+  // The final result is  <baseBuffer, offset, sizes, strides>.
+  // Thus we need 1 + 1 + subview.getRank() + subview.getRank(), to hold all
+  // the values.
+  auto subType = subview.getType().cast<MemRefType>();
+  unsigned subRank = subType.getRank();
+
+  // The sizes of the final type are defined directly by the input sizes of
+  // the subview.
+  // Moreover subviews can drop some dimensions, some strides and sizes may
+  // not end up in the final <base, offset, sizes, strides> value that we are
+  // replacing.
+  // Do the filtering here.
+  SmallVector<OpFoldResult> subSizes = subview.getMixedSizes();
+  llvm::SmallBitVector droppedDims = subview.getDroppedDims();
+
+  SmallVector<OpFoldResult> finalSizes;
+  finalSizes.reserve(subRank);
+
+  SmallVector<OpFoldResult> finalStrides;
+  finalStrides.reserve(subRank);
+
+  for (unsigned i = 0; i < sourceRank; ++i) {
+    if (droppedDims.test(i))
+      continue;
+
+    finalSizes.push_back(subSizes[i]);
+    finalStrides.push_back(strides[i]);
+  }
+  assert(finalSizes.size() == subRank &&
+         "Should have populated all the values at this point");
+  return StridedMetadata{newExtractStridedMetadata.getBaseBuffer(), finalOffset,
+                         finalSizes, finalStrides};
+}
+
 /// Replace `dst = subview(memref, subOffset, subSizes, subStrides))`
 /// With
 ///
@@ -54,96 +164,62 @@ public:
 
   LogicalResult matchAndRewrite(memref::SubViewOp subview,
                                 PatternRewriter &rewriter) const override {
-    // Build a plain extract_strided_metadata(memref) from subview(memref).
-    Location origLoc = subview.getLoc();
-    Value source = subview.getSource();
-    auto sourceType = source.getType().cast<MemRefType>();
-    unsigned sourceRank = sourceType.getRank();
-
-    auto newExtractStridedMetadata =
-        rewriter.create<memref::ExtractStridedMetadataOp>(origLoc, source);
-
-    auto [sourceStrides, sourceOffset] = getStridesAndOffset(sourceType);
-
-    // Compute the new strides and offset from the base strides and offset:
-    // newStride#i = baseStride#i * subStride#i
-    // offset = baseOffset + sum(subOffsets#i * newStrides#i)
-    SmallVector<OpFoldResult> strides;
-    SmallVector<OpFoldResult> subStrides = subview.getMixedStrides();
-    auto origStrides = newExtractStridedMetadata.getStrides();
-
-    // Hold the affine symbols and values for the computation of the offset.
-    SmallVector<OpFoldResult> values(2 * sourceRank + 1);
-    SmallVector<AffineExpr> symbols(2 * sourceRank + 1);
-
-    bindSymbolsList(rewriter.getContext(), MutableArrayRef{symbols});
-    AffineExpr expr = symbols.front();
-    values[0] = ShapedType::isDynamic(sourceOffset)
-                    ? getAsOpFoldResult(newExtractStridedMetadata.getOffset())
-                    : rewriter.getIndexAttr(sourceOffset);
-    SmallVector<OpFoldResult> subOffsets = subview.getMixedOffsets();
-
-    AffineExpr s0 = rewriter.getAffineSymbolExpr(0);
-    AffineExpr s1 = rewriter.getAffineSymbolExpr(1);
-    for (unsigned i = 0; i < sourceRank; ++i) {
-      // Compute the stride.
-      OpFoldResult origStride =
-          ShapedType::isDynamic(sourceStrides[i])
-              ? origStrides[i]
-              : OpFoldResult(rewriter.getIndexAttr(sourceStrides[i]));
-      strides.push_back(makeComposedFoldedAffineApply(
-          rewriter, origLoc, s0 * s1, {subStrides[i], origStride}));
-
-      // Build up the computation of the offset.
-      unsigned baseIdxForDim = 1 + 2 * i;
-      unsigned subOffsetForDim = baseIdxForDim;
-      unsigned origStrideForDim = baseIdxForDim + 1;
-      expr = expr + symbols[subOffsetForDim] * symbols[origStrideForDim];
-      values[subOffsetForDim] = subOffsets[i];
-      values[origStrideForDim] = origStride;
+    FailureOr<StridedMetadata> stridedMetadata =
+        resolveSubviewStridedMetadata(rewriter, subview);
+    if (failed(stridedMetadata)) {
+      return rewriter.notifyMatchFailure(subview,
+                                         "failed to resolve subview metadata");
     }
 
-    // Compute the offset.
-    OpFoldResult finalOffset =
-        makeComposedFoldedAffineApply(rewriter, origLoc, expr, values);
-
-    // The final result is  <baseBuffer, offset, sizes, strides>.
-    // Thus we need 1 + 1 + subview.getRank() + subview.getRank(), to hold all
-    // the values.
-    auto subType = subview.getType().cast<MemRefType>();
-    unsigned subRank = subType.getRank();
-
-    // The sizes of the final type are defined directly by the input sizes of
-    // the subview.
-    // Moreover subviews can drop some dimensions, some strides and sizes may
-    // not end up in the final <base, offset, sizes, strides> value that we are
-    // replacing.
-    // Do the filtering here.
-    SmallVector<OpFoldResult> subSizes = subview.getMixedSizes();
-    llvm::SmallBitVector droppedDims = subview.getDroppedDims();
-
-    SmallVector<OpFoldResult> finalSizes;
-    finalSizes.reserve(subRank);
+    rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
+        subview, subview.getType(), stridedMetadata->basePtr,
+        stridedMetadata->offset, stridedMetadata->sizes,
+        stridedMetadata->strides);
+    return success();
+  }
+};
 
-    SmallVector<OpFoldResult> finalStrides;
-    finalStrides.reserve(subRank);
+/// Pattern to replace `extract_strided_metadata(subview)`
+/// With
+///
+/// \verbatim
+/// baseBuffer, baseOffset, baseSizes, baseStrides =
+///     extract_strided_metadata(memref)
+/// strides#i = baseStrides#i * subSizes#i
+/// offset = baseOffset + sum(subOffset#i * baseStrides#i)
+/// sizes = subSizes
+/// \verbatim
+///
+/// with `baseBuffer`, `offset`, `sizes` and `strides` being
+/// the replacements for the original `extract_strided_metadata`.
+struct ExtractStridedMetadataOpSubviewFolder
+    : OpRewritePattern<memref::ExtractStridedMetadataOp> {
+  using OpRewritePattern::OpRewritePattern;
 
-    for (unsigned i = 0; i < sourceRank; ++i) {
-      if (droppedDims.test(i))
-        continue;
+  LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op,
+                                PatternRewriter &rewriter) const override {
+    auto subviewOp = op.getSource().getDefiningOp<memref::SubViewOp>();
+    if (!subviewOp)
+      return failure();
 
-      finalSizes.push_back(subSizes[i]);
-      finalStrides.push_back(strides[i]);
+    FailureOr<StridedMetadata> stridedMetadata =
+        resolveSubviewStridedMetadata(rewriter, subviewOp);
+    if (failed(stridedMetadata)) {
+      return rewriter.notifyMatchFailure(
+          op, "failed to resolve metadata in terms of source subview op");
     }
-    assert(finalSizes.size() == subRank &&
-           "Should have populated all the values at this point");
+    Location loc = subviewOp.getLoc();
+    SmallVector<Value> results;
+    results.reserve(subviewOp.getType().getRank() * 2 + 2);
+    results.push_back(stridedMetadata->basePtr);
+    results.push_back(getValueOrCreateConstantIndexOp(rewriter, loc,
+                                                      stridedMetadata->offset));
+    results.append(
+        getValueOrCreateConstantIndexOp(rewriter, loc, stridedMetadata->sizes));
+    results.append(getValueOrCreateConstantIndexOp(rewriter, loc,
+                                                   stridedMetadata->strides));
+    rewriter.replaceOp(op, results);
 
-    auto memrefDesc = rewriter.create<memref::ReinterpretCastOp>(
-        origLoc, subType, newExtractStridedMetadata.getBaseBuffer(),
-        finalOffset,
-        /*sizes=*/finalSizes,
-        /*strides=*/finalStrides);
-    rewriter.replaceOp(subview, memrefDesc.getResult());
     return success();
   }
 };
@@ -634,6 +710,77 @@ public:
   }
 };
 
+/// Replace `base, offset, sizes, strides =
+///              extract_strided_metadata(get_global)`
+///
+/// With
+///
+/// ```
+/// base = reinterpret_cast get_global to a flat memref<eltTy>
+/// offset = 0
+/// sizes = allocSizes
+/// strides#i = prod(allocSizes#j, for j in {i+1..rank-1})
+/// ```
+///
+/// It is expected that the memref.get_global op has static shapes
+/// and identity affine_map for the layout.
+struct ExtractStridedMetadataOpGetGlobalFolder
+    : public OpRewritePattern<memref::ExtractStridedMetadataOp> {
+public:
+  using OpRewritePattern<memref::ExtractStridedMetadataOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op,
+                                PatternRewriter &rewriter) const override {
+    auto getGlobalOp = op.getSource().getDefiningOp<memref::GetGlobalOp>();
+    if (!getGlobalOp)
+      return failure();
+
+    auto memRefType = getGlobalOp.getResult().getType().cast<MemRefType>();
+    if (!memRefType.getLayout().isIdentity()) {
+      return rewriter.notifyMatchFailure(
+          getGlobalOp,
+          "get-global operation result should have been normalized");
+    }
+
+    Location loc = op.getLoc();
+    int rank = memRefType.getRank();
+
+    // Collect the sizes.
+    ArrayRef<int64_t> sizes = memRefType.getShape();
+    assert(!llvm::any_of(sizes, ShapedType::isDynamic) &&
+           "unexpected dynamic shape for result of `memref.get_global` op");
+
+    // Strides (just creates identity strides).
+    SmallVector<int64_t> strides = computeSuffixProduct(sizes);
+
+    // Put all the values together to replace the results.
+    SmallVector<Value> results;
+    results.reserve(rank * 2 + 2);
+
+    auto baseBufferType = op.getBaseBuffer().getType().cast<MemRefType>();
+    int64_t offset = 0;
+    if (getGlobalOp.getType() == baseBufferType)
+      results.push_back(getGlobalOp);
+    else
+      results.push_back(rewriter.create<memref::ReinterpretCastOp>(
+          loc, baseBufferType, getGlobalOp, offset,
+          /*sizes=*/ArrayRef<int64_t>(),
+          /*strides=*/ArrayRef<int64_t>()));
+
+    // Offset.
+    results.push_back(rewriter.create<arith::ConstantIndexOp>(loc, offset));
+
+    for (auto size : sizes)
+      results.push_back(rewriter.create<arith::ConstantIndexOp>(loc, size));
+
+    for (auto stride : strides)
+      results.push_back(rewriter.create<arith::ConstantIndexOp>(loc, stride));
+
+    rewriter.replaceOp(op, results);
+    return success();
+  }
+};
+
 /// Rewrite memref.extract_aligned_pointer_as_index of a ViewLikeOp to the
 /// source of the ViewLikeOp.
 class RewriteExtractAlignedPointerAsIndexOfViewLikeOp
@@ -758,6 +905,19 @@ void memref::populateExpandStridedMetadataPatterns(
                              getCollapsedStride>,
                ExtractStridedMetadataOpAllocFolder<memref::AllocOp>,
                ExtractStridedMetadataOpAllocFolder<memref::AllocaOp>,
+               ExtractStridedMetadataOpGetGlobalFolder,
+               RewriteExtractAlignedPointerAsIndexOfViewLikeOp,
+               ExtractStridedMetadataOpReinterpretCastFolder,
+               ExtractStridedMetadataOpExtractStridedMetadataFolder>(
+      patterns.getContext());
+}
+
+void memref::populateResolveExtractStridedMetadataPatterns(
+    RewritePatternSet &patterns) {
+  patterns.add<ExtractStridedMetadataOpAllocFolder<memref::AllocOp>,
+               ExtractStridedMetadataOpAllocFolder<memref::AllocaOp>,
+               ExtractStridedMetadataOpGetGlobalFolder,
+               ExtractStridedMetadataOpSubviewFolder,
                RewriteExtractAlignedPointerAsIndexOfViewLikeOp,
                ExtractStridedMetadataOpReinterpretCastFolder,
                ExtractStridedMetadataOpExtractStridedMetadataFolder>(
index 0eaf7d1..a6303aa 100644 (file)
@@ -1290,3 +1290,82 @@ func.func @extract_strided_metadata_of_reinterpret_cast_rank0(
       index, index,
       index, index
 }
+
+// -----
+
+// Check that for `memref.get_global` -> `memref.extract_strided_metadata` resolves
+// with the consumer replaced with the strides, sizes and offsets computed from
+// `memref.get_global`. Since the result of `memref.get_global is always static shaped
+// no need to check for dynamic shapes.
+
+// CHECK-LABEL: func @extract_strided_metadata_of_get_global()
+//   CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
+//   CHECK-DAG:   %[[C384:.+]] = arith.constant 384 : index
+//   CHECK-DAG:   %[[C512:.+]] = arith.constant 512 : index
+//   CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//       CHECK:   %[[GET_GLOBAL:.+]] = memref.get_global @const_i32
+//       CHECK:   %[[CAST:.+]] = memref.reinterpret_cast %[[GET_GLOBAL]]
+//  CHECK-SAME:       offset: [0], sizes: [], strides: []
+//       CHECK:   return %[[CAST]], %[[C0]], %[[C512]], %[[C384]], %[[C384]], %[[C1]]
+
+memref.global "private" constant @const_i32 : memref<512x384xi32> = dense<42>
+
+func.func @extract_strided_metadata_of_get_global()
+    -> (memref<i32>, index, index, index, index, index) {
+
+  %A = memref.get_global @const_i32 : memref<512x384xi32>
+
+  %base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %A :
+    memref<512x384xi32> -> memref<i32>, index, index, index, index, index
+
+  return %base, %offset, %sizes#0, %sizes#1, %strides#0, %strides#1 :
+      memref<i32>, index, index, index, index, index
+}
+
+// -----
+
+// Check that for `memref.get_global` -> `memref.extract_strided_metadata` does not
+// resolve when the strides are not identity. This is an unhandled case that could
+// be covered in the future
+
+// CHECK-LABEL: func @extract_strided_metadata_of_get_global_with_strides()
+//       CHECK:   %[[GET_GLOBAL:.+]] = memref.get_global @const_i32
+//       CHECK:   memref.extract_strided_metadata %[[GET_GLOBAL]]
+memref.global "private" constant @const_i32 : memref<512x384xi32, strided<[420, 1], offset: 0>> = dense<42>
+
+func.func @extract_strided_metadata_of_get_global_with_strides()
+    -> (memref<i32>, index, index, index, index, index) {
+
+  %A = memref.get_global @const_i32 : memref<512x384xi32, strided<[420, 1], offset: 0>>
+
+  %base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %A :
+    memref<512x384xi32, strided<[420, 1], offset: 0>>
+    -> memref<i32>, index, index, index, index, index
+
+  return %base, %offset, %sizes#0, %sizes#1, %strides#0, %strides#1 :
+      memref<i32>, index, index, index, index, index
+}
+
+// -----
+
+// Check that for `memref.get_global` -> `memref.extract_strided_metadata` does not
+// resolve when the offset is non-zero. This is an unhandled case that could
+// be covered in the future
+
+// CHECK-LABEL: func @extract_strided_metadata_of_get_global_with_offset()
+//       CHECK:   %[[GET_GLOBAL:.+]] = memref.get_global @const_i32
+//       CHECK:   memref.extract_strided_metadata %[[GET_GLOBAL]]
+memref.global "private" constant @const_i32 : memref<512x384xi32, strided<[384, 1], offset: 20>> = dense<42>
+
+func.func @extract_strided_metadata_of_get_global_with_offset()
+    -> (memref<i32>, index, index, index, index, index) {
+
+  %A = memref.get_global @const_i32 : memref<512x384xi32, strided<[384, 1], offset: 20>>
+
+  %base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %A :
+    memref<512x384xi32, strided<[384, 1], offset: 20>>
+    -> memref<i32>, index, index, index, index, index
+
+  return %base, %offset, %sizes#0, %sizes#1, %strides#0, %strides#1 :
+      memref<i32>, index, index, index, index, index
+}