[mlir][MemRef] Move the forwarding patterns for `extract_strided_metadata`
authorQuentin Colombet <quentin.colombet@gmail.com>
Wed, 12 Oct 2022 00:53:52 +0000 (00:53 +0000)
committerQuentin Colombet <quentin.colombet@gmail.com>
Tue, 18 Oct 2022 22:34:50 +0000 (22:34 +0000)
The `SimplifyExtractStridedMetadata` pass features a pattern that forward
statically known information (offset, sizes, strides) to their respective
users.

This patch moves this pattern from this pass to the
`extract_strided_metadata` folding patterns.

Differential Revision: https://reviews.llvm.org/D135797

mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp
mlir/test/Dialect/MemRef/canonicalize.mlir
mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir

index ba8fe81..1f1b118 100644 (file)
@@ -912,6 +912,8 @@ def MemRef_ExtractStridedMetadataOp : MemRef_Op<"extract_strided_metadata", [
   let assemblyFormat = [{
     $source `:` type($source) `->` type(results) attr-dict
   }];
+
+  let hasFolder = 1;
 }
 
 //===----------------------------------------------------------------------===//
index ab7311b..9a6727d 100644 (file)
@@ -1286,6 +1286,58 @@ void ExtractStridedMetadataOp::getAsmResultNames(
   }
 }
 
+/// Helper function to perform the replacement of all constant uses of `values`
+/// by a materialized constant extracted from `maybeConstants`.
+/// `values` and `maybeConstants` are expected to have the same size.
+template <typename Container>
+static bool replaceConstantUsesOf(OpBuilder &rewriter, Location loc,
+                                  Container values,
+                                  ArrayRef<int64_t> maybeConstants,
+                                  llvm::function_ref<bool(int64_t)> isDynamic) {
+  assert(values.size() == maybeConstants.size() &&
+         " expected values and maybeConstants of the same size");
+  bool atLeastOneReplacement = false;
+  for (auto [maybeConstant, result] : llvm::zip(maybeConstants, values)) {
+    // Don't materialize a constant if there are no uses: this would indice
+    // infinite loops in the driver.
+    if (isDynamic(maybeConstant) || result.use_empty())
+      continue;
+    Value constantVal =
+        rewriter.create<arith::ConstantIndexOp>(loc, maybeConstant);
+    for (Operation *op : llvm::make_early_inc_range(result.getUsers())) {
+      // updateRootInplace: lambda cannot capture structured bindings in C++17
+      // yet.
+      op->replaceUsesOfWith(result, constantVal);
+      atLeastOneReplacement = true;
+    }
+  }
+  return atLeastOneReplacement;
+}
+
+LogicalResult
+ExtractStridedMetadataOp::fold(ArrayRef<Attribute> cstOperands,
+                               SmallVectorImpl<OpFoldResult> &results) {
+  OpBuilder builder(*this);
+  auto memrefType = getSource().getType().cast<MemRefType>();
+  SmallVector<int64_t> strides;
+  int64_t offset;
+  LogicalResult res = getStridesAndOffset(memrefType, strides, offset);
+  (void)res;
+  assert(succeeded(res) && "must be a strided memref type");
+
+  bool atLeastOneReplacement = replaceConstantUsesOf(
+      builder, getLoc(), ArrayRef<TypedValue<IndexType>>(getOffset()),
+      ArrayRef<int64_t>(offset), ShapedType::isDynamicStrideOrOffset);
+  atLeastOneReplacement |=
+      replaceConstantUsesOf(builder, getLoc(), getSizes(),
+                            memrefType.getShape(), ShapedType::isDynamic);
+  atLeastOneReplacement |=
+      replaceConstantUsesOf(builder, getLoc(), getStrides(), strides,
+                            ShapedType::isDynamicStrideOrOffset);
+
+  return success(atLeastOneReplacement);
+}
+
 //===----------------------------------------------------------------------===//
 // GenericAtomicRMWOp
 //===----------------------------------------------------------------------===//
index 6e86103..1ebc2f6 100644 (file)
@@ -550,64 +550,6 @@ public:
   }
 };
 
