[mlir] Add AffineMaxOp
authorOuHangKresnik <ouhang.kresnik.zero@gmail.com>
Thu, 6 Feb 2020 09:25:55 +0000 (10:25 +0100)
committerAlex Zinenko <zinenko@google.com>
Thu, 6 Feb 2020 09:26:50 +0000 (10:26 +0100)
Differential Revision: https://reviews.llvm.org/D73848

mlir/include/mlir/Dialect/AffineOps/AffineOps.td
mlir/include/mlir/EDSC/Intrinsics.h
mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
mlir/lib/Dialect/AffineOps/AffineOps.cpp
mlir/test/AffineOps/canonicalize.mlir
mlir/test/AffineOps/invalid.mlir
mlir/test/AffineOps/ops.mlir
mlir/test/Transforms/lower-affine.mlir

index 3168cc2..6b55261 100644 (file)
@@ -234,7 +234,20 @@ def AffineIfOp : Affine_Op<"if", [ImplicitAffineTerminator]> {
   let hasFolder = 1;
 }
 
-def AffineMinOp : Affine_Op<"min"> {
+class AffineMinMaxOpBase<string mnemonic, list<OpTrait> traits = []> :
+    Op<Affine_Dialect, mnemonic, traits> {
+  let arguments = (ins AffineMapAttr:$map, Variadic<Index>:$operands);
+  let results = (outs Index);
+  let extraClassDeclaration = [{
+    static StringRef getMapAttrName() { return "map"; }
+  }];
+  let verifier = [{ return ::verifyAffineMinMaxOp(*this); }];
+  let printer = [{ return ::printAffineMinMaxOp(p, *this); }];
+  let parser = [{ return ::parseAffineMinMaxOp<$cppClass>(parser, result); }];
+  let hasFolder = 1;
+}
+
+def AffineMinOp : AffineMinMaxOpBase<"min"> {
   let summary = "min operation";
   let description = [{
     The "min" operation computes the minimum value result from a multi-result
@@ -244,12 +257,18 @@ def AffineMinOp : Affine_Op<"min"> {
 
        %0 = affine.min (d0) -> (1000, d0 + 512) (%i0) : index
   }];
-  let arguments = (ins AffineMapAttr:$map, Variadic<Index>:$operands);
-  let results = (outs Index);
-  let extraClassDeclaration = [{
-    static StringRef getMapAttrName() { return "map"; }
+}
+
+def AffineMaxOp : AffineMinMaxOpBase<"max"> {
+  let summary = "max operation";
+  let description = [{
+    The "max" operation computes the maximum value result from a multi-result
+    affine map.
+
+    Example:
+
+       %0 = affine.max (d0) -> (1000, d0 + 512) (%i0) : index
   }];
-  let hasFolder = 1;
 }
 
 def AffinePrefetchOp : Affine_Op<"prefetch"> {
index 4f1c8d8..d800f36 100644 (file)
@@ -196,6 +196,7 @@ using affine_apply = ValueBuilder<AffineApplyOp>;
 using affine_if = OperationBuilder<AffineIfOp>;
 using affine_load = ValueBuilder<AffineLoadOp>;
 using affine_min = ValueBuilder<AffineMinOp>;
+using affine_max = ValueBuilder<AffineMaxOp>;
 using affine_store = OperationBuilder<AffineStoreOp>;
 using alloc = ValueBuilder<AllocOp>;
 using call = OperationBuilder<mlir::CallOp>;
index 601bfe7..d4b5fce 100644 (file)
@@ -258,16 +258,13 @@ static Value buildMinMaxReductionSeq(Location loc, CmpIPredicate predicate,
   return value;
 }
 
-/// Emit instructions that correspond to the affine map in the lower bound
-/// applied to the respective operands, and compute the maximum value across
-/// the results.
-Value mlir::lowerAffineLowerBound(AffineForOp op, OpBuilder &builder) {
-  auto lbValues = expandAffineMap(builder, op.getLoc(), op.getLowerBoundMap(),
-                                  op.getLowerBoundOperands());
-  if (!lbValues)
-    return nullptr;
-  return buildMinMaxReductionSeq(op.getLoc(), CmpIPredicate::sgt, *lbValues,
-                                 builder);
+/// Emit instructions that correspond to computing the maximum value amoung the
+/// values of a (potentially) multi-output affine map applied to `operands`.
+static Value lowerAffineMapMax(OpBuilder &builder, Location loc, AffineMap map,
+                               ValueRange operands) {
+  if (auto values = expandAffineMap(builder, loc, map, operands))
+    return buildMinMaxReductionSeq(loc, CmpIPredicate::sgt, *values, builder);
+  return nullptr;
 }
 
 /// Emit instructions that correspond to computing the minimum value amoung the
@@ -287,6 +284,14 @@ Value mlir::lowerAffineUpperBound(AffineForOp op, OpBuilder &builder) {
                            op.getUpperBoundOperands());
 }
 
+/// Emit instructions that correspond to the affine map in the lower bound
+/// applied to the respective operands, and compute the maximum value across
+/// the results.
+Value mlir::lowerAffineLowerBound(AffineForOp op, OpBuilder &builder) {
+  return lowerAffineMapMax(builder, op.getLoc(), op.getLowerBoundMap(),
+                           op.getLowerBoundOperands());
+}
+
 namespace {
 class AffineMinLowering : public OpRewritePattern<AffineMinOp> {
 public:
@@ -304,6 +309,22 @@ public:
   }
 };
 
+class AffineMaxLowering : public OpRewritePattern<AffineMaxOp> {
+public:
+  using OpRewritePattern<AffineMaxOp>::OpRewritePattern;
+
+  PatternMatchResult matchAndRewrite(AffineMaxOp op,
+                                     PatternRewriter &rewriter) const override {
+    Value reduced =
+        lowerAffineMapMax(rewriter, op.getLoc(), op.map(), op.operands());
+    if (!reduced)
+      return matchFailure();
+
+    rewriter.replaceOp(op, reduced);
+    return matchSuccess();
+  }
+};
+
 /// Affine terminators are removed.
 class AffineTerminatorLowering : public OpRewritePattern<AffineTerminatorOp> {
 public:
@@ -546,6 +567,7 @@ void mlir::populateAffineToStdConversionPatterns(
       AffineDmaWaitLowering,
       AffineLoadLowering,
       AffineMinLowering,
+      AffineMaxLowering,
       AffinePrefetchLowering,
       AffineStoreLowering,
       AffineForLowering,
index 224049e..c51faae 100644 (file)
@@ -1935,22 +1935,41 @@ LogicalResult AffineStoreOp::fold(ArrayRef<Attribute> cstOperands,
 }
 
 //===----------------------------------------------------------------------===//
-// AffineMinOp
+// AffineMinMaxOpBase
 //===----------------------------------------------------------------------===//
-//
-//   %0 = affine.min (d0) -> (1000, d0 + 512) (%i0)
-//
 
-static ParseResult parseAffineMinOp(OpAsmParser &parser,
-                                    OperationState &result) {
+template <typename T>
+static LogicalResult verifyAffineMinMaxOp(T op) {
+  // Verify that operand count matches affine map dimension and symbol count.
+  if (op.getNumOperands() != op.map().getNumDims() + op.map().getNumSymbols())
+    return op.emitOpError(
+        "operand count and affine map dimension and symbol count must match");
+  return success();
+}
+
+template <typename T>
+static void printAffineMinMaxOp(OpAsmPrinter &p, T op) {
+  p << op.getOperationName() << ' ' << op.getAttr(T::getMapAttrName());
+  auto operands = op.getOperands();
+  unsigned numDims = op.map().getNumDims();
+  p << '(' << operands.take_front(numDims) << ')';
+
+  if (operands.size() != numDims)
+    p << '[' << operands.drop_front(numDims) << ']';
+  p.printOptionalAttrDict(op.getAttrs(),
+                          /*elidedAttrs=*/{T::getMapAttrName()});
+}
+
+template <typename T>
+static ParseResult parseAffineMinMaxOp(OpAsmParser &parser,
+                                       OperationState &result) {
   auto &builder = parser.getBuilder();
   auto indexType = builder.getIndexType();
   SmallVector<OpAsmParser::OperandType, 8> dim_infos;
   SmallVector<OpAsmParser::OperandType, 8> sym_infos;
   AffineMapAttr mapAttr;
   return failure(
-      parser.parseAttribute(mapAttr, AffineMinOp::getMapAttrName(),
-                            result.attributes) ||
+      parser.parseAttribute(mapAttr, T::getMapAttrName(), result.attributes) ||
       parser.parseOperandList(dim_infos, OpAsmParser::Delimiter::Paren) ||
       parser.parseOperandList(sym_infos,
                               OpAsmParser::Delimiter::OptionalSquare) ||
@@ -1960,25 +1979,12 @@ static ParseResult parseAffineMinOp(OpAsmParser &parser,
       parser.addTypeToList(indexType, result.types));
 }
 
-static void print(OpAsmPrinter &p, AffineMinOp op) {
-  p << op.getOperationName() << ' '
-    << op.getAttr(AffineMinOp::getMapAttrName());
-  auto operands = op.getOperands();
-  unsigned numDims = op.map().getNumDims();
-  p << '(' << operands.take_front(numDims) << ')';
-
-  if (operands.size() != numDims)
-    p << '[' << operands.drop_front(numDims) << ']';
-  p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"map"});
-}
-
-static LogicalResult verify(AffineMinOp op) {
-  // Verify that operand count matches affine map dimension and symbol count.
-  if (op.getNumOperands() != op.map().getNumDims() + op.map().getNumSymbols())
-    return op.emitOpError(
-        "operand count and affine map dimension and symbol count must match");
-  return success();
-}
+//===----------------------------------------------------------------------===//
+// AffineMinOp
+//===----------------------------------------------------------------------===//
+//
+//   %0 = affine.min (d0) -> (1000, d0 + 512) (%i0)
+//
 
 OpFoldResult AffineMinOp::fold(ArrayRef<Attribute> operands) {
   // Fold the affine map.
@@ -2004,6 +2010,36 @@ OpFoldResult AffineMinOp::fold(ArrayRef<Attribute> operands) {
 }
 
 //===----------------------------------------------------------------------===//
+// AffineMaxOp
+//===----------------------------------------------------------------------===//
+//
+//   %0 = affine.max (d0) -> (1000, d0 + 512) (%i0)
+//
+
+OpFoldResult AffineMaxOp::fold(ArrayRef<Attribute> operands) {
+  // Fold the affine map.
+  // TODO(andydavis, ntv, ouhang) Fold more cases: partial static information,
+  // max(some_affine, some_affine + constant, ...).
+  SmallVector<Attribute, 2> results;
+  if (failed(map().constantFold(operands, results)))
+    return {};
+
+  // Compute and return max of folded map results.
+  int64_t max = std::numeric_limits<int64_t>::min();
+  int maxIndex = -1;
+  for (unsigned i = 0, e = results.size(); i < e; ++i) {
+    auto intAttr = results[i].cast<IntegerAttr>();
+    if (intAttr.getInt() > max) {
+      max = intAttr.getInt();
+      maxIndex = i;
+    }
+  }
+  if (maxIndex < 0)
+    return {};
+  return results[maxIndex];
+}
+
+//===----------------------------------------------------------------------===//
 // AffinePrefetchOp
 //===----------------------------------------------------------------------===//
 
index 47477fd..5662693 100644 (file)
@@ -526,3 +526,29 @@ func @affine_min(%arg0 : index, %arg1 : index, %arg2 : index) {
   // CHECK-NEXT:  return
   return
 }
+
+// -----
+
+func @affine_max(%arg0 : index, %arg1 : index, %arg2 : index) {
+  %c511 = constant 511 : index
+  %c1 = constant 0 : index
+  %0 = affine.max affine_map<(d0)[s0] -> (1000, d0 + 512, s0 + 1)> (%c1)[%c511]
+  "op0"(%0) : (index) -> ()
+  // CHECK:       %[[CST:.*]] = constant 1000 : index
+  // CHECK-NEXT:  "op0"(%[[CST]]) : (index) -> ()
+  // CHECK-NEXT:  return
+  return
+}
+
+// -----
+
+func @affine_max(%arg0 : index, %arg1 : index, %arg2 : index) {
+  %c3 = constant 3 : index
+  %c20 = constant 20 : index
+  %0 = affine.max affine_map<(d0)[s0] -> (1000, d0 floordiv 4, (s0 mod 5) + 1)> (%c20)[%c3]
+  "op0"(%0) : (index) -> ()
+  // CHECK:       %[[CST:.*]] = constant 1000 : index
+  // CHECK-NEXT:  "op0"(%[[CST]]) : (index) -> ()
+  // CHECK-NEXT:  return
+  return
+}
\ No newline at end of file
index b36af2c..0a9fe86 100644 (file)
@@ -168,3 +168,33 @@ func @affine_min(%arg0 : index, %arg1 : index, %arg2 : index) {
 
   return
 }
+
+// -----
+
+// CHECK-LABEL: @affine_max
+func @affine_max(%arg0 : index, %arg1 : index, %arg2 : index) {
+  // expected-error@+1 {{operand count and affine map dimension and symbol count must match}}
+  %0 = affine.max affine_map<(d0) -> (d0)> (%arg0, %arg1)
+
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @affine_max
+func @affine_max(%arg0 : index, %arg1 : index, %arg2 : index) {
+  // expected-error@+1 {{operand count and affine map dimension and symbol count must match}}
+  %0 = affine.max affine_map<()[s0] -> (s0)> (%arg0, %arg1)
+
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @affine_max
+func @affine_max(%arg0 : index, %arg1 : index, %arg2 : index) {
+  // expected-error@+1 {{operand count and affine map dimension and symbol count must match}}
+  %0 = affine.max affine_map<(d0) -> (d0)> ()
+
+  return
+}
\ No newline at end of file
index 48b21fa..c038524 100644 (file)
@@ -79,6 +79,19 @@ func @affine_min(%arg0 : index, %arg1 : index, %arg2 : index) {
   return
 }
 
+// CHECK-LABEL: @affine_max
+func @affine_max(%arg0 : index, %arg1 : index, %arg2 : index) {
+  // CHECK: affine.max #[[MAP0]](%arg0)[%arg1]
+  %0 = affine.max affine_map<(d0)[s0] -> (1000, d0 + 512, s0)> (%arg0)[%arg1]
+  // CHECK: affine.max #[[MAP1]](%arg0, %arg1)[%arg2]
+  %1 = affine.max affine_map<(d0, d1)[s0] -> (d0 - d1, s0 + 512)> (%arg0, %arg1)[%arg2]
+  // CHECK: affine.max #[[MAP2]]()[%arg1, %arg2]
+  %2 = affine.max affine_map<()[s0, s1] -> (s0 - s1, 11)> ()[%arg1, %arg2]
+  // CHECK: affine.max #[[MAP3]]()
+  %3 = affine.max affine_map<()[] -> (77, 78, 79)> ()[]
+  return
+}
+
 // -----
 
 func @valid_symbols(%arg0: index, %arg1: index, %arg2: index) {
index 3eb12ed..18b410a 100644 (file)
@@ -605,3 +605,18 @@ func @affine_min(%arg0: index, %arg1: index) -> index{
   %0 = affine.min affine_map<(d0,d1) -> (d0 - d1, d1 - d0)>(%arg0, %arg1)
   return %0 : index
 }
+
+// CHECK-LABEL: func @affine_max
+// CHECK-SAME: %[[ARG0:.*]]: index, %[[ARG1:.*]]: index
+func @affine_max(%arg0: index, %arg1: index) -> index{
+  // CHECK: %[[Cm1:.*]] = constant -1
+  // CHECK: %[[neg1:.*]] = muli %[[ARG1]], %[[Cm1:.*]]
+  // CHECK: %[[first:.*]] = addi %[[ARG0]], %[[neg1]]
+  // CHECK: %[[Cm2:.*]] = constant -1
+  // CHECK: %[[neg2:.*]] = muli %[[ARG0]], %[[Cm2:.*]]
+  // CHECK: %[[second:.*]] = addi %[[ARG1]], %[[neg2]]
+  // CHECK: %[[cmp:.*]] = cmpi "sgt", %[[first]], %[[second]]
+  // CHECK: select %[[cmp]], %[[first]], %[[second]]
+  %0 = affine.max affine_map<(d0,d1) -> (d0 - d1, d1 - d0)>(%arg0, %arg1)
+  return %0 : index
+}
\ No newline at end of file