Properly clone Linalg ops with regions
authorNicolas Vasilache <ntv@google.com>
Tue, 3 Sep 2019 22:27:49 +0000 (15:27 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 3 Sep 2019 22:28:47 +0000 (15:28 -0700)
This CL adds support for proper cloning of Linalg ops that have regions (i.e. the generic linalg op). This is used to properly implement tiling and fusion for such ops. Adequate tests are added.

PiperOrigin-RevId: 267027176

mlir/include/mlir/Dialect/Linalg/IR/LinalgLibraryOps.td
mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h
mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
mlir/test/Linalg/fusion.mlir
mlir/test/Linalg/tile.mlir

index cac24ce..3ca0089 100644 (file)
@@ -110,6 +110,20 @@ def LinalgLibraryInterface : OpInterface<"LinalgOp"> {
            "ArrayRef<NamedAttribute>":$attributes), [{
         return builder.create<ConcreteOp>(loc, ArrayRef<Type>{}, operands,
                                           attributes);
+      }]>,
+
+    /// Clone an operation with the given location and operands. This is used to
+    /// abstract away the optional underlying region creation.
+    InterfaceMethod<"Operation *", "clone",
+      (ins "OpBuilder &":$b, "Location":$loc, "ArrayRef<Value *>":$operands), [{
+        BlockAndValueMapping map;
+        unsigned numRegions = op.getOperation()->getNumRegions();
+        Operation *res = create(b, loc, operands, op.getAttrs());
+        assert(res->getNumRegions() == numRegions && "inconsistent # regions");
+        for (unsigned ridx = 0; ridx < numRegions; ++ridx)
+          op.getOperation()->getRegion(ridx).cloneInto(
+            &res->getRegion(ridx), map);
+        return res;
       }]>
   ];
 }
index e30c4d1..f9bcf77 100644 (file)
 #ifndef MLIR_DIALECT_LINALG_LINALGOPS_H_
 #define MLIR_DIALECT_LINALG_LINALGOPS_H_
 
+#include "mlir/Dialect/Linalg/IR/LinalgTraits.h"
+#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
 #include "mlir/IR/AffineMap.h"
+#include "mlir/IR/BlockAndValueMapping.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/Function.h"
 #include "mlir/IR/Module.h"
@@ -26,8 +29,6 @@
 #include "mlir/IR/StandardTypes.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/IR/Types.h"
