[mlir][ArmSME] Dialect and Intrinsic Op Definition
authorFrank (Fang) Gao <fang.gao1@huawei.com>
Wed, 14 Jun 2023 21:03:36 +0000 (17:03 -0400)
committerPrabhdeep Singh Soni <prabhdeep.singh.soni3@huawei.com>
Wed, 14 Jun 2023 21:11:49 +0000 (17:11 -0400)
This patch creates the ArmSME dialect, and provides the intrinsic op
definition necessary for lowering to LLVM IR.

This will cover most instructions interacting with the ZA tile register,
not covering SME2 instructions.

Source: https://developer.arm.com/documentation/ddi0616/latest

Reviewed By: awarzynski, c-rhodes

Differential Revision: https://reviews.llvm.org/D152878

16 files changed:
mlir/include/mlir/Dialect/ArmSME/CMakeLists.txt
mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h [new file with mode: 0644]
mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td [new file with mode: 0644]
mlir/include/mlir/Dialect/ArmSME/IR/CMakeLists.txt [new file with mode: 0644]
mlir/include/mlir/InitAllDialects.h
mlir/include/mlir/Target/LLVMIR/Dialect/All.h
mlir/include/mlir/Target/LLVMIR/Dialect/ArmSME/ArmSMEToLLVMIRTranslation.h [new file with mode: 0644]
mlir/lib/Dialect/ArmSME/CMakeLists.txt
mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp [new file with mode: 0644]
mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt [new file with mode: 0644]
mlir/lib/Target/LLVMIR/CMakeLists.txt
mlir/lib/Target/LLVMIR/Dialect/ArmSME/ArmSMEToLLVMIRTranslation.cpp [new file with mode: 0644]
mlir/lib/Target/LLVMIR/Dialect/ArmSME/CMakeLists.txt [new file with mode: 0644]
mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt
mlir/test/Target/LLVMIR/arm-sme.mlir [new file with mode: 0644]
mlir/test/mlir-opt/commandline.mlir

diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
new file mode 100644 (file)
index 0000000..a69d326
--- /dev/null
@@ -0,0 +1,27 @@
+//===- ArmSMEDialect.h - MLIR Dialect for Arm SME ---------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares the Target dialect for ArmSME in MLIR.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_ARMSME_IR_ARMSME_H
+#define MLIR_DIALECT_ARMSME_IR_ARMSME_H
+
+#include "mlir/Bytecode/BytecodeOpInterface.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+
+#include "mlir/Dialect/ArmSME/IR/ArmSMEDialect.h.inc"
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/ArmSME/IR/ArmSME.h.inc"
+
+#endif // MLIR_DIALECT_ARMSME_IR_ARMSME_H
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
new file mode 100644 (file)
index 0000000..45a0ad7
--- /dev/null
@@ -0,0 +1,122 @@
+//===-- ArmSME.td - ArmSME dialect operation definitions ---*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the ArmSME dialect and contains intrinsic ops to lower to
+// LLVM IR.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef ARMSME_OPS
+#define ARMSME_OPS
+
+include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
+
+//===----------------------------------------------------------------------===//
+// ArmSME dialect definition
+//===----------------------------------------------------------------------===//
+
+def ArmSME_Dialect : Dialect {
+  let name = "arm_sme";
+  let cppNamespace = "::mlir::arm_sme";
+  let summary = "Basic dialect to target Arm SME architectures";
+  let description = [{
+    This dialect contains the definitions necessary to target Arm SME
+    scalable matrix operations.
+
+    Sources:
+    https://developer.arm.com/documentation/ddi0616
+    https://developer.arm.com/documentation/ddi0602/2023-03/SME-Instructions
+  }];
+}
+
+//===----------------------------------------------------------------------===//
+// ArmSME Intrinsic op definitions
+//===----------------------------------------------------------------------===//
+
+def MOPPredicate : ScalableVectorOfLengthAndType<[16, 8, 4, 2], [I1]>;
+def MOPVector : ScalableVectorOfLengthAndType<[16, 8, 4, 2],
+                                              [I8, I16, BF16, F16, F32, F64]>;
+def LDSTPredicate : ScalableVectorOfLengthAndType<[16, 8, 4, 2, 1], [I1]>;
+
+class ArmSME_IntrOp<string mnemonic, list<int> overloadedOperands = [],
+                    list<Trait> traits = []>
+    : LLVM_IntrOpBase<
+          /*Dialect dialect=*/ArmSME_Dialect,
+          /*string opName=*/"intr." # mnemonic,
+          /*string enumName=*/"aarch64_sme_" # !subst(".", "_", mnemonic),
+          /*list<int> overloadedResults=*/[],
+          /*list<int> overloadedOperands=*/overloadedOperands,
+          /*list<Trait> traits=*/traits,
+          /*int numResults=*/0>;
+
+// Zero
+def LLVM_aarch64_sme_zero : ArmSME_IntrOp<"zero">,
+                            Arguments<(ins Arg<I32, "Tile mask">)>;
+
+// MOP's
+class ArmSME_IntrMopOverloadedOp<string mnemonic>
+    : ArmSME_IntrOp<mnemonic, [4]>,
+      Arguments<(ins Arg<I32, "Virtual tile ID">,
+                 Arg<MOPPredicate, "LHS predicate">,
+                 Arg<MOPPredicate, "RHS predicate">,
+                 Arg<MOPVector, "LHS vector operand">,
+                 Arg<MOPVector, "RHS vector operand">)>;
+
+def LLVM_aarch64_sme_mopa : ArmSME_IntrMopOverloadedOp<"mopa">;
+def LLVM_aarch64_sme_mops : ArmSME_IntrMopOverloadedOp<"mops">;
+def LLVM_aarch64_sme_mopa_wide : ArmSME_IntrMopOverloadedOp<"mopa.wide">;
+def LLVM_aarch64_sme_mops_wide : ArmSME_IntrMopOverloadedOp<"mops.wide">;
+def LLVM_aarch64_sme_smopa_wide : ArmSME_IntrMopOverloadedOp<"smopa.wide">;
+def LLVM_aarch64_sme_smops_wide : ArmSME_IntrMopOverloadedOp<"smops.wide">;
+def LLVM_aarch64_sme_umopa_wide : ArmSME_IntrMopOverloadedOp<"umopa.wide">;
+def LLVM_aarch64_sme_umops_wide : ArmSME_IntrMopOverloadedOp<"umops.wide">;
+def LLVM_aarch64_sme_sumopa_wide : ArmSME_IntrMopOverloadedOp<"sumopa.wide">;
+def LLVM_aarch64_sme_sumops_wide : ArmSME_IntrMopOverloadedOp<"sumops.wide">;
+def LLVM_aarch64_sme_usmopa_wide : ArmSME_IntrMopOverloadedOp<"usmopa.wide">;
+def LLVM_aarch64_sme_usmops_wide : ArmSME_IntrMopOverloadedOp<"usmops.wide">;
+
+// Loads
+class ArmSME_IntrLoadOp<string mnemonic>
+    : ArmSME_IntrOp<mnemonic>,
+      Arguments<(ins Arg<LDSTPredicate, "Vector predicate">,
+                 Arg<LLVM_AnyPointer, "Load address", [MemRead]>,
+                 Arg<I32, "Virtual tile ID">,
+                 Arg<I32, "Tile slice">)>;
+
+def LLVM_aarch64_sme_ld1b_horiz : ArmSME_IntrLoadOp<"ld1b.horiz">;
+def LLVM_aarch64_sme_ld1h_horiz : ArmSME_IntrLoadOp<"ld1h.horiz">;
+def LLVM_aarch64_sme_ld1w_horiz : ArmSME_IntrLoadOp<"ld1w.horiz">;
+def LLVM_aarch64_sme_ld1d_horiz : ArmSME_IntrLoadOp<"ld1d.horiz">;
+def LLVM_aarch64_sme_ld1q_horiz : ArmSME_IntrLoadOp<"ld1q.horiz">;
+def LLVM_aarch64_sme_ld1b_vert : ArmSME_IntrLoadOp<"ld1b.vert">;
+def LLVM_aarch64_sme_ld1h_vert : ArmSME_IntrLoadOp<"ld1h.vert">;
+def LLVM_aarch64_sme_ld1w_vert : ArmSME_IntrLoadOp<"ld1w.vert">;
+def LLVM_aarch64_sme_ld1d_vert : ArmSME_IntrLoadOp<"ld1d.vert">;
+def LLVM_aarch64_sme_ld1q_vert : ArmSME_IntrLoadOp<"ld1q.vert">;
+
+// Stores
+class ArmSME_IntrStoreOp<string mnemonic>
+    : ArmSME_IntrOp<mnemonic>,
+      Arguments<(ins Arg<LDSTPredicate, "Vector predicate">,
+                 Arg<LLVM_AnyPointer, "Store address", [MemWrite]>,
+                 Arg<I32, "Virtual tile ID">,
+                 Arg<I32, "Tile slice">)>;
+
+def LLVM_aarch64_sme_st1b_horiz : ArmSME_IntrStoreOp<"st1b.horiz">;
+def LLVM_aarch64_sme_st1h_horiz : ArmSME_IntrStoreOp<"st1h.horiz">;
+def LLVM_aarch64_sme_st1w_horiz : ArmSME_IntrStoreOp<"st1w.horiz">;
+def LLVM_aarch64_sme_st1d_horiz : ArmSME_IntrStoreOp<"st1d.horiz">;
+def LLVM_aarch64_sme_st1q_horiz : ArmSME_IntrStoreOp<"st1q.horiz">;
+def LLVM_aarch64_sme_st1b_vert : ArmSME_IntrStoreOp<"st1b.vert">;
+def LLVM_aarch64_sme_st1h_vert : ArmSME_IntrStoreOp<"st1h.vert">;
+def LLVM_aarch64_sme_st1w_vert : ArmSME_IntrStoreOp<"st1w.vert">;
+def LLVM_aarch64_sme_st1d_vert : ArmSME_IntrStoreOp<"st1d.vert">;
+def LLVM_aarch64_sme_st1q_vert : ArmSME_IntrStoreOp<"st1q.vert">;
+
+#endif // ARMSME_OPS
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/ArmSME/IR/CMakeLists.txt
new file mode 100644 (file)
index 0000000..d20ee65
--- /dev/null
@@ -0,0 +1,6 @@
+add_mlir_dialect(ArmSME arm_sme ArmSME)
+add_mlir_doc(ArmSME ArmSME Dialects/ -gen-dialect-doc -dialect=arm_sme)
+
+set(LLVM_TARGET_DEFINITIONS ArmSME.td)
+mlir_tablegen(ArmSMEConversions.inc -gen-llvmir-conversions)
+add_public_tablegen_target(MLIRArmSMEConversionsIncGen)
index 0baaa7b..db15dff 100644 (file)
@@ -23,6 +23,7 @@
 #include "mlir/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.h"
 #include "mlir/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.h"
 #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
