From f52d71736b10e87b1aa1880b777dc9462a0085ce Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Sat, 11 Jan 2020 02:22:00 -0500 Subject: [PATCH] [mlir][Linalg] Update the semantics, verifier and test for Linalg with tensors. Summary: This diff fixes issues with the semantics of linalg.generic on tensors that appeared when converting directly from HLO to linalg.generic. The changes are self-contained within MLIR and can be captured and tested independently of XLA. The linalg.generic and indexed_generic are updated to: To allow progressive lowering from the value world (a.k.a tensor values) to the buffer world (a.k.a memref values), a linalg.generic op accepts mixing input and output ranked tensor values with input and output memrefs. ``` %1 = linalg.generic #trait_attribute %A, %B {other-attributes} : tensor, memref -> (tensor) ``` In this case, the number of outputs (args_out) must match the sum of (1) the number of output buffer operands and (2) the number of tensor return values. The semantics is that the linalg.indexed_generic op produces (i.e. allocates and fills) its return values. Tensor values must be legalized by a buffer allocation pass before most transformations can be applied. Such legalization moves tensor return values into output buffer operands and updates the region argument accordingly. Transformations that create control-flow around linalg.indexed_generic operations are not expected to mix with tensors because SSA values do not escape naturally. Still, transformations and rewrites that take advantage of tensor SSA values are expected to be useful and will be added in the near future. Subscribers: bmahjour, mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, arpith-jacob, mgester, lucyrfox, llvm-commits Tags: #llvm Differential Revision: https://reviews.llvm.org/D72555 --- .../mlir/Dialect/Linalg/IR/LinalgStructuredOps.td | 178 ++++++++++++-------- mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h | 167 +++++++++++------- mlir/include/mlir/Dialect/Linalg/Utils/Utils.h | 2 +- .../Dialect/Linalg/Analysis/DependenceAnalysis.cpp | 10 +- mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 187 ++++++++++----------- mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp | 31 +++- .../Dialect/Linalg/Transforms/LinalgToLoops.cpp | 43 +++-- .../Dialect/Linalg/Transforms/LinalgTransforms.cpp | 14 +- mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp | 11 +- mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp | 9 +- mlir/test/Dialect/Linalg/invalid.mlir | 94 +++++------ mlir/test/Dialect/Linalg/roundtrip.mlir | 24 +-- 12 files changed, 458 insertions(+), 312 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td index d8c657c..44735a0 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -35,38 +35,9 @@ def StructuredOpTraits : NativeOpTrait<"linalg::StructuredOpTraits">; // interface. def LinalgStructuredInterface : OpInterface<"LinalgOp"> { let methods = [ - InterfaceMethod< - "Query the number of inputs from the current operation.", - "unsigned", "getNumInputs" - >, - InterfaceMethod< - "Query the number of outputs from the current operation.", - "unsigned", "getNumOutputs" - >, - InterfaceMethod< - "Query the number of inputs and outputs from the current operation.", - "unsigned", "getNumInputsAndOutputs" - >, - InterfaceMethod< - "Query the input operands from the current operation.", - "Operation::operand_range", "getInputs" - >, - InterfaceMethod< - "Query the output operands from the current operation.", - "Operation::operand_range", "getOutputs" - >, - InterfaceMethod< - "Query the input and output operands from the current operation.", - "Operation::operand_range", "getInputsAndOutputs" - >, - InterfaceMethod< - "Query the iterator types attribute within the current operation.", - "ArrayAttr", "iterator_types" - >, - InterfaceMethod< - "Query the indexing maps attribute within the current operation.", - "ArrayAttr", "indexing_maps" - >, + //========================================================================// + // Loop types handling. + //========================================================================// InterfaceMethod< "Query the number of parallel loops within the current operation.", "unsigned", "getNumParallelLoops" @@ -82,40 +53,98 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> { InterfaceMethod< "Query the number of loops within the current operation.", "unsigned", "getNumLoops">, + + //========================================================================// + // Input arguments handling. + //========================================================================// + InterfaceMethod< + "Query the number of inputs from the current operation.", + "unsigned", "getNumInputs" + >, InterfaceMethod<"Query the input view at the given index.", "Value ", "getInput", (ins "unsigned":$i) >, - InterfaceMethod<"Query the output view at the given index.", - "Value ", "getOutput", (ins "unsigned":$i) - >, InterfaceMethod<[{ Return the index of the given input value `v`, or `None` if the value is not an input. }], "llvm::Optional", "getIndexOfInput", (ins "Value ":$v) >, - InterfaceMethod<[{ - Query the index of the given view value, or `None` if the value is not - a view. - }], - "llvm::Optional", "getIndexOfOutput", (ins "Value ":$view) + InterfaceMethod< + "Query the input operands from the current operation.", + "Operation::operand_range", "getInputs" >, InterfaceMethod<[{ Query the type of the input shape at the given index. }], "ShapedType", "getInputShapedType", (ins "unsigned":$i)>, InterfaceMethod<[{ - Query the type of the output view at the given index. - }], "ShapedType", "getOutputShapedType", (ins "unsigned":$i)>, - InterfaceMethod<[{ - Query whether the op has only MemRef input and outputs. - }], "bool", "hasBufferSemantics">, - InterfaceMethod<[{ Query the subset of input operands that are of ranked tensor type. }], "SmallVector", "getInputTensorTypes">, + + + //========================================================================// + // Output arguments handling. + //========================================================================// + InterfaceMethod< + "Query the number of outputs from the current operation.", + "unsigned", "getNumOutputs" + >, + InterfaceMethod<"Query the output buffer at the given index.", + "Value ", "getOutputBuffer", (ins "unsigned":$i) + >, InterfaceMethod<[{ - Query the subset of output operands that are of ranked tensor type. + Query the index of the given buffer value, or `None` if the value is not + part of the output buffers. + }], + "llvm::Optional", "getIndexOfOutputBuffer", (ins "Value ":$view) + >, + InterfaceMethod<[{ + Query the type of the output buffer at the given index. + }], "MemRefType", "getOutputBufferType", (ins "unsigned":$i)>, + InterfaceMethod<[{ + Query the results that are of ranked tensor type. }], "SmallVector", "getOutputTensorTypes">, + InterfaceMethod< + "Query the output buffers (operands) from the current operation.", + "Operation::operand_range", "getOutputBuffers" + >, + //========================================================================// + // Input and Output arguments handling. + //========================================================================// + InterfaceMethod< + "Return the number of inputs and outputs, irrespective of their buffer " + "or tensor type.", + "unsigned", "getNumInputsAndOutputs" + >, + InterfaceMethod< + "Return the number of inputs, irrespective of their buffer or tensor " + "type, and output buffers", + "unsigned", "getNumInputsAndOutputBuffers" + >, + InterfaceMethod< + "Return the range over inputs (irrespective of type) and output buffers.", + "Operation::operand_range", "getInputsAndOutputBuffers" + >, + + //========================================================================// + // Other interface methods. + //========================================================================// + InterfaceMethod< + "Query the iterator types attribute within the current operation.", + "ArrayAttr", "iterator_types" + >, + InterfaceMethod< + "Query the indexing maps attribute within the current operation.", + "ArrayAttr", "indexing_maps" + >, + InterfaceMethod<[{ + Query whether the op has only MemRef input and outputs. + }], "bool", "hasBufferSemantics">, + + //========================================================================// + // Other static interface methods. + //========================================================================// StaticInterfaceMethod<[{ Create an operation of the current type with the given location, operands, and attributes. @@ -128,9 +157,6 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> { attributes); }] >, - - /// Clone an operation with the given location and operands. This is used to - /// abstract away the optional underlying region creation. InterfaceMethod<[{ Clone the current operation with the given location and operands. This is used to abstract away the optional underlying region creation. @@ -536,22 +562,26 @@ def GenericOp : GenericOpBase<"generic"> { mixing input and output ranked tensor values with input and output memrefs. ```mlir - %1 = linalg.generic #trait_attribute %A, %B, %C {other-attributes} : + %C = linalg.generic #trait_attribute %A, %B {other-attributes} : tensor, - memref, - tensor + memref -> (tensor) ``` - In this case, the number of return values must match the number of output - tensor arguments. The semantics is that the `linalg.generic` op - produces (i.e. allocates and fills) its return values. + In this case, the number of outputs (args_out) must match the sum of (1) the + number of output buffer operands and (2) the number of tensor return values. + The semantics is that the `linalg.indexed_generic` op produces (i.e. + allocates and fills) its tensor return values. + Tensor values must be legalized by a buffer allocation pass before most - transformations can be applied. In particular, transformations that create - control flow around linalg.generic operations are not expected to mix with - tensors because SSA values do not escape naturally. Still, transformations - and rewrites that take advantage of tensor SSA values are expected to be - useful and will be added in the near future. + transformations can be applied. Such legalization moves tensor return values + into output buffer operands and updates the region arguments accordingly. + + Transformations that create control-flow around linalg.indexed_generic + operations are not expected to work with tensors because SSA values do not + escape naturally. Still, transformations and rewrites that take advantage of + tensor SSA values are expected to be useful and will be added in the near + future. }]; let verifier = [{ return ::verify(*this); }]; } @@ -659,22 +689,26 @@ def IndexedGenericOp : GenericOpBase<"indexed_generic"> { memrefs. ```mlir - %1 = linalg.indexed_generic #trait_attribute %A, %B, %C {other-attributes} + %C = linalg.indexed_generic #trait_attribute %A, %B {other-attributes} : tensor, - memref, - tensor + memref -> (tensor) ``` - In this case, the number of return values must match the number of output - tensor arguments. The semantics is that the `linalg.indexed_generic` op - produces (i.e. allocates and fills) its return values. + In this case, the number of outputs (args_out) must match the sum of (1) the + number of output buffer operands and (2) the number of tensor return values. + The semantics is that the `linalg.indexed_generic` op produces (i.e. + allocates and fills) its return values. + Tensor values must be legalized by a buffer allocation pass before most - transformations can be applied. In particular, transformations that create - control flow around linalg.generic operations are not expected to mix with - tensors because SSA values do not escape naturally. Still, transformations - and rewrites that take advantage of tensor SSA values are expected to be - useful and will be added in the near future. + transformations can be applied. Such legalization moves tensor return values + into output buffer operands and updates the region argument accordingly. + + Transformations that create control-flow around linalg.indexed_generic + operations are not expected to work with tensors because SSA values do not + escape naturally. Still, transformations and rewrites that take advantage of + tensor SSA values are expected to be useful and will be added in the near + future. }]; let verifier = [{ return ::verify(*this); }]; } diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h index 7275863..2284616 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h @@ -58,16 +58,47 @@ template class StructuredOpTraits : public OpTrait::TraitBase { private: - /// Return the number of inputs. For internal use only. + /// Return the number of inputs, irrespective of their buffer or tensor type. + /// For internal use only. unsigned nInputs() { return cast(this->getOperation()).getNumInputs(); } - /// Return the number of outputs. For internal use only. + /// Return the number of outputs, irrespective of their buffer or tensor type. + /// For internal use only. unsigned nOutputs() { return cast(this->getOperation()).getNumOutputs(); } public: + //==========================================================================// + // Loop types handling. + //==========================================================================// + unsigned getNumParallelLoops() { + return getNumIterators( + getParallelIteratorTypeName(), + cast(this->getOperation()).iterator_types()); + } + unsigned getNumReductionLoops() { + return getNumIterators( + getReductionIteratorTypeName(), + cast(this->getOperation()).iterator_types()); + } + unsigned getNumWindowLoops() { + return getNumIterators( + getWindowIteratorTypeName(), + cast(this->getOperation()).iterator_types()); + } + unsigned getNumLoops() { + return getNumIterators( + cast(this->getOperation()).iterator_types()); + } + + //==========================================================================// + // Input arguments handling. + //==========================================================================// + // The `i^th` input argument is always the `i^th` operand regardless of + // whether we have tensors or buffers. + // /// Return the `i`-th input value. Value getInput(unsigned i) { assert(i < nInputs()); @@ -90,28 +121,6 @@ public: auto range = this->getOperation()->getOperands(); return {range.begin(), range.begin() + nInputs()}; } - /// Return the `i`-th output. - Value getOutput(unsigned i) { - return this->getOperation()->getOperand(nInputs() + i); - } - /// Return the index of `value` in the list of output values if found, - /// llvm::None otherwise. - Optional getIndexOfOutput(Value value) { - auto it = llvm::find(getOutputs(), value); - if (it != getOutputs().end()) - return it - getOutputs().begin(); - return llvm::None; - } - /// Return the `i`-th output buffer type. - ShapedType getOutputShapedType(unsigned i) { - return getOutput(i).getType().template cast(); - } - /// Query whether the op has only MemRef input and outputs. - bool hasBufferSemantics() { - return this->getOperation()->getNumResults() == 0 && - llvm::all_of(getInputsAndOutputs(), - [](Value v) { return v.getType().isa(); }); - } /// Query the subset of input operands that are of ranked tensor type. SmallVector getInputTensorTypes() { SmallVector res; @@ -120,53 +129,97 @@ public: res.push_back(t); return res; } - /// Query the subset of output operands that are of ranked tensor type. + + //==========================================================================// + // Output arguments handling. + //==========================================================================// + // The `i^th` output argument is an operand (resp. a return value) iff it is + // a value of buffer type (resp. a return value of tensor type). + + /// Return the `i`-th output, asserts that this is a buffer operand and not + /// a tensor result. + Value getOutputBuffer(unsigned i) { + assert(i + this->getOperation()->getNumResults() < nOutputs() && + "overflowing output buffer index"); + return this->getOperation()->getOperand(nInputs() + i); + } + /// Return the index of `value` in the list of output buffers if found, + /// llvm::None otherwise. + Optional getIndexOfOutputBuffer(Value value) { + auto it = llvm::find(getOutputBuffers(), value); + if (it != getOutputBuffers().end()) + return it - getOutputBuffers().begin(); + return llvm::None; + } + /// Return the `i`-th output buffer type. + MemRefType getOutputBufferType(unsigned i) { + return getOutputBuffer(i).getType().template cast(); + } + /// Return the `i`-th output shaped type, irrespective of buffer of tensor + /// type. + ShapedType getOutputShapedType(unsigned i) { + return getShapedType(i + nInputs()); + } + /// Query the subset of results that are of ranked tensor type. SmallVector getOutputTensorTypes() { SmallVector res; - for (Type type : getOutputs().getTypes()) - if (auto t = type.template dyn_cast()) - res.push_back(t); + for (Type type : this->getOperation()->getResults().getTypes()) + res.push_back(type.template cast()); return res; } /// Return the range over outputs. - Operation::operand_range getOutputs() { + Operation::operand_range getOutputBuffers() { auto range = this->getOperation()->getOperands(); return {range.begin() + nInputs(), - range.begin() + getNumInputsAndOutputs()}; + range.begin() + getNumInputsAndOutputBuffers()}; } - /// Return the number of inputs and outputs. + + //==========================================================================// + // Input and Output arguments handling. + //==========================================================================// + /// Return the number of inputs and outputs, irrespective of their buffer or + /// tensor type. unsigned getNumInputsAndOutputs() { return nInputs() + nOutputs(); } - /// Return the `i`-th buffer type. - ShapedType getShapedType(unsigned i) { - return (i < nInputs()) ? getInputShapedType(i) - : getOutputShapedType(i - nInputs()); - } - /// Return the range over inputs and outputs. - Operation::operand_range getInputsAndOutputs() { + /// Return the number of inputs, irrespective of their buffer or tensor type, + /// and output buffers. + unsigned getNumInputsAndOutputBuffers() { + assert(this->getOperation()->getNumResults() <= nOutputs()); + return nInputs() + nOutputs() - this->getOperation()->getNumResults(); + } + /// Return the range over inputs (irrespective of type) and output buffers. + Operation::operand_range getInputsAndOutputBuffers() { auto range = this->getOperation()->getOperands(); - return {range.begin(), range.begin() + getNumInputsAndOutputs()}; + return {range.begin(), range.begin() + getNumInputsAndOutputBuffers()}; } - unsigned getNumParallelLoops() { - return getNumIterators( - getParallelIteratorTypeName(), - cast(this->getOperation()).iterator_types()); - } - unsigned getNumReductionLoops() { - return getNumIterators( - getReductionIteratorTypeName(), - cast(this->getOperation()).iterator_types()); - } - unsigned getNumWindowLoops() { - return getNumIterators( - getWindowIteratorTypeName(), - cast(this->getOperation()).iterator_types()); + /// Return the `i`-th shaped type, there are 3 cases: + /// 1. if `i < nInputs()` then return `getInputShapedType(i)`; otherwise + /// 2. if `i < getNumInputsAndOutputBuffers()` then return the + /// `getOutputBufferType(i - nInputs())`; otherwise + /// 3. return the `i - getNumInputsAndOutputBuffers()` result type. + ShapedType getShapedType(unsigned i) { + if (i < nInputs()) + return getInputShapedType(i); + if (i < getNumInputsAndOutputBuffers()) + return getOutputBufferType(i - nInputs()).template cast(); + return getOutputTensorTypes()[i - getNumInputsAndOutputBuffers()] + .template cast(); } - unsigned getNumLoops() { - return getNumIterators( - cast(this->getOperation()).iterator_types()); + + //==========================================================================// + // Other interface methods. + //==========================================================================// + /// Query whether the op has only buffer inputs and no returns. + bool hasBufferSemantics() { + return this->getOperation()->getNumResults() == 0 && + llvm::all_of(getInputs(), + [](Value v) { return v.getType().isa(); }); } + + //==========================================================================// + // Other static interface methods. + //==========================================================================// static LogicalResult verifyTrait(Operation *op) { - auto nOperands = cast(op).getNumInputsAndOutputs(); + auto nOperands = cast(op).getNumInputsAndOutputBuffers(); if (failed(OpTrait::impl::verifyAtLeastNOperands(op, nOperands))) return failure(); return success(); diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h index 2bd4a0d..f559ba4 100644 --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -119,7 +119,7 @@ Optional fuseProducerOf(OpBuilder &b, LinalgOp consumer, template SmallVector getViewSizes(ConcreteOp linalgOp) { SmallVector res; - for (auto v : linalgOp.getInputsAndOutputs()) { + for (auto v : linalgOp.getInputsAndOutputBuffers()) { MemRefType t = v.getType().template cast(); for (unsigned i = 0; i < t.getRank(); ++i) res.push_back(edsc::intrinsics::dim(v, i)); diff --git a/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp b/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp index 144afa4..109a35b 100644 --- a/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp +++ b/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp @@ -139,7 +139,11 @@ LinalgDependenceGraph::getDependencesInto( } void LinalgDependenceGraph::addDependencesBetween(LinalgOp src, LinalgOp dst) { - for (auto srcView : src.getOutputs()) { // W + assert(src.hasBufferSemantics() && + "expected linalg op with buffer semantics"); + assert(dst.hasBufferSemantics() && + "expected linalg op with buffer semantics"); + for (auto srcView : src.getOutputBuffers()) { // W // RAW graph for (auto dstView : dst.getInputs()) { // R if (aliases.alias(srcView, dstView)) { // if alias, fill RAW @@ -149,7 +153,7 @@ void LinalgDependenceGraph::addDependencesBetween(LinalgOp src, LinalgOp dst) { } } // WAW graph - for (auto dstView : dst.getOutputs()) { // W + for (auto dstView : dst.getOutputBuffers()) { // W if (aliases.alias(srcView, dstView)) { // if alias, fill WAW addDependenceElem(DependenceType::WAW, LinalgOpView{src.getOperation(), srcView}, @@ -167,7 +171,7 @@ void LinalgDependenceGraph::addDependencesBetween(LinalgOp src, LinalgOp dst) { } } // WAR graph - for (auto dstView : dst.getOutputs()) { // W + for (auto dstView : dst.getOutputBuffers()) { // W if (aliases.alias(srcView, dstView)) { // if alias, fill WAR addDependenceElem(DependenceType::WAR, LinalgOpView{src.getOperation(), srcView}, diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 384d249..7850dd60 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -112,19 +112,20 @@ template static LogicalResult verifyBlockArgs(GenericOpType op, Block &block); template <> LogicalResult verifyBlockArgs(GenericOp op, Block &block) { - auto nViews = op.getNumInputsAndOutputs(); - auto nInputViews = op.getNumInputs(); - if (block.getNumArguments() != nViews) - return op.emitOpError( - "expected number of block arguments to match number of views"); + auto nOperands = op.getNumOperands(); + if (block.getNumArguments() != nOperands) + return op.emitOpError("expected number of block arguments to match number " + "of operands"); - for (unsigned i = 0; i < nViews; ++i) { + // Note: the number and type of yield values are checked in the YieldOp. + auto nInputViews = op.getNumInputs(); + for (unsigned i = 0; i < nOperands; ++i) { auto viewType = op.getShapedType(i); if (viewType.getElementType() != block.getArgument(i).getType()) return op.emitOpError("expected block argument ") - << i << " of the same type as elemental type of " + << (i + 1) << " of the same type as elemental type of " << ((i < nInputViews) ? "input " : "output ") - << "view: " << viewType; + << "operand: " << viewType; } return success(); } @@ -132,27 +133,28 @@ template <> LogicalResult verifyBlockArgs(GenericOp op, Block &block) { template <> LogicalResult verifyBlockArgs(IndexedGenericOp op, Block &block) { auto nInputViews = op.getNumInputs(); auto nLoops = op.getNumLoops(); - auto nViews = op.getNumInputsAndOutputs(); - if (block.getNumArguments() != nViews + nLoops) + auto nOperands = op.getNumOperands(); + if (block.getNumArguments() != nOperands + nLoops) return op.emitOpError( - "expected number of block arguments to match number of views + " + "expected number of block arguments to match number of operands + " "number of loops"); - for (unsigned i = 0; i < nLoops; ++i) { + // Note: the number and type of yield values are checked in the YieldOp. + for (unsigned i = 0; i < nLoops; ++i) if (!block.getArgument(i).getType().isIndex()) return op.emitOpError("expected block argument ") - << i << " to be of IndexType"; - } + << (i + 1) << " to be an index"; - for (unsigned i = 0; i < nViews; ++i) { + for (unsigned i = 0; i < nOperands; ++i) { unsigned memrefArgIndex = i + nLoops; auto viewType = op.getShapedType(i); if (viewType.getElementType() != block.getArgument(memrefArgIndex).getType()) return op.emitOpError("expected block argument ") - << memrefArgIndex << " of the same type as elemental type of " + << (memrefArgIndex + 1) + << " of the same type as elemental type of " << ((i < nInputViews) ? "input " : "output ") - << "view: " << viewType; + << "operand: " << viewType; } return success(); } @@ -160,70 +162,74 @@ template <> LogicalResult verifyBlockArgs(IndexedGenericOp op, Block &block) { template static LogicalResult verifyFuncArgs(GenericOpType op, FunctionType funType); +template +LogicalResult verifyFuncArgsGeneric(GenericOpType op, FunctionType funType) { + auto res = verifyFuncArgs(op, funType); + if (failed(res)) + return res; + + auto nInputs = op.getNumInputs(); + auto nOutputs = op.getNumOutputs(); + // linalg.generic output element types are exactly the function results. + for (unsigned idx = 0; idx < nOutputs; ++idx) { + ShapedType shapedType = op.getShapedType(nInputs + idx); + if (funType.getResult(idx) != shapedType.getElementType()) + return op.emitOpError("expected function result ") + << (idx + 1) << " of the same type as elemental type " + << shapedType.getElementType() << " of output " << (idx + 1); + } + return success(); +} + template <> LogicalResult verifyFuncArgs(GenericOp op, FunctionType funType) { - auto nViews = op.getNumInputsAndOutputs(); - auto nInputViews = op.getNumInputs(); - if (funType.getNumInputs() != nViews) - return op.emitOpError("expected fun arguments to match number of views"); - if (funType.getNumResults() != op.getNumOutputs()) + auto nOperands = op.getNumOperands(); + if (funType.getNumInputs() != nOperands) return op.emitOpError( - "expected fun results to match number of output views"); - - for (auto en : llvm::enumerate(op.indexing_maps())) { - auto idx = en.index(); - auto view = (idx < nInputViews) ? op.getInputShapedType(idx) - : op.getOutputShapedType(idx - nInputViews); - if (funType.getInput(idx) != view.getElementType()) - return op.emitOpError("expected fun argument ") - << idx << " of the same type as elemental type " - << view.getElementType() << " of view " << idx; - - if (idx >= nInputViews) { - auto resultIdx = idx - nInputViews; - if (funType.getResult(resultIdx) != view.getElementType()) - return op.emitOpError("expected fun result ") - << resultIdx << " of the same type as elemental type " - << view.getElementType() << " of view " << idx; - } + "expected function arguments to match number of operands"); + if (funType.getNumResults() != op.getNumOutputs()) + return op.emitOpError("expected function results(") + << funType.getNumResults() << ") to match number of outputs(" + << op.getNumOutputs() << ")"; + + // linalg.generic operands element types are exactly the first function + // arguments. + for (unsigned idx = 0; idx < nOperands; ++idx) { + ShapedType shapedType = op.getShapedType(idx); + if (funType.getInput(idx) != shapedType.getElementType()) + return op.emitOpError("expected function argument ") + << (idx + 1) << " of the same type as elemental type " + << shapedType.getElementType() << " of operand " << (idx + 1); } + return success(); } template <> LogicalResult verifyFuncArgs(IndexedGenericOp op, FunctionType funType) { auto nLoops = op.getNumLoops(); - auto nInputViews = op.getNumInputs(); auto nOutputs = op.getNumOutputs(); - auto nViews = op.getNumInputsAndOutputs(); - if (funType.getNumInputs() != nViews + nLoops) - return op.emitOpError( - "expected fun arguments to match number of views + number of loops"); + auto nOperands = op.getNumOperands(); + if (funType.getNumInputs() != nOperands + nLoops) + return op.emitOpError("expected function arguments to match number of " + "loops + number of operands"); if (funType.getNumResults() != nOutputs) return op.emitOpError( - "expected fun results to match number of output views"); - for (unsigned i = 0; i < nLoops; ++i) { + "expected function results to match number of outputs"); + for (unsigned i = 0; i < nLoops; ++i) if (!funType.getInput(i).isIndex()) - return op.emitOpError("expected fun argument ") - << i << " to be of IndexType"; - } - for (auto en : llvm::enumerate(op.indexing_maps())) { - auto idx = en.index(); - auto funIdx = nLoops + idx; - auto view = (idx < nInputViews) ? op.getInputShapedType(idx) - : op.getOutputShapedType(idx - nInputViews); - if (funType.getInput(funIdx) != view.getElementType()) - return op.emitOpError("expected fun argument ") - << funIdx << " of the same type as elemental type " - << view.getElementType() << " of view " << idx; - - if (idx >= nInputViews) { - auto resultIdx = idx - nInputViews; - if (funType.getResult(resultIdx) != view.getElementType()) - return op.emitOpError("expected fun result ") - << resultIdx << " of the same type as elemental type " - << view.getElementType() << " of view " << idx; - } + return op.emitOpError("expected function argument ") + << (i + 1) << " to be an index"; + + // linalg.generic operands element types are exactly the first function + // arguments. + for (unsigned idx = 0; idx < nOperands; ++idx) { + ShapedType shapedType = op.getShapedType(idx); + if (funType.getInput(idx + nLoops) != shapedType.getElementType()) + return op.emitOpError("expected function argument ") + << (idx + nLoops + 1) << " of the same type as elemental type " + << shapedType.getElementType() << " of input " << (idx + 1); } + return success(); } @@ -231,9 +237,11 @@ template static LogicalResult verifyGenericOp(GenericOpType op) { auto nInputViews = op.getNumInputs(); auto nLoops = op.getNumLoops(); - auto nViews = op.getNumInputsAndOutputs(); - if (nViews != llvm::size(op.views())) - return op.emitOpError("expected exactly ") << nViews << " view operands"; + auto nInputsAndOutputBuffers = op.getNumInputsAndOutputBuffers(); + if (nInputsAndOutputBuffers != llvm::size(op.views())) + return op.emitOpError("expected exactly ") + << nInputsAndOutputBuffers + << " inputs (tensor or buffer) and output buffer operands"; auto ®ion = op.region(); auto funOp = op.getFunction(); @@ -246,8 +254,8 @@ static LogicalResult verifyGenericOp(GenericOpType op) { } else { if (!funOp || !funOp.getType()) return op.emitOpError( - "expected fun attribute to refer to a defined symbol"); - if (failed(verifyFuncArgs(op, funType))) + "expected function attribute to refer to a defined symbol"); + if (failed(verifyFuncArgsGeneric(op, funType))) return failure(); } @@ -287,22 +295,6 @@ static LogicalResult verifyGenericOp(GenericOpType op) { return op.emitOpError("expected the concatenation of maps in indexing_map " "to be invertible"); - auto outputTensorTypes = op.getOutputTensorTypes(); - if (outputTensorTypes.size() != op.getNumResults()) - return op.emitOpError("expected #output tensor operands (") - << outputTensorTypes.size() << ") to match #results (" - << op.getNumResults() << ")"; - - unsigned index = 0; - for (auto it : llvm::zip(op.getResultTypes(), outputTensorTypes)) { - auto resTy = std::get<0>(it); - auto outOpTy = std::get<1>(it); - if (resTy != outOpTy) - return op.emitOpError("result #") - << index << " must be " << outOpTy << ", but got " << resTy; - ++index; - } - return success(); } @@ -731,17 +723,20 @@ static ParseResult parseYieldOp(OpAsmParser &parser, OperationState &result) { template static LogicalResult verifyYield(YieldOp op, GenericOpType genericOp) { // The operand number and types must match the view element types. - auto nOutputViews = genericOp.getNumOutputs(); - if (op.getNumOperands() != nOutputViews) - return op.emitOpError("expected ") - << nOutputViews << " operand to match enclosing linalg.generic op"; + auto nOutputs = genericOp.getNumOutputs(); + if (op.getNumOperands() != nOutputs) + return op.emitOpError("expected number of yield values (") + << nOutputs << ") to match the number of operands of the enclosing " + << "linalg.generic op (" << op.getNumOperands() << ")"; - for (unsigned i = 0; i != nOutputViews; ++i) { + for (unsigned i = 0; i != nOutputs; ++i) { auto elementType = genericOp.getOutputShapedType(i).getElementType(); if (op.getOperand(i).getType() != elementType) - return op.emitOpError("type of return operand ") - << i << " (" << op.getOperand(i).getType() - << ") doesn't match view element type (" << elementType << ")"; + return op.emitOpError("type of yield operand ") + << (i + 1) << " (" << op.getOperand(i).getType() + << ") doesn't match " + << "the element type of the enclosing linalg.generic op (" + << elementType << ")"; } return success(); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp index 043d9c0..6ad73ee 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -67,12 +67,13 @@ static llvm::cl::list clTileSizes( // to the `loopRanges` in order to obtain view ranges. static LinalgOp cloneWithLoopRanges(OpBuilder &b, Location loc, LinalgOp op, ArrayRef loopRanges) { + assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics"); auto maps = loopToOperandRangesMaps(op); SmallVector clonedViews; clonedViews.reserve(op.getNumInputsAndOutputs()); // Iterate over the inputs and outputs in order. // Extract the subranges from the linearized ranges. - SmallVector ios(op.getInputsAndOutputs()); + SmallVector ios(op.getInputsAndOutputBuffers()); for (auto en : llvm::enumerate(ios)) { unsigned idx = en.index(); auto map = maps[idx]; @@ -118,10 +119,11 @@ struct ViewDimension { // they must agree by construction (i.e. have the same size) and we just return // the first one. static ViewDimension getViewDefiningLoopRange(LinalgOp op, unsigned loopDepth) { + assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics"); auto maps = loopToOperandRangesMaps(op); // Iterate over the inputs and outputs in order. // Extract the subranges from the linearized ranges. - SmallVector ios(op.getInputsAndOutputs()); + SmallVector ios(op.getInputsAndOutputBuffers()); for (auto en : llvm::enumerate(ios)) { unsigned idx = en.index(); auto map = maps[idx]; @@ -144,6 +146,10 @@ static ViewDimension getViewDefiningLoopRange(LinalgOp op, unsigned loopDepth) { static LinalgOp fuse(Value producedView, LinalgOp producer, LinalgOp consumer, unsigned consumerIdx, unsigned producerIdx, OperationFolder *folder) { + assert(producer.hasBufferSemantics() && + "expected linalg op with buffer semantics"); + assert(consumer.hasBufferSemantics() && + "expected linalg op with buffer semantics"); auto subView = dyn_cast_or_null( consumer.getInput(consumerIdx).getDefiningOp()); auto slice = @@ -197,6 +203,10 @@ static LinalgOp fuse(Value producedView, LinalgOp producer, LinalgOp consumer, // Some of these will be lifted in the future with better analysis. static bool isStructurallyFusableProducer(LinalgOp producer, Value consumedView, LinalgOp consumer) { + assert(producer.hasBufferSemantics() && + "expected linalg op with buffer semantics"); + assert(consumer.hasBufferSemantics() && + "expected linalg op with buffer semantics"); if (producer.getNumOutputs() != 1) { LLVM_DEBUG(dbgs() << "\nNot structurally fusable (multi-output)"); return false; @@ -217,6 +227,10 @@ bool mlir::linalg::isProducerLastWriteOfView(const LinalgDependenceGraph &graph, LinalgOp consumer, Value consumedView, LinalgOp producer) { + assert(producer.hasBufferSemantics() && + "expected linalg op with buffer semantics"); + assert(consumer.hasBufferSemantics() && + "expected linalg op with buffer semantics"); // Make some simple structural checks that alleviate the need for more // complex analyses. if (!isStructurallyFusableProducer(producer, consumedView, consumer)) { @@ -236,6 +250,10 @@ bool mlir::linalg::isProducerLastWriteOfView(const LinalgDependenceGraph &graph, bool mlir::linalg::isFusableInto(const LinalgDependenceGraph &graph, LinalgOp consumer, Value consumedView, LinalgOp producer) { + assert(producer.hasBufferSemantics() && + "expected linalg op with buffer semantics"); + assert(consumer.hasBufferSemantics() && + "expected linalg op with buffer semantics"); if (!isProducerLastWriteOfView(graph, consumer, consumedView, producer)) return false; // Check for any fusion-preventing dependence to any view read/written that @@ -252,6 +270,8 @@ bool mlir::linalg::isFusableInto(const LinalgDependenceGraph &graph, Optional mlir::linalg::fuseProducerOf( OpBuilder &b, LinalgOp consumer, unsigned consumerIdx, const LinalgDependenceGraph &graph, OperationFolder *folder) { + assert(consumer.hasBufferSemantics() && + "expected linalg op with buffer semantics"); LLVM_DEBUG(dbgs() << "\nStart examining consumer: " << *consumer.getOperation()); for (auto dependence : graph.getDependencesInto( @@ -268,7 +288,7 @@ Optional mlir::linalg::fuseProducerOf( // Consumer consumes this view, `isStructurallyFusableProducer` also checks // whether it is a strict subview of the producer view. auto producedView = dependence.dependentOpView.view; - auto producerIdx = producer.getIndexOfOutput(producedView).getValue(); + auto producerIdx = producer.getIndexOfOutputBuffer(producedView).getValue(); // `consumerIdx` and `producerIdx` exist by construction. LLVM_DEBUG(dbgs() << "\nRAW producer: " << *producer.getOperation() << " view: " << producedView @@ -309,7 +329,10 @@ static void fuseLinalgOpsGreedily(FuncOp f) { // Save original Linalg ops, we only want to make a pass over those. SmallVector linalgOps; - f.walk([&](LinalgOp op) { linalgOps.push_back(op); }); + f.walk([&](LinalgOp op) { + if (op.hasBufferSemantics()) + linalgOps.push_back(op); + }); Aliases aliases; LinalgDependenceGraph G(aliases, linalgOps); diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp index f5dac8a..2f97b62 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp @@ -90,6 +90,8 @@ template class LinalgScopedEmitter { public: static void emitScalarImplementation(ArrayRef allIvs, CopyOp copyOp) { + assert(copyOp.hasBufferSemantics() && + "expected linalg op with buffer semantics"); auto nPar = copyOp.getNumParallelLoops(); assert(nPar == allIvs.size()); auto inputIvs = @@ -98,7 +100,7 @@ public: permuteIvs(allIvs.take_front(nPar), copyOp.outputPermutation()); SmallVector iivs(inputIvs.begin(), inputIvs.end()); SmallVector oivs(outputIvs.begin(), outputIvs.end()); - IndexedValueType O(copyOp.getOutput(0)), I(copyOp.getInput(0)); + IndexedValueType O(copyOp.getOutputBuffer(0)), I(copyOp.getInput(0)); // Emit the proper scalar assignment, whether we are dealing with a 0-D or // an n-D loop nest; with or without permutations. // clang-format off @@ -112,11 +114,13 @@ template class LinalgScopedEmitter { public: static void emitScalarImplementation(ArrayRef allIvs, FillOp fillOp) { + assert(fillOp.hasBufferSemantics() && + "expected linalg op with buffer semantics"); auto nPar = fillOp.getNumParallelLoops(); assert(nPar == allIvs.size()); auto ivs = SmallVector(allIvs.begin(), allIvs.begin() + nPar); - IndexedValueType O(fillOp.getOutput(0)); + IndexedValueType O(fillOp.getOutputBuffer(0)); // Emit the proper scalar assignment, whether we are dealing with a 0-D or // an n-D loop nest; with or without permutations. nPar > 0 ? O(ivs) = ValueHandle(fillOp.value()) @@ -128,10 +132,12 @@ template class LinalgScopedEmitter { public: static void emitScalarImplementation(ArrayRef allIvs, DotOp dotOp) { + assert(dotOp.hasBufferSemantics() && + "expected linalg op with buffer semantics"); assert(allIvs.size() == 1); IndexHandle r_i(allIvs[0]); IndexedValueType A(dotOp.getInput(0)), B(dotOp.getInput(1)), - C(dotOp.getOutput(0)); + C(dotOp.getOutputBuffer(0)); // Emit scalar form. C() = C() + A(r_i) * B(r_i); } @@ -142,10 +148,12 @@ class LinalgScopedEmitter { public: static void emitScalarImplementation(ArrayRef allIvs, MatvecOp matvecOp) { + assert(matvecOp.hasBufferSemantics() && + "expected linalg op with buffer semantics"); assert(allIvs.size() == 2); IndexHandle i(allIvs[0]), r_j(allIvs[1]); IndexedValueType A(matvecOp.getInput(0)), B(matvecOp.getInput(1)), - C(matvecOp.getOutput(0)); + C(matvecOp.getOutputBuffer(0)); // Emit scalar form. C(i) = C(i) + A(i, r_j) * B(r_j); } @@ -156,10 +164,12 @@ class LinalgScopedEmitter { public: static void emitScalarImplementation(ArrayRef allIvs, MatmulOp matmulOp) { + assert(matmulOp.hasBufferSemantics() && + "expected linalg op with buffer semantics"); assert(allIvs.size() == 3); IndexHandle i(allIvs[0]), j(allIvs[1]), r_k(allIvs[2]); IndexedValueType A(matmulOp.getInput(0)), B(matmulOp.getInput(1)), - C(matmulOp.getOutput(0)); + C(matmulOp.getOutputBuffer(0)); // Emit scalar form. C(i, j) = C(i, j) + A(i, r_k) * B(r_k, j); } @@ -169,6 +179,8 @@ template class LinalgScopedEmitter { public: static void emitScalarImplementation(ArrayRef allIvs, ConvOp convOp) { + assert(convOp.hasBufferSemantics() && + "expected linalg op with buffer semantics"); auto b = ScopedContext::getBuilder(); auto loc = ScopedContext::getLocation(); auto maps = loopToOperandRangesMaps(convOp); @@ -219,6 +231,8 @@ class LinalgScopedEmitter { public: static void emitScalarImplementation(ArrayRef allIvs, GenericOp genericOp) { + assert(genericOp.hasBufferSemantics() && + "expected linalg op with buffer semantics"); auto b = ScopedContext::getBuilder(); auto loc = ScopedContext::getLocation(); using edsc::intrinsics::detail::ValueHandleArray; @@ -237,7 +251,8 @@ public: for (unsigned i = 0; i < nOutputs; ++i) { ValueHandleArray indexing(makeCanonicalAffineApplies( b, loc, genericOp.getOutputIndexingMap(i), allIvs)); - indexedValues[nInputs + i] = std_load(genericOp.getOutput(i), indexing); + indexedValues[nInputs + i] = + std_load(genericOp.getOutputBuffer(i), indexing); } auto funcOp = genericOp.getFunction(); @@ -250,7 +265,7 @@ public: for (unsigned i = 0; i < nOutputs; ++i) { ValueHandleArray indexing(makeCanonicalAffineApplies( b, loc, genericOp.getOutputIndexingMap(i), allIvs)); - std_store(callOp->getResult(i), genericOp.getOutput(i), indexing); + std_store(callOp->getResult(i), genericOp.getOutputBuffer(i), indexing); } return; } @@ -273,8 +288,8 @@ public: for (unsigned i = 0; i < nOutputs; ++i) { ValueHandleArray indexing(makeCanonicalAffineApplies( b, loc, genericOp.getOutputIndexingMap(i), allIvs)); - std_store(map.lookup(yieldOp->getOperand(i)), genericOp.getOutput(i), - indexing); + std_store(map.lookup(yieldOp->getOperand(i)), + genericOp.getOutputBuffer(i), indexing); } } }; @@ -314,6 +329,8 @@ class LinalgScopedEmitter { public: static void emitScalarImplementation(ArrayRef allIvs, IndexedGenericOp indexedGenericOp) { + assert(indexedGenericOp.hasBufferSemantics() && + "expected linalg op with buffer semantics"); auto b = ScopedContext::getBuilder(); auto loc = ScopedContext::getLocation(); using edsc::intrinsics::detail::ValueHandleArray; @@ -339,7 +356,7 @@ public: ValueHandleArray indexing(makeCanonicalAffineApplies( b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs)); indexedValues[nLoops + nInputs + i] = - std_load(indexedGenericOp.getOutput(i), indexing); + std_load(indexedGenericOp.getOutputBuffer(i), indexing); } if (auto funcOp = indexedGenericOp.getFunction()) { @@ -351,7 +368,7 @@ public: for (unsigned i = 0; i < nOutputs; ++i) { ValueHandleArray indexing(makeCanonicalAffineApplies( b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs)); - std_store(callOp->getResult(i), indexedGenericOp.getOutput(i), + std_store(callOp->getResult(i), indexedGenericOp.getOutputBuffer(i), indexing); } return; @@ -376,7 +393,7 @@ public: ValueHandleArray indexing(makeCanonicalAffineApplies( b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs)); std_store(map.lookup(yieldOp->getOperand(i)), - indexedGenericOp.getOutput(i), indexing); + indexedGenericOp.getOutputBuffer(i), indexing); } } }; @@ -404,6 +421,8 @@ LogicalResult LinalgOpToLoopsImpl::doit( // The flattened loopToOperandRangesMaps is expected to be an invertible // permutation map (which is asserted in the inverse calculation). auto linalgOp = cast(op); + assert(linalgOp.hasBufferSemantics() && + "expected linalg op with buffer semantics"); auto invertedMap = inversePermutation(concatAffineMaps(loopToOperandRangesMaps(linalgOp))); if (!invertedMap) { diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp index 9657daf..10c537e 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp @@ -93,6 +93,8 @@ bool mlir::linalg::detail::isProducedByOpOfTypeImpl( Operation *consumerOp, Value consumedView, function_ref isaOpType) { LinalgOp consumer = dyn_cast(consumerOp); + assert(consumer.hasBufferSemantics() && + "expected linalg op with buffer semantics"); if (!consumer) return false; @@ -171,7 +173,7 @@ mlir::linalg::vectorizeGenericLinalgOpPrecondition(Operation *op) { return false; return true; }; - if (!llvm::all_of(genericOp.getInputsAndOutputs(), + if (!llvm::all_of(genericOp.getInputsAndOutputBuffers(), isStaticMemRefWithIdentityLayout)) return failure(); return success(); @@ -188,6 +190,8 @@ mlir::linalg::vectorizeGenericLinalgOp(PatternRewriter &rewriter, "DRR failure case must be a precondition"); auto genericOp = cast(op); + assert(genericOp.hasBufferSemantics() && + "expected linalg op with buffer semantics"); edsc::ScopedContext scope(rewriter, op->getLoc()); using edsc::intrinsics::std_load; using edsc::intrinsics::std_store; @@ -195,7 +199,7 @@ mlir::linalg::vectorizeGenericLinalgOp(PatternRewriter &rewriter, using vector_type_cast = edsc::intrinsics::ValueBuilder; auto vA = std_load(vector_type_cast(genericOp.getInput(0))); auto vB = std_load(vector_type_cast(genericOp.getInput(1))); - auto vectorMemRefC = vector_type_cast(genericOp.getOutput(0)); + auto vectorMemRefC = vector_type_cast(genericOp.getOutputBuffer(0)); auto vC = std_load(vectorMemRefC); auto vRes = vector_contract(vA, vB, vC, genericOp.indexing_maps(), genericOp.iterator_types()); @@ -262,7 +266,7 @@ LogicalResult mlir::linalg::promoteSubviewsLinalgOpPrecondition(Operation *op) { // Transformation applies to buffers only. if (!linOp || !linOp.hasBufferSemantics()) return failure(); - if (llvm::none_of(linOp.getInputsAndOutputs(), [](Value v) { + if (llvm::none_of(linOp.getInputsAndOutputBuffers(), [](Value v) { return isa_and_nonnull(v.getDefiningOp()); })) return failure(); @@ -279,8 +283,10 @@ mlir::linalg::promoteSubviewsLinalgOp(PatternRewriter &rewriter, "DRR failure case must be a precondition"); LinalgOp linOp = cast(op); + assert(linOp.hasBufferSemantics() && + "expected linalg op with buffer semantics"); SetVector subViews; - for (auto it : linOp.getInputsAndOutputs()) + for (auto it : linOp.getInputsAndOutputBuffers()) if (auto sv = dyn_cast_or_null(it.getDefiningOp())) subViews.insert(sv); if (!subViews.empty()) { diff --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp index eb60569..3caa8c8 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp @@ -155,6 +155,8 @@ LinalgOp mlir::linalg::promoteSubViewOperands(OpBuilder &b, LinalgOp op, SetVector subViews, bool dynamicBuffers, OperationFolder *folder) { + assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics"); + // 1. Promote the specified views and use them in the new op. ScopedContext scope(b, op.getLoc()); auto promotedBufferAndViews = promoteSubViews( @@ -164,7 +166,7 @@ LinalgOp mlir::linalg::promoteSubViewOperands(OpBuilder &b, LinalgOp op, SmallVector, 8> writebackViews; writebackViews.reserve(subViews.size()); unsigned promotedIdx = 0; - for (auto view : op.getInputsAndOutputs()) { + for (auto view : op.getInputsAndOutputBuffers()) { if (subViews.count(view) != 0) { opViews.push_back(promotedBufferAndViews[promotedIdx].fullLocalView); writebackViews.emplace_back(std::make_pair( @@ -187,7 +189,7 @@ LinalgOp mlir::linalg::promoteSubViewOperands(OpBuilder &b, LinalgOp op, // WARNING: MUST use the old op to determine whether the operand view is an // output. bool isOutput = - op.getIndexOfOutput(viewAndPartialLocalView.first).hasValue(); + op.getIndexOfOutputBuffer(viewAndPartialLocalView.first).hasValue(); if (isOutput) copy(viewAndPartialLocalView.second, viewAndPartialLocalView.first); } @@ -203,11 +205,14 @@ static void promoteSubViews(FuncOp f, bool dynamicBuffers) { SmallVector toErase; OperationFolder folder(f.getContext()); f.walk([dynamicBuffers, &folder, &toErase](LinalgOp op) { + if (!op.hasBufferSemantics()) + return; + // TODO(ntv) some heuristic here to decide what to promote. Atm it is all or // nothing. SetVector subViews; OpBuilder b(op); - for (auto it : op.getInputsAndOutputs()) + for (auto it : op.getInputsAndOutputBuffers()) if (auto sv = dyn_cast_or_null(it.getDefiningOp())) subViews.insert(sv); if (!subViews.empty()) { diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp index bcf6576a..ed9553d 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -173,6 +173,7 @@ struct TileCheck : public AffineExprVisitor { static void transformIndexedGenericOpIndices( OpBuilder &b, LinalgOp op, ArrayRef pivs, const LoopIndexToRangeIndexMap &loopIndexToRangeIndex) { + assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics"); auto indexedGenericOp = dyn_cast(op.getOperation()); if (!indexedGenericOp) return; @@ -232,6 +233,8 @@ static SmallVector makeTiledViews(OpBuilder &b, Location loc, LinalgOp linalgOp, ArrayRef ivs, ArrayRef tileSizes, ArrayRef viewSizes, OperationFolder *folder) { + assert(linalgOp.hasBufferSemantics() && + "expected linalg op with buffer semantics"); assert(ivs.size() == static_cast(llvm::count_if( llvm::make_range(tileSizes.begin(), tileSizes.end()), [](Value v) { return !isZero(v); })) && @@ -254,7 +257,7 @@ makeTiledViews(OpBuilder &b, Location loc, LinalgOp linalgOp, SmallVector res; res.reserve(op->getNumOperands()); - auto viewIteratorBegin = linalgOp.getInputsAndOutputs().begin(); + auto viewIteratorBegin = linalgOp.getInputsAndOutputBuffers().begin(); for (unsigned viewIndex = 0; viewIndex < linalgOp.getNumInputsAndOutputs(); ++viewIndex) { Value view = *(viewIteratorBegin + viewIndex); @@ -309,6 +312,7 @@ Optional mlir::linalg::tileLinalgOp(OpBuilder &b, LinalgOp op, ArrayRef tileSizes, ArrayRef permutation, OperationFolder *folder) { + assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics"); // 1. Enforce the convention that "tiling by zero" skips tiling a particular // dimension. This convention is significantly simpler to handle instead of // adjusting affine maps to account for missing dimensions. @@ -383,6 +387,7 @@ mlir::linalg::tileLinalgOp(OpBuilder &b, LinalgOp op, ArrayRef tileSizes, Optional mlir::linalg::tileLinalgOp( OpBuilder &b, LinalgOp op, ArrayRef tileSizes, ArrayRef permutation, OperationFolder *folder) { + assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics"); if (tileSizes.empty()) return llvm::None; @@ -419,6 +424,8 @@ static void tileLinalgOps(FuncOp f, ArrayRef tileSizes) { OpBuilder b(f); OperationFolder folder(f.getContext()); f.walk([tileSizes, &b, &folder](LinalgOp op) { + if (!op.hasBufferSemantics()) + return; auto opLoopsPair = tileLinalgOp(b, op, tileSizes, /*permutation=*/{}, &folder); // If tiling occurred successfully, erase old op. diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir index 748ab05..c7c0752 100644 --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -68,7 +68,7 @@ func @generic_at_least_2_operands(%arg0: memref) { // ----- func @generic_exactly_2_views(%arg0: memref) { - // expected-error @+1 {{op expected exactly 2 view operands}} + // expected-error @+1 {{op expected exactly 2 inputs (tensor or buffer) and output buffer operands}} linalg.generic { args_in = 1, args_out = 1, @@ -81,7 +81,7 @@ func @generic_exactly_2_views(%arg0: memref) { // ----- func @generic_undefined_fun(%arg0: memref) { - // expected-error @+1 {{op expected fun attribute to refer to a defined symbol}} + // expected-error @+1 {{op expected function attribute to refer to a defined symbol}} linalg.generic { args_in = 1, args_out = 1, @@ -96,7 +96,7 @@ func @generic_undefined_fun(%arg0: memref) { func @foo() { return } func @generic_mismatched_num_arguments(%arg0: memref) { - // expected-error @+1 {{op expected fun arguments to match number of views}} + // expected-error @+1 {{op expected function arguments to match number of operands}} linalg.generic { args_in = 0, args_out = 1, @@ -111,7 +111,7 @@ func @generic_mismatched_num_arguments(%arg0: memref) { func @foo(%0: i32) { return } func @generic_mismatched_num_returns(%arg0: memref) { - // expected-error @+1 {{op expected fun results to match number of output views}} + // expected-error @+1 {{op expected function results(0) to match number of outputs(1)}} linalg.generic { args_in = 0, args_out = 1, @@ -123,6 +123,36 @@ func @generic_mismatched_num_returns(%arg0: memref) { // ----- +func @foo(%0: i32, %1: i32, %2: i32) { return } + +func @generic_mismatched_num_returns(%0: memref, %1: memref) { + // expected-error @+1 {{op expected function argument 2 of the same type as elemental type 'f32' of operand 2}} + linalg.generic { + args_in = 3, + args_out = 0, + fun = @foo, + indexing_maps = [ affine_map<() -> (0)> ], + iterator_types = [] + } %0, %1, %1: memref, memref, memref +} + +// ----- + +func @foo(%0: i32, %1: i32, %2: f32) -> i32 { return %1: i32} + +func @generic_mismatched_num_returns(%0: memref, %1: memref) { + // expected-error @+1 {{op expected function result 1 of the same type as elemental type 'f32' of output 1}} + linalg.generic { + args_in = 2, + args_out = 1, + fun = @foo, + indexing_maps = [ affine_map<() -> (0)> ], + iterator_types = [] + } %0, %0, %1: memref, memref, memref +} + +// ----- + func @foo(%0: i32) -> i32 { return %0: i32 } func @generic_symbol_in_map(%arg0: memref) { @@ -189,7 +219,7 @@ func @foo(%0: i32) -> f32 { } func @generic_fun_arg_0_element_type(%arg0: memref(off + i)>>) { - // expected-error @+1 {{op expected fun argument 0 of the same type as elemental type 'f32' of view 0}} + // expected-error @+1 {{op expected function argument 1 of the same type as elemental type 'f32' of operand 1}} linalg.generic { args_in = 0, args_out = 1, @@ -207,7 +237,7 @@ func @foo(%0: f32) -> i4 { } func @generic_fun_result_0_element_type(%arg0: memref(off + i)>>) { - // expected-error @+1 {{op expected fun result 0 of the same type as elemental type 'f32' of view 0}} + // expected-error @+1 {{op expected function result 1 of the same type as elemental type 'f32' of output 1}} linalg.generic { args_in = 0, args_out = 1, @@ -257,7 +287,7 @@ func @generic_empty_region(%arg0: memref) { // ----- func @generic_mismatched_num_arguments(%arg0: memref) { - // expected-error @+1 {{op expected number of block arguments to match number of views}} + // expected-error @+1 {{op expected number of block arguments to match number of operands}} linalg.generic { args_in = 0, args_out = 1, @@ -271,7 +301,7 @@ func @generic_mismatched_num_arguments(%arg0: memref) { // ----- func @generic_block_arg_type(%arg0: memref) { - // expected-error @+1 {{op expected block argument 0 of the same type as elemental type of output view: 'memref'}} + // expected-error @+1 {{op expected block argument 1 of the same type as elemental type of output operand: 'memref'}} linalg.generic { args_in = 0, args_out = 1, @@ -285,7 +315,7 @@ func @generic_block_arg_type(%arg0: memref) { // ----- func @indexed_generic_block_arg_count(%arg0: memref) { - // expected-error @+1 {{op expected number of block arguments to match number of views + number of loops}} + // expected-error @+1 {{op expected number of block arguments to match number of operands + number of loops}} linalg.indexed_generic { args_in = 0, args_out = 1, @@ -299,7 +329,7 @@ func @indexed_generic_block_arg_count(%arg0: memref) { // ----- func @indexed_generic_block_induction_var_arg_type(%arg0: memref) { - // expected-error @+1 {{op expected block argument 0 to be of IndexType}} + // expected-error @+1 {{op expected block argument 1 to be an index}} linalg.indexed_generic { args_in = 0, args_out = 1, @@ -313,7 +343,7 @@ func @indexed_generic_block_induction_var_arg_type(%arg0: memref) { // ----- func @indexed_generic_block_arg_type(%arg0: memref) { - // expected-error @+1 {{op expected block argument 1 of the same type as elemental type of output view: 'memref'}} + // expected-error @+1 {{op expected block argument 2 of the same type as elemental type of output operand: 'memref'}} linalg.indexed_generic { args_in = 0, args_out = 1, @@ -330,7 +360,7 @@ func @foo(%f: f32) -> (f32) { return %f : f32 } func @indexed_generic_fun_arg_count(%arg0: memref) { - // expected-error @+1 {{op expected fun arguments to match number of views + number of loops}} + // expected-error @+1 {{op expected function arguments to match number of loops + number of operands}} linalg.indexed_generic { args_in = 0, args_out = 1, @@ -346,7 +376,7 @@ func @foo(%i: i32, %val: f32) -> (f32) { return %val : f32 } func @indexed_generic_fun_induction_var_arg_type(%arg0: memref) { - // expected-error @+1 {{op expected fun argument 0 to be of IndexType}} + // expected-error @+1 {{op expected function argument 1 to be an index}} linalg.indexed_generic { args_in = 0, args_out = 1, @@ -362,7 +392,7 @@ func @foo(%i: index, %val: i1) -> (i1) { return %val : i1 } func @indexed_generic_fun_arg_type(%arg0: memref) { - // expected-error @+1 {{op expected fun argument 1 of the same type as elemental type 'f32' of view 0}} + // expected-error @+1 {{op expected function argument 2 of the same type as elemental type 'f32' of input 1}} linalg.indexed_generic { args_in = 0, args_out = 1, @@ -378,7 +408,7 @@ func @foo(%i: index, %val: i1) -> (i1, i1) { return %val, %val : i1, i1 } func @indexed_generic_fun_result_count(%arg0: memref) { - // expected-error @+1 {{op expected fun results to match number of output views}} + // expected-error @+1 {{op expected function results to match number of outputs}} linalg.indexed_generic { args_in = 0, args_out = 1, @@ -395,7 +425,7 @@ func @foo(%i: index, %val: i32) -> (f32) { return %val_float : f32 } func @indexed_generic_fun_result_count(%arg0: memref) { - // expected-error @+1 {{op expected fun result 0 of the same type as elemental type 'i32' of view 0}} + // expected-error @+1 {{op expected function result 1 of the same type as elemental type 'i32' of output 1}} linalg.indexed_generic { args_in = 0, args_out = 1, @@ -408,7 +438,7 @@ func @indexed_generic_fun_result_count(%arg0: memref) { // ----- func @generic_fun_result_0_element_type(%arg0: memref(off + i)>>) { - // expected-error @+9 {{type of return operand 0 ('i1') doesn't match view element type ('f32')}} + // expected-error @+9 {{type of yield operand 1 ('i1') doesn't match the element type of the enclosing linalg.generic op ('f32')}} linalg.generic { args_in = 0, args_out = 1, @@ -438,36 +468,6 @@ func @generic_result_tensor_type(%arg0: memref(off // ----- -func @generic_result_tensor_count(%arg0: memref(off + i)>>) { - // expected-error @+1 {{op expected #output tensor operands (0) to match #results (1)}} - %0 = linalg.generic { - args_in = 0, - args_out = 1, - indexing_maps = [ affine_map<(i) -> (i)> ], - iterator_types = ["parallel"] - } %arg0 { - ^bb(%i: f32): - linalg.yield %i: f32 - }: memref(off + i)>> -> tensor -} - -// ----- - -func @generic_result_tensor_type(%arg0: tensor) { - // expected-error @+1 {{op result #0 must be 'tensor', but got 'tensor'}} - %0 = linalg.generic { - args_in = 0, - args_out = 1, - indexing_maps = [ affine_map<(i) -> (i)> ], - iterator_types = ["parallel"] - } %arg0 { - ^bb(%i: f32): - linalg.yield %i: f32 - }: tensor -> tensor -} - -// ----- - func @generic_fun_result_0_element_type(%arg0: memref) { // expected-error @+1 {{'linalg.dot' op expected 3 or more operands}} linalg.dot(%arg0, %arg0): memref, memref diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir index f9dbab9..75a09eb 100644 --- a/mlir/test/Dialect/Linalg/roundtrip.mlir +++ b/mlir/test/Dialect/Linalg/roundtrip.mlir @@ -157,23 +157,23 @@ func @generic_with_tensor_input(%arg0: tensor>, %arg1: memref // CHECK-LABEL: func @generic_with_tensor_input // CHECK: linalg.generic {args_in = 1 : i64, args_out = 1 : i64, fun = @foo, indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"], library_call = "some_external_function_name_1"} %{{.*}}, %{{.*}} {foo = 1 : i64}: tensor>, memref -func @generic_with_tensor_output(%arg0: memref, offset: ?, strides: [?, 1]>, %arg1: tensor) -> (tensor) { - %0 = linalg.generic #trait %arg0, %arg1 {foo = 1} : memref, offset: ?, strides: [?, 1]>, tensor -> tensor - return %0 : tensor +#trait2 = { + args_in = 2, + args_out = 1, + indexing_maps = #accesses, + iterator_types = ["parallel", "parallel", "parallel"], + fun = @foo, + library_call = "some_external_function_name_1" } -// CHECK-LABEL: func @generic_with_tensor_output -// CHECK: linalg.generic {args_in = 1 : i64, args_out = 1 : i64, fun = @foo, indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"], library_call = "some_external_function_name_1"} %{{.*}}, %{{.*}} {foo = 1 : i64}: memref, #[[strided2D]]>, tensor -> tensor -// CHECK: return {{.*}} : tensor - func @generic_with_tensor_input_and_output(%arg0: tensor>, %arg1: tensor) -> (tensor) { - %0 = linalg.generic #trait %arg0, %arg1 {foo = 1} : tensor>, tensor -> tensor + %0 = linalg.generic #trait2 %arg0, %arg1 {foo = 1} : tensor>, tensor -> tensor return %0 : tensor } // CHECK-LABEL: func @generic_with_tensor_input_and_output -// CHECK: linalg.generic {args_in = 1 : i64, args_out = 1 : i64, fun = @foo, indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"], library_call = "some_external_function_name_1"} %{{.*}}, %{{.*}} {foo = 1 : i64}: tensor>, tensor -> tensor +// CHECK: linalg.generic {args_in = 2 : i64, args_out = 1 : i64, fun = @foo, indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"], library_call = "some_external_function_name_1"} %{{.*}}, %{{.*}} {foo = 1 : i64}: tensor>, tensor -> tensor // CHECK: return {{.*}} : tensor -#trait2 = { +#trait3 = { args_in = 1, args_out = 1, indexing_maps = #accesses, @@ -181,7 +181,7 @@ func @generic_with_tensor_input_and_output(%arg0: tensor>, %a library_call = "some_external_function_name_2" } func @generic_region(%arg0: memref, offset: ?, strides: [?, 1]>, %arg1: memref) { - linalg.generic #trait2 %arg0, %arg1 { + linalg.generic #trait3 %arg0, %arg1 { ^bb(%a: vector<3x4xi4>, %b: f32) : linalg.yield %b : f32 } {foo = 1}: memref, offset: ?, strides: [?, 1]>, memref @@ -194,7 +194,7 @@ func @generic_region(%arg0: memref, offset: ?, strides: [?, 1 // CHECK: } {foo = 1 : i64}: memref, #[[strided2D]]>, memref func @indexed_generic(%arg0: memref, offset: ?, strides: [?, 1]>, %arg1: memref) { - linalg.indexed_generic #trait2 %arg0, %arg1 { + linalg.indexed_generic #trait3 %arg0, %arg1 { ^bb(%i: index, %j: index, %k: index, %a: vector<3x4xi4>, %b: f32) : linalg.yield %b : f32 } {foo = 1}: memref, offset: ?, strides: [?, 1]>, memref -- 2.7.4