From 85285be9c37ad0b6e3dabe82248d8917a6ebd5ec Mon Sep 17 00:00:00 2001 From: Xiang Li Date: Mon, 2 May 2022 13:59:37 -0700 Subject: [PATCH] [DirectX backend] Add pass to lower llvm intrinsic into dxil op function. A new pass DXILOpLowering was added. It will scan all llvm intrinsics, create dxil op function if it can map to dxil op function. Then translate call instructions on the intrinsic into call on dxil op function. dxil op function will add i32 argument to the begining of args for dxil opcode. So cannot use setCalledFunction to update the call instruction on intrinsic. This commit only support sin to start the work. Reviewed By: kuhar, beanz Differential Revision: https://reviews.llvm.org/D124805 --- llvm/lib/Target/DirectX/CMakeLists.txt | 1 + llvm/lib/Target/DirectX/DXILConstants.h | 29 +++ llvm/lib/Target/DirectX/DXILOpLowering.cpp | 279 +++++++++++++++++++++++ llvm/lib/Target/DirectX/DirectX.h | 7 + llvm/lib/Target/DirectX/DirectXTargetMachine.cpp | 2 + llvm/test/CodeGen/DirectX/sin.ll | 43 ++++ llvm/tools/opt/opt.cpp | 2 +- 7 files changed, 362 insertions(+), 1 deletion(-) create mode 100644 llvm/lib/Target/DirectX/DXILConstants.h create mode 100644 llvm/lib/Target/DirectX/DXILOpLowering.cpp create mode 100644 llvm/test/CodeGen/DirectX/sin.ll diff --git a/llvm/lib/Target/DirectX/CMakeLists.txt b/llvm/lib/Target/DirectX/CMakeLists.txt index d76eb97..f2bcbf4 100644 --- a/llvm/lib/Target/DirectX/CMakeLists.txt +++ b/llvm/lib/Target/DirectX/CMakeLists.txt @@ -9,6 +9,7 @@ add_public_tablegen_target(DirectXCommonTableGen) add_llvm_target(DirectXCodeGen DirectXSubtarget.cpp DirectXTargetMachine.cpp + DXILOpLowering.cpp DXILPointerType.cpp DXILPrepare.cpp PointerTypeAnalysis.cpp diff --git a/llvm/lib/Target/DirectX/DXILConstants.h b/llvm/lib/Target/DirectX/DXILConstants.h new file mode 100644 index 0000000..c7b2be6 --- /dev/null +++ b/llvm/lib/Target/DirectX/DXILConstants.h @@ -0,0 +1,29 @@ +//===- DXILConstants.h - Essential DXIL constants -------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// +/// \file This file contains essential DXIL constants. +//===----------------------------------------------------------------------===// + +#ifndef LLVM_LIB_TARGET_DIRECTX_DXILCONSTANTS_H +#define LLVM_LIB_TARGET_DIRECTX_DXILCONSTANTS_H + +namespace llvm { +namespace DXIL { +// Enumeration for operations specified by DXIL +enum class OpCode : unsigned { + Sin = 13, // returns sine(theta) for theta in radians. +}; +// Groups for DXIL operations with equivalent function templates +enum class OpCodeClass : unsigned { + Unary, +}; + +} // namespace DXIL +} // namespace llvm + +#endif diff --git a/llvm/lib/Target/DirectX/DXILOpLowering.cpp b/llvm/lib/Target/DirectX/DXILOpLowering.cpp new file mode 100644 index 0000000..f7925b5 --- /dev/null +++ b/llvm/lib/Target/DirectX/DXILOpLowering.cpp @@ -0,0 +1,279 @@ +//===- DXILOpLower.cpp - Lowering LLVM intrinsic to DIXLOp function -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// +/// \file This file contains passes and utilities to lower llvm intrinsic call +/// to DXILOp function call. +//===----------------------------------------------------------------------===// + +#include "DXILConstants.h" +#include "DirectX.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/CodeGen/Passes.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/PassManager.h" +#include "llvm/Pass.h" +#include "llvm/Support/ErrorHandling.h" + +#define DEBUG_TYPE "dxil-op-lower" + +using namespace llvm; +using namespace llvm::DXIL; + +constexpr StringLiteral DXILOpNamePrefix = "dx.op."; + +enum OverloadKind : uint16_t { + VOID = 1, + HALF = 1 << 1, + FLOAT = 1 << 2, + DOUBLE = 1 << 3, + I1 = 1 << 4, + I8 = 1 << 5, + I16 = 1 << 6, + I32 = 1 << 7, + I64 = 1 << 8, + UserDefineType = 1 << 9, + ObjectType = 1 << 10, +}; + +static const char *getOverloadTypeName(OverloadKind Kind) { + switch (Kind) { + case OverloadKind::HALF: + return "f16"; + case OverloadKind::FLOAT: + return "f32"; + case OverloadKind::DOUBLE: + return "f64"; + case OverloadKind::I1: + return "i1"; + case OverloadKind::I8: + return "i8"; + case OverloadKind::I16: + return "i16"; + case OverloadKind::I32: + return "i32"; + case OverloadKind::I64: + return "i64"; + case OverloadKind::VOID: + case OverloadKind::ObjectType: + case OverloadKind::UserDefineType: + llvm_unreachable("invalid overload type for name"); + break; + } +} + +static OverloadKind getOverloadKind(Type *Ty) { + Type::TypeID T = Ty->getTypeID(); + switch (T) { + case Type::VoidTyID: + return OverloadKind::VOID; + case Type::HalfTyID: + return OverloadKind::HALF; + case Type::FloatTyID: + return OverloadKind::FLOAT; + case Type::DoubleTyID: + return OverloadKind::DOUBLE; + case Type::IntegerTyID: { + IntegerType *ITy = cast(Ty); + unsigned Bits = ITy->getBitWidth(); + switch (Bits) { + case 1: + return OverloadKind::I1; + case 8: + return OverloadKind::I8; + case 16: + return OverloadKind::I16; + case 32: + return OverloadKind::I32; + case 64: + return OverloadKind::I64; + default: + llvm_unreachable("invalid overload type"); + return OverloadKind::VOID; + } + } + case Type::PointerTyID: + return OverloadKind::UserDefineType; + case Type::StructTyID: + return OverloadKind::ObjectType; + default: + llvm_unreachable("invalid overload type"); + return OverloadKind::VOID; + } +} + +static std::string getTypeName(OverloadKind Kind, Type *Ty) { + if (Kind < OverloadKind::UserDefineType) { + return getOverloadTypeName(Kind); + } else if (Kind == OverloadKind::UserDefineType) { + StructType *ST = cast(Ty); + return ST->getStructName().str(); + } else if (Kind == OverloadKind::ObjectType) { + StructType *ST = cast(Ty); + return ST->getStructName().str(); + } else { + std::string Str; + raw_string_ostream OS(Str); + Ty->print(OS); + return OS.str(); + } +} + +// Static properties. +struct OpCodeProperty { + DXIL::OpCode OpCode; + // FIXME: change OpCodeName into index to a large string constant when move to + // tableGen. + const char *OpCodeName; + DXIL::OpCodeClass OpCodeClass; + uint16_t OverloadTys; + llvm::Attribute::AttrKind FuncAttr; +}; + +static const char *getOpCodeClassName(const OpCodeProperty &Prop) { + // FIXME: generate this table with tableGen. + static const char *OpCodeClassNames[] = { + "unary", + }; + unsigned Index = static_cast(Prop.OpCodeClass); + assert(Index < (sizeof(OpCodeClassNames) / sizeof(OpCodeClassNames[0])) && + "Out of bound OpCodeClass"); + return OpCodeClassNames[Index]; +} + +static std::string constructOverloadName(OverloadKind Kind, Type *Ty, + const OpCodeProperty &Prop) { + if (Kind == OverloadKind::VOID) { + return (Twine(DXILOpNamePrefix) + getOpCodeClassName(Prop)).str(); + } + return (Twine(DXILOpNamePrefix) + getOpCodeClassName(Prop) + "." + + getTypeName(Kind, Ty)) + .str(); +} + +static const OpCodeProperty *getOpCodeProperty(DXIL::OpCode DXILOp) { + // FIXME: generate this table with tableGen. + static const OpCodeProperty OpCodeProps[] = { + {DXIL::OpCode::Sin, "Sin", OpCodeClass::Unary, + OverloadKind::FLOAT | OverloadKind::HALF, Attribute::AttrKind::ReadNone}, + }; + // FIXME: change search to indexing with + // DXILOp once all DXIL op is added. + OpCodeProperty TmpProp; + TmpProp.OpCode = DXILOp; + const OpCodeProperty *Prop = + llvm::lower_bound(OpCodeProps, TmpProp, + [](const OpCodeProperty &A, const OpCodeProperty &B) { + return A.OpCode < B.OpCode; + }); + return Prop; +} + +static FunctionCallee createDXILOpFunction(DXIL::OpCode DXILOp, Function &F, + Module &M) { + const OpCodeProperty *Prop = getOpCodeProperty(DXILOp); + + // Get return type as overload type for DXILOp. + // Only simple mapping case here, so return type is good enough. + Type *OverloadTy = F.getReturnType(); + + OverloadKind Kind = getOverloadKind(OverloadTy); + // FIXME: find the issue and report error in clang instead of check it in + // backend. + if ((Prop->OverloadTys & (uint16_t)Kind) == 0) { + llvm_unreachable("invalid overload"); + } + + std::string FnName = constructOverloadName(Kind, OverloadTy, *Prop); + assert(!M.getFunction(FnName) && "Function already exists"); + + auto &Ctx = M.getContext(); + Type *OpCodeTy = Type::getInt32Ty(Ctx); + + SmallVector ArgTypes; + // DXIL has i32 opcode as first arg. + ArgTypes.emplace_back(OpCodeTy); + FunctionType *FT = F.getFunctionType(); + ArgTypes.append(FT->param_begin(), FT->param_end()); + FunctionType *DXILOpFT = FunctionType::get(OverloadTy, ArgTypes, false); + return M.getOrInsertFunction(FnName, DXILOpFT); +} + +static void lowerIntrinsic(DXIL::OpCode DXILOp, Function &F, Module &M) { + auto DXILOpFn = createDXILOpFunction(DXILOp, F, M); + IRBuilder<> B(M.getContext()); + Value *DXILOpArg = B.getInt32(static_cast(DXILOp)); + for (User *U : make_early_inc_range(F.users())) { + CallInst *CI = dyn_cast(U); + if (!CI) + continue; + + SmallVector Args; + Args.emplace_back(DXILOpArg); + Args.append(CI->arg_begin(), CI->arg_end()); + B.SetInsertPoint(CI); + CallInst *DXILCI = B.CreateCall(DXILOpFn, Args); + CI->replaceAllUsesWith(DXILCI); + CI->eraseFromParent(); + } + if (F.user_empty()) + F.eraseFromParent(); +} + +static bool lowerIntrinsics(Module &M) { + bool Updated = false; + static SmallDenseMap LowerMap = { + {Intrinsic::sin, DXIL::OpCode::Sin}}; + for (Function &F : make_early_inc_range(M.functions())) { + if (!F.isDeclaration()) + continue; + Intrinsic::ID ID = F.getIntrinsicID(); + auto LowerIt = LowerMap.find(ID); + if (LowerIt == LowerMap.end()) + continue; + lowerIntrinsic(LowerIt->second, F, M); + Updated = true; + } + return Updated; +} + +namespace { +/// A pass that transforms external global definitions into declarations. +class DXILOpLowering : public PassInfoMixin { +public: + PreservedAnalyses run(Module &M, ModuleAnalysisManager &) { + if (lowerIntrinsics(M)) + return PreservedAnalyses::none(); + return PreservedAnalyses::all(); + } +}; +} // namespace + +namespace { +class DXILOpLoweringLegacy : public ModulePass { +public: + bool runOnModule(Module &M) override { return lowerIntrinsics(M); } + StringRef getPassName() const override { return "DXIL Op Lowering"; } + DXILOpLoweringLegacy() : ModulePass(ID) {} + + static char ID; // Pass identification. +}; +char DXILOpLoweringLegacy::ID = 0; + +} // end anonymous namespace + +INITIALIZE_PASS_BEGIN(DXILOpLoweringLegacy, DEBUG_TYPE, "DXIL Op Lowering", + false, false) +INITIALIZE_PASS_END(DXILOpLoweringLegacy, DEBUG_TYPE, "DXIL Op Lowering", false, + false) + +ModulePass *llvm::createDXILOpLoweringLegacyPass() { + return new DXILOpLoweringLegacy(); +} diff --git a/llvm/lib/Target/DirectX/DirectX.h b/llvm/lib/Target/DirectX/DirectX.h index 73932ae..73cd553 100644 --- a/llvm/lib/Target/DirectX/DirectX.h +++ b/llvm/lib/Target/DirectX/DirectX.h @@ -23,6 +23,13 @@ void initializeDXILPrepareModulePass(PassRegistry &); /// Pass to convert modules into DXIL-compatable modules ModulePass *createDXILPrepareModulePass(); + +/// Initializer for DXILOpLowering +void initializeDXILOpLoweringLegacyPass(PassRegistry &); + +/// Pass to lowering LLVM intrinsic call to DXIL op function call. +ModulePass *createDXILOpLoweringLegacyPass(); + } // namespace llvm #endif // LLVM_LIB_TARGET_DIRECTX_DIRECTX_H diff --git a/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp b/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp index 98adfbf..a12d87f 100644 --- a/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp +++ b/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp @@ -34,6 +34,7 @@ extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeDirectXTarget() { RegisterTargetMachine X(getTheDirectXTarget()); auto *PR = PassRegistry::getPassRegistry(); initializeDXILPrepareModulePass(*PR); + initializeDXILOpLoweringLegacyPass(*PR); } class DXILTargetObjectFile : public TargetLoweringObjectFile { @@ -84,6 +85,7 @@ bool DirectXTargetMachine::addPassesToEmitFile( PassManagerBase &PM, raw_pwrite_stream &Out, raw_pwrite_stream *DwoOut, CodeGenFileType FileType, bool DisableVerify, MachineModuleInfoWrapperPass *MMIWP) { + PM.add(createDXILOpLoweringLegacyPass()); PM.add(createDXILPrepareModulePass()); switch (FileType) { case CGFT_AssemblyFile: diff --git a/llvm/test/CodeGen/DirectX/sin.ll b/llvm/test/CodeGen/DirectX/sin.ll new file mode 100644 index 0000000..bb31d28 --- /dev/null +++ b/llvm/test/CodeGen/DirectX/sin.ll @@ -0,0 +1,43 @@ +; RUN: opt -S -dxil-op-lower < %s | FileCheck %s + +; Make sure dxil operation function calls for sin are generated for float and half. +; CHECK:call float @dx.op.unary.f32(i32 13, float %{{.*}}) +; CHECK:call half @dx.op.unary.f16(i32 13, half %{{.*}}) + +target datalayout = "e-m:e-p:32:32-i1:32-i8:8-i16:16-i32:32-i64:64-f16:16-f32:32-f64:64-n8:16:32:64" +target triple = "dxil-pc-shadermodel6.7-library" + +; Function Attrs: noinline nounwind optnone +define noundef float @_Z3foof(float noundef %a) #0 { +entry: + %a.addr = alloca float, align 4 + store float %a, ptr %a.addr, align 4 + %0 = load float, ptr %a.addr, align 4 + %1 = call float @llvm.sin.f32(float %0) + ret float %1 +} + +; Function Attrs: nocallback nofree nosync nounwind readnone speculatable willreturn +declare float @llvm.sin.f32(float) #1 + +; Function Attrs: noinline nounwind optnone +define noundef half @_Z3barDh(half noundef %a) #0 { +entry: + %a.addr = alloca half, align 2 + store half %a, ptr %a.addr, align 2 + %0 = load half, ptr %a.addr, align 2 + %1 = call half @llvm.sin.f16(half %0) + ret half %1 +} + +; Function Attrs: nocallback nofree nosync nounwind readnone speculatable willreturn +declare half @llvm.sin.f16(half) #1 + +attributes #0 = { noinline nounwind optnone "frame-pointer"="none" "min-legal-vector-width"="0" "no-trapping-math"="true" "stack-protector-buffer-size"="8" } +attributes #1 = { nocallback nofree nosync nounwind readnone speculatable willreturn } + +!llvm.module.flags = !{!0} +!llvm.ident = !{!1} + +!0 = !{i32 1, !"wchar_size", i32 4} +!1 = !{!"clang version 15.0.0 (https://github.com/llvm/llvm-project.git 73417c517644db5c419c85c0b3cb6750172fcab5)"} diff --git a/llvm/tools/opt/opt.cpp b/llvm/tools/opt/opt.cpp index bd4738d..e43590d 100644 --- a/llvm/tools/opt/opt.cpp +++ b/llvm/tools/opt/opt.cpp @@ -476,7 +476,7 @@ static bool shouldPinPassToLegacyPM(StringRef Pass) { "x86-", "xcore-", "wasm-", "systemz-", "ppc-", "nvvm-", "nvptx-", "mips-", "lanai-", "hexagon-", "bpf-", "avr-", "thumb2-", "arm-", "si-", "gcn-", "amdgpu-", "aarch64-", - "amdgcn-", "polly-", "riscv-"}; + "amdgcn-", "polly-", "riscv-", "dxil-"}; std::vector PassNameContain = {"ehprepare"}; std::vector PassNameExact = { "safe-stack", "cost-model", -- 2.7.4