#include "mlir/IR/Function.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Linalg/IR/LinalgTypes.h"
#include "mlir/Linalg/Utils/Utils.h"
using namespace mlir::edsc::intrinsics;
using namespace mlir::linalg;
+namespace {
+/// Fold constant dimensions into an alloc operation.
+struct SimplifyDimOp : public OpRewritePattern<linalg::DimOp> {
+ using OpRewritePattern<linalg::DimOp>::OpRewritePattern;
+
+ PatternMatchResult matchAndRewrite(linalg::DimOp dimOp,
+ PatternRewriter &rewriter) const override;
+};
+} // end namespace
+
+PatternMatchResult
+SimplifyDimOp::matchAndRewrite(linalg::DimOp dimOp,
+ PatternRewriter &rewriter) const {
+ auto *viewProducingOp = dimOp.view()->getDefiningOp();
+ auto subView = dyn_cast_or_null<SubViewOp>(viewProducingOp);
+ auto slice = dyn_cast_or_null<SliceOp>(viewProducingOp);
+ auto view = dyn_cast_or_null<ViewOp>(viewProducingOp);
+ if (!subView && !slice && !view)
+ return matchFailure();
+
+ unsigned dim = dimOp.getIndex();
+ Value *min, *max, *step;
+ if (view) {
+ // Cannot traverse block arguments, fail.
+ if (isa<BlockArgument>(view.getIndexing(dim)))
+ return matchFailure();
+ // Record min, max, step for further processing.
+ auto range = cast<RangeOp>(view.getIndexing(dim)->getDefiningOp());
+ std::tie(min, max, step) =
+ std::make_tuple(range.min(), range.max(), range.step());
+ } else if (subView) {
+ // Record min, max, step for further processing.
+ auto range = subView.getRange(dim);
+ std::tie(min, max, step) =
+ std::make_tuple(range.min, range.max, range.step);
+ } else {
+ // Taking the dim of a slice must take a range (since other dims have been
+ // rank-reduced).
+ auto *rangeValue = slice.getRanges()[dim];
+ // Cannot traverse block arguments, fail.
+ if (isa<BlockArgument>(rangeValue))
+ return matchFailure();
+ auto range = cast<RangeOp>(rangeValue->getDefiningOp());
+ // Record min, max, step for further processing.
+ std::tie(min, max, step) =
+ std::make_tuple(range.min(), range.max(), range.step());
+ }
+
+ // Only support constant steps of 1 atm.
+ auto constant = dyn_cast_or_null<ConstantIndexOp>(step->getDefiningOp());
+ if (!constant || constant.getValue() != 1)
+ return matchFailure();
+
+ // Circumvent affine constraints:
+ // emit an affine_apply when possible, otherwise emit a `subi`.
+ bool validAffineMin = isValidDim(min) || isValidSymbol(min) ||
+ isa_and_nonnull<ConstantIndexOp>(min->getDefiningOp());
+ bool validAffineMax = isValidDim(max) || isValidSymbol(max) ||
+ isa_and_nonnull<ConstantIndexOp>(max->getDefiningOp());
+
+ OpBuilder b(dimOp);
+ ScopedContext scope(b, dimOp.getLoc());
+ // Emit `subi`.
+ if (!validAffineMin || !validAffineMax) {
+ rewriter.replaceOp(dimOp, {subi(max, min)}, {dimOp.view()});
+ return matchSuccess();
+ }
+
+ // Emit affine_apply.
+ using edsc::op::operator-;
+ rewriter.replaceOp(dimOp, {ValueHandle(max) - ValueHandle(min)},
+ {dimOp.view()});
+ return matchSuccess();
+}
+
////////////////////////////////////////////////////////////////////////////////
// LoadOp.
////////////////////////////////////////////////////////////////////////////////
result->types));
}
+//===----------------------------------------------------------------------===//
+// DimOp
+//===----------------------------------------------------------------------===//
+void mlir::linalg::DimOp::getCanonicalizationPatterns(
+ OwningRewritePatternList &results, MLIRContext *context) {
+ results.insert<SimplifyDimOp>(context);
+}
+
static void print(OpAsmPrinter *p, linalg::DimOp op) {
*p << op.getOperationName() << " " << *op.getOperand() << ", "
<< op.getIndex();
--- /dev/null
+// RUN: mlir-opt %s -canonicalize | FileCheck %s
+
+// CHECK-DAG: #[[SUB:.*]] = ()[s0, s1] -> (s0 - s1)
+
+func @fold_constants(%arg0: !linalg.buffer<?xf32>) -> (index, index, index, index, index) {
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %c2 = constant 2 : index
+ %c3 = constant 3 : index
+ %c4 = constant 4 : index
+ %c5 = constant 5 : index
+ %R02 = linalg.range %c0:%c2:%c1 : !linalg.range
+ %R03 = linalg.range %c0:%c3:%c1 : !linalg.range
+ %R04 = linalg.range %c0:%c4:%c1 : !linalg.range
+ %R12 = linalg.range %c1:%c2:%c1 : !linalg.range
+ %R13 = linalg.range %c1:%c3:%c1 : !linalg.range
+ %R14 = linalg.range %c1:%c4:%c1 : !linalg.range
+
+ %v = linalg.view %arg0[%R02, %R14] : !linalg.buffer<?xf32> -> !linalg.view<?x?xf32>
+ // Expected 2.
+ %v0 = linalg.dim %v, 0 : !linalg.view<?x?xf32>
+ // Expected 3.
+ %v1 = linalg.dim %v, 1 : !linalg.view<?x?xf32>
+
+ %s = linalg.slice %v[%c1, %R12] : !linalg.view<?x?xf32>, index, !linalg.range, !linalg.view<?xf32>
+ // Expected 1.
+ %s0 = linalg.dim %s, 0 : !linalg.view<?xf32>
+
+ %sv = linalg.subview %v[%v0, %v1, %c1, %c2, %c4, %c1] : !linalg.view<?x?xf32>
+ // Expected 1.
+ %sv0 = linalg.dim %sv, 0 : !linalg.view<?x?xf32>
+ // Expected 2.
+ %sv1 = linalg.dim %sv, 1 : !linalg.view<?x?xf32>
+
+ return %v0, %v1, %s0, %sv0, %sv1 : index, index, index, index, index
+}
+
+// CHECK-LABEL: fold_constants
+// CHECK-DAG: %[[c1:.*]] = constant 1 : index
+// CHECK-DAG: %[[c2:.*]] = constant 2 : index
+// CHECK-DAG: %[[c3:.*]] = constant 3 : index
+// CHECK: return %[[c2]], %[[c3]], %[[c1]], %[[c1]], %[[c2]]
+
+
+func @fold_indices(%arg0: !linalg.buffer<?xf32>, %arg1: index, %arg2: index, %arg3: index) -> (index, index, index, index) {
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %R = linalg.range %arg1:%arg3:%c1 : !linalg.range
+
+ %v = linalg.view %arg0[%R, %R] : !linalg.buffer<?xf32> -> !linalg.view<?x?xf32>
+ // Expected %arg3 - %arg1.
+ %v0 = linalg.dim %v, 0 : !linalg.view<?x?xf32>
+ // Expected %arg3 - %arg1.
+ %v1 = linalg.dim %v, 1 : !linalg.view<?x?xf32>
+
+ %arg1_p_arg2 = addi %arg1, %arg2: index
+ %arg1_p_arg2_affine = affine.apply (i, j) -> (i + j) (%arg1, %arg2)
+ %sv = linalg.subview %v[%arg1, %arg1_p_arg2, %c1, %arg1, %arg1_p_arg2_affine, %c1] : !linalg.view<?x?xf32>
+ // Expected %arg2 but can't fold affine.apply with addi.
+ %sv0 = linalg.dim %sv, 0 : !linalg.view<?x?xf32>
+ // Expected %arg2.
+ %sv1 = linalg.dim %sv, 1 : !linalg.view<?x?xf32>
+
+ return %v0, %v1, %sv0, %sv1 : index, index, index, index
+}
+
+// CHECK-LABEL: fold_indices
+// CHECK: (%[[arg0:.*]]: !linalg.buffer<?xf32>, %[[arg1:.*]]: index, %[[arg2:.*]]: index, %[[arg3:.*]]: index
+// CHECK: %[[r0:.*]] = affine.apply #[[SUB]]()[%[[arg3]], %[[arg1]]]
+// CHECK: %[[r1:.*]] = affine.apply #[[SUB]]()[%[[arg3]], %[[arg1]]]
+// CHECK: %[[add:.*]] = addi %[[arg1]], %[[arg2]] : index
+// CHECK: %[[aff:.*]] = affine.apply #[[SUB]]()[%[[add]], %[[arg1]]]
+// CHECK: return %[[r0]], %[[r1]], %[[aff]], %[[arg2]]
\ No newline at end of file