[mlir] translate types between MLIR LLVM dialect and LLVM IR
authorAlex Zinenko <zinenko@google.com>
Tue, 4 Aug 2020 09:37:25 +0000 (11:37 +0200)
committerAlex Zinenko <zinenko@google.com>
Tue, 4 Aug 2020 11:42:43 +0000 (13:42 +0200)
With new LLVM dialect type modeling, the dialect types no longer wrap LLVM IR
types. Therefore, they need to be translated to and from LLVM IR during export
and import. Introduce the relevant functionality for translating types. It is
currently exercised by an ad-hoc type translation roundtripping test that will
be subsumed by the actual translation test when the type system transition is
complete.

Depends On D84339

Reviewed By: herhut

Differential Revision: https://reviews.llvm.org/D85019

mlir/include/mlir/Target/LLVMIR/TypeTranslation.h [new file with mode: 0644]
mlir/lib/Target/CMakeLists.txt
mlir/lib/Target/LLVMIR/TypeTranslation.cpp [new file with mode: 0644]
mlir/test/Target/llvmir-types.mlir [new file with mode: 0644]
mlir/test/lib/CMakeLists.txt
mlir/test/lib/Target/CMakeLists.txt [new file with mode: 0644]
mlir/test/lib/Target/TestLLVMTypeTranslation.cpp [new file with mode: 0644]
mlir/tools/mlir-translate/CMakeLists.txt
mlir/tools/mlir-translate/mlir-translate.cpp

