[mlir][memref] Add initial Wide Int Emulation pass and patterns
authorJakub Kuderski <kubak@google.com>
Fri, 14 Oct 2022 15:36:47 +0000 (11:36 -0400)
committerJakub Kuderski <kubak@google.com>
Fri, 14 Oct 2022 15:37:52 +0000 (11:37 -0400)
Add a new pass and conversions to emulate wide integer operations over memrefs.
The emulation is implemented on top of the existing pass to emulate wide integer arith ops.

Improve naming in the arith pass to avoid potential name clashes.

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D135722

mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h
mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp [new file with mode: 0644]
mlir/test/Dialect/MemRef/emulate-wide-int.mlir [new file with mode: 0644]
mlir/test/lib/Dialect/Arith/TestEmulateWideInt.cpp

index 5e441a7020708e3a799d8abddd7a7ba38cddade6..d087ac69828a963abcb700b99cf512645a6a7507 100644 (file)
@@ -28,8 +28,8 @@ std::unique_ptr<Pass> createConstantBufferizePass(uint64_t alignment = 0);
 /// Adds patterns to emulate wide Arith and Function ops over integer
 /// types into supported ones. This is done by splitting original power-of-two
 /// i2N integer types into two iN halves.
-void populateWideIntEmulationPatterns(WideIntEmulationConverter &typeConverter,
-                                      RewritePatternSet &patterns);
+void populateArithWideIntEmulationPatterns(
+    WideIntEmulationConverter &typeConverter, RewritePatternSet &patterns);
 
 /// Add patterns to expand Arith ceil/floor division ops.
 void populateCeilFloorDivExpandOpsPatterns(RewritePatternSet &patterns);
