From 65305aeab99ad8ea09dd85e28a41c657152a08fb Mon Sep 17 00:00:00 2001 From: Cullen Rhodes Date: Fri, 16 Jun 2023 09:27:20 +0000 Subject: [PATCH] [mlir][ArmSME] Insert intrinsics to enable/disable ZA This patch adds two LLVM intrinsics to the ArmSME dialect: * llvm.aarch64.sme.za.enable * llvm.aarch64.sme.za.disable for enabling the ZA storage array [1], as well as patterns for inserting them during legalization to LLVM at the start and end of functions if the function has the 'arm_za' attribute (D152695). In the future ZA should probably be automatically enabled/disabled when lowering from vector to SME, but this should be sufficient for now at least until we have patterns lowering to SME instructions that use ZA. N.B. The backend function attribute 'aarch64_pstate_za_new' can be used manage ZA state (as was originally tried in D152694), but it emits calls to the following SME support routines [2] for the lazy-save mechanism [3]: * __arm_tpidr2_restore * __arm_tpidr2_save These will soon be added to compiler-rt but there's currently no public implementation, and using this attribute would introduce an MLIR dependency on compiler-rt. Furthermore, this mechanism is for routines with ZA enabled calling other routines with it also enabled. We can choose not to enable ZA in the compiler when this is case. Depends on D152695 [1] https://developer.arm.com/documentation/ddi0616/aa [2] https://github.com/ARM-software/abi-aa/blob/main/aapcs64/aapcs64.rst#sme-support-routines [3] https://github.com/ARM-software/abi-aa/blob/main/aapcs64/aapcs64.rst#the-za-lazy-saving-scheme Reviewed By: awarzynski, dcaballe Differential Revision: https://reviews.llvm.org/D153050 --- mlir/include/mlir/Conversion/Passes.td | 4 ++ mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td | 3 + .../mlir/Dialect/ArmSME/Transforms/Transforms.h | 29 ++++++++ mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt | 2 + .../VectorToLLVM/ConvertVectorToLLVMPass.cpp | 8 +++ mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt | 3 + .../ArmSME/Transforms/LegalizeForLLVMExport.cpp | 78 ++++++++++++++++++++++ mlir/test/Dialect/ArmSME/enable-arm-za.mlir | 15 +++++ mlir/test/Target/LLVMIR/arm-sme.mlir | 11 +++ 9 files changed, 153 insertions(+) create mode 100644 mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h create mode 100644 mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp create mode 100644 mlir/test/Dialect/ArmSME/enable-arm-za.mlir diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index 9e39137..7a6c6c7 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -1092,6 +1092,10 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm", "ModuleOp"> { "bool", /*default=*/"false", "Enables the use of ArmSVE dialect while lowering the vector " "dialect.">, + Option<"armSME", "enable-arm-sme", + "bool", /*default=*/"false", + "Enables the use of ArmSME dialect while lowering the vector " + "dialect.">, Option<"x86Vector", "enable-x86vector", "bool", /*default=*/"false", "Enables the use of X86Vector dialect while lowering the vector " diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td index 45a0ad7..d0072b6 100644 --- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td +++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td @@ -119,4 +119,7 @@ def LLVM_aarch64_sme_st1w_vert : ArmSME_IntrStoreOp<"st1w.vert">; def LLVM_aarch64_sme_st1d_vert : ArmSME_IntrStoreOp<"st1d.vert">; def LLVM_aarch64_sme_st1q_vert : ArmSME_IntrStoreOp<"st1q.vert">; +def LLVM_aarch64_sme_za_enable : ArmSME_IntrOp<"za.enable">; +def LLVM_aarch64_sme_za_disable : ArmSME_IntrOp<"za.disable">; + #endif // ARMSME_OPS diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h b/mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h new file mode 100644 index 0000000..c8d9e67 --- /dev/null +++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h @@ -0,0 +1,29 @@ +//===- Transforms.h - ArmSME Dialect Transformation Entrypoints -*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_ARMSME_TRANSFORMS_H +#define MLIR_DIALECT_ARMSME_TRANSFORMS_H + +namespace mlir { + +class LLVMConversionTarget; +class LLVMTypeConverter; +class RewritePatternSet; + +/// Collect a set of patterns to lower ArmSME ops to ops that map to LLVM +/// intrinsics. +void populateArmSMELegalizeForLLVMExportPatterns(LLVMTypeConverter &converter, + RewritePatternSet &patterns); + +/// Configure the target to support lowering ArmSME ops to ops that map to LLVM +/// intrinsics. +void configureArmSMELegalizeForExportTarget(LLVMConversionTarget &target); + +} // namespace mlir + +#endif // MLIR_DIALECT_ARMSME_TRANSFORMS_H diff --git a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt index b7fadea..e4a5528 100644 --- a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt +++ b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt @@ -15,6 +15,8 @@ add_mlir_conversion_library(MLIRVectorToLLVM LINK_LIBS PUBLIC MLIRArithDialect MLIRArmNeonDialect + MLIRArmSMEDialect + MLIRArmSMETransforms MLIRArmSVEDialect MLIRArmSVETransforms MLIRAMXDialect diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp index 3f1b107..acc4244 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp @@ -14,6 +14,8 @@ #include "mlir/Dialect/AMX/Transforms.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h" +#include "mlir/Dialect/ArmSME/IR/ArmSME.h" +#include "mlir/Dialect/ArmSME/Transforms/Transforms.h" #include "mlir/Dialect/ArmSVE/ArmSVEDialect.h" #include "mlir/Dialect/ArmSVE/Transforms.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -49,6 +51,8 @@ struct LowerVectorToLLVMPass registry.insert(); if (armSVE) registry.insert(); + if (armSME) + registry.insert(); if (amx) registry.insert(); if (x86Vector) @@ -102,6 +106,10 @@ void LowerVectorToLLVMPass::runOnOperation() { configureArmSVELegalizeForExportTarget(target); populateArmSVELegalizeForLLVMExportPatterns(converter, patterns); } + if (armSME) { + configureArmSMELegalizeForExportTarget(target); + populateArmSMELegalizeForLLVMExportPatterns(converter, patterns); + } if (amx) { configureAMXLegalizeForExportTarget(target); populateAMXLegalizeForLLVMExportPatterns(converter, patterns); diff --git a/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt index 2b616b5..efcb17f 100644 --- a/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_dialect_library(MLIRArmSMETransforms EnableArmStreaming.cpp + LegalizeForLLVMExport.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ArmSME/Transforms @@ -8,6 +9,8 @@ add_mlir_dialect_library(MLIRArmSMETransforms MLIRArmSMETransformsIncGen LINK_LIBS PUBLIC + MLIRArmSMEDialect MLIRFuncDialect + MLIRLLVMCommonConversion MLIRPass ) diff --git a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp new file mode 100644 index 0000000..3fe9e78 --- /dev/null +++ b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp @@ -0,0 +1,78 @@ +//===- LegalizeForLLVMExport.cpp - Prepare ArmSME for LLVM translation ----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/LLVMCommon/ConversionTarget.h" +#include "mlir/Dialect/ArmSME/IR/ArmSME.h" +#include "mlir/Dialect/ArmSME/Transforms/Transforms.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" + +using namespace mlir; +using namespace mlir::arm_sme; + +namespace { +/// Insert 'llvm.aarch64.sme.za.enable' intrinsic at the start of 'func.func' +/// ops to enable the ZA storage array. +struct EnableZAPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(func::FuncOp op, + PatternRewriter &rewriter) const final { + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPointToStart(&op.front()); + rewriter.create(op->getLoc()); + rewriter.updateRootInPlace(op, [] {}); + return success(); + } +}; + +/// Insert 'llvm.aarch64.sme.za.disable' intrinsic before 'func.return' ops to +/// disable the ZA storage array. +struct DisableZAPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(func::ReturnOp op, + PatternRewriter &rewriter) const final { + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(op); + rewriter.create(op->getLoc()); + rewriter.updateRootInPlace(op, [] {}); + return success(); + } +}; +} // namespace + +void mlir::populateArmSMELegalizeForLLVMExportPatterns( + LLVMTypeConverter &converter, RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} + +void mlir::configureArmSMELegalizeForExportTarget( + LLVMConversionTarget &target) { + target.addLegalOp(); + + // Mark 'func.func' ops as legal if either: + // 1. no 'arm_za' function attribute is present. + // 2. the 'arm_za' function attribute is present and the first op in the + // function is an 'arm_sme::aarch64_sme_za_enable' intrinsic. + target.addDynamicallyLegalOp([&](func::FuncOp funcOp) { + auto firstOp = funcOp.getBody().front().begin(); + return !funcOp->hasAttr("arm_za") || + isa(firstOp); + }); + + // Mark 'func.return' ops as legal if either: + // 1. no 'arm_za' function attribute is present. + // 2. the 'arm_za' function attribute is present and there's a preceding + // 'arm_sme::aarch64_sme_za_disable' intrinsic. + target.addDynamicallyLegalOp([&](func::ReturnOp returnOp) { + bool hasDisableZA = false; + auto funcOp = returnOp->getParentOp(); + funcOp->walk( + [&](arm_sme::aarch64_sme_za_disable op) { hasDisableZA = true; }); + return !funcOp->hasAttr("arm_za") || hasDisableZA; + }); +} diff --git a/mlir/test/Dialect/ArmSME/enable-arm-za.mlir b/mlir/test/Dialect/ArmSME/enable-arm-za.mlir new file mode 100644 index 0000000..ae0bbdc --- /dev/null +++ b/mlir/test/Dialect/ArmSME/enable-arm-za.mlir @@ -0,0 +1,15 @@ +// RUN: mlir-opt %s -enable-arm-streaming=enable-za -convert-vector-to-llvm="enable-arm-sme" | FileCheck %s -check-prefix=ENABLE-ZA +// RUN: mlir-opt %s -enable-arm-streaming -convert-vector-to-llvm="enable-arm-sme" | FileCheck %s -check-prefix=DISABLE-ZA +// RUN: mlir-opt %s -convert-vector-to-llvm="enable-arm-sme" | FileCheck %s -check-prefix=NO-ARM-STREAMING + +// CHECK-LABEL: @arm_za +func.func @arm_za() { + // ENABLE-ZA: arm_sme.intr.za.enable + // ENABLE-ZA-NEXT: arm_sme.intr.za.disable + // ENABLE-ZA-NEXT: return + // DISABLE-ZA-NOT: arm_sme.intr.za.enable + // DISABLE-ZA-NOT: arm_sme.intr.za.disable + // NO-ARM-STREAMING-NOT: arm_sme.intr.za.enable + // NO-ARM-STREAMING-NOT: arm_sme.intr.za.disable + return +} diff --git a/mlir/test/Target/LLVMIR/arm-sme.mlir b/mlir/test/Target/LLVMIR/arm-sme.mlir index 096d619..453a887 100644 --- a/mlir/test/Target/LLVMIR/arm-sme.mlir +++ b/mlir/test/Target/LLVMIR/arm-sme.mlir @@ -223,3 +223,14 @@ llvm.func @arm_sme_store(%nxv1i1 : vector<[1]xi1>, (vector<[16]xi1>, !llvm.ptr, i32, i32) -> () llvm.return } + +// ----- + +// CHECK-LABEL: @arm_sme_toggle_za +llvm.func @arm_sme_toggle_za() { + // CHECK: call void @llvm.aarch64.sme.za.enable() + "arm_sme.intr.za.enable"() : () -> () + // CHECK: call void @llvm.aarch64.sme.za.disable() + "arm_sme.intr.za.disable"() : () -> () + llvm.return +} -- 2.7.4