--- /dev/null
+//===- EmulateWideInt.cpp - Wide integer operation emulation ----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Transforms/Passes.h"
+#include "mlir/Dialect/Arith/Transforms/WideIntEmulationConverter.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/MemRef/Transforms/Passes.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/MathExtras.h"
+#include <cassert>
+
+namespace mlir::memref {
+#define GEN_PASS_DEF_MEMREFEMULATEWIDEINT
+#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
+} // namespace mlir::memref
+
+using namespace mlir;
+
+namespace {
+
+//===----------------------------------------------------------------------===//
+// ConvertMemRefAlloc
+//===----------------------------------------------------------------------===//
+
+struct ConvertMemRefAlloc final : OpConversionPattern<memref::AllocOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(memref::AllocOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Type newTy = getTypeConverter()->convertType(op.getType());
+ if (!newTy)
+ return rewriter.notifyMatchFailure(
+ op->getLoc(),
+ llvm::formatv("failed to convert memref type: {0}", op.getType()));
+
+ rewriter.replaceOpWithNewOp<memref::AllocOp>(
+ op, newTy, adaptor.getDynamicSizes(), adaptor.getSymbolOperands(),
+ adaptor.getAlignmentAttr());
+ return success();
+ }
+};
+
+//===----------------------------------------------------------------------===//
+// ConvertMemRefLoad
+//===----------------------------------------------------------------------===//
+
+struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(memref::LoadOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Type newResTy = getTypeConverter()->convertType(op.getType());
+ if (!newResTy)
+ return rewriter.notifyMatchFailure(
+ op->getLoc(), llvm::formatv("failed to convert memref type: {0}",
+ op.getMemRefType()));
+
+ rewriter.replaceOpWithNewOp<memref::LoadOp>(
+ op, newResTy, adaptor.getMemref(), adaptor.getIndices());
+ return success();
+ }
+};
+
+//===----------------------------------------------------------------------===//
+// ConvertMemRefStore
+//===----------------------------------------------------------------------===//
+
+struct ConvertMemRefStore final : OpConversionPattern<memref::StoreOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Type newTy = getTypeConverter()->convertType(op.getMemRefType());
+ if (!newTy)
+ return rewriter.notifyMatchFailure(
+ op->getLoc(), llvm::formatv("failed to convert memref type: {0}",
+ op.getMemRefType()));
+
+ rewriter.replaceOpWithNewOp<memref::StoreOp>(
+ op, adaptor.getValue(), adaptor.getMemref(), adaptor.getIndices());
+ return success();
+ }
+};
+
+//===----------------------------------------------------------------------===//
+// Pass Definition
+//===----------------------------------------------------------------------===//
+
+struct EmulateWideIntPass final
+ : memref::impl::MemRefEmulateWideIntBase<EmulateWideIntPass> {
+ using MemRefEmulateWideIntBase::MemRefEmulateWideIntBase;
+
+ void runOnOperation() override {
+ if (!llvm::isPowerOf2_32(widestIntSupported) || widestIntSupported < 2) {
+ signalPassFailure();
+ return;
+ }
+
+ Operation *op = getOperation();
+ MLIRContext *ctx = op->getContext();
+
+ arith::WideIntEmulationConverter typeConverter(widestIntSupported);
+ memref::populateMemRefWideIntEmulationConversions(typeConverter);
+ ConversionTarget target(*ctx);
+ target.addDynamicallyLegalDialect<
+ arith::ArithDialect, memref::MemRefDialect, vector::VectorDialect>(
+ [&typeConverter](Operation *op) { return typeConverter.isLegal(op); });
+
+ RewritePatternSet patterns(ctx);
+ // Add common pattenrs to support contants, functions, etc.
+ arith::populateArithWideIntEmulationPatterns(typeConverter, patterns);
+
+ memref::populateMemRefWideIntEmulationPatterns(typeConverter, patterns);
+
+ if (failed(applyPartialConversion(op, target, std::move(patterns))))
+ signalPassFailure();
+ }
+};
+
+} // end anonymous namespace
+
+//===----------------------------------------------------------------------===//
+// Public Interface Definition
+//===----------------------------------------------------------------------===//
+
+void memref::populateMemRefWideIntEmulationPatterns(
+ arith::WideIntEmulationConverter &typeConverter,
+ RewritePatternSet &patterns) {
+ // Populate `memref.*` conversion patterns.
+ patterns.add<ConvertMemRefAlloc, ConvertMemRefLoad, ConvertMemRefStore>(
+ typeConverter, patterns.getContext());
+}
+
+void memref::populateMemRefWideIntEmulationConversions(
+ arith::WideIntEmulationConverter &typeConverter) {
+ typeConverter.addConversion(
+ [&typeConverter](MemRefType ty) -> Optional<Type> {
+ auto intTy = ty.getElementType().dyn_cast<IntegerType>();
+ if (!intTy)
+ return ty;
+
+ if (intTy.getIntOrFloatBitWidth() <=
+ typeConverter.getMaxTargetIntBitWidth())
+ return ty;
+
+ Type newElemTy = typeConverter.convertType(intTy);
+ if (!newElemTy)
+ return None;
+
+ return ty.cloneWith(None, newElemTy);
+ });
+}
--- /dev/null
+// RUN: mlir-opt --memref-emulate-wide-int="widest-int-supported=32" %s | FileCheck %s
+
+// Expect no conversions, i32 is supported.
+// CHECK-LABEL: func @memref_i32
+// CHECK: [[M:%.+]] = memref.alloc() : memref<4xi32, 1>
+// CHECK-NEXT: [[V:%.+]] = memref.load [[M]][{{%.+}}] : memref<4xi32, 1>
+// CHECK-NEXT: memref.store {{%.+}}, [[M]][{{%.+}}] : memref<4xi32, 1>
+// CHECK-NEXT: return
+func.func @memref_i32() {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : i32
+ %m = memref.alloc() : memref<4xi32, 1>
+ %v = memref.load %m[%c0] : memref<4xi32, 1>
+ memref.store %c1, %m[%c0] : memref<4xi32, 1>
+ return
+}
+
+// Expect no conversions, f64 is not an integer type.
+// CHECK-LABEL: func @memref_f32
+// CHECK: [[M:%.+]] = memref.alloc() : memref<4xf32, 1>
+// CHECK-NEXT: [[V:%.+]] = memref.load [[M]][{{%.+}}] : memref<4xf32, 1>
+// CHECK-NEXT: memref.store {{%.+}}, [[M]][{{%.+}}] : memref<4xf32, 1>
+// CHECK-NEXT: return
+func.func @memref_f32() {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1.0 : f32
+ %m = memref.alloc() : memref<4xf32, 1>
+ %v = memref.load %m[%c0] : memref<4xf32, 1>
+ memref.store %c1, %m[%c0] : memref<4xf32, 1>
+ return
+}
+
+// CHECK-LABEL: func @alloc_load_store_i64
+// CHECK: [[C1:%.+]] = arith.constant dense<[1, 0]> : vector<2xi32>
+// CHECK-NEXT: [[M:%.+]] = memref.alloc() : memref<4xvector<2xi32>, 1>
+// CHECK-NEXT: [[V:%.+]] = memref.load [[M]][{{%.+}}] : memref<4xvector<2xi32>, 1>
+// CHECK-NEXT: memref.store [[C1]], [[M]][{{%.+}}] : memref<4xvector<2xi32>, 1>
+// CHECK-NEXT: return
+func.func @alloc_load_store_i64() {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : i64
+ %m = memref.alloc() : memref<4xi64, 1>
+ %v = memref.load %m[%c0] : memref<4xi64, 1>
+ memref.store %c1, %m[%c0] : memref<4xi64, 1>
+ return
+}