[mlir][ArmSME] Extend streaming-mode pass to support enabling ZA
authorCullen Rhodes <cullen.rhodes@arm.com>
Fri, 16 Jun 2023 09:14:53 +0000 (09:14 +0000)
committerCullen Rhodes <cullen.rhodes@arm.com>
Fri, 16 Jun 2023 09:26:42 +0000 (09:26 +0000)
This patch extends the 'enable-arm-streaming' pass with a new option to
enable the ZA storage array by adding the 'arm_za' attribute to
'func.func' ops.

A later patch will insert `llvm.aarch64.sme.za.enable` at the beginning
of 'func.func' ops and `llvm.aarch64.sme.za.disable` before
`func.return` statements when lowering to LLVM dialect.

Currently the pass only supports enabling ZA with streaming-mode on but
the SME LDR, STR and ZERO instructions can access ZA when not in
streaming-mode (section B1.1.1, IDGNQM [1]), so it may be worth making
these options independent in the future.

N.B. This patch is generally useful in the context of SME enablement in
MLIR, but it will help enable writing an integration test for rewrite
pattern that lowers `vector.transfer_write` -> `zero {za}` (D152508).

[1] https://developer.arm.com/documentation/ddi0616/aa

Reviewed By: awarzynski, dcaballe

Differential Revision: https://reviews.llvm.org/D152695

mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp
mlir/test/Dialect/ArmSME/enable-arm-streaming.mlir

index b51b0a3..00ac537 100644 (file)
@@ -27,7 +27,8 @@ enum class ArmStreaming { Default = 0, Locally = 1 };
 
 /// Pass to enable Armv9 Streaming SVE mode.
 std::unique_ptr<Pass>
-createEnableArmStreamingPass(const ArmStreaming mode = ArmStreaming::Default);
+createEnableArmStreamingPass(const ArmStreaming mode = ArmStreaming::Default,
+                             const bool enableZA = false);
 
 //===----------------------------------------------------------------------===//
 // Registration
index 8c9455d..7bc39e0 100644 (file)
@@ -33,6 +33,8 @@ def EnableArmStreaming
                                                   "Streaming mode is internal to the function, callee "
                                                   "manages PSTATE.SM on entry/exit.")
           )}]>,
+    Option<"enableZA", "enable-za", "bool", /*default=*/"false",
+           "Enable ZA storage array.">,
   ];
   let dependentDialects = ["func::FuncDialect"];
 }
index 0f35b22..97c38b5 100644 (file)
@@ -13,6 +13,8 @@
 //   * 'arm_streaming' (default)
 //   * 'arm_locally_streaming'
 //
+// It can also optionally enable the ZA storage array.
+//
 // Streaming-mode is part of the interface (ABI) for functions with the
 // first attribute and it's the responsibility of the caller to manage
 // PSTATE.SM on entry/exit to functions with this attribute [3]. The LLVM
@@ -49,11 +51,15 @@ using namespace mlir::arm_sme;
 
 static constexpr char kArmStreamingAttr[] = "arm_streaming";
 static constexpr char kArmLocallyStreamingAttr[] = "arm_locally_streaming";
+static constexpr char kArmZAAttr[] = "arm_za";
 
 namespace {
 struct EnableArmStreamingPass
     : public arm_sme::impl::EnableArmStreamingBase<EnableArmStreamingPass> {
-  EnableArmStreamingPass(ArmStreaming mode) { this->mode = mode; }
+  EnableArmStreamingPass(ArmStreaming mode, bool enableZA) {
+    this->mode = mode;
+    this->enableZA = enableZA;
+  }
   void runOnOperation() override {
     std::string attr;
     switch (mode) {
@@ -65,11 +71,19 @@ struct EnableArmStreamingPass
       break;
     }
     getOperation()->setAttr(attr, UnitAttr::get(&getContext()));
+
+    // The pass currently only supports enabling ZA when in streaming-mode, but
+    // ZA can be accessed by the SME LDR, STR and ZERO instructions when not in
+    // streaming-mode (see section B1.1.1, IDGNQM of spec [1]). It may be worth
+    // supporting this later.
+    if (enableZA)
+      getOperation()->setAttr(kArmZAAttr, UnitAttr::get(&getContext()));
   }
 };
 } // namespace
 
 std::unique_ptr<Pass>
-mlir::arm_sme::createEnableArmStreamingPass(const ArmStreaming mode) {
-  return std::make_unique<EnableArmStreamingPass>(mode);
+mlir::arm_sme::createEnableArmStreamingPass(const ArmStreaming mode,
+                                            const bool enableZA) {
+  return std::make_unique<EnableArmStreamingPass>(mode, enableZA);
 }
index 0c24f8c..f5cc831 100644 (file)
@@ -1,8 +1,11 @@
 // RUN: mlir-opt %s -enable-arm-streaming -verify-diagnostics | FileCheck %s
 // RUN: mlir-opt %s -enable-arm-streaming=mode=locally -verify-diagnostics | FileCheck %s -check-prefix=CHECK-LOCALLY
+// RUN: mlir-opt %s -enable-arm-streaming=enable-za -verify-diagnostics | FileCheck %s -check-prefix=CHECK-ENABLE-ZA
 
 // CHECK-LABEL: @arm_streaming
 // CHECK-SAME: attributes {arm_streaming}
 // CHECK-LOCALLY-LABEL: @arm_streaming
 // CHECK-LOCALLY-SAME: attributes {arm_locally_streaming}
+// CHECK-ENABLE-ZA-LABEL: @arm_streaming
+// CHECK-ENABLE-ZA-SAME: attributes {arm_streaming, arm_za}
 func.func @arm_streaming() { return }