[mlir][vector] Add folding for ExtractOp with ShapeCastOp source
authorThomas Raoux <thomasraoux@google.com>
Fri, 23 Oct 2020 18:53:38 +0000 (11:53 -0700)
committerThomas Raoux <thomasraoux@google.com>
Fri, 23 Oct 2020 19:06:18 +0000 (12:06 -0700)
Differential Revision: https://reviews.llvm.org/D89853

mlir/lib/Dialect/Vector/VectorOps.cpp
mlir/test/Dialect/Vector/canonicalize.mlir

index b71102c..d1deb5a 100644 (file)
@@ -843,6 +843,61 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) {
   return Value();
 }
 
+// Fold extractOp with source coming from ShapeCast op.
+static Value foldExtractFromShapeCast(ExtractOp extractOp) {
+  auto shapeCastOp = extractOp.vector().getDefiningOp<vector::ShapeCastOp>();
+  if (!shapeCastOp)
+    return Value();
+  // Get the nth dimension size starting from lowest dimension.
+  auto getDimReverse = [](VectorType type, int64_t n) {
+    return type.getDimSize(type.getRank() - n - 1);
+  };
+  int64_t destinationRank =
+      extractOp.getVectorType().getRank() - extractOp.position().size();
+  if (destinationRank > shapeCastOp.getSourceVectorType().getRank())
+    return Value();
+  if (destinationRank > 0) {
+    auto destinationType = extractOp.getResult().getType().cast<VectorType>();
+    for (int64_t i = 0; i < destinationRank; i++) {
+      // The lowest dimension of of the destination must match the lowest
+      // dimension of the shapecast op source.
+      if (getDimReverse(shapeCastOp.getSourceVectorType(), i) !=
+          getDimReverse(destinationType, i))
+        return Value();
+    }
+  }
+  // Extract the strides associated with the extract op vector source. Then use
+  // this to calculate a linearized position for the extract.
+  auto extractedPos = extractVector<int64_t>(extractOp.position());
+  std::reverse(extractedPos.begin(), extractedPos.end());
+  SmallVector<int64_t, 4> strides;
+  int64_t stride = 1;
+  for (int64_t i = 0, e = extractedPos.size(); i < e; i++) {
+    strides.push_back(stride);
+    stride *= getDimReverse(extractOp.getVectorType(), i + destinationRank);
+  }
+
+  int64_t position = linearize(extractedPos, strides);
+  // Then extract the strides assoociated to the shapeCast op vector source and
+  // delinearize the position using those strides.
+  SmallVector<int64_t, 4> newStrides;
+  int64_t numDimension =
+      shapeCastOp.getSourceVectorType().getRank() - destinationRank;
+  stride = 1;
+  for (int64_t i = 0; i < numDimension; i++) {
+    newStrides.push_back(stride);
+    stride *=
+        getDimReverse(shapeCastOp.getSourceVectorType(), i + destinationRank);
+  }
+  std::reverse(newStrides.begin(), newStrides.end());
+  SmallVector<int64_t, 4> newPosition = delinearize(newStrides, position);
+  OpBuilder b(extractOp.getContext());
+  extractOp.setAttr(ExtractOp::getPositionAttrName(),
+                    b.getI64ArrayAttr(newPosition));
+  extractOp.setOperand(shapeCastOp.source());
+  return extractOp.getResult();
+}
+
 OpFoldResult ExtractOp::fold(ArrayRef<Attribute>) {
   if (succeeded(foldExtractOpFromExtractChain(*this)))
     return getResult();
@@ -852,6 +907,8 @@ OpFoldResult ExtractOp::fold(ArrayRef<Attribute>) {
     return val;
   if (auto val = foldExtractFromBroadcast(*this))
     return val;
+  if (auto val = foldExtractFromShapeCast(*this))
+    return val;
   return OpFoldResult();
 }
 
index 2f927a1..66bad06 100644 (file)
@@ -396,6 +396,39 @@ func @fold_extract_broadcast_negative(%a : f32) -> vector<4xf32> {
 
 // -----
 
+// CHECK-LABEL: func @fold_extract_shapecast
+//  CHECK-SAME: (%[[A0:.*]]: vector<5x1x3x2xf32>, %[[A1:.*]]: vector<8x4x2xf32>
+//       CHECK:   %[[R0:.*]] = vector.extract %[[A0]][1, 0, 1, 1] : vector<5x1x3x2xf32>
+//       CHECK:   %[[R1:.*]] = vector.extract %[[A0]][1, 0, 2] : vector<5x1x3x2xf32>
+//       CHECK:   %[[R2:.*]] = vector.extract %[[A1]][7] : vector<8x4x2xf32>
+//       CHECK:   return %[[R0]], %[[R1]], %[[R2]] : f32, vector<2xf32>, vector<4x2xf32>
+func @fold_extract_shapecast(%arg0 : vector<5x1x3x2xf32>,
+                             %arg1 : vector<8x4x2xf32>)
+  -> (f32, vector<2xf32>, vector<4x2xf32>) {
+  %0 = vector.shape_cast %arg0 : vector<5x1x3x2xf32> to vector<15x2xf32>
+  %1 = vector.shape_cast %arg1 : vector<8x4x2xf32> to vector<4x2x4x2xf32>
+  %r1 = vector.extract %0[4, 1] : vector<15x2xf32>
+  %r2 = vector.extract %0[5] : vector<15x2xf32>
+  %r3 = vector.extract %1[3, 1] : vector<4x2x4x2xf32>
+  return %r1, %r2, %r3 : f32, vector<2xf32>, vector<4x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: fold_extract_shapecast_negative
+//       CHECK:   %[[V:.*]] = vector.shape_cast %{{.*}} : vector<16xf32> to vector<2x4x2xf32>
+//       CHECK:   %[[R:.*]] = vector.extract %[[V]][1] : vector<2x4x2xf32>
+//       CHECK:   return %[[R]] : vector<4x2xf32>
+func @fold_extract_shapecast_negative(%arg0 : vector<16xf32>,
+                             %arg1 : vector<8x4x2xf32>) -> vector<4x2xf32> {
+  %0 = vector.shape_cast %arg0 : vector<16xf32> to vector<2x4x2xf32>
+  %r = vector.extract %0[1] : vector<2x4x2xf32>
+  return %r : vector<4x2xf32>
+}
+
+
+// -----
+
 // CHECK-LABEL: fold_vector_transfers
 func @fold_vector_transfers(%A: memref<?x8xf32>) -> (vector<4x8xf32>, vector<4x9xf32>) {
   %c0 = constant 0 : index