[mlir][Linalg] NFC - Add result and bbArg pretty printing to linalg.reduce
authorNicolas Vasilache <nicolas.vasilache@gmail.com>
Tue, 4 Oct 2022 12:34:12 +0000 (05:34 -0700)
committerNicolas Vasilache <nicolas.vasilache@gmail.com>
Tue, 4 Oct 2022 16:27:18 +0000 (09:27 -0700)
Differential Revision: https://reviews.llvm.org/D135152

mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

index d8c1e0b..268587e 100644 (file)
@@ -340,6 +340,24 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
     >,
     InterfaceMethod<
       /*desc=*/[{
+        Return the input block arguments of the region.
+      }],
+      /*retTy=*/"Block::BlockArgListType",
+      /*methodName=*/"getRegionInputArgs",
+      /*args=*/(ins),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        // MLIR currently does not support dependent interfaces or interface
+        // inheritance. By construction all ops with StructuredOpInterface must
+        // implement DestinationStyleOpInterface.
+        // TODO: reevalute the need for a cast when a better mechanism exists.
+        return getBlock()->getArguments().take_front(
+            cast<DestinationStyleOpInterface>(*this->getOperation())
+                .getNumInputs());
+      }]
+    >,
+    InterfaceMethod<
+      /*desc=*/[{
         Return the output block arguments of the region.
       }],
       /*retTy=*/"Block::BlockArgListType",
index 6234d33..3d1ee2f 100644 (file)
@@ -19,6 +19,7 @@ include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td"
 include "mlir/Interfaces/ControlFlowInterfaces.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/IR/OpAsmInterface.td"
 
 // Base Tablegen class for Linalg ops.
 // Linalg ops that correspond to library calls operate on ShapedType as their
@@ -229,8 +230,10 @@ def TensorOrMemref :
   AnyTypeOf<[AnyMemRef, AnyRankedTensor], "", "::mlir::ShapedType">;
 
 def ReduceOp : LinalgStructuredBase_Op<"reduce", [
-      SameVariadicOperandSize, SingleBlockImplicitTerminator<"YieldOp">
-    ]> {
+    DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+    DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmBlockArgumentNames"]>,
+    SameVariadicOperandSize,
+    SingleBlockImplicitTerminator<"YieldOp">]> {
   let summary = "Reduce operator";
   let description = [{
     Executes `combiner` on the `dimensions` of `inputs` and returns the
index 6289369..3741e7d 100644 (file)
@@ -1187,6 +1187,19 @@ LogicalResult GenericOp::fold(ArrayRef<Attribute>,
 // ReduceOp
 //===----------------------------------------------------------------------===//
 
+void ReduceOp::getAsmBlockArgumentNames(Region &region,
+                                        OpAsmSetValueNameFn setNameFn) {
+  for (Value v : getRegionInputArgs())
+    setNameFn(v, "in");
+  for (Value v : getRegionOutputArgs())
+    setNameFn(v, "init");
+}
+
+void ReduceOp::getAsmResultNames(
+    function_ref<void(Value, StringRef)> setNameFn) {
+  setNameFn(getResults().front(), "reduced");
+}
+
 ArrayAttr ReduceOp::getIteratorTypes() {
   int64_t inputRank = getInputs()[0].getType().cast<ShapedType>().getRank();
   SmallVector<StringRef> iteratorTypes(inputRank,