From 5a1cdcbd8698cd263696b38e2672fccac9ec793c Mon Sep 17 00:00:00 2001 From: yzhang93 Date: Mon, 26 Jun 2023 14:18:15 -0700 Subject: [PATCH] [mlir] Narrow bitwidth emulation for MemRef load This patch adds support for narrow bitwidth storage emulation. The goal is to support sub-byte type codegen for LLVM CPU. Specifically, a type converter is added to convert memref of narrow bitwidth (e.g., i4) into supported wider bitwidth (e.g., i8). Another focus of this patch is to populate the pattern for int4 memref.load. memref.store pattern should be added in a seperate patch. Reviewed By: hanchung, mravishankar Differential Revision: https://reviews.llvm.org/D151519 --- .../Transforms/NarrowTypeEmulationConverter.h | 31 ++ .../include/mlir/Dialect/Arith/Transforms/Passes.h | 7 + .../mlir/Dialect/MemRef/Transforms/Transforms.h | 12 + mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt | 1 + .../Dialect/Arith/Transforms/EmulateNarrowType.cpp | 61 ++++ mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt | 1 + .../MemRef/Transforms/EmulateNarrowType.cpp | 315 +++++++++++++++++++++ mlir/test/Dialect/Arith/emulate-narrow-type.mlir | 47 +++ .../emulate-narrow-type-diff-load-compute.mlir | 107 +++++++ .../emulate-narrow-type-same-load-compute.mlir | 72 +++++ mlir/test/lib/Dialect/MemRef/CMakeLists.txt | 1 + .../lib/Dialect/MemRef/TestEmulateNarrowType.cpp | 118 ++++++++ mlir/tools/mlir-opt/mlir-opt.cpp | 2 + 13 files changed, 775 insertions(+) create mode 100644 mlir/include/mlir/Dialect/Arith/Transforms/NarrowTypeEmulationConverter.h create mode 100644 mlir/lib/Dialect/Arith/Transforms/EmulateNarrowType.cpp create mode 100644 mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp create mode 100644 mlir/test/Dialect/Arith/emulate-narrow-type.mlir create mode 100644 mlir/test/Dialect/MemRef/emulate-narrow-type-diff-load-compute.mlir create mode 100644 mlir/test/Dialect/MemRef/emulate-narrow-type-same-load-compute.mlir create mode 100644 mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/NarrowTypeEmulationConverter.h b/mlir/include/mlir/Dialect/Arith/Transforms/NarrowTypeEmulationConverter.h new file mode 100644 index 0000000..528bb51 --- /dev/null +++ b/mlir/include/mlir/Dialect/Arith/Transforms/NarrowTypeEmulationConverter.h @@ -0,0 +1,31 @@ +//===- NarrowTypeEmulationConverter.h - Type Converter for NTE -----*- C++ +//-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_ARITH_NARROW_TYPE_EMULATION_CONVERTER_H_ +#define MLIR_DIALECT_ARITH_NARROW_TYPE_EMULATION_CONVERTER_H_ + +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir::arith { +/// Converts narrow integer or float types that are not supported +/// by the target hardware to wider types. Currently, we only +/// handle power-of-two integer types and convert them to wider +/// integers that are equal or larger than 8 bits. +class NarrowTypeEmulationConverter : public TypeConverter { +public: + explicit NarrowTypeEmulationConverter(unsigned targetBitwidth); + + unsigned getLoadStoreBitwidth() const { return loadStoreBitwidth; } + +private: + unsigned loadStoreBitwidth; +}; +} // namespace mlir::arith + +#endif // MLIR_DIALECT_ARITH_NARROW_TYPE_EMULATION_CONVERTER_H_ diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h index c4010b7..de36cb4 100644 --- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h @@ -22,6 +22,7 @@ namespace arith { #include "mlir/Dialect/Arith/Transforms/Passes.h.inc" class WideIntEmulationConverter; +class NarrowTypeEmulationConverter; /// Create a pass to bufferize Arith ops. std::unique_ptr createArithBufferizePass(); @@ -35,6 +36,12 @@ std::unique_ptr createConstantBufferizePass(uint64_t alignment = 0); void populateArithWideIntEmulationPatterns( WideIntEmulationConverter &typeConverter, RewritePatternSet &patterns); +/// Adds patterns to emulate narrow Arith and Function ops into wide +/// supported types. Users need to add conversions about the computation +/// domain of narrow types. +void populateArithNarrowTypeEmulationPatterns( + NarrowTypeEmulationConverter &typeConverter, RewritePatternSet &patterns); + /// Add patterns to expand Arith ceil/floor division ops. void populateCeilFloorDivExpandOpsPatterns(RewritePatternSet &patterns); diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h index 91ef162..0b1af47 100644 --- a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h @@ -25,6 +25,7 @@ class ValueRange; namespace arith { class WideIntEmulationConverter; +class NarrowTypeEmulationConverter; } // namespace arith namespace memref { @@ -73,6 +74,17 @@ void populateMemRefWideIntEmulationPatterns( void populateMemRefWideIntEmulationConversions( arith::WideIntEmulationConverter &typeConverter); +/// Appends patterns for emulating memref operations over narrow types with ops +/// over wider types. +void populateMemRefNarrowTypeEmulationPatterns( + arith::NarrowTypeEmulationConverter &typeConverter, + RewritePatternSet &patterns); + +/// Appends type conversions for emulating memref operations over narrow types +/// with ops over wider types. +void populateMemRefNarrowTypeEmulationConversions( + arith::NarrowTypeEmulationConverter &typeConverter); + /// Transformation to do multi-buffering/array expansion to remove dependencies /// on the temporary allocation between consecutive loop iterations. /// It returns the new allocation if the original allocation was multi-buffered diff --git a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt index 87d9beb..b969389 100644 --- a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt @@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRArithTransforms BufferizableOpInterfaceImpl.cpp Bufferize.cpp EmulateWideInt.cpp + EmulateNarrowType.cpp ExpandOps.cpp IntNarrowing.cpp IntRangeOptimizations.cpp diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateNarrowType.cpp new file mode 100644 index 0000000..e0e1385 --- /dev/null +++ b/mlir/lib/Dialect/Arith/Transforms/EmulateNarrowType.cpp @@ -0,0 +1,61 @@ +//===- EmulateNarrowType.cpp - Narrow type emulation ----*- C++ +//-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Arith/Transforms/Passes.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Arith/Transforms/NarrowTypeEmulationConverter.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Func/Transforms/FuncConversions.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/APInt.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/MathExtras.h" +#include + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// Public Interface Definition +//===----------------------------------------------------------------------===// + +arith::NarrowTypeEmulationConverter::NarrowTypeEmulationConverter( + unsigned targetBitwidth) + : loadStoreBitwidth(targetBitwidth) { + assert(llvm::isPowerOf2_32(targetBitwidth) && + "Only power-of-two integers are supported"); + + // Allow unknown types. + addConversion([](Type ty) -> std::optional { return ty; }); + + // Function case. + addConversion([this](FunctionType ty) -> std::optional { + SmallVector inputs; + if (failed(convertTypes(ty.getInputs(), inputs))) + return std::nullopt; + + SmallVector results; + if (failed(convertTypes(ty.getResults(), results))) + return std::nullopt; + + return FunctionType::get(ty.getContext(), inputs, results); + }); +} + +void arith::populateArithNarrowTypeEmulationPatterns( + NarrowTypeEmulationConverter &typeConverter, RewritePatternSet &patterns) { + // Populate `func.*` conversion patterns. + populateFunctionOpInterfaceTypeConversionPattern(patterns, + typeConverter); + populateCallOpTypeConversionPattern(patterns, typeConverter); + populateReturnOpTypeConversionPattern(patterns, typeConverter); +} diff --git a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt index a16d850..10ca179 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt @@ -4,6 +4,7 @@ add_mlir_dialect_library(MLIRMemRefTransforms ExpandOps.cpp ExpandStridedMetadata.cpp EmulateWideInt.cpp + EmulateNarrowType.cpp ExtractAddressComputations.cpp FoldMemRefAliasOps.cpp IndependenceTransforms.cpp diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp new file mode 100644 index 0000000..a876bc7 --- /dev/null +++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp @@ -0,0 +1,315 @@ +//===- EmulateNarrowType.cpp - Narrow type emulation ----*- C++ +//-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Arith/Transforms/NarrowTypeEmulationConverter.h" +#include "mlir/Dialect/Arith/Transforms/Passes.h" +#include "mlir/Dialect/Arith/Utils/Utils.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/MemRef/Transforms/Passes.h" +#include "mlir/Dialect/MemRef/Transforms/Transforms.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Transforms/DialectConversion.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/MathExtras.h" +#include + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// Utility functions +//===----------------------------------------------------------------------===// + +/// The emulation only works on 1D memref types. +/// To make this work on N-D memref, we need to linearize the offset. +/// +/// For example, to emulate i4 to i8, the following op: +/// +/// %0 = memref.load %arg0[%v0, %v1] : +/// memref> +/// +/// can be replaced with +/// +/// %b, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %0 +/// +/// %linearized_offset = %v0 * %stride#0 + %v1 * %stride#1 +/// %linearized_size = %size0 * %size1 +/// %scaled_linear_offset = %linearized_offset / 8 * 4 +/// %scaled_base_offset = %offset / 8 * 4 +/// +/// %linearized = memref.reinterpret_cast %b, offset = [%scaled_base_offset], +/// sizes = [%linearized_size], strides = [%stride#1] +/// +/// %new_load = memref.load %linearized[%scaled_linear_offset] : +/// memref> + +static Value +linearizeMemrefLoad(Location loc, MemRefType sourceType, int srcBits, + int dstBits, SmallVector indices, + memref::ExtractStridedMetadataOp stridedMetadata, + OpBuilder &builder) { + auto srcElementType = sourceType.getElementType(); + unsigned sourceRank = indices.size(); + + Value baseBuffer = stridedMetadata.getBaseBuffer(); + SmallVector baseSizes = stridedMetadata.getSizes(); + SmallVector baseStrides = stridedMetadata.getStrides(); + Value baseOffset = stridedMetadata.getOffset(); + assert(indices.size() == baseStrides.size()); + + // Create the affine symbols and values for linearization. + SmallVector symbols(2 * sourceRank + 2); + bindSymbolsList(builder.getContext(), MutableArrayRef{symbols}); + symbols[0] = builder.getAffineSymbolExpr(0); + AffineExpr addMulMap = symbols.front(); + AffineExpr mulMap = symbols.front(); + + SmallVector offsetValues(2 * sourceRank + 2); + offsetValues[0] = builder.getIndexAttr(0); + SmallVector sizeValues(sourceRank + 1); + sizeValues[0] = builder.getIndexAttr(1); + + for (unsigned i = 0; i < sourceRank; ++i) { + unsigned offsetIdx = 2 * i + 1; + addMulMap = addMulMap + symbols[offsetIdx] * symbols[offsetIdx + 1]; + offsetValues[offsetIdx] = indices[i]; + offsetValues[offsetIdx + 1] = baseStrides[i]; + + unsigned sizeIdx = i + 1; + mulMap = mulMap * symbols[sizeIdx]; + sizeValues[sizeIdx] = baseSizes[i]; + } + + // Adjust linearizedOffset by the scale factor (dstBits / srcBits). + OpFoldResult scaler = builder.getIndexAttr(dstBits / srcBits); + AffineExpr scaledAddMulMap = addMulMap.floorDiv(symbols.back()); + offsetValues.back() = scaler; + + OpFoldResult linearizedOffset = affine::makeComposedFoldedAffineApply( + builder, loc, scaledAddMulMap, offsetValues); + OpFoldResult linearizedSize = + affine::makeComposedFoldedAffineApply(builder, loc, mulMap, sizeValues); + + // Adjust baseOffset by the scale factor (dstBits / srcBits). + AffineExpr s0, s1; + bindSymbols(builder.getContext(), s0, s1); + OpFoldResult adjustBaseOffset = affine::makeComposedFoldedAffineApply( + builder, loc, s0.floorDiv(s1), {baseOffset, scaler}); + + // Flatten n-D MemRef to 1-D MemRef. + auto layoutAttr = StridedLayoutAttr::get( + sourceType.getContext(), ShapedType::kDynamic, {ShapedType::kDynamic}); + int64_t staticShape = sourceType.hasStaticShape() + ? sourceType.getNumElements() + : ShapedType::kDynamic; + auto flattenMemrefType = MemRefType::get( + staticShape, srcElementType, layoutAttr, sourceType.getMemorySpace()); + + auto reinterpret = builder.create( + loc, flattenMemrefType, baseBuffer, + getValueOrCreateConstantIndexOp(builder, loc, adjustBaseOffset), + getValueOrCreateConstantIndexOp(builder, loc, linearizedSize), + baseStrides.back()); + + return builder.create( + loc, srcElementType, reinterpret.getResult(), + getValueOrCreateConstantIndexOp(builder, loc, linearizedOffset)); +} + +/// When data is loaded/stored in `targetBits` granularity, but is used in +/// `sourceBits` granularity (`sourceBits` < `targetBits`), the `targetBits` is +/// treated as an array of elements of width `sourceBits`. +/// Return the bit offset of the value at position `srcIdx`. For example, if +/// `sourceBits` equals to 4 and `targetBits` equals to 8, the x-th element is +/// located at (x % 2) * 4. Because there are two elements in one i8, and one +/// element has 4 bits. +static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits, + int targetBits, OpBuilder &builder) { + assert(targetBits % sourceBits == 0); + IntegerType targetType = builder.getIntegerType(targetBits); + IntegerAttr idxAttr = + builder.getIntegerAttr(targetType, targetBits / sourceBits); + auto idx = builder.create(loc, targetType, idxAttr); + IntegerAttr srcBitsAttr = builder.getIntegerAttr(targetType, sourceBits); + auto srcBitsValue = + builder.create(loc, targetType, srcBitsAttr); + auto m = builder.create(loc, srcIdx, idx); + return builder.create(loc, targetType, m, srcBitsValue); +} + +namespace { + +//===----------------------------------------------------------------------===// +// ConvertMemRefAlloc +//===----------------------------------------------------------------------===// + +struct ConvertMemRefAlloc final : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(memref::AllocOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type newTy = getTypeConverter()->convertType(op.getType()); + if (!newTy) { + return rewriter.notifyMatchFailure( + op->getLoc(), + llvm::formatv("failed to convert memref type: {0}", op.getType())); + } + + rewriter.replaceOpWithNewOp( + op, newTy, adaptor.getDynamicSizes(), adaptor.getSymbolOperands(), + adaptor.getAlignmentAttr()); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// ConvertMemRefAssumeAlignment +//===----------------------------------------------------------------------===// + +struct ConvertMemRefAssumeAlignment final + : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type newTy = getTypeConverter()->convertType(op.getMemref().getType()); + if (!newTy) { + return rewriter.notifyMatchFailure( + op->getLoc(), llvm::formatv("failed to convert memref type: {0}", + op.getMemref().getType())); + } + + rewriter.replaceOpWithNewOp( + op, adaptor.getMemref(), adaptor.getAlignmentAttr()); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// ConvertMemRefLoad +//===----------------------------------------------------------------------===// + +struct ConvertMemRefLoad final : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(memref::LoadOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type newTy = getTypeConverter()->convertType(op.getMemRefType()); + if (!newTy) { + return rewriter.notifyMatchFailure( + op->getLoc(), llvm::formatv("failed to convert memref type: {0}", + op.getMemRefType())); + } + + if (op.getMemRefType() == newTy) + return failure(); + + auto loc = op.getLoc(); + auto sourceType = cast(adaptor.getMemref().getType()); + unsigned sourceRank = sourceType.getRank(); + SmallVector indices = adaptor.getIndices(); + assert(indices.size() == sourceRank); + + auto srcElementType = sourceType.getElementType(); + auto oldElementType = op.getMemRefType().getElementType(); + int srcBits = oldElementType.getIntOrFloatBitWidth(); + int dstBits = srcElementType.getIntOrFloatBitWidth(); + if (dstBits % srcBits != 0) { + return rewriter.notifyMatchFailure( + op, "only dstBits % srcBits == 0 supported"); + } + + auto stridedMetadata = rewriter.create( + loc, adaptor.getMemref()); + + Value newLoad, lastIdx; + if (sourceRank == 0) { + newLoad = rewriter.create( + loc, srcElementType, adaptor.getMemref(), adaptor.getIndices()); + + lastIdx = stridedMetadata.getOffset(); + } else { + newLoad = linearizeMemrefLoad(loc, sourceType, srcBits, dstBits, indices, + stridedMetadata, rewriter); + + lastIdx = adaptor.getIndices().back(); + } + + // Get the offset and shift the bits to the rightmost. + // Note, currently only the big-endian is supported. + auto castLastIdx = + rewriter.create(loc, srcElementType, lastIdx); + + Value BitwidthOffset = + getOffsetForBitwidth(loc, castLastIdx, srcBits, dstBits, rewriter); + auto bitsLoad = + rewriter.create(loc, newLoad, BitwidthOffset); + + // Get the corresponding bits. If the arith computation bitwidth equals + // to the emulated bitwidth, we apply a mask to extract the low bits. + // It is not clear if this case actually happens in practice, but we keep + // the operations just in case. Otherwise, if the arith computation bitwidth + // is different from the emulated bitwidth we truncate the result. + Operation *result; + auto resultTy = getTypeConverter()->convertType(oldElementType); + if (resultTy == srcElementType) { + auto mask = rewriter.create( + loc, srcElementType, + rewriter.getIntegerAttr(srcElementType, (1 << srcBits) - 1)); + + result = rewriter.create(loc, bitsLoad, mask); + } else { + result = rewriter.create(loc, resultTy, bitsLoad); + } + + rewriter.replaceOp(op, result->getResult(0)); + return success(); + } +}; +} // end anonymous namespace + +//===----------------------------------------------------------------------===// +// Public Interface Definition +//===----------------------------------------------------------------------===// + +void memref::populateMemRefNarrowTypeEmulationPatterns( + arith::NarrowTypeEmulationConverter &typeConverter, + RewritePatternSet &patterns) { + + // Populate `memref.*` conversion patterns. + patterns + .add( + typeConverter, patterns.getContext()); +} + +void memref::populateMemRefNarrowTypeEmulationConversions( + arith::NarrowTypeEmulationConverter &typeConverter) { + typeConverter.addConversion( + [&typeConverter](MemRefType ty) -> std::optional { + auto intTy = dyn_cast(ty.getElementType()); + if (!intTy) + return ty; + + unsigned width = intTy.getWidth(); + unsigned loadStoreWidth = typeConverter.getLoadStoreBitwidth(); + if (width >= loadStoreWidth) + return ty; + + auto newElemTy = IntegerType::get(ty.getContext(), loadStoreWidth, + intTy.getSignedness()); + if (!newElemTy) + return std::nullopt; + + return ty.cloneWith(std::nullopt, newElemTy); + }); +} diff --git a/mlir/test/Dialect/Arith/emulate-narrow-type.mlir b/mlir/test/Dialect/Arith/emulate-narrow-type.mlir new file mode 100644 index 0000000..7120882 --- /dev/null +++ b/mlir/test/Dialect/Arith/emulate-narrow-type.mlir @@ -0,0 +1,47 @@ +// RUN: mlir-opt --test-emulate-narrow-int="arith-compute-bitwidth=8" %s | FileCheck %s + +// Expect no conversions, f32 is not an integer type. +// CHECK-LABEL: func @identity_f32 +// CHECK-SAME: ([[ARG:%.+]]: f32) -> f32 +// CHECK-NEXT: return [[ARG]] : f32 +func.func @identity_f32(%a : f32) -> f32 { + return %a : f32 +} + +// Expect no conversions, i32 is supported. +// CHECK-LABEL: func @identity_i32 +// CHECK-SAME: ([[ARG:%.+]]: vector<2xi32>) -> vector<2xi32> +// CHECK-NEXT: return [[ARG]] : vector<2xi32> +func.func @identity_i32(%a : vector<2xi32>) -> vector<2xi32> { + return %a : vector<2xi32> +} + +// CHECK-LABEL: func @identity_scalar +// CHECK-SAME: ([[ARG:%.+]]: i8) -> i8 +// CHECK-NEXT: return [[ARG]] : i8 +func.func @identity_scalar(%x : i4) -> i4 { + return %x : i4 +} + +// CHECK-LABEL: func @identity_vector +// CHECK-SAME: ([[ARG:%.+]]: vector<4xi8>) -> vector<4xi8> +// CHECK-NEXT: return [[ARG]] : vector<4xi8> +func.func @identity_vector(%x : vector<4xi4>) -> vector<4xi4> { + return %x : vector<4xi4> +} + +// CHECK-LABEL: func @identity_vector2d +// CHECK-SAME: ([[ARG:%.+]]: vector<3x4xi8>) -> vector<3x4xi8> +// CHECK-NEXT: return [[ARG]] : vector<3x4xi8> +func.func @identity_vector2d(%x : vector<3x4xi4>) -> vector<3x4xi4> { + return %x : vector<3x4xi4> +} + +// CHECK-LABEL: func @call +// CHECK-SAME: ([[ARG:%.+]]: vector<4xi8>) -> vector<4xi8> +// CHECK-NEXT: [[RES:%.+]] = call @identity_vector([[ARG]]) : (vector<4xi8>) -> vector<4xi8> +// CHECK-NEXT: return [[RES]] : vector<4xi8> +func.func @call(%a : vector<4xi4>) -> vector<4xi4> { + %res = func.call @identity_vector(%a) : (vector<4xi4>) -> vector<4xi4> + return %res : vector<4xi4> +} diff --git a/mlir/test/Dialect/MemRef/emulate-narrow-type-diff-load-compute.mlir b/mlir/test/Dialect/MemRef/emulate-narrow-type-diff-load-compute.mlir new file mode 100644 index 0000000..85d4cc1 --- /dev/null +++ b/mlir/test/Dialect/MemRef/emulate-narrow-type-diff-load-compute.mlir @@ -0,0 +1,107 @@ +// RUN: mlir-opt --test-emulate-narrow-int="arith-compute-bitwidth=4 memref-load-bitwidth=8" %s | FileCheck %s + +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 2)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 floordiv 2)> +// CHECK-DAG: #[[$MAP2:.*]] = affine_map<()[s0, s1, s2, s3] -> ((s0 * s1 + s2 * s3) floordiv 2)> +// CHECK-DAG: #[[$MAP3:.*]] = affine_map<()[s0, s1] -> (s0 * s1)> + +// Expect no conversions, i32 is supported. +// CHECK-LABEL: func @memref_i32 +// CHECK: [[M:%.+]] = memref.alloc() : memref<4xi32, 1> +// CHECK-NEXT: [[V:%.+]] = memref.load [[M]][{{%.+}}] : memref<4xi32, 1> +// CHECK-NEXT: memref.store {{%.+}}, [[M]][{{%.+}}] : memref<4xi32, 1> +// CHECK-NEXT: return +func.func @memref_i32() { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : i32 + %m = memref.alloc() : memref<4xi32, 1> + %v = memref.load %m[%c0] : memref<4xi32, 1> + memref.store %c1, %m[%c0] : memref<4xi32, 1> + return +} + +// ----- + +// Expect no conversions, f32 is not an integer type. +// CHECK-LABEL: func @memref_f32 +// CHECK: [[M:%.+]] = memref.alloc() : memref<4xf32, 1> +// CHECK-NEXT: [[V:%.+]] = memref.load [[M]][{{%.+}}] : memref<4xf32, 1> +// CHECK-NEXT: memref.store {{%.+}}, [[M]][{{%.+}}] : memref<4xf32, 1> +// CHECK-NEXT: return +func.func @memref_f32() { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1.0 : f32 + %m = memref.alloc() : memref<4xf32, 1> + %v = memref.load %m[%c0] : memref<4xf32, 1> + memref.store %c1, %m[%c0] : memref<4xf32, 1> + return +} + +// ----- + +// CHECK-LABEL: func @memref_load_i4_zero_rank +// CHECK-NEXT: %[[M:.*]] = memref.alloc() : memref +// CHECK-NEXT: %[[BASE:.*]], %[[OFFSET:.*]] = memref.extract_strided_metadata %[[M]] : memref -> memref, index +// CHECK-NEXT: %[[LOAD:.*]] = memref.load %[[M]][] : memref +// CHECK-NEXT: %[[I:.*]] = arith.index_castui %[[OFFSET]] : index to i8 +// CHECK-NEXT: %[[C2:.*]] = arith.constant 2 : i8 +// CHECK-NEXT: %[[C4:.*]] = arith.constant 4 : i8 +// CHECK-NEXT: %[[REM:.*]] = arith.remui %[[I]], %[[C2]] : i8 +// CHECK-NEXT: %[[STEP:.*]] = arith.muli %[[REM]], %[[C4]] : i8 +// CHECK-NEXT: %[[SHIFT:.*]] = arith.shrsi %[[LOAD]], %[[STEP]] : i8 +// CHECK-NEXT: %[[RES:.*]] = arith.trunci %[[SHIFT]] : i8 to i4 +// CHECK-NEXT: return +func.func @memref_load_i4_zero_rank() { + %0 = memref.alloc() : memref + %1 = memref.load %0[] : memref + return +} + +// ----- + +// CHECK-LABEL: func @memref_load_i4 +// CHECK-SAME: (%[[ARG:.*]]: index) +// CHECK-NEXT: %[[M:.*]] = memref.alloc() : memref<4xi8> +// CHECK-NEXT: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]], %[[STRIDES:.*]] = memref.extract_strided_metadata %[[M]] : memref<4xi8> -> memref, index, index, index +// CHECK-NEXT: %[[INDEX:.*]] = affine.apply #[[$MAP0]]()[%[[ARG]], %[[STRIDES]]] +// CHECK-NEXT: %[[AOFF:.*]] = affine.apply #[[$MAP1]]()[%[[OFFSET]]] +// CHECK-NEXT: %[[CAST:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[AOFF]]], sizes: [%[[SIZES]]], strides: [%[[STRIDES]]] : memref to memref<4xi8, strided<[?], offset: ?>> +// CHECK-NEXT: %[[LOAD:.*]] = memref.load %[[CAST]][%[[INDEX]]] : memref<4xi8, strided<[?], offset: ?>> +// CHECK-NEXT: %[[I:.*]] = arith.index_castui %[[ARG]] : index to i8 +// CHECK-NEXT: %[[C2:.*]] = arith.constant 2 : i8 +// CHECK-NEXT: %[[C4:.*]] = arith.constant 4 : i8 +// CHECK-NEXT: %[[REM:.*]] = arith.remui %[[I]], %[[C2]] : i8 +// CHECK-NEXT: %[[STEP:.*]] = arith.muli %[[REM]], %[[C4]] : i8 +// CHECK-NEXT: %[[SHIFT:.*]] = arith.shrsi %[[LOAD]], %[[STEP]] : i8 +// CHECK-NEXT: %[[RES:.*]] = arith.trunci %[[SHIFT]] : i8 to i4 +// CHECK-NEXT: return +func.func @memref_load_i4(%arg0: index) { + %0 = memref.alloc() : memref<4xi4> + %1 = memref.load %0[%arg0] : memref<4xi4> + return +} + +// ----- + +// CHECK-LABEL: func @memref_load_i4_rank2 +// CHECK-SAME: (%[[ARG:.*]]: memref<4x128xi8>, %[[ARG0:.*]]: index, %[[ARG1:.*]]: index) +// CHECK-NEXT: memref.assume_alignment %[[ARG]], 64 : memref<4x128xi8> +// CHECK-NEXT: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]] : memref<4x128xi8> -> memref, index, index, index, index, index +// CHECK-NEXT: %[[INDEX:.*]] = affine.apply #[[$MAP2]]()[%[[ARG0]], %[[STRIDES]]#0, %[[ARG1]], %[[STRIDES]]#1] +// CHECK-NEXT: %[[LSIZE:.*]] = affine.apply #[[$MAP3]]()[%[[SIZES]]#0, %[[SIZES]]#1] +// CHECK-NEXT: %[[AOFF:.*]] = affine.apply #[[$MAP1]]()[%[[OFFSET]]] +// CHECK-NEXT: %[[CAST:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[AOFF]]], sizes: [%[[LSIZE]]], strides: [%[[STRIDES]]#1] : memref to memref<512xi8, strided<[?], offset: ?>> +// CHECK-NEXT: %[[LOAD:.*]] = memref.load %[[CAST]][%[[INDEX]]] : memref<512xi8, strided<[?], offset: ?>> +// CHECK-NEXT: %[[I:.*]] = arith.index_castui %[[ARG1]] : index to i8 +// CHECK-NEXT: %[[C2:.*]] = arith.constant 2 : i8 +// CHECK-NEXT: %[[C4:.*]] = arith.constant 4 : i8 +// CHECK-NEXT: %[[REM:.*]] = arith.remui %[[I]], %[[C2]] : i8 +// CHECK-NEXT: %[[STEP:.*]] = arith.muli %[[REM]], %[[C4]] : i8 +// CHECK-NEXT: %[[SHIFT:.*]] = arith.shrsi %[[LOAD]], %[[STEP]] : i8 +// CHECK-NEXT: %[[RES:.*]] = arith.trunci %[[SHIFT]] : i8 to i4 +// CHECK-NEXT: return +func.func @memref_load_i4_rank2(%0: memref<4x128xi4>, %arg0: index, %arg1: index) { + memref.assume_alignment %0, 64 : memref<4x128xi4> + %1 = memref.load %0[%arg0,%arg1] : memref<4x128xi4> + return +} diff --git a/mlir/test/Dialect/MemRef/emulate-narrow-type-same-load-compute.mlir b/mlir/test/Dialect/MemRef/emulate-narrow-type-same-load-compute.mlir new file mode 100644 index 0000000..9d63b9d --- /dev/null +++ b/mlir/test/Dialect/MemRef/emulate-narrow-type-same-load-compute.mlir @@ -0,0 +1,72 @@ +// RUN: mlir-opt --test-emulate-narrow-int="arith-compute-bitwidth=8 memref-load-bitwidth=8" %s | FileCheck %s + +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 2)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 floordiv 2)> +// CHECK-DAG: #[[$MAP2:.*]] = affine_map<()[s0, s1, s2, s3] -> ((s0 * s1 + s2 * s3) floordiv 2)> +// CHECK-DAG: #[[$MAP3:.*]] = affine_map<()[s0, s1] -> (s0 * s1)> + +// Expect no conversions. +// CHECK-LABEL: func @memref_i8 +// CHECK: [[M:%.+]] = memref.alloc() : memref<4xi8, 1> +// CHECK-NEXT: [[V:%.+]] = memref.load [[M]][{{%.+}}] : memref<4xi8, 1> +// CHECK-NEXT: memref.store {{%.+}}, [[M]][{{%.+}}] : memref<4xi8, 1> +// CHECK-NEXT: return +func.func @memref_i8() { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : i8 + %m = memref.alloc() : memref<4xi8, 1> + %v = memref.load %m[%c0] : memref<4xi8, 1> + memref.store %c1, %m[%c0] : memref<4xi8, 1> + return +} + +// ----- + +// CHECK-LABEL: func @memref_load_i4 +// CHECK-SAME: (%[[ARG:.*]]: index) +// CHECK-NEXT: %[[M:.*]] = memref.alloc() : memref<4xi8> +// CHECK-NEXT: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]], %[[STRIDES:.*]] = memref.extract_strided_metadata %[[M]] : memref<4xi8> -> memref, index, index, index +// CHECK-NEXT: %[[INDEX:.*]] = affine.apply #[[$MAP0]]()[%[[ARG]], %[[STRIDES]]] +// CHECK-NEXT: %[[AOFF:.*]] = affine.apply #[[$MAP1]]()[%[[OFFSET]]] +// CHECK-NEXT: %[[CAST:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[AOFF]]], sizes: [%[[SIZES]]], strides: [%[[STRIDES]]] : memref to memref<4xi8, strided<[?], offset: ?>> +// CHECK-NEXT: %[[LOAD:.*]] = memref.load %[[CAST]][%[[INDEX]]] : memref<4xi8, strided<[?], offset: ?>> +// CHECK-NEXT: %[[I:.*]] = arith.index_castui %[[ARG]] : index to i8 +// CHECK-NEXT: %[[C2:.*]] = arith.constant 2 : i8 +// CHECK-NEXT: %[[C4:.*]] = arith.constant 4 : i8 +// CHECK-NEXT: %[[REM:.*]] = arith.remui %[[I]], %[[C2]] : i8 +// CHECK-NEXT: %[[STEP:.*]] = arith.muli %[[REM]], %[[C4]] : i8 +// CHECK-NEXT: %[[SHIFT:.*]] = arith.shrsi %[[LOAD]], %[[STEP]] : i8 +// CHECK-NEXT: %[[MASK:.*]] = arith.constant 15 : i8 +// CHECK-NEXT: %[[RES:.*]] = arith.andi %[[SHIFT]], %[[MASK]] : i8 +// CHECK-NEXT: return +func.func @memref_load_i4(%arg0: index) { + %0 = memref.alloc() : memref<4xi4> + %1 = memref.load %0[%arg0] : memref<4xi4> + return +} + +// ----- + +// CHECK-LABEL: func @memref_load_i4_rank2 +// CHECK-SAME: (%[[ARG:.*]]: memref<4x128xi8>, %[[ARG0:.*]]: index, %[[ARG1:.*]]: index) +// CHECK-NEXT: memref.assume_alignment %[[ARG]], 64 : memref<4x128xi8> +// CHECK-NEXT: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]] : memref<4x128xi8> -> memref, index, index, index, index, index +// CHECK-NEXT: %[[INDEX:.*]] = affine.apply #[[$MAP2]]()[%[[ARG0]], %[[STRIDES]]#0, %[[ARG1]], %[[STRIDES]]#1] +// CHECK-NEXT: %[[LSIZE:.*]] = affine.apply #[[$MAP3]]()[%[[SIZES]]#0, %[[SIZES]]#1] +// CHECK-NEXT: %[[AOFF:.*]] = affine.apply #[[$MAP1]]()[%[[OFFSET]]] +// CHECK-NEXT: %[[CAST:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[AOFF]]], sizes: [%[[LSIZE]]], strides: [%[[STRIDES]]#1] : memref to memref<512xi8, strided<[?], offset: ?>> +// CHECK-NEXT: %[[LOAD:.*]] = memref.load %[[CAST]][%[[INDEX]]] : memref<512xi8, strided<[?], offset: ?>> +// CHECK-NEXT: %[[I:.*]] = arith.index_castui %[[ARG1]] : index to i8 +// CHECK-NEXT: %[[C2:.*]] = arith.constant 2 : i8 +// CHECK-NEXT: %[[C4:.*]] = arith.constant 4 : i8 +// CHECK-NEXT: %[[REM:.*]] = arith.remui %[[I]], %[[C2]] : i8 +// CHECK-NEXT: %[[STEP:.*]] = arith.muli %[[REM]], %[[C4]] : i8 +// CHECK-NEXT: %[[SHIFT:.*]] = arith.shrsi %[[LOAD]], %[[STEP]] : i8 +// CHECK-NEXT: %[[MASK:.*]] = arith.constant 15 : i8 +// CHECK-NEXT: %[[RES:.*]] = arith.andi %[[SHIFT]], %[[MASK]] : i8 +// CHECK-NEXT: return +func.func @memref_load_i4_rank2(%0: memref<4x128xi4>, %arg0: index, %arg1: index) { + memref.assume_alignment %0, 64 : memref<4x128xi4> + %1 = memref.load %0[%arg0,%arg1] : memref<4x128xi4> + return +} diff --git a/mlir/test/lib/Dialect/MemRef/CMakeLists.txt b/mlir/test/lib/Dialect/MemRef/CMakeLists.txt index df3fdacd..0498de3 100644 --- a/mlir/test/lib/Dialect/MemRef/CMakeLists.txt +++ b/mlir/test/lib/Dialect/MemRef/CMakeLists.txt @@ -1,6 +1,7 @@ # Exclude tests from libMLIR.so add_mlir_library(MLIRMemRefTestPasses TestComposeSubView.cpp + TestEmulateNarrowType.cpp TestMultiBuffer.cpp EXCLUDE_FROM_LIBMLIR diff --git a/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp b/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp new file mode 100644 index 0000000..b1f2308 --- /dev/null +++ b/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp @@ -0,0 +1,118 @@ +//===- TestEmulateNarrowType.cpp - Test Narrow Type Emulation ------*- c++ +//-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Arith/Transforms/NarrowTypeEmulationConverter.h" +#include "mlir/Dialect/Arith/Transforms/Passes.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/MemRef/Transforms/Transforms.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +using namespace mlir; + +namespace { + +struct TestEmulateNarrowTypePass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestEmulateNarrowTypePass) + + TestEmulateNarrowTypePass() = default; + TestEmulateNarrowTypePass(const TestEmulateNarrowTypePass &pass) + : PassWrapper(pass) {} + + void getDependentDialects(DialectRegistry ®istry) const override { + registry + .insert(); + } + StringRef getArgument() const final { return "test-emulate-narrow-int"; } + StringRef getDescription() const final { + return "Function pass to test Narrow Integer Emulation"; + } + + void runOnOperation() override { + if (!llvm::isPowerOf2_32(loadStoreEmulateBitwidth) || + loadStoreEmulateBitwidth < 8) { + signalPassFailure(); + return; + } + + Operation *op = getOperation(); + MLIRContext *ctx = op->getContext(); + + arith::NarrowTypeEmulationConverter typeConverter(loadStoreEmulateBitwidth); + + // Convert scalar type. + typeConverter.addConversion([this](IntegerType ty) -> std::optional { + unsigned width = ty.getWidth(); + if (width >= arithComputeBitwidth) + return ty; + + return IntegerType::get(ty.getContext(), arithComputeBitwidth); + }); + + // Convert vector type. + typeConverter.addConversion([this](VectorType ty) -> std::optional { + auto intTy = dyn_cast(ty.getElementType()); + if (!intTy) + return ty; + + unsigned width = intTy.getWidth(); + if (width >= arithComputeBitwidth) + return ty; + + return VectorType::get( + to_vector(ty.getShape()), + IntegerType::get(ty.getContext(), arithComputeBitwidth)); + }); + + memref::populateMemRefNarrowTypeEmulationConversions(typeConverter); + ConversionTarget target(*ctx); + target.addDynamicallyLegalOp([&typeConverter](Operation *op) { + return typeConverter.isLegal(cast(op).getFunctionType()); + }); + auto opLegalCallback = [&typeConverter](Operation *op) { + return typeConverter.isLegal(op); + }; + target.addDynamicallyLegalOp(opLegalCallback); + target.addDynamicallyLegalDialect< + arith::ArithDialect, vector::VectorDialect, memref::MemRefDialect, + affine::AffineDialect>( + [&typeConverter](Operation *op) { return typeConverter.isLegal(op); }); + + RewritePatternSet patterns(ctx); + + arith::populateArithNarrowTypeEmulationPatterns(typeConverter, patterns); + memref::populateMemRefNarrowTypeEmulationPatterns(typeConverter, patterns); + + if (failed(applyPartialConversion(op, target, std::move(patterns)))) + signalPassFailure(); + } + + Option loadStoreEmulateBitwidth{ + *this, "memref-load-bitwidth", + llvm::cl::desc("memref load/store emulation bit width"), + llvm::cl::init(8)}; + + Option arithComputeBitwidth{ + *this, "arith-compute-bitwidth", + llvm::cl::desc("arith computation bit width"), llvm::cl::init(4)}; +}; +} // namespace + +namespace mlir::test { +void registerTestEmulateNarrowTypePass() { + PassRegistration(); +} +} // namespace mlir::test diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp index d75b54e..5b95663 100644 --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -87,6 +87,7 @@ void registerTestDiagnosticsPass(); void registerTestDialectConversionPasses(); void registerTestDominancePass(); void registerTestDynamicPipelinePass(); +void registerTestEmulateNarrowTypePass(); void registerTestExpandMathPass(); void registerTestFooAnalysisPass(); void registerTestComposeSubView(); @@ -205,6 +206,7 @@ void registerTestPasses() { mlir::test::registerTestDeadCodeAnalysisPass(); mlir::test::registerTestDominancePass(); mlir::test::registerTestDynamicPipelinePass(); + mlir::test::registerTestEmulateNarrowTypePass(); mlir::test::registerTestExpandMathPass(); mlir::test::registerTestFooAnalysisPass(); mlir::test::registerTestComposeSubView(); -- 2.7.4