Add support for a Linalg base op class
authorNicolas Vasilache <ntv@google.com>
Wed, 15 May 2019 02:37:48 +0000 (19:37 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Mon, 20 May 2019 20:42:21 +0000 (13:42 -0700)
    This CL uses a pattern proposed by aminim@ to add a base Linalg op that further dispatches to the proper op implementation.
    This CL adds a LinalgOp which implements isclassof for only a subset of the linalg ops: the ops that behave like a library call for the purpose of transformations like tiling.
    This uses a static dispatch mechanism based on the LinalgLibraryOps.td ops declarations to avoid switch or visitor patterns. This may later be replaced by Tablegen'd dispatch when it is available.

    As a consequence, the list of library like operations in Linalg may now grow without having to modify any of the dispatch or transformation support.

    More details in the concept-based dispatch, as explained by aminim@
    ```
    This is inspired by Sean Parent's: https://sean-parent.stlab.cc/papers-and-presentations/#value-semantics-and-concept-based-polymorphism

    A key difference is that the set of classes erased is statically known, which avoids to use dynamic memory allocation.
    We use a zero-sized templated class to emit the virtual table and generate a singleton object for each instantiation of this class. We pay the cost of initialization once on construction (find which class to dispatch to) and then a virtual dispatch on every call.
    ```

--

PiperOrigin-RevId: 248258921

mlir/include/mlir/Linalg/IR/LinalgBase.td [new file with mode: 0644]
mlir/include/mlir/Linalg/IR/LinalgLibraryOps.td [new file with mode: 0644]
mlir/include/mlir/Linalg/IR/LinalgOps.h
mlir/include/mlir/Linalg/IR/LinalgOps.td
mlir/lib/Linalg/IR/LinalgOps.cpp
mlir/lib/Linalg/IR/LinalgTypes.cpp
mlir/lib/Linalg/Transforms/Tiling.cpp

diff --git a/mlir/include/mlir/Linalg/IR/LinalgBase.td b/mlir/include/mlir/Linalg/IR/LinalgBase.td
new file mode 100644 (file)
index 0000000..42e5bcd
--- /dev/null
@@ -0,0 +1,42 @@
+//===- LinalgBase.td - Linalg dialect base support ---------*- tablegen -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This is the definition file for base linear algebra support.
+//
+//===----------------------------------------------------------------------===//
+
+#ifdef LINALG_OPS
+#else
+
+#ifdef OP_BASE
+#else
+include "mlir/IR/OpBase.td"
+#endif // OP_BASE
+
+def Linalg_Dialect : Dialect {
+  let name = "linalg";
+}
+
+// Whether a type is a BufferType.
+def LinalgIsBufferTypePred : CPred<"$_self.isa<BufferType>()">;
+def Buffer : Type<LinalgIsBufferTypePred, "buffer">;
+
+// Whether a type is a ViewType.
+def LinalgIsViewTypePred : CPred<"$_self.isa<ViewType>()">;
+def View : Type<LinalgIsViewTypePred, "view">;
+
+#endif // LINALG_OPS
\ No newline at end of file
diff --git a/mlir/include/mlir/Linalg/IR/LinalgLibraryOps.td b/mlir/include/mlir/Linalg/IR/LinalgLibraryOps.td
new file mode 100644 (file)
index 0000000..15b9fab
--- /dev/null
@@ -0,0 +1,91 @@
+//===- LinalgLibraryOps.td - Linalg dialect library ops -*- tablegen ----*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This is the operation definition file for linear algebra operations that 
+// correspond to underlying library calls (e.g. BLAS).
+//
+//===----------------------------------------------------------------------===//
+
+#ifdef LINALG_OPS
+#else
+
+#ifdef LINALG_BASE
+#else
+include "mlir/Linalg/IR/LinalgBase.td"
+#endif // LINALG_BASE
+
+class LinalgParametricNativeOpTrait<string prop, string parameters> :
+  NativeOpTrait<"linalg::" # prop # parameters>
+{}
+
+class LinalgParametricIntNativeOpTrait<string prop, list<int> parameters> :
+  LinalgParametricNativeOpTrait<
+    prop,
+    !strconcat("<",
+               !cast<string>(!head(parameters)),
+               !foldl("",
+                      !tail(parameters),
+                      sum,
+                      param,
+                      sum # "," # !cast<string>(param)),
+               ">::Impl")>
+{}
+
+// The Linalg `NInputsAndOutputs` trait provides the API for ops that are known
+// to have a specified number of inputs and outputs, all passed as operands.
+// See Linalg/LinalgTraits.h for implementation details an usage.
+class NInputsAndOutputs<int n_ins, int n_outs> :
+  LinalgParametricIntNativeOpTrait<"NInputsAndOutputs", [n_ins, n_outs]>
+{}
+
+// The linalg `NLoopTypes` trait provides the API for ops that are known to have
+// a specified number of parallel (n_par), reduction (n_red) and window (n_win)
+// loops.
+// See Linalg/LinalgTraits.h for implementation details an usage.
+class NLoopTypes<int n_par, int n_red, int n_win> :
+LinalgParametricIntNativeOpTrait<"NLoopTypes", [n_par, n_red, n_win]>
+{}
+
+// The linalg `ViewRanks` trait the API for ops that are known to have a
+// specified list of view ranks.
+// See Linalg/LinalgTraits.h for implementation details an usage.
+class ViewRanks<list<int> ranks> :
+LinalgParametricIntNativeOpTrait<"ViewRanks", ranks>
+{}
+
+// Base Tablegen class for Linalg ops.
+class LinalgOp<string mnemonic, list<OpTrait> props> :
+Op<Linalg_Dialect, mnemonic, props> {
+  let arguments = (ins Variadic<View>); // default variadic builder
+  let parser = [{ return parseLinalgLibraryOp(parser, result); }];
+  let printer = [{ printLinalgLibraryOp(p, *this); }];
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// Concrete Linalg ops.
+////////////////////////////////////////////////////////////////////////////////
+def DotOp : LinalgOp<"dot", [NInputsAndOutputs<2, 1>,
+                             NLoopTypes<0, 1, 0>,
+                             ViewRanks<[1, 1, 0]>]> {}
+def MatvecOp : LinalgOp<"matvec", [NInputsAndOutputs<2, 1>,
+                                   NLoopTypes<1, 1, 0>,
+                                   ViewRanks<[2, 1, 1]>]> {}
+def MatmulOp : LinalgOp<"matmul", [NInputsAndOutputs<2, 1>,
+                                   NLoopTypes<2, 1, 0>,
+                                   ViewRanks<[2, 2, 2]>]> {}
+
+#endif // LINALG_OPS
\ No newline at end of file
index 6a6c953..0c48cb0 100644 (file)
@@ -18,6 +18,7 @@
 #ifndef MLIR_LINALG_LINALGOPS_H_
 #define MLIR_LINALG_LINALGOPS_H_
 
+#include "mlir/IR/Builders.h"
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/Linalg/IR/LinalgTraits.h"
 #include "mlir/Linalg/IR/LinalgTypes.h"
@@ -274,6 +275,9 @@ public:
 #define GET_OP_CLASSES
 #include "mlir/Linalg/IR/LinalgOps.h.inc"
 
+#define GET_OP_CLASSES
+#include "mlir/Linalg/IR/LinalgLibraryOps.h.inc"
+
 /// Returns the list of maps that map loops to operands of a Linalg op.
 /// The i-th affine map identifies loop indices to subscripts that are used when
 /// accessing the i-th operand.
@@ -292,6 +296,116 @@ public:
 /// Only permutation maps are currently supported.
 SmallVector<AffineMap, 4> loopToOperandRangesMaps(Operation *op);
 
+/// A LinalgOp behaves like a base class for the Linalg operations that are
+/// defined in LinalgLibraryOps.td. The implementation does not use inheritance
+/// directly. Instead, a LinalgOp directly derives from Op, hides the `classof`
+/// method and dispatches to the appropriate LinalgLibraryOp.
+/// This allows writing generic passes, like tiling, for all current and future
+/// LinalgOps without requiring templating and dispatch in multiple places.
+class LinalgOp : public Op<LinalgOp> {
+public:
+  using Op::Op;
+
+  LinalgOp(Operation *op) : Op<LinalgOp>(op) {
+    impl = ModelDispatch<
+#define GET_OP_LIST
+#include "mlir/Linalg/IR/LinalgLibraryOps.cpp.inc"
+        >::dispatch(op);
+  }
+
+  static bool classof(Operation *op) {
+    return ModelDispatch<
+#define GET_OP_LIST
+#include "mlir/Linalg/IR/LinalgLibraryOps.cpp.inc"
+        >::classof(op);
+  }
+
+  unsigned getNumParallelLoops() {
+    return impl->getNumParallelLoops(getOperation());
+  }
+  unsigned getNumReductionLoops() {
+    return impl->getNumReductionLoops(getOperation());
+  }
+  unsigned getNumWindowLoops() {
+    return impl->getNumWindowLoops(getOperation());
+  }
+  unsigned getNumInputsAndOutputs() {
+    return impl->getNumInputsAndOutputs(getOperation());
+  }
+  Operation *create(FuncBuilder &builder, Location loc,
+                    ArrayRef<Value *> operands) {
+    return impl->create(builder, loc, operands);
+  }
+
+private:
+  struct Concept {
+    virtual ~Concept() = default;
+    virtual unsigned getNumInputsAndOutputs(Operation *op) = 0;
+    virtual unsigned getNumParallelLoops(Operation *op) = 0;
+    virtual unsigned getNumReductionLoops(Operation *op) = 0;
+    virtual unsigned getNumWindowLoops(Operation *op) = 0;
+    virtual unsigned getNumLoops(Operation *op) = 0;
+    virtual Operation *create(FuncBuilder &builder, Location loc,
+                              ArrayRef<Value *> operands) = 0;
+  };
+
+  /// The implementation is inspired from Sean Parent's concept-based
+  /// polymorphism. A key difference is that the set of classes erased is
+  /// statically known, which alleviates the need for using dynamic memory
+  /// allocation.
+  /// We use a zero-sized templated class `Model<ConcreteOp>` to emit the
+  /// virtual table and generate a singleton object for each instantiation of
+  /// this class.
+  /// We pay the cost of initialization once on construction (find which class
+  /// to dispatch to) and then a virtual dispatch on every call.
+  template <typename ConcreteOp> struct Model : public Concept {
+    static Model<ConcreteOp> &instance() {
+      static Model<ConcreteOp> singleton;
+      return singleton;
+    }
+    unsigned getNumInputsAndOutputs(Operation *op) override {
+      return cast<ConcreteOp>(op).getNumInputsAndOutputs();
+    }
+    unsigned getNumParallelLoops(Operation *op) override {
+      return cast<ConcreteOp>(op).getNumParallelLoops();
+    }
+    unsigned getNumReductionLoops(Operation *op) override {
+      return cast<ConcreteOp>(op).getNumReductionLoops();
+    }
+    unsigned getNumWindowLoops(Operation *op) override {
+      return cast<ConcreteOp>(op).getNumWindowLoops();
+    }
+    unsigned getNumLoops(Operation *op) override {
+      return cast<ConcreteOp>(op).getNumLoops();
+    }
+    Operation *create(FuncBuilder &builder, Location loc,
+                      ArrayRef<Value *> operands) override {
+      return builder.create<ConcreteOp>(loc, operands);
+    }
+  };
+  Concept *impl;
+
+  template <typename... Types> struct ModelDispatch;
+
+  template <typename First, typename... Rest>
+  struct ModelDispatch<First, Rest...> {
+    static bool classof(Operation *op) {
+      return isa<First>(op) || ModelDispatch<Rest...>::classof(op);
+    }
+    static Concept *dispatch(Operation *op) {
+      return isa<First>(op) ? &Model<First>::instance()
+                            : ModelDispatch<Rest...>::dispatch(op);
+    }
+  };
+
+  template <typename...> struct ModelDispatch {
+    static bool classof(Operation *op) { return false; }
+    static Concept *dispatch(Operation *op) {
+      llvm_unreachable("Invalid LinalgOp");
+    }
+  };
+};
+
 } // namespace linalg
 } // namespace mlir
 
index 58eb3f0..ecdb111 100644 (file)
@@ -1,4 +1,4 @@
-//===- LinalgOps.td - Linear algebra dialect ops -----------*- tablegen -*-===//
+//===- LinalgOps.td - Linalg dialect ops -------------------*- tablegen -*-===//
 //
 // Copyright 2019 The MLIR Authors.
 //
 #ifdef LINALG_OPS
 #else
 
-#ifdef OP_BASE
+#ifdef LINALG_BASE
 #else
-include "mlir/IR/OpBase.td"
-#endif // OP_BASE
-
-def Linalg_Dialect : Dialect {
-  let name = "linalg";
-}
-
-// Whether a type is a BufferType.
-def LinalgIsBufferTypePred : CPred<"$_self.isa<BufferType>()">;
-def Buffer : Type<LinalgIsBufferTypePred, "buffer">;
-
-// Whether a type is a ViewType.
-def LinalgIsViewTypePred : CPred<"$_self.isa<ViewType>()">;
-def View : Type<LinalgIsViewTypePred, "view">;
-
-class LinalgParametricNativeOpTrait<string prop, string parameters> :
-  NativeOpTrait<"linalg::" # prop # parameters>
-{}
-
-class LinalgParametricIntNativeOpTrait<string prop, list<int> parameters> :
-  LinalgParametricNativeOpTrait<
-    prop,
-    !strconcat("<",
-               !cast<string>(!head(parameters)),
-               !foldl("",
-                      !tail(parameters),
-                      sum,
-                      param,
-                      sum # "," # !cast<string>(param)),
-               ">::Impl")>
-{}
-
-// The Linalg `NInputsAndOutputs` trait provides the API for ops that are known
-// to have a specified number of inputs and outputs, all passed as operands.
-// See Linalg/LinalgTraits.h for implementation details an usage.
-class NInputsAndOutputs<int n_ins, int n_outs> :
-  LinalgParametricIntNativeOpTrait<"NInputsAndOutputs", [n_ins, n_outs]>
-{}
-
-// The linalg `NLoopTypes` trait provides the API for ops that are known to have
-// a specified number of parallel (n_par), reduction (n_red) and window (n_win)
-// loops.
-// See Linalg/LinalgTraits.h for implementation details an usage.
-class NLoopTypes<int n_par, int n_red, int n_win> :
-LinalgParametricIntNativeOpTrait<"NLoopTypes", [n_par, n_red, n_win]>
-{}
-
-// The linalg `ViewRanks` trait the API for ops that are known to have a
-// specified list of view ranks.
-// See Linalg/LinalgTraits.h for implementation details an usage.
-class ViewRanks<list<int> ranks> :
-LinalgParametricIntNativeOpTrait<"ViewRanks", ranks>
-{}
-
-// Base Tablegen class for Linalg ops.
-class LinalgOp<string mnemonic, list<OpTrait> props> :
-Op<Linalg_Dialect, mnemonic, props> {
-  let arguments = (ins Variadic<View>); // default variadic builder
-  let parser = [{ return parseLinalgLibraryOp(parser, result); }];
-  let printer = [{ printLinalgLibraryOp(p, *this); }];
-}
+include "mlir/Linalg/IR/LinalgBase.td"
+#endif // LINALG_BASE
 
 def BufferSizeOp :
     Op<Linalg_Dialect, "buffer_size", [NoSideEffect]>,
@@ -127,17 +68,4 @@ def DimOp : Op<Linalg_Dialect, "dim", [NoSideEffect]>,
   }];
 }
 
