if (rank != insertOp.getDestVectorType().getRank())
return failure();
+ // Requires that shape of insert op src is castable to dstType.
+ unsigned sourceWidth = castSrcType.getElementType().getIntOrFloatBitWidth();
+ unsigned destinationWidth =
+ castDstType.getElementType().getIntOrFloatBitWidth();
+ unsigned numElements = destinationWidth / sourceWidth;
+ if (insertOp.getSourceVectorType().getNumElements() % numElements != 0)
+ return failure();
+
ArrayAttr newOffsets = insertOp.getOffsets();
assert(newOffsets.size() == rank);
SmallVector<int64_t, 4> offsets = getIntValueVector(newOffsets);
%cast = vector.bitcast %0: vector<16x4x8xf16> to vector<16x4x4xf32>
return %cast: vector<16x4x4xf32>
}
+
+// CHECK-LABEL: func @bubble_up_bitcast_in_strided_slice_insert_odd_shape
+func.func @bubble_up_bitcast_in_strided_slice_insert_odd_shape(%dst: vector<2xf16>, %src: vector<1xf16>) -> vector<1xf32> {
+ // CHECK: vector.insert_strided_slice
+ // CHECK-NEXT: vector.bitcast
+ %0 = vector.insert_strided_slice %src, %dst {offsets = [0], strides = [1]} : vector<1xf16> into vector<2xf16>
+ %cast = vector.bitcast %0: vector<2xf16> to vector<1xf32>
+ return %cast: vector<1xf32>
+}
+
+// CHECK-LABEL: func @bubble_up_bitcast_in_strided_slice_insert_larger_odd_shape
+func.func @bubble_up_bitcast_in_strided_slice_insert_larger_odd_shape(%dst: vector<8xf16>, %src: vector<3xf16>) -> vector<4xf32> {
+ // CHECK: vector.insert_strided_slice
+ // CHECK-NEXT: vector.bitcast
+ %0 = vector.insert_strided_slice %src, %dst {offsets = [0], strides = [1]} : vector<3xf16> into vector<8xf16>
+ %cast = vector.bitcast %0: vector<8xf16> to vector<4xf32>
+ return %cast: vector<4xf32>
+}