From 0813700de1af72173ad18202fcbd3eafce90d184 Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Sun, 27 Jun 2021 15:15:44 +0900 Subject: [PATCH] [mlir][NFC] Cleanup: Move helper functions to StaticValueUtils Reduce code duplication: Move various helper functions, that are duplicated in TensorDialect, MemRefDialect, LinalgDialect, StandardDialect, into a new StaticValueUtils.cpp. Differential Revision: https://reviews.llvm.org/D104687 --- mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td | 4 +- .../mlir/Dialect/Linalg/Transforms/Transforms.h | 1 + mlir/include/mlir/Dialect/StandardOps/IR/Ops.h | 15 ---- mlir/include/mlir/Dialect/Utils/StaticValueUtils.h | 58 ++++++++++++++++ mlir/include/mlir/Interfaces/ViewLikeInterface.h | 3 +- mlir/include/mlir/Interfaces/ViewLikeInterface.td | 4 +- .../Conversion/StandardToLLVM/StandardToLLVM.cpp | 9 +-- mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 27 +------- .../Linalg/Transforms/ComprehensiveBufferize.cpp | 1 + .../Dialect/Linalg/Transforms/Vectorization.cpp | 10 +-- mlir/lib/Dialect/MemRef/IR/CMakeLists.txt | 1 + mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 35 +--------- mlir/lib/Dialect/StandardOps/IR/Ops.cpp | 32 --------- mlir/lib/Dialect/Tensor/IR/CMakeLists.txt | 1 + mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 42 ++---------- mlir/lib/Dialect/Utils/CMakeLists.txt | 1 + mlir/lib/Dialect/Utils/StaticValueUtils.cpp | 79 ++++++++++++++++++++++ 17 files changed, 159 insertions(+), 164 deletions(-) create mode 100644 mlir/include/mlir/Dialect/Utils/StaticValueUtils.h create mode 100644 mlir/lib/Dialect/Utils/StaticValueUtils.cpp diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td index df568d6..ffd65f7 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td @@ -269,13 +269,13 @@ def Linalg_PadTensorOp : Linalg_Op<"pad_tensor", // Return true if low padding is guaranteed to be 0. bool hasZeroLowPad() { return llvm::all_of(getMixedLowPad(), [](OpFoldResult ofr) { - return mlir::isEqualConstantInt(ofr, 0); + return getConstantIntValue(ofr) == static_cast(0); }); } // Return true if high padding is guaranteed to be 0. bool hasZeroHighPad() { return llvm::all_of(getMixedHighPad(), [](OpFoldResult ofr) { - return mlir::isEqualConstantInt(ofr, 0); + return getConstantIntValue(ofr) == static_cast(0); }); } }]; diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index f6b78ae..5f533df 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -13,6 +13,7 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/Utils.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/IR/Identifier.h" #include "mlir/IR/PatternMatch.h" diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h index bff62c7..477474b 100644 --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h @@ -114,21 +114,6 @@ bool applyCmpPredicate(CmpIPredicate predicate, const APInt &lhs, bool applyCmpPredicate(CmpFPredicate predicate, const APFloat &lhs, const APFloat &rhs); -/// If ofr is a constant integer, i.e., an IntegerAttr or a ConstantOp with an -/// IntegerAttr, return the integer. -llvm::Optional getConstantIntValue(OpFoldResult ofr); - -/// Return true if ofr and value are the same integer. -/// Ignore integer bitwidth and type mismatch that come from the fact there is -/// no IndexAttr and that IndexType has no bitwidth. -bool isEqualConstantInt(OpFoldResult ofr, int64_t value); - -/// Return true if ofr1 and ofr2 are the same integer constant attribute values -/// or the same SSA value. -/// Ignore integer bitwitdh and type mismatch that come from the fact there is -/// no IndexAttr and that IndexType have no bitwidth. -bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2); - /// Returns the identity value attribute associated with an AtomicRMWKind op. Attribute getIdentityValueAttr(AtomicRMWKind kind, Type resultType, OpBuilder &builder, Location loc); diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h new file mode 100644 index 0000000..3284c02 --- /dev/null +++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h @@ -0,0 +1,58 @@ +//===- StaticValueUtils.h - Utilities for static values ---------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header file defines utilities for dealing with static values, e.g., +// converting back and forth between Value and OpFoldResult. Such functionality +// is used in multiple dialects. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_UTILS_STATICVALUEUTILS_H +#define MLIR_DIALECT_UTILS_STATICVALUEUTILS_H + +#include "mlir/IR/OpDefinition.h" +#include "mlir/Support/LLVM.h" +#include "llvm/ADT/SmallVector.h" + +namespace mlir { + +/// Helper function to dispatch an OpFoldResult into either the `dynamicVec` if +/// it is a Value or into `staticVec` if it is an IntegerAttr. +/// In the case of a Value, a copy of the `sentinel` value is also pushed to +/// `staticVec`. This is useful to extract mixed static and dynamic entries that +/// come from an AttrSizedOperandSegments trait. +void dispatchIndexOpFoldResult(OpFoldResult ofr, + SmallVectorImpl &dynamicVec, + SmallVectorImpl &staticVec, + int64_t sentinel); + +/// Helper function to dispatch multiple OpFoldResults into either the +/// `dynamicVec` (for Values) or into `staticVec` (for IntegerAttrs). +/// In the case of a Value, a copy of the `sentinel` value is also pushed to +/// `staticVec`. This is useful to extract mixed static and dynamic entries that +/// come from an AttrSizedOperandSegments trait. +void dispatchIndexOpFoldResults(ArrayRef ofrs, + SmallVectorImpl &dynamicVec, + SmallVectorImpl &staticVec, + int64_t sentinel); + +/// Extract int64_t values from the assumed ArrayAttr of IntegerAttr. +SmallVector extractFromI64ArrayAttr(Attribute attr); + +/// If ofr is a constant integer or an IntegerAttr, return the integer. +Optional getConstantIntValue(OpFoldResult ofr); + +/// Return true if ofr1 and ofr2 are the same integer constant attribute values +/// or the same SSA value. +/// Ignore integer bitwitdh and type mismatch that come from the fact there is +/// no IndexAttr and that IndexType have no bitwidth. +bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2); + +} // namespace mlir + +#endif // MLIR_DIALECT_UTILS_STATICVALUEUTILS_H diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.h b/mlir/include/mlir/Interfaces/ViewLikeInterface.h index 8d58570..7df3d1e 100644 --- a/mlir/include/mlir/Interfaces/ViewLikeInterface.h +++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.h @@ -13,6 +13,7 @@ #ifndef MLIR_INTERFACES_VIEWLIKEINTERFACE_H_ #define MLIR_INTERFACES_VIEWLIKEINTERFACE_H_ +#include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" @@ -30,8 +31,6 @@ struct Range { class OffsetSizeAndStrideOpInterface; -bool isEqualConstantInt(OpFoldResult ofr, int64_t value); - namespace detail { LogicalResult verifyOffsetSizeAndStrideOp(OffsetSizeAndStrideOpInterface op); diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.td b/mlir/include/mlir/Interfaces/ViewLikeInterface.td index 62f24f2..2ba9038 100644 --- a/mlir/include/mlir/Interfaces/ViewLikeInterface.td +++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.td @@ -444,7 +444,7 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface /*methodBody=*/"", /*defaultImplementation=*/[{ return ::llvm::all_of(getMixedStrides(), [](OpFoldResult ofr) { - return ::mlir::isEqualConstantInt(ofr, 1); + return ::mlir::getConstantIntValue(ofr) == static_cast(1); }); }] >, @@ -456,7 +456,7 @@ def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface /*methodBody=*/"", /*defaultImplementation=*/[{ return ::llvm::all_of(getMixedOffsets(), [](OpFoldResult ofr) { - return ::mlir::isEqualConstantInt(ofr, 0); + return ::mlir::getConstantIntValue(ofr) == static_cast(0); }); }] >, diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp index 8e808d7..db5918e 100644 --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -20,6 +20,7 @@ #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Builders.h" @@ -3388,14 +3389,6 @@ struct SplatNdOpLowering : public ConvertOpToLLVMPattern { } }; -/// Helper function extracts int64_t from the assumedArrayAttr of IntegerAttr. -static SmallVector extractFromI64ArrayAttr(Attribute attr) { - return llvm::to_vector<4>( - llvm::map_range(attr.cast(), [](Attribute a) -> int64_t { - return a.cast().getInt(); - })); -} - /// Conversion pattern that transforms a subview op into: /// 1. An `llvm.mlir.undef` operation to create a memref descriptor /// 2. Updates to the descriptor to introduce the data ptr, offset, size diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 9a1ceeb..109a1c6 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -17,6 +17,7 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/AffineExprVisitor.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/OpImplementation.h" @@ -116,24 +117,6 @@ static SmallVector getAsValues(OpBuilder &b, Location loc, })); } -/// Helper function to dispatch an OpFoldResult into either the `dynamicVec` if -/// it is a Value or into `staticVec` if it is an IntegerAttr. -/// In the case of a Value, a copy of the `sentinel` value is also pushed to -/// `staticVec`. This is useful to extract mixed static and dynamic entries that -/// come from an AttrSizedOperandSegments trait. -static void dispatchIndexOpFoldResult(OpFoldResult ofr, - SmallVectorImpl &dynamicVec, - SmallVectorImpl &staticVec, - int64_t sentinel) { - if (auto v = ofr.dyn_cast()) { - dynamicVec.push_back(v); - staticVec.push_back(sentinel); - return; - } - APInt apInt = ofr.dyn_cast().cast().getValue(); - staticVec.push_back(apInt.getSExtValue()); -} - /// This is a common class used for patterns of the form /// ``` /// someop(memrefcast(%src)) -> someop(%src) @@ -819,14 +802,6 @@ LogicalResult InitTensorOp::reifyReturnTypeShapesPerResultDim( // PadTensorOp //===----------------------------------------------------------------------===// -/// Extract int64_t values from the assumed ArrayAttr of IntegerAttr. -static SmallVector extractFromI64ArrayAttr(Attribute attr) { - return llvm::to_vector<4>( - llvm::map_range(attr.cast(), [](Attribute a) -> int64_t { - return a.cast().getInt(); - })); -} - static LogicalResult verify(PadTensorOp op) { auto sourceType = op.source().getType().cast(); auto resultType = op.result().getType().cast(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp index d02570a..c951e70 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp @@ -110,6 +110,7 @@ #include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/IR/Operation.h" #include "mlir/Pass/Pass.h" diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index 829b988..92382a6 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -814,8 +814,8 @@ struct GenericPadTensorOpVectorizationPattern readInBounds.push_back(false); // Write is out-of-bounds if low padding > 0. writeInBounds.push_back( - isEqualConstantIntOrValue(padOp.getMixedLowPad()[i], - rewriter.getIndexAttr(0))); + getConstantIntValue(padOp.getMixedLowPad()[i]) == + static_cast(0)); } else { // Neither source nor result dim of padOp is static. Cannot vectorize // the copy. @@ -1098,9 +1098,9 @@ struct PadTensorOpVectorizationWithInsertSlicePattern SmallVector expectedSizes(tensorRank - vecRank, 1); expectedSizes.append(vecType.getShape().begin(), vecType.getShape().end()); if (!llvm::all_of( - llvm::zip(insertOp.getMixedSizes(), expectedSizes), - [](auto it) { return isEqualConstantInt(std::get<0>(it), - std::get<1>(it)); })) + llvm::zip(insertOp.getMixedSizes(), expectedSizes), [](auto it) { + return getConstantIntValue(std::get<0>(it)) == std::get<1>(it); + })) return failure(); // Generate TransferReadOp: Read entire source tensor and add high padding. diff --git a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt index 6ac47b1..6f9aeaa 100644 --- a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt @@ -14,6 +14,7 @@ add_mlir_dialect_library(MLIRMemRef LINK_LIBS PUBLIC MLIRDialect + MLIRDialectUtils MLIRInferTypeOpInterface MLIRIR MLIRMemRefUtils diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index 8d00357..cc4e7a4 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -11,6 +11,7 @@ #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/StandardOps/Utils/Utils.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" @@ -32,40 +33,6 @@ Operation *MemRefDialect::materializeConstant(OpBuilder &builder, return builder.create(loc, type, value); } -/// Extract int64_t values from the assumed ArrayAttr of IntegerAttr. -static SmallVector extractFromI64ArrayAttr(Attribute attr) { - return llvm::to_vector<4>( - llvm::map_range(attr.cast(), [](Attribute a) -> int64_t { - return a.cast().getInt(); - })); -} - -/// Helper function to dispatch an OpFoldResult into either the `dynamicVec` if -/// it is a Value or into `staticVec` if it is an IntegerAttr. -/// In the case of a Value, a copy of the `sentinel` value is also pushed to -/// `staticVec`. This is useful to extract mixed static and dynamic entries that -/// come from an AttrSizedOperandSegments trait. -static void dispatchIndexOpFoldResult(OpFoldResult ofr, - SmallVectorImpl &dynamicVec, - SmallVectorImpl &staticVec, - int64_t sentinel) { - if (auto v = ofr.dyn_cast()) { - dynamicVec.push_back(v); - staticVec.push_back(sentinel); - return; - } - APInt apInt = ofr.dyn_cast().cast().getValue(); - staticVec.push_back(apInt.getSExtValue()); -} - -static void dispatchIndexOpFoldResults(ArrayRef ofrs, - SmallVectorImpl &dynamicVec, - SmallVectorImpl &staticVec, - int64_t sentinel) { - for (auto ofr : ofrs) - dispatchIndexOpFoldResult(ofr, dynamicVec, staticVec, sentinel); -} - //===----------------------------------------------------------------------===// // Common canonicalization pattern support logic //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp index 73c0c4a..837986f 100644 --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -33,38 +33,6 @@ using namespace mlir; -/// If ofr is a constant integer, i.e., an IntegerAttr or a ConstantOp with an -/// IntegerAttr, return the integer. -llvm::Optional mlir::getConstantIntValue(OpFoldResult ofr) { - Attribute attr = ofr.dyn_cast(); - // Note: isa+cast-like pattern allows writing the condition below as 1 line. - if (!attr && ofr.get().getDefiningOp()) - attr = ofr.get().getDefiningOp().getValue(); - if (auto intAttr = attr.dyn_cast_or_null()) - return intAttr.getValue().getSExtValue(); - return llvm::None; -} - -/// Return true if ofr and value are the same integer. -/// Ignore integer bitwidth and type mismatch that come from the fact there is -/// no IndexAttr and that IndexType has no bitwidth. -bool mlir::isEqualConstantInt(OpFoldResult ofr, int64_t value) { - auto ofrValue = getConstantIntValue(ofr); - return ofrValue && *ofrValue == value; -} - -/// Return true if ofr1 and ofr2 are the same integer constant attribute values -/// or the same SSA value. -/// Ignore integer bitwidth and type mismatch that come from the fact there is -/// no IndexAttr and that IndexType has no bitwidth. -bool mlir::isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2) { - auto cst1 = getConstantIntValue(ofr1), cst2 = getConstantIntValue(ofr2); - if (cst1 && cst2 && *cst1 == *cst2) - return true; - auto v1 = ofr1.dyn_cast(), v2 = ofr2.dyn_cast(); - return v1 && v2 && v1 == v2; -} - //===----------------------------------------------------------------------===// // StandardOpsDialect Interfaces //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt b/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt index e1fad1b..4b6886e 100644 --- a/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt @@ -13,6 +13,7 @@ add_mlir_dialect_library(MLIRTensor LINK_LIBS PUBLIC MLIRCastInterfaces + MLIRDialectUtils MLIRIR MLIRSideEffectInterfaces MLIRSupport diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index 8a4b212..28a5f5d 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -8,6 +8,7 @@ #include "mlir/Dialect/StandardOps/Utils/Utils.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Matchers.h" @@ -516,32 +517,6 @@ static LogicalResult verify(ReshapeOp op) { // ExtractSliceOp //===----------------------------------------------------------------------===// -/// Helper function to dispatch an OpFoldResult into either the `dynamicVec` if -/// it is a Value or into `staticVec` if it is an IntegerAttr. -/// In the case of a Value, a copy of the `sentinel` value is also pushed to -/// `staticVec`. This is useful to extract mixed static and dynamic entries that -/// come from an AttrSizedOperandSegments trait. -static void dispatchIndexOpFoldResult(OpFoldResult ofr, - SmallVectorImpl &dynamicVec, - SmallVectorImpl &staticVec, - int64_t sentinel) { - if (auto v = ofr.dyn_cast()) { - dynamicVec.push_back(v); - staticVec.push_back(sentinel); - return; - } - APInt apInt = ofr.dyn_cast().cast().getValue(); - staticVec.push_back(apInt.getSExtValue()); -} - -static void dispatchIndexOpFoldResults(ArrayRef ofrs, - SmallVectorImpl &dynamicVec, - SmallVectorImpl &staticVec, - int64_t sentinel) { - for (auto ofr : ofrs) - dispatchIndexOpFoldResult(ofr, dynamicVec, staticVec, sentinel); -} - /// An extract_slice op result type can be fully inferred from the source type /// and the static representation of offsets, sizes and strides. Special /// sentinels encode the dynamic case. @@ -563,14 +538,6 @@ Type ExtractSliceOp::inferResultType(RankedTensorType sourceRankedTensorType, sourceRankedTensorType.getElementType()); } -/// Extract int64_t values from the assumed ArrayAttr of IntegerAttr. -static SmallVector extractFromI64ArrayAttr(Attribute attr) { - return llvm::to_vector<4>( - llvm::map_range(attr.cast(), [](Attribute a) -> int64_t { - return a.cast().getInt(); - })); -} - Type ExtractSliceOp::inferResultType( RankedTensorType sourceRankedTensorType, ArrayRef leadingStaticOffsets, @@ -890,17 +857,16 @@ foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op, ShapedType shapedType) { OpBuilder b(op.getContext()); for (OpFoldResult ofr : op.getMixedOffsets()) - if (!isEqualConstantIntOrValue(ofr, b.getIndexAttr(0))) + if (getConstantIntValue(ofr) != static_cast(0)) return failure(); // Rank-reducing noops only need to inspect the leading dimensions: llvm::zip // is appropriate. auto shape = shapedType.getShape(); for (auto it : llvm::zip(op.getMixedSizes(), shape)) - if (!isEqualConstantIntOrValue(std::get<0>(it), - b.getIndexAttr(std::get<1>(it)))) + if (getConstantIntValue(std::get<0>(it)) != std::get<1>(it)) return failure(); for (OpFoldResult ofr : op.getMixedStrides()) - if (!isEqualConstantIntOrValue(ofr, b.getIndexAttr(1))) + if (getConstantIntValue(ofr) != static_cast(1)) return failure(); return success(); } diff --git a/mlir/lib/Dialect/Utils/CMakeLists.txt b/mlir/lib/Dialect/Utils/CMakeLists.txt index a640e35..098b6b4 100644 --- a/mlir/lib/Dialect/Utils/CMakeLists.txt +++ b/mlir/lib/Dialect/Utils/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_library(MLIRDialectUtils StructuredOpsUtils.cpp + StaticValueUtils.cpp LINK_LIBS PUBLIC MLIRIR diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp new file mode 100644 index 0000000..bf7d662 --- /dev/null +++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp @@ -0,0 +1,79 @@ +//===- StaticValueUtils.cpp - Utilities for dealing with static values ----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/IR/Matchers.h" +#include "mlir/Support/LLVM.h" +#include "llvm/ADT/APSInt.h" + +namespace mlir { + +/// Helper function to dispatch an OpFoldResult into either the `dynamicVec` if +/// it is a Value or into `staticVec` if it is an IntegerAttr. +/// In the case of a Value, a copy of the `sentinel` value is also pushed to +/// `staticVec`. This is useful to extract mixed static and dynamic entries that +/// come from an AttrSizedOperandSegments trait. +void dispatchIndexOpFoldResult(OpFoldResult ofr, + SmallVectorImpl &dynamicVec, + SmallVectorImpl &staticVec, + int64_t sentinel) { + if (auto v = ofr.dyn_cast()) { + dynamicVec.push_back(v); + staticVec.push_back(sentinel); + return; + } + APInt apInt = ofr.dyn_cast().cast().getValue(); + staticVec.push_back(apInt.getSExtValue()); +} + +void dispatchIndexOpFoldResults(ArrayRef ofrs, + SmallVectorImpl &dynamicVec, + SmallVectorImpl &staticVec, + int64_t sentinel) { + for (OpFoldResult ofr : ofrs) + dispatchIndexOpFoldResult(ofr, dynamicVec, staticVec, sentinel); +} + +/// Extract int64_t values from the assumed ArrayAttr of IntegerAttr. +SmallVector extractFromI64ArrayAttr(Attribute attr) { + return llvm::to_vector<4>( + llvm::map_range(attr.cast(), [](Attribute a) -> int64_t { + return a.cast().getInt(); + })); +} + +/// If ofr is a constant integer or an IntegerAttr, return the integer. +Optional getConstantIntValue(OpFoldResult ofr) { + // Case 1: Check for Constant integer. + if (auto val = ofr.dyn_cast()) { + APSInt intVal; + if (matchPattern(val, m_ConstantInt(&intVal))) + return intVal.getSExtValue(); + return llvm::None; + } + // Case 2: Check for IntegerAttr. + Attribute attr = ofr.dyn_cast(); + if (auto intAttr = attr.dyn_cast_or_null()) + return intAttr.getValue().getSExtValue(); + return llvm::None; +} + +/// Return true if ofr1 and ofr2 are the same integer constant attribute values +/// or the same SSA value. +/// Ignore integer bitwidth and type mismatch that come from the fact there is +/// no IndexAttr and that IndexType has no bitwidth. +bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2) { + auto cst1 = getConstantIntValue(ofr1), cst2 = getConstantIntValue(ofr2); + if (cst1 && cst2 && *cst1 == *cst2) + return true; + auto v1 = ofr1.dyn_cast(), v2 = ofr2.dyn_cast(); + return v1 && v1 == v2; +} + +} // namespace mlir + -- 2.7.4