Add canonicalization pattern for linalg.dim
authorNicolas Vasilache <ntv@google.com>
Thu, 8 Aug 2019 16:09:29 +0000 (09:09 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Thu, 8 Aug 2019 16:09:58 +0000 (09:09 -0700)
This CL introduces canonicalization patterns for linalg.dim.
This allows the dimenions of chains of view, slice and subview operations to simplify.
Down the line, when mixed with cse, this also allows better composition of linalg tiling and fusion by tracking operations that give the same result (not in this CL).

PiperOrigin-RevId: 262365865

mlir/include/mlir/Linalg/IR/LinalgOps.td
mlir/lib/Linalg/IR/LinalgOps.cpp
mlir/test/Linalg/canonicalize.mlir [new file with mode: 0644]

index 55a8108..bbbbfad 100644 (file)
@@ -146,6 +146,8 @@ def DimOp : Linalg_Op<"dim", [NoSideEffect]>,
     }
     ViewType getViewType() { return getOperand()->getType().cast<ViewType>(); }
   }];
+
+  let hasCanonicalizer = 1;
 }
 
 def SubViewOp : Linalg_Op<"subview", [NoSideEffect]>,
index 4feb22b..60820ae 100644 (file)
@@ -28,6 +28,7 @@
 #include "mlir/IR/Function.h"
 #include "mlir/IR/Module.h"
 #include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/StandardTypes.h"
 #include "mlir/Linalg/IR/LinalgTypes.h"
 #include "mlir/Linalg/Utils/Utils.h"
@@ -42,6 +43,81 @@ using namespace mlir::edsc;
 using namespace mlir::edsc::intrinsics;
 using namespace mlir::linalg;
 
+namespace {
+/// Fold constant dimensions into an alloc operation.
+struct SimplifyDimOp : public OpRewritePattern<linalg::DimOp> {
+  using OpRewritePattern<linalg::DimOp>::OpRewritePattern;
+
+  PatternMatchResult matchAndRewrite(linalg::DimOp dimOp,
+                                     PatternRewriter &rewriter) const override;
+};
+} // end namespace
+
+PatternMatchResult
+SimplifyDimOp::matchAndRewrite(linalg::DimOp dimOp,
+                               PatternRewriter &rewriter) const {
+  auto *viewProducingOp = dimOp.view()->getDefiningOp();
+  auto subView = dyn_cast_or_null<SubViewOp>(viewProducingOp);
+  auto slice = dyn_cast_or_null<SliceOp>(viewProducingOp);
+  auto view = dyn_cast_or_null<ViewOp>(viewProducingOp);
+  if (!subView && !slice && !view)
+    return matchFailure();
+
+  unsigned dim = dimOp.getIndex();
+  Value *min, *max, *step;
+  if (view) {
+    // Cannot traverse block arguments, fail.
+    if (isa<BlockArgument>(view.getIndexing(dim)))
+      return matchFailure();
+    // Record min, max, step for further processing.
+    auto range = cast<RangeOp>(view.getIndexing(dim)->getDefiningOp());
+    std::tie(min, max, step) =
+        std::make_tuple(range.min(), range.max(), range.step());
+  } else if (subView) {
+    // Record min, max, step for further processing.
+    auto range = subView.getRange(dim);
+    std::tie(min, max, step) =
+        std::make_tuple(range.min, range.max, range.step);
+  } else {
+    // Taking the dim of a slice must take a range (since other dims have been
+    // rank-reduced).
+    auto *rangeValue = slice.getRanges()[dim];
+    // Cannot traverse block arguments, fail.
+    if (isa<BlockArgument>(rangeValue))
+      return matchFailure();
+    auto range = cast<RangeOp>(rangeValue->getDefiningOp());
+    // Record min, max, step for further processing.
+    std::tie(min, max, step) =
+        std::make_tuple(range.min(), range.max(), range.step());
+  }
+
+  // Only support constant steps of 1 atm.
+  auto constant = dyn_cast_or_null<ConstantIndexOp>(step->getDefiningOp());
+  if (!constant || constant.getValue() != 1)
+    return matchFailure();
+
+  // Circumvent affine constraints:
+  //   emit an affine_apply when possible, otherwise emit a `subi`.
+  bool validAffineMin = isValidDim(min) || isValidSymbol(min) ||
+                        isa_and_nonnull<ConstantIndexOp>(min->getDefiningOp());
+  bool validAffineMax = isValidDim(max) || isValidSymbol(max) ||
+                        isa_and_nonnull<ConstantIndexOp>(max->getDefiningOp());
+
+  OpBuilder b(dimOp);
+  ScopedContext scope(b, dimOp.getLoc());
+  // Emit `subi`.
+  if (!validAffineMin || !validAffineMax) {
+    rewriter.replaceOp(dimOp, {subi(max, min)}, {dimOp.view()});
+    return matchSuccess();
+  }
+
+  // Emit affine_apply.
+  using edsc::op::operator-;
+  rewriter.replaceOp(dimOp, {ValueHandle(max) - ValueHandle(min)},
+                     {dimOp.view()});
+  return matchSuccess();
+}
+
 ////////////////////////////////////////////////////////////////////////////////
 // LoadOp.
 ////////////////////////////////////////////////////////////////////////////////
@@ -501,6 +577,14 @@ static ParseResult parseBufferSizeOp(OpAsmParser *parser,
                                        result->types));
 }
 
