[mlir][Vector][NFC] Small vector masking clean-up
authorDiego Caballero <diegocaballero@google.com>
Mon, 3 Apr 2023 21:34:38 +0000 (21:34 +0000)
committerDiego Caballero <diegocaballero@google.com>
Mon, 3 Apr 2023 21:58:27 +0000 (21:58 +0000)
We stored static (int) and dynamic (Value) iteration space dims separately
and then merged them by creating constant ops for the static ones. This
merge happened multiple times during vectorization. This PR changes that
to perform the merge once and store in the state instead of the dynamic
values in isolation.

Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D147351

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/test/Dialect/Linalg/vectorization.mlir

index 98ee5e2..7f8e033 100644 (file)
@@ -196,11 +196,11 @@ private:
     iterSpaceStaticSizes.append(linalgOp.getStaticLoopRanges());
   }
 
-  /// Generates 'tensor.dim' operations for all the dynamic dimensions of the
-  /// iteration space to be vectorized and store them in
-  /// `iterSpaceDynamicSizes`.
-  LogicalResult precomputeIterSpaceDynamicSizes(RewriterBase &rewriter,
-                                                LinalgOp linalgOp);
+  /// Generates 'arith.constant' and 'tensor/memref.dim' operations for
+  /// all the static and dynamic dimensions of the iteration space to be
+  /// vectorized and store them in `iterSpaceValueSizes`.
+  LogicalResult precomputeIterSpaceValueSizes(RewriterBase &rewriter,
+                                              LinalgOp linalgOp);
 
   /// Create or retrieve an existing mask value to mask `opToMask` in the
   /// canonical vector iteration space. If `maybeMaskingMap` the mask is
@@ -214,9 +214,10 @@ private:
   // Dynamic dimensions are represented using ShapedType::kDynamic.
   SmallVector<int64_t> iterSpaceStaticSizes;
 
-  /// Holds the runtime sizes of the iteration spaces to vectorize. Static
-  /// dimensions are represented with a empty value.
-  SmallVector<Value> iterSpaceDynamicSizes;
+  /// Holds the value sizes of the iteration space to vectorize. Static
+  /// dimensions are represented by 'arith.constant' and dynamic
+  /// dimensions by 'tensor/memref.dim'.
+  SmallVector<Value> iterSpaceValueSizes;
 
   /// Holds the canonical vector shape used to vectorize the iteration space.
   SmallVector<int64_t> canonicalVecShape;
@@ -230,17 +231,15 @@ private:
   OpBuilder::InsertionGuard rewriterGuard;
 };
 
-/// Generates 'tensor.dim' operations for all the dynamic dimensions of the
-/// iteration space to be vectorized and store them in
-/// `iterSpaceDynamicSizes`.
 LogicalResult
