unsigned int operandNumber = operand->getOperandNumber();
auto extractOp = operand->get().getDefiningOp<vector::ExtractElementOp>();
VectorType extractSrcType = extractOp.getVectorType();
- bool is0dExtract = extractSrcType.getRank() == 0;
+ bool is0dOrVec1Extract = extractSrcType.getNumElements() == 1;
Type elType = extractSrcType.getElementType();
VectorType distributedVecType;
- if (!is0dExtract) {
+ if (!is0dOrVec1Extract) {
assert(extractSrcType.getRank() == 1 &&
"expected that extractelement src rank is 0 or 1");
+ if (extractSrcType.getShape()[0] % warpOp.getWarpSize() != 0)
+ return failure();
int64_t elementsPerLane =
extractSrcType.getShape()[0] / warpOp.getWarpSize();
distributedVecType = VectorType::get({elementsPerLane}, elType);
} else {
distributedVecType = extractSrcType;
}
-
// Yield source vector from warp op.
Location loc = extractOp.getLoc();
SmallVector<size_t> newRetIndices;
// 0d extract: The new warp op broadcasts the source vector to all lanes.
// All lanes extract the scalar.
- if (is0dExtract) {
- Value newExtract =
- rewriter.create<vector::ExtractElementOp>(loc, distributedVec);
+ if (is0dOrVec1Extract) {
+ Value newExtract;
+ if (extractSrcType.getRank() == 1) {
+ newExtract = rewriter.create<vector::ExtractElementOp>(
+ loc, distributedVec,
+ rewriter.create<arith::ConstantIndexOp>(loc, 0));
+
+ } else {
+ newExtract =
+ rewriter.create<vector::ExtractElementOp>(loc, distributedVec);
+ }
newWarpOp->getResult(operandNumber).replaceAllUsesWith(newExtract);
return success();
}
// -----
+// CHECK-PROP-LABEL: func.func @vector_extractelement_1element(
+// CHECK-PROP: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-PROP: %[[R:.*]] = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<1xf32>) {
+// CHECK-PROP: %[[V:.*]] = "some_def"() : () -> vector<1xf32>
+// CHECK-PROP: vector.yield %[[V]] : vector<1xf32>
+// CHECK-PROP: }
+// CHECK-PROP: %[[E:.*]] = vector.extractelement %[[R]][%[[C0]] : index] : vector<1xf32>
+// CHECK-PROP: return %[[E]] : f32
+func.func @vector_extractelement_1element(%laneid: index) -> (f32) {
+ %r = vector.warp_execute_on_lane_0(%laneid)[32] -> (f32) {
+ %0 = "some_def"() : () -> (vector<1xf32>)
+ %c0 = arith.constant 0 : index
+ %1 = vector.extractelement %0[%c0 : index] : vector<1xf32>
+ vector.yield %1 : f32
+ }
+ return %r : f32
+}
+
+// -----
+
// CHECK-PROP: #[[$map:.*]] = affine_map<()[s0] -> (s0 ceildiv 3)>
// CHECK-PROP: #[[$map1:.*]] = affine_map<()[s0] -> (s0 mod 3)>
// CHECK-PROP-LABEL: func.func @vector_extractelement_1d(