[mlir][Vector] Prevent AVX2 lowering for non-f32 transpose ops
authorDiego Caballero <diegocaballero@google.com>
Fri, 25 Feb 2022 18:27:43 +0000 (18:27 +0000)
committerDiego Caballero <diegocaballero@google.com>
Fri, 25 Feb 2022 19:27:32 +0000 (19:27 +0000)
The AVX2 lowering for transpose operations is only applicable to f32 vector types.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D120427

mlir/lib/Dialect/X86Vector/Transforms/AVXTranspose.cpp
mlir/test/Dialect/Vector/vector-transpose-lowering.mlir

index 27272d1..065848d 100644 (file)
@@ -250,8 +250,11 @@ public:
     auto loc = op.getLoc();
 
     // Check if the source vector type is supported. AVX2 patterns can only be
-    // applied if the vector type has two dimensions greater than one.
+    // applied to f32 vector types with two dimensions greater than one.
     VectorType srcType = op.getVectorType();
+    if (!srcType.getElementType().isF32())
+      return rewriter.notifyMatchFailure(op, "Unsupported vector element type");
+
     SmallVector<int64_t> srcGtOneDims;
     for (auto &en : llvm::enumerate(srcType.getShape()))
       if (en.value() > 1)
index 44a59a2..651006b 100644 (file)
@@ -548,6 +548,15 @@ func @transpose210_1x8x8xf32(%arg0: vector<1x8x8xf32>) -> vector<8x8x1xf32> {
 
 // -----
 
+func @do_not_lower_nonf32_to_avx2(%arg0: vector<4x8xi32>) -> vector<8x4xi32> {
+  %0 = vector.transpose %arg0, [1, 0] : vector<4x8xi32> to vector<8x4xi32>
+  return %0 : vector<8x4xi32>
+}
+
+// AVX2-NOT: vector.shuffle
+
+// -----
+
 // AVX2-LABEL: func @transpose021_8x1x8
 func @transpose021_8x1x8xf32(%arg0: vector<8x1x8xf32>) -> vector<8x8x1xf32> {
   %0 = vector.transpose %arg0, [0, 2, 1] : vector<8x1x8xf32> to vector<8x8x1xf32>