namespace mlir {
+/// Creates a pass that bufferizes the SCF dialect.
+std::unique_ptr<Pass> createSCFBufferizePass();
+
/// Creates a pass that specializes for loop for unrolling and
/// vectorization.
std::unique_ptr<Pass> createForLoopSpecializationPass();
include "mlir/Pass/PassBase.td"
+def SCFBufferize : FunctionPass<"scf-bufferize"> {
+ let summary = "Bufferize the scf dialect.";
+ let constructor = "mlir::createSCFBufferizePass()";
+}
+
def SCFForLoopSpecialization
: FunctionPass<"for-loop-specialization"> {
let summary = "Specialize `for` loops for vectorization";
namespace mlir {
+class ConversionTarget;
+class MLIRContext;
+class OwningRewritePatternList;
class Region;
+class TypeConverter;
namespace scf {
/// The old loop is replaced with the new one.
void tileParallelLoop(ParallelOp op, llvm::ArrayRef<int64_t> tileSizes);
+/// Populates patterns for SCF structural type conversions and sets up the
+/// provided ConversionTarget with the appropriate legality configuration for
+/// the ops to get converted properly.
+///
+/// A "structural" type conversion is one where the underlying ops are
+/// completely agnostic to the actual types involved and simply need to update
+/// their types. An example of this is scf.if -- the scf.if op and the
+/// corresponding scf.yield ops need to update their types accordingly to the
+/// TypeConverter, but otherwise don't care what type conversions are happening.
+void populateSCFStructuralTypeConversionsAndLegality(
+ MLIRContext *context, TypeConverter &typeConverter,
+ OwningRewritePatternList &patterns, ConversionTarget &target);
+
} // namespace scf
} // namespace mlir
SmallVector<DecomposeTypeConversionCallFn, 2> decomposeTypeConversions;
};
+/// Marks ops used by bufferization for type conversion materializations as
+/// "legal" in the given ConversionTarget.
+///
+/// This function should be called by all bufferization passes using
+/// BufferizeTypeConverter so that materializations work proprely. One exception
+/// is bufferization passes doing "full" conversions, where it can be desirable
+/// for even the materializations to remain illegal so that they are eliminated.
+void populateBufferizeMaterializationLegality(ConversionTarget &target);
+
/// Helper conversion pattern that encapsulates a BufferizeTypeConverter
/// instance.
template <typename SourceOp>
--- /dev/null
+//===- Bufferize.cpp - scf bufferize pass ---------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Transforms/Bufferize.h"
+#include "PassDetail.h"
+#include "mlir/Dialect/SCF/Passes.h"
+#include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/Dialect/SCF/Transforms.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+using namespace mlir;
+using namespace mlir::scf;
+
+namespace {
+struct SCFBufferizePass : public SCFBufferizeBase<SCFBufferizePass> {
+ void runOnFunction() override {
+ auto func = getOperation();
+ auto *context = &getContext();
+
+ BufferizeTypeConverter typeConverter;
+ OwningRewritePatternList patterns;
+ ConversionTarget target(*context);
+
+ populateBufferizeMaterializationLegality(target);
+ populateSCFStructuralTypeConversionsAndLegality(context, typeConverter,
+ patterns, target);
+ if (failed(applyPartialConversion(func, target, patterns)))
+ return signalPassFailure();
+ };
+};
+} // end anonymous namespace
+
+std::unique_ptr<Pass> mlir::createSCFBufferizePass() {
+ return std::make_unique<SCFBufferizePass>();
+}
add_mlir_dialect_library(MLIRSCFTransforms
+ Bufferize.cpp
LoopSpecialization.cpp
ParallelLoopFusion.cpp
ParallelLoopTiling.cpp
+ StructuralTypeConversions.cpp
Utils.cpp
ADDITIONAL_HEADER_DIRS
--- /dev/null
+//===- StructuralTypeConversions.cpp - scf structural type conversions ----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "PassDetail.h"
+#include "mlir/Dialect/SCF/Passes.h"
+#include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/Dialect/SCF/Transforms.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+using namespace mlir;
+using namespace mlir::scf;
+
+namespace {
+class ConvertForOpTypes : public OpConversionPattern<ForOp> {
+public:
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(ForOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ SmallVector<Type, 6> newResultTypes;
+ for (auto type : op.getResultTypes()) {
+ Type newType = typeConverter->convertType(type);
+ if (!newType)
+ return rewriter.notifyMatchFailure(op, "not a 1:1 type conversion");
+ newResultTypes.push_back(newType);
+ }
+
+ // Clone and replace.
+ ForOp newOp = cast<ForOp>(rewriter.clone(*op.getOperation()));
+ newOp.getOperation()->setOperands(operands);
+ for (auto t : llvm::zip(newOp.getResults(), newResultTypes))
+ std::get<0>(t).setType(std::get<1>(t));
+ auto bodyArgs = newOp.getBody()->getArguments();
+ for (auto t : llvm::zip(llvm::drop_begin(bodyArgs, 1), newResultTypes))
+ std::get<0>(t).setType(std::get<1>(t));
+ rewriter.replaceOp(op, newOp.getResults());
+
+ return success();
+ }
+};
+} // namespace
+
+namespace {
+class ConvertIfOpTypes : public OpConversionPattern<IfOp> {
+public:
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(IfOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ // TODO: Generalize this to any type conversion, not just 1:1.
+ //
+ // We need to implement something more sophisticated here that tracks which
+ // types convert to which other types and does the appropriate
+ // materialization logic.
+ // For example, it's possible that one result type converts to 0 types and
+ // another to 2 types, so newResultTypes would at least be the right size to
+ // not crash in the llvm::zip call below, but then we would set the the
+ // wrong type on the SSA values! These edge cases are also why we cannot
+ // safely use the TypeConverter::convertTypes helper here.
+ SmallVector<Type, 6> newResultTypes;
+ for (auto type : op.getResultTypes()) {
+ Type newType = typeConverter->convertType(type);
+ if (!newType)
+ return rewriter.notifyMatchFailure(op, "not a 1:1 type conversion");
+ newResultTypes.push_back(newType);
+ }
+
+ // TODO: Write this with updateRootInPlace once the conversion infra
+ // supports source materializations on ops updated in place.
+ IfOp newOp = cast<IfOp>(rewriter.clone(*op.getOperation()));
+ newOp.getOperation()->setOperands(operands);
+ for (auto t : llvm::zip(newOp.getResults(), newResultTypes))
+ std::get<0>(t).setType(std::get<1>(t));
+ rewriter.replaceOp(op, newOp.getResults());
+ return success();
+ }
+};
+} // namespace
+
+namespace {
+// When the result types of a ForOp/IfOp get changed, the operand types of the
+// corresponding yield op need to be changed. In order to trigger the
+// appropriate type conversions / materializations, we need a dummy pattern.
+class ConvertYieldOpTypes : public OpConversionPattern<scf::YieldOp> {
+public:
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(scf::YieldOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ rewriter.replaceOpWithNewOp<scf::YieldOp>(op, operands);
+ return success();
+ }
+};
+} // namespace
+
+void mlir::scf::populateSCFStructuralTypeConversionsAndLegality(
+ MLIRContext *context, TypeConverter &typeConverter,
+ OwningRewritePatternList &patterns, ConversionTarget &target) {
+ patterns.insert<ConvertForOpTypes, ConvertIfOpTypes, ConvertYieldOpTypes>(
+ typeConverter, context);
+ target.addDynamicallyLegalOp<ForOp, IfOp>([&](Operation *op) {
+ return typeConverter.isLegal(op->getResultTypes());
+ });
+ target.addDynamicallyLegalOp<scf::YieldOp>([&](scf::YieldOp op) {
+ // We only have conversions for a subset of ops that use scf.yield
+ // terminators.
+ if (!isa<ForOp, IfOp>(op.getParentOp()))
+ return true;
+ return typeConverter.isLegal(op.getOperandTypes());
+ });
+}
return KeepAsFunctionResult;
}
+void mlir::populateBufferizeMaterializationLegality(ConversionTarget &target) {
+ target.addLegalOp<TensorLoadOp, TensorToMemrefOp>();
+};
+
//===----------------------------------------------------------------------===//
// BufferizeFuncOpConverter
//===----------------------------------------------------------------------===//
--- /dev/null
+// RUN: mlir-opt %s -scf-bufferize | FileCheck %s
+
+// CHECK-LABEL: func @if(
+// CHECK-SAME: %[[PRED:.*]]: i1,
+// CHECK-SAME: %[[TRUE_TENSOR:.*]]: tensor<?xf32>,
+// CHECK-SAME: %[[FALSE_TENSOR:.*]]: tensor<?xf32>) -> tensor<?xf32> {
+// CHECK: %[[RESULT_MEMREF:.*]] = scf.if %[[PRED]] -> (memref<?xf32>) {
+// CHECK: %[[TRUE_MEMREF:.*]] = tensor_to_memref %[[TRUE_TENSOR]] : memref<?xf32>
+// CHECK: scf.yield %[[TRUE_MEMREF]] : memref<?xf32>
+// CHECK: } else {
+// CHECK: %[[FALSE_MEMREF:.*]] = tensor_to_memref %[[FALSE_TENSOR]] : memref<?xf32>
+// CHECK: scf.yield %[[FALSE_MEMREF]] : memref<?xf32>
+// CHECK: }
+// CHECK: %[[RESULT_TENSOR:.*]] = tensor_load %[[RESULT_MEMREF:.*]] : memref<?xf32>
+// CHECK: return %[[RESULT_TENSOR]] : tensor<?xf32>
+// CHECK: }
+func @if(%pred: i1, %true_val: tensor<?xf32>, %false_val: tensor<?xf32>) -> tensor<?xf32> {
+ %0 = scf.if %pred -> (tensor<?xf32>) {
+ scf.yield %true_val : tensor<?xf32>
+ } else {
+ scf.yield %false_val : tensor<?xf32>
+ }
+ return %0 : tensor<?xf32>
+}
+
+// CHECK-LABEL: func @for(
+// CHECK-SAME: %[[TENSOR:.*]]: tensor<f32>,
+// CHECK-SAME: %[[LB:.*]]: index, %[[UB:.*]]: index,
+// CHECK-SAME: %[[STEP:.*]]: index) -> tensor<f32> {
+// CHECK: %[[MEMREF:.*]] = tensor_to_memref %[[TENSOR]] : memref<f32>
+// CHECK: %[[RESULT_MEMREF:.*]] = scf.for %[[VAL_6:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args(%[[ITER:.*]] = %[[MEMREF]]) -> (memref<f32>) {
+// CHECK: scf.yield %[[ITER]] : memref<f32>
+// CHECK: }
+// CHECK: %[[VAL_8:.*]] = tensor_load %[[VAL_9:.*]] : memref<f32>
+// CHECK: return %[[VAL_8]] : tensor<f32>
+// CHECK: }
+func @for(%arg0: tensor<f32>, %lb: index, %ub: index, %step: index) -> tensor<f32> {
+ %ret = scf.for %iv = %lb to %ub step %step iter_args(%iter = %arg0) -> tensor<f32> {
+ scf.yield %iter : tensor<f32>
+ }
+ return %ret : tensor<f32>
+}