[mlir][Linalg] Introduce folding patterns to remove certain MemRefCastOp
authorNicolas Vasilache <ntv@google.com>
Tue, 28 Jan 2020 18:44:37 +0000 (13:44 -0500)
committerNicolas Vasilache <ntv@google.com>
Wed, 29 Jan 2020 14:52:51 +0000 (09:52 -0500)
Summary:
Canonicalization and folding patterns in StandardOps may interfere with the needs
of Linalg. This revision introduces specific foldings for dynamic memrefs that can
be proven to be static.

Very concretely:

Determines whether it is possible to fold it away in the parent Linalg op:

```mlir
  %1 = memref_cast %0 : memref<8x16xf32> to memref<?x?xf32>
  %2 = linalg.slice %1 ... : memref<?x?xf32> ...
  // or
  %1 = memref_cast %0 : memref<8x16xf32, affine_map<(i, j)->(16 * i + j)>>
         to memref<?x?xf32>
  linalg.generic(%1 ...) : memref<?x?xf32> ...
```

into

```mlir
  %2 = linalg.slice %0 ... : memref<8x16xf32> ...
  // or
  linalg.generic(%0 ... : memref<8x16xf32, affine_map<(i, j)->(16 * i + j)>>
```

Reviewers: ftynse, aartbik, jsetoain, tetuante, asaadaldien

Subscribers: mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, arpith-jacob, mgester, lucyrfox, liufengdb, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D73565

mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/test/Dialect/Linalg/canonicalize.mlir [new file with mode: 0644]

index 4c23444..0dec1d1 100644 (file)
@@ -117,6 +117,8 @@ def Linalg_ReshapeOp : Linalg_Op<"reshape", [NoSideEffect]>,
     static StringRef getReassociationAttrName() { return "reassociation"; }
     MemRefType getViewType() { return view().getType().cast<MemRefType>(); }
   }];
+
+  let hasFolder = 1;
 }
 
 def Linalg_SliceOp : Linalg_Op<"slice", [NoSideEffect]>,
@@ -188,6 +190,8 @@ def Linalg_SliceOp : Linalg_Op<"slice", [NoSideEffect]>,
       return res;
     }
   }];
+
+  let hasFolder = 1;
 }
 
 def Linalg_TransposeOp : Linalg_Op<"transpose", [NoSideEffect]>,
@@ -222,6 +226,8 @@ def Linalg_TransposeOp : Linalg_Op<"transpose", [NoSideEffect]>,
     static StringRef getPermutationAttrName() { return "permutation"; }
     ShapedType getShapedType() { return view().getType().cast<ShapedType>(); }
   }];
+
+  let hasFolder = 1;
 }
 
 def Linalg_YieldOp : Linalg_Op<"yield", [NativeOpTrait<"IsTerminator">]>,
index 2a2ef55..03318fa 100644 (file)
@@ -270,6 +270,8 @@ def CopyOp : LinalgStructured_Op<"copy", [NInputs<1>, NOutputs<1>]> {
     }
   }];
   let verifier = [{ return ::verify(*this); }];
+
+  let hasFolder = 1;
 }
 
 def FillOp : LinalgStructured_Op<"fill", [NInputs<0>, NOutputs<1>]> {
@@ -287,6 +289,8 @@ def FillOp : LinalgStructured_Op<"fill", [NInputs<0>, NOutputs<1>]> {
     }
   }];
   let verifier = [{ return ::verify(*this); }];
+
+  let hasFolder = 1;
 }
 
 def DotOp : LinalgStructured_Op<"dot", [NInputs<2>, NOutputs<1>]> {
@@ -302,6 +306,8 @@ def DotOp : LinalgStructured_Op<"dot", [NInputs<2>, NOutputs<1>]> {
         StringAttr::get(getReductionIteratorTypeName(), ctx), ctx);
     }
   }];
+
+  let hasFolder = 1;
 }
 
 def MatvecOp : LinalgStructured_Op<"matvec", [NInputs<2>, NOutputs<1>]> {
@@ -319,6 +325,8 @@ def MatvecOp : LinalgStructured_Op<"matvec", [NInputs<2>, NOutputs<1>]> {
       return ArrayAttr::get(iters, ctx);
     }
   }];
+
+  let hasFolder = 1;
 }
 
 def MatmulOp : LinalgStructured_Op<"matmul", [NInputs<2>, NOutputs<1>]> {
@@ -337,6 +345,8 @@ def MatmulOp : LinalgStructured_Op<"matmul", [NInputs<2>, NOutputs<1>]> {
       return ArrayAttr::get(iters, ctx);
     }
   }];
