class Operation;
class Value;
-
//===--------------------------------------------------------------------===//
// OperationFolder
//===--------------------------------------------------------------------===//
public:
OperationFolder(MLIRContext *ctx) : interfaces(ctx) {}
+ /// Scan the specified region for constants that can be used in folding,
+ /// moving them to the entry block and adding them to our known-constants
+ /// table.
+ void processExistingConstants(Region ®ion);
+
/// Tries to perform folding on the given `op`, including unifying
/// deduplicated constants. If successful, replaces `op`'s uses with
/// folded results, and returns success. `preReplaceAction` is invoked on `op`
// OperationFolder
//===----------------------------------------------------------------------===//
+/// Scan the specified region for constants that can be used in folding,
+/// moving them to the entry block and adding them to our known-constants
+/// table.
+void OperationFolder::processExistingConstants(Region ®ion) {
+ if (region.empty())
+ return;
+
+ // March the constant insertion point forward, moving all constants to the
+ // top of the block, but keeping them in their order of discovery.
+ Region *insertRegion = getInsertionRegion(interfaces, ®ion.front());
+ auto &uniquedConstants = foldScopes[insertRegion];
+
+ Block &insertBlock = insertRegion->front();
+ Block::iterator constantIterator = insertBlock.begin();
+
+ // Process each constant that we discover in this region.
+ auto processConstant = [&](Operation *op, Attribute value) {
+ // Check to see if we already have an instance of this constant.
+ Operation *&constOp = uniquedConstants[std::make_tuple(
+ op->getDialect(), value, op->getResult(0).getType())];
+
+ // If we already have an instance of this constant, CSE/delete this one as
+ // we go.
+ if (constOp) {
+ if (constantIterator == Block::iterator(op))
+ ++constantIterator; // Don't invalidate our iterator when scanning.
+ op->getResult(0).replaceAllUsesWith(constOp->getResult(0));
+ op->erase();
+ return;
+ }
+
+ // Otherwise, remember that we have this constant.
+ constOp = op;
+ referencedDialects[op].push_back(op->getDialect());
+
+ // If the constant isn't already at the insertion point then move it up.
+ if (constantIterator == insertBlock.end() || &*constantIterator != op)
+ op->moveBefore(&insertBlock, constantIterator);
+ else
+ ++constantIterator; // It was pointing at the constant.
+ };
+
+ SmallVector<Operation *> isolatedOps;
+ region.walk<WalkOrder::PreOrder>([&](Operation *op) {
+ // If this is a constant, process it.
+ Attribute value;
+ if (matchPattern(op, m_Constant(&value))) {
+ processConstant(op, value);
+ // We may have deleted the operation, don't check it for regions.
+ return WalkResult::skip();
+ }
+
+ // If the operation has regions and is isolated, don't recurse into it.
+ if (op->getNumRegions() != 0) {
+ auto hasDifferentInsertRegion = [&](Region ®ion) {
+ return !region.empty() &&
+ getInsertionRegion(interfaces, ®ion.front()) != insertRegion;
+ };
+ if (llvm::any_of(op->getRegions(), hasDifferentInsertRegion)) {
+ isolatedOps.push_back(op);
+ return WalkResult::skip();
+ }
+ }
+
+ // Otherwise keep going.
+ return WalkResult::advance();
+ });
+
+ // Process regions in any isolated ops separately.
+ for (Operation *isolated : isolatedOps) {
+ for (Region ®ion : isolated->getRegions())
+ processExistingConstants(region);
+ }
+}
+
LogicalResult OperationFolder::tryToFold(
Operation *op, function_ref<void(Operation *)> processGeneratedConstants,
function_ref<void(Operation *)> preReplaceAction, bool *inPlaceUpdate) {
Attribute value, Type type, Location loc) {
// Check if an existing mapping already exists.
auto constKey = std::make_tuple(dialect, value, type);
- auto *&constInst = uniquedConstants[constKey];
- if (constInst)
- return constInst;
+ auto *&constOp = uniquedConstants[constKey];
+ if (constOp)
+ return constOp;
// If one doesn't exist, try to materialize one.
- if (!(constInst = materializeConstant(dialect, builder, value, type, loc)))
+ if (!(constOp = materializeConstant(dialect, builder, value, type, loc)))
return nullptr;
// Check to see if the generated constant is in the expected dialect.
- auto *newDialect = constInst->getDialect();
+ auto *newDialect = constOp->getDialect();
if (newDialect == dialect) {
- referencedDialects[constInst].push_back(dialect);
- return constInst;
+ referencedDialects[constOp].push_back(dialect);
+ return constOp;
}
// If it isn't, then we also need to make sure that the mapping for the new
// If an existing operation in the new dialect already exists, delete the
// materialized operation in favor of the existing one.
if (auto *existingOp = uniquedConstants.lookup(newKey)) {
- constInst->erase();
+ constOp->erase();
referencedDialects[existingOp].push_back(dialect);
- return constInst = existingOp;
+ return constOp = existingOp;
}
// Otherwise, update the new dialect to the materialized operation.
- referencedDialects[constInst].assign({dialect, newDialect});
- auto newIt = uniquedConstants.insert({newKey, constInst});
+ referencedDialects[constOp].assign({dialect, newDialect});
+ auto newIt = uniquedConstants.insert({newKey, constOp});
return newIt.first->second;
}
// be re-added to the worklist. This function should be called when an
// operation is modified or removed, as it may trigger further
// simplifications.
- template <typename Operands> void addToWorklist(Operands &&operands) {
+ template <typename Operands>
+ void addToWorklist(Operands &&operands) {
for (Value operand : operands) {
// If the use count of this operand is now < 2, we re-add the defining
// operation to the worklist.
/// if the rewrite converges in `maxIterations`.
bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions,
int maxIterations) {
- // Add the given operation to the worklist.
- auto collectOps = [this](Operation *op) { addToWorklist(op); };
+ // Perform a prepass over the IR to discover constants.
+ for (auto ®ion : regions)
+ folder.processExistingConstants(region);
bool changed = false;
- int i = 0;
+ int iteration = 0;
do {
- // Add all nested operations to the worklist.
+ worklist.clear();
+ worklistMap.clear();
+
+ // Add all nested operations to the worklist in preorder.
for (auto ®ion : regions)
- region.walk(collectOps);
+ region.walk<WalkOrder::PreOrder>(
+ [this](Operation *op) { worklist.push_back(op); });
+
+ // Reverse the list so our pop-back loop processes them in-order.
+ std::reverse(worklist.begin(), worklist.end());
+ // Remember the reverse index.
+ for (unsigned i = 0, e = worklist.size(); i != e; ++i)
+ worklistMap[worklist[i]] = i;
// These are scratch vectors used in the folding loop below.
SmallVector<Value, 8> originalOperands, resultValues;
notifyOperationRemoved(op);
};
+ // Add the given operation to the worklist.
+ auto collectOps = [this](Operation *op) { addToWorklist(op); };
+
// Try to fold this op.
bool inPlaceUpdate;
if ((succeeded(folder.tryToFold(op, collectOps, preReplaceAction,
// After applying patterns, make sure that the CFG of each of the regions is
// kept up to date.
changed |= succeeded(simplifyRegions(*this, regions));
- } while (changed && ++i < maxIterations);
+ } while (changed && ++iteration < maxIterations);
+
// Whether the rewrite converges, i.e. wasn't changed in the last iteration.
return !changed;
}
// CHECK-DAG: %[[C0:.*]] = constant 0 : index
// CHECK-DAG: %[[splat:.*]] = constant dense<7.000000e+00> : vector<15xf32>
// CHECK-DAG: %[[alloc:.*]] = memref.alloca() : memref<3xvector<15xf32>>
+ // CHECK-DAG: [[CST:%.*]] = constant 7.000000e+00 : f32
// CHECK-DAG: %[[dim:.*]] = memref.dim %[[A]], %[[C0]] : memref<?x?xf32>
// CHECK: affine.for %[[I:.*]] = 0 to 3 {
// CHECK: %[[add:.*]] = affine.apply #[[$MAP0]](%[[I]])[%[[base]]]
// CHECK: %[[cond1:.*]] = cmpi slt, %[[add]], %[[dim]] : index
// CHECK: scf.if %[[cond1]] {
- // CHECK: %[[vec_1d:.*]] = vector.transfer_read %[[A]][%[[add]], %[[base]]], %cst : memref<?x?xf32>, vector<15xf32>
+ // CHECK: %[[vec_1d:.*]] = vector.transfer_read %[[A]][%[[add]], %[[base]]], [[CST]] : memref<?x?xf32>, vector<15xf32>
// CHECK: store %[[vec_1d]], %[[alloc]][%[[I]]] : memref<3xvector<15xf32>>
// CHECK: } else {
// CHECK: store %[[splat]], %[[alloc]][%[[I]]] : memref<3xvector<15xf32>>
// CHECK: %[[vmemref:.*]] = vector.type_cast %[[alloc]] : memref<3xvector<15xf32>> to memref<vector<3x15xf32>>
// CHECK: %[[cst:.*]] = memref.load %[[vmemref]][] : memref<vector<3x15xf32>>
- // FULL-UNROLL: %[[VEC0:.*]] = constant dense<7.000000e+00> : vector<3x15xf32>
- // FULL-UNROLL: %[[C0:.*]] = constant 0 : index
- // FULL-UNROLL: %[[SPLAT:.*]] = constant dense<7.000000e+00> : vector<15xf32>
+ // FULL-UNROLL-DAG: %[[VEC0:.*]] = constant dense<7.000000e+00> : vector<3x15xf32>
+ // FULL-UNROLL-DAG: %[[C0:.*]] = constant 0 : index
+ // FULL-UNROLL-DAG: %[[SPLAT:.*]] = constant dense<7.000000e+00> : vector<15xf32>
+ // FULL-UNROLL-DAG: [[CST:%.*]] = constant 7.000000e+00 : f32
// FULL-UNROLL: %[[DIM:.*]] = memref.dim %[[A]], %[[C0]] : memref<?x?xf32>
// FULL-UNROLL: cmpi slt, %[[base]], %[[DIM]] : index
// FULL-UNROLL: %[[VEC1:.*]] = scf.if %{{.*}} -> (vector<3x15xf32>) {
- // FULL-UNROLL: vector.transfer_read %[[A]][%[[base]], %[[base]]], %cst : memref<?x?xf32>, vector<15xf32>
+ // FULL-UNROLL: vector.transfer_read %[[A]][%[[base]], %[[base]]], [[CST]] : memref<?x?xf32>, vector<15xf32>
// FULL-UNROLL: vector.insert %{{.*}}, %[[VEC0]] [0] : vector<15xf32> into vector<3x15xf32>
// FULL-UNROLL: scf.yield %{{.*}} : vector<3x15xf32>
// FULL-UNROLL: } else {
// FULL-UNROLL: affine.apply #[[$MAP1]]()[%[[base]]]
// FULL-UNROLL: cmpi slt, %{{.*}}, %[[DIM]] : index
// FULL-UNROLL: %[[VEC2:.*]] = scf.if %{{.*}} -> (vector<3x15xf32>) {
- // FULL-UNROLL: vector.transfer_read %[[A]][%{{.*}}, %[[base]]], %cst : memref<?x?xf32>, vector<15xf32>
+ // FULL-UNROLL: vector.transfer_read %[[A]][%{{.*}}, %[[base]]], [[CST]] : memref<?x?xf32>, vector<15xf32>
// FULL-UNROLL: vector.insert %{{.*}}, %[[VEC1]] [1] : vector<15xf32> into vector<3x15xf32>
// FULL-UNROLL: scf.yield %{{.*}} : vector<3x15xf32>
// FULL-UNROLL: } else {
// FULL-UNROLL: affine.apply #[[$MAP2]]()[%[[base]]]
// FULL-UNROLL: cmpi slt, %{{.*}}, %[[DIM]] : index
// FULL-UNROLL: %[[VEC3:.*]] = scf.if %{{.*}} -> (vector<3x15xf32>) {
- // FULL-UNROLL: vector.transfer_read %[[A]][%{{.*}}, %[[base]]], %cst : memref<?x?xf32>, vector<15xf32>
+ // FULL-UNROLL: vector.transfer_read %[[A]][%{{.*}}, %[[base]]], [[CST]] : memref<?x?xf32>, vector<15xf32>
// FULL-UNROLL: vector.insert %{{.*}}, %[[VEC2]] [2] : vector<15xf32> into vector<3x15xf32>
// FULL-UNROLL: scf.yield %{{.*}} : vector<3x15xf32>
// FULL-UNROLL: } else {
// CHECK-LABEL: transfer_read_minor_identity(
// CHECK-SAME: %[[A:.*]]: memref<?x?x?x?xf32>) -> vector<3x3xf32>
-// CHECK-DAG: %[[c0:.*]] = constant 0 : index
-// CHECK-DAG: %cst = constant 0.000000e+00 : f32
// CHECK-DAG: %[[c2:.*]] = constant 2 : index
// CHECK-DAG: %[[cst0:.*]] = constant dense<0.000000e+00> : vector<3xf32>
// CHECK: %[[m:.*]] = memref.alloca() : memref<3xvector<3xf32>>
+// CHECK-DAG: %[[cst:.*]] = constant 0.000000e+00 : f32
+// CHECK-DAG: %[[c0:.*]] = constant 0 : index
// CHECK: %[[d:.*]] = memref.dim %[[A]], %[[c2]] : memref<?x?x?x?xf32>
// CHECK: affine.for %[[arg1:.*]] = 0 to 3 {
// CHECK: %[[cmp:.*]] = cmpi slt, %[[arg1]], %[[d]] : index
// CHECK: scf.if %[[cmp]] {
-// CHECK: %[[tr:.*]] = vector.transfer_read %[[A]][%c0, %c0, %[[arg1]], %c0], %cst : memref<?x?x?x?xf32>, vector<3xf32>
+// CHECK: %[[tr:.*]] = vector.transfer_read %[[A]][%c0, %c0, %[[arg1]], %c0], %[[cst]] : memref<?x?x?x?xf32>, vector<3xf32>
// CHECK: store %[[tr]], %[[m]][%[[arg1]]] : memref<3xvector<3xf32>>
// CHECK: } else {
// CHECK: store %[[cst0]], %[[m]][%[[arg1]]] : memref<3xvector<3xf32>>
// CHECK-SAME: %[[A:.*]]: vector<3x3xf32>,
// CHECK-SAME: %[[B:.*]]: memref<?x?x?x?xf32>)
// CHECK-DAG: %[[c2:.*]] = constant 2 : index
-// CHECK-DAG: %[[c0:.*]] = constant 0 : index
// CHECK: %[[m:.*]] = memref.alloca() : memref<3xvector<3xf32>>
+// CHECK-DAG: %[[c0:.*]] = constant 0 : index
// CHECK: %[[cast:.*]] = vector.type_cast %[[m]] : memref<3xvector<3xf32>> to memref<vector<3x3xf32>>
// CHECK: store %[[A]], %[[cast]][] : memref<vector<3x3xf32>>
// CHECK: %[[d:.*]] = memref.dim %[[B]], %[[c2]] : memref<?x?x?x?xf32>
// -----
-// CHECK-DAG: #[[$MAP14:.*]] = affine_map<()[s0, s1] -> (((s1 + s0) * 4) floordiv s0)>
+// CHECK-DAG: #[[$MAP14:.*]] = affine_map<()[s0, s1] -> ((s0 * 4 + s1 * 4) floordiv s0)>
// CHECK-LABEL: func @compose_affine_maps_multiple_symbols
func @compose_affine_maps_multiple_symbols(%arg0: index, %arg1: index) -> index {
// -----
-// CHECK-DAG: #[[$MAP_symbolic_composition_d:.*]] = affine_map<()[s0, s1] -> (s0 + s1 * 3)>
+// CHECK-DAG: #[[$MAP_symbolic_composition_d:.*]] = affine_map<()[s0, s1] -> (s0 * 3 + s1)>
// CHECK-LABEL: func @symbolic_composition_d(
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: index
%0 = affine.apply affine_map<(d0) -> (d0)>(%arg0)
%1 = affine.apply affine_map<()[s0] -> (s0)>()[%arg1]
%2 = affine.apply affine_map<()[s0, s1, s2, s3] -> (s0 + s1 + s2 + s3)>()[%0, %0, %0, %1]
- // CHECK: %{{.*}} = affine.apply #[[$MAP_symbolic_composition_d]]()[%[[ARG1]], %[[ARG0]]]
+ // CHECK: %{{.*}} = affine.apply #[[$MAP_symbolic_composition_d]]()[%[[ARG0]], %[[ARG1]]]
return %2 : index
}
return
}
// CHECK-LABEL: func @aligned_promote_fill
-// CHECK: %[[cf:.*]] = constant {{.*}} : f32
+// CHECK: %[[cf:.*]] = constant 1.0{{.*}} : f32
// CHECK: %[[s0:.*]] = memref.subview {{%.*}}[{{%.*}}, {{%.*}}] [{{%.*}}, {{%.*}}] [{{%.*}}, {{%.*}}] : memref<?x?xf32, #map{{.*}}> to memref<?x?xf32, #map{{.*}}>
// CHECK: %[[a0:.*]] = memref.alloc({{%.*}}) {alignment = 32 : i64} : memref<?xi8>
// CHECK: %[[v0:.*]] = memref.view %[[a0]][{{.*}}][{{%.*}}, {{%.*}}] : memref<?xi8> to memref<?x?xf32>
// CHECK: [[T0:%.*]] = vector.transpose [[ARG]], [2, 1, 0]
%0 = vector.transpose %arg, [1, 2, 0] : vector<4x3x2xf32> to vector<3x2x4xf32>
%1 = vector.transpose %0, [1, 0, 2] : vector<3x2x4xf32> to vector<2x3x4xf32>
- // CHECK-NOT: transpose
+ // CHECK: [[T1:%.*]] = vector.transpose [[ARG]], [2, 1, 0]
%2 = vector.transpose %1, [2, 1, 0] : vector<2x3x4xf32> to vector<4x3x2xf32>
%3 = vector.transpose %2, [2, 1, 0] : vector<4x3x2xf32> to vector<2x3x4xf32>
- // CHECK: [[MUL:%.*]] = mulf [[T0]], [[T0]]
+ // CHECK: [[MUL:%.*]] = mulf [[T0]], [[T1]]
%4 = mulf %1, %3 : vector<2x3x4xf32>
// CHECK: [[T5:%.*]] = vector.transpose [[MUL]], [2, 1, 0]
%5 = vector.transpose %4, [2, 1, 0] : vector<2x3x4xf32> to vector<4x3x2xf32>
//
// CHECK-LABEL: func @lowered_affine_ceildiv
func @lowered_affine_ceildiv() -> (index, index) {
-// CHECK-NEXT: %c-1 = constant -1 : index
+// CHECK-DAG: %c-1 = constant -1 : index
%c-43 = constant -43 : index
%c42 = constant 42 : index
%c0 = constant 0 : index
%5 = subi %c0, %4 : index
%6 = addi %4, %c1 : index
%7 = select %0, %5, %6 : index
-// CHECK-NEXT: %c2 = constant 2 : index
+// CHECK-DAG: %c2 = constant 2 : index
%c43 = constant 43 : index
%c42_0 = constant 42 : index
%c0_1 = constant 0 : index
%0 = "test.op_a"(%arg0) {attr = 10 : i32} : (i32) -> i32 loc("a")
%result = "test.op_a"(%0) {attr = 20 : i32} : (i32) -> i32 loc("b")
- // CHECK: "test.op_b"(%arg0) {attr = 10 : i32} : (i32) -> i32 loc("a")
- // CHECK: "test.op_b"(%arg0) {attr = 20 : i32} : (i32) -> i32 loc(fused["b", "a"])
+ // CHECK: %0 = "test.op_b"(%arg0) {attr = 10 : i32} : (i32) -> i32 loc("a")
+ // CHECK: %1 = "test.op_b"(%0) {attr = 20 : i32} : (i32) -> i32 loc("b")
return %result : i32
}
%2 = "test.op_g"(%1) : (i32) -> i32
// CHECK: "test.op_f"(%arg0)
- // CHECK: "test.op_b"(%arg0) {attr = 34 : i32}
+ // CHECK: "test.op_b"(%arg0) {attr = 20 : i32}
return %0 : i32
}