-////////////////////////////////////////////////////////////////////////////////
-// Concrete Linalg ops.
-////////////////////////////////////////////////////////////////////////////////
-def DotOp : LinalgOp<"dot", [NInputsAndOutputs<2, 1>,
-                             NLoopTypes<0, 1, 0>,
-                             ViewRanks<[1, 1, 0]>]> {}
-def MatvecOp : LinalgOp<"matvec", [NInputsAndOutputs<2, 1>,
-                                   NLoopTypes<1, 1, 0>,
-                                   ViewRanks<[2, 1, 1]>]> {}
-def MatmulOp : LinalgOp<"matmul", [NInputsAndOutputs<2, 1>,
-                                   NLoopTypes<2, 1, 0>,
-                                   ViewRanks<[2, 2, 2]>]> {}
-
 #endif // LINALG_OPS
\ No newline at end of file
index e6e18bb..d077927 100644 (file)
@@ -593,6 +593,9 @@ namespace mlir {
 #define GET_OP_CLASSES
 #include "mlir/Linalg/IR/LinalgOps.cpp.inc"
 
+#define GET_OP_CLASSES
+#include "mlir/Linalg/IR/LinalgLibraryOps.cpp.inc"
+
 } // namespace mlir
 
 // Ideally this should all be Tablegen'd but there is no good story for
