[MLIR][Linalg] Use `DenseI64ArrayAttr` in `InterchangeOp` (NFC)
authorLorenzo Chelini <l.chelini@icloud.com>
Fri, 9 Dec 2022 17:50:36 +0000 (18:50 +0100)
committerLorenzo Chelini <l.chelini@icloud.com>
Fri, 16 Dec 2022 15:37:33 +0000 (16:37 +0100)
Use op separator to improve code navigation.

Reviewed By: hanchung

Differential Revision: https://reviews.llvm.org/D139917

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/lib/Dialect/Transform/IR/TransformOps.cpp
mlir/test/Dialect/Linalg/transform-op-interchange.mlir
mlir/test/Dialect/Linalg/transform-ops-invalid.mlir
mlir/test/Dialect/Linalg/transform-patterns.mlir

index 1cac6b8..1cb321d 100644 (file)
@@ -19,6 +19,10 @@ include "mlir/IR/EnumAttr.td"
 include "mlir/IR/OpBase.td"
 include "mlir/IR/RegionKindInterface.td"
 
+//===----------------------------------------------------------------------===//
+// DecomposeOp
+//===----------------------------------------------------------------------===//
+
 def DecomposeOp : Op<Transform_Dialect, "structured.decompose",
     [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
      TransformOpInterface, TransformEachOpTrait]> {
@@ -48,6 +52,10 @@ def DecomposeOp : Op<Transform_Dialect, "structured.decompose",
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// FuseOp
+//===----------------------------------------------------------------------===//
+
 def FuseOp : Op<Transform_Dialect, "structured.fuse",
     [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
      DeclareOpInterfaceMethods<TransformOpInterface>]> {
@@ -67,6 +75,10 @@ def FuseOp : Op<Transform_Dialect, "structured.fuse",
   let hasVerifier = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// FuseIntoContainingOp
+//===----------------------------------------------------------------------===//
+
 def FuseIntoContainingOp :
     Op<Transform_Dialect, "structured.fuse_into_containing_op",
       [DeclareOpInterfaceMethods<TransformOpInterface>]> {
@@ -120,6 +132,10 @@ def FuseIntoContainingOp :
   ];
 }
 
+//===----------------------------------------------------------------------===//
+// GeneralizeOp
+//===----------------------------------------------------------------------===//
+
 def GeneralizeOp : Op<Transform_Dialect, "structured.generalize",
     [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
      TransformOpInterface, TransformEachOpTrait]> {
@@ -149,6 +165,10 @@ def GeneralizeOp : Op<Transform_Dialect, "structured.generalize",
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// InterchangeOp
+//===----------------------------------------------------------------------===//
+
 def InterchangeOp : Op<Transform_Dialect, "structured.interchange",
     [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
     TransformOpInterface, TransformEachOpTrait]> {
@@ -169,10 +189,14 @@ def InterchangeOp : Op<Transform_Dialect, "structured.interchange",
 
   let arguments =
     (ins PDL_Operation:$target,
-         DefaultValuedAttr<I64ArrayAttr, "{}">:$iterator_interchange);
+         ConfinedAttr<DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">,
+                      [DenseArrayNonNegative<DenseI64ArrayAttr>]>:$iterator_interchange);
   let results = (outs PDL_Operation:$transformed);
 
-  let assemblyFormat = "$target attr-dict";
+  let assemblyFormat = [{ 
+    $target 
+    (`iterator_interchange` `=` $iterator_interchange^)? attr-dict
+  }];
   let hasVerifier = 1;
 
   let extraClassDeclaration = [{
@@ -183,6 +207,10 @@ def InterchangeOp : Op<Transform_Dialect, "structured.interchange",
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// MatchOp
+//===----------------------------------------------------------------------===//
+
 def MatchInterfaceEnum : I32EnumAttr<"MatchInterfaceEnum", "An interface to match",
     [
       I32EnumAttrCase<"LinalgOp", 0>,
@@ -245,6 +273,10 @@ def MatchOp : Op<Transform_Dialect, "structured.match",
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// MultiTileSizesOp
+//===----------------------------------------------------------------------===//
+
 def MultiTileSizesOp : Op<Transform_Dialect, "structured.multitile_sizes",
     [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
      TransformOpInterface, TransformEachOpTrait]> {
@@ -309,6 +341,10 @@ def MultiTileSizesOp : Op<Transform_Dialect, "structured.multitile_sizes",
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// PadOp
+//===----------------------------------------------------------------------===//
+
 def PadOp : Op<Transform_Dialect, "structured.pad",
     [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
      TransformOpInterface, TransformEachOpTrait]> {
@@ -349,6 +385,10 @@ def PadOp : Op<Transform_Dialect, "structured.pad",
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// PromoteOp
+//===----------------------------------------------------------------------===//
+
 def PromoteOp : Op<Transform_Dialect, "structured.promote",
     [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
     TransformOpInterface, TransformEachOpTrait]> {
@@ -388,6 +428,10 @@ def PromoteOp : Op<Transform_Dialect, "structured.promote",
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// ReplaceOp
+//===----------------------------------------------------------------------===//
+
 def ReplaceOp : Op<Transform_Dialect, "structured.replace",
     [IsolatedFromAbove, DeclareOpInterfaceMethods<TransformOpInterface>,
      DeclareOpInterfaceMethods<MemoryEffectsOpInterface>] # GraphRegionNoTerminator.traits> {
@@ -410,6 +454,10 @@ def ReplaceOp : Op<Transform_Dialect, "structured.replace",
   let hasVerifier = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// ScalarizeOp
+//===----------------------------------------------------------------------===//
+
 def ScalarizeOp : Op<Transform_Dialect, "structured.scalarize",
     [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
      TransformOpInterface, TransformEachOpTrait]> {
@@ -449,6 +497,10 @@ def ScalarizeOp : Op<Transform_Dialect, "structured.scalarize",
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// SplitOp
+//===----------------------------------------------------------------------===//
+
 def SplitOp : Op<Transform_Dialect, "structured.split",
     [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
      DeclareOpInterfaceMethods<TransformOpInterface>]> {
@@ -481,6 +533,10 @@ def SplitOp : Op<Transform_Dialect, "structured.split",
   let hasCustomAssemblyFormat = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// SplitReductionOp
+//===----------------------------------------------------------------------===//
+
 def SplitReductionOp : Op<Transform_Dialect, "structured.split_reduction",
        [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
         TransformEachOpTrait, TransformOpInterface]> {
@@ -649,6 +705,10 @@ def SplitReductionOp : Op<Transform_Dialect, "structured.split_reduction",
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// TileReductionUsingScfOp
+//===----------------------------------------------------------------------===//
+
 def TileReductionUsingScfOp : Op<Transform_Dialect, "structured.tile_reduction_using_scf",
        [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
         TransformEachOpTrait, TransformOpInterface]> {
@@ -748,6 +808,10 @@ def TileReductionUsingScfOp : Op<Transform_Dialect, "structured.tile_reduction_u
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// TileReductionUsingForeachThreadOp
+//===----------------------------------------------------------------------===//
+
 def TileReductionUsingForeachThreadOp :
   Op<Transform_Dialect, "structured.tile_reduction_using_foreach_thread",
        [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
@@ -853,6 +917,10 @@ def TileReductionUsingForeachThreadOp :
 
 }
 
+//===----------------------------------------------------------------------===//
+// TileOp
+//===----------------------------------------------------------------------===//
+
 def TileOp : Op<Transform_Dialect, "structured.tile",
        [DeclareOpInterfaceMethods<TransformOpInterface>,
         DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
@@ -910,6 +978,10 @@ def TileOp : Op<Transform_Dialect, "structured.tile",
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// TileToForeachThreadOp
+//===----------------------------------------------------------------------===//
+
 def TileToForeachThreadOp :
     Op<Transform_Dialect, "structured.tile_to_foreach_thread_op",
       [AttrSizedOperandSegments,
@@ -1023,6 +1095,10 @@ def TileToForeachThreadOp :
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// TileToScfForOp
+//===----------------------------------------------------------------------===//
+
 def TileToScfForOp : Op<Transform_Dialect, "structured.tile_to_scf_for",
        [DeclareOpInterfaceMethods<TransformOpInterface>,
         DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
@@ -1080,6 +1156,10 @@ def TileToScfForOp : Op<Transform_Dialect, "structured.tile_to_scf_for",
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// VectorizeOp
+//===----------------------------------------------------------------------===//
+
 def VectorizeOp : Op<Transform_Dialect, "structured.vectorize",
     [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
      TransformEachOpTrait, TransformOpInterface]> {
index c8995e6..3138268 100644 (file)
@@ -34,16 +34,6 @@ using namespace mlir::transform;
 
 #define DEBUG_TYPE "linalg-transforms"
 
-/// Extracts a vector of unsigned from an array attribute. Asserts if the
-/// attribute contains values other than intergers. May truncate.
-static SmallVector<unsigned> extractUIntArray(ArrayAttr attr) {
-  SmallVector<unsigned> result;
-  result.reserve(attr.size());
-  for (APInt value : attr.getAsValueRange<IntegerAttr>())
-    result.push_back(value.getZExtValue());
-  return result;
-}
-
 /// Attempts to apply the pattern specified as template argument to the given
 /// operation. The pattern is expected to have a `returningMatchAndRewrite`
 /// function that returns the "main" result or failure. Returns failure if the
@@ -604,8 +594,7 @@ DiagnosedSilenceableFailure
 transform::InterchangeOp::applyToOne(linalg::GenericOp target,
                                      SmallVectorImpl<Operation *> &results,
                                      transform::TransformState &state) {
-  SmallVector<unsigned> interchangeVector =
-      extractUIntArray(getIteratorInterchange());
+  ArrayRef<int64_t> interchangeVector = getIteratorInterchange();
   // Exit early if no transformation is needed.
   if (interchangeVector.empty()) {
     results.push_back(target);
@@ -613,7 +602,9 @@ transform::InterchangeOp::applyToOne(linalg::GenericOp target,
   }
   TrivialPatternRewriter rewriter(target->getContext());
   FailureOr<GenericOp> res =
-      interchangeGenericOp(rewriter, target, interchangeVector);
+      interchangeGenericOp(rewriter, target,
+                           SmallVector<unsigned>(interchangeVector.begin(),
+                                                 interchangeVector.end()));
   if (failed(res))
     return DiagnosedSilenceableFailure::definiteFailure();
   results.push_back(res->getOperation());
@@ -621,9 +612,8 @@ transform::InterchangeOp::applyToOne(linalg::GenericOp target,
 }
 
 LogicalResult transform::InterchangeOp::verify() {
-  SmallVector<unsigned> permutation =
-      extractUIntArray(getIteratorInterchange());
-  auto sequence = llvm::to_vector(llvm::seq<unsigned>(0, permutation.size()));
+  ArrayRef<int64_t> permutation = getIteratorInterchange();
+  auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size()));
   if (!std::is_permutation(sequence.begin(), sequence.end(),
                            permutation.begin(), permutation.end())) {
     return emitOpError()
index e3607b2..402e80b 100644 (file)
@@ -257,7 +257,7 @@ LogicalResult transform::AlternativesOp::verify() {
 }
 
 //===----------------------------------------------------------------------===//
-// ForeachOp
+// CastOp
 //===----------------------------------------------------------------------===//
 
 DiagnosedSilenceableFailure
index 0f3a9fc..3b480d7 100644 (file)
@@ -21,7 +21,7 @@ func.func @interchange_generic(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -
 transform.sequence failures(propagate) {
 ^bb1(%arg1: !pdl.operation):
   %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
-  transform.structured.interchange %0 { iterator_interchange = [1, 0]}
+  transform.structured.interchange %0 iterator_interchange = [1, 0] 
 }
 
 // -----
@@ -36,5 +36,5 @@ transform.sequence failures(propagate) {
 ^bb1(%arg1: !pdl.operation):
   %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
   // expected-error @below {{transform applied to the wrong op kind}}
-  transform.structured.interchange %0 { iterator_interchange = [1, 0]}
+  transform.structured.interchange %0 iterator_interchange = [1, 0]
 }
index 01bb8e8..e21b21a 100644 (file)
@@ -2,8 +2,8 @@
 
 transform.sequence failures(propagate) {
 ^bb0(%arg0: !pdl.operation):
-  // expected-error@below {{expects iterator_interchange to be a permutation, found [1, 1]}}
-  transform.structured.interchange %arg0 {iterator_interchange = [1, 1]}
+  // expected-error@below {{'transform.structured.interchange' op expects iterator_interchange to be a permutation, found 1, 1}}
+  transform.structured.interchange %arg0 iterator_interchange = [1, 1] 
 }
 
 // -----
@@ -37,3 +37,11 @@ transform.sequence failures(propagate) {
   // expected-error@below {{expects transpose_paddings to be a permutation, found [1, 1]}}
   transform.structured.pad %arg0 {transpose_paddings=[[1, 1]]}
 }
+
+// -----
+
+transform.sequence failures(propagate) {
+^bb0(%arg0: !pdl.operation):
+  // expected-error@below {{'transform.structured.interchange' op attribute 'iterator_interchange' failed to satisfy constraint: i64 dense array attribute whose value is non-negative}}
+  transform.structured.interchange %arg0 iterator_interchange = [-3, 1]
+}
index 482cbc7..65ff4d6 100644 (file)
@@ -138,7 +138,7 @@ func.func @permute_generic(%A: memref<?x?xf32, strided<[?, 1], offset: ?>>,
 transform.sequence failures(propagate) {
 ^bb1(%arg1: !pdl.operation):
   %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
-  transform.structured.interchange %0 {iterator_interchange = [1, 2, 0]}
+  transform.structured.interchange %0 iterator_interchange = [1, 2, 0]
 }
 
 // CHECK-LABEL:  func @permute_generic