[mlir][interfaces][NFC] Move DestinationStyleOpInterface to mlir/Interfaces
authorMatthias Springer <springerm@google.com>
Tue, 18 Oct 2022 15:23:42 +0000 (17:23 +0200)
committerMatthias Springer <springerm@google.com>
Tue, 18 Oct 2022 15:39:06 +0000 (17:39 +0200)
This is the second (and final) step of making "destination style" usable without depending on the Linalg dialect. (The first step was D135129.)

This change allows us to provide default bufferization implementations for all destination-style ops. It also allows us to simplify `TilingInterface`. (E.g., `getDestinationOperands` can be removed.)

Differential Revision: https://reviews.llvm.org/D136179

17 files changed:
mlir/include/mlir/Dialect/Linalg/IR/Linalg.h
mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
mlir/include/mlir/Interfaces/CMakeLists.txt
mlir/include/mlir/Interfaces/DestinationStyleOpInterface.h [new file with mode: 0644]
mlir/include/mlir/Interfaces/DestinationStyleOpInterface.td [new file with mode: 0644]
mlir/lib/Dialect/Linalg/IR/CMakeLists.txt
mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
mlir/lib/Interfaces/CMakeLists.txt
mlir/lib/Interfaces/DestinationStyleOpInterface.cpp [new file with mode: 0644]
mlir/test/lib/Dialect/Test/CMakeLists.txt
mlir/test/lib/Dialect/Test/TestOps.td
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel

index 70e7fc9..28c75fc 100644 (file)
@@ -20,6 +20,7 @@
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Interfaces/CopyOpInterface.h"
+#include "mlir/Interfaces/DestinationStyleOpInterface.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "mlir/Interfaces/TilingInterface.h"
index 39ca855..8e3df10 100644 (file)
@@ -19,6 +19,7 @@
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/ImplicitLocOpBuilder.h"
 #include "mlir/IR/OpDefinition.h"
+#include "mlir/Interfaces/DestinationStyleOpInterface.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Interfaces/ViewLikeInterface.h"
 
@@ -26,11 +27,6 @@ namespace mlir {
 namespace linalg {
 class LinalgOp;
 
-/// OpOperand vector that implicitly converts to a Value vector.
-struct OpOperandVector : public SmallVector<OpOperand *> {
-  operator SmallVector<Value>();
-};
-
 namespace detail {
 /// Implementation of the method that that check if given operands
 /// can be dropped, i.e. the remaining operands can compute the loop
@@ -57,9 +53,6 @@ LogicalResult verifyFillInterface(Operation *op);
 /// Verify that `op` conforms to the invariants of StructuredOpInterface
 LogicalResult verifyStructuredOpInterface(Operation *op);
 
-/// Verify that `op` conforms to the invariants of DestinationStyleOpInterface
-LogicalResult verifyDestinationStyleOpInterface(Operation *op);
-
 } // namespace detail
 } // namespace linalg
 } // namespace mlir
index 995ced5..28bf0b0 100644 (file)
@@ -879,291 +879,4 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
   let verifyWithRegions = 1;
 }
 
