"bool", /*default=*/"false",
"Enables the use of ArmSVE dialect while lowering the vector "
"dialect.">,
+ Option<"armSME", "enable-arm-sme",
+ "bool", /*default=*/"false",
+ "Enables the use of ArmSME dialect while lowering the vector "
+ "dialect.">,
Option<"x86Vector", "enable-x86vector",
"bool", /*default=*/"false",
"Enables the use of X86Vector dialect while lowering the vector "
def LLVM_aarch64_sme_st1d_vert : ArmSME_IntrStoreOp<"st1d.vert">;
def LLVM_aarch64_sme_st1q_vert : ArmSME_IntrStoreOp<"st1q.vert">;
+def LLVM_aarch64_sme_za_enable : ArmSME_IntrOp<"za.enable">;
+def LLVM_aarch64_sme_za_disable : ArmSME_IntrOp<"za.disable">;
+
#endif // ARMSME_OPS
--- /dev/null
+//===- Transforms.h - ArmSME Dialect Transformation Entrypoints -*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_ARMSME_TRANSFORMS_H
+#define MLIR_DIALECT_ARMSME_TRANSFORMS_H
+
+namespace mlir {
+
+class LLVMConversionTarget;
+class LLVMTypeConverter;
+class RewritePatternSet;
+
+/// Collect a set of patterns to lower ArmSME ops to ops that map to LLVM
+/// intrinsics.
+void populateArmSMELegalizeForLLVMExportPatterns(LLVMTypeConverter &converter,
+ RewritePatternSet &patterns);
+
+/// Configure the target to support lowering ArmSME ops to ops that map to LLVM
+/// intrinsics.
+void configureArmSMELegalizeForExportTarget(LLVMConversionTarget &target);
+
+} // namespace mlir
+
+#endif // MLIR_DIALECT_ARMSME_TRANSFORMS_H
LINK_LIBS PUBLIC
MLIRArithDialect
MLIRArmNeonDialect
+ MLIRArmSMEDialect
+ MLIRArmSMETransforms
MLIRArmSVEDialect
MLIRArmSVETransforms
MLIRAMXDialect
#include "mlir/Dialect/AMX/Transforms.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
+#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
+#include "mlir/Dialect/ArmSME/Transforms/Transforms.h"
#include "mlir/Dialect/ArmSVE/ArmSVEDialect.h"
#include "mlir/Dialect/ArmSVE/Transforms.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
registry.insert<arm_neon::ArmNeonDialect>();
if (armSVE)
registry.insert<arm_sve::ArmSVEDialect>();
+ if (armSME)
+ registry.insert<arm_sme::ArmSMEDialect>();
if (amx)
registry.insert<amx::AMXDialect>();
if (x86Vector)
configureArmSVELegalizeForExportTarget(target);
populateArmSVELegalizeForLLVMExportPatterns(converter, patterns);
}
+ if (armSME) {
+ configureArmSMELegalizeForExportTarget(target);
+ populateArmSMELegalizeForLLVMExportPatterns(converter, patterns);
+ }
if (amx) {
configureAMXLegalizeForExportTarget(target);
populateAMXLegalizeForLLVMExportPatterns(converter, patterns);
add_mlir_dialect_library(MLIRArmSMETransforms
EnableArmStreaming.cpp
+ LegalizeForLLVMExport.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ArmSME/Transforms
MLIRArmSMETransformsIncGen
LINK_LIBS PUBLIC
+ MLIRArmSMEDialect
MLIRFuncDialect
+ MLIRLLVMCommonConversion
MLIRPass
)
--- /dev/null
+//===- LegalizeForLLVMExport.cpp - Prepare ArmSME for LLVM translation ----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
+#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
+#include "mlir/Dialect/ArmSME/Transforms/Transforms.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+
+using namespace mlir;
+using namespace mlir::arm_sme;
+
+namespace {
+/// Insert 'llvm.aarch64.sme.za.enable' intrinsic at the start of 'func.func'
+/// ops to enable the ZA storage array.
+struct EnableZAPattern : public OpRewritePattern<func::FuncOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(func::FuncOp op,
+ PatternRewriter &rewriter) const final {
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPointToStart(&op.front());
+ rewriter.create<arm_sme::aarch64_sme_za_enable>(op->getLoc());
+ rewriter.updateRootInPlace(op, [] {});
+ return success();
+ }
+};
+
+/// Insert 'llvm.aarch64.sme.za.disable' intrinsic before 'func.return' ops to
+/// disable the ZA storage array.
+struct DisableZAPattern : public OpRewritePattern<func::ReturnOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(func::ReturnOp op,
+ PatternRewriter &rewriter) const final {
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(op);
+ rewriter.create<arm_sme::aarch64_sme_za_disable>(op->getLoc());
+ rewriter.updateRootInPlace(op, [] {});
+ return success();
+ }
+};
+} // namespace
+
+void mlir::populateArmSMELegalizeForLLVMExportPatterns(
+ LLVMTypeConverter &converter, RewritePatternSet &patterns) {
+ patterns.add<EnableZAPattern, DisableZAPattern>(patterns.getContext());
+}
+
+void mlir::configureArmSMELegalizeForExportTarget(
+ LLVMConversionTarget &target) {
+ target.addLegalOp<arm_sme::aarch64_sme_za_enable,
+ arm_sme::aarch64_sme_za_disable>();
+
+ // Mark 'func.func' ops as legal if either:
+ // 1. no 'arm_za' function attribute is present.
+ // 2. the 'arm_za' function attribute is present and the first op in the
+ // function is an 'arm_sme::aarch64_sme_za_enable' intrinsic.
+ target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp funcOp) {
+ auto firstOp = funcOp.getBody().front().begin();
+ return !funcOp->hasAttr("arm_za") ||
+ isa<arm_sme::aarch64_sme_za_enable>(firstOp);
+ });
+
+ // Mark 'func.return' ops as legal if either:
+ // 1. no 'arm_za' function attribute is present.
+ // 2. the 'arm_za' function attribute is present and there's a preceding
+ // 'arm_sme::aarch64_sme_za_disable' intrinsic.
+ target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp returnOp) {
+ bool hasDisableZA = false;
+ auto funcOp = returnOp->getParentOp();
+ funcOp->walk<WalkOrder::PreOrder>(
+ [&](arm_sme::aarch64_sme_za_disable op) { hasDisableZA = true; });
+ return !funcOp->hasAttr("arm_za") || hasDisableZA;
+ });
+}
--- /dev/null
+// RUN: mlir-opt %s -enable-arm-streaming=enable-za -convert-vector-to-llvm="enable-arm-sme" | FileCheck %s -check-prefix=ENABLE-ZA
+// RUN: mlir-opt %s -enable-arm-streaming -convert-vector-to-llvm="enable-arm-sme" | FileCheck %s -check-prefix=DISABLE-ZA
+// RUN: mlir-opt %s -convert-vector-to-llvm="enable-arm-sme" | FileCheck %s -check-prefix=NO-ARM-STREAMING
+
+// CHECK-LABEL: @arm_za
+func.func @arm_za() {
+ // ENABLE-ZA: arm_sme.intr.za.enable
+ // ENABLE-ZA-NEXT: arm_sme.intr.za.disable
+ // ENABLE-ZA-NEXT: return
+ // DISABLE-ZA-NOT: arm_sme.intr.za.enable
+ // DISABLE-ZA-NOT: arm_sme.intr.za.disable
+ // NO-ARM-STREAMING-NOT: arm_sme.intr.za.enable
+ // NO-ARM-STREAMING-NOT: arm_sme.intr.za.disable
+ return
+}
(vector<[16]xi1>, !llvm.ptr<i8>, i32, i32) -> ()
llvm.return
}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_toggle_za
+llvm.func @arm_sme_toggle_za() {
+ // CHECK: call void @llvm.aarch64.sme.za.enable()
+ "arm_sme.intr.za.enable"() : () -> ()
+ // CHECK: call void @llvm.aarch64.sme.za.disable()
+ "arm_sme.intr.za.disable"() : () -> ()
+ llvm.return
+}