-/// Helper function to perform the replacement of all constant uses of `values`
-/// by a materialized constant extracted from `maybeConstants`.
-/// `values` and `maybeConstants` are expected to have the same size.
-template <typename Container>
-bool replaceConstantUsesOf(PatternRewriter &rewriter, Location loc,
-                           Container values, ArrayRef<int64_t> maybeConstants,
-                           llvm::function_ref<bool(int64_t)> isDynamic) {
-  assert(values.size() == maybeConstants.size() &&
-         " expected values and maybeConstants of the same size");
-  bool atLeastOneReplacement = false;
-  for (auto [maybeConstant, result] : llvm::zip(maybeConstants, values)) {
-    // Don't materialize a constant if there are no uses: this would indice
-    // infinite loops in the driver.
-    if (isDynamic(maybeConstant) || result.use_empty())
-      continue;
-    Value constantVal =
-        rewriter.create<arith::ConstantIndexOp>(loc, maybeConstant);
-    for (Operation *op : llvm::make_early_inc_range(result.getUsers())) {
-      rewriter.startRootUpdate(op);
-      // updateRootInplace: lambda cannot capture structured bindings in C++17
-      // yet.
-      op->replaceUsesOfWith(result, constantVal);
-      rewriter.finalizeRootUpdate(op);
-      atLeastOneReplacement = true;
-    }
-  }
-  return atLeastOneReplacement;
-}
-
-// Forward propagate all constants information from an ExtractStridedMetadataOp.
-struct ForwardStaticMetadata
-    : public OpRewritePattern<memref::ExtractStridedMetadataOp> {
-  using OpRewritePattern<memref::ExtractStridedMetadataOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp metadataOp,
-                                PatternRewriter &rewriter) const override {
-    auto memrefType = metadataOp.getSource().getType().cast<MemRefType>();
-    SmallVector<int64_t> strides;
-    int64_t offset;
-    LogicalResult res = getStridesAndOffset(memrefType, strides, offset);
-    (void)res;
-    assert(succeeded(res) && "must be a strided memref type");
-
-    bool atLeastOneReplacement = replaceConstantUsesOf(
-        rewriter, metadataOp.getLoc(),
-        ArrayRef<TypedValue<IndexType>>(metadataOp.getOffset()),
-        ArrayRef<int64_t>(offset), ShapedType::isDynamicStrideOrOffset);
-    atLeastOneReplacement |= replaceConstantUsesOf(
-        rewriter, metadataOp.getLoc(), metadataOp.getSizes(),
-        memrefType.getShape(), ShapedType::isDynamic);
-    atLeastOneReplacement |= replaceConstantUsesOf(
-        rewriter, metadataOp.getLoc(), metadataOp.getStrides(), strides,
-        ShapedType::isDynamicStrideOrOffset);
-
-    return success(atLeastOneReplacement);
-  }
-};
-
 /// Replace `base, offset, sizes, strides =
 ///              extract_strided_metadata(allocLikeOp)`
 ///
