From d4fbbab2e494a59480096a257136ed2b75d07e87 Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Tue, 4 Aug 2020 11:37:25 +0200 Subject: [PATCH] [mlir] translate types between MLIR LLVM dialect and LLVM IR With new LLVM dialect type modeling, the dialect types no longer wrap LLVM IR types. Therefore, they need to be translated to and from LLVM IR during export and import. Introduce the relevant functionality for translating types. It is currently exercised by an ad-hoc type translation roundtripping test that will be subsumed by the actual translation test when the type system transition is complete. Depends On D84339 Reviewed By: herhut Differential Revision: https://reviews.llvm.org/D85019 --- mlir/include/mlir/Target/LLVMIR/TypeTranslation.h | 36 +++ mlir/lib/Target/CMakeLists.txt | 1 + mlir/lib/Target/LLVMIR/TypeTranslation.cpp | 309 ++++++++++++++++++++++ mlir/test/Target/llvmir-types.mlir | 228 ++++++++++++++++ mlir/test/lib/CMakeLists.txt | 1 + mlir/test/lib/Target/CMakeLists.txt | 13 + mlir/test/lib/Target/TestLLVMTypeTranslation.cpp | 79 ++++++ mlir/tools/mlir-translate/CMakeLists.txt | 4 + mlir/tools/mlir-translate/mlir-translate.cpp | 4 + 9 files changed, 675 insertions(+) create mode 100644 mlir/include/mlir/Target/LLVMIR/TypeTranslation.h create mode 100644 mlir/lib/Target/LLVMIR/TypeTranslation.cpp create mode 100644 mlir/test/Target/llvmir-types.mlir create mode 100644 mlir/test/lib/Target/CMakeLists.txt create mode 100644 mlir/test/lib/Target/TestLLVMTypeTranslation.cpp diff --git a/mlir/include/mlir/Target/LLVMIR/TypeTranslation.h b/mlir/include/mlir/Target/LLVMIR/TypeTranslation.h new file mode 100644 index 0000000..5a82f0a0 --- /dev/null +++ b/mlir/include/mlir/Target/LLVMIR/TypeTranslation.h @@ -0,0 +1,36 @@ +//===- TypeTranslation.h - Translate types between MLIR & LLVM --*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the type translation function going from MLIR LLVM dialect +// to LLVM IR and back. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TARGET_LLVMIR_TYPETRANSLATION_H +#define MLIR_TARGET_LLVMIR_TYPETRANSLATION_H + +namespace llvm { +class LLVMContext; +class Type; +} // namespace llvm + +namespace mlir { + +class MLIRContext; + +namespace LLVM { + +class LLVMTypeNew; + +llvm::Type *translateTypeToLLVMIR(LLVMTypeNew type, llvm::LLVMContext &context); +LLVMTypeNew translateTypeFromLLVMIR(llvm::Type *type, MLIRContext &context); + +} // namespace LLVM +} // namespace mlir + +#endif // MLIR_TARGET_LLVMIR_TYPETRANSLATION_H diff --git a/mlir/lib/Target/CMakeLists.txt b/mlir/lib/Target/CMakeLists.txt index 4a0af66..5ca335b 100644 --- a/mlir/lib/Target/CMakeLists.txt +++ b/mlir/lib/Target/CMakeLists.txt @@ -1,6 +1,7 @@ add_mlir_translation_library(MLIRTargetLLVMIRModuleTranslation LLVMIR/DebugTranslation.cpp LLVMIR/ModuleTranslation.cpp + LLVMIR/TypeTranslation.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Target/LLVMIR diff --git a/mlir/lib/Target/LLVMIR/TypeTranslation.cpp b/mlir/lib/Target/LLVMIR/TypeTranslation.cpp new file mode 100644 index 0000000..6163334 --- /dev/null +++ b/mlir/lib/Target/LLVMIR/TypeTranslation.cpp @@ -0,0 +1,309 @@ +//===- TypeTranslation.cpp - type translation between MLIR LLVM & LLVM IR -===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Target/LLVMIR/TypeTranslation.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/IR/MLIRContext.h" + +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Type.h" + +using namespace mlir; + +namespace { +/// Support for translating MLIR LLVM dialect types to LLVM IR. +class TypeToLLVMIRTranslator { +public: + /// Constructs a class creating types in the given LLVM context. + TypeToLLVMIRTranslator(llvm::LLVMContext &context) : context(context) {} + + /// Translates a single type. + llvm::Type *translateType(LLVM::LLVMTypeNew type) { + // If the conversion is already known, just return it. + if (knownTranslations.count(type)) + return knownTranslations.lookup(type); + + // Dispatch to an appropriate function. + llvm::Type *translated = + llvm::TypeSwitch(type) + .Case([this](LLVM::LLVMVoidType) { + return llvm::Type::getVoidTy(context); + }) + .Case([this](LLVM::LLVMHalfType) { + return llvm::Type::getHalfTy(context); + }) + .Case([this](LLVM::LLVMBFloatType) { + return llvm::Type::getBFloatTy(context); + }) + .Case([this](LLVM::LLVMFloatType) { + return llvm::Type::getFloatTy(context); + }) + .Case([this](LLVM::LLVMDoubleType) { + return llvm::Type::getDoubleTy(context); + }) + .Case([this](LLVM::LLVMFP128Type) { + return llvm::Type::getFP128Ty(context); + }) + .Case([this](LLVM::LLVMX86FP80Type) { + return llvm::Type::getX86_FP80Ty(context); + }) + .Case([this](LLVM::LLVMPPCFP128Type) { + return llvm::Type::getPPC_FP128Ty(context); + }) + .Case([this](LLVM::LLVMX86MMXType) { + return llvm::Type::getX86_MMXTy(context); + }) + .Case([this](LLVM::LLVMTokenType) { + return llvm::Type::getTokenTy(context); + }) + .Case([this](LLVM::LLVMLabelType) { + return llvm::Type::getLabelTy(context); + }) + .Case([this](LLVM::LLVMMetadataType) { + return llvm::Type::getMetadataTy(context); + }) + .Case( + [this](auto array) { return translate(array); }) + .Default([](LLVM::LLVMTypeNew t) -> llvm::Type * { + llvm_unreachable("unknown LLVM dialect type"); + }); + + // Cache the result of the conversion and return. + knownTranslations.try_emplace(type, translated); + return translated; + } + +private: + /// Translates the given array type. + llvm::Type *translate(LLVM::LLVMArrayType type) { + return llvm::ArrayType::get(translateType(type.getElementType()), + type.getNumElements()); + } + + /// Translates the given function type. + llvm::Type *translate(LLVM::LLVMFunctionType type) { + SmallVector paramTypes; + translateTypes(type.getParams(), paramTypes); + return llvm::FunctionType::get(translateType(type.getReturnType()), + paramTypes, type.isVarArg()); + } + + /// Translates the given integer type. + llvm::Type *translate(LLVM::LLVMIntegerType type) { + return llvm::IntegerType::get(context, type.getBitWidth()); + } + + /// Translates the given pointer type. + llvm::Type *translate(LLVM::LLVMPointerType type) { + return llvm::PointerType::get(translateType(type.getElementType()), + type.getAddressSpace()); + } + + /// Translates the given structure type, supports both identified and literal + /// structs. This will _create_ a new identified structure every time, use + /// `convertType` if a structure with the same name must be looked up instead. + llvm::Type *translate(LLVM::LLVMStructType type) { + SmallVector subtypes; + if (!type.isIdentified()) { + translateTypes(type.getBody(), subtypes); + return llvm::StructType::get(context, subtypes, type.isPacked()); + } + + llvm::StructType *structType = + llvm::StructType::create(context, type.getName()); + // Mark the type we just created as known so that recursive calls can pick + // it up and use directly. + knownTranslations.try_emplace(type, structType); + if (type.isOpaque()) + return structType; + + translateTypes(type.getBody(), subtypes); + structType->setBody(subtypes, type.isPacked()); + return structType; + } + + /// Translates the given fixed-vector type. + llvm::Type *translate(LLVM::LLVMFixedVectorType type) { + return llvm::FixedVectorType::get(translateType(type.getElementType()), + type.getNumElements()); + } + + /// Translates the given scalable-vector type. + llvm::Type *translate(LLVM::LLVMScalableVectorType type) { + return llvm::ScalableVectorType::get(translateType(type.getElementType()), + type.getMinNumElements()); + } + + /// Translates a list of types. + void translateTypes(ArrayRef types, + SmallVectorImpl &result) { + result.reserve(result.size() + types.size()); + for (auto type : types) + result.push_back(translateType(type)); + } + + /// Reference to the context in which the LLVM IR types are created. + llvm::LLVMContext &context; + + /// Map of known translation. This serves a double purpose: caches translation + /// results to avoid repeated recursive calls and makes sure identified + /// structs with the same name (that is, equal) are resolved to an existing + /// type instead of creating a new type. + llvm::DenseMap knownTranslations; +}; +} // end namespace + +/// Translates a type from MLIR LLVM dialect to LLVM IR. This does not maintain +/// the mapping for identified structs so new structs will be created with +/// auto-renaming on each call. This is intended exclusively for testing. +llvm::Type *mlir::LLVM::translateTypeToLLVMIR(LLVM::LLVMTypeNew type, + llvm::LLVMContext &context) { + return TypeToLLVMIRTranslator(context).translateType(type); +} + +namespace { +/// Support for translating LLVM IR types to MLIR LLVM dialect types. +class TypeFromLLVMIRTranslator { +public: + /// Constructs a class creating types in the given MLIR context. + TypeFromLLVMIRTranslator(MLIRContext &context) : context(context) {} + + /// Translates the given type. + LLVM::LLVMTypeNew translateType(llvm::Type *type) { + if (knownTranslations.count(type)) + return knownTranslations.lookup(type); + + LLVM::LLVMTypeNew translated = + llvm::TypeSwitch(type) + .Case( + [this](auto *type) { return translate(type); }) + .Default([this](llvm::Type *type) { + return translatePrimitiveType(type); + }); + knownTranslations.try_emplace(type, translated); + return translated; + } + +private: + /// Translates the given primitive, i.e. non-parametric in MLIR nomenclature, + /// type. + LLVM::LLVMTypeNew translatePrimitiveType(llvm::Type *type) { + if (type->isVoidTy()) + return LLVM::LLVMVoidType::get(&context); + if (type->isHalfTy()) + return LLVM::LLVMHalfType::get(&context); + if (type->isBFloatTy()) + return LLVM::LLVMBFloatType::get(&context); + if (type->isFloatTy()) + return LLVM::LLVMFloatType::get(&context); + if (type->isDoubleTy()) + return LLVM::LLVMDoubleType::get(&context); + if (type->isFP128Ty()) + return LLVM::LLVMFP128Type::get(&context); + if (type->isX86_FP80Ty()) + return LLVM::LLVMX86FP80Type::get(&context); + if (type->isPPC_FP128Ty()) + return LLVM::LLVMPPCFP128Type::get(&context); + if (type->isX86_MMXTy()) + return LLVM::LLVMX86MMXType::get(&context); + if (type->isLabelTy()) + return LLVM::LLVMLabelType::get(&context); + if (type->isMetadataTy()) + return LLVM::LLVMMetadataType::get(&context); + llvm_unreachable("not a primitive type"); + } + + /// Translates the given array type. + LLVM::LLVMTypeNew translate(llvm::ArrayType *type) { + return LLVM::LLVMArrayType::get(translateType(type->getElementType()), + type->getNumElements()); + } + + /// Translates the given function type. + LLVM::LLVMTypeNew translate(llvm::FunctionType *type) { + SmallVector paramTypes; + translateTypes(type->params(), paramTypes); + return LLVM::LLVMFunctionType::get(translateType(type->getReturnType()), + paramTypes, type->isVarArg()); + } + + /// Translates the given integer type. + LLVM::LLVMTypeNew translate(llvm::IntegerType *type) { + return LLVM::LLVMIntegerType::get(&context, type->getBitWidth()); + } + + /// Translates the given pointer type. + LLVM::LLVMTypeNew translate(llvm::PointerType *type) { + return LLVM::LLVMPointerType::get(translateType(type->getElementType()), + type->getAddressSpace()); + } + + /// Translates the given structure type. + LLVM::LLVMTypeNew translate(llvm::StructType *type) { + SmallVector subtypes; + if (type->isLiteral()) { + translateTypes(type->subtypes(), subtypes); + return LLVM::LLVMStructType::getLiteral(&context, subtypes, + type->isPacked()); + } + + if (type->isOpaque()) + return LLVM::LLVMStructType::getOpaque(type->getName(), &context); + + LLVM::LLVMStructType translated = + LLVM::LLVMStructType::getIdentified(&context, type->getName()); + knownTranslations.try_emplace(type, translated); + translateTypes(type->subtypes(), subtypes); + LogicalResult bodySet = translated.setBody(subtypes, type->isPacked()); + assert(succeeded(bodySet) && + "could not set the body of an identified struct"); + (void)bodySet; + return translated; + } + + /// Translates the given fixed-vector type. + LLVM::LLVMTypeNew translate(llvm::FixedVectorType *type) { + return LLVM::LLVMFixedVectorType::get(translateType(type->getElementType()), + type->getNumElements()); + } + + /// Translates the given scalable-vector type. + LLVM::LLVMTypeNew translate(llvm::ScalableVectorType *type) { + return LLVM::LLVMScalableVectorType::get( + translateType(type->getElementType()), type->getMinNumElements()); + } + + /// Translates a list of types. + void translateTypes(ArrayRef types, + SmallVectorImpl &result) { + result.reserve(result.size() + types.size()); + for (llvm::Type *type : types) + result.push_back(translateType(type)); + } + + /// Map of known translations. Serves as a cache and as recursion stopper for + /// translating recursive structs. + llvm::DenseMap knownTranslations; + + /// The context in which MLIR types are created. + MLIRContext &context; +}; +} // end namespace + +/// Translates a type from LLVM IR to MLIR LLVM dialect. This is intended +/// exclusively for testing. +LLVM::LLVMTypeNew mlir::LLVM::translateTypeFromLLVMIR(llvm::Type *type, + MLIRContext &context) { + return TypeFromLLVMIRTranslator(context).translateType(type); +} diff --git a/mlir/test/Target/llvmir-types.mlir b/mlir/test/Target/llvmir-types.mlir new file mode 100644 index 0000000..d807562 --- /dev/null +++ b/mlir/test/Target/llvmir-types.mlir @@ -0,0 +1,228 @@ +// RUN: mlir-translate -test-mlir-to-llvmir -split-input-file %s | FileCheck %s + +llvm.func @primitives() { + // CHECK: declare void @return_void() + // CHECK: declare void @return_void_round() + "llvm.test_introduce_func"() { name = "return_void", type = !llvm2.void } : () -> () + // CHECK: declare half @return_half() + // CHECK: declare half @return_half_round() + "llvm.test_introduce_func"() { name = "return_half", type = !llvm2.half } : () -> () + // CHECK: declare bfloat @return_bfloat() + // CHECK: declare bfloat @return_bfloat_round() + "llvm.test_introduce_func"() { name = "return_bfloat", type = !llvm2.bfloat } : () -> () + // CHECK: declare float @return_float() + // CHECK: declare float @return_float_round() + "llvm.test_introduce_func"() { name = "return_float", type = !llvm2.float } : () -> () + // CHECK: declare double @return_double() + // CHECK: declare double @return_double_round() + "llvm.test_introduce_func"() { name = "return_double", type = !llvm2.double } : () -> () + // CHECK: declare fp128 @return_fp128() + // CHECK: declare fp128 @return_fp128_round() + "llvm.test_introduce_func"() { name = "return_fp128", type = !llvm2.fp128 } : () -> () + // CHECK: declare x86_fp80 @return_x86_fp80() + // CHECK: declare x86_fp80 @return_x86_fp80_round() + "llvm.test_introduce_func"() { name = "return_x86_fp80", type = !llvm2.x86_fp80 } : () -> () + // CHECK: declare ppc_fp128 @return_ppc_fp128() + // CHECK: declare ppc_fp128 @return_ppc_fp128_round() + "llvm.test_introduce_func"() { name = "return_ppc_fp128", type = !llvm2.ppc_fp128 } : () -> () + // CHECK: declare x86_mmx @return_x86_mmx() + // CHECK: declare x86_mmx @return_x86_mmx_round() + "llvm.test_introduce_func"() { name = "return_x86_mmx", type = !llvm2.x86_mmx } : () -> () + llvm.return +} + +llvm.func @funcs() { + // CHECK: declare void @f_void_i32(i32) + // CHECK: declare void @f_void_i32_round(i32) + "llvm.test_introduce_func"() { name ="f_void_i32", type = !llvm2.func } : () -> () + // CHECK: declare i32 @f_i32_empty() + // CHECK: declare i32 @f_i32_empty_round() + "llvm.test_introduce_func"() { name ="f_i32_empty", type = !llvm2.func } : () -> () + // CHECK: declare i32 @f_i32_half_bfloat_float_double(half, bfloat, float, double) + // CHECK: declare i32 @f_i32_half_bfloat_float_double_round(half, bfloat, float, double) + "llvm.test_introduce_func"() { name ="f_i32_half_bfloat_float_double", type = !llvm2.func } : () -> () + // CHECK: declare i32 @f_i32_i32_i32(i32, i32) + // CHECK: declare i32 @f_i32_i32_i32_round(i32, i32) + "llvm.test_introduce_func"() { name ="f_i32_i32_i32", type = !llvm2.func } : () -> () + // CHECK: declare void @f_void_variadic(...) + // CHECK: declare void @f_void_variadic_round(...) + "llvm.test_introduce_func"() { name ="f_void_variadic", type = !llvm2.func } : () -> () + // CHECK: declare void @f_void_i32_i32_variadic(i32, i32, ...) + // CHECK: declare void @f_void_i32_i32_variadic_round(i32, i32, ...) + "llvm.test_introduce_func"() { name ="f_void_i32_i32_variadic", type = !llvm2.func } : () -> () + llvm.return +} + +llvm.func @ints() { + // CHECK: declare i1 @return_i1() + // CHECK: declare i1 @return_i1_round() + "llvm.test_introduce_func"() { name = "return_i1", type = !llvm2.i1 } : () -> () + // CHECK: declare i8 @return_i8() + // CHECK: declare i8 @return_i8_round() + "llvm.test_introduce_func"() { name = "return_i8", type = !llvm2.i8 } : () -> () + // CHECK: declare i16 @return_i16() + // CHECK: declare i16 @return_i16_round() + "llvm.test_introduce_func"() { name = "return_i16", type = !llvm2.i16 } : () -> () + // CHECK: declare i32 @return_i32() + // CHECK: declare i32 @return_i32_round() + "llvm.test_introduce_func"() { name = "return_i32", type = !llvm2.i32 } : () -> () + // CHECK: declare i64 @return_i64() + // CHECK: declare i64 @return_i64_round() + "llvm.test_introduce_func"() { name = "return_i64", type = !llvm2.i64 } : () -> () + // CHECK: declare i57 @return_i57() + // CHECK: declare i57 @return_i57_round() + "llvm.test_introduce_func"() { name = "return_i57", type = !llvm2.i57 } : () -> () + // CHECK: declare i129 @return_i129() + // CHECK: declare i129 @return_i129_round() + "llvm.test_introduce_func"() { name = "return_i129", type = !llvm2.i129 } : () -> () + llvm.return +} + +llvm.func @pointers() { + // CHECK: declare i8* @return_pi8() + // CHECK: declare i8* @return_pi8_round() + "llvm.test_introduce_func"() { name = "return_pi8", type = !llvm2.ptr } : () -> () + // CHECK: declare float* @return_pfloat() + // CHECK: declare float* @return_pfloat_round() + "llvm.test_introduce_func"() { name = "return_pfloat", type = !llvm2.ptr } : () -> () + // CHECK: declare i8** @return_ppi8() + // CHECK: declare i8** @return_ppi8_round() + "llvm.test_introduce_func"() { name = "return_ppi8", type = !llvm2.ptr> } : () -> () + // CHECK: declare i8***** @return_pppppi8() + // CHECK: declare i8***** @return_pppppi8_round() + "llvm.test_introduce_func"() { name = "return_pppppi8", type = !llvm2.ptr>>>> } : () -> () + // CHECK: declare i8* @return_pi8_0() + // CHECK: declare i8* @return_pi8_0_round() + "llvm.test_introduce_func"() { name = "return_pi8_0", type = !llvm2.ptr } : () -> () + // CHECK: declare i8 addrspace(1)* @return_pi8_1() + // CHECK: declare i8 addrspace(1)* @return_pi8_1_round() + "llvm.test_introduce_func"() { name = "return_pi8_1", type = !llvm2.ptr } : () -> () + // CHECK: declare i8 addrspace(42)* @return_pi8_42() + // CHECK: declare i8 addrspace(42)* @return_pi8_42_round() + "llvm.test_introduce_func"() { name = "return_pi8_42", type = !llvm2.ptr } : () -> () + // CHECK: declare i8 addrspace(42)* addrspace(9)* @return_ppi8_42_9() + // CHECK: declare i8 addrspace(42)* addrspace(9)* @return_ppi8_42_9_round() + "llvm.test_introduce_func"() { name = "return_ppi8_42_9", type = !llvm2.ptr, 9> } : () -> () + llvm.return +} + +llvm.func @vectors() { + // CHECK: declare <4 x i32> @return_v4_i32() + // CHECK: declare <4 x i32> @return_v4_i32_round() + "llvm.test_introduce_func"() { name = "return_v4_i32", type = !llvm2.vec<4 x i32> } : () -> () + // CHECK: declare <4 x float> @return_v4_float() + // CHECK: declare <4 x float> @return_v4_float_round() + "llvm.test_introduce_func"() { name = "return_v4_float", type = !llvm2.vec<4 x float> } : () -> () + // CHECK: declare @return_vs_4_i32() + // CHECK: declare @return_vs_4_i32_round() + "llvm.test_introduce_func"() { name = "return_vs_4_i32", type = !llvm2.vec } : () -> () + // CHECK: declare @return_vs_8_half() + // CHECK: declare @return_vs_8_half_round() + "llvm.test_introduce_func"() { name = "return_vs_8_half", type = !llvm2.vec } : () -> () + // CHECK: declare <4 x i8*> @return_v_4_pi8() + // CHECK: declare <4 x i8*> @return_v_4_pi8_round() + "llvm.test_introduce_func"() { name = "return_v_4_pi8", type = !llvm2.vec<4 x ptr> } : () -> () + llvm.return +} + +llvm.func @arrays() { + // CHECK: declare [10 x i32] @return_a10_i32() + // CHECK: declare [10 x i32] @return_a10_i32_round() + "llvm.test_introduce_func"() { name = "return_a10_i32", type = !llvm2.array<10 x i32> } : () -> () + // CHECK: declare [8 x float] @return_a8_float() + // CHECK: declare [8 x float] @return_a8_float_round() + "llvm.test_introduce_func"() { name = "return_a8_float", type = !llvm2.array<8 x float> } : () -> () + // CHECK: declare [10 x i32 addrspace(4)*] @return_a10_pi32_4() + // CHECK: declare [10 x i32 addrspace(4)*] @return_a10_pi32_4_round() + "llvm.test_introduce_func"() { name = "return_a10_pi32_4", type = !llvm2.array<10 x ptr> } : () -> () + // CHECK: declare [10 x [4 x float]] @return_a10_a4_float() + // CHECK: declare [10 x [4 x float]] @return_a10_a4_float_round() + "llvm.test_introduce_func"() { name = "return_a10_a4_float", type = !llvm2.array<10 x array<4 x float>> } : () -> () + llvm.return +} + +llvm.func @literal_structs() { + // CHECK: declare {} @return_struct_empty() + // CHECK: declare {} @return_struct_empty_round() + "llvm.test_introduce_func"() { name = "return_struct_empty", type = !llvm2.struct<()> } : () -> () + // CHECK: declare { i32 } @return_s_i32() + // CHECK: declare { i32 } @return_s_i32_round() + "llvm.test_introduce_func"() { name = "return_s_i32", type = !llvm2.struct<(i32)> } : () -> () + // CHECK: declare { float, i32 } @return_s_float_i32() + // CHECK: declare { float, i32 } @return_s_float_i32_round() + "llvm.test_introduce_func"() { name = "return_s_float_i32", type = !llvm2.struct<(float, i32)> } : () -> () + // CHECK: declare { { i32 } } @return_s_s_i32() + // CHECK: declare { { i32 } } @return_s_s_i32_round() + "llvm.test_introduce_func"() { name = "return_s_s_i32", type = !llvm2.struct<(struct<(i32)>)> } : () -> () + // CHECK: declare { i32, { i32 }, float } @return_s_i32_s_i32_float() + // CHECK: declare { i32, { i32 }, float } @return_s_i32_s_i32_float_round() + "llvm.test_introduce_func"() { name = "return_s_i32_s_i32_float", type = !llvm2.struct<(i32, struct<(i32)>, float)> } : () -> () + + // CHECK: declare <{}> @return_sp_empty() + // CHECK: declare <{}> @return_sp_empty_round() + "llvm.test_introduce_func"() { name = "return_sp_empty", type = !llvm2.struct } : () -> () + // CHECK: declare <{ i32 }> @return_sp_i32() + // CHECK: declare <{ i32 }> @return_sp_i32_round() + "llvm.test_introduce_func"() { name = "return_sp_i32", type = !llvm2.struct } : () -> () + // CHECK: declare <{ float, i32 }> @return_sp_float_i32() + // CHECK: declare <{ float, i32 }> @return_sp_float_i32_round() + "llvm.test_introduce_func"() { name = "return_sp_float_i32", type = !llvm2.struct } : () -> () + // CHECK: declare <{ i32, { i32, i1 }, float }> @return_sp_i32_s_i31_1_float() + // CHECK: declare <{ i32, { i32, i1 }, float }> @return_sp_i32_s_i31_1_float_round() + "llvm.test_introduce_func"() { name = "return_sp_i32_s_i31_1_float", type = !llvm2.struct, float)> } : () -> () + + // CHECK: declare { <{ i32 }> } @return_s_sp_i32() + // CHECK: declare { <{ i32 }> } @return_s_sp_i32_round() + "llvm.test_introduce_func"() { name = "return_s_sp_i32", type = !llvm2.struct<(struct)> } : () -> () + // CHECK: declare <{ { i32 } }> @return_sp_s_i32() + // CHECK: declare <{ { i32 } }> @return_sp_s_i32_round() + "llvm.test_introduce_func"() { name = "return_sp_s_i32", type = !llvm2.struct)> } : () -> () + llvm.return +} + +// ----- +// Put structs into a separate split so that we can match their declarations +// locally. + +// CHECK: %empty = type {} +// CHECK: %opaque = type opaque +// CHECK: %long = type { i32, { i32, i1 }, float, void ()* } +// CHECK: %self-recursive = type { %self-recursive* } +// CHECK: %unpacked = type { i32 } +// CHECK: %packed = type <{ i32 }> +// CHECK: %"name with spaces and !^$@$#" = type <{ i32 }> +// CHECK: %mutually-a = type { %mutually-b* } +// CHECK: %mutually-b = type { %mutually-a addrspace(3)* } +// CHECK: %struct-of-arrays = type { [10 x i32] } +// CHECK: %array-of-structs = type { i32 } +// CHECK: %ptr-to-struct = type { i8 } + +llvm.func @identified_structs() { + // CHECK: declare %empty + "llvm.test_introduce_func"() { name = "return_s_empty", type = !llvm2.struct<"empty", ()> } : () -> () + // CHECK: declare %opaque + "llvm.test_introduce_func"() { name = "return_s_opaque", type = !llvm2.struct<"opaque", opaque> } : () -> () + // CHECK: declare %long + "llvm.test_introduce_func"() { name = "return_s_long", type = !llvm2.struct<"long", (i32, struct<(i32, i1)>, float, ptr>)> } : () -> () + // CHECK: declare %self-recursive + "llvm.test_introduce_func"() { name = "return_s_self_recurisve", type = !llvm2.struct<"self-recursive", (ptr>)> } : () -> () + // CHECK: declare %unpacked + "llvm.test_introduce_func"() { name = "return_s_unpacked", type = !llvm2.struct<"unpacked", (i32)> } : () -> () + // CHECK: declare %packed + "llvm.test_introduce_func"() { name = "return_s_packed", type = !llvm2.struct<"packed", packed (i32)> } : () -> () + // CHECK: declare %"name with spaces and !^$@$#" + "llvm.test_introduce_func"() { name = "return_s_symbols", type = !llvm2.struct<"name with spaces and !^$@$#", packed (i32)> } : () -> () + + // CHECK: declare %mutually-a + "llvm.test_introduce_func"() { name = "return_s_mutually_a", type = !llvm2.struct<"mutually-a", (ptr, 3>)>>)> } : () -> () + // CHECK: declare %mutually-b + "llvm.test_introduce_func"() { name = "return_s_mutually_b", type = !llvm2.struct<"mutually-b", (ptr>)>, 3>)> } : () -> () + + // CHECK: declare %struct-of-arrays + "llvm.test_introduce_func"() { name = "return_s_struct_of_arrays", type = !llvm2.struct<"struct-of-arrays", (array<10 x i32>)> } : () -> () + // CHECK: declare [10 x %array-of-structs] + "llvm.test_introduce_func"() { name = "return_s_array_of_structs", type = !llvm2.array<10 x struct<"array-of-structs", (i32)>> } : () -> () + // CHECK: declare %ptr-to-struct* + "llvm.test_introduce_func"() { name = "return_s_ptr_to_struct", type = !llvm2.ptr> } : () -> () + llvm.return +} diff --git a/mlir/test/lib/CMakeLists.txt b/mlir/test/lib/CMakeLists.txt index 0df357c..ec9e5cd 100644 --- a/mlir/test/lib/CMakeLists.txt +++ b/mlir/test/lib/CMakeLists.txt @@ -2,4 +2,5 @@ add_subdirectory(Dialect) add_subdirectory(IR) add_subdirectory(Pass) add_subdirectory(Reducer) +add_subdirectory(Target) add_subdirectory(Transforms) diff --git a/mlir/test/lib/Target/CMakeLists.txt b/mlir/test/lib/Target/CMakeLists.txt new file mode 100644 index 0000000..cb8f206 --- /dev/null +++ b/mlir/test/lib/Target/CMakeLists.txt @@ -0,0 +1,13 @@ +add_mlir_translation_library(MLIRTestLLVMTypeTranslation + TestLLVMTypeTranslation.cpp + + LINK_COMPONENTS + Core + TransformUtils + + LINK_LIBS PUBLIC + MLIRLLVMIR + MLIRTargetLLVMIRModuleTranslation + MLIRTestIR + MLIRTranslation + ) diff --git a/mlir/test/lib/Target/TestLLVMTypeTranslation.cpp b/mlir/test/lib/Target/TestLLVMTypeTranslation.cpp new file mode 100644 index 0000000..b76ac2a --- /dev/null +++ b/mlir/test/lib/Target/TestLLVMTypeTranslation.cpp @@ -0,0 +1,79 @@ +//===- TestLLVMTypeTranslation.cpp - Test MLIR/LLVM IR type translation ---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/Target/LLVMIR/ModuleTranslation.h" +#include "mlir/Target/LLVMIR/TypeTranslation.h" +#include "mlir/Translation.h" + +using namespace mlir; + +namespace { +class TestLLVMTypeTranslation : public LLVM::ModuleTranslation { + // Allow access to the constructors under MSVC. + friend LLVM::ModuleTranslation; + +public: + using LLVM::ModuleTranslation::ModuleTranslation; + +protected: + /// Simple test facility for translating types from MLIR LLVM dialect to LLVM + /// IR. This converts the "llvm.test_introduce_func" operation into an LLVM IR + /// function with the name extracted from the `name` attribute that returns + /// the type contained in the `type` attribute if it is a non-function type or + /// that has the signature obtained by converting `type` if it is a function + /// type. This is a temporary check before type translation is substituted + /// into the main translation flow and exercised here. + LogicalResult convertOperation(Operation &op, + llvm::IRBuilder<> &builder) override { + if (op.getName().getStringRef() == "llvm.test_introduce_func") { + auto attr = op.getAttrOfType("type"); + assert(attr && "expected 'type' attribute"); + auto type = attr.getValue().cast(); + + auto nameAttr = op.getAttrOfType("name"); + assert(nameAttr && "expected 'name' attributes"); + + llvm::Type *translated = + LLVM::translateTypeToLLVMIR(type, builder.getContext()); + + llvm::Module *module = builder.GetInsertBlock()->getModule(); + if (auto *funcType = dyn_cast(translated)) + module->getOrInsertFunction(nameAttr.getValue(), funcType); + else + module->getOrInsertFunction(nameAttr.getValue(), translated); + + std::string roundtripName = (Twine(nameAttr.getValue()) + "_round").str(); + LLVM::LLVMTypeNew translatedBack = + LLVM::translateTypeFromLLVMIR(translated, *op.getContext()); + llvm::Type *translatedBackAndForth = + LLVM::translateTypeToLLVMIR(translatedBack, builder.getContext()); + if (auto *funcType = dyn_cast(translatedBackAndForth)) + module->getOrInsertFunction(roundtripName, funcType); + else + module->getOrInsertFunction(roundtripName, translatedBackAndForth); + return success(); + } + + return LLVM::ModuleTranslation::convertOperation(op, builder); + } +}; +} // namespace + +namespace mlir { +void registerTestLLVMTypeTranslation() { + TranslateFromMLIRRegistration reg( + "test-mlir-to-llvmir", [](ModuleOp module, raw_ostream &output) { + std::unique_ptr llvmModule = + LLVM::ModuleTranslation::translateModule( + module.getOperation()); + llvmModule->print(output, nullptr); + return success(); + }); +} +} // namespace mlir diff --git a/mlir/tools/mlir-translate/CMakeLists.txt b/mlir/tools/mlir-translate/CMakeLists.txt index 897e7ad..1e6cdfe 100644 --- a/mlir/tools/mlir-translate/CMakeLists.txt +++ b/mlir/tools/mlir-translate/CMakeLists.txt @@ -13,7 +13,11 @@ target_link_libraries(mlir-translate PRIVATE ${dialect_libs} ${translation_libs} + ${test_libs} MLIRIR + # TODO: remove after LLVM dialect transition is complete; translation uses a + # registration function defined in this library unconditionally. + MLIRLLVMTypeTestDialect MLIRParser MLIRPass MLIRSPIRV diff --git a/mlir/tools/mlir-translate/mlir-translate.cpp b/mlir/tools/mlir-translate/mlir-translate.cpp index 914bd34..70bf285 100644 --- a/mlir/tools/mlir-translate/mlir-translate.cpp +++ b/mlir/tools/mlir-translate/mlir-translate.cpp @@ -49,17 +49,21 @@ static llvm::cl::opt verifyDiagnostics( namespace mlir { // Defined in the test directory, no public header. +void registerLLVMTypeTestDialect(); +void registerTestLLVMTypeTranslation(); void registerTestRoundtripSPIRV(); void registerTestRoundtripDebugSPIRV(); } // namespace mlir static void registerTestTranslations() { + registerTestLLVMTypeTranslation(); registerTestRoundtripSPIRV(); registerTestRoundtripDebugSPIRV(); } int main(int argc, char **argv) { registerAllDialects(); + registerLLVMTypeTestDialect(); registerAllTranslations(); registerTestTranslations(); llvm::InitLLVM y(argc, argv); -- 2.7.4