[MLIR] [Linalg] Add option to use the partial view after promotion.
authorPierre Oechsel <pierre.oechsel@gmail.com>
Mon, 18 May 2020 16:25:23 +0000 (18:25 +0200)
committerAlex Zinenko <zinenko@google.com>
Mon, 18 May 2020 16:28:18 +0000 (18:28 +0200)
For now the promoted buffer is indexed using the `full view`. The full view might be
slightly bigger than the partial view (which is accounting for boundaries).
Unfortunately this does not compose easily with other transformations when multiple buffers
with shapes related to each other are involved.
Take `linalg.matmul A B C` (with A of size MxK, B of size KxN and C of size MxN) and suppose we are:
- Tiling over M by 100
- Promoting A only

This is producing a `linalg.matmul promoted_A B subview_C` where `promoted_A` is a promoted buffer
of `A` of size (100xK) and `subview_C` is a subview of size mxK where m could be smaller than 100 due
to boundaries thus leading to a possible incorrect behavior.

We propose to:
- Add a new parameter to the tiling promotion allowing to enable the use of the full tile buffer.
- By default all promoted buffers will be indexed by the partial view.

Note that this could be considered as a breaking change in comparison to the way the tiling promotion
was working.

Differential Revision: https://reviews.llvm.org/D79927

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/include/mlir/Dialect/Vector/EDSC/Intrinsics.h
mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
mlir/test/Dialect/Linalg/promote.mlir
mlir/test/lib/Transforms/TestLinalgTransforms.cpp

index 70c3f00..e939771 100644 (file)
@@ -11,6 +11,7 @@
 
 #include "mlir/Dialect/Linalg/Utils/Utils.h"
 #include "mlir/IR/PatternMatch.h"