-// Ops that are in destination style have designated output operands, which act
-// as initial tensor values for the results of the operation or the output
-// buffers to which the results of the op will be written.
-//
-// Output operands must be tensors or memrefs. Input operands can have any
-// type. All non-output operands are inputs.
-
-// It is assumed that the output operands of the op are the operands at
-// position [start, end). The positions are defined by getOutputsPositionRange
-// method. All non-output operands are "inputs" of the DPS op.
-
-// If the op has "tensor semantics", then the input operands are either scalars
-// or tensors. The output operands are tensors and every tensor output is tied
-// to a corresponding tensor OpResult in a 1-to-1 fashion. The i-th output
-// tensor is tied to the i-th OpResult. The op may not have any additional
-// OpResults. Output operands and their tied OpResults have the same type.
-//
-// If the op has "buffer semantics", then the input operands are either memrefs
-// or other non-tensor types, e.g. scalar types. Furthermore, the output
-// operands are memrefs and the op has no results.
-//
-// Destination-passing style abstraction makes certain transformations easier.
-// For example, tiling implementation can extract/insert slices from/into the
-// destination of an op and use the resulting shaped value as an iter_arg in
-// the surrounding loop structure. As another example, bufferization does not
-// have to allocate new buffers for destinations (in case of in-place
-// bufferization) and can directly reuse the existing destination buffer.
-//
-// Example of a destination style op: `%r = tensor.insert_slice %t into %d`,
-// where `%t` is the single input and `%d` is the single output. `%d` is tied
-// to `%r`.
-//
-// Example of an op that is not in destination style: `%r = tensor.pad %t`.
-// This op is not in destination style because `%r` and `%t` have different
-// shape.
-//
-// Each op that wants to implement DestinationStyleOpInterface needs to define
-// the getOutputsPositionRange() method.
-def DestinationStyleOpInterface : OpInterface<"DestinationStyleOpInterface"> {
-  let cppNamespace = "::mlir::linalg";
-  let methods = [
-    // This method has to be defined for every DPS op.
-    InterfaceMethod<
-      /*desc=*/"Return start and end indices of the output operands range.",
-      /*retTy=*/"std::pair<int64_t, int64_t>",
-      /*methodName=*/"getOutputsPositionRange",
-      /*args=*/(ins),
-      /*methodBody=*/"",
-      /*defaultImplementation=*/""
-    >,
-    //===------------------------------------------------------------------===//
-    // Operands handling.
-    //===------------------------------------------------------------------===//
-    // The operand list is assumed to start with the input operands and end
-    // with the output operands. Therefore, all methods to access the inputs
-    // and outputs can be expressed if the number of output operands is know.
-    InterfaceMethod<
-      /*desc=*/"Return the number of outputs.",
-      /*retTy=*/"int64_t",
-      /*methodName=*/"getNumOutputs",
-      /*args=*/(ins),
-      /*methodBody=*/"",
-      /*defaultImplementation=*/[{
-        auto [start, end] = $_op.getOutputsPositionRange();
-        return end - start;
-      }]
-    >,
-    InterfaceMethod<
-      /*desc=*/"Return the output operands.",
-      /*retTy=*/"OpOperandVector",
-      /*methodName=*/"getOutputOperands",
-      /*args=*/(ins),
-      /*methodBody=*/"",
-      /*defaultImplementation=*/[{
-        auto [start, end] = $_op.getOutputsPositionRange();
-
-        OpOperandVector result;
-        result.reserve(end - start);
-        for (int i = start; i < end; ++i)
-          result.push_back(&$_op->getOpOperand(i));
-        return result;
-      }]
-    >,
-    InterfaceMethod<
-      /*desc=*/"Return the `i`-th output operand.",
-      /*retTy=*/"OpOperand*",
-      /*methodName=*/"getOutputOperand",
-      /*args=*/(ins "int64_t":$i),
-      /*methodBody=*/"",
-      /*defaultImplementation=*/[{
-        assert(i >= 0 && i < $_op.getNumOutputs());
-        auto [start, end] = $_op.getOutputsPositionRange();
-        return &$_op->getOpOperand(start + i);
-      }]
-    >,
-    InterfaceMethod<
-      /*desc=*/"Set the `i`-th output operand.",
-      /*retTy=*/"void",
-      /*methodName=*/"setOutputOperand",
-      /*args=*/(ins "int64_t":$i, "Value":$value),
-      /*methodBody=*/"",
-      /*defaultImplementation=*/[{
-        assert(i >= 0 && i < $_op.getNumOutputs());
-        auto [start, end] = $_op.getOutputsPositionRange();
-        $_op->setOperand(start + i, value);
-      }]
-    >,
-    InterfaceMethod<
-      /*desc=*/"Return the number of inputs.",
-      /*retTy=*/"int64_t",
-      /*methodName=*/"getNumInputs",
-      /*args=*/(ins),
-      /*methodBody=*/"",
-      /*defaultImplementation=*/[{
-        return $_op.getNumOperands() - $_op.getNumOutputs();
-      }]
-    >,
-    InterfaceMethod<
-      /*desc=*/"Return the input operands.",
-      /*retTy=*/"OpOperandVector",
-      /*methodName=*/"getInputOperands",
-      /*args=*/(ins),
-      /*methodBody=*/"",
-      /*defaultImplementation=*/[{
-        auto [start, end] = $_op.getOutputsPositionRange();
-        int64_t numOutputs = end - start;
-        int64_t numOperands = $_op.getNumOperands();
-
-        OpOperandVector result;
-        result.reserve(numOperands - numOutputs);
-        for (int i = 0; i < start; ++i)
-          result.push_back(&$_op->getOpOperand(i));
-        for (int i = end; i < numOperands; ++i)
-          result.push_back(&$_op->getOpOperand(end + i));
-
-        return result;
-      }]
-    >,
-    InterfaceMethod<
-      /*desc=*/[{ Return the `i`-th input operand.  }],
-      /*retTy=*/"OpOperand*",
-      /*methodName=*/"getInputOperand",
-      /*args=*/(ins "int64_t":$i),
-      /*methodBody=*/"",
-      /*defaultImplementation=*/[{
-        assert(i >= 0 && i < getNumInputs());
-        auto [start, end] = $_op.getOutputsPositionRange();
-        return &$_op->getOpOperand(i < start ? i : i + end - start) ;
-      }]
-    >,
-    //===------------------------------------------------------------------===//
-    // Input and Output arguments handling.
-    //===------------------------------------------------------------------===//
-    InterfaceMethod<
-      /*desc=*/"Return true if `opOperand` is an input.",
-      /*retTy=*/"bool",
-      /*methodName=*/"isInput",
-      /*args=*/(ins "OpOperand *":$opOperand),
-      /*methodBody=*/"",
-      /*defaultImplementation=*/[{
-        auto [start, end] = $_op.getOutputsPositionRange();
-        auto operandNumber = opOperand->getOperandNumber();
-        return operandNumber < start || operandNumber >= end;
-      }]
-    >,
-    InterfaceMethod<
-      /*desc=*/"Return true if `opOperand` is an output.",
-      /*retTy=*/"bool",
-      /*methodName=*/"isOutput",
-      /*args=*/(ins "OpOperand *":$opOperand),
-      /*methodBody=*/"",
-      /*defaultImplementation=*/[{
-        auto [start, end] = $_op.getOutputsPositionRange();
-        auto operandNumber = opOperand->getOperandNumber();
-        return operandNumber >= start && operandNumber < end;
-      }]
-    >,
-    InterfaceMethod<
-      /*desc=*/"Return true if the `opOperand` is a scalar value.",
-      /*retTy=*/"bool",
-      /*methodName=*/"isScalar",
-      /*args=*/(ins "OpOperand*":$opOperand),
-      /*methodBody=*/"",
-      /*defaultImplementation=*/[{
-        assert(opOperand->getOwner() == this->getOperation());
-        return !opOperand->get().getType().template isa<ShapedType>();
-      }]
-    >,
-    InterfaceMethod<
-      /*desc=*/"Return the result tied to `opOperand`.",
-      /*retTy=*/"OpResult",
-      /*methodName=*/"getTiedOpResult",
-      /*args=*/(ins "OpOperand*":$opOperand),
-      /*methodBody=*/"",
-      /*defaultImplementation=*/[{
-        assert(opOperand->getOwner() == this->getOperation());
-
-        auto [start, end] = $_op.getOutputsPositionRange();
-        int64_t resultIndex = opOperand->getOperandNumber() - start;
-        assert(resultIndex >= 0 &&
-               resultIndex < $_op->getNumResults() );
-        return $_op->getResult(resultIndex);
-      }]
-    >,
-    //===------------------------------------------------------------------===//
-    // Other interface methods.
-    //===------------------------------------------------------------------===//
-    InterfaceMethod<
-      /*desc=*/"Return whether the op has only MemRef input and outputs.",
-      /*retTy=*/"bool",
-      /*methodName=*/"hasBufferSemantics",
-      /*args=*/(ins),
-      /*methodBody=*/"",
-      /*defaultImplementation=*/[{
-        return $_op->getNumResults() == 0 &&
-          llvm::all_of($_op->getOpOperands(),
-            [&](OpOperand &opOperand) {
-              return isScalar(&opOperand) ||
-                     opOperand.get().getType().template isa<MemRefType>();
-            });
-      }]
-    >,
-    InterfaceMethod<
-      /*desc=*/"Return whether the op has only RankedTensor input and outputs.",
-      /*retTy=*/"bool",
-      /*methodName=*/"hasTensorSemantics",
-      /*args=*/(ins),
-      /*methodBody=*/"",
-      /*defaultImplementation=*/[{
-        return llvm::all_of($_op->getOpOperands(),
-          [&](OpOperand &opOperand) {
-            return isScalar(&opOperand) ||
-                   opOperand.get().getType().template isa<RankedTensorType>();
-          });
-      }]
-    >,
-    //===------------------------------------------------------------------===//
-    // Other static interface methods.
-    //===------------------------------------------------------------------===//
-    InterfaceMethod<
-      /*desc=*/[{
-        Clone the current operation with the given location and operands. This
-        is used to abstract away the optional underlying region creation. This
-        does not change the balance between input, output_buffer and
-        init_tensors operands.
-      }],
-      /*retTy=*/"Operation *",
-      /*methodName=*/"clone",
-      (ins "OpBuilder &":$b, "Location":$loc, "TypeRange":$resultTypes,
-           "ValueRange":$operands),
-      [{
-        BlockAndValueMapping bvm;
-        OperationState state(
-          loc, ConcreteOp::getOperationName(), operands, resultTypes,
-          $_op->getAttrs());
-        for (Region &r : $_op->getRegions())
-          r.cloneInto(state.addRegion(), bvm);
-        return b.create(state);
-      }]
-    >,
-    InterfaceMethod<
-      /*desc=*/[{
-        Clone the current operation with the given location, operands
-        and BlockAndValueMapping but leave the regions empty. This is
-        used to abstract away the optional underlying region creation.
-        This does not change the balance between input, output_buffer
-        and init_tensors operands.
-      }],
-      /*retTy=*/"Operation *",
-      /*methodName=*/"cloneWithoutRegions",
-      (ins "OpBuilder &":$b, "Location":$loc, "TypeRange":$resultTypes,
-           "ValueRange":$operands),
-      [{
-        OperationState state(
-          loc, ConcreteOp::getOperationName(), operands, resultTypes,
-          $_op->getAttrs());
-        for (size_t cnt = 0, e = $_op->getNumRegions(); cnt < e; ++cnt)
-          state.addRegion();
-        return b.create(state);
-      }]
-    >
-  ];
-
-  let verify = [{ return detail::verifyDestinationStyleOpInterface($_op); }];
-  let verifyWithRegions = 1;
-}
-
 #endif // LINALG_IR_LINALGINTERFACES
