From: Ingo Müller Date: Tue, 21 Feb 2023 11:21:25 +0000 (+0000) Subject: [mlir] Implement pass utils for 1:N type conversions. X-Git-Tag: upstream/17.0.6~13612 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=9c4611f9c7a7055b18f0a30a4c9074b9917e4ab0;p=platform%2Fupstream%2Fllvm.git [mlir] Implement pass utils for 1:N type conversions. The current dialect conversion does not support 1:N type conversions. This commit implements a (poor-man's) dialect conversion pass that does just that. To keep the pass independent of the "real" dialect conversion infrastructure, it provides a specialization of the TypeConverter class that allows for N:1 target materializations, a specialization of the RewritePattern and PatternRewriter classes that automatically add appropriate unrealized casts supporting 1:N type conversions and provide converted operands for implementing subclasses, and a conversion driver that applies the provided patterns and replaces the unrealized casts that haven't folded away with user-provided materializations. The current pass is powerful enough to express many existing manual solutions for 1:N type conversions or extend transforms that previously didn't support them, out of which this patch implements call graph type decomposition (which is currently implemented with a ValueDecomposer that is only used there). The goal of this pass is to illustrate the effect that 1:N type conversions could have, gain experience in how patterns should be written that achieve that effect, and get feedback on how the APIs of the dialect conversion should be extended or changed to support such patterns. The hope is that the "real" dialect conversion eventually supports such patterns, at which point, this pass could be removed again. Reviewed By: springerm Differential Revision: https://reviews.llvm.org/D144469 --- diff --git a/mlir/include/mlir/Dialect/Func/Transforms/OneToNFuncConversions.h b/mlir/include/mlir/Dialect/Func/Transforms/OneToNFuncConversions.h new file mode 100644 index 0000000..2fba342 --- /dev/null +++ b/mlir/include/mlir/Dialect/Func/Transforms/OneToNFuncConversions.h @@ -0,0 +1,26 @@ +//===- OneToNTypeFuncConversions.h - 1:N type conv. for Func ----*- C++ -*-===// +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_FUNC_TRANSFORMS_ONETONTYPEFUNCCONVERSIONS_H +#define MLIR_DIALECT_FUNC_TRANSFORMS_ONETONTYPEFUNCCONVERSIONS_H + +namespace mlir { +class TypeConverter; +class RewritePatternSet; +} // namespace mlir + +namespace mlir { + +// Populates the provided pattern set with patterns that do 1:N type conversions +// on func ops. This is intended to be used with `applyPartialOneToNConversion`. +void populateFuncTypeConversionPatterns(TypeConverter &typeConverter, + RewritePatternSet &patterns); + +} // namespace mlir + +#endif // MLIR_DIALECT_FUNC_TRANSFORMS_ONETONTYPEFUNCCONVERSIONS_H diff --git a/mlir/include/mlir/Transforms/OneToNTypeConversion.h b/mlir/include/mlir/Transforms/OneToNTypeConversion.h new file mode 100644 index 0000000..25beee2 --- /dev/null +++ b/mlir/include/mlir/Transforms/OneToNTypeConversion.h @@ -0,0 +1,256 @@ +//===-- OneToNTypeConversion.h - Utils for 1:N type conversion --*- C++ -*-===// +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file provides utils for implementing (poor-man's) dialect conversion +// passes with 1:N type conversions. +// +// The main function, `applyPartialOneToNConversion`, first applies a set of +// `RewritePattern`s, which produce unrealized casts to convert the operands and +// results from and to the source types, and then replaces all newly added +// unrealized casts by user-provided materializations. For this to work, the +// main function requires a special `TypeConverter`, a special +// `PatternRewriter`, and special RewritePattern`s, which extend their +// respective base classes for 1:N type converions. +// +// Note that this is much more simple-minded than the "real" dialect conversion, +// which checks for legality before applying patterns and does probably many +// other additional things. Ideally, some of the extensions here could be +// integrated there. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TRANSFORMS_ONETONTYPECONVERSION_H +#define MLIR_TRANSFORMS_ONETONTYPECONVERSION_H + +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/SmallVector.h" + +namespace mlir { + +/// Extends `TypeConverter` with 1:N target materializations. Such +/// materializations have to provide the "reverse" of 1:N type conversions, +/// i.e., they need to materialize N values with target types into one value +/// with a source type (which isn't possible in the base class currently). +class OneToNTypeConverter : public TypeConverter { +public: + /// Callback that expresses user-provided materialization logic from the given + /// value to N values of the given types. This is useful for expressing target + /// materializations for 1:N type conversions, which materialize one value in + /// a source type as N values in target types. + using OneToNMaterializationCallbackFn = + std::function>(OpBuilder &, TypeRange, + Value, Location)>; + + /// Creates the mapping of the given range of original types to target types + /// of the conversion and stores that mapping in the given (signature) + /// conversion. This function simply calls + /// `TypeConverter::convertSignatureArgs` and exists here with a different + /// name to reflect the broader semantic. + LogicalResult computeTypeMapping(TypeRange types, + SignatureConversion &result) { + return convertSignatureArgs(types, result); + } + + /// Applies one of the user-provided 1:N target materializations. If several + /// exists, they are tried out in the reverse order in which they have been + /// added until the first one succeeds. If none succeeds, the functions + /// returns `std::nullopt`. + std::optional> + materializeTargetConversion(OpBuilder &builder, Location loc, + TypeRange resultTypes, Value input) const; + + /// Adds a 1:N target materialization to the converter. Such materializations + /// build IR that converts N values with target types into 1 value of the + /// source type. + void addTargetMaterialization(OneToNMaterializationCallbackFn &&callback) { + oneToNTargetMaterializations.emplace_back(std::move(callback)); + } + +private: + SmallVector oneToNTargetMaterializations; +}; + +/// Stores a 1:N mapping of types and provides several useful accessors. This +/// class extends `SignatureConversion`, which already supports 1:N type +/// mappings but lacks some accessors into the mapping as well as access to the +/// original types. +class OneToNTypeMapping : public TypeConverter::SignatureConversion { +public: + OneToNTypeMapping(TypeRange originalTypes) + : TypeConverter::SignatureConversion(originalTypes.size()), + originalTypes(originalTypes) {} + + using TypeConverter::SignatureConversion::getConvertedTypes; + + /// Returns the list of types that corresponds to the original type at the + /// given index. + TypeRange getConvertedTypes(unsigned originalTypeNo) const; + + /// Returns the list of original types. + TypeRange getOriginalTypes() const { return originalTypes; } + + /// Returns the slice of converted values that corresponds the original value + /// at the given index. + ValueRange getConvertedValues(ValueRange convertedValues, + unsigned originalValueNo) const; + + /// Fills the given result vector with as many copies of the location of the + /// original value as the number of values it is converted to. + void convertLocation(Value originalValue, unsigned originalValueNo, + llvm::SmallVectorImpl &result) const; + + /// Fills the given result vector with as many copies of the lociation of each + /// original value as the number of values they are respectively converted to. + void convertLocations(ValueRange originalValues, + llvm::SmallVectorImpl &result) const; + + /// Returns true iff at least one type conversion maps an input type to a type + /// that is different from itself. + bool hasNonIdentityConversion() const; + +private: + llvm::SmallVector originalTypes; +}; + +/// Extends the basic `RewritePattern` class with a type converter member and +/// some accessors to it. This is useful for patterns that are not +/// `ConversionPattern`s but still require access to a type converter. +class RewritePatternWithConverter : public mlir::RewritePattern { +public: + /// Construct a conversion pattern with the given converter, and forward the + /// remaining arguments to RewritePattern. + template + RewritePatternWithConverter(TypeConverter &typeConverter, Args &&...args) + : RewritePattern(std::forward(args)...), + typeConverter(&typeConverter) {} + + /// Return the type converter held by this pattern, or nullptr if the pattern + /// does not require type conversion. + TypeConverter *getTypeConverter() const { return typeConverter; } + + template + std::enable_if_t::value, + ConverterTy *> + getTypeConverter() const { + return static_cast(typeConverter); + } + +protected: + /// A type converter for use by this pattern. + TypeConverter *const typeConverter; +}; + +/// Specialization of `PatternRewriter` that `OneToNConversionPattern`s use. The +/// class provides additional rewrite methods that are specific to 1:N type +/// conversions. +class OneToNPatternRewriter : public PatternRewriter { +public: + OneToNPatternRewriter(MLIRContext *context) : PatternRewriter(context) {} + + /// Replaces the results of the operation with the specified list of values + /// mapped back to the original types as specified in the provided type + /// mapping. That type mapping must match the replaced op (i.e., the original + /// types must be the same as the result types of the op) and the new values + /// (i.e., the converted types must be the same as the types of the new + /// values). + void replaceOp(Operation *op, ValueRange newValues, + const OneToNTypeMapping &resultMapping); + using PatternRewriter::replaceOp; + + /// Applies the given argument conversion to the given block. This consists of + /// replacing each original argument with N arguments as specified in the + /// argument conversion and inserting unrealized casts from the converted + /// values to the original types, which are then used in lieu of the original + /// ones. (Eventually, `applyPartialOneToNConversion` replaces these casts + /// with a user-provided argument materialization if necessary.) This is + /// similar to `ArgConverter::applySignatureConversion` but (1) handles 1:N + /// type conversion properly and probably (2) doesn't handle many other edge + /// cases. + Block *applySignatureConversion(Block *block, + OneToNTypeMapping &argumentConversion); +}; + +/// Base class for patterns with 1:N type conversions. Derived classes have to +/// overwrite the `matchAndRewrite` overlaod that provides additional +/// information for 1:N type conversions. +class OneToNConversionPattern : public RewritePatternWithConverter { +public: + using RewritePatternWithConverter::RewritePatternWithConverter; + + /// This function has to be implemented by base classes and is called from the + /// usual overloads. Like in "normal" `DialectConversion`, the function is + /// provided with the converted operands (which thus have target types). Since + /// 1:N conversion are supported, there is usually no 1:1 relationship between + /// the original and the converted operands. Instead, the provided + /// `operandMapping` can be used to access the converted operands that + /// correspond to a particular original operand. Similarly, `resultMapping` + /// is provided to help with assembling the result values, which may have 1:N + /// correspondences as well. In that case, the original op should be replaced + /// with the overload of `replaceOp` that takes the provided `resultMapping` + /// in order to deal with the mapping of converted result values to their + /// usages in the original types correctly. + virtual LogicalResult matchAndRewrite(Operation *op, + OneToNPatternRewriter &rewriter, + const OneToNTypeMapping &operandMapping, + const OneToNTypeMapping &resultMapping, + ValueRange convertedOperands) const = 0; + + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const final; +}; + +/// This class is a wrapper around `OneToNConversionPattern` for matching +/// against instances of a particular op class. +template +class OneToNOpConversionPattern : public OneToNConversionPattern { +public: + OneToNOpConversionPattern(TypeConverter &typeConverter, MLIRContext *context, + PatternBenefit benefit = 1, + ArrayRef generatedNames = {}) + : OneToNConversionPattern(typeConverter, SourceOp::getOperationName(), + benefit, context, generatedNames) {} + + using OneToNConversionPattern::matchAndRewrite; + + /// Overload that derived classes have to override for their op type. + virtual LogicalResult matchAndRewrite(SourceOp op, + OneToNPatternRewriter &rewriter, + const OneToNTypeMapping &operandMapping, + const OneToNTypeMapping &resultMapping, + ValueRange convertedOperands) const = 0; + + LogicalResult matchAndRewrite(Operation *op, OneToNPatternRewriter &rewriter, + const OneToNTypeMapping &operandMapping, + const OneToNTypeMapping &resultMapping, + ValueRange convertedOperands) const final { + return matchAndRewrite(cast(op), rewriter, operandMapping, + resultMapping, convertedOperands); + } +}; + +/// Applies the given set of patterns recursively on the given op and adds user +/// materializations where necessary. The patterns are expected to be +/// `OneToNConversionPattern`, which help converting the types of the operands +/// and results of the matched ops. The provided type converter is used to +/// convert the operands of matched ops from their original types to operands +/// with different types. Unlike in `DialectConversion`, this supports 1:N type +/// conversions. Those conversions at the "boundary" of the pattern application, +/// where converted results are not consumed by replaced ops that expect the +/// converted operands or vice versa, the function inserts user materializations +/// from the type converter. Also unlike `DialectConversion`, there are no legal +/// or illegal types; the function simply applies the given patterns and does +/// not fail if some ops or types remain unconverted (i.e., the conversion is +/// only "partial"). +LogicalResult +applyPartialOneToNConversion(Operation *op, OneToNTypeConverter &typeConverter, + const FrozenRewritePatternSet &patterns); + +} // namespace mlir + +#endif // MLIR_TRANSFORMS_ONETONTYPECONVERSION_H diff --git a/mlir/lib/Dialect/Func/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Func/Transforms/CMakeLists.txt index 9a5b38b..1720199 100644 --- a/mlir/lib/Dialect/Func/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Func/Transforms/CMakeLists.txt @@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIRFuncTransforms DuplicateFunctionElimination.cpp FuncBufferize.cpp FuncConversions.cpp + OneToNFuncConversions.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Func/Transforms diff --git a/mlir/lib/Dialect/Func/Transforms/OneToNFuncConversions.cpp b/mlir/lib/Dialect/Func/Transforms/OneToNFuncConversions.cpp new file mode 100644 index 0000000..5e8125c --- /dev/null +++ b/mlir/lib/Dialect/Func/Transforms/OneToNFuncConversions.cpp @@ -0,0 +1,132 @@ +//===-- OneToNTypeFuncConversions.cpp - Func 1:N type conversion-*- C++ -*-===// +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// The patterns in this file are heavily inspired (and copied from) +// convertFuncOpTypes in lib/Transforms/Utils/DialectConversion.cpp and the +// patterns in lib/Dialect/Func/Transforms/FuncConversions.cpp but work for 1:N +// type conversions. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Func/Transforms/OneToNFuncConversions.h" + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Transforms/OneToNTypeConversion.h" + +using namespace mlir; +using namespace mlir::func; + +namespace { + +class ConvertTypesInFuncCallOp : public OneToNOpConversionPattern { +public: + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult matchAndRewrite(CallOp op, OneToNPatternRewriter &rewriter, + const OneToNTypeMapping &operandMapping, + const OneToNTypeMapping &resultMapping, + ValueRange convertedOperands) const override { + Location loc = op->getLoc(); + + // Nothing to do if the op doesn't have any non-identity conversions for its + // operands or results. + if (!operandMapping.hasNonIdentityConversion() && + !resultMapping.hasNonIdentityConversion()) + return failure(); + + // Create new CallOp. + auto newOp = rewriter.create(loc, resultMapping.getConvertedTypes(), + convertedOperands); + newOp->setAttrs(op->getAttrs()); + + rewriter.replaceOp(op, newOp->getResults(), resultMapping); + return success(); + } +}; + +class ConvertTypesInFuncFuncOp : public OneToNOpConversionPattern { +public: + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(FuncOp op, OneToNPatternRewriter &rewriter, + const OneToNTypeMapping & /*operandMapping*/, + const OneToNTypeMapping & /*resultMapping*/, + ValueRange /*convertedOperands*/) const override { + auto *typeConverter = getTypeConverter(); + + // Construct mapping for function arguments. + OneToNTypeMapping argumentMapping(op.getArgumentTypes()); + if (failed(typeConverter->computeTypeMapping(op.getArgumentTypes(), + argumentMapping))) + return failure(); + + // Construct mapping for function results. + OneToNTypeMapping funcResultMapping(op.getResultTypes()); + if (failed(typeConverter->computeTypeMapping(op.getResultTypes(), + funcResultMapping))) + return failure(); + + // Nothing to do if the op doesn't have any non-identity conversions for its + // operands or results. + if (!argumentMapping.hasNonIdentityConversion() && + !funcResultMapping.hasNonIdentityConversion()) + return failure(); + + // Update the function signature in-place. + auto newType = FunctionType::get(rewriter.getContext(), + argumentMapping.getConvertedTypes(), + funcResultMapping.getConvertedTypes()); + rewriter.updateRootInPlace(op, [&] { op.setType(newType); }); + + // Update block signatures. + if (!op.isExternal()) { + Region *region = &op.getBody(); + Block *block = ®ion->front(); + rewriter.applySignatureConversion(block, argumentMapping); + } + + return success(); + } +}; + +class ConvertTypesInFuncReturnOp : public OneToNOpConversionPattern { +public: + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult matchAndRewrite(ReturnOp op, OneToNPatternRewriter &rewriter, + const OneToNTypeMapping &operandMapping, + const OneToNTypeMapping & /*resultMapping*/, + ValueRange convertedOperands) const override { + // Nothing to do if there is no non-identity conversion. + if (!operandMapping.hasNonIdentityConversion()) + return failure(); + + // Convert operands. + rewriter.updateRootInPlace(op, [&] { op->setOperands(convertedOperands); }); + + return success(); + } +}; + +} // namespace + +namespace mlir { + +void populateFuncTypeConversionPatterns(TypeConverter &typeConverter, + RewritePatternSet &patterns) { + patterns.add< + // clang-format off + ConvertTypesInFuncCallOp, + ConvertTypesInFuncFuncOp, + ConvertTypesInFuncReturnOp + // clang-format on + >(typeConverter, patterns.getContext()); +} + +} // namespace mlir diff --git a/mlir/lib/Transforms/Utils/CMakeLists.txt b/mlir/lib/Transforms/Utils/CMakeLists.txt index ba8fa20..6892d00 100644 --- a/mlir/lib/Transforms/Utils/CMakeLists.txt +++ b/mlir/lib/Transforms/Utils/CMakeLists.txt @@ -6,6 +6,7 @@ add_mlir_library(MLIRTransformUtils GreedyPatternRewriteDriver.cpp InliningUtils.cpp LoopInvariantCodeMotionUtils.cpp + OneToNTypeConversion.cpp RegionUtils.cpp TopologicalSortUtils.cpp diff --git a/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp b/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp new file mode 100644 index 0000000..c0866f8 --- /dev/null +++ b/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp @@ -0,0 +1,405 @@ +//===-- OneToNTypeConversion.cpp - Utils for 1:N type conversion-*- C++ -*-===// +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Transforms/OneToNTypeConversion.h" + +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/SmallSet.h" + +using namespace llvm; +using namespace mlir; + +std::optional> +OneToNTypeConverter::materializeTargetConversion(OpBuilder &builder, + Location loc, + TypeRange resultTypes, + Value input) const { + for (const OneToNMaterializationCallbackFn &fn : + llvm::reverse(oneToNTargetMaterializations)) { + if (std::optional> result = + fn(builder, resultTypes, input, loc)) + return *result; + } + return std::nullopt; +} + +TypeRange OneToNTypeMapping::getConvertedTypes(unsigned originalTypeNo) const { + TypeRange convertedTypes = getConvertedTypes(); + if (auto mapping = getInputMapping(originalTypeNo)) + return convertedTypes.slice(mapping->inputNo, mapping->size); + return {}; +} + +ValueRange +OneToNTypeMapping::getConvertedValues(ValueRange convertedValues, + unsigned originalValueNo) const { + if (auto mapping = getInputMapping(originalValueNo)) + return convertedValues.slice(mapping->inputNo, mapping->size); + return {}; +} + +void OneToNTypeMapping::convertLocation( + Value originalValue, unsigned originalValueNo, + llvm::SmallVectorImpl &result) const { + if (auto mapping = getInputMapping(originalValueNo)) + result.append(mapping->size, originalValue.getLoc()); +} + +void OneToNTypeMapping::convertLocations( + ValueRange originalValues, llvm::SmallVectorImpl &result) const { + assert(originalValues.size() == getOriginalTypes().size()); + for (auto [i, value] : llvm::enumerate(originalValues)) + convertLocation(value, i, result); +} + +static bool isIdentityConversion(Type originalType, TypeRange convertedTypes) { + return convertedTypes.size() == 1 && convertedTypes[0] == originalType; +} + +bool OneToNTypeMapping::hasNonIdentityConversion() const { + // XXX: I think that the original types and the converted types are the same + // iff there was no non-identity type conversion. If that is true, the + // patterns could actually test whether there is anything useful to do + // without having access to the signature conversion. + for (auto [i, originalType] : llvm::enumerate(originalTypes)) { + TypeRange types = getConvertedTypes(i); + if (!isIdentityConversion(originalType, types)) { + assert(TypeRange(originalTypes) != getConvertedTypes()); + return true; + } + } + assert(TypeRange(originalTypes) == getConvertedTypes()); + return false; +} + +namespace { +enum class CastKind { + // Casts block arguments in the target type back to the source type. (If + // necessary, this cast becomes an argument materialization.) + Argument, + + // Casts other values in the target type back to the source type. (If + // necessary, this cast becomes a source materialization.) + Source, + + // Casts values in the source type to the target type. (If necessary, this + // cast becomes a target materialization.) + Target +}; +} + +/// Mapping of enum values to string values. +StringRef getCastKindName(CastKind kind) { + static const std::unordered_map castKindNames = { + {CastKind::Argument, "argument"}, + {CastKind::Source, "source"}, + {CastKind::Target, "target"}}; + return castKindNames.at(kind); +} + +/// Attribute name that is used to annotate inserted unrealized casts with their +/// kind (source, argument, or target). +static const char *const castKindAttrName = + "__one-to-n-type-conversion_cast-kind__"; + +/// Builds an `UnrealizedConversionCastOp` from the given inputs to the given +/// result types. Returns the result values of the cast. +static ValueRange buildUnrealizedCast(OpBuilder &builder, TypeRange resultTypes, + ValueRange inputs, CastKind kind) { + // Create cast. + Location loc = builder.getUnknownLoc(); + if (!inputs.empty()) + loc = inputs.front().getLoc(); + auto castOp = + builder.create(loc, resultTypes, inputs); + + // Store cast kind as attribute. + auto kindAttr = StringAttr::get(builder.getContext(), getCastKindName(kind)); + castOp->setAttr(castKindAttrName, kindAttr); + + return castOp->getResults(); +} + +/// Builds one `UnrealizedConversionCastOp` for each of the given original +/// values using the respective target types given in the provided conversion +/// mapping and returns the results of these casts. If the conversion mapping of +/// a value maps a type to itself (i.e., is an identity conversion), then no +/// cast is inserted and the original value is returned instead. +/// Note that these unrealized casts are different from target materializations +/// in that they are *always* inserted, even if they immediately fold away, such +/// that patterns always see valid intermediate IR, whereas materializations are +/// only used in the places where the unrealized casts *don't* fold away. +static SmallVector +buildUnrealizedForwardCasts(ValueRange originalValues, + OneToNTypeMapping &conversion, + RewriterBase &rewriter, CastKind kind) { + + // Convert each operand one by one. + SmallVector convertedValues; + convertedValues.reserve(conversion.getConvertedTypes().size()); + for (auto [idx, originalValue] : llvm::enumerate(originalValues)) { + TypeRange convertedTypes = conversion.getConvertedTypes(idx); + + // Identity conversion: keep operand as is. + if (isIdentityConversion(originalValue.getType(), convertedTypes)) { + convertedValues.push_back(originalValue); + continue; + } + + // Non-identity conversion: materialize target types. + ValueRange castResult = + buildUnrealizedCast(rewriter, convertedTypes, originalValue, kind); + convertedValues.append(castResult.begin(), castResult.end()); + } + + return convertedValues; +} + +/// Builds one `UnrealizedConversionCastOp` for each sequence of the given +/// original values to one value of the type they originated from, i.e., a +/// "reverse" conversion from N converted values back to one value of the +/// original type, using the given (forward) type conversion. If a given value +/// was mapped to a value of the same type (i.e., the conversion in the mapping +/// is an identity conversion), then the "converted" value is returned without +/// cast. +/// Note that these unrealized casts are different from source materializations +/// in that they are *always* inserted, even if they immediately fold away, such +/// that patterns always see valid intermediate IR, whereas materializations are +/// only used in the places where the unrealized casts *don't* fold away. +static SmallVector +buildUnrealizedBackwardsCasts(ValueRange convertedValues, + const OneToNTypeMapping &typeConversion, + RewriterBase &rewriter) { + assert(typeConversion.getConvertedTypes() == convertedValues.getTypes()); + + // Create unrealized cast op for each converted result of the op. + SmallVector recastValues; + TypeRange originalTypes = typeConversion.getOriginalTypes(); + recastValues.reserve(originalTypes.size()); + auto convertedValueIt = convertedValues.begin(); + for (auto [idx, originalType] : llvm::enumerate(originalTypes)) { + TypeRange convertedTypes = typeConversion.getConvertedTypes(idx); + size_t numConvertedValues = convertedTypes.size(); + if (isIdentityConversion(originalType, convertedTypes)) { + // Identity conversion: take result as is. + recastValues.push_back(*convertedValueIt); + } else { + // Non-identity conversion: cast back to source type. + ValueRange recastValue = buildUnrealizedCast( + rewriter, originalType, + ValueRange{convertedValueIt, convertedValueIt + numConvertedValues}, + CastKind::Source); + assert(recastValue.size() == 1); + recastValues.push_back(recastValue.front()); + } + convertedValueIt += numConvertedValues; + } + + return recastValues; +} + +void OneToNPatternRewriter::replaceOp(Operation *op, ValueRange newValues, + const OneToNTypeMapping &resultMapping) { + // Create a cast back to the original types and replace the results of the + // original op with those. + assert(newValues.size() == resultMapping.getConvertedTypes().size()); + assert(op->getResultTypes() == resultMapping.getOriginalTypes()); + SmallVector castResults = + buildUnrealizedBackwardsCasts(newValues, resultMapping, *this); + replaceOp(op, castResults); +} + +Block *OneToNPatternRewriter::applySignatureConversion( + Block *block, OneToNTypeMapping &argumentConversion) { + // Split the block at the beginning to get a new block to use for the + // updated signature. + SmallVector locs; + argumentConversion.convertLocations(block->getArguments(), locs); + Block *newBlock = + createBlock(block, argumentConversion.getConvertedTypes(), locs); + replaceAllUsesWith(block, newBlock); + + // Create necessary casts in new block. + SmallVector castResults; + for (auto [i, arg] : llvm::enumerate(block->getArguments())) { + TypeRange convertedTypes = argumentConversion.getConvertedTypes(i); + ValueRange newArgs = + argumentConversion.getConvertedValues(newBlock->getArguments(), i); + if (isIdentityConversion(arg.getType(), convertedTypes)) { + // Identity conversion: take argument as is. + assert(newArgs.size() == 1); + castResults.push_back(newArgs.front()); + } else { + // Non-identity conversion: cast the converted arguments to the original + // type. + PatternRewriter::InsertionGuard g(*this); + setInsertionPointToStart(newBlock); + ValueRange castResult = buildUnrealizedCast(*this, arg.getType(), newArgs, + CastKind::Argument); + assert(castResult.size() == 1); + castResults.push_back(castResult.front()); + } + } + + // Merge old block into new block such that we only have the latter with the + // new signature. + mergeBlocks(block, newBlock, castResults); + + return newBlock; +} + +LogicalResult +OneToNConversionPattern::matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const { + auto *typeConverter = getTypeConverter(); + + // Construct conversion mapping for results. + Operation::result_type_range originalResultTypes = op->getResultTypes(); + OneToNTypeMapping resultMapping(originalResultTypes); + if (failed(typeConverter->computeTypeMapping(originalResultTypes, + resultMapping))) + return failure(); + + // Construct conversion mapping for operands. + Operation::operand_type_range originalOperandTypes = op->getOperandTypes(); + OneToNTypeMapping operandMapping(originalOperandTypes); + if (failed(typeConverter->computeTypeMapping(originalOperandTypes, + operandMapping))) + return failure(); + + // Cast operands to target types. + SmallVector convertedOperands = buildUnrealizedForwardCasts( + op->getOperands(), operandMapping, rewriter, CastKind::Target); + + // Create a `OneToNPatternRewriter` for the pattern, which provides additional + // functionality. + // TODO(ingomueller): I guess it would be better to use only one rewriter + // throughout the whole pass, but that would require to + // drive the pattern application ourselves, which is a lot + // of additional boilerplate code. This seems to work fine, + // so I leave it like this for the time being. + OneToNPatternRewriter oneToNPatternRewriter(rewriter.getContext()); + oneToNPatternRewriter.restoreInsertionPoint(rewriter.saveInsertionPoint()); + oneToNPatternRewriter.setListener(rewriter.getListener()); + + // Apply actual pattern. + if (failed(matchAndRewrite(op, oneToNPatternRewriter, operandMapping, + resultMapping, convertedOperands))) + return failure(); + + return success(); +} + +namespace mlir { + +// This function applies the provided patterns using +// `applyPatternsAndFoldGreedily` and then replaces all newly inserted +// `UnrealizedConversionCastOps` that haven't folded away. ("Backward" casts +// from target to source types inserted by a `OneToNConversionPattern` normally +// fold away with the "forward" casts from source to target types inserted by +// the next pattern.) To understand which casts are "newly inserted", all casts +// inserted by this pass are annotated with a string attribute that also +// documents which kind of the cast (source, argument, or target). +LogicalResult +applyPartialOneToNConversion(Operation *op, OneToNTypeConverter &typeConverter, + const FrozenRewritePatternSet &patterns) { +#ifndef NDEBUG + // Remember existing unrealized casts. This data structure is only used in + // asserts; building it only for that purpose may be an overkill. + SmallSet existingCasts; + op->walk([&](UnrealizedConversionCastOp castOp) { + assert(!castOp->hasAttr(castKindAttrName)); + existingCasts.insert(castOp); + }); +#endif // NDEBUG + + // Apply provided conversion patterns. + if (failed(applyPatternsAndFoldGreedily(op, patterns))) { + emitError(op->getLoc()) << "failed to apply conversion patterns"; + return failure(); + } + + // Find all unrealized casts inserted by the pass that haven't folded away. + SmallVector worklist; + op->walk([&](UnrealizedConversionCastOp castOp) { + if (castOp->hasAttr(castKindAttrName)) { + assert(!existingCasts.contains(castOp)); + worklist.push_back(castOp); + } + }); + + // Replace new casts with user materializations. + IRRewriter rewriter(op->getContext()); + for (UnrealizedConversionCastOp castOp : worklist) { + TypeRange resultTypes = castOp->getResultTypes(); + ValueRange operands = castOp->getOperands(); + StringRef castKind = + castOp->getAttrOfType(castKindAttrName).getValue(); + rewriter.setInsertionPoint(castOp); + +#ifndef NDEBUG + // Determine whether operands or results are already legal to test some + // assumptions for the different kind of materializations. These properties + // are only used it asserts and it may be overkill to compute them. + bool areOperandTypesLegal = llvm::all_of( + operands.getTypes(), [&](Type t) { return typeConverter.isLegal(t); }); + bool areResultsTypesLegal = llvm::all_of( + resultTypes, [&](Type t) { return typeConverter.isLegal(t); }); +#endif // NDEBUG + + // Add materialization and remember materialized results. + SmallVector materializedResults; + if (castKind == getCastKindName(CastKind::Target)) { + // Target materialization. + assert(!areOperandTypesLegal && areResultsTypesLegal && + operands.size() == 1 && "found unexpected target cast"); + std::optional> maybeResults = + typeConverter.materializeTargetConversion( + rewriter, castOp->getLoc(), resultTypes, operands.front()); + if (!maybeResults) { + emitError(castOp->getLoc()) + << "failed to create target materialization"; + return failure(); + } + materializedResults = maybeResults.value(); + } else { + // Source and argument materializations. + assert(areOperandTypesLegal && !areResultsTypesLegal && + resultTypes.size() == 1 && "found unexpected cast"); + std::optional maybeResult; + if (castKind == getCastKindName(CastKind::Source)) { + // Source materialization. + maybeResult = typeConverter.materializeSourceConversion( + rewriter, castOp->getLoc(), resultTypes.front(), + castOp.getOperands()); + } else { + // Argument materialization. + assert(castKind == getCastKindName(CastKind::Argument) && + "unexpected value of cast kind attribute"); + assert(llvm::all_of(operands, + [&](Value v) { return v.isa(); })); + maybeResult = typeConverter.materializeArgumentConversion( + rewriter, castOp->getLoc(), resultTypes.front(), + castOp.getOperands()); + } + if (!maybeResult.has_value() || !maybeResult.value()) { + emitError(castOp->getLoc()) + << "failed to create " << castKind << " materialization"; + return failure(); + } + materializedResults = {maybeResult.value()}; + } + + // Replace the cast with the result of the materialization. + rewriter.replaceOp(castOp, materializedResults); + } + + return success(); +} + +} // namespace mlir diff --git a/mlir/test/Transforms/decompose-call-graph-types.mlir b/mlir/test/Transforms/decompose-call-graph-types.mlir index 604e948..51b63ba 100644 --- a/mlir/test/Transforms/decompose-call-graph-types.mlir +++ b/mlir/test/Transforms/decompose-call-graph-types.mlir @@ -1,5 +1,9 @@ // RUN: mlir-opt %s -split-input-file -test-decompose-call-graph-types | FileCheck %s +// RUN: mlir-opt %s -split-input-file \ +// RUN: -test-one-to-n-type-conversion="convert-func-ops" \ +// RUN: | FileCheck %s --check-prefix=CHECK-12N + // Test case: Most basic case of a 1:N decomposition, an identity function. // CHECK-LABEL: func @identity( @@ -9,6 +13,10 @@ // CHECK: %[[RET0:.*]] = "test.get_tuple_element"(%[[ARG_MATERIALIZED]]) {index = 0 : i32} : (tuple) -> i1 // CHECK: %[[RET1:.*]] = "test.get_tuple_element"(%[[ARG_MATERIALIZED]]) {index = 1 : i32} : (tuple) -> i32 // CHECK: return %[[RET0]], %[[RET1]] : i1, i32 +// CHECK-12N-LABEL: func @identity( +// CHECK-12N-SAME: %[[ARG0:.*]]: i1, +// CHECK-12N-SAME: %[[ARG1:.*]]: i32) -> (i1, i32) { +// CHECK-12N: return %[[ARG0]], %[[ARG1]] : i1, i32 func.func @identity(%arg0: tuple) -> tuple { return %arg0 : tuple } @@ -20,6 +28,9 @@ func.func @identity(%arg0: tuple) -> tuple { // CHECK-LABEL: func @identity_1_to_1_no_materializations( // CHECK-SAME: %[[ARG0:.*]]: i1) -> i1 { // CHECK: return %[[ARG0]] : i1 +// CHECK-12N-LABEL: func @identity_1_to_1_no_materializations( +// CHECK-12N-SAME: %[[ARG0:.*]]: i1) -> i1 { +// CHECK-12N: return %[[ARG0]] : i1 func.func @identity_1_to_1_no_materializations(%arg0: tuple) -> tuple { return %arg0 : tuple } @@ -31,6 +42,9 @@ func.func @identity_1_to_1_no_materializations(%arg0: tuple) -> tuple { // CHECK-LABEL: func @recursive_decomposition( // CHECK-SAME: %[[ARG0:.*]]: i1) -> i1 { // CHECK: return %[[ARG0]] : i1 +// CHECK-12N-LABEL: func @recursive_decomposition( +// CHECK-12N-SAME: %[[ARG0:.*]]: i1) -> i1 { +// CHECK-12N: return %[[ARG0]] : i1 func.func @recursive_decomposition(%arg0: tuple>>) -> tuple>> { return %arg0 : tuple>> } @@ -54,6 +68,10 @@ func.func @recursive_decomposition(%arg0: tuple>>) -> tuple>) -> tuple // CHECK: %[[V10:.*]] = "test.get_tuple_element"(%[[V9]]) {index = 0 : i32} : (tuple) -> i2 // CHECK: return %[[V7]], %[[V10]] : i1, i2 +// CHECK-12N-LABEL: func @mixed_recursive_decomposition( +// CHECK-12N-SAME: %[[ARG0:.*]]: i1, +// CHECK-12N-SAME: %[[ARG1:.*]]: i2) -> (i1, i2) { +// CHECK-12N: return %[[ARG0]], %[[ARG1]] : i1, i2 func.func @mixed_recursive_decomposition(%arg0: tuple, tuple, tuple>>) -> tuple, tuple, tuple>> { return %arg0 : tuple, tuple, tuple>> } @@ -63,6 +81,7 @@ func.func @mixed_recursive_decomposition(%arg0: tuple, tuple, tuple< // Test case: Check decomposition of calls. // CHECK-LABEL: func private @callee(i1, i32) -> (i1, i32) +// CHECK-12N-LABEL: func private @callee(i1, i32) -> (i1, i32) func.func private @callee(tuple) -> tuple // CHECK-LABEL: func @caller( @@ -76,6 +95,11 @@ func.func private @callee(tuple) -> tuple // CHECK: %[[RET0:.*]] = "test.get_tuple_element"(%[[CALL_RESULT_RECOMPOSED]]) {index = 0 : i32} : (tuple) -> i1 // CHECK: %[[RET1:.*]] = "test.get_tuple_element"(%[[CALL_RESULT_RECOMPOSED]]) {index = 1 : i32} : (tuple) -> i32 // CHECK: return %[[RET0]], %[[RET1]] : i1, i32 +// CHECK-12N-LABEL: func @caller( +// CHECK-12N-SAME: %[[ARG0:.*]]: i1, +// CHECK-12N-SAME: %[[ARG1:.*]]: i32) -> (i1, i32) { +// CHECK-12N: %[[V0:.*]]:2 = call @callee(%[[ARG0]], %[[ARG1]]) : (i1, i32) -> (i1, i32) +// CHECK-12N: return %[[V0]]#0, %[[V0]]#1 : i1, i32 func.func @caller(%arg0: tuple) -> tuple { %0 = call @callee(%arg0) : (tuple) -> tuple return %0 : tuple @@ -86,10 +110,15 @@ func.func @caller(%arg0: tuple) -> tuple { // Test case: Type that decomposes to nothing (that is, a 1:0 decomposition). // CHECK-LABEL: func private @callee() +// CHECK-12N-LABEL: func private @callee() func.func private @callee(tuple<>) -> tuple<> + // CHECK-LABEL: func @caller() { // CHECK: call @callee() : () -> () // CHECK: return +// CHECK-12N-LABEL: func @caller() { +// CHECK-12N: call @callee() : () -> () +// CHECK-12N: return func.func @caller(%arg0: tuple<>) -> tuple<> { %0 = call @callee(%arg0) : (tuple<>) -> (tuple<>) return %0 : tuple<> @@ -105,6 +134,11 @@ func.func @caller(%arg0: tuple<>) -> tuple<> { // CHECK: %[[RET0:.*]] = "test.get_tuple_element"(%[[UNCONVERTED_VALUE]]) {index = 0 : i32} : (tuple) -> i1 // CHECK: %[[RET1:.*]] = "test.get_tuple_element"(%[[UNCONVERTED_VALUE]]) {index = 1 : i32} : (tuple) -> i32 // CHECK: return %[[RET0]], %[[RET1]] : i1, i32 +// CHECK-12N-LABEL: func @unconverted_op_result() -> (i1, i32) { +// CHECK-12N: %[[UNCONVERTED_VALUE:.*]] = "test.source"() : () -> tuple +// CHECK-12N: %[[RET0:.*]] = "test.get_tuple_element"(%[[UNCONVERTED_VALUE]]) {index = 0 : i32} : (tuple) -> i1 +// CHECK-12N: %[[RET1:.*]] = "test.get_tuple_element"(%[[UNCONVERTED_VALUE]]) {index = 1 : i32} : (tuple) -> i32 +// CHECK-12N: return %[[RET0]], %[[RET1]] : i1, i32 func.func @unconverted_op_result() -> tuple { %0 = "test.source"() : () -> (tuple) return %0 : tuple @@ -125,6 +159,16 @@ func.func @unconverted_op_result() -> tuple { // CHECK: %[[V4:.*]] = "test.get_tuple_element"(%[[V2]]) {index = 1 : i32} : (tuple>) -> tuple // CHECK: %[[V5:.*]] = "test.get_tuple_element"(%[[V4]]) {index = 0 : i32} : (tuple) -> i32 // CHECK: return %[[V3]], %[[V5]] : i1, i32 +// CHECK-12N-LABEL: func @nested_unconverted_op_result( +// CHECK-12N-SAME: %[[ARG0:.*]]: i1, +// CHECK-12N-SAME: %[[ARG1:.*]]: i32) -> (i1, i32) { +// CHECK-12N: %[[V0:.*]] = "test.make_tuple"(%[[ARG1]]) : (i32) -> tuple +// CHECK-12N: %[[V1:.*]] = "test.make_tuple"(%[[ARG0]], %[[V0]]) : (i1, tuple) -> tuple> +// CHECK-12N: %[[V2:.*]] = "test.op"(%[[V1]]) : (tuple>) -> tuple> +// CHECK-12N: %[[V3:.*]] = "test.get_tuple_element"(%[[V2]]) {index = 0 : i32} : (tuple>) -> i1 +// CHECK-12N: %[[V4:.*]] = "test.get_tuple_element"(%[[V2]]) {index = 1 : i32} : (tuple>) -> tuple +// CHECK-12N: %[[V5:.*]] = "test.get_tuple_element"(%[[V4]]) {index = 0 : i32} : (tuple) -> i32 +// CHECK-12N: return %[[V3]], %[[V5]] : i1, i32 func.func @nested_unconverted_op_result(%arg: tuple>) -> tuple> { %0 = "test.op"(%arg) : (tuple>) -> (tuple>) return %0 : tuple> @@ -136,6 +180,7 @@ func.func @nested_unconverted_op_result(%arg: tuple>) -> tuple (i1, i2, i3, i4, i5, i6) +// CHECK-12N-LABEL: func private @callee(i1, i2, i3, i4, i5, i6) -> (i1, i2, i3, i4, i5, i6) func.func private @callee(tuple<>, i1, tuple, i3, tuple, i6) -> (tuple<>, i1, tuple, i3, tuple, i6) // CHECK-LABEL: func @caller( @@ -153,6 +198,15 @@ func.func private @callee(tuple<>, i1, tuple, i3, tuple, i6) -> (tup // CHECK: %[[RET_TUPLE_0:.*]] = "test.get_tuple_element"(%[[RET_TUPLE]]) {index = 0 : i32} : (tuple) -> i4 // CHECK: %[[RET_TUPLE_1:.*]] = "test.get_tuple_element"(%[[RET_TUPLE]]) {index = 1 : i32} : (tuple) -> i5 // CHECK: return %[[CALL]]#0, %[[CALL]]#1, %[[CALL]]#2, %[[RET_TUPLE_0]], %[[RET_TUPLE_1]], %[[CALL]]#5 : i1, i2, i3, i4, i5, i6 +// CHECK-12N-LABEL: func @caller( +// CHECK-12N-SAME: %[[I1:.*]]: i1, +// CHECK-12N-SAME: %[[I2:.*]]: i2, +// CHECK-12N-SAME: %[[I3:.*]]: i3, +// CHECK-12N-SAME: %[[I4:.*]]: i4, +// CHECK-12N-SAME: %[[I5:.*]]: i5, +// CHECK-12N-SAME: %[[I6:.*]]: i6) -> (i1, i2, i3, i4, i5, i6) { +// CHECK-12N: %[[CALL:.*]]:6 = call @callee(%[[I1]], %[[I2]], %[[I3]], %[[I4]], %[[I5]], %[[I6]]) : (i1, i2, i3, i4, i5, i6) -> (i1, i2, i3, i4, i5, i6) +// CHECK-12N: return %[[CALL]]#0, %[[CALL]]#1, %[[CALL]]#2, %[[CALL]]#3, %[[CALL]]#4, %[[CALL]]#5 : i1, i2, i3, i4, i5, i6 func.func @caller(%arg0: tuple<>, %arg1: i1, %arg2: tuple, %arg3: i3, %arg4: tuple, %arg5: i6) -> (tuple<>, i1, tuple, i3, tuple, i6) { %0, %1, %2, %3, %4, %5 = call @callee(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) : (tuple<>, i1, tuple, i3, tuple, i6) -> (tuple<>, i1, tuple, i3, tuple, i6) return %0, %1, %2, %3, %4, %5 : tuple<>, i1, tuple, i3, tuple, i6 diff --git a/mlir/test/lib/Conversion/CMakeLists.txt b/mlir/test/lib/Conversion/CMakeLists.txt index 14f0e0d..14df652 100644 --- a/mlir/test/lib/Conversion/CMakeLists.txt +++ b/mlir/test/lib/Conversion/CMakeLists.txt @@ -1,2 +1,3 @@ add_subdirectory(FuncToLLVM) +add_subdirectory(OneToNTypeConversion) add_subdirectory(VectorToSPIRV) diff --git a/mlir/test/lib/Conversion/OneToNTypeConversion/CMakeLists.txt b/mlir/test/lib/Conversion/OneToNTypeConversion/CMakeLists.txt new file mode 100644 index 0000000..4189786 --- /dev/null +++ b/mlir/test/lib/Conversion/OneToNTypeConversion/CMakeLists.txt @@ -0,0 +1,18 @@ +add_mlir_library(MLIRTestOneToNTypeConversionPass + TestOneToNTypeConversionPass.cpp + + EXCLUDE_FROM_LIBMLIR + + LINK_LIBS PUBLIC + MLIRFuncDialect + MLIRFuncTransforms + MLIRIR + MLIRTestDialect + MLIRTransformUtils + ) + +target_include_directories(MLIRTestOneToNTypeConversionPass + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/../../Dialect/Test + ${CMAKE_CURRENT_BINARY_DIR}/../../Dialect/Test + ) diff --git a/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp b/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp new file mode 100644 index 0000000..220bcb5 --- /dev/null +++ b/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp @@ -0,0 +1,245 @@ +//===- TestOneToNTypeConversionPass.cpp - Test pass 1:N type conv. utils --===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "TestDialect.h" +#include "mlir/Dialect/Func/Transforms/OneToNFuncConversions.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/OneToNTypeConversion.h" + +using namespace mlir; + +namespace { +/// Test pass that exercises the (poor-man's) 1:N type conversion mechanisms +/// in `applyPartialOneToNConversion` by converting built-in tuples to the +/// elements they consist of as well as some dummy ops operating on these +/// tuples. +struct TestOneToNTypeConversionPass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestOneToNTypeConversionPass) + + TestOneToNTypeConversionPass() = default; + TestOneToNTypeConversionPass(const TestOneToNTypeConversionPass &pass) + : PassWrapper(pass) {} + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + StringRef getArgument() const final { + return "test-one-to-n-type-conversion"; + } + + StringRef getDescription() const final { + return "Test pass for 1:N type conversion"; + } + + Option convertFuncOps{*this, "convert-func-ops", + llvm::cl::desc("Enable conversion on func ops"), + llvm::cl::init(false)}; + + Option convertTupleOps{*this, "convert-tuple-ops", + llvm::cl::desc("Enable conversion on tuple ops"), + llvm::cl::init(false)}; + + void runOnOperation() override; +}; + +} // namespace + +namespace mlir { +namespace test { +void registerTestOneToNTypeConversionPass() { + PassRegistration(); +} +} // namespace test +} // namespace mlir + +namespace { + +/// Test pattern on for the `make_tuple` op from the test dialect that converts +/// this kind of op into it's "decomposed" form, i.e., the elements of the tuple +/// that is being produced by `test.make_tuple`, which are really just the +/// operands of this op. +class ConvertMakeTupleOp + : public OneToNOpConversionPattern<::test::MakeTupleOp> { +public: + using OneToNOpConversionPattern< + ::test::MakeTupleOp>::OneToNOpConversionPattern; + + LogicalResult matchAndRewrite(::test::MakeTupleOp op, + OneToNPatternRewriter &rewriter, + const OneToNTypeMapping &operandMapping, + const OneToNTypeMapping &resultMapping, + ValueRange convertedOperands) const override { + // Simply replace the current op with the converted operands. + rewriter.replaceOp(op, convertedOperands, resultMapping); + return success(); + } +}; + +/// Test pattern on for the `get_tuple_element` op from the test dialect that +/// converts this kind of op into it's "decomposed" form, i.e., instead of +/// "physically" extracting one element from the tuple, we forward the one +/// element of the decomposed form that is being extracted (or the several +/// elements in case that element is a nested tuple). +class ConvertGetTupleElementOp + : public OneToNOpConversionPattern<::test::GetTupleElementOp> { +public: + using OneToNOpConversionPattern< + ::test::GetTupleElementOp>::OneToNOpConversionPattern; + + LogicalResult matchAndRewrite(::test::GetTupleElementOp op, + OneToNPatternRewriter &rewriter, + const OneToNTypeMapping &operandMapping, + const OneToNTypeMapping &resultMapping, + ValueRange convertedOperands) const override { + // Construct mapping for tuple element types. + auto stateType = op->getOperand(0).getType().cast(); + TypeRange originalElementTypes = stateType.getTypes(); + OneToNTypeMapping elementMapping(originalElementTypes); + if (failed(typeConverter->convertSignatureArgs(originalElementTypes, + elementMapping))) + return failure(); + + // Compute converted operands corresponding to original input tuple. + ValueRange convertedTuple = + operandMapping.getConvertedValues(convertedOperands, 0); + + // Got those converted operands that correspond to the index-th element of + // the original input tuple. + size_t index = op.getIndex(); + ValueRange extractedElement = + elementMapping.getConvertedValues(convertedTuple, index); + + rewriter.replaceOp(op, extractedElement, resultMapping); + + return success(); + } +}; + +} // namespace + +static void populateDecomposeTuplesTestPatterns(TypeConverter &typeConverter, + RewritePatternSet &patterns) { + patterns.add< + // clang-format off + ConvertMakeTupleOp, + ConvertGetTupleElementOp + // clang-format on + >(typeConverter, patterns.getContext()); +} + +/// Creates a sequence of `test.get_tuple_element` ops for all elements of a +/// given tuple value. If some tuple elements are, in turn, tuples, the elements +/// of those are extracted recursively such that the returned values have the +/// same types as `resultTypes.getFlattenedTypes()`. +/// +/// This function has been copied (with small adaptions) from +/// TestDecomposeCallGraphTypes.cpp. +static std::optional> +buildGetTupleElementOps(OpBuilder &builder, TypeRange resultTypes, Value input, + Location loc) { + TupleType inputType = input.getType().dyn_cast(); + if (!inputType) + return {}; + + SmallVector values; + for (auto [idx, elementType] : llvm::enumerate(inputType.getTypes())) { + Value element = builder.create<::test::GetTupleElementOp>( + loc, elementType, input, builder.getI32IntegerAttr(idx)); + if (auto nestedTupleType = elementType.dyn_cast()) { + // Recurse if the current element is also a tuple. + SmallVector flatRecursiveTypes; + nestedTupleType.getFlattenedTypes(flatRecursiveTypes); + std::optional> resursiveValues = + buildGetTupleElementOps(builder, flatRecursiveTypes, element, loc); + if (!resursiveValues.has_value()) + return {}; + values.append(resursiveValues.value()); + } else { + values.push_back(element); + } + } + return values; +} + +/// Creates a `test.make_tuple` op out of the given inputs building a tuple of +/// type `resultType`. If that type is nested, each nested tuple is built +/// recursively with another `test.make_tuple` op. +/// +/// This function has been copied (with small adaptions) from +/// TestDecomposeCallGraphTypes.cpp. +static std::optional buildMakeTupleOp(OpBuilder &builder, + TupleType resultType, + ValueRange inputs, Location loc) { + // Build one value for each element at this nesting level. + SmallVector elements; + elements.reserve(resultType.getTypes().size()); + ValueRange::iterator inputIt = inputs.begin(); + for (Type elementType : resultType.getTypes()) { + if (auto nestedTupleType = elementType.dyn_cast()) { + // Determine how many input values are needed for the nested elements of + // the nested TupleType and advance inputIt by that number. + // TODO: We only need the *number* of nested types, not the types itself. + // Maybe it's worth adding a more efficient overload? + SmallVector nestedFlattenedTypes; + nestedTupleType.getFlattenedTypes(nestedFlattenedTypes); + size_t numNestedFlattenedTypes = nestedFlattenedTypes.size(); + ValueRange nestedFlattenedelements(inputIt, + inputIt + numNestedFlattenedTypes); + inputIt += numNestedFlattenedTypes; + + // Recurse on the values for the nested TupleType. + std::optional res = buildMakeTupleOp(builder, nestedTupleType, + nestedFlattenedelements, loc); + if (!res.has_value()) + return {}; + + // The tuple constructed by the conversion is the element value. + elements.push_back(res.value()); + } else { + // Base case: take one input as is. + elements.push_back(*inputIt++); + } + } + + // Assemble the tuple from the elements. + return builder.create<::test::MakeTupleOp>(loc, resultType, elements); +} + +void TestOneToNTypeConversionPass::runOnOperation() { + ModuleOp module = getOperation(); + auto *context = &getContext(); + + // Assemble type converter. + OneToNTypeConverter typeConverter; + + typeConverter.addConversion([](Type type) { return type; }); + typeConverter.addConversion( + [](TupleType tupleType, SmallVectorImpl &types) { + tupleType.getFlattenedTypes(types); + return success(); + }); + + typeConverter.addArgumentMaterialization(buildMakeTupleOp); + typeConverter.addSourceMaterialization(buildMakeTupleOp); + typeConverter.addTargetMaterialization(buildGetTupleElementOps); + + // Assemble patterns. + RewritePatternSet patterns(context); + if (convertTupleOps) + populateDecomposeTuplesTestPatterns(typeConverter, patterns); + if (convertFuncOps) + populateFuncTypeConversionPatterns(typeConverter, patterns); + + // Run conversion. + if (failed(applyPartialOneToNConversion(module, typeConverter, + std::move(patterns)))) + return signalPassFailure(); +} diff --git a/mlir/tools/mlir-opt/CMakeLists.txt b/mlir/tools/mlir-opt/CMakeLists.txt index f84fbe6..c430569 100644 --- a/mlir/tools/mlir-opt/CMakeLists.txt +++ b/mlir/tools/mlir-opt/CMakeLists.txt @@ -33,6 +33,7 @@ if(MLIR_INCLUDE_TESTS) MLIRTestDialect MLIRTestDynDialect MLIRTestIR + MLIRTestOneToNTypeConversionPass MLIRTestPass MLIRTestPDLL MLIRTestReducer diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp index e7ca06b..855c6a6 100644 --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -107,6 +107,7 @@ void registerTestMathAlgebraicSimplificationPass(); void registerTestMathPolynomialApproximationPass(); void registerTestMemRefDependenceCheck(); void registerTestMemRefStrideCalculation(); +void registerTestOneToNTypeConversionPass(); void registerTestOpaqueLoc(); void registerTestPadFusion(); void registerTestPDLByteCodePass(); @@ -218,6 +219,7 @@ void registerTestPasses() { mlir::test::registerTestMathPolynomialApproximationPass(); mlir::test::registerTestMemRefDependenceCheck(); mlir::test::registerTestMemRefStrideCalculation(); + mlir::test::registerTestOneToNTypeConversionPass(); mlir::test::registerTestOpaqueLoc(); mlir::test::registerTestPadFusion(); mlir::test::registerTestPDLByteCodePass();