-#include "mlir/Dialect/Linalg/IR/LinalgTraits.h"
-#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
 #include "mlir/Support/LLVM.h"
 
 namespace mlir {
index 954f826..0ce6c82 100644 (file)
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/EDSC/Helpers.h"
-#include "mlir/IR/AffineExpr.h"
-#include "mlir/IR/AffineMap.h"
-#include "mlir/IR/OpImplementation.h"
 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
 #include "mlir/Dialect/Linalg/Passes.h"
 #include "mlir/Dialect/Linalg/Utils/Intrinsics.h"
 #include "mlir/Dialect/Linalg/Utils/Utils.h"
+#include "mlir/EDSC/Helpers.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/OpImplementation.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Support/LLVM.h"
 #include "mlir/Support/STLExtras.h"
@@ -107,7 +107,7 @@ static LinalgOp cloneWithLoopRanges(OpBuilder &b, Location loc, LinalgOp op,
   }
   auto operands = getAssumedNonViewOperands(op);
   clonedViews.append(operands.begin(), operands.end());
-  return op.create(b, loc, clonedViews, op.getAttrs());
+  return op.clone(b, loc, clonedViews);
 }
 
 struct ViewDimension {
index 99e42cf..cacec86 100644 (file)
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
+#include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/Dialect/Linalg/Utils/Intrinsics.h"
+#include "mlir/Dialect/Linalg/Utils/Utils.h"
 #include "mlir/Dialect/LoopOps/LoopOps.h"
 #include "mlir/EDSC/Helpers.h"
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/AffineExprVisitor.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/OpImplementation.h"
-#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
-#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
-#include "mlir/Dialect/Linalg/Passes.h"
-#include "mlir/Dialect/Linalg/Utils/Intrinsics.h"
-#include "mlir/Dialect/Linalg/Utils/Utils.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Support/LLVM.h"
 #include "mlir/Support/STLExtras.h"
@@ -397,7 +397,7 @@ mlir::linalg::tileLinalgOp(LinalgOp op, ArrayRef<Value *> tileSizes,
     if (!promote) {
       auto operands = getAssumedNonViewOperands(op);
       views.append(operands.begin(), operands.end());
-      res = op.create(b, loc, views, op.getAttrs());
+      res = op.clone(b, loc, views);
       return;
     }
 
@@ -429,7 +429,7 @@ mlir::linalg::tileLinalgOp(LinalgOp op, ArrayRef<Value *> tileSizes,
     }
     auto operands = getAssumedNonViewOperands(op);
     opViews.append(operands.begin(), operands.end());
-    res = op.create(b, loc, opViews, op.getAttrs());
+    res = op.clone(b, loc, opViews);
 
     // 6. Emit write-back for the promoted output views: copy the partial view.
     for (unsigned i = 0, e = writebackViews.size(); i < e; ++i) {
index 24e078d..c07d0a6 100644 (file)
@@ -236,3 +236,58 @@ func @f8(%A: !linalg.view<?x?xf32>, %B: !linalg.view<?x?xf32>, %C: !linalg.view<
 //
 // FUSE-234-LABEL: func @f8
 //   FUSE-234-NOT:   loop.for
+
+#id_2d = (i, j) -> (i, j)
+#pointwise_2d_trait = {
+  indexing_maps = [#id_2d, #id_2d, #id_2d],
+  n_loop_types = [2, 0, 0],
+  n_views = [2, 1]
+}
+
+func @pointwise(%arg0: !linalg.view<?x?xf32>, %arg1: !linalg.view<?x?xf32>,
+                %arg2: !linalg.view<?x?xf32>, %arg3: !linalg.view<?x?xf32>) {
+  linalg.generic #pointwise_2d_trait %arg0, %arg0, %arg1 {
+  ^bb0(%arg4: f32, %arg5: f32, %arg6: f32):   // no predecessors
+    %4 = addf %arg4, %arg5 : f32
+    linalg.yield %4 : f32
+  }: !linalg.view<?x?xf32>, !linalg.view<?x?xf32>, !linalg.view<?x?xf32>
+  linalg.generic #pointwise_2d_trait %arg1, %arg2, %arg3 {
+  ^bb0(%arg4: f32, %arg5: f32, %arg6: f32):   // no predecessors
+    %4 = mulf %arg4, %arg5 : f32
+    linalg.yield %4 : f32
+  }: !linalg.view<?x?xf32>, !linalg.view<?x?xf32>, !linalg.view<?x?xf32>
+  return
+}
+// No tiling => no fusion
+// FUSE-0-LABEL: func @pointwise
+//   FUSE-0-NOT: loop.for
+//       FUSE-0: linalg.generic
+//       FUSE-0:   addf
+//       FUSE-0: linalg.generic
+//       FUSE-0:   mulf
+//
+// FUSE-2-LABEL: func @pointwise
+//       FUSE-2:   loop.for
+//   FUSE-2-NOT:   loop.for
+//       FUSE-2:     linalg.generic
+//       FUSE-2:       addf
+//       FUSE-2:     linalg.generic
+//       FUSE-2:       mulf
+//
+// FUSE-23-LABEL: func @pointwise
+//       FUSE-23:   loop.for
+//       FUSE-23:     loop.for
+//   FUSE-23-NOT:   loop.for
+//       FUSE-23:       linalg.generic
+//       FUSE-23:         addf
+//       FUSE-23:       linalg.generic
+//       FUSE-23:         mulf
+//
+// FUSE-234-LABEL: func @pointwise
+//       FUSE-234:   loop.for
+//       FUSE-234:     loop.for
+//   FUSE-234-NOT:   loop.for
+//       FUSE-234:       linalg.generic
+//       FUSE-234:         addf
+//       FUSE-234:       linalg.generic
+//       FUSE-234:         mulf
index 92898b7..c3d6826 100644 (file)
@@ -159,3 +159,39 @@ func @fill(%arg0: !linalg.view<?x?xf32>, %arg1: f32) {
 //       TILE-234:     for
 //   TILE-234-NOT:   for
 //       TILE-234:       fill{{.*}} f32
+
+#id_2d = (i, j) -> (i, j)
+#pointwise_2d_trait = {
+  indexing_maps = [#id_2d, #id_2d, #id_2d],
+  n_loop_types = [2, 0, 0],
+  n_views = [2, 1]
+}
+
+func @pointwise(%arg0: !linalg.view<?x?xf32>, %arg1: !linalg.view<?x?xf32>,
+                %arg2: !linalg.view<?x?xf32>) {
+  linalg.generic #pointwise_2d_trait %arg0, %arg1, %arg2 {
+  ^bb0(%arg4: f32, %arg5: f32, %arg6: f32):   // no predecessors
+    %4 = addf %arg4, %arg5 : f32
+    linalg.yield %4 : f32
+  }: !linalg.view<?x?xf32>, !linalg.view<?x?xf32>, !linalg.view<?x?xf32>
+  return
+}
+// TILE-2-LABEL: func @pointwise
+//       TILE-2:   for
+//   TILE-2-NOT:   for
+//       TILE-2:   linalg.generic
+
+// TILE-02-LABEL: func @pointwise
+//       TILE-02:   for
+//   TILE-02-NOT:   for
+//       TILE-02:     linalg.generic
+
+// TILE-002-LABEL: func @pointwise
+//   TILE-002-NOT:   for
+//       TILE-002:     linalg.generic
+
+// TILE-234-LABEL: func @pointwise
+//       TILE-234:   for
+//       TILE-234:     for
+//   TILE-234-NOT:   for
+//       TILE-234:       linalg.generic