index 19105e8..0e20eb8 100644 (file)
@@ -37,6 +37,10 @@ mlir::linalg::LinalgDialect::LinalgDialect(MLIRContext *context)
 #define GET_OP_LIST
 #include "mlir/Linalg/IR/LinalgOps.cpp.inc"
       >();
+  addOperations<
+#define GET_OP_LIST
+#include "mlir/Linalg/IR/LinalgLibraryOps.cpp.inc"
+      >();
 }
 
 struct mlir::linalg::BufferTypeStorage : public TypeStorage {
index ba1fdbe..ff6f02f 100644 (file)
@@ -248,7 +248,6 @@ static SmallVector<Value *, 4> makeTiledViews(FuncBuilder *b, Location loc,
   return res;
 }
 
-template <class LinalgOp>
 static LogicalResult tileLinalgOp(LinalgOp &op, ArrayRef<Value *> tileSizes,
                                   PerFunctionState &state) {
   // Enforce the convention that "tiling by zero" skips tiling a particular
@@ -278,7 +277,7 @@ static LogicalResult tileLinalgOp(LinalgOp &op, ArrayRef<Value *> tileSizes,
     assert(op.getNumInputsAndOutputs() == op.getOperation()->getNumOperands());
     auto views =
         makeTiledViews(b, loc, op.getOperation(), ivValues, tileSizes, state);
-    b->create<LinalgOp>(loc, views);
+    op.create(*b, loc, views);
     /// NestedBuilders expect handles, we thus return an IndexHandle.
     return IndexHandle();
   }()});
