From: Jacques Pienaar Date: Mon, 28 Mar 2022 18:24:47 +0000 (-0700) Subject: [mlir] Flip Vector dialect accessors used to prefixed form. X-Git-Tag: upstream/15.0.7~12144 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=7c38fd605ba85657a0ecbea75a8e3a68174d3dff;p=platform%2Fupstream%2Fllvm.git [mlir] Flip Vector dialect accessors used to prefixed form. This has been on _Both for a couple of weeks. Flip usages in core with intention to flip flag to _Prefixed in follow up. Needed to add a couple of helper methods in AffineOps and Linalg to facilitate a pure flag flip in follow up as some of these classes are used in templates and so sensitive to Vector dialect changes. Differential Revision: https://reviews.llvm.org/D122151 --- diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td index f130143..3880ad8 100644 --- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td +++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td @@ -493,6 +493,9 @@ class AffineLoadOpBase traits = []> : } static StringRef getMapAttrName() { return "map"; } + + // TODO: Remove once prefixing is flipped. + operand_range getIndices() { return indices(); } }]; } @@ -856,6 +859,9 @@ class AffineStoreOpBase traits = []> : } static StringRef getMapAttrName() { return "map"; } + + // TODO: Remove once prefixing is flipped. + operand_range getIndices() { return indices(); } }]; } diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td index a551f40..4395383 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td @@ -1131,6 +1131,9 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> { LogicalResult reifyResultShapes(OpBuilder &b, ReifiedRankedShapedTypeDims &reifiedReturnShapes); + // TODO: Remove once prefixing is flipped. + ArrayAttr getIteratorTypes() { return iterator_types(); } + //========================================================================// // Helper functions to mutate the `operand_segment_sizes` attribute. // These are useful when cloning and changing operand types. diff --git a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h index 824bfeb..779983f 100644 --- a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h +++ b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h @@ -163,7 +163,7 @@ public: StructuredGenerator(OpBuilder &builder, StructuredOpInterface op) : builder(builder), ctx(op.getContext()), loc(op.getLoc()), - iterators(op.iterator_types()), maps(op.getIndexingMaps()), op(op) {} + iterators(op.getIteratorTypes()), maps(op.getIndexingMaps()), op(op) {} bool iters(ArrayRef its) { if (its.size() != iterators.size()) diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index 69c2c92..005db9a 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -219,18 +219,18 @@ def Vector_ContractionOp : ]; let extraClassDeclaration = [{ VectorType getLhsType() { - return lhs().getType().cast(); + return getLhs().getType().cast(); } VectorType getRhsType() { - return rhs().getType().cast(); + return getRhs().getType().cast(); } - Type getAccType() { return acc().getType(); } + Type getAccType() { return getAcc().getType(); } VectorType getLHSVectorMaskType() { - if (llvm::size(masks()) != 2) return VectorType(); + if (llvm::size(getMasks()) != 2) return VectorType(); return getOperand(3).getType().cast(); } VectorType getRHSVectorMaskType() { - if (llvm::size(masks()) != 2) return VectorType(); + if (llvm::size(getMasks()) != 2) return VectorType(); return getOperand(4).getType().cast(); } Type getResultType() { return getResult().getType(); } @@ -296,7 +296,7 @@ def Vector_ReductionOp : }]; let extraClassDeclaration = [{ VectorType getVectorType() { - return vector().getType().cast(); + return getVector().getType().cast(); } }]; let builders = [ @@ -347,10 +347,10 @@ def Vector_MultiDimReductionOp : static StringRef getReductionDimsAttrStrName() { return "reduction_dims"; } VectorType getSourceVectorType() { - return source().getType().cast(); + return getSource().getType().cast(); } Type getDestType() { - return dest().getType(); + return getDest().getType(); } bool isReducedDim(int64_t d) { @@ -361,7 +361,7 @@ def Vector_MultiDimReductionOp : SmallVector getReductionMask() { SmallVector res(getSourceVectorType().getRank(), false); - for (auto ia : reduction_dims().getAsRange()) + for (auto ia : getReductionDims().getAsRange()) res[ia.getInt()] = true; return res; } @@ -415,9 +415,9 @@ def Vector_BroadcastOp : ``` }]; let extraClassDeclaration = [{ - Type getSourceType() { return source().getType(); } + Type getSourceType() { return getSource().getType(); } VectorType getVectorType() { - return vector().getType().cast(); + return getVector().getType().cast(); } }]; let assemblyFormat = "$source attr-dict `:` type($source) `to` type($vector)"; @@ -472,13 +472,13 @@ def Vector_ShuffleOp : let extraClassDeclaration = [{ static StringRef getMaskAttrStrName() { return "mask"; } VectorType getV1VectorType() { - return v1().getType().cast(); + return getV1().getType().cast(); } VectorType getV2VectorType() { - return v2().getType().cast(); + return getV2().getType().cast(); } VectorType getVectorType() { - return vector().getType().cast(); + return getVector().getType().cast(); } }]; let assemblyFormat = "operands $mask attr-dict `:` type(operands)"; @@ -526,7 +526,7 @@ def Vector_ExtractElementOp : ]; let extraClassDeclaration = [{ VectorType getVectorType() { - return vector().getType().cast(); + return getVector().getType().cast(); } }]; let hasVerifier = 1; @@ -560,7 +560,7 @@ def Vector_ExtractOp : let extraClassDeclaration = [{ static StringRef getPositionAttrStrName() { return "position"; } VectorType getVectorType() { - return vector().getType().cast(); + return getVector().getType().cast(); } static bool isCompatibleReturnTypes(TypeRange l, TypeRange r); }]; @@ -623,7 +623,7 @@ def Vector_ExtractMapOp : "AffineMap":$map)>]; let extraClassDeclaration = [{ VectorType getSourceVectorType() { - return vector().getType().cast(); + return getVector().getType().cast(); } VectorType getResultType() { return getResult().getType().cast(); @@ -664,7 +664,7 @@ def Vector_FMAOp : }]; let assemblyFormat = "$lhs `,` $rhs `,` $acc attr-dict `:` type($lhs)"; let extraClassDeclaration = [{ - VectorType getVectorType() { return lhs().getType().cast(); } + VectorType getVectorType() { return getLhs().getType().cast(); } }]; } @@ -707,9 +707,9 @@ def Vector_InsertElementOp : OpBuilder<(ins "Value":$source, "Value":$dest)>, ]; let extraClassDeclaration = [{ - Type getSourceType() { return source().getType(); } + Type getSourceType() { return getSource().getType(); } VectorType getDestVectorType() { - return dest().getType().cast(); + return getDest().getType().cast(); } }]; let hasVerifier = 1; @@ -747,9 +747,9 @@ def Vector_InsertOp : ]; let extraClassDeclaration = [{ static StringRef getPositionAttrStrName() { return "position"; } - Type getSourceType() { return source().getType(); } + Type getSourceType() { return getSource().getType(); } VectorType getDestVectorType() { - return dest().getType().cast(); + return getDest().getType().cast(); } }]; @@ -809,7 +809,7 @@ def Vector_InsertMapOp : }]; let extraClassDeclaration = [{ VectorType getSourceVectorType() { - return vector().getType().cast(); + return getVector().getType().cast(); } VectorType getResultType() { return getResult().getType().cast(); @@ -866,13 +866,13 @@ def Vector_InsertStridedSliceOp : static StringRef getOffsetsAttrStrName() { return "offsets"; } static StringRef getStridesAttrStrName() { return "strides"; } VectorType getSourceVectorType() { - return source().getType().cast(); + return getSource().getType().cast(); } VectorType getDestVectorType() { - return dest().getType().cast(); + return getDest().getType().cast(); } bool hasNonUnitStrides() { - return llvm::any_of(strides(), [](Attribute attr) { + return llvm::any_of(getStrides(), [](Attribute attr) { return attr.cast().getInt() != 1; }); } @@ -947,15 +947,15 @@ def Vector_OuterProductOp : ]; let extraClassDeclaration = [{ VectorType getOperandVectorTypeLHS() { - return lhs().getType().cast(); + return getLhs().getType().cast(); } Type getOperandTypeRHS() { - return rhs().getType(); + return getRhs().getType(); } VectorType getOperandVectorTypeACC() { - return (llvm::size(acc()) == 0) + return (llvm::size(getAcc()) == 0) ? VectorType() - : (*acc().begin()).getType().cast(); + : (*getAcc().begin()).getType().cast(); } VectorType getVectorType() { return getResult().getType().cast(); @@ -1065,17 +1065,17 @@ def Vector_ReshapeOp : let extraClassDeclaration = [{ VectorType getInputVectorType() { - return vector().getType().cast(); + return getVector().getType().cast(); } VectorType getOutputVectorType() { return getResult().getType().cast(); } /// Returns as integer value the number of input shape operands. - int64_t getNumInputShapeSizes() { return input_shape().size(); } + int64_t getNumInputShapeSizes() { return getInputShape().size(); } /// Returns as integer value the number of output shape operands. - int64_t getNumOutputShapeSizes() { return output_shape().size(); } + int64_t getNumOutputShapeSizes() { return getOutputShape().size(); } void getFixedVectorSizes(SmallVectorImpl &results); @@ -1133,10 +1133,10 @@ def Vector_ExtractStridedSliceOp : static StringRef getOffsetsAttrStrName() { return "offsets"; } static StringRef getSizesAttrStrName() { return "sizes"; } static StringRef getStridesAttrStrName() { return "strides"; } - VectorType getVectorType(){ return vector().getType().cast(); } + VectorType getVectorType(){ return getVector().getType().cast(); } void getOffsets(SmallVectorImpl &results); bool hasNonUnitStrides() { - return llvm::any_of(strides(), [](Attribute attr) { + return llvm::any_of(getStrides(), [](Attribute attr) { return attr.cast().getInt() != 1; }); } @@ -1558,11 +1558,11 @@ def Vector_LoadOp : Vector_Op<"load"> { let extraClassDeclaration = [{ MemRefType getMemRefType() { - return base().getType().cast(); + return getBase().getType().cast(); } VectorType getVectorType() { - return result().getType().cast(); + return getResult().getType().cast(); } }]; @@ -1635,11 +1635,11 @@ def Vector_StoreOp : Vector_Op<"store"> { let extraClassDeclaration = [{ MemRefType getMemRefType() { - return base().getType().cast(); + return getBase().getType().cast(); } VectorType getVectorType() { - return valueToStore().getType().cast(); + return getValueToStore().getType().cast(); } }]; @@ -1688,16 +1688,16 @@ def Vector_MaskedLoadOp : }]; let extraClassDeclaration = [{ MemRefType getMemRefType() { - return base().getType().cast(); + return getBase().getType().cast(); } VectorType getMaskVectorType() { - return mask().getType().cast(); + return getMask().getType().cast(); } VectorType getPassThruVectorType() { - return pass_thru().getType().cast(); + return getPassThru().getType().cast(); } VectorType getVectorType() { - return result().getType().cast(); + return getResult().getType().cast(); } }]; let assemblyFormat = "$base `[` $indices `]` `,` $mask `,` $pass_thru attr-dict `:` " @@ -1744,13 +1744,13 @@ def Vector_MaskedStoreOp : }]; let extraClassDeclaration = [{ MemRefType getMemRefType() { - return base().getType().cast(); + return getBase().getType().cast(); } VectorType getMaskVectorType() { - return mask().getType().cast(); + return getMask().getType().cast(); } VectorType getVectorType() { - return valueToStore().getType().cast(); + return getValueToStore().getType().cast(); } }]; let assemblyFormat = @@ -1803,19 +1803,19 @@ def Vector_GatherOp : }]; let extraClassDeclaration = [{ MemRefType getMemRefType() { - return base().getType().cast(); + return getBase().getType().cast(); } VectorType getIndexVectorType() { - return index_vec().getType().cast(); + return getIndexVec().getType().cast(); } VectorType getMaskVectorType() { - return mask().getType().cast(); + return getMask().getType().cast(); } VectorType getPassThruVectorType() { - return pass_thru().getType().cast(); + return getPassThru().getType().cast(); } VectorType getVectorType() { - return result().getType().cast(); + return getResult().getType().cast(); } }]; let assemblyFormat = @@ -1870,16 +1870,16 @@ def Vector_ScatterOp : }]; let extraClassDeclaration = [{ MemRefType getMemRefType() { - return base().getType().cast(); + return getBase().getType().cast(); } VectorType getIndexVectorType() { - return index_vec().getType().cast(); + return getIndexVec().getType().cast(); } VectorType getMaskVectorType() { - return mask().getType().cast(); + return getMask().getType().cast(); } VectorType getVectorType() { - return valueToStore().getType().cast(); + return getValueToStore().getType().cast(); } }]; let assemblyFormat = @@ -1931,16 +1931,16 @@ def Vector_ExpandLoadOp : }]; let extraClassDeclaration = [{ MemRefType getMemRefType() { - return base().getType().cast(); + return getBase().getType().cast(); } VectorType getMaskVectorType() { - return mask().getType().cast(); + return getMask().getType().cast(); } VectorType getPassThruVectorType() { - return pass_thru().getType().cast(); + return getPassThru().getType().cast(); } VectorType getVectorType() { - return result().getType().cast(); + return getResult().getType().cast(); } }]; let assemblyFormat = "$base `[` $indices `]` `,` $mask `,` $pass_thru attr-dict `:` " @@ -1989,13 +1989,13 @@ def Vector_CompressStoreOp : }]; let extraClassDeclaration = [{ MemRefType getMemRefType() { - return base().getType().cast(); + return getBase().getType().cast(); } VectorType getMaskVectorType() { - return mask().getType().cast(); + return getMask().getType().cast(); } VectorType getVectorType() { - return valueToStore().getType().cast(); + return getValueToStore().getType().cast(); } }]; let assemblyFormat = @@ -2045,7 +2045,7 @@ def Vector_ShapeCastOp : }]; let extraClassDeclaration = [{ VectorType getSourceVectorType() { - return source().getType().cast(); + return getSource().getType().cast(); } VectorType getResultVectorType() { return getResult().getType().cast(); @@ -2086,7 +2086,7 @@ def Vector_BitCastOp : }]; let extraClassDeclaration = [{ VectorType getSourceVectorType() { - return source().getType().cast(); + return getSource().getType().cast(); } VectorType getResultVectorType() { return getResult().getType().cast(); @@ -2129,13 +2129,13 @@ def Vector_TypeCastOp : let extraClassDeclaration = [{ MemRefType getMemRefType() { - return memref().getType().cast(); + return getMemref().getType().cast(); } MemRefType getResultMemRefType() { return getResult().getType().cast(); } // Implement ViewLikeOpInterface. - Value getViewSource() { return memref(); } + Value getViewSource() { return getMemref(); } }]; let assemblyFormat = [{ @@ -2260,10 +2260,10 @@ def Vector_TransposeOp : ]; let extraClassDeclaration = [{ VectorType getVectorType() { - return vector().getType().cast(); + return getVector().getType().cast(); } VectorType getResultType() { - return result().getType().cast(); + return getResult().getType().cast(); } void getTransp(SmallVectorImpl &results); static StringRef getTranspAttrStrName() { return "transp"; } @@ -2303,7 +2303,7 @@ def Vector_PrintOp : }]; let extraClassDeclaration = [{ Type getPrintType() { - return source().getType(); + return getSource().getType(); } }]; let assemblyFormat = "$source attr-dict `:` type($source)"; @@ -2530,16 +2530,16 @@ def Vector_ScanOp : static StringRef getKindAttrStrName() { return "kind"; } static StringRef getReductionDimAttrStrName() { return "reduction_dim"; } VectorType getSourceType() { - return source().getType().cast(); + return getSource().getType().cast(); } VectorType getDestType() { - return dest().getType().cast(); + return getDest().getType().cast(); } VectorType getAccumulatorType() { - return accumulated_value().getType().cast(); + return getAccumulatedValue().getType().cast(); } VectorType getInitialValueType() { - return initial_value().getType().cast(); + return getInitialValue().getType().cast(); } }]; let assemblyFormat = diff --git a/mlir/include/mlir/Interfaces/VectorInterfaces.td b/mlir/include/mlir/Interfaces/VectorInterfaces.td index ee6c638..00e69e2 100644 --- a/mlir/include/mlir/Interfaces/VectorInterfaces.td +++ b/mlir/include/mlir/Interfaces/VectorInterfaces.td @@ -77,8 +77,8 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> { /*methodBody=*/"", /*defaultImplementation=*/[{ return $_op.isBroadcastDim(dim) - || ($_op.in_bounds() - && $_op.in_bounds()->template cast<::mlir::ArrayAttr>()[dim] + || ($_op.getInBounds() + && $_op.getInBounds()->template cast<::mlir::ArrayAttr>()[dim] .template cast<::mlir::BoolAttr>().getValue()); }] >, @@ -87,7 +87,7 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> { /*retTy=*/"::mlir::Value", /*methodName=*/"source", /*args=*/(ins), - /*methodBody=*/"return $_op.source();" + /*methodBody=*/"return $_op.getSource();" /*defaultImplementation=*/ >, InterfaceMethod< @@ -95,7 +95,7 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> { /*retTy=*/"::mlir::Value", /*methodName=*/"vector", /*args=*/(ins), - /*methodBody=*/"return $_op.vector();" + /*methodBody=*/"return $_op.getVector();" /*defaultImplementation=*/ >, InterfaceMethod< @@ -103,7 +103,7 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> { /*retTy=*/"::mlir::ValueRange", /*methodName=*/"indices", /*args=*/(ins), - /*methodBody=*/"return $_op.indices();" + /*methodBody=*/"return $_op.getIndices();" /*defaultImplementation=*/ >, InterfaceMethod< @@ -111,7 +111,7 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> { /*retTy=*/"::mlir::AffineMap", /*methodName=*/"permutation_map", /*args=*/(ins), - /*methodBody=*/"return $_op.permutation_map();" + /*methodBody=*/"return $_op.getPermutationMap();" /*defaultImplementation=*/ >, InterfaceMethod< @@ -121,7 +121,7 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> { /*args=*/(ins "unsigned":$idx), /*methodBody=*/"", /*defaultImplementation=*/[{ - auto expr = $_op.permutation_map().getResult(idx); + auto expr = $_op.getPermutationMap().getResult(idx); return expr.template isa<::mlir::AffineConstantExpr>() && expr.template dyn_cast<::mlir::AffineConstantExpr>().getValue() == 0; }] @@ -146,7 +146,7 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> { /*retTy=*/"::mlir::Optional<::mlir::ArrayAttr>", /*methodName=*/"in_bounds", /*args=*/(ins), - /*methodBody=*/"return $_op.in_bounds();" + /*methodBody=*/"return $_op.getInBounds();" /*defaultImplementation=*/ >, InterfaceMethod< @@ -156,7 +156,7 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> { /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/ - "return $_op.source().getType().template cast<::mlir::ShapedType>();" + "return $_op.getSource().getType().template cast<::mlir::ShapedType>();" >, InterfaceMethod< /*desc=*/"Return the VectorType.", @@ -165,7 +165,7 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> { /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - return $_op.vector().getType().template dyn_cast<::mlir::VectorType>(); + return $_op.getVector().getType().template dyn_cast<::mlir::VectorType>(); }] >, InterfaceMethod< @@ -175,9 +175,9 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> { /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - return $_op.mask() + return $_op.getMask() ? ::mlir::vector::detail::transferMaskType( - $_op.getVectorType(), $_op.permutation_map()) + $_op.getVectorType(), $_op.getPermutationMap()) : ::mlir::VectorType(); }] >, @@ -189,7 +189,7 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> { /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/ - "return $_op.permutation_map().getNumResults();" + "return $_op.getPermutationMap().getNumResults();" >, InterfaceMethod< /*desc=*/[{ Return the number of leading shaped dimensions that do not diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp index 196e4f6..9ed1c34 100644 --- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp +++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp @@ -32,14 +32,14 @@ using namespace mlir; // Return true if the contract op can be convert to MMA matmul. static bool contractSupportsMMAMatrixType(vector::ContractionOp contract) { - if (llvm::size(contract.masks()) != 0) + if (llvm::size(contract.getMasks()) != 0) return false; using MapList = ArrayRef>; auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); }; AffineExpr m, n, k; bindDims(contract.getContext(), m, n, k); - auto iteratorTypes = contract.iterator_types().getValue(); + auto iteratorTypes = contract.getIteratorTypes().getValue(); if (!(isParallelIterator(iteratorTypes[0]) && isParallelIterator(iteratorTypes[1]) && isReductionIterator(iteratorTypes[2]))) @@ -76,12 +76,12 @@ getMemrefConstantHorizontalStride(ShapedType type) { // Return true if the transfer op can be converted to a MMA matrix load. static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp) { - if (readOp.mask() || readOp.hasOutOfBoundsDim() || + if (readOp.getMask() || readOp.hasOutOfBoundsDim() || readOp.getVectorType().getRank() != 2) return false; if (!getMemrefConstantHorizontalStride(readOp.getShapedType())) return false; - AffineMap map = readOp.permutation_map(); + AffineMap map = readOp.getPermutationMap(); OpBuilder b(readOp.getContext()); AffineExpr innerDim = b.getAffineDimExpr(map.getNumDims() - 1); AffineExpr zero = b.getAffineConstantExpr(0); @@ -99,13 +99,13 @@ transferWriteSupportsMMAMatrixType(vector::TransferWriteOp writeOp) { if (writeOp.getTransferRank() == 0) return false; - if (writeOp.mask() || writeOp.hasOutOfBoundsDim() || + if (writeOp.getMask() || writeOp.hasOutOfBoundsDim() || writeOp.getVectorType().getRank() != 2) return false; if (!getMemrefConstantHorizontalStride(writeOp.getShapedType())) return false; // TODO: Support transpose once it is added to GPU dialect ops. - if (!writeOp.permutation_map().isMinorIdentity()) + if (!writeOp.getPermutationMap().isMinorIdentity()) return false; return true; } @@ -122,7 +122,7 @@ static bool constantSupportsMMAMatrixType(arith::ConstantOp constantOp) { /// Return true if this is a broadcast from scalar to a 2D vector. static bool broadcastSupportsMMAMatrixType(vector::BroadcastOp broadcastOp) { return broadcastOp.getVectorType().getRank() == 2 && - broadcastOp.source().getType().isa(); + broadcastOp.getSource().getType().isa(); } /// Return the MMA elementwise enum associated with `op` if it is supported. @@ -240,7 +240,7 @@ struct PrepareContractToGPUMMA LogicalResult matchAndRewrite(vector::ContractionOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); - Value lhs = op.lhs(), rhs = op.rhs(), res = op.acc(); + Value lhs = op.getLhs(), rhs = op.getRhs(), res = op.getAcc(); // Set up the parallel/reduction structure in right form. using MapList = ArrayRef>; @@ -248,7 +248,7 @@ struct PrepareContractToGPUMMA AffineExpr m, n, k; bindDims(rewriter.getContext(), m, n, k); static constexpr std::array perm = {1, 0}; - auto iteratorTypes = op.iterator_types().getValue(); + auto iteratorTypes = op.getIteratorTypes().getValue(); SmallVector maps = op.getIndexingMaps(); if (!(isParallelIterator(iteratorTypes[0]) && isParallelIterator(iteratorTypes[1]) && @@ -286,7 +286,7 @@ struct PrepareContractToGPUMMA rewriter.replaceOpWithNewOp( op, lhs, rhs, res, rewriter.getAffineMapArrayAttr(infer({{m, k}, {k, n}, {m, n}})), - op.iterator_types()); + op.getIteratorTypes()); return success(); } }; @@ -299,7 +299,8 @@ struct CombineTransferReadOpTranspose final LogicalResult matchAndRewrite(vector::TransposeOp op, PatternRewriter &rewriter) const override { - auto transferReadOp = op.vector().getDefiningOp(); + auto transferReadOp = + op.getVector().getDefiningOp(); if (!transferReadOp) return failure(); @@ -307,7 +308,7 @@ struct CombineTransferReadOpTranspose final if (transferReadOp.getTransferRank() == 0) return failure(); - if (transferReadOp.mask() || transferReadOp.hasOutOfBoundsDim()) + if (transferReadOp.getMask() || transferReadOp.hasOutOfBoundsDim()) return failure(); SmallVector perm; op.getTransp(perm); @@ -316,11 +317,13 @@ struct CombineTransferReadOpTranspose final permU.push_back(unsigned(o)); AffineMap permutationMap = AffineMap::getPermutationMap(permU, op.getContext()); - AffineMap newMap = permutationMap.compose(transferReadOp.permutation_map()); + AffineMap newMap = + permutationMap.compose(transferReadOp.getPermutationMap()); rewriter.replaceOpWithNewOp( - op, op.getType(), transferReadOp.source(), transferReadOp.indices(), - AffineMapAttr::get(newMap), transferReadOp.padding(), - transferReadOp.mask(), transferReadOp.in_boundsAttr()); + op, op.getType(), transferReadOp.getSource(), + transferReadOp.getIndices(), AffineMapAttr::get(newMap), + transferReadOp.getPadding(), transferReadOp.getMask(), + transferReadOp.getInBoundsAttr()); return success(); } }; @@ -337,9 +340,9 @@ static const char *inferFragType(OpTy op) { auto contract = dyn_cast(users); if (!contract) continue; - if (contract.lhs() == op.getResult()) + if (contract.getLhs() == op.getResult()) return "AOp"; - if (contract.rhs() == op.getResult()) + if (contract.getRhs() == op.getResult()) return "BOp"; } return "COp"; @@ -351,7 +354,7 @@ static void convertTransferReadOp(vector::TransferReadOp op, assert(transferReadSupportsMMAMatrixType(op)); Optional stride = getMemrefConstantHorizontalStride(op.getShapedType()); - AffineMap map = op.permutation_map(); + AffineMap map = op.getPermutationMap(); // Handle broadcast by setting the stride to 0. if (map.getResult(0).isa()) { assert(map.getResult(0).cast().getValue() == 0); @@ -364,7 +367,8 @@ static void convertTransferReadOp(vector::TransferReadOp op, op.getVectorType().getElementType(), fragType); OpBuilder b(op); Value load = b.create( - op.getLoc(), type, op.source(), op.indices(), b.getIndexAttr(*stride)); + op.getLoc(), type, op.getSource(), op.getIndices(), + b.getIndexAttr(*stride)); valueMapping[op.getResult()] = load; } @@ -375,18 +379,19 @@ static void convertTransferWriteOp(vector::TransferWriteOp op, getMemrefConstantHorizontalStride(op.getShapedType()); assert(stride); OpBuilder b(op); - Value matrix = valueMapping.find(op.vector())->second; - b.create( - op.getLoc(), matrix, op.source(), op.indices(), b.getIndexAttr(*stride)); + Value matrix = valueMapping.find(op.getVector())->second; + b.create(op.getLoc(), matrix, op.getSource(), + op.getIndices(), + b.getIndexAttr(*stride)); op.erase(); } static void convertContractOp(vector::ContractionOp op, llvm::DenseMap &valueMapping) { OpBuilder b(op); - Value opA = valueMapping.find(op.lhs())->second; - Value opB = valueMapping.find(op.rhs())->second; - Value opC = valueMapping.find(op.acc())->second; + Value opA = valueMapping.find(op.getLhs())->second; + Value opB = valueMapping.find(op.getRhs())->second; + Value opC = valueMapping.find(op.getAcc())->second; Value matmul = b.create(op.getLoc(), opC.getType(), opA, opB, opC); valueMapping[op.getResult()] = matmul; @@ -420,7 +425,7 @@ static void convertBroadcastOp(vector::BroadcastOp op, gpu::MMAMatrixType type = gpu::MMAMatrixType::get( vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType)); auto matrix = b.create(op.getLoc(), type, - op.source()); + op.getSource()); valueMapping[op.getResult()] = matrix; } diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index 20e5100..3f6b3524 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -155,9 +155,9 @@ public: matchAndRewrite(vector::MatmulOp matmulOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp( - matmulOp, typeConverter->convertType(matmulOp.res().getType()), - adaptor.lhs(), adaptor.rhs(), matmulOp.lhs_rows(), - matmulOp.lhs_columns(), matmulOp.rhs_columns()); + matmulOp, typeConverter->convertType(matmulOp.getRes().getType()), + adaptor.getLhs(), adaptor.getRhs(), matmulOp.getLhsRows(), + matmulOp.getLhsColumns(), matmulOp.getRhsColumns()); return success(); } }; @@ -173,8 +173,8 @@ public: matchAndRewrite(vector::FlatTransposeOp transOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp( - transOp, typeConverter->convertType(transOp.res().getType()), - adaptor.matrix(), transOp.rows(), transOp.columns()); + transOp, typeConverter->convertType(transOp.getRes().getType()), + adaptor.getMatrix(), transOp.getRows(), transOp.getColumns()); return success(); } }; @@ -194,14 +194,14 @@ static void replaceLoadOrStoreOp(vector::MaskedLoadOp loadOp, VectorType vectorTy, Value ptr, unsigned align, ConversionPatternRewriter &rewriter) { rewriter.replaceOpWithNewOp( - loadOp, vectorTy, ptr, adaptor.mask(), adaptor.pass_thru(), align); + loadOp, vectorTy, ptr, adaptor.getMask(), adaptor.getPassThru(), align); } static void replaceLoadOrStoreOp(vector::StoreOp storeOp, vector::StoreOpAdaptor adaptor, VectorType vectorTy, Value ptr, unsigned align, ConversionPatternRewriter &rewriter) { - rewriter.replaceOpWithNewOp(storeOp, adaptor.valueToStore(), + rewriter.replaceOpWithNewOp(storeOp, adaptor.getValueToStore(), ptr, align); } @@ -210,7 +210,7 @@ static void replaceLoadOrStoreOp(vector::MaskedStoreOp storeOp, VectorType vectorTy, Value ptr, unsigned align, ConversionPatternRewriter &rewriter) { rewriter.replaceOpWithNewOp( - storeOp, adaptor.valueToStore(), ptr, adaptor.mask(), align); + storeOp, adaptor.getValueToStore(), ptr, adaptor.getMask(), align); } /// Conversion pattern for a vector.load, vector.store, vector.maskedload, and @@ -240,8 +240,8 @@ public: // Resolve address. auto vtype = this->typeConverter->convertType(loadOrStoreOp.getVectorType()) .template cast(); - Value dataPtr = this->getStridedElementPtr(loc, memRefTy, adaptor.base(), - adaptor.indices(), rewriter); + Value dataPtr = this->getStridedElementPtr(loc, memRefTy, adaptor.getBase(), + adaptor.getIndices(), rewriter); Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefTy, vtype); replaceLoadOrStoreOp(loadOrStoreOp, adaptor, vtype, ptr, align, rewriter); @@ -269,16 +269,16 @@ public: // Resolve address. Value ptrs; VectorType vType = gather.getVectorType(); - Value ptr = getStridedElementPtr(loc, memRefType, adaptor.base(), - adaptor.indices(), rewriter); - if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), ptr, - adaptor.index_vec(), memRefType, vType, ptrs))) + Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), + adaptor.getIndices(), rewriter); + if (failed(getIndexedPtrs(rewriter, loc, adaptor.getBase(), ptr, + adaptor.getIndexVec(), memRefType, vType, ptrs))) return failure(); // Replace with the gather intrinsic. rewriter.replaceOpWithNewOp( - gather, typeConverter->convertType(vType), ptrs, adaptor.mask(), - adaptor.pass_thru(), rewriter.getI32IntegerAttr(align)); + gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(), + adaptor.getPassThru(), rewriter.getI32IntegerAttr(align)); return success(); } }; @@ -303,15 +303,15 @@ public: // Resolve address. Value ptrs; VectorType vType = scatter.getVectorType(); - Value ptr = getStridedElementPtr(loc, memRefType, adaptor.base(), - adaptor.indices(), rewriter); - if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), ptr, - adaptor.index_vec(), memRefType, vType, ptrs))) + Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), + adaptor.getIndices(), rewriter); + if (failed(getIndexedPtrs(rewriter, loc, adaptor.getBase(), ptr, + adaptor.getIndexVec(), memRefType, vType, ptrs))) return failure(); // Replace with the scatter intrinsic. rewriter.replaceOpWithNewOp( - scatter, adaptor.valueToStore(), ptrs, adaptor.mask(), + scatter, adaptor.getValueToStore(), ptrs, adaptor.getMask(), rewriter.getI32IntegerAttr(align)); return success(); } @@ -331,11 +331,11 @@ public: // Resolve address. auto vtype = typeConverter->convertType(expand.getVectorType()); - Value ptr = getStridedElementPtr(loc, memRefType, adaptor.base(), - adaptor.indices(), rewriter); + Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), + adaptor.getIndices(), rewriter); rewriter.replaceOpWithNewOp( - expand, vtype, ptr, adaptor.mask(), adaptor.pass_thru()); + expand, vtype, ptr, adaptor.getMask(), adaptor.getPassThru()); return success(); } }; @@ -353,11 +353,11 @@ public: MemRefType memRefType = compress.getMemRefType(); // Resolve address. - Value ptr = getStridedElementPtr(loc, memRefType, adaptor.base(), - adaptor.indices(), rewriter); + Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), + adaptor.getIndices(), rewriter); rewriter.replaceOpWithNewOp( - compress, adaptor.valueToStore(), ptr, adaptor.mask()); + compress, adaptor.getValueToStore(), ptr, adaptor.getMask()); return success(); } }; @@ -374,8 +374,8 @@ public: LogicalResult matchAndRewrite(vector::ReductionOp reductionOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto kind = reductionOp.kind(); - Type eltType = reductionOp.dest().getType(); + auto kind = reductionOp.getKind(); + Type eltType = reductionOp.getDest().getType(); Type llvmType = typeConverter->convertType(eltType); Value operand = adaptor.getOperands()[0]; if (eltType.isIntOrIndex()) { @@ -468,7 +468,7 @@ public: auto v2Type = shuffleOp.getV2VectorType(); auto vectorType = shuffleOp.getVectorType(); Type llvmType = typeConverter->convertType(vectorType); - auto maskArrayAttr = shuffleOp.mask(); + auto maskArrayAttr = shuffleOp.getMask(); // Bail if result type cannot be lowered. if (!llvmType) @@ -484,7 +484,7 @@ public: // there is direct shuffle support in LLVM. Use it! if (rank == 1 && v1Type == v2Type) { Value llvmShuffleOp = rewriter.create( - loc, adaptor.v1(), adaptor.v2(), maskArrayAttr); + loc, adaptor.getV1(), adaptor.getV2(), maskArrayAttr); rewriter.replaceOp(shuffleOp, llvmShuffleOp); return success(); } @@ -499,10 +499,10 @@ public: int64_t insPos = 0; for (const auto &en : llvm::enumerate(maskArrayAttr)) { int64_t extPos = en.value().cast().getInt(); - Value value = adaptor.v1(); + Value value = adaptor.getV1(); if (extPos >= v1Dim) { extPos -= v1Dim; - value = adaptor.v2(); + value = adaptor.getV2(); } Value extract = extractOne(rewriter, *getTypeConverter(), loc, value, eltType, rank, extPos); @@ -537,12 +537,12 @@ public: loc, typeConverter->convertType(idxType), rewriter.getIntegerAttr(idxType, 0)); rewriter.replaceOpWithNewOp( - extractEltOp, llvmType, adaptor.vector(), zero); + extractEltOp, llvmType, adaptor.getVector(), zero); return success(); } rewriter.replaceOpWithNewOp( - extractEltOp, llvmType, adaptor.vector(), adaptor.position()); + extractEltOp, llvmType, adaptor.getVector(), adaptor.getPosition()); return success(); } }; @@ -559,7 +559,7 @@ public: auto vectorType = extractOp.getVectorType(); auto resultType = extractOp.getResult().getType(); auto llvmResultType = typeConverter->convertType(resultType); - auto positionArrayAttr = extractOp.position(); + auto positionArrayAttr = extractOp.getPosition(); // Bail if result type cannot be lowered. if (!llvmResultType) @@ -567,21 +567,21 @@ public: // Extract entire vector. Should be handled by folder, but just to be safe. if (positionArrayAttr.empty()) { - rewriter.replaceOp(extractOp, adaptor.vector()); + rewriter.replaceOp(extractOp, adaptor.getVector()); return success(); } // One-shot extraction of vector from array (only requires extractvalue). if (resultType.isa()) { Value extracted = rewriter.create( - loc, llvmResultType, adaptor.vector(), positionArrayAttr); + loc, llvmResultType, adaptor.getVector(), positionArrayAttr); rewriter.replaceOp(extractOp, extracted); return success(); } // Potential extraction of 1-D vector from array. auto *context = extractOp->getContext(); - Value extracted = adaptor.vector(); + Value extracted = adaptor.getVector(); auto positionAttrs = positionArrayAttr.getValue(); if (positionAttrs.size() > 1) { auto oneDVectorType = reducedVectorTypeBack(vectorType); @@ -628,8 +628,8 @@ public: VectorType vType = fmaOp.getVectorType(); if (vType.getRank() != 1) return failure(); - rewriter.replaceOpWithNewOp(fmaOp, adaptor.lhs(), - adaptor.rhs(), adaptor.acc()); + rewriter.replaceOpWithNewOp( + fmaOp, adaptor.getLhs(), adaptor.getRhs(), adaptor.getAcc()); return success(); } }; @@ -656,13 +656,13 @@ public: loc, typeConverter->convertType(idxType), rewriter.getIntegerAttr(idxType, 0)); rewriter.replaceOpWithNewOp( - insertEltOp, llvmType, adaptor.dest(), adaptor.source(), zero); + insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(), zero); return success(); } rewriter.replaceOpWithNewOp( - insertEltOp, llvmType, adaptor.dest(), adaptor.source(), - adaptor.position()); + insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(), + adaptor.getPosition()); return success(); } }; @@ -679,7 +679,7 @@ public: auto sourceType = insertOp.getSourceType(); auto destVectorType = insertOp.getDestVectorType(); auto llvmResultType = typeConverter->convertType(destVectorType); - auto positionArrayAttr = insertOp.position(); + auto positionArrayAttr = insertOp.getPosition(); // Bail if result type cannot be lowered. if (!llvmResultType) @@ -688,14 +688,14 @@ public: // Overwrite entire vector with value. Should be handled by folder, but // just to be safe. if (positionArrayAttr.empty()) { - rewriter.replaceOp(insertOp, adaptor.source()); + rewriter.replaceOp(insertOp, adaptor.getSource()); return success(); } // One-shot insertion of a vector into an array (only requires insertvalue). if (sourceType.isa()) { Value inserted = rewriter.create( - loc, llvmResultType, adaptor.dest(), adaptor.source(), + loc, llvmResultType, adaptor.getDest(), adaptor.getSource(), positionArrayAttr); rewriter.replaceOp(insertOp, inserted); return success(); @@ -703,7 +703,7 @@ public: // Potential extraction of 1-D vector from array. auto *context = insertOp->getContext(); - Value extracted = adaptor.dest(); + Value extracted = adaptor.getDest(); auto positionAttrs = positionArrayAttr.getValue(); auto position = positionAttrs.back().cast(); auto oneDVectorType = destVectorType; @@ -721,15 +721,15 @@ public: auto constant = rewriter.create(loc, i64Type, position); Value inserted = rewriter.create( loc, typeConverter->convertType(oneDVectorType), extracted, - adaptor.source(), constant); + adaptor.getSource(), constant); // Potential insertion of resulting 1-D vector into array. if (positionAttrs.size() > 1) { auto nMinusOnePositionAttrs = ArrayAttr::get(context, positionAttrs.drop_back()); - inserted = rewriter.create(loc, llvmResultType, - adaptor.dest(), inserted, - nMinusOnePositionAttrs); + inserted = rewriter.create( + loc, llvmResultType, adaptor.getDest(), inserted, + nMinusOnePositionAttrs); } rewriter.replaceOp(insertOp, inserted); @@ -780,9 +780,9 @@ public: loc, elemType, rewriter.getZeroAttr(elemType)); Value desc = rewriter.create(loc, vType, zero); for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) { - Value extrLHS = rewriter.create(loc, op.lhs(), i); - Value extrRHS = rewriter.create(loc, op.rhs(), i); - Value extrACC = rewriter.create(loc, op.acc(), i); + Value extrLHS = rewriter.create(loc, op.getLhs(), i); + Value extrRHS = rewriter.create(loc, op.getRhs(), i); + Value extrACC = rewriter.create(loc, op.getAcc(), i); Value fma = rewriter.create(loc, extrLHS, extrRHS, extrACC); desc = rewriter.create(loc, fma, desc, i); } @@ -1009,7 +1009,7 @@ public: // Unroll vector into elementary print calls. int64_t rank = vectorType ? vectorType.getRank() : 0; Type type = vectorType ? vectorType : eltType; - emitRanks(rewriter, printOp, adaptor.source(), type, printer, rank, + emitRanks(rewriter, printOp, adaptor.getSource(), type, printer, rank, conversion); emitCall(rewriter, printOp->getLoc(), LLVM::lookupOrCreatePrintNewlineFn( @@ -1119,13 +1119,13 @@ struct VectorSplatOpLowering : public ConvertOpToLLVMPattern { // For 0-d vector, we simply do `insertelement`. if (resultType.getRank() == 0) { rewriter.replaceOpWithNewOp( - splatOp, vectorType, undef, adaptor.input(), zero); + splatOp, vectorType, undef, adaptor.getInput(), zero); return success(); } // For 1-d vector, we additionally do a `vectorshuffle`. auto v = rewriter.create( - splatOp.getLoc(), vectorType, undef, adaptor.input(), zero); + splatOp.getLoc(), vectorType, undef, adaptor.getInput(), zero); int64_t width = splatOp.getType().cast().getDimSize(0); SmallVector zeroValues(width, 0); @@ -1170,7 +1170,7 @@ struct VectorSplatNdOpLowering : public ConvertOpToLLVMPattern { loc, typeConverter->convertType(rewriter.getIntegerType(32)), rewriter.getZeroAttr(rewriter.getIntegerType(32))); Value v = rewriter.create(loc, llvm1DVectorTy, vdesc, - adaptor.input(), zero); + adaptor.getInput(), zero); // Shuffle the value across the desired number of elements. int64_t width = resultType.getDimSize(resultType.getRank() - 1); diff --git a/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp b/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp index fc906bf..a849246 100644 --- a/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp +++ b/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp @@ -44,7 +44,7 @@ static LogicalResult replaceTransferOpWithMubuf( Type &vecTy, Value &dwordConfig, Value &vindex, Value &offsetSizeInBytes, Value &glc, Value &slc) { auto adaptor = TransferWriteOpAdaptor(operands, xferOp->getAttrDictionary()); - rewriter.replaceOpWithNewOp(xferOp, adaptor.vector(), + rewriter.replaceOpWithNewOp(xferOp, adaptor.getVector(), dwordConfig, vindex, offsetSizeInBytes, glc, slc); return success(); @@ -68,10 +68,10 @@ public: return failure(); if (xferOp.getVectorType().getRank() > 1 || - llvm::size(xferOp.indices()) == 0) + llvm::size(xferOp.getIndices()) == 0) return failure(); - if (!xferOp.permutation_map().isMinorIdentity()) + if (!xferOp.getPermutationMap().isMinorIdentity()) return failure(); // Have it handled in vector->llvm conversion pass. @@ -105,7 +105,7 @@ public: // indices, so no need to calculate offset size in bytes again in // the MUBUF instruction. Value dataPtr = this->getStridedElementPtr( - loc, memRefType, adaptor.source(), adaptor.indices(), rewriter); + loc, memRefType, adaptor.getSource(), adaptor.getIndices(), rewriter); // 1. Create and fill a <4 x i32> dwordConfig with: // 1st two elements holding the address of dataPtr. diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp index 0c79403..0f57c72 100644 --- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp +++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp @@ -53,7 +53,7 @@ template static Optional unpackedDim(OpTy xferOp) { // TODO: support 0-d corner case. assert(xferOp.getTransferRank() > 0 && "unexpected 0-d transfer"); - auto map = xferOp.permutation_map(); + auto map = xferOp.getPermutationMap(); if (auto expr = map.getResult(0).template dyn_cast()) { return expr.getPosition(); } @@ -69,7 +69,7 @@ template static AffineMap unpackedPermutationMap(OpBuilder &b, OpTy xferOp) { // TODO: support 0-d corner case. assert(xferOp.getTransferRank() > 0 && "unexpected 0-d transfer"); - auto map = xferOp.permutation_map(); + auto map = xferOp.getPermutationMap(); return AffineMap::get(map.getNumDims(), 0, map.getResults().drop_front(), b.getContext()); } @@ -86,7 +86,7 @@ static void getXferIndices(OpBuilder &b, OpTy xferOp, Value iv, typename OpTy::Adaptor adaptor(xferOp); // Corresponding memref dim of the vector dim that is unpacked. auto dim = unpackedDim(xferOp); - auto prevIndices = adaptor.indices(); + auto prevIndices = adaptor.getIndices(); indices.append(prevIndices.begin(), prevIndices.end()); Location loc = xferOp.getLoc(); @@ -94,7 +94,7 @@ static void getXferIndices(OpBuilder &b, OpTy xferOp, Value iv, if (!isBroadcast) { AffineExpr d0, d1; bindDims(xferOp.getContext(), d0, d1); - Value offset = adaptor.indices()[dim.getValue()]; + Value offset = adaptor.getIndices()[dim.getValue()]; indices[dim.getValue()] = makeComposedAffineApply(b, loc, d0 + d1, {offset, iv}); } @@ -118,7 +118,7 @@ static void maybeYieldValue(OpBuilder &b, Location loc, bool hasRetVal, /// * The to-be-unpacked dim of xferOp is a broadcast. template static Value generateMaskCheck(OpBuilder &b, OpTy xferOp, Value iv) { - if (!xferOp.mask()) + if (!xferOp.getMask()) return Value(); if (xferOp.getMaskType().getRank() != 1) return Value(); @@ -126,7 +126,7 @@ static Value generateMaskCheck(OpBuilder &b, OpTy xferOp, Value iv) { return Value(); Location loc = xferOp.getLoc(); - return b.create(loc, xferOp.mask(), iv); + return b.create(loc, xferOp.getMask(), iv); } /// Helper function TransferOpConversion and TransferOp1dConversion. @@ -167,10 +167,11 @@ static Value generateInBoundsCheck( Location loc = xferOp.getLoc(); ImplicitLocOpBuilder lb(xferOp.getLoc(), b); if (!xferOp.isDimInBounds(0) && !isBroadcast) { - Value memrefDim = vector::createOrFoldDimOp(b, loc, xferOp.source(), *dim); + Value memrefDim = + vector::createOrFoldDimOp(b, loc, xferOp.getSource(), *dim); AffineExpr d0, d1; bindDims(xferOp.getContext(), d0, d1); - Value base = xferOp.indices()[dim.getValue()]; + Value base = xferOp.getIndices()[dim.getValue()]; Value memrefIdx = makeComposedAffineApply(b, loc, d0 + d1, {base, iv}); cond = lb.create(arith::CmpIPredicate::sgt, memrefDim, memrefIdx); @@ -289,11 +290,11 @@ static BufferAllocs allocBuffers(OpBuilder &b, OpTy xferOp) { auto bufferType = MemRefType::get({}, xferOp.getVectorType()); result.dataBuffer = b.create(loc, bufferType); - if (xferOp.mask()) { - auto maskType = MemRefType::get({}, xferOp.mask().getType()); + if (xferOp.getMask()) { + auto maskType = MemRefType::get({}, xferOp.getMask().getType()); auto maskBuffer = b.create(loc, maskType); b.setInsertionPoint(xferOp); - b.create(loc, xferOp.mask(), maskBuffer); + b.create(loc, xferOp.getMask(), maskBuffer); result.maskBuffer = b.create(loc, maskBuffer); } @@ -319,8 +320,8 @@ static MemRefType unpackOneDim(MemRefType type) { /// is similar to Strategy::getBuffer. template static Value getMaskBuffer(OpTy xferOp) { - assert(xferOp.mask() && "Expected that transfer op has mask"); - auto loadOp = xferOp.mask().template getDefiningOp(); + assert(xferOp.getMask() && "Expected that transfer op has mask"); + auto loadOp = xferOp.getMask().template getDefiningOp(); assert(loadOp && "Expected transfer op mask produced by LoadOp"); return loadOp.getMemRef(); } @@ -401,15 +402,15 @@ struct Strategy { Location loc = xferOp.getLoc(); auto bufferType = buffer.getType().dyn_cast(); auto vecType = bufferType.getElementType().dyn_cast(); - auto inBoundsAttr = dropFirstElem(b, xferOp.in_boundsAttr()); + auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr()); auto newXferOp = b.create( - loc, vecType, xferOp.source(), xferIndices, - AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), xferOp.padding(), - Value(), inBoundsAttr); + loc, vecType, xferOp.getSource(), xferIndices, + AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), + xferOp.getPadding(), Value(), inBoundsAttr); maybeApplyPassLabel(b, newXferOp, options.targetRank); - b.create(loc, newXferOp.vector(), buffer, storeIndices); + b.create(loc, newXferOp.getVector(), buffer, storeIndices); return newXferOp; } @@ -425,7 +426,7 @@ struct Strategy { Location loc = xferOp.getLoc(); auto bufferType = buffer.getType().dyn_cast(); auto vecType = bufferType.getElementType().dyn_cast(); - auto vec = b.create(loc, vecType, xferOp.padding()); + auto vec = b.create(loc, vecType, xferOp.getPadding()); b.create(loc, vec, buffer, storeIndices); return Value(); @@ -453,7 +454,7 @@ struct Strategy { /// vector.transfer_write %vec ... { __vector_to_scf_lowering__ } ... /// ``` static Value getBuffer(TransferWriteOp xferOp) { - auto loadOp = xferOp.vector().getDefiningOp(); + auto loadOp = xferOp.getVector().getDefiningOp(); assert(loadOp && "Expected transfer op vector produced by LoadOp"); return loadOp.getMemRef(); } @@ -461,7 +462,7 @@ struct Strategy { /// Retrieve the indices of the current LoadOp that loads from the buffer. static void getBufferIndices(TransferWriteOp xferOp, SmallVector &indices) { - auto loadOp = xferOp.vector().getDefiningOp(); + auto loadOp = xferOp.getVector().getDefiningOp(); auto prevIndices = memref::LoadOpAdaptor(loadOp).indices(); indices.append(prevIndices.begin(), prevIndices.end()); } @@ -488,8 +489,8 @@ struct Strategy { Location loc = xferOp.getLoc(); auto vec = b.create(loc, buffer, loadIndices); - auto inBoundsAttr = dropFirstElem(b, xferOp.in_boundsAttr()); - auto source = loopState.empty() ? xferOp.source() : loopState[0]; + auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr()); + auto source = loopState.empty() ? xferOp.getSource() : loopState[0]; Type type = isTensorOp(xferOp) ? xferOp.getShapedType() : Type(); auto newXferOp = b.create( loc, type, vec, source, xferIndices, @@ -521,7 +522,7 @@ struct Strategy { /// Return the initial loop state for the generated scf.for loop. static Value initialLoopState(TransferWriteOp xferOp) { - return isTensorOp(xferOp) ? xferOp.source() : Value(); + return isTensorOp(xferOp) ? xferOp.getSource() : Value(); } }; @@ -576,8 +577,8 @@ struct PrepareTransferReadConversion auto buffers = allocBuffers(rewriter, xferOp); auto *newXfer = rewriter.clone(*xferOp.getOperation()); newXfer->setAttr(kPassLabel, rewriter.getUnitAttr()); - if (xferOp.mask()) { - dyn_cast(newXfer).maskMutable().assign( + if (xferOp.getMask()) { + dyn_cast(newXfer).getMaskMutable().assign( buffers.maskBuffer); } @@ -624,16 +625,18 @@ struct PrepareTransferWriteConversion Location loc = xferOp.getLoc(); auto buffers = allocBuffers(rewriter, xferOp); - rewriter.create(loc, xferOp.vector(), buffers.dataBuffer); + rewriter.create(loc, xferOp.getVector(), + buffers.dataBuffer); auto loadedVec = rewriter.create(loc, buffers.dataBuffer); rewriter.updateRootInPlace(xferOp, [&]() { - xferOp.vectorMutable().assign(loadedVec); + xferOp.getVectorMutable().assign(loadedVec); xferOp->setAttr(kPassLabel, rewriter.getUnitAttr()); }); - if (xferOp.mask()) { - rewriter.updateRootInPlace( - xferOp, [&]() { xferOp.maskMutable().assign(buffers.maskBuffer); }); + if (xferOp.getMask()) { + rewriter.updateRootInPlace(xferOp, [&]() { + xferOp.getMaskMutable().assign(buffers.maskBuffer); + }); } return success(); @@ -694,7 +697,7 @@ struct TransferOpConversion : public VectorToSCFPattern { // If the xferOp has a mask: Find and cast mask buffer. Value castedMaskBuffer; - if (xferOp.mask()) { + if (xferOp.getMask()) { auto maskBuffer = getMaskBuffer(xferOp); auto maskBufferType = maskBuffer.getType().template dyn_cast(); @@ -741,8 +744,8 @@ struct TransferOpConversion : public VectorToSCFPattern { // the // unpacked dim is not a broadcast, no mask is // needed on the new transfer op. - if (xferOp.mask() && (xferOp.isBroadcastDim(0) || - xferOp.getMaskType().getRank() > 1)) { + if (xferOp.getMask() && (xferOp.isBroadcastDim(0) || + xferOp.getMaskType().getRank() > 1)) { OpBuilder::InsertionGuard guard(b); b.setInsertionPoint(newXfer); // Insert load before newXfer. @@ -755,8 +758,9 @@ struct TransferOpConversion : public VectorToSCFPattern { auto mask = b.create(loc, castedMaskBuffer, loadIndices); - rewriter.updateRootInPlace( - newXfer, [&]() { newXfer.maskMutable().assign(mask); }); + rewriter.updateRootInPlace(newXfer, [&]() { + newXfer.getMaskMutable().assign(mask); + }); } return loopState.empty() ? Value() : newXfer->getResult(0); @@ -784,13 +788,13 @@ namespace lowering_n_d_unrolled { template static void maybeAssignMask(OpBuilder &b, OpTy xferOp, OpTy newXferOp, int64_t i) { - if (!xferOp.mask()) + if (!xferOp.getMask()) return; if (xferOp.isBroadcastDim(0)) { // To-be-unpacked dimension is a broadcast, which does not have a // corresponding mask dimension. Mask attribute remains unchanged. - newXferOp.maskMutable().assign(xferOp.mask()); + newXferOp.getMaskMutable().assign(xferOp.getMask()); return; } @@ -801,8 +805,8 @@ static void maybeAssignMask(OpBuilder &b, OpTy xferOp, OpTy newXferOp, llvm::SmallVector indices({i}); Location loc = xferOp.getLoc(); - auto newMask = b.create(loc, xferOp.mask(), indices); - newXferOp.maskMutable().assign(newMask); + auto newMask = b.create(loc, xferOp.getMask(), indices); + newXferOp.getMaskMutable().assign(newMask); } // If we end up here: The mask of the old transfer op is 1D and the unpacked @@ -853,10 +857,10 @@ struct UnrollTransferReadConversion Value getResultVector(TransferReadOp xferOp, PatternRewriter &rewriter) const { if (auto insertOp = getInsertOp(xferOp)) - return insertOp.dest(); + return insertOp.getDest(); Location loc = xferOp.getLoc(); return rewriter.create(loc, xferOp.getVectorType(), - xferOp.padding()); + xferOp.getPadding()); } /// If the result of the TransferReadOp has exactly one user, which is a @@ -876,7 +880,7 @@ struct UnrollTransferReadConversion void getInsertionIndices(TransferReadOp xferOp, SmallVector &indices) const { if (auto insertOp = getInsertOp(xferOp)) { - llvm::for_each(insertOp.position(), [&](Attribute attr) { + llvm::for_each(insertOp.getPosition(), [&](Attribute attr) { indices.push_back(attr.dyn_cast().getInt()); }); } @@ -921,11 +925,11 @@ struct UnrollTransferReadConversion getInsertionIndices(xferOp, insertionIndices); insertionIndices.push_back(i); - auto inBoundsAttr = dropFirstElem(b, xferOp.in_boundsAttr()); + auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr()); auto newXferOp = b.create( - loc, newXferVecType, xferOp.source(), xferIndices, + loc, newXferVecType, xferOp.getSource(), xferIndices, AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), - xferOp.padding(), Value(), inBoundsAttr); + xferOp.getPadding(), Value(), inBoundsAttr); maybeAssignMask(b, xferOp, newXferOp, i); return b.create(loc, newXferOp, vec, insertionIndices); @@ -988,13 +992,13 @@ struct UnrollTransferWriteConversion /// Return the vector from which newly generated ExtracOps will extract. Value getDataVector(TransferWriteOp xferOp) const { if (auto extractOp = getExtractOp(xferOp)) - return extractOp.vector(); - return xferOp.vector(); + return extractOp.getVector(); + return xferOp.getVector(); } /// If the input of the given TransferWriteOp is an ExtractOp, return it. vector::ExtractOp getExtractOp(TransferWriteOp xferOp) const { - if (auto *op = xferOp.vector().getDefiningOp()) + if (auto *op = xferOp.getVector().getDefiningOp()) return dyn_cast(op); return vector::ExtractOp(); } @@ -1004,7 +1008,7 @@ struct UnrollTransferWriteConversion void getExtractionIndices(TransferWriteOp xferOp, SmallVector &indices) const { if (auto extractOp = getExtractOp(xferOp)) { - llvm::for_each(extractOp.position(), [&](Attribute attr) { + llvm::for_each(extractOp.getPosition(), [&](Attribute attr) { indices.push_back(attr.dyn_cast().getInt()); }); } @@ -1026,7 +1030,7 @@ struct UnrollTransferWriteConversion auto vec = getDataVector(xferOp); auto xferVecType = xferOp.getVectorType(); int64_t dimSize = xferVecType.getShape()[0]; - auto source = xferOp.source(); // memref or tensor to be written to. + auto source = xferOp.getSource(); // memref or tensor to be written to. auto sourceType = isTensorOp(xferOp) ? xferOp.getShapedType() : Type(); // Generate fully unrolled loop of transfer ops. @@ -1050,7 +1054,7 @@ struct UnrollTransferWriteConversion auto extracted = b.create(loc, vec, extractionIndices); - auto inBoundsAttr = dropFirstElem(b, xferOp.in_boundsAttr()); + auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr()); auto newXferOp = b.create( loc, sourceType, extracted, source, xferIndices, AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), Value(), @@ -1089,8 +1093,8 @@ template static Optional get1dMemrefIndices(OpBuilder &b, OpTy xferOp, Value iv, SmallVector &memrefIndices) { - auto indices = xferOp.indices(); - auto map = xferOp.permutation_map(); + auto indices = xferOp.getIndices(); + auto map = xferOp.getPermutationMap(); assert(xferOp.getTransferRank() > 0 && "unexpected 0-d transfer"); memrefIndices.append(indices.begin(), indices.end()); @@ -1132,7 +1136,8 @@ struct Strategy1d { b, xferOp, iv, dim, TypeRange(xferOp.getVectorType()), /*inBoundsCase=*/ [&](OpBuilder &b, Location loc) { - Value val = b.create(loc, xferOp.source(), indices); + Value val = + b.create(loc, xferOp.getSource(), indices); return b.create(loc, val, vec, iv); }, /*outOfBoundsCase=*/ @@ -1144,7 +1149,7 @@ struct Strategy1d { // Inititalize vector with padding value. Location loc = xferOp.getLoc(); return b.create(loc, xferOp.getVectorType(), - xferOp.padding()); + xferOp.getPadding()); } }; @@ -1162,8 +1167,8 @@ struct Strategy1d { b, xferOp, iv, dim, /*inBoundsCase=*/[&](OpBuilder &b, Location loc) { auto val = - b.create(loc, xferOp.vector(), iv); - b.create(loc, val, xferOp.source(), indices); + b.create(loc, xferOp.getVector(), iv); + b.create(loc, val, xferOp.getSource(), indices); }); b.create(loc); } @@ -1221,7 +1226,7 @@ struct TransferOp1dConversion : public VectorToSCFPattern { // TODO: support 0-d corner case. if (xferOp.getTransferRank() == 0) return failure(); - auto map = xferOp.permutation_map(); + auto map = xferOp.getPermutationMap(); auto memRefType = xferOp.getShapedType().template dyn_cast(); if (!memRefType) diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp index 4061f57..5bdcc38 100644 --- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp +++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp @@ -44,11 +44,11 @@ struct VectorBitcastConvert final if (!dstType) return failure(); - if (dstType == adaptor.source().getType()) - rewriter.replaceOp(bitcastOp, adaptor.source()); + if (dstType == adaptor.getSource().getType()) + rewriter.replaceOp(bitcastOp, adaptor.getSource()); else rewriter.replaceOpWithNewOp(bitcastOp, dstType, - adaptor.source()); + adaptor.getSource()); return success(); } @@ -61,11 +61,11 @@ struct VectorBroadcastConvert final LogicalResult matchAndRewrite(vector::BroadcastOp broadcastOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - if (broadcastOp.source().getType().isa() || + if (broadcastOp.getSource().getType().isa() || !spirv::CompositeType::isValid(broadcastOp.getVectorType())) return failure(); SmallVector source(broadcastOp.getVectorType().getNumElements(), - adaptor.source()); + adaptor.getSource()); rewriter.replaceOpWithNewOp( broadcastOp, broadcastOp.getVectorType(), source); return success(); @@ -88,14 +88,14 @@ struct VectorExtractOpConvert final if (!dstType) return failure(); - if (adaptor.vector().getType().isa()) { - rewriter.replaceOp(extractOp, adaptor.vector()); + if (adaptor.getVector().getType().isa()) { + rewriter.replaceOp(extractOp, adaptor.getVector()); return success(); } - int32_t id = getFirstIntValue(extractOp.position()); + int32_t id = getFirstIntValue(extractOp.getPosition()); rewriter.replaceOpWithNewOp( - extractOp, adaptor.vector(), id); + extractOp, adaptor.getVector(), id); return success(); } }; @@ -111,10 +111,9 @@ struct VectorExtractStridedSliceOpConvert final if (!dstType) return failure(); - - uint64_t offset = getFirstIntValue(extractOp.offsets()); - uint64_t size = getFirstIntValue(extractOp.sizes()); - uint64_t stride = getFirstIntValue(extractOp.strides()); + uint64_t offset = getFirstIntValue(extractOp.getOffsets()); + uint64_t size = getFirstIntValue(extractOp.getSizes()); + uint64_t stride = getFirstIntValue(extractOp.getStrides()); if (stride != 1) return failure(); @@ -147,7 +146,8 @@ struct VectorFmaOpConvert final : public OpConversionPattern { if (!spirv::CompositeType::isValid(fmaOp.getVectorType())) return failure(); rewriter.replaceOpWithNewOp( - fmaOp, fmaOp.getType(), adaptor.lhs(), adaptor.rhs(), adaptor.acc()); + fmaOp, fmaOp.getType(), adaptor.getLhs(), adaptor.getRhs(), + adaptor.getAcc()); return success(); } }; @@ -162,16 +162,16 @@ struct VectorInsertOpConvert final // Special case for inserting scalar values into size-1 vectors. if (insertOp.getSourceType().isIntOrFloat() && insertOp.getDestVectorType().getNumElements() == 1) { - rewriter.replaceOp(insertOp, adaptor.source()); + rewriter.replaceOp(insertOp, adaptor.getSource()); return success(); } if (insertOp.getSourceType().isa() || !spirv::CompositeType::isValid(insertOp.getDestVectorType())) return failure(); - int32_t id = getFirstIntValue(insertOp.position()); + int32_t id = getFirstIntValue(insertOp.getPosition()); rewriter.replaceOpWithNewOp( - insertOp, adaptor.source(), adaptor.dest(), id); + insertOp, adaptor.getSource(), adaptor.getDest(), id); return success(); } }; @@ -186,8 +186,8 @@ struct VectorExtractElementOpConvert final if (!spirv::CompositeType::isValid(extractElementOp.getVectorType())) return failure(); rewriter.replaceOpWithNewOp( - extractElementOp, extractElementOp.getType(), adaptor.vector(), - extractElementOp.position()); + extractElementOp, extractElementOp.getType(), adaptor.getVector(), + extractElementOp.getPosition()); return success(); } }; @@ -202,8 +202,8 @@ struct VectorInsertElementOpConvert final if (!spirv::CompositeType::isValid(insertElementOp.getDestVectorType())) return failure(); rewriter.replaceOpWithNewOp( - insertElementOp, insertElementOp.getType(), insertElementOp.dest(), - adaptor.source(), insertElementOp.position()); + insertElementOp, insertElementOp.getType(), insertElementOp.getDest(), + adaptor.getSource(), insertElementOp.getPosition()); return success(); } }; @@ -218,10 +218,10 @@ struct VectorInsertStridedSliceOpConvert final Value srcVector = adaptor.getOperands().front(); Value dstVector = adaptor.getOperands().back(); - uint64_t stride = getFirstIntValue(insertOp.strides()); + uint64_t stride = getFirstIntValue(insertOp.getStrides()); if (stride != 1) return failure(); - uint64_t offset = getFirstIntValue(insertOp.offsets()); + uint64_t offset = getFirstIntValue(insertOp.getOffsets()); if (srcVector.getType().isa()) { assert(!dstVector.getType().isa()); @@ -259,7 +259,8 @@ public: VectorType dstVecType = op.getType(); if (!spirv::CompositeType::isValid(dstVecType)) return failure(); - SmallVector source(dstVecType.getNumElements(), adaptor.input()); + SmallVector source(dstVecType.getNumElements(), + adaptor.getInput()); rewriter.replaceOpWithNewOp(op, dstVecType, source); return success(); @@ -281,19 +282,19 @@ struct VectorShuffleOpConvert final auto oldSourceType = shuffleOp.getV1VectorType(); if (oldSourceType.getNumElements() > 1) { SmallVector components = llvm::to_vector<4>( - llvm::map_range(shuffleOp.mask(), [](Attribute attr) -> int32_t { + llvm::map_range(shuffleOp.getMask(), [](Attribute attr) -> int32_t { return attr.cast().getValue().getZExtValue(); })); rewriter.replaceOpWithNewOp( - shuffleOp, newResultType, adaptor.v1(), adaptor.v2(), + shuffleOp, newResultType, adaptor.getV1(), adaptor.getV2(), rewriter.getI32ArrayAttr(components)); return success(); } - SmallVector oldOperands = {adaptor.v1(), adaptor.v2()}; + SmallVector oldOperands = {adaptor.getV1(), adaptor.getV2()}; SmallVector newOperands; newOperands.reserve(oldResultType.getNumElements()); - for (const APInt &i : shuffleOp.mask().getAsValueRange()) { + for (const APInt &i : shuffleOp.getMask().getAsValueRange()) { newOperands.push_back(oldOperands[i.getZExtValue()]); } rewriter.replaceOpWithNewOp( diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp index 578d956..570359f 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp @@ -148,7 +148,7 @@ static HoistableRead findMatchingTransferRead(HoistableWrite write, LLVM_DEBUG(DBGS() << "maybeTransferReadUser: " << *maybeTransferReadUser << "\n"); auto read = dyn_cast(maybeTransferReadUser); - if (read && read.indices() == write.transferWriteOp.indices() && + if (read && read.getIndices() == write.transferWriteOp.getIndices() && read.getVectorType() == write.transferWriteOp.getVectorType()) return HoistableRead{read, sliceOp}; } @@ -223,7 +223,7 @@ getLoopInvariantTransferWriteOpDefining(scf::ForOp forOp, Value v = yieldOperand.get(); if (auto write = v.getDefiningOp()) { // Indexing must not depend on `forOp`. - for (Value operand : write.indices()) + for (Value operand : write.getIndices()) if (!forOp.isDefinedOutsideOfLoop(operand)) return HoistableWrite(); @@ -286,7 +286,7 @@ static void hoistReadWrite(HoistableRead read, HoistableWrite write, read.extractSliceOp.sourceMutable().assign( forOp.getInitArgs()[initArgNumber]); else - read.transferReadOp.sourceMutable().assign( + read.transferReadOp.getSourceMutable().assign( forOp.getInitArgs()[initArgNumber]); // Hoist write after. @@ -299,12 +299,12 @@ static void hoistReadWrite(HoistableRead read, HoistableWrite write, if (write.insertSliceOp) yieldOp->setOperand(initArgNumber, write.insertSliceOp.dest()); else - yieldOp->setOperand(initArgNumber, write.transferWriteOp.source()); + yieldOp->setOperand(initArgNumber, write.transferWriteOp.getSource()); // Rewrite `loop` with additional new yields. OpBuilder b(read.transferReadOp); - auto newForOp = cloneWithNewYields(b, forOp, read.transferReadOp.vector(), - write.transferWriteOp.vector()); + auto newForOp = cloneWithNewYields(b, forOp, read.transferReadOp.getVector(), + write.transferWriteOp.getVector()); // Transfer write has been hoisted, need to update the vector and tensor // source. Replace the result of the loop to use the new tensor created // outside the loop. @@ -313,17 +313,18 @@ static void hoistReadWrite(HoistableRead read, HoistableWrite write, if (write.insertSliceOp) { newForOp.getResult(initArgNumber) .replaceAllUsesWith(write.insertSliceOp.getResult()); - write.transferWriteOp.sourceMutable().assign(read.extractSliceOp.result()); + write.transferWriteOp.getSourceMutable().assign( + read.extractSliceOp.result()); write.insertSliceOp.destMutable().assign(read.extractSliceOp.source()); } else { newForOp.getResult(initArgNumber) .replaceAllUsesWith(write.transferWriteOp.getResult()); - write.transferWriteOp.sourceMutable().assign( + write.transferWriteOp.getSourceMutable().assign( newForOp.getResult(initArgNumber)); } // Always update with the newly yield tensor and vector. - write.transferWriteOp.vectorMutable().assign(newForOp.getResults().back()); + write.transferWriteOp.getVectorMutable().assign(newForOp.getResults().back()); } // To hoist transfer op on tensor the logic can be significantly simplified @@ -355,7 +356,7 @@ void mlir::linalg::hoistRedundantVectorTransfersOnTensor(FuncOp func) { if (write.insertSliceOp) LLVM_DEBUG(DBGS() << "Candidate insert_slice for hoisting: " << *write.insertSliceOp.getOperation() << "\n"); - if (llvm::any_of(write.transferWriteOp.indices(), + if (llvm::any_of(write.transferWriteOp.getIndices(), [&forOp](Value index) { return !forOp.isDefinedOutsideOfLoop(index); })) @@ -422,7 +423,8 @@ void mlir::linalg::hoistRedundantVectorTransfers(FuncOp func) { vector::TransferWriteOp transferWrite; for (auto *sliceOp : llvm::reverse(forwardSlice)) { auto candidateWrite = dyn_cast(sliceOp); - if (!candidateWrite || candidateWrite.source() != transferRead.source()) + if (!candidateWrite || + candidateWrite.getSource() != transferRead.getSource()) continue; transferWrite = candidateWrite; } @@ -444,7 +446,7 @@ void mlir::linalg::hoistRedundantVectorTransfers(FuncOp func) { // 2. no other operations in the loop access the same memref except // for transfer_read/transfer_write accessing statically disjoint // slices. - if (transferRead.indices() != transferWrite.indices() && + if (transferRead.getIndices() != transferWrite.getIndices() && transferRead.getVectorType() == transferWrite.getVectorType()) return WalkResult::advance(); @@ -453,7 +455,7 @@ void mlir::linalg::hoistRedundantVectorTransfers(FuncOp func) { DominanceInfo dom(loop); if (!dom.properlyDominates(transferRead.getOperation(), transferWrite)) return WalkResult::advance(); - for (auto &use : transferRead.source().getUses()) { + for (auto &use : transferRead.getSource().getUses()) { if (!loop->isAncestor(use.getOwner())) continue; if (use.getOwner() == transferRead.getOperation() || @@ -488,12 +490,12 @@ void mlir::linalg::hoistRedundantVectorTransfers(FuncOp func) { // Rewrite `loop` with new yields by cloning and erase the original loop. OpBuilder b(transferRead); - auto newForOp = cloneWithNewYields(b, loop, transferRead.vector(), - transferWrite.vector()); + auto newForOp = cloneWithNewYields(b, loop, transferRead.getVector(), + transferWrite.getVector()); // Transfer write has been hoisted, need to update the written value to // the value yielded by the newForOp. - transferWrite.vector().replaceAllUsesWith( + transferWrite.getVector().replaceAllUsesWith( newForOp.getResults().take_back()[0]); changed = true; diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index a7ae791..4660baa 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -846,15 +846,15 @@ struct PadOpVectorizationWithTransferReadPattern if (!padValue) return failure(); // Padding value of existing `xferOp` is unused. - if (xferOp.hasOutOfBoundsDim() || xferOp.mask()) + if (xferOp.hasOutOfBoundsDim() || xferOp.getMask()) return failure(); rewriter.updateRootInPlace(xferOp, [&]() { SmallVector inBounds(xferOp.getVectorType().getRank(), false); xferOp->setAttr(xferOp.getInBoundsAttrName(), rewriter.getBoolArrayAttr(inBounds)); - xferOp.sourceMutable().assign(padOp.source()); - xferOp.paddingMutable().assign(padValue); + xferOp.getSourceMutable().assign(padOp.source()); + xferOp.getPaddingMutable().assign(padValue); }); return success(); @@ -929,8 +929,8 @@ struct PadOpVectorizationWithTransferWritePattern SmallVector inBounds(xferOp.getVectorType().getRank(), false); auto newXferOp = rewriter.replaceOpWithNewOp( - xferOp, padOp.source().getType(), xferOp.vector(), padOp.source(), - xferOp.indices(), xferOp.permutation_mapAttr(), xferOp.mask(), + xferOp, padOp.source().getType(), xferOp.getVector(), padOp.source(), + xferOp.getIndices(), xferOp.getPermutationMapAttr(), xferOp.getMask(), rewriter.getBoolArrayAttr(inBounds)); rewriter.replaceOp(trimPadding, newXferOp->getResult(0)); @@ -1174,11 +1174,11 @@ LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite( vector::TransferReadOp xferOp, PatternRewriter &rewriter) const { // TODO: support mask. - if (xferOp.mask()) + if (xferOp.getMask()) return failure(); // Transfer into `view`. - Value viewOrAlloc = xferOp.source(); + Value viewOrAlloc = xferOp.getSource(); if (!viewOrAlloc.getDefiningOp() && !viewOrAlloc.getDefiningOp()) return failure(); @@ -1226,7 +1226,7 @@ LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite( } } // Ensure padding matches. - if (maybeFillOp && xferOp.padding() != maybeFillOp.value()) + if (maybeFillOp && xferOp.getPadding() != maybeFillOp.value()) return failure(); if (maybeFillOp) LDBG("with maybeFillOp " << *maybeFillOp); @@ -1239,8 +1239,8 @@ LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite( // When forwarding to vector.transfer_read, the attribute must be reset // conservatively. Value res = rewriter.create( - xferOp.getLoc(), xferOp.getVectorType(), in, xferOp.indices(), - xferOp.permutation_mapAttr(), xferOp.padding(), xferOp.mask(), + xferOp.getLoc(), xferOp.getVectorType(), in, xferOp.getIndices(), + xferOp.getPermutationMapAttr(), xferOp.getPadding(), xferOp.getMask(), // in_bounds is explicitly reset /*inBoundsAttr=*/ArrayAttr()); @@ -1257,11 +1257,11 @@ LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite( LogicalResult LinalgCopyVTWForwardingPattern::matchAndRewrite( vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const { // TODO: support mask. - if (xferOp.mask()) + if (xferOp.getMask()) return failure(); // Transfer into `viewOrAlloc`. - Value viewOrAlloc = xferOp.source(); + Value viewOrAlloc = xferOp.getSource(); if (!viewOrAlloc.getDefiningOp() && !viewOrAlloc.getDefiningOp()) return failure(); @@ -1297,8 +1297,8 @@ LogicalResult LinalgCopyVTWForwardingPattern::matchAndRewrite( // When forwarding to vector.transfer_write, the attribute must be reset // conservatively. rewriter.create( - xferOp.getLoc(), xferOp.vector(), out, xferOp.indices(), - xferOp.permutation_mapAttr(), xferOp.mask(), + xferOp.getLoc(), xferOp.getVector(), out, xferOp.getIndices(), + xferOp.getPermutationMapAttr(), xferOp.getMask(), // in_bounds is explicitly reset /*inBoundsAttr=*/ArrayAttr()); diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldSubViewOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldSubViewOps.cpp index c54cda0..4e0438f 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/FoldSubViewOps.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/FoldSubViewOps.cpp @@ -96,10 +96,12 @@ static Value getMemRefOperand(LoadOrStoreOpTy op) { return op.memref(); } -static Value getMemRefOperand(vector::TransferReadOp op) { return op.source(); } +static Value getMemRefOperand(vector::TransferReadOp op) { + return op.getSource(); +} static Value getMemRefOperand(vector::TransferWriteOp op) { - return op.source(); + return op.getSource(); } /// Given the permutation map of the original @@ -175,9 +177,9 @@ void LoadOpOfSubViewFolder::replaceOp( transferReadOp, transferReadOp.getVectorType(), subViewOp.source(), sourceIndices, getPermutationMapAttr(rewriter.getContext(), subViewOp, - transferReadOp.permutation_map()), - transferReadOp.padding(), - /*mask=*/Value(), transferReadOp.in_boundsAttr()); + transferReadOp.getPermutationMap()), + transferReadOp.getPadding(), + /*mask=*/Value(), transferReadOp.getInBoundsAttr()); } template @@ -196,11 +198,11 @@ void StoreOpOfSubViewFolder::replaceOp( if (transferWriteOp.getTransferRank() == 0) return; rewriter.replaceOpWithNewOp( - transferWriteOp, transferWriteOp.vector(), subViewOp.source(), + transferWriteOp, transferWriteOp.getVector(), subViewOp.source(), sourceIndices, getPermutationMapAttr(rewriter.getContext(), subViewOp, - transferWriteOp.permutation_map()), - transferWriteOp.in_boundsAttr()); + transferWriteOp.getPermutationMap()), + transferWriteOp.getInBoundsAttr()); } } // namespace @@ -215,7 +217,7 @@ LoadOpOfSubViewFolder::matchAndRewrite(OpTy loadOp, SmallVector sourceIndices; if (failed(resolveSourceIndices(loadOp.getLoc(), rewriter, subViewOp, - loadOp.indices(), sourceIndices))) + loadOp.getIndices(), sourceIndices))) return failure(); replaceOp(loadOp, subViewOp, sourceIndices, rewriter); @@ -233,7 +235,7 @@ StoreOpOfSubViewFolder::matchAndRewrite(OpTy storeOp, SmallVector sourceIndices; if (failed(resolveSourceIndices(storeOp.getLoc(), rewriter, subViewOp, - storeOp.indices(), sourceIndices))) + storeOp.getIndices(), sourceIndices))) return failure(); replaceOp(storeOp, subViewOp, sourceIndices, rewriter); diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 9cf1538..cddc034 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -76,7 +76,7 @@ static MaskFormat get1DMaskFormat(Value mask) { // Inspect constant mask index. If the index exceeds the // dimension size, all bits are set. If the index is zero // or less, no bits are set. - ArrayAttr masks = m.mask_dim_sizes(); + ArrayAttr masks = m.getMaskDimSizes(); assert(masks.size() == 1); int64_t i = masks[0].cast().getInt(); int64_t u = m.getType().getDimSize(0); @@ -140,18 +140,18 @@ AffineMap mlir::vector::getTransferMinorIdentityMap(ShapedType shapedType, bool mlir::vector::checkSameValueRAW(vector::TransferWriteOp defWrite, vector::TransferReadOp read) { - return !defWrite.hasOutOfBoundsDim() && !defWrite.mask() && !read.mask() && - defWrite.indices() == read.indices() && + return !defWrite.hasOutOfBoundsDim() && !defWrite.getMask() && + !read.getMask() && defWrite.getIndices() == read.getIndices() && defWrite.getVectorType() == read.getVectorType() && - defWrite.permutation_map() == read.permutation_map(); + defWrite.getPermutationMap() == read.getPermutationMap(); } bool mlir::vector::checkSameValueWAW(vector::TransferWriteOp write, vector::TransferWriteOp priorWrite) { - return priorWrite.indices() == write.indices() && - priorWrite.mask() == write.mask() && + return priorWrite.getIndices() == write.getIndices() && + priorWrite.getMask() == write.getMask() && priorWrite.getVectorType() == write.getVectorType() && - priorWrite.permutation_map() == write.permutation_map(); + priorWrite.getPermutationMap() == write.getPermutationMap(); } bool mlir::vector::isDisjointTransferIndices( @@ -348,10 +348,10 @@ LogicalResult MultiDimReductionOp::inferReturnTypes( DictionaryAttr attributes, RegionRange, SmallVectorImpl &inferredReturnTypes) { MultiDimReductionOp::Adaptor op(operands, attributes); - auto vectorType = op.source().getType().cast(); + auto vectorType = op.getSource().getType().cast(); SmallVector targetShape; for (auto it : llvm::enumerate(vectorType.getShape())) - if (!llvm::any_of(op.reduction_dims().getValue(), [&](Attribute attr) { + if (!llvm::any_of(op.getReductionDims().getValue(), [&](Attribute attr) { return attr.cast().getValue() == it.index(); })) targetShape.push_back(it.value()); @@ -367,7 +367,7 @@ LogicalResult MultiDimReductionOp::inferReturnTypes( OpFoldResult MultiDimReductionOp::fold(ArrayRef operands) { // Single parallel dim, this is a noop. if (getSourceVectorType().getRank() == 1 && !isReducedDim(0)) - return source(); + return getSource(); return {}; } @@ -397,17 +397,17 @@ LogicalResult ReductionOp::verify() { return emitOpError("unsupported reduction rank: ") << rank; // Verify supported reduction kind. - Type eltType = dest().getType(); - if (!isSupportedCombiningKind(kind(), eltType)) + Type eltType = getDest().getType(); + if (!isSupportedCombiningKind(getKind(), eltType)) return emitOpError("unsupported reduction type '") - << eltType << "' for kind '" << stringifyCombiningKind(kind()) + << eltType << "' for kind '" << stringifyCombiningKind(getKind()) << "'"; // Verify optional accumulator. - if (acc()) { - if (kind() != CombiningKind::ADD && kind() != CombiningKind::MUL) + if (getAcc()) { + if (getKind() != CombiningKind::ADD && getKind() != CombiningKind::MUL) return emitOpError("no accumulator for reduction kind: ") - << stringifyCombiningKind(kind()); + << stringifyCombiningKind(getKind()); if (!eltType.isa()) return emitOpError("no accumulator for type: ") << eltType; } @@ -439,11 +439,11 @@ ParseResult ReductionOp::parse(OpAsmParser &parser, OperationState &result) { void ReductionOp::print(OpAsmPrinter &p) { p << " "; - kindAttr().print(p); - p << ", " << vector(); - if (acc()) - p << ", " << acc(); - p << " : " << vector().getType() << " into " << dest().getType(); + getKindAttr().print(p); + p << ", " << getVector(); + if (getAcc()) + p << ", " << getAcc(); + p << " : " << getVector().getType() << " into " << getDest().getType(); } Value mlir::vector::getVectorReductionOp(arith::AtomicRMWKind op, @@ -582,13 +582,13 @@ void ContractionOp::print(OpAsmPrinter &p) { attrs.push_back(attr); auto dictAttr = DictionaryAttr::get(getContext(), attrs); - p << " " << dictAttr << " " << lhs() << ", "; - p << rhs() << ", " << acc(); - if (masks().size() == 2) - p << ", " << masks(); + p << " " << dictAttr << " " << getLhs() << ", "; + p << getRhs() << ", " << getAcc(); + if (getMasks().size() == 2) + p << ", " << getMasks(); p.printOptionalAttrDict((*this)->getAttrs(), attrNames); - p << " : " << lhs().getType() << ", " << rhs().getType() << " into " + p << " : " << getLhs().getType() << ", " << getRhs().getType() << " into " << getResultType(); } @@ -696,14 +696,14 @@ LogicalResult ContractionOp::verify() { auto resType = getResultType(); // Verify that an indexing map was specified for each vector operand. - if (indexing_maps().size() != 3) + if (getIndexingMaps().size() != 3) return emitOpError("expected an indexing map for each vector operand"); // Verify that each index map has 'numIterators' inputs, no symbols, and // that the number of map outputs equals the rank of its associated // vector operand. - unsigned numIterators = iterator_types().getValue().size(); - for (const auto &it : llvm::enumerate(indexing_maps())) { + unsigned numIterators = getIteratorTypes().getValue().size(); + for (const auto &it : llvm::enumerate(getIndexingMaps())) { auto index = it.index(); auto map = it.value(); if (map.getNumSymbols() != 0) @@ -759,7 +759,7 @@ LogicalResult ContractionOp::verify() { // Verify supported combining kind. auto vectorType = resType.dyn_cast(); auto elementType = vectorType ? vectorType.getElementType() : resType; - if (!isSupportedCombiningKind(kind(), elementType)) + if (!isSupportedCombiningKind(getKind(), elementType)) return emitOpError("unsupported contraction type"); return success(); @@ -803,7 +803,7 @@ void ContractionOp::getIterationBounds( auto resVectorType = getResultType().dyn_cast(); SmallVector indexingMaps(getIndexingMaps()); SmallVector iterationShape; - for (const auto &it : llvm::enumerate(iterator_types())) { + for (const auto &it : llvm::enumerate(getIteratorTypes())) { // Search lhs/rhs map results for 'targetExpr'. auto targetExpr = getAffineDimExpr(it.index(), getContext()); auto iteratorTypeName = it.value().cast().getValue(); @@ -824,9 +824,9 @@ void ContractionOp::getIterationBounds( void ContractionOp::getIterationIndexMap( std::vector> &iterationIndexMap) { - unsigned numMaps = indexing_maps().size(); + unsigned numMaps = getIndexingMaps().size(); iterationIndexMap.resize(numMaps); - for (const auto &it : llvm::enumerate(indexing_maps())) { + for (const auto &it : llvm::enumerate(getIndexingMaps())) { auto index = it.index(); auto map = it.value(); for (unsigned i = 0, e = map.getNumResults(); i < e; ++i) { @@ -838,13 +838,13 @@ void ContractionOp::getIterationIndexMap( std::vector> ContractionOp::getContractingDimMap() { SmallVector indexingMaps(getIndexingMaps()); - return getDimMap(indexingMaps, iterator_types(), + return getDimMap(indexingMaps, getIteratorTypes(), getReductionIteratorTypeName(), getContext()); } std::vector> ContractionOp::getBatchDimMap() { SmallVector indexingMaps(getIndexingMaps()); - return getDimMap(indexingMaps, iterator_types(), + return getDimMap(indexingMaps, getIteratorTypes(), getParallelIteratorTypeName(), getContext()); } @@ -886,11 +886,11 @@ struct CanonicalizeContractAdd : public OpRewritePattern { if (!contractionOp) return vector::ContractionOp(); if (auto maybeZero = dyn_cast_or_null( - contractionOp.acc().getDefiningOp())) { + contractionOp.getAcc().getDefiningOp())) { if (maybeZero.getValue() == - rewriter.getZeroAttr(contractionOp.acc().getType())) { + rewriter.getZeroAttr(contractionOp.getAcc().getType())) { BlockAndValueMapping bvm; - bvm.map(contractionOp.acc(), otherOperand); + bvm.map(contractionOp.getAcc(), otherOperand); auto newContraction = cast(rewriter.clone(*contractionOp, bvm)); rewriter.replaceOp(addOp, newContraction.getResult()); @@ -932,13 +932,13 @@ void vector::ExtractElementOp::build(OpBuilder &builder, OperationState &result, LogicalResult vector::ExtractElementOp::verify() { VectorType vectorType = getVectorType(); if (vectorType.getRank() == 0) { - if (position()) + if (getPosition()) return emitOpError("expected position to be empty with 0-D vector"); return success(); } if (vectorType.getRank() != 1) return emitOpError("unexpected >1 vector rank"); - if (!position()) + if (!getPosition()) return emitOpError("expected position for 1-D vector"); return success(); } @@ -968,11 +968,12 @@ ExtractOp::inferReturnTypes(MLIRContext *, Optional, RegionRange, SmallVectorImpl &inferredReturnTypes) { ExtractOp::Adaptor op(operands, attributes); - auto vectorType = op.vector().getType().cast(); - if (static_cast(op.position().size()) == vectorType.getRank()) { + auto vectorType = op.getVector().getType().cast(); + if (static_cast(op.getPosition().size()) == vectorType.getRank()) { inferredReturnTypes.push_back(vectorType.getElementType()); } else { - auto n = std::min(op.position().size(), vectorType.getRank() - 1); + auto n = + std::min(op.getPosition().size(), vectorType.getRank() - 1); inferredReturnTypes.push_back(VectorType::get( vectorType.getShape().drop_front(n), vectorType.getElementType())); } @@ -993,7 +994,7 @@ bool ExtractOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { } LogicalResult vector::ExtractOp::verify() { - auto positionAttr = position().getValue(); + auto positionAttr = getPosition().getValue(); if (positionAttr.size() > static_cast(getVectorType().getRank())) return emitOpError( "expected position attribute of rank smaller than vector rank"); @@ -1019,19 +1020,19 @@ static SmallVector extractVector(ArrayAttr arrayAttr) { /// Fold the result of chains of ExtractOp in place by simply concatenating the /// positions. static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp) { - if (!extractOp.vector().getDefiningOp()) + if (!extractOp.getVector().getDefiningOp()) return failure(); SmallVector globalPosition; ExtractOp currentOp = extractOp; - auto extrPos = extractVector(currentOp.position()); + auto extrPos = extractVector(currentOp.getPosition()); globalPosition.append(extrPos.rbegin(), extrPos.rend()); - while (ExtractOp nextOp = currentOp.vector().getDefiningOp()) { + while (ExtractOp nextOp = currentOp.getVector().getDefiningOp()) { currentOp = nextOp; - auto extrPos = extractVector(currentOp.position()); + auto extrPos = extractVector(currentOp.getPosition()); globalPosition.append(extrPos.rbegin(), extrPos.rend()); } - extractOp.setOperand(currentOp.vector()); + extractOp.setOperand(currentOp.getVector()); // OpBuilder is only used as a helper to build an I64ArrayAttr. OpBuilder b(extractOp.getContext()); std::reverse(globalPosition.begin(), globalPosition.end()); @@ -1143,12 +1144,12 @@ private: ExtractFromInsertTransposeChainState::ExtractFromInsertTransposeChainState( ExtractOp e) : extractOp(e), vectorRank(extractOp.getVectorType().getRank()), - extractedRank(extractOp.position().size()) { + extractedRank(extractOp.getPosition().size()) { assert(vectorRank >= extractedRank && "extracted pos overflow"); sentinels.reserve(vectorRank - extractedRank); for (int64_t i = 0, e = vectorRank - extractedRank; i < e; ++i) sentinels.push_back(-(i + 1)); - extractPosition = extractVector(extractOp.position()); + extractPosition = extractVector(extractOp.getPosition()); llvm::append_range(extractPosition, sentinels); } @@ -1157,7 +1158,7 @@ ExtractFromInsertTransposeChainState::ExtractFromInsertTransposeChainState( LogicalResult ExtractFromInsertTransposeChainState::handleTransposeOp() { if (!nextTransposeOp) return failure(); - auto permutation = extractVector(nextTransposeOp.transp()); + auto permutation = extractVector(nextTransposeOp.getTransp()); AffineMap m = inversePermutation( AffineMap::getPermutationMap(permutation, extractOp.getContext())); extractPosition = applyPermutationMap(m, makeArrayRef(extractPosition)); @@ -1168,12 +1169,12 @@ LogicalResult ExtractFromInsertTransposeChainState::handleTransposeOp() { LogicalResult ExtractFromInsertTransposeChainState::handleInsertOpWithMatchingPos( Value &res) { - auto insertedPos = extractVector(nextInsertOp.position()); + auto insertedPos = extractVector(nextInsertOp.getPosition()); if (makeArrayRef(insertedPos) != llvm::makeArrayRef(extractPosition).take_front(extractedRank)) return failure(); // Case 2.a. early-exit fold. - res = nextInsertOp.source(); + res = nextInsertOp.getSource(); // Case 2.b. if internal transposition is present, canFold will be false. return success(); } @@ -1183,7 +1184,7 @@ ExtractFromInsertTransposeChainState::handleInsertOpWithMatchingPos( /// This method updates the internal state. LogicalResult ExtractFromInsertTransposeChainState::handleInsertOpWithPrefixPos(Value &res) { - auto insertedPos = extractVector(nextInsertOp.position()); + auto insertedPos = extractVector(nextInsertOp.getPosition()); if (!isContainedWithin(insertedPos, extractPosition)) return failure(); // Set leading dims to zero. @@ -1193,7 +1194,7 @@ ExtractFromInsertTransposeChainState::handleInsertOpWithPrefixPos(Value &res) { extractPosition.begin() + insertedPos.size()); extractedRank = extractPosition.size() - sentinels.size(); // Case 3.a. early-exit fold (break and delegate to post-while path). - res = nextInsertOp.source(); + res = nextInsertOp.getSource(); // Case 3.b. if internal transposition is present, canFold will be false. return success(); } @@ -1204,28 +1205,28 @@ ExtractFromInsertTransposeChainState::handleInsertOpWithPrefixPos(Value &res) { Value ExtractFromInsertTransposeChainState::tryToFoldExtractOpInPlace( Value source) { // If we can't fold (either internal transposition, or nothing to fold), bail. - bool nothingToFold = (source == extractOp.vector()); + bool nothingToFold = (source == extractOp.getVector()); if (nothingToFold || !canFold()) return Value(); // Otherwise, fold by updating the op inplace and return its result. OpBuilder b(extractOp.getContext()); extractOp->setAttr( - extractOp.positionAttrName(), + extractOp.getPositionAttrName(), b.getI64ArrayAttr( makeArrayRef(extractPosition).take_front(extractedRank))); - extractOp.vectorMutable().assign(source); + extractOp.getVectorMutable().assign(source); return extractOp.getResult(); } /// Iterate over producing insert and transpose ops until we find a fold. Value ExtractFromInsertTransposeChainState::fold() { - Value valueToExtractFrom = extractOp.vector(); + Value valueToExtractFrom = extractOp.getVector(); updateStateForNextIteration(valueToExtractFrom); while (nextInsertOp || nextTransposeOp) { // Case 1. If we hit a transpose, just compose the map and iterate. // Invariant: insert + transpose do not change rank, we can always compose. if (succeeded(handleTransposeOp())) { - valueToExtractFrom = nextTransposeOp.vector(); + valueToExtractFrom = nextTransposeOp.getVector(); updateStateForNextIteration(valueToExtractFrom); continue; } @@ -1242,13 +1243,13 @@ Value ExtractFromInsertTransposeChainState::fold() { // Case 4: extractPositionRef intersects insertedPosRef on non-sentinel // values. This is a more difficult case and we bail. - auto insertedPos = extractVector(nextInsertOp.position()); + auto insertedPos = extractVector(nextInsertOp.getPosition()); if (isContainedWithin(extractPosition, insertedPos) || intersectsWhereNonNegative(extractPosition, insertedPos)) return Value(); // Case 5: No intersection, we forward the extract to insertOp.dest(). - valueToExtractFrom = nextInsertOp.dest(); + valueToExtractFrom = nextInsertOp.getDest(); updateStateForNextIteration(valueToExtractFrom); } // If after all this we can fold, go for it. @@ -1257,7 +1258,7 @@ Value ExtractFromInsertTransposeChainState::fold() { /// Fold extractOp with scalar result coming from BroadcastOp or SplatOp. static Value foldExtractFromBroadcast(ExtractOp extractOp) { - Operation *defOp = extractOp.vector().getDefiningOp(); + Operation *defOp = extractOp.getVector().getDefiningOp(); if (!defOp || !isa(defOp)) return Value(); Value source = defOp->getOperand(0); @@ -1269,7 +1270,7 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) { unsigned broadcastSrcRank = getRank(source.getType()); unsigned extractResultRank = getRank(extractOp.getType()); if (extractResultRank < broadcastSrcRank) { - auto extractPos = extractVector(extractOp.position()); + auto extractPos = extractVector(extractOp.getPosition()); unsigned rankDiff = broadcastSrcRank - extractResultRank; extractPos.erase( extractPos.begin(), @@ -1286,7 +1287,7 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) { // Fold extractOp with source coming from ShapeCast op. static Value foldExtractFromShapeCast(ExtractOp extractOp) { - auto shapeCastOp = extractOp.vector().getDefiningOp(); + auto shapeCastOp = extractOp.getVector().getDefiningOp(); if (!shapeCastOp) return Value(); // Get the nth dimension size starting from lowest dimension. @@ -1312,7 +1313,7 @@ static Value foldExtractFromShapeCast(ExtractOp extractOp) { } // Extract the strides associated with the extract op vector source. Then use // this to calculate a linearized position for the extract. - auto extractedPos = extractVector(extractOp.position()); + auto extractedPos = extractVector(extractOp.getPosition()); std::reverse(extractedPos.begin(), extractedPos.end()); SmallVector strides; int64_t stride = 1; @@ -1339,14 +1340,14 @@ static Value foldExtractFromShapeCast(ExtractOp extractOp) { OpBuilder b(extractOp.getContext()); extractOp->setAttr(ExtractOp::getPositionAttrStrName(), b.getI64ArrayAttr(newPosition)); - extractOp.setOperand(shapeCastOp.source()); + extractOp.setOperand(shapeCastOp.getSource()); return extractOp.getResult(); } /// Fold an ExtractOp from ExtractStridedSliceOp. static Value foldExtractFromExtractStrided(ExtractOp extractOp) { auto extractStridedSliceOp = - extractOp.vector().getDefiningOp(); + extractOp.getVector().getDefiningOp(); if (!extractStridedSliceOp) return Value(); // Return if 'extractStridedSliceOp' has non-unit strides. @@ -1354,7 +1355,8 @@ static Value foldExtractFromExtractStrided(ExtractOp extractOp) { return Value(); // Trim offsets for dimensions fully extracted. - auto sliceOffsets = extractVector(extractStridedSliceOp.offsets()); + auto sliceOffsets = + extractVector(extractStridedSliceOp.getOffsets()); while (!sliceOffsets.empty()) { size_t lastOffset = sliceOffsets.size() - 1; if (sliceOffsets.back() != 0 || @@ -1371,11 +1373,11 @@ static Value foldExtractFromExtractStrided(ExtractOp extractOp) { if (destinationRank > extractStridedSliceOp.getVectorType().getRank() - sliceOffsets.size()) return Value(); - auto extractedPos = extractVector(extractOp.position()); + auto extractedPos = extractVector(extractOp.getPosition()); assert(extractedPos.size() >= sliceOffsets.size()); for (size_t i = 0, e = sliceOffsets.size(); i < e; i++) extractedPos[i] = extractedPos[i] + sliceOffsets[i]; - extractOp.vectorMutable().assign(extractStridedSliceOp.vector()); + extractOp.getVectorMutable().assign(extractStridedSliceOp.getVector()); // OpBuilder is only used as a helper to build an I64ArrayAttr. OpBuilder b(extractOp.getContext()); extractOp->setAttr(ExtractOp::getPositionAttrStrName(), @@ -1388,16 +1390,16 @@ static Value foldExtractStridedOpFromInsertChain(ExtractOp op) { int64_t destinationRank = op.getType().isa() ? op.getType().cast().getRank() : 0; - auto insertOp = op.vector().getDefiningOp(); + auto insertOp = op.getVector().getDefiningOp(); while (insertOp) { int64_t insertRankDiff = insertOp.getDestVectorType().getRank() - insertOp.getSourceVectorType().getRank(); if (destinationRank > insertOp.getSourceVectorType().getRank()) return Value(); - auto insertOffsets = extractVector(insertOp.offsets()); - auto extractOffsets = extractVector(op.position()); + auto insertOffsets = extractVector(insertOp.getOffsets()); + auto extractOffsets = extractVector(op.getPosition()); - if (llvm::any_of(insertOp.strides(), [](Attribute attr) { + if (llvm::any_of(insertOp.getStrides(), [](Attribute attr) { return attr.cast().getInt() != 1; })) return Value(); @@ -1432,7 +1434,7 @@ static Value foldExtractStridedOpFromInsertChain(ExtractOp op) { insertRankDiff)) return Value(); } - op.vectorMutable().assign(insertOp.source()); + op.getVectorMutable().assign(insertOp.getSource()); // OpBuilder is only used as a helper to build an I64ArrayAttr. OpBuilder b(op.getContext()); op->setAttr(ExtractOp::getPositionAttrStrName(), @@ -1441,14 +1443,14 @@ static Value foldExtractStridedOpFromInsertChain(ExtractOp op) { } // If the chunk extracted is disjoint from the chunk inserted, keep // looking in the insert chain. - insertOp = insertOp.dest().getDefiningOp(); + insertOp = insertOp.getDest().getDefiningOp(); } return Value(); } OpFoldResult ExtractOp::fold(ArrayRef) { - if (position().empty()) - return vector(); + if (getPosition().empty()) + return getVector(); if (succeeded(foldExtractOpFromExtractChain(*this))) return getResult(); if (auto res = ExtractFromInsertTransposeChainState(*this).fold()) @@ -1473,7 +1475,7 @@ public: LogicalResult matchAndRewrite(ExtractOp extractOp, PatternRewriter &rewriter) const override { - Operation *defOp = extractOp.vector().getDefiningOp(); + Operation *defOp = extractOp.getVector().getDefiningOp(); if (!defOp || !isa(defOp)) return failure(); Value source = defOp->getOperand(0); @@ -1504,7 +1506,7 @@ public: PatternRewriter &rewriter) const override { // Return if 'extractStridedSliceOp' operand is not defined by a // ConstantOp. - auto constantOp = extractOp.vector().getDefiningOp(); + auto constantOp = extractOp.getVector().getDefiningOp(); if (!constantOp) return failure(); auto dense = constantOp.getValue().dyn_cast(); @@ -1566,18 +1568,18 @@ LogicalResult ExtractMapOp::verify() { if (getSourceVectorType().getDimSize(i) != getResultType().getDimSize(i)) numId++; } - if (numId != ids().size()) + if (numId != getIds().size()) return emitOpError("expected number of ids must match the number of " "dimensions distributed"); return success(); } OpFoldResult ExtractMapOp::fold(ArrayRef operands) { - auto insert = vector().getDefiningOp(); - if (insert == nullptr || getType() != insert.vector().getType() || - ids() != insert.ids()) + auto insert = getVector().getDefiningOp(); + if (insert == nullptr || getType() != insert.getVector().getType() || + getIds() != insert.getIds()) return {}; - return insert.vector(); + return insert.getVector(); } void ExtractMapOp::getMultiplicity(SmallVectorImpl &multiplicity) { @@ -1670,7 +1672,7 @@ LogicalResult BroadcastOp::verify() { OpFoldResult BroadcastOp::fold(ArrayRef operands) { if (getSourceType() == getVectorType()) - return source(); + return getSource(); if (!operands[0]) return {}; auto vectorType = getVectorType(); @@ -1689,11 +1691,11 @@ struct BroadcastFolder : public OpRewritePattern { LogicalResult matchAndRewrite(BroadcastOp broadcastOp, PatternRewriter &rewriter) const override { - auto srcBroadcast = broadcastOp.source().getDefiningOp(); + auto srcBroadcast = broadcastOp.getSource().getDefiningOp(); if (!srcBroadcast) return failure(); rewriter.replaceOpWithNewOp( - broadcastOp, broadcastOp.getVectorType(), srcBroadcast.source()); + broadcastOp, broadcastOp.getVectorType(), srcBroadcast.getSource()); return success(); } }; @@ -1734,7 +1736,7 @@ LogicalResult ShuffleOp::verify() { return emitOpError("dimension mismatch"); } // Verify mask length. - auto maskAttr = mask().getValue(); + auto maskAttr = getMask().getValue(); int64_t maskLength = maskAttr.size(); if (maskLength <= 0) return emitOpError("invalid mask length"); @@ -1756,12 +1758,12 @@ ShuffleOp::inferReturnTypes(MLIRContext *, Optional, RegionRange, SmallVectorImpl &inferredReturnTypes) { ShuffleOp::Adaptor op(operands, attributes); - auto v1Type = op.v1().getType().cast(); + auto v1Type = op.getV1().getType().cast(); // Construct resulting type: leading dimension matches mask length, // all trailing dimensions match the operands. SmallVector shape; shape.reserve(v1Type.getRank()); - shape.push_back(std::max(1, op.mask().size())); + shape.push_back(std::max(1, op.getMask().size())); llvm::append_range(shape, v1Type.getShape().drop_front()); inferredReturnTypes.push_back( VectorType::get(shape, v1Type.getElementType())); @@ -1783,7 +1785,7 @@ OpFoldResult vector::ShuffleOp::fold(ArrayRef operands) { SmallVector results; auto lhsElements = lhs.cast().getValues(); auto rhsElements = rhs.cast().getValues(); - for (const auto &index : this->mask().getAsValueRange()) { + for (const auto &index : this->getMask().getAsValueRange()) { int64_t i = index.getZExtValue(); if (i >= lhsSize) { results.push_back(rhsElements[i - lhsSize]); @@ -1807,13 +1809,13 @@ void InsertElementOp::build(OpBuilder &builder, OperationState &result, LogicalResult InsertElementOp::verify() { auto dstVectorType = getDestVectorType(); if (dstVectorType.getRank() == 0) { - if (position()) + if (getPosition()) return emitOpError("expected position to be empty with 0-D vector"); return success(); } if (dstVectorType.getRank() != 1) return emitOpError("unexpected >1 vector rank"); - if (!position()) + if (!getPosition()) return emitOpError("expected position for 1-D vector"); return success(); } @@ -1841,7 +1843,7 @@ void InsertOp::build(OpBuilder &builder, OperationState &result, Value source, } LogicalResult InsertOp::verify() { - auto positionAttr = position().getValue(); + auto positionAttr = getPosition().getValue(); auto destVectorType = getDestVectorType(); if (positionAttr.size() > static_cast(destVectorType.getRank())) return emitOpError( @@ -1883,7 +1885,7 @@ public: srcVecType.getNumElements()) return failure(); rewriter.replaceOpWithNewOp( - insertOp, insertOp.getDestVectorType(), insertOp.source()); + insertOp, insertOp.getDestVectorType(), insertOp.getSource()); return success(); } }; @@ -1899,8 +1901,8 @@ void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results, // value. This happens when the source and destination vectors have identical // sizes. OpFoldResult vector::InsertOp::fold(ArrayRef operands) { - if (position().empty()) - return source(); + if (getPosition().empty()) + return getSource(); return {}; } @@ -1920,7 +1922,7 @@ LogicalResult InsertMapOp::verify() { if (getResultType().getDimSize(i) != getSourceVectorType().getDimSize(i)) numId++; } - if (numId != ids().size()) + if (numId != getIds().size()) return emitOpError("expected number of ids must match the number of " "dimensions distributed"); return success(); @@ -2037,8 +2039,8 @@ static ArrayAttr makeI64ArrayAttr(ArrayRef values, LogicalResult InsertStridedSliceOp::verify() { auto sourceVectorType = getSourceVectorType(); auto destVectorType = getDestVectorType(); - auto offsets = offsetsAttr(); - auto strides = stridesAttr(); + auto offsets = getOffsetsAttr(); + auto strides = getStridesAttr(); if (offsets.size() != static_cast(destVectorType.getRank())) return emitOpError( "expected offsets of same size as destination vector rank"); @@ -2072,7 +2074,7 @@ LogicalResult InsertStridedSliceOp::verify() { OpFoldResult InsertStridedSliceOp::fold(ArrayRef operands) { if (getSourceVectorType() == getDestVectorType()) - return source(); + return getSource(); return {}; } @@ -2088,12 +2090,12 @@ void OuterProductOp::build(OpBuilder &builder, OperationState &result, } void OuterProductOp::print(OpAsmPrinter &p) { - p << " " << lhs() << ", " << rhs(); - if (!acc().empty()) { - p << ", " << acc(); + p << " " << getLhs() << ", " << getRhs(); + if (!getAcc().empty()) { + p << ", " << getAcc(); p.printOptionalAttrDict((*this)->getAttrs()); } - p << " : " << lhs().getType() << ", " << rhs().getType(); + p << " : " << getLhs().getType() << ", " << getRhs().getType(); } ParseResult OuterProductOp::parse(OpAsmParser &parser, OperationState &result) { @@ -2163,7 +2165,7 @@ LogicalResult OuterProductOp::verify() { return emitOpError("expected operand #3 of same type as result type"); // Verify supported combining kind. - if (!isSupportedCombiningKind(kind(), vRES.getElementType())) + if (!isSupportedCombiningKind(getKind(), vRES.getElementType())) return emitOpError("unsupported outerproduct type"); return success(); @@ -2214,14 +2216,14 @@ LogicalResult ReshapeOp::verify() { auto isDefByConstant = [](Value operand) { return isa_and_nonnull(operand.getDefiningOp()); }; - if (llvm::all_of(input_shape(), isDefByConstant) && - llvm::all_of(output_shape(), isDefByConstant)) { + if (llvm::all_of(getInputShape(), isDefByConstant) && + llvm::all_of(getOutputShape(), isDefByConstant)) { int64_t numInputElements = 1; - for (auto operand : input_shape()) + for (auto operand : getInputShape()) numInputElements *= cast(operand.getDefiningOp()).value(); int64_t numOutputElements = 1; - for (auto operand : output_shape()) + for (auto operand : getOutputShape()) numOutputElements *= cast(operand.getDefiningOp()).value(); if (numInputElements != numOutputElements) @@ -2231,7 +2233,7 @@ LogicalResult ReshapeOp::verify() { } void ReshapeOp::getFixedVectorSizes(SmallVectorImpl &results) { - populateFromInt64AttrArray(fixed_vector_sizes(), results); + populateFromInt64AttrArray(getFixedVectorSizes(), results); } //===----------------------------------------------------------------------===// @@ -2274,9 +2276,9 @@ void ExtractStridedSliceOp::build(OpBuilder &builder, OperationState &result, LogicalResult ExtractStridedSliceOp::verify() { auto type = getVectorType(); - auto offsets = offsetsAttr(); - auto sizes = sizesAttr(); - auto strides = stridesAttr(); + auto offsets = getOffsetsAttr(); + auto sizes = getSizesAttr(); + auto strides = getStridesAttr(); if (offsets.size() != sizes.size() || offsets.size() != strides.size()) return emitOpError("expected offsets, sizes and strides attributes of same size"); @@ -2316,16 +2318,16 @@ foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op) { auto getElement = [](ArrayAttr array, int idx) { return array[idx].cast().getInt(); }; - ArrayAttr extractOffsets = op.offsets(); - ArrayAttr extractStrides = op.strides(); - ArrayAttr extractSizes = op.sizes(); - auto insertOp = op.vector().getDefiningOp(); + ArrayAttr extractOffsets = op.getOffsets(); + ArrayAttr extractStrides = op.getStrides(); + ArrayAttr extractSizes = op.getSizes(); + auto insertOp = op.getVector().getDefiningOp(); while (insertOp) { if (op.getVectorType().getRank() != insertOp.getSourceVectorType().getRank()) return failure(); - ArrayAttr insertOffsets = insertOp.offsets(); - ArrayAttr insertStrides = insertOp.strides(); + ArrayAttr insertOffsets = insertOp.getOffsets(); + ArrayAttr insertStrides = insertOp.getStrides(); // If the rank of extract is greater than the rank of insert, we are likely // extracting a partial chunk of the vector inserted. if (extractOffsets.size() > insertOffsets.size()) @@ -2354,7 +2356,7 @@ foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op) { } // The extract element chunk is a subset of the insert element. if (!disjoint && !patialoverlap) { - op.setOperand(insertOp.source()); + op.setOperand(insertOp.getSource()); // OpBuilder is only used as a helper to build an I64ArrayAttr. OpBuilder b(op.getContext()); op->setAttr(ExtractStridedSliceOp::getOffsetsAttrStrName(), @@ -2364,7 +2366,7 @@ foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op) { // If the chunk extracted is disjoint from the chunk inserted, keep looking // in the insert chain. if (disjoint) - insertOp = insertOp.dest().getDefiningOp(); + insertOp = insertOp.getDest().getDefiningOp(); else { // The extracted vector partially overlap the inserted vector, we cannot // fold. @@ -2376,14 +2378,14 @@ foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op) { OpFoldResult ExtractStridedSliceOp::fold(ArrayRef operands) { if (getVectorType() == getResult().getType()) - return vector(); + return getVector(); if (succeeded(foldExtractStridedOpFromInsertChain(*this))) return getResult(); return {}; } void ExtractStridedSliceOp::getOffsets(SmallVectorImpl &results) { - populateFromInt64AttrArray(offsets(), results); + populateFromInt64AttrArray(getOffsets(), results); } namespace { @@ -2399,7 +2401,7 @@ public: PatternRewriter &rewriter) const override { // Return if 'extractStridedSliceOp' operand is not defined by a // ConstantMaskOp. - auto *defOp = extractStridedSliceOp.vector().getDefiningOp(); + auto *defOp = extractStridedSliceOp.getVector().getDefiningOp(); auto constantMaskOp = dyn_cast_or_null(defOp); if (!constantMaskOp) return failure(); @@ -2408,12 +2410,13 @@ public: return failure(); // Gather constant mask dimension sizes. SmallVector maskDimSizes; - populateFromInt64AttrArray(constantMaskOp.mask_dim_sizes(), maskDimSizes); + populateFromInt64AttrArray(constantMaskOp.getMaskDimSizes(), maskDimSizes); // Gather strided slice offsets and sizes. SmallVector sliceOffsets; - populateFromInt64AttrArray(extractStridedSliceOp.offsets(), sliceOffsets); + populateFromInt64AttrArray(extractStridedSliceOp.getOffsets(), + sliceOffsets); SmallVector sliceSizes; - populateFromInt64AttrArray(extractStridedSliceOp.sizes(), sliceSizes); + populateFromInt64AttrArray(extractStridedSliceOp.getSizes(), sliceSizes); // Compute slice of vector mask region. SmallVector sliceMaskDimSizes; @@ -2452,7 +2455,7 @@ public: // Return if 'extractStridedSliceOp' operand is not defined by a // ConstantOp. auto constantOp = - extractStridedSliceOp.vector().getDefiningOp(); + extractStridedSliceOp.getVector().getDefiningOp(); if (!constantOp) return failure(); auto dense = constantOp.getValue().dyn_cast(); @@ -2475,10 +2478,10 @@ public: LogicalResult matchAndRewrite(ExtractStridedSliceOp op, PatternRewriter &rewriter) const override { - auto broadcast = op.vector().getDefiningOp(); + auto broadcast = op.getVector().getDefiningOp(); if (!broadcast) return failure(); - auto srcVecType = broadcast.source().getType().dyn_cast(); + auto srcVecType = broadcast.getSource().getType().dyn_cast(); unsigned srcRrank = srcVecType ? srcVecType.getRank() : 0; auto dstVecType = op.getType().cast(); unsigned dstRank = dstVecType.getRank(); @@ -2493,15 +2496,15 @@ public: break; } } - Value source = broadcast.source(); + Value source = broadcast.getSource(); if (!lowerDimMatch) { // The inner dimensions don't match, it means we need to extract from the // source of the orignal broadcast and then broadcast the extracted value. source = rewriter.create( op->getLoc(), source, - getI64SubArray(op.offsets(), /* dropFront=*/rankDiff), - getI64SubArray(op.sizes(), /* dropFront=*/rankDiff), - getI64SubArray(op.strides(), /* dropFront=*/rankDiff)); + getI64SubArray(op.getOffsets(), /* dropFront=*/rankDiff), + getI64SubArray(op.getSizes(), /* dropFront=*/rankDiff), + getI64SubArray(op.getStrides(), /* dropFront=*/rankDiff)); } rewriter.replaceOpWithNewOp(op, op.getType(), source); return success(); @@ -2515,10 +2518,10 @@ public: LogicalResult matchAndRewrite(ExtractStridedSliceOp op, PatternRewriter &rewriter) const override { - auto splat = op.vector().getDefiningOp(); + auto splat = op.getVector().getDefiningOp(); if (!splat) return failure(); - rewriter.replaceOpWithNewOp(op, op.getType(), splat.input()); + rewriter.replaceOpWithNewOp(op, op.getType(), splat.getInput()); return success(); } }; @@ -2726,9 +2729,9 @@ static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op) { } void TransferReadOp::print(OpAsmPrinter &p) { - p << " " << source() << "[" << indices() << "], " << padding(); - if (mask()) - p << ", " << mask(); + p << " " << getSource() << "[" << getIndices() << "], " << getPadding(); + if (getMask()) + p << ", " << getMask(); printTransferAttrs(p, *this); p << " : " << getShapedType() << ", " << getVectorType(); } @@ -2798,16 +2801,16 @@ LogicalResult TransferReadOp::verify() { ShapedType shapedType = getShapedType(); VectorType vectorType = getVectorType(); VectorType maskType = getMaskType(); - auto paddingType = padding().getType(); - auto permutationMap = permutation_map(); + auto paddingType = getPadding().getType(); + auto permutationMap = getPermutationMap(); auto sourceElementType = shapedType.getElementType(); - if (static_cast(indices().size()) != shapedType.getRank()) + if (static_cast(getIndices().size()) != shapedType.getRank()) return emitOpError("requires ") << shapedType.getRank() << " indices"; if (failed(verifyTransferOp(cast(getOperation()), shapedType, vectorType, maskType, permutationMap, - in_bounds() ? *in_bounds() : ArrayAttr()))) + getInBounds() ? *getInBounds() : ArrayAttr()))) return failure(); if (auto sourceVectorElementType = sourceElementType.dyn_cast()) { @@ -2867,7 +2870,7 @@ static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx) { // `op.indices()[indicesIdx] + vectorType < dim(op.source(), indicesIdx)` if (op.getShapedType().isDynamicDim(indicesIdx)) return false; - Value index = op.indices()[indicesIdx]; + Value index = op.getIndices()[indicesIdx]; auto cstOp = index.getDefiningOp(); if (!cstOp) return false; @@ -2884,7 +2887,7 @@ static LogicalResult foldTransferInBoundsAttribute(TransferOp op) { // TODO: Be less conservative. if (op.getTransferRank() == 0) return failure(); - AffineMap permutationMap = op.permutation_map(); + AffineMap permutationMap = op.getPermutationMap(); bool changed = false; SmallVector newInBounds; newInBounds.reserve(op.getTransferRank()); @@ -2926,15 +2929,15 @@ static LogicalResult foldTransferInBoundsAttribute(TransferOp op) { static Value foldRAW(TransferReadOp readOp) { if (!readOp.getShapedType().isa()) return {}; - auto defWrite = readOp.source().getDefiningOp(); + auto defWrite = readOp.getSource().getDefiningOp(); while (defWrite) { if (checkSameValueRAW(defWrite, readOp)) - return defWrite.vector(); + return defWrite.getVector(); if (!isDisjointTransferIndices( cast(defWrite.getOperation()), cast(readOp.getOperation()))) break; - defWrite = defWrite.source().getDefiningOp(); + defWrite = defWrite.getSource().getDefiningOp(); } return {}; } @@ -2960,7 +2963,7 @@ void TransferReadOp::getEffects( SmallVectorImpl> &effects) { if (getShapedType().isa()) - effects.emplace_back(MemoryEffects::Read::get(), source(), + effects.emplace_back(MemoryEffects::Read::get(), getSource(), SideEffects::DefaultResource::get()); } @@ -2992,11 +2995,11 @@ public: return failure(); if (xferOp.hasOutOfBoundsDim()) return failure(); - if (!xferOp.permutation_map().isIdentity()) + if (!xferOp.getPermutationMap().isIdentity()) return failure(); - if (xferOp.mask()) + if (xferOp.getMask()) return failure(); - auto extractOp = xferOp.source().getDefiningOp(); + auto extractOp = xferOp.getSource().getDefiningOp(); if (!extractOp) return failure(); if (!extractOp.hasUnitStride()) @@ -3039,7 +3042,7 @@ public: newIndices.push_back(getValueOrCreateConstantIndexOp( rewriter, extractOp.getLoc(), offset)); } - for (const auto &it : llvm::enumerate(xferOp.indices())) { + for (const auto &it : llvm::enumerate(xferOp.getIndices())) { OpFoldResult offset = extractOp.getMixedOffsets()[it.index() + rankReduced]; newIndices.push_back(rewriter.create( @@ -3050,7 +3053,7 @@ public: SmallVector inBounds(xferOp.getTransferRank(), true); rewriter.replaceOpWithNewOp( xferOp, xferOp.getVectorType(), extractOp.source(), newIndices, - xferOp.padding(), ArrayRef{inBounds}); + xferOp.getPadding(), ArrayRef{inBounds}); return success(); } @@ -3165,9 +3168,9 @@ ParseResult TransferWriteOp::parse(OpAsmParser &parser, } void TransferWriteOp::print(OpAsmPrinter &p) { - p << " " << vector() << ", " << source() << "[" << indices() << "]"; - if (mask()) - p << ", " << mask(); + p << " " << getVector() << ", " << getSource() << "[" << getIndices() << "]"; + if (getMask()) + p << ", " << getMask(); printTransferAttrs(p, *this); p << " : " << getVectorType() << ", " << getShapedType(); } @@ -3177,9 +3180,9 @@ LogicalResult TransferWriteOp::verify() { ShapedType shapedType = getShapedType(); VectorType vectorType = getVectorType(); VectorType maskType = getMaskType(); - auto permutationMap = permutation_map(); + auto permutationMap = getPermutationMap(); - if (llvm::size(indices()) != shapedType.getRank()) + if (llvm::size(getIndices()) != shapedType.getRank()) return emitOpError("requires ") << shapedType.getRank() << " indices"; // We do not allow broadcast dimensions on TransferWriteOps for the moment, @@ -3189,7 +3192,7 @@ LogicalResult TransferWriteOp::verify() { if (failed(verifyTransferOp(cast(getOperation()), shapedType, vectorType, maskType, permutationMap, - in_bounds() ? *in_bounds() : ArrayAttr()))) + getInBounds() ? *getInBounds() : ArrayAttr()))) return failure(); return verifyPermutationMap(permutationMap, @@ -3219,20 +3222,21 @@ static LogicalResult foldReadInitWrite(TransferWriteOp write, // TODO: support 0-d corner case. if (write.getTransferRank() == 0) return failure(); - auto rankedTensorType = write.source().getType().dyn_cast(); + auto rankedTensorType = + write.getSource().getType().dyn_cast(); // If not operating on tensors, bail. if (!rankedTensorType) return failure(); // If no read, bail. - auto read = write.vector().getDefiningOp(); + auto read = write.getVector().getDefiningOp(); if (!read) return failure(); // TODO: support 0-d corner case. if (read.getTransferRank() == 0) return failure(); // For now, only accept minor identity. Future: composition is minor identity. - if (!read.permutation_map().isMinorIdentity() || - !write.permutation_map().isMinorIdentity()) + if (!read.getPermutationMap().isMinorIdentity() || + !write.getPermutationMap().isMinorIdentity()) return failure(); // Bail on mismatching ranks. if (read.getTransferRank() != write.getTransferRank()) @@ -3241,7 +3245,7 @@ static LogicalResult foldReadInitWrite(TransferWriteOp write, if (read.hasOutOfBoundsDim() || write.hasOutOfBoundsDim()) return failure(); // Tensor types must be the same. - if (read.source().getType() != rankedTensorType) + if (read.getSource().getType() != rankedTensorType) return failure(); // Vector types must be the same. if (read.getVectorType() != write.getVectorType()) @@ -3254,20 +3258,21 @@ static LogicalResult foldReadInitWrite(TransferWriteOp write, auto cstOp = v.getDefiningOp(); return !cstOp || cstOp.value() != 0; }; - if (llvm::any_of(read.indices(), isNotConstantZero) || - llvm::any_of(write.indices(), isNotConstantZero)) + if (llvm::any_of(read.getIndices(), isNotConstantZero) || + llvm::any_of(write.getIndices(), isNotConstantZero)) return failure(); // Success. - results.push_back(read.source()); + results.push_back(read.getSource()); return success(); } static bool checkSameValueWAR(vector::TransferReadOp read, vector::TransferWriteOp write) { - return read.source() == write.source() && read.indices() == write.indices() && - read.permutation_map() == write.permutation_map() && - read.getVectorType() == write.getVectorType() && !read.mask() && - !write.mask(); + return read.getSource() == write.getSource() && + read.getIndices() == write.getIndices() && + read.getPermutationMap() == write.getPermutationMap() && + read.getVectorType() == write.getVectorType() && !read.getMask() && + !write.getMask(); } /// Fold transfer_write write after read: /// ``` @@ -3285,15 +3290,15 @@ static bool checkSameValueWAR(vector::TransferReadOp read, /// ``` static LogicalResult foldWAR(TransferWriteOp write, SmallVectorImpl &results) { - if (!write.source().getType().isa()) + if (!write.getSource().getType().isa()) return failure(); - auto read = write.vector().getDefiningOp(); + auto read = write.getVector().getDefiningOp(); if (!read) return failure(); if (!checkSameValueWAR(read, write)) return failure(); - results.push_back(read.source()); + results.push_back(read.getSource()); return success(); } @@ -3316,7 +3321,7 @@ void TransferWriteOp::getEffects( SmallVectorImpl> &effects) { if (getShapedType().isa()) - effects.emplace_back(MemoryEffects::Write::get(), source(), + effects.emplace_back(MemoryEffects::Write::get(), getSource(), SideEffects::DefaultResource::get()); } @@ -3354,10 +3359,11 @@ public: return failure(); vector::TransferWriteOp writeToModify = writeOp; - auto defWrite = writeOp.source().getDefiningOp(); + auto defWrite = + writeOp.getSource().getDefiningOp(); while (defWrite) { if (checkSameValueWAW(writeOp, defWrite)) { - writeToModify.sourceMutable().assign(defWrite.source()); + writeToModify.getSourceMutable().assign(defWrite.getSource()); return success(); } if (!isDisjointTransferIndices( @@ -3369,7 +3375,7 @@ public: if (!defWrite->hasOneUse()) break; writeToModify = defWrite; - defWrite = defWrite.source().getDefiningOp(); + defWrite = defWrite.getSource().getDefiningOp(); } return failure(); } @@ -3410,7 +3416,7 @@ public: return failure(); if (xferOp.getVectorType().getRank() != xferOp.getShapedType().getRank()) return failure(); - if (xferOp.mask()) + if (xferOp.getMask()) return failure(); // Fold only if the TransferWriteOp completely overwrites the `source` with // a vector. I.e., the result of the TransferWriteOp is a new tensor whose @@ -3418,7 +3424,7 @@ public: if (!llvm::equal(xferOp.getVectorType().getShape(), xferOp.getShapedType().getShape())) return failure(); - if (!xferOp.permutation_map().isIdentity()) + if (!xferOp.getPermutationMap().isIdentity()) return failure(); // Bail on illegal rank-reduction: we need to check that the rank-reduced @@ -3453,7 +3459,7 @@ public: SmallVector indices = getValueOrCreateConstantIndexOp( rewriter, insertOp.getLoc(), insertOp.getMixedOffsets()); SmallVector inBounds(xferOp.getTransferRank(), true); - rewriter.replaceOpWithNewOp(insertOp, xferOp.vector(), + rewriter.replaceOpWithNewOp(insertOp, xferOp.getVector(), insertOp.dest(), indices, ArrayRef{inBounds}); return success(); @@ -3494,7 +3500,7 @@ LogicalResult vector::LoadOp::verify() { if (resVecTy.getElementType() != memElemTy) return emitOpError("base and result element types should match"); - if (llvm::size(indices()) != memRefTy.getRank()) + if (llvm::size(getIndices()) != memRefTy.getRank()) return emitOpError("requires ") << memRefTy.getRank() << " indices"; return success(); } @@ -3527,7 +3533,7 @@ LogicalResult vector::StoreOp::verify() { if (valueVecTy.getElementType() != memElemTy) return emitOpError("base and valueToStore element type should match"); - if (llvm::size(indices()) != memRefTy.getRank()) + if (llvm::size(getIndices()) != memRefTy.getRank()) return emitOpError("requires ") << memRefTy.getRank() << " indices"; return success(); } @@ -3549,7 +3555,7 @@ LogicalResult MaskedLoadOp::verify() { if (resVType.getElementType() != memType.getElementType()) return emitOpError("base and result element type should match"); - if (llvm::size(indices()) != memType.getRank()) + if (llvm::size(getIndices()) != memType.getRank()) return emitOpError("requires ") << memType.getRank() << " indices"; if (resVType.getDimSize(0) != maskVType.getDimSize(0)) return emitOpError("expected result dim to match mask dim"); @@ -3564,13 +3570,13 @@ public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(MaskedLoadOp load, PatternRewriter &rewriter) const override { - switch (get1DMaskFormat(load.mask())) { + switch (get1DMaskFormat(load.getMask())) { case MaskFormat::AllTrue: - rewriter.replaceOpWithNewOp(load, load.getType(), - load.base(), load.indices()); + rewriter.replaceOpWithNewOp( + load, load.getType(), load.getBase(), load.getIndices()); return success(); case MaskFormat::AllFalse: - rewriter.replaceOp(load, load.pass_thru()); + rewriter.replaceOp(load, load.getPassThru()); return success(); case MaskFormat::Unknown: return failure(); @@ -3602,7 +3608,7 @@ LogicalResult MaskedStoreOp::verify() { if (valueVType.getElementType() != memType.getElementType()) return emitOpError("base and valueToStore element type should match"); - if (llvm::size(indices()) != memType.getRank()) + if (llvm::size(getIndices()) != memType.getRank()) return emitOpError("requires ") << memType.getRank() << " indices"; if (valueVType.getDimSize(0) != maskVType.getDimSize(0)) return emitOpError("expected valueToStore dim to match mask dim"); @@ -3615,10 +3621,10 @@ public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(MaskedStoreOp store, PatternRewriter &rewriter) const override { - switch (get1DMaskFormat(store.mask())) { + switch (get1DMaskFormat(store.getMask())) { case MaskFormat::AllTrue: rewriter.replaceOpWithNewOp( - store, store.valueToStore(), store.base(), store.indices()); + store, store.getValueToStore(), store.getBase(), store.getIndices()); return success(); case MaskFormat::AllFalse: rewriter.eraseOp(store); @@ -3653,7 +3659,7 @@ LogicalResult GatherOp::verify() { if (resVType.getElementType() != memType.getElementType()) return emitOpError("base and result element type should match"); - if (llvm::size(indices()) != memType.getRank()) + if (llvm::size(getIndices()) != memType.getRank()) return emitOpError("requires ") << memType.getRank() << " indices"; if (resVType.getDimSize(0) != indVType.getDimSize(0)) return emitOpError("expected result dim to match indices dim"); @@ -3670,11 +3676,11 @@ public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(GatherOp gather, PatternRewriter &rewriter) const override { - switch (get1DMaskFormat(gather.mask())) { + switch (get1DMaskFormat(gather.getMask())) { case MaskFormat::AllTrue: return failure(); // no unmasked equivalent case MaskFormat::AllFalse: - rewriter.replaceOp(gather, gather.pass_thru()); + rewriter.replaceOp(gather, gather.getPassThru()); return success(); case MaskFormat::Unknown: return failure(); @@ -3701,7 +3707,7 @@ LogicalResult ScatterOp::verify() { if (valueVType.getElementType() != memType.getElementType()) return emitOpError("base and valueToStore element type should match"); - if (llvm::size(indices()) != memType.getRank()) + if (llvm::size(getIndices()) != memType.getRank()) return emitOpError("requires ") << memType.getRank() << " indices"; if (valueVType.getDimSize(0) != indVType.getDimSize(0)) return emitOpError("expected valueToStore dim to match indices dim"); @@ -3716,7 +3722,7 @@ public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ScatterOp scatter, PatternRewriter &rewriter) const override { - switch (get1DMaskFormat(scatter.mask())) { + switch (get1DMaskFormat(scatter.getMask())) { case MaskFormat::AllTrue: return failure(); // no unmasked equivalent case MaskFormat::AllFalse: @@ -3747,7 +3753,7 @@ LogicalResult ExpandLoadOp::verify() { if (resVType.getElementType() != memType.getElementType()) return emitOpError("base and result element type should match"); - if (llvm::size(indices()) != memType.getRank()) + if (llvm::size(getIndices()) != memType.getRank()) return emitOpError("requires ") << memType.getRank() << " indices"; if (resVType.getDimSize(0) != maskVType.getDimSize(0)) return emitOpError("expected result dim to match mask dim"); @@ -3762,13 +3768,13 @@ public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ExpandLoadOp expand, PatternRewriter &rewriter) const override { - switch (get1DMaskFormat(expand.mask())) { + switch (get1DMaskFormat(expand.getMask())) { case MaskFormat::AllTrue: rewriter.replaceOpWithNewOp( - expand, expand.getType(), expand.base(), expand.indices()); + expand, expand.getType(), expand.getBase(), expand.getIndices()); return success(); case MaskFormat::AllFalse: - rewriter.replaceOp(expand, expand.pass_thru()); + rewriter.replaceOp(expand, expand.getPassThru()); return success(); case MaskFormat::Unknown: return failure(); @@ -3794,7 +3800,7 @@ LogicalResult CompressStoreOp::verify() { if (valueVType.getElementType() != memType.getElementType()) return emitOpError("base and valueToStore element type should match"); - if (llvm::size(indices()) != memType.getRank()) + if (llvm::size(getIndices()) != memType.getRank()) return emitOpError("requires ") << memType.getRank() << " indices"; if (valueVType.getDimSize(0) != maskVType.getDimSize(0)) return emitOpError("expected valueToStore dim to match mask dim"); @@ -3807,11 +3813,11 @@ public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(CompressStoreOp compress, PatternRewriter &rewriter) const override { - switch (get1DMaskFormat(compress.mask())) { + switch (get1DMaskFormat(compress.getMask())) { case MaskFormat::AllTrue: rewriter.replaceOpWithNewOp( - compress, compress.valueToStore(), compress.base(), - compress.indices()); + compress, compress.getValueToStore(), compress.getBase(), + compress.getIndices()); return success(); case MaskFormat::AllFalse: rewriter.eraseOp(compress); @@ -3894,8 +3900,8 @@ static LogicalResult verifyVectorShapeCast(Operation *op, } LogicalResult ShapeCastOp::verify() { - auto sourceVectorType = source().getType().dyn_cast_or_null(); - auto resultVectorType = result().getType().dyn_cast_or_null(); + auto sourceVectorType = getSource().getType().dyn_cast_or_null(); + auto resultVectorType = getResult().getType().dyn_cast_or_null(); // Check if source/result are of vector type. if (sourceVectorType && resultVectorType) @@ -3906,16 +3912,16 @@ LogicalResult ShapeCastOp::verify() { OpFoldResult ShapeCastOp::fold(ArrayRef operands) { // Nop shape cast. - if (source().getType() == result().getType()) - return source(); + if (getSource().getType() == getResult().getType()) + return getSource(); // Canceling shape casts. - if (auto otherOp = source().getDefiningOp()) { - if (result().getType() == otherOp.source().getType()) - return otherOp.source(); + if (auto otherOp = getSource().getDefiningOp()) { + if (getResult().getType() == otherOp.getSource().getType()) + return otherOp.getSource(); // Only allows valid transitive folding. - VectorType srcType = otherOp.source().getType().cast(); + VectorType srcType = otherOp.getSource().getType().cast(); VectorType resultType = getResult().getType().cast(); if (srcType.getRank() < resultType.getRank()) { if (!isValidShapeCast(srcType.getShape(), resultType.getShape())) @@ -3927,7 +3933,7 @@ OpFoldResult ShapeCastOp::fold(ArrayRef operands) { return {}; } - setOperand(otherOp.source()); + setOperand(otherOp.getSource()); return getResult(); } return {}; @@ -3941,7 +3947,8 @@ public: LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp, PatternRewriter &rewriter) const override { - auto constantOp = shapeCastOp.source().getDefiningOp(); + auto constantOp = + shapeCastOp.getSource().getDefiningOp(); if (!constantOp) return failure(); // Only handle splat for now. @@ -3998,13 +4005,13 @@ LogicalResult BitCastOp::verify() { OpFoldResult BitCastOp::fold(ArrayRef operands) { // Nop cast. - if (source().getType() == result().getType()) - return source(); + if (getSource().getType() == getResult().getType()) + return getSource(); // Canceling bitcasts. - if (auto otherOp = source().getDefiningOp()) - if (result().getType() == otherOp.source().getType()) - return otherOp.source(); + if (auto otherOp = getSource().getDefiningOp()) + if (getResult().getType() == otherOp.getSource().getType()) + return otherOp.getSource(); Attribute sourceConstant = operands.front(); if (!sourceConstant) @@ -4113,7 +4120,7 @@ OpFoldResult vector::TransposeOp::fold(ArrayRef operands) { return {}; } - return vector(); + return getVector(); } LogicalResult vector::TransposeOp::verify() { @@ -4123,7 +4130,7 @@ LogicalResult vector::TransposeOp::verify() { if (vectorType.getRank() != rank) return emitOpError("vector result rank mismatch: ") << rank; // Verify transposition array. - auto transpAttr = transp().getValue(); + auto transpAttr = getTransp().getValue(); int64_t size = transpAttr.size(); if (rank != size) return emitOpError("transposition length mismatch: ") << size; @@ -4168,7 +4175,7 @@ public: // Return if the input of 'transposeOp' is not defined by another transpose. vector::TransposeOp parentTransposeOp = - transposeOp.vector().getDefiningOp(); + transposeOp.getVector().getDefiningOp(); if (!parentTransposeOp) return failure(); @@ -4177,7 +4184,7 @@ public: // Replace 'transposeOp' with a new transpose operation. rewriter.replaceOpWithNewOp( transposeOp, transposeOp.getResult().getType(), - parentTransposeOp.vector(), + parentTransposeOp.getVector(), vector::getVectorSubscriptAttr(rewriter, permutation)); return success(); } @@ -4191,7 +4198,7 @@ void vector::TransposeOp::getCanonicalizationPatterns( } void vector::TransposeOp::getTransp(SmallVectorImpl &results) { - populateFromInt64AttrArray(transp(), results); + populateFromInt64AttrArray(getTransp(), results); } //===----------------------------------------------------------------------===// @@ -4202,23 +4209,23 @@ LogicalResult ConstantMaskOp::verify() { auto resultType = getResult().getType().cast(); // Check the corner case of 0-D vectors first. if (resultType.getRank() == 0) { - if (mask_dim_sizes().size() != 1) + if (getMaskDimSizes().size() != 1) return emitError("array attr must have length 1 for 0-D vectors"); - auto dim = mask_dim_sizes()[0].cast().getInt(); + auto dim = getMaskDimSizes()[0].cast().getInt(); if (dim != 0 && dim != 1) return emitError("mask dim size must be either 0 or 1 for 0-D vectors"); return success(); } // Verify that array attr size matches the rank of the vector result. - if (static_cast(mask_dim_sizes().size()) != resultType.getRank()) + if (static_cast(getMaskDimSizes().size()) != resultType.getRank()) return emitOpError( "must specify array attr of size equal vector result rank"); // Verify that each array attr element is in bounds of corresponding vector // result dimension size. auto resultShape = resultType.getShape(); SmallVector maskDimSizes; - for (const auto &it : llvm::enumerate(mask_dim_sizes())) { + for (const auto &it : llvm::enumerate(getMaskDimSizes())) { int64_t attrValue = it.value().cast().getInt(); if (attrValue < 0 || attrValue > resultShape[it.index()]) return emitOpError( @@ -4238,7 +4245,7 @@ LogicalResult ConstantMaskOp::verify() { // `vector.constant_mask`. In the future, a convention could be established // to decide if a specific dimension value could be considered as "all set". if (resultType.isScalable() && - mask_dim_sizes()[0].cast().getInt() != 0) + getMaskDimSizes()[0].cast().getInt() != 0) return emitOpError("expected mask dim sizes for scalable masks to be 0"); return success(); } @@ -4329,7 +4336,7 @@ LogicalResult ScanOp::verify() { VectorType initialType = getInitialValueType(); // Check reduction dimension < rank. int64_t srcRank = srcType.getRank(); - int64_t reductionDim = reduction_dim(); + int64_t reductionDim = getReductionDim(); if (reductionDim >= srcRank) return emitOpError("reduction dimension ") << reductionDim << " has to be less than " << srcRank; diff --git a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp index c823f34..dd83430 100644 --- a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp @@ -55,9 +55,9 @@ struct TransferReadOpInterface Value buffer = *state.getBuffer(rewriter, readOp->getOpOperand(0) /*source*/); replaceOpWithNewBufferizedOp( - rewriter, readOp, readOp.getVectorType(), buffer, readOp.indices(), - readOp.permutation_map(), readOp.padding(), readOp.mask(), - readOp.in_boundsAttr()); + rewriter, readOp, readOp.getVectorType(), buffer, readOp.getIndices(), + readOp.getPermutationMap(), readOp.getPadding(), readOp.getMask(), + readOp.getInBoundsAttr()); return success(); } }; @@ -107,8 +107,9 @@ struct TransferWriteOpInterface if (failed(resultBuffer)) return failure(); rewriter.create( - writeOp.getLoc(), writeOp.vector(), *resultBuffer, writeOp.indices(), - writeOp.permutation_mapAttr(), writeOp.in_boundsAttr()); + writeOp.getLoc(), writeOp.getVector(), *resultBuffer, + writeOp.getIndices(), writeOp.getPermutationMapAttr(), + writeOp.getInBoundsAttr()); replaceOpWithBufferizedValues(rewriter, op, *resultBuffer); return success(); diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp index e6afec8..d555c60 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp @@ -63,16 +63,16 @@ struct CastAwayExtractStridedSliceLeadingOneDim Location loc = extractOp.getLoc(); Value newSrcVector = rewriter.create( - loc, extractOp.vector(), splatZero(dropCount)); + loc, extractOp.getVector(), splatZero(dropCount)); // The offsets/sizes/strides attribute can have a less number of elements // than the input vector's rank: it is meant for the leading dimensions. auto newOffsets = rewriter.getArrayAttr( - extractOp.offsets().getValue().drop_front(dropCount)); + extractOp.getOffsets().getValue().drop_front(dropCount)); auto newSizes = rewriter.getArrayAttr( - extractOp.sizes().getValue().drop_front(dropCount)); + extractOp.getSizes().getValue().drop_front(dropCount)); auto newStrides = rewriter.getArrayAttr( - extractOp.strides().getValue().drop_front(dropCount)); + extractOp.getStrides().getValue().drop_front(dropCount)); auto newExtractOp = rewriter.create( loc, newDstType, newSrcVector, newOffsets, newSizes, newStrides); @@ -106,14 +106,14 @@ struct CastAwayInsertStridedSliceLeadingOneDim Location loc = insertOp.getLoc(); Value newSrcVector = rewriter.create( - loc, insertOp.source(), splatZero(srcDropCount)); + loc, insertOp.getSource(), splatZero(srcDropCount)); Value newDstVector = rewriter.create( - loc, insertOp.dest(), splatZero(dstDropCount)); + loc, insertOp.getDest(), splatZero(dstDropCount)); auto newOffsets = rewriter.getArrayAttr( - insertOp.offsets().getValue().take_back(newDstType.getRank())); + insertOp.getOffsets().getValue().take_back(newDstType.getRank())); auto newStrides = rewriter.getArrayAttr( - insertOp.strides().getValue().take_back(newSrcType.getRank())); + insertOp.getStrides().getValue().take_back(newSrcType.getRank())); auto newInsertOp = rewriter.create( loc, newDstType, newSrcVector, newDstVector, newOffsets, newStrides); @@ -138,10 +138,10 @@ struct CastAwayTransferReadLeadingOneDim if (read.getTransferRank() == 0) return failure(); - if (read.mask()) + if (read.getMask()) return failure(); - auto shapedType = read.source().getType().cast(); + auto shapedType = read.getSource().getType().cast(); if (shapedType.getElementType() != read.getVectorType().getElementType()) return failure(); @@ -151,7 +151,7 @@ struct CastAwayTransferReadLeadingOneDim if (newType == oldType) return failure(); - AffineMap oldMap = read.permutation_map(); + AffineMap oldMap = read.getPermutationMap(); ArrayRef newResults = oldMap.getResults().take_back(newType.getRank()); AffineMap newMap = @@ -159,13 +159,13 @@ struct CastAwayTransferReadLeadingOneDim rewriter.getContext()); ArrayAttr inBoundsAttr; - if (read.in_bounds()) + if (read.getInBounds()) inBoundsAttr = rewriter.getArrayAttr( - read.in_boundsAttr().getValue().take_back(newType.getRank())); + read.getInBoundsAttr().getValue().take_back(newType.getRank())); auto newRead = rewriter.create( - read.getLoc(), newType, read.source(), read.indices(), - AffineMapAttr::get(newMap), read.padding(), /*mask=*/Value(), + read.getLoc(), newType, read.getSource(), read.getIndices(), + AffineMapAttr::get(newMap), read.getPadding(), /*mask=*/Value(), inBoundsAttr); rewriter.replaceOpWithNewOp(read, oldType, newRead); @@ -186,10 +186,10 @@ struct CastAwayTransferWriteLeadingOneDim if (write.getTransferRank() == 0) return failure(); - if (write.mask()) + if (write.getMask()) return failure(); - auto shapedType = write.source().getType().dyn_cast(); + auto shapedType = write.getSource().getType().dyn_cast(); if (shapedType.getElementType() != write.getVectorType().getElementType()) return failure(); @@ -199,7 +199,7 @@ struct CastAwayTransferWriteLeadingOneDim return failure(); int64_t dropDim = oldType.getRank() - newType.getRank(); - AffineMap oldMap = write.permutation_map(); + AffineMap oldMap = write.getPermutationMap(); ArrayRef newResults = oldMap.getResults().take_back(newType.getRank()); AffineMap newMap = @@ -207,14 +207,14 @@ struct CastAwayTransferWriteLeadingOneDim rewriter.getContext()); ArrayAttr inBoundsAttr; - if (write.in_bounds()) + if (write.getInBounds()) inBoundsAttr = rewriter.getArrayAttr( - write.in_boundsAttr().getValue().take_back(newType.getRank())); + write.getInBoundsAttr().getValue().take_back(newType.getRank())); auto newVector = rewriter.create( - write.getLoc(), write.vector(), splatZero(dropDim)); + write.getLoc(), write.getVector(), splatZero(dropDim)); rewriter.replaceOpWithNewOp( - write, newVector, write.source(), write.indices(), + write, newVector, write.getSource(), write.getIndices(), AffineMapAttr::get(newMap), inBoundsAttr); return success(); @@ -237,7 +237,7 @@ struct CastAwayContractionLeadingOneDim if (oldAccType.getRank() < 2) return failure(); // TODO: implement masks. - if (llvm::size(contractOp.masks()) != 0) + if (llvm::size(contractOp.getMasks()) != 0) return failure(); if (oldAccType.getShape()[0] != 1) return failure(); @@ -248,7 +248,7 @@ struct CastAwayContractionLeadingOneDim auto oldIndexingMaps = contractOp.getIndexingMaps(); SmallVector newIndexingMaps; - auto oldIteratorTypes = contractOp.iterator_types(); + auto oldIteratorTypes = contractOp.getIteratorTypes(); SmallVector newIteratorTypes; int64_t dimToDrop = oldIndexingMaps[2].getDimPosition(0); @@ -264,8 +264,8 @@ struct CastAwayContractionLeadingOneDim newIteratorTypes.push_back(it.value()); } - SmallVector operands = {contractOp.lhs(), contractOp.rhs(), - contractOp.acc()}; + SmallVector operands = {contractOp.getLhs(), contractOp.getRhs(), + contractOp.getAcc()}; SmallVector newOperands; for (const auto &it : llvm::enumerate(oldIndexingMaps)) { @@ -336,7 +336,7 @@ struct CastAwayContractionLeadingOneDim auto newContractOp = rewriter.create( contractOp.getLoc(), newOperands[0], newOperands[1], newOperands[2], rewriter.getAffineMapArrayAttr(newIndexingMaps), - rewriter.getArrayAttr(newIteratorTypes), contractOp.kind()); + rewriter.getArrayAttr(newIteratorTypes), contractOp.getKind()); rewriter.replaceOpWithNewOp( contractOp, contractOp->getResultTypes()[0], newContractOp); return success(); diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp index 4308fa6..2a384c3 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp @@ -62,7 +62,7 @@ public: auto srcType = op.getSourceVectorType(); auto dstType = op.getDestVectorType(); - if (op.offsets().getValue().empty()) + if (op.getOffsets().getValue().empty()) return failure(); auto loc = op.getLoc(); @@ -74,21 +74,21 @@ public: int64_t rankRest = dstType.getRank() - rankDiff; // Extract / insert the subvector of matching rank and InsertStridedSlice // on it. - Value extracted = - rewriter.create(loc, op.dest(), - getI64SubArray(op.offsets(), /*dropFront=*/0, - /*dropBack=*/rankRest)); + Value extracted = rewriter.create( + loc, op.getDest(), + getI64SubArray(op.getOffsets(), /*dropFront=*/0, + /*dropBack=*/rankRest)); // A different pattern will kick in for InsertStridedSlice with matching // ranks. auto stridedSliceInnerOp = rewriter.create( - loc, op.source(), extracted, - getI64SubArray(op.offsets(), /*dropFront=*/rankDiff), - getI64SubArray(op.strides(), /*dropFront=*/0)); + loc, op.getSource(), extracted, + getI64SubArray(op.getOffsets(), /*dropFront=*/rankDiff), + getI64SubArray(op.getStrides(), /*dropFront=*/0)); rewriter.replaceOpWithNewOp( - op, stridedSliceInnerOp.getResult(), op.dest(), - getI64SubArray(op.offsets(), /*dropFront=*/0, + op, stridedSliceInnerOp.getResult(), op.getDest(), + getI64SubArray(op.getOffsets(), /*dropFront=*/0, /*dropBack=*/rankRest)); return success(); } @@ -118,7 +118,7 @@ public: auto srcType = op.getSourceVectorType(); auto dstType = op.getDestVectorType(); - if (op.offsets().getValue().empty()) + if (op.getOffsets().getValue().empty()) return failure(); int64_t srcRank = srcType.getRank(); @@ -128,18 +128,18 @@ public: return failure(); if (srcType == dstType) { - rewriter.replaceOp(op, op.source()); + rewriter.replaceOp(op, op.getSource()); return success(); } int64_t offset = - op.offsets().getValue().front().cast().getInt(); + op.getOffsets().getValue().front().cast().getInt(); int64_t size = srcType.getShape().front(); int64_t stride = - op.strides().getValue().front().cast().getInt(); + op.getStrides().getValue().front().cast().getInt(); auto loc = op.getLoc(); - Value res = op.dest(); + Value res = op.getDest(); if (srcRank == 1) { int nSrc = srcType.getShape().front(); @@ -148,8 +148,8 @@ public: SmallVector offsets(nDest, 0); for (int64_t i = 0; i < nSrc; ++i) offsets[i] = i; - Value scaledSource = - rewriter.create(loc, op.source(), op.source(), offsets); + Value scaledSource = rewriter.create(loc, op.getSource(), + op.getSource(), offsets); // 2. Create a mask where we take the value from scaledSource of dest // depending on the offset. @@ -162,7 +162,7 @@ public: } // 3. Replace with a ShuffleOp. - rewriter.replaceOpWithNewOp(op, scaledSource, op.dest(), + rewriter.replaceOpWithNewOp(op, scaledSource, op.getDest(), offsets); return success(); @@ -172,17 +172,17 @@ public: for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e; off += stride, ++idx) { // 1. extract the proper subvector (or element) from source - Value extractedSource = extractOne(rewriter, loc, op.source(), idx); + Value extractedSource = extractOne(rewriter, loc, op.getSource(), idx); if (extractedSource.getType().isa()) { // 2. If we have a vector, extract the proper subvector from destination // Otherwise we are at the element level and no need to recurse. - Value extractedDest = extractOne(rewriter, loc, op.dest(), off); + Value extractedDest = extractOne(rewriter, loc, op.getDest(), off); // 3. Reduce the problem to lowering a new InsertStridedSlice op with // smaller rank. extractedSource = rewriter.create( loc, extractedSource, extractedDest, - getI64SubArray(op.offsets(), /* dropFront=*/1), - getI64SubArray(op.strides(), /* dropFront=*/1)); + getI64SubArray(op.getOffsets(), /* dropFront=*/1), + getI64SubArray(op.getStrides(), /* dropFront=*/1)); } // 4. Insert the extractedSource into the res vector. res = insertOne(rewriter, loc, extractedSource, res, off); @@ -212,27 +212,28 @@ public: PatternRewriter &rewriter) const override { auto dstType = op.getType(); - assert(!op.offsets().getValue().empty() && "Unexpected empty offsets"); + assert(!op.getOffsets().getValue().empty() && "Unexpected empty offsets"); int64_t offset = - op.offsets().getValue().front().cast().getInt(); - int64_t size = op.sizes().getValue().front().cast().getInt(); + op.getOffsets().getValue().front().cast().getInt(); + int64_t size = + op.getSizes().getValue().front().cast().getInt(); int64_t stride = - op.strides().getValue().front().cast().getInt(); + op.getStrides().getValue().front().cast().getInt(); auto loc = op.getLoc(); auto elemType = dstType.getElementType(); assert(elemType.isSignlessIntOrIndexOrFloat()); // Single offset can be more efficiently shuffled. - if (op.offsets().getValue().size() == 1) { + if (op.getOffsets().getValue().size() == 1) { SmallVector offsets; offsets.reserve(size); for (int64_t off = offset, e = offset + size * stride; off < e; off += stride) offsets.push_back(off); - rewriter.replaceOpWithNewOp(op, dstType, op.vector(), - op.vector(), + rewriter.replaceOpWithNewOp(op, dstType, op.getVector(), + op.getVector(), rewriter.getI64ArrayAttr(offsets)); return success(); } @@ -243,11 +244,11 @@ public: Value res = rewriter.create(loc, dstType, zero); for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e; off += stride, ++idx) { - Value one = extractOne(rewriter, loc, op.vector(), off); + Value one = extractOne(rewriter, loc, op.getVector(), off); Value extracted = rewriter.create( - loc, one, getI64SubArray(op.offsets(), /* dropFront=*/1), - getI64SubArray(op.sizes(), /* dropFront=*/1), - getI64SubArray(op.strides(), /* dropFront=*/1)); + loc, one, getI64SubArray(op.getOffsets(), /* dropFront=*/1), + getI64SubArray(op.getSizes(), /* dropFront=*/1), + getI64SubArray(op.getStrides(), /* dropFront=*/1)); res = insertOne(rewriter, loc, extracted, res, idx); } rewriter.replaceOp(op, res); diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp index db5c667..07e24de 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp @@ -38,13 +38,13 @@ public: LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp, PatternRewriter &rewriter) const override { - auto src = multiReductionOp.source(); + auto src = multiReductionOp.getSource(); auto loc = multiReductionOp.getLoc(); auto srcRank = multiReductionOp.getSourceVectorType().getRank(); // Separate reduction and parallel dims auto reductionDimsRange = - multiReductionOp.reduction_dims().getAsValueRange(); + multiReductionOp.getReductionDims().getAsValueRange(); auto reductionDims = llvm::to_vector<4>(llvm::map_range( reductionDimsRange, [](const APInt &a) { return a.getZExtValue(); })); llvm::SmallDenseSet reductionDimsSet(reductionDims.begin(), @@ -86,8 +86,8 @@ public: reductionMask[i] = true; } rewriter.replaceOpWithNewOp( - multiReductionOp, transposeOp.result(), reductionMask, - multiReductionOp.kind()); + multiReductionOp, transposeOp.getResult(), reductionMask, + multiReductionOp.getKind()); return success(); } @@ -186,17 +186,17 @@ public: auto castedType = VectorType::get( vectorShape, multiReductionOp.getSourceVectorType().getElementType()); Value cast = rewriter.create( - loc, castedType, multiReductionOp.source()); + loc, castedType, multiReductionOp.getSource()); // 5. Creates the flattened form of vector.multi_reduction with inner/outer // most dim as reduction. auto newOp = rewriter.create( - loc, cast, mask, multiReductionOp.kind()); + loc, cast, mask, multiReductionOp.getKind()); // 6. If there are no parallel shapes, the result is a scalar. // TODO: support 0-d vectors when available. if (parallelShapes.empty()) { - rewriter.replaceOp(multiReductionOp, newOp.dest()); + rewriter.replaceOp(multiReductionOp, newOp.getDest()); return success(); } @@ -205,7 +205,7 @@ public: parallelShapes, multiReductionOp.getSourceVectorType().getElementType()); rewriter.replaceOpWithNewOp( - multiReductionOp, outputCastedType, newOp.dest()); + multiReductionOp, outputCastedType, newOp.getDest()); return success(); } @@ -238,12 +238,12 @@ struct TwoDimMultiReductionToElementWise return failure(); Value result = - rewriter.create(loc, multiReductionOp.source(), 0) + rewriter.create(loc, multiReductionOp.getSource(), 0) .getResult(); for (int64_t i = 1; i < srcShape[0]; i++) { - auto operand = - rewriter.create(loc, multiReductionOp.source(), i); - result = makeArithReduction(rewriter, loc, multiReductionOp.kind(), + auto operand = rewriter.create( + loc, multiReductionOp.getSource(), i); + result = makeArithReduction(rewriter, loc, multiReductionOp.getKind(), operand, result); } @@ -275,9 +275,9 @@ struct TwoDimMultiReductionToReduction for (int i = 0; i < outerDim; ++i) { auto v = rewriter.create( - loc, multiReductionOp.source(), ArrayRef{i}); - auto reducedValue = - rewriter.create(loc, multiReductionOp.kind(), v); + loc, multiReductionOp.getSource(), ArrayRef{i}); + auto reducedValue = rewriter.create( + loc, multiReductionOp.getKind(), v); result = rewriter.create( loc, reducedValue, result, rewriter.create(loc, i)); @@ -317,9 +317,9 @@ struct OneDimMultiReductionToTwoDim /// vector.extract(vector.multi_reduce(vector.shape_cast(v, 1xk)), 0) Value cast = rewriter.create( - loc, castedType, multiReductionOp.source()); + loc, castedType, multiReductionOp.getSource()); Value reduced = rewriter.create( - loc, cast, mask, multiReductionOp.kind()); + loc, cast, mask, multiReductionOp.getKind()); rewriter.replaceOpWithNewOp(multiReductionOp, reduced, ArrayRef{0}); return success(); diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp index 9315746..364f09c 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp @@ -96,7 +96,7 @@ void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) { << "\n"); llvm::SmallVector reads; Operation *firstOverwriteCandidate = nullptr; - for (auto *user : write.source().getUsers()) { + for (auto *user : write.getSource().getUsers()) { if (user == write.getOperation()) continue; if (auto nextWrite = dyn_cast(user)) { @@ -163,7 +163,7 @@ void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) { << "\n"); SmallVector blockingWrites; vector::TransferWriteOp lastwrite = nullptr; - for (Operation *user : read.source().getUsers()) { + for (Operation *user : read.getSource().getUsers()) { if (isa(user)) continue; if (auto write = dyn_cast(user)) { @@ -207,7 +207,7 @@ void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) { LLVM_DEBUG(DBGS() << "Forward value from " << *lastwrite.getOperation() << " to: " << *read.getOperation() << "\n"); - read.replaceAllUsesWith(lastwrite.vector()); + read.replaceAllUsesWith(lastwrite.getVector()); opToErase.push_back(read.getOperation()); } @@ -259,9 +259,9 @@ class TransferReadDropUnitDimsPattern LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp, PatternRewriter &rewriter) const override { auto loc = transferReadOp.getLoc(); - Value vector = transferReadOp.vector(); + Value vector = transferReadOp.getVector(); VectorType vectorType = vector.getType().cast(); - Value source = transferReadOp.source(); + Value source = transferReadOp.getSource(); MemRefType sourceType = source.getType().dyn_cast(); // TODO: support tensor types. if (!sourceType || !sourceType.hasStaticShape()) @@ -271,7 +271,7 @@ class TransferReadDropUnitDimsPattern // TODO: generalize this pattern, relax the requirements here. if (transferReadOp.hasOutOfBoundsDim()) return failure(); - if (!transferReadOp.permutation_map().isMinorIdentity()) + if (!transferReadOp.getPermutationMap().isMinorIdentity()) return failure(); int reducedRank = getReducedRank(sourceType.getShape()); if (reducedRank == sourceType.getRank()) @@ -279,7 +279,7 @@ class TransferReadDropUnitDimsPattern if (reducedRank != vectorType.getRank()) return failure(); // This pattern requires the vector shape to match the // reduced source shape. - if (llvm::any_of(transferReadOp.indices(), + if (llvm::any_of(transferReadOp.getIndices(), [](Value v) { return !isZero(v); })) return failure(); Value reducedShapeSource = @@ -302,9 +302,9 @@ class TransferWriteDropUnitDimsPattern LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp, PatternRewriter &rewriter) const override { auto loc = transferWriteOp.getLoc(); - Value vector = transferWriteOp.vector(); + Value vector = transferWriteOp.getVector(); VectorType vectorType = vector.getType().cast(); - Value source = transferWriteOp.source(); + Value source = transferWriteOp.getSource(); MemRefType sourceType = source.getType().dyn_cast(); // TODO: support tensor type. if (!sourceType || !sourceType.hasStaticShape()) @@ -314,7 +314,7 @@ class TransferWriteDropUnitDimsPattern // TODO: generalize this pattern, relax the requirements here. if (transferWriteOp.hasOutOfBoundsDim()) return failure(); - if (!transferWriteOp.permutation_map().isMinorIdentity()) + if (!transferWriteOp.getPermutationMap().isMinorIdentity()) return failure(); int reducedRank = getReducedRank(sourceType.getShape()); if (reducedRank == sourceType.getRank()) @@ -322,7 +322,7 @@ class TransferWriteDropUnitDimsPattern if (reducedRank != vectorType.getRank()) return failure(); // This pattern requires the vector shape to match the // reduced source shape. - if (llvm::any_of(transferWriteOp.indices(), + if (llvm::any_of(transferWriteOp.getIndices(), [](Value v) { return !isZero(v); })) return failure(); Value reducedShapeSource = @@ -366,9 +366,9 @@ class FlattenContiguousRowMajorTransferReadPattern LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp, PatternRewriter &rewriter) const override { auto loc = transferReadOp.getLoc(); - Value vector = transferReadOp.vector(); + Value vector = transferReadOp.getVector(); VectorType vectorType = vector.getType().cast(); - Value source = transferReadOp.source(); + Value source = transferReadOp.getSource(); MemRefType sourceType = source.getType().dyn_cast(); // Contiguity check is valid on tensors only. if (!sourceType) @@ -386,11 +386,11 @@ class FlattenContiguousRowMajorTransferReadPattern // TODO: generalize this pattern, relax the requirements here. if (transferReadOp.hasOutOfBoundsDim()) return failure(); - if (!transferReadOp.permutation_map().isMinorIdentity()) + if (!transferReadOp.getPermutationMap().isMinorIdentity()) return failure(); - if (transferReadOp.mask()) + if (transferReadOp.getMask()) return failure(); - if (llvm::any_of(transferReadOp.indices(), + if (llvm::any_of(transferReadOp.getIndices(), [](Value v) { return !isZero(v); })) return failure(); Value c0 = rewriter.create(loc, 0); @@ -418,9 +418,9 @@ class FlattenContiguousRowMajorTransferWritePattern LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp, PatternRewriter &rewriter) const override { auto loc = transferWriteOp.getLoc(); - Value vector = transferWriteOp.vector(); + Value vector = transferWriteOp.getVector(); VectorType vectorType = vector.getType().cast(); - Value source = transferWriteOp.source(); + Value source = transferWriteOp.getSource(); MemRefType sourceType = source.getType().dyn_cast(); // Contiguity check is valid on tensors only. if (!sourceType) @@ -438,11 +438,11 @@ class FlattenContiguousRowMajorTransferWritePattern // TODO: generalize this pattern, relax the requirements here. if (transferWriteOp.hasOutOfBoundsDim()) return failure(); - if (!transferWriteOp.permutation_map().isMinorIdentity()) + if (!transferWriteOp.getPermutationMap().isMinorIdentity()) return failure(); - if (transferWriteOp.mask()) + if (transferWriteOp.getMask()) return failure(); - if (llvm::any_of(transferWriteOp.indices(), + if (llvm::any_of(transferWriteOp.getIndices(), [](Value v) { return !isZero(v); })) return failure(); Value c0 = rewriter.create(loc, 0); diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferPermutationMapRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferPermutationMapRewritePatterns.cpp index baf6973..9488145 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferPermutationMapRewritePatterns.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferPermutationMapRewritePatterns.cpp @@ -62,7 +62,7 @@ struct TransferReadPermutationLowering return failure(); SmallVector permutation; - AffineMap map = op.permutation_map(); + AffineMap map = op.getPermutationMap(); if (map.getNumResults() == 0) return failure(); if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation)) @@ -85,7 +85,7 @@ struct TransferReadPermutationLowering // Transpose mask operand. Value newMask; - if (op.mask()) { + if (op.getMask()) { // Remove unused dims from the permutation map. E.g.: // E.g.: (d0, d1, d2, d3, d4, d5) -> (d5, 0, d3, 0, d2) // comp = (d0, d1, d2) -> (d2, 0, d1, 0 d0) @@ -99,22 +99,23 @@ struct TransferReadPermutationLowering maskTransposeIndices.push_back(expr.getPosition()); } - newMask = rewriter.create(op.getLoc(), op.mask(), + newMask = rewriter.create(op.getLoc(), op.getMask(), maskTransposeIndices); } // Transpose in_bounds attribute. ArrayAttr newInBoundsAttr = - op.in_bounds() ? transposeInBoundsAttr( - rewriter, op.in_bounds().getValue(), permutation) - : ArrayAttr(); + op.getInBounds() + ? transposeInBoundsAttr(rewriter, op.getInBounds().getValue(), + permutation) + : ArrayAttr(); // Generate new transfer_read operation. VectorType newReadType = VectorType::get(newVectorShape, op.getVectorType().getElementType()); Value newRead = rewriter.create( - op.getLoc(), newReadType, op.source(), op.indices(), - AffineMapAttr::get(newMap), op.padding(), newMask, newInBoundsAttr); + op.getLoc(), newReadType, op.getSource(), op.getIndices(), + AffineMapAttr::get(newMap), op.getPadding(), newMask, newInBoundsAttr); // Transpose result of transfer_read. SmallVector transposePerm(permutation.begin(), permutation.end()); @@ -151,7 +152,7 @@ struct TransferWritePermutationLowering return failure(); SmallVector permutation; - AffineMap map = op.permutation_map(); + AffineMap map = op.getPermutationMap(); if (map.isMinorIdentity()) return failure(); if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation)) @@ -169,23 +170,24 @@ struct TransferWritePermutationLowering }); // Transpose mask operand. - Value newMask = op.mask() ? rewriter.create( - op.getLoc(), op.mask(), indices) - : Value(); + Value newMask = op.getMask() ? rewriter.create( + op.getLoc(), op.getMask(), indices) + : Value(); // Transpose in_bounds attribute. ArrayAttr newInBoundsAttr = - op.in_bounds() ? transposeInBoundsAttr( - rewriter, op.in_bounds().getValue(), permutation) - : ArrayAttr(); + op.getInBounds() + ? transposeInBoundsAttr(rewriter, op.getInBounds().getValue(), + permutation) + : ArrayAttr(); // Generate new transfer_write operation. - Value newVec = - rewriter.create(op.getLoc(), op.vector(), indices); + Value newVec = rewriter.create( + op.getLoc(), op.getVector(), indices); auto newMap = AffineMap::getMinorIdentityMap( map.getNumDims(), map.getNumResults(), rewriter.getContext()); rewriter.replaceOpWithNewOp( - op, newVec, op.source(), op.indices(), AffineMapAttr::get(newMap), + op, newVec, op.getSource(), op.getIndices(), AffineMapAttr::get(newMap), newMask, newInBoundsAttr); return success(); @@ -209,7 +211,7 @@ struct TransferOpReduceRank : public OpRewritePattern { if (op.getTransferRank() == 0) return failure(); - AffineMap map = op.permutation_map(); + AffineMap map = op.getPermutationMap(); unsigned numLeadingBroadcast = 0; for (auto expr : map.getResults()) { auto dimExpr = expr.dyn_cast(); @@ -237,12 +239,12 @@ struct TransferOpReduceRank : public OpRewritePattern { if (reducedShapeRank == 0) { Value newRead; if (op.getShapedType().isa()) { - newRead = rewriter.create(op.getLoc(), op.source(), - op.indices()); + newRead = rewriter.create( + op.getLoc(), op.getSource(), op.getIndices()); } else { newRead = rewriter.create( - op.getLoc(), originalVecType.getElementType(), op.source(), - op.indices()); + op.getLoc(), originalVecType.getElementType(), op.getSource(), + op.getIndices()); } rewriter.replaceOpWithNewOp(op, originalVecType, newRead); @@ -256,13 +258,14 @@ struct TransferOpReduceRank : public OpRewritePattern { VectorType newReadType = VectorType::get(newShape, originalVecType.getElementType()); ArrayAttr newInBoundsAttr = - op.in_bounds() + op.getInBounds() ? rewriter.getArrayAttr( - op.in_boundsAttr().getValue().take_back(reducedShapeRank)) + op.getInBoundsAttr().getValue().take_back(reducedShapeRank)) : ArrayAttr(); Value newRead = rewriter.create( - op.getLoc(), newReadType, op.source(), op.indices(), - AffineMapAttr::get(newMap), op.padding(), op.mask(), newInBoundsAttr); + op.getLoc(), newReadType, op.getSource(), op.getIndices(), + AffineMapAttr::get(newMap), op.getPadding(), op.getMask(), + newInBoundsAttr); rewriter.replaceOpWithNewOp(op, originalVecType, newRead); return success(); diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp index c457621..5e090a6 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp @@ -249,7 +249,7 @@ createFullPartialLinalgCopy(RewriterBase &b, vector::TransferReadOp xferOp, MemRefType compatibleMemRefType, Value alloc) { Location loc = xferOp.getLoc(); Value zero = b.create(loc, 0); - Value memref = xferOp.source(); + Value memref = xferOp.getSource(); return b.create( loc, returnTypes, inBoundsCond, [&](OpBuilder &b, Location loc) { @@ -257,12 +257,12 @@ createFullPartialLinalgCopy(RewriterBase &b, vector::TransferReadOp xferOp, if (compatibleMemRefType != xferOp.getShapedType()) res = b.create(loc, compatibleMemRefType, memref); scf::ValueVector viewAndIndices{res}; - viewAndIndices.insert(viewAndIndices.end(), xferOp.indices().begin(), - xferOp.indices().end()); + viewAndIndices.insert(viewAndIndices.end(), xferOp.getIndices().begin(), + xferOp.getIndices().end()); b.create(loc, viewAndIndices); }, [&](OpBuilder &b, Location loc) { - b.create(loc, ValueRange{xferOp.padding()}, + b.create(loc, ValueRange{xferOp.getPadding()}, ValueRange{alloc}); // Take partial subview of memref which guarantees no dimension // overflows. @@ -304,7 +304,7 @@ static scf::IfOp createFullPartialVectorTransferRead( Location loc = xferOp.getLoc(); scf::IfOp fullPartialIfOp; Value zero = b.create(loc, 0); - Value memref = xferOp.source(); + Value memref = xferOp.getSource(); return b.create( loc, returnTypes, inBoundsCond, [&](OpBuilder &b, Location loc) { @@ -312,8 +312,8 @@ static scf::IfOp createFullPartialVectorTransferRead( if (compatibleMemRefType != xferOp.getShapedType()) res = b.create(loc, compatibleMemRefType, memref); scf::ValueVector viewAndIndices{res}; - viewAndIndices.insert(viewAndIndices.end(), xferOp.indices().begin(), - xferOp.indices().end()); + viewAndIndices.insert(viewAndIndices.end(), xferOp.getIndices().begin(), + xferOp.getIndices().end()); b.create(loc, viewAndIndices); }, [&](OpBuilder &b, Location loc) { @@ -354,7 +354,7 @@ getLocationToWriteFullVec(RewriterBase &b, vector::TransferWriteOp xferOp, MemRefType compatibleMemRefType, Value alloc) { Location loc = xferOp.getLoc(); Value zero = b.create(loc, 0); - Value memref = xferOp.source(); + Value memref = xferOp.getSource(); return b .create( loc, returnTypes, inBoundsCond, @@ -364,8 +364,8 @@ getLocationToWriteFullVec(RewriterBase &b, vector::TransferWriteOp xferOp, res = b.create(loc, compatibleMemRefType, memref); scf::ValueVector viewAndIndices{res}; viewAndIndices.insert(viewAndIndices.end(), - xferOp.indices().begin(), - xferOp.indices().end()); + xferOp.getIndices().begin(), + xferOp.getIndices().end()); b.create(loc, viewAndIndices); }, [&](OpBuilder &b, Location loc) { @@ -430,9 +430,10 @@ static void createFullPartialVectorTransferWrite(RewriterBase &b, b.create(loc, notInBounds, [&](OpBuilder &b, Location loc) { BlockAndValueMapping mapping; Value load = b.create( - loc, b.create( - loc, MemRefType::get({}, xferOp.vector().getType()), alloc)); - mapping.map(xferOp.vector(), load); + loc, + b.create( + loc, MemRefType::get({}, xferOp.getVector().getType()), alloc)); + mapping.map(xferOp.getVector(), load); b.clone(*xferOp.getOperation(), mapping); b.create(loc, ValueRange{}); }); @@ -530,9 +531,9 @@ LogicalResult mlir::vector::splitFullAndPartialTransfer( if (!(xferReadOp || xferWriteOp)) return failure(); - if (xferWriteOp && xferWriteOp.mask()) + if (xferWriteOp && xferWriteOp.getMask()) return failure(); - if (xferReadOp && xferReadOp.mask()) + if (xferReadOp && xferReadOp.getMask()) return failure(); } @@ -601,8 +602,8 @@ LogicalResult mlir::vector::splitFullAndPartialTransfer( // The operation is cloned to prevent deleting information needed for the // later IR creation. BlockAndValueMapping mapping; - mapping.map(xferWriteOp.source(), memrefAndIndices.front()); - mapping.map(xferWriteOp.indices(), memrefAndIndices.drop_front()); + mapping.map(xferWriteOp.getSource(), memrefAndIndices.front()); + mapping.map(xferWriteOp.getIndices(), memrefAndIndices.drop_front()); auto *clone = b.clone(*xferWriteOp, mapping); clone->setAttr(xferWriteOp.getInBoundsAttrName(), inBoundsAttr); diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp index 32e2fa7..2ca6481 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -168,19 +168,19 @@ struct ShapeCastOpFolder : public OpRewritePattern { PatternRewriter &rewriter) const override { // Check if 'shapeCastOp' has vector source/result type. auto sourceVectorType = - shapeCastOp.source().getType().dyn_cast_or_null(); + shapeCastOp.getSource().getType().dyn_cast_or_null(); auto resultVectorType = - shapeCastOp.result().getType().dyn_cast_or_null(); + shapeCastOp.getResult().getType().dyn_cast_or_null(); if (!sourceVectorType || !resultVectorType) return failure(); // Check if shape cast op source operand is also a shape cast op. auto sourceShapeCastOp = dyn_cast_or_null( - shapeCastOp.source().getDefiningOp()); + shapeCastOp.getSource().getDefiningOp()); if (!sourceShapeCastOp) return failure(); auto operandSourceVectorType = - sourceShapeCastOp.source().getType().cast(); + sourceShapeCastOp.getSource().getType().cast(); auto operandResultVectorType = sourceShapeCastOp.getType(); // Check if shape cast operations invert each other. @@ -188,7 +188,7 @@ struct ShapeCastOpFolder : public OpRewritePattern { operandResultVectorType != sourceVectorType) return failure(); - rewriter.replaceOp(shapeCastOp, sourceShapeCastOp.source()); + rewriter.replaceOp(shapeCastOp, sourceShapeCastOp.getSource()); return success(); } }; @@ -207,7 +207,7 @@ public: // Scalar to any vector can use splat. if (!srcType) { - rewriter.replaceOpWithNewOp(op, dstType, op.source()); + rewriter.replaceOpWithNewOp(op, dstType, op.getSource()); return success(); } @@ -219,9 +219,9 @@ public: if (srcRank <= 1 && dstRank == 1) { Value ext; if (srcRank == 0) - ext = rewriter.create(loc, op.source()); + ext = rewriter.create(loc, op.getSource()); else - ext = rewriter.create(loc, op.source(), 0); + ext = rewriter.create(loc, op.getSource(), 0); rewriter.replaceOpWithNewOp(op, dstType, ext); return success(); } @@ -240,7 +240,7 @@ public: VectorType resType = VectorType::get(dstType.getShape().drop_front(), eltType); Value bcst = - rewriter.create(loc, resType, op.source()); + rewriter.create(loc, resType, op.getSource()); Value result = rewriter.create( loc, dstType, rewriter.getZeroAttr(dstType)); for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d) @@ -260,7 +260,7 @@ public: // All trailing dimensions are the same. Simply pass through. if (m == -1) { - rewriter.replaceOp(op, op.source()); + rewriter.replaceOp(op, op.getSource()); return success(); } @@ -285,14 +285,14 @@ public: loc, dstType, rewriter.getZeroAttr(dstType)); if (m == 0) { // Stetch at start. - Value ext = rewriter.create(loc, op.source(), 0); + Value ext = rewriter.create(loc, op.getSource(), 0); Value bcst = rewriter.create(loc, resType, ext); for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d) result = rewriter.create(loc, bcst, result, d); } else { // Stetch not at start. for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d) { - Value ext = rewriter.create(loc, op.source(), d); + Value ext = rewriter.create(loc, op.getSource(), d); Value bcst = rewriter.create(loc, resType, ext); result = rewriter.create(loc, bcst, result, d); } @@ -338,13 +338,13 @@ public: PatternRewriter &rewriter) const override { auto loc = op.getLoc(); - Value input = op.vector(); + Value input = op.getVector(); VectorType inputType = op.getVectorType(); VectorType resType = op.getResultType(); // Set up convenience transposition table. SmallVector transp; - for (auto attr : op.transp()) + for (auto attr : op.getTransp()) transp.push_back(attr.cast().getInt()); if (vectorTransformOptions.vectorTransposeLowering == @@ -433,7 +433,7 @@ public: return rewriter.notifyMatchFailure(op, "Not a 2D transpose"); SmallVector transp; - for (auto attr : op.transp()) + for (auto attr : op.getTransp()) transp.push_back(attr.cast().getInt()); if (transp[0] != 1 && transp[1] != 0) return rewriter.notifyMatchFailure(op, "Not a 2D transpose permutation"); @@ -444,7 +444,8 @@ public: int64_t m = srcType.getShape().front(), n = srcType.getShape().back(); Value casted = rewriter.create( - loc, VectorType::get({m * n}, srcType.getElementType()), op.vector()); + loc, VectorType::get({m * n}, srcType.getElementType()), + op.getVector()); SmallVector mask; mask.reserve(m * n); for (int64_t j = 0; j < n; ++j) @@ -490,15 +491,15 @@ public: VectorType resType = op.getVectorType(); Type eltType = resType.getElementType(); bool isInt = eltType.isa(); - Value acc = (op.acc().empty()) ? nullptr : op.acc()[0]; - vector::CombiningKind kind = op.kind(); + Value acc = (op.getAcc().empty()) ? nullptr : op.getAcc()[0]; + vector::CombiningKind kind = op.getKind(); if (!rhsType) { // Special case: AXPY operation. - Value b = rewriter.create(loc, lhsType, op.rhs()); + Value b = rewriter.create(loc, lhsType, op.getRhs()); Optional mult = - isInt ? genMultI(loc, op.lhs(), b, acc, kind, rewriter) - : genMultF(loc, op.lhs(), b, acc, kind, rewriter); + isInt ? genMultI(loc, op.getLhs(), b, acc, kind, rewriter) + : genMultF(loc, op.getLhs(), b, acc, kind, rewriter); if (!mult.hasValue()) return failure(); rewriter.replaceOp(op, mult.getValue()); @@ -509,13 +510,15 @@ public: loc, resType, rewriter.getZeroAttr(resType)); for (int64_t d = 0, e = resType.getDimSize(0); d < e; ++d) { auto pos = rewriter.getI64ArrayAttr(d); - Value x = rewriter.create(loc, eltType, op.lhs(), pos); + Value x = + rewriter.create(loc, eltType, op.getLhs(), pos); Value a = rewriter.create(loc, rhsType, x); Value r = nullptr; if (acc) r = rewriter.create(loc, rhsType, acc, pos); - Optional m = isInt ? genMultI(loc, a, op.rhs(), r, kind, rewriter) - : genMultF(loc, a, op.rhs(), r, kind, rewriter); + Optional m = + isInt ? genMultI(loc, a, op.getRhs(), r, kind, rewriter) + : genMultF(loc, a, op.getRhs(), r, kind, rewriter); if (!m.hasValue()) return failure(); result = rewriter.create(loc, resType, m.getValue(), @@ -588,7 +591,7 @@ public: auto loc = op.getLoc(); auto dstType = op.getType(); auto eltType = dstType.getElementType(); - auto dimSizes = op.mask_dim_sizes(); + auto dimSizes = op.getMaskDimSizes(); int64_t rank = dstType.getRank(); if (rank == 0) { @@ -715,7 +718,7 @@ public: loc, resultVectorType, rewriter.getZeroAttr(resultVectorType)); unsigned mostMinorVectorSize = sourceVectorType.getShape()[1]; for (int64_t i = 0, e = sourceVectorType.getShape().front(); i != e; ++i) { - Value vec = rewriter.create(loc, op.source(), i); + Value vec = rewriter.create(loc, op.getSource(), i); desc = rewriter.create( loc, vec, desc, /*offsets=*/i * mostMinorVectorSize, /*strides=*/1); @@ -749,7 +752,7 @@ public: unsigned mostMinorVectorSize = resultVectorType.getShape()[1]; for (int64_t i = 0, e = resultVectorType.getShape().front(); i != e; ++i) { Value vec = rewriter.create( - loc, op.source(), /*offsets=*/i * mostMinorVectorSize, + loc, op.getSource(), /*offsets=*/i * mostMinorVectorSize, /*sizes=*/mostMinorVectorSize, /*strides=*/1); desc = rewriter.create(loc, vec, desc, i); @@ -804,7 +807,7 @@ public: incIdx(srcIdx, sourceVectorType, srcRank - 1); incIdx(resIdx, resultVectorType, resRank - 1); } - Value e = rewriter.create(loc, op.source(), srcIdx); + Value e = rewriter.create(loc, op.getSource(), srcIdx); result = rewriter.create(loc, e, result, resIdx); } rewriter.replaceOp(op, result); @@ -844,9 +847,9 @@ struct MultiReduceToContract LogicalResult matchAndRewrite(vector::MultiDimReductionOp reduceOp, PatternRewriter &rewriter) const override { - if (reduceOp.kind() != vector::CombiningKind::ADD) + if (reduceOp.getKind() != vector::CombiningKind::ADD) return failure(); - Operation *mulOp = reduceOp.source().getDefiningOp(); + Operation *mulOp = reduceOp.getSource().getDefiningOp(); if (!mulOp || !isa(mulOp)) return failure(); SmallVector reductionMask = reduceOp.getReductionMask(); @@ -905,8 +908,8 @@ struct CombineContractTranspose PatternRewriter &rewriter) const override { SmallVector maps = llvm::to_vector<4>(contractOp.getIndexingMaps()); - Value lhs = contractOp.lhs(); - Value rhs = contractOp.rhs(); + Value lhs = contractOp.getLhs(); + Value rhs = contractOp.getRhs(); size_t index = 0; bool changed = false; for (Value *operand : {&lhs, &rhs}) { @@ -917,17 +920,17 @@ struct CombineContractTranspose SmallVector perm; transposeOp.getTransp(perm); AffineMap permutationMap = AffineMap::getPermutationMap( - extractVector(transposeOp.transp()), + extractVector(transposeOp.getTransp()), contractOp.getContext()); map = inversePermutation(permutationMap).compose(map); - *operand = transposeOp.vector(); + *operand = transposeOp.getVector(); changed = true; } if (!changed) return failure(); rewriter.replaceOpWithNewOp( - contractOp, lhs, rhs, contractOp.acc(), - rewriter.getAffineMapArrayAttr(maps), contractOp.iterator_types()); + contractOp, lhs, rhs, contractOp.getAcc(), + rewriter.getAffineMapArrayAttr(maps), contractOp.getIteratorTypes()); return success(); } }; @@ -962,8 +965,8 @@ struct CombineContractBroadcast PatternRewriter &rewriter) const override { SmallVector maps = llvm::to_vector<4>(contractOp.getIndexingMaps()); - Value lhs = contractOp.lhs(); - Value rhs = contractOp.rhs(); + Value lhs = contractOp.getLhs(); + Value rhs = contractOp.getRhs(); size_t index = 0; bool changed = false; for (Value *operand : {&lhs, &rhs}) { @@ -996,14 +999,14 @@ struct CombineContractBroadcast AffineMap::get(broadcast.getVectorType().getRank(), 0, originalDims, contractOp.getContext()); map = broadcastMap.compose(map); - *operand = broadcast.source(); + *operand = broadcast.getSource(); changed = true; } if (!changed) return failure(); rewriter.replaceOpWithNewOp( - contractOp, lhs, rhs, contractOp.acc(), - rewriter.getAffineMapArrayAttr(maps), contractOp.iterator_types()); + contractOp, lhs, rhs, contractOp.getAcc(), + rewriter.getAffineMapArrayAttr(maps), contractOp.getIteratorTypes()); return success(); } }; @@ -1036,8 +1039,9 @@ struct ReorderCastOpsOnBroadcast Type castResTy = getElementTypeOrSelf(op->getResult(0)); if (auto vecTy = bcastOp.getSourceType().dyn_cast()) castResTy = VectorType::get(vecTy.getShape(), castResTy); - auto castOp = rewriter.create(op->getLoc(), op->getName().getIdentifier(), - bcastOp.source(), castResTy, op->getAttrs()); + auto castOp = + rewriter.create(op->getLoc(), op->getName().getIdentifier(), + bcastOp.getSource(), castResTy, op->getAttrs()); rewriter.replaceOpWithNewOp( op, op->getResult(0).getType(), castOp->getResult(0)); return success(); @@ -1075,8 +1079,9 @@ struct ReorderCastOpsOnTranspose auto castResTy = transpOp.getVectorType(); castResTy = VectorType::get(castResTy.getShape(), getElementTypeOrSelf(op->getResult(0))); - auto castOp = rewriter.create(op->getLoc(), op->getName().getIdentifier(), - transpOp.vector(), castResTy, op->getAttrs()); + auto castOp = + rewriter.create(op->getLoc(), op->getName().getIdentifier(), + transpOp.getVector(), castResTy, op->getAttrs()); rewriter.replaceOpWithNewOp( op, op->getResult(0).getType(), castOp->getResult(0), transpOp.getTransp()); @@ -1127,7 +1132,7 @@ LogicalResult ContractionOpToMatmulOpLowering::matchAndRewrite(vector::ContractionOp op, PatternRewriter &rew) const { // TODO: implement masks - if (llvm::size(op.masks()) != 0) + if (llvm::size(op.getMasks()) != 0) return failure(); if (vectorTransformOptions.vectorContractLowering != vector::VectorContractLowering::Matmul) @@ -1135,7 +1140,7 @@ ContractionOpToMatmulOpLowering::matchAndRewrite(vector::ContractionOp op, if (failed(filter(op))) return failure(); - auto iteratorTypes = op.iterator_types().getValue(); + auto iteratorTypes = op.getIteratorTypes().getValue(); if (!isParallelIterator(iteratorTypes[0]) || !isParallelIterator(iteratorTypes[1]) || !isReductionIterator(iteratorTypes[2])) @@ -1152,16 +1157,16 @@ ContractionOpToMatmulOpLowering::matchAndRewrite(vector::ContractionOp op, AffineExpr m, n, k; bindDims(rew.getContext(), m, n, k); // LHS must be A(m, k) or A(k, m). - Value lhs = op.lhs(); - auto lhsMap = op.indexing_maps()[0]; + Value lhs = op.getLhs(); + auto lhsMap = op.getIndexingMaps()[0]; if (lhsMap == AffineMap::get(3, 0, {k, m}, ctx)) lhs = rew.create(loc, lhs, ArrayRef{1, 0}); else if (lhsMap != AffineMap::get(3, 0, {m, k}, ctx)) return failure(); // RHS must be B(k, n) or B(n, k). - Value rhs = op.rhs(); - auto rhsMap = op.indexing_maps()[1]; + Value rhs = op.getRhs(); + auto rhsMap = op.getIndexingMaps()[1]; if (rhsMap == AffineMap::get(3, 0, {n, k}, ctx)) rhs = rew.create(loc, rhs, ArrayRef{1, 0}); else if (rhsMap != AffineMap::get(3, 0, {k, n}, ctx)) @@ -1187,11 +1192,11 @@ ContractionOpToMatmulOpLowering::matchAndRewrite(vector::ContractionOp op, mul = rew.create( loc, VectorType::get({lhsRows, rhsColumns}, - getElementTypeOrSelf(op.acc().getType())), + getElementTypeOrSelf(op.getAcc().getType())), mul); // ACC must be C(m, n) or C(n, m). - auto accMap = op.indexing_maps()[2]; + auto accMap = op.getIndexingMaps()[2]; if (accMap == AffineMap::get(3, 0, {n, m}, ctx)) mul = rew.create(loc, mul, ArrayRef{1, 0}); else if (accMap != AffineMap::get(3, 0, {m, n}, ctx)) @@ -1199,8 +1204,9 @@ ContractionOpToMatmulOpLowering::matchAndRewrite(vector::ContractionOp op, Value res = elementType.isa() - ? static_cast(rew.create(loc, op.acc(), mul)) - : static_cast(rew.create(loc, op.acc(), mul)); + ? static_cast(rew.create(loc, op.getAcc(), mul)) + : static_cast( + rew.create(loc, op.getAcc(), mul)); rew.replaceOp(op, res); return success(); @@ -1226,11 +1232,10 @@ struct Red : public IteratorType { /// This unrolls outer-products along the reduction dimension. struct UnrolledOuterProductGenerator : public StructuredGenerator { - UnrolledOuterProductGenerator(OpBuilder &builder, vector::ContractionOp op) : StructuredGenerator(builder, op), - kind(op.kind()), lhs(op.lhs()), rhs(op.rhs()), res(op.acc()), - lhsType(op.getLhsType()) {} + kind(op.getKind()), lhs(op.getLhs()), rhs(op.getRhs()), + res(op.getAcc()), lhsType(op.getLhsType()) {} Value t(Value v) { static constexpr std::array perm = {1, 0}; @@ -1356,7 +1361,7 @@ private: LogicalResult ContractionOpToOuterProductOpLowering::matchAndRewrite( vector::ContractionOp op, PatternRewriter &rewriter) const { // TODO: implement masks - if (llvm::size(op.masks()) != 0) + if (llvm::size(op.getMasks()) != 0) return failure(); if (vectorTransformOptions.vectorContractLowering != @@ -1390,7 +1395,7 @@ LogicalResult ContractionOpToDotLowering::matchAndRewrite(vector::ContractionOp op, PatternRewriter &rewriter) const { // TODO: implement masks - if (llvm::size(op.masks()) != 0) + if (llvm::size(op.getMasks()) != 0) return failure(); if (failed(filter(op))) @@ -1400,10 +1405,10 @@ ContractionOpToDotLowering::matchAndRewrite(vector::ContractionOp op, vector::VectorContractLowering::Dot) return failure(); - auto iteratorTypes = op.iterator_types().getValue(); + auto iteratorTypes = op.getIteratorTypes().getValue(); static constexpr std::array perm = {1, 0}; Location loc = op.getLoc(); - Value lhs = op.lhs(), rhs = op.rhs(); + Value lhs = op.getLhs(), rhs = op.getRhs(); using MapList = ArrayRef>; auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); }; @@ -1495,7 +1500,7 @@ ContractionOpToDotLowering::matchAndRewrite(vector::ContractionOp op, res = rewriter.create(op.getLoc(), reduced, res, pos); } } - if (auto acc = op.acc()) + if (auto acc = op.getAcc()) res = createAdd(op.getLoc(), res, acc, isInt, rewriter); rewriter.replaceOp(op, res); return success(); @@ -1522,7 +1527,7 @@ LogicalResult ContractionOpLowering::matchAndRewrite(vector::ContractionOp op, PatternRewriter &rewriter) const { // TODO: implement masks. - if (llvm::size(op.masks()) != 0) + if (llvm::size(op.getMasks()) != 0) return failure(); if (failed(filter(op))) @@ -1627,15 +1632,15 @@ Value ContractionOpLowering::lowerParallel(vector::ContractionOp op, adjustMap(iMap[2], iterIndex, rewriter)}; auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps); auto lowIter = - rewriter.getArrayAttr(adjustIter(op.iterator_types(), iterIndex)); + rewriter.getArrayAttr(adjustIter(op.getIteratorTypes(), iterIndex)); // Unroll into a series of lower dimensional vector.contract ops. Location loc = op.getLoc(); Value result = rewriter.create( loc, resType, rewriter.getZeroAttr(resType)); for (int64_t d = 0; d < dimSize; ++d) { - auto lhs = reshapeLoad(loc, op.lhs(), lhsType, lhsIndex, d, rewriter); - auto rhs = reshapeLoad(loc, op.rhs(), rhsType, rhsIndex, d, rewriter); - auto acc = reshapeLoad(loc, op.acc(), resType, resIndex, d, rewriter); + auto lhs = reshapeLoad(loc, op.getLhs(), lhsType, lhsIndex, d, rewriter); + auto rhs = reshapeLoad(loc, op.getRhs(), rhsType, rhsIndex, d, rewriter); + auto acc = reshapeLoad(loc, op.getAcc(), resType, resIndex, d, rewriter); Value lowContract = rewriter.create( loc, lhs, rhs, acc, lowAffine, lowIter); result = @@ -1667,10 +1672,10 @@ Value ContractionOpLowering::lowerReduction(vector::ContractionOp op, // Base case. if (lhsType.getRank() == 1) { assert(rhsType.getRank() == 1 && "corrupt contraction"); - Value m = createMul(loc, op.lhs(), op.rhs(), isInt, rewriter); + Value m = createMul(loc, op.getLhs(), op.getRhs(), isInt, rewriter); auto kind = vector::CombiningKind::ADD; Value res = rewriter.create(loc, kind, m); - if (auto acc = op.acc()) + if (auto acc = op.getAcc()) res = createAdd(op.getLoc(), res, acc, isInt, rewriter); return res; } @@ -1681,15 +1686,15 @@ Value ContractionOpLowering::lowerReduction(vector::ContractionOp op, adjustMap(iMap[2], iterIndex, rewriter)}; auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps); auto lowIter = - rewriter.getArrayAttr(adjustIter(op.iterator_types(), iterIndex)); + rewriter.getArrayAttr(adjustIter(op.getIteratorTypes(), iterIndex)); // Unroll into a series of lower dimensional vector.contract ops. // By feeding the initial accumulator into the first contraction, // and the result of each contraction into the next, eventually // the sum of all reductions is computed. - Value result = op.acc(); + Value result = op.getAcc(); for (int64_t d = 0; d < dimSize; ++d) { - auto lhs = reshapeLoad(loc, op.lhs(), lhsType, lhsIndex, d, rewriter); - auto rhs = reshapeLoad(loc, op.rhs(), rhsType, rhsIndex, d, rewriter); + auto lhs = reshapeLoad(loc, op.getLhs(), lhsType, lhsIndex, d, rewriter); + auto rhs = reshapeLoad(loc, op.getRhs(), rhsType, rhsIndex, d, rewriter); result = rewriter.create(loc, lhs, rhs, result, lowAffine, lowIter); } @@ -1753,7 +1758,7 @@ struct TransferReadToVectorLoadLowering // Permutations are handled by VectorToSCF or // populateVectorTransferPermutationMapLoweringPatterns. // We let the 0-d corner case pass-through as it is supported. - if (!read.permutation_map().isMinorIdentityWithBroadcasting( + if (!read.getPermutationMap().isMinorIdentityWithBroadcasting( &broadcastedDims)) return failure(); @@ -1792,16 +1797,16 @@ struct TransferReadToVectorLoadLowering // Create vector load op. Operation *loadOp; - if (read.mask()) { + if (read.getMask()) { Value fill = rewriter.create( - read.getLoc(), unbroadcastedVectorType, read.padding()); + read.getLoc(), unbroadcastedVectorType, read.getPadding()); loadOp = rewriter.create( - read.getLoc(), unbroadcastedVectorType, read.source(), read.indices(), - read.mask(), fill); + read.getLoc(), unbroadcastedVectorType, read.getSource(), + read.getIndices(), read.getMask(), fill); } else { - loadOp = rewriter.create(read.getLoc(), - unbroadcastedVectorType, - read.source(), read.indices()); + loadOp = rewriter.create( + read.getLoc(), unbroadcastedVectorType, read.getSource(), + read.getIndices()); } // Insert a broadcasting op if required. @@ -1836,7 +1841,7 @@ struct VectorLoadToMemrefLoadLowering if (vecType.getNumElements() != 1) return failure(); auto memrefLoad = rewriter.create( - loadOp.getLoc(), loadOp.base(), loadOp.indices()); + loadOp.getLoc(), loadOp.getBase(), loadOp.getIndices()); rewriter.replaceOpWithNewOp(loadOp, vecType, memrefLoad); return success(); @@ -1857,15 +1862,15 @@ struct VectorStoreToMemrefStoreLowering if (vecType.getRank() == 0) { // TODO: Unifiy once ExtractOp supports 0-d vectors. extracted = rewriter.create( - storeOp.getLoc(), storeOp.valueToStore()); + storeOp.getLoc(), storeOp.getValueToStore()); } else { SmallVector indices(vecType.getRank(), 0); extracted = rewriter.create( - storeOp.getLoc(), storeOp.valueToStore(), indices); + storeOp.getLoc(), storeOp.getValueToStore(), indices); } rewriter.replaceOpWithNewOp( - storeOp, extracted, storeOp.base(), storeOp.indices()); + storeOp, extracted, storeOp.getBase(), storeOp.getIndices()); return success(); } }; @@ -1893,7 +1898,7 @@ struct TransferWriteToVectorStoreLowering // Permutations are handled by VectorToSCF or // populateVectorTransferPermutationMapLoweringPatterns. if ( // pass-through for the 0-d corner case. - !write.permutation_map().isMinorIdentity()) + !write.getPermutationMap().isMinorIdentity()) return failure(); auto memRefType = write.getShapedType().dyn_cast(); @@ -1918,12 +1923,13 @@ struct TransferWriteToVectorStoreLowering // Out-of-bounds dims are handled by MaterializeTransferMask. if (write.hasOutOfBoundsDim()) return failure(); - if (write.mask()) { + if (write.getMask()) { rewriter.replaceOpWithNewOp( - write, write.source(), write.indices(), write.mask(), write.vector()); + write, write.getSource(), write.getIndices(), write.getMask(), + write.getVector()); } else { rewriter.replaceOpWithNewOp( - write, write.vector(), write.source(), write.indices()); + write, write.getVector(), write.getSource(), write.getIndices()); } return success(); } @@ -1957,7 +1963,7 @@ struct BubbleDownVectorBitCastForExtract if (extractOp.getVectorType().getRank() != 1) return failure(); - auto castOp = extractOp.vector().getDefiningOp(); + auto castOp = extractOp.getVector().getDefiningOp(); if (!castOp) return failure(); @@ -1983,14 +1989,14 @@ struct BubbleDownVectorBitCastForExtract return (*attr.getAsValueRange().begin()).getZExtValue(); }; - uint64_t index = getFirstIntValue(extractOp.position()); + uint64_t index = getFirstIntValue(extractOp.getPosition()); // Get the single scalar (as a vector) in the source value that packs the // desired scalar. E.g. extract vector<1xf32> from vector<4xf32> VectorType oneScalarType = VectorType::get({1}, castSrcType.getElementType()); Value packedValue = rewriter.create( - extractOp.getLoc(), oneScalarType, castOp.source(), + extractOp.getLoc(), oneScalarType, castOp.getSource(), rewriter.getI64ArrayAttr(index / expandRatio)); // Cast it to a vector with the desired scalar's type. @@ -2027,7 +2033,7 @@ struct BubbleDownBitCastForStridedSliceExtract LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp, PatternRewriter &rewriter) const override { - auto castOp = extractOp.vector().getDefiningOp(); + auto castOp = extractOp.getVector().getDefiningOp(); if (!castOp) return failure(); @@ -2042,7 +2048,7 @@ struct BubbleDownBitCastForStridedSliceExtract return failure(); // Only accept all one strides for now. - if (llvm::any_of(extractOp.strides().getAsValueRange(), + if (llvm::any_of(extractOp.getStrides().getAsValueRange(), [](const APInt &val) { return !val.isOneValue(); })) return failure(); @@ -2054,7 +2060,7 @@ struct BubbleDownBitCastForStridedSliceExtract // are selecting the full range for the last bitcasted dimension; other // dimensions aren't affected. Otherwise, we need to scale down the last // dimension's offset given we are extracting from less elements now. - ArrayAttr newOffsets = extractOp.offsets(); + ArrayAttr newOffsets = extractOp.getOffsets(); if (newOffsets.size() == rank) { SmallVector offsets = getIntValueVector(newOffsets); if (offsets.back() % expandRatio != 0) @@ -2064,7 +2070,7 @@ struct BubbleDownBitCastForStridedSliceExtract } // Similarly for sizes. - ArrayAttr newSizes = extractOp.sizes(); + ArrayAttr newSizes = extractOp.getSizes(); if (newSizes.size() == rank) { SmallVector sizes = getIntValueVector(newSizes); if (sizes.back() % expandRatio != 0) @@ -2080,8 +2086,8 @@ struct BubbleDownBitCastForStridedSliceExtract VectorType::get(dims, castSrcType.getElementType()); auto newExtractOp = rewriter.create( - extractOp.getLoc(), newExtractType, castOp.source(), newOffsets, - newSizes, extractOp.strides()); + extractOp.getLoc(), newExtractType, castOp.getSource(), newOffsets, + newSizes, extractOp.getStrides()); rewriter.replaceOpWithNewOp( extractOp, extractOp.getType(), newExtractOp); @@ -2120,12 +2126,12 @@ struct BubbleUpBitCastForStridedSliceInsert int64_t shrinkRatio = castSrcLastDim / castDstLastDim; auto insertOp = - bitcastOp.source().getDefiningOp(); + bitcastOp.getSource().getDefiningOp(); if (!insertOp) return failure(); // Only accept all one strides for now. - if (llvm::any_of(insertOp.strides().getAsValueRange(), + if (llvm::any_of(insertOp.getStrides().getAsValueRange(), [](const APInt &val) { return !val.isOneValue(); })) return failure(); @@ -2135,7 +2141,7 @@ struct BubbleUpBitCastForStridedSliceInsert if (rank != insertOp.getDestVectorType().getRank()) return failure(); - ArrayAttr newOffsets = insertOp.offsets(); + ArrayAttr newOffsets = insertOp.getOffsets(); assert(newOffsets.size() == rank); SmallVector offsets = getIntValueVector(newOffsets); if (offsets.back() % shrinkRatio != 0) @@ -2150,7 +2156,7 @@ struct BubbleUpBitCastForStridedSliceInsert VectorType::get(srcDims, castDstType.getElementType()); auto newCastSrcOp = rewriter.create( - bitcastOp.getLoc(), newCastSrcType, insertOp.source()); + bitcastOp.getLoc(), newCastSrcType, insertOp.getSource()); SmallVector dstDims = llvm::to_vector<4>(insertOp.getDestVectorType().getShape()); @@ -2159,11 +2165,11 @@ struct BubbleUpBitCastForStridedSliceInsert VectorType::get(dstDims, castDstType.getElementType()); auto newCastDstOp = rewriter.create( - bitcastOp.getLoc(), newCastDstType, insertOp.dest()); + bitcastOp.getLoc(), newCastDstType, insertOp.getDest()); rewriter.replaceOpWithNewOp( bitcastOp, bitcastOp.getType(), newCastSrcOp, newCastDstOp, newOffsets, - insertOp.strides()); + insertOp.getStrides()); return success(); } @@ -2229,7 +2235,7 @@ public: return failure(); if (xferOp.getVectorType().getRank() > 1 || - llvm::size(xferOp.indices()) == 0) + llvm::size(xferOp.getIndices()) == 0) return failure(); Location loc = xferOp->getLoc(); @@ -2240,24 +2246,24 @@ public: // // TODO: when the leaf transfer rank is k > 1, we need the last `k` // dimensions here. - unsigned lastIndex = llvm::size(xferOp.indices()) - 1; - Value off = xferOp.indices()[lastIndex]; + unsigned lastIndex = llvm::size(xferOp.getIndices()) - 1; + Value off = xferOp.getIndices()[lastIndex]; Value dim = - vector::createOrFoldDimOp(rewriter, loc, xferOp.source(), lastIndex); + vector::createOrFoldDimOp(rewriter, loc, xferOp.getSource(), lastIndex); Value b = rewriter.create(loc, dim.getType(), dim, off); Value mask = rewriter.create( loc, VectorType::get(vtp.getShape(), rewriter.getI1Type(), vtp.getNumScalableDims()), b); - if (xferOp.mask()) { + if (xferOp.getMask()) { // Intersect the in-bounds with the mask specified as an op parameter. - mask = rewriter.create(loc, mask, xferOp.mask()); + mask = rewriter.create(loc, mask, xferOp.getMask()); } rewriter.updateRootInPlace(xferOp, [&]() { - xferOp.maskMutable().assign(mask); - xferOp.in_boundsAttr(rewriter.getBoolArrayAttr({true})); + xferOp.getMaskMutable().assign(mask); + xferOp.setInBoundsAttr(rewriter.getBoolArrayAttr({true})); }); return success(); @@ -2306,14 +2312,14 @@ class DropInnerMostUnitDims : public OpRewritePattern { return failure(); // TODO: support mask. - if (readOp.mask()) + if (readOp.getMask()) return failure(); - auto srcType = readOp.source().getType().dyn_cast(); + auto srcType = readOp.getSource().getType().dyn_cast(); if (!srcType || !srcType.hasStaticShape()) return failure(); - if (!readOp.permutation_map().isMinorIdentity()) + if (!readOp.getPermutationMap().isMinorIdentity()) return failure(); auto targetType = readOp.getVectorType(); @@ -2366,19 +2372,19 @@ class DropInnerMostUnitDims : public OpRewritePattern { SmallVector strides(srcType.getRank(), 1); ArrayAttr inBoundsAttr = - readOp.in_bounds() + readOp.getInBounds() ? rewriter.getArrayAttr( - readOp.in_boundsAttr().getValue().drop_back(dimsToDrop)) + readOp.getInBoundsAttr().getValue().drop_back(dimsToDrop)) : ArrayAttr(); Value rankedReducedView = rewriter.create( - loc, resultMemrefType, readOp.source(), offsets, srcType.getShape(), + loc, resultMemrefType, readOp.getSource(), offsets, srcType.getShape(), strides); auto permMap = getTransferMinorIdentityMap( rankedReducedView.getType().cast(), resultTargetVecType); Value result = rewriter.create( loc, resultTargetVecType, rankedReducedView, - readOp.indices().drop_back(dimsToDrop), AffineMapAttr::get(permMap), - readOp.padding(), + readOp.getIndices().drop_back(dimsToDrop), AffineMapAttr::get(permMap), + readOp.getPadding(), // TODO: support mask. /*mask=*/Value(), inBoundsAttr); rewriter.replaceOpWithNewOp(readOp, targetType, @@ -2514,14 +2520,14 @@ struct ScanToArithOps : public OpRewritePattern { ArrayRef destShape = destType.getShape(); auto elType = destType.getElementType(); bool isInt = elType.isIntOrIndex(); - if (!isValidKind(isInt, scanOp.kind())) + if (!isValidKind(isInt, scanOp.getKind())) return failure(); VectorType resType = VectorType::get(destShape, elType); Value result = rewriter.create( loc, resType, rewriter.getZeroAttr(resType)); - int64_t reductionDim = scanOp.reduction_dim(); - bool inclusive = scanOp.inclusive(); + int64_t reductionDim = scanOp.getReductionDim(); + bool inclusive = scanOp.getInclusive(); int64_t destRank = destType.getRank(); VectorType initialValueType = scanOp.getInitialValueType(); int64_t initialValueRank = initialValueType.getRank(); @@ -2541,7 +2547,7 @@ struct ScanToArithOps : public OpRewritePattern { offsets[reductionDim] = i; ArrayAttr scanOffsets = rewriter.getI64ArrayAttr(offsets); Value input = rewriter.create( - loc, reductionType, scanOp.source(), scanOffsets, scanSizes, + loc, reductionType, scanOp.getSource(), scanOffsets, scanSizes, scanStrides); Value output; if (i == 0) { @@ -2551,15 +2557,15 @@ struct ScanToArithOps : public OpRewritePattern { if (initialValueRank == 0) { // ShapeCastOp cannot handle 0-D vectors output = rewriter.create( - loc, input.getType(), scanOp.initial_value()); + loc, input.getType(), scanOp.getInitialValue()); } else { output = rewriter.create( - loc, input.getType(), scanOp.initial_value()); + loc, input.getType(), scanOp.getInitialValue()); } } } else { Value y = inclusive ? input : lastInput; - output = genOperator(loc, lastOutput, y, scanOp.kind(), rewriter); + output = genOperator(loc, lastOutput, y, scanOp.getKind(), rewriter); assert(output != nullptr); } result = rewriter.create( diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp index 2d1e7c1..2b73018 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnrollDistribute.cpp @@ -112,7 +112,7 @@ struct UnrollTransferReadPattern // TODO: support 0-d corner case. if (readOp.getTransferRank() == 0) return failure(); - if (readOp.mask()) + if (readOp.getMask()) return failure(); auto targetShape = getTargetShape(options, readOp); if (!targetShape) @@ -129,16 +129,16 @@ struct UnrollTransferReadPattern loc, sourceVectorType, rewriter.getZeroAttr(sourceVectorType)); auto targetType = VectorType::get(*targetShape, sourceVectorType.getElementType()); - SmallVector originalIndices(readOp.indices().begin(), - readOp.indices().end()); + SmallVector originalIndices(readOp.getIndices().begin(), + readOp.getIndices().end()); for (int64_t i = 0; i < sliceCount; i++) { SmallVector indices = sliceTransferIndices(i, originalSize, *targetShape, originalIndices, - readOp.permutation_map(), loc, rewriter); + readOp.getPermutationMap(), loc, rewriter); auto slicedRead = rewriter.create( - loc, targetType, readOp.source(), indices, - readOp.permutation_mapAttr(), readOp.padding(), readOp.mask(), - readOp.in_boundsAttr()); + loc, targetType, readOp.getSource(), indices, + readOp.getPermutationMapAttr(), readOp.getPadding(), readOp.getMask(), + readOp.getInBoundsAttr()); SmallVector elementOffsets = getVectorOffset(originalSize, *targetShape, i); @@ -165,7 +165,7 @@ struct UnrollTransferWritePattern if (writeOp.getTransferRank() == 0) return failure(); - if (writeOp.mask()) + if (writeOp.getMask()) return failure(); auto targetShape = getTargetShape(options, writeOp); if (!targetShape) @@ -177,21 +177,21 @@ struct UnrollTransferWritePattern SmallVector ratio = *shapeRatio(originalSize, *targetShape); // Compute shape ratio of 'shape' and 'sizes'. int64_t sliceCount = computeMaxLinearIndex(ratio); - SmallVector originalIndices(writeOp.indices().begin(), - writeOp.indices().end()); + SmallVector originalIndices(writeOp.getIndices().begin(), + writeOp.getIndices().end()); Value resultTensor; for (int64_t i = 0; i < sliceCount; i++) { SmallVector elementOffsets = getVectorOffset(originalSize, *targetShape, i); Value slicedVector = rewriter.create( - loc, writeOp.vector(), elementOffsets, *targetShape, strides); + loc, writeOp.getVector(), elementOffsets, *targetShape, strides); SmallVector indices = sliceTransferIndices(i, originalSize, *targetShape, originalIndices, - writeOp.permutation_map(), loc, rewriter); + writeOp.getPermutationMap(), loc, rewriter); Operation *slicedWrite = rewriter.create( - loc, slicedVector, resultTensor ? resultTensor : writeOp.source(), - indices, writeOp.permutation_mapAttr(), writeOp.in_boundsAttr()); + loc, slicedVector, resultTensor ? resultTensor : writeOp.getSource(), + indices, writeOp.getPermutationMapAttr(), writeOp.getInBoundsAttr()); // For the tensor case update the destination for the next transfer write. if (!slicedWrite->getResults().empty()) resultTensor = slicedWrite->getResult(0); @@ -267,19 +267,21 @@ struct UnrollContractionPattern AffineMap lhsPermutationMap = contractOp.getIndexingMaps()[0]; SmallVector lhsOffets = applyPermutationMap(lhsPermutationMap, ArrayRef(offsets)); - extractOperand(0, contractOp.lhs(), lhsPermutationMap, lhsOffets); + extractOperand(0, contractOp.getLhs(), lhsPermutationMap, lhsOffets); // If there is a mask associated to lhs, extract it as well. if (slicesOperands.size() > 3) - extractOperand(3, contractOp.masks()[0], lhsPermutationMap, lhsOffets); + extractOperand(3, contractOp.getMasks()[0], lhsPermutationMap, + lhsOffets); // Extract the new rhs operand. AffineMap rhsPermutationMap = contractOp.getIndexingMaps()[1]; SmallVector rhsOffets = applyPermutationMap(rhsPermutationMap, ArrayRef(offsets)); - extractOperand(1, contractOp.rhs(), rhsPermutationMap, rhsOffets); + extractOperand(1, contractOp.getRhs(), rhsPermutationMap, rhsOffets); // If there is a mask associated to rhs, extract it as well. if (slicesOperands.size() > 4) - extractOperand(4, contractOp.masks()[1], rhsPermutationMap, rhsOffets); + extractOperand(4, contractOp.getMasks()[1], rhsPermutationMap, + rhsOffets); AffineMap accPermutationMap = contractOp.getIndexingMaps()[2]; SmallVector accOffets = @@ -290,7 +292,7 @@ struct UnrollContractionPattern if (accIt != accCache.end()) slicesOperands[2] = accIt->second; else - extractOperand(2, contractOp.acc(), accPermutationMap, accOffets); + extractOperand(2, contractOp.getAcc(), accPermutationMap, accOffets); SmallVector dstShape = applyPermutationMap(dstAffineMap, ArrayRef(*targetShape)); @@ -367,8 +369,8 @@ struct UnrollMultiReductionPattern // reduction loop keeps updating the accumulator. auto accIt = accCache.find(destOffset); if (accIt != accCache.end()) - result = makeArithReduction(rewriter, loc, reductionOp.kind(), result, - accIt->second); + result = makeArithReduction(rewriter, loc, reductionOp.getKind(), + result, accIt->second); accCache[destOffset] = result; } // Assemble back the accumulator into a single vector. @@ -451,7 +453,7 @@ struct PointwiseExtractPattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(vector::ExtractMapOp extract, PatternRewriter &rewriter) const override { - Operation *definedOp = extract.vector().getDefiningOp(); + Operation *definedOp = extract.getVector().getDefiningOp(); if (!definedOp || !OpTrait::hasElementwiseMappableTraits(definedOp) || definedOp->getNumResults() != 1) return failure(); @@ -467,7 +469,7 @@ struct PointwiseExtractPattern : public OpRewritePattern { loc, VectorType::get(extract.getResultType().getShape(), vecType.getElementType()), - operand.get(), extract.ids())); + operand.get(), extract.getIds())); } Operation *newOp = cloneOpWithOperandsAndTypes( rewriter, loc, definedOp, extractOperands, extract.getResultType()); @@ -482,7 +484,7 @@ struct ContractExtractPattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(vector::ExtractMapOp extract, PatternRewriter &rewriter) const override { - Operation *definedOp = extract.vector().getDefiningOp(); + Operation *definedOp = extract.getVector().getDefiningOp(); auto contract = dyn_cast_or_null(definedOp); if (!contract) return failure(); @@ -514,7 +516,7 @@ struct ContractExtractPattern : public OpRewritePattern { VectorType newVecType = VectorType::get(operandShape, vecType.getElementType()); extractOperands.push_back(rewriter.create( - loc, newVecType, operand, extract.ids())); + loc, newVecType, operand, extract.getIds())); } Operation *newOp = cloneOpWithOperandsAndTypes(rewriter, loc, definedOp, extractOperands, @@ -554,11 +556,12 @@ struct TransferReadExtractPattern dyn_cast(*read.getResult().getUsers().begin()); if (!extract) return failure(); - if (read.mask()) + if (read.getMask()) return failure(); - SmallVector indices(read.indices().begin(), read.indices().end()); - AffineMap indexMap = extract.map().compose(read.permutation_map()); + SmallVector indices(read.getIndices().begin(), + read.getIndices().end()); + AffineMap indexMap = extract.map().compose(read.getPermutationMap()); unsigned idCount = 0; ImplicitLocOpBuilder lb(read.getLoc(), rewriter); for (auto it : @@ -574,14 +577,15 @@ struct TransferReadExtractPattern extract.getResultType().getDimSize(vectorPos), read.getContext()); indices[indexPos] = makeComposedAffineApply( rewriter, read.getLoc(), d0 + scale * d1, - {indices[indexPos], extract.ids()[idCount++]}); + {indices[indexPos], extract.getIds()[idCount++]}); } Value newRead = lb.create( - extract.getType(), read.source(), indices, read.permutation_mapAttr(), - read.padding(), read.mask(), read.in_boundsAttr()); + extract.getType(), read.getSource(), indices, + read.getPermutationMapAttr(), read.getPadding(), read.getMask(), + read.getInBoundsAttr()); Value dest = lb.create( read.getType(), rewriter.getZeroAttr(read.getType())); - newRead = lb.create(newRead, dest, extract.ids()); + newRead = lb.create(newRead, dest, extract.getIds()); rewriter.replaceOp(read, newRead); return success(); } @@ -597,14 +601,14 @@ struct TransferWriteInsertPattern if (write.getTransferRank() == 0) return failure(); - auto insert = write.vector().getDefiningOp(); + auto insert = write.getVector().getDefiningOp(); if (!insert) return failure(); - if (write.mask()) + if (write.getMask()) return failure(); - SmallVector indices(write.indices().begin(), - write.indices().end()); - AffineMap indexMap = insert.map().compose(write.permutation_map()); + SmallVector indices(write.getIndices().begin(), + write.getIndices().end()); + AffineMap indexMap = insert.map().compose(write.getPermutationMap()); unsigned idCount = 0; Location loc = write.getLoc(); for (auto it : @@ -619,13 +623,13 @@ struct TransferWriteInsertPattern auto scale = getAffineConstantExpr( insert.getSourceVectorType().getDimSize(vectorPos), write.getContext()); - indices[indexPos] = - makeComposedAffineApply(rewriter, loc, d0 + scale * d1, - {indices[indexPos], insert.ids()[idCount++]}); + indices[indexPos] = makeComposedAffineApply( + rewriter, loc, d0 + scale * d1, + {indices[indexPos], insert.getIds()[idCount++]}); } rewriter.create( - loc, insert.vector(), write.source(), indices, - write.permutation_mapAttr(), write.in_boundsAttr()); + loc, insert.getVector(), write.getSource(), indices, + write.getPermutationMapAttr(), write.getInBoundsAttr()); rewriter.eraseOp(write); return success(); } @@ -654,7 +658,7 @@ struct UnrollReductionPattern : public OpRewritePattern { getVectorOffset(originalSize, *targetShape, i); SmallVector strides(offsets.size(), 1); Value slicedOperand = rewriter.create( - loc, reductionOp.vector(), offsets, *targetShape, strides); + loc, reductionOp.getVector(), offsets, *targetShape, strides); Operation *newOp = cloneOpWithOperandsAndTypes( rewriter, loc, reductionOp, slicedOperand, reductionOp.getType()); Value result = newOp->getResult(0); @@ -664,7 +668,7 @@ struct UnrollReductionPattern : public OpRewritePattern { accumulator = result; } else { // On subsequent reduction, combine with the accumulator. - accumulator = makeArithReduction(rewriter, loc, reductionOp.kind(), + accumulator = makeArithReduction(rewriter, loc, reductionOp.getKind(), accumulator, result); } } diff --git a/mlir/lib/Dialect/X86Vector/Transforms/AVXTranspose.cpp b/mlir/lib/Dialect/X86Vector/Transforms/AVXTranspose.cpp index 065848d..1346256 100644 --- a/mlir/lib/Dialect/X86Vector/Transforms/AVXTranspose.cpp +++ b/mlir/lib/Dialect/X86Vector/Transforms/AVXTranspose.cpp @@ -264,7 +264,7 @@ public: return rewriter.notifyMatchFailure(op, "Unsupported vector type"); SmallVector transp; - for (auto attr : op.transp()) + for (auto attr : op.getTransp()) transp.push_back(attr.cast().getInt()); // Check whether the two source vector dimensions that are greater than one @@ -289,7 +289,7 @@ public: VectorType::get({n * m}, op.getVectorType().getElementType()); auto reshInputType = VectorType::get({m, n}, srcType.getElementType()); auto reshInput = - ib.create(flattenedType, op.vector()); + ib.create(flattenedType, op.getVector()); reshInput = ib.create(reshInputType, reshInput); // Extract 1-D vectors from the higher-order dimension of the input diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp index 59b5891..03fc3a8 100644 --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -86,7 +86,7 @@ private: dstVec.getShape().end()); } if (auto writeOp = dyn_cast(op)) { - auto insert = writeOp.vector().getDefiningOp(); + auto insert = writeOp.getVector().getDefiningOp(); if (!insert) return llvm::None; ArrayRef shape = insert.getSourceVectorType().getShape();