+//===----------------------------------------------------------------------===//
+// DimOp
+//===----------------------------------------------------------------------===//
+void mlir::linalg::DimOp::getCanonicalizationPatterns(
+    OwningRewritePatternList &results, MLIRContext *context) {
+  results.insert<SimplifyDimOp>(context);
+}
+
 static void print(OpAsmPrinter *p, linalg::DimOp op) {
   *p << op.getOperationName() << " " << *op.getOperand() << ", "
      << op.getIndex();
diff --git a/mlir/test/Linalg/canonicalize.mlir b/mlir/test/Linalg/canonicalize.mlir
new file mode 100644 (file)
index 0000000..65e1d54
--- /dev/null
@@ -0,0 +1,73 @@
+// RUN: mlir-opt %s -canonicalize | FileCheck %s
+
+// CHECK-DAG: #[[SUB:.*]] = ()[s0, s1] -> (s0 - s1)
+
+func @fold_constants(%arg0: !linalg.buffer<?xf32>) -> (index, index, index, index, index) {
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %c2 = constant 2 : index
+  %c3 = constant 3 : index
+  %c4 = constant 4 : index
+  %c5 = constant 5 : index
+  %R02 = linalg.range %c0:%c2:%c1 : !linalg.range
+  %R03 = linalg.range %c0:%c3:%c1 : !linalg.range
+  %R04 = linalg.range %c0:%c4:%c1 : !linalg.range
+  %R12 = linalg.range %c1:%c2:%c1 : !linalg.range
+  %R13 = linalg.range %c1:%c3:%c1 : !linalg.range
+  %R14 = linalg.range %c1:%c4:%c1 : !linalg.range
+
+  %v = linalg.view %arg0[%R02, %R14] : !linalg.buffer<?xf32> -> !linalg.view<?x?xf32>
+  // Expected 2.
+  %v0 = linalg.dim %v, 0 : !linalg.view<?x?xf32>
+  // Expected 3.
+  %v1 = linalg.dim %v, 1 : !linalg.view<?x?xf32>
+
+  %s = linalg.slice %v[%c1, %R12] : !linalg.view<?x?xf32>, index, !linalg.range, !linalg.view<?xf32>
+  // Expected 1.
+  %s0 = linalg.dim %s, 0 : !linalg.view<?xf32>
+
+  %sv = linalg.subview %v[%v0, %v1, %c1, %c2, %c4, %c1] : !linalg.view<?x?xf32>
+  // Expected 1.
+  %sv0 = linalg.dim %sv, 0 : !linalg.view<?x?xf32>
+  // Expected 2.
+  %sv1 = linalg.dim %sv, 1 : !linalg.view<?x?xf32>
+
+  return %v0, %v1, %s0, %sv0, %sv1 : index, index, index, index, index
+}
+
+// CHECK-LABEL: fold_constants
+//   CHECK-DAG:   %[[c1:.*]] = constant 1 : index
+//   CHECK-DAG:   %[[c2:.*]] = constant 2 : index
+//   CHECK-DAG:   %[[c3:.*]] = constant 3 : index
+//       CHECK:   return %[[c2]], %[[c3]], %[[c1]], %[[c1]], %[[c2]]
+
+
+func @fold_indices(%arg0: !linalg.buffer<?xf32>, %arg1: index, %arg2: index, %arg3: index) -> (index, index, index, index) {
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %R = linalg.range %arg1:%arg3:%c1 : !linalg.range
+
+  %v = linalg.view %arg0[%R, %R] : !linalg.buffer<?xf32> -> !linalg.view<?x?xf32>
+  // Expected %arg3 - %arg1.
+  %v0 = linalg.dim %v, 0 : !linalg.view<?x?xf32>
+  // Expected %arg3 - %arg1.
+  %v1 = linalg.dim %v, 1 : !linalg.view<?x?xf32>
+
+  %arg1_p_arg2 = addi %arg1, %arg2: index
+  %arg1_p_arg2_affine = affine.apply (i, j) -> (i + j) (%arg1, %arg2)
+  %sv = linalg.subview %v[%arg1, %arg1_p_arg2, %c1, %arg1, %arg1_p_arg2_affine, %c1] : !linalg.view<?x?xf32>
+  // Expected %arg2 but can't fold affine.apply with addi.
+  %sv0 = linalg.dim %sv, 0 : !linalg.view<?x?xf32>
+  // Expected %arg2.
+  %sv1 = linalg.dim %sv, 1 : !linalg.view<?x?xf32>
+
+  return %v0, %v1, %sv0, %sv1 : index, index, index, index
+}
+
+// CHECK-LABEL: fold_indices
+//       CHECK: (%[[arg0:.*]]: !linalg.buffer<?xf32>, %[[arg1:.*]]: index, %[[arg2:.*]]: index, %[[arg3:.*]]: index
+//       CHECK:   %[[r0:.*]] = affine.apply #[[SUB]]()[%[[arg3]], %[[arg1]]]
+//       CHECK:   %[[r1:.*]] = affine.apply #[[SUB]]()[%[[arg3]], %[[arg1]]]
+//       CHECK:   %[[add:.*]] = addi %[[arg1]], %[[arg2]] : index
+//       CHECK:   %[[aff:.*]] = affine.apply #[[SUB]]()[%[[add]], %[[arg1]]]
+//       CHECK:   return %[[r0]], %[[r1]], %[[aff]], %[[arg2]]
\ No newline at end of file