[mlir][tensor][bufferize] Lower tensor.generate to linalg.map
authorMatthias Springer <springerm@google.com>
Thu, 27 Oct 2022 09:54:01 +0000 (11:54 +0200)
committerMatthias Springer <springerm@google.com>
Thu, 27 Oct 2022 10:03:13 +0000 (12:03 +0200)
There is no memref equivalent of tensor.generate. The purpose of this change is to avoid creating scf.parallel loops during bufferization.

Differential Revision: https://reviews.llvm.org/D136767

mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
mlir/test/Dialect/Tensor/bufferize.mlir
mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

index ac24d0d..5cddd3c 100644 (file)
@@ -12,6 +12,7 @@
 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
 #include "mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
@@ -487,6 +488,57 @@ struct FromElementsOpInterface
   }
 };
 
+/// Lower the body of a tensor.generate like op (one index-typed bbArg per dim).
+/// Such ops are lowered to linalg.map with the given tensor as a destination.
+///
+/// Example:
+/// ```
+/// %r = tensor.generate %x, %y {
+///   ^bb0(%arg0: index, %arg1: index):
+///   %0 = "some_op"(%arg0, %arg1) : (index, index) -> (index)
+///   tensor.yield %0 : index
+/// } : tensor<?x?xindex>
+/// ```
+///
+/// Is lowered to:
+/// ```
+/// linalg.map ins() outs(%dest) {
+///   %d0 = linalg.index 0 : index
+///   %d1 = linalg.index 1 : index
+///   %0 = "some_op"(%d0, %d1) : (index, index) -> (index)
+///   linalg.yield %0 : index
+/// }
+/// ```
+static Value lowerGenerateLikeOpBody(RewriterBase &rewriter, Location loc,
+                                     Value tensorDestination,
+                                     ValueRange dynamicSizes,
+                                     Region &generateBody) {
+  assert(generateBody.hasOneBlock() && "expected body with single block");
+  auto tensorType = tensorDestination.getType().cast<RankedTensorType>();
+  assert(generateBody.getNumArguments() == tensorType.getRank() &&
+         "rank mismatch");
+
+  // Create linalg::MapOp.
+  OpBuilder::InsertionGuard g(rewriter);
+  auto linalgOp =
+      rewriter.create<linalg::MapOp>(loc, tensorType, /*inputs=*/ValueRange(),
+                                     /*init=*/tensorDestination);
+  Block &linalgBody = linalgOp.getMapper().emplaceBlock();
+
+  // Create linalg::IndexOps.
+  rewriter.setInsertionPointToStart(&linalgBody);
+  SmallVector<Value> indices;
+  for (int64_t dim = 0; dim < tensorType.getRank(); ++dim)
+    indices.push_back(rewriter.create<linalg::IndexOp>(loc, dim));
+
+  // Move over body.
+  rewriter.mergeBlocks(&generateBody.front(), &linalgBody, indices);
+  auto yieldOp = cast<tensor::YieldOp>(linalgBody.getTerminator());
+  rewriter.replaceOpWithNewOp<linalg::YieldOp>(yieldOp, yieldOp.getValue());
+
+  return linalgOp.getResult()[0];
+}
+
 /// Bufferization of tensor.generate.
 struct GenerateOpInterface
     : public BufferizableOpInterface::ExternalModel<GenerateOpInterface,
@@ -507,57 +559,19 @@ struct GenerateOpInterface
     if (options.defaultMemorySpace != static_cast<unsigned>(0))
       return op->emitError("memory space not implemented yet");
 
-    auto tensorType = generateOp.getType().cast<RankedTensorType>();
     // Allocate memory.
     Location loc = op->getLoc();
-    // TODO: Create alloc_tensor ops during TensorCopyInsertion.
     FailureOr<Value> tensorAlloc =
         allocateTensorForShapedValue(rewriter, loc, generateOp.getResult(),
                                      /*escape=*/!dealloc, options,
                                      /*copy=*/false);
     if (failed(tensorAlloc))
       return failure();
