From: Cullen Rhodes Date: Tue, 25 Jul 2023 08:28:36 +0000 (+0000) Subject: [mlir][ArmSME] Add tile load op and extend tile store tile size support X-Git-Tag: upstream/17.0.6~429 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=ca9a3354d04b15366088d7831b40f891e3d77b95;p=platform%2Fupstream%2Fllvm.git [mlir][ArmSME] Add tile load op and extend tile store tile size support This extends the existing 'arm_sme.tile_store' op to support all tile sizes and adds a new op 'arm_sme.tile_load', as well as lowerings from vector -> custom ops and custom ops -> intrinsics. Currently there's no lowering for i128. Depends on D154867 Reviewed By: awarzynski, dcaballe Differential Revision: https://reviews.llvm.org/D155306 --- diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td index 09f8bfb..caa6e38 100644 --- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td +++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td @@ -224,21 +224,74 @@ def ZeroOp : ArmSME_Op<"zero", [Pure]> { let assemblyFormat = "attr-dict `:` type($res)"; } +def TileLoadOp : ArmSME_Op<"tile_load"> { + let summary = "Tile load operation"; + let description = [{ + Loads a 2D SME "virtual tile" from memory defined by a base and indices, + with the shape defined by the 2D scalable vector type of the result tile. + The slice of memory must be contiguous. The memref must be either rank 1 or + rank 2 with dynamic dimensions, since the operation is scalable, and the + element type must be a scalar that matches the element type of the result. + + Example 1: Load an 8-bit element ZA tile from memory (ZA0.B). + ```mlir + %tile = arm_sme.tile_load %base[%c0, %c0] : memref, vector<[16]x[16]xi8> + ``` + + Example 2: Load a FP 32-bit element ZA tile from memory. + ```mlir + %tile = arm_sme.tile_load %base[%c0, %c0] : memref, vector<[4]x[4]xf32> + ``` + + Example 3: Load a 128-bit element ZA tile from memory. + ```mlir + %tile = arm_sme.tile_load %base[%c0, %c0] : memref, vector<[1]x[1]xi128> + ``` + }]; + let arguments = (ins + Arg:$base, + Variadic:$indices); + let results = (outs SMETile:$result); + + let extraClassDeclaration = [{ + MemRefType getMemRefType() { + return ::llvm::cast(getBase().getType()); + } + VectorType getVectorType() { + return ::llvm::cast(getResult().getType()); + } + }]; + + let assemblyFormat = + "$base `[` $indices `]` attr-dict `:` type($base) `,` type($result)"; +} + def TileStoreOp : ArmSME_Op<"tile_store"> { let summary = "Tile store operation"; let description = [{ - Store a 2D SME "virtual tile" to memory. - - NOTE: At the moment it is assumed that the element type is `i8` and that - there's only one "virtual tile". + Stores a 2D SME "virtual tile" to memory defined by a base and indices, + with the shape defined by the 2D scalable vector type of the tile being + stored. The slice of memory must be contiguous. The memref must be either + rank 1 or rank 2 with dynamic dimensions, since the operation is scalable, + and the element type must be a scalar that matches the element type of the + result. + + Example 1: Store an 8-bit element ZA tile to memory (ZA0.B). + ```mlir + arm_sme.tile_store %tile, %base[%c0, %c0] : vector<[16]x[16]xi8>, memref + ``` - Example: + Example 2: Store a FP 32-bit element ZA tile to memory. + ```mlir + arm_sme.tile_store %tile, %base[%c0, %c0] : vector<[4]x[4]xf32>, memref + ``` + Example 3: Store a 128-bit element ZA tile to memory. ```mlir - arm_sme.tile_store %0, %arg0[%c0, %c0] : vector<[16]x[16]xi8>, memref + arm_sme.tile_store %tile, %base[%c0, %c0] : vector<[1]x[1]xi128>, memref ``` }]; - let arguments = (ins nxnxv16i8:$valueToStore, + let arguments = (ins SMETile:$valueToStore, Arg:$base, Variadic:$indices); let extraClassDeclaration = [{ @@ -304,7 +357,7 @@ def LLVM_aarch64_sme_usmops_wide : ArmSME_IntrMopOverloadedOp<"usmops.wide">; class ArmSME_IntrLoadOp : ArmSME_IntrOp, Arguments<(ins Arg, - Arg, + Arg, Arg, Arg)>; diff --git a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h new file mode 100644 index 0000000..554b9f1 --- /dev/null +++ b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h @@ -0,0 +1,38 @@ +//===- Utils.h - General ArmSME transformation utilities --------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header file defines prototypes for various utilities for the ArmSME +// dialect. These are not passes by themselves but are used either by passes, +// optimization sequences, or in turn by other transformation utilities. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_ARMSME_UTILS_UTILS_H_ +#define MLIR_DIALECT_ARMSME_UTILS_UTILS_H_ + +#include "mlir/Dialect/ArmSME/IR/ArmSME.h" + +namespace mlir { +namespace arm_sme { + +/// Return minimum number of elements for the given element `type` in +/// a vector of SVL bits. +unsigned getSMETileSliceMinNumElts(Type type); + +/// Returns true if `type` is a valid element type for an SME tile or false +/// otherwise. +bool isValidSMETileElementType(Type type); + +/// Returns true if `vType` is a valid vector type for an SME tile or false +/// otherwise. +bool isValidSMETileVectorType(VectorType vType); + +} // namespace arm_sme +} // namespace mlir + +#endif // MLIR_DIALECT_ARMSME_UTILS_UTILS_H_ diff --git a/mlir/lib/Conversion/VectorToArmSME/CMakeLists.txt b/mlir/lib/Conversion/VectorToArmSME/CMakeLists.txt index b062f65e..715816a 100644 --- a/mlir/lib/Conversion/VectorToArmSME/CMakeLists.txt +++ b/mlir/lib/Conversion/VectorToArmSME/CMakeLists.txt @@ -10,5 +10,6 @@ add_mlir_conversion_library(MLIRVectorToArmSME LINK_LIBS PUBLIC MLIRArmSMEDialect + MLIRArmSMEUtils MLIRLLVMCommonConversion ) diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp index cd0d99c..4106b04 100644 --- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp +++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp @@ -9,6 +9,7 @@ #include "mlir/Conversion/VectorToArmSME/VectorToArmSME.h" #include "mlir/Dialect/ArmSME/IR/ArmSME.h" +#include "mlir/Dialect/ArmSME/Utils/Utils.h" #include "mlir/IR/BuiltinTypes.h" #include "llvm/Support/Casting.h" @@ -76,9 +77,42 @@ struct TransferWriteToArmSMELowering } }; +/// Conversion pattern for vector.load. +struct VectorLoadToArmSMELowering : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::LoadOp load, + PatternRewriter &rewriter) const override { + if (!arm_sme::isValidSMETileVectorType(load.getVectorType())) + return failure(); + + rewriter.replaceOpWithNewOp( + load, load.getVectorType(), load.getBase(), load.getIndices()); + + return success(); + } +}; + +/// Conversion pattern for vector.store. +struct VectorStoreToArmSMELowering : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::StoreOp store, + PatternRewriter &rewriter) const override { + if (!arm_sme::isValidSMETileVectorType(store.getVectorType())) + return failure(); + + rewriter.replaceOpWithNewOp( + store, store.getValueToStore(), store.getBase(), store.getIndices()); + + return success(); + } +}; + } // namespace void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns, MLIRContext &ctx) { - patterns.add(&ctx); + patterns.add(&ctx); } diff --git a/mlir/lib/Dialect/ArmSME/CMakeLists.txt b/mlir/lib/Dialect/ArmSME/CMakeLists.txt index 9f57627..31167e6 100644 --- a/mlir/lib/Dialect/ArmSME/CMakeLists.txt +++ b/mlir/lib/Dialect/ArmSME/CMakeLists.txt @@ -1,2 +1,3 @@ add_subdirectory(IR) add_subdirectory(Transforms) +add_subdirectory(Utils) diff --git a/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt index 991beae..8f485db 100644 --- a/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt @@ -12,6 +12,7 @@ add_mlir_dialect_library(MLIRArmSMETransforms LINK_LIBS PUBLIC MLIRArmSMEDialect + MLIRArmSMEUtils MLIRFuncDialect MLIRLLVMCommonConversion MLIRVectorDialect diff --git a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp index e837432..b3a747a 100644 --- a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp +++ b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp @@ -11,6 +11,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ArmSME/IR/ArmSME.h" #include "mlir/Dialect/ArmSME/Transforms/Transforms.h" +#include "mlir/Dialect/ArmSME/Utils/Utils.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/SCF/IR/SCF.h" @@ -19,7 +20,6 @@ using namespace mlir; using namespace mlir::arm_sme; -static constexpr unsigned kMinNumElts = 16; static constexpr unsigned kZeroZAMask = 255; namespace { @@ -50,7 +50,6 @@ struct DisableZAPattern : public OpRewritePattern { return success(); } }; -} // namespace /// Lower 'arm_sme.zero'. Use 'arm_sme.cast_tile_to_vector' to model the return /// value. The latter is a nop, which should be folded away (e.g. during @@ -95,68 +94,285 @@ struct ZeroOpConversion : public ConvertOpToLLVMPattern { } }; -/// Lower 'arm_sme.store_tile' to a loop over the rows of ZA and store each row -/// using 'arm_sme.intr.str'. +/// Extends or truncates `tile`, which should be an `arm_sme::GetTileID` or +/// `arm_sme::CastVectorToTile` op returning an 8/16/32/64/128-bit scalar +/// integer, to an i32 that can be passed as the `tile` parameter to the SME +/// intrinsics. Or returns `tile` if already i32. +Value castTileIDToI32(Value tile, Location loc, + ConversionPatternRewriter &rewriter) { + assert((isa( + tile.getDefiningOp())) && + "expected ArmSME GetTileID or CastVectorToTile op!"); + unsigned tileElementWidth = tile.getType().getIntOrFloatBitWidth(); + if (tileElementWidth < 32) + return rewriter.create(loc, rewriter.getI32Type(), tile); + if (tileElementWidth > 32) + return rewriter.create(loc, rewriter.getI32Type(), tile); + return tile; +} + +/// Returns the following +/// * for rank 2 memrefs `tileSliceIndex`, since `getStridedElementPtr` does +/// the arithmetic. +/// * for rank 1 memrefs `tileSliceIndex * tileSliceNumElts`, adjusting the +/// index by the number of elements in a vector of SVL bits. +/// * otherwise throws an unreachable error. +Value getTileSlicePtrIndex(unsigned rank, Value tileSliceIndex, + Value tileSliceNumElts, Location loc, + ConversionPatternRewriter &rewriter) { + assert((rank == 1 || rank == 2) && "memref has unexpected rank!"); + + auto tileSliceIndexI64 = rewriter.create( + loc, rewriter.getI64Type(), tileSliceIndex); + + if (rank == 1) { + auto tileSliceNumEltsI64 = rewriter.create( + loc, rewriter.getI64Type(), tileSliceNumElts); + return rewriter.create(loc, tileSliceIndexI64, + tileSliceNumEltsI64); + } + + if (rank == 2) + return tileSliceIndexI64; + + llvm_unreachable("memref has unexpected rank!"); +} + +/// Conversion pattern for `arm_sme.tile_load` to SME intrinsics. +/// +/// Lower `arm_sme.tile_load` to a loop over the rows of ZA and load each row +/// using `arm_sme.intr.ld1*.horiz`. +/// +/// BEFORE: +/// ```mlir +/// %tile = arm_sme.tile_load %base[%c0, %c0] : +/// memref, vector<[4]x[4]xi32> +/// ``` +/// +/// AFTER: +/// ```mlir +/// %tile_id = arm_sme.get_tile_id : i32 +/// %vscale = vector.vscale +/// %c0 = arith.constant 0 : index +/// %c1 = arith.constant 1 : index +/// %min_svl_s = arith.constant 4 : index +/// %svl_s = arith.muli %min_svl_s, %vscale : index +/// scf.for %tile_slice = %c0 to %svl_s step %c1 { +/// // (...) +/// "arm_sme.intr.ld1w.horiz"(%ptrue_s, %ptr, %tile_id, %tile_slice) : +/// (vector<[4]xi1>, !llvm.ptr, i32, i32) -> () +/// } +/// %tile = arm_sme.cast_tile_to_vector %tile_id : i32 to vector<[4]x[4]xi32> +/// ``` +struct TileLoadToArmSMELowering + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(arm_sme::TileLoadOp tileLoadOp, + arm_sme::TileLoadOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = tileLoadOp.getLoc(); + auto tileType = tileLoadOp.getVectorType(); + auto tileElementType = tileType.getElementType(); + unsigned tileElementWidth = tileElementType.getIntOrFloatBitWidth(); + + // Create 'arm_sme.get_tile_id' op. + auto tile = rewriter.create( + loc, rewriter.getIntegerType(tileElementWidth)); + + // Create a loop that loads each ZA tile slice from memory. + auto step = rewriter.create(loc, 1); + auto minTileSlices = rewriter.create( + loc, arm_sme::getSMETileSliceMinNumElts(tileElementType)); + auto vscale = + rewriter.create(loc, rewriter.getIndexType()); + auto lowerBound = rewriter.create(loc, 0); + // This describes both the number of ZA tile slices and the number of + // elements in a vector of SVL bits for a given element type (SVL_B, SVL_H, + // ..., SVL_Q). + auto numTileSlices = + rewriter.create(loc, minTileSlices, vscale); + auto forOp = + rewriter.create(loc, lowerBound, numTileSlices, step); + rewriter.setInsertionPointToStart(forOp.getBody()); + + // Create 'arm_sme.intr.ld1*.horiz' intrinsic to load ZA tile slice. + auto memRefType = tileLoadOp.getMemRefType(); + auto tileSlice = forOp.getInductionVar(); + // TODO: The 'indices' argument for the 'base' memref is currently ignored, + // 'tileSliceIndex' should be added to 'indices[0]'. + Value tileSliceIndex = getTileSlicePtrIndex(memRefType.getRank(), tileSlice, + numTileSlices, loc, rewriter); + Value ptr = this->getStridedElementPtr(loc, memRefType, adaptor.getBase(), + {tileSliceIndex}, rewriter); + + // Cast tile slice to i32 for intrinsic. + auto tileSliceI32 = rewriter.create( + loc, rewriter.getI32Type(), tileSlice); + + // Create all active predicate mask. + auto one = rewriter.create( + loc, rewriter.getI1Type(), + rewriter.getIntegerAttr(rewriter.getI1Type(), 1)); + auto predTy = VectorType::get(tileType.getShape()[0], rewriter.getI1Type(), + /*scalableDims=*/{true}); + auto allActiveMask = rewriter.create(loc, predTy, one); + + auto tileI32 = castTileIDToI32(tile, loc, rewriter); + switch (tileElementWidth) { + default: + llvm_unreachable("unexpected element type!"); + case 8: + rewriter.create(loc, allActiveMask, ptr, + tileI32, tileSliceI32); + break; + case 16: + rewriter.create(loc, allActiveMask, ptr, + tileI32, tileSliceI32); + break; + case 32: + rewriter.create(loc, allActiveMask, ptr, + tileI32, tileSliceI32); + break; + case 64: + rewriter.create(loc, allActiveMask, ptr, + tileI32, tileSliceI32); + break; + } + + rewriter.setInsertionPointAfter(forOp); + + // The load intrinsics have no result, replace 'arm_sme.tile_load' with + // 'arm_sme.cast_tile_to_vector' to preserve dataflow. + rewriter.replaceOpWithNewOp(tileLoadOp, tileType, + tile); + + return success(); + } +}; + +/// Conversion pattern for `arm_sme.tile_store` to SME intrinsics. +/// +/// Lower `arm_sme.tile_store` to a loop over the rows of ZA and store each row +/// using `arm_sme.intr.st1*.horiz`. /// /// BEFORE: /// ```mlir -/// arm_sme.tile_store %arg0[%c0, %c0], %0 : memref, -/// vector<[16]x[16]xi8 +/// arm_sme.tile_store %value, %base[%c0, %c0] : memref, +/// vector<[4]x[4]xi32 /// ``` /// /// AFTER: /// ```mlir -/// %vscale = "llvm.intr.vscale"() : () -> index -/// %c0 = arith.constant 0 : index -/// %c1 = arith.constant 1 : index -/// %c16 = arith.constant 16 : index -/// %vec_size = arith.muli %c16, %vscale : index -/// scf.for %row_idx = %c0 to %vec_size step %c1 { -/// // (...) -/// "arm_sme.intr.str"(%row_idx, %addr) : (i32, !llvm.ptr) -> () +/// %tile_id = arm_sme.cast_vector_to_tile %tile : vector<[4]x[4]xi32> to i32 +/// %vscale = vector.vscale +/// %c0 = arith.constant 0 : index +/// %c1 = arith.constant 1 : index +/// %min_svl_s = arith.constant 4 : index +/// %svl_s = arith.muli %min_svl_s, %vscale : index +/// scf.for %tile_slice = %c0 to %svl_s step %c1 { +/// // (...) +/// "arm_sme.intr.st1w.horiz"(%ptrue_s, %ptr, %tile_id, %tile_slice) : +/// (vector<[4]xi1>, !llvm.ptr, i32, i32) -> () +/// } /// ``` -struct TileStoreOpConversion : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; +struct TileStoreToArmSMELowering + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(TileStoreOp store, OpAdaptor adaptor, + matchAndRewrite(arm_sme::TileStoreOp tileStoreOp, + arm_sme::TileStoreOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto loc = store.getLoc(); + auto loc = tileStoreOp.getLoc(); + auto tileType = tileStoreOp.getVectorType(); + auto tileElementType = tileType.getElementType(); + unsigned tileElementWidth = tileElementType.getIntOrFloatBitWidth(); + + // Create 'arm_sme.cast_vector_to_tile' to get a tile ID for the vector + // being stored. + auto tile = rewriter.create( + loc, rewriter.getIntegerType(tileElementWidth), + tileStoreOp.getValueToStore()); - // Create loop that iterates from 0 to SVLB-1 inclusive (the number of - // vectors in ZA) and stores each ZA vector to memory. + // Create a loop that stores each ZA tile slice to memory. auto step = rewriter.create(loc, 1); - auto minElems = rewriter.create(loc, kMinNumElts); + auto minTileSlices = rewriter.create( + loc, arm_sme::getSMETileSliceMinNumElts(tileElementType)); auto vscale = rewriter.create(loc, rewriter.getIndexType()); auto lowerBound = rewriter.create(loc, 0); - auto upperBound = rewriter.create(loc, minElems, vscale); - auto forOp = rewriter.create(loc, lowerBound, upperBound, step); + // This describes both the number of ZA tile slices and the number of + // elements in a vector of SVL bits for a given element type (SVL_B, SVL_H, + // ..., SVL_Q). + auto numTileSlices = + rewriter.create(loc, minTileSlices, vscale); + auto forOp = + rewriter.create(loc, lowerBound, numTileSlices, step); rewriter.setInsertionPointToStart(forOp.getBody()); - // Create 'arm_sme.intr.str' intrinsic to store ZA vector. - auto vnumI64 = rewriter.create( - loc, rewriter.getI64Type(), forOp.getInductionVar()); - auto offset = - rewriter.create(loc, rewriter.getI64Type(), 0); - Value ptr = - getStridedElementPtr(loc, store.getMemRefType(), adaptor.getBase(), - ValueRange{vnumI64, offset}, rewriter); - auto vnumI32 = rewriter.create( - loc, rewriter.getI32Type(), forOp.getInductionVar()); - rewriter.create(loc, vnumI32, ptr); - - rewriter.eraseOp(store); + // Create 'arm_sme.intr.st1*.horiz' intrinsic to store ZA tile slice. + auto memRefType = tileStoreOp.getMemRefType(); + auto tileSlice = forOp.getInductionVar(); + // TODO: The 'indices' argument for the 'base' memref is currently ignored, + // 'tileSliceIndex' should be added to 'indices[0]'. + Value tileSliceIndex = getTileSlicePtrIndex(memRefType.getRank(), tileSlice, + numTileSlices, loc, rewriter); + Value ptr = this->getStridedElementPtr(loc, memRefType, adaptor.getBase(), + {tileSliceIndex}, rewriter); + + // Cast tile slice to i32 for intrinsic. + auto tileSliceI32 = rewriter.create( + loc, rewriter.getI32Type(), tileSlice); + + // Create all active predicate mask. + auto one = rewriter.create( + loc, rewriter.getI1Type(), + rewriter.getIntegerAttr(rewriter.getI1Type(), 1)); + auto predTy = VectorType::get(tileType.getShape()[0], rewriter.getI1Type(), + /*scalableDims=*/{true}); + auto allActiveMask = rewriter.create(loc, predTy, one); + + Value tileI32 = castTileIDToI32(tile, loc, rewriter); + switch (tileElementWidth) { + default: + llvm_unreachable("unexpected element type!"); + case 8: + rewriter.replaceOpWithNewOp( + tileStoreOp, allActiveMask, ptr, tileI32, tileSliceI32); + break; + case 16: + rewriter.replaceOpWithNewOp( + tileStoreOp, allActiveMask, ptr, tileI32, tileSliceI32); + break; + case 32: + rewriter.replaceOpWithNewOp( + tileStoreOp, allActiveMask, ptr, tileI32, tileSliceI32); + break; + case 64: + rewriter.replaceOpWithNewOp( + tileStoreOp, allActiveMask, ptr, tileI32, tileSliceI32); + break; + } + return success(); } }; +} // namespace + void mlir::configureArmSMELegalizeForExportTarget( LLVMConversionTarget &target) { - target.addLegalOp(); + target.addLegalOp< + scf::ForOp, scf::YieldOp, arm_sme::CastTileToVector, + arm_sme::CastVectorToTile, arm_sme::aarch64_sme_zero, + arm_sme::aarch64_sme_str, arm_sme::aarch64_sme_ld1b_horiz, + arm_sme::aarch64_sme_ld1h_horiz, arm_sme::aarch64_sme_ld1w_horiz, + arm_sme::aarch64_sme_ld1d_horiz, arm_sme::aarch64_sme_st1b_horiz, + arm_sme::aarch64_sme_st1h_horiz, arm_sme::aarch64_sme_st1w_horiz, + arm_sme::aarch64_sme_st1d_horiz, arm_sme::aarch64_sme_za_enable, + arm_sme::aarch64_sme_za_disable>(); target.addLegalOp(); // Mark 'func.func' ops as legal if either: @@ -187,5 +403,6 @@ void mlir::configureArmSMELegalizeForExportTarget( void mlir::populateArmSMELegalizeForLLVMExportPatterns( LLVMTypeConverter &converter, RewritePatternSet &patterns) { patterns.add(patterns.getContext()); - patterns.add(converter); + patterns.add(converter); } diff --git a/mlir/lib/Dialect/ArmSME/Utils/CMakeLists.txt b/mlir/lib/Dialect/ArmSME/Utils/CMakeLists.txt new file mode 100644 index 0000000..da8517a --- /dev/null +++ b/mlir/lib/Dialect/ArmSME/Utils/CMakeLists.txt @@ -0,0 +1,11 @@ +add_mlir_dialect_library(MLIRArmSMEUtils + Utils.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ArmSME/Utils + + LINK_LIBS PUBLIC + MLIRArmSMEDialect + MLIRDialect + MLIRIR + ) diff --git a/mlir/lib/Dialect/ArmSME/Utils/Utils.cpp b/mlir/lib/Dialect/ArmSME/Utils/Utils.cpp new file mode 100644 index 0000000..a5908a5 --- /dev/null +++ b/mlir/lib/Dialect/ArmSME/Utils/Utils.cpp @@ -0,0 +1,48 @@ +//===- Utils.cpp - Utilities to support the ArmSME dialect ----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements utilities for the ArmSME dialect. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/ArmSME/Utils/Utils.h" + +#include "mlir/Dialect/ArmSME/IR/ArmSME.h" + +using namespace mlir; +using namespace mlir::arm_sme; + +static constexpr unsigned MinStreamingVectorLengthInBits = 128; + +unsigned mlir::arm_sme::getSMETileSliceMinNumElts(Type type) { + assert(isValidSMETileElementType(type) && "invalid tile type!"); + return MinStreamingVectorLengthInBits / type.getIntOrFloatBitWidth(); +} + +bool mlir::arm_sme::isValidSMETileElementType(Type type) { + // TODO: add support for i128. + return type.isInteger(8) || type.isInteger(16) || type.isInteger(32) || + type.isInteger(64) || type.isF16() || type.isBF16() || type.isF32() || + type.isF64(); +} + +bool mlir::arm_sme::isValidSMETileVectorType(VectorType vType) { + if ((vType.getRank() != 2) && vType.allDimsScalable()) + return false; + + // TODO: add support for i128. + auto elemType = vType.getElementType(); + if (!isValidSMETileElementType(elemType)) + return false; + + unsigned minNumElts = arm_sme::getSMETileSliceMinNumElts(elemType); + if (vType.getShape() != ArrayRef({minNumElts, minNumElts})) + return false; + + return true; +} diff --git a/mlir/test/Dialect/ArmSME/roundtrip.mlir b/mlir/test/Dialect/ArmSME/roundtrip.mlir index 5c1f3a9..66be8a2 100644 --- a/mlir/test/Dialect/ArmSME/roundtrip.mlir +++ b/mlir/test/Dialect/ArmSME/roundtrip.mlir @@ -1,7 +1,5 @@ // RUN: mlir-opt -split-input-file -verify-diagnostics %s | mlir-opt | FileCheck %s -// ----- - func.func @arm_sme_cast_tile_to_vector_i8(%tile_id : i8) -> vector<[16]x[16]xi8> { // CHECK: arm_sme.cast_tile_to_vector {{.*}} : i8 to vector<[16]x[16]xi8> %0 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<[16]x[16]xi8> @@ -194,6 +192,87 @@ func.func @arm_sme_zero() -> () { // ----- +func.func @arm_sme_tile_load_i8(%src : memref) -> () { + // CHECK: arm_sme.tile_load {{.*}} : memref, vector<[16]x[16]xi8> + %c0 = arith.constant 0 : index + %tile = arm_sme.tile_load %src[%c0, %c0] : memref, vector<[16]x[16]xi8> + return +} + +// ----- + +func.func @arm_sme_tile_load_i16(%src : memref) -> () { + // CHECK: arm_sme.tile_load {{.*}} : memref, vector<[8]x[8]xi16> + %c0 = arith.constant 0 : index + %tile = arm_sme.tile_load %src[%c0, %c0] : memref, vector<[8]x[8]xi16> + return +} + +// ----- + +func.func @arm_sme_tile_load_i32(%src : memref) -> () { + // CHECK: arm_sme.tile_load {{.*}} : memref, vector<[4]x[4]xi32> + %c0 = arith.constant 0 : index + %tile = arm_sme.tile_load %src[%c0, %c0] : memref, vector<[4]x[4]xi32> + return +} + +// ----- + +func.func @arm_sme_tile_load_i64(%src : memref) -> () { + // CHECK: arm_sme.tile_load {{.*}} : memref, vector<[2]x[2]xi64> + %c0 = arith.constant 0 : index + %tile = arm_sme.tile_load %src[%c0, %c0] : memref, vector<[2]x[2]xi64> + return +} + +// ----- + +func.func @arm_sme_tile_load_i128(%src : memref) -> () { + // CHECK: arm_sme.tile_load {{.*}} : memref, vector<[1]x[1]xi128> + %c0 = arith.constant 0 : index + %tile = arm_sme.tile_load %src[%c0, %c0] : memref, vector<[1]x[1]xi128> + return +} + +// ----- + +func.func @arm_sme_tile_load_f16(%src : memref) -> () { + // CHECK: arm_sme.tile_load {{.*}} : memref, vector<[8]x[8]xf16> + %c0 = arith.constant 0 : index + %tile = arm_sme.tile_load %src[%c0, %c0] : memref, vector<[8]x[8]xf16> + return +} + +// ----- + +func.func @arm_sme_tile_load_bf16(%src : memref) -> () { + // CHECK: arm_sme.tile_load {{.*}} : memref, vector<[8]x[8]xbf16> + %c0 = arith.constant 0 : index + %tile = arm_sme.tile_load %src[%c0, %c0] : memref, vector<[8]x[8]xbf16> + return +} + +// ----- + +func.func @arm_sme_tile_load_f32(%src : memref) -> () { + // CHECK: arm_sme.tile_load {{.*}} : memref, vector<[4]x[4]xf32> + %c0 = arith.constant 0 : index + %tile = arm_sme.tile_load %src[%c0, %c0] : memref, vector<[4]x[4]xf32> + return +} + +// ----- + +func.func @arm_sme_tile_load_f64(%src : memref) -> () { + // CHECK: arm_sme.tile_load {{.*}} : memref, vector<[2]x[2]xf64> + %c0 = arith.constant 0 : index + %tile = arm_sme.tile_load %src[%c0, %c0] : memref, vector<[2]x[2]xf64> + return +} + +// ----- + func.func @arm_sme_store_tile(%tile : vector<[16]x[16]xi8>, %dest : memref) -> () { // CHECK: arm_sme.tile_store {{.*}} : memref, vector<[16]x[16]xi8> %c0 = arith.constant 0 : index diff --git a/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir index cb52ab5..9c76a4c 100644 --- a/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir +++ b/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir @@ -1,28 +1,30 @@ // RUN: mlir-opt %s -convert-vector-to-arm-sme -convert-vector-to-llvm="enable-arm-sme" -split-input-file | mlir-opt | FileCheck %s -// CHECK-LABEL: @transfer_write_2d_zero_i8 -// CHECK-SAME: %[[ARG0:.*]]: memref) +// CHECK-LABEL: @transfer_write_2d_zero_i8( +// CHECK-SAME: %[[ARG0:.*]]: memref) // CHECK-DAG: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK-DAG: %[[C255:.*]] = arith.constant 255 : i32 // CHECK-DAG: "arm_sme.intr.zero"(%[[C255]]) : (i32) -> () // CHECK-DAG: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i8 -// CHECK-DAG: %[[CAST_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i8 to vector<[16]x[16]xi8> +// CHECK-DAG: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i8 to vector<[16]x[16]xi8> +// CHECK-DAG: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[CAST_TILE_TO_VECTOR]] : vector<[16]x[16]xi8> to i8 // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[MIN_ZA_VECTORS:.*]] = arith.constant 16 : index +// CHECK-DAG: %[[MIN_SVL_B:.*]] = arith.constant 16 : index // CHECK-NEXT: %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64 // CHECK-NEXT: %[[VSCALE_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index -// CHECK-NEXT: %[[C0_0:.*]] = arith.constant 0 : index -// CHECK-NEXT: %[[NUM_ZA_VECTORS:.*]] = arith.muli %[[MIN_ZA_VECTORS]], %[[VSCALE_IDX]] : index -// CHECK-NEXT: scf.for %[[VNUM:.*]] = %[[C0_0]] to %[[NUM_ZA_VECTORS]] step %[[C1]] { -// CHECK-NEXT: %[[VNUM_I64:.*]] = arith.index_castui %[[VNUM]] : index to i64 -// CHECK-NEXT: %[[C0_1:.*]] = llvm.mlir.constant(0 : i64) : i64 +// CHECK-NEXT: %[[C0:.*]] = arith.constant 0 : index +// CHECK-NEXT: %[[SVL_B:.*]] = arith.muli %[[MIN_SVL_B]], %[[VSCALE_IDX]] : index +// CHECK-NEXT: scf.for %[[TILE_SLICE:.*]] = %[[C0]] to %[[SVL_B]] step %[[C1]] { +// CHECK-NEXT: %[[TILE_SLICE_I64:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i64 // CHECK-NEXT: %[[ALIGNED_BASE:.*]] = llvm.extractvalue %[[MEM_DESC]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK-NEXT: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEM_DESC]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// CHECK-NEXT: %[[OFF0:.*]] = llvm.mul %[[VNUM_I64]], %[[STRIDE0]] : i64 -// CHECK-NEXT: %[[OFF1:.*]] = llvm.add %[[OFF0]], %[[C0_1]] : i64 -// CHECK-NEXT: %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[OFF1]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8 -// CHECK-NEXT: %[[VNUM_I32:.*]] = arith.index_castui %[[VNUM]] : index to i32 -// CHECK-NEXT: "arm_sme.intr.str"(%[[VNUM_I32]], %[[GEP]]) : (i32, !llvm.ptr) -> () +// CHECK-NEXT: %[[OFF0:.*]] = llvm.mul %[[TILE_SLICE_I64]], %[[STRIDE0]] : i64 +// CHECK-NEXT: %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[OFF0]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8 +// CHECK-NEXT: %[[TILE_SLICE_I32:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i32 +// CHECK-NEXT: %[[TRUE:.*]] = arith.constant true +// CHECK-NEXT: %[[PTRUE_ALL:.*]] = arith.constant dense : vector<[16]xi1> +// CHECK-NEXT: %[[TILE_ID_I32:.*]] = arith.extui %[[CAST_VECTOR_TO_TILE]] : i8 to i32 +// CHECK-NEXT: "arm_sme.intr.st1b.horiz"(%[[PTRUE_ALL]], %[[GEP]], %[[TILE_ID_I32]], %[[TILE_SLICE_I32]]) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> () func.func @transfer_write_2d_zero_i8(%arg0 : memref) { %c0 = arith.constant 0 : index %cst = arith.constant dense<0> : vector<[16]x[16]xi8> @@ -30,3 +32,329 @@ func.func @transfer_write_2d_zero_i8(%arg0 : memref) { return } +// ----- + +// CHECK-LABEL: @vector_load_i8( +// CHECK-SAME: %[[ARG0:.*]]: memref) +// CHECK-NEXT: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK-NEXT: %[[C0_0:.*]] = arith.constant 0 : index +// CHECK-NEXT: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i8 +// CHECK-NEXT: %[[C1:.*]] = arith.constant 1 : index +// CHECK-NEXT: %[[MIN_SVL_B:.*]] = arith.constant 16 : index +// CHECK-NEXT: %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64 +// CHECK-NEXT: %[[VSCALE_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index +// CHECK-NEXT: %[[C0_1:.*]] = arith.constant 0 : index +// CHECK-NEXT: %[[SVL_B:.*]] = arith.muli %[[MIN_SVL_B]], %[[VSCALE_IDX]] : index +// CHECK-NEXT: scf.for %[[TILE_SLICE:.*]] = %[[C0_1]] to %[[SVL_B]] step %[[C1]] { +// CHECK-NEXT: %[[TILE_SLICE_I64:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i64 +// CHECK-NEXT: %[[ALIGNED_BASE:.*]] = llvm.extractvalue %[[MEM_DESC]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK-NEXT: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEM_DESC]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK-NEXT: %[[OFF0:.*]] = llvm.mul %[[TILE_SLICE_I64]], %[[STRIDE0]] : i64 +// CHECK-NEXT: %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[OFF0]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8 +// CHECK-NEXT: %[[TILE_SLICE_I32:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i32 +// CHECK-NEXT: %[[TRUE:.*]] = arith.constant true +// CHECK-NEXT: %[[PTRUE_ALL:.*]] = arith.constant dense : vector<[16]xi1> +// CHECK-NEXT: %[[TILE_ID_I32:.*]] = arith.extui %[[TILE_ID]] : i8 to i32 +// CHECK-NEXT: "arm_sme.intr.ld1b.horiz"(%[[PTRUE_ALL]], %[[GEP]], %[[TILE_ID_I32]], %[[TILE_SLICE_I32]]) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> () +// CHECK-NEXT: } +// CHECK-NEXT: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i8 to vector<[16]x[16]xi8> +// CHECK-NEXT: return %[[CAST_TILE_TO_VECTOR]] : vector<[16]x[16]xi8> +func.func @vector_load_i8(%arg0 : memref) -> vector<[16]x[16]xi8> { + %c0 = arith.constant 0 : index + %tile = vector.load %arg0[%c0, %c0] : memref, vector<[16]x[16]xi8> + return %tile : vector<[16]x[16]xi8> +} + +// ----- + +// CHECK-LABEL: @vector_load_i8_from_rank_1_memref( +// CHECK-SAME: %[[ARG0:.*]]: memref) +// CHECK-NEXT: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref to !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> +// CHECK-NEXT: %[[C0_0:.*]] = arith.constant 0 : index +// CHECK-NEXT: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i8 +// CHECK-NEXT: %[[C1:.*]] = arith.constant 1 : index +// CHECK-NEXT: %[[MIN_SVL_B:.*]] = arith.constant 16 : index +// CHECK-NEXT: %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64 +// CHECK-NEXT: %[[VSCALE_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index +// CHECK-NEXT: %[[C0_1:.*]] = arith.constant 0 : index +// CHECK-NEXT: %[[SVL_B:.*]] = arith.muli %[[MIN_SVL_B]], %[[VSCALE_IDX]] : index +// CHECK-NEXT: scf.for %[[TILE_SLICE:.*]] = %[[C0_1]] to %[[SVL_B]] step %[[C1]] { +// CHECK-NEXT: %[[TILE_SLICE_I64:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i64 +// CHECK-NEXT: %[[SVL_B_I64:.*]] = arith.index_castui %[[SVL_B]] : index to i64 +// CHECK-NEXT: %[[TILE_SLICE_IDX:.*]] = arith.muli %[[TILE_SLICE_I64]], %[[SVL_B_I64]] : i64 +// CHECK-NEXT: %[[ALIGNED_BASE:.*]] = llvm.extractvalue %[[MEM_DESC]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> +// CHECK-NEXT: %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[TILE_SLICE_IDX]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8 +// CHECK-NEXT: %[[TILE_SLICE_I32:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i32 +// CHECK-NEXT: %[[TRUE:.*]] = arith.constant true +// CHECK-NEXT: %[[PTRUE_ALL:.*]] = arith.constant dense : vector<[16]xi1> +// CHECK-NEXT: %[[TILE_ID_I32:.*]] = arith.extui %[[TILE_ID]] : i8 to i32 +// CHECK-NEXT: "arm_sme.intr.ld1b.horiz"(%[[PTRUE_ALL]], %[[GEP]], %[[TILE_ID_I32]], %[[TILE_SLICE_I32]]) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> () +// CHECK-NEXT: } +// CHECK-NEXT: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i8 to vector<[16]x[16]xi8> +// CHECK-NEXT: return %[[CAST_TILE_TO_VECTOR]] : vector<[16]x[16]xi8> +func.func @vector_load_i8_from_rank_1_memref(%arg0 : memref) -> vector<[16]x[16]xi8> { + %c0 = arith.constant 0 : index + %tile = vector.load %arg0[%c0] : memref, vector<[16]x[16]xi8> + return %tile : vector<[16]x[16]xi8> +} + + +// ----- + +// CHECK-LABEL: @vector_load_i16( +// CHECK-SAME: %[[ARG0:.*]]: memref) +// CHECK: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i16 +// CHECK: %[[MIN_SVL_H:.*]] = arith.constant 8 : index +// CHECK: %[[SVL_H:.*]] = arith.muli %[[MIN_SVL_H]], %{{.*}} : index +// CHECK: %[[TILE_ID_I32:.*]] = arith.extui %[[TILE_ID]] : i16 to i32 +// CHECK: arm_sme.intr.ld1h.horiz +// CHECK: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i16 to vector<[8]x[8]xi16> +func.func @vector_load_i16(%arg0 : memref) -> vector<[8]x[8]xi16> { + %c0 = arith.constant 0 : index + %tile = vector.load %arg0[%c0, %c0] : memref, vector<[8]x[8]xi16> + return %tile : vector<[8]x[8]xi16> +} + +// ----- + +// CHECK-LABEL: @vector_load_i32( +// CHECK-SAME: %[[ARG0:.*]]: memref) +// CHECK: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i32 +// CHECK: %[[MIN_SVL_S:.*]] = arith.constant 4 : index +// CHECK: %[[SVL_S:.*]] = arith.muli %[[MIN_SVL_S]], %{{.*}} : index +// CHECK-NOT: arith.extui %[[TILE_ID]] +// CHECK-NOT: arith.trunci %[[TILE_ID]] +// CHECK: arm_sme.intr.ld1w.horiz +// CHECK: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i32 to vector<[4]x[4]xi32> +func.func @vector_load_i32(%arg0 : memref) -> vector<[4]x[4]xi32> { + %c0 = arith.constant 0 : index + %tile = vector.load %arg0[%c0, %c0] : memref, vector<[4]x[4]xi32> + return %tile : vector<[4]x[4]xi32> +} + +// ----- + +// CHECK-LABEL: @vector_load_i64( +// CHECK-SAME: %[[ARG0:.*]]: memref) +// CHECK: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i64 +// CHECK: %[[MIN_SVL_D:.*]] = arith.constant 2 : index +// CHECK: %[[SVL_D:.*]] = arith.muli %[[MIN_SVL_D]], %{{.*}} : index +// CHECK: %[[TILE_ID_I32:.*]] = arith.trunci %[[TILE_ID]] : i64 to i32 +// CHECK: arm_sme.intr.ld1d.horiz +// CHECK: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i64 to vector<[2]x[2]xi64> +func.func @vector_load_i64(%arg0 : memref) -> vector<[2]x[2]xi64> { + %c0 = arith.constant 0 : index + %tile = vector.load %arg0[%c0, %c0] : memref, vector<[2]x[2]xi64> + return %tile : vector<[2]x[2]xi64> +} + +// ----- + +// CHECK-LABEL: @vector_load_f16( +// CHECK-SAME: %[[ARG0:.*]]: memref) +// CHECK: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i16 +// CHECK: %[[MIN_SVL_H:.*]] = arith.constant 8 : index +// CHECK: %[[SVL_H:.*]] = arith.muli %[[MIN_SVL_H]], %{{.*}} : index +// CHECK: %[[TILE_ID_I32:.*]] = arith.extui %[[TILE_ID]] : i16 to i32 +// CHECK: arm_sme.intr.ld1h.horiz +// CHECK: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i16 to vector<[8]x[8]xf16> +func.func @vector_load_f16(%arg0 : memref) -> vector<[8]x[8]xf16> { + %c0 = arith.constant 0 : index + %tile = vector.load %arg0[%c0, %c0] : memref, vector<[8]x[8]xf16> + return %tile : vector<[8]x[8]xf16> +} + +// ----- + +// CHECK-LABEL: @vector_load_bf16( +// CHECK-SAME: %[[ARG0:.*]]: memref) +// CHECK: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i16 +// CHECK: %[[MIN_SVL_H:.*]] = arith.constant 8 : index +// CHECK: %[[SVL_H:.*]] = arith.muli %[[MIN_SVL_H]], %{{.*}} : index +// CHECK: %[[TILE_ID_I32:.*]] = arith.extui %[[TILE_ID]] : i16 to i32 +// CHECK: arm_sme.intr.ld1h.horiz +// CHECK: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i16 to vector<[8]x[8]xbf16> +func.func @vector_load_bf16(%arg0 : memref) -> vector<[8]x[8]xbf16> { + %c0 = arith.constant 0 : index + %tile = vector.load %arg0[%c0, %c0] : memref, vector<[8]x[8]xbf16> + return %tile : vector<[8]x[8]xbf16> +} + +// ----- + +// CHECK-LABEL: @vector_load_f32( +// CHECK-SAME: %[[ARG0:.*]]: memref) +// CHECK: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i32 +// CHECK: %[[MIN_SVL_S:.*]] = arith.constant 4 : index +// CHECK: %[[SVL_S:.*]] = arith.muli %[[MIN_SVL_S]], %{{.*}} : index +// CHECK-NOT: arith.extui %[[TILE_ID]] +// CHECK-NOT: arith.trunci %[[TILE_ID]] +// CHECK: arm_sme.intr.ld1w.horiz +// CHECK: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i32 to vector<[4]x[4]xf32> +func.func @vector_load_f32(%arg0 : memref) -> vector<[4]x[4]xf32> { + %c0 = arith.constant 0 : index + %tile = vector.load %arg0[%c0, %c0] : memref, vector<[4]x[4]xf32> + return %tile : vector<[4]x[4]xf32> +} + +// ----- + +// CHECK-LABEL: @vector_load_f64( +// CHECK-SAME: %[[ARG0:.*]]: memref) +// CHECK: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i64 +// CHECK: %[[MIN_SVL_D:.*]] = arith.constant 2 : index +// CHECK: %[[SVL_D:.*]] = arith.muli %[[MIN_SVL_D]], %{{.*}} : index +// CHECK: %[[TILE_ID_I32:.*]] = arith.trunci %[[TILE_ID]] : i64 to i32 +// CHECK: arm_sme.intr.ld1d.horiz +// CHECK: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i64 to vector<[2]x[2]xf64> +func.func @vector_load_f64(%arg0 : memref) -> vector<[2]x[2]xf64> { + %c0 = arith.constant 0 : index + %tile = vector.load %arg0[%c0, %c0] : memref, vector<[2]x[2]xf64> + return %tile : vector<[2]x[2]xf64> +} + +// ----- + +// CHECK-LABEL: @vector_store_i8( +// CHECK-SAME: %[[TILE:.*]]: vector<[16]x[16]xi8>, +// CHECK-SAME: %[[ARG0:.*]]: memref) +// CHECK-NEXT: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK-NEXT: %[[C0_0:.*]] = arith.constant 0 : index +// CHECK-NEXT: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[16]x[16]xi8> to i8 +// CHECK-NEXT: %[[C1:.*]] = arith.constant 1 : index +// CHECK-NEXT: %[[MIN_SVL_B:.*]] = arith.constant 16 : index +// CHECK-NEXT: %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64 +// CHECK-NEXT: %[[VSCALE_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index +// CHECK-NEXT: %[[C0_1:.*]] = arith.constant 0 : index +// CHECK-NEXT: %[[SVL_B:.*]] = arith.muli %[[MIN_SVL_B]], %[[VSCALE_IDX]] : index +// CHECK-NEXT: scf.for %[[TILE_SLICE:.*]] = %[[C0_1]] to %[[SVL_B]] step %[[C1]] { +// CHECK-NEXT: %[[TILE_SLICE_I64:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i64 +// CHECK-NEXT: %[[ALIGNED_BASE:.*]] = llvm.extractvalue %[[MEM_DESC]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK-NEXT: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEM_DESC]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK-NEXT: %[[OFF0:.*]] = llvm.mul %[[TILE_SLICE_I64]], %[[STRIDE0]] : i64 +// CHECK-NEXT: %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[OFF0]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8 +// CHECK-NEXT: %[[TILE_SLICE_I32:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i32 +// CHECK-NEXT: %[[TRUE:.*]] = arith.constant true +// CHECK-NEXT: %[[PTRUE_ALL:.*]] = arith.constant dense : vector<[16]xi1> +// CHECK-NEXT: %[[TILE_ID_I32:.*]] = arith.extui %[[CAST_VECTOR_TO_TILE]] : i8 to i32 +// CHECK-NEXT: "arm_sme.intr.st1b.horiz"(%[[PTRUE_ALL]], %[[GEP]], %[[TILE_ID_I32]], %[[TILE_SLICE_I32]]) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> () +// CHECK-NEXT: } +// CHECK-NEXT: return +func.func @vector_store_i8(%tile : vector<[16]x[16]xi8>, %arg0 : memref) { + %c0 = arith.constant 0 : index + vector.store %tile, %arg0[%c0, %c0] : memref, vector<[16]x[16]xi8> + return +} + +// ----- + +// CHECK-LABEL: @vector_store_i16( +// CHECK-SAME: %[[TILE:.*]]: vector<[8]x[8]xi16>, +// CHECK-SAME: %[[ARG0:.*]]: memref) +// CHECK: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[8]x[8]xi16> to i16 +// CHECK: %[[MIN_SVL_H:.*]] = arith.constant 8 : index +// CHECK: %[[SVL_H:.*]] = arith.muli %[[MIN_SVL_H]], %{{.*}} : index +// CHECK: %[[TILE_ID_I32:.*]] = arith.extui %[[CAST_VECTOR_TO_TILE]] : i16 to i32 +// CHECK: arm_sme.intr.st1h.horiz +func.func @vector_store_i16(%tile : vector<[8]x[8]xi16>, %arg0 : memref) { + %c0 = arith.constant 0 : index + vector.store %tile, %arg0[%c0, %c0] : memref, vector<[8]x[8]xi16> + return +} + +// ----- + +// CHECK-LABEL: @vector_store_i32( +// CHECK-SAME: %[[TILE:.*]]: vector<[4]x[4]xi32>, +// CHECK-SAME: %[[ARG0:.*]]: memref) +// CHECK: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[4]x[4]xi32> to i32 +// CHECK: %[[MIN_SVL_S:.*]] = arith.constant 4 : index +// CHECK: %[[SVL_S:.*]] = arith.muli %[[MIN_SVL_S]], %{{.*}} : index +// CHECK-NOT: arith.extui %[[CAST_VECTOR_TO_TILE]] +// CHECK-NOT: arith.trunci %[[CAST_VECTOR_TO_TILE]] +// CHECK: arm_sme.intr.st1w.horiz +func.func @vector_store_i32(%tile : vector<[4]x[4]xi32>, %arg0 : memref) { + %c0 = arith.constant 0 : index + vector.store %tile, %arg0[%c0, %c0] : memref, vector<[4]x[4]xi32> + return +} + +// ----- + +// CHECK-LABEL: @vector_store_i64( +// CHECK-SAME: %[[TILE:.*]]: vector<[2]x[2]xi64>, +// CHECK-SAME: %[[ARG0:.*]]: memref) +// CHECK: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[2]x[2]xi64> to i64 +// CHECK: %[[MIN_SVL_D:.*]] = arith.constant 2 : index +// CHECK: %[[SVL_D:.*]] = arith.muli %[[MIN_SVL_D]], %{{.*}} : index +// CHECK: %[[TILE_ID_I32:.*]] = arith.trunci %[[CAST_VECTOR_TO_TILE]] : i64 to i32 +// CHECK: arm_sme.intr.st1d.horiz +func.func @vector_store_i64(%tile : vector<[2]x[2]xi64>, %arg0 : memref) { + %c0 = arith.constant 0 : index + vector.store %tile, %arg0[%c0, %c0] : memref, vector<[2]x[2]xi64> + return +} + +// ----- + +// CHECK-LABEL: @vector_store_f16( +// CHECK-SAME: %[[TILE:.*]]: vector<[8]x[8]xf16>, +// CHECK-SAME: %[[ARG0:.*]]: memref) +// CHECK: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[8]x[8]xf16> to i16 +// CHECK: %[[MIN_SVL_H:.*]] = arith.constant 8 : index +// CHECK: %[[SVL_H:.*]] = arith.muli %[[MIN_SVL_H]], %{{.*}} : index +// CHECK: %[[TILE_ID_I32:.*]] = arith.extui %[[CAST_VECTOR_TO_TILE]] : i16 to i32 +// CHECK: arm_sme.intr.st1h.horiz +func.func @vector_store_f16(%tile : vector<[8]x[8]xf16>, %arg0 : memref) { + %c0 = arith.constant 0 : index + vector.store %tile, %arg0[%c0, %c0] : memref, vector<[8]x[8]xf16> + return +} + +// ----- + +// CHECK-LABEL: @vector_store_bf16( +// CHECK-SAME: %[[TILE:.*]]: vector<[8]x[8]xbf16>, +// CHECK-SAME: %[[ARG0:.*]]: memref) +// CHECK: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[8]x[8]xbf16> to i16 +// CHECK: %[[MIN_SVL_H:.*]] = arith.constant 8 : index +// CHECK: %[[SVL_H:.*]] = arith.muli %[[MIN_SVL_H]], %{{.*}} : index +// CHECK: %[[TILE_ID_I32:.*]] = arith.extui %[[CAST_VECTOR_TO_TILE]] : i16 to i32 +// CHECK: arm_sme.intr.st1h.horiz +func.func @vector_store_bf16(%tile : vector<[8]x[8]xbf16>, %arg0 : memref) { + %c0 = arith.constant 0 : index + vector.store %tile, %arg0[%c0, %c0] : memref, vector<[8]x[8]xbf16> + return +} +// ----- + +// CHECK-LABEL: @vector_store_f32( +// CHECK-SAME: %[[TILE:.*]]: vector<[4]x[4]xf32>, +// CHECK-SAME: %[[ARG0:.*]]: memref) +// CHECK: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[4]x[4]xf32> to i32 +// CHECK: %[[MIN_SVL_S:.*]] = arith.constant 4 : index +// CHECK: %[[SVL_S:.*]] = arith.muli %[[MIN_SVL_S]], %{{.*}} : index +// CHECK-NOT: arith.extui %[[CAST_VECTOR_TO_TILE]] +// CHECK-NOT: arith.trunci %[[CAST_VECTOR_TO_TILE]] +// CHECK: arm_sme.intr.st1w.horiz +func.func @vector_store_f32(%tile : vector<[4]x[4]xf32>, %arg0 : memref) { + %c0 = arith.constant 0 : index + vector.store %tile, %arg0[%c0, %c0] : memref, vector<[4]x[4]xf32> + return +} + +// ----- + +// CHECK-LABEL: @vector_store_f64( +// CHECK-SAME: %[[TILE:.*]]: vector<[2]x[2]xf64>, +// CHECK-SAME: %[[ARG0:.*]]: memref) +// CHECK: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[2]x[2]xf64> to i64 +// CHECK: %[[MIN_SVL_D:.*]] = arith.constant 2 : index +// CHECK: %[[SVL_D:.*]] = arith.muli %[[MIN_SVL_D]], %{{.*}} : index +// CHECK: %[[TILE_ID_I32:.*]] = arith.trunci %[[CAST_VECTOR_TO_TILE]] : i64 to i32 +// CHECK: arm_sme.intr.st1d.horiz +func.func @vector_store_f64(%tile : vector<[2]x[2]xf64>, %arg0 : memref) { + %c0 = arith.constant 0 : index + vector.store %tile, %arg0[%c0, %c0] : memref, vector<[2]x[2]xf64> + return +} diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir new file mode 100644 index 0000000..f0db752 --- /dev/null +++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir @@ -0,0 +1,192 @@ +// RUN: mlir-opt %s -enable-arm-streaming="mode=locally enable-za" \ +// RUN: -convert-vector-to-arm-sme -convert-vector-to-llvm="enable-arm-sme" \ +// RUN: -allocate-arm-sme-tiles -test-lower-to-llvm | \ +// RUN: mlir-translate -mlir-to-llvmir | \ +// RUN: %lli_aarch64_cmd --march=aarch64 --mattr="+sve,+sme" \ +// RUN: --entry-function=za0_d_f64 \ +// RUN: --dlopen=%mlir_native_utils_lib_dir/libmlir_c_runner_utils%shlibext | \ +// RUN: FileCheck %s --check-prefix=CHECK-ZA0_D + +// Integration test demonstrating load/store to/from SME ZA tile. + +llvm.func @printF64(f64) +llvm.func @printOpen() +llvm.func @printClose() +llvm.func @printComma() +llvm.func @printNewline() + +func.func @za0_d_f64() -> i32 { + %c0 = arith.constant 0 : index + %c0_f64 = arith.constant 0.0 : f64 + %c1_f64 = arith.constant 1.0 : f64 + %c1_index = arith.constant 1 : index + + %min_elts_d = arith.constant 2 : index + %vscale = vector.vscale + + // "svl" refers to the Streaming Vector Length and "svl_d" the number of + // 64-bit elements in a vector of SVL bits. + %svl_d = arith.muli %min_elts_d, %vscale : index + + // Allocate "mem1" and fill each "row" with row number. + // + // For example, assuming an SVL of 256-bits: + // + // 0.1, 0.1, 0.1, 0.1 + // 1.1, 1.1, 1.1, 1.1 + // 2.1, 2.1, 2.1, 2.1 + // 3.1, 3.1, 3.1, 3.1 + // + %tilesize = arith.muli %svl_d, %svl_d : index + %mem1 = memref.alloca(%tilesize) : memref + %init_0 = arith.constant 0.1 : f64 + scf.for %i = %c0 to %tilesize step %svl_d iter_args(%val = %init_0) -> (f64) { + %splat_val = vector.broadcast %val : f64 to vector<[2]xf64> + vector.store %splat_val, %mem1[%i] : memref, vector<[2]xf64> + %val_next = arith.addf %val, %c1_f64 : f64 + scf.yield %val_next : f64 + } + + // Dump "mem1". The smallest SVL is 128-bits so the tile will be at least + // 2x2xi64. + // + // CHECK-ZA0_D: ( 0.1, 0.1 + // CHECK-ZA0_D-NEXT: ( 1.1, 1.1 + scf.for %i = %c0 to %tilesize step %svl_d { + %tileslice = vector.load %mem1[%i] : memref, vector<[2]xf64> + + llvm.call @printOpen() : () -> () + scf.for %i2 = %c0 to %svl_d step %c1_index { + %elem = vector.extractelement %tileslice[%i2 : index] : vector<[2]xf64> + llvm.call @printF64(%elem) : (f64) -> () + %last_i = arith.subi %svl_d, %c1_index : index + %isNotLastIter = arith.cmpi ult, %i2, %last_i : index + scf.if %isNotLastIter { + llvm.call @printComma() : () -> () + } + } + llvm.call @printClose() : () -> () + llvm.call @printNewline() : () -> () + } + + // Load ZA0.D from "mem1" + %za0_d = vector.load %mem1[%c0] : memref, vector<[2]x[2]xf64> + + // Allocate "mem2" to store ZA0.D to + %mem2 = memref.alloca(%tilesize) : memref + + // Zero "mem2" + scf.for %i = %c0 to %tilesize step %c1_index { + memref.store %c0_f64, %mem2[%i] : memref + } + + // Verify "mem2" is zeroed by doing an add reduction with initial value of + // zero + %init_0_f64 = arith.constant 0.0 : f64 + %add_reduce = scf.for %vnum = %c0 to %tilesize step %svl_d iter_args(%iter = %init_0_f64) -> (f64) { + %row = vector.load %mem2[%vnum] : memref, vector<[2]xf64> + + %inner_add_reduce = scf.for %offset = %c0 to %svl_d step %c1_index iter_args(%inner_iter = %init_0_f64) -> (f64) { + %t = vector.extractelement %row[%offset : index] : vector<[2]xf64> + %inner_add_reduce_next = arith.addf %inner_iter, %t : f64 + scf.yield %inner_add_reduce_next : f64 + } + + %add_reduce_next = arith.addf %iter, %inner_add_reduce : f64 + scf.yield %add_reduce_next : f64 + } + + // CHECK-ZA0_D: 0 + vector.print %add_reduce : f64 + + // Dump zeroed "mem2". The smallest SVL is 128-bits so the tile will be at + // least 2x2xi64. + // + // CHECK-ZA0_D-NEXT: ( 0, 0 + // CHECK-ZA0_D-NEXT: ( 0, 0 + scf.for %i = %c0 to %tilesize step %svl_d { + %tileslice = vector.load %mem2[%i] : memref, vector<[2]xf64> + + llvm.call @printOpen() : () -> () + scf.for %i2 = %c0 to %svl_d step %c1_index { + %elem = vector.extractelement %tileslice[%i2 : index] : vector<[2]xf64> + llvm.call @printF64(%elem) : (f64) -> () + %last_i = arith.subi %svl_d, %c1_index : index + %isNotLastIter = arith.cmpi ult, %i2, %last_i : index + scf.if %isNotLastIter { + llvm.call @printComma() : () -> () + } + } + llvm.call @printClose() : () -> () + llvm.call @printNewline() : () -> () + } + + // Verify "mem1" != "mem2" + %init_1 = arith.constant 1 : i64 + %mul_reduce_0 = scf.for %vnum = %c0 to %tilesize step %svl_d iter_args(%iter = %init_1) -> (i64) { + %row_1 = vector.load %mem1[%vnum] : memref, vector<[2]xf64> + %row_2 = vector.load %mem2[%vnum] : memref, vector<[2]xf64> + %cmp = arith.cmpf one, %row_1, %row_2 : vector<[2]xf64> + + %inner_mul_reduce = scf.for %i = %c0 to %svl_d step %c1_index iter_args(%inner_iter = %init_1) -> (i64) { + %t = vector.extractelement %cmp[%i : index] : vector<[2]xi1> + %t_i64 = arith.extui %t : i1 to i64 + %inner_mul_reduce_next = arith.muli %inner_iter, %t_i64 : i64 + scf.yield %inner_mul_reduce_next : i64 + } + + %mul_reduce_next = arith.muli %iter, %inner_mul_reduce : i64 + scf.yield %mul_reduce_next : i64 + } + + // CHECK-ZA0_D: 1 + vector.print %mul_reduce_0 : i64 + + // Store ZA0.D to "mem2" + vector.store %za0_d, %mem2[%c0] : memref, vector<[2]x[2]xf64> + + // Verify "mem1" == "mem2" + %mul_reduce_1 = scf.for %vnum = %c0 to %tilesize step %svl_d iter_args(%iter = %init_1) -> (i64) { + %row_1 = vector.load %mem1[%vnum] : memref, vector<[2]xf64> + %row_2 = vector.load %mem2[%vnum] : memref, vector<[2]xf64> + %cmp = arith.cmpf oeq, %row_1, %row_2 : vector<[2]xf64> + + %inner_mul_reduce = scf.for %i = %c0 to %svl_d step %c1_index iter_args(%inner_iter = %init_1) -> (i64) { + %t = vector.extractelement %cmp[%i : index] : vector<[2]xi1> + %t_i64 = arith.extui %t : i1 to i64 + %inner_mul_reduce_next = arith.muli %inner_iter, %t_i64 : i64 + scf.yield %inner_mul_reduce_next : i64 + } + + %mul_reduce_next = arith.muli %iter, %inner_mul_reduce : i64 + scf.yield %mul_reduce_next : i64 + } + + // CHECK-ZA0_D-NEXT: 1 + vector.print %mul_reduce_1 : i64 + + // Dump "mem2". The smallest SVL is 128-bits so the tile will be at least + // 2x2xi64. + // + // CHECK-ZA0_D-NEXT: ( 0.1, 0.1 + // CHECK-ZA0_D-NEXT: ( 1.1, 1.1 + scf.for %i = %c0 to %tilesize step %svl_d { + %tileslice = vector.load %mem2[%i] : memref, vector<[2]xf64> + + llvm.call @printOpen() : () -> () + scf.for %i2 = %c0 to %svl_d step %c1_index { + %elem = vector.extractelement %tileslice[%i2 : index] : vector<[2]xf64> + llvm.call @printF64(%elem) : (f64) -> () + %last_i = arith.subi %svl_d, %c1_index : index + %isNotLastIter = arith.cmpi ult, %i2, %last_i : index + scf.if %isNotLastIter { + llvm.call @printComma() : () -> () + } + } + llvm.call @printClose() : () -> () + llvm.call @printNewline() : () -> () + } + + %c0_i32 = arith.constant 0 : i32 + return %c0_i32 : i32 +}