#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Traits.h"
+#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectImplementation.h"
return success();
}
};
+
+// Results of an assuming op that are defined outside its body are available
+// indepentently of the assuming op. There is no need to yield such values. This
+// canonicalization replaces such results with their definition.
+struct AssumingBypassIndependentResult : public OpRewritePattern<AssumingOp> {
+ using OpRewritePattern<AssumingOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(AssumingOp op,
+ PatternRewriter &rewriter) const override {
+ Block *body = op.getBody();
+ auto yieldOp = llvm::dyn_cast<AssumingYieldOp>(body->getTerminator());
+ if (!yieldOp)
+ return failure();
+
+ // See if there is at least one result that can bypass the assuming op.
+ auto isDefinedInBody = [&](Value val) {
+ Operation *def = val.getDefiningOp();
+ return def && op->isAncestor(def);
+ };
+ if (llvm::all_of(yieldOp.operands(), isDefinedInBody))
+ return failure();
+
+ SmallVector<Value, 2> replacementValues;
+ auto newAssumingOp = rewriter.create<shape::AssumingOp>(
+ op.getLoc(), op.witness(), [&](OpBuilder &b, Location loc) {
+ // Copy body.
+ BlockAndValueMapping mapping;
+ for (auto &nested : body->without_terminator())
+ b.clone(nested, mapping);
+
+ // Collect new yielded values.
+ SmallVector<Value, 2> mappedResults;
+ for (auto result : yieldOp.getOperands()) {
+ if (isDefinedInBody(result)) {
+ // This value is a result of the assuming op. We can obtain the
+ // replacement value only after the new op is fully constructed.
+ mappedResults.push_back(mapping.lookup(result));
+ replacementValues.push_back(nullptr);
+ } else {
+ // When defined outside of the assuming block, we can use it
+ // direclty. There is no need to yield the value from within the
+ // block.
+ replacementValues.push_back(result);
+ }
+ }
+ return mappedResults;
+ });
+
+ // Use the assuming op's results for the missing replacement values, which
+ // could not bypass the op.
+ auto src = newAssumingOp.getResults().begin();
+ for (auto &dst : replacementValues) {
+ if (dst)
+ continue;
+ dst = *src++;
+ }
+
+ rewriter.replaceOp(op, replacementValues);
+ return success();
+ }
+};
} // namespace
void AssumingOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
- // If taking a passing witness, inline region.
- patterns.add<AssumingWithTrue>(context);
+ patterns.add<AssumingBypassIndependentResult, AssumingWithTrue>(context);
}
// See RegionBranchOpInterface in Interfaces/ControlFlowInterfaces.td
"use"(%0) : (tensor<?xindex>) -> ()
return
}
+
+// -----
+
+// CHECK-LABEL: @bypass_assmunig
+// CHECK-SAME: (%[[ARG:.*]]: tensor<2x3xf32>)
+func @bypass_assmunig(%arg : tensor<2x3xf32>)
+ -> (tensor<2x3xf32>, tensor<2x3xf32>, tensor<2x3xf32>) {
+ // CHECK: %[[OUTER:.*]] = "some.tensor"
+ // CHECK: %[[WITNESS:.*]] = "some.witness"
+ // CHECK: %[[YIELDED:.*]] = shape.assuming %[[WITNESS]] -> (tensor<2x3xf32>) {
+ // CHECK: %[[INNER:.*]] = "some.tensor"
+ // CHECK: shape.assuming_yield %[[INNER]] : tensor<2x3xf32>
+ // CHECK: }
+ // CHECK: return %[[YIELDED]], %[[OUTER]], %[[ARG]]
+ %outer = "some.tensor"() : () -> tensor<2x3xf32>
+ %witness = "some.witness"() : () -> !shape.witness
+ %results:3 = shape.assuming %witness
+ -> (tensor<2x3xf32>, tensor<2x3xf32>, tensor<2x3xf32>) {
+ %inner = "some.tensor"() : () -> tensor<2x3xf32>
+ shape.assuming_yield %inner, %outer, %arg
+ : tensor<2x3xf32>, tensor<2x3xf32>, tensor<2x3xf32>
+ }
+ return %results#0, %results#1, %results#2
+ : tensor<2x3xf32>, tensor<2x3xf32>, tensor<2x3xf32>
+}