[mlir][Vector] Generalize and improve folding of ExtractOp from Insert/Transpose...
authorNicolas Vasilache <nicolas.vasilache@gmail.com>
Fri, 14 Jan 2022 16:08:14 +0000 (16:08 +0000)
committerNicolas Vasilache <nicolas.vasilache@gmail.com>
Mon, 17 Jan 2022 16:05:23 +0000 (16:05 +0000)
This revision fixes a bug where the iterative algorithm would walk back def-use chains to an incorrect operand.
This exposed opportunities for a larger refactoring and behavior improvement.
The new algorithm has improved folding behavior and proceeds by tracking both the
permutation of the extraction position and the internal vector permutation.
Multiple partial intersection cases with a candidate insertOp are supported.

The refactoring of the implementation should also help it generalize to strided insert/extract op.

This also subsumes the previous `foldExtractOpFromTranspose` which is now a simple special case and can be deleted.

Differential Revision: https://reviews.llvm.org/D117322

mlir/lib/Dialect/Vector/VectorOps.cpp
mlir/test/Dialect/Vector/canonicalize.mlir

index 224bccf..eaa4f4e 100644 (file)
@@ -946,7 +946,7 @@ static LogicalResult verify(vector::ExtractOp op) {
 }
 
 template <typename IntType>
