From a7cccb9cbb2b9954684cbea37615303a59719973 Mon Sep 17 00:00:00 2001 From: Alexander Belyaev Date: Fri, 14 Oct 2022 19:59:55 +0200 Subject: [PATCH] [mlir] Simplify DestinationStyleOpInterface. Differential Revision: https://reviews.llvm.org/D135348 --- .../mlir/Dialect/Linalg/IR/LinalgInterfaces.td | 453 ++++++--------------- .../mlir/Dialect/Linalg/IR/LinalgStructuredOps.td | 19 +- mlir/lib/CAPI/Dialect/Linalg.cpp | 6 +- .../Dialect/Linalg/Analysis/DependenceAnalysis.cpp | 22 +- mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp | 70 ++-- mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 140 ++++--- .../Linalg/Transforms/BubbleUpExtractSlice.cpp | 4 +- .../Transforms/BufferizableOpInterfaceImpl.cpp | 2 +- .../lib/Dialect/Linalg/Transforms/ConstantFold.cpp | 24 +- .../Linalg/Transforms/DecomposeLinalgOps.cpp | 6 +- mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp | 7 +- .../lib/Dialect/Linalg/Transforms/DropUnitDims.cpp | 12 +- .../Linalg/Transforms/ElementwiseOpFusion.cpp | 43 +- mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp | 25 +- .../Dialect/Linalg/Transforms/FusionOnTensors.cpp | 2 +- .../Dialect/Linalg/Transforms/Generalization.cpp | 14 +- .../lib/Dialect/Linalg/Transforms/HoistPadding.cpp | 2 +- .../Linalg/Transforms/InlineScalarOperands.cpp | 4 +- mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp | 4 +- mlir/lib/Dialect/Linalg/Transforms/Loops.cpp | 6 +- mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp | 26 +- .../Dialect/Linalg/Transforms/SplitReduction.cpp | 18 +- mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp | 2 +- .../Linalg/Transforms/TilingInterfaceImpl.cpp | 22 +- mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp | 14 +- .../Dialect/Linalg/Transforms/Vectorization.cpp | 34 +- mlir/lib/Dialect/Linalg/Utils/Utils.cpp | 53 +-- .../Transforms/SparseTensorRewriting.cpp | 7 +- .../SparseTensor/Transforms/Sparsification.cpp | 74 ++-- mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp | 6 +- mlir/test/Dialect/Linalg/canonicalize.mlir | 2 +- mlir/test/Dialect/Linalg/roundtrip.mlir | 32 +- .../Dialect/Linalg/TestLinalgElementwiseFusion.cpp | 4 +- .../Dialect/Linalg/TestLinalgFusionTransforms.cpp | 12 +- mlir/test/lib/Dialect/Test/TestOps.td | 14 +- .../mlir-linalg-ods-yaml-gen.cpp | 14 +- 36 files changed, 501 insertions(+), 698 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td index 69871fa..995ced5 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td @@ -317,7 +317,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> { /*args=*/(ins "OpOperand *":$opOperand), /*methodBody=*/"", /*defaultImplementation=*/[{ - if (!$_op.isOutputTensor(opOperand)) + if (!$_op.isOutput(opOperand)) return false; return payloadUsesValueFromOperand(opOperand); }] @@ -606,7 +606,13 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> { /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - return $_op.getInputAndOutputOperands(); + OpOperandVector result; + result.reserve($_op->getNumOperands()); + llvm::transform( + this->getOperation()->getOpOperands(), + std::back_inserter(result), + [](OpOperand &opOperand) { return &opOperand; }); + return result; }] >, //===------------------------------------------------------------------===// @@ -684,13 +690,8 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> { /*methodBody=*/"", /*defaultImplementation=*/[{ SmallVector res; - // MLIR currently does not support dependent interfaces or interface - // inheritance. By construction all ops with StructuredOpInterface must - // implement DestinationStyleOpInterface. - // TODO: reevalute the need for a cast when a better mechanism exists. - auto iface = cast(*this->getOperation()); - for (OpOperand *opOperand : iface.getInputAndOutputOperands()) - llvm::append_range(res, getShape(opOperand)); + for (OpOperand &opOperand : this->getOperation()->getOpOperands()) + llvm::append_range(res, getShape(&opOperand)); return res; }] >, @@ -779,31 +780,16 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> { // TODO: reevalute the need for a cast when a better mechanism exists. //========================================================================// - ValueRange getInputs() { - return cast(*this->getOperation()) - .getInputs(); - } - int64_t getNumInputs() { return cast(*this->getOperation()) .getNumInputs(); } - ValueRange getOutputs() { - return cast(*this->getOperation()) - .getOutputs(); - } - int64_t getNumOutputs() { return cast(*this->getOperation()) .getNumOutputs(); } - int64_t getNumInputsAndOutputs() { - return cast(*this->getOperation()) - .getNumInputsAndOutputs(); - } - OpOperandVector getInputOperands() { return cast(*this->getOperation()) .getInputOperands(); @@ -814,14 +800,9 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> { .getInputOperand(i); } - OpOperandVector getInputBufferOperands() { - return cast(*this->getOperation()) - .getInputBufferOperands(); - } - - OpOperandVector getInputTensorOperands() { + void setOutputOperand(int64_t i, Value value) { return cast(*this->getOperation()) - .getInputTensorOperands(); + .setOutputOperand(i, value); } OpOperandVector getOutputOperands() { @@ -834,44 +815,14 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> { .getOutputOperand(i); } - void setOutputOperand(int64_t i, Value value) { + bool isInput(OpOperand *opOperand) { return cast(*this->getOperation()) - .setOutputOperand(i, value); + .isInput(opOperand); } - OpOperandVector getOutputBufferOperands() { + bool isOutput(OpOperand *opOperand) { return cast(*this->getOperation()) - .getOutputBufferOperands(); - } - - OpOperandVector getOutputTensorOperands() { - return cast(*this->getOperation()) - .getOutputTensorOperands(); - } - - SmallVector getOutputBufferTypes() { - return cast(*this->getOperation()) - .getOutputBufferTypes(); - } - - SmallVector getOutputTensorTypes() { - return cast(*this->getOperation()) - .getOutputTensorTypes(); - } - - OpOperandVector getInputAndOutputOperands() { - return cast(*this->getOperation()) - .getInputAndOutputOperands(); - } - - bool isInputTensor(OpOperand *opOperand) { - return cast(*this->getOperation()) - .isInputTensor(opOperand); - } - - bool isOutputTensor(OpOperand *opOperand) { - return cast(*this->getOperation()) - .isOutputTensor(opOperand); + .isOutput(opOperand); } bool isScalar(OpOperand *opOperand) { @@ -928,331 +879,185 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> { let verifyWithRegions = 1; } -// The 'DestinationStyleOpInterface' provides access to the methods relevant -// for destination-style ops. A destination-style operation has 'n' input -// arguments and 'm' output arguments. Each op that wants to implement -// DestinationStyleOpInterface needs to define getInputs() and getOutputs() -// methods. +// Ops that are in destination style have designated output operands, which act +// as initial tensor values for the results of the operation or the output +// buffers to which the results of the op will be written. +// +// Output operands must be tensors or memrefs. Input operands can have any +// type. All non-output operands are inputs. + +// It is assumed that the output operands of the op are the operands at +// position [start, end). The positions are defined by getOutputsPositionRange +// method. All non-output operands are "inputs" of the DPS op. + +// If the op has "tensor semantics", then the input operands are either scalars +// or tensors. The output operands are tensors and every tensor output is tied +// to a corresponding tensor OpResult in a 1-to-1 fashion. The i-th output +// tensor is tied to the i-th OpResult. The op may not have any additional +// OpResults. Output operands and their tied OpResults have the same type. +// +// If the op has "buffer semantics", then the input operands are either memrefs +// or other non-tensor types, e.g. scalar types. Furthermore, the output +// operands are memrefs and the op has no results. +// +// Destination-passing style abstraction makes certain transformations easier. +// For example, tiling implementation can extract/insert slices from/into the +// destination of an op and use the resulting shaped value as an iter_arg in +// the surrounding loop structure. As another example, bufferization does not +// have to allocate new buffers for destinations (in case of in-place +// bufferization) and can directly reuse the existing destination buffer. +// +// Example of a destination style op: `%r = tensor.insert_slice %t into %d`, +// where `%t` is the single input and `%d` is the single output. `%d` is tied +// to `%r`. +// +// Example of an op that is not in destination style: `%r = tensor.pad %t`. +// This op is not in destination style because `%r` and `%t` have different +// shape. +// +// Each op that wants to implement DestinationStyleOpInterface needs to define +// the getOutputsPositionRange() method. def DestinationStyleOpInterface : OpInterface<"DestinationStyleOpInterface"> { let cppNamespace = "::mlir::linalg"; let methods = [ - //===------------------------------------------------------------------===// - // Num input/output arguments handling. - //===------------------------------------------------------------------===// - // `getInputs` must be defined by each op that wants to implement the - // DestinationStyleOpInterface. + // This method has to be defined for every DPS op. InterfaceMethod< - /*desc=*/[{ - Return the input shape operands. - }], - /*retTy=*/"ValueRange", - /*methodName=*/"getInputs", - /*args=*/(ins) - >, - // These special methods rely on `getInputs` and `getOutputs` being defined - // by each op that wants to implement the DestinationStyleOpInterface. - InterfaceMethod< - /*desc=*/[{ - Return the number of inputs. - }], - /*retTy=*/"int64_t", - /*methodName=*/"getNumInputs", - /*args=*/(ins), - /*methodBody=*/"", - /*defaultImplementation=*/[{ - return $_op.getInputs().size(); - }] - >, - // `getOutputs` must be defined by each op that wants to implement the - // DestinationStyleOpInterface. - InterfaceMethod< - /*desc=*/[{ - Return the output shape operands. - }], - /*retTy=*/"ValueRange", - /*methodName=*/"getOutputs", - /*args=*/(ins) - >, - InterfaceMethod< - /*desc=*/[{ - Return the number of outputs. - }], - /*retTy=*/"int64_t", - /*methodName=*/"getNumOutputs", - /*args=*/(ins), - /*methodBody=*/"", - /*defaultImplementation=*/[{ - return $_op.getOutputs().size(); - }] - >, - InterfaceMethod< - /*desc=*/[{ - Return the number of inputs and outputs. - }], - /*retTy=*/"int64_t", - /*methodName=*/"getNumInputsAndOutputs", + /*desc=*/"Return start and end indices of the output operands range.", + /*retTy=*/"std::pair", + /*methodName=*/"getOutputsPositionRange", /*args=*/(ins), /*methodBody=*/"", - /*defaultImplementation=*/[{ - return this->getOperation()->getNumOperands(); - }] + /*defaultImplementation=*/"" >, //===------------------------------------------------------------------===// - // Input operands handling. + // Operands handling. //===------------------------------------------------------------------===// + // The operand list is assumed to start with the input operands and end + // with the output operands. Therefore, all methods to access the inputs + // and outputs can be expressed if the number of output operands is know. InterfaceMethod< - /*desc=*/[{ - Return the input operands. - }], - /*retTy=*/"OpOperandVector", - /*methodName=*/"getInputOperands", - /*args=*/(ins), - /*methodBody=*/"", - /*defaultImplementation=*/[{ - int64_t numInputs = getNumInputs(); - OpOperandVector result; - result.reserve(numInputs); - llvm::transform( - this->getOperation()->getOpOperands().take_front(numInputs), - std::back_inserter(result), - [](OpOperand &opOperand) { return &opOperand; }); - return result; - }] - >, - InterfaceMethod< - /*desc=*/[{ - Return the `i`-th input operand. - }], - /*retTy=*/"OpOperand*", - /*methodName=*/"getInputOperand", - /*args=*/(ins "int64_t":$i), - /*methodBody=*/"", - /*defaultImplementation=*/[{ - assert(i >= 0 && i < getNumInputs()); - return &this->getOperation()->getOpOperand(i); - }] - >, - InterfaceMethod< - /*desc=*/[{ - Return the subset of input operands that are of buffer type. - }], - /*retTy=*/"OpOperandVector", - /*methodName=*/"getInputBufferOperands", - /*args=*/(ins), - /*methodBody=*/"", - /*defaultImplementation=*/[{ - OpOperandVector result; - result.reserve(getNumInputs()); - llvm::copy_if(getInputOperands(), - std::back_inserter(result), - [](OpOperand *opOperand) { - return opOperand->get().getType().template isa(); - }); - return result; - }] - >, - InterfaceMethod< - /*desc=*/[{ - Return the subset of input operands that are of tensor type. - }], - /*retTy=*/"OpOperandVector", - /*methodName=*/"getInputTensorOperands", + /*desc=*/"Return the number of outputs.", + /*retTy=*/"int64_t", + /*methodName=*/"getNumOutputs", /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - OpOperandVector result; - result.reserve(getNumInputs()); - llvm::copy_if(getInputOperands(), - std::back_inserter(result), - [](OpOperand *opOperand) { - return opOperand->get().getType().template isa(); - }); - return result; + auto [start, end] = $_op.getOutputsPositionRange(); + return end - start; }] >, - //===------------------------------------------------------------------===// - // Output operands handling. - //===------------------------------------------------------------------===// InterfaceMethod< - /*desc=*/[{ - Return the output operands. - }], + /*desc=*/"Return the output operands.", /*retTy=*/"OpOperandVector", /*methodName=*/"getOutputOperands", /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - int64_t numOutputs = getNumOutputs(); + auto [start, end] = $_op.getOutputsPositionRange(); + OpOperandVector result; - result.reserve(numOutputs); - llvm::transform( - this->getOperation()->getOpOperands() - .take_back(numOutputs), - std::back_inserter(result), - [](OpOperand &opOperand) { return &opOperand; }); + result.reserve(end - start); + for (int i = start; i < end; ++i) + result.push_back(&$_op->getOpOperand(i)); return result; }] >, InterfaceMethod< - /*desc=*/[{ - Return the `i`-th output operand. - }], + /*desc=*/"Return the `i`-th output operand.", /*retTy=*/"OpOperand*", /*methodName=*/"getOutputOperand", /*args=*/(ins "int64_t":$i), /*methodBody=*/"", /*defaultImplementation=*/[{ - assert(i >= 0 && i < getNumOutputs()); - return &this->getOperation()->getOpOperand(getNumInputs() + i); + assert(i >= 0 && i < $_op.getNumOutputs()); + auto [start, end] = $_op.getOutputsPositionRange(); + return &$_op->getOpOperand(start + i); }] >, InterfaceMethod< - /*desc=*/[{ - Set the `i`-th output operand. - }], + /*desc=*/"Set the `i`-th output operand.", /*retTy=*/"void", /*methodName=*/"setOutputOperand", /*args=*/(ins "int64_t":$i, "Value":$value), /*methodBody=*/"", /*defaultImplementation=*/[{ - assert(i >= 0 && i < getNumOutputs()); - this->getOperation()->setOperand(getNumInputs() + i, value); + assert(i >= 0 && i < $_op.getNumOutputs()); + auto [start, end] = $_op.getOutputsPositionRange(); + $_op->setOperand(start + i, value); }] >, InterfaceMethod< - /*desc=*/[{ - Return the subset of output operands that are of buffer type. - }], - /*retTy=*/"OpOperandVector", - /*methodName=*/"getOutputBufferOperands", + /*desc=*/"Return the number of inputs.", + /*retTy=*/"int64_t", + /*methodName=*/"getNumInputs", /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - OpOperandVector result; - result.reserve(getNumOutputs()); - llvm::copy_if(getOutputOperands(), - std::back_inserter(result), - [](OpOperand *opOperand) { - return opOperand->get().getType().template isa(); - }); - return result; + return $_op.getNumOperands() - $_op.getNumOutputs(); }] >, InterfaceMethod< - /*desc=*/[{ - Return the subset of output operands that are of tensor type. - }], + /*desc=*/"Return the input operands.", /*retTy=*/"OpOperandVector", - /*methodName=*/"getOutputTensorOperands", + /*methodName=*/"getInputOperands", /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ + auto [start, end] = $_op.getOutputsPositionRange(); + int64_t numOutputs = end - start; + int64_t numOperands = $_op.getNumOperands(); + OpOperandVector result; - result.reserve(getNumOutputs()); - llvm::copy_if(getOutputOperands(), - std::back_inserter(result), - [](OpOperand *opOperand) { - return opOperand->get().getType().template isa(); - }); - return result; - }] - >, - InterfaceMethod< - /*desc=*/[{ - Return the types of the subset of output operands that are of buffer type. - }], - /*retTy=*/"SmallVector", - /*methodName=*/"getOutputBufferTypes", - /*args=*/(ins), - /*methodBody=*/"", - /*defaultImplementation=*/[{ - SmallVector result; - result.reserve(getNumOutputs()); - llvm::transform(getOutputBufferOperands(), - std::back_inserter(result), - [](OpOperand *opOperands) { - return opOperands->get().getType().cast(); - }); + result.reserve(numOperands - numOutputs); + for (int i = 0; i < start; ++i) + result.push_back(&$_op->getOpOperand(i)); + for (int i = end; i < numOperands; ++i) + result.push_back(&$_op->getOpOperand(end + i)); + return result; }] >, InterfaceMethod< - /*desc=*/[{ - Return the types of the subset of output operands that are of tensor type. - }], - /*retTy=*/"SmallVector", - /*methodName=*/"getOutputTensorTypes", - /*args=*/(ins), + /*desc=*/[{ Return the `i`-th input operand. }], + /*retTy=*/"OpOperand*", + /*methodName=*/"getInputOperand", + /*args=*/(ins "int64_t":$i), /*methodBody=*/"", /*defaultImplementation=*/[{ - SmallVector result; - result.reserve(getNumOutputs()); - llvm::transform(getOutputTensorOperands(), - std::back_inserter(result), - [](OpOperand *opOperands) { - return opOperands->get().getType().cast(); - }); - return result; + assert(i >= 0 && i < getNumInputs()); + auto [start, end] = $_op.getOutputsPositionRange(); + return &$_op->getOpOperand(i < start ? i : i + end - start) ; }] >, //===------------------------------------------------------------------===// // Input and Output arguments handling. //===------------------------------------------------------------------===// InterfaceMethod< - /*desc=*/[{ - Return the range over input and output operands. - }], - /*retTy=*/"OpOperandVector", - /*methodName=*/"getInputAndOutputOperands", - /*args=*/(ins), - /*methodBody=*/"", - /*defaultImplementation=*/[{ - int64_t numInputsAndOutputs = getNumInputsAndOutputs(); - OpOperandVector result; - result.reserve(numInputsAndOutputs); - llvm::transform( - this->getOperation()->getOpOperands(), - std::back_inserter(result), - [](OpOperand &opOperand) { return &opOperand; }); - return result; - }] - >, - InterfaceMethod< - /*desc=*/[{ - Return true if `opOperand` is an input tensor. - }], + /*desc=*/"Return true if `opOperand` is an input.", /*retTy=*/"bool", - /*methodName=*/"isInputTensor", + /*methodName=*/"isInput", /*args=*/(ins "OpOperand *":$opOperand), /*methodBody=*/"", /*defaultImplementation=*/[{ - if (!opOperand->get().getType().template isa()) - return false; - if (opOperand->getOperandNumber() < $_op.getNumInputs()) - return true; - return false; + auto [start, end] = $_op.getOutputsPositionRange(); + auto operandNumber = opOperand->getOperandNumber(); + return operandNumber < start || operandNumber >= end; }] >, InterfaceMethod< - /*desc=*/[{ - Return true if `opOperand` is an output tensor. - }], + /*desc=*/"Return true if `opOperand` is an output.", /*retTy=*/"bool", - /*methodName=*/"isOutputTensor", + /*methodName=*/"isOutput", /*args=*/(ins "OpOperand *":$opOperand), /*methodBody=*/"", /*defaultImplementation=*/[{ - if (!opOperand->get().getType().template isa()) - return false; - if (opOperand->getOperandNumber() >= $_op.getNumInputs()) - return true; - return false; + auto [start, end] = $_op.getOutputsPositionRange(); + auto operandNumber = opOperand->getOperandNumber(); + return operandNumber >= start && operandNumber < end; }] >, InterfaceMethod< - /*desc=*/[{ - Return true if the `opOperand` is a scalar value. - }], + /*desc=*/"Return true if the `opOperand` is a scalar value.", /*retTy=*/"bool", /*methodName=*/"isScalar", /*args=*/(ins "OpOperand*":$opOperand), @@ -1263,35 +1068,33 @@ def DestinationStyleOpInterface : OpInterface<"DestinationStyleOpInterface"> { }] >, InterfaceMethod< - /*desc=*/[{ - Return the result tied to `opOperand`. - }], + /*desc=*/"Return the result tied to `opOperand`.", /*retTy=*/"OpResult", /*methodName=*/"getTiedOpResult", /*args=*/(ins "OpOperand*":$opOperand), /*methodBody=*/"", /*defaultImplementation=*/[{ assert(opOperand->getOwner() == this->getOperation()); - int64_t resultIndex = opOperand->getOperandNumber() - getNumInputs(); + + auto [start, end] = $_op.getOutputsPositionRange(); + int64_t resultIndex = opOperand->getOperandNumber() - start; assert(resultIndex >= 0 && - resultIndex < this->getOperation()->getNumResults() ); - return this->getOperation()->getResult(resultIndex); + resultIndex < $_op->getNumResults() ); + return $_op->getResult(resultIndex); }] >, //===------------------------------------------------------------------===// // Other interface methods. //===------------------------------------------------------------------===// InterfaceMethod< - /*desc=*/[{ - Return whether the op has only MemRef input and outputs. - }], + /*desc=*/"Return whether the op has only MemRef input and outputs.", /*retTy=*/"bool", /*methodName=*/"hasBufferSemantics", /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - return this->getOperation()->getNumResults() == 0 && - llvm::all_of(this->getOperation()->getOpOperands(), + return $_op->getNumResults() == 0 && + llvm::all_of($_op->getOpOperands(), [&](OpOperand &opOperand) { return isScalar(&opOperand) || opOperand.get().getType().template isa(); @@ -1299,15 +1102,13 @@ def DestinationStyleOpInterface : OpInterface<"DestinationStyleOpInterface"> { }] >, InterfaceMethod< - /*desc=*/[{ - Return whether the op has only RankedTensor input and outputs. - }], + /*desc=*/"Return whether the op has only RankedTensor input and outputs.", /*retTy=*/"bool", /*methodName=*/"hasTensorSemantics", /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - return llvm::all_of(this->getOperation()->getOpOperands(), + return llvm::all_of($_op->getOpOperands(), [&](OpOperand &opOperand) { return isScalar(&opOperand) || opOperand.get().getType().template isa(); diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td index 2619ad1..3b06a59 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -215,6 +215,10 @@ def GenericOp : LinalgStructuredBase_Op<"generic", [ getRegionBuilder() { return nullptr; } + std::pair getOutputsPositionRange() { + int64_t getNumOperands = this->getNumOperands(); + return {getNumOperands - getOutputs().size(), getNumOperands}; + } }]; let hasCanonicalizer = 1; @@ -271,11 +275,10 @@ def MapOp : LinalgStructuredBase_Op<"map", [ } // Implement functions necessary for DestinationStyleOpInterface. - unsigned getNumInputs() { - return this->getOperation()->getNumOperands() - getNumOutputs(); - }; - unsigned getNumOutputs() { return 1; }; - mlir::ValueRange getOutputs() { return getOperands().take_back(1); } + std::pair getOutputsPositionRange() { + int64_t getNumOperands = this->getNumOperands(); + return {getNumOperands - 1, getNumOperands}; + } linalg::OpOperandVector getOpOperandsMatchingBBargs() { return getInputOperands(); } @@ -341,14 +344,14 @@ def ReduceOp : LinalgStructuredBase_Op<"reduce", [ } // Implement functions necessary for DestinationStyleOpInterface. - mlir::ValueRange getOutputs() { return getInits(); } - unsigned getNumInputs() { return getInputs().size(); }; - unsigned getNumOutputs() { return getInits().size(); }; static std::function)> getRegionBuilder() { return nullptr; } + std::pair getOutputsPositionRange() { + return {getInits().size(), getNumOperands()}; + } }]; let hasCustomAssemblyFormat = 1; diff --git a/mlir/lib/CAPI/Dialect/Linalg.cpp b/mlir/lib/CAPI/Dialect/Linalg.cpp index bfb3313..2fb5bc6 100644 --- a/mlir/lib/CAPI/Dialect/Linalg.cpp +++ b/mlir/lib/CAPI/Dialect/Linalg.cpp @@ -29,9 +29,9 @@ void mlirLinalgFillBuiltinNamedOpRegion(MlirOperation mlirOp) { SmallVector argTypes; SmallVector argLocs; - for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) { - argTypes.push_back(getElementTypeOrSelf(opOperand->get().getType())); - argLocs.push_back(opOperand->get().getLoc()); + for (OpOperand &opOperand : linalgOp->getOpOperands()) { + argTypes.push_back(getElementTypeOrSelf(opOperand.get().getType())); + argLocs.push_back(opOperand.get().getLoc()); } ImplicitLocOpBuilder b(op->getLoc(), op->getContext()); diff --git a/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp b/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp index 5a6d6788..c4c7efb 100644 --- a/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp +++ b/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp @@ -166,6 +166,8 @@ void LinalgDependenceGraph::addDependencesBetween(LinalgOp src, LinalgOp dst) { << " and " << *dst.getOperation() << "\n"); if (src.hasTensorSemantics() && dst.hasTensorSemantics()) { for (OpOperand *dstOpOperand : dst.getInputOperands()) { + if (!dstOpOperand->get().getType().isa()) + continue; // Check if the operand is defined by the src. auto definingOp = dstOpOperand->get().getDefiningOp(); if (definingOp && definingOp == src) @@ -188,23 +190,31 @@ void LinalgDependenceGraph::addDependencesBetween(LinalgOp src, LinalgOp dst) { } assert(src.hasBufferSemantics() && dst.hasBufferSemantics() && "unhandled dependence tracking for mixed buffer/tensor operations"); - for (OpOperand *srcOpOperand : src.getOutputBufferOperands()) { // W + for (OpOperand *srcOpOperand : src.getOutputOperands()) { // W // RAW graph - for (OpOperand *dstOpOperand : dst.getInputBufferOperands()) // R + for (OpOperand *dstOpOperand : dst.getInputOperands()) { // R + if (!dstOpOperand->get().getType().isa()) + continue; if (aliases.alias(srcOpOperand->get(), dstOpOperand->get())) // RAW alias addDependenceElem(DependenceType::RAW, srcOpOperand, dstOpOperand); + } // WAW graph - for (OpOperand *dstOpOperand : dst.getOutputBufferOperands()) // W + for (OpOperand *dstOpOperand : dst.getOutputOperands()) // W if (aliases.alias(srcOpOperand->get(), dstOpOperand->get())) // WAW alias addDependenceElem(DependenceType::WAW, srcOpOperand, dstOpOperand); } - for (OpOperand *srcOpOperand : src.getInputBufferOperands()) { // R + for (OpOperand *srcOpOperand : src.getInputOperands()) { // R + if (!srcOpOperand->get().getType().isa()) + continue; // RAR graph - for (OpOperand *dstOpOperand : dst.getInputBufferOperands()) // R + for (OpOperand *dstOpOperand : dst.getInputOperands()) { // R + if (!dstOpOperand->get().getType().isa()) + continue; if (aliases.alias(srcOpOperand->get(), dstOpOperand->get())) // RAR alias addDependenceElem(DependenceType::RAR, srcOpOperand, dstOpOperand); + } // WAR graph - for (OpOperand *dstOpOperand : dst.getOutputBufferOperands()) // W + for (OpOperand *dstOpOperand : dst.getOutputOperands()) // W if (aliases.alias(srcOpOperand->get(), dstOpOperand->get())) // WAR alias addDependenceElem(DependenceType::WAR, srcOpOperand, dstOpOperand); } diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp index 9a62a40..88fd71c 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -31,10 +31,10 @@ using namespace mlir::linalg; bool linalg::detail::canOpOperandsBeDroppedImpl( linalg::LinalgOp linalgOp, ArrayRef droppedOperands) { SmallVector indexingMaps; - for (auto *opOperand : linalgOp.getInputAndOutputOperands()) { - if (llvm::is_contained(droppedOperands, opOperand)) + for (auto &opOperand : linalgOp->getOpOperands()) { + if (llvm::is_contained(droppedOperands, &opOperand)) continue; - indexingMaps.push_back(linalgOp.getMatchingIndexingMap(opOperand)); + indexingMaps.push_back(linalgOp.getMatchingIndexingMap(&opOperand)); } return inversePermutation(concatAffineMaps(indexingMaps)) != AffineMap(); } @@ -491,9 +491,9 @@ static OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value source, SmallVector LinalgOp::createFlatListOfOperandDims(OpBuilder &b, Location loc) { SmallVector res; - for (OpOperand *opOperand : getInputAndOutputOperands()) { - for (int64_t i = 0, e = getRank(opOperand); i < e; ++i) - res.push_back(createFoldedDimOp(b, loc, opOperand->get(), i)); + for (OpOperand &opOperand : getOperation()->getOpOperands()) { + for (int64_t i = 0, e = getRank(&opOperand); i < e; ++i) + res.push_back(createFoldedDimOp(b, loc, opOperand.get(), i)); } return res; } @@ -501,8 +501,8 @@ SmallVector LinalgOp::createFlatListOfOperandDims(OpBuilder &b, SmallVector LinalgOp::createFlatListOfOperandStaticDims() { SmallVector res; assert(!hasDynamicShape() && "expected operands to have static shapes"); - for (OpOperand *opOperand : getInputAndOutputOperands()) - llvm::append_range(res, getShape(opOperand)); + for (OpOperand &opOperand : getOperation()->getOpOperands()) + llvm::append_range(res, getShape(&opOperand)); return res; } @@ -644,32 +644,32 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) { // All input/output operands must be indexed. if (static_cast(linalgOp.getIndexingMapsArray().size()) != - linalgOp.getNumInputsAndOutputs()) + linalgOp->getNumOperands()) return op->emitOpError("expected the number of indexing_map (") << linalgOp.getIndexingMapsArray().size() << ") to be equal to the number of input/output operands (" - << linalgOp.getNumInputsAndOutputs() << ")"; + << linalgOp->getNumOperands() << ")"; - for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) { - AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand); + for (OpOperand &opOperand : linalgOp->getOpOperands()) { + AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand); // Symbols disallowed. if (indexingMap.getNumSymbols() != 0) return op->emitOpError("unexpected symbols in indexing_map #") - << opOperand->getOperandNumber(); + << opOperand.getOperandNumber(); // Domain must be consistent. unsigned numLoops = linalgOp.getNumLoops(); if (indexingMap.getNumDims() != numLoops) return op->emitOpError("expected indexing_map #") - << opOperand->getOperandNumber() << " to have " << numLoops + << opOperand.getOperandNumber() << " to have " << numLoops << " dim(s) to match the number of loops"; - int64_t rank = linalgOp.getRank(opOperand); + int64_t rank = linalgOp.getRank(&opOperand); if (indexingMap.getNumResults() != rank) return op->emitOpError("expected operand rank (") << rank << ") to match the result rank of indexing_map #" - << opOperand->getOperandNumber() << " (" + << opOperand.getOperandNumber() << " (" << indexingMap.getNumResults() << ")"; } @@ -688,13 +688,13 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) { if (llvm::none_of(endLoopRangeValues, ShapedType::isDynamic)) { for (int64_t &range : endLoopRangeValues) range -= 1; - for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) { - AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand); + for (OpOperand &opOperand : linalgOp->getOpOperands()) { + AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand); SmallVector startIndices = indexingMap.compose(startLoopRangeValues); SmallVector endIndices = indexingMap.compose(endLoopRangeValues); - ArrayRef shape = linalgOp.getShape(opOperand); + ArrayRef shape = linalgOp.getShape(&opOperand); for (auto dim : llvm::seq(0, shape.size())) { // Ignore dynamic dimension or the case that the dimension size is 0 if (ShapedType::isDynamic(shape[dim]) || shape[dim] == 0) @@ -725,17 +725,16 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) { if (indexingMap.getResult(dim).dyn_cast()) { if (inferredDimSize != shape[dim]) { return op->emitOpError("inferred input/output operand #") - << opOperand->getOperandNumber() - << " has shape's dimension #" << dim << " to be " - << inferredDimSize << ", but found " << shape[dim]; + << opOperand.getOperandNumber() << " has shape's dimension #" + << dim << " to be " << inferredDimSize << ", but found " + << shape[dim]; } } else { if (inferredDimSize > shape[dim]) { return op->emitOpError("inferred input/output operand #") - << opOperand->getOperandNumber() - << " has shape's dimension #" << dim - << " to be greater than or equal to " << inferredDimSize - << ", but found " << shape[dim]; + << opOperand.getOperandNumber() << " has shape's dimension #" + << dim << " to be greater than or equal to " + << inferredDimSize << ", but found " << shape[dim]; } } } @@ -777,6 +776,15 @@ mlir::linalg::detail::verifyDestinationStyleOpInterface(Operation *op) { DestinationStyleOpInterface dstStyleOp = cast(op); + SmallVector outputBufferOperands, outputTensorOperands; + for (OpOperand *operand : dstStyleOp.getOutputOperands()) { + Type type = operand->get().getType(); + if (type.isa()) + outputBufferOperands.push_back(operand); + if (type.isa()) + outputTensorOperands.push_back(operand); + } + // Expect at least one output operand. // This means an op that constructs a tensor out of indices cannot be a // LinalgOp at the moment. For now this will have to be a special op until we @@ -788,23 +796,22 @@ mlir::linalg::detail::verifyDestinationStyleOpInterface(Operation *op) { if (failed(OpTrait::impl::verifyNOperands(op, numInputs + numOutputs))) return failure(); // Verify the number of results matches the number of output tensors. - if (op->getNumResults() != dstStyleOp.getOutputTensorOperands().size()) + if (op->getNumResults() != outputTensorOperands.size()) return op->emitOpError("expected the number of results (") << op->getNumResults() << ") to be equal to the number of output tensors (" - << dstStyleOp.getOutputTensorOperands().size() << ")"; + << outputTensorOperands.size() << ")"; // Simplifying assumption: either full tensor or full buffer mode. // This allows simpler verification of output operands vs result types // without premature tracking of which operand is what in mixed-mode. // TODO: relax when mixed-mode needs to pass verification. - if (!dstStyleOp.getOutputBufferOperands().empty() && - !dstStyleOp.getOutputTensorOperands().empty()) + if (!outputBufferOperands.empty() && !outputTensorOperands.empty()) return op->emitOpError( "expected output operands to all have tensor type or " "all have buffer type"); - for (OpOperand *opOperand : dstStyleOp.getOutputTensorOperands()) { + for (OpOperand *opOperand : outputTensorOperands) { OpResult result = dstStyleOp.getTiedOpResult(opOperand); if (result.getType() != opOperand->get().getType()) return op->emitOpError("expected type of operand #") @@ -813,6 +820,5 @@ mlir::linalg::detail::verifyDestinationStyleOpInterface(Operation *op) { << " to match type of corresponding result (" << result.getType() << ")"; } - return success(); } diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index c2705a3..586d198 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -767,7 +767,8 @@ void GenericOp::print(OpAsmPrinter &p) { } // Printing is shared with named ops, except for the region and attributes - printCommonStructuredOpParts(p, getInputs(), getOutputs()); + printCommonStructuredOpParts(p, SmallVector(getInputOperands()), + SmallVector(getOutputOperands())); genericAttrNames.push_back("operand_segment_sizes"); genericAttrNamesSet.insert(genericAttrNames.back()); @@ -835,15 +836,20 @@ ParseResult GenericOp::parse(OpAsmParser &parser, OperationState &result) { static void getGenericEffectsImpl( SmallVectorImpl> &effects, - ValueRange results, ValueRange inputBuffers, ValueRange outputs) { - for (Value value : inputBuffers) { - effects.emplace_back(MemoryEffects::Read::get(), value, + ValueRange results, OpOperandVector inputOperands, + OpOperandVector outputOperands) { + for (auto *operand : inputOperands) { + if (!operand->get().getType().isa()) + continue; + effects.emplace_back(MemoryEffects::Read::get(), operand->get(), SideEffects::DefaultResource::get()); } - for (Value value : outputs) { - effects.emplace_back(MemoryEffects::Read::get(), value, + for (auto *operand : outputOperands) { + if (!operand->get().getType().isa()) + continue; + effects.emplace_back(MemoryEffects::Read::get(), operand->get(), SideEffects::DefaultResource::get()); - effects.emplace_back(MemoryEffects::Write::get(), value, + effects.emplace_back(MemoryEffects::Write::get(), operand->get(), SideEffects::DefaultResource::get()); } } @@ -851,10 +857,8 @@ static void getGenericEffectsImpl( void GenericOp::getEffects( SmallVectorImpl> &effects) { - SmallVector inputBuffers = getInputBufferOperands(); - SmallVector outputBuffers = getOutputBufferOperands(); - getGenericEffectsImpl(effects, getOperation()->getResults(), inputBuffers, - outputBuffers); + getGenericEffectsImpl(effects, getOperation()->getResults(), + getInputOperands(), getOutputOperands()); } static bool isResultValueDead(linalg::GenericOp genericOp, OpResult result) { @@ -925,7 +929,7 @@ struct DeduplicateAndRemoveDeadOperandsAndResults // Check if there is any change to operands. if (newInputOperands.size() + newOutputOperands.size() == - static_cast(genericOp.getNumInputsAndOutputs())) + genericOp->getNumOperands()) return failure(); // Create the new op with the body being empty. @@ -977,35 +981,34 @@ private: SmallVector &newIndexingMaps) const { llvm::SmallDenseMap origToNewPos; llvm::SmallDenseMap, unsigned> dedupedInputs; - for (const auto &inputOpOperand : - llvm::enumerate(genericOp.getInputOperands())) { + for (const auto &en : llvm::enumerate(genericOp.getInputOperands())) { + OpOperand *inputOpOperand = en.value(); // Check if operand is dead and if dropping the indexing map makes the // loops to shape computation invalid. - if (!genericOp.payloadUsesValueFromOperand(inputOpOperand.value())) { + if (!genericOp.payloadUsesValueFromOperand(inputOpOperand)) { // Add the current operands to the list of potentially droppable // operands. If it cannot be dropped, this needs to be popped back. - droppedOpOperands.push_back(inputOpOperand.value()); + droppedOpOperands.push_back(inputOpOperand); if (genericOp.canOpOperandsBeDropped(droppedOpOperands)) continue; droppedOpOperands.pop_back(); } // Check if this operand is a duplicate. - AffineMap indexingMap = - genericOp.getMatchingIndexingMap(inputOpOperand.value()); + AffineMap indexingMap = genericOp.getMatchingIndexingMap(inputOpOperand); auto it = dedupedInputs.find( - std::make_pair(inputOpOperand.value()->get(), indexingMap)); + std::make_pair(inputOpOperand->get(), indexingMap)); if (it != dedupedInputs.end()) { - origToNewPos[inputOpOperand.index()] = it->second; - droppedOpOperands.push_back(inputOpOperand.value()); + origToNewPos[en.index()] = it->second; + droppedOpOperands.push_back(inputOpOperand); continue; } // This is a preserved argument. - origToNewPos[inputOpOperand.index()] = newInputOperands.size(); - dedupedInputs[{inputOpOperand.value()->get(), indexingMap}] = + origToNewPos[en.index()] = newInputOperands.size(); + dedupedInputs[{inputOpOperand->get(), indexingMap}] = newInputOperands.size(); - newInputOperands.push_back(inputOpOperand.value()->get()); + newInputOperands.push_back(inputOpOperand->get()); newIndexingMaps.push_back(indexingMap); } return origToNewPos; @@ -1026,12 +1029,10 @@ private: // If the op doesnt have tensor semantics, keep all the outputs as // preserved. if (!genericOp.hasTensorSemantics()) { - for (const auto &outputOpOperand : - llvm::enumerate(genericOp.getOutputOperands())) { - origToNewPos[outputOpOperand.index()] = newOutputOperands.size(); - newOutputOperands.push_back(outputOpOperand.value()->get()); - newIndexingMaps.push_back( - genericOp.getMatchingIndexingMap(outputOpOperand.value())); + for (const auto &en : llvm::enumerate(genericOp.getOutputOperands())) { + origToNewPos[en.index()] = newOutputOperands.size(); + newOutputOperands.push_back(en.value()->get()); + newIndexingMaps.push_back(genericOp.getMatchingIndexingMap(en.value())); } return origToNewPos; } @@ -1347,7 +1348,8 @@ ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) { } void MapOp::print(OpAsmPrinter &p) { - printCommonStructuredOpParts(p, getInputs(), getOutputs()); + printCommonStructuredOpParts(p, SmallVector(getInputOperands()), + SmallVector(getOutputOperands())); p.printOptionalAttrDict((*this)->getAttrs()); p << "("; @@ -1380,7 +1382,7 @@ LogicalResult MapOp::verify() { // The shape of each input must match the shape of the output. auto outputShape = - getOutputs().front().getType().cast().getShape(); + getOutputOperand(0)->get().getType().cast().getShape(); for (Type inputArgType : TypeRange{getInputs()}) { auto inputElemShape = inputArgType.cast().getShape(); if (inputElemShape != outputShape) { @@ -1409,10 +1411,8 @@ ArrayAttr MapOp::getIndexingMaps() { void MapOp::getEffects( SmallVectorImpl> &effects) { - SmallVector inputBuffers = getInputBufferOperands(); - SmallVector outputBuffers = getOutputBufferOperands(); - getGenericEffectsImpl(effects, getOperation()->getResults(), inputBuffers, - outputBuffers); + getGenericEffectsImpl(effects, getOperation()->getResults(), + getInputOperands(), getOutputOperands()); } //===----------------------------------------------------------------------===// @@ -1458,10 +1458,8 @@ ArrayAttr ReduceOp::getIndexingMaps() { void ReduceOp::getEffects( SmallVectorImpl> &effects) { - SmallVector inputBuffers = getInputBufferOperands(); - SmallVector outputBuffers = getOutputBufferOperands(); - getGenericEffectsImpl(effects, getOperation()->getResults(), inputBuffers, - outputBuffers); + getGenericEffectsImpl(effects, getOperation()->getResults(), + getInputOperands(), getOutputOperands()); } static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser, @@ -1500,7 +1498,8 @@ static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName, } void ReduceOp::print(OpAsmPrinter &p) { - printCommonStructuredOpParts(p, getInputs(), getOutputs()); + printCommonStructuredOpParts(p, SmallVector(getInputOperands()), + SmallVector(getOutputOperands())); printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions()); p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()}); @@ -1584,10 +1583,11 @@ LogicalResult ReduceOp::verify() { } // Check that the last block arguments match the element type of the outputs. - for (auto [output, bbArg] : llvm::zip( - getOutputs(), block->getArguments().take_back(getNumOutputs()))) { + for (auto [output, bbArg] : + llvm::zip(getOutputOperands(), + block->getArguments().take_back(getNumOutputs()))) { auto outputElementType = - output.getType().cast().getElementType(); + output->get().getType().cast().getElementType(); if (outputElementType != bbArg.getType()) return emitOpError() << "output element type " << outputElementType @@ -1751,14 +1751,14 @@ struct EraseDeadLinalgOp : public OpInterfaceRewritePattern { LogicalResult matchAndRewrite(LinalgOp op, PatternRewriter &rewriter) const override { - for (OpOperand *opOperand : op.getInputAndOutputOperands()) { + for (OpOperand &opOperand : op->getOpOperands()) { // Linalg "inputs" may be either tensor or memref type. // tensor<0xelt_type> is a convention that may not always mean // "0 iterations". Only erase in cases we see memref<...x0x...>. - auto mt = opOperand->get().getType().dyn_cast(); + auto mt = opOperand.get().getType().dyn_cast(); if (!mt) continue; - if (llvm::is_contained(op.getShape(opOperand), 0)) { + if (llvm::is_contained(op.getShape(&opOperand), 0)) { rewriter.eraseOp(op); return success(); } @@ -1774,10 +1774,10 @@ struct FoldTensorCastProducerOp : public OpInterfaceRewritePattern { PatternRewriter &rewriter) const override { // If no operand comes from a tensor::CastOp and can be folded then fail. bool hasTensorCastOperand = - llvm::any_of(op.getInputAndOutputOperands(), [&](OpOperand *opOperand) { - if (opOperand->get().isa()) + llvm::any_of(op->getOpOperands(), [&](OpOperand &opOperand) { + if (opOperand.get().isa()) return false; - auto castOp = opOperand->get().getDefiningOp(); + auto castOp = opOperand.get().getDefiningOp(); return castOp && canFoldIntoConsumerOp(castOp); }); if (!hasTensorCastOperand) @@ -1788,18 +1788,17 @@ struct FoldTensorCastProducerOp : public OpInterfaceRewritePattern { SmallVector newOperands; newOperands.reserve(op->getNumOperands()); // Inputs may fold. - for (OpOperand *opOperand : op.getInputOperands()) { - auto tensorCastOp = opOperand->get().getDefiningOp(); + for (auto *input : op.getInputOperands()) { + auto tensorCastOp = input->get().getDefiningOp(); newOperands.push_back(canFoldIntoConsumerOp(tensorCastOp) ? tensorCastOp.getSource() - : opOperand->get()); + : input->get()); } // Init tensors may fold, in which case the resultType must also change. - for (OpOperand *opOperand : op.getOutputOperands()) { - auto tensorCastOp = opOperand->get().getDefiningOp(); + for (auto *output : op.getOutputOperands()) { + auto tensorCastOp = output->get().getDefiningOp(); bool fold = canFoldIntoConsumerOp(tensorCastOp); - newOperands.push_back(fold ? tensorCastOp.getOperand() - : opOperand->get()); + newOperands.push_back(fold ? tensorCastOp.getOperand() : output->get()); newResultTypes.push_back(newOperands.back().getType()); } // Clone op. @@ -1858,8 +1857,8 @@ struct FoldTensorCastConsumerOp : public OpRewritePattern { OpOperand *outOperand = linalgOp.getOutputOperand(resultNumber); Value newOperand = rewriter.create(loc, resultType, outOperand->get()); - SmallVector newOperands = linalgOp.getInputOperands(); - SmallVector outputOperands = linalgOp.getOutputOperands(); + SmallVector newOperands{linalgOp.getInputOperands()}; + SmallVector outputOperands{linalgOp.getOutputOperands()}; outputOperands[resultNumber] = newOperand; newOperands.append(outputOperands.begin(), outputOperands.end()); @@ -1882,14 +1881,14 @@ struct FoldTensorCastConsumerOp : public OpRewritePattern { /// For each of the operand in `operands` this function maps the static sizes of /// dimensions to their affine dim expressions. -static void populateMap(LinalgOp linalgOp, ArrayRef operands, +static void populateMap(LinalgOp linalgOp, MutableArrayRef operands, llvm::DenseMap &affineExprToSize) { - for (OpOperand *opOperand : operands) { - if (linalgOp.isScalar(opOperand)) + for (OpOperand &opOperand : operands) { + if (linalgOp.isScalar(&opOperand)) continue; - Value src = opOperand->get(); + Value src = opOperand.get(); auto sourceType = src.getType().cast(); - auto sourceMap = linalgOp.getMatchingIndexingMap(opOperand); + auto sourceMap = linalgOp.getMatchingIndexingMap(&opOperand); // Get the `sourceShape` of the `sourceType`. If the operand is a result of // `tensor.cast` operation and source of the cast operation has a static @@ -1932,7 +1931,7 @@ static void createNewOperandWithStaticSizes( return; auto sourceType = src.getType().cast(); Type resultType = sourceType; - if (sourceType.hasStaticShape() && linalgOp.isOutputTensor(opOperand)) { + if (sourceType.hasStaticShape() && linalgOp.isOutput(opOperand)) { resultTypes.push_back(resultType); return; } @@ -1965,7 +1964,7 @@ static void createNewOperandWithStaticSizes( unsigned index = opOperand->getOperandNumber(); newOperands[index] = newOperand; } - if (linalgOp.isOutputTensor(opOperand)) + if (linalgOp.isOutput(opOperand)) resultTypes.push_back(resultType); } @@ -1992,8 +1991,7 @@ struct InferStaticShapeOfOperands : public OpInterfaceRewritePattern { // For each of the affine dim expression, check if the size is known. If // known add that in the map. - populateMap(linalgOp, linalgOp.getInputAndOutputOperands(), - affineExprToSize); + populateMap(linalgOp, linalgOp->getOpOperands(), affineExprToSize); SmallVector newOperands; SmallVector resultTypes; @@ -2001,12 +1999,12 @@ struct InferStaticShapeOfOperands : public OpInterfaceRewritePattern { // `changeNeeded` is `false` if the operands of `linalgOp` require no // change in their types. bool changeNeeded = false; - newOperands.reserve(linalgOp.getNumInputsAndOutputs()); + newOperands.reserve(linalgOp->getNumOperands()); resultTypes.reserve(linalgOp.getNumOutputs()); // Iterate over all the operands and update the static sizes. - for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) { - createNewOperandWithStaticSizes(loc, rewriter, opOperand, + for (OpOperand &opOperand : linalgOp->getOpOperands()) { + createNewOperandWithStaticSizes(loc, rewriter, &opOperand, affineExprToSize, linalgOp, newOperands, resultTypes, changeNeeded); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/BubbleUpExtractSlice.cpp b/mlir/lib/Dialect/Linalg/Transforms/BubbleUpExtractSlice.cpp index 2cf8a57..383a926 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/BubbleUpExtractSlice.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/BubbleUpExtractSlice.cpp @@ -112,14 +112,14 @@ struct BubbleUpExtractSliceOpPattern tileSizes[position] = sliceOp.getMixedSizes()[result.index()]; } - SmallVector valuesToTile = linalgOp.getInputAndOutputOperands(); + SmallVector valuesToTile = linalgOp->getOperands(); SmallVector tiledOperands = makeTiledShapes(rewriter, linalgLoc, linalgOp, valuesToTile, tileOffsets, tileSizes, sizeBounds, /*omitPartialTileCheck=*/true); SmallVector resultTensorTypes; - for (OpOperand *opOperand : linalgOp.getOutputTensorOperands()) + for (OpOperand *opOperand : linalgOp.getOutputOperands()) resultTensorTypes.push_back( tiledOperands[opOperand->getOperandNumber()].getType()); diff --git a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp index abc430f..bb38004 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp @@ -118,7 +118,7 @@ struct LinalgOpInterface auto genericOp = cast(op); // The i-th "out" tensor may alias with the i-th OpResult. - if (genericOp.isOutputTensor(&opOperand)) + if (genericOp.isOutput(&opOperand)) return {genericOp.getTiedOpResult(&opOperand)}; return {}; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp index 58a54ba..a21e0fc 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp @@ -68,17 +68,17 @@ public: if (!outputType || !outputType.hasStaticShape()) return failure(); - if (!llvm::all_of(genericOp.getInputOperands(), [](OpOperand *operand) { - return operand->get().getType().isa(); + if (!llvm::all_of(genericOp.getInputs(), [](Value input) { + return input.getType().isa(); })) return failure(); // Make sure all element types are the same. - auto getOperandElementType = [](OpOperand *operand) { - return operand->get().getType().cast().getElementType(); + auto getOperandElementType = [](Value value) { + return value.getType().cast().getElementType(); }; - if (!llvm::all_equal(llvm::map_range(genericOp.getInputAndOutputOperands(), - getOperandElementType))) + if (!llvm::all_equal( + llvm::map_range(genericOp->getOperands(), getOperandElementType))) return failure(); // We can only handle the case where we have int/float elements. @@ -114,15 +114,15 @@ public: // All inputs should be constants. int numInputs = genericOp.getNumInputs(); SmallVector inputValues(numInputs); - for (const auto &operand : llvm::enumerate(genericOp.getInputOperands())) { - if (!matchPattern(operand.value()->get(), - m_Constant(&inputValues[operand.index()]))) + for (const auto &en : llvm::enumerate(genericOp.getInputOperands())) { + if (!matchPattern(en.value()->get(), + m_Constant(&inputValues[en.index()]))) return failure(); } // Identified this as a potential candidate for folding. Now check the // policy to see whether we are allowed to proceed. - for (auto *operand : genericOp.getInputOperands()) { + for (OpOperand *operand : genericOp.getInputOperands()) { if (!controlFn(operand)) return failure(); } @@ -171,8 +171,8 @@ public: APIntOrFloatArray computeFnInputs; auto inputShapes = llvm::to_vector<4>( - llvm::map_range(genericOp.getInputOperands(), [](OpOperand *operand) { - return operand->get().getType().cast().getShape(); + llvm::map_range(genericOp.getInputs(), [](Value value) { + return value.getType().cast().getShape(); })); // Given a `linearIndex`, remap it to a linear index to access linalg op diff --git a/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp b/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp index cebc978..327e8be 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp @@ -194,7 +194,7 @@ DecomposeLinalgOp::createPeeledGenericOp(GenericOp genericOp, } /// Create the peeled generic op with an empty body. - SmallVector outsOperands = genericOp.getOutputOperands(); + SmallVector outsOperands = genericOp.getOutputs(); outsOperands.append(newInitValues.begin(), newInitValues.end()); SmallVector resultTypes = llvm::to_vector(genericOp.getResultTypes()); resultTypes.append(newResultTypes.begin(), newResultTypes.end()); @@ -212,9 +212,7 @@ DecomposeLinalgOp::createResidualGenericOp(GenericOp genericOp, PatternRewriter &rewriter) const { /// Append all results from the peeledGenericOps as `ins` operand for the /// residual generic op. - SmallVector residualGenericOpOperands = llvm::to_vector( - llvm::map_range(genericOp.getInputOperands(), - [](OpOperand *operand) { return operand->get(); })); + SmallVector residualGenericOpOperands = genericOp.getInputs(); unsigned origNumResults = genericOp.getNumResults(); unsigned peeledGenericOpNumResults = peeledGenericOp.getNumResults(); SmallVector extraIns; diff --git a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp index baef90c..acc0126 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp @@ -55,10 +55,9 @@ bool canBeDetensored(TensorType tensorType) { bool shouldBeDetensored(Operation *op, TypeConverter typeConverter) { GenericOp genericOp = dyn_cast_or_null(op); return genericOp && - llvm::all_of( - genericOp.getInputAndOutputOperands(), [&](OpOperand *opOperand) { - return !typeConverter.isLegal(opOperand->get().getType()); - }); + llvm::all_of(genericOp->getOpOperands(), [&](OpOperand &opOperand) { + return !typeConverter.isLegal(opOperand.get().getType()); + }); } /// A conversion patttern for detensoring `linalg.generic` ops. diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp index 361c85a..2fcb2ac 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -377,21 +377,21 @@ struct ReplaceUnitExtents : public OpRewritePattern { SmallVector reassociationMaps; SmallVector newInputOutputTypes; bool doCanonicalization = false; - for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) { - auto replacementInfo = replaceUnitExtents(genericOp, opOperand, context); + for (OpOperand &opOperand : genericOp->getOpOperands()) { + auto replacementInfo = replaceUnitExtents(genericOp, &opOperand, context); if (replacementInfo) { reassociationMaps.push_back(replacementInfo->reassociation); newIndexingMaps.push_back(replacementInfo->indexMap); newInputOutputTypes.push_back(replacementInfo->type); doCanonicalization |= - replacementInfo->type != opOperand->get().getType(); + replacementInfo->type != opOperand.get().getType(); } else { // If replaceUnitExtents cannot handle this case, maintain the same // type, indexing map, and create a set of mappings representing an // identity matrix. - newInputOutputTypes.push_back(opOperand->get().getType()); - newIndexingMaps.push_back(genericOp.getMatchingIndexingMap(opOperand)); - int64_t origRank = genericOp.getRank(opOperand); + newInputOutputTypes.push_back(opOperand.get().getType()); + newIndexingMaps.push_back(genericOp.getMatchingIndexingMap(&opOperand)); + int64_t origRank = genericOp.getRank(&opOperand); auto maps = llvm::to_vector<8>(llvm::map_range( llvm::seq(0, origRank), [&](int64_t dim) -> Attribute { return AffineMapAttr::get( diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp index 05dce4c..80cef16 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -90,7 +90,7 @@ bool mlir::linalg::areElementwiseOpsFusable(OpOperand *fusedOperand) { // Only allow fusing the producer of an input operand for now. // TODO: allow fusing the producer of an output operand. - if (!consumer.isInputTensor(fusedOperand)) + if (!consumer.isInput(fusedOperand)) return false; // Get the consumer index map. The number of results of the consumer index @@ -179,7 +179,7 @@ generateFusedElementwiseOpRegion(RewriterBase &rewriter, GenericOp fusedOp, } } // TODO: allow fusing the producer of an output operand. - assert(consumer.isInputTensor(fusedOperand) && + assert(consumer.isInput(fusedOperand) && "expected producer of input operand"); // 3. Consumer input operands up to consumerIdx (exclusive). for (BlockArgument bbArg : consumerBlock.getArguments().take_front( @@ -267,7 +267,7 @@ mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter, auto producer = cast(producerResult.getOwner()); auto consumer = cast(fusedOperand->getOwner()); // TODO: allow fusing the producer of an output operand. - assert(consumer.isInputTensor(fusedOperand) && + assert(consumer.isInput(fusedOperand) && "expected producer of input operand"); // Compute the fused operands list and indexing maps. @@ -278,13 +278,14 @@ mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter, fusedOutputOperands.reserve(producer.getNumOutputs() + consumer.getNumOutputs()); fusedResultTypes.reserve(producer.getNumOutputs() + consumer.getNumOutputs()); - fusedIndexMaps.reserve(producer.getNumInputsAndOutputs() + - consumer.getNumInputsAndOutputs()); + fusedIndexMaps.reserve(producer->getNumOperands() + + consumer->getNumOperands()); // In the following, numbering matches that of `generateFusedTensorOpRegion`. // 3. Consumer input operands/maps up to consumerIdx (exclusive). - SmallVector consumerInputs = consumer.getInputOperands(); - SmallVector::iterator it = - llvm::find(consumerInputs, fusedOperand); + auto consumerInputs = consumer.getInputOperands(); + auto *it = llvm::find_if(consumerInputs, [&](OpOperand *operand) { + return operand == fusedOperand; + }); assert(it != consumerInputs.end() && "expected to find the consumer operand"); for (OpOperand *opOperand : llvm::make_range(consumerInputs.begin(), it)) { fusedInputOperands.push_back(opOperand->get()); @@ -373,13 +374,13 @@ public: LogicalResult matchAndRewrite(GenericOp genericOp, PatternRewriter &rewriter) const override { // Find the first operand that is defined by another generic op on tensors. - for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) { - if (!areElementwiseOpsFusable(opOperand)) + for (OpOperand &opOperand : genericOp->getOpOperands()) { + if (!areElementwiseOpsFusable(&opOperand)) continue; - if (!controlFn(opOperand)) + if (!controlFn(&opOperand)) continue; - FailureOr fusedOp = fuseElementwiseOps(rewriter, opOperand); + FailureOr fusedOp = fuseElementwiseOps(rewriter, &opOperand); if (succeeded(fusedOp)) { auto replacements = fusedOp.value()->getResults().take_back(genericOp.getNumResults()); @@ -727,9 +728,9 @@ fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp, : collapsingReshapeOp.getSrc()); continue; } - if (genericOp.isInputTensor(opOperand)) { + if (auto opOperandType = + opOperand->get().getType().dyn_cast()) { AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand); - auto opOperandType = opOperand->get().getType().cast(); RankedTensorType expandedOperandType = getExpandedType(opOperandType, indexingMap, expansionInfo); if (expandedOperandType != opOperand->get().getType()) { @@ -833,7 +834,7 @@ public: LogicalResult matchAndRewrite(GenericOp genericOp, PatternRewriter &rewriter) const override { - for (OpOperand *opOperand : genericOp.getInputTensorOperands()) { + for (OpOperand *opOperand : genericOp.getInputOperands()) { tensor::CollapseShapeOp reshapeOp = opOperand->get().getDefiningOp(); if (!reshapeOp) @@ -1494,17 +1495,17 @@ public: LogicalResult matchAndRewrite(GenericOp genericOp, PatternRewriter &rewriter) const override { - for (OpOperand *opOperand : genericOp.getInputTensorOperands()) { + for (OpOperand &opOperand : genericOp->getOpOperands()) { tensor::ExpandShapeOp reshapeOp = - opOperand->get().getDefiningOp(); + opOperand.get().getDefiningOp(); if (!reshapeOp) continue; SmallVector collapsableIterationDims = - getCollapsableIterationSpaceDims(genericOp, opOperand, + getCollapsableIterationSpaceDims(genericOp, &opOperand, reshapeOp.getReassociationIndices()); if (collapsableIterationDims.empty() || - !controlFoldingReshapes(opOperand)) { + !controlFoldingReshapes(&opOperand)) { continue; } @@ -1614,7 +1615,7 @@ public: SmallVector fusedIndexMaps; SmallVector fusedOperands; SmallVector fusedLocs{genericOp.getLoc()}; - fusedIndexMaps.reserve(genericOp.getNumInputsAndOutputs()); + fusedIndexMaps.reserve(genericOp->getNumOperands()); fusedOperands.reserve(genericOp.getNumInputs()); fusedLocs.reserve(fusedLocs.size() + genericOp.getNumInputs()); for (OpOperand *inputOperand : genericOp.getInputOperands()) { @@ -1640,7 +1641,7 @@ public: Value scalarConstant = rewriter.create( def->getLoc(), constantAttr, constantAttr.getType()); - SmallVector outputOperands = genericOp.getOutputOperands(); + SmallVector outputOperands = genericOp.getOutputs(); auto fusedOp = rewriter.create( rewriter.getFusedLoc(fusedLocs), genericOp->getResultTypes(), /*inputs=*/fusedOperands, diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp index 5738d51..5c2c987 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -68,7 +68,7 @@ getShapeDefiningLoopRange(LinalgOp op, unsigned loopDepth, bool fromSubViewOpOnly = false) { // Iterate over the inputs and outputs in order. // Extract the subranges from the linearized ranges. - for (OpOperand *opOperand : op.getInputAndOutputOperands()) { + for (OpOperand &opOperand : op->getOpOperands()) { // The method `getRangeFromOperandShape` requires using SubViewOp or // ExtractSliceOps. If the value isn't defined from there continue. // todo: The method should be adapted to get the values from @@ -77,12 +77,12 @@ getShapeDefiningLoopRange(LinalgOp op, unsigned loopDepth, // `std` dialect and add the method to `ViewInterface`. if (fromSubViewOpOnly && !isa_and_nonnull( - opOperand->get().getDefiningOp())) + opOperand.get().getDefiningOp())) continue; - AffineMap map = op.getMatchingIndexingMap(opOperand); + AffineMap map = op.getMatchingIndexingMap(&opOperand); LLVM_DEBUG(llvm::dbgs() << "getShapeDefiningLoopRange I/O idx: " - << opOperand->getOperandNumber() << "\n"); + << opOperand.getOperandNumber() << "\n"); LLVM_DEBUG(llvm::dbgs() << "getShapeDefiningLoopRange map: " << map << "\n"); SmallVector shapeRanges(map.getNumResults(), nullptr); @@ -94,8 +94,8 @@ getShapeDefiningLoopRange(LinalgOp op, unsigned loopDepth, LLVM_DEBUG(llvm::dbgs() << "getShapeDefiningLoopRange loopDepth: " << loopDepth << "\n"); LLVM_DEBUG(llvm::dbgs() << "getShapeDefiningLoopRange shape: " - << opOperand->get() << "\n"); - return ShapeDimension{opOperand->get(), + << opOperand.get() << "\n"); + return ShapeDimension{opOperand.get(), static_cast(en.index())}; } } @@ -104,7 +104,7 @@ getShapeDefiningLoopRange(LinalgOp op, unsigned loopDepth, } static SmallVector getTiledOperands(LinalgOp producer) { - return producer.getInputAndOutputOperands(); + return producer->getOperands(); } /// Fuses the producer by cloning the `producer`. The `fusedLoopsAndRanges` @@ -137,7 +137,7 @@ static LinalgOp fuse(OpBuilder &b, LinalgOp producer, } SmallVector clonedShapes; - clonedShapes.reserve(producer.getNumInputsAndOutputs()); + clonedShapes.reserve(producer->getNumOperands()); // Compute subranges for all tensor input/output operands. clonedShapes.append(makeTiledShapes( @@ -150,15 +150,18 @@ static LinalgOp fuse(OpBuilder &b, LinalgOp producer, // fully dynamic at construction time. SmallVector resultTypes; resultTypes.reserve(producer->getNumResults()); - for (RankedTensorType t : producer.getOutputTensorTypes()) { - unsigned rank = t.getRank(); + for (OpOperand *operand : producer.getOutputOperands()) { + auto tensorType = operand->get().getType().dyn_cast(); + if (!tensorType) + continue; + unsigned rank = tensorType.getRank(); SmallVector staticOffsetsVector( rank, ShapedType::kDynamicStrideOrOffset); SmallVector staticSizesVector(rank, ShapedType::kDynamicSize); SmallVector staticStridesVector( rank, ShapedType::kDynamicStrideOrOffset); resultTypes.push_back(tensor::ExtractSliceOp::inferResultType( - t.cast(), staticOffsetsVector, staticSizesVector, + tensorType, staticOffsetsVector, staticSizesVector, staticStridesVector)); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp index 2451c79..1dd6c35 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp @@ -161,7 +161,7 @@ static LinalgOp getTiledProducer(OpBuilder &b, OpResult producerResult, allIvs[tiledProducerLoop] = tileIvs[tiledProducerLoop]; } erase_value(tileIvs, OpFoldResult()); - SmallVector tiledOperands = producerOp.getInputAndOutputOperands(); + SmallVector tiledOperands = producerOp->getOperands(); tiledOperands = makeTiledShapes(b, loc, producerOp, tiledOperands, tileIvs, tileSizes, producerLoopBounds, /**omitPartialTileCheck=*/false); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp index d656e92..ea6ce39 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp @@ -50,19 +50,19 @@ FailureOr mlir::linalg::generalizeNamedOp(RewriterBase &rewriter, if (failed(generalizeNamedOpPrecondition(linalgOp))) return rewriter.notifyMatchFailure(linalgOp, "preconditions not met"); - SmallVector inputOperands = linalgOp.getInputOperands(); - SmallVector outputOperands = linalgOp.getOutputOperands(); + SmallVector inputs = linalgOp.getInputOperands(); + SmallVector outputs = linalgOp.getOutputOperands(); SmallVector indexingMaps = linalgOp.getIndexingMapsArray(); SmallVector iterators = linalgOp.getIteratorTypesArray(); - SmallVector resultTypes = linalgOp.getOutputTensorTypes(); - SmallVector types(resultTypes.begin(), resultTypes.end()); + SmallVector resultTypes = linalgOp.hasTensorSemantics() + ? TypeRange(ValueRange(outputs)) + : TypeRange{}; // All named ops have a region attached that can be inlined. assert(linalgOp->getNumRegions() == 1 && "expect named op to have one region attached"); - GenericOp genericOp = - rewriter.create(linalgOp.getLoc(), types, inputOperands, - outputOperands, indexingMaps, iterators); + GenericOp genericOp = rewriter.create( + linalgOp.getLoc(), resultTypes, inputs, outputs, indexingMaps, iterators); rewriter.inlineRegionBefore(linalgOp->getRegion(0), genericOp.getRegion(), genericOp.getRegion().begin()); rewriter.replaceOp(linalgOp, genericOp->getResults()); diff --git a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp index 2ca15dc..7515e30 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp @@ -111,7 +111,7 @@ private: static bool isOnlyUsedAsInputOfLinalgOp(tensor::PadOp padOp) { for (OpOperand &use : padOp.getResult().getUses()) { auto linalgUser = dyn_cast(use.getOwner()); - if (!linalgUser || !linalgUser.isInputTensor(&use)) { + if (!linalgUser || !linalgUser.isInput(&use)) { LLVM_DEBUG(DBGS() << "Found a use of " << *(padOp) << "\nthat is not an input tensor of a LinalgOp, " << "cannot hoist\n" diff --git a/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp b/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp index 04e94b1..4ea889d 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp @@ -43,7 +43,7 @@ struct InlineScalarOperands : public OpRewritePattern { SmallVector newOperands; for (OpOperand *opOperand : genericOp.getInputOperands()) { AffineMap map = genericOp.getMatchingIndexingMap(opOperand); - if (genericOp.isInputTensor(opOperand) && map.isConstant()) { + if (genericOp.isInput(opOperand) && map.isConstant()) { scalarOperands.emplace_back(opOperand->getOperandNumber()); } else { newIndexingMaps.emplace_back(map); @@ -58,7 +58,7 @@ struct InlineScalarOperands : public OpRewritePattern { newIndexingMaps.emplace_back(genericOp.getMatchingIndexingMap(opOperand)); Location loc = genericOp->getLoc(); - SmallVector outputOperands = genericOp.getOutputOperands(); + SmallVector outputOperands = genericOp.getOutputs(); auto newOp = rewriter.create( loc, genericOp->getResultTypes(), newOperands, outputOperands, newIndexingMaps, genericOp.getIteratorTypesArray()); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp b/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp index 8641e11..a745387 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp @@ -67,8 +67,8 @@ mlir::linalg::interchangeGenericOp(RewriterBase &rewriter, GenericOp genericOp, // 2. Compute the interchanged indexing maps. SmallVector newIndexingMaps; - for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) { - AffineMap m = genericOp.getMatchingIndexingMap(opOperand); + for (OpOperand &opOperand : genericOp->getOpOperands()) { + AffineMap m = genericOp.getMatchingIndexingMap(&opOperand); if (!permutationMap.isEmpty()) m = m.compose(permutationMap); newIndexingMaps.push_back(m); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp index 3052a4d..4fc9149 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp @@ -131,7 +131,7 @@ static void emitScalarImplementation(OpBuilder &b, Location loc, assert(linalgOp.hasBufferSemantics() && "expected linalg op with buffer semantics"); SmallVector indexedValues; - indexedValues.reserve(linalgOp.getNumInputsAndOutputs()); + indexedValues.reserve(linalgOp->getNumOperands()); auto allIvsPlusDims = SmallVector(allIvs.begin(), allIvs.end()); @@ -161,7 +161,9 @@ static void emitScalarImplementation(OpBuilder &b, Location loc, // 3. Emit store. SmallVector, 8> indexing; SmallVector outputBuffers; - for (OpOperand *outputOperand : linalgOp.getOutputBufferOperands()) { + for (OpOperand *outputOperand : linalgOp.getOutputOperands()) { + if (!outputOperand->get().getType().isa()) + continue; indexing.push_back(makeCanonicalAffineApplies( b, loc, linalgOp.getMatchingIndexingMap(outputOperand), allIvsPlusDims)); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp index 17d74fa..0995b01 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp @@ -145,15 +145,15 @@ LinalgOpInstancePromotionOptions::LinalgOpInstancePromotionOptions( assert(linalgOp.hasBufferSemantics() && "revisit usage of shaped operand"); auto vUseFullTileBuffers = options.useFullTileBuffers.value_or(llvm::SmallBitVector()); - vUseFullTileBuffers.resize(linalgOp.getNumInputsAndOutputs(), + vUseFullTileBuffers.resize(linalgOp->getNumOperands(), options.useFullTileBuffersDefault); - for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) { - int64_t operandNumber = opOperand->getOperandNumber(); + for (OpOperand &opOperand : linalgOp->getOpOperands()) { + int64_t operandNumber = opOperand.getOperandNumber(); if (options.operandsToPromote && !options.operandsToPromote->count(operandNumber)) continue; - Operation *op = opOperand->get().getDefiningOp(); + Operation *op = opOperand.get().getDefiningOp(); if (auto sv = dyn_cast_or_null(op)) { subViews[operandNumber] = sv; useFullTileBuffers[sv] = vUseFullTileBuffers[operandNumber]; @@ -326,13 +326,13 @@ promoteSubViews(ImplicitLocOpBuilder &b, LinalgOp op, // operands are not views. This is to support cases such as FillOp taking // extra scalars etc. Keep a reference to output buffers; SmallVector opViews; - opViews.reserve(op.getNumInputsAndOutputs()); + opViews.reserve(op->getNumOperands()); SmallVector, 8> writebackViews; writebackViews.reserve(promotedBuffersAndViews->size()); - for (OpOperand *opOperand : op.getInputAndOutputOperands()) { - int64_t operandNumber = opOperand->getOperandNumber(); + for (OpOperand &opOperand : op->getOpOperands()) { + int64_t operandNumber = opOperand.getOperandNumber(); if (options.subViews.count(operandNumber) != 0) { - if (options.useFullTileBuffers[opOperand->get()]) + if (options.useFullTileBuffers[opOperand.get()]) opViews.push_back( (*promotedBuffersAndViews)[operandNumber].fullLocalView); else @@ -340,10 +340,10 @@ promoteSubViews(ImplicitLocOpBuilder &b, LinalgOp op, (*promotedBuffersAndViews)[operandNumber].partialLocalView); if (operandNumber >= op.getNumInputs()) writebackViews.emplace_back(std::make_pair( - opOperand->get(), + opOperand.get(), (*promotedBuffersAndViews)[operandNumber].partialLocalView)); } else { - opViews.push_back(opOperand->get()); + opViews.push_back(opOperand.get()); } } op->setOperands(0, opViews.size(), opViews); @@ -371,12 +371,12 @@ mlir::linalg::promoteSubviewsPrecondition(Operation *op, if (!linalgOp || !linalgOp.hasBufferSemantics()) return failure(); // Check that at least one of the requested operands is indeed a subview. - for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) { + for (OpOperand &opOperand : linalgOp->getOpOperands()) { auto sv = - isa_and_nonnull(opOperand->get().getDefiningOp()); + isa_and_nonnull(opOperand.get().getDefiningOp()); if (sv) { if (!options.operandsToPromote || - options.operandsToPromote->count(opOperand->getOperandNumber())) + options.operandsToPromote->count(opOperand.getOperandNumber())) return success(); } } diff --git a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp index 7df65c8..92d04c1 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp @@ -214,7 +214,6 @@ FailureOr mlir::linalg::splitReduction( // from the previous op. unsigned intermRank = newOutputShape.size(); AffineMap inputMap = b.getMultiDimIdentityMap(intermRank); - SmallVector outputOperands = op.getOutputOperands(); SmallVector reductionIteratorTypes; SmallVector exprs; for (unsigned i : llvm::seq(0, intermRank)) { @@ -230,7 +229,8 @@ FailureOr mlir::linalg::splitReduction( auto reduction = b.create( loc, op->getResultTypes(), ValueRange({genericOp.getResult(0)}), - outputOperands, reductionMaps, reductionIteratorTypes, + SmallVector{op.getOutputOperands()}, reductionMaps, + reductionIteratorTypes, [reductionOp](OpBuilder &b, Location loc, ValueRange inputs) { Operation *clonedReductionOp = b.clone(*reductionOp); clonedReductionOp->setOperand(0, inputs[0]); @@ -341,8 +341,8 @@ FailureOr mlir::linalg::splitReductionByScaling( SmallVector emptyOrAllocTensorOps; SmallVector fillOps; fillOps.reserve(op.getNumOutputs()); - for (auto it : llvm::zip(op.getOutputs(), neutralElements)) { - Value rankedTensor = std::get<0>(it); + for (auto it : llvm::zip(op.getOutputOperands(), neutralElements)) { + Value rankedTensor = std::get<0>(it)->get(); auto t = rankedTensor.getType().cast(); RankedTensorType newT = RankedTensorType::Builder(t).insertDim( reductionDimSize / splitFactor, insertSplitDimension); @@ -366,7 +366,7 @@ FailureOr mlir::linalg::splitReductionByScaling( // Step 2. Reindex / expand indexing maps. // Reindex existing input indexings: k -> k * splitFactor + k'. SmallVector newMaps; - newMaps.reserve(op.getNumInputsAndOutputs() + 1); + newMaps.reserve(op->getNumOperands() + 1); for (OpOperand *o : op.getInputOperands()) newMaps.push_back(scaleReductionDim(op, *o, reductionDimPos, splitFactor)); // Provision a new indexing for the shape-only tensor. @@ -384,7 +384,7 @@ FailureOr mlir::linalg::splitReductionByScaling( // Step 3. Handle operands. // Compute the new input tensors. - auto newInputs = llvm::to_vector<4>(op.getInputs()); + SmallVector newInputs(op.getInputOperands()); // Add a single shape-only tensor to carry the dimensions without resorting to // more complex inversions. newInputs.push_back(b.create( @@ -413,10 +413,10 @@ FailureOr mlir::linalg::splitReductionByScaling( // TODO: all results can be handled in a single GenericOp, when // multi-reduction support is available. SmallVector results; - for (auto it : - llvm::zip(genericOp->getResults(), op.getOutputs(), combinerOps)) { + for (auto it : llvm::zip(genericOp->getResults(), op.getOutputOperands(), + combinerOps)) { Value reindexedOutput = std::get<0>(it); - Value originalOutput = std::get<1>(it); + Value originalOutput = std::get<1>(it)->get(); auto originalOutputType = originalOutput.getType().cast(); Operation *combinerOp = std::get<2>(it); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp index dd04d00..b66a718 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -503,7 +503,7 @@ tileLinalgOpImpl(RewriterBase &b, LinalgOp op, ArrayRef tileSizes, // Tile the `operandValuesToUse` that either match the `op` operands // themselves or the tile loop arguments forwarding them. assert(operandValuesToUse.size() == - static_cast(op.getNumInputsAndOutputs()) && + static_cast(op->getNumOperands()) && "expect the number of operands and inputs and outputs to match"); SmallVector valuesToTile = operandValuesToUse; SmallVector sizeBounds = diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp index 66d55dc..d88b2c5 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp @@ -125,14 +125,12 @@ struct LinalgOpTilingInterface // specified could lead to out of bounds accesses. Location loc = op->getLoc(); LinalgOp linalgOp = cast(op); - SmallVector valuesToTile = linalgOp.getInputAndOutputOperands(); + SmallVector valuesToTile = linalgOp->getOperands(); SmallVector tiledOperands = makeTiledShapes( b, loc, linalgOp, valuesToTile, offsets, sizes, {}, true); - SmallVector resultTensorTypes = llvm::to_vector(llvm::map_range( - linalgOp.getOutputTensorOperands(), [&](OpOperand *opOperand) { - return tiledOperands[opOperand->getOperandNumber()].getType(); - })); + SmallVector resultTensorTypes = + getTensorOutputTypes(linalgOp, tiledOperands); Operation *tiledOp = linalgOp.clone(b, loc, resultTensorTypes, tiledOperands); @@ -222,23 +220,23 @@ struct LinalgOpTilingInterface return op->emitOpError("expected operation to have buffer semantics"); SmallVector indexedValues; - indexedValues.reserve(linalgOp.getNumInputsAndOutputs()); + indexedValues.reserve(linalgOp->getNumOperands()); Location linalgOpLoc = op->getLoc(); /// Load the data corresponding to the block arguments that /// represent input operands. - for (OpOperand *operand : linalgOp.getInputAndOutputOperands()) { - if (!linalgOp.payloadUsesValueFromOperand(operand)) { + for (OpOperand &operand : linalgOp->getOpOperands()) { + if (!linalgOp.payloadUsesValueFromOperand(&operand)) { indexedValues.push_back(nullptr); continue; } - if (linalgOp.isScalar(operand)) { - indexedValues.push_back(operand->get()); + if (linalgOp.isScalar(&operand)) { + indexedValues.push_back(operand.get()); continue; } SmallVector indices = getIndicesForAccess( - builder, linalgOpLoc, linalgOp.getMatchingIndexingMap(operand), ivs); + builder, linalgOpLoc, linalgOp.getMatchingIndexingMap(&operand), ivs); Value load = - builder.create(linalgOpLoc, operand->get(), indices); + builder.create(linalgOpLoc, operand.get(), indices); indexedValues.push_back(load); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp index 8eb41c5..eee454b 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -203,10 +203,10 @@ linalg::rewriteAsPaddedOp(OpBuilder &b, LinalgOp opToPad, b.setInsertionPointAfter(opToPad); // Make a copy of the shaped operands and update it. SmallVector newOperands; - newOperands.reserve(opToPad.getNumInputsAndOutputs()); - for (OpOperand *opOperand : opToPad.getInputAndOutputOperands()) { + newOperands.reserve(opToPad->getNumOperands()); + for (OpOperand &opOperand : opToPad->getOpOperands()) { FailureOr paddedOperand = padOperandToSmallestStaticBoundingBox( - b, opToPad, opOperand, paddingDimensions, paddingValues, packPaddings); + b, opToPad, &opOperand, paddingDimensions, paddingValues, packPaddings); // Exit if `paddingDimensions` cannot be bounded statically. if (failed(paddedOperand)) return failure(); @@ -327,15 +327,15 @@ mlir::linalg::LinalgPaddingPattern::returningMatchAndRewrite( // Hoist the padding. for (const auto &en : enumerate(options.hoistPaddings)) { - if (static_cast(en.index()) >= paddedOp.getNumInputsAndOutputs()) + if (static_cast(en.index()) >= paddedOp->getNumOperands()) break; - OpOperand *opOperand = &paddedOp->getOpOperand(en.index()); - auto padOp = opOperand->get().getDefiningOp(); + OpOperand &opOperand = paddedOp->getOpOperand(en.index()); + auto padOp = opOperand.get().getDefiningOp(); if (!padOp || en.value() == 0) continue; // Fail hoisting if the operand shape is not fully static. - if (llvm::any_of(paddedOp.getShape(opOperand), ShapedType::isDynamic)) + if (llvm::any_of(paddedOp.getShape(&opOperand), ShapedType::isDynamic)) return failure(); tensor::PadOp hoistedOp; diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index 5623a16..2b70155 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -459,35 +459,35 @@ vectorizeAsLinalgGeneric(OpBuilder &b, LinalgOp linalgOp, // 3. Turn all BBArgs into vector.transfer_read / load. Location loc = linalgOp.getLoc(); Value zero = b.create(loc, 0); - for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) { - BlockArgument bbarg = block->getArgument(opOperand->getOperandNumber()); - if (linalgOp.isScalar(opOperand)) { - bvm.map(bbarg, opOperand->get()); + for (OpOperand &opOperand : linalgOp->getOpOperands()) { + BlockArgument bbarg = block->getArgument(opOperand.getOperandNumber()); + if (linalgOp.isScalar(&opOperand)) { + bvm.map(bbarg, opOperand.get()); continue; } VectorType readType; AffineMap map; // TODO: can we keep this simplification? - // if (linalgOp.getShape(opOperand).empty()) { + // if (linalgOp.getShape(&opOperand).empty()) { // readType = VectorType::get({}, bbarg.getType()); // } else { - if (opOperand->getOperandNumber() < linalgOp.getNumInputs()) { + if (opOperand.getOperandNumber() < linalgOp.getNumInputs()) { map = inverseAndBroadcastProjectedPermutation( - linalgOp.getMatchingIndexingMap(opOperand)); + linalgOp.getMatchingIndexingMap(&opOperand)); readType = VectorType::get(commonVectorShape, - getElementTypeOrSelf(opOperand->get())); + getElementTypeOrSelf(opOperand.get())); } else { map = inversePermutation( - reindexIndexingMap(linalgOp.getMatchingIndexingMap(opOperand))); - readType = VectorType::get(map.compose(linalgOp.getShape(opOperand)), - getElementTypeOrSelf(opOperand->get())); + reindexIndexingMap(linalgOp.getMatchingIndexingMap(&opOperand))); + readType = VectorType::get(map.compose(linalgOp.getShape(&opOperand)), + getElementTypeOrSelf(opOperand.get())); } // } - auto shape = linalgOp.getShape(opOperand); + auto shape = linalgOp.getShape(&opOperand); SmallVector indices(shape.size(), zero); Value readValue = b.create( - loc, readType, opOperand->get(), indices, map); + loc, readType, opOperand.get(), indices, map); // Not all ops support 0-d vectors, extract the scalar for now. // TODO: remove this. if (readValue.getType().cast().getRank() == 0) @@ -495,7 +495,7 @@ vectorizeAsLinalgGeneric(OpBuilder &b, LinalgOp linalgOp, LDBG("new vectorized bbarg(" << bbarg.getArgNumber() << "): " << readValue); bvm.map(bbarg, readValue); - bvm.map(opOperand->get(), readValue); + bvm.map(opOperand.get(), readValue); } SmallVector hooks; @@ -1342,9 +1342,9 @@ struct Conv1DGenerator : public StructuredGenerator { // Determine whether `linalgOp` can be generated with this generator if (linalgOp.getNumInputs() != 2 || linalgOp.getNumOutputs() != 1) return; - lhsShaped = linalgOp.getInputs()[0]; - rhsShaped = linalgOp.getInputs()[1]; - resShaped = linalgOp.getOutputs()[0]; + lhsShaped = linalgOp.getInputOperand(0)->get(); + rhsShaped = linalgOp.getInputOperand(1)->get(); + resShaped = linalgOp.getOutputOperand(0)->get(); lhsShapedType = lhsShaped.getType().dyn_cast(); rhsShapedType = rhsShaped.getType().dyn_cast(); resShapedType = resShaped.getType().dyn_cast(); diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp index 999034b..119c3db 100644 --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -490,17 +490,18 @@ void GenerateLoopNest::doit( assert((procInfo.empty() || (procInfo.size() == loopRanges.size())) && "expected as many entries for proc info as number of loops, even if " "they are null entries"); - SmallVector iterArgInitValues = linalgOp.getOutputTensorOperands(); + SmallVector iterArgInitValues = linalgOp.hasBufferSemantics() + ? SmallVector{} + : linalgOp.getOutputOperands(); SmallVector lbs, ubs, steps; unpackRanges(b, loc, loopRanges, lbs, ubs, steps); LoopNest loopNest = mlir::scf::buildLoopNest( b, loc, lbs, ubs, steps, iterArgInitValues, [&](OpBuilder &b, Location loc, ValueRange ivs, ValueRange iterArgs) { - assert(iterArgs.size() == linalgOp.getOutputTensorOperands().size() && + assert(iterArgs.size() == iterArgInitValues.size() && "expect the number of output tensors and iter args to match"); - SmallVector operandValuesToUse = - linalgOp.getInputAndOutputOperands(); + SmallVector operandValuesToUse = linalgOp->getOperands(); if (!iterArgs.empty()) { operandValuesToUse = linalgOp.getInputOperands(); operandValuesToUse.append(iterArgs.begin(), iterArgs.end()); @@ -530,7 +531,9 @@ void GenerateLoopNest::doit( ValueRange)> bodyBuilderFn, ArrayRef /*procInfo*/) { - SmallVector iterArgInitValues = linalgOp.getOutputTensorOperands(); + SmallVector iterArgInitValues = linalgOp.hasBufferSemantics() + ? SmallVector{} + : linalgOp.getOutputOperands(); assert(iterArgInitValues.empty() && "unexpected AffineForOp init values"); SmallVector lbs, ubs, steps; unpackRanges(b, loc, loopRanges, lbs, ubs, steps); @@ -546,9 +549,8 @@ void GenerateLoopNest::doit( mlir::buildAffineLoopNest(b, loc, lbs, ubs, constantSteps, [&](OpBuilder &b, Location loc, ValueRange ivs) { - SmallVector operandValuesToUse = - linalgOp.getInputAndOutputOperands(); - bodyBuilderFn(b, loc, ivs, operandValuesToUse); + bodyBuilderFn(b, loc, ivs, + linalgOp->getOperands()); }); } @@ -695,7 +697,9 @@ void GenerateLoopNest::doit( ValueRange)> bodyBuilderFn, ArrayRef procInfo) { - SmallVector iterArgInitValues = linalgOp.getOutputTensorOperands(); + SmallVector iterArgInitValues = linalgOp.hasBufferSemantics() + ? SmallVector{} + : linalgOp.getOutputOperands(); assert(iterArgInitValues.empty() && "unexpected ParallelOp init values"); // This function may be passed more iterator types than ranges. assert(iteratorTypes.size() >= loopRanges.size() && @@ -725,9 +729,7 @@ void GenerateLoopNest::doit( generateParallelLoopNest( b, loc, lbs, ubs, steps, iteratorTypes, procInfo, [&](OpBuilder &b, Location loc, ValueRange ivs) { - SmallVector operandValuesToUse = - linalgOp.getInputAndOutputOperands(); - bodyBuilderFn(b, loc, ivs, operandValuesToUse); + bodyBuilderFn(b, loc, ivs, linalgOp->getOperands()); }, ivs); @@ -905,10 +907,10 @@ SmallVector computeTileSizes(OpBuilder &b, Location loc, } SmallVector getTensorOutputTypes(LinalgOp op, ValueRange operands) { - // TODO: use an interface/adaptor to avoid leaking position in - // `tiledOperands`. + if (op.hasBufferSemantics()) + return {}; return llvm::to_vector( - llvm::map_range(op.getOutputTensorOperands(), [&](OpOperand *opOperand) { + llvm::map_range(op.getOutputOperands(), [&](OpOperand *opOperand) { return operands[opOperand->getOperandNumber()].getType(); })); } @@ -916,11 +918,13 @@ SmallVector getTensorOutputTypes(LinalgOp op, ValueRange operands) { SmallVector insertSlicesBack(OpBuilder &builder, Location loc, LinalgOp op, ValueRange operands, ValueRange results) { + if (op.hasBufferSemantics()) + return {}; SmallVector tensorResults; tensorResults.reserve(results.size()); // Insert a insert_slice for each output tensor. unsigned resultIdx = 0; - for (OpOperand *opOperand : op.getOutputTensorOperands()) { + for (OpOperand *opOperand : op.getOutputOperands()) { // TODO: use an interface/adaptor to avoid leaking position in // `tiledOperands`. Value outputTensor = operands[opOperand->getOperandNumber()]; @@ -958,23 +962,26 @@ computeAllSliceParameters(OpBuilder &builder, Location loc, LinalgOp linalgOp, computeTileSizes(builder, loc, tileSizes, sizeBounds); assert(static_cast(valuesToTile.size()) == - linalgOp.getNumInputsAndOutputs() && + linalgOp->getNumOperands() && "expected one value to tile for every operand"); SmallVector> allSliceParams; allSliceParams.reserve(valuesToTile.size()); - for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) { - Value shapedOp = valuesToTile[opOperand->getOperandNumber()]; + for (OpOperand &opOperand : linalgOp->getOpOperands()) { + Value shapedOp = valuesToTile[opOperand.getOperandNumber()]; LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: for operand " << shapedOp); - AffineMap map = linalgOp.getMatchingIndexingMap(opOperand); + AffineMap map = linalgOp.getMatchingIndexingMap(&opOperand); // Use `opOperand` as is if it is not tiled and not an output tensor. Having // an extract/insert slice pair for all output tensors simplifies follow up // transformations such as padding and bufferization since the // extract/insert slice pairs make the accessed iteration argument // subdomains explicit. - if (!isTiled(map, tileSizes) && !linalgOp.isOutputTensor(opOperand)) { + + Type operandType = opOperand.get().getType(); + if (!isTiled(map, tileSizes) && !(operandType.isa() && + linalgOp.isOutput(&opOperand))) { allSliceParams.push_back(llvm::None); - LLVM_DEBUG(llvm::dbgs() << ": not tiled: use shape: " - << opOperand->get().getType() << "\n"); + LLVM_DEBUG(llvm::dbgs() + << ": not tiled: use shape: " << operandType << "\n"); continue; } LLVM_DEBUG(llvm::dbgs() << ": tiled: figure out subshape...\n"); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp index 2458dab..73f428a 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -105,8 +105,7 @@ static bool isZeroYield(GenericOp op) { auto yieldOp = cast(op.getRegion().front().getTerminator()); if (auto arg = yieldOp.getOperand(0).dyn_cast()) { if (arg.getOwner()->getParentOp() == op) { - OpOperand *t = op.getInputAndOutputOperands()[arg.getArgNumber()]; - return isZeroValue(t->get()); + return isZeroValue(op->getOperand(arg.getArgNumber())); } } return isZeroValue(yieldOp.getOperand(0)); @@ -242,8 +241,8 @@ public: return failure(); // Modify operand structure of producer and consumer. Location loc = prod.getLoc(); - SmallVector inputOps = prod.getInputOperands(); - SmallVector outputOps = op.getOutputOperands(); + SmallVector inputOps = prod.getInputs(); + SmallVector outputOps = op.getOutputs(); SmallVector fusedIndexMaps = prod.getIndexingMapsArray(); inputOps.push_back(op.getInputOperand(1 - other)->get()); fusedIndexMaps.push_back(fusedIndexMaps.back()); // mimic other diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp index 1418ed4..e512723 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -194,14 +194,14 @@ static bool findAffine(Merger &merger, unsigned tensor, AffineExpr a, /// no annotations are found or inadmissible constructs occur. static bool findSparseAnnotations(Merger &merger, linalg::GenericOp op) { bool annotated = false; - for (OpOperand *t : op.getInputAndOutputOperands()) { - auto map = op.getMatchingIndexingMap(t); - auto enc = getSparseTensorEncoding(t->get().getType()); + for (OpOperand &t : op->getOpOperands()) { + auto map = op.getMatchingIndexingMap(&t); + auto enc = getSparseTensorEncoding(t.get().getType()); if (enc) annotated = true; - assert(map.getNumResults() == op.getRank(t)); + assert(map.getNumResults() == op.getRank(&t)); for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { - unsigned tensor = t->getOperandNumber(); + unsigned tensor = t.getOperandNumber(); AffineExpr a = map.getResult(toOrigDim(enc, d)); if (!findAffine(merger, tensor, a, toDimLevelFormat(enc, d))) return false; // inadmissible affine expression @@ -291,13 +291,13 @@ static bool computeIterationGraph(Merger &merger, linalg::GenericOp op, std::vector inDegree(n, 0); // in-degree of each node. auto iteratorTypes = op.getIteratorTypesArray(); // Iterate over the indexing maps of every tensor in the tensor expression. - for (OpOperand *t : op.getInputAndOutputOperands()) { + for (OpOperand &t : op->getOpOperands()) { // Skip tensor during cycle resolution. - if (t == skip) + if (&t == skip) continue; // Get map and encoding. - auto map = op.getMatchingIndexingMap(t); - auto enc = getSparseTensorEncoding(t->get().getType()); + auto map = op.getMatchingIndexingMap(&t); + auto enc = getSparseTensorEncoding(t.get().getType()); assert(map.getNumDims() == n); // Skip dense tensor constraints when not requested. if (!(mask & SortMask::kIncludeDense) && !enc) @@ -314,7 +314,7 @@ static bool computeIterationGraph(Merger &merger, linalg::GenericOp op, // Push unrelated loops into sparse iteration space, so these // will be skipped more often. if (mask & SortMask::kIncludeUndef) { - unsigned tensor = t->getOperandNumber(); + unsigned tensor = t.getOperandNumber(); for (unsigned i = 0; i < n; i++) if (merger.isDimLevelType(tensor, i, DimLvlType::kCompressed) || merger.isDimLevelType(tensor, i, DimLvlType::kSingleton)) { @@ -534,16 +534,16 @@ static Value genOutputBuffer(CodeGen &codegen, OpBuilder &builder, static void genBuffers(Merger &merger, CodeGen &codegen, OpBuilder &builder, linalg::GenericOp op) { Location loc = op.getLoc(); - assert(op.getNumInputsAndOutputs() == op.getNumInputs() + 1); + assert(op->getNumOperands() == op.getNumInputs() + 1); // For every tensor, find lower and upper bound on dimensions, set the // same bounds on loop indices, and obtain dense or sparse buffer(s). auto dynShape = {ShapedType::kDynamicSize}; SmallVector args; - for (OpOperand *t : op.getInputAndOutputOperands()) { - unsigned tensor = t->getOperandNumber(); - auto shape = op.getShape(t); - auto map = op.getMatchingIndexingMap(t); - auto enc = getSparseTensorEncoding(t->get().getType()); + for (OpOperand &t : op->getOpOperands()) { + unsigned tensor = t.getOperandNumber(); + auto shape = op.getShape(&t); + auto map = op.getMatchingIndexingMap(&t); + auto enc = getSparseTensorEncoding(t.get().getType()); // Scan all dimensions of current tensor. args.clear(); for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { @@ -560,23 +560,23 @@ static void genBuffers(Merger &merger, CodeGen &codegen, OpBuilder &builder, MemRefType::get(dynShape, getIndexOverheadType(builder, enc)); auto dim = builder.getIndexAttr(d); codegen.pointers[tensor][idx] = - builder.create(loc, ptrTp, t->get(), dim); + builder.create(loc, ptrTp, t.get(), dim); codegen.indices[tensor][idx] = - builder.create(loc, indTp, t->get(), dim); + builder.create(loc, indTp, t.get(), dim); } else if (merger.isDimLevelType(tensor, idx, DimLvlType::kSingleton)) { // Singleton dimension, fetch indices. auto indTp = MemRefType::get(dynShape, getIndexOverheadType(builder, enc)); auto dim = builder.getIndexAttr(d); codegen.indices[tensor][idx] = - builder.create(loc, indTp, t->get(), dim); + builder.create(loc, indTp, t.get(), dim); } else { // Dense dimension, nothing to fetch. assert(merger.isDimLevelType(tensor, idx, DimLvlType::kDense)); } // Find upper bound in current dimension. unsigned p = toOrigDim(enc, d); - Value up = linalg::createOrFoldDimOp(builder, loc, t->get(), p); + Value up = linalg::createOrFoldDimOp(builder, loc, t.get(), p); if (ShapedType::isDynamic(shape[p])) args.push_back(up); assert(codegen.highs[tensor][idx] == nullptr); @@ -585,21 +585,21 @@ static void genBuffers(Merger &merger, CodeGen &codegen, OpBuilder &builder, // Perform the required bufferization. Dense inputs materialize // from the input tensors. Dense outputs need special handling. // Sparse inputs use sparse primitives to obtain the values. - Type elementType = getElementTypeOrSelf(t->get().getType()); + Type elementType = getElementTypeOrSelf(t.get().getType()); if (!enc) { // Non-annotated dense tensors. auto denseTp = MemRefType::get(shape, elementType); if (tensor < op.getNumInputs()) codegen.buffers[tensor] = - builder.create(loc, denseTp, t->get()); + builder.create(loc, denseTp, t.get()); else codegen.buffers[tensor] = genOutputBuffer(codegen, builder, op, denseTp, args); - } else if (t != codegen.sparseOut) { + } else if (&t != codegen.sparseOut) { // Annotated sparse tensors (not involved in output). auto sparseTp = MemRefType::get(dynShape, elementType); codegen.buffers[tensor] = - builder.create(loc, sparseTp, t->get()); + builder.create(loc, sparseTp, t.get()); } } } @@ -845,15 +845,15 @@ static Value genTensorLoad(Merger &merger, CodeGen &codegen, OpBuilder &builder, return val; } // Load during insertion. - OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).tensor]; - if (t == codegen.sparseOut) { + OpOperand &t = op->getOpOperand(merger.exp(exp).tensor); + if (&t == codegen.sparseOut) { if (codegen.redCustom != -1u) - return genInsertionLoadReduce(merger, codegen, builder, op, t); - return genInsertionLoad(codegen, builder, op, t); + return genInsertionLoadReduce(merger, codegen, builder, op, &t); + return genInsertionLoad(codegen, builder, op, &t); } // Actual load. SmallVector args; - Value ptr = genSubscript(codegen, builder, op, t, args); + Value ptr = genSubscript(codegen, builder, op, &t, args); if (codegen.curVecLength > 1) return genVectorLoad(codegen, builder, ptr, args); return builder.create(op.getLoc(), ptr, args); @@ -1093,9 +1093,9 @@ static void genInvariants(Merger &merger, CodeGen &codegen, OpBuilder &builder, if (merger.exp(exp).kind == Kind::kTensor) { // Inspect tensor indices. bool atLevel = ldx == -1u; - OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).tensor]; - auto map = op.getMatchingIndexingMap(t); - auto enc = getSparseTensorEncoding(t->get().getType()); + OpOperand &t = op->getOpOperand(merger.exp(exp).tensor); + auto map = op.getMatchingIndexingMap(&t); + auto enc = getSparseTensorEncoding(t.get().getType()); for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { AffineExpr a = map.getResult(toOrigDim(enc, d)); if (!isInvariantAffine(codegen, a, ldx, atLevel)) @@ -1105,7 +1105,7 @@ static void genInvariants(Merger &merger, CodeGen &codegen, OpBuilder &builder, if (!atLevel) return; OpOperand *lhs = op.getOutputOperand(0); - if (lhs == t) { + if (lhs == &t) { // Start or end a scalarized reduction if (atStart) { Kind kind = merger.exp(last).kind; @@ -1288,9 +1288,9 @@ static bool isParallelFor(CodeGen &codegen, bool isOuter, bool isReduction, /// This prevents effective vectorization. static bool denseUnitStrides(Merger &merger, linalg::GenericOp op, unsigned idx) { - for (OpOperand *t : op.getInputAndOutputOperands()) { - if (!getSparseTensorEncoding(t->get().getType())) { - auto map = op.getMatchingIndexingMap(t); + for (OpOperand &t : op->getOpOperands()) { + if (!getSparseTensorEncoding(t.get().getType())) { + auto map = op.getMatchingIndexingMap(&t); for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { AffineExpr a = map.getResult(d); // Report non-unit stride if innermost index appears at an outer @@ -1856,7 +1856,7 @@ public: // information for all tensors to loop indices in the kernel. if (op.getNumOutputs() != 1) return failure(); - unsigned numTensors = op.getNumInputsAndOutputs(); + unsigned numTensors = op->getNumOperands(); unsigned numLoops = op.getNumLoops(); Merger merger(numTensors, numLoops); if (!findSparseAnnotations(merger, op)) diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp index 187a6c0..b8f6a93 100644 --- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp +++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp @@ -910,10 +910,10 @@ Optional Merger::buildTensorExp(linalg::GenericOp op, Value v) { // argument is considered a tensor, indexed by the implicit loop // bounds. This includes rank-0 tensor arguments. if (arg.getOwner()->getParentOp() == op) { - OpOperand *t = op.getInputAndOutputOperands()[argN]; - if (!op.isScalar(t)) + OpOperand &t = op->getOpOperand(argN); + if (!op.isScalar(&t)) return addExp(kTensor, argN); - v = t->get(); // get scalar value + v = t.get(); // get scalar value } // Any other argument (marked as scalar argument for the generic op // or belonging to an enveloping op) is considered invariant. diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir index 43589f7..2062c65 100644 --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -275,7 +275,7 @@ func.func @self_copy(%arg0 : memref<2x3x?x4xf32>) { // ----- // CHECK-LABEL: func @remove_deadargs_generic_basic -// CHECK-SAME: (%[[ARG0:.*]]: tensor) -> tensor { +// CHECK-SAME: (%[[ARG0:.*]]: tensor) -> tensor { // CHECK: %[[GENERIC_OP:.*]] = linalg.generic // CHECK-SAME: ins(%[[ARG0]] : tensor) // CHECK-SAME: outs({{.*}} : tensor) { diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir index 02471b1..f751ddf 100644 --- a/mlir/test/Dialect/Linalg/roundtrip.mlir +++ b/mlir/test/Dialect/Linalg/roundtrip.mlir @@ -121,26 +121,6 @@ func.func @generic(%arg0: memref, strided<[?, 1], offset: ?>> // CHECK-SAME: outs({{.*}} : memref>) // CHECK-SAME: {foo = 1 : i64} -func.func @generic_with_tensor_input(%arg0: tensor>, - %arg1: memref>) { - %cst = arith.constant 0.0 : f32 - linalg.generic #trait_0 - ins(%arg0, %cst : tensor>, f32) - outs(%arg1 : memref>) - attrs = {foo = 1} { - ^bb(%0: vector<3x4xi4>, %1: f32, %2: f32) : - linalg.yield %1 : f32 - } - return -} -// CHECK-LABEL: func @generic_with_tensor_input -// CHECK: linalg.generic { -// CHECK-SAME: indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"], -// CHECK-SAME: library_call = "some_external_function_name_1"} -// CHECK-SAME: ins({{.*}}, {{.*}} : tensor>, f32) -// CHECK-SAME: outs({{.*}} : memref>) -// CHECK-SAME: {foo = 1 : i64} - // ----- #map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> @@ -300,27 +280,19 @@ func.func @generic_region(%arg0: memref, strided<[?, 1], offs func.func @named_ops(%a3: memref, %b3: memref, %c3: memref, %ta3: tensor, %tb3: tensor, %tc3: tensor) - -> (tensor, tensor) + -> (tensor) { linalg.batch_matmul ins(%a3, %b3: memref, memref) outs(%c3: memref) - linalg.batch_matmul ins(%ta3, %tb3: tensor, tensor) - outs(%c3: memref) %res1 = linalg.batch_matmul ins(%ta3, %tb3: tensor, tensor) outs(%tc3: tensor) -> tensor - %res2 = linalg.batch_matmul - ins(%ta3, %b3: tensor, memref) - outs(%tc3: tensor) - -> tensor - return %res1, %res2 : tensor, tensor + return %res1 : tensor } // CHECK-LABEL: func @named_ops // CHECK: linalg.batch_matmul // CHECK: linalg.batch_matmul -// CHECK: linalg.batch_matmul -// CHECK: linalg.batch_matmul // ----- diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp index 0119516..5807726 100644 --- a/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp @@ -26,7 +26,7 @@ static void addOperands(Operation *op, SetVector &operandSet) { return; TypeSwitch(op) .Case([&](linalg::LinalgOp linalgOp) { - SmallVector inputOperands = linalgOp.getInputOperands(); + SmallVector inputOperands{linalgOp.getInputOperands()}; operandSet.insert(inputOperands.begin(), inputOperands.end()); }) .Default([&](Operation *operation) { @@ -147,7 +147,7 @@ struct TestLinalgElementwiseFusion if (expandOp->hasOneUse()) { OpOperand &use = *expandOp->getUses().begin(); auto linalgOp = dyn_cast(use.getOwner()); - if (linalgOp && linalgOp.isOutputTensor(&use)) + if (linalgOp && linalgOp.isOutput(&use)) return true; } return false; diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp index c5b27c5..3c62496 100644 --- a/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp @@ -38,14 +38,14 @@ static LogicalResult fuseLinalgOpsGreedily(func::FuncOp f) { // Tile and Fuse for tensors inputs (TODO: all tensor operands). bool changed = false; for (LinalgOp linalgOp : llvm::reverse(linalgOps)) { - for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) { - if (opOperand->get().getType().isa()) { + for (OpOperand &opOperand : linalgOp->getOpOperands()) { + if (opOperand.get().getType().isa()) { // TODO: LinalgDependenceGraph should be able to update itself. // The current naive and expensive reconstruction of the graph should be // removed. linalg::Aliases aliases; linalg::LinalgDependenceGraph graph(aliases, linalgOps); - auto info = fuseProducerOfBuffer(b, *opOperand, graph); + auto info = fuseProducerOfBuffer(b, opOperand, graph); if (failed(info)) continue; auto *originalOp = info->originalProducer.getOperation(); @@ -54,11 +54,11 @@ static LogicalResult fuseLinalgOpsGreedily(func::FuncOp f) { std::find(linalgOps.begin(), linalgOps.end(), originalOp); *originalOpInLinalgOpsVector = info->fusedProducer.getOperation(); changed = true; - } else if (opOperand->get().getType().isa()) { + } else if (opOperand.get().getType().isa()) { // Tile and Fuse tensor input. - if (opOperand->getOperandNumber() >= linalgOp.getNumInputs()) + if (opOperand.getOperandNumber() >= linalgOp.getNumInputs()) continue; - auto info = fuseProducerOfTensor(b, *opOperand); + auto info = fuseProducerOfTensor(b, opOperand); if (failed(info)) continue; auto *originalOp = info->originalProducer.getOperation(); diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index 85c5f32..79ed068 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -2835,9 +2835,10 @@ def TestLinalgConvOp : return ""; } - // To conform with interface requirement on operand naming. - mlir::ValueRange inputs() { return getInputs(); } - mlir::ValueRange outputs() { return getOutputs(); } + std::pair getOutputsPositionRange() { + int64_t getNumOperands = this->getNumOperands(); + return {getNumOperands - 1, getNumOperands}; + } }]; } @@ -2894,9 +2895,10 @@ def TestLinalgFillOp : return ""; } - // To conform with interface requirement on operand naming. - mlir::ValueRange inputs() { return getInputs(); } - mlir::ValueRange outputs() { return getOutputs(); } + std::pair getOutputsPositionRange() { + int64_t getNumOperands = this->getNumOperands(); + return {getNumOperands - 1, getNumOperands}; + } }]; } diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp index 8156bb9..d8e10ef 100644 --- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp +++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp @@ -563,6 +563,11 @@ def {0} : LinalgStructuredBase_Op<"{1}", !listconcat([AttrSizedOperandSegments], return regionBuilder; } + std::pair getOutputsPositionRange() {{ + int64_t getNumOperands = this->getNumOperands(); + return {{getNumOperands - 1, getNumOperands}; + } + // Generic methods. static unsigned getNumRegionArgs(); std::string getLibraryCallName(); @@ -638,8 +643,8 @@ ArrayAttr {0}::getIndexingMaps() {{ AffineMap tensorMap = AffineMap::getMultiDimIdentityMap( getNumParallelLoops(), context); SmallVector indexingMaps; - for (OpOperand *opOperand : getInputAndOutputOperands()) - indexingMaps.push_back(getRank(opOperand) == 0 ? scalarMap : tensorMap); + for (OpOperand &opOperand : getOperation()->getOpOperands()) + indexingMaps.push_back(getRank(&opOperand) == 0 ? scalarMap : tensorMap); return Builder(getContext()).getAffineMapArrayAttr(indexingMaps); } )FMT"; @@ -654,10 +659,9 @@ LogicalResult {0}::fold(ArrayRef, } void {0}::getEffects(SmallVectorImpl< SideEffects::EffectInstance >&effects) {{ - SmallVector inputBuffers = getInputBufferOperands(); - SmallVector outputBuffers = getOutputBufferOperands(); + if (hasTensorSemantics()) return; getGenericEffectsImpl(effects, - getOperation()->getResults(), inputBuffers, outputBuffers); + getOperation()->getResults(), getInputOperands(), getOutputOperands()); } )FMT"; -- 2.7.4