From b4db15a949646f45011f31c58133adab59f8ddb0 Mon Sep 17 00:00:00 2001 From: Alexander Belyaev Date: Fri, 28 Oct 2022 15:24:34 +0200 Subject: [PATCH] [mlir] Rename getInputs->getDpsInputs and getOutputs->getDpsInits in DPS interface. https://discourse.llvm.org/t/rfc-interface-for-destination-style-ops/64056 Differential Revision: https://reviews.llvm.org/D136943 --- .../IR/DstBufferizableOpInterfaceImpl.h | 4 +- .../Dialect/Linalg/Analysis/DependenceAnalysis.h | 2 +- .../mlir/Dialect/Linalg/IR/LinalgInterfaces.td | 48 +++++----- .../mlir/Dialect/Linalg/IR/LinalgStructuredOps.td | 12 +-- mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td | 4 +- mlir/include/mlir/Dialect/Vector/IR/VectorOps.td | 2 +- .../mlir/Interfaces/DestinationStyleOpInterface.td | 106 ++++++++++----------- .../lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp | 2 +- .../Dialect/Linalg/Analysis/DependenceAnalysis.cpp | 16 ++-- mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp | 14 +-- mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 91 +++++++++--------- .../Linalg/Transforms/BubbleUpExtractSlice.cpp | 6 +- .../Transforms/BufferizableOpInterfaceImpl.cpp | 10 +- .../lib/Dialect/Linalg/Transforms/ConstantFold.cpp | 10 +- .../Linalg/Transforms/DecomposeLinalgOps.cpp | 11 ++- .../lib/Dialect/Linalg/Transforms/DropUnitDims.cpp | 5 +- .../Linalg/Transforms/ElementwiseOpFusion.cpp | 76 ++++++++------- .../Transforms/FusePadOpWithLinalgProducer.cpp | 2 +- mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp | 6 +- .../Dialect/Linalg/Transforms/FusionOnTensors.cpp | 6 +- .../Dialect/Linalg/Transforms/Generalization.cpp | 4 +- .../lib/Dialect/Linalg/Transforms/HoistPadding.cpp | 2 +- .../Linalg/Transforms/InlineScalarOperands.cpp | 8 +- mlir/lib/Dialect/Linalg/Transforms/Loops.cpp | 6 +- .../Linalg/Transforms/NamedOpConversions.cpp | 16 ++-- mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp | 2 +- .../Dialect/Linalg/Transforms/SplitReduction.cpp | 24 ++--- mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp | 2 +- .../Linalg/Transforms/TilingInterfaceImpl.cpp | 10 +- mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp | 4 +- .../Dialect/Linalg/Transforms/Vectorization.cpp | 24 ++--- mlir/lib/Dialect/Linalg/Utils/Utils.cpp | 18 ++-- .../Dialect/SCF/Transforms/TileUsingInterface.cpp | 4 +- .../Transforms/SparseTensorRewriting.cpp | 32 +++---- .../SparseTensor/Transforms/Sparsification.cpp | 16 ++-- .../lib/Interfaces/DestinationStyleOpInterface.cpp | 10 +- .../Dialect/Linalg/TestLinalgElementwiseFusion.cpp | 4 +- .../Dialect/Linalg/TestLinalgFusionTransforms.cpp | 2 +- mlir/test/lib/Dialect/Test/TestOps.td | 4 +- .../test-linalg-ods-yaml-gen.yaml | 2 +- .../mlir-linalg-ods-yaml-gen.cpp | 6 +- 41 files changed, 319 insertions(+), 314 deletions(-) diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h b/mlir/include/mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h index f1b1c65..4fc88eb 100644 --- a/mlir/include/mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h @@ -33,14 +33,14 @@ struct DstBufferizableOpInterfaceExternalModel const AnalysisState &state) const { // Only outputs bufferize to a memory write. auto dstOp = cast(op); - return dstOp.isOutput(&opOperand); + return dstOp.isDpsInit(&opOperand); } SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { // Output operands alias with their respective tied OpResults. auto dstOp = cast(op); - if (dstOp.isOutput(&opOperand)) + if (dstOp.isDpsInit(&opOperand)) return {dstOp.getTiedOpResult(&opOperand)}; return {}; } diff --git a/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h b/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h index e49fe0c..b34973c 100644 --- a/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h +++ b/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h @@ -83,7 +83,7 @@ public: return llvm::None; if (OpOperand *operand = opView.dyn_cast()) return owner.getMatchingIndexingMap(operand); - return owner.getMatchingIndexingMap(owner.getOutputOperand( + return owner.getMatchingIndexingMap(owner.getDpsInitOperand( opView.get().cast().getResultNumber())); } // Return the operand number if the `opView` is an OpOperand *. Otherwise diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td index 28bf0b0..533a52f 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td @@ -289,7 +289,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> { getNumIterators(getReductionIteratorTypeName(), iters) == 1; }]>, //===------------------------------------------------------------------===// - // Input and Output arguments handling. + // Input and Init arguments handling. //===------------------------------------------------------------------===// InterfaceMethod< /*desc=*/[{ @@ -317,7 +317,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> { /*args=*/(ins "OpOperand *":$opOperand), /*methodBody=*/"", /*defaultImplementation=*/[{ - if (!$_op.isOutput(opOperand)) + if (!$_op.isDpsInit(opOperand)) return false; return payloadUsesValueFromOperand(opOperand); }] @@ -353,7 +353,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> { // TODO: reevalute the need for a cast when a better mechanism exists. return getBlock()->getArguments().take_front( cast(*this->getOperation()) - .getNumInputs()); + .getNumDpsInputs()); }] >, InterfaceMethod< @@ -371,7 +371,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> { // TODO: reevalute the need for a cast when a better mechanism exists. return getBlock()->getArguments().take_back( cast(*this->getOperation()) - .getNumOutputs()); + .getNumDpsInits()); }] >, InterfaceMethod< @@ -450,7 +450,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> { // TODO: reevalute the need for a cast when a better mechanism exists. return *(indexingMaps.begin() + cast(*this->getOperation()) - .getNumInputs() + + .getNumDpsInputs() + result.getResultNumber()); }] >, @@ -472,7 +472,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> { int64_t resultIndex = opOperand->getOperandNumber() - cast(*this->getOperation()) - .getNumInputs(); + .getNumDpsInputs(); assert(resultIndex >= 0 && resultIndex < this->getOperation()->getNumResults()); Operation *yieldOp = getBlock()->getTerminator(); @@ -780,49 +780,49 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> { // TODO: reevalute the need for a cast when a better mechanism exists. //========================================================================// - int64_t getNumInputs() { + int64_t getNumDpsInputs() { return cast(*this->getOperation()) - .getNumInputs(); + .getNumDpsInputs(); } - int64_t getNumOutputs() { + int64_t getNumDpsInits() { return cast(*this->getOperation()) - .getNumOutputs(); + .getNumDpsInits(); } - OpOperandVector getInputOperands() { + OpOperandVector getDpsInputOperands() { return cast(*this->getOperation()) - .getInputOperands(); + .getDpsInputOperands(); } - OpOperand *getInputOperand(int64_t i) { + OpOperand *getDpsInputOperand(int64_t i) { return cast(*this->getOperation()) - .getInputOperand(i); + .getDpsInputOperand(i); } - void setOutputOperand(int64_t i, Value value) { + void setDpsInitOperand(int64_t i, Value value) { return cast(*this->getOperation()) - .setOutputOperand(i, value); + .setDpsInitOperand(i, value); } - OpOperandVector getOutputOperands() { + OpOperandVector getDpsInitOperands() { return cast(*this->getOperation()) - .getOutputOperands(); + .getDpsInitOperands(); } - OpOperand *getOutputOperand(int64_t i) { + OpOperand *getDpsInitOperand(int64_t i) { return cast(*this->getOperation()) - .getOutputOperand(i); + .getDpsInitOperand(i); } - bool isInput(OpOperand *opOperand) { + bool isDpsInput(OpOperand *opOperand) { return cast(*this->getOperation()) - .isInput(opOperand); + .isDpsInput(opOperand); } - bool isOutput(OpOperand *opOperand) { + bool isDpsInit(OpOperand *opOperand) { return cast(*this->getOperation()) - .isOutput(opOperand); + .isDpsInit(opOperand); } bool isScalar(OpOperand *opOperand) { diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td index 510f883..b067a1d 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -216,7 +216,7 @@ def GenericOp : LinalgStructuredBase_Op<"generic", [ getRegionBuilder() { return nullptr; } - std::pair getOutputsPositionRange() { + std::pair getDpsInitsPositionRange() { int64_t getNumOperands = this->getNumOperands(); return {getNumOperands - getOutputs().size(), getNumOperands}; } @@ -282,16 +282,16 @@ def MapOp : LinalgStructuredBase_Op<"map", [ } // Implement functions necessary for DestinationStyleOpInterface. - std::pair getOutputsPositionRange() { + std::pair getDpsInitsPositionRange() { int64_t getNumOperands = this->getNumOperands(); return {getNumOperands - 1, getNumOperands}; } OpOperandVector getOpOperandsMatchingBBargs() { - return getInputOperands(); + return getDpsInputOperands(); } bool payloadUsesValueFromOperand(OpOperand * opOperand) { - if (isOutput(opOperand)) return false; + if (isDpsInit(opOperand)) return false; return !getMatchingBlockArgument(opOperand).use_empty(); } @@ -368,7 +368,7 @@ def ReduceOp : LinalgStructuredBase_Op<"reduce", [ getRegionBuilder() { return nullptr; } - std::pair getOutputsPositionRange() { + std::pair getDpsInitsPositionRange() { return {getInits().size(), getNumOperands()}; } }]; @@ -433,7 +433,7 @@ def TransposeOp : LinalgStructuredBase_Op<"transpose", [ } // Implement functions necessary for DestinationStyleOpInterface. - std::pair getOutputsPositionRange() { + std::pair getDpsInitsPositionRange() { int64_t getNumOperands = this->getNumOperands(); return {getNumOperands - 1, getNumOperands}; } diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td index e2451f9..2cfdc6d 100644 --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -723,7 +723,7 @@ def Tensor_InsertOp : Tensor_Op<"insert", [ }]>]; let extraClassDeclaration = [{ - std::pair getOutputsPositionRange() { + std::pair getDpsInitsPositionRange() { return {1, 2}; // `dest` operand } }]; @@ -868,7 +868,7 @@ def Tensor_InsertSliceOp : Tensor_OpWithOffsetSizesAndStrides<"insert_slice", [ /// and `strides` operands. static unsigned getOffsetSizeAndStrideStartOperandIndex() { return 2; } - std::pair getOutputsPositionRange() { + std::pair getDpsInitsPositionRange() { return {1, 2}; // `dest` operand } }]; diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index e0ec6f4..b47c5fa 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -1397,7 +1397,7 @@ def Vector_TransferWriteOp : /// ops of other dialects. Value getValue() { return getVector(); } - std::pair getOutputsPositionRange() { + std::pair getDpsInitsPositionRange() { return {1, 2}; // `source` operand } }]; diff --git a/mlir/include/mlir/Interfaces/DestinationStyleOpInterface.td b/mlir/include/mlir/Interfaces/DestinationStyleOpInterface.td index 0e7662d..75f7477 100644 --- a/mlir/include/mlir/Interfaces/DestinationStyleOpInterface.td +++ b/mlir/include/mlir/Interfaces/DestinationStyleOpInterface.td @@ -13,27 +13,27 @@ include "mlir/IR/OpBase.td" def DestinationStyleOpInterface : OpInterface<"DestinationStyleOpInterface"> { let description = [{ - Ops that are in destination style have designated output operands, which act - as initial tensor values for the results of the operation or the output + Ops that are in destination style have designated init operands, which act + as initial tensor values for the results of the operation or the init buffers to which the results of the op will be written. - Output operands must be ranked tensors or ranked memrefs. Input operands can - have any type. All non-output operands are inputs. + Init operands must be ranked tensors or ranked memrefs. Input operands can + have any type. All non-init operands are DPS inputs. - It is assumed that the output operands of the op are the operands at - position [start, end). The positions are defined by getOutputsPositionRange - method. All non-output operands are "inputs" of the DPS op. + It is assumed that the init operands of the op are the operands at + position [start, end). The positions are defined by getDpsInitsPositionRange + method. If the op has "tensor semantics", then the input operands are either scalars - or ranked tensors. The output operands are ranked tensors and every tensor - output is tied to a corresponding tensor OpResult in a 1-to-1 fashion. - The i-th output tensor is tied to the i-th OpResult. The op may not have any - additional OpResults. Output operands and their tied OpResults have the same + or ranked tensors. The init operands are ranked tensors and every tensor + init is tied to a corresponding tensor OpResult in a 1-to-1 fashion. + The i-th init tensor is tied to the i-th OpResult. The op may not have any + additional OpResults. Init operands and their tied OpResults have the same type. If the op has "buffer semantics", then the input operands are either ranked memrefs or other non-tensor types, e.g. scalar types. Furthermore, the - output operands are ranked memrefs and the op has no results. + init operands are ranked memrefs and the op has no results. Destination-passing style abstraction makes certain transformations easier. For example, tiling implementation can extract/insert slices from/into the @@ -43,7 +43,7 @@ def DestinationStyleOpInterface : OpInterface<"DestinationStyleOpInterface"> { bufferization) and can directly reuse the existing destination buffer. Example of a destination style op: `%r = tensor.insert_slice %t into %d`, - where `%t` is the single input and `%d` is the single output. `%d` is tied + where `%t` is the single input and `%d` is the single init. `%d` is tied to `%r`. Example of an op that is not in destination style: `%r = tensor.pad %t`. @@ -51,7 +51,7 @@ def DestinationStyleOpInterface : OpInterface<"DestinationStyleOpInterface"> { shape. Each op that wants to implement DestinationStyleOpInterface needs to define - the getOutputsPositionRange() method. + the getDpsInitsPositionRange() method. }]; let cppNamespace = "::mlir"; @@ -59,9 +59,9 @@ def DestinationStyleOpInterface : OpInterface<"DestinationStyleOpInterface"> { let methods = [ // This method has to be defined for every DPS op. InterfaceMethod< - /*desc=*/"Return start and end indices of the output operands range.", + /*desc=*/"Return start and end indices of the init operands range.", /*retTy=*/"std::pair", - /*methodName=*/"getOutputsPositionRange", + /*methodName=*/"getDpsInitsPositionRange", /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/"" @@ -70,27 +70,27 @@ def DestinationStyleOpInterface : OpInterface<"DestinationStyleOpInterface"> { // Operands handling. //===------------------------------------------------------------------===// // The operand list is assumed to start with the input operands and end - // with the output operands. Therefore, all methods to access the inputs - // and outputs can be expressed if the number of output operands is know. + // with the init operands. Therefore, all methods to access the inputs + // and inits can be expressed if the number of init operands is know. InterfaceMethod< - /*desc=*/"Return the number of outputs.", + /*desc=*/"Return the number of inits.", /*retTy=*/"int64_t", - /*methodName=*/"getNumOutputs", + /*methodName=*/"getNumDpsInits", /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - auto [start, end] = $_op.getOutputsPositionRange(); + auto [start, end] = $_op.getDpsInitsPositionRange(); return end - start; }] >, InterfaceMethod< - /*desc=*/"Return the output operands.", + /*desc=*/"Return the init operands.", /*retTy=*/"OpOperandVector", - /*methodName=*/"getOutputOperands", + /*methodName=*/"getDpsInitOperands", /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - auto [start, end] = $_op.getOutputsPositionRange(); + auto [start, end] = $_op.getDpsInitsPositionRange(); OpOperandVector result; result.reserve(end - start); @@ -100,52 +100,52 @@ def DestinationStyleOpInterface : OpInterface<"DestinationStyleOpInterface"> { }] >, InterfaceMethod< - /*desc=*/"Return the `i`-th output operand.", + /*desc=*/"Return the `i`-th init operand.", /*retTy=*/"OpOperand *", - /*methodName=*/"getOutputOperand", + /*methodName=*/"getDpsInitOperand", /*args=*/(ins "int64_t":$i), /*methodBody=*/"", /*defaultImplementation=*/[{ - assert(i >= 0 && i < $_op.getNumOutputs()); - auto [start, end] = $_op.getOutputsPositionRange(); + assert(i >= 0 && i < $_op.getNumDpsInits()); + auto [start, end] = $_op.getDpsInitsPositionRange(); return &$_op->getOpOperand(start + i); }] >, InterfaceMethod< - /*desc=*/"Set the `i`-th output operand.", + /*desc=*/"Set the `i`-th init operand.", /*retTy=*/"void", - /*methodName=*/"setOutputOperand", + /*methodName=*/"setDpsInitOperand", /*args=*/(ins "int64_t":$i, "Value":$value), /*methodBody=*/"", /*defaultImplementation=*/[{ - assert(i >= 0 && i < $_op.getNumOutputs()); - auto [start, end] = $_op.getOutputsPositionRange(); + assert(i >= 0 && i < $_op.getNumDpsInits()); + auto [start, end] = $_op.getDpsInitsPositionRange(); $_op->setOperand(start + i, value); }] >, InterfaceMethod< /*desc=*/"Return the number of inputs.", /*retTy=*/"int64_t", - /*methodName=*/"getNumInputs", + /*methodName=*/"getNumDpsInputs", /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - return $_op.getNumOperands() - $_op.getNumOutputs(); + return $_op.getNumOperands() - $_op.getNumDpsInits(); }] >, InterfaceMethod< /*desc=*/"Return the input operands.", /*retTy=*/"OpOperandVector", - /*methodName=*/"getInputOperands", + /*methodName=*/"getDpsInputOperands", /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - auto [start, end] = $_op.getOutputsPositionRange(); - int64_t numOutputs = end - start; + auto [start, end] = $_op.getDpsInitsPositionRange(); + int64_t numInits = end - start; int64_t numOperands = $_op.getNumOperands(); OpOperandVector result; - result.reserve(numOperands - numOutputs); + result.reserve(numOperands - numInits); for (int i = 0; i < start; ++i) result.push_back(&$_op->getOpOperand(i)); for (int i = end; i < numOperands; ++i) @@ -157,38 +157,38 @@ def DestinationStyleOpInterface : OpInterface<"DestinationStyleOpInterface"> { InterfaceMethod< /*desc=*/[{ Return the `i`-th input operand. }], /*retTy=*/"OpOperand *", - /*methodName=*/"getInputOperand", + /*methodName=*/"getDpsInputOperand", /*args=*/(ins "int64_t":$i), /*methodBody=*/"", /*defaultImplementation=*/[{ - assert(i >= 0 && i < getNumInputs()); - auto [start, end] = $_op.getOutputsPositionRange(); + assert(i >= 0 && i < getNumDpsInputs()); + auto [start, end] = $_op.getDpsInitsPositionRange(); return &$_op->getOpOperand(i < start ? i : i + end - start) ; }] >, //===------------------------------------------------------------------===// - // Input and Output arguments handling. + // Input and DpsInit arguments handling. //===------------------------------------------------------------------===// InterfaceMethod< /*desc=*/"Return true if `opOperand` is an input.", /*retTy=*/"bool", - /*methodName=*/"isInput", + /*methodName=*/"isDpsInput", /*args=*/(ins "OpOperand *":$opOperand), /*methodBody=*/"", /*defaultImplementation=*/[{ - auto [start, end] = $_op.getOutputsPositionRange(); + auto [start, end] = $_op.getDpsInitsPositionRange(); auto operandNumber = opOperand->getOperandNumber(); return operandNumber < start || operandNumber >= end; }] >, InterfaceMethod< - /*desc=*/"Return true if `opOperand` is an output.", + /*desc=*/"Return true if `opOperand` is an init.", /*retTy=*/"bool", - /*methodName=*/"isOutput", + /*methodName=*/"isDpsInit", /*args=*/(ins "OpOperand *":$opOperand), /*methodBody=*/"", /*defaultImplementation=*/[{ - auto [start, end] = $_op.getOutputsPositionRange(); + auto [start, end] = $_op.getDpsInitsPositionRange(); auto operandNumber = opOperand->getOperandNumber(); return operandNumber >= start && operandNumber < end; }] @@ -213,7 +213,7 @@ def DestinationStyleOpInterface : OpInterface<"DestinationStyleOpInterface"> { /*defaultImplementation=*/[{ assert(opOperand->getOwner() == $_op.getOperation()); - auto [start, end] = $_op.getOutputsPositionRange(); + auto [start, end] = $_op.getDpsInitsPositionRange(); int64_t resultIndex = opOperand->getOperandNumber() - start; assert(resultIndex >= 0 && resultIndex < $_op->getNumResults() ); @@ -228,14 +228,14 @@ def DestinationStyleOpInterface : OpInterface<"DestinationStyleOpInterface"> { /*methodBody=*/"", /*defaultImplementation=*/[{ assert(opResult.getDefiningOp() == $_op.getOperation()); - return $_op.getOutputOperand(opResult.getResultNumber()); + return $_op.getDpsInitOperand(opResult.getResultNumber()); }] >, //===------------------------------------------------------------------===// // Other interface methods. //===------------------------------------------------------------------===// InterfaceMethod< - /*desc=*/"Return whether the op has only ranked MemRef inputs/outputs.", + /*desc=*/"Return whether the op has only ranked MemRef input/inits.", /*retTy=*/"bool", /*methodName=*/"hasBufferSemantics", /*args=*/(ins), @@ -250,7 +250,7 @@ def DestinationStyleOpInterface : OpInterface<"DestinationStyleOpInterface"> { }] >, InterfaceMethod< - /*desc=*/"Return whether the op has only ranked tensor inputs/outputs.", + /*desc=*/"Return whether the op has only ranked tensor inputs/inits.", /*retTy=*/"bool", /*methodName=*/"hasTensorSemantics", /*args=*/(ins), @@ -270,7 +270,7 @@ def DestinationStyleOpInterface : OpInterface<"DestinationStyleOpInterface"> { /*desc=*/[{ Clone the current operation with the given location and operands. This is used to abstract away the optional underlying region creation. This - does not change the balance between input, output_buffer and + does not change the balance between input, init_buffer and init_tensors operands. }], /*retTy=*/"Operation *", @@ -292,7 +292,7 @@ def DestinationStyleOpInterface : OpInterface<"DestinationStyleOpInterface"> { Clone the current operation with the given location, operands and BlockAndValueMapping but leave the regions empty. This is used to abstract away the optional underlying region creation. - This does not change the balance between input, output_buffer + This does not change the balance between input, init_buffer and init_tensors operands. }], /*retTy=*/"Operation *", diff --git a/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp b/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp index 305ea9c..866d414 100644 --- a/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp +++ b/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp @@ -70,7 +70,7 @@ SingleWorkgroupReduction::matchAsPerformingReduction( return llvm::None; // Make sure this is reduction with one input and one output. - if (genericOp.getNumInputs() != 1 || genericOp.getNumOutputs() != 1) + if (genericOp.getNumDpsInputs() != 1 || genericOp.getNumDpsInits() != 1) return llvm::None; auto originalInputType = op->getOperand(0).getType().cast(); diff --git a/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp b/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp index c4c7efb..e72cf5d 100644 --- a/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp +++ b/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp @@ -165,7 +165,7 @@ void LinalgDependenceGraph::addDependencesBetween(LinalgOp src, LinalgOp dst) { LLVM_DEBUG(dbgs() << "addDependencesBetween " << *src.getOperation() << " and " << *dst.getOperation() << "\n"); if (src.hasTensorSemantics() && dst.hasTensorSemantics()) { - for (OpOperand *dstOpOperand : dst.getInputOperands()) { + for (OpOperand *dstOpOperand : dst.getDpsInputOperands()) { if (!dstOpOperand->get().getType().isa()) continue; // Check if the operand is defined by the src. @@ -174,7 +174,7 @@ void LinalgDependenceGraph::addDependencesBetween(LinalgOp src, LinalgOp dst) { addDependenceElem(DependenceType::RAW, dstOpOperand->get(), dstOpOperand); } - for (OpOperand *dstOpOperand : dst.getOutputOperands()) { + for (OpOperand *dstOpOperand : dst.getDpsInitOperands()) { // Check if the operand is defined by the src. auto definingOp = dstOpOperand->get().getDefiningOp(); if (definingOp && definingOp == src) { @@ -190,31 +190,31 @@ void LinalgDependenceGraph::addDependencesBetween(LinalgOp src, LinalgOp dst) { } assert(src.hasBufferSemantics() && dst.hasBufferSemantics() && "unhandled dependence tracking for mixed buffer/tensor operations"); - for (OpOperand *srcOpOperand : src.getOutputOperands()) { // W + for (OpOperand *srcOpOperand : src.getDpsInitOperands()) { // W // RAW graph - for (OpOperand *dstOpOperand : dst.getInputOperands()) { // R + for (OpOperand *dstOpOperand : dst.getDpsInputOperands()) { // R if (!dstOpOperand->get().getType().isa()) continue; if (aliases.alias(srcOpOperand->get(), dstOpOperand->get())) // RAW alias addDependenceElem(DependenceType::RAW, srcOpOperand, dstOpOperand); } // WAW graph - for (OpOperand *dstOpOperand : dst.getOutputOperands()) // W + for (OpOperand *dstOpOperand : dst.getDpsInitOperands()) // W if (aliases.alias(srcOpOperand->get(), dstOpOperand->get())) // WAW alias addDependenceElem(DependenceType::WAW, srcOpOperand, dstOpOperand); } - for (OpOperand *srcOpOperand : src.getInputOperands()) { // R + for (OpOperand *srcOpOperand : src.getDpsInputOperands()) { // R if (!srcOpOperand->get().getType().isa()) continue; // RAR graph - for (OpOperand *dstOpOperand : dst.getInputOperands()) { // R + for (OpOperand *dstOpOperand : dst.getDpsInputOperands()) { // R if (!dstOpOperand->get().getType().isa()) continue; if (aliases.alias(srcOpOperand->get(), dstOpOperand->get())) // RAR alias addDependenceElem(DependenceType::RAR, srcOpOperand, dstOpOperand); } // WAR graph - for (OpOperand *dstOpOperand : dst.getOutputOperands()) // W + for (OpOperand *dstOpOperand : dst.getDpsInitOperands()) // W if (aliases.alias(srcOpOperand->get(), dstOpOperand->get())) // WAR alias addDependenceElem(DependenceType::WAR, srcOpOperand, dstOpOperand); } diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp index 78e8490..78a29f4 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -119,7 +119,7 @@ static MatchContractionResult isContractionInterfaceImpl(Operation *op) { auto linalgOp = dyn_cast(op); if (!linalgOp) return MatchContractionResult::NotLinalgOp; - if (linalgOp.getNumInputs() != 2 || linalgOp.getNumOutputs() != 1) + if (linalgOp.getNumDpsInputs() != 2 || linalgOp.getNumDpsInits() != 1) return MatchContractionResult::WrongNumOperands; auto mapRange = linalgOp.getIndexingMapsArray(); if (linalgOp.getNumReductionLoops() == 0) @@ -278,7 +278,7 @@ static MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op) { auto linalgOp = dyn_cast(op); if (!linalgOp) return MatchConvolutionResult::NotLinalgOp; - if (linalgOp.getNumInputs() < 2 || linalgOp.getNumOutputs() != 1) + if (linalgOp.getNumDpsInputs() < 2 || linalgOp.getNumDpsInits() != 1) return MatchConvolutionResult::WrongNumOperands; auto indexingMaps = linalgOp.getIndexingMapsArray(); @@ -436,10 +436,10 @@ static MatchFillResult isFillInterfaceImpl(Operation *op) { auto linalgOp = dyn_cast(op); if (!linalgOp) return MatchFillResult::NotLinalgOp; - if (linalgOp.getNumInputs() != 1 || linalgOp.getNumOutputs() != 1) + if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1) return MatchFillResult::WrongNumOperands; - OpOperand *value = linalgOp.getInputOperand(0); + OpOperand *value = linalgOp.getDpsInputOperand(0); if (!linalgOp.isScalar(value)) return MatchFillResult::NotScalarInput; @@ -555,9 +555,9 @@ static std::pair getResultsPositionInLoopsToShapeMap(LinalgOp &op) { int64_t inputRankSum = 0; int64_t outputRankSum = 0; - for (OpOperand *input : op.getInputOperands()) + for (OpOperand *input : op.getDpsInputOperands()) inputRankSum += op.getRank(input); - for (OpOperand *output : op.getOutputOperands()) + for (OpOperand *output : op.getDpsInitOperands()) outputRankSum += op.getRank(output); return {inputRankSum, inputRankSum + outputRankSum}; } @@ -601,7 +601,7 @@ LinalgOp::reifyResultShapes(OpBuilder &b, createFlatListOfOperandDims(b, loc)); int64_t pos = 0; ArrayRef shapeExprs = resultShapesFromInputShapesMap.getResults(); - for (OpOperand *opOperand : getOutputOperands()) { + for (OpOperand *opOperand : getDpsInitOperands()) { SmallVector shapes; for (int64_t dim : llvm::seq(0, getRank(opOperand))) { if (checkDimExpr.visit(shapeExprs[pos])) diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index e9f2630..00893e6 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -767,8 +767,8 @@ void GenericOp::print(OpAsmPrinter &p) { } // Printing is shared with named ops, except for the region and attributes - printCommonStructuredOpParts(p, SmallVector(getInputOperands()), - SmallVector(getOutputOperands())); + printCommonStructuredOpParts(p, SmallVector(getDpsInputOperands()), + SmallVector(getDpsInitOperands())); genericAttrNames.push_back("operand_segment_sizes"); genericAttrNamesSet.insert(genericAttrNames.back()); @@ -858,7 +858,7 @@ void GenericOp::getEffects( SmallVectorImpl> &effects) { getGenericEffectsImpl(effects, getOperation()->getResults(), - getInputOperands(), getOutputOperands()); + getDpsInputOperands(), getDpsInitOperands()); } static bool isResultValueDead(linalg::GenericOp genericOp, OpResult result) { @@ -866,7 +866,7 @@ static bool isResultValueDead(linalg::GenericOp genericOp, OpResult result) { return false; // If out operand not used in payload, we can drop it. OpOperand *outputOpOperand = - genericOp.getOutputOperand(result.getResultNumber()); + genericOp.getDpsInitOperand(result.getResultNumber()); if (!genericOp.payloadUsesValueFromOperand(outputOpOperand)) return true; @@ -981,7 +981,7 @@ private: SmallVector &newIndexingMaps) const { llvm::SmallDenseMap origToNewPos; llvm::SmallDenseMap, unsigned> dedupedInputs; - for (const auto &en : llvm::enumerate(genericOp.getInputOperands())) { + for (const auto &en : llvm::enumerate(genericOp.getDpsInputOperands())) { OpOperand *inputOpOperand = en.value(); // Check if operand is dead and if dropping the indexing map makes the // loops to shape computation invalid. @@ -1029,7 +1029,7 @@ private: // If the op doesnt have tensor semantics, keep all the outputs as // preserved. if (!genericOp.hasTensorSemantics()) { - for (const auto &en : llvm::enumerate(genericOp.getOutputOperands())) { + for (const auto &en : llvm::enumerate(genericOp.getDpsInitOperands())) { origToNewPos[en.index()] = newOutputOperands.size(); newOutputOperands.push_back(en.value()->get()); newIndexingMaps.push_back(genericOp.getMatchingIndexingMap(en.value())); @@ -1043,7 +1043,7 @@ private: // computation. auto yieldOp = cast(genericOp.getBody()->getTerminator()); for (const auto &outputOpOperand : - llvm::enumerate(genericOp.getOutputOperands())) { + llvm::enumerate(genericOp.getDpsInitOperands())) { OpResult result = genericOp.getTiedOpResult(outputOpOperand.value()); AffineMap indexingMap = genericOp.getMatchingIndexingMap(outputOpOperand.value()); @@ -1111,22 +1111,22 @@ private: } }; - OpOperandVector origInputOperands = genericOp.getInputOperands(); - OpOperandVector newInputOperands = newOp.getInputOperands(); + OpOperandVector origInputOperands = genericOp.getDpsInputOperands(); + OpOperandVector newInputOperands = newOp.getDpsInputOperands(); updateReplacements(origInputOperands, newInputOperands, origInsToNewInsPos); - OpOperandVector origOutputOperands = genericOp.getOutputOperands(); - OpOperandVector newOutputOperands = newOp.getOutputOperands(); + OpOperandVector origOutputOperands = genericOp.getDpsInitOperands(); + OpOperandVector newOutputOperands = newOp.getDpsInitOperands(); updateReplacements(origOutputOperands, newOutputOperands, origOutsToNewOutsPos); // Drop the unused yield args. - if (newOp.getNumOutputs() != genericOp.getNumOutputs()) { + if (newOp.getNumDpsInits() != genericOp.getNumDpsInits()) { OpBuilder::InsertionGuard g(rewriter); YieldOp origYieldOp = cast(origOpBlock->getTerminator()); rewriter.setInsertionPoint(origYieldOp); - SmallVector newYieldVals(newOp.getNumOutputs(), nullptr); + SmallVector newYieldVals(newOp.getNumDpsInits(), nullptr); for (const auto &yieldOpOperands : llvm::enumerate(origYieldOp.getValues())) { auto it = origOutsToNewOutsPos.find(yieldOpOperands.index()); @@ -1167,9 +1167,9 @@ struct EraseIdentityGenericOp : public OpRewritePattern { // In the buffer case, we need to check exact buffer equality. if (genericOp.hasBufferSemantics()) { - if (genericOp.getNumInputs() == 1 && genericOp.getNumOutputs() == 1 && - genericOp.getInputOperand(0)->get() == - genericOp.getOutputOperand(0)->get()) { + if (genericOp.getNumDpsInputs() == 1 && genericOp.getNumDpsInits() == 1 && + genericOp.getDpsInputOperand(0)->get() == + genericOp.getDpsInitOperand(0)->get()) { rewriter.eraseOp(genericOp); return success(); } @@ -1238,7 +1238,7 @@ struct RemoveUnusedCycleInGenericOp : public OpRewritePattern { bool hasRemovedCycles = false; // Iterate over output operands and remove any unused cycles. for (const auto &outputOpOperand : - llvm::enumerate(genericOp.getOutputOperands())) { + llvm::enumerate(genericOp.getDpsInitOperands())) { // Check that result from out operand is dead. Value result = genericOp.getResult(outputOpOperand.index()); @@ -1370,8 +1370,8 @@ ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) { void MapOp::print(OpAsmPrinter &p) { p.increaseIndent(); printCommonStructuredOpPartsWithNewLine( - p, SmallVector(getInputOperands()), - SmallVector(getOutputOperands())); + p, SmallVector(getDpsInputOperands()), + SmallVector(getDpsInitOperands())); p.printOptionalAttrDict((*this)->getAttrs()); p.printNewline(); @@ -1405,8 +1405,7 @@ LogicalResult MapOp::verify() { } // The shape of each input must match the shape of the output. - auto outputShape = - getOutputOperand(0)->get().getType().cast().getShape(); + auto outputShape = getInit().getType().cast().getShape(); for (Type inputArgType : TypeRange{getInputs()}) { auto inputElemShape = inputArgType.cast().getShape(); if (inputElemShape != outputShape) { @@ -1436,7 +1435,7 @@ void MapOp::getEffects( SmallVectorImpl> &effects) { getGenericEffectsImpl(effects, getOperation()->getResults(), - getInputOperands(), getOutputOperands()); + getDpsInputOperands(), getDpsInitOperands()); } //===----------------------------------------------------------------------===// @@ -1488,12 +1487,12 @@ SmallVector ReduceOp::getIteratorTypesArray() { ArrayAttr ReduceOp::getIndexingMaps() { int64_t inputRank = getInputs()[0].getType().cast().getRank(); SmallVector affineMaps( - getNumInputs(), + getNumDpsInputs(), AffineMap::getMultiDimIdentityMap(inputRank, getContext())); AffineMap resultMap = AffineMap::getMultiDimIdentityMap(inputRank, getContext()) .dropResults(getDimensions()); - for (int64_t i = 0, e = getNumOutputs(); i < e; ++i) + for (int64_t i = 0, e = getNumDpsInits(); i < e; ++i) affineMaps.push_back(resultMap); return Builder(getContext()).getAffineMapArrayAttr(affineMaps); } @@ -1502,7 +1501,7 @@ void ReduceOp::getEffects( SmallVectorImpl> &effects) { getGenericEffectsImpl(effects, getOperation()->getResults(), - getInputOperands(), getOutputOperands()); + getDpsInputOperands(), getDpsInitOperands()); } static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser, @@ -1543,9 +1542,10 @@ static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName, void ReduceOp::print(OpAsmPrinter &p) { p.increaseIndent(); printCommonStructuredOpPartsWithNewLine( - p, SmallVector(getInputOperands()), - SmallVector(getOutputOperands())); + p, SmallVector(getDpsInputOperands()), + SmallVector(getDpsInitOperands())); p.printNewline(); + printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions()); p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()}); @@ -1562,7 +1562,7 @@ void ReduceOp::print(OpAsmPrinter &p) { LogicalResult ReduceOp::verify() { ArrayRef dimensionsRef = getDimensions(); - for (int64_t i = 1; i < getNumInputs(); ++i) { + for (int64_t i = 1; i < getNumDpsInputs(); ++i) { if (getInputs()[i].getType().cast().getShape() != getInputs()[0].getType().cast().getShape()) { return emitOpError() << "expects all inputs to have the same shapes. " @@ -1571,7 +1571,7 @@ LogicalResult ReduceOp::verify() { << " is not equal to the shape at input-index 0."; } } - for (int64_t i = 1; i < getNumOutputs(); ++i) { + for (int64_t i = 1; i < getNumDpsInits(); ++i) { if (getInits()[i].getType().cast().getShape() != getInits()[0].getType().cast().getShape()) { return emitOpError() << "expects all outputs to have the same shapes. " @@ -1632,8 +1632,8 @@ LogicalResult ReduceOp::verify() { // Check that the last block arguments match the element type of the outputs. for (auto [output, bbArg] : - llvm::zip(getOutputOperands(), - block->getArguments().take_back(getNumOutputs()))) { + llvm::zip(getDpsInitOperands(), + block->getArguments().take_back(getNumDpsInits()))) { auto outputElementType = output->get().getType().cast().getElementType(); if (outputElementType != bbArg.getType()) @@ -1712,9 +1712,10 @@ void TransposeOp::getAsmResultNames( void TransposeOp::print(OpAsmPrinter &p) { p.increaseIndent(); printCommonStructuredOpPartsWithNewLine( - p, SmallVector(getInputOperands()), - SmallVector(getOutputOperands())); + p, SmallVector(getDpsInputOperands()), + SmallVector(getDpsInitOperands())); p.printNewline(); + printDenseI64ArrayAttr(p, getPermutationAttrName(), getPermutation()); p.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrName()}); p.decreaseIndent(); @@ -1774,7 +1775,7 @@ void TransposeOp::getEffects( SmallVectorImpl> &effects) { getGenericEffectsImpl(effects, getOperation()->getResults(), - getInputOperands(), getOutputOperands()); + getDpsInputOperands(), getDpsInitOperands()); } //===----------------------------------------------------------------------===// @@ -1802,15 +1803,15 @@ ParseResult YieldOp::parse(OpAsmParser &parser, OperationState &result) { // Check the operand number and types must match the element types of the // LinalgOp interface's shaped operands. static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp) { - if (op.getNumOperands() != linalgOp.getNumOutputs()) + if (op.getNumOperands() != linalgOp.getNumDpsInits()) return op.emitOpError("expected number of yield values (") - << linalgOp.getNumOutputs() + << linalgOp.getNumDpsInits() << ") to match the number of operands of the enclosing " << "LinalgOp (" << op.getNumOperands() << ")"; for (OpOperand &opOperand : op->getOpOperands()) { OpOperand *outputOperand = - linalgOp.getOutputOperand(opOperand.getOperandNumber()); + linalgOp.getDpsInitOperand(opOperand.getOperandNumber()); Type elementType = getElementTypeOrSelf(outputOperand->get().getType()); if (opOperand.get().getType() != elementType) return op.emitOpError("type of yield operand ") @@ -1981,14 +1982,14 @@ struct FoldTensorCastProducerOp : public OpInterfaceRewritePattern { SmallVector newOperands; newOperands.reserve(op->getNumOperands()); // Inputs may fold. - for (auto *input : op.getInputOperands()) { + for (auto *input : op.getDpsInputOperands()) { auto tensorCastOp = input->get().getDefiningOp(); newOperands.push_back(canFoldIntoConsumerOp(tensorCastOp) ? tensorCastOp.getSource() : input->get()); } // Init tensors may fold, in which case the resultType must also change. - for (auto *output : op.getOutputOperands()) { + for (auto *output : op.getDpsInitOperands()) { auto tensorCastOp = output->get().getDefiningOp(); bool fold = canFoldIntoConsumerOp(tensorCastOp); newOperands.push_back(fold ? tensorCastOp.getOperand() : output->get()); @@ -2047,11 +2048,11 @@ struct FoldTensorCastConsumerOp : public OpRewritePattern { // for this cast, i.e. producer of the out operand, is also an operation // that folds with tensor.cast consumer (like this pattern), the cast will // continue to propagate as far up the stack as it can go. - OpOperand *outOperand = linalgOp.getOutputOperand(resultNumber); + OpOperand *outOperand = linalgOp.getDpsInitOperand(resultNumber); Value newOperand = rewriter.create(loc, resultType, outOperand->get()); - SmallVector newOperands{linalgOp.getInputOperands()}; - SmallVector outputOperands{linalgOp.getOutputOperands()}; + SmallVector newOperands{linalgOp.getDpsInputOperands()}; + SmallVector outputOperands{linalgOp.getDpsInitOperands()}; outputOperands[resultNumber] = newOperand; newOperands.append(outputOperands.begin(), outputOperands.end()); @@ -2124,7 +2125,7 @@ static void createNewOperandWithStaticSizes( return; auto sourceType = src.getType().cast(); Type resultType = sourceType; - if (sourceType.hasStaticShape() && linalgOp.isOutput(opOperand)) { + if (sourceType.hasStaticShape() && linalgOp.isDpsInit(opOperand)) { resultTypes.push_back(resultType); return; } @@ -2157,7 +2158,7 @@ static void createNewOperandWithStaticSizes( unsigned index = opOperand->getOperandNumber(); newOperands[index] = newOperand; } - if (linalgOp.isOutput(opOperand)) + if (linalgOp.isDpsInit(opOperand)) resultTypes.push_back(resultType); } @@ -2193,7 +2194,7 @@ struct InferStaticShapeOfOperands : public OpInterfaceRewritePattern { // change in their types. bool changeNeeded = false; newOperands.reserve(linalgOp->getNumOperands()); - resultTypes.reserve(linalgOp.getNumOutputs()); + resultTypes.reserve(linalgOp.getNumDpsInits()); // Iterate over all the operands and update the static sizes. for (OpOperand &opOperand : linalgOp->getOpOperands()) { diff --git a/mlir/lib/Dialect/Linalg/Transforms/BubbleUpExtractSlice.cpp b/mlir/lib/Dialect/Linalg/Transforms/BubbleUpExtractSlice.cpp index 383a926..c75151a 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/BubbleUpExtractSlice.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/BubbleUpExtractSlice.cpp @@ -63,7 +63,7 @@ struct BubbleUpExtractSliceOpPattern "expected single use of linalg op"); } - if (linalgOp.getNumOutputs() != 1) { + if (linalgOp.getNumDpsInits() != 1) { return rewriter.notifyMatchFailure(sliceOp, "expected single output of linalg op"); } @@ -80,7 +80,7 @@ struct BubbleUpExtractSliceOpPattern return rewriter.notifyMatchFailure(sliceOp, "expected no rank reduction"); } - OpOperand *outOperand = linalgOp.getOutputOperand(0); + OpOperand *outOperand = linalgOp.getDpsInitOperand(0); AffineMap indexingMap = linalgOp.getMatchingIndexingMap(outOperand); if (!indexingMap.isProjectedPermutation()) { return rewriter.notifyMatchFailure( @@ -119,7 +119,7 @@ struct BubbleUpExtractSliceOpPattern /*omitPartialTileCheck=*/true); SmallVector resultTensorTypes; - for (OpOperand *opOperand : linalgOp.getOutputOperands()) + for (OpOperand *opOperand : linalgOp.getDpsInitOperands()) resultTensorTypes.push_back( tiledOperands[opOperand->getOperandNumber()].getType()); diff --git a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp index ca17a21..f954ac0 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp @@ -41,8 +41,8 @@ bufferizeDestinationStyleOpInterface(RewriterBase &rewriter, // New input operands for the cloned op. SmallVector newInputBuffers; - newInputBuffers.reserve(op.getNumInputs()); - for (OpOperand *opOperand : op.getInputOperands()) { + newInputBuffers.reserve(op.getNumDpsInputs()); + for (OpOperand *opOperand : op.getDpsInputOperands()) { if (op.isScalar(opOperand)) { newInputBuffers.push_back(opOperand->get()); continue; @@ -56,7 +56,7 @@ bufferizeDestinationStyleOpInterface(RewriterBase &rewriter, // New output operands for the cloned op. SmallVector newOutputBuffers; for (OpResult opResult : op->getOpResults()) { - OpOperand *opOperand = op.getOutputOperand(opResult.getResultNumber()); + OpOperand *opOperand = op.getDpsInitOperand(opResult.getResultNumber()); FailureOr resultBuffer = getBuffer(rewriter, opOperand->get(), options); if (failed(resultBuffer)) @@ -111,7 +111,7 @@ struct LinalgOpInterface auto genericOp = cast(op); // The i-th OpResult may alias with the i-th "out" tensor. - return {genericOp.getOutputOperand(opResult.getResultNumber())}; + return {genericOp.getDpsInitOperand(opResult.getResultNumber())}; } SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, @@ -119,7 +119,7 @@ struct LinalgOpInterface auto genericOp = cast(op); // The i-th "out" tensor may alias with the i-th OpResult. - if (genericOp.isOutput(&opOperand)) + if (genericOp.isDpsInit(&opOperand)) return {genericOp.getTiedOpResult(&opOperand)}; return {}; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp index a21e0fc..d0efba5 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp @@ -59,7 +59,7 @@ public: return failure(); // Only support ops generating one output for now. - if (genericOp.getNumOutputs() != 1) + if (genericOp.getNumDpsInits() != 1) return failure(); auto outputType = genericOp.getResultTypes().front().dyn_cast(); @@ -95,7 +95,7 @@ public: [](AffineMap map) { return map.isPermutation(); })) return failure(); - for (OpOperand *operand : genericOp.getOutputOperands()) { + for (OpOperand *operand : genericOp.getDpsInitOperands()) { if (genericOp.payloadUsesValueFromOperand(operand)) return failure(); } @@ -112,9 +112,9 @@ public: return failure(); // All inputs should be constants. - int numInputs = genericOp.getNumInputs(); + int numInputs = genericOp.getNumDpsInputs(); SmallVector inputValues(numInputs); - for (const auto &en : llvm::enumerate(genericOp.getInputOperands())) { + for (const auto &en : llvm::enumerate(genericOp.getDpsInputOperands())) { if (!matchPattern(en.value()->get(), m_Constant(&inputValues[en.index()]))) return failure(); @@ -122,7 +122,7 @@ public: // Identified this as a potential candidate for folding. Now check the // policy to see whether we are allowed to proceed. - for (OpOperand *operand : genericOp.getInputOperands()) { + for (OpOperand *operand : genericOp.getDpsInputOperands()) { if (!controlFn(operand)) return failure(); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp b/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp index 327e8be..bbba218 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp @@ -176,7 +176,8 @@ DecomposeLinalgOp::createPeeledGenericOp(GenericOp genericOp, } } if (resultNumber) { - newInitValues.push_back(genericOp.getOutputOperand(*resultNumber)->get()); + newInitValues.push_back( + genericOp.getDpsInitOperand(*resultNumber)->get()); OpResult result = genericOp.getResult(*resultNumber).cast(); newResultTypes.push_back(result.getType()); peeledGenericOpIndexingMaps.push_back( @@ -224,7 +225,7 @@ DecomposeLinalgOp::createResidualGenericOp(GenericOp genericOp, /// Add indexing maps for the newly added operands. Use the same map /// as those used for the new results of the peeledGenericOp. auto indexingMaps = llvm::to_vector( - llvm::map_range(genericOp.getInputOperands(), [&](OpOperand *operand) { + llvm::map_range(genericOp.getDpsInputOperands(), [&](OpOperand *operand) { return genericOp.getMatchingIndexingMap(operand); })); for (auto resultNum : @@ -233,7 +234,7 @@ DecomposeLinalgOp::createResidualGenericOp(GenericOp genericOp, indexingMaps.push_back( peeledGenericOp.getIndexingMapMatchingResult(result)); } - for (OpOperand *outOperand : genericOp.getOutputOperands()) + for (OpOperand *outOperand : genericOp.getDpsInitOperands()) indexingMaps.push_back(genericOp.getMatchingIndexingMap(outOperand)); auto indexingMapAttr = rewriter.getAffineMapArrayAttr(indexingMaps); @@ -261,7 +262,7 @@ DecomposeLinalgOp::matchAndRewrite(GenericOp genericOp, genericOp, "only operations with tensor semantics are handled"); } - if (llvm::any_of(genericOp.getOutputOperands(), [&](OpOperand *outOperand) { + if (llvm::any_of(genericOp.getDpsInitOperands(), [&](OpOperand *outOperand) { return !genericOp.getMatchingIndexingMap(outOperand).isPermutation(); })) { return rewriter.notifyMatchFailure( @@ -322,7 +323,7 @@ DecomposeLinalgOp::matchAndRewrite(GenericOp genericOp, /// In the split operations, replace block arguments uses that refer to /// original operation to the block arguments of the newly created operation. - unsigned origNumInputs = genericOp.getNumInputs(); + unsigned origNumInputs = genericOp.getNumDpsInputs(); for (const auto &inputBlockArg : llvm::enumerate(genericOp.getBody()->getArguments())) { Value residualOpReplacementArg = diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp index 2fcb2ac..7b9b735 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -435,7 +435,8 @@ struct ReplaceUnitExtents : public OpRewritePattern { SmallVector resultTypes; resultTypes.reserve(genericOp.getNumResults()); for (unsigned i : llvm::seq(0, genericOp.getNumResults())) - resultTypes.push_back(newInputOutputTypes[i + genericOp.getNumInputs()]); + resultTypes.push_back( + newInputOutputTypes[i + genericOp.getNumDpsInputs()]); GenericOp replacementOp = rewriter.create( loc, resultTypes, newInputs, newOutputs, newIndexingMaps, genericOp.getIteratorTypesArray()); @@ -447,7 +448,7 @@ struct ReplaceUnitExtents : public OpRewritePattern { // the original shape. SmallVector resultReplacements; for (const auto &result : llvm::enumerate(replacementOp.getResults())) { - unsigned index = result.index() + replacementOp.getNumInputs(); + unsigned index = result.index() + replacementOp.getNumDpsInputs(); auto origResultType = genericOp.getResult(result.index()).getType(); auto newResult = maybeExpand(result.value(), origResultType, diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp index 80cef16..6a9c4e3 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -90,7 +90,7 @@ bool mlir::linalg::areElementwiseOpsFusable(OpOperand *fusedOperand) { // Only allow fusing the producer of an input operand for now. // TODO: allow fusing the producer of an output operand. - if (!consumer.isInput(fusedOperand)) + if (!consumer.isDpsInput(fusedOperand)) return false; // Get the consumer index map. The number of results of the consumer index @@ -102,7 +102,7 @@ bool mlir::linalg::areElementwiseOpsFusable(OpOperand *fusedOperand) { // Finally the index_map for the result must be invertible. For now just // verify it is a permutation. AffineMap producerResultIndexMap = - producer.getMatchingIndexingMap(producer.getOutputOperand(0)); + producer.getMatchingIndexingMap(producer.getDpsInitOperand(0)); if (!producerResultIndexMap.isPermutation()) return false; @@ -128,7 +128,7 @@ bool mlir::linalg::areElementwiseOpsFusable(OpOperand *fusedOperand) { addToCoveredDims(operandMap); } - for (OpOperand *operand : producer.getInputOperands()) { + for (OpOperand *operand : producer.getDpsInputOperands()) { AffineMap newIndexingMap = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp( operand, producerResultIndexMap, consumerIndexMap); @@ -179,7 +179,7 @@ generateFusedElementwiseOpRegion(RewriterBase &rewriter, GenericOp fusedOp, } } // TODO: allow fusing the producer of an output operand. - assert(consumer.isInput(fusedOperand) && + assert(consumer.isDpsInput(fusedOperand) && "expected producer of input operand"); // 3. Consumer input operands up to consumerIdx (exclusive). for (BlockArgument bbArg : consumerBlock.getArguments().take_front( @@ -191,24 +191,24 @@ generateFusedElementwiseOpRegion(RewriterBase &rewriter, GenericOp fusedOp, // 4. Splice in producer's input operands. for (BlockArgument bbArg : - producerBlock.getArguments().take_front(producer.getNumInputs())) + producerBlock.getArguments().take_front(producer.getNumDpsInputs())) mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc())); // 5. Remaining consumer's input operands (drop past index `consumerIdx`). for (BlockArgument bbArg : consumerBlock.getArguments() - .take_front(consumer.getNumInputs()) + .take_front(consumer.getNumDpsInputs()) .drop_front(fusedOperand->getOperandNumber() + 1)) mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc())); // 6. All of the producer's output operands for (BlockArgument bbArg : - producerBlock.getArguments().take_back(producer.getNumOutputs())) + producerBlock.getArguments().take_back(producer.getNumDpsInits())) mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc())); // 7. All of consumer's output operands. for (BlockArgument bbArg : - consumerBlock.getArguments().take_back(consumer.getNumOutputs())) + consumerBlock.getArguments().take_back(consumer.getNumDpsInits())) mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc())); // 8. Clone all producer operations except for the yield and index operations @@ -267,22 +267,24 @@ mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter, auto producer = cast(producerResult.getOwner()); auto consumer = cast(fusedOperand->getOwner()); // TODO: allow fusing the producer of an output operand. - assert(consumer.isInput(fusedOperand) && + assert(consumer.isDpsInput(fusedOperand) && "expected producer of input operand"); // Compute the fused operands list and indexing maps. SmallVector fusedInputOperands, fusedOutputOperands; SmallVector fusedResultTypes; SmallVector fusedIndexMaps; - fusedInputOperands.reserve(producer.getNumInputs() + consumer.getNumInputs()); - fusedOutputOperands.reserve(producer.getNumOutputs() + - consumer.getNumOutputs()); - fusedResultTypes.reserve(producer.getNumOutputs() + consumer.getNumOutputs()); + fusedInputOperands.reserve(producer.getNumDpsInputs() + + consumer.getNumDpsInputs()); + fusedOutputOperands.reserve(producer.getNumDpsInits() + + consumer.getNumDpsInits()); + fusedResultTypes.reserve(producer.getNumDpsInits() + + consumer.getNumDpsInits()); fusedIndexMaps.reserve(producer->getNumOperands() + consumer->getNumOperands()); // In the following, numbering matches that of `generateFusedTensorOpRegion`. // 3. Consumer input operands/maps up to consumerIdx (exclusive). - auto consumerInputs = consumer.getInputOperands(); + auto consumerInputs = consumer.getDpsInputOperands(); auto *it = llvm::find_if(consumerInputs, [&](OpOperand *operand) { return operand == fusedOperand; }); @@ -294,7 +296,7 @@ mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter, // 4. Splice in producer's input operands/maps. AffineMap producerResultIndexMap = producer.getIndexingMapMatchingResult(producerResult); - for (OpOperand *opOperand : producer.getInputOperands()) { + for (OpOperand *opOperand : producer.getDpsInputOperands()) { fusedInputOperands.push_back(opOperand->get()); // Compute indexing maps for the producer args in the fused operation. AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp( @@ -311,7 +313,7 @@ mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter, } // 6. Collect all of the producer outputs. - for (OpOperand *opOperand : producer.getOutputOperands()) { + for (OpOperand *opOperand : producer.getDpsInitOperands()) { fusedOutputOperands.push_back(opOperand->get()); AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp( opOperand, producerResultIndexMap, @@ -321,7 +323,7 @@ mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter, } // 7. All of consumer's output operands (skip operands: added by the builder). - for (OpOperand *opOperand : consumer.getOutputOperands()) { + for (OpOperand *opOperand : consumer.getDpsInitOperands()) { fusedOutputOperands.push_back(opOperand->get()); fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(opOperand)); fusedResultTypes.push_back(opOperand->get().getType()); @@ -721,8 +723,8 @@ fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp, rewriter.setInsertionPoint(genericOp); SmallVector expandedOpOperands; - expandedOpOperands.reserve(genericOp.getNumInputs()); - for (OpOperand *opOperand : genericOp.getInputOperands()) { + expandedOpOperands.reserve(genericOp.getNumDpsInputs()); + for (OpOperand *opOperand : genericOp.getDpsInputOperands()) { if (opOperand == fusableOpOperand) { expandedOpOperands.push_back(isExpanding ? expandingReshapeOp.getSrc() : collapsingReshapeOp.getSrc()); @@ -756,7 +758,7 @@ fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp, Location loc = genericOp.getLoc(); SmallVector outputs; - for (OpOperand *opOperand : genericOp.getOutputOperands()) { + for (OpOperand *opOperand : genericOp.getDpsInitOperands()) { AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand); auto opOperandType = opOperand->get().getType().cast(); RankedTensorType expandedOutputType = @@ -805,7 +807,7 @@ fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp, SmallVector reassociation = getReassociationForExpansion( genericOp.getMatchingIndexingMap( - genericOp.getOutputOperand(resultNumber)), + genericOp.getDpsInitOperand(resultNumber)), expansionInfo); resultVals.push_back(rewriter.create( genericOp.getLoc(), opResult.getType(), @@ -834,7 +836,7 @@ public: LogicalResult matchAndRewrite(GenericOp genericOp, PatternRewriter &rewriter) const override { - for (OpOperand *opOperand : genericOp.getInputOperands()) { + for (OpOperand *opOperand : genericOp.getDpsInputOperands()) { tensor::CollapseShapeOp reshapeOp = opOperand->get().getDefiningOp(); if (!reshapeOp) @@ -888,7 +890,7 @@ struct FoldReshapeWithGenericOpByExpansion if (!isFusableWithReshapeByDimExpansion( producer, - producer.getOutputOperand(producerResult.getResultNumber()))) { + producer.getDpsInitOperand(producerResult.getResultNumber()))) { return rewriter.notifyMatchFailure( reshapeOp, "failed preconditions of fusion with producer generic op"); } @@ -900,7 +902,7 @@ struct FoldReshapeWithGenericOpByExpansion Optional> replacementValues = fuseWithReshapeByExpansion( producer, reshapeOp, - producer.getOutputOperand(producerResult.getResultNumber()), rewriter); + producer.getDpsInitOperand(producerResult.getResultNumber()), rewriter); if (!replacementValues) { return rewriter.notifyMatchFailure(reshapeOp, "fusion by expansion failed"); @@ -1046,7 +1048,7 @@ static SmallVector getCollapsableIterationSpaceDims(GenericOp genericOp, OpOperand *fusableOperand, ArrayRef reassociation) { // Some basic checks for this fusion to be valid. - if (!genericOp.hasTensorSemantics() || genericOp.getNumOutputs() != 1) + if (!genericOp.hasTensorSemantics() || genericOp.getNumDpsInits() != 1) return {}; if (!llvm::all_of(genericOp.getIndexingMapsArray(), [](AffineMap map) { @@ -1416,8 +1418,8 @@ static FailureOr> collapseGenericOpIterationDims( Location loc = genericOp->getLoc(); // Get the input operands. - auto inputOperands = llvm::to_vector( - llvm::map_range(genericOp.getInputOperands(), [&](OpOperand *opOperand) { + auto inputOperands = llvm::to_vector(llvm::map_range( + genericOp.getDpsInputOperands(), [&](OpOperand *opOperand) { return getCollapsedOpOperand(loc, genericOp, opOperand, collapsingInfo, rewriter); })); @@ -1425,9 +1427,9 @@ static FailureOr> collapseGenericOpIterationDims( // Get the output operands and result types. SmallVector resultTypes; SmallVector outputOperands; - resultTypes.reserve(genericOp.getNumOutputs()); - outputOperands.reserve(genericOp.getNumOutputs()); - for (OpOperand *output : genericOp.getOutputOperands()) { + resultTypes.reserve(genericOp.getNumDpsInits()); + outputOperands.reserve(genericOp.getNumDpsInits()); + for (OpOperand *output : genericOp.getDpsInitOperands()) { Value newOutput = getCollapsedOpOperand(loc, genericOp, output, collapsingInfo, rewriter); outputOperands.push_back(newOutput); @@ -1575,7 +1577,7 @@ public: PatternRewriter &rewriter) const override { if (!genericOp.hasTensorSemantics()) return failure(); - for (OpOperand *opOperand : genericOp.getInputOperands()) { + for (OpOperand *opOperand : genericOp.getDpsInputOperands()) { Operation *def = opOperand->get().getDefiningOp(); TypedAttr constantAttr; auto isScalarOrSplatConstantOp = [&constantAttr](Operation *def) -> bool { @@ -1616,9 +1618,9 @@ public: SmallVector fusedOperands; SmallVector fusedLocs{genericOp.getLoc()}; fusedIndexMaps.reserve(genericOp->getNumOperands()); - fusedOperands.reserve(genericOp.getNumInputs()); - fusedLocs.reserve(fusedLocs.size() + genericOp.getNumInputs()); - for (OpOperand *inputOperand : genericOp.getInputOperands()) { + fusedOperands.reserve(genericOp.getNumDpsInputs()); + fusedLocs.reserve(fusedLocs.size() + genericOp.getNumDpsInputs()); + for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) { if (inputOperand == opOperand) continue; Value inputValue = inputOperand->get(); @@ -1627,7 +1629,7 @@ public: fusedOperands.push_back(inputValue); fusedLocs.push_back(inputValue.getLoc()); } - for (OpOperand *outputOperand : genericOp.getOutputOperands()) + for (OpOperand *outputOperand : genericOp.getDpsInitOperands()) fusedIndexMaps.push_back( genericOp.getMatchingIndexingMap(outputOperand)); @@ -1687,7 +1689,7 @@ struct RemoveOutsDependency : public OpRewritePattern { rewriter.startRootUpdate(op); bool modifiedOutput = false; Location loc = op.getLoc(); - for (OpOperand *opOperand : op.getOutputOperands()) { + for (OpOperand *opOperand : op.getDpsInitOperands()) { if (!op.payloadUsesValueFromOperand(opOperand)) { Value operandVal = opOperand->get(); auto operandType = operandVal.getType().dyn_cast(); @@ -1735,7 +1737,7 @@ struct FoldFillWithGenericOp : public OpRewritePattern { return failure(); bool fillFound = false; Block &payload = genericOp.getRegion().front(); - for (OpOperand *opOperand : genericOp.getInputOperands()) { + for (OpOperand *opOperand : genericOp.getDpsInputOperands()) { if (!genericOp.payloadUsesValueFromOperand(opOperand)) continue; FillOp fillOp = opOperand->get().getDefiningOp(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusePadOpWithLinalgProducer.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusePadOpWithLinalgProducer.cpp index 866b411..4736165 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/FusePadOpWithLinalgProducer.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/FusePadOpWithLinalgProducer.cpp @@ -110,7 +110,7 @@ struct FusePadOp : OpRewritePattern { // Clone the generic op. auto clonedOp = cast(rewriter.clone(*linalgOp.getOperation())); - clonedOp.setOutputOperand(resultNumber, slice.getResult()); + clonedOp.setDpsInitOperand(resultNumber, slice.getResult()); // Insert it back into the result of the fill. rewriter.replaceOpWithNewOp( diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp index 5c2c987..2d51b8d 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -150,7 +150,7 @@ static LinalgOp fuse(OpBuilder &b, LinalgOp producer, // fully dynamic at construction time. SmallVector resultTypes; resultTypes.reserve(producer->getNumResults()); - for (OpOperand *operand : producer.getOutputOperands()) { + for (OpOperand *operand : producer.getDpsInitOperands()) { auto tensorType = operand->get().getType().dyn_cast(); if (!tensorType) continue; @@ -211,7 +211,7 @@ static bool isStructurallyFusableProducer(LinalgOp producer, Value consumedView, "expected linalg op with buffer semantics"); assert(consumer.hasBufferSemantics() && "expected linalg op with buffer semantics"); - if (producer.getNumOutputs() != 1) { + if (producer.getNumDpsInits() != 1) { LLVM_DEBUG(llvm::dbgs() << "\nNot structurally fusable (multi-output)"); return false; } @@ -443,7 +443,7 @@ mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpResult producerOpResult, b.setInsertionPoint(consumerOp); LLVM_DEBUG(llvm::dbgs() << "Fuse into consumer: " << *consumerOp << "\n"); OpOperand *opOperand = - producerOp.getOutputOperand(producerOpResult.getResultNumber()); + producerOp.getDpsInitOperand(producerOpResult.getResultNumber()); LinalgOp fusedProducer = fuse(b, producerOp, producerOp.getMatchingIndexingMap(opOperand), consumerOpOperand); diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp index 1dd6c35..d0cf684 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp @@ -69,7 +69,7 @@ getTiledProducerLoops(OpResult producerResult, // Get the indexing map of the `producerOp` output operand that matches // ´producerResult´. AffineMap producerIndexingMap = producerOp.getMatchingIndexingMap( - producerOp.getOutputOperand(producerResult.getResultNumber())); + producerOp.getDpsInitOperand(producerResult.getResultNumber())); // Keep only the tiled result slice dimensions of `producerIndexingMap`. AffineMap tiledProducerIndexingSubMap = @@ -173,14 +173,14 @@ static LinalgOp getTiledProducer(OpBuilder &b, OpResult producerResult, // output operand. if (iterArg) { OpOperand *outputOperand = - producerOp.getOutputOperand(producerResult.getResultNumber()); + producerOp.getDpsInitOperand(producerResult.getResultNumber()); iterArg->set(outputOperand->get()); tiledOperands[outputOperand->getOperandNumber()] = sliceOp.getResult(); } // Clone the producer using the tiled producer operands. TypeRange resultTypes = ValueRange(tiledOperands) - .take_back(producerOp.getNumOutputs()) + .take_back(producerOp.getNumDpsInits()) .getTypes(); LinalgOp clonedOp = producerOp.clone(b, loc, resultTypes, tiledOperands); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp index ea6ce39..da43b49 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp @@ -50,8 +50,8 @@ FailureOr mlir::linalg::generalizeNamedOp(RewriterBase &rewriter, if (failed(generalizeNamedOpPrecondition(linalgOp))) return rewriter.notifyMatchFailure(linalgOp, "preconditions not met"); - SmallVector inputs = linalgOp.getInputOperands(); - SmallVector outputs = linalgOp.getOutputOperands(); + SmallVector inputs = linalgOp.getDpsInputOperands(); + SmallVector outputs = linalgOp.getDpsInitOperands(); SmallVector indexingMaps = linalgOp.getIndexingMapsArray(); SmallVector iterators = linalgOp.getIteratorTypesArray(); SmallVector resultTypes = linalgOp.hasTensorSemantics() diff --git a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp index 7515e30..baeb5c2 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp @@ -111,7 +111,7 @@ private: static bool isOnlyUsedAsInputOfLinalgOp(tensor::PadOp padOp) { for (OpOperand &use : padOp.getResult().getUses()) { auto linalgUser = dyn_cast(use.getOwner()); - if (!linalgUser || !linalgUser.isInput(&use)) { + if (!linalgUser || !linalgUser.isDpsInput(&use)) { LLVM_DEBUG(DBGS() << "Found a use of " << *(padOp) << "\nthat is not an input tensor of a LinalgOp, " << "cannot hoist\n" diff --git a/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp b/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp index 4ea889d..c9b118b 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp @@ -41,9 +41,9 @@ struct InlineScalarOperands : public OpRewritePattern { SmallVector scalarOperands; SmallVector newIndexingMaps; SmallVector newOperands; - for (OpOperand *opOperand : genericOp.getInputOperands()) { + for (OpOperand *opOperand : genericOp.getDpsInputOperands()) { AffineMap map = genericOp.getMatchingIndexingMap(opOperand); - if (genericOp.isInput(opOperand) && map.isConstant()) { + if (genericOp.isDpsInput(opOperand) && map.isConstant()) { scalarOperands.emplace_back(opOperand->getOperandNumber()); } else { newIndexingMaps.emplace_back(map); @@ -54,7 +54,7 @@ struct InlineScalarOperands : public OpRewritePattern { if (scalarOperands.empty()) return failure(); - for (OpOperand *opOperand : genericOp.getOutputOperands()) + for (OpOperand *opOperand : genericOp.getDpsInitOperands()) newIndexingMaps.emplace_back(genericOp.getMatchingIndexingMap(opOperand)); Location loc = genericOp->getLoc(); @@ -70,7 +70,7 @@ struct InlineScalarOperands : public OpRewritePattern { rewriter.setInsertionPointToStart(body); for (auto idx : llvm::reverse(scalarOperands)) { - OpOperand *opOperand = genericOp.getInputOperand(idx); + OpOperand *opOperand = genericOp.getDpsInputOperand(idx); AffineMap map = genericOp.getMatchingIndexingMap(opOperand); SmallVector indices = map.getConstantResults(); SmallVector indicesValues; diff --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp index 4fc9149..b8d0b09 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp @@ -138,7 +138,7 @@ static void emitScalarImplementation(OpBuilder &b, Location loc, // TODO: Avoid the loads if the corresponding argument of the // region has no uses. // 1.a. Emit load from input operand or for scalars access the operand itself. - for (OpOperand *inputOperand : linalgOp.getInputOperands()) { + for (OpOperand *inputOperand : linalgOp.getDpsInputOperands()) { if (linalgOp.isScalar(inputOperand)) { indexedValues.push_back(inputOperand->get()); continue; @@ -149,7 +149,7 @@ static void emitScalarImplementation(OpBuilder &b, Location loc, b.create(loc, inputOperand->get(), indexing)); } // 1.b. Emit load from output views. - for (OpOperand *outputOperand : linalgOp.getOutputOperands()) { + for (OpOperand *outputOperand : linalgOp.getDpsInitOperands()) { SmallVector indexing = makeCanonicalAffineApplies( b, loc, linalgOp.getMatchingIndexingMap(outputOperand), allIvsPlusDims); indexedValues.push_back( @@ -161,7 +161,7 @@ static void emitScalarImplementation(OpBuilder &b, Location loc, // 3. Emit store. SmallVector, 8> indexing; SmallVector outputBuffers; - for (OpOperand *outputOperand : linalgOp.getOutputOperands()) { + for (OpOperand *outputOperand : linalgOp.getDpsInitOperands()) { if (!outputOperand->get().getType().isa()) continue; indexing.push_back(makeCanonicalAffineApplies( diff --git a/mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp b/mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp index c2a8dbe..cabd342 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp @@ -108,9 +108,9 @@ struct SimplifyDepthwiseConvOp LogicalResult matchAndRewrite(DepthwiseConv2DNhwcHwcmOp op, PatternRewriter &rewriter) const override { Operation *operation = op.getOperation(); - Value input = op.getInputOperand(0)->get(); - Value kernel = op.getInputOperand(1)->get(); - Value init = op.getOutputOperand(0)->get(); + Value input = op.getDpsInputOperand(0)->get(); + Value kernel = op.getDpsInputOperand(1)->get(); + Value init = op.getDpsInitOperand(0)->get(); auto stride = op.getStrides(); auto dilation = op.getDilations(); @@ -128,11 +128,11 @@ struct SimplifyDepthwiseConvQOp LogicalResult matchAndRewrite(DepthwiseConv2DNhwcHwcmQOp op, PatternRewriter &rewriter) const override { Operation *operation = op.getOperation(); - Value input = op.getInputOperand(0)->get(); - Value kernel = op.getInputOperand(1)->get(); - Value iZp = op.getInputOperand(2)->get(); - Value kZp = op.getInputOperand(3)->get(); - Value init = op.getOutputOperand(0)->get(); + Value input = op.getDpsInputOperand(0)->get(); + Value kernel = op.getDpsInputOperand(1)->get(); + Value iZp = op.getDpsInputOperand(2)->get(); + Value kZp = op.getDpsInputOperand(3)->get(); + Value init = op.getDpsInitOperand(0)->get(); auto stride = op.getStrides(); auto dilation = op.getDilations(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp index 6f642ea..1d966b3 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp @@ -339,7 +339,7 @@ promoteSubViews(ImplicitLocOpBuilder &b, LinalgOp op, else opViews.push_back( (*promotedBuffersAndViews)[operandNumber].partialLocalView); - if (operandNumber >= op.getNumInputs()) + if (operandNumber >= op.getNumDpsInputs()) writebackViews.emplace_back(std::make_pair( opOperand.get(), (*promotedBuffersAndViews)[operandNumber].partialLocalView)); diff --git a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp index 92d04c1..32d05c5 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp @@ -96,7 +96,7 @@ FailureOr mlir::linalg::splitReduction( SmallVector newInputs; SmallVector newMaps; // Calculate the new shapes and indexing maps of the input operands. - for (OpOperand *operand : op.getInputOperands()) { + for (OpOperand *operand : op.getDpsInputOperands()) { AffineMap map = op.getMatchingIndexingMap(operand); SmallVector newShape; SmallVector exprs; @@ -151,8 +151,8 @@ FailureOr mlir::linalg::splitReduction( // Calculate the new output map and shape, we insert the new dimension based // on the index returned by `controlSplitReductionFn`. SmallVector newOutputShape; - AffineMap oldOutputMap = op.getMatchingIndexingMap(op.getOutputOperand(0)); - ArrayRef oldShape = op.getShape(op.getOutputOperand(0)); + AffineMap oldOutputMap = op.getMatchingIndexingMap(op.getDpsInitOperand(0)); + ArrayRef oldShape = op.getShape(op.getDpsInitOperand(0)); SmallVector outputExpr; for (unsigned idx : llvm::seq(0, oldOutputMap.getNumResults() + 1)) { @@ -229,7 +229,7 @@ FailureOr mlir::linalg::splitReduction( auto reduction = b.create( loc, op->getResultTypes(), ValueRange({genericOp.getResult(0)}), - SmallVector{op.getOutputOperands()}, reductionMaps, + SmallVector{op.getDpsInitOperands()}, reductionMaps, reductionIteratorTypes, [reductionOp](OpBuilder &b, Location loc, ValueRange inputs) { Operation *clonedReductionOp = b.clone(*reductionOp); @@ -317,7 +317,7 @@ FailureOr mlir::linalg::splitReductionByScaling( return b.notifyMatchFailure(op, "unknown reduction neutral"); // TODO: relax this when multi-reduction support is available. - if (op.getNumOutputs() != static_cast(neutralElements.size())) + if (op.getNumDpsInits() != static_cast(neutralElements.size())) return b.notifyMatchFailure(op, "expect one reduction per output"); // Rewrite part. @@ -337,11 +337,11 @@ FailureOr mlir::linalg::splitReductionByScaling( // For now assume outputs are 1-1 with reduction neutralElements. // TODO: generalize when multi-reduction support is available. SmallVector newOutputs; - newOutputs.reserve(op.getNumOutputs()); + newOutputs.reserve(op.getNumDpsInits()); SmallVector emptyOrAllocTensorOps; SmallVector fillOps; - fillOps.reserve(op.getNumOutputs()); - for (auto it : llvm::zip(op.getOutputOperands(), neutralElements)) { + fillOps.reserve(op.getNumDpsInits()); + for (auto it : llvm::zip(op.getDpsInitOperands(), neutralElements)) { Value rankedTensor = std::get<0>(it)->get(); auto t = rankedTensor.getType().cast(); RankedTensorType newT = RankedTensorType::Builder(t).insertDim( @@ -367,7 +367,7 @@ FailureOr mlir::linalg::splitReductionByScaling( // Reindex existing input indexings: k -> k * splitFactor + k'. SmallVector newMaps; newMaps.reserve(op->getNumOperands() + 1); - for (OpOperand *o : op.getInputOperands()) + for (OpOperand *o : op.getDpsInputOperands()) newMaps.push_back(scaleReductionDim(op, *o, reductionDimPos, splitFactor)); // Provision a new indexing for the shape-only tensor. auto nDims = op.getNumLoops() + 1; @@ -378,13 +378,13 @@ FailureOr mlir::linalg::splitReductionByScaling( // TODO: a subset of these may not reduce along reducePos and should be // reindexed: k -> k * splitFactor + k', when multi-reduction support is // available. - for (OpOperand *o : op.getOutputOperands()) + for (OpOperand *o : op.getDpsInitOperands()) newMaps.push_back(insertParallelDim(op, *o, reductionDimPos, reductionDimSize / splitFactor)); // Step 3. Handle operands. // Compute the new input tensors. - SmallVector newInputs(op.getInputOperands()); + SmallVector newInputs(op.getDpsInputOperands()); // Add a single shape-only tensor to carry the dimensions without resorting to // more complex inversions. newInputs.push_back(b.create( @@ -413,7 +413,7 @@ FailureOr mlir::linalg::splitReductionByScaling( // TODO: all results can be handled in a single GenericOp, when // multi-reduction support is available. SmallVector results; - for (auto it : llvm::zip(genericOp->getResults(), op.getOutputOperands(), + for (auto it : llvm::zip(genericOp->getResults(), op.getDpsInitOperands(), combinerOps)) { Value reindexedOutput = std::get<0>(it); Value originalOutput = std::get<1>(it)->get(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp index b5d83b8..5937da3 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -324,7 +324,7 @@ static FailureOr tileToForeachThreadOpImpl( Operation *clonedOp = b.clone(*op.getOperation()); auto destinationStyleOp = dyn_cast(clonedOp); if (destinationStyleOp) { - for (OpOperand *outOperand : destinationStyleOp.getOutputOperands()) { + for (OpOperand *outOperand : destinationStyleOp.getDpsInitOperands()) { auto *it = llvm::find(dest, outOperand->get()); assert(it != dest.end() && "dest operand not found in dest"); unsigned destNum = std::distance(dest.begin(), it); diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp index 86dae14..c843f0f 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp @@ -60,12 +60,12 @@ static LogicalResult inlinePayload(OpBuilder &b, LinalgOp linalgOp, Location loc = terminator->getLoc(); for (const auto &operand : llvm::enumerate(terminator->getOperands())) { Value toStore = map.lookupOrDefault(operand.value()); - OpOperand *storeInto = linalgOp.getOutputOperand(operand.index()); + OpOperand *storeInto = linalgOp.getDpsInitOperand(operand.index()); auto indices = getIndicesForAccess( b, loc, linalgOp.getMatchingIndexingMap(storeInto), ivs); - b.create(loc, toStore, - linalgOp.getOutputOperand(operand.index())->get(), - indices); + b.create( + loc, toStore, linalgOp.getDpsInitOperand(operand.index())->get(), + indices); } return success(); } @@ -152,7 +152,7 @@ struct LinalgOpTilingInterface return makeComposedFoldedAffineApply(b, loc, d0 - 1, ofr); })); - OpOperand *outOperand = linalgOp.getOutputOperand(resultNumber); + OpOperand *outOperand = linalgOp.getDpsInitOperand(resultNumber); SliceParameters sliceParams = computeSliceParameters( b, loc, outOperand->get(), sizes, linalgOp.getMatchingIndexingMap(outOperand), offsets, diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp index eee454b..415d87a 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -138,7 +138,7 @@ static FailureOr padOperandToSmallestStaticBoundingBox( OpOperand *currOpOperand = opOperand; while (auto linalgOp = currOpOperand->get().getDefiningOp()) { OpResult result = currOpOperand->get().cast(); - currOpOperand = linalgOp.getOutputOperand(result.getResultNumber()); + currOpOperand = linalgOp.getDpsInitOperand(result.getResultNumber()); } // Fail if `currOpOperand` is not defined by an ExtractSliceOp. @@ -222,7 +222,7 @@ linalg::rewriteAsPaddedOp(OpBuilder &b, LinalgOp opToPad, // Clone `opToPad` to operate on the statically padded shapes. auto resultTensorTypes = - ValueRange(newOperands).take_back(opToPad.getNumOutputs()).getTypes(); + ValueRange(newOperands).take_back(opToPad.getNumDpsInits()).getTypes(); paddedOp = opToPad.clone(b, loc, resultTensorTypes, newOperands); // Recover the slice out of the new static results. This keeps the original diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index 356ba00..d565efb 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -150,7 +150,7 @@ mlir::linalg::getCombinerOpKind(Operation *combinerOp) { static Operation *matchLinalgReduction(OpOperand *outputOperand) { auto linalgOp = cast(outputOperand->getOwner()); unsigned outputPos = - outputOperand->getOperandNumber() - linalgOp.getNumInputs(); + outputOperand->getOperandNumber() - linalgOp.getNumDpsInputs(); // Only single combiner operations are supported for now. SmallVector combinerOps; if (!matchReduction(linalgOp.getRegionOutputArgs(), outputPos, combinerOps) || @@ -263,7 +263,7 @@ vectorizeLinalgYield(OpBuilder &b, Operation *op, // TODO: use a map. Value vectorValue = bvm.lookup(outputs.value()); Value newResult = buildVectorWrite( - b, vectorValue, linalgOp.getOutputOperand(outputs.index())); + b, vectorValue, linalgOp.getDpsInitOperand(outputs.index())); if (newResult) newResults.push_back(newResult); } @@ -435,12 +435,12 @@ vectorizeOneOp(OpBuilder &b, LinalgOp linalgOp, Operation *op, SmallVector> reductionOperands; for (Value operand : op->getOperands()) { auto arg = operand.dyn_cast(); - if (!arg || arg.getArgNumber() < linalgOp.getNumInputs()) + if (!arg || arg.getArgNumber() < linalgOp.getNumDpsInputs()) continue; SmallVector reductionOps; Value reduceValue = matchReduction( linalgOp.getRegionOutputArgs(), - arg.getArgNumber() - linalgOp.getNumInputs(), reductionOps); + arg.getArgNumber() - linalgOp.getNumDpsInputs(), reductionOps); if (!reduceValue) continue; reductionOperands.push_back(std::make_pair(reduceValue, operand)); @@ -517,7 +517,7 @@ vectorizeAsLinalgGeneric(OpBuilder &b, LinalgOp linalgOp, mlir::getUsedValuesDefinedAbove(linalgOp->getRegion(0), valuesSet); bvm.map(valuesSet.getArrayRef(), valuesSet.getArrayRef()); - if (linalgOp.getNumOutputs() == 0) + if (linalgOp.getNumDpsInits() == 0) return failure(); // TODO: the common vector shape is equal to the static loop sizes only when @@ -540,7 +540,7 @@ vectorizeAsLinalgGeneric(OpBuilder &b, LinalgOp linalgOp, // if (linalgOp.getShape(&opOperand).empty()) { // readType = VectorType::get({}, bbarg.getType()); // } else { - if (opOperand->getOperandNumber() < linalgOp.getNumInputs()) { + if (opOperand->getOperandNumber() < linalgOp.getNumDpsInputs()) { map = inverseAndBroadcastProjectedPermutation( linalgOp.getMatchingIndexingMap(opOperand)); readType = VectorType::get(commonVectorShape, @@ -615,7 +615,7 @@ static LogicalResult reductionPreconditions(LinalgOp op) { LDBG("reduction precondition failed: no reduction iterator"); return failure(); } - for (OpOperand *opOperand : op.getOutputOperands()) { + for (OpOperand *opOperand : op.getDpsInitOperands()) { AffineMap indexingMap = op.getMatchingIndexingMap(opOperand); if (indexingMap.isPermutation()) continue; @@ -1426,11 +1426,11 @@ struct Conv1DGenerator : public StructuredGenerator { : StructuredGenerator(builder, linalgOp), strideW(strideW), dilationW(dilationW) { // Determine whether `linalgOp` can be generated with this generator - if (linalgOp.getNumInputs() != 2 || linalgOp.getNumOutputs() != 1) + if (linalgOp.getNumDpsInputs() != 2 || linalgOp.getNumDpsInits() != 1) return; - lhsShaped = linalgOp.getInputOperand(0)->get(); - rhsShaped = linalgOp.getInputOperand(1)->get(); - resShaped = linalgOp.getOutputOperand(0)->get(); + lhsShaped = linalgOp.getDpsInputOperand(0)->get(); + rhsShaped = linalgOp.getDpsInputOperand(1)->get(); + resShaped = linalgOp.getDpsInitOperand(0)->get(); lhsShapedType = lhsShaped.getType().dyn_cast(); rhsShapedType = rhsShaped.getType().dyn_cast(); resShapedType = resShaped.getType().dyn_cast(); @@ -1442,7 +1442,7 @@ struct Conv1DGenerator : public StructuredGenerator { return; // Check for reduction `add` preceded by `mul`. - Operation *reduceOp = matchLinalgReduction(linalgOp.getOutputOperand(0)); + Operation *reduceOp = matchLinalgReduction(linalgOp.getDpsInitOperand(0)); if (!reduceOp) return; llvm::Optional maybeKind; diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp index af5a201..ce15c67 100644 --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -179,7 +179,7 @@ bool isElementwise(LinalgOp op) { return false; // TODO: relax the restrictions on indexing map. - for (OpOperand *opOperand : op.getOutputOperands()) { + for (OpOperand *opOperand : op.getDpsInitOperands()) { if (!op.getMatchingIndexingMap(opOperand).isPermutation()) return false; } @@ -357,7 +357,7 @@ Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type, if (!linalgOp) break; OpResult opResult = current.cast(); - current = linalgOp.getOutputOperand(opResult.getResultNumber())->get(); + current = linalgOp.getDpsInitOperand(opResult.getResultNumber())->get(); } auto padOp = current ? current.getDefiningOp() : nullptr; @@ -479,7 +479,7 @@ void GenerateLoopNest::doit( "they are null entries"); SmallVector iterArgInitValues = linalgOp.hasBufferSemantics() ? SmallVector{} - : linalgOp.getOutputOperands(); + : linalgOp.getDpsInitOperands(); SmallVector lbs, ubs, steps; unpackRanges(b, loc, loopRanges, lbs, ubs, steps); @@ -490,7 +490,7 @@ void GenerateLoopNest::doit( "expect the number of output tensors and iter args to match"); SmallVector operandValuesToUse = linalgOp->getOperands(); if (!iterArgs.empty()) { - operandValuesToUse = linalgOp.getInputOperands(); + operandValuesToUse = linalgOp.getDpsInputOperands(); operandValuesToUse.append(iterArgs.begin(), iterArgs.end()); } return bodyBuilderFn(b, loc, ivs, operandValuesToUse); @@ -520,7 +520,7 @@ void GenerateLoopNest::doit( ArrayRef /*procInfo*/) { SmallVector iterArgInitValues = linalgOp.hasBufferSemantics() ? SmallVector{} - : linalgOp.getOutputOperands(); + : linalgOp.getDpsInitOperands(); assert(iterArgInitValues.empty() && "unexpected AffineForOp init values"); SmallVector lbs, ubs, steps; unpackRanges(b, loc, loopRanges, lbs, ubs, steps); @@ -686,7 +686,7 @@ void GenerateLoopNest::doit( ArrayRef procInfo) { SmallVector iterArgInitValues = linalgOp.hasBufferSemantics() ? SmallVector{} - : linalgOp.getOutputOperands(); + : linalgOp.getDpsInitOperands(); assert(iterArgInitValues.empty() && "unexpected ParallelOp init values"); // This function may be passed more iterator types than ranges. assert(iteratorTypes.size() >= loopRanges.size() && @@ -897,7 +897,7 @@ SmallVector getTensorOutputTypes(LinalgOp op, ValueRange operands) { if (op.hasBufferSemantics()) return {}; return llvm::to_vector( - llvm::map_range(op.getOutputOperands(), [&](OpOperand *opOperand) { + llvm::map_range(op.getDpsInitOperands(), [&](OpOperand *opOperand) { return operands[opOperand->getOperandNumber()].getType(); })); } @@ -911,7 +911,7 @@ SmallVector insertSlicesBack(OpBuilder &builder, Location loc, tensorResults.reserve(results.size()); // Insert a insert_slice for each output tensor. unsigned resultIdx = 0; - for (OpOperand *opOperand : op.getOutputOperands()) { + for (OpOperand *opOperand : op.getDpsInitOperands()) { // TODO: use an interface/adaptor to avoid leaking position in // `tiledOperands`. Value outputTensor = operands[opOperand->getOperandNumber()]; @@ -965,7 +965,7 @@ computeAllSliceParameters(OpBuilder &builder, Location loc, LinalgOp linalgOp, Type operandType = opOperand.get().getType(); if (!isTiled(map, tileSizes) && !(operandType.isa() && - linalgOp.isOutput(&opOperand))) { + linalgOp.isDpsInit(&opOperand))) { allSliceParams.push_back(llvm::None); LLVM_DEBUG(llvm::dbgs() << ": not tiled: use shape: " << operandType << "\n"); diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp index 8801fee..e3ab722 100644 --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -394,7 +394,7 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op, if (auto dstOp = dyn_cast(tilingResult.tiledOp)) { auto innerMostLoop = tilingResult.loops.back(); - SmallVector destinationTensors = dstOp.getOutputOperands(); + SmallVector destinationTensors = dstOp.getDpsInitOperands(); assert(destinationTensors.size() == innerMostLoop.getRegionIterArgs().size() && "unexpected number of outputs"); @@ -588,7 +588,7 @@ mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp( .getDefiningOp()) { scf::ForOp innerMostLoop = tileAndFuseResult.loops.back(); updateDestinationOperandsForTiledOp( - rewriter, dstOp.getOutputOperand(resultNumber)->get(), + rewriter, dstOp.getDpsInitOperand(resultNumber)->get(), innerMostLoop.getRegionIterArgs()[iterArgNumber.value()]); } } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp index 48cec34..4946256 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -169,13 +169,13 @@ public: LogicalResult matchAndRewrite(GenericOp op, PatternRewriter &rewriter) const override { if (!op.hasTensorSemantics() || op.getNumResults() != 1 || - !isAlloc(op.getOutputOperand(0), /*isZero=*/false) || !isZeroYield(op)) + !isAlloc(op.getDpsInitOperand(0), /*isZero=*/false) || !isZeroYield(op)) return failure(); auto outputType = op.getResult(0).getType().cast(); // Yielding zero on newly allocated (all-zero) sparse tensors can be // optimized out directly (regardless of dynamic or static size). if (getSparseTensorEncoding(outputType)) { - rewriter.replaceOp(op, op.getOutputOperand(0)->get()); + rewriter.replaceOp(op, op.getDpsInitOperand(0)->get()); return success(); } // Incorporate zero value into allocation copy. @@ -183,9 +183,9 @@ public: return failure(); Value zero = constantZero(rewriter, op.getLoc(), op.getResult(0).getType()); AllocTensorOp a = - op.getOutputOperand(0)->get().getDefiningOp(); + op.getDpsInitOperand(0)->get().getDefiningOp(); rewriter.updateRootInPlace(a, [&]() { a.getCopyMutable().assign(zero); }); - rewriter.replaceOp(op, op.getOutputOperand(0)->get()); + rewriter.replaceOp(op, op.getDpsInitOperand(0)->get()); return success(); } }; @@ -212,31 +212,31 @@ public: LogicalResult matchAndRewrite(GenericOp op, PatternRewriter &rewriter) const override { // Check consumer. - if (!op.hasTensorSemantics() || op.getNumInputs() != 2 || + if (!op.hasTensorSemantics() || op.getNumDpsInputs() != 2 || op.getNumResults() != 1 || op.getNumParallelLoops() != op.getNumLoops() || - !op.getMatchingIndexingMap(op.getOutputOperand(0)).isIdentity() || - !op.getMatchingIndexingMap(op.getInputOperand(0)).isIdentity() || - !op.getMatchingIndexingMap(op.getInputOperand(1)).isIdentity()) + !op.getMatchingIndexingMap(op.getDpsInitOperand(0)).isIdentity() || + !op.getMatchingIndexingMap(op.getDpsInputOperand(0)).isIdentity() || + !op.getMatchingIndexingMap(op.getDpsInputOperand(1)).isIdentity()) return failure(); // Find consuming OP2(sparse, other) or OP2(other, sparse). The other // operand can be sparse or dense, since the point of this rewriting rule // is detecting a situation in which *more* sparsity is introduced into // a computation, be it already sparse or still dense. unsigned other = 0; - if (isSparseTensor(op.getInputOperand(0))) + if (isSparseTensor(op.getDpsInputOperand(0))) other = 1; - else if (!isSparseTensor(op.getInputOperand(1))) + else if (!isSparseTensor(op.getDpsInputOperand(1))) return failure(); // Check producer. auto prod = dyn_cast_or_null( - op.getInputOperand(other)->get().getDefiningOp()); + op.getDpsInputOperand(other)->get().getDefiningOp()); if (!prod || !prod.hasTensorSemantics() || prod.getNumResults() != 1 || !prod.getResult(0).hasOneUse()) return failure(); // Sampling consumer and sum of multiplication chain producer. - if (!isAlloc(op.getOutputOperand(0), /*isZero=*/false) || - !isAlloc(prod.getOutputOperand(0), /*isZero=*/true) || + if (!isAlloc(op.getDpsInitOperand(0), /*isZero=*/false) || + !isAlloc(prod.getDpsInitOperand(0), /*isZero=*/true) || !isSampling(op) || !isSumOfMul(prod)) return failure(); // Modify operand structure of producer and consumer. @@ -244,7 +244,7 @@ public: SmallVector inputOps = prod.getInputs(); SmallVector outputOps = op.getOutputs(); SmallVector fusedIndexMaps = prod.getIndexingMapsArray(); - inputOps.push_back(op.getInputOperand(1 - other)->get()); + inputOps.push_back(op.getDpsInputOperand(1 - other)->get()); fusedIndexMaps.push_back(fusedIndexMaps.back()); // mimic other // Fuse producer and consumer into a new generic op. auto fusedOp = rewriter.create( @@ -277,12 +277,12 @@ public: rewriter.create(loc, last); // Force initial value on merged allocation for dense outputs. if (!getSparseTensorEncoding(op.getResult(0).getType())) { - Value init = prod.getOutputOperand(0) + Value init = prod.getDpsInitOperand(0) ->get() .getDefiningOp() .getCopy(); AllocTensorOp a = - op.getOutputOperand(0)->get().getDefiningOp(); + op.getDpsInitOperand(0)->get().getDefiningOp(); rewriter.updateRootInPlace(a, [&]() { a.getCopyMutable().assign(init); }); } // Replace consumer with fused operation. Old producer diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp index efef4ff..f77fdb5 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -311,7 +311,7 @@ static bool isAdmissibleTensorExp(Merger &merger, linalg::GenericOp op, std::vector &topSort, unsigned exp, OpOperand **sparseOut, unsigned &outerParNest) { - OpOperand *lhs = op.getOutputOperand(0); + OpOperand *lhs = op.getDpsInitOperand(0); unsigned tensor = lhs->getOperandNumber(); auto enc = getSparseTensorEncoding(lhs->get().getType()); // An non-annotated output tensor is assumed dense, and becomes a random @@ -410,7 +410,7 @@ static Value getCustomRedId(Operation *op) { static void genBuffers(Merger &merger, CodeGen &codegen, OpBuilder &builder, linalg::GenericOp op) { Location loc = op.getLoc(); - assert(op.getNumOperands() == op.getNumInputs() + 1); + assert(op.getNumOperands() == op.getNumDpsInputs() + 1); codegen.loopEmitter.initializeLoopEmit( builder, loc, @@ -425,7 +425,7 @@ static void genBuffers(Merger &merger, CodeGen &codegen, OpBuilder &builder, Value tensor) -> Value { // Must not be a sparse tensor. assert(!getSparseTensorEncoding(tensor.getType())); - OpOperand *lhs = op.getOutputOperand(0); + OpOperand *lhs = op.getDpsInitOperand(0); // Two output tensors references should pointed to the same object. assert(lhs->get() == tensor); bool isInit = op.isInitTensor(lhs); @@ -626,7 +626,7 @@ static void genTensorStore(Merger &merger, CodeGen &codegen, OpBuilder &builder, return; } // Store during insertion. - OpOperand *t = op.getOutputOperand(0); + OpOperand *t = op.getDpsInitOperand(0); if (t == codegen.sparseOut) { if (!rhs) { // Only unary and binary are allowed to return uninitialized rhs @@ -768,7 +768,7 @@ static void genInvariants(Merger &merger, CodeGen &codegen, OpBuilder &builder, // All exhausted at this level (atLevel denotes exactly at this level). if (!atLevel) return; - OpOperand *lhs = op.getOutputOperand(0); + OpOperand *lhs = op.getDpsInitOperand(0); if (lhs == &t) { // Start or end a scalarized reduction if (atStart) { @@ -1248,7 +1248,7 @@ static void genStmt(Merger &merger, CodeGen &codegen, RewriterBase &rewriter, /// Converts the result computed by the sparse kernel into the required form. static void genResult(Merger &merger, CodeGen &codegen, RewriterBase &rewriter, linalg::GenericOp op) { - OpOperand *lhs = op.getOutputOperand(0); + OpOperand *lhs = op.getDpsInitOperand(0); Type resType = lhs->get().getType(); if (getSparseTensorEncoding(resType)) { // The sparse tensor rematerializes from the original sparse tensor's @@ -1279,7 +1279,7 @@ public: PatternRewriter &rewriter) const override { // Detects sparse annotations and translate the per-dimension sparsity // information for all tensors to loop indices in the kernel. - if (op.getNumOutputs() != 1) + if (op.getNumDpsInits() != 1) return failure(); unsigned numTensors = op->getNumOperands(); unsigned numLoops = op.getNumLoops(); @@ -1349,7 +1349,7 @@ private: // sparse input tensor in succession until an acylic // iteration graph results. std::vector topSort; - for (OpOperand *t : op.getInputOperands()) { + for (OpOperand *t : op.getDpsInputOperands()) { unsigned tensor = t->getOperandNumber(); Value tval = t->get(); auto srcEnc = getSparseTensorEncoding(tval.getType()); diff --git a/mlir/lib/Interfaces/DestinationStyleOpInterface.cpp b/mlir/lib/Interfaces/DestinationStyleOpInterface.cpp index d89914e..b334eed 100644 --- a/mlir/lib/Interfaces/DestinationStyleOpInterface.cpp +++ b/mlir/lib/Interfaces/DestinationStyleOpInterface.cpp @@ -27,7 +27,7 @@ LogicalResult detail::verifyDestinationStyleOpInterface(Operation *op) { cast(op); SmallVector outputBufferOperands, outputTensorOperands; - for (OpOperand *operand : dstStyleOp.getOutputOperands()) { + for (OpOperand *operand : dstStyleOp.getDpsInitOperands()) { Type type = operand->get().getType(); if (type.isa()) { outputBufferOperands.push_back(operand); @@ -41,11 +41,11 @@ LogicalResult detail::verifyDestinationStyleOpInterface(Operation *op) { } // Expect at least one output operand. - int64_t numInputs = dstStyleOp.getNumInputs(); - int64_t numOutputs = dstStyleOp.getNumOutputs(); - if (numOutputs == 0) + int64_t numInputs = dstStyleOp.getNumDpsInputs(); + int64_t numInits = dstStyleOp.getNumDpsInits(); + if (numInits == 0) return op->emitOpError("expected at least one output operand"); - if (failed(OpTrait::impl::verifyNOperands(op, numInputs + numOutputs))) + if (failed(OpTrait::impl::verifyNOperands(op, numInputs + numInits))) return failure(); // Verify the number of results matches the number of output tensors. if (op->getNumResults() != outputTensorOperands.size()) diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp index 5807726..4bd6d43 100644 --- a/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp @@ -26,7 +26,7 @@ static void addOperands(Operation *op, SetVector &operandSet) { return; TypeSwitch(op) .Case([&](linalg::LinalgOp linalgOp) { - SmallVector inputOperands{linalgOp.getInputOperands()}; + SmallVector inputOperands{linalgOp.getDpsInputOperands()}; operandSet.insert(inputOperands.begin(), inputOperands.end()); }) .Default([&](Operation *operation) { @@ -147,7 +147,7 @@ struct TestLinalgElementwiseFusion if (expandOp->hasOneUse()) { OpOperand &use = *expandOp->getUses().begin(); auto linalgOp = dyn_cast(use.getOwner()); - if (linalgOp && linalgOp.isOutput(&use)) + if (linalgOp && linalgOp.isDpsInit(&use)) return true; } return false; diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp index 3c62496..561b5bc 100644 --- a/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp @@ -56,7 +56,7 @@ static LogicalResult fuseLinalgOpsGreedily(func::FuncOp f) { changed = true; } else if (opOperand.get().getType().isa()) { // Tile and Fuse tensor input. - if (opOperand.getOperandNumber() >= linalgOp.getNumInputs()) + if (opOperand.getOperandNumber() >= linalgOp.getNumDpsInputs()) continue; auto info = fuseProducerOfTensor(b, opOperand); if (failed(info)) diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index e4206f9..84bbe24 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -2836,7 +2836,7 @@ def TestLinalgConvOp : return ""; } - std::pair getOutputsPositionRange() { + std::pair getDpsInitsPositionRange() { int64_t getNumOperands = this->getNumOperands(); return {getNumOperands - 1, getNumOperands}; } @@ -2896,7 +2896,7 @@ def TestLinalgFillOp : return ""; } - std::pair getOutputsPositionRange() { + std::pair getDpsInitsPositionRange() { int64_t getNumOperands = this->getNumOperands(); return {getNumOperands - 1, getNumOperands}; } diff --git a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml index 51557f5..de15df5 100644 --- a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml +++ b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml @@ -236,7 +236,7 @@ structured_op: !LinalgStructuredOpConfig scalar_arg: value # IMPL: Test3Op::getIteratorTypesArray() { -# IMPL-NEXT: int64_t rank = getRank(getOutputOperand(0)); +# IMPL-NEXT: int64_t rank = getRank(getDpsInitOperand(0)); # IMPL: Test3Op::getIndexingMaps() { # IMPL-NEXT: MLIRContext *context = getContext(); diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp index d531d8b..0a482cc 100644 --- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp +++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp @@ -563,7 +563,7 @@ def {0} : LinalgStructuredBase_Op<"{1}", !listconcat([AttrSizedOperandSegments], return regionBuilder; } - std::pair getOutputsPositionRange() {{ + std::pair getDpsInitsPositionRange() {{ int64_t getNumOperands = this->getNumOperands(); return {{getNumOperands - 1, getNumOperands}; } @@ -608,7 +608,7 @@ SmallVector {0}::getIteratorTypesArray() {{ static const char rankPolyStructuredOpIteratorTypesFormat[] = R"FMT( SmallVector {0}::getIteratorTypesArray() {{ - int64_t rank = getRank(getOutputOperand(0)); + int64_t rank = getRank(getDpsInitOperand(0)); return SmallVector(rank, getParallelIteratorTypeName()); } )FMT"; @@ -661,7 +661,7 @@ void {0}::getEffects(SmallVectorImpl< SideEffects::EffectInstance >&effects) {{ if (hasTensorSemantics()) return; getGenericEffectsImpl(effects, - getOperation()->getResults(), getInputOperands(), getOutputOperands()); + getOperation()->getResults(), getDpsInputOperands(), getDpsInitOperands()); } )FMT"; -- 2.7.4