From b0ea33a7c62654a6aec8de37156755abb3021da9 Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Thu, 8 Aug 2019 09:09:29 -0700 Subject: [PATCH] Add canonicalization pattern for linalg.dim This CL introduces canonicalization patterns for linalg.dim. This allows the dimenions of chains of view, slice and subview operations to simplify. Down the line, when mixed with cse, this also allows better composition of linalg tiling and fusion by tracking operations that give the same result (not in this CL). PiperOrigin-RevId: 262365865 --- mlir/include/mlir/Linalg/IR/LinalgOps.td | 2 + mlir/lib/Linalg/IR/LinalgOps.cpp | 84 ++++++++++++++++++++++++++++++++ mlir/test/Linalg/canonicalize.mlir | 73 +++++++++++++++++++++++++++ 3 files changed, 159 insertions(+) create mode 100644 mlir/test/Linalg/canonicalize.mlir diff --git a/mlir/include/mlir/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Linalg/IR/LinalgOps.td index 55a8108..bbbbfad 100644 --- a/mlir/include/mlir/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Linalg/IR/LinalgOps.td @@ -146,6 +146,8 @@ def DimOp : Linalg_Op<"dim", [NoSideEffect]>, } ViewType getViewType() { return getOperand()->getType().cast(); } }]; + + let hasCanonicalizer = 1; } def SubViewOp : Linalg_Op<"subview", [NoSideEffect]>, diff --git a/mlir/lib/Linalg/IR/LinalgOps.cpp b/mlir/lib/Linalg/IR/LinalgOps.cpp index 4feb22b..60820ae 100644 --- a/mlir/lib/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Linalg/IR/LinalgOps.cpp @@ -28,6 +28,7 @@ #include "mlir/IR/Function.h" #include "mlir/IR/Module.h" #include "mlir/IR/OpImplementation.h" +#include "mlir/IR/PatternMatch.h" #include "mlir/IR/StandardTypes.h" #include "mlir/Linalg/IR/LinalgTypes.h" #include "mlir/Linalg/Utils/Utils.h" @@ -42,6 +43,81 @@ using namespace mlir::edsc; using namespace mlir::edsc::intrinsics; using namespace mlir::linalg; +namespace { +/// Fold constant dimensions into an alloc operation. +struct SimplifyDimOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + PatternMatchResult matchAndRewrite(linalg::DimOp dimOp, + PatternRewriter &rewriter) const override; +}; +} // end namespace + +PatternMatchResult +SimplifyDimOp::matchAndRewrite(linalg::DimOp dimOp, + PatternRewriter &rewriter) const { + auto *viewProducingOp = dimOp.view()->getDefiningOp(); + auto subView = dyn_cast_or_null(viewProducingOp); + auto slice = dyn_cast_or_null(viewProducingOp); + auto view = dyn_cast_or_null(viewProducingOp); + if (!subView && !slice && !view) + return matchFailure(); + + unsigned dim = dimOp.getIndex(); + Value *min, *max, *step; + if (view) { + // Cannot traverse block arguments, fail. + if (isa(view.getIndexing(dim))) + return matchFailure(); + // Record min, max, step for further processing. + auto range = cast(view.getIndexing(dim)->getDefiningOp()); + std::tie(min, max, step) = + std::make_tuple(range.min(), range.max(), range.step()); + } else if (subView) { + // Record min, max, step for further processing. + auto range = subView.getRange(dim); + std::tie(min, max, step) = + std::make_tuple(range.min, range.max, range.step); + } else { + // Taking the dim of a slice must take a range (since other dims have been + // rank-reduced). + auto *rangeValue = slice.getRanges()[dim]; + // Cannot traverse block arguments, fail. + if (isa(rangeValue)) + return matchFailure(); + auto range = cast(rangeValue->getDefiningOp()); + // Record min, max, step for further processing. + std::tie(min, max, step) = + std::make_tuple(range.min(), range.max(), range.step()); + } + + // Only support constant steps of 1 atm. + auto constant = dyn_cast_or_null(step->getDefiningOp()); + if (!constant || constant.getValue() != 1) + return matchFailure(); + + // Circumvent affine constraints: + // emit an affine_apply when possible, otherwise emit a `subi`. + bool validAffineMin = isValidDim(min) || isValidSymbol(min) || + isa_and_nonnull(min->getDefiningOp()); + bool validAffineMax = isValidDim(max) || isValidSymbol(max) || + isa_and_nonnull(max->getDefiningOp()); + + OpBuilder b(dimOp); + ScopedContext scope(b, dimOp.getLoc()); + // Emit `subi`. + if (!validAffineMin || !validAffineMax) { + rewriter.replaceOp(dimOp, {subi(max, min)}, {dimOp.view()}); + return matchSuccess(); + } + + // Emit affine_apply. + using edsc::op::operator-; + rewriter.replaceOp(dimOp, {ValueHandle(max) - ValueHandle(min)}, + {dimOp.view()}); + return matchSuccess(); +} + //////////////////////////////////////////////////////////////////////////////// // LoadOp. //////////////////////////////////////////////////////////////////////////////// @@ -501,6 +577,14 @@ static ParseResult parseBufferSizeOp(OpAsmParser *parser, result->types)); } +//===----------------------------------------------------------------------===// +// DimOp +//===----------------------------------------------------------------------===// +void mlir::linalg::DimOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.insert(context); +} + static void print(OpAsmPrinter *p, linalg::DimOp op) { *p << op.getOperationName() << " " << *op.getOperand() << ", " << op.getIndex(); diff --git a/mlir/test/Linalg/canonicalize.mlir b/mlir/test/Linalg/canonicalize.mlir new file mode 100644 index 0000000..65e1d54 --- /dev/null +++ b/mlir/test/Linalg/canonicalize.mlir @@ -0,0 +1,73 @@ +// RUN: mlir-opt %s -canonicalize | FileCheck %s + +// CHECK-DAG: #[[SUB:.*]] = ()[s0, s1] -> (s0 - s1) + +func @fold_constants(%arg0: !linalg.buffer) -> (index, index, index, index, index) { + %c0 = constant 0 : index + %c1 = constant 1 : index + %c2 = constant 2 : index + %c3 = constant 3 : index + %c4 = constant 4 : index + %c5 = constant 5 : index + %R02 = linalg.range %c0:%c2:%c1 : !linalg.range + %R03 = linalg.range %c0:%c3:%c1 : !linalg.range + %R04 = linalg.range %c0:%c4:%c1 : !linalg.range + %R12 = linalg.range %c1:%c2:%c1 : !linalg.range + %R13 = linalg.range %c1:%c3:%c1 : !linalg.range + %R14 = linalg.range %c1:%c4:%c1 : !linalg.range + + %v = linalg.view %arg0[%R02, %R14] : !linalg.buffer -> !linalg.view + // Expected 2. + %v0 = linalg.dim %v, 0 : !linalg.view + // Expected 3. + %v1 = linalg.dim %v, 1 : !linalg.view + + %s = linalg.slice %v[%c1, %R12] : !linalg.view, index, !linalg.range, !linalg.view + // Expected 1. + %s0 = linalg.dim %s, 0 : !linalg.view + + %sv = linalg.subview %v[%v0, %v1, %c1, %c2, %c4, %c1] : !linalg.view + // Expected 1. + %sv0 = linalg.dim %sv, 0 : !linalg.view + // Expected 2. + %sv1 = linalg.dim %sv, 1 : !linalg.view + + return %v0, %v1, %s0, %sv0, %sv1 : index, index, index, index, index +} + +// CHECK-LABEL: fold_constants +// CHECK-DAG: %[[c1:.*]] = constant 1 : index +// CHECK-DAG: %[[c2:.*]] = constant 2 : index +// CHECK-DAG: %[[c3:.*]] = constant 3 : index +// CHECK: return %[[c2]], %[[c3]], %[[c1]], %[[c1]], %[[c2]] + + +func @fold_indices(%arg0: !linalg.buffer, %arg1: index, %arg2: index, %arg3: index) -> (index, index, index, index) { + %c0 = constant 0 : index + %c1 = constant 1 : index + %R = linalg.range %arg1:%arg3:%c1 : !linalg.range + + %v = linalg.view %arg0[%R, %R] : !linalg.buffer -> !linalg.view + // Expected %arg3 - %arg1. + %v0 = linalg.dim %v, 0 : !linalg.view + // Expected %arg3 - %arg1. + %v1 = linalg.dim %v, 1 : !linalg.view + + %arg1_p_arg2 = addi %arg1, %arg2: index + %arg1_p_arg2_affine = affine.apply (i, j) -> (i + j) (%arg1, %arg2) + %sv = linalg.subview %v[%arg1, %arg1_p_arg2, %c1, %arg1, %arg1_p_arg2_affine, %c1] : !linalg.view + // Expected %arg2 but can't fold affine.apply with addi. + %sv0 = linalg.dim %sv, 0 : !linalg.view + // Expected %arg2. + %sv1 = linalg.dim %sv, 1 : !linalg.view + + return %v0, %v1, %sv0, %sv1 : index, index, index, index +} + +// CHECK-LABEL: fold_indices +// CHECK: (%[[arg0:.*]]: !linalg.buffer, %[[arg1:.*]]: index, %[[arg2:.*]]: index, %[[arg3:.*]]: index +// CHECK: %[[r0:.*]] = affine.apply #[[SUB]]()[%[[arg3]], %[[arg1]]] +// CHECK: %[[r1:.*]] = affine.apply #[[SUB]]()[%[[arg3]], %[[arg1]]] +// CHECK: %[[add:.*]] = addi %[[arg1]], %[[arg2]] : index +// CHECK: %[[aff:.*]] = affine.apply #[[SUB]]()[%[[add]], %[[arg1]]] +// CHECK: return %[[r0]], %[[r1]], %[[aff]], %[[arg2]] \ No newline at end of file -- 2.7.4