[MLIR][SCF] Remove loop invariant arguments of scf.while
authorAbhishek Varma <abhishek.varma@cerebras.net>
Thu, 3 Feb 2022 16:09:51 +0000 (17:09 +0100)
committerAlex Zinenko <zinenko@google.com>
Thu, 3 Feb 2022 16:13:25 +0000 (17:13 +0100)
-- This commit adds a canonicalization pattern on scf.while to remove
   the loop invariant arguments.
-- An argument is considered loop invariant if the iteration argument value is
   the same as the corresponding one being yielded (at the same position) in both
   the before/after block of scf.while.
-- For the arguments removed, their use within scf.while and their corresponding
   scf.while's result are replaced with their corresponding initial value.

Signed-off-by: Abhishek Varma <abhishek.varma@polymagelabs.com>
Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D116923

mlir/lib/Dialect/SCF/SCF.cpp
mlir/test/Dialect/SCF/canonicalize.mlir

index 28c50ba..0e2759f 100644 (file)
@@ -2343,6 +2343,297 @@ struct WhileConditionTruth : public OpRewritePattern<WhileOp> {
   }
 };
 
+/// Remove loop invariant arguments from `before` block of scf.while.
+/// A before block argument is considered loop invariant if :-
+///   1. i-th yield operand is equal to the i-th while operand.
+///   2. i-th yield operand is k-th after block argument which is (k+1)-th
+///      condition operand AND this (k+1)-th condition operand is equal to i-th
+///      iter argument/while operand.
+/// For the arguments which are removed, their uses inside scf.while
+/// are replaced with their corresponding initial value.
+///
+/// Eg:
+///    INPUT :-
+///    %res = scf.while <...> iter_args(%arg0_before = %a, %arg1_before = %b,
+///                                     ..., %argN_before = %N)
+///           {
+///                ...
+///                scf.condition(%cond) %arg1_before, %arg0_before,
+///                                     %arg2_before, %arg0_before, ...
+///           } do {
+///             ^bb0(%arg1_after, %arg0_after_1, %arg2_after, %arg0_after_2,
+///                  ..., %argK_after):
+///                ...
+///                scf.yield %arg0_after_2, %b, %arg1_after, ..., %argN
+///           }
+///
+///    OUTPUT :-
+///    %res = scf.while <...> iter_args(%arg2_before = %c, ..., %argN_before =
+///                                     %N)
+///           {
+///                ...
+///                scf.condition(%cond) %b, %a, %arg2_before, %a, ...
+///           } do {
+///             ^bb0(%arg1_after, %arg0_after_1, %arg2_after, %arg0_after_2,
+///                  ..., %argK_after):
+///                ...
+///                scf.yield %arg1_after, ..., %argN
+///           }
+///
+///    EXPLANATION:
+///      We iterate over each yield operand.
+///        1. 0-th yield operand %arg0_after_2 is 4-th condition operand
+///           %arg0_before, which in turn is the 0-th iter argument. So we
+///           remove 0-th before block argument and yield operand, and replace
+///           all uses of the 0-th before block argument with its initial value
+///           %a.
+///        2. 1-th yield operand %b is equal to the 1-th iter arg's initial
+///           value. So we remove this operand and the corresponding before
+///           block argument and replace all uses of 1-th before block argument
+///           with %b.
+struct RemoveLoopInvariantArgsFromBeforeBlock
+    : public OpRewritePattern<WhileOp> {
+  using OpRewritePattern<WhileOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(WhileOp op,
+                                PatternRewriter &rewriter) const override {
+    Block &afterBlock = op.getAfter().front();
+    Block::BlockArgListType beforeBlockArgs = op.getBeforeArguments();
+    ConditionOp condOp = op.getConditionOp();
+    OperandRange condOpArgs = condOp.getArgs();
+    Operation *yieldOp = afterBlock.getTerminator();
+    ValueRange yieldOpArgs = yieldOp->getOperands();
+
+    bool canSimplify = false;
+    for (auto it : llvm::enumerate(llvm::zip(op.getOperands(), yieldOpArgs))) {
+      auto index = static_cast<unsigned>(it.index());
+      Value initVal, yieldOpArg;
+      std::tie(initVal, yieldOpArg) = it.value();
+      // If i-th yield operand is equal to the i-th operand of the scf.while,
+      // the i-th before block argument is a loop invariant.
+      if (yieldOpArg == initVal) {
+        canSimplify = true;
+        break;
+      }
+      // If the i-th yield operand is k-th after block argument, then we check
+      // if the (k+1)-th condition op operand is equal to either the i-th before
+      // block argument or the initial value of i-th before block argument. If
+      // the comparison results `true`, i-th before block argument is a loop
+      // invariant.
+      auto yieldOpBlockArg = yieldOpArg.dyn_cast<BlockArgument>();
+      if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
+        Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
+        if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
+          canSimplify = true;
+          break;
+        }
+      }
+    }
+
+    if (!canSimplify)
+      return failure();
+
+    SmallVector<Value> newInitArgs, newYieldOpArgs;
+    DenseMap<unsigned, Value> beforeBlockInitValMap;
+    SmallVector<Location> newBeforeBlockArgLocs;
+    for (auto it : llvm::enumerate(llvm::zip(op.getOperands(), yieldOpArgs))) {
+      auto index = static_cast<unsigned>(it.index());
+      Value initVal, yieldOpArg;
+      std::tie(initVal, yieldOpArg) = it.value();
+
+      // If i-th yield operand is equal to the i-th operand of the scf.while,
+      // the i-th before block argument is a loop invariant.
+      if (yieldOpArg == initVal) {
+        beforeBlockInitValMap.insert({index, initVal});
+        continue;
+      } else {
+        // If the i-th yield operand is k-th after block argument, then we check
+        // if the (k+1)-th condition op operand is equal to either the i-th
+        // before block argument or the initial value of i-th before block
+        // argument. If the comparison results `true`, i-th before block
+        // argument is a loop invariant.
+        auto yieldOpBlockArg = yieldOpArg.dyn_cast<BlockArgument>();
+        if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
+          Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
+          if (condOpArg == beforeBlockArgs[index] || condOpArg == initVal) {
+            beforeBlockInitValMap.insert({index, initVal});
+            continue;
+          }
+        }
+      }
+      newInitArgs.emplace_back(initVal);
+      newYieldOpArgs.emplace_back(yieldOpArg);
+      newBeforeBlockArgLocs.emplace_back(beforeBlockArgs[index].getLoc());
+    }
+
+    {
+      OpBuilder::InsertionGuard g(rewriter);
+      rewriter.setInsertionPoint(yieldOp);
+      rewriter.replaceOpWithNewOp<YieldOp>(yieldOp, newYieldOpArgs);
+    }
+
+    auto newWhile =
+        rewriter.create<WhileOp>(op.getLoc(), op.getResultTypes(), newInitArgs);
+
+    Block &newBeforeBlock = *rewriter.createBlock(
+        &newWhile.getBefore(), /*insertPt*/ {},
+        ValueRange(newYieldOpArgs).getTypes(), newBeforeBlockArgLocs);
+
+    Block &beforeBlock = op.getBefore().front();
+    SmallVector<Value> newBeforeBlockArgs(beforeBlock.getNumArguments());
+    // For each i-th before block argument we find it's replacement value as :-
+    //   1. If i-th before block argument is a loop invariant, we fetch it's
+    //      initial value from `beforeBlockInitValMap` by querying for key `i`.
+    //   2. Else we fetch j-th new before block argument as the replacement
+    //      value of i-th before block argument.
+    for (unsigned i = 0, j = 0, n = beforeBlock.getNumArguments(); i < n; i++) {
+      // If the index 'i' argument was a loop invariant we fetch it's initial
+      // value from `beforeBlockInitValMap`.
+      if (beforeBlockInitValMap.count(i) != 0)
+        newBeforeBlockArgs[i] = beforeBlockInitValMap[i];
+      else
+        newBeforeBlockArgs[i] = newBeforeBlock.getArgument(j++);
+    }
+
+    rewriter.mergeBlocks(&beforeBlock, &newBeforeBlock, newBeforeBlockArgs);
+    rewriter.inlineRegionBefore(op.getAfter(), newWhile.getAfter(),
+                                newWhile.getAfter().begin());
+
+    rewriter.replaceOp(op, newWhile.getResults());
+    return success();
+  }
+};
+
+/// Remove loop invariant value from result (condition op) of scf.while.
+/// A value is considered loop invariant if the final value yielded by
+/// scf.condition is defined outside of the `before` block. We remove the
+/// corresponding argument in `after` block and replace the use with the value.
+/// We also replace the use of the corresponding result of scf.while with the
+/// value.
+///
+/// Eg:
+///    INPUT :-
+///    %res_input:K = scf.while <...> iter_args(%arg0_before = , ...,
+///                                             %argN_before = %N) {
+///                ...
+///                scf.condition(%cond) %arg0_before, %a, %b, %arg1_before, ...
+///           } do {
+///             ^bb0(%arg0_after, %arg1_after, %arg2_after, ..., %argK_after):
+///                ...
+///                some_func(%arg1_after)
+///                ...
+///                scf.yield %arg0_after, %arg2_after, ..., %argN_after
+///           }
+///
+///    OUTPUT :-
+///    %res_output:M = scf.while <...> iter_args(%arg0 = , ..., %argN = %N) {
+///                ...
+///                scf.condition(%cond) %arg0, %arg1, ..., %argM
+///           } do {
+///             ^bb0(%arg0, %arg3, ..., %argM):
+///                ...
+///                some_func(%a)
+///                ...
+///                scf.yield %arg0, %b, ..., %argN
+///           }
+///
+///     EXPLANATION:
+///       1. The 1-th and 2-th operand of scf.condition are defined outside the
+///          before block of scf.while, so they get removed.
+///       2. %res_input#1's uses are replaced by %a and %res_input#2's uses are
+///          replaced by %b.
+///       3. The corresponding after block argument %arg1_after's uses are
+///          replaced by %a and %arg2_after's uses are replaced by %b.
+struct RemoveLoopInvariantValueYielded : public OpRewritePattern<WhileOp> {
+  using OpRewritePattern<WhileOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(WhileOp op,
+                                PatternRewriter &rewriter) const override {
+    Block &beforeBlock = op.getBefore().front();
+    ConditionOp condOp = op.getConditionOp();
+    OperandRange condOpArgs = condOp.getArgs();
+
+    bool canSimplify = false;
+    for (Value condOpArg : condOpArgs) {
+      // Those values not defined within `before` block will be considered as
+      // loop invariant values. We map the corresponding `index` with their
+      // value.
+      if (condOpArg.getParentBlock() != &beforeBlock) {
+        canSimplify = true;
+        break;
+      }
+    }
+
+    if (!canSimplify)
+      return failure();
+
+    Block::BlockArgListType afterBlockArgs = op.getAfterArguments();
+
+    SmallVector<Value> newCondOpArgs;
+    SmallVector<Type> newAfterBlockType;
+    DenseMap<unsigned, Value> condOpInitValMap;
+    SmallVector<Location> newAfterBlockArgLocs;
+    for (auto it : llvm::enumerate(condOpArgs)) {
+      auto index = static_cast<unsigned>(it.index());
+      Value condOpArg = it.value();
+      // Those values not defined within `before` block will be considered as
+      // loop invariant values. We map the corresponding `index` with their
+      // value.
+      if (condOpArg.getParentBlock() != &beforeBlock) {
+        condOpInitValMap.insert({index, condOpArg});
+      } else {
+        newCondOpArgs.emplace_back(condOpArg);
+        newAfterBlockType.emplace_back(condOpArg.getType());
+        newAfterBlockArgLocs.emplace_back(afterBlockArgs[index].getLoc());
+      }
+    }
+
+    {
+      OpBuilder::InsertionGuard g(rewriter);
+      rewriter.setInsertionPoint(condOp);
+      rewriter.replaceOpWithNewOp<ConditionOp>(condOp, condOp.getCondition(),
+                                               newCondOpArgs);
+    }
+
+    auto newWhile = rewriter.create<WhileOp>(op.getLoc(), newAfterBlockType,
+                                             op.getOperands());
+
+    Block &newAfterBlock =
+        *rewriter.createBlock(&newWhile.getAfter(), /*insertPt*/ {},
+                              newAfterBlockType, newAfterBlockArgLocs);
+
+    Block &afterBlock = op.getAfter().front();
+    // Since a new scf.condition op was created, we need to fetch the new
+    // `after` block arguments which will be used while replacing operations of
+    // previous scf.while's `after` blocks. We'd also be fetching new result
+    // values too.
+    SmallVector<Value> newAfterBlockArgs(afterBlock.getNumArguments());
+    SmallVector<Value> newWhileResults(afterBlock.getNumArguments());
+    for (unsigned i = 0, j = 0, n = afterBlock.getNumArguments(); i < n; i++) {
+      Value afterBlockArg, result;
+      // If index 'i' argument was loop invariant we fetch it's value from the
+      // `condOpInitMap` map.
+      if (condOpInitValMap.count(i) != 0) {
+        afterBlockArg = condOpInitValMap[i];
+        result = afterBlockArg;
+      } else {
+        afterBlockArg = newAfterBlock.getArgument(j);
+        result = newWhile.getResult(j);
+        j++;
+      }
+      newAfterBlockArgs[i] = afterBlockArg;
+      newWhileResults[i] = result;
+    }
+
+    rewriter.mergeBlocks(&afterBlock, &newAfterBlock, newAfterBlockArgs);
+    rewriter.inlineRegionBefore(op.getBefore(), newWhile.getBefore(),
+                                newWhile.getBefore().begin());
+
+    rewriter.replaceOp(op, newWhileResults);
+    return success();
+  }
+};
+
 /// Remove WhileOp results that are also unused in 'after' block.
 ///
 ///  %0:2 = scf.while () : () -> (i32, i64) {
