namespace arith {
#define GEN_PASS_DECL_ARITHMETICBUFFERIZE
-#define GEN_PASS_DECL_ARITHMETICEMULATEWIDEINT
#define GEN_PASS_DECL_ARITHMETICEXPANDOPS
#define GEN_PASS_DECL_ARITHMETICUNSIGNEDWHENEQUIVALENT
#include "mlir/Dialect/Arithmetic/Transforms/Passes.h.inc"
-class WideIntEmulationConverter;
-
/// Create a pass to bufferize Arithmetic ops.
std::unique_ptr<Pass> createArithmeticBufferizePass();
/// Create a pass to bufferize arith.constant ops.
std::unique_ptr<Pass> createConstantBufferizePass(uint64_t alignment = 0);
-/// Adds patterns to emulate wide Arithmetic and Function ops over integer
-/// types into supported ones. This is done by splitting original power-of-two
-/// i2N integer types into two iN halves.
-void populateWideIntEmulationPatterns(WideIntEmulationConverter &typeConverter,
- RewritePatternSet &patterns);
-
/// Add patterns to expand Arithmetic ops for LLVM lowering.
void populateArithmeticExpandOpsPatterns(RewritePatternSet &patterns);
let constructor = "mlir::arith::createArithmeticUnsignedWhenEquivalentPass()";
}
-def ArithmeticEmulateWideInt : Pass<"arith-emulate-wide-int"> {
- let summary = "Emulate 2*N-bit integer operations using N-bit operations";
- let description = [{
- Emulate integer operations that use too wide integer types with equivalent
- operations on supported narrow integer types. This is done by splitting
- original integer values into two halves.
-
- This pass is intended preserve semantics but not necessarily provide the
- most efficient implementation.
- TODO: Optimize op emulation.
-
- Currently, only power-of-two integer bitwidths are supported.
- }];
- let options = [
- Option<"widestIntSupported", "widest-int-supported", "unsigned",
- /*default=*/"32", "Widest integer type supported by the target">,
- ];
- let dependentDialects = ["vector::VectorDialect"];
-}
-
#endif // MLIR_DIALECT_ARITHMETIC_TRANSFORMS_PASSES
+++ /dev/null
-//===- WideIntEmulationConverter.h - Type Converter for WIE -----*- C++ -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_DIALECT_ARITHMETIC_WIDE_INT_EMULATION_CONVERTER_H_
-#define MLIR_DIALECT_ARITHMETIC_WIDE_INT_EMULATION_CONVERTER_H_
-
-#include "mlir/Transforms/DialectConversion.h"
-
-namespace mlir::arith {
-/// Converts integer types that are too wide for the target by splitting them in
-/// two halves and thus turning into supported ones, i.e., i2*N --> iN, where N
-/// is the widest integer bitwidth supported by the target.
-/// Currently, we only handle power-of-two integer types and support conversions
-/// of integers twice as wide as the maxium supported by the target. Wide
-/// integers are represented as vectors, e.g., i64 --> vector<2xi32>, where the
-/// first element is the low half of the original integer, and the second
-/// element the high half.
-class WideIntEmulationConverter : public TypeConverter {
-public:
- explicit WideIntEmulationConverter(unsigned widestIntSupportedByTarget);
-
- unsigned getMaxTargetIntBitWidth() const { return maxIntWidth; }
-
-private:
- unsigned maxIntWidth;
-};
-} // namespace mlir::arith
-
-#endif // MLIR_DIALECT_ARITHMETIC_WIDE_INT_EMULATION_CONVERTER_H_
add_mlir_dialect_library(MLIRArithmeticTransforms
BufferizableOpInterfaceImpl.cpp
Bufferize.cpp
- EmulateWideInt.cpp
ExpandOps.cpp
UnsignedWhenEquivalent.cpp
MLIRAnalysis
MLIRArithmeticDialect
MLIRBufferizationDialect
- MLIRFuncDialect
MLIRBufferizationTransforms
- MLIRVectorDialect
MLIRInferIntRangeInterface
MLIRIR
MLIRMemRefDialect
MLIRPass
MLIRTransforms
- MLIRTransformUtils
)
+++ /dev/null
-//===- EmulateWideInt.cpp - Wide integer operation emulation ----*- C++ -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/Arithmetic/Transforms/Passes.h"
-
-#include "mlir/Dialect/Arithmetic/Transforms/WideIntEmulationConverter.h"
-#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
-#include "mlir/Dialect/Vector/IR/VectorOps.h"
-#include "mlir/Transforms/DialectConversion.h"
-#include "llvm/Support/MathExtras.h"
-#include <cassert>
-
-namespace mlir::arith {
-#define GEN_PASS_DEF_ARITHMETICEMULATEWIDEINT
-#include "mlir/Dialect/Arithmetic/Transforms/Passes.h.inc"
-} // namespace mlir::arith
-
-using namespace mlir;
-
-namespace {
-struct EmulateWideIntPass final
- : arith::impl::ArithmeticEmulateWideIntBase<EmulateWideIntPass> {
- using ArithmeticEmulateWideIntBase::ArithmeticEmulateWideIntBase;
-
- void runOnOperation() override {
- if (!llvm::isPowerOf2_32(widestIntSupported)) {
- signalPassFailure();
- return;
- }
-
- Operation *op = getOperation();
- MLIRContext *ctx = op->getContext();
-
- arith::WideIntEmulationConverter typeConverter(widestIntSupported);
- ConversionTarget target(*ctx);
- target.addDynamicallyLegalOp<func::FuncOp>([&typeConverter](Operation *op) {
- return typeConverter.isLegal(cast<func::FuncOp>(op).getFunctionType());
- });
- target.addDynamicallyLegalOp<func::CallOp, func::ReturnOp>(
- [&typeConverter](Operation *op) { return typeConverter.isLegal(op); });
-
- RewritePatternSet patterns(ctx);
- arith::populateWideIntEmulationPatterns(typeConverter, patterns);
-
- if (failed(applyPartialConversion(op, target, std::move(patterns))))
- signalPassFailure();
- }
-};
-} // end anonymous namespace
-
-arith::WideIntEmulationConverter::WideIntEmulationConverter(
- unsigned widestIntSupportedByTarget)
- : maxIntWidth(widestIntSupportedByTarget) {
- assert(llvm::isPowerOf2_32(widestIntSupportedByTarget) &&
- "Only power-of-two integers are supported");
-
- // Scalar case.
- addConversion([this](IntegerType ty) -> Optional<Type> {
- unsigned width = ty.getWidth();
- if (width <= maxIntWidth)
- return ty;
-
- // i2N --> vector<2xiN>
- if (width == 2 * maxIntWidth)
- return VectorType::get(2, IntegerType::get(ty.getContext(), maxIntWidth));
-
- return None;
- });
-
- // Vector case.
- addConversion([this](VectorType ty) -> Optional<Type> {
- auto intTy = ty.getElementType().dyn_cast<IntegerType>();
- if (!intTy)
- return ty;
-
- unsigned width = intTy.getWidth();
- if (width <= maxIntWidth)
- return ty;
-
- // vector<...xi2N> --> vector<...x2xiN>
- if (width == 2 * maxIntWidth) {
- auto newShape = to_vector(ty.getShape());
- newShape.push_back(2);
- return VectorType::get(newShape,
- IntegerType::get(ty.getContext(), maxIntWidth));
- }
-
- return None;
- });
-
- // Function case.
- addConversion([this](FunctionType ty) -> Optional<Type> {
- // Convert inputs and results, e.g.:
- // (i2N, i2N) -> i2N --> (vector<2xiN>, vector<2xiN>) -> vector<2xiN>
- SmallVector<Type> inputs;
- if (failed(convertTypes(ty.getInputs(), inputs)))
- return None;
-
- SmallVector<Type> results;
- if (failed(convertTypes(ty.getResults(), results)))
- return None;
-
- return FunctionType::get(ty.getContext(), inputs, results);
- });
-}
-
-void arith::populateWideIntEmulationPatterns(
- WideIntEmulationConverter &typeConverter, RewritePatternSet &patterns) {
- // Populate `func.*` conversion patterns.
- populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
- typeConverter);
- populateCallOpTypeConversionPattern(patterns, typeConverter);
- populateReturnOpTypeConversionPattern(patterns, typeConverter);
-}
+++ /dev/null
-// RUN: mlir-opt --arith-emulate-wide-int="widest-int-supported=32" %s | FileCheck %s
-
-// Expect no conversions, i32 is supported.
-// CHECK-LABEL: func @addi_same_i32
-// CHECK-SAME: ([[ARG:%.+]]: i32) -> i32
-// CHECK-NEXT: [[X:%.+]] = arith.addi [[ARG]], [[ARG]] : i32
-// CHECK-NEXT: return [[X]] : i32
-func.func @addi_same_i32(%a : i32) -> i32 {
- %x = arith.addi %a, %a : i32
- return %x : i32
-}
-
-// Expect no conversions, i32 is supported.
-// CHECK-LABEL: func @addi_same_vector_i32
-// CHECK-SAME: ([[ARG:%.+]]: vector<2xi32>) -> vector<2xi32>
-// CHECK-NEXT: [[X:%.+]] = arith.addi [[ARG]], [[ARG]] : vector<2xi32>
-// CHECK-NEXT: return [[X]] : vector<2xi32>
-func.func @addi_same_vector_i32(%a : vector<2xi32>) -> vector<2xi32> {
- %x = arith.addi %a, %a : vector<2xi32>
- return %x : vector<2xi32>
-}
-
-// CHECK-LABEL: func @identity_scalar
-// CHECK-SAME: ([[ARG:%.+]]: vector<2xi32>) -> vector<2xi32>
-// CHECK-NEXT: return [[ARG]] : vector<2xi32>
-func.func @identity_scalar(%x : i64) -> i64 {
- return %x : i64
-}
-
-// CHECK-LABEL: func @identity_vector
-// CHECK-SAME: ([[ARG:%.+]]: vector<4x2xi32>) -> vector<4x2xi32>
-// CHECK-NEXT: return [[ARG]] : vector<4x2xi32>
-func.func @identity_vector(%x : vector<4xi64>) -> vector<4xi64> {
- return %x : vector<4xi64>
-}
-
-// CHECK-LABEL: func @identity_vector2d
-// CHECK-SAME: ([[ARG:%.+]]: vector<3x4x2xi32>) -> vector<3x4x2xi32>
-// CHECK-NEXT: return [[ARG]] : vector<3x4x2xi32>
-func.func @identity_vector2d(%x : vector<3x4xi64>) -> vector<3x4xi64> {
- return %x : vector<3x4xi64>
-}
-
-// CHECK-LABEL: func @call
-// CHECK-SAME: ([[ARG:%.+]]: vector<4x2xi32>) -> vector<4x2xi32>
-// CHECK-NEXT: [[RES:%.+]] = call @identity_vector([[ARG]]) : (vector<4x2xi32>) -> vector<4x2xi32>
-// CHECK-NEXT: return [[RES]] : vector<4x2xi32>
-func.func @call(%a : vector<4xi64>) -> vector<4xi64> {
- %res = func.call @identity_vector(%a) : (vector<4xi64>) -> vector<4xi64>
- return %res : vector<4xi64>
-}