#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/PDL/IR/PDLTypes.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/Dialect/Transform/IR/TransformTypes.h"
#include "mlir/IR/OpImplementation.h"
+namespace mlir {
+namespace tensor {
+class EmptyOp;
+} // namespace tensor
+} // namespace mlir
+
//===----------------------------------------------------------------------===//
// Bufferization Transform Operations
//===----------------------------------------------------------------------===//
include "mlir/Dialect/Transform/IR/TransformDialect.td"
include "mlir/Dialect/Transform/IR/TransformEffects.td"
include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
+include "mlir/Dialect/Transform/IR/TransformTypes.td"
include "mlir/Dialect/PDL/IR/PDLTypes.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/OpBase.td"
}];
}
+//===----------------------------------------------------------------------===//
+// EmptyTensorToAllocTensorOp
+//===----------------------------------------------------------------------===//
+
+
+def Transform_EmptyOp : Transform_ConcreteOpType<"tensor.empty">;
+def Transform_AllocTensorOp : Transform_ConcreteOpType<"bufferization.alloc_tensor">;
+
+def EmptyTensorToAllocTensorOp
+ : Op<Transform_Dialect, "bufferization.empty_tensor_to_alloc_tensor",
+ [FunctionalStyleTransformOpTrait,
+ MemoryEffectsOpInterface,
+ TransformOpInterface,
+ TransformEachOpTrait]> {
+ let description = [{
+ Replace a tensor.empty with a bufferization.tensor_alloc.
+
+ ### Return modes
+
+ This operation consumes the `target` handle and produces the `transformed`
+ handle. `target` is expected to be a `tensor.empty` operation. The transform
+ always succeeds.
+ }];
+
+ let arguments = (ins Transform_EmptyOp:$target);
+ let results = (outs Transform_AllocTensorOp:$transformed);
+
+ let assemblyFormat = "$target attr-dict `:` functional-type(operands, results)";
+
+ let extraClassDeclaration = [{
+ ::mlir::DiagnosedSilenceableFailure applyToOne(
+ ::mlir::tensor::EmptyOp target,
+ ::llvm::SmallVector<::mlir::Operation *> &results,
+ ::mlir::transform::TransformState &state);
+ }];
+}
+
#endif // BUFFERIZATION_TRANSFORM_OPS
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/PDL/IR/PDL.h"
#include "mlir/Dialect/PDL/IR/PDLTypes.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
using namespace mlir;
effects.emplace_back(MemoryEffects::Free::get(), getTarget(),
TransformMappingResource::get());
}
+
+//===----------------------------------------------------------------------===//
+// EmptyTensorToAllocTensorOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+EmptyTensorToAllocTensorOp::applyToOne(tensor::EmptyOp target,
+ SmallVector<Operation *> &results,
+ transform::TransformState &state) {
+ IRRewriter rewriter(target->getContext());
+ rewriter.setInsertionPoint(target);
+ auto alloc = rewriter.replaceOpWithNewOp<bufferization::AllocTensorOp>(
+ target, target.getType(), target.getDynamicSizes());
+ results.push_back(alloc);
+ return DiagnosedSilenceableFailure::success();
+}
+
//===----------------------------------------------------------------------===//
// Transform op registration
//===----------------------------------------------------------------------===//
// CHECK: return %[[C]] : memref<12x6xf32>
return %D : tensor<12x6xf32>
}
+
+// -----
+
+transform.sequence failures(propagate) {
+ ^bb0(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["tensor.empty"]} in %arg1
+ %1 = transform.cast %0 : !pdl.operation to !transform.op<"tensor.empty">
+ transform.bufferization.empty_tensor_to_alloc_tensor %1 : (!transform.op<"tensor.empty">) -> !transform.op<"bufferization.alloc_tensor">
+}
+
+// Expect `bufferization.empty_tensor_to_alloc_tensor` to replace the tensor.empty.
+func.func @empty_to_tensor_alloc() -> tensor<2x2xf32> {
+ // CHECK: bufferization.alloc_tensor
+ %0 = tensor.empty() : tensor<2x2xf32>
+ return %0 : tensor<2x2xf32>
+}