-    auto memrefType =
-        MemRefType::get(tensorType.getShape(), tensorType.getElementType());
-    Value buffer = rewriter.create<bufferization::ToMemrefOp>(
-        op->getLoc(), memrefType, *tensorAlloc);
-
-    // Collect loop bounds.
-    int64_t rank = memrefType.getRank();
-    Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
-    Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
-    SmallVector<Value, 4> lowerBounds(rank, zero);
-    SmallVector<Value, 4> steps(rank, one);
-    SmallVector<Value, 4> upperBounds;
-    int nextDynamicIndex = 0;
-    for (int i = 0; i < rank; i++) {
-      Value upperBound =
-          memrefType.isDynamicDim(i)
-              ? generateOp.getDynamicExtents()[nextDynamicIndex++]
-              : rewriter.create<arith::ConstantIndexOp>(
-                    loc, memrefType.getDimSize(i));
-      upperBounds.push_back(upperBound);
-    }
-
-    // Generate tensor elements with a parallel loop that stores into
-    // each element of the resulting memref. We use mergeBlockBefore to "move"
-    // this op's body into the scf.parallel's body.
-    auto parallel =
-        rewriter.create<scf::ParallelOp>(loc, lowerBounds, upperBounds, steps);
-    Block *parallelBody = parallel.getBody();
-    rewriter.mergeBlockBefore(&generateOp.getBody().front(),
-                              parallelBody->getTerminator(),
-                              parallelBody->getArguments());
-    // Replace the inlined yield op with a store op. The scf.parallel's builder
-    // already populated an scf.yield at the end, so we don't need to worry
-    // about creating that.
-    Operation *elementYield = parallelBody->getTerminator()->getPrevNode();
-    rewriter.setInsertionPointAfter(elementYield);
-    rewriter.replaceOpWithNewOp<memref::StoreOp>(
-        elementYield, elementYield->getOperands()[0], buffer,
-        parallelBody->getArguments());
 
-    replaceOpWithBufferizedValues(rewriter, op, buffer);
+    Value result = lowerGenerateLikeOpBody(rewriter, loc, *tensorAlloc,
+                                           generateOp.getDynamicExtents(),
+                                           generateOp.getBody());
+    rewriter.replaceOp(generateOp, result);
 
     return success();
   }
@@ -1019,6 +1033,6 @@ void mlir::tensor::registerBufferizableOpInterfaceExternalModels(
     ReshapeOp::attachInterface<ReshapeOpInterface>(*ctx);
 
     // Load additional dialects of which ops may get created.
-    ctx->loadDialect<arith::ArithDialect, scf::SCFDialect>();
+    ctx->loadDialect<arith::ArithDialect, linalg::LinalgDialect>();
   });
 }
