[mlir] Add offset/stride helper functions to OffsetSizeAndStrideOpInterface
authorMatthias Springer <springerm@google.com>
Mon, 7 Jun 2021 11:05:25 +0000 (20:05 +0900)
committerMatthias Springer <springerm@google.com>
Mon, 7 Jun 2021 11:11:41 +0000 (20:11 +0900)
* Add hasUnitStride and hasZeroOffset to OffsetSizeAndStrideOpInterface. These functions are useful for various patterns. E.g., some vectorization patterns apply only for tensor ops with zero offsets and/or unit stride.
* Add getConstantIntValue and isEqualConstantInt helper functions, which are useful for implementing the two above functions, as well as various patterns.

Differential Revision: https://reviews.llvm.org/D103763

mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
mlir/include/mlir/Interfaces/ViewLikeInterface.h
mlir/include/mlir/Interfaces/ViewLikeInterface.td
mlir/lib/Dialect/StandardOps/IR/Ops.cpp

index 65cac17..ee1cc67 100644 (file)
@@ -122,6 +122,15 @@ bool applyCmpPredicate(CmpIPredicate predicate, const APInt &lhs,
 bool applyCmpPredicate(CmpFPredicate predicate, const APFloat &lhs,
                        const APFloat &rhs);
 
+/// If ofr is a constant integer, i.e., an IntegerAttr or a ConstantOp with an
+/// IntegerAttr, return the integer.
+llvm::Optional<int64_t> getConstantIntValue(OpFoldResult ofr);
+
+/// Return true if ofr and value are the same integer.
+/// Ignore integer bitwidth and type mismatch that come from the fact there is
+/// no IndexAttr and that IndexType has no bitwidth.
+bool isEqualConstantInt(OpFoldResult ofr, int64_t value);
+
 /// Return true if ofr1 and ofr2 are the same integer constant attribute values
 /// or the same SSA value.
 /// Ignore integer bitwitdh and type mismatch that come from the fact there is
index 0094fff..8d58570 100644 (file)
@@ -30,6 +30,8 @@ struct Range {
 
 class OffsetSizeAndStrideOpInterface;
 
+bool isEqualConstantInt(OpFoldResult ofr, int64_t value);
+
 namespace detail {
 LogicalResult verifyOffsetSizeAndStrideOp(OffsetSizeAndStrideOpInterface op);
 
index e26a02f..62f24f2 100644 (file)
@@ -436,6 +436,30 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface
             $_op.getOperation()), other, cmp);
       }]
     >,
+    InterfaceMethod<
+      /*desc=*/[{ Return true if all strides are guaranteed to be 1. }],
+      /*retTy=*/"bool",
+      /*methodName=*/"hasUnitStride",
+      /*args=*/(ins),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        return ::llvm::all_of(getMixedStrides(), [](OpFoldResult ofr) {
+          return ::mlir::isEqualConstantInt(ofr, 1);
+        });
+      }]
+    >,
+    InterfaceMethod<
+      /*desc=*/[{ Return true if all offsets are guaranteed to be 0. }],
+      /*retTy=*/"bool",
+      /*methodName=*/"hasZeroOffset",
+      /*args=*/(ins),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        return ::llvm::all_of(getMixedOffsets(), [](OpFoldResult ofr) {
+          return ::mlir::isEqualConstantInt(ofr, 0);
+        });
+      }]
+    >,
   ];
 
   let extraClassDeclaration = [{
index f115fd0..a3c2513 100644 (file)
@@ -60,24 +60,35 @@ static void dispatchIndexOpFoldResults(ArrayRef<OpFoldResult> ofrs,
     dispatchIndexOpFoldResult(ofr, dynamicVec, staticVec, sentinel);
 }
 
+/// If ofr is a constant integer, i.e., an IntegerAttr or a ConstantOp with an
+/// IntegerAttr, return the integer.
+llvm::Optional<int64_t> mlir::getConstantIntValue(OpFoldResult ofr) {
+  Attribute attr = ofr.dyn_cast<Attribute>();
+  // Note: isa+cast-like pattern allows writing the condition below as 1 line.
+  if (!attr && ofr.get<Value>().getDefiningOp<ConstantOp>())
+    attr = ofr.get<Value>().getDefiningOp<ConstantOp>().getValue();
+  if (auto intAttr = attr.dyn_cast_or_null<IntegerAttr>())
+    return intAttr.getValue().getSExtValue();
+  return llvm::None;
+}
+
+/// Return true if ofr and value are the same integer.
+/// Ignore integer bitwidth and type mismatch that come from the fact there is
+/// no IndexAttr and that IndexType has no bitwidth.
+bool mlir::isEqualConstantInt(OpFoldResult ofr, int64_t value) {
+  auto ofrValue = getConstantIntValue(ofr);
+  return ofrValue && *ofrValue == value;
+}
+
 /// Return true if ofr1 and ofr2 are the same integer constant attribute values
 /// or the same SSA value.
-/// Ignore integer bitwitdh and type mismatch that come from the fact there is
-/// no IndexAttr and that IndexType have no bitwidth.
-bool mlir::isEqualConstantIntOrValue(OpFoldResult op1, OpFoldResult op2) {
-  auto getConstantIntValue = [](OpFoldResult ofr) -> llvm::Optional<int64_t> {
-    Attribute attr = ofr.dyn_cast<Attribute>();
-    // Note: isa+cast-like pattern allows writing the condition below as 1 line.
-    if (!attr && ofr.get<Value>().getDefiningOp<ConstantOp>())
-      attr = ofr.get<Value>().getDefiningOp<ConstantOp>().getValue();
-    if (auto intAttr = attr.dyn_cast_or_null<IntegerAttr>())
-      return intAttr.getValue().getSExtValue();
-    return llvm::None;
-  };
-  auto cst1 = getConstantIntValue(op1), cst2 = getConstantIntValue(op2);
+/// Ignore integer bitwidth and type mismatch that come from the fact there is
+/// no IndexAttr and that IndexType has no bitwidth.
+bool mlir::isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2) {
+  auto cst1 = getConstantIntValue(ofr1), cst2 = getConstantIntValue(ofr2);
   if (cst1 && cst2 && *cst1 == *cst2)
     return true;
-  auto v1 = op1.dyn_cast<Value>(), v2 = op2.dyn_cast<Value>();
+  auto v1 = ofr1.dyn_cast<Value>(), v2 = ofr2.dyn_cast<Value>();
   return v1 && v2 && v1 == v2;
 }