[mlir][bufferization] Add bufferization.eliminate_empty_tensors transform op
authorMatthias Springer <springerm@google.com>
Thu, 6 Apr 2023 05:17:54 +0000 (14:17 +0900)
committerMatthias Springer <springerm@google.com>
Thu, 6 Apr 2023 05:22:47 +0000 (14:22 +0900)
Differential Revision: https://reviews.llvm.org/D144401

mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td
mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
mlir/test/Dialect/Bufferization/Transforms/transform-ops.mlir

index 880781b..692f0f0 100644 (file)
@@ -17,6 +17,13 @@ include "mlir/Dialect/PDL/IR/PDLTypes.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/IR/OpBase.td"
 
+def Transform_EmptyOp : Transform_ConcreteOpType<"tensor.empty">;
+def Transform_AllocTensorOp : Transform_ConcreteOpType<"bufferization.alloc_tensor">;
+
+//===----------------------------------------------------------------------===//
+// OneShotBufferizeOp
+//===----------------------------------------------------------------------===//
+
 def OneShotBufferizeOp
     : Op<Transform_Dialect, "bufferization.one_shot_bufferize",
         [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
@@ -63,12 +70,65 @@ def OneShotBufferizeOp
 }
 
 //===----------------------------------------------------------------------===//
-// EmptyTensorToAllocTensorOp
+// EliminateEmptyTensorsOp
 //===----------------------------------------------------------------------===//
 
+def EliminateEmptyTensorsOp
+    : Op<Transform_Dialect, "bufferization.eliminate_empty_tensors",
+        [DeclareOpInterfaceMethods<TransformOpInterface>,
+         DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
+  let description = [{
+    Try to eliminate all `tensor.empty` ops within the targeted op by replacing
+    them with a destination tensor.
 
-def Transform_EmptyOp : Transform_ConcreteOpType<"tensor.empty">;
-def Transform_AllocTensorOp : Transform_ConcreteOpType<"bufferization.alloc_tensor">;
+    `tensor.empty` ops cannot be bufferizes. They can either be converted to
+    `bufferization.alloc_tensor` or replaced with another tensor (via this
+    transform). `tensor.empty` does not specify the contents of the returned
+    tensor so their results can be replaced with arbitrary tensor values as long
+    as the dimensions match.
+
+    This transform looks for `tensor.empty` ops where the SSA use-def chain of
+    the result ends in a supported "anchor op" (always following the aliasing
+    OpOperand/OpResult chain). Currently supported anchor ops are:
+    - `tensor.insert_slice`
+    - `bufferization.yield` (inside `bufferization.alloc_tensor`)
+
+    Example:
+
+    ```
+    %0 = tensor.empty() : tensor<5xf32>
+    %1 = linalg.fill ... outs(%0)
+    %2 = tensor.insert_slice %1 into %t[1][5][1]
+    ```
+
+    Is rewritten with:
+    ```
+    %0 = tensor.extract_slice %t[1][5][1]
+    %1 = linalg.fill ... outs(%0)
+    %2 = tensor.insert_slice %1 into %t[1][5][1]
+    ```
+
+    The above example can bufferize without an allocation (in the absence of
+    other conflicts) because there is no longer a `tensor.empty` op.
+
+    See `-eliminate-empty-tensors` for more details.
+
+    #### Return modes
+
+    This transform reads the target handle and modifies the payload. It does
+    not produce any handle.
+  }];
+
+  let arguments = (ins PDL_Operation:$target);
+
+  let results = (outs);
+
+  let assemblyFormat = "$target attr-dict";
+}
+
+//===----------------------------------------------------------------------===//
+// EmptyTensorToAllocTensorOp
+//===----------------------------------------------------------------------===//
 
 def EmptyTensorToAllocTensorOp
     : Op<Transform_Dialect, "bufferization.empty_tensor_to_alloc_tensor",
index 8f56357..58766e8 100644 (file)
@@ -11,6 +11,7 @@
 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
 #include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
+#include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/PDL/IR/PDL.h"
 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
@@ -63,6 +64,37 @@ transform::OneShotBufferizeOp::apply(TransformResults &transformResults,
 }
 
 //===----------------------------------------------------------------------===//
+// EliminateEmptyTensorsOp
+//===----------------------------------------------------------------------===//
+
+void transform::EliminateEmptyTensorsOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  onlyReadsHandle(getTarget(), effects);
+  modifiesPayload(effects);
+}
+
+DiagnosedSilenceableFailure
+transform::EliminateEmptyTensorsOp::apply(TransformResults &transformResults,
+                                          TransformState &state) {
+  IRRewriter rewriter(getContext());
+  OneShotBufferizationOptions options;
+  options.allowReturnAllocs = true;
+
+  ArrayRef<Operation *> payloadOps = state.getPayloadOps(getTarget());
+  for (Operation *target : payloadOps) {
+    OneShotAnalysisState state(target, options);
+    if (failed(analyzeOp(target, state)))
+      return mlir::emitSilenceableFailure(target->getLoc())
+             << "failed to analyze op";
+    if (failed(bufferization::insertSliceAnchoredEmptyTensorEliminationStep(
+            rewriter, target, state)))
+      return mlir::emitSilenceableFailure(target->getLoc())
+             << "failed to eliminate insert_slice anchored tensor.empty ops";
+  }
+  return DiagnosedSilenceableFailure::success();
+}
+
+//===----------------------------------------------------------------------===//
 // EmptyTensorToAllocTensorOp
 //===----------------------------------------------------------------------===//
 
index 7e62500..05d6a1b 100644 (file)
@@ -130,3 +130,24 @@ func.func @empty_to_tensor_alloc() -> tensor<2x2xf32> {
   %0 = tensor.empty() : tensor<2x2xf32>
   return %0 : tensor<2x2xf32>
 }
+
+// -----
+
+transform.sequence failures(propagate) {
+^bb0(%arg1: !pdl.operation):
+  %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!pdl.operation) -> !pdl.operation
+  transform.bufferization.eliminate_empty_tensors %0
+}
+
+// CHECK-LABEL: func @empty_tensor_elimination(
+//       CHECK:   tensor.extract_slice
+//       CHECK:   linalg.fill
+//       CHECK:   tensor.insert_slice
+func.func @empty_tensor_elimination(
+    %t: tensor<10xf32>, %f: f32) -> tensor<10xf32> {
+  %0 = tensor.empty() : tensor<5xf32>
+  %1 = linalg.fill ins(%f : f32) outs(%0 : tensor<5xf32>) -> tensor<5xf32>
+  %2 = tensor.insert_slice %1 into %t [1][5][1]
+      : tensor<5xf32> into tensor<10xf32>
+  return %2 : tensor<10xf32>
+}