From 2e5fe721724446265d1ea48267b6a34d33fca14b Mon Sep 17 00:00:00 2001 From: Lorenzo Chelini Date: Fri, 9 Dec 2022 18:50:36 +0100 Subject: [PATCH] [MLIR][Linalg] Use `DenseI64ArrayAttr` in `InterchangeOp` (NFC) Use op separator to improve code navigation. Reviewed By: hanchung Differential Revision: https://reviews.llvm.org/D139917 --- .../Linalg/TransformOps/LinalgTransformOps.td | 84 +++++++++++++++++++++- .../Linalg/TransformOps/LinalgTransformOps.cpp | 22 ++---- mlir/lib/Dialect/Transform/IR/TransformOps.cpp | 2 +- .../Dialect/Linalg/transform-op-interchange.mlir | 4 +- .../test/Dialect/Linalg/transform-ops-invalid.mlir | 12 +++- mlir/test/Dialect/Linalg/transform-patterns.mlir | 2 +- 6 files changed, 102 insertions(+), 24 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td index 1cac6b8..1cb321d 100644 --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -19,6 +19,10 @@ include "mlir/IR/EnumAttr.td" include "mlir/IR/OpBase.td" include "mlir/IR/RegionKindInterface.td" +//===----------------------------------------------------------------------===// +// DecomposeOp +//===----------------------------------------------------------------------===// + def DecomposeOp : Op { @@ -48,6 +52,10 @@ def DecomposeOp : Op]> { @@ -67,6 +75,10 @@ def FuseOp : Op]> { @@ -120,6 +132,10 @@ def FuseIntoContainingOp : ]; } +//===----------------------------------------------------------------------===// +// GeneralizeOp +//===----------------------------------------------------------------------===// + def GeneralizeOp : Op { @@ -149,6 +165,10 @@ def GeneralizeOp : Op { @@ -169,10 +189,14 @@ def InterchangeOp : Op:$iterator_interchange); + ConfinedAttr, + [DenseArrayNonNegative]>:$iterator_interchange); let results = (outs PDL_Operation:$transformed); - let assemblyFormat = "$target attr-dict"; + let assemblyFormat = [{ + $target + (`iterator_interchange` `=` $iterator_interchange^)? attr-dict + }]; let hasVerifier = 1; let extraClassDeclaration = [{ @@ -183,6 +207,10 @@ def InterchangeOp : Op, @@ -245,6 +273,10 @@ def MatchOp : Op, TransformOpInterface, TransformEachOpTrait]> { @@ -309,6 +341,10 @@ def MultiTileSizesOp : Op { @@ -349,6 +385,10 @@ def PadOp : Op { @@ -388,6 +428,10 @@ def PromoteOp : Op, DeclareOpInterfaceMethods] # GraphRegionNoTerminator.traits> { @@ -410,6 +454,10 @@ def ReplaceOp : Op { @@ -449,6 +497,10 @@ def ScalarizeOp : Op, DeclareOpInterfaceMethods]> { @@ -481,6 +533,10 @@ def SplitOp : Op { @@ -649,6 +705,10 @@ def SplitReductionOp : Op { @@ -748,6 +808,10 @@ def TileReductionUsingScfOp : Op, DeclareOpInterfaceMethods]> { @@ -910,6 +978,10 @@ def TileOp : Op, DeclareOpInterfaceMethods]> { @@ -1080,6 +1156,10 @@ def TileToScfForOp : Op { diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index c8995e6..3138268 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -34,16 +34,6 @@ using namespace mlir::transform; #define DEBUG_TYPE "linalg-transforms" -/// Extracts a vector of unsigned from an array attribute. Asserts if the -/// attribute contains values other than intergers. May truncate. -static SmallVector extractUIntArray(ArrayAttr attr) { - SmallVector result; - result.reserve(attr.size()); - for (APInt value : attr.getAsValueRange()) - result.push_back(value.getZExtValue()); - return result; -} - /// Attempts to apply the pattern specified as template argument to the given /// operation. The pattern is expected to have a `returningMatchAndRewrite` /// function that returns the "main" result or failure. Returns failure if the @@ -604,8 +594,7 @@ DiagnosedSilenceableFailure transform::InterchangeOp::applyToOne(linalg::GenericOp target, SmallVectorImpl &results, transform::TransformState &state) { - SmallVector interchangeVector = - extractUIntArray(getIteratorInterchange()); + ArrayRef interchangeVector = getIteratorInterchange(); // Exit early if no transformation is needed. if (interchangeVector.empty()) { results.push_back(target); @@ -613,7 +602,9 @@ transform::InterchangeOp::applyToOne(linalg::GenericOp target, } TrivialPatternRewriter rewriter(target->getContext()); FailureOr res = - interchangeGenericOp(rewriter, target, interchangeVector); + interchangeGenericOp(rewriter, target, + SmallVector(interchangeVector.begin(), + interchangeVector.end())); if (failed(res)) return DiagnosedSilenceableFailure::definiteFailure(); results.push_back(res->getOperation()); @@ -621,9 +612,8 @@ transform::InterchangeOp::applyToOne(linalg::GenericOp target, } LogicalResult transform::InterchangeOp::verify() { - SmallVector permutation = - extractUIntArray(getIteratorInterchange()); - auto sequence = llvm::to_vector(llvm::seq(0, permutation.size())); + ArrayRef permutation = getIteratorInterchange(); + auto sequence = llvm::to_vector(llvm::seq(0, permutation.size())); if (!std::is_permutation(sequence.begin(), sequence.end(), permutation.begin(), permutation.end())) { return emitOpError() diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp index e3607b2..402e80b 100644 --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -257,7 +257,7 @@ LogicalResult transform::AlternativesOp::verify() { } //===----------------------------------------------------------------------===// -// ForeachOp +// CastOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure diff --git a/mlir/test/Dialect/Linalg/transform-op-interchange.mlir b/mlir/test/Dialect/Linalg/transform-op-interchange.mlir index 0f3a9fc..3b480d7 100644 --- a/mlir/test/Dialect/Linalg/transform-op-interchange.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-interchange.mlir @@ -21,7 +21,7 @@ func.func @interchange_generic(%arg0: tensor, %arg1: tensor) - transform.sequence failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 - transform.structured.interchange %0 { iterator_interchange = [1, 0]} + transform.structured.interchange %0 iterator_interchange = [1, 0] } // ----- @@ -36,5 +36,5 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 // expected-error @below {{transform applied to the wrong op kind}} - transform.structured.interchange %0 { iterator_interchange = [1, 0]} + transform.structured.interchange %0 iterator_interchange = [1, 0] } diff --git a/mlir/test/Dialect/Linalg/transform-ops-invalid.mlir b/mlir/test/Dialect/Linalg/transform-ops-invalid.mlir index 01bb8e8..e21b21a 100644 --- a/mlir/test/Dialect/Linalg/transform-ops-invalid.mlir +++ b/mlir/test/Dialect/Linalg/transform-ops-invalid.mlir @@ -2,8 +2,8 @@ transform.sequence failures(propagate) { ^bb0(%arg0: !pdl.operation): - // expected-error@below {{expects iterator_interchange to be a permutation, found [1, 1]}} - transform.structured.interchange %arg0 {iterator_interchange = [1, 1]} + // expected-error@below {{'transform.structured.interchange' op expects iterator_interchange to be a permutation, found 1, 1}} + transform.structured.interchange %arg0 iterator_interchange = [1, 1] } // ----- @@ -37,3 +37,11 @@ transform.sequence failures(propagate) { // expected-error@below {{expects transpose_paddings to be a permutation, found [1, 1]}} transform.structured.pad %arg0 {transpose_paddings=[[1, 1]]} } + +// ----- + +transform.sequence failures(propagate) { +^bb0(%arg0: !pdl.operation): + // expected-error@below {{'transform.structured.interchange' op attribute 'iterator_interchange' failed to satisfy constraint: i64 dense array attribute whose value is non-negative}} + transform.structured.interchange %arg0 iterator_interchange = [-3, 1] +} diff --git a/mlir/test/Dialect/Linalg/transform-patterns.mlir b/mlir/test/Dialect/Linalg/transform-patterns.mlir index 482cbc7..65ff4d6 100644 --- a/mlir/test/Dialect/Linalg/transform-patterns.mlir +++ b/mlir/test/Dialect/Linalg/transform-patterns.mlir @@ -138,7 +138,7 @@ func.func @permute_generic(%A: memref>, transform.sequence failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 - transform.structured.interchange %0 {iterator_interchange = [1, 2, 0]} + transform.structured.interchange %0 iterator_interchange = [1, 2, 0] } // CHECK-LABEL: func @permute_generic -- 2.7.4