[mlir][ArmSME] Insert intrinsics to enable/disable ZA
authorCullen Rhodes <cullen.rhodes@arm.com>
Fri, 16 Jun 2023 09:27:20 +0000 (09:27 +0000)
committerCullen Rhodes <cullen.rhodes@arm.com>
Fri, 16 Jun 2023 09:40:48 +0000 (09:40 +0000)
This patch adds two LLVM intrinsics to the ArmSME dialect:

  * llvm.aarch64.sme.za.enable
  * llvm.aarch64.sme.za.disable

for enabling the ZA storage array [1], as well as patterns for inserting
them during legalization to LLVM at the start and end of functions if
the function has the 'arm_za' attribute (D152695).

In the future ZA should probably be automatically enabled/disabled when
lowering from vector to SME, but this should be sufficient for now at
least until we have patterns lowering to SME instructions that use ZA.

N.B. The backend function attribute 'aarch64_pstate_za_new' can be used
manage ZA state (as was originally tried in D152694), but it emits calls
to the following SME support routines [2] for the lazy-save mechanism
[3]:

  * __arm_tpidr2_restore
  * __arm_tpidr2_save

These will soon be added to compiler-rt but there's currently no public
implementation, and using this attribute would introduce an MLIR
dependency on compiler-rt. Furthermore, this mechanism is for routines
with ZA enabled calling other routines with it also enabled. We can
choose not to enable ZA in the compiler when this is case.

Depends on D152695

[1] https://developer.arm.com/documentation/ddi0616/aa
[2] https://github.com/ARM-software/abi-aa/blob/main/aapcs64/aapcs64.rst#sme-support-routines
[3] https://github.com/ARM-software/abi-aa/blob/main/aapcs64/aapcs64.rst#the-za-lazy-saving-scheme

Reviewed By: awarzynski, dcaballe

Differential Revision: https://reviews.llvm.org/D153050

mlir/include/mlir/Conversion/Passes.td
mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h [new file with mode: 0644]
mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp [new file with mode: 0644]
mlir/test/Dialect/ArmSME/enable-arm-za.mlir [new file with mode: 0644]
mlir/test/Target/LLVMIR/arm-sme.mlir

index 9e39137..7a6c6c7 100644 (file)
@@ -1092,6 +1092,10 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm", "ModuleOp"> {
            "bool", /*default=*/"false",
            "Enables the use of ArmSVE dialect while lowering the vector "
        "dialect.">,
+    Option<"armSME", "enable-arm-sme",
+           "bool", /*default=*/"false",
+           "Enables the use of ArmSME dialect while lowering the vector "
+       "dialect.">,
     Option<"x86Vector", "enable-x86vector",
            "bool", /*default=*/"false",
            "Enables the use of X86Vector dialect while lowering the vector "
index 45a0ad7..d0072b6 100644 (file)
@@ -119,4 +119,7 @@ def LLVM_aarch64_sme_st1w_vert : ArmSME_IntrStoreOp<"st1w.vert">;
 def LLVM_aarch64_sme_st1d_vert : ArmSME_IntrStoreOp<"st1d.vert">;
 def LLVM_aarch64_sme_st1q_vert : ArmSME_IntrStoreOp<"st1q.vert">;
 
+def LLVM_aarch64_sme_za_enable : ArmSME_IntrOp<"za.enable">;
+def LLVM_aarch64_sme_za_disable : ArmSME_IntrOp<"za.disable">;
+
 #endif // ARMSME_OPS
diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h b/mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h
new file mode 100644 (file)
index 0000000..c8d9e67
--- /dev/null
@@ -0,0 +1,29 @@
+//===- Transforms.h - ArmSME Dialect Transformation Entrypoints -*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_ARMSME_TRANSFORMS_H
+#define MLIR_DIALECT_ARMSME_TRANSFORMS_H
+
+namespace mlir {
+
+class LLVMConversionTarget;
+class LLVMTypeConverter;
+class RewritePatternSet;
+
+/// Collect a set of patterns to lower ArmSME ops to ops that map to LLVM
+/// intrinsics.
+void populateArmSMELegalizeForLLVMExportPatterns(LLVMTypeConverter &converter,
+                                                 RewritePatternSet &patterns);
+
+/// Configure the target to support lowering ArmSME ops to ops that map to LLVM
+/// intrinsics.
+void configureArmSMELegalizeForExportTarget(LLVMConversionTarget &target);
+
+} // namespace mlir
+
+#endif // MLIR_DIALECT_ARMSME_TRANSFORMS_H
index b7fadea..e4a5528 100644 (file)
@@ -15,6 +15,8 @@ add_mlir_conversion_library(MLIRVectorToLLVM
   LINK_LIBS PUBLIC
   MLIRArithDialect
   MLIRArmNeonDialect
+  MLIRArmSMEDialect
+  MLIRArmSMETransforms
   MLIRArmSVEDialect
   MLIRArmSVETransforms
   MLIRAMXDialect
index 3f1b107..acc4244 100644 (file)
@@ -14,6 +14,8 @@
 #include "mlir/Dialect/AMX/Transforms.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
+#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
+#include "mlir/Dialect/ArmSME/Transforms/Transforms.h"
 #include "mlir/Dialect/ArmSVE/ArmSVEDialect.h"
 #include "mlir/Dialect/ArmSVE/Transforms.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -49,6 +51,8 @@ struct LowerVectorToLLVMPass
       registry.insert<arm_neon::ArmNeonDialect>();
     if (armSVE)
       registry.insert<arm_sve::ArmSVEDialect>();
+    if (armSME)
+      registry.insert<arm_sme::ArmSMEDialect>();
     if (amx)
       registry.insert<amx::AMXDialect>();
     if (x86Vector)
@@ -102,6 +106,10 @@ void LowerVectorToLLVMPass::runOnOperation() {
     configureArmSVELegalizeForExportTarget(target);
     populateArmSVELegalizeForLLVMExportPatterns(converter, patterns);
   }
+  if (armSME) {
+    configureArmSMELegalizeForExportTarget(target);
+    populateArmSMELegalizeForLLVMExportPatterns(converter, patterns);
+  }
   if (amx) {
     configureAMXLegalizeForExportTarget(target);
     populateAMXLegalizeForLLVMExportPatterns(converter, patterns);
index 2b616b5..efcb17f 100644 (file)
@@ -1,5 +1,6 @@
 add_mlir_dialect_library(MLIRArmSMETransforms
   EnableArmStreaming.cpp
+  LegalizeForLLVMExport.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ArmSME/Transforms
@@ -8,6 +9,8 @@ add_mlir_dialect_library(MLIRArmSMETransforms
   MLIRArmSMETransformsIncGen
 
   LINK_LIBS PUBLIC
+  MLIRArmSMEDialect
   MLIRFuncDialect
+  MLIRLLVMCommonConversion
   MLIRPass
   )
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
new file mode 100644 (file)
index 0000000..3fe9e78
--- /dev/null
@@ -0,0 +1,78 @@
+//===- LegalizeForLLVMExport.cpp - Prepare ArmSME for LLVM translation ----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
+#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
+#include "mlir/Dialect/ArmSME/Transforms/Transforms.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+
+using namespace mlir;
+using namespace mlir::arm_sme;
+
+namespace {
+/// Insert 'llvm.aarch64.sme.za.enable' intrinsic at the start of 'func.func'
+/// ops to enable the ZA storage array.
+struct EnableZAPattern : public OpRewritePattern<func::FuncOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(func::FuncOp op,
+                                PatternRewriter &rewriter) const final {
+    OpBuilder::InsertionGuard g(rewriter);
+    rewriter.setInsertionPointToStart(&op.front());
+    rewriter.create<arm_sme::aarch64_sme_za_enable>(op->getLoc());
+    rewriter.updateRootInPlace(op, [] {});
+    return success();
+  }
+};
+
+/// Insert 'llvm.aarch64.sme.za.disable' intrinsic before 'func.return' ops to
+/// disable the ZA storage array.
+struct DisableZAPattern : public OpRewritePattern<func::ReturnOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(func::ReturnOp op,
+                                PatternRewriter &rewriter) const final {
+    OpBuilder::InsertionGuard g(rewriter);
+    rewriter.setInsertionPoint(op);
+    rewriter.create<arm_sme::aarch64_sme_za_disable>(op->getLoc());
+    rewriter.updateRootInPlace(op, [] {});
+    return success();
+  }
+};
+} // namespace
+
+void mlir::populateArmSMELegalizeForLLVMExportPatterns(
+    LLVMTypeConverter &converter, RewritePatternSet &patterns) {
+  patterns.add<EnableZAPattern, DisableZAPattern>(patterns.getContext());
+}
+
+void mlir::configureArmSMELegalizeForExportTarget(
+    LLVMConversionTarget &target) {
+  target.addLegalOp<arm_sme::aarch64_sme_za_enable,
+                    arm_sme::aarch64_sme_za_disable>();
+
+  // Mark 'func.func' ops as legal if either:
+  //   1. no 'arm_za' function attribute is present.
+  //   2. the 'arm_za' function attribute is present and the first op in the
+  //      function is an 'arm_sme::aarch64_sme_za_enable' intrinsic.
+  target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp funcOp) {
+    auto firstOp = funcOp.getBody().front().begin();
+    return !funcOp->hasAttr("arm_za") ||
+           isa<arm_sme::aarch64_sme_za_enable>(firstOp);
+  });
+
+  // Mark 'func.return' ops as legal if either:
+  //   1. no 'arm_za' function attribute is present.
+  //   2. the 'arm_za' function attribute is present and there's a preceding
+  //      'arm_sme::aarch64_sme_za_disable' intrinsic.
+  target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp returnOp) {
+    bool hasDisableZA = false;
+    auto funcOp = returnOp->getParentOp();
+    funcOp->walk<WalkOrder::PreOrder>(
+        [&](arm_sme::aarch64_sme_za_disable op) { hasDisableZA = true; });
+    return !funcOp->hasAttr("arm_za") || hasDisableZA;
+  });
+}
diff --git a/mlir/test/Dialect/ArmSME/enable-arm-za.mlir b/mlir/test/Dialect/ArmSME/enable-arm-za.mlir
new file mode 100644 (file)
index 0000000..ae0bbdc
--- /dev/null
@@ -0,0 +1,15 @@
+// RUN: mlir-opt %s -enable-arm-streaming=enable-za -convert-vector-to-llvm="enable-arm-sme" | FileCheck %s -check-prefix=ENABLE-ZA
+// RUN: mlir-opt %s -enable-arm-streaming -convert-vector-to-llvm="enable-arm-sme" | FileCheck %s -check-prefix=DISABLE-ZA
+// RUN: mlir-opt %s -convert-vector-to-llvm="enable-arm-sme" | FileCheck %s -check-prefix=NO-ARM-STREAMING
+
+// CHECK-LABEL: @arm_za
+func.func @arm_za() {
+  // ENABLE-ZA: arm_sme.intr.za.enable
+  // ENABLE-ZA-NEXT: arm_sme.intr.za.disable
+  // ENABLE-ZA-NEXT: return
+  // DISABLE-ZA-NOT: arm_sme.intr.za.enable
+  // DISABLE-ZA-NOT: arm_sme.intr.za.disable
+  // NO-ARM-STREAMING-NOT: arm_sme.intr.za.enable
+  // NO-ARM-STREAMING-NOT: arm_sme.intr.za.disable
+  return
+}
index 096d619..453a887 100644 (file)
@@ -223,3 +223,14 @@ llvm.func @arm_sme_store(%nxv1i1  : vector<[1]xi1>,
               (vector<[16]xi1>, !llvm.ptr<i8>, i32, i32) -> ()
   llvm.return
 }
+
+// -----
+
+// CHECK-LABEL: @arm_sme_toggle_za
+llvm.func @arm_sme_toggle_za() {
+  // CHECK: call void @llvm.aarch64.sme.za.enable()
+  "arm_sme.intr.za.enable"() : () -> ()
+  // CHECK: call void @llvm.aarch64.sme.za.disable()
+  "arm_sme.intr.za.disable"() : () -> ()
+  llvm.return
+}