[mlir][vector] Fix extract op canonicalization for 0d vector
authorThomas Raoux <thomasraoux@google.com>
Tue, 17 Jan 2023 17:25:51 +0000 (17:25 +0000)
committerThomas Raoux <thomasraoux@google.com>
Tue, 17 Jan 2023 17:25:51 +0000 (17:25 +0000)
Fix ExtractOpFromBroadcast when the broadcast source is a 0d vector.

Differential Revision: https://reviews.llvm.org/D141735

mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/test/Dialect/Vector/canonicalize.mlir

index db2cd82..387f003 100644 (file)
@@ -1632,6 +1632,13 @@ public:
     // folding patterns.
     if (extractResultRank < broadcastSrcRank)
       return failure();
+
+    // Special case if broadcast src is a 0D vector.
+    if (extractResultRank == 0) {
+      assert(broadcastSrcRank == 0 && source.getType().isa<VectorType>());
+      rewriter.replaceOpWithNewOp<vector::ExtractElementOp>(extractOp, source);
+      return success();
+    }
     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
         extractOp, extractOp.getType(), source);
     return success();
index 1990b89..081d94e 100644 (file)
@@ -519,6 +519,18 @@ func.func @fold_extract_broadcast(%a : f32) -> f32 {
 
 // -----
 
+// CHECK-LABEL: fold_extract_broadcast_0dvec
+//  CHECK-SAME:   %[[A:.*]]: vector<f32>
+//       CHECK:   %[[B:.+]] = vector.extractelement %[[A]][] : vector<f32>
+//       CHECK:   return %[[B]] : f32
+func.func @fold_extract_broadcast_0dvec(%a : vector<f32>) -> f32 {
+  %b = vector.broadcast %a : vector<f32> to vector<1x2x4xf32>
+  %r = vector.extract %b[0, 1, 2] : vector<1x2x4xf32>
+  return %r : f32
+}
+
+// -----
+
 // CHECK-LABEL: fold_extract_broadcast_negative
 //       CHECK:   vector.broadcast %{{.*}} : vector<1x1xf32> to vector<1x1x4xf32>
 //       CHECK:   vector.extract %{{.*}}[0, 0] : vector<1x1x4xf32>