[mlir][Vector] Support 0-D vectors in TransposeOp
authorNicolas Vasilache <nicolas.vasilache@gmail.com>
Fri, 26 Aug 2022 10:34:39 +0000 (03:34 -0700)
committerNicolas Vasilache <nicolas.vasilache@gmail.com>
Fri, 26 Aug 2022 10:40:21 +0000 (03:40 -0700)
Co-authored-by: Michal Terepeta <michalt@google.com>
Reviewed-by: ftynse
Differential Revision: https://reviews.llvm.org/D115743

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
mlir/test/Dialect/Vector/invalid.mlir
mlir/test/Dialect/Vector/ops.mlir
mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir

index 033c29b..aa6624f 100644 (file)
@@ -2229,12 +2229,13 @@ def Vector_TransposeOp :
     DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
     PredOpTrait<"operand and result have same element type",
                  TCresVTEtIsSameAsOpBase<0, 0>>]>,
-    Arguments<(ins AnyVector:$vector, I64ArrayAttr:$transp)>,
-    Results<(outs AnyVector:$result)> {
+    Arguments<(ins AnyVectorOfAnyRank:$vector, I64ArrayAttr:$transp)>,
+    Results<(outs AnyVectorOfAnyRank:$result)> {
   let summary = "vector transpose operation";
   let description = [{
     Takes a n-D vector and returns the transposed n-D vector defined by
-    the permutation of ranks in the n-sized integer array attribute.
+    the permutation of ranks in the n-sized integer array attribute (in case
+    of 0-D vectors the array attribute must be empty).
     In the operation
 
     ```mlir
index 40e4022..828fc22 100644 (file)
@@ -1760,6 +1760,8 @@ func.func @create_mask_1d(%a : index) -> vector<4xi1> {
 // CHECK:  %[[result:.*]] = arith.cmpi slt, %[[indices]], %[[bounds]] : vector<4xi32>
 // CHECK:  return %[[result]] : vector<4xi1>
 
+// -----
+
 func.func @create_mask_1d_scalable(%a : index) -> vector<[4]xi1> {
   %v = vector.create_mask %a : vector<[4]xi1>
   return %v: vector<[4]xi1>
@@ -1776,6 +1778,17 @@ func.func @create_mask_1d_scalable(%a : index) -> vector<[4]xi1> {
 
 // -----
 
+func.func @transpose_0d(%arg0: vector<f32>) -> vector<f32> {
+  %0 = vector.transpose %arg0, [] : vector<f32> to vector<f32>
+  return %0 : vector<f32>
+}
+
+// CHECK-LABEL: func @transpose_0d
+// CHECK-SAME:  %[[A:.*]]: vector<f32>
+// CHECK:       return %[[A]] : vector<f32>
+
+// -----
+
 func.func @flat_transpose(%arg0: vector<16xf32>) -> vector<16xf32> {
   %0 = vector.flat_transpose %arg0 { rows = 4: i32, columns = 4: i32 }
      : vector<16xf32> -> vector<16xf32>
index 9a7e6f4..fa25164 100644 (file)
@@ -1145,12 +1145,26 @@ func.func @multi_reduce_invalid_type(%arg0: vector<4x16xf32>, %acc: vector<16xf3
 
 // -----
 
+func.func @transpose_rank_mismatch_0d(%arg0: vector<f32>) {
+  // expected-error@+1 {{'vector.transpose' op vector result rank mismatch: 1}}
+  %0 = vector.transpose %arg0, [] : vector<f32> to vector<100xf32>
+}
+
+// -----
+
 func.func @transpose_rank_mismatch(%arg0: vector<4x16x11xf32>) {
   // expected-error@+1 {{'vector.transpose' op vector result rank mismatch: 1}}
   %0 = vector.transpose %arg0, [2, 1, 0] : vector<4x16x11xf32> to vector<100xf32>
 }
 
 // -----
+func.func @transpose_length_mismatch_0d(%arg0: vector<f32>) {
+  // expected-error@+1 {{'vector.transpose' op transposition length mismatch: 1}}
+  %0 = vector.transpose %arg0, [1] : vector<f32> to vector<f32>
+}
+
+// -----
 
 func.func @transpose_length_mismatch(%arg0: vector<4x4xf32>) {
   // expected-error@+1 {{'vector.transpose' op transposition length mismatch: 3}}
index 4c3e322..e4e260a 100644 (file)
@@ -570,6 +570,22 @@ func.func @transpose_int(%arg0: vector<11x7x3x2xi32>) -> vector<2x11x7x3xi32> {
   return %0 : vector<2x11x7x3xi32>
 }
 
+// CHECK-LABEL: @transpose_fp_0d
+func.func @transpose_fp_0d(%arg0: vector<f32>) -> vector<f32> {
+  // CHECK: %[[X:.*]] = vector.transpose %{{.*}}, [] : vector<f32> to vector<f32>
+  %0 = vector.transpose %arg0, [] : vector<f32> to vector<f32>
+  // CHECK: return %[[X]] : vector<f32>
+  return %0 : vector<f32>
+}
+
+// CHECK-LABEL: @transpose_int_0d
+func.func @transpose_int_0d(%arg0: vector<i32>) -> vector<i32> {
+  // CHECK: %[[X:.*]] = vector.transpose %{{.*}}, [] : vector<i32> to vector<i32>
+  %0 = vector.transpose %arg0, [] : vector<i32> to vector<i32>
+  // CHECK: return %[[X]] : vector<i32>
+  return %0 : vector<i32>
+}
+
 // CHECK-LABEL: @flat_transpose_fp
 func.func @flat_transpose_fp(%arg0: vector<16xf32>) -> vector<16xf32> {
   // CHECK: %[[X:.*]] = vector.flat_transpose %{{.*}} {columns = 4 : i32, rows = 4 : i32} : vector<16xf32> -> vector<16xf32>
index a29ab10..8a100dc 100644 (file)
@@ -120,6 +120,13 @@ func.func @fma_0d(%four: vector<f32>) {
   return
 }
 
+func.func @transpose_0d(%arg: vector<i32>) {
+  %1 = vector.transpose %arg, [] : vector<i32> to vector<i32>
+  // CHECK: ( 42 )
+  vector.print %1: vector<i32>
+  return
+}
+
 func.func @entry() {
   %0 = arith.constant 42.0 : f32
   %1 = arith.constant dense<0.0> : vector<f32>
@@ -151,6 +158,8 @@ func.func @entry() {
 
   %5 = arith.constant dense<4.0> : vector<f32>
   call  @fma_0d(%5) : (vector<f32>) -> ()
+  %6 = arith.constant dense<42> : vector<i32>
+  call @transpose_0d(%6) : (vector<i32>) -> ()
 
   return
 }