From: Matthias Springer Date: Mon, 7 Jun 2021 11:05:25 +0000 (+0900) Subject: [mlir] Add offset/stride helper functions to OffsetSizeAndStrideOpInterface X-Git-Tag: llvmorg-14-init~4703 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=6e7bbdd6e7f7649bccc4f981520ed916e21d7058;p=platform%2Fupstream%2Fllvm.git [mlir] Add offset/stride helper functions to OffsetSizeAndStrideOpInterface * Add hasUnitStride and hasZeroOffset to OffsetSizeAndStrideOpInterface. These functions are useful for various patterns. E.g., some vectorization patterns apply only for tensor ops with zero offsets and/or unit stride. * Add getConstantIntValue and isEqualConstantInt helper functions, which are useful for implementing the two above functions, as well as various patterns. Differential Revision: https://reviews.llvm.org/D103763 --- diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h index 65cac17..ee1cc67 100644 --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h @@ -122,6 +122,15 @@ bool applyCmpPredicate(CmpIPredicate predicate, const APInt &lhs, bool applyCmpPredicate(CmpFPredicate predicate, const APFloat &lhs, const APFloat &rhs); +/// If ofr is a constant integer, i.e., an IntegerAttr or a ConstantOp with an +/// IntegerAttr, return the integer. +llvm::Optional getConstantIntValue(OpFoldResult ofr); + +/// Return true if ofr and value are the same integer. +/// Ignore integer bitwidth and type mismatch that come from the fact there is +/// no IndexAttr and that IndexType has no bitwidth. +bool isEqualConstantInt(OpFoldResult ofr, int64_t value); + /// Return true if ofr1 and ofr2 are the same integer constant attribute values /// or the same SSA value. /// Ignore integer bitwitdh and type mismatch that come from the fact there is diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.h b/mlir/include/mlir/Interfaces/ViewLikeInterface.h index 0094fff..8d58570 100644 --- a/mlir/include/mlir/Interfaces/ViewLikeInterface.h +++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.h @@ -30,6 +30,8 @@ struct Range { class OffsetSizeAndStrideOpInterface; +bool isEqualConstantInt(OpFoldResult ofr, int64_t value); + namespace detail { LogicalResult verifyOffsetSizeAndStrideOp(OffsetSizeAndStrideOpInterface op); diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.td b/mlir/include/mlir/Interfaces/ViewLikeInterface.td index e26a02f..62f24f2 100644 --- a/mlir/include/mlir/Interfaces/ViewLikeInterface.td +++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.td @@ -436,6 +436,30 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface $_op.getOperation()), other, cmp); }] >, + InterfaceMethod< + /*desc=*/[{ Return true if all strides are guaranteed to be 1. }], + /*retTy=*/"bool", + /*methodName=*/"hasUnitStride", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return ::llvm::all_of(getMixedStrides(), [](OpFoldResult ofr) { + return ::mlir::isEqualConstantInt(ofr, 1); + }); + }] + >, + InterfaceMethod< + /*desc=*/[{ Return true if all offsets are guaranteed to be 0. }], + /*retTy=*/"bool", + /*methodName=*/"hasZeroOffset", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return ::llvm::all_of(getMixedOffsets(), [](OpFoldResult ofr) { + return ::mlir::isEqualConstantInt(ofr, 0); + }); + }] + >, ]; let extraClassDeclaration = [{ diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp index f115fd0..a3c2513 100644 --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -60,24 +60,35 @@ static void dispatchIndexOpFoldResults(ArrayRef ofrs, dispatchIndexOpFoldResult(ofr, dynamicVec, staticVec, sentinel); } +/// If ofr is a constant integer, i.e., an IntegerAttr or a ConstantOp with an +/// IntegerAttr, return the integer. +llvm::Optional mlir::getConstantIntValue(OpFoldResult ofr) { + Attribute attr = ofr.dyn_cast(); + // Note: isa+cast-like pattern allows writing the condition below as 1 line. + if (!attr && ofr.get().getDefiningOp()) + attr = ofr.get().getDefiningOp().getValue(); + if (auto intAttr = attr.dyn_cast_or_null()) + return intAttr.getValue().getSExtValue(); + return llvm::None; +} + +/// Return true if ofr and value are the same integer. +/// Ignore integer bitwidth and type mismatch that come from the fact there is +/// no IndexAttr and that IndexType has no bitwidth. +bool mlir::isEqualConstantInt(OpFoldResult ofr, int64_t value) { + auto ofrValue = getConstantIntValue(ofr); + return ofrValue && *ofrValue == value; +} + /// Return true if ofr1 and ofr2 are the same integer constant attribute values /// or the same SSA value. -/// Ignore integer bitwitdh and type mismatch that come from the fact there is -/// no IndexAttr and that IndexType have no bitwidth. -bool mlir::isEqualConstantIntOrValue(OpFoldResult op1, OpFoldResult op2) { - auto getConstantIntValue = [](OpFoldResult ofr) -> llvm::Optional { - Attribute attr = ofr.dyn_cast(); - // Note: isa+cast-like pattern allows writing the condition below as 1 line. - if (!attr && ofr.get().getDefiningOp()) - attr = ofr.get().getDefiningOp().getValue(); - if (auto intAttr = attr.dyn_cast_or_null()) - return intAttr.getValue().getSExtValue(); - return llvm::None; - }; - auto cst1 = getConstantIntValue(op1), cst2 = getConstantIntValue(op2); +/// Ignore integer bitwidth and type mismatch that come from the fact there is +/// no IndexAttr and that IndexType has no bitwidth. +bool mlir::isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2) { + auto cst1 = getConstantIntValue(ofr1), cst2 = getConstantIntValue(ofr2); if (cst1 && cst2 && *cst1 == *cst2) return true; - auto v1 = op1.dyn_cast(), v2 = op2.dyn_cast(); + auto v1 = ofr1.dyn_cast(), v2 = ofr2.dyn_cast(); return v1 && v2 && v1 == v2; }