@@ -2552,8 +2843,9 @@ struct WhileUnusedArg : public OpRewritePattern<WhileOp> {
 
 void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                           MLIRContext *context) {
-  results.insert<WhileConditionTruth, WhileUnusedResult, WhileCmpCond,
-                 WhileUnusedArg>(context);
+  results.insert<RemoveLoopInvariantArgsFromBeforeBlock,
+                 RemoveLoopInvariantValueYielded, WhileConditionTruth,
+                 WhileCmpCond, WhileUnusedResult>(context);
 }
 
 //===----------------------------------------------------------------------===//
index a9ad591..1563349 100644 (file)
@@ -870,6 +870,74 @@ func @while_unused_arg(%x : i32, %y : f64) -> i32 {
 
 // -----
 
+// CHECK-LABEL: @invariant_loop_args_in_same_order
+// CHECK-SAME: (%[[FUNC_ARG0:.*]]: tensor<i32>)
+func @invariant_loop_args_in_same_order(%f_arg0: tensor<i32>) -> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>) {
+  %cst_0 = arith.constant dense<0> : tensor<i32>
+  %cst_1 = arith.constant dense<1> : tensor<i32>
+  %cst_42 = arith.constant dense<42> : tensor<i32>
+
+  %0:5 = scf.while (%arg0 = %cst_0, %arg1 = %f_arg0, %arg2 = %cst_1, %arg3 = %cst_1, %arg4 = %cst_0) : (tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>) {
+    %1 = arith.cmpi slt, %arg0, %cst_42 : tensor<i32>
+    %2 = tensor.extract %1[] : tensor<i1>
+    scf.condition(%2) %arg0, %arg1, %arg2, %arg3, %arg4 : tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>
+  } do {
+  ^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<i32>, %arg3: tensor<i32>, %arg4: tensor<i32>): // no predecessors
+    // %arg1 here will get replaced by %cst_1
+    %1 = arith.addi %arg0, %arg1 : tensor<i32>
+    %2 = arith.addi %arg2, %arg3 : tensor<i32>
+    scf.yield %1, %arg1, %2, %2, %arg4 : tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>
+  }
+  return %0#0, %0#1, %0#2, %0#3, %0#4 : tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>
+}
+// CHECK:    %[[CST42:.*]] = arith.constant dense<42>
+// CHECK:    %[[ONE:.*]] = arith.constant dense<1>
+// CHECK:    %[[ZERO:.*]] = arith.constant dense<0>
+// CHECK:    %[[WHILE:.*]]:3 = scf.while (%[[ARG0:.*]] = %[[ZERO]], %[[ARG2:.*]] = %[[ONE]], %[[ARG3:.*]] = %[[ONE]])
+// CHECK:       arith.cmpi slt, %[[ARG0]], %{{.*}}
+// CHECK:       tensor.extract %{{.*}}[]
+// CHECK:       scf.condition(%{{.*}}) %[[ARG0]], %[[ARG2]], %[[ARG3]]
+// CHECK:    } do {
+// CHECK:     ^{{.*}}(%[[ARG0:.*]]: tensor<i32>, %[[ARG2:.*]]: tensor<i32>, %[[ARG3:.*]]: tensor<i32>):
+// CHECK:       %[[VAL0:.*]] = arith.addi %[[ARG0]], %[[FUNC_ARG0]]
+// CHECK:       %[[VAL1:.*]] = arith.addi %[[ARG2]], %[[ARG3]]
+// CHECK:       scf.yield %[[VAL0]], %[[VAL1]], %[[VAL1]]
+// CHECK:    }
+// CHECK:    return %[[WHILE]]#0, %[[FUNC_ARG0]], %[[WHILE]]#1, %[[WHILE]]#2, %[[ZERO]]
+
+// CHECK-LABEL: @while_loop_invariant_argument_different_order
+func @while_loop_invariant_argument_different_order() -> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>) {
+  %cst_0 = arith.constant dense<0> : tensor<i32>
+  %cst_1 = arith.constant dense<1> : tensor<i32>
+  %cst_42 = arith.constant dense<42> : tensor<i32>
+
+  %0:6 = scf.while (%arg0 = %cst_0, %arg1 = %cst_1, %arg2 = %cst_1, %arg3 = %cst_1, %arg4 = %cst_0) : (tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>) {
+    %1 = arith.cmpi slt, %arg0, %cst_42 : tensor<i32>
+    %2 = tensor.extract %1[] : tensor<i1>
+    scf.condition(%2) %arg1, %arg0, %arg2, %arg0, %arg3, %arg4 : tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>
+  } do {
+  ^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<i32>, %arg3: tensor<i32>, %arg4: tensor<i32>, %arg5: tensor<i32>): // no predecessors
+    %1 = arith.addi %arg0, %cst_1 : tensor<i32>
+    %2 = arith.addi %arg2, %arg3 : tensor<i32>
+    scf.yield %arg3, %arg1, %2, %2, %arg4 : tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>
+  }
+  return %0#0, %0#1, %0#2, %0#3, %0#4, %0#5 : tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>
+}
+// CHECK:    %[[CST42:.*]] = arith.constant dense<42>
+// CHECK:    %[[ONE:.*]] = arith.constant dense<1>
+// CHECK:    %[[ZERO:.*]] = arith.constant dense<0>
+// CHECK:    %[[WHILE:.*]]:2 = scf.while (%[[ARG1:.*]] = %[[ONE]], %[[ARG4:.*]] = %[[ZERO]])
+// CHECK:       arith.cmpi slt, %[[ZERO]], %[[CST42]]
+// CHECK:       tensor.extract %{{.*}}[]
+// CHECK:       scf.condition(%{{.*}}) %[[ARG1]], %[[ARG4]]
+// CHECK:    } do {
+// CHECK:     ^{{.*}}(%{{.*}}: tensor<i32>, %{{.*}}: tensor<i32>):
+// CHECK:       scf.yield %[[ZERO]], %[[ONE]]
+// CHECK:    }
+// CHECK:    return %[[WHILE]]#0, %[[ZERO]], %[[ONE]], %[[ZERO]], %[[ONE]], %[[WHILE]]#1
+
+// -----
+
 // CHECK-LABEL: @while_unused_result
 func @while_unused_result() -> i32 {
   %0:2 = scf.while () : () -> (i32, i64) {