+#include "llvm/ADT/SmallBitVector.h"
 
 namespace mlir {
 namespace linalg {
@@ -97,6 +98,28 @@ struct LinalgPromotionOptions {
     operandsToPromote->insert(operands.begin(), operands.end());
     return *this;
   }
+  /// If ith element of `useFullTiles` is true the full view should be used for
+  /// the promoted buffer of the ith operand in `operandsToPromote`. Otherwise
+  /// the partial view will be used.
+  /// The decision is defaulted to `useFullTileBuffersDefault` when
+  /// `useFullTileBuffers` is None and for operands missing from
+  /// `useFullTileBuffers`.
+  Optional<llvm::SmallBitVector> useFullTileBuffers = None;
+  LinalgPromotionOptions &setUseFullTileBuffers(ArrayRef<bool> useFullTiles) {
+    unsigned size = useFullTiles.size();
+    llvm::SmallBitVector tmp(size, false);
+    for (unsigned i = 0; i < size; ++i)
+      tmp[i] = useFullTiles[i];
+    useFullTileBuffers = tmp;
+    return *this;
+  }
+  /// If true all operands unspecified by `useFullTileBuffers` will use the full
+  /// view, otherwise the partial view.
+  bool useFullTileBuffersDefault = false;
+  LinalgPromotionOptions &useFullTileBuffersByDefault() {
+    useFullTileBuffersDefault = true;
+    return *this;
+  }
   /// Allow the use of dynamicaly-sized buffers.
   bool dynamicBuffers = false;
   LinalgPromotionOptions &setDynamicBuffers(unsigned dynamic) {
index 7fa9099..6b5c4be 100644 (file)
@@ -16,6 +16,9 @@ namespace intrinsics {
 
 using vector_broadcast = ValueBuilder<vector::BroadcastOp>;
 using vector_contract = ValueBuilder<vector::ContractionOp>;
+using vector_insert = ValueBuilder<vector::InsertOp>;
+using vector_fma = ValueBuilder<vector::FMAOp>;
+using vector_extract = ValueBuilder<vector::ExtractOp>;
 using vector_matmul = ValueBuilder<vector::MatmulOp>;
 using vector_print = OperationBuilder<vector::PrintOp>;
 using vector_transfer_read = ValueBuilder<vector::TransferReadOp>;
index 5cbaa2f..44de2a1 100644 (file)
@@ -56,6 +56,8 @@ struct LinalgOpInstancePromotionOptions {
                                    const LinalgPromotionOptions &options);
   /// SubViews to promote.
   SetVector<Value> subViews;
+  /// True if the full view should be used for the promoted buffer.
+  DenseMap<Value, bool> useFullTileBuffers;
   /// Allow the use of dynamicaly-sized buffers.
   bool dynamicBuffers;
   /// Alignment of promoted buffer.
@@ -65,20 +67,28 @@ struct LinalgOpInstancePromotionOptions {
 
 LinalgOpInstancePromotionOptions::LinalgOpInstancePromotionOptions(
     LinalgOp linalgOp, const LinalgPromotionOptions &options)
-    : subViews(), dynamicBuffers(options.dynamicBuffers),
+    : subViews(), useFullTileBuffers(), dynamicBuffers(options.dynamicBuffers),
       alignment(options.alignment) {
+  unsigned nBuffers = linalgOp.getNumInputsAndOutputBuffers();
+  auto vUseFullTileBuffers =
+      options.useFullTileBuffers.getValueOr(llvm::SmallBitVector());
+  vUseFullTileBuffers.resize(nBuffers, options.useFullTileBuffersDefault);
+
   if (options.operandsToPromote.hasValue()) {
-    for (unsigned idx : options.operandsToPromote.getValue()) {
-      auto *op = linalgOp.getBuffer(idx).getDefiningOp();
-      if (auto sv = dyn_cast_or_null<SubViewOp>(op))
+    for (auto it : llvm::enumerate(options.operandsToPromote.getValue())) {
+      auto *op = linalgOp.getBuffer(it.value()).getDefiningOp();
+      if (auto sv = dyn_cast_or_null<SubViewOp>(op)) {
         subViews.insert(sv);
+        useFullTileBuffers[sv] = vUseFullTileBuffers[it.index()];
+      }
     }
   } else {
-    unsigned nBuffers = linalgOp.getNumInputsAndOutputBuffers();
     for (unsigned idx = 0; idx < nBuffers; ++idx) {
       auto *op = linalgOp.getBuffer(idx).getDefiningOp();
-      if (auto sv = dyn_cast_or_null<SubViewOp>(op))
+      if (auto sv = dyn_cast_or_null<SubViewOp>(op)) {
         subViews.insert(sv);
+        useFullTileBuffers[sv] = vUseFullTileBuffers[idx];
+      }
     }
   }
 }
@@ -201,6 +211,9 @@ promoteSubViews(OpBuilder &b, Location loc,
     auto info = promotionInfoMap.find(v);
     if (info == promotionInfoMap.end())
       continue;
+    // Only fill the buffer if the full local view is used
+    if (!options.useFullTileBuffers[v])
+      continue;
     Value fillVal;
     if (auto t = subView.getType().getElementType().dyn_cast<FloatType>())
       fillVal = folded_std_constant(folder, FloatAttr::get(t, 0.0));
@@ -244,7 +257,10 @@ static void promoteSubViews(OpBuilder &b, LinalgOp op,
   unsigned promotedIdx = 0;
   for (auto view : op.getInputsAndOutputBuffers()) {
     if (options.subViews.count(view) != 0) {
-      opViews.push_back(promotedBufferAndViews[promotedIdx].fullLocalView);
+      if (options.useFullTileBuffers[view])
+        opViews.push_back(promotedBufferAndViews[promotedIdx].fullLocalView);
+      else
+        opViews.push_back(promotedBufferAndViews[promotedIdx].partialLocalView);
       writebackViews.emplace_back(std::make_pair(
           view, promotedBufferAndViews[promotedIdx].partialLocalView));
       promotedIdx++;
index 6453473..27364b0 100644 (file)
@@ -56,14 +56,11 @@ func @matmul_f32(%A: memref<?xi8>, %M: index, %N: index, %K: index) {
 //     DYNAMIC:         std.view %{{.*}}[{{.*}}][{{.*}}] : memref<?xi8> to memref<?x?xf32>
 //       CHECK:         %[[partialC:.*]] = subview %[[fullC]]{{.*}} : memref<?x?xf32> to memref<?x?xf32, #[[strided2D_dynamic]]>
 
-//       CHECK:         linalg.fill(%[[fullA]], {{.*}}) : memref<?x?xf32>, f32
-//       CHECK:         linalg.fill(%[[fullB]], {{.*}}) : memref<?x?xf32>, f32
-//       CHECK:         linalg.fill(%[[fullC]], {{.*}}) : memref<?x?xf32>, f32
 //       CHECK:         linalg.copy(%[[vA]], %[[partialA]]) : memref<?x?xf32, #[[strided2D]]>, memref<?x?xf32, #[[strided2D_dynamic]]>
 //       CHECK:         linalg.copy(%[[vB]], %[[partialB]]) : memref<?x?xf32, #[[strided2D]]>, memref<?x?xf32, #[[strided2D_dynamic]]>
 //       CHECK:         linalg.copy(%[[vC]], %[[partialC]]) : memref<?x?xf32, #[[strided2D]]>, memref<?x?xf32, #[[strided2D_dynamic]]>
 //
-//       CHECK:         linalg.matmul(%[[fullA]], %[[fullB]], %[[fullC]]) : memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>
+//       CHECK:         linalg.matmul(%[[partialA]], %[[partialB]], %[[partialC]]) : memref<?x?xf32, #[[strided2D_dynamic]]>, memref<?x?xf32, #[[strided2D_dynamic]]>, memref<?x?xf32, #[[strided2D_dynamic]]>
 //
 //       CHECK:         linalg.copy(%[[partialC]], %[[vC]]) : memref<?x?xf32, #[[strided2D_dynamic]]>, memref<?x?xf32, #[[strided2D]]>
 //
@@ -121,14 +118,11 @@ func @matmul_f64(%A: memref<?xi8>, %M: index, %N: index, %K: index) {
 //     DYNAMIC:         std.view %{{.*}}[{{.*}}][{{.*}}] : memref<?xi8> to memref<?x?xf64>
 //       CHECK:         %[[partialC_f64:.*]] = subview %[[fullC_f64]][%{{.*}}, %{{.*}}] : memref<?x?xf64> to memref<?x?xf64, #[[strided2D_dynamic]]>
 
-//       CHECK:         linalg.fill(%[[fullA_f64]], {{.*}}) : memref<?x?xf64>, f64
-//       CHECK:         linalg.fill(%[[fullB_f64]], {{.*}}) : memref<?x?xf64>, f64
-//       CHECK:         linalg.fill(%[[fullC_f64]], {{.*}}) : memref<?x?xf64>, f64
 //       CHECK:         linalg.copy(%[[vA_f64]], %[[partialA_f64]]) : memref<?x?xf64, #[[strided2D]]>, memref<?x?xf64, #[[strided2D_dynamic]]>
 //       CHECK:         linalg.copy(%[[vB_f64]], %[[partialB_f64]]) : memref<?x?xf64, #[[strided2D]]>, memref<?x?xf64, #[[strided2D_dynamic]]>
 //       CHECK:         linalg.copy(%[[vC_f64]], %[[partialC_f64]]) : memref<?x?xf64, #[[strided2D]]>, memref<?x?xf64, #[[strided2D_dynamic]]>
 //
-//       CHECK:         linalg.matmul(%[[fullA_f64]], %[[fullB_f64]], %[[fullC_f64]]) : memref<?x?xf64>, memref<?x?xf64>, memref<?x?xf64>
+//       CHECK:         linalg.matmul(%[[partialA_f64]], %[[partialB_f64]], %[[partialC_f64]]) : memref<?x?xf64, #[[strided2D_dynamic]]>, memref<?x?xf64, #[[strided2D_dynamic]]>, memref<?x?xf64, #[[strided2D_dynamic]]>
 //
 //       CHECK:         linalg.copy(%[[partialC_f64]], %[[vC_f64]]) : memref<?x?xf64, #[[strided2D_dynamic]]>, memref<?x?xf64, #[[strided2D]]>
 //
@@ -186,14 +180,11 @@ func @matmul_i32(%A: memref<?xi8>, %M: index, %N: index, %K: index) {
 //     DYNAMIC:         std.view %{{.*}}[{{.*}}][{{.*}}] : memref<?xi8> to memref<?x?xi32>
 //       CHECK:         %[[partialC_i32:.*]] = subview %[[fullC_i32]][%{{.*}}, %{{.*}}] : memref<?x?xi32> to memref<?x?xi32, #[[strided2D_dynamic]]>
 
-//       CHECK:         linalg.fill(%[[fullA_i32]], {{.*}}) : memref<?x?xi32>, i32
-//       CHECK:         linalg.fill(%[[fullB_i32]], {{.*}}) : memref<?x?xi32>, i32
-//       CHECK:         linalg.fill(%[[fullC_i32]], {{.*}}) : memref<?x?xi32>, i32
 //       CHECK:         linalg.copy(%[[vA_i32]], %[[partialA_i32]]) : memref<?x?xi32, #[[strided2D]]>, memref<?x?xi32, #[[strided2D_dynamic]]>
 //       CHECK:         linalg.copy(%[[vB_i32]], %[[partialB_i32]]) : memref<?x?xi32, #[[strided2D]]>, memref<?x?xi32, #[[strided2D_dynamic]]>
 //       CHECK:         linalg.copy(%[[vC_i32]], %[[partialC_i32]]) : memref<?x?xi32, #[[strided2D]]>, memref<?x?xi32, #[[strided2D_dynamic]]>
 //
-//       CHECK:         linalg.matmul(%[[fullA_i32]], %[[fullB_i32]], %[[fullC_i32]]) : memref<?x?xi32>, memref<?x?xi32>, memref<?x?xi32>
+//       CHECK:         linalg.matmul(%[[partialA_i32]], %[[partialB_i32]], %[[partialC_i32]]) : memref<?x?xi32, #[[strided2D_dynamic]]>, memref<?x?xi32, #[[strided2D_dynamic]]>, memref<?x?xi32, #[[strided2D_dynamic]]>
 //
 //       CHECK:         linalg.copy(%[[partialC_i32]], %[[vC_i32]]) : memref<?x?xi32, #[[strided2D_dynamic]]>, memref<?x?xi32, #[[strided2D]]>
 //
index 0390ac9..87191d3 100644 (file)
@@ -132,13 +132,20 @@ static void applyPatterns(FuncOp funcOp) {
   // Linalg subview operands promotion.
   //===--------------------------------------------------------------------===//
   patterns.insert<LinalgPromotionPattern<MatmulOp>>(
-      ctx, LinalgPromotionOptions(),
+      ctx, LinalgPromotionOptions().useFullTileBuffersByDefault(),
       LinalgMarker({"_promote_views_"}, "_views_promoted_"));
   patterns.insert<LinalgPromotionPattern<MatmulOp>>(
-      ctx, LinalgPromotionOptions().setOperandsToPromote({0}),
+      ctx,
+      LinalgPromotionOptions()
+          .setOperandsToPromote({0})
+          .useFullTileBuffersByDefault(),
       LinalgMarker({"_promote_first_view_"}, "_first_view_promoted_"));
   patterns.insert<LinalgPromotionPattern<FillOp>>(
-      ctx, LinalgPromotionOptions().setOperandsToPromote({0}).setAlignment(32),
+      ctx,
+      LinalgPromotionOptions()
+          .setOperandsToPromote({0})
+          .setUseFullTileBuffers({true})
+          .setAlignment(32),
       LinalgMarker({"_promote_views_aligned_"}, "_views_aligned_promoted_"));
 
   applyPatternsAndFoldGreedily(funcOp, patterns);
@@ -171,7 +178,8 @@ void fillL1TilingAndMatmulToVectorPatterns(
       LinalgMarker({startMarker}, "L1")));
 
   patternsVector.emplace_back(LinalgPromotionPattern<MatmulOp>(
-      context, LinalgPromotionOptions(), LinalgMarker({"L1"}, "VEC")));
+      context, LinalgPromotionOptions().useFullTileBuffersByDefault(),
+      LinalgMarker({"L1"}, "VEC")));
 
   patternsVector.emplace_back(
       LinalgVectorizationPattern<MatmulOp>(context, LinalgMarker({"VEC"})));