From: Nicolas Vasilache Date: Wed, 12 Feb 2020 19:41:11 +0000 (-0500) Subject: [mlir][Linalg] Refactor in preparation for automatic Linalg "named" ops. X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=bfaf535791897f3cc2af40d4f5a677489ad25940;p=platform%2Fupstream%2Fllvm.git [mlir][Linalg] Refactor in preparation for automatic Linalg "named" ops. This revision prepares the ground for declaratively defining Linalg "named" ops. Such named ops form the backbone of operations that are ubiquitous in the ML application domain. This revision closely related to the definition of a "Tensor Computation Primitives Dialect" and demonstrates that ops can be expressed as declarative configurations of the `linalg.generic` op. Differential Revision: https://reviews.llvm.org/D74491 --- diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h index 305159d..2e26989 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h @@ -12,6 +12,7 @@ #include "mlir/Dialect/Linalg/IR/LinalgTraits.h" #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" +#include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Builders.h" diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td index d31424a..8914bcf 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -39,19 +39,19 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> { // Loop types handling. //========================================================================// InterfaceMethod< - "Query the number of parallel loops within the current operation.", + "Return the number of parallel loops within the current operation.", "unsigned", "getNumParallelLoops" >, InterfaceMethod< - "Query the number of reduction loops within the current operation.", + "Return the number of reduction loops within the current operation.", "unsigned", "getNumReductionLoops" >, InterfaceMethod< - "Query the number of window loops within the current operation.", + "Return the number of window loops within the current operation.", "unsigned", "getNumWindowLoops" >, InterfaceMethod< - "Query the number of loops within the current operation.", + "Return the number of loops within the current operation.", "unsigned", "getNumLoops">, InterfaceMethod< @@ -63,10 +63,10 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> { // Input arguments handling. //========================================================================// InterfaceMethod< - "Query the number of inputs from the current operation.", + "Return the number of inputs from the current operation.", "unsigned", "getNumInputs" >, - InterfaceMethod<"Query the input view at the given index.", + InterfaceMethod<"Return the input view at the given index.", "Value ", "getInput", (ins "unsigned":$i) >, InterfaceMethod<[{ @@ -76,41 +76,40 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> { "llvm::Optional", "getIndexOfInput", (ins "Value ":$v) >, InterfaceMethod< - "Query the input operands from the current operation.", + "Return the input operands from the current operation.", "Operation::operand_range", "getInputs" >, InterfaceMethod<[{ - Query the type of the input shape at the given index. + Return the type of the input shape at the given index. }], "ShapedType", "getInputShapedType", (ins "unsigned":$i)>, InterfaceMethod<[{ - Query the subset of input operands that are of ranked tensor type. + Return the subset of input operands that are of ranked tensor type. }], "SmallVector", "getInputTensorTypes">, - //========================================================================// // Output arguments handling. //========================================================================// InterfaceMethod< - "Query the number of outputs from the current operation.", + "Return the number of outputs from the current operation.", "unsigned", "getNumOutputs" >, - InterfaceMethod<"Query the output buffer at the given index.", + InterfaceMethod<"Return the output buffer at the given index.", "Value ", "getOutputBuffer", (ins "unsigned":$i) >, InterfaceMethod<[{ - Query the index of the given buffer value, or `None` if the value is not - part of the output buffers. + Return the index of the given buffer value, or `None` if the value is + not part of the output buffers. }], "llvm::Optional", "getIndexOfOutputBuffer", (ins "Value ":$view) >, InterfaceMethod<[{ - Query the type of the output buffer at the given index. + Return the type of the output buffer at the given index. }], "MemRefType", "getOutputBufferType", (ins "unsigned":$i)>, InterfaceMethod<[{ - Query the results that are of ranked tensor type. + Return the results that are of ranked tensor type. }], "SmallVector", "getOutputTensorTypes">, InterfaceMethod< - "Query the output buffers (operands) from the current operation.", + "Return the output buffers (operands) from the current operation.", "Operation::operand_range", "getOutputBuffers" >, @@ -136,18 +135,44 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> { // Other interface methods. //========================================================================// InterfaceMethod< - "Query the iterator types attribute within the current operation.", + "Return the reference iterators for this named op (if any are specied). " + "These reference iterators are used to specify the default behavior of " + "the op. Typically this would be a static method but in order to allow " + "rank-polymorphic ops, this needs to be per object instance. Named ops " + "must define referenceIterators, even if empty for the 0-D case. " + "Generic ops on the other hand have a None `referenceIterators`", + "llvm::Optional>", "referenceIterators" + >, + InterfaceMethod< + "Return the reference indexing maps for this named op (if any are " + "specified). Typically this would be a static method but in order to " + "allow rank-polymorphic ops, this needs to be per object instance. Named " + "ops must define referenceIterators, even if empty for the 0-D case. " + "Generic ops on the other hand have a None `referenceIndexingMaps`", + "llvm::Optional>", "referenceIndexingMaps" + >, + InterfaceMethod< + "Return the iterator types attribute within the current operation.", "ArrayAttr", "iterator_types" >, InterfaceMethod< - "Query the indexing maps attribute within the current operation.", + "Return the indexing maps attribute within the current operation.", "ArrayAttr", "indexing_maps" >, + InterfaceMethod<"Return the input or output indexing map at index `i`.", + "AffineMap", "getIndexingMap", (ins "unsigned":$i) + >, + InterfaceMethod<"Return the input indexing map at index `i`.", + "AffineMap", "getInputIndexingMap", (ins "unsigned":$i) + >, + InterfaceMethod<"Return the output indexing map at index `i`.", + "AffineMap", "getOutputIndexingMap", (ins "unsigned":$i) + >, InterfaceMethod<[{ - Query whether the op has only MemRef input and outputs. + Return whether the op has only MemRef input and outputs. }], "bool", "hasBufferSemantics">, InterfaceMethod<[{ - Query whether the op has only RankedTensor input and outputs. + Return whether the op has only RankedTensor input and outputs. }], "bool", "hasTensorSemantics">, //========================================================================// @@ -204,7 +229,7 @@ class LinalgStructured_Op props> } //////////////////////////////////////////////////////////////////////////////// -// Concrete Linalg ops. +// Named Linalg ops, implemented as special configurations of a generic op. //////////////////////////////////////////////////////////////////////////////// def CopyOp : LinalgStructured_Op<"copy", [NInputs<1>, NOutputs<1>]> { let description = [{ @@ -266,14 +291,19 @@ def CopyOp : LinalgStructured_Op<"copy", [NInputs<1>, NOutputs<1>]> { builder, result, input, output, AffineMapAttr(), AffineMapAttr()); }]>]; let extraClassDeclaration = libraryCallName # [{ + // Defined in C++ for now. + // TODO(ntv): auto-generate. ArrayAttr indexing_maps(); - ArrayAttr iterator_types() { + // Rank-polymorphic. + // filling_value -> O(ivs) with parallel iterators. + llvm::Optional> referenceIterators() { unsigned nPar = input().getType().cast().getRank(); - MLIRContext *ctx = getContext(); - SmallVector iters( - nPar, StringAttr::get(getParallelIteratorTypeName(), ctx)); - return ArrayAttr::get(iters, ctx); + return SmallVector(nPar, getParallelIteratorTypeName()); + } + + llvm::Optional> referenceIndexingMaps() { + llvm_unreachable("NYI referenceIndexingMaps for CopyOp"); } }]; let verifier = [{ return ::verify(*this); }]; @@ -282,21 +312,24 @@ def CopyOp : LinalgStructured_Op<"copy", [NInputs<1>, NOutputs<1>]> { } def FillOp : LinalgStructured_Op<"fill", [NInputs<0>, NOutputs<1>]> { - let arguments = (ins AnyStridedMemRef:$input, + let arguments = (ins AnyStridedMemRef:$output, AnyTypeOf<[AnyFloat, AnyInteger, AnyVector]>:$value); let extraClassDeclaration = libraryCallName # [{ + // Defined in C++ for now. + // TODO(ntv): auto-generate. ArrayAttr indexing_maps(); - ArrayAttr iterator_types() { - unsigned nPar = input().getType().cast().getRank(); - MLIRContext *ctx = getContext(); - SmallVector iters( - nPar, StringAttr::get(getParallelIteratorTypeName(), ctx)); - return ArrayAttr::get(iters, ctx); + // Rank-polymorphic. + // filling_value -> O(ivs) with parallel iterators. + llvm::Optional> referenceIterators() { + unsigned nPar = output().getType().cast().getRank(); + return SmallVector(nPar, getParallelIteratorTypeName()); } - }]; - let verifier = [{ return ::verify(*this); }]; + llvm::Optional> referenceIndexingMaps() { + llvm_unreachable("NYI referenceIndexingMaps for CopyOp"); + } + }]; let hasFolder = 1; } @@ -305,12 +338,16 @@ def DotOp : LinalgStructured_Op<"dot", [NInputs<2>, NOutputs<1>]> { AnyStridedMemRefOfRank<1>, AnyStridedMemRefOfRank<0>); let extraClassDeclaration = libraryCallName # [{ + // Defined in C++ for now. + // TODO(ntv): auto-generate. ArrayAttr indexing_maps(); - ArrayAttr iterator_types() { - MLIRContext *ctx = getContext(); - return ArrayAttr::get( - StringAttr::get(getReductionIteratorTypeName(), ctx), ctx); + llvm::Optional> referenceIterators() { + return SmallVector{getReductionIteratorTypeName()}; + } + + llvm::Optional> referenceIndexingMaps() { + llvm_unreachable("NYI referenceIndexingMaps for DotOp"); } }]; @@ -322,14 +359,18 @@ def MatvecOp : LinalgStructured_Op<"matvec", [NInputs<2>, NOutputs<1>]> { AnyStridedMemRefOfRank<1>, AnyStridedMemRefOfRank<1>); let extraClassDeclaration = libraryCallName # [{ + // Defined in C++ for now. + // TODO(ntv): auto-generate. ArrayAttr indexing_maps(); - ArrayAttr iterator_types() { - MLIRContext *ctx = getContext(); - Attribute iters[2]{ - StringAttr::get(getParallelIteratorTypeName(), ctx), - StringAttr::get(getReductionIteratorTypeName(), ctx)}; - return ArrayAttr::get(iters, ctx); + llvm::Optional> referenceIterators() { + return SmallVector{ + getParallelIteratorTypeName(), + getReductionIteratorTypeName()}; + } + + llvm::Optional> referenceIndexingMaps() { + llvm_unreachable("NYI referenceIndexingMaps for MatvecOp"); } }]; @@ -341,15 +382,19 @@ def MatmulOp : LinalgStructured_Op<"matmul", [NInputs<2>, NOutputs<1>]> { AnyStridedMemRefOfRank<2>, AnyStridedMemRefOfRank<2>); let extraClassDeclaration = libraryCallName # [{ + // Defined in C++ for now. + // TODO(ntv): auto-generate. ArrayAttr indexing_maps(); - ArrayAttr iterator_types() { - MLIRContext *ctx = getContext(); - Attribute iters[3]{ - StringAttr::get(getParallelIteratorTypeName(), ctx), - StringAttr::get(getParallelIteratorTypeName(), ctx), - StringAttr::get(getReductionIteratorTypeName(), ctx)}; - return ArrayAttr::get(iters, ctx); + llvm::Optional> referenceIterators() { + return SmallVector{ + getParallelIteratorTypeName(), + getParallelIteratorTypeName(), + getReductionIteratorTypeName()}; + } + + llvm::Optional> referenceIndexingMaps() { + llvm_unreachable("NYI referenceIndexingMaps for MatmulOp"); } }]; @@ -387,11 +432,13 @@ def ConvOp : LinalgStructured_Op<"conv", [NInputs<2>, NOutputs<1>]> { unsigned getNumInputFeatureDimensions() { return 1; } unsigned getNumOutputFeatureDimensions() { return 1; } + // Defined in C++ for now. + // TODO(ntv): auto-generate. ArrayAttr indexing_maps(); - ArrayAttr iterator_types() { + llvm::Optional> referenceIterators() { // Outer parallel loops are always the number of output dimensions; i.e. - // [ b, xs, q] in the TF notation above. + // [b, xs, q] in the TF notation above. unsigned nPar = getOutputShapedType(0).getRank(); unsigned nRed = getNumInputFeatureDimensions(); // Window loops are a special kind of reduction that is never tiled or @@ -400,13 +447,11 @@ def ConvOp : LinalgStructured_Op<"conv", [NInputs<2>, NOutputs<1>]> { // This may evolve in the future. unsigned nWin = nPar - getNumBatchDimensions() - getNumInputFeatureDimensions(); - MLIRContext *ctx = getContext(); - SmallVector iters( - nPar, StringAttr::get(getParallelIteratorTypeName(), ctx)); + SmallVector iters(nPar, getParallelIteratorTypeName()); iters.reserve(nPar + nRed + nWin); - iters.append(nRed, StringAttr::get(getReductionIteratorTypeName(), ctx)); - iters.append(nWin, StringAttr::get(getWindowIteratorTypeName(), ctx)); - return ArrayAttr::get(iters, ctx); + iters.append(nRed, getReductionIteratorTypeName()); + iters.append(nWin, getWindowIteratorTypeName()); + return iters; } int64_t getStride(unsigned i) { @@ -422,6 +467,10 @@ def ConvOp : LinalgStructured_Op<"conv", [NInputs<2>, NOutputs<1>]> { return dilations()->getValue()[i] .cast().getValue().getSExtValue(); } + + llvm::Optional> referenceIndexingMaps() { + llvm_unreachable("NYI referenceIndexingMaps for MatmulOp"); + } }]; let verifier = [{ return ::verify(*this); }]; @@ -438,6 +487,9 @@ class LinalgOperandOfRank: Type< CPred<"$_self.cast().getRank() == " # rank>] >>; +//////////////////////////////////////////////////////////////////////////////// +// Generic Linalg ops. +//////////////////////////////////////////////////////////////////////////////// class GenericOpBase : LinalgStructuredBase_Op { let arguments = (ins Variadic:$views, I64Attr:$args_in, @@ -457,34 +509,36 @@ class GenericOpBase : LinalgStructuredBase_Op { getIteratorTypesAttrName() }; } + unsigned getNumInputs() { return args_in().getSExtValue(); } + unsigned getNumOutputs() { return args_out().getSExtValue(); } + FuncOp getFunction() { auto moduleOp = getParentOfType(); return fun().hasValue() ? moduleOp.lookupSymbol(fun().getValue()) : FuncOp(); } + StringRef getLibraryCallName() { return library_call().hasValue() ? library_call().getValue() : ""; } - AffineMap getIndexingMap(unsigned i) { - assert(i < getNumInputsAndOutputs()); - return indexing_maps().getValue()[i].cast().getValue(); - } - AffineMap getInputIndexingMap(unsigned i) { - assert(i < getNumInputs()); - return indexing_maps().getValue()[i].cast().getValue(); - } - AffineMap getOutputIndexingMap(unsigned i) { - assert(i < getNumOutputs()); - return indexing_maps().getValue()[i + getNumInputs()] - .cast().getValue(); - } + + llvm::Optional> referenceIterators() { + llvm_unreachable( + "No such thing as reference iterator types for a generic op."); + } + + llvm::Optional> referenceIndexingMaps() { + llvm_unreachable( + "No such thing as reference indexing maps for a generic op."); + } }]; let printer = [{ return ::print(p, *this); }]; let parser = [{ return ::parseGenericOp(parser, result); }]; } +/// Index-free GenericOp. def GenericOp : GenericOpBase<"generic"> { let description = [{ Generic Linalg op form where the key properties of the computation are @@ -609,6 +663,8 @@ def GenericOp : GenericOpBase<"generic"> { let hasFolder = 1; } +/// GenericOp with Indexing (i.e. multi-for style in which the region is passed +/// the enclosing loop induction variables) def IndexedGenericOp : GenericOpBase<"indexed_generic"> { let description = [{ Indexed Generic Linalg op form where the key properties of the computation diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h index bfe528d..31b462a 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h @@ -11,6 +11,7 @@ #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" +#include "mlir/IR/AffineMap.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/StandardTypes.h" #include "mlir/Support/LLVM.h" @@ -214,6 +215,103 @@ public: //==========================================================================// // Other interface methods. //==========================================================================// + + // Get or build the indexing_maps ArrayAttr. + ArrayAttr iterator_types() { + // Return the attribute if it is present. + if (auto attr = this->getOperation()->getAttr("iterator_types")) + return attr.template cast(); + + // If not, form the attribute using the reference iterator types for the + // ConcreteType. + auto maybeReferenceIteratorTypes = + cast(this->getOperation()).referenceIterators(); + + // If there is no reference, this must be a generic op. + // TODO(ntv): Traits are used to define ops. Split into cpp to avoid + // cyclic dependency. + auto name = this->getOperation()->getName().getStringRef(); + if (!maybeReferenceIteratorTypes && name != "generic" && + name != "indexed_generic") { + this->getOperation()->dump(); + llvm_unreachable("Op missing "); + } + + // If we have a reference, build the reference attribute. + auto *ctx = this->getOperation()->getContext(); + auto attrRange = llvm::map_range(*maybeReferenceIteratorTypes, + [ctx](StringRef str) -> Attribute { + return StringAttr::get(str, ctx); + }); + auto attr = ArrayAttr::get(llvm::to_vector<4>(attrRange), ctx); + // TODO(ntv): Need to memoize this. Can't just store as an attribute atm as + // it will impact parser, printer and tests. + // this->getOperation()->setAttr("iterator_types", attr); + return attr; + } + + // Get or build the indexing_maps ArrayAttr. + ArrayAttr indexing_maps() { + // Return the attribute if it is present. + if (auto attr = this->getOperation()->getAttr("indexing_maps")) + return attr.template cast(); + + // If not, form the attribute using the reference indexing map for the + // ConcreteType. + auto maybeReferenceIndexingMaps = + cast(this->getOperation()).referenceIndexingMaps(); + + // If there is no reference, this must be a generic op. + auto name = this->getOperation()->getName().getStringRef(); + if (!maybeReferenceIndexingMaps && name != "generic" && + name != "indexed_generic") { + this->getOperation()->dump(); + llvm_unreachable("Op missing referenceIndexingMaps"); + } + + // If we have a reference, build the reference attribute and set it in the + // op before returning. + auto *ctx = this->getOperation()->getContext(); + auto attrRange = + llvm::map_range(*maybeReferenceIndexingMaps, [ctx](AffineMap map) { + // 0-D corner case because there is no such thing as a concrete empty + // map type. + if (!map) + map = AffineMap::get(0, 0, getAffineConstantExpr(0, ctx)); + return AffineMapAttr::get(map); + }); + SmallVector attrs{attrRange.begin(), attrRange.end()}; + auto attr = ArrayAttr::get(attrs, ctx); + // TODO(ntv): Need to memoize this. Can't just store as an attribute atm as + // it will impact parser, printer and tests. + // this->getOperation()->setAttr("indexing_maps", attr); + return attr; + } + + AffineMap getIndexingMap(unsigned i) { + assert(i < getNumInputsAndOutputs()); + return indexing_maps() + .getValue()[i] + .template cast() + .getValue(); + } + + AffineMap getInputIndexingMap(unsigned i) { + assert(i < nInputs()); + return indexing_maps() + .getValue()[i] + .template cast() + .getValue(); + } + + AffineMap getOutputIndexingMap(unsigned i) { + assert(i < nOutputs()); + return indexing_maps() + .getValue()[i + nInputs()] + .template cast() + .getValue(); + } + /// Query whether the op has only buffer inputs and no returns. bool hasBufferSemantics() { return this->getOperation()->getNumResults() == 0 && diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index fb18fbf..c5fbea9 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -866,6 +866,15 @@ static LogicalResult verify(ConvOp op) { return success(); } +static AffineMap extractOrIdentityMap(Optional maybeMap, + unsigned rank, MLIRContext *context) { + if (maybeMap) + return maybeMap.getValue(); + if (rank == 0) + return AffineMap(); + return AffineMap::getMultiDimIdentityMap(rank, context); +} + namespace mlir { namespace linalg { @@ -880,15 +889,6 @@ namespace linalg { } // namespace linalg } // namespace mlir -static AffineMap extractOrIdentityMap(Optional maybeMap, - unsigned rank, MLIRContext *context) { - if (maybeMap) - return maybeMap.getValue(); - if (rank == 0) - return AffineMap(); - return AffineMap::getMultiDimIdentityMap(rank, context); -} - // Returns `num` AffineDimExpr dimensions at positions [curIdx, curIdx + num) // and increments `curIdx` to `curIdx + num`. static SmallVector @@ -997,23 +997,15 @@ SmallVector mlir::linalg::loopToOperandRangesMaps(Operation *op) { AffineMap::get(idx, 0, concat(concat(bs, ws), qs)), // output[b, x[0], ..., x[N-1], k] AffineMap::get(idx, 0, concat(concat(bs, xs), ks))}; - } else if (auto genericOp = dyn_cast(op)) { - SmallVector res; - unsigned nViews = genericOp.getNumInputsAndOutputs(); - res.reserve(nViews); - for (unsigned i = 0, e = nViews; i < e; ++i) { - res.push_back(genericOp.getIndexingMap(i)); - } - return res; - } else if (auto indexedGenericOp = dyn_cast(op)) { - SmallVector res; - unsigned nViews = indexedGenericOp.getNumInputsAndOutputs(); - res.reserve(nViews); - for (unsigned i = 0, e = nViews; i < e; ++i) - res.push_back(indexedGenericOp.getIndexingMap(i)); - return res; } - llvm_unreachable("Missing loopToOperandRangesMaps for op"); + SmallVector res; + auto linalgOp = cast(op); + unsigned nViews = linalgOp.getNumInputsAndOutputs(); + res.reserve(nViews); + for (unsigned i = 0, e = nViews; i < e; ++i) + res.push_back(linalgOp.getIndexingMap(i)); + assert(nViews == linalgOp.indexing_maps().size()); + return res; } static void appendMangledType(llvm::raw_string_ostream &ss, Type t) {