From bbe5bf1788b55e3c7020d50ee0fd5956f261cfec Mon Sep 17 00:00:00 2001 From: Mehdi Amini Date: Sun, 14 May 2023 22:39:50 -0700 Subject: [PATCH] Cleanup uses of getAttrDictionary() in MLIR to use getDiscardableAttrDictionary() when possible This also speeds up some benchmarks in compiling simple fortan file by 2x! Fixes #62687 Differential Revision: https://reviews.llvm.org/D150540 --- mlir/include/mlir/Conversion/LLVMCommon/Pattern.h | 15 +++++++------- mlir/include/mlir/IR/OpDefinition.h | 12 ++++++----- mlir/include/mlir/Transforms/DialectConversion.h | 11 +++++++--- mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp | 7 +++++-- .../Dialect/GPU/Transforms/AsyncRegionRewriter.cpp | 8 ++++---- .../lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp | 4 +++- mlir/lib/Dialect/Shape/IR/Shape.cpp | 3 ++- mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 24 +++++++++++----------- .../Tosa/Transforms/TosaDecomposeTransposeConv.cpp | 8 ++++---- .../Dialect/Tosa/Transforms/TosaInferShapes.cpp | 7 ++++--- mlir/lib/IR/OperationSupport.cpp | 12 +++++++---- mlir/lib/IR/Verifier.cpp | 2 +- mlir/lib/Interfaces/InferTypeOpInterface.cpp | 4 ++-- mlir/test/lib/Dialect/Test/TestPatterns.cpp | 2 +- mlir/test/lib/IR/TestOperationEquals.cpp | 2 +- 15 files changed, 70 insertions(+), 51 deletions(-) diff --git a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h index f17a2e5..1a362f6 100644 --- a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h +++ b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h @@ -147,11 +147,11 @@ public: ConversionPatternRewriter &rewriter) const final { if constexpr (SourceOp::hasProperties()) rewrite(cast(op), - OpAdaptor(operands, op->getAttrDictionary(), + OpAdaptor(operands, op->getDiscardableAttrDictionary(), cast(op).getProperties()), rewriter); - rewrite(cast(op), OpAdaptor(operands, op->getAttrDictionary()), - rewriter); + rewrite(cast(op), + OpAdaptor(operands, op->getDiscardableAttrDictionary()), rewriter); } LogicalResult match(Operation *op) const final { return match(cast(op)); @@ -161,12 +161,13 @@ public: ConversionPatternRewriter &rewriter) const final { if constexpr (SourceOp::hasProperties()) return matchAndRewrite(cast(op), - OpAdaptor(operands, op->getAttrDictionary(), + OpAdaptor(operands, + op->getDiscardableAttrDictionary(), cast(op).getProperties()), rewriter); - return matchAndRewrite(cast(op), - OpAdaptor(operands, op->getAttrDictionary()), - rewriter); + return matchAndRewrite( + cast(op), + OpAdaptor(operands, op->getDiscardableAttrDictionary()), rewriter); } /// Rewrite and Match methods that operate on the SourceOp type. These must be diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h index d08d3de..71864def 100644 --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -1890,11 +1890,12 @@ private: if constexpr (has_fold_adaptor_single_result_v) { if constexpr (hasProperties()) { result = cast(op).fold(typename ConcreteOpT::FoldAdaptor( - operands, op->getAttrDictionary(), + operands, op->getDiscardableAttrDictionary(), cast(op).getProperties(), op->getRegions())); } else { result = cast(op).fold(typename ConcreteOpT::FoldAdaptor( - operands, op->getAttrDictionary(), {}, op->getRegions())); + operands, op->getDiscardableAttrDictionary(), {}, + op->getRegions())); } } else { result = cast(op).fold(operands); @@ -1920,13 +1921,14 @@ private: if constexpr (hasProperties()) { result = cast(op).fold( typename ConcreteOpT::FoldAdaptor( - operands, op->getAttrDictionary(), + operands, op->getDiscardableAttrDictionary(), cast(op).getProperties(), op->getRegions()), results); } else { result = cast(op).fold( - typename ConcreteOpT::FoldAdaptor(operands, op->getAttrDictionary(), - {}, op->getRegions()), + typename ConcreteOpT::FoldAdaptor( + operands, op->getDiscardableAttrDictionary(), {}, + op->getRegions()), results); } } else { diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index 020c8ce9..f242eea 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -520,7 +520,10 @@ public: } void rewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { - rewrite(cast(op), OpAdaptor(operands, op->getAttrDictionary()), + auto sourceOp = cast(op); + rewrite(sourceOp, + OpAdaptor(operands, op->getDiscardableAttrDictionary(), + sourceOp.getProperties()), rewriter); } LogicalResult @@ -529,11 +532,13 @@ public: auto sourceOp = cast(op); if constexpr (SourceOp::hasProperties()) return matchAndRewrite(sourceOp, - OpAdaptor(operands, op->getAttrDictionary(), + OpAdaptor(operands, + op->getDiscardableAttrDictionary(), sourceOp.getProperties()), rewriter); return matchAndRewrite( - sourceOp, OpAdaptor(operands, op->getAttrDictionary()), rewriter); + sourceOp, OpAdaptor(operands, op->getDiscardableAttrDictionary()), + rewriter); } /// Rewrite and Match methods that operate on the SourceOp type. These must be diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp index 17b9b74..47c2cdb 100644 --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -132,8 +132,11 @@ struct ReallocOpLoweringBase : public AllocationOpLLVMLowering { LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { - return matchAndRewrite(cast(op), - OpAdaptor(operands, op->getAttrDictionary()), + auto reallocOp = cast(op); + return matchAndRewrite(reallocOp, + OpAdaptor(operands, + op->getDiscardableAttrDictionary(), + reallocOp.getProperties()), rewriter); } diff --git a/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp b/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp index 1fbe66f..40903f1 100644 --- a/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp @@ -111,10 +111,10 @@ private: resultTypes.reserve(1 + op->getNumResults()); copy(op->getResultTypes(), std::back_inserter(resultTypes)); resultTypes.push_back(tokenType); - auto *newOp = Operation::create(op->getLoc(), op->getName(), resultTypes, - op->getOperands(), op->getAttrDictionary(), - op->getPropertiesStorage(), - op->getSuccessors(), op->getNumRegions()); + auto *newOp = Operation::create( + op->getLoc(), op->getName(), resultTypes, op->getOperands(), + op->getDiscardableAttrDictionary(), op->getPropertiesStorage(), + op->getSuccessors(), op->getNumRegions()); // Clone regions into new op. IRMapping mapping; diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp index 3aa1c3f..36e967d 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp @@ -393,7 +393,9 @@ private: } bool isSameAttrList(spirv::StoreOp lhs, spirv::StoreOp rhs) const { - return lhs->getAttrDictionary() == rhs->getAttrDictionary(); + return lhs->getDiscardableAttrDictionary() == + rhs->getDiscardableAttrDictionary() && + lhs->hashProperties() == rhs->hashProperties(); } // Returns a source value for the given block. diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp index 1a056a0..2430254 100644 --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -920,7 +920,8 @@ LogicalResult mlir::shape::ConstShapeOp::inferReturnTypes( DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl &inferredReturnTypes) { Builder b(context); - auto shape = attributes.getAs("shape"); + Properties *prop = properties.as(); + DenseIntElementsAttr shape = prop->shape; if (!shape) return emitOptionalError(location, "missing shape attribute"); inferredReturnTypes.assign({RankedTensorType::get( diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index 1b063e7..1040d4c 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -383,7 +383,8 @@ LogicalResult tosa::ArgMaxOp::inferReturnTypeComponents( OpaqueProperties properties, RegionRange regions, SmallVectorImpl &inferredReturnShapes) { ShapeAdaptor inputShape = operands.getShape(0); - IntegerAttr axis = llvm::cast(attributes.get("axis")); + auto *prop = properties.as(); + IntegerAttr axis = prop->axis; int32_t axisVal = axis.getValue().getSExtValue(); if (!inputShape.hasRank()) { @@ -446,8 +447,8 @@ LogicalResult tosa::ConcatOp::inferReturnTypeComponents( OpaqueProperties properties, RegionRange regions, SmallVectorImpl &inferredReturnShapes) { // Infer all dimension sizes by reducing based on inputs. - int32_t axis = - llvm::cast(attributes.get("axis")).getValue().getSExtValue(); + auto *prop = properties.as(); + int32_t axis = prop->axis.getValue().getSExtValue(); llvm::SmallVector outputShape; bool hasRankedInput = false; for (auto operand : operands) { @@ -985,7 +986,7 @@ static LogicalResult ReduceInferReturnTypes( Type inputType = \ operands.getType()[0].cast().getElementType(); \ return ReduceInferReturnTypes(operands.getShape(0), inputType, \ - attributes.get("axis").cast(), \ + properties.as()->axis, \ inferredReturnShapes); \ } \ COMPATIBLE_RETURN_TYPES(OP) @@ -1062,6 +1063,7 @@ NARY_SHAPE_INFER(tosa::SigmoidOp) static LogicalResult poolingInferReturnTypes( const ValueShapeRange &operands, DictionaryAttr attributes, + ArrayRef kernel, ArrayRef stride, ArrayRef pad, SmallVectorImpl &inferredReturnShapes) { ShapeAdaptor inputShape = operands.getShape(0); llvm::SmallVector outputShape; @@ -1080,12 +1082,6 @@ static LogicalResult poolingInferReturnTypes( int64_t height = inputShape.getDimSize(1); int64_t width = inputShape.getDimSize(2); - ArrayRef kernel = - llvm::cast(attributes.get("kernel")); - ArrayRef stride = - llvm::cast(attributes.get("stride")); - ArrayRef pad = llvm::cast(attributes.get("pad")); - if (!ShapedType::isDynamic(height)) { int64_t padded = height + pad[0] + pad[1] - kernel[0]; outputShape[1] = padded / stride[0] + 1; @@ -1245,7 +1241,9 @@ LogicalResult AvgPool2dOp::inferReturnTypeComponents( ValueShapeRange operands, DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl &inferredReturnShapes) { - return poolingInferReturnTypes(operands, attributes, inferredReturnShapes); + Properties &prop = *properties.as(); + return poolingInferReturnTypes(operands, attributes, prop.kernel, prop.stride, + prop.pad, inferredReturnShapes); } LogicalResult MaxPool2dOp::inferReturnTypeComponents( @@ -1253,7 +1251,9 @@ LogicalResult MaxPool2dOp::inferReturnTypeComponents( ValueShapeRange operands, DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl &inferredReturnShapes) { - return poolingInferReturnTypes(operands, attributes, inferredReturnShapes); + Properties &prop = *properties.as(); + return poolingInferReturnTypes(operands, attributes, prop.kernel, prop.stride, + prop.pad, inferredReturnShapes); } LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents( diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp index 87563c1..50a556d 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp @@ -37,10 +37,10 @@ TosaOp createOpAndInfer(PatternRewriter &rewriter, Location loc, Type resultTy, SmallVector returnedShapes; if (shapeInterface - .inferReturnTypeComponents(op.getContext(), op.getLoc(), - op->getOperands(), op->getAttrDictionary(), - op->getPropertiesStorage(), - op->getRegions(), returnedShapes) + .inferReturnTypeComponents( + op.getContext(), op.getLoc(), op->getOperands(), + op->getDiscardableAttrDictionary(), op->getPropertiesStorage(), + op->getRegions(), returnedShapes) .failed()) return op; diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp index 3e2da9d..65b66d2 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp @@ -218,9 +218,10 @@ void propagateShapesInRegion(Region ®ion) { ValueShapeRange range(op.getOperands(), operandShape); if (shapeInterface - .inferReturnTypeComponents( - op.getContext(), op.getLoc(), range, op.getAttrDictionary(), - op.getPropertiesStorage(), op.getRegions(), returnedShapes) + .inferReturnTypeComponents(op.getContext(), op.getLoc(), range, + op.getDiscardableAttrDictionary(), + op.getPropertiesStorage(), + op.getRegions(), returnedShapes) .succeeded()) { for (auto it : llvm::zip(op.getResults(), returnedShapes)) { Value result = std::get<0>(it); diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp index b8f3601..716239a 100644 --- a/mlir/lib/IR/OperationSupport.cpp +++ b/mlir/lib/IR/OperationSupport.cpp @@ -653,7 +653,7 @@ llvm::hash_code OperationEquivalence::computeHash( // - Attributes // - Result Types llvm::hash_code hash = - llvm::hash_combine(op->getName(), op->getAttrDictionary(), + llvm::hash_combine(op->getName(), op->getDiscardableAttrDictionary(), op->getResultTypes(), op->hashProperties()); // - Operands @@ -768,11 +768,13 @@ OperationEquivalence::isRegionEquivalentTo(Region *lhs, Region *rhs, // 1. Compare the operation properties. if (lhs->getName() != rhs->getName() || - lhs->getAttrDictionary() != rhs->getAttrDictionary() || + lhs->getDiscardableAttrDictionary() != + rhs->getDiscardableAttrDictionary() || lhs->getNumRegions() != rhs->getNumRegions() || lhs->getNumSuccessors() != rhs->getNumSuccessors() || lhs->getNumOperands() != rhs->getNumOperands() || - lhs->getNumResults() != rhs->getNumResults()) + lhs->getNumResults() != rhs->getNumResults() || + lhs->hashProperties() != rhs->hashProperties()) return false; if (!(flags & IgnoreLocations) && lhs->getLoc() != rhs->getLoc()) return false; @@ -876,7 +878,9 @@ OperationFingerPrint::OperationFingerPrint(Operation *topOp) { // - Operation pointer addDataToHash(hasher, op); // - Attributes - addDataToHash(hasher, op->getAttrDictionary()); + addDataToHash(hasher, op->getDiscardableAttrDictionary()); + // - Properties + addDataToHash(hasher, op->hashProperties()); // - Blocks in Regions for (Region ®ion : op->getRegions()) { for (Block &block : region) { diff --git a/mlir/lib/IR/Verifier.cpp b/mlir/lib/IR/Verifier.cpp index 68e498d..a7f84be 100644 --- a/mlir/lib/IR/Verifier.cpp +++ b/mlir/lib/IR/Verifier.cpp @@ -174,7 +174,7 @@ LogicalResult OperationVerifier::verifyOperation(Operation &op) { return op.emitError("null operand found"); /// Verify that all of the attributes are okay. - for (auto attr : op.getAttrs()) { + for (auto attr : op.getDiscardableAttrDictionary()) { // Check for any optional dialect specific attributes. if (auto *dialect = attr.getNameDialect()) if (failed(dialect->verifyOperationAttribute(&op, attr))) diff --git a/mlir/lib/Interfaces/InferTypeOpInterface.cpp b/mlir/lib/Interfaces/InferTypeOpInterface.cpp index 80ed2cc..aaa1e1b 100644 --- a/mlir/lib/Interfaces/InferTypeOpInterface.cpp +++ b/mlir/lib/Interfaces/InferTypeOpInterface.cpp @@ -251,8 +251,8 @@ LogicalResult mlir::detail::verifyInferredResultTypes(Operation *op) { auto retTypeFn = cast(op); auto result = retTypeFn.refineReturnTypes( op->getContext(), op->getLoc(), op->getOperands(), - op->getAttrDictionary(), op->getPropertiesStorage(), op->getRegions(), - inferredReturnTypes); + op->getDiscardableAttrDictionary(), op->getPropertiesStorage(), + op->getRegions(), inferredReturnTypes); if (failed(result)) op->emitOpError() << "failed to infer returned types"; diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp index 82ae72a..3a1faea 100644 --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -436,7 +436,7 @@ static void invokeCreateWithInferredReturnType(Operation *op) { std::array values = {{fop.getArgument(i), fop.getArgument(j)}}; SmallVector inferredReturnTypes; if (succeeded(OpTy::inferReturnTypes( - context, std::nullopt, values, op->getAttrDictionary(), + context, std::nullopt, values, op->getDiscardableAttrDictionary(), op->getPropertiesStorage(), op->getRegions(), inferredReturnTypes))) { OperationState state(location, OpTy::getOperationName()); diff --git a/mlir/test/lib/IR/TestOperationEquals.cpp b/mlir/test/lib/IR/TestOperationEquals.cpp index ef35896..03cf5f4 100644 --- a/mlir/test/lib/IR/TestOperationEquals.cpp +++ b/mlir/test/lib/IR/TestOperationEquals.cpp @@ -31,7 +31,7 @@ struct TestOperationEqualPass Operation *first = &module.getBody()->front(); llvm::outs() << first->getName().getStringRef() << " with attr " - << first->getAttrDictionary(); + << first->getDiscardableAttrDictionary(); OperationEquivalence::Flags flags{}; if (!first->hasAttr("strict_loc_check")) flags |= OperationEquivalence::IgnoreLocations; -- 2.7.4