@@ -753,7 +695,6 @@ void memref::populateSimplifyExtractStridedMetadataOpPatterns(
                memref::ExpandShapeOp, getExpandedSizes, getExpandedStrides>,
            ExtractStridedMetadataOpReshapeFolder<
                memref::CollapseShapeOp, getCollapsedSize, getCollapsedStride>,
-           ForwardStaticMetadata,
            ExtractStridedMetadataOpAllocFolder<memref::AllocOp>,
            ExtractStridedMetadataOpAllocFolder<memref::AllocaOp>,
            RewriteExtractAlignedPointerAsIndexOfViewLikeOp,
index f20838a..5a41802 100644 (file)
@@ -758,9 +758,13 @@ func.func @reinterpret_of_subview(%arg : memref<?xi8>, %size1: index, %size2: in
 // Check that a reinterpret cast of an equivalent extract strided metadata
 // is canonicalized to a plain cast when the destination type is different
 // than the type of the original memref.
+// This pattern is currently defeated by the constant folding that happens
+// with extract_strided_metadata.
 // CHECK-LABEL: func @reinterpret_of_extract_strided_metadata_w_type_mistach
 //  CHECK-SAME: (%[[ARG:.*]]: memref<8x2xf32>)
-//       CHECK: %[[CAST:.*]] = memref.cast %[[ARG]] : memref<8x2xf32> to memref<?x?xf32,
+//   CHECK-DAG: %[[C0:.*]] = arith.constant 0
+//   CHECK-DAG: %[[BASE:.*]], %{{.*}}, %{{.*}}:2, %{{.*}}:2 = memref.extract_strided_metadata %[[ARG]]
+//       CHECK: %[[CAST:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[C0]]]
 //       CHECK: return %[[CAST]]
 func.func @reinterpret_of_extract_strided_metadata_w_type_mistach(%arg0 : memref<8x2xf32>) -> memref<?x?xf32, strided<[?, ?], offset: ?>> {
   %base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref<f32>, index, index, index, index, index
@@ -773,12 +777,12 @@ func.func @reinterpret_of_extract_strided_metadata_w_type_mistach(%arg0 : memref
 // Check that a reinterpret cast of an equivalent extract strided metadata
 // is completely removed when the original memref has the same type.
 // CHECK-LABEL: func @reinterpret_of_extract_strided_metadata_same_type
-//  CHECK-SAME: (%[[ARG:.*]]: memref<8x2xf32>)
+//  CHECK-SAME: (%[[ARG:.*]]: memref<?x?xf32
 //       CHECK: return %[[ARG]]
-func.func @reinterpret_of_extract_strided_metadata_same_type(%arg0 : memref<8x2xf32>) -> memref<8x2xf32> {
-  %base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref<f32>, index, index, index, index, index
-  %m2 = memref.reinterpret_cast %base to offset: [%offset], sizes: [%sizes#0, %sizes#1], strides: [%strides#0, %strides#1] : memref<f32> to memref<8x2xf32>
-  return %m2 : memref<8x2xf32>
+func.func @reinterpret_of_extract_strided_metadata_same_type(%arg0 : memref<?x?xf32, strided<[?,?], offset: ?>>) -> memref<?x?xf32, strided<[?,?], offset: ?>> {
+  %base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<?x?xf32, strided<[?,?], offset: ?>> -> memref<f32>, index, index, index, index, index
+  %m2 = memref.reinterpret_cast %base to offset: [%offset], sizes: [%sizes#0, %sizes#1], strides: [%strides#0, %strides#1] : memref<f32> to memref<?x?xf32, strided<[?,?], offset:?>>
+  return %m2 : memref<?x?xf32, strided<[?,?], offset:?>>
 }
 
 // -----
@@ -787,8 +791,10 @@ func.func @reinterpret_of_extract_strided_metadata_same_type(%arg0 : memref<8x2x
 // when the strides don't match.
 // CHECK-LABEL: func @reinterpret_of_extract_strided_metadata_w_different_stride
 //  CHECK-SAME: (%[[ARG:.*]]: memref<8x2xf32>)
-//       CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]]
-//       CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[OFFSET]]], sizes: [4, 2, 2], strides: [1, 1, %[[STRIDES]]#1]
+//   CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+//   CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+//   CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]]
+//       CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[C0]]], sizes: [4, 2, 2], strides: [1, 1, %[[C1]]]
 //       CHECK: return %[[RES]]
 func.func @reinterpret_of_extract_strided_metadata_w_different_stride(%arg0 : memref<8x2xf32>) -> memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>> {
   %base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref<f32>, index, index, index, index, index
@@ -801,8 +807,11 @@ func.func @reinterpret_of_extract_strided_metadata_w_different_stride(%arg0 : me
 // when the offset doesn't match.
 // CHECK-LABEL: func @reinterpret_of_extract_strided_metadata_w_different_offset
 //  CHECK-SAME: (%[[ARG:.*]]: memref<8x2xf32>)
-//       CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]]
-//       CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [1], sizes: [%[[SIZES]]#0, %[[SIZES]]#1], strides: [%[[STRIDES]]#0, %[[STRIDES]]#1]
+//   CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
+//   CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+//   CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+//   CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]]
+//       CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [1], sizes: [%[[C8]], %[[C2]]], strides: [%[[C2]], %[[C1]]]
 //       CHECK: return %[[RES]]
 func.func @reinterpret_of_extract_strided_metadata_w_different_offset(%arg0 : memref<8x2xf32>) -> memref<?x?xf32, strided<[?, ?], offset: ?>> {
   %base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref<f32>, index, index, index, index, index
index caa7efd..4648d9e 100644 (file)
@@ -193,10 +193,10 @@ func.func @extract_strided_metadata_of_rank_reduced_subview_w_variable_strides(
     -> (memref<f32>, index, index, index, index, index) {
 
   %subview = memref.subview %base[3, 4, 2][1, 6, 3][1, %stride, 1] :
-    memref<8x16x4xf32> to memref<6x3xf32, strided<[4, 1], offset: 210>>
+    memref<8x16x4xf32> to memref<6x3xf32, strided<[?, 1], offset: 210>>
 
   %base_buffer, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %subview :
-    memref<6x3xf32, strided<[4, 1], offset: 210>>
+    memref<6x3xf32, strided<[?, 1], offset: 210>>
     -> memref<f32>, index, index, index, index, index
 
   return %base_buffer, %offset, %sizes#0, %sizes#1, %strides#0, %strides#1 :