[mlir][linalg] Reuploading: Apply shortened printing/parsing form to linalg.reduce.
authorAliia Khasanova <aliia@google.com>
Mon, 9 Jan 2023 07:25:03 +0000 (08:25 +0100)
committerAliia Khasanova <aliia@google.com>
Mon, 9 Jan 2023 12:32:29 +0000 (13:32 +0100)
Differential Revision: https://reviews.llvm.org/D141259

mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/test/Dialect/Linalg/one-shot-bufferize.mlir
mlir/test/Dialect/Linalg/roundtrip.mlir

index 5456ca1..dd2a943 100644 (file)
@@ -255,6 +255,16 @@ def MapOp : LinalgStructuredBase_Op<"map", [
             linalg.yield %0: f32
           }
     ```
+
+    Shortened print form is available. Applies to simple maps with one 
+    non-yield operation inside the body.
+
+    The example above will be printed as:
+    ```
+      %add = linalg.map { arith.addf }
+          ins(%lhs, %rhs : tensor<64xf32>, tensor<64xf32>)
+          outs(%init: tensor<64xf32>)
+    ```
   }];
 
   let arguments = (ins
@@ -329,10 +339,22 @@ def ReduceOp : LinalgStructuredBase_Op<"reduce", [
           outs(%init:tensor<16x64xf32>)
           dimensions = [1]
           (%in: f32, %out: f32) {
-            %0 = arith.addf %in, %out: f32
+            %0 = arith.addf %out, %in: f32
             linalg.yield %0: f32
           }
     ```
+
+    Shortened print form is available. Applies to simple (not variadic) reduces
+    with one non-yield operation inside the body. Applies only if the operation
+    takes `%out` as the first argument.
+
+    The example above will be printed as:
+    ```
+          %reduce = linalg.reduce { arith.addf }
+          ins(%input:tensor<16x32x64xf32>)
+          outs(%init:tensor<16x64xf32>)
+          dimensions = [1]
+    ```
   }];
 
   let arguments = (ins
index 48e7cfd..33f49c9 100644 (file)
@@ -1046,7 +1046,8 @@ void MapOp::build(
 static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result,
                                  const OperationName &payloadOpName,
                                  const NamedAttrList &payloadOpAttrs,
-                                 ArrayRef<Value> operands) {
+                                 ArrayRef<Value> operands,
+                                 bool initFirst = false) {
   OpBuilder b(parser.getContext());
   Region *body = result.addRegion();
   Block &block = body->emplaceBlock();
@@ -1056,14 +1057,24 @@ static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result,
     block.addArgument(operand.getType().cast<ShapedType>().getElementType(),
                       b.getUnknownLoc());
   }
+  SmallVector<Value> payloadOpOperands;
+  // If initFirst flag is enabled, we consider init as the first position of
+  // payload operands.
+  if (initFirst) {
+    payloadOpOperands.push_back(block.getArguments().back());
+    for (const auto &arg : block.getArguments().drop_back())
+      payloadOpOperands.push_back(arg);
+  } else {
+    payloadOpOperands = {block.getArguments().begin(),
+                         block.getArguments().end()};
+  }
 
   Operation *payloadOp = b.create(
       result.location, b.getStringAttr(payloadOpName.getStringRef()),
-      block.getArguments(),
+      payloadOpOperands,
       TypeRange{
           result.operands.back().getType().cast<ShapedType>().getElementType()},
       payloadOpAttrs);
-
   b.create<YieldOp>(result.location, payloadOp->getResults());
 }
 
@@ -1102,7 +1113,9 @@ ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
 
 // Retrieve the operation from the body, if it is the only one (except
 // yield) and if it gets the same amount of arguments as the body does.
-static Operation *findPayloadOp(Block *body) {
+// If initFirst flag is enabled, we check that init takes the first position in
+// operands of payload.
+static Operation *findPayloadOp(Block *body, bool initFirst = false) {
   if (body->getOperations().size() != 2)
     return nullptr;
   Operation &payload = body->getOperations().front();
@@ -1111,10 +1124,22 @@ static Operation *findPayloadOp(Block *body) {
   if (payload.getNumOperands() == 0 ||
       payload.getNumOperands() != body->getNumArguments())
     return nullptr;
-  for (const auto &[bbArg, operand] :
-       llvm::zip(payload.getOperands(), body->getArguments())) {
-    if (bbArg != operand)
+  if (initFirst) {
+    // check init
+    if (payload.getOperands().back() != body->getArgument(0))
       return nullptr;
+    // check rest
+    for (const auto &[operand, bbArg] :
+         llvm::zip(payload.getOperands(), body->getArguments().drop_front())) {
+      if (bbArg != operand)
+        return nullptr;
+    }
+  } else {
+    for (const auto &[operand, bbArg] :
+         llvm::zip(payload.getOperands(), body->getArguments())) {
+      if (bbArg != operand)
+        return nullptr;
+    }
   }
   return &payload;
 }
@@ -1313,7 +1338,7 @@ ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &result) {
 
   if (payloadOpName.has_value()) {
     addBodyWithPayloadOp(parser, result, payloadOpName.value(), payloadOpAttrs,
-                         makeArrayRef(result.operands));
+                         makeArrayRef(result.operands), /*initFirst=*/true);
   } else {
     SmallVector<OpAsmParser::Argument> regionArgs;
     if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren,
@@ -1336,7 +1361,7 @@ static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName,
 
 void ReduceOp::print(OpAsmPrinter &p) {
   Block *mapper = getBody();
-  Operation *payloadOp = findPayloadOp(mapper);
+  Operation *payloadOp = findPayloadOp(mapper, /*initFirst=*/true);
   if (payloadOp) {
     printShortForm(p, payloadOp);
   }
index 87763c9..7795b63 100644 (file)
@@ -363,7 +363,7 @@ func.func @reduce(%input: tensor<16x32x64xf32>,
       outs(%init:tensor<16x64xf32>)
       dimensions = [1]
       (%in: f32, %out: f32) {
-        %0 = arith.addf %in, %out: f32
+        %0 = arith.addf %out, %in: f32
         linalg.yield %0: f32
       }
   func.return %reduce : tensor<16x64xf32>
index 611d428..c665366 100644 (file)
@@ -414,7 +414,7 @@ func.func @reduce(%input: tensor<16x32x64xf32>,
       outs(%init:tensor<16x64xf32>)
       dimensions = [1]
       (%in: f32, %out: f32) {
-        %0 = arith.addf %in, %out: f32
+        %0 = arith.addf %out, %in: f32
         linalg.yield %0: f32
       }
   func.return %reduce : tensor<16x64xf32>
@@ -433,7 +433,7 @@ func.func @reduce_memref(%input: memref<16x32x64xf32>,
       outs(%init:memref<16x64xf32>)
       dimensions = [1]
       (%in: f32, %out: f32) {
-        %0 = arith.addf %in, %out: f32
+        %0 = arith.addf %out, %in: f32
         linalg.yield %0: f32
       }
   func.return
@@ -587,7 +587,7 @@ func.func @reduce_arith_with_attr(%input: tensor<16x32x64xf32>,
       outs(%init:tensor<16x64xf32>)
       dimensions = [1]
       (%in: f32, %out: f32) {
-        %0 = arith.addf %in, %out fastmath<fast> : f32
+        %0 = arith.addf %out, %in fastmath<fast> : f32
         linalg.yield %0: f32
       }
   func.return %reduce : tensor<16x64xf32>