From a2e7775441bf86897cb46c0dd451b21a33a69628 Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Tue, 23 Apr 2019 13:21:40 -0700 Subject: [PATCH] [Linalg] Add basic linalg ops This CL adds linalg.dot, linalg.matvec and linalg.matmul ops with the proper roundtripping test. These are the first LinalgOp that operate on views and that will lower to library calls. Linalg ops exhibit some common properties and behavior that are modeled with Traits. A LinalgOp is defined as a generic Op that operates on input and output views (passed as operands) and has the following properties: 1. a number of input and outputs captured by the `NInputsAndOutputs` trait. 2. a list of ranks for each operand captured by the `ViewRanks` trait. 3. a set of parallel, reduction and windowing loops captured by `NLoopTypes` trait. These represent are a first set of generic properties that will enable the definition of generic linear algebra operations and the properties necessary for upcoming transformations. -- PiperOrigin-RevId: 244912754 --- mlir/include/mlir/CMakeLists.txt | 3 +- mlir/include/mlir/Linalg/CMakeLists.txt | 4 ++ mlir/include/mlir/Linalg/LinalgOps.h | 4 ++ mlir/include/mlir/Linalg/LinalgOps.td | 96 ++++++++++++++++++++++++++ mlir/include/mlir/Linalg/LinalgTraits.h | 117 ++++++++++++++++++++++++++++++++ mlir/include/mlir/Linalg/LinalgTypes.h | 8 +-- mlir/lib/Linalg/CMakeLists.txt | 1 + mlir/lib/Linalg/LinalgOps.cpp | 47 +++++++++++++ mlir/lib/Linalg/LinalgTypes.cpp | 4 ++ mlir/test/Linalg/roundtrip.mlir | 16 ++++- 10 files changed, 292 insertions(+), 8 deletions(-) create mode 100644 mlir/include/mlir/Linalg/CMakeLists.txt create mode 100644 mlir/include/mlir/Linalg/LinalgOps.td create mode 100644 mlir/include/mlir/Linalg/LinalgTraits.h diff --git a/mlir/include/mlir/CMakeLists.txt b/mlir/include/mlir/CMakeLists.txt index 344a5d0..21a4b21 100644 --- a/mlir/include/mlir/CMakeLists.txt +++ b/mlir/include/mlir/CMakeLists.txt @@ -1,5 +1,6 @@ +add_subdirectory(EDSC) add_subdirectory(FxpMathOps) +add_subdirectory(Linalg) add_subdirectory(LLVMIR) add_subdirectory(Quantization) add_subdirectory(StandardOps) -add_subdirectory(EDSC) diff --git a/mlir/include/mlir/Linalg/CMakeLists.txt b/mlir/include/mlir/Linalg/CMakeLists.txt new file mode 100644 index 0000000..d3ed75c --- /dev/null +++ b/mlir/include/mlir/Linalg/CMakeLists.txt @@ -0,0 +1,4 @@ +set(LLVM_TARGET_DEFINITIONS LinalgOps.td) +mlir_tablegen(LinalgOps.h.inc -gen-op-decls) +mlir_tablegen(LinalgOps.cpp.inc -gen-op-defs) +add_public_tablegen_target(MLIRLinalgOpsIncGen) diff --git a/mlir/include/mlir/Linalg/LinalgOps.h b/mlir/include/mlir/Linalg/LinalgOps.h index 2142459..9406feb 100644 --- a/mlir/include/mlir/Linalg/LinalgOps.h +++ b/mlir/include/mlir/Linalg/LinalgOps.h @@ -19,6 +19,7 @@ #define MLIR_LINALG_LINALGOPS_H_ #include "mlir/IR/OpDefinition.h" +#include "mlir/Linalg/LinalgTraits.h" #include "mlir/Linalg/LinalgTypes.h" #include "mlir/Support/LLVM.h" @@ -219,6 +220,9 @@ public: } }; +#define GET_OP_CLASSES +#include "mlir/Linalg/LinalgOps.h.inc" + } // namespace mlir #endif // MLIR_LINALG_LINALGOPS_H_ diff --git a/mlir/include/mlir/Linalg/LinalgOps.td b/mlir/include/mlir/Linalg/LinalgOps.td new file mode 100644 index 0000000..4198f91 --- /dev/null +++ b/mlir/include/mlir/Linalg/LinalgOps.td @@ -0,0 +1,96 @@ +//===- LinalgOps.td - Linear algebra dialect ops -----------*- tablegen -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This is the operation definition file for linear algebra operations. +// +//===----------------------------------------------------------------------===// + +#ifdef LINALG_OPS +#else + +#ifdef OP_BASE +#else +include "mlir/IR/OpBase.td" +#endif // OP_BASE + +// Whether a type is a ViewType. +def LinalgIsViewTypePred : CPred<"$_self.isa()">; +def View : Type; + +class ParametricNativeOpTrait : + NativeOpTrait +{} + +class ParametricIntNativeOpTrait parameters> : + ParametricNativeOpTrait< + prop, + !strconcat("<", + !cast(!head(parameters)), + !foldl("", + !tail(parameters), + sum, + param, + sum # "," # !cast(param)), + ">::Impl")> +{} + +// The Linalg `NInputsAndOutputs` trait provides the API for ops that are known +// to have a specified number of inputs and outputs, all passed as operands. +// See Linalg/LinalgTraits.h for implementation details an usage. +class NInputsAndOutputs : + ParametricIntNativeOpTrait<"NInputsAndOutputs", [n_ins, n_outs]> +{} + +// The linalg `NLoopTypes` trait provides the API for ops that are known to have +// a specified number of parallel (n_par), reduction (n_red) and window (n_win) +// loops. +// See Linalg/LinalgTraits.h for implementation details an usage. +class NLoopTypes : +ParametricIntNativeOpTrait<"NLoopTypes", [n_par, n_red, n_win]> +{} + +// The linalg `ViewRanks` trait the API for ops that are known to have a +// specified list of view ranks. +// See Linalg/LinalgTraits.h for implementation details an usage. +class ViewRanks ranks> : +ParametricIntNativeOpTrait<"ViewRanks", ranks> +{} + +// Base Tablegen class for Linalg ops. +class LinalgOp props> : +Op { + let arguments = (ins Variadic); // default variadic builder + + let parser = [{ return impl::parseLinalgLibraryOp(parser, result); }]; + + let printer = [{ impl::printLinalgLibraryOp(p, *this); }]; +} + +//////////////////////////////////////////////////////////////////////////////// +// Concrete Linalg ops. +//////////////////////////////////////////////////////////////////////////////// +def DotOp : LinalgOp<"dot", [NInputsAndOutputs<2, 1>, + NLoopTypes<0, 1, 0>, + ViewRanks<[1, 1, 0]>]> {} +def MatvecOp : LinalgOp<"matvec", [NInputsAndOutputs<2, 1>, + NLoopTypes<1, 1, 0>, + ViewRanks<[2, 1, 1]>]> {} +def MatmulOp : LinalgOp<"matmul", [NInputsAndOutputs<2, 1>, + NLoopTypes<2, 1, 0>, + ViewRanks<[2, 2, 2]>]> {} + +#endif // LINALG_OPS \ No newline at end of file diff --git a/mlir/include/mlir/Linalg/LinalgTraits.h b/mlir/include/mlir/Linalg/LinalgTraits.h new file mode 100644 index 0000000..94620a6 --- /dev/null +++ b/mlir/include/mlir/Linalg/LinalgTraits.h @@ -0,0 +1,117 @@ +//===- LinalgTraits.h - Linalg Traits ---------------------------*- C++ -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#ifndef MLIR_LINALG_LINALGTRAITS_H_ +#define MLIR_LINALG_LINALGTRAITS_H_ + +#include "mlir/IR/OpDefinition.h" +#include "mlir/Linalg/LinalgTypes.h" +#include "mlir/Support/LLVM.h" + +namespace mlir { +namespace OpTrait { + +/// This class provides the API for ops that are known to have a specified +/// number of inputs and outputs, all passed as operands. This is used as a +/// trait like this: +/// +/// class DotOp : public Op::Impl> { +/// +template class NInputsAndOutputs { +public: + template + class Impl : public OpTrait::detail::MultiOperandTraitBase< + ConcreteType, NInputsAndOutputs::Impl> { + public: + static unsigned getNumInputs() { return NInputs; } + static unsigned getNumOutputs() { return NOutputs; } + static unsigned getNumInputsAndOutputs() { return NInputs + NOutputs; } + Value *getInput(unsigned i) { return this->getOperand(i); } + Value *getOutput(unsigned i) { + return this->getOperand(getNumInputs() + i); + } + ViewType getInputViewType(unsigned i) { + return this->getOperand(i)->getType().template cast(); + } + ViewType getOutputViewType(unsigned i) { + return this->getOperand(getNumInputs() + i) + ->getType() + .template cast(); + } + ViewType getViewType(unsigned i) { + return this->getOperand(i)->getType().template cast(); + } + static LogicalResult verifyTrait(Operation *op) { + return OpTrait::impl::verifyAtLeastNOperands(op, NInputs + NOutputs); + } + }; +}; + +/// This class provides the API for ops that are known to have a specified +/// number of parallel, reduction and window loops. This is used as a trait like +/// this: +/// +/// class MatmulOp : public Op::Impl> { +/// +template +class NLoopTypes { +public: + template + class Impl + : public OpTrait::TraitBase< + ConcreteType, NLoopTypes::Impl> { + public: + static unsigned getNumParallelLoops() { return NParallel; } + static unsigned getNumReductionLoops() { return NReduction; } + static unsigned getNumWindowLoops() { return NWindow; } + static unsigned getNumLoops() { return NParallel + NReduction + NWindow; } + }; +}; + +/// This class provides the API for ops that are known to have a specified +/// list of view ranks. This is used as a trait like this: +/// +/// class MatvecOp : public Op::Impl> { +/// +template class ViewRanks { +public: + template + class Impl + : public OpTrait::TraitBase::Impl> { + public: + static LogicalResult verifyTrait(Operation *op) { + ArrayRef ranks{Ranks...}; + if (op->getNumOperands() != ranks.size()) + return op->emitError("expected " + Twine(ranks.size()) + " operands"); + for (unsigned i = 0, e = op->getNumOperands(); i < e; ++i) { + auto viewType = op->getOperand(i)->getType().dyn_cast(); + if (!viewType) + return op->emitOpError("operand " + Twine(i) + + " must have view type "); + if (ranks[i] != viewType.getRank()) + return op->emitOpError("operand " + Twine(i) + " must have rank " + + Twine(ranks[i])); + } + return success(); + } + }; +}; + +} // namespace OpTrait +} // namespace mlir + +#endif // MLIR_LINALG_LINALGTRAITS_H_ diff --git a/mlir/include/mlir/Linalg/LinalgTypes.h b/mlir/include/mlir/Linalg/LinalgTypes.h index fbb1cbf..64f86d4 100644 --- a/mlir/include/mlir/Linalg/LinalgTypes.h +++ b/mlir/include/mlir/Linalg/LinalgTypes.h @@ -93,20 +93,18 @@ public: /// %3 = linalg.view %1[%2, %2] : !linalg.view /// ``` class ViewTypeStorage; -class ViewType - : public mlir::Type::TypeBase { +class ViewType : public Type::TypeBase { public: // Used for generic hooks in TypeBase. using Base::Base; /// Construction hook. - static ViewType get(mlir::MLIRContext *context, mlir::Type elementType, - unsigned rank); + static ViewType get(MLIRContext *context, Type elementType, unsigned rank); // Used to implement llvm-style cast. static bool kindof(unsigned kind) { return kind == LinalgTypes::View; } // Type-specific functionality. /// Return the underlying elemental type. - mlir::Type getElementType(); + Type getElementType(); /// Return the rank of the view. /// This is the number of indexings needed to reach an underlying element. unsigned getRank(); diff --git a/mlir/lib/Linalg/CMakeLists.txt b/mlir/lib/Linalg/CMakeLists.txt index b1df307..50af3cc 100644 --- a/mlir/lib/Linalg/CMakeLists.txt +++ b/mlir/lib/Linalg/CMakeLists.txt @@ -6,3 +6,4 @@ add_llvm_library(MLIRLinalg ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Linalg ) +add_dependencies(MLIRLinalg MLIRLinalgOpsIncGen) diff --git a/mlir/lib/Linalg/LinalgOps.cpp b/mlir/lib/Linalg/LinalgOps.cpp index 2db7696..423fed0 100644 --- a/mlir/lib/Linalg/LinalgOps.cpp +++ b/mlir/lib/Linalg/LinalgOps.cpp @@ -20,6 +20,8 @@ //===----------------------------------------------------------------------===// #include "mlir/Linalg/LinalgOps.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/StandardTypes.h" @@ -354,3 +356,48 @@ void mlir::ViewOp::print(OpAsmPrinter *p) { [&](mlir::Value *v) { *p << *v; }, [&]() { *p << ", "; }); *p << "] : " << getType(); } + +namespace mlir { +namespace impl { + +// A LinalgLibraryOp prints as: +// +// ```{.mlir} +// concrete_op_name (ssa-inputs, ssa-outputs) : view-types +// ``` +// +// for example: +// +// ``` +// linalg.matmul(%0, %1, %2) : +// !linalg.view, !linalg.view, !linalg.view +// ``` +// +// Where %0, %1 and %2 are ssa-values of type ViewType. +void printLinalgLibraryOp(mlir::OpAsmPrinter *p, Operation *op) { + assert(op->getAbstractOperation() && "unregistered operation"); + *p << op->getName().getStringRef() << "("; + interleave( + op->getOperands().begin(), op->getOperands().end(), + [&](mlir::Value *v) { *p << *v; }, [&]() { *p << ", "; }); + *p << ") : "; + interleave( + op->getOperands().begin(), op->getOperands().end(), + [&](mlir::Value *v) { *p << v->getType(); }, [&]() { *p << ", "; }); +} + +bool parseLinalgLibraryOp(OpAsmParser *parser, OperationState *result) { + SmallVector ops; + SmallVector types; + return parser->parseOperandList(ops, -1, OpAsmParser::Delimiter::Paren) || + parser->parseOptionalAttributeDict(result->attributes) || + parser->parseColonTypeList(types) || + parser->resolveOperands(ops, types, parser->getNameLoc(), + result->operands); +} +} // namespace impl + +#define GET_OP_CLASSES +#include "mlir/Linalg/LinalgOps.cpp.inc" + +} // namespace mlir diff --git a/mlir/lib/Linalg/LinalgTypes.cpp b/mlir/lib/Linalg/LinalgTypes.cpp index fa08f75..a507fa8 100644 --- a/mlir/lib/Linalg/LinalgTypes.cpp +++ b/mlir/lib/Linalg/LinalgTypes.cpp @@ -31,6 +31,10 @@ mlir::LinalgDialect::LinalgDialect(MLIRContext *context) : Dialect("linalg", context) { addTypes(); addOperations(); + addOperations< +#define GET_OP_LIST +#include "mlir/Linalg/LinalgOps.cpp.inc" + >(); } struct mlir::BufferTypeStorage : public mlir::TypeStorage { diff --git a/mlir/test/Linalg/roundtrip.mlir b/mlir/test/Linalg/roundtrip.mlir index 4327e5d..a0e02d0 100644 --- a/mlir/test/Linalg/roundtrip.mlir +++ b/mlir/test/Linalg/roundtrip.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -verify | FileCheck %s +// RUN: mlir-opt %s -verify | mlir-opt %s -verify | FileCheck %s func @range(%arg0: index, %arg1: index, %arg2: index) { %0 = linalg.range %arg0:%arg1:%arg2 : !linalg.range @@ -39,4 +39,16 @@ func @views(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index // CHECK-NEXT: %5 = linalg.slice %3[%2, %arg2] : !linalg.view, !linalg.range, index, !linalg.view // CHECK-NEXT: %6 = linalg.slice %3[%arg2, %2] : !linalg.view, index, !linalg.range, !linalg.view // CHECK-NEXT: %7 = linalg.slice %3[%arg2, %arg3] : !linalg.view, index, index, !linalg.view -// CHECK-NEXT: linalg.buffer_dealloc %1 : !linalg.buffer \ No newline at end of file +// CHECK-NEXT: linalg.buffer_dealloc %1 : !linalg.buffer + +func @ops(%arg0: !linalg.view, %arg1: !linalg.view, %arg2: !linalg.view, %arg3: !linalg.view) { + linalg.matmul(%arg0, %arg0, %arg0) : !linalg.view, !linalg.view, !linalg.view + linalg.matvec(%arg0, %arg1, %arg2) : !linalg.view, !linalg.view, !linalg.view + linalg.dot(%arg1, %arg2, %arg3) : !linalg.view, !linalg.view, !linalg.view + return +} +// CHECK-LABEL: func @ops(%arg0: !linalg.view, %arg1: !linalg.view, %arg2: !linalg.view, %arg3: !linalg.view) { +// CHECK-NEXT: linalg.matmul(%arg0, %arg0, %arg0) : !linalg.view, !linalg.view, !linalg.view +// CHECK-NEXT: linalg.matvec(%arg0, %arg1, %arg2) : !linalg.view, !linalg.view, !linalg.view +// CHECK-NEXT: linalg.dot(%arg1, %arg2, %arg3) : !linalg.view, !linalg.view, !linalg.view + -- 2.7.4