[mlir][vector] Fix crash in extractelement vec distribution
authorThomas Raoux <thomasraoux@google.com>
Wed, 11 Jan 2023 01:42:59 +0000 (01:42 +0000)
committerThomas Raoux <thomasraoux@google.com>
Wed, 11 Jan 2023 02:35:12 +0000 (02:35 +0000)
Prevent creating a vector of size 0 that would fail verifier.
Vector 1d with a single element should be treated like 0d vectors.

Differential Revision: https://reviews.llvm.org/D141452

mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
mlir/test/Dialect/Vector/vector-warp-distribute.mlir

index 07e608a..c8b0fc4 100644 (file)
@@ -995,19 +995,20 @@ struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
     unsigned int operandNumber = operand->getOperandNumber();
     auto extractOp = operand->get().getDefiningOp<vector::ExtractElementOp>();
     VectorType extractSrcType = extractOp.getVectorType();
-    bool is0dExtract = extractSrcType.getRank() == 0;
+    bool is0dOrVec1Extract = extractSrcType.getNumElements() == 1;
     Type elType = extractSrcType.getElementType();
     VectorType distributedVecType;
-    if (!is0dExtract) {
+    if (!is0dOrVec1Extract) {
       assert(extractSrcType.getRank() == 1 &&
              "expected that extractelement src rank is 0 or 1");
+      if (extractSrcType.getShape()[0] % warpOp.getWarpSize() != 0)
+        return failure();
       int64_t elementsPerLane =
           extractSrcType.getShape()[0] / warpOp.getWarpSize();
       distributedVecType = VectorType::get({elementsPerLane}, elType);
     } else {
       distributedVecType = extractSrcType;
     }
-
     // Yield source vector from warp op.
     Location loc = extractOp.getLoc();
     SmallVector<size_t> newRetIndices;
@@ -1019,9 +1020,17 @@ struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
 
     // 0d extract: The new warp op broadcasts the source vector to all lanes.
     // All lanes extract the scalar.
-    if (is0dExtract) {
-      Value newExtract =
-          rewriter.create<vector::ExtractElementOp>(loc, distributedVec);
+    if (is0dOrVec1Extract) {
+      Value newExtract;
+      if (extractSrcType.getRank() == 1) {
+        newExtract = rewriter.create<vector::ExtractElementOp>(
+            loc, distributedVec,
+            rewriter.create<arith::ConstantIndexOp>(loc, 0));
+
+      } else {
+        newExtract =
+            rewriter.create<vector::ExtractElementOp>(loc, distributedVec);
+      }
       newWarpOp->getResult(operandNumber).replaceAllUsesWith(newExtract);
       return success();
     }
index 2dd5477..b5087fe 100644 (file)
@@ -761,6 +761,26 @@ func.func @vector_extractelement_0d(%laneid: index) -> (f32) {
 
 // -----
 
+// CHECK-PROP-LABEL: func.func @vector_extractelement_1element(
+//       CHECK-PROP:   %[[C0:.*]] = arith.constant 0 : index
+//       CHECK-PROP:   %[[R:.*]] = vector.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<1xf32>) {
+//       CHECK-PROP:     %[[V:.*]] = "some_def"() : () -> vector<1xf32>
+//       CHECK-PROP:     vector.yield %[[V]] : vector<1xf32>
+//       CHECK-PROP:   }
+//       CHECK-PROP:   %[[E:.*]] = vector.extractelement %[[R]][%[[C0]] : index] : vector<1xf32>
+//       CHECK-PROP:   return %[[E]] : f32
+func.func @vector_extractelement_1element(%laneid: index) -> (f32) {
+  %r = vector.warp_execute_on_lane_0(%laneid)[32] -> (f32) {
+    %0 = "some_def"() : () -> (vector<1xf32>)
+    %c0 = arith.constant 0 : index
+    %1 = vector.extractelement %0[%c0 : index] : vector<1xf32>
+    vector.yield %1 : f32
+  }
+  return %r : f32
+}
+
+// -----
+
 //       CHECK-PROP: #[[$map:.*]] = affine_map<()[s0] -> (s0 ceildiv 3)>
 //       CHECK-PROP: #[[$map1:.*]] = affine_map<()[s0] -> (s0 mod 3)>
 // CHECK-PROP-LABEL: func.func @vector_extractelement_1d(