[MLIR] MemRef Normalization for Dialects
authorAlexandre E. Eichenberger <alexe@us.ibm.com>
Thu, 27 Aug 2020 05:17:33 +0000 (10:47 +0530)
committerUday Bondhugula <uday@polymagelabs.com>
Thu, 27 Aug 2020 14:56:59 +0000 (20:26 +0530)
When dealing with dialects that will results in function calls to
external libraries, it is important to be able to handle maps as some
dialects may require mapped data.  Before this patch, the detection of
whether normalization can apply or not, operations are compared to an
explicit list of operations (`alloc`, `dealloc`, `return`) or to the
presence of specific operation interfaces (`AffineReadOpInterface`,
`AffineWriteOpInterface`, `AffineDMAStartOp`, or `AffineDMAWaitOp`).

This patch add a trait, `MemRefsNormalizable` to determine if an
operation can have its `memrefs` normalized.

This trait can be used in turn by dialects to assert that such
operations are compatible with normalization of `memrefs` with
nontrivial memory layout specification. An example is given in the
literal tests.

Differential Revision: https://reviews.llvm.org/D86236

mlir/docs/Traits.md
mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
mlir/include/mlir/IR/OpBase.td
mlir/include/mlir/IR/OpDefinition.h
mlir/lib/Transforms/NormalizeMemRefs.cpp
mlir/lib/Transforms/Utils/Utils.cpp
mlir/test/Transforms/normalize-memrefs-ops.mlir [new file with mode: 0644]
mlir/test/lib/Dialect/Test/TestOps.td

