include "mlir/IR/OpBase.td"
include "mlir/IR/RegionKindInterface.td"
+//===----------------------------------------------------------------------===//
+// DecomposeOp
+//===----------------------------------------------------------------------===//
+
def DecomposeOp : Op<Transform_Dialect, "structured.decompose",
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
TransformOpInterface, TransformEachOpTrait]> {
}];
}
+//===----------------------------------------------------------------------===//
+// FuseOp
+//===----------------------------------------------------------------------===//
+
def FuseOp : Op<Transform_Dialect, "structured.fuse",
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
DeclareOpInterfaceMethods<TransformOpInterface>]> {
let hasVerifier = 1;
}
+//===----------------------------------------------------------------------===//
+// FuseIntoContainingOp
+//===----------------------------------------------------------------------===//
+
def FuseIntoContainingOp :
Op<Transform_Dialect, "structured.fuse_into_containing_op",
[DeclareOpInterfaceMethods<TransformOpInterface>]> {
];
}
+//===----------------------------------------------------------------------===//
+// GeneralizeOp
+//===----------------------------------------------------------------------===//
+
def GeneralizeOp : Op<Transform_Dialect, "structured.generalize",
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
TransformOpInterface, TransformEachOpTrait]> {
}];
}
+//===----------------------------------------------------------------------===//
+// InterchangeOp
+//===----------------------------------------------------------------------===//
+
def InterchangeOp : Op<Transform_Dialect, "structured.interchange",
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
TransformOpInterface, TransformEachOpTrait]> {
let arguments =
(ins PDL_Operation:$target,
- DefaultValuedAttr<I64ArrayAttr, "{}">:$iterator_interchange);
+ ConfinedAttr<DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">,
+ [DenseArrayNonNegative<DenseI64ArrayAttr>]>:$iterator_interchange);
let results = (outs PDL_Operation:$transformed);
- let assemblyFormat = "$target attr-dict";
+ let assemblyFormat = [{
+ $target
+ (`iterator_interchange` `=` $iterator_interchange^)? attr-dict
+ }];
let hasVerifier = 1;
let extraClassDeclaration = [{
}];
}
+//===----------------------------------------------------------------------===//
+// MatchOp
+//===----------------------------------------------------------------------===//
+
def MatchInterfaceEnum : I32EnumAttr<"MatchInterfaceEnum", "An interface to match",
[
I32EnumAttrCase<"LinalgOp", 0>,
}];
}
+//===----------------------------------------------------------------------===//
+// MultiTileSizesOp
+//===----------------------------------------------------------------------===//
+
def MultiTileSizesOp : Op<Transform_Dialect, "structured.multitile_sizes",
[DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
TransformOpInterface, TransformEachOpTrait]> {
}];
}
+//===----------------------------------------------------------------------===//
+// PadOp
+//===----------------------------------------------------------------------===//
+
def PadOp : Op<Transform_Dialect, "structured.pad",
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
TransformOpInterface, TransformEachOpTrait]> {
}];
}
+//===----------------------------------------------------------------------===//
+// PromoteOp
+//===----------------------------------------------------------------------===//
+
def PromoteOp : Op<Transform_Dialect, "structured.promote",
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
TransformOpInterface, TransformEachOpTrait]> {
}];
}
+//===----------------------------------------------------------------------===//
+// ReplaceOp
+//===----------------------------------------------------------------------===//
+
def ReplaceOp : Op<Transform_Dialect, "structured.replace",
[IsolatedFromAbove, DeclareOpInterfaceMethods<TransformOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>] # GraphRegionNoTerminator.traits> {
let hasVerifier = 1;
}
+//===----------------------------------------------------------------------===//
+// ScalarizeOp
+//===----------------------------------------------------------------------===//
+
def ScalarizeOp : Op<Transform_Dialect, "structured.scalarize",
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
TransformOpInterface, TransformEachOpTrait]> {
}];
}
+//===----------------------------------------------------------------------===//
+// SplitOp
+//===----------------------------------------------------------------------===//
+
def SplitOp : Op<Transform_Dialect, "structured.split",
[DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
DeclareOpInterfaceMethods<TransformOpInterface>]> {
let hasCustomAssemblyFormat = 1;
}
+//===----------------------------------------------------------------------===//
+// SplitReductionOp
+//===----------------------------------------------------------------------===//
+
def SplitReductionOp : Op<Transform_Dialect, "structured.split_reduction",
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
TransformEachOpTrait, TransformOpInterface]> {
}];
}
+//===----------------------------------------------------------------------===//
+// TileReductionUsingScfOp
+//===----------------------------------------------------------------------===//
+
def TileReductionUsingScfOp : Op<Transform_Dialect, "structured.tile_reduction_using_scf",
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
TransformEachOpTrait, TransformOpInterface]> {
}];
}
+//===----------------------------------------------------------------------===//
+// TileReductionUsingForeachThreadOp
+//===----------------------------------------------------------------------===//
+
def TileReductionUsingForeachThreadOp :
Op<Transform_Dialect, "structured.tile_reduction_using_foreach_thread",
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
}
+//===----------------------------------------------------------------------===//
+// TileOp
+//===----------------------------------------------------------------------===//
+
def TileOp : Op<Transform_Dialect, "structured.tile",
[DeclareOpInterfaceMethods<TransformOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
}];
}
+//===----------------------------------------------------------------------===//
+// TileToForeachThreadOp
+//===----------------------------------------------------------------------===//
+
def TileToForeachThreadOp :
Op<Transform_Dialect, "structured.tile_to_foreach_thread_op",
[AttrSizedOperandSegments,
}];
}
+//===----------------------------------------------------------------------===//
+// TileToScfForOp
+//===----------------------------------------------------------------------===//
+
def TileToScfForOp : Op<Transform_Dialect, "structured.tile_to_scf_for",
[DeclareOpInterfaceMethods<TransformOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
}];
}
+//===----------------------------------------------------------------------===//
+// VectorizeOp
+//===----------------------------------------------------------------------===//
+
def VectorizeOp : Op<Transform_Dialect, "structured.vectorize",
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
TransformEachOpTrait, TransformOpInterface]> {
#define DEBUG_TYPE "linalg-transforms"
-/// Extracts a vector of unsigned from an array attribute. Asserts if the
-/// attribute contains values other than intergers. May truncate.
-static SmallVector<unsigned> extractUIntArray(ArrayAttr attr) {
- SmallVector<unsigned> result;
- result.reserve(attr.size());
- for (APInt value : attr.getAsValueRange<IntegerAttr>())
- result.push_back(value.getZExtValue());
- return result;
-}
-
/// Attempts to apply the pattern specified as template argument to the given
/// operation. The pattern is expected to have a `returningMatchAndRewrite`
/// function that returns the "main" result or failure. Returns failure if the
transform::InterchangeOp::applyToOne(linalg::GenericOp target,
SmallVectorImpl<Operation *> &results,
transform::TransformState &state) {
- SmallVector<unsigned> interchangeVector =
- extractUIntArray(getIteratorInterchange());
+ ArrayRef<int64_t> interchangeVector = getIteratorInterchange();
// Exit early if no transformation is needed.
if (interchangeVector.empty()) {
results.push_back(target);
}
TrivialPatternRewriter rewriter(target->getContext());
FailureOr<GenericOp> res =
- interchangeGenericOp(rewriter, target, interchangeVector);
+ interchangeGenericOp(rewriter, target,
+ SmallVector<unsigned>(interchangeVector.begin(),
+ interchangeVector.end()));
if (failed(res))
return DiagnosedSilenceableFailure::definiteFailure();
results.push_back(res->getOperation());
}
LogicalResult transform::InterchangeOp::verify() {
- SmallVector<unsigned> permutation =
- extractUIntArray(getIteratorInterchange());
- auto sequence = llvm::to_vector(llvm::seq<unsigned>(0, permutation.size()));
+ ArrayRef<int64_t> permutation = getIteratorInterchange();
+ auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size()));
if (!std::is_permutation(sequence.begin(), sequence.end(),
permutation.begin(), permutation.end())) {
return emitOpError()