auto getRank = [](Type type) {
return type.isa<VectorType>() ? type.cast<VectorType>().getRank() : 0;
};
+ // If splat or broadcast from a scalar, just return the source scalar.
unsigned broadcastSrcRank = getRank(source.getType());
+ if (broadcastSrcRank == 0)
+ return source;
+
unsigned extractResultRank = getRank(extractOp.getType());
if (extractResultRank >= broadcastSrcRank)
return Value();
extractVecType.getShape() !=
broadcastVecType.getShape().take_back(extractResultRank))
return Value();
+
+ auto broadcastOp = cast<vector::BroadcastOp>(defOp);
+ int64_t rankDiff = broadcastSrcRank - extractResultRank;
+ // Detect all the positions that come from "dim-1" broadcasting.
+ // These dimensions correspond to "dim-1" broadcasted dims; set the mathching
+ // extract position to `0` when extracting from the source operand.
+ llvm::SetVector<int64_t> broadcastedUnitDims =
+ broadcastOp.computeBroadcastedUnitDims();
auto extractPos = extractVector<int64_t>(extractOp.getPosition());
- unsigned rankDiff = broadcastSrcRank - extractResultRank;
+ for (int64_t i = rankDiff, e = extractPos.size(); i < e; ++i)
+ if (broadcastedUnitDims.contains(i))
+ extractPos[i] = 0;
+ // `rankDiff` leading dimensions correspond to new broadcasted dims, drop the
+ // matching extract position when extracting from the source operand.
extractPos.erase(extractPos.begin(),
std::next(extractPos.begin(), extractPos.size() - rankDiff));
- extractOp.setOperand(source);
// OpBuilder is only used as a helper to build an I64ArrayAttr.
OpBuilder b(extractOp.getContext());
+ extractOp.setOperand(source);
extractOp->setAttr(ExtractOp::getPositionAttrStrName(),
b.getI64ArrayAttr(extractPos));
return extractOp.getResult();
// BroadcastOp
//===----------------------------------------------------------------------===//
+/// Return the dimensions of the result vector that were formerly ones in the
+/// source tensor and thus correspond to "dim-1" broadcasting.
+llvm::SetVector<int64_t> BroadcastOp::computeBroadcastedUnitDims() {
+ VectorType srcVectorType = getSourceType().dyn_cast<VectorType>();
+ // Scalar broadcast is without any unit dim broadcast.
+ if (!srcVectorType)
+ return {};
+ ArrayRef<int64_t> srcShape = srcVectorType.getShape();
+ ArrayRef<int64_t> dstShape = getVectorType().getShape();
+ int64_t rankDiff = dstShape.size() - srcShape.size();
+ int64_t dstDim = rankDiff;
+ llvm::SetVector<int64_t> res;
+ for (auto [s1, s2] : llvm::zip(srcShape, dstShape.drop_front(rankDiff))) {
+ if (s1 != s2) {
+ assert(s1 == 1 && "expected dim-1 broadcasting");
+ res.insert(dstDim);
+ }
+ ++dstDim;
+ }
+ return res;
+}
+
BroadcastableToResult
mlir::vector::isBroadcastableTo(Type srcType, VectorType dstVectorType,
std::pair<int, int> *mismatchingDims) {
%1 = vector.transfer_read %0[%c0, %i4, %c0], %f0 {in_bounds = [true]} : tensor<1x4x4xf32>, vector<4xf32>
return %1 : vector<4xf32>
}
+
+// -----
+
+// CHECK-LABEL: func.func @extract_from_broadcast
+func.func @extract_from_broadcast(%src: vector<1x1x1xf32>) -> vector<1xf32> {
+ %0 = vector.broadcast %src : vector<1x1x1xf32> to vector<1x1x32x1xf32>
+
+ // CHECK-NEXT: %0 = vector.extract {{.*}}[0, 0] : vector<1x1x1xf32>
+ // CHECK-NEXT: return %0 : vector<1xf32>
+ %1 = vector.extract %0[0, 0, 31] : vector<1x1x32x1xf32>
+ return %1: vector<1xf32>
+}