[mlir][linalg] Fix bug in elementwise vectorization
authorThomas Raoux <thomasraoux@google.com>
Fri, 11 Dec 2020 15:03:30 +0000 (07:03 -0800)
committerThomas Raoux <thomasraoux@google.com>
Mon, 14 Dec 2020 18:44:36 +0000 (10:44 -0800)
Fix a bug causing to pick the wrong vector size to broadcast to when the source
vectors have different ranks.

Differential Revision: https://reviews.llvm.org/D93118

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/test/Dialect/Linalg/vectorization.mlir

index a28b90b..2df1a94 100644 (file)
@@ -216,6 +216,7 @@ private:
       if (!vecType)
         continue;
       if (maxSize < vecType.getNumElements()) {
+        maxSize = vecType.getNumElements();
         largestShape.assign(vecType.getShape().begin(),
                             vecType.getShape().end());
       }
index 1c35332..6019dde 100644 (file)
@@ -169,7 +169,7 @@ func @generic_vectorize(%arg0: memref<4x256xf32>, %arg1: memref<4x256xf32>,
     %11 = mulf %arg5, %8 : f32
     %12 = rsqrt %arg5 : f32
     %13 = select %7, %arg5, %arg6 : f32
-    %14 = subf %arg5, %arg6 : f32
+    %14 = subf %arg5, %arg4 : f32
     %15 = tanh %arg5 : f32
     linalg.yield %6, %8, %c1_f32, %9, %10, %11, %12, %13, %14, %15 : f32, f32,
       f32, f32, f32, f32, f32, f32, f32, f32
@@ -196,7 +196,8 @@ func @generic_vectorize(%arg0: memref<4x256xf32>, %arg1: memref<4x256xf32>,
 //       CHECK:   %[[MUL:.*]] = mulf %[[V3]], %[[CST0]] : vector<4x256xf32>
 //       CHECK:   %[[RSQRT:.*]] = rsqrt %[[V3]] : vector<4x256xf32>
 //       CHECK:   %[[SEL:.*]] = select %[[CMP]], %[[V3]], %[[V1]] : vector<4x256xi1>, vector<4x256xf32>
-//       CHECK:   %[[SUB:.*]] = subf %[[V3]], %[[V1]] : vector<4x256xf32>
+//       CHECK:   %[[V0B:.*]] = vector.broadcast %[[V0]] : vector<256xf32> to vector<4x256xf32>
+//       CHECK:   %[[SUB:.*]] = subf %[[V3]], %[[V0B]] : vector<4x256xf32>
 //       CHECK:   %[[TAN:.*]] = tanh %[[V3]] : vector<4x256xf32>
 //       CHECK:   vector.transfer_write %[[ADD]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32>
 //       CHECK:   vector.transfer_write %[[CST0]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, memref<4x256xf32>