[mlir][ArmSME] Add initial dialect with basic lowering of vector.transfer write to...
authorCullen Rhodes <cullen.rhodes@arm.com>
Wed, 14 Jun 2023 08:26:44 +0000 (08:26 +0000)
committerCullen Rhodes <cullen.rhodes@arm.com>
Wed, 14 Jun 2023 08:46:53 +0000 (08:46 +0000)
This patch adds support for lowering a `vector.transfer_write` of zeroes
and type `vector<[16x16]xi8>` to the SME `zero {za}` instruction [1],
which zeroes the entire accumulator.

This contributes to supporting a path from `linalg.fill` to SME.

[1] https://developer.arm.com/documentation/ddi0602/2022-06/SME-Instructions/ZERO--Zero-a-list-of-64-bit-element-ZA-tiles-

Reviewed By: awarzynski, dcaballe

Differential Revision: https://reviews.llvm.org/D152508

24 files changed:
mlir/include/mlir/Conversion/Passes.td
mlir/include/mlir/Dialect/ArmSME/CMakeLists.txt
mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td [new file with mode: 0644]
mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEDialect.h [new file with mode: 0644]
mlir/include/mlir/Dialect/ArmSME/IR/CMakeLists.txt [new file with mode: 0644]
mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h [new file with mode: 0644]
mlir/include/mlir/InitAllDialects.h
mlir/include/mlir/Target/LLVMIR/Dialect/All.h
mlir/include/mlir/Target/LLVMIR/Dialect/ArmSME/ArmSMEToLLVMIRTranslation.h [new file with mode: 0644]
mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
mlir/lib/Dialect/ArmSME/CMakeLists.txt
mlir/lib/Dialect/ArmSME/IR/ArmSMEDialect.cpp [new file with mode: 0644]
mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt [new file with mode: 0644]
mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp [new file with mode: 0644]
mlir/lib/Dialect/ArmSME/Transforms/LowerVectorOps.cpp [new file with mode: 0644]
mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
mlir/lib/Target/LLVMIR/CMakeLists.txt
mlir/lib/Target/LLVMIR/Dialect/ArmSME/ArmSMEToLLVMIRTranslation.cpp [new file with mode: 0644]
mlir/lib/Target/LLVMIR/Dialect/ArmSME/CMakeLists.txt [new file with mode: 0644]
mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt
mlir/test/Dialect/ArmSME/vector_ops.mlir [new file with mode: 0644]
mlir/test/Target/LLVMIR/arm-sme.mlir [new file with mode: 0644]

index 9e39137..7a6c6c7 100644 (file)
@@ -1092,6 +1092,10 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm", "ModuleOp"> {
            "bool", /*default=*/"false",
            "Enables the use of ArmSVE dialect while lowering the vector "
        "dialect.">,