index a18d1de..75216e7 100644 (file)
@@ -19,6 +19,7 @@ add_mlir_dialect_library(MLIRTensorTransforms
   MLIRBufferizationDialect
   MLIRBufferizationTransforms
   MLIRIR
+  MLIRLinalgDialect
   MLIRMemRefDialect
   MLIRPass
   MLIRSCFDialect
index 2584044..d112725 100644 (file)
@@ -186,19 +186,17 @@ func.func @tensor.from_elements_3d(%f0 : f32) -> tensor<3x2x2xf32> {
 // -----
 
 // CHECK-LABEL:   func @tensor.generate(
-// CHECK-SAME:                                       %[[ARG:.*]]: tensor<*xf32>,
-// CHECK-SAME:                                       %[[DYNAMIC_EXTENT:.*]]: index) -> tensor<?xindex> {
-// CHECK-DAG:       %[[C0:.*]] = arith.constant 0 : index
-// CHECK-DAG:       %[[C1:.*]] = arith.constant 1 : index
-// CHECK-DAG:       %[[CASTED:.*]] = bufferization.to_memref %[[ARG]] : memref<*xf32>
-// CHECK-DAG:       %[[MEMREF:.*]] = memref.alloc(%[[DYNAMIC_EXTENT]]) {{.*}} : memref<?xindex>
-// CHECK:           scf.parallel (%[[I:.*]]) = (%[[C0]]) to (%[[DYNAMIC_EXTENT]]) step (%[[C1]]) {
-// CHECK:             %[[ELEM:.*]] = memref.dim %[[CASTED]], %[[I]] : memref<*xf32>
-// CHECK:             store %[[ELEM]], %[[MEMREF]][%[[I]]] : memref<?xindex>
-// CHECK:             scf.yield
+// CHECK-SAME:        %[[ARG:.*]]: tensor<*xf32>,
+// CHECK-SAME:        %[[DYNAMIC_EXTENT:.*]]: index) -> tensor<?xindex> {
+// CHECK-DAG:       %[[ARG_M:.*]] = bufferization.to_memref %[[ARG]] : memref<*xf32>
+// CHECK-DAG:       %[[ALLOC:.*]] = memref.alloc(%[[DYNAMIC_EXTENT]]) {{.*}} : memref<?xindex>
+// CHECK:           %[[ALLOC_T:.*]] = bufferization.to_tensor %[[ALLOC]]
+// CHECK:           %[[MAPPED:.*]] = linalg.map outs(%[[ALLOC_T]] : tensor<?xindex>)() {
+// CHECK:             %[[INDEX:.*]] = linalg.index 0 : index
+// CHECK:             %[[ELEM:.*]] = memref.dim %[[ARG_M]], %[[INDEX]] : memref<*xf32>
+// CHECK:             linalg.yield %[[ELEM]]
 // CHECK:           }
-// CHECK:           %[[RET:.*]] = bufferization.to_tensor %[[MEMREF]] : memref<?xindex>
-// CHECK:           return %[[RET]] : tensor<?xindex>
+// CHECK:           return %[[MAPPED]] : tensor<?xindex>
 // CHECK:         }
 func.func @tensor.generate(%arg: tensor<*xf32>, %dynamic_extent: index) -> tensor<?xindex> {
   %result = tensor.generate %dynamic_extent {
@@ -216,17 +214,15 @@ func.func @tensor.generate(%arg: tensor<*xf32>, %dynamic_extent: index) -> tenso
 //
 // CHECK-LABEL:   func @tensor.generate_static_and_dynamic(
 // CHECK-SAME:        %[[DYNAMIC_EXTENT:.*]]: index) -> tensor<16x?xindex> {
-// CHECK-DAG:       %[[C0:.*]] = arith.constant 0 : index
-// CHECK-DAG:       %[[C1:.*]] = arith.constant 1 : index
-// CHECK-DAG:       %[[C16:.*]] = arith.constant 16 : index
-// CHECK-DAG:       %[[MEMREF:.*]] = memref.alloc(%[[DYNAMIC_EXTENT]]) {{.*}} : memref<16x?xindex>
-// CHECK:           scf.parallel (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[C0]]) to (%[[C16]], %[[DYNAMIC_EXTENT]]) step (%[[C1]], %[[C1]]) {
-// CHECK:             %[[VAL_7:.*]] = arith.addi %[[I]], %[[J]] : index
-// CHECK:             store %[[VAL_7]], %[[MEMREF]][%[[I]], %[[J]]] : memref<16x?xindex>
-// CHECK:             scf.yield
+// CHECK:           %[[ALLOC:.*]] = memref.alloc(%[[DYNAMIC_EXTENT]]) {{.*}} : memref<16x?xindex>
+// CHECK:           %[[ALLOC_T:.*]] = bufferization.to_tensor %[[ALLOC]]
+// CHECK:           %[[MAPPED:.*]] = linalg.map outs(%[[ALLOC_T]] : tensor<16x?xindex>)() {
+// CHECK:             %[[INDEX0:.*]] = linalg.index 0
+// CHECK:             %[[INDEX1:.*]] = linalg.index 1
+// CHECK:             %[[ADD:.*]] = arith.addi %[[INDEX0]], %[[INDEX1]]
+// CHECK:             linalg.yield %[[ADD]]
 // CHECK:           }
-// CHECK:           %[[RET:.*]] = bufferization.to_tensor %[[MEMREF]] : memref<16x?xindex>
-// CHECK:           return %[[RET]] : tensor<16x?xindex>
+// CHECK:           return %[[MAPPED]] : tensor<16x?xindex>
 // CHECK:         }
 func.func @tensor.generate_static_and_dynamic(%arg0: index) -> tensor<16x?xindex> {
   %result = tensor.generate %arg0 {
@@ -541,7 +537,7 @@ func.func @tensor.reshape(%t1: tensor<?x10xf32>) -> tensor<2x2x5xf32> {
 
 // -----
 
-// CHECK:       #[[$sum_map:.*]] = affine_map<()[s0, s1, s2] -> (s0 + s1 + s2)>
+// CHECK:       #[[$sum_map:.+]] = affine_map<()[s0, s1, s2] -> (s0 + s1 + s2)>
 // CHECK-LABEL: func @tensor.pad(
 //  CHECK-SAME:   %[[t1:.*]]: tensor<?x10xindex>, %[[l2:.*]]: index, %[[h1:.*]]: index, %[[h2:.*]]: index
 func.func @tensor.pad(%t1: tensor<?x10xindex>, %l2: index, %h1: index,
@@ -555,10 +551,15 @@ func.func @tensor.pad(%t1: tensor<?x10xindex>, %l2: index, %h1: index,
   // CHECK-DAG: %[[size0:.*]] = affine.apply #[[$sum_map]]()[%[[dim0]], %[[c5]], %[[h1]]]
   // CHECK-DAG: %[[size1:.*]] = affine.apply #[[$sum_map]]()[%[[dim1]], %[[l2]], %[[h2]]]
   // CHECK:     %[[alloc:.*]] = memref.alloc(%[[size0]], %[[size1]]) {{.*}} : memref<?x?xindex>
-  // CHECK:     scf.parallel ({{.*}}) = (%[[c0]], %[[c0]]) to (%[[size0]], %[[size1]]) step (%[[c1]], %[[c1]]) {
-  // CHECK:       memref.store
+  // CHECK:     %[[alloc_t:.*]] = bufferization.to_tensor %[[alloc]]
+  // CHECK:     %[[mapped:.*]] = linalg.map outs(%[[alloc_t]] : tensor<?x?xindex>)() {
+  // CHECK:       %[[index0:.*]] = linalg.index 0
+  // CHECK:       %[[index1:.*]] = linalg.index 1
+  // CHECK:       %[[mul:.*]] = arith.muli %[[index0]], %[[index1]]
+  // CHECK:       linalg.yield %[[mul]]
   // CHECK:     }
-  // CHECK:     %[[subview:.*]] = memref.subview %[[alloc]][5, %[[l2]]] [%[[dim0]], 10] [1, 1]
+  // CHECK:     %[[mapped_m:.*]] = bufferization.to_memref %[[mapped]]
+  // CHECK:     %[[subview:.*]] = memref.subview %[[mapped_m]][5, %[[l2]]] [%[[dim0]], 10] [1, 1]
   // CHECK:     memref.copy %[[m1]], %[[subview]]
   %0 = tensor.pad %t1 low[5, %l2] high[%h1, %h2] {
   ^bb0(%arg0: index, %arg1: index):
@@ -566,7 +567,7 @@ func.func @tensor.pad(%t1: tensor<?x10xindex>, %l2: index, %h1: index,
     tensor.yield %m : index
   } : tensor<?x10xindex> to tensor<?x?xindex>
 
-  // CHECK:     %[[r:.*]] = bufferization.to_tensor %[[alloc]]
+  // CHECK:     %[[r:.*]] = bufferization.to_tensor %[[mapped_m]]
   // CHECK:     return %[[r]] : tensor<?x?xindex>
   return %0 : tensor<?x?xindex>
 }
index e800270..0d9ade7 100644 (file)
@@ -211,8 +211,7 @@ func.func @dealloc_generate_buffer(%arg: tensor<*xf32>, %sz: index, %idx: index)
   -> index
 {
   // CHECK: memref.alloc
-  // CHECK: scf.parallel
-  // CHECK: memref.load
+  // CHECK: linalg.map
   // CHECK: memref.dealloc
   %0 = tensor.generate %sz {
   ^bb0(%i : index):
@@ -229,8 +228,7 @@ func.func @dealloc_generate_buffer(%arg: tensor<*xf32>, %sz: index, %idx: index)
 func.func @dealloc_pad_buffer(%t1: tensor<?x10xindex>, %l2: index, %h1: index,
                               %h2: index, %idx: index) -> index {
   // CHECK: memref.alloc
-  // CHECK: scf.parallel
-  // CHECK: memref.load
+  // CHECK: linalg.map
   // CHECK: memref.dealloc
   %0 = tensor.pad %t1 low[5, %l2] high[%h1, %h2] {
   ^bb0(%arg0: index, %arg1: index):
index e60e712..112caa5 100644 (file)
@@ -5293,6 +5293,7 @@ cc_library(
         ":DialectUtils",
         ":FuncDialect",
         ":IR",
+        ":LinalgDialect",
         ":MemRefDialect",
         ":Pass",
         ":SCFDialect",