[mlir][Linalg] Introduce a ContractionOpInterface
authorNicolas Vasilache <nicolas.vasilache@gmail.com>
Thu, 4 Feb 2021 16:49:09 +0000 (16:49 +0000)
committerNicolas Vasilache <nicolas.vasilache@gmail.com>
Thu, 4 Feb 2021 16:53:24 +0000 (16:53 +0000)
This revision takes advantage of recent extensions to vectorization to refactor contraction detection into a bona fide Linalg interface.
The mlit-linalg-ods-gen parser is extended to support adding such interfaces.
The detection that was originally enabling vectorization is refactored to serve as both a test on a generic LinalgOp as well as to verify ops that declare to conform to that interface.

This is plugged through Linalg transforms and strategies but it quickly becomes evident that the complexity and rigidity of the C++ class based templating does not pay for itself.
Therefore, this revision changes the API for vectorization patterns to get rid of templates as much as possible.
Variadic templates are relegated to the internals of LinalgTransformationFilter as much as possible and away from the user-facing APIs.

It is expected other patterns / transformations will follow the same path and drop as much C++ templating as possible from the class definition.

Differential revision: https://reviews.llvm.org/D95973

mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc
mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/test/lib/Transforms/TestLinalgCodegenStrategy.cpp
mlir/test/lib/Transforms/TestLinalgTransforms.cpp
mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp

