[mlir] Avoid needlessly converting LLVM named structs with compatible elements
authorAlex Zinenko <zinenko@google.com>
Fri, 3 Dec 2021 12:13:43 +0000 (13:13 +0100)
committerAlex Zinenko <zinenko@google.com>
Mon, 6 Dec 2021 12:42:11 +0000 (13:42 +0100)
Conversion of LLVM named structs leads to them being renamed since we cannot
modify the body of the struct type once it is set. Previously, this applied to
all named struct types, even if their element types were not affected by the
conversion. Make this behvaior only applicable when element types are changed.
This requires making the LLVM dialect type-compatibility check recursively look
at the element types (arguably, it should have been doing than since the moment
the LLVM dialect type system stopped being closed). In addition, have a more
lax check for outer types only to avoid repeated check when necessary (e.g.,
parser, verifiers that are going to also look at the inner type).

Reviewed By: wsmoses

Differential Revision: https://reviews.llvm.org/D115037

mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
mlir/test/Conversion/StandardToLLVM/convert-types.mlir

index 3de4f3ab60dbe5c8319ef29fdea5baa101734813..6abede4b8c557620a0f38ec4fd6cc30a74dfee36 100644 (file)
@@ -429,6 +429,10 @@ void printType(Type type, AsmPrinter &printer);
 /// Returns `true` if the given type is compatible with the LLVM dialect.
 bool isCompatibleType(Type type);
 
+/// Returns `true` if the given outer type is compatible with the LLVM dialect
+/// without checking its potential nested types such as struct elements.
+bool isCompatibleOuterType(Type type);
+
 /// Returns `true` if the given type is a floating-point type compatible with
 /// the LLVM dialect.
 bool isCompatibleFloatingPointType(Type type);
