This patch adds support for narrow bitwidth storage emulation. The goal is to support sub-byte type
codegen for LLVM CPU. Specifically, a type converter is added to convert memref of narrow bitwidth
(e.g., i4) into supported wider bitwidth (e.g., i8). Another focus of this patch is to populate the
pattern for int4 memref.load. memref.store pattern should be added in a seperate patch.
Reviewed By: hanchung, mravishankar
Differential Revision: https://reviews.llvm.org/D151519
--- /dev/null
+//===- NarrowTypeEmulationConverter.h - Type Converter for NTE -----*- C++
+//-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_ARITH_NARROW_TYPE_EMULATION_CONVERTER_H_
+#define MLIR_DIALECT_ARITH_NARROW_TYPE_EMULATION_CONVERTER_H_
+
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir::arith {
+/// Converts narrow integer or float types that are not supported
+/// by the target hardware to wider types. Currently, we only
+/// handle power-of-two integer types and convert them to wider
+/// integers that are equal or larger than 8 bits.
+class NarrowTypeEmulationConverter : public TypeConverter {
+public:
+ explicit NarrowTypeEmulationConverter(unsigned targetBitwidth);
+
+ unsigned getLoadStoreBitwidth() const { return loadStoreBitwidth; }
+
+private:
+ unsigned loadStoreBitwidth;
+};
+} // namespace mlir::arith
+
+#endif // MLIR_DIALECT_ARITH_NARROW_TYPE_EMULATION_CONVERTER_H_
#include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
class WideIntEmulationConverter;
+class NarrowTypeEmulationConverter;
/// Create a pass to bufferize Arith ops.
std::unique_ptr<Pass> createArithBufferizePass();
void populateArithWideIntEmulationPatterns(
WideIntEmulationConverter &typeConverter, RewritePatternSet &patterns);
+/// Adds patterns to emulate narrow Arith and Function ops into wide
+/// supported types. Users need to add conversions about the computation
+/// domain of narrow types.
+void populateArithNarrowTypeEmulationPatterns(
+ NarrowTypeEmulationConverter &typeConverter, RewritePatternSet &patterns);
+
/// Add patterns to expand Arith ceil/floor division ops.
void populateCeilFloorDivExpandOpsPatterns(RewritePatternSet &patterns);
namespace arith {
class WideIntEmulationConverter;
+class NarrowTypeEmulationConverter;
} // namespace arith
namespace memref {
void populateMemRefWideIntEmulationConversions(
arith::WideIntEmulationConverter &typeConverter);
+/// Appends patterns for emulating memref operations over narrow types with ops
+/// over wider types.
+void populateMemRefNarrowTypeEmulationPatterns(
+ arith::NarrowTypeEmulationConverter &typeConverter,
+ RewritePatternSet &patterns);
+
+/// Appends type conversions for emulating memref operations over narrow types
+/// with ops over wider types.
+void populateMemRefNarrowTypeEmulationConversions(
+ arith::NarrowTypeEmulationConverter &typeConverter);
+
/// Transformation to do multi-buffering/array expansion to remove dependencies
/// on the temporary allocation between consecutive loop iterations.
/// It returns the new allocation if the original allocation was multi-buffered
BufferizableOpInterfaceImpl.cpp
Bufferize.cpp
EmulateWideInt.cpp
+ EmulateNarrowType.cpp
ExpandOps.cpp
IntNarrowing.cpp
IntRangeOptimizations.cpp
--- /dev/null
+//===- EmulateNarrowType.cpp - Narrow type emulation ----*- C++
+//-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/Transforms/Passes.h"
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Transforms/NarrowTypeEmulationConverter.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/APInt.h"
+#include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/MathExtras.h"
+#include <cassert>
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// Public Interface Definition
+//===----------------------------------------------------------------------===//
+
+arith::NarrowTypeEmulationConverter::NarrowTypeEmulationConverter(
+ unsigned targetBitwidth)
+ : loadStoreBitwidth(targetBitwidth) {
+ assert(llvm::isPowerOf2_32(targetBitwidth) &&
+ "Only power-of-two integers are supported");
+
+ // Allow unknown types.
+ addConversion([](Type ty) -> std::optional<Type> { return ty; });
+
+ // Function case.
+ addConversion([this](FunctionType ty) -> std::optional<Type> {
+ SmallVector<Type> inputs;
+ if (failed(convertTypes(ty.getInputs(), inputs)))
+ return std::nullopt;
+
+ SmallVector<Type> results;
+ if (failed(convertTypes(ty.getResults(), results)))
+ return std::nullopt;
+
+ return FunctionType::get(ty.getContext(), inputs, results);
+ });
+}
+
+void arith::populateArithNarrowTypeEmulationPatterns(
+ NarrowTypeEmulationConverter &typeConverter, RewritePatternSet &patterns) {
+ // Populate `func.*` conversion patterns.
+ populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
+ typeConverter);
+ populateCallOpTypeConversionPattern(patterns, typeConverter);
+ populateReturnOpTypeConversionPattern(patterns, typeConverter);
+}
ExpandOps.cpp
ExpandStridedMetadata.cpp
EmulateWideInt.cpp
+ EmulateNarrowType.cpp
ExtractAddressComputations.cpp
FoldMemRefAliasOps.cpp
IndependenceTransforms.cpp
--- /dev/null
+//===- EmulateNarrowType.cpp - Narrow type emulation ----*- C++
+//-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Transforms/NarrowTypeEmulationConverter.h"
+#include "mlir/Dialect/Arith/Transforms/Passes.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/MemRef/Transforms/Passes.h"
+#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/MathExtras.h"
+#include <cassert>
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// Utility functions
+//===----------------------------------------------------------------------===//
+
+/// The emulation only works on 1D memref types.
+/// To make this work on N-D memref, we need to linearize the offset.
+///
+/// For example, to emulate i4 to i8, the following op:
+///
+/// %0 = memref.load %arg0[%v0, %v1] :
+/// memref<?x?xi4, strided<[?, ?], offset: ?>>
+///
+/// can be replaced with
+///
+/// %b, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %0
+///
+/// %linearized_offset = %v0 * %stride#0 + %v1 * %stride#1
+/// %linearized_size = %size0 * %size1
+/// %scaled_linear_offset = %linearized_offset / 8 * 4
+/// %scaled_base_offset = %offset / 8 * 4
+///
+/// %linearized = memref.reinterpret_cast %b, offset = [%scaled_base_offset],
+/// sizes = [%linearized_size], strides = [%stride#1]
+///
+/// %new_load = memref.load %linearized[%scaled_linear_offset] :
+/// memref<?xi8, strided<[?], offset: ?>>
+
+static Value
+linearizeMemrefLoad(Location loc, MemRefType sourceType, int srcBits,
+ int dstBits, SmallVector<Value> indices,
+ memref::ExtractStridedMetadataOp stridedMetadata,
+ OpBuilder &builder) {
+ auto srcElementType = sourceType.getElementType();
+ unsigned sourceRank = indices.size();
+
+ Value baseBuffer = stridedMetadata.getBaseBuffer();
+ SmallVector<Value> baseSizes = stridedMetadata.getSizes();
+ SmallVector<Value> baseStrides = stridedMetadata.getStrides();
+ Value baseOffset = stridedMetadata.getOffset();
+ assert(indices.size() == baseStrides.size());
+
+ // Create the affine symbols and values for linearization.
+ SmallVector<AffineExpr> symbols(2 * sourceRank + 2);
+ bindSymbolsList(builder.getContext(), MutableArrayRef{symbols});
+ symbols[0] = builder.getAffineSymbolExpr(0);
+ AffineExpr addMulMap = symbols.front();
+ AffineExpr mulMap = symbols.front();
+
+ SmallVector<OpFoldResult> offsetValues(2 * sourceRank + 2);
+ offsetValues[0] = builder.getIndexAttr(0);
+ SmallVector<OpFoldResult> sizeValues(sourceRank + 1);
+ sizeValues[0] = builder.getIndexAttr(1);
+
+ for (unsigned i = 0; i < sourceRank; ++i) {
+ unsigned offsetIdx = 2 * i + 1;
+ addMulMap = addMulMap + symbols[offsetIdx] * symbols[offsetIdx + 1];
+ offsetValues[offsetIdx] = indices[i];
+ offsetValues[offsetIdx + 1] = baseStrides[i];
+
+ unsigned sizeIdx = i + 1;
+ mulMap = mulMap * symbols[sizeIdx];
+ sizeValues[sizeIdx] = baseSizes[i];
+ }
+
+ // Adjust linearizedOffset by the scale factor (dstBits / srcBits).
+ OpFoldResult scaler = builder.getIndexAttr(dstBits / srcBits);
+ AffineExpr scaledAddMulMap = addMulMap.floorDiv(symbols.back());
+ offsetValues.back() = scaler;
+
+ OpFoldResult linearizedOffset = affine::makeComposedFoldedAffineApply(
+ builder, loc, scaledAddMulMap, offsetValues);
+ OpFoldResult linearizedSize =
+ affine::makeComposedFoldedAffineApply(builder, loc, mulMap, sizeValues);
+
+ // Adjust baseOffset by the scale factor (dstBits / srcBits).
+ AffineExpr s0, s1;
+ bindSymbols(builder.getContext(), s0, s1);
+ OpFoldResult adjustBaseOffset = affine::makeComposedFoldedAffineApply(
+ builder, loc, s0.floorDiv(s1), {baseOffset, scaler});
+
+ // Flatten n-D MemRef to 1-D MemRef.
+ auto layoutAttr = StridedLayoutAttr::get(
+ sourceType.getContext(), ShapedType::kDynamic, {ShapedType::kDynamic});
+ int64_t staticShape = sourceType.hasStaticShape()
+ ? sourceType.getNumElements()
+ : ShapedType::kDynamic;
+ auto flattenMemrefType = MemRefType::get(
+ staticShape, srcElementType, layoutAttr, sourceType.getMemorySpace());
+
+ auto reinterpret = builder.create<memref::ReinterpretCastOp>(
+ loc, flattenMemrefType, baseBuffer,
+ getValueOrCreateConstantIndexOp(builder, loc, adjustBaseOffset),
+ getValueOrCreateConstantIndexOp(builder, loc, linearizedSize),
+ baseStrides.back());
+
+ return builder.create<memref::LoadOp>(
+ loc, srcElementType, reinterpret.getResult(),
+ getValueOrCreateConstantIndexOp(builder, loc, linearizedOffset));
+}
+
+/// When data is loaded/stored in `targetBits` granularity, but is used in
+/// `sourceBits` granularity (`sourceBits` < `targetBits`), the `targetBits` is
+/// treated as an array of elements of width `sourceBits`.
+/// Return the bit offset of the value at position `srcIdx`. For example, if
+/// `sourceBits` equals to 4 and `targetBits` equals to 8, the x-th element is
+/// located at (x % 2) * 4. Because there are two elements in one i8, and one
+/// element has 4 bits.
+static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits,
+ int targetBits, OpBuilder &builder) {
+ assert(targetBits % sourceBits == 0);
+ IntegerType targetType = builder.getIntegerType(targetBits);
+ IntegerAttr idxAttr =
+ builder.getIntegerAttr(targetType, targetBits / sourceBits);
+ auto idx = builder.create<arith::ConstantOp>(loc, targetType, idxAttr);
+ IntegerAttr srcBitsAttr = builder.getIntegerAttr(targetType, sourceBits);
+ auto srcBitsValue =
+ builder.create<arith::ConstantOp>(loc, targetType, srcBitsAttr);
+ auto m = builder.create<arith::RemUIOp>(loc, srcIdx, idx);
+ return builder.create<arith::MulIOp>(loc, targetType, m, srcBitsValue);
+}
+
+namespace {
+
+//===----------------------------------------------------------------------===//
+// ConvertMemRefAlloc
+//===----------------------------------------------------------------------===//
+
+struct ConvertMemRefAlloc final : OpConversionPattern<memref::AllocOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(memref::AllocOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Type newTy = getTypeConverter()->convertType(op.getType());
+ if (!newTy) {
+ return rewriter.notifyMatchFailure(
+ op->getLoc(),
+ llvm::formatv("failed to convert memref type: {0}", op.getType()));
+ }
+
+ rewriter.replaceOpWithNewOp<memref::AllocOp>(
+ op, newTy, adaptor.getDynamicSizes(), adaptor.getSymbolOperands(),
+ adaptor.getAlignmentAttr());
+ return success();
+ }
+};
+
+//===----------------------------------------------------------------------===//
+// ConvertMemRefAssumeAlignment
+//===----------------------------------------------------------------------===//
+
+struct ConvertMemRefAssumeAlignment final
+ : OpConversionPattern<memref::AssumeAlignmentOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Type newTy = getTypeConverter()->convertType(op.getMemref().getType());
+ if (!newTy) {
+ return rewriter.notifyMatchFailure(
+ op->getLoc(), llvm::formatv("failed to convert memref type: {0}",
+ op.getMemref().getType()));
+ }
+
+ rewriter.replaceOpWithNewOp<memref::AssumeAlignmentOp>(
+ op, adaptor.getMemref(), adaptor.getAlignmentAttr());
+ return success();
+ }
+};
+
+//===----------------------------------------------------------------------===//
+// ConvertMemRefLoad
+//===----------------------------------------------------------------------===//
+
+struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(memref::LoadOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Type newTy = getTypeConverter()->convertType(op.getMemRefType());
+ if (!newTy) {
+ return rewriter.notifyMatchFailure(
+ op->getLoc(), llvm::formatv("failed to convert memref type: {0}",
+ op.getMemRefType()));
+ }
+
+ if (op.getMemRefType() == newTy)
+ return failure();
+
+ auto loc = op.getLoc();
+ auto sourceType = cast<MemRefType>(adaptor.getMemref().getType());
+ unsigned sourceRank = sourceType.getRank();
+ SmallVector<Value> indices = adaptor.getIndices();
+ assert(indices.size() == sourceRank);
+
+ auto srcElementType = sourceType.getElementType();
+ auto oldElementType = op.getMemRefType().getElementType();
+ int srcBits = oldElementType.getIntOrFloatBitWidth();
+ int dstBits = srcElementType.getIntOrFloatBitWidth();
+ if (dstBits % srcBits != 0) {
+ return rewriter.notifyMatchFailure(
+ op, "only dstBits % srcBits == 0 supported");
+ }
+
+ auto stridedMetadata = rewriter.create<memref::ExtractStridedMetadataOp>(
+ loc, adaptor.getMemref());
+
+ Value newLoad, lastIdx;
+ if (sourceRank == 0) {
+ newLoad = rewriter.create<memref::LoadOp>(
+ loc, srcElementType, adaptor.getMemref(), adaptor.getIndices());
+
+ lastIdx = stridedMetadata.getOffset();
+ } else {
+ newLoad = linearizeMemrefLoad(loc, sourceType, srcBits, dstBits, indices,
+ stridedMetadata, rewriter);
+
+ lastIdx = adaptor.getIndices().back();
+ }
+
+ // Get the offset and shift the bits to the rightmost.
+ // Note, currently only the big-endian is supported.
+ auto castLastIdx =
+ rewriter.create<arith::IndexCastUIOp>(loc, srcElementType, lastIdx);
+
+ Value BitwidthOffset =
+ getOffsetForBitwidth(loc, castLastIdx, srcBits, dstBits, rewriter);
+ auto bitsLoad =
+ rewriter.create<arith::ShRSIOp>(loc, newLoad, BitwidthOffset);
+
+ // Get the corresponding bits. If the arith computation bitwidth equals
+ // to the emulated bitwidth, we apply a mask to extract the low bits.
+ // It is not clear if this case actually happens in practice, but we keep
+ // the operations just in case. Otherwise, if the arith computation bitwidth
+ // is different from the emulated bitwidth we truncate the result.
+ Operation *result;
+ auto resultTy = getTypeConverter()->convertType(oldElementType);
+ if (resultTy == srcElementType) {
+ auto mask = rewriter.create<arith::ConstantOp>(
+ loc, srcElementType,
+ rewriter.getIntegerAttr(srcElementType, (1 << srcBits) - 1));
+
+ result = rewriter.create<arith::AndIOp>(loc, bitsLoad, mask);
+ } else {
+ result = rewriter.create<arith::TruncIOp>(loc, resultTy, bitsLoad);
+ }
+
+ rewriter.replaceOp(op, result->getResult(0));
+ return success();
+ }
+};
+} // end anonymous namespace
+
+//===----------------------------------------------------------------------===//
+// Public Interface Definition
+//===----------------------------------------------------------------------===//
+
+void memref::populateMemRefNarrowTypeEmulationPatterns(
+ arith::NarrowTypeEmulationConverter &typeConverter,
+ RewritePatternSet &patterns) {
+
+ // Populate `memref.*` conversion patterns.
+ patterns
+ .add<ConvertMemRefAlloc, ConvertMemRefLoad, ConvertMemRefAssumeAlignment>(
+ typeConverter, patterns.getContext());
+}
+
+void memref::populateMemRefNarrowTypeEmulationConversions(
+ arith::NarrowTypeEmulationConverter &typeConverter) {
+ typeConverter.addConversion(
+ [&typeConverter](MemRefType ty) -> std::optional<Type> {
+ auto intTy = dyn_cast<IntegerType>(ty.getElementType());
+ if (!intTy)
+ return ty;
+
+ unsigned width = intTy.getWidth();
+ unsigned loadStoreWidth = typeConverter.getLoadStoreBitwidth();
+ if (width >= loadStoreWidth)
+ return ty;
+
+ auto newElemTy = IntegerType::get(ty.getContext(), loadStoreWidth,
+ intTy.getSignedness());
+ if (!newElemTy)
+ return std::nullopt;
+
+ return ty.cloneWith(std::nullopt, newElemTy);
+ });
+}
--- /dev/null
+// RUN: mlir-opt --test-emulate-narrow-int="arith-compute-bitwidth=8" %s | FileCheck %s
+
+// Expect no conversions, f32 is not an integer type.
+// CHECK-LABEL: func @identity_f32
+// CHECK-SAME: ([[ARG:%.+]]: f32) -> f32
+// CHECK-NEXT: return [[ARG]] : f32
+func.func @identity_f32(%a : f32) -> f32 {
+ return %a : f32
+}
+
+// Expect no conversions, i32 is supported.
+// CHECK-LABEL: func @identity_i32
+// CHECK-SAME: ([[ARG:%.+]]: vector<2xi32>) -> vector<2xi32>
+// CHECK-NEXT: return [[ARG]] : vector<2xi32>
+func.func @identity_i32(%a : vector<2xi32>) -> vector<2xi32> {
+ return %a : vector<2xi32>
+}
+
+// CHECK-LABEL: func @identity_scalar
+// CHECK-SAME: ([[ARG:%.+]]: i8) -> i8
+// CHECK-NEXT: return [[ARG]] : i8
+func.func @identity_scalar(%x : i4) -> i4 {
+ return %x : i4
+}
+
+// CHECK-LABEL: func @identity_vector
+// CHECK-SAME: ([[ARG:%.+]]: vector<4xi8>) -> vector<4xi8>
+// CHECK-NEXT: return [[ARG]] : vector<4xi8>
+func.func @identity_vector(%x : vector<4xi4>) -> vector<4xi4> {
+ return %x : vector<4xi4>
+}
+
+// CHECK-LABEL: func @identity_vector2d
+// CHECK-SAME: ([[ARG:%.+]]: vector<3x4xi8>) -> vector<3x4xi8>
+// CHECK-NEXT: return [[ARG]] : vector<3x4xi8>
+func.func @identity_vector2d(%x : vector<3x4xi4>) -> vector<3x4xi4> {
+ return %x : vector<3x4xi4>
+}
+
+// CHECK-LABEL: func @call
+// CHECK-SAME: ([[ARG:%.+]]: vector<4xi8>) -> vector<4xi8>
+// CHECK-NEXT: [[RES:%.+]] = call @identity_vector([[ARG]]) : (vector<4xi8>) -> vector<4xi8>
+// CHECK-NEXT: return [[RES]] : vector<4xi8>
+func.func @call(%a : vector<4xi4>) -> vector<4xi4> {
+ %res = func.call @identity_vector(%a) : (vector<4xi4>) -> vector<4xi4>
+ return %res : vector<4xi4>
+}
--- /dev/null
+// RUN: mlir-opt --test-emulate-narrow-int="arith-compute-bitwidth=4 memref-load-bitwidth=8" %s | FileCheck %s
+
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 2)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 floordiv 2)>
+// CHECK-DAG: #[[$MAP2:.*]] = affine_map<()[s0, s1, s2, s3] -> ((s0 * s1 + s2 * s3) floordiv 2)>
+// CHECK-DAG: #[[$MAP3:.*]] = affine_map<()[s0, s1] -> (s0 * s1)>
+
+// Expect no conversions, i32 is supported.
+// CHECK-LABEL: func @memref_i32
+// CHECK: [[M:%.+]] = memref.alloc() : memref<4xi32, 1>
+// CHECK-NEXT: [[V:%.+]] = memref.load [[M]][{{%.+}}] : memref<4xi32, 1>
+// CHECK-NEXT: memref.store {{%.+}}, [[M]][{{%.+}}] : memref<4xi32, 1>
+// CHECK-NEXT: return
+func.func @memref_i32() {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : i32
+ %m = memref.alloc() : memref<4xi32, 1>
+ %v = memref.load %m[%c0] : memref<4xi32, 1>
+ memref.store %c1, %m[%c0] : memref<4xi32, 1>
+ return
+}
+
+// -----
+
+// Expect no conversions, f32 is not an integer type.
+// CHECK-LABEL: func @memref_f32
+// CHECK: [[M:%.+]] = memref.alloc() : memref<4xf32, 1>
+// CHECK-NEXT: [[V:%.+]] = memref.load [[M]][{{%.+}}] : memref<4xf32, 1>
+// CHECK-NEXT: memref.store {{%.+}}, [[M]][{{%.+}}] : memref<4xf32, 1>
+// CHECK-NEXT: return
+func.func @memref_f32() {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1.0 : f32
+ %m = memref.alloc() : memref<4xf32, 1>
+ %v = memref.load %m[%c0] : memref<4xf32, 1>
+ memref.store %c1, %m[%c0] : memref<4xf32, 1>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @memref_load_i4_zero_rank
+// CHECK-NEXT: %[[M:.*]] = memref.alloc() : memref<i8>
+// CHECK-NEXT: %[[BASE:.*]], %[[OFFSET:.*]] = memref.extract_strided_metadata %[[M]] : memref<i8> -> memref<i8>, index
+// CHECK-NEXT: %[[LOAD:.*]] = memref.load %[[M]][] : memref<i8>
+// CHECK-NEXT: %[[I:.*]] = arith.index_castui %[[OFFSET]] : index to i8
+// CHECK-NEXT: %[[C2:.*]] = arith.constant 2 : i8
+// CHECK-NEXT: %[[C4:.*]] = arith.constant 4 : i8
+// CHECK-NEXT: %[[REM:.*]] = arith.remui %[[I]], %[[C2]] : i8
+// CHECK-NEXT: %[[STEP:.*]] = arith.muli %[[REM]], %[[C4]] : i8
+// CHECK-NEXT: %[[SHIFT:.*]] = arith.shrsi %[[LOAD]], %[[STEP]] : i8
+// CHECK-NEXT: %[[RES:.*]] = arith.trunci %[[SHIFT]] : i8 to i4
+// CHECK-NEXT: return
+func.func @memref_load_i4_zero_rank() {
+ %0 = memref.alloc() : memref<i4>
+ %1 = memref.load %0[] : memref<i4>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @memref_load_i4
+// CHECK-SAME: (%[[ARG:.*]]: index)
+// CHECK-NEXT: %[[M:.*]] = memref.alloc() : memref<4xi8>
+// CHECK-NEXT: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]], %[[STRIDES:.*]] = memref.extract_strided_metadata %[[M]] : memref<4xi8> -> memref<i8>, index, index, index
+// CHECK-NEXT: %[[INDEX:.*]] = affine.apply #[[$MAP0]]()[%[[ARG]], %[[STRIDES]]]
+// CHECK-NEXT: %[[AOFF:.*]] = affine.apply #[[$MAP1]]()[%[[OFFSET]]]
+// CHECK-NEXT: %[[CAST:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[AOFF]]], sizes: [%[[SIZES]]], strides: [%[[STRIDES]]] : memref<i8> to memref<4xi8, strided<[?], offset: ?>>
+// CHECK-NEXT: %[[LOAD:.*]] = memref.load %[[CAST]][%[[INDEX]]] : memref<4xi8, strided<[?], offset: ?>>
+// CHECK-NEXT: %[[I:.*]] = arith.index_castui %[[ARG]] : index to i8
+// CHECK-NEXT: %[[C2:.*]] = arith.constant 2 : i8
+// CHECK-NEXT: %[[C4:.*]] = arith.constant 4 : i8
+// CHECK-NEXT: %[[REM:.*]] = arith.remui %[[I]], %[[C2]] : i8
+// CHECK-NEXT: %[[STEP:.*]] = arith.muli %[[REM]], %[[C4]] : i8
+// CHECK-NEXT: %[[SHIFT:.*]] = arith.shrsi %[[LOAD]], %[[STEP]] : i8
+// CHECK-NEXT: %[[RES:.*]] = arith.trunci %[[SHIFT]] : i8 to i4
+// CHECK-NEXT: return
+func.func @memref_load_i4(%arg0: index) {
+ %0 = memref.alloc() : memref<4xi4>
+ %1 = memref.load %0[%arg0] : memref<4xi4>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @memref_load_i4_rank2
+// CHECK-SAME: (%[[ARG:.*]]: memref<4x128xi8>, %[[ARG0:.*]]: index, %[[ARG1:.*]]: index)
+// CHECK-NEXT: memref.assume_alignment %[[ARG]], 64 : memref<4x128xi8>
+// CHECK-NEXT: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]] : memref<4x128xi8> -> memref<i8>, index, index, index, index, index
+// CHECK-NEXT: %[[INDEX:.*]] = affine.apply #[[$MAP2]]()[%[[ARG0]], %[[STRIDES]]#0, %[[ARG1]], %[[STRIDES]]#1]
+// CHECK-NEXT: %[[LSIZE:.*]] = affine.apply #[[$MAP3]]()[%[[SIZES]]#0, %[[SIZES]]#1]
+// CHECK-NEXT: %[[AOFF:.*]] = affine.apply #[[$MAP1]]()[%[[OFFSET]]]
+// CHECK-NEXT: %[[CAST:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[AOFF]]], sizes: [%[[LSIZE]]], strides: [%[[STRIDES]]#1] : memref<i8> to memref<512xi8, strided<[?], offset: ?>>
+// CHECK-NEXT: %[[LOAD:.*]] = memref.load %[[CAST]][%[[INDEX]]] : memref<512xi8, strided<[?], offset: ?>>
+// CHECK-NEXT: %[[I:.*]] = arith.index_castui %[[ARG1]] : index to i8
+// CHECK-NEXT: %[[C2:.*]] = arith.constant 2 : i8
+// CHECK-NEXT: %[[C4:.*]] = arith.constant 4 : i8
+// CHECK-NEXT: %[[REM:.*]] = arith.remui %[[I]], %[[C2]] : i8
+// CHECK-NEXT: %[[STEP:.*]] = arith.muli %[[REM]], %[[C4]] : i8
+// CHECK-NEXT: %[[SHIFT:.*]] = arith.shrsi %[[LOAD]], %[[STEP]] : i8
+// CHECK-NEXT: %[[RES:.*]] = arith.trunci %[[SHIFT]] : i8 to i4
+// CHECK-NEXT: return
+func.func @memref_load_i4_rank2(%0: memref<4x128xi4>, %arg0: index, %arg1: index) {
+ memref.assume_alignment %0, 64 : memref<4x128xi4>
+ %1 = memref.load %0[%arg0,%arg1] : memref<4x128xi4>
+ return
+}
--- /dev/null
+// RUN: mlir-opt --test-emulate-narrow-int="arith-compute-bitwidth=8 memref-load-bitwidth=8" %s | FileCheck %s
+
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 2)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 floordiv 2)>
+// CHECK-DAG: #[[$MAP2:.*]] = affine_map<()[s0, s1, s2, s3] -> ((s0 * s1 + s2 * s3) floordiv 2)>
+// CHECK-DAG: #[[$MAP3:.*]] = affine_map<()[s0, s1] -> (s0 * s1)>
+
+// Expect no conversions.
+// CHECK-LABEL: func @memref_i8
+// CHECK: [[M:%.+]] = memref.alloc() : memref<4xi8, 1>
+// CHECK-NEXT: [[V:%.+]] = memref.load [[M]][{{%.+}}] : memref<4xi8, 1>
+// CHECK-NEXT: memref.store {{%.+}}, [[M]][{{%.+}}] : memref<4xi8, 1>
+// CHECK-NEXT: return
+func.func @memref_i8() {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : i8
+ %m = memref.alloc() : memref<4xi8, 1>
+ %v = memref.load %m[%c0] : memref<4xi8, 1>
+ memref.store %c1, %m[%c0] : memref<4xi8, 1>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @memref_load_i4
+// CHECK-SAME: (%[[ARG:.*]]: index)
+// CHECK-NEXT: %[[M:.*]] = memref.alloc() : memref<4xi8>
+// CHECK-NEXT: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]], %[[STRIDES:.*]] = memref.extract_strided_metadata %[[M]] : memref<4xi8> -> memref<i8>, index, index, index
+// CHECK-NEXT: %[[INDEX:.*]] = affine.apply #[[$MAP0]]()[%[[ARG]], %[[STRIDES]]]
+// CHECK-NEXT: %[[AOFF:.*]] = affine.apply #[[$MAP1]]()[%[[OFFSET]]]
+// CHECK-NEXT: %[[CAST:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[AOFF]]], sizes: [%[[SIZES]]], strides: [%[[STRIDES]]] : memref<i8> to memref<4xi8, strided<[?], offset: ?>>
+// CHECK-NEXT: %[[LOAD:.*]] = memref.load %[[CAST]][%[[INDEX]]] : memref<4xi8, strided<[?], offset: ?>>
+// CHECK-NEXT: %[[I:.*]] = arith.index_castui %[[ARG]] : index to i8
+// CHECK-NEXT: %[[C2:.*]] = arith.constant 2 : i8
+// CHECK-NEXT: %[[C4:.*]] = arith.constant 4 : i8
+// CHECK-NEXT: %[[REM:.*]] = arith.remui %[[I]], %[[C2]] : i8
+// CHECK-NEXT: %[[STEP:.*]] = arith.muli %[[REM]], %[[C4]] : i8
+// CHECK-NEXT: %[[SHIFT:.*]] = arith.shrsi %[[LOAD]], %[[STEP]] : i8
+// CHECK-NEXT: %[[MASK:.*]] = arith.constant 15 : i8
+// CHECK-NEXT: %[[RES:.*]] = arith.andi %[[SHIFT]], %[[MASK]] : i8
+// CHECK-NEXT: return
+func.func @memref_load_i4(%arg0: index) {
+ %0 = memref.alloc() : memref<4xi4>
+ %1 = memref.load %0[%arg0] : memref<4xi4>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @memref_load_i4_rank2
+// CHECK-SAME: (%[[ARG:.*]]: memref<4x128xi8>, %[[ARG0:.*]]: index, %[[ARG1:.*]]: index)
+// CHECK-NEXT: memref.assume_alignment %[[ARG]], 64 : memref<4x128xi8>
+// CHECK-NEXT: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]] : memref<4x128xi8> -> memref<i8>, index, index, index, index, index
+// CHECK-NEXT: %[[INDEX:.*]] = affine.apply #[[$MAP2]]()[%[[ARG0]], %[[STRIDES]]#0, %[[ARG1]], %[[STRIDES]]#1]
+// CHECK-NEXT: %[[LSIZE:.*]] = affine.apply #[[$MAP3]]()[%[[SIZES]]#0, %[[SIZES]]#1]
+// CHECK-NEXT: %[[AOFF:.*]] = affine.apply #[[$MAP1]]()[%[[OFFSET]]]
+// CHECK-NEXT: %[[CAST:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[AOFF]]], sizes: [%[[LSIZE]]], strides: [%[[STRIDES]]#1] : memref<i8> to memref<512xi8, strided<[?], offset: ?>>
+// CHECK-NEXT: %[[LOAD:.*]] = memref.load %[[CAST]][%[[INDEX]]] : memref<512xi8, strided<[?], offset: ?>>
+// CHECK-NEXT: %[[I:.*]] = arith.index_castui %[[ARG1]] : index to i8
+// CHECK-NEXT: %[[C2:.*]] = arith.constant 2 : i8
+// CHECK-NEXT: %[[C4:.*]] = arith.constant 4 : i8
+// CHECK-NEXT: %[[REM:.*]] = arith.remui %[[I]], %[[C2]] : i8
+// CHECK-NEXT: %[[STEP:.*]] = arith.muli %[[REM]], %[[C4]] : i8
+// CHECK-NEXT: %[[SHIFT:.*]] = arith.shrsi %[[LOAD]], %[[STEP]] : i8
+// CHECK-NEXT: %[[MASK:.*]] = arith.constant 15 : i8
+// CHECK-NEXT: %[[RES:.*]] = arith.andi %[[SHIFT]], %[[MASK]] : i8
+// CHECK-NEXT: return
+func.func @memref_load_i4_rank2(%0: memref<4x128xi4>, %arg0: index, %arg1: index) {
+ memref.assume_alignment %0, 64 : memref<4x128xi4>
+ %1 = memref.load %0[%arg0,%arg1] : memref<4x128xi4>
+ return
+}
# Exclude tests from libMLIR.so
add_mlir_library(MLIRMemRefTestPasses
TestComposeSubView.cpp
+ TestEmulateNarrowType.cpp
TestMultiBuffer.cpp
EXCLUDE_FROM_LIBMLIR
--- /dev/null
+//===- TestEmulateNarrowType.cpp - Test Narrow Type Emulation ------*- c++
+//-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Transforms/NarrowTypeEmulationConverter.h"
+#include "mlir/Dialect/Arith/Transforms/Passes.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+using namespace mlir;
+
+namespace {
+
+struct TestEmulateNarrowTypePass
+ : public PassWrapper<TestEmulateNarrowTypePass,
+ OperationPass<func::FuncOp>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestEmulateNarrowTypePass)
+
+ TestEmulateNarrowTypePass() = default;
+ TestEmulateNarrowTypePass(const TestEmulateNarrowTypePass &pass)
+ : PassWrapper(pass) {}
+
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry
+ .insert<arith::ArithDialect, func::FuncDialect, memref::MemRefDialect,
+ vector::VectorDialect, affine::AffineDialect>();
+ }
+ StringRef getArgument() const final { return "test-emulate-narrow-int"; }
+ StringRef getDescription() const final {
+ return "Function pass to test Narrow Integer Emulation";
+ }
+
+ void runOnOperation() override {
+ if (!llvm::isPowerOf2_32(loadStoreEmulateBitwidth) ||
+ loadStoreEmulateBitwidth < 8) {
+ signalPassFailure();
+ return;
+ }
+
+ Operation *op = getOperation();
+ MLIRContext *ctx = op->getContext();
+
+ arith::NarrowTypeEmulationConverter typeConverter(loadStoreEmulateBitwidth);
+
+ // Convert scalar type.
+ typeConverter.addConversion([this](IntegerType ty) -> std::optional<Type> {
+ unsigned width = ty.getWidth();
+ if (width >= arithComputeBitwidth)
+ return ty;
+
+ return IntegerType::get(ty.getContext(), arithComputeBitwidth);
+ });
+
+ // Convert vector type.
+ typeConverter.addConversion([this](VectorType ty) -> std::optional<Type> {
+ auto intTy = dyn_cast<IntegerType>(ty.getElementType());
+ if (!intTy)
+ return ty;
+
+ unsigned width = intTy.getWidth();
+ if (width >= arithComputeBitwidth)
+ return ty;
+
+ return VectorType::get(
+ to_vector(ty.getShape()),
+ IntegerType::get(ty.getContext(), arithComputeBitwidth));
+ });
+
+ memref::populateMemRefNarrowTypeEmulationConversions(typeConverter);
+ ConversionTarget target(*ctx);
+ target.addDynamicallyLegalOp<func::FuncOp>([&typeConverter](Operation *op) {
+ return typeConverter.isLegal(cast<func::FuncOp>(op).getFunctionType());
+ });
+ auto opLegalCallback = [&typeConverter](Operation *op) {
+ return typeConverter.isLegal(op);
+ };
+ target.addDynamicallyLegalOp<func::CallOp, func::ReturnOp>(opLegalCallback);
+ target.addDynamicallyLegalDialect<
+ arith::ArithDialect, vector::VectorDialect, memref::MemRefDialect,
+ affine::AffineDialect>(
+ [&typeConverter](Operation *op) { return typeConverter.isLegal(op); });
+
+ RewritePatternSet patterns(ctx);
+
+ arith::populateArithNarrowTypeEmulationPatterns(typeConverter, patterns);
+ memref::populateMemRefNarrowTypeEmulationPatterns(typeConverter, patterns);
+
+ if (failed(applyPartialConversion(op, target, std::move(patterns))))
+ signalPassFailure();
+ }
+
+ Option<unsigned> loadStoreEmulateBitwidth{
+ *this, "memref-load-bitwidth",
+ llvm::cl::desc("memref load/store emulation bit width"),
+ llvm::cl::init(8)};
+
+ Option<unsigned> arithComputeBitwidth{
+ *this, "arith-compute-bitwidth",
+ llvm::cl::desc("arith computation bit width"), llvm::cl::init(4)};
+};
+} // namespace
+
+namespace mlir::test {
+void registerTestEmulateNarrowTypePass() {
+ PassRegistration<TestEmulateNarrowTypePass>();
+}
+} // namespace mlir::test
void registerTestDialectConversionPasses();
void registerTestDominancePass();
void registerTestDynamicPipelinePass();
+void registerTestEmulateNarrowTypePass();
void registerTestExpandMathPass();
void registerTestFooAnalysisPass();
void registerTestComposeSubView();
mlir::test::registerTestDeadCodeAnalysisPass();
mlir::test::registerTestDominancePass();
mlir::test::registerTestDynamicPipelinePass();
+ mlir::test::registerTestEmulateNarrowTypePass();
mlir::test::registerTestExpandMathPass();
mlir::test::registerTestFooAnalysisPass();
mlir::test::registerTestComposeSubView();