+    Option<"armSME", "enable-arm-sme",
+           "bool", /*default=*/"false",
+           "Enables the use of ArmSME dialect while lowering the vector "
+       "dialect.">,
     Option<"x86Vector", "enable-x86vector",
            "bool", /*default=*/"false",
            "Enables the use of X86Vector dialect while lowering the vector "
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
new file mode 100644 (file)
index 0000000..d6f7df1
--- /dev/null
@@ -0,0 +1,50 @@
+//===-- ArmSME.td - ArmSME dialect operation definitions ---*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the basic operations for the ArmSME dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef ArmSME
+#define ArmSME
+
+include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
+
+//===----------------------------------------------------------------------===//
+// ArmSME dialect definition.
+//===----------------------------------------------------------------------===//
+
+def ArmSME_Dialect : Dialect {
+  let name = "arm_sme";
+  let cppNamespace = "::mlir::arm_sme";
+  let summary = "Dialect to target the Armv9 Scalable Matrix Extension (SME)";
+  let description = [{
+    This dialect contains the definitions necessary to target specific Arm SME
+    operations.
+
+    For more details on the architecture, see the Arm documentation:
+    https://developer.arm.com/documentation/ddi0616
+  }];
+  let usePropertiesForAttributes = 1;
+}
+
+//===----------------------------------------------------------------------===//
+// LLVMIR Intrinsics
+//===----------------------------------------------------------------------===//
+
+class ArmSME_IntrOp<string mnemonic, int numResults = 1,
+                    list<Trait> traits = []> :
+  LLVM_IntrOpBase<ArmSME_Dialect, "intr." # mnemonic,
+                  "aarch64_sme_" # !subst(".", "_", mnemonic), [], [], traits,
+                  numResults>;
+
+/// Create a call to aarch64_sme_zero intrinsic.
+def LLVM_aarch64_sme_zero
+    : ArmSME_IntrOp<"zero", 0>, Arguments<(ins I32:$imm)>;
+
+#endif // ArmSME
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEDialect.h b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEDialect.h
new file mode 100644 (file)
index 0000000..46ee791
--- /dev/null
@@ -0,0 +1,26 @@
+//===- ArmSMEDialect.h - MLIR Dialect for Arm SME ---------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares the Target dialect for ArmSME in MLIR.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_ARMSME_ARMSMEDIALECT_H
+#define MLIR_DIALECT_ARMSME_ARMSMEDIALECT_H
+
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+
+#include "mlir/Dialect/ArmSME/IR/ArmSMEDialect.h.inc"
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/ArmSME/IR/ArmSME.h.inc"
+
+#endif // MLIR_DIALECT_ARMSME_ARMSMEDIALECT_H
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/ArmSME/IR/CMakeLists.txt
new file mode 100644 (file)
index 0000000..d20ee65
--- /dev/null
@@ -0,0 +1,6 @@
+add_mlir_dialect(ArmSME arm_sme ArmSME)
+add_mlir_doc(ArmSME ArmSME Dialects/ -gen-dialect-doc -dialect=arm_sme)
+
+set(LLVM_TARGET_DEFINITIONS ArmSME.td)
+mlir_tablegen(ArmSMEConversions.inc -gen-llvmir-conversions)
+add_public_tablegen_target(MLIRArmSMEConversionsIncGen)
diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h b/mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h
new file mode 100644 (file)
index 0000000..f3eb839
--- /dev/null
@@ -0,0 +1,27 @@
+//===- Transforms.h - ArmSME Dialect Transformation Entrypoints -*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_ARMSME_TRANSFORMS_H
+#define MLIR_DIALECT_ARMSME_TRANSFORMS_H
+
+namespace mlir {
+
+class LLVMConversionTarget;
+class RewritePatternSet;
+
+namespace arm_sme {
+void populateVectorTransferLoweringPatterns(RewritePatternSet &patterns);
+} // namespace arm_sme
+
+/// Configure the target to support lowering ArmSME ops to ops that map to LLVM
+/// intrinsics.
+void configureArmSMELegalizeForExportTarget(LLVMConversionTarget &target);
+
+} // namespace mlir
+
+#endif // MLIR_DIALECT_ARMSME_TRANSFORMS_H
index 0baaa7b..ad706d5 100644 (file)
@@ -23,6 +23,7 @@
 #include "mlir/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.h"
 #include "mlir/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.h"
 #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
+#include "mlir/Dialect/ArmSME/IR/ArmSMEDialect.h"
 #include "mlir/Dialect/ArmSVE/ArmSVEDialect.h"
 #include "mlir/Dialect/Async/IR/Async.h"
 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
@@ -118,6 +119,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
                   quant::QuantizationDialect,
                   spirv::SPIRVDialect,
                   arm_sve::ArmSVEDialect,
+                  arm_sme::ArmSMEDialect,
                   vector::VectorDialect,
                   NVVM::NVVMDialect,
                   ROCDL::ROCDLDialect,
index cd7f76f..c0f7e70 100644 (file)
@@ -16,6 +16,7 @@
 
 #include "mlir/Target/LLVMIR/Dialect/AMX/AMXToLLVMIRTranslation.h"
 #include "mlir/Target/LLVMIR/Dialect/ArmNeon/ArmNeonToLLVMIRTranslation.h"
+#include "mlir/Target/LLVMIR/Dialect/ArmSME/ArmSMEToLLVMIRTranslation.h"
 #include "mlir/Target/LLVMIR/Dialect/ArmSVE/ArmSVEToLLVMIRTranslation.h"
 #include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h"
 #include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h"