index 3b06a59..4b83de1 100644 (file)
@@ -17,6 +17,7 @@
 include "mlir/Dialect/Linalg/IR/LinalgBase.td"
 include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td"
 include "mlir/Interfaces/ControlFlowInterfaces.td"
+include "mlir/Interfaces/DestinationStyleOpInterface.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/IR/OpAsmInterface.td"
@@ -279,7 +280,7 @@ def MapOp : LinalgStructuredBase_Op<"map", [
       int64_t getNumOperands = this->getNumOperands();
       return {getNumOperands - 1, getNumOperands};
     }
-    linalg::OpOperandVector getOpOperandsMatchingBBargs() {
+    OpOperandVector getOpOperandsMatchingBBargs() {
       return getInputOperands();
     }
 
index e471b9e..721a9de 100644 (file)
@@ -3,6 +3,7 @@ add_mlir_interface(CastInterfaces)
 add_mlir_interface(ControlFlowInterfaces)
 add_mlir_interface(CopyOpInterface)
 add_mlir_interface(DerivedAttributeOpInterface)
+add_mlir_interface(DestinationStyleOpInterface)
 add_mlir_interface(InferIntRangeInterface)
 add_mlir_interface(InferTypeOpInterface)
 add_mlir_interface(LoopLikeInterface)
diff --git a/mlir/include/mlir/Interfaces/DestinationStyleOpInterface.h b/mlir/include/mlir/Interfaces/DestinationStyleOpInterface.h
new file mode 100644 (file)
index 0000000..fe8478e
--- /dev/null
@@ -0,0 +1,34 @@
+//===- DestinationStyleOpInterface.h ----------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_INTERFACES_DESTINATIONSTYLEOPINTERFACE_H_
+#define MLIR_INTERFACES_DESTINATIONSTYLEOPINTERFACE_H_
+
+#include "mlir/IR/BlockAndValueMapping.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/Value.h"
+#include "llvm/ADT/SmallVector.h"
+
+namespace mlir {
+/// OpOperand vector that implicitly converts to a Value vector.
+struct OpOperandVector : public llvm::SmallVector<OpOperand *> {
+  operator SmallVector<Value>();
+};
+
+namespace detail {
+/// Verify that `op` conforms to the invariants of DestinationStyleOpInterface
+LogicalResult verifyDestinationStyleOpInterface(Operation *op);
+} // namespace detail
+} // namespace mlir
+
+/// Include the generated interface declarations.
+#include "mlir/Interfaces/DestinationStyleOpInterface.h.inc"
+
+#endif // MLIR_INTERFACES_DESTINATIONSTYLEOPINTERFACE_H_
diff --git a/mlir/include/mlir/Interfaces/DestinationStyleOpInterface.td b/mlir/include/mlir/Interfaces/DestinationStyleOpInterface.td
new file mode 100644 (file)
index 0000000..718b32d
--- /dev/null
@@ -0,0 +1,306 @@
+//===- DestinationStyleOpInterface.td ----------------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DESTINATIONSTYLEOPINTERFACE
+#define MLIR_DESTINATIONSTYLEOPINTERFACE
+
+include "mlir/IR/OpBase.td"
+
+def DestinationStyleOpInterface : OpInterface<"DestinationStyleOpInterface"> {
+  let description = [{
+    Ops that are in destination style have designated output operands, which act
+    as initial tensor values for the results of the operation or the output
+    buffers to which the results of the op will be written.
+
+    Output operands must be tensors or memrefs. Input operands can have any
+    type. All non-output operands are inputs.
+
+    It is assumed that the output operands of the op are the operands at
+    position [start, end). The positions are defined by getOutputsPositionRange
+    method. All non-output operands are "inputs" of the DPS op.
+
+    If the op has "tensor semantics", then the input operands are either scalars
+    or tensors. The output operands are tensors and every tensor output is tied
+    to a corresponding tensor OpResult in a 1-to-1 fashion. The i-th output
+    tensor is tied to the i-th OpResult. The op may not have any additional
+    OpResults. Output operands and their tied OpResults have the same type.
+
+    If the op has "buffer semantics", then the input operands are either memrefs
+    or other non-tensor types, e.g. scalar types. Furthermore, the output
+    operands are memrefs and the op has no results.
+
+    Destination-passing style abstraction makes certain transformations easier.
+    For example, tiling implementation can extract/insert slices from/into the
+    destination of an op and use the resulting shaped value as an iter_arg in
+    the surrounding loop structure. As another example, bufferization does not
+    have to allocate new buffers for destinations (in case of in-place
+    bufferization) and can directly reuse the existing destination buffer.
+
+    Example of a destination style op: `%r = tensor.insert_slice %t into %d`,
+    where `%t` is the single input and `%d` is the single output. `%d` is tied
+    to `%r`.
+
+    Example of an op that is not in destination style: `%r = tensor.pad %t`.
+    This op is not in destination style because `%r` and `%t` have different
+    shape.
+
+    Each op that wants to implement DestinationStyleOpInterface needs to define
+    the getOutputsPositionRange() method.
+  }];
+
+  let cppNamespace = "::mlir";
+
+  let methods = [
+    // This method has to be defined for every DPS op.
+    InterfaceMethod<
+      /*desc=*/"Return start and end indices of the output operands range.",
+      /*retTy=*/"std::pair<int64_t, int64_t>",
+      /*methodName=*/"getOutputsPositionRange",
+      /*args=*/(ins),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/""
+    >,
+    //===------------------------------------------------------------------===//
+    // Operands handling.
+    //===------------------------------------------------------------------===//
+    // The operand list is assumed to start with the input operands and end
+    // with the output operands. Therefore, all methods to access the inputs
+    // and outputs can be expressed if the number of output operands is know.
+    InterfaceMethod<
+      /*desc=*/"Return the number of outputs.",
+      /*retTy=*/"int64_t",
+      /*methodName=*/"getNumOutputs",
+      /*args=*/(ins),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        auto [start, end] = $_op.getOutputsPositionRange();
+        return end - start;
+      }]
+    >,
+    InterfaceMethod<
+      /*desc=*/"Return the output operands.",
+      /*retTy=*/"OpOperandVector",
+      /*methodName=*/"getOutputOperands",
+      /*args=*/(ins),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        auto [start, end] = $_op.getOutputsPositionRange();
+
+        OpOperandVector result;
+        result.reserve(end - start);
+        for (int i = start; i < end; ++i)
+          result.push_back(&$_op->getOpOperand(i));
+        return result;
+      }]
+    >,
+    InterfaceMethod<
+      /*desc=*/"Return the `i`-th output operand.",
+      /*retTy=*/"OpOperand *",
+      /*methodName=*/"getOutputOperand",
+      /*args=*/(ins "int64_t":$i),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        assert(i >= 0 && i < $_op.getNumOutputs());
+        auto [start, end] = $_op.getOutputsPositionRange();
+        return &$_op->getOpOperand(start + i);
+      }]
+    >,
+    InterfaceMethod<
+      /*desc=*/"Set the `i`-th output operand.",
+      /*retTy=*/"void",
+      /*methodName=*/"setOutputOperand",
+      /*args=*/(ins "int64_t":$i, "Value":$value),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        assert(i >= 0 && i < $_op.getNumOutputs());
+        auto [start, end] = $_op.getOutputsPositionRange();
+        $_op->setOperand(start + i, value);
+      }]
+    >,
+    InterfaceMethod<
+      /*desc=*/"Return the number of inputs.",
+      /*retTy=*/"int64_t",
+      /*methodName=*/"getNumInputs",
+      /*args=*/(ins),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        return $_op.getNumOperands() - $_op.getNumOutputs();
+      }]
+    >,
+    InterfaceMethod<
+      /*desc=*/"Return the input operands.",
+      /*retTy=*/"OpOperandVector",
+      /*methodName=*/"getInputOperands",
+      /*args=*/(ins),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        auto [start, end] = $_op.getOutputsPositionRange();
+        int64_t numOutputs = end - start;
+        int64_t numOperands = $_op.getNumOperands();
+
+        OpOperandVector result;
+        result.reserve(numOperands - numOutputs);
+        for (int i = 0; i < start; ++i)
+          result.push_back(&$_op->getOpOperand(i));
+        for (int i = end; i < numOperands; ++i)
+          result.push_back(&$_op->getOpOperand(end + i));
+
+        return result;
+      }]
+    >,
+    InterfaceMethod<
+      /*desc=*/[{ Return the `i`-th input operand.  }],
+      /*retTy=*/"OpOperand *",
+      /*methodName=*/"getInputOperand",
+      /*args=*/(ins "int64_t":$i),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        assert(i >= 0 && i < getNumInputs());
+        auto [start, end] = $_op.getOutputsPositionRange();
+        return &$_op->getOpOperand(i < start ? i : i + end - start) ;
+      }]
+    >,
+    //===------------------------------------------------------------------===//
+    // Input and Output arguments handling.
+    //===------------------------------------------------------------------===//
+    InterfaceMethod<
+      /*desc=*/"Return true if `opOperand` is an input.",
+      /*retTy=*/"bool",
+      /*methodName=*/"isInput",
+      /*args=*/(ins "OpOperand *":$opOperand),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        auto [start, end] = $_op.getOutputsPositionRange();
+        auto operandNumber = opOperand->getOperandNumber();
+        return operandNumber < start || operandNumber >= end;
+      }]
+    >,
+    InterfaceMethod<
+      /*desc=*/"Return true if `opOperand` is an output.",
+      /*retTy=*/"bool",
+      /*methodName=*/"isOutput",
+      /*args=*/(ins "OpOperand *":$opOperand),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        auto [start, end] = $_op.getOutputsPositionRange();
+        auto operandNumber = opOperand->getOperandNumber();
+        return operandNumber >= start && operandNumber < end;
+      }]
+    >,
+    InterfaceMethod<
+      /*desc=*/"Return true if the `opOperand` is a scalar value.",
+      /*retTy=*/"bool",
+      /*methodName=*/"isScalar",
+      /*args=*/(ins "OpOperand *":$opOperand),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        assert(opOperand->getOwner() == this->getOperation());
+        return !opOperand->get().getType().template isa<ShapedType>();
+      }]
+    >,
+    InterfaceMethod<
+      /*desc=*/"Return the result tied to `opOperand`.",
+      /*retTy=*/"OpResult",
+      /*methodName=*/"getTiedOpResult",
+      /*args=*/(ins "OpOperand *":$opOperand),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        assert(opOperand->getOwner() == this->getOperation());
+
+        auto [start, end] = $_op.getOutputsPositionRange();
+        int64_t resultIndex = opOperand->getOperandNumber() - start;
+        assert(resultIndex >= 0 &&
+               resultIndex < $_op->getNumResults() );
+        return $_op->getResult(resultIndex);
+      }]
+    >,
+    //===------------------------------------------------------------------===//
+    // Other interface methods.
+    //===------------------------------------------------------------------===//
+    InterfaceMethod<
+      /*desc=*/"Return whether the op has only MemRef input and outputs.",
+      /*retTy=*/"bool",
+      /*methodName=*/"hasBufferSemantics",
+      /*args=*/(ins),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        return $_op->getNumResults() == 0 &&
+          llvm::all_of($_op->getOpOperands(),
+            [&](OpOperand &opOperand) {
+              return isScalar(&opOperand) ||
+                     opOperand.get().getType().template isa<MemRefType>();
+            });
+      }]
+    >,
+    InterfaceMethod<
+      /*desc=*/"Return whether the op has only RankedTensor input and outputs.",
+      /*retTy=*/"bool",
+      /*methodName=*/"hasTensorSemantics",
+      /*args=*/(ins),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        return llvm::all_of($_op->getOpOperands(),
+          [&](OpOperand &opOperand) {
+            return isScalar(&opOperand) ||
+                   opOperand.get().getType().template isa<RankedTensorType>();
+          });
+      }]
+    >,
+    //===------------------------------------------------------------------===//
+    // Other static interface methods.
+    //===------------------------------------------------------------------===//
+    InterfaceMethod<
+      /*desc=*/[{
+        Clone the current operation with the given location and operands. This
+        is used to abstract away the optional underlying region creation. This
+        does not change the balance between input, output_buffer and
+        init_tensors operands.
+      }],
+      /*retTy=*/"Operation *",
+      /*methodName=*/"clone",
+      (ins "OpBuilder &":$b, "Location":$loc, "TypeRange":$resultTypes,
+           "ValueRange":$operands),
+      [{
+        BlockAndValueMapping bvm;
+        OperationState state(
+          loc, ConcreteOp::getOperationName(), operands, resultTypes,
+          $_op->getAttrs());
+        for (Region &r : $_op->getRegions())
+          r.cloneInto(state.addRegion(), bvm);
+        return b.create(state);
+      }]
+    >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Clone the current operation with the given location, operands
+        and BlockAndValueMapping but leave the regions empty. This is
+        used to abstract away the optional underlying region creation.
+        This does not change the balance between input, output_buffer
+        and init_tensors operands.
+      }],
+      /*retTy=*/"Operation *",
+      /*methodName=*/"cloneWithoutRegions",
+      (ins "OpBuilder &":$b, "Location":$loc, "TypeRange":$resultTypes,
+           "ValueRange":$operands),
+      [{
+        OperationState state(
+          loc, ConcreteOp::getOperationName(), operands, resultTypes,
+          $_op->getAttrs());
+        for (size_t cnt = 0, e = $_op->getNumRegions(); cnt < e; ++cnt)
+          state.addRegion();
+        return b.create(state);
+      }]
+    >
+  ];
+
+  let verify = [{ return detail::verifyDestinationStyleOpInterface($_op); }];
+  let verifyWithRegions = 1;
+}
+
+
+#endif // MLIR_DESTINATIONSTYLEOPINTERFACE
index 247eceb..85412db 100644 (file)
@@ -18,6 +18,7 @@ add_mlir_dialect_library(MLIRLinalgDialect
   MLIRArithDialect
   MLIRArithUtils
   MLIRBufferizationDialect
+  MLIRDestinationStyleOpInterface
   MLIRDialectUtils
   MLIRInferTypeOpInterface
   MLIRIR
index 88fd71c..78e8490 100644 (file)
@@ -462,14 +462,6 @@ LogicalResult mlir::linalg::detail::verifyFillInterface(Operation *op) {
 // StructuredOpInterface implementation
 //===----------------------------------------------------------------------===//
 
-OpOperandVector::operator SmallVector<Value>() {
-  SmallVector<Value> result;
-  result.reserve(this->size());
-  llvm::transform(*this, std::back_inserter(result),
-                  [](OpOperand *opOperand) { return opOperand->get(); });
-  return result;
-}
-
 /// Helper function that creates a memref::DimOp or tensor::DimOp depending on
 /// the type of `source`.
 static Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source,
@@ -770,55 +762,3 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
 
   return success();
 }
-
-LogicalResult
-mlir::linalg::detail::verifyDestinationStyleOpInterface(Operation *op) {
-  DestinationStyleOpInterface dstStyleOp =
-      cast<DestinationStyleOpInterface>(op);
-
-  SmallVector<OpOperand *> outputBufferOperands, outputTensorOperands;
-  for (OpOperand *operand : dstStyleOp.getOutputOperands()) {
-    Type type = operand->get().getType();
-    if (type.isa<MemRefType>())
-      outputBufferOperands.push_back(operand);
-    if (type.isa<RankedTensorType>())
-      outputTensorOperands.push_back(operand);
-  }
-
-  // Expect at least one output operand.
-  // This means an op that constructs a tensor out of indices cannot be a
-  // LinalgOp at the moment. For now this will have to be a special op until we
-  // have output shape operands that are not tensors.
-  int64_t numInputs = dstStyleOp.getNumInputs();
-  int64_t numOutputs = dstStyleOp.getNumOutputs();
-  if (numOutputs == 0)
-    return op->emitOpError("expected at least one output operand");
-  if (failed(OpTrait::impl::verifyNOperands(op, numInputs + numOutputs)))
-    return failure();
-  // Verify the number of results matches the number of output tensors.
-  if (op->getNumResults() != outputTensorOperands.size())
-    return op->emitOpError("expected the number of results (")
-           << op->getNumResults()
-           << ") to be equal to the number of output tensors ("
-           << outputTensorOperands.size() << ")";
-
-  // Simplifying assumption: either full tensor or full buffer mode.
-  // This allows simpler verification of output operands vs result types
-  // without premature tracking of which operand is what in mixed-mode.
-  // TODO: relax when mixed-mode needs to pass verification.
-  if (!outputBufferOperands.empty() && !outputTensorOperands.empty())
-    return op->emitOpError(
-        "expected output operands to all have tensor type or "
-        "all have buffer type");
-
-  for (OpOperand *opOperand : outputTensorOperands) {
-    OpResult result = dstStyleOp.getTiedOpResult(opOperand);
-    if (result.getType() != opOperand->get().getType())
-      return op->emitOpError("expected type of operand #")
-             << opOperand->getOperandNumber() << " ("
-             << opOperand->get().getType() << ")"
-             << " to match type of corresponding result (" << result.getType()
-             << ")";
-  }
-  return success();
-}
index bb38004..ca17a21 100644 (file)
@@ -13,6 +13,7 @@
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/Operation.h"
+#include "mlir/Interfaces/DestinationStyleOpInterface.h"
 
 using namespace mlir;
 using namespace linalg;
@@ -115,7 +116,7 @@ struct LinalgOpInterface
 
   SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
                                             const AnalysisState &state) const {
-    auto genericOp = cast<linalg::DestinationStyleOpInterface>(op);
+    auto genericOp = cast<DestinationStyleOpInterface>(op);
 
     // The i-th "out" tensor may alias with the i-th OpResult.
     if (genericOp.isOutput(&opOperand))
index 24048c2..c69f55c 100644 (file)
@@ -43,6 +43,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   MLIRBufferizationDialect
   MLIRBufferizationTransforms
   MLIRComplexDialect
+  MLIRDestinationStyleOpInterface
   MLIRDialectUtils
   MLIRFuncDialect
   MLIRFuncToLLVM
index 83b2fab..fb4958a 100644 (file)
@@ -5,6 +5,7 @@ set(LLVM_OPTIONAL_SOURCES
   CopyOpInterface.cpp
   DataLayoutInterfaces.cpp
   DerivedAttributeOpInterface.cpp
+  DestinationStyleOpInterface.cpp
   InferIntRangeInterface.cpp
   InferTypeOpInterface.cpp
   LoopLikeInterface.cpp
@@ -38,6 +39,7 @@ add_mlir_interface_library(ControlFlowInterfaces)
 add_mlir_interface_library(CopyOpInterface)
 add_mlir_interface_library(DataLayoutInterfaces)
 add_mlir_interface_library(DerivedAttributeOpInterface)
+add_mlir_interface_library(DestinationStyleOpInterface)
 add_mlir_interface_library(InferIntRangeInterface)
 add_mlir_interface_library(InferTypeOpInterface)
 add_mlir_interface_library(LoopLikeInterface)
diff --git a/mlir/lib/Interfaces/DestinationStyleOpInterface.cpp b/mlir/lib/Interfaces/DestinationStyleOpInterface.cpp
new file mode 100644 (file)
index 0000000..104b0fa
--- /dev/null
@@ -0,0 +1,71 @@
+//===- DestinationStyleOpInterface.cpp -- Destination style ops -----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Interfaces/DestinationStyleOpInterface.h"
+
+using namespace mlir;
+
+namespace mlir {
+#include "mlir/Interfaces/DestinationStyleOpInterface.cpp.inc"
+} // namespace mlir
+
+OpOperandVector::operator SmallVector<Value>() {
+  SmallVector<Value> result;
+  result.reserve(this->size());
+  llvm::transform(*this, std::back_inserter(result),
+                  [](OpOperand *opOperand) { return opOperand->get(); });
+  return result;
+}
+
+LogicalResult detail::verifyDestinationStyleOpInterface(Operation *op) {
+  DestinationStyleOpInterface dstStyleOp =
+      cast<DestinationStyleOpInterface>(op);
+
+  SmallVector<OpOperand *> outputBufferOperands, outputTensorOperands;
+  for (OpOperand *operand : dstStyleOp.getOutputOperands()) {
+    Type type = operand->get().getType();
+    if (type.isa<MemRefType>())
+      outputBufferOperands.push_back(operand);
+    if (type.isa<RankedTensorType>())
+      outputTensorOperands.push_back(operand);
+  }
+
+  // Expect at least one output operand.
+  int64_t numInputs = dstStyleOp.getNumInputs();
+  int64_t numOutputs = dstStyleOp.getNumOutputs();
+  if (numOutputs == 0)
+    return op->emitOpError("expected at least one output operand");
+  if (failed(OpTrait::impl::verifyNOperands(op, numInputs + numOutputs)))
+    return failure();
+  // Verify the number of results matches the number of output tensors.
+  if (op->getNumResults() != outputTensorOperands.size())
+    return op->emitOpError("expected the number of results (")
+           << op->getNumResults()
+           << ") to be equal to the number of output tensors ("
+           << outputTensorOperands.size() << ")";
+
+  // Simplifying assumption: either full tensor or full buffer mode.
+  // This allows simpler verification of output operands vs result types
+  // without premature tracking of which operand is what in mixed-mode.
+  // TODO: relax when mixed-mode needs to pass verification.
+  if (!outputBufferOperands.empty() && !outputTensorOperands.empty())
+    return op->emitOpError(
+        "expected output operands to all have tensor type or "
+        "all have buffer type");
+
+  for (OpOperand *opOperand : outputTensorOperands) {
+    OpResult result = dstStyleOp.getTiedOpResult(opOperand);
+    if (result.getType() != opOperand->get().getType())
+      return op->emitOpError("expected type of operand #")
+             << opOperand->getOperandNumber() << " ("
+             << opOperand->get().getType() << ")"
+             << " to match type of corresponding result (" << result.getType()
+             << ")";
+  }
+  return success();
+}
index 141d618..2c8719d 100644 (file)
@@ -54,6 +54,7 @@ add_mlir_library(MLIRTestDialect
   MLIRControlFlowInterfaces
   MLIRDataLayoutInterfaces
   MLIRDerivedAttributeOpInterface
+  MLIRDestinationStyleOpInterface
   MLIRDialect
   MLIRDLTIDialect
   MLIRFuncDialect
index 9888923..e4206f9 100644 (file)
@@ -23,6 +23,7 @@ include "mlir/Interfaces/CallInterfaces.td"
 include "mlir/Interfaces/ControlFlowInterfaces.td"
 include "mlir/Interfaces/CopyOpInterface.td"
 include "mlir/Interfaces/DataLayoutInterfaces.td"
+include "mlir/Interfaces/DestinationStyleOpInterface.td"
 include "mlir/Interfaces/InferIntRangeInterface.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/LoopLikeInterface.td"
index 3666aa8..024ba9c 100644 (file)
@@ -996,6 +996,13 @@ td_library(
 )
 
 td_library(
+    name = "DestinationStyleOpInterfaceTdFiles",
+    srcs = ["include/mlir/Interfaces/DestinationStyleOpInterface.td"],
+    includes = ["include"],
+    deps = [":OpBaseTdFiles"],
+)
+
+td_library(
     name = "InferIntRangeInterfaceTdFiles",
     srcs = ["include/mlir/Interfaces/InferIntRangeInterface.td"],
     includes = ["include"],
@@ -5322,6 +5329,36 @@ cc_library(
 )
 
 gentbl_cc_library(
+    name = "DestinationStyleOpInterfaceIncGen",
+    strip_include_prefix = "include",
+    tbl_outs = [
+        (
+            ["-gen-op-interface-decls"],
+            "include/mlir/Interfaces/DestinationStyleOpInterface.h.inc",
+        ),
+        (
+            ["-gen-op-interface-defs"],
+            "include/mlir/Interfaces/DestinationStyleOpInterface.cpp.inc",
+        ),
+    ],
+    tblgen = ":mlir-tblgen",
+    td_file = "include/mlir/Interfaces/DestinationStyleOpInterface.td",
+    deps = [":DestinationStyleOpInterfaceTdFiles"],
+)
+
+cc_library(
+    name = "DestinationStyleOpInterface",
+    srcs = ["lib/Interfaces/DestinationStyleOpInterface.cpp"],
+    hdrs = ["include/mlir/Interfaces/DestinationStyleOpInterface.h"],
+    includes = ["include"],
+    deps = [
+        ":DestinationStyleOpInterfaceIncGen",
+        ":IR",
+        "//llvm:Support",
+    ],
+)
+
+gentbl_cc_library(
     name = "InferIntRangeInterfaceIncGen",
     strip_include_prefix = "include",
     tbl_outs = [
@@ -7437,6 +7474,7 @@ td_library(
     includes = ["include"],
     deps = [
         ":ControlFlowInterfacesTdFiles",
+        ":DestinationStyleOpInterfaceTdFiles",
         ":DialectUtilsTdFiles",
         ":InferTypeOpInterfaceTdFiles",
         ":LoopLikeInterfaceTdFiles",
@@ -7571,6 +7609,7 @@ td_library(
     includes = ["include"],
     deps = [
         ":CopyOpInterfaceTdFiles",
+        ":DestinationStyleOpInterface",
         ":LinalgOpsTdFiles",
         ":OpBaseTdFiles",
         ":SideEffectInterfacesTdFiles",
@@ -7768,6 +7807,7 @@ cc_library(
         ":ComplexDialect",
         ":ControlFlowInterfaces",
         ":CopyOpInterface",
+        ":DestinationStyleOpInterface",
         ":DialectUtils",
         ":FuncDialect",
         ":IR",
@@ -7925,6 +7965,7 @@ cc_library(
         ":BufferizationTransforms",
         ":ComplexDialect",
         ":ControlFlowDialect",
+        ":DestinationStyleOpInterface",
         ":DialectUtils",
         ":FuncDialect",
         ":FuncTransforms",
index 02ae5f1..c11dee8 100644 (file)
@@ -94,6 +94,7 @@ td_library(
         "//mlir:CopyOpInterfaceTdFiles",
         "//mlir:DLTIDialectTdFiles",
         "//mlir:DataLayoutInterfacesTdFiles",
+        "//mlir:DestinationStyleOpInterfaceTdFiles",
         "//mlir:InferIntRangeInterfaceTdFiles",
         "//mlir:InferTypeOpInterfaceTdFiles",
         "//mlir:LinalgStructuredOpsTdFiles",
@@ -325,6 +326,7 @@ cc_library(
         "//mlir:DLTIDialect",
         "//mlir:DataLayoutInterfaces",
         "//mlir:DerivedAttributeOpInterface",
+        "//mlir:DestinationStyleOpInterface",
         "//mlir:Dialect",
         "//mlir:FuncDialect",
         "//mlir:FuncTransforms",