diff --git a/mlir/include/mlir/Target/LLVMIR/TypeTranslation.h b/mlir/include/mlir/Target/LLVMIR/TypeTranslation.h
new file mode 100644 (file)
index 0000000..5a82f0a
--- /dev/null
@@ -0,0 +1,36 @@
+//===- TypeTranslation.h - Translate types between MLIR & LLVM --*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares the type translation function going from MLIR LLVM dialect
+// to LLVM IR and back.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TARGET_LLVMIR_TYPETRANSLATION_H
+#define MLIR_TARGET_LLVMIR_TYPETRANSLATION_H
+
+namespace llvm {
+class LLVMContext;
+class Type;
+} // namespace llvm
+
+namespace mlir {
+
+class MLIRContext;
+
+namespace LLVM {
+
+class LLVMTypeNew;
+
+llvm::Type *translateTypeToLLVMIR(LLVMTypeNew type, llvm::LLVMContext &context);
+LLVMTypeNew translateTypeFromLLVMIR(llvm::Type *type, MLIRContext &context);
+
+} // namespace LLVM
+} // namespace mlir
+
+#endif // MLIR_TARGET_LLVMIR_TYPETRANSLATION_H
index 4a0af66..5ca335b 100644 (file)
@@ -1,6 +1,7 @@
 add_mlir_translation_library(MLIRTargetLLVMIRModuleTranslation
   LLVMIR/DebugTranslation.cpp
   LLVMIR/ModuleTranslation.cpp
+  LLVMIR/TypeTranslation.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Target/LLVMIR
diff --git a/mlir/lib/Target/LLVMIR/TypeTranslation.cpp b/mlir/lib/Target/LLVMIR/TypeTranslation.cpp
new file mode 100644 (file)
index 0000000..6163334
--- /dev/null
@@ -0,0 +1,309 @@
+//===- TypeTranslation.cpp - type translation between MLIR LLVM & LLVM IR -===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Target/LLVMIR/TypeTranslation.h"
+#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
+#include "mlir/IR/MLIRContext.h"
+
+#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/Type.h"
+
+using namespace mlir;
+
+namespace {
+/// Support for translating MLIR LLVM dialect types to LLVM IR.
+class TypeToLLVMIRTranslator {
+public:
+  /// Constructs a class creating types in the given LLVM context.
+  TypeToLLVMIRTranslator(llvm::LLVMContext &context) : context(context) {}
+
+  /// Translates a single type.
+  llvm::Type *translateType(LLVM::LLVMTypeNew type) {
+    // If the conversion is already known, just return it.
+    if (knownTranslations.count(type))
+      return knownTranslations.lookup(type);
+
+    // Dispatch to an appropriate function.
+    llvm::Type *translated =
+        llvm::TypeSwitch<LLVM::LLVMTypeNew, llvm::Type *>(type)
+            .Case([this](LLVM::LLVMVoidType) {
+              return llvm::Type::getVoidTy(context);
+            })
+            .Case([this](LLVM::LLVMHalfType) {
+              return llvm::Type::getHalfTy(context);
+            })
+            .Case([this](LLVM::LLVMBFloatType) {
+              return llvm::Type::getBFloatTy(context);
+            })
+            .Case([this](LLVM::LLVMFloatType) {
+              return llvm::Type::getFloatTy(context);
+            })
+            .Case([this](LLVM::LLVMDoubleType) {
+              return llvm::Type::getDoubleTy(context);
+            })
+            .Case([this](LLVM::LLVMFP128Type) {
+              return llvm::Type::getFP128Ty(context);
+            })
+            .Case([this](LLVM::LLVMX86FP80Type) {
+              return llvm::Type::getX86_FP80Ty(context);
+            })
+            .Case([this](LLVM::LLVMPPCFP128Type) {
+              return llvm::Type::getPPC_FP128Ty(context);
+            })
+            .Case([this](LLVM::LLVMX86MMXType) {
+              return llvm::Type::getX86_MMXTy(context);
+            })
+            .Case([this](LLVM::LLVMTokenType) {
+              return llvm::Type::getTokenTy(context);
+            })
+            .Case([this](LLVM::LLVMLabelType) {
+              return llvm::Type::getLabelTy(context);
+            })
+            .Case([this](LLVM::LLVMMetadataType) {
+              return llvm::Type::getMetadataTy(context);
+            })
+            .Case<LLVM::LLVMArrayType, LLVM::LLVMIntegerType,
+                  LLVM::LLVMFunctionType, LLVM::LLVMPointerType,
+                  LLVM::LLVMStructType, LLVM::LLVMFixedVectorType,
+                  LLVM::LLVMScalableVectorType>(
+                [this](auto array) { return translate(array); })
+            .Default([](LLVM::LLVMTypeNew t) -> llvm::Type * {
+              llvm_unreachable("unknown LLVM dialect type");
+            });
+
+    // Cache the result of the conversion and return.
+    knownTranslations.try_emplace(type, translated);
+    return translated;
+  }
+
+private:
+  /// Translates the given array type.
+  llvm::Type *translate(LLVM::LLVMArrayType type) {
+    return llvm::ArrayType::get(translateType(type.getElementType()),
+                                type.getNumElements());
+  }
+
+  /// Translates the given function type.
+  llvm::Type *translate(LLVM::LLVMFunctionType type) {
+    SmallVector<llvm::Type *, 8> paramTypes;
+    translateTypes(type.getParams(), paramTypes);
+    return llvm::FunctionType::get(translateType(type.getReturnType()),
+                                   paramTypes, type.isVarArg());
+  }
+
+  /// Translates the given integer type.
+  llvm::Type *translate(LLVM::LLVMIntegerType type) {
+    return llvm::IntegerType::get(context, type.getBitWidth());
+  }
+
+  /// Translates the given pointer type.
+  llvm::Type *translate(LLVM::LLVMPointerType type) {
+    return llvm::PointerType::get(translateType(type.getElementType()),
+                                  type.getAddressSpace());
+  }
+
+  /// Translates the given structure type, supports both identified and literal
+  /// structs. This will _create_ a new identified structure every time, use
+  /// `convertType` if a structure with the same name must be looked up instead.
+  llvm::Type *translate(LLVM::LLVMStructType type) {
+    SmallVector<llvm::Type *, 8> subtypes;
+    if (!type.isIdentified()) {
+      translateTypes(type.getBody(), subtypes);
+      return llvm::StructType::get(context, subtypes, type.isPacked());
+    }
+
+    llvm::StructType *structType =
+        llvm::StructType::create(context, type.getName());
+    // Mark the type we just created as known so that recursive calls can pick
+    // it up and use directly.
+    knownTranslations.try_emplace(type, structType);
+    if (type.isOpaque())
+      return structType;
+
+    translateTypes(type.getBody(), subtypes);
+    structType->setBody(subtypes, type.isPacked());
+    return structType;
+  }
+
+  /// Translates the given fixed-vector type.
+  llvm::Type *translate(LLVM::LLVMFixedVectorType type) {
+    return llvm::FixedVectorType::get(translateType(type.getElementType()),
+                                      type.getNumElements());
+  }
+
+  /// Translates the given scalable-vector type.
+  llvm::Type *translate(LLVM::LLVMScalableVectorType type) {
+    return llvm::ScalableVectorType::get(translateType(type.getElementType()),
+                                         type.getMinNumElements());
+  }
+
+  /// Translates a list of types.
+  void translateTypes(ArrayRef<LLVM::LLVMTypeNew> types,
+                      SmallVectorImpl<llvm::Type *> &result) {
+    result.reserve(result.size() + types.size());
+    for (auto type : types)
+      result.push_back(translateType(type));
+  }
+
+  /// Reference to the context in which the LLVM IR types are created.
+  llvm::LLVMContext &context;
+
+  /// Map of known translation. This serves a double purpose: caches translation
+  /// results to avoid repeated recursive calls and makes sure identified
+  /// structs with the same name (that is, equal) are resolved to an existing
+  /// type instead of creating a new type.
+  llvm::DenseMap<LLVM::LLVMTypeNew, llvm::Type *> knownTranslations;
+};
+} // end namespace
+
+/// Translates a type from MLIR LLVM dialect to LLVM IR. This does not maintain
+/// the mapping for identified structs so new structs will be created with
+/// auto-renaming on each call. This is intended exclusively for testing.
+llvm::Type *mlir::LLVM::translateTypeToLLVMIR(LLVM::LLVMTypeNew type,
+                                              llvm::LLVMContext &context) {
+  return TypeToLLVMIRTranslator(context).translateType(type);
+}
+
+namespace {
+/// Support for translating LLVM IR types to MLIR LLVM dialect types.
+class TypeFromLLVMIRTranslator {
+public:
+  /// Constructs a class creating types in the given MLIR context.
+  TypeFromLLVMIRTranslator(MLIRContext &context) : context(context) {}
+
+  /// Translates the given type.
+  LLVM::LLVMTypeNew translateType(llvm::Type *type) {
+    if (knownTranslations.count(type))
+      return knownTranslations.lookup(type);
+
+    LLVM::LLVMTypeNew translated =
+        llvm::TypeSwitch<llvm::Type *, LLVM::LLVMTypeNew>(type)
+            .Case<llvm::ArrayType, llvm::FunctionType, llvm::IntegerType,
+                  llvm::PointerType, llvm::StructType, llvm::FixedVectorType,
+                  llvm::ScalableVectorType>(
+                [this](auto *type) { return translate(type); })
+            .Default([this](llvm::Type *type) {
+              return translatePrimitiveType(type);
+            });
+    knownTranslations.try_emplace(type, translated);
+    return translated;
+  }
+
+private:
+  /// Translates the given primitive, i.e. non-parametric in MLIR nomenclature,
+  /// type.
+  LLVM::LLVMTypeNew translatePrimitiveType(llvm::Type *type) {
+    if (type->isVoidTy())
+      return LLVM::LLVMVoidType::get(&context);
+    if (type->isHalfTy())
+      return LLVM::LLVMHalfType::get(&context);
+    if (type->isBFloatTy())
+      return LLVM::LLVMBFloatType::get(&context);
+    if (type->isFloatTy())
+      return LLVM::LLVMFloatType::get(&context);
+    if (type->isDoubleTy())
+      return LLVM::LLVMDoubleType::get(&context);
+    if (type->isFP128Ty())
+      return LLVM::LLVMFP128Type::get(&context);
+    if (type->isX86_FP80Ty())
+      return LLVM::LLVMX86FP80Type::get(&context);
+    if (type->isPPC_FP128Ty())
+      return LLVM::LLVMPPCFP128Type::get(&context);
+    if (type->isX86_MMXTy())
+      return LLVM::LLVMX86MMXType::get(&context);
+    if (type->isLabelTy())
+      return LLVM::LLVMLabelType::get(&context);
+    if (type->isMetadataTy())
+      return LLVM::LLVMMetadataType::get(&context);
+    llvm_unreachable("not a primitive type");
+  }
+
+  /// Translates the given array type.
+  LLVM::LLVMTypeNew translate(llvm::ArrayType *type) {
+    return LLVM::LLVMArrayType::get(translateType(type->getElementType()),
+                                    type->getNumElements());
+  }
+
+  /// Translates the given function type.
+  LLVM::LLVMTypeNew translate(llvm::FunctionType *type) {
+    SmallVector<LLVM::LLVMTypeNew, 8> paramTypes;
+    translateTypes(type->params(), paramTypes);
+    return LLVM::LLVMFunctionType::get(translateType(type->getReturnType()),
+                                       paramTypes, type->isVarArg());
+  }
+
+  /// Translates the given integer type.
+  LLVM::LLVMTypeNew translate(llvm::IntegerType *type) {
+    return LLVM::LLVMIntegerType::get(&context, type->getBitWidth());
+  }
+
+  /// Translates the given pointer type.
+  LLVM::LLVMTypeNew translate(llvm::PointerType *type) {
+    return LLVM::LLVMPointerType::get(translateType(type->getElementType()),
+                                      type->getAddressSpace());
+  }
+
+  /// Translates the given structure type.
+  LLVM::LLVMTypeNew translate(llvm::StructType *type) {
+    SmallVector<LLVM::LLVMTypeNew, 8> subtypes;
+    if (type->isLiteral()) {
+      translateTypes(type->subtypes(), subtypes);
+      return LLVM::LLVMStructType::getLiteral(&context, subtypes,
+                                              type->isPacked());
+    }
+
+    if (type->isOpaque())
+      return LLVM::LLVMStructType::getOpaque(type->getName(), &context);
+
+    LLVM::LLVMStructType translated =
+        LLVM::LLVMStructType::getIdentified(&context, type->getName());
+    knownTranslations.try_emplace(type, translated);
+    translateTypes(type->subtypes(), subtypes);
+    LogicalResult bodySet = translated.setBody(subtypes, type->isPacked());
+    assert(succeeded(bodySet) &&
+           "could not set the body of an identified struct");
+    (void)bodySet;
+    return translated;
+  }
+
+  /// Translates the given fixed-vector type.
+  LLVM::LLVMTypeNew translate(llvm::FixedVectorType *type) {
+    return LLVM::LLVMFixedVectorType::get(translateType(type->getElementType()),
+                                          type->getNumElements());
+  }
+
+  /// Translates the given scalable-vector type.
+  LLVM::LLVMTypeNew translate(llvm::ScalableVectorType *type) {
+    return LLVM::LLVMScalableVectorType::get(
+        translateType(type->getElementType()), type->getMinNumElements());
+  }
+
+  /// Translates a list of types.
+  void translateTypes(ArrayRef<llvm::Type *> types,
+                      SmallVectorImpl<LLVM::LLVMTypeNew> &result) {
+    result.reserve(result.size() + types.size());
+    for (llvm::Type *type : types)
+      result.push_back(translateType(type));
+  }
+
+  /// Map of known translations. Serves as a cache and as recursion stopper for
+  /// translating recursive structs.
+  llvm::DenseMap<llvm::Type *, LLVM::LLVMTypeNew> knownTranslations;
+
+  /// The context in which MLIR types are created.
+  MLIRContext &context;
+};
+} // end namespace
+
+/// Translates a type from LLVM IR to MLIR LLVM dialect. This is intended
+/// exclusively for testing.
+LLVM::LLVMTypeNew mlir::LLVM::translateTypeFromLLVMIR(llvm::Type *type,
+                                                      MLIRContext &context) {
+  return TypeFromLLVMIRTranslator(context).translateType(type);
+}
diff --git a/mlir/test/Target/llvmir-types.mlir b/mlir/test/Target/llvmir-types.mlir
new file mode 100644 (file)
index 0000000..d807562
--- /dev/null
@@ -0,0 +1,228 @@
+// RUN: mlir-translate -test-mlir-to-llvmir -split-input-file %s | FileCheck %s
+
+llvm.func @primitives() {
+  // CHECK: declare void @return_void()
+  // CHECK: declare void @return_void_round()
+  "llvm.test_introduce_func"() { name = "return_void", type = !llvm2.void } : () -> ()
+  // CHECK: declare half @return_half()
+  // CHECK: declare half @return_half_round()
+  "llvm.test_introduce_func"() { name = "return_half", type = !llvm2.half } : () -> ()
+  // CHECK: declare bfloat @return_bfloat()
+  // CHECK: declare bfloat @return_bfloat_round()
+  "llvm.test_introduce_func"() { name = "return_bfloat", type = !llvm2.bfloat } : () -> ()
+  // CHECK: declare float @return_float()
+  // CHECK: declare float @return_float_round()
+  "llvm.test_introduce_func"() { name = "return_float", type = !llvm2.float } : () -> ()
+  // CHECK: declare double @return_double()
+  // CHECK: declare double @return_double_round()
+  "llvm.test_introduce_func"() { name = "return_double", type = !llvm2.double } : () -> ()
+  // CHECK: declare fp128 @return_fp128()
+  // CHECK: declare fp128 @return_fp128_round()
+  "llvm.test_introduce_func"() { name = "return_fp128", type = !llvm2.fp128 } : () -> ()
+  // CHECK: declare x86_fp80 @return_x86_fp80()
+  // CHECK: declare x86_fp80 @return_x86_fp80_round()
+  "llvm.test_introduce_func"() { name = "return_x86_fp80", type = !llvm2.x86_fp80 } : () -> ()
+  // CHECK: declare ppc_fp128 @return_ppc_fp128()
+  // CHECK: declare ppc_fp128 @return_ppc_fp128_round()
+  "llvm.test_introduce_func"() { name = "return_ppc_fp128", type = !llvm2.ppc_fp128 } : () -> ()
+  // CHECK: declare x86_mmx @return_x86_mmx()
+  // CHECK: declare x86_mmx @return_x86_mmx_round()
+  "llvm.test_introduce_func"() { name = "return_x86_mmx", type = !llvm2.x86_mmx } : () -> ()
+  llvm.return
+}
+
+llvm.func @funcs() {
+  // CHECK: declare void @f_void_i32(i32)
+  // CHECK: declare void @f_void_i32_round(i32)
+  "llvm.test_introduce_func"() { name ="f_void_i32", type = !llvm2.func<void (i32)> } : () -> ()
+  // CHECK: declare i32 @f_i32_empty()
+  // CHECK: declare i32 @f_i32_empty_round()
+  "llvm.test_introduce_func"() { name ="f_i32_empty", type = !llvm2.func<i32 ()> } : () -> ()
+  // CHECK: declare i32 @f_i32_half_bfloat_float_double(half, bfloat, float, double)
+  // CHECK: declare i32 @f_i32_half_bfloat_float_double_round(half, bfloat, float, double)
+  "llvm.test_introduce_func"() { name ="f_i32_half_bfloat_float_double", type = !llvm2.func<i32 (half, bfloat, float, double)> } : () -> ()
+  // CHECK: declare i32 @f_i32_i32_i32(i32, i32)
+  // CHECK: declare i32 @f_i32_i32_i32_round(i32, i32)
+  "llvm.test_introduce_func"() { name ="f_i32_i32_i32", type = !llvm2.func<i32 (i32, i32)> } : () -> ()
+  // CHECK: declare void @f_void_variadic(...)
+  // CHECK: declare void @f_void_variadic_round(...)
+  "llvm.test_introduce_func"() { name ="f_void_variadic", type = !llvm2.func<void (...)> } : () -> ()
+  // CHECK: declare void @f_void_i32_i32_variadic(i32, i32, ...)
+  // CHECK: declare void @f_void_i32_i32_variadic_round(i32, i32, ...)
+  "llvm.test_introduce_func"() { name ="f_void_i32_i32_variadic", type = !llvm2.func<void (i32, i32, ...)> } : () -> ()
+  llvm.return
+}
+
+llvm.func @ints() {
+  // CHECK: declare i1 @return_i1()
+  // CHECK: declare i1 @return_i1_round()
+  "llvm.test_introduce_func"() { name = "return_i1", type = !llvm2.i1 } : () -> ()
+  // CHECK: declare i8 @return_i8()
+  // CHECK: declare i8 @return_i8_round()
+  "llvm.test_introduce_func"() { name = "return_i8", type = !llvm2.i8 } : () -> ()
+  // CHECK: declare i16 @return_i16()
+  // CHECK: declare i16 @return_i16_round()
+  "llvm.test_introduce_func"() { name = "return_i16", type = !llvm2.i16 } : () -> ()
+  // CHECK: declare i32 @return_i32()
+  // CHECK: declare i32 @return_i32_round()
+  "llvm.test_introduce_func"() { name = "return_i32", type = !llvm2.i32 } : () -> ()
+  // CHECK: declare i64 @return_i64()
+  // CHECK: declare i64 @return_i64_round()
+  "llvm.test_introduce_func"() { name = "return_i64", type = !llvm2.i64 } : () -> ()
+  // CHECK: declare i57 @return_i57()
+  // CHECK: declare i57 @return_i57_round()
+  "llvm.test_introduce_func"() { name = "return_i57", type = !llvm2.i57 } : () -> ()
+  // CHECK: declare i129 @return_i129()
+  // CHECK: declare i129 @return_i129_round()
+  "llvm.test_introduce_func"() { name = "return_i129", type = !llvm2.i129 } : () -> ()
+  llvm.return
+}
+
+llvm.func @pointers() {
+  // CHECK: declare i8* @return_pi8()
+  // CHECK: declare i8* @return_pi8_round()
+  "llvm.test_introduce_func"() { name = "return_pi8", type = !llvm2.ptr<i8> } : () -> ()
+  // CHECK: declare float* @return_pfloat()
+  // CHECK: declare float* @return_pfloat_round()
+  "llvm.test_introduce_func"() { name = "return_pfloat", type = !llvm2.ptr<float> } : () -> ()
+  // CHECK: declare i8** @return_ppi8()
+  // CHECK: declare i8** @return_ppi8_round()
+  "llvm.test_introduce_func"() { name = "return_ppi8", type = !llvm2.ptr<ptr<i8>> } : () -> ()
+  // CHECK: declare i8***** @return_pppppi8()
+  // CHECK: declare i8***** @return_pppppi8_round()
+  "llvm.test_introduce_func"() { name = "return_pppppi8", type = !llvm2.ptr<ptr<ptr<ptr<ptr<i8>>>>> } : () -> ()
+  // CHECK: declare i8* @return_pi8_0()
+  // CHECK: declare i8* @return_pi8_0_round()
+  "llvm.test_introduce_func"() { name = "return_pi8_0", type = !llvm2.ptr<i8, 0> } : () -> ()
+  // CHECK: declare i8 addrspace(1)* @return_pi8_1()
+  // CHECK: declare i8 addrspace(1)* @return_pi8_1_round()
+  "llvm.test_introduce_func"() { name = "return_pi8_1", type = !llvm2.ptr<i8, 1> } : () -> ()
+  // CHECK: declare i8 addrspace(42)* @return_pi8_42()
+  // CHECK: declare i8 addrspace(42)* @return_pi8_42_round()
+  "llvm.test_introduce_func"() { name = "return_pi8_42", type = !llvm2.ptr<i8, 42> } : () -> ()
+  // CHECK: declare i8 addrspace(42)* addrspace(9)* @return_ppi8_42_9()
+  // CHECK: declare i8 addrspace(42)* addrspace(9)* @return_ppi8_42_9_round()
+  "llvm.test_introduce_func"() { name = "return_ppi8_42_9", type = !llvm2.ptr<ptr<i8, 42>, 9> } : () -> ()
+  llvm.return
+}
+
+llvm.func @vectors() {
+  // CHECK: declare <4 x i32> @return_v4_i32()
+  // CHECK: declare <4 x i32> @return_v4_i32_round()
+  "llvm.test_introduce_func"() { name = "return_v4_i32", type = !llvm2.vec<4 x i32> } : () -> ()
+  // CHECK: declare <4 x float> @return_v4_float()
+  // CHECK: declare <4 x float> @return_v4_float_round()
+  "llvm.test_introduce_func"() { name = "return_v4_float", type = !llvm2.vec<4 x float> } : () -> ()
+  // CHECK: declare <vscale x 4 x i32> @return_vs_4_i32()
+  // CHECK: declare <vscale x 4 x i32> @return_vs_4_i32_round()
+  "llvm.test_introduce_func"() { name = "return_vs_4_i32", type = !llvm2.vec<? x 4 x i32> } : () -> ()
+  // CHECK: declare <vscale x 8 x half> @return_vs_8_half()
+  // CHECK: declare <vscale x 8 x half> @return_vs_8_half_round()
+  "llvm.test_introduce_func"() { name = "return_vs_8_half", type = !llvm2.vec<? x 8 x half> } : () -> ()
+  // CHECK: declare <4 x i8*> @return_v_4_pi8()
+  // CHECK: declare <4 x i8*> @return_v_4_pi8_round()
+  "llvm.test_introduce_func"() { name = "return_v_4_pi8", type = !llvm2.vec<4 x ptr<i8>> } : () -> ()
+  llvm.return
+}
+
+llvm.func @arrays() {
+  // CHECK: declare [10 x i32] @return_a10_i32()
+  // CHECK: declare [10 x i32] @return_a10_i32_round()
+  "llvm.test_introduce_func"() { name = "return_a10_i32", type = !llvm2.array<10 x i32> } : () -> ()
+  // CHECK: declare [8 x float] @return_a8_float()
+  // CHECK: declare [8 x float] @return_a8_float_round()
+  "llvm.test_introduce_func"() { name = "return_a8_float", type = !llvm2.array<8 x float> } : () -> ()
+  // CHECK: declare [10 x i32 addrspace(4)*] @return_a10_pi32_4()
+  // CHECK: declare [10 x i32 addrspace(4)*] @return_a10_pi32_4_round()
+  "llvm.test_introduce_func"() { name = "return_a10_pi32_4", type = !llvm2.array<10 x ptr<i32, 4>> } : () -> ()
+  // CHECK: declare [10 x [4 x float]] @return_a10_a4_float()
+  // CHECK: declare [10 x [4 x float]] @return_a10_a4_float_round()
+  "llvm.test_introduce_func"() { name = "return_a10_a4_float", type = !llvm2.array<10 x array<4 x float>> } : () -> ()
+  llvm.return
+}
+
+llvm.func @literal_structs() {
+  // CHECK: declare {} @return_struct_empty()
+  // CHECK: declare {} @return_struct_empty_round()
+  "llvm.test_introduce_func"() { name = "return_struct_empty", type = !llvm2.struct<()> } : () -> ()
+  // CHECK: declare { i32 } @return_s_i32()
+  // CHECK: declare { i32 } @return_s_i32_round()
+  "llvm.test_introduce_func"() { name = "return_s_i32", type = !llvm2.struct<(i32)> } : () -> ()
+  // CHECK: declare { float, i32 } @return_s_float_i32()
+  // CHECK: declare { float, i32 } @return_s_float_i32_round()
+  "llvm.test_introduce_func"() { name = "return_s_float_i32", type = !llvm2.struct<(float, i32)> } : () -> ()
+  // CHECK: declare { { i32 } } @return_s_s_i32()
+  // CHECK: declare { { i32 } } @return_s_s_i32_round()
+  "llvm.test_introduce_func"() { name = "return_s_s_i32", type = !llvm2.struct<(struct<(i32)>)> } : () -> ()
+  // CHECK: declare { i32, { i32 }, float } @return_s_i32_s_i32_float()
+  // CHECK: declare { i32, { i32 }, float } @return_s_i32_s_i32_float_round()
+  "llvm.test_introduce_func"() { name = "return_s_i32_s_i32_float", type = !llvm2.struct<(i32, struct<(i32)>, float)> } : () -> ()
+
+  // CHECK: declare <{}> @return_sp_empty()
+  // CHECK: declare <{}> @return_sp_empty_round()
+  "llvm.test_introduce_func"() { name = "return_sp_empty", type = !llvm2.struct<packed ()> } : () -> ()
+  // CHECK: declare <{ i32 }> @return_sp_i32()
+  // CHECK: declare <{ i32 }> @return_sp_i32_round()
+  "llvm.test_introduce_func"() { name = "return_sp_i32", type = !llvm2.struct<packed (i32)> } : () -> ()
+  // CHECK: declare <{ float, i32 }> @return_sp_float_i32()
+  // CHECK: declare <{ float, i32 }> @return_sp_float_i32_round()
+  "llvm.test_introduce_func"() { name = "return_sp_float_i32", type = !llvm2.struct<packed (float, i32)> } : () -> ()
+  // CHECK: declare <{ i32, { i32, i1 }, float }> @return_sp_i32_s_i31_1_float()
+  // CHECK: declare <{ i32, { i32, i1 }, float }> @return_sp_i32_s_i31_1_float_round()
+  "llvm.test_introduce_func"() { name = "return_sp_i32_s_i31_1_float", type = !llvm2.struct<packed (i32, struct<(i32, i1)>, float)> } : () -> ()
+
+  // CHECK: declare { <{ i32 }> } @return_s_sp_i32()
+  // CHECK: declare { <{ i32 }> } @return_s_sp_i32_round()
+  "llvm.test_introduce_func"() { name = "return_s_sp_i32", type = !llvm2.struct<(struct<packed (i32)>)> } : () -> ()
+  // CHECK: declare <{ { i32 } }> @return_sp_s_i32()
+  // CHECK: declare <{ { i32 } }> @return_sp_s_i32_round()
+  "llvm.test_introduce_func"() { name = "return_sp_s_i32", type = !llvm2.struct<packed (struct<(i32)>)> } : () -> ()
+  llvm.return
+}
+
+// -----
+// Put structs into a separate split so that we can match their declarations
+// locally.
+
+// CHECK: %empty = type {}
+// CHECK: %opaque = type opaque
+// CHECK: %long = type { i32, { i32, i1 }, float, void ()* }
+// CHECK: %self-recursive = type { %self-recursive* }
+// CHECK: %unpacked = type { i32 }
+// CHECK: %packed = type <{ i32 }>
+// CHECK: %"name with spaces and !^$@$#" = type <{ i32 }>
+// CHECK: %mutually-a = type { %mutually-b* }
+// CHECK: %mutually-b = type { %mutually-a addrspace(3)* }
+// CHECK: %struct-of-arrays = type { [10 x i32] }
+// CHECK: %array-of-structs = type { i32 }
+// CHECK: %ptr-to-struct = type { i8 }
+
+llvm.func @identified_structs() {
+  // CHECK: declare %empty
+  "llvm.test_introduce_func"() { name = "return_s_empty", type = !llvm2.struct<"empty", ()> } : () -> ()
+  // CHECK: declare %opaque
+  "llvm.test_introduce_func"() { name = "return_s_opaque", type = !llvm2.struct<"opaque", opaque> } : () -> ()
+  // CHECK: declare %long
+  "llvm.test_introduce_func"() { name = "return_s_long", type = !llvm2.struct<"long", (i32, struct<(i32, i1)>, float, ptr<func<void ()>>)> } : () -> ()
+  // CHECK: declare %self-recursive
+  "llvm.test_introduce_func"() { name = "return_s_self_recurisve", type = !llvm2.struct<"self-recursive", (ptr<struct<"self-recursive">>)> } : () -> ()
+  // CHECK: declare %unpacked
+  "llvm.test_introduce_func"() { name = "return_s_unpacked", type = !llvm2.struct<"unpacked", (i32)> } : () -> ()
+  // CHECK: declare %packed
+  "llvm.test_introduce_func"() { name = "return_s_packed", type = !llvm2.struct<"packed", packed (i32)> } : () -> ()
+  // CHECK: declare %"name with spaces and !^$@$#"
+  "llvm.test_introduce_func"() { name = "return_s_symbols", type = !llvm2.struct<"name with spaces and !^$@$#", packed (i32)> } : () -> ()
+
+  // CHECK: declare %mutually-a
+  "llvm.test_introduce_func"() { name = "return_s_mutually_a", type = !llvm2.struct<"mutually-a", (ptr<struct<"mutually-b", (ptr<struct<"mutually-a">, 3>)>>)> } : () -> ()
+  // CHECK: declare %mutually-b
+  "llvm.test_introduce_func"() { name = "return_s_mutually_b", type = !llvm2.struct<"mutually-b", (ptr<struct<"mutually-a", (ptr<struct<"mutually-b">>)>, 3>)> } : () -> ()
+
+  // CHECK: declare %struct-of-arrays
+  "llvm.test_introduce_func"() { name = "return_s_struct_of_arrays", type = !llvm2.struct<"struct-of-arrays", (array<10 x i32>)> } : () -> ()
+  // CHECK: declare [10 x %array-of-structs]
+  "llvm.test_introduce_func"() { name = "return_s_array_of_structs", type = !llvm2.array<10 x struct<"array-of-structs", (i32)>> } : () -> ()
+  // CHECK: declare %ptr-to-struct*
+  "llvm.test_introduce_func"() { name = "return_s_ptr_to_struct", type = !llvm2.ptr<struct<"ptr-to-struct", (i8)>> } : () -> ()
+  llvm.return
+}
index 0df357c..ec9e5cd 100644 (file)
@@ -2,4 +2,5 @@ add_subdirectory(Dialect)
 add_subdirectory(IR)
 add_subdirectory(Pass)
 add_subdirectory(Reducer)
+add_subdirectory(Target)
 add_subdirectory(Transforms)
diff --git a/mlir/test/lib/Target/CMakeLists.txt b/mlir/test/lib/Target/CMakeLists.txt
new file mode 100644 (file)
index 0000000..cb8f206
--- /dev/null
@@ -0,0 +1,13 @@
+add_mlir_translation_library(MLIRTestLLVMTypeTranslation
+  TestLLVMTypeTranslation.cpp
+
+  LINK_COMPONENTS
+  Core
+  TransformUtils
+
+  LINK_LIBS PUBLIC
+  MLIRLLVMIR
+  MLIRTargetLLVMIRModuleTranslation
+  MLIRTestIR
+  MLIRTranslation
+  )
diff --git a/mlir/test/lib/Target/TestLLVMTypeTranslation.cpp b/mlir/test/lib/Target/TestLLVMTypeTranslation.cpp
new file mode 100644 (file)
index 0000000..b76ac2a
--- /dev/null
@@ -0,0 +1,79 @@
+//===- TestLLVMTypeTranslation.cpp - Test MLIR/LLVM IR type translation ---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
+#include "mlir/Target/LLVMIR/ModuleTranslation.h"
+#include "mlir/Target/LLVMIR/TypeTranslation.h"
+#include "mlir/Translation.h"
+
+using namespace mlir;
+
+namespace {
+class TestLLVMTypeTranslation : public LLVM::ModuleTranslation {
+  // Allow access to the constructors under MSVC.
+  friend LLVM::ModuleTranslation;
+
+public:
+  using LLVM::ModuleTranslation::ModuleTranslation;
+
+protected:
+  /// Simple test facility for translating types from MLIR LLVM dialect to LLVM
+  /// IR. This converts the "llvm.test_introduce_func" operation into an LLVM IR
+  /// function with the name extracted from the `name` attribute that returns
+  /// the type contained in the `type` attribute if it is a non-function type or
+  /// that has the signature obtained by converting `type` if it is a function
+  /// type. This is a temporary check before type translation is substituted
+  /// into the main translation flow and exercised here.
+  LogicalResult convertOperation(Operation &op,
+                                 llvm::IRBuilder<> &builder) override {
+    if (op.getName().getStringRef() == "llvm.test_introduce_func") {
+      auto attr = op.getAttrOfType<TypeAttr>("type");
+      assert(attr && "expected 'type' attribute");
+      auto type = attr.getValue().cast<LLVM::LLVMTypeNew>();
+
+      auto nameAttr = op.getAttrOfType<StringAttr>("name");
+      assert(nameAttr && "expected 'name' attributes");
+
+      llvm::Type *translated =
+          LLVM::translateTypeToLLVMIR(type, builder.getContext());
+
+      llvm::Module *module = builder.GetInsertBlock()->getModule();
+      if (auto *funcType = dyn_cast<llvm::FunctionType>(translated))
+        module->getOrInsertFunction(nameAttr.getValue(), funcType);
+      else
+        module->getOrInsertFunction(nameAttr.getValue(), translated);
+
+      std::string roundtripName = (Twine(nameAttr.getValue()) + "_round").str();
+      LLVM::LLVMTypeNew translatedBack =
+          LLVM::translateTypeFromLLVMIR(translated, *op.getContext());
+      llvm::Type *translatedBackAndForth =
+          LLVM::translateTypeToLLVMIR(translatedBack, builder.getContext());
+      if (auto *funcType = dyn_cast<llvm::FunctionType>(translatedBackAndForth))
+        module->getOrInsertFunction(roundtripName, funcType);
+      else
+        module->getOrInsertFunction(roundtripName, translatedBackAndForth);
+      return success();
+    }
+
+    return LLVM::ModuleTranslation::convertOperation(op, builder);
+  }
+};
+} // namespace
+
+namespace mlir {
+void registerTestLLVMTypeTranslation() {
+  TranslateFromMLIRRegistration reg(
+      "test-mlir-to-llvmir", [](ModuleOp module, raw_ostream &output) {
+        std::unique_ptr<llvm::Module> llvmModule =
+            LLVM::ModuleTranslation::translateModule<TestLLVMTypeTranslation>(
+                module.getOperation());
+        llvmModule->print(output, nullptr);
+        return success();
+      });
+}
+} // namespace mlir
index 897e7ad..1e6cdfe 100644 (file)
@@ -13,7 +13,11 @@ target_link_libraries(mlir-translate
   PRIVATE
   ${dialect_libs}
   ${translation_libs}
+  ${test_libs}
   MLIRIR
+  # TODO: remove after LLVM dialect transition is complete; translation uses a
+  # registration function defined in this library unconditionally.
+  MLIRLLVMTypeTestDialect
   MLIRParser
   MLIRPass
   MLIRSPIRV
index 914bd34..70bf285 100644 (file)
@@ -49,17 +49,21 @@ static llvm::cl::opt<bool> verifyDiagnostics(
 
 namespace mlir {
 // Defined in the test directory, no public header.
+void registerLLVMTypeTestDialect();
+void registerTestLLVMTypeTranslation();
 void registerTestRoundtripSPIRV();
 void registerTestRoundtripDebugSPIRV();
 } // namespace mlir
 
 static void registerTestTranslations() {
+  registerTestLLVMTypeTranslation();
   registerTestRoundtripSPIRV();
   registerTestRoundtripDebugSPIRV();
 }
 
 int main(int argc, char **argv) {
   registerAllDialects();
+  registerLLVMTypeTestDialect();
   registerAllTranslations();
   registerTestTranslations();
   llvm::InitLLVM y(argc, argv);