From cfc9ddaafc5d4c2b293c8e6b7c9244c4844c7c89 Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Tue, 18 Oct 2022 17:23:42 +0200 Subject: [PATCH] [mlir][interfaces][NFC] Move DestinationStyleOpInterface to mlir/Interfaces This is the second (and final) step of making "destination style" usable without depending on the Linalg dialect. (The first step was D135129.) This change allows us to provide default bufferization implementations for all destination-style ops. It also allows us to simplify `TilingInterface`. (E.g., `getDestinationOperands` can be removed.) Differential Revision: https://reviews.llvm.org/D136179 --- mlir/include/mlir/Dialect/Linalg/IR/Linalg.h | 1 + .../mlir/Dialect/Linalg/IR/LinalgInterfaces.h | 9 +- .../mlir/Dialect/Linalg/IR/LinalgInterfaces.td | 287 ------------------- .../mlir/Dialect/Linalg/IR/LinalgStructuredOps.td | 3 +- mlir/include/mlir/Interfaces/CMakeLists.txt | 1 + .../mlir/Interfaces/DestinationStyleOpInterface.h | 34 +++ .../mlir/Interfaces/DestinationStyleOpInterface.td | 306 +++++++++++++++++++++ mlir/lib/Dialect/Linalg/IR/CMakeLists.txt | 1 + mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp | 60 ---- .../Transforms/BufferizableOpInterfaceImpl.cpp | 3 +- mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt | 1 + mlir/lib/Interfaces/CMakeLists.txt | 2 + .../lib/Interfaces/DestinationStyleOpInterface.cpp | 71 +++++ mlir/test/lib/Dialect/Test/CMakeLists.txt | 1 + mlir/test/lib/Dialect/Test/TestOps.td | 1 + utils/bazel/llvm-project-overlay/mlir/BUILD.bazel | 41 +++ .../llvm-project-overlay/mlir/test/BUILD.bazel | 2 + 17 files changed, 467 insertions(+), 357 deletions(-) create mode 100644 mlir/include/mlir/Interfaces/DestinationStyleOpInterface.h create mode 100644 mlir/include/mlir/Interfaces/DestinationStyleOpInterface.td create mode 100644 mlir/lib/Interfaces/DestinationStyleOpInterface.cpp diff --git a/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h b/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h index 70e7fc9..28c75fc 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h @@ -20,6 +20,7 @@ #include "mlir/IR/TypeUtilities.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/CopyOpInterface.h" +#include "mlir/Interfaces/DestinationStyleOpInterface.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Interfaces/TilingInterface.h" diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h index 39ca855..8e3df10 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h @@ -19,6 +19,7 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/OpDefinition.h" +#include "mlir/Interfaces/DestinationStyleOpInterface.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/ViewLikeInterface.h" @@ -26,11 +27,6 @@ namespace mlir { namespace linalg { class LinalgOp; -/// OpOperand vector that implicitly converts to a Value vector. -struct OpOperandVector : public SmallVector { - operator SmallVector(); -}; - namespace detail { /// Implementation of the method that that check if given operands /// can be dropped, i.e. the remaining operands can compute the loop @@ -57,9 +53,6 @@ LogicalResult verifyFillInterface(Operation *op); /// Verify that `op` conforms to the invariants of StructuredOpInterface LogicalResult verifyStructuredOpInterface(Operation *op); -/// Verify that `op` conforms to the invariants of DestinationStyleOpInterface -LogicalResult verifyDestinationStyleOpInterface(Operation *op); - } // namespace detail } // namespace linalg } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td index 995ced5..28bf0b0 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td @@ -879,291 +879,4 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> { let verifyWithRegions = 1; } -// Ops that are in destination style have designated output operands, which act -// as initial tensor values for the results of the operation or the output -// buffers to which the results of the op will be written. -// -// Output operands must be tensors or memrefs. Input operands can have any -// type. All non-output operands are inputs. - -// It is assumed that the output operands of the op are the operands at -// position [start, end). The positions are defined by getOutputsPositionRange -// method. All non-output operands are "inputs" of the DPS op. - -// If the op has "tensor semantics", then the input operands are either scalars -// or tensors. The output operands are tensors and every tensor output is tied -// to a corresponding tensor OpResult in a 1-to-1 fashion. The i-th output -// tensor is tied to the i-th OpResult. The op may not have any additional -// OpResults. Output operands and their tied OpResults have the same type. -// -// If the op has "buffer semantics", then the input operands are either memrefs -// or other non-tensor types, e.g. scalar types. Furthermore, the output -// operands are memrefs and the op has no results. -// -// Destination-passing style abstraction makes certain transformations easier. -// For example, tiling implementation can extract/insert slices from/into the -// destination of an op and use the resulting shaped value as an iter_arg in -// the surrounding loop structure. As another example, bufferization does not -// have to allocate new buffers for destinations (in case of in-place -// bufferization) and can directly reuse the existing destination buffer. -// -// Example of a destination style op: `%r = tensor.insert_slice %t into %d`, -// where `%t` is the single input and `%d` is the single output. `%d` is tied -// to `%r`. -// -// Example of an op that is not in destination style: `%r = tensor.pad %t`. -// This op is not in destination style because `%r` and `%t` have different -// shape. -// -// Each op that wants to implement DestinationStyleOpInterface needs to define -// the getOutputsPositionRange() method. -def DestinationStyleOpInterface : OpInterface<"DestinationStyleOpInterface"> { - let cppNamespace = "::mlir::linalg"; - let methods = [ - // This method has to be defined for every DPS op. - InterfaceMethod< - /*desc=*/"Return start and end indices of the output operands range.", - /*retTy=*/"std::pair", - /*methodName=*/"getOutputsPositionRange", - /*args=*/(ins), - /*methodBody=*/"", - /*defaultImplementation=*/"" - >, - //===------------------------------------------------------------------===// - // Operands handling. - //===------------------------------------------------------------------===// - // The operand list is assumed to start with the input operands and end - // with the output operands. Therefore, all methods to access the inputs - // and outputs can be expressed if the number of output operands is know. - InterfaceMethod< - /*desc=*/"Return the number of outputs.", - /*retTy=*/"int64_t", - /*methodName=*/"getNumOutputs", - /*args=*/(ins), - /*methodBody=*/"", - /*defaultImplementation=*/[{ - auto [start, end] = $_op.getOutputsPositionRange(); - return end - start; - }] - >, - InterfaceMethod< - /*desc=*/"Return the output operands.", - /*retTy=*/"OpOperandVector", - /*methodName=*/"getOutputOperands", - /*args=*/(ins), - /*methodBody=*/"", - /*defaultImplementation=*/[{ - auto [start, end] = $_op.getOutputsPositionRange(); - - OpOperandVector result; - result.reserve(end - start); - for (int i = start; i < end; ++i) - result.push_back(&$_op->getOpOperand(i)); - return result; - }] - >, - InterfaceMethod< - /*desc=*/"Return the `i`-th output operand.", - /*retTy=*/"OpOperand*", - /*methodName=*/"getOutputOperand", - /*args=*/(ins "int64_t":$i), - /*methodBody=*/"", - /*defaultImplementation=*/[{ - assert(i >= 0 && i < $_op.getNumOutputs()); - auto [start, end] = $_op.getOutputsPositionRange(); - return &$_op->getOpOperand(start + i); - }] - >, - InterfaceMethod< - /*desc=*/"Set the `i`-th output operand.", - /*retTy=*/"void", - /*methodName=*/"setOutputOperand", - /*args=*/(ins "int64_t":$i, "Value":$value), - /*methodBody=*/"", - /*defaultImplementation=*/[{ - assert(i >= 0 && i < $_op.getNumOutputs()); - auto [start, end] = $_op.getOutputsPositionRange(); - $_op->setOperand(start + i, value); - }] - >, - InterfaceMethod< - /*desc=*/"Return the number of inputs.", - /*retTy=*/"int64_t", - /*methodName=*/"getNumInputs", - /*args=*/(ins), - /*methodBody=*/"", - /*defaultImplementation=*/[{ - return $_op.getNumOperands() - $_op.getNumOutputs(); - }] - >, - InterfaceMethod< - /*desc=*/"Return the input operands.", - /*retTy=*/"OpOperandVector", - /*methodName=*/"getInputOperands", - /*args=*/(ins), - /*methodBody=*/"", - /*defaultImplementation=*/[{ - auto [start, end] = $_op.getOutputsPositionRange(); - int64_t numOutputs = end - start; - int64_t numOperands = $_op.getNumOperands(); - - OpOperandVector result; - result.reserve(numOperands - numOutputs); - for (int i = 0; i < start; ++i) - result.push_back(&$_op->getOpOperand(i)); - for (int i = end; i < numOperands; ++i) - result.push_back(&$_op->getOpOperand(end + i)); - - return result; - }] - >, - InterfaceMethod< - /*desc=*/[{ Return the `i`-th input operand. }], - /*retTy=*/"OpOperand*", - /*methodName=*/"getInputOperand", - /*args=*/(ins "int64_t":$i), - /*methodBody=*/"", - /*defaultImplementation=*/[{ - assert(i >= 0 && i < getNumInputs()); - auto [start, end] = $_op.getOutputsPositionRange(); - return &$_op->getOpOperand(i < start ? i : i + end - start) ; - }] - >, - //===------------------------------------------------------------------===// - // Input and Output arguments handling. - //===------------------------------------------------------------------===// - InterfaceMethod< - /*desc=*/"Return true if `opOperand` is an input.", - /*retTy=*/"bool", - /*methodName=*/"isInput", - /*args=*/(ins "OpOperand *":$opOperand), - /*methodBody=*/"", - /*defaultImplementation=*/[{ - auto [start, end] = $_op.getOutputsPositionRange(); - auto operandNumber = opOperand->getOperandNumber(); - return operandNumber < start || operandNumber >= end; - }] - >, - InterfaceMethod< - /*desc=*/"Return true if `opOperand` is an output.", - /*retTy=*/"bool", - /*methodName=*/"isOutput", - /*args=*/(ins "OpOperand *":$opOperand), - /*methodBody=*/"", - /*defaultImplementation=*/[{ - auto [start, end] = $_op.getOutputsPositionRange(); - auto operandNumber = opOperand->getOperandNumber(); - return operandNumber >= start && operandNumber < end; - }] - >, - InterfaceMethod< - /*desc=*/"Return true if the `opOperand` is a scalar value.", - /*retTy=*/"bool", - /*methodName=*/"isScalar", - /*args=*/(ins "OpOperand*":$opOperand), - /*methodBody=*/"", - /*defaultImplementation=*/[{ - assert(opOperand->getOwner() == this->getOperation()); - return !opOperand->get().getType().template isa(); - }] - >, - InterfaceMethod< - /*desc=*/"Return the result tied to `opOperand`.", - /*retTy=*/"OpResult", - /*methodName=*/"getTiedOpResult", - /*args=*/(ins "OpOperand*":$opOperand), - /*methodBody=*/"", - /*defaultImplementation=*/[{ - assert(opOperand->getOwner() == this->getOperation()); - - auto [start, end] = $_op.getOutputsPositionRange(); - int64_t resultIndex = opOperand->getOperandNumber() - start; - assert(resultIndex >= 0 && - resultIndex < $_op->getNumResults() ); - return $_op->getResult(resultIndex); - }] - >, - //===------------------------------------------------------------------===// - // Other interface methods. - //===------------------------------------------------------------------===// - InterfaceMethod< - /*desc=*/"Return whether the op has only MemRef input and outputs.", - /*retTy=*/"bool", - /*methodName=*/"hasBufferSemantics", - /*args=*/(ins), - /*methodBody=*/"", - /*defaultImplementation=*/[{ - return $_op->getNumResults() == 0 && - llvm::all_of($_op->getOpOperands(), - [&](OpOperand &opOperand) { - return isScalar(&opOperand) || - opOperand.get().getType().template isa(); - }); - }] - >, - InterfaceMethod< - /*desc=*/"Return whether the op has only RankedTensor input and outputs.", - /*retTy=*/"bool", - /*methodName=*/"hasTensorSemantics", - /*args=*/(ins), - /*methodBody=*/"", - /*defaultImplementation=*/[{ - return llvm::all_of($_op->getOpOperands(), - [&](OpOperand &opOperand) { - return isScalar(&opOperand) || - opOperand.get().getType().template isa(); - }); - }] - >, - //===------------------------------------------------------------------===// - // Other static interface methods. - //===------------------------------------------------------------------===// - InterfaceMethod< - /*desc=*/[{ - Clone the current operation with the given location and operands. This - is used to abstract away the optional underlying region creation. This - does not change the balance between input, output_buffer and - init_tensors operands. - }], - /*retTy=*/"Operation *", - /*methodName=*/"clone", - (ins "OpBuilder &":$b, "Location":$loc, "TypeRange":$resultTypes, - "ValueRange":$operands), - [{ - BlockAndValueMapping bvm; - OperationState state( - loc, ConcreteOp::getOperationName(), operands, resultTypes, - $_op->getAttrs()); - for (Region &r : $_op->getRegions()) - r.cloneInto(state.addRegion(), bvm); - return b.create(state); - }] - >, - InterfaceMethod< - /*desc=*/[{ - Clone the current operation with the given location, operands - and BlockAndValueMapping but leave the regions empty. This is - used to abstract away the optional underlying region creation. - This does not change the balance between input, output_buffer - and init_tensors operands. - }], - /*retTy=*/"Operation *", - /*methodName=*/"cloneWithoutRegions", - (ins "OpBuilder &":$b, "Location":$loc, "TypeRange":$resultTypes, - "ValueRange":$operands), - [{ - OperationState state( - loc, ConcreteOp::getOperationName(), operands, resultTypes, - $_op->getAttrs()); - for (size_t cnt = 0, e = $_op->getNumRegions(); cnt < e; ++cnt) - state.addRegion(); - return b.create(state); - }] - > - ]; - - let verify = [{ return detail::verifyDestinationStyleOpInterface($_op); }]; - let verifyWithRegions = 1; -} - #endif // LINALG_IR_LINALGINTERFACES diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td index 3b06a59..4b83de1 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -17,6 +17,7 @@ include "mlir/Dialect/Linalg/IR/LinalgBase.td" include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td" include "mlir/Interfaces/ControlFlowInterfaces.td" +include "mlir/Interfaces/DestinationStyleOpInterface.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/IR/OpAsmInterface.td" @@ -279,7 +280,7 @@ def MapOp : LinalgStructuredBase_Op<"map", [ int64_t getNumOperands = this->getNumOperands(); return {getNumOperands - 1, getNumOperands}; } - linalg::OpOperandVector getOpOperandsMatchingBBargs() { + OpOperandVector getOpOperandsMatchingBBargs() { return getInputOperands(); } diff --git a/mlir/include/mlir/Interfaces/CMakeLists.txt b/mlir/include/mlir/Interfaces/CMakeLists.txt index e471b9e..721a9de 100644 --- a/mlir/include/mlir/Interfaces/CMakeLists.txt +++ b/mlir/include/mlir/Interfaces/CMakeLists.txt @@ -3,6 +3,7 @@ add_mlir_interface(CastInterfaces) add_mlir_interface(ControlFlowInterfaces) add_mlir_interface(CopyOpInterface) add_mlir_interface(DerivedAttributeOpInterface) +add_mlir_interface(DestinationStyleOpInterface) add_mlir_interface(InferIntRangeInterface) add_mlir_interface(InferTypeOpInterface) add_mlir_interface(LoopLikeInterface) diff --git a/mlir/include/mlir/Interfaces/DestinationStyleOpInterface.h b/mlir/include/mlir/Interfaces/DestinationStyleOpInterface.h new file mode 100644 index 0000000..fe8478e --- /dev/null +++ b/mlir/include/mlir/Interfaces/DestinationStyleOpInterface.h @@ -0,0 +1,34 @@ +//===- DestinationStyleOpInterface.h ----------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_INTERFACES_DESTINATIONSTYLEOPINTERFACE_H_ +#define MLIR_INTERFACES_DESTINATIONSTYLEOPINTERFACE_H_ + +#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/Value.h" +#include "llvm/ADT/SmallVector.h" + +namespace mlir { +/// OpOperand vector that implicitly converts to a Value vector. +struct OpOperandVector : public llvm::SmallVector { + operator SmallVector(); +}; + +namespace detail { +/// Verify that `op` conforms to the invariants of DestinationStyleOpInterface +LogicalResult verifyDestinationStyleOpInterface(Operation *op); +} // namespace detail +} // namespace mlir + +/// Include the generated interface declarations. +#include "mlir/Interfaces/DestinationStyleOpInterface.h.inc" + +#endif // MLIR_INTERFACES_DESTINATIONSTYLEOPINTERFACE_H_ diff --git a/mlir/include/mlir/Interfaces/DestinationStyleOpInterface.td b/mlir/include/mlir/Interfaces/DestinationStyleOpInterface.td new file mode 100644 index 0000000..718b32d --- /dev/null +++ b/mlir/include/mlir/Interfaces/DestinationStyleOpInterface.td @@ -0,0 +1,306 @@ +//===- DestinationStyleOpInterface.td ----------------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DESTINATIONSTYLEOPINTERFACE +#define MLIR_DESTINATIONSTYLEOPINTERFACE + +include "mlir/IR/OpBase.td" + +def DestinationStyleOpInterface : OpInterface<"DestinationStyleOpInterface"> { + let description = [{ + Ops that are in destination style have designated output operands, which act + as initial tensor values for the results of the operation or the output + buffers to which the results of the op will be written. + + Output operands must be tensors or memrefs. Input operands can have any + type. All non-output operands are inputs. + + It is assumed that the output operands of the op are the operands at + position [start, end). The positions are defined by getOutputsPositionRange + method. All non-output operands are "inputs" of the DPS op. + + If the op has "tensor semantics", then the input operands are either scalars + or tensors. The output operands are tensors and every tensor output is tied + to a corresponding tensor OpResult in a 1-to-1 fashion. The i-th output + tensor is tied to the i-th OpResult. The op may not have any additional + OpResults. Output operands and their tied OpResults have the same type. + + If the op has "buffer semantics", then the input operands are either memrefs + or other non-tensor types, e.g. scalar types. Furthermore, the output + operands are memrefs and the op has no results. + + Destination-passing style abstraction makes certain transformations easier. + For example, tiling implementation can extract/insert slices from/into the + destination of an op and use the resulting shaped value as an iter_arg in + the surrounding loop structure. As another example, bufferization does not + have to allocate new buffers for destinations (in case of in-place + bufferization) and can directly reuse the existing destination buffer. + + Example of a destination style op: `%r = tensor.insert_slice %t into %d`, + where `%t` is the single input and `%d` is the single output. `%d` is tied + to `%r`. + + Example of an op that is not in destination style: `%r = tensor.pad %t`. + This op is not in destination style because `%r` and `%t` have different + shape. + + Each op that wants to implement DestinationStyleOpInterface needs to define + the getOutputsPositionRange() method. + }]; + + let cppNamespace = "::mlir"; + + let methods = [ + // This method has to be defined for every DPS op. + InterfaceMethod< + /*desc=*/"Return start and end indices of the output operands range.", + /*retTy=*/"std::pair", + /*methodName=*/"getOutputsPositionRange", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/"" + >, + //===------------------------------------------------------------------===// + // Operands handling. + //===------------------------------------------------------------------===// + // The operand list is assumed to start with the input operands and end + // with the output operands. Therefore, all methods to access the inputs + // and outputs can be expressed if the number of output operands is know. + InterfaceMethod< + /*desc=*/"Return the number of outputs.", + /*retTy=*/"int64_t", + /*methodName=*/"getNumOutputs", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + auto [start, end] = $_op.getOutputsPositionRange(); + return end - start; + }] + >, + InterfaceMethod< + /*desc=*/"Return the output operands.", + /*retTy=*/"OpOperandVector", + /*methodName=*/"getOutputOperands", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + auto [start, end] = $_op.getOutputsPositionRange(); + + OpOperandVector result; + result.reserve(end - start); + for (int i = start; i < end; ++i) + result.push_back(&$_op->getOpOperand(i)); + return result; + }] + >, + InterfaceMethod< + /*desc=*/"Return the `i`-th output operand.", + /*retTy=*/"OpOperand *", + /*methodName=*/"getOutputOperand", + /*args=*/(ins "int64_t":$i), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + assert(i >= 0 && i < $_op.getNumOutputs()); + auto [start, end] = $_op.getOutputsPositionRange(); + return &$_op->getOpOperand(start + i); + }] + >, + InterfaceMethod< + /*desc=*/"Set the `i`-th output operand.", + /*retTy=*/"void", + /*methodName=*/"setOutputOperand", + /*args=*/(ins "int64_t":$i, "Value":$value), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + assert(i >= 0 && i < $_op.getNumOutputs()); + auto [start, end] = $_op.getOutputsPositionRange(); + $_op->setOperand(start + i, value); + }] + >, + InterfaceMethod< + /*desc=*/"Return the number of inputs.", + /*retTy=*/"int64_t", + /*methodName=*/"getNumInputs", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return $_op.getNumOperands() - $_op.getNumOutputs(); + }] + >, + InterfaceMethod< + /*desc=*/"Return the input operands.", + /*retTy=*/"OpOperandVector", + /*methodName=*/"getInputOperands", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + auto [start, end] = $_op.getOutputsPositionRange(); + int64_t numOutputs = end - start; + int64_t numOperands = $_op.getNumOperands(); + + OpOperandVector result; + result.reserve(numOperands - numOutputs); + for (int i = 0; i < start; ++i) + result.push_back(&$_op->getOpOperand(i)); + for (int i = end; i < numOperands; ++i) + result.push_back(&$_op->getOpOperand(end + i)); + + return result; + }] + >, + InterfaceMethod< + /*desc=*/[{ Return the `i`-th input operand. }], + /*retTy=*/"OpOperand *", + /*methodName=*/"getInputOperand", + /*args=*/(ins "int64_t":$i), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + assert(i >= 0 && i < getNumInputs()); + auto [start, end] = $_op.getOutputsPositionRange(); + return &$_op->getOpOperand(i < start ? i : i + end - start) ; + }] + >, + //===------------------------------------------------------------------===// + // Input and Output arguments handling. + //===------------------------------------------------------------------===// + InterfaceMethod< + /*desc=*/"Return true if `opOperand` is an input.", + /*retTy=*/"bool", + /*methodName=*/"isInput", + /*args=*/(ins "OpOperand *":$opOperand), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + auto [start, end] = $_op.getOutputsPositionRange(); + auto operandNumber = opOperand->getOperandNumber(); + return operandNumber < start || operandNumber >= end; + }] + >, + InterfaceMethod< + /*desc=*/"Return true if `opOperand` is an output.", + /*retTy=*/"bool", + /*methodName=*/"isOutput", + /*args=*/(ins "OpOperand *":$opOperand), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + auto [start, end] = $_op.getOutputsPositionRange(); + auto operandNumber = opOperand->getOperandNumber(); + return operandNumber >= start && operandNumber < end; + }] + >, + InterfaceMethod< + /*desc=*/"Return true if the `opOperand` is a scalar value.", + /*retTy=*/"bool", + /*methodName=*/"isScalar", + /*args=*/(ins "OpOperand *":$opOperand), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + assert(opOperand->getOwner() == this->getOperation()); + return !opOperand->get().getType().template isa(); + }] + >, + InterfaceMethod< + /*desc=*/"Return the result tied to `opOperand`.", + /*retTy=*/"OpResult", + /*methodName=*/"getTiedOpResult", + /*args=*/(ins "OpOperand *":$opOperand), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + assert(opOperand->getOwner() == this->getOperation()); + + auto [start, end] = $_op.getOutputsPositionRange(); + int64_t resultIndex = opOperand->getOperandNumber() - start; + assert(resultIndex >= 0 && + resultIndex < $_op->getNumResults() ); + return $_op->getResult(resultIndex); + }] + >, + //===------------------------------------------------------------------===// + // Other interface methods. + //===------------------------------------------------------------------===// + InterfaceMethod< + /*desc=*/"Return whether the op has only MemRef input and outputs.", + /*retTy=*/"bool", + /*methodName=*/"hasBufferSemantics", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return $_op->getNumResults() == 0 && + llvm::all_of($_op->getOpOperands(), + [&](OpOperand &opOperand) { + return isScalar(&opOperand) || + opOperand.get().getType().template isa(); + }); + }] + >, + InterfaceMethod< + /*desc=*/"Return whether the op has only RankedTensor input and outputs.", + /*retTy=*/"bool", + /*methodName=*/"hasTensorSemantics", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return llvm::all_of($_op->getOpOperands(), + [&](OpOperand &opOperand) { + return isScalar(&opOperand) || + opOperand.get().getType().template isa(); + }); + }] + >, + //===------------------------------------------------------------------===// + // Other static interface methods. + //===------------------------------------------------------------------===// + InterfaceMethod< + /*desc=*/[{ + Clone the current operation with the given location and operands. This + is used to abstract away the optional underlying region creation. This + does not change the balance between input, output_buffer and + init_tensors operands. + }], + /*retTy=*/"Operation *", + /*methodName=*/"clone", + (ins "OpBuilder &":$b, "Location":$loc, "TypeRange":$resultTypes, + "ValueRange":$operands), + [{ + BlockAndValueMapping bvm; + OperationState state( + loc, ConcreteOp::getOperationName(), operands, resultTypes, + $_op->getAttrs()); + for (Region &r : $_op->getRegions()) + r.cloneInto(state.addRegion(), bvm); + return b.create(state); + }] + >, + InterfaceMethod< + /*desc=*/[{ + Clone the current operation with the given location, operands + and BlockAndValueMapping but leave the regions empty. This is + used to abstract away the optional underlying region creation. + This does not change the balance between input, output_buffer + and init_tensors operands. + }], + /*retTy=*/"Operation *", + /*methodName=*/"cloneWithoutRegions", + (ins "OpBuilder &":$b, "Location":$loc, "TypeRange":$resultTypes, + "ValueRange":$operands), + [{ + OperationState state( + loc, ConcreteOp::getOperationName(), operands, resultTypes, + $_op->getAttrs()); + for (size_t cnt = 0, e = $_op->getNumRegions(); cnt < e; ++cnt) + state.addRegion(); + return b.create(state); + }] + > + ]; + + let verify = [{ return detail::verifyDestinationStyleOpInterface($_op); }]; + let verifyWithRegions = 1; +} + + +#endif // MLIR_DESTINATIONSTYLEOPINTERFACE diff --git a/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt b/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt index 247eceb..85412db 100644 --- a/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt @@ -18,6 +18,7 @@ add_mlir_dialect_library(MLIRLinalgDialect MLIRArithDialect MLIRArithUtils MLIRBufferizationDialect + MLIRDestinationStyleOpInterface MLIRDialectUtils MLIRInferTypeOpInterface MLIRIR diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp index 88fd71c..78e8490 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -462,14 +462,6 @@ LogicalResult mlir::linalg::detail::verifyFillInterface(Operation *op) { // StructuredOpInterface implementation //===----------------------------------------------------------------------===// -OpOperandVector::operator SmallVector() { - SmallVector result; - result.reserve(this->size()); - llvm::transform(*this, std::back_inserter(result), - [](OpOperand *opOperand) { return opOperand->get(); }); - return result; -} - /// Helper function that creates a memref::DimOp or tensor::DimOp depending on /// the type of `source`. static Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, @@ -770,55 +762,3 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) { return success(); } - -LogicalResult -mlir::linalg::detail::verifyDestinationStyleOpInterface(Operation *op) { - DestinationStyleOpInterface dstStyleOp = - cast(op); - - SmallVector outputBufferOperands, outputTensorOperands; - for (OpOperand *operand : dstStyleOp.getOutputOperands()) { - Type type = operand->get().getType(); - if (type.isa()) - outputBufferOperands.push_back(operand); - if (type.isa()) - outputTensorOperands.push_back(operand); - } - - // Expect at least one output operand. - // This means an op that constructs a tensor out of indices cannot be a - // LinalgOp at the moment. For now this will have to be a special op until we - // have output shape operands that are not tensors. - int64_t numInputs = dstStyleOp.getNumInputs(); - int64_t numOutputs = dstStyleOp.getNumOutputs(); - if (numOutputs == 0) - return op->emitOpError("expected at least one output operand"); - if (failed(OpTrait::impl::verifyNOperands(op, numInputs + numOutputs))) - return failure(); - // Verify the number of results matches the number of output tensors. - if (op->getNumResults() != outputTensorOperands.size()) - return op->emitOpError("expected the number of results (") - << op->getNumResults() - << ") to be equal to the number of output tensors (" - << outputTensorOperands.size() << ")"; - - // Simplifying assumption: either full tensor or full buffer mode. - // This allows simpler verification of output operands vs result types - // without premature tracking of which operand is what in mixed-mode. - // TODO: relax when mixed-mode needs to pass verification. - if (!outputBufferOperands.empty() && !outputTensorOperands.empty()) - return op->emitOpError( - "expected output operands to all have tensor type or " - "all have buffer type"); - - for (OpOperand *opOperand : outputTensorOperands) { - OpResult result = dstStyleOp.getTiedOpResult(opOperand); - if (result.getType() != opOperand->get().getType()) - return op->emitOpError("expected type of operand #") - << opOperand->getOperandNumber() << " (" - << opOperand->get().getType() << ")" - << " to match type of corresponding result (" << result.getType() - << ")"; - } - return success(); -} diff --git a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp index bb38004..ca17a21 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp @@ -13,6 +13,7 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/Operation.h" +#include "mlir/Interfaces/DestinationStyleOpInterface.h" using namespace mlir; using namespace linalg; @@ -115,7 +116,7 @@ struct LinalgOpInterface SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { - auto genericOp = cast(op); + auto genericOp = cast(op); // The i-th "out" tensor may alias with the i-th OpResult. if (genericOp.isOutput(&opOperand)) diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt index 24048c2..c69f55c 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -43,6 +43,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms MLIRBufferizationDialect MLIRBufferizationTransforms MLIRComplexDialect + MLIRDestinationStyleOpInterface MLIRDialectUtils MLIRFuncDialect MLIRFuncToLLVM diff --git a/mlir/lib/Interfaces/CMakeLists.txt b/mlir/lib/Interfaces/CMakeLists.txt index 83b2fab..fb4958a 100644 --- a/mlir/lib/Interfaces/CMakeLists.txt +++ b/mlir/lib/Interfaces/CMakeLists.txt @@ -5,6 +5,7 @@ set(LLVM_OPTIONAL_SOURCES CopyOpInterface.cpp DataLayoutInterfaces.cpp DerivedAttributeOpInterface.cpp + DestinationStyleOpInterface.cpp InferIntRangeInterface.cpp InferTypeOpInterface.cpp LoopLikeInterface.cpp @@ -38,6 +39,7 @@ add_mlir_interface_library(ControlFlowInterfaces) add_mlir_interface_library(CopyOpInterface) add_mlir_interface_library(DataLayoutInterfaces) add_mlir_interface_library(DerivedAttributeOpInterface) +add_mlir_interface_library(DestinationStyleOpInterface) add_mlir_interface_library(InferIntRangeInterface) add_mlir_interface_library(InferTypeOpInterface) add_mlir_interface_library(LoopLikeInterface) diff --git a/mlir/lib/Interfaces/DestinationStyleOpInterface.cpp b/mlir/lib/Interfaces/DestinationStyleOpInterface.cpp new file mode 100644 index 0000000..104b0fa --- /dev/null +++ b/mlir/lib/Interfaces/DestinationStyleOpInterface.cpp @@ -0,0 +1,71 @@ +//===- DestinationStyleOpInterface.cpp -- Destination style ops -----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Interfaces/DestinationStyleOpInterface.h" + +using namespace mlir; + +namespace mlir { +#include "mlir/Interfaces/DestinationStyleOpInterface.cpp.inc" +} // namespace mlir + +OpOperandVector::operator SmallVector() { + SmallVector result; + result.reserve(this->size()); + llvm::transform(*this, std::back_inserter(result), + [](OpOperand *opOperand) { return opOperand->get(); }); + return result; +} + +LogicalResult detail::verifyDestinationStyleOpInterface(Operation *op) { + DestinationStyleOpInterface dstStyleOp = + cast(op); + + SmallVector outputBufferOperands, outputTensorOperands; + for (OpOperand *operand : dstStyleOp.getOutputOperands()) { + Type type = operand->get().getType(); + if (type.isa()) + outputBufferOperands.push_back(operand); + if (type.isa()) + outputTensorOperands.push_back(operand); + } + + // Expect at least one output operand. + int64_t numInputs = dstStyleOp.getNumInputs(); + int64_t numOutputs = dstStyleOp.getNumOutputs(); + if (numOutputs == 0) + return op->emitOpError("expected at least one output operand"); + if (failed(OpTrait::impl::verifyNOperands(op, numInputs + numOutputs))) + return failure(); + // Verify the number of results matches the number of output tensors. + if (op->getNumResults() != outputTensorOperands.size()) + return op->emitOpError("expected the number of results (") + << op->getNumResults() + << ") to be equal to the number of output tensors (" + << outputTensorOperands.size() << ")"; + + // Simplifying assumption: either full tensor or full buffer mode. + // This allows simpler verification of output operands vs result types + // without premature tracking of which operand is what in mixed-mode. + // TODO: relax when mixed-mode needs to pass verification. + if (!outputBufferOperands.empty() && !outputTensorOperands.empty()) + return op->emitOpError( + "expected output operands to all have tensor type or " + "all have buffer type"); + + for (OpOperand *opOperand : outputTensorOperands) { + OpResult result = dstStyleOp.getTiedOpResult(opOperand); + if (result.getType() != opOperand->get().getType()) + return op->emitOpError("expected type of operand #") + << opOperand->getOperandNumber() << " (" + << opOperand->get().getType() << ")" + << " to match type of corresponding result (" << result.getType() + << ")"; + } + return success(); +} diff --git a/mlir/test/lib/Dialect/Test/CMakeLists.txt b/mlir/test/lib/Dialect/Test/CMakeLists.txt index 141d618..2c8719d 100644 --- a/mlir/test/lib/Dialect/Test/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Test/CMakeLists.txt @@ -54,6 +54,7 @@ add_mlir_library(MLIRTestDialect MLIRControlFlowInterfaces MLIRDataLayoutInterfaces MLIRDerivedAttributeOpInterface + MLIRDestinationStyleOpInterface MLIRDialect MLIRDLTIDialect MLIRFuncDialect diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index 9888923..e4206f9 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -23,6 +23,7 @@ include "mlir/Interfaces/CallInterfaces.td" include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/CopyOpInterface.td" include "mlir/Interfaces/DataLayoutInterfaces.td" +include "mlir/Interfaces/DestinationStyleOpInterface.td" include "mlir/Interfaces/InferIntRangeInterface.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/LoopLikeInterface.td" diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel index 3666aa8..024ba9c 100644 --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -996,6 +996,13 @@ td_library( ) td_library( + name = "DestinationStyleOpInterfaceTdFiles", + srcs = ["include/mlir/Interfaces/DestinationStyleOpInterface.td"], + includes = ["include"], + deps = [":OpBaseTdFiles"], +) + +td_library( name = "InferIntRangeInterfaceTdFiles", srcs = ["include/mlir/Interfaces/InferIntRangeInterface.td"], includes = ["include"], @@ -5322,6 +5329,36 @@ cc_library( ) gentbl_cc_library( + name = "DestinationStyleOpInterfaceIncGen", + strip_include_prefix = "include", + tbl_outs = [ + ( + ["-gen-op-interface-decls"], + "include/mlir/Interfaces/DestinationStyleOpInterface.h.inc", + ), + ( + ["-gen-op-interface-defs"], + "include/mlir/Interfaces/DestinationStyleOpInterface.cpp.inc", + ), + ], + tblgen = ":mlir-tblgen", + td_file = "include/mlir/Interfaces/DestinationStyleOpInterface.td", + deps = [":DestinationStyleOpInterfaceTdFiles"], +) + +cc_library( + name = "DestinationStyleOpInterface", + srcs = ["lib/Interfaces/DestinationStyleOpInterface.cpp"], + hdrs = ["include/mlir/Interfaces/DestinationStyleOpInterface.h"], + includes = ["include"], + deps = [ + ":DestinationStyleOpInterfaceIncGen", + ":IR", + "//llvm:Support", + ], +) + +gentbl_cc_library( name = "InferIntRangeInterfaceIncGen", strip_include_prefix = "include", tbl_outs = [ @@ -7437,6 +7474,7 @@ td_library( includes = ["include"], deps = [ ":ControlFlowInterfacesTdFiles", + ":DestinationStyleOpInterfaceTdFiles", ":DialectUtilsTdFiles", ":InferTypeOpInterfaceTdFiles", ":LoopLikeInterfaceTdFiles", @@ -7571,6 +7609,7 @@ td_library( includes = ["include"], deps = [ ":CopyOpInterfaceTdFiles", + ":DestinationStyleOpInterface", ":LinalgOpsTdFiles", ":OpBaseTdFiles", ":SideEffectInterfacesTdFiles", @@ -7768,6 +7807,7 @@ cc_library( ":ComplexDialect", ":ControlFlowInterfaces", ":CopyOpInterface", + ":DestinationStyleOpInterface", ":DialectUtils", ":FuncDialect", ":IR", @@ -7925,6 +7965,7 @@ cc_library( ":BufferizationTransforms", ":ComplexDialect", ":ControlFlowDialect", + ":DestinationStyleOpInterface", ":DialectUtils", ":FuncDialect", ":FuncTransforms", diff --git a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel index 02ae5f1..c11dee8 100644 --- a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel @@ -94,6 +94,7 @@ td_library( "//mlir:CopyOpInterfaceTdFiles", "//mlir:DLTIDialectTdFiles", "//mlir:DataLayoutInterfacesTdFiles", + "//mlir:DestinationStyleOpInterfaceTdFiles", "//mlir:InferIntRangeInterfaceTdFiles", "//mlir:InferTypeOpInterfaceTdFiles", "//mlir:LinalgStructuredOpsTdFiles", @@ -325,6 +326,7 @@ cc_library( "//mlir:DLTIDialect", "//mlir:DataLayoutInterfaces", "//mlir:DerivedAttributeOpInterface", + "//mlir:DestinationStyleOpInterface", "//mlir:Dialect", "//mlir:FuncDialect", "//mlir:FuncTransforms", -- 2.7.4