index c9ef132..5867f22 100644 (file)
@@ -247,6 +247,18 @@ foo.region_op {
 This trait is an important structural property of the IR, and enables operations
 to have [passes](PassManagement.md) scheduled under them.
 
+### MemRefsNormalizable
+
+* `OpTrait::MemRefsNormalizable` -- `MemRefsNormalizable`
+
+This trait is used to flag operations that can accommodate `MemRefs` with
+non-identity memory-layout specifications. This trait indicates that the
+normalization of memory layout can be performed for such operations.
+`MemRefs` normalization consists of replacing an original memory reference
+with layout specifications to an equivalent memory reference where
+the specified memory layout is applied by rewritting accesses and types
+associated with that memory reference.
+
 ### Single Block with Implicit Terminator
 
 *   `OpTrait::SingleBlockImplicitTerminator<typename TerminatorOpType>` :
index b880281..b8b29ff 100644 (file)
@@ -80,8 +80,9 @@ bool isTopLevelValue(Value value);
 // multiple stride levels (possibly using AffineMaps to specify multiple levels
 // of striding).
 // TODO: Consider replacing src/dst memref indices with view memrefs.
-class AffineDmaStartOp : public Op<AffineDmaStartOp, OpTrait::VariadicOperands,
-                                   OpTrait::ZeroResult> {
+class AffineDmaStartOp
+    : public Op<AffineDmaStartOp, OpTrait::MemRefsNormalizable,
+                OpTrait::VariadicOperands, OpTrait::ZeroResult> {
 public:
   using Op::Op;
 
@@ -268,8 +269,9 @@ public:
 //   ...
 //   affine.dma_wait %tag[%index], %num_elements : memref<1xi32, 2>
 //
-class AffineDmaWaitOp : public Op<AffineDmaWaitOp, OpTrait::VariadicOperands,
-                                  OpTrait::ZeroResult> {
+class AffineDmaWaitOp
+    : public Op<AffineDmaWaitOp, OpTrait::MemRefsNormalizable,
+                OpTrait::VariadicOperands, OpTrait::ZeroResult> {
 public:
   using Op::Op;
 
index e5273a8..480e171 100644 (file)
@@ -405,7 +405,8 @@ def AffineIfOp : Affine_Op<"if",
 
 class AffineLoadOpBase<string mnemonic, list<OpTrait> traits = []> :
     Affine_Op<mnemonic, !listconcat(traits,
-        [DeclareOpInterfaceMethods<AffineReadOpInterface>])> {
+        [DeclareOpInterfaceMethods<AffineReadOpInterface>,
+        MemRefsNormalizable])> {
   let arguments = (ins Arg<AnyMemRef, "the reference to load from",
       [MemRead]>:$memref,
       Variadic<Index>:$indices);
@@ -732,7 +733,8 @@ def AffinePrefetchOp : Affine_Op<"prefetch"> {
 
 class AffineStoreOpBase<string mnemonic, list<OpTrait> traits = []> :
     Affine_Op<mnemonic, !listconcat(traits,
-        [DeclareOpInterfaceMethods<AffineWriteOpInterface>])> {
+    [DeclareOpInterfaceMethods<AffineWriteOpInterface>,
+    MemRefsNormalizable])> {
   code extraClassDeclarationBase = [{
     /// Returns the operand index of the value to be stored.
     unsigned getStoredValOperandIndex() { return 0; }
index b80da29..063c34c 100644 (file)
@@ -658,7 +658,7 @@ def BranchOp : Std_Op<"br",
 // CallOp
 //===----------------------------------------------------------------------===//
 
-def CallOp : Std_Op<"call", [CallOpInterface]> {
+def CallOp : Std_Op<"call", [CallOpInterface, MemRefsNormalizable]> {
   let summary = "call operation";
   let description = [{
     The `call` operation represents a direct call to a function that is within
@@ -1388,7 +1388,8 @@ def SinOp : FloatUnaryOp<"sin"> {
 // DeallocOp
 //===----------------------------------------------------------------------===//
 
-def DeallocOp : Std_Op<"dealloc", [MemoryEffects<[MemFree]>]> {
+def DeallocOp : Std_Op<"dealloc",
+    [MemoryEffects<[MemFree]>, MemRefsNormalizable]> {
   let summary = "memory deallocation operation";
   let description = [{
     The `dealloc` operation frees the region of memory referenced by a memref
@@ -2144,8 +2145,8 @@ def RemFOp : FloatArithmeticOp<"remf"> {
 // ReturnOp
 //===----------------------------------------------------------------------===//
 
-def ReturnOp : Std_Op<"return", [NoSideEffect, HasParent<"FuncOp">, ReturnLike,
-                                 Terminator]> {
+def ReturnOp : Std_Op<"return", [NoSideEffect, HasParent<"FuncOp">,
+                                MemRefsNormalizable, ReturnLike, Terminator]> {
   let summary = "return operation";
   let description = [{
     The `return` operation represents a return operation within a function.
index a28410f..8375f24 100644 (file)
@@ -1698,6 +1698,9 @@ def SameOperandsAndResultElementType :
   NativeOpTrait<"SameOperandsAndResultElementType">;
 // Op is a terminator.
 def Terminator : NativeOpTrait<"IsTerminator">;
+// Op can be safely normalized in the presence of MemRefs with
+// non-identity maps.
+def MemRefsNormalizable : NativeOpTrait<"MemRefsNormalizable">;
 
 // Op's regions have a single block with the specified terminator.
 class SingleBlockImplicitTerminator<string op>
index db77935..9579c81 100644 (file)
@@ -1212,6 +1212,20 @@ struct NoRegionArguments : public TraitBase<ConcrentType, NoRegionArguments> {
   }
 };
 
+/// This trait is used to flag operations that can accommodate MemRefs with
+/// non-identity memory-layout specifications. This trait indicates that the
+/// normalization of memory layout can be performed for such operations.
+/// MemRefs normalization consists of replacing an original memory reference
+/// with layout specifications to an equivalent memory reference where the
+/// specified memory layout is applied by rewritting accesses and types
+/// associated with that memory reference.
+// TODO: Right now, the operands of an operation are either all normalizable,
+// or not. In the future, we may want to allow some of the operands to be
+// normalizable.
+template <typename ConcrentType>
+struct MemRefsNormalizable
+    : public TraitBase<ConcrentType, MemRefsNormalizable> {};
+
 } // end namespace OpTrait
 
 //===----------------------------------------------------------------------===//
index 1736fa9..c4f91eb 100644 (file)
@@ -106,23 +106,15 @@ void NormalizeMemRefs::runOnOperation() {
     normalizeFuncOpMemRefs(funcOp, moduleOp);
 }
 
-/// Return true if this operation dereferences one or more memref's.
-/// TODO: Temporary utility, will be replaced when this is modeled through
-/// side-effects/op traits.
-static bool isMemRefDereferencingOp(Operation &op) {
-  return isa<AffineReadOpInterface, AffineWriteOpInterface, AffineDmaStartOp,
-             AffineDmaWaitOp>(op);
-}
-
 /// Check whether all the uses of oldMemRef are either dereferencing uses or the
 /// op is of type : DeallocOp, CallOp or ReturnOp. Only if these constraints
 /// are satisfied will the value become a candidate for replacement.
 /// TODO: Extend this for DimOps.
 static bool isMemRefNormalizable(Value::user_range opUsers) {
   if (llvm::any_of(opUsers, [](Operation *op) {
-        if (isMemRefDereferencingOp(*op))
+        if (op->hasTrait<OpTrait::MemRefsNormalizable>())
           return false;
-        return !isa<DeallocOp, CallOp, ReturnOp>(*op);
+        return true;
       }))
     return false;
   return true;
index c310702..516f8c0 100644 (file)
@@ -279,7 +279,7 @@ LogicalResult mlir::replaceAllMemRefUsesWith(
       // Currently we support the following non-dereferencing ops to be a
       // candidate for replacement: Dealloc, CallOp and ReturnOp.
       // TODO: Add support for other kinds of ops.
-      if (!isa<DeallocOp, CallOp, ReturnOp>(*op))
+      if (!op->hasTrait<OpTrait::MemRefsNormalizable>())
         return failure();
     }
 
diff --git a/mlir/test/Transforms/normalize-memrefs-ops.mlir b/mlir/test/Transforms/normalize-memrefs-ops.mlir
new file mode 100644 (file)
index 0000000..8ce841e
--- /dev/null
@@ -0,0 +1,57 @@
+// RUN: mlir-opt -normalize-memrefs %s | FileCheck %s
+
+// For all these cases, we test if MemRefs Normalization works with the test
+// operations.
+// * test.op_norm: this operation has the MemRefsNormalizable attribute. The tests
+//   that include this operation are constructed so that the normalization should
+//   happen.
+// * test_op_nonnorm: this operation does not have the MemRefsNormalization
+//   attribute. The tests that include this operation are contructed so that the
+//    normalization should not happen.
+
+#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2 floordiv 32, d3 floordiv 64, d2 mod 32, d3 mod 64)>
+
+// Test with op_norm and maps in arguments and in the operations in the function.
+
+// CHECK-LABEL: test_norm
+// CHECK-SAME: (%[[ARG0:[a-z0-9]*]]: memref<1x16x1x1x32x64xf32>)
+func @test_norm(%arg0 : memref<1x16x14x14xf32, #map0>) -> () {
+    %0 = alloc() : memref<1x16x14x14xf32, #map0>
+    "test.op_norm"(%arg0, %0) : (memref<1x16x14x14xf32, #map0>, memref<1x16x14x14xf32, #map0>) -> ()
+    dealloc %0 :  memref<1x16x14x14xf32, #map0>
+
+    // CHECK: %[[v0:[a-z0-9]*]] = alloc() : memref<1x16x1x1x32x64xf32>
+    // CHECK: "test.op_norm"(%[[ARG0]], %[[v0]]) : (memref<1x16x1x1x32x64xf32>, memref<1x16x1x1x32x64xf32>) -> ()
+    // CHECK: dealloc %[[v0]] : memref<1x16x1x1x32x64xf32>
+    return
+}
+
+// Same test with op_nonnorm, with maps in the argmentets and the operations in the function.
+
+// CHECK-LABEL: test_nonnorm
+// CHECK-SAME: (%[[ARG0:[a-z0-9]*]]: memref<1x16x14x14xf32, #map0>)
+func @test_nonnorm(%arg0 : memref<1x16x14x14xf32, #map0>) -> () {
+    %0 = alloc() : memref<1x16x14x14xf32, #map0>
+    "test.op_nonnorm"(%arg0, %0) : (memref<1x16x14x14xf32, #map0>, memref<1x16x14x14xf32, #map0>) -> ()
+    dealloc %0 :  memref<1x16x14x14xf32, #map0>
+
+    // CHECK: %[[v0:[a-z0-9]*]] = alloc() : memref<1x16x14x14xf32, #map0>
+    // CHECK: "test.op_nonnorm"(%[[ARG0]], %[[v0]]) : (memref<1x16x14x14xf32, #map0>, memref<1x16x14x14xf32, #map0>) -> ()
+    // CHECK: dealloc %[[v0]] : memref<1x16x14x14xf32, #map0>
+    return
+}
+
+// Test with op_norm, with maps in the operations in the function.
+
+// CHECK-LABEL: test_norm_mix
+// CHECK-SAME: (%[[ARG0:[a-z0-9]*]]: memref<1x16x1x1x32x64xf32>
+func @test_norm_mix(%arg0 : memref<1x16x1x1x32x64xf32>) -> () {
+    %0 = alloc() : memref<1x16x14x14xf32, #map0>
+    "test.op_norm"(%arg0, %0) : (memref<1x16x1x1x32x64xf32>, memref<1x16x14x14xf32, #map0>) -> ()
+    dealloc %0 :  memref<1x16x14x14xf32, #map0>
+
+    // CHECK: %[[v0:[a-z0-9]*]] = alloc() : memref<1x16x1x1x32x64xf32>
+    // CHECK: "test.op_norm"(%[[ARG0]], %[[v0]]) : (memref<1x16x1x1x32x64xf32>, memref<1x16x1x1x32x64xf32>) -> ()
+    // CHECK: dealloc %[[v0]] : memref<1x16x1x1x32x64xf32>
+    return
+}
index c938081..022732d 100644 (file)
@@ -618,6 +618,16 @@ def OpM : TEST_Op<"op_m"> {
   let arguments = (ins I32, OptionalAttr<I32Attr>:$optional_attr);
   let results = (outs I32);
 }
+
+// Test for memrefs normalization of an op with normalizable memrefs.
+def OpNorm : TEST_Op<"op_norm", [MemRefsNormalizable]> {
+  let arguments = (ins AnyMemRef:$X, AnyMemRef:$Y);
+}
+// Test for memrefs normalization of an op without normalizable memrefs.
+def OpNonNorm : TEST_Op<"op_nonnorm"> {
+  let arguments = (ins AnyMemRef:$X, AnyMemRef:$Y);
+}
+
 // Pattern add the argument plus a increasing static number hidden in
 // OpMTest function. That value is set into the optional argument.
 // That way, we will know if operations is called once or twice.