index e642dc572a0af4f1115c8b75848537edab119e39..16ef294a90d28d5fa6a950edb34a9c8319baba26 100644 (file)
@@ -52,9 +52,9 @@ def ArithUnsignedWhenEquivalent : Pass<"arith-unsigned-when-equivalent"> {
 def ArithEmulateWideInt : Pass<"arith-emulate-wide-int"> {
   let summary = "Emulate 2*N-bit integer operations using N-bit operations";
   let description = [{
-    Emulate integer operations that use too wide integer types with equivalent
-    operations on supported narrow integer types. This is done by splitting
-    original integer values into two halves.
+    Emulate arith integer operations that use too wide integer types with
+    equivalent operations on supported narrow integer types. This is done by
+    splitting original integer values into two halves.
 
     This pass is intended preserve semantics but not necessarily provide the
     most efficient implementation.
index 2a7b5d82a4cdbcde816079e30eb800e38655ed93..ee30e6e252dff058f97aab26c864d4dc615fc600 100644 (file)
@@ -20,6 +20,10 @@ namespace mlir {
 class AffineDialect;
 class ModuleOp;
 
+namespace arith {
+class WideIntEmulationConverter;
+} // namespace arith
+
 namespace func {
 class FuncDialect;
 } // namespace func
@@ -60,6 +64,17 @@ void populateResolveShapedTypeResultDimsPatterns(RewritePatternSet &patterns);
 void populateSimplifyExtractStridedMetadataOpPatterns(
     RewritePatternSet &patterns);
 
+/// Appends patterns for emulating wide integer memref operations with ops over
+/// narrower integer types.
+void populateMemRefWideIntEmulationPatterns(
+    arith::WideIntEmulationConverter &typeConverter,
+    RewritePatternSet &patterns);
+
+/// Appends type converions for emulating wide integer memref operations with
+/// ops over narrowe integer types.
+void populateMemRefWideIntEmulationConversions(
+    arith::WideIntEmulationConverter &typeConverter);
+
 /// Transformation to do multi-buffering/array expansion to remove dependencies
 /// on the temporary allocation between consecutive loop iterations.
 /// It returns the new allocation if the original allocation was multi-buffered
index 64045033cabef5edbe1c52b2294738ae99a9bfea..b41676482a889dce0fba2853769b3f79b11853f2 100644 (file)
@@ -28,6 +28,22 @@ def FoldMemRefAliasOps : Pass<"fold-memref-alias-ops"> {
   ];
 }
 
+def MemRefEmulateWideInt : Pass<"memref-emulate-wide-int"> {
+  let summary = "Emulate 2*N-bit integer operations using N-bit operations";
+  let description = [{
+    Emulate memref integer operations that use too wide integer types with
+    equivalent operations on supported narrow integer types. This is done by
+    splitting original integer values into two halves.
+
+    Currently, only power-of-two integer bitwidths are supported.
+  }];
+  let options = [
+    Option<"widestIntSupported", "widest-int-supported", "unsigned",
+           /*default=*/"32", "Widest integer type supported by the target">,
+  ];
+  let dependentDialects = ["vector::VectorDialect"];
+}
+
 def NormalizeMemRefs : Pass<"normalize-memrefs", "ModuleOp"> {
   let summary = "Normalize memrefs";
    let description = [{
index 9784f0d4f92e20b514215e1378e48560d5b40c37..826c8ee96d4192e662792fd38274009484d81bc7 100644 (file)
@@ -745,7 +745,7 @@ struct EmulateWideIntPass final
             opLegalCallback);
 
     RewritePatternSet patterns(ctx);
-    arith::populateWideIntEmulationPatterns(typeConverter, patterns);
+    arith::populateArithWideIntEmulationPatterns(typeConverter, patterns);
 
     if (failed(applyPartialConversion(op, target, std::move(patterns))))
       signalPassFailure();
@@ -817,7 +817,7 @@ arith::WideIntEmulationConverter::WideIntEmulationConverter(
   });
 }
 
-void arith::populateWideIntEmulationPatterns(
+void arith::populateArithWideIntEmulationPatterns(
     WideIntEmulationConverter &typeConverter, RewritePatternSet &patterns) {
   // Populate `func.*` conversion patterns.
   populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
index 967a85d053bfe528aae77336d6df8c1a2b3975dc..2e2ffb491fb75c3e1ccd7f56d2b7e2d29629d47e 100644 (file)
@@ -1,6 +1,7 @@
 add_mlir_dialect_library(MLIRMemRefTransforms
   ComposeSubView.cpp
   ExpandOps.cpp
+  EmulateWideInt.cpp
   FoldMemRefAliasOps.cpp
   MultiBuffer.cpp
   NormalizeMemRefs.cpp
@@ -17,6 +18,7 @@ add_mlir_dialect_library(MLIRMemRefTransforms
   MLIRAffineDialect
   MLIRAffineUtils
   MLIRArithDialect
+  MLIRArithTransforms
   MLIRFuncDialect
   MLIRInferTypeOpInterface
   MLIRLoopLikeInterface
diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp
new file mode 100644 (file)
index 0000000..02c6e58
--- /dev/null
@@ -0,0 +1,163 @@
+//===- EmulateWideInt.cpp - Wide integer operation emulation ----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Transforms/Passes.h"
+#include "mlir/Dialect/Arith/Transforms/WideIntEmulationConverter.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/MemRef/Transforms/Passes.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/MathExtras.h"
+#include <cassert>
+
+namespace mlir::memref {
+#define GEN_PASS_DEF_MEMREFEMULATEWIDEINT
+#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
+} // namespace mlir::memref
+
+using namespace mlir;
+
+namespace {
+
+//===----------------------------------------------------------------------===//
+// ConvertMemRefAlloc
+//===----------------------------------------------------------------------===//
+
+struct ConvertMemRefAlloc final : OpConversionPattern<memref::AllocOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(memref::AllocOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Type newTy = getTypeConverter()->convertType(op.getType());
+    if (!newTy)
+      return rewriter.notifyMatchFailure(
+          op->getLoc(),
+          llvm::formatv("failed to convert memref type: {0}", op.getType()));
+
+    rewriter.replaceOpWithNewOp<memref::AllocOp>(
+        op, newTy, adaptor.getDynamicSizes(), adaptor.getSymbolOperands(),
+        adaptor.getAlignmentAttr());
+    return success();
+  }
+};
+
+//===----------------------------------------------------------------------===//
+// ConvertMemRefLoad
+//===----------------------------------------------------------------------===//
+
+struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(memref::LoadOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Type newResTy = getTypeConverter()->convertType(op.getType());
+    if (!newResTy)
+      return rewriter.notifyMatchFailure(
+          op->getLoc(), llvm::formatv("failed to convert memref type: {0}",
+                                      op.getMemRefType()));
+
+    rewriter.replaceOpWithNewOp<memref::LoadOp>(
+        op, newResTy, adaptor.getMemref(), adaptor.getIndices());
+    return success();
+  }
+};
+
+//===----------------------------------------------------------------------===//
+// ConvertMemRefStore
+//===----------------------------------------------------------------------===//
+
+struct ConvertMemRefStore final : OpConversionPattern<memref::StoreOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Type newTy = getTypeConverter()->convertType(op.getMemRefType());
+    if (!newTy)
+      return rewriter.notifyMatchFailure(
+          op->getLoc(), llvm::formatv("failed to convert memref type: {0}",
+                                      op.getMemRefType()));
+
+    rewriter.replaceOpWithNewOp<memref::StoreOp>(
+        op, adaptor.getValue(), adaptor.getMemref(), adaptor.getIndices());
+    return success();
+  }
+};
+
+//===----------------------------------------------------------------------===//
+// Pass Definition
+//===----------------------------------------------------------------------===//
+
+struct EmulateWideIntPass final
+    : memref::impl::MemRefEmulateWideIntBase<EmulateWideIntPass> {
+  using MemRefEmulateWideIntBase::MemRefEmulateWideIntBase;
+
+  void runOnOperation() override {
+    if (!llvm::isPowerOf2_32(widestIntSupported) || widestIntSupported < 2) {
+      signalPassFailure();
+      return;
+    }
+
+    Operation *op = getOperation();
+    MLIRContext *ctx = op->getContext();
+
+    arith::WideIntEmulationConverter typeConverter(widestIntSupported);
+    memref::populateMemRefWideIntEmulationConversions(typeConverter);
+    ConversionTarget target(*ctx);
+    target.addDynamicallyLegalDialect<
+        arith::ArithDialect, memref::MemRefDialect, vector::VectorDialect>(
+        [&typeConverter](Operation *op) { return typeConverter.isLegal(op); });
+
+    RewritePatternSet patterns(ctx);
+    // Add common pattenrs to support contants, functions, etc.
+    arith::populateArithWideIntEmulationPatterns(typeConverter, patterns);
+
+    memref::populateMemRefWideIntEmulationPatterns(typeConverter, patterns);
+
+    if (failed(applyPartialConversion(op, target, std::move(patterns))))
+      signalPassFailure();
+  }
+};
+
+} // end anonymous namespace
+
+//===----------------------------------------------------------------------===//
+// Public Interface Definition
+//===----------------------------------------------------------------------===//
+
+void memref::populateMemRefWideIntEmulationPatterns(
+    arith::WideIntEmulationConverter &typeConverter,
+    RewritePatternSet &patterns) {
+  // Populate `memref.*` conversion patterns.
+  patterns.add<ConvertMemRefAlloc, ConvertMemRefLoad, ConvertMemRefStore>(
+      typeConverter, patterns.getContext());
+}
+
+void memref::populateMemRefWideIntEmulationConversions(
+    arith::WideIntEmulationConverter &typeConverter) {
+  typeConverter.addConversion(
+      [&typeConverter](MemRefType ty) -> Optional<Type> {
+        auto intTy = ty.getElementType().dyn_cast<IntegerType>();
+        if (!intTy)
+          return ty;
+
+        if (intTy.getIntOrFloatBitWidth() <=
+            typeConverter.getMaxTargetIntBitWidth())
+          return ty;
+
+        Type newElemTy = typeConverter.convertType(intTy);
+        if (!newElemTy)
+          return None;
+
+        return ty.cloneWith(None, newElemTy);
+      });
+}
diff --git a/mlir/test/Dialect/MemRef/emulate-wide-int.mlir b/mlir/test/Dialect/MemRef/emulate-wide-int.mlir
new file mode 100644 (file)
index 0000000..de1cba5
--- /dev/null
@@ -0,0 +1,46 @@
+// RUN: mlir-opt --memref-emulate-wide-int="widest-int-supported=32" %s | FileCheck %s
+
+// Expect no conversions, i32 is supported.
+// CHECK-LABEL: func @memref_i32
+// CHECK:         [[M:%.+]] = memref.alloc() : memref<4xi32, 1>
+// CHECK-NEXT:    [[V:%.+]] = memref.load [[M]][{{%.+}}] : memref<4xi32, 1>
+// CHECK-NEXT:    memref.store {{%.+}}, [[M]][{{%.+}}] : memref<4xi32, 1>
+// CHECK-NEXT:    return
+func.func @memref_i32() {
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : i32
+    %m = memref.alloc() : memref<4xi32, 1>
+    %v = memref.load %m[%c0] : memref<4xi32, 1>
+    memref.store %c1, %m[%c0] : memref<4xi32, 1>
+    return
+}
+
+// Expect no conversions, f64 is not an integer type.
+// CHECK-LABEL: func @memref_f32
+// CHECK:         [[M:%.+]] = memref.alloc() : memref<4xf32, 1>
+// CHECK-NEXT:    [[V:%.+]] = memref.load [[M]][{{%.+}}] : memref<4xf32, 1>
+// CHECK-NEXT:    memref.store {{%.+}}, [[M]][{{%.+}}] : memref<4xf32, 1>
+// CHECK-NEXT:    return
+func.func @memref_f32() {
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1.0 : f32
+    %m = memref.alloc() : memref<4xf32, 1>
+    %v = memref.load %m[%c0] : memref<4xf32, 1>
+    memref.store %c1, %m[%c0] : memref<4xf32, 1>
+    return
+}
+
+// CHECK-LABEL: func @alloc_load_store_i64
+// CHECK:         [[C1:%.+]] = arith.constant dense<[1, 0]> : vector<2xi32>
+// CHECK-NEXT:    [[M:%.+]]  = memref.alloc() : memref<4xvector<2xi32>, 1>
+// CHECK-NEXT:    [[V:%.+]]  = memref.load [[M]][{{%.+}}] : memref<4xvector<2xi32>, 1>
+// CHECK-NEXT:    memref.store [[C1]], [[M]][{{%.+}}] : memref<4xvector<2xi32>, 1>
+// CHECK-NEXT:    return
+func.func @alloc_load_store_i64() {
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : i64
+    %m = memref.alloc() : memref<4xi64, 1>
+    %v = memref.load %m[%c0] : memref<4xi64, 1>
+    memref.store %c1, %m[%c0] : memref<4xi64, 1>
+    return
+}
index ee84eadcbcdf470f26c3be630cfba273760dfae4..c1ae321711fcdf3051059bb24d9596ac4717052a 100644 (file)
@@ -74,7 +74,7 @@ struct TestEmulateWideIntPass
             });
 
     RewritePatternSet patterns(ctx);
-    arith::populateWideIntEmulationPatterns(typeConverter, patterns);
+    arith::populateArithWideIntEmulationPatterns(typeConverter, patterns);
     if (failed(applyPartialConversion(op, target, std::move(patterns))))
       signalPassFailure();
   }