}
template <typename IntType>
-static SmallVector<IntType, 4> extractVector(ArrayAttr arrayAttr) {
+static SmallVector<IntType> extractVector(ArrayAttr arrayAttr) {
return llvm::to_vector<4>(llvm::map_range(
arrayAttr.getAsRange<IntegerAttr>(),
[](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
SmallVector<int64_t, 4> globalPosition;
ExtractOp currentOp = extractOp;
- auto extractedPos = extractVector<int64_t>(currentOp.position());
- globalPosition.append(extractedPos.rbegin(), extractedPos.rend());
+ auto extrPos = extractVector<int64_t>(currentOp.position());
+ globalPosition.append(extrPos.rbegin(), extrPos.rend());
while (ExtractOp nextOp = currentOp.vector().getDefiningOp<ExtractOp>()) {
currentOp = nextOp;
- auto extractedPos = extractVector<int64_t>(currentOp.position());
- globalPosition.append(extractedPos.rbegin(), extractedPos.rend());
+ auto extrPos = extractVector<int64_t>(currentOp.position());
+ globalPosition.append(extrPos.rbegin(), extrPos.rend());
}
extractOp.setOperand(currentOp.vector());
// OpBuilder is only used as a helper to build an I64ArrayAttr.
return success();
}
-/// Fold the result of an ExtractOp in place when it comes from a TransposeOp.
-static LogicalResult foldExtractOpFromTranspose(ExtractOp extractOp) {
- auto transposeOp = extractOp.vector().getDefiningOp<vector::TransposeOp>();
- if (!transposeOp)
- return failure();
+namespace {
+/// Fold an ExtractOp that is fed by a chain of InsertOps and TransposeOps.
+/// Walk back a chain of InsertOp/TransposeOp until we hit a match.
+/// Compose TransposeOp permutations as we walk back.
+/// This helper class keeps an updated extraction position `extractPosition`
+/// with extra trailing sentinels.
+/// The sentinels encode the internal transposition status of the result vector.
+/// As we iterate, extractPosition is permuted and updated.
+class ExtractFromInsertTransposeChainState {
+public:
+ ExtractFromInsertTransposeChainState(ExtractOp e);
+
+ /// Iterate over producing insert and transpose ops until we find a fold.
+ Value fold();
+
+private:
+ /// Return true if the vector at position `a` is contained within the vector
+ /// at position `b`. Under insert/extract semantics, this is the same as `a`
+ /// is a prefix of `b`.
+ template <typename ContainerA, typename ContainerB>
+ bool isContainedWithin(const ContainerA &a, const ContainerB &b) {
+ return a.size() <= b.size() &&
+ std::equal(a.begin(), a.begin() + a.size(), b.begin());
+ }
- auto permutation = extractVector<unsigned>(transposeOp.transp());
- auto extractedPos = extractVector<int64_t>(extractOp.position());
+ /// Return true if the vector at position `a` intersects the vector at
+ /// position `b`. Under insert/extract semantics, this is the same as equality
+ /// of all entries of `a` that are >=0 with the corresponding entries of b.
+ /// Comparison is on the common prefix (i.e. zip).
+ template <typename ContainerA, typename ContainerB>
+ bool intersectsWhereNonNegative(const ContainerA &a, const ContainerB &b) {
+ for (auto it : llvm::zip(a, b)) {
+ if (std::get<0>(it) < 0 || std::get<0>(it) < 0)
+ continue;
+ if (std::get<0>(it) != std::get<1>(it))
+ return false;
+ }
+ return true;
+ }
+
+ /// Folding is only possible in the absence of an internal permutation in the
+ /// result vector.
+ bool canFold() {
+ return (sentinels ==
+ makeArrayRef(extractPosition).drop_front(extractedRank));
+ }
+
+ // Helper to get the next defining op of interest.
+ void updateStateForNextIteration(Value v) {
+ nextInsertOp = v.getDefiningOp<vector::InsertOp>();
+ nextTransposeOp = v.getDefiningOp<vector::TransposeOp>();
+ };
+
+ // Case 1. If we hit a transpose, just compose the map and iterate.
+ // Invariant: insert + transpose do not change rank, we can always compose.
+ LogicalResult handleTransposeOp();
+
+ // Case 2: the insert position matches extractPosition exactly, early return.
+ LogicalResult handleInsertOpWithMatchingPos(Value &res);
+
+ /// Case 3: if the insert position is a prefix of extractPosition, extract a
+ /// portion of the source of the insert.
+ /// Example:
+ /// ```
+ /// %ins = vector.insert %source, %vest[1]: vector<3x4> into vector<2x3x4x5>
+ /// // extractPosition == [1, 2, 3]
+ /// %ext = vector.extract %ins[1, 0]: vector<3x4x5>
+ /// // can fold to vector.extract %source[0, 3]
+ /// %ext = vector.extract %source[3]: vector<5x6>
+ /// ```
+ /// To traverse through %source, we need to set the leading dims to 0 and
+ /// drop the extra leading dims.
+ /// This method updates the internal state.
+ LogicalResult handleInsertOpWithPrefixPos(Value &res);
+
+ /// Try to fold in place to extract(source, extractPosition) and return the
+ /// folded result. Return null if folding is not possible (e.g. due to an
+ /// internal tranposition in the result).
+ Value tryToFoldExtractOpInPlace(Value source);
+
+ ExtractOp extractOp;
+ int64_t vectorRank;
+ int64_t extractedRank;
+
+ InsertOp nextInsertOp;
+ TransposeOp nextTransposeOp;
+
+ /// Sentinel values that encode the internal permutation status of the result.
+ /// They are set to (-1, ... , -k) at the beginning and appended to
+ /// `extractPosition`.
+ /// In the end, the tail of `extractPosition` must be exactly `sentinels` to
+ /// ensure that there is no internal transposition.
+ /// Internal transposition cannot be accounted for with a folding pattern.
+ // TODO: We could relax the internal transposition with an extra transposition
+ // operation in a future canonicalizer.
+ SmallVector<int64_t> sentinels;
+ SmallVector<int64_t> extractPosition;
+};
+} // namespace
- // If transposition permutation is larger than the ExtractOp, all minor
- // dimensions must be an identity for folding to occur. If not, individual
- // elements within the extracted value are transposed and this is not just a
- // simple folding.
- unsigned minorRank = permutation.size() - extractedPos.size();
- MLIRContext *ctx = extractOp.getContext();
- AffineMap permutationMap = AffineMap::getPermutationMap(permutation, ctx);
- AffineMap minorMap = permutationMap.getMinorSubMap(minorRank);
- if (minorMap && !minorMap.isMinorIdentity())
+ExtractFromInsertTransposeChainState::ExtractFromInsertTransposeChainState(
+ ExtractOp e)
+ : extractOp(e), vectorRank(extractOp.getVectorType().getRank()),
+ extractedRank(extractOp.position().size()) {
+ assert(vectorRank >= extractedRank && "extracted pos overflow");
+ sentinels.reserve(vectorRank - extractedRank);
+ for (int64_t i = 0, e = vectorRank - extractedRank; i < e; ++i)
+ sentinels.push_back(-(i + 1));
+ extractPosition = extractVector<int64_t>(extractOp.position());
+ llvm::append_range(extractPosition, sentinels);
+}
+
+// Case 1. If we hit a transpose, just compose the map and iterate.
+// Invariant: insert + transpose do not change rank, we can always compose.
+LogicalResult ExtractFromInsertTransposeChainState::handleTransposeOp() {
+ if (!nextTransposeOp)
return failure();
+ auto permutation = extractVector<unsigned>(nextTransposeOp.transp());
+ AffineMap m = inversePermutation(
+ AffineMap::getPermutationMap(permutation, extractOp.getContext()));
+ extractPosition = applyPermutationMap(m, makeArrayRef(extractPosition));
+ return success();
+}
- // %1 = transpose %0[x, y, z] : vector<axbxcxf32>
- // %2 = extract %1[u, v] : vector<..xf32>
- // may turn into:
- // %2 = extract %0[w, x] : vector<..xf32>
- // iff z == 2 and [w, x] = [x, y]^-1 o [u, v] here o denotes composition and
- // -1 denotes the inverse.
- permutationMap = permutationMap.getMajorSubMap(extractedPos.size());
- // The major submap has fewer results but the same number of dims. To compose
- // cleanly, we need to drop dims to form a "square matrix". This is possible
- // because:
- // (a) this is a permutation map and
- // (b) the minor map has already been checked to be identity.
- // Therefore, the major map cannot contain dims of position greater or equal
- // than the number of results.
- assert(llvm::all_of(permutationMap.getResults(),
- [&](AffineExpr e) {
- auto dim = e.dyn_cast<AffineDimExpr>();
- return dim && dim.getPosition() <
- permutationMap.getNumResults();
- }) &&
- "Unexpected map results depend on higher rank positions");
- // Project on the first domain dimensions to allow composition.
- permutationMap = AffineMap::get(permutationMap.getNumResults(), 0,
- permutationMap.getResults(), ctx);
-
- extractOp.setOperand(transposeOp.vector());
- // Compose the inverse permutation map with the extractedPos.
- auto newExtractedPos =
- inversePermutation(permutationMap).compose(extractedPos);
- // OpBuilder is only used as a helper to build an I64ArrayAttr.
- OpBuilder b(extractOp.getContext());
- extractOp->setAttr(ExtractOp::getPositionAttrName(),
- b.getI64ArrayAttr(newExtractedPos));
+// Case 2: the insert position matches extractPosition exactly, early return.
+LogicalResult
+ExtractFromInsertTransposeChainState::handleInsertOpWithMatchingPos(
+ Value &res) {
+ auto insertedPos = extractVector<int64_t>(nextInsertOp.position());
+ if (makeArrayRef(insertedPos) !=
+ llvm::makeArrayRef(extractPosition).take_front(extractedRank))
+ return failure();
+ // Case 2.a. early-exit fold.
+ res = nextInsertOp.source();
+ // Case 2.b. if internal transposition is present, canFold will be false.
+ return success();
+}
+/// Case 3: if inserted position is a prefix of extractPosition,
+/// extract a portion of the source of the insertion.
+/// This method updates the internal state.
+LogicalResult
+ExtractFromInsertTransposeChainState::handleInsertOpWithPrefixPos(Value &res) {
+ auto insertedPos = extractVector<int64_t>(nextInsertOp.position());
+ if (!isContainedWithin(insertedPos, extractPosition))
+ return failure();
+ // Set leading dims to zero.
+ std::fill_n(extractPosition.begin(), insertedPos.size(), 0);
+ // Drop extra leading dims.
+ extractPosition.erase(extractPosition.begin(),
+ extractPosition.begin() + insertedPos.size());
+ extractedRank = extractPosition.size() - sentinels.size();
+ // Case 3.a. early-exit fold (break and delegate to post-while path).
+ res = nextInsertOp.source();
+ // Case 3.b. if internal transposition is present, canFold will be false.
return success();
}
-/// Fold an ExtractOp that is fed by a chain of InsertOps and TransposeOps. The
-/// result is always the input to some InsertOp.
-static Value foldExtractOpFromInsertChainAndTranspose(ExtractOp extractOp) {
- MLIRContext *context = extractOp.getContext();
- AffineMap permutationMap;
- auto extractedPos = extractVector<unsigned>(extractOp.position());
- // Walk back a chain of InsertOp/TransposeOp until we hit a match.
- // Compose TransposeOp permutations as we walk back.
- auto insertOp = extractOp.vector().getDefiningOp<vector::InsertOp>();
- auto transposeOp = extractOp.vector().getDefiningOp<vector::TransposeOp>();
- while (insertOp || transposeOp) {
- if (transposeOp) {
- // If it is transposed, compose the map and iterate.
- auto permutation = extractVector<unsigned>(transposeOp.transp());
- AffineMap newMap = AffineMap::getPermutationMap(permutation, context);
- if (!permutationMap)
- permutationMap = newMap;
- else if (newMap.getNumInputs() != permutationMap.getNumResults())
- return Value();
- else
- permutationMap = newMap.compose(permutationMap);
- // Compute insert/transpose for the next iteration.
- Value transposed = transposeOp.vector();
- insertOp = transposed.getDefiningOp<vector::InsertOp>();
- transposeOp = transposed.getDefiningOp<vector::TransposeOp>();
- continue;
- }
+/// Try to fold in place to extract(source, extractPosition) and return the
+/// folded result. Return null if folding is not possible (e.g. due to an
+/// internal tranposition in the result).
+Value ExtractFromInsertTransposeChainState::tryToFoldExtractOpInPlace(
+ Value source) {
+ // If we can't fold (either internal transposition, or nothing to fold), bail.
+ bool nothingToFold = (source == extractOp.vector());
+ if (nothingToFold || !canFold())
+ return Value();
+ // Otherwise, fold by updating the op inplace and return its result.
+ OpBuilder b(extractOp.getContext());
+ extractOp->setAttr(
+ extractOp.positionAttrName(),
+ b.getI64ArrayAttr(
+ makeArrayRef(extractPosition).take_front(extractedRank)));
+ extractOp.vectorMutable().assign(source);
+ return extractOp.getResult();
+}
- assert(insertOp);
- Value insertionDest = insertOp.dest();
- // If it is inserted into, either the position matches and we have a
- // successful folding; or we iterate until we run out of
- // InsertOp/TransposeOp. This is because `vector.insert %scalar, %vector`
- // produces a new vector with 1 modified value/slice in exactly the static
- // position we need to match.
- auto insertedPos = extractVector<unsigned>(insertOp.position());
- // Trivial permutations are solved with position equality checks.
- if (!permutationMap || permutationMap.isIdentity()) {
- if (extractedPos == insertedPos)
- return insertOp.source();
- // Fallthrough: if the position does not match, just skip to the next
- // producing `vector.insert` / `vector.transpose`.
- // Compute insert/transpose for the next iteration.
- insertOp = insertionDest.getDefiningOp<vector::InsertOp>();
- transposeOp = insertionDest.getDefiningOp<vector::TransposeOp>();
+/// Iterate over producing insert and transpose ops until we find a fold.
+Value ExtractFromInsertTransposeChainState::fold() {
+ Value valueToExtractFrom = extractOp.vector();
+ updateStateForNextIteration(valueToExtractFrom);
+ while (nextInsertOp || nextTransposeOp) {
+ // Case 1. If we hit a transpose, just compose the map and iterate.
+ // Invariant: insert + transpose do not change rank, we can always compose.
+ if (succeeded(handleTransposeOp())) {
+ valueToExtractFrom = nextTransposeOp.vector();
+ updateStateForNextIteration(valueToExtractFrom);
continue;
}
- // More advanced permutations require application of the permutation.
- // However, the rank of `insertedPos` may be different from that of the
- // `permutationMap`. To support such case, we need to:
- // 1. apply on the `insertedPos.size()` major dimensions
- // 2. check the other dimensions of the permutation form a minor identity.
- assert(permutationMap.isPermutation() && "expected a permutation");
- if (insertedPos.size() == extractedPos.size()) {
- bool fold = true;
- for (unsigned idx = 0, sz = extractedPos.size(); idx < sz; ++idx) {
- auto pos = permutationMap.getDimPosition(idx);
- if (pos >= sz || insertedPos[pos] != extractedPos[idx]) {
- fold = false;
- break;
- }
- }
- if (fold) {
- assert(permutationMap.getNumResults() >= insertedPos.size() &&
- "expected map of rank larger than insert indexing");
- unsigned minorRank =
- permutationMap.getNumResults() - insertedPos.size();
- AffineMap minorMap = permutationMap.getMinorSubMap(minorRank);
- if (!minorMap || minorMap.isMinorIdentity())
- return insertOp.source();
- }
- }
+ Value result;
+ // Case 2: the position match exactly.
+ if (succeeded(handleInsertOpWithMatchingPos(result)))
+ return result;
+
+ // Case 3: if the inserted position is a prefix of extractPosition, we can
+ // just extract a portion of the source of the insert.
+ if (succeeded(handleInsertOpWithPrefixPos(result)))
+ return tryToFoldExtractOpInPlace(result);
+
+ // Case 4: extractPositionRef intersects insertedPosRef on non-sentinel
+ // values. This is a more difficult case and we bail.
+ auto insertedPos = extractVector<int64_t>(nextInsertOp.position());
+ if (isContainedWithin(extractPosition, insertedPos) ||
+ intersectsWhereNonNegative(extractPosition, insertedPos))
+ return Value();
- // If we haven't found a match, just continue to the next producing
- // `vector.insert` / `vector.transpose`.
- // Compute insert/transpose for the next iteration.
- insertOp = insertionDest.getDefiningOp<vector::InsertOp>();
- transposeOp = insertionDest.getDefiningOp<vector::TransposeOp>();
+ // Case 5: No intersection, we forward the extract to insertOp.dest().
+ valueToExtractFrom = nextInsertOp.dest();
+ updateStateForNextIteration(valueToExtractFrom);
}
- return Value();
+ // If after all this we can fold, go for it.
+ return tryToFoldExtractOpInPlace(valueToExtractFrom);
}
/// Fold extractOp with scalar result coming from BroadcastOp or SplatOp.
return vector();
if (succeeded(foldExtractOpFromExtractChain(*this)))
return getResult();
- if (succeeded(foldExtractOpFromTranspose(*this)))
- return getResult();
- if (auto val = foldExtractOpFromInsertChainAndTranspose(*this))
- return val;
- if (auto val = foldExtractFromBroadcast(*this))
- return val;
- if (auto val = foldExtractFromShapeCast(*this))
- return val;
+ if (auto res = ExtractFromInsertTransposeChainState(*this).fold())
+ return res;
+ if (auto res = foldExtractFromBroadcast(*this))
+ return res;
+ if (auto res = foldExtractFromShapeCast(*this))
+ return res;
if (auto val = foldExtractFromExtractStrided(*this))
return val;
if (auto val = foldExtractStridedOpFromInsertChain(*this))
// -----
-// CHECK-LABEL: func @insert_extract_transpose_3d(
-// CHECK-SAME: %[[V:[a-zA-Z0-9]*]]: vector<2x3x4xf32>,
-// CHECK-SAME: %[[F0:[a-zA-Z0-9]*]]: f32,
-// CHECK-SAME: %[[F1:[a-zA-Z0-9]*]]: f32,
-// CHECK-SAME: %[[F2:[a-zA-Z0-9]*]]: f32,
-// CHECK-SAME: %[[F3:[a-zA-Z0-9]*]]: f32
-func @insert_extract_transpose_3d(
- %v: vector<2x3x4xf32>, %f0: f32, %f1: f32, %f2: f32, %f3: f32)
--> (f32, f32, f32, f32)
-{
- %0 = vector.insert %f0, %v[0, 0, 0] : f32 into vector<2x3x4xf32>
- %1 = vector.insert %f1, %0[0, 1, 0] : f32 into vector<2x3x4xf32>
- %2 = vector.insert %f2, %1[1, 0, 0] : f32 into vector<2x3x4xf32>
- %3 = vector.insert %f3, %2[0, 0, 1] : f32 into vector<2x3x4xf32>
- %4 = vector.transpose %3, [1, 2, 0] : vector<2x3x4xf32> to vector<3x4x2xf32>
- %5 = vector.insert %f3, %4[1, 0, 0] : f32 into vector<3x4x2xf32>
- %6 = vector.transpose %5, [1, 2, 0] : vector<3x4x2xf32> to vector<4x2x3xf32>
- %7 = vector.insert %f3, %6[1, 0, 0] : f32 into vector<4x2x3xf32>
- %8 = vector.transpose %7, [1, 2, 0] : vector<4x2x3xf32> to vector<2x3x4xf32>
-
- // Expected %f2 from %2 = vector.insert %f2, %1[1, 0, 0].
- %r1 = vector.extract %3[1, 0, 0] : vector<2x3x4xf32>
-
- // Expected %f1 from %1 = vector.insert %f1, %0[0, 1, 0] followed by
- // transpose[1, 2, 0].
- %r2 = vector.extract %4[1, 0, 0] : vector<3x4x2xf32>
-
- // Expected %f3 from %3 = vector.insert %f3, %0[0, 0, 1] followed by double
- // transpose[1, 2, 0].
- %r3 = vector.extract %6[1, 0, 0] : vector<4x2x3xf32>
-
- // Expected %f2 from %2 = vector.insert %f2, %1[1, 0, 0] followed by triple
- // transpose[1, 2, 0].
- %r4 = vector.extract %8[1, 0, 0] : vector<2x3x4xf32>
-
- // CHECK-NEXT: return %[[F2]], %[[F1]], %[[F3]], %[[F2]] : f32, f32, f32
- return %r1, %r2, %r3, %r4 : f32, f32, f32, f32
-}
-
-// -----
-
-// CHECK-LABEL: func @insert_extract_transpose_3d_2d(
-// CHECK-SAME: %[[V:[a-zA-Z0-9]*]]: vector<2x3x4xf32>,
-// CHECK-SAME: %[[F0:[a-zA-Z0-9]*]]: vector<4xf32>,
-// CHECK-SAME: %[[F1:[a-zA-Z0-9]*]]: vector<4xf32>,
-// CHECK-SAME: %[[F2:[a-zA-Z0-9]*]]: vector<4xf32>,
-// CHECK-SAME: %[[F3:[a-zA-Z0-9]*]]: vector<4xf32>
-func @insert_extract_transpose_3d_2d(
- %v: vector<2x3x4xf32>,
- %f0: vector<4xf32>, %f1: vector<4xf32>, %f2: vector<4xf32>, %f3: vector<4xf32>)
--> (vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>)
-{
- %0 = vector.insert %f0, %v[0, 0] : vector<4xf32> into vector<2x3x4xf32>
- %1 = vector.insert %f1, %0[0, 1] : vector<4xf32> into vector<2x3x4xf32>
- %2 = vector.insert %f2, %1[1, 0] : vector<4xf32> into vector<2x3x4xf32>
- %3 = vector.insert %f3, %2[1, 1] : vector<4xf32> into vector<2x3x4xf32>
- %4 = vector.transpose %3, [1, 0, 2] : vector<2x3x4xf32> to vector<3x2x4xf32>
- %5 = vector.transpose %4, [1, 0, 2] : vector<3x2x4xf32> to vector<2x3x4xf32>
-
- // Expected %f2 from %2 = vector.insert %f2, %1[1, 0].
- %r1 = vector.extract %3[1, 0] : vector<2x3x4xf32>
-
- // Expected %f1 from %1 = vector.insert %f1, %0[0, 1] followed by
- // transpose[1, 0, 2].
- %r2 = vector.extract %4[1, 0] : vector<3x2x4xf32>
-
- // Expected %f2 from %2 = vector.insert %f2, %1[1, 0, 0] followed by double
- // transpose[1, 0, 2].
- %r3 = vector.extract %5[1, 0] : vector<2x3x4xf32>
-
- %6 = vector.transpose %3, [1, 2, 0] : vector<2x3x4xf32> to vector<3x4x2xf32>
- %7 = vector.transpose %6, [1, 2, 0] : vector<3x4x2xf32> to vector<4x2x3xf32>
- %8 = vector.transpose %7, [1, 2, 0] : vector<4x2x3xf32> to vector<2x3x4xf32>
+// CHECK-LABEL: insert_extract_chain
+// CHECK-SAME: %[[V234:[a-zA-Z0-9]*]]: vector<2x3x4xf32>
+// CHECK-SAME: %[[V34:[a-zA-Z0-9]*]]: vector<3x4xf32>
+// CHECK-SAME: %[[V4:[a-zA-Z0-9]*]]: vector<4xf32>
+func @insert_extract_chain(%v234: vector<2x3x4xf32>, %v34: vector<3x4xf32>, %v4: vector<4xf32>)
+ -> (vector<4xf32>, vector<4xf32>, vector<3x4xf32>, vector<3x4xf32>) {
+ // CHECK-NEXT: %[[A34:.*]] = vector.insert
+ %A34 = vector.insert %v34, %v234[0]: vector<3x4xf32> into vector<2x3x4xf32>
+ // CHECK-NEXT: %[[B34:.*]] = vector.insert
+ %B34 = vector.insert %v34, %A34[1]: vector<3x4xf32> into vector<2x3x4xf32>
+ // CHECK-NEXT: %[[A4:.*]] = vector.insert
+ %A4 = vector.insert %v4, %B34[1, 0]: vector<4xf32> into vector<2x3x4xf32>
+ // CHECK-NEXT: %[[B4:.*]] = vector.insert
+ %B4 = vector.insert %v4, %A4[1, 1]: vector<4xf32> into vector<2x3x4xf32>
+
+ // Case 2.a. [1, 1] == insertpos ([1, 1])
+ // Match %A4 insertionpos and fold to its source(i.e. %V4).
+ %r0 = vector.extract %B4[1, 1]: vector<2x3x4xf32>
+
+ // Case 3.a. insertpos ([1]) is a prefix of [1, 0].
+ // Traverse %B34 to its source(i.e. %V34@[*0*]).
+ // CHECK-NEXT: %[[R1:.*]] = vector.extract %[[V34]][0]
+ %r1 = vector.extract %B34[1, 0]: vector<2x3x4xf32>
+
+ // Case 4. [1] is a prefix of insertpos ([1, 1]).
+ // Cannot traverse %B4.
+ // CHECK-NEXT: %[[R2:.*]] = vector.extract %[[B4]][1]
+ %r2 = vector.extract %B4[1]: vector<2x3x4xf32>
+
+ // Case 5. [0] is disjoint from insertpos ([1, 1]).
+ // Traverse %B4 to its dest(i.e. %A4@[0]).
+ // Traverse %A4 to its dest(i.e. %B34@[0]).
+ // Traverse %B34 to its dest(i.e. %A34@[0]).
+ // Match %A34 insertionpos and fold to its source(i.e. %V34).
+ %r3 = vector.extract %B4[0]: vector<2x3x4xf32>
+
+ // CHECK: return %[[V4]], %[[R1]], %[[R2]], %[[V34]]
+ return %r0, %r1, %r2, %r3:
+ vector<4xf32>, vector<4xf32>, vector<3x4xf32>, vector<3x4xf32>
+}
- // Expected %f2 from %2 = vector.insert %f2, %1[1, 0, 0] followed by triple
- // transpose[1, 2, 0].
- %r4 = vector.extract %8[1, 0] : vector<2x3x4xf32>
+// -----
- // CHECK: return %[[F2]], %[[F1]], %[[F2]], %[[F2]]
- // CHECK-SAME: vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>
- return %r1, %r2, %r3, %r4 : vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>
+// CHECK-LABEL: func @insert_extract_transpose_3d(
+// CHECK-SAME: %[[V234:[a-zA-Z0-9]*]]: vector<2x3x4xf32>
+func @insert_extract_transpose_3d(
+ %v234: vector<2x3x4xf32>, %v43: vector<4x3xf32>, %f0: f32)
+ -> (vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<3x4xf32>) {
+
+ %a432 = vector.transpose %v234, [2, 1, 0] : vector<2x3x4xf32> to vector<4x3x2xf32>
+ %b432 = vector.insert %f0, %a432[0, 0, 1] : f32 into vector<4x3x2xf32>
+ %c234 = vector.transpose %b432, [2, 1, 0] : vector<4x3x2xf32> to vector<2x3x4xf32>
+ // Case 1. %c234 = transpose [2,1,0] posWithSentinels [1,2,-1] -> [-1,2,1]
+ // Case 5. %b432 = insert [0,0,1] (inter([.,2,1], [.,0,1]) == 0) prop to %v432
+ // Case 1. %a432 = transpose [2,1,0] posWithSentinels [-1,2,1] -> [1,2,-1]
+ // can extract directly from %v234, the rest folds.
+ // CHECK: %[[R0:.*]] = vector.extract %[[V234]][1, 2]
+ %r0 = vector.extract %c234[1, 2] : vector<2x3x4xf32>
+
+ // CHECK-NEXT: vector.transpose
+ // CHECK-NEXT: vector.insert
+ // CHECK-NEXT: %[[F234:.*]] = vector.transpose
+ %d432 = vector.transpose %v234, [2, 1, 0] : vector<2x3x4xf32> to vector<4x3x2xf32>
+ %e432 = vector.insert %f0, %d432[0, 2, 1] : f32 into vector<4x3x2xf32>
+ %f234 = vector.transpose %e432, [2, 1, 0] : vector<4x3x2xf32> to vector<2x3x4xf32>
+ // Case 1. %c234 = transpose [2,1,0] posWithSentinels [1,2,-1] -> [-1,2,1]
+ // Case 4. %b432 = insert [0,0,1] (inter([.,2,1], [.,2,1]) != 0)
+ // Bail, cannot do better than the current.
+ // CHECK: %[[R1:.*]] = vector.extract %[[F234]]
+ %r1 = vector.extract %f234[1, 2] : vector<2x3x4xf32>
+
+ // CHECK-NEXT: vector.transpose
+ // CHECK-NEXT: vector.insert
+ // CHECK-NEXT: %[[H234:.*]] = vector.transpose
+ %g243 = vector.transpose %v234, [0, 2, 1] : vector<2x3x4xf32> to vector<2x4x3xf32>
+ %h243 = vector.insert %v43, %g243[0] : vector<4x3xf32> into vector<2x4x3xf32>
+ %i234 = vector.transpose %h243, [0, 2, 1] : vector<2x4x3xf32> to vector<2x3x4xf32>
+ // Case 1. %i234 = transpose [0,2,1] posWithSentinels [0,-1,-2] -> [0,-2,-1]
+ // Case 3.b. %b432 = insert [0] is prefix of [0,.,.] but internal transpose.
+ // Bail, cannot do better than the current.
+ // CHECK: %[[R2:.*]] = vector.extract %[[H234]][0, 1]
+ %r2 = vector.extract %i234[0, 1] : vector<2x3x4xf32>
+
+ // CHECK-NEXT: vector.transpose
+ // CHECK-NEXT: vector.insert
+ // CHECK-NEXT: %[[K234:.*]] = vector.transpose
+ %j243 = vector.transpose %v234, [0, 2, 1] : vector<2x3x4xf32> to vector<2x4x3xf32>
+ %k243 = vector.insert %v43, %j243[0] : vector<4x3xf32> into vector<2x4x3xf32>
+ %l234 = vector.transpose %k243, [0, 2, 1] : vector<2x4x3xf32> to vector<2x3x4xf32>
+ // Case 1. %i234 = transpose [0,2,1] posWithSentinels [0,-1,-2] -> [0,-2,-1]
+ // Case 2.b. %b432 = insert [0] == [0,.,.] but internal transpose.
+ // Bail, cannot do better than the current.
+ // CHECK: %[[R3:.*]] = vector.extract %[[K234]][0]
+ %r3 = vector.extract %l234[0] : vector<2x3x4xf32>
+
+ // CHECK-NEXT: return %[[R0]], %[[R1]], %[[R2]], %[[R3]]
+ return %r0, %r1, %r2, %r3: vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<3x4xf32>
}
// -----