--- /dev/null
+//===- LinalgBase.td - Linalg dialect base support ---------*- tablegen -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This is the definition file for base linear algebra support.
+//
+//===----------------------------------------------------------------------===//
+
+#ifdef LINALG_OPS
+#else
+
+#ifdef OP_BASE
+#else
+include "mlir/IR/OpBase.td"
+#endif // OP_BASE
+
+def Linalg_Dialect : Dialect {
+ let name = "linalg";
+}
+
+// Whether a type is a BufferType.
+def LinalgIsBufferTypePred : CPred<"$_self.isa<BufferType>()">;
+def Buffer : Type<LinalgIsBufferTypePred, "buffer">;
+
+// Whether a type is a ViewType.
+def LinalgIsViewTypePred : CPred<"$_self.isa<ViewType>()">;
+def View : Type<LinalgIsViewTypePred, "view">;
+
+#endif // LINALG_OPS
\ No newline at end of file
--- /dev/null
+//===- LinalgLibraryOps.td - Linalg dialect library ops -*- tablegen ----*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This is the operation definition file for linear algebra operations that
+// correspond to underlying library calls (e.g. BLAS).
+//
+//===----------------------------------------------------------------------===//
+
+#ifdef LINALG_OPS
+#else
+
+#ifdef LINALG_BASE
+#else
+include "mlir/Linalg/IR/LinalgBase.td"
+#endif // LINALG_BASE
+
+class LinalgParametricNativeOpTrait<string prop, string parameters> :
+ NativeOpTrait<"linalg::" # prop # parameters>
+{}
+
+class LinalgParametricIntNativeOpTrait<string prop, list<int> parameters> :
+ LinalgParametricNativeOpTrait<
+ prop,
+ !strconcat("<",
+ !cast<string>(!head(parameters)),
+ !foldl("",
+ !tail(parameters),
+ sum,
+ param,
+ sum # "," # !cast<string>(param)),
+ ">::Impl")>
+{}
+
+// The Linalg `NInputsAndOutputs` trait provides the API for ops that are known
+// to have a specified number of inputs and outputs, all passed as operands.
+// See Linalg/LinalgTraits.h for implementation details an usage.
+class NInputsAndOutputs<int n_ins, int n_outs> :
+ LinalgParametricIntNativeOpTrait<"NInputsAndOutputs", [n_ins, n_outs]>
+{}
+
+// The linalg `NLoopTypes` trait provides the API for ops that are known to have
+// a specified number of parallel (n_par), reduction (n_red) and window (n_win)
+// loops.
+// See Linalg/LinalgTraits.h for implementation details an usage.
+class NLoopTypes<int n_par, int n_red, int n_win> :
+LinalgParametricIntNativeOpTrait<"NLoopTypes", [n_par, n_red, n_win]>
+{}
+
+// The linalg `ViewRanks` trait the API for ops that are known to have a
+// specified list of view ranks.
+// See Linalg/LinalgTraits.h for implementation details an usage.
+class ViewRanks<list<int> ranks> :
+LinalgParametricIntNativeOpTrait<"ViewRanks", ranks>
+{}
+
+// Base Tablegen class for Linalg ops.
+class LinalgOp<string mnemonic, list<OpTrait> props> :
+Op<Linalg_Dialect, mnemonic, props> {
+ let arguments = (ins Variadic<View>); // default variadic builder
+ let parser = [{ return parseLinalgLibraryOp(parser, result); }];
+ let printer = [{ printLinalgLibraryOp(p, *this); }];
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// Concrete Linalg ops.
+////////////////////////////////////////////////////////////////////////////////
+def DotOp : LinalgOp<"dot", [NInputsAndOutputs<2, 1>,
+ NLoopTypes<0, 1, 0>,
+ ViewRanks<[1, 1, 0]>]> {}
+def MatvecOp : LinalgOp<"matvec", [NInputsAndOutputs<2, 1>,
+ NLoopTypes<1, 1, 0>,
+ ViewRanks<[2, 1, 1]>]> {}
+def MatmulOp : LinalgOp<"matmul", [NInputsAndOutputs<2, 1>,
+ NLoopTypes<2, 1, 0>,
+ ViewRanks<[2, 2, 2]>]> {}
+
+#endif // LINALG_OPS
\ No newline at end of file
#ifndef MLIR_LINALG_LINALGOPS_H_
#define MLIR_LINALG_LINALGOPS_H_
+#include "mlir/IR/Builders.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/Linalg/IR/LinalgTraits.h"
#include "mlir/Linalg/IR/LinalgTypes.h"
#define GET_OP_CLASSES
#include "mlir/Linalg/IR/LinalgOps.h.inc"
+#define GET_OP_CLASSES
+#include "mlir/Linalg/IR/LinalgLibraryOps.h.inc"
+
/// Returns the list of maps that map loops to operands of a Linalg op.
/// The i-th affine map identifies loop indices to subscripts that are used when
/// accessing the i-th operand.
/// Only permutation maps are currently supported.
SmallVector<AffineMap, 4> loopToOperandRangesMaps(Operation *op);
+/// A LinalgOp behaves like a base class for the Linalg operations that are
+/// defined in LinalgLibraryOps.td. The implementation does not use inheritance
+/// directly. Instead, a LinalgOp directly derives from Op, hides the `classof`
+/// method and dispatches to the appropriate LinalgLibraryOp.
+/// This allows writing generic passes, like tiling, for all current and future
+/// LinalgOps without requiring templating and dispatch in multiple places.
+class LinalgOp : public Op<LinalgOp> {
+public:
+ using Op::Op;
+
+ LinalgOp(Operation *op) : Op<LinalgOp>(op) {
+ impl = ModelDispatch<
+#define GET_OP_LIST
+#include "mlir/Linalg/IR/LinalgLibraryOps.cpp.inc"
+ >::dispatch(op);
+ }
+
+ static bool classof(Operation *op) {
+ return ModelDispatch<
+#define GET_OP_LIST
+#include "mlir/Linalg/IR/LinalgLibraryOps.cpp.inc"
+ >::classof(op);
+ }
+
+ unsigned getNumParallelLoops() {
+ return impl->getNumParallelLoops(getOperation());
+ }
+ unsigned getNumReductionLoops() {
+ return impl->getNumReductionLoops(getOperation());
+ }
+ unsigned getNumWindowLoops() {
+ return impl->getNumWindowLoops(getOperation());
+ }
+ unsigned getNumInputsAndOutputs() {
+ return impl->getNumInputsAndOutputs(getOperation());
+ }
+ Operation *create(FuncBuilder &builder, Location loc,
+ ArrayRef<Value *> operands) {
+ return impl->create(builder, loc, operands);
+ }
+
+private:
+ struct Concept {
+ virtual ~Concept() = default;
+ virtual unsigned getNumInputsAndOutputs(Operation *op) = 0;
+ virtual unsigned getNumParallelLoops(Operation *op) = 0;
+ virtual unsigned getNumReductionLoops(Operation *op) = 0;
+ virtual unsigned getNumWindowLoops(Operation *op) = 0;
+ virtual unsigned getNumLoops(Operation *op) = 0;
+ virtual Operation *create(FuncBuilder &builder, Location loc,
+ ArrayRef<Value *> operands) = 0;
+ };
+
+ /// The implementation is inspired from Sean Parent's concept-based
+ /// polymorphism. A key difference is that the set of classes erased is
+ /// statically known, which alleviates the need for using dynamic memory
+ /// allocation.
+ /// We use a zero-sized templated class `Model<ConcreteOp>` to emit the
+ /// virtual table and generate a singleton object for each instantiation of
+ /// this class.
+ /// We pay the cost of initialization once on construction (find which class
+ /// to dispatch to) and then a virtual dispatch on every call.
+ template <typename ConcreteOp> struct Model : public Concept {
+ static Model<ConcreteOp> &instance() {
+ static Model<ConcreteOp> singleton;
+ return singleton;
+ }
+ unsigned getNumInputsAndOutputs(Operation *op) override {
+ return cast<ConcreteOp>(op).getNumInputsAndOutputs();
+ }
+ unsigned getNumParallelLoops(Operation *op) override {
+ return cast<ConcreteOp>(op).getNumParallelLoops();
+ }
+ unsigned getNumReductionLoops(Operation *op) override {
+ return cast<ConcreteOp>(op).getNumReductionLoops();
+ }
+ unsigned getNumWindowLoops(Operation *op) override {
+ return cast<ConcreteOp>(op).getNumWindowLoops();
+ }
+ unsigned getNumLoops(Operation *op) override {
+ return cast<ConcreteOp>(op).getNumLoops();
+ }
+ Operation *create(FuncBuilder &builder, Location loc,
+ ArrayRef<Value *> operands) override {
+ return builder.create<ConcreteOp>(loc, operands);
+ }
+ };
+ Concept *impl;
+
+ template <typename... Types> struct ModelDispatch;
+
+ template <typename First, typename... Rest>
+ struct ModelDispatch<First, Rest...> {
+ static bool classof(Operation *op) {
+ return isa<First>(op) || ModelDispatch<Rest...>::classof(op);
+ }
+ static Concept *dispatch(Operation *op) {
+ return isa<First>(op) ? &Model<First>::instance()
+ : ModelDispatch<Rest...>::dispatch(op);
+ }
+ };
+
+ template <typename...> struct ModelDispatch {
+ static bool classof(Operation *op) { return false; }
+ static Concept *dispatch(Operation *op) {
+ llvm_unreachable("Invalid LinalgOp");
+ }
+ };
+};
+
} // namespace linalg
} // namespace mlir
-//===- LinalgOps.td - Linear algebra dialect ops -----------*- tablegen -*-===//
+//===- LinalgOps.td - Linalg dialect ops -------------------*- tablegen -*-===//
//
// Copyright 2019 The MLIR Authors.
//
#ifdef LINALG_OPS
#else
-#ifdef OP_BASE
+#ifdef LINALG_BASE
#else
-include "mlir/IR/OpBase.td"
-#endif // OP_BASE
-
-def Linalg_Dialect : Dialect {
- let name = "linalg";
-}
-
-// Whether a type is a BufferType.
-def LinalgIsBufferTypePred : CPred<"$_self.isa<BufferType>()">;
-def Buffer : Type<LinalgIsBufferTypePred, "buffer">;
-
-// Whether a type is a ViewType.
-def LinalgIsViewTypePred : CPred<"$_self.isa<ViewType>()">;
-def View : Type<LinalgIsViewTypePred, "view">;
-
-class LinalgParametricNativeOpTrait<string prop, string parameters> :
- NativeOpTrait<"linalg::" # prop # parameters>
-{}
-
-class LinalgParametricIntNativeOpTrait<string prop, list<int> parameters> :
- LinalgParametricNativeOpTrait<
- prop,
- !strconcat("<",
- !cast<string>(!head(parameters)),
- !foldl("",
- !tail(parameters),
- sum,
- param,
- sum # "," # !cast<string>(param)),
- ">::Impl")>
-{}
-
-// The Linalg `NInputsAndOutputs` trait provides the API for ops that are known
-// to have a specified number of inputs and outputs, all passed as operands.
-// See Linalg/LinalgTraits.h for implementation details an usage.
-class NInputsAndOutputs<int n_ins, int n_outs> :
- LinalgParametricIntNativeOpTrait<"NInputsAndOutputs", [n_ins, n_outs]>
-{}
-
-// The linalg `NLoopTypes` trait provides the API for ops that are known to have
-// a specified number of parallel (n_par), reduction (n_red) and window (n_win)
-// loops.
-// See Linalg/LinalgTraits.h for implementation details an usage.
-class NLoopTypes<int n_par, int n_red, int n_win> :
-LinalgParametricIntNativeOpTrait<"NLoopTypes", [n_par, n_red, n_win]>
-{}
-
-// The linalg `ViewRanks` trait the API for ops that are known to have a
-// specified list of view ranks.
-// See Linalg/LinalgTraits.h for implementation details an usage.
-class ViewRanks<list<int> ranks> :
-LinalgParametricIntNativeOpTrait<"ViewRanks", ranks>
-{}
-
-// Base Tablegen class for Linalg ops.
-class LinalgOp<string mnemonic, list<OpTrait> props> :
-Op<Linalg_Dialect, mnemonic, props> {
- let arguments = (ins Variadic<View>); // default variadic builder
- let parser = [{ return parseLinalgLibraryOp(parser, result); }];
- let printer = [{ printLinalgLibraryOp(p, *this); }];
-}
+include "mlir/Linalg/IR/LinalgBase.td"
+#endif // LINALG_BASE
def BufferSizeOp :
Op<Linalg_Dialect, "buffer_size", [NoSideEffect]>,
}];
}
-////////////////////////////////////////////////////////////////////////////////
-// Concrete Linalg ops.
-////////////////////////////////////////////////////////////////////////////////
-def DotOp : LinalgOp<"dot", [NInputsAndOutputs<2, 1>,
- NLoopTypes<0, 1, 0>,
- ViewRanks<[1, 1, 0]>]> {}
-def MatvecOp : LinalgOp<"matvec", [NInputsAndOutputs<2, 1>,
- NLoopTypes<1, 1, 0>,
- ViewRanks<[2, 1, 1]>]> {}
-def MatmulOp : LinalgOp<"matmul", [NInputsAndOutputs<2, 1>,
- NLoopTypes<2, 1, 0>,
- ViewRanks<[2, 2, 2]>]> {}
-
#endif // LINALG_OPS
\ No newline at end of file
#define GET_OP_CLASSES
#include "mlir/Linalg/IR/LinalgOps.cpp.inc"
+#define GET_OP_CLASSES
+#include "mlir/Linalg/IR/LinalgLibraryOps.cpp.inc"
+
} // namespace mlir
// Ideally this should all be Tablegen'd but there is no good story for
#define GET_OP_LIST
#include "mlir/Linalg/IR/LinalgOps.cpp.inc"
>();
+ addOperations<
+#define GET_OP_LIST
+#include "mlir/Linalg/IR/LinalgLibraryOps.cpp.inc"
+ >();
}
struct mlir::linalg::BufferTypeStorage : public TypeStorage {
return res;
}
-template <class LinalgOp>
static LogicalResult tileLinalgOp(LinalgOp &op, ArrayRef<Value *> tileSizes,
PerFunctionState &state) {
// Enforce the convention that "tiling by zero" skips tiling a particular
assert(op.getNumInputsAndOutputs() == op.getOperation()->getNumOperands());
auto views =
makeTiledViews(b, loc, op.getOperation(), ivValues, tileSizes, state);
- b->create<LinalgOp>(loc, views);
+ op.create(*b, loc, views);
/// NestedBuilders expect handles, we thus return an IndexHandle.
return IndexHandle();
}()});
return success();
}
-template <class LinalgOp>
static LogicalResult tileLinalgOp(LinalgOp &op, ArrayRef<int64_t> tileSizes,
PerFunctionState &state) {
if (tileSizes.empty())
// TODO(ntv) expose as a primitive for other passes.
static LogicalResult tileLinalgOp(Operation *op, ArrayRef<int64_t> tileSizes,
PerFunctionState &state) {
- if (auto matmulOp = dyn_cast<MatmulOp>(op)) {
- return tileLinalgOp(matmulOp, tileSizes, state);
- } else if (auto matvecOp = dyn_cast<MatvecOp>(op)) {
- return tileLinalgOp(matvecOp, tileSizes, state);
- } else if (auto dotOp = dyn_cast<DotOp>(op)) {
- return tileLinalgOp(dotOp, tileSizes, state);
- }
+ if (auto linalgOp = dyn_cast<LinalgOp>(op))
+ return tileLinalgOp(linalgOp, tileSizes, state);
return failure();
}