From d9cbefc4c738dc976b6b20fb1eb82a55d78dd801 Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Tue, 23 Aug 2022 12:34:58 +0200 Subject: [PATCH] [mlir] Extract DestinationStyleOpInterface from LinalgStructuredInterface. There are several use cases where a destination style operation needs an interface that contains a subset of the methods from LinalgStructuredInterface. In this change, we move all such methods to a new interface, and add forwarding methods to LinalgStructuredInterface to make the change the less invasive. It may be possible to refactor the code later to get rid of (some or all) of the forwarding methods. This change also removes the cloneWithMapper interface methods, as it is not used anywhere. RFC: https://discourse.llvm.org/t/rfc-interface-for-destination-style-ops/64056 Differential Revision: https://reviews.llvm.org/D132125 --- .../mlir/Dialect/Linalg/IR/LinalgInterfaces.h | 3 + .../mlir/Dialect/Linalg/IR/LinalgInterfaces.td | 1187 +++++++++++--------- .../mlir/Dialect/Linalg/IR/LinalgStructuredOps.td | 1 + mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp | 135 +-- mlir/test/lib/Dialect/Test/TestOps.td | 6 +- 5 files changed, 753 insertions(+), 579 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h index 8397293..39ca855 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h @@ -57,6 +57,9 @@ LogicalResult verifyFillInterface(Operation *op); /// Verify that `op` conforms to the invariants of StructuredOpInterface LogicalResult verifyStructuredOpInterface(Operation *op); +/// Verify that `op` conforms to the invariants of DestinationStyleOpInterface +LogicalResult verifyDestinationStyleOpInterface(Operation *op); + } // namespace detail } // namespace linalg } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td index 1a404b3..8bcde8a 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td @@ -286,736 +286,986 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> { getNumIterators(getReductionIteratorTypeName(), iters) == 1; }]>, //===------------------------------------------------------------------===// - // Num input/output arguments handling. + // Input and Output arguments handling. //===------------------------------------------------------------------===// - // `inputs` must be defined by each op that wants to implement the - // LinalgStructuredInterface. - InterfaceMethod< - /*desc=*/[{ - Return the input shape operands. - }], - /*retTy=*/"ValueRange", - /*methodName=*/"inputs", - /*args=*/(ins) - >, - // These special methods rely on `inputs` and `outputs` being defined by - // each op that wants to implement the LinalgStructuredInterface. InterfaceMethod< /*desc=*/[{ - Return the number of inputs. + Return true if the payload uses the value loaded from `opOperand`. This + is useful to avoid loading from "write-only" memory that may be + uninitialized, as well as properly cloning "read-write" operands. }], - /*retTy=*/"int64_t", - /*methodName=*/"getNumInputs", - /*args=*/(ins), + /*retTy=*/"bool", + /*methodName=*/"payloadUsesValueFromOperand", + /*args=*/(ins "OpOperand *":$opOperand), /*methodBody=*/"", /*defaultImplementation=*/[{ - return $_op.getInputs().size(); + unsigned bbArgNumber = opOperand->getOperandNumber(); + // Init tensors have uses. + return !getBlock()->getArgument(bbArgNumber).use_empty(); }] >, - // `outputs` must be defined by each op that wants to implement the - // LinalgStructuredInterface. - InterfaceMethod< - /*desc=*/[{ - Return the output shape operands. - }], - /*retTy=*/"ValueRange", - /*methodName=*/"outputs", - /*args=*/(ins) - >, InterfaceMethod< /*desc=*/[{ - Return the number of outputs. + Return true if `opOperand` is an init tensor. This is true when it is + an output tensor operand whose value is used in the payload region. }], - /*retTy=*/"int64_t", - /*methodName=*/"getNumOutputs", - /*args=*/(ins), + /*retTy=*/"bool", + /*methodName=*/"isInitTensor", + /*args=*/(ins "OpOperand *":$opOperand), /*methodBody=*/"", /*defaultImplementation=*/[{ - return $_op.outputs().size(); + if (!$_op.isOutputTensor(opOperand)) + return false; + return payloadUsesValueFromOperand(opOperand); }] >, InterfaceMethod< /*desc=*/[{ - Return the number of inputs and outputs. + Return the `opOperand` rank or zero for scalars. }], /*retTy=*/"int64_t", - /*methodName=*/"getNumInputsAndOutputs", - /*args=*/(ins), + /*methodName=*/"getRank", + /*args=*/(ins "OpOperand*":$opOperand), /*methodBody=*/"", /*defaultImplementation=*/[{ - return this->getOperation()->getNumOperands(); + assert(opOperand->getOwner() == this->getOperation()); + if (auto shapedType = + opOperand->get().getType().template dyn_cast()) + return shapedType.getRank(); + return 0; }] >, - //===------------------------------------------------------------------===// - // Input operands handling. - //===------------------------------------------------------------------===// InterfaceMethod< /*desc=*/[{ - Return the input operands. + Return the output block arguments of the region. }], - /*retTy=*/"OpOperandVector", - /*methodName=*/"getInputOperands", + /*retTy=*/"Block::BlockArgListType", + /*methodName=*/"getRegionOutputArgs", /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - int64_t numInputs = getNumInputs(); - OpOperandVector result; - result.reserve(numInputs); - llvm::transform( - this->getOperation()->getOpOperands().take_front(numInputs), - std::back_inserter(result), - [](OpOperand &opOperand) { return &opOperand; }); - return result; + // MLIR currently does not support dependent interfaces or interface + // inheritance. By construction all ops with StructuredOpInterface must + // implement DestinationStyleOpInterface. + // TODO: reevalute the need for a cast when a better mechanism exists. + return getBlock()->getArguments().take_back( + cast(*this->getOperation()) + .getNumOutputs()); }] >, InterfaceMethod< /*desc=*/[{ - Return the `i`-th input operand. + Return the `opOperand` shape or an empty vector for scalars. }], - /*retTy=*/"OpOperand*", - /*methodName=*/"getInputOperand", - /*args=*/(ins "int64_t":$i), + /*retTy=*/"ArrayRef", + /*methodName=*/"getShape", + /*args=*/(ins "OpOperand*":$opOperand), /*methodBody=*/"", /*defaultImplementation=*/[{ - assert(i >= 0 && i < getNumInputs()); - return &this->getOperation()->getOpOperand(i); + assert(opOperand->getOwner() == this->getOperation()); + if (auto shapedType = + opOperand->get().getType().template dyn_cast()) + return shapedType.getShape(); + return {}; }] >, InterfaceMethod< /*desc=*/[{ - Return the subset of input operands that are of buffer type. + Return the block argument for an `opOperand`. }], - /*retTy=*/"OpOperandVector", - /*methodName=*/"getInputBufferOperands", - /*args=*/(ins), + /*retTy=*/"BlockArgument", + /*methodName=*/"getTiedBlockArgument", + /*args=*/(ins "OpOperand *":$opOperand), /*methodBody=*/"", /*defaultImplementation=*/[{ - OpOperandVector result; - result.reserve(getNumInputs()); - llvm::copy_if(getInputOperands(), - std::back_inserter(result), - [](OpOperand *opOperand) { - return opOperand->get().getType().template isa(); - }); - return result; + assert(opOperand->getOwner() == this->getOperation()); + return getBlock()->getArgument(opOperand->getOperandNumber()); }] >, InterfaceMethod< /*desc=*/[{ - Return the subset of input operands that are of tensor type. + Return the operand for a `blockArgument`. }], - /*retTy=*/"OpOperandVector", - /*methodName=*/"getInputTensorOperands", - /*args=*/(ins), + /*retTy=*/"OpOperand *", + /*methodName=*/"getTiedOpOperand", + /*args=*/(ins "BlockArgument":$blockArgument), /*methodBody=*/"", /*defaultImplementation=*/[{ - OpOperandVector result; - result.reserve(getNumInputs()); - llvm::copy_if(getInputOperands(), - std::back_inserter(result), - [](OpOperand *opOperand) { - return opOperand->get().getType().template isa(); - }); - return result; + assert(blockArgument.getOwner() == getBlock()); + return &this->getOperation()->getOpOperand( + blockArgument.getArgNumber()); }] >, - //===------------------------------------------------------------------===// - // Output operands handling. - //===------------------------------------------------------------------===// InterfaceMethod< /*desc=*/[{ - Return the output operands. + Return the input or output indexing map for `opOperand`. }], - /*retTy=*/"OpOperandVector", - /*methodName=*/"getOutputOperands", - /*args=*/(ins), + /*retTy=*/"AffineMap", + /*methodName=*/"getTiedIndexingMap", + /*args=*/(ins "OpOperand*":$opOperand), /*methodBody=*/"", /*defaultImplementation=*/[{ - int64_t numOutputs = getNumOutputs(); - OpOperandVector result; - result.reserve(numOutputs); - llvm::transform( - this->getOperation()->getOpOperands() - .take_back(numOutputs), - std::back_inserter(result), - [](OpOperand &opOperand) { return &opOperand; }); - return result; + assert(opOperand->getOwner() == this->getOperation()); + auto indexingMaps = + $_op.getIndexingMaps().template getAsValueRange(); + return *(indexingMaps.begin() + opOperand->getOperandNumber()); }] >, InterfaceMethod< /*desc=*/[{ - Return the `i`-th output operand. + Return the indexing map for a `result`. }], - /*retTy=*/"OpOperand*", - /*methodName=*/"getOutputOperand", - /*args=*/(ins "int64_t":$i), + /*retTy=*/"AffineMap", + /*methodName=*/"getTiedIndexingMapForResult", + /*args=*/(ins "OpResult":$result), /*methodBody=*/"", /*defaultImplementation=*/[{ - assert(i >= 0 && i < getNumOutputs()); - return &this->getOperation()->getOpOperand(getNumInputs() + i); + assert(result.getOwner() == this->getOperation()); + auto indexingMaps = + $_op.getIndexingMaps().template getAsValueRange(); + // MLIR currently does not support dependent interfaces or interface + // inheritance. By construction all ops with StructuredOpInterface must + // implement DestinationStyleOpInterface. + // TODO: reevalute the need for a cast when a better mechanism exists. + return *(indexingMaps.begin() + + cast(*this->getOperation()) + .getNumInputs() + + result.getResultNumber()); }] >, InterfaceMethod< /*desc=*/[{ - Set the `i`-th output operand. + Return the value yielded by the region corresponding to an output + `opOperand`. }], - /*retTy=*/"void", - /*methodName=*/"setOutputOperand", - /*args=*/(ins "int64_t":$i, "Value":$value), + /*retTy=*/"OpOperand *", + /*methodName=*/"getTiedYieldValue", + /*args=*/(ins "OpOperand*":$opOperand), /*methodBody=*/"", /*defaultImplementation=*/[{ - assert(i >= 0 && i < getNumOutputs()); - this->getOperation()->setOperand(getNumInputs() + i, value); + assert(opOperand->getOwner() == this->getOperation()); + // MLIR currently does not support dependent interfaces or interface + // inheritance. By construction all ops with StructuredOpInterface must + // implement DestinationStyleOpInterface. + // TODO: reevalute the need for a cast when a better mechanism exists. + int64_t resultIndex = + opOperand->getOperandNumber() - + cast(*this->getOperation()) + .getNumInputs(); + assert(resultIndex >= 0 && + resultIndex < this->getOperation()->getNumResults()); + Operation *yieldOp = getBlock()->getTerminator(); + return &yieldOp->getOpOperand(resultIndex); }] >, + //===------------------------------------------------------------------===// + // Other interface methods. + //===------------------------------------------------------------------===// InterfaceMethod< /*desc=*/[{ - Return the subset of output operands that are of buffer type. + Return the single block constituting the body of the operation by + calling the getBody method on the concrete operation. }], - /*retTy=*/"OpOperandVector", - /*methodName=*/"getOutputBufferOperands", + /*retTy=*/"Block*", + /*methodName=*/"getBlock", /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - OpOperandVector result; - result.reserve(getNumOutputs()); - llvm::copy_if(getOutputOperands(), - std::back_inserter(result), - [](OpOperand *opOperand) { - return opOperand->get().getType().template isa(); - }); - return result; + // Assume the concrete operation implements the + // SingleBlockImplicitTerminator trait. + return $_op.getBody(); }] >, InterfaceMethod< /*desc=*/[{ - Return the subset of output operands that are of tensor type. + Return the iterator types attribute within the current operation. }], - /*retTy=*/"OpOperandVector", - /*methodName=*/"getOutputTensorOperands", + /*retTy=*/"ArrayAttr", + /*methodName=*/"iterator_types", /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - OpOperandVector result; - result.reserve(getNumOutputs()); - llvm::copy_if(getOutputOperands(), - std::back_inserter(result), - [](OpOperand *opOperand) { - return opOperand->get().getType().template isa(); - }); - return result; + return $_op.iterator_types(); }] >, InterfaceMethod< /*desc=*/[{ - Return the types of the subset of output operands that are of buffer type. + Return true if the indexing map is depending on the current op instance. + This means that the indexing map is dynamically synthesized by using the + op instance's concrete attributes, instead of being static for all + instances of the same op kind. }], - /*retTy=*/"SmallVector", - /*methodName=*/"getOutputBufferTypes", + /*retTy=*/"bool", + /*methodName=*/"hasDynamicIndexingMaps", /*args=*/(ins), /*methodBody=*/"", - /*defaultImplementation=*/[{ - SmallVector result; - result.reserve(getNumOutputs()); - llvm::transform(getOutputBufferOperands(), - std::back_inserter(result), - [](OpOperand *opOperands) { - return opOperands->get().getType().cast(); - }); - return result; - }] + /*defaultImplementation=*/[{ return false; }] >, InterfaceMethod< /*desc=*/[{ - Return the types of the subset of output operands that are of tensor type. + Verify all attributes used by indexing maps are valid. }], - /*retTy=*/"SmallVector", - /*methodName=*/"getOutputTensorTypes", + /*retTy=*/"LogicalResult", + /*methodName=*/"verifyIndexingMapRequiredAttributes", /*args=*/(ins), /*methodBody=*/"", - /*defaultImplementation=*/[{ - SmallVector result; - result.reserve(getNumOutputs()); - llvm::transform(getOutputTensorOperands(), - std::back_inserter(result), - [](OpOperand *opOperands) { - return opOperands->get().getType().cast(); - }); - return result; - }] + /*defaultImplementation=*/[{ return success(); }] >, - //===------------------------------------------------------------------===// - // Input and Output arguments handling. - //===------------------------------------------------------------------===// InterfaceMethod< /*desc=*/[{ - Return the range over input and output operands. + Return the indexing maps attribute within the current operation. }], - /*retTy=*/"OpOperandVector", - /*methodName=*/"getInputAndOutputOperands", + /*retTy=*/"ArrayAttr", + /*methodName=*/"getIndexingMaps" + >, + InterfaceMethod< + /*desc=*/[{ + Return the indexing maps within the current operation. + }], + /*retTy=*/"SmallVector", + /*methodName=*/"getIndexingMapsArray", /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - int64_t numInputsAndOutputs = getNumInputsAndOutputs(); - OpOperandVector result; - result.reserve(numInputsAndOutputs); - llvm::transform( - this->getOperation()->getOpOperands(), - std::back_inserter(result), - [](OpOperand &opOperand) { return &opOperand; }); - return result; + auto range = $_op.getIndexingMaps() + .template getAsValueRange(); + return {range.begin(), range.end()}; }] >, InterfaceMethod< /*desc=*/[{ - Return true if the payload uses the value loaded from `opOperand`. This - is useful to avoid loading from "write-only" memory that may be - uninitialized, as well as properly cloning "read-write" operands. + Return true if any of the operands has a dynamic shape. }], /*retTy=*/"bool", - /*methodName=*/"payloadUsesValueFromOperand", - /*args=*/(ins "OpOperand *":$opOperand), + /*methodName=*/"hasDynamicShape", + /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - unsigned bbArgNumber = opOperand->getOperandNumber(); - // Init tensors have uses. - return !getBlock()->getArgument(bbArgNumber).use_empty(); + return llvm::any_of(getStaticShape(), ShapedType::isDynamic); }] >, InterfaceMethod< /*desc=*/[{ - Return true if `opOperand` is an input tensor. + Return the name registered for this op when lowering to an external + library call. }], - /*retTy=*/"bool", - /*methodName=*/"isInputTensor", - /*args=*/(ins "OpOperand *":$opOperand), + /*retTy=*/"std::string", + /*methodName=*/"getLibraryCallName", + /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - if (!opOperand->get().getType().template isa()) - return false; - if (opOperand->getOperandNumber() < $_op.getNumInputs()) - return true; - return false; + return $_op.getLibraryCallName(); }] >, InterfaceMethod< /*desc=*/[{ - Return true if `opOperand` is an output tensor. + Return whether the op accesses the iteration indices. }], /*retTy=*/"bool", - /*methodName=*/"isOutputTensor", - /*args=*/(ins "OpOperand *":$opOperand), + /*methodName=*/"hasIndexSemantics", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/"" + >, + //===------------------------------------------------------------------===// + // Linalg generalization hooks. + //===------------------------------------------------------------------===// + InterfaceMethod< + /*desc=*/[{ + Hook to provide a custom AffineMap used to compute all the operand + subshapes given loop bounds. This is used to answer the question: "given + an iteration space over the codomain, what are the subshapes of the + operands involved in the computation". + The default behavior is to just concatenate all the indexing maps. + A custom AffineMap allows providing a map that can be used to + compute subshapes even in cases where the concatenation of indexing maps + (i.e. the data traversal order) is not a simple permutation of the loop + traversal order. It is then possible to define ops with skewed data + traversal order for which we can still easily compute hyperrectangular + loop bounds and subviews. + }], + /*retTy=*/"AffineMap", + /*methodName=*/"getLoopsToShapesMap", + /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - if (!opOperand->get().getType().template isa()) - return false; - if (opOperand->getOperandNumber() >= $_op.getNumInputs()) - return true; - return false; + auto maps = $_op.getIndexingMapsArray(); + return concatAffineMaps(maps); }] >, InterfaceMethod< /*desc=*/[{ - Return true if `opOperand` is an init tensor. This is true when it is - an output tensor operand whose value is used in the payload region. + Hook to provide a custom AffineMap used to construct the + hyperrectangular loop iteration space given all the operand subshapes. + This is used to answer the question: + "Given a list of operand ranges, what is the subportion of the iteration + space involved in the computation". + This is the inverse problem of `getLoopsToShapesMap`. + Return the empty AffineMap when such an AffineMap cannot be constructed. + The default behavior is based on a very simple inference procedure that + only works with permutation affine maps. + A more advanced Tensor-Comprehension like inference is possible but has + proven to be ambiguous in unfavorable case. + A safer and more robust alternative is to allow each op to define + its own AffineMap. + }], + /*retTy=*/"AffineMap", + /*methodName=*/"getShapesToLoopsMap", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return inversePermutation(getLoopsToShapesMap()); + }] + >, + InterfaceMethod< + /*desc=*/[{ + Checks if the given operands can be dropped, and the remaining + operands can still compute the bounds of the op. }], /*retTy=*/"bool", - /*methodName=*/"isInitTensor", - /*args=*/(ins "OpOperand *":$opOperand), + /*methodName=*/"canOpOperandsBeDropped", + /*args=*/(ins "ArrayRef":$droppedOperands), /*methodBody=*/"", /*defaultImplementation=*/[{ - if (!$_op.isOutputTensor(opOperand)) - return false; - return payloadUsesValueFromOperand(opOperand); + return detail::canOpOperandsBeDroppedImpl($_op, droppedOperands); }] >, InterfaceMethod< /*desc=*/[{ - Return the `opOperand` rank or zero for scalars. + Like `getShape`, but only returns statically-known information, without + generating any new IR. For each shape dimension, returns >=0 if that + dimension is statically known, or ShapeType::kDynamicSize otherwise. }], - /*retTy=*/"int64_t", - /*methodName=*/"getRank", - /*args=*/(ins "OpOperand*":$opOperand), + /*retTy=*/"SmallVector", + /*methodName=*/"getStaticShape", + /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - assert(opOperand->getOwner() == this->getOperation()); - if (auto shapedType = - opOperand->get().getType().template dyn_cast()) - return shapedType.getRank(); - return 0; + SmallVector res; + // MLIR currently does not support dependent interfaces or interface + // inheritance. By construction all ops with StructuredOpInterface must + // implement DestinationStyleOpInterface. + // TODO: reevalute the need for a cast when a better mechanism exists. + auto iface = cast(*this->getOperation()); + for (OpOperand *opOperand : iface.getInputAndOutputOperands()) + llvm::append_range(res, getShape(opOperand)); + return res; }] >, InterfaceMethod< /*desc=*/[{ - Return the output block arguments of the region. + Returns the statically-known loop ranges. Composes + `getShapesToLoopsMap()` with the result of `getStaticShape`. + Returns ShapeType::kDynamicSize for non-statically-known loop ranges. + This is expected to be called by a valid Linalg op + }], + /*retTy=*/"SmallVector", + /*methodName=*/"getStaticLoopRanges", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + SmallVector viewSizes = getStaticShape(); + AffineMap invertedMap = getShapesToLoopsMap(); + assert(invertedMap && "expected a valid Linalg op to call the method"); + return invertedMap.compose(viewSizes); + }] + >, + //===------------------------------------------------------------------===// + // Other static interface methods. + //===------------------------------------------------------------------===// + StaticInterfaceMethod< + /*desc=*/[{ + Returns the region builder for constructing the body for linalg.generic. + Returns a null function if this named op does not define a region + builder. + }], + /*retTy=*/"std::function)>", + /*methodName=*/"getRegionBuilder", + (ins), + [{ return ConcreteOp::getRegionBuilder(); }] + >, + InterfaceMethod< + /*desc=*/[{ + Return true if all the indexing maps are projected permutations. + Otherwise return false. + }], + /*retTy=*/"bool", + /*methodName=*/"hasOnlyProjectedPermutations", + (ins), + [{ + return llvm::all_of($_op.getIndexingMapsArray(), + [](AffineMap map) { return map.isProjectedPermutation(); }); + }] + > + ]; + + let extraClassDeclaration = [{ + /// Return the flat list of all operand dimension sizes in the order they + /// appear in the operands. + SmallVector createFlatListOfOperandDims(OpBuilder &, Location); + + /// Return the flat list of all operands' static dimension sizes in the + /// order they appear in the operands. All operand dimension sizes have to + /// be statically known. + SmallVector createFlatListOfOperandStaticDims(); + + /// Create the loop ranges to materialize the computation over the current + /// operands. This is done by applying `getShapesToLoopsMap` to + /// `createFlatListOfOperandDims`. + SmallVector createLoopRanges(OpBuilder &b, Location loc); + + /// Compute the static loop sizes necessary to vectorize the computation. + /// This is done by applying `getShapesToLoopsMap` to + /// `createFlatListOfOperandStaticDims`. + SmallVector computeStaticLoopSizes(); + + /// Returns the value that expresses the shape of the output in terms of + /// shape of the input operands where possible + LogicalResult reifyResultShapes(OpBuilder &b, + ReifiedRankedShapedTypeDims &reifiedReturnShapes); + + // TODO: Remove once prefixing is flipped. + ArrayAttr getIteratorTypes() { return iterator_types(); } + + //========================================================================// + // Forwarding functions to access interface methods from the + // DestinationStyleOpInterface. + // MLIR currently does not support dependent interfaces or interface + // inheritance. By construction all ops with StructuredOpInterface must + // implement DestinationStyleOpInterface. + // TODO: reevalute the need for a cast when a better mechanism exists. + //========================================================================// + + ValueRange inputs() { + return cast(*this->getOperation()).inputs(); + } + + int64_t getNumInputs() { + return cast(*this->getOperation()) + .getNumInputs(); + } + + ValueRange outputs() { + return cast(*this->getOperation()).outputs(); + } + + int64_t getNumOutputs() { + return cast(*this->getOperation()) + .getNumOutputs(); + } + + int64_t getNumInputsAndOutputs() { + return cast(*this->getOperation()) + .getNumInputsAndOutputs(); + } + + OpOperandVector getInputOperands() { + return cast(*this->getOperation()) + .getInputOperands(); + } + + OpOperand *getInputOperand(int64_t i) { + return cast(*this->getOperation()) + .getInputOperand(i); + } + + OpOperandVector getInputBufferOperands() { + return cast(*this->getOperation()) + .getInputBufferOperands(); + } + + OpOperandVector getInputTensorOperands() { + return cast(*this->getOperation()) + .getInputTensorOperands(); + } + + OpOperandVector getOutputOperands() { + return cast(*this->getOperation()) + .getOutputOperands(); + } + + OpOperand *getOutputOperand(int64_t i) { + return cast(*this->getOperation()) + .getOutputOperand(i); + } + + void setOutputOperand(int64_t i, Value value) { + return cast(*this->getOperation()) + .setOutputOperand(i, value); + } + + OpOperandVector getOutputBufferOperands() { + return cast(*this->getOperation()) + .getOutputBufferOperands(); + } + + OpOperandVector getOutputTensorOperands() { + return cast(*this->getOperation()) + .getOutputTensorOperands(); + } + + SmallVector getOutputBufferTypes() { + return cast(*this->getOperation()) + .getOutputBufferTypes(); + } + + SmallVector getOutputTensorTypes() { + return cast(*this->getOperation()) + .getOutputTensorTypes(); + } + + OpOperandVector getInputAndOutputOperands() { + return cast(*this->getOperation()) + .getInputAndOutputOperands(); + } + + bool isInputTensor(OpOperand *opOperand) { + return cast(*this->getOperation()) + .isInputTensor(opOperand); + } + + bool isOutputTensor(OpOperand *opOperand) { + return cast(*this->getOperation()) + .isOutputTensor(opOperand); + } + + bool isScalar(OpOperand *opOperand) { + return cast(*this->getOperation()) + .isScalar(opOperand); + } + + OpResult getTiedOpResult(OpOperand *opOperand) { + return cast(*this->getOperation()) + .getTiedOpResult(opOperand); + } + + bool hasBufferSemantics() { + return cast(*this->getOperation()) + .hasBufferSemantics(); + } + + bool hasTensorSemantics() { + return cast(*this->getOperation()) + .hasTensorSemantics(); + } + + Operation *clone(OpBuilder & b, Location loc, TypeRange resultTypes, + ValueRange operands) { + return cast(*this->getOperation()) + .clone(b, loc, resultTypes, operands); + } + + Operation *cloneWithoutRegions(OpBuilder & b, Location loc, + TypeRange resultTypes, ValueRange operands) { + return cast(*this->getOperation()) + .cloneWithoutRegions(b, loc, resultTypes, operands); + } + + //========================================================================// + // Helper functions to mutate the `operand_segment_sizes` attribute. + // These are useful when cloning and changing operand types. + //========================================================================// + void setNumInputs(unsigned num) { setOperandSegmentAt(0, num); } + void setNumOutputBuffers(unsigned num) { setOperandSegmentAt(1, num); } + + private: + void setOperandSegmentAt(unsigned idx, unsigned val) { + auto attr = (*this)->getAttr("operand_segment_sizes") + .cast(); + unsigned i = 0; + auto newAttr = attr.mapValues(IntegerType::get(getContext(), 32), + [&](const APInt &v) { return (i++ == idx) ? APInt(32, val) : v; }); + getOperation()->setAttr("operand_segment_sizes", newAttr); + } + }]; + + let verify = [{ return detail::verifyStructuredOpInterface($_op); }]; + let verifyWithRegions = 1; +} + +// The 'DestinationStyleOpInterface' provides access to the methods relevant +// for destination-style ops. A destination-style operation has 'n' input +// arguments and 'm' output arguments. Each op that wants to implement +// DestinationStyleOpInterface needs to define inputs() and outputs() methods. +def DestinationStyleOpInterface : OpInterface<"DestinationStyleOpInterface"> { + let cppNamespace = "::mlir::linalg"; + let methods = [ + //===------------------------------------------------------------------===// + // Num input/output arguments handling. + //===------------------------------------------------------------------===// + // `inputs` must be defined by each op that wants to implement the + // DestinationStyleOpInterface. + InterfaceMethod< + /*desc=*/[{ + Return the input shape operands. }], - /*retTy=*/"Block::BlockArgListType", - /*methodName=*/"getRegionOutputArgs", - /*args=*/(ins), - /*methodBody=*/"", - /*defaultImplementation=*/[{ - return getBlock()->getArguments().take_back(this->getNumOutputs()); - }] + /*retTy=*/"ValueRange", + /*methodName=*/"inputs", + /*args=*/(ins) >, + // These special methods rely on `inputs` and `outputs` being defined by + // each op that wants to implement the DestinationStyleOpInterface. InterfaceMethod< /*desc=*/[{ - Return the `opOperand` shape or an empty vector for scalars. + Return the number of inputs. }], - /*retTy=*/"ArrayRef", - /*methodName=*/"getShape", - /*args=*/(ins "OpOperand*":$opOperand), + /*retTy=*/"int64_t", + /*methodName=*/"getNumInputs", + /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - assert(opOperand->getOwner() == this->getOperation()); - if (auto shapedType = - opOperand->get().getType().template dyn_cast()) - return shapedType.getShape(); - return {}; + return $_op.getInputs().size(); }] >, + // `outputs` must be defined by each op that wants to implement the + // DestinationStyleOpInterface. InterfaceMethod< /*desc=*/[{ - Return true if the `opOperand` is a scalar value. + Return the output shape operands. }], - /*retTy=*/"bool", - /*methodName=*/"isScalar", - /*args=*/(ins "OpOperand*":$opOperand), - /*methodBody=*/"", - /*defaultImplementation=*/[{ - assert(opOperand->getOwner() == this->getOperation()); - return !opOperand->get().getType().template isa(); - }] + /*retTy=*/"ValueRange", + /*methodName=*/"outputs", + /*args=*/(ins) >, InterfaceMethod< /*desc=*/[{ - Return the block argument for an `opOperand`. + Return the number of outputs. }], - /*retTy=*/"BlockArgument", - /*methodName=*/"getTiedBlockArgument", - /*args=*/(ins "OpOperand *":$opOperand), + /*retTy=*/"int64_t", + /*methodName=*/"getNumOutputs", + /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - assert(opOperand->getOwner() == this->getOperation()); - return getBlock()->getArgument(opOperand->getOperandNumber()); + return $_op.outputs().size(); }] >, InterfaceMethod< /*desc=*/[{ - Return the operand for a `blockArgument`. + Return the number of inputs and outputs. }], - /*retTy=*/"OpOperand *", - /*methodName=*/"getTiedOpOperand", - /*args=*/(ins "BlockArgument":$blockArgument), + /*retTy=*/"int64_t", + /*methodName=*/"getNumInputsAndOutputs", + /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - assert(blockArgument.getOwner() == getBlock()); - return &this->getOperation()->getOpOperand( - blockArgument.getArgNumber()); + return this->getOperation()->getNumOperands(); }] >, + //===------------------------------------------------------------------===// + // Input operands handling. + //===------------------------------------------------------------------===// InterfaceMethod< /*desc=*/[{ - Return the input or output indexing map for `opOperand`. + Return the input operands. }], - /*retTy=*/"AffineMap", - /*methodName=*/"getTiedIndexingMap", - /*args=*/(ins "OpOperand*":$opOperand), + /*retTy=*/"OpOperandVector", + /*methodName=*/"getInputOperands", + /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - assert(opOperand->getOwner() == this->getOperation()); - auto indexingMaps = - $_op.getIndexingMaps().template getAsValueRange(); - return *(indexingMaps.begin() + opOperand->getOperandNumber()); + int64_t numInputs = getNumInputs(); + OpOperandVector result; + result.reserve(numInputs); + llvm::transform( + this->getOperation()->getOpOperands().take_front(numInputs), + std::back_inserter(result), + [](OpOperand &opOperand) { return &opOperand; }); + return result; }] >, InterfaceMethod< /*desc=*/[{ - Return the indexing map for a `result`. + Return the `i`-th input operand. }], - /*retTy=*/"AffineMap", - /*methodName=*/"getTiedIndexingMapForResult", - /*args=*/(ins "OpResult":$result), + /*retTy=*/"OpOperand*", + /*methodName=*/"getInputOperand", + /*args=*/(ins "int64_t":$i), /*methodBody=*/"", /*defaultImplementation=*/[{ - assert(result.getOwner() == this->getOperation()); - auto indexingMaps = - $_op.getIndexingMaps().template getAsValueRange(); - return *(indexingMaps.begin() + getNumInputs() + - result.getResultNumber()); + assert(i >= 0 && i < getNumInputs()); + return &this->getOperation()->getOpOperand(i); }] >, InterfaceMethod< /*desc=*/[{ - Return the result tied to `opOperand`. + Return the subset of input operands that are of buffer type. }], - /*retTy=*/"OpResult", - /*methodName=*/"getTiedOpResult", - /*args=*/(ins "OpOperand*":$opOperand), + /*retTy=*/"OpOperandVector", + /*methodName=*/"getInputBufferOperands", + /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - assert(opOperand->getOwner() == this->getOperation()); - int64_t resultIndex = opOperand->getOperandNumber() - getNumInputs(); - assert(resultIndex >= 0 && - resultIndex < this->getOperation()->getNumResults() ); - return this->getOperation()->getResult(resultIndex); + OpOperandVector result; + result.reserve(getNumInputs()); + llvm::copy_if(getInputOperands(), + std::back_inserter(result), + [](OpOperand *opOperand) { + return opOperand->get().getType().template isa(); + }); + return result; }] >, InterfaceMethod< /*desc=*/[{ - Return the value yielded by the region corresponding to an output - `opOperand`. + Return the subset of input operands that are of tensor type. }], - /*retTy=*/"OpOperand *", - /*methodName=*/"getTiedYieldValue", - /*args=*/(ins "OpOperand*":$opOperand), + /*retTy=*/"OpOperandVector", + /*methodName=*/"getInputTensorOperands", + /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - assert(opOperand->getOwner() == this->getOperation()); - int64_t resultIndex = opOperand->getOperandNumber() - getNumInputs(); - assert(resultIndex >= 0 && - resultIndex < this->getOperation()->getNumResults()); - Operation *yieldOp = getBlock()->getTerminator(); - return &yieldOp->getOpOperand(resultIndex); + OpOperandVector result; + result.reserve(getNumInputs()); + llvm::copy_if(getInputOperands(), + std::back_inserter(result), + [](OpOperand *opOperand) { + return opOperand->get().getType().template isa(); + }); + return result; }] >, //===------------------------------------------------------------------===// - // Other interface methods. + // Output operands handling. //===------------------------------------------------------------------===// InterfaceMethod< /*desc=*/[{ - Return the single block constituting the body of the operation by - calling the getBody method on the concrete operation. + Return the output operands. }], - /*retTy=*/"Block*", - /*methodName=*/"getBlock", + /*retTy=*/"OpOperandVector", + /*methodName=*/"getOutputOperands", /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - // Assume the concrete operation implements the - // SingleBlockImplicitTerminator trait. - return $_op.getBody(); + int64_t numOutputs = getNumOutputs(); + OpOperandVector result; + result.reserve(numOutputs); + llvm::transform( + this->getOperation()->getOpOperands() + .take_back(numOutputs), + std::back_inserter(result), + [](OpOperand &opOperand) { return &opOperand; }); + return result; }] >, InterfaceMethod< /*desc=*/[{ - Return the iterator types attribute within the current operation. + Return the `i`-th output operand. }], - /*retTy=*/"ArrayAttr", - /*methodName=*/"iterator_types", - /*args=*/(ins), + /*retTy=*/"OpOperand*", + /*methodName=*/"getOutputOperand", + /*args=*/(ins "int64_t":$i), /*methodBody=*/"", /*defaultImplementation=*/[{ - return $_op.iterator_types(); + assert(i >= 0 && i < getNumOutputs()); + return &this->getOperation()->getOpOperand(getNumInputs() + i); }] >, InterfaceMethod< /*desc=*/[{ - Return true if the indexing map is depending on the current op instance. - This means that the indexing map is dynamically synthesized by using the - op instance's concrete attributes, instead of being static for all - instances of the same op kind. - }], - /*retTy=*/"bool", - /*methodName=*/"hasDynamicIndexingMaps", - /*args=*/(ins), - /*methodBody=*/"", - /*defaultImplementation=*/[{ return false; }] - >, - InterfaceMethod< - /*desc=*/[{ - Verify all attributes used by indexing maps are valid. + Set the `i`-th output operand. }], - /*retTy=*/"LogicalResult", - /*methodName=*/"verifyIndexingMapRequiredAttributes", - /*args=*/(ins), + /*retTy=*/"void", + /*methodName=*/"setOutputOperand", + /*args=*/(ins "int64_t":$i, "Value":$value), /*methodBody=*/"", - /*defaultImplementation=*/[{ return success(); }] - >, - InterfaceMethod< - /*desc=*/[{ - Return the indexing maps attribute within the current operation. - }], - /*retTy=*/"ArrayAttr", - /*methodName=*/"getIndexingMaps" + /*defaultImplementation=*/[{ + assert(i >= 0 && i < getNumOutputs()); + this->getOperation()->setOperand(getNumInputs() + i, value); + }] >, InterfaceMethod< /*desc=*/[{ - Return the indexing maps within the current operation. + Return the subset of output operands that are of buffer type. }], - /*retTy=*/"SmallVector", - /*methodName=*/"getIndexingMapsArray", + /*retTy=*/"OpOperandVector", + /*methodName=*/"getOutputBufferOperands", /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - auto range = $_op.getIndexingMaps() - .template getAsValueRange(); - return {range.begin(), range.end()}; + OpOperandVector result; + result.reserve(getNumOutputs()); + llvm::copy_if(getOutputOperands(), + std::back_inserter(result), + [](OpOperand *opOperand) { + return opOperand->get().getType().template isa(); + }); + return result; }] >, InterfaceMethod< /*desc=*/[{ - Return true if any of the operands has a dynamic shape. + Return the subset of output operands that are of tensor type. }], - /*retTy=*/"bool", - /*methodName=*/"hasDynamicShape", + /*retTy=*/"OpOperandVector", + /*methodName=*/"getOutputTensorOperands", /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - return llvm::any_of(getStaticShape(), ShapedType::isDynamic); + OpOperandVector result; + result.reserve(getNumOutputs()); + llvm::copy_if(getOutputOperands(), + std::back_inserter(result), + [](OpOperand *opOperand) { + return opOperand->get().getType().template isa(); + }); + return result; }] >, InterfaceMethod< /*desc=*/[{ - Return whether the op has only MemRef input and outputs. + Return the types of the subset of output operands that are of buffer type. }], - /*retTy=*/"bool", - /*methodName=*/"hasBufferSemantics", + /*retTy=*/"SmallVector", + /*methodName=*/"getOutputBufferTypes", /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - return this->getOperation()->getNumResults() == 0 && - llvm::all_of(this->getOperation()->getOpOperands(), - [&](OpOperand &opOperand) { - return isScalar(&opOperand) || - opOperand.get().getType().template isa(); - }); + SmallVector result; + result.reserve(getNumOutputs()); + llvm::transform(getOutputBufferOperands(), + std::back_inserter(result), + [](OpOperand *opOperands) { + return opOperands->get().getType().cast(); + }); + return result; }] >, InterfaceMethod< /*desc=*/[{ - Return whether the op has only RankedTensor input and outputs. + Return the types of the subset of output operands that are of tensor type. }], - /*retTy=*/"bool", - /*methodName=*/"hasTensorSemantics", + /*retTy=*/"SmallVector", + /*methodName=*/"getOutputTensorTypes", /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - return llvm::all_of(this->getOperation()->getOpOperands(), - [&](OpOperand &opOperand) { - return isScalar(&opOperand) || - opOperand.get().getType().template isa(); + SmallVector result; + result.reserve(getNumOutputs()); + llvm::transform(getOutputTensorOperands(), + std::back_inserter(result), + [](OpOperand *opOperands) { + return opOperands->get().getType().cast(); }); + return result; }] >, + //===------------------------------------------------------------------===// + // Input and Output arguments handling. + //===------------------------------------------------------------------===// InterfaceMethod< /*desc=*/[{ - Return the name registered for this op when lowering to an external - library call. + Return the range over input and output operands. }], - /*retTy=*/"std::string", - /*methodName=*/"getLibraryCallName", + /*retTy=*/"OpOperandVector", + /*methodName=*/"getInputAndOutputOperands", /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - return $_op.getLibraryCallName(); + int64_t numInputsAndOutputs = getNumInputsAndOutputs(); + OpOperandVector result; + result.reserve(numInputsAndOutputs); + llvm::transform( + this->getOperation()->getOpOperands(), + std::back_inserter(result), + [](OpOperand &opOperand) { return &opOperand; }); + return result; }] >, InterfaceMethod< /*desc=*/[{ - Return whether the op accesses the iteration indices. + Return true if `opOperand` is an input tensor. }], /*retTy=*/"bool", - /*methodName=*/"hasIndexSemantics", - /*args=*/(ins), + /*methodName=*/"isInputTensor", + /*args=*/(ins "OpOperand *":$opOperand), /*methodBody=*/"", - /*defaultImplementation=*/"" + /*defaultImplementation=*/[{ + if (!opOperand->get().getType().template isa()) + return false; + if (opOperand->getOperandNumber() < $_op.getNumInputs()) + return true; + return false; + }] >, - //===------------------------------------------------------------------===// - // Linalg generalization hooks. - //===------------------------------------------------------------------===// InterfaceMethod< /*desc=*/[{ - Hook to provide a custom AffineMap used to compute all the operand - subshapes given loop bounds. This is used to answer the question: "given - an iteration space over the codomain, what are the subshapes of the - operands involved in the computation". - The default behavior is to just concatenate all the indexing maps. - A custom AffineMap allows providing a map that can be used to - compute subshapes even in cases where the concatenation of indexing maps - (i.e. the data traversal order) is not a simple permutation of the loop - traversal order. It is then possible to define ops with skewed data - traversal order for which we can still easily compute hyperrectangular - loop bounds and subviews. + Return true if `opOperand` is an output tensor. }], - /*retTy=*/"AffineMap", - /*methodName=*/"getLoopsToShapesMap", - /*args=*/(ins), + /*retTy=*/"bool", + /*methodName=*/"isOutputTensor", + /*args=*/(ins "OpOperand *":$opOperand), /*methodBody=*/"", /*defaultImplementation=*/[{ - auto maps = $_op.getIndexingMapsArray(); - return concatAffineMaps(maps); + if (!opOperand->get().getType().template isa()) + return false; + if (opOperand->getOperandNumber() >= $_op.getNumInputs()) + return true; + return false; }] >, InterfaceMethod< /*desc=*/[{ - Hook to provide a custom AffineMap used to construct the - hyperrectangular loop iteration space given all the operand subshapes. - This is used to answer the question: - "Given a list of operand ranges, what is the subportion of the iteration - space involved in the computation". - This is the inverse problem of `getLoopsToShapesMap`. - Return the empty AffineMap when such an AffineMap cannot be constructed. - The default behavior is based on a very simple inference procedure that - only works with permutation affine maps. - A more advanced Tensor-Comprehension like inference is possible but has - proven to be ambiguous in unfavorable case. - A safer and more robust alternative is to allow each op to define - its own AffineMap. + Return true if the `opOperand` is a scalar value. }], - /*retTy=*/"AffineMap", - /*methodName=*/"getShapesToLoopsMap", - /*args=*/(ins), + /*retTy=*/"bool", + /*methodName=*/"isScalar", + /*args=*/(ins "OpOperand*":$opOperand), /*methodBody=*/"", /*defaultImplementation=*/[{ - return inversePermutation(getLoopsToShapesMap()); + assert(opOperand->getOwner() == this->getOperation()); + return !opOperand->get().getType().template isa(); }] >, InterfaceMethod< /*desc=*/[{ - Checks if the given operands can be dropped, and the remaining - operands can still compute the bounds of the op. + Return the result tied to `opOperand`. }], - /*retTy=*/"bool", - /*methodName=*/"canOpOperandsBeDropped", - /*args=*/(ins "ArrayRef":$droppedOperands), + /*retTy=*/"OpResult", + /*methodName=*/"getTiedOpResult", + /*args=*/(ins "OpOperand*":$opOperand), /*methodBody=*/"", /*defaultImplementation=*/[{ - return detail::canOpOperandsBeDroppedImpl($_op, droppedOperands); + assert(opOperand->getOwner() == this->getOperation()); + int64_t resultIndex = opOperand->getOperandNumber() - getNumInputs(); + assert(resultIndex >= 0 && + resultIndex < this->getOperation()->getNumResults() ); + return this->getOperation()->getResult(resultIndex); }] >, + //===------------------------------------------------------------------===// + // Other interface methods. + //===------------------------------------------------------------------===// InterfaceMethod< /*desc=*/[{ - Like `getShape`, but only returns statically-known information, without - generating any new IR. For each shape dimension, returns >=0 if that - dimension is statically known, or ShapeType::kDynamicSize otherwise. + Return whether the op has only MemRef input and outputs. }], - /*retTy=*/"SmallVector", - /*methodName=*/"getStaticShape", + /*retTy=*/"bool", + /*methodName=*/"hasBufferSemantics", /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - SmallVector res; - for (OpOperand *opOperand : getInputAndOutputOperands()) - llvm::append_range(res, getShape(opOperand)); - return res; + return this->getOperation()->getNumResults() == 0 && + llvm::all_of(this->getOperation()->getOpOperands(), + [&](OpOperand &opOperand) { + return isScalar(&opOperand) || + opOperand.get().getType().template isa(); + }); }] >, InterfaceMethod< /*desc=*/[{ - Returns the statically-known loop ranges. Composes - `getShapesToLoopsMap()` with the result of `getStaticShape`. - Returns ShapeType::kDynamicSize for non-statically-known loop ranges. - This is expected to be called by a valid Linalg op + Return whether the op has only RankedTensor input and outputs. }], - /*retTy=*/"SmallVector", - /*methodName=*/"getStaticLoopRanges", + /*retTy=*/"bool", + /*methodName=*/"hasTensorSemantics", /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - SmallVector viewSizes = getStaticShape(); - AffineMap invertedMap = getShapesToLoopsMap(); - assert(invertedMap && "expected a valid Linalg op to call the method"); - return invertedMap.compose(viewSizes); + return llvm::all_of(this->getOperation()->getOpOperands(), + [&](OpOperand &opOperand) { + return isScalar(&opOperand) || + opOperand.get().getType().template isa(); + }); }] >, //===------------------------------------------------------------------===// @@ -1045,27 +1295,6 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> { InterfaceMethod< /*desc=*/[{ Clone the current operation with the given location, operands - and BlockAndValueMapping. This is used to abstract away the - optional underlying region creation. This does not change the - balance between input, output_buffer and init_tensors - operands. - }], - /*retTy=*/"Operation *", - /*methodName=*/"cloneWithMapper", - (ins "OpBuilder &":$b, "Location":$loc, "TypeRange":$resultTypes, - "ValueRange":$operands, "BlockAndValueMapping &":$bvm), - [{ - OperationState state( - loc, ConcreteOp::getOperationName(), operands, resultTypes, - $_op->getAttrs()); - for (Region &r : $_op->getRegions()) - r.cloneInto(state.addRegion(), bvm); - return b.create(state); - }] - >, - InterfaceMethod< - /*desc=*/[{ - Clone the current operation with the given location, operands and BlockAndValueMapping but leave the regions empty. This is used to abstract away the optional underlying region creation. This does not change the balance between input, output_buffer @@ -1083,80 +1312,10 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> { state.addRegion(); return b.create(state); }] - >, - StaticInterfaceMethod< - /*desc=*/[{ - Returns the region builder for constructing the body for linalg.generic. - Returns a null function if this named op does not define a region - builder. - }], - /*retTy=*/"std::function)>", - /*methodName=*/"getRegionBuilder", - (ins), - [{ return ConcreteOp::getRegionBuilder(); }] - >, - InterfaceMethod< - /*desc=*/[{ - Return true if all the indexing maps are projected permutations. - Otherwise return false. - }], - /*retTy=*/"bool", - /*methodName=*/"hasOnlyProjectedPermutations", - (ins), - [{ - return llvm::all_of($_op.getIndexingMapsArray(), - [](AffineMap map) { return map.isProjectedPermutation(); }); - }] > ]; - let extraClassDeclaration = [{ - /// Return the flat list of all operand dimension sizes in the order they - /// appear in the operands. - SmallVector createFlatListOfOperandDims(OpBuilder &, Location); - - /// Return the flat list of all operands' static dimension sizes in the - /// order they appear in the operands. All operand dimension sizes have to - /// be statically known. - SmallVector createFlatListOfOperandStaticDims(); - - /// Create the loop ranges to materialize the computation over the current - /// operands. This is done by applying `getShapesToLoopsMap` to - /// `createFlatListOfOperandDims`. - SmallVector createLoopRanges(OpBuilder &b, Location loc); - - /// Compute the static loop sizes necessary to vectorize the computation. - /// This is done by applying `getShapesToLoopsMap` to - /// `createFlatListOfOperandStaticDims`. - SmallVector computeStaticLoopSizes(); - - /// Returns the value that expresses the shape of the output in terms of - /// shape of the input operands where possible - LogicalResult reifyResultShapes(OpBuilder &b, - ReifiedRankedShapedTypeDims &reifiedReturnShapes); - - // TODO: Remove once prefixing is flipped. - ArrayAttr getIteratorTypes() { return iterator_types(); } - - //========================================================================// - // Helper functions to mutate the `operand_segment_sizes` attribute. - // These are useful when cloning and changing operand types. - //========================================================================// - void setNumInputs(unsigned num) { setOperandSegmentAt(0, num); } - void setNumOutputBuffers(unsigned num) { setOperandSegmentAt(1, num); } - - private: - void setOperandSegmentAt(unsigned idx, unsigned val) { - auto attr = (*this)->getAttr("operand_segment_sizes") - .cast(); - unsigned i = 0; - auto newAttr = attr.mapValues(IntegerType::get(getContext(), 32), - [&](const APInt &v) { return (i++ == idx) ? APInt(32, val) : v; }); - getOperation()->setAttr("operand_segment_sizes", newAttr); - } - }]; - - let verify = [{ return detail::verifyStructuredOpInterface($_op); }]; + let verify = [{ return detail::verifyDestinationStyleOpInterface($_op); }]; let verifyWithRegions = 1; } diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td index 3a5b87a..02df051 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -28,6 +28,7 @@ class LinalgStructuredBase_Op props> : Op, DeclareOpInterfaceMethods, + DestinationStyleOpInterface, LinalgStructuredInterface, RegionBranchOpInterface, ReifyRankedShapedTypeOpInterface], props)> { diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp index 3b63824..b940e41 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -633,22 +633,6 @@ LinalgOp::reifyResultShapes(OpBuilder &b, LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) { LinalgOp linalgOp = cast(op); - // Expect at least one output operand. - // This means an op that constructs a tensor out of indices cannot be a - // LinalgOp at the moment. For now this will have to be a special op until we - // have output shape operands that are not tensors. - int64_t numInputs = linalgOp.getNumInputs(); - int64_t numOutputs = linalgOp.getNumOutputs(); - if (numOutputs == 0) - return op->emitOpError("expected at least one output operand"); - if (failed(OpTrait::impl::verifyNOperands(op, numInputs + numOutputs))) - return failure(); - // Verify the number of results matches the number of output tensors. - if (op->getNumResults() != linalgOp.getOutputTensorOperands().size()) - return op->emitOpError("expected the number of results (") - << op->getNumResults() - << ") to be equal to the number of output tensors (" - << linalgOp.getOutputTensorOperands().size() << ")"; // Check all iterator types are known. auto iteratorTypesRange = @@ -699,26 +683,6 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) { SmallVector redDims; linalgOp.getReductionDims(redDims); - // Simplifying assumption: either full tensor or full buffer mode. - // This allows simpler verification of output operands vs result types - // without premature tracking of which operand is what in mixed-mode. - // TODO: relax when mixed-mode needs to pass verification. - if (!linalgOp.getOutputBufferOperands().empty() && - !linalgOp.getOutputTensorOperands().empty()) - return op->emitOpError( - "expected output operands to all have tensor type or " - "all have buffer type"); - - for (OpOperand *opOperand : linalgOp.getOutputTensorOperands()) { - OpResult result = linalgOp.getTiedOpResult(opOperand); - if (result.getType() != opOperand->get().getType()) - return op->emitOpError("expected type of operand #") - << opOperand->getOperandNumber() << " (" - << opOperand->get().getType() << ")" - << " to match type of corresponding result (" << result.getType() - << ")"; - } - // Output tensor indexing map may not depend on reduction indices. for (OpOperand *opOperand : linalgOp.getOutputOperands()) { AffineMap indexingMap = linalgOp.getTiedIndexingMap(opOperand); @@ -740,36 +704,9 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) { } } - // Check the region has exactly one block. - if (linalgOp->getNumRegions() != 1 || - !llvm::hasSingleElement(linalgOp->getRegion(0))) - return op->emitOpError("expects to have 1 region with 1 block"); - if (!linalgOp.getShapesToLoopsMap()) return op->emitOpError("expected the shape-to-loops map to be non-null"); - // Simplifying assumption: bbargs match 1-1 with shape operands elemental - // types. - // TODO: once ranked shape types are plugged in, we may want to drop the - // corresponding bbargs, that can never be read from. This will be subject to - // consistency discussions (i.e. what to do with output tensors whose bbarg is - // not used). - Block &block = linalgOp->getRegion(0).front(); - - if (linalgOp.getNumInputsAndOutputs() != block.getNumArguments()) - return op->emitOpError("expected as many non-induction variable region " - "arguments as the number of input/output operands"); - - for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) { - Type elementType = getElementTypeOrSelf(opOperand->get()); - Type argType = block.getArgument(opOperand->getOperandNumber()).getType(); - if (elementType != argType) - return op->emitOpError("expected type of bb argument #") - << opOperand->getOperandNumber() << " (" << argType << ")" - << " to match element or self type of the corresponding operand (" - << elementType << ")"; - } - // Check if given shapes match to inferred shapes. SmallVector endLoopRangeValues = linalgOp.getStaticLoopRanges(); SmallVector startLoopRangeValues(endLoopRangeValues.size(), 0); @@ -835,3 +772,75 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) { return success(); } + +LogicalResult +mlir::linalg::detail::verifyDestinationStyleOpInterface(Operation *op) { + DestinationStyleOpInterface dstStyleOp = + cast(op); + + // Expect at least one output operand. + // This means an op that constructs a tensor out of indices cannot be a + // LinalgOp at the moment. For now this will have to be a special op until we + // have output shape operands that are not tensors. + int64_t numInputs = dstStyleOp.getNumInputs(); + int64_t numOutputs = dstStyleOp.getNumOutputs(); + if (numOutputs == 0) + return op->emitOpError("expected at least one output operand"); + if (failed(OpTrait::impl::verifyNOperands(op, numInputs + numOutputs))) + return failure(); + // Verify the number of results matches the number of output tensors. + if (op->getNumResults() != dstStyleOp.getOutputTensorOperands().size()) + return op->emitOpError("expected the number of results (") + << op->getNumResults() + << ") to be equal to the number of output tensors (" + << dstStyleOp.getOutputTensorOperands().size() << ")"; + + // Simplifying assumption: either full tensor or full buffer mode. + // This allows simpler verification of output operands vs result types + // without premature tracking of which operand is what in mixed-mode. + // TODO: relax when mixed-mode needs to pass verification. + if (!dstStyleOp.getOutputBufferOperands().empty() && + !dstStyleOp.getOutputTensorOperands().empty()) + return op->emitOpError( + "expected output operands to all have tensor type or " + "all have buffer type"); + + for (OpOperand *opOperand : dstStyleOp.getOutputTensorOperands()) { + OpResult result = dstStyleOp.getTiedOpResult(opOperand); + if (result.getType() != opOperand->get().getType()) + return op->emitOpError("expected type of operand #") + << opOperand->getOperandNumber() << " (" + << opOperand->get().getType() << ")" + << " to match type of corresponding result (" << result.getType() + << ")"; + } + + // Check the region has exactly one block. + if (dstStyleOp->getNumRegions() != 1 || + !llvm::hasSingleElement(dstStyleOp->getRegion(0))) + return op->emitOpError("expects to have 1 region with 1 block"); + + // Simplifying assumption: bbargs match 1-1 with shape operands elemental + // types. + // TODO: once ranked shape types are plugged in, we may want to drop the + // corresponding bbargs, that can never be read from. This will be subject to + // consistency discussions (i.e. what to do with output tensors whose bbarg is + // not used). + Block &block = dstStyleOp->getRegion(0).front(); + + if (dstStyleOp.getNumInputsAndOutputs() != block.getNumArguments()) + return op->emitOpError("expected as many non-induction variable region " + "arguments as the number of input/output operands"); + + for (OpOperand *opOperand : dstStyleOp.getInputAndOutputOperands()) { + Type elementType = getElementTypeOrSelf(opOperand->get()); + Type argType = block.getArgument(opOperand->getOperandNumber()).getType(); + if (elementType != argType) + return op->emitOpError("expected type of bb argument #") + << opOperand->getOperandNumber() << " (" << argType << ")" + << " to match element or self type of the corresponding operand (" + << elementType << ")"; + } + + return success(); +} diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index a506b1c..6732f34 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -2741,7 +2741,8 @@ def TestLinalgConvOpNotLinalgOp : TEST_Op<"conv_op_not_linalg_op", [ def TestLinalgConvOp : TEST_Op<"linalg_conv_op", [AttrSizedOperandSegments, SingleBlock, - LinalgStructuredInterface, LinalgConvolutionOpInterface]> { + DestinationStyleOpInterface, LinalgStructuredInterface, + LinalgConvolutionOpInterface]> { let arguments = (ins Variadic:$inputs, Variadic:$outputs); @@ -2799,7 +2800,8 @@ def TestLinalgFillOpNotLinalgOp : TEST_Op<"fill_op_not_linalg_op", [ def TestLinalgFillOp : TEST_Op<"linalg_fill_op", [AttrSizedOperandSegments, SingleBlock, - LinalgStructuredInterface, LinalgFillOpInterface]> { + DestinationStyleOpInterface, LinalgStructuredInterface, + LinalgFillOpInterface]> { let arguments = (ins Variadic:$inputs, Variadic:$outputs); -- 2.7.4