[mlir][vector] Fold extractOp coming from broadcastOp
authorThomas Raoux <thomasraoux@google.com>
Tue, 6 Oct 2020 16:56:35 +0000 (09:56 -0700)
committerThomas Raoux <thomasraoux@google.com>
Tue, 6 Oct 2020 17:27:39 +0000 (10:27 -0700)
Combine ExtractOp with scalar result with BroadcastOp source. This is useful to
be able to incrementally convert degenerated vector of one element into scalar.

Differential Revision: https://reviews.llvm.org/D88751

mlir/lib/Dialect/Vector/VectorOps.cpp
mlir/test/Dialect/Vector/canonicalize.mlir

index 672ad40..b71102c 100644 (file)
@@ -812,6 +812,37 @@ static Value foldExtractOpFromInsertChainAndTranspose(ExtractOp extractOp) {
   return Value();
 }
 
+/// Fold extractOp with scalar result coming from BroadcastOp.
+static Value foldExtractFromBroadcast(ExtractOp extractOp) {
+  auto broadcastOp = extractOp.vector().getDefiningOp<vector::BroadcastOp>();
+  if (!broadcastOp)
+    return Value();
+  if (extractOp.getType() == broadcastOp.getSourceType())
+    return broadcastOp.source();
+  auto getRank = [](Type type) {
+    return type.isa<VectorType>() ? type.cast<VectorType>().getRank() : 0;
+  };
+  unsigned broadcasrSrcRank = getRank(broadcastOp.getSourceType());
+  unsigned extractResultRank = getRank(extractOp.getType());
+  if (extractResultRank < broadcasrSrcRank) {
+    auto extractPos = extractVector<int64_t>(extractOp.position());
+    unsigned rankDiff = broadcasrSrcRank - extractResultRank;
+    extractPos.erase(
+        extractPos.begin(),
+        std::next(extractPos.begin(), extractPos.size() - rankDiff));
+    extractOp.setOperand(broadcastOp.source());
+    // OpBuilder is only used as a helper to build an I64ArrayAttr.
+    OpBuilder b(extractOp.getContext());
+    extractOp.setAttr(ExtractOp::getPositionAttrName(),
+                      b.getI64ArrayAttr(extractPos));
+    return extractOp.getResult();
+  }
+  // TODO: In case the rank of the broadcast source is greater than the rank of
+  // the extract result this can be combined into a new broadcast op. This needs
+  // to be added a canonicalization pattern if needed.
+  return Value();
+}
+
 OpFoldResult ExtractOp::fold(ArrayRef<Attribute>) {
   if (succeeded(foldExtractOpFromExtractChain(*this)))
     return getResult();
@@ -819,6 +850,8 @@ OpFoldResult ExtractOp::fold(ArrayRef<Attribute>) {
     return getResult();
   if (auto val = foldExtractOpFromInsertChainAndTranspose(*this))
     return val;
+  if (auto val = foldExtractFromBroadcast(*this))
+    return val;
   return OpFoldResult();
 }
 
index 9c36f76..2f927a1 100644 (file)
@@ -348,6 +348,54 @@ func @fold_extract_transpose(
 
 // -----
 
+// CHECK-LABEL: fold_extract_broadcast
+//  CHECK-SAME:   %[[A:.*]]: f32
+//       CHECK:   return %[[A]] : f32
+func @fold_extract_broadcast(%a : f32) -> f32 {
+  %b = vector.broadcast %a : f32 to vector<1x2x4xf32>
+  %r = vector.extract %b[0, 1, 2] : vector<1x2x4xf32>
+  return %r : f32
+}
+
+// -----
+
+// CHECK-LABEL: fold_extract_broadcast_vector
+//  CHECK-SAME:   %[[A:.*]]: vector<4xf32>
+//       CHECK:   return %[[A]] : vector<4xf32>
+func @fold_extract_broadcast_vector(%a : vector<4xf32>) -> vector<4xf32> {
+  %b = vector.broadcast %a : vector<4xf32> to vector<1x2x4xf32>
+  %r = vector.extract %b[0, 1] : vector<1x2x4xf32>
+  return %r : vector<4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: fold_extract_broadcast
+//  CHECK-SAME:   %[[A:.*]]: vector<4xf32>
+//       CHECK:   %[[R:.*]] = vector.extract %[[A]][2] : vector<4xf32>
+//       CHECK:   return %[[R]] : f32
+func @fold_extract_broadcast(%a : vector<4xf32>) -> f32 {
+  %b = vector.broadcast %a : vector<4xf32> to vector<1x2x4xf32>
+  %r = vector.extract %b[0, 1, 2] : vector<1x2x4xf32>
+  return %r : f32
+}
+
+// -----
+
+// Negative test for extract_op folding when the type of broadcast source
+// doesn't match the type of vector.extract.
+// CHECK-LABEL: fold_extract_broadcast_negative
+//       CHECK:   %[[B:.*]] = vector.broadcast %{{.*}} : f32 to vector<1x2x4xf32>
+//       CHECK:   %[[R:.*]] = vector.extract %[[B]][0, 1] : vector<1x2x4xf32>
+//       CHECK:   return %[[R]] : vector<4xf32>
+func @fold_extract_broadcast_negative(%a : f32) -> vector<4xf32> {
+  %b = vector.broadcast %a : f32 to vector<1x2x4xf32>
+  %r = vector.extract %b[0, 1] : vector<1x2x4xf32>
+  return %r : vector<4xf32>
+}
+
+// -----
+
 // CHECK-LABEL: fold_vector_transfers
 func @fold_vector_transfers(%A: memref<?x8xf32>) -> (vector<4x8xf32>, vector<4x9xf32>) {
   %c0 = constant 0 : index