index e4fddd5..fb419d0 100644 (file)
 
 namespace mlir {
 namespace linalg {
+class LinalgOp;
 
 /// Returns the values obtained by applying `map` to the list of values.
 SmallVector<Value, 4> applyMapToValues(OpBuilder &b, Location loc,
                                        AffineMap map, ValueRange values);
 
+/// Checks whether `linalgOp` conforms to ContractionOpInterface.
+// TODO: embed within `isa<ContractionOpInterface>` if possible / natural.
+bool isaContractionOpInterface(LinalgOp linalgOp);
+
 namespace detail {
 
+/// Verify that `op` conforms to ContractionOpInterface.
+LogicalResult verifyContractionInterface(Operation *op);
+
 /// Verify that `op` conforms to the invariants of StructuredOpInterface
 LogicalResult verifyStructuredOpInterface(Operation *op);
 
index a38b04c..dd15f22 100644 (file)
 
 include "mlir/IR/OpBase.td"
 
-// The linalg 'LinalgStructuredInterface' provides access to the 'LinalgOp'
-// interface.
+// The 'LinalgContractionOpInterface' provides access to the
+// 'ContractionOpInterface'.
+def LinalgContractionOpInterface : OpInterface<"ContractionOpInterface"> {
+  let description = [{
+   A Linalg contraction is defined in general terms:
+     1. Has 2 input and 1 output shapes.
+     2. Has at least one reduction dimension.
+     3. Has only projected permutation indexing maps.
+     4. its body computes `u5(u1(c) + u2(u3(a) * u4(b)))` on some field
+     (AddOpType, MulOpType), where u1, u2, u3, u4 and u5 represent scalar unary
+     operations that may change the type (e.g. for mixed-precision).
+   As a consequence, when vectorization of such an op occurs, the only special
+   behavior is that the (unique) MulOpType is vectorized into a
+   `vector.contract`. All other ops are handled in a generic fashion.
+   In the future, we may wish to allow more input arguments and elementwise and
+   constant operations that do not involve the reduction dimension(s).
+  }];
+  let cppNamespace = "::mlir::linalg";
+  let verify = [{ return detail::verifyContractionInterface($_op); }];
+}
+
+// The 'LinalgStructuredInterface' provides access to the 'LinalgOp' interface.
 def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
   let cppNamespace = "::mlir::linalg";
   let methods = [
index fc09243..5406b30 100644 (file)
@@ -1,36 +1,43 @@
-ods_def<MatmulOp>:
+ods_def<MatmulOp>
+implements_interface<LinalgContractionOpInterface> :
 def matmul(A: f32(M, K), B: f32(K, N)) -> (C: f32(M, N)) {
   C(m, n) = std_addf<k>(std_mulf(A(m, k), B(k, n)));
 }
 
-ods_def<MatmulColumnMajorOp>:
+ods_def<MatmulColumnMajorOp>
+implements_interface<LinalgContractionOpInterface> :
 def matmul_column_major(A: f32(K, M), B: f32(N, K)) -> (C: f32(N, M)) {
   C(n, m) = std_addf<k>(std_mulf(A(k, m), B(n, k)));
 }
 
-ods_def<MatmulI8I8I32Op>:
+ods_def<MatmulI8I8I32Op>
+implements_interface<LinalgContractionOpInterface> :
 def matmul_i8_i8_i32(A: i8(M, K), B: i8(K, N)) -> (C: i32(M, N)) {
   // TODO: ideally something closer to
   //   C(m, n) += cast<i32>(A(m, k)) * cast<i32>(B(k, n))
   C(m, n) = std_addi<k>(std_sexti32(std_muli(A(m, k), B(k, n))));
 }
 
-ods_def<MatvecOp>:
+ods_def<MatvecOp>
+implements_interface<LinalgContractionOpInterface> :
 def matvec(A: f32(M, N), y: f32(N)) -> (x: f32(M)) {
   x(m) = std_addf<n>(std_mulf(A(m, n), y(n)));
 }
 
-ods_def<VecmatOp>:
+ods_def<VecmatOp>
+implements_interface<LinalgContractionOpInterface> :
 def vecmat(y: f32(M), A: f32(M, N)) -> (x: f32(N)) {
   x(n) = std_addf<m>(std_mulf(y(m), A(m, n)));
 }
 
-ods_def<DotOp>:
+ods_def<DotOp>
+implements_interface<LinalgContractionOpInterface> :
 def dot(A: f32(M), B: f32(M)) -> (C: f32()) {
   C() = std_addf<m>(std_mulf(A(m), B(m)));
 }
 
-ods_def<BatchMatmulOp>:
+ods_def<BatchMatmulOp>
+implements_interface<LinalgContractionOpInterface> :
 def batch_matmul(A: f32(Batch, M, K), B: f32(Batch, K, N)) -> (C: f32(Batch, M, N)) {
   C(b, m, n) = std_addf<k>(std_mulf(A(b, m, k), B(b, k, n)));
 }
index 2e0796d..fe481bb 100644 (file)
@@ -35,23 +35,33 @@ template <template <typename> class PatternType, typename ConcreteOpType,
           typename OptionsType,
           typename = std::enable_if_t<std::is_member_function_pointer<
               decltype(&ConcreteOpType::getOperationName)>::value>>
-void sfinae_enqueue(OwningRewritePatternList &patterList, OptionsType options,
+void sfinae_enqueue(OwningRewritePatternList &patternList, OptionsType options,
                     MLIRContext *context, StringRef opName,
                     linalg::LinalgTransformationFilter m) {
   assert(opName == ConcreteOpType::getOperationName() &&
          "explicit name must match ConcreteOpType::getOperationName");
-  patterList.insert<PatternType<ConcreteOpType>>(context, options, m);
+  patternList.insert<PatternType<ConcreteOpType>>(context, options, m);
 }
 
 /// SFINAE: Enqueue helper for OpType that do not have a `getOperationName`
 /// (e.g. LinalgOp, other interfaces, Operation*).
 template <template <typename> class PatternType, typename OpType,
           typename OptionsType>
-void sfinae_enqueue(OwningRewritePatternList &patterList, OptionsType options,
+void sfinae_enqueue(OwningRewritePatternList &patternList, OptionsType options,
                     MLIRContext *context, StringRef opName,
                     linalg::LinalgTransformationFilter m) {
   assert(!opName.empty() && "opName must not be empty");
-  patterList.insert<PatternType<OpType>>(opName, context, options, m);
+  patternList.insert<PatternType<OpType>>(opName, context, options, m);
+}
+
+template <typename PatternType, typename OpType, typename OptionsType>
+void enqueue(OwningRewritePatternList &patternList, OptionsType options,
+             MLIRContext *context, StringRef opName,
+             linalg::LinalgTransformationFilter m) {
+  if (!opName.empty())
+    patternList.insert<PatternType>(opName, context, options, m);
+  else
+    patternList.insert<PatternType>(m.addOpFilter<OpType>(), options);
 }
 
 /// Promotion transformation enqueues a particular stage-1 pattern for
@@ -112,13 +122,12 @@ private:
 /// Vectorization transformation enqueues a particular stage-1 pattern for
 /// `LinalgVectorizationPattern<LinalgOpType>` as well as copy to vector
 /// transfer rewrite forwarding patterns.
-template <typename LinalgOpType>
+template <typename LinalgOpType = LinalgOp>
 struct Vectorize : public Transformation {
   explicit Vectorize(
       linalg::LinalgVectorizationOptions options,
       linalg::LinalgTransformationFilter::FilterFunction f = nullptr)
-      : Transformation(f), opName(LinalgOpType::getOperationName()),
-        options(options) {}
+      : Transformation(f), opName(), options(options) {}
 
   Vectorize(StringRef name, linalg::LinalgVectorizationOptions options,
             linalg::LinalgTransformationFilter::FilterFunction f = nullptr)
@@ -128,7 +137,7 @@ struct Vectorize : public Transformation {
   buildRewritePatterns(MLIRContext *context,
                        linalg::LinalgTransformationFilter m) override {
     OwningRewritePatternList vectorizationPatterns;
-    sfinae_enqueue<linalg::LinalgVectorizationPattern, LinalgOpType>(
+    enqueue<linalg::LinalgVectorizationPattern, LinalgOpType>(
         vectorizationPatterns, options, context, opName, m);
     vectorizationPatterns.insert<linalg::LinalgCopyVTRForwardingPattern,
                                  linalg::LinalgCopyVTWForwardingPattern>(
@@ -235,16 +244,6 @@ struct CodegenStrategy {
             linalg::LinalgVectorizationOptions(), f));
     return *this;
   }
-  /// Append a pattern to rewrite `LinalgOpType` as a vector operation.
-  template <typename LinalgOpType>
-  CodegenStrategy &
-  vectorize(StringRef opName,
-            linalg::LinalgTransformationFilter::FilterFunction f = nullptr) {
-    transformationSequence.emplace_back(
-        std::make_unique<Vectorize<LinalgOpType>>(
-            opName, linalg::LinalgVectorizationOptions(), f));
-    return *this;
-  }
   /// Conditionally append a pattern to rewrite `LinalgOpType` as a vector
   /// operation.
   template <typename LinalgOpType>
@@ -254,13 +253,21 @@ struct CodegenStrategy {
     return b ? vectorize<LinalgOpType>(f) : *this;
     return *this;
   }
+  /// Append a pattern to rewrite `LinalgOpType` as a vector operation.
+  CodegenStrategy &
+  vectorize(StringRef opName,
+            linalg::LinalgTransformationFilter::FilterFunction f = nullptr) {
+    assert(!opName.empty() && "expected an op name");
+    transformationSequence.emplace_back(std::make_unique<Vectorize<LinalgOp>>(
+        opName, linalg::LinalgVectorizationOptions(), f));
+    return *this;
+  }
   /// Conditionally append a pattern to rewrite `LinalgOpType` as a vector
   /// operation.
-  template <typename LinalgOpType>
   CodegenStrategy &
   vectorizeIf(bool b, StringRef opName,
               linalg::LinalgTransformationFilter::FilterFunction f = nullptr) {
-    return b ? vectorize<LinalgOpType>(opName, f) : *this;
+    return b ? vectorize(opName, f) : *this;
     return *this;
   }
   /// Configure the post staged-patterns late vector transformations.
index 18cb91e..16203e5 100644 (file)
@@ -340,9 +340,20 @@ struct LinalgTransformationFilter {
   void replaceLinalgTransformationFilter(PatternRewriter &rewriter,
                                          Operation *op) const;
 
+  LinalgTransformationFilter &addFilter(FilterFunction f) {
+    if (f)
+      filters.push_back(f);
+    return *this;
+  }
+  template <typename... OpTypes>
+  LinalgTransformationFilter &addOpFilter() {
+    return addFilter(
+        [](Operation *op) { return success(isa<OpTypes...>(op)); });
+  }
+
 private:
-  FilterFunction filter;
-  SmallVector<Identifier, 4> matchDisjunction;
+  SmallVector<FilterFunction> filters;
+  SmallVector<Identifier> matchDisjunction;
   Optional<Identifier> replacement;
 };
 
@@ -350,7 +361,7 @@ private:
 /// Linalg tiling patterns.
 ///
 /// Apply the `tileLinalgOp` transformation as a pattern.
-/// `marker` controls LinalgTransformMarker matching and update when specified.
+/// `filter` controls LinalgTransformMarker matching and update when specified.
 /// See `tileLinalgOp` for more details.
 enum class LinalgTilingLoopType {
   Loops = 0,
@@ -443,19 +454,19 @@ struct LinalgBaseTilingPattern : public RewritePattern {
   // Entry point to match any LinalgOp OpInterface.
   LinalgBaseTilingPattern(
       LinalgTilingOptions options,
-      LinalgTransformationFilter marker = LinalgTransformationFilter(),
+      LinalgTransformationFilter filter = LinalgTransformationFilter(),
       PatternBenefit benefit = 1);
   // Entry point to match a specific Linalg op.
   LinalgBaseTilingPattern(
       StringRef opName, MLIRContext *context, LinalgTilingOptions options,
-      LinalgTransformationFilter marker = LinalgTransformationFilter(),
+      LinalgTransformationFilter filter = LinalgTransformationFilter(),
       PatternBenefit benefit = 1);
   LogicalResult matchAndRewriteBase(Operation *op, PatternRewriter &rewriter,
                                     TiledLinalgOp &result) const;
 
 private:
   /// LinalgTransformMarker handles special attribute manipulations.
-  LinalgTransformationFilter marker;
+  LinalgTransformationFilter filter;
   /// Options to control tiling;
   LinalgTilingOptions options;
 };
@@ -467,17 +478,17 @@ struct LinalgTilingPattern : public LinalgBaseTilingPattern {
   template <typename ConcreateOpTy = OpTy>
   LinalgTilingPattern(
       MLIRContext *context, LinalgTilingOptions options,
-      LinalgTransformationFilter marker = LinalgTransformationFilter(),
+      LinalgTransformationFilter filter = LinalgTransformationFilter(),
       PatternBenefit benefit = 1)
       : LinalgBaseTilingPattern(ConcreateOpTy::getOperationName(), context,
-                                options, marker, benefit) {}
+                                options, filter, benefit) {}
 
   /// This constructor is available to anyone.
   LinalgTilingPattern(
       StringRef opName, MLIRContext *context, LinalgTilingOptions options,
-      LinalgTransformationFilter marker = LinalgTransformationFilter(),
+      LinalgTransformationFilter filter = LinalgTransformationFilter(),
       PatternBenefit benefit = 1)
-      : LinalgBaseTilingPattern(opName, context, options, marker, benefit) {}
+      : LinalgBaseTilingPattern(opName, context, options, filter, benefit) {}
 
   LogicalResult matchAndRewrite(Operation *op,
                                 PatternRewriter &rewriter) const override {
@@ -507,7 +518,7 @@ struct LinalgBaseTileAndFusePattern : public RewritePattern {
       StringRef opName, MLIRContext *context,
       const LinalgDependenceGraph &dependenceGraph,
       LinalgTilingOptions tilingOptions, LinalgFusionOptions fusionOptions,
-      LinalgTransformationFilter marker = LinalgTransformationFilter(),
+      LinalgTransformationFilter filter = LinalgTransformationFilter(),
       LinalgTransformationFilter fusedOpMarker = LinalgTransformationFilter(),
       LinalgTransformationFilter originalOpMarker =
           LinalgTransformationFilter(),
@@ -523,13 +534,13 @@ private:
   /// Options to control fusion.
   LinalgFusionOptions fusionOptions;
   /// Marker to control application of the pattern.
-  LinalgTransformationFilter marker;
+  LinalgTransformationFilter filter;
   /// Marker set on the fused op after tile and fuse.
   LinalgTransformationFilter fusedOpMarker;
   /// The dependenceGraph is not modifiable, i.e. if the Linalg operations used
   /// to build the dependence graph changes then the dependenceGraph needs to be
   /// recomputed right now. To not invalidate the dependenceGraph as
-  /// transformation happens, the original producer can be tagged with a marker
+  /// transformation happens, the original producer can be tagged with a filter
   /// that can be later used to delete the original operations.
   LinalgTransformationFilter originalOpMarker;
 };
@@ -539,34 +550,34 @@ struct LinalgTileAndFusePattern : public LinalgBaseTileAndFusePattern {
   LinalgTileAndFusePattern(
       MLIRContext *context, const LinalgDependenceGraph &dependenceGraph,
       LinalgTilingOptions tilingOptions, LinalgFusionOptions fusionOptions,
-      LinalgTransformationFilter marker = LinalgTransformationFilter(),
+      LinalgTransformationFilter filter = LinalgTransformationFilter(),
       LinalgTransformationFilter fusedOpMarker = LinalgTransformationFilter(),
       LinalgTransformationFilter originalOpMarker =
           LinalgTransformationFilter(),
       PatternBenefit benefit = 1)
       : LinalgBaseTileAndFusePattern(
             OpTy::getOperationName(), context, dependenceGraph, tilingOptions,
-            fusionOptions, marker, fusedOpMarker, originalOpMarker, benefit) {}
+            fusionOptions, filter, fusedOpMarker, originalOpMarker, benefit) {}
 };
 
 ///
 /// Linalg interchange patterns.
 ///
 /// Apply the `interchange` transformation as a pattern.
-/// `marker` controls LinalgTransformMarker matching and update when specified.
+/// `filter` controls LinalgTransformMarker matching and update when specified.
 /// See `interchange` for more details.
 struct LinalgBaseInterchangePattern : public RewritePattern {
   LinalgBaseInterchangePattern(
       StringRef opName, MLIRContext *context,
       ArrayRef<unsigned> interchangeVector,
-      LinalgTransformationFilter marker = LinalgTransformationFilter(),
+      LinalgTransformationFilter filter = LinalgTransformationFilter(),
       PatternBenefit benefit = 1);
   LogicalResult matchAndRewrite(Operation *op,
                                 PatternRewriter &rewriter) const override;
 
 private:
   /// LinalgTransformMarker handles special attribute manipulations.
-  LinalgTransformationFilter marker;
+  LinalgTransformationFilter filter;
   /// The interchange vector to reorder the iterators and indexing_maps dims.
   SmallVector<unsigned, 8> interchangeVector;
 };
@@ -575,22 +586,22 @@ template <typename OpTy>
 struct LinalgInterchangePattern : public LinalgBaseInterchangePattern {
   LinalgInterchangePattern(
       MLIRContext *context, ArrayRef<unsigned> interchangeVector,
-      LinalgTransformationFilter marker = LinalgTransformationFilter(),
+      LinalgTransformationFilter filter = LinalgTransformationFilter(),
       PatternBenefit benefit = 1)
       : LinalgBaseInterchangePattern(OpTy::getOperationName(), context,
-                                     interchangeVector, marker, benefit) {}
+                                     interchangeVector, filter, benefit) {}
 };
 
 ///
 /// Linalg promotion patterns.
 ///
 /// Apply the `promoteSubViews` transformation as a pattern.
-/// `marker` controls LinalgTransformMarker matching and update when specified.
+/// `filter` controls LinalgTransformMarker matching and update when specified.
 /// See `promoteSubViews` for more details.
 struct LinalgBasePromotionPattern : public RewritePattern {
   LinalgBasePromotionPattern(
       StringRef opName, MLIRContext *context, LinalgPromotionOptions options,
-      LinalgTransformationFilter marker = LinalgTransformationFilter(),
+      LinalgTransformationFilter filter = LinalgTransformationFilter(),
       PatternBenefit benefit = 1);
 
   LogicalResult matchAndRewrite(Operation *op,
@@ -598,7 +609,7 @@ struct LinalgBasePromotionPattern : public RewritePattern {
 
 private:
   /// LinalgTransformMarker handles special attribute manipulations.
-  LinalgTransformationFilter marker;
+  LinalgTransformationFilter filter;
   /// Promotion options.
   LinalgPromotionOptions options;
 };
@@ -610,67 +621,112 @@ struct LinalgPromotionPattern : public LinalgBasePromotionPattern {
   template <typename ConcreateOpTy = OpTy>
   LinalgPromotionPattern(
       MLIRContext *context, LinalgPromotionOptions options,
-      LinalgTransformationFilter marker = LinalgTransformationFilter(),
+      LinalgTransformationFilter filter = LinalgTransformationFilter(),
       PatternBenefit benefit = 1)
       : LinalgBasePromotionPattern(OpTy::getOperationName(), context, options,
-                                   marker, benefit) {}
+                                   filter, benefit) {}
   /// This constructor is available to anyone.
   LinalgPromotionPattern(
       StringRef opName, MLIRContext *context, LinalgPromotionOptions options,
-      LinalgTransformationFilter marker = LinalgTransformationFilter(),
+      LinalgTransformationFilter filter = LinalgTransformationFilter(),
       PatternBenefit benefit = 1)
-      : LinalgBasePromotionPattern(opName, context, options, marker, benefit) {}
+      : LinalgBasePromotionPattern(opName, context, options, filter, benefit) {}
 };
 
 ///
 /// Linalg vectorization patterns.
 ///
 /// Apply the `vectorizeLinalgOp` transformation as a pattern.
-/// `marker` controls LinalgTransformMarker matching and update when specified.
+/// `filter` controls LinalgTransformMarker matching and update when specified.
 /// See `vectorizeLinalgOp` for more details.
 
 /// Empty for now, used for SFINAE purposes only.
 struct LinalgVectorizationOptions {};
 
 struct LinalgBaseVectorizationPattern : public RewritePattern {
+  /// MatchAnyOpTag-based constructor with a mandatory `filter`.
+  LinalgBaseVectorizationPattern(LinalgTransformationFilter filter,
+                                 PatternBenefit benefit = 1);
+  /// Name-based constructor with an optional `filter`.
   LinalgBaseVectorizationPattern(
       StringRef opName, MLIRContext *context,
-      LinalgTransformationFilter marker = LinalgTransformationFilter(),
+      LinalgTransformationFilter filter = LinalgTransformationFilter(),
       PatternBenefit benefit = 1);
   LogicalResult matchAndRewrite(Operation *op,
                                 PatternRewriter &rewriter) const override;
 
 private:
   /// LinalgTransformMarker handles special attribute manipulations.
-  LinalgTransformationFilter marker;
+  LinalgTransformationFilter filter;
 };
 
-template <typename OpTy>
 struct LinalgVectorizationPattern : public LinalgBaseVectorizationPattern {
-  /// SFINAE: This constructor can only trigger for concrete ops that have a
-  /// static `getOperationName` method.
-  template <typename ConcreateOpTy = OpTy>
+  /// These constructors are available to anyone.
+  /// MatchAnyOpTag-based constructor with a mandatory `filter`.
   LinalgVectorizationPattern(
-      MLIRContext *context,
+      LinalgTransformationFilter filter,
       LinalgVectorizationOptions options = LinalgVectorizationOptions(),
-      LinalgTransformationFilter marker = LinalgTransformationFilter(),
       PatternBenefit benefit = 1)
-      : LinalgBaseVectorizationPattern(OpTy::getOperationName(), context,
-                                       marker, benefit) {}
-  /// This constructor is available to anyone.
+      : LinalgBaseVectorizationPattern(filter, benefit) {}
+  /// Name-based constructor with an optional `filter`.
   LinalgVectorizationPattern(
       StringRef opName, MLIRContext *context,
       LinalgVectorizationOptions options = LinalgVectorizationOptions(),
-      LinalgTransformationFilter marker = LinalgTransformationFilter(),
+      LinalgTransformationFilter filter = LinalgTransformationFilter(),
       PatternBenefit benefit = 1)
-      : LinalgBaseVectorizationPattern(opName, context, marker, benefit) {}
+      : LinalgBaseVectorizationPattern(opName, context, filter, benefit) {}
 };
 
+/// Trait to check if T provides a `getOperationName` method.
+template <typename T, typename... Args>
+using has_get_operation_name = decltype(T::getOperationName());
+template <typename T>
+using detect_has_get_operation_name =
+    llvm::is_detected<has_get_operation_name, T>;
+
+/// SFINAE helper for single C++ op with a `getOperationName` method.
+template <
+    typename OpType,
+    typename = std::enable_if_t<detect_has_get_operation_name<OpType>::value>,
+    typename = void>
+void insertVectorizationPatternImpl(OwningRewritePatternList &patternList,
+                                    MLIRContext *context,
+                                    linalg::LinalgVectorizationOptions options,
+                                    linalg::LinalgTransformationFilter f) {
+  patternList.insert<linalg::LinalgVectorizationPattern>(
+      OpType::getOperationName(), context, options, f);
+}
+
+/// SFINAE helper for single C++ class without a `getOperationName` method (e.g.
+/// an OpInterface).
+template <typename OpType, typename = std::enable_if_t<
+                               !detect_has_get_operation_name<OpType>::value>>
+void insertVectorizationPatternImpl(OwningRewritePatternList &patternList,
+                                    MLIRContext *context,
+                                    linalg::LinalgVectorizationOptions options,
+                                    linalg::LinalgTransformationFilter f) {
+  patternList.insert<linalg::LinalgVectorizationPattern>(
+      f.addOpFilter<OpType>(), options);
+}
+
+/// Variadic helper function to insert vectorization patterns for C++ ops.
+template <typename... OpTypes>
+void insertVectorizationPatterns(OwningRewritePatternList &patternList,
+                                 MLIRContext *context,
+                                 linalg::LinalgVectorizationOptions options,
+                                 linalg::LinalgTransformationFilter f =
+                                     linalg::LinalgTransformationFilter()) {
+  // FIXME: In c++17 this can be simplified by using 'fold expressions'.
+  (void)std::initializer_list<int>{0, (insertVectorizationPatternImpl<OpTypes>(
+                                           patternList, context, options, f),
+                                       0)...};
+}
+
 ///
 /// Linalg lowering patterns.
 ///
 /// Apply the `linalgLowerOpToLoops` transformation as a pattern.
-/// `marker` controls LinalgTransformMarker matching and update when specified.
+/// `filter` controls LinalgTransformMarker matching and update when specified.
 /// See `linalgLowerOpToLoops` for more details.
 enum class LinalgLoweringType {
   LibraryCall = 0,
@@ -683,10 +739,10 @@ template <typename OpTy>
 struct LinalgLoweringPattern : public RewritePattern {
   LinalgLoweringPattern(
       MLIRContext *context, LinalgLoweringType loweringType,
-      LinalgTransformationFilter marker = LinalgTransformationFilter(),
+      LinalgTransformationFilter filter = LinalgTransformationFilter(),
       ArrayRef<unsigned> interchangeVector = {}, PatternBenefit benefit = 1)
       : RewritePattern(OpTy::getOperationName(), {}, benefit, context),
-        marker(marker), loweringType(loweringType),
+        filter(filter), loweringType(loweringType),
         interchangeVector(interchangeVector.begin(), interchangeVector.end()) {}
 
   // TODO: Move implementation to .cpp once named ops are auto-generated.
@@ -695,7 +751,7 @@ struct LinalgLoweringPattern : public RewritePattern {
     LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
     if (!linalgOp)
       return failure();
-    if (failed(marker.checkAndNotify(rewriter, linalgOp)))
+    if (failed(filter.checkAndNotify(rewriter, linalgOp)))
       return failure();
 
     switch (loweringType) {
@@ -722,7 +778,7 @@ struct LinalgLoweringPattern : public RewritePattern {
 
 private:
   /// LinalgTransformMarker handles special attribute manipulations.
-  LinalgTransformationFilter marker;
+  LinalgTransformationFilter filter;
   /// Controls whether the pattern lowers to library calls, scf.for, affine.for
   /// or scf.parallel.
   LinalgLoweringType loweringType;
@@ -736,13 +792,13 @@ private:
 /// linalg.generic ops.
 void populateLinalgNamedOpsGeneralizationPatterns(
     MLIRContext *context, OwningRewritePatternList &patterns,
-    LinalgTransformationFilter marker = LinalgTransformationFilter());
+    LinalgTransformationFilter filter = LinalgTransformationFilter());
 
 /// Populates `patterns` with patterns to convert linalg.conv ops to
 /// linalg.generic ops.
 void populateLinalgConvGeneralizationPatterns(
     MLIRContext *context, OwningRewritePatternList &patterns,
-    LinalgTransformationFilter marker = LinalgTransformationFilter());
+    LinalgTransformationFilter filter = LinalgTransformationFilter());
 
 //===----------------------------------------------------------------------===//
 // Op-specific patterns.
index f9b17dd..0aafc4f 100644 (file)
@@ -19,6 +19,139 @@ using namespace mlir::linalg;
 /// Include the definitions of the copy operation interface.
 #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.cpp.inc"
 
+//===----------------------------------------------------------------------===//
+// ContractionOpInterface implementation
+//===----------------------------------------------------------------------===//
+
+/// Return true if the use-def chain from `v` to `from` consists of 0 or more
+/// unary single-operand operations.
+// TODO: relax to multi-operands with constants, which are technically unary ops
+// as needed (e.g. add5).
+static bool isChainOfUnaryOpsFrom(Value v, Value from) {
+  while (true) {
+    if (v == from)
+      return true;
+    Operation *op = v.getDefiningOp();
+    if (!op || op->getNumOperands() != 1)
+      return false;
+    v = op->getOperand(0);
+  };
+}
+
+/// Return the unique instance of OpType in `block` if it is indeed unique.
+/// Return null if none or more than 1 instances exist.
+template <typename OpType>
+static OpType getSingleOpOfType(Block &block) {
+  OpType res = nullptr;
+  block.walk([&](OpType op) {
+    if (res) {
+      res = nullptr;
+      return WalkResult::interrupt();
+    }
+    res = op;
+    return WalkResult::advance();
+  });
+  return res;
+}
+
+/// Detect whether res is any permutation of `u5(u1(c) + u2(u3(a) * u4(b)))`
+/// on the field (AddOpType, MulOpType), where u1, u2, u3, u4 and u5 represent
+/// unary operations that may change the type.
+template <typename AddOpType, typename MulOpType>
+static bool isAddMul(Block &block) {
+  if (block.getNumArguments() != 3)
+    return false;
+  Operation *yieldOp = block.getTerminator();
+  if (yieldOp->getNumOperands() != 1)
+    return false;
+
+  AddOpType addOp = getSingleOpOfType<AddOpType>(block);
+  MulOpType mulOp = getSingleOpOfType<MulOpType>(block);
+  if (!addOp || !mulOp)
+    return false;
+
+  Value argA = block.getArgument(0), argB = block.getArgument(1);
+  Value a = mulOp->getOperand(0), b = mulOp->getOperand(1);
+  Value mul = mulOp->getResult(0);
+  Value argC = block.getArgument(2);
+  Value c1 = addOp->getOperand(0), c2 = addOp->getOperand(1);
+  Value add = addOp->getResult(0);
+  Value res = yieldOp->getOperand(0);
+  // Result traces back to add.
+  auto un = isChainOfUnaryOpsFrom;
+  bool success = un(res, add);
+  // One of the operands of add traces back to argC, the other to the mul.
+  success |= (un(c1, argC) && un(c2, mul)) || ((un(c1, mul)) && un(c2, argC));
+  // One of the operands of mul traces back to argA, the other to argB.
+  success |= (un(a, argA) && un(b, argB)) || ((un(a, argB)) && un(b, argA));
+  return success;
+}
+
+enum MatchContractionResult {
+  Success = 0,
+  NotLinalgOp,
+  WrongNumOperands,
+  NoReduction,
+  NotProjectedPermutations,
+  NotAddMul
+};
+static MatchContractionResult isContractionInterfaceImpl(Operation *op) {
+  auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
+  if (!linalgOp)
+    return MatchContractionResult::NotLinalgOp;
+  if (linalgOp.getNumInputs() != 2 || linalgOp.getNumOutputs() != 1)
+    return MatchContractionResult::WrongNumOperands;
+  auto mapRange = linalgOp.indexing_maps().getAsValueRange<AffineMapAttr>();
+  if (linalgOp.getNumReductionLoops() == 0)
+    return MatchContractionResult::NoReduction;
+  if (llvm::any_of(mapRange,
+                   [](AffineMap m) { return !m.isProjectedPermutation(); }))
+    return MatchContractionResult::NotProjectedPermutations;
+  // TODO: more fields than add/mul.
+  if (!isAddMul<AddFOp, MulFOp>(linalgOp->getRegion(0).front()) &&
+      !isAddMul<AddIOp, MulIOp>(linalgOp->getRegion(0).front()))
+    return MatchContractionResult::NotAddMul;
+  return MatchContractionResult::Success;
+}
+
+bool mlir::linalg::isaContractionOpInterface(LinalgOp linalgOp) {
+  Operation *op = linalgOp.getOperation();
+  return isa<ContractionOpInterface>(op) ||
+         (isContractionInterfaceImpl(op) == MatchContractionResult::Success);
+}
+
+/// Verify that a LinalgOp `op` is a contraction.
+/// A Linalg contraction is defined in general terms:
+///   1. Has 2 input and 1 output shapes.
+///   2. Has at least one reduction dimension.
+///   3. Has only projected permutation indexing maps.
+///   4. its body computes `u5(u1(c) + u2(u3(a) * u4(b)))` on some field
+///   (AddOpType, MulOpType), where u1, u2, u3, u4 and u5 represent scalar unary
+///   operations that may change the type (e.g. for mixed-precision).
+/// As a consequence, when vectorization of such an op occurs, the only special
+/// behavior is that the (unique) MulOpType is vectorized into a
+/// `vector.contract`. All other ops are handled in a generic fashion.
+/// In the future, we may wish to allow more input arguments and elementwise and
+/// constant operations that do not involve the reduction dimension(s).
+LogicalResult mlir::linalg::detail::verifyContractionInterface(Operation *op) {
+  auto res = isContractionInterfaceImpl(op);
+  if (res == MatchContractionResult::NotLinalgOp)
+    return op->emitError("expected a LinalgOp");
+  if (res == MatchContractionResult::WrongNumOperands)
+    return op->emitError("expected op with 2 inputs and 1 outputs");
+  if (res == MatchContractionResult::NoReduction)
+    return op->emitError("expected at least a reduction loop");
+  if (res == MatchContractionResult::NotProjectedPermutations)
+    return op->emitError("expected all indexings to be projected permutations");
+  if (res == MatchContractionResult::NotAddMul)
+    return op->emitError("(add, mul) operations not found");
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// StructuredOpInterface implementation
+//===----------------------------------------------------------------------===//
+
 /// Fully compose map with operands and canonicalize the result.
 /// Return the `createOrFold`'ed AffineApply op.
 static Value createFoldedComposedAffineApply(OpBuilder &b, Location loc,
index fc647ea..8dac82a 100644 (file)
@@ -48,44 +48,48 @@ const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker =
 
 mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter(
     ArrayRef<Identifier> matchDisjunction, Optional<Identifier> replacement)
-    : LinalgTransformationFilter([](Operation *) { return success(); },
-                                 matchDisjunction, replacement) {}
+    : matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()),
+      replacement(replacement) {}
 
 mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter(
     FilterFunction f, ArrayRef<Identifier> matchDisjunction,
     Optional<Identifier> replacement)
-    : filter(f),
+    : filters(),
       matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()),
-      replacement(replacement) {}
+      replacement(replacement) {
+  if (f)
+    filters.push_back(f);
+}
 
 LogicalResult mlir::linalg::LinalgTransformationFilter::checkAndNotify(
     PatternRewriter &rewriter, Operation *op) const {
-  if (filter && failed(filter(op)))
+  if (llvm::any_of(filters,
+                   [&](const FilterFunction &f) { return failed(f(op)); }))
     return failure();
 
   auto attr = op->template getAttrOfType<StringAttr>(
       LinalgTransforms::kLinalgTransformMarker);
 
   if (!attr) {
-    // 1. Has no marker case and matchDisjunction is empty.
+    // 1. Has no filter case and matchDisjunction is empty.
     if (matchDisjunction.empty())
       return success();
 
-    // 2. Has no marker but was expecting a marker.
+    // 2. Has no filter but was expecting a filter.
     return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
-      diag << " does not have any marker from list: ";
+      diag << " does not have any filter from list: ";
       interleaveComma(matchDisjunction, diag);
     });
   }
 
-  // 4. Match explicit marker.
-  for (auto marker : matchDisjunction)
-    if (attr.getValue() == marker)
+  // 4. Match explicit filter.
+  for (auto filter : matchDisjunction)
+    if (attr.getValue() == filter)
       return success();
 
   // 5. Fail to match.
   return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
-    diag << " does not have any marker from list: ";
+    diag << " does not have any filter from list: ";
     interleaveComma(matchDisjunction, diag);
   });
 }
@@ -229,14 +233,14 @@ static LogicalResult rewriteAsPaddedOp(PatternRewriter &rewriter,
 /// Linalg base tiling pattern.
 mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern(
     StringRef opName, MLIRContext *context, LinalgTilingOptions options,
-    LinalgTransformationFilter marker, PatternBenefit benefit)
-    : RewritePattern(opName, {}, benefit, context), marker(marker),
+    LinalgTransformationFilter filter, PatternBenefit benefit)
+    : RewritePattern(opName, {}, benefit, context), filter(filter),
       options(options) {}
 
 mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern(
-    LinalgTilingOptions options, LinalgTransformationFilter marker,
+    LinalgTilingOptions options, LinalgTransformationFilter filter,
     PatternBenefit benefit)
-    : RewritePattern(benefit, MatchAnyOpTypeTag()), marker(marker),
+    : RewritePattern(benefit, MatchAnyOpTypeTag()), filter(filter),
       options(options) {}
 
 LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewriteBase(
@@ -244,7 +248,7 @@ LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewriteBase(
   LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
   if (!linalgOp)
     return failure();
-  if (failed(marker.checkAndNotify(rewriter, linalgOp)))
+  if (failed(filter.checkAndNotify(rewriter, linalgOp)))
     return failure();
 
   Optional<TiledLinalgOp> res = tileLinalgOp(rewriter, linalgOp, options);
@@ -260,10 +264,10 @@ LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewriteBase(
       return;
     // Return relevant information to derived pattern.
     result = *res;
-    // Replace marker on both tiledOp and tiledAndPaddedOp, if necessary.
-    marker.replaceLinalgTransformationFilter(rewriter, tiledOp);
+    // Replace filter on both tiledOp and tiledAndPaddedOp, if necessary.
+    filter.replaceLinalgTransformationFilter(rewriter, tiledOp);
     if (tiledOp != res->op)
-      marker.replaceLinalgTransformationFilter(rewriter, res->op);
+      filter.replaceLinalgTransformationFilter(rewriter, res->op);
   });
 
   // Consider padding on the fly only if the op has tensor semantics.
@@ -300,11 +304,11 @@ mlir::linalg::LinalgBaseTileAndFusePattern::LinalgBaseTileAndFusePattern(
     StringRef opName, MLIRContext *context,
     const LinalgDependenceGraph &dependenceGraph,
     LinalgTilingOptions tilingOptions, LinalgFusionOptions fusionOptions,
-    LinalgTransformationFilter marker, LinalgTransformationFilter fusedOpMarker,
+    LinalgTransformationFilter filter, LinalgTransformationFilter fusedOpMarker,
     LinalgTransformationFilter originalOpMarker, PatternBenefit benefit)
     : RewritePattern(opName, {}, benefit, context),
       dependenceGraph(dependenceGraph), tilingOptions(tilingOptions),
-      fusionOptions(fusionOptions), marker(marker),
+      fusionOptions(fusionOptions), filter(filter),
       fusedOpMarker(fusedOpMarker), originalOpMarker(originalOpMarker) {}
 
 LogicalResult mlir::linalg::LinalgBaseTileAndFusePattern::matchAndRewrite(
@@ -312,7 +316,7 @@ LogicalResult mlir::linalg::LinalgBaseTileAndFusePattern::matchAndRewrite(
   LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
   if (!linalgOp)
     return failure();
-  if (failed(marker.checkAndNotify(rewriter, linalgOp)))
+  if (failed(filter.checkAndNotify(rewriter, linalgOp)))
     return failure();
 
   DenseSet<Operation *> producers;
@@ -376,7 +380,7 @@ LogicalResult mlir::linalg::LinalgBaseTileAndFusePattern::matchAndRewrite(
   }
   op->replaceAllUsesWith(getTiledAndFusedOpResult(tiledAndFusedOps.getValue()));
 
-  marker.replaceLinalgTransformationFilter(rewriter,
+  filter.replaceLinalgTransformationFilter(rewriter,
                                            tiledAndFusedOps->op.getOperation());
   for (auto fusedOp : tiledAndFusedOps->fusedProducers) {
     fusedOpMarker.replaceLinalgTransformationFilter(rewriter,
@@ -395,9 +399,9 @@ LogicalResult mlir::linalg::LinalgBaseTileAndFusePattern::matchAndRewrite(
 /// Linalg base interchange pattern.
 mlir::linalg::LinalgBaseInterchangePattern::LinalgBaseInterchangePattern(
     StringRef opName, MLIRContext *context,
-    ArrayRef<unsigned> interchangeVector, LinalgTransformationFilter marker,
+    ArrayRef<unsigned> interchangeVector, LinalgTransformationFilter filter,
     PatternBenefit benefit)
-    : RewritePattern(opName, {}, benefit, context), marker(marker),
+    : RewritePattern(opName, {}, benefit, context), filter(filter),
       interchangeVector(interchangeVector.begin(), interchangeVector.end()) {}
 
 LogicalResult mlir::linalg::LinalgBaseInterchangePattern::matchAndRewrite(
@@ -405,7 +409,7 @@ LogicalResult mlir::linalg::LinalgBaseInterchangePattern::matchAndRewrite(
   LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
   if (!linalgOp)
     return failure();
-  if (failed(marker.checkAndNotify(rewriter, linalgOp)))
+  if (failed(filter.checkAndNotify(rewriter, linalgOp)))
     return failure();
   if (failed(interchangeGenericLinalgOpPrecondition(op, interchangeVector)))
     return failure();
@@ -414,21 +418,21 @@ LogicalResult mlir::linalg::LinalgBaseInterchangePattern::matchAndRewrite(
   // should break the named op property.
   rewriter.updateRootInPlace(op, [&]() {
     interchange(linalgOp, interchangeVector);
-    // New marker if specified.
-    marker.replaceLinalgTransformationFilter(rewriter, op);
+    // New filter if specified.
+    filter.replaceLinalgTransformationFilter(rewriter, op);
   });
   return success();
 }
 
 mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern(
     StringRef opName, MLIRContext *context, LinalgPromotionOptions options,
-    LinalgTransformationFilter marker, PatternBenefit benefit)
-    : RewritePattern(opName, {}, benefit, context), marker(marker),
+    LinalgTransformationFilter filter, PatternBenefit benefit)
+    : RewritePattern(opName, {}, benefit, context), filter(filter),
       options(options) {}
 
 LogicalResult mlir::linalg::LinalgBasePromotionPattern::matchAndRewrite(
     Operation *op, PatternRewriter &rewriter) const {
-  if (failed(marker.checkAndNotify(rewriter, op)))
+  if (failed(filter.checkAndNotify(rewriter, op)))
     return failure();
   if (failed(promoteSubviewsPrecondition(op, options)))
     return failure();
@@ -444,21 +448,25 @@ LogicalResult mlir::linalg::LinalgBasePromotionPattern::matchAndRewrite(
     return op->emitError("subview promotion failed");
   }
   rewriter.finalizeRootUpdate(op);
-  marker.replaceLinalgTransformationFilter(rewriter, op);
+  filter.replaceLinalgTransformationFilter(rewriter, op);
   return success();
 }
 
 mlir::linalg::LinalgBaseVectorizationPattern::LinalgBaseVectorizationPattern(
-    StringRef opName, MLIRContext *context, LinalgTransformationFilter marker,
+    LinalgTransformationFilter filter, PatternBenefit benefit)
+    : RewritePattern(benefit, MatchAnyOpTypeTag()), filter(filter) {}
+
+mlir::linalg::LinalgBaseVectorizationPattern::LinalgBaseVectorizationPattern(
+    StringRef opName, MLIRContext *context, LinalgTransformationFilter filter,
     PatternBenefit benefit)
-    : RewritePattern(opName, {}, benefit, context), marker(marker) {}
+    : RewritePattern(opName, {}, benefit, context), filter(filter) {}
 
 LogicalResult mlir::linalg::LinalgBaseVectorizationPattern::matchAndRewrite(
     Operation *op, PatternRewriter &rewriter) const {
   LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
   if (!linalgOp)
     return failure();
-  if (failed(marker.checkAndNotify(rewriter, linalgOp)))
+  if (failed(filter.checkAndNotify(rewriter, linalgOp)))
     return failure();
   if (failed(vectorizeLinalgOpPrecondition(op)))
     return failure();
index fb9d452..516606b 100644 (file)
@@ -369,36 +369,6 @@ static LogicalResult vectorizeAsLinalgGeneric(
   return success();
 }
 
-/// Detect whether the LinalgOp `op` is a contraction.
-/// A Linalg contraction is defined in general terms:
-///   1. Has 2 input and 1 output shapes.
-///   2. Has at least one reduction dimension.
-///   3. Has only projected permutation indexing maps.
-///   4. its body computes `u5(u1(c) + u2(u3(a) * u4(b)))` on some field
-///   (AddOpType, MulOpType), where u1, u2, u3, u4 and u5 represent scalar unary
-///   operations that may change the type (e.g. for mixed-precision).
-/// As a consequence, when vectorization of such an op occurs, the only special
-/// behavior is that the (unique) MulOpType is vectorized into a
-/// `vector.contract`. All other ops are handled in a generic fashion.
-/// In the future, we may wish to allow more input arguments and elementwise and
-/// constant operations that do not involve the reduction dimension(s).
-static LogicalResult isContraction(Operation *op) {
-  LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: isContraction: "; op->dump());
-  auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
-  if (!linalgOp)
-    return failure();
-
-  auto mapRange = linalgOp.indexing_maps().getAsValueRange<AffineMapAttr>();
-  return success(
-      linalgOp.getNumInputs() == 2 && linalgOp.getNumOutputs() == 1 &&
-      linalgOp.getNumReductionLoops() > 0 &&
-      llvm::all_of(mapRange,
-                   [](AffineMap m) { return m.isProjectedPermutation(); }) &&
-      // TODO: more fields than add/mul.
-      (isAddMul<AddFOp, MulFOp>(linalgOp->getRegion(0).front()) ||
-       isAddMul<AddIOp, MulIOp>(linalgOp->getRegion(0).front())));
-}
-
 /// Detect whether `r` has only ConstantOp, ElementwiseMappable and YieldOp.
 static bool hasOnlyScalarElementwiseOp(Region &r) {
   if (!llvm::hasSingleElement(r))
@@ -435,6 +405,40 @@ static bool isElementwise(Operation *op) {
   return hasOnlyScalarElementwiseOp(genericOp.getRegion());
 }
 
+static void vectorizeContraction(OpBuilder &builder, LinalgOp linalgOp) {
+  assert(isaContractionOpInterface(linalgOp) &&
+         "expected vectorizeContraction preconditions to be met");
+  Location loc = linalgOp.getLoc();
+  // Vectorize other ops as vector contraction.
+  // TODO: interface.
+  LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: "
+                    << "Rewrite linalg op as vector.contract: ";
+             linalgOp.dump());
+  // Special function that describes how to vectorize the multiplication op in a
+  // linalg contraction.
+  CustomVectorizationHook vectorizeContraction =
+      [&](Operation *op,
+          const BlockAndValueMapping &bvm) -> VectorizationResult {
+    if (!isa<MulIOp, MulFOp>(op))
+      return VectorizationResult{VectorizationStatus::Failure, nullptr};
+    auto outShape = linalgOp.getOutputShapedType(0).getShape();
+    auto vType = outShape.empty()
+                     ? op->getResult(0).getType()
+                     : VectorType::get(outShape, op->getResult(0).getType());
+    auto zero =
+        builder.create<ConstantOp>(loc, vType, builder.getZeroAttr(vType));
+    Operation *contract = builder.create<vector::ContractionOp>(
+        loc, bvm.lookup(op->getOperand(0)), bvm.lookup(op->getOperand(1)), zero,
+        linalgOp.indexing_maps(), linalgOp.iterator_types());
+    return VectorizationResult{VectorizationStatus::NewOp, contract};
+  };
+  auto status =
+      vectorizeAsLinalgGeneric(builder, linalgOp, {vectorizeContraction});
+  (void)status;
+  assert(succeeded(status) &&
+         "Unexpected vectorization failed despite preconditions");
+}
+
 LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
   auto linalgOp = cast<linalg::LinalgOp>(op);
   // All types must be static shape to go to vector.
@@ -449,27 +453,25 @@ LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
     return success();
   if (isElementwise(op))
     return success();
-  return isContraction(op);
+  return success(isaContractionOpInterface(linalgOp));
 }
 
 void mlir::linalg::vectorizeLinalgOp(OpBuilder &builder, Operation *op) {
   assert(succeeded(vectorizeLinalgOpPrecondition(op)));
 
-  StringRef dbgPref = "\n[" DEBUG_TYPE "]: ";
-  (void)dbgPref;
   edsc::ScopedContext scope(builder, op->getLoc());
   // In the case of 0-D memrefs, return null and special case to scalar load or
   // store later.
   if (auto fillOp = dyn_cast<linalg::FillOp>(op)) {
     // Vectorize fill as a vector.broadcast.
-    LLVM_DEBUG(dbgs() << dbgPref
+    LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: "
                       << "Rewrite linalg.fill as vector.broadcast: " << *op);
     buildVectorWrite(builder, fillOp.value(), fillOp.output());
     return;
   }
   if (auto copyOp = dyn_cast<linalg::CopyOp>(op)) {
     // Vectorize copy as a vector.transfer_read+vector.transfer_write.
-    LLVM_DEBUG(dbgs() << dbgPref
+    LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: "
                       << "Rewrite linalg.copy as vector.transfer_read + "
                          "vector.transfer_write: "
                       << *op);
@@ -478,48 +480,17 @@ void mlir::linalg::vectorizeLinalgOp(OpBuilder &builder, Operation *op) {
     return;
   }
 
-  auto linalgOp = cast<linalg::LinalgOp>(op);
-  Location loc = linalgOp.getLoc();
-
   if (isElementwise(op)) {
-    LLVM_DEBUG(dbgs() << dbgPref
+    LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: "
                       << "Rewrite linalg op as vector.transfer_read + " << *op);
-    auto status = vectorizeAsLinalgGeneric(builder, linalgOp);
+    auto status = vectorizeAsLinalgGeneric(builder, cast<LinalgOp>(op));
     (void)status;
     assert(succeeded(status) &&
            "Unexpected vectorization failed despite preconditions");
     return;
   }
 
-  assert(succeeded(isContraction(op)) && "Expected contraction");
-
-  // Vectorize other ops as vector contraction.
-  // TODO: interface.
-  LLVM_DEBUG(dbgs() << dbgPref
-                    << "Rewrite linalg op as vector.contract: " << *op);
-  // Special function that describes how to vectorize the multiplication op in a
-  // linalg contraction.
-  CustomVectorizationHook vectorizeContraction =
-      [&](Operation *op,
-          const BlockAndValueMapping &bvm) -> VectorizationResult {
-    if (!isa<MulIOp, MulFOp>(op))
-      return VectorizationResult{VectorizationStatus::Failure, nullptr};
-    auto outShape = linalgOp.getOutputShapedType(0).getShape();
-    auto vType = outShape.empty()
-                     ? op->getResult(0).getType()
-                     : VectorType::get(outShape, op->getResult(0).getType());
-    auto zero =
-        builder.create<ConstantOp>(loc, vType, builder.getZeroAttr(vType));
-    Operation *contract = builder.create<vector::ContractionOp>(
-        loc, bvm.lookup(op->getOperand(0)), bvm.lookup(op->getOperand(1)), zero,
-        linalgOp.indexing_maps(), linalgOp.iterator_types());
-    return VectorizationResult{VectorizationStatus::NewOp, contract};
-  };
-  auto status =
-      vectorizeAsLinalgGeneric(builder, linalgOp, {vectorizeContraction});
-  (void)status;
-  assert(succeeded(status) &&
-         "Unexpected vectorization failed despite preconditions");
+  vectorizeContraction(builder, cast<LinalgOp>(op));
 }
 
 //----------------------------------------------------------------------------//
@@ -671,13 +642,11 @@ void mlir::linalg::populateConvVectorizationPatterns(
 /// different block.
 static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp,
                                     ValueRange values) {
-  StringRef dbgPref = "\n[" DEBUG_TYPE "]: ";
-  (void)dbgPref;
   if (firstOp->getBlock() != secondOp->getBlock() ||
       !firstOp->isBeforeInBlock(secondOp)) {
-    LLVM_DEBUG(llvm::dbgs()
-               << dbgPref << "interleavedUses precondition failed, firstOp: "
-               << *firstOp << ", second op: " << *secondOp);
+    LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "]: "
+                            << "interleavedUses precondition failed, firstOp: "
+                            << *firstOp << ", second op: " << *secondOp);
     return true;
   }
   for (auto v : values) {
@@ -690,7 +659,8 @@ static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp,
           (owner->isBeforeInBlock(firstOp) || secondOp->isBeforeInBlock(owner)))
         continue;
       LLVM_DEBUG(llvm::dbgs()
-                 << dbgPref << " found interleaved op " << *owner
+                 << "\n[" DEBUG_TYPE "]: "
+                 << " found interleaved op " << *owner
                  << ", firstOp: " << *firstOp << ", second op: " << *secondOp);
       return true;
     }
@@ -722,16 +692,15 @@ LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite(
       !viewOrAlloc.getDefiningOp<AllocOp>())
     return failure();
 
-  StringRef dbgPref = "\n[" DEBUG_TYPE "]: VTRForwarding: ";
-  (void)dbgPref;
-  LLVM_DEBUG(llvm::dbgs() << dbgPref << viewOrAlloc);
+  LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "]: " << viewOrAlloc);
 
   // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
   SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
   if (!subViewOp)
     return failure();
   Value subView = subViewOp.getResult();
-  LLVM_DEBUG(llvm::dbgs() << dbgPref << "with subView " << subView);
+  LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "]: "
+                          << "with subView " << subView);
 
   // Find the copy into `subView` without interleaved uses.
   CopyOp copyOp;
@@ -739,7 +708,8 @@ LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite(
     if (auto newCopyOp = dyn_cast<CopyOp>(u.getOwner())) {
       if (newCopyOp.getOutputBuffer(0) != subView)
         continue;
-      LLVM_DEBUG(llvm::dbgs() << dbgPref << "copy candidate " << *newCopyOp);
+      LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "]: "
+                              << "copy candidate " << *newCopyOp);
       if (mayExistInterleavedUses(newCopyOp, xferOp, {viewOrAlloc, subView}))
         continue;
       copyOp = newCopyOp;
@@ -748,7 +718,8 @@ LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite(
   }
   if (!copyOp)
     return failure();
-  LLVM_DEBUG(llvm::dbgs() << dbgPref << "with copy " << *copyOp);
+  LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "]: "
+                          << "with copy " << *copyOp);
 
   // Find the fill into `viewOrAlloc` without interleaved uses before the copy.
   FillOp maybeFillOp;
@@ -756,7 +727,8 @@ LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite(
     if (auto newFillOp = dyn_cast<FillOp>(u.getOwner())) {
       if (newFillOp.getOutputBuffer(0) != viewOrAlloc)
         continue;
-      LLVM_DEBUG(llvm::dbgs() << dbgPref << "fill candidate " << *newFillOp);
+      LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "]: "
+                              << "fill candidate " << *newFillOp);
       if (mayExistInterleavedUses(newFillOp, copyOp, {viewOrAlloc, subView}))
         continue;
       maybeFillOp = newFillOp;
@@ -767,7 +739,8 @@ LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite(
   if (maybeFillOp && xferOp.padding() != maybeFillOp.value())
     return failure();
   if (maybeFillOp)
-    LLVM_DEBUG(llvm::dbgs() << dbgPref << "with maybeFillOp " << *maybeFillOp);
+    LLVM_DEBUG(llvm::dbgs() << "\n[" DEBUG_TYPE "]: "
+                            << "with maybeFillOp " << *maybeFillOp);
 
   // `in` is the subview that linalg.copy reads. Replace it.
   Value in = copyOp.getInput(0);
index d125e06..566128c 100644 (file)
@@ -132,7 +132,7 @@ void TestLinalgCodegenStrategy::runStrategy<LinalgOp>(
           LinalgPromotionOptions()
               .setAlignment(16)
               .setUseFullTileBuffersByDefault(registerPromoteFullTile))
-      .vectorizeIf<LinalgOp>(vectorize, anchorOpName)
+      .vectorizeIf(vectorize, anchorOpName)
       .setVectorTransformsOptions(
           vector::VectorTransformsOptions()
               .setVectorTransformsOptions(vectorContractLowering)
index 27ca994..3a6ac1d 100644 (file)
@@ -181,12 +181,9 @@ static void applyPatterns(FuncOp funcOp) {
   //===--------------------------------------------------------------------===//
   // Linalg to vector contraction patterns.
   //===--------------------------------------------------------------------===//
-  patterns.insert<LinalgVectorizationPattern<MatmulOp>,
-                  LinalgVectorizationPattern<FillOp>,
-                  LinalgVectorizationPattern<CopyOp>,
-                  LinalgVectorizationPattern<GenericOp>>(
-      ctx, LinalgVectorizationOptions(),
-      LinalgTransformationFilter(Identifier::get("VECTORIZE", ctx)));
+  patterns.insert<LinalgVectorizationPattern>(
+      LinalgTransformationFilter(Identifier::get("VECTORIZE", ctx))
+          .addOpFilter<MatmulOp, FillOp, CopyOp, GenericOp>());
 
   //===--------------------------------------------------------------------===//
   // Linalg generic permutation patterns.
@@ -251,13 +248,12 @@ static void fillL1TilingAndMatmulToVectorPatterns(
           LinalgTransformationFilter(Identifier::get("L1", ctx),
                                      Identifier::get("VEC", ctx))));
 
-  patternsVector.emplace_back(
-      std::make_unique<LinalgVectorizationPattern<MatmulOp>>(
-          ctx, LinalgVectorizationOptions(),
-          LinalgTransformationFilter(Identifier::get("VEC", ctx))));
-  patternsVector.back()
-      .insert<LinalgVectorizationPattern<FillOp>,
-              LinalgVectorizationPattern<CopyOp>>(ctx);
+  patternsVector.emplace_back(std::make_unique<LinalgVectorizationPattern>(
+      MatmulOp::getOperationName(), ctx, LinalgVectorizationOptions(),
+      LinalgTransformationFilter(Identifier::get("VEC", ctx))));
+  patternsVector.back().insert<LinalgVectorizationPattern>(
+      LinalgTransformationFilter().addFilter(
+          [](Operation *op) { return success(isa<FillOp, CopyOp>(op)); }));
 }
 
 //===----------------------------------------------------------------------===//
@@ -493,15 +489,9 @@ static void applyVectorTransferForwardingPatterns(FuncOp funcOp) {
 
 static void applyLinalgToVectorPatterns(FuncOp funcOp) {
   OwningRewritePatternList patterns;
-  // TODO: remove all this in favor of a single LinalgOp.
-  patterns.insert<
-      LinalgVectorizationPattern<BatchMatmulOp>,
-      LinalgVectorizationPattern<MatmulOp>,
-      LinalgVectorizationPattern<MatmulI8I8I32Op>,
-      LinalgVectorizationPattern<MatvecOp>,
-      LinalgVectorizationPattern<VecmatOp>, LinalgVectorizationPattern<DotOp>,
-      LinalgVectorizationPattern<FillOp>, LinalgVectorizationPattern<CopyOp>,
-      LinalgVectorizationPattern<GenericOp>>(funcOp.getContext());
+  patterns.insert<LinalgVectorizationPattern>(
+      LinalgTransformationFilter()
+          .addOpFilter<ContractionOpInterface, FillOp, CopyOp, GenericOp>());
   applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
 }
 
index 9bf7630..3ca9909 100644 (file)
@@ -106,6 +106,7 @@ public:
     kw_def,
     FIRST_KEYWORD = kw_def,
     kw_ods_def,
+    kw_implements_interface,
     kw_attr_def,
     kw_floordiv,
     kw_ceildiv,
@@ -319,14 +320,16 @@ Token Lexer::lexIdentifier(const char *tokStart) {
 
   // Check to see if this identifier is a keyword.
   StringRef str(tokStart, curPtr - tokStart);
-  Token::Kind kind = StringSwitch<Token::Kind>(str)
-                         .Case("attr", Token::Kind::kw_attr_def)
-                         .Case("def", Token::Kind::kw_def)
-                         .Case("ods_def", Token::Kind::kw_ods_def)
-                         .Case("floordiv", Token::Kind::kw_floordiv)
-                         .Case("ceildiv", Token::Kind::kw_ceildiv)
-                         .Case("mod", Token::Kind::kw_mod)
-                         .Default(Token::Kind::id);
+  Token::Kind kind =
+      StringSwitch<Token::Kind>(str)
+          .Case("attr", Token::Kind::kw_attr_def)
+          .Case("def", Token::Kind::kw_def)
+          .Case("ods_def", Token::Kind::kw_ods_def)
+          .Case("implements_interface", Token::Kind::kw_implements_interface)
+          .Case("floordiv", Token::Kind::kw_floordiv)
+          .Case("ceildiv", Token::Kind::kw_ceildiv)
+          .Case("mod", Token::Kind::kw_mod)
+          .Default(Token::Kind::id);
 
   return Token(kind, str);
 }
@@ -1111,7 +1114,8 @@ public:
 
   /// Print the ODS class that defines a new `cppOpName` for a `linalgOpName`.
   void printODS(llvm::raw_ostream &os, StringRef cppOpName,
-                StringRef linalgOpName, ComprehensionParsingState &state);
+                StringRef linalgOpName, ArrayRef<StringRef> interfaces,
+                ComprehensionParsingState &state);
 
   /// Print the C++ StructuredOpsInterface impl of `iterator_types`.
   void printReferenceIterators(llvm::raw_ostream &os, StringRef cppOpName,
@@ -1635,29 +1639,54 @@ TCParser::parseOneComprehension(StringRef cppOpName, StringRef linalgOpName,
 ///     (tc-attr-def)?
 ///     `{` comprehension-list `}`
 ///
-///   ods-def ::= `ods_def` `<` bare-id `>` `:` tc-def
+///   implements-interface ::=
+///     `implements_interface` `<` bare-id (`,` bare-id)* `>` `:` tc-def
+///
+///   ods-def ::= `ods_def` `<` bare-id `>`
+///               (implements-interface)? `:`
+///               tc-def
 ///
 /// All the affine-expr in a `tensor-typedef` must be dimensionless (i.e.
 /// contain only expressions involving symbols and constants), but can
 /// otherwise contain arbitrary affine expressions.
 LogicalResult TCParser::parseAndEmitODSDef(llvm::raw_ostream &os) {
-  // Parse def header (including C++ op name)
+  // Parse ods-def header (including C++ op name)
   if (failed(parser.parseToken(Token::Kind::kw_ods_def,
                                "expected 'ods_def' to define a TC ODS")) ||
       failed(parser.parseToken(Token::Kind::lt, "expected '<'")))
     return failure();
   StringRef cppOpName = parser.curToken.getSpelling();
   LLVM_DEBUG(llvm::dbgs() << "\n\nStart parsing ODS: " << cppOpName << "\n");
-
   if (failed(parser.parseToken(Token::Kind::id, "expected id")) ||
-      failed(parser.parseToken(Token::Kind::gt, "expected '>'")) ||
-      failed(parser.parseToken(Token::Kind::colon, "expected ':'")))
+      failed(parser.parseToken(Token::Kind::gt, "expected '>'")))
     return failure();
 
+  // Parse optional implements-interface header (including C++ op names)
+  SmallVector<StringRef> interfaces;
+  bool implementsInterface = succeeded(
+      parser.parseOptionalToken(Token::Kind::kw_implements_interface));
+  if (implementsInterface) {
+    auto parseInterfaceString = [&]() -> LogicalResult {
+      StringRef interfaceName = parser.curToken.getSpelling();
+      if (failed(parser.parseToken(Token::Kind::id, "expected id")))
+        return failure();
+      interfaces.push_back(interfaceName);
+      return success();
+    };
+    if (failed(parser.parseToken(Token::Kind::lt, "expected '<'")) ||
+        failed(parser.parseCommaSeparatedListUntil(
+            Token::Kind::gt, parseInterfaceString, /*allowEmptyList=*/false)))
+      return failure();
+  }
+
+  // Parse column.
+  if (failed(parser.parseToken(Token::Kind::colon, "expected ':'")))
+    return failure();
+
+  // Parse TC op name.
   if (failed(parser.parseToken(Token::Kind::kw_def,
                                "expected 'def' to define a TC")))
     return failure();
-
   StringRef tcName = parser.curToken.getSpelling();
   LLVM_DEBUG(llvm::dbgs() << "\n\nStart parsing TC: " << tcName << "\n");
 
@@ -1734,7 +1763,7 @@ LogicalResult TCParser::parseAndEmitODSDef(llvm::raw_ostream &os) {
   }
   if (genODSDecl) {
     auto &state = perComprehensionStates.back();
-    printODS(os, cppOpName, tcName, state);
+    printODS(os, cppOpName, tcName, interfaces, state);
     os << "\n";
   }
   if (genODSImpl) {
@@ -1758,7 +1787,7 @@ LogicalResult TCParser::parseAndEmitODSDef(llvm::raw_ostream &os) {
 
 /// Print the ODS class that defines a new `cppOpName` for a `linalgOpName`.
 void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
-                        StringRef linalgOpName,
+                        StringRef linalgOpName, ArrayRef<StringRef> interfaces,
                         ComprehensionParsingState &state) {
   SmallVector<std::string, 4> attributes;
   for (const auto &attr : registeredAttrs) {
@@ -1802,11 +1831,12 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
   const char *header = R"FMT(  def {0} : LinalgStructuredBase_Op<"{1}", [
     AttrSizedOperandSegments,
     DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
-    SingleBlockImplicitTerminator<"YieldOp">]> {
-      {2}
+    SingleBlockImplicitTerminator<"YieldOp">
+    /*extraInterfaces=*/{2}]> {
+      {3}
       let arguments = (ins
         Variadic<AnyShaped>:$inputs,
-        Variadic<AnyShaped>:$outputs{3}
+        Variadic<AnyShaped>:$outputs{4}
       );
       let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
       let regions = (region AnyRegion:$region);
@@ -1856,7 +1886,7 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
           $_state.addTypes(resultTensorTypes);
           (void)$_state.addRegion();
         }]>
-        {5}
+        {6}
       ];
       let printer = [{{ return ::printNamedStructuredOp(p, *this); }];
       let parser = [{{ return ::parseNamedStructuredOp<{0}>(parser, result); }];
@@ -1873,13 +1903,22 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
         }
 
         // Generic methods.
-        static unsigned getNumRegionArgs() {{ return {4}; }
+        static unsigned getNumRegionArgs() {{ return {5}; }
         std::string getLibraryCallName() {{
           return generateLibraryCallName(getOperation());
         }
       }];
   })FMT";
 
+  // Generate the list of extra implemented interfaces.
+  std::string interfaceNameList;
+  if (!interfaces.empty()) {
+    llvm::raw_string_ostream ss(interfaceNameList);
+    ss << ", "; // Leading comma to concat to existing list of interfaces.
+    llvm::interleaveComma(interfaces, ss);
+    ss.flush();
+  }
+
   // Generate documentation.
   std::string doc;
   if (!docString.empty()) {
@@ -1934,8 +1973,8 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
   }
 
   // Finally put everything together.
-  os << llvm::formatv(header, cppOpName, linalgOpName, doc, attrList,
-                      state.orderedTensorArgs.size(), attrBuilder);
+  os << llvm::formatv(header, cppOpName, linalgOpName, interfaceNameList, doc,
+                      attrList, state.orderedTensorArgs.size(), attrBuilder);
 }
 
 /// Print the C++ StructuredOpsInterface impl of `iterator_types`.
@@ -2085,8 +2124,8 @@ void TCParser::printReferenceIndexingMaps(llvm::raw_ostream &os,
     // Note that we use `0` as the result affine map's number of symbols. All
     // symbols representing attribute usages should be folded away. But there
     // may exist additional symbols for tensor dimension upper bounds. Linalg
-    // does not handle such cases right now. This needs to be fixed once we need
-    // that.
+    // does not handle such cases right now. This needs to be fixed once we
+    // need that.
     const char *replaceFmt =
         "\n\tmap{0} = map{0}.replaceDimsAndSymbols({{}, {1}, {2}, 0);";
     mapsStringStream << llvm::formatv(replaceFmt, tensorUse.index(),