});
}
-/// A conversion patttern for detensoring `linalg.generic` ops.
+/// A conversion pattern for detensoring `linalg.generic` ops.
class DetensorizeGenericOp : public OpConversionPattern<GenericOp> {
public:
using OpConversionPattern::OpConversionPattern;
ConversionPatternRewriter &rewriter) const override {
Block *originalBlock = op->getBlock();
- // Gather some information about the op before inling its region.
+ // Gather some information about the op before inlining its region.
Block *opEntryBlock = &*op.getRegion().begin();
YieldOp yieldOp = dyn_cast<YieldOp>(op.getRegion().back().getTerminator());
DenseSet<BlockArgument> blockArgsToDetensor;
FunctionOpInterface funcOp = getOperation();
+ // Make sure the entry block of the function doesn't contain any Linalg ops.
+ // Otherwise, it may lead to the signature of the block being changed by the
+ // dialect conversion below, which would make the function op invalid
+ // because its type shouldn't change.
+ IRRewriter rewriter(funcOp->getContext());
+ Block *entryBlock = &funcOp.getFunctionBody().front();
+ Block *postEntryBlock =
+ rewriter.splitBlock(entryBlock, entryBlock->begin());
+ rewriter.setInsertionPointToStart(entryBlock);
+ auto branch =
+ rewriter.create<cf::BranchOp>(rewriter.getUnknownLoc(), postEntryBlock);
+
if (aggressiveMode.getValue()) {
AggressiveDetensoringModel costModel;
costModel.compute(funcOp, typeConverter, opsToDetensor,
if (failed(applyPatternsAndFoldGreedily(getOperation(),
std::move(canonPatterns))))
signalPassFailure();
+
+ // Get rid of the dummy entry block we created in the beginning to work
+ // around dialect conversion signature rewriting.
+ rewriter.eraseOp(branch);
+ rewriter.mergeBlocks(postEntryBlock, entryBlock);
}
};
} // namespace
--- /dev/null
+// RUN: mlir-opt %s -allow-unregistered-dialect -pass-pipeline="builtin.module(func.func(linalg-detensorize))" | FileCheck %s
+
+#map = affine_map<() -> ()>
+func.func @main(%arg0: tensor<f32>) -> tensor<f32> {
+ %0 = tensor.empty() : tensor<f32>
+ %1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = []} ins(%arg0 : tensor<f32>) outs(%0 : tensor<f32>) {
+ ^bb0(%in: f32, %out: f32):
+ linalg.yield %in : f32
+ } -> tensor<f32>
+ cf.br ^bb1(%1 : tensor<f32>)
+^bb1(%2: tensor<f32>): // pred: ^bb0
+ return %2 : tensor<f32>
+}
+
+// CHECK-LABEL: @main
+// CHECK-SAME: (%[[ARG0:.+]]: tensor<f32>) -> tensor<f32>
+// CHECK: %[[EXTRACTED:.+]] = tensor.extract %[[ARG0]][] : tensor<f32>
+// CHECK: cf.br ^{{.*}}(%[[EXTRACTED]] : f32)
+// CHECK: ^{{.*}}(%[[ARG1:.+]]: f32):
+// CHECK: %[[ELEMENTS:.+]] = tensor.from_elements %[[ARG1]] : tensor<f32>
+// CHECK: return %[[ELEMENTS]] : tensor<f32>
}
// CHECK-LABEL: func @main
-// CHECK-NEXT: arith.constant 0 : i32
-// CHECK-NEXT: arith.constant 10
+// CHECK-DAG: arith.constant 0 : i32
+// CHECK-DAG: arith.constant 10
// CHECK-NEXT: cf.br ^[[bb1:.*]](%{{.*}} : i32)
// CHECK-NEXT: ^[[bb1]](%{{.*}}: i32)
// CHECK-NEXT: %{{.*}} = arith.cmpi slt, %{{.*}}, %{{.*}}