@@ -36,6 +37,7 @@ static inline void registerAllToLLVMIRTranslations(DialectRegistry &registry) {
   registerArmNeonDialectTranslation(registry);
   registerAMXDialectTranslation(registry);
   registerArmSVEDialectTranslation(registry);
+  registerArmSMEDialectTranslation(registry);
   registerBuiltinDialectTranslation(registry);
   registerGPUDialectTranslation(registry);
   registerLLVMDialectTranslation(registry);
diff --git a/mlir/include/mlir/Target/LLVMIR/Dialect/ArmSME/ArmSMEToLLVMIRTranslation.h b/mlir/include/mlir/Target/LLVMIR/Dialect/ArmSME/ArmSMEToLLVMIRTranslation.h
new file mode 100644 (file)
index 0000000..205d9b6
--- /dev/null
@@ -0,0 +1,31 @@
+//=======- ArmSMEToLLVMIRTranslation.h - ArmSME to LLVM IR --*- C++ -*-=======//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This provides registration calls for ArmSME dialect to LLVM IR translation.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TARGET_LLVMIR_DIALECT_ARMSME_ARMSMETOLLVMIRTRANSLATION_H
+#define MLIR_TARGET_LLVMIR_DIALECT_ARMSME_ARMSMETOLLVMIRTRANSLATION_H
+
+namespace mlir {
+
+class DialectRegistry;
+class MLIRContext;
+
+/// Register the ArmSME dialect and the translation from it to the LLVM IR in
+/// the given registry;
+void registerArmSMEDialectTranslation(DialectRegistry &registry);
+
+/// Register the ArmSME dialect and the translation from it in the registry
+/// associated with the given context.
+void registerArmSMEDialectTranslation(MLIRContext &context);
+
+} // namespace mlir
+
+#endif // MLIR_TARGET_LLVMIR_DIALECT_ARMSME_ARMSMETOLLVMIRTRANSLATION_H
index b7fadea..5822208 100644 (file)
@@ -17,6 +17,8 @@ add_mlir_conversion_library(MLIRVectorToLLVM
   MLIRArmNeonDialect
   MLIRArmSVEDialect
   MLIRArmSVETransforms
+  MLIRArmSMEDialect
+  MLIRArmSMETransforms
   MLIRAMXDialect
   MLIRAMXTransforms
   MLIRLLVMCommonConversion
index 3f1b107..3e0cfd4 100644 (file)
@@ -14,6 +14,8 @@
 #include "mlir/Dialect/AMX/Transforms.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
+#include "mlir/Dialect/ArmSME/IR/ArmSMEDialect.h"
+#include "mlir/Dialect/ArmSME/Transforms/Transforms.h"
 #include "mlir/Dialect/ArmSVE/ArmSVEDialect.h"
 #include "mlir/Dialect/ArmSVE/Transforms.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -49,6 +51,8 @@ struct LowerVectorToLLVMPass
       registry.insert<arm_neon::ArmNeonDialect>();
     if (armSVE)
       registry.insert<arm_sve::ArmSVEDialect>();
+    if (armSME)
+      registry.insert<arm_sme::ArmSMEDialect>();
     if (amx)
       registry.insert<amx::AMXDialect>();
     if (x86Vector)
@@ -102,6 +106,10 @@ void LowerVectorToLLVMPass::runOnOperation() {
     configureArmSVELegalizeForExportTarget(target);
     populateArmSVELegalizeForLLVMExportPatterns(converter, patterns);
   }
+  if (armSME) {
+    configureArmSMELegalizeForExportTarget(target);
+    arm_sme::populateVectorTransferLoweringPatterns(patterns);
+  }
   if (amx) {
     configureAMXLegalizeForExportTarget(target);
     populateAMXLegalizeForLLVMExportPatterns(converter, patterns);
diff --git a/mlir/lib/Dialect/ArmSME/IR/ArmSMEDialect.cpp b/mlir/lib/Dialect/ArmSME/IR/ArmSMEDialect.cpp
new file mode 100644 (file)
index 0000000..fad040b
--- /dev/null
@@ -0,0 +1,33 @@
+//===- ArmSMEDialect.cpp - MLIR ArmSME dialect implementation -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the ArmSME dialect and its operations.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/ArmSME/IR/ArmSMEDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
+
+using namespace mlir;
+using namespace mlir::arm_sme;
+
+//===----------------------------------------------------------------------===//
+// Tablegen Definitions
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/ArmSME/IR/ArmSMEDialect.cpp.inc"
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/ArmSME/IR/ArmSME.cpp.inc"
+
+void ArmSMEDialect::initialize() {
+  addOperations<
+#define GET_OP_LIST
+#include "mlir/Dialect/ArmSME/IR/ArmSME.cpp.inc"
+      >();
+}
diff --git a/mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt b/mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt
new file mode 100644 (file)
index 0000000..4318a7a
--- /dev/null
@@ -0,0 +1,13 @@
+add_mlir_dialect_library(MLIRArmSMEDialect
+  ArmSMEDialect.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ArmSME
+
+  DEPENDS
+  MLIRArmSMEIncGen
+
+  LINK_LIBS PUBLIC
+  MLIRIR
+  MLIRLLVMDialect
+  )
index 2b616b5..3927c38 100644 (file)
@@ -1,5 +1,7 @@
 add_mlir_dialect_library(MLIRArmSMETransforms
   EnableArmStreaming.cpp
+  LegalizeForLLVMExport.cpp
+  LowerVectorOps.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ArmSME/Transforms
@@ -8,6 +10,11 @@ add_mlir_dialect_library(MLIRArmSMETransforms
   MLIRArmSMETransformsIncGen
 
   LINK_LIBS PUBLIC
+  MLIRArmSMEDialect
   MLIRFuncDialect
+  MLIRLLVMDialect
+  MLIRVectorDialect
+  MLIRLLVMCommonConversion
+  MLIRIR
   MLIRPass
   )
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
new file mode 100644 (file)
index 0000000..00512cb
--- /dev/null
@@ -0,0 +1,19 @@
+//===- LegalizeForLLVMExport.cpp - Prepare ArmSME for LLVM translation ----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
+#include "mlir/Dialect/ArmSME/IR/ArmSMEDialect.h"
+#include "mlir/Dialect/ArmSME/Transforms/Transforms.h"
+
+using namespace mlir;
+using namespace mlir::arm_sme;
+
+void mlir::configureArmSMELegalizeForExportTarget(
+    LLVMConversionTarget &target) {
+  target.addLegalOp<aarch64_sme_zero>();
+}
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/LowerVectorOps.cpp b/mlir/lib/Dialect/ArmSME/Transforms/LowerVectorOps.cpp
new file mode 100644 (file)
index 0000000..288623f
--- /dev/null
@@ -0,0 +1,55 @@
+//===- LowerVectorOps.cpp - Lower vector ops to SME -----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements rewrite patterns to lower vector dialect ops to ArmSME.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/ArmSME/IR/ArmSMEDialect.h"
+#include "mlir/Dialect/ArmSME/Transforms/Transforms.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+
+using namespace mlir;
+using namespace mlir::arm_sme;
+
+static constexpr unsigned kZeroZAMask = 255;
+
+namespace {
+/// Lower `vector.transfer_write` op to `arm_sme.intr.zero` op. Currently only
+/// supports 2d scalable vector type `vector<[16x16]xi8>` that maps to the ZA0.B
+/// SME tile. This will be extended to support more element types.
+struct TransferWriteToArmSMEZeroLowering
+    : public OpRewritePattern<vector::TransferWriteOp> {
+  TransferWriteToArmSMEZeroLowering(MLIRContext *context)
+      : OpRewritePattern<vector::TransferWriteOp>(context) {}
+
+  LogicalResult matchAndRewrite(vector::TransferWriteOp write,
+                                PatternRewriter &rewriter) const override {
+    auto vType = write.getVectorType();
+    if (vType.getRank() != 2)
+      return failure();
+    if (vType.getShape() != ArrayRef<int64_t>({16, 16}))
+      return failure();
+    if (vType.getElementType() != rewriter.getI8Type())
+      return failure();
+    if (vType.getNumScalableDims() != 2)
+      return failure();
+    auto tile = rewriter.create<arith::ConstantOp>(
+        write.getLoc(), rewriter.getI32Type(),
+        rewriter.getI32IntegerAttr(kZeroZAMask));
+    rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_zero>(write, tile);
+    return success();
+  }
+};
+} // namespace
+
+void mlir::arm_sme::populateVectorTransferLoweringPatterns(
+    RewritePatternSet &patterns) {
+  patterns.add<TransferWriteToArmSMEZeroLowering>(patterns.getContext());
+}
index 2d269ca..8708f40 100644 (file)
@@ -44,4 +44,5 @@ add_mlir_dialect_library(MLIRVectorTransforms
   MLIRVectorDialect
   MLIRVectorInterfaces
   MLIRVectorUtils
+  MLIRArmSMEDialect
   )
index f2d9594..196b8f1 100644 (file)
@@ -47,6 +47,7 @@ add_mlir_translation_library(MLIRToLLVMIRTranslationRegistration
   LINK_LIBS PUBLIC
   MLIRArmNeonToLLVMIRTranslation
   MLIRArmSVEToLLVMIRTranslation
+  MLIRArmSMEToLLVMIRTranslation
   MLIRAMXToLLVMIRTranslation
   MLIRBuiltinToLLVMIRTranslation
   MLIRGPUToLLVMIRTranslation
diff --git a/mlir/lib/Target/LLVMIR/Dialect/ArmSME/ArmSMEToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/ArmSME/ArmSMEToLLVMIRTranslation.cpp
new file mode 100644 (file)
index 0000000..a09d49e
--- /dev/null
@@ -0,0 +1,56 @@
+//======- ArmSMEToLLVMIRTranslation.cpp - Translate ArmSME to LLVM IR -=======//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a translation between the ArmSME dialect and LLVM IR.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Target/LLVMIR/Dialect/ArmSME/ArmSMEToLLVMIRTranslation.h"
+#include "mlir/Dialect/ArmSME/IR/ArmSMEDialect.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/Target/LLVMIR/ModuleTranslation.h"
+
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/IntrinsicsAArch64.h"
+
+using namespace mlir;
+using namespace mlir::LLVM;
+
+namespace {
+/// Implementation of the dialect interface that converts operations belonging
+/// to the ArmSME dialect to LLVM IR.
+class ArmSMEDialectLLVMIRTranslationInterface
+    : public LLVMTranslationDialectInterface {
+public:
+  using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface;
+
+  /// Translates the given operation to LLVM IR using the provided IR builder
+  /// and saving the state in `moduleTranslation`.
+  LogicalResult
+  convertOperation(Operation *op, llvm::IRBuilderBase &builder,
+                   LLVM::ModuleTranslation &moduleTranslation) const final {
+    Operation &opInst = *op;
+#include "mlir/Dialect/ArmSME/IR/ArmSMEConversions.inc"
+
+    return failure();
+  }
+};
+} // namespace
+
+void mlir::registerArmSMEDialectTranslation(DialectRegistry &registry) {
+  registry.insert<arm_sme::ArmSMEDialect>();
+  registry.addExtension(+[](MLIRContext *ctx, arm_sme::ArmSMEDialect *dialect) {
+    dialect->addInterfaces<ArmSMEDialectLLVMIRTranslationInterface>();
+  });
+}
+
+void mlir::registerArmSMEDialectTranslation(MLIRContext &context) {
+  DialectRegistry registry;
+  registerArmSMEDialectTranslation(registry);
+  context.appendDialectRegistry(registry);
+}
diff --git a/mlir/lib/Target/LLVMIR/Dialect/ArmSME/CMakeLists.txt b/mlir/lib/Target/LLVMIR/Dialect/ArmSME/CMakeLists.txt
new file mode 100644 (file)
index 0000000..d34cebf
--- /dev/null
@@ -0,0 +1,16 @@
+add_mlir_translation_library(MLIRArmSMEToLLVMIRTranslation
+  ArmSMEToLLVMIRTranslation.cpp
+
+  DEPENDS
+  MLIRArmSMEConversionsIncGen
+
+  LINK_COMPONENTS
+  Core
+
+  LINK_LIBS PUBLIC
+  MLIRIR
+  MLIRArmSMEDialect
+  MLIRLLVMDialect
+  MLIRSupport
+  MLIRTargetLLVMIRExport
+  )
index f27810f..5d3d7ab 100644 (file)
@@ -1,5 +1,6 @@
 add_subdirectory(ArmNeon)
 add_subdirectory(ArmSVE)
+add_subdirectory(ArmSME)
 add_subdirectory(AMX)
 add_subdirectory(Builtin)
 add_subdirectory(GPU)
diff --git a/mlir/test/Dialect/ArmSME/vector_ops.mlir b/mlir/test/Dialect/ArmSME/vector_ops.mlir
new file mode 100644 (file)
index 0000000..afe8d1a
--- /dev/null
@@ -0,0 +1,88 @@
+// RUN: mlir-opt %s -convert-vector-to-llvm="enable-arm-sme" -split-input-file | mlir-opt | FileCheck %s
+
+// CHECK-LABEL: @transfer_write_2d_zero_i8
+// CHECK: %[[C255:.*]] = arith.constant 255 : i32
+// CHECK: "arm_sme.intr.zero"(%[[C255]]) : (i32) -> ()
+func.func @transfer_write_2d_zero_i8() {
+  %c0 = arith.constant 0 : index
+  %c16 = arith.constant 16 : index
+  %vscale = vector.vscale
+  %dim = arith.muli %c16, %vscale : index
+  %0 = memref.alloc(%dim, %dim) : memref<?x?xi8>
+  %cst = arith.constant dense<0> : vector<[16x16]xi8>
+  vector.transfer_write %cst, %0[%c0, %c0] {in_bounds = [true, true]} : vector<[16x16]xi8>, memref<?x?xi8>
+  memref.dealloc %0 : memref<?x?xi8>
+  return
+}
+
+// -----
+
+// The following tests check the 'vector.transfer_write' -> 'arm_sme.intr.zero'
+// lowering only occurs for vector types of correct rank, shape, element size
+// and number of scalable dims.
+
+// CHECK-LABEL: @transfer_write_2d_zero__bad_type
+// CHECK: vector.transfer_write
+// CHECK-NOT: arm_sme.intr.zero
+func.func @transfer_write_2d_zero__bad_type() {
+  %c0 = arith.constant 0 : index
+  %c16 = arith.constant 16 : index
+  %vscale = vector.vscale
+  %dim = arith.muli %c16, %vscale : index
+  %0 = memref.alloc(%dim, %dim) : memref<?x?xi4>
+  %cst = arith.constant dense<0> : vector<[16x16]xi4>
+  vector.transfer_write %cst, %0[%c0, %c0] {in_bounds = [true, true]} : vector<[16x16]xi4>, memref<?x?xi4>
+  memref.dealloc %0 : memref<?x?xi4>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @transfer_write_2d_zero__bad_shape
+// CHECK: vector.transfer_write
+// CHECK-NOT: arm_sme.intr.zero
+func.func @transfer_write_2d_zero__bad_shape() {
+  %c0 = arith.constant 0 : index
+  %c8 = arith.constant 8 : index
+  %vscale = vector.vscale
+  %dim = arith.muli %c8, %vscale : index
+  %0 = memref.alloc(%dim, %dim) : memref<?x?xi8>
+  %cst = arith.constant dense<0> : vector<[8x8]xi8>
+  vector.transfer_write %cst, %0[%c0, %c0] {in_bounds = [true, true]} : vector<[8x8]xi8>, memref<?x?xi8>
+  memref.dealloc %0 : memref<?x?xi8>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @transfer_write_2d_zero__bad_rank
+// CHECK: vector.transfer_write
+// CHECK-NOT: arm_sme.intr.zero
+func.func @transfer_write_2d_zero__bad_rank() {
+  %c0 = arith.constant 0 : index
+  %c16 = arith.constant 16 : index
+  %vscale = vector.vscale
+  %dim = arith.muli %c16, %vscale : index
+  %0 = memref.alloc(%dim, %dim, %dim) : memref<?x?x?xi8>
+  %cst = arith.constant dense<0> : vector<[16x16x16]xi8>
+  vector.transfer_write %cst, %0[%c0, %c0, %c0] {in_bounds = [true, true, true]} : vector<[16x16x16]xi8>, memref<?x?x?xi8>
+  memref.dealloc %0 : memref<?x?x?xi8>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @transfer_write_2d_zero__bad_num_scalable_dims
+// CHECK: vector.transfer_write
+// CHECK-NOT: arm_sme.intr.zero
+func.func @transfer_write_2d_zero__bad_num_scalable_dims() {
+  %c0 = arith.constant 0 : index
+  %c16 = arith.constant 16 : index
+  %vscale = vector.vscale
+  %dim = arith.muli %c16, %vscale : index
+  %0 = memref.alloc(%dim) : memref<16x?xi8>
+  %cst = arith.constant dense<0> : vector<16x[16]xi8>
+  vector.transfer_write %cst, %0[%c0, %c0] {in_bounds = [true, true]} : vector<16x[16]xi8>, memref<16x?xi8>
+  memref.dealloc %0 : memref<16x?xi8>
+  return
+}
diff --git a/mlir/test/Target/LLVMIR/arm-sme.mlir b/mlir/test/Target/LLVMIR/arm-sme.mlir
new file mode 100644 (file)
index 0000000..b39ae26
--- /dev/null
@@ -0,0 +1,11 @@
+// RUN: mlir-translate --mlir-to-llvmir -split-input-file %s | FileCheck %s
+
+// CHECK-LABEL: define void @arm_sme_zero
+// CHECK: call void @llvm.aarch64.sme.zero(i32 255)
+llvm.func @arm_sme_zero() {
+  %mask = llvm.mlir.constant(255 : i32) : i32
+  "arm_sme.intr.zero"(%mask) : (i32) -> ()
+  llvm.return
+}
+
+// -----