+
+  let hasFolder = 1;
 }
 
 def ConvOp : LinalgStructured_Op<"conv", [NInputs<2>, NOutputs<1>]> {
@@ -406,7 +416,10 @@ def ConvOp : LinalgStructured_Op<"conv", [NInputs<2>, NOutputs<1>]> {
         .cast<IntegerAttr>().getValue().getSExtValue();
     }
   }];
+
   let verifier = [{ return ::verify(*this); }];
+
+  let hasFolder = 1;
 }
 
 def LinalgOperand: Type<
@@ -583,7 +596,10 @@ def GenericOp : GenericOpBase<"generic"> {
     tensor SSA values are expected to be useful and will be added in the near
     future.
   }];
+
   let verifier = [{ return ::verify(*this); }];
+
+  let hasFolder = 1;
 }
 
 def IndexedGenericOp : GenericOpBase<"indexed_generic"> {
@@ -710,7 +726,10 @@ def IndexedGenericOp : GenericOpBase<"indexed_generic"> {
     tensor SSA values are expected to be useful and will be added in the near
     future.
   }];
+
   let verifier = [{ return ::verify(*this); }];
+
+  let hasFolder = 1;
 }
 
 #endif // LINALG_STRUCTURED_OPS
index b1ffce6..8ed7e79 100644 (file)
@@ -12,6 +12,7 @@
 
 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
+#include "mlir/Dialect/StandardOps/Ops.h"
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/Builders.h"
 using namespace mlir;
 using namespace mlir::linalg;
 
+/// Determines whether it is possible to fold it away in the parent Linalg op:
+///
+/// ```mlir
+///   %1 = memref_cast %0 : memref<8x16xf32> to memref<?x?xf32>
+///   %2 = linalg.slice %1 ... : memref<?x?xf32> ...
+///   // or
+///   %1 = memref_cast %0 : memref<8x16xf32, affine_map<(i, j)->(16 * i + j)>>
+///          to memref<?x?xf32>
+///   linalg.generic(%1 ...) : memref<?x?xf32> ...
+/// ```
+///
+/// into
+///
+/// ```mlir
+///   %2 = linalg.slice %0 ... : memref<8x16xf32> ...
+///   // or
+///   linalg.generic(%0 ... : memref<8x16xf32, affine_map<(i, j)->(16 * i + j)>>
+/// ```
+///
+static bool canFold(MemRefCastOp castOp) {
+  MemRefType sourceType = castOp.source().getType().dyn_cast<MemRefType>();
+  MemRefType resultType = castOp.getType().dyn_cast<MemRefType>();
+
+  // If we don't have MemRefType as source and destination, bail out.
+  if (!sourceType || !resultType)
+    return false;
+
+  // If resultType has a map, it needs to be the same as the source type to
+  // canonicalize.
+  if (!resultType.getAffineMaps().empty() &&
+      sourceType.getAffineMaps() != resultType.getAffineMaps())
+    return false;
+
+  // Ensure that:
+  //   1. source is static
+  //   2. source and target have the same rank (will be extended when needed)
+  //   3. if result is partially static, ensure sizes match.
+  if (!sourceType.hasStaticShape() ||
+      sourceType.getRank() != resultType.getRank())
+    return false;
+
+  for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) {
+    auto sourceSize = std::get<0>(it);
+    auto resultSize = std::get<1>(it);
+    if (ShapedType::isDynamic(resultSize))
+      continue;
+    if (sourceSize != resultSize)
+      return false;
+  }
+
+  // If source has a map, it can only canonicalize if it is the canonical
+  // strided layout map.
+  if (sourceType.getAffineMaps().empty())
+    return true;
+
+  int64_t offset;
+  SmallVector<int64_t, 4> strides;
+  auto res = getStridesAndOffset(sourceType, strides, offset);
+  (void)res;
+  assert(succeeded(res));
+  auto stridedMap =
+      makeStridedLinearLayoutMap(strides, offset, castOp.getContext());
+  AffineMap sourceMap = sourceType.getAffineMaps().front();
+  return sourceMap == stridedMap;
+}
+
+/// This is a common class used for patterns of the form
+/// ```
+///    someop(memrefcast) -> someop
+/// ```
+/// It folds the source of any memref_cast into the root operation directly.
+static LogicalResult foldMemRefCast(Operation *op) {
+  bool folded = false;
+  for (OpOperand &operand : op->getOpOperands()) {
+    auto castOp = dyn_cast_or_null<MemRefCastOp>(operand.get().getDefiningOp());
+    if (castOp && canFold(castOp)) {
+      operand.set(castOp.getOperand());
+      folded = true;
+    }
+  }
+  return success(folded);
+}
+
 ///////////////////// Operations defined with Tablegen /////////////////////////
 // For such operations that do not correspond to library calls (i.e. defined in
 // LinalgOps.td), we define an overloaded `print` function and a
