[mlir] Make linalg-bufferize a composable bufferization pass
authorSean Silva <silvasean@google.com>
Tue, 3 Nov 2020 01:18:46 +0000 (17:18 -0800)
committerSean Silva <silvasean@google.com>
Wed, 4 Nov 2020 18:16:55 +0000 (10:16 -0800)
Previously, linalg-bufferize was a "finalizing" bufferization pass (it
did a "full" conversion). This wasn't great because it couldn't be used
composably with other bufferization passes like std-bufferize and
scf-bufferize.

This patch makes linalg-bufferize a composable bufferization pass.
Notice that the integration tests are switched over to using a pipeline
of std-bufferize, linalg-bufferize, and (to finalize the conversion)
func-bufferize. It all "just works" together.

While doing this transition, I ran into a nasty bug in the 1-use special
case logic for forwarding init tensors. That logic, while
well-intentioned, was fundamentally flawed, because it assumed that if
the original tensor value had one use, then the converted memref could
be mutated in place. That assumption is wrong in many cases. For
example:

```
  %0 = some_tensor : tensor<4xf32>
  br ^bb0(%0, %0: tensor<4xf32>, tensor<4xf32>)
^bb0(%bbarg0: tensor<4xf32>, %bbarg1: tensor<4xf32>)
  // %bbarg0 is an alias of %bbarg1. We cannot safely write
  // to it without analyzing uses of %bbarg1.
  linalg.generic ... init(%bbarg0) {...}
```

A similar example can happen in many scenarios with function arguments.
Even more sinister, if the converted memref is produced by a
`std.get_global_memref` of a constant global memref, then we might
attempt to write into read-only statically allocated storage! Not all
memrefs are writable!

Clearly, this 1-use check is not a local transformation that we can do
on the fly in this pattern, so I removed it.

The test is now drastically shorter and I basically rewrote the CHECK
lines from scratch because:
- the new composable linalg-bufferize just doesn't do as much, so there
is less to test
- a lot of the tests were related to the 1-use check, which is now gone,
so there is less to test
- the `-buffer-hoisting -buffer-deallocation` is no longer mixed in, so
the checks related to that had to be rewritten

Differential Revision: https://reviews.llvm.org/D90657

mlir/integration_test/Dialect/Linalg/CPU/test-tensor-e2e.mlir
mlir/integration_test/Dialect/Linalg/CPU/test-tensor-matmul.mlir
mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
mlir/test/Dialect/Linalg/bufferize.mlir

index b14418e..1ac09d6 100644 (file)
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -linalg-bufferize -convert-linalg-to-loops -convert-linalg-to-llvm -convert-std-to-llvm | \
+// RUN: mlir-opt %s -std-bufferize -linalg-bufferize -func-bufferize -convert-linalg-to-loops -convert-linalg-to-llvm -convert-std-to-llvm | \
 // RUN: mlir-cpu-runner -e main -entry-point-result=void \
 // RUN:   -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
 // RUN: | FileCheck %s
index 548fab3..df099a3 100644 (file)
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -linalg-bufferize -convert-linalg-to-loops -convert-linalg-to-llvm -convert-std-to-llvm | \
+// RUN: mlir-opt %s -std-bufferize -linalg-bufferize -func-bufferize -convert-linalg-to-loops -convert-linalg-to-llvm -convert-std-to-llvm | \
 // RUN: mlir-cpu-runner -e main -entry-point-result=void \
 // RUN:   -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \
 // RUN: | FileCheck %s
@@ -25,4 +25,3 @@ func @main() {
 }
 
 func @print_memref_f32(%ptr : tensor<*xf32>)
