namespace arith {
#define GEN_PASS_DECL_ARITHMETICBUFFERIZE
+#define GEN_PASS_DECL_ARITHMETICEMULATEWIDEINT
#define GEN_PASS_DECL_ARITHMETICEXPANDOPS
#define GEN_PASS_DECL_ARITHMETICUNSIGNEDWHENEQUIVALENT
#include "mlir/Dialect/Arithmetic/Transforms/Passes.h.inc"
+class WideIntEmulationConverter;
+
/// Create a pass to bufferize Arithmetic ops.
std::unique_ptr<Pass> createArithmeticBufferizePass();
/// Create a pass to bufferize arith.constant ops.
std::unique_ptr<Pass> createConstantBufferizePass(uint64_t alignment = 0);
+/// Adds patterns to emulate wide Arithmetic and Function ops over integer
+/// types into supported ones. This is done by splitting original power-of-two
+/// i2N integer types into two iN halves.
+void populateWideIntEmulationPatterns(WideIntEmulationConverter &typeConverter,
+ RewritePatternSet &patterns);
+
/// Add patterns to expand Arithmetic ops for LLVM lowering.
void populateArithmeticExpandOpsPatterns(RewritePatternSet &patterns);
let constructor = "mlir::arith::createArithmeticUnsignedWhenEquivalentPass()";
}
+def ArithmeticEmulateWideInt : Pass<"arith-emulate-wide-int"> {
+ let summary = "Emulate 2*N-bit integer operations using N-bit operations";
+ let description = [{
+ Emulate integer operations that use too wide integer types with equivalent
+ operations on supported narrow integer types. This is done by splitting
+ original integer values into two halves.
+
+ This pass is intended preserve semantics but not necessarily provide the
+ most efficient implementation.
+ TODO: Optimize op emulation.
+
+ Currently, only power-of-two integer bitwidths are supported.
+ }];
+ let options = [
+ Option<"widestIntSupported", "widest-int-supported", "unsigned",
+ /*default=*/"32", "Widest integer type supported by the target">,
+ ];
+ let dependentDialects = ["vector::VectorDialect"];
+}
+
#endif // MLIR_DIALECT_ARITHMETIC_TRANSFORMS_PASSES
--- /dev/null
+//===- WideIntEmulationConverter.h - Type Converter for WIE -----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_ARITHMETIC_WIDE_INT_EMULATION_CONVERTER_H_
+#define MLIR_DIALECT_ARITHMETIC_WIDE_INT_EMULATION_CONVERTER_H_
+
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir::arith {
+/// Converts integer types that are too wide for the target by splitting them in
+/// two halves and thus turning into supported ones, i.e., i2*N --> iN, where N
+/// is the widest integer bitwidth supported by the target.
+/// Currently, we only handle power-of-two integer types and support conversions
+/// of integers twice as wide as the maxium supported by the target. Wide
+/// integers are represented as vectors, e.g., i64 --> vector<2xi32>, where the
+/// first element is the low half of the original integer, and the second
+/// element the high half.
+class WideIntEmulationConverter : public TypeConverter {
+public:
+ explicit WideIntEmulationConverter(unsigned widestIntSupportedByTarget);
+
+ unsigned getMaxTargetIntBitWidth() const { return maxIntWidth; }
+
+private:
+ unsigned maxIntWidth;
+};
+} // namespace mlir::arith
+
+#endif // MLIR_DIALECT_ARITHMETIC_WIDE_INT_EMULATION_CONVERTER_H_
add_mlir_dialect_library(MLIRArithmeticTransforms
BufferizableOpInterfaceImpl.cpp
Bufferize.cpp
+ EmulateWideInt.cpp
ExpandOps.cpp
UnsignedWhenEquivalent.cpp
MLIRArithmeticDialect
MLIRBufferizationDialect
MLIRBufferizationTransforms
+ MLIRFuncDialect
+ MLIRFuncTransforms
MLIRInferIntRangeInterface
MLIRIR
MLIRMemRefDialect
MLIRPass
MLIRTransforms
+ MLIRTransformUtils
+ MLIRVectorDialect
)
--- /dev/null
+//===- EmulateWideInt.cpp - Wide integer operation emulation ----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arithmetic/Transforms/Passes.h"
+
+#include "mlir/Dialect/Arithmetic/Transforms/WideIntEmulationConverter.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/Support/MathExtras.h"
+#include <cassert>
+
+namespace mlir::arith {
+#define GEN_PASS_DEF_ARITHMETICEMULATEWIDEINT
+#include "mlir/Dialect/Arithmetic/Transforms/Passes.h.inc"
+} // namespace mlir::arith
+
+using namespace mlir;
+
+namespace {
+struct EmulateWideIntPass final
+ : arith::impl::ArithmeticEmulateWideIntBase<EmulateWideIntPass> {
+ using ArithmeticEmulateWideIntBase::ArithmeticEmulateWideIntBase;
+
+ void runOnOperation() override {
+ if (!llvm::isPowerOf2_32(widestIntSupported)) {
+ signalPassFailure();
+ return;
+ }
+
+ Operation *op = getOperation();
+ MLIRContext *ctx = op->getContext();
+
+ arith::WideIntEmulationConverter typeConverter(widestIntSupported);
+ ConversionTarget target(*ctx);
+ target.addDynamicallyLegalOp<func::FuncOp>([&typeConverter](Operation *op) {
+ return typeConverter.isLegal(cast<func::FuncOp>(op).getFunctionType());
+ });
+ target.addDynamicallyLegalOp<func::CallOp, func::ReturnOp>(
+ [&typeConverter](Operation *op) { return typeConverter.isLegal(op); });
+
+ RewritePatternSet patterns(ctx);
+ arith::populateWideIntEmulationPatterns(typeConverter, patterns);
+
+ if (failed(applyPartialConversion(op, target, std::move(patterns))))
+ signalPassFailure();
+ }
+};
+} // end anonymous namespace
+
+arith::WideIntEmulationConverter::WideIntEmulationConverter(
+ unsigned widestIntSupportedByTarget)
+ : maxIntWidth(widestIntSupportedByTarget) {
+ assert(llvm::isPowerOf2_32(widestIntSupportedByTarget) &&
+ "Only power-of-two integers are supported");
+
+ // Scalar case.
+ addConversion([this](IntegerType ty) -> Optional<Type> {
+ unsigned width = ty.getWidth();
+ if (width <= maxIntWidth)
+ return ty;
+
+ // i2N --> vector<2xiN>
+ if (width == 2 * maxIntWidth)
+ return VectorType::get(2, IntegerType::get(ty.getContext(), maxIntWidth));
+
+ return None;
+ });
+
+ // Vector case.
+ addConversion([this](VectorType ty) -> Optional<Type> {
+ auto intTy = ty.getElementType().dyn_cast<IntegerType>();
+ if (!intTy)
+ return ty;
+
+ unsigned width = intTy.getWidth();
+ if (width <= maxIntWidth)
+ return ty;
+
+ // vector<...xi2N> --> vector<...x2xiN>
+ if (width == 2 * maxIntWidth) {
+ auto newShape = to_vector(ty.getShape());
+ newShape.push_back(2);
+ return VectorType::get(newShape,
+ IntegerType::get(ty.getContext(), maxIntWidth));
+ }
+
+ return None;
+ });
+
+ // Function case.
+ addConversion([this](FunctionType ty) -> Optional<Type> {
+ // Convert inputs and results, e.g.:
+ // (i2N, i2N) -> i2N --> (vector<2xiN>, vector<2xiN>) -> vector<2xiN>
+ SmallVector<Type> inputs;
+ if (failed(convertTypes(ty.getInputs(), inputs)))
+ return None;
+
+ SmallVector<Type> results;
+ if (failed(convertTypes(ty.getResults(), results)))
+ return None;
+
+ return FunctionType::get(ty.getContext(), inputs, results);
+ });
+}
+
+void arith::populateWideIntEmulationPatterns(
+ WideIntEmulationConverter &typeConverter, RewritePatternSet &patterns) {
+ // Populate `func.*` conversion patterns.
+ populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
+ typeConverter);
+ populateCallOpTypeConversionPattern(patterns, typeConverter);
+ populateReturnOpTypeConversionPattern(patterns, typeConverter);
+}
--- /dev/null
+// RUN: mlir-opt --arith-emulate-wide-int="widest-int-supported=32" %s | FileCheck %s
+
+// Expect no conversions, i32 is supported.
+// CHECK-LABEL: func @addi_same_i32
+// CHECK-SAME: ([[ARG:%.+]]: i32) -> i32
+// CHECK-NEXT: [[X:%.+]] = arith.addi [[ARG]], [[ARG]] : i32
+// CHECK-NEXT: return [[X]] : i32
+func.func @addi_same_i32(%a : i32) -> i32 {
+ %x = arith.addi %a, %a : i32
+ return %x : i32
+}
+
+// Expect no conversions, i32 is supported.
+// CHECK-LABEL: func @addi_same_vector_i32
+// CHECK-SAME: ([[ARG:%.+]]: vector<2xi32>) -> vector<2xi32>
+// CHECK-NEXT: [[X:%.+]] = arith.addi [[ARG]], [[ARG]] : vector<2xi32>
+// CHECK-NEXT: return [[X]] : vector<2xi32>
+func.func @addi_same_vector_i32(%a : vector<2xi32>) -> vector<2xi32> {
+ %x = arith.addi %a, %a : vector<2xi32>
+ return %x : vector<2xi32>
+}
+
+// CHECK-LABEL: func @identity_scalar
+// CHECK-SAME: ([[ARG:%.+]]: vector<2xi32>) -> vector<2xi32>
+// CHECK-NEXT: return [[ARG]] : vector<2xi32>
+func.func @identity_scalar(%x : i64) -> i64 {
+ return %x : i64
+}
+
+// CHECK-LABEL: func @identity_vector
+// CHECK-SAME: ([[ARG:%.+]]: vector<4x2xi32>) -> vector<4x2xi32>
+// CHECK-NEXT: return [[ARG]] : vector<4x2xi32>
+func.func @identity_vector(%x : vector<4xi64>) -> vector<4xi64> {
+ return %x : vector<4xi64>
+}
+
+// CHECK-LABEL: func @identity_vector2d
+// CHECK-SAME: ([[ARG:%.+]]: vector<3x4x2xi32>) -> vector<3x4x2xi32>
+// CHECK-NEXT: return [[ARG]] : vector<3x4x2xi32>
+func.func @identity_vector2d(%x : vector<3x4xi64>) -> vector<3x4xi64> {
+ return %x : vector<3x4xi64>
+}
+
+// CHECK-LABEL: func @call
+// CHECK-SAME: ([[ARG:%.+]]: vector<4x2xi32>) -> vector<4x2xi32>
+// CHECK-NEXT: [[RES:%.+]] = call @identity_vector([[ARG]]) : (vector<4x2xi32>) -> vector<4x2xi32>
+// CHECK-NEXT: return [[RES]] : vector<4x2xi32>
+func.func @call(%a : vector<4xi64>) -> vector<4xi64> {
+ %res = func.call @identity_vector(%a) : (vector<4xi64>) -> vector<4xi64>
+ return %res : vector<4xi64>
+}