From a8416e3c047a7d590fe3884ed49d965c4425d5c3 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Ingo=20M=C3=BCller?= Date: Mon, 27 Mar 2023 09:23:57 +0000 Subject: [PATCH] Revert "[mlir] Implement pass utils for 1:N type conversions." This reverts commit 9c4611f9c7a7055b18f0a30a4c9074b9917e4ab0. --- .../Func/Transforms/OneToNFuncConversions.h | 26 -- .../include/mlir/Transforms/OneToNTypeConversion.h | 256 ------------- mlir/lib/Dialect/Func/Transforms/CMakeLists.txt | 1 - .../Func/Transforms/OneToNFuncConversions.cpp | 132 ------- mlir/lib/Transforms/Utils/CMakeLists.txt | 1 - mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp | 405 --------------------- .../Transforms/decompose-call-graph-types.mlir | 54 --- mlir/test/lib/Conversion/CMakeLists.txt | 1 - .../Conversion/OneToNTypeConversion/CMakeLists.txt | 18 - .../TestOneToNTypeConversionPass.cpp | 245 ------------- mlir/tools/mlir-opt/CMakeLists.txt | 1 - mlir/tools/mlir-opt/mlir-opt.cpp | 2 - 12 files changed, 1142 deletions(-) delete mode 100644 mlir/include/mlir/Dialect/Func/Transforms/OneToNFuncConversions.h delete mode 100644 mlir/include/mlir/Transforms/OneToNTypeConversion.h delete mode 100644 mlir/lib/Dialect/Func/Transforms/OneToNFuncConversions.cpp delete mode 100644 mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp delete mode 100644 mlir/test/lib/Conversion/OneToNTypeConversion/CMakeLists.txt delete mode 100644 mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp diff --git a/mlir/include/mlir/Dialect/Func/Transforms/OneToNFuncConversions.h b/mlir/include/mlir/Dialect/Func/Transforms/OneToNFuncConversions.h deleted file mode 100644 index 2fba342..0000000 --- a/mlir/include/mlir/Dialect/Func/Transforms/OneToNFuncConversions.h +++ /dev/null @@ -1,26 +0,0 @@ -//===- OneToNTypeFuncConversions.h - 1:N type conv. for Func ----*- C++ -*-===// -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_DIALECT_FUNC_TRANSFORMS_ONETONTYPEFUNCCONVERSIONS_H -#define MLIR_DIALECT_FUNC_TRANSFORMS_ONETONTYPEFUNCCONVERSIONS_H - -namespace mlir { -class TypeConverter; -class RewritePatternSet; -} // namespace mlir - -namespace mlir { - -// Populates the provided pattern set with patterns that do 1:N type conversions -// on func ops. This is intended to be used with `applyPartialOneToNConversion`. -void populateFuncTypeConversionPatterns(TypeConverter &typeConverter, - RewritePatternSet &patterns); - -} // namespace mlir - -#endif // MLIR_DIALECT_FUNC_TRANSFORMS_ONETONTYPEFUNCCONVERSIONS_H diff --git a/mlir/include/mlir/Transforms/OneToNTypeConversion.h b/mlir/include/mlir/Transforms/OneToNTypeConversion.h deleted file mode 100644 index 25beee2..0000000 --- a/mlir/include/mlir/Transforms/OneToNTypeConversion.h +++ /dev/null @@ -1,256 +0,0 @@ -//===-- OneToNTypeConversion.h - Utils for 1:N type conversion --*- C++ -*-===// -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file provides utils for implementing (poor-man's) dialect conversion -// passes with 1:N type conversions. -// -// The main function, `applyPartialOneToNConversion`, first applies a set of -// `RewritePattern`s, which produce unrealized casts to convert the operands and -// results from and to the source types, and then replaces all newly added -// unrealized casts by user-provided materializations. For this to work, the -// main function requires a special `TypeConverter`, a special -// `PatternRewriter`, and special RewritePattern`s, which extend their -// respective base classes for 1:N type converions. -// -// Note that this is much more simple-minded than the "real" dialect conversion, -// which checks for legality before applying patterns and does probably many -// other additional things. Ideally, some of the extensions here could be -// integrated there. -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_TRANSFORMS_ONETONTYPECONVERSION_H -#define MLIR_TRANSFORMS_ONETONTYPECONVERSION_H - -#include "mlir/IR/PatternMatch.h" -#include "mlir/Transforms/DialectConversion.h" -#include "llvm/ADT/SmallVector.h" - -namespace mlir { - -/// Extends `TypeConverter` with 1:N target materializations. Such -/// materializations have to provide the "reverse" of 1:N type conversions, -/// i.e., they need to materialize N values with target types into one value -/// with a source type (which isn't possible in the base class currently). -class OneToNTypeConverter : public TypeConverter { -public: - /// Callback that expresses user-provided materialization logic from the given - /// value to N values of the given types. This is useful for expressing target - /// materializations for 1:N type conversions, which materialize one value in - /// a source type as N values in target types. - using OneToNMaterializationCallbackFn = - std::function>(OpBuilder &, TypeRange, - Value, Location)>; - - /// Creates the mapping of the given range of original types to target types - /// of the conversion and stores that mapping in the given (signature) - /// conversion. This function simply calls - /// `TypeConverter::convertSignatureArgs` and exists here with a different - /// name to reflect the broader semantic. - LogicalResult computeTypeMapping(TypeRange types, - SignatureConversion &result) { - return convertSignatureArgs(types, result); - } - - /// Applies one of the user-provided 1:N target materializations. If several - /// exists, they are tried out in the reverse order in which they have been - /// added until the first one succeeds. If none succeeds, the functions - /// returns `std::nullopt`. - std::optional> - materializeTargetConversion(OpBuilder &builder, Location loc, - TypeRange resultTypes, Value input) const; - - /// Adds a 1:N target materialization to the converter. Such materializations - /// build IR that converts N values with target types into 1 value of the - /// source type. - void addTargetMaterialization(OneToNMaterializationCallbackFn &&callback) { - oneToNTargetMaterializations.emplace_back(std::move(callback)); - } - -private: - SmallVector oneToNTargetMaterializations; -}; - -/// Stores a 1:N mapping of types and provides several useful accessors. This -/// class extends `SignatureConversion`, which already supports 1:N type -/// mappings but lacks some accessors into the mapping as well as access to the -/// original types. -class OneToNTypeMapping : public TypeConverter::SignatureConversion { -public: - OneToNTypeMapping(TypeRange originalTypes) - : TypeConverter::SignatureConversion(originalTypes.size()), - originalTypes(originalTypes) {} - - using TypeConverter::SignatureConversion::getConvertedTypes; - - /// Returns the list of types that corresponds to the original type at the - /// given index. - TypeRange getConvertedTypes(unsigned originalTypeNo) const; - - /// Returns the list of original types. - TypeRange getOriginalTypes() const { return originalTypes; } - - /// Returns the slice of converted values that corresponds the original value - /// at the given index. - ValueRange getConvertedValues(ValueRange convertedValues, - unsigned originalValueNo) const; - - /// Fills the given result vector with as many copies of the location of the - /// original value as the number of values it is converted to. - void convertLocation(Value originalValue, unsigned originalValueNo, - llvm::SmallVectorImpl &result) const; - - /// Fills the given result vector with as many copies of the lociation of each - /// original value as the number of values they are respectively converted to. - void convertLocations(ValueRange originalValues, - llvm::SmallVectorImpl &result) const; - - /// Returns true iff at least one type conversion maps an input type to a type - /// that is different from itself. - bool hasNonIdentityConversion() const; - -private: - llvm::SmallVector originalTypes; -}; - -/// Extends the basic `RewritePattern` class with a type converter member and -/// some accessors to it. This is useful for patterns that are not -/// `ConversionPattern`s but still require access to a type converter. -class RewritePatternWithConverter : public mlir::RewritePattern { -public: - /// Construct a conversion pattern with the given converter, and forward the - /// remaining arguments to RewritePattern. - template - RewritePatternWithConverter(TypeConverter &typeConverter, Args &&...args) - : RewritePattern(std::forward(args)...), - typeConverter(&typeConverter) {} - - /// Return the type converter held by this pattern, or nullptr if the pattern - /// does not require type conversion. - TypeConverter *getTypeConverter() const { return typeConverter; } - - template - std::enable_if_t::value, - ConverterTy *> - getTypeConverter() const { - return static_cast(typeConverter); - } - -protected: - /// A type converter for use by this pattern. - TypeConverter *const typeConverter; -}; - -/// Specialization of `PatternRewriter` that `OneToNConversionPattern`s use. The -/// class provides additional rewrite methods that are specific to 1:N type -/// conversions. -class OneToNPatternRewriter : public PatternRewriter { -public: - OneToNPatternRewriter(MLIRContext *context) : PatternRewriter(context) {} - - /// Replaces the results of the operation with the specified list of values - /// mapped back to the original types as specified in the provided type - /// mapping. That type mapping must match the replaced op (i.e., the original - /// types must be the same as the result types of the op) and the new values - /// (i.e., the converted types must be the same as the types of the new - /// values). - void replaceOp(Operation *op, ValueRange newValues, - const OneToNTypeMapping &resultMapping); - using PatternRewriter::replaceOp; - - /// Applies the given argument conversion to the given block. This consists of - /// replacing each original argument with N arguments as specified in the - /// argument conversion and inserting unrealized casts from the converted - /// values to the original types, which are then used in lieu of the original - /// ones. (Eventually, `applyPartialOneToNConversion` replaces these casts - /// with a user-provided argument materialization if necessary.) This is - /// similar to `ArgConverter::applySignatureConversion` but (1) handles 1:N - /// type conversion properly and probably (2) doesn't handle many other edge - /// cases. - Block *applySignatureConversion(Block *block, - OneToNTypeMapping &argumentConversion); -}; - -/// Base class for patterns with 1:N type conversions. Derived classes have to -/// overwrite the `matchAndRewrite` overlaod that provides additional -/// information for 1:N type conversions. -class OneToNConversionPattern : public RewritePatternWithConverter { -public: - using RewritePatternWithConverter::RewritePatternWithConverter; - - /// This function has to be implemented by base classes and is called from the - /// usual overloads. Like in "normal" `DialectConversion`, the function is - /// provided with the converted operands (which thus have target types). Since - /// 1:N conversion are supported, there is usually no 1:1 relationship between - /// the original and the converted operands. Instead, the provided - /// `operandMapping` can be used to access the converted operands that - /// correspond to a particular original operand. Similarly, `resultMapping` - /// is provided to help with assembling the result values, which may have 1:N - /// correspondences as well. In that case, the original op should be replaced - /// with the overload of `replaceOp` that takes the provided `resultMapping` - /// in order to deal with the mapping of converted result values to their - /// usages in the original types correctly. - virtual LogicalResult matchAndRewrite(Operation *op, - OneToNPatternRewriter &rewriter, - const OneToNTypeMapping &operandMapping, - const OneToNTypeMapping &resultMapping, - ValueRange convertedOperands) const = 0; - - LogicalResult matchAndRewrite(Operation *op, - PatternRewriter &rewriter) const final; -}; - -/// This class is a wrapper around `OneToNConversionPattern` for matching -/// against instances of a particular op class. -template -class OneToNOpConversionPattern : public OneToNConversionPattern { -public: - OneToNOpConversionPattern(TypeConverter &typeConverter, MLIRContext *context, - PatternBenefit benefit = 1, - ArrayRef generatedNames = {}) - : OneToNConversionPattern(typeConverter, SourceOp::getOperationName(), - benefit, context, generatedNames) {} - - using OneToNConversionPattern::matchAndRewrite; - - /// Overload that derived classes have to override for their op type. - virtual LogicalResult matchAndRewrite(SourceOp op, - OneToNPatternRewriter &rewriter, - const OneToNTypeMapping &operandMapping, - const OneToNTypeMapping &resultMapping, - ValueRange convertedOperands) const = 0; - - LogicalResult matchAndRewrite(Operation *op, OneToNPatternRewriter &rewriter, - const OneToNTypeMapping &operandMapping, - const OneToNTypeMapping &resultMapping, - ValueRange convertedOperands) const final { - return matchAndRewrite(cast(op), rewriter, operandMapping, - resultMapping, convertedOperands); - } -}; - -/// Applies the given set of patterns recursively on the given op and adds user -/// materializations where necessary. The patterns are expected to be -/// `OneToNConversionPattern`, which help converting the types of the operands -/// and results of the matched ops. The provided type converter is used to -/// convert the operands of matched ops from their original types to operands -/// with different types. Unlike in `DialectConversion`, this supports 1:N type -/// conversions. Those conversions at the "boundary" of the pattern application, -/// where converted results are not consumed by replaced ops that expect the -/// converted operands or vice versa, the function inserts user materializations -/// from the type converter. Also unlike `DialectConversion`, there are no legal -/// or illegal types; the function simply applies the given patterns and does -/// not fail if some ops or types remain unconverted (i.e., the conversion is -/// only "partial"). -LogicalResult -applyPartialOneToNConversion(Operation *op, OneToNTypeConverter &typeConverter, - const FrozenRewritePatternSet &patterns); - -} // namespace mlir - -#endif // MLIR_TRANSFORMS_ONETONTYPECONVERSION_H diff --git a/mlir/lib/Dialect/Func/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Func/Transforms/CMakeLists.txt index 1720199..9a5b38b 100644 --- a/mlir/lib/Dialect/Func/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Func/Transforms/CMakeLists.txt @@ -3,7 +3,6 @@ add_mlir_dialect_library(MLIRFuncTransforms DuplicateFunctionElimination.cpp FuncBufferize.cpp FuncConversions.cpp - OneToNFuncConversions.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Func/Transforms diff --git a/mlir/lib/Dialect/Func/Transforms/OneToNFuncConversions.cpp b/mlir/lib/Dialect/Func/Transforms/OneToNFuncConversions.cpp deleted file mode 100644 index 5e8125c..0000000 --- a/mlir/lib/Dialect/Func/Transforms/OneToNFuncConversions.cpp +++ /dev/null @@ -1,132 +0,0 @@ -//===-- OneToNTypeFuncConversions.cpp - Func 1:N type conversion-*- C++ -*-===// -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// The patterns in this file are heavily inspired (and copied from) -// convertFuncOpTypes in lib/Transforms/Utils/DialectConversion.cpp and the -// patterns in lib/Dialect/Func/Transforms/FuncConversions.cpp but work for 1:N -// type conversions. -// -//===----------------------------------------------------------------------===// - -#include "mlir/Dialect/Func/Transforms/OneToNFuncConversions.h" - -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Transforms/OneToNTypeConversion.h" - -using namespace mlir; -using namespace mlir::func; - -namespace { - -class ConvertTypesInFuncCallOp : public OneToNOpConversionPattern { -public: - using OneToNOpConversionPattern::OneToNOpConversionPattern; - - LogicalResult matchAndRewrite(CallOp op, OneToNPatternRewriter &rewriter, - const OneToNTypeMapping &operandMapping, - const OneToNTypeMapping &resultMapping, - ValueRange convertedOperands) const override { - Location loc = op->getLoc(); - - // Nothing to do if the op doesn't have any non-identity conversions for its - // operands or results. - if (!operandMapping.hasNonIdentityConversion() && - !resultMapping.hasNonIdentityConversion()) - return failure(); - - // Create new CallOp. - auto newOp = rewriter.create(loc, resultMapping.getConvertedTypes(), - convertedOperands); - newOp->setAttrs(op->getAttrs()); - - rewriter.replaceOp(op, newOp->getResults(), resultMapping); - return success(); - } -}; - -class ConvertTypesInFuncFuncOp : public OneToNOpConversionPattern { -public: - using OneToNOpConversionPattern::OneToNOpConversionPattern; - - LogicalResult - matchAndRewrite(FuncOp op, OneToNPatternRewriter &rewriter, - const OneToNTypeMapping & /*operandMapping*/, - const OneToNTypeMapping & /*resultMapping*/, - ValueRange /*convertedOperands*/) const override { - auto *typeConverter = getTypeConverter(); - - // Construct mapping for function arguments. - OneToNTypeMapping argumentMapping(op.getArgumentTypes()); - if (failed(typeConverter->computeTypeMapping(op.getArgumentTypes(), - argumentMapping))) - return failure(); - - // Construct mapping for function results. - OneToNTypeMapping funcResultMapping(op.getResultTypes()); - if (failed(typeConverter->computeTypeMapping(op.getResultTypes(), - funcResultMapping))) - return failure(); - - // Nothing to do if the op doesn't have any non-identity conversions for its - // operands or results. - if (!argumentMapping.hasNonIdentityConversion() && - !funcResultMapping.hasNonIdentityConversion()) - return failure(); - - // Update the function signature in-place. - auto newType = FunctionType::get(rewriter.getContext(), - argumentMapping.getConvertedTypes(), - funcResultMapping.getConvertedTypes()); - rewriter.updateRootInPlace(op, [&] { op.setType(newType); }); - - // Update block signatures. - if (!op.isExternal()) { - Region *region = &op.getBody(); - Block *block = ®ion->front(); - rewriter.applySignatureConversion(block, argumentMapping); - } - - return success(); - } -}; - -class ConvertTypesInFuncReturnOp : public OneToNOpConversionPattern { -public: - using OneToNOpConversionPattern::OneToNOpConversionPattern; - - LogicalResult matchAndRewrite(ReturnOp op, OneToNPatternRewriter &rewriter, - const OneToNTypeMapping &operandMapping, - const OneToNTypeMapping & /*resultMapping*/, - ValueRange convertedOperands) const override { - // Nothing to do if there is no non-identity conversion. - if (!operandMapping.hasNonIdentityConversion()) - return failure(); - - // Convert operands. - rewriter.updateRootInPlace(op, [&] { op->setOperands(convertedOperands); }); - - return success(); - } -}; - -} // namespace - -namespace mlir { - -void populateFuncTypeConversionPatterns(TypeConverter &typeConverter, - RewritePatternSet &patterns) { - patterns.add< - // clang-format off - ConvertTypesInFuncCallOp, - ConvertTypesInFuncFuncOp, - ConvertTypesInFuncReturnOp - // clang-format on - >(typeConverter, patterns.getContext()); -} - -} // namespace mlir diff --git a/mlir/lib/Transforms/Utils/CMakeLists.txt b/mlir/lib/Transforms/Utils/CMakeLists.txt index 6892d00..ba8fa20 100644 --- a/mlir/lib/Transforms/Utils/CMakeLists.txt +++ b/mlir/lib/Transforms/Utils/CMakeLists.txt @@ -6,7 +6,6 @@ add_mlir_library(MLIRTransformUtils GreedyPatternRewriteDriver.cpp InliningUtils.cpp LoopInvariantCodeMotionUtils.cpp - OneToNTypeConversion.cpp RegionUtils.cpp TopologicalSortUtils.cpp diff --git a/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp b/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp deleted file mode 100644 index c0866f8..0000000 --- a/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp +++ /dev/null @@ -1,405 +0,0 @@ -//===-- OneToNTypeConversion.cpp - Utils for 1:N type conversion-*- C++ -*-===// -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "mlir/Transforms/OneToNTypeConversion.h" - -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "llvm/ADT/SmallSet.h" - -using namespace llvm; -using namespace mlir; - -std::optional> -OneToNTypeConverter::materializeTargetConversion(OpBuilder &builder, - Location loc, - TypeRange resultTypes, - Value input) const { - for (const OneToNMaterializationCallbackFn &fn : - llvm::reverse(oneToNTargetMaterializations)) { - if (std::optional> result = - fn(builder, resultTypes, input, loc)) - return *result; - } - return std::nullopt; -} - -TypeRange OneToNTypeMapping::getConvertedTypes(unsigned originalTypeNo) const { - TypeRange convertedTypes = getConvertedTypes(); - if (auto mapping = getInputMapping(originalTypeNo)) - return convertedTypes.slice(mapping->inputNo, mapping->size); - return {}; -} - -ValueRange -OneToNTypeMapping::getConvertedValues(ValueRange convertedValues, - unsigned originalValueNo) const { - if (auto mapping = getInputMapping(originalValueNo)) - return convertedValues.slice(mapping->inputNo, mapping->size); - return {}; -} - -void OneToNTypeMapping::convertLocation( - Value originalValue, unsigned originalValueNo, - llvm::SmallVectorImpl &result) const { - if (auto mapping = getInputMapping(originalValueNo)) - result.append(mapping->size, originalValue.getLoc()); -} - -void OneToNTypeMapping::convertLocations( - ValueRange originalValues, llvm::SmallVectorImpl &result) const { - assert(originalValues.size() == getOriginalTypes().size()); - for (auto [i, value] : llvm::enumerate(originalValues)) - convertLocation(value, i, result); -} - -static bool isIdentityConversion(Type originalType, TypeRange convertedTypes) { - return convertedTypes.size() == 1 && convertedTypes[0] == originalType; -} - -bool OneToNTypeMapping::hasNonIdentityConversion() const { - // XXX: I think that the original types and the converted types are the same - // iff there was no non-identity type conversion. If that is true, the - // patterns could actually test whether there is anything useful to do - // without having access to the signature conversion. - for (auto [i, originalType] : llvm::enumerate(originalTypes)) { - TypeRange types = getConvertedTypes(i); - if (!isIdentityConversion(originalType, types)) { - assert(TypeRange(originalTypes) != getConvertedTypes()); - return true; - } - } - assert(TypeRange(originalTypes) == getConvertedTypes()); - return false; -} - -namespace { -enum class CastKind { - // Casts block arguments in the target type back to the source type. (If - // necessary, this cast becomes an argument materialization.) - Argument, - - // Casts other values in the target type back to the source type. (If - // necessary, this cast becomes a source materialization.) - Source, - - // Casts values in the source type to the target type. (If necessary, this - // cast becomes a target materialization.) - Target -}; -} - -/// Mapping of enum values to string values. -StringRef getCastKindName(CastKind kind) { - static const std::unordered_map castKindNames = { - {CastKind::Argument, "argument"}, - {CastKind::Source, "source"}, - {CastKind::Target, "target"}}; - return castKindNames.at(kind); -} - -/// Attribute name that is used to annotate inserted unrealized casts with their -/// kind (source, argument, or target). -static const char *const castKindAttrName = - "__one-to-n-type-conversion_cast-kind__"; - -/// Builds an `UnrealizedConversionCastOp` from the given inputs to the given -/// result types. Returns the result values of the cast. -static ValueRange buildUnrealizedCast(OpBuilder &builder, TypeRange resultTypes, - ValueRange inputs, CastKind kind) { - // Create cast. - Location loc = builder.getUnknownLoc(); - if (!inputs.empty()) - loc = inputs.front().getLoc(); - auto castOp = - builder.create(loc, resultTypes, inputs); - - // Store cast kind as attribute. - auto kindAttr = StringAttr::get(builder.getContext(), getCastKindName(kind)); - castOp->setAttr(castKindAttrName, kindAttr); - - return castOp->getResults(); -} - -/// Builds one `UnrealizedConversionCastOp` for each of the given original -/// values using the respective target types given in the provided conversion -/// mapping and returns the results of these casts. If the conversion mapping of -/// a value maps a type to itself (i.e., is an identity conversion), then no -/// cast is inserted and the original value is returned instead. -/// Note that these unrealized casts are different from target materializations -/// in that they are *always* inserted, even if they immediately fold away, such -/// that patterns always see valid intermediate IR, whereas materializations are -/// only used in the places where the unrealized casts *don't* fold away. -static SmallVector -buildUnrealizedForwardCasts(ValueRange originalValues, - OneToNTypeMapping &conversion, - RewriterBase &rewriter, CastKind kind) { - - // Convert each operand one by one. - SmallVector convertedValues; - convertedValues.reserve(conversion.getConvertedTypes().size()); - for (auto [idx, originalValue] : llvm::enumerate(originalValues)) { - TypeRange convertedTypes = conversion.getConvertedTypes(idx); - - // Identity conversion: keep operand as is. - if (isIdentityConversion(originalValue.getType(), convertedTypes)) { - convertedValues.push_back(originalValue); - continue; - } - - // Non-identity conversion: materialize target types. - ValueRange castResult = - buildUnrealizedCast(rewriter, convertedTypes, originalValue, kind); - convertedValues.append(castResult.begin(), castResult.end()); - } - - return convertedValues; -} - -/// Builds one `UnrealizedConversionCastOp` for each sequence of the given -/// original values to one value of the type they originated from, i.e., a -/// "reverse" conversion from N converted values back to one value of the -/// original type, using the given (forward) type conversion. If a given value -/// was mapped to a value of the same type (i.e., the conversion in the mapping -/// is an identity conversion), then the "converted" value is returned without -/// cast. -/// Note that these unrealized casts are different from source materializations -/// in that they are *always* inserted, even if they immediately fold away, such -/// that patterns always see valid intermediate IR, whereas materializations are -/// only used in the places where the unrealized casts *don't* fold away. -static SmallVector -buildUnrealizedBackwardsCasts(ValueRange convertedValues, - const OneToNTypeMapping &typeConversion, - RewriterBase &rewriter) { - assert(typeConversion.getConvertedTypes() == convertedValues.getTypes()); - - // Create unrealized cast op for each converted result of the op. - SmallVector recastValues; - TypeRange originalTypes = typeConversion.getOriginalTypes(); - recastValues.reserve(originalTypes.size()); - auto convertedValueIt = convertedValues.begin(); - for (auto [idx, originalType] : llvm::enumerate(originalTypes)) { - TypeRange convertedTypes = typeConversion.getConvertedTypes(idx); - size_t numConvertedValues = convertedTypes.size(); - if (isIdentityConversion(originalType, convertedTypes)) { - // Identity conversion: take result as is. - recastValues.push_back(*convertedValueIt); - } else { - // Non-identity conversion: cast back to source type. - ValueRange recastValue = buildUnrealizedCast( - rewriter, originalType, - ValueRange{convertedValueIt, convertedValueIt + numConvertedValues}, - CastKind::Source); - assert(recastValue.size() == 1); - recastValues.push_back(recastValue.front()); - } - convertedValueIt += numConvertedValues; - } - - return recastValues; -} - -void OneToNPatternRewriter::replaceOp(Operation *op, ValueRange newValues, - const OneToNTypeMapping &resultMapping) { - // Create a cast back to the original types and replace the results of the - // original op with those. - assert(newValues.size() == resultMapping.getConvertedTypes().size()); - assert(op->getResultTypes() == resultMapping.getOriginalTypes()); - SmallVector castResults = - buildUnrealizedBackwardsCasts(newValues, resultMapping, *this); - replaceOp(op, castResults); -} - -Block *OneToNPatternRewriter::applySignatureConversion( - Block *block, OneToNTypeMapping &argumentConversion) { - // Split the block at the beginning to get a new block to use for the - // updated signature. - SmallVector locs; - argumentConversion.convertLocations(block->getArguments(), locs); - Block *newBlock = - createBlock(block, argumentConversion.getConvertedTypes(), locs); - replaceAllUsesWith(block, newBlock); - - // Create necessary casts in new block. - SmallVector castResults; - for (auto [i, arg] : llvm::enumerate(block->getArguments())) { - TypeRange convertedTypes = argumentConversion.getConvertedTypes(i); - ValueRange newArgs = - argumentConversion.getConvertedValues(newBlock->getArguments(), i); - if (isIdentityConversion(arg.getType(), convertedTypes)) { - // Identity conversion: take argument as is. - assert(newArgs.size() == 1); - castResults.push_back(newArgs.front()); - } else { - // Non-identity conversion: cast the converted arguments to the original - // type. - PatternRewriter::InsertionGuard g(*this); - setInsertionPointToStart(newBlock); - ValueRange castResult = buildUnrealizedCast(*this, arg.getType(), newArgs, - CastKind::Argument); - assert(castResult.size() == 1); - castResults.push_back(castResult.front()); - } - } - - // Merge old block into new block such that we only have the latter with the - // new signature. - mergeBlocks(block, newBlock, castResults); - - return newBlock; -} - -LogicalResult -OneToNConversionPattern::matchAndRewrite(Operation *op, - PatternRewriter &rewriter) const { - auto *typeConverter = getTypeConverter(); - - // Construct conversion mapping for results. - Operation::result_type_range originalResultTypes = op->getResultTypes(); - OneToNTypeMapping resultMapping(originalResultTypes); - if (failed(typeConverter->computeTypeMapping(originalResultTypes, - resultMapping))) - return failure(); - - // Construct conversion mapping for operands. - Operation::operand_type_range originalOperandTypes = op->getOperandTypes(); - OneToNTypeMapping operandMapping(originalOperandTypes); - if (failed(typeConverter->computeTypeMapping(originalOperandTypes, - operandMapping))) - return failure(); - - // Cast operands to target types. - SmallVector convertedOperands = buildUnrealizedForwardCasts( - op->getOperands(), operandMapping, rewriter, CastKind::Target); - - // Create a `OneToNPatternRewriter` for the pattern, which provides additional - // functionality. - // TODO(ingomueller): I guess it would be better to use only one rewriter - // throughout the whole pass, but that would require to - // drive the pattern application ourselves, which is a lot - // of additional boilerplate code. This seems to work fine, - // so I leave it like this for the time being. - OneToNPatternRewriter oneToNPatternRewriter(rewriter.getContext()); - oneToNPatternRewriter.restoreInsertionPoint(rewriter.saveInsertionPoint()); - oneToNPatternRewriter.setListener(rewriter.getListener()); - - // Apply actual pattern. - if (failed(matchAndRewrite(op, oneToNPatternRewriter, operandMapping, - resultMapping, convertedOperands))) - return failure(); - - return success(); -} - -namespace mlir { - -// This function applies the provided patterns using -// `applyPatternsAndFoldGreedily` and then replaces all newly inserted -// `UnrealizedConversionCastOps` that haven't folded away. ("Backward" casts -// from target to source types inserted by a `OneToNConversionPattern` normally -// fold away with the "forward" casts from source to target types inserted by -// the next pattern.) To understand which casts are "newly inserted", all casts -// inserted by this pass are annotated with a string attribute that also -// documents which kind of the cast (source, argument, or target). -LogicalResult -applyPartialOneToNConversion(Operation *op, OneToNTypeConverter &typeConverter, - const FrozenRewritePatternSet &patterns) { -#ifndef NDEBUG - // Remember existing unrealized casts. This data structure is only used in - // asserts; building it only for that purpose may be an overkill. - SmallSet existingCasts; - op->walk([&](UnrealizedConversionCastOp castOp) { - assert(!castOp->hasAttr(castKindAttrName)); - existingCasts.insert(castOp); - }); -#endif // NDEBUG - - // Apply provided conversion patterns. - if (failed(applyPatternsAndFoldGreedily(op, patterns))) { - emitError(op->getLoc()) << "failed to apply conversion patterns"; - return failure(); - } - - // Find all unrealized casts inserted by the pass that haven't folded away. - SmallVector worklist; - op->walk([&](UnrealizedConversionCastOp castOp) { - if (castOp->hasAttr(castKindAttrName)) { - assert(!existingCasts.contains(castOp)); - worklist.push_back(castOp); - } - }); - - // Replace new casts with user materializations. - IRRewriter rewriter(op->getContext()); - for (UnrealizedConversionCastOp castOp : worklist) { - TypeRange resultTypes = castOp->getResultTypes(); - ValueRange operands = castOp->getOperands(); - StringRef castKind = - castOp->getAttrOfType(castKindAttrName).getValue(); - rewriter.setInsertionPoint(castOp); - -#ifndef NDEBUG - // Determine whether operands or results are already legal to test some - // assumptions for the different kind of materializations. These properties - // are only used it asserts and it may be overkill to compute them. - bool areOperandTypesLegal = llvm::all_of( - operands.getTypes(), [&](Type t) { return typeConverter.isLegal(t); }); - bool areResultsTypesLegal = llvm::all_of( - resultTypes, [&](Type t) { return typeConverter.isLegal(t); }); -#endif // NDEBUG - - // Add materialization and remember materialized results. - SmallVector materializedResults; - if (castKind == getCastKindName(CastKind::Target)) { - // Target materialization. - assert(!areOperandTypesLegal && areResultsTypesLegal && - operands.size() == 1 && "found unexpected target cast"); - std::optional> maybeResults = - typeConverter.materializeTargetConversion( - rewriter, castOp->getLoc(), resultTypes, operands.front()); - if (!maybeResults) { - emitError(castOp->getLoc()) - << "failed to create target materialization"; - return failure(); - } - materializedResults = maybeResults.value(); - } else { - // Source and argument materializations. - assert(areOperandTypesLegal && !areResultsTypesLegal && - resultTypes.size() == 1 && "found unexpected cast"); - std::optional maybeResult; - if (castKind == getCastKindName(CastKind::Source)) { - // Source materialization. - maybeResult = typeConverter.materializeSourceConversion( - rewriter, castOp->getLoc(), resultTypes.front(), - castOp.getOperands()); - } else { - // Argument materialization. - assert(castKind == getCastKindName(CastKind::Argument) && - "unexpected value of cast kind attribute"); - assert(llvm::all_of(operands, - [&](Value v) { return v.isa(); })); - maybeResult = typeConverter.materializeArgumentConversion( - rewriter, castOp->getLoc(), resultTypes.front(), - castOp.getOperands()); - } - if (!maybeResult.has_value() || !maybeResult.value()) { - emitError(castOp->getLoc()) - << "failed to create " << castKind << " materialization"; - return failure(); - } - materializedResults = {maybeResult.value()}; - } - - // Replace the cast with the result of the materialization. - rewriter.replaceOp(castOp, materializedResults); - } - - return success(); -} - -} // namespace mlir diff --git a/mlir/test/Transforms/decompose-call-graph-types.mlir b/mlir/test/Transforms/decompose-call-graph-types.mlir index 51b63ba..604e948 100644 --- a/mlir/test/Transforms/decompose-call-graph-types.mlir +++ b/mlir/test/Transforms/decompose-call-graph-types.mlir @@ -1,9 +1,5 @@ // RUN: mlir-opt %s -split-input-file -test-decompose-call-graph-types | FileCheck %s -// RUN: mlir-opt %s -split-input-file \ -// RUN: -test-one-to-n-type-conversion="convert-func-ops" \ -// RUN: | FileCheck %s --check-prefix=CHECK-12N - // Test case: Most basic case of a 1:N decomposition, an identity function. // CHECK-LABEL: func @identity( @@ -13,10 +9,6 @@ // CHECK: %[[RET0:.*]] = "test.get_tuple_element"(%[[ARG_MATERIALIZED]]) {index = 0 : i32} : (tuple) -> i1 // CHECK: %[[RET1:.*]] = "test.get_tuple_element"(%[[ARG_MATERIALIZED]]) {index = 1 : i32} : (tuple) -> i32 // CHECK: return %[[RET0]], %[[RET1]] : i1, i32 -// CHECK-12N-LABEL: func @identity( -// CHECK-12N-SAME: %[[ARG0:.*]]: i1, -// CHECK-12N-SAME: %[[ARG1:.*]]: i32) -> (i1, i32) { -// CHECK-12N: return %[[ARG0]], %[[ARG1]] : i1, i32 func.func @identity(%arg0: tuple) -> tuple { return %arg0 : tuple } @@ -28,9 +20,6 @@ func.func @identity(%arg0: tuple) -> tuple { // CHECK-LABEL: func @identity_1_to_1_no_materializations( // CHECK-SAME: %[[ARG0:.*]]: i1) -> i1 { // CHECK: return %[[ARG0]] : i1 -// CHECK-12N-LABEL: func @identity_1_to_1_no_materializations( -// CHECK-12N-SAME: %[[ARG0:.*]]: i1) -> i1 { -// CHECK-12N: return %[[ARG0]] : i1 func.func @identity_1_to_1_no_materializations(%arg0: tuple) -> tuple { return %arg0 : tuple } @@ -42,9 +31,6 @@ func.func @identity_1_to_1_no_materializations(%arg0: tuple) -> tuple { // CHECK-LABEL: func @recursive_decomposition( // CHECK-SAME: %[[ARG0:.*]]: i1) -> i1 { // CHECK: return %[[ARG0]] : i1 -// CHECK-12N-LABEL: func @recursive_decomposition( -// CHECK-12N-SAME: %[[ARG0:.*]]: i1) -> i1 { -// CHECK-12N: return %[[ARG0]] : i1 func.func @recursive_decomposition(%arg0: tuple>>) -> tuple>> { return %arg0 : tuple>> } @@ -68,10 +54,6 @@ func.func @recursive_decomposition(%arg0: tuple>>) -> tuple>) -> tuple // CHECK: %[[V10:.*]] = "test.get_tuple_element"(%[[V9]]) {index = 0 : i32} : (tuple) -> i2 // CHECK: return %[[V7]], %[[V10]] : i1, i2 -// CHECK-12N-LABEL: func @mixed_recursive_decomposition( -// CHECK-12N-SAME: %[[ARG0:.*]]: i1, -// CHECK-12N-SAME: %[[ARG1:.*]]: i2) -> (i1, i2) { -// CHECK-12N: return %[[ARG0]], %[[ARG1]] : i1, i2 func.func @mixed_recursive_decomposition(%arg0: tuple, tuple, tuple>>) -> tuple, tuple, tuple>> { return %arg0 : tuple, tuple, tuple>> } @@ -81,7 +63,6 @@ func.func @mixed_recursive_decomposition(%arg0: tuple, tuple, tuple< // Test case: Check decomposition of calls. // CHECK-LABEL: func private @callee(i1, i32) -> (i1, i32) -// CHECK-12N-LABEL: func private @callee(i1, i32) -> (i1, i32) func.func private @callee(tuple) -> tuple // CHECK-LABEL: func @caller( @@ -95,11 +76,6 @@ func.func private @callee(tuple) -> tuple // CHECK: %[[RET0:.*]] = "test.get_tuple_element"(%[[CALL_RESULT_RECOMPOSED]]) {index = 0 : i32} : (tuple) -> i1 // CHECK: %[[RET1:.*]] = "test.get_tuple_element"(%[[CALL_RESULT_RECOMPOSED]]) {index = 1 : i32} : (tuple) -> i32 // CHECK: return %[[RET0]], %[[RET1]] : i1, i32 -// CHECK-12N-LABEL: func @caller( -// CHECK-12N-SAME: %[[ARG0:.*]]: i1, -// CHECK-12N-SAME: %[[ARG1:.*]]: i32) -> (i1, i32) { -// CHECK-12N: %[[V0:.*]]:2 = call @callee(%[[ARG0]], %[[ARG1]]) : (i1, i32) -> (i1, i32) -// CHECK-12N: return %[[V0]]#0, %[[V0]]#1 : i1, i32 func.func @caller(%arg0: tuple) -> tuple { %0 = call @callee(%arg0) : (tuple) -> tuple return %0 : tuple @@ -110,15 +86,10 @@ func.func @caller(%arg0: tuple) -> tuple { // Test case: Type that decomposes to nothing (that is, a 1:0 decomposition). // CHECK-LABEL: func private @callee() -// CHECK-12N-LABEL: func private @callee() func.func private @callee(tuple<>) -> tuple<> - // CHECK-LABEL: func @caller() { // CHECK: call @callee() : () -> () // CHECK: return -// CHECK-12N-LABEL: func @caller() { -// CHECK-12N: call @callee() : () -> () -// CHECK-12N: return func.func @caller(%arg0: tuple<>) -> tuple<> { %0 = call @callee(%arg0) : (tuple<>) -> (tuple<>) return %0 : tuple<> @@ -134,11 +105,6 @@ func.func @caller(%arg0: tuple<>) -> tuple<> { // CHECK: %[[RET0:.*]] = "test.get_tuple_element"(%[[UNCONVERTED_VALUE]]) {index = 0 : i32} : (tuple) -> i1 // CHECK: %[[RET1:.*]] = "test.get_tuple_element"(%[[UNCONVERTED_VALUE]]) {index = 1 : i32} : (tuple) -> i32 // CHECK: return %[[RET0]], %[[RET1]] : i1, i32 -// CHECK-12N-LABEL: func @unconverted_op_result() -> (i1, i32) { -// CHECK-12N: %[[UNCONVERTED_VALUE:.*]] = "test.source"() : () -> tuple -// CHECK-12N: %[[RET0:.*]] = "test.get_tuple_element"(%[[UNCONVERTED_VALUE]]) {index = 0 : i32} : (tuple) -> i1 -// CHECK-12N: %[[RET1:.*]] = "test.get_tuple_element"(%[[UNCONVERTED_VALUE]]) {index = 1 : i32} : (tuple) -> i32 -// CHECK-12N: return %[[RET0]], %[[RET1]] : i1, i32 func.func @unconverted_op_result() -> tuple { %0 = "test.source"() : () -> (tuple) return %0 : tuple @@ -159,16 +125,6 @@ func.func @unconverted_op_result() -> tuple { // CHECK: %[[V4:.*]] = "test.get_tuple_element"(%[[V2]]) {index = 1 : i32} : (tuple>) -> tuple // CHECK: %[[V5:.*]] = "test.get_tuple_element"(%[[V4]]) {index = 0 : i32} : (tuple) -> i32 // CHECK: return %[[V3]], %[[V5]] : i1, i32 -// CHECK-12N-LABEL: func @nested_unconverted_op_result( -// CHECK-12N-SAME: %[[ARG0:.*]]: i1, -// CHECK-12N-SAME: %[[ARG1:.*]]: i32) -> (i1, i32) { -// CHECK-12N: %[[V0:.*]] = "test.make_tuple"(%[[ARG1]]) : (i32) -> tuple -// CHECK-12N: %[[V1:.*]] = "test.make_tuple"(%[[ARG0]], %[[V0]]) : (i1, tuple) -> tuple> -// CHECK-12N: %[[V2:.*]] = "test.op"(%[[V1]]) : (tuple>) -> tuple> -// CHECK-12N: %[[V3:.*]] = "test.get_tuple_element"(%[[V2]]) {index = 0 : i32} : (tuple>) -> i1 -// CHECK-12N: %[[V4:.*]] = "test.get_tuple_element"(%[[V2]]) {index = 1 : i32} : (tuple>) -> tuple -// CHECK-12N: %[[V5:.*]] = "test.get_tuple_element"(%[[V4]]) {index = 0 : i32} : (tuple) -> i32 -// CHECK-12N: return %[[V3]], %[[V5]] : i1, i32 func.func @nested_unconverted_op_result(%arg: tuple>) -> tuple> { %0 = "test.op"(%arg) : (tuple>) -> (tuple>) return %0 : tuple> @@ -180,7 +136,6 @@ func.func @nested_unconverted_op_result(%arg: tuple>) -> tuple (i1, i2, i3, i4, i5, i6) -// CHECK-12N-LABEL: func private @callee(i1, i2, i3, i4, i5, i6) -> (i1, i2, i3, i4, i5, i6) func.func private @callee(tuple<>, i1, tuple, i3, tuple, i6) -> (tuple<>, i1, tuple, i3, tuple, i6) // CHECK-LABEL: func @caller( @@ -198,15 +153,6 @@ func.func private @callee(tuple<>, i1, tuple, i3, tuple, i6) -> (tup // CHECK: %[[RET_TUPLE_0:.*]] = "test.get_tuple_element"(%[[RET_TUPLE]]) {index = 0 : i32} : (tuple) -> i4 // CHECK: %[[RET_TUPLE_1:.*]] = "test.get_tuple_element"(%[[RET_TUPLE]]) {index = 1 : i32} : (tuple) -> i5 // CHECK: return %[[CALL]]#0, %[[CALL]]#1, %[[CALL]]#2, %[[RET_TUPLE_0]], %[[RET_TUPLE_1]], %[[CALL]]#5 : i1, i2, i3, i4, i5, i6 -// CHECK-12N-LABEL: func @caller( -// CHECK-12N-SAME: %[[I1:.*]]: i1, -// CHECK-12N-SAME: %[[I2:.*]]: i2, -// CHECK-12N-SAME: %[[I3:.*]]: i3, -// CHECK-12N-SAME: %[[I4:.*]]: i4, -// CHECK-12N-SAME: %[[I5:.*]]: i5, -// CHECK-12N-SAME: %[[I6:.*]]: i6) -> (i1, i2, i3, i4, i5, i6) { -// CHECK-12N: %[[CALL:.*]]:6 = call @callee(%[[I1]], %[[I2]], %[[I3]], %[[I4]], %[[I5]], %[[I6]]) : (i1, i2, i3, i4, i5, i6) -> (i1, i2, i3, i4, i5, i6) -// CHECK-12N: return %[[CALL]]#0, %[[CALL]]#1, %[[CALL]]#2, %[[CALL]]#3, %[[CALL]]#4, %[[CALL]]#5 : i1, i2, i3, i4, i5, i6 func.func @caller(%arg0: tuple<>, %arg1: i1, %arg2: tuple, %arg3: i3, %arg4: tuple, %arg5: i6) -> (tuple<>, i1, tuple, i3, tuple, i6) { %0, %1, %2, %3, %4, %5 = call @callee(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) : (tuple<>, i1, tuple, i3, tuple, i6) -> (tuple<>, i1, tuple, i3, tuple, i6) return %0, %1, %2, %3, %4, %5 : tuple<>, i1, tuple, i3, tuple, i6 diff --git a/mlir/test/lib/Conversion/CMakeLists.txt b/mlir/test/lib/Conversion/CMakeLists.txt index 14df652..14f0e0d 100644 --- a/mlir/test/lib/Conversion/CMakeLists.txt +++ b/mlir/test/lib/Conversion/CMakeLists.txt @@ -1,3 +1,2 @@ add_subdirectory(FuncToLLVM) -add_subdirectory(OneToNTypeConversion) add_subdirectory(VectorToSPIRV) diff --git a/mlir/test/lib/Conversion/OneToNTypeConversion/CMakeLists.txt b/mlir/test/lib/Conversion/OneToNTypeConversion/CMakeLists.txt deleted file mode 100644 index 4189786..0000000 --- a/mlir/test/lib/Conversion/OneToNTypeConversion/CMakeLists.txt +++ /dev/null @@ -1,18 +0,0 @@ -add_mlir_library(MLIRTestOneToNTypeConversionPass - TestOneToNTypeConversionPass.cpp - - EXCLUDE_FROM_LIBMLIR - - LINK_LIBS PUBLIC - MLIRFuncDialect - MLIRFuncTransforms - MLIRIR - MLIRTestDialect - MLIRTransformUtils - ) - -target_include_directories(MLIRTestOneToNTypeConversionPass - PRIVATE - ${CMAKE_CURRENT_SOURCE_DIR}/../../Dialect/Test - ${CMAKE_CURRENT_BINARY_DIR}/../../Dialect/Test - ) diff --git a/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp b/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp deleted file mode 100644 index 220bcb5..0000000 --- a/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp +++ /dev/null @@ -1,245 +0,0 @@ -//===- TestOneToNTypeConversionPass.cpp - Test pass 1:N type conv. utils --===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "TestDialect.h" -#include "mlir/Dialect/Func/Transforms/OneToNFuncConversions.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/OneToNTypeConversion.h" - -using namespace mlir; - -namespace { -/// Test pass that exercises the (poor-man's) 1:N type conversion mechanisms -/// in `applyPartialOneToNConversion` by converting built-in tuples to the -/// elements they consist of as well as some dummy ops operating on these -/// tuples. -struct TestOneToNTypeConversionPass - : public PassWrapper> { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestOneToNTypeConversionPass) - - TestOneToNTypeConversionPass() = default; - TestOneToNTypeConversionPass(const TestOneToNTypeConversionPass &pass) - : PassWrapper(pass) {} - - void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); - } - - StringRef getArgument() const final { - return "test-one-to-n-type-conversion"; - } - - StringRef getDescription() const final { - return "Test pass for 1:N type conversion"; - } - - Option convertFuncOps{*this, "convert-func-ops", - llvm::cl::desc("Enable conversion on func ops"), - llvm::cl::init(false)}; - - Option convertTupleOps{*this, "convert-tuple-ops", - llvm::cl::desc("Enable conversion on tuple ops"), - llvm::cl::init(false)}; - - void runOnOperation() override; -}; - -} // namespace - -namespace mlir { -namespace test { -void registerTestOneToNTypeConversionPass() { - PassRegistration(); -} -} // namespace test -} // namespace mlir - -namespace { - -/// Test pattern on for the `make_tuple` op from the test dialect that converts -/// this kind of op into it's "decomposed" form, i.e., the elements of the tuple -/// that is being produced by `test.make_tuple`, which are really just the -/// operands of this op. -class ConvertMakeTupleOp - : public OneToNOpConversionPattern<::test::MakeTupleOp> { -public: - using OneToNOpConversionPattern< - ::test::MakeTupleOp>::OneToNOpConversionPattern; - - LogicalResult matchAndRewrite(::test::MakeTupleOp op, - OneToNPatternRewriter &rewriter, - const OneToNTypeMapping &operandMapping, - const OneToNTypeMapping &resultMapping, - ValueRange convertedOperands) const override { - // Simply replace the current op with the converted operands. - rewriter.replaceOp(op, convertedOperands, resultMapping); - return success(); - } -}; - -/// Test pattern on for the `get_tuple_element` op from the test dialect that -/// converts this kind of op into it's "decomposed" form, i.e., instead of -/// "physically" extracting one element from the tuple, we forward the one -/// element of the decomposed form that is being extracted (or the several -/// elements in case that element is a nested tuple). -class ConvertGetTupleElementOp - : public OneToNOpConversionPattern<::test::GetTupleElementOp> { -public: - using OneToNOpConversionPattern< - ::test::GetTupleElementOp>::OneToNOpConversionPattern; - - LogicalResult matchAndRewrite(::test::GetTupleElementOp op, - OneToNPatternRewriter &rewriter, - const OneToNTypeMapping &operandMapping, - const OneToNTypeMapping &resultMapping, - ValueRange convertedOperands) const override { - // Construct mapping for tuple element types. - auto stateType = op->getOperand(0).getType().cast(); - TypeRange originalElementTypes = stateType.getTypes(); - OneToNTypeMapping elementMapping(originalElementTypes); - if (failed(typeConverter->convertSignatureArgs(originalElementTypes, - elementMapping))) - return failure(); - - // Compute converted operands corresponding to original input tuple. - ValueRange convertedTuple = - operandMapping.getConvertedValues(convertedOperands, 0); - - // Got those converted operands that correspond to the index-th element of - // the original input tuple. - size_t index = op.getIndex(); - ValueRange extractedElement = - elementMapping.getConvertedValues(convertedTuple, index); - - rewriter.replaceOp(op, extractedElement, resultMapping); - - return success(); - } -}; - -} // namespace - -static void populateDecomposeTuplesTestPatterns(TypeConverter &typeConverter, - RewritePatternSet &patterns) { - patterns.add< - // clang-format off - ConvertMakeTupleOp, - ConvertGetTupleElementOp - // clang-format on - >(typeConverter, patterns.getContext()); -} - -/// Creates a sequence of `test.get_tuple_element` ops for all elements of a -/// given tuple value. If some tuple elements are, in turn, tuples, the elements -/// of those are extracted recursively such that the returned values have the -/// same types as `resultTypes.getFlattenedTypes()`. -/// -/// This function has been copied (with small adaptions) from -/// TestDecomposeCallGraphTypes.cpp. -static std::optional> -buildGetTupleElementOps(OpBuilder &builder, TypeRange resultTypes, Value input, - Location loc) { - TupleType inputType = input.getType().dyn_cast(); - if (!inputType) - return {}; - - SmallVector values; - for (auto [idx, elementType] : llvm::enumerate(inputType.getTypes())) { - Value element = builder.create<::test::GetTupleElementOp>( - loc, elementType, input, builder.getI32IntegerAttr(idx)); - if (auto nestedTupleType = elementType.dyn_cast()) { - // Recurse if the current element is also a tuple. - SmallVector flatRecursiveTypes; - nestedTupleType.getFlattenedTypes(flatRecursiveTypes); - std::optional> resursiveValues = - buildGetTupleElementOps(builder, flatRecursiveTypes, element, loc); - if (!resursiveValues.has_value()) - return {}; - values.append(resursiveValues.value()); - } else { - values.push_back(element); - } - } - return values; -} - -/// Creates a `test.make_tuple` op out of the given inputs building a tuple of -/// type `resultType`. If that type is nested, each nested tuple is built -/// recursively with another `test.make_tuple` op. -/// -/// This function has been copied (with small adaptions) from -/// TestDecomposeCallGraphTypes.cpp. -static std::optional buildMakeTupleOp(OpBuilder &builder, - TupleType resultType, - ValueRange inputs, Location loc) { - // Build one value for each element at this nesting level. - SmallVector elements; - elements.reserve(resultType.getTypes().size()); - ValueRange::iterator inputIt = inputs.begin(); - for (Type elementType : resultType.getTypes()) { - if (auto nestedTupleType = elementType.dyn_cast()) { - // Determine how many input values are needed for the nested elements of - // the nested TupleType and advance inputIt by that number. - // TODO: We only need the *number* of nested types, not the types itself. - // Maybe it's worth adding a more efficient overload? - SmallVector nestedFlattenedTypes; - nestedTupleType.getFlattenedTypes(nestedFlattenedTypes); - size_t numNestedFlattenedTypes = nestedFlattenedTypes.size(); - ValueRange nestedFlattenedelements(inputIt, - inputIt + numNestedFlattenedTypes); - inputIt += numNestedFlattenedTypes; - - // Recurse on the values for the nested TupleType. - std::optional res = buildMakeTupleOp(builder, nestedTupleType, - nestedFlattenedelements, loc); - if (!res.has_value()) - return {}; - - // The tuple constructed by the conversion is the element value. - elements.push_back(res.value()); - } else { - // Base case: take one input as is. - elements.push_back(*inputIt++); - } - } - - // Assemble the tuple from the elements. - return builder.create<::test::MakeTupleOp>(loc, resultType, elements); -} - -void TestOneToNTypeConversionPass::runOnOperation() { - ModuleOp module = getOperation(); - auto *context = &getContext(); - - // Assemble type converter. - OneToNTypeConverter typeConverter; - - typeConverter.addConversion([](Type type) { return type; }); - typeConverter.addConversion( - [](TupleType tupleType, SmallVectorImpl &types) { - tupleType.getFlattenedTypes(types); - return success(); - }); - - typeConverter.addArgumentMaterialization(buildMakeTupleOp); - typeConverter.addSourceMaterialization(buildMakeTupleOp); - typeConverter.addTargetMaterialization(buildGetTupleElementOps); - - // Assemble patterns. - RewritePatternSet patterns(context); - if (convertTupleOps) - populateDecomposeTuplesTestPatterns(typeConverter, patterns); - if (convertFuncOps) - populateFuncTypeConversionPatterns(typeConverter, patterns); - - // Run conversion. - if (failed(applyPartialOneToNConversion(module, typeConverter, - std::move(patterns)))) - return signalPassFailure(); -} diff --git a/mlir/tools/mlir-opt/CMakeLists.txt b/mlir/tools/mlir-opt/CMakeLists.txt index c430569..f84fbe6 100644 --- a/mlir/tools/mlir-opt/CMakeLists.txt +++ b/mlir/tools/mlir-opt/CMakeLists.txt @@ -33,7 +33,6 @@ if(MLIR_INCLUDE_TESTS) MLIRTestDialect MLIRTestDynDialect MLIRTestIR - MLIRTestOneToNTypeConversionPass MLIRTestPass MLIRTestPDLL MLIRTestReducer diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp index 855c6a6..e7ca06b 100644 --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -107,7 +107,6 @@ void registerTestMathAlgebraicSimplificationPass(); void registerTestMathPolynomialApproximationPass(); void registerTestMemRefDependenceCheck(); void registerTestMemRefStrideCalculation(); -void registerTestOneToNTypeConversionPass(); void registerTestOpaqueLoc(); void registerTestPadFusion(); void registerTestPDLByteCodePass(); @@ -219,7 +218,6 @@ void registerTestPasses() { mlir::test::registerTestMathPolynomialApproximationPass(); mlir::test::registerTestMemRefDependenceCheck(); mlir::test::registerTestMemRefStrideCalculation(); - mlir::test::registerTestOneToNTypeConversionPass(); mlir::test::registerTestOpaqueLoc(); mlir::test::registerTestPadFusion(); mlir::test::registerTestPDLByteCodePass(); -- 2.7.4