[MLIR] Vectorize tensor.extract on n-D tensor (n >= 2)
authorAndrzej Warzynski <andrzej.warzynski@arm.com>
Wed, 19 Oct 2022 17:04:31 +0000 (17:04 +0000)
committerAndrzej Warzynski <andrzej.warzynski@arm.com>
Mon, 12 Dec 2022 09:32:16 +0000 (09:32 +0000)
This patch implements the vectorization of tensor.extract for arbitrary
tensors. It basically extends https://reviews.llvm.org/D133786 by adding
support for n-D tensors (n >= 2). This is implemented by essentially
flattening the indices.

When benchmarking the vectorized code, we have observed that it is
slower than the scalar code. That's most likely due to sub-optimal (and,
in general slow) gather loads. More work is needed to identify an
implementation and/or a representation that would lead to better code.
In the meantime, the vectorization of n-D tensors (where n >= 2) has to
be explicitly enabled. This can be done either via:
  * transfer dialect's `vectorize_nd_extract` attribute,
  * dedicated bool argument in the `vectorize` method from
    "Vectorization.cpp".
The second option was added to control the new functionality through
means other than the transfer dialect.

Related discussion: https://github.com/iree-org/iree/issues/9198

Differential Revision: https://reviews.llvm.org/D137660

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/test/Dialect/Linalg/vectorization.mlir

index f2b3fb7..e6bfa08 100644 (file)
@@ -1091,6 +1091,7 @@ def VectorizeOp : Op<Transform_Dialect, "structured.vectorize",
 
   let arguments = (ins PDL_Operation:$target,
                    UnitAttr:$vectorize_padding,
+                   UnitAttr:$vectorize_nd_extract,
                    UnitAttr:$disable_multi_reduction_to_contract_patterns,
                    UnitAttr:$disable_transfer_permutation_map_lowering_patterns);
   let results = (outs PDL_Operation:$transformed);
