From 586cebef271f627e80c3535e7cd201915f88b349 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Ingo=20M=C3=BCller?= Date: Tue, 21 Feb 2023 11:21:25 +0000 Subject: [PATCH] [mlir][scf] Implement structural conversion for 1:N type conversions. This patch implements patterns for the newly introduced 1:N type conversion utils for several ops of the SCF dialect. It also adds an option to the existing test pass as well as test cases that applies the patterns through the test pass. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D146959 --- .../mlir/Dialect/SCF/Transforms/Transforms.h | 6 + mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt | 1 + .../SCF/Transforms/OneToNTypeConversion.cpp | 161 +++++++++++++++++++++ .../scf-structural-one-to-n-type-conversion.mlir | 118 +++++++++++++++ .../Conversion/OneToNTypeConversion/CMakeLists.txt | 3 + .../TestOneToNTypeConversionPass.cpp | 7 + 6 files changed, 296 insertions(+) create mode 100644 mlir/lib/Dialect/SCF/Transforms/OneToNTypeConversion.cpp create mode 100644 mlir/test/Conversion/OneToNTypeConversion/scf-structural-one-to-n-type-conversion.mlir diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h b/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h index bfeab9d..fbe73a2 100644 --- a/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h @@ -120,6 +120,12 @@ void populateSCFStructuralTypeConversionsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target); +/// Populates the provided pattern set with patterns that do 1:N type +/// conversions on (some) SCF ops. This is intended to be used with +/// applyPartialOneToNConversion. +void populateSCFStructuralOneToNTypeConversions(TypeConverter &typeConverter, + RewritePatternSet &patterns); + /// Options to dictate how loops should be pipelined. struct PipeliningOption { /// Lambda returning all the operation in the forOp, with their stage, in the diff --git a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt index 3dd9099..20abf2b 100644 --- a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt @@ -6,6 +6,7 @@ add_mlir_dialect_library(MLIRSCFTransforms LoopPipelining.cpp LoopRangeFolding.cpp LoopSpecialization.cpp + OneToNTypeConversion.cpp ParallelLoopCollapsing.cpp ParallelLoopFusion.cpp ParallelLoopTiling.cpp diff --git a/mlir/lib/Dialect/SCF/Transforms/OneToNTypeConversion.cpp b/mlir/lib/Dialect/SCF/Transforms/OneToNTypeConversion.cpp new file mode 100644 index 0000000..74207e6 --- /dev/null +++ b/mlir/lib/Dialect/SCF/Transforms/OneToNTypeConversion.cpp @@ -0,0 +1,161 @@ +//===-- OneToNTypeConversion.cpp - SCF 1:N type conversion ------*- C++ -*-===// +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// The patterns in this file are heavily inspired (and copied from) +// lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp but work for 1:N +// type conversions. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/SCF/Transforms/Transforms.h" + +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Transforms/OneToNTypeConversion.h" + +using namespace mlir; +using namespace mlir::scf; + +class ConvertTypesInSCFIfOp : public OneToNOpConversionPattern { +public: + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(IfOp op, OneToNPatternRewriter &rewriter, + const OneToNTypeMapping & /*operandMapping*/, + const OneToNTypeMapping &resultMapping, + const ValueRange /*convertedOperands*/) const override { + Location loc = op->getLoc(); + + // Nothing to do if there is no non-identity conversion. + if (!resultMapping.hasNonIdentityConversion()) + return failure(); + + // Create new IfOp. + TypeRange convertedResultTypes = resultMapping.getConvertedTypes(); + auto newOp = rewriter.create(loc, convertedResultTypes, + op.getCondition(), true); + newOp->setAttrs(op->getAttrs()); + + // We do not need the empty blocks created by rewriter. + rewriter.eraseBlock(newOp.elseBlock()); + rewriter.eraseBlock(newOp.thenBlock()); + + // Inlines block from the original operation. + rewriter.inlineRegionBefore(op.getThenRegion(), newOp.getThenRegion(), + newOp.getThenRegion().end()); + rewriter.inlineRegionBefore(op.getElseRegion(), newOp.getElseRegion(), + newOp.getElseRegion().end()); + + rewriter.replaceOp(op, SmallVector(newOp->getResults()), + resultMapping); + return success(); + } +}; + +class ConvertTypesInSCFWhileOp : public OneToNOpConversionPattern { +public: + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(WhileOp op, OneToNPatternRewriter &rewriter, + const OneToNTypeMapping &operandMapping, + const OneToNTypeMapping &resultMapping, + const ValueRange convertedOperands) const override { + Location loc = op->getLoc(); + + // Nothing to do if the op doesn't have any non-identity conversions for its + // operands or results. + if (!operandMapping.hasNonIdentityConversion() && + !resultMapping.hasNonIdentityConversion()) + return failure(); + + // Create new WhileOp. + TypeRange convertedResultTypes = resultMapping.getConvertedTypes(); + + auto newOp = + rewriter.create(loc, convertedResultTypes, convertedOperands); + newOp->setAttrs(op->getAttrs()); + + // Update block signatures. + std::array blockMappings = {operandMapping, + resultMapping}; + for (unsigned int i : {0u, 1u}) { + Region *region = &op.getRegion(i); + Block *block = ®ion->front(); + + rewriter.applySignatureConversion(block, blockMappings[i]); + + // Move updated region to new WhileOp. + Region &dstRegion = newOp.getRegion(i); + rewriter.inlineRegionBefore(op.getRegion(i), dstRegion, dstRegion.end()); + } + + rewriter.replaceOp(op, SmallVector(newOp->getResults()), + resultMapping); + return success(); + } +}; + +class ConvertTypesInSCFYieldOp : public OneToNOpConversionPattern { +public: + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(YieldOp op, OneToNPatternRewriter &rewriter, + const OneToNTypeMapping &operandMapping, + const OneToNTypeMapping & /*resultMapping*/, + const ValueRange convertedOperands) const override { + // Nothing to do if there is no non-identity conversion. + if (!operandMapping.hasNonIdentityConversion()) + return failure(); + + // Convert operands. + rewriter.updateRootInPlace(op, [&] { op->setOperands(convertedOperands); }); + + return success(); + } +}; + +class ConvertTypesInSCFConditionOp + : public OneToNOpConversionPattern { +public: + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(ConditionOp op, OneToNPatternRewriter &rewriter, + const OneToNTypeMapping &operandMapping, + const OneToNTypeMapping & /*resultMapping*/, + const ValueRange convertedOperands) const override { + // Nothing to do if there is no non-identity conversion. + if (!operandMapping.hasNonIdentityConversion()) + return failure(); + + // Convert operands. + rewriter.updateRootInPlace(op, [&] { op->setOperands(convertedOperands); }); + + return success(); + } +}; + +namespace mlir { +namespace scf { + +void populateSCFStructuralOneToNTypeConversions(TypeConverter &typeConverter, + RewritePatternSet &patterns) { + patterns.add< + // clang-format off + ConvertTypesInSCFConditionOp, + ConvertTypesInSCFIfOp, + ConvertTypesInSCFWhileOp, + ConvertTypesInSCFYieldOp + // clang-format on + >(typeConverter, patterns.getContext()); +} + +} // namespace scf +} // namespace mlir diff --git a/mlir/test/Conversion/OneToNTypeConversion/scf-structural-one-to-n-type-conversion.mlir b/mlir/test/Conversion/OneToNTypeConversion/scf-structural-one-to-n-type-conversion.mlir new file mode 100644 index 0000000..dd2013c --- /dev/null +++ b/mlir/test/Conversion/OneToNTypeConversion/scf-structural-one-to-n-type-conversion.mlir @@ -0,0 +1,118 @@ +// RUN: mlir-opt %s -split-input-file \ +// RUN: -test-one-to-n-type-conversion="convert-func-ops convert-scf-ops" \ +// RUN: | FileCheck %s + +// Test case: Nested 1:N type conversion is carried through scf.if and +// scf.yield. + +// CHECK-LABEL: func.func @if_result( +// CHECK-SAME: %[[ARG0:.*]]: i1, +// CHECK-SAME: %[[ARG1:.*]]: i2, +// CHECK-SAME: %[[ARG2:.*]]: i1) -> (i1, i2) { +// CHECK-NEXT: %[[V0:.*]]:2 = scf.if %[[ARG2]] -> (i1, i2) { +// CHECK-NEXT: scf.yield %[[ARG0]], %[[ARG1]] : i1, i2 +// CHECK-NEXT: } else { +// CHECK-NEXT: scf.yield %[[ARG0]], %[[ARG1]] : i1, i2 +// CHECK-NEXT: } +// CHECK-NEXT: return %[[V0]]#0, %[[V0]]#1 : i1, i2 +func.func @if_result(%arg0: tuple, i1, tuple>, %arg1: i1) -> tuple, i1, tuple> { + %0 = scf.if %arg1 -> (tuple, i1, tuple>) { + scf.yield %arg0 : tuple, i1, tuple> + } else { + scf.yield %arg0 : tuple, i1, tuple> + } + return %0 : tuple, i1, tuple> +} + +// ----- + +// Test case: Nested 1:N type conversion is carried through scf.if and +// scf.yield and unconverted ops inside have proper materializations. + +// CHECK-LABEL: func.func @if_tuple_ops( +// CHECK-SAME: %[[ARG0:.*]]: i1, +// CHECK-SAME: %[[ARG1:.*]]: i1) -> i1 { +// CHECK-NEXT: %[[V0:.*]] = "test.make_tuple"() : () -> tuple<> +// CHECK-NEXT: %[[V1:.*]] = "test.make_tuple"(%[[V0]], %[[ARG0]]) : (tuple<>, i1) -> tuple, i1> +// CHECK-NEXT: %[[V2:.*]] = scf.if %[[ARG1]] -> (i1) { +// CHECK-NEXT: %[[V3:.*]] = "test.op"(%[[V1]]) : (tuple, i1>) -> tuple, i1> +// CHECK-NEXT: %[[V4:.*]] = "test.get_tuple_element"(%[[V3]]) {index = 0 : i32} : (tuple, i1>) -> tuple<> +// CHECK-NEXT: %[[V5:.*]] = "test.get_tuple_element"(%[[V3]]) {index = 1 : i32} : (tuple, i1>) -> i1 +// CHECK-NEXT: scf.yield %[[V5]] : i1 +// CHECK-NEXT: } else { +// CHECK-NEXT: %[[V6:.*]] = "test.source"() : () -> tuple, i1> +// CHECK-NEXT: %[[V7:.*]] = "test.get_tuple_element"(%[[V6]]) {index = 0 : i32} : (tuple, i1>) -> tuple<> +// CHECK-NEXT: %[[V8:.*]] = "test.get_tuple_element"(%[[V6]]) {index = 1 : i32} : (tuple, i1>) -> i1 +// CHECK-NEXT: scf.yield %[[V8]] : i1 +// CHECK-NEXT: } +// CHECK-NEXT: return %[[V2]] : i1 +func.func @if_tuple_ops(%arg0: tuple, i1>, %arg1: i1) -> tuple, i1> { + %0 = scf.if %arg1 -> (tuple, i1>) { + %1 = "test.op"(%arg0) : (tuple, i1>) -> tuple, i1> + scf.yield %1 : tuple, i1> + } else { + %1 = "test.source"() : () -> tuple, i1> + scf.yield %1 : tuple, i1> + } + return %0 : tuple, i1> +} +// ----- + +// Test case: Nested 1:N type conversion is carried through scf.while, +// scf.condition, and scf.yield. + +// CHECK-LABEL: func.func @while_operands_results( +// CHECK-SAME: %[[ARG0:.*]]: i1, +// CHECK-SAME: %[[ARG1:.*]]: i2, +// CHECK-SAME: %[[ARG2:.*]]: i1) -> (i1, i2) { +// %[[V0:.*]]:2 = scf.while (%[[ARG3:.*]] = %[[ARG0]], %[[ARG4:.*]] = %[[ARG1]]) : (i1, i2) -> (i1, i2) { +// scf.condition(%arg2) %[[ARG3]], %[[ARG4]] : i1, i2 +// } do { +// ^bb0(%[[ARG5:.*]]: i1, %[[ARG6:.*]]: i2): +// scf.yield %[[ARG5]], %[[ARG4]] : i1, i2 +// } +// return %[[V0]]#0, %[[V0]]#1 : i1, i2 +func.func @while_operands_results(%arg0: tuple, i1, tuple>, %arg1: i1) -> tuple, i1, tuple> { + %0 = scf.while (%arg2 = %arg0) : (tuple, i1, tuple>) -> tuple, i1, tuple> { + scf.condition(%arg1) %arg2 : tuple, i1, tuple> + } do { + ^bb0(%arg2: tuple, i1, tuple>): + scf.yield %arg2 : tuple, i1, tuple> + } + return %0 : tuple, i1, tuple> +} + +// ----- + +// Test case: Nested 1:N type conversion is carried through scf.while, +// scf.condition, and unconverted ops inside have proper materializations. + +// CHECK-LABEL: func.func @while_tuple_ops( +// CHECK-SAME: %[[ARG0:.*]]: i1, +// CHECK-SAME: %[[ARG1:.*]]: i1) -> i1 { +// CHECK-NEXT: %[[V0:.*]] = scf.while (%[[ARG2:.*]] = %[[ARG0]]) : (i1) -> i1 { +// CHECK-NEXT: %[[V1:.*]] = "test.make_tuple"() : () -> tuple<> +// CHECK-NEXT: %[[V2:.*]] = "test.make_tuple"(%[[V1]], %[[ARG2]]) : (tuple<>, i1) -> tuple, i1> +// CHECK-NEXT: %[[V3:.*]] = "test.op"(%[[V2]]) : (tuple, i1>) -> tuple, i1> +// CHECK-NEXT: %[[V4:.*]] = "test.get_tuple_element"(%[[V3]]) {index = 0 : i32} : (tuple, i1>) -> tuple<> +// CHECK-NEXT: %[[V5:.*]] = "test.get_tuple_element"(%[[V3]]) {index = 1 : i32} : (tuple, i1>) -> i1 +// CHECK-NEXT: scf.condition(%[[ARG1]]) %[[V5]] : i1 +// CHECK-NEXT: } do { +// CHECK-NEXT: ^bb0(%[[ARG3:.*]]: i1): +// CHECK-NEXT: %[[V6:.*]] = "test.source"() : () -> tuple, i1> +// CHECK-NEXT: %[[V7:.*]] = "test.get_tuple_element"(%[[V6]]) {index = 0 : i32} : (tuple, i1>) -> tuple<> +// CHECK-NEXT: %[[V8:.*]] = "test.get_tuple_element"(%[[V6]]) {index = 1 : i32} : (tuple, i1>) -> i1 +// CHECK-NEXT: scf.yield %[[V8]] : i1 +// CHECK-NEXT: } +// CHECK-NEXT: return %[[V0]] : i1 +func.func @while_tuple_ops(%arg0: tuple, i1>, %arg1: i1) -> tuple, i1> { + %0 = scf.while (%arg2 = %arg0) : (tuple, i1>) -> tuple, i1> { + %1 = "test.op"(%arg2) : (tuple, i1>) -> tuple, i1> + scf.condition(%arg1) %1 : tuple, i1> + } do { + ^bb0(%arg2: tuple, i1>): + %1 = "test.source"() : () -> tuple, i1> + scf.yield %1 : tuple, i1> + } + return %0 : tuple, i1> +} diff --git a/mlir/test/lib/Conversion/OneToNTypeConversion/CMakeLists.txt b/mlir/test/lib/Conversion/OneToNTypeConversion/CMakeLists.txt index 4189786..b723022 100644 --- a/mlir/test/lib/Conversion/OneToNTypeConversion/CMakeLists.txt +++ b/mlir/test/lib/Conversion/OneToNTypeConversion/CMakeLists.txt @@ -7,6 +7,9 @@ add_mlir_library(MLIRTestOneToNTypeConversionPass MLIRFuncDialect MLIRFuncTransforms MLIRIR + MLIRPass + MLIRSCFDialect + MLIRSCFTransforms MLIRTestDialect MLIRTransformUtils ) diff --git a/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp b/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp index 220bcb5..c60c323 100644 --- a/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp +++ b/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp @@ -8,6 +8,7 @@ #include "TestDialect.h" #include "mlir/Dialect/Func/Transforms/OneToNFuncConversions.h" +#include "mlir/Dialect/SCF/Transforms/Transforms.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/OneToNTypeConversion.h" @@ -43,6 +44,10 @@ struct TestOneToNTypeConversionPass llvm::cl::desc("Enable conversion on func ops"), llvm::cl::init(false)}; + Option convertSCFOps{*this, "convert-scf-ops", + llvm::cl::desc("Enable conversion on scf ops"), + llvm::cl::init(false)}; + Option convertTupleOps{*this, "convert-tuple-ops", llvm::cl::desc("Enable conversion on tuple ops"), llvm::cl::init(false)}; @@ -237,6 +242,8 @@ void TestOneToNTypeConversionPass::runOnOperation() { populateDecomposeTuplesTestPatterns(typeConverter, patterns); if (convertFuncOps) populateFuncTypeConversionPatterns(typeConverter, patterns); + if (convertSCFOps) + scf::populateSCFStructuralOneToNTypeConversions(typeConverter, patterns); // Run conversion. if (failed(applyPartialOneToNConversion(module, typeConverter, -- 2.7.4