[MLIR][Bufferization] Introduce `EmptyTensorToAllocTensorOp`
authorLorenzo Chelini <l.chelini@icloud.com>
Wed, 14 Dec 2022 13:47:01 +0000 (14:47 +0100)
committerLorenzo Chelini <l.chelini@icloud.com>
Mon, 19 Dec 2022 08:12:10 +0000 (09:12 +0100)
Introduce a new transform operation to replace `tensor.empty` with
`alloc_tensor` operations. The operation is a pass-through if the target
operation is already a `alloc_tensor`; otherwise, it expects a
`tensor.empty` as a target. Currently, it does not return any results.

The operation is expected to run before `one_shot_bufferize` as
`one_shot_bufferize` rejects `tensor.empty`.

Reviewed By: springerm

Differential Revision: https://reviews.llvm.org/D140026

mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h
mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td
mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
mlir/test/Dialect/Bufferization/Transforms/transform-ops.mlir

index 0aab581..06204b6 100644 (file)
 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/Dialect/Transform/IR/TransformTypes.h"
 #include "mlir/IR/OpImplementation.h"
 
+namespace mlir {
+namespace tensor {
+class EmptyOp;
+} // namespace tensor
+} // namespace mlir
+
 //===----------------------------------------------------------------------===//
 // Bufferization Transform Operations
 //===----------------------------------------------------------------------===//
index e63ecbf..5135764 100644 (file)
@@ -13,6 +13,7 @@ include "mlir/Dialect/Bufferization/IR/BufferizationEnums.td"
 include "mlir/Dialect/Transform/IR/TransformDialect.td"
 include "mlir/Dialect/Transform/IR/TransformEffects.td"
 include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
+include "mlir/Dialect/Transform/IR/TransformTypes.td"
 include "mlir/Dialect/PDL/IR/PDLTypes.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/IR/OpBase.td"
@@ -60,4 +61,41 @@ def OneShotBufferizeOp
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// EmptyTensorToAllocTensorOp
+//===----------------------------------------------------------------------===//
+
+
+def Transform_EmptyOp : Transform_ConcreteOpType<"tensor.empty">;
+def Transform_AllocTensorOp : Transform_ConcreteOpType<"bufferization.alloc_tensor">;
+
+def EmptyTensorToAllocTensorOp 
+    : Op<Transform_Dialect, "bufferization.empty_tensor_to_alloc_tensor",
+        [FunctionalStyleTransformOpTrait, 
+         MemoryEffectsOpInterface,
+         TransformOpInterface,
+         TransformEachOpTrait]> {
+  let description = [{
+    Replace a tensor.empty with a bufferization.tensor_alloc.
+    
+    ### Return modes
+
+    This operation consumes the `target` handle and produces the `transformed`
+    handle. `target` is expected to be a `tensor.empty` operation. The transform
+    always succeeds.
+  }];
+
+  let arguments = (ins Transform_EmptyOp:$target);
+  let results = (outs Transform_AllocTensorOp:$transformed);
+
+  let assemblyFormat = "$target attr-dict `:` functional-type(operands, results)";
+
+  let extraClassDeclaration = [{
+    ::mlir::DiagnosedSilenceableFailure applyToOne(
+        ::mlir::tensor::EmptyOp target,
+        ::llvm::SmallVector<::mlir::Operation *> &results,
+        ::mlir::transform::TransformState &state);
+  }];
+}
+
 #endif // BUFFERIZATION_TRANSFORM_OPS
index 9415bf7..f7706d7 100644 (file)
@@ -14,6 +14,7 @@
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/PDL/IR/PDL.h"
 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
 
 using namespace mlir;
@@ -67,6 +68,23 @@ void transform::OneShotBufferizeOp::getEffects(
     effects.emplace_back(MemoryEffects::Free::get(), getTarget(),
                          TransformMappingResource::get());
 }
+
+//===----------------------------------------------------------------------===//
+// EmptyTensorToAllocTensorOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+EmptyTensorToAllocTensorOp::applyToOne(tensor::EmptyOp target,
+                                       SmallVector<Operation *> &results,
+                                       transform::TransformState &state) {
+  IRRewriter rewriter(target->getContext());
+  rewriter.setInsertionPoint(target);
+  auto alloc = rewriter.replaceOpWithNewOp<bufferization::AllocTensorOp>(
+      target, target.getType(), target.getDynamicSizes());
+  results.push_back(alloc);
+  return DiagnosedSilenceableFailure::success();
+}
+
 //===----------------------------------------------------------------------===//
 // Transform op registration
 //===----------------------------------------------------------------------===//
index 4ff8a23..ec537b6 100644 (file)
@@ -118,3 +118,19 @@ func.func @matmul(%A: tensor<12x9xf32>, %B: tensor<9x6xf32>, %C: tensor<12x6xf32
   // CHECK: return %[[C]] : memref<12x6xf32>
   return %D : tensor<12x6xf32>
 }
+
+// -----
+
+transform.sequence failures(propagate) {
+  ^bb0(%arg1: !pdl.operation):
+    %0 = transform.structured.match ops{["tensor.empty"]} in %arg1
+    %1 = transform.cast %0 : !pdl.operation to !transform.op<"tensor.empty">
+    transform.bufferization.empty_tensor_to_alloc_tensor %1 : (!transform.op<"tensor.empty">) -> !transform.op<"bufferization.alloc_tensor">
+}
+
+// Expect `bufferization.empty_tensor_to_alloc_tensor` to replace the tensor.empty.
+func.func @empty_to_tensor_alloc() -> tensor<2x2xf32> {
+  // CHECK: bufferization.alloc_tensor
+  %0 = tensor.empty() : tensor<2x2xf32>
+  return %0 : tensor<2x2xf32>
+}