[mlir][ArmSME] Introduce custom TypeConverter for ArmSME
authorAndrzej Warzynski <andrzej.warzynski@arm.com>
Fri, 14 Jul 2023 11:55:13 +0000 (11:55 +0000)
committerAndrzej Warzynski <andrzej.warzynski@arm.com>
Tue, 18 Jul 2023 09:35:32 +0000 (09:35 +0000)
At the moment, SME-to-LLVM lowerings rely entirely on
`LLVMTypeConverter`. This patch introduces a dedicated `TypeConverter`
that inherits from `LLVMTypeConverter` (it will also be used when
lowering ArmSME Ops to LLVM).

The new type converter merely disables lowerings for `VectorType` to
prevent 2-d scalable vectors (common in the context of ArmSME), e.g.

   `vector<[16]x[16]xi8>`,

entering the LLVM Type converter. LLVM does not support arrays of
scalable vectors and hence the need for specialisation. In the case of
SME such types are effectively eliminated when emitting LLVM IR
intrinsics for SME.

Differential Revision: https://reviews.llvm.org/D155365

mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
mlir/lib/Dialect/ArmSME/Transforms/ArmSMETypeConverter.cpp [new file with mode: 0644]
mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt

index 133968b..ab5c179 100644 (file)
@@ -9,6 +9,7 @@
 #ifndef MLIR_DIALECT_ARMSME_TRANSFORMS_PASSES_H
 #define MLIR_DIALECT_ARMSME_TRANSFORMS_PASSES_H
 
+#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
 #include "mlir/Pass/Pass.h"
 
 namespace mlir {
@@ -16,6 +17,9 @@ namespace mlir {
 class RewritePatternSet;
 
 namespace arm_sme {
+//===----------------------------------------------------------------------===//
+// The EnableArmStreaming pass.
+//===----------------------------------------------------------------------===//
 // Options for Armv9 Streaming SVE mode. By default, streaming-mode is part of
 // the function interface (ABI) and the caller manages PSTATE.SM on entry/exit.
 // In a locally streaming function PSTATE.SM is kept internal and the callee
@@ -34,6 +38,14 @@ createEnableArmStreamingPass(const ArmStreaming mode = ArmStreaming::Default,
 std::unique_ptr<Pass> createTileAllocationPass();
 
 //===----------------------------------------------------------------------===//
+// Type ArmSMETypeConverter pass.
+//===----------------------------------------------------------------------===//
+class ArmSMETypeConverter : public LLVMTypeConverter {
+public:
+  ArmSMETypeConverter(MLIRContext *ctx, const LowerToLLVMOptions &options);
+};
+
+//===----------------------------------------------------------------------===//
 // Registration
 //===----------------------------------------------------------------------===//
 
index e4a5528..bb92d65 100644 (file)
@@ -17,6 +17,7 @@ add_mlir_conversion_library(MLIRVectorToLLVM
   MLIRArmNeonDialect
   MLIRArmSMEDialect
   MLIRArmSMETransforms
+  MLIRVectorToArmSME
   MLIRArmSVEDialect
   MLIRArmSVETransforms
   MLIRAMXDialect
index acc4244..04570a7 100644 (file)
@@ -15,6 +15,7 @@
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
 #include "mlir/Dialect/ArmSME/IR/ArmSME.h"
+#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
 #include "mlir/Dialect/ArmSME/Transforms/Transforms.h"
 #include "mlir/Dialect/ArmSVE/ArmSVEDialect.h"
 #include "mlir/Dialect/ArmSVE/Transforms.h"
@@ -96,6 +97,8 @@ void LowerVectorToLLVMPass::runOnOperation() {
   target.addLegalDialect<arith::ArithDialect>();
   target.addLegalDialect<memref::MemRefDialect>();
   target.addLegalOp<UnrealizedConversionCastOp>();
+  arm_sme::ArmSMETypeConverter armSMEConverter(&getContext(), options);
+
   if (armNeon) {
     // TODO: we may or may not want to include in-dialect lowering to
     // LLVM-compatible operations here. So far, all operations in the dialect
@@ -108,7 +111,7 @@ void LowerVectorToLLVMPass::runOnOperation() {
   }
   if (armSME) {
     configureArmSMELegalizeForExportTarget(target);
-    populateArmSMELegalizeForLLVMExportPatterns(converter, patterns);
+    populateArmSMELegalizeForLLVMExportPatterns(armSMEConverter, patterns);
   }
   if (amx) {
     configureAMXLegalizeForExportTarget(target);
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/ArmSMETypeConverter.cpp b/mlir/lib/Dialect/ArmSME/Transforms/ArmSMETypeConverter.cpp
new file mode 100644 (file)
index 0000000..1cefc22
--- /dev/null
@@ -0,0 +1,22 @@
+//===- ArmSMETypeConverter.cpp - Convert builtin to LLVM dialect types ----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
+
+using namespace mlir;
+arm_sme::ArmSMETypeConverter::ArmSMETypeConverter(
+    MLIRContext *ctx, const LowerToLLVMOptions &options)
+    : LLVMTypeConverter(ctx, options) {
+  // Disable LLVM type conversion for vectors. This is to prevent 2-d scalable
+  // vectors (common in the context of ArmSME), e.g.
+  //    `vector<[16]x[16]xi8>`,
+  // entering the LLVM Type converter. LLVM does not support arrays of scalable
+  // vectors, but in the case of SME such types are effectively eliminated when
+  // emitting ArmSME LLVM IR intrinsics.
+  addConversion([&](VectorType type) { return type; });
+}
index 247da2a..991beae 100644 (file)
@@ -1,4 +1,5 @@
 add_mlir_dialect_library(MLIRArmSMETransforms
+  ArmSMETypeConverter.cpp
   EnableArmStreaming.cpp
   LegalizeForLLVMExport.cpp
   TileAllocation.cpp