[mlir][Vector] Fix 0-D tensor vectorization in Linalg
authorDiego Caballero <diegocaballero@google.com>
Fri, 16 Jun 2023 23:21:24 +0000 (23:21 +0000)
committerDiego Caballero <diegocaballero@google.com>
Fri, 16 Jun 2023 23:45:03 +0000 (23:45 +0000)
It looks like scalable vector support broke vectorization for 0-D
tensors and we didn't have any test coverting that case. This patch
provides a fix and a test.

Differential Revision: https://reviews.llvm.org/D153181

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/test/Dialect/Linalg/vectorization.mlir

index 685567d..bbcde44 100644 (file)
@@ -1199,38 +1199,43 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
   //   a. Get the first max ranked shape.
   VectorType firstMaxRankedType;
   for (Value operand : op->getOperands()) {
-    auto vecType = dyn_cast<VectorType>(bvm.lookup(operand).getType());
+    auto vecOperand = bvm.lookup(operand);
+    assert(vecOperand && "Vector operand couldn't be found");
+
+    auto vecType = dyn_cast<VectorType>(vecOperand.getType());
     if (vecType && (!firstMaxRankedType ||
                     firstMaxRankedType.getRank() < vecType.getRank()))
       firstMaxRankedType = vecType;
   }
   //   b. Broadcast each op if needed.
-  SmallVector<Value> vectorizedOperands;
+  SmallVector<Value> vecOperands;
   for (Value scalarOperand : op->getOperands()) {
-    Value vectorizedOperand = bvm.lookup(scalarOperand);
-    auto vecType =
-        VectorType::get(firstMaxRankedType.getShape(),
-                        getElementTypeOrSelf(vectorizedOperand.getType()),
-                        firstMaxRankedType.getNumScalableDims());
-    vectorizedOperands.push_back(
-        !firstMaxRankedType
-            ? vectorizedOperand
-            : broadcastIfNeeded(rewriter, vectorizedOperand, vecType));
+    Value vecOperand = bvm.lookup(scalarOperand);
+    assert(vecOperand && "Vector operand couldn't be found");
+
+    if (firstMaxRankedType) {
+      auto vecType = VectorType::get(firstMaxRankedType.getShape(),
+                                     getElementTypeOrSelf(vecOperand.getType()),
+                                     firstMaxRankedType.getNumScalableDims());
+      vecOperands.push_back(broadcastIfNeeded(rewriter, vecOperand, vecType));
+    } else {
+      vecOperands.push_back(vecOperand);
+    }
   }
   //   c. for elementwise, the result is the vector with the firstMaxRankedShape
   SmallVector<Type> resultTypes;
   for (Type resultType : op->getResultTypes()) {
     resultTypes.push_back(
-        !firstMaxRankedType
-            ? resultType
-            : VectorType::get(firstMaxRankedType.getShape(), resultType,
-                              firstMaxRankedType.getNumScalableDims()));
+        firstMaxRankedType
+            ? VectorType::get(firstMaxRankedType.getShape(), resultType,
+                              firstMaxRankedType.getNumScalableDims())
+            : resultType);
   }
   //   d. Build and return the new op.
   return VectorizationResult{
       VectorizationStatus::NewOp,
-      rewriter.create(op->getLoc(), op->getName().getIdentifier(),
-                      vectorizedOperands, resultTypes, op->getAttrs())};
+      rewriter.create(op->getLoc(), op->getName().getIdentifier(), vecOperands,
+                      resultTypes, op->getAttrs())};
 }
 
 /// Generic vectorization function that rewrites the body of a `linalgOp` into
index 404c349..130c6bc 100644 (file)
@@ -1719,3 +1719,35 @@ transform.sequence failures(propagate) {
   %1 = get_closest_isolated_parent %0 : (!transform.any_op) -> !transform.any_op
   %2 = transform.structured.vectorize %1  { vectorize_padding } : (!transform.any_op) -> !transform.any_op
 }
+
+// -----
+
+func.func @zero_dim_tensor(%input: tensor<f32>, %output: tensor<f32>) -> tensor<f32>
+{
+  %0 = linalg.generic { indexing_maps = [ affine_map<() -> ()>, affine_map<() -> ()> ],
+                        iterator_types = [] }
+                        ins(%input : tensor<f32>)
+                        outs(%output : tensor<f32>) {
+    ^bb0(%arg0: f32, %arg1: f32):
+      %2 = arith.addf %arg0, %arg1 : f32
+      linalg.yield %2 : f32
+    } -> tensor<f32>
+  return %0 : tensor<f32>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+  %3 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+  %4 = get_closest_isolated_parent %3 : (!transform.any_op) -> !transform.any_op
+  %5 = transform.structured.vectorize %4 : (!transform.any_op) -> !transform.any_op
+}
+
+// CHECK-LABEL: func @zero_dim_tensor
+//       CHECK:     vector.transfer_read {{.*}} : tensor<f32>, vector<f32>
+//       CHECK:     vector.extractelement
+//       CHECK:     vector.transfer_read {{.*}} : tensor<f32>, vector<f32>
+//       CHECK:     vector.extractelement
+//       CHECK:     arith.addf {{.*}} : f32
+//       CHECK:     vector.broadcast %{{.*}} : f32 to vector<f32>
+//       CHECK:     vector.transfer_write {{.*}} : vector<f32>, tensor<f32>
+