include "mlir/Dialect/Transform/IR/TransformDialect.td"
include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
+include "mlir/Dialect/Transform/IR/TransformTypes.td"
include "mlir/Dialect/PDL/IR/PDLTypes.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/OpBase.td"
+def Transform_MemRefAllocOp : Transform_ConcreteOpType<"memref.alloc">;
+
def MemRefMultiBufferOp : Op<Transform_Dialect, "memref.multibuffer",
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
- TransformOpInterface, TransformEachOpTrait]> {
+ DeclareOpInterfaceMethods<TransformOpInterface>]> {
let summary = "Multibuffers an allocation";
let description = [{
Transformation to do multi-buffering/array expansion to remove
}];
let arguments =
- (ins PDL_Operation:$target,
+ (ins Transform_MemRefAllocOp:$target,
ConfinedAttr<I64Attr, [IntPositive]>:$factor);
let results = (outs PDL_Operation:$transformed);
- let assemblyFormat = "$target attr-dict";
-
- let extraClassDeclaration = [{
- ::mlir::DiagnosedSilenceableFailure applyToOne(
- memref::AllocOp target,
- ::mlir::transform::ApplyToEachResultList &results,
- ::mlir::transform::TransformState &state);
- }];
+ let assemblyFormat =
+ "$target attr-dict `:` functional-type(operands, results)";
}
#endif // MEMREF_TRANSFORM_OPS
#include "mlir/Dialect/PDL/IR/PDL.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/Interfaces/LoopLikeInterface.h"
using namespace mlir;
// MemRefMultiBufferOp
//===----------------------------------------------------------------------===//
-DiagnosedSilenceableFailure transform::MemRefMultiBufferOp::applyToOne(
- memref::AllocOp target, transform::ApplyToEachResultList &results,
+DiagnosedSilenceableFailure transform::MemRefMultiBufferOp::apply(
+ transform::TransformResults &transformResults,
transform::TransformState &state) {
- auto newBuffer = memref::multiBuffer(target, getFactor());
- if (failed(newBuffer))
- return emitSilenceableFailure(target->getLoc())
- << "op failed to multibuffer";
+ SmallVector<Operation *> results;
+ ArrayRef<Operation *> payloadOps = state.getPayloadOps(getTarget());
+ for (auto *op : payloadOps) {
+ bool canApplyMultiBuffer = true;
+ auto target = cast<memref::AllocOp>(op);
+ // Skip allocations not used in a loop.
+ for (Operation *user : target->getUsers()) {
+ auto loop = user->getParentOfType<LoopLikeOpInterface>();
+ if (!loop) {
+ canApplyMultiBuffer = false;
+ break;
+ }
+ }
+ if (!canApplyMultiBuffer)
+ continue;
- results.push_back(*newBuffer);
+ auto newBuffer = memref::multiBuffer(target, getFactor());
+ if (failed(newBuffer))
+ return emitSilenceableFailure(target->getLoc())
+ << "op failed to multibuffer";
+
+ results.push_back(*newBuffer);
+ }
+ transformResults.set(getResult().cast<OpResult>(), results);
return DiagnosedSilenceableFailure::success();
}
transform.sequence failures(propagate) {
^bb1(%arg1: !pdl.operation):
- %0 = transform.structured.match ops{["memref.alloc"]} in %arg1 : (!pdl.operation) -> !pdl.operation
- %1 = transform.memref.multibuffer %0 {factor = 2 : i64}
+ %0 = transform.structured.match ops{["memref.alloc"]} in %arg1 : (!pdl.operation) -> !transform.op<"memref.alloc">
+ %1 = transform.memref.multibuffer %0 {factor = 2 : i64} : (!transform.op<"memref.alloc">) -> !pdl.operation
// Verify that the returned handle is usable.
transform.test_print_remark_at_operand %1, "transformed" : !pdl.operation
}
// -----
-// Trying to use multibuffer on alloc that are used outside of loops is
-// going to fail.
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0) -> ((d0 floordiv 4) mod 2)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0)[s0] -> (d0 + s0)>
+
+// CHECK-LABEL: func @multi_buffer_on_affine_loop
+func.func @multi_buffer_on_affine_loop(%in: memref<16xf32>) {
+ // CHECK: %[[A:.*]] = memref.alloc() : memref<2x4xf32>
+ // expected-remark @below {{transformed}}
+ %tmp = memref.alloc() : memref<4xf32>
+
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ %c0 = arith.constant 0 : index
+
+ // CHECK: affine.for %[[IV:.*]] = 0
+ affine.for %i0 = 0 to 16 step 4 {
+ // CHECK: %[[I:.*]] = affine.apply #[[$MAP0]](%[[IV]])
+ // CHECK: %[[SV:.*]] = memref.subview %[[A]][%[[I]], 0] [1, 4] [1, 1] : memref<2x4xf32> to memref<4xf32, strided<[1], offset: ?>>
+ %1 = memref.subview %in[%i0] [4] [1] : memref<16xf32> to memref<4xf32, affine_map<(d0)[s0] -> (d0 + s0)>>
+ // CHECK: memref.copy %{{.*}}, %[[SV]] : memref<4xf32, #[[$MAP1]]> to memref<4xf32, strided<[1], offset: ?>>
+ memref.copy %1, %tmp : memref<4xf32, affine_map<(d0)[s0] -> (d0 + s0)>> to memref<4xf32>
+
+ "some_use"(%tmp) : (memref<4xf32>) ->()
+ }
+ return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["memref.alloc"]} in %arg1 : (!pdl.operation) -> !transform.op<"memref.alloc">
+ %1 = transform.memref.multibuffer %0 {factor = 2 : i64} : (!transform.op<"memref.alloc">) -> !pdl.operation
+ // Verify that the returned handle is usable.
+ transform.test_print_remark_at_operand %1, "transformed" : !pdl.operation
+}
+
+// -----
+
+// Trying to use multibuffer on allocs that are used in different loops
+// with none dominating the other is going to fail.
// Check that we emit a proper error for that.
-func.func @multi_buffer_uses_outside_of_loop(%in: memref<16xf32>) {
+func.func @multi_buffer_uses_with_no_loop_dominator(%in: memref<16xf32>, %cond: i1) {
// expected-error @below {{op failed to multibuffer}}
%tmp = memref.alloc() : memref<4xf32>
- "some_outside_loop_use"(%tmp) : (memref<4xf32>) -> ()
+ %c0 = arith.constant 0 : index
+ %c4 = arith.constant 4 : index
+ %c16 = arith.constant 16 : index
+ scf.if %cond {
+ scf.for %i0 = %c0 to %c16 step %c4 {
+ %var = memref.subview %in[%i0] [4] [1] : memref<16xf32> to memref<4xf32, affine_map<(d0)[s0] -> (d0 + s0)>>
+ memref.copy %var, %tmp : memref<4xf32, affine_map<(d0)[s0] -> (d0 + s0)>> to memref<4xf32>
+
+ "some_use"(%tmp) : (memref<4xf32>) ->()
+ }
+ }
+
+ scf.for %i0 = %c0 to %c16 step %c4 {
+ %1 = memref.subview %in[%i0] [4] [1] : memref<16xf32> to memref<4xf32, affine_map<(d0)[s0] -> (d0 + s0)>>
+ memref.copy %1, %tmp : memref<4xf32, affine_map<(d0)[s0] -> (d0 + s0)>> to memref<4xf32>
+
+ "some_use"(%tmp) : (memref<4xf32>) ->()
+ }
+ return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["memref.alloc"]} in %arg1 : (!pdl.operation) -> !transform.op<"memref.alloc">
+ %1 = transform.memref.multibuffer %0 {factor = 2 : i64} : (!transform.op<"memref.alloc">) -> !pdl.operation
+}
+
+// -----
+
+// Make sure the multibuffer operation is typed so that it only supports
+// memref.alloc.
+// Check that we emit an error if we try to match something else.
+func.func @multi_buffer_reject_alloca(%in: memref<16xf32>, %cond: i1) {
+ %tmp = memref.alloca() : memref<4xf32>
%c0 = arith.constant 0 : index
%c4 = arith.constant 4 : index
%c16 = arith.constant 16 : index
+ scf.if %cond {
+ scf.for %i0 = %c0 to %c16 step %c4 {
+ %var = memref.subview %in[%i0] [4] [1] : memref<16xf32> to memref<4xf32, affine_map<(d0)[s0] -> (d0 + s0)>>
+ memref.copy %var, %tmp : memref<4xf32, affine_map<(d0)[s0] -> (d0 + s0)>> to memref<4xf32>
+
+ "some_use"(%tmp) : (memref<4xf32>) ->()
+ }
+ }
scf.for %i0 = %c0 to %c16 step %c4 {
%1 = memref.subview %in[%i0] [4] [1] : memref<16xf32> to memref<4xf32, affine_map<(d0)[s0] -> (d0 + s0)>>
transform.sequence failures(propagate) {
^bb1(%arg1: !pdl.operation):
- %0 = transform.structured.match ops{["memref.alloc"]} in %arg1 : (!pdl.operation) -> !pdl.operation
- %1 = transform.memref.multibuffer %0 {factor = 2 : i64}
+ %0 = transform.structured.match ops{["memref.alloca"]} in %arg1 : (!pdl.operation) -> !transform.op<"memref.alloca">
+ // expected-error @below {{'transform.memref.multibuffer' op operand #0 must be Transform IR handle to memref.alloc operations, but got '!transform.op<"memref.alloca">'}}
+ %1 = transform.memref.multibuffer %0 {factor = 2 : i64} : (!transform.op<"memref.alloca">) -> !pdl.operation
+}
+
+// -----
+
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0) -> ((d0 floordiv 4) mod 2)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0)[s0] -> (d0 + s0)>
+
+// CHECK-LABEL: func @multi_buffer_one_alloc_with_use_outside_of_loop
+// Make sure we manage to apply multi_buffer to the memref that is used in
+// the loop (%tmp) and don't error out for the one that is not (%tmp2).
+func.func @multi_buffer_one_alloc_with_use_outside_of_loop(%in: memref<16xf32>) {
+ // CHECK: %[[A:.*]] = memref.alloc() : memref<2x4xf32>
+ // expected-remark @below {{transformed}}
+ %tmp = memref.alloc() : memref<4xf32>
+ %tmp2 = memref.alloc() : memref<4xf32>
+
+ "some_use_outside_of_loop"(%tmp2) : (memref<4xf32>) -> ()
+
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK: %[[C4:.*]] = arith.constant 4 : index
+ %c0 = arith.constant 0 : index
+ %c4 = arith.constant 4 : index
+ %c16 = arith.constant 16 : index
+
+ // CHECK: scf.for %[[IV:.*]] = %[[C0]]
+ scf.for %i0 = %c0 to %c16 step %c4 {
+ // CHECK: %[[I:.*]] = affine.apply #[[$MAP0]](%[[IV]])
+ // CHECK: %[[SV:.*]] = memref.subview %[[A]][%[[I]], 0] [1, 4] [1, 1] : memref<2x4xf32> to memref<4xf32, strided<[1], offset: ?>>
+ %1 = memref.subview %in[%i0] [4] [1] : memref<16xf32> to memref<4xf32, affine_map<(d0)[s0] -> (d0 + s0)>>
+ // CHECK: memref.copy %{{.*}}, %[[SV]] : memref<4xf32, #[[$MAP1]]> to memref<4xf32, strided<[1], offset: ?>>
+ memref.copy %1, %tmp : memref<4xf32, affine_map<(d0)[s0] -> (d0 + s0)>> to memref<4xf32>
+
+ "some_use"(%tmp) : (memref<4xf32>) ->()
+ }
+ return
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["memref.alloc"]} in %arg1 : (!pdl.operation) -> !transform.op<"memref.alloc">
+ %1 = transform.memref.multibuffer %0 {factor = 2 : i64} : (!transform.op<"memref.alloc">) -> !pdl.operation
+ // Verify that the returned handle is usable.
+ transform.test_print_remark_at_operand %1, "transformed" : !pdl.operation
}