@@ -286,7 +285,6 @@ static LogicalResult tileLinalgOp(LinalgOp &op, ArrayRef<Value *> tileSizes,
   return success();
 }
 
-template <class LinalgOp>
 static LogicalResult tileLinalgOp(LinalgOp &op, ArrayRef<int64_t> tileSizes,
                                   PerFunctionState &state) {
   if (tileSizes.empty())
@@ -319,13 +317,8 @@ static LogicalResult tileLinalgOp(LinalgOp &op, ArrayRef<int64_t> tileSizes,
 // TODO(ntv) expose as a primitive for other passes.
 static LogicalResult tileLinalgOp(Operation *op, ArrayRef<int64_t> tileSizes,
                                   PerFunctionState &state) {
-  if (auto matmulOp = dyn_cast<MatmulOp>(op)) {
-    return tileLinalgOp(matmulOp, tileSizes, state);
-  } else if (auto matvecOp = dyn_cast<MatvecOp>(op)) {
-    return tileLinalgOp(matvecOp, tileSizes, state);
-  } else if (auto dotOp = dyn_cast<DotOp>(op)) {
-    return tileLinalgOp(dotOp, tileSizes, state);
-  }
+  if (auto linalgOp = dyn_cast<LinalgOp>(op))
+    return tileLinalgOp(linalgOp, tileSizes, state);
   return failure();
 }