[mlir][vector] Add insertOp src shape check for BubbleUpBitCastForStridedSliceInsert
authorstanley-nod <stanley@nod-labs.com>
Fri, 11 Nov 2022 00:41:59 +0000 (16:41 -0800)
committerstanley-nod <stanley@nod-labs.com>
Fri, 11 Nov 2022 00:41:59 +0000 (16:41 -0800)
Not all shape of vectors can be casted into other types, we add a check
to not fold insertOp into bitcast if the shape does not support it.

Examples of unsupported shape castings are f16 vectors to f32 if the
shape is not multiple of 2s. or int8 to int32 if shapes are not multiple
of 4.

Reviewed By: antiagainst, ThomasRaoux

Differential Revision: https://reviews.llvm.org/D137802

mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
mlir/test/Dialect/Vector/vector-transforms.mlir

index 0bdaf7b..db80474 100644 (file)
@@ -2503,6 +2503,14 @@ struct BubbleUpBitCastForStridedSliceInsert
     if (rank != insertOp.getDestVectorType().getRank())
       return failure();
 
+    // Requires that shape of insert op src is castable to dstType.
+    unsigned sourceWidth = castSrcType.getElementType().getIntOrFloatBitWidth();
+    unsigned destinationWidth =
+        castDstType.getElementType().getIntOrFloatBitWidth();
+    unsigned numElements = destinationWidth / sourceWidth;
+    if (insertOp.getSourceVectorType().getNumElements() % numElements != 0)
+      return failure();
+
     ArrayAttr newOffsets = insertOp.getOffsets();
     assert(newOffsets.size() == rank);
     SmallVector<int64_t, 4> offsets = getIntValueVector(newOffsets);
index 0c86c7c..e8f5cf7 100644 (file)
@@ -507,3 +507,21 @@ func.func @bubble_up_bitcast_in_strided_slice_insert_different_rank(%dst: vector
   %cast = vector.bitcast %0: vector<16x4x8xf16> to vector<16x4x4xf32>
   return %cast: vector<16x4x4xf32>
 }
+
+// CHECK-LABEL: func @bubble_up_bitcast_in_strided_slice_insert_odd_shape
+func.func @bubble_up_bitcast_in_strided_slice_insert_odd_shape(%dst: vector<2xf16>, %src: vector<1xf16>) -> vector<1xf32> {
+  // CHECK: vector.insert_strided_slice
+  // CHECK-NEXT: vector.bitcast
+  %0 = vector.insert_strided_slice %src, %dst {offsets = [0], strides = [1]} : vector<1xf16> into vector<2xf16>
+  %cast = vector.bitcast %0: vector<2xf16> to vector<1xf32>
+  return %cast: vector<1xf32>
+}
+
+// CHECK-LABEL: func @bubble_up_bitcast_in_strided_slice_insert_larger_odd_shape
+func.func @bubble_up_bitcast_in_strided_slice_insert_larger_odd_shape(%dst: vector<8xf16>, %src: vector<3xf16>) -> vector<4xf32> {
+  // CHECK: vector.insert_strided_slice
+  // CHECK-NEXT: vector.bitcast
+  %0 = vector.insert_strided_slice %src, %dst {offsets = [0], strides = [1]} : vector<3xf16> into vector<8xf16>
+  %cast = vector.bitcast %0: vector<8xf16> to vector<4xf32>
+  return %cast: vector<4xf32>
+}