[mlir] Narrow bitwidth emulation for MemRef load
authoryzhang93 <zhyuhang88@gmail.com>
Mon, 26 Jun 2023 21:18:15 +0000 (14:18 -0700)
committerHanhan Wang <hanchung@google.com>
Mon, 26 Jun 2023 21:18:30 +0000 (14:18 -0700)
This patch adds support for narrow bitwidth storage emulation. The goal is to support sub-byte type
codegen for LLVM CPU. Specifically, a type converter is added to convert memref of narrow bitwidth
(e.g., i4) into supported wider bitwidth (e.g., i8). Another focus of this patch is to populate the
pattern for int4 memref.load. memref.store pattern should be added in a seperate patch.

Reviewed By: hanchung, mravishankar

Differential Revision: https://reviews.llvm.org/D151519

13 files changed:
mlir/include/mlir/Dialect/Arith/Transforms/NarrowTypeEmulationConverter.h [new file with mode: 0644]
mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
mlir/lib/Dialect/Arith/Transforms/EmulateNarrowType.cpp [new file with mode: 0644]
mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp [new file with mode: 0644]
mlir/test/Dialect/Arith/emulate-narrow-type.mlir [new file with mode: 0644]
mlir/test/Dialect/MemRef/emulate-narrow-type-diff-load-compute.mlir [new file with mode: 0644]
mlir/test/Dialect/MemRef/emulate-narrow-type-same-load-compute.mlir [new file with mode: 0644]
mlir/test/lib/Dialect/MemRef/CMakeLists.txt
mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp [new file with mode: 0644]
mlir/tools/mlir-opt/mlir-opt.cpp

diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/NarrowTypeEmulationConverter.h b/mlir/include/mlir/Dialect/Arith/Transforms/NarrowTypeEmulationConverter.h
new file mode 100644 (file)
index 0000000..528bb51
--- /dev/null
@@ -0,0 +1,31 @@
+//===- NarrowTypeEmulationConverter.h - Type Converter for NTE -----*- C++
+//-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_ARITH_NARROW_TYPE_EMULATION_CONVERTER_H_
+#define MLIR_DIALECT_ARITH_NARROW_TYPE_EMULATION_CONVERTER_H_
+
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir::arith {
+/// Converts narrow integer or float types that are not supported
+/// by the target hardware to wider types. Currently, we only
+/// handle power-of-two integer types and convert them to wider
+/// integers that are equal or larger than 8 bits.
+class NarrowTypeEmulationConverter : public TypeConverter {
+public:
+  explicit NarrowTypeEmulationConverter(unsigned targetBitwidth);
+
+  unsigned getLoadStoreBitwidth() const { return loadStoreBitwidth; }
+
+private:
+  unsigned loadStoreBitwidth;
+};
+} // namespace mlir::arith
+
+#endif // MLIR_DIALECT_ARITH_NARROW_TYPE_EMULATION_CONVERTER_H_
index c4010b7..de36cb4 100644 (file)
@@ -22,6 +22,7 @@ namespace arith {
 #include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
 
 class WideIntEmulationConverter;
+class NarrowTypeEmulationConverter;
 
 /// Create a pass to bufferize Arith ops.
 std::unique_ptr<Pass> createArithBufferizePass();
@@ -35,6 +36,12 @@ std::unique_ptr<Pass> createConstantBufferizePass(uint64_t alignment = 0);
 void populateArithWideIntEmulationPatterns(
     WideIntEmulationConverter &typeConverter, RewritePatternSet &patterns);
 
+/// Adds patterns to emulate narrow Arith and Function ops into wide
+/// supported types. Users need to add conversions about the computation
+/// domain of narrow types.
+void populateArithNarrowTypeEmulationPatterns(
+    NarrowTypeEmulationConverter &typeConverter, RewritePatternSet &patterns);
+
 /// Add patterns to expand Arith ceil/floor division ops.
 void populateCeilFloorDivExpandOpsPatterns(RewritePatternSet &patterns);
 
index 91ef162..0b1af47 100644 (file)
@@ -25,6 +25,7 @@ class ValueRange;
 
 namespace arith {
 class WideIntEmulationConverter;
+class NarrowTypeEmulationConverter;
 } // namespace arith
 
 namespace memref {
@@ -73,6 +74,17 @@ void populateMemRefWideIntEmulationPatterns(
 void populateMemRefWideIntEmulationConversions(
     arith::WideIntEmulationConverter &typeConverter);
 
+/// Appends patterns for emulating memref operations over narrow types with ops
+/// over wider types.
+void populateMemRefNarrowTypeEmulationPatterns(
+    arith::NarrowTypeEmulationConverter &typeConverter,
+    RewritePatternSet &patterns);
+
+/// Appends type conversions for emulating memref operations over narrow types
+/// with ops over wider types.
+void populateMemRefNarrowTypeEmulationConversions(
+    arith::NarrowTypeEmulationConverter &typeConverter);
+
 /// Transformation to do multi-buffering/array expansion to remove dependencies
 /// on the temporary allocation between consecutive loop iterations.
 /// It returns the new allocation if the original allocation was multi-buffered
index 87d9beb..b969389 100644 (file)
@@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRArithTransforms
   BufferizableOpInterfaceImpl.cpp
   Bufferize.cpp
   EmulateWideInt.cpp
+  EmulateNarrowType.cpp
   ExpandOps.cpp
   IntNarrowing.cpp
   IntRangeOptimizations.cpp
diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateNarrowType.cpp
new file mode 100644 (file)
index 0000000..e0e1385
--- /dev/null
@@ -0,0 +1,61 @@
+//===- EmulateNarrowType.cpp - Narrow type emulation ----*- C++
+//-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/Transforms/Passes.h"
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Transforms/NarrowTypeEmulationConverter.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/APInt.h"
+#include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/MathExtras.h"
+#include <cassert>
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// Public Interface Definition
+//===----------------------------------------------------------------------===//
+
+arith::NarrowTypeEmulationConverter::NarrowTypeEmulationConverter(
+    unsigned targetBitwidth)
+    : loadStoreBitwidth(targetBitwidth) {
+  assert(llvm::isPowerOf2_32(targetBitwidth) &&
+         "Only power-of-two integers are supported");
+
+  // Allow unknown types.
+  addConversion([](Type ty) -> std::optional<Type> { return ty; });
+
+  // Function case.
+  addConversion([this](FunctionType ty) -> std::optional<Type> {
+    SmallVector<Type> inputs;
+    if (failed(convertTypes(ty.getInputs(), inputs)))
+      return std::nullopt;
+
+    SmallVector<Type> results;
+    if (failed(convertTypes(ty.getResults(), results)))
+      return std::nullopt;
+
+    return FunctionType::get(ty.getContext(), inputs, results);
+  });
+}
+
+void arith::populateArithNarrowTypeEmulationPatterns(
+    NarrowTypeEmulationConverter &typeConverter, RewritePatternSet &patterns) {
+  // Populate `func.*` conversion patterns.
+  populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
+                                                                 typeConverter);
+  populateCallOpTypeConversionPattern(patterns, typeConverter);
+  populateReturnOpTypeConversionPattern(patterns, typeConverter);
+}
index a16d850..10ca179 100644 (file)
@@ -4,6 +4,7 @@ add_mlir_dialect_library(MLIRMemRefTransforms
   ExpandOps.cpp
   ExpandStridedMetadata.cpp
   EmulateWideInt.cpp
+  EmulateNarrowType.cpp
   ExtractAddressComputations.cpp
   FoldMemRefAliasOps.cpp
   IndependenceTransforms.cpp
diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
new file mode 100644 (file)
index 0000000..a876bc7
--- /dev/null
@@ -0,0 +1,315 @@
+//===- EmulateNarrowType.cpp - Narrow type emulation ----*- C++
+//-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Transforms/NarrowTypeEmulationConverter.h"
+#include "mlir/Dialect/Arith/Transforms/Passes.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/MemRef/Transforms/Passes.h"
+#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/MathExtras.h"
+#include <cassert>
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// Utility functions
+//===----------------------------------------------------------------------===//
+
+/// The emulation only works on 1D memref types.
+/// To make this work on N-D memref, we need to linearize the offset.
+///
+/// For example, to emulate i4 to i8, the following op:
+///
+/// %0 = memref.load %arg0[%v0, %v1] :
+///                  memref<?x?xi4, strided<[?, ?], offset: ?>>
+///
+/// can be replaced with
+///
+/// %b, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %0
+///
+/// %linearized_offset = %v0 * %stride#0 + %v1 * %stride#1
+/// %linearized_size = %size0 * %size1
+/// %scaled_linear_offset = %linearized_offset / 8 * 4
+/// %scaled_base_offset = %offset / 8 * 4
+///
+/// %linearized = memref.reinterpret_cast %b, offset = [%scaled_base_offset],
+///                      sizes = [%linearized_size], strides = [%stride#1]
+///
+/// %new_load = memref.load %linearized[%scaled_linear_offset] :
+///                         memref<?xi8, strided<[?], offset: ?>>
+
+static Value
+linearizeMemrefLoad(Location loc, MemRefType sourceType, int srcBits,
+                    int dstBits, SmallVector<Value> indices,
+                    memref::ExtractStridedMetadataOp stridedMetadata,
+                    OpBuilder &builder) {
+  auto srcElementType = sourceType.getElementType();
+  unsigned sourceRank = indices.size();
+
+  Value baseBuffer = stridedMetadata.getBaseBuffer();
+  SmallVector<Value> baseSizes = stridedMetadata.getSizes();
+  SmallVector<Value> baseStrides = stridedMetadata.getStrides();
+  Value baseOffset = stridedMetadata.getOffset();
+  assert(indices.size() == baseStrides.size());
+
+  // Create the affine symbols and values for linearization.
+  SmallVector<AffineExpr> symbols(2 * sourceRank + 2);
+  bindSymbolsList(builder.getContext(), MutableArrayRef{symbols});
+  symbols[0] = builder.getAffineSymbolExpr(0);
+  AffineExpr addMulMap = symbols.front();
+  AffineExpr mulMap = symbols.front();
+
+  SmallVector<OpFoldResult> offsetValues(2 * sourceRank + 2);
+  offsetValues[0] = builder.getIndexAttr(0);
+  SmallVector<OpFoldResult> sizeValues(sourceRank + 1);
+  sizeValues[0] = builder.getIndexAttr(1);
+
+  for (unsigned i = 0; i < sourceRank; ++i) {
+    unsigned offsetIdx = 2 * i + 1;
+    addMulMap = addMulMap + symbols[offsetIdx] * symbols[offsetIdx + 1];
+    offsetValues[offsetIdx] = indices[i];
+    offsetValues[offsetIdx + 1] = baseStrides[i];
+
+    unsigned sizeIdx = i + 1;
+    mulMap = mulMap * symbols[sizeIdx];
+    sizeValues[sizeIdx] = baseSizes[i];
+  }
+
+  // Adjust linearizedOffset by the scale factor (dstBits / srcBits).
+  OpFoldResult scaler = builder.getIndexAttr(dstBits / srcBits);
+  AffineExpr scaledAddMulMap = addMulMap.floorDiv(symbols.back());
+  offsetValues.back() = scaler;
+
+  OpFoldResult linearizedOffset = affine::makeComposedFoldedAffineApply(
+      builder, loc, scaledAddMulMap, offsetValues);
+  OpFoldResult linearizedSize =
+      affine::makeComposedFoldedAffineApply(builder, loc, mulMap, sizeValues);
+
+  // Adjust baseOffset by the scale factor (dstBits / srcBits).
+  AffineExpr s0, s1;
+  bindSymbols(builder.getContext(), s0, s1);
+  OpFoldResult adjustBaseOffset = affine::makeComposedFoldedAffineApply(
+      builder, loc, s0.floorDiv(s1), {baseOffset, scaler});
+
+  // Flatten n-D MemRef to 1-D MemRef.
+  auto layoutAttr = StridedLayoutAttr::get(
+      sourceType.getContext(), ShapedType::kDynamic, {ShapedType::kDynamic});
+  int64_t staticShape = sourceType.hasStaticShape()
+                            ? sourceType.getNumElements()
+                            : ShapedType::kDynamic;
+  auto flattenMemrefType = MemRefType::get(
+      staticShape, srcElementType, layoutAttr, sourceType.getMemorySpace());
+
+  auto reinterpret = builder.create<memref::ReinterpretCastOp>(
+      loc, flattenMemrefType, baseBuffer,
+      getValueOrCreateConstantIndexOp(builder, loc, adjustBaseOffset),
+      getValueOrCreateConstantIndexOp(builder, loc, linearizedSize),
+      baseStrides.back());
+
+  return builder.create<memref::LoadOp>(
+      loc, srcElementType, reinterpret.getResult(),
+      getValueOrCreateConstantIndexOp(builder, loc, linearizedOffset));
+}
+
+/// When data is loaded/stored in `targetBits` granularity, but is used in
+/// `sourceBits` granularity (`sourceBits` < `targetBits`), the `targetBits` is
+/// treated as an array of elements of width `sourceBits`.
+/// Return the bit offset of the value at position `srcIdx`. For example, if
+/// `sourceBits` equals to 4 and `targetBits` equals to 8, the x-th element is
+/// located at (x % 2) * 4. Because there are two elements in one i8, and one
+/// element has 4 bits.
+static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits,
+                                  int targetBits, OpBuilder &builder) {
+  assert(targetBits % sourceBits == 0);
+  IntegerType targetType = builder.getIntegerType(targetBits);
+  IntegerAttr idxAttr =
+      builder.getIntegerAttr(targetType, targetBits / sourceBits);
+  auto idx = builder.create<arith::ConstantOp>(loc, targetType, idxAttr);
+  IntegerAttr srcBitsAttr = builder.getIntegerAttr(targetType, sourceBits);
+  auto srcBitsValue =
+      builder.create<arith::ConstantOp>(loc, targetType, srcBitsAttr);
+  auto m = builder.create<arith::RemUIOp>(loc, srcIdx, idx);
+  return builder.create<arith::MulIOp>(loc, targetType, m, srcBitsValue);
+}
+
+namespace {
+
+//===----------------------------------------------------------------------===//
+// ConvertMemRefAlloc
+//===----------------------------------------------------------------------===//
+
+struct ConvertMemRefAlloc final : OpConversionPattern<memref::AllocOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(memref::AllocOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Type newTy = getTypeConverter()->convertType(op.getType());
+    if (!newTy) {
+      return rewriter.notifyMatchFailure(
+          op->getLoc(),
+          llvm::formatv("failed to convert memref type: {0}", op.getType()));
+    }
+
+    rewriter.replaceOpWithNewOp<memref::AllocOp>(
+        op, newTy, adaptor.getDynamicSizes(), adaptor.getSymbolOperands(),
+        adaptor.getAlignmentAttr());
+    return success();
+  }
+};
+
+//===----------------------------------------------------------------------===//
+// ConvertMemRefAssumeAlignment
+//===----------------------------------------------------------------------===//
+
+struct ConvertMemRefAssumeAlignment final
+    : OpConversionPattern<memref::AssumeAlignmentOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Type newTy = getTypeConverter()->convertType(op.getMemref().getType());
+    if (!newTy) {
+      return rewriter.notifyMatchFailure(
+          op->getLoc(), llvm::formatv("failed to convert memref type: {0}",
+                                      op.getMemref().getType()));
+    }
+
+    rewriter.replaceOpWithNewOp<memref::AssumeAlignmentOp>(
+        op, adaptor.getMemref(), adaptor.getAlignmentAttr());
+    return success();
+  }
+};
+
+//===----------------------------------------------------------------------===//
+// ConvertMemRefLoad
+//===----------------------------------------------------------------------===//
+
+struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(memref::LoadOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Type newTy = getTypeConverter()->convertType(op.getMemRefType());
+    if (!newTy) {
+      return rewriter.notifyMatchFailure(
+          op->getLoc(), llvm::formatv("failed to convert memref type: {0}",
+                                      op.getMemRefType()));
+    }
+
+    if (op.getMemRefType() == newTy)
+      return failure();
+
+    auto loc = op.getLoc();
+    auto sourceType = cast<MemRefType>(adaptor.getMemref().getType());
+    unsigned sourceRank = sourceType.getRank();
+    SmallVector<Value> indices = adaptor.getIndices();
+    assert(indices.size() == sourceRank);
+
+    auto srcElementType = sourceType.getElementType();
+    auto oldElementType = op.getMemRefType().getElementType();
+    int srcBits = oldElementType.getIntOrFloatBitWidth();
+    int dstBits = srcElementType.getIntOrFloatBitWidth();
+    if (dstBits % srcBits != 0) {
+      return rewriter.notifyMatchFailure(
+          op, "only dstBits % srcBits == 0 supported");
+    }
+
+    auto stridedMetadata = rewriter.create<memref::ExtractStridedMetadataOp>(
+        loc, adaptor.getMemref());
+
+    Value newLoad, lastIdx;
+    if (sourceRank == 0) {
+      newLoad = rewriter.create<memref::LoadOp>(
+          loc, srcElementType, adaptor.getMemref(), adaptor.getIndices());
+
+      lastIdx = stridedMetadata.getOffset();
+    } else {
+      newLoad = linearizeMemrefLoad(loc, sourceType, srcBits, dstBits, indices,
+                                    stridedMetadata, rewriter);
+
+      lastIdx = adaptor.getIndices().back();
+    }
+
+    // Get the offset and shift the bits to the rightmost.
+    // Note, currently only the big-endian is supported.
+    auto castLastIdx =
+        rewriter.create<arith::IndexCastUIOp>(loc, srcElementType, lastIdx);
+
+    Value BitwidthOffset =
+        getOffsetForBitwidth(loc, castLastIdx, srcBits, dstBits, rewriter);
+    auto bitsLoad =
+        rewriter.create<arith::ShRSIOp>(loc, newLoad, BitwidthOffset);
+
+    // Get the corresponding bits. If the arith computation bitwidth equals
+    // to the emulated bitwidth, we apply a mask to extract the low bits.
+    // It is not clear if this case actually happens in practice, but we keep
+    // the operations just in case. Otherwise, if the arith computation bitwidth
+    // is different from the emulated bitwidth we truncate the result.
+    Operation *result;
+    auto resultTy = getTypeConverter()->convertType(oldElementType);
+    if (resultTy == srcElementType) {
+      auto mask = rewriter.create<arith::ConstantOp>(
+          loc, srcElementType,
+          rewriter.getIntegerAttr(srcElementType, (1 << srcBits) - 1));
+
+      result = rewriter.create<arith::AndIOp>(loc, bitsLoad, mask);
+    } else {
+      result = rewriter.create<arith::TruncIOp>(loc, resultTy, bitsLoad);
+    }
+
+    rewriter.replaceOp(op, result->getResult(0));
+    return success();
+  }
+};
+} // end anonymous namespace
+
+//===----------------------------------------------------------------------===//
+// Public Interface Definition
+//===----------------------------------------------------------------------===//
+
+void memref::populateMemRefNarrowTypeEmulationPatterns(
+    arith::NarrowTypeEmulationConverter &typeConverter,
+    RewritePatternSet &patterns) {
+
+  // Populate `memref.*` conversion patterns.
+  patterns
+      .add<ConvertMemRefAlloc, ConvertMemRefLoad, ConvertMemRefAssumeAlignment>(
+          typeConverter, patterns.getContext());
+}
+
+void memref::populateMemRefNarrowTypeEmulationConversions(
+    arith::NarrowTypeEmulationConverter &typeConverter) {
+  typeConverter.addConversion(
+      [&typeConverter](MemRefType ty) -> std::optional<Type> {
+        auto intTy = dyn_cast<IntegerType>(ty.getElementType());
+        if (!intTy)
+          return ty;
+
+        unsigned width = intTy.getWidth();
+        unsigned loadStoreWidth = typeConverter.getLoadStoreBitwidth();
+        if (width >= loadStoreWidth)
+          return ty;
+
+        auto newElemTy = IntegerType::get(ty.getContext(), loadStoreWidth,
+                                          intTy.getSignedness());
+        if (!newElemTy)
+          return std::nullopt;
+
+        return ty.cloneWith(std::nullopt, newElemTy);
+      });
+}
diff --git a/mlir/test/Dialect/Arith/emulate-narrow-type.mlir b/mlir/test/Dialect/Arith/emulate-narrow-type.mlir
new file mode 100644 (file)
index 0000000..7120882
--- /dev/null
@@ -0,0 +1,47 @@
+// RUN: mlir-opt --test-emulate-narrow-int="arith-compute-bitwidth=8" %s | FileCheck %s
+
+// Expect no conversions, f32 is not an integer type.
+// CHECK-LABEL: func @identity_f32
+// CHECK-SAME:    ([[ARG:%.+]]: f32) -> f32
+// CHECK-NEXT:    return [[ARG]] : f32
+func.func @identity_f32(%a : f32) -> f32 {
+    return %a : f32
+}
+
+// Expect no conversions, i32 is supported.
+// CHECK-LABEL: func @identity_i32
+// CHECK-SAME:    ([[ARG:%.+]]: vector<2xi32>) -> vector<2xi32>
+// CHECK-NEXT:    return [[ARG]] : vector<2xi32>
+func.func @identity_i32(%a : vector<2xi32>) -> vector<2xi32> {
+    return %a : vector<2xi32>
+}
+
+// CHECK-LABEL: func @identity_scalar
+// CHECK-SAME:     ([[ARG:%.+]]: i8) -> i8
+// CHECK-NEXT:     return [[ARG]] : i8
+func.func @identity_scalar(%x : i4) -> i4 {
+    return %x : i4
+}
+
+// CHECK-LABEL: func @identity_vector
+// CHECK-SAME:     ([[ARG:%.+]]: vector<4xi8>) -> vector<4xi8>
+// CHECK-NEXT:     return [[ARG]] : vector<4xi8>
+func.func @identity_vector(%x : vector<4xi4>) -> vector<4xi4> {
+    return %x : vector<4xi4>
+}
+
+// CHECK-LABEL: func @identity_vector2d
+// CHECK-SAME:     ([[ARG:%.+]]: vector<3x4xi8>) -> vector<3x4xi8>
+// CHECK-NEXT:     return [[ARG]] : vector<3x4xi8>
+func.func @identity_vector2d(%x : vector<3x4xi4>) -> vector<3x4xi4> {
+    return %x : vector<3x4xi4>
+}
+
+// CHECK-LABEL: func @call
+// CHECK-SAME:     ([[ARG:%.+]]: vector<4xi8>) -> vector<4xi8>
+// CHECK-NEXT:     [[RES:%.+]] = call @identity_vector([[ARG]]) : (vector<4xi8>) -> vector<4xi8>
+// CHECK-NEXT:     return [[RES]] : vector<4xi8>
+func.func @call(%a : vector<4xi4>) -> vector<4xi4> {
+    %res = func.call @identity_vector(%a) : (vector<4xi4>) -> vector<4xi4>
+    return %res : vector<4xi4>
+}
diff --git a/mlir/test/Dialect/MemRef/emulate-narrow-type-diff-load-compute.mlir b/mlir/test/Dialect/MemRef/emulate-narrow-type-diff-load-compute.mlir
new file mode 100644 (file)
index 0000000..85d4cc1
--- /dev/null
@@ -0,0 +1,107 @@
+// RUN: mlir-opt --test-emulate-narrow-int="arith-compute-bitwidth=4 memref-load-bitwidth=8" %s | FileCheck %s
+
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 2)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 floordiv 2)>
+// CHECK-DAG: #[[$MAP2:.*]] = affine_map<()[s0, s1, s2, s3] -> ((s0 * s1 + s2 * s3) floordiv 2)>
+// CHECK-DAG: #[[$MAP3:.*]] = affine_map<()[s0, s1] -> (s0 * s1)>
+
+// Expect no conversions, i32 is supported.
+// CHECK-LABEL: func @memref_i32
+// CHECK:         [[M:%.+]] = memref.alloc() : memref<4xi32, 1>
+// CHECK-NEXT:    [[V:%.+]] = memref.load [[M]][{{%.+}}] : memref<4xi32, 1>
+// CHECK-NEXT:    memref.store {{%.+}}, [[M]][{{%.+}}] : memref<4xi32, 1>
+// CHECK-NEXT:    return
+func.func @memref_i32() {
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : i32
+    %m = memref.alloc() : memref<4xi32, 1>
+    %v = memref.load %m[%c0] : memref<4xi32, 1>
+    memref.store %c1, %m[%c0] : memref<4xi32, 1>
+    return
+}
+
+// -----
+
+// Expect no conversions, f32 is not an integer type.
+// CHECK-LABEL: func @memref_f32
+// CHECK:         [[M:%.+]] = memref.alloc() : memref<4xf32, 1>
+// CHECK-NEXT:    [[V:%.+]] = memref.load [[M]][{{%.+}}] : memref<4xf32, 1>
+// CHECK-NEXT:    memref.store {{%.+}}, [[M]][{{%.+}}] : memref<4xf32, 1>
+// CHECK-NEXT:    return
+func.func @memref_f32() {
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1.0 : f32
+    %m = memref.alloc() : memref<4xf32, 1>
+    %v = memref.load %m[%c0] : memref<4xf32, 1>
+    memref.store %c1, %m[%c0] : memref<4xf32, 1>
+    return
+}
+
+// -----
+
+// CHECK-LABEL: func @memref_load_i4_zero_rank
+// CHECK-NEXT:    %[[M:.*]]  = memref.alloc() : memref<i8>
+// CHECK-NEXT:    %[[BASE:.*]], %[[OFFSET:.*]] = memref.extract_strided_metadata %[[M]] : memref<i8> -> memref<i8>, index
+// CHECK-NEXT:    %[[LOAD:.*]] = memref.load %[[M]][] : memref<i8>
+// CHECK-NEXT:    %[[I:.*]] = arith.index_castui %[[OFFSET]] : index to i8
+// CHECK-NEXT:    %[[C2:.*]] = arith.constant 2 : i8
+// CHECK-NEXT:    %[[C4:.*]] = arith.constant 4 : i8
+// CHECK-NEXT:    %[[REM:.*]] = arith.remui %[[I]], %[[C2]] : i8
+// CHECK-NEXT:    %[[STEP:.*]] = arith.muli %[[REM]], %[[C4]] : i8
+// CHECK-NEXT:    %[[SHIFT:.*]] = arith.shrsi %[[LOAD]], %[[STEP]] : i8
+// CHECK-NEXT:    %[[RES:.*]] = arith.trunci %[[SHIFT]] : i8 to i4
+// CHECK-NEXT:    return
+func.func @memref_load_i4_zero_rank() {
+    %0 = memref.alloc() : memref<i4>
+    %1 = memref.load %0[] : memref<i4>
+    return
+}
+
+// -----
+
+// CHECK-LABEL: func @memref_load_i4
+// CHECK-SAME:    (%[[ARG:.*]]: index)
+// CHECK-NEXT:    %[[M:.*]]  = memref.alloc() : memref<4xi8>
+// CHECK-NEXT:    %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]], %[[STRIDES:.*]] = memref.extract_strided_metadata %[[M]] : memref<4xi8> -> memref<i8>, index, index, index
+// CHECK-NEXT:    %[[INDEX:.*]] = affine.apply #[[$MAP0]]()[%[[ARG]], %[[STRIDES]]]
+// CHECK-NEXT:    %[[AOFF:.*]] = affine.apply #[[$MAP1]]()[%[[OFFSET]]]
+// CHECK-NEXT:    %[[CAST:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[AOFF]]], sizes: [%[[SIZES]]], strides: [%[[STRIDES]]] : memref<i8> to memref<4xi8, strided<[?], offset: ?>>
+// CHECK-NEXT:    %[[LOAD:.*]] = memref.load %[[CAST]][%[[INDEX]]] : memref<4xi8, strided<[?], offset: ?>>
+// CHECK-NEXT:    %[[I:.*]] = arith.index_castui %[[ARG]] : index to i8
+// CHECK-NEXT:    %[[C2:.*]] = arith.constant 2 : i8
+// CHECK-NEXT:    %[[C4:.*]] = arith.constant 4 : i8
+// CHECK-NEXT:    %[[REM:.*]] = arith.remui %[[I]], %[[C2]] : i8
+// CHECK-NEXT:    %[[STEP:.*]] = arith.muli %[[REM]], %[[C4]] : i8
+// CHECK-NEXT:    %[[SHIFT:.*]] = arith.shrsi %[[LOAD]], %[[STEP]] : i8
+// CHECK-NEXT:    %[[RES:.*]] = arith.trunci %[[SHIFT]] : i8 to i4
+// CHECK-NEXT:    return
+func.func @memref_load_i4(%arg0: index) {
+    %0 = memref.alloc() : memref<4xi4>
+    %1 = memref.load %0[%arg0] : memref<4xi4>
+    return
+}
+
+// -----
+
+// CHECK-LABEL: func @memref_load_i4_rank2
+// CHECK-SAME:    (%[[ARG:.*]]: memref<4x128xi8>, %[[ARG0:.*]]: index, %[[ARG1:.*]]: index)
+// CHECK-NEXT:    memref.assume_alignment %[[ARG]], 64 : memref<4x128xi8>
+// CHECK-NEXT:    %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]] : memref<4x128xi8> -> memref<i8>, index, index, index, index, index
+// CHECK-NEXT:    %[[INDEX:.*]] = affine.apply #[[$MAP2]]()[%[[ARG0]], %[[STRIDES]]#0, %[[ARG1]], %[[STRIDES]]#1]
+// CHECK-NEXT:    %[[LSIZE:.*]] = affine.apply #[[$MAP3]]()[%[[SIZES]]#0, %[[SIZES]]#1]
+// CHECK-NEXT:    %[[AOFF:.*]] = affine.apply #[[$MAP1]]()[%[[OFFSET]]]
+// CHECK-NEXT:    %[[CAST:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[AOFF]]], sizes: [%[[LSIZE]]], strides: [%[[STRIDES]]#1] : memref<i8> to memref<512xi8, strided<[?], offset: ?>>
+// CHECK-NEXT:    %[[LOAD:.*]] = memref.load %[[CAST]][%[[INDEX]]] : memref<512xi8, strided<[?], offset: ?>>
+// CHECK-NEXT:    %[[I:.*]] = arith.index_castui %[[ARG1]] : index to i8
+// CHECK-NEXT:    %[[C2:.*]] = arith.constant 2 : i8
+// CHECK-NEXT:    %[[C4:.*]] = arith.constant 4 : i8
+// CHECK-NEXT:    %[[REM:.*]] = arith.remui %[[I]], %[[C2]] : i8
+// CHECK-NEXT:    %[[STEP:.*]] = arith.muli %[[REM]], %[[C4]] : i8
+// CHECK-NEXT:    %[[SHIFT:.*]] = arith.shrsi %[[LOAD]], %[[STEP]] : i8
+// CHECK-NEXT:    %[[RES:.*]] = arith.trunci %[[SHIFT]] : i8 to i4
+// CHECK-NEXT:    return
+func.func @memref_load_i4_rank2(%0: memref<4x128xi4>, %arg0: index, %arg1: index) {
+    memref.assume_alignment %0, 64 : memref<4x128xi4>
+    %1 = memref.load %0[%arg0,%arg1] : memref<4x128xi4>
+    return
+}
diff --git a/mlir/test/Dialect/MemRef/emulate-narrow-type-same-load-compute.mlir b/mlir/test/Dialect/MemRef/emulate-narrow-type-same-load-compute.mlir
new file mode 100644 (file)
index 0000000..9d63b9d
--- /dev/null
@@ -0,0 +1,72 @@
+// RUN: mlir-opt --test-emulate-narrow-int="arith-compute-bitwidth=8 memref-load-bitwidth=8" %s | FileCheck %s
+
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 2)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 floordiv 2)>
+// CHECK-DAG: #[[$MAP2:.*]] = affine_map<()[s0, s1, s2, s3] -> ((s0 * s1 + s2 * s3) floordiv 2)>
+// CHECK-DAG: #[[$MAP3:.*]] = affine_map<()[s0, s1] -> (s0 * s1)>
+
+// Expect no conversions.
+// CHECK-LABEL: func @memref_i8
+// CHECK:         [[M:%.+]] = memref.alloc() : memref<4xi8, 1>
+// CHECK-NEXT:    [[V:%.+]] = memref.load [[M]][{{%.+}}] : memref<4xi8, 1>
+// CHECK-NEXT:    memref.store {{%.+}}, [[M]][{{%.+}}] : memref<4xi8, 1>
+// CHECK-NEXT:    return
+func.func @memref_i8() {
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : i8
+    %m = memref.alloc() : memref<4xi8, 1>
+    %v = memref.load %m[%c0] : memref<4xi8, 1>
+    memref.store %c1, %m[%c0] : memref<4xi8, 1>
+    return
+}
+
+// -----
+
+// CHECK-LABEL: func @memref_load_i4
+// CHECK-SAME:    (%[[ARG:.*]]: index)
+// CHECK-NEXT:    %[[M:.*]]  = memref.alloc() : memref<4xi8>
+// CHECK-NEXT:    %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]], %[[STRIDES:.*]] = memref.extract_strided_metadata %[[M]] : memref<4xi8> -> memref<i8>, index, index, index
+// CHECK-NEXT:    %[[INDEX:.*]] = affine.apply #[[$MAP0]]()[%[[ARG]], %[[STRIDES]]]
+// CHECK-NEXT:    %[[AOFF:.*]] = affine.apply #[[$MAP1]]()[%[[OFFSET]]]
+// CHECK-NEXT:    %[[CAST:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[AOFF]]], sizes: [%[[SIZES]]], strides: [%[[STRIDES]]] : memref<i8> to memref<4xi8, strided<[?], offset: ?>>
+// CHECK-NEXT:    %[[LOAD:.*]] = memref.load %[[CAST]][%[[INDEX]]] : memref<4xi8, strided<[?], offset: ?>>
+// CHECK-NEXT:    %[[I:.*]] = arith.index_castui %[[ARG]] : index to i8
+// CHECK-NEXT:    %[[C2:.*]] = arith.constant 2 : i8
+// CHECK-NEXT:    %[[C4:.*]] = arith.constant 4 : i8
+// CHECK-NEXT:    %[[REM:.*]] = arith.remui %[[I]], %[[C2]] : i8
+// CHECK-NEXT:    %[[STEP:.*]] = arith.muli %[[REM]], %[[C4]] : i8
+// CHECK-NEXT:    %[[SHIFT:.*]] = arith.shrsi %[[LOAD]], %[[STEP]] : i8
+// CHECK-NEXT:    %[[MASK:.*]] = arith.constant 15 : i8
+// CHECK-NEXT:    %[[RES:.*]] = arith.andi %[[SHIFT]], %[[MASK]] : i8
+// CHECK-NEXT:    return
+func.func @memref_load_i4(%arg0: index) {
+    %0 = memref.alloc() : memref<4xi4>
+    %1 = memref.load %0[%arg0] : memref<4xi4>
+    return
+}
+
+// -----
+
+// CHECK-LABEL: func @memref_load_i4_rank2
+// CHECK-SAME:    (%[[ARG:.*]]: memref<4x128xi8>, %[[ARG0:.*]]: index, %[[ARG1:.*]]: index)
+// CHECK-NEXT:    memref.assume_alignment %[[ARG]], 64 : memref<4x128xi8>
+// CHECK-NEXT:    %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]] : memref<4x128xi8> -> memref<i8>, index, index, index, index, index
+// CHECK-NEXT:    %[[INDEX:.*]] = affine.apply #[[$MAP2]]()[%[[ARG0]], %[[STRIDES]]#0, %[[ARG1]], %[[STRIDES]]#1]
+// CHECK-NEXT:    %[[LSIZE:.*]] = affine.apply #[[$MAP3]]()[%[[SIZES]]#0, %[[SIZES]]#1]
+// CHECK-NEXT:    %[[AOFF:.*]] = affine.apply #[[$MAP1]]()[%[[OFFSET]]]
+// CHECK-NEXT:    %[[CAST:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[AOFF]]], sizes: [%[[LSIZE]]], strides: [%[[STRIDES]]#1] : memref<i8> to memref<512xi8, strided<[?], offset: ?>>
+// CHECK-NEXT:    %[[LOAD:.*]] = memref.load %[[CAST]][%[[INDEX]]] : memref<512xi8, strided<[?], offset: ?>>
+// CHECK-NEXT:    %[[I:.*]] = arith.index_castui %[[ARG1]] : index to i8
+// CHECK-NEXT:    %[[C2:.*]] = arith.constant 2 : i8
+// CHECK-NEXT:    %[[C4:.*]] = arith.constant 4 : i8
+// CHECK-NEXT:    %[[REM:.*]] = arith.remui %[[I]], %[[C2]] : i8
+// CHECK-NEXT:    %[[STEP:.*]] = arith.muli %[[REM]], %[[C4]] : i8
+// CHECK-NEXT:    %[[SHIFT:.*]] = arith.shrsi %[[LOAD]], %[[STEP]] : i8
+// CHECK-NEXT:    %[[MASK:.*]] = arith.constant 15 : i8
+// CHECK-NEXT:    %[[RES:.*]] = arith.andi %[[SHIFT]], %[[MASK]] : i8
+// CHECK-NEXT:    return
+func.func @memref_load_i4_rank2(%0: memref<4x128xi4>, %arg0: index, %arg1: index) {
+    memref.assume_alignment %0, 64 : memref<4x128xi4>
+    %1 = memref.load %0[%arg0,%arg1] : memref<4x128xi4>
+    return
+}
index df3fdac..0498de3 100644 (file)
@@ -1,6 +1,7 @@
 # Exclude tests from libMLIR.so
 add_mlir_library(MLIRMemRefTestPasses
   TestComposeSubView.cpp
+  TestEmulateNarrowType.cpp
   TestMultiBuffer.cpp
 
   EXCLUDE_FROM_LIBMLIR
diff --git a/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp b/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp
new file mode 100644 (file)
index 0000000..b1f2308
--- /dev/null
@@ -0,0 +1,118 @@
+//===- TestEmulateNarrowType.cpp - Test Narrow Type Emulation  ------*- c++
+//-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Transforms/NarrowTypeEmulationConverter.h"
+#include "mlir/Dialect/Arith/Transforms/Passes.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+using namespace mlir;
+
+namespace {
+
+struct TestEmulateNarrowTypePass
+    : public PassWrapper<TestEmulateNarrowTypePass,
+                         OperationPass<func::FuncOp>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestEmulateNarrowTypePass)
+
+  TestEmulateNarrowTypePass() = default;
+  TestEmulateNarrowTypePass(const TestEmulateNarrowTypePass &pass)
+      : PassWrapper(pass) {}
+
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry
+        .insert<arith::ArithDialect, func::FuncDialect, memref::MemRefDialect,
+                vector::VectorDialect, affine::AffineDialect>();
+  }
+  StringRef getArgument() const final { return "test-emulate-narrow-int"; }
+  StringRef getDescription() const final {
+    return "Function pass to test Narrow Integer Emulation";
+  }
+
+  void runOnOperation() override {
+    if (!llvm::isPowerOf2_32(loadStoreEmulateBitwidth) ||
+        loadStoreEmulateBitwidth < 8) {
+      signalPassFailure();
+      return;
+    }
+
+    Operation *op = getOperation();
+    MLIRContext *ctx = op->getContext();
+
+    arith::NarrowTypeEmulationConverter typeConverter(loadStoreEmulateBitwidth);
+
+    // Convert scalar type.
+    typeConverter.addConversion([this](IntegerType ty) -> std::optional<Type> {
+      unsigned width = ty.getWidth();
+      if (width >= arithComputeBitwidth)
+        return ty;
+
+      return IntegerType::get(ty.getContext(), arithComputeBitwidth);
+    });
+
+    // Convert vector type.
+    typeConverter.addConversion([this](VectorType ty) -> std::optional<Type> {
+      auto intTy = dyn_cast<IntegerType>(ty.getElementType());
+      if (!intTy)
+        return ty;
+
+      unsigned width = intTy.getWidth();
+      if (width >= arithComputeBitwidth)
+        return ty;
+
+      return VectorType::get(
+          to_vector(ty.getShape()),
+          IntegerType::get(ty.getContext(), arithComputeBitwidth));
+    });
+
+    memref::populateMemRefNarrowTypeEmulationConversions(typeConverter);
+    ConversionTarget target(*ctx);
+    target.addDynamicallyLegalOp<func::FuncOp>([&typeConverter](Operation *op) {
+      return typeConverter.isLegal(cast<func::FuncOp>(op).getFunctionType());
+    });
+    auto opLegalCallback = [&typeConverter](Operation *op) {
+      return typeConverter.isLegal(op);
+    };
+    target.addDynamicallyLegalOp<func::CallOp, func::ReturnOp>(opLegalCallback);
+    target.addDynamicallyLegalDialect<
+        arith::ArithDialect, vector::VectorDialect, memref::MemRefDialect,
+        affine::AffineDialect>(
+        [&typeConverter](Operation *op) { return typeConverter.isLegal(op); });
+
+    RewritePatternSet patterns(ctx);
+
+    arith::populateArithNarrowTypeEmulationPatterns(typeConverter, patterns);
+    memref::populateMemRefNarrowTypeEmulationPatterns(typeConverter, patterns);
+
+    if (failed(applyPartialConversion(op, target, std::move(patterns))))
+      signalPassFailure();
+  }
+
+  Option<unsigned> loadStoreEmulateBitwidth{
+      *this, "memref-load-bitwidth",
+      llvm::cl::desc("memref load/store emulation bit width"),
+      llvm::cl::init(8)};
+
+  Option<unsigned> arithComputeBitwidth{
+      *this, "arith-compute-bitwidth",
+      llvm::cl::desc("arith computation bit width"), llvm::cl::init(4)};
+};
+} // namespace
+
+namespace mlir::test {
+void registerTestEmulateNarrowTypePass() {
+  PassRegistration<TestEmulateNarrowTypePass>();
+}
+} // namespace mlir::test
index d75b54e..5b95663 100644 (file)
@@ -87,6 +87,7 @@ void registerTestDiagnosticsPass();
 void registerTestDialectConversionPasses();
 void registerTestDominancePass();
 void registerTestDynamicPipelinePass();
+void registerTestEmulateNarrowTypePass();
 void registerTestExpandMathPass();
 void registerTestFooAnalysisPass();
 void registerTestComposeSubView();
@@ -205,6 +206,7 @@ void registerTestPasses() {
   mlir::test::registerTestDeadCodeAnalysisPass();
   mlir::test::registerTestDominancePass();
   mlir::test::registerTestDynamicPipelinePass();
+  mlir::test::registerTestEmulateNarrowTypePass();
   mlir::test::registerTestExpandMathPass();
   mlir::test::registerTestFooAnalysisPass();
   mlir::test::registerTestComposeSubView();