index cd6651cbcf6eb52de22cb9b18e6160bfd607ac6a..5175b93a930312c9bc6a156dcabdbabfd0d7b1d4 100644 (file)
@@ -55,6 +55,12 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
   });
   addConversion([&](LLVM::LLVMStructType type, SmallVectorImpl<Type> &results,
                     ArrayRef<Type> callStack) -> llvm::Optional<LogicalResult> {
+    // Fastpath for types that won't be converted by this callback anyway.
+    if (LLVM::isCompatibleType(type)) {
+      results.push_back(type);
+      return success();
+    }
+
     if (type.isIdentified()) {
       auto convertedType = LLVM::LLVMStructType::getIdentified(
           type.getContext(), ("_Converted_" + type.getName()).str());
index ca1bbbf59a6946d0d0a74ae194cd5c8493d39012..bca0a5800d91965a13e1f4e2ae6108cc8650d76c 100644 (file)
@@ -468,7 +468,7 @@ Type mlir::LLVM::detail::parseType(DialectAsmParser &parser) {
   Type type = dispatchParse(parser, /*allowAny=*/false);
   if (!type)
     return type;
-  if (!isCompatibleType(type)) {
+  if (!isCompatibleOuterType(type)) {
     parser.emitError(loc) << "unexpected type, expected keyword";
     return nullptr;
   }
index ea4d7a69c0633b9190235974b2e66cf3a31f2505..1ce9e4481e9c0ff8a73131d720e907db21143eed 100644 (file)
@@ -1,4 +1,3 @@
-//===- LLVMTypes.cpp - MLIR LLVM Dialect types ----------------------------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -19,6 +18,7 @@
 #include "mlir/IR/DialectImplementation.h"
 #include "mlir/IR/TypeSupport.h"
 
+#include "llvm/ADT/ScopeExit.h"
 #include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/TypeSize.h"
 
@@ -120,9 +120,10 @@ LLVMFunctionType::verify(function_ref<InFlightDiagnostic()> emitError,
 //===----------------------------------------------------------------------===//
 
 bool LLVMPointerType::isValidElementType(Type type) {
-  return isCompatibleType(type) ? !type.isa<LLVMVoidType, LLVMTokenType,
-                                            LLVMMetadataType, LLVMLabelType>()
-                                : type.isa<PointerElementTypeInterface>();
+  return isCompatibleOuterType(type)
+             ? !type.isa<LLVMVoidType, LLVMTokenType, LLVMMetadataType,
+                         LLVMLabelType>()
+             : type.isa<PointerElementTypeInterface>();
 }
 
 LLVMPointerType LLVMPointerType::get(Type pointee, unsigned addressSpace) {
@@ -483,17 +484,9 @@ LLVMScalableVectorType::verify(function_ref<InFlightDiagnostic()> emitError,
 // Utility functions.
 //===----------------------------------------------------------------------===//
 
-bool mlir::LLVM::isCompatibleType(Type type) {
-  // Only signless integers are compatible.
-  if (auto intType = type.dyn_cast<IntegerType>())
-    return intType.isSignless();
-
-  // 1D vector types are compatible if their element types are.
-  if (auto vecType = type.dyn_cast<VectorType>())
-    return vecType.getRank() == 1 && isCompatibleType(vecType.getElementType());
-
+bool mlir::LLVM::isCompatibleOuterType(Type type) {
   // clang-format off
-  return type.isa<
+  if (type.isa<
       BFloat16Type,
       Float16Type,
       Float32Type,
@@ -512,8 +505,75 @@ bool mlir::LLVM::isCompatibleType(Type type) {
       LLVMScalableVectorType,
       LLVMVoidType,
       LLVMX86MMXType
-  >();
-  // clang-format on
+    >()) {
+    // clang-format on
+    return true;
+  }
+
+  // Only signless integers are compatible.
+  if (auto intType = type.dyn_cast<IntegerType>())
+    return intType.isSignless();
+
+  // 1D vector types are compatible.
+  if (auto vecType = type.dyn_cast<VectorType>())
+    return vecType.getRank() == 1;
+
+  return false;
+}
+
+static bool isCompatibleImpl(Type type, SetVector<Type> &callstack) {
+  if (callstack.contains(type))
+    return true;
+
+  callstack.insert(type);
+  auto stackPopper = llvm::make_scope_exit([&] { callstack.pop_back(); });
+
+  auto isCompatible = [&](Type type) {
+    return isCompatibleImpl(type, callstack);
+  };
+
+  return llvm::TypeSwitch<Type, bool>(type)
+      .Case<LLVMStructType>([&](auto structType) {
+        return llvm::all_of(structType.getBody(), isCompatible);
+      })
+      .Case<LLVMFunctionType>([&](auto funcType) {
+        return isCompatible(funcType.getReturnType()) &&
+               llvm::all_of(funcType.getParams(), isCompatible);
+      })
+      .Case<IntegerType>([](auto intType) { return intType.isSignless(); })
+      .Case<VectorType>([&](auto vecType) {
+        return vecType.getRank() == 1 && isCompatible(vecType.getElementType());
+      })
+      // clang-format off
+      .Case<
+          LLVMPointerType,
+          LLVMFixedVectorType,
+          LLVMScalableVectorType,
+          LLVMArrayType
+      >([&](auto containerType) {
+        return isCompatible(containerType.getElementType());
+      })
+      .Case<
+        BFloat16Type,
+        Float16Type,
+        Float32Type,
+        Float64Type,
+        Float80Type,
+        Float128Type,
+        LLVMLabelType,
+        LLVMMetadataType,
+        LLVMPPCFP128Type,
+        LLVMTokenType,
+        LLVMVoidType,
+        LLVMX86MMXType
+      >([](Type) { return true; })
+      // clang-format on
+      .Default([](Type) { return false; });
+}
+
+bool mlir::LLVM::isCompatibleType(Type type) {
+  SetVector<Type> callstack;
+  return isCompatibleImpl(type, callstack);
 }
 
 bool mlir::LLVM::isCompatibleFloatingPointType(Type type) {
index 298ba1a6a0991bb0d646a244c6b62ed1067e8d4d..8c214a8f03e9c2797a46157a425e08d063c8f43e 100644 (file)
@@ -16,6 +16,10 @@ func private @struct_ptr() -> !llvm.struct<(ptr<!test.smpla>)>
 // CHECK: !llvm.struct<"_Converted_named", (ptr<i42>)>
 func private @named_struct_ptr() -> !llvm.struct<"named", (ptr<!test.smpla>)>
 
+// CHECK-LABEL: @named_no_convert
+// CHECK: !llvm.struct<"no_convert", (ptr<struct<"no_convert">>)>
+func private @named_no_convert() -> !llvm.struct<"no_convert", (ptr<struct<"no_convert">>)>
+
 // CHECK-LABEL: @array_ptr()
 // CHECK: !llvm.array<10 x ptr<i42>> 
 func private @array_ptr() -> !llvm.array<10 x ptr<!test.smpla>>