@@ -1077,3 +1161,54 @@ ArrayAttr mlir::linalg::MatmulOp::indexing_maps() {
 ArrayAttr mlir::linalg::MatvecOp::indexing_maps() {
   return getIndexingMaps(getOperation());
 }
+
+// TODO(ntv, rriddle): Consider making all this boilerplate easy to autogenerate
+// with Tablegen. This seems a desirable property in the context of OpInterfaces
+// where a Linalg "named" op **isa** LinalgOp.
+LogicalResult ConvOp::fold(ArrayRef<Attribute>,
+                           SmallVectorImpl<OpFoldResult> &) {
+  return foldMemRefCast(*this);
+}
+LogicalResult CopyOp::fold(ArrayRef<Attribute>,
+                           SmallVectorImpl<OpFoldResult> &) {
+  return foldMemRefCast(*this);
+}
+LogicalResult DotOp::fold(ArrayRef<Attribute>,
+                          SmallVectorImpl<OpFoldResult> &) {
+  return foldMemRefCast(*this);
+}
+LogicalResult FillOp::fold(ArrayRef<Attribute>,
+                           SmallVectorImpl<OpFoldResult> &) {
+  return foldMemRefCast(*this);
+}
+LogicalResult GenericOp::fold(ArrayRef<Attribute>,
+                              SmallVectorImpl<OpFoldResult> &) {
+  return foldMemRefCast(*this);
+}
+LogicalResult IndexedGenericOp::fold(ArrayRef<Attribute>,
+                                     SmallVectorImpl<OpFoldResult> &) {
+  return foldMemRefCast(*this);
+}
+LogicalResult MatvecOp::fold(ArrayRef<Attribute>,
+                             SmallVectorImpl<OpFoldResult> &) {
+  return foldMemRefCast(*this);
+}
+LogicalResult MatmulOp::fold(ArrayRef<Attribute>,
+                             SmallVectorImpl<OpFoldResult> &) {
+  return foldMemRefCast(*this);
+}
+OpFoldResult ReshapeOp::fold(ArrayRef<Attribute>) {
+  if (succeeded(foldMemRefCast(*this)))
+    return getResult();
+  return {};
+}
+OpFoldResult SliceOp::fold(ArrayRef<Attribute>) {
+  if (succeeded(foldMemRefCast(*this)))
+    return getResult();
+  return {};
+}
+OpFoldResult TransposeOp::fold(ArrayRef<Attribute>) {
+  if (succeeded(foldMemRefCast(*this)))
+    return getResult();
+  return {};
+}
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
new file mode 100644 (file)
index 0000000..370cf45
--- /dev/null
@@ -0,0 +1,20 @@
+// RUN: mlir-opt %s -canonicalize | FileCheck %s
+
+// CHECK-LABEL: func @memref_cast(
+func @memref_cast(%a: index, %b: index) -> memref<?x?xf32> {
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %c8 = constant 8 : index
+  %c16 = constant 16 : index
+  %1 = alloc (%b) : memref<?xi8>
+  %2 = view %1[][] : memref<?xi8> to memref<16x16xf32>
+  %3 = memref_cast %2 : memref<16x16xf32> to memref<?x?xf32>
+  %r0 = linalg.range %c0:%c8:%c1 : !linalg.range
+
+  // CHECK:  linalg.slice {{.*}} : memref<16x16xf32>, !linalg.range, !linalg.range, memref<?x?xf32>
+  %4 = linalg.slice %3[%r0, %r0] : memref<?x?xf32>, !linalg.range, !linalg.range, memref<?x?xf32>
+
+  // CHECK:  linalg.matmul{{.*}}: memref<16x16xf32>, memref<16x16xf32>, memref<16x16xf32>
+  linalg.matmul(%3, %3, %3) : memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>
+  return %4: memref<?x?xf32>
+}