Add types to the Loop (SCF) extension of the transform dialect.
See https://discourse.llvm.org/t/rfc-type-system-for-the-transform-dialect/65702
Reviewed By: nicolasvasilache
Differential Revision: https://reviews.llvm.org/D135587
#ifndef MLIR_DIALECT_SCF_TRANSFORMOPS_SCFTRANSFORMOPS_H
#define MLIR_DIALECT_SCF_TRANSFORMOPS_SCFTRANSFORMOPS_H
-#include "mlir/Dialect/PDL/IR/PDLTypes.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/Dialect/Transform/IR/TransformTypes.h"
#include "mlir/IR/OpImplementation.h"
namespace mlir {
include "mlir/Dialect/Transform/IR/TransformDialect.td"
include "mlir/Dialect/Transform/IR/TransformEffects.td"
include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
-include "mlir/Dialect/PDL/IR/PDLTypes.td"
+include "mlir/Dialect/Transform/IR/TransformTypes.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/OpBase.td"
+def Transform_ScfForOp : Transform_ConcreteOpType<"scf.for">;
+
def GetParentForOp : Op<Transform_Dialect, "loop.get_parent_for",
[NavigationTransformOpTrait, MemoryEffectsOpInterface,
DeclareOpInterfaceMethods<TransformOpInterface>]> {
}];
let arguments =
- (ins PDL_Operation:$target,
+ (ins TransformTypeInterface:$target,
DefaultValuedAttr<ConfinedAttr<I64Attr, [IntPositive]>,
"1">:$num_loops);
- let results = (outs PDL_Operation:$parent);
+ let results = (outs TransformTypeInterface:$parent);
- let assemblyFormat = "$target attr-dict";
+ let assemblyFormat =
+ "$target attr-dict `:` functional-type(operands, results)";
}
def LoopOutlineOp : Op<Transform_Dialect, "loop.outline",
order as the operand handle.
}];
- let arguments = (ins PDL_Operation:$target,
+ // Note that despite the name of the transform operation and related utility
+ // functions, the actual implementation does not require the operation to be
+ // a loop.
+ let arguments = (ins TransformTypeInterface:$target,
StrAttr:$func_name);
- let results = (outs PDL_Operation:$transformed);
+ let results = (outs TransformTypeInterface:$transformed);
- let assemblyFormat = "$target attr-dict";
+ let assemblyFormat =
+ "$target attr-dict `:` functional-type(operands, results)";
}
def LoopPeelOp : Op<Transform_Dialect, "loop.peel",
}];
let arguments =
- (ins PDL_Operation:$target,
+ (ins Transform_ScfForOp:$target,
DefaultValuedAttr<BoolAttr, "false">:$fail_if_already_divisible);
// TODO: Return both the peeled loop and the remainder loop.
- let results = (outs PDL_Operation:$transformed);
+ let results = (outs TransformTypeInterface:$transformed);
- let assemblyFormat = "$target attr-dict";
+ let assemblyFormat =
+ "$target attr-dict `:` functional-type(operands, results)";
let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
pipelined loops, which can be empty.
}];
- let arguments = (ins PDL_Operation:$target,
+ let arguments = (ins Transform_ScfForOp:$target,
DefaultValuedAttr<I64Attr, "1">:$iteration_interval,
DefaultValuedAttr<I64Attr, "10">:$read_latency);
- let results = (outs PDL_Operation:$transformed);
+ let results = (outs TransformTypeInterface:$transformed);
- let assemblyFormat = "$target attr-dict";
+ let assemblyFormat =
+ "$target attr-dict `:` functional-type(operands, results)";
let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
removed after a full unrolling.
}];
- let arguments = (ins PDL_Operation:$target,
+ let arguments = (ins Transform_ScfForOp:$target,
ConfinedAttr<I64Attr, [IntPositive]>:$factor);
- let assemblyFormat = "$target attr-dict";
+ let assemblyFormat = "$target attr-dict `:` type($target)";
let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
MLIRAffineDialect
MLIRFuncDialect
MLIRIR
- MLIRPDLDialect
MLIRSCFDialect
MLIRSCFTransforms
MLIRSCFUtils
#include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/Dialect/PDL/IR/PDL.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
using Base::Base;
void init() {
- declareDependentDialect<pdl::PDLDialect>();
-
declareGeneratedDialect<AffineDialect>();
declareGeneratedDialect<func::FuncDialect>();
try:
from ..ir import *
from ._ods_common import get_op_result_or_value as _get_op_result_or_value
- from ..dialects import pdl
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
"""Extension for GetParentForOp."""
def __init__(self,
+ result_type: Type,
target: Union[Operation, Value],
*,
num_loops: int = 1,
ip=None,
loc=None):
super().__init__(
- pdl.OperationType.get(),
+ result_type,
_get_op_result_or_value(target),
num_loops=_get_int64_attr(num_loops, default_value=1),
ip=ip,
"""Extension for LoopOutlineOp."""
def __init__(self,
+ result_type: Type,
target: Union[Operation, Value],
*,
func_name: Union[str, StringAttr],
ip=None,
loc=None):
super().__init__(
- pdl.OperationType.get(),
+ result_type,
_get_op_result_or_value(target),
func_name=(func_name if isinstance(func_name, StringAttr) else
StringAttr.get(func_name)),
"""Extension for LoopPeelOp."""
def __init__(self,
+ result_type: Type,
target: Union[Operation, Value],
*,
fail_if_already_divisible: Union[bool, BoolAttr] = False,
ip=None,
loc=None):
super().__init__(
- pdl.OperationType.get(),
+ result_type,
_get_op_result_or_value(target),
fail_if_already_divisible=(fail_if_already_divisible if isinstance(
fail_if_already_divisible, BoolAttr) else
"""Extension for LoopPipelineOp."""
def __init__(self,
+ result_type: Type,
target: Union[Operation, Value],
*,
iteration_interval: Optional[Union[int, IntegerAttr]] = None,
ip=None,
loc=None):
super().__init__(
- pdl.OperationType.get(),
+ result_type,
_get_op_result_or_value(target),
iteration_interval=_get_int64_attr(iteration_interval, default_value=1),
read_latency=_get_int64_attr(read_latency, default_value=10),
^bb1(%arg1: !pdl.operation):
%0 = transform.structured.match ops{["linalg.elemwise_binary"]} in %arg1
%1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [32, 32], tile_interchange = [0, 1]}
- transform.loop.peel %loops#0
+ %loop = transform.cast %loops#0 : !pdl.operation to !transform.op<"scf.for">
+ transform.loop.peel %loop : (!transform.op<"scf.for">) -> !pdl.operation
}
}
transform.sequence %arg0 : !pdl.operation failures(propagate) {
^bb1(%arg1: !pdl.operation):
%match_name = transform.structured.match ops{["arith.constant"]} in %arg1
- transform.test_print_remark_at_operand %match_name, "matched op name"
+ transform.test_print_remark_at_operand %match_name, "matched op name" : !pdl.operation
transform.test_consume_operand %match_name
%match_attr = transform.structured.match ops{["arith.constant"]} attributes{my_attr} in %arg1
- transform.test_print_remark_at_operand %match_attr, "matched attr name"
+ transform.test_print_remark_at_operand %match_attr, "matched attr name" : !pdl.operation
transform.test_consume_operand %match_attr
}
}
^bb1(%arg1: !pdl.operation):
%match_name = transform.structured.match
ops{["arith.constant"]} filter_result_type = f32 in %arg1
- transform.test_print_remark_at_operand %match_name, "matched op name"
+ transform.test_print_remark_at_operand %match_name, "matched op name" : !pdl.operation
transform.test_consume_operand %match_name
}
}
ops{["linalg.generic"]}
attributes{iterator_types = ["parallel", "parallel", "parallel"]}
in %arg1
- transform.test_print_remark_at_operand %match_attr, "matched complex attr"
+ transform.test_print_remark_at_operand %match_attr, "matched complex attr" : !pdl.operation
transform.test_consume_operand %match_attr
%no_match = transform.structured.match
%0 = transform.structured.match ops{["memref.alloc"]} in %arg1
%1 = transform.memref.multibuffer %0 {factor = 2 : i64}
// Verify that the returned handle is usable.
- transform.test_print_remark_at_operand %1, "transformed"
+ transform.test_print_remark_at_operand %1, "transformed" : !pdl.operation
}
^bb1(%arg1: !pdl.operation):
%0 = transform.structured.match ops{["arith.addi"]} in %arg1
// CHECK: = transform.loop.get_parent_for
- %1 = transform.loop.get_parent_for %0
- %2 = transform.loop.get_parent_for %0 { num_loops = 2 }
- %3 = transform.loop.get_parent_for %0 { num_loops = 3 }
- transform.test_print_remark_at_operand %1, "third loop"
- transform.test_print_remark_at_operand %2, "second loop"
- transform.test_print_remark_at_operand %3, "first loop"
+ %1 = transform.loop.get_parent_for %0 : (!pdl.operation) -> !transform.op<"scf.for">
+ %2 = transform.loop.get_parent_for %0 { num_loops = 2 } : (!pdl.operation) -> !transform.op<"scf.for">
+ %3 = transform.loop.get_parent_for %0 { num_loops = 3 } : (!pdl.operation) -> !transform.op<"scf.for">
+ transform.test_print_remark_at_operand %1, "third loop" : !transform.op<"scf.for">
+ transform.test_print_remark_at_operand %2, "second loop" : !transform.op<"scf.for">
+ transform.test_print_remark_at_operand %3, "first loop" : !transform.op<"scf.for">
}
}
^bb1(%arg1: !pdl.operation):
%0 = transform.structured.match ops{["arith.addi"]} in %arg1
// expected-error @below {{could not find an 'scf.for' parent}}
- %1 = transform.loop.get_parent_for %0
+ %1 = transform.loop.get_parent_for %0 : (!pdl.operation) -> !transform.op<"scf.for">
}
}
sequence %arg0 : !pdl.operation failures(propagate) {
^bb1(%arg1: !pdl.operation):
%0 = transform.structured.match ops{["arith.addi"]} in %arg1
- %1 = transform.loop.get_parent_for %0
+ %1 = transform.loop.get_parent_for %0 : (!pdl.operation) -> !transform.op<"scf.for">
// CHECK: = transform.loop.outline %{{.*}}
- transform.loop.outline %1 {func_name = "foo"}
+ transform.loop.outline %1 {func_name = "foo"} : (!transform.op<"scf.for">) -> !pdl.operation
}
}
^bb1(%arg1: !pdl.operation):
%0 = transform.structured.match ops{["scf.while"]} in %arg1
// expected-error @below {{failed to outline}}
- transform.loop.outline %0 {func_name = "foo"}
+ transform.loop.outline %0 {func_name = "foo"} : (!pdl.operation) -> !pdl.operation
}
}
sequence %arg0 : !pdl.operation failures(propagate) {
^bb1(%arg1: !pdl.operation):
%0 = transform.structured.match ops{["arith.addi"]} in %arg1
- %1 = transform.loop.get_parent_for %0
- transform.loop.peel %1
+ %1 = transform.loop.get_parent_for %0 : (!pdl.operation) -> !transform.op<"scf.for">
+ transform.loop.peel %1 : (!transform.op<"scf.for">) -> !pdl.operation
}
}
sequence %arg0 : !pdl.operation failures(propagate) {
^bb1(%arg1: !pdl.operation):
%0 = transform.structured.match ops{["arith.addf"]} in %arg1
- %1 = transform.loop.get_parent_for %0
- %2 = transform.loop.pipeline %1
+ %1 = transform.loop.get_parent_for %0 : (!pdl.operation) -> !transform.op<"scf.for">
+ %2 = transform.loop.pipeline %1 : (!transform.op<"scf.for">) -> !pdl.operation
// Verify that the returned handle is usable.
- transform.test_print_remark_at_operand %2, "transformed"
+ transform.test_print_remark_at_operand %2, "transformed" : !pdl.operation
}
}
sequence %arg0 : !pdl.operation failures(propagate) {
^bb1(%arg1: !pdl.operation):
%0 = transform.structured.match ops{["arith.addi"]} in %arg1
- %1 = transform.loop.get_parent_for %0
- transform.loop.unroll %1 { factor = 4 }
+ %1 = transform.loop.get_parent_for %0 : (!pdl.operation) -> !transform.op<"scf.for">
+ transform.loop.unroll %1 { factor = 4 } : !transform.op<"scf.for">
}
}
// expected-note @below {{invalidated by this transform op that consumes its operand #0}}
test_consume_operand %1
// expected-error @below {{op uses a handle invalidated by a previously executed transform op}}
- test_print_remark_at_operand %0, "remark"
+ test_print_remark_at_operand %0, "remark" : !pdl.operation
}
}
%2 = replicate num(%0) %1 : !pdl.operation, !pdl.operation
// expected-error @below {{a handle passed as operand #0 and consumed by this operation points to a payload operation more than once}}
test_consume_operand %2
- test_print_remark_at_operand %0, "remark"
+ test_print_remark_at_operand %0, "remark" : !pdl.operation
}
}
sequence %arg0 : !pdl.operation failures(propagate) {
^bb0(%arg1: !pdl.operation):
%0 = pdl_match @some in %arg1 : (!pdl.operation) -> !pdl.operation
- test_print_remark_at_operand %0, "matched"
+ test_print_remark_at_operand %0, "matched" : !pdl.operation
}
pdl.pattern @some : benefit(1) {
%f = pdl_match @const in %arg1 : (!pdl.operation) -> !pdl.operation
// CHECK: %{{.+}} = get_closest_isolated_parent %{{.+}}
%m = get_closest_isolated_parent %f : (!pdl.operation) -> !pdl.operation
- test_print_remark_at_operand %m, "parent function"
+ test_print_remark_at_operand %m, "parent function" : !pdl.operation
}
}
}, {
^bb2(%arg2: !pdl.operation):
%2 = transform.pdl_match @match_call in %arg2 : (!pdl.operation) -> !pdl.operation
- transform.test_print_remark_at_operand %2, "still here"
+ transform.test_print_remark_at_operand %2, "still here" : !pdl.operation
// This alternative succeeds.
}, {
^bb2(%arg2: !pdl.operation):
sequence %arg0 : !pdl.operation failures(propagate) {
^bb1(%arg1: !pdl.operation):
%0 = transform.pdl_match @match_const in %arg1 : (!pdl.operation) -> !pdl.operation
- %1 = transform.loop.get_parent_for %0
+ %1 = transform.loop.get_parent_for %0 : (!pdl.operation) -> !pdl.operation
// expected-error @below {{only isolated-from-above ops can be alternative scopes}}
alternatives %1 : !pdl.operation {
^bb2(%arg2: !pdl.operation):
%0 = pdl_match @addi in %arg1 : (!pdl.operation) -> !pdl.operation
%1 = pdl_match @subi in %arg1 : (!pdl.operation) -> !pdl.operation
%2 = merge_handles %0, %1 : !pdl.operation
- test_print_remark_at_operand %2, "matched"
+ test_print_remark_at_operand %2, "matched" : !pdl.operation
}
}
^bb2(%arg2: !pdl.operation):
// expected-remark @below {{1}}
transform.test_print_number_of_associated_payload_ir_ops %arg2
- transform.test_print_remark_at_operand %arg2, "transform applied"
+ transform.test_print_remark_at_operand %arg2, "transform applied" : !pdl.operation
}
}
}
// expected-remark @below {{3}}
transform.test_print_number_of_associated_payload_ir_ops %results
- transform.test_print_remark_at_operand %results, "transform applied"
+ transform.test_print_remark_at_operand %results, "transform applied" : !pdl.operation
}
}
^bb1(%arg1: !pdl.operation):
%addi = transform.structured.match ops{["arith.addi"]} in %arg1
%muli = get_producer_of_operand %addi[0] : (!pdl.operation) -> !pdl.operation
- transform.test_print_remark_at_operand %muli, "found muli"
+ transform.test_print_remark_at_operand %muli, "found muli" : !pdl.operation
}
// -----
: Op<Transform_Dialect, "test_print_remark_at_operand",
[DeclareOpInterfaceMethods<TransformOpInterface>]> {
let arguments = (ins
- Arg<PDL_Operation, "",
+ Arg<TransformTypeInterface, "",
[TransformMappingRead, PayloadIRRead]>:$operand,
StrAttr:$message);
- let assemblyFormat = "$operand `,` $message attr-dict";
+ let assemblyFormat =
+ "$operand `,` $message attr-dict `:` type($operand)";
let cppNamespace = "::mlir::test";
}
@run
def getParentLoop():
- sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get())
+ sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE,
+ [], pdl.OperationType.get())
with InsertionPoint(sequence.body):
- loop.GetParentForOp(sequence.bodyTarget, num_loops=2)
+ loop.GetParentForOp(transform.OperationType.get("scf.for"), sequence.bodyTarget, num_loops=2)
transform.YieldOp()
# CHECK-LABEL: TEST: getParentLoop
# CHECK: = transform.loop.get_parent_for %
@run
def loopOutline():
- sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get())
+ sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE,
+ [], transform.OperationType.get("scf.for"))
with InsertionPoint(sequence.body):
- loop.LoopOutlineOp(sequence.bodyTarget, func_name="foo")
+ loop.LoopOutlineOp(pdl.OperationType.get(), sequence.bodyTarget, func_name="foo")
transform.YieldOp()
# CHECK-LABEL: TEST: loopOutline
# CHECK: = transform.loop.outline %
@run
def loopPeel():
- sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get())
+ sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE,
+ [], transform.OperationType.get("scf.for"))
with InsertionPoint(sequence.body):
- loop.LoopPeelOp(sequence.bodyTarget)
+ loop.LoopPeelOp(pdl.OperationType.get(), sequence.bodyTarget)
transform.YieldOp()
# CHECK-LABEL: TEST: loopPeel
# CHECK: = transform.loop.peel %
@run
def loopPipeline():
- sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get())
+ sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE,
+ [], transform.OperationType.get("scf.for"))
with InsertionPoint(sequence.body):
- loop.LoopPipelineOp(sequence.bodyTarget, iteration_interval=3)
+ loop.LoopPipelineOp(pdl.OperationType.get(), sequence.bodyTarget, iteration_interval=3)
transform.YieldOp()
# CHECK-LABEL: TEST: loopPipeline
# CHECK: = transform.loop.pipeline %
@run
def loopUnroll():
- sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], pdl.OperationType.get())
+ sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE,
+ [], transform.OperationType.get("scf.for"))
with InsertionPoint(sequence.body):
loop.LoopUnrollOp(sequence.bodyTarget, factor=42)
transform.YieldOp()