[mlir] Fix canonicalization of tiled_loop if not all opresults fold.
authorAlexander Belyaev <pifon@google.com>
Wed, 28 Apr 2021 17:48:27 +0000 (19:48 +0200)
committerAlexander Belyaev <pifon@google.com>
Wed, 28 Apr 2021 17:57:48 +0000 (19:57 +0200)
The current canonicalization did not remap operation results correctly
and attempted to erase tiledLoop, which is incorrect if not all tensor
results are folded.

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/test/Dialect/Linalg/canonicalize.mlir

index 17ecab1..750d818 100644 (file)
@@ -2094,7 +2094,7 @@ namespace {
 
 static constexpr int64_t kNoMatch = -1;
 
-// Folds away TiledLoopOp input tensors if they have no uses within the body.
+// Folds away TiledLoopOp inputs if they have no uses within the body.
 //
 // Example:
 //
@@ -2117,7 +2117,7 @@ struct TiledLoopInputsFolder : public OpRewritePattern<linalg::TiledLoopOp> {
       Value in, bbArg;
       size_t index = en.index();
       std::tie(in, bbArg) = en.value();
-      if (!in.getType().isa<RankedTensorType>() || !bbArg.use_empty()) {
+      if (!bbArg.use_empty()) {
         oldInputIdToNew[index] = newInputs.size();
         newInputs.push_back(in);
       }
@@ -2142,7 +2142,7 @@ struct TiledLoopInputsFolder : public OpRewritePattern<linalg::TiledLoopOp> {
         OpBuilder::atBlockEnd(newTiledLoop.getBody(), rewriter.getListener());
     for (auto &op : *tiledLoop.getBody())
       innerBuilder.clone(op, bvm);
-    rewriter.eraseOp(tiledLoop);
+    rewriter.replaceOp(tiledLoop, newTiledLoop.getResults());
 
     return success();
   }
@@ -2184,6 +2184,10 @@ struct TiledLoopResultsFolder : public OpRewritePattern<linalg::TiledLoopOp> {
     // Store ids of the corresponding old and new output operands.
     SmallVector<int64_t, 2> oldOutputIdToNew(tiledLoop.outputs().size(),
                                              kNoMatch);
+    // Store ids of the corresponding old and new results.
+    SmallVector<int64_t, 2> oldResultIdToNew(tiledLoop.getNumResults(),
+                                             kNoMatch);
+    SmallVector<Value, 2> resultReplacement(tiledLoop.getNumResults());
     for (auto en : llvm::enumerate(
              llvm::zip(tiledLoop.outputs(), tiledLoop.getRegionOutputArgs()))) {
       size_t index = en.index();
@@ -2199,6 +2203,8 @@ struct TiledLoopResultsFolder : public OpRewritePattern<linalg::TiledLoopOp> {
       Value yieldArg = yieldOp.getOperand(resultId);
       if (yieldArg != outRegionArg || !result.use_empty()) {
         oldOutputIdToNew[index] = newOutputOperands.size();
+        oldResultIdToNew[resultId] = newYieldArgs.size();
+        resultReplacement[resultId] = out;
         newOutputOperands.push_back(out);
         newYieldArgs.push_back(yieldArg);
       }
@@ -2228,8 +2234,14 @@ struct TiledLoopResultsFolder : public OpRewritePattern<linalg::TiledLoopOp> {
         OpBuilder::atBlockEnd(newTiledLoop.getBody(), rewriter.getListener());
     for (auto &op : tiledLoop.getBody()->without_terminator())
       innerBuilder.clone(op, bvm);
-    innerBuilder.create<linalg::YieldOp>(loc, newYieldArgs);
-    rewriter.eraseOp(tiledLoop);
+    innerBuilder.create<linalg::YieldOp>(
+        loc, llvm::to_vector<2>(llvm::map_range(
+                 newYieldArgs, [&](Value arg) { return bvm.lookup(arg); })));
+
+    for (const auto &en : llvm::enumerate(oldResultIdToNew))
+      if (en.value() != kNoMatch)
+        resultReplacement[en.index()] = newTiledLoop.getResult(en.value());
+    rewriter.replaceOp(tiledLoop, resultReplacement);
 
     return success();
   }
index e66ee38..244b78f 100644 (file)
@@ -867,75 +867,71 @@ func @fold_fill_reshape_dynamic(%arg0 : tensor<?x?x?x?x?xf32>) -> tensor<?x?xf32
   return %1 : tensor<?x?xf32>
 }
 
-// -----
 
-#map0 = affine_map<(d0) -> (24, -d0 + 192)>
-#map1 = affine_map<(d0, d1)[s0] -> (d0 * 192 + s0 + d1)>
-#map2 = affine_map<(d0) -> (16, -d0 + 192)>
+// -----
 
-func private @foo(%A: memref<192x192xf32>, %B: memref<192x192xf32>,
-                  %C: memref<192x192xf32>) -> ()
+func private @foo(%A: memref<48xf32>, %B: tensor<48xf32>,
+                  %C: memref<48xf32>) -> (tensor<48xf32>)
 
-func @fold_tiled_loop_results(%A: memref<192x192xf32>, %B: memref<192x192xf32>,
-                              %C: memref<192x192xf32>,
-                              %C_tensor: tensor<192x192xf32>) {
-  %cst = constant 0.000000e+00 : f32
-  %c24 = constant 24 : index
-  %c16 = constant 16 : index
+func @fold_tiled_loop_results(%A: memref<48xf32>, %B: tensor<48xf32>,
+    %C: memref<48xf32>, %C_tensor: tensor<48xf32>) -> tensor<48xf32> {
   %c0 = constant 0 : index
-  %c192 = constant 192 : index
-  %useless = linalg.tiled_loop (%i, %j) = (%c0, %c0) to (%c192, %c192)
-      step (%c24, %c16)
-      ins (%A_ = %A: memref<192x192xf32>, %B_ = %B: memref<192x192xf32>)
-      outs (%CT_ = %C_tensor: tensor<192x192xf32>,
-            %C_ = %C: memref<192x192xf32>) {
-        call @foo(%A_, %B_, %C_)
-          : (memref<192x192xf32>, memref<192x192xf32>, memref<192x192xf32>)-> ()
-    linalg.yield %CT_ : tensor<192x192xf32>
+  %c24 = constant 24 : index
+  %c48 = constant 48 : index
+  %useful, %useless = linalg.tiled_loop (%i) = (%c0) to (%c48) step (%c24)
+      ins (%A_ = %A: memref<48xf32>)
+      outs (%B_ = %B: tensor<48xf32>,
+            %CT_ = %C_tensor: tensor<48xf32>,
+            %C_ = %C: memref<48xf32>) {
+        %result = call @foo(%A_, %B_, %C_)
+          : (memref<48xf32>, tensor<48xf32>, memref<48xf32>)-> (tensor<48xf32>)
+    linalg.yield %result, %CT_ : tensor<48xf32>, tensor<48xf32>
   }
-  return
+  return %useful : tensor<48xf32>
 }
 
 // CHECK-LABEL: func @fold_tiled_loop_results(
-// CHECK-SAME:    %[[A:.*]]: [[TY:.*]], %[[B:.*]]: [[TY]], %[[C:.*]]: [[TY]],
-// CHECK-SAME:    %[[C_TENSOR:.*]]: tensor<{{.*}}>) {
-// CHECK:  %[[C24:.*]] = constant 24 : index
-// CHECK:  %[[C16:.*]] = constant 16 : index
+// CHECK-SAME:   %[[A:.*]]: [[BUF_TY:memref<48xf32>]], %[[B:.*]]: [[TY:tensor<48xf32>]],
+// CHECK-SAME:   %[[C:.*]]: [[BUF_TY]],  %[[C_TENSOR:.*]]: [[TY]]) -> [[TY]] {
+
 // CHECK:  %[[C0:.*]] = constant 0 : index
-// CHECK:  %[[C192:.*]] = constant 192 : index
+// CHECK:  %[[C24:.*]] = constant 24 : index
+// CHECK:  %[[C48:.*]] = constant 48 : index
 
 // CHECK-NOT: %{{.*}} = linalg.tiled_loop
-// CHECK:  linalg.tiled_loop (%{{.*}}, %{{.*}}) = (%[[C0]], %[[C0]])
-// CHECK-SAME: to (%[[C192]], %[[C192]]) step (%[[C24]], %[[C16]])
-// CHECK-SAME: ins (%[[A_:.*]] = %[[A]]: memref<192x192xf32>, %[[B_:.*]] = %[[B]]: memref<192x192xf32>)
-// CHECK-SAME: outs (%[[C_:.*]] = %[[C]]: memref<192x192xf32>) {
-// CHECK-NEXT:   call @foo(%[[A_]], %[[B_]], %[[C_]])
-// CHECK-NEXT:   linalg.yield
+// CHECK:  %[[RESULT:.*]] = linalg.tiled_loop (%{{.*}}) = (%[[C0]])
+// CHECK-SAME: to (%[[C48]]) step (%[[C24]])
+// CHECK-SAME: ins (%[[A_:.*]] = %[[A]]: [[BUF_TY]])
+// CHECK-SAME: outs (%[[B_:.*]] = %[[B]]: [[TY]], %[[C_:.*]] = %[[C]]: [[BUF_TY]]) {
+// CHECK-NEXT:   %[[RES:.*]] = call @foo(%[[A_]], %[[B_]], %[[C_]])
+// CHECK-NEXT:   linalg.yield %[[RES]] :
 
-// -----
+// CHECK: return %[[RESULT]]
 
-#map0 = affine_map<(d0) -> (24, -d0 + 192)>
-#map1 = affine_map<(d0, d1)[s0] -> (d0 * 192 + s0 + d1)>
-#map2 = affine_map<(d0) -> (16, -d0 + 192)>
+// -----
 
-func private @foo(%A: memref<192xf32>) -> ()
+func private @foo(%A: memref<192xf32>, %B: tensor<192xf32>) -> tensor<192xf32>
 
-func @fold_tiled_loop_inputs(%A: memref<192xf32>, %A_tensor: tensor<192xf32>) {
+func @fold_tiled_loop_inputs(%A: memref<192xf32>, %A_tensor: tensor<192xf32>,
+                             %B_tensor: tensor<192xf32>) -> tensor<192xf32> {
   %c0 = constant 0 : index
   %c24 = constant 24 : index
   %c192 = constant 192 : index
-  linalg.tiled_loop (%i) = (%c0) to (%c192) step (%c24)
-      ins (%A_ = %A: memref<192xf32>, %AT_ = %A_tensor: tensor<192xf32>) {
-        call @foo(%A_) : (memref<192xf32>)-> ()
-    linalg.yield
+  %result = linalg.tiled_loop (%i) = (%c0) to (%c192) step (%c24)
+      ins (%A_ = %A: memref<192xf32>, %AT_ = %A_tensor: tensor<192xf32>)
+      outs (%BT_ = %B_tensor: tensor<192xf32>) {
+    %0 = call @foo(%A_, %BT_) : (memref<192xf32>, tensor<192xf32>) -> tensor<192xf32>
+    linalg.yield %0 : tensor<192xf32>
   }
-  return
+  return %result : tensor<192xf32>
 }
 
 // CHECK-LABEL: func @fold_tiled_loop_inputs
-// CHECK: linalg.tiled_loop
+// CHECK: %[[RESULT:.*]] = linalg.tiled_loop
 // CHECK-SAME: ins (%{{.*}} = %{{.*}}: memref<192xf32>)
 
+// CHECK: return %[[RESULT]]
+
 // -----
 
 func @dim_of_pad_op(%arg0 : tensor<2x?x?xf32>, %arg1 : index, %arg2 : index,