-static SmallVector<IntType, 4> extractVector(ArrayAttr arrayAttr) {
+static SmallVector<IntType> extractVector(ArrayAttr arrayAttr) {
   return llvm::to_vector<4>(llvm::map_range(
       arrayAttr.getAsRange<IntegerAttr>(),
       [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
@@ -960,12 +960,12 @@ static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp) {
 
   SmallVector<int64_t, 4> globalPosition;
   ExtractOp currentOp = extractOp;
-  auto extractedPos = extractVector<int64_t>(currentOp.position());
-  globalPosition.append(extractedPos.rbegin(), extractedPos.rend());
+  auto extrPos = extractVector<int64_t>(currentOp.position());
+  globalPosition.append(extrPos.rbegin(), extrPos.rend());
   while (ExtractOp nextOp = currentOp.vector().getDefiningOp<ExtractOp>()) {
     currentOp = nextOp;
-    auto extractedPos = extractVector<int64_t>(currentOp.position());
-    globalPosition.append(extractedPos.rbegin(), extractedPos.rend());
+    auto extrPos = extractVector<int64_t>(currentOp.position());
+    globalPosition.append(extrPos.rbegin(), extrPos.rend());
   }
   extractOp.setOperand(currentOp.vector());
   // OpBuilder is only used as a helper to build an I64ArrayAttr.
@@ -976,144 +976,219 @@ static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp) {
   return success();
 }
 
-/// Fold the result of an ExtractOp in place when it comes from a TransposeOp.
-static LogicalResult foldExtractOpFromTranspose(ExtractOp extractOp) {
-  auto transposeOp = extractOp.vector().getDefiningOp<vector::TransposeOp>();
-  if (!transposeOp)
-    return failure();
+namespace {
+/// Fold an ExtractOp that is fed by a chain of InsertOps and TransposeOps.
+/// Walk back a chain of InsertOp/TransposeOp until we hit a match.
+/// Compose TransposeOp permutations as we walk back.
+/// This helper class keeps an updated extraction position `extractPosition`
+/// with extra trailing sentinels.
+/// The sentinels encode the internal transposition status of the result vector.
+/// As we iterate, extractPosition is permuted and updated.
+class ExtractFromInsertTransposeChainState {
+public:
+  ExtractFromInsertTransposeChainState(ExtractOp e);
+
+  /// Iterate over producing insert and transpose ops until we find a fold.
+  Value fold();
+
+private:
+  /// Return true if the vector at position `a` is contained within the vector
+  /// at position `b`. Under insert/extract semantics, this is the same as `a`
+  /// is a prefix of `b`.
+  template <typename ContainerA, typename ContainerB>
+  bool isContainedWithin(const ContainerA &a, const ContainerB &b) {
+    return a.size() <= b.size() &&
+           std::equal(a.begin(), a.begin() + a.size(), b.begin());
+  }
 
-  auto permutation = extractVector<unsigned>(transposeOp.transp());
-  auto extractedPos = extractVector<int64_t>(extractOp.position());
+  /// Return true if the vector at position `a` intersects the vector at
+  /// position `b`. Under insert/extract semantics, this is the same as equality
+  /// of all entries of `a` that are >=0 with the corresponding entries of b.
+  /// Comparison is on the common prefix (i.e. zip).
+  template <typename ContainerA, typename ContainerB>
+  bool intersectsWhereNonNegative(const ContainerA &a, const ContainerB &b) {
+    for (auto it : llvm::zip(a, b)) {
+      if (std::get<0>(it) < 0 || std::get<0>(it) < 0)
+        continue;
+      if (std::get<0>(it) != std::get<1>(it))
+        return false;
+    }
+    return true;
+  }
+
+  /// Folding is only possible in the absence of an internal permutation in the
+  /// result vector.
+  bool canFold() {
+    return (sentinels ==
+            makeArrayRef(extractPosition).drop_front(extractedRank));
+  }
+
+  // Helper to get the next defining op of interest.
+  void updateStateForNextIteration(Value v) {
+    nextInsertOp = v.getDefiningOp<vector::InsertOp>();
+    nextTransposeOp = v.getDefiningOp<vector::TransposeOp>();
+  };
+
+  // Case 1. If we hit a transpose, just compose the map and iterate.
+  // Invariant: insert + transpose do not change rank, we can always compose.
+  LogicalResult handleTransposeOp();
+
+  // Case 2: the insert position matches extractPosition exactly, early return.
+  LogicalResult handleInsertOpWithMatchingPos(Value &res);
+
+  /// Case 3: if the insert position is a prefix of extractPosition, extract a
+  /// portion of the source of the insert.
+  /// Example:
+  /// ```
+  /// %ins = vector.insert %source, %vest[1]: vector<3x4> into vector<2x3x4x5>
+  /// // extractPosition == [1, 2, 3]
+  /// %ext = vector.extract %ins[1, 0]: vector<3x4x5>
+  /// // can fold to vector.extract %source[0, 3]
+  /// %ext = vector.extract %source[3]: vector<5x6>
+  /// ```
+  /// To traverse through %source, we need to set the leading dims to 0 and
+  /// drop the extra leading dims.
+  /// This method updates the internal state.
+  LogicalResult handleInsertOpWithPrefixPos(Value &res);
+
+  /// Try to fold in place to extract(source, extractPosition) and return the
+  /// folded result. Return null if folding is not possible (e.g. due to an
+  /// internal tranposition in the result).
+  Value tryToFoldExtractOpInPlace(Value source);
+
+  ExtractOp extractOp;
+  int64_t vectorRank;
+  int64_t extractedRank;
+
+  InsertOp nextInsertOp;
+  TransposeOp nextTransposeOp;
+
+  /// Sentinel values that encode the internal permutation status of the result.
+  /// They are set to (-1, ... , -k) at the beginning and appended to
+  /// `extractPosition`.
+  /// In the end, the tail of `extractPosition` must be exactly `sentinels` to
+  /// ensure that there is no internal transposition.
+  /// Internal transposition cannot be accounted for with a folding pattern.
+  // TODO: We could relax the internal transposition with an extra transposition
+  // operation in a future canonicalizer.
+  SmallVector<int64_t> sentinels;
+  SmallVector<int64_t> extractPosition;
+};
+} // namespace
 
-  // If transposition permutation is larger than the ExtractOp, all minor
-  // dimensions must be an identity for folding to occur. If not, individual
-  // elements within the extracted value are transposed and this is not just a
-  // simple folding.
-  unsigned minorRank = permutation.size() - extractedPos.size();
-  MLIRContext *ctx = extractOp.getContext();
-  AffineMap permutationMap = AffineMap::getPermutationMap(permutation, ctx);
-  AffineMap minorMap = permutationMap.getMinorSubMap(minorRank);
-  if (minorMap && !minorMap.isMinorIdentity())
+ExtractFromInsertTransposeChainState::ExtractFromInsertTransposeChainState(
+    ExtractOp e)
+    : extractOp(e), vectorRank(extractOp.getVectorType().getRank()),
+      extractedRank(extractOp.position().size()) {
+  assert(vectorRank >= extractedRank && "extracted pos overflow");
+  sentinels.reserve(vectorRank - extractedRank);
+  for (int64_t i = 0, e = vectorRank - extractedRank; i < e; ++i)
+    sentinels.push_back(-(i + 1));
+  extractPosition = extractVector<int64_t>(extractOp.position());
+  llvm::append_range(extractPosition, sentinels);
+}
+
+// Case 1. If we hit a transpose, just compose the map and iterate.
+// Invariant: insert + transpose do not change rank, we can always compose.
+LogicalResult ExtractFromInsertTransposeChainState::handleTransposeOp() {
+  if (!nextTransposeOp)
     return failure();
+  auto permutation = extractVector<unsigned>(nextTransposeOp.transp());
+  AffineMap m = inversePermutation(
+      AffineMap::getPermutationMap(permutation, extractOp.getContext()));
+  extractPosition = applyPermutationMap(m, makeArrayRef(extractPosition));
+  return success();
+}
 
-  //   %1 = transpose %0[x, y, z] : vector<axbxcxf32>
-  //   %2 = extract %1[u, v] : vector<..xf32>
-  // may turn into:
-  //   %2 = extract %0[w, x] : vector<..xf32>
-  // iff z == 2 and [w, x] = [x, y]^-1 o [u, v] here o denotes composition and
-  // -1 denotes the inverse.
-  permutationMap = permutationMap.getMajorSubMap(extractedPos.size());
-  // The major submap has fewer results but the same number of dims. To compose
-  // cleanly, we need to drop dims to form a "square matrix". This is possible
-  // because:
-  //   (a) this is a permutation map and
-  //   (b) the minor map has already been checked to be identity.
-  // Therefore, the major map cannot contain dims of position greater or equal
-  // than the number of results.
-  assert(llvm::all_of(permutationMap.getResults(),
-                      [&](AffineExpr e) {
-                        auto dim = e.dyn_cast<AffineDimExpr>();
-                        return dim && dim.getPosition() <
-                                          permutationMap.getNumResults();
-                      }) &&
-         "Unexpected map results depend on higher rank positions");
-  // Project on the first domain dimensions to allow composition.
-  permutationMap = AffineMap::get(permutationMap.getNumResults(), 0,
-                                  permutationMap.getResults(), ctx);
-
-  extractOp.setOperand(transposeOp.vector());
-  // Compose the inverse permutation map with the extractedPos.
-  auto newExtractedPos =
-      inversePermutation(permutationMap).compose(extractedPos);
-  // OpBuilder is only used as a helper to build an I64ArrayAttr.
-  OpBuilder b(extractOp.getContext());
-  extractOp->setAttr(ExtractOp::getPositionAttrName(),
-                     b.getI64ArrayAttr(newExtractedPos));
+// Case 2: the insert position matches extractPosition exactly, early return.
+LogicalResult
+ExtractFromInsertTransposeChainState::handleInsertOpWithMatchingPos(
+    Value &res) {
+  auto insertedPos = extractVector<int64_t>(nextInsertOp.position());
+  if (makeArrayRef(insertedPos) !=
+      llvm::makeArrayRef(extractPosition).take_front(extractedRank))
+    return failure();
+  // Case 2.a. early-exit fold.
+  res = nextInsertOp.source();
+  // Case 2.b. if internal transposition is present, canFold will be false.
+  return success();
+}
 
+/// Case 3: if inserted position is a prefix of extractPosition,
+/// extract a portion of the source of the insertion.
+/// This method updates the internal state.
+LogicalResult
+ExtractFromInsertTransposeChainState::handleInsertOpWithPrefixPos(Value &res) {
+  auto insertedPos = extractVector<int64_t>(nextInsertOp.position());
+  if (!isContainedWithin(insertedPos, extractPosition))
+    return failure();
+  // Set leading dims to zero.
+  std::fill_n(extractPosition.begin(), insertedPos.size(), 0);
+  // Drop extra leading dims.
+  extractPosition.erase(extractPosition.begin(),
+                        extractPosition.begin() + insertedPos.size());
+  extractedRank = extractPosition.size() - sentinels.size();
+  // Case 3.a. early-exit fold (break and delegate to post-while path).
+  res = nextInsertOp.source();
+  // Case 3.b. if internal transposition is present, canFold will be false.
   return success();
 }
 
-/// Fold an ExtractOp that is fed by a chain of InsertOps and TransposeOps. The
-/// result is always the input to some InsertOp.
-static Value foldExtractOpFromInsertChainAndTranspose(ExtractOp extractOp) {
-  MLIRContext *context = extractOp.getContext();
-  AffineMap permutationMap;
-  auto extractedPos = extractVector<unsigned>(extractOp.position());
-  // Walk back a chain of InsertOp/TransposeOp until we hit a match.
-  // Compose TransposeOp permutations as we walk back.
-  auto insertOp = extractOp.vector().getDefiningOp<vector::InsertOp>();
-  auto transposeOp = extractOp.vector().getDefiningOp<vector::TransposeOp>();
-  while (insertOp || transposeOp) {
-    if (transposeOp) {
-      // If it is transposed, compose the map and iterate.
-      auto permutation = extractVector<unsigned>(transposeOp.transp());
-      AffineMap newMap = AffineMap::getPermutationMap(permutation, context);
-      if (!permutationMap)
-        permutationMap = newMap;
-      else if (newMap.getNumInputs() != permutationMap.getNumResults())
-        return Value();
-      else
-        permutationMap = newMap.compose(permutationMap);
-      // Compute insert/transpose for the next iteration.
-      Value transposed = transposeOp.vector();
-      insertOp = transposed.getDefiningOp<vector::InsertOp>();
-      transposeOp = transposed.getDefiningOp<vector::TransposeOp>();
-      continue;
-    }
+/// Try to fold in place to extract(source, extractPosition) and return the
+/// folded result. Return null if folding is not possible (e.g. due to an
+/// internal tranposition in the result).
+Value ExtractFromInsertTransposeChainState::tryToFoldExtractOpInPlace(
+    Value source) {
+  // If we can't fold (either internal transposition, or nothing to fold), bail.
+  bool nothingToFold = (source == extractOp.vector());
+  if (nothingToFold || !canFold())
+    return Value();
+  // Otherwise, fold by updating the op inplace and return its result.
+  OpBuilder b(extractOp.getContext());
+  extractOp->setAttr(
+      extractOp.positionAttrName(),
+      b.getI64ArrayAttr(
+          makeArrayRef(extractPosition).take_front(extractedRank)));
+  extractOp.vectorMutable().assign(source);
+  return extractOp.getResult();
+}
 
-    assert(insertOp);
-    Value insertionDest = insertOp.dest();
-    // If it is inserted into, either the position matches and we have a
-    // successful folding; or we iterate until we run out of
-    // InsertOp/TransposeOp. This is because `vector.insert %scalar, %vector`
-    // produces a new vector with 1 modified value/slice in exactly the static
-    // position we need to match.
-    auto insertedPos = extractVector<unsigned>(insertOp.position());
-    // Trivial permutations are solved with position equality checks.
-    if (!permutationMap || permutationMap.isIdentity()) {
-      if (extractedPos == insertedPos)
-        return insertOp.source();
-      // Fallthrough: if the position does not match, just skip to the next
-      // producing `vector.insert` / `vector.transpose`.
-      // Compute insert/transpose for the next iteration.
-      insertOp = insertionDest.getDefiningOp<vector::InsertOp>();
-      transposeOp = insertionDest.getDefiningOp<vector::TransposeOp>();
+/// Iterate over producing insert and transpose ops until we find a fold.
+Value ExtractFromInsertTransposeChainState::fold() {
+  Value valueToExtractFrom = extractOp.vector();
+  updateStateForNextIteration(valueToExtractFrom);
+  while (nextInsertOp || nextTransposeOp) {
+    // Case 1. If we hit a transpose, just compose the map and iterate.
+    // Invariant: insert + transpose do not change rank, we can always compose.
+    if (succeeded(handleTransposeOp())) {
+      valueToExtractFrom = nextTransposeOp.vector();
+      updateStateForNextIteration(valueToExtractFrom);
       continue;
     }
 
-    // More advanced permutations require application of the permutation.
-    // However, the rank of `insertedPos` may be different from that of the
-    // `permutationMap`. To support such case, we need to:
-    //   1. apply on the `insertedPos.size()` major dimensions
-    //   2. check the other dimensions of the permutation form a minor identity.
-    assert(permutationMap.isPermutation() && "expected a permutation");
-    if (insertedPos.size() == extractedPos.size()) {
-      bool fold = true;
-      for (unsigned idx = 0, sz = extractedPos.size(); idx < sz; ++idx) {
-        auto pos = permutationMap.getDimPosition(idx);
-        if (pos >= sz || insertedPos[pos] != extractedPos[idx]) {
-          fold = false;
-          break;
-        }
-      }
-      if (fold) {
-        assert(permutationMap.getNumResults() >= insertedPos.size() &&
-               "expected map of rank larger than insert indexing");
-        unsigned minorRank =
-            permutationMap.getNumResults() - insertedPos.size();
-        AffineMap minorMap = permutationMap.getMinorSubMap(minorRank);
-        if (!minorMap || minorMap.isMinorIdentity())
-          return insertOp.source();
-      }
-    }
+    Value result;
+    // Case 2: the position match exactly.
+    if (succeeded(handleInsertOpWithMatchingPos(result)))
+      return result;
+
+    // Case 3: if the inserted position is a prefix of extractPosition, we can
+    // just extract a portion of the source of the insert.
+    if (succeeded(handleInsertOpWithPrefixPos(result)))
+      return tryToFoldExtractOpInPlace(result);
+
+    // Case 4: extractPositionRef intersects insertedPosRef on non-sentinel
+    // values. This is a more difficult case and we bail.
+    auto insertedPos = extractVector<int64_t>(nextInsertOp.position());
+    if (isContainedWithin(extractPosition, insertedPos) ||
+        intersectsWhereNonNegative(extractPosition, insertedPos))
+      return Value();
 
-    // If we haven't found a match, just continue to the next producing
-    // `vector.insert` / `vector.transpose`.
-    // Compute insert/transpose for the next iteration.
-    insertOp = insertionDest.getDefiningOp<vector::InsertOp>();
-    transposeOp = insertionDest.getDefiningOp<vector::TransposeOp>();
+    // Case 5: No intersection, we forward the extract to insertOp.dest().
+    valueToExtractFrom = nextInsertOp.dest();
+    updateStateForNextIteration(valueToExtractFrom);
   }
-  return Value();
+  // If after all this we can fold, go for it.
+  return tryToFoldExtractOpInPlace(valueToExtractFrom);
 }
 
 /// Fold extractOp with scalar result coming from BroadcastOp or SplatOp.
@@ -1312,14 +1387,12 @@ OpFoldResult ExtractOp::fold(ArrayRef<Attribute>) {
     return vector();
   if (succeeded(foldExtractOpFromExtractChain(*this)))
     return getResult();
-  if (succeeded(foldExtractOpFromTranspose(*this)))
-    return getResult();
-  if (auto val = foldExtractOpFromInsertChainAndTranspose(*this))
-    return val;
-  if (auto val = foldExtractFromBroadcast(*this))
-    return val;
-  if (auto val = foldExtractFromShapeCast(*this))
-    return val;
+  if (auto res = ExtractFromInsertTransposeChainState(*this).fold())
+    return res;
+  if (auto res = foldExtractFromBroadcast(*this))
+    return res;
+  if (auto res = foldExtractFromShapeCast(*this))
+    return res;
   if (auto val = foldExtractFromExtractStrided(*this))
     return val;
   if (auto val = foldExtractStridedOpFromInsertChain(*this))
index ba0e0a2..5da3382 100644 (file)
@@ -316,87 +316,103 @@ func @insert_extract_transpose_2d(
 
 // -----
 
-// CHECK-LABEL: func @insert_extract_transpose_3d(
-//  CHECK-SAME: %[[V:[a-zA-Z0-9]*]]: vector<2x3x4xf32>,
-//  CHECK-SAME: %[[F0:[a-zA-Z0-9]*]]: f32,
-//  CHECK-SAME: %[[F1:[a-zA-Z0-9]*]]: f32,
-//  CHECK-SAME: %[[F2:[a-zA-Z0-9]*]]: f32,
-//  CHECK-SAME: %[[F3:[a-zA-Z0-9]*]]: f32
-func @insert_extract_transpose_3d(
-    %v: vector<2x3x4xf32>, %f0: f32, %f1: f32, %f2: f32, %f3: f32)
--> (f32, f32, f32, f32)
-{
-  %0 = vector.insert %f0, %v[0, 0, 0] : f32 into vector<2x3x4xf32>
-  %1 = vector.insert %f1, %0[0, 1, 0] : f32 into vector<2x3x4xf32>
-  %2 = vector.insert %f2, %1[1, 0, 0] : f32 into vector<2x3x4xf32>
-  %3 = vector.insert %f3, %2[0, 0, 1] : f32 into vector<2x3x4xf32>
-  %4 = vector.transpose %3, [1, 2, 0] : vector<2x3x4xf32> to vector<3x4x2xf32>
-  %5 = vector.insert %f3, %4[1, 0, 0] : f32 into vector<3x4x2xf32>
-  %6 = vector.transpose %5, [1, 2, 0] : vector<3x4x2xf32> to vector<4x2x3xf32>
-  %7 = vector.insert %f3, %6[1, 0, 0] : f32 into vector<4x2x3xf32>
-  %8 = vector.transpose %7, [1, 2, 0] : vector<4x2x3xf32> to vector<2x3x4xf32>
-
-  // Expected %f2 from %2 = vector.insert %f2, %1[1, 0, 0].
-  %r1 = vector.extract %3[1, 0, 0] : vector<2x3x4xf32>
-
-  // Expected %f1 from %1 = vector.insert %f1, %0[0, 1, 0] followed by
-  // transpose[1, 2, 0].
-  %r2 = vector.extract %4[1, 0, 0] : vector<3x4x2xf32>
-
-  // Expected %f3 from %3 = vector.insert %f3, %0[0, 0, 1] followed by double
-  // transpose[1, 2, 0].
-  %r3 = vector.extract %6[1, 0, 0] : vector<4x2x3xf32>
-
-  // Expected %f2 from %2 = vector.insert %f2, %1[1, 0, 0] followed by triple
-  // transpose[1, 2, 0].
-  %r4 = vector.extract %8[1, 0, 0] : vector<2x3x4xf32>
-
-  // CHECK-NEXT: return %[[F2]], %[[F1]], %[[F3]], %[[F2]] : f32, f32, f32
-  return %r1, %r2, %r3, %r4 : f32, f32, f32, f32
-}
-
-// -----
-
-// CHECK-LABEL: func @insert_extract_transpose_3d_2d(
-//  CHECK-SAME: %[[V:[a-zA-Z0-9]*]]: vector<2x3x4xf32>,
-//  CHECK-SAME: %[[F0:[a-zA-Z0-9]*]]: vector<4xf32>,
-//  CHECK-SAME: %[[F1:[a-zA-Z0-9]*]]: vector<4xf32>,
-//  CHECK-SAME: %[[F2:[a-zA-Z0-9]*]]: vector<4xf32>,
-//  CHECK-SAME: %[[F3:[a-zA-Z0-9]*]]: vector<4xf32>
-func @insert_extract_transpose_3d_2d(
-    %v: vector<2x3x4xf32>,
-    %f0: vector<4xf32>, %f1: vector<4xf32>, %f2: vector<4xf32>, %f3: vector<4xf32>)
--> (vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>)
-{
-  %0 = vector.insert %f0, %v[0, 0] : vector<4xf32> into vector<2x3x4xf32>
-  %1 = vector.insert %f1, %0[0, 1] : vector<4xf32> into vector<2x3x4xf32>
-  %2 = vector.insert %f2, %1[1, 0] : vector<4xf32> into vector<2x3x4xf32>
-  %3 = vector.insert %f3, %2[1, 1] : vector<4xf32> into vector<2x3x4xf32>
-  %4 = vector.transpose %3, [1, 0, 2] : vector<2x3x4xf32> to vector<3x2x4xf32>
-  %5 = vector.transpose %4, [1, 0, 2] : vector<3x2x4xf32> to vector<2x3x4xf32>
-
-  // Expected %f2 from %2 = vector.insert %f2, %1[1, 0].
-  %r1 = vector.extract %3[1, 0] : vector<2x3x4xf32>
-
-  // Expected %f1 from %1 = vector.insert %f1, %0[0, 1] followed by
-  // transpose[1, 0, 2].
-  %r2 = vector.extract %4[1, 0] : vector<3x2x4xf32>
-
-  // Expected %f2 from %2 = vector.insert %f2, %1[1, 0, 0] followed by double
-  // transpose[1, 0, 2].
-  %r3 = vector.extract %5[1, 0] : vector<2x3x4xf32>
-
-  %6 = vector.transpose %3, [1, 2, 0] : vector<2x3x4xf32> to vector<3x4x2xf32>
-  %7 = vector.transpose %6, [1, 2, 0] : vector<3x4x2xf32> to vector<4x2x3xf32>
-  %8 = vector.transpose %7, [1, 2, 0] : vector<4x2x3xf32> to vector<2x3x4xf32>
+// CHECK-LABEL: insert_extract_chain
+//  CHECK-SAME: %[[V234:[a-zA-Z0-9]*]]: vector<2x3x4xf32>
+//  CHECK-SAME: %[[V34:[a-zA-Z0-9]*]]: vector<3x4xf32>
+//  CHECK-SAME: %[[V4:[a-zA-Z0-9]*]]: vector<4xf32>
+func @insert_extract_chain(%v234: vector<2x3x4xf32>, %v34: vector<3x4xf32>, %v4: vector<4xf32>)
+    -> (vector<4xf32>, vector<4xf32>, vector<3x4xf32>, vector<3x4xf32>) {
+  // CHECK-NEXT: %[[A34:.*]] = vector.insert
+  %A34 = vector.insert %v34, %v234[0]: vector<3x4xf32> into vector<2x3x4xf32>
+  // CHECK-NEXT: %[[B34:.*]] = vector.insert
+  %B34 = vector.insert %v34, %A34[1]: vector<3x4xf32> into vector<2x3x4xf32>
+  // CHECK-NEXT: %[[A4:.*]] = vector.insert
+  %A4 = vector.insert %v4, %B34[1, 0]: vector<4xf32> into vector<2x3x4xf32>
+  // CHECK-NEXT: %[[B4:.*]] = vector.insert
+  %B4 = vector.insert %v4, %A4[1, 1]: vector<4xf32> into vector<2x3x4xf32>
+
+  // Case 2.a. [1, 1] == insertpos ([1, 1])
+  // Match %A4 insertionpos and fold to its source(i.e. %V4).
+   %r0 = vector.extract %B4[1, 1]: vector<2x3x4xf32>
+
+  // Case 3.a. insertpos ([1]) is a prefix of [1, 0].
+  // Traverse %B34 to its source(i.e. %V34@[*0*]).
+  // CHECK-NEXT: %[[R1:.*]] = vector.extract %[[V34]][0]
+   %r1 = vector.extract %B34[1, 0]: vector<2x3x4xf32>
+
+  // Case 4. [1] is a prefix of insertpos ([1, 1]).
+  // Cannot traverse %B4.
+  // CHECK-NEXT: %[[R2:.*]] = vector.extract %[[B4]][1]
+   %r2 = vector.extract %B4[1]: vector<2x3x4xf32>
+
+  // Case 5. [0] is disjoint from insertpos ([1, 1]).
+  // Traverse %B4 to its dest(i.e. %A4@[0]).
+  // Traverse %A4 to its dest(i.e. %B34@[0]).
+  // Traverse %B34 to its dest(i.e. %A34@[0]).
+  // Match %A34 insertionpos and fold to its source(i.e. %V34).
+   %r3 = vector.extract %B4[0]: vector<2x3x4xf32>
+
+  // CHECK: return %[[V4]], %[[R1]], %[[R2]], %[[V34]]
+  return %r0, %r1, %r2, %r3:
+    vector<4xf32>, vector<4xf32>, vector<3x4xf32>, vector<3x4xf32>
+}
 
-  // Expected %f2 from %2 = vector.insert %f2, %1[1, 0, 0] followed by triple
-  // transpose[1, 2, 0].
-  %r4 = vector.extract %8[1, 0] : vector<2x3x4xf32>
+// -----
 
-  //      CHECK: return %[[F2]], %[[F1]], %[[F2]], %[[F2]]
-  // CHECK-SAME: vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>
-  return %r1, %r2, %r3, %r4 : vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>
+// CHECK-LABEL: func @insert_extract_transpose_3d(
+//  CHECK-SAME: %[[V234:[a-zA-Z0-9]*]]: vector<2x3x4xf32>
+func @insert_extract_transpose_3d(
+  %v234: vector<2x3x4xf32>, %v43: vector<4x3xf32>, %f0: f32)
+    -> (vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<3x4xf32>) {
+
+  %a432 = vector.transpose %v234, [2, 1, 0] : vector<2x3x4xf32> to vector<4x3x2xf32>
+  %b432 = vector.insert %f0, %a432[0, 0, 1] : f32 into vector<4x3x2xf32>
+  %c234 = vector.transpose %b432, [2, 1, 0] : vector<4x3x2xf32> to vector<2x3x4xf32>
+  // Case 1. %c234 = transpose [2,1,0] posWithSentinels [1,2,-1] -> [-1,2,1]
+  // Case 5. %b432 = insert [0,0,1] (inter([.,2,1], [.,0,1]) == 0) prop to %v432
+  // Case 1. %a432 = transpose [2,1,0] posWithSentinels [-1,2,1] -> [1,2,-1]
+  // can extract directly from %v234, the rest folds.
+  // CHECK: %[[R0:.*]] = vector.extract %[[V234]][1, 2]
+  %r0 = vector.extract %c234[1, 2] : vector<2x3x4xf32>
+
+  // CHECK-NEXT: vector.transpose
+  // CHECK-NEXT: vector.insert
+  // CHECK-NEXT: %[[F234:.*]] = vector.transpose
+  %d432 = vector.transpose %v234, [2, 1, 0] : vector<2x3x4xf32> to vector<4x3x2xf32>
+  %e432 = vector.insert %f0, %d432[0, 2, 1] : f32 into vector<4x3x2xf32>
+  %f234 = vector.transpose %e432, [2, 1, 0] : vector<4x3x2xf32> to vector<2x3x4xf32>
+  // Case 1. %c234 = transpose [2,1,0] posWithSentinels [1,2,-1] -> [-1,2,1]
+  // Case 4. %b432 = insert [0,0,1] (inter([.,2,1], [.,2,1]) != 0)
+  // Bail, cannot do better than the current.
+  // CHECK: %[[R1:.*]] = vector.extract %[[F234]]
+  %r1 = vector.extract %f234[1, 2] : vector<2x3x4xf32>
+
+  // CHECK-NEXT: vector.transpose
+  // CHECK-NEXT: vector.insert
+  // CHECK-NEXT: %[[H234:.*]] = vector.transpose
+  %g243 = vector.transpose %v234, [0, 2, 1] : vector<2x3x4xf32> to vector<2x4x3xf32>
+  %h243 = vector.insert %v43, %g243[0] : vector<4x3xf32> into vector<2x4x3xf32>
+  %i234 = vector.transpose %h243, [0, 2, 1] : vector<2x4x3xf32> to vector<2x3x4xf32>
+  // Case 1. %i234 = transpose [0,2,1] posWithSentinels [0,-1,-2] -> [0,-2,-1]
+  // Case 3.b. %b432 = insert [0] is prefix of [0,.,.] but internal transpose.
+  // Bail, cannot do better than the current.
+  // CHECK: %[[R2:.*]] = vector.extract %[[H234]][0, 1]
+  %r2 = vector.extract %i234[0, 1] : vector<2x3x4xf32>
+
+  // CHECK-NEXT: vector.transpose
+  // CHECK-NEXT: vector.insert
+  // CHECK-NEXT: %[[K234:.*]] = vector.transpose
+  %j243 = vector.transpose %v234, [0, 2, 1] : vector<2x3x4xf32> to vector<2x4x3xf32>
+  %k243 = vector.insert %v43, %j243[0] : vector<4x3xf32> into vector<2x4x3xf32>
+  %l234 = vector.transpose %k243, [0, 2, 1] : vector<2x4x3xf32> to vector<2x3x4xf32>
+  // Case 1. %i234 = transpose [0,2,1] posWithSentinels [0,-1,-2] -> [0,-2,-1]
+  // Case 2.b. %b432 = insert [0] == [0,.,.] but internal transpose.
+  // Bail, cannot do better than the current.
+  // CHECK: %[[R3:.*]] = vector.extract %[[K234]][0]
+  %r3 = vector.extract %l234[0] : vector<2x3x4xf32>
+
+  // CHECK-NEXT: return %[[R0]], %[[R1]], %[[R2]], %[[R3]]
+  return %r0, %r1, %r2, %r3: vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<3x4xf32>
 }
 
 // -----