/// Returns the producer Value of the same type as 'consumerValue', by tracking
/// the tuple index and offsets of the consumer vector value through the
-/// chain of operations (TupleGetOp, InsertSlicesOp, ExtractSlicesOp, TupleOp)
-/// from consumer to producer. Each operation in the chain is structured, and
-/// so the tuple index and offsets can be mapped from result to input, while
-/// visiting each operation in the chain.
+/// chain of operations (TupleGetOp, InsertSlicesOp, ExtractSlicesOp, TupleOp,
+/// and ShapeCastOp) from consumer to producer. Each operation in the chain is
+/// structured, and so the tuple index and offsets can be mapped from result to
+/// input, while visiting each operation in the chain.
/// Returns nullptr on failure.
static Value getProducerValue(Value consumerValue) {
auto consumerVectorType = consumerValue.getType().cast<VectorType>();
// Update 'tupleIndex' and next defining 'op' to visit.
tupleIndex = -1;
op = value.getDefiningOp();
+ } else if (auto shapeCastOp = dyn_cast<vector::ShapeCastOp>(op)) {
+ if (shapeCastOp.source().getType().isa<TupleType>())
+ return nullptr;
+ assert(tupleIndex == -1);
+ auto sourceVectorType = shapeCastOp.getSourceVectorType();
+ auto sourceVectorShape = sourceVectorType.getShape();
+ unsigned sourceVectorRank = sourceVectorType.getRank();
+ auto resultVectorType = shapeCastOp.getResultVectorType();
+ auto resultVectorShape = resultVectorType.getShape();
+ unsigned resultVectorRank = resultVectorType.getRank();
+
+ int i = sourceVectorRank - 1;
+ int j = resultVectorRank - 1;
+
+ // Check that source/result vector shape prefixes match while
+ // updating 'newOffsets'.
+ bool canShapeCastFold = true;
+ SmallVector<int64_t, 4> newOffsets(sourceVectorRank, 0);
+
+ auto apply = [&](int64_t sourceSize, int64_t resultSize) {
+ canShapeCastFold = sourceSize == resultSize;
+ newOffsets[i--] = offsets[j--];
+ };
+ functional::zipApply(apply, llvm::reverse(sourceVectorShape),
+ llvm::reverse(resultVectorShape));
+ if (!canShapeCastFold)
+ return nullptr;
+
+ // Check that remaining prefix of source/result vector shapes are all 1s.
+ // Currently we only support producer/consumer tracking through trivial
+ // shape cast ops. Examples:
+ // %1 = vector.shape_cast %0 : vector<1x1x2x4xf32> to vector<2x4xf32>
+ // %3 = vector.shape_cast %2 : vector<16x8xf32> to vector<1x16x8xf32>
+ assert(i == -1 || j == -1);
+ if (i >= 0 &&
+ !std::all_of(sourceVectorShape.begin(), sourceVectorShape.begin() + i,
+ [](int64_t v) { return v == 1; }))
+ return nullptr;
+ if (j >= 0 &&
+ !std::all_of(resultVectorShape.begin(), resultVectorShape.begin() + j,
+ [](int64_t v) { return v == 1; }))
+ return nullptr;
+
+ offsets.swap(newOffsets);
+ op = shapeCastOp.source().getDefiningOp();
} else {
- break;
+ // Check if 'op' produces a Value with the same type as 'consumerValue'.
+ if (op->getNumResults() == 1 &&
+ op->getResult(0).getType() == consumerVectorType)
+ return op->getResult(0);
+ return nullptr;
}
}
return nullptr;
LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
PatternRewriter &rewriter) const override {
+ // Check if we can replace 'shapeCastOp' result with its producer.
+ if (auto producer = getProducerValue(shapeCastOp.getResult())) {
+ rewriter.replaceOp(shapeCastOp, producer);
+ return success();
+ }
+
// Check if 'shapeCastOp' has vector source/result type.
auto sourceVectorType =
shapeCastOp.source().getType().dyn_cast_or_null<VectorType>();
%2 = vector.extract_slices %1, [4, 8], [1, 1]
: vector<4x16xf32> into tuple<vector<4x8xf32>, vector<4x8xf32>>
// %arg7 == %2 at tupleIndex = 1, offsets = [2, 4]
- %3 = vector.tuple_get %2, 1 : tuple<vector<4x8xf32>, vector<4x8xf32>>
- // %arg7 == %3 at tupleIndex = -1, offsets = [2, 4]
- %4 = vector.extract_slices %3, [2, 4], [1, 1]
+ %3 = vector.shape_cast %2 : tuple<vector<4x8xf32>, vector<4x8xf32>> to
+ tuple<vector<1x1x4x8xf32>, vector<1x1x4x8xf32>>
+ // %arg7 = %3 at tupleIndex = 1, offsets = [0, 0, 2, 4]
+ %4 = vector.tuple_get %3, 1 : tuple<vector<1x1x4x8xf32>, vector<1x1x4x8xf32>>
+ // %arg7 == %4 at tupleIndex = -1, offsets = [0, 0, 2, 4]
+ %5 = vector.shape_cast %4 : vector<1x1x4x8xf32> to vector<4x8xf32>
+ // %arg7 == %5 at tupleIndex = -1, offsets = [2, 4]
+ %6 = vector.extract_slices %5, [2, 4], [1, 1]
: vector<4x8xf32> into
tuple<vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>>
- // %arg7 == %4 at tupleIndex = 3, offsets = [0, 0]
- %5 = vector.tuple_get %4, 3
+ // %arg7 == %6 at tupleIndex = 3, offsets = [0, 0]
+ %7 = vector.tuple_get %6, 3
: tuple<vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>>
- // %arg7 == %5
- return %5 : vector<2x4xf32>
+ // %arg7 == %7
+ return %7 : vector<2x4xf32>
}
// CHECK-LABEL: func @tuple_get_producer_consumer_swizzle
%2 = vector.extract_slices %1, [4, 8], [1, 1]
: vector<4x16xf32> into tuple<vector<4x8xf32>, vector<4x8xf32>>
// %arg7 == %2 at tupleIndex = 1, offsets = [2, 4]
+ %3= vector.shape_cast %2 : tuple<vector<4x8xf32>, vector<4x8xf32>> to
+ tuple<vector<1x1x4x8xf32>, vector<1x1x4x8xf32>>
+ // %arg7 = %3 at tupleIndex = 1, offsets = [0, 0, 2, 4]
// Extract tuple elements.
- %3 = vector.tuple_get %2, 0 : tuple<vector<4x8xf32>, vector<4x8xf32>>
- %4 = vector.tuple_get %2, 1 : tuple<vector<4x8xf32>, vector<4x8xf32>>
- // %arg7 == %4 at tupleIndex = -1, offsets = [2, 4]
+ %4 = vector.tuple_get %3, 0 : tuple<vector<1x1x4x8xf32>, vector<1x1x4x8xf32>>
+ %5 = vector.tuple_get %3, 1 : tuple<vector<1x1x4x8xf32>, vector<1x1x4x8xf32>>
+ // %arg7 == %5 at tupleIndex = -1, offsets = [0, 0, 2, 4]
// Swizzle tuple elements.
- %5 = vector.tuple %4, %3 : vector<4x8xf32>, vector<4x8xf32>
- // %arg7 == %5 at tupleIndex = 0, offsets = [2, 4]
- %6 = vector.tuple_get %5, 0 : tuple<vector<4x8xf32>, vector<4x8xf32>>
- // %arg7 == %6 at tupleIndex = -1, offsets = [2, 4]
- %7 = vector.extract_slices %6, [2, 4], [1, 1]
+ %6 = vector.tuple %5, %4 : vector<1x1x4x8xf32>, vector<1x1x4x8xf32>
+ // %arg7 == %6 at tupleIndex = 0, offsets = [0, 0, 2, 4]
+ %7 = vector.shape_cast %6 : tuple<vector<1x1x4x8xf32>, vector<1x1x4x8xf32>> to
+ tuple<vector<4x8xf32>, vector<4x8xf32>>
+ // %arg7 = %7 at tupleIndex = 0, offsets = [2, 4]
+ %8 = vector.tuple_get %7, 0 : tuple<vector<4x8xf32>, vector<4x8xf32>>
+ // %arg7 == %8 at tupleIndex = -1, offsets = [2, 4]
+ %9 = vector.extract_slices %8, [2, 4], [1, 1]
: vector<4x8xf32> into
tuple<vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>>
- // %arg7 == %7 at tupleIndex = 3, offsets = [0, 0]
- %8 = vector.tuple_get %7, 3
+ // %arg7 == %9 at tupleIndex = 3, offsets = [0, 0]
+ %10 = vector.tuple_get %9, 3
: tuple<vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>>
- // %arg7 == %8
- return %8 : vector<2x4xf32>
+ // %arg7 == %10
+ return %10 : vector<2x4xf32>
+}
+
+// CHECK-LABEL: func @cancelling_shape_cast_ops
+// CHECK-SAME: %[[A0:.*0]]: vector<2x4xf32>
+// CHECK: return %[[A0]] : vector<2x4xf32>
+func @cancelling_shape_cast_ops(%arg0 : vector<2x4xf32>) -> vector<2x4xf32> {
+ %0 = vector.shape_cast %arg0 : vector<2x4xf32> to vector<8xf32>
+ %1 = vector.shape_cast %0 : vector<8xf32> to vector<2x4xf32>
+ return %1 : vector<2x4xf32>
}
// CHECK-LABEL: func @vector_transfers_vector_element_type