From 75044e9b4f20d025295dbd56284435937cfb4de5 Mon Sep 17 00:00:00 2001 From: Jacques Pienaar Date: Tue, 15 Feb 2022 09:48:51 -0800 Subject: [PATCH] [mlir] Flipping vector dialect to both prefixed form. Following https://discourse.llvm.org/t/psa-ods-generated-accessors-will-change-to-have-a-get-prefix-update-you-apis/4476 Mostly mechanical, avoiding function name conflicts. Differential Revision: https://reviews.llvm.org/D119607 --- mlir/include/mlir/Dialect/Vector/IR/VectorOps.td | 56 +++++++++------ mlir/include/mlir/Interfaces/VectorInterfaces.td | 4 +- mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp | 2 +- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 83 ++++++++++------------ .../VectorTransferSplitRewritePatterns.cpp | 4 +- .../Dialect/Vector/Transforms/VectorTransforms.cpp | 6 +- 6 files changed, 79 insertions(+), 76 deletions(-) diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index 009df11..66d4a69 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -22,6 +22,7 @@ def Vector_Dialect : Dialect { let cppNamespace = "::mlir::vector"; let hasConstantMaterializer = 1; let dependentDialects = ["arith::ArithmeticDialect"]; + let emitAccessorPrefix = kEmitAccessorPrefix_Both; } // Base class for Vector dialect ops. @@ -63,6 +64,15 @@ def Vector_CombiningKindAttr : DialectAttr< "::mlir::vector::CombiningKindAttr::get($0, $_builder.getContext())"; } +def Vector_AffineMapArrayAttr : TypedArrayAttrBase { + let returnType = [{ ::llvm::SmallVector<::mlir::AffineMap, 4> }]; + let convertFromStorage = [{ + llvm::to_vector<4>($_self.getAsValueRange<::mlir::AffineMapAttr>()); + }]; + let constBuilderCall = "$_builder.getAffineMapArrayAttr($0)"; +} + // TODO: Add an attribute to specify a different algebra with operators other // than the current set: {*, +}. def Vector_ContractionOp : @@ -75,7 +85,8 @@ def Vector_ContractionOp : ]>, Arguments<(ins AnyVector:$lhs, AnyVector:$rhs, AnyType:$acc, Variadic>:$masks, - AffineMapArrayAttr:$indexing_maps, ArrayAttr:$iterator_types, + Vector_AffineMapArrayAttr:$indexing_maps, + ArrayAttr:$iterator_types, DefaultValuedAttr:$kind)>, Results<(outs AnyType)> { @@ -223,7 +234,6 @@ def Vector_ContractionOp : } Type getResultType() { return getResult().getType(); } ArrayRef getTraitAttrNames(); - SmallVector getIndexingMaps(); static unsigned getAccOperandIndex() { return 2; } // Returns the bounds of each dimension in the iteration space spanned @@ -240,7 +250,7 @@ def Vector_ContractionOp : std::vector> getContractingDimMap(); std::vector> getBatchDimMap(); - static constexpr StringRef getKindAttrName() { return "kind"; } + static constexpr StringRef getKindAttrStrName() { return "kind"; } static CombiningKind getDefaultKind() { return CombiningKind::ADD; @@ -327,8 +337,8 @@ def Vector_MultiDimReductionOp : "CombiningKind":$kind)> ]; let extraClassDeclaration = [{ - static StringRef getKindAttrName() { return "kind"; } - static StringRef getReductionDimsAttrName() { return "reduction_dims"; } + static StringRef getKindAttrStrName() { return "kind"; } + static StringRef getReductionDimsAttrStrName() { return "reduction_dims"; } VectorType getSourceVectorType() { return source().getType().cast(); @@ -474,7 +484,7 @@ def Vector_ShuffleOp : ]; let hasFolder = 1; let extraClassDeclaration = [{ - static StringRef getMaskAttrName() { return "mask"; } + static StringRef getMaskAttrStrName() { return "mask"; } VectorType getV1VectorType() { return v1().getType().cast(); } @@ -561,7 +571,7 @@ def Vector_ExtractOp : OpBuilder<(ins "Value":$source, "ValueRange":$position)> ]; let extraClassDeclaration = [{ - static StringRef getPositionAttrName() { return "position"; } + static StringRef getPositionAttrStrName() { return "position"; } VectorType getVectorType() { return vector().getType().cast(); } @@ -754,7 +764,7 @@ def Vector_InsertOp : OpBuilder<(ins "Value":$source, "Value":$dest, "ValueRange":$position)> ]; let extraClassDeclaration = [{ - static StringRef getPositionAttrName() { return "position"; } + static StringRef getPositionAttrStrName() { return "position"; } Type getSourceType() { return source().getType(); } VectorType getDestVectorType() { return dest().getType().cast(); @@ -873,15 +883,15 @@ def Vector_InsertStridedSliceOp : "ArrayRef":$offsets, "ArrayRef":$strides)> ]; let extraClassDeclaration = [{ - static StringRef getOffsetsAttrName() { return "offsets"; } - static StringRef getStridesAttrName() { return "strides"; } + static StringRef getOffsetsAttrStrName() { return "offsets"; } + static StringRef getStridesAttrStrName() { return "strides"; } VectorType getSourceVectorType() { return source().getType().cast(); } VectorType getDestVectorType() { return dest().getType().cast(); } - bool hasNonUnitStrides() { + bool hasNonUnitStrides() { return llvm::any_of(strides(), [](Attribute attr) { return attr.cast().getInt() != 1; }); @@ -970,7 +980,7 @@ def Vector_OuterProductOp : VectorType getVectorType() { return getResult().getType().cast(); } - static constexpr StringRef getKindAttrName() { + static constexpr StringRef getKindAttrStrName() { return "kind"; } static CombiningKind getDefaultKind() { @@ -1089,11 +1099,11 @@ def Vector_ReshapeOp : void getFixedVectorSizes(SmallVectorImpl &results); - static StringRef getFixedVectorSizesAttrName() { + static StringRef getFixedVectorSizesAttrStrName() { return "fixed_vector_sizes"; } - static StringRef getInputShapeAttrName() { return "input_shape"; } - static StringRef getOutputShapeAttrName() { return "output_shape"; } + static StringRef getInputShapeAttrStrName() { return "input_shape"; } + static StringRef getOutputShapeAttrStrName() { return "output_shape"; } }]; let assemblyFormat = [{ @@ -1140,12 +1150,12 @@ def Vector_ExtractStridedSliceOp : "ArrayRef":$sizes, "ArrayRef":$strides)> ]; let extraClassDeclaration = [{ - static StringRef getOffsetsAttrName() { return "offsets"; } - static StringRef getSizesAttrName() { return "sizes"; } - static StringRef getStridesAttrName() { return "strides"; } + static StringRef getOffsetsAttrStrName() { return "offsets"; } + static StringRef getSizesAttrStrName() { return "sizes"; } + static StringRef getStridesAttrStrName() { return "strides"; } VectorType getVectorType(){ return vector().getType().cast(); } void getOffsets(SmallVectorImpl &results); - bool hasNonUnitStrides() { + bool hasNonUnitStrides() { return llvm::any_of(strides(), [](Attribute attr) { return attr.cast().getInt() != 1; }); @@ -2190,7 +2200,7 @@ def Vector_ConstantMaskOp : }]; let extraClassDeclaration = [{ - static StringRef getMaskDimSizesAttrName() { return "mask_dim_sizes"; } + static StringRef getMaskDimSizesAttrStrName() { return "mask_dim_sizes"; } }]; let assemblyFormat = "$mask_dim_sizes attr-dict `:` type(results)"; let hasVerifier = 1; @@ -2276,7 +2286,7 @@ def Vector_TransposeOp : return result().getType().cast(); } void getTransp(SmallVectorImpl &results); - static StringRef getTranspAttrName() { return "transp"; } + static StringRef getTranspAttrStrName() { return "transp"; } }]; let assemblyFormat = [{ $vector `,` $transp attr-dict `:` type($vector) `to` type($result) @@ -2537,8 +2547,8 @@ def Vector_ScanOp : CArg<"bool", "true">:$inclusive)> ]; let extraClassDeclaration = [{ - static StringRef getKindAttrName() { return "kind"; } - static StringRef getReductionDimAttrName() { return "reduction_dim"; } + static StringRef getKindAttrStrName() { return "kind"; } + static StringRef getReductionDimAttrStrName() { return "reduction_dim"; } VectorType getSourceType() { return source().getType().cast(); } diff --git a/mlir/include/mlir/Interfaces/VectorInterfaces.td b/mlir/include/mlir/Interfaces/VectorInterfaces.td index 68b8886..ee6c638 100644 --- a/mlir/include/mlir/Interfaces/VectorInterfaces.td +++ b/mlir/include/mlir/Interfaces/VectorInterfaces.td @@ -55,7 +55,7 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> { StaticInterfaceMethod< /*desc=*/"Return the `in_bounds` attribute name.", /*retTy=*/"::mlir::StringRef", - /*methodName=*/"getInBoundsAttrName", + /*methodName=*/"getInBoundsAttrStrName", /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/ [{ return "in_bounds"; }] @@ -63,7 +63,7 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> { StaticInterfaceMethod< /*desc=*/"Return the `permutation_map` attribute name.", /*retTy=*/"::mlir::StringRef", - /*methodName=*/"getPermutationMapAttrName", + /*methodName=*/"getPermutationMapAttrStrName", /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/ [{ return "permutation_map"; }] diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp index dec9eec..8650d57 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp @@ -318,7 +318,7 @@ static void hoistReadWrite(HoistableRead read, HoistableWrite write, write.insertSliceOp.destMutable().assign(read.extractSliceOp.source()); } else { newForOp.getResult(initArgNumber) - .replaceAllUsesWith(write.transferWriteOp.getResult(0)); + .replaceAllUsesWith(write.transferWriteOp.getResult()); write.transferWriteOp.sourceMutable().assign( newForOp.getResult(initArgNumber)); } diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 2d504cb..4db1509 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -347,9 +347,9 @@ void vector::MultiDimReductionOp::build(OpBuilder &builder, for (const auto &en : llvm::enumerate(reductionMask)) if (en.value()) reductionDims.push_back(en.index()); - result.addAttribute(getReductionDimsAttrName(), + result.addAttribute(getReductionDimsAttrStrName(), builder.getI64ArrayAttr(reductionDims)); - result.addAttribute(getKindAttrName(), + result.addAttribute(getKindAttrStrName(), CombiningKindAttr::get(kind, builder.getContext())); } @@ -491,10 +491,10 @@ void vector::ContractionOp::build(OpBuilder &builder, OperationState &result, ArrayRef iteratorTypes) { result.addOperands({lhs, rhs, acc}); result.addTypes(acc.getType()); - result.addAttribute(getIndexingMapsAttrName(), + result.addAttribute(::mlir::getIndexingMapsAttrName(), builder.getAffineMapArrayAttr( AffineMap::inferFromExprList(indexingExprs))); - result.addAttribute(getIteratorTypesAttrName(), + result.addAttribute(::mlir::getIteratorTypesAttrName(), builder.getStrArrayAttr(iteratorTypes)); } @@ -512,9 +512,9 @@ void vector::ContractionOp::build(OpBuilder &builder, OperationState &result, ArrayAttr iteratorTypes, CombiningKind kind) { result.addOperands({lhs, rhs, acc}); result.addTypes(acc.getType()); - result.addAttribute(getIndexingMapsAttrName(), indexingMaps); - result.addAttribute(getIteratorTypesAttrName(), iteratorTypes); - result.addAttribute(ContractionOp::getKindAttrName(), + result.addAttribute(::mlir::getIndexingMapsAttrName(), indexingMaps); + result.addAttribute(::mlir::getIteratorTypesAttrName(), iteratorTypes); + result.addAttribute(ContractionOp::getKindAttrStrName(), CombiningKindAttr::get(kind, builder.getContext())); } @@ -543,8 +543,8 @@ ParseResult ContractionOp::parse(OpAsmParser &parser, OperationState &result) { return failure(); result.attributes.assign(dictAttr.getValue().begin(), dictAttr.getValue().end()); - if (!result.attributes.get(ContractionOp::getKindAttrName())) { - result.addAttribute(ContractionOp::getKindAttrName(), + if (!result.attributes.get(ContractionOp::getKindAttrStrName())) { + result.addAttribute(ContractionOp::getKindAttrStrName(), CombiningKindAttr::get(ContractionOp::getDefaultKind(), result.getContext())); } @@ -698,7 +698,7 @@ LogicalResult ContractionOp::verify() { unsigned numIterators = iterator_types().getValue().size(); for (const auto &it : llvm::enumerate(indexing_maps())) { auto index = it.index(); - auto map = it.value().cast().getValue(); + auto map = it.value(); if (map.getNumSymbols() != 0) return emitOpError("expected indexing map ") << index << " to have no symbols"; @@ -759,9 +759,9 @@ LogicalResult ContractionOp::verify() { } ArrayRef ContractionOp::getTraitAttrNames() { - static constexpr StringRef names[3] = {getIndexingMapsAttrName(), - getIteratorTypesAttrName(), - ContractionOp::getKindAttrName()}; + static constexpr StringRef names[3] = {::mlir::getIndexingMapsAttrName(), + ::mlir::getIteratorTypesAttrName(), + ContractionOp::getKindAttrStrName()}; return llvm::makeArrayRef(names); } @@ -817,11 +817,11 @@ void ContractionOp::getIterationBounds( void ContractionOp::getIterationIndexMap( std::vector> &iterationIndexMap) { - unsigned numMaps = indexing_maps().getValue().size(); + unsigned numMaps = indexing_maps().size(); iterationIndexMap.resize(numMaps); for (const auto &it : llvm::enumerate(indexing_maps())) { auto index = it.index(); - auto map = it.value().cast().getValue(); + auto map = it.value(); for (unsigned i = 0, e = map.getNumResults(); i < e; ++i) { auto dim = map.getResult(i).cast(); iterationIndexMap[index][dim.getPosition()] = i; @@ -841,13 +841,6 @@ std::vector> ContractionOp::getBatchDimMap() { getParallelIteratorTypeName(), getContext()); } -SmallVector ContractionOp::getIndexingMaps() { - return llvm::to_vector<4>( - llvm::map_range(indexing_maps().getValue(), [](Attribute mapAttr) { - return mapAttr.cast().getValue(); - })); -} - Optional> ContractionOp::getShapeForUnroll() { SmallVector shape; getIterationBounds(shape); @@ -961,7 +954,7 @@ void vector::ExtractOp::build(OpBuilder &builder, OperationState &result, auto positionAttr = getVectorSubscriptAttr(builder, position); result.addTypes(inferExtractOpResultType(source.getType().cast(), positionAttr)); - result.addAttribute(getPositionAttrName(), positionAttr); + result.addAttribute(getPositionAttrStrName(), positionAttr); } // Convenience builder which assumes the values are constant indices. @@ -1053,7 +1046,7 @@ static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp) { // OpBuilder is only used as a helper to build an I64ArrayAttr. OpBuilder b(extractOp.getContext()); std::reverse(globalPosition.begin(), globalPosition.end()); - extractOp->setAttr(ExtractOp::getPositionAttrName(), + extractOp->setAttr(ExtractOp::getPositionAttrStrName(), b.getI64ArrayAttr(globalPosition)); return success(); } @@ -1295,7 +1288,7 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) { extractOp.setOperand(source); // OpBuilder is only used as a helper to build an I64ArrayAttr. OpBuilder b(extractOp.getContext()); - extractOp->setAttr(ExtractOp::getPositionAttrName(), + extractOp->setAttr(ExtractOp::getPositionAttrStrName(), b.getI64ArrayAttr(extractPos)); return extractOp.getResult(); } @@ -1355,7 +1348,7 @@ static Value foldExtractFromShapeCast(ExtractOp extractOp) { SmallVector newPosition = delinearize(newStrides, position); // OpBuilder is only used as a helper to build an I64ArrayAttr. OpBuilder b(extractOp.getContext()); - extractOp->setAttr(ExtractOp::getPositionAttrName(), + extractOp->setAttr(ExtractOp::getPositionAttrStrName(), b.getI64ArrayAttr(newPosition)); extractOp.setOperand(shapeCastOp.source()); return extractOp.getResult(); @@ -1396,7 +1389,7 @@ static Value foldExtractFromExtractStrided(ExtractOp extractOp) { extractOp.vectorMutable().assign(extractStridedSliceOp.vector()); // OpBuilder is only used as a helper to build an I64ArrayAttr. OpBuilder b(extractOp.getContext()); - extractOp->setAttr(ExtractOp::getPositionAttrName(), + extractOp->setAttr(ExtractOp::getPositionAttrStrName(), b.getI64ArrayAttr(extractedPos)); return extractOp.getResult(); } @@ -1453,7 +1446,7 @@ static Value foldExtractStridedOpFromInsertChain(ExtractOp op) { op.vectorMutable().assign(insertOp.source()); // OpBuilder is only used as a helper to build an I64ArrayAttr. OpBuilder b(op.getContext()); - op->setAttr(ExtractOp::getPositionAttrName(), + op->setAttr(ExtractOp::getPositionAttrStrName(), b.getI64ArrayAttr(offsetDiffs)); return op.getResult(); } @@ -1736,7 +1729,7 @@ void ShuffleOp::build(OpBuilder &builder, OperationState &result, Value v1, auto shape = llvm::to_vector<4>(v1Type.getShape()); shape[0] = mask.size(); result.addTypes(VectorType::get(shape, v1Type.getElementType())); - result.addAttribute(getMaskAttrName(), maskAttr); + result.addAttribute(getMaskAttrStrName(), maskAttr); } void ShuffleOp::print(OpAsmPrinter &p) { @@ -1784,7 +1777,7 @@ ParseResult ShuffleOp::parse(OpAsmParser &parser, OperationState &result) { VectorType v1Type, v2Type; if (parser.parseOperand(v1) || parser.parseComma() || parser.parseOperand(v2) || - parser.parseAttribute(attr, ShuffleOp::getMaskAttrName(), + parser.parseAttribute(attr, ShuffleOp::getMaskAttrStrName(), result.attributes) || parser.parseOptionalAttrDict(result.attributes) || parser.parseColonType(v1Type) || parser.parseComma() || @@ -1877,7 +1870,7 @@ void InsertOp::build(OpBuilder &builder, OperationState &result, Value source, result.addOperands({source, dest}); auto positionAttr = getVectorSubscriptAttr(builder, position); result.addTypes(dest.getType()); - result.addAttribute(getPositionAttrName(), positionAttr); + result.addAttribute(getPositionAttrStrName(), positionAttr); } // Convenience builder which assumes the values are constant indices. @@ -1995,8 +1988,8 @@ void InsertStridedSliceOp::build(OpBuilder &builder, OperationState &result, auto offsetsAttr = getVectorSubscriptAttr(builder, offsets); auto stridesAttr = getVectorSubscriptAttr(builder, strides); result.addTypes(dest.getType()); - result.addAttribute(getOffsetsAttrName(), offsetsAttr); - result.addAttribute(getStridesAttrName(), stridesAttr); + result.addAttribute(getOffsetsAttrStrName(), offsetsAttr); + result.addAttribute(getStridesAttrStrName(), stridesAttr); } // TODO: Should be moved to Tablegen Confined attributes. @@ -2172,9 +2165,9 @@ ParseResult OuterProductOp::parse(OpAsmParser &parser, OperationState &result) { vLHS.getElementType()) : VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType()); - if (!result.attributes.get(OuterProductOp::getKindAttrName())) { + if (!result.attributes.get(OuterProductOp::getKindAttrStrName())) { result.attributes.append( - OuterProductOp::getKindAttrName(), + OuterProductOp::getKindAttrStrName(), CombiningKindAttr::get(OuterProductOp::getDefaultKind(), result.getContext())); } @@ -2322,9 +2315,9 @@ void ExtractStridedSliceOp::build(OpBuilder &builder, OperationState &result, result.addTypes( inferStridedSliceOpResultType(source.getType().cast(), offsetsAttr, sizesAttr, stridesAttr)); - result.addAttribute(getOffsetsAttrName(), offsetsAttr); - result.addAttribute(getSizesAttrName(), sizesAttr); - result.addAttribute(getStridesAttrName(), stridesAttr); + result.addAttribute(getOffsetsAttrStrName(), offsetsAttr); + result.addAttribute(getSizesAttrStrName(), sizesAttr); + result.addAttribute(getStridesAttrStrName(), stridesAttr); } LogicalResult ExtractStridedSliceOp::verify() { @@ -2412,7 +2405,7 @@ foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op) { op.setOperand(insertOp.source()); // OpBuilder is only used as a helper to build an I64ArrayAttr. OpBuilder b(op.getContext()); - op->setAttr(ExtractStridedSliceOp::getOffsetsAttrName(), + op->setAttr(ExtractStridedSliceOp::getOffsetsAttrStrName(), b.getI64ArrayAttr(offsetDiffs)); return success(); } @@ -2765,7 +2758,7 @@ static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op) { SmallVector elidedAttrs; elidedAttrs.push_back(TransferReadOp::getOperandSegmentSizeAttr()); if (op.permutation_map().isMinorIdentity()) - elidedAttrs.push_back(op.getPermutationMapAttrName()); + elidedAttrs.push_back(op.getPermutationMapAttrStrName()); bool elideInBounds = true; if (auto inBounds = op.in_bounds()) { for (auto attr : *inBounds) { @@ -2776,7 +2769,7 @@ static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op) { } } if (elideInBounds) - elidedAttrs.push_back(op.getInBoundsAttrName()); + elidedAttrs.push_back(op.getInBoundsAttrStrName()); p.printOptionalAttrDict(op->getAttrs(), elidedAttrs); } @@ -2817,7 +2810,7 @@ ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) { VectorType vectorType = types[1].dyn_cast(); if (!vectorType) return parser.emitError(typesLoc, "requires vector type"); - auto permutationAttrName = TransferReadOp::getPermutationMapAttrName(); + auto permutationAttrName = TransferReadOp::getPermutationMapAttrStrName(); Attribute mapAttr = result.attributes.get(permutationAttrName); if (!mapAttr) { auto permMap = getTransferMinorIdentityMap(shapedType, vectorType); @@ -2963,7 +2956,7 @@ static LogicalResult foldTransferInBoundsAttribute(TransferOp op) { return failure(); // OpBuilder is only used as a helper to build an I64ArrayAttr. OpBuilder b(op.getContext()); - op->setAttr(TransferOp::getInBoundsAttrName(), + op->setAttr(TransferOp::getInBoundsAttrStrName(), b.getBoolArrayAttr(newInBounds)); return success(); } @@ -3193,7 +3186,7 @@ ParseResult TransferWriteOp::parse(OpAsmParser &parser, ShapedType shapedType = types[1].dyn_cast(); if (!shapedType || !shapedType.isa()) return parser.emitError(typesLoc, "requires memref or ranked tensor type"); - auto permutationAttrName = TransferWriteOp::getPermutationMapAttrName(); + auto permutationAttrName = TransferWriteOp::getPermutationMapAttrStrName(); auto attr = result.attributes.get(permutationAttrName); if (!attr) { auto permMap = getTransferMinorIdentityMap(shapedType, vectorType); @@ -4151,7 +4144,7 @@ void vector::TransposeOp::build(OpBuilder &builder, OperationState &result, result.addOperands(vector); result.addTypes(VectorType::get(transposedShape, vt.getElementType())); - result.addAttribute(getTranspAttrName(), builder.getI64ArrayAttr(transp)); + result.addAttribute(getTranspAttrStrName(), builder.getI64ArrayAttr(transp)); } // Eliminates transpose operations, which produce values identical to their diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp index f574713..48470f7 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp @@ -514,7 +514,7 @@ LogicalResult mlir::vector::splitFullAndPartialTransfer( SmallVector bools(xferOp.getTransferRank(), true); auto inBoundsAttr = b.getBoolArrayAttr(bools); if (options.vectorTransferSplit == VectorTransferSplit::ForceInBounds) { - xferOp->setAttr(xferOp.getInBoundsAttrName(), inBoundsAttr); + xferOp->setAttr(xferOp.getInBoundsAttrStrName(), inBoundsAttr); return success(); } @@ -585,7 +585,7 @@ LogicalResult mlir::vector::splitFullAndPartialTransfer( for (unsigned i = 0, e = returnTypes.size(); i != e; ++i) xferReadOp.setOperand(i, fullPartialIfOp.getResult(i)); - xferOp->setAttr(xferOp.getInBoundsAttrName(), inBoundsAttr); + xferOp->setAttr(xferOp.getInBoundsAttrStrName(), inBoundsAttr); return success(); } diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp index 226facc..f9413a7 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -1050,7 +1050,7 @@ ContractionOpToMatmulOpLowering::matchAndRewrite(vector::ContractionOp op, bindDims(rew.getContext(), m, n, k); // LHS must be A(m, k) or A(k, m). Value lhs = op.lhs(); - auto lhsMap = op.indexing_maps()[0].cast().getValue(); + auto lhsMap = op.indexing_maps()[0]; if (lhsMap == AffineMap::get(3, 0, {k, m}, ctx)) lhs = rew.create(loc, lhs, ArrayRef{1, 0}); else if (lhsMap != AffineMap::get(3, 0, {m, k}, ctx)) @@ -1058,7 +1058,7 @@ ContractionOpToMatmulOpLowering::matchAndRewrite(vector::ContractionOp op, // RHS must be B(k, n) or B(n, k). Value rhs = op.rhs(); - auto rhsMap = op.indexing_maps()[1].cast().getValue(); + auto rhsMap = op.indexing_maps()[1]; if (rhsMap == AffineMap::get(3, 0, {n, k}, ctx)) rhs = rew.create(loc, rhs, ArrayRef{1, 0}); else if (rhsMap != AffineMap::get(3, 0, {k, n}, ctx)) @@ -1088,7 +1088,7 @@ ContractionOpToMatmulOpLowering::matchAndRewrite(vector::ContractionOp op, mul); // ACC must be C(m, n) or C(n, m). - auto accMap = op.indexing_maps()[2].cast().getValue(); + auto accMap = op.indexing_maps()[2]; if (accMap == AffineMap::get(3, 0, {n, m}, ctx)) mul = rew.create(loc, mul, ArrayRef{1, 0}); else if (accMap != AffineMap::get(3, 0, {m, n}, ctx)) -- 2.7.4