[mlir][vector] NFC: Improve vector type accessor methods
authorLei Zhang <antiagainst@google.com>
Thu, 16 Feb 2023 04:08:15 +0000 (04:08 +0000)
committerLei Zhang <antiagainst@google.com>
Thu, 16 Feb 2023 04:08:33 +0000 (04:08 +0000)
Plain `getVectorType()` can be quite confusing and error-prone
given that, well, vector ops always work on vector types, and
it can commonly involve both source and result vectors. So this
commit makes various such accessor methods to be explicit w.r.t.
source or result vectors.

Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D144159

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
mlir/lib/Dialect/X86Vector/Transforms/AVXTranspose.cpp

index 94d4b64..c5ebe9f 100644 (file)
@@ -321,7 +321,7 @@ def Vector_ReductionOp :
     ```
   }];
   let extraClassDeclaration = [{
-    VectorType getVectorType() {
+    VectorType getSourceVectorType() {
       return getVector().getType().cast<VectorType>();
     }
   }];
@@ -449,7 +449,7 @@ def Vector_BroadcastOp :
   }];
   let extraClassDeclaration = [{
     Type getSourceType() { return getSource().getType(); }
-    VectorType getVectorType() {
+    VectorType getResultVectorType() {
       return getVector().getType().cast<VectorType>();
     }
 
@@ -466,7 +466,7 @@ def Vector_BroadcastOp :
     /// `value`, `dstShape` and `broadcastedDims` must be properly specified or
     /// the helper will assert. This means:
     ///   1. `dstShape` must not be empty.
-    ///   2. `broadcastedDims` must be confined to [0 .. rank(value.getVectorType)]
+    ///   2. `broadcastedDims` must be confined to [0 .. rank(value.getResultVectorType)]
     ///   2. `dstShape` trimmed of the dimensions specified in `broadcastedDims`
     //       must match the `value` shape.
     static Value createOrFoldBroadcastOp(
@@ -537,7 +537,7 @@ def Vector_ShuffleOp :
     VectorType getV2VectorType() {
       return getV2().getType().cast<VectorType>();
     }
-    VectorType getVectorType() {
+    VectorType getResultVectorType() {
       return getVector().getType().cast<VectorType>();
     }
   }];
@@ -584,7 +584,7 @@ def Vector_ExtractElementOp :
     OpBuilder<(ins "Value":$source)>,
   ];
   let extraClassDeclaration = [{
-    VectorType getVectorType() {
+    VectorType getSourceVectorType() {
       return getVector().getType().cast<VectorType>();
     }
   }];
@@ -619,7 +619,7 @@ def Vector_ExtractOp :
   ];
   let extraClassDeclaration = [{
     static StringRef getPositionAttrStrName() { return "position"; }
-    VectorType getVectorType() {
+    VectorType getSourceVectorType() {
       return getVector().getType().cast<VectorType>();
     }
     static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
@@ -996,7 +996,7 @@ def Vector_OuterProductOp :
         ? VectorType()
         : (*getAcc().begin()).getType().cast<VectorType>();
     }
-    VectorType getVectorType() {
+    VectorType getResultVectorType() {
       return getResult().getType().cast<VectorType>();
     }
     static constexpr StringRef getKindAttrStrName() {
@@ -1172,7 +1172,9 @@ def Vector_ExtractStridedSliceOp :
     static StringRef getOffsetsAttrStrName() { return "offsets"; }
     static StringRef getSizesAttrStrName() { return "sizes"; }
     static StringRef getStridesAttrStrName() { return "strides"; }
-    VectorType getVectorType(){ return getVector().getType().cast<VectorType>(); }
+    VectorType getSourceVectorType() {
+      return getVector().getType().cast<VectorType>(); 
+    }
     void getOffsets(SmallVectorImpl<int64_t> &results);
     bool hasNonUnitStrides() {
       return llvm::any_of(getStrides(), [](Attribute attr) {
@@ -2424,10 +2426,10 @@ def Vector_TransposeOp :
     OpBuilder<(ins "Value":$vector, "ArrayRef<int64_t>":$transp)>
   ];
   let extraClassDeclaration = [{
-    VectorType getVectorType() {
+    VectorType getSourceVectorType() {
       return getVector().getType().cast<VectorType>();
     }
-    VectorType getResultType() {
+    VectorType getResultVectorType() {
       return getResult().getType().cast<VectorType>();
     }
     void getTransp(SmallVectorImpl<int64_t> &results);
index ece1751..d9533b4 100644 (file)
@@ -203,7 +203,7 @@ static bool constantSupportsMMAMatrixType(arith::ConstantOp constantOp) {
 
 /// Return true if this is a broadcast from scalar to a 2D vector.
 static bool broadcastSupportsMMAMatrixType(vector::BroadcastOp broadcastOp) {
-  return broadcastOp.getVectorType().getRank() == 2;
+  return broadcastOp.getResultVectorType().getRank() == 2;
 }
 
 /// Return true if this integer extend op can be folded into a contract op.
@@ -949,7 +949,7 @@ convertExtractStridedSlice(RewriterBase &rewriter,
 
   SmallVector<int64_t> sizes;
   populateFromInt64AttrArray(op.getSizes(), sizes);
-  ArrayRef<int64_t> warpVectorShape = op.getVectorType().getShape();
+  ArrayRef<int64_t> warpVectorShape = op.getSourceVectorType().getShape();
 
   // Compute offset in vector registers. Note that the mma.sync vector registers
   // are shaped as numberOfFragments x numberOfRegistersPerfFragment. The vector
@@ -1045,7 +1045,7 @@ convertBroadcastOp(RewriterBase &rewriter, vector::BroadcastOp op,
   assert(broadcastSupportsMMAMatrixType(op));
 
   const char *fragType = inferFragType(op);
-  auto vecType = op.getVectorType();
+  auto vecType = op.getResultVectorType();
   gpu::MMAMatrixType type = gpu::MMAMatrixType::get(
       vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType));
   auto matrix = rewriter.create<gpu::SubgroupMmaConstantMatrixOp>(
index a1c1c3d..159bae8 100644 (file)
@@ -939,7 +939,7 @@ public:
     auto loc = shuffleOp->getLoc();
     auto v1Type = shuffleOp.getV1VectorType();
     auto v2Type = shuffleOp.getV2VectorType();
-    auto vectorType = shuffleOp.getVectorType();
+    auto vectorType = shuffleOp.getResultVectorType();
     Type llvmType = typeConverter->convertType(vectorType);
     auto maskArrayAttr = shuffleOp.getMask();
 
@@ -1002,7 +1002,7 @@ public:
   LogicalResult
   matchAndRewrite(vector::ExtractElementOp extractEltOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    auto vectorType = extractEltOp.getVectorType();
+    auto vectorType = extractEltOp.getSourceVectorType();
     auto llvmType = typeConverter->convertType(vectorType.getElementType());
 
     // Bail if result type cannot be lowered.
index 0dbf67e..a3a3a61 100644 (file)
@@ -83,7 +83,8 @@ struct VectorBroadcastConvert final
   LogicalResult
   matchAndRewrite(vector::BroadcastOp castOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    Type resultType = getTypeConverter()->convertType(castOp.getVectorType());
+    Type resultType =
+        getTypeConverter()->convertType(castOp.getResultVectorType());
     if (!resultType)
       return failure();
 
@@ -92,10 +93,10 @@ struct VectorBroadcastConvert final
       return success();
     }
 
-    SmallVector<Value, 4> source(castOp.getVectorType().getNumElements(),
+    SmallVector<Value, 4> source(castOp.getResultVectorType().getNumElements(),
                                  adaptor.getSource());
     rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
-        castOp, castOp.getVectorType(), source);
+        castOp, castOp.getResultVectorType(), source);
     return success();
   }
 };
@@ -405,7 +406,7 @@ struct VectorShuffleOpConvert final
   LogicalResult
   matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    auto oldResultType = shuffleOp.getVectorType();
+    auto oldResultType = shuffleOp.getResultVectorType();
     if (!spirv::CompositeType::isValid(oldResultType))
       return failure();
     Type newResultType = getTypeConverter()->convertType(oldResultType);
index 64125ef..8c6609d 100644 (file)
@@ -416,7 +416,7 @@ void vector::ReductionOp::build(OpBuilder &builder, OperationState &result,
 
 LogicalResult ReductionOp::verify() {
   // Verify for 0-D and 1-D vector.
-  int64_t rank = getVectorType().getRank();
+  int64_t rank = getSourceVectorType().getRank();
   if (rank > 1)
     return emitOpError("unsupported reduction rank: ") << rank;
 
@@ -465,7 +465,7 @@ void ReductionOp::print(OpAsmPrinter &p) {
 
 /// Returns the mask type expected by this operation.
 Type ReductionOp::getExpectedMaskType() {
-  auto vecType = getVectorType();
+  auto vecType = getSourceVectorType();
   return vecType.cloneWith(std::nullopt,
                            IntegerType::get(vecType.getContext(), /*width=*/1));
 }
@@ -515,7 +515,7 @@ Value mlir::vector::getVectorReductionOp(arith::AtomicRMWKind op,
 }
 
 std::optional<SmallVector<int64_t, 4>> ReductionOp::getShapeForUnroll() {
-  return llvm::to_vector<4>(getVectorType().getShape());
+  return llvm::to_vector<4>(getSourceVectorType().getShape());
 }
 
 namespace {
@@ -530,7 +530,7 @@ struct ElideSingleElementReduction : public OpRewritePattern<ReductionOp> {
     if (maskableOp.isMasked())
       return failure();
 
-    auto vectorType = reductionOp.getVectorType();
+    auto vectorType = reductionOp.getSourceVectorType();
     if (vectorType.getRank() != 0 && vectorType.getDimSize(0) != 1)
       return failure();
 
@@ -1074,7 +1074,7 @@ void vector::ExtractElementOp::build(OpBuilder &builder, OperationState &result,
 }
 
 LogicalResult vector::ExtractElementOp::verify() {
-  VectorType vectorType = getVectorType();
+  VectorType vectorType = getSourceVectorType();
   if (vectorType.getRank() == 0) {
     if (getPosition())
       return emitOpError("expected position to be empty with 0-D vector");
@@ -1167,13 +1167,14 @@ bool ExtractOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
 
 LogicalResult vector::ExtractOp::verify() {
   auto positionAttr = getPosition().getValue();
-  if (positionAttr.size() > static_cast<unsigned>(getVectorType().getRank()))
+  if (positionAttr.size() >
+      static_cast<unsigned>(getSourceVectorType().getRank()))
     return emitOpError(
         "expected position attribute of rank smaller than vector rank");
   for (const auto &en : llvm::enumerate(positionAttr)) {
     auto attr = en.value().dyn_cast<IntegerAttr>();
     if (!attr || attr.getInt() < 0 ||
-        attr.getInt() >= getVectorType().getDimSize(en.index()))
+        attr.getInt() >= getSourceVectorType().getDimSize(en.index()))
       return emitOpError("expected position attribute #")
              << (en.index() + 1)
              << " to be a non-negative integer smaller than the corresponding "
@@ -1314,7 +1315,7 @@ private:
 
 ExtractFromInsertTransposeChainState::ExtractFromInsertTransposeChainState(
     ExtractOp e)
-    : extractOp(e), vectorRank(extractOp.getVectorType().getRank()),
+    : extractOp(e), vectorRank(extractOp.getSourceVectorType().getRank()),
       extractedRank(extractOp.getPosition().size()) {
   assert(vectorRank >= extractedRank && "extracted pos overflow");
   sentinels.reserve(vectorRank - extractedRank);
@@ -1510,7 +1511,8 @@ static Value foldExtractFromShapeCast(ExtractOp extractOp) {
   int64_t stride = 1;
   for (int64_t i = 0, e = extractedPos.size(); i < e; i++) {
     strides.push_back(stride);
-    stride *= getDimReverse(extractOp.getVectorType(), i + destinationRank);
+    stride *=
+        getDimReverse(extractOp.getSourceVectorType(), i + destinationRank);
   }
 
   int64_t position = linearize(extractedPos, strides);
@@ -1552,7 +1554,7 @@ static Value foldExtractFromExtractStrided(ExtractOp extractOp) {
     size_t lastOffset = sliceOffsets.size() - 1;
     if (sliceOffsets.back() != 0 ||
         extractStridedSliceOp.getType().getDimSize(lastOffset) !=
-            extractStridedSliceOp.getVectorType().getDimSize(lastOffset))
+            extractStridedSliceOp.getSourceVectorType().getDimSize(lastOffset))
       break;
     sliceOffsets.pop_back();
   }
@@ -1561,8 +1563,8 @@ static Value foldExtractFromExtractStrided(ExtractOp extractOp) {
     destinationRank = vecType.getRank();
   // The dimensions of the result need to be untouched by the
   // extractStridedSlice op.
-  if (destinationRank >
-      extractStridedSliceOp.getVectorType().getRank() - sliceOffsets.size())
+  if (destinationRank > extractStridedSliceOp.getSourceVectorType().getRank() -
+                            sliceOffsets.size())
     return Value();
   auto extractedPos = extractVector<int64_t>(extractOp.getPosition());
   assert(extractedPos.size() >= sliceOffsets.size());
@@ -1827,7 +1829,7 @@ llvm::SetVector<int64_t> BroadcastOp::computeBroadcastedUnitDims() {
   if (!srcVectorType)
     return {};
   return ::computeBroadcastedUnitDims(srcVectorType.getShape(),
-                                      getVectorType().getShape());
+                                      getResultVectorType().getShape());
 }
 
 /// Broadcast `value` to a vector of `dstShape`, knowing that exactly the
@@ -1973,8 +1975,8 @@ mlir::vector::isBroadcastableTo(Type srcType, VectorType dstVectorType,
 
 LogicalResult BroadcastOp::verify() {
   std::pair<int, int> mismatchingDims;
-  BroadcastableToResult res =
-      isBroadcastableTo(getSourceType(), getVectorType(), &mismatchingDims);
+  BroadcastableToResult res = isBroadcastableTo(
+      getSourceType(), getResultVectorType(), &mismatchingDims);
   if (res == BroadcastableToResult::Success)
     return success();
   if (res == BroadcastableToResult::SourceRankHigher)
@@ -1988,11 +1990,11 @@ LogicalResult BroadcastOp::verify() {
 }
 
 OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {
-  if (getSourceType() == getVectorType())
+  if (getSourceType() == getResultVectorType())
     return getSource();
   if (!adaptor.getSource())
     return {};
-  auto vectorType = getVectorType();
+  auto vectorType = getResultVectorType();
   if (adaptor.getSource().isa<IntegerAttr, FloatAttr>())
     return DenseElementsAttr::get(vectorType, adaptor.getSource());
   if (auto attr = adaptor.getSource().dyn_cast<SplatElementsAttr>())
@@ -2011,8 +2013,9 @@ struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
     auto srcBroadcast = broadcastOp.getSource().getDefiningOp<BroadcastOp>();
     if (!srcBroadcast)
       return failure();
-    rewriter.replaceOpWithNewOp<BroadcastOp>(
-        broadcastOp, broadcastOp.getVectorType(), srcBroadcast.getSource());
+    rewriter.replaceOpWithNewOp<BroadcastOp>(broadcastOp,
+                                             broadcastOp.getResultVectorType(),
+                                             srcBroadcast.getSource());
     return success();
   }
 };
@@ -2035,7 +2038,7 @@ void ShuffleOp::build(OpBuilder &builder, OperationState &result, Value v1,
 }
 
 LogicalResult ShuffleOp::verify() {
-  VectorType resultType = getVectorType();
+  VectorType resultType = getResultVectorType();
   VectorType v1Type = getV1VectorType();
   VectorType v2Type = getV2VectorType();
   // Verify ranks.
@@ -2143,7 +2146,7 @@ OpFoldResult vector::ShuffleOp::fold(FoldAdaptor adaptor) {
     }
   }
 
-  return DenseElementsAttr::get(getVectorType(), results);
+  return DenseElementsAttr::get(getResultVectorType(), results);
 }
 
 namespace {
@@ -2764,7 +2767,7 @@ LogicalResult OuterProductOp::verify() {
   Type tRHS = getOperandTypeRHS();
   VectorType vLHS = getOperandVectorTypeLHS(),
              vRHS = tRHS.dyn_cast<VectorType>(),
-             vACC = getOperandVectorTypeACC(), vRES = getVectorType();
+             vACC = getOperandVectorTypeACC(), vRES = getResultVectorType();
 
   if (vLHS.getRank() != 1)
     return emitOpError("expected 1-d vector for operand #1");
@@ -2805,7 +2808,7 @@ LogicalResult OuterProductOp::verify() {
 /// Returns the mask type expected by this operation. Mostly used for
 /// verification purposes. It requires the operation to be vectorized."
 Type OuterProductOp::getExpectedMaskType() {
-  auto vecType = this->getVectorType();
+  auto vecType = this->getResultVectorType();
   return VectorType::get(vecType.getShape(),
                          IntegerType::get(vecType.getContext(), /*width=*/1));
 }
@@ -2913,7 +2916,7 @@ void ExtractStridedSliceOp::build(OpBuilder &builder, OperationState &result,
 }
 
 LogicalResult ExtractStridedSliceOp::verify() {
-  auto type = getVectorType();
+  auto type = getSourceVectorType();
   auto offsets = getOffsetsAttr();
   auto sizes = getSizesAttr();
   auto strides = getStridesAttr();
@@ -2944,8 +2947,8 @@ LogicalResult ExtractStridedSliceOp::verify() {
                                                     /*halfOpen=*/false)))
     return failure();
 
-  auto resultType =
-      inferStridedSliceOpResultType(getVectorType(), offsets, sizes, strides);
+  auto resultType = inferStridedSliceOpResultType(getSourceVectorType(),
+                                                  offsets, sizes, strides);
   if (getResult().getType() != resultType)
     return emitOpError("expected result type to be ") << resultType;
 
@@ -2966,7 +2969,7 @@ foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op) {
   ArrayAttr extractSizes = op.getSizes();
   auto insertOp = op.getVector().getDefiningOp<InsertStridedSliceOp>();
   while (insertOp) {
-    if (op.getVectorType().getRank() !=
+    if (op.getSourceVectorType().getRank() !=
         insertOp.getSourceVectorType().getRank())
       return failure();
     ArrayAttr insertOffsets = insertOp.getOffsets();
@@ -3020,7 +3023,7 @@ foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op) {
 }
 
 OpFoldResult ExtractStridedSliceOp::fold(FoldAdaptor adaptor) {
-  if (getVectorType() == getResult().getType())
+  if (getSourceVectorType() == getResult().getType())
     return getVector();
   if (succeeded(foldExtractStridedOpFromInsertChain(*this)))
     return getResult();
@@ -5113,7 +5116,7 @@ OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
   // Eliminate splat constant transpose ops.
   if (auto attr = adaptor.getVector().dyn_cast_or_null<DenseElementsAttr>())
     if (attr.isSplat())
-      return attr.reshape(getResultType());
+      return attr.reshape(getResultVectorType());
 
   // Eliminate identity transpose ops. This happens when the dimensions of the
   // input vector remain in their original order after the transpose operation.
@@ -5131,8 +5134,8 @@ OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
 }
 
 LogicalResult vector::TransposeOp::verify() {
-  VectorType vectorType = getVectorType();
-  VectorType resultType = getResultType();
+  VectorType vectorType = getSourceVectorType();
+  VectorType resultType = getResultVectorType();
   int64_t rank = resultType.getRank();
   if (vectorType.getRank() != rank)
     return emitOpError("vector result rank mismatch: ") << rank;
@@ -5156,7 +5159,7 @@ LogicalResult vector::TransposeOp::verify() {
 }
 
 std::optional<SmallVector<int64_t, 4>> TransposeOp::getShapeForUnroll() {
-  return llvm::to_vector<4>(getResultType().getShape());
+  return llvm::to_vector<4>(getResultVectorType().getShape());
 }
 
 namespace {
@@ -5215,7 +5218,7 @@ struct FoldTransposedScalarBroadcast final
     auto srcVectorType = bcastOp.getSourceType().dyn_cast<VectorType>();
     if (!srcVectorType || srcVectorType.getNumElements() == 1) {
       rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
-          transposeOp, transposeOp.getResultType(), bcastOp.getSource());
+          transposeOp, transposeOp.getResultVectorType(), bcastOp.getSource());
       return success();
     }
 
@@ -5235,7 +5238,7 @@ public:
       return failure();
 
     rewriter.replaceOpWithNewOp<vector::SplatOp>(
-        transposeOp, transposeOp.getResultType(), splatOp.getInput());
+        transposeOp, transposeOp.getResultVectorType(), splatOp.getInput());
     return success();
   }
 };
index 6005f37..a242cca 100644 (file)
@@ -897,7 +897,7 @@ struct WarpOpExtract : public OpRewritePattern<WarpExecuteOnLane0Op> {
       return failure();
     unsigned int operandNumber = operand->getOperandNumber();
     auto extractOp = operand->get().getDefiningOp<vector::ExtractOp>();
-    VectorType extractSrcType = extractOp.getVectorType();
+    VectorType extractSrcType = extractOp.getSourceVectorType();
     Location loc = extractOp.getLoc();
 
     // "vector.extract %v[] : vector<f32>" is an invalid op.
@@ -930,7 +930,7 @@ struct WarpOpExtract : public OpRewritePattern<WarpExecuteOnLane0Op> {
       SmallVector<size_t> newRetIndices;
       WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
           rewriter, warpOp, {extractOp.getVector()},
-          {extractOp.getVectorType()}, newRetIndices);
+          {extractOp.getSourceVectorType()}, newRetIndices);
       rewriter.setInsertionPointAfter(newWarpOp);
       Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
       // Extract from distributed vector.
@@ -994,7 +994,7 @@ struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
       return failure();
     unsigned int operandNumber = operand->getOperandNumber();
     auto extractOp = operand->get().getDefiningOp<vector::ExtractElementOp>();
-    VectorType extractSrcType = extractOp.getVectorType();
+    VectorType extractSrcType = extractOp.getSourceVectorType();
     bool is0dOrVec1Extract = extractSrcType.getNumElements() == 1;
     Type elType = extractSrcType.getElementType();
     VectorType distributedVecType;
index 4918abb..722db6a 100644 (file)
@@ -48,7 +48,7 @@ struct CastAwayExtractStridedSliceLeadingOneDim
     // vector.extract_strided_slice requires the input and output vector to have
     // the same rank. Here we drop leading one dimensions from the input vector
     // type to make sure we don't cause mismatch.
-    VectorType oldSrcType = extractOp.getVectorType();
+    VectorType oldSrcType = extractOp.getSourceVectorType();
     VectorType newSrcType = trimLeadingOneDims(oldSrcType);
 
     if (newSrcType.getRank() == oldSrcType.getRank())
index 357582b..b510246 100644 (file)
@@ -266,7 +266,7 @@ public:
   LogicalResult matchAndRewrite(vector::BroadcastOp op,
                                 PatternRewriter &rewriter) const override {
     auto loc = op.getLoc();
-    VectorType dstType = op.getVectorType();
+    VectorType dstType = op.getResultVectorType();
     VectorType srcType = op.getSourceType().dyn_cast<VectorType>();
     Type eltType = dstType.getElementType();
 
@@ -404,8 +404,8 @@ public:
     auto loc = op.getLoc();
 
     Value input = op.getVector();
-    VectorType inputType = op.getVectorType();
-    VectorType resType = op.getResultType();
+    VectorType inputType = op.getSourceVectorType();
+    VectorType resType = op.getResultVectorType();
 
     // Set up convenience transposition table.
     SmallVector<int64_t> transp;
@@ -492,7 +492,7 @@ public:
                                 PatternRewriter &rewriter) const override {
     auto loc = op.getLoc();
 
-    VectorType srcType = op.getVectorType();
+    VectorType srcType = op.getSourceVectorType();
     if (srcType.getRank() != 2)
       return rewriter.notifyMatchFailure(op, "Not a 2D transpose");
 
@@ -518,8 +518,8 @@ public:
 
     Value shuffled =
         rewriter.create<vector::ShuffleOp>(loc, casted, casted, mask);
-    rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, op.getResultType(),
-                                                     shuffled);
+    rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
+        op, op.getResultVectorType(), shuffled);
 
     return success();
   }
@@ -552,7 +552,7 @@ public:
 
     VectorType lhsType = op.getOperandVectorTypeLHS();
     VectorType rhsType = op.getOperandTypeRHS().dyn_cast<VectorType>();
-    VectorType resType = op.getVectorType();
+    VectorType resType = op.getResultVectorType();
     Type eltType = resType.getElementType();
     bool isInt = eltType.isa<IntegerType, IndexType>();
     Value acc = (op.getAcc().empty()) ? nullptr : op.getAcc()[0];
@@ -1208,15 +1208,16 @@ struct CombineContractBroadcast
         continue;
       // contractionOp can only take vector as operands.
       auto srcType = broadcast.getSourceType().dyn_cast<VectorType>();
-      if (!srcType || srcType.getRank() == broadcast.getVectorType().getRank())
+      if (!srcType ||
+          srcType.getRank() == broadcast.getResultVectorType().getRank())
         continue;
       int64_t rankDiff =
-          broadcast.getVectorType().getRank() - srcType.getRank();
+          broadcast.getResultVectorType().getRank() - srcType.getRank();
       bool innerDimBroadcast = false;
       SmallVector<AffineExpr> originalDims;
       for (const auto &dim : llvm::enumerate(srcType.getShape())) {
-        if (dim.value() !=
-            broadcast.getVectorType().getDimSize(rankDiff + dim.index())) {
+        if (dim.value() != broadcast.getResultVectorType().getDimSize(
+                               rankDiff + dim.index())) {
           innerDimBroadcast = true;
           break;
         }
@@ -1232,7 +1233,7 @@ struct CombineContractBroadcast
       // of non-unit size.
       bool nonUnitDimReductionBroadcast = false;
       for (int64_t i = 0; i < rankDiff; ++i) {
-        if (broadcast.getVectorType().getDimSize(i) != 1 &&
+        if (broadcast.getResultVectorType().getDimSize(i) != 1 &&
             isReductionIterator(contractOp.getIteratorTypes()
                                     .getValue()[map.getDimPosition(i)])) {
           nonUnitDimReductionBroadcast = true;
@@ -1243,8 +1244,8 @@ struct CombineContractBroadcast
         continue;
 
       AffineMap broadcastMap =
-          AffineMap::get(broadcast.getVectorType().getRank(), 0, originalDims,
-                         contractOp.getContext());
+          AffineMap::get(broadcast.getResultVectorType().getRank(), 0,
+                         originalDims, contractOp.getContext());
       map = broadcastMap.compose(map);
       *operand = broadcast.getSource();
       changed = true;
@@ -1363,7 +1364,7 @@ struct ReorderElementwiseOpsOnTranspose final
       auto transposeOp = operand.getDefiningOp<vector::TransposeOp>();
       if (transposeOp) {
         transposeMaps.push_back(transposeOp.getTransp());
-        srcType = transposeOp.getVectorType();
+        srcType = transposeOp.getSourceVectorType();
       } else if (!matchPattern(operand, m_Constant())) {
         return failure();
       }
@@ -2376,7 +2377,7 @@ struct BubbleDownVectorBitCastForExtract
   LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
                                 PatternRewriter &rewriter) const override {
     // Only support extracting scalars for now.
-    if (extractOp.getVectorType().getRank() != 1)
+    if (extractOp.getSourceVectorType().getRank() != 1)
       return failure();
 
     auto castOp = extractOp.getVector().getDefiningOp<vector::BitCastOp>();
@@ -2468,7 +2469,7 @@ struct BubbleDownBitCastForStridedSliceExtract
                      [](const APInt &val) { return !val.isOneValue(); }))
       return failure();
 
-    unsigned rank = extractOp.getVectorType().getRank();
+    unsigned rank = extractOp.getSourceVectorType().getRank();
     assert(castDstLastDim % castSrcLastDim == 0);
     int64_t expandRatio = castDstLastDim / castSrcLastDim;
 
index 0872d10..b2d0025 100644 (file)
@@ -602,12 +602,12 @@ struct UnrollTransposePattern : public OpRewritePattern<vector::TransposeOp> {
 
   LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
                                 PatternRewriter &rewriter) const override {
-    if (transposeOp.getResultType().getRank() == 0)
+    if (transposeOp.getResultVectorType().getRank() == 0)
       return failure();
     auto targetShape = getTargetShape(options, transposeOp);
     if (!targetShape)
       return failure();
-    auto originalVectorType = transposeOp.getResultType();
+    auto originalVectorType = transposeOp.getResultVectorType();
     SmallVector<int64_t> strides(targetShape->size(), 1);
     Location loc = transposeOp.getLoc();
     ArrayRef<int64_t> originalSize = originalVectorType.getShape();
index 8705329..1558000 100644 (file)
@@ -252,7 +252,7 @@ public:
 
     // Check if the source vector type is supported. AVX2 patterns can only be
     // applied to f32 vector types with two dimensions greater than one.
-    VectorType srcType = op.getVectorType();
+    VectorType srcType = op.getSourceVectorType();
     if (!srcType.getElementType().isF32())
       return rewriter.notifyMatchFailure(op, "Unsupported vector element type");
 
@@ -287,7 +287,7 @@ public:
       // Reshape the n-D input vector with only two dimensions greater than one
       // to a 2-D vector.
       auto flattenedType =
-          VectorType::get({n * m}, op.getVectorType().getElementType());
+          VectorType::get({n * m}, op.getSourceVectorType().getElementType());
       auto reshInputType = VectorType::get({m, n}, srcType.getElementType());
       auto reshInput =
           ib.create<vector::ShapeCastOp>(flattenedType, op.getVector());
@@ -315,7 +315,7 @@ public:
       // We have to transpose their dimensions and retrieve its original rank
       // (e.g., 1x8x1x4x1).
       res = ib.create<vector::ShapeCastOp>(flattenedType, res);
-      res = ib.create<vector::ShapeCastOp>(op.getResultType(), res);
+      res = ib.create<vector::ShapeCastOp>(op.getResultVectorType(), res);
       rewriter.replaceOp(op, res);
       return success();
     };