From dc26c030661a763bdc50c759576fc3c34f3c496a Mon Sep 17 00:00:00 2001 From: stanley-nod Date: Thu, 10 Nov 2022 16:41:59 -0800 Subject: [PATCH] [mlir][vector] Add insertOp src shape check for BubbleUpBitCastForStridedSliceInsert Not all shape of vectors can be casted into other types, we add a check to not fold insertOp into bitcast if the shape does not support it. Examples of unsupported shape castings are f16 vectors to f32 if the shape is not multiple of 2s. or int8 to int32 if shapes are not multiple of 4. Reviewed By: antiagainst, ThomasRaoux Differential Revision: https://reviews.llvm.org/D137802 --- .../lib/Dialect/Vector/Transforms/VectorTransforms.cpp | 8 ++++++++ mlir/test/Dialect/Vector/vector-transforms.mlir | 18 ++++++++++++++++++ 2 files changed, 26 insertions(+) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp index 0bdaf7b..db80474 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -2503,6 +2503,14 @@ struct BubbleUpBitCastForStridedSliceInsert if (rank != insertOp.getDestVectorType().getRank()) return failure(); + // Requires that shape of insert op src is castable to dstType. + unsigned sourceWidth = castSrcType.getElementType().getIntOrFloatBitWidth(); + unsigned destinationWidth = + castDstType.getElementType().getIntOrFloatBitWidth(); + unsigned numElements = destinationWidth / sourceWidth; + if (insertOp.getSourceVectorType().getNumElements() % numElements != 0) + return failure(); + ArrayAttr newOffsets = insertOp.getOffsets(); assert(newOffsets.size() == rank); SmallVector offsets = getIntValueVector(newOffsets); diff --git a/mlir/test/Dialect/Vector/vector-transforms.mlir b/mlir/test/Dialect/Vector/vector-transforms.mlir index 0c86c7c..e8f5cf7 100644 --- a/mlir/test/Dialect/Vector/vector-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-transforms.mlir @@ -507,3 +507,21 @@ func.func @bubble_up_bitcast_in_strided_slice_insert_different_rank(%dst: vector %cast = vector.bitcast %0: vector<16x4x8xf16> to vector<16x4x4xf32> return %cast: vector<16x4x4xf32> } + +// CHECK-LABEL: func @bubble_up_bitcast_in_strided_slice_insert_odd_shape +func.func @bubble_up_bitcast_in_strided_slice_insert_odd_shape(%dst: vector<2xf16>, %src: vector<1xf16>) -> vector<1xf32> { + // CHECK: vector.insert_strided_slice + // CHECK-NEXT: vector.bitcast + %0 = vector.insert_strided_slice %src, %dst {offsets = [0], strides = [1]} : vector<1xf16> into vector<2xf16> + %cast = vector.bitcast %0: vector<2xf16> to vector<1xf32> + return %cast: vector<1xf32> +} + +// CHECK-LABEL: func @bubble_up_bitcast_in_strided_slice_insert_larger_odd_shape +func.func @bubble_up_bitcast_in_strided_slice_insert_larger_odd_shape(%dst: vector<8xf16>, %src: vector<3xf16>) -> vector<4xf32> { + // CHECK: vector.insert_strided_slice + // CHECK-NEXT: vector.bitcast + %0 = vector.insert_strided_slice %src, %dst {offsets = [0], strides = [1]} : vector<3xf16> into vector<8xf16> + %cast = vector.bitcast %0: vector<8xf16> to vector<4xf32> + return %cast: vector<4xf32> +} -- 2.7.4