-
index 2ad287e..026424b 100644 (file)
@@ -67,14 +67,8 @@ allocateBuffersForResults(Location loc, LinalgOp linalgOp,
     // under linalg on tensor based transformations.
     bool foldedInitTensor = resultIndex < linalgOp.getNumInitTensors();
     if (foldedInitTensor) {
-      // Dealing with an init tensor requires distinguishing between 1-use
-      // and many-use cases which would create aliasing and WAR hazards.
       Value initTensor = linalgOp.getInitTensor(resultIndex);
       Value initBuffer = adaptor.init_tensors()[resultIndex];
-      if (initTensor.hasOneUse()) {
-        resultBuffers.push_back(initBuffer);
-        continue;
-      }
       SmallVector<Value, 4> dynOperands;
       for (auto dim : llvm::enumerate(tensorShape)) {
         if (dim.value() == TensorType::kDynamicSize) {
@@ -187,17 +181,16 @@ static void finalizeBufferAllocation(ConversionPatternRewriter &rewriter,
 }
 
 //===----------------------------------------------------------------------===//
-// Buffer allocation patterns.
+// Bufferization patterns.
 //===----------------------------------------------------------------------===//
 
 namespace {
-/// Generic BufferizeConversionPattern that matches any Operation* and
-/// dispatches internally. This avoids template instantiating one pattern for
-/// each LinalgOp op.
-class LinalgOpConverter : public BufferizeConversionPattern {
+/// Generic conversion pattern that matches any LinalgOp. This avoids template
+/// instantiating one pattern for each LinalgOp.
+class BufferizeAnyLinalgOp : public ConversionPattern {
 public:
-  LinalgOpConverter(MLIRContext *context, BufferizeTypeConverter &converter)
-      : BufferizeConversionPattern(context, converter) {}
+  BufferizeAnyLinalgOp(TypeConverter &typeConverter)
+      : ConversionPattern(/*benefit=*/1, typeConverter, MatchAnyOpTypeTag()) {}
 
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
@@ -212,17 +205,6 @@ public:
     // init_tensors for all linalg::LinalgOp interface ops.
     linalg::GenericOpAdaptor adaptor(operands, op->getAttrDictionary());
 
-    // All inputs need to be turned into buffers first. Until then, bail out.
-    if (llvm::any_of(adaptor.inputs(),
-                     [](Value in) { return !in.getType().isa<MemRefType>(); }))
-      return failure();
-
-    // All init_tensors need to be turned into buffers first. Until then, bail
-    // out.
-    if (llvm::any_of(adaptor.init_tensors(),
-                     [](Value in) { return !in.getType().isa<MemRefType>(); }))
-      return failure();
-
     Location loc = linalgOp.getLoc();
     SmallVector<Value, 2> newOutputBuffers(adaptor.output_buffers().begin(),
                                            adaptor.output_buffers().end());
@@ -252,10 +234,9 @@ namespace {
 /// TensorConstantOp conversion inserts a linearized 1-D vector constant that is
 /// stored in memory. A linalg.reshape is introduced to convert to the desired
 /// n-D buffer form.
-class TensorConstantOpConverter
-    : public BufferizeOpConversionPattern<ConstantOp> {
+class TensorConstantOpConverter : public OpConversionPattern<ConstantOp> {
 public:
-  using BufferizeOpConversionPattern<ConstantOp>::BufferizeOpConversionPattern;
+  using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
   matchAndRewrite(ConstantOp op, ArrayRef<Value> operands,
@@ -275,7 +256,7 @@ public:
       nElements *= s;
     Type elementType = rankedTensorType.getElementType();
     MemRefType memrefType =
-        converter.convertType(op.getType()).cast<MemRefType>();
+        getTypeConverter()->convertType(op.getType()).cast<MemRefType>();
     VectorType flatVectorType = VectorType::get({nElements}, elementType);
     MemRefType memrefOfFlatVectorType = MemRefType::get({}, flatVectorType);
     MemRefType flatMemrefType = MemRefType::get({nElements}, elementType);
@@ -316,64 +297,21 @@ struct LinalgBufferizePass : public LinalgBufferizeBase<LinalgBufferizePass> {
     BufferizeTypeConverter converter;
 
     // Mark all Standard operations legal.
+    // TODO: Remove after TensorConstantOpConverter moves to std-bufferize.
     target.addLegalDialect<StandardOpsDialect, vector::VectorDialect>();
-    target.addLegalOp<ModuleOp>();
-    target.addLegalOp<ModuleTerminatorOp>();
 
     // Mark all Linalg operations illegal as long as they work on tensors.
     auto isLegalOperation = [&](Operation *op) {
       return converter.isLegal(op);
     };
-    target.addDynamicallyLegalDialect<linalg::LinalgDialect>(
-        Optional<ConversionTarget::DynamicLegalityCallbackFn>(
-            isLegalOperation));
-
-    // Mark operations that consume or return tensors illegal.
-    auto isLegal = [&](Operation *op) {
-      if (llvm::any_of(op->getOperandTypes(),
-                       [&](Type t) { return !converter.isLegal(t); }))
-        return false;
-      if (llvm::any_of(op->getResultTypes(),
-                       [&](Type t) { return !converter.isLegal(t); }))
-        return false;
-      return true;
-    };
-    target.addDynamicallyLegalOp<
-        // clang-format off
-        CallOp,
-        ConstantOp,
-        ConstantIntOp,
-        ConstantIndexOp,
-        ConstantFloatOp,
-        ReturnOp,
-        TensorCastOp
-        // clang-format on
-        >(isLegal);
-
-    // Mark the function operation illegal as long as an argument is tensor.
-    // TODO: if the FuncOp is a FuncOp that only has a declaration (e.g. to an
-    // externally defined symbol like an external library calls), only convert
-    // if some special attribute is set. This will allow more control of interop
-    // across ABI boundaries.
-    target.addDynamicallyLegalOp<FuncOp>([&](FuncOp funcOp) {
-      return converter.isSignatureLegal(funcOp.getType()) &&
-             llvm::none_of(funcOp.getType().getResults(),
-                           [&](Type type) { return type.isa<MemRefType>(); }) &&
-             converter.isLegal(&funcOp.getBody());
-    });
-
-    converter.setResultConversionKind<RankedTensorType, MemRefType>(
-        BufferizeTypeConverter::AppendToArgumentsList);
+    target.addDynamicallyLegalDialect<linalg::LinalgDialect>(isLegalOperation);
+    target.addDynamicallyLegalOp<ConstantOp>(isLegalOperation);
 
     OwningRewritePatternList patterns;
     populateLinalgBufferizePatterns(&context, converter, patterns);
-    populateStdBufferizePatterns(&context, converter, patterns);
-    populateWithBufferizeOpConversionPatterns<mlir::ReturnOp, mlir::ReturnOp,
-                                              linalg::CopyOp>(
-        &context, converter, patterns);
-    if (failed(applyFullConversion(this->getOperation(), target,
-                                   std::move(patterns))))
-      this->signalPassFailure();
+    if (failed(applyPartialConversion(getOperation(), target,
+                                      std::move(patterns))))
+      signalPassFailure();
   }
 };
 } // end anonymous namespace
@@ -384,10 +322,7 @@ std::unique_ptr<OperationPass<ModuleOp>> mlir::createLinalgBufferizePass() {
 void mlir::linalg::populateLinalgBufferizePatterns(
     MLIRContext *context, BufferizeTypeConverter &converter,
     OwningRewritePatternList &patterns) {
-  patterns.insert<
-      // clang-format off
-      LinalgOpConverter,
-      TensorConstantOpConverter
-      // clang-format on
-      >(context, converter);
+
+  patterns.insert<BufferizeAnyLinalgOp>(converter);
+  patterns.insert<TensorConstantOpConverter>(converter, context);
 }
index f0b7c3d..50ac4b7 100644 (file)
@@ -1,94 +1,84 @@
-// RUN: mlir-opt -linalg-bufferize -buffer-hoisting -buffer-deallocation -split-input-file %s | FileCheck %s
+// RUN: mlir-opt -linalg-bufferize -split-input-file %s | FileCheck %s
 
 #map0 = affine_map<(d0) -> (d0)>
 
-// CHECK-LABEL: func @multiple_results
-func @multiple_results(%arg0: tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>) {
-    %0, %1 = linalg.generic {
-      indexing_maps = [#map0, #map0, #map0],
+// In-depth checking of a basic case, this is testing
+// - tensor_to_memref / tensor_load materializations are properly inserted
+// - payload is correctly carried over
+// - affine maps are correctly carried over
+// Later tests will not check all these details.
+
+// CHECK: #map = affine_map<(d0) -> (d0)>
+// CHECK-LABEL:   func @basic(
+// CHECK-SAME:                %[[TENSOR:.*]]: tensor<4xf32>) -> tensor<4xf32> {
+// CHECK:           %[[MEMREF:.*]] = tensor_to_memref %[[TENSOR]] : memref<4xf32>
+// CHECK:           %[[RESULT_MEMREF:.*]] = alloc() : memref<4xf32>
+// CHECK:           linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]}
+// CHECK-SAME:      ins(%[[MEMREF]] : memref<4xf32>)
+// CHECK-SAME:      outs(%[[RESULT_MEMREF]] : memref<4xf32>) {
+// CHECK:           ^bb0(%[[RESULT1:.*]]: f32, %[[UNUSED:.*]]: f32):
+// CHECK:             %[[DIM1:.*]] = exp %[[RESULT1]] : f32
+// CHECK:             linalg.yield %[[DIM1]] : f32
+// CHECK:           }
+// CHECK:           %[[RESULT:.*]] = tensor_load %[[RESULT_MEMREF]] : memref<4xf32>
+// CHECK:           return %[[RESULT]] : tensor<4xf32>
+func @basic(%arg0: tensor<4xf32>) -> tensor<4xf32> {
+    %0 = linalg.generic {
+      indexing_maps = [#map0, #map0],
       iterator_types = ["parallel"]
     } ins(%arg0 : tensor<4xf32>) {
       ^bb0(%gen_arg1: f32):
         %tmp1 = exp %gen_arg1 : f32
-        linalg.yield %tmp1, %tmp1 : f32, f32
-    } -> tensor<4xf32>, tensor<4xf32>
-    return %0, %1 : tensor<4xf32>, tensor<4xf32>
+        linalg.yield %tmp1 : f32
+    } -> tensor<4xf32>
+    return %0 : tensor<4xf32>
 }
-//      CHECK: (%[[NEW_ARG0:.*]]: [[TYPE:.*]], %[[ARG1_RESULT:.*]]: [[TYPE]], %[[ARG2_RESULT:.*]]: [[TYPE]])
-//      CHECK: %[[FIRST_ALLOC:.*]] = alloc() : [[TYPE]]
-//      CHECK: %[[SECOND_ALLOC:.*]] = alloc() : [[TYPE]]
-//      CHECK: linalg.generic
-// CHECK-SAME: ins(%[[NEW_ARG0]] : [[TYPE]]
-// CHECK-SAME: outs(%[[FIRST_ALLOC]], %[[SECOND_ALLOC]] : [[TYPE]], [[TYPE]]
-// CHECK-NEXT: ^{{[a-z0-9_]*}}
-// CHECK-SAME: %{{.*}}: f32, %{{.*}}: f32, %{{.*}}: f32
-// CHECK-NEXT: %{{.*}} = exp
-// CHECK-NEXT: linalg.yield
-//      CHECK: linalg.copy(%[[FIRST_ALLOC]], %[[ARG1_RESULT]])
-//      CHECK: dealloc %[[FIRST_ALLOC]]
-//      CHECK: linalg.copy(%[[SECOND_ALLOC]], %[[ARG2_RESULT]])
-//      CHECK: dealloc %[[SECOND_ALLOC]]
-//      CHECK: return
+
 
 // -----
 
 #map0 = affine_map<(d0) -> (d0)>
 
-// CHECK-LABEL: func @chained_operations
-func @chained_operations(%arg0: tensor<4xf32>) -> tensor<4xf32> {
-    %0 = linalg.generic {
-      indexing_maps = [#map0, #map0],
+// CHECK-LABEL:   func @multiple_results
+// CHECK:           %[[RESULT0:.*]] = alloc() : memref<4xf32>
+// CHECK:           %[[RESULT1:.*]] = alloc() : memref<4xf32>
+// CHECK:           linalg.generic
+// CHECK-SAME:      ins(%{{.*}} : memref<4xf32>)
+// CHECK-SAME:      outs(%[[RESULT0]], %[[RESULT1]] : memref<4xf32>, memref<4xf32>)
+func @multiple_results(%arg0: tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>) {
+    %0, %1 = linalg.generic {
+      indexing_maps = [#map0, #map0, #map0],
       iterator_types = ["parallel"]
     } ins(%arg0 : tensor<4xf32>) {
       ^bb0(%gen_arg1: f32):
         %tmp1 = exp %gen_arg1 : f32
-        linalg.yield %tmp1 : f32
-    } -> tensor<4xf32>
-
-    %1 = linalg.generic {
-      indexing_maps = [#map0, #map0],
-      iterator_types = ["parallel"]
-    } ins(%0 : tensor<4xf32>) {
-      ^bb0(%gen_arg2: f32):
-        %tmp2 = exp %gen_arg2 : f32
-        linalg.yield %tmp2 : f32
-    } -> tensor<4xf32>
-    return %1 : tensor<4xf32>
-}
-//      CHECK: (%[[NEW_ARG0:.*]]: [[TYPE:.*]], %[[ARG1_RESULT:.*]]: [[TYPE]])
-//      CHECK: %[[FIRST_ALLOC:.*]] = alloc() : [[TYPE]]
-//      CHECK: linalg.generic
-// CHECK-SAME: ins(%[[NEW_ARG0]] : [[TYPE]]
-// CHECK-SAME: outs(%[[FIRST_ALLOC]] : [[TYPE]]
-//      CHECK: ^{{[a-z0-9_]*}}
-// CHECK-SAME: %{{.*}}: f32, %{{.*}}: f32
-//      CHECK: %[[SECOND_ALLOC:.*]] = alloc() : [[TYPE]]
-//      CHECK: linalg.generic
-// CHECK-SAME: ins(%[[FIRST_ALLOC]] : [[TYPE]]
-// CHECK-SAME: outs(%[[SECOND_ALLOC]] : [[TYPE]]
-//      CHECK: ^{{[a-z0-9_]*}}
-// CHECK-SAME: %{{.*}}: f32, %{{.*}}: f32
-//      CHECK: dealloc %[[FIRST_ALLOC]]
-//      CHECK: linalg.copy(%[[SECOND_ALLOC]], %[[ARG1_RESULT]])
-//      CHECK: dealloc %[[SECOND_ALLOC]]
-//      CHECK: return
-
-// -----
-
-// CHECK-LABEL: func @no_linalg_op
-func @no_linalg_op(%arg0: f32) -> (f32, f32) {
-  %0 = mulf %arg0, %arg0 : f32
-  return %0, %0 : f32, f32
+        linalg.yield %tmp1, %tmp1 : f32, f32
+    } -> tensor<4xf32>, tensor<4xf32>
+    return %0, %1 : tensor<4xf32>, tensor<4xf32>
 }
-// CHECK: (%[[NEW_ARG0:.*]]: [[TYPE:.*]]) -> ([[TYPE]], [[TYPE]])
-// CHECK: %[[RESULT:.*]] = mulf %[[NEW_ARG0]], %[[NEW_ARG0]] : [[TYPE]]
-// CHECK: return %[[RESULT]], %[[RESULT]] : [[TYPE]], [[TYPE]]
 
 // -----
 
 #map_2d = affine_map<(d0, d1) -> (d0, d1)>
 #map_2d_inv = affine_map<(d0, d1) -> (d1, d0)>
 
+// Check that the allocs properly consider the different shapes of the output
+// operands. The permuted indexing maps translate to different output shapes.
+
+// CHECK: #map0 = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK: #map1 = affine_map<(d0, d1) -> (d1, d0)>
+// CHECK-LABEL:   func @dynamic_results(
+// CHECK-SAME:                          %[[ARG:.*]]: tensor<?x?xf32>
+// CHECK:           %[[MEMREF_ARG:.*]] = tensor_to_memref %[[ARG]] : memref<?x?xf32>
+// CHECK:           %[[C0:.*]] = constant 0 : index
+// CHECK:           %[[DIM0:.*]] = dim %[[ARG]], %[[C0]] : tensor<?x?xf32>
+// CHECK:           %[[C1:.*]] = constant 1 : index
+// CHECK:           %[[DIM1:.*]] = dim %[[ARG]], %[[C1]] : tensor<?x?xf32>
+// CHECK:           %[[RESULT0:.*]] = alloc(%[[DIM0]], %[[DIM1]]) : memref<?x?xf32>
+// CHECK:           %[[RESULT1:.*]] = alloc(%[[DIM1]], %[[DIM0]]) : memref<?x?xf32>
+// CHECK:           linalg.generic {indexing_maps = [#map0, #map0, #map1]
+// CHECK-SAME:      ins(%[[MEMREF_ARG]] : memref<?x?xf32>)
+// CHECK-SAME:      outs(%[[RESULT0]], %[[RESULT1]] : memref<?x?xf32>, memref<?x?xf32>)
 func @dynamic_results(%arg0: tensor<?x?xf32>)
          -> (tensor<?x?xf32>, tensor<?x?xf32>) {
     %0, %1 = linalg.generic {
@@ -102,79 +92,24 @@ func @dynamic_results(%arg0: tensor<?x?xf32>)
     return %0, %1 : tensor<?x?xf32>, tensor<?x?xf32>
 }
 
-// CHECK: #map0 = affine_map<(d0, d1) -> (d0, d1)>
-// CHECK: #map1 = affine_map<(d0, d1) -> (d1, d0)>
-
-// CHECK-LABEL: func @dynamic_results
-// CHECK-SAME: (%[[INPUT:.*]]: [[TYPE:.*]], %[[OUT_1:.*]]: [[TYPE]], %[[OUT_2:.*]]: [[TYPE]]) {
-// CHECK: %[[C0:.*]] = constant 0 : index
-// CHECK: %[[DIM_0:.*]] = dim %[[INPUT]], %[[C0]] : [[TYPE]]
-// CHECK: %[[C1:.*]] = constant 1 : index
-// CHECK: %[[DIM_1:.*]] = dim %[[INPUT]], %[[C1]] : [[TYPE]]
-// CHECK: %[[OUT_BUF_1:.*]] = alloc(%[[DIM_0]], %[[DIM_1]]) : [[TYPE]]
-// CHECK: %[[OUT_BUF_2:.*]] = alloc(%[[DIM_1]], %[[DIM_0]]) : [[TYPE]]
-
-// CHECK: linalg.generic {indexing_maps = [#map0, #map0, #map1], {{.*}}}
-// CHECK-SAME: ins(%[[INPUT]] : [[TYPE]])
-// CHECK-SAME: outs(%[[OUT_BUF_1]], %[[OUT_BUF_2]] : [[TYPE]], [[TYPE]]) {
-
-// CHECK: linalg.copy(%[[OUT_BUF_1]], %[[OUT_1]]) : [[TYPE]], [[TYPE]]
-// CHECK: dealloc %[[OUT_BUF_1]] : [[TYPE]]
-// CHECK: linalg.copy(%[[OUT_BUF_2]], %[[OUT_2]]) : [[TYPE]], [[TYPE]]
-// CHECK: dealloc %[[OUT_BUF_2]] : [[TYPE]]
-// CHECK: return
-
 // -----
 
-func @foo() -> tensor<2x3xf32> {
-// CHECK-LABEL: func @foo(
-//  CHECK-SAME:   %[[A:[0-9a-z]*]]: memref<2x3xf32>) {
-
+// Check lowering of tensor-valued std.constant's
+// TODO: Move this to std-bufferize.
+
+// CHECK-LABEL:   func @constant() -> tensor<2x3xf32> {
+// CHECK:           %[[VECTOR_MEMREF:.*]] = alloc() : memref<vector<6xf32>>
+// CHECK:           %[[VECTOR_CONST:.*]] = constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00]> : vector<6xf32>
+// CHECK:           store %[[VECTOR_CONST]], %[[VECTOR_MEMREF]][] : memref<vector<6xf32>>
+// CHECK:           %[[MEMREF:.*]] = vector.type_cast %[[VECTOR_MEMREF]] : memref<vector<6xf32>> to memref<6xf32>
+// CHECK:           %[[FINAL_SHAPE:.*]] = linalg.reshape %[[MEMREF]] [#map] : memref<6xf32> into memref<2x3xf32>
+// CHECK:           %[[RESULT:.*]] = tensor_load %[[FINAL_SHAPE]] : memref<2x3xf32>
+// CHECK:           return %[[RESULT]] : tensor<2x3xf32>
+func @constant() -> tensor<2x3xf32> {
   %0 = constant dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>
-//  CHECK-NEXT:   %[[ALLOC:.*]] = alloc() : memref<vector<6xf32>>
-//  CHECK-NEXT:   %[[CST:.*]] = constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00]> : vector<6xf32>
-//  CHECK-NEXT:   store %[[CST]], %[[ALLOC]][] : memref<vector<6xf32>>
-//  CHECK-NEXT:   %[[FLAT:.*]] = vector.type_cast %[[ALLOC]] : memref<vector<6xf32>> to memref<6xf32>
-//  CHECK-NEXT:   %[[RES:.*]] = linalg.reshape %[[FLAT]] {{.*}} : memref<6xf32> into memref<2x3xf32>
-
-  return %0 : tensor<2x3xf32>
-//  CHECK-NEXT:   linalg.copy(%[[RES]], %[[A]]) : memref<2x3xf32>, memref<2x3xf32>
-//  CHECK-NEXT:   dealloc %[[ALLOC]] : memref<vector<6xf32>>
-//  CHECK-NEXT:   return
+  return %0: tensor<2x3xf32>
 }
 
-func @bar() {
-// CHECK-LABEL: func @bar() {
-
-  %0 = call @foo() : () -> tensor<2x3xf32>
-//  CHECK-NEXT:   %[[ALLOC:.*]] = alloc() : memref<2x3xf32>
-//  CHECK-NEXT:   call @foo(%[[ALLOC]]) : (memref<2x3xf32>) -> ()
-
-  // Instead of relying on tensor_store which introduces aliasing, we rely on
-  // the conversion of print_memref_f32(tensor<*xf32>) to
-  // print_memref_f32(memref<*xf32>).
-  // Note that this is skipping a step and we would need at least some function
-  // attribute to declare that this conversion is valid (e.g. when we statically
-  // know that things will play nicely at the C ABI boundary).
-  %unranked = tensor_cast %0 : tensor<2x3xf32> to tensor<*xf32>
-//  CHECK-NEXT:   %[[UNRANKED:.*]] = memref_cast %[[ALLOC]] :
-//  CHECK-SAME:     memref<2x3xf32> to memref<*xf32>
-
-  call @print_memref_f32(%unranked) : (tensor<*xf32>) -> ()
-//  CHECK-NEXT:   call @print_memref_f32(%[[UNRANKED]]) : (memref<*xf32>) -> ()
-
-  return
-//  CHECK-NEXT:   dealloc %[[ALLOC]] : memref<2x3xf32>
-//  CHECK-NEXT:   return
-}
-
-// This gets converted to a function operating on memref<*xf32>.
-// Note that this is skipping a step and we would need at least some function
-// attribute to declare that this conversion is valid (e.g. when we statically
-// know that things will play nicely at the C ABI boundary).
-func @print_memref_f32(%ptr : tensor<*xf32>)
-// CHECK-LABEL: func @print_memref_f32(memref<*xf32>)
-
 // -----
 
 #accesses = [
@@ -187,6 +122,18 @@ func @print_memref_f32(%ptr : tensor<*xf32>)
   iterator_types = ["parallel", "parallel", "reduction"]
 }
 
+// Check the bufferization of init tensors.
+
+// CHECK-LABEL:   func @generic_with_init_tensor(
+// CHECK-SAME:                                   %[[ARG0_TENSOR:.*]]: tensor<2x3x4xvector<3x4xi4>>,
+// CHECK-SAME:                                   %[[ARG1_TENSOR:.*]]: tensor<3x2xf32>) -> tensor<3x2xf32> {
+// CHECK:           %[[ARG0_MEMREF:.*]] = tensor_to_memref %[[ARG0_TENSOR]] : memref<2x3x4xvector<3x4xi4>>
+// CHECK:           %[[ARG1_MEMREF:.*]] = tensor_to_memref %[[ARG1_TENSOR]] : memref<3x2xf32>
+// CHECK:           %[[INIT_BUFFER:.*]] = alloc() : memref<3x2xf32>
+// CHECK:           linalg.copy(%[[ARG1_MEMREF]], %[[INIT_BUFFER]]) : memref<3x2xf32>, memref<3x2xf32>
+// CHECK:           linalg.generic
+// CHECK-SAME:      ins(%[[ARG0_MEMREF]] : memref<2x3x4xvector<3x4xi4>>)
+// CHECK-SAME:      outs(%[[INIT_BUFFER]] : memref<3x2xf32>) {
 func @generic_with_init_tensor(%arg0: tensor<2x3x4xvector<3x4xi4>>,
   %arg1: tensor<3x2xf32>) -> (tensor<3x2xf32>) {
 
@@ -200,116 +147,3 @@ func @generic_with_init_tensor(%arg0: tensor<2x3x4xvector<3x4xi4>>,
 
   return %0 : tensor<3x2xf32>
 }
-// CHECK-LABEL: func @generic_with_init_tensor
-//  CHECK-SAME: (%[[ARG0:.*]]: memref<2x3x4xvector<3x4xi4>>, %[[ARG1:.*]]: memref<3x2xf32>, %[[RESULT0:.*]]: memref<3x2xf32>) {
-//  CHECK-NEXT: linalg.generic
-//       CHECK: linalg.copy(%[[ARG1]], %[[RESULT0]])
-//  CHECK-NEXT: return
-//   CHECK-NOT: %
-
-// -----
-
-#accesses = [
-  affine_map<(i, j, k) -> (j, i, k)>,
-  affine_map<(i, j, k) -> (i, j)>
-]
-
-#trait = {
-  indexing_maps = #accesses,
-  iterator_types = ["parallel", "parallel", "reduction"]
-}
-
-func @init_tensor_with_2_uses(%arg0: tensor<2x3x4xvector<3x4xi4>>,
-  %arg1: tensor<3x2xf32>) -> (tensor<3x2xf32>, tensor<3x2xf32>) {
-
-  %0 = linalg.generic #trait
-    ins(%arg0 : tensor<2x3x4xvector<3x4xi4>>)
-   init(%arg1 : tensor<3x2xf32>) {
-    ^bb(%v0: vector<3x4xi4>, %v1: f32) :
-      %f0 = constant 0.0 : f32
-      linalg.yield %f0 : f32
-  } -> tensor<3x2xf32>
-
-  %1 = linalg.generic #trait
-    ins(%arg0 : tensor<2x3x4xvector<3x4xi4>>)
-   init(%arg1 : tensor<3x2xf32>) {
-    ^bb(%v0: vector<3x4xi4>, %v1: f32) :
-      %f0 = constant 0.0 : f32
-      linalg.yield %f0 : f32
-  } -> tensor<3x2xf32>
-
-  return %0, %1 : tensor<3x2xf32>, tensor<3x2xf32>
-}
-// CHECK-LABEL: func @init_tensor_with_2_uses
-//  CHECK-SAME: (%[[ARG0:.*]]: memref<2x3x4xvector<3x4xi4>>, %[[ARG1:.*]]: memref<3x2xf32>, %[[RESULT0:.*]]: memref<3x2xf32>, %[[RESULT1:.*]]: memref<3x2xf32>) {
-//  CHECK-NEXT: %[[ALLOC0:.*]] = alloc
-//  CHECK-NEXT: linalg.copy(%[[ARG1]], %[[ALLOC0]])
-//  CHECK-NEXT: linalg.generic
-//  CHECK-SAME: outs(%[[ALLOC0]]
-//  CHECK-NEXT: ^bb
-//  CHECK-NEXT:   constant
-//  CHECK-NEXT:   yield
-//  CHECK-NEXT: }
-//  CHECK-NEXT: %[[ALLOC1:.*]] = alloc
-//  CHECK-NEXT: linalg.copy(%[[ARG1]], %[[ALLOC1]])
-//  CHECK-NEXT: linalg.generic
-//  CHECK-SAME: outs(%[[ALLOC1]]
-//  CHECK-NEXT: ^bb
-//  CHECK-NEXT:   constant
-//  CHECK-NEXT:   yield
-//  CHECK-NEXT: }
-//  CHECK-NEXT: linalg.copy(%[[ALLOC0]], %[[RESULT0]])
-//  CHECK-NEXT: dealloc
-//  CHECK-NEXT: linalg.copy(%[[ALLOC1]], %[[RESULT1]])
-//  CHECK-NEXT: dealloc
-//  CHECK-NEXT: return
-//   CHECK-NOT: %
-
-// -----
-
-#accesses = [
-  affine_map<(i, j, k) -> (j, i, k)>,
-  affine_map<(i, j, k) -> (i, j)>
-]
-
-#trait = {
-  indexing_maps = #accesses,
-  iterator_types = ["parallel", "parallel", "reduction"]
-}
-
-func @init_tensor_with_1_use_def_chain(%arg0: tensor<2x3x4xvector<3x4xi4>>,
-  %arg1: tensor<3x2xf32>) -> (tensor<3x2xf32>) {
-
-  %0 = linalg.generic #trait
-    ins(%arg0 : tensor<2x3x4xvector<3x4xi4>>)
-   init(%arg1 : tensor<3x2xf32>) {
-    ^bb(%v0: vector<3x4xi4>, %v1: f32) :
-      %f0 = constant 0.0 : f32
-      linalg.yield %f0 : f32
-  } -> tensor<3x2xf32>
-
-  %1 = linalg.generic #trait
-    ins(%arg0 : tensor<2x3x4xvector<3x4xi4>>)
-   init(%0 : tensor<3x2xf32>) {
-    ^bb(%v0: vector<3x4xi4>, %v1: f32) :
-      %f0 = constant 0.0 : f32
-      linalg.yield %f0 : f32
-  } -> tensor<3x2xf32>
-
-  return %1 : tensor<3x2xf32>
-}
-// CHECK-LABEL: func @init_tensor_with_1_use_def_chain
-//  CHECK-SAME: (%[[ARG0:.*]]: memref<2x3x4xvector<3x4xi4>>, %[[ARG1:.*]]: memref<3x2xf32>, %[[RESULT0:.*]]: memref<3x2xf32>) {
-//  CHECK-NEXT: linalg.generic
-//  CHECK-NEXT: ^bb
-//  CHECK-NEXT:   constant
-//  CHECK-NEXT:   yield
-//  CHECK-NEXT: }
-//  CHECK-NEXT: linalg.generic
-//  CHECK-NEXT: ^bb
-//  CHECK-NEXT:   constant
-//  CHECK-NEXT:   yield
-//  CHECK-NEXT: }
-//  CHECK-NEXT: linalg.copy(%[[ARG1]], %[[RESULT0]])
-//  CHECK-NEXT: return
-//   CHECK-NOT: %