+#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
 #include "mlir/Dialect/ArmSVE/ArmSVEDialect.h"
 #include "mlir/Dialect/Async/IR/Async.h"
 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
@@ -117,6 +118,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
                   pdl_interp::PDLInterpDialect,
                   quant::QuantizationDialect,
                   spirv::SPIRVDialect,
+                  arm_sme::ArmSMEDialect,
                   arm_sve::ArmSVEDialect,
                   vector::VectorDialect,
                   NVVM::NVVMDialect,
index cd7f76f..65c1c51 100644 (file)
@@ -16,6 +16,7 @@
 
 #include "mlir/Target/LLVMIR/Dialect/AMX/AMXToLLVMIRTranslation.h"
 #include "mlir/Target/LLVMIR/Dialect/ArmNeon/ArmNeonToLLVMIRTranslation.h"
+#include "mlir/Target/LLVMIR/Dialect/ArmSME/ArmSMEToLLVMIRTranslation.h"
 #include "mlir/Target/LLVMIR/Dialect/ArmSVE/ArmSVEToLLVMIRTranslation.h"
 #include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h"
 #include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h"
@@ -35,6 +36,7 @@ class DialectRegistry;
 static inline void registerAllToLLVMIRTranslations(DialectRegistry &registry) {
   registerArmNeonDialectTranslation(registry);
   registerAMXDialectTranslation(registry);
+  registerArmSMEDialectTranslation(registry);
   registerArmSVEDialectTranslation(registry);
   registerBuiltinDialectTranslation(registry);
   registerGPUDialectTranslation(registry);
diff --git a/mlir/include/mlir/Target/LLVMIR/Dialect/ArmSME/ArmSMEToLLVMIRTranslation.h b/mlir/include/mlir/Target/LLVMIR/Dialect/ArmSME/ArmSMEToLLVMIRTranslation.h
new file mode 100644 (file)
index 0000000..205d9b6
--- /dev/null
@@ -0,0 +1,31 @@
+//=======- ArmSMEToLLVMIRTranslation.h - ArmSME to LLVM IR --*- C++ -*-=======//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This provides registration calls for ArmSME dialect to LLVM IR translation.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TARGET_LLVMIR_DIALECT_ARMSME_ARMSMETOLLVMIRTRANSLATION_H
+#define MLIR_TARGET_LLVMIR_DIALECT_ARMSME_ARMSMETOLLVMIRTRANSLATION_H
+
+namespace mlir {
+
+class DialectRegistry;
+class MLIRContext;
+
+/// Register the ArmSME dialect and the translation from it to the LLVM IR in
+/// the given registry;
+void registerArmSMEDialectTranslation(DialectRegistry &registry);
+
+/// Register the ArmSME dialect and the translation from it in the registry
+/// associated with the given context.
+void registerArmSMEDialectTranslation(MLIRContext &context);
+
+} // namespace mlir
+
+#endif // MLIR_TARGET_LLVMIR_DIALECT_ARMSME_ARMSMETOLLVMIRTRANSLATION_H
diff --git a/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp b/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp
new file mode 100644 (file)
index 0000000..7f5aa61
--- /dev/null
@@ -0,0 +1,36 @@
+//===- ArmSMEDialect.cpp - MLIR ArmSME dialect implementation -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the ArmSME dialect and its operations.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
+#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
+
+using namespace mlir;
+using namespace mlir::arm_sme;
+
+//===----------------------------------------------------------------------===//
+// Tablegen Definitions
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/ArmSME/IR/ArmSMEDialect.cpp.inc"
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/ArmSME/IR/ArmSME.cpp.inc"
+
+#define GET_TYPEDEF_CLASSES
+#include "mlir/Dialect/ArmSME/IR/ArmSMETypes.cpp.inc"
+
+void ArmSMEDialect::initialize() {
+  addOperations<
+#define GET_OP_LIST
+#include "mlir/Dialect/ArmSME/IR/ArmSME.cpp.inc"
+      >();
+}
diff --git a/mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt b/mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt
new file mode 100644 (file)
index 0000000..afe69de
--- /dev/null
@@ -0,0 +1,14 @@
+add_mlir_dialect_library(MLIRArmSMEDialect
+  ArmSME.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ArmSME
+
+  DEPENDS
+  MLIRArmSMEIncGen
+
+  LINK_LIBS PUBLIC
+  MLIRIR
+  MLIRLLVMDialect
+  MLIRSideEffectInterfaces
+)
index f2d9594..868ccbb 100644 (file)
@@ -46,6 +46,7 @@ add_mlir_translation_library(MLIRToLLVMIRTranslationRegistration
 
   LINK_LIBS PUBLIC
   MLIRArmNeonToLLVMIRTranslation
+  MLIRArmSMEToLLVMIRTranslation
   MLIRArmSVEToLLVMIRTranslation
   MLIRAMXToLLVMIRTranslation
   MLIRBuiltinToLLVMIRTranslation
diff --git a/mlir/lib/Target/LLVMIR/Dialect/ArmSME/ArmSMEToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/ArmSME/ArmSMEToLLVMIRTranslation.cpp
new file mode 100644 (file)
index 0000000..1b57b99
--- /dev/null
@@ -0,0 +1,56 @@
+//======- ArmSMEToLLVMIRTranslation.cpp - Translate ArmSME to LLVM IR -=======//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a translation between the ArmSME dialect and LLVM IR.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Target/LLVMIR/Dialect/ArmSME/ArmSMEToLLVMIRTranslation.h"
+#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/Target/LLVMIR/ModuleTranslation.h"
+
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/IntrinsicsAArch64.h"
+
+using namespace mlir;
+using namespace mlir::LLVM;
+
+namespace {
+/// Implementation of the dialect interface that converts operations belonging
+/// to the ArmSME dialect to LLVM IR.
+class ArmSMEDialectLLVMIRTranslationInterface
+    : public LLVMTranslationDialectInterface {
+public:
+  using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface;
+
+  /// Translates the given operation to LLVM IR using the provided IR builder
+  /// and saving the state in `moduleTranslation`.
+  LogicalResult
+  convertOperation(Operation *op, llvm::IRBuilderBase &builder,
+                   LLVM::ModuleTranslation &moduleTranslation) const final {
+    Operation &opInst = *op;
+#include "mlir/Dialect/ArmSME/IR/ArmSMEConversions.inc"
+
+    return failure();
+  }
+};
+} // namespace
+
+void mlir::registerArmSMEDialectTranslation(DialectRegistry &registry) {
+  registry.insert<arm_sme::ArmSMEDialect>();
+  registry.addExtension(+[](MLIRContext *ctx, arm_sme::ArmSMEDialect *dialect) {
+    dialect->addInterfaces<ArmSMEDialectLLVMIRTranslationInterface>();
+  });
+}
+
+void mlir::registerArmSMEDialectTranslation(MLIRContext &context) {
+  DialectRegistry registry;
+  registerArmSMEDialectTranslation(registry);
+  context.appendDialectRegistry(registry);
+}
diff --git a/mlir/lib/Target/LLVMIR/Dialect/ArmSME/CMakeLists.txt b/mlir/lib/Target/LLVMIR/Dialect/ArmSME/CMakeLists.txt
new file mode 100644 (file)
index 0000000..d34cebf
--- /dev/null
@@ -0,0 +1,16 @@
+add_mlir_translation_library(MLIRArmSMEToLLVMIRTranslation
+  ArmSMEToLLVMIRTranslation.cpp
+
+  DEPENDS
+  MLIRArmSMEConversionsIncGen
+
+  LINK_COMPONENTS
+  Core
+
+  LINK_LIBS PUBLIC
+  MLIRIR
+  MLIRArmSMEDialect
+  MLIRLLVMDialect
+  MLIRSupport
+  MLIRTargetLLVMIRExport
+  )
index f27810f..fb0e5cd 100644 (file)
@@ -1,4 +1,5 @@
 add_subdirectory(ArmNeon)
+add_subdirectory(ArmSME)
 add_subdirectory(ArmSVE)
 add_subdirectory(AMX)
 add_subdirectory(Builtin)
diff --git a/mlir/test/Target/LLVMIR/arm-sme.mlir b/mlir/test/Target/LLVMIR/arm-sme.mlir
new file mode 100644 (file)
index 0000000..096d619
--- /dev/null
@@ -0,0 +1,225 @@
+// RUN: mlir-translate -mlir-to-llvmir -split-input-file %s | FileCheck %s
+
+// CHECK-LABEL: @arm_sme_zero
+llvm.func @arm_sme_zero() {
+  %c0 = llvm.mlir.constant(0 : index) : i32
+  // CHECK: call void @llvm.aarch64.sme.zero(i32 0)
+  "arm_sme.intr.zero"(%c0) : (i32) -> ()
+  llvm.return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_fmopa
+llvm.func @arm_sme_fmopa(%nxv2f64 : vector<[2]xf64>,
+                         %nxv4f32 : vector<[4]xf32>,
+                         %nxv8f16 : vector<[8]xf16>,
+                         %nxv8bf16: vector<[8]xbf16>,
+                         %nxv2i1  : vector<[2]xi1>,
+                         %nxv4i1  : vector<[4]xi1>,
+                         %nxv8i1  : vector<[8]xi1>) {
+  %c0 = llvm.mlir.constant(0 : index) : i32
+  // CHECK: call void @llvm.aarch64.sme.mopa.nxv2f64
+  "arm_sme.intr.mopa"(%c0, %nxv2i1, %nxv2i1, %nxv2f64, %nxv2f64) :
+    (i32, vector<[2]xi1>, vector<[2]xi1>, vector<[2]xf64>, vector<[2]xf64>) -> ()
+  // CHECK: call void @llvm.aarch64.sme.mopa.nxv4f32
+  "arm_sme.intr.mopa"(%c0, %nxv4i1, %nxv4i1, %nxv4f32, %nxv4f32) :
+    (i32, vector<[4]xi1>, vector<[4]xi1>, vector<[4]xf32>, vector<[4]xf32>) -> ()
+  // CHECK: call void @llvm.aarch64.sme.mopa.wide.nxv8f16
+  "arm_sme.intr.mopa.wide"(%c0, %nxv8i1, %nxv8i1, %nxv8f16, %nxv8f16) :
+    (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xf16>, vector<[8]xf16>) -> ()
+  // CHECK: call void @llvm.aarch64.sme.mopa.wide.nxv8bf16
+  "arm_sme.intr.mopa.wide"(%c0, %nxv8i1, %nxv8i1, %nxv8bf16, %nxv8bf16) :
+    (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xbf16>, vector<[8]xbf16>) -> ()
+  llvm.return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_imopa
+llvm.func @arm_sme_imopa(%nxv8i16 : vector<[8]xi16>,
+                         %nxv16i8 : vector<[16]xi8>,
+                         %nxv8i1  : vector<[8]xi1>,
+                         %nxv16i1 : vector<[16]xi1>) {
+  %c0 = llvm.mlir.constant(0 : index) : i32
+  // CHECK: call void @llvm.aarch64.sme.smopa.wide.nxv8i16
+  "arm_sme.intr.smopa.wide"(%c0, %nxv8i1, %nxv8i1, %nxv8i16, %nxv8i16) :
+    (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
+  // CHECK: call void @llvm.aarch64.sme.umopa.wide.nxv8i16
+  "arm_sme.intr.umopa.wide"(%c0, %nxv8i1, %nxv8i1, %nxv8i16, %nxv8i16) :
+    (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
+  // CHECK: call void @llvm.aarch64.sme.sumopa.wide.nxv8i16
+  "arm_sme.intr.sumopa.wide"(%c0, %nxv8i1, %nxv8i1, %nxv8i16, %nxv8i16) :
+    (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
+  // CHECK: call void @llvm.aarch64.sme.usmopa.wide.nxv8i16
+  "arm_sme.intr.usmopa.wide"(%c0, %nxv8i1, %nxv8i1, %nxv8i16, %nxv8i16) :
+    (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
+  // CHECK: call void @llvm.aarch64.sme.smopa.wide.nxv16i8
+  "arm_sme.intr.smopa.wide"(%c0, %nxv16i1, %nxv16i1, %nxv16i8, %nxv16i8) :
+    (i32, vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
+  // CHECK: call void @llvm.aarch64.sme.umopa.wide.nxv16i8
+  "arm_sme.intr.umopa.wide"(%c0, %nxv16i1, %nxv16i1, %nxv16i8, %nxv16i8) :
+    (i32, vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
+  // CHECK: call void @llvm.aarch64.sme.sumopa.wide.nxv16i8
+  "arm_sme.intr.sumopa.wide"(%c0, %nxv16i1, %nxv16i1, %nxv16i8, %nxv16i8) :
+    (i32, vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
+  // CHECK: call void @llvm.aarch64.sme.usmopa.wide.nxv16i8
+  "arm_sme.intr.usmopa.wide"(%c0, %nxv16i1, %nxv16i1, %nxv16i8, %nxv16i8) :
+    (i32, vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
+  llvm.return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_fmops
+llvm.func @arm_sme_fmops(%nxv2f64 : vector<[2]xf64>,
+                         %nxv4f32 : vector<[4]xf32>,
+                         %nxv8f16 : vector<[8]xf16>,
+                         %nxv8bf16: vector<[8]xbf16>,
+                         %nxv2i1  : vector<[2]xi1>,
+                         %nxv4i1  : vector<[4]xi1>,
+                         %nxv8i1  : vector<[8]xi1>) {
+  %c0 = llvm.mlir.constant(0 : index) : i32
+  // CHECK: call void @llvm.aarch64.sme.mops.nxv2f64
+  "arm_sme.intr.mops"(%c0, %nxv2i1, %nxv2i1, %nxv2f64, %nxv2f64) :
+    (i32, vector<[2]xi1>, vector<[2]xi1>, vector<[2]xf64>, vector<[2]xf64>) -> ()
+  // CHECK: call void @llvm.aarch64.sme.mops.nxv4f32
+  "arm_sme.intr.mops"(%c0, %nxv4i1, %nxv4i1, %nxv4f32, %nxv4f32) :
+    (i32, vector<[4]xi1>, vector<[4]xi1>, vector<[4]xf32>, vector<[4]xf32>) -> ()
+  // CHECK: call void @llvm.aarch64.sme.mops.wide.nxv8f16
+  "arm_sme.intr.mops.wide"(%c0, %nxv8i1, %nxv8i1, %nxv8f16, %nxv8f16) :
+    (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xf16>, vector<[8]xf16>) -> ()
+  // CHECK: call void @llvm.aarch64.sme.mops.wide.nxv8bf16
+  "arm_sme.intr.mops.wide"(%c0, %nxv8i1, %nxv8i1, %nxv8bf16, %nxv8bf16) :
+    (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xbf16>, vector<[8]xbf16>) -> ()
+  llvm.return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_imops
+llvm.func @arm_sme_imops(%nxv8i16 : vector<[8]xi16>,
+                         %nxv16i8 : vector<[16]xi8>,
+                         %nxv8i1  : vector<[8]xi1>,
+                         %nxv16i1 : vector<[16]xi1>) {
+  %c0 = llvm.mlir.constant(0 : index) : i32
+  // CHECK: call void @llvm.aarch64.sme.smops.wide.nxv8i16
+  "arm_sme.intr.smops.wide"(%c0, %nxv8i1, %nxv8i1, %nxv8i16, %nxv8i16) :
+    (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
+  // CHECK: call void @llvm.aarch64.sme.umops.wide.nxv8i16
+  "arm_sme.intr.umops.wide"(%c0, %nxv8i1, %nxv8i1, %nxv8i16, %nxv8i16) :
+    (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
+  // CHECK: call void @llvm.aarch64.sme.sumops.wide.nxv8i16
+  "arm_sme.intr.sumops.wide"(%c0, %nxv8i1, %nxv8i1, %nxv8i16, %nxv8i16) :
+    (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
+  // CHECK: call void @llvm.aarch64.sme.usmops.wide.nxv8i16
+  "arm_sme.intr.usmops.wide"(%c0, %nxv8i1, %nxv8i1, %nxv8i16, %nxv8i16) :
+    (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
+  // CHECK: call void @llvm.aarch64.sme.smops.wide.nxv16i8
+  "arm_sme.intr.smops.wide"(%c0, %nxv16i1, %nxv16i1, %nxv16i8, %nxv16i8) :
+    (i32, vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
+  // CHECK: call void @llvm.aarch64.sme.umops.wide.nxv16i8
+  "arm_sme.intr.umops.wide"(%c0, %nxv16i1, %nxv16i1, %nxv16i8, %nxv16i8) :
+    (i32, vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
+  // CHECK: call void @llvm.aarch64.sme.sumops.wide.nxv16i8
+  "arm_sme.intr.sumops.wide"(%c0, %nxv16i1, %nxv16i1, %nxv16i8, %nxv16i8) :
+    (i32, vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
+  // CHECK: call void @llvm.aarch64.sme.usmops.wide.nxv16i8
+  "arm_sme.intr.usmops.wide"(%c0, %nxv16i1, %nxv16i1, %nxv16i8, %nxv16i8) :
+    (i32, vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
+  llvm.return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_load
+llvm.func @arm_sme_load(%nxv1i1  : vector<[1]xi1>,
+                        %nxv2i1  : vector<[2]xi1>,
+                        %nxv4i1  : vector<[4]xi1>,
+                        %nxv8i1  : vector<[8]xi1>,
+                        %nxv16i1 : vector<[16]xi1>,
+                        %p8      : !llvm.ptr<i8>,
+                        %p16     : !llvm.ptr<i16>,
+                        %p32     : !llvm.ptr<i32>,
+                        %p64     : !llvm.ptr<i64>,
+                        %p128    : !llvm.ptr<i128>) {
+  %c0 = llvm.mlir.constant(0 : index) : i32
+  // CHECK: call void @llvm.aarch64.sme.ld1q.horiz
+  "arm_sme.intr.ld1q.horiz"(%nxv1i1, %p128, %c0, %c0) :
+              (vector<[1]xi1>, !llvm.ptr<i128>, i32, i32) -> ()
+  // CHECK: call void @llvm.aarch64.sme.ld1d.horiz
+  "arm_sme.intr.ld1d.horiz"(%nxv2i1, %p64, %c0, %c0) :
+              (vector<[2]xi1>, !llvm.ptr<i64>, i32, i32) -> ()
+  // CHECK: call void @llvm.aarch64.sme.ld1w.horiz
+  "arm_sme.intr.ld1w.horiz"(%nxv4i1, %p32, %c0, %c0) :
+              (vector<[4]xi1>, !llvm.ptr<i32>, i32, i32) -> ()
+  // CHECK: call void @llvm.aarch64.sme.ld1h.horiz
+  "arm_sme.intr.ld1h.horiz"(%nxv8i1, %p16, %c0, %c0) :
+              (vector<[8]xi1>, !llvm.ptr<i16>, i32, i32) -> ()
+  // CHECK: call void @llvm.aarch64.sme.ld1b.horiz
+  "arm_sme.intr.ld1b.horiz"(%nxv16i1, %p8, %c0, %c0) :
+              (vector<[16]xi1>, !llvm.ptr<i8>, i32, i32) -> ()
+  // CHECK: call void @llvm.aarch64.sme.ld1q.vert
+  "arm_sme.intr.ld1q.vert"(%nxv1i1, %p128, %c0, %c0) :
+              (vector<[1]xi1>, !llvm.ptr<i128>, i32, i32) -> ()
+  // CHECK: call void @llvm.aarch64.sme.ld1d.vert
+  "arm_sme.intr.ld1d.vert"(%nxv2i1, %p64, %c0, %c0) :
+              (vector<[2]xi1>, !llvm.ptr<i64>, i32, i32) -> ()
+  // CHECK: call void @llvm.aarch64.sme.ld1w.vert
+  "arm_sme.intr.ld1w.vert"(%nxv4i1, %p32, %c0, %c0) :
+              (vector<[4]xi1>, !llvm.ptr<i32>, i32, i32) -> ()
+  // CHECK: call void @llvm.aarch64.sme.ld1h.vert
+  "arm_sme.intr.ld1h.vert"(%nxv8i1, %p16, %c0, %c0) :
+              (vector<[8]xi1>, !llvm.ptr<i16>, i32, i32) -> ()
+  // CHECK: call void @llvm.aarch64.sme.ld1b.vert
+  "arm_sme.intr.ld1b.vert"(%nxv16i1, %p8, %c0, %c0) :
+              (vector<[16]xi1>, !llvm.ptr<i8>, i32, i32) -> ()
+  llvm.return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_store
+llvm.func @arm_sme_store(%nxv1i1  : vector<[1]xi1>,
+                         %nxv2i1  : vector<[2]xi1>,
+                         %nxv4i1  : vector<[4]xi1>,
+                         %nxv8i1  : vector<[8]xi1>,
+                         %nxv16i1 : vector<[16]xi1>,
+                         %p8      : !llvm.ptr<i8>,
+                         %p16     : !llvm.ptr<i16>,
+                         %p32     : !llvm.ptr<i32>,
+                         %p64     : !llvm.ptr<i64>,
+                         %p128    : !llvm.ptr<i128>) {
+  %c0 = llvm.mlir.constant(0 : index) : i32
+  // CHECK: call void @llvm.aarch64.sme.st1q.horiz
+  "arm_sme.intr.st1q.horiz"(%nxv1i1, %p128, %c0, %c0) :
+              (vector<[1]xi1>, !llvm.ptr<i128>, i32, i32) -> ()
+  // CHECK: call void @llvm.aarch64.sme.st1d.horiz
+  "arm_sme.intr.st1d.horiz"(%nxv2i1, %p64, %c0, %c0) :
+              (vector<[2]xi1>, !llvm.ptr<i64>, i32, i32) -> ()
+  // CHECK: call void @llvm.aarch64.sme.st1w.horiz
+  "arm_sme.intr.st1w.horiz"(%nxv4i1, %p32, %c0, %c0) :
+              (vector<[4]xi1>, !llvm.ptr<i32>, i32, i32) -> ()
+  // CHECK: call void @llvm.aarch64.sme.st1h.horiz
+  "arm_sme.intr.st1h.horiz"(%nxv8i1, %p16, %c0, %c0) :
+              (vector<[8]xi1>, !llvm.ptr<i16>, i32, i32) -> ()
+  // CHECK: call void @llvm.aarch64.sme.st1b.horiz
+  "arm_sme.intr.st1b.horiz"(%nxv16i1, %p8, %c0, %c0) :
+              (vector<[16]xi1>, !llvm.ptr<i8>, i32, i32) -> ()
+  // CHECK: call void @llvm.aarch64.sme.st1q.vert
+  "arm_sme.intr.st1q.vert"(%nxv1i1, %p128, %c0, %c0) :
+              (vector<[1]xi1>, !llvm.ptr<i128>, i32, i32) -> ()
+  // CHECK: call void @llvm.aarch64.sme.st1d.vert
+  "arm_sme.intr.st1d.vert"(%nxv2i1, %p64, %c0, %c0) :
+              (vector<[2]xi1>, !llvm.ptr<i64>, i32, i32) -> ()
+  // CHECK: call void @llvm.aarch64.sme.st1w.vert
+  "arm_sme.intr.st1w.vert"(%nxv4i1, %p32, %c0, %c0) :
+              (vector<[4]xi1>, !llvm.ptr<i32>, i32, i32) -> ()
+  // CHECK: call void @llvm.aarch64.sme.st1h.vert
+  "arm_sme.intr.st1h.vert"(%nxv8i1, %p16, %c0, %c0) :
+              (vector<[8]xi1>, !llvm.ptr<i16>, i32, i32) -> ()
+  // CHECK: call void @llvm.aarch64.sme.st1b.vert
+  "arm_sme.intr.st1b.vert"(%nxv16i1, %p8, %c0, %c0) :
+              (vector<[16]xi1>, !llvm.ptr<i8>, i32, i32) -> ()
+  llvm.return
+}
index 1b2ab3f..7400f46 100644 (file)
@@ -6,6 +6,7 @@
 // CHECK-SAME: amx
 // CHECK-SAME: arith
 // CHECK-SAME: arm_neon
+// CHECK-SAME: arm_sme
 // CHECK-SAME: arm_sve
 // CHECK-SAME: async
 // CHECK-SAME: bufferization