Reland "[mlir][arith] Add wide integer emulation pass"
authorJakub Kuderski <kubak@google.com>
Fri, 9 Sep 2022 03:23:44 +0000 (23:23 -0400)
committerJakub Kuderski <kubak@google.com>
Fri, 9 Sep 2022 03:30:47 +0000 (23:30 -0400)
This reverts commit 45b5e8abe56d7f28c88b0c6cdd60ff741874fb1d.

Relands https://reviews.llvm.org/D133135 after fixing shared libs
builds.

mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.h
mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.td
mlir/include/mlir/Dialect/Arithmetic/Transforms/WideIntEmulationConverter.h [new file with mode: 0644]
mlir/lib/Dialect/Arithmetic/Transforms/CMakeLists.txt
mlir/lib/Dialect/Arithmetic/Transforms/EmulateWideInt.cpp [new file with mode: 0644]
mlir/test/Dialect/Arithmetic/emulate-wide-int.mlir [new file with mode: 0644]

index 922d653..5ee3fb0 100644 (file)
@@ -15,16 +15,25 @@ namespace mlir {
 namespace arith {
 
 #define GEN_PASS_DECL_ARITHMETICBUFFERIZE
+#define GEN_PASS_DECL_ARITHMETICEMULATEWIDEINT
 #define GEN_PASS_DECL_ARITHMETICEXPANDOPS
 #define GEN_PASS_DECL_ARITHMETICUNSIGNEDWHENEQUIVALENT
 #include "mlir/Dialect/Arithmetic/Transforms/Passes.h.inc"
 
+class WideIntEmulationConverter;
+
 /// Create a pass to bufferize Arithmetic ops.
 std::unique_ptr<Pass> createArithmeticBufferizePass();
 
 /// Create a pass to bufferize arith.constant ops.
 std::unique_ptr<Pass> createConstantBufferizePass(uint64_t alignment = 0);
 
+/// Adds patterns to emulate wide Arithmetic and Function ops over integer
+/// types into supported ones. This is done by splitting original power-of-two
+/// i2N integer types into two iN halves.
+void populateWideIntEmulationPatterns(WideIntEmulationConverter &typeConverter,
+                                      RewritePatternSet &patterns);
+
 /// Add patterns to expand Arithmetic ops for LLVM lowering.
 void populateArithmeticExpandOpsPatterns(RewritePatternSet &patterns);
 
index 752d715..3895562 100644 (file)
@@ -49,4 +49,24 @@ def ArithmeticUnsignedWhenEquivalent : Pass<"arith-unsigned-when-equivalent"> {
   let constructor = "mlir::arith::createArithmeticUnsignedWhenEquivalentPass()";
 }
 
+def ArithmeticEmulateWideInt : Pass<"arith-emulate-wide-int"> {
+  let summary = "Emulate 2*N-bit integer operations using N-bit operations";
+  let description = [{
+    Emulate integer operations that use too wide integer types with equivalent
+    operations on supported narrow integer types. This is done by splitting
+    original integer values into two halves.
+
+    This pass is intended preserve semantics but not necessarily provide the
+    most efficient implementation.
+    TODO: Optimize op emulation.
+
+    Currently, only power-of-two integer bitwidths are supported.
+  }];
+  let options = [
+    Option<"widestIntSupported", "widest-int-supported", "unsigned",
+           /*default=*/"32", "Widest integer type supported by the target">,
+  ];
+  let dependentDialects = ["vector::VectorDialect"];
+}
+
 #endif // MLIR_DIALECT_ARITHMETIC_TRANSFORMS_PASSES
diff --git a/mlir/include/mlir/Dialect/Arithmetic/Transforms/WideIntEmulationConverter.h b/mlir/include/mlir/Dialect/Arithmetic/Transforms/WideIntEmulationConverter.h
new file mode 100644 (file)
index 0000000..814db23
--- /dev/null
@@ -0,0 +1,34 @@
+//===- WideIntEmulationConverter.h - Type Converter for WIE -----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_ARITHMETIC_WIDE_INT_EMULATION_CONVERTER_H_
+#define MLIR_DIALECT_ARITHMETIC_WIDE_INT_EMULATION_CONVERTER_H_
+
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir::arith {
+/// Converts integer types that are too wide for the target by splitting them in
+/// two halves and thus turning into supported ones, i.e., i2*N --> iN, where N
+/// is the widest integer bitwidth supported by the target.
+/// Currently, we only handle power-of-two integer types and support conversions
+/// of integers twice as wide as the maxium supported by the target. Wide
+/// integers are represented as vectors, e.g., i64 --> vector<2xi32>, where the
+/// first element is the low half of the original integer, and the second
+/// element the high half.
+class WideIntEmulationConverter : public TypeConverter {
+public:
+  explicit WideIntEmulationConverter(unsigned widestIntSupportedByTarget);
+
+  unsigned getMaxTargetIntBitWidth() const { return maxIntWidth; }
+
+private:
+  unsigned maxIntWidth;
+};
+} // namespace mlir::arith
+
+#endif // MLIR_DIALECT_ARITHMETIC_WIDE_INT_EMULATION_CONVERTER_H_
index f140715..ba68d36 100644 (file)
@@ -1,6 +1,7 @@
 add_mlir_dialect_library(MLIRArithmeticTransforms
   BufferizableOpInterfaceImpl.cpp
   Bufferize.cpp
+  EmulateWideInt.cpp
   ExpandOps.cpp
   UnsignedWhenEquivalent.cpp
 
@@ -15,9 +16,13 @@ add_mlir_dialect_library(MLIRArithmeticTransforms
   MLIRArithmeticDialect
   MLIRBufferizationDialect
   MLIRBufferizationTransforms
+  MLIRFuncDialect
+  MLIRFuncTransforms
   MLIRInferIntRangeInterface
   MLIRIR
   MLIRMemRefDialect
   MLIRPass
   MLIRTransforms
+  MLIRTransformUtils
+  MLIRVectorDialect
   )
diff --git a/mlir/lib/Dialect/Arithmetic/Transforms/EmulateWideInt.cpp b/mlir/lib/Dialect/Arithmetic/Transforms/EmulateWideInt.cpp
new file mode 100644 (file)
index 0000000..94e321b
--- /dev/null
@@ -0,0 +1,120 @@
+//===- EmulateWideInt.cpp - Wide integer operation emulation ----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arithmetic/Transforms/Passes.h"
+
+#include "mlir/Dialect/Arithmetic/Transforms/WideIntEmulationConverter.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/Support/MathExtras.h"
+#include <cassert>
+
+namespace mlir::arith {
+#define GEN_PASS_DEF_ARITHMETICEMULATEWIDEINT
+#include "mlir/Dialect/Arithmetic/Transforms/Passes.h.inc"
+} // namespace mlir::arith
+
+using namespace mlir;
+
+namespace {
+struct EmulateWideIntPass final
+    : arith::impl::ArithmeticEmulateWideIntBase<EmulateWideIntPass> {
+  using ArithmeticEmulateWideIntBase::ArithmeticEmulateWideIntBase;
+
+  void runOnOperation() override {
+    if (!llvm::isPowerOf2_32(widestIntSupported)) {
+      signalPassFailure();
+      return;
+    }
+
+    Operation *op = getOperation();
+    MLIRContext *ctx = op->getContext();
+
+    arith::WideIntEmulationConverter typeConverter(widestIntSupported);
+    ConversionTarget target(*ctx);
+    target.addDynamicallyLegalOp<func::FuncOp>([&typeConverter](Operation *op) {
+      return typeConverter.isLegal(cast<func::FuncOp>(op).getFunctionType());
+    });
+    target.addDynamicallyLegalOp<func::CallOp, func::ReturnOp>(
+        [&typeConverter](Operation *op) { return typeConverter.isLegal(op); });
+
+    RewritePatternSet patterns(ctx);
+    arith::populateWideIntEmulationPatterns(typeConverter, patterns);
+
+    if (failed(applyPartialConversion(op, target, std::move(patterns))))
+      signalPassFailure();
+  }
+};
+} // end anonymous namespace
+
+arith::WideIntEmulationConverter::WideIntEmulationConverter(
+    unsigned widestIntSupportedByTarget)
+    : maxIntWidth(widestIntSupportedByTarget) {
+  assert(llvm::isPowerOf2_32(widestIntSupportedByTarget) &&
+         "Only power-of-two integers are supported");
+
+  // Scalar case.
+  addConversion([this](IntegerType ty) -> Optional<Type> {
+    unsigned width = ty.getWidth();
+    if (width <= maxIntWidth)
+      return ty;
+
+    // i2N --> vector<2xiN>
+    if (width == 2 * maxIntWidth)
+      return VectorType::get(2, IntegerType::get(ty.getContext(), maxIntWidth));
+
+    return None;
+  });
+
+  // Vector case.
+  addConversion([this](VectorType ty) -> Optional<Type> {
+    auto intTy = ty.getElementType().dyn_cast<IntegerType>();
+    if (!intTy)
+      return ty;
+
+    unsigned width = intTy.getWidth();
+    if (width <= maxIntWidth)
+      return ty;
+
+    // vector<...xi2N> --> vector<...x2xiN>
+    if (width == 2 * maxIntWidth) {
+      auto newShape = to_vector(ty.getShape());
+      newShape.push_back(2);
+      return VectorType::get(newShape,
+                             IntegerType::get(ty.getContext(), maxIntWidth));
+    }
+
+    return None;
+  });
+
+  // Function case.
+  addConversion([this](FunctionType ty) -> Optional<Type> {
+    // Convert inputs and results, e.g.:
+    //   (i2N, i2N) -> i2N --> (vector<2xiN>, vector<2xiN>) -> vector<2xiN>
+    SmallVector<Type> inputs;
+    if (failed(convertTypes(ty.getInputs(), inputs)))
+      return None;
+
+    SmallVector<Type> results;
+    if (failed(convertTypes(ty.getResults(), results)))
+      return None;
+
+    return FunctionType::get(ty.getContext(), inputs, results);
+  });
+}
+
+void arith::populateWideIntEmulationPatterns(
+    WideIntEmulationConverter &typeConverter, RewritePatternSet &patterns) {
+  // Populate `func.*` conversion patterns.
+  populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
+                                                                 typeConverter);
+  populateCallOpTypeConversionPattern(patterns, typeConverter);
+  populateReturnOpTypeConversionPattern(patterns, typeConverter);
+}
diff --git a/mlir/test/Dialect/Arithmetic/emulate-wide-int.mlir b/mlir/test/Dialect/Arithmetic/emulate-wide-int.mlir
new file mode 100644 (file)
index 0000000..aafeb5b
--- /dev/null
@@ -0,0 +1,51 @@
+// RUN: mlir-opt --arith-emulate-wide-int="widest-int-supported=32" %s | FileCheck %s
+
+// Expect no conversions, i32 is supported.
+// CHECK-LABEL: func @addi_same_i32
+// CHECK-SAME:    ([[ARG:%.+]]: i32) -> i32
+// CHECK-NEXT:    [[X:%.+]] = arith.addi [[ARG]], [[ARG]] : i32
+// CHECK-NEXT:    return [[X]] : i32
+func.func @addi_same_i32(%a : i32) -> i32 {
+    %x = arith.addi %a, %a : i32
+    return %x : i32
+}
+
+// Expect no conversions, i32 is supported.
+// CHECK-LABEL: func @addi_same_vector_i32
+// CHECK-SAME:    ([[ARG:%.+]]: vector<2xi32>) -> vector<2xi32>
+// CHECK-NEXT:    [[X:%.+]] = arith.addi [[ARG]], [[ARG]] : vector<2xi32>
+// CHECK-NEXT:    return [[X]] : vector<2xi32>
+func.func @addi_same_vector_i32(%a : vector<2xi32>) -> vector<2xi32> {
+    %x = arith.addi %a, %a : vector<2xi32>
+    return %x : vector<2xi32>
+}
+
+// CHECK-LABEL: func @identity_scalar
+// CHECK-SAME:     ([[ARG:%.+]]: vector<2xi32>) -> vector<2xi32>
+// CHECK-NEXT:     return [[ARG]] : vector<2xi32>
+func.func @identity_scalar(%x : i64) -> i64 {
+    return %x : i64
+}
+
+// CHECK-LABEL: func @identity_vector
+// CHECK-SAME:     ([[ARG:%.+]]: vector<4x2xi32>) -> vector<4x2xi32>
+// CHECK-NEXT:     return [[ARG]] : vector<4x2xi32>
+func.func @identity_vector(%x : vector<4xi64>) -> vector<4xi64> {
+    return %x : vector<4xi64>
+}
+
+// CHECK-LABEL: func @identity_vector2d
+// CHECK-SAME:     ([[ARG:%.+]]: vector<3x4x2xi32>) -> vector<3x4x2xi32>
+// CHECK-NEXT:     return [[ARG]] : vector<3x4x2xi32>
+func.func @identity_vector2d(%x : vector<3x4xi64>) -> vector<3x4xi64> {
+    return %x : vector<3x4xi64>
+}
+
+// CHECK-LABEL: func @call
+// CHECK-SAME:     ([[ARG:%.+]]: vector<4x2xi32>) -> vector<4x2xi32>
+// CHECK-NEXT:     [[RES:%.+]] = call @identity_vector([[ARG]]) : (vector<4x2xi32>) -> vector<4x2xi32>
+// CHECK-NEXT:     return [[RES]] : vector<4x2xi32>
+func.func @call(%a : vector<4xi64>) -> vector<4xi64> {
+    %res = func.call @identity_vector(%a) : (vector<4xi64>) -> vector<4xi64>
+    return %res : vector<4xi64>
+}