@@ -1098,7 +1099,9 @@ def VectorizeOp : Op<Transform_Dialect, "structured.vectorize",
   let assemblyFormat = "$target attr-dict";
 
   let builders = [
-    OpBuilder<(ins "Value":$target, CArg<"bool", "false">:$vectorizePadding)>
+    OpBuilder<(ins "Value":$target,
+               CArg<"bool", "false">:$vectorizePadding,
+               CArg<"bool", "false">:$vectorizeNDExtract)>,
   ];
   let extraClassDeclaration = [{
     ::mlir::DiagnosedSilenceableFailure applyToOne(
index 81ec026..7d6a584 100644 (file)
@@ -345,7 +345,8 @@ FailureOr<LinalgOp> promoteSubViews(OpBuilder &b, LinalgOp op,
                                     const LinalgPromotionOptions &options);
 
 /// Emit a suitable vector form for a Linalg op with fully static shape.
-LogicalResult vectorize(RewriterBase &builder, LinalgOp linalgOp);
+LogicalResult vectorize(RewriterBase &builder, LinalgOp linalgOp,
+                        bool vectorizeNDExtract = false);
 
 /// Emit a suitable vector form for a Copy op with fully static shape.
 LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp);
@@ -371,7 +372,8 @@ LogicalResult promoteSubviewsPrecondition(Operation *op,
                                           LinalgPromotionOptions options);
 
 /// Return success if the operation can be vectorized.
-LogicalResult vectorizeLinalgOpPrecondition(LinalgOp linalgOp);
+LogicalResult vectorizeLinalgOpPrecondition(LinalgOp linalgOp,
+                                            bool vectorizeNDExtract = false);
 
 //===----------------------------------------------------------------------===//
 // Transformations exposed as rewrite patterns.
@@ -865,6 +867,9 @@ protected:
 void populatePadOpVectorizationPatterns(RewritePatternSet &patterns,
                                         PatternBenefit baseBenefit = 1);
 
+void populateExtractOpVectorizationPatterns(RewritePatternSet &patterns,
+                                            PatternBenefit baseBenefit = 1);
+
 /// Match and rewrite for the pattern:
 /// ```
 ///    %alloc = ...
index 8fdd6cb..d35d96a 100644 (file)
@@ -1781,12 +1781,17 @@ void transform::TileToScfForOp::getEffects(
 //===----------------------------------------------------------------------===//
 
 void transform::VectorizeOp::build(OpBuilder &builder, OperationState &result,
-                                   Value target, bool vectorizePadding) {
+                                   Value target, bool vectorizePadding,
+                                   bool vectorizeExtract) {
   result.addOperands(target);
   if (vectorizePadding) {
     result.addAttribute(VectorizeOp::getVectorizePaddingAttrName(result.name),
                         builder.getUnitAttr());
   }
+  if (vectorizeExtract) {
+    result.addAttribute(VectorizeOp::getVectorizeNdExtractAttrName(result.name),
+                        builder.getUnitAttr());
+  }
   result.addTypes(pdl::OperationType::get(builder.getContext()));
 }
 
@@ -1794,15 +1799,22 @@ namespace {
 /// This is an helper only to call vectorize via a pattern inside of
 /// VectorizeOp::applyToOne.
 struct VectorizationPattern : public RewritePattern {
-  explicit VectorizationPattern(MLIRContext *context)
-      : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}
+  explicit VectorizationPattern(MLIRContext *context,
+                                bool vectorizeExtract = false)
+      : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context),
+        vectorizeNDExtract(vectorizeExtract) {}
   LogicalResult matchAndRewrite(Operation *op,
                                 PatternRewriter &rewriter) const override {
     LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
     if (!linalgOp)
       return rewriter.notifyMatchFailure(op, "expected Linalg Op");
-    return vectorize(rewriter, linalgOp);
+    return vectorize(rewriter, linalgOp, vectorizeNDExtract);
   }
+
+private:
+  /// Controls whether to vectorize `tensor.extract` when the input tensor is
+  /// rank >= 2.
+  bool vectorizeNDExtract = false;
 };
 } // namespace
 
@@ -1818,7 +1830,7 @@ transform::VectorizeOp::applyToOne(Operation *target,
 
   MLIRContext *ctx = getContext();
   RewritePatternSet patterns(ctx);
-  patterns.add<VectorizationPattern>(ctx);
+  patterns.add<VectorizationPattern>(ctx, getVectorizeNdExtract());
 
   if (!getDisableTransferPermutationMapLoweringPatterns())
     vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
index ad52a5a..a7c3c00 100644 (file)
@@ -242,7 +242,7 @@ static Value buildVectorWrite(OpBuilder &b, Value value,
 // with CustomVectorizationHook. Returns success if the corresponding custom
 // hook can vectorize the op.
 using CustomVectorizationPrecondition =
-    std::function<LogicalResult(Operation *)>;
+    std::function<LogicalResult(Operation *, bool)>;
 
 // Custom vectorization function type. Produce a vector form of Operation*
 // assuming all its vectorized operands are already in the BlockAndValueMapping.
@@ -314,13 +314,13 @@ vectorizeLinalgIndex(RewriterBase &rewriter, Operation *op, LinalgOp linalgOp) {
 
 /// Helper function to check if the tensor.extract can be vectorized by the
 /// custom hook vectorizeTensorExtract.
-static LogicalResult tensorExtractVectorizationPrecondition(Operation *op) {
+static LogicalResult
+tensorExtractVectorizationPrecondition(Operation *op, bool vectorizeNDExtract) {
   tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
   if (!extractOp)
     return failure();
 
-  // Currently only supports extraction with an 1-D index.
-  if (extractOp.getIndices().size() != 1)
+  if (extractOp.getIndices().size() != 1 && !vectorizeNDExtract)
     return failure();
 
   if (!VectorType::isValidElementType(extractOp.getIndices()[0].getType()))
@@ -335,6 +335,51 @@ static LogicalResult tensorExtractVectorizationPrecondition(Operation *op) {
   return success();
 }
 
+/// Calculates the offsets (`$index_vec`) for `vector.gather` operations
+/// generated from `tensor.extract`. The offset is calculated as follows
+/// (example using scalar values):
+///
+///    offset = extractOp.indices[0]
+///    for (i = 1; i < numIndices; i++)
+///      offset = extractOp.dimSize[i] * offset + extractOp.indices[i];
+///
+/// For tensor<45 x 80 x 15 x f32> and index [1, 2, 3], this leads to:
+///  offset = ( ( 1 ) * 80 +  2 ) * 15  + 3
+static Value
+calculateGatherOffset(OpBuilder &b, tensor::ExtractOp extractOp,
+                      const BlockAndValueMapping &bvm,
+                      const SmallVectorImpl<int64_t> &targetShape) {
+  // The vector of indices for GatherOp should be shaped as the output vector
+  auto indexVecType = VectorType::get(targetShape, b.getIndexType());
+  auto loc = extractOp.getLoc();
+
+  Value offset = b.create<vector::BroadcastOp>(
+      loc, indexVecType, bvm.lookup(extractOp.getIndices()[0]));
+
+  const size_t numIndices = extractOp.getIndices().size();
+  for (size_t i = 1; i < numIndices; i++) {
+    auto dimSizeBcast = b.create<vector::BroadcastOp>(
+        loc, indexVecType,
+        b.create<arith::ConstantIndexOp>(
+            loc,
+            extractOp.getTensor().getType().cast<ShapedType>().getDimSize(i)));
+    offset = b.create<arith::MulIOp>(loc, offset, dimSizeBcast);
+
+    auto originalIndexBcast = bvm.lookup(extractOp.getIndices()[i]);
+    if (i == numIndices - 1) {
+      // We only need an additional broadcast for the trailing index. All other
+      // indices have already been broadcast by `vectorizeLinalgIndex` to match
+      // the output size.
+      originalIndexBcast = b.create<vector::BroadcastOp>(
+          loc, indexVecType, bvm.lookup(extractOp.getIndices()[i]));
+    }
+
+    offset = b.create<arith::AddIOp>(loc, originalIndexBcast, offset);
+  }
+
+  return offset;
+}
+
 /// Helper function to vectorize the tensor.extract operations. Returns
 /// VectorizationStatus::NewOp to signal the vectorization algorithm that it
 /// should map the produced operations. This function is meant to be used as a
@@ -347,29 +392,29 @@ vectorizeTensorExtract(RewriterBase &rewriter, Operation *op, LinalgOp linalgOp,
     return VectorizationResult{VectorizationStatus::Failure, nullptr};
   auto loc = extractOp.getLoc();
 
-  // Currently only supports extraction with an 1-D index. Checked in the
-  // tensorExtractVectorizationPrecondition.
-  assert(extractOp.getIndices().size() == 1);
-
-  auto indexVec = bvm.lookup(extractOp.getIndices()[0]);
   // Compute the static loop sizes of the extract op.
   auto targetShape = linalgOp.computeStaticLoopSizes();
 
-  SmallVector<Value> gatherIndices;
-  gatherIndices.push_back(rewriter.create<arith::ConstantIndexOp>(loc, 0));
-
+  auto resultType =
+      VectorType::get(targetShape, extractOp.getResult().getType());
   auto maskConstantOp = rewriter.create<arith::ConstantOp>(
       loc, DenseIntElementsAttr::get(
                VectorType::get(targetShape, rewriter.getI1Type()),
                /*value=*/true));
-
-  auto resultType =
-      VectorType::get(targetShape, extractOp.getResult().getType());
   auto passThruConstantOp =
       rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(resultType));
 
+  // Base indices are currently set to 0. We will need to re-visit if more
+  // generic scenarios are to be supported.
+  SmallVector<Value> baseIndices(
+      extractOp.getIndices().size(),
+      rewriter.create<arith::ConstantIndexOp>(loc, 0));
+
+  Value offset = calculateGatherOffset(rewriter, extractOp, bvm, targetShape);
+
+  // Generate the gather load
   auto gatherOp = rewriter.create<vector::GatherOp>(
-      loc, resultType, extractOp.getTensor(), gatherIndices, indexVec,
+      loc, resultType, extractOp.getTensor(), baseIndices, offset,
       maskConstantOp, passThruConstantOp);
 
   return VectorizationResult{VectorizationStatus::NewOp, gatherOp};
@@ -586,7 +631,7 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, LinalgOp linalgOp,
   };
   hooks.push_back(vectorizeYield);
 
-  // 4rewriter. Register CustomVectorizationHook for indexOp.
+  // 4b. Register CustomVectorizationHook for indexOp.
   CustomVectorizationHook vectorizeIndex =
       [&](Operation *op,
           const BlockAndValueMapping &bvm) -> VectorizationResult {
@@ -642,7 +687,8 @@ static LogicalResult reductionPreconditions(LinalgOp op) {
 
 static LogicalResult vectorizeStaticLinalgOpPrecondition(
     linalg::LinalgOp op,
-    ArrayRef<CustomVectorizationPrecondition> customPreconditions) {
+    ArrayRef<CustomVectorizationPrecondition> customPreconditions,
+    bool vectorizeNDExtract) {
 
   // All types in the body should be a supported element type for VectorType.
   for (Operation &innerOp : op->getRegion(0).front()) {
@@ -650,7 +696,8 @@ static LogicalResult vectorizeStaticLinalgOpPrecondition(
     if (llvm::any_of(
             customPreconditions,
             [&](const CustomVectorizationPrecondition &customPrecondition) {
-              return succeeded(customPrecondition(&innerOp));
+              return succeeded(
+                  customPrecondition(&innerOp, vectorizeNDExtract));
             })) {
       continue;
     }
@@ -686,7 +733,9 @@ static LogicalResult vectorizeStaticLinalgOpPrecondition(
   return success();
 }
 
-LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(LinalgOp linalgOp) {
+LogicalResult
+mlir::linalg::vectorizeLinalgOpPrecondition(LinalgOp linalgOp,
+                                            bool vectorizeNDExtract) {
   // All types must be static shape to go to vector.
   if (linalgOp.hasDynamicShape()) {
     LDBG("precondition failed: dynamic shape");
@@ -698,12 +747,13 @@ LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(LinalgOp linalgOp) {
   // Register CustomVectorizationPrecondition for extractOp.
   customPreconditions.push_back(tensorExtractVectorizationPrecondition);
 
-  return vectorizeStaticLinalgOpPrecondition(linalgOp, customPreconditions);
+  return vectorizeStaticLinalgOpPrecondition(linalgOp, customPreconditions,
+                                             vectorizeNDExtract);
 }
 
-LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter,
-                                      LinalgOp linalgOp) {
-  if (failed(vectorizeLinalgOpPrecondition(linalgOp)))
+LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, LinalgOp linalgOp,
+                                      bool vectorizeNDExtract) {
+  if (failed(vectorizeLinalgOpPrecondition(linalgOp, vectorizeNDExtract)))
     return failure();
 
   SmallVector<Value> results;
@@ -713,7 +763,7 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter,
   if (succeeded(convOr)) {
     llvm::append_range(results, (*convOr)->getResults());
   } else {
-    if (failed(vectorizeLinalgOpPrecondition(linalgOp)))
+    if (failed(vectorizeLinalgOpPrecondition(linalgOp, vectorizeNDExtract)))
       return failure();
     LDBG("Vectorize generic by broadcasting to a common shape: " << linalgOp);
     if (failed(vectorizeAsLinalgGeneric(rewriter, linalgOp, results)))
index 3b351dc..a9a536a 100644 (file)
@@ -1500,7 +1500,7 @@ transform.sequence failures(propagate) {
 #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d2)>
 #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
 #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-func.func @not_vectorize_nd_tensor_extract(%arg0: tensor<3x3xf32>, %arg1: tensor<4x3xi32>, %arg2: tensor<4x3xi32>, %arg3: tensor<4x7x2xf32>, %arg4: tensor<4x7x3x2xf32>) -> tensor<4x7x3x2xf32> {
+func.func @vectorize_nd_tensor_extract(%arg0: tensor<3x3xf32>, %arg1: tensor<4x3xi32>, %arg2: tensor<4x3xi32>, %arg3: tensor<4x7x2xf32>, %arg4: tensor<4x7x3x2xf32>) -> tensor<4x7x3x2xf32> {
   %2 = linalg.generic {
     indexing_maps = [#map0, #map0, #map1, #map2],
     iterator_types = ["parallel", "parallel", "parallel", "parallel"]
@@ -1513,14 +1513,34 @@ func.func @not_vectorize_nd_tensor_extract(%arg0: tensor<3x3xf32>, %arg1: tensor
   } -> tensor<4x7x3x2xf32>
   return %2 : tensor<4x7x3x2xf32>
 }
-// CHECK-LABEL: func.func @not_vectorize_nd_tensor_extract
-// CHECK: tensor.extract
+// CHECK-LABEL: func.func @vectorize_nd_tensor_extract
+// CHECK-SAME: %[[ARG0:.*]]: tensor<3x3xf32>
+// CHECK-SAME: %[[ARG1:arg1]]: tensor<4x3xi32>
+// CHECK-SAME: %[[ARG2:arg2]]: tensor<4x3xi32>
+// CHECK-SAME: %[[ARG3:.*]]: tensor<4x7x2xf32>
+// CHECK-SAME: %[[ARG4:.*]]: tensor<4x7x3x2xf32>
+// CHECK:    %[[C0:.*]] = arith.constant 0 : index
+// CHECK:    %[[C0_i32:.*]] = arith.constant 0 : i32
+// CHECK:    %[[CST:.*]] = arith.constant dense<3> : vector<7x2x4x3xindex>
+// CHECK:    %[[CST_1:.*]] = arith.constant dense<true> : vector<4x7x3x2xi1>
+// CHECK:    %[[PASSTHRU:.*]] = arith.constant dense<0.000000e+00> : vector<4x7x3x2xf32>
+// CHECK:    %[[V0:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], %[[C0_i32]] {in_bounds = [true, true]} : tensor<4x3xi32>, vector<4x3xi32>
+// CHECK:    %[[V1:.*]] = vector.transfer_read %[[ARG2]][%[[C0]], %[[C0]]], %[[C0_i32]] {in_bounds = [true, true]} : tensor<4x3xi32>, vector<4x3xi32>
+// CHECK:    %[[CAST:.*]] = arith.index_cast %[[V0]] : vector<4x3xi32> to vector<4x3xindex>
+// CHECK:    %[[B1:.*]] = vector.broadcast %[[CAST]] : vector<4x3xindex> to vector<7x2x4x3xindex>
+// CHECK:    %[[CAST_1:.*]] = arith.index_cast %[[V1]] : vector<4x3xi32> to vector<4x3xindex>
+// CHECK:    %[[B2:.*]] = vector.broadcast %[[CAST_1]] : vector<4x3xindex> to vector<7x2x4x3xindex>
+// CHECK:    %[[MULI:.*]] = arith.muli %[[B1]], %[[CST]] : vector<7x2x4x3xindex>
+// CHECK:    %[[ADDI:.*]] = arith.addi %[[B2]], %[[MULI]] : vector<7x2x4x3xindex>
+// CHECK:    %[[T:.*]] = vector.transpose %[[ADDI]], [2, 0, 3, 1] : vector<7x2x4x3xindex> to vector<4x7x3x2xindex>
+// CHECK:    %[[GATHER:.*]] = vector.gather %[[ARG0]][%[[C0]], %[[C0]]] [%[[T]]], %[[CST_1]], %[[PASSTHRU]] : tensor<3x3xf32>, vector<4x7x3x2xindex>, vector<4x7x3x2xi1>, vector<4x7x3x2xf32> into vector<4x7x3x2xf32>
+// CHECK:    vector.transfer_write %[[GATHER]], %[[ARG4]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true, true]} : vector<4x7x3x2xf32>, tensor<4x7x3x2xf32>
 
 transform.sequence failures(propagate) {
 ^bb1(%arg1: !pdl.operation):
   %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
   %1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation
-  %2 = transform.structured.vectorize %1
+  %2 = transform.structured.vectorize %1 { vectorize_nd_extract }
 }
 
 // -----