Adds affine.min operation which returns the minimum value from a multi-result affine...
authorAndy Davis <andydavis@google.com>
Tue, 12 Nov 2019 15:08:23 +0000 (07:08 -0800)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 12 Nov 2019 15:08:49 +0000 (07:08 -0800)
PiperOrigin-RevId: 279959714

mlir/g3doc/Dialects/Affine.md
mlir/include/mlir/Dialect/AffineOps/AffineOps.td
mlir/lib/Dialect/AffineOps/AffineOps.cpp
mlir/test/AffineOps/canonicalize.mlir
mlir/test/AffineOps/invalid.mlir
mlir/test/AffineOps/ops.mlir

index 541daf8..6049457 100644 (file)
@@ -560,6 +560,29 @@ Example:
 
 ```
 
+#### 'affine.min' operation
+
+Syntax:
+
+``` {.ebnf}
+operation ::= ssa-id `=` `affine.min` affine-map dim-and-symbol-use-list
+```
+
+The `affine.min` operation applies an
+[affine mapping](#affine-expressions) to a list of SSA values, and returns the
+minimum value of all result expressions. The number of dimension and symbol
+arguments to affine.min must be equal to the respective number of dimensional
+and symbolic inputs to the affine mapping; the `affine.min` operation always
+returns one value. The input operands and result must all have 'index' type.
+
+Example:
+
+```mlir {.mlir}
+
+%0 = affine.min (d0)[s0] -> (1000, d0 + 512, s0) (%arg0)[%arg1]
+
+```
+
 #### `affine.terminator` operation
 
 Syntax:
index f54c514..1b6c777 100644 (file)
@@ -248,6 +248,24 @@ def AffineIfOp : Affine_Op<"if", [ImplicitAffineTerminator]> {
   let hasCanonicalizer = 1;
 }
 
+def AffineMinOp : Affine_Op<"min"> {
+  let summary = "min operation";
+  let description = [{
+    The "min" operation computes the minimum value result from a multi-result
+    affine map.
+
+    Example:
+
+       %0 = affine.min (d0) -> (1000, d0 + 512) (%i0) : index
+  }];
+  let arguments = (ins AffineMapAttr:$map, Variadic<Index>:$operands);
+  let results = (outs Index);
+  let extraClassDeclaration = [{
+    static StringRef getMapAttrName() { return "map"; }
+  }];
+  let hasFolder = 1;
+}
+
 def AffineTerminatorOp :
     Affine_Op<"terminator", [Terminator]> {
   let summary = "affine terminator operation";
index acec1dd..77ee9cf 100644 (file)
@@ -1937,5 +1937,80 @@ void AffineStoreOp::getCanonicalizationPatterns(
   results.insert<SimplifyAffineOp<AffineStoreOp>>(context);
 }
 
+//===----------------------------------------------------------------------===//
+// AffineMinOp
+//===----------------------------------------------------------------------===//
+//
+//   %0 = affine.min (d0) -> (1000, d0 + 512) (%i0)
+//
+
+static ParseResult parseAffineMinOp(OpAsmParser &parser,
+                                    OperationState &result) {
+  auto &builder = parser.getBuilder();
+  auto indexType = builder.getIndexType();
+  SmallVector<OpAsmParser::OperandType, 8> dim_infos;
+  SmallVector<OpAsmParser::OperandType, 8> sym_infos;
+  AffineMapAttr mapAttr;
+  return failure(
+      parser.parseAttribute(mapAttr, AffineMinOp::getMapAttrName(),
+                            result.attributes) ||
+      parser.parseOperandList(dim_infos, OpAsmParser::Delimiter::Paren) ||
+      parser.parseOperandList(sym_infos,
+                              OpAsmParser::Delimiter::OptionalSquare) ||
+      parser.parseOptionalAttrDict(result.attributes) ||
+      parser.resolveOperands(dim_infos, indexType, result.operands) ||
+      parser.resolveOperands(sym_infos, indexType, result.operands) ||
+      parser.addTypeToList(indexType, result.types));
+}
+
+static void print(OpAsmPrinter &p, AffineMinOp op) {
+  p << op.getOperationName() << ' '
+    << op.getAttr(AffineMinOp::getMapAttrName());
+  auto begin = op.operand_begin();
+  auto end = op.operand_end();
+  unsigned numDims = op.map().getNumDims();
+  p << '(';
+  p.printOperands(begin, begin + numDims);
+  p << ')';
+
+  if (begin + numDims != end) {
+    p << '[';
+    p.printOperands(begin + numDims, end);
+    p << ']';
+  }
+  p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"map"});
+}
+
+static LogicalResult verify(AffineMinOp op) {
+  // Verify that operand count matches affine map dimension and symbol count.
+  if (op.getNumOperands() != op.map().getNumDims() + op.map().getNumSymbols())
+    return op.emitOpError(
+        "operand count and affine map dimension and symbol count must match");
+  return success();
+}
+
+OpFoldResult AffineMinOp::fold(ArrayRef<Attribute> operands) {
+  // Fold the affine map.
+  // TODO(andydavis, ntv) Fold more cases: partial static information,
+  // min(some_affine, some_affine + constant, ...).
+  SmallVector<Attribute, 2> results;
+  if (failed(map().constantFold(operands, results)))
+    return {};
+
+  // Compute and return min of folded map results.
+  int64_t min = std::numeric_limits<int64_t>::max();
+  int minIndex = -1;
+  for (unsigned i = 0, e = results.size(); i < e; ++i) {
+    auto intAttr = results[i].cast<IntegerAttr>();
+    if (intAttr.getInt() < min) {
+      min = intAttr.getInt();
+      minIndex = i;
+    }
+  }
+  if (minIndex < 0)
+    return {};
+  return results[minIndex];
+}
+
 #define GET_OP_CLASSES
 #include "mlir/Dialect/AffineOps/AffineOps.cpp.inc"
index e99a794..98b78f1 100644 (file)
@@ -500,3 +500,29 @@ func @compose_into_affine_load_store(%A : memref<1024xf32>, %u : index) {
   }
   return
 }
+
+// -----
+
+func @affine_min(%arg0 : index, %arg1 : index, %arg2 : index) {
+  %c511 = constant 511 : index
+  %c1 = constant 0 : index
+  %0 = affine.min (d0)[s0] -> (1000, d0 + 512, s0 + 1) (%c1)[%c511]
+  "op0"(%0) : (index) -> ()
+  // CHECK:       %[[CST:.*]] = constant 512 : index
+  // CHECK-NEXT:  "op0"(%[[CST]]) : (index) -> ()
+  // CHECK-NEXT:  return
+  return
+}
+
+// -----
+
+func @affine_min(%arg0 : index, %arg1 : index, %arg2 : index) {
+  %c3 = constant 3 : index
+  %c20 = constant 20 : index
+  %0 = affine.min (d0)[s0] -> (1000, d0 floordiv 4, (s0 mod 5) + 1) (%c20)[%c3]
+  "op0"(%0) : (index) -> ()
+  // CHECK:       %[[CST:.*]] = constant 4 : index
+  // CHECK-NEXT:  "op0"(%[[CST]]) : (index) -> ()
+  // CHECK-NEXT:  return
+  return
+}
index 5a8c242..390b2ae 100644 (file)
@@ -151,3 +151,33 @@ func @affine_store_missing_l_square(%C: memref<4096x4096xf32>) {
   affine.store %9, %C : memref<4096x4096xf32>
   return
 }
+
+// -----
+
+// CHECK-LABEL: @affine_min
+func @affine_min(%arg0 : index, %arg1 : index, %arg2 : index) {
+  // expected-error@+1 {{operand count and affine map dimension and symbol count must match}}
+  %0 = affine.min (d0) -> (d0) (%arg0, %arg1)
+
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @affine_min
+func @affine_min(%arg0 : index, %arg1 : index, %arg2 : index) {
+  // expected-error@+1 {{operand count and affine map dimension and symbol count must match}}
+  %0 = affine.min ()[s0] -> (s0) (%arg0, %arg1)
+
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @affine_min
+func @affine_min(%arg0 : index, %arg1 : index, %arg2 : index) {
+  // expected-error@+1 {{operand count and affine map dimension and symbol count must match}}
+  %0 = affine.min (d0) -> (d0) ()
+
+  return
+}
index 69dd4b8..795bce9 100644 (file)
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s | FileCheck %s
+// RUN: mlir-opt -split-input-file %s | FileCheck %s
 // RUN: mlir-opt %s -mlir-print-op-generic | FileCheck -check-prefix=GENERIC %s
 
 // Check that the attributes for the affine operations are round-tripped.
@@ -58,3 +58,23 @@ func @affine_terminator() {
   }
   return
 }
+
+// -----
+
+// CHECK-DAG: #[[MAP0:map[0-9]+]] = (d0)[s0] -> (1000, d0 + 512, s0)
+// CHECK-DAG: #[[MAP1:map[0-9]+]] = (d0, d1)[s0] -> (d0 - d1, s0 + 512)
+// CHECK-DAG: #[[MAP2:map[0-9]+]] = ()[s0, s1] -> (s0 - s1, 11)
+// CHECK-DAG: #[[MAP3:map[0-9]+]] = () -> (77, 78, 79)
+
+// CHECK-LABEL: @affine_min
+func @affine_min(%arg0 : index, %arg1 : index, %arg2 : index) {
+  // CHECK: affine.min #[[MAP0]](%arg0)[%arg1]
+  %0 = affine.min (d0)[s0] -> (1000, d0 + 512, s0) (%arg0)[%arg1]
+  // CHECK: affine.min #[[MAP1]](%arg0, %arg1)[%arg2]
+  %1 = affine.min (d0, d1)[s0] -> (d0 - d1, s0 + 512) (%arg0, %arg1)[%arg2]
+  // CHECK: affine.min #[[MAP2]]()[%arg1, %arg2]
+  %2 = affine.min ()[s0, s1] -> (s0 - s1, 11) ()[%arg1, %arg2]
+  // CHECK: affine.min #[[MAP3]]()
+  %3 = affine.min ()[] -> (77, 78, 79) ()[]
+  return
+}