[mlir] Simplify DestinationStyleOpInterface.
authorAlexander Belyaev <pifon@google.com>
Fri, 14 Oct 2022 17:59:55 +0000 (19:59 +0200)
committerAlexander Belyaev <pifon@google.com>
Mon, 17 Oct 2022 10:43:41 +0000 (12:43 +0200)
Differential Revision: https://reviews.llvm.org/D135348

36 files changed:
mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
mlir/lib/CAPI/Dialect/Linalg.cpp
mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp
mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/lib/Dialect/Linalg/Transforms/BubbleUpExtractSlice.cpp
mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp
mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp
mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp
mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp
mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp
mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/lib/Dialect/Linalg/Utils/Utils.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
mlir/test/Dialect/Linalg/canonicalize.mlir
mlir/test/Dialect/Linalg/roundtrip.mlir
mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp
mlir/test/lib/Dialect/Test/TestOps.td
mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp

index 69871fa..995ced5 100644 (file)
@@ -317,7 +317,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
       /*args=*/(ins "OpOperand *":$opOperand),
       /*methodBody=*/"",
       /*defaultImplementation=*/[{
-        if (!$_op.isOutputTensor(opOperand))
+        if (!$_op.isOutput(opOperand))
           return false;
         return payloadUsesValueFromOperand(opOperand);
       }]
@@ -606,7 +606,13 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
       /*args=*/(ins),
       /*methodBody=*/"",
       /*defaultImplementation=*/[{
-        return $_op.getInputAndOutputOperands();
+        OpOperandVector result;
+        result.reserve($_op->getNumOperands());
+        llvm::transform(
+          this->getOperation()->getOpOperands(),
+          std::back_inserter(result),
+          [](OpOperand &opOperand) { return &opOperand; });
+        return result;
       }]
     >,
     //===------------------------------------------------------------------===//
@@ -684,13 +690,8 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
       /*methodBody=*/"",
       /*defaultImplementation=*/[{
         SmallVector<int64_t> res;
-        // MLIR currently does not support dependent interfaces or interface
-        // inheritance. By construction all ops with StructuredOpInterface must
-        // implement DestinationStyleOpInterface.
-        // TODO: reevalute the need for a cast when a better mechanism exists.
-        auto iface = cast<DestinationStyleOpInterface>(*this->getOperation());
-        for (OpOperand *opOperand : iface.getInputAndOutputOperands())
-          llvm::append_range(res, getShape(opOperand));
+        for (OpOperand &opOperand : this->getOperation()->getOpOperands())
+          llvm::append_range(res, getShape(&opOperand));
         return res;
       }]
     >,
@@ -779,31 +780,16 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
     // TODO: reevalute the need for a cast when a better mechanism exists.
     //========================================================================//
 
-    ValueRange getInputs() {
-      return cast<DestinationStyleOpInterface>(*this->getOperation())
-          .getInputs();
-    }
-
     int64_t getNumInputs() {
       return cast<DestinationStyleOpInterface>(*this->getOperation())
           .getNumInputs();
     }
 
-    ValueRange getOutputs() {
-      return cast<DestinationStyleOpInterface>(*this->getOperation())
-          .getOutputs();
-    }
-
     int64_t getNumOutputs() {
       return cast<DestinationStyleOpInterface>(*this->getOperation())
           .getNumOutputs();
     }
 
-    int64_t getNumInputsAndOutputs() {
-      return cast<DestinationStyleOpInterface>(*this->getOperation())
-          .getNumInputsAndOutputs();
-    }
-
     OpOperandVector getInputOperands() {
       return cast<DestinationStyleOpInterface>(*this->getOperation())
           .getInputOperands();
@@ -814,14 +800,9 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
           .getInputOperand(i);
     }
 
