[mlir] Add structural type conversions for SCF dialect.
authorSean Silva <silvasean@google.com>
Fri, 16 Oct 2020 03:17:25 +0000 (20:17 -0700)
committerSean Silva <silvasean@google.com>
Wed, 21 Oct 2020 18:58:27 +0000 (11:58 -0700)
A "structural" type conversion is one where the underlying ops are
completely agnostic to the actual types involved and simply need to update
their types. An example of this is scf.if -- the scf.if op and the
corresponding scf.yield ops need to update their types accordingly to the
TypeConverter, but otherwise don't care what type conversions are happening.

To test the structural type conversions, it is convenient to define a
bufferize pass for a dialect, which exercises them nicely.

Differential Revision: https://reviews.llvm.org/D89757

mlir/include/mlir/Dialect/SCF/Passes.h
mlir/include/mlir/Dialect/SCF/Passes.td
mlir/include/mlir/Dialect/SCF/Transforms.h
mlir/include/mlir/Transforms/Bufferize.h
mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp [new file with mode: 0644]
mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp [new file with mode: 0644]
mlir/lib/Transforms/Bufferize.cpp
mlir/test/Dialect/SCF/bufferize.mlir [new file with mode: 0644]

index 7edb244..f3dda9b 100644 (file)
@@ -17,6 +17,9 @@
 
 namespace mlir {
 
+/// Creates a pass that bufferizes the SCF dialect.
+std::unique_ptr<Pass> createSCFBufferizePass();
+
 /// Creates a pass that specializes for loop for unrolling and
 /// vectorization.
 std::unique_ptr<Pass> createForLoopSpecializationPass();
index 6f3cf0e..6118694 100644 (file)
 
 include "mlir/Pass/PassBase.td"
 
+def SCFBufferize : FunctionPass<"scf-bufferize"> {
+  let summary = "Bufferize the scf dialect.";
+  let constructor = "mlir::createSCFBufferizePass()";
+}
+
 def SCFForLoopSpecialization
     : FunctionPass<"for-loop-specialization"> {
   let summary = "Specialize `for` loops for vectorization";
index 222ad6b..3164d33 100644 (file)
 
 namespace mlir {
 
+class ConversionTarget;
+class MLIRContext;
+class OwningRewritePatternList;
 class Region;
+class TypeConverter;
 
 namespace scf {
 
@@ -42,6 +46,19 @@ void naivelyFuseParallelOps(Region &region);
 /// The old loop is replaced with the new one.
 void tileParallelLoop(ParallelOp op, llvm::ArrayRef<int64_t> tileSizes);
 
+/// Populates patterns for SCF structural type conversions and sets up the
+/// provided ConversionTarget with the appropriate legality configuration for
+/// the ops to get converted properly.
+///
+/// A "structural" type conversion is one where the underlying ops are
+/// completely agnostic to the actual types involved and simply need to update
+/// their types. An example of this is scf.if -- the scf.if op and the
+/// corresponding scf.yield ops need to update their types accordingly to the
+/// TypeConverter, but otherwise don't care what type conversions are happening.
+void populateSCFStructuralTypeConversionsAndLegality(
+    MLIRContext *context, TypeConverter &typeConverter,
+    OwningRewritePatternList &patterns, ConversionTarget &target);
+
 } // namespace scf
 } // namespace mlir
 
index ddc0089..5bee53e 100644 (file)
@@ -143,6 +143,15 @@ private:
   SmallVector<DecomposeTypeConversionCallFn, 2> decomposeTypeConversions;
 };
 
+/// Marks ops used by bufferization for type conversion materializations as
+/// "legal" in the given ConversionTarget.
+///
+/// This function should be called by all bufferization passes using
+/// BufferizeTypeConverter so that materializations work proprely. One exception
+/// is bufferization passes doing "full" conversions, where it can be desirable
+/// for even the materializations to remain illegal so that they are eliminated.
+void populateBufferizeMaterializationLegality(ConversionTarget &target);
+
 /// Helper conversion pattern that encapsulates a BufferizeTypeConverter
 /// instance.
 template <typename SourceOp>
diff --git a/mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp b/mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp
new file mode 100644 (file)
index 0000000..23cf72f
--- /dev/null
@@ -0,0 +1,41 @@
+//===- Bufferize.cpp - scf bufferize pass ---------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Transforms/Bufferize.h"
+#include "PassDetail.h"
+#include "mlir/Dialect/SCF/Passes.h"
+#include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/Dialect/SCF/Transforms.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+using namespace mlir;
+using namespace mlir::scf;
+
+namespace {
+struct SCFBufferizePass : public SCFBufferizeBase<SCFBufferizePass> {
+  void runOnFunction() override {
+    auto func = getOperation();
+    auto *context = &getContext();
+
+    BufferizeTypeConverter typeConverter;
+    OwningRewritePatternList patterns;
+    ConversionTarget target(*context);
+
+    populateBufferizeMaterializationLegality(target);
+    populateSCFStructuralTypeConversionsAndLegality(context, typeConverter,
+                                                    patterns, target);
+    if (failed(applyPartialConversion(func, target, patterns)))
+      return signalPassFailure();
+  };
+};
+} // end anonymous namespace
+
+std::unique_ptr<Pass> mlir::createSCFBufferizePass() {
+  return std::make_unique<SCFBufferizePass>();
+}
index b3b2002..6b516de 100644 (file)
@@ -1,7 +1,9 @@
 add_mlir_dialect_library(MLIRSCFTransforms
+  Bufferize.cpp
   LoopSpecialization.cpp
   ParallelLoopFusion.cpp
   ParallelLoopTiling.cpp
+  StructuralTypeConversions.cpp
   Utils.cpp
 
   ADDITIONAL_HEADER_DIRS
diff --git a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
new file mode 100644 (file)
index 0000000..30a2272
--- /dev/null
@@ -0,0 +1,117 @@
+//===- StructuralTypeConversions.cpp - scf structural type conversions ----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "PassDetail.h"
+#include "mlir/Dialect/SCF/Passes.h"
+#include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/Dialect/SCF/Transforms.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+using namespace mlir;
+using namespace mlir::scf;
+
+namespace {
+class ConvertForOpTypes : public OpConversionPattern<ForOp> {
+public:
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(ForOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    SmallVector<Type, 6> newResultTypes;
+    for (auto type : op.getResultTypes()) {
+      Type newType = typeConverter->convertType(type);
+      if (!newType)
+        return rewriter.notifyMatchFailure(op, "not a 1:1 type conversion");
+      newResultTypes.push_back(newType);
+    }
+
+    // Clone and replace.
+    ForOp newOp = cast<ForOp>(rewriter.clone(*op.getOperation()));
+    newOp.getOperation()->setOperands(operands);
+    for (auto t : llvm::zip(newOp.getResults(), newResultTypes))
+      std::get<0>(t).setType(std::get<1>(t));
+    auto bodyArgs = newOp.getBody()->getArguments();
+    for (auto t : llvm::zip(llvm::drop_begin(bodyArgs, 1), newResultTypes))
+      std::get<0>(t).setType(std::get<1>(t));
+    rewriter.replaceOp(op, newOp.getResults());
+
+    return success();
+  }
+};
+} // namespace
+
+namespace {
+class ConvertIfOpTypes : public OpConversionPattern<IfOp> {
+public:
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(IfOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    // TODO: Generalize this to any type conversion, not just 1:1.
+    //
+    // We need to implement something more sophisticated here that tracks which
+    // types convert to which other types and does the appropriate
+    // materialization logic.
+    // For example, it's possible that one result type converts to 0 types and
+    // another to 2 types, so newResultTypes would at least be the right size to
+    // not crash in the llvm::zip call below, but then we would set the the
+    // wrong type on the SSA values! These edge cases are also why we cannot
+    // safely use the TypeConverter::convertTypes helper here.
+    SmallVector<Type, 6> newResultTypes;
+    for (auto type : op.getResultTypes()) {
+      Type newType = typeConverter->convertType(type);
+      if (!newType)
+        return rewriter.notifyMatchFailure(op, "not a 1:1 type conversion");
+      newResultTypes.push_back(newType);
+    }
+
+    // TODO: Write this with updateRootInPlace once the conversion infra
+    // supports source materializations on ops updated in place.
+    IfOp newOp = cast<IfOp>(rewriter.clone(*op.getOperation()));
+    newOp.getOperation()->setOperands(operands);
+    for (auto t : llvm::zip(newOp.getResults(), newResultTypes))
+      std::get<0>(t).setType(std::get<1>(t));
+    rewriter.replaceOp(op, newOp.getResults());
+    return success();
+  }
+};
+} // namespace
+
+namespace {
+// When the result types of a ForOp/IfOp get changed, the operand types of the
+// corresponding yield op need to be changed. In order to trigger the
+// appropriate type conversions / materializations, we need a dummy pattern.
+class ConvertYieldOpTypes : public OpConversionPattern<scf::YieldOp> {
+public:
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(scf::YieldOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    rewriter.replaceOpWithNewOp<scf::YieldOp>(op, operands);
+    return success();
+  }
+};
+} // namespace
+
+void mlir::scf::populateSCFStructuralTypeConversionsAndLegality(
+    MLIRContext *context, TypeConverter &typeConverter,
+    OwningRewritePatternList &patterns, ConversionTarget &target) {
+  patterns.insert<ConvertForOpTypes, ConvertIfOpTypes, ConvertYieldOpTypes>(
+      typeConverter, context);
+  target.addDynamicallyLegalOp<ForOp, IfOp>([&](Operation *op) {
+    return typeConverter.isLegal(op->getResultTypes());
+  });
+  target.addDynamicallyLegalOp<scf::YieldOp>([&](scf::YieldOp op) {
+    // We only have conversions for a subset of ops that use scf.yield
+    // terminators.
+    if (!isa<ForOp, IfOp>(op.getParentOp()))
+      return true;
+    return typeConverter.isLegal(op.getOperandTypes());
+  });
+}
index 682fd9f..26eabe2 100644 (file)
@@ -72,6 +72,10 @@ BufferizeTypeConverter::getResultConversionKind(Type origin, Type converted) {
   return KeepAsFunctionResult;
 }
 
+void mlir::populateBufferizeMaterializationLegality(ConversionTarget &target) {
+  target.addLegalOp<TensorLoadOp, TensorToMemrefOp>();
+};
+
 //===----------------------------------------------------------------------===//
 // BufferizeFuncOpConverter
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SCF/bufferize.mlir b/mlir/test/Dialect/SCF/bufferize.mlir
new file mode 100644 (file)
index 0000000..01b353d
--- /dev/null
@@ -0,0 +1,42 @@
+// RUN: mlir-opt %s -scf-bufferize | FileCheck %s
+
+// CHECK-LABEL:   func @if(
+// CHECK-SAME:             %[[PRED:.*]]: i1,
+// CHECK-SAME:             %[[TRUE_TENSOR:.*]]: tensor<?xf32>,
+// CHECK-SAME:             %[[FALSE_TENSOR:.*]]: tensor<?xf32>) -> tensor<?xf32> {
+// CHECK:           %[[RESULT_MEMREF:.*]] = scf.if %[[PRED]] -> (memref<?xf32>) {
+// CHECK:             %[[TRUE_MEMREF:.*]] = tensor_to_memref %[[TRUE_TENSOR]] : memref<?xf32>
+// CHECK:             scf.yield %[[TRUE_MEMREF]] : memref<?xf32>
+// CHECK:           } else {
+// CHECK:             %[[FALSE_MEMREF:.*]] = tensor_to_memref %[[FALSE_TENSOR]] : memref<?xf32>
+// CHECK:             scf.yield %[[FALSE_MEMREF]] : memref<?xf32>
+// CHECK:           }
+// CHECK:           %[[RESULT_TENSOR:.*]] = tensor_load %[[RESULT_MEMREF:.*]] : memref<?xf32>
+// CHECK:           return %[[RESULT_TENSOR]] : tensor<?xf32>
+// CHECK:         }
+func @if(%pred: i1, %true_val: tensor<?xf32>, %false_val: tensor<?xf32>) -> tensor<?xf32> {
+  %0 = scf.if %pred -> (tensor<?xf32>) {
+    scf.yield %true_val : tensor<?xf32>
+  } else {
+    scf.yield %false_val : tensor<?xf32>
+  }
+  return %0 : tensor<?xf32>
+}
+
+// CHECK-LABEL:   func @for(
+// CHECK-SAME:              %[[TENSOR:.*]]: tensor<f32>,
+// CHECK-SAME:              %[[LB:.*]]: index, %[[UB:.*]]: index,
+// CHECK-SAME:              %[[STEP:.*]]: index) -> tensor<f32> {
+// CHECK:           %[[MEMREF:.*]] = tensor_to_memref %[[TENSOR]] : memref<f32>
+// CHECK:           %[[RESULT_MEMREF:.*]] = scf.for %[[VAL_6:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args(%[[ITER:.*]] = %[[MEMREF]]) -> (memref<f32>) {
+// CHECK:             scf.yield %[[ITER]] : memref<f32>
+// CHECK:           }
+// CHECK:           %[[VAL_8:.*]] = tensor_load %[[VAL_9:.*]] : memref<f32>
+// CHECK:           return %[[VAL_8]] : tensor<f32>
+// CHECK:         }
+func @for(%arg0: tensor<f32>, %lb: index, %ub: index, %step: index) -> tensor<f32> {
+  %ret = scf.for %iv = %lb to %ub step %step iter_args(%iter = %arg0) -> tensor<f32> {
+    scf.yield %iter : tensor<f32>
+  }
+  return %ret : tensor<f32>
+}