[mlir][vector] Address post-commit review comments on vector ops folding patterns
authorThomas Raoux <thomasraoux@google.com>
Mon, 2 Nov 2020 18:18:38 +0000 (10:18 -0800)
committerThomas Raoux <thomasraoux@google.com>
Mon, 2 Nov 2020 18:57:32 +0000 (10:57 -0800)
Differential Revision: https://reviews.llvm.org/D90183

mlir/lib/Dialect/Vector/VectorOps.cpp
mlir/test/Dialect/Vector/canonicalize.mlir

index faae278..53cdf3f 100644 (file)
@@ -850,10 +850,12 @@ static Value foldExtractFromShapeCast(ExtractOp extractOp) {
     return Value();
   // Get the nth dimension size starting from lowest dimension.
   auto getDimReverse = [](VectorType type, int64_t n) {
-    return type.getDimSize(type.getRank() - n - 1);
+    return type.getShape().take_back(n+1).front();
   };
   int64_t destinationRank =
-      extractOp.getVectorType().getRank() - extractOp.position().size();
+      extractOp.getType().isa<VectorType>()
+          ? extractOp.getType().cast<VectorType>().getRank()
+          : 0;
   if (destinationRank > shapeCastOp.getSourceVectorType().getRank())
     return Value();
   if (destinationRank > 0) {
@@ -861,6 +863,7 @@ static Value foldExtractFromShapeCast(ExtractOp extractOp) {
     for (int64_t i = 0; i < destinationRank; i++) {
       // The lowest dimension of of the destination must match the lowest
       // dimension of the shapecast op source.
+      // TODO: This case could be support in a canonicalization pattern.
       if (getDimReverse(shapeCastOp.getSourceVectorType(), i) !=
           getDimReverse(destinationType, i))
         return Value();
@@ -891,6 +894,7 @@ static Value foldExtractFromShapeCast(ExtractOp extractOp) {
   }
   std::reverse(newStrides.begin(), newStrides.end());
   SmallVector<int64_t, 4> newPosition = delinearize(newStrides, position);
+  // OpBuilder is only used as a helper to build an I64ArrayAttr.
   OpBuilder b(extractOp.getContext());
   extractOp.setAttr(ExtractOp::getPositionAttrName(),
                     b.getI64ArrayAttr(newPosition));
@@ -1632,8 +1636,8 @@ static LogicalResult verify(ExtractStridedSliceOp op) {
 }
 
 // When the source of ExtractStrided comes from a chain of InsertStrided ops try
-// to use the source o the InsertStrided ops if we can detect that the extracted
-// vector is a subset of one of the vector inserted.
+// to use the source of the InsertStrided ops if we can detect that the
+// extracted vector is a subset of one of the vector inserted.
 static LogicalResult
 foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op) {
   // Helper to extract integer out of ArrayAttr.
index b20accc..0090542 100644 (file)
@@ -160,20 +160,20 @@ func @extract_strided_fold_negative(%a: vector<4x4xf32>, %b: vector<8x16xf32>)
 
 // Case where we need to go through 2 level of insert element.
 // CHECK-LABEL: extract_strided_fold_insert
-//  CHECK-SAME: (%[[ARG0:.*]]: vector<2x4xf32>, %[[ARG1:.*]]: vector<1x4xf32>,
+//  CHECK-SAME: (%[[ARG0:.*]]: vector<2x8xf32>, %[[ARG1:.*]]: vector<1x4xf32>,
 //  CHECK-NEXT:   %[[EXT:.*]] = vector.extract_strided_slice %[[ARG1]]
-//  CHECK-SAME:     {offsets = [0, 1], sizes = [1, 1], strides = [1, 1]}
+//  CHECK-SAME:     {offsets = [0, 0], sizes = [1, 1], strides = [1, 1]}
 //  CHECK-SAME:       : vector<1x4xf32> to vector<1x1xf32>
 //  CHECK-NEXT:   return %[[EXT]] : vector<1x1xf32>
-func @extract_strided_fold_insert(%a: vector<2x4xf32>, %b: vector<1x4xf32>,
+func @extract_strided_fold_insert(%a: vector<2x8xf32>, %b: vector<1x4xf32>,
                                   %c : vector<1x4xf32>) -> (vector<1x1xf32>) {
-  %0 = vector.insert_strided_slice %b, %a {offsets = [0, 0], strides = [1, 1]}
-    : vector<1x4xf32> into vector<2x4xf32>
+  %0 = vector.insert_strided_slice %b, %a {offsets = [0, 1], strides = [1, 1]}
+    : vector<1x4xf32> into vector<2x8xf32>
   %1 = vector.insert_strided_slice %c, %0 {offsets = [1, 0], strides = [1, 1]}
-    : vector<1x4xf32> into vector<2x4xf32>
+    : vector<1x4xf32> into vector<2x8xf32>
   %2 = vector.extract_strided_slice %1
       {offsets = [0, 1], sizes = [1, 1], strides = [1, 1]}
-        : vector<2x4xf32> to vector<1x1xf32>
+        : vector<2x8xf32> to vector<1x1xf32>
   return %2 : vector<1x1xf32>
 }