-    OpOperandVector getInputBufferOperands() {
-      return cast<DestinationStyleOpInterface>(*this->getOperation())
-          .getInputBufferOperands();
-    }
-
-    OpOperandVector getInputTensorOperands() {
+    void setOutputOperand(int64_t i, Value value) {
       return cast<DestinationStyleOpInterface>(*this->getOperation())
-          .getInputTensorOperands();
+          .setOutputOperand(i, value);
     }
 
     OpOperandVector getOutputOperands() {
@@ -834,44 +815,14 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
           .getOutputOperand(i);
     }
 
-    void setOutputOperand(int64_t i, Value value) {
+    bool isInput(OpOperand *opOperand) {
       return cast<DestinationStyleOpInterface>(*this->getOperation())
-          .setOutputOperand(i, value);
+          .isInput(opOperand);
     }
 
-    OpOperandVector getOutputBufferOperands() {
+    bool isOutput(OpOperand *opOperand) {
       return cast<DestinationStyleOpInterface>(*this->getOperation())
-          .getOutputBufferOperands();
-    }
-
-    OpOperandVector getOutputTensorOperands() {
-      return cast<DestinationStyleOpInterface>(*this->getOperation())
-          .getOutputTensorOperands();
-    }
-
-    SmallVector<MemRefType> getOutputBufferTypes() {
-      return cast<DestinationStyleOpInterface>(*this->getOperation())
-          .getOutputBufferTypes();
-    }
-
-    SmallVector<RankedTensorType> getOutputTensorTypes() {
-      return cast<DestinationStyleOpInterface>(*this->getOperation())
-          .getOutputTensorTypes();
-    }
-
-    OpOperandVector getInputAndOutputOperands() {
-      return cast<DestinationStyleOpInterface>(*this->getOperation())
-          .getInputAndOutputOperands();
-    }
-
-    bool isInputTensor(OpOperand *opOperand) {
-      return cast<DestinationStyleOpInterface>(*this->getOperation())
-          .isInputTensor(opOperand);
-    }
-
-    bool isOutputTensor(OpOperand *opOperand) {
-      return cast<DestinationStyleOpInterface>(*this->getOperation())
-          .isOutputTensor(opOperand);
+          .isOutput(opOperand);
     }
 
     bool isScalar(OpOperand *opOperand) {
@@ -928,331 +879,185 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
   let verifyWithRegions = 1;
 }
 
-// The 'DestinationStyleOpInterface' provides access to the methods relevant
-// for destination-style ops. A destination-style operation has 'n' input
-// arguments and 'm' output arguments. Each op that wants to implement
-// DestinationStyleOpInterface needs to define getInputs() and getOutputs()
-// methods.
+// Ops that are in destination style have designated output operands, which act
+// as initial tensor values for the results of the operation or the output
+// buffers to which the results of the op will be written.
+//
+// Output operands must be tensors or memrefs. Input operands can have any
+// type. All non-output operands are inputs.
+
+// It is assumed that the output operands of the op are the operands at
+// position [start, end). The positions are defined by getOutputsPositionRange
+// method. All non-output operands are "inputs" of the DPS op.
+
+// If the op has "tensor semantics", then the input operands are either scalars
+// or tensors. The output operands are tensors and every tensor output is tied
+// to a corresponding tensor OpResult in a 1-to-1 fashion. The i-th output
+// tensor is tied to the i-th OpResult. The op may not have any additional
+// OpResults. Output operands and their tied OpResults have the same type.
+//
+// If the op has "buffer semantics", then the input operands are either memrefs
+// or other non-tensor types, e.g. scalar types. Furthermore, the output
+// operands are memrefs and the op has no results.
+//
+// Destination-passing style abstraction makes certain transformations easier.
+// For example, tiling implementation can extract/insert slices from/into the
+// destination of an op and use the resulting shaped value as an iter_arg in
+// the surrounding loop structure. As another example, bufferization does not
+// have to allocate new buffers for destinations (in case of in-place
+// bufferization) and can directly reuse the existing destination buffer.
+//
+// Example of a destination style op: `%r = tensor.insert_slice %t into %d`,
+// where `%t` is the single input and `%d` is the single output. `%d` is tied
+// to `%r`.
+//
+// Example of an op that is not in destination style: `%r = tensor.pad %t`.
+// This op is not in destination style because `%r` and `%t` have different
+// shape.
+//
+// Each op that wants to implement DestinationStyleOpInterface needs to define
+// the getOutputsPositionRange() method.
 def DestinationStyleOpInterface : OpInterface<"DestinationStyleOpInterface"> {
   let cppNamespace = "::mlir::linalg";
   let methods = [
-    //===------------------------------------------------------------------===//
-    // Num input/output arguments handling.
-    //===------------------------------------------------------------------===//
-    // `getInputs` must be defined by each op that wants to implement the
-    // DestinationStyleOpInterface.
+    // This method has to be defined for every DPS op.
     InterfaceMethod<
-      /*desc=*/[{
-        Return the input shape operands.
-      }],
-      /*retTy=*/"ValueRange",
-      /*methodName=*/"getInputs",
-      /*args=*/(ins)
-    >,
-    // These special methods rely on `getInputs` and `getOutputs` being defined
-    // by each op that wants to implement the DestinationStyleOpInterface.
-    InterfaceMethod<
-      /*desc=*/[{
-        Return the number of inputs.
-      }],
-      /*retTy=*/"int64_t",
-      /*methodName=*/"getNumInputs",
-      /*args=*/(ins),
-      /*methodBody=*/"",
-      /*defaultImplementation=*/[{
-        return $_op.getInputs().size();
-      }]
-    >,
-    // `getOutputs` must be defined by each op that wants to implement the
-    // DestinationStyleOpInterface.
-    InterfaceMethod<
-      /*desc=*/[{
-        Return the output shape operands.
-      }],
-      /*retTy=*/"ValueRange",
-      /*methodName=*/"getOutputs",
-      /*args=*/(ins)
-    >,
-    InterfaceMethod<
-      /*desc=*/[{
-        Return the number of outputs.
-      }],
-      /*retTy=*/"int64_t",
-      /*methodName=*/"getNumOutputs",
-      /*args=*/(ins),
-      /*methodBody=*/"",
-      /*defaultImplementation=*/[{
-        return $_op.getOutputs().size();
-      }]
-    >,
-    InterfaceMethod<
-      /*desc=*/[{
-        Return the number of inputs and outputs.
-      }],
-      /*retTy=*/"int64_t",
-      /*methodName=*/"getNumInputsAndOutputs",
+      /*desc=*/"Return start and end indices of the output operands range.",
+      /*retTy=*/"std::pair<int64_t, int64_t>",
+      /*methodName=*/"getOutputsPositionRange",
       /*args=*/(ins),
       /*methodBody=*/"",
-      /*defaultImplementation=*/[{
-        return this->getOperation()->getNumOperands();
-      }]
+      /*defaultImplementation=*/""
     >,
     //===------------------------------------------------------------------===//
-    // Input operands handling.
+    // Operands handling.
     //===------------------------------------------------------------------===//
+    // The operand list is assumed to start with the input operands and end
+    // with the output operands. Therefore, all methods to access the inputs
+    // and outputs can be expressed if the number of output operands is know.
     InterfaceMethod<
-      /*desc=*/[{
-        Return the input operands.
-      }],
-      /*retTy=*/"OpOperandVector",
-      /*methodName=*/"getInputOperands",
-      /*args=*/(ins),
-      /*methodBody=*/"",
-      /*defaultImplementation=*/[{
-        int64_t numInputs = getNumInputs();
-        OpOperandVector result;
-        result.reserve(numInputs);
-        llvm::transform(
-          this->getOperation()->getOpOperands().take_front(numInputs),
-          std::back_inserter(result),
-          [](OpOperand &opOperand) { return &opOperand; });
-        return result;
-      }]
-    >,
-    InterfaceMethod<
-      /*desc=*/[{
-        Return the `i`-th input operand.
-      }],
-      /*retTy=*/"OpOperand*",
-      /*methodName=*/"getInputOperand",
-      /*args=*/(ins "int64_t":$i),
-      /*methodBody=*/"",
-      /*defaultImplementation=*/[{
-        assert(i >= 0 && i < getNumInputs());
-        return &this->getOperation()->getOpOperand(i);
-      }]
-    >,
-    InterfaceMethod<
-      /*desc=*/[{
-        Return the subset of input operands that are of buffer type.
-      }],
-      /*retTy=*/"OpOperandVector",
-      /*methodName=*/"getInputBufferOperands",
-      /*args=*/(ins),
-      /*methodBody=*/"",
-      /*defaultImplementation=*/[{
-        OpOperandVector result;
-        result.reserve(getNumInputs());
-        llvm::copy_if(getInputOperands(),
-          std::back_inserter(result),
-          [](OpOperand *opOperand) {
-            return opOperand->get().getType().template isa<MemRefType>();
-          });
-        return result;
-      }]
-    >,
-    InterfaceMethod<
-      /*desc=*/[{
-        Return the subset of input operands that are of tensor type.
-      }],
-      /*retTy=*/"OpOperandVector",
-      /*methodName=*/"getInputTensorOperands",
+      /*desc=*/"Return the number of outputs.",
+      /*retTy=*/"int64_t",
+      /*methodName=*/"getNumOutputs",
       /*args=*/(ins),
       /*methodBody=*/"",
       /*defaultImplementation=*/[{
-        OpOperandVector result;
-        result.reserve(getNumInputs());
-        llvm::copy_if(getInputOperands(),
-          std::back_inserter(result),
-          [](OpOperand *opOperand) {
-            return opOperand->get().getType().template isa<RankedTensorType>();
-          });
-        return result;
+        auto [start, end] = $_op.getOutputsPositionRange();
+        return end - start;
       }]
     >,
-    //===------------------------------------------------------------------===//
-    // Output operands handling.
-    //===------------------------------------------------------------------===//
     InterfaceMethod<
-      /*desc=*/[{
-        Return the output operands.
-      }],
+      /*desc=*/"Return the output operands.",
       /*retTy=*/"OpOperandVector",
       /*methodName=*/"getOutputOperands",
       /*args=*/(ins),
       /*methodBody=*/"",
       /*defaultImplementation=*/[{
-        int64_t numOutputs = getNumOutputs();
+        auto [start, end] = $_op.getOutputsPositionRange();
+
         OpOperandVector result;
-        result.reserve(numOutputs);
-        llvm::transform(
-          this->getOperation()->getOpOperands()
-            .take_back(numOutputs),
-          std::back_inserter(result),
-          [](OpOperand &opOperand) { return &opOperand; });
+        result.reserve(end - start);
+        for (int i = start; i < end; ++i)
+          result.push_back(&$_op->getOpOperand(i));
         return result;
       }]
     >,
     InterfaceMethod<
-      /*desc=*/[{
-        Return the `i`-th output operand.
-      }],
+      /*desc=*/"Return the `i`-th output operand.",
       /*retTy=*/"OpOperand*",
       /*methodName=*/"getOutputOperand",
       /*args=*/(ins "int64_t":$i),
       /*methodBody=*/"",
       /*defaultImplementation=*/[{
-        assert(i >= 0 && i < getNumOutputs());
-        return &this->getOperation()->getOpOperand(getNumInputs() + i);
+        assert(i >= 0 && i < $_op.getNumOutputs());
+        auto [start, end] = $_op.getOutputsPositionRange();
+        return &$_op->getOpOperand(start + i);
       }]
     >,
     InterfaceMethod<
-      /*desc=*/[{
-        Set the `i`-th output operand.
-      }],
+      /*desc=*/"Set the `i`-th output operand.",
       /*retTy=*/"void",
       /*methodName=*/"setOutputOperand",
       /*args=*/(ins "int64_t":$i, "Value":$value),
       /*methodBody=*/"",
       /*defaultImplementation=*/[{
-        assert(i >= 0 && i < getNumOutputs());
-        this->getOperation()->setOperand(getNumInputs() + i, value);
+        assert(i >= 0 && i < $_op.getNumOutputs());
+        auto [start, end] = $_op.getOutputsPositionRange();
+        $_op->setOperand(start + i, value);
       }]
     >,
     InterfaceMethod<
-      /*desc=*/[{
-        Return the subset of output operands that are of buffer type.
-      }],
-      /*retTy=*/"OpOperandVector",
-      /*methodName=*/"getOutputBufferOperands",
+      /*desc=*/"Return the number of inputs.",
+      /*retTy=*/"int64_t",
+      /*methodName=*/"getNumInputs",
       /*args=*/(ins),
       /*methodBody=*/"",
       /*defaultImplementation=*/[{
-        OpOperandVector result;
-        result.reserve(getNumOutputs());
-        llvm::copy_if(getOutputOperands(),
-          std::back_inserter(result),
-          [](OpOperand *opOperand) {
-            return opOperand->get().getType().template isa<MemRefType>();
-          });
-        return result;
+        return $_op.getNumOperands() - $_op.getNumOutputs();
       }]
     >,
     InterfaceMethod<
-      /*desc=*/[{
-        Return the subset of output operands that are of tensor type.
-      }],
+      /*desc=*/"Return the input operands.",
       /*retTy=*/"OpOperandVector",
-      /*methodName=*/"getOutputTensorOperands",
+      /*methodName=*/"getInputOperands",
       /*args=*/(ins),
       /*methodBody=*/"",
       /*defaultImplementation=*/[{
+        auto [start, end] = $_op.getOutputsPositionRange();
+        int64_t numOutputs = end - start;
+        int64_t numOperands = $_op.getNumOperands();
+
         OpOperandVector result;
-        result.reserve(getNumOutputs());
-        llvm::copy_if(getOutputOperands(),
-          std::back_inserter(result),
-          [](OpOperand *opOperand) {
-            return opOperand->get().getType().template isa<RankedTensorType>();
-          });
-        return result;
-      }]
-    >,
-    InterfaceMethod<
-      /*desc=*/[{
-        Return the types of the subset of output operands that are of buffer type.
-      }],
-      /*retTy=*/"SmallVector<MemRefType>",
-      /*methodName=*/"getOutputBufferTypes",
-      /*args=*/(ins),
-      /*methodBody=*/"",
-      /*defaultImplementation=*/[{
-        SmallVector<MemRefType> result;
-        result.reserve(getNumOutputs());
-        llvm::transform(getOutputBufferOperands(),
-          std::back_inserter(result),
-          [](OpOperand *opOperands) {
-            return opOperands->get().getType().cast<MemRefType>();
-          });
+        result.reserve(numOperands - numOutputs);
+        for (int i = 0; i < start; ++i)
+          result.push_back(&$_op->getOpOperand(i));
+        for (int i = end; i < numOperands; ++i)
+          result.push_back(&$_op->getOpOperand(end + i));
+
         return result;
       }]
     >,
     InterfaceMethod<
-      /*desc=*/[{
-        Return the types of the subset of output operands that are of tensor type.
-      }],
-      /*retTy=*/"SmallVector<RankedTensorType>",
-      /*methodName=*/"getOutputTensorTypes",
-      /*args=*/(ins),
+      /*desc=*/[{ Return the `i`-th input operand.  }],
+      /*retTy=*/"OpOperand*",
+      /*methodName=*/"getInputOperand",
+      /*args=*/(ins "int64_t":$i),
       /*methodBody=*/"",
       /*defaultImplementation=*/[{
-        SmallVector<RankedTensorType> result;
-        result.reserve(getNumOutputs());
-        llvm::transform(getOutputTensorOperands(),
-          std::back_inserter(result),
-          [](OpOperand *opOperands) {
-            return opOperands->get().getType().cast<RankedTensorType>();
-          });
-        return result;
+        assert(i >= 0 && i < getNumInputs());
+        auto [start, end] = $_op.getOutputsPositionRange();
+        return &$_op->getOpOperand(i < start ? i : i + end - start) ;
       }]
     >,
     //===------------------------------------------------------------------===//
     // Input and Output arguments handling.
     //===------------------------------------------------------------------===//
     InterfaceMethod<
-      /*desc=*/[{
-        Return the range over input and output operands.
-      }],
-      /*retTy=*/"OpOperandVector",
-      /*methodName=*/"getInputAndOutputOperands",
-      /*args=*/(ins),
-      /*methodBody=*/"",
-      /*defaultImplementation=*/[{
-        int64_t numInputsAndOutputs = getNumInputsAndOutputs();
-        OpOperandVector result;
-        result.reserve(numInputsAndOutputs);
-        llvm::transform(
-          this->getOperation()->getOpOperands(),
-          std::back_inserter(result),
-          [](OpOperand &opOperand) { return &opOperand; });
-        return result;
-      }]
-    >,
-    InterfaceMethod<
-      /*desc=*/[{
-        Return true if `opOperand` is an input tensor.
-      }],
+      /*desc=*/"Return true if `opOperand` is an input.",
       /*retTy=*/"bool",
-      /*methodName=*/"isInputTensor",
+      /*methodName=*/"isInput",
       /*args=*/(ins "OpOperand *":$opOperand),
       /*methodBody=*/"",
       /*defaultImplementation=*/[{
-        if (!opOperand->get().getType().template isa<RankedTensorType>())
-          return false;
-        if (opOperand->getOperandNumber() < $_op.getNumInputs())
-          return true;
-        return false;
+        auto [start, end] = $_op.getOutputsPositionRange();
+        auto operandNumber = opOperand->getOperandNumber();
+        return operandNumber < start || operandNumber >= end;
       }]
     >,
     InterfaceMethod<
-      /*desc=*/[{
-        Return true if `opOperand` is an output tensor.
-      }],
+      /*desc=*/"Return true if `opOperand` is an output.",
       /*retTy=*/"bool",
-      /*methodName=*/"isOutputTensor",
+      /*methodName=*/"isOutput",
       /*args=*/(ins "OpOperand *":$opOperand),
       /*methodBody=*/"",
       /*defaultImplementation=*/[{
-        if (!opOperand->get().getType().template isa<RankedTensorType>())
-          return false;
-        if (opOperand->getOperandNumber() >= $_op.getNumInputs())
-          return true;
-        return false;
+        auto [start, end] = $_op.getOutputsPositionRange();
+        auto operandNumber = opOperand->getOperandNumber();
+        return operandNumber >= start && operandNumber < end;
       }]
     >,
     InterfaceMethod<
-      /*desc=*/[{
-        Return true if the `opOperand` is a scalar value.
-      }],
+      /*desc=*/"Return true if the `opOperand` is a scalar value.",
       /*retTy=*/"bool",
       /*methodName=*/"isScalar",
       /*args=*/(ins "OpOperand*":$opOperand),
@@ -1263,35 +1068,33 @@ def DestinationStyleOpInterface : OpInterface<"DestinationStyleOpInterface"> {
       }]
     >,
     InterfaceMethod<
-      /*desc=*/[{
-        Return the result tied to `opOperand`.
-      }],
+      /*desc=*/"Return the result tied to `opOperand`.",
       /*retTy=*/"OpResult",
       /*methodName=*/"getTiedOpResult",
       /*args=*/(ins "OpOperand*":$opOperand),
       /*methodBody=*/"",
       /*defaultImplementation=*/[{
         assert(opOperand->getOwner() == this->getOperation());
-        int64_t resultIndex = opOperand->getOperandNumber() - getNumInputs();
+
+        auto [start, end] = $_op.getOutputsPositionRange();
+        int64_t resultIndex = opOperand->getOperandNumber() - start;
         assert(resultIndex >= 0 &&
-               resultIndex < this->getOperation()->getNumResults() );
-        return this->getOperation()->getResult(resultIndex);
+               resultIndex < $_op->getNumResults() );
+        return $_op->getResult(resultIndex);
       }]
     >,
     //===------------------------------------------------------------------===//
     // Other interface methods.
     //===------------------------------------------------------------------===//
     InterfaceMethod<
-      /*desc=*/[{
-        Return whether the op has only MemRef input and outputs.
-      }],
+      /*desc=*/"Return whether the op has only MemRef input and outputs.",
       /*retTy=*/"bool",
       /*methodName=*/"hasBufferSemantics",
       /*args=*/(ins),
       /*methodBody=*/"",
       /*defaultImplementation=*/[{
-        return this->getOperation()->getNumResults() == 0 &&
-          llvm::all_of(this->getOperation()->getOpOperands(),
+        return $_op->getNumResults() == 0 &&
+          llvm::all_of($_op->getOpOperands(),
             [&](OpOperand &opOperand) {
               return isScalar(&opOperand) ||
                      opOperand.get().getType().template isa<MemRefType>();
@@ -1299,15 +1102,13 @@ def DestinationStyleOpInterface : OpInterface<"DestinationStyleOpInterface"> {
       }]
     >,
     InterfaceMethod<
-      /*desc=*/[{
-        Return whether the op has only RankedTensor input and outputs.
-      }],
+      /*desc=*/"Return whether the op has only RankedTensor input and outputs.",
       /*retTy=*/"bool",
       /*methodName=*/"hasTensorSemantics",
       /*args=*/(ins),
       /*methodBody=*/"",
       /*defaultImplementation=*/[{
-        return llvm::all_of(this->getOperation()->getOpOperands(),
+        return llvm::all_of($_op->getOpOperands(),
           [&](OpOperand &opOperand) {
             return isScalar(&opOperand) ||
                    opOperand.get().getType().template isa<RankedTensorType>();
index 2619ad1..3b06a59 100644 (file)
@@ -215,6 +215,10 @@ def GenericOp : LinalgStructuredBase_Op<"generic", [
     getRegionBuilder() {
       return nullptr;
     }
+    std::pair<int64_t, int64_t> getOutputsPositionRange() {
+      int64_t getNumOperands = this->getNumOperands();
+      return {getNumOperands - getOutputs().size(), getNumOperands};
+    }
   }];
 
   let hasCanonicalizer = 1;
@@ -271,11 +275,10 @@ def MapOp : LinalgStructuredBase_Op<"map", [
     }
 
     // Implement functions necessary for DestinationStyleOpInterface.
-    unsigned getNumInputs() {
-      return this->getOperation()->getNumOperands() - getNumOutputs();
-    };
-    unsigned getNumOutputs() { return 1; };
-    mlir::ValueRange getOutputs() { return getOperands().take_back(1); }
+    std::pair<int64_t, int64_t> getOutputsPositionRange() {
+      int64_t getNumOperands = this->getNumOperands();
+      return {getNumOperands - 1, getNumOperands};
+    }
     linalg::OpOperandVector getOpOperandsMatchingBBargs() {
       return getInputOperands();
     }
@@ -341,14 +344,14 @@ def ReduceOp : LinalgStructuredBase_Op<"reduce", [
     }
 
     // Implement functions necessary for DestinationStyleOpInterface.
-    mlir::ValueRange getOutputs() { return getInits(); }
-    unsigned getNumInputs() { return getInputs().size(); };
-    unsigned getNumOutputs() { return getInits().size(); };
     static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
                               mlir::ArrayRef<mlir::NamedAttribute>)>
     getRegionBuilder() {
       return nullptr;
     }
+    std::pair<int64_t, int64_t> getOutputsPositionRange() {
+      return {getInits().size(), getNumOperands()};
+    }
   }];
 
   let hasCustomAssemblyFormat = 1;
index bfb3313..2fb5bc6 100644 (file)
@@ -29,9 +29,9 @@ void mlirLinalgFillBuiltinNamedOpRegion(MlirOperation mlirOp) {
 
   SmallVector<Type, 8> argTypes;
   SmallVector<Location, 8> argLocs;
-  for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) {
-    argTypes.push_back(getElementTypeOrSelf(opOperand->get().getType()));
-    argLocs.push_back(opOperand->get().getLoc());
+  for (OpOperand &opOperand : linalgOp->getOpOperands()) {
+    argTypes.push_back(getElementTypeOrSelf(opOperand.get().getType()));
+    argLocs.push_back(opOperand.get().getLoc());
   }
 
   ImplicitLocOpBuilder b(op->getLoc(), op->getContext());
index 5a6d678..c4c7efb 100644 (file)
@@ -166,6 +166,8 @@ void LinalgDependenceGraph::addDependencesBetween(LinalgOp src, LinalgOp dst) {
                     << " and " << *dst.getOperation() << "\n");
   if (src.hasTensorSemantics() && dst.hasTensorSemantics()) {
     for (OpOperand *dstOpOperand : dst.getInputOperands()) {
+      if (!dstOpOperand->get().getType().isa<RankedTensorType>())
+        continue;
       // Check if the operand is defined by the src.
       auto definingOp = dstOpOperand->get().getDefiningOp<LinalgOp>();
       if (definingOp && definingOp == src)
@@ -188,23 +190,31 @@ void LinalgDependenceGraph::addDependencesBetween(LinalgOp src, LinalgOp dst) {
   }
   assert(src.hasBufferSemantics() && dst.hasBufferSemantics() &&
          "unhandled dependence tracking for mixed buffer/tensor operations");
-  for (OpOperand *srcOpOperand : src.getOutputBufferOperands()) { // W
+  for (OpOperand *srcOpOperand : src.getOutputOperands()) { // W
     // RAW graph
-    for (OpOperand *dstOpOperand : dst.getInputBufferOperands())   // R
+    for (OpOperand *dstOpOperand : dst.getInputOperands()) { // R
+      if (!dstOpOperand->get().getType().isa<MemRefType>())
+        continue;
       if (aliases.alias(srcOpOperand->get(), dstOpOperand->get())) // RAW alias
         addDependenceElem(DependenceType::RAW, srcOpOperand, dstOpOperand);
+    }
     // WAW graph
-    for (OpOperand *dstOpOperand : dst.getOutputBufferOperands())  // W
+    for (OpOperand *dstOpOperand : dst.getOutputOperands())        // W
       if (aliases.alias(srcOpOperand->get(), dstOpOperand->get())) // WAW alias
         addDependenceElem(DependenceType::WAW, srcOpOperand, dstOpOperand);
   }
-  for (OpOperand *srcOpOperand : src.getInputBufferOperands()) { // R
+  for (OpOperand *srcOpOperand : src.getInputOperands()) { // R
+    if (!srcOpOperand->get().getType().isa<MemRefType>())
+      continue;
     // RAR graph
-    for (OpOperand *dstOpOperand : dst.getInputBufferOperands())   // R
+    for (OpOperand *dstOpOperand : dst.getInputOperands()) { // R
+      if (!dstOpOperand->get().getType().isa<MemRefType>())
+        continue;
       if (aliases.alias(srcOpOperand->get(), dstOpOperand->get())) // RAR alias
         addDependenceElem(DependenceType::RAR, srcOpOperand, dstOpOperand);
+    }
     // WAR graph
-    for (OpOperand *dstOpOperand : dst.getOutputBufferOperands())  // W
+    for (OpOperand *dstOpOperand : dst.getOutputOperands())        // W
       if (aliases.alias(srcOpOperand->get(), dstOpOperand->get())) // WAR alias
         addDependenceElem(DependenceType::WAR, srcOpOperand, dstOpOperand);
   }
index 9a62a40..88fd71c 100644 (file)
@@ -31,10 +31,10 @@ using namespace mlir::linalg;
 bool linalg::detail::canOpOperandsBeDroppedImpl(
     linalg::LinalgOp linalgOp, ArrayRef<OpOperand *> droppedOperands) {
   SmallVector<AffineMap> indexingMaps;
-  for (auto *opOperand : linalgOp.getInputAndOutputOperands()) {
-    if (llvm::is_contained(droppedOperands, opOperand))
+  for (auto &opOperand : linalgOp->getOpOperands()) {
+    if (llvm::is_contained(droppedOperands, &opOperand))
       continue;
-    indexingMaps.push_back(linalgOp.getMatchingIndexingMap(opOperand));
+    indexingMaps.push_back(linalgOp.getMatchingIndexingMap(&opOperand));
   }
   return inversePermutation(concatAffineMaps(indexingMaps)) != AffineMap();
 }
@@ -491,9 +491,9 @@ static OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value source,
 SmallVector<OpFoldResult> LinalgOp::createFlatListOfOperandDims(OpBuilder &b,
                                                                 Location loc) {
   SmallVector<OpFoldResult> res;
-  for (OpOperand *opOperand : getInputAndOutputOperands()) {
-    for (int64_t i = 0, e = getRank(opOperand); i < e; ++i)
-      res.push_back(createFoldedDimOp(b, loc, opOperand->get(), i));
+  for (OpOperand &opOperand : getOperation()->getOpOperands()) {
+    for (int64_t i = 0, e = getRank(&opOperand); i < e; ++i)
+      res.push_back(createFoldedDimOp(b, loc, opOperand.get(), i));
   }
   return res;
 }
@@ -501,8 +501,8 @@ SmallVector<OpFoldResult> LinalgOp::createFlatListOfOperandDims(OpBuilder &b,
 SmallVector<int64_t, 4> LinalgOp::createFlatListOfOperandStaticDims() {
   SmallVector<int64_t, 4> res;
   assert(!hasDynamicShape() && "expected operands to have static shapes");
-  for (OpOperand *opOperand : getInputAndOutputOperands())
-    llvm::append_range(res, getShape(opOperand));
+  for (OpOperand &opOperand : getOperation()->getOpOperands())
+    llvm::append_range(res, getShape(&opOperand));
   return res;
 }
 
@@ -644,32 +644,32 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
 
   // All input/output operands must be indexed.
   if (static_cast<int64_t>(linalgOp.getIndexingMapsArray().size()) !=
-      linalgOp.getNumInputsAndOutputs())
+      linalgOp->getNumOperands())
     return op->emitOpError("expected the number of indexing_map (")
            << linalgOp.getIndexingMapsArray().size()
            << ") to be equal to the number of input/output operands ("
-           << linalgOp.getNumInputsAndOutputs() << ")";
+           << linalgOp->getNumOperands() << ")";
 
-  for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) {
-    AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
+  for (OpOperand &opOperand : linalgOp->getOpOperands()) {
+    AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
 
     // Symbols disallowed.
     if (indexingMap.getNumSymbols() != 0)
       return op->emitOpError("unexpected symbols in indexing_map #")
-             << opOperand->getOperandNumber();
+             << opOperand.getOperandNumber();
 
     // Domain must be consistent.
     unsigned numLoops = linalgOp.getNumLoops();
     if (indexingMap.getNumDims() != numLoops)
       return op->emitOpError("expected indexing_map #")
-             << opOperand->getOperandNumber() << " to have " << numLoops
+             << opOperand.getOperandNumber() << " to have " << numLoops
              << " dim(s) to match the number of loops";
 
-    int64_t rank = linalgOp.getRank(opOperand);
+    int64_t rank = linalgOp.getRank(&opOperand);
     if (indexingMap.getNumResults() != rank)
       return op->emitOpError("expected operand rank (")
              << rank << ") to match the result rank of indexing_map #"
-             << opOperand->getOperandNumber() << " ("
+             << opOperand.getOperandNumber() << " ("
              << indexingMap.getNumResults() << ")";
   }
 
@@ -688,13 +688,13 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
   if (llvm::none_of(endLoopRangeValues, ShapedType::isDynamic)) {
     for (int64_t &range : endLoopRangeValues)
       range -= 1;
-    for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) {
-      AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
+    for (OpOperand &opOperand : linalgOp->getOpOperands()) {
+      AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
       SmallVector<int64_t, 4> startIndices =
           indexingMap.compose(startLoopRangeValues);
       SmallVector<int64_t, 4> endIndices =
           indexingMap.compose(endLoopRangeValues);
-      ArrayRef<int64_t> shape = linalgOp.getShape(opOperand);
+      ArrayRef<int64_t> shape = linalgOp.getShape(&opOperand);
       for (auto dim : llvm::seq<int64_t>(0, shape.size())) {
         // Ignore dynamic dimension or the case that the dimension size is 0
         if (ShapedType::isDynamic(shape[dim]) || shape[dim] == 0)
@@ -725,17 +725,16 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
         if (indexingMap.getResult(dim).dyn_cast<AffineDimExpr>()) {
           if (inferredDimSize != shape[dim]) {
             return op->emitOpError("inferred input/output operand #")
-                   << opOperand->getOperandNumber()
-                   << " has shape's dimension #" << dim << " to be "
-                   << inferredDimSize << ", but found " << shape[dim];
+                   << opOperand.getOperandNumber() << " has shape's dimension #"
+                   << dim << " to be " << inferredDimSize << ", but found "
+                   << shape[dim];
           }
         } else {
           if (inferredDimSize > shape[dim]) {
             return op->emitOpError("inferred input/output operand #")
-                   << opOperand->getOperandNumber()
-                   << " has shape's dimension #" << dim
-                   << " to be greater than or equal to " << inferredDimSize
-                   << ", but found " << shape[dim];
+                   << opOperand.getOperandNumber() << " has shape's dimension #"
+                   << dim << " to be greater than or equal to "
+                   << inferredDimSize << ", but found " << shape[dim];
           }
         }
       }
@@ -777,6 +776,15 @@ mlir::linalg::detail::verifyDestinationStyleOpInterface(Operation *op) {
   DestinationStyleOpInterface dstStyleOp =
       cast<DestinationStyleOpInterface>(op);
 
+  SmallVector<OpOperand *> outputBufferOperands, outputTensorOperands;
+  for (OpOperand *operand : dstStyleOp.getOutputOperands()) {
+    Type type = operand->get().getType();
+    if (type.isa<MemRefType>())
+      outputBufferOperands.push_back(operand);
+    if (type.isa<RankedTensorType>())
+      outputTensorOperands.push_back(operand);
+  }
+
   // Expect at least one output operand.
   // This means an op that constructs a tensor out of indices cannot be a
   // LinalgOp at the moment. For now this will have to be a special op until we
@@ -788,23 +796,22 @@ mlir::linalg::detail::verifyDestinationStyleOpInterface(Operation *op) {
   if (failed(OpTrait::impl::verifyNOperands(op, numInputs + numOutputs)))
     return failure();
   // Verify the number of results matches the number of output tensors.
-  if (op->getNumResults() != dstStyleOp.getOutputTensorOperands().size())
+  if (op->getNumResults() != outputTensorOperands.size())
     return op->emitOpError("expected the number of results (")
            << op->getNumResults()
            << ") to be equal to the number of output tensors ("
-           << dstStyleOp.getOutputTensorOperands().size() << ")";
+           << outputTensorOperands.size() << ")";
 
   // Simplifying assumption: either full tensor or full buffer mode.
   // This allows simpler verification of output operands vs result types
   // without premature tracking of which operand is what in mixed-mode.
   // TODO: relax when mixed-mode needs to pass verification.
-  if (!dstStyleOp.getOutputBufferOperands().empty() &&
-      !dstStyleOp.getOutputTensorOperands().empty())
+  if (!outputBufferOperands.empty() && !outputTensorOperands.empty())
     return op->emitOpError(
         "expected output operands to all have tensor type or "
         "all have buffer type");
 
-  for (OpOperand *opOperand : dstStyleOp.getOutputTensorOperands()) {
+  for (OpOperand *opOperand : outputTensorOperands) {
     OpResult result = dstStyleOp.getTiedOpResult(opOperand);
     if (result.getType() != opOperand->get().getType())
       return op->emitOpError("expected type of operand #")
@@ -813,6 +820,5 @@ mlir::linalg::detail::verifyDestinationStyleOpInterface(Operation *op) {
              << " to match type of corresponding result (" << result.getType()
              << ")";
   }
-
   return success();
 }
index c2705a3..586d198 100644 (file)
@@ -767,7 +767,8 @@ void GenericOp::print(OpAsmPrinter &p) {
   }
 
   // Printing is shared with named ops, except for the region and attributes
-  printCommonStructuredOpParts(p, getInputs(), getOutputs());
+  printCommonStructuredOpParts(p, SmallVector<Value>(getInputOperands()),
+                               SmallVector<Value>(getOutputOperands()));
 
   genericAttrNames.push_back("operand_segment_sizes");
   genericAttrNamesSet.insert(genericAttrNames.back());
@@ -835,15 +836,20 @@ ParseResult GenericOp::parse(OpAsmParser &parser, OperationState &result) {
 static void getGenericEffectsImpl(
     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
         &effects,
-    ValueRange results, ValueRange inputBuffers, ValueRange outputs) {
-  for (Value value : inputBuffers) {
-    effects.emplace_back(MemoryEffects::Read::get(), value,
+    ValueRange results, OpOperandVector inputOperands,
+    OpOperandVector outputOperands) {
+  for (auto *operand : inputOperands) {
+    if (!operand->get().getType().isa<MemRefType>())
+      continue;
+    effects.emplace_back(MemoryEffects::Read::get(), operand->get(),
                          SideEffects::DefaultResource::get());
   }
-  for (Value value : outputs) {
-    effects.emplace_back(MemoryEffects::Read::get(), value,
+  for (auto *operand : outputOperands) {
+    if (!operand->get().getType().isa<MemRefType>())
+      continue;
+    effects.emplace_back(MemoryEffects::Read::get(), operand->get(),
                          SideEffects::DefaultResource::get());
-    effects.emplace_back(MemoryEffects::Write::get(), value,
+    effects.emplace_back(MemoryEffects::Write::get(), operand->get(),
                          SideEffects::DefaultResource::get());
   }
 }
@@ -851,10 +857,8 @@ static void getGenericEffectsImpl(
 void GenericOp::getEffects(
     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
         &effects) {
-  SmallVector<Value> inputBuffers = getInputBufferOperands();
-  SmallVector<Value> outputBuffers = getOutputBufferOperands();
-  getGenericEffectsImpl(effects, getOperation()->getResults(), inputBuffers,
-                        outputBuffers);
+  getGenericEffectsImpl(effects, getOperation()->getResults(),
+                        getInputOperands(), getOutputOperands());
 }
 
 static bool isResultValueDead(linalg::GenericOp genericOp, OpResult result) {
@@ -925,7 +929,7 @@ struct DeduplicateAndRemoveDeadOperandsAndResults
 
     // Check if there is any change to operands.
     if (newInputOperands.size() + newOutputOperands.size() ==
-        static_cast<size_t>(genericOp.getNumInputsAndOutputs()))
+        genericOp->getNumOperands())
       return failure();
 
     // Create the new op with the body being empty.
@@ -977,35 +981,34 @@ private:
                            SmallVector<AffineMap> &newIndexingMaps) const {
     llvm::SmallDenseMap<unsigned, unsigned> origToNewPos;
     llvm::SmallDenseMap<std::pair<Value, AffineMap>, unsigned> dedupedInputs;
-    for (const auto &inputOpOperand :
-         llvm::enumerate(genericOp.getInputOperands())) {
+    for (const auto &en : llvm::enumerate(genericOp.getInputOperands())) {
+      OpOperand *inputOpOperand = en.value();
       // Check if operand is dead and if dropping the indexing map makes the
       // loops to shape computation invalid.
-      if (!genericOp.payloadUsesValueFromOperand(inputOpOperand.value())) {
+      if (!genericOp.payloadUsesValueFromOperand(inputOpOperand)) {
         // Add the current operands to the list of potentially droppable
         // operands. If it cannot be dropped, this needs to be popped back.
-        droppedOpOperands.push_back(inputOpOperand.value());
+        droppedOpOperands.push_back(inputOpOperand);
         if (genericOp.canOpOperandsBeDropped(droppedOpOperands))
           continue;
         droppedOpOperands.pop_back();
       }
 
       // Check if this operand is a duplicate.
-      AffineMap indexingMap =
-          genericOp.getMatchingIndexingMap(inputOpOperand.value());
+      AffineMap indexingMap = genericOp.getMatchingIndexingMap(inputOpOperand);
       auto it = dedupedInputs.find(
-          std::make_pair(inputOpOperand.value()->get(), indexingMap));
+          std::make_pair(inputOpOperand->get(), indexingMap));
       if (it != dedupedInputs.end()) {
-        origToNewPos[inputOpOperand.index()] = it->second;
-        droppedOpOperands.push_back(inputOpOperand.value());
+        origToNewPos[en.index()] = it->second;
+        droppedOpOperands.push_back(inputOpOperand);
         continue;
       }
 
       // This is a preserved argument.
-      origToNewPos[inputOpOperand.index()] = newInputOperands.size();
-      dedupedInputs[{inputOpOperand.value()->get(), indexingMap}] =
+      origToNewPos[en.index()] = newInputOperands.size();
+      dedupedInputs[{inputOpOperand->get(), indexingMap}] =
           newInputOperands.size();
-      newInputOperands.push_back(inputOpOperand.value()->get());
+      newInputOperands.push_back(inputOpOperand->get());
       newIndexingMaps.push_back(indexingMap);
     }
     return origToNewPos;
@@ -1026,12 +1029,10 @@ private:
     // If the op doesnt have tensor semantics, keep all the outputs as
     // preserved.
     if (!genericOp.hasTensorSemantics()) {
-      for (const auto &outputOpOperand :
-           llvm::enumerate(genericOp.getOutputOperands())) {
-        origToNewPos[outputOpOperand.index()] = newOutputOperands.size();
-        newOutputOperands.push_back(outputOpOperand.value()->get());
-        newIndexingMaps.push_back(
-            genericOp.getMatchingIndexingMap(outputOpOperand.value()));
+      for (const auto &en : llvm::enumerate(genericOp.getOutputOperands())) {
+        origToNewPos[en.index()] = newOutputOperands.size();
+        newOutputOperands.push_back(en.value()->get());
+        newIndexingMaps.push_back(genericOp.getMatchingIndexingMap(en.value()));
       }
       return origToNewPos;
     }
@@ -1347,7 +1348,8 @@ ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
 }
 
 void MapOp::print(OpAsmPrinter &p) {
-  printCommonStructuredOpParts(p, getInputs(), getOutputs());
+  printCommonStructuredOpParts(p, SmallVector<Value>(getInputOperands()),
+                               SmallVector<Value>(getOutputOperands()));
   p.printOptionalAttrDict((*this)->getAttrs());
 
   p << "(";
@@ -1380,7 +1382,7 @@ LogicalResult MapOp::verify() {
 
   // The shape of each input must match the shape of the output.
   auto outputShape =
-      getOutputs().front().getType().cast<ShapedType>().getShape();
+      getOutputOperand(0)->get().getType().cast<ShapedType>().getShape();
   for (Type inputArgType : TypeRange{getInputs()}) {
     auto inputElemShape = inputArgType.cast<ShapedType>().getShape();
     if (inputElemShape != outputShape) {
@@ -1409,10 +1411,8 @@ ArrayAttr MapOp::getIndexingMaps() {
 void MapOp::getEffects(
     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
         &effects) {
-  SmallVector<Value> inputBuffers = getInputBufferOperands();
-  SmallVector<Value> outputBuffers = getOutputBufferOperands();
-  getGenericEffectsImpl(effects, getOperation()->getResults(), inputBuffers,
-                        outputBuffers);
+  getGenericEffectsImpl(effects, getOperation()->getResults(),
+                        getInputOperands(), getOutputOperands());
 }
 
 //===----------------------------------------------------------------------===//
@@ -1458,10 +1458,8 @@ ArrayAttr ReduceOp::getIndexingMaps() {
 void ReduceOp::getEffects(
     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
         &effects) {
-  SmallVector<Value> inputBuffers = getInputBufferOperands();
-  SmallVector<Value> outputBuffers = getOutputBufferOperands();
-  getGenericEffectsImpl(effects, getOperation()->getResults(), inputBuffers,
-                        outputBuffers);
+  getGenericEffectsImpl(effects, getOperation()->getResults(),
+                        getInputOperands(), getOutputOperands());
 }
 
 static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser,
@@ -1500,7 +1498,8 @@ static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName,
 }
 
 void ReduceOp::print(OpAsmPrinter &p) {
-  printCommonStructuredOpParts(p, getInputs(), getOutputs());
+  printCommonStructuredOpParts(p, SmallVector<Value>(getInputOperands()),
+                               SmallVector<Value>(getOutputOperands()));
   printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
   p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
 
@@ -1584,10 +1583,11 @@ LogicalResult ReduceOp::verify() {
   }
 
   // Check that the last block arguments match the element type of the outputs.
-  for (auto [output, bbArg] : llvm::zip(
-           getOutputs(), block->getArguments().take_back(getNumOutputs()))) {
+  for (auto [output, bbArg] :
+       llvm::zip(getOutputOperands(),
+                 block->getArguments().take_back(getNumOutputs()))) {
     auto outputElementType =
-        output.getType().cast<ShapedType>().getElementType();
+        output->get().getType().cast<ShapedType>().getElementType();
     if (outputElementType != bbArg.getType())
       return emitOpError()
              << "output element type " << outputElementType
@@ -1751,14 +1751,14 @@ struct EraseDeadLinalgOp : public OpInterfaceRewritePattern<LinalgOp> {
 
   LogicalResult matchAndRewrite(LinalgOp op,
                                 PatternRewriter &rewriter) const override {
-    for (OpOperand *opOperand : op.getInputAndOutputOperands()) {
+    for (OpOperand &opOperand : op->getOpOperands()) {
       // Linalg "inputs" may be either tensor or memref type.
       // tensor<0xelt_type> is a convention that may not always mean
       // "0 iterations". Only erase in cases we see memref<...x0x...>.
-      auto mt = opOperand->get().getType().dyn_cast<MemRefType>();
+      auto mt = opOperand.get().getType().dyn_cast<MemRefType>();
       if (!mt)
         continue;
-      if (llvm::is_contained(op.getShape(opOperand), 0)) {
+      if (llvm::is_contained(op.getShape(&opOperand), 0)) {
         rewriter.eraseOp(op);
         return success();
       }
@@ -1774,10 +1774,10 @@ struct FoldTensorCastProducerOp : public OpInterfaceRewritePattern<LinalgOp> {
                                 PatternRewriter &rewriter) const override {
     // If no operand comes from a tensor::CastOp and can be folded then fail.
     bool hasTensorCastOperand =
-        llvm::any_of(op.getInputAndOutputOperands(), [&](OpOperand *opOperand) {
-          if (opOperand->get().isa<BlockArgument>())
+        llvm::any_of(op->getOpOperands(), [&](OpOperand &opOperand) {
+          if (opOperand.get().isa<BlockArgument>())
             return false;
-          auto castOp = opOperand->get().getDefiningOp<tensor::CastOp>();
+          auto castOp = opOperand.get().getDefiningOp<tensor::CastOp>();
           return castOp && canFoldIntoConsumerOp(castOp);
         });
     if (!hasTensorCastOperand)
@@ -1788,18 +1788,17 @@ struct FoldTensorCastProducerOp : public OpInterfaceRewritePattern<LinalgOp> {
     SmallVector<Value, 4> newOperands;
     newOperands.reserve(op->getNumOperands());
     // Inputs may fold.
-    for (OpOperand *opOperand : op.getInputOperands()) {
-      auto tensorCastOp = opOperand->get().getDefiningOp<tensor::CastOp>();
+    for (auto *input : op.getInputOperands()) {
+      auto tensorCastOp = input->get().getDefiningOp<tensor::CastOp>();
       newOperands.push_back(canFoldIntoConsumerOp(tensorCastOp)
                                 ? tensorCastOp.getSource()
-                                : opOperand->get());
+                                : input->get());
     }
     // Init tensors may fold, in which case the resultType must also change.
-    for (OpOperand *opOperand : op.getOutputOperands()) {
-      auto tensorCastOp = opOperand->get().getDefiningOp<tensor::CastOp>();
+    for (auto *output : op.getOutputOperands()) {
+      auto tensorCastOp = output->get().getDefiningOp<tensor::CastOp>();
       bool fold = canFoldIntoConsumerOp(tensorCastOp);
-      newOperands.push_back(fold ? tensorCastOp.getOperand()
-                                 : opOperand->get());
+      newOperands.push_back(fold ? tensorCastOp.getOperand() : output->get());
       newResultTypes.push_back(newOperands.back().getType());
     }
     // Clone op.
@@ -1858,8 +1857,8 @@ struct FoldTensorCastConsumerOp : public OpRewritePattern<tensor::CastOp> {
     OpOperand *outOperand = linalgOp.getOutputOperand(resultNumber);
     Value newOperand =
         rewriter.create<tensor::CastOp>(loc, resultType, outOperand->get());
-    SmallVector<Value> newOperands = linalgOp.getInputOperands();
-    SmallVector<Value> outputOperands = linalgOp.getOutputOperands();
+    SmallVector<Value> newOperands{linalgOp.getInputOperands()};
+    SmallVector<Value> outputOperands{linalgOp.getOutputOperands()};
     outputOperands[resultNumber] = newOperand;
     newOperands.append(outputOperands.begin(), outputOperands.end());
 
@@ -1882,14 +1881,14 @@ struct FoldTensorCastConsumerOp : public OpRewritePattern<tensor::CastOp> {
 
 /// For each of the operand in `operands` this function maps the static sizes of
 /// dimensions to their affine dim expressions.
-static void populateMap(LinalgOp linalgOp, ArrayRef<OpOperand *> operands,
+static void populateMap(LinalgOp linalgOp, MutableArrayRef<OpOperand> operands,
                         llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize) {
-  for (OpOperand *opOperand : operands) {
-    if (linalgOp.isScalar(opOperand))
+  for (OpOperand &opOperand : operands) {
+    if (linalgOp.isScalar(&opOperand))
       continue;
-    Value src = opOperand->get();
+    Value src = opOperand.get();
     auto sourceType = src.getType().cast<RankedTensorType>();
-    auto sourceMap = linalgOp.getMatchingIndexingMap(opOperand);
+    auto sourceMap = linalgOp.getMatchingIndexingMap(&opOperand);
 
     // Get the `sourceShape` of the `sourceType`. If the operand is a result of
     // `tensor.cast` operation and source of the cast operation has a static
@@ -1932,7 +1931,7 @@ static void createNewOperandWithStaticSizes(
     return;
   auto sourceType = src.getType().cast<RankedTensorType>();
   Type resultType = sourceType;
-  if (sourceType.hasStaticShape() && linalgOp.isOutputTensor(opOperand)) {
+  if (sourceType.hasStaticShape() && linalgOp.isOutput(opOperand)) {
     resultTypes.push_back(resultType);
     return;
   }
@@ -1965,7 +1964,7 @@ static void createNewOperandWithStaticSizes(
     unsigned index = opOperand->getOperandNumber();
     newOperands[index] = newOperand;
   }
-  if (linalgOp.isOutputTensor(opOperand))
+  if (linalgOp.isOutput(opOperand))
     resultTypes.push_back(resultType);
 }
 
@@ -1992,8 +1991,7 @@ struct InferStaticShapeOfOperands : public OpInterfaceRewritePattern<LinalgOp> {
 
     // For each of the affine dim expression, check if the size is known. If
     // known add that in the map.
-    populateMap(linalgOp, linalgOp.getInputAndOutputOperands(),
-                affineExprToSize);
+    populateMap(linalgOp, linalgOp->getOpOperands(), affineExprToSize);
 
     SmallVector<Value> newOperands;
     SmallVector<Type> resultTypes;
@@ -2001,12 +1999,12 @@ struct InferStaticShapeOfOperands : public OpInterfaceRewritePattern<LinalgOp> {
     // `changeNeeded` is `false` if the operands of `linalgOp` require no
     // change in their types.
     bool changeNeeded = false;
-    newOperands.reserve(linalgOp.getNumInputsAndOutputs());
+    newOperands.reserve(linalgOp->getNumOperands());
     resultTypes.reserve(linalgOp.getNumOutputs());
 
     // Iterate over all the operands and update the static sizes.
-    for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) {
-      createNewOperandWithStaticSizes(loc, rewriter, opOperand,
+    for (OpOperand &opOperand : linalgOp->getOpOperands()) {
+      createNewOperandWithStaticSizes(loc, rewriter, &opOperand,
                                       affineExprToSize, linalgOp, newOperands,
                                       resultTypes, changeNeeded);
     }
index 2cf8a57..383a926 100644 (file)
@@ -112,14 +112,14 @@ struct BubbleUpExtractSliceOpPattern
       tileSizes[position] = sliceOp.getMixedSizes()[result.index()];
     }
 
-    SmallVector<Value> valuesToTile = linalgOp.getInputAndOutputOperands();
+    SmallVector<Value> valuesToTile = linalgOp->getOperands();
     SmallVector<Value> tiledOperands =
         makeTiledShapes(rewriter, linalgLoc, linalgOp, valuesToTile,
                         tileOffsets, tileSizes, sizeBounds,
                         /*omitPartialTileCheck=*/true);
 
     SmallVector<Type, 4> resultTensorTypes;
-    for (OpOperand *opOperand : linalgOp.getOutputTensorOperands())
+    for (OpOperand *opOperand : linalgOp.getOutputOperands())
       resultTensorTypes.push_back(
           tiledOperands[opOperand->getOperandNumber()].getType());
 
index abc430f..bb38004 100644 (file)
@@ -118,7 +118,7 @@ struct LinalgOpInterface
     auto genericOp = cast<linalg::DestinationStyleOpInterface>(op);
 
     // The i-th "out" tensor may alias with the i-th OpResult.
-    if (genericOp.isOutputTensor(&opOperand))
+    if (genericOp.isOutput(&opOperand))
       return {genericOp.getTiedOpResult(&opOperand)};
     return {};
   }
index 58a54ba..a21e0fc 100644 (file)
@@ -68,17 +68,17 @@ public:
     if (!outputType || !outputType.hasStaticShape())
       return failure();
 
-    if (!llvm::all_of(genericOp.getInputOperands(), [](OpOperand *operand) {
-          return operand->get().getType().isa<ShapedType>();
+    if (!llvm::all_of(genericOp.getInputs(), [](Value input) {
+          return input.getType().isa<ShapedType>();
         }))
       return failure();
 
     // Make sure all element types are the same.
-    auto getOperandElementType = [](OpOperand *operand) {
-      return operand->get().getType().cast<ShapedType>().getElementType();
+    auto getOperandElementType = [](Value value) {
+      return value.getType().cast<ShapedType>().getElementType();
     };
-    if (!llvm::all_equal(llvm::map_range(genericOp.getInputAndOutputOperands(),
-                                         getOperandElementType)))
+    if (!llvm::all_equal(
+            llvm::map_range(genericOp->getOperands(), getOperandElementType)))
       return failure();
 
     // We can only handle the case where we have int/float elements.
@@ -114,15 +114,15 @@ public:
     // All inputs should be constants.
     int numInputs = genericOp.getNumInputs();
     SmallVector<DenseIntOrFPElementsAttr> inputValues(numInputs);
-    for (const auto &operand : llvm::enumerate(genericOp.getInputOperands())) {
-      if (!matchPattern(operand.value()->get(),
-                        m_Constant(&inputValues[operand.index()])))
+    for (const auto &en : llvm::enumerate(genericOp.getInputOperands())) {
+      if (!matchPattern(en.value()->get(),
+                        m_Constant(&inputValues[en.index()])))
         return failure();
     }
 
     // Identified this as a potential candidate for folding. Now check the
     // policy to see whether we are allowed to proceed.
-    for (auto *operand : genericOp.getInputOperands()) {
+    for (OpOperand *operand : genericOp.getInputOperands()) {
       if (!controlFn(operand))
         return failure();
     }
@@ -171,8 +171,8 @@ public:
     APIntOrFloatArray computeFnInputs;
 
     auto inputShapes = llvm::to_vector<4>(
-        llvm::map_range(genericOp.getInputOperands(), [](OpOperand *operand) {
-          return operand->get().getType().cast<ShapedType>().getShape();
+        llvm::map_range(genericOp.getInputs(), [](Value value) {
+          return value.getType().cast<ShapedType>().getShape();
         }));
 
     // Given a `linearIndex`, remap it to a linear index to access linalg op
index cebc978..327e8be 100644 (file)
@@ -194,7 +194,7 @@ DecomposeLinalgOp::createPeeledGenericOp(GenericOp genericOp,
   }
 
   /// Create the peeled generic op with an empty body.
-  SmallVector<Value> outsOperands = genericOp.getOutputOperands();
+  SmallVector<Value> outsOperands = genericOp.getOutputs();
   outsOperands.append(newInitValues.begin(), newInitValues.end());
   SmallVector<Type> resultTypes = llvm::to_vector(genericOp.getResultTypes());
   resultTypes.append(newResultTypes.begin(), newResultTypes.end());
@@ -212,9 +212,7 @@ DecomposeLinalgOp::createResidualGenericOp(GenericOp genericOp,
                                            PatternRewriter &rewriter) const {
   /// Append all results from the peeledGenericOps as `ins` operand for the
   /// residual generic op.
-  SmallVector<Value> residualGenericOpOperands = llvm::to_vector(
-      llvm::map_range(genericOp.getInputOperands(),
-                      [](OpOperand *operand) { return operand->get(); }));
+  SmallVector<Value> residualGenericOpOperands = genericOp.getInputs();
   unsigned origNumResults = genericOp.getNumResults();
   unsigned peeledGenericOpNumResults = peeledGenericOp.getNumResults();
   SmallVector<Value> extraIns;
index baef90c..acc0126 100644 (file)
@@ -55,10 +55,9 @@ bool canBeDetensored(TensorType tensorType) {
 bool shouldBeDetensored(Operation *op, TypeConverter typeConverter) {
   GenericOp genericOp = dyn_cast_or_null<GenericOp>(op);
   return genericOp &&
-         llvm::all_of(
-             genericOp.getInputAndOutputOperands(), [&](OpOperand *opOperand) {
-               return !typeConverter.isLegal(opOperand->get().getType());
-             });
+         llvm::all_of(genericOp->getOpOperands(), [&](OpOperand &opOperand) {
+           return !typeConverter.isLegal(opOperand.get().getType());
+         });
 }
 
 /// A conversion patttern for detensoring `linalg.generic` ops.
index 361c85a..2fcb2ac 100644 (file)
@@ -377,21 +377,21 @@ struct ReplaceUnitExtents : public OpRewritePattern<GenericOp> {
     SmallVector<ArrayAttr> reassociationMaps;
     SmallVector<Type> newInputOutputTypes;
     bool doCanonicalization = false;
-    for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) {
-      auto replacementInfo = replaceUnitExtents(genericOp, opOperand, context);
+    for (OpOperand &opOperand : genericOp->getOpOperands()) {
+      auto replacementInfo = replaceUnitExtents(genericOp, &opOperand, context);
       if (replacementInfo) {
         reassociationMaps.push_back(replacementInfo->reassociation);
         newIndexingMaps.push_back(replacementInfo->indexMap);
         newInputOutputTypes.push_back(replacementInfo->type);
         doCanonicalization |=
-            replacementInfo->type != opOperand->get().getType();
+            replacementInfo->type != opOperand.get().getType();
       } else {
         // If replaceUnitExtents cannot handle this case, maintain the same
         // type, indexing map, and create a set of mappings representing an
         // identity matrix.
-        newInputOutputTypes.push_back(opOperand->get().getType());
-        newIndexingMaps.push_back(genericOp.getMatchingIndexingMap(opOperand));
-        int64_t origRank = genericOp.getRank(opOperand);
+        newInputOutputTypes.push_back(opOperand.get().getType());
+        newIndexingMaps.push_back(genericOp.getMatchingIndexingMap(&opOperand));
+        int64_t origRank = genericOp.getRank(&opOperand);
         auto maps = llvm::to_vector<8>(llvm::map_range(
             llvm::seq<int64_t>(0, origRank), [&](int64_t dim) -> Attribute {
               return AffineMapAttr::get(
index 05dce4c..80cef16 100644 (file)
@@ -90,7 +90,7 @@ bool mlir::linalg::areElementwiseOpsFusable(OpOperand *fusedOperand) {
 
   // Only allow fusing the producer of an input operand for now.
   // TODO: allow fusing the producer of an output operand.
-  if (!consumer.isInputTensor(fusedOperand))
+  if (!consumer.isInput(fusedOperand))
     return false;
 
   // Get the consumer index map. The number of results of the consumer index
@@ -179,7 +179,7 @@ generateFusedElementwiseOpRegion(RewriterBase &rewriter, GenericOp fusedOp,
     }
   }
   // TODO: allow fusing the producer of an output operand.
-  assert(consumer.isInputTensor(fusedOperand) &&
+  assert(consumer.isInput(fusedOperand) &&
          "expected producer of input operand");
   // 3. Consumer input operands up to consumerIdx (exclusive).
   for (BlockArgument bbArg : consumerBlock.getArguments().take_front(
@@ -267,7 +267,7 @@ mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter,
   auto producer = cast<GenericOp>(producerResult.getOwner());
   auto consumer = cast<GenericOp>(fusedOperand->getOwner());
   // TODO: allow fusing the producer of an output operand.
-  assert(consumer.isInputTensor(fusedOperand) &&
+  assert(consumer.isInput(fusedOperand) &&
          "expected producer of input operand");
 
   // Compute the fused operands list and indexing maps.
@@ -278,13 +278,14 @@ mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter,
   fusedOutputOperands.reserve(producer.getNumOutputs() +
                               consumer.getNumOutputs());
   fusedResultTypes.reserve(producer.getNumOutputs() + consumer.getNumOutputs());
-  fusedIndexMaps.reserve(producer.getNumInputsAndOutputs() +
-                         consumer.getNumInputsAndOutputs());
+  fusedIndexMaps.reserve(producer->getNumOperands() +
+                         consumer->getNumOperands());
   // In the following, numbering matches that of `generateFusedTensorOpRegion`.
   // 3. Consumer input operands/maps up to consumerIdx (exclusive).
-  SmallVector<OpOperand *> consumerInputs = consumer.getInputOperands();
-  SmallVector<OpOperand *>::iterator it =
-      llvm::find(consumerInputs, fusedOperand);
+  auto consumerInputs = consumer.getInputOperands();
+  auto *it = llvm::find_if(consumerInputs, [&](OpOperand *operand) {
+    return operand == fusedOperand;
+  });
   assert(it != consumerInputs.end() && "expected to find the consumer operand");
   for (OpOperand *opOperand : llvm::make_range(consumerInputs.begin(), it)) {
     fusedInputOperands.push_back(opOperand->get());
@@ -373,13 +374,13 @@ public:
   LogicalResult matchAndRewrite(GenericOp genericOp,
                                 PatternRewriter &rewriter) const override {
     // Find the first operand that is defined by another generic op on tensors.
-    for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) {
-      if (!areElementwiseOpsFusable(opOperand))
+    for (OpOperand &opOperand : genericOp->getOpOperands()) {
+      if (!areElementwiseOpsFusable(&opOperand))
         continue;
-      if (!controlFn(opOperand))
+      if (!controlFn(&opOperand))
         continue;
 
-      FailureOr<Operation *> fusedOp = fuseElementwiseOps(rewriter, opOperand);
+      FailureOr<Operation *> fusedOp = fuseElementwiseOps(rewriter, &opOperand);
       if (succeeded(fusedOp)) {
         auto replacements =
             fusedOp.value()->getResults().take_back(genericOp.getNumResults());
@@ -727,9 +728,9 @@ fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp,
                                                : collapsingReshapeOp.getSrc());
       continue;
     }
-    if (genericOp.isInputTensor(opOperand)) {
+    if (auto opOperandType =
+            opOperand->get().getType().dyn_cast<RankedTensorType>()) {
       AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand);
-      auto opOperandType = opOperand->get().getType().cast<RankedTensorType>();
       RankedTensorType expandedOperandType =
           getExpandedType(opOperandType, indexingMap, expansionInfo);
       if (expandedOperandType != opOperand->get().getType()) {
@@ -833,7 +834,7 @@ public:
 
   LogicalResult matchAndRewrite(GenericOp genericOp,
                                 PatternRewriter &rewriter) const override {
-    for (OpOperand *opOperand : genericOp.getInputTensorOperands()) {
+    for (OpOperand *opOperand : genericOp.getInputOperands()) {
       tensor::CollapseShapeOp reshapeOp =
           opOperand->get().getDefiningOp<tensor::CollapseShapeOp>();
       if (!reshapeOp)
@@ -1494,17 +1495,17 @@ public:
 
   LogicalResult matchAndRewrite(GenericOp genericOp,
                                 PatternRewriter &rewriter) const override {
-    for (OpOperand *opOperand : genericOp.getInputTensorOperands()) {
+    for (OpOperand &opOperand : genericOp->getOpOperands()) {
       tensor::ExpandShapeOp reshapeOp =
-          opOperand->get().getDefiningOp<tensor::ExpandShapeOp>();
+          opOperand.get().getDefiningOp<tensor::ExpandShapeOp>();
       if (!reshapeOp)
         continue;
 
       SmallVector<ReassociationIndices> collapsableIterationDims =
-          getCollapsableIterationSpaceDims(genericOp, opOperand,
+          getCollapsableIterationSpaceDims(genericOp, &opOperand,
                                            reshapeOp.getReassociationIndices());
       if (collapsableIterationDims.empty() ||
-          !controlFoldingReshapes(opOperand)) {
+          !controlFoldingReshapes(&opOperand)) {
         continue;
       }
 
@@ -1614,7 +1615,7 @@ public:
       SmallVector<AffineMap> fusedIndexMaps;
       SmallVector<Value> fusedOperands;
       SmallVector<Location> fusedLocs{genericOp.getLoc()};
-      fusedIndexMaps.reserve(genericOp.getNumInputsAndOutputs());
+      fusedIndexMaps.reserve(genericOp->getNumOperands());
       fusedOperands.reserve(genericOp.getNumInputs());
       fusedLocs.reserve(fusedLocs.size() + genericOp.getNumInputs());
       for (OpOperand *inputOperand : genericOp.getInputOperands()) {
@@ -1640,7 +1641,7 @@ public:
       Value scalarConstant = rewriter.create<arith::ConstantOp>(
           def->getLoc(), constantAttr, constantAttr.getType());
 
-      SmallVector<Value> outputOperands = genericOp.getOutputOperands();
+      SmallVector<Value> outputOperands = genericOp.getOutputs();
       auto fusedOp = rewriter.create<GenericOp>(
           rewriter.getFusedLoc(fusedLocs), genericOp->getResultTypes(),
           /*inputs=*/fusedOperands,
index 5738d51..5c2c987 100644 (file)
@@ -68,7 +68,7 @@ getShapeDefiningLoopRange(LinalgOp op, unsigned loopDepth,
                           bool fromSubViewOpOnly = false) {
   // Iterate over the inputs and outputs in order.
   // Extract the subranges from the linearized ranges.
-  for (OpOperand *opOperand : op.getInputAndOutputOperands()) {
+  for (OpOperand &opOperand : op->getOpOperands()) {
     // The method `getRangeFromOperandShape` requires using SubViewOp or
     // ExtractSliceOps. If the value isn't defined from there continue.
     // todo: The method should be adapted to get the values from
@@ -77,12 +77,12 @@ getShapeDefiningLoopRange(LinalgOp op, unsigned loopDepth,
     // `std` dialect and add the method to `ViewInterface`.
     if (fromSubViewOpOnly &&
         !isa_and_nonnull<memref::SubViewOp, tensor::ExtractSliceOp>(
-            opOperand->get().getDefiningOp()))
+            opOperand.get().getDefiningOp()))
       continue;
 
-    AffineMap map = op.getMatchingIndexingMap(opOperand);
+    AffineMap map = op.getMatchingIndexingMap(&opOperand);
     LLVM_DEBUG(llvm::dbgs() << "getShapeDefiningLoopRange I/O idx: "
-                            << opOperand->getOperandNumber() << "\n");
+                            << opOperand.getOperandNumber() << "\n");
     LLVM_DEBUG(llvm::dbgs()
                << "getShapeDefiningLoopRange map: " << map << "\n");
     SmallVector<Value, 8> shapeRanges(map.getNumResults(), nullptr);
@@ -94,8 +94,8 @@ getShapeDefiningLoopRange(LinalgOp op, unsigned loopDepth,
         LLVM_DEBUG(llvm::dbgs() << "getShapeDefiningLoopRange loopDepth: "
                                 << loopDepth << "\n");
         LLVM_DEBUG(llvm::dbgs() << "getShapeDefiningLoopRange shape: "
-                                << opOperand->get() << "\n");
-        return ShapeDimension{opOperand->get(),
+                                << opOperand.get() << "\n");
+        return ShapeDimension{opOperand.get(),
                               static_cast<unsigned>(en.index())};
       }
     }
@@ -104,7 +104,7 @@ getShapeDefiningLoopRange(LinalgOp op, unsigned loopDepth,
 }
 
 static SmallVector<Value> getTiledOperands(LinalgOp producer) {
-  return producer.getInputAndOutputOperands();
+  return producer->getOperands();
 }
 
 /// Fuses the producer by cloning the `producer`. The `fusedLoopsAndRanges`
@@ -137,7 +137,7 @@ static LinalgOp fuse(OpBuilder &b, LinalgOp producer,
   }
 
   SmallVector<Value, 8> clonedShapes;
-  clonedShapes.reserve(producer.getNumInputsAndOutputs());
+  clonedShapes.reserve(producer->getNumOperands());
 
   // Compute subranges for all tensor input/output operands.
   clonedShapes.append(makeTiledShapes(
@@ -150,15 +150,18 @@ static LinalgOp fuse(OpBuilder &b, LinalgOp producer,
   // fully dynamic at construction time.
   SmallVector<Type, 4> resultTypes;
   resultTypes.reserve(producer->getNumResults());
-  for (RankedTensorType t : producer.getOutputTensorTypes()) {
-    unsigned rank = t.getRank();
+  for (OpOperand *operand : producer.getOutputOperands()) {
+    auto tensorType = operand->get().getType().dyn_cast<RankedTensorType>();
+    if (!tensorType)
+      continue;
+    unsigned rank = tensorType.getRank();
     SmallVector<int64_t, 4> staticOffsetsVector(
         rank, ShapedType::kDynamicStrideOrOffset);
     SmallVector<int64_t, 4> staticSizesVector(rank, ShapedType::kDynamicSize);
     SmallVector<int64_t, 4> staticStridesVector(
         rank, ShapedType::kDynamicStrideOrOffset);
     resultTypes.push_back(tensor::ExtractSliceOp::inferResultType(
-        t.cast<RankedTensorType>(), staticOffsetsVector, staticSizesVector,
+        tensorType, staticOffsetsVector, staticSizesVector,
         staticStridesVector));
   }
 
index 2451c79..1dd6c35 100644 (file)
@@ -161,7 +161,7 @@ static LinalgOp getTiledProducer(OpBuilder &b, OpResult producerResult,
     allIvs[tiledProducerLoop] = tileIvs[tiledProducerLoop];
   }
   erase_value(tileIvs, OpFoldResult());
-  SmallVector<Value> tiledOperands = producerOp.getInputAndOutputOperands();
+  SmallVector<Value> tiledOperands = producerOp->getOperands();
   tiledOperands = makeTiledShapes(b, loc, producerOp, tiledOperands, tileIvs,
                                   tileSizes, producerLoopBounds,
                                   /**omitPartialTileCheck=*/false);
index d656e92..ea6ce39 100644 (file)
@@ -50,19 +50,19 @@ FailureOr<GenericOp> mlir::linalg::generalizeNamedOp(RewriterBase &rewriter,
   if (failed(generalizeNamedOpPrecondition(linalgOp)))
     return rewriter.notifyMatchFailure(linalgOp, "preconditions not met");
 
-  SmallVector<Value> inputOperands = linalgOp.getInputOperands();
-  SmallVector<Value> outputOperands = linalgOp.getOutputOperands();
+  SmallVector<Value> inputs = linalgOp.getInputOperands();
+  SmallVector<Value> outputs = linalgOp.getOutputOperands();
   SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray();
   SmallVector<StringRef> iterators = linalgOp.getIteratorTypesArray();
-  SmallVector<RankedTensorType> resultTypes = linalgOp.getOutputTensorTypes();
-  SmallVector<Type> types(resultTypes.begin(), resultTypes.end());
+  SmallVector<Type> resultTypes = linalgOp.hasTensorSemantics()
+                                      ? TypeRange(ValueRange(outputs))
+                                      : TypeRange{};
 
   // All named ops have a region attached that can be inlined.
   assert(linalgOp->getNumRegions() == 1 &&
          "expect named op to have one region attached");
-  GenericOp genericOp =
-      rewriter.create<GenericOp>(linalgOp.getLoc(), types, inputOperands,
-                                 outputOperands, indexingMaps, iterators);
+  GenericOp genericOp = rewriter.create<GenericOp>(
+      linalgOp.getLoc(), resultTypes, inputs, outputs, indexingMaps, iterators);
   rewriter.inlineRegionBefore(linalgOp->getRegion(0), genericOp.getRegion(),
                               genericOp.getRegion().begin());
   rewriter.replaceOp(linalgOp, genericOp->getResults());
index 2ca15dc..7515e30 100644 (file)
@@ -111,7 +111,7 @@ private:
 static bool isOnlyUsedAsInputOfLinalgOp(tensor::PadOp padOp) {
   for (OpOperand &use : padOp.getResult().getUses()) {
     auto linalgUser = dyn_cast<linalg::LinalgOp>(use.getOwner());
-    if (!linalgUser || !linalgUser.isInputTensor(&use)) {
+    if (!linalgUser || !linalgUser.isInput(&use)) {
       LLVM_DEBUG(DBGS() << "Found a use of " << *(padOp)
                         << "\nthat is not an input tensor of a LinalgOp, "
                         << "cannot hoist\n"
index 04e94b1..4ea889d 100644 (file)
@@ -43,7 +43,7 @@ struct InlineScalarOperands : public OpRewritePattern<GenericOp> {
     SmallVector<Value> newOperands;
     for (OpOperand *opOperand : genericOp.getInputOperands()) {
       AffineMap map = genericOp.getMatchingIndexingMap(opOperand);
-      if (genericOp.isInputTensor(opOperand) && map.isConstant()) {
+      if (genericOp.isInput(opOperand) && map.isConstant()) {
         scalarOperands.emplace_back(opOperand->getOperandNumber());
       } else {
         newIndexingMaps.emplace_back(map);
@@ -58,7 +58,7 @@ struct InlineScalarOperands : public OpRewritePattern<GenericOp> {
       newIndexingMaps.emplace_back(genericOp.getMatchingIndexingMap(opOperand));
 
     Location loc = genericOp->getLoc();
-    SmallVector<Value> outputOperands = genericOp.getOutputOperands();
+    SmallVector<Value> outputOperands = genericOp.getOutputs();
     auto newOp = rewriter.create<GenericOp>(
         loc, genericOp->getResultTypes(), newOperands, outputOperands,
         newIndexingMaps, genericOp.getIteratorTypesArray());
index 8641e11..a745387 100644 (file)
@@ -67,8 +67,8 @@ mlir::linalg::interchangeGenericOp(RewriterBase &rewriter, GenericOp genericOp,
 
   // 2. Compute the interchanged indexing maps.
   SmallVector<AffineMap> newIndexingMaps;
-  for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) {
-    AffineMap m = genericOp.getMatchingIndexingMap(opOperand);
+  for (OpOperand &opOperand : genericOp->getOpOperands()) {
+    AffineMap m = genericOp.getMatchingIndexingMap(&opOperand);
     if (!permutationMap.isEmpty())
       m = m.compose(permutationMap);
     newIndexingMaps.push_back(m);
index 3052a4d..4fc9149 100644 (file)
@@ -131,7 +131,7 @@ static void emitScalarImplementation(OpBuilder &b, Location loc,
   assert(linalgOp.hasBufferSemantics() &&
          "expected linalg op with buffer semantics");
   SmallVector<Value> indexedValues;
-  indexedValues.reserve(linalgOp.getNumInputsAndOutputs());
+  indexedValues.reserve(linalgOp->getNumOperands());
 
   auto allIvsPlusDims = SmallVector<Value>(allIvs.begin(), allIvs.end());
 
@@ -161,7 +161,9 @@ static void emitScalarImplementation(OpBuilder &b, Location loc,
   // 3. Emit store.
   SmallVector<SmallVector<Value>, 8> indexing;
   SmallVector<Value> outputBuffers;
-  for (OpOperand *outputOperand : linalgOp.getOutputBufferOperands()) {
+  for (OpOperand *outputOperand : linalgOp.getOutputOperands()) {
+    if (!outputOperand->get().getType().isa<MemRefType>())
+      continue;
     indexing.push_back(makeCanonicalAffineApplies(
         b, loc, linalgOp.getMatchingIndexingMap(outputOperand),
         allIvsPlusDims));
index 17d74fa..0995b01 100644 (file)
@@ -145,15 +145,15 @@ LinalgOpInstancePromotionOptions::LinalgOpInstancePromotionOptions(
   assert(linalgOp.hasBufferSemantics() && "revisit usage of shaped operand");
   auto vUseFullTileBuffers =
       options.useFullTileBuffers.value_or(llvm::SmallBitVector());
-  vUseFullTileBuffers.resize(linalgOp.getNumInputsAndOutputs(),
+  vUseFullTileBuffers.resize(linalgOp->getNumOperands(),
                              options.useFullTileBuffersDefault);
 
-  for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) {
-    int64_t operandNumber = opOperand->getOperandNumber();
+  for (OpOperand &opOperand : linalgOp->getOpOperands()) {
+    int64_t operandNumber = opOperand.getOperandNumber();
     if (options.operandsToPromote &&
         !options.operandsToPromote->count(operandNumber))
       continue;
-    Operation *op = opOperand->get().getDefiningOp();
+    Operation *op = opOperand.get().getDefiningOp();
     if (auto sv = dyn_cast_or_null<memref::SubViewOp>(op)) {
       subViews[operandNumber] = sv;
       useFullTileBuffers[sv] = vUseFullTileBuffers[operandNumber];
@@ -326,13 +326,13 @@ promoteSubViews(ImplicitLocOpBuilder &b, LinalgOp op,
   // operands are not views. This is to support cases such as FillOp taking
   // extra scalars etc.  Keep a reference to output buffers;
   SmallVector<Value, 8> opViews;
-  opViews.reserve(op.getNumInputsAndOutputs());
+  opViews.reserve(op->getNumOperands());
   SmallVector<std::pair<Value, Value>, 8> writebackViews;
   writebackViews.reserve(promotedBuffersAndViews->size());
-  for (OpOperand *opOperand : op.getInputAndOutputOperands()) {
-    int64_t operandNumber = opOperand->getOperandNumber();
+  for (OpOperand &opOperand : op->getOpOperands()) {
+    int64_t operandNumber = opOperand.getOperandNumber();
     if (options.subViews.count(operandNumber) != 0) {
-      if (options.useFullTileBuffers[opOperand->get()])
+      if (options.useFullTileBuffers[opOperand.get()])
         opViews.push_back(
             (*promotedBuffersAndViews)[operandNumber].fullLocalView);
       else
@@ -340,10 +340,10 @@ promoteSubViews(ImplicitLocOpBuilder &b, LinalgOp op,
             (*promotedBuffersAndViews)[operandNumber].partialLocalView);
       if (operandNumber >= op.getNumInputs())
         writebackViews.emplace_back(std::make_pair(
-            opOperand->get(),
+            opOperand.get(),
             (*promotedBuffersAndViews)[operandNumber].partialLocalView));
     } else {
-      opViews.push_back(opOperand->get());
+      opViews.push_back(opOperand.get());
     }
   }
   op->setOperands(0, opViews.size(), opViews);
@@ -371,12 +371,12 @@ mlir::linalg::promoteSubviewsPrecondition(Operation *op,
   if (!linalgOp || !linalgOp.hasBufferSemantics())
     return failure();
   // Check that at least one of the requested operands is indeed a subview.
-  for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) {
+  for (OpOperand &opOperand : linalgOp->getOpOperands()) {
     auto sv =
-        isa_and_nonnull<memref::SubViewOp>(opOperand->get().getDefiningOp());
+        isa_and_nonnull<memref::SubViewOp>(opOperand.get().getDefiningOp());
     if (sv) {
       if (!options.operandsToPromote ||
-          options.operandsToPromote->count(opOperand->getOperandNumber()))
+          options.operandsToPromote->count(opOperand.getOperandNumber()))
         return success();
     }
   }
index 7df65c8..92d04c1 100644 (file)
@@ -214,7 +214,6 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
   // from the previous op.
   unsigned intermRank = newOutputShape.size();
   AffineMap inputMap = b.getMultiDimIdentityMap(intermRank);
-  SmallVector<Value> outputOperands = op.getOutputOperands();
   SmallVector<StringRef> reductionIteratorTypes;
   SmallVector<AffineExpr> exprs;
   for (unsigned i : llvm::seq<unsigned>(0, intermRank)) {
@@ -230,7 +229,8 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
 
   auto reduction = b.create<GenericOp>(
       loc, op->getResultTypes(), ValueRange({genericOp.getResult(0)}),
-      outputOperands, reductionMaps, reductionIteratorTypes,
+      SmallVector<Value>{op.getOutputOperands()}, reductionMaps,
+      reductionIteratorTypes,
       [reductionOp](OpBuilder &b, Location loc, ValueRange inputs) {
         Operation *clonedReductionOp = b.clone(*reductionOp);
         clonedReductionOp->setOperand(0, inputs[0]);
@@ -341,8 +341,8 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReductionByScaling(
   SmallVector<Operation *> emptyOrAllocTensorOps;
   SmallVector<linalg::FillOp> fillOps;
   fillOps.reserve(op.getNumOutputs());
-  for (auto it : llvm::zip(op.getOutputs(), neutralElements)) {
-    Value rankedTensor = std::get<0>(it);
+  for (auto it : llvm::zip(op.getOutputOperands(), neutralElements)) {
+    Value rankedTensor = std::get<0>(it)->get();
     auto t = rankedTensor.getType().cast<RankedTensorType>();
     RankedTensorType newT = RankedTensorType::Builder(t).insertDim(
         reductionDimSize / splitFactor, insertSplitDimension);
@@ -366,7 +366,7 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReductionByScaling(
   // Step 2. Reindex / expand indexing maps.
   // Reindex existing input indexings: k -> k * splitFactor + k'.
   SmallVector<AffineMap> newMaps;
-  newMaps.reserve(op.getNumInputsAndOutputs() + 1);
+  newMaps.reserve(op->getNumOperands() + 1);
   for (OpOperand *o : op.getInputOperands())
     newMaps.push_back(scaleReductionDim(op, *o, reductionDimPos, splitFactor));
   // Provision a new indexing for the shape-only tensor.
@@ -384,7 +384,7 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReductionByScaling(
 
   // Step 3. Handle operands.
   // Compute the new input tensors.
-  auto newInputs = llvm::to_vector<4>(op.getInputs());
+  SmallVector<Value> newInputs(op.getInputOperands());
   // Add a single shape-only tensor to carry the dimensions without resorting to
   // more complex inversions.
   newInputs.push_back(b.create<tensor::EmptyOp>(
@@ -413,10 +413,10 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReductionByScaling(
   // TODO: all results can be handled in a single GenericOp, when
   // multi-reduction support is available.
   SmallVector<LinalgOp> results;
-  for (auto it :
-       llvm::zip(genericOp->getResults(), op.getOutputs(), combinerOps)) {
+  for (auto it : llvm::zip(genericOp->getResults(), op.getOutputOperands(),
+                           combinerOps)) {
     Value reindexedOutput = std::get<0>(it);
-    Value originalOutput = std::get<1>(it);
+    Value originalOutput = std::get<1>(it)->get();
     auto originalOutputType = originalOutput.getType().cast<RankedTensorType>();
     Operation *combinerOp = std::get<2>(it);
 
index dd04d00..b66a718 100644 (file)
@@ -503,7 +503,7 @@ tileLinalgOpImpl(RewriterBase &b, LinalgOp op, ArrayRef<OpFoldResult> tileSizes,
     // Tile the `operandValuesToUse` that either match the `op` operands
     // themselves or the tile loop arguments forwarding them.
     assert(operandValuesToUse.size() ==
-               static_cast<size_t>(op.getNumInputsAndOutputs()) &&
+               static_cast<size_t>(op->getNumOperands()) &&
            "expect the number of operands and inputs and outputs to match");
     SmallVector<Value> valuesToTile = operandValuesToUse;
     SmallVector<OpFoldResult> sizeBounds =
index 66d55dc..d88b2c5 100644 (file)
@@ -125,14 +125,12 @@ struct LinalgOpTilingInterface
     // specified could lead to out of bounds accesses.
     Location loc = op->getLoc();
     LinalgOp linalgOp = cast<LinalgOp>(op);
-    SmallVector<Value> valuesToTile = linalgOp.getInputAndOutputOperands();
+    SmallVector<Value> valuesToTile = linalgOp->getOperands();
     SmallVector<Value, 4> tiledOperands = makeTiledShapes(
         b, loc, linalgOp, valuesToTile, offsets, sizes, {}, true);
 
-    SmallVector<Type> resultTensorTypes = llvm::to_vector(llvm::map_range(
-        linalgOp.getOutputTensorOperands(), [&](OpOperand *opOperand) {
-          return tiledOperands[opOperand->getOperandNumber()].getType();
-        }));
+    SmallVector<Type> resultTensorTypes =
+        getTensorOutputTypes(linalgOp, tiledOperands);
 
     Operation *tiledOp =
         linalgOp.clone(b, loc, resultTensorTypes, tiledOperands);
@@ -222,23 +220,23 @@ struct LinalgOpTilingInterface
       return op->emitOpError("expected operation to have buffer semantics");
 
     SmallVector<Value> indexedValues;
-    indexedValues.reserve(linalgOp.getNumInputsAndOutputs());
+    indexedValues.reserve(linalgOp->getNumOperands());
     Location linalgOpLoc = op->getLoc();
     /// Load the data corresponding to the block arguments that
     /// represent input operands.
-    for (OpOperand *operand : linalgOp.getInputAndOutputOperands()) {
-      if (!linalgOp.payloadUsesValueFromOperand(operand)) {
+    for (OpOperand &operand : linalgOp->getOpOperands()) {
+      if (!linalgOp.payloadUsesValueFromOperand(&operand)) {
         indexedValues.push_back(nullptr);
         continue;
       }
-      if (linalgOp.isScalar(operand)) {
-        indexedValues.push_back(operand->get());
+      if (linalgOp.isScalar(&operand)) {
+        indexedValues.push_back(operand.get());
         continue;
       }
       SmallVector<Value> indices = getIndicesForAccess(
-          builder, linalgOpLoc, linalgOp.getMatchingIndexingMap(operand), ivs);
+          builder, linalgOpLoc, linalgOp.getMatchingIndexingMap(&operand), ivs);
       Value load =
-          builder.create<memref::LoadOp>(linalgOpLoc, operand->get(), indices);
+          builder.create<memref::LoadOp>(linalgOpLoc, operand.get(), indices);
       indexedValues.push_back(load);
     }
 
index 8eb41c5..eee454b 100644 (file)
@@ -203,10 +203,10 @@ linalg::rewriteAsPaddedOp(OpBuilder &b, LinalgOp opToPad,
   b.setInsertionPointAfter(opToPad);
   // Make a copy of the shaped operands and update it.
   SmallVector<Value> newOperands;
-  newOperands.reserve(opToPad.getNumInputsAndOutputs());
-  for (OpOperand *opOperand : opToPad.getInputAndOutputOperands()) {
+  newOperands.reserve(opToPad->getNumOperands());
+  for (OpOperand &opOperand : opToPad->getOpOperands()) {
     FailureOr<Value> paddedOperand = padOperandToSmallestStaticBoundingBox(
-        b, opToPad, opOperand, paddingDimensions, paddingValues, packPaddings);
+        b, opToPad, &opOperand, paddingDimensions, paddingValues, packPaddings);
     // Exit if `paddingDimensions` cannot be bounded statically.
     if (failed(paddedOperand))
       return failure();
@@ -327,15 +327,15 @@ mlir::linalg::LinalgPaddingPattern::returningMatchAndRewrite(
 
   // Hoist the padding.
   for (const auto &en : enumerate(options.hoistPaddings)) {
-    if (static_cast<int64_t>(en.index()) >= paddedOp.getNumInputsAndOutputs())
+    if (static_cast<int64_t>(en.index()) >= paddedOp->getNumOperands())
       break;
-    OpOperand *opOperand = &paddedOp->getOpOperand(en.index());
-    auto padOp = opOperand->get().getDefiningOp<tensor::PadOp>();
+    OpOperand &opOperand = paddedOp->getOpOperand(en.index());
+    auto padOp = opOperand.get().getDefiningOp<tensor::PadOp>();
     if (!padOp || en.value() == 0)
       continue;
 
     // Fail hoisting if the operand shape is not fully static.
-    if (llvm::any_of(paddedOp.getShape(opOperand), ShapedType::isDynamic))
+    if (llvm::any_of(paddedOp.getShape(&opOperand), ShapedType::isDynamic))
       return failure();
 
     tensor::PadOp hoistedOp;
index 5623a16..2b70155 100644 (file)
@@ -459,35 +459,35 @@ vectorizeAsLinalgGeneric(OpBuilder &b, LinalgOp linalgOp,
   // 3. Turn all BBArgs into vector.transfer_read / load.
   Location loc = linalgOp.getLoc();
   Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
-  for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) {
-    BlockArgument bbarg = block->getArgument(opOperand->getOperandNumber());
-    if (linalgOp.isScalar(opOperand)) {
-      bvm.map(bbarg, opOperand->get());
+  for (OpOperand &opOperand : linalgOp->getOpOperands()) {
+    BlockArgument bbarg = block->getArgument(opOperand.getOperandNumber());
+    if (linalgOp.isScalar(&opOperand)) {
+      bvm.map(bbarg, opOperand.get());
       continue;
     }
     VectorType readType;
     AffineMap map;
     // TODO: can we keep this simplification?
-    // if (linalgOp.getShape(opOperand).empty()) {
+    // if (linalgOp.getShape(&opOperand).empty()) {
     //   readType = VectorType::get({}, bbarg.getType());
     // } else {
-    if (opOperand->getOperandNumber() < linalgOp.getNumInputs()) {
+    if (opOperand.getOperandNumber() < linalgOp.getNumInputs()) {
       map = inverseAndBroadcastProjectedPermutation(
-          linalgOp.getMatchingIndexingMap(opOperand));
+          linalgOp.getMatchingIndexingMap(&opOperand));
       readType = VectorType::get(commonVectorShape,
-                                 getElementTypeOrSelf(opOperand->get()));
+                                 getElementTypeOrSelf(opOperand.get()));
     } else {
       map = inversePermutation(
-          reindexIndexingMap(linalgOp.getMatchingIndexingMap(opOperand)));
-      readType = VectorType::get(map.compose(linalgOp.getShape(opOperand)),
-                                 getElementTypeOrSelf(opOperand->get()));
+          reindexIndexingMap(linalgOp.getMatchingIndexingMap(&opOperand)));
+      readType = VectorType::get(map.compose(linalgOp.getShape(&opOperand)),
+                                 getElementTypeOrSelf(opOperand.get()));
     }
     // }
 
-    auto shape = linalgOp.getShape(opOperand);
+    auto shape = linalgOp.getShape(&opOperand);
     SmallVector<Value> indices(shape.size(), zero);
     Value readValue = b.create<vector::TransferReadOp>(
-        loc, readType, opOperand->get(), indices, map);
+        loc, readType, opOperand.get(), indices, map);
     // Not all ops support 0-d vectors, extract the scalar for now.
     // TODO: remove this.
     if (readValue.getType().cast<VectorType>().getRank() == 0)
@@ -495,7 +495,7 @@ vectorizeAsLinalgGeneric(OpBuilder &b, LinalgOp linalgOp,
 
     LDBG("new vectorized bbarg(" << bbarg.getArgNumber() << "): " << readValue);
     bvm.map(bbarg, readValue);
-    bvm.map(opOperand->get(), readValue);
+    bvm.map(opOperand.get(), readValue);
   }
 
   SmallVector<CustomVectorizationHook> hooks;
@@ -1342,9 +1342,9 @@ struct Conv1DGenerator : public StructuredGenerator<LinalgOp> {
     // Determine whether `linalgOp` can be generated with this generator
     if (linalgOp.getNumInputs() != 2 || linalgOp.getNumOutputs() != 1)
       return;
-    lhsShaped = linalgOp.getInputs()[0];
-    rhsShaped = linalgOp.getInputs()[1];
-    resShaped = linalgOp.getOutputs()[0];
+    lhsShaped = linalgOp.getInputOperand(0)->get();
+    rhsShaped = linalgOp.getInputOperand(1)->get();
+    resShaped = linalgOp.getOutputOperand(0)->get();
     lhsShapedType = lhsShaped.getType().dyn_cast<ShapedType>();
     rhsShapedType = rhsShaped.getType().dyn_cast<ShapedType>();
     resShapedType = resShaped.getType().dyn_cast<ShapedType>();
index 999034b..119c3db 100644 (file)
@@ -490,17 +490,18 @@ void GenerateLoopNest<scf::ForOp>::doit(
   assert((procInfo.empty() || (procInfo.size() == loopRanges.size())) &&
          "expected as many entries for proc info as number of loops, even if "
          "they are null entries");
-  SmallVector<Value> iterArgInitValues = linalgOp.getOutputTensorOperands();
+  SmallVector<Value> iterArgInitValues = linalgOp.hasBufferSemantics()
+                                             ? SmallVector<Value>{}
+                                             : linalgOp.getOutputOperands();
 
   SmallVector<Value, 4> lbs, ubs, steps;
   unpackRanges(b, loc, loopRanges, lbs, ubs, steps);
   LoopNest loopNest = mlir::scf::buildLoopNest(
       b, loc, lbs, ubs, steps, iterArgInitValues,
       [&](OpBuilder &b, Location loc, ValueRange ivs, ValueRange iterArgs) {
-        assert(iterArgs.size() == linalgOp.getOutputTensorOperands().size() &&
+        assert(iterArgs.size() == iterArgInitValues.size() &&
                "expect the number of output tensors and iter args to match");
-        SmallVector<Value> operandValuesToUse =
-            linalgOp.getInputAndOutputOperands();
+        SmallVector<Value> operandValuesToUse = linalgOp->getOperands();
         if (!iterArgs.empty()) {
           operandValuesToUse = linalgOp.getInputOperands();
           operandValuesToUse.append(iterArgs.begin(), iterArgs.end());
@@ -530,7 +531,9 @@ void GenerateLoopNest<AffineForOp>::doit(
                                   ValueRange)>
         bodyBuilderFn,
     ArrayRef<linalg::ProcInfo> /*procInfo*/) {
-  SmallVector<Value> iterArgInitValues = linalgOp.getOutputTensorOperands();
+  SmallVector<Value> iterArgInitValues = linalgOp.hasBufferSemantics()
+                                             ? SmallVector<Value>{}
+                                             : linalgOp.getOutputOperands();
   assert(iterArgInitValues.empty() && "unexpected AffineForOp init values");
   SmallVector<Value, 4> lbs, ubs, steps;
   unpackRanges(b, loc, loopRanges, lbs, ubs, steps);
@@ -546,9 +549,8 @@ void GenerateLoopNest<AffineForOp>::doit(
 
   mlir::buildAffineLoopNest(b, loc, lbs, ubs, constantSteps,
                             [&](OpBuilder &b, Location loc, ValueRange ivs) {
-                              SmallVector<Value> operandValuesToUse =
-                                  linalgOp.getInputAndOutputOperands();
-                              bodyBuilderFn(b, loc, ivs, operandValuesToUse);
+                              bodyBuilderFn(b, loc, ivs,
+                                            linalgOp->getOperands());
                             });
 }
 
@@ -695,7 +697,9 @@ void GenerateLoopNest<scf::ParallelOp>::doit(
                                   ValueRange)>
         bodyBuilderFn,
     ArrayRef<linalg::ProcInfo> procInfo) {
-  SmallVector<Value> iterArgInitValues = linalgOp.getOutputTensorOperands();
+  SmallVector<Value> iterArgInitValues = linalgOp.hasBufferSemantics()
+                                             ? SmallVector<Value>{}
+                                             : linalgOp.getOutputOperands();
   assert(iterArgInitValues.empty() && "unexpected ParallelOp init values");
   // This function may be passed more iterator types than ranges.
   assert(iteratorTypes.size() >= loopRanges.size() &&
@@ -725,9 +729,7 @@ void GenerateLoopNest<scf::ParallelOp>::doit(
   generateParallelLoopNest(
       b, loc, lbs, ubs, steps, iteratorTypes, procInfo,
       [&](OpBuilder &b, Location loc, ValueRange ivs) {
-        SmallVector<Value> operandValuesToUse =
-            linalgOp.getInputAndOutputOperands();
-        bodyBuilderFn(b, loc, ivs, operandValuesToUse);
+        bodyBuilderFn(b, loc, ivs, linalgOp->getOperands());
       },
       ivs);
 
@@ -905,10 +907,10 @@ SmallVector<OpFoldResult> computeTileSizes(OpBuilder &b, Location loc,
 }
 
 SmallVector<Type> getTensorOutputTypes(LinalgOp op, ValueRange operands) {
-  // TODO: use an interface/adaptor to avoid leaking position in
-  // `tiledOperands`.
+  if (op.hasBufferSemantics())
+    return {};
   return llvm::to_vector(
-      llvm::map_range(op.getOutputTensorOperands(), [&](OpOperand *opOperand) {
+      llvm::map_range(op.getOutputOperands(), [&](OpOperand *opOperand) {
         return operands[opOperand->getOperandNumber()].getType();
       }));
 }
@@ -916,11 +918,13 @@ SmallVector<Type> getTensorOutputTypes(LinalgOp op, ValueRange operands) {
 SmallVector<Value> insertSlicesBack(OpBuilder &builder, Location loc,
                                     LinalgOp op, ValueRange operands,
                                     ValueRange results) {
+  if (op.hasBufferSemantics())
+    return {};
   SmallVector<Value> tensorResults;
   tensorResults.reserve(results.size());
   // Insert a insert_slice for each output tensor.
   unsigned resultIdx = 0;
-  for (OpOperand *opOperand : op.getOutputTensorOperands()) {
+  for (OpOperand *opOperand : op.getOutputOperands()) {
     // TODO: use an interface/adaptor to avoid leaking position in
     // `tiledOperands`.
     Value outputTensor = operands[opOperand->getOperandNumber()];
@@ -958,23 +962,26 @@ computeAllSliceParameters(OpBuilder &builder, Location loc, LinalgOp linalgOp,
       computeTileSizes(builder, loc, tileSizes, sizeBounds);
 
   assert(static_cast<int64_t>(valuesToTile.size()) ==
-             linalgOp.getNumInputsAndOutputs() &&
+             linalgOp->getNumOperands() &&
          "expected one value to tile for every operand");
   SmallVector<Optional<SliceParameters>> allSliceParams;
   allSliceParams.reserve(valuesToTile.size());
-  for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) {
-    Value shapedOp = valuesToTile[opOperand->getOperandNumber()];
+  for (OpOperand &opOperand : linalgOp->getOpOperands()) {
+    Value shapedOp = valuesToTile[opOperand.getOperandNumber()];
     LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: for operand " << shapedOp);
-    AffineMap map = linalgOp.getMatchingIndexingMap(opOperand);
+    AffineMap map = linalgOp.getMatchingIndexingMap(&opOperand);
     // Use `opOperand` as is if it is not tiled and not an output tensor. Having
     // an extract/insert slice pair for all output tensors simplifies follow up
     // transformations such as padding and bufferization since the
     // extract/insert slice pairs make the accessed iteration argument
     // subdomains explicit.
-    if (!isTiled(map, tileSizes) && !linalgOp.isOutputTensor(opOperand)) {
+
+    Type operandType = opOperand.get().getType();
+    if (!isTiled(map, tileSizes) && !(operandType.isa<RankedTensorType>() &&
+                                      linalgOp.isOutput(&opOperand))) {
       allSliceParams.push_back(llvm::None);
-      LLVM_DEBUG(llvm::dbgs() << ": not tiled: use shape: "
-                              << opOperand->get().getType() << "\n");
+      LLVM_DEBUG(llvm::dbgs()
+                 << ": not tiled: use shape: " << operandType << "\n");
       continue;
     }
     LLVM_DEBUG(llvm::dbgs() << ": tiled: figure out subshape...\n");
index 2458dab..73f428a 100644 (file)
@@ -105,8 +105,7 @@ static bool isZeroYield(GenericOp op) {
   auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator());
   if (auto arg = yieldOp.getOperand(0).dyn_cast<BlockArgument>()) {
     if (arg.getOwner()->getParentOp() == op) {
-      OpOperand *t = op.getInputAndOutputOperands()[arg.getArgNumber()];
-      return isZeroValue(t->get());
+      return isZeroValue(op->getOperand(arg.getArgNumber()));
     }
   }
   return isZeroValue(yieldOp.getOperand(0));
@@ -242,8 +241,8 @@ public:
       return failure();
     // Modify operand structure of producer and consumer.
     Location loc = prod.getLoc();
-    SmallVector<Value> inputOps = prod.getInputOperands();
-    SmallVector<Value> outputOps = op.getOutputOperands();
+    SmallVector<Value> inputOps = prod.getInputs();
+    SmallVector<Value> outputOps = op.getOutputs();
     SmallVector<AffineMap> fusedIndexMaps = prod.getIndexingMapsArray();
     inputOps.push_back(op.getInputOperand(1 - other)->get());
     fusedIndexMaps.push_back(fusedIndexMaps.back()); // mimic other
index 1418ed4..e512723 100644 (file)
@@ -194,14 +194,14 @@ static bool findAffine(Merger &merger, unsigned tensor, AffineExpr a,
 /// no annotations are found or inadmissible constructs occur.
 static bool findSparseAnnotations(Merger &merger, linalg::GenericOp op) {
   bool annotated = false;
-  for (OpOperand *t : op.getInputAndOutputOperands()) {
-    auto map = op.getMatchingIndexingMap(t);
-    auto enc = getSparseTensorEncoding(t->get().getType());
+  for (OpOperand &t : op->getOpOperands()) {
+    auto map = op.getMatchingIndexingMap(&t);
+    auto enc = getSparseTensorEncoding(t.get().getType());
     if (enc)
       annotated = true;
-    assert(map.getNumResults() == op.getRank(t));
+    assert(map.getNumResults() == op.getRank(&t));
     for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) {
-      unsigned tensor = t->getOperandNumber();
+      unsigned tensor = t.getOperandNumber();
       AffineExpr a = map.getResult(toOrigDim(enc, d));
       if (!findAffine(merger, tensor, a, toDimLevelFormat(enc, d)))
         return false; // inadmissible affine expression
@@ -291,13 +291,13 @@ static bool computeIterationGraph(Merger &merger, linalg::GenericOp op,
   std::vector<unsigned> inDegree(n, 0); // in-degree of each node.
   auto iteratorTypes = op.getIteratorTypesArray();
   // Iterate over the indexing maps of every tensor in the tensor expression.
-  for (OpOperand *t : op.getInputAndOutputOperands()) {
+  for (OpOperand &t : op->getOpOperands()) {
     // Skip tensor during cycle resolution.
-    if (t == skip)
+    if (&t == skip)
       continue;
     // Get map and encoding.
-    auto map = op.getMatchingIndexingMap(t);
-    auto enc = getSparseTensorEncoding(t->get().getType());
+    auto map = op.getMatchingIndexingMap(&t);
+    auto enc = getSparseTensorEncoding(t.get().getType());
     assert(map.getNumDims() == n);
     // Skip dense tensor constraints when not requested.
     if (!(mask & SortMask::kIncludeDense) && !enc)
@@ -314,7 +314,7 @@ static bool computeIterationGraph(Merger &merger, linalg::GenericOp op,
     // Push unrelated loops into sparse iteration space, so these
     // will be skipped more often.
     if (mask & SortMask::kIncludeUndef) {
-      unsigned tensor = t->getOperandNumber();
+      unsigned tensor = t.getOperandNumber();
       for (unsigned i = 0; i < n; i++)
         if (merger.isDimLevelType(tensor, i, DimLvlType::kCompressed) ||
             merger.isDimLevelType(tensor, i, DimLvlType::kSingleton)) {
@@ -534,16 +534,16 @@ static Value genOutputBuffer(CodeGen &codegen, OpBuilder &builder,
 static void genBuffers(Merger &merger, CodeGen &codegen, OpBuilder &builder,
                        linalg::GenericOp op) {
   Location loc = op.getLoc();
-  assert(op.getNumInputsAndOutputs() == op.getNumInputs() + 1);
+  assert(op->getNumOperands() == op.getNumInputs() + 1);
   // For every tensor, find lower and upper bound on dimensions, set the
   // same bounds on loop indices, and obtain dense or sparse buffer(s).
   auto dynShape = {ShapedType::kDynamicSize};
   SmallVector<Value, 4> args;
-  for (OpOperand *t : op.getInputAndOutputOperands()) {
-    unsigned tensor = t->getOperandNumber();
-    auto shape = op.getShape(t);
-    auto map = op.getMatchingIndexingMap(t);
-    auto enc = getSparseTensorEncoding(t->get().getType());
+  for (OpOperand &t : op->getOpOperands()) {
+    unsigned tensor = t.getOperandNumber();
+    auto shape = op.getShape(&t);
+    auto map = op.getMatchingIndexingMap(&t);
+    auto enc = getSparseTensorEncoding(t.get().getType());
     // Scan all dimensions of current tensor.
     args.clear();
     for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) {
@@ -560,23 +560,23 @@ static void genBuffers(Merger &merger, CodeGen &codegen, OpBuilder &builder,
             MemRefType::get(dynShape, getIndexOverheadType(builder, enc));
         auto dim = builder.getIndexAttr(d);
         codegen.pointers[tensor][idx] =
-            builder.create<ToPointersOp>(loc, ptrTp, t->get(), dim);
+            builder.create<ToPointersOp>(loc, ptrTp, t.get(), dim);
         codegen.indices[tensor][idx] =
-            builder.create<ToIndicesOp>(loc, indTp, t->get(), dim);
+            builder.create<ToIndicesOp>(loc, indTp, t.get(), dim);
       } else if (merger.isDimLevelType(tensor, idx, DimLvlType::kSingleton)) {
         // Singleton dimension, fetch indices.
         auto indTp =
             MemRefType::get(dynShape, getIndexOverheadType(builder, enc));
         auto dim = builder.getIndexAttr(d);
         codegen.indices[tensor][idx] =
-            builder.create<ToIndicesOp>(loc, indTp, t->get(), dim);
+            builder.create<ToIndicesOp>(loc, indTp, t.get(), dim);
       } else {
         // Dense dimension, nothing to fetch.
         assert(merger.isDimLevelType(tensor, idx, DimLvlType::kDense));
       }
       // Find upper bound in current dimension.
       unsigned p = toOrigDim(enc, d);
-      Value up = linalg::createOrFoldDimOp(builder, loc, t->get(), p);
+      Value up = linalg::createOrFoldDimOp(builder, loc, t.get(), p);
       if (ShapedType::isDynamic(shape[p]))
         args.push_back(up);
       assert(codegen.highs[tensor][idx] == nullptr);
@@ -585,21 +585,21 @@ static void genBuffers(Merger &merger, CodeGen &codegen, OpBuilder &builder,
     // Perform the required bufferization. Dense inputs materialize
     // from the input tensors. Dense outputs need special handling.
     // Sparse inputs use sparse primitives to obtain the values.
-    Type elementType = getElementTypeOrSelf(t->get().getType());
+    Type elementType = getElementTypeOrSelf(t.get().getType());
     if (!enc) {
       // Non-annotated dense tensors.
       auto denseTp = MemRefType::get(shape, elementType);
       if (tensor < op.getNumInputs())
         codegen.buffers[tensor] =
-            builder.create<bufferization::ToMemrefOp>(loc, denseTp, t->get());
+            builder.create<bufferization::ToMemrefOp>(loc, denseTp, t.get());
       else
         codegen.buffers[tensor] =
             genOutputBuffer(codegen, builder, op, denseTp, args);
-    } else if (t != codegen.sparseOut) {
+    } else if (&t != codegen.sparseOut) {
       // Annotated sparse tensors (not involved in output).
       auto sparseTp = MemRefType::get(dynShape, elementType);
       codegen.buffers[tensor] =
-          builder.create<ToValuesOp>(loc, sparseTp, t->get());
+          builder.create<ToValuesOp>(loc, sparseTp, t.get());
     }
   }
 }
@@ -845,15 +845,15 @@ static Value genTensorLoad(Merger &merger, CodeGen &codegen, OpBuilder &builder,
     return val;
   }
   // Load during insertion.
-  OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).tensor];
-  if (t == codegen.sparseOut) {
+  OpOperand &t = op->getOpOperand(merger.exp(exp).tensor);
+  if (&t == codegen.sparseOut) {
     if (codegen.redCustom != -1u)
-      return genInsertionLoadReduce(merger, codegen, builder, op, t);
-    return genInsertionLoad(codegen, builder, op, t);
+      return genInsertionLoadReduce(merger, codegen, builder, op, &t);
+    return genInsertionLoad(codegen, builder, op, &t);
   }
   // Actual load.
   SmallVector<Value, 4> args;
-  Value ptr = genSubscript(codegen, builder, op, t, args);
+  Value ptr = genSubscript(codegen, builder, op, &t, args);
   if (codegen.curVecLength > 1)
     return genVectorLoad(codegen, builder, ptr, args);
   return builder.create<memref::LoadOp>(op.getLoc(), ptr, args);
@@ -1093,9 +1093,9 @@ static void genInvariants(Merger &merger, CodeGen &codegen, OpBuilder &builder,
   if (merger.exp(exp).kind == Kind::kTensor) {
     // Inspect tensor indices.
     bool atLevel = ldx == -1u;
-    OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).tensor];
-    auto map = op.getMatchingIndexingMap(t);
-    auto enc = getSparseTensorEncoding(t->get().getType());
+    OpOperand &t = op->getOpOperand(merger.exp(exp).tensor);
+    auto map = op.getMatchingIndexingMap(&t);
+    auto enc = getSparseTensorEncoding(t.get().getType());
     for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) {
       AffineExpr a = map.getResult(toOrigDim(enc, d));
       if (!isInvariantAffine(codegen, a, ldx, atLevel))
@@ -1105,7 +1105,7 @@ static void genInvariants(Merger &merger, CodeGen &codegen, OpBuilder &builder,
     if (!atLevel)
       return;
     OpOperand *lhs = op.getOutputOperand(0);
-    if (lhs == t) {
+    if (lhs == &t) {
       // Start or end a scalarized reduction
       if (atStart) {
         Kind kind = merger.exp(last).kind;
@@ -1288,9 +1288,9 @@ static bool isParallelFor(CodeGen &codegen, bool isOuter, bool isReduction,
 /// This prevents effective vectorization.
 static bool denseUnitStrides(Merger &merger, linalg::GenericOp op,
                              unsigned idx) {
-  for (OpOperand *t : op.getInputAndOutputOperands()) {
-    if (!getSparseTensorEncoding(t->get().getType())) {
-      auto map = op.getMatchingIndexingMap(t);
+  for (OpOperand &t : op->getOpOperands()) {
+    if (!getSparseTensorEncoding(t.get().getType())) {
+      auto map = op.getMatchingIndexingMap(&t);
       for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) {
         AffineExpr a = map.getResult(d);
         // Report non-unit stride if innermost index appears at an outer
@@ -1856,7 +1856,7 @@ public:
     // information for all tensors to loop indices in the kernel.
     if (op.getNumOutputs() != 1)
       return failure();
-    unsigned numTensors = op.getNumInputsAndOutputs();
+    unsigned numTensors = op->getNumOperands();
     unsigned numLoops = op.getNumLoops();
     Merger merger(numTensors, numLoops);
     if (!findSparseAnnotations(merger, op))
index 187a6c0..b8f6a93 100644 (file)
@@ -910,10 +910,10 @@ Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
     // argument is considered a tensor, indexed by the implicit loop
     // bounds. This includes rank-0 tensor arguments.
     if (arg.getOwner()->getParentOp() == op) {
-      OpOperand *t = op.getInputAndOutputOperands()[argN];
-      if (!op.isScalar(t))
+      OpOperand &t = op->getOpOperand(argN);
+      if (!op.isScalar(&t))
         return addExp(kTensor, argN);
-      v = t->get(); // get scalar value
+      v = t.get(); // get scalar value
     }
     // Any other argument (marked as scalar argument for the generic op
     // or belonging to an enveloping op) is considered invariant.
index 43589f7..2062c65 100644 (file)
@@ -275,7 +275,7 @@ func.func @self_copy(%arg0 : memref<2x3x?x4xf32>) {
 // -----
 
 // CHECK-LABEL: func @remove_deadargs_generic_basic
-//  CHECK-SAME: (%[[ARG0:.*]]: tensor<?xf32>) -> tensor<?xf32> { 
+//  CHECK-SAME: (%[[ARG0:.*]]: tensor<?xf32>) -> tensor<?xf32> {
 //       CHECK: %[[GENERIC_OP:.*]] = linalg.generic
 //  CHECK-SAME: ins(%[[ARG0]] : tensor<?xf32>)
 //  CHECK-SAME: outs({{.*}} : tensor<?xf32>) {
index 02471b1..f751ddf 100644 (file)
@@ -121,26 +121,6 @@ func.func @generic(%arg0: memref<?x?xvector<3x4xi4>, strided<[?, 1], offset: ?>>
 //  CHECK-SAME:     outs({{.*}} : memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>>)
 //  CHECK-SAME:     {foo = 1 : i64}
 
-func.func @generic_with_tensor_input(%arg0: tensor<?x?xvector<3x4xi4>>,
-                                %arg1: memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>>) {
-  %cst = arith.constant 0.0 : f32
-  linalg.generic #trait_0
-       ins(%arg0, %cst : tensor<?x?xvector<3x4xi4>>, f32)
-      outs(%arg1 : memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>>)
-      attrs = {foo = 1} {
-    ^bb(%0: vector<3x4xi4>, %1: f32, %2: f32) :
-      linalg.yield %1 : f32
-  }
-  return
-}
-// CHECK-LABEL: func @generic_with_tensor_input
-//       CHECK:   linalg.generic {
-//  CHECK-SAME:     indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"],
-//  CHECK-SAME:     library_call = "some_external_function_name_1"}
-//  CHECK-SAME:     ins({{.*}}, {{.*}} : tensor<?x?xvector<3x4xi4>>, f32)
-//  CHECK-SAME:     outs({{.*}} : memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>>)
-//  CHECK-SAME:     {foo = 1 : i64}
-
 // -----
 
 #map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
@@ -300,27 +280,19 @@ func.func @generic_region(%arg0: memref<?x?xvector<3x4xi4>, strided<[?, 1], offs
 
 func.func @named_ops(%a3: memref<?x?x?xf32>, %b3: memref<?x?x?xf32>, %c3: memref<?x?x?xf32>,
                 %ta3: tensor<?x?x?xf32>, %tb3: tensor<?x?x?xf32>, %tc3: tensor<?x?x?xf32>)
-  -> (tensor<?x?x?xf32>, tensor<?x?x?xf32>)
+  -> (tensor<?x?x?xf32>)
 {
   linalg.batch_matmul ins(%a3, %b3: memref<?x?x?xf32>, memref<?x?x?xf32>)
                      outs(%c3: memref<?x?x?xf32>)
-  linalg.batch_matmul ins(%ta3, %tb3: tensor<?x?x?xf32>, tensor<?x?x?xf32>)
-                     outs(%c3: memref<?x?x?xf32>)
   %res1 = linalg.batch_matmul
                       ins(%ta3, %tb3: tensor<?x?x?xf32>, tensor<?x?x?xf32>)
                      outs(%tc3: tensor<?x?x?xf32>)
                   -> tensor<?x?x?xf32>
-  %res2 = linalg.batch_matmul
-                      ins(%ta3, %b3: tensor<?x?x?xf32>, memref<?x?x?xf32>)
-                     outs(%tc3: tensor<?x?x?xf32>)
-                  -> tensor<?x?x?xf32>
-  return %res1, %res2 : tensor<?x?x?xf32>, tensor<?x?x?xf32>
+  return %res1 : tensor<?x?x?xf32>
 }
 // CHECK-LABEL: func @named_ops
 //       CHECK:   linalg.batch_matmul
 //       CHECK:   linalg.batch_matmul
-//       CHECK:   linalg.batch_matmul
-//       CHECK:   linalg.batch_matmul
 
 // -----
 
index 0119516..5807726 100644 (file)
@@ -26,7 +26,7 @@ static void addOperands(Operation *op, SetVector<Value> &operandSet) {
     return;
   TypeSwitch<Operation *, void>(op)
       .Case<linalg::LinalgOp>([&](linalg::LinalgOp linalgOp) {
-        SmallVector<Value> inputOperands = linalgOp.getInputOperands();
+        SmallVector<Value> inputOperands{linalgOp.getInputOperands()};
         operandSet.insert(inputOperands.begin(), inputOperands.end());
       })
       .Default([&](Operation *operation) {
@@ -147,7 +147,7 @@ struct TestLinalgElementwiseFusion
               if (expandOp->hasOneUse()) {
                 OpOperand &use = *expandOp->getUses().begin();
                 auto linalgOp = dyn_cast<linalg::LinalgOp>(use.getOwner());
-                if (linalgOp && linalgOp.isOutputTensor(&use))
+                if (linalgOp && linalgOp.isOutput(&use))
                   return true;
               }
               return false;
index c5b27c5..3c62496 100644 (file)
@@ -38,14 +38,14 @@ static LogicalResult fuseLinalgOpsGreedily(func::FuncOp f) {
   // Tile and Fuse for tensors inputs (TODO: all tensor operands).
   bool changed = false;
   for (LinalgOp linalgOp : llvm::reverse(linalgOps)) {
-    for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) {
-      if (opOperand->get().getType().isa<MemRefType>()) {
+    for (OpOperand &opOperand : linalgOp->getOpOperands()) {
+      if (opOperand.get().getType().isa<MemRefType>()) {
         // TODO: LinalgDependenceGraph should be able to update itself.
         // The current naive and expensive reconstruction of the graph should be
         // removed.
         linalg::Aliases aliases;
         linalg::LinalgDependenceGraph graph(aliases, linalgOps);
-        auto info = fuseProducerOfBuffer(b, *opOperand, graph);
+        auto info = fuseProducerOfBuffer(b, opOperand, graph);
         if (failed(info))
           continue;
         auto *originalOp = info->originalProducer.getOperation();
@@ -54,11 +54,11 @@ static LogicalResult fuseLinalgOpsGreedily(func::FuncOp f) {
             std::find(linalgOps.begin(), linalgOps.end(), originalOp);
         *originalOpInLinalgOpsVector = info->fusedProducer.getOperation();
         changed = true;
-      } else if (opOperand->get().getType().isa<RankedTensorType>()) {
+      } else if (opOperand.get().getType().isa<RankedTensorType>()) {
         // Tile and Fuse tensor input.
-        if (opOperand->getOperandNumber() >= linalgOp.getNumInputs())
+        if (opOperand.getOperandNumber() >= linalgOp.getNumInputs())
           continue;
-        auto info = fuseProducerOfTensor(b, *opOperand);
+        auto info = fuseProducerOfTensor(b, opOperand);
         if (failed(info))
           continue;
         auto *originalOp = info->originalProducer.getOperation();
index 85c5f32..79ed068 100644 (file)
@@ -2835,9 +2835,10 @@ def TestLinalgConvOp :
       return "";
     }
 
-    // To conform with interface requirement on operand naming.
-    mlir::ValueRange inputs() { return getInputs(); }
-    mlir::ValueRange outputs() { return getOutputs(); }
+    std::pair<int64_t, int64_t> getOutputsPositionRange() {
+      int64_t getNumOperands = this->getNumOperands();
+      return {getNumOperands - 1, getNumOperands};
+    }
   }];
 }
 
@@ -2894,9 +2895,10 @@ def TestLinalgFillOp :
       return "";
     }
 
-    // To conform with interface requirement on operand naming.
-    mlir::ValueRange inputs() { return getInputs(); }
-    mlir::ValueRange outputs() { return getOutputs(); }
+    std::pair<int64_t, int64_t> getOutputsPositionRange() {
+      int64_t getNumOperands = this->getNumOperands();
+      return {getNumOperands - 1, getNumOperands};
+    }
   }];
 }
 
index 8156bb9..d8e10ef 100644 (file)
@@ -563,6 +563,11 @@ def {0} : LinalgStructuredBase_Op<"{1}", !listconcat([AttrSizedOperandSegments],
         return regionBuilder;
       }
 
+      std::pair<int64_t, int64_t> getOutputsPositionRange() {{
+        int64_t getNumOperands = this->getNumOperands();
+        return {{getNumOperands - 1, getNumOperands};
+      }
+
       // Generic methods.
       static unsigned getNumRegionArgs();
       std::string getLibraryCallName();
@@ -638,8 +643,8 @@ ArrayAttr {0}::getIndexingMaps() {{
   AffineMap tensorMap = AffineMap::getMultiDimIdentityMap(
     getNumParallelLoops(), context);
   SmallVector<AffineMap> indexingMaps;
-  for (OpOperand *opOperand : getInputAndOutputOperands())
-    indexingMaps.push_back(getRank(opOperand) == 0 ? scalarMap : tensorMap);
+  for (OpOperand &opOperand : getOperation()->getOpOperands())
+    indexingMaps.push_back(getRank(&opOperand) == 0 ? scalarMap : tensorMap);
   return Builder(getContext()).getAffineMapArrayAttr(indexingMaps);
 }
 )FMT";
@@ -654,10 +659,9 @@ LogicalResult {0}::fold(ArrayRef<Attribute>,
 }
 void {0}::getEffects(SmallVectorImpl<
     SideEffects::EffectInstance<MemoryEffects::Effect> >&effects) {{
-      SmallVector<Value> inputBuffers = getInputBufferOperands();
-      SmallVector<Value> outputBuffers = getOutputBufferOperands();
+      if (hasTensorSemantics()) return;
       getGenericEffectsImpl(effects,
-        getOperation()->getResults(), inputBuffers, outputBuffers);
+        getOperation()->getResults(), getInputOperands(), getOutputOperands());
 }
 )FMT";