-VectorizationState::precomputeIterSpaceDynamicSizes(RewriterBase &rewriter,
-                                                    LinalgOp linalgOp) {
+VectorizationState::precomputeIterSpaceValueSizes(RewriterBase &rewriter,
+                                                  LinalgOp linalgOp) {
   // TODO: Support 0-d vectors.
   for (int vecDim = 0, end = canonicalVecShape.size(); vecDim < end; ++vecDim) {
     if (!ShapedType::isDynamic(iterSpaceStaticSizes[vecDim])) {
-      // Add a empty value for static dimensions.
-      iterSpaceDynamicSizes.push_back(Value());
+      // Create constant index op for static dimensions.
+      iterSpaceValueSizes.push_back(rewriter.create<arith::ConstantIndexOp>(
+          linalgOp.getLoc(), iterSpaceStaticSizes[vecDim]));
       continue;
     }
 
@@ -257,7 +256,7 @@ VectorizationState::precomputeIterSpaceDynamicSizes(RewriterBase &rewriter,
                                  linalgOp.getLoc(), operand, operandDimPos)
                            : (Value)rewriter.create<memref::DimOp>(
                                  linalgOp.getLoc(), operand, operandDimPos);
-    iterSpaceDynamicSizes.push_back(dynamicDim);
+    iterSpaceValueSizes.push_back(dynamicDim);
   }
 
   return success();
@@ -293,9 +292,10 @@ VectorizationState::initState(RewriterBase &rewriter, LinalgOp linalgOp,
   // Initialize iteration space static sizes.
   initIterSpaceStaticSizes(linalgOp);
 
-  // Extract and register the runtime value of any potential dynamic shape
-  // needed to compute a mask during vectorization.
-  if (failed(precomputeIterSpaceDynamicSizes(rewriter, linalgOp)))
+  // Generate 'arith.constant' and 'tensor/memref.dim' operations for
+  // all the static and dynamic dimensions of the iteration space, needed to
+  // compute a mask during vectorization.
+  if (failed(precomputeIterSpaceValueSizes(rewriter, linalgOp)))
     return failure();
 
   return success();
@@ -355,18 +355,9 @@ Value VectorizationState::getOrCreateMaskFor(
     return Value();
   }
 
-  // Compute the mask upper bound values by combining the permuted iteration
-  // space static sizes and the dynamic values.
-  SmallVector<Value> permutedDynamicSizes =
-      applyPermutationMap(maskingMap, ArrayRef<Value>(iterSpaceDynamicSizes));
-  SmallVector<Value> upperBounds;
-  for (auto [staticBound, dynBound] :
-       llvm::zip(permutedStaticSizes, permutedDynamicSizes))
-    upperBounds.push_back(ShapedType::isDynamic(staticBound)
-                              ? dynBound
-                              : rewriter.create<arith::ConstantIndexOp>(
-                                    linalgOp.getLoc(), staticBound));
-
+  // Permute the iteration space value sizes to compute the mask upper bounds.
+  SmallVector<Value> upperBounds =
+      applyPermutationMap(maskingMap, ArrayRef<Value>(iterSpaceValueSizes));
   assert(!maskShape.empty() && !upperBounds.empty() &&
          "Masked 0-d vectors are not supported yet");
 
@@ -651,19 +642,19 @@ static VectorizationResult vectorizeLinalgIndex(RewriterBase &rewriter,
   // Compute a one-dimensional index vector for the index op dimension.
   SmallVector<int64_t> constantSeq =
       llvm::to_vector<16>(llvm::seq<int64_t>(0, targetShape[indexOp.getDim()]));
-  auto constantOp = rewriter.create<arith::ConstantOp>(
+  auto indexSteps = rewriter.create<arith::ConstantOp>(
       loc, rewriter.getIndexVectorAttr(constantSeq));
   // Return the one-dimensional index vector if it lives in the trailing
   // dimension of the iteration space since the vectorization algorithm in this
   // case can handle the broadcast.
   if (indexOp.getDim() == targetShape.size() - 1)
-    return VectorizationResult{VectorizationStatus::NewOp, constantOp};
+    return VectorizationResult{VectorizationStatus::NewOp, indexSteps};
   // Otherwise permute the targetShape to move the index dimension last,
   // broadcast the one-dimensional index vector to the permuted shape, and
   // finally transpose the broadcasted index vector to undo the permutation.
   std::swap(targetShape[indexOp.getDim()], targetShape.back());
   auto broadCastOp = rewriter.create<vector::BroadcastOp>(
-      loc, VectorType::get(targetShape, rewriter.getIndexType()), constantOp);
+      loc, VectorType::get(targetShape, rewriter.getIndexType()), indexSteps);
   SmallVector<int64_t> transposition =
       llvm::to_vector<16>(llvm::seq<int64_t>(0, linalgOp.getNumLoops()));
   std::swap(transposition.back(), transposition[indexOp.getDim()]);
index 105d952..6821dd4 100644 (file)
@@ -2397,11 +2397,11 @@ func.func @vectorize_partial_dynamic_identity(%arg0: tensor<8x?xf32>,
 
 // CHECK-LABEL:   func.func @vectorize_partial_dynamic_identity(
 // CHECK-SAME:      %[[VAL_0:.*]]: tensor<8x?xf32>, %[[VAL_1:.*]]: tensor<8x?xf32>, %[[VAL_2:.*]]: tensor<8x?xf32>) -> tensor<8x?xf32> {
-// CHECK:           %[[VAL_3:.*]] = arith.constant 1 : index
-// CHECK:           %[[VAL_4:.*]] = tensor.dim %[[VAL_0]], %[[VAL_3]] : tensor<8x?xf32>
-// CHECK:           %[[VAL_5:.*]] = arith.constant 0 : index
-// CHECK:           %[[VAL_6:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK:           %[[VAL_7:.*]] = arith.constant 8 : index
+// CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 1 : index
+// CHECK-DAG:       %[[VAL_4:.*]] = tensor.dim %[[VAL_0]], %[[VAL_3]] : tensor<8x?xf32>
+// CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 0 : index
+// CHECK-DAG:       %[[VAL_6:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-DAG:       %[[VAL_7:.*]] = arith.constant 8 : index
 // CHECK:           %[[VAL_8:.*]] = vector.create_mask %[[VAL_7]], %[[VAL_4]] : vector<8x32xi1>
 // CHECK:           %[[VAL_9:.*]] = vector.mask %[[VAL_8]] { vector.transfer_read %[[VAL_0]][%[[VAL_5]], %[[VAL_5]]], %[[VAL_6]] {in_bounds = [true, true]} : tensor<8x?xf32>, vector<8x32xf32> } : vector<8x32xi1> -> vector<8x32xf32>
 // CHECK:           %[[VAL_10:.*]] = arith.constant 0.000000e+00 : f32
@@ -2516,10 +2516,10 @@ func.func @vectorize_static_shape_with_mask(%arg0: tensor<8x30xf32>,
 
 // CHECK-LABEL:   func.func @vectorize_static_shape_with_mask(
 // CHECK-SAME:      %[[VAL_0:.*]]: tensor<8x30xf32>, %[[VAL_1:.*]]: tensor<8x30xf32>, %[[VAL_2:.*]]: tensor<8x30xf32>) -> tensor<8x30xf32> {
-// CHECK:           %[[VAL_3:.*]] = arith.constant 0 : index
-// CHECK:           %[[VAL_4:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK:           %[[VAL_5:.*]] = arith.constant 8 : index
-// CHECK:           %[[VAL_6:.*]] = arith.constant 30 : index
+// CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 0 : index
+// CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 8 : index
+// CHECK-DAG:       %[[VAL_6:.*]] = arith.constant 30 : index
 // CHECK:           %[[VAL_7:.*]] = vector.create_mask %[[VAL_5]], %[[VAL_6]] : vector<8x32xi1>
 // CHECK:           %[[VAL_8:.*]] = vector.mask %[[VAL_7]] { vector.transfer_read %[[VAL_0]][%[[VAL_3]], %[[VAL_3]]], %[[VAL_4]] {in_bounds = [true, true]} : tensor<8x30xf32>, vector<8x32xf32> } : vector<8x32xi1> -> vector<8x32xf32>
 // CHECK:           %[[VAL_9:.*]] = arith.constant 0.000000e+00 : f32