[mlir][vector] Fix folding of vector.extract from vector.broadcast
authorNicolas Vasilache <nicolas.vasilache@gmail.com>
Mon, 28 Nov 2022 14:12:03 +0000 (06:12 -0800)
committerNicolas Vasilache <nicolas.vasilache@gmail.com>
Mon, 28 Nov 2022 15:17:31 +0000 (07:17 -0800)
This revision fixes a bug in the vector.extract folding that was missing
handling the "dim-1" broadcasting case in vector.broadcast.

Differential Revision: https://reviews.llvm.org/D138804

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/test/Dialect/Vector/canonicalize.mlir

index 5060d8c..edaf78b 100644 (file)
@@ -445,6 +445,10 @@ def Vector_BroadcastOp :
     VectorType getVectorType() {
       return getVector().getType().cast<VectorType>();
     }
+
+    /// Return the dimensions of the result vector that were formerly ones in the
+    /// source tensor and thus correspond to "dim-1" broadcasting.
+    llvm::SetVector<int64_t> computeBroadcastedUnitDims();
   }];
   let assemblyFormat = "$source attr-dict `:` type($source) `to` type($vector)";
   let hasFolder = 1;
index 2f9bca6..9e1b630 100644 (file)
@@ -1351,7 +1351,11 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) {
   auto getRank = [](Type type) {
     return type.isa<VectorType>() ? type.cast<VectorType>().getRank() : 0;
   };
+  // If splat or broadcast from a scalar, just return the source scalar.
   unsigned broadcastSrcRank = getRank(source.getType());
+  if (broadcastSrcRank == 0)
+    return source;
+
   unsigned extractResultRank = getRank(extractOp.getType());
   if (extractResultRank >= broadcastSrcRank)
     return Value();
@@ -1362,13 +1366,25 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) {
       extractVecType.getShape() !=
           broadcastVecType.getShape().take_back(extractResultRank))
     return Value();
+
+  auto broadcastOp = cast<vector::BroadcastOp>(defOp);
+  int64_t rankDiff = broadcastSrcRank - extractResultRank;
+  // Detect all the positions that come from "dim-1" broadcasting.
+  // These dimensions correspond to "dim-1" broadcasted dims; set the mathching
+  // extract position to `0` when extracting from the source operand.
+  llvm::SetVector<int64_t> broadcastedUnitDims =
+      broadcastOp.computeBroadcastedUnitDims();
   auto extractPos = extractVector<int64_t>(extractOp.getPosition());
-  unsigned rankDiff = broadcastSrcRank - extractResultRank;
+  for (int64_t i = rankDiff, e = extractPos.size(); i < e; ++i)
+    if (broadcastedUnitDims.contains(i))
+      extractPos[i] = 0;
+  // `rankDiff` leading dimensions correspond to new broadcasted dims, drop the
+  // matching extract position when extracting from the source operand.
   extractPos.erase(extractPos.begin(),
                    std::next(extractPos.begin(), extractPos.size() - rankDiff));
-  extractOp.setOperand(source);
   // OpBuilder is only used as a helper to build an I64ArrayAttr.
   OpBuilder b(extractOp.getContext());
+  extractOp.setOperand(source);
   extractOp->setAttr(ExtractOp::getPositionAttrStrName(),
                      b.getI64ArrayAttr(extractPos));
   return extractOp.getResult();
@@ -1683,6 +1699,28 @@ Optional<SmallVector<int64_t, 4>> FMAOp::getShapeForUnroll() {
 // BroadcastOp
 //===----------------------------------------------------------------------===//
 
+/// Return the dimensions of the result vector that were formerly ones in the
+/// source tensor and thus correspond to "dim-1" broadcasting.
+llvm::SetVector<int64_t> BroadcastOp::computeBroadcastedUnitDims() {
+  VectorType srcVectorType = getSourceType().dyn_cast<VectorType>();
+  // Scalar broadcast is without any unit dim broadcast.
+  if (!srcVectorType)
+    return {};
+  ArrayRef<int64_t> srcShape = srcVectorType.getShape();
+  ArrayRef<int64_t> dstShape = getVectorType().getShape();
+  int64_t rankDiff = dstShape.size() - srcShape.size();
+  int64_t dstDim = rankDiff;
+  llvm::SetVector<int64_t> res;
+  for (auto [s1, s2] : llvm::zip(srcShape, dstShape.drop_front(rankDiff))) {
+    if (s1 != s2) {
+      assert(s1 == 1 && "expected dim-1 broadcasting");
+      res.insert(dstDim);
+    }
+    ++dstDim;
+  }
+  return res;
+}
+
 BroadcastableToResult
 mlir::vector::isBroadcastableTo(Type srcType, VectorType dstVectorType,
                                 std::pair<int, int> *mismatchingDims) {
index 7aabcec..872767c 100644 (file)
@@ -2020,3 +2020,15 @@ func.func @transfer_read_from_rank_reducing_extract_slice(%src: tensor<1x8x8x8xf
   %1 = vector.transfer_read %0[%c0, %i4, %c0], %f0 {in_bounds = [true]} : tensor<1x4x4xf32>, vector<4xf32>
   return %1 : vector<4xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func.func @extract_from_broadcast
+func.func @extract_from_broadcast(%src: vector<1x1x1xf32>) -> vector<1xf32> {
+  %0 = vector.broadcast %src : vector<1x1x1xf32> to vector<1x1x32x1xf32>
+
+  //  CHECK-NEXT:   %0 = vector.extract {{.*}}[0, 0] : vector<1x1x1xf32>
+  //  CHECK-NEXT:   return %0 : vector<1xf32>
+  %1 = vector.extract %0[0, 0, 31] : vector<1x1x